IndexRequest
Speed Gains Part 2: Optimizing updates for vmap¶
import timeit
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import genjax
from genjax import ChoiceMapBuilder as C
from genjax import (
IndexRequest,
StaticRequest,
Update,
gen,
normal,
pretty,
)
from genjax._src.core.pytree import Const
pretty()
key = jax.random.key(0)
As we discussed in the previous cookbook entries, a main point of update
is to be used for incremental computation: update
performs algebraic simplifications of the logpdf-ratios computed in the weight that it returns. This is tracked through the Diff
system.
A limitation of the current automation is that if an address "x" has a tensor value, and any index of "x" changes, the system will consider that "x" has changed without capturing a finer description of what exactly changed.
However, we can manually specify how something has changed in a more specific way.
@gen
def model(size_model_const: Const[int]):
size_model = size_model_const.unwrap()
x = normal(0.0, 1.0) @ "x"
a = normal.vmap()(jnp.zeros(size_model), jnp.ones(size_model)) @ "a"
b = normal.vmap()(jnp.zeros(size_model), jnp.ones(size_model)) @ "b"
c = normal.vmap()(jnp.zeros(size_model), jnp.ones(size_model)) @ "c"
obs = normal(jnp.sum(a) + jnp.sum(b) + jnp.sum(c) + x, 5.0) @ "obs"
return obs
Let's create a trace from our model.
obs = C["obs"].set(
1.0,
)
size_model = 10000
args = (Const(size_model),)
key, subkey = jax.random.split(key)
tr, _ = model.importance(subkey, obs, args)
Let's first see an equivalent way to perform do what update
does.
Just like update
generalizes importance
, there is yet another more general interface, edit
, which generalizes update
.
We will go into the details of edit
in a follow up cookbook.
For now, let's see the equivalent of update
using edit
. For this, we introduce a Request
to change the trace.
edit
will then answer the Request
and change the trace following the logic of the request.
To mimick update
, we will perform an Update
request.
change_in_value_for_a = jnp.ones(size_model)
# usual update
constraints = C["a"].set(change_in_value_for_a)
argdiffs = genjax.Diff.no_change(args)
key, subkey = jax.random.split(key)
new_tr1, _, _, _ = tr.update(subkey, constraints, argdiffs)
# update using `Request`
val = C.v(change_in_value_for_a)
request = StaticRequest({"a": Update(val)})
key, subkey = jax.random.split(key)
new_tr2, _, _, _ = request.edit(subkey, tr, args)
# comparing the values of both choicemaps after the update
jax.tree_util.tree_all(
jax.tree.map(jnp.allclose, new_tr1.get_choices(), new_tr2.get_choices())
)
Now let's see how we can efficiently change the value of "a" at a specific index.
For that, we create a more specific Request
called an IndexRequest
. This request expects another request for what to do at the given index.
request = StaticRequest({"a": IndexRequest(jnp.array(3), Update(C.v(42.0)))})
key, subkey = jax.random.split(key)
new_tr, _, _, _ = request.edit(subkey, tr, args)
# Checking we only made one change by checking that only one value in the choicemap is 42
jnp.sum(new_tr.get_choices()["a"] == 42.0) == 1
Now, let's compare the 3 options: naive density ratio computation vs update
vs IndexRequest
.
For this, we will do a comparison of doing an MH move on a specific variable in the model as we did in the previous cookbook, but this time for a specific index of the traced value "a".
We will also compare
IDX_WHERE_CHANGE_A = 3
@gen
def rejuv_a(a):
a = normal(a, 1.0) @ "a"
return a
def compute_ratio_slow(key, fwd_choice, fwd_weight, model_args, chm):
model_weight_old, _ = model.assess(chm, model_args)
new_a = chm["a"].at[IDX_WHERE_CHANGE_A].set(fwd_choice["a"])
new_chm = C["a"].set(new_a) | chm
model_weight_new, _ = model.assess(new_chm, model_args)
old_a = C["a"].set(chm["a", IDX_WHERE_CHANGE_A])
proposal_args_backward = (fwd_choice["a"],)
bwd_weight, _ = rejuv_a.assess(old_a, proposal_args_backward)
α = model_weight_new - model_weight_old - fwd_weight + bwd_weight
return α
def compute_ratio_fast(key, fwd_choice, fwd_weight, model_args, trace):
argdiffs = genjax.Diff.no_change(model_args)
constraint = C["a"].set(
trace.get_choices()["a"].at[IDX_WHERE_CHANGE_A].set(fwd_choice["a"])
)
_, weight, _, discard = model.update(key, trace, constraint, argdiffs)
proposal_args_backward = (fwd_choice["a"],)
bwd_weight, _ = rejuv_a.assess(
C["a"].set(discard["a", IDX_WHERE_CHANGE_A]), proposal_args_backward
)
α = weight - fwd_weight + bwd_weight
return α
def compute_ratio_very_fast(key, fwd_choice, fwd_weight, model_args, trace):
request = StaticRequest({
"a": IndexRequest(jnp.array(IDX_WHERE_CHANGE_A), Update(C.v(fwd_choice["a"])))
})
_, weight, _, _ = request.edit(key, trace, model_args)
proposal_args_backward = (fwd_choice["a"],)
bwd_weight, _ = rejuv_a.assess(
C["a"].set(trace.get_choices()["a", IDX_WHERE_CHANGE_A]), proposal_args_backward
)
α = weight - fwd_weight + bwd_weight
return α
def metropolis_hastings_move(key, trace, which_move):
model_args = trace.get_args()
proposal_args_forward = (trace.get_choices()["a", IDX_WHERE_CHANGE_A],)
key, subkey = jax.random.split(key)
fwd_choice, fwd_weight, _ = rejuv_a.propose(subkey, proposal_args_forward)
key, subkey = jax.random.split(key)
if which_move == 0:
chm = trace.get_choices()
α = compute_ratio_slow(subkey, fwd_choice, fwd_weight, model_args, chm)
elif which_move == 1:
α = compute_ratio_fast(subkey, fwd_choice, fwd_weight, model_args, trace)
else:
α = compute_ratio_very_fast(subkey, fwd_choice, fwd_weight, model_args, trace)
old_chm = C["a"].set(trace.get_choices()["a"])
new_chm = C["a"].set(old_chm["a"].at[IDX_WHERE_CHANGE_A].set(fwd_choice["a"]))
key, subkey = jax.random.split(key)
ret_chm = jax.lax.cond(
jnp.log(jax.random.uniform(subkey)) < α, lambda: new_chm, lambda: old_chm
)
return ret_chm
model_sizes = [1000, 10000, 100000, 1000000, 10000000, 100000000]
slow_times = []
fast_times = []
very_fast_times = []
for model_size in model_sizes:
total_time_slow = 0
total_time_fast = 0
total_time_very_fast = 0
num_trials = 10000 if model_size <= 1000000 else 200
model_size = Const(model_size)
obs = C["obs"].set(
1.0,
)
key, subkey = jax.random.split(key)
# create a trace from the model of the right size
tr, _ = jax.jit(model.importance, static_argnums=(2))(subkey, obs, (model_size,))
# warm up run to trigger jit compilation
jitted = jax.jit(metropolis_hastings_move, static_argnums=(2))
jitted(subkey, tr, 0)
jitted(subkey, tr, 1)
jitted(subkey, tr, 2)
# measure time for each algorithm
total_time_slow = timeit.timeit(lambda: jitted(subkey, tr, 0), number=num_trials)
total_time_fast = timeit.timeit(lambda: jitted(subkey, tr, 1), number=num_trials)
total_time_very_fast = timeit.timeit(
lambda: jitted(subkey, tr, 2), number=num_trials
)
average_time_slow = total_time_slow / num_trials
average_time_fast = total_time_fast / num_trials
average_time_very_fast = total_time_very_fast / num_trials
slow_times.append(average_time_slow)
fast_times.append(average_time_fast)
very_fast_times.append(average_time_very_fast)
plt.figure(figsize=(20, 5))
# First half of the values
plt.subplot(1, 2, 1)
plt.plot(
model_sizes[: len(model_sizes) // 2],
[time * 1000 for time in slow_times[: len(slow_times) // 2]],
marker="o",
label="No incremental computation",
)
plt.plot(
model_sizes[: len(model_sizes) // 2],
[time * 1000 for time in fast_times[: len(fast_times) // 2]],
marker="o",
label="Default incremental computation",
)
plt.plot(
model_sizes[: len(model_sizes) // 2],
[time * 1000 for time in very_fast_times[: len(very_fast_times) // 2]],
marker="o",
label="Optimized incremental computation",
)
plt.xscale("log")
plt.xlabel("Argument (n)")
plt.ylabel("Average Time (milliseconds)")
plt.title("Average Execution Time of MH move for different model sizes (First Half)")
plt.grid(True)
plt.legend()
# Second half of the values
plt.subplot(1, 2, 2)
plt.plot(
model_sizes[len(model_sizes) // 2 :],
[time * 1000 for time in slow_times[len(slow_times) // 2 :]],
marker="o",
label="No incremental computation",
)
plt.plot(
model_sizes[len(model_sizes) // 2 :],
[time * 1000 for time in fast_times[len(fast_times) // 2 :]],
marker="o",
label="Default incremental computation",
)
plt.plot(
model_sizes[len(model_sizes) // 2 :],
[time * 1000 for time in very_fast_times[len(very_fast_times) // 2 :]],
marker="o",
label="Optimized incremental computation",
)
plt.xscale("log")
plt.xlabel("Argument (n)")
plt.ylabel("Average Time (milliseconds)")
plt.title("Average Execution Time of MH move for different model sizes (Second Half)")
plt.grid(True)
plt.legend()
plt.show()