import pymc as pm
import numpy as np
import arviz as az

%load_ext lab_black

4. Rasch*#

Adapted from Unit 10: rasch.odc.

Data can be found here.

  • True/False Questions

  • 1 if answered correctly, 0 otherwise

  • n students

  • k questions

  • Assess (relative) ability of students

  • Assess (relative) difficulty of questions

  • Originally motivated by testing/education, applicable in different contexts

Notes:#

  • Model works well and matches BUGS results with simple broadcasting. Just need to figure out a better way to display the results.

y = np.loadtxt("../data/rasch.txt")
n, k = y.shape
n, k
(162, 33)
with pm.Model() as m:
    tau_alpha = pm.Gamma("tau_alpha", 0.01, 0.01)
    var_alpha = pm.Deterministic("var_alpha", 1 / tau_alpha)
    tau_delta = pm.Gamma("tau_delta", 0.01, 0.01)
    # there's a typo for mu in BUGS version
    mu_delta = pm.Normal("mu_delta", 0, tau=0.001)

    # the 1s in the shapes are for broadcasting
    delta = pm.Normal("delta", mu_delta, tau=tau_delta, shape=(1, 33))
    alpha = pm.Normal("alpha", 0, tau=tau_alpha, shape=(162, 1))

    p = alpha - delta

    pm.Bernoulli("likelihood", logit_p=p, observed=y)

    trace = pm.sample(3000)
Hide code cell output
Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [tau_alpha, tau_delta, mu_delta, delta, alpha]
100.00% [16000/16000 00:14<00:00 Sampling 4 chains, 0 divergences]
Sampling 4 chains for 1_000 tune and 3_000 draw iterations (4_000 + 12_000 draws total) took 15 seconds.
az.summary(trace, var_names=["delta"])
mean sd hdi_3% hdi_97% mcse_mean mcse_sd ess_bulk ess_tail r_hat
delta[0, 0] 0.951 0.231 0.509 1.375 0.005 0.003 2540.0 6221.0 1.0
delta[0, 1] 0.574 0.222 0.171 1.010 0.005 0.003 2302.0 4440.0 1.0
delta[0, 2] 0.255 0.219 -0.163 0.659 0.004 0.003 2399.0 5453.0 1.0
delta[0, 3] 1.451 0.245 0.982 1.901 0.005 0.003 2841.0 6140.0 1.0
delta[0, 4] -1.018 0.225 -1.428 -0.584 0.005 0.003 2242.0 4786.0 1.0
delta[0, 5] 0.684 0.229 0.246 1.103 0.005 0.003 2355.0 5108.0 1.0
delta[0, 6] 0.953 0.231 0.511 1.378 0.005 0.003 2375.0 4748.0 1.0
delta[0, 7] 0.759 0.227 0.338 1.189 0.004 0.003 2814.0 5476.0 1.0
delta[0, 8] 0.649 0.227 0.219 1.070 0.005 0.003 2536.0 5835.0 1.0
delta[0, 9] 1.682 0.251 1.217 2.158 0.005 0.003 2722.0 5779.0 1.0
delta[0, 10] 1.152 0.235 0.713 1.595 0.004 0.003 2770.0 5242.0 1.0
delta[0, 11] -1.201 0.227 -1.627 -0.779 0.005 0.003 2415.0 5362.0 1.0
delta[0, 12] -0.187 0.223 -0.620 0.225 0.005 0.003 2194.0 5197.0 1.0
delta[0, 13] -0.088 0.215 -0.500 0.301 0.005 0.003 1996.0 4825.0 1.0
delta[0, 14] 0.873 0.229 0.445 1.305 0.005 0.003 2496.0 5520.0 1.0
delta[0, 15] 2.271 0.274 1.743 2.776 0.005 0.003 3313.0 6834.0 1.0
delta[0, 16] 1.543 0.243 1.084 2.002 0.004 0.003 2953.0 6441.0 1.0
delta[0, 17] -0.291 0.220 -0.724 0.091 0.005 0.003 2148.0 5395.0 1.0
delta[0, 18] 0.049 0.218 -0.360 0.461 0.004 0.003 2413.0 5464.0 1.0
delta[0, 19] -1.315 0.233 -1.742 -0.873 0.005 0.003 2466.0 5697.0 1.0
delta[0, 20] 2.805 0.300 2.209 3.342 0.005 0.004 3432.0 6759.0 1.0
delta[0, 21] -0.458 0.217 -0.854 -0.043 0.005 0.003 2296.0 4904.0 1.0
delta[0, 22] 0.573 0.224 0.134 0.984 0.004 0.003 2548.0 5782.0 1.0
delta[0, 23] 0.359 0.218 -0.053 0.769 0.004 0.003 2473.0 5795.0 1.0
delta[0, 24] -1.767 0.243 -2.206 -1.291 0.005 0.003 2708.0 6129.0 1.0
delta[0, 25] -1.474 0.232 -1.906 -1.038 0.005 0.003 2537.0 6302.0 1.0
delta[0, 26] 0.872 0.226 0.447 1.293 0.004 0.003 2635.0 5833.0 1.0
delta[0, 27] 0.084 0.219 -0.329 0.494 0.005 0.003 2220.0 5476.0 1.0
delta[0, 28] 1.319 0.239 0.881 1.774 0.005 0.003 2536.0 6135.0 1.0
delta[0, 29] -0.394 0.219 -0.804 0.013 0.005 0.003 2209.0 4431.0 1.0
delta[0, 30] 0.398 0.224 -0.029 0.810 0.005 0.003 2232.0 5299.0 1.0
delta[0, 31] -0.874 0.222 -1.279 -0.455 0.005 0.003 2324.0 5660.0 1.0
delta[0, 32] 0.723 0.227 0.308 1.159 0.005 0.003 2448.0 5629.0 1.0
%load_ext watermark
%watermark -n -u -v -iv -p pytensor
Last updated: Wed Mar 22 2023

Python implementation: CPython
Python version       : 3.11.0
IPython version      : 8.9.0

pytensor: 2.10.1

arviz: 0.15.1
numpy: 1.24.2
pymc : 5.1.2