Mixture
How can I write a mixture of models in GenJAX?
¶
In [1]:
Copied!
import sys
if "google.colab" in sys.modules:
%pip install --quiet "genjax[genstudio]"
import sys
if "google.colab" in sys.modules:
%pip install --quiet "genjax[genstudio]"
In [2]:
Copied!
from jax import random
from genjax import flip, gen, inverse_gamma, mix, normal
key = random.key(0)
from jax import random
from genjax import flip, gen, inverse_gamma, mix, normal
key = random.key(0)
We simply use the mix
combinator.
Note that the trace is the join of the traces of the different components.
We first define the three components of the mixture model as generative functions.
In [3]:
Copied!
@gen
def mixture_component_1(p):
x = normal(p, 1.0) @ "x"
return x
@gen
def mixture_component_2(p):
b = flip(p) @ "b"
return b
@gen
def mixture_component_3(p):
y = inverse_gamma(p, 0.1) @ "y"
return y
@gen
def mixture_component_1(p):
x = normal(p, 1.0) @ "x"
return x
@gen
def mixture_component_2(p):
b = flip(p) @ "b"
return b
@gen
def mixture_component_3(p):
y = inverse_gamma(p, 0.1) @ "y"
return y
The mix combinator take as input the logits of the mixture components, and args for each component of the mixture.
In [4]:
Copied!
@gen
def mixture_model(p):
z = normal(p, 1.0) @ "z"
logits = (0.3, 0.5, 0.2)
arg_1 = (p,)
arg_2 = (p,)
arg_3 = (p,)
a = (
mix(mixture_component_1, mixture_component_2, mixture_component_3)(
logits, arg_1, arg_2, arg_3
)
@ "a"
)
return a + z
key, subkey = random.split(key)
tr = mixture_model.simulate(subkey, (0.4,))
print("return value:", tr.get_retval())
print("value for z:", tr.get_choices()["z"])
@gen
def mixture_model(p):
z = normal(p, 1.0) @ "z"
logits = (0.3, 0.5, 0.2)
arg_1 = (p,)
arg_2 = (p,)
arg_3 = (p,)
a = (
mix(mixture_component_1, mixture_component_2, mixture_component_3)(
logits, arg_1, arg_2, arg_3
)
@ "a"
)
return a + z
key, subkey = random.split(key)
tr = mixture_model.simulate(subkey, (0.4,))
print("return value:", tr.get_retval())
print("value for z:", tr.get_choices()["z"])
return value: 0.06437492 value for z: -0.9356251
The combinator uses a fix address "mixture_component" for the components of the mixture model.
In [5]:
Copied!
print("value for the mixture_component:", tr.get_choices()["a", "mixture_component"])
print("value for the mixture_component:", tr.get_choices()["a", "mixture_component"])
value for the mixture_component: 1