import pymc as pm
import numpy as np
import arviz as az
import pytensor.tensor as pt

17. Multinomial regression*#

Adapted from Unit 7: NHANESmulti.odc

Problem statement#

The National Health and Nutrition Examination Survey (NHANES) is a program of studies designed to assess the health and nutritional status of adults and children in the United States. The survey is unique in that it combines interviews and physical examinations.

Assume that \(N\) subjects select a choice fromm \(K\) categories. The \(i\text{-th}\) subject is characterized by 3 covariates x[i, 1], x[i, 2], and x[i, 3]. Given the covariates, model the probability of a subject selecting the category \(k\), where \(k \in \{1,...,K\}\).

# data
# fmt: off
y = np.array([[1, 0, 0, 0, 0],
              [0, 1, 0, 0, 0],
              [1, 0, 0, 0, 0],
              [0, 0, 1, 0, 0],
              [0, 1, 0, 0, 0],
              [0, 0, 1, 0, 0],
              [0, 0, 0, 1, 0],
              [0, 0, 0, 0, 1],
              [0, 0, 0, 0, 1],
              [0, 0, 0, 1, 0]])

X = np.array([[2, 4, 9],
              [1, 5, 10],
              [1, 6, 14],
              [2, 4, 21],
              [2, 4, 22],
              [2, 6, 30],
              [3, 3, 33],
              [3, 2, 36],
              [3, 1, 40],
              [4, 1, 44]])
# fmt: on
# N = 10, P = 4  (intercept + 3 predictors)
X_aug = np.concatenate((np.ones((X.shape[0], 1)), X), axis=1)
N, P = X_aug.shape
K = y.shape[1]

The Multinomial distribution has two parameters, \(n\) and \(p\). See the PyMC Documentation for the Multinomial distribution class

Argument

What it means

In this model

n

Total number of independent trials per replicate. A Multinomial with \(n = 1\) is a one‑hot vector. Must equal the sum (per-row) of the corresponding observed values.

Each survey respondent makes exactly one choice, so n = 1.

p

Vector (or matrix) of category probabilities. Must sum to 1 along the last axis.

We generate p with a softmax of the linear predictor \(\eta = X\beta\). Its shape is \((N, K)\), one probability vector per observation.

observed

Either a one‑hot matrix (\(n = 1\)) or a count matrix (if \(n > 1\)). Shape must match p.

For this example every row of y is one‑hot encoded.

with pm.Model() as m:
    y_data = pm.Data("y", y)
    X_data = pm.Data("X", X_aug)

    _beta = pm.Normal("_beta", mu=0, tau=0.1, shape=(P, K - 1))
    beta = pt.concatenate([pt.zeros((P, 1)), _beta], axis=1)

    eta = pm.math.dot(X_data, beta)
    p = pm.math.softmax(eta, axis=1)

    pm.Multinomial("likelihood", n=1, p=p, observed=y_data, shape=X_data.shape)

    trace = pm.sample(10000)
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [_beta]

Sampling 4 chains for 1_000 tune and 10_000 draw iterations (4_000 + 40_000 draws total) took 23 seconds.
X_new = np.array([1, 3, 3, 30]).reshape((1, 4))

with m:
    pm.set_data({"X": X_new})
    ppc = pm.sample_posterior_predictive(trace, predictions=True)
Sampling: [likelihood]

az.summary(ppc.predictions)
mean sd hdi_3% hdi_97% mcse_mean mcse_sd ess_bulk ess_tail r_hat
likelihood[0, 0] 0.007 0.086 0.0 0.0 0.000 0.002 39808.0 39808.0 1.0
likelihood[0, 1] 0.038 0.191 0.0 0.0 0.001 0.002 39180.0 39180.0 1.0
likelihood[0, 2] 0.116 0.320 0.0 1.0 0.002 0.002 40635.0 40000.0 1.0
likelihood[0, 3] 0.686 0.464 0.0 1.0 0.002 0.001 40664.0 40000.0 1.0
likelihood[0, 4] 0.153 0.360 0.0 1.0 0.002 0.002 40308.0 40000.0 1.0

Author#

Aaron Reding, May 2022 (updated April 2025)

%load_ext watermark
%watermark -n -u -v -iv
Last updated: Sun Apr 20 2025

Python implementation: CPython
Python version       : 3.12.7
IPython version      : 8.29.0

numpy   : 1.26.4
pytensor: 2.30.2
arviz   : 0.21.0
pymc    : 5.22.0