Custom distribution
How do I create a custom distribution in GenJAX?
¶
import sys
if "google.colab" in sys.modules:
%pip install --quiet "genjax[genstudio]"
import jax
import jax.numpy as jnp
from tensorflow_probability.substrates import jax as tfp
from genjax import ChoiceMapBuilder as C
from genjax import Distribution, ExactDensity, Pytree, Weight, gen, normal, pretty
from genjax.typing import PRNGKey
tfd = tfp.distributions
key = jax.random.key(0)
pretty()
In GenJAX, there are two simple ways to extend the language by adding custom distributions which can be seamlessly used by the system.
The first way is to add a distribution for which we can compute its density exactly. In this case the API follows what one expects: one method to sample and one method to compute logpdf.
@Pytree.dataclass
class NormalInverseGamma(ExactDensity):
def sample(self, key: PRNGKey, μ, σ, α, β):
key, subkey = jax.random.split(key)
x = tfd.Normal(μ, σ).sample(seed=key)
y = tfd.InverseGamma(α, β).sample(seed=subkey)
return (x, y)
def logpdf(self, v, μ, σ, α, β):
x, y = v
a = tfd.Normal(μ, σ).log_prob(x)
b = tfd.InverseGamma(α, β).log_prob(y)
return a + b
Testing
# Create a particular instance of the distribution
norm_inv_gamma = NormalInverseGamma()
@gen
def model():
(x, y) = norm_inv_gamma(0.0, 1.0, 1.0, 1.0) @ "xy"
z = normal(x, y) @ "z"
return z
# Sampling from the model
key, subkey = jax.random.split(key)
jax.jit(model.simulate)(key, ())
# Computing density of joint
jax.jit(model.assess)(C["xy"].set((2.0, 2.0)) | C["z"].set(2.0), ())
The second way is to create a distribution via the Distribution
class.
Here, the logpdf
method is replace by the more general estimate_logpdf
method. The distribution is asked to return an unbiased density estimate of its logpdf at the provided value.
The sample
method is replaced by random_weighted
. It returns a sample from the distribution as well as an unbiased estimate of the reciprocal density, i.e. an estimate of $\frac{1}{p(x)}$.
Here we'll create a simple mixture of Gaussians.
@Pytree.dataclass
class GaussianMixture(Distribution):
# It can have static args
bias: float = Pytree.static(default=0.0)
# For distributions that can compute their densities exactly, `random_weighted` should return a sample x and the reciprocal density 1/p(x).
def random_weighted(self, key: PRNGKey, probs, means, vars) -> tuple[Weight, any]:
# making sure that the inputs are jnp arrays for jax compatibility
probs = jnp.asarray(probs)
means = jnp.asarray(means)
vars = jnp.asarray(vars)
# sampling from the categorical distribution and then sampling from the normal distribution
cat = tfd.Categorical(probs=probs)
cat_index = jnp.asarray(cat.sample(seed=key))
normal = tfd.Normal(
loc=means[cat_index] + jnp.asarray(self.bias), scale=vars[cat_index]
)
key, subkey = jax.random.split(key)
normal_sample = normal.sample(seed=subkey)
# calculating the reciprocal density
zipped = jnp.stack([probs, means, vars], axis=1)
weight_recip = -jnp.log(
sum(
jax.vmap(
lambda z: tfd.Normal(
loc=z[1] + jnp.asarray(self.bias), scale=z[2]
).prob(normal_sample)
* tfd.Categorical(probs=probs).prob(z[0])
)(zipped)
)
)
return weight_recip, normal_sample
# For distributions that can compute their densities exactly, `estimate_logpdf` should return the log density at x.
def estimate_logpdf(self, key: jax.random.key, x, probs, means, vars) -> Weight:
zipped = jnp.stack([probs, means, vars], axis=1)
return jnp.log(
sum(
jax.vmap(
lambda z: tfd.Normal(
loc=z[1] + jnp.asarray(self.bias), scale=z[2]
).prob(x)
* tfd.Categorical(probs=probs).prob(z[0])
)(zipped)
)
)
Testing:
gauss_mix = GaussianMixture(0.0)
@gen
def model(probs):
mix1 = gauss_mix(probs, jnp.array([0.0, 1.0]), jnp.array([1.0, 1.0])) @ "mix1"
mix2 = gauss_mix(probs, jnp.array([0.0, 1.0]), jnp.array([1.0, 1.0])) @ "mix2"
return mix1, mix2
probs = jnp.array([0.5, 0.5])
# Sampling from the model
key, subkey = jax.random.split(key)
jax.jit(model.simulate)(subkey, (probs,))
# Computing density of joint
key, subkey = jax.random.split(key)
jax.jit(model.importance)(subkey, C["mix1"].set(3.0) | C["mix2"].set(4.0), (probs,))