Importance
Intro to the update
logic¶
import jax
import jax.numpy as jnp
from genjax import ChoiceMapBuilder as C
from genjax import gen, normal, pretty
pretty()
key = jax.random.key(0)
One of the most important building block of the library is the update
method. Before investigating its details, let's look at the more user-friendly version called importance
.
importance
is a method on generative functions. It takes a key, constraints in the form of a choicemap, and arguments for the generative function. Let's first see how we use it and then explain what happened.
@gen
def model(x):
y = normal(x, 1.0) @ "y"
z = normal(y, 1.0) @ "z"
return y + z
constraints = C.n()
args = (1.0,)
key, subkey = jax.random.split(key)
tr, w = model.importance(subkey, constraints, args)
We obtain a pair of a trace tr
and a weight w
. tr
is produced by the model, and its choicemap satisfies the constraints given by constraints
.
For the choices that are not constrained, they are sampled from the prior distribution given by the model.
# we expect normal(0., 1.) for y and constant 4. for z
constraints = C["z"].set(4.0)
args = (0.0,)
key, subkey = jax.random.split(key)
keys = jax.random.split(subkey, 100000)
trs, ws = jax.vmap(lambda key: model.importance(key, constraints, args))(keys)
import matplotlib.pyplot as plt
import numpy as np
ys = trs.get_choices()["y"]
zs = trs.get_choices()["z"]
plt.hist(ys, bins=200, density=True, alpha=0.5, color="b", label="ys")
plt.scatter(zs, np.zeros_like(zs), color="r", label="zs")
plt.title("Gaussian Distribution of ys and Constant z")
plt.legend()
plt.show()
The weights computed represent the ratio $\frac{P(y, 4. ; x)}{P(y ; x)}$ where $P(y, z ; x)$ is the joint density given by the model at the argument $x$, and $P(y ; x)$ is the density of the subpart of the model that does not contain the constrained variables. As "z" is constrained in our example, it only leaves "y".
We can easily check this:
numerators, _ = jax.vmap(lambda y: model.assess(C["y"].set(y) ^ C["z"].set(4.0), args))(
ys
)
denominators = trs.get_subtrace("y").get_score()
# yeah, numerical stability of floats implies it's not even exactly equal ...
jnp.allclose(ws, numerators - denominators, atol=1e-7)
/tmp/ipykernel_10544/1554183388.py:1: DeprecationWarning: Call to deprecated method __xor__. (^ is deprecated, please use | or _.merge(...) instead.) -- Deprecated since version 0.8.0. numerators, _ = jax.vmap(lambda y: model.assess(C["y"].set(y) ^ C["z"].set(4.0), args))(
More generally the denominator is the joint on the sampled variables (the constraints are not sampled) and Gen has a way to automatically sampled from the generative function obtained by replacing the sampling operations of the constrained addresses by the values of the constraints. For instance in our example it would mean:
@gen
def constrained_model(x):
y = normal(x, 1.0) @ "y"
z = 4.0
return y + z
Thanks to the factorisation $P(y, z ; x) = P(y ; x)P(z | y ; x)$, the weight ws
simplifies to $P(z | y ; x)$.
In fact we can easily check it
ws == trs.get_subtrace("z").get_score()
And this time the equality is exact as this is how importance
computes it. The algebraic simplification $\frac{P(y ; x)}{P(y ; x)}=1$ is done automatically.
Let's review. importance
completes a set of constraints given by a partial choicemap to a full choicemap under the model. It also efficiently computes a weight which simplifies to a distribution of the form $P(\text{sampled } | \text{ constraints} ; \text{arguments})$.
The complex recursive nature of this formula becomes a bit more apparent in the following example:
@gen
def fancier_model(x):
y1 = normal(x, 1.0) @ "y1"
z1 = normal(y1, 1.0) @ "z1"
y2 = normal(z1, 1.0) @ "y2"
z2 = normal(z1 + y2, 1.0) @ "z2"
return y2 + z2
# if we constraint `z1` to be 4. and `z2` to be 2. we'd get a constrained model as follows:
@gen
def constrained_fancier_model(x):
y1 = normal(x, 1.0) @ "y1"
z1 = 4.0
y2 = normal(z1, 1.0) @ "y2" # note how the sampled `y2` depends on a constraint
z2 = 2.0
return y1 + z1 + y2 + z2
But what does this have to do this importance sampling?¶
What we effectively did was to sample a value y
from the distribution constrained_model
, which is called a proposal in importance sampling, often noted $q$. We then computed the weight $\frac{p(y)}{q(y)}$ under some model $p$.
Given that we constrained z
, an equivalent view is that we observed z
and we have a posterior inference problem: we want to approximately sample from the posterior $P(y | z)$ (all for a given argument x
).
Note that $P(y | z) = \frac{P(y,z)}{P(z)}$ by Bayes rule. So our fraction $\frac{P(y, z ; x)}{P(y ; x)}$ for the weight rewrites as $\frac{P(y | z)P(z)}{q(y)}= P(z)\frac{p(y)}{q(y)}$ (1).
Also remember that the weight $\frac{dp}{dq}$ for importance comes from the proper weight guarantee, i.e. it satisfies this equation: $$\forall f.\mathbb{E}_{y\sim p}[f(y)]= \mathbb{E}_{y\sim q}[\frac{dp}{dq}(y)f(y)] = \frac{1}{p(z)} \mathbb{E}_{y\sim q}[w(y)f(y)] $$
where in the last step we used (1) and called w
the weight computed by importance
.
By taking $f:= \lambda y.1$, we derive that $p(z) = \mathbb{E}_{y\sim q}[w(y)]$. That is, by sampling from our proposal distribution, we can estimate the marginal $p(z)$. Theferore with the same samples we can estimate any quantity $\mathbb{E}_{y\sim p}[f(y)]$ using our estimate of $\mathbb{E}_{y\sim q}[w(y)f(y)]$ and our estimate of $p(z)$. That's the essence of self-normalizing importance sampling.
The special case of the fully constrained choicemap¶
In the case where we give constraints that are a full choicemap for the model, importance
returns the same value as assess
.
args = (1.0,)
key, subkey = jax.random.split(key)
tr = model.simulate(key, args)
constraints = tr.get_choices()
new_tr, w = model.importance(subkey, constraints, args)
score, _ = model.assess(constraints, args)
w == score