Custom proposal
I'm doing importance sampling as advised but it's bad, what can I do?
¶
import sys
if "google.colab" in sys.modules:
%pip install --quiet "genjax[genstudio]"
One thing one can do is write a custom proposal for importance sampling.
The idea is to sample from this one instead of the default one used by genjax when using model.importance
.
The default one is only informed by the structure of the model, and not by the posterior defined by both the model and the observations.
import jax
import jax.numpy as jnp
from jax import jit, vmap
from jax.scipy.special import logsumexp
from genjax import ChoiceMapBuilder as C
from genjax import Target, gen, normal, pretty, smc
key = jax.random.key(0)
pretty()
Let's first define a simple model with a broad normal prior and some observations
@gen
def model():
# Initially, the prior is a pretty broad normal distribution centred at 0
x = normal(0.0, 100.0) @ "x"
# We add some observations, which will shift the posterior towards these values
_ = normal(x, 1.0) @ "obs1"
_ = normal(x, 1.0) @ "obs2"
_ = normal(x, 1.0) @ "obs3"
return x
# We create some data, 3 observed values at 234
obs = C["obs1"].set(234.0) ^ C["obs2"].set(234.0) ^ C["obs3"].set(234.0)
/tmp/ipykernel_9946/323051816.py:13: DeprecationWarning: Call to deprecated method __xor__. (^ is deprecated, please use | or _.merge(...) instead.) -- Deprecated since version 0.8.0. obs = C["obs1"].set(234.0) ^ C["obs2"].set(234.0) ^ C["obs3"].set(234.0) /tmp/ipykernel_9946/323051816.py:13: DeprecationWarning: Call to deprecated method __xor__. (^ is deprecated, please use | or _.merge(...) instead.) -- Deprecated since version 0.8.0. obs = C["obs1"].set(234.0) ^ C["obs2"].set(234.0) ^ C["obs3"].set(234.0)
We then run importance sampling with a default proposal, snd print the average weight of the samples, to give us a sense of how well the proposal is doing.
key, *sub_keys = jax.random.split(key, 1000 + 1)
sub_keys = jnp.array(sub_keys)
args = ()
jitted = jit(vmap(model.importance, in_axes=(0, None, None)))
trace, weight = jitted(sub_keys, obs, args)
print("The average weight is", logsumexp(weight) - jnp.log(len(weight)))
print("The maximum weight is", weight.max())
The average weight is -9.859367 The maximum weight is -2.951612
We can see that both the average and even maximum weight are quite low, which means that the proposal is not doing a great job.
If there is no observations, ideally, the weight should center around 1 and be quite concentrated around that value.
A weight much higher than 1 means that the proposal is too narrow and is missing modes. Indeed, for that to happen, one has to sample a very unlikely value under the proposal which is very likely under the target.
A weight much lower than 1 means that the proposal is too broad and is wasting samples. This happens in this case as the default proposal uses the broad prior normal(0.0, 100.0)
as a proposal, which is far from the observed values centred around $234.0$.
If there are observations, as is the case above, the weight should center around the marginal on the observations. More precisely, if the model has density $p(x,y)$ where $y$ are the observations and the proposal has density $q(x)$, then a weight is given by $w = \frac{p(x,y)}{q(x)}$ whose average value over many runs (expectations under the proposal) is $p(y)$.
We now define a custom proposal, which will be a normal distribution centred around the observed values
@gen
def proposal(obs):
avg_val = jnp.array(obs).mean()
std = jnp.array(obs).std()
x = (
normal(avg_val, 0.1 + std) @ "x"
) # To avoid a degenerate proposal, we add a small value to the standard deviation
return x
To do things by hand first, let's reimplement the importance function. It samples from the proposal and then computes the importance weight
def importance_sample(target, obs, proposal):
def _inner(key, target_args, proposal_args):
trace = proposal.simulate(key, *proposal_args)
# the full choice map under which we evaluate the model
# has the sampled values from the proposal and the observed values
chm = obs ^ trace.get_choices()
proposal_logpdf = trace.get_score()
target_logpdf, _ = target.assess(chm, *target_args)
importance_weight = target_logpdf - proposal_logpdf
return (trace, importance_weight)
return _inner
We then run importance sampling with the custom proposal
key, *sub_keys = jax.random.split(key, 1000 + 1)
sub_keys = jnp.array(sub_keys)
args_for_model = ()
args_for_proposal = (jnp.array([obs["obs1"], obs["obs2"], obs["obs3"]]),)
jitted = jit(vmap(importance_sample(model, obs, proposal), in_axes=(0, None, None)))
trace, new_weight = jitted(sub_keys, (args_for_model,), (args_for_proposal,))
/tmp/ipykernel_9946/2378245163.py:6: DeprecationWarning: Call to deprecated method __xor__. (^ is deprecated, please use | or _.merge(...) instead.) -- Deprecated since version 0.8.0. chm = obs ^ trace.get_choices()
We see that the new values, both average and maximum, are much higher than before, which means that the custom proposal is doing a much better job
print("The new average weight is", logsumexp(new_weight) - jnp.log(len(new_weight)))
print("The new maximum weight is", new_weight.max())
The new average weight is -11.4855 The new maximum weight is -7.8311706
We can also do the same using the library functions.
To do this, let's first create a target posterior distribution. It consists of the model, arguments for it, and observations.
target_posterior = Target(model, args_for_model, obs)
Next, we redefine the proposal slightly to take the target as argument. This way, it can extract the observations fro the target as we previously used. But the target can for instance also depend on the arguments passed to the model.
@gen
def proposal(target: Target):
model_obs = target.constraint
used_obs = jnp.array([model_obs["obs1"], model_obs["obs2"], model_obs["obs3"]])
avg_val = jnp.array(used_obs).mean()
std = jnp.array(used_obs).std()
x = normal(avg_val, 0.1 + std) @ "x"
return x
Now, similarly to the importance_sampling notebook, we create an instance algorithm: it specifies a strategy to approximate our posterior of interest, target_posterior
, using importance sampling with k_particles
, and our custom proposal.
To specify that we use all the traced variables from proposal
in importance sampling (we will revisit why that may not be the case in the ravi_stack notebook) are to be used, we will use proposal.marginal()
. This indicates that no traced variable from proposal
is marginalized out.
k_particles = 1000
alg = smc.ImportanceK(target_posterior, q=proposal.marginal(), k_particles=k_particles)
This will perform sampling importance resampling (SIR) with a $1000$ intermediate particles and one resampled and returned at the end which is returned. Testing
jitted = jit(alg.simulate)
key, subkey = jax.random.split(key)
posterior_samples = jitted(subkey, (target_posterior,))
posterior_samples