Stochastic probabilities
How to create and use distributions with inexact likelihood evaluations
¶
import sys
if "google.colab" in sys.modules:
%pip install --quiet "genjax[genstudio]"
This notebook builds on top of the custom_distribution
one.
import jax
import jax.numpy as jnp
from tensorflow_probability.substrates import jax as tfp
import genjax
from genjax import ChoiceMapBuilder as C
from genjax import Pytree, Weight, pretty
from genjax._src.generative_functions.distributions.distribution import Distribution
from genjax.typing import Any
tfd = tfp.distributions
key = jax.random.key(0)
pretty()
Recall how we defined a distribution for a Gaussian mixture, using the Distribution
class.
@Pytree.dataclass
class GaussianMixture(Distribution):
def random_weighted(
self, key: jax.random.key, probs, means, vars
) -> tuple[Weight, Any]:
probs = jnp.asarray(probs)
means = jnp.asarray(means)
vars = jnp.asarray(vars)
cat = tfd.Categorical(probs=probs)
cat_index = jnp.asarray(cat.sample(seed=key))
normal = tfd.Normal(loc=means[cat_index], scale=vars[cat_index])
key, subkey = jax.random.split(key)
normal_sample = normal.sample(seed=subkey)
zipped = jnp.stack([jnp.arange(0, len(probs)), means, vars], axis=1)
weight_recip = -jax.scipy.special.logsumexp(
jax.vmap(
lambda z: tfd.Normal(loc=z[1], scale=z[2]).log_prob(normal_sample)
+ tfd.Categorical(probs=probs).log_prob(z[0])
)(zipped)
)
return weight_recip, normal_sample
def estimate_logpdf(self, key: jax.random.key, x, probs, means, vars) -> Weight:
zipped = jnp.stack([jnp.arange(0, len(probs)), means, vars], axis=1)
return jax.scipy.special.logsumexp(
jax.vmap(
lambda z: tfd.Normal(loc=z[1], scale=z[2]).log_prob(x)
+ tfd.Categorical(probs=probs).log_prob(z[0])
)(zipped)
)
In the class above, note in estimate_logpdf
how we computed the density as a sum over all possible paths in the that could lead to a particular outcome x
.
In fact, the same occurs in random_weighted
: even though we know exactly the path we took to get to the sample normal_sample
, when evaluating the reciprocal density, we also sum over all possible paths that could lead to that value
.
Precisely, this required to sum over all the possible values of the categorical distribution cat
. We technically sampled two random values cat_index
and normal_sample
, but we are only interested in the distribution on normal_sample
: we marginalized out the intermediate random variable cat_index
.
Mathematically, we have
p(normal_sample) = sum_{cat_index} p(normal_sample, cat_index)
.
GenJAX supports a more general kind of distribution, that only need to be able to estimate their densities. The correctness criterion for this to be valid are that the estimation should be unbiased, i.e. the correct value on average.
More precisely, estimate_logpdf
should return an unbiased density estimate, while random_weighted
should return an unbiased estimate for the reciprocal density. In general you can't get one from the other, as the following example shows.
Flip a coin and with $50%$ chance return $1$, otherwise $3$. This gives an unbiased estimator of $2$. If we now return $\frac{1}{1}$ with 50%, and $\frac{1}{3}$ otherwise, the average value is $\frac{2}{3}$, which is not $\frac{1}{2}$.
Let's now define a Gaussian mixture distribution that only estimates its density.
@Pytree.dataclass
class StochasticGaussianMixture(Distribution):
def random_weighted(
self, key: jax.random.key, probs, means, vars
) -> tuple[Weight, Any]:
probs = jnp.asarray(probs)
means = jnp.asarray(means)
vars = jnp.asarray(vars)
cat = tfd.Categorical(probs=probs)
cat_index = jnp.asarray(cat.sample(seed=key))
normal = tfd.Normal(loc=means[cat_index], scale=vars[cat_index])
key, subkey = jax.random.split(key)
normal_sample = normal.sample(seed=subkey)
# We can estimate the reciprocal (marginal) density in constant time. Math magic explained at the end!
weight_recip = -tfd.Normal(
loc=means[cat_index], scale=vars[cat_index]
).log_prob(normal_sample)
return weight_recip, normal_sample
# Given a sample `x`, we can also estimate the density in constant time
# Math again explained at the end.
# TODO: we could probably improve further with a better proposal
def estimate_logpdf(self, key: jax.random.key, x, probs, means, vars) -> Weight:
cat = tfd.Categorical(probs=probs)
cat_index = jnp.asarray(cat.sample(seed=key))
return tfd.Normal(loc=means[cat_index], scale=vars[cat_index]).log_prob(x)
To test, we start by creating a generative function using our new distribution.
sgm = StochasticGaussianMixture()
@genjax.gen
def model(cat_probs, means, vars):
x = sgm(cat_probs, means, vars) @ "x"
y_means = jnp.repeat(x, len(means))
y = sgm(cat_probs, y_means, vars) @ "y"
return (x, y)
We can then simulate from the model, assess a trace, or use importance sampling with the default proposal, seemlessly.
cat_probs = jnp.array([0.1, 0.4, 0.2, 0.3])
means = jnp.array([0.0, 1.0, 2.0, 3.0])
vars = jnp.array([1.0, 1.0, 1.0, 1.0])
key, subkey = jax.random.split(key)
tr = model.simulate(subkey, (cat_probs, means, vars))
tr
# TODO: assess currently raises a not implemented error, but we can use importance with a full trace instead
# model.assess(tr.get_choices(), (cat_probs, means, vars))
key, subkey = jax.random.split(key)
_, w = model.importance(subkey, tr.get_choices(), (cat_probs, means, vars))
w
y = 2.0
key, subkey = jax.random.split(key)
model.importance(subkey, C["y"].set(y), (cat_probs, means, vars))
Let's also check that estimate_logpdf
from our distribution sgm
indeed correctly estimates the density.
gm = GaussianMixture()
x = 2.0
N = 42
n_estimates = 2000000
cat_probs = jnp.array(jnp.arange(1.0 / N, 1.0 + 1.0 / N, 1.0 / N))
cat_probs = cat_probs / jnp.sum(cat_probs)
means = jnp.arange(0.0, N * 1.0, 1.0)
vars = jnp.ones(N) / N
key, subkey = jax.random.split(key)
log_density = gm.estimate_logpdf(subkey, x, cat_probs, means, vars) # exact value
log_density
jitted = jax.jit(jax.vmap(sgm.estimate_logpdf, in_axes=(0, None, None, None, None)))
key, subkey = jax.random.split(key)
keys = jax.random.split(subkey, n_estimates)
estimates = jitted(keys, x, cat_probs, means, vars)
log_mean_estimates = jax.scipy.special.logsumexp(estimates) - jnp.log(len(estimates))
log_density, log_mean_estimates
One benefit of using density estimates instead of exact ones is that it can be much faster to compute. Here's a way to test it, though it will not shine on this example as it is too simple. We will explore examples in different notebooks where this shines more brightly.
N = 30000
n_estimates = 10
cat_probs = jnp.array(jnp.arange(1.0 / N, 1.0 + 1.0 / N, 1.0 / N))
cat_probs = cat_probs / jnp.sum(cat_probs)
means = jnp.arange(0.0, N * 1.0, 1.0)
vars = jnp.ones(N) / N
jitted_exact = jax.jit(gm.estimate_logpdf)
jitted_approx = jax.jit(
lambda key, x, cat_probs, means, vars: jax.scipy.special.logsumexp(
jax.vmap(sgm.estimate_logpdf, in_axes=(0, None, None, None, None))(
key, x, cat_probs, means, vars
)
)
- jnp.log(n_estimates)
)
# warmup the jit
key, subkey = jax.random.split(key)
jitted_exact(subkey, x, cat_probs, means, vars)
key, subkey = jax.random.split(key)
keys = jax.random.split(subkey, n_estimates)
jitted_approx(keys, x, cat_probs, means, vars)
key, subkey = jax.random.split(key)
keys = jax.random.split(subkey, n_estimates)
%timeit jitted(keys, x, cat_probs, means, vars)
key, subkey = jax.random.split(key)
keys = jax.random.split(subkey, n_estimates)
%timeit jitted_approx(keys, x, cat_probs, means, vars)
2.92 ms ± 45.1 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)
2.94 ms ± 37 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)
Now, the reason we need both methods random_weighted
and estimate_logpdf
is that both methods will be used at different times, notably depending on whether we use the distribution in a proposal or in a model, as we show next.
Let's define a simple model and a proposal which both use our sgm
distribution.
@genjax.gen
def model(cat_probs, means, vars):
x = sgm(cat_probs, means, vars) @ "x"
y_means = jnp.repeat(x, len(means))
y = sgm(cat_probs, y_means, vars) @ "y"
return (x, y)
@genjax.gen
def proposal(obs, cat_probs, means, vars):
y = obs["y"]
# simple logic to propose a new x: its mean was presumably closer to y
new_means = jax.vmap(lambda m: (m + y) / 2)(means)
x = sgm(cat_probs, new_means, vars) @ "x"
return (x, y)
Let's define importance sampling once again. Note that it is exactly the same as the usual one!
This is because behind the scenes GenJAX implements simulate
using random_weighted
and assess
using estimate_logpdf
.
def gensp_importance_sampling(target, obs, proposal):
def _inner(key, target_args, proposal_args):
key, subkey = jax.random.split(key)
trace = proposal.simulate(key, *proposal_args)
chm = obs ^ trace.get_choices()
proposal_logpdf = trace.get_score()
# TODO: using importance instead of assess, as assess is not implemented
_, target_logpdf = target.importance(subkey, chm, *target_args)
importance_weight = target_logpdf - proposal_logpdf
return (trace, importance_weight)
return _inner
Testing
obs = C["y"].set(2.0)
key, subkey = jax.random.split(key)
gensp_importance_sampling(model, obs, proposal)(
subkey, ((cat_probs, means, vars),), ((obs, cat_probs, means, vars),)
)
/tmp/ipykernel_9825/1815923397.py:5: DeprecationWarning: Call to deprecated method __xor__. (^ is deprecated, please use | or _.merge(...) instead.) -- Deprecated since version 0.8.0. chm = obs ^ trace.get_choices()
Finally, for those curious about the math magic that enabled to correctly (meaning unbiasedly) estimate the pdf and its reciprocal, there's a follow up cookbook on this!