import pymc as pm
import numpy as np
import arviz as az
import pytensor.tensor as pt
17. Multinomial regression*#
Adapted from Unit 7: NHANESmulti.odc
Problem statement#
The National Health and Nutrition Examination Survey (NHANES) is a program of studies designed to assess the health and nutritional status of adults and children in the United States. The survey is unique in that it combines interviews and physical examinations.
Assume that \(N\) subjects select a choice fromm \(K\) categories. The \(i\text{-th}\) subject is characterized by 3 covariates x[i, 1]
, x[i, 2]
, and x[i, 3]
. Given the covariates, model the probability of a subject selecting the category \(k\), where \(k \in \{1,...,K\}\).
# data
# fmt: off
y = np.array([[1, 0, 0, 0, 0],
[0, 1, 0, 0, 0],
[1, 0, 0, 0, 0],
[0, 0, 1, 0, 0],
[0, 1, 0, 0, 0],
[0, 0, 1, 0, 0],
[0, 0, 0, 1, 0],
[0, 0, 0, 0, 1],
[0, 0, 0, 0, 1],
[0, 0, 0, 1, 0]])
X = np.array([[2, 4, 9],
[1, 5, 10],
[1, 6, 14],
[2, 4, 21],
[2, 4, 22],
[2, 6, 30],
[3, 3, 33],
[3, 2, 36],
[3, 1, 40],
[4, 1, 44]])
# fmt: on
# N = 10, P = 4 (intercept + 3 predictors)
X_aug = np.concatenate((np.ones((X.shape[0], 1)), X), axis=1)
N, P = X_aug.shape
K = y.shape[1]
The Multinomial distribution has two parameters, \(n\) and \(p\). See the PyMC Documentation for the Multinomial distribution class
Argument |
What it means |
In this model |
---|---|---|
|
Total number of independent trials per replicate. A Multinomial with \(n = 1\) is a one‑hot vector. Must equal the sum (per-row) of the corresponding observed values. |
Each survey respondent makes exactly one choice, so |
|
Vector (or matrix) of category probabilities. Must sum to 1 along the last axis. |
We generate |
|
Either a one‑hot matrix (\(n = 1\)) or a count matrix (if \(n > 1\)). Shape must match |
For this example every row of |
with pm.Model() as m:
y_data = pm.Data("y", y)
X_data = pm.Data("X", X_aug)
_beta = pm.Normal("_beta", mu=0, tau=0.1, shape=(P, K - 1))
beta = pt.concatenate([pt.zeros((P, 1)), _beta], axis=1)
eta = pm.math.dot(X_data, beta)
p = pm.math.softmax(eta, axis=1)
pm.Multinomial("likelihood", n=1, p=p, observed=y_data, shape=X_data.shape)
trace = pm.sample(10000)
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [_beta]
Sampling 4 chains for 1_000 tune and 10_000 draw iterations (4_000 + 40_000 draws total) took 23 seconds.
X_new = np.array([1, 3, 3, 30]).reshape((1, 4))
with m:
pm.set_data({"X": X_new})
ppc = pm.sample_posterior_predictive(trace, predictions=True)
Sampling: [likelihood]
az.summary(ppc.predictions)
mean | sd | hdi_3% | hdi_97% | mcse_mean | mcse_sd | ess_bulk | ess_tail | r_hat | |
---|---|---|---|---|---|---|---|---|---|
likelihood[0, 0] | 0.007 | 0.086 | 0.0 | 0.0 | 0.000 | 0.002 | 39808.0 | 39808.0 | 1.0 |
likelihood[0, 1] | 0.038 | 0.191 | 0.0 | 0.0 | 0.001 | 0.002 | 39180.0 | 39180.0 | 1.0 |
likelihood[0, 2] | 0.116 | 0.320 | 0.0 | 1.0 | 0.002 | 0.002 | 40635.0 | 40000.0 | 1.0 |
likelihood[0, 3] | 0.686 | 0.464 | 0.0 | 1.0 | 0.002 | 0.001 | 40664.0 | 40000.0 | 1.0 |
likelihood[0, 4] | 0.153 | 0.360 | 0.0 | 1.0 | 0.002 | 0.002 | 40308.0 | 40000.0 | 1.0 |