Stochastic probabilities math
Details on random_weighted and estimate_logpdf
¶
import sys
if "google.colab" in sys.modules:
%pip install --quiet "genjax[genstudio]"
Let's start with estimate_logpdf
.
We have that the marginal distribution over the returned value x
(the sample from the normal distribution) is given by
$$p(x) = \sum_i p(x\mid z=i) p(z=i)$$
where the sum is over the possible values of the categorical distribution, $p(x|z=i)$ is the density of the $i$-th normal at $x$, and $p(z=i)$ is the density of the categorical at the value $i$.
This sum can be rewritten as the expectation under the categorical distribution $p(z)$:
$$\sum_i p(x\mid z=i)p(z=i) = \mathbb{E}_{z\sim p(z)}[p(x\mid z)]$$
This means we can get an unbiased estimate of the expectation by simply sampling a z
and returning p(x|z)
: the average value of this process is obviously its expectation (it's the definition on the expectation).
In other words, we proved that the estimation strategy used in estimate_logpdf
indeed returns an unbiased estimate of the exact marginal.
Lastly, as we discussed above we cannot in general invert an unbiased estimate to get an unbiased estimate of the reciprocal, so one may be suspicious that the returned weight in random_weighted
looks like the negation (in logspace) of the one returned in estimate_logpdf
.
Here the argument is different, based on the following identity:
$$\frac{1}{p(x)} = \mathbb{E}_{z\sim p(z\mid x)}[\frac{1}{p(x\mid z)}]$$
The idea is that we can get an unbiased estimate if we can sample from the posterior $p(z|x)$. Given an $x$, this is an intractable sampling problem in general. However, in random_weighted
, we sample a $z$ together with the $x$, and this $z$ is an exact posterior sample of $z$ that we get "for free".
Now to finish the explanation, the compact way to prove the identity is as follows.
$$ \begin{matrix} \frac{1}{p(x)} &\\ = \frac{1}{p(x)} \mathbb{E}_{z \sim B}[p(z)] & \text{$p(z)$ density w.r.t. base measure $B$ and of total mass 1}\\ = \frac{1}{p(x)} \mathbb{E}_{z \sim p(z\mid x)}[\frac{p(z)}{p(z\mid x)}] &\text{seeing $p(z|x)$ as an importance sampler for $B$}\\ = \mathbb{E}_{z \sim p(z\mid x)}[\frac{p(z)}{p(z\mid x)p(x)}] & \text{$p(x)$ doesn't depend on $z$ moved within the expectation}\\ = \mathbb{E}_{z \sim p(z\mid x)}[\frac{p(z)}{p(z,x)}] & \text{ definition of joint distribution}\\ = \mathbb{E}_{z \sim p(z\mid x)}[\frac{p(z)}{p(z)p(x|z)}] & \text{definition of conditional distribution}\\ = \mathbb{E}_{z \sim p(z\mid x)}[\frac{1}{p(x|z)}] & \text{simplification} \end{matrix} $$