Masking
I want more dynamic features but JAX only accepts arrays with statically known sizes, what do I do?
¶
import sys
if "google.colab" in sys.modules:
%pip install --quiet "genjax[genstudio]"
import jax
import jax.numpy as jnp
from PIL import Image
import genjax
from genjax import ChoiceMapBuilder as C
from genjax import bernoulli, categorical, gen, normal, or_else, pretty
pretty()
key = jax.random.key(0)
One classic trick is to encode all the options as an array and pick the desired value from the array with a dynamic one.
Here's a first example:
@gen
def model(
i, means, vars
): # provide all the possible values and the dynamic index to pick from them
x = normal(means[i], vars[i]) @ "x"
return x
key, subkey = jax.random.split(key)
model.simulate(subkey, (7, jnp.arange(10, dtype=jnp.float32), jnp.ones(10)))
Now, what if there's a value we may or may not want to get depending on a dynamic value?
In this case, we can use masking. Let's look at an example in JAX.
non_masked = jnp.arange(9).reshape(3, 3)
non_masked
# mask the upper triangular part of the matrix
mask = jnp.mask_indices(3, jnp.triu)
non_masked[mask]
We can use similar logic for generative functions in GenJAX.
Let's create an HMM using the scan combinator.
state_size = 10
length = 10
variance = jnp.eye(state_size)
key, subkey = jax.random.split(key)
initial_state = jax.random.normal(subkey, (state_size,))
@genjax.gen
def hmm_step(x):
new_x = genjax.mv_normal(x, variance) @ "new_x"
return new_x
hmm = hmm_step.iterate_final(n=length)
When we run it, we get a full trace.
jitted = jax.jit(hmm.simulate)
key, subkey = jax.random.split(key)
trace = jitted(subkey, (initial_state,))
trace.get_choices()
To get the partial results in the HMM instead, we can use the masked version of iterate_final
as follows:
stop_at_index = 5
pairs = jnp.arange(state_size) < stop_at_index
masked_hmm = hmm_step.masked_iterate_final()
key, subkey = jax.random.split(key)
choices = masked_hmm.simulate(subkey, (initial_state, pairs)).get_choices()
choices
We see that we obtain a filtered choice map, with a selection representing the boolean mask array. Within the filtered choice map, we have a static choice map where all the results are computed, without the mask applied to them. This is generally what will happen behind the scene in GenJAX: results will tend to be computed and then ignored, which is often more efficient on the GPU rather than being too eager in trying to avoid to do computations in the first place.
Let's now use it in a bigger computation where the masking index is dynamic and comes from a sampled value.
@gen
def larger_model(init, probs):
i = categorical(probs) @ "i"
mask = jnp.arange(10) < i
x = masked_hmm(init, mask) @ "x"
return x
key, subkey = jax.random.split(key)
init = jax.random.normal(subkey, (state_size,))
probs = jnp.arange(state_size) / sum(jnp.arange(state_size))
key, subkey = jax.random.split(key)
choices = larger_model.simulate(subkey, (init, probs)).get_choices()
choices
We have already seen how to use conditionals in GenJAX models in the conditionals
notebook. Behind the scene, it's using the same logic with masks.
@gen
def cond_model(p):
pred = p > 0
arg_1 = (p,)
arg_2 = (p,)
v = (
or_else(
gen(lambda p: bernoulli(p) @ "v1"), gen(lambda p: bernoulli(-p) @ "v1")
)(pred, arg_1, arg_2)
@ "cond"
)
return v
key, subkey = jax.random.split(key)
choices = cond_model.simulate(subkey, (0.5,)).get_choices()
choices
We see that both branches will get evaluated and a mask will be applied to each branch, whose value depends on the evaluation of the boolean predicate pred
.
What's happening behind the scene for masked values in the trace? Simply put, even though the system computes values, they are ignored w.r.t. the math of inference.
We can check it on a simple example, with two versions of a model, where one has an extra masked variable y
.
Let's first define the two versions of the model.
@gen
def simple_model():
x = normal(0.0, 1.0) @ "x"
return x
@gen
def submodel():
y = normal(0.0, 1.0) @ "y"
return y
@gen
def model_with_mask():
x = normal(0.0, 1.0) @ "x"
_ = submodel.mask()(False) @ "y"
return x
@gen
def proposal(_: genjax.Target):
x = normal(3.0, 1.0) @ "x"
return x
Let's now test that on the same key, they return the exact same score:
key, subkey = jax.random.split(key)
simple_target = genjax.Target(simple_model, (), C.n())
masked_target = genjax.Target(model_with_mask, (), C.n())
simple_alg = genjax.smc.Importance(simple_target, q=proposal.marginal())
masked_alg = genjax.smc.Importance(masked_target, q=proposal.marginal())
# TODO: something's fishy here with the math. Get the same whether I mask or not.
simple_alg.simulate(subkey, (simple_target,)).get_score() == masked_alg.simulate(
subkey, (masked_target,)
).get_score()
masked_alg.simulate(subkey, (masked_target,))
Let's see a final example for an unknown number of objects that may evolve over time.
For this, we can use vmap
over a masked object andd we get to choose which ones are masked or not.
Let's create a model consisting of a 2D image where each pixel is traced.
@gen
def single_pixel():
pixel = normal(0.0, 1.0) @ "pixel"
return pixel
image_model = single_pixel.mask().vmap(in_axes=(0,)).vmap(in_axes=(0,))
Let's create a circular mask around the image.
import matplotlib.animation as animation
import matplotlib.pyplot as plt
def create_circle_mask(size=200, center=None, radius=80):
if center is None:
center = (size // 2, size // 2)
y, x = jnp.ogrid[:size, :size]
dist_from_center = jnp.sqrt((x - center[0]) ** 2 + (y - center[1]) ** 2)
mask = dist_from_center <= radius
return mask
circle_mask = create_circle_mask()
plt.imshow(circle_mask, cmap="gray")
plt.show()
Let's now sample from the masked image and play with the mask and inference.
key, subkey = jax.random.split(key)
tr = image_model.simulate(subkey, (circle_mask,))
flag = tr.get_choices()[:, :, "pixel"].flag
im = flag * tr.get_choices()[:, :, "pixel"].value
plt.imshow(im, cmap="gray", vmin=0, vmax=1)
plt.show()
We can create a small animation by updating the mask over time using the GenJAX update
function to ensure that the probabilistic parts are taken properly into account.
number_iter = 10
fig, ax = plt.subplots()
# Load the image
image_path = "./ending_dynamic_computation.png" # Update with your image path
image = Image.open(image_path)
# Convert to grayscale if needed and resize to match new_im dimensions
image = image.convert("L") # Convert to grayscale
image = image.resize(im.shape[1::-1]) # Resize to match (height, width)
# Convert to NumPy array
image_array = jnp.array(image) / 255.0
images = []
for i in range(number_iter):
key, subkey = jax.random.split(key)
new_circle_mask = create_circle_mask(radius=10 * i)
arg_diff = (genjax.Diff(new_circle_mask, genjax.UnknownChange),)
constraints = C.n()
update_problem = genjax.Update(constraints)
tr, _, _, _ = tr.edit(key, update_problem, arg_diff)
flag = tr.get_choices()[:, :, "pixel"].flag
new_im = flag * (tr.get_choices()[:, :, "pixel"].value / 5.0 + image_array)
images.append([ax.imshow(new_im, cmap="gray", vmin=0, vmax=1, animated=True)])
ani = animation.ArtistAnimation(fig, images, interval=200, blit=True, repeat_delay=1000)
# Save the animation as a GIF
ani.save("masked_image_animation.gif", writer="pillow")
# Display the animation in the notebook
from IPython.display import HTML
HTML(ani.to_jshtml())