import arviz as az
import numpy as np
import xarray as xr
import pandas as pd
import pymc as pm
from pymc.math import dot, invlogit
import seaborn as sns
15. GLM Examples*#
Arrhythmia#
A logistic regression example, adapted from Unit 7: arrhythmia.odc.
Data descriptions#
I mirrored the data here.
Variable ID |
Name |
Description |
---|---|---|
Y |
Fibrillation |
Outcome variable: presence of fibrillation. |
X1 |
Age |
Age of the patient. |
X2 |
Aortic Cross Clamp Time |
Duration of time the aortic valve is clamped during surgery. |
X3 |
Cardiopulmonary Bypass Time |
Bypass of the heart and lungs. Involves diverting blood through a heart-lung machine, which performs the functions of the heart and lungs. |
X4 |
ICU Time |
Time spent in the Intensive Care Unit. |
X5 |
Avg Heart Rate |
Average heart rate of the patient. |
X6 |
Left Ventricle Ejection Fraction |
Measure of how well the left ventricle pumps blood out to the body. |
X7 |
Hypertension |
Binary: Presence (1) or absence (0) of high blood pressure. |
X8 |
Gender |
Binary: 1 for female; 0 for male. |
X9 |
Diabetes |
Binary: Presence (1) or absence (0) of diabetes. |
X10 |
Previous MI |
Binary: Presence (1) or absence (0) of a previous myocardial infarction (heart attack). |
Background#
Patients who undergo Coronary Artery Bypass Graft Surgery (CABG) have an approximate 19–40% chance of developing atrial fibrillation (AF). AF can lead to the formation of blood clots, resulting in increased in-hospital mortality, strokes, and longer hospital stays. While drugs can prevent this condition, they are expensive and can be dangerous if not warranted. Ideally, identifying several risk factors that indicate an increased risk of developing AF could save lives and money by showing which patients need pharmacological intervention. Researchers have begun collecting data such as demographics, heart rate, cholesterol, and operation time from CABG patients during their hospital stays. They have also recorded which patients developed AF. The goal now is to identify the data points that signal a high risk of AF. In the past, factors such as age, hypertension, and body surface area (BSA) have been useful indicators, although they have not provided a satisfactory solution on their own.
Fibrillation occurs when the heart muscle begins a quivering motion instead of maintaining a normal, healthy pumping rhythm. Fibrillation can affect either the atrium (atrial fibrillation) or the ventricle (ventricular fibrillation); the latter is imminently life-threatening.
Atrial fibrillation involves quivering, chaotic motion in the upper chambers of the heart, known as the atria. It is often linked to serious underlying medical conditions and should be evaluated by a physician. Although it is not typically a medical emergency, it still requires medical attention.
Ventricular fibrillation occurs in the ventricles (lower chambers) of the heart and is always a medical emergency. If left untreated, ventricular fibrillation (VF, or V-fib) can lead to death within minutes. When the heart enters V-fib, effective blood pumping ceases. V-fib is considered a form of cardiac arrest, and an individual experiencing it will not survive unless immediate cardiopulmonary resuscitation (CPR) and defibrillation are administered.
Model#
This is a logistic regression model. We consider each patient’s outcome a single Bernoulli event.
where \(k\) is the number of predictors and \(g(\cdot)\) is the logit function: \(\text{logit}(p) = \ln\left(\frac{p}{1-p}\right)\), and its inverse \(g^{-1}(\cdot)\) is the logistic function: \(\text{logistic}(x) = \frac{1}{1 + e^{-x}}\).
If your data is in an aggregated format, you should consider going with a Binomial likelihood. The model can be equivalently stated this way:
data_df = pd.read_csv("../data/arrhythmia.csv")
data_df.info()
X = data_df.iloc[:, 1:]
y = data_df["Fibrillation"]
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 81 entries, 0 to 80
Data columns (total 11 columns):
# Column Non-Null Count Dtype
--- ------ -------------- -----
0 Fibrillation 81 non-null float64
1 Age 81 non-null float64
2 AorticCrossClampTime 81 non-null float64
3 CardiopulmonaryBypassTime 81 non-null float64
4 ICUTime 81 non-null float64
5 AvgHeartRate 81 non-null float64
6 LeftVentricleEjectionFraction 81 non-null float64
7 Hypertension 81 non-null float64
8 Gender 81 non-null float64
9 Diabetes 81 non-null float64
10 PreviousMI 81 non-null float64
dtypes: float64(11)
memory usage: 7.1 KB
data_df.describe()
Fibrillation | Age | AorticCrossClampTime | CardiopulmonaryBypassTime | ICUTime | AvgHeartRate | LeftVentricleEjectionFraction | Hypertension | Gender | Diabetes | PreviousMI | |
---|---|---|---|---|---|---|---|---|---|---|---|
count | 81.000000 | 81.000000 | 81.000000 | 81.000000 | 81.000000 | 81.000000 | 81.000000 | 81.000000 | 81.000000 | 81.000000 | 81.000000 |
mean | 0.345679 | 66.654321 | 81.753086 | 131.123457 | 16.148716 | 85.683951 | 56.401235 | 0.666667 | 0.308642 | 0.419753 | 0.469136 |
std | 0.478552 | 10.429718 | 30.322241 | 56.196170 | 3.672736 | 11.847557 | 13.634153 | 0.474342 | 0.464811 | 0.496593 | 0.502156 |
min | 0.000000 | 44.000000 | 0.000000 | 0.000000 | 2.000000 | 50.000000 | 18.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 |
25% | 0.000000 | 61.000000 | 67.000000 | 109.000000 | 13.500000 | 77.200000 | 50.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 |
50% | 0.000000 | 69.000000 | 82.000000 | 128.000000 | 16.000000 | 86.700000 | 59.000000 | 1.000000 | 0.000000 | 0.000000 | 0.000000 |
75% | 1.000000 | 73.000000 | 98.000000 | 148.000000 | 19.000000 | 94.800000 | 65.000000 | 1.000000 | 1.000000 | 1.000000 | 1.000000 |
max | 1.000000 | 88.000000 | 193.000000 | 487.000000 | 23.000000 | 111.800000 | 82.000000 | 1.000000 | 1.000000 | 1.000000 | 1.000000 |
Our predictors have very different scales. With non-informative priors as the professor uses in the BUGS model, the coefficients should have no trouble fitting the data. However, since PyMC uses a different sampling algorithm it seems to be having trouble with the shape of the posterior. Actually, this used to work fine in PyMC (as of version 5.1.2, at least), but students in Fall 2023 discovered that PyMC could no longer sample this model without divergences (using version 5.9.0 or above).
with pm.Model() as m:
X_data = pm.Data("X_data", X, mutable=True)
y_data = pm.Data("y_data", y, mutable=False)
alpha = pm.Normal("alpha", mu=0, sigma=10)
betas = pm.Normal("beta", mu=0, sigma=5, shape=X.shape[1])
p = invlogit(alpha + dot(X_data, betas))
pm.Bernoulli("y", p=p, observed=y_data)
trace = pm.sample(5000)
Show code cell output
Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [alpha, beta]
Sampling 4 chains for 1_000 tune and 5_000 draw iterations (4_000 + 20_000 draws total) took 14 seconds.
The rhat statistic is larger than 1.01 for some parameters. This indicates problems during sampling. See https://arxiv.org/abs/1903.08008 for details
The effective sample size per chain is smaller than 100 for some parameters. A higher number is needed for reliable rhat and ess computation. See https://arxiv.org/abs/1903.08008 for details
There were 15000 divergences after tuning. Increase `target_accept` or reparameterize.
az.summary(trace, hdi_prob=0.95)
mean | sd | hdi_2.5% | hdi_97.5% | mcse_mean | mcse_sd | ess_bulk | ess_tail | r_hat | |
---|---|---|---|---|---|---|---|---|---|
alpha | -2.562 | 5.195 | -14.336 | 0.710 | 2.383 | 1.801 | 4.0 | 26.0 | 3.64 |
beta[0] | -0.108 | 0.460 | -0.602 | 0.499 | 0.230 | 0.176 | 4.0 | 4.0 | 7.19 |
beta[1] | 0.065 | 0.488 | -0.374 | 0.870 | 0.244 | 0.187 | 4.0 | 11.0 | 4.32 |
beta[2] | 0.213 | 0.409 | -0.277 | 0.812 | 0.204 | 0.157 | 4.0 | 4.0 | 8.46 |
beta[3] | 0.098 | 0.578 | -0.706 | 0.790 | 0.288 | 0.220 | 4.0 | 12.0 | 5.32 |
beta[4] | -0.299 | 0.256 | -0.689 | 0.023 | 0.127 | 0.098 | 4.0 | 4.0 | 6.35 |
beta[5] | 0.396 | 0.325 | -0.003 | 0.890 | 0.162 | 0.124 | 4.0 | 28.0 | 6.60 |
beta[6] | -0.039 | 0.577 | -1.239 | 0.701 | 0.242 | 0.181 | 6.0 | 26.0 | 2.32 |
beta[7] | -0.366 | 0.414 | -1.124 | 0.559 | 0.126 | 0.092 | 10.0 | 26.0 | 1.97 |
beta[8] | 0.668 | 0.624 | -0.213 | 1.748 | 0.268 | 0.201 | 6.0 | 4.0 | 2.11 |
beta[9] | 0.051 | 0.625 | -0.566 | 1.087 | 0.270 | 0.203 | 6.0 | 26.0 | 2.10 |
m.to_graphviz()
With that many divergences, there’s no way the model fit correctly, and that’s borne out in the summary statistics with each r_hat being well above 1.01. So we may need to standardize our data. Andrew Gelman Gelman [2008] suggests standardizing by two standard deviations.
def standardize(X_df: pd.DataFrame) -> pd.DataFrame:
"""
Standardize input variables by 2 std dev.
See https://stat.columbia.edu/~gelman/research/published/standardizing7.pdf.
"""
# find and store means and std, then standardize
means = X_df.mean(axis=0)
stdevs = X_df.std(axis=0)
X_standardized = (X_df - means) / (2 * stdevs)
return X_standardized
X_std = standardize(data_df.iloc[:, 1:])
with m:
pm.set_data({"X_data": X_std})
trace_std = pm.sample(5000)
Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [alpha, beta]
Sampling 4 chains for 1_000 tune and 5_000 draw iterations (4_000 + 20_000 draws total) took 3 seconds.
Looks like the model fit just fine this time.
az.summary(trace_std, hdi_prob=0.95)
mean | sd | hdi_2.5% | hdi_97.5% | mcse_mean | mcse_sd | ess_bulk | ess_tail | r_hat | |
---|---|---|---|---|---|---|---|---|---|
alpha | -1.213 | 0.358 | -1.929 | -0.520 | 0.003 | 0.002 | 13937.0 | 12849.0 | 1.0 |
beta[0] | 3.712 | 0.978 | 1.857 | 5.627 | 0.008 | 0.006 | 14098.0 | 12890.0 | 1.0 |
beta[1] | 1.578 | 1.376 | -1.146 | 4.235 | 0.013 | 0.010 | 10709.0 | 12360.0 | 1.0 |
beta[2] | -2.126 | 1.588 | -5.299 | 0.844 | 0.015 | 0.011 | 11193.0 | 13133.0 | 1.0 |
beta[3] | -1.066 | 0.693 | -2.411 | 0.287 | 0.005 | 0.004 | 18532.0 | 14805.0 | 1.0 |
beta[4] | 0.105 | 0.727 | -1.402 | 1.455 | 0.005 | 0.005 | 18952.0 | 14915.0 | 1.0 |
beta[5] | 0.638 | 0.745 | -0.845 | 2.063 | 0.006 | 0.004 | 16462.0 | 15661.0 | 1.0 |
beta[6] | -0.596 | 0.631 | -1.857 | 0.623 | 0.004 | 0.004 | 19765.0 | 14749.0 | 1.0 |
beta[7] | -0.278 | 0.622 | -1.520 | 0.922 | 0.004 | 0.004 | 21618.0 | 15623.0 | 1.0 |
beta[8] | 1.233 | 0.670 | -0.062 | 2.546 | 0.005 | 0.004 | 16973.0 | 14099.0 | 1.0 |
beta[9] | 0.395 | 0.684 | -0.942 | 1.738 | 0.005 | 0.004 | 17938.0 | 14891.0 | 1.0 |
Ants#
An example of Poisson regression, adapted from Unit 7: ants.odc.
Data description#
Data can be found here.
The data discussed in Gotelli and Ellison (2002) provide the ant species richness (number of ant species) found in 64-square-meter sampling grids in 22 forests (coded as 1) and 22 bogs (coded as 2) surrounding the forests in Connecticut, Massachusetts, and Vermont. The sites span 3 degrees of latitude in New England. There are 44 observations on four variables (columns in data set):
Ants: number of species,
Habitat: forests (1) and bogs (2),
Elevation: in meters above sea level.
(a) Using Poisson regression, model the number of ant species (Ants) with covariates Habitat and Elevation.
(b) For a sampling grid unit located in a forest at the elevation of 100 m how many species the model from (a) predicts? For the model coefficients and the prediction report 95% credible sets.
Poisson regression model#
For Poisson regression our link function \(g(\cdot)\) is the natural logarithm and its inverse \(g^{-1}(\cdot)\) is the exponential function.
data = pd.read_csv("../data/ants.csv")
data.info()
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 44 entries, 0 to 43
Data columns (total 3 columns):
# Column Non-Null Count Dtype
--- ------ -------------- -----
0 ants 44 non-null int64
1 habitat 44 non-null int64
2 elevation 44 non-null int64
dtypes: int64(3)
memory usage: 1.2 KB
with pm.Model() as m:
ant_species = pm.Data("ant_species", data["ants"].to_numpy(), mutable=False)
habitat = pm.Data("habitat", data["habitat"].to_numpy(), mutable=True)
elevation = pm.Data("elevation", data["elevation"].to_numpy(), mutable=True)
beta0 = pm.Normal("beta0_intercept", mu=0, tau=0.0001)
beta1 = pm.Normal("beta1_habitat", mu=0, tau=0.0001)
beta2 = pm.Normal("beta2_elevation", mu=0, tau=0.0001)
μ = pm.math.exp(beta0 + beta1 * habitat + beta2 * elevation)
y = pm.Poisson("y", mu=μ, observed=ant_species)
trace = pm.sample(5000, tune=2000, init="adapt_diag")
Show code cell output
Auto-assigning NUTS sampler...
Initializing NUTS using adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [beta0_intercept, beta1_habitat, beta2_elevation]
Sampling 4 chains for 2_000 tune and 5_000 draw iterations (8_000 + 20_000 draws total) took 4 seconds.
az.summary(trace)
mean | sd | hdi_3% | hdi_97% | mcse_mean | mcse_sd | ess_bulk | ess_tail | r_hat | |
---|---|---|---|---|---|---|---|---|---|
beta0_intercept | 3.171 | 0.183 | 2.819 | 3.506 | 0.002 | 0.002 | 7068.0 | 8835.0 | 1.0 |
beta1_habitat | -0.638 | 0.119 | -0.863 | -0.418 | 0.001 | 0.001 | 7274.0 | 8587.0 | 1.0 |
beta2_elevation | -0.001 | 0.000 | -0.002 | -0.001 | 0.000 | 0.000 | 9650.0 | 9345.0 | 1.0 |
# prediction
with m:
pm.set_data({"habitat": [1], "elevation": [100]})
ppc = pm.sample_posterior_predictive(trace, predictions=True)
Show code cell output
Sampling: [y]
ppc.predictions
<xarray.Dataset> Dimensions: (chain: 4, draw: 5000, y_dim_2: 44) Coordinates: * chain (chain) int64 0 1 2 3 * draw (draw) int64 0 1 2 3 4 5 6 7 ... 4993 4994 4995 4996 4997 4998 4999 * y_dim_2 (y_dim_2) int64 0 1 2 3 4 5 6 7 8 9 ... 35 36 37 38 39 40 41 42 43 Data variables: y (chain, draw, y_dim_2) int64 8 11 11 13 6 8 13 ... 10 17 10 17 6 9 Attributes: created_at: 2023-10-28T23:56:19.150908 arviz_version: 0.16.1 inference_library: pymc inference_library_version: 5.9.0
az.summary(ppc.predictions).mean()
mean 10.875045
sd 3.406659
hdi_3% 5.000000
hdi_97% 17.000000
mcse_mean 0.024955
mcse_sd 0.017636
ess_bulk 18690.204545
ess_tail 19293.545455
r_hat 1.000000
dtype: float64
%load_ext watermark
%watermark -n -u -v -iv -p pytensor
Last updated: Sat Oct 28 2023
Python implementation: CPython
Python version : 3.11.5
IPython version : 8.15.0
pytensor: 2.17.1
pandas : 2.1.0
seaborn: 0.13.0
arviz : 0.16.1
pymc : 5.9.0
xarray : 2023.8.0
numpy : 1.25.2