Choice maps
¶
import sys
if "google.colab" in sys.modules:
%pip install --quiet "genjax[genstudio]"
import jax
import jax.numpy as jnp
import jax.random as random
import genjax
from genjax import ChoiceMapBuilder as C
from genjax import (
bernoulli,
beta,
gen,
mix,
normal,
or_else,
pretty,
repeat,
scan,
vmap,
)
pretty()
key = random.key(0)
Choice maps are dictionary-like data structures that accumulate the random choices produced by generative functions which are traced
by the system, i.e. that are indicated by @ "p"
in generative functions.
They also serve as a set of constraints/observations when one tries to do inference: given the constraints, inference provides plausible value to complete a choice map to a full trace of a generative model (one value per traced random sample).
@gen
def beta_bernoulli_process(u):
p = beta(1.0, u) @ "p"
v = bernoulli(p) @ "v"
return 2 * v
Simulating from a model produces a traces which contains a choice map.
key, subkey = jax.random.split(key)
trace = jax.jit(beta_bernoulli_process.simulate)(subkey, (0.5,))
From that trace, we can recover the choicemap with either of the two equivalent methods:
trace.get_choices(), trace.get_choices()
We can also print specific subparts of the choice map.
trace.get_choices()["p"]
Then, we can create a choice map of observations and perform diverse operations on it. We can set the value of an address in the choice map. For instance, we can add two choicemaps together, which behaves similarly to the union of two dictionaries.
chm = C["p"].set(0.5) | C["v"].set(1)
chm
A couple of extra ways to achieve the same result.
chm_equiv_1 = (
C["p"].set(0.5).at["v"].set(1)
) # the at/set notation mimics JAX's array update pattern
chm_equiv_2 = C.d({"p": 0.5, "v": 1}) # creates a dictionary directly
assert chm == chm_equiv_1 == chm_equiv_2
This also works for hierarchical addresses:
chm = C["p", "v"].set(1)
# equivalent to
eq_chm = C.d({"p": C.d({"v": 1})})
assert chm == eq_chm
chm
We can also directly set a value in the choice_map
chm = C.v(5.0)
chm
We can also create an empty choice_map
chm = C.n()
chm
Other examples of Choice map creation include iteratively adding choices to a choice map.
chm = C.n()
for i in range(10):
chm = chm ^ C["p" + str(i)].set(i)
/tmp/ipykernel_6444/1087827984.py:3: DeprecationWarning: Call to deprecated method __xor__. (^ is deprecated, please use | or _.merge(...) instead.) -- Deprecated since version 0.8.0. chm = chm ^ C["p" + str(i)].set(i) /tmp/ipykernel_6444/1087827984.py:3: DeprecationWarning: Call to deprecated method __xor__. (^ is deprecated, please use | or _.merge(...) instead.) -- Deprecated since version 0.8.0. chm = chm ^ C["p" + str(i)].set(i) /tmp/ipykernel_6444/1087827984.py:3: DeprecationWarning: Call to deprecated method __xor__. (^ is deprecated, please use | or _.merge(...) instead.) -- Deprecated since version 0.8.0. chm = chm ^ C["p" + str(i)].set(i) /tmp/ipykernel_6444/1087827984.py:3: DeprecationWarning: Call to deprecated method __xor__. (^ is deprecated, please use | or _.merge(...) instead.) -- Deprecated since version 0.8.0. chm = chm ^ C["p" + str(i)].set(i) /tmp/ipykernel_6444/1087827984.py:3: DeprecationWarning: Call to deprecated method __xor__. (^ is deprecated, please use | or _.merge(...) instead.) -- Deprecated since version 0.8.0. chm = chm ^ C["p" + str(i)].set(i) /tmp/ipykernel_6444/1087827984.py:3: DeprecationWarning: Call to deprecated method __xor__. (^ is deprecated, please use | or _.merge(...) instead.) -- Deprecated since version 0.8.0. chm = chm ^ C["p" + str(i)].set(i) /tmp/ipykernel_6444/1087827984.py:3: DeprecationWarning: Call to deprecated method __xor__. (^ is deprecated, please use | or _.merge(...) instead.) -- Deprecated since version 0.8.0. chm = chm ^ C["p" + str(i)].set(i) /tmp/ipykernel_6444/1087827984.py:3: DeprecationWarning: Call to deprecated method __xor__. (^ is deprecated, please use | or _.merge(...) instead.) -- Deprecated since version 0.8.0. chm = chm ^ C["p" + str(i)].set(i) /tmp/ipykernel_6444/1087827984.py:3: DeprecationWarning: Call to deprecated method __xor__. (^ is deprecated, please use | or _.merge(...) instead.) -- Deprecated since version 0.8.0. chm = chm ^ C["p" + str(i)].set(i) /tmp/ipykernel_6444/1087827984.py:3: DeprecationWarning: Call to deprecated method __xor__. (^ is deprecated, please use | or _.merge(...) instead.) -- Deprecated since version 0.8.0. chm = chm ^ C["p" + str(i)].set(i)
An equivalent, more JAX-friendly way to do this
chm = jax.vmap(lambda idx: C[idx].set(idx.astype(float)))(jnp.arange(10))
And in fact, we can directly use the numpy notation to create a choice map.
chm = C[:].set(jnp.arange(10.0))
chm
For a nested vmap combinator, the creation of a choice map can be a bit more tricky.
sample_image = genjax.vmap(in_axes=(0,))(
genjax.vmap(in_axes=(0,))(gen(lambda pixel: normal(pixel, 1.0) @ "new_pixel"))
)
image = jnp.zeros([4, 4], dtype=jnp.float32)
key, subkey = jax.random.split(key)
trace = sample_image.simulate(subkey, (image,))
trace.get_choices()
Creating a few values for the choice map is simple.
chm = C[1, 2, "new_pixel"].set(1.0) ^ C[0, 2, "new_pixel"].set(1.0)
key, subkey = jax.random.split(key)
tr, w = jax.jit(sample_image.importance)(subkey, chm, (image,))
w
/tmp/ipykernel_6444/3889217969.py:1: DeprecationWarning: Call to deprecated method __xor__. (^ is deprecated, please use | or _.merge(...) instead.) -- Deprecated since version 0.8.0. chm = C[1, 2, "new_pixel"].set(1.0) ^ C[0, 2, "new_pixel"].set(1.0)
But because of the nested vmap
, the address hierarchy can sometimes lead to unintuitive results, e.g. as there is no bound check on the address. We seemingly adding a new constraint but we obtain the same weight as before, meaning that the new choice was not used for inference.
chm = chm ^ C[1, 5, "new_pixel"].set(1.0)
tr, w = jax.jit(sample_image.importance)(
subkey, chm, (image,)
) # reusing the key to make comparisons easier
w
/tmp/ipykernel_6444/3063422751.py:1: DeprecationWarning: Call to deprecated method __xor__. (^ is deprecated, please use | or _.merge(...) instead.) -- Deprecated since version 0.8.0. chm = chm ^ C[1, 5, "new_pixel"].set(1.0)
A different way to create a choicemap that is compatible with the nested vmap in this case.
chm = C[:, :, "new_pixel"].set(jnp.ones((4, 4), dtype=jnp.float32))
key, subkey = jax.random.split(key)
tr, w = jax.jit(sample_image.importance)(subkey, chm, (image,))
w
More generally, some combinators introduce an Indexed
choicemap.
These are mainly vmap, scan
as well as those derived from these 2, such as iterate, repeat
.
An Indexed
choicemap introduced an integer in the hierarchy of addresses, as the place where the combinator is introduced.
For instance:
@genjax.gen
def submodel():
x = genjax.exponential.vmap()(1.0 + jnp.arange(50, dtype=jnp.float32)) @ "x"
return x
@genjax.gen
def model():
xs = submodel.repeat(n=5)() @ "xs"
return xs
key, subkey = jax.random.split(key)
tr = model.simulate(subkey, ())
chm = tr.get_choices()
chm
In this case, we can create a hierarchical choicemap as follows:
chm = C["xs", :, "x", :].set(jnp.ones((5, 50)))
key, subkey = jax.random.split(key)
model.importance(subkey, chm, ())
We can also construct an indexed choicemap with more than one variable in it using the following syntax:
_phi, _q, _beta, _r = (0.9, 1.0, 0.5, 1.0)
@genjax.gen
def step(state):
x_prev, z_prev = state
x = genjax.normal(_phi * x_prev, _q) @ "x"
z = _beta * z_prev + x
_ = genjax.normal(z, _r) @ "y"
return (x, z)
max_T = 20
model = step.iterate_final(n=max_T)
x_range = 1.0 * jnp.where(
(jnp.arange(20) >= 10) & (jnp.arange(20) < 15), jnp.arange(20) + 1, jnp.arange(20)
)
y_range = 1.0 * jnp.where(
(jnp.arange(20) >= 15) & (jnp.arange(20) < 20), jnp.arange(20) + 1, jnp.arange(20)
)
xy = C["x"].set(x_range).at["y"].set(y_range)
chm4 = C[jnp.arange(20)].set(xy)
chm4
key, subkey = jax.random.split(key)
model.importance(subkey, chm4, ((0.5, 0.5),))
Accessing the right elements in the trace can become non-trivial when one creates hierarchical generative functions. Here are minimal examples and solutions for selection.
# For `or_else` combinator
@gen
def model(p):
branch_1 = gen(lambda p: bernoulli(p) @ "v1")
branch_2 = gen(lambda p: bernoulli(-p) @ "v2")
v = or_else(branch_1, branch_2)(p > 0, (p,), (p,)) @ "s"
return v
key, subkey = jax.random.split(key)
trace = jax.jit(model.simulate)(subkey, (0.5,))
trace.get_choices()["s", "v1"]
# For `vmap` combinator
sample_image = vmap(in_axes=(0,))(
vmap(in_axes=(0,))(gen(lambda pixel: normal(pixel, 1.0) @ "new_pixel"))
)
image = jnp.zeros([2, 3], dtype=jnp.float32)
key, subkey = jax.random.split(key)
trace = sample_image.simulate(subkey, (image,))
trace.get_choices()[:, :, "new_pixel"]
# For `scan_combinator`
@scan(n=10)
@gen
def hmm(x, c):
z = normal(x, 1.0) @ "z"
y = normal(z, 1.0) @ "y"
return y, None
key, subkey = jax.random.split(key)
trace = hmm.simulate(subkey, (0.0, None))
trace.get_choices()[:, "z"], trace.get_choices()[3, "y"]
# For `repeat_combinator`
@repeat(n=10)
@gen
def model(y):
x = normal(y, 0.01) @ "x"
y = normal(x, 0.01) @ "y"
return y
key, subkey = jax.random.split(key)
trace = model.simulate(subkey, (0.3,))
trace.get_choices()[:, "x"]
# For `mixture_combinator`
@gen
def mixture_model(p):
z = normal(p, 1.0) @ "z"
logits = (0.3, 0.5, 0.2)
arg_1 = (p,)
arg_2 = (p,)
arg_3 = (p,)
a = (
mix(
gen(lambda p: normal(p, 1.0) @ "x1"),
gen(lambda p: normal(p, 2.0) @ "x2"),
gen(lambda p: normal(p, 3.0) @ "x3"),
)(logits, arg_1, arg_2, arg_3)
@ "a"
)
return a + z
key, subkey = jax.random.split(key)
trace = mixture_model.simulate(subkey, (0.4,))
# The combinator uses a fixed address "mixture_component" for the components of the mixture model.
trace.get_choices()["a", "mixture_component"]
Similarly, if traces were created as a batch using jax.vmap
, in general it will not create a valid batched trace, e.g. the score will not be defined as a single float. It can be very useful for inference though.
@genjax.gen
def random_walk_step(prev, _):
x = genjax.normal(prev, 1.0) @ "x"
return x, None
random_walk = random_walk_step.scan(n=1000)
init = 0.5
keys = jax.random.split(key, 10)
trs = jax.vmap(random_walk.simulate, (0, None))(keys, (init, None))
try:
if isinstance(trs.get_score(), float):
trs.get_score()
else:
raise ValueError("Expected a float value for the score.")
except Exception as e:
print(e)
Expected a float value for the score.
However, with a little extra step we can recover information in individual traces.
jax.vmap(lambda tr: tr.get_choices())(trs)
Note that this limitation is dependent on the model, and the simpler thing may work anyway for some classes' models.
jitted = jax.jit(jax.vmap(model.simulate, in_axes=(0, None)))
keys = random.split(key, 10)
traces = jitted(keys, (0.5,))
traces.get_choices()