Skip to content

The menagerie of GenerativeFunction

Generative functions are probabilistic building blocks. They allow you to express complex probability distributions, and automate several operations on them. GenJAX exports a standard library of generative functions, and this page catalogues them and their usage.

The venerable & reliable Distribution

To start, distributions are generative functions.

genjax.Distribution

Bases: Generic[R], GenerativeFunction[R]

Methods:

Name Description
random_weighted
estimate_logpdf
Source code in src/genjax/_src/generative_functions/distributions/distribution.py
class Distribution(Generic[R], GenerativeFunction[R]):
    @abstractmethod
    def random_weighted(
        self,
        key: PRNGKey,
        *args,
    ) -> tuple[Score, R]:
        pass

    @abstractmethod
    def estimate_logpdf(
        self,
        key: PRNGKey,
        v: R,
        *args,
    ) -> Score:
        pass

    def simulate(
        self,
        key: PRNGKey,
        args: tuple[Any, ...],
    ) -> Trace[R]:
        (w, v) = self.random_weighted(key, *args)
        tr = DistributionTrace(self, args, v, w)
        return tr

    def generate_choice_map(
        self,
        key: PRNGKey,
        chm: ChoiceMap,
        args: tuple[Any, ...],
    ) -> tuple[Trace[R], Weight]:
        v = chm.get_value()
        match v:
            case None:
                tr = self.simulate(key, args)
                return tr, jnp.array(0.0)

            case Mask(value, flag):

                def _simulate(key, v):
                    score, new_v = self.random_weighted(key, *args)
                    w = 0.0
                    return (score, w, new_v)

                def _importance(key, v):
                    w = self.estimate_logpdf(key, v, *args)
                    return (w, w, v)

                score, w, new_v = jax.lax.cond(flag, _importance, _simulate, key, value)
                tr = DistributionTrace(self, args, new_v, score)
                return tr, w

            case _:
                w = self.estimate_logpdf(key, v, *args)
                tr = DistributionTrace(self, args, v, w)
                return tr, w

    def generate(
        self,
        key: PRNGKey,
        constraint: ChoiceMap,
        args: tuple[Any, ...],
    ) -> tuple[Trace[R], Weight]:
        match constraint:
            case ChoiceMap():
                tr, w = self.generate_choice_map(key, constraint, args)

            case _:
                raise Exception("Unhandled type.")
        return tr, w

    def edit_empty(
        self,
        trace: Trace[R],
        argdiffs: Argdiffs,
    ) -> tuple[Trace[R], Weight, Retdiff[R], Update]:
        sample = trace.get_choices()
        primals = Diff.tree_primal(argdiffs)
        new_score, _ = self.assess(sample, primals)
        new_trace = DistributionTrace(self, primals, sample.get_value(), new_score)
        return (
            new_trace,
            new_score - trace.get_score(),
            Diff.no_change(trace.get_retval()),
            Update(ChoiceMap.empty()),
        )

    def edit_update_with_constraint(
        self,
        key: PRNGKey,
        trace: Trace[R],
        constraint: ChoiceMap,
        argdiffs: Argdiffs,
    ) -> tuple[Trace[R], Weight, Retdiff[R], Update]:
        primals = Diff.tree_primal(argdiffs)
        match constraint:
            case ChoiceMap():
                match constraint.get_value():
                    case Mask() as masked_value:

                        def _true_branch(key, new_value: R, _):
                            fwd = self.estimate_logpdf(key, new_value, *primals)
                            bwd = trace.get_score()
                            w = fwd - bwd
                            return (new_value, w, fwd)

                        def _false_branch(key, _, old_value: R):
                            fwd = self.estimate_logpdf(key, old_value, *primals)
                            bwd = trace.get_score()
                            w = fwd - bwd
                            return (old_value, w, fwd)

                        flag = masked_value.primal_flag()
                        new_value: R = masked_value.value
                        old_choices = trace.get_choices()
                        old_value: R = old_choices.get_value()

                        new_value, w, score = FlagOp.cond(
                            flag,
                            _true_branch,
                            _false_branch,
                            key,
                            new_value,
                            old_value,
                        )
                        return (
                            DistributionTrace(self, primals, new_value, score),
                            w,
                            Diff.unknown_change(new_value),
                            Update(
                                old_choices.mask(flag),
                            ),
                        )
                    case None:
                        value_chm = trace.get_choices()
                        v = value_chm.get_value()
                        fwd = self.estimate_logpdf(key, v, *primals)
                        bwd = trace.get_score()
                        w = fwd - bwd
                        new_tr = DistributionTrace(self, primals, v, fwd)
                        retval_diff = Diff.no_change(v)
                        return (new_tr, w, retval_diff, Update(ChoiceMap.empty()))

                    case v:
                        fwd = self.estimate_logpdf(key, v, *primals)
                        bwd = trace.get_score()
                        w = fwd - bwd
                        new_tr = DistributionTrace(self, primals, v, fwd)
                        discard = trace.get_choices()
                        retval_diff = Diff.unknown_change(v)
                        return (new_tr, w, retval_diff, Update(discard))
            case _:
                raise Exception(f"Unhandled constraint in edit: {type(constraint)}.")

    def project(
        self,
        key: PRNGKey,
        trace: Trace[R],
        selection: Selection,
    ) -> Weight:
        return jnp.where(
            selection.check(),
            trace.get_score(),
            jnp.array(0.0),
        )

    def edit_regenerate(
        self,
        key: PRNGKey,
        trace: Trace[R],
        selection: Selection,
        argdiffs: Argdiffs,
    ) -> tuple[Trace[R], Weight, Retdiff[R], EditRequest]:
        check = () in selection
        if FlagOp.concrete_true(check):
            primals = Diff.tree_primal(argdiffs)
            w, new_v = self.random_weighted(key, *primals)
            incremental_w = w - trace.get_score()
            old_v = trace.get_retval()
            new_trace = DistributionTrace(self, primals, new_v, w)
            return (
                new_trace,
                incremental_w,
                Diff.unknown_change(new_v),
                Update(ChoiceMap.choice(old_v)),
            )
        elif FlagOp.concrete_false(check):
            if Diff.static_check_no_change(argdiffs):
                return (
                    trace,
                    jnp.array(0.0),
                    Diff.no_change(trace.get_retval()),
                    Update(ChoiceMap.empty()),
                )
            else:
                chm = trace.get_choices()
                primals = Diff.tree_primal(argdiffs)
                new_score, _ = self.assess(chm, primals)
                new_trace = DistributionTrace(self, primals, chm.get_value(), new_score)
                return (
                    new_trace,
                    new_score - trace.get_score(),
                    Diff.no_change(trace.get_retval()),
                    Update(
                        ChoiceMap.empty(),
                    ),
                )
        else:
            raise NotImplementedError

    def edit_update(
        self,
        key: PRNGKey,
        trace: Trace[R],
        constraint: ChoiceMap,
        argdiffs: Argdiffs,
    ) -> tuple[Trace[R], Weight, Retdiff[R], Update]:
        match constraint:
            case ChoiceMap():
                return self.edit_update_with_constraint(
                    key, trace, constraint, argdiffs
                )

            case _:
                raise Exception(f"Not implement fwd problem: {constraint}.")

    def edit(
        self,
        key: PRNGKey,
        trace: Trace[R],
        edit_request: EditRequest,
        argdiffs: Argdiffs,
    ) -> tuple[Trace[R], Weight, Retdiff[R], EditRequest]:
        match edit_request:
            case Update(chm):
                return self.edit_update(
                    key,
                    trace,
                    chm,
                    argdiffs,
                )
            case Regenerate(selection):
                return self.edit_regenerate(
                    key,
                    trace,
                    selection,
                    argdiffs,
                )

            case _:
                raise NotSupportedEditRequest(edit_request)

    def assess(
        self,
        sample: ChoiceMap,
        args: tuple[Any, ...],
    ):
        raise NotImplementedError

random_weighted abstractmethod

random_weighted(key: PRNGKey, *args) -> tuple[Score, R]
Source code in src/genjax/_src/generative_functions/distributions/distribution.py
@abstractmethod
def random_weighted(
    self,
    key: PRNGKey,
    *args,
) -> tuple[Score, R]:
    pass

estimate_logpdf abstractmethod

estimate_logpdf(key: PRNGKey, v: R, *args) -> Score
Source code in src/genjax/_src/generative_functions/distributions/distribution.py
@abstractmethod
def estimate_logpdf(
    self,
    key: PRNGKey,
    v: R,
    *args,
) -> Score:
    pass

Distributions intentionally expose a permissive interface (random_weighted and estimate_logpdf which doesn't assume exact density evaluation. genjax.ExactDensity is a more restrictive interface, which assumes exact density evaluation.

genjax.ExactDensity

Bases: Generic[R], Distribution[R]

Methods:

Name Description
random_weighted

Given arguments to the distribution, sample from the distribution, and return the exact log density of the sample, and the sample.

estimate_logpdf

Given a sample and arguments to the distribution, return the exact log density of the sample.

Source code in src/genjax/_src/generative_functions/distributions/distribution.py
class ExactDensity(Generic[R], Distribution[R]):
    @abstractmethod
    def sample(self, key: PRNGKey, *args) -> R:
        pass

    @abstractmethod
    def logpdf(self, v: R, *args, **kwargs) -> Score:
        pass

    def __abstract_call__(self, *args):
        return to_shape_fn(self.sample, jnp.zeros)(_fake_key, *args)

    def random_weighted(
        self,
        key: PRNGKey,
        *args,
    ) -> tuple[Score, R]:
        """
        Given arguments to the distribution, sample from the distribution, and return the exact log density of the sample, and the sample.
        """
        v = self.sample(key, *args)
        w = self.estimate_logpdf(key, v, *args)
        return (w, v)

    def estimate_logpdf(
        self,
        key: PRNGKey,
        v: R,
        *args,
    ) -> Weight:
        """
        Given a sample and arguments to the distribution, return the exact log density of the sample.
        """
        w = self.logpdf(v, *args)
        if w.shape:
            return jnp.sum(w)
        else:
            return w

    def assess(
        self,
        sample: ChoiceMap,
        args: tuple[Any, ...],
    ) -> tuple[Weight, R]:
        key = jax.random.key(0)
        v = sample.get_value()
        match v:
            case Mask(value, flag):

                def _check():
                    checkify.check(
                        bool(flag),
                        "Attempted to unmask when a mask flag is False: the masked value is invalid.\n",
                    )

                optional_check(_check)
                w = self.estimate_logpdf(key, value, *args)
                return w, value
            case _:
                w = self.estimate_logpdf(key, v, *args)
                return w, v

random_weighted

random_weighted(key: PRNGKey, *args) -> tuple[Score, R]

Given arguments to the distribution, sample from the distribution, and return the exact log density of the sample, and the sample.

Source code in src/genjax/_src/generative_functions/distributions/distribution.py
def random_weighted(
    self,
    key: PRNGKey,
    *args,
) -> tuple[Score, R]:
    """
    Given arguments to the distribution, sample from the distribution, and return the exact log density of the sample, and the sample.
    """
    v = self.sample(key, *args)
    w = self.estimate_logpdf(key, v, *args)
    return (w, v)

estimate_logpdf

estimate_logpdf(key: PRNGKey, v: R, *args) -> Weight

Given a sample and arguments to the distribution, return the exact log density of the sample.

Source code in src/genjax/_src/generative_functions/distributions/distribution.py
def estimate_logpdf(
    self,
    key: PRNGKey,
    v: R,
    *args,
) -> Weight:
    """
    Given a sample and arguments to the distribution, return the exact log density of the sample.
    """
    w = self.logpdf(v, *args)
    if w.shape:
        return jnp.sum(w)
    else:
        return w

GenJAX exports a long list of exact density distributions, which uses the functionality of tfp.distributions. A list of these is shown below.

genjax.generative_functions.distributions

Attributes:

Name Type Description
bernoulli

A tfp_distribution generative function which wraps the tfd.Bernoulli distribution from TensorFlow Probability distributions.

beta

A tfp_distribution generative function which wraps the tfd.Beta distribution from TensorFlow Probability distributions.

beta_binomial

A tfp_distribution generative function which wraps the tfd.BetaBinomial distribution from TensorFlow Probability distributions.

beta_quotient

A tfp_distribution generative function which wraps the tfd.BetaQuotient distribution from TensorFlow Probability distributions.

binomial ExactDensity[Array]

A tfp_distribution generative function which wraps the tfd.Binomial distribution from TensorFlow Probability distributions.

categorical

A tfp_distribution generative function which wraps the tfd.Categorical distribution from TensorFlow Probability distributions.

cauchy

A tfp_distribution generative function which wraps the tfd.Cauchy distribution from TensorFlow Probability distributions.

chi

A tfp_distribution generative function which wraps the tfd.Chi distribution from TensorFlow Probability distributions.

chi2

A tfp_distribution generative function which wraps the tfd.Chi2 distribution from TensorFlow Probability distributions.

dirichlet

A tfp_distribution generative function which wraps the tfd.Dirichlet distribution from TensorFlow Probability distributions.

dirichlet_multinomial

A tfp_distribution generative function which wraps the tfd.DirichletMultinomial distribution from TensorFlow Probability distributions.

double_sided_maxwell

A tfp_distribution generative function which wraps the tfd.DoublesidedMaxwell distribution from TensorFlow Probability distributions.

exp_gamma

A tfp_distribution generative function which wraps the tfd.ExpGamma distribution from TensorFlow Probability distributions.

exp_inverse_gamma

A tfp_distribution generative function which wraps the tfd.ExpInverseGamma distribution from TensorFlow Probability distributions.

exponential

A tfp_distribution generative function which wraps the tfd.Exponential distribution from TensorFlow Probability distributions.

flip

A tfp_distribution generative function which wraps the tfd.Bernoulli distribution from TensorFlow Probability distributions, but is constructed using a probability value and not a logit.

gamma

A tfp_distribution generative function which wraps the tfd.Gamma distribution from TensorFlow Probability distributions.

geometric

A tfp_distribution generative function which wraps the tfd.Geometric distribution from TensorFlow Probability distributions.

gumbel

A tfp_distribution generative function which wraps the tfd.Gumbel distribution from TensorFlow Probability distributions.

half_cauchy

A tfp_distribution generative function which wraps the tfd.HalfCauchy distribution from TensorFlow Probability distributions.

half_normal

A tfp_distribution generative function which wraps the tfd.HalfNormal distribution from TensorFlow Probability distributions.

half_student_t

A tfp_distribution generative function which wraps the tfd.HalfStudentT distribution from TensorFlow Probability distributions.

inverse_gamma

A tfp_distribution generative function which wraps the tfd.InverseGamma distribution from TensorFlow Probability distributions.

kumaraswamy

A tfp_distribution generative function which wraps the tfd.Kumaraswamy distribution from TensorFlow Probability distributions.

lambert_w_normal

A tfp_distribution generative function which wraps the tfd.LambertWNormal distribution from TensorFlow Probability distributions.

laplace

A tfp_distribution generative function which wraps the tfd.Laplace distribution from TensorFlow Probability distributions.

log_normal

A tfp_distribution generative function which wraps the tfd.LogNormal distribution from TensorFlow Probability distributions.

logit_normal

A tfp_distribution generative function which wraps the tfd.LogitNormal distribution from TensorFlow Probability distributions.

moyal

A tfp_distribution generative function which wraps the tfd.Moyal distribution from TensorFlow Probability distributions.

multinomial

A tfp_distribution generative function which wraps the tfd.Multinomial distribution from TensorFlow Probability distributions.

mv_normal

A tfp_distribution generative function which wraps the tfd.MultivariateNormalFullCovariance distribution from TensorFlow Probability distributions.

mv_normal_diag

A tfp_distribution generative function which wraps the tfd.MultivariateNormalDiag distribution from TensorFlow Probability distributions.

negative_binomial

A tfp_distribution generative function which wraps the tfd.NegativeBinomial distribution from TensorFlow Probability distributions.

non_central_chi2

A tfp_distribution generative function which wraps the tfd.NoncentralChi2 distribution from TensorFlow Probability distributions.

normal

A tfp_distribution generative function which wraps the tfd.Normal distribution from TensorFlow Probability distributions.

poisson

A tfp_distribution generative function which wraps the tfd.Poisson distribution from TensorFlow Probability distributions.

power_spherical

A tfp_distribution generative function which wraps the tfd.PowerSpherical distribution from TensorFlow Probability distributions.

skellam

A tfp_distribution generative function which wraps the tfd.Skellam distribution from TensorFlow Probability distributions.

student_t

A tfp_distribution generative function which wraps the tfd.StudentT distribution from TensorFlow Probability distributions.

truncated_cauchy

A tfp_distribution generative function which wraps the tfd.TruncatedCauchy distribution from TensorFlow Probability distributions.

truncated_normal

A tfp_distribution generative function which wraps the tfd.TruncatedNormal distribution from TensorFlow Probability distributions.

uniform

A tfp_distribution generative function which wraps the tfd.Uniform distribution from TensorFlow Probability distributions.

von_mises

A tfp_distribution generative function which wraps the tfd.VonMises distribution from TensorFlow Probability distributions.

von_mises_fisher

A tfp_distribution generative function which wraps the tfd.VonMisesFisher distribution from TensorFlow Probability distributions.

weibull

A tfp_distribution generative function which wraps the tfd.Weibull distribution from TensorFlow Probability distributions.

zipf

A tfp_distribution generative function which wraps the tfd.Zipf distribution from TensorFlow Probability distributions.

bernoulli module-attribute

bernoulli = tfp_distribution(
    implicit_logit_warning(Bernoulli), name="bernoulli"
)

A tfp_distribution generative function which wraps the tfd.Bernoulli distribution from TensorFlow Probability distributions.

Takes an N-D Tensor representing the log-odds of a 1 event. Each entry in the Tensor parameterizes an independent Bernoulli distribution where the probability of an event is sigmoid(logits).

(Note that this is the logits argument to the tfd.Bernoulli constructor.)

beta module-attribute

beta = tfp_distribution(Beta)

A tfp_distribution generative function which wraps the tfd.Beta distribution from TensorFlow Probability distributions.

beta_binomial module-attribute

beta_binomial = tfp_distribution(BetaBinomial)

A tfp_distribution generative function which wraps the tfd.BetaBinomial distribution from TensorFlow Probability distributions.

beta_quotient module-attribute

beta_quotient = tfp_distribution(BetaQuotient)

A tfp_distribution generative function which wraps the tfd.BetaQuotient distribution from TensorFlow Probability distributions.

binomial module-attribute

binomial: ExactDensity[Array] = tfp_distribution(Binomial)

A tfp_distribution generative function which wraps the tfd.Binomial distribution from TensorFlow Probability distributions.

categorical module-attribute

categorical = tfp_distribution(
    implicit_logit_warning(Categorical), name="categorical"
)

A tfp_distribution generative function which wraps the tfd.Categorical distribution from TensorFlow Probability distributions.

cauchy module-attribute

cauchy = tfp_distribution(Cauchy)

A tfp_distribution generative function which wraps the tfd.Cauchy distribution from TensorFlow Probability distributions.

chi module-attribute

chi = tfp_distribution(Chi)

A tfp_distribution generative function which wraps the tfd.Chi distribution from TensorFlow Probability distributions.

chi2 module-attribute

chi2 = tfp_distribution(Chi2)

A tfp_distribution generative function which wraps the tfd.Chi2 distribution from TensorFlow Probability distributions.

dirichlet module-attribute

dirichlet = tfp_distribution(Dirichlet)

A tfp_distribution generative function which wraps the tfd.Dirichlet distribution from TensorFlow Probability distributions.

dirichlet_multinomial module-attribute

dirichlet_multinomial = tfp_distribution(
    DirichletMultinomial
)

A tfp_distribution generative function which wraps the tfd.DirichletMultinomial distribution from TensorFlow Probability distributions.

double_sided_maxwell module-attribute

double_sided_maxwell = tfp_distribution(DoublesidedMaxwell)

A tfp_distribution generative function which wraps the tfd.DoublesidedMaxwell distribution from TensorFlow Probability distributions.

exp_gamma module-attribute

exp_gamma = tfp_distribution(ExpGamma)

A tfp_distribution generative function which wraps the tfd.ExpGamma distribution from TensorFlow Probability distributions.

exp_inverse_gamma module-attribute

exp_inverse_gamma = tfp_distribution(ExpInverseGamma)

A tfp_distribution generative function which wraps the tfd.ExpInverseGamma distribution from TensorFlow Probability distributions.

exponential module-attribute

exponential = tfp_distribution(Exponential)

A tfp_distribution generative function which wraps the tfd.Exponential distribution from TensorFlow Probability distributions.

flip module-attribute

flip = tfp_distribution(
    lambda p: Bernoulli(probs=p, dtype=bool_), name="flip"
)

A tfp_distribution generative function which wraps the tfd.Bernoulli distribution from TensorFlow Probability distributions, but is constructed using a probability value and not a logit.

Takes an N-D Tensor representing the probability of a 1 event. Each entry in the Tensor parameterizes an independent Bernoulli distribution.

(Note that this is the probs argument to the tfd.Bernoulli constructor.)

gamma module-attribute

gamma = tfp_distribution(Gamma)

A tfp_distribution generative function which wraps the tfd.Gamma distribution from TensorFlow Probability distributions.

geometric module-attribute

geometric = tfp_distribution(Geometric)

A tfp_distribution generative function which wraps the tfd.Geometric distribution from TensorFlow Probability distributions.

gumbel module-attribute

gumbel = tfp_distribution(Gumbel)

A tfp_distribution generative function which wraps the tfd.Gumbel distribution from TensorFlow Probability distributions.

half_cauchy module-attribute

half_cauchy = tfp_distribution(HalfCauchy)

A tfp_distribution generative function which wraps the tfd.HalfCauchy distribution from TensorFlow Probability distributions.

half_normal module-attribute

half_normal = tfp_distribution(HalfNormal)

A tfp_distribution generative function which wraps the tfd.HalfNormal distribution from TensorFlow Probability distributions.

half_student_t module-attribute

half_student_t = tfp_distribution(HalfStudentT)

A tfp_distribution generative function which wraps the tfd.HalfStudentT distribution from TensorFlow Probability distributions.

inverse_gamma module-attribute

inverse_gamma = tfp_distribution(InverseGamma)

A tfp_distribution generative function which wraps the tfd.InverseGamma distribution from TensorFlow Probability distributions.

kumaraswamy module-attribute

kumaraswamy = tfp_distribution(Kumaraswamy)

A tfp_distribution generative function which wraps the tfd.Kumaraswamy distribution from TensorFlow Probability distributions.

lambert_w_normal module-attribute

lambert_w_normal = tfp_distribution(LambertWNormal)

A tfp_distribution generative function which wraps the tfd.LambertWNormal distribution from TensorFlow Probability distributions.

laplace module-attribute

laplace = tfp_distribution(Laplace)

A tfp_distribution generative function which wraps the tfd.Laplace distribution from TensorFlow Probability distributions.

log_normal module-attribute

log_normal = tfp_distribution(LogNormal)

A tfp_distribution generative function which wraps the tfd.LogNormal distribution from TensorFlow Probability distributions.

logit_normal module-attribute

logit_normal = tfp_distribution(LogitNormal)

A tfp_distribution generative function which wraps the tfd.LogitNormal distribution from TensorFlow Probability distributions.

moyal module-attribute

moyal = tfp_distribution(Moyal)

A tfp_distribution generative function which wraps the tfd.Moyal distribution from TensorFlow Probability distributions.

multinomial module-attribute

multinomial = tfp_distribution(Multinomial)

A tfp_distribution generative function which wraps the tfd.Multinomial distribution from TensorFlow Probability distributions.

mv_normal module-attribute

mv_normal = tfp_distribution(
    MultivariateNormalFullCovariance
)

A tfp_distribution generative function which wraps the tfd.MultivariateNormalFullCovariance distribution from TensorFlow Probability distributions.

mv_normal_diag module-attribute

mv_normal_diag = tfp_distribution(MultivariateNormalDiag)

A tfp_distribution generative function which wraps the tfd.MultivariateNormalDiag distribution from TensorFlow Probability distributions.

negative_binomial module-attribute

negative_binomial = tfp_distribution(NegativeBinomial)

A tfp_distribution generative function which wraps the tfd.NegativeBinomial distribution from TensorFlow Probability distributions.

non_central_chi2 module-attribute

non_central_chi2 = tfp_distribution(NoncentralChi2)

A tfp_distribution generative function which wraps the tfd.NoncentralChi2 distribution from TensorFlow Probability distributions.

normal module-attribute

normal = tfp_distribution(Normal)

A tfp_distribution generative function which wraps the tfd.Normal distribution from TensorFlow Probability distributions.

poisson module-attribute

poisson = tfp_distribution(Poisson)

A tfp_distribution generative function which wraps the tfd.Poisson distribution from TensorFlow Probability distributions.

power_spherical module-attribute

power_spherical = tfp_distribution(PowerSpherical)

A tfp_distribution generative function which wraps the tfd.PowerSpherical distribution from TensorFlow Probability distributions.

skellam module-attribute

skellam = tfp_distribution(Skellam)

A tfp_distribution generative function which wraps the tfd.Skellam distribution from TensorFlow Probability distributions.

student_t module-attribute

student_t = tfp_distribution(StudentT)

A tfp_distribution generative function which wraps the tfd.StudentT distribution from TensorFlow Probability distributions.

truncated_cauchy module-attribute

truncated_cauchy = tfp_distribution(TruncatedCauchy)

A tfp_distribution generative function which wraps the tfd.TruncatedCauchy distribution from TensorFlow Probability distributions.

truncated_normal module-attribute

truncated_normal = tfp_distribution(TruncatedNormal)

A tfp_distribution generative function which wraps the tfd.TruncatedNormal distribution from TensorFlow Probability distributions.

uniform module-attribute

uniform = tfp_distribution(Uniform)

A tfp_distribution generative function which wraps the tfd.Uniform distribution from TensorFlow Probability distributions.

von_mises module-attribute

von_mises = tfp_distribution(VonMises)

A tfp_distribution generative function which wraps the tfd.VonMises distribution from TensorFlow Probability distributions.

von_mises_fisher module-attribute

von_mises_fisher = tfp_distribution(VonMisesFisher)

A tfp_distribution generative function which wraps the tfd.VonMisesFisher distribution from TensorFlow Probability distributions.

weibull module-attribute

weibull = tfp_distribution(Weibull)

A tfp_distribution generative function which wraps the tfd.Weibull distribution from TensorFlow Probability distributions.

zipf module-attribute

zipf = tfp_distribution(Zipf)

A tfp_distribution generative function which wraps the tfd.Zipf distribution from TensorFlow Probability distributions.

ExactDensity

Bases: Generic[R], Distribution[R]

Source code in src/genjax/_src/generative_functions/distributions/distribution.py
class ExactDensity(Generic[R], Distribution[R]):
    @abstractmethod
    def sample(self, key: PRNGKey, *args) -> R:
        pass

    @abstractmethod
    def logpdf(self, v: R, *args, **kwargs) -> Score:
        pass

    def __abstract_call__(self, *args):
        return to_shape_fn(self.sample, jnp.zeros)(_fake_key, *args)

    def random_weighted(
        self,
        key: PRNGKey,
        *args,
    ) -> tuple[Score, R]:
        """
        Given arguments to the distribution, sample from the distribution, and return the exact log density of the sample, and the sample.
        """
        v = self.sample(key, *args)
        w = self.estimate_logpdf(key, v, *args)
        return (w, v)

    def estimate_logpdf(
        self,
        key: PRNGKey,
        v: R,
        *args,
    ) -> Weight:
        """
        Given a sample and arguments to the distribution, return the exact log density of the sample.
        """
        w = self.logpdf(v, *args)
        if w.shape:
            return jnp.sum(w)
        else:
            return w

    def assess(
        self,
        sample: ChoiceMap,
        args: tuple[Any, ...],
    ) -> tuple[Weight, R]:
        key = jax.random.key(0)
        v = sample.get_value()
        match v:
            case Mask(value, flag):

                def _check():
                    checkify.check(
                        bool(flag),
                        "Attempted to unmask when a mask flag is False: the masked value is invalid.\n",
                    )

                optional_check(_check)
                w = self.estimate_logpdf(key, value, *args)
                return w, value
            case _:
                w = self.estimate_logpdf(key, v, *args)
                return w, v

estimate_logpdf

estimate_logpdf(key: PRNGKey, v: R, *args) -> Weight

Given a sample and arguments to the distribution, return the exact log density of the sample.

Source code in src/genjax/_src/generative_functions/distributions/distribution.py
def estimate_logpdf(
    self,
    key: PRNGKey,
    v: R,
    *args,
) -> Weight:
    """
    Given a sample and arguments to the distribution, return the exact log density of the sample.
    """
    w = self.logpdf(v, *args)
    if w.shape:
        return jnp.sum(w)
    else:
        return w

random_weighted

random_weighted(key: PRNGKey, *args) -> tuple[Score, R]

Given arguments to the distribution, sample from the distribution, and return the exact log density of the sample, and the sample.

Source code in src/genjax/_src/generative_functions/distributions/distribution.py
def random_weighted(
    self,
    key: PRNGKey,
    *args,
) -> tuple[Score, R]:
    """
    Given arguments to the distribution, sample from the distribution, and return the exact log density of the sample, and the sample.
    """
    v = self.sample(key, *args)
    w = self.estimate_logpdf(key, v, *args)
    return (w, v)

exact_density

exact_density(
    sample: Callable[..., R],
    logpdf: Callable[..., Score],
    name: str | None = None,
) -> ExactDensity[R]

Construct a new type, a subclass of ExactDensity, with the given name, (with genjax. prepended, to avoid confusion with the underlying object, which may not share the same interface) and attach the supplied functions as the sample and logpdf methods. The return value is an instance of this new type, and should be treated as a singleton.

Source code in src/genjax/_src/generative_functions/distributions/distribution.py
def exact_density(
    sample: Callable[..., R], logpdf: Callable[..., Score], name: str | None = None
) -> ExactDensity[R]:
    """Construct a new type, a subclass of ExactDensity, with the given name,
    (with `genjax.` prepended, to avoid confusion with the underlying object,
    which may not share the same interface) and attach the supplied functions
    as the `sample` and `logpdf` methods. The return value is an instance of
    this new type, and should be treated as a singleton."""
    if name is None:
        warnings.warn("You should supply a name argument to exact_density")
        name = "unknown"

    def kwargle(f, a0, args, kwargs):
        """Keyword arguments currently get unusual treatment in GenJAX: when
        a keyword argument is provided to a generative function, the function
        is asked to provide a new version of itself which receives a different
        signature: `(args, kwargs)` instead of `(*args, **kwargs)`. The
        replacement of the GF with a new object may cause JAX to believe that
        the implementations are materially different. To avoid this, we
        reply to the handle_kwargs request with self and infer kwargs handling
        by seeing whether we were passed a 2-tuple with a dict in the [1] slot.
        We are assuming that this will not represent a useful argument package
        to any of the TF distributions."""
        if len(args) == 2 and isinstance(args[1], dict):
            return f(a0, *args[0], **args[1])
        else:
            return f(a0, *args, **kwargs)

    T = type(
        canonicalize_distribution_name(name),
        (ExactDensity,),
        {
            "sample": lambda self, key, *args, **kwargs: kwargle(
                sample, key, args, kwargs
            ),
            "logpdf": lambda self, v, *args, **kwargs: kwargle(logpdf, v, args, kwargs),
            "handle_kwargs": lambda self: self,
        },
    )

    return Pytree.dataclass(T)()

tfp_distribution

tfp_distribution(
    dist: Callable[..., Distribution],
    name: str | None = None,
) -> ExactDensity[Array]

Creates a generative function from a TensorFlow Probability distribution.

Parameters:

Name Type Description Default

dist

Callable[..., Distribution]

A callable that returns a TensorFlow Probability distribution.

required

Returns:

Type Description
ExactDensity[Array]

A generative function wrapping the TensorFlow Probability distribution.

This function creates a generative function that encapsulates the sampling and log probability computation of a TensorFlow Probability distribution. It uses the distribution's sample and log_prob methods to define the generative function's behavior.

Source code in src/genjax/_src/generative_functions/distributions/tensorflow_probability/__init__.py
def tfp_distribution(
    dist: Callable[..., "dist.Distribution"], name: str | None = None
) -> ExactDensity[Array]:
    """
    Creates a generative function from a TensorFlow Probability distribution.

    Args:
        dist: A callable that returns a TensorFlow Probability distribution.

    Returns:
        A generative function wrapping the TensorFlow Probability distribution.

    This function creates a generative function that encapsulates the sampling and log probability
    computation of a TensorFlow Probability distribution. It uses the distribution's `sample` and
    `log_prob` methods to define the generative function's behavior.
    """

    def sampler(key, *args, **kwargs):
        sample_shape = kwargs.pop("sample_shape", ())
        d = dist(*args, **kwargs)
        return d.sample(seed=key, sample_shape=Const.unwrap(sample_shape))

    def logpdf(v, *args, **kwargs):
        # Remove unused kwarg to match sampler function behavior
        kwargs.pop("sample_shape", ())
        d = dist(*args, **kwargs)

        return d.log_prob(v)

    return exact_density(sampler, logpdf, name or dist.__name__)

StaticGenerativeFunction: a programmatic language

For any serious work, you'll want a way to combine generative functions together, mixing deterministic functions with sampling. StaticGenerativeFunction is a way to do that: it supports the use of a JAX compatible subset of Python to author generative functions. It also supports the ability to invoke other generative functions: instances of this type (and any other type of generative function) can then be used in larger generative programs.

genjax.StaticGenerativeFunction

Bases: Generic[R], GenerativeFunction[R]

A StaticGenerativeFunction is a generative function which relies on program transformations applied to JAX-compatible Python programs to implement the generative function interface.

By virtue of the implementation, any source program which is provided to this generative function must be JAX traceable, meaning all the footguns for programs that JAX exposes apply to the source program.

Language restrictions

In addition to JAX footguns, there are a few more which are specific to the generative function interface semantics. Here is the full list of language restrictions (and capabilities):

  • One is allowed to use jax.lax control flow primitives so long as the functions provided to the primitives do not contain trace invocations. In other words, utilizing control flow primitives within the source of a StaticGenerativeFunction's source program requires that the control flow primitives get deterministic computation.

  • The above restriction also applies to jax.vmap.

  • Source programs are allowed to utilize untraced randomness, although there are restrictions (which we discuss below). It is required to use jax.random and JAX's PRNG capabilities. To utilize untraced randomness, you'll need to pass in an extra key as an argument to your model.

    @gen
    def model(key: PRNGKey):
        v = some_untraced_call(key)
        x = trace("x", genjax.normal)(v, 1.0)
        return x
    

Methods:

Name Description
simulate
assess

Attributes:

Name Type Description
source Closure[R]

The source program of the generative function. This is a JAX-compatible Python program.

Source code in src/genjax/_src/generative_functions/static.py
@Pytree.dataclass
class StaticGenerativeFunction(Generic[R], GenerativeFunction[R]):
    """A `StaticGenerativeFunction` is a generative function which relies on program
    transformations applied to JAX-compatible Python programs to implement the generative
    function interface.

    By virtue of the implementation, any source program which is provided to this generative function *must* be JAX traceable, meaning [all the footguns for programs that JAX exposes](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html) apply to the source program.

    **Language restrictions**

    In addition to JAX footguns, there are a few more which are specific to the generative function interface semantics. Here is the full list of language restrictions (and capabilities):

    * One is allowed to use `jax.lax` control flow primitives _so long as the functions provided to the primitives do not contain `trace` invocations_. In other words, utilizing control flow primitives within the source of a `StaticGenerativeFunction`'s source program requires that the control flow primitives get *deterministic* computation.

    * The above restriction also applies to `jax.vmap`.

    * Source programs are allowed to utilize untraced randomness, although there are restrictions (which we discuss below). It is required to use [`jax.random`](https://jax.readthedocs.io/en/latest/jax.random.html) and JAX's PRNG capabilities. To utilize untraced randomness, you'll need to pass in an extra key as an argument to your model.

        ```python
        @gen
        def model(key: PRNGKey):
            v = some_untraced_call(key)
            x = trace("x", genjax.normal)(v, 1.0)
            return x
        ```
    """

    source: Closure[R]
    """
    The source program of the generative function. This is a JAX-compatible Python program.
    """

    def __get__(self, instance, _klass) -> "StaticGenerativeFunction[R]":
        """
        This method allows the @genjax.gen decorator to transform instance methods, turning them into `StaticGenerativeFunction[R]` calls.

        NOTE: if you assign an already-created `StaticGenerativeFunction` to a variable inside of a class, it will always receive the instance as its first method.
        """
        return self.partial_apply(instance) if instance else self

    # To get the type of return value, just invoke
    # the source (with abstract tracer arguments).
    def __abstract_call__(self, *args) -> Any:
        return to_shape_fn(self.source, jnp.zeros)(*args)

    def __post_init__(self):
        wrapped = self.source.fn
        # Preserve the original function's docstring and name
        for k in _WRAPPER_ASSIGNMENTS:
            v = getattr(wrapped, k, None)
            if v is not None:
                object.__setattr__(self, k, v)

        object.__setattr__(self, "__wrapped__", wrapped)

    def handle_kwargs(self) -> "StaticGenerativeFunction[R]":
        @Pytree.partial()
        def kwarged_source(args, kwargs):
            return self.source(*args, **kwargs)

        return StaticGenerativeFunction(kwarged_source)

    def simulate(
        self,
        key: PRNGKey,
        args: tuple[Any, ...],
    ) -> StaticTrace[R]:
        (args, retval, traces) = simulate_transform(self.source)(key, args)
        return StaticTrace(self, args, retval, traces)

    def generate(
        self,
        key: PRNGKey,
        constraint: ChoiceMap,
        args: tuple[Any, ...],
    ) -> tuple[StaticTrace[R], Weight]:
        (
            weight,
            # Trace.
            (
                args,
                retval,
                traces,
            ),
        ) = generate_transform(self.source)(key, constraint, args)
        return StaticTrace(self, args, retval, traces), weight

    def project(
        self,
        key: PRNGKey,
        trace: Trace[Any],
        selection: Selection,
    ) -> Weight:
        assert isinstance(trace, StaticTrace)

        weight = jnp.array(0.0)
        for addr in trace.subtraces.keys():
            subprojection = selection(addr)
            subtrace = trace.get_subtrace(addr)
            weight += subtrace.project(key, subprojection)
        return weight

    def edit_update(
        self,
        key: PRNGKey,
        trace: StaticTrace[R],
        constraint: ChoiceMap,
        argdiffs: Argdiffs,
    ) -> tuple[StaticTrace[R], Weight, Retdiff[R], EditRequest]:
        (
            (
                retval_diffs,
                weight,
                (
                    arg_primals,
                    retval_primals,
                    traces,
                ),
                bwd_requests,
            ),
        ) = update_transform(self.source)(key, trace, constraint, argdiffs)
        if not Diff.static_check_tree_diff(retval_diffs):
            retval_diffs = Diff.no_change(retval_diffs)

        def make_bwd_request(traces, subconstraints):
            addresses = traces.keys()
            chm = ChoiceMap.from_mapping(zip(addresses, subconstraints))
            return Update(chm)

        bwd_request = make_bwd_request(traces, bwd_requests)
        return (
            StaticTrace(
                self,
                arg_primals,
                retval_primals,
                traces,
            ),
            weight,
            retval_diffs,
            bwd_request,
        )

    def edit_static_edit_request(
        self,
        key: PRNGKey,
        trace: StaticTrace[R],
        addressed: StaticDict,
        argdiffs: Argdiffs,
    ) -> tuple[StaticTrace[R], Weight, Retdiff[R], EditRequest]:
        (
            (
                retval_diffs,
                weight,
                (
                    arg_primals,
                    retval_primals,
                    traces,
                ),
                bwd_requests,
            ),
        ) = static_edit_request_transform(self.source)(key, trace, addressed, argdiffs)

        def make_bwd_request(
            traces: dict[StaticAddress, Trace[R]],
            subrequests: list[EditRequest],
        ):
            return StaticRequest(dict(zip(traces.keys(), subrequests)))

        bwd_request = make_bwd_request(traces, bwd_requests)
        return (
            StaticTrace(
                self,
                arg_primals,
                retval_primals,
                traces,
            ),
            weight,
            retval_diffs,
            bwd_request,
        )

    def edit_regenerate(
        self,
        key: PRNGKey,
        trace: StaticTrace[R],
        selection: Selection,
        edit_request: EditRequest,
        argdiffs: Argdiffs,
    ) -> tuple[StaticTrace[R], Weight, Retdiff[R], EditRequest]:
        (
            (
                retval_diffs,
                weight,
                (
                    arg_primals,
                    retval_primals,
                    traces,
                ),
                bwd_requests,
            ),
        ) = regenerate_transform(self.source)(
            key, trace, selection, edit_request, argdiffs
        )

        def make_bwd_request(
            traces: dict[StaticAddress, Trace[R]],
            subrequests: list[EditRequest],
        ):
            return StaticRequest(dict(zip(traces.keys(), subrequests)))

        bwd_request = make_bwd_request(traces, bwd_requests)
        return (
            StaticTrace(
                self,
                arg_primals,
                retval_primals,
                traces,
            ),
            weight,
            retval_diffs,
            bwd_request,
        )

    def edit(
        self,
        key: PRNGKey,
        trace: Trace[R],
        edit_request: EditRequest,
        argdiffs: Argdiffs,
    ) -> tuple[StaticTrace[R], Weight, Retdiff[R], EditRequest]:
        assert isinstance(trace, StaticTrace)
        match edit_request:
            case Update(constraint):
                return self.edit_update(
                    key,
                    trace,
                    constraint,
                    argdiffs,
                )

            case StaticRequest(addressed):
                return self.edit_static_edit_request(
                    key,
                    trace,
                    addressed,
                    argdiffs,
                )
            case Regenerate(selection):
                return self.edit_regenerate(
                    key,
                    trace,
                    selection,
                    edit_request,
                    argdiffs,
                )
            case _:
                raise NotSupportedEditRequest(edit_request)

    def assess(
        self,
        sample: ChoiceMap,
        args: tuple[Any, ...],
    ) -> tuple[Score, R]:
        (retval, score) = assess_transform(self.source)(sample, args)
        return (score, retval)

    def inline(self, *args):
        return self.source(*args)

    @property
    def partial_args(self) -> tuple[Any, ...]:
        """
        Returns the partially applied arguments of the generative function.

        This method retrieves the dynamically applied arguments that were used to create
        this StaticGenerativeFunction instance through partial application.

        Returns:
            tuple[Any, ...]: A tuple containing the partially applied arguments.

        Note:
            This method is particularly useful when working with partially applied
            generative functions, allowing access to the pre-filled arguments.
        """
        return self.source.dyn_args

    def partial_apply(self, *args) -> "StaticGenerativeFunction[R]":
        """
        Returns a new [`StaticGenerativeFunction`][] with the given arguments partially applied.

        This method creates a new [`StaticGenerativeFunction`][] that has some of its arguments pre-filled. When called, the new function will use the pre-filled arguments along with any additional arguments provided.

        Args:
            *args: Variable length argument list to be partially applied to the function.

        Returns:
            A new [`StaticGenerativeFunction`][] with partially applied arguments.

        Example:
            ```python
            @gen
            def my_model(x, y):
                z = normal(x, 1.0) @ "z"
                return y * z


            partially_applied_model = my_model.partial_apply(2.0)
            # Now `partially_applied_model` is equivalent to a model that only takes 'y' as an argument
            ```
        """
        all_args = self.source.dyn_args + args
        return gen(Closure[R](all_args, self.source.fn))

source instance-attribute

source: Closure[R]

The source program of the generative function. This is a JAX-compatible Python program.

simulate

simulate(
    key: PRNGKey, args: tuple[Any, ...]
) -> StaticTrace[R]
Source code in src/genjax/_src/generative_functions/static.py
def simulate(
    self,
    key: PRNGKey,
    args: tuple[Any, ...],
) -> StaticTrace[R]:
    (args, retval, traces) = simulate_transform(self.source)(key, args)
    return StaticTrace(self, args, retval, traces)

assess

assess(
    sample: ChoiceMap, args: tuple[Any, ...]
) -> tuple[Score, R]
Source code in src/genjax/_src/generative_functions/static.py
def assess(
    self,
    sample: ChoiceMap,
    args: tuple[Any, ...],
) -> tuple[Score, R]:
    (retval, score) = assess_transform(self.source)(sample, args)
    return (score, retval)