Dimap combinator
What is this magic?
¶
In [1]:
Copied!
import sys
if "google.colab" in sys.modules:
%pip install --quiet "genjax[genstudio]"
import sys
if "google.colab" in sys.modules:
%pip install --quiet "genjax[genstudio]"
!! It is only meant to be use by library authors. It is used to implement other combinators such as or_else
, and repeat
.
In [2]:
Copied!
import jax
import jax.numpy as jnp
from genjax import gen, normal, pretty
from genjax._src.core.generative import GenerativeFunction
from genjax._src.core.typing import Callable, ScalarFlag
key = jax.random.key(0)
pretty()
import jax
import jax.numpy as jnp
from genjax import gen, normal, pretty
from genjax._src.core.generative import GenerativeFunction
from genjax._src.core.typing import Callable, ScalarFlag
key = jax.random.key(0)
pretty()
Here's an example of rewriting the OrElseCombinator
combinator using contramap
and switch
.
In [3]:
Copied!
def NewOrElseCombinator(
if_gen_fn: GenerativeFunction,
else_gen_fn: GenerativeFunction,
) -> GenerativeFunction:
def argument_mapping(b: ScalarFlag, if_args: tuple, else_args: tuple):
idx = jnp.array(jnp.logical_not(b), dtype=int)
return (idx, if_args, else_args)
# The `contramap` method is used to map the input arguments to the expected input of the generative function, and then call the switch combinator
return if_gen_fn.switch(else_gen_fn).contramap(argument_mapping)
def NewOrElseCombinator(
if_gen_fn: GenerativeFunction,
else_gen_fn: GenerativeFunction,
) -> GenerativeFunction:
def argument_mapping(b: ScalarFlag, if_args: tuple, else_args: tuple):
idx = jnp.array(jnp.logical_not(b), dtype=int)
return (idx, if_args, else_args)
# The `contramap` method is used to map the input arguments to the expected input of the generative function, and then call the switch combinator
return if_gen_fn.switch(else_gen_fn).contramap(argument_mapping)
To add a version accessible as decorator
In [4]:
Copied!
def new_or_else(
else_gen_fn: GenerativeFunction,
) -> Callable[[GenerativeFunction], GenerativeFunction]:
def decorator(if_gen_fn) -> GenerativeFunction:
return NewOrElseCombinator(if_gen_fn, else_gen_fn)
return decorator
def new_or_else(
else_gen_fn: GenerativeFunction,
) -> Callable[[GenerativeFunction], GenerativeFunction]:
def decorator(if_gen_fn) -> GenerativeFunction:
return NewOrElseCombinator(if_gen_fn, else_gen_fn)
return decorator
To add a version accessible using postfix syntax, one would need to add the following method as part of the GenerativeFunction
dataclass in core.py
.
In [5]:
Copied!
def postfix_new_or_else(self, gen_fn: "GenerativeFunction", /) -> "GenerativeFunction":
return new_or_else(gen_fn)(self)
def postfix_new_or_else(self, gen_fn: "GenerativeFunction", /) -> "GenerativeFunction":
return new_or_else(gen_fn)(self)
Testing the rewritten version on an example
In [6]:
Copied!
@gen
def if_model(x):
return normal(x, 1.0) @ "if_value"
@gen
def else_model(x):
return normal(x, 5.0) @ "else_value"
@gen
def model(toss: bool):
return NewOrElseCombinator(if_model, else_model)(toss, (1.0,), (10.0,)) @ "tossed"
key, subkey = jax.random.split(key)
tr = jax.jit(model.simulate)(subkey, (True,))
tr.get_choices()
@gen
def if_model(x):
return normal(x, 1.0) @ "if_value"
@gen
def else_model(x):
return normal(x, 5.0) @ "else_value"
@gen
def model(toss: bool):
return NewOrElseCombinator(if_model, else_model)(toss, (1.0,), (10.0,)) @ "tossed"
key, subkey = jax.random.split(key)
tr = jax.jit(model.simulate)(subkey, (True,))
tr.get_choices()
Out[6]:
Checking that the two versions are equivalent on an example
In [7]:
Copied!
@new_or_else(else_model)
@gen
def or_else_model(x):
return normal(x, 1.0) @ "if_value"
@gen
def model_v2(toss: bool):
return or_else_model(toss, (1.0,), (10.0,)) @ "tossed"
# reusing subkey to get the same result
tr2 = jax.jit(model_v2.simulate)(subkey, (True,))
tr.get_choices() == tr2.get_choices()
@new_or_else(else_model)
@gen
def or_else_model(x):
return normal(x, 1.0) @ "if_value"
@gen
def model_v2(toss: bool):
return or_else_model(toss, (1.0,), (10.0,)) @ "tossed"
# reusing subkey to get the same result
tr2 = jax.jit(model_v2.simulate)(subkey, (True,))
tr.get_choices() == tr2.get_choices()
Out[7]: