The generative function interface
¶
import sys
if "google.colab" in sys.modules:
%pip install --quiet "genjax[genstudio]"
import jax
from jax import jit
from genjax import ChoiceMapBuilder as C
from genjax import (
Diff,
NoChange,
UnknownChange,
bernoulli,
beta,
gen,
pretty,
)
from genjax._src.generative_functions.static import MissingAddress
key = jax.random.key(0)
pretty()
# Define a generative function
@gen
def beta_bernoulli_process(u):
p = beta(1.0, u) @ "p"
v = bernoulli(p) @ "v"
return 2 * v
- Generate a traced sample and constructs choicemaps
There's an entire cookbook entry on this in choicemap_creation_selection
.
key, subkey = jax.random.split(key)
trace = jax.jit(beta_bernoulli_process.simulate)(subkey, (0.5,))
- Compute log probabilities
2.1 Print the log probability of the trace
trace.get_score()
2.2 Print the log probability of an observation encoded as a ChoiceMap under the model
It returns both the log probability and the return value
chm = C["p"].set(0.5) ^ C["v"].set(1)
args = (0.5,)
beta_bernoulli_process.assess(chm, args)
/tmp/ipykernel_8592/3994869088.py:1: DeprecationWarning: Call to deprecated method __xor__. (^ is deprecated, please use | or _.merge(...) instead.) -- Deprecated since version 0.8.0. chm = C["p"].set(0.5) ^ C["v"].set(1)
Note that the ChoiceMap should be complete, i.e. all random choices should be observed
chm_2 = C["v"].set(1)
try:
beta_bernoulli_process.assess(chm_2, args)
except MissingAddress as e:
print(e)
p
- Generate a sample conditioned on the observations
We can also use a partial ChoiceMap as a constraint/observation and generate a full trace with these constraints.
key, subkey = jax.random.split(key)
partial_chm = C["v"].set(1) # Creates a ChoiceMap of observations
args = (0.5,)
trace, weight = beta_bernoulli_process.importance(
subkey, partial_chm, args
) # Runs importance sampling
This returns a pair containing the new trace and the log probability of produced trace under the model
trace.get_choices()
weight
- Update a trace.
We can also update a trace. This is for instance useful as a performance optimization in Metropolis-Hastings algorithms where often most of the trace doesn't change between time steps.
We first define a model for which changing the argument will force a change in the trace.
@gen
def beta_bernoulli_process(u):
p = beta(1.0, u) @ "p"
v = bernoulli(p) @ "v"
return 2 * v
We then create an trace to be updated and constraints.
key, subkey = jax.random.split(key)
jitted = jit(beta_bernoulli_process.simulate)
old_trace = jitted(subkey, (1.0,))
constraint = C["v"].set(1)
Now the update uses a form of incremental computation. It works by tracking the differences between the old new values for arguments. Just like for differentiation, it can be achieved by providing for each argument a tuple containing the new value and its change compared to the old value.
If there's no change for an argument, the change is set to NoChange.
arg_diff = (Diff(1.0, NoChange),)
If there is any change, the change is set to UnknownChange.
arg_diff = (Diff(5.0, UnknownChange),)
We finally use the update method by passing it a key, the trace to be updated, and the update to be performed.
jitted_update = jit(beta_bernoulli_process.update)
key, subkey = jax.random.split(key)
new_trace, weight_diff, ret_diff, discard_choice = jitted_update(
subkey, old_trace, constraint, arg_diff
)
We can compare the old and new values for the samples and notice that they have not changed.
old_trace.get_choices() == new_trace.get_choices()
We can also see that the weight has changed. In fact we can check that the following relation holds new_weight
= old_weight
+ weight_diff
.
weight_diff, old_trace.get_score() + weight_diff == new_trace.get_score()
- A few more convenient methods
5.1 propose
It uses the same inputs as simulate
but returns the sample, the score and the return value
key, subkey = jax.random.split(key)
sample, score, retval = jit(beta_bernoulli_process.propose)(subkey, (0.5,))
sample, score, retval
5.2 get_gen_fn
It returns the generative function that produced the trace.
trace.get_gen_fn()
5.3 get_args
It returns the arguments passed to the generative function used to produce the trace
trace.get_args()
5.4 get_subtrace
It takes a StaticAddress
as argument and returns the sub-trace of a trace rooted at these addresses
subtrace = trace.get_subtrace("p")
subtrace, subtrace.get_choices()