Generative functions
What is a generative function and how to use it?¶
In [1]:
Copied!
import jax
from genjax import bernoulli, beta, gen, pretty
pretty()
import jax
from genjax import bernoulli, beta, gen, pretty
pretty()
The following is a simple of a beta-bernoulli process. We use the @gen
decorator to create generative functions.
In [2]:
Copied!
@gen
def beta_bernoulli_process(u):
p = beta(1.0, u) @ "p"
v = bernoulli(p) @ "v"
return v
@gen
def beta_bernoulli_process(u):
p = beta(1.0, u) @ "p"
v = bernoulli(p) @ "v"
return v
We can now call the generative function with a specified random key
In [3]:
Copied!
key = jax.random.key(0)
key = jax.random.key(0)
Running the function will return a trace, which records the arguments, random choices made, and the return value
In [4]:
Copied!
key, subkey = jax.random.split(key)
tr = beta_bernoulli_process.simulate(subkey, (1.0,))
key, subkey = jax.random.split(key)
tr = beta_bernoulli_process.simulate(subkey, (1.0,))
We can print the trace to see what happened
In [5]:
Copied!
tr.args, tr.get_retval(), tr.get_choices()
tr.args, tr.get_retval(), tr.get_choices()
Out[5]:
GenJAX functions can be accelerated with jit
compilation.
The non-optimal way is within the @gen
decorator.
In [6]:
Copied!
@gen
@jax.jit
def fast_beta_bernoulli_process(u):
p = beta(0.0, u) @ "p"
v = bernoulli(p) @ "v" # sweet
return v
@gen
@jax.jit
def fast_beta_bernoulli_process(u):
p = beta(0.0, u) @ "p"
v = bernoulli(p) @ "v" # sweet
return v
And the better way is to jit
the final function we aim to run
In [7]:
Copied!
jitted = jax.jit(beta_bernoulli_process.simulate)
jitted = jax.jit(beta_bernoulli_process.simulate)
We can then compare the speed of the three functions. To fairly compare we need to run the functions once to compile them.
In [8]:
Copied!
key, subkey = jax.random.split(key)
fast_beta_bernoulli_process.simulate(subkey, (1.0,))
key, subkey = jax.random.split(key)
jitted(subkey, (1.0,))
key, subkey = jax.random.split(key)
fast_beta_bernoulli_process.simulate(subkey, (1.0,))
key, subkey = jax.random.split(key)
jitted(subkey, (1.0,))
Out[8]:
In [9]:
Copied!
key, subkey = jax.random.split(key)
%timeit beta_bernoulli_process.simulate(subkey, (1.0,))
key, subkey = jax.random.split(key)
%timeit fast_beta_bernoulli_process.simulate(subkey, (1.0,))
key, subkey = jax.random.split(key)
%timeit jitted(subkey, (1.0,))
key, subkey = jax.random.split(key)
%timeit beta_bernoulli_process.simulate(subkey, (1.0,))
key, subkey = jax.random.split(key)
%timeit fast_beta_bernoulli_process.simulate(subkey, (1.0,))
key, subkey = jax.random.split(key)
%timeit jitted(subkey, (1.0,))
431 ms ± 4.38 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
466 μs ± 1.32 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
81.4 μs ± 612 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)