Conditionals
How do I use conditionals in (Gen)JAX?
¶
import sys
if "google.colab" in sys.modules:
%pip install --quiet "genjax[genstudio]"
import jax
import jax.numpy as jnp
from genjax import bernoulli, gen, normal, or_else, pretty, switch
key = jax.random.key(0)
pretty()
In pure Python, we can use usual conditionals
def simple_cond_python(p):
if p > 0:
return 2 * p
else:
return -p
simple_cond_python(0.3), simple_cond_python(-0.4)
In pure JAX, we write conditionals with jax.lax.cond
as follows
def branch_1(p):
return 2 * p
def branch_2(p):
return -p
def simple_cond_jax(p):
pred = p > 0
arg_of_cond = p
cond_res = jax.lax.cond(pred, branch_1, branch_2, arg_of_cond)
return cond_res
simple_cond_jax(0.3), simple_cond_jax(-0.4)
Compiled JAX code is usually quite faster than Python code
def python_loop(x):
for i in range(40000):
if x < 100.0:
x = 2 * x
else:
x = x - 97.0
return x
@jax.jit
def jax_loop(x):
return jax.lax.fori_loop(
0,
40000,
lambda _, x: jax.lax.cond(x < 100.0, lambda x: 2 * x, lambda x: x - 97.0, x),
x,
)
%timeit python_loop(1.0)
# Get the JIT time out of the way
jax_loop(1.0)
%timeit jax_loop(1.0)
2.03 ms ± 5.13 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)
3.25 ms ± 8.41 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)
One restriction is that both branches should have the same pytree structure
def failing_simple_cond_1(p):
pred = p > 0
def branch_1(p):
return (p, p)
def branch_2(p):
return -p
arg_of_cond = p
cond_res = jax.lax.cond(pred, branch_1, branch_2, arg_of_cond)
return cond_res
try:
print(failing_simple_cond_1(0.3))
except TypeError as e:
print(e)
true_fun output must have same type structure as false_fun output, but there are differences: * at output, true_fun output has <class 'tuple'> and false_fun output has pytree leaf, so their Python types differ
The other one is that the type of the output of the branches should be the same
def failing_simple_cond_2(p):
pred = p > 0
def branch_1(p):
return 2 * p
def branch_2(p):
return 7
arg_of_cond = p
cond_res = jax.lax.cond(pred, branch_1, branch_2, arg_of_cond)
return cond_res
try:
print(failing_simple_cond_2(0.3))
except TypeError as e:
print(e)
true_fun output and false_fun output must have identical types, got DIFFERENT ShapedArray(float32[], weak_type=True) vs. ShapedArray(int32[], weak_type=True).
In GenJAX, the syntax is a bit different still.
Similarly to JAX having a custom primitive jax.lax.cond
that creates a conditional by "composing" two functions seen as branches, GenJAX has a custom combinator that "composes" two generative functions, called genjax.or_else
.
We can first define the two branches as generative functions
@gen
def branch_1(p):
v = bernoulli(p) @ "v1"
return v
@gen
def branch_2(p):
v = bernoulli(-p) @ "v2"
return v
Then we use the combinator to compose them
@gen
def cond_model(p):
pred = p > 0
arg_1 = (p,)
arg_2 = (p,)
v = or_else(branch_1, branch_2)(pred, arg_1, arg_2) @ "cond"
return v
jitted = jax.jit(cond_model.simulate)
key, subkey = jax.random.split(key)
tr = jitted(subkey, (0.0,))
tr.get_choices()
Alternatively, we can write or_else
as follows:
@gen
def cond_model_v2(p):
pred = p > 0
arg_1 = (p,)
arg_2 = (p,)
v = branch_1.or_else(branch_2)(pred, arg_1, arg_2) @ "cond"
return v
key, subkey = jax.random.split(key)
cond_model_v2.simulate(subkey, (0.0,))
Note that it may be possible to write the following down, but this will not give you what you want in general!
# TODO: find a way to make it fail to better show the point.
@gen
def simple_cond_genjax(p):
def branch_1(p):
return bernoulli(p) @ "v1"
def branch_2(p):
return bernoulli(-p) @ "v2"
cond = jax.lax.cond(p > 0, branch_1, branch_2, p)
return cond
key, subkey = jax.random.split(key)
tr1 = simple_cond_genjax.simulate(subkey, (0.3,))
key, subkey = jax.random.split(key)
tr2 = simple_cond_genjax.simulate(subkey, (-0.4,))
tr1.get_retval(), tr2.get_retval()
Alternatively, if we have more than two branches, in JAX we can use the jax.lax.switch
function.
def simple_switch_jax(p):
index = jnp.floor(jnp.abs(p)).astype(jnp.int32) % 3
branches = [lambda p: 2 * p, lambda p: -p, lambda p: p]
switch_res = jax.lax.switch(index, branches, p)
return switch_res
simple_switch_jax(0.3), simple_switch_jax(1.1), simple_switch_jax(2.3)
Likewise, in GenJAX we can use the switch
combinator if we have more than two branches.
We can first define three branches as generative functions
@gen
def branch_1(p):
v = normal(p, 1.0) @ "v1"
return v
@gen
def branch_2(p):
v = normal(-p, 1.0) @ "v2"
return v
@gen
def branch_3(p):
v = normal(p * p, 1.0) @ "v3"
return v
Then we use the combinator to compose them.
@gen
def switch_model(p):
index = jnp.floor(jnp.abs(p)).astype(jnp.int32) % 3
v = switch(branch_1, branch_2, branch_3)(index, (p,), (p,), (p,)) @ "s"
return v
key, subkey = jax.random.split(key)
jitted = jax.jit(switch_model.simulate)
tr1 = jitted(subkey, (0.0,))
key, subkey = jax.random.split(key)
tr2 = jitted(subkey, (1.1,))
key, subkey = jax.random.split(key)
tr3 = jitted(subkey, (2.2,))
(
tr1.get_choices()["s", "v1"],
tr2.get_choices()["s", "v2"],
tr3.get_choices()["s", "v3"],
)
We can rewrite the above a bit more elegantly using the *args syntax
@gen
def switch_model_v2(p):
index = jnp.floor(jnp.abs(p)).astype(jnp.int32) % 3
branches = [branch_1, branch_2, branch_3]
args = [(p,), (p,), (p,)]
v = switch(*branches)(index, *args) @ "switch"
return v
jitted = switch_model_v2.simulate
key, subkey = jax.random.split(key)
tr = jitted(subkey, (0.0,))
tr.get_choices()["switch", "v1"]