{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "6a0dc113",
"metadata": {},
"outputs": [],
"source": [
"import pymc as pm\n",
"import numpy as np\n",
"import arviz as az\n",
"import pandas as pd\n",
"import pytensor.tensor.subtensor as st\n",
"from itertools import combinations\n",
"\n",
"%load_ext lab_black"
]
},
{
"cell_type": "markdown",
"id": "d8a18872",
"metadata": {},
"source": [
"# 9. Simvastatin*\n",
"\n",
"This one is about factorial designs (2-way ANOVA) with sum-to-zero and corner constraints.\n",
"\n",
"Adapted from [Unit 7: simvastatin.odc](https://raw.githubusercontent.com/areding/6420-pymc/main/original_examples/Codes4Unit7/simvastatin.odc).\n",
"\n",
"Data can be found [here](https://raw.githubusercontent.com/areding/6420-pymc/main/data/simvastatin_data.tsv).\n",
"\n",
"Thanks to [Anthony Miyaguchi](https://github.com/acmiyaguchi) for updating this example!"
]
},
{
"cell_type": "markdown",
"id": "dc5f944e",
"metadata": {},
"source": [
"In a quantitative physiology lab II at Georgia Tech, students were asked to find a therapeutic model to test on MC3T3-E1 cell line to enhance osteoblastic growth. The students found a drug called Simvastatin, a cholesterol lowering drug to test on these cells. Using a control and three different concentrations, $10^{-9}$, $10^{-8}$ and $10^{-7}$ M, cells were treated with the drug. These cells were plated on four, 24 well plates with each well plate having a different treatment. To test for osteoblastic differentiation an assay, pNPP, was used to test for alkaline phosphatase activity. The higher the alkaline phosphatase activity the better the cells are differentiating, and become more bone like. This assay was performed 6 times total within 11 days. Each time the assay was performed, four wells from each plate were used.\n"
]
},
{
"cell_type": "markdown",
"id": "8fbbed4a",
"metadata": {},
"source": [
"## Notes: \n",
"\n",
"A [good explanation](https://stats.stackexchange.com/questions/257778/sum-to-zero-constraint-in-one-way-anova) of STZ constraints.\n"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "26142563",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"
\n",
"\n",
"
\n",
" \n",
" \n",
" \n",
" apa \n",
" conc \n",
" time \n",
" \n",
" \n",
" \n",
" \n",
" 0 \n",
" 0.062 \n",
" 1 \n",
" 1 \n",
" \n",
" \n",
" 1 \n",
" 0.517 \n",
" 1 \n",
" 1 \n",
" \n",
" \n",
" 2 \n",
" 0.261 \n",
" 1 \n",
" 1 \n",
" \n",
" \n",
"
\n",
"
"
],
"text/plain": [
" apa conc time\n",
"0 0.062 1 1\n",
"1 0.517 1 1\n",
"2 0.261 1 1"
]
},
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"data = pd.read_csv(\"../data/simvastatin_data.tsv\", sep=\"\\t\")\n",
"data.head(3)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "a091a38f",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(array([0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 0, 0, 0, 0, 1, 1,\n",
" 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2,\n",
" 3, 3, 3, 3, 0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 0, 0,\n",
" 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 0, 0, 0, 0, 1, 1, 1, 1,\n",
" 2, 2, 2, 2, 3, 3, 3, 3]),\n",
" array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1,\n",
" 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,\n",
" 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 4, 4,\n",
" 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 5, 5, 5,\n",
" 5, 5, 5, 5, 5, 5, 5, 5]),\n",
" {'conc': Int64Index([1, 2, 3, 4], dtype='int64'),\n",
" 'time': Int64Index([1, 2, 3, 4, 5, 6], dtype='int64'),\n",
" 'id': RangeIndex(start=0, stop=96, step=1)})"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# set up alternate coordinates, the ID3 or clusters column\n",
"conc_idx, conc = pd.factorize(data[\"conc\"])\n",
"time_idx, time = pd.factorize(data[\"time\"])\n",
"coords = {\"conc\": conc, \"time\": time, \"id\": data.index}\n",
"\n",
"conc_idx, time_idx, coords"
]
},
{
"cell_type": "markdown",
"id": "de2cd6c1",
"metadata": {},
"source": [
"## Model 1 with sum-to-zero constraints"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "891c822f",
"metadata": {
"tags": [
"hide-output"
]
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Auto-assigning NUTS sampler...\n",
"Initializing NUTS using jitter+adapt_diag...\n",
"Multiprocess sampling (4 chains in 4 jobs)\n",
"NUTS: [mu0, _alpha, _beta, _alphabeta, tau]\n"
]
},
{
"data": {
"text/html": [
"\n",
"\n"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"\n",
" \n",
"
\n",
" 100.00% [12000/12000 00:07<00:00 Sampling 4 chains, 0 divergences]\n",
"
\n",
" "
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Sampling 4 chains for 1_000 tune and 2_000 draw iterations (4_000 + 8_000 draws total) took 8 seconds.\n"
]
}
],
"source": [
"def differences(var, index):\n",
" \"\"\"Calculate differences between levels with names like \"alpha[low] - alpha[high]\".\n",
"\n",
" var: aesara.tensor.var.TensorVariable\n",
" index: pandas.Index\n",
" \"\"\"\n",
" name = var.name\n",
" for i, j in combinations(range(index.size), 2):\n",
" a, b = index[i], index[j]\n",
" pm.Deterministic(f\"{name}[{a}] - {name}[{b}]\", var[i] - var[j])\n",
"\n",
"\n",
"with pm.Model(coords=coords) as m:\n",
" apa_data = pm.Data(\"apa_data\", data.apa, mutable=False)\n",
" time_idx_data = pm.Data(\"time_idx_data\", time_idx, dims=\"id\", mutable=False)\n",
" conc_idx_data = pm.Data(\"conc_idx_data\", conc_idx, dims=\"id\", mutable=False)\n",
"\n",
" mu0 = pm.Normal(\"mu0\", 0, tau=0.0001)\n",
" _alpha = pm.Normal(\"_alpha\", 0, tau=0.0001, dims=\"conc\")\n",
" _beta = pm.Normal(\"_beta\", 0, tau=0.0001, dims=\"time\")\n",
" _alphabeta = pm.Normal(\"_alphabeta\", 0, tau=0.0001, dims=(\"conc\", \"time\"))\n",
" tau = pm.Gamma(\"tau\", 0.001, 0.001)\n",
" sigma = pm.Deterministic(\"sigma\", 1 / tau**0.5)\n",
"\n",
" # sum-to-zero constraints\n",
" # sets the first element of a dimension to the negative sum of the rest\n",
" sst_1d_0 = lambda var: st.set_subtensor(var[0], -var[1:].sum(axis=0))\n",
" sst_2d_0 = lambda var: st.set_subtensor(var[0, :], -var[1:, :].sum(axis=0))\n",
" sst_2d_1 = lambda var: st.set_subtensor(var[:, 0], -var[:, 1:].sum(axis=1))\n",
"\n",
" alpha = pm.Deterministic(\"alpha\", sst_1d_0(_alpha), dims=\"conc\")\n",
" beta = pm.Deterministic(\"beta\", sst_1d_0(_beta), dims=\"time\")\n",
" _alphabeta = sst_2d_1(_alphabeta)\n",
" alphabeta = pm.Deterministic(\n",
" \"alphabeta\", sst_2d_0(_alphabeta), dims=(\"conc\", \"time\")\n",
" )\n",
"\n",
" mu = (\n",
" mu0\n",
" + alpha[conc_idx_data]\n",
" + beta[time_idx_data]\n",
" + alphabeta[conc_idx_data, time_idx_data]\n",
" )\n",
" pm.Normal(\"apa\", mu, tau=tau, observed=apa_data, dims=\"id\")\n",
"\n",
" # calculate differences between levels with appropriate names\n",
" differences(alpha, coords[\"conc\"])\n",
" differences(beta, coords[\"time\"])\n",
"\n",
" trace = pm.sample(2000)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "ef27e6a9",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" \n",
" mean \n",
" sd \n",
" hdi_3% \n",
" hdi_97% \n",
" \n",
" \n",
" \n",
" \n",
" mu0 \n",
" 0.239 \n",
" 0.025 \n",
" 0.192 \n",
" 0.286 \n",
" \n",
" \n",
" tau \n",
" 17.233 \n",
" 2.874 \n",
" 12.020 \n",
" 22.775 \n",
" \n",
" \n",
" sigma \n",
" 0.243 \n",
" 0.021 \n",
" 0.206 \n",
" 0.283 \n",
" \n",
" \n",
" alpha[1] \n",
" 0.049 \n",
" 0.043 \n",
" -0.033 \n",
" 0.130 \n",
" \n",
" \n",
" alpha[2] \n",
" 0.068 \n",
" 0.044 \n",
" -0.014 \n",
" 0.149 \n",
" \n",
" \n",
" alpha[3] \n",
" -0.075 \n",
" 0.043 \n",
" -0.153 \n",
" 0.008 \n",
" \n",
" \n",
" alpha[4] \n",
" -0.042 \n",
" 0.042 \n",
" -0.121 \n",
" 0.037 \n",
" \n",
" \n",
" beta[1] \n",
" 0.046 \n",
" 0.055 \n",
" -0.058 \n",
" 0.149 \n",
" \n",
" \n",
" beta[2] \n",
" -0.150 \n",
" 0.055 \n",
" -0.258 \n",
" -0.047 \n",
" \n",
" \n",
" beta[3] \n",
" -0.019 \n",
" 0.055 \n",
" -0.118 \n",
" 0.090 \n",
" \n",
" \n",
" beta[4] \n",
" 0.230 \n",
" 0.055 \n",
" 0.123 \n",
" 0.329 \n",
" \n",
" \n",
" beta[5] \n",
" -0.022 \n",
" 0.056 \n",
" -0.129 \n",
" 0.081 \n",
" \n",
" \n",
" beta[6] \n",
" -0.086 \n",
" 0.057 \n",
" -0.190 \n",
" 0.023 \n",
" \n",
" \n",
" alphabeta[1, 1] \n",
" -0.085 \n",
" 0.097 \n",
" -0.267 \n",
" 0.099 \n",
" \n",
" \n",
" alphabeta[1, 2] \n",
" -0.083 \n",
" 0.096 \n",
" -0.265 \n",
" 0.095 \n",
" \n",
" \n",
" alphabeta[1, 3] \n",
" -0.210 \n",
" 0.098 \n",
" -0.393 \n",
" -0.028 \n",
" \n",
" \n",
" alphabeta[1, 4] \n",
" 0.501 \n",
" 0.095 \n",
" 0.316 \n",
" 0.673 \n",
" \n",
" \n",
" alphabeta[1, 5] \n",
" -0.019 \n",
" 0.097 \n",
" -0.207 \n",
" 0.160 \n",
" \n",
" \n",
" alphabeta[1, 6] \n",
" -0.104 \n",
" 0.096 \n",
" -0.276 \n",
" 0.081 \n",
" \n",
" \n",
" alphabeta[2, 1] \n",
" 0.150 \n",
" 0.097 \n",
" -0.027 \n",
" 0.338 \n",
" \n",
" \n",
" alphabeta[2, 2] \n",
" -0.100 \n",
" 0.097 \n",
" -0.283 \n",
" 0.081 \n",
" \n",
" \n",
" alphabeta[2, 3] \n",
" 0.354 \n",
" 0.097 \n",
" 0.173 \n",
" 0.535 \n",
" \n",
" \n",
" alphabeta[2, 4] \n",
" -0.176 \n",
" 0.095 \n",
" -0.362 \n",
" -0.003 \n",
" \n",
" \n",
" alphabeta[2, 5] \n",
" -0.187 \n",
" 0.094 \n",
" -0.358 \n",
" -0.006 \n",
" \n",
" \n",
" alphabeta[2, 6] \n",
" -0.042 \n",
" 0.098 \n",
" -0.225 \n",
" 0.141 \n",
" \n",
" \n",
" alphabeta[3, 1] \n",
" -0.043 \n",
" 0.098 \n",
" -0.226 \n",
" 0.140 \n",
" \n",
" \n",
" alphabeta[3, 2] \n",
" 0.056 \n",
" 0.097 \n",
" -0.119 \n",
" 0.241 \n",
" \n",
" \n",
" alphabeta[3, 3] \n",
" -0.046 \n",
" 0.097 \n",
" -0.228 \n",
" 0.133 \n",
" \n",
" \n",
" alphabeta[3, 4] \n",
" -0.138 \n",
" 0.096 \n",
" -0.310 \n",
" 0.051 \n",
" \n",
" \n",
" alphabeta[3, 5] \n",
" 0.158 \n",
" 0.094 \n",
" -0.020 \n",
" 0.335 \n",
" \n",
" \n",
" alphabeta[3, 6] \n",
" 0.013 \n",
" 0.098 \n",
" -0.171 \n",
" 0.198 \n",
" \n",
" \n",
" alphabeta[4, 1] \n",
" -0.022 \n",
" 0.095 \n",
" -0.196 \n",
" 0.158 \n",
" \n",
" \n",
" alphabeta[4, 2] \n",
" 0.126 \n",
" 0.094 \n",
" -0.056 \n",
" 0.301 \n",
" \n",
" \n",
" alphabeta[4, 3] \n",
" -0.098 \n",
" 0.095 \n",
" -0.283 \n",
" 0.076 \n",
" \n",
" \n",
" alphabeta[4, 4] \n",
" -0.187 \n",
" 0.097 \n",
" -0.353 \n",
" 0.012 \n",
" \n",
" \n",
" alphabeta[4, 5] \n",
" 0.048 \n",
" 0.095 \n",
" -0.132 \n",
" 0.223 \n",
" \n",
" \n",
" alphabeta[4, 6] \n",
" 0.133 \n",
" 0.095 \n",
" -0.047 \n",
" 0.311 \n",
" \n",
" \n",
" alpha[1] - alpha[2] \n",
" -0.019 \n",
" 0.071 \n",
" -0.156 \n",
" 0.112 \n",
" \n",
" \n",
" alpha[1] - alpha[3] \n",
" 0.125 \n",
" 0.069 \n",
" -0.004 \n",
" 0.258 \n",
" \n",
" \n",
" alpha[1] - alpha[4] \n",
" 0.092 \n",
" 0.069 \n",
" -0.034 \n",
" 0.226 \n",
" \n",
" \n",
" alpha[2] - alpha[3] \n",
" 0.144 \n",
" 0.071 \n",
" 0.010 \n",
" 0.275 \n",
" \n",
" \n",
" alpha[2] - alpha[4] \n",
" 0.111 \n",
" 0.070 \n",
" -0.021 \n",
" 0.241 \n",
" \n",
" \n",
" alpha[3] - alpha[4] \n",
" -0.033 \n",
" 0.069 \n",
" -0.164 \n",
" 0.098 \n",
" \n",
" \n",
" beta[1] - beta[2] \n",
" 0.195 \n",
" 0.086 \n",
" 0.041 \n",
" 0.365 \n",
" \n",
" \n",
" beta[1] - beta[3] \n",
" 0.064 \n",
" 0.086 \n",
" -0.100 \n",
" 0.224 \n",
" \n",
" \n",
" beta[1] - beta[4] \n",
" -0.185 \n",
" 0.086 \n",
" -0.343 \n",
" -0.019 \n",
" \n",
" \n",
" beta[1] - beta[5] \n",
" 0.068 \n",
" 0.086 \n",
" -0.093 \n",
" 0.229 \n",
" \n",
" \n",
" beta[1] - beta[6] \n",
" 0.131 \n",
" 0.087 \n",
" -0.028 \n",
" 0.297 \n",
" \n",
" \n",
" beta[2] - beta[3] \n",
" -0.131 \n",
" 0.085 \n",
" -0.292 \n",
" 0.027 \n",
" \n",
" \n",
" beta[2] - beta[4] \n",
" -0.380 \n",
" 0.085 \n",
" -0.542 \n",
" -0.225 \n",
" \n",
" \n",
" beta[2] - beta[5] \n",
" -0.128 \n",
" 0.088 \n",
" -0.293 \n",
" 0.039 \n",
" \n",
" \n",
" beta[2] - beta[6] \n",
" -0.064 \n",
" 0.087 \n",
" -0.227 \n",
" 0.099 \n",
" \n",
" \n",
" beta[3] - beta[4] \n",
" -0.249 \n",
" 0.086 \n",
" -0.404 \n",
" -0.088 \n",
" \n",
" \n",
" beta[3] - beta[5] \n",
" 0.003 \n",
" 0.086 \n",
" -0.160 \n",
" 0.158 \n",
" \n",
" \n",
" beta[3] - beta[6] \n",
" 0.067 \n",
" 0.088 \n",
" -0.093 \n",
" 0.235 \n",
" \n",
" \n",
" beta[4] - beta[5] \n",
" 0.252 \n",
" 0.086 \n",
" 0.081 \n",
" 0.406 \n",
" \n",
" \n",
" beta[4] - beta[6] \n",
" 0.316 \n",
" 0.087 \n",
" 0.149 \n",
" 0.475 \n",
" \n",
" \n",
" beta[5] - beta[6] \n",
" 0.064 \n",
" 0.087 \n",
" -0.099 \n",
" 0.228 \n",
" \n",
" \n",
"
\n",
"
"
],
"text/plain": [
" mean sd hdi_3% hdi_97%\n",
"mu0 0.239 0.025 0.192 0.286\n",
"tau 17.233 2.874 12.020 22.775\n",
"sigma 0.243 0.021 0.206 0.283\n",
"alpha[1] 0.049 0.043 -0.033 0.130\n",
"alpha[2] 0.068 0.044 -0.014 0.149\n",
"alpha[3] -0.075 0.043 -0.153 0.008\n",
"alpha[4] -0.042 0.042 -0.121 0.037\n",
"beta[1] 0.046 0.055 -0.058 0.149\n",
"beta[2] -0.150 0.055 -0.258 -0.047\n",
"beta[3] -0.019 0.055 -0.118 0.090\n",
"beta[4] 0.230 0.055 0.123 0.329\n",
"beta[5] -0.022 0.056 -0.129 0.081\n",
"beta[6] -0.086 0.057 -0.190 0.023\n",
"alphabeta[1, 1] -0.085 0.097 -0.267 0.099\n",
"alphabeta[1, 2] -0.083 0.096 -0.265 0.095\n",
"alphabeta[1, 3] -0.210 0.098 -0.393 -0.028\n",
"alphabeta[1, 4] 0.501 0.095 0.316 0.673\n",
"alphabeta[1, 5] -0.019 0.097 -0.207 0.160\n",
"alphabeta[1, 6] -0.104 0.096 -0.276 0.081\n",
"alphabeta[2, 1] 0.150 0.097 -0.027 0.338\n",
"alphabeta[2, 2] -0.100 0.097 -0.283 0.081\n",
"alphabeta[2, 3] 0.354 0.097 0.173 0.535\n",
"alphabeta[2, 4] -0.176 0.095 -0.362 -0.003\n",
"alphabeta[2, 5] -0.187 0.094 -0.358 -0.006\n",
"alphabeta[2, 6] -0.042 0.098 -0.225 0.141\n",
"alphabeta[3, 1] -0.043 0.098 -0.226 0.140\n",
"alphabeta[3, 2] 0.056 0.097 -0.119 0.241\n",
"alphabeta[3, 3] -0.046 0.097 -0.228 0.133\n",
"alphabeta[3, 4] -0.138 0.096 -0.310 0.051\n",
"alphabeta[3, 5] 0.158 0.094 -0.020 0.335\n",
"alphabeta[3, 6] 0.013 0.098 -0.171 0.198\n",
"alphabeta[4, 1] -0.022 0.095 -0.196 0.158\n",
"alphabeta[4, 2] 0.126 0.094 -0.056 0.301\n",
"alphabeta[4, 3] -0.098 0.095 -0.283 0.076\n",
"alphabeta[4, 4] -0.187 0.097 -0.353 0.012\n",
"alphabeta[4, 5] 0.048 0.095 -0.132 0.223\n",
"alphabeta[4, 6] 0.133 0.095 -0.047 0.311\n",
"alpha[1] - alpha[2] -0.019 0.071 -0.156 0.112\n",
"alpha[1] - alpha[3] 0.125 0.069 -0.004 0.258\n",
"alpha[1] - alpha[4] 0.092 0.069 -0.034 0.226\n",
"alpha[2] - alpha[3] 0.144 0.071 0.010 0.275\n",
"alpha[2] - alpha[4] 0.111 0.070 -0.021 0.241\n",
"alpha[3] - alpha[4] -0.033 0.069 -0.164 0.098\n",
"beta[1] - beta[2] 0.195 0.086 0.041 0.365\n",
"beta[1] - beta[3] 0.064 0.086 -0.100 0.224\n",
"beta[1] - beta[4] -0.185 0.086 -0.343 -0.019\n",
"beta[1] - beta[5] 0.068 0.086 -0.093 0.229\n",
"beta[1] - beta[6] 0.131 0.087 -0.028 0.297\n",
"beta[2] - beta[3] -0.131 0.085 -0.292 0.027\n",
"beta[2] - beta[4] -0.380 0.085 -0.542 -0.225\n",
"beta[2] - beta[5] -0.128 0.088 -0.293 0.039\n",
"beta[2] - beta[6] -0.064 0.087 -0.227 0.099\n",
"beta[3] - beta[4] -0.249 0.086 -0.404 -0.088\n",
"beta[3] - beta[5] 0.003 0.086 -0.160 0.158\n",
"beta[3] - beta[6] 0.067 0.088 -0.093 0.235\n",
"beta[4] - beta[5] 0.252 0.086 0.081 0.406\n",
"beta[4] - beta[6] 0.316 0.087 0.149 0.475\n",
"beta[5] - beta[6] 0.064 0.087 -0.099 0.228"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"az.summary(trace, var_names=\"~_\", filter_vars=\"like\", kind=\"stats\")"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "f9ad9ec3",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([], dtype=object)"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "\n",
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"az.plot_forest(trace, var_names=[\"alpha\"], combined=True)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "73781c6c",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([], dtype=object)"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "\n",
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"az.plot_forest(trace, var_names=[\"beta\"], combined=True)"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "b9ba2d64",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([], dtype=object)"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "\n",
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"az.plot_forest(trace, var_names=[\"alphabeta\"], combined=True)"
]
},
{
"cell_type": "markdown",
"id": "32850285",
"metadata": {},
"source": [
"## Model 2 with corner constraints"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "80dc9379",
"metadata": {
"tags": [
"hide-output"
]
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Auto-assigning NUTS sampler...\n",
"Initializing NUTS using jitter+adapt_diag...\n",
"Multiprocess sampling (4 chains in 4 jobs)\n",
"NUTS: [mu0, _alpha, _beta, _alphabeta, tau]\n"
]
},
{
"data": {
"text/html": [
"\n",
"\n"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"\n",
" \n",
"
\n",
" 100.00% [12000/12000 00:10<00:00 Sampling 4 chains, 0 divergences]\n",
"
\n",
" "
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Sampling 4 chains for 1_000 tune and 2_000 draw iterations (4_000 + 8_000 draws total) took 10 seconds.\n"
]
}
],
"source": [
"with pm.Model(coords=coords) as m:\n",
" apa_data = pm.Data(\"apa_data\", data.apa, mutable=False)\n",
" time_idx_data = pm.Data(\"time_idx_data\", time_idx, dims=\"id\", mutable=False)\n",
" conc_idx_data = pm.Data(\"conc_idx_data\", conc_idx, dims=\"id\", mutable=False)\n",
"\n",
" mu0 = pm.Normal(\"mu0\", 0, tau=0.0001)\n",
" _alpha = pm.Normal(\"_alpha\", 0, tau=0.0001, dims=\"conc\")\n",
" _beta = pm.Normal(\"_beta\", 0, tau=0.0001, dims=\"time\")\n",
" _alphabeta = pm.Normal(\"_alphabeta\", 0, tau=0.0001, dims=(\"conc\", \"time\"))\n",
" tau = pm.Gamma(\"tau\", 0.001, 0.001)\n",
" sigma = pm.Deterministic(\"sigma\", 1 / tau**0.5)\n",
"\n",
" # corner constraints: sets the first element of a dimension to zero\n",
" alpha = pm.Deterministic(\"alpha\", st.set_subtensor(_alpha[0], 0), dims=\"conc\")\n",
" beta = pm.Deterministic(\"beta\", st.set_subtensor(_beta[0], 0), dims=\"time\")\n",
" _alphabeta = st.set_subtensor(_alphabeta[:, 0], 0)\n",
" alphabeta = pm.Deterministic(\n",
" \"alphabeta\", st.set_subtensor(_alphabeta[0, :], 0), dims=(\"conc\", \"time\")\n",
" )\n",
"\n",
" mu = (\n",
" mu0\n",
" + alpha[conc_idx_data]\n",
" + beta[time_idx_data]\n",
" + alphabeta[conc_idx_data, time_idx_data]\n",
" )\n",
" pm.Normal(\"apa\", mu, tau=tau, observed=apa_data, dims=\"id\")\n",
"\n",
" differences(alpha, coords[\"conc\"])\n",
" differences(beta, coords[\"time\"])\n",
"\n",
" trace = pm.sample(2000)"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "c77d43a5",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" \n",
" mean \n",
" sd \n",
" hdi_3% \n",
" hdi_97% \n",
" \n",
" \n",
" \n",
" \n",
" mu0 \n",
" 0.247 \n",
" 0.124 \n",
" 0.007 \n",
" 0.469 \n",
" \n",
" \n",
" tau \n",
" 17.120 \n",
" 2.884 \n",
" 11.492 \n",
" 22.201 \n",
" \n",
" \n",
" sigma \n",
" 0.244 \n",
" 0.021 \n",
" 0.206 \n",
" 0.282 \n",
" \n",
" \n",
" alpha[1] \n",
" 0.000 \n",
" 0.000 \n",
" 0.000 \n",
" 0.000 \n",
" \n",
" \n",
" alpha[2] \n",
" 0.257 \n",
" 0.176 \n",
" -0.086 \n",
" 0.573 \n",
" \n",
" \n",
" alpha[3] \n",
" -0.082 \n",
" 0.173 \n",
" -0.400 \n",
" 0.243 \n",
" \n",
" \n",
" alpha[4] \n",
" -0.025 \n",
" 0.176 \n",
" -0.355 \n",
" 0.302 \n",
" \n",
" \n",
" beta[1] \n",
" 0.000 \n",
" 0.000 \n",
" 0.000 \n",
" 0.000 \n",
" \n",
" \n",
" beta[2] \n",
" -0.191 \n",
" 0.174 \n",
" -0.511 \n",
" 0.140 \n",
" \n",
" \n",
" beta[3] \n",
" -0.186 \n",
" 0.174 \n",
" -0.513 \n",
" 0.145 \n",
" \n",
" \n",
" beta[4] \n",
" 0.771 \n",
" 0.175 \n",
" 0.445 \n",
" 1.100 \n",
" \n",
" \n",
" beta[5] \n",
" -0.001 \n",
" 0.175 \n",
" -0.347 \n",
" 0.309 \n",
" \n",
" \n",
" beta[6] \n",
" -0.146 \n",
" 0.177 \n",
" -0.472 \n",
" 0.195 \n",
" \n",
" \n",
" alphabeta[1, 1] \n",
" 0.000 \n",
" 0.000 \n",
" 0.000 \n",
" 0.000 \n",
" \n",
" \n",
" alphabeta[1, 2] \n",
" 0.000 \n",
" 0.000 \n",
" 0.000 \n",
" 0.000 \n",
" \n",
" \n",
" alphabeta[1, 3] \n",
" 0.000 \n",
" 0.000 \n",
" 0.000 \n",
" 0.000 \n",
" \n",
" \n",
" alphabeta[1, 4] \n",
" 0.000 \n",
" 0.000 \n",
" 0.000 \n",
" 0.000 \n",
" \n",
" \n",
" alphabeta[1, 5] \n",
" 0.000 \n",
" 0.000 \n",
" 0.000 \n",
" 0.000 \n",
" \n",
" \n",
" alphabeta[1, 6] \n",
" 0.000 \n",
" 0.000 \n",
" 0.000 \n",
" 0.000 \n",
" \n",
" \n",
" alphabeta[2, 1] \n",
" 0.000 \n",
" 0.000 \n",
" 0.000 \n",
" 0.000 \n",
" \n",
" \n",
" alphabeta[2, 2] \n",
" -0.256 \n",
" 0.246 \n",
" -0.708 \n",
" 0.212 \n",
" \n",
" \n",
" alphabeta[2, 3] \n",
" 0.324 \n",
" 0.248 \n",
" -0.137 \n",
" 0.787 \n",
" \n",
" \n",
" alphabeta[2, 4] \n",
" -0.912 \n",
" 0.246 \n",
" -1.384 \n",
" -0.458 \n",
" \n",
" \n",
" alphabeta[2, 5] \n",
" -0.406 \n",
" 0.251 \n",
" -0.880 \n",
" 0.057 \n",
" \n",
" \n",
" alphabeta[2, 6] \n",
" -0.180 \n",
" 0.250 \n",
" -0.656 \n",
" 0.287 \n",
" \n",
" \n",
" alphabeta[3, 1] \n",
" 0.000 \n",
" 0.000 \n",
" 0.000 \n",
" 0.000 \n",
" \n",
" \n",
" alphabeta[3, 2] \n",
" 0.096 \n",
" 0.246 \n",
" -0.370 \n",
" 0.552 \n",
" \n",
" \n",
" alphabeta[3, 3] \n",
" 0.120 \n",
" 0.244 \n",
" -0.334 \n",
" 0.583 \n",
" \n",
" \n",
" alphabeta[3, 4] \n",
" -0.682 \n",
" 0.248 \n",
" -1.140 \n",
" -0.199 \n",
" \n",
" \n",
" alphabeta[3, 5] \n",
" 0.138 \n",
" 0.245 \n",
" -0.335 \n",
" 0.586 \n",
" \n",
" \n",
" alphabeta[3, 6] \n",
" 0.069 \n",
" 0.246 \n",
" -0.395 \n",
" 0.527 \n",
" \n",
" \n",
" alphabeta[4, 1] \n",
" 0.000 \n",
" 0.000 \n",
" 0.000 \n",
" 0.000 \n",
" \n",
" \n",
" alphabeta[4, 2] \n",
" 0.141 \n",
" 0.249 \n",
" -0.324 \n",
" 0.603 \n",
" \n",
" \n",
" alphabeta[4, 3] \n",
" 0.045 \n",
" 0.247 \n",
" -0.428 \n",
" 0.502 \n",
" \n",
" \n",
" alphabeta[4, 4] \n",
" -0.753 \n",
" 0.250 \n",
" -1.230 \n",
" -0.296 \n",
" \n",
" \n",
" alphabeta[4, 5] \n",
" -0.002 \n",
" 0.249 \n",
" -0.473 \n",
" 0.449 \n",
" \n",
" \n",
" alphabeta[4, 6] \n",
" 0.166 \n",
" 0.248 \n",
" -0.283 \n",
" 0.639 \n",
" \n",
" \n",
" alpha[1] - alpha[2] \n",
" -0.257 \n",
" 0.176 \n",
" -0.573 \n",
" 0.086 \n",
" \n",
" \n",
" alpha[1] - alpha[3] \n",
" 0.082 \n",
" 0.173 \n",
" -0.243 \n",
" 0.400 \n",
" \n",
" \n",
" alpha[1] - alpha[4] \n",
" 0.025 \n",
" 0.176 \n",
" -0.302 \n",
" 0.355 \n",
" \n",
" \n",
" alpha[2] - alpha[3] \n",
" 0.338 \n",
" 0.176 \n",
" 0.013 \n",
" 0.672 \n",
" \n",
" \n",
" alpha[2] - alpha[4] \n",
" 0.282 \n",
" 0.175 \n",
" -0.041 \n",
" 0.619 \n",
" \n",
" \n",
" alpha[3] - alpha[4] \n",
" -0.057 \n",
" 0.175 \n",
" -0.378 \n",
" 0.276 \n",
" \n",
" \n",
" beta[1] - beta[2] \n",
" 0.191 \n",
" 0.174 \n",
" -0.140 \n",
" 0.511 \n",
" \n",
" \n",
" beta[1] - beta[3] \n",
" 0.186 \n",
" 0.174 \n",
" -0.145 \n",
" 0.513 \n",
" \n",
" \n",
" beta[1] - beta[4] \n",
" -0.771 \n",
" 0.175 \n",
" -1.100 \n",
" -0.445 \n",
" \n",
" \n",
" beta[1] - beta[5] \n",
" 0.001 \n",
" 0.175 \n",
" -0.309 \n",
" 0.347 \n",
" \n",
" \n",
" beta[1] - beta[6] \n",
" 0.146 \n",
" 0.177 \n",
" -0.195 \n",
" 0.472 \n",
" \n",
" \n",
" beta[2] - beta[3] \n",
" -0.005 \n",
" 0.176 \n",
" -0.334 \n",
" 0.323 \n",
" \n",
" \n",
" beta[2] - beta[4] \n",
" -0.962 \n",
" 0.177 \n",
" -1.297 \n",
" -0.637 \n",
" \n",
" \n",
" beta[2] - beta[5] \n",
" -0.190 \n",
" 0.175 \n",
" -0.526 \n",
" 0.127 \n",
" \n",
" \n",
" beta[2] - beta[6] \n",
" -0.045 \n",
" 0.176 \n",
" -0.360 \n",
" 0.308 \n",
" \n",
" \n",
" beta[3] - beta[4] \n",
" -0.956 \n",
" 0.173 \n",
" -1.296 \n",
" -0.645 \n",
" \n",
" \n",
" beta[3] - beta[5] \n",
" -0.185 \n",
" 0.172 \n",
" -0.515 \n",
" 0.126 \n",
" \n",
" \n",
" beta[3] - beta[6] \n",
" -0.040 \n",
" 0.173 \n",
" -0.363 \n",
" 0.285 \n",
" \n",
" \n",
" beta[4] - beta[5] \n",
" 0.771 \n",
" 0.175 \n",
" 0.458 \n",
" 1.103 \n",
" \n",
" \n",
" beta[4] - beta[6] \n",
" 0.916 \n",
" 0.177 \n",
" 0.600 \n",
" 1.264 \n",
" \n",
" \n",
" beta[5] - beta[6] \n",
" 0.145 \n",
" 0.176 \n",
" -0.184 \n",
" 0.474 \n",
" \n",
" \n",
"
\n",
"
"
],
"text/plain": [
" mean sd hdi_3% hdi_97%\n",
"mu0 0.247 0.124 0.007 0.469\n",
"tau 17.120 2.884 11.492 22.201\n",
"sigma 0.244 0.021 0.206 0.282\n",
"alpha[1] 0.000 0.000 0.000 0.000\n",
"alpha[2] 0.257 0.176 -0.086 0.573\n",
"alpha[3] -0.082 0.173 -0.400 0.243\n",
"alpha[4] -0.025 0.176 -0.355 0.302\n",
"beta[1] 0.000 0.000 0.000 0.000\n",
"beta[2] -0.191 0.174 -0.511 0.140\n",
"beta[3] -0.186 0.174 -0.513 0.145\n",
"beta[4] 0.771 0.175 0.445 1.100\n",
"beta[5] -0.001 0.175 -0.347 0.309\n",
"beta[6] -0.146 0.177 -0.472 0.195\n",
"alphabeta[1, 1] 0.000 0.000 0.000 0.000\n",
"alphabeta[1, 2] 0.000 0.000 0.000 0.000\n",
"alphabeta[1, 3] 0.000 0.000 0.000 0.000\n",
"alphabeta[1, 4] 0.000 0.000 0.000 0.000\n",
"alphabeta[1, 5] 0.000 0.000 0.000 0.000\n",
"alphabeta[1, 6] 0.000 0.000 0.000 0.000\n",
"alphabeta[2, 1] 0.000 0.000 0.000 0.000\n",
"alphabeta[2, 2] -0.256 0.246 -0.708 0.212\n",
"alphabeta[2, 3] 0.324 0.248 -0.137 0.787\n",
"alphabeta[2, 4] -0.912 0.246 -1.384 -0.458\n",
"alphabeta[2, 5] -0.406 0.251 -0.880 0.057\n",
"alphabeta[2, 6] -0.180 0.250 -0.656 0.287\n",
"alphabeta[3, 1] 0.000 0.000 0.000 0.000\n",
"alphabeta[3, 2] 0.096 0.246 -0.370 0.552\n",
"alphabeta[3, 3] 0.120 0.244 -0.334 0.583\n",
"alphabeta[3, 4] -0.682 0.248 -1.140 -0.199\n",
"alphabeta[3, 5] 0.138 0.245 -0.335 0.586\n",
"alphabeta[3, 6] 0.069 0.246 -0.395 0.527\n",
"alphabeta[4, 1] 0.000 0.000 0.000 0.000\n",
"alphabeta[4, 2] 0.141 0.249 -0.324 0.603\n",
"alphabeta[4, 3] 0.045 0.247 -0.428 0.502\n",
"alphabeta[4, 4] -0.753 0.250 -1.230 -0.296\n",
"alphabeta[4, 5] -0.002 0.249 -0.473 0.449\n",
"alphabeta[4, 6] 0.166 0.248 -0.283 0.639\n",
"alpha[1] - alpha[2] -0.257 0.176 -0.573 0.086\n",
"alpha[1] - alpha[3] 0.082 0.173 -0.243 0.400\n",
"alpha[1] - alpha[4] 0.025 0.176 -0.302 0.355\n",
"alpha[2] - alpha[3] 0.338 0.176 0.013 0.672\n",
"alpha[2] - alpha[4] 0.282 0.175 -0.041 0.619\n",
"alpha[3] - alpha[4] -0.057 0.175 -0.378 0.276\n",
"beta[1] - beta[2] 0.191 0.174 -0.140 0.511\n",
"beta[1] - beta[3] 0.186 0.174 -0.145 0.513\n",
"beta[1] - beta[4] -0.771 0.175 -1.100 -0.445\n",
"beta[1] - beta[5] 0.001 0.175 -0.309 0.347\n",
"beta[1] - beta[6] 0.146 0.177 -0.195 0.472\n",
"beta[2] - beta[3] -0.005 0.176 -0.334 0.323\n",
"beta[2] - beta[4] -0.962 0.177 -1.297 -0.637\n",
"beta[2] - beta[5] -0.190 0.175 -0.526 0.127\n",
"beta[2] - beta[6] -0.045 0.176 -0.360 0.308\n",
"beta[3] - beta[4] -0.956 0.173 -1.296 -0.645\n",
"beta[3] - beta[5] -0.185 0.172 -0.515 0.126\n",
"beta[3] - beta[6] -0.040 0.173 -0.363 0.285\n",
"beta[4] - beta[5] 0.771 0.175 0.458 1.103\n",
"beta[4] - beta[6] 0.916 0.177 0.600 1.264\n",
"beta[5] - beta[6] 0.145 0.176 -0.184 0.474"
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"az.summary(trace, var_names=\"~_\", filter_vars=\"like\", kind=\"stats\")"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "8c1ff385",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Last updated: Wed Mar 22 2023\n",
"\n",
"Python implementation: CPython\n",
"Python version : 3.11.0\n",
"IPython version : 8.9.0\n",
"\n",
"aesara: 2.8.10\n",
"aeppl : 0.1.1\n",
"\n",
"numpy : 1.24.2\n",
"pandas : 1.5.3\n",
"pytensor: 2.10.1\n",
"arviz : 0.14.0\n",
"pymc : 5.1.2\n",
"\n"
]
}
],
"source": [
"%load_ext watermark\n",
"%watermark -n -u -v -iv -p aesara,aeppl"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.0"
},
"vscode": {
"interpreter": {
"hash": "e197428f119775a30e0221ede525e07580bbbcb52f3c1ab01042e9594a2688a6"
},
"version": "3.11.0"
}
},
"nbformat": 4,
"nbformat_minor": 5
}