Combinators: structured patterns of composition¶
While the programmatic genjax.StaticGenerativeFunction
language is powerful, its restrictions can be limiting. Combinators are a way to express common patterns of composition in a more concise way, and to gain access to effects which are common in JAX (like jax.vmap
) for generative computations.
Each of the combinators below is implemented as a method on genjax.GenerativeFunction
and as a standalone decorator.
You should strongly prefer the method form. Here's an example of the vmap
combinator created by the genjax.GenerativeFunction.vmap
method:
Here is the vmap
combinator used as a method. square_many
below accepts an array and returns an array:
Here is square_many
defined with genjax.vmap
, the decorator version of the vmap
method:
Warning
We do not recommend this style, since the original building block generative function won't be available by itself. Please prefer using the combinator methods, or the transformation style shown below.
If you insist on using the decorator form, you can preserve the original function like this:
@genjax.gen
def square(x):
return x * x
# Use the decorator as a transformation:
square_many_better = genjax.vmap()(square)
vmap
-like Combinators¶
genjax.vmap
¶
vmap(
*, in_axes: InAxes = 0
) -> Callable[[GenerativeFunction[R]], Vmap[R]]
Returns a decorator that wraps a GenerativeFunction
and returns a new GenerativeFunction
that performs a vectorized map over the argument specified by in_axes
. Traced values are nested under an index, and the retval is vectorized.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
|
InAxes
|
Selector specifying which input arguments (or index into them) should be vectorized. |
0
|
Returns:
Type | Description |
---|---|
Callable[[GenerativeFunction[R]], Vmap[R]]
|
A decorator that converts a |
Examples:
import jax, genjax
import jax.numpy as jnp
@genjax.vmap(in_axes=0)
@genjax.gen
def vmapped_model(x):
v = genjax.normal(x, 1.0) @ "v"
return genjax.normal(v, 0.01) @ "q"
key = jax.random.key(314159)
arr = jnp.ones(100)
# `vmapped_model` accepts an array of numbers:
tr = jax.jit(vmapped_model.simulate)(key, (arr,))
print(tr.render_html())
Source code in src/genjax/_src/generative_functions/combinators/vmap.py
genjax.repeat
¶
repeat(
*, n: int
) -> Callable[
[GenerativeFunction[R]], GenerativeFunction[R]
]
Returns a decorator that wraps a genjax.GenerativeFunction
gen_fn
of type a -> b
and returns a new GenerativeFunction
of type a -> [b]
that samples from gen_fn
ntimes, returning a vector of
n` results.
The values traced by each call gen_fn
will be nested under an integer index that matches the loop iteration index that generated it.
This combinator is useful for creating multiple samples from the same generative model in a batched manner.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
|
int
|
The number of times to sample from the generative function. |
required |
Returns:
Type | Description |
---|---|
Callable[[GenerativeFunction[R]], GenerativeFunction[R]]
|
A new |
Examples:
import genjax, jax
@genjax.repeat(n=10)
@genjax.gen
def normal_draws(mean):
return genjax.normal(mean, 1.0) @ "x"
key = jax.random.key(314159)
# Generate 10 draws from a normal distribution with mean 2.0
tr = jax.jit(normal_draws.simulate)(key, (2.0,))
print(tr.render_html())
Source code in src/genjax/_src/generative_functions/combinators/repeat.py
scan
-like Combinators¶
genjax.scan
¶
scan(*, n: int | None = None) -> Callable[
[GenerativeFunction[tuple[Carry, Y]]],
GenerativeFunction[tuple[Carry, Y]],
]
Returns a decorator that wraps a genjax.GenerativeFunction
of type
(c, a) -> (c, b)
and returns a new genjax.GenerativeFunction
of type
(c, [a]) -> (c, [b])
where.
c
is a loop-carried value, which must hold a fixed shape and dtype across all iterationsa
may be a primitive, an array type or a pytree (container) type with array leavesb
may be a primitive, an array type or a pytree (container) type with array leaves.
The values traced by each call to the original generative function will be nested under an integer index that matches the loop iteration index that generated it.
For any array type specifier t
, [t]
represents the type with an additional leading axis, and if t
is a pytree (container) type with array leaves then [t]
represents the type with the same pytree structure and corresponding leaves each with an additional leading axis.
When the type of xs
in the snippet below (denoted [a]
above) is an array type or None, and the type of ys
in the snippet below (denoted [b]
above) is an array type, the semantics of the returned genjax.GenerativeFunction
are given roughly by this Python implementation:
def scan(f, init, xs, length=None):
if xs is None:
xs = [None] * length
carry = init
ys = []
for x in xs:
carry, y = f(carry, x)
ys.append(y)
return carry, np.stack(ys)
Unlike that Python version, both xs
and ys
may be arbitrary pytree values, and so multiple arrays can be scanned over at once and produce multiple output arrays. None
is actually a special case of this, as it represents an empty pytree.
The loop-carried value c
must hold a fixed shape and dtype across all iterations (and not just be consistent up to NumPy rank/shape broadcasting and dtype promotion rules, for example). In other words, the type c
in the type signature above represents an array with a fixed shape and dtype (or a nested tuple/list/dict container data structure with a fixed structure and arrays with fixed shape and dtype at the leaves).
Parameters:
Name | Type | Description | Default |
---|---|---|---|
|
int | None
|
optional integer specifying the number of loop iterations, which (if supplied) must agree with the sizes of leading axes of the arrays in the returned function's second argument. If supplied then the returned generative function can take |
None
|
Returns:
Type | Description |
---|---|
Callable[[GenerativeFunction[tuple[Carry, Y]]], GenerativeFunction[tuple[Carry, Y]]]
|
A new |
Examples:
Scan for 1000 iterations with no array input:
import jax
import genjax
@genjax.scan(n=1000)
@genjax.gen
def random_walk(prev, _):
x = genjax.normal(prev, 1.0) @ "x"
return x, None
init = 0.5
key = jax.random.key(314159)
tr = jax.jit(random_walk.simulate)(key, (init, None))
print(tr.render_html())
Scan across an input array:
import jax.numpy as jnp
@genjax.scan()
@genjax.gen
def add_and_square_all(sum, x):
new_sum = sum + x
return new_sum, sum * sum
init = 0.0
xs = jnp.ones(10)
tr = jax.jit(add_and_square_all.simulate)(key, (init, xs))
# The retval has the final carry and an array of all `sum*sum` returned.
print(tr.render_html())
Source code in src/genjax/_src/generative_functions/combinators/scan.py
672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 |
|
genjax.accumulate
¶
accumulate() -> Callable[
[GenerativeFunction[Carry]],
GenerativeFunction[Carry],
]
Returns a decorator that wraps a genjax.GenerativeFunction
of type
(c, a) -> c
and returns a new genjax.GenerativeFunction
of type
(c, [a]) -> [c]
where.
c
is a loop-carried value, which must hold a fixed shape and dtype across all iterations[c]
is an array of all loop-carried values seen during iteration (including the first)a
may be a primitive, an array type or a pytree (container) type with array leaves
All traced values are nested under an index.
For any array type specifier t
, [t]
represents the type with an additional leading axis, and if t
is a pytree (container) type with array leaves then [t]
represents the type with the same pytree structure and corresponding leaves each with an additional leading axis.
The semantics of the returned genjax.GenerativeFunction
are given roughly by this Python implementation (note the similarity to itertools.accumulate
):
def accumulate(f, init, xs):
carry = init
carries = [init]
for x in xs:
carry = f(carry, x)
carries.append(carry)
return carries
Unlike that Python version, both xs
and carries
may be arbitrary pytree values, and so multiple arrays can be scanned over at once and produce multiple output arrays.
The loop-carried value c
must hold a fixed shape and dtype across all iterations (and not just be consistent up to NumPy rank/shape broadcasting and dtype promotion rules, for example). In other words, the type c
in the type signature above represents an array with a fixed shape and dtype (or a nested tuple/list/dict container data structure with a fixed structure and arrays with fixed shape and dtype at the leaves).
Examples:
accumulate a running total:
import jax
import genjax
import jax.numpy as jnp
@genjax.accumulate()
@genjax.gen
def add(sum, x):
new_sum = sum + x
return new_sum
init = 0.0
key = jax.random.key(314159)
xs = jnp.ones(10)
tr = jax.jit(add.simulate)(key, (init, xs))
print(tr.render_html())
Source code in src/genjax/_src/generative_functions/combinators/scan.py
genjax.reduce
¶
reduce() -> Callable[
[GenerativeFunction[Carry]],
GenerativeFunction[Carry],
]
Returns a decorator that wraps a genjax.GenerativeFunction
of type
(c, a) -> c
and returns a new genjax.GenerativeFunction
of type
(c, [a]) -> c
where.
c
is a loop-carried value, which must hold a fixed shape and dtype across all iterationsa
may be a primitive, an array type or a pytree (container) type with array leaves
All traced values are nested under an index.
For any array type specifier t
, [t]
represents the type with an additional leading axis, and if t
is a pytree (container) type with array leaves then [t]
represents the type with the same pytree structure and corresponding leaves each with an additional leading axis.
The semantics of the returned genjax.GenerativeFunction
are given roughly by this Python implementation (note the similarity to functools.reduce
):
Unlike that Python version, both xs
and carry
may be arbitrary pytree values, and so multiple arrays can be scanned over at once and produce multiple output arrays.
The loop-carried value c
must hold a fixed shape and dtype across all iterations (and not just be consistent up to NumPy rank/shape broadcasting and dtype promotion rules, for example). In other words, the type c
in the type signature above represents an array with a fixed shape and dtype (or a nested tuple/list/dict container data structure with a fixed structure and arrays with fixed shape and dtype at the leaves).
Examples:
sum an array of numbers:
import jax
import genjax
import jax.numpy as jnp
@genjax.reduce()
@genjax.gen
def add(sum, x):
new_sum = sum + x
return new_sum
init = 0.0
key = jax.random.key(314159)
xs = jnp.ones(10)
tr = jax.jit(add.simulate)(key, (init, xs))
print(tr.render_html())
Source code in src/genjax/_src/generative_functions/combinators/scan.py
genjax.iterate
¶
iterate(
*, n: int
) -> Callable[
[GenerativeFunction[Y]], GenerativeFunction[Y]
]
Returns a decorator that wraps a genjax.GenerativeFunction
of type
a -> a
and returns a new genjax.GenerativeFunction
of type a ->
[a]
where.
a
is a loop-carried value, which must hold a fixed shape and dtype across all iterations[a]
is an array of alla
,f(a)
,f(f(a))
etc. values seen during iteration.
All traced values are nested under an index.
The semantics of the returned genjax.GenerativeFunction
are given roughly by this Python implementation:
def iterate(f, n, init):
input = init
seen = [init]
for _ in range(n):
input = f(input)
seen.append(input)
return seen
init
may be an arbitrary pytree value, and so multiple arrays can be iterated over at once and produce multiple output arrays.
The iterated value a
must hold a fixed shape and dtype across all iterations (and not just be consistent up to NumPy rank/shape broadcasting and dtype promotion rules, for example). In other words, the type a
in the type signature above represents an array with a fixed shape and dtype (or a nested tuple/list/dict container data structure with a fixed structure and arrays with fixed shape and dtype at the leaves).
Parameters:
Name | Type | Description | Default |
---|---|---|---|
|
int
|
the number of iterations to run. |
required |
Examples:
iterative addition, returning all intermediate sums:
import jax
import genjax
@genjax.iterate(n=100)
@genjax.gen
def inc(x):
return x + 1
init = 0.0
key = jax.random.key(314159)
tr = jax.jit(inc.simulate)(key, (init,))
print(tr.render_html())
Source code in src/genjax/_src/generative_functions/combinators/scan.py
genjax.iterate_final
¶
iterate_final(
*, n: int
) -> Callable[
[GenerativeFunction[Y]], GenerativeFunction[Y]
]
Returns a decorator that wraps a genjax.GenerativeFunction
of type
a -> a
and returns a new genjax.GenerativeFunction
of type a -> a
where.
a
is a loop-carried value, which must hold a fixed shape and dtype across all iterations- the original function is invoked
n
times with each input coming from the previous invocation's output, so that the new function returns \(f^n(a)\)
All traced values are nested under an index.
The semantics of the returned genjax.GenerativeFunction
are given roughly by this Python implementation:
init
may be an arbitrary pytree value, and so multiple arrays can be iterated over at once and produce multiple output arrays.
The iterated value a
must hold a fixed shape and dtype across all iterations (and not just be consistent up to NumPy rank/shape broadcasting and dtype promotion rules, for example). In other words, the type a
in the type signature above represents an array with a fixed shape and dtype (or a nested tuple/list/dict container data structure with a fixed structure and arrays with fixed shape and dtype at the leaves).
Parameters:
Name | Type | Description | Default |
---|---|---|---|
|
int
|
the number of iterations to run. |
required |
Examples:
iterative addition:
import jax
import genjax
@genjax.iterate_final(n=100)
@genjax.gen
def inc(x):
return x + 1
init = 0.0
key = jax.random.key(314159)
tr = jax.jit(inc.simulate)(key, (init,))
print(tr.render_html())
Source code in src/genjax/_src/generative_functions/combinators/scan.py
980 981 982 983 984 985 986 987 988 989 990 991 992 993 994 995 996 997 998 999 1000 1001 1002 1003 1004 1005 1006 1007 1008 1009 1010 1011 1012 1013 1014 1015 1016 1017 1018 1019 1020 1021 1022 1023 1024 1025 1026 1027 1028 1029 1030 1031 1032 1033 1034 1035 1036 1037 1038 1039 1040 1041 1042 1043 1044 |
|
genjax.masked_iterate
¶
masked_iterate() -> (
Callable[
[GenerativeFunction[Y]], GenerativeFunction[Y]
]
)
Transforms a generative function that takes a single argument of type a
and returns a value of type a
, into a function that takes a tuple of arguments (a, [mask])
and returns a list of values of type a
.
The original function is modified to accept an additional argument mask
, which is a boolean value indicating whether the operation should be masked or not. The function returns a Masked list of results of the original operation with the input [mask] as mask.
All traced values from the kernel generative function are traced (with an added axis due to the scan) but only those indices from [mask] with a flag of True will accounted for in inference, notably for score computations.
Example
import jax
import genjax
masks = jnp.array([True, False, True])
# Create a kernel generative function
@genjax.gen
def step(x):
_ = (
genjax.normal.mask().vmap(in_axes=(0, None, None))(masks, x, 1.0)
@ "rats"
)
return x
# Create a model using masked_iterate
model = genjax.masked_iterate()(step)
# Simulate from the model
key = jax.random.key(0)
mask_steps = jnp.arange(10) < 5
tr = model.simulate(key, (0.0, mask_steps))
print(tr.render_html())
Source code in src/genjax/_src/generative_functions/combinators/scan.py
genjax.masked_iterate_final
¶
masked_iterate_final() -> (
Callable[
[GenerativeFunction[Y]], GenerativeFunction[Y]
]
)
Transforms a generative function that takes a single argument of type a
and returns a value of type a
, into a function that takes a tuple of arguments (a, [mask])
and returns a value of type a
.
The original function is modified to accept an additional argument mask
, which is a boolean value indicating whether the operation should be masked or not. The function returns the result of the original operation if mask
is True
, and the original input if mask
is False
.
All traced values from the kernel generative function are traced (with an added axis due to the scan) but only those indices from [mask] with a flag of True will accounted for in inference, notably for score computations.
Example
import jax
import genjax
masks = jnp.array([True, False, True])
# Create a kernel generative function
@genjax.gen
def step(x):
_ = (
genjax.normal.mask().vmap(in_axes=(0, None, None))(masks, x, 1.0)
@ "rats"
)
return x
# Create a model using masked_iterate_final
model = genjax.masked_iterate_final()(step)
# Simulate from the model
key = jax.random.key(0)
mask_steps = jnp.arange(10) < 5
tr = model.simulate(key, (0.0, mask_steps))
print(tr.render_html())
Source code in src/genjax/_src/generative_functions/combinators/scan.py
Control Flow Combinators¶
genjax.or_else
¶
or_else(
if_gen_fn: GenerativeFunction[R],
else_gen_fn: GenerativeFunction[R],
) -> GenerativeFunction[R]
Given two genjax.GenerativeFunction
s if_gen_fn
and else_gen_fn
, returns a new genjax.GenerativeFunction
that accepts
- a boolean argument
- an argument tuple for
if_gen_fn
- an argument tuple for the supplied
else_gen_fn
and acts like if_gen_fn
when the boolean is True
or else_gen_fn
otherwise.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
|
GenerativeFunction[R]
|
called when the boolean argument is |
required |
Returns:
Type | Description |
---|---|
GenerativeFunction[R]
|
A |
Examples:
import jax
import jax.numpy as jnp
import genjax
@genjax.gen
def if_model(x):
return genjax.normal(x, 1.0) @ "if_value"
@genjax.gen
def else_model(x):
return genjax.normal(x, 5.0) @ "else_value"
or_else_model = genjax.or_else(if_model, else_model)
@genjax.gen
def model(toss: bool):
# Note that `or_else_model` takes a new boolean predicate in
# addition to argument tuples for each branch.
return or_else_model(toss, (1.0,), (10.0,)) @ "tossed"
key = jax.random.key(314159)
tr = jax.jit(model.simulate)(key, (True,))
print(tr.render_html())
Source code in src/genjax/_src/generative_functions/combinators/or_else.py
genjax.switch
¶
switch(*gen_fns: GenerativeFunction[R]) -> Switch[R]
Given n
genjax.GenerativeFunction
inputs, returns a genjax.GenerativeFunction
that accepts n+1
arguments:
- an index in the range \([0, n)\)
- a tuple of arguments for each of the input generative functions (
n
total tuples)
and executes the generative function at the supplied index with its provided arguments.
If index
is out of bounds, index
is clamped to within bounds.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
|
GenerativeFunction[R]
|
generative functions that the |
()
|
Examples:
Create a Switch
via the genjax.switch
method:
import jax, genjax
@genjax.gen
def branch_1():
x = genjax.normal(0.0, 1.0) @ "x1"
@genjax.gen
def branch_2():
x = genjax.bernoulli(probs=0.3) @ "x2"
switch = genjax.switch(branch_1, branch_2)
key = jax.random.key(314159)
jitted = jax.jit(switch.simulate)
# Select `branch_2` by providing 1:
tr = jitted(key, (1, (), ()))
print(tr.render_html())
Source code in src/genjax/_src/generative_functions/combinators/switch.py
Argument and Return Transformations¶
genjax.dimap
¶
dimap(
*,
pre: Callable[..., ArgTuple] = lambda *args: args,
post: Callable[
[tuple[Any, ...], ArgTuple, R], S
] = lambda _, _xformed, retval: retval
) -> Callable[
[GenerativeFunction[R]], Dimap[ArgTuple, R, S]
]
Returns a decorator that wraps a genjax.GenerativeFunction
and applies pre- and post-processing functions to its arguments and return value.
Info
Prefer genjax.map
if you only need to transform the return value, or genjax.contramap
if you need to transform the arguments.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
|
Callable[..., ArgTuple]
|
A callable that preprocesses the arguments before passing them to the wrapped function. Note that |
lambda *args: args
|
|
Callable[[tuple[Any, ...], ArgTuple, R], S]
|
A callable that postprocesses the return value of the wrapped function. Default is the identity function. |
lambda _, _xformed, retval: retval
|
Returns:
Type | Description |
---|---|
Callable[[GenerativeFunction[R]], Dimap[ArgTuple, R, S]]
|
A decorator that takes a |
Examples:
import jax, genjax
# Define pre- and post-processing functions
def pre_process(x, y):
return (x + 1, y * 2)
def post_process(args, xformed, retval):
return retval**2
# Apply dimap to a generative function
@genjax.dimap(pre=pre_process, post=post_process)
@genjax.gen
def dimap_model(x, y):
return genjax.normal(x, y) @ "z"
# Use the dimap model
key = jax.random.key(0)
trace = dimap_model.simulate(key, (2.0, 3.0))
print(trace.render_html())
Source code in src/genjax/_src/generative_functions/combinators/dimap.py
genjax.map
¶
map(
f: Callable[[R], S]
) -> Callable[
[GenerativeFunction[R]], Dimap[tuple[Any, ...], R, S]
]
Returns a decorator that wraps a genjax.GenerativeFunction
and applies a post-processing function to its return value.
This is a specialized version of genjax.dimap
where only the post-processing function is applied.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
|
Callable[[R], S]
|
A callable that postprocesses the return value of the wrapped function. |
required |
Returns:
Type | Description |
---|---|
Callable[[GenerativeFunction[R]], Dimap[tuple[Any, ...], R, S]]
|
A decorator that takes a |
Examples:
import jax, genjax
# Define a post-processing function
def square(x):
return x**2
# Apply map to a generative function
@genjax.map(square)
@genjax.gen
def map_model(x):
return genjax.normal(x, 1.0) @ "z"
# Use the map model
key = jax.random.key(0)
trace = map_model.simulate(key, (2.0,))
print(trace.render_html())
Source code in src/genjax/_src/generative_functions/combinators/dimap.py
genjax.contramap
¶
contramap(
f: Callable[..., ArgTuple]
) -> Callable[
[GenerativeFunction[R]], Dimap[ArgTuple, R, R]
]
Returns a decorator that wraps a genjax.GenerativeFunction
and applies a pre-processing function to its arguments.
This is a specialized version of genjax.dimap
where only the pre-processing function is applied.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
|
Callable[..., ArgTuple]
|
A callable that preprocesses the arguments of the wrapped function. Note that |
required |
Returns:
Type | Description |
---|---|
Callable[[GenerativeFunction[R]], Dimap[ArgTuple, R, R]]
|
A decorator that takes a |
Examples:
import jax, genjax
# Define a pre-processing function.
# Note that this function must return a tuple of arguments!
def add_one(x):
return (x + 1,)
# Apply contramap to a generative function
@genjax.contramap(add_one)
@genjax.gen
def contramap_model(x):
return genjax.normal(x, 1.0) @ "z"
# Use the contramap model
key = jax.random.key(0)
trace = contramap_model.simulate(key, (2.0,))
print(trace.render_html())
Source code in src/genjax/_src/generative_functions/combinators/dimap.py
The Rest¶
genjax.mask
¶
mask(f: GenerativeFunction[R]) -> MaskCombinator[R]
Combinator which enables dynamic masking of generative functions. Takes a genjax.GenerativeFunction
and returns a new genjax.GenerativeFunction
which accepts an additional boolean first argument.
If True
, the invocation of the generative function is masked, and its contribution to the score is ignored. If False
, it has the same semantics as if one was invoking the generative function without masking.
The return value type is a Mask
, with a flag value equal to the supplied boolean.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
|
GenerativeFunction[R]
|
The generative function to be masked. |
required |
Returns:
Type | Description |
---|---|
MaskCombinator[R]
|
The masked version of the input generative function. |
Examples:
Masking a normal draw:
import genjax, jax
@genjax.mask
@genjax.gen
def masked_normal_draw(mean):
return genjax.normal(mean, 1.0) @ "x"
key = jax.random.key(314159)
tr = jax.jit(masked_normal_draw.simulate)(
key,
(
False,
2.0,
),
)
print(tr.render_html())
Source code in src/genjax/_src/generative_functions/combinators/mask.py
genjax.mix
¶
mix(
*gen_fns: GenerativeFunction[R],
) -> GenerativeFunction[R]
Creates a mixture model from a set of generative functions.
This function takes multiple generative functions as input and returns a new generative function that represents a mixture model.
The returned generative function takes the following arguments:
mixture_logits
: Logits for the categorical distribution used to select a component.*args
: Argument tuples for each of the input generative functions
and samples from one of the input generative functions based on draw from a categorical distribution defined by the provided mixture logits.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
|
GenerativeFunction[R]
|
Variable number of |
()
|
Returns:
Type | Description |
---|---|
GenerativeFunction[R]
|
A new |
Examples:
import jax
import genjax
# Define component generative functions
@genjax.gen
def component1(x):
return genjax.normal(x, 1.0) @ "y"
@genjax.gen
def component2(x):
return genjax.normal(x, 2.0) @ "y"
# Create mixture model
mixture = genjax.mix(component1, component2)
# Use the mixture model
key = jax.random.key(0)
logits = jax.numpy.array([0.3, 0.7]) # Favors component2
trace = mixture.simulate(key, (logits, (0.0,), (7.0,)))
print(trace.render_html())