import arviz as az
import numpy as np
import xarray as xr
import pandas as pd
import pymc as pm
from pymc.math import dot, invlogit
import seaborn as sns

15. GLM Examples*#

Arrhythmia#

A logistic regression example, adapted from Unit 7: arrhythmia.odc.

Data descriptions#

I mirrored the data here.

Variable ID

Name

Description

Y

Fibrillation

Outcome variable: presence of fibrillation.

X1

Age

Age of the patient.

X2

Aortic Cross Clamp Time

Duration of time the aortic valve is clamped during surgery.

X3

Cardiopulmonary Bypass Time

Bypass of the heart and lungs. Involves diverting blood through a heart-lung machine, which performs the functions of the heart and lungs.

X4

ICU Time

Time spent in the Intensive Care Unit.

X5

Avg Heart Rate

Average heart rate of the patient.

X6

Left Ventricle Ejection Fraction

Measure of how well the left ventricle pumps blood out to the body.

X7

Hypertension

Binary: Presence (1) or absence (0) of high blood pressure.

X8

Gender

Binary: 1 for female; 0 for male.

X9

Diabetes

Binary: Presence (1) or absence (0) of diabetes.

X10

Previous MI

Binary: Presence (1) or absence (0) of a previous myocardial infarction (heart attack).

Background#

Patients who undergo Coronary Artery Bypass Graft Surgery (CABG) have an approximate 19–40% chance of developing atrial fibrillation (AF). AF can lead to the formation of blood clots, resulting in increased in-hospital mortality, strokes, and longer hospital stays. While drugs can prevent this condition, they are expensive and can be dangerous if not warranted. Ideally, identifying several risk factors that indicate an increased risk of developing AF could save lives and money by showing which patients need pharmacological intervention. Researchers have begun collecting data such as demographics, heart rate, cholesterol, and operation time from CABG patients during their hospital stays. They have also recorded which patients developed AF. The goal now is to identify the data points that signal a high risk of AF. In the past, factors such as age, hypertension, and body surface area (BSA) have been useful indicators, although they have not provided a satisfactory solution on their own.

Fibrillation occurs when the heart muscle begins a quivering motion instead of maintaining a normal, healthy pumping rhythm. Fibrillation can affect either the atrium (atrial fibrillation) or the ventricle (ventricular fibrillation); the latter is imminently life-threatening.

Atrial fibrillation involves quivering, chaotic motion in the upper chambers of the heart, known as the atria. It is often linked to serious underlying medical conditions and should be evaluated by a physician. Although it is not typically a medical emergency, it still requires medical attention.

Ventricular fibrillation occurs in the ventricles (lower chambers) of the heart and is always a medical emergency. If left untreated, ventricular fibrillation (VF, or V-fib) can lead to death within minutes. When the heart enters V-fib, effective blood pumping ceases. V-fib is considered a form of cardiac arrest, and an individual experiencing it will not survive unless immediate cardiopulmonary resuscitation (CPR) and defibrillation are administered.

Model#

This is a logistic regression model. We consider each patient’s outcome a single Bernoulli event.

\[\begin{split} \begin{align*} y_i | \boldsymbol{\beta}, \mathbf{x}_i &\sim \text{Bernoulli}(n_i, p_i) \\ g(p_i) &= \beta_0 + \sum_{j=1}^p \beta_j x_{ij} \\ \beta_0 &\sim \mathcal{N}(0, \sigma_0^2) && \text{Intercept variance may be higher than coefficient variance.}\\ \beta_j &\sim \mathcal{N}(0, \sigma^2), \quad j = 0, 1, \ldots, p \end{align*} \end{split}\]

where \(g(\cdot)\) is the logit function: \(\text{logit}(p) = \ln\left(\frac{p}{1-p}\right)\), and its inverse \(g^{-1}(\cdot)\) is the logistic function: \(\text{logistic}(x) = \frac{1}{1 + e^{-x}}\).

data_df = pd.read_csv("../data/arrhythmia.csv")
data_df.info()
X = data_df.iloc[:, 1:]
y = data_df["Fibrillation"]
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 81 entries, 0 to 80
Data columns (total 11 columns):
 #   Column                         Non-Null Count  Dtype  
---  ------                         --------------  -----  
 0   Fibrillation                   81 non-null     float64
 1   Age                            81 non-null     float64
 2   AorticCrossClampTime           81 non-null     float64
 3   CardiopulmonaryBypassTime      81 non-null     float64
 4   ICUTime                        81 non-null     float64
 5   AvgHeartRate                   81 non-null     float64
 6   LeftVentricleEjectionFraction  81 non-null     float64
 7   Hypertension                   81 non-null     float64
 8   Gender                         81 non-null     float64
 9   Diabetes                       81 non-null     float64
 10  PreviousMI                     81 non-null     float64
dtypes: float64(11)
memory usage: 7.1 KB
data_df.describe()
Fibrillation Age AorticCrossClampTime CardiopulmonaryBypassTime ICUTime AvgHeartRate LeftVentricleEjectionFraction Hypertension Gender Diabetes PreviousMI
count 81.000000 81.000000 81.000000 81.000000 81.000000 81.000000 81.000000 81.000000 81.000000 81.000000 81.000000
mean 0.345679 66.654321 81.753086 131.123457 16.148716 85.683951 56.401235 0.666667 0.308642 0.419753 0.469136
std 0.478552 10.429718 30.322241 56.196170 3.672736 11.847557 13.634153 0.474342 0.464811 0.496593 0.502156
min 0.000000 44.000000 0.000000 0.000000 2.000000 50.000000 18.000000 0.000000 0.000000 0.000000 0.000000
25% 0.000000 61.000000 67.000000 109.000000 13.500000 77.200000 50.000000 0.000000 0.000000 0.000000 0.000000
50% 0.000000 69.000000 82.000000 128.000000 16.000000 86.700000 59.000000 1.000000 0.000000 0.000000 0.000000
75% 1.000000 73.000000 98.000000 148.000000 19.000000 94.800000 65.000000 1.000000 1.000000 1.000000 1.000000
max 1.000000 88.000000 193.000000 487.000000 23.000000 111.800000 82.000000 1.000000 1.000000 1.000000 1.000000

Our predictors have very different scales. With non-informative priors as the professor uses in the BUGS model, the coefficients should have no trouble fitting the data. However, since PyMC uses a different sampling algorithm it seems to be having trouble with the shape of the posterior. Actually, this used to work fine in PyMC (as of version 5.1.2, at least), but students in Fall 2023 discovered that PyMC could no longer sample this model without divergences (using version 5.9.0 or above).

with pm.Model() as m:
    X_data = pm.Data("X_data", X, mutable=True)
    y_data = pm.Data("y_data", y, mutable=False)

    alpha = pm.Normal("alpha", mu=0, sigma=10)
    betas = pm.Normal("beta", mu=0, sigma=5, shape=X.shape[1])

    p = invlogit(alpha + dot(X_data, betas))

    pm.Bernoulli("y", p=p, observed=y_data)

    trace = pm.sample(5000)
Hide code cell output
Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [alpha, beta]
100.00% [24000/24000 00:14<00:00 Sampling 4 chains, 15,000 divergences]
Sampling 4 chains for 1_000 tune and 5_000 draw iterations (4_000 + 20_000 draws total) took 14 seconds.
The rhat statistic is larger than 1.01 for some parameters. This indicates problems during sampling. See https://arxiv.org/abs/1903.08008 for details
The effective sample size per chain is smaller than 100 for some parameters.  A higher number is needed for reliable rhat and ess computation. See https://arxiv.org/abs/1903.08008 for details
There were 15000 divergences after tuning. Increase `target_accept` or reparameterize.
az.summary(trace, hdi_prob=0.95)
mean sd hdi_2.5% hdi_97.5% mcse_mean mcse_sd ess_bulk ess_tail r_hat
alpha -2.562 5.195 -14.336 0.710 2.383 1.801 4.0 26.0 3.64
beta[0] -0.108 0.460 -0.602 0.499 0.230 0.176 4.0 4.0 7.19
beta[1] 0.065 0.488 -0.374 0.870 0.244 0.187 4.0 11.0 4.32
beta[2] 0.213 0.409 -0.277 0.812 0.204 0.157 4.0 4.0 8.46
beta[3] 0.098 0.578 -0.706 0.790 0.288 0.220 4.0 12.0 5.32
beta[4] -0.299 0.256 -0.689 0.023 0.127 0.098 4.0 4.0 6.35
beta[5] 0.396 0.325 -0.003 0.890 0.162 0.124 4.0 28.0 6.60
beta[6] -0.039 0.577 -1.239 0.701 0.242 0.181 6.0 26.0 2.32
beta[7] -0.366 0.414 -1.124 0.559 0.126 0.092 10.0 26.0 1.97
beta[8] 0.668 0.624 -0.213 1.748 0.268 0.201 6.0 4.0 2.11
beta[9] 0.051 0.625 -0.566 1.087 0.270 0.203 6.0 26.0 2.10
m.to_graphviz()
../_images/eefc6bf95caf55ab438b3ca0c0853c8d3b89943e37ba217ee2336aa08d75985b.svg

With that many divergences, there’s no way the model fit correctly, and that’s borne out in the summary statistics with each r_hat being well above 1.01. So we may need to standardize our data. Andrew Gelman Gelman [2008] suggests standardizing by two standard deviations.

def standardize(X_df: pd.DataFrame) -> pd.DataFrame:
    """
    Standardize input variables by 2 std dev.

    See https://stat.columbia.edu/~gelman/research/published/standardizing7.pdf.
    """
    # find and store means and std, then standardize
    means = X_df.mean(axis=0)
    stdevs = X_df.std(axis=0)
    X_standardized = (X_df - means) / (2 * stdevs)

    return X_standardized
X_std = standardize(data_df.iloc[:, 1:])
with m:
    pm.set_data({"X_data": X_std})
    trace_std = pm.sample(5000)
Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [alpha, beta]
100.00% [24000/24000 00:02<00:00 Sampling 4 chains, 0 divergences]
Sampling 4 chains for 1_000 tune and 5_000 draw iterations (4_000 + 20_000 draws total) took 3 seconds.

Looks like the model fit just fine this time.

az.summary(trace_std, hdi_prob=0.95)
mean sd hdi_2.5% hdi_97.5% mcse_mean mcse_sd ess_bulk ess_tail r_hat
alpha -1.213 0.358 -1.929 -0.520 0.003 0.002 13937.0 12849.0 1.0
beta[0] 3.712 0.978 1.857 5.627 0.008 0.006 14098.0 12890.0 1.0
beta[1] 1.578 1.376 -1.146 4.235 0.013 0.010 10709.0 12360.0 1.0
beta[2] -2.126 1.588 -5.299 0.844 0.015 0.011 11193.0 13133.0 1.0
beta[3] -1.066 0.693 -2.411 0.287 0.005 0.004 18532.0 14805.0 1.0
beta[4] 0.105 0.727 -1.402 1.455 0.005 0.005 18952.0 14915.0 1.0
beta[5] 0.638 0.745 -0.845 2.063 0.006 0.004 16462.0 15661.0 1.0
beta[6] -0.596 0.631 -1.857 0.623 0.004 0.004 19765.0 14749.0 1.0
beta[7] -0.278 0.622 -1.520 0.922 0.004 0.004 21618.0 15623.0 1.0
beta[8] 1.233 0.670 -0.062 2.546 0.005 0.004 16973.0 14099.0 1.0
beta[9] 0.395 0.684 -0.942 1.738 0.005 0.004 17938.0 14891.0 1.0

Ants#

An example of Poisson regression, adapted from Unit 7: ants.odc.

Data description#

Data can be found here.

The data discussed in Gotelli and Ellison (2002) provide the ant species richness (number of ant species) found in 64-square-meter sampling grids in 22 forests (coded as 1) and 22 bogs (coded as 2) surrounding the forests in Connecticut, Massachusetts, and Vermont. The sites span 3 degrees of latitude in New England. There are 44 observations on four variables (columns in data set):

  • Ants: number of species,

  • Habitat: forests (1) and bogs (2),

  • Elevation: in meters above sea level.

(a) Using Poisson regression, model the number of ant species (Ants) with covariates Habitat and Elevation.
(b) For a sampling grid unit located in a forest at the elevation of 100 m how many species the model from (a) predicts? For the model coefficients and the prediction report 95% credible sets.

Poisson regression model#

\[\begin{split} \begin{align*} y_i | \boldsymbol{\theta}, \mathbf{X} &\sim \text{Poisson}(\lambda_i) \\ \lambda_i &= g^{-1}(\beta_0 + \sum_{j=1}^p \beta_j x_{ij}) \\ \beta_0 &\sim \mathcal{N}(0, \sigma_0^2) && \text{Intercept variance may be higher than coefficient variance.}\\ \beta_j &\sim \mathcal{N}(0, \sigma^2) \\ \end{align*} \end{split}\]

For Poisson regression our link function \(g(\cdot)\) is the natural logarithm and its inverse \(g^{-1}(\cdot)\) is the exponential function.

data = pd.read_csv("../data/ants.csv")
data.info()
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 44 entries, 0 to 43
Data columns (total 3 columns):
 #   Column     Non-Null Count  Dtype
---  ------     --------------  -----
 0   ants       44 non-null     int64
 1   habitat    44 non-null     int64
 2   elevation  44 non-null     int64
dtypes: int64(3)
memory usage: 1.2 KB
with pm.Model() as m:
    ant_species = pm.Data("ant_species", data["ants"].to_numpy(), mutable=False)
    habitat = pm.Data("habitat", data["habitat"].to_numpy(), mutable=True)
    elevation = pm.Data("elevation", data["elevation"].to_numpy(), mutable=True)

    beta0 = pm.Normal("beta0_intercept", mu=0, tau=0.0001)
    beta1 = pm.Normal("beta1_habitat", mu=0, tau=0.0001)
    beta2 = pm.Normal("beta2_elevation", mu=0, tau=0.0001)

    μ = pm.math.exp(beta0 + beta1 * habitat + beta2 * elevation)

    y = pm.Poisson("y", mu=μ, observed=ant_species)

    trace = pm.sample(5000, tune=2000, init="adapt_diag")
Hide code cell output
Auto-assigning NUTS sampler...
Initializing NUTS using adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [beta0_intercept, beta1_habitat, beta2_elevation]
100.00% [28000/28000 00:04<00:00 Sampling 4 chains, 0 divergences]
Sampling 4 chains for 2_000 tune and 5_000 draw iterations (8_000 + 20_000 draws total) took 4 seconds.
az.summary(trace)
mean sd hdi_3% hdi_97% mcse_mean mcse_sd ess_bulk ess_tail r_hat
beta0_intercept 3.171 0.183 2.819 3.506 0.002 0.002 7068.0 8835.0 1.0
beta1_habitat -0.638 0.119 -0.863 -0.418 0.001 0.001 7274.0 8587.0 1.0
beta2_elevation -0.001 0.000 -0.002 -0.001 0.000 0.000 9650.0 9345.0 1.0
# prediction
with m:
    pm.set_data({"habitat": [1], "elevation": [100]})
    ppc = pm.sample_posterior_predictive(trace, predictions=True)
Hide code cell output
Sampling: [y]
100.00% [20000/20000 00:00<00:00]
ppc.predictions
<xarray.Dataset>
Dimensions:  (chain: 4, draw: 5000, y_dim_2: 44)
Coordinates:
  * chain    (chain) int64 0 1 2 3
  * draw     (draw) int64 0 1 2 3 4 5 6 7 ... 4993 4994 4995 4996 4997 4998 4999
  * y_dim_2  (y_dim_2) int64 0 1 2 3 4 5 6 7 8 9 ... 35 36 37 38 39 40 41 42 43
Data variables:
    y        (chain, draw, y_dim_2) int64 8 11 11 13 6 8 13 ... 10 17 10 17 6 9
Attributes:
    created_at:                 2023-10-28T23:56:19.150908
    arviz_version:              0.16.1
    inference_library:          pymc
    inference_library_version:  5.9.0
az.summary(ppc.predictions).mean()
mean            10.875045
sd               3.406659
hdi_3%           5.000000
hdi_97%         17.000000
mcse_mean        0.024955
mcse_sd          0.017636
ess_bulk     18690.204545
ess_tail     19293.545455
r_hat            1.000000
dtype: float64
%load_ext watermark
%watermark -n -u -v -iv -p pytensor
Last updated: Sat Oct 28 2023

Python implementation: CPython
Python version       : 3.11.5
IPython version      : 8.15.0

pytensor: 2.17.1

pandas : 2.1.0
seaborn: 0.13.0
arviz  : 0.16.1
pymc   : 5.9.0
xarray : 2023.8.0
numpy  : 1.25.2