Skip to content

Inference

Conditioning probability distributions is a commonly desired operation, allowing users to express Bayesian inference problems. Conditioning is also a subroutine in other desired operations, like marginalization.

The language of inference

In GenJAX, inference problems are specified by constructing Target distributions. Their solutions are approximated using Algorithm families.

genjax.inference.Target

Bases: Generic[R], Pytree

A Target represents an unnormalized target distribution induced by conditioning a generative function on a genjax.ChoiceMap.

Targets are created by providing a generative function, arguments to the generative function, and a constraint.

Examples:

Creating a target from a generative function, by providing arguments and a constraint:

import genjax
from genjax import ChoiceMapBuilder as C
from genjax.inference import Target


@genjax.gen
def model():
    x = genjax.normal(0.0, 1.0) @ "x"
    y = genjax.normal(x, 1.0) @ "y"
    return x


target = Target(model, (), C["y"].set(3.0))
print(target.render_html())

Source code in src/genjax/_src/inference/sp.py
@Pytree.dataclass
class Target(Generic[R], Pytree):
    """
    A `Target` represents an unnormalized target distribution induced by conditioning a generative function on a [`genjax.ChoiceMap`][].

    Targets are created by providing a generative function, arguments to the generative function, and a constraint.

    Examples:
        Creating a target from a generative function, by providing arguments and a constraint:
        ```python exec="yes" html="true" source="material-block" session="core"
        import genjax
        from genjax import ChoiceMapBuilder as C
        from genjax.inference import Target


        @genjax.gen
        def model():
            x = genjax.normal(0.0, 1.0) @ "x"
            y = genjax.normal(x, 1.0) @ "y"
            return x


        target = Target(model, (), C["y"].set(3.0))
        print(target.render_html())
        ```
    """

    p: Annotated[GenerativeFunction[R], Is[validate_non_marginal]]
    args: tuple[Any, ...]
    constraint: ChoiceMap

    def importance(
        self, key: PRNGKey, constraint: ChoiceMap
    ) -> tuple[Trace[R], Weight]:
        merged = self.constraint.merge(constraint)
        return self.p.importance(key, merged, self.args)

    def filter_to_unconstrained(self, choice_map):
        selection = ~self.constraint.get_selection()
        return choice_map.filter(selection)

    def __getitem__(self, addr):
        return self.constraint[addr]

Algorithms inherit from a class called SampleDistribution - these are objects which implement the stochastic probability interface [Lew23], meaning they expose methods to produce samples and samples from density estimators for density computations.

genjax.inference.SampleDistribution module-attribute

SampleDistribution = Distribution[ChoiceMap]

The abstract class SampleDistribution represents the type of distributions whose return value type is a ChoiceMap. This is the abstract base class of Algorithm, as well as Marginal.

Algorithm families implement the stochastic probability interface. Their Distribution methods accept Target instances, and produce samples and density estimates for approximate posteriors.

genjax.inference.Algorithm

Bases: Generic[R], SampleDistribution

Algorithm is the type of inference algorithms: probabilistic programs which provide interfaces for sampling from posterior approximations, and estimating densities.

The stochastic probability interface for Algorithm

Inference algorithms implement the stochastic probability interface:

  • Algorithm.random_weighted exposes sampling from the approximation which the algorithm represents: it accepts a Target as input, representing the unnormalized distribution, and returns a sample from an approximation to the normalized distribution, along with a density estimate of the normalized distribution.

  • Algorithm.estimate_logpdf exposes density estimation for the approximation which Algorithm.random_weighted samples from: it accepts a value on the support of the approximation, and the Target which induced the approximation as input, and returns an estimate of the density of the approximation.

Optional methods for gradient estimators

Subclasses of type Algorithm can also implement two optional methods designed to support effective gradient estimators for variational objectives (estimate_normalizing_constant and estimate_reciprocal_normalizing_constant).

Methods:

Name Description
random_weighted

Given a Target, return a ChoiceMap from an approximation to the normalized distribution of the target, and a random Weight estimate of the normalized density of the target at the sample.

estimate_logpdf

Given a ChoiceMap and a Target, return a random Weight estimate of the normalized density of the target at the sample.

Source code in src/genjax/_src/inference/sp.py
class Algorithm(Generic[R], SampleDistribution):
    """`Algorithm` is the type of inference
    algorithms: probabilistic programs which provide interfaces for sampling from
    posterior approximations, and estimating densities.

    **The stochastic probability interface for `Algorithm`**

    Inference algorithms implement the stochastic probability interface:

    * `Algorithm.random_weighted` exposes sampling from the approximation
    which the algorithm represents: it accepts a `Target` as input, representing the
    unnormalized distribution, and returns a sample from an approximation to
    the normalized distribution, along with a density estimate of the normalized distribution.

    * `Algorithm.estimate_logpdf` exposes density estimation for the
    approximation which `Algorithm.random_weighted` samples from:
    it accepts a value on the support of the approximation, and the `Target` which
    induced the approximation as input, and returns an estimate of the density of
    the approximation.

    **Optional methods for gradient estimators**

    Subclasses of type `Algorithm` can also implement two optional methods
    designed to support effective gradient estimators for variational objectives
    (`estimate_normalizing_constant` and `estimate_reciprocal_normalizing_constant`).
    """

    #########
    # GenSP #
    #########

    @abstractmethod
    def random_weighted(
        self,
        key: PRNGKey,
        *args: Any,
    ) -> tuple[Score, ChoiceMap]:
        """
        Given a [`Target`][genjax.inference.Target], return a [`ChoiceMap`][genjax.core.ChoiceMap] from an approximation to the normalized distribution of the target, and a random [`Weight`][genjax.core.Weight] estimate of the normalized density of the target at the sample.

        The `sample` is a sample on the support of `target.gen_fn` which _are not in_ `target.constraints`, produced by running the inference algorithm.

        Let $T_P(a, c)$ denote the target, with $P$ the distribution on samples represented by `target.gen_fn`, and $S$ denote the sample. Let $w$ denote the weight `w`. The weight $w$ is a random weight such that $w$ satisfies:

        $$
        \\mathbb{E}\\big[\\frac{1}{w} \\mid S \\big] = \\frac{1}{P(S \\mid c; a)}
        $$

        This interface corresponds to **(Defn 3.2) Unbiased Density Sampler** in [[Lew23](https://dl.acm.org/doi/pdf/10.1145/3591290)].
        """
        assert isinstance(args[0], Target)

    @abstractmethod
    def estimate_logpdf(
        self, key: PRNGKey, v: ChoiceMap, *args: tuple[Any, ...]
    ) -> Score:
        """
        Given a [`ChoiceMap`][genjax.core.ChoiceMap] and a [`Target`][genjax.inference.Target], return a random [`Weight`][genjax.core.Weight] estimate of the normalized density of the target at the sample.

        Let $T_P(a, c)$ denote the target, with $P$ the distribution on samples represented by `target.gen_fn`, and $S$ denote the sample. Let $w$ denote the weight `w`. The weight $w$ is a random weight such that $w$ satisfies:

        $$
        \\mathbb{E}[w] = P(S \\mid c, a)
        $$

        This interface corresponds to **(Defn 3.1) Positive Unbiased Density Estimator** in [[Lew23](https://dl.acm.org/doi/pdf/10.1145/3591290)].
        """

    ################
    # VI via GRASP #
    ################

    @abstractmethod
    def estimate_normalizing_constant(
        self,
        key: PRNGKey,
        target: Target[R],
    ) -> Weight:
        pass

    @abstractmethod
    def estimate_reciprocal_normalizing_constant(
        self,
        key: PRNGKey,
        target: Target[R],
        latent_choices: ChoiceMap,
        w: Weight,
    ) -> Weight:
        pass

random_weighted abstractmethod

random_weighted(
    key: PRNGKey, *args: Any
) -> tuple[Score, ChoiceMap]

Given a Target, return a ChoiceMap from an approximation to the normalized distribution of the target, and a random Weight estimate of the normalized density of the target at the sample.

The sample is a sample on the support of target.gen_fn which are not in target.constraints, produced by running the inference algorithm.

Let \(T_P(a, c)\) denote the target, with \(P\) the distribution on samples represented by target.gen_fn, and \(S\) denote the sample. Let \(w\) denote the weight w. The weight \(w\) is a random weight such that \(w\) satisfies:

\[ \mathbb{E}\big[\frac{1}{w} \mid S \big] = \frac{1}{P(S \mid c; a)} \]

This interface corresponds to (Defn 3.2) Unbiased Density Sampler in [Lew23].

Source code in src/genjax/_src/inference/sp.py
@abstractmethod
def random_weighted(
    self,
    key: PRNGKey,
    *args: Any,
) -> tuple[Score, ChoiceMap]:
    """
    Given a [`Target`][genjax.inference.Target], return a [`ChoiceMap`][genjax.core.ChoiceMap] from an approximation to the normalized distribution of the target, and a random [`Weight`][genjax.core.Weight] estimate of the normalized density of the target at the sample.

    The `sample` is a sample on the support of `target.gen_fn` which _are not in_ `target.constraints`, produced by running the inference algorithm.

    Let $T_P(a, c)$ denote the target, with $P$ the distribution on samples represented by `target.gen_fn`, and $S$ denote the sample. Let $w$ denote the weight `w`. The weight $w$ is a random weight such that $w$ satisfies:

    $$
    \\mathbb{E}\\big[\\frac{1}{w} \\mid S \\big] = \\frac{1}{P(S \\mid c; a)}
    $$

    This interface corresponds to **(Defn 3.2) Unbiased Density Sampler** in [[Lew23](https://dl.acm.org/doi/pdf/10.1145/3591290)].
    """
    assert isinstance(args[0], Target)

estimate_logpdf abstractmethod

estimate_logpdf(
    key: PRNGKey, v: ChoiceMap, *args: tuple[Any, ...]
) -> Score

Given a ChoiceMap and a Target, return a random Weight estimate of the normalized density of the target at the sample.

Let \(T_P(a, c)\) denote the target, with \(P\) the distribution on samples represented by target.gen_fn, and \(S\) denote the sample. Let \(w\) denote the weight w. The weight \(w\) is a random weight such that \(w\) satisfies:

\[ \mathbb{E}[w] = P(S \mid c, a) \]

This interface corresponds to (Defn 3.1) Positive Unbiased Density Estimator in [Lew23].

Source code in src/genjax/_src/inference/sp.py
@abstractmethod
def estimate_logpdf(
    self, key: PRNGKey, v: ChoiceMap, *args: tuple[Any, ...]
) -> Score:
    """
    Given a [`ChoiceMap`][genjax.core.ChoiceMap] and a [`Target`][genjax.inference.Target], return a random [`Weight`][genjax.core.Weight] estimate of the normalized density of the target at the sample.

    Let $T_P(a, c)$ denote the target, with $P$ the distribution on samples represented by `target.gen_fn`, and $S$ denote the sample. Let $w$ denote the weight `w`. The weight $w$ is a random weight such that $w$ satisfies:

    $$
    \\mathbb{E}[w] = P(S \\mid c, a)
    $$

    This interface corresponds to **(Defn 3.1) Positive Unbiased Density Estimator** in [[Lew23](https://dl.acm.org/doi/pdf/10.1145/3591290)].
    """

By virtue of the stochastic probability interface, GenJAX also exposes marginalization as a first class concept.

genjax.inference.Marginal

Bases: Generic[R], SampleDistribution

The Marginal class represents the marginal distribution of a generative function over a selection of addresses.

Methods:

Name Description
random_weighted
estimate_logpdf
Source code in src/genjax/_src/inference/sp.py
@Pytree.dataclass
class Marginal(Generic[R], SampleDistribution):
    """The `Marginal` class represents the marginal distribution of a generative function over
    a selection of addresses.
    """

    gen_fn: GenerativeFunction[R]
    selection: Selection = Pytree.field(default=Selection.all())
    algorithm: Algorithm[R] | None = Pytree.field(default=None)

    def random_weighted(
        self,
        key: PRNGKey,
        *args: Any,
    ) -> tuple[Score, ChoiceMap]:
        key, sub_key = jax.random.split(key)
        tr = self.gen_fn.simulate(sub_key, args)
        choices: ChoiceMap = tr.get_choices()
        latent_choices = choices.filter(self.selection)
        key, sub_key = jax.random.split(key)
        bwd_request = ~self.selection
        weight = tr.project(sub_key, bwd_request)
        if self.algorithm is None:
            return weight, latent_choices
        else:
            target = Target(self.gen_fn, args, latent_choices)
            other_choices = choices.filter(~self.selection)
            Z = self.algorithm.estimate_reciprocal_normalizing_constant(
                key, target, other_choices, weight
            )

            return (Z, latent_choices)

    def estimate_logpdf(
        self,
        key: PRNGKey,
        v: ChoiceMap,
        *args: tuple[Any, ...],
    ) -> Score:
        if self.algorithm is None:
            _, weight = self.gen_fn.importance(key, v, args)
            return weight
        else:
            target = Target(self.gen_fn, args, v)
            Z = self.algorithm.estimate_normalizing_constant(key, target)
            return Z

random_weighted

random_weighted(
    key: PRNGKey, *args: Any
) -> tuple[Score, ChoiceMap]
Source code in src/genjax/_src/inference/sp.py
def random_weighted(
    self,
    key: PRNGKey,
    *args: Any,
) -> tuple[Score, ChoiceMap]:
    key, sub_key = jax.random.split(key)
    tr = self.gen_fn.simulate(sub_key, args)
    choices: ChoiceMap = tr.get_choices()
    latent_choices = choices.filter(self.selection)
    key, sub_key = jax.random.split(key)
    bwd_request = ~self.selection
    weight = tr.project(sub_key, bwd_request)
    if self.algorithm is None:
        return weight, latent_choices
    else:
        target = Target(self.gen_fn, args, latent_choices)
        other_choices = choices.filter(~self.selection)
        Z = self.algorithm.estimate_reciprocal_normalizing_constant(
            key, target, other_choices, weight
        )

        return (Z, latent_choices)

estimate_logpdf

estimate_logpdf(
    key: PRNGKey, v: ChoiceMap, *args: tuple[Any, ...]
) -> Score
Source code in src/genjax/_src/inference/sp.py
def estimate_logpdf(
    self,
    key: PRNGKey,
    v: ChoiceMap,
    *args: tuple[Any, ...],
) -> Score:
    if self.algorithm is None:
        _, weight = self.gen_fn.importance(key, v, args)
        return weight
    else:
        target = Target(self.gen_fn, args, v)
        Z = self.algorithm.estimate_normalizing_constant(key, target)
        return Z

The SMC inference library

Sequential Monte Carlo (SMC) is a popular algorithm for performing approximate inference in probabilistic models.

genjax.inference.smc.SMCAlgorithm

Bases: Generic[R], Algorithm[R]

Abstract class for SMC algorithms.

Source code in src/genjax/_src/inference/smc.py
class SMCAlgorithm(Generic[R], Algorithm[R]):
    """Abstract class for SMC algorithms."""

    @abstractmethod
    def get_num_particles(self) -> int:
        pass

    @abstractmethod
    def get_final_target(self) -> Target[R]:
        pass

    @abstractmethod
    def run_smc(
        self,
        key: PRNGKey,
    ) -> ParticleCollection[R]:
        pass

    @abstractmethod
    def run_csmc(
        self,
        key: PRNGKey,
        retained: ChoiceMap,
    ) -> ParticleCollection[R]:
        pass

    # Convenience method for returning an estimate of the normalizing constant
    # of the target.
    def log_marginal_likelihood_estimate(
        self,
        key: PRNGKey,
        target: Target[R] | None = None,
    ):
        if target:
            algorithm = ChangeTarget(self, target)
        else:
            algorithm = self
        key, sub_key = jrandom.split(key)
        particle_collection = algorithm.run_smc(sub_key)
        return particle_collection.get_log_marginal_likelihood_estimate()

    #########
    # GenSP #
    #########

    def random_weighted(
        self,
        key: PRNGKey,
        *args: Any,
    ) -> tuple[Score, ChoiceMap]:
        assert isinstance(args[0], Target)

        target: Target[R] = args[0]
        algorithm = ChangeTarget(self, target)
        key, sub_key = jrandom.split(key)
        particle_collection = algorithm.run_smc(key)
        particle = particle_collection.sample_particle(sub_key)
        log_density_estimate = (
            particle.get_score()
            - particle_collection.get_log_marginal_likelihood_estimate()
        )
        chm = target.filter_to_unconstrained(particle.get_choices())
        return log_density_estimate, chm

    def estimate_logpdf(
        self,
        key: PRNGKey,
        v: ChoiceMap,
        *args: tuple[Any, ...],
    ) -> Score:
        assert isinstance(args[0], Target)

        target: Target[R] = args[0]
        algorithm = ChangeTarget(self, target)
        key, sub_key = jrandom.split(key)
        particle_collection = algorithm.run_csmc(key, v)
        particle = particle_collection.sample_particle(sub_key)
        log_density_estimate = (
            particle.get_score()
            - particle_collection.get_log_marginal_likelihood_estimate()
        )
        return log_density_estimate

    ################
    # VI via GRASP #
    ################

    def estimate_normalizing_constant(
        self,
        key: PRNGKey,
        target: Target[R],
    ) -> FloatArray:
        algorithm = ChangeTarget(self, target)
        key, sub_key = jrandom.split(key)
        particle_collection = algorithm.run_smc(sub_key)
        return particle_collection.get_log_marginal_likelihood_estimate()

    def estimate_reciprocal_normalizing_constant(
        self,
        key: PRNGKey,
        target: Target[R],
        latent_choices: ChoiceMap,
        w: FloatArray,
    ) -> FloatArray:
        algorithm = ChangeTarget(self, target)
        # Special, for ChangeTarget -- to avoid a redundant reweighting step,
        # when we have `w` which (with `latent_choices`) is already properly weighted
        # for the `target`.
        return algorithm.run_csmc_for_normalizing_constant(key, latent_choices, w)

genjax.inference.smc.Importance

Bases: Generic[R], SMCAlgorithm[R]

Accepts as input a target: Target and, optionally, a proposal q: SampleDistribution. q should accept a Target as input and return a choicemap on a subset of the addresses in target.gen_fn not in target.constraints.

This initializes a 1-particle ParticleCollection by importance sampling from target using q.

Any choices in target.p not in q will be sampled from the internal proposal distribution of p, given target.constraints and the choices sampled by q.

Source code in src/genjax/_src/inference/smc.py
@Pytree.dataclass
class Importance(Generic[R], SMCAlgorithm[R]):
    """Accepts as input a `target: Target` and, optionally, a proposal `q: SampleDistribution`.
    `q` should accept a `Target` as input and return a choicemap on a subset
    of the addresses in `target.gen_fn` not in `target.constraints`.

    This initializes a 1-particle `ParticleCollection` by importance sampling from `target` using `q`.

    Any choices in `target.p` not in `q` will be sampled from the internal proposal distribution of `p`,
    given `target.constraints` and the choices sampled by `q`.
    """

    target: Target[R]
    q: SampleDistribution | None = Pytree.field(default=None)

    def get_num_particles(self):
        return 1

    def get_final_target(self):
        return self.target

    def run_smc(self, key: PRNGKey):
        key, sub_key = jrandom.split(key)
        if self.q is not None:
            log_weight, choice = self.q.random_weighted(sub_key, self.target)
            tr, target_score = self.target.importance(key, choice)
        else:
            log_weight = 0.0
            tr, target_score = self.target.importance(key, ChoiceMap.empty())
        return ParticleCollection(
            jtu.tree_map(lambda v: jnp.expand_dims(v, axis=0), tr),
            jnp.array([target_score - log_weight]),
            jnp.array(True),
        )

    def run_csmc(self, key: PRNGKey, retained: ChoiceMap):
        key, sub_key = jrandom.split(key)
        if self.q:
            q_score = self.q.estimate_logpdf(sub_key, retained, self.target)
        else:
            q_score = 0.0
        target_trace, target_score = self.target.importance(key, retained)
        return ParticleCollection(
            jtu.tree_map(lambda v: jnp.expand_dims(v, axis=0), target_trace),
            jnp.array([target_score - q_score]),
            jnp.array(True),
        )

genjax.inference.smc.ImportanceK

Bases: Generic[R], SMCAlgorithm[R]

Given a target: Target and a proposal q: SampleDistribution, as well as the number of particles k_particles: int, initialize a particle collection using importance sampling.

Source code in src/genjax/_src/inference/smc.py
@Pytree.dataclass
class ImportanceK(Generic[R], SMCAlgorithm[R]):
    """Given a `target: Target` and a proposal `q: SampleDistribution`, as well as the
    number of particles `k_particles: int`, initialize a particle collection using
    importance sampling."""

    target: Target[R]
    q: SampleDistribution | None = Pytree.field(default=None)
    k_particles: int = Pytree.static(default=2)

    def get_num_particles(self):
        return self.k_particles

    def get_final_target(self):
        return self.target

    def run_smc(self, key: PRNGKey):
        key, sub_key = jrandom.split(key)
        sub_keys = jrandom.split(sub_key, self.get_num_particles())
        if self.q is not None:
            log_weights, choices = vmap(self.q.random_weighted, in_axes=(0, None))(
                sub_keys, self.target
            )
            trs, target_scores = vmap(self.target.importance)(sub_keys, choices)
        else:
            log_weights = 0.0
            trs, target_scores = vmap(self.target.importance, in_axes=(0, None))(
                sub_keys, ChoiceMap.empty()
            )
        return ParticleCollection(
            trs,
            target_scores - log_weights,
            jnp.array(True),
        )

    def run_csmc(self, key: PRNGKey, retained: ChoiceMap):
        key, sub_key = jrandom.split(key)
        sub_keys = jrandom.split(sub_key, self.get_num_particles() - 1)
        if self.q:
            log_scores, choices = vmap(self.q.random_weighted, in_axes=(0, None))(
                sub_keys, self.target
            )
            retained_choice_score = self.q.estimate_logpdf(key, retained, self.target)
            stacked_choices = jtu.tree_map(stack_to_first_dim, choices, retained)
            stacked_scores = jtu.tree_map(
                stack_to_first_dim, log_scores, retained_choice_score
            )
            sub_keys = jrandom.split(key, self.get_num_particles())
            target_traces, target_scores = vmap(self.target.importance)(
                sub_keys, stacked_choices
            )
        else:
            ignored_traces, ignored_scores = vmap(
                self.target.importance, in_axes=(0, None)
            )(sub_keys, ChoiceMap.empty())
            retained_trace, retained_choice_score = self.target.importance(
                key, retained
            )
            target_scores = jtu.tree_map(
                stack_to_first_dim, ignored_scores, retained_choice_score
            )
            stacked_scores = 0.0
            target_traces = jtu.tree_map(
                stack_to_first_dim, ignored_traces, retained_trace
            )
        return ParticleCollection(
            target_traces,
            target_scores - stacked_scores,
            jnp.array(True),
        )

The VI inference library

Variational inference is an approach to inference which involves solving optimization problems over spaces of distributions. For a posterior inference problem, the goal is to find the distribution in some parametrized family of distributions (often called the guide family) which is close to the posterior under some notion of distance.

Variational inference problems typically involve optimization functions which are defined as expectations, and these expectations and their analytic gradients are often intractable to compute. Therefore, unbiased gradient estimators are used to approximate the true gradients.

The genjax.vi inference module provides automation for constructing variational losses, and deriving gradient estimators. The architecture is shown below.

GenJAX VI architecture
Fig. 1: How variational inference works in GenJAX.

genjax.inference.vi.adev_distribution

adev_distribution(
    adev_primitive: ADEVPrimitive,
    differentiable_logpdf: Callable[..., Any],
    name: str,
) -> ExactDensity[Any]

Return an ExactDensity distribution whose sampler invokes an ADEV sampling primitive, with a provided differentiable log density function.

Exact densities created using this function can be used as distributions in variational guide programs.

Source code in src/genjax/_src/inference/vi.py
def adev_distribution(
    adev_primitive: ADEVPrimitive, differentiable_logpdf: Callable[..., Any], name: str
) -> ExactDensity[Any]:
    """
    Return an [`ExactDensity`][genjax.ExactDensity] distribution whose sampler invokes an ADEV sampling primitive, with a provided differentiable log density function.

    Exact densities created using this function can be used as distributions in variational guide programs.
    """

    def sampler(key: PRNGKey, *args: Any) -> Any:
        return sample_primitive(adev_primitive, *args, key=key)

    def logpdf(v: Any, *args: Any) -> FloatArray:
        lp = differentiable_logpdf(v, *args)
        # Branching here is statically resolved.
        if lp.shape:
            return jnp.sum(lp)
        else:
            return lp

    return exact_density(sampler, logpdf, name)

genjax.inference.vi.ELBO

ELBO(
    guide: SampleDistribution,
    make_target: Callable[..., Target[Any]],
) -> Callable[[PRNGKey, Arguments], GradientEstimate]

Return a function that computes the gradient estimate of the ELBO loss term.

Source code in src/genjax/_src/inference/vi.py
def ELBO(
    guide: SampleDistribution,
    make_target: Callable[..., Target[Any]],
) -> Callable[[PRNGKey, Arguments], GradientEstimate]:
    """
    Return a function that computes the gradient estimate of the ELBO loss term.
    """

    def grad_estimate(
        key: PRNGKey,
        args: tuple[Any, ...],
    ) -> tuple[Any, ...]:
        # In the source language of ADEV.
        @expectation
        def _loss(*args):
            target = make_target(*args)
            guide_alg = Importance(target, guide)
            w = guide_alg.estimate_normalizing_constant(key, target)
            return -w

        return _loss.grad_estimate(key, args)

    return grad_estimate

genjax.inference.vi.IWELBO

IWELBO(
    proposal: SampleDistribution,
    make_target: Callable[[Any], Target[Any]],
    N: int,
) -> Callable[[PRNGKey, Arguments], GradientEstimate]

Return a function that computes the gradient estimate of the IWELBO loss term.

Source code in src/genjax/_src/inference/vi.py
def IWELBO(
    proposal: SampleDistribution,
    make_target: Callable[[Any], Target[Any]],
    N: int,
) -> Callable[[PRNGKey, Arguments], GradientEstimate]:
    """
    Return a function that computes the gradient estimate of the IWELBO loss term.
    """

    def grad_estimate(
        key: PRNGKey,
        args: Arguments,
    ) -> GradientEstimate:
        # In the source language of ADEV.
        @expectation
        def _loss(*args):
            target = make_target(*args)
            guide = ImportanceK(target, proposal, N)
            w = guide.estimate_normalizing_constant(key, target)
            return -w

        return _loss.grad_estimate(key, args)

    return grad_estimate

genjax.inference.vi.PWake

PWake(
    posterior_approx: SampleDistribution,
    make_target: Callable[[Any], Target[Any]],
) -> Callable[[PRNGKey, Arguments], GradientEstimate]

Return a function that computes the gradient estimate of the PWake loss term.

Source code in src/genjax/_src/inference/vi.py
def PWake(
    posterior_approx: SampleDistribution,
    make_target: Callable[[Any], Target[Any]],
) -> Callable[[PRNGKey, Arguments], GradientEstimate]:
    """
    Return a function that computes the gradient estimate of the PWake loss term.
    """

    def grad_estimate(
        key: PRNGKey,
        args: tuple[Any, ...],
    ) -> tuple[Any, ...]:
        key, sub_key1, sub_key2 = jax.random.split(key, 3)

        # In the source language of ADEV.
        @expectation
        def _loss(*target_args):
            target = make_target(*target_args)
            _, sample = posterior_approx.random_weighted(sub_key1, target)
            tr, _ = target.importance(sub_key2, sample)
            return -tr.get_score()

        return _loss.grad_estimate(key, args)

    return grad_estimate

genjax.inference.vi.QWake

QWake(
    proposal: SampleDistribution,
    posterior_approx: SampleDistribution,
    make_target: Callable[[Any], Target[Any]],
) -> Callable[[PRNGKey, Arguments], GradientEstimate]

Return a function that computes the gradient estimate of the QWake loss term.

Source code in src/genjax/_src/inference/vi.py
def QWake(
    proposal: SampleDistribution,
    posterior_approx: SampleDistribution,
    make_target: Callable[[Any], Target[Any]],
) -> Callable[[PRNGKey, Arguments], GradientEstimate]:
    """
    Return a function that computes the gradient estimate of the QWake loss term.
    """

    def grad_estimate(
        key: PRNGKey,
        args: tuple[Any, ...],
    ) -> tuple[Any, ...]:
        key, sub_key1, sub_key2 = jax.random.split(key, 3)

        # In the source language of ADEV.
        @expectation
        def _loss(*target_args):
            target = make_target(*target_args)
            _, sample = posterior_approx.random_weighted(sub_key1, target)
            w = proposal.estimate_logpdf(sub_key2, sample, target)
            return -w

        return _loss.grad_estimate(key, args)

    return grad_estimate