Importance sampling
I want to do my first inference task, how do I do it?
¶
import sys
if "google.colab" in sys.modules:
%pip install --quiet "genjax[genstudio]"
We will do it with importance sampling, which works as follows. We choose a distribution $q$ called a proposal that you we will sample from, and we need a distribution $p$ of interest, typically representing a posterior from a model having received observations.
import jax
import jax.numpy as jnp
import jax.tree_util as jtu
from jax import jit, vmap
import genjax
from genjax import ChoiceMapBuilder as C
from genjax import Target, bernoulli, beta, gen, pretty, smc
key = jax.random.key(0)
pretty()
Let's first look at a simple python version of the algorithm to get the idea.
def importance_sample(model, proposal):
def _inner(key, model_args, proposal_args):
# we sample from the easy distribution, the proposal `q`
trace = proposal.simulate(key, *proposal_args)
chm = trace.get_choices()
# we evaluate the score of the easy distribution q(x)
proposal_logpdf = trace.get_score()
# we evaluate the score of the hard distribution p(x)
model_logpdf, _ = model.assess(chm, *model_args)
# the importance weight is p(x)/q(x), which corrects for the bias from sampling from q instead of p
importance_weight = model_logpdf - proposal_logpdf
return (trace, importance_weight)
# we return the trace and the importance weight
return _inner
We can test this on a very simple example.
model = genjax.normal
proposal = genjax.normal
model_args = (0.0, 1.0)
proposal_args = (3.0, 4.0)
key, subkey = jax.random.split(key)
sample, importance_weight = jit(importance_sample(model, proposal))(
subkey, (model_args,), (proposal_args,)
)
print(importance_weight, sample.get_choices())
-18.546162 Choice(v=<jax.Array(-6.769823, dtype=float32)>)
We can also run it in parallel!
jitted = jit(
vmap(
importance_sample(model, proposal),
in_axes=(0, None, None),
)
)
key, *sub_keys = jax.random.split(key, 100 + 1)
sub_keys = jnp.array(sub_keys)
(sample, importance_weight) = jitted(sub_keys, (model_args,), (proposal_args,))
sample.get_choices(), importance_weight
In GenJAX, every generative function comes equipped with a default proposal which we can use for importance sampling.
Let's define a generative function.
@gen
def beta_bernoulli_process(u):
p = beta(1.0, u) @ "p"
v = bernoulli(p) @ "v"
return v
By giving constraints to some of the random samples, which we call observations, we obtain a posterior inference problem where the goal is to infer the distribution of the random variables which are not observed.
obs = C["v"].set(1)
args = (0.5,)
The method .importance
defines a default proposal based on the generative function which targets the posterior distribution we just defined.
It returns a pair containing a trace and the log incremental weight.
This weight corrects for the bias from sampling from the proposal instead of the intractable posterior distribution.
key, subkey = jax.random.split(key)
trace, weight = beta_bernoulli_process.importance(subkey, obs, args)
trace, weight
N = 1000
K = 100
def SIR(N, K, model, chm):
@jit
def _inner(key, args):
key, subkey = jax.random.split(key)
traces, weights = vmap(model.importance, in_axes=(0, None, None))(
jax.random.split(key, N), chm, args
)
idxs = vmap(jax.jit(genjax.categorical.simulate), in_axes=(0, None))(
jax.random.split(subkey, K), (weights,)
).get_retval()
samples = traces.get_choices()
resampled_samples = vmap(lambda idx: jtu.tree_map(lambda v: v[idx], samples))(
idxs
)
return resampled_samples
return _inner
Testing
chm = C["v"].set(1)
args = (0.5,)
key, subkey = jax.random.split(key)
samples = jit(SIR(N, K, beta_bernoulli_process, chm))(subkey, args)
samples
Another way to do the basically the same thing using library functions.
To do this, we first define a Target for importance sampling, i.e. the posterior inference problem we're targetting. It consists of a generative function, arguments to the generative function, and observations.
target_posterior = Target(beta_bernoulli_process, (args,), chm)
Next, we define an inference strategy algorithm (Algorithm class) to use to approximate the target distribution.
It's importance sampling with $N$ particles in our case.
alg = smc.ImportanceK(target_posterior, k_particles=N)
To get a different sense of what's going on, the hierarchy of classes is as follows:
ImportanceK <: SMCAlgorithm <: Algorithm <: SampleDistribution <: Distribution <: GenerativeFunction <: Pytree
In words, importance sampling (ImportanceK
) is a particular instance of Sequential Monte Carlo ( SMCAlgorithm
). The latter is one instance of approximate inference strategy (Algorithm
).
An inference strategy in particular produces samples for a distribution (SampleDistribution
), which is a distribution (Distribution
) whose return value is the sample. A distribution here is the definition from GenSP (Lew et al 2023) which has two methods random_weighted
and estimate_logpdf
. See the appropriate cookbook for details on these.
Finally, a distribution is a particular case of generative function (GenerativeFunction
), which are all pytrees (Pytree
) to be JAX-compatible and in particular jittable.
To get K independent samples from the approximate posterior distribution, we can for instance use vmap
.
# It's a bit different from the previous example, because each of the final
# K samples is obtained by running a different set of N-particles.
# This can of course be optimized but we keep it simple here.
jitted = jit(vmap(alg.simulate, in_axes=(0, None)))
Testing
key, *sub_keys = jax.random.split(key, K + 1)
sub_keys = jnp.array(sub_keys)
posterior_samples = jitted(sub_keys, (target_posterior,)).get_retval()
# This only does the importance sampling step, not the resampling step
# Therefore the shape is (K, N, 1)
posterior_samples["p"]
We can check the mean value estimate for "p"
.
posterior_samples["p"].mean(axis=(0, 1))
And we compare the relative difference with the one obtained using the previous method.
100.0 * jnp.abs(
samples["p"].mean() - posterior_samples["p"].mean(axis=(0, 1))
) / posterior_samples["p"].mean(axis=(0, 1)) # about 2% difference