import arviz as az
import pymc as pm
from pymc.math import log, sqr
10. The Zero Trick and Custom Likelihoods*#
Zero-trick Jeremy#
This introduces the “zero trick”, which is a method for specifying custom likelihoods in BUGS. For a more detailed treatment of these methods, see Ntzoufras [2009], page 276, which is where I got this explanation. These tricks are unnecessary in PyMC since we can just define custom distributions directly, but they do seem to work.
Adapted from Unit 6: zerotrickjeremy.odc.
Here’s the model we’re using:
Of course, BUGS can handle this model just fine. But, let’s pretend for a moment that there is no built-in normal distribution. The zero trick takes advantages of some properties of the Poisson distribution to recreate an arbitrary likelihood.
Given a log-likelihood of the form \(l_i = \log f(y; \theta)\),
–Ntzoufras [2009] page 276.
But the rate, \(\lambda\), can’t be negative. So we need to add a constant, C, to keep that from happening.
The normal log-likelihood is:
But the constant terms won’t affect the posterior.
Here’s the model in PyMC:
y = 98
μ = 110
σ = 80**0.5
τ = 120**0.5
C = 10000
inits = {"θ": 100}
with pm.Model() as m:
θ = pm.Flat("θ")
λ1 = pm.Deterministic("λ1", log(σ) + 0.5 * sqr(((y - θ) / σ)) + C)
λ2 = pm.Deterministic("λ2", log(τ) + 0.5 * sqr(((θ - μ) / τ)) + C)
z1 = pm.Poisson("z1", λ1, observed=0)
z2 = pm.Poisson("z2", λ2, observed=0)
trace = pm.sample(5000, tune=1000, initvals=inits, target_accept=0.88)
Show code cell output
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [θ]
Sampling 4 chains for 1_000 tune and 5_000 draw iterations (4_000 + 20_000 draws total) took 1 seconds.
az.summary(trace, hdi_prob=0.95, var_names="θ", kind="stats")
mean | sd | hdi_2.5% | hdi_97.5% | |
---|---|---|---|---|
θ | 102.742 | 6.841 | 89.26 | 115.994 |
Custom likelihoods in PyMC#
PyMC is a lot more flexible. For one, many more built-in distributions are available, including mixtures and zero-inflated models, which is where I’ve seen the zero trick used most often (Karlis and Ntzoufras [2008], Ntzoufras [2009] page 288).
If you need a custom likelihood, take a look at the pm.Potential or pm.CustomDist classes. pm.Potential
is used to adjust the likelihood, as we can see in the censored model ahead in Unit 8.6. For now, let’s take a look at how to use pm.CustomDist
.
pm.CustomDist
Motivating Example: Survival Model#
We can alternatively build the survival model in in Unit 6.9 with the custom distribution based on this post (but with equations corrected):
c
is an indicator variable. The data point is censored if c=0
, and a failure or non-survival is denoted by c=1
. When the point is not censored, the probability distribution is exponential. When the point is censored, the probability is the exponential distribution’s complementary CDF, which accounts for the probability to the right of the censored point.
To build this as a custom distribution, we create a logp
or log-probability function, as PyMC prefers to work with log-probabilities. One tricky part for this distribution is that both observed data arrays (t
and c
) need to be concatenated into a matrix as input into the observed
argument. We can see the summary results match those on Unit 6.9, which uses the different pm.Censored
function.
import numpy as np
# gamma dist parameters
α = 0.01
β = 0.1
# observed life within experiment dates
t = np.array([2, 72, 51, 60, 33, 27, 14, 24, 4, 21])
# censored indicator
c = np.array([1, 0, 1, 0, 1, 1, 1, 1, 1, 0])
# CustomDist requires all data to be together in a matrix
val = np.concat([[t],[c]])
# log-probability of custom distribution
def logp(value, lam):
t = value[0,:]
c = value[1,:]
return (c * log(lam) - lam * t).sum()
with pm.Model() as m:
λ = pm.Gamma("λ", α, β)
μ = pm.Deterministic("μ", 1 / λ)
exp_surv = pm.CustomDist('exp_surv', λ, logp=logp, observed=val)
trace = pm.sample(5000)
az.summary(trace)
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [λ]
Sampling 4 chains for 1_000 tune and 5_000 draw iterations (4_000 + 20_000 draws total) took 1 seconds.
mean | sd | hdi_3% | hdi_97% | mcse_mean | mcse_sd | ess_bulk | ess_tail | r_hat | |
---|---|---|---|---|---|---|---|---|---|
λ | 0.023 | 0.008 | 0.008 | 0.038 | 0.00 | 0.000 | 7786.0 | 11573.0 | 1.0 |
μ | 51.157 | 22.767 | 20.026 | 90.397 | 0.26 | 0.613 | 7786.0 | 11573.0 | 1.0 |
%load_ext watermark
%watermark -n -u -v -iv -p pytensor
Last updated: Fri Jun 13 2025
Python implementation: CPython
Python version : 3.13.3
IPython version : 9.2.0
pytensor: 2.30.3
pymc : 5.22.0
numpy: 2.2.6
arviz: 0.21.0