Update
Compositional Incremental Weight Computation¶
import jax
import genjax
from genjax import ChoiceMapBuilder as C
from genjax import gen, normal, pretty
pretty()
key = jax.random.key(0)
Let's now see how importance
and update
are related.
The high level is that
importance
starts with an empty trace, adds the constraints, and then samples the missing values to form a full choicemap under the moekupdate
starts with any trace, overwrites those given by the constraints, and samples the missing ones. The missing ones can come from the initial trace possibly having an incomplete choicemap for the model, but also if some constraints force the computation in the model to take a different path which has different sampled values.- It also returns a weight ratio which generalizes the one from
importance
. - It also takes and returns additional data to make it compositional, i.e.
update
is defined inductively on the structure of themodel
.
- It also returns a weight ratio which generalizes the one from
@gen
def model(x):
y = normal(x, 1.0) @ "y"
z = normal(y, 1.0) @ "z"
return y + z
Let's first check that update
does not do anything if we provide no changes nor constraints.
args = (1.0,)
key, subkey = jax.random.split(key)
tr = model.simulate(subkey, args)
constraints = C.n()
argdiffs = genjax.Diff.no_change(args)
key, subkey = jax.random.split(key)
new_trace, _, _, _ = model.update(subkey, tr, constraints, argdiffs)
new_trace == tr
Let's now check that it returns a trace where the constraints overwrite the value from the initial trace.
key, subkey = jax.random.split(key)
constraints = C["y"].set(3.0)
new_tr, _, _, _ = model.update(subkey, tr, constraints, argdiffs)
new_tr.get_choices()["y"] == 3.0 and new_tr.get_choices()["z"] == tr.get_choices()["z"]
Next, let's look at the new input and new outputs compared to importance
.
args = (1.0,)
key, subkey = jax.random.split(key)
tr = model.simulate(subkey, args)
constraints = C["z"].set(3.0)
argdiffs = genjax.Diff.no_change(args)
new_trace, weight, ret_diff, discarded = model.update(subkey, tr, constraints, argdiffs)
argdiffs, ret_diff, discarded
discarded
represents a choicemap of the choices that were overwritten by the constraints.
discarded["z"] == tr.get_choices()["z"]
argdiffs
and ret_diff
use a special Diff
type which is a simpler analogue of dual-numbers from automatic differentiation (AD). They represent a pair of a primal value and a tangent value.
In AD, the primal would be the point at which we're differentiating the function and the dual would be the derivative of the current variable w.r.t. an outside variable.
Here, the tangent type is much simpler and Boolean. It either consists of the NoChange()
tag or the UnknownChange
tag.
This is inspired by the literature on incremental computation, and is only there for efficiently computing the density ratio weight
by doing algebraic simplifications at compile time as we have briefly seen for the case of importance
in the previous cookbook.
The idea is that a change in the argument x
of the generative function implies a change to the distribution on y
. So given a value of y
, when we want to compute its density we need to know the value of x
. Maybe a change in x
would force resampling a different variable y
, which would then force a change on the distribution on z
. That's the basic idea behind the Diff
system and why it needs to be compositional. It's a form of dependency tracking to check which distributions might have changed given a change in arguments, and importantly know which ones didn't change for sure so we can apply some algebraic simplifications.
Now what about the weight? what does it compute?¶
Let's denote a trace by a pair (x,t)
of the argument x
and the choicemap t
.
Given a trace (x,t)
, a new argument x'
, and a map of constraints u
, update
returns a new trace (x', t')
that is consistent with u
. The values of choices in t'
are copied from t
and u
(with u
taking precedence) or sampled from the internal proposal $q$ (i.e. the equivalent to constrained_model
that we have seen in the importance
cookbook).
The weight $w$ satisfies $$w_{update} = \frac{p(t' ; x)}{q(t' ; x', t+u).p(t ; x)}$$
where $t+u$ denotes the choicemap where u
overwrites the values in t
on their common addresses.
Let's contrast it with the weight $w$ computed by importance which we can write as
$$w_{importance}\frac{p(t' ; x)}{q(t' ; x, u)}$$
which we can see as the special case of update
with an empty starting trace t
.
What to do with the weight from update
?¶
One simple thing is that given a trace with choicemap $y$ and a full choicemap $y'$ used as a constraint, update
will not need to call the internal proposal q
and the weight returned will be $\frac{p(y')}{p(y)}$. This is a useful quantity that appears in many SMC algorithms, and for instance in the ratio in the MH acceptance ratio $\alpha$.
Given a current value y
for the choicemap and proposed value u
for a change in some variables of the choicemap, if we call the model p
and the proposal q
(a kernel which may depend on y
and proposes the value u
), we write $y':= y+u$. Then, the MH acceptance ratio is defined as $$\alpha := \frac{q(y | y')p(y')}{p(y)q(y' | y)} = \frac{q(y | y')}{q(y' | y)}w_{update}$$
A convenient usage of update
¶
update
has a derived convenient usage. If you have a trace tr
and want to do some inference move, e.g. propose a new value for a specific address "x". We obtain a new trace with the new value for "x" using update
:
new_tr, _ = model.update(subkey, tr, C["x"].set(new_val_for_x), genjax.Diff.no_change(args))