Incremental
Compute gains via incremental computation or how to not compute log pdfs¶
import timeit
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import genjax
from genjax import ChoiceMapBuilder as C
from genjax import gen, normal, pretty
from genjax._src.core.pytree import Const
pretty()
key = jax.random.PRNGKey(0)
In the previous cookbooks, we have seen that importance
and update
do algebraic simplifications in the weight ratios that they are computing.
Let's first see the difference in the case of importance
by testing a naive version of sampling importance resampling (SIR) to one using importance
.
Let's define a model to be used in the rest ot the cookbook.
@gen
def model(size_model: Const[int]):
size_model = size_model.unwrap()
x = normal(0.0, 1.0) @ "x"
a = normal.vmap()(jnp.zeros(size_model), jnp.ones(size_model)) @ "a"
b = normal.vmap()(jnp.zeros(size_model), jnp.ones(size_model)) @ "b"
c = normal.vmap()(jnp.zeros(size_model), jnp.ones(size_model)) @ "c"
obs = normal(jnp.sum(a) + jnp.sum(b) + jnp.sum(c) + x, 5.0) @ "obs"
return obs
To compare naive SIR to the one using importance
and the default proposal, let's write define the default proposal manually:
@gen
def default_proposal(size_model: Const[int]):
size_model = size_model.unwrap()
_ = normal(0.0, 1.0) @ "x"
_ = normal.vmap()(jnp.zeros(size_model), jnp.ones(size_model)) @ "a"
_ = normal.vmap()(jnp.zeros(size_model), jnp.ones(size_model)) @ "b"
_ = normal.vmap()(jnp.zeros(size_model), jnp.ones(size_model)) @ "c"
return None
Let's now write SIR with a parameter controlling whether to call the slow or fast version.
obs = C["obs"].set(
1.0,
)
def sir(key, N: int, use_fast: bool, size_model):
if use_fast:
traces, weights = jax.vmap(model.importance, in_axes=(0, None, None))(
jax.random.split(key, N), obs, size_model
)
else:
traces = jax.vmap(default_proposal.simulate, in_axes=(0, None))(
jax.random.split(key, N), size_model
)
chm_proposal = traces.get_choices()
q_weights, _ = jax.vmap(
lambda idx: default_proposal.assess(
jax.tree_util.tree_map(lambda v: v[idx], chm_proposal), size_model
)
)(jnp.arange(N))
chm_model = chm_proposal | C["obs"].set(jnp.ones(N) * obs["obs"])
p_weights, _ = jax.vmap(
lambda idx: model.assess(
jax.tree_util.tree_map(lambda v: v[idx], chm_model), size_model
)
)(jnp.arange(N))
weights = p_weights - q_weights
idx = genjax.categorical.simulate(key, (weights,)).get_retval()
samples = traces.get_choices()
resampled = jax.tree_util.tree_map(lambda v: v[idx], samples)
return resampled
Let's now compare the speed of the 2 versions (beware there's some variance in the estimate, but adding more trials makes the runtime comparison take a while).
obs = C["obs"].set(
1.0,
)
model_sizes = [10, 100, 1000]
N_sir = 100
num_trials = 30
slow_times = []
fast_times = []
for model_size in model_sizes:
total_time_slow = 0
total_time_fast = 0
model_size = Const(model_size)
obs = C["obs"].set(
1.0,
)
key, subkey = jax.random.split(key)
# warm up run to trigger jit compilation
jitted = jax.jit(sir, static_argnums=(1, 2))
jitted(subkey, N_sir, False, (Const(model_sizes),))
jitted(subkey, N_sir, True, (Const(model_sizes),))
# measure time for each algorithm
key, subkey = jax.random.split(key)
total_time_slow = timeit.timeit(
lambda: jitted(subkey, N_sir, False, (Const(model_sizes),)), number=num_trials
)
total_time_fast = timeit.timeit(
lambda: jitted(subkey, N_sir, True, (Const(model_sizes),)), number=num_trials
)
average_time_slow = total_time_slow / num_trials
average_time_fast = total_time_fast / num_trials
slow_times.append(average_time_slow)
fast_times.append(average_time_fast)
plt.plot(model_sizes, [time for time in slow_times], marker="o", label="Slow Algorithm")
plt.plot(model_sizes, [time for time in fast_times], marker="o", label="Fast Algorithm")
plt.xscale("log")
plt.xlabel("Argument (n)")
plt.ylabel("Average Time (seconds)")
plt.title("Average Execution Time of MH move for different model sizes")
plt.grid(True)
plt.legend()
plt.show()
When doing inference with iterative algorithms like MCMC, we often need to make small adjustments to the choice map.
We have seen that update
can be used to compute part of the MH acceptance ratio.
So now let's try to compare two versions of an MH move, one computing naively thee ratio and one using update.
Let's create a very basic kernel to rejuvenate the variable "x" in an MH algorithm.
@gen
def rejuv_x(x):
x = normal(x, 1.0) @ "x"
return x
Let's now write 2 versions of computing the MH acceptance ratio as well as the MH algorithm to rejuvenate the variable "x".
def compute_ratio_slow(key, fwd_choice, fwd_weight, model_args, chm):
model_weight_old, _ = model.assess(chm, model_args)
new_chm = fwd_choice | chm
model_weight_new, _ = model.assess(new_chm, model_args)
old_x = C["x"].set(chm["x"])
proposal_args_backward = (fwd_choice["x"],)
bwd_weight, _ = rejuv_x.assess(old_x, proposal_args_backward)
α = model_weight_new - model_weight_old - fwd_weight + bwd_weight
return α
def compute_ratio_fast(key, fwd_choice, fwd_weight, model_args, trace):
argdiffs = genjax.Diff.no_change(model_args)
_, weight, _, discard = model.update(key, trace, fwd_choice, argdiffs)
proposal_args_backward = (fwd_choice["x"],)
bwd_weight, _ = rejuv_x.assess(discard, proposal_args_backward)
α = weight - fwd_weight + bwd_weight
return α
def metropolis_hastings_move(key, trace, use_fast):
model_args = trace.get_args()
proposal_args_forward = (trace.get_choices()["x"],)
key, subkey = jax.random.split(key)
fwd_choice, fwd_weight, _ = rejuv_x.propose(subkey, proposal_args_forward)
key, subkey = jax.random.split(key)
if use_fast:
α = compute_ratio_fast(subkey, fwd_choice, fwd_weight, model_args, trace)
else:
chm = trace.get_choices()
α = compute_ratio_slow(subkey, fwd_choice, fwd_weight, model_args, chm)
old_choice = C["x"].set(trace.get_choices()["x"])
key, subkey = jax.random.split(key)
ret_trace = jax.lax.cond(
jnp.log(jax.random.uniform(subkey)) < α, lambda: fwd_choice, lambda: old_choice
)
return ret_trace
Let's measure the performance of each variant.
model_sizes = [1000, 10000, 100000, 1000000, 10000000, 100000000]
slow_times = []
fast_times = []
for model_size in model_sizes:
total_time_slow = 0
total_time_fast = 0
num_trials = 5000 if model_size <= 1000000 else 100
model_size = Const(model_size)
obs = C["obs"].set(
1.0,
)
key, subkey = jax.random.split(key)
# create a trace from the model of the right size
tr, _ = jax.jit(model.importance, static_argnums=(2))(subkey, obs, (model_size,))
# warm up run to trigger jit compilation
jitted = jax.jit(metropolis_hastings_move, static_argnums=(2))
jitted(subkey, tr, False)
jitted(subkey, tr, True)
# measure time for each algorithm
key, subkey = jax.random.split(key)
total_time_slow = timeit.timeit(
lambda: jitted(subkey, tr, False), number=num_trials
)
total_time_fast = timeit.timeit(lambda: jitted(subkey, tr, True), number=num_trials)
average_time_slow = total_time_slow / num_trials
average_time_fast = total_time_fast / num_trials
slow_times.append(average_time_slow)
fast_times.append(average_time_fast)
Plotting the results.
plt.figure(figsize=(20, 5))
# First half of the values
plt.subplot(1, 2, 1)
plt.plot(
model_sizes[: len(model_sizes) // 2],
[time * 1000 for time in slow_times[: len(slow_times) // 2]],
marker="o",
label="No incremental computation",
)
plt.plot(
model_sizes[: len(model_sizes) // 2],
[time * 1000 for time in fast_times[: len(fast_times) // 2]],
marker="o",
label="Default incremental computation",
)
plt.xscale("log")
plt.xlabel("Argument (n)")
plt.ylabel("Average Time (milliseconds)")
plt.title("Average Execution Time of MH move for different model sizes (First Half)")
plt.grid(True)
plt.legend()
# Second half of the values
plt.subplot(1, 2, 2)
plt.plot(
model_sizes[len(model_sizes) // 2 :],
[time * 1000 for time in slow_times[len(slow_times) // 2 :]],
marker="o",
label="No incremental computation",
)
plt.plot(
model_sizes[len(model_sizes) // 2 :],
[time * 1000 for time in fast_times[len(fast_times) // 2 :]],
marker="o",
label="Default incremental computation",
)
plt.xscale("log")
plt.xlabel("Argument (n)")
plt.ylabel("Average Time (milliseconds)")
plt.title("Average Execution Time of MH move for different model sizes (Second Half)")
plt.grid(True)
plt.legend()
plt.show()