Introduction
¶
import sys
if "google.colab" in sys.modules:
%pip install --quiet "genjax[genstudio]"
GenJAX is a swiss army knife for probabilistic machine learning: it's designed to support probabilistic modeling workflows, and to make the resulting code extremely fast and parallelizable via JAX.
In this introduction, we'll focus on one such workflow: writing a latent variable model (we often say: a generative model) which describes a probability distribution over latent variables and data, and then asking questions about the conditional distribution over the latent variables given data.
In the following, we'll often shorten GenJAX to Gen -- because GenJAX implements Gen.
import genstudio.plot as Plot
import jax.numpy as jnp
import matplotlib.pyplot as plt
import seaborn as sns
from jax import jit, vmap
from jax import random as jrand
import genjax
from genjax import gen, normal, pretty
sns.set_theme(style="white")
plt.rcParams["figure.facecolor"] = "none"
plt.rcParams["savefig.transparent"] = True
%config InlineBackend.figure_format = 'svg'
pretty() # pretty print the types
Generative functions¶
@gen
def model():
x = normal(0.0, 1.0) @ "x"
normal(x, 1.0) @ "y"
model
In Gen, probabilistic models are represented by a computational object called a generative function. Once we create one of these objects, we can use one of several interfaces to gain access to probabilistic effects.
Here's one interface: simulate
-- this samples from the probability distribution which the program represents, and stores the result, along with other data about the invocation of the function, in a data structure called a Trace
.
key = jrand.key(0)
tr = model.simulate(key, ())
tr
We can dig around in this object uses its interfaces:
chm = tr.get_choices()
chm
A ChoiceMap
is a representation of the sample from the probability distribution which the generative function represents. We can ask what values were sampled at the addresses (the "x"
and "y"
syntax in our model):
(chm["x"], chm["y"])
Neat -- all of our interfaces are JAX compatible, so we could sample 1000 times just by using jax.vmap
:
sub_keys = jrand.split(jrand.key(0), 1000)
tr = jit(vmap(model.simulate, in_axes=(0, None)))(sub_keys, ())
tr
Let's plot our samples to get a sense of the distribution we wrote down.
chm = tr.get_choices()
Plot.dot({"x": chm["x"], "y": chm["y"]})
Traces also keep track of other data, like the score of the execution (which is a value which estimates the joint probability of the random choices under the distribution):
tr.get_score()
Composition of generative functions¶
Generative functions are probabilistic building blocks. You can combine them into larger probability distributions:
# A regression distribution.
@gen
def regression(x, coefficients, sigma):
basis_value = jnp.array([1.0, x, x**2])
polynomial_value = jnp.sum(basis_value * coefficients)
y = genjax.normal(polynomial_value, sigma) @ "v"
return y
# Regression, with an outlier random variable.
@gen
def regression_with_outlier(x, coefficients):
is_outlier = genjax.flip(0.1) @ "is_outlier"
sigma = jnp.where(is_outlier, 30.0, 0.3)
is_outlier = jnp.array(is_outlier, dtype=int)
return regression(x, coefficients, sigma) @ "y"
# The full model, sample coefficients for a curve, and then use
# them in independent draws from the regression submodel.
@gen
def full_model(xs):
coefficients = (
genjax.mv_normal(
jnp.zeros(3, dtype=float),
2.0 * jnp.identity(3),
)
@ "alpha"
)
ys = regression_with_outlier.vmap(in_axes=(0, None))(xs, coefficients) @ "ys"
return ys
Now, let's examine a sample from this model:
data = jnp.arange(0, 10, 0.5)
full_model.simulate(key, (data,)).get_choices()["ys", :, "y", "v"]
We can plot a few such samples.
key, *sub_keys = jrand.split(key, 10)
traces = vmap(lambda k: full_model.simulate(k, (data,)))(jnp.array(sub_keys))
ys = traces.get_choices()["ys", :, "y", "v"]
(
Plot.dot(
Plot.dimensions(ys, ["sample", "ys"], leaves="y"),
{"x": Plot.repeat(data), "y": "y", "facetGrid": "sample"},
)
+ Plot.frame()
)
These are samples from the distribution over curves which our generative function represents.
Inference in generative functions¶
So we've written a regression model, a distribution over curves. Our model includes an outlier component. If we observe some data for "y"
, can we predict which points might be outliers?
x = jnp.array([0.3, 0.7, 1.1, 1.4, 2.3, 2.5, 3.0, 4.0, 5.0])
y = 2.0 * x + 1.5 + x**2
y = y.at[2].set(50.0)
y
We've explored how generative functions represent joint distributions over random variables, but what about distributions induced by inference problems?
We can create an inference problem by pairing a generative function with arguments, and a constraint.
First, let's learn how to create one type of constraint -- a choice map sample, just like the choice maps we saw earlier.
from genjax import ChoiceMapBuilder as C
chm = C["ys", :, "y", "v"].set(y)
chm["ys", :, "y", "v"]
The choice map holds the value constraint for the distributions we used in our generative function. Choice maps are a lot like arrays, with a bit of extra metadata.
Now, we can specify an inference target.
from genjax import Target
target = Target(full_model, (x,), chm)
target
A Target
represents an unnormalized distribution -- in this case, the posterior of the distribution represented by our generative function with arguments args = (x, )
.
Now, we can approximate the solution to the inference problem using an inference algorithm. GenJAX exposes a standard library of approximate inference algorithms: let's use $K$-particle importance sampling for this one.
from genjax.inference.smc import ImportanceK
alg = ImportanceK(target, k_particles=100)
alg
sub_keys = jrand.split(key, 50)
posterior_samples = jit(vmap(alg(target)))(sub_keys)
With samples from our approximate posterior in hand, we can check queries like "estimate the probability that a point is an outlier":
posterior_samples["ys", :, "is_outlier"]
Here, we see that our approximate posterior assigns high probability to the query "the 3rd data point is an outlier". Remember, we set this point to be far away from the other points.
posterior_samples["ys", :, "is_outlier"].mean(axis=0)
We can also plot the sampled curves against the data.
def polynomial_at_x(x, coefficients):
basis_values = jnp.array([1.0, x, x**2])
polynomial_value = jnp.sum(coefficients * basis_values)
return polynomial_value
jitted = jit(vmap(polynomial_at_x, in_axes=(None, 0)))
coefficients = posterior_samples["alpha"]
evaluation_points = jnp.arange(0, 5, 0.01)
points = [(x, y) for x in evaluation_points for y in jitted(x, coefficients).tolist()]
(
Plot.dot(points, fill="gold", opacity=0.25, r=0.5)
+ Plot.dot({"x": x, "y": y})
+ Plot.frame()
)
Summary¶
We’ve covered a lot of ground in this notebook. Please reflect, re-read, and post issues!
- We discussed generative functions - the main computational object of Gen, and how these objects represent probability distributions.
- We showed how to create generative functions.
- We showed how to use interfaces on generative functions to compute with common operations on distributions.
- We created a generative function to model a data-generating process based on sampling and evaluating random polynomials on input data - representing regression task.
- We showed how to create inference problems from generative functions.
- We created an inference problem from our regression model.
- We showed how to create approximate inference solutions to inference problems, and sample from them.
- We investigated the approximate posterior samples, and visually inspected that they match the inferences that we might draw - both for the polynomials we expected to produce the data, as well as what data points might be outliers.
This is just the beginning! There’s a lot more to learn, but this is plenty to chew (for now).