{ "cells": [ { "cell_type": "code", "execution_count": 1, "id": "3d127ad1", "metadata": {}, "outputs": [], "source": [ "import pymc as pm\n", "import numpy as np\n", "import arviz as az\n", "\n", "%load_ext lab_black" ] }, { "cell_type": "markdown", "id": "122c2e96", "metadata": {}, "source": [ "# Rasch*\n", "\n", "Adapted from [Unit 10: rasch.odc](https://raw.githubusercontent.com/areding/6420-pymc/main/original_examples/Codes4Unit10/rasch.odc).\n", "\n", "Data can be found [here](https://raw.githubusercontent.com/areding/6420-pymc/main/data/rasch.txt).\n" ] }, { "cell_type": "markdown", "id": "7dbd4c14", "metadata": {}, "source": [ "* True/False Questions \n", "* 1 if answered correctly, 0 otherwise\n", "* n students\n", "* k questions\n", "* Assess (relative) ability of students\n", "* Assess (relative) difficulty of questions\n", "* Originally motivated by testing/education, applicable in different contexts\n", "\n", "\n", "## Notes: \n", "\n", "- Model works well and matches BUGS results with simple broadcasting. Just need to figure out a better way to display the results." ] }, { "cell_type": "code", "execution_count": 2, "id": "ea318d60", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(162, 33)" ] }, "execution_count": 2, "metadata": {}, "output_type": "execute_result" } ], "source": [ "y = np.loadtxt(\"../data/rasch.txt\")\n", "n, k = y.shape\n", "n, k" ] }, { "cell_type": "code", "execution_count": 3, "id": "49bca796", "metadata": { "tags": [ "hide-output" ] }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Auto-assigning NUTS sampler...\n", "Initializing NUTS using jitter+adapt_diag...\n", "Multiprocess sampling (4 chains in 4 jobs)\n", "NUTS: [tau_alpha, tau_delta, mu_delta, delta, alpha]\n" ] }, { "data": { "text/html": [ "\n", "\n" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "\n", "
\n", " \n", " 100.00% [16000/16000 00:14<00:00 Sampling 4 chains, 0 divergences]\n", "
\n", " " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stderr", "output_type": "stream", "text": [ "Sampling 4 chains for 1_000 tune and 3_000 draw iterations (4_000 + 12_000 draws total) took 15 seconds.\n" ] } ], "source": [ "with pm.Model() as m:\n", " tau_alpha = pm.Gamma(\"tau_alpha\", 0.01, 0.01)\n", " var_alpha = pm.Deterministic(\"var_alpha\", 1 / tau_alpha)\n", " tau_delta = pm.Gamma(\"tau_delta\", 0.01, 0.01)\n", " # there's a typo for mu in BUGS version\n", " mu_delta = pm.Normal(\"mu_delta\", 0, tau=0.001)\n", "\n", " # the 1s in the shapes are for broadcasting\n", " delta = pm.Normal(\"delta\", mu_delta, tau=tau_delta, shape=(1, 33))\n", " alpha = pm.Normal(\"alpha\", 0, tau=tau_alpha, shape=(162, 1))\n", "\n", " p = alpha - delta\n", "\n", " pm.Bernoulli(\"likelihood\", logit_p=p, observed=y)\n", "\n", " trace = pm.sample(3000)" ] }, { "cell_type": "code", "execution_count": 4, "id": "3119daff", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
meansdhdi_3%hdi_97%mcse_meanmcse_sdess_bulkess_tailr_hat
delta[0, 0]0.9510.2310.5091.3750.0050.0032540.06221.01.0
delta[0, 1]0.5740.2220.1711.0100.0050.0032302.04440.01.0
delta[0, 2]0.2550.219-0.1630.6590.0040.0032399.05453.01.0
delta[0, 3]1.4510.2450.9821.9010.0050.0032841.06140.01.0
delta[0, 4]-1.0180.225-1.428-0.5840.0050.0032242.04786.01.0
delta[0, 5]0.6840.2290.2461.1030.0050.0032355.05108.01.0
delta[0, 6]0.9530.2310.5111.3780.0050.0032375.04748.01.0
delta[0, 7]0.7590.2270.3381.1890.0040.0032814.05476.01.0
delta[0, 8]0.6490.2270.2191.0700.0050.0032536.05835.01.0
delta[0, 9]1.6820.2511.2172.1580.0050.0032722.05779.01.0
delta[0, 10]1.1520.2350.7131.5950.0040.0032770.05242.01.0
delta[0, 11]-1.2010.227-1.627-0.7790.0050.0032415.05362.01.0
delta[0, 12]-0.1870.223-0.6200.2250.0050.0032194.05197.01.0
delta[0, 13]-0.0880.215-0.5000.3010.0050.0031996.04825.01.0
delta[0, 14]0.8730.2290.4451.3050.0050.0032496.05520.01.0
delta[0, 15]2.2710.2741.7432.7760.0050.0033313.06834.01.0
delta[0, 16]1.5430.2431.0842.0020.0040.0032953.06441.01.0
delta[0, 17]-0.2910.220-0.7240.0910.0050.0032148.05395.01.0
delta[0, 18]0.0490.218-0.3600.4610.0040.0032413.05464.01.0
delta[0, 19]-1.3150.233-1.742-0.8730.0050.0032466.05697.01.0
delta[0, 20]2.8050.3002.2093.3420.0050.0043432.06759.01.0
delta[0, 21]-0.4580.217-0.854-0.0430.0050.0032296.04904.01.0
delta[0, 22]0.5730.2240.1340.9840.0040.0032548.05782.01.0
delta[0, 23]0.3590.218-0.0530.7690.0040.0032473.05795.01.0
delta[0, 24]-1.7670.243-2.206-1.2910.0050.0032708.06129.01.0
delta[0, 25]-1.4740.232-1.906-1.0380.0050.0032537.06302.01.0
delta[0, 26]0.8720.2260.4471.2930.0040.0032635.05833.01.0
delta[0, 27]0.0840.219-0.3290.4940.0050.0032220.05476.01.0
delta[0, 28]1.3190.2390.8811.7740.0050.0032536.06135.01.0
delta[0, 29]-0.3940.219-0.8040.0130.0050.0032209.04431.01.0
delta[0, 30]0.3980.224-0.0290.8100.0050.0032232.05299.01.0
delta[0, 31]-0.8740.222-1.279-0.4550.0050.0032324.05660.01.0
delta[0, 32]0.7230.2270.3081.1590.0050.0032448.05629.01.0
\n", "
" ], "text/plain": [ " mean sd hdi_3% hdi_97% mcse_mean mcse_sd ess_bulk \\\n", "delta[0, 0] 0.951 0.231 0.509 1.375 0.005 0.003 2540.0 \n", "delta[0, 1] 0.574 0.222 0.171 1.010 0.005 0.003 2302.0 \n", "delta[0, 2] 0.255 0.219 -0.163 0.659 0.004 0.003 2399.0 \n", "delta[0, 3] 1.451 0.245 0.982 1.901 0.005 0.003 2841.0 \n", "delta[0, 4] -1.018 0.225 -1.428 -0.584 0.005 0.003 2242.0 \n", "delta[0, 5] 0.684 0.229 0.246 1.103 0.005 0.003 2355.0 \n", "delta[0, 6] 0.953 0.231 0.511 1.378 0.005 0.003 2375.0 \n", "delta[0, 7] 0.759 0.227 0.338 1.189 0.004 0.003 2814.0 \n", "delta[0, 8] 0.649 0.227 0.219 1.070 0.005 0.003 2536.0 \n", "delta[0, 9] 1.682 0.251 1.217 2.158 0.005 0.003 2722.0 \n", "delta[0, 10] 1.152 0.235 0.713 1.595 0.004 0.003 2770.0 \n", "delta[0, 11] -1.201 0.227 -1.627 -0.779 0.005 0.003 2415.0 \n", "delta[0, 12] -0.187 0.223 -0.620 0.225 0.005 0.003 2194.0 \n", "delta[0, 13] -0.088 0.215 -0.500 0.301 0.005 0.003 1996.0 \n", "delta[0, 14] 0.873 0.229 0.445 1.305 0.005 0.003 2496.0 \n", "delta[0, 15] 2.271 0.274 1.743 2.776 0.005 0.003 3313.0 \n", "delta[0, 16] 1.543 0.243 1.084 2.002 0.004 0.003 2953.0 \n", "delta[0, 17] -0.291 0.220 -0.724 0.091 0.005 0.003 2148.0 \n", "delta[0, 18] 0.049 0.218 -0.360 0.461 0.004 0.003 2413.0 \n", "delta[0, 19] -1.315 0.233 -1.742 -0.873 0.005 0.003 2466.0 \n", "delta[0, 20] 2.805 0.300 2.209 3.342 0.005 0.004 3432.0 \n", "delta[0, 21] -0.458 0.217 -0.854 -0.043 0.005 0.003 2296.0 \n", "delta[0, 22] 0.573 0.224 0.134 0.984 0.004 0.003 2548.0 \n", "delta[0, 23] 0.359 0.218 -0.053 0.769 0.004 0.003 2473.0 \n", "delta[0, 24] -1.767 0.243 -2.206 -1.291 0.005 0.003 2708.0 \n", "delta[0, 25] -1.474 0.232 -1.906 -1.038 0.005 0.003 2537.0 \n", "delta[0, 26] 0.872 0.226 0.447 1.293 0.004 0.003 2635.0 \n", "delta[0, 27] 0.084 0.219 -0.329 0.494 0.005 0.003 2220.0 \n", "delta[0, 28] 1.319 0.239 0.881 1.774 0.005 0.003 2536.0 \n", "delta[0, 29] -0.394 0.219 -0.804 0.013 0.005 0.003 2209.0 \n", "delta[0, 30] 0.398 0.224 -0.029 0.810 0.005 0.003 2232.0 \n", "delta[0, 31] -0.874 0.222 -1.279 -0.455 0.005 0.003 2324.0 \n", "delta[0, 32] 0.723 0.227 0.308 1.159 0.005 0.003 2448.0 \n", "\n", " ess_tail r_hat \n", "delta[0, 0] 6221.0 1.0 \n", "delta[0, 1] 4440.0 1.0 \n", "delta[0, 2] 5453.0 1.0 \n", "delta[0, 3] 6140.0 1.0 \n", "delta[0, 4] 4786.0 1.0 \n", "delta[0, 5] 5108.0 1.0 \n", "delta[0, 6] 4748.0 1.0 \n", "delta[0, 7] 5476.0 1.0 \n", "delta[0, 8] 5835.0 1.0 \n", "delta[0, 9] 5779.0 1.0 \n", "delta[0, 10] 5242.0 1.0 \n", "delta[0, 11] 5362.0 1.0 \n", "delta[0, 12] 5197.0 1.0 \n", "delta[0, 13] 4825.0 1.0 \n", "delta[0, 14] 5520.0 1.0 \n", "delta[0, 15] 6834.0 1.0 \n", "delta[0, 16] 6441.0 1.0 \n", "delta[0, 17] 5395.0 1.0 \n", "delta[0, 18] 5464.0 1.0 \n", "delta[0, 19] 5697.0 1.0 \n", "delta[0, 20] 6759.0 1.0 \n", "delta[0, 21] 4904.0 1.0 \n", "delta[0, 22] 5782.0 1.0 \n", "delta[0, 23] 5795.0 1.0 \n", "delta[0, 24] 6129.0 1.0 \n", "delta[0, 25] 6302.0 1.0 \n", "delta[0, 26] 5833.0 1.0 \n", "delta[0, 27] 5476.0 1.0 \n", "delta[0, 28] 6135.0 1.0 \n", "delta[0, 29] 4431.0 1.0 \n", "delta[0, 30] 5299.0 1.0 \n", "delta[0, 31] 5660.0 1.0 \n", "delta[0, 32] 5629.0 1.0 " ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "az.summary(trace, var_names=[\"delta\"])" ] }, { "cell_type": "code", "execution_count": 5, "id": "cacc39d4", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Last updated: Wed Mar 22 2023\n", "\n", "Python implementation: CPython\n", "Python version : 3.11.0\n", "IPython version : 8.9.0\n", "\n", "pytensor: 2.10.1\n", "\n", "arviz: 0.15.1\n", "numpy: 1.24.2\n", "pymc : 5.1.2\n", "\n" ] } ], "source": [ "%load_ext watermark\n", "%watermark -n -u -v -iv -p pytensor" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.0" } }, "nbformat": 4, "nbformat_minor": 5 }