Debugging
¶
How can I debug my code? I want to add break points or print statements in my Jax/GenJax code but it doesn't seem to work because of traced values and/or jit compilation.
import sys
if "google.colab" in sys.modules:
%pip install --quiet "genjax[genstudio]"
import jax
from genjax import bernoulli, beta, gen
key = jax.random.key(0)
TLDR: inside of generative functions, use jax.debug.print
and jax.debug.breakpoint()
instead of print()
statements.
We also recommend looking at the official JAX debug doc which applies to GenJAX as well:
https://jax.readthedocs.io/en/latest/debugging/print_breakpoint.html
Example of printing
@gen
def beta_bernoulli_process(u):
p = beta(0.0, u) @ "p"
v = bernoulli(p) @ "v"
print("Bad looking printing:", v) # will print a traced Value, not what you want
jax.debug.print("Better looking printing: {v}", v=v)
return v
non_jitted = beta_bernoulli_process.simulate
key, subkey = jax.random.split(key)
tr = non_jitted(subkey, (1.0,))
key, subkey = jax.random.split(key)
jitted = jax.jit(beta_bernoulli_process.simulate)
tr = jitted(subkey, (1.0,))
Bad looking printing: Traced<ShapedArray(int32[])>with<DynamicJaxprTrace>
Better looking printing: 0 Bad looking printing: Traced<ShapedArray(int32[])>with<DynamicJaxprTrace>
Better looking printing: 0
Inside generative functions, jax.debug.print
is available and compatible with all the JAX transformations and higher-order functions like jax.jit
, jax.grad
, jax.vmap
, jax.lax.scan
, etc.
Running the cell below will open a pdb-like interface in the terminal where you can inspect the values of the variables in the scope of the breakpoint. You can continue the execution of the program by typing c and pressing Enter. You can also inspect the values of the variables in the scope of the breakpoint by typing the name of the variable and pressing Enter. You can exit the breakpoint by typing q and pressing Enter. You can see the commands available in the breakpoint by typing h and pressing Enter. It also works with jitted functions, but may affect performance. It is compatible with all the JAX transformations and higher-order functions too but you can expect some sharp edges.
# Example of breakpoint
@gen
def beta_bernoulli_process(u):
p = beta(0.0, u) @ "p"
v = bernoulli(p) @ "v"
jax.debug.breakpoint()
return v
non_jitted = beta_bernoulli_process.simulate
key, subkey = jax.random.split(key)
tr = non_jitted(subkey, (1.0,))