IndexRequest
Speed Gains Part 2: Optimizing updates for vmap¶
import timeit
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import genjax
from genjax import ChoiceMapBuilder as C
from genjax import (
IndexRequest,
StaticRequest,
Update,
gen,
normal,
pretty,
)
from genjax._src.core.pytree import Const
pretty()
key = jax.random.key(0)
As we discussed in the previous cookbook entries, a main point of update is to be used for incremental computation: update performs algebraic simplifications of the logpdf-ratios computed in the weight that it returns. This is tracked through the Diff system.
A limitation of the current automation is that if an address "x" has a tensor value, and any index of "x" changes, the system will consider that "x" has changed without capturing a finer description of what exactly changed.
However, we can manually specify how something has changed in a more specific way.
@gen
def model(size_model_const: Const[int]):
size_model = size_model_const.unwrap()
x = normal(0.0, 1.0) @ "x"
a = normal.vmap()(jnp.zeros(size_model), jnp.ones(size_model)) @ "a"
b = normal.vmap()(jnp.zeros(size_model), jnp.ones(size_model)) @ "b"
c = normal.vmap()(jnp.zeros(size_model), jnp.ones(size_model)) @ "c"
obs = normal(jnp.sum(a) + jnp.sum(b) + jnp.sum(c) + x, 5.0) @ "obs"
return obs
Let's create a trace from our model.
obs = C["obs"].set(
1.0,
)
size_model = 10000
args = (Const(size_model),)
key, subkey = jax.random.split(key)
tr, _ = model.importance(subkey, obs, args)
Let's first see an equivalent way to perform do what update does.
Just like update generalizes importance, there is yet another more general interface, edit, which generalizes update.
We will go into the details of edit in a follow up cookbook.
For now, let's see the equivalent of update using edit. For this, we introduce a Request to change the trace.
edit will then answer the Request and change the trace following the logic of the request.
To mimick update, we will perform an Update request.
change_in_value_for_a = jnp.ones(size_model)
# usual update
constraints = C["a"].set(change_in_value_for_a)
argdiffs = genjax.Diff.no_change(args)
key, subkey = jax.random.split(key)
new_tr1, _, _, _ = tr.update(subkey, constraints, argdiffs)
# update using `Request`
val = C.v(change_in_value_for_a)
request = StaticRequest({"a": Update(val)})
key, subkey = jax.random.split(key)
new_tr2, _, _, _ = request.edit(subkey, tr, args)
# comparing the values of both choicemaps after the update
jax.tree_util.tree_all(
jax.tree.map(jnp.allclose, new_tr1.get_choices(), new_tr2.get_choices())
)
Now let's see how we can efficiently change the value of "a" at a specific index.
For that, we create a more specific Request called an IndexRequest. This request expects another request for what to do at the given index.
request = StaticRequest({"a": IndexRequest(jnp.array(3), Update(C.v(42.0)))})
key, subkey = jax.random.split(key)
new_tr, _, _, _ = request.edit(subkey, tr, args)
# Checking we only made one change by checking that only one value in the choicemap is 42
jnp.sum(new_tr.get_choices()["a"] == 42.0) == 1
Now, let's compare the 3 options: naive density ratio computation vs update vs IndexRequest.
For this, we will do a comparison of doing an MH move on a specific variable in the model as we did in the previous cookbook, but this time for a specific index of the traced value "a".
We will also compare
IDX_WHERE_CHANGE_A = 3
@gen
def rejuv_a(a):
a = normal(a, 1.0) @ "a"
return a
def compute_ratio_slow(key, fwd_choice, fwd_weight, model_args, chm):
model_weight_old, _ = model.assess(chm, model_args)
new_a = chm["a"].at[IDX_WHERE_CHANGE_A].set(fwd_choice["a"])
new_chm = C["a"].set(new_a) | chm
model_weight_new, _ = model.assess(new_chm, model_args)
old_a = C["a"].set(chm["a", IDX_WHERE_CHANGE_A])
proposal_args_backward = (fwd_choice["a"],)
bwd_weight, _ = rejuv_a.assess(old_a, proposal_args_backward)
α = model_weight_new - model_weight_old - fwd_weight + bwd_weight
return α
def compute_ratio_fast(key, fwd_choice, fwd_weight, model_args, trace):
argdiffs = genjax.Diff.no_change(model_args)
constraint = C["a"].set(
trace.get_choices()["a"].at[IDX_WHERE_CHANGE_A].set(fwd_choice["a"])
)
_, weight, _, discard = model.update(key, trace, constraint, argdiffs)
proposal_args_backward = (fwd_choice["a"],)
bwd_weight, _ = rejuv_a.assess(
C["a"].set(discard["a", IDX_WHERE_CHANGE_A]), proposal_args_backward
)
α = weight - fwd_weight + bwd_weight
return α
def compute_ratio_very_fast(key, fwd_choice, fwd_weight, model_args, trace):
request = StaticRequest({
"a": IndexRequest(jnp.array(IDX_WHERE_CHANGE_A), Update(C.v(fwd_choice["a"])))
})
_, weight, _, _ = request.edit(key, trace, model_args)
proposal_args_backward = (fwd_choice["a"],)
bwd_weight, _ = rejuv_a.assess(
C["a"].set(trace.get_choices()["a", IDX_WHERE_CHANGE_A]), proposal_args_backward
)
α = weight - fwd_weight + bwd_weight
return α
def metropolis_hastings_move(key, trace, which_move):
model_args = trace.get_args()
proposal_args_forward = (trace.get_choices()["a", IDX_WHERE_CHANGE_A],)
key, subkey = jax.random.split(key)
fwd_choice, fwd_weight, _ = rejuv_a.propose(subkey, proposal_args_forward)
key, subkey = jax.random.split(key)
if which_move == 0:
chm = trace.get_choices()
α = compute_ratio_slow(subkey, fwd_choice, fwd_weight, model_args, chm)
elif which_move == 1:
α = compute_ratio_fast(subkey, fwd_choice, fwd_weight, model_args, trace)
else:
α = compute_ratio_very_fast(subkey, fwd_choice, fwd_weight, model_args, trace)
old_chm = C["a"].set(trace.get_choices()["a"])
new_chm = C["a"].set(old_chm["a"].at[IDX_WHERE_CHANGE_A].set(fwd_choice["a"]))
key, subkey = jax.random.split(key)
ret_chm = jax.lax.cond(
jnp.log(jax.random.uniform(subkey)) < α, lambda: new_chm, lambda: old_chm
)
return ret_chm
model_sizes = [1000, 10000, 100000, 1000000, 10000000, 100000000]
slow_times = []
fast_times = []
very_fast_times = []
for model_size in model_sizes:
total_time_slow = 0
total_time_fast = 0
total_time_very_fast = 0
num_trials = 10000 if model_size <= 1000000 else 200
model_size = Const(model_size)
obs = C["obs"].set(
1.0,
)
key, subkey = jax.random.split(key)
# create a trace from the model of the right size
tr, _ = jax.jit(model.importance, static_argnums=(2))(subkey, obs, (model_size,))
# warm up run to trigger jit compilation
jitted = jax.jit(metropolis_hastings_move, static_argnums=(2))
jitted(subkey, tr, 0)
jitted(subkey, tr, 1)
jitted(subkey, tr, 2)
# measure time for each algorithm
total_time_slow = timeit.timeit(lambda: jitted(subkey, tr, 0), number=num_trials)
total_time_fast = timeit.timeit(lambda: jitted(subkey, tr, 1), number=num_trials)
total_time_very_fast = timeit.timeit(
lambda: jitted(subkey, tr, 2), number=num_trials
)
average_time_slow = total_time_slow / num_trials
average_time_fast = total_time_fast / num_trials
average_time_very_fast = total_time_very_fast / num_trials
slow_times.append(average_time_slow)
fast_times.append(average_time_fast)
very_fast_times.append(average_time_very_fast)
plt.figure(figsize=(20, 5))
# First half of the values
plt.subplot(1, 2, 1)
plt.plot(
model_sizes[: len(model_sizes) // 2],
[time * 1000 for time in slow_times[: len(slow_times) // 2]],
marker="o",
label="No incremental computation",
)
plt.plot(
model_sizes[: len(model_sizes) // 2],
[time * 1000 for time in fast_times[: len(fast_times) // 2]],
marker="o",
label="Default incremental computation",
)
plt.plot(
model_sizes[: len(model_sizes) // 2],
[time * 1000 for time in very_fast_times[: len(very_fast_times) // 2]],
marker="o",
label="Optimized incremental computation",
)
plt.xscale("log")
plt.xlabel("Argument (n)")
plt.ylabel("Average Time (milliseconds)")
plt.title("Average Execution Time of MH move for different model sizes (First Half)")
plt.grid(True)
plt.legend()
# Second half of the values
plt.subplot(1, 2, 2)
plt.plot(
model_sizes[len(model_sizes) // 2 :],
[time * 1000 for time in slow_times[len(slow_times) // 2 :]],
marker="o",
label="No incremental computation",
)
plt.plot(
model_sizes[len(model_sizes) // 2 :],
[time * 1000 for time in fast_times[len(fast_times) // 2 :]],
marker="o",
label="Default incremental computation",
)
plt.plot(
model_sizes[len(model_sizes) // 2 :],
[time * 1000 for time in very_fast_times[len(very_fast_times) // 2 :]],
marker="o",
label="Optimized incremental computation",
)
plt.xscale("log")
plt.xlabel("Argument (n)")
plt.ylabel("Average Time (milliseconds)")
plt.title("Average Execution Time of MH move for different model sizes (Second Half)")
plt.grid(True)
plt.legend()
plt.show()