JAX Basics
¶
import sys
if "google.colab" in sys.modules:
%pip install --quiet "genjax[genstudio]"
import multiprocessing
import subprocess
import time
import jax
import jax.numpy as jnp
import numpy as np
from jax import jit, random
import genjax
from genjax import ChoiceMapBuilder as C
from genjax import beta, gen, pretty
key = jax.random.key(0)
pretty()
- JAX expects arrays/tuples everywhere
@gen
def f(p):
v = genjax.bernoulli(probs=p) @ "v"
return v
# First way of failing
key, subkey = jax.random.split(key)
try:
f.simulate(key, 0.5)
except Exception as e:
print(e)
# Second way of failing
key, subkey = jax.random.split(key)
try:
f.simulate(subkey, [0.5])
except Exception as e:
print(e)
# Third way of failing
key, subkey = jax.random.split(key)
try:
f.simulate(subkey, (0.5))
except Exception as e:
print(e)
# Correct way
key, subkey = jax.random.split(key)
f.simulate(subkey, (0.5,)).get_retval()
Method genjax._src.generative_functions.static.StaticGenerativeFunction.simulate() parameter args=0.5 violates type hint tuple[typing.Any, ...], as float 0.5 not instance of tuple. Method genjax._src.generative_functions.static.StaticGenerativeFunction.simulate() parameter args=[0.5] violates type hint tuple[typing.Any, ...], as list [0.5] not instance of tuple. Method genjax._src.generative_functions.static.StaticGenerativeFunction.simulate() parameter args=0.5 violates type hint tuple[typing.Any, ...], as float 0.5 not instance of tuple.
- GenJAX relies on Tensor Flow Probability and it sometimes does unintuitive things.
The Bernoulli distribution uses logits instead of probabilities
@gen
def g(p):
v = genjax.bernoulli(probs=p) @ "v"
return v
key, subkey = jax.random.split(key)
arg = (3.0,) # 3 is not a valid probability but a valid logit
keys = jax.random.split(subkey, 30)
# simulate 30 times
jax.vmap(g.simulate, in_axes=(0, None))(keys, arg).get_choices()
Values which are stricter than $0$ are considered to be the value True.
This means that observing that the value of "v"
is $4$ will be considered possible while intuitively "v"
should only have support on $0$ and $1$.
chm = C["v"].set(3)
g.assess(chm, (0.5,))[0] # This should be -inf.
Alternatively, we can use the flip function which uses probabilities instead of logits.
@gen
def h(p):
v = genjax.flip(p) @ "v"
return v
key, subkey = jax.random.split(key)
arg = (0.3,) # 0.3 is a valid probability
keys = jax.random.split(subkey, 30)
# simulate 30 times
jax.vmap(h.simulate, in_axes=(0, None))(keys, arg).get_choices()
Categorical distributions also use logits instead of probabilities
@gen
def i(p):
v = genjax.categorical(p) @ "v"
return v
key, subkey = jax.random.split(key)
arg = ([3.0, 1.0, 2.0],) # lists of 3 logits
keys = jax.random.split(subkey, 30)
# simulate 30 times
jax.vmap(i.simulate, in_axes=(0, None))(keys, arg).get_choices()
- JAX code can be compiled for better performance.
jit
is the way to force JAX to compile the code.
It can be used as a decorator.
@jit
def f_v1(p):
return jax.lax.cond(p.sum(), lambda p: p * p, lambda p: p * p, p)
Or as a function
f_v2 = jit(lambda p: jax.lax.cond(p.sum(), lambda p: p * p, lambda p: p * p, p))
Testing the effect. Notice that the first and second have the same performance while the third is much slower (~50x on a mac m2 cpu)
# Baseline
def f_v3(p):
jax.lax.cond(p.sum(), lambda p: p * p, lambda p: p * p, p)
arg = jax.numpy.eye(500)
# Warmup to force jit compilation
f_v1(arg)
f_v2(arg)
# Runtime comparison
%timeit f_v1(arg)
%timeit f_v2(arg)
%timeit f_v3(arg)
#
153 μs ± 926 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
152 μs ± 961 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
22.8 ms ± 138 μs per loop (mean ± std. dev. of 7 runs, 10 loops each)
- Going from Python to JAX
4.1 For loops
def python_loop(x):
for i in range(100):
x = 2 * x
return x
def jax_loop(x):
jax.lax.fori_loop(0, 100, lambda i, x: 2 * x, x)
4.2 Conditional statements
def python_cond(x):
if x.sum() > 0:
return x * x
else:
return x
def jax_cond(x):
jax.lax.cond(x.sum(), lambda x: x * x, lambda x: x, x)
4.3 While loops
def python_while(x):
while x.sum() > 0:
x = x * x
return x
def jax_while(x):
jax.lax.while_loop(lambda x: x.sum() > 0, lambda x: x * x, x)
- Is my thing compiling or is it blocked at traced time?
In Jax, the first time you run a function, it is traced, which produces a Jaxpr, a representation of the computation that Jax can optimize.
So in order to debug whether a function is running or not, if it passes the first check that Python let's you write it, you can check if it is running by checking if it is traced, before actually running it on data.
This is done by calling make_jaxpr
on the function. If it returns a Jaxpr, then the function is traced and ready to be run on data.
def im_fine(x):
return x * x
jax.make_jaxpr(im_fine)(1.0)
{ lambda ; a:f32[]. let b:f32[] = mul a a in (b,) }
def i_wont_be_so_fine(x):
return jax.lax.while_loop(lambda x: x > 0, lambda x: x * x, x)
jax.make_jaxpr(i_wont_be_so_fine)(1.0)
{ lambda ; a:f32[]. let b:f32[] = while[ body_jaxpr={ lambda ; c:f32[]. let d:f32[] = mul c c in (d,) } body_nconsts=0 cond_jaxpr={ lambda ; e:f32[]. let f:bool[] = gt e 0.0 in (f,) } cond_nconsts=0 ] a in (b,) }
Try running the function for 8 seconds
def run_process():
ctx = multiprocessing.get_context("spawn")
p = ctx.Process(target=i_wont_be_so_fine, args=(1.0,))
p.start()
time.sleep(5000)
if p.is_alive():
print("I'm still running")
p.terminate()
p.join()
result = subprocess.run(
["python", "genjax/docs/sharp-edges-notebooks/basics/script.py"],
capture_output=True,
text=True,
)
# Print the output
result.stdout
- Using random keys for generative functions
In GenJAX, we use explicit random keys to generate random numbers. This is done by splitting a key into multiple keys, and using them to generate random numbers.
@gen
def beta_bernoulli_process(u):
p = beta(0.0, u) @ "p"
v = genjax.bernoulli(probs=p) @ "v" # sweet
return v
key, subkey = jax.random.split(key)
keys = jax.random.split(subkey, 20)
jitted = jit(beta_bernoulli_process.simulate)
jax.vmap(jitted, in_axes=(0, None))(keys, (0.5,)).get_choices()
- JAX uses 32-bit floats by default
key, subkey = jax.random.split(key)
x = random.uniform(subkey, (1000,), dtype=jnp.float64)
print("surprise surprise: ", x.dtype)
surprise surprise: float32
/tmp/ipykernel_8832/1255632939.py:2: UserWarning: Explicitly requested dtype <class 'jax.numpy.float64'> is not available, and will be truncated to dtype float32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/jax-ml/jax#current-gotchas for more. x = random.uniform(subkey, (1000,), dtype=jnp.float64)
A common TypeError occurs when one tries using np instead of jnp, which is the JAX version of numpy, the former uses 64-bit floats by default, while the JAX version uses 32-bit floats by default.
This on its own gives a UserWarning
jnp.array([1, 2, 3], dtype=np.float64)
/tmp/ipykernel_8832/403521608.py:1: UserWarning: Explicitly requested dtype <class 'numpy.float64'> requested in array is not available, and will be truncated to dtype float32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/jax-ml/jax#current-gotchas for more. jnp.array([1, 2, 3], dtype=np.float64)
Using an array from numpy
instead of jax.numpy
will truncate the array to 32-bit floats and also give a UserWarning when used in JAX code
innocent_looking_array = np.array([1.0, 2.0, 3.0], dtype=np.float64)
@jax.jit
def innocent_looking_function(x):
return jax.lax.cond(x.sum(), lambda x: x * x, lambda x: innocent_looking_array, x)
input = jnp.array([1.0, 2.0, 3.0])
innocent_looking_function(input)
try:
# This looks fine so far but...
innocent_looking_array = np.array([1, 2, 3], dtype=np.float64)
# This actually raises a TypeError, as one branch has type float32
# while the other has type float64
@jax.jit
def innocent_looking_function(x):
return jax.lax.cond(
x.sum(), lambda x: x * x, lambda x: innocent_looking_array, x
)
input = jnp.array([1, 2, 3])
innocent_looking_function(input)
except Exception as e:
print(e)
true_fun output and false_fun output must have identical types, got DIFFERENT ShapedArray(int32[3]) vs. ShapedArray(float32[3]).
- Beware to OOM on the GPU which happens faster than you might think
Here's a simple HMM model that can be run on the GPU. By simply changing $N$ from $300$ to $1000$, the code will typically run out of memory on the GPU as it will take ~300GB of memory
N = 300
n_repeats = 100
variance = jnp.eye(N)
key, subkey = jax.random.split(key)
initial_state = jax.random.normal(subkey, (N,))
@genjax.gen
def hmm_step(x, _):
new_x = genjax.mv_normal(x, variance) @ "new_x"
return new_x, None
hmm = hmm_step.scan(n=100)
key, subkey = jax.random.split(key)
jitted = jit(hmm.repeat(n=n_repeats).simulate)
trace = jitted(subkey, (initial_state, None))
key, subkey = jax.random.split(key)
%timeit jitted(subkey, (initial_state, None))
499 ms ± 1.46 ms per loop (mean ± std. dev. of 7 runs, 100 loops each)
If you are running out of memory, you can try de-batching one of the computations, or using a smaller batch size. For instance, in this example, we can de-batch the repeat
combinator, which will reduce the memory usage by a factor of $100$, at the cost of some performance.
jitted = jit(hmm.simulate)
def hmm_debatched(key, initial_state):
keys = jax.random.split(key, n_repeats)
traces = {}
for i in range(n_repeats):
trace = jitted(keys[i], (initial_state, None))
traces[i] = trace
return traces
key, subkey = jax.random.split(key)
# About 4x slower on arm64 CPU and 40x on a Google Colab GPU
%timeit hmm_debatched(subkey, initial_state)
1.04 s ± 32.3 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
- Fast sampling can be inaccurate and yield Nan/wrong results.
As an example, truncating a normal distribution outside 5.5 standard deviations from its mean can yield NaNs. Many default TFP/JAX implementations that run on the GPU use fast implementations on 32bits. If one really wants that, one could use slower implementations that use 64bits and an exponential tilting Monte Carlo algorithm.
genjax.truncated_normal.sample(
jax.random.key(2), 0.5382424, 0.05, 0.83921564 - 0.03, 0.83921564 + 0.03
)
minv = 0.83921564 - 0.03
maxv = 0.83921564 + 0.03
mean = 0.5382424
std = 0.05
def raw_jax_truncated(key, minv, maxv, mean, std):
low = (minv - mean) / std
high = (maxv - mean) / std
return std * jax.random.truncated_normal(key, low, high, (), jnp.float32) + mean
raw_jax_truncated(jax.random.key(2), minv, maxv, mean, std)
# ==> Array(0.80921566, dtype=float32)
jax.jit(raw_jax_truncated)(jax.random.key(2), minv, maxv, mean, std)
# ==> Array(nan, dtype=float32)