Skip to content

Combinators: structured patterns of composition

While the programmatic genjax.StaticGenerativeFunction language is powerful, its restrictions can be limiting. Combinators are a way to express common patterns of composition in a more concise way, and to gain access to effects which are common in JAX (like jax.vmap) for generative computations.

Each of the combinators below is implemented as a method on genjax.GenerativeFunction and as a standalone decorator.

You should strongly prefer the method form. Here's an example of the vmap combinator created by the genjax.GenerativeFunction.vmap method:

Here is the vmap combinator used as a method. square_many below accepts an array and returns an array:

import jax, genjax

@genjax.gen
def square(x):
    return x * x

square_many = square.vmap()

Here is square_many defined with genjax.vmap, the decorator version of the vmap method:

@genjax.vmap()
@genjax.gen
def square_many_decorator(x):
    return x * x

Warning

We do not recommend this style, since the original building block generative function won't be available by itself. Please prefer using the combinator methods, or the transformation style shown below.

If you insist on using the decorator form, you can preserve the original function like this:

@genjax.gen
def square(x):
    return x * x

# Use the decorator as a transformation:
square_many_better = genjax.vmap()(square)

vmap-like Combinators

genjax.vmap

vmap(
    *, in_axes: InAxes = 0
) -> Callable[[GenerativeFunction[R]], Vmap[R]]

Returns a decorator that wraps a GenerativeFunction and returns a new GenerativeFunction that performs a vectorized map over the argument specified by in_axes. Traced values are nested under an index, and the retval is vectorized.

Parameters:

Name Type Description Default

in_axes

InAxes

Selector specifying which input arguments (or index into them) should be vectorized. in_axes must match (or prefix) the Pytree type of the argument tuple for the underlying gen_fn. Defaults to 0, i.e., the first argument. See this link for more detail.

0

Returns:

Type Description
Callable[[GenerativeFunction[R]], Vmap[R]]

A decorator that converts a genjax.GenerativeFunction into a new genjax.GenerativeFunction that accepts an argument of one-higher dimension at the position specified by in_axes.

Examples:

import jax, genjax
import jax.numpy as jnp


@genjax.vmap(in_axes=0)
@genjax.gen
def vmapped_model(x):
    v = genjax.normal(x, 1.0) @ "v"
    return genjax.normal(v, 0.01) @ "q"


key = jax.random.key(314159)
arr = jnp.ones(100)

# `vmapped_model` accepts an array of numbers:
tr = jax.jit(vmapped_model.simulate)(key, (arr,))

print(tr.render_html())
Source code in src/genjax/_src/generative_functions/combinators/vmap.py
def vmap(*, in_axes: InAxes = 0) -> Callable[[GenerativeFunction[R]], Vmap[R]]:
    """
    Returns a decorator that wraps a [`GenerativeFunction`][genjax.GenerativeFunction] and returns a new `GenerativeFunction` that performs a vectorized map over the argument specified by `in_axes`. Traced values are nested under an index, and the retval is vectorized.

    Args:
        in_axes: Selector specifying which input arguments (or index into them) should be vectorized. `in_axes` must match (or prefix) the `Pytree` type of the argument tuple for the underlying `gen_fn`. Defaults to 0, i.e., the first argument. See [this link](https://jax.readthedocs.io/en/latest/pytrees.html#applying-optional-parameters-to-pytrees) for more detail.

    Returns:
        A decorator that converts a [`genjax.GenerativeFunction`][] into a new [`genjax.GenerativeFunction`][] that accepts an argument of one-higher dimension at the position specified by `in_axes`.

    Examples:
        ```python exec="yes" html="true" source="material-block" session="vmap"
        import jax, genjax
        import jax.numpy as jnp


        @genjax.vmap(in_axes=0)
        @genjax.gen
        def vmapped_model(x):
            v = genjax.normal(x, 1.0) @ "v"
            return genjax.normal(v, 0.01) @ "q"


        key = jax.random.key(314159)
        arr = jnp.ones(100)

        # `vmapped_model` accepts an array of numbers:
        tr = jax.jit(vmapped_model.simulate)(key, (arr,))

        print(tr.render_html())
        ```
    """

    def decorator(gen_fn: GenerativeFunction[R]) -> Vmap[R]:
        return Vmap(gen_fn, in_axes)

    return decorator

genjax.repeat

repeat(
    *, n: int
) -> Callable[
    [GenerativeFunction[R]], GenerativeFunction[R]
]

Returns a decorator that wraps a genjax.GenerativeFunction gen_fn of type a -> b and returns a new GenerativeFunction of type a -> [b] that samples from gen_fnntimes, returning a vector ofn` results.

The values traced by each call gen_fn will be nested under an integer index that matches the loop iteration index that generated it.

This combinator is useful for creating multiple samples from the same generative model in a batched manner.

Parameters:

Name Type Description Default

n

int

The number of times to sample from the generative function.

required

Returns:

Type Description
Callable[[GenerativeFunction[R]], GenerativeFunction[R]]

A new genjax.GenerativeFunction that samples from the original function n times.

Examples:

import genjax, jax


@genjax.repeat(n=10)
@genjax.gen
def normal_draws(mean):
    return genjax.normal(mean, 1.0) @ "x"


key = jax.random.key(314159)

# Generate 10 draws from a normal distribution with mean 2.0
tr = jax.jit(normal_draws.simulate)(key, (2.0,))
print(tr.render_html())
Source code in src/genjax/_src/generative_functions/combinators/repeat.py
def repeat(*, n: int) -> Callable[[GenerativeFunction[R]], GenerativeFunction[R]]:
    """
    Returns a decorator that wraps a [`genjax.GenerativeFunction`][] `gen_fn` of type `a -> b` and returns a new `GenerativeFunction` of type `a -> [b]` that samples from `gen_fn `n` times, returning a vector of `n` results.

    The values traced by each call `gen_fn` will be nested under an integer index that matches the loop iteration index that generated it.

    This combinator is useful for creating multiple samples from the same generative model in a batched manner.

    Args:
        n: The number of times to sample from the generative function.

    Returns:
        A new [`genjax.GenerativeFunction`][] that samples from the original function `n` times.

    Examples:
        ```python exec="yes" html="true" source="material-block" session="repeat"
        import genjax, jax


        @genjax.repeat(n=10)
        @genjax.gen
        def normal_draws(mean):
            return genjax.normal(mean, 1.0) @ "x"


        key = jax.random.key(314159)

        # Generate 10 draws from a normal distribution with mean 2.0
        tr = jax.jit(normal_draws.simulate)(key, (2.0,))
        print(tr.render_html())
        ```
    """

    def decorator(gen_fn: GenerativeFunction[R]) -> GenerativeFunction[R]:
        return RepeatCombinator(gen_fn, n=n)

    return decorator

scan-like Combinators

genjax.scan

scan(*, n: int | None = None) -> Callable[
    [GenerativeFunction[tuple[Carry, Y]]],
    GenerativeFunction[tuple[Carry, Y]],
]

Returns a decorator that wraps a genjax.GenerativeFunction of type (c, a) -> (c, b)and returns a new genjax.GenerativeFunction of type (c, [a]) -> (c, [b]) where.

  • c is a loop-carried value, which must hold a fixed shape and dtype across all iterations
  • a may be a primitive, an array type or a pytree (container) type with array leaves
  • b may be a primitive, an array type or a pytree (container) type with array leaves.

The values traced by each call to the original generative function will be nested under an integer index that matches the loop iteration index that generated it.

For any array type specifier t, [t] represents the type with an additional leading axis, and if t is a pytree (container) type with array leaves then [t] represents the type with the same pytree structure and corresponding leaves each with an additional leading axis.

When the type of xs in the snippet below (denoted [a] above) is an array type or None, and the type of ys in the snippet below (denoted [b] above) is an array type, the semantics of the returned genjax.GenerativeFunction are given roughly by this Python implementation:

def scan(f, init, xs, length=None):
    if xs is None:
        xs = [None] * length
    carry = init
    ys = []
    for x in xs:
        carry, y = f(carry, x)
        ys.append(y)
    return carry, np.stack(ys)

Unlike that Python version, both xs and ys may be arbitrary pytree values, and so multiple arrays can be scanned over at once and produce multiple output arrays. None is actually a special case of this, as it represents an empty pytree.

The loop-carried value c must hold a fixed shape and dtype across all iterations (and not just be consistent up to NumPy rank/shape broadcasting and dtype promotion rules, for example). In other words, the type c in the type signature above represents an array with a fixed shape and dtype (or a nested tuple/list/dict container data structure with a fixed structure and arrays with fixed shape and dtype at the leaves).

Parameters:

Name Type Description Default

n

int | None

optional integer specifying the number of loop iterations, which (if supplied) must agree with the sizes of leading axes of the arrays in the returned function's second argument. If supplied then the returned generative function can take None as its second argument.

None

Returns:

Type Description
Callable[[GenerativeFunction[tuple[Carry, Y]]], GenerativeFunction[tuple[Carry, Y]]]

A new genjax.GenerativeFunction that takes a loop-carried value and a new input, and returns a new loop-carried value along with either None or an output to be collected into the second return value.

Examples:

Scan for 1000 iterations with no array input:

import jax
import genjax


@genjax.scan(n=1000)
@genjax.gen
def random_walk(prev, _):
    x = genjax.normal(prev, 1.0) @ "x"
    return x, None


init = 0.5
key = jax.random.key(314159)

tr = jax.jit(random_walk.simulate)(key, (init, None))
print(tr.render_html())

Scan across an input array:

import jax.numpy as jnp


@genjax.scan()
@genjax.gen
def add_and_square_all(sum, x):
    new_sum = sum + x
    return new_sum, sum * sum


init = 0.0
xs = jnp.ones(10)

tr = jax.jit(add_and_square_all.simulate)(key, (init, xs))

# The retval has the final carry and an array of all `sum*sum` returned.
print(tr.render_html())

Source code in src/genjax/_src/generative_functions/combinators/scan.py
def scan(
    *, n: int | None = None
) -> Callable[
    [GenerativeFunction[tuple[Carry, Y]]], GenerativeFunction[tuple[Carry, Y]]
]:
    """Returns a decorator that wraps a [`genjax.GenerativeFunction`][] of type
    `(c, a) -> (c, b)`and returns a new [`genjax.GenerativeFunction`][] of type
    `(c, [a]) -> (c, [b])` where.

    - `c` is a loop-carried value, which must hold a fixed shape and dtype across all iterations
    - `a` may be a primitive, an array type or a pytree (container) type with array leaves
    - `b` may be a primitive, an array type or a pytree (container) type with array leaves.

    The values traced by each call to the original generative function will be nested under an integer index that matches the loop iteration index that generated it.

    For any array type specifier `t`, `[t]` represents the type with an additional leading axis, and if `t` is a pytree (container) type with array leaves then `[t]` represents the type with the same pytree structure and corresponding leaves each with an additional leading axis.

    When the type of `xs` in the snippet below (denoted `[a]` above) is an array type or None, and the type of `ys` in the snippet below (denoted `[b]` above) is an array type, the semantics of the returned [`genjax.GenerativeFunction`][] are given roughly by this Python implementation:

    ```python
    def scan(f, init, xs, length=None):
        if xs is None:
            xs = [None] * length
        carry = init
        ys = []
        for x in xs:
            carry, y = f(carry, x)
            ys.append(y)
        return carry, np.stack(ys)
    ```

    Unlike that Python version, both `xs` and `ys` may be arbitrary pytree values, and so multiple arrays can be scanned over at once and produce multiple output arrays. `None` is actually a special case of this, as it represents an empty pytree.

    The loop-carried value `c` must hold a fixed shape and dtype across all iterations (and not just be consistent up to NumPy rank/shape broadcasting and dtype promotion rules, for example). In other words, the type `c` in the type signature above represents an array with a fixed shape and dtype (or a nested tuple/list/dict container data structure with a fixed structure and arrays with fixed shape and dtype at the leaves).

    Args:
        n: optional integer specifying the number of loop iterations, which (if supplied) must agree with the sizes of leading axes of the arrays in the returned function's second argument. If supplied then the returned generative function can take `None` as its second argument.

    Returns:
        A new [`genjax.GenerativeFunction`][] that takes a loop-carried value and a new input, and returns a new loop-carried value along with either `None` or an output to be collected into the second return value.

    Examples:
        Scan for 1000 iterations with no array input:
        ```python exec="yes" html="true" source="material-block" session="scan"
        import jax
        import genjax


        @genjax.scan(n=1000)
        @genjax.gen
        def random_walk(prev, _):
            x = genjax.normal(prev, 1.0) @ "x"
            return x, None


        init = 0.5
        key = jax.random.key(314159)

        tr = jax.jit(random_walk.simulate)(key, (init, None))
        print(tr.render_html())
        ```

        Scan across an input array:
        ```python exec="yes" html="true" source="material-block" session="scan"
        import jax.numpy as jnp


        @genjax.scan()
        @genjax.gen
        def add_and_square_all(sum, x):
            new_sum = sum + x
            return new_sum, sum * sum


        init = 0.0
        xs = jnp.ones(10)

        tr = jax.jit(add_and_square_all.simulate)(key, (init, xs))

        # The retval has the final carry and an array of all `sum*sum` returned.
        print(tr.render_html())
        ```
    """

    def decorator(f: GenerativeFunction[tuple[Carry, Y]]):
        return Scan[Carry, Y](f, length=n)

    return decorator

genjax.accumulate

accumulate() -> Callable[
    [GenerativeFunction[Carry]],
    GenerativeFunction[Carry],
]

Returns a decorator that wraps a genjax.GenerativeFunction of type (c, a) -> c and returns a new genjax.GenerativeFunction of type (c, [a]) -> [c] where.

  • c is a loop-carried value, which must hold a fixed shape and dtype across all iterations
  • [c] is an array of all loop-carried values seen during iteration (including the first)
  • a may be a primitive, an array type or a pytree (container) type with array leaves

All traced values are nested under an index.

For any array type specifier t, [t] represents the type with an additional leading axis, and if t is a pytree (container) type with array leaves then [t] represents the type with the same pytree structure and corresponding leaves each with an additional leading axis.

The semantics of the returned genjax.GenerativeFunction are given roughly by this Python implementation (note the similarity to itertools.accumulate):

def accumulate(f, init, xs):
    carry = init
    carries = [init]
    for x in xs:
        carry = f(carry, x)
        carries.append(carry)
    return carries

Unlike that Python version, both xs and carries may be arbitrary pytree values, and so multiple arrays can be scanned over at once and produce multiple output arrays.

The loop-carried value c must hold a fixed shape and dtype across all iterations (and not just be consistent up to NumPy rank/shape broadcasting and dtype promotion rules, for example). In other words, the type c in the type signature above represents an array with a fixed shape and dtype (or a nested tuple/list/dict container data structure with a fixed structure and arrays with fixed shape and dtype at the leaves).

Examples:

accumulate a running total:

import jax
import genjax
import jax.numpy as jnp


@genjax.accumulate()
@genjax.gen
def add(sum, x):
    new_sum = sum + x
    return new_sum


init = 0.0
key = jax.random.key(314159)
xs = jnp.ones(10)

tr = jax.jit(add.simulate)(key, (init, xs))
print(tr.render_html())

Source code in src/genjax/_src/generative_functions/combinators/scan.py
def accumulate() -> Callable[[GenerativeFunction[Carry]], GenerativeFunction[Carry]]:
    """Returns a decorator that wraps a [`genjax.GenerativeFunction`][] of type
    `(c, a) -> c` and returns a new [`genjax.GenerativeFunction`][] of type
    `(c, [a]) -> [c]` where.

    - `c` is a loop-carried value, which must hold a fixed shape and dtype across all iterations
    - `[c]` is an array of all loop-carried values seen during iteration (including the first)
    - `a` may be a primitive, an array type or a pytree (container) type with array leaves

    All traced values are nested under an index.

    For any array type specifier `t`, `[t]` represents the type with an additional leading axis, and if `t` is a pytree (container) type with array leaves then `[t]` represents the type with the same pytree structure and corresponding leaves each with an additional leading axis.

    The semantics of the returned [`genjax.GenerativeFunction`][] are given roughly by this Python implementation (note the similarity to [`itertools.accumulate`](https://docs.python.org/3/library/itertools.html#itertools.accumulate)):

    ```python
    def accumulate(f, init, xs):
        carry = init
        carries = [init]
        for x in xs:
            carry = f(carry, x)
            carries.append(carry)
        return carries
    ```

    Unlike that Python version, both `xs` and `carries` may be arbitrary pytree values, and so multiple arrays can be scanned over at once and produce multiple output arrays.

    The loop-carried value `c` must hold a fixed shape and dtype across all iterations (and not just be consistent up to NumPy rank/shape broadcasting and dtype promotion rules, for example). In other words, the type `c` in the type signature above represents an array with a fixed shape and dtype (or a nested tuple/list/dict container data structure with a fixed structure and arrays with fixed shape and dtype at the leaves).

    Examples:
        accumulate a running total:
        ```python exec="yes" html="true" source="material-block" session="scan"
        import jax
        import genjax
        import jax.numpy as jnp


        @genjax.accumulate()
        @genjax.gen
        def add(sum, x):
            new_sum = sum + x
            return new_sum


        init = 0.0
        key = jax.random.key(314159)
        xs = jnp.ones(10)

        tr = jax.jit(add.simulate)(key, (init, xs))
        print(tr.render_html())
        ```
    """

    def decorator(f: GenerativeFunction[Carry]) -> GenerativeFunction[Carry]:
        return (
            f.map(lambda ret: (ret, ret))
            .scan()
            .dimap(pre=lambda *args: args, post=prepend_initial_acc)
        )

    return decorator

genjax.reduce

reduce() -> Callable[
    [GenerativeFunction[Carry]],
    GenerativeFunction[Carry],
]

Returns a decorator that wraps a genjax.GenerativeFunction of type (c, a) -> c and returns a new genjax.GenerativeFunction of type (c, [a]) -> c where.

  • c is a loop-carried value, which must hold a fixed shape and dtype across all iterations
  • a may be a primitive, an array type or a pytree (container) type with array leaves

All traced values are nested under an index.

For any array type specifier t, [t] represents the type with an additional leading axis, and if t is a pytree (container) type with array leaves then [t] represents the type with the same pytree structure and corresponding leaves each with an additional leading axis.

The semantics of the returned genjax.GenerativeFunction are given roughly by this Python implementation (note the similarity to functools.reduce):

def reduce(f, init, xs):
    carry = init
    for x in xs:
        carry = f(carry, x)
    return carry

Unlike that Python version, both xs and carry may be arbitrary pytree values, and so multiple arrays can be scanned over at once and produce multiple output arrays.

The loop-carried value c must hold a fixed shape and dtype across all iterations (and not just be consistent up to NumPy rank/shape broadcasting and dtype promotion rules, for example). In other words, the type c in the type signature above represents an array with a fixed shape and dtype (or a nested tuple/list/dict container data structure with a fixed structure and arrays with fixed shape and dtype at the leaves).

Examples:

sum an array of numbers:

import jax
import genjax
import jax.numpy as jnp


@genjax.reduce()
@genjax.gen
def add(sum, x):
    new_sum = sum + x
    return new_sum


init = 0.0
key = jax.random.key(314159)
xs = jnp.ones(10)

tr = jax.jit(add.simulate)(key, (init, xs))
print(tr.render_html())

Source code in src/genjax/_src/generative_functions/combinators/scan.py
def reduce() -> Callable[[GenerativeFunction[Carry]], GenerativeFunction[Carry]]:
    """Returns a decorator that wraps a [`genjax.GenerativeFunction`][] of type
    `(c, a) -> c` and returns a new [`genjax.GenerativeFunction`][] of type
    `(c, [a]) -> c` where.

    - `c` is a loop-carried value, which must hold a fixed shape and dtype across all iterations
    - `a` may be a primitive, an array type or a pytree (container) type with array leaves

    All traced values are nested under an index.

    For any array type specifier `t`, `[t]` represents the type with an additional leading axis, and if `t` is a pytree (container) type with array leaves then `[t]` represents the type with the same pytree structure and corresponding leaves each with an additional leading axis.

    The semantics of the returned [`genjax.GenerativeFunction`][] are given roughly by this Python implementation (note the similarity to [`functools.reduce`](https://docs.python.org/3/library/itertools.html#functools.reduce)):

    ```python
    def reduce(f, init, xs):
        carry = init
        for x in xs:
            carry = f(carry, x)
        return carry
    ```

    Unlike that Python version, both `xs` and `carry` may be arbitrary pytree values, and so multiple arrays can be scanned over at once and produce multiple output arrays.

    The loop-carried value `c` must hold a fixed shape and dtype across all iterations (and not just be consistent up to NumPy rank/shape broadcasting and dtype promotion rules, for example). In other words, the type `c` in the type signature above represents an array with a fixed shape and dtype (or a nested tuple/list/dict container data structure with a fixed structure and arrays with fixed shape and dtype at the leaves).

    Examples:
        sum an array of numbers:
        ```python exec="yes" html="true" source="material-block" session="scan"
        import jax
        import genjax
        import jax.numpy as jnp


        @genjax.reduce()
        @genjax.gen
        def add(sum, x):
            new_sum = sum + x
            return new_sum


        init = 0.0
        key = jax.random.key(314159)
        xs = jnp.ones(10)

        tr = jax.jit(add.simulate)(key, (init, xs))
        print(tr.render_html())
        ```
    """

    def decorator(f: GenerativeFunction[Carry]) -> GenerativeFunction[Carry]:
        def pre(ret: Carry):
            return ret, None

        def post(ret: tuple[Carry, None]):
            return ret[0]

        return f.map(pre).scan().map(post)

    return decorator

genjax.iterate

iterate(
    *, n: int
) -> Callable[
    [GenerativeFunction[Y]], GenerativeFunction[Y]
]

Returns a decorator that wraps a genjax.GenerativeFunction of type a -> a and returns a new genjax.GenerativeFunction of type a -> [a] where.

  • a is a loop-carried value, which must hold a fixed shape and dtype across all iterations
  • [a] is an array of all a, f(a), f(f(a)) etc. values seen during iteration.

All traced values are nested under an index.

The semantics of the returned genjax.GenerativeFunction are given roughly by this Python implementation:

def iterate(f, n, init):
    input = init
    seen = [init]
    for _ in range(n):
        input = f(input)
        seen.append(input)
    return seen

init may be an arbitrary pytree value, and so multiple arrays can be iterated over at once and produce multiple output arrays.

The iterated value a must hold a fixed shape and dtype across all iterations (and not just be consistent up to NumPy rank/shape broadcasting and dtype promotion rules, for example). In other words, the type a in the type signature above represents an array with a fixed shape and dtype (or a nested tuple/list/dict container data structure with a fixed structure and arrays with fixed shape and dtype at the leaves).

Parameters:

Name Type Description Default

n

int

the number of iterations to run.

required

Examples:

iterative addition, returning all intermediate sums:

import jax
import genjax


@genjax.iterate(n=100)
@genjax.gen
def inc(x):
    return x + 1


init = 0.0
key = jax.random.key(314159)

tr = jax.jit(inc.simulate)(key, (init,))
print(tr.render_html())

Source code in src/genjax/_src/generative_functions/combinators/scan.py
def iterate(*, n: int) -> Callable[[GenerativeFunction[Y]], GenerativeFunction[Y]]:
    """Returns a decorator that wraps a [`genjax.GenerativeFunction`][] of type
    `a -> a` and returns a new [`genjax.GenerativeFunction`][] of type `a ->
    [a]` where.

    - `a` is a loop-carried value, which must hold a fixed shape and dtype across all iterations
    - `[a]` is an array of all `a`, `f(a)`, `f(f(a))` etc. values seen during iteration.

    All traced values are nested under an index.

    The semantics of the returned [`genjax.GenerativeFunction`][] are given roughly by this Python implementation:

    ```python
    def iterate(f, n, init):
        input = init
        seen = [init]
        for _ in range(n):
            input = f(input)
            seen.append(input)
        return seen
    ```

    `init` may be an arbitrary pytree value, and so multiple arrays can be iterated over at once and produce multiple output arrays.

    The iterated value `a` must hold a fixed shape and dtype across all iterations (and not just be consistent up to NumPy rank/shape broadcasting and dtype promotion rules, for example). In other words, the type `a` in the type signature above represents an array with a fixed shape and dtype (or a nested tuple/list/dict container data structure with a fixed structure and arrays with fixed shape and dtype at the leaves).

    Args:
        n: the number of iterations to run.

    Examples:
        iterative addition, returning all intermediate sums:
        ```python exec="yes" html="true" source="material-block" session="scan"
        import jax
        import genjax


        @genjax.iterate(n=100)
        @genjax.gen
        def inc(x):
            return x + 1


        init = 0.0
        key = jax.random.key(314159)

        tr = jax.jit(inc.simulate)(key, (init,))
        print(tr.render_html())
        ```
    """

    def decorator(f: GenerativeFunction[Y]) -> GenerativeFunction[Y]:
        # strip off the JAX-supplied `None` on the way in, accumulate `ret` on the way out.
        return (
            f.dimap(
                pre=lambda *args: args[:-1],
                post=lambda _args, _xformed, ret: (ret, ret),
            )
            .scan(n=n)
            .dimap(pre=lambda *args: (*args, None), post=prepend_initial_acc)
        )

    return decorator

genjax.iterate_final

iterate_final(
    *, n: int
) -> Callable[
    [GenerativeFunction[Y]], GenerativeFunction[Y]
]

Returns a decorator that wraps a genjax.GenerativeFunction of type a -> a and returns a new genjax.GenerativeFunction of type a -> a where.

  • a is a loop-carried value, which must hold a fixed shape and dtype across all iterations
  • the original function is invoked n times with each input coming from the previous invocation's output, so that the new function returns \(f^n(a)\)

All traced values are nested under an index.

The semantics of the returned genjax.GenerativeFunction are given roughly by this Python implementation:

def iterate_final(f, n, init):
    ret = init
    for _ in range(n):
        ret = f(ret)
    return ret

init may be an arbitrary pytree value, and so multiple arrays can be iterated over at once and produce multiple output arrays.

The iterated value a must hold a fixed shape and dtype across all iterations (and not just be consistent up to NumPy rank/shape broadcasting and dtype promotion rules, for example). In other words, the type a in the type signature above represents an array with a fixed shape and dtype (or a nested tuple/list/dict container data structure with a fixed structure and arrays with fixed shape and dtype at the leaves).

Parameters:

Name Type Description Default

n

int

the number of iterations to run.

required

Examples:

iterative addition:

import jax
import genjax


@genjax.iterate_final(n=100)
@genjax.gen
def inc(x):
    return x + 1


init = 0.0
key = jax.random.key(314159)

tr = jax.jit(inc.simulate)(key, (init,))
print(tr.render_html())

Source code in src/genjax/_src/generative_functions/combinators/scan.py
def iterate_final(
    *, n: int
) -> Callable[[GenerativeFunction[Y]], GenerativeFunction[Y]]:
    """Returns a decorator that wraps a [`genjax.GenerativeFunction`][] of type
    `a -> a` and returns a new [`genjax.GenerativeFunction`][] of type `a -> a`
    where.

    - `a` is a loop-carried value, which must hold a fixed shape and dtype across all iterations
    - the original function is invoked `n` times with each input coming from the previous invocation's output, so that the new function returns $f^n(a)$

    All traced values are nested under an index.

    The semantics of the returned [`genjax.GenerativeFunction`][] are given roughly by this Python implementation:

    ```python
    def iterate_final(f, n, init):
        ret = init
        for _ in range(n):
            ret = f(ret)
        return ret
    ```

    `init` may be an arbitrary pytree value, and so multiple arrays can be iterated over at once and produce multiple output arrays.

    The iterated value `a` must hold a fixed shape and dtype across all iterations (and not just be consistent up to NumPy rank/shape broadcasting and dtype promotion rules, for example). In other words, the type `a` in the type signature above represents an array with a fixed shape and dtype (or a nested tuple/list/dict container data structure with a fixed structure and arrays with fixed shape and dtype at the leaves).

    Args:
        n: the number of iterations to run.

    Examples:
        iterative addition:
        ```python exec="yes" html="true" source="material-block" session="scan"
        import jax
        import genjax


        @genjax.iterate_final(n=100)
        @genjax.gen
        def inc(x):
            return x + 1


        init = 0.0
        key = jax.random.key(314159)

        tr = jax.jit(inc.simulate)(key, (init,))
        print(tr.render_html())
        ```
    """

    def decorator(f: GenerativeFunction[Y]) -> GenerativeFunction[Y]:
        # strip off the JAX-supplied `None` on the way in, no accumulation on the way out.
        def pre_post(_, _xformed, ret: Y):
            return ret, None

        def post_post(_, _xformed, ret: tuple[Y, None]):
            return ret[0]

        return (
            f.dimap(pre=lambda *args: args[:-1], post=pre_post)
            .scan(n=n)
            .dimap(pre=lambda *args: (*args, None), post=post_post)
        )

    return decorator

genjax.masked_iterate

masked_iterate() -> (
    Callable[
        [GenerativeFunction[Y]], GenerativeFunction[Y]
    ]
)

Transforms a generative function that takes a single argument of type a and returns a value of type a, into a function that takes a tuple of arguments (a, [mask]) and returns a list of values of type a.

The original function is modified to accept an additional argument mask, which is a boolean value indicating whether the operation should be masked or not. The function returns a Masked list of results of the original operation with the input [mask] as mask.

All traced values from the kernel generative function are traced (with an added axis due to the scan) but only those indices from [mask] with a flag of True will accounted for in inference, notably for score computations.

Example
import jax
import genjax

masks = jnp.array([True, False, True])


# Create a kernel generative function
@genjax.gen
def step(x):
    _ = (
        genjax.normal.mask().vmap(in_axes=(0, None, None))(masks, x, 1.0)
        @ "rats"
    )
    return x


# Create a model using masked_iterate
model = genjax.masked_iterate()(step)

# Simulate from the model
key = jax.random.key(0)
mask_steps = jnp.arange(10) < 5
tr = model.simulate(key, (0.0, mask_steps))
print(tr.render_html())
Source code in src/genjax/_src/generative_functions/combinators/scan.py
def masked_iterate() -> Callable[[GenerativeFunction[Y]], GenerativeFunction[Y]]:
    """
    Transforms a generative function that takes a single argument of type `a` and returns a value of type `a`, into a function that takes a tuple of arguments `(a, [mask])` and returns a list of values of type `a`.

    The original function is modified to accept an additional argument `mask`, which is a boolean value indicating whether the operation should be masked or not. The function returns a Masked list of results of the original operation with the input [mask] as mask.

    All traced values from the kernel generative function are traced (with an added axis due to the scan) but only those indices from [mask] with a flag of True will accounted for in inference, notably for score computations.

    Example:
        ```python exec="yes" html="true" source="material-block" session="scan"
        import jax
        import genjax

        masks = jnp.array([True, False, True])


        # Create a kernel generative function
        @genjax.gen
        def step(x):
            _ = (
                genjax.normal.mask().vmap(in_axes=(0, None, None))(masks, x, 1.0)
                @ "rats"
            )
            return x


        # Create a model using masked_iterate
        model = genjax.masked_iterate()(step)

        # Simulate from the model
        key = jax.random.key(0)
        mask_steps = jnp.arange(10) < 5
        tr = model.simulate(key, (0.0, mask_steps))
        print(tr.render_html())
        ```
    """

    def decorator(step: GenerativeFunction[Y]) -> GenerativeFunction[Y]:
        def pre(state, flag: Flag):
            return flag, state

        def post(_unused_args, _xformed, masked_retval: Mask[Y]):
            v = masked_retval.value
            return v, v

        # scan_step: (a, bool) -> a
        scan_step = step.mask().dimap(pre=pre, post=post)
        return scan_step.scan().dimap(pre=lambda *args: args, post=prepend_initial_acc)

    return decorator

genjax.masked_iterate_final

masked_iterate_final() -> (
    Callable[
        [GenerativeFunction[Y]], GenerativeFunction[Y]
    ]
)

Transforms a generative function that takes a single argument of type a and returns a value of type a, into a function that takes a tuple of arguments (a, [mask]) and returns a value of type a.

The original function is modified to accept an additional argument mask, which is a boolean value indicating whether the operation should be masked or not. The function returns the result of the original operation if mask is True, and the original input if mask is False.

All traced values from the kernel generative function are traced (with an added axis due to the scan) but only those indices from [mask] with a flag of True will accounted for in inference, notably for score computations.

Example
import jax
import genjax

masks = jnp.array([True, False, True])


# Create a kernel generative function
@genjax.gen
def step(x):
    _ = (
        genjax.normal.mask().vmap(in_axes=(0, None, None))(masks, x, 1.0)
        @ "rats"
    )
    return x


# Create a model using masked_iterate_final
model = genjax.masked_iterate_final()(step)

# Simulate from the model
key = jax.random.key(0)
mask_steps = jnp.arange(10) < 5
tr = model.simulate(key, (0.0, mask_steps))
print(tr.render_html())
Source code in src/genjax/_src/generative_functions/combinators/scan.py
def masked_iterate_final() -> Callable[[GenerativeFunction[Y]], GenerativeFunction[Y]]:
    """
    Transforms a generative function that takes a single argument of type `a` and returns a value of type `a`, into a function that takes a tuple of arguments `(a, [mask])` and returns a value of type `a`.

    The original function is modified to accept an additional argument `mask`, which is a boolean value indicating whether the operation should be masked or not. The function returns the result of the original operation if `mask` is `True`, and the original input if `mask` is `False`.

    All traced values from the kernel generative function are traced (with an added axis due to the scan) but only those indices from [mask] with a flag of True will accounted for in inference, notably for score computations.

    Example:
        ```python exec="yes" html="true" source="material-block" session="scan"
        import jax
        import genjax

        masks = jnp.array([True, False, True])


        # Create a kernel generative function
        @genjax.gen
        def step(x):
            _ = (
                genjax.normal.mask().vmap(in_axes=(0, None, None))(masks, x, 1.0)
                @ "rats"
            )
            return x


        # Create a model using masked_iterate_final
        model = genjax.masked_iterate_final()(step)

        # Simulate from the model
        key = jax.random.key(0)
        mask_steps = jnp.arange(10) < 5
        tr = model.simulate(key, (0.0, mask_steps))
        print(tr.render_html())
        ```
    """

    def decorator(step: GenerativeFunction[Y]) -> GenerativeFunction[Y]:
        def pre(state, flag: Flag):
            return flag, state

        def post(_unused_args, _xformed, masked_retval: Mask[Y]):
            return masked_retval.value, None

        # scan_step: (a, bool) -> a
        scan_step = step.mask().dimap(pre=pre, post=post)
        return scan_step.scan().map(lambda ret: ret[0])

    return decorator

Control Flow Combinators

genjax.or_else

Given two genjax.GenerativeFunctions if_gen_fn and else_gen_fn, returns a new genjax.GenerativeFunction that accepts

  • a boolean argument
  • an argument tuple for if_gen_fn
  • an argument tuple for the supplied else_gen_fn

and acts like if_gen_fn when the boolean is True or else_gen_fn otherwise.

Parameters:

Name Type Description Default

else_gen_fn

GenerativeFunction[R]

called when the boolean argument is False.

required

Returns:

Type Description
GenerativeFunction[R]

A genjax.GenerativeFunction modified for conditional execution.

Examples:

import jax
import jax.numpy as jnp
import genjax


@genjax.gen
def if_model(x):
    return genjax.normal(x, 1.0) @ "if_value"


@genjax.gen
def else_model(x):
    return genjax.normal(x, 5.0) @ "else_value"


or_else_model = genjax.or_else(if_model, else_model)


@genjax.gen
def model(toss: bool):
    # Note that `or_else_model` takes a new boolean predicate in
    # addition to argument tuples for each branch.
    return or_else_model(toss, (1.0,), (10.0,)) @ "tossed"


key = jax.random.key(314159)

tr = jax.jit(model.simulate)(key, (True,))

print(tr.render_html())
Source code in src/genjax/_src/generative_functions/combinators/or_else.py
def or_else(
    if_gen_fn: GenerativeFunction[R],
    else_gen_fn: GenerativeFunction[R],
) -> GenerativeFunction[R]:
    """
    Given two [`genjax.GenerativeFunction`][]s `if_gen_fn` and `else_gen_fn`, returns a new [`genjax.GenerativeFunction`][] that accepts

    - a boolean argument
    - an argument tuple for `if_gen_fn`
    - an argument tuple for the supplied `else_gen_fn`

    and acts like `if_gen_fn` when the boolean is `True` or `else_gen_fn` otherwise.

    Args:
        else_gen_fn: called when the boolean argument is `False`.

    Returns:
        A [`genjax.GenerativeFunction`][] modified for conditional execution.

    Examples:
        ```python exec="yes" html="true" source="material-block" session="or_else"
        import jax
        import jax.numpy as jnp
        import genjax


        @genjax.gen
        def if_model(x):
            return genjax.normal(x, 1.0) @ "if_value"


        @genjax.gen
        def else_model(x):
            return genjax.normal(x, 5.0) @ "else_value"


        or_else_model = genjax.or_else(if_model, else_model)


        @genjax.gen
        def model(toss: bool):
            # Note that `or_else_model` takes a new boolean predicate in
            # addition to argument tuples for each branch.
            return or_else_model(toss, (1.0,), (10.0,)) @ "tossed"


        key = jax.random.key(314159)

        tr = jax.jit(model.simulate)(key, (True,))

        print(tr.render_html())
        ```
    """

    def argument_mapping(
        b: ScalarFlag, if_args: tuple[Any, ...], else_args: tuple[Any, ...]
    ):
        # Note that `True` maps to 0 to select the "if" branch, `False` to 1.
        idx = jnp.array(jnp.logical_not(b), dtype=int)
        return (idx, if_args, else_args)

    return if_gen_fn.switch(else_gen_fn).contramap(argument_mapping)

genjax.switch

switch(*gen_fns: GenerativeFunction[R]) -> Switch[R]

Given n genjax.GenerativeFunction inputs, returns a genjax.GenerativeFunction that accepts n+1 arguments:

  • an index in the range \([0, n)\)
  • a tuple of arguments for each of the input generative functions (n total tuples)

and executes the generative function at the supplied index with its provided arguments.

If index is out of bounds, index is clamped to within bounds.

Parameters:

Name Type Description Default

gen_fns

GenerativeFunction[R]

generative functions that the Switch will select from.

()

Examples:

Create a Switch via the genjax.switch method:

import jax, genjax


@genjax.gen
def branch_1():
    x = genjax.normal(0.0, 1.0) @ "x1"


@genjax.gen
def branch_2():
    x = genjax.bernoulli(probs=0.3) @ "x2"


switch = genjax.switch(branch_1, branch_2)

key = jax.random.key(314159)
jitted = jax.jit(switch.simulate)

# Select `branch_2` by providing 1:
tr = jitted(key, (1, (), ()))

print(tr.render_html())

Source code in src/genjax/_src/generative_functions/combinators/switch.py
def switch(
    *gen_fns: GenerativeFunction[R],
) -> Switch[R]:
    """
    Given `n` [`genjax.GenerativeFunction`][] inputs, returns a [`genjax.GenerativeFunction`][] that accepts `n+1` arguments:

    - an index in the range $[0, n)$
    - a tuple of arguments for each of the input generative functions (`n` total tuples)

    and executes the generative function at the supplied index with its provided arguments.

    If `index` is out of bounds, `index` is clamped to within bounds.

    Args:
        gen_fns: generative functions that the `Switch` will select from.

    Examples:
        Create a `Switch` via the [`genjax.switch`][] method:
        ```python exec="yes" html="true" source="material-block" session="switch"
        import jax, genjax


        @genjax.gen
        def branch_1():
            x = genjax.normal(0.0, 1.0) @ "x1"


        @genjax.gen
        def branch_2():
            x = genjax.bernoulli(probs=0.3) @ "x2"


        switch = genjax.switch(branch_1, branch_2)

        key = jax.random.key(314159)
        jitted = jax.jit(switch.simulate)

        # Select `branch_2` by providing 1:
        tr = jitted(key, (1, (), ()))

        print(tr.render_html())
        ```
    """
    return Switch[R](gen_fns)

Argument and Return Transformations

genjax.dimap

dimap(
    *,
    pre: Callable[..., ArgTuple] = lambda *args: args,
    post: Callable[
        [tuple[Any, ...], ArgTuple, R], S
    ] = lambda _, _xformed, retval: retval
) -> Callable[
    [GenerativeFunction[R]], Dimap[ArgTuple, R, S]
]

Returns a decorator that wraps a genjax.GenerativeFunction and applies pre- and post-processing functions to its arguments and return value.

Info

Prefer genjax.map if you only need to transform the return value, or genjax.contramap if you need to transform the arguments.

Parameters:

Name Type Description Default

pre

Callable[..., ArgTuple]

A callable that preprocesses the arguments before passing them to the wrapped function. Note that pre must return a tuple of arguments, not a bare argument. Default is the identity function.

lambda *args: args

post

Callable[[tuple[Any, ...], ArgTuple, R], S]

A callable that postprocesses the return value of the wrapped function. Default is the identity function.

lambda _, _xformed, retval: retval

Returns:

Type Description
Callable[[GenerativeFunction[R]], Dimap[ArgTuple, R, S]]

A decorator that takes a genjax.GenerativeFunction and returns a new genjax.GenerativeFunction with the same behavior but with the arguments and return value transformed according to pre and post.

Examples:

import jax, genjax


# Define pre- and post-processing functions
def pre_process(x, y):
    return (x + 1, y * 2)


def post_process(args, xformed, retval):
    return retval**2


# Apply dimap to a generative function
@genjax.dimap(pre=pre_process, post=post_process)
@genjax.gen
def dimap_model(x, y):
    return genjax.normal(x, y) @ "z"


# Use the dimap model
key = jax.random.key(0)
trace = dimap_model.simulate(key, (2.0, 3.0))

print(trace.render_html())
Source code in src/genjax/_src/generative_functions/combinators/dimap.py
def dimap(
    *,
    pre: Callable[..., ArgTuple] = lambda *args: args,
    post: Callable[[tuple[Any, ...], ArgTuple, R], S] = lambda _,
    _xformed,
    retval: retval,
) -> Callable[[GenerativeFunction[R]], Dimap[ArgTuple, R, S]]:
    """
    Returns a decorator that wraps a [`genjax.GenerativeFunction`][] and applies pre- and post-processing functions to its arguments and return value.

    !!! info
        Prefer [`genjax.map`][] if you only need to transform the return value, or [`genjax.contramap`][] if you need to transform the arguments.

    Args:
        pre: A callable that preprocesses the arguments before passing them to the wrapped function. Note that `pre` must return a _tuple_ of arguments, not a bare argument. Default is the identity function.
        post: A callable that postprocesses the return value of the wrapped function. Default is the identity function.

    Returns:
        A decorator that takes a [`genjax.GenerativeFunction`][] and returns a new [`genjax.GenerativeFunction`][] with the same behavior but with the arguments and return value transformed according to `pre` and `post`.

    Examples:
        ```python exec="yes" html="true" source="material-block" session="dimap"
        import jax, genjax


        # Define pre- and post-processing functions
        def pre_process(x, y):
            return (x + 1, y * 2)


        def post_process(args, xformed, retval):
            return retval**2


        # Apply dimap to a generative function
        @genjax.dimap(pre=pre_process, post=post_process)
        @genjax.gen
        def dimap_model(x, y):
            return genjax.normal(x, y) @ "z"


        # Use the dimap model
        key = jax.random.key(0)
        trace = dimap_model.simulate(key, (2.0, 3.0))

        print(trace.render_html())
        ```
    """

    def decorator(f: GenerativeFunction[R]) -> Dimap[ArgTuple, R, S]:
        return Dimap(f, pre, post)

    return decorator

genjax.map

map(
    f: Callable[[R], S]
) -> Callable[
    [GenerativeFunction[R]], Dimap[tuple[Any, ...], R, S]
]

Returns a decorator that wraps a genjax.GenerativeFunction and applies a post-processing function to its return value.

This is a specialized version of genjax.dimap where only the post-processing function is applied.

Parameters:

Name Type Description Default

f

Callable[[R], S]

A callable that postprocesses the return value of the wrapped function.

required

Returns:

Type Description
Callable[[GenerativeFunction[R]], Dimap[tuple[Any, ...], R, S]]

A decorator that takes a genjax.GenerativeFunction and returns a new genjax.GenerativeFunction with the same behavior but with the return value transformed according to f.

Examples:

import jax, genjax


# Define a post-processing function
def square(x):
    return x**2


# Apply map to a generative function
@genjax.map(square)
@genjax.gen
def map_model(x):
    return genjax.normal(x, 1.0) @ "z"


# Use the map model
key = jax.random.key(0)
trace = map_model.simulate(key, (2.0,))

print(trace.render_html())
Source code in src/genjax/_src/generative_functions/combinators/dimap.py
def map(
    f: Callable[[R], S],
) -> Callable[[GenerativeFunction[R]], Dimap[tuple[Any, ...], R, S]]:
    """
    Returns a decorator that wraps a [`genjax.GenerativeFunction`][] and applies a post-processing function to its return value.

    This is a specialized version of [`genjax.dimap`][] where only the post-processing function is applied.

    Args:
        f: A callable that postprocesses the return value of the wrapped function.

    Returns:
        A decorator that takes a [`genjax.GenerativeFunction`][] and returns a new [`genjax.GenerativeFunction`][] with the same behavior but with the return value transformed according to `f`.

    Examples:
        ```python exec="yes" html="true" source="material-block" session="map"
        import jax, genjax


        # Define a post-processing function
        def square(x):
            return x**2


        # Apply map to a generative function
        @genjax.map(square)
        @genjax.gen
        def map_model(x):
            return genjax.normal(x, 1.0) @ "z"


        # Use the map model
        key = jax.random.key(0)
        trace = map_model.simulate(key, (2.0,))

        print(trace.render_html())
        ```
    """

    def post(_args, _xformed, x: R) -> S:
        return f(x)

    return dimap(pre=lambda *args: args, post=post)

genjax.contramap

contramap(
    f: Callable[..., ArgTuple]
) -> Callable[
    [GenerativeFunction[R]], Dimap[ArgTuple, R, R]
]

Returns a decorator that wraps a genjax.GenerativeFunction and applies a pre-processing function to its arguments.

This is a specialized version of genjax.dimap where only the pre-processing function is applied.

Parameters:

Name Type Description Default

f

Callable[..., ArgTuple]

A callable that preprocesses the arguments of the wrapped function. Note that f must return a tuple of arguments, not a bare argument.

required

Returns:

Type Description
Callable[[GenerativeFunction[R]], Dimap[ArgTuple, R, R]]

A decorator that takes a genjax.GenerativeFunction and returns a new genjax.GenerativeFunction with the same behavior but with the arguments transformed according to f.

Examples:

import jax, genjax


# Define a pre-processing function.
# Note that this function must return a tuple of arguments!
def add_one(x):
    return (x + 1,)


# Apply contramap to a generative function
@genjax.contramap(add_one)
@genjax.gen
def contramap_model(x):
    return genjax.normal(x, 1.0) @ "z"


# Use the contramap model
key = jax.random.key(0)
trace = contramap_model.simulate(key, (2.0,))

print(trace.render_html())
Source code in src/genjax/_src/generative_functions/combinators/dimap.py
def contramap(
    f: Callable[..., ArgTuple],
) -> Callable[[GenerativeFunction[R]], Dimap[ArgTuple, R, R]]:
    """
    Returns a decorator that wraps a [`genjax.GenerativeFunction`][] and applies a pre-processing function to its arguments.

    This is a specialized version of [`genjax.dimap`][] where only the pre-processing function is applied.

    Args:
        f: A callable that preprocesses the arguments of the wrapped function. Note that `f` must return a _tuple_ of arguments, not a bare argument.

    Returns:
        A decorator that takes a [`genjax.GenerativeFunction`][] and returns a new [`genjax.GenerativeFunction`][] with the same behavior but with the arguments transformed according to `f`.

    Examples:
        ```python exec="yes" html="true" source="material-block" session="contramap"
        import jax, genjax


        # Define a pre-processing function.
        # Note that this function must return a tuple of arguments!
        def add_one(x):
            return (x + 1,)


        # Apply contramap to a generative function
        @genjax.contramap(add_one)
        @genjax.gen
        def contramap_model(x):
            return genjax.normal(x, 1.0) @ "z"


        # Use the contramap model
        key = jax.random.key(0)
        trace = contramap_model.simulate(key, (2.0,))

        print(trace.render_html())
        ```
    """
    return dimap(pre=f, post=lambda _args, _xformed, ret: ret)

The Rest

genjax.mask

mask(f: GenerativeFunction[R]) -> MaskCombinator[R]

Combinator which enables dynamic masking of generative functions. Takes a genjax.GenerativeFunction and returns a new genjax.GenerativeFunction which accepts an additional boolean first argument.

If True, the invocation of the generative function is masked, and its contribution to the score is ignored. If False, it has the same semantics as if one was invoking the generative function without masking.

The return value type is a Mask, with a flag value equal to the supplied boolean.

Parameters:

Name Type Description Default

f

GenerativeFunction[R]

The generative function to be masked.

required

Returns:

Type Description
MaskCombinator[R]

The masked version of the input generative function.

Examples:

Masking a normal draw:

import genjax, jax


@genjax.mask
@genjax.gen
def masked_normal_draw(mean):
    return genjax.normal(mean, 1.0) @ "x"


key = jax.random.key(314159)
tr = jax.jit(masked_normal_draw.simulate)(
    key,
    (
        False,
        2.0,
    ),
)
print(tr.render_html())

Source code in src/genjax/_src/generative_functions/combinators/mask.py
def mask(f: GenerativeFunction[R]) -> MaskCombinator[R]:
    """
    Combinator which enables dynamic masking of generative functions. Takes a [`genjax.GenerativeFunction`][] and returns a new [`genjax.GenerativeFunction`][] which accepts an additional boolean first argument.

    If `True`, the invocation of the generative function is masked, and its contribution to the score is ignored. If `False`, it has the same semantics as if one was invoking the generative function without masking.

    The return value type is a `Mask`, with a flag value equal to the supplied boolean.

    Args:
        f: The generative function to be masked.

    Returns:
        The masked version of the input generative function.

    Examples:
        Masking a normal draw:
        ```python exec="yes" html="true" source="material-block" session="mask"
        import genjax, jax


        @genjax.mask
        @genjax.gen
        def masked_normal_draw(mean):
            return genjax.normal(mean, 1.0) @ "x"


        key = jax.random.key(314159)
        tr = jax.jit(masked_normal_draw.simulate)(
            key,
            (
                False,
                2.0,
            ),
        )
        print(tr.render_html())
        ```
    """
    return MaskCombinator(f)

genjax.mix

Creates a mixture model from a set of generative functions.

This function takes multiple generative functions as input and returns a new generative function that represents a mixture model.

The returned generative function takes the following arguments:

  • mixture_logits: Logits for the categorical distribution used to select a component.
  • *args: Argument tuples for each of the input generative functions

and samples from one of the input generative functions based on draw from a categorical distribution defined by the provided mixture logits.

Parameters:

Name Type Description Default

*gen_fns

GenerativeFunction[R]

Variable number of genjax.GenerativeFunctions to be mixed.

()

Returns:

Type Description
GenerativeFunction[R]

A new genjax.GenerativeFunction representing the mixture model.

Examples:

import jax
import genjax


# Define component generative functions
@genjax.gen
def component1(x):
    return genjax.normal(x, 1.0) @ "y"


@genjax.gen
def component2(x):
    return genjax.normal(x, 2.0) @ "y"


# Create mixture model
mixture = genjax.mix(component1, component2)

# Use the mixture model
key = jax.random.key(0)
logits = jax.numpy.array([0.3, 0.7])  # Favors component2
trace = mixture.simulate(key, (logits, (0.0,), (7.0,)))
print(trace.render_html())
Source code in src/genjax/_src/generative_functions/combinators/mixture.py
def mix(*gen_fns: GenerativeFunction[R]) -> GenerativeFunction[R]:
    """
    Creates a mixture model from a set of generative functions.

    This function takes multiple generative functions as input and returns a new generative function that represents a mixture model.

    The returned generative function takes the following arguments:

    - `mixture_logits`: Logits for the categorical distribution used to select a component.
    - `*args`: Argument tuples for each of the input generative functions

    and samples from one of the input generative functions based on draw from a categorical distribution defined by the provided mixture logits.

    Args:
        *gen_fns: Variable number of [`genjax.GenerativeFunction`][]s to be mixed.

    Returns:
        A new [`genjax.GenerativeFunction`][] representing the mixture model.

    Examples:
        ```python exec="yes" html="true" source="material-block" session="mix"
        import jax
        import genjax


        # Define component generative functions
        @genjax.gen
        def component1(x):
            return genjax.normal(x, 1.0) @ "y"


        @genjax.gen
        def component2(x):
            return genjax.normal(x, 2.0) @ "y"


        # Create mixture model
        mixture = genjax.mix(component1, component2)

        # Use the mixture model
        key = jax.random.key(0)
        logits = jax.numpy.array([0.3, 0.7])  # Favors component2
        trace = mixture.simulate(key, (logits, (0.0,), (7.0,)))
        print(trace.render_html())
        ```
    """

    inner_combinator_closure = switch(*gen_fns)

    def mixture_model(mixture_logits, *args) -> R:
        mix_idx = categorical(logits=mixture_logits) @ "mixture_component"
        v = inner_combinator_closure(mix_idx, *args) @ "component_sample"
        return v

    return gen(mixture_model)