Journey to the center of genjax.core
¶
This page describes the set of core concepts and datatypes in GenJAX, including Gen's generative datatypes and concepts (GenerativeFunction
, Trace
, ChoiceMap
, and EditRequest
), the core JAX compatibility datatypes (Pytree
, Const
, and Closure
), as well as functionally inspired Pytree
extensions (Mask
), and GenJAX's approach to "static" (JAX tracing time) typechecking.
genjax.core.GenerativeFunction
¶
Bases: Generic[R]
, Pytree
GenerativeFunction
is the type of generative functions, the main computational object in Gen.
Generative functions are a type of probabilistic program. In terms of their mathematical specification, they come equipped with a few ingredients:
- (Distribution over samples) \(P(\cdot_t, \cdot_r; a)\) - a probability distribution over samples \(t\) and untraced randomness \(r\), indexed by arguments \(a\). This ingredient is involved in all the interfaces and specifies the distribution over samples which the generative function represents.
- (Family of K/L proposals) \((K(\cdot_t, \cdot_{K_r}; u, t), L(\cdot_t, \cdot_{L_r}; u, t)) = \mathcal{F}(u, t)\) - a family of pairs of probabilistic programs (referred to as K and L), indexed by
EditRequest
\(u\) and an existing sample \(t\). This ingredient supports theedit
andimportance
interface, and is used to specify an SMCP3 move which the generative function must provide in response to an edit request. K and L must satisfy additional properties, described further inedit
. - (Return value function) \(f(t, r, a)\) - a deterministic return value function, which maps samples and untraced randomness to return values.
Generative functions also support a family of Target
distributions - a Target
distribution is a (possibly unnormalized) distribution, typically induced by inference problems.
- \(\delta_\emptyset\) - the empty target, whose only possible value is the empty sample, with density 1.
- (Family of targets induced by \(P\)) \(T_P(a, c)\) - a family of targets indexed by arguments \(a\) and constraints (
ChoiceMap
), created by pairing the distribution over samples \(P\) with arguments and constraint.
Generative functions expose computations using these ingredients through the generative function interface (the methods which are documented below).
Examples:
The interface methods can be used to implement inference algorithms directly - here's a simple example using bootstrap importance sampling directly:
import jax
from jax.scipy.special import logsumexp
import jax.tree_util as jtu
from genjax import ChoiceMapBuilder as C
from genjax import gen, uniform, flip, categorical
@gen
def model():
p = uniform(0.0, 1.0) @ "p"
f1 = flip(p) @ "f1"
f2 = flip(p) @ "f2"
# Bootstrap importance sampling.
def importance_sampling(key, constraint):
key, sub_key = jax.random.split(key)
sub_keys = jax.random.split(sub_key, 5)
tr, log_weights = jax.vmap(model.importance, in_axes=(0, None, None))(
sub_keys, constraint, ()
)
logits = log_weights - logsumexp(log_weights)
idx = categorical(logits)(key)
return jtu.tree_map(lambda v: v[idx], tr.get_choices())
sub_keys = jax.random.split(jax.random.key(0), 50)
samples = jax.jit(jax.vmap(importance_sampling, in_axes=(0, None)))(
sub_keys, C.kw(f1=True, f2=True)
)
print(samples.render_html())
Methods:
Name | Description |
---|---|
__abstract_call__ |
Used to support JAX tracing, although this default implementation involves no |
accumulate |
When called on a |
assess |
Return the score and the return value when the generative function is invoked with the provided arguments, and constrained to take the provided sample as the sampled value. |
contramap |
Specialized version of |
dimap |
Returns a new |
edit |
Update a trace in response to an |
get_zero_trace |
|
handle_kwargs |
Returns a new GenerativeFunction like |
importance |
|
iterate |
When called on a |
iterate_final |
Returns a decorator that wraps a |
map |
Specialized version of |
mask |
Enables dynamic masking of generative functions. Returns a new |
masked_iterate |
Transforms a generative function that takes a single argument of type |
masked_iterate_final |
Transforms a generative function that takes a single argument of type |
mix |
Takes any number of |
or_else |
Returns a |
propose |
|
reduce |
When called on a |
repeat |
Returns a |
scan |
When called on a |
simulate |
Execute the generative function, sampling from its distribution over samples, and return a |
switch |
Given |
vmap |
Returns a |
Source code in src/genjax/_src/core/generative/generative_function.py
238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913 914 915 916 917 918 919 920 921 922 923 924 925 926 927 928 929 930 931 932 933 934 935 936 937 938 939 940 941 942 943 944 945 946 947 948 949 950 951 952 953 954 955 956 957 958 959 960 961 962 963 964 965 966 967 968 969 970 971 972 973 974 975 976 977 978 979 980 981 982 983 984 985 986 987 988 989 990 991 992 993 994 995 996 997 998 999 1000 1001 1002 1003 1004 1005 1006 1007 1008 1009 1010 1011 1012 1013 1014 1015 1016 1017 1018 1019 1020 1021 1022 1023 1024 1025 1026 1027 1028 1029 1030 1031 1032 1033 1034 1035 1036 1037 1038 1039 1040 1041 1042 1043 1044 1045 1046 1047 1048 1049 1050 1051 1052 1053 1054 1055 1056 1057 1058 1059 1060 1061 1062 1063 1064 1065 1066 1067 1068 1069 1070 1071 1072 1073 1074 1075 1076 1077 1078 1079 1080 1081 1082 1083 1084 1085 1086 1087 1088 1089 1090 1091 1092 1093 1094 1095 1096 1097 1098 1099 1100 1101 1102 1103 1104 1105 1106 1107 1108 1109 1110 1111 1112 1113 1114 1115 1116 1117 1118 1119 1120 1121 1122 1123 1124 1125 1126 1127 1128 1129 1130 1131 1132 1133 1134 1135 1136 1137 1138 1139 1140 1141 1142 1143 1144 1145 1146 1147 1148 1149 1150 1151 1152 1153 1154 1155 1156 1157 1158 1159 1160 1161 1162 1163 1164 1165 1166 1167 1168 1169 1170 1171 1172 1173 1174 1175 1176 1177 1178 1179 1180 1181 1182 1183 1184 1185 1186 1187 1188 1189 1190 1191 1192 1193 1194 1195 1196 1197 1198 1199 1200 1201 1202 1203 1204 1205 1206 1207 1208 1209 1210 1211 1212 1213 1214 1215 1216 1217 1218 1219 1220 1221 1222 1223 1224 1225 1226 1227 1228 1229 1230 1231 1232 1233 1234 1235 1236 1237 1238 1239 1240 1241 1242 1243 1244 1245 1246 1247 1248 1249 1250 1251 1252 1253 1254 1255 1256 1257 1258 1259 1260 1261 1262 1263 1264 1265 1266 1267 1268 1269 1270 1271 1272 1273 1274 1275 1276 1277 1278 1279 1280 1281 1282 1283 1284 1285 1286 1287 1288 1289 1290 1291 1292 1293 1294 1295 1296 1297 1298 1299 1300 1301 1302 1303 1304 1305 1306 1307 1308 1309 1310 1311 1312 1313 1314 1315 1316 1317 1318 1319 1320 1321 1322 1323 1324 1325 1326 1327 1328 1329 1330 1331 1332 1333 1334 1335 1336 1337 1338 1339 1340 1341 1342 1343 1344 1345 1346 1347 1348 1349 1350 1351 1352 1353 1354 1355 1356 1357 1358 1359 1360 1361 1362 1363 1364 1365 1366 1367 1368 1369 1370 1371 1372 1373 1374 1375 1376 1377 1378 1379 1380 1381 1382 1383 1384 1385 1386 1387 1388 1389 1390 1391 1392 1393 1394 1395 1396 1397 1398 1399 1400 1401 1402 1403 1404 1405 1406 1407 1408 1409 1410 1411 1412 1413 1414 1415 1416 1417 1418 1419 1420 1421 1422 1423 1424 1425 1426 1427 1428 1429 1430 1431 1432 1433 1434 1435 1436 1437 1438 1439 1440 1441 1442 1443 1444 1445 1446 1447 1448 1449 1450 1451 1452 1453 1454 1455 1456 1457 1458 1459 1460 1461 1462 1463 1464 1465 1466 1467 1468 1469 1470 1471 1472 1473 1474 1475 1476 1477 1478 1479 1480 1481 1482 1483 1484 |
|
__abstract_call__
¶
Used to support JAX tracing, although this default implementation involves no JAX operations (it takes a fixed-key sample from the return value).
Generative functions may customize this to improve compilation time.
Source code in src/genjax/_src/core/generative/generative_function.py
accumulate
¶
accumulate() -> GenerativeFunction[R]
When called on a genjax.GenerativeFunction
of type (c, a) -> c
, returns a new genjax.GenerativeFunction
of type (c, [a]) -> [c]
where
c
is a loop-carried value, which must hold a fixed shape and dtype across all iterations[c]
is an array of all loop-carried values seen during iteration (including the first)a
may be a primitive, an array type or a pytree (container) type with array leaves
All traced values are nested under an index.
For any array type specifier t
, [t]
represents the type with an additional leading axis, and if t
is a pytree (container) type with array leaves then [t]
represents the type with the same pytree structure and corresponding leaves each with an additional leading axis.
The semantics of the returned genjax.GenerativeFunction
are given roughly by this Python implementation (note the similarity to itertools.accumulate
):
def accumulate(f, init, xs):
carry = init
carries = [init]
for x in xs:
carry = f(carry, x)
carries.append(carry)
return carries
Unlike that Python version, both xs
and carries
may be arbitrary pytree values, and so multiple arrays can be scanned over at once and produce multiple output arrays.
The loop-carried value c
must hold a fixed shape and dtype across all iterations (and not just be consistent up to NumPy rank/shape broadcasting and dtype promotion rules, for example). In other words, the type c
in the type signature above represents an array with a fixed shape and dtype (or a nested tuple/list/dict container data structure with a fixed structure and arrays with fixed shape and dtype at the leaves).
Examples:
import jax
import genjax
import jax.numpy as jnp
@genjax.accumulate()
@genjax.gen
def add(sum, x):
new_sum = sum + x
return new_sum
init = 0.0
key = jax.random.key(314159)
xs = jnp.ones(10)
tr = jax.jit(add.simulate)(key, (init, xs))
print(tr.render_html())
Source code in src/genjax/_src/core/generative/generative_function.py
assess
abstractmethod
¶
Return the score and the return value when the generative function is invoked with the provided arguments, and constrained to take the provided sample as the sampled value.
It is an error if the provided sample value is off the support of the distribution over the ChoiceMap
type, or otherwise induces a partial constraint on the execution of the generative function (which would require the generative function to provide an edit
implementation which responds to the EditRequest
induced by the importance
interface).
Examples:
This method is similar to density evaluation interfaces for distributions.
from genjax import normal
from genjax import ChoiceMapBuilder as C
sample = C.v(1.0)
score, retval = normal.assess(sample, (1.0, 1.0))
print((score, retval))
But it also works with generative functions that sample from spaces with more structure:
from genjax import gen
from genjax import normal
from genjax import ChoiceMapBuilder as C
@gen
def model():
v1 = normal(0.0, 1.0) @ "v1"
v2 = normal(v1, 1.0) @ "v2"
sample = C.kw(v1=1.0, v2=0.0)
score, retval = model.assess(sample, ())
print((score, retval))
Source code in src/genjax/_src/core/generative/generative_function.py
contramap
¶
contramap(
f: Callable[..., ArgTuple]
) -> GenerativeFunction[R]
Specialized version of genjax.GenerativeFunction.dimap
where only the pre-processing function is applied.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
|
Callable[..., ArgTuple]
|
A callable that preprocesses the arguments of the wrapped function. Note that |
required |
Returns:
Type | Description |
---|---|
GenerativeFunction[R]
|
A |
Examples:
import jax, genjax
# Define a pre-processing function.
# Note that this function must return a tuple of arguments!
def add_one(x):
return (x + 1,)
@genjax.gen
def model(x):
return genjax.normal(x, 1.0) @ "z"
contramap_model = model.contramap(add_one)
# Use the contramap model
key = jax.random.key(0)
trace = contramap_model.simulate(key, (2.0,))
print(trace.render_html())
Source code in src/genjax/_src/core/generative/generative_function.py
dimap
¶
dimap(
*,
pre: Callable[..., ArgTuple],
post: Callable[[tuple[Any, ...], ArgTuple, R], S]
) -> GenerativeFunction[S]
Returns a new genjax.GenerativeFunction
and applies pre- and post-processing functions to its arguments and return value.
Info
Prefer genjax.GenerativeFunction.map
if you only need to transform the return value, or genjax.GenerativeFunction.contramap
if you only need to transform the arguments.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
|
Callable[..., ArgTuple]
|
A callable that preprocesses the arguments before passing them to the wrapped function. Note that |
required |
|
Callable[[tuple[Any, ...], ArgTuple, R], S]
|
A callable that postprocesses the return value of the wrapped function. Default is the identity function. |
required |
Returns:
Type | Description |
---|---|
GenerativeFunction[S]
|
A new |
Examples:
import jax, genjax
# Define pre- and post-processing functions
def pre_process(x, y):
return (x + 1, y * 2)
def post_process(args, xformed, retval):
return retval**2
@genjax.gen
def model(x, y):
return genjax.normal(x, y) @ "z"
dimap_model = model.dimap(pre=pre_process, post=post_process)
# Use the dimap model
key = jax.random.key(0)
trace = dimap_model.simulate(key, (2.0, 3.0))
print(trace.render_html())
Source code in src/genjax/_src/core/generative/generative_function.py
edit
abstractmethod
¶
edit(
key: PRNGKey,
trace: Trace[R],
edit_request: EditRequest,
argdiffs: Argdiffs,
) -> tuple[Trace[R], Weight, Retdiff[R], EditRequest]
Update a trace in response to an EditRequest
, returning a new Trace
, an incremental Weight
for the new target, a Retdiff
return value tagged with change information, and a backward EditRequest
which requests the reverse move (to go back to the original trace).
The specification of this interface is parametric over the kind of EditRequest
-- responding to an EditRequest
instance requires that the generative function provides an implementation of a sequential Monte Carlo move in the SMCP3 framework. Users of inference algorithms are not expected to understand the ingredients, but inference algorithm developers are.
Examples:
Updating a trace in response to a request for a Target
change induced by a change to the arguments:
import jax
from genjax import gen, normal, Diff, Update, ChoiceMap as C
key = jax.random.key(0)
@gen
def model(var):
v1 = normal(0.0, 1.0) @ "v1"
v2 = normal(v1, var) @ "v2"
return v2
# Generating an initial trace properly weighted according
# to the target induced by the constraint.
constraint = C.kw(v2=1.0)
initial_tr, w = model.importance(key, constraint, (1.0,))
# Updating the trace to a new target.
new_tr, inc_w, retdiff, bwd_prob = model.edit(
key,
initial_tr,
Update(
C.empty(),
),
Diff.unknown_change((3.0,)),
)
Now, let's inspect the trace:
# Inspect the trace, the sampled values should not have changed!
sample = new_tr.get_choices()
print(sample["v1"], sample["v2"])
And the return value diff:
As expected, neither have changed -- but the weight is non-zero:
Mathematical ingredients behind edit¶
The edit
interface exposes SMCP3 moves. Here, we omit the measure theoretic description, and refer interested readers to the paper. Informally, the ingredients of such a move are:
- The previous target \(T\).
- The new target \(T'\).
- A pair of kernel probabilistic programs, called \(K\) and \(L\):
- The K kernel is a kernel probabilistic program which accepts a previous sample \(x_{t-1}\) from \(T\) as an argument, may sample auxiliary randomness \(u_K\), and returns a new sample \(x_t\) approximately distributed according to \(T'\), along with transformed randomness \(u_L\).
- The L kernel is a kernel probabilistic program which accepts the new sample \(x_t\), and provides a density evaluator for the auxiliary randomness \(u_L\) which K returns, and an inverter \(x_t \mapsto x_{t-1}\) which is almost everywhere the identity function.
The specification of these ingredients are encapsulated in the type signature of the edit
interface.
Understanding the edit
interface¶
The edit
interface uses the mathematical ingredients described above to perform probability-aware mutations and incremental Weight
computations on Trace
instances, which allows Gen to provide automation to support inference agorithms like importance sampling, SMC, MCMC and many more.
An EditRequest
denotes a function \(tr \mapsto (T, T')\) from traces to a pair of targets (the previous Target
\(T\), and the final Target
\(T'\)).
Several common types of moves can be requested via the Update
type:
from genjax import Update
from genjax import ChoiceMap
g = Update(
ChoiceMap.empty(), # Constraint
)
Update
contains information about changes to the arguments of the generative function (Argdiffs
) and a constraint which specifies an additional move to be performed.
new_tr, inc_w, retdiff, bwd_prob = model.edit(
key,
initial_tr,
Update(
C.kw(v1=3.0),
),
Diff.unknown_change((3.0,)),
)
print((new_tr.get_choices()["v1"], w))
Additional notes on Argdiffs
Argument changes induce changes to the distribution over samples, internal K and L proposals, and (by virtue of changes to \(P\)) target distributions. The Argdiffs
type denotes the type of values attached with a change type, a piece of data which indicates how the value has changed from the arguments which created the trace. Generative functions can utilize change type information to inform efficient edit
implementations.
Source code in src/genjax/_src/core/generative/generative_function.py
495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 |
|
get_zero_trace
¶
get_zero_trace(*args, **_kwargs) -> Trace[R]
Returns a trace with zero values for all leaves, generated without executing the generative function.
This method is useful for static analysis and shape inference without executing the generative function. It returns a trace with the same structure as a real trace, but filled with zero or default values.
Args:
*args: The arguments to the generative function.
**_kwargs: Ignored keyword arguments.
Returns:
A trace with zero values, matching the structure of a real trace.
Note:
This method uses the `empty_trace` utility function, which creates a trace without spending any FLOPs. The resulting trace has the correct structure but contains placeholder zero values.
Example:
```python exec="yes" html="true" source="material-block" session="core"
@genjax.gen
def weather_model():
temperature = genjax.normal(20.0, 5.0) @ "temperature"
is_sunny = genjax.bernoulli(0.7) @ "is_sunny"
return {"temperature": temperature, "is_sunny": is_sunny}
zero_trace = weather_model.get_zero_trace()
print("Zero trace structure:")
print(zero_trace.render_html())
print("
Actual simulation:") key = jax.random.key(0) actual_trace = weather_model.simulate(key, ()) print(actual_trace.render_html()) ```
Source code in src/genjax/_src/core/generative/generative_function.py
handle_kwargs
¶
handle_kwargs() -> GenerativeFunction[R]
Returns a new GenerativeFunction like self
, but where all GFI methods accept a tuple of arguments and a dictionary of keyword arguments.
The returned GenerativeFunction can be invoked with __call__
with no special argument handling (just like the original).
In place of args
tuples in GFI methods, the new GenerativeFunction expects a 2-tuple containing:
- A tuple containing the original positional arguments.
- A dictionary containing the keyword arguments.
This allows for more flexible argument passing, especially useful in contexts where keyword arguments need to be handled separately or passed through multiple layers.
Returns:
Type | Description |
---|---|
GenerativeFunction[R]
|
A new GenerativeFunction that accepts (args_tuple, kwargs_dict) for all GFI methods. |
Example
import genjax
import jax
@genjax.gen
def model(x, y, z=1.0):
_ = genjax.normal(x + y, z) @ "v"
return x + y + z
key = jax.random.key(0)
kw_model = model.handle_kwargs()
tr = kw_model.simulate(key, ((1.0, 2.0), {"z": 3.0}))
print(tr.render_html())
Source code in src/genjax/_src/core/generative/generative_function.py
importance
¶
Returns a properly weighted pair, a Trace
and a Weight
, properly weighted for the target induced by the generative function for the provided constraint and arguments.
Examples:
(Full constraints) A simple example using the importance
interface on distributions:
import jax
from genjax import normal
from genjax import ChoiceMapBuilder as C
key = jax.random.key(0)
tr, w = normal.importance(key, C.v(1.0), (0.0, 1.0))
print(tr.get_choices().render_html())
(Internal proposal for partial constraints) Specifying a partial constraint on a StaticGenerativeFunction
:
from genjax import flip, uniform, gen
from genjax import ChoiceMapBuilder as C
@gen
def model():
p = uniform(0.0, 1.0) @ "p"
f1 = flip(p) @ "f1"
f2 = flip(p) @ "f2"
tr, w = model.importance(key, C.kw(f1=True, f2=True), ())
print(tr.get_choices().render_html())
Under the hood, creates an EditRequest
which requests that the generative function respond with a move from the empty trace (the only possible value for empty target \(\delta_\emptyset\)) to the target induced by the generative function for constraint \(C\) with arguments \(a\).
Source code in src/genjax/_src/core/generative/generative_function.py
iterate
¶
iterate(*, n: int) -> GenerativeFunction[R]
When called on a genjax.GenerativeFunction
of type a -> a
, returns a new genjax.GenerativeFunction
of type a -> [a]
where
a
is a loop-carried value, which must hold a fixed shape and dtype across all iterations[a]
is an array of alla
,f(a)
,f(f(a))
etc. values seen during iteration.
All traced values are nested under an index.
The semantics of the returned genjax.GenerativeFunction
are given roughly by this Python implementation:
def iterate(f, n, init):
input = init
seen = [init]
for _ in range(n):
input = f(input)
seen.append(input)
return seen
init
may be an arbitrary pytree value, and so multiple arrays can be iterated over at once and produce multiple output arrays.
The iterated value a
must hold a fixed shape and dtype across all iterations (and not just be consistent up to NumPy rank/shape broadcasting and dtype promotion rules, for example). In other words, the type a
in the type signature above represents an array with a fixed shape and dtype (or a nested tuple/list/dict container data structure with a fixed structure and arrays with fixed shape and dtype at the leaves).
Parameters:
Name | Type | Description | Default |
---|---|---|---|
|
int
|
the number of iterations to run. |
required |
Examples:
iterative addition, returning all intermediate sums:
import jax
import genjax
@genjax.iterate(n=100)
@genjax.gen
def inc(x):
return x + 1
init = 0.0
key = jax.random.key(314159)
tr = jax.jit(inc.simulate)(key, (init,))
print(tr.render_html())
Source code in src/genjax/_src/core/generative/generative_function.py
iterate_final
¶
iterate_final(*, n: int) -> GenerativeFunction[R]
Returns a decorator that wraps a genjax.GenerativeFunction
of type a -> a
and returns a new genjax.GenerativeFunction
of type a -> a
where
a
is a loop-carried value, which must hold a fixed shape and dtype across all iterations- the original function is invoked
n
times with each input coming from the previous invocation's output, so that the new function returns \(f^n(a)\)
All traced values are nested under an index.
The semantics of the returned genjax.GenerativeFunction
are given roughly by this Python implementation:
init
may be an arbitrary pytree value, and so multiple arrays can be iterated over at once and produce multiple output arrays.
The iterated value a
must hold a fixed shape and dtype across all iterations (and not just be consistent up to NumPy rank/shape broadcasting and dtype promotion rules, for example). In other words, the type a
in the type signature above represents an array with a fixed shape and dtype (or a nested tuple/list/dict container data structure with a fixed structure and arrays with fixed shape and dtype at the leaves).
Parameters:
Name | Type | Description | Default |
---|---|---|---|
|
int
|
the number of iterations to run. |
required |
Examples:
iterative addition:
import jax
import genjax
@genjax.iterate_final(n=100)
@genjax.gen
def inc(x):
return x + 1
init = 0.0
key = jax.random.key(314159)
tr = jax.jit(inc.simulate)(key, (init,))
print(tr.render_html())
Source code in src/genjax/_src/core/generative/generative_function.py
map
¶
map(f: Callable[[R], S]) -> GenerativeFunction[S]
Specialized version of genjax.dimap
where only the post-processing function is applied.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
|
Callable[[R], S]
|
A callable that postprocesses the return value of the wrapped function. |
required |
Returns:
Type | Description |
---|---|
GenerativeFunction[S]
|
A |
Examples:
import jax, genjax
# Define a post-processing function
def square(x):
return x**2
@genjax.gen
def model(x):
return genjax.normal(x, 1.0) @ "z"
map_model = model.map(square)
# Use the map model
key = jax.random.key(0)
trace = map_model.simulate(key, (2.0,))
print(trace.render_html())
Source code in src/genjax/_src/core/generative/generative_function.py
mask
¶
mask() -> GenerativeFunction[Mask[R]]
Enables dynamic masking of generative functions. Returns a new genjax.GenerativeFunction
like self
, but which accepts an additional boolean first argument.
If True
, the invocation of self
is masked, and its contribution to the score is ignored. If False
, it has the same semantics as if one was invoking self
without masking.
The return value type is a Mask
, with a flag value equal to the supplied boolean.
Returns:
Type | Description |
---|---|
GenerativeFunction[Mask[R]]
|
The masked version of the original |
Examples:
Masking a normal draw:
import genjax, jax
@genjax.gen
def normal_draw(mean):
return genjax.normal(mean, 1.0) @ "x"
masked_normal_draw = normal_draw.mask()
key = jax.random.key(314159)
tr = jax.jit(masked_normal_draw.simulate)(
key,
(
False,
2.0,
),
)
print(tr.render_html())
Source code in src/genjax/_src/core/generative/generative_function.py
masked_iterate
¶
masked_iterate() -> GenerativeFunction[R]
Transforms a generative function that takes a single argument of type a
and returns a value of type a
, into a function that takes a tuple of arguments (a, [mask])
and returns a list of values of type a
.
The original function is modified to accept an additional argument mask
, which is a boolean value indicating whether the operation should be masked or not. The function returns a Masked list of results of the original operation with the input [mask] as mask.
All traced values from the kernel generative function are traced (with an added axis due to the scan) but only those indices from [mask] with a flag of True will accounted for in inference, notably for score computations.
Example
import jax
import genjax
masks = jnp.array([True, False, True])
# Create a kernel generative function
@genjax.gen
def step(x):
_ = (
genjax.normal.mask().vmap(in_axes=(0, None, None))(masks, x, 1.0)
@ "rats"
)
return x
# Create a model using masked_iterate
model = step.masked_iterate()
# Simulate from the model
key = jax.random.key(0)
mask_steps = jnp.arange(10) < 5
tr = model.simulate(key, (0.0, mask_steps))
print(tr.render_html())
Source code in src/genjax/_src/core/generative/generative_function.py
masked_iterate_final
¶
masked_iterate_final() -> GenerativeFunction[R]
Transforms a generative function that takes a single argument of type a
and returns a value of type a
, into a function that takes a tuple of arguments (a, [mask])
and returns a value of type a
.
The original function is modified to accept an additional argument mask
, which is a boolean value indicating whether the operation should be masked or not. The function returns the result of the original operation if mask
is True
, and the original input if mask
is False
.
All traced values from the kernel generative function are traced (with an added axis due to the scan) but only those indices from [mask] with a flag of True will accounted for in inference, notably for score computations.
Example
import jax
import genjax
masks = jnp.array([True, False, True])
# Create a kernel generative function
@genjax.gen
def step(x):
_ = (
genjax.normal.mask().vmap(in_axes=(0, None, None))(masks, x, 1.0)
@ "rats"
)
return x
# Create a model using masked_iterate_final
model = step.masked_iterate_final()
# Simulate from the model
key = jax.random.key(0)
mask_steps = jnp.arange(10) < 5
tr = model.simulate(key, (0.0, mask_steps))
print(tr.render_html())
Source code in src/genjax/_src/core/generative/generative_function.py
mix
¶
mix(*fns: GenerativeFunction[R]) -> GenerativeFunction[R]
Takes any number of genjax.GenerativeFunction
s and returns a new genjax.GenerativeFunction
that represents a mixture model.
The returned generative function takes the following arguments:
mixture_logits
: Logits for the categorical distribution used to select a component.*args
: Argument tuples forself
and each of the input generative functions
and samples from self
or one of the input generative functions based on a draw from a categorical distribution defined by the provided mixture logits.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
|
GenerativeFunction[R]
|
Variable number of |
()
|
Returns:
Type | Description |
---|---|
GenerativeFunction[R]
|
A new |
Examples:
```python exec="yes" html="true" source="material-block" session="mix" import jax import genjax
Define component generative functions¶
@genjax.gen def component1(x): return genjax.normal(x, 1.0) @ "y"
@genjax.gen def component2(x): return genjax.normal(x, 2.0) @ "y"
Create mixture model¶
mixture = component1.mix(component2)
Use the mixture model¶
key = jax.random.key(0) logits = jax.numpy.array([0.3, 0.7]) # Favors component2 trace = mixture.simulate(key, (logits, (0.0,), (7.0,))) print(trace.render_html()) ```
Source code in src/genjax/_src/core/generative/generative_function.py
or_else
¶
or_else(
gen_fn: GenerativeFunction[R],
) -> GenerativeFunction[R]
Returns a GenerativeFunction
that accepts
- a boolean argument
- an argument tuple for
self
- an argument tuple for the supplied
gen_fn
and acts like self
when the boolean is True
or like gen_fn
otherwise.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
|
GenerativeFunction[R]
|
called when the boolean argument is |
required |
Examples:
import jax
import jax.numpy as jnp
import genjax
@genjax.gen
def if_model(x):
return genjax.normal(x, 1.0) @ "if_value"
@genjax.gen
def else_model(x):
return genjax.normal(x, 5.0) @ "else_value"
@genjax.gen
def model(toss: bool):
# Note that the returned model takes a new boolean predicate in
# addition to argument tuples for each branch.
return if_model.or_else(else_model)(toss, (1.0,), (10.0,)) @ "tossed"
key = jax.random.key(314159)
tr = jax.jit(model.simulate)(key, (True,))
print(tr.render_html())
Source code in src/genjax/_src/core/generative/generative_function.py
propose
¶
Samples a ChoiceMap
and any untraced randomness \(r\) from the generative function's distribution over samples (\(P\)), and returns the Score
of that sample under the distribution, and the R
of the generative function's return value function \(f(r, t, a)\) for the sample and untraced randomness.
Source code in src/genjax/_src/core/generative/generative_function.py
reduce
¶
reduce() -> GenerativeFunction[R]
When called on a genjax.GenerativeFunction
of type (c, a) -> c
, returns a new genjax.GenerativeFunction
of type (c, [a]) -> c
where
c
is a loop-carried value, which must hold a fixed shape and dtype across all iterationsa
may be a primitive, an array type or a pytree (container) type with array leaves
All traced values are nested under an index.
For any array type specifier t
, [t]
represents the type with an additional leading axis, and if t
is a pytree (container) type with array leaves then [t]
represents the type with the same pytree structure and corresponding leaves each with an additional leading axis.
The semantics of the returned genjax.GenerativeFunction
are given roughly by this Python implementation (note the similarity to functools.reduce
):
Unlike that Python version, both xs
and carry
may be arbitrary pytree values, and so multiple arrays can be scanned over at once and produce multiple output arrays.
The loop-carried value c
must hold a fixed shape and dtype across all iterations (and not just be consistent up to NumPy rank/shape broadcasting and dtype promotion rules, for example). In other words, the type c
in the type signature above represents an array with a fixed shape and dtype (or a nested tuple/list/dict container data structure with a fixed structure and arrays with fixed shape and dtype at the leaves).
Examples:
sum an array of numbers:
import jax
import genjax
import jax.numpy as jnp
@genjax.reduce()
@genjax.gen
def add(sum, x):
new_sum = sum + x
return new_sum
init = 0.0
key = jax.random.key(314159)
xs = jnp.ones(10)
tr = jax.jit(add.simulate)(key, (init, xs))
print(tr.render_html())
Source code in src/genjax/_src/core/generative/generative_function.py
repeat
¶
repeat(*, n: int) -> GenerativeFunction[R]
Returns a genjax.GenerativeFunction
that samples from self
n
times, returning a vector of n
results.
The values traced by each call gen_fn
will be nested under an integer index that matches the loop iteration index that generated it.
This combinator is useful for creating multiple samples from self
in a batched manner.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
|
int
|
The number of times to sample from the generative function. |
required |
Returns:
Type | Description |
---|---|
GenerativeFunction[R]
|
A new |
Examples:
import genjax, jax
@genjax.gen
def normal_draw(mean):
return genjax.normal(mean, 1.0) @ "x"
normal_draws = normal_draw.repeat(n=10)
key = jax.random.key(314159)
# Generate 10 draws from a normal distribution with mean 2.0
tr = jax.jit(normal_draws.simulate)(key, (2.0,))
print(tr.render_html())
Source code in src/genjax/_src/core/generative/generative_function.py
scan
¶
scan(
*, n: int | None = None
) -> GenerativeFunction[tuple[Carry, Y]]
When called on a genjax.GenerativeFunction
of type (c, a) -> (c, b)
, returns a new genjax.GenerativeFunction
of type (c, [a]) -> (c, [b])
where
c
is a loop-carried value, which must hold a fixed shape and dtype across all iterationsa
may be a primitive, an array type or a pytree (container) type with array leavesb
may be a primitive, an array type or a pytree (container) type with array leaves.
The values traced by each call to the original generative function will be nested under an integer index that matches the loop iteration index that generated it.
For any array type specifier t
, [t]
represents the type with an additional leading axis, and if t
is a pytree (container) type with array leaves then [t]
represents the type with the same pytree structure and corresponding leaves each with an additional leading axis.
When the type of xs
in the snippet below (denoted [a]
above) is an array type or None, and the type of ys
in the snippet below (denoted [b]
above) is an array type, the semantics of the returned genjax.GenerativeFunction
are given roughly by this Python implementation:
def scan(f, init, xs, length=None):
if xs is None:
xs = [None] * length
carry = init
ys = []
for x in xs:
carry, y = f(carry, x)
ys.append(y)
return carry, np.stack(ys)
Unlike that Python version, both xs
and ys
may be arbitrary pytree values, and so multiple arrays can be scanned over at once and produce multiple output arrays. None
is actually a special case of this, as it represents an empty pytree.
The loop-carried value c
must hold a fixed shape and dtype across all iterations (and not just be consistent up to NumPy rank/shape broadcasting and dtype promotion rules, for example). In other words, the type c
in the type signature above represents an array with a fixed shape and dtype (or a nested tuple/list/dict container data structure with a fixed structure and arrays with fixed shape and dtype at the leaves).
Parameters:
Name | Type | Description | Default |
---|---|---|---|
|
int | None
|
optional integer specifying the number of loop iterations, which (if supplied) must agree with the sizes of leading axes of the arrays in the returned function's second argument. If supplied then the returned generative function can take |
None
|
Returns:
Type | Description |
---|---|
GenerativeFunction[tuple[Carry, Y]]
|
A new |
Examples:
Scan for 1000 iterations with no array input:
import jax
import genjax
@genjax.gen
def random_walk_step(prev, _):
x = genjax.normal(prev, 1.0) @ "x"
return x, None
random_walk = random_walk_step.scan(n=1000)
init = 0.5
key = jax.random.key(314159)
tr = jax.jit(random_walk.simulate)(key, (init, None))
print(tr.render_html())
Scan across an input array:
import jax.numpy as jnp
@genjax.gen
def add_and_square_step(sum, x):
new_sum = sum + x
return new_sum, sum * sum
# notice no `n` parameter supplied:
add_and_square_all = add_and_square_step.scan()
init = 0.0
xs = jnp.ones(10)
tr = jax.jit(add_and_square_all.simulate)(key, (init, xs))
# The retval has the final carry and an array of all `sum*sum` returned.
print(tr.render_html())
Source code in src/genjax/_src/core/generative/generative_function.py
776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 |
|
simulate
abstractmethod
¶
Execute the generative function, sampling from its distribution over samples, and return a Trace
.
More on traces¶
The Trace
returned by simulate
implements its own interface.
It is responsible for storing the arguments of the invocation (genjax.Trace.get_args
), the return value of the generative function (genjax.Trace.get_retval
), the identity of the generative function which produced the trace (genjax.Trace.get_gen_fn
), the sample of traced random choices produced during the invocation (genjax.Trace.get_choices
) and the score of the sample (genjax.Trace.get_score
).
Examples:
import genjax
import jax
from jax import vmap, jit
from jax.random import split
@genjax.gen
def model():
x = genjax.normal(0.0, 1.0) @ "x"
return x
key = jax.random.key(0)
tr = model.simulate(key, ())
print(tr.render_html())
Another example, using the same model, composed into genjax.repeat
- which creates a new generative function, which has the same interface:
@genjax.gen
def model():
x = genjax.normal(0.0, 1.0) @ "x"
return x
key = jax.random.key(0)
tr = model.repeat(n=10).simulate(key, ())
print(tr.render_html())
(Fun, flirty, fast ... parallel?) Feel free to use jax.jit
and jax.vmap
!
key = jax.random.key(0)
sub_keys = split(key, 10)
sim = model.repeat(n=10).simulate
tr = jit(vmap(sim, in_axes=(0, None)))(sub_keys, ())
print(tr.render_html())
Source code in src/genjax/_src/core/generative/generative_function.py
switch
¶
switch(*branches: GenerativeFunction[R]) -> Switch[R]
Given n
genjax.GenerativeFunction
inputs, returns a new genjax.GenerativeFunction
that accepts n+2
arguments:
- an index in the range \([0, n+1)\)
- a tuple of arguments for
self
and each of the input generative functions (n+1
total tuples)
and executes the generative function at the supplied index with its provided arguments.
If index
is out of bounds, index
is clamped to within bounds.
Examples:
import jax, genjax
@genjax.gen
def branch_1():
x = genjax.normal(0.0, 1.0) @ "x1"
@genjax.gen
def branch_2():
x = genjax.bernoulli(0.3) @ "x2"
switch = branch_1.switch(branch_2)
key = jax.random.key(314159)
jitted = jax.jit(switch.simulate)
# Select `branch_2` by providing 1:
tr = jitted(key, (1, (), ()))
print(tr.render_html())
Source code in src/genjax/_src/core/generative/generative_function.py
vmap
¶
vmap(*, in_axes: InAxes = 0) -> GenerativeFunction[R]
Returns a GenerativeFunction
that performs a vectorized map over the argument specified by in_axes
. Traced values are nested under an index, and the retval is vectorized.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
|
InAxes
|
Selector specifying which input arguments (or index into them) should be vectorized. Defaults to 0, i.e., the first argument. See this link for more detail. |
0
|
Returns:
Type | Description |
---|---|
GenerativeFunction[R]
|
A new |
Examples:
import jax
import jax.numpy as jnp
import genjax
@genjax.gen
def model(x):
v = genjax.normal(x, 1.0) @ "v"
return genjax.normal(v, 0.01) @ "q"
vmapped = model.vmap(in_axes=0)
key = jax.random.key(314159)
arr = jnp.ones(100)
# `vmapped` accepts an array if numbers instead of the original
# single number that `model` accepted.
tr = jax.jit(vmapped.simulate)(key, (arr,))
print(tr.render_html())
Source code in src/genjax/_src/core/generative/generative_function.py
Traces are data structures which record (execution and inference) data about the invocation of generative functions. Traces are often specialized to a generative function language, to take advantage of data locality, and other representation optimizations. Traces support a trace interface: a set of accessor methods designed to provide convenient manipulation when handling traces in inference algorithms. We document this interface below for the Trace
data type.
genjax.core.Trace
¶
Bases: Generic[R]
, Pytree
Trace
is the type of traces of generative functions.
A trace is a data structure used to represent sampled executions of generative functions. Traces track metadata associated with the probabilities of choices, as well as other data associated with the invocation of a generative function, including the arguments it was invoked with, its return value, and the identity of the generative function itself.
Methods:
Name | Description |
---|---|
edit |
This method calls out to the underlying |
get_args |
Returns the |
get_choices |
Retrieves the random choices made in a trace in the form of a |
get_gen_fn |
Returns the |
get_inner_trace |
Override this method to provide |
get_retval |
Returns the |
get_score |
Return the |
get_subtrace |
Return the subtrace having the supplied address. Specifying multiple addresses |
update |
This method calls out to the underlying |
Source code in src/genjax/_src/core/generative/generative_function.py
72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 |
|
edit
¶
edit(
key: PRNGKey,
request: EditRequest,
argdiffs: tuple[Any, ...] | None = None,
) -> tuple[Self, Weight, Retdiff[R], EditRequest]
This method calls out to the underlying GenerativeFunction.edit
method - see EditRequest
and edit
for more information.
Source code in src/genjax/_src/core/generative/generative_function.py
get_args
abstractmethod
¶
get_args() -> Arguments
Returns the Arguments
for the GenerativeFunction
invocation which created the Trace
.
Source code in src/genjax/_src/core/generative/generative_function.py
get_choices
abstractmethod
¶
get_choices() -> ChoiceMap
Retrieves the random choices made in a trace in the form of a genjax.ChoiceMap
.
get_gen_fn
abstractmethod
¶
get_gen_fn() -> GenerativeFunction[R]
Returns the GenerativeFunction
whose invocation created the Trace
.
get_inner_trace
¶
get_inner_trace(_address: Address) -> Trace[Any]
Override this method to provide Trace.get_subtrace
support
for those trace types that have substructure that can be addressed
in this way.
NOTE: get_inner_trace
takes a full Address
because, unlike ChoiceMap
, if a user traces to a tupled address like ("a", "b"), then the resulting StaticTrace
will store a sub-trace at this address, vs flattening it out.
As a result, tr.get_inner_trace(("a", "b"))
does not equal tr.get_inner_trace("a").get_inner_trace("b")
.
Source code in src/genjax/_src/core/generative/generative_function.py
get_retval
abstractmethod
¶
Returns the R
from the GenerativeFunction
invocation which created the Trace
.
get_score
abstractmethod
¶
get_score() -> Score
Return the Score
of the Trace
.
The score must satisfy a particular mathematical specification: it's either an exact density evaluation of \(P\) (the distribution over samples) for the sample returned by genjax.Trace.get_choices
, or a sample from an estimator (a density estimate) if the generative function contains untraced randomness.
Let \(s\) be the score, \(t\) the sample, and \(a\) the arguments: when the generative function contains no untraced randomness, the score (in logspace) is given by:
(With untraced randomness) Gen allows for the possibility of sources of randomness which are not traced. When these sources are included in generative computations, the score is defined so that the following property holds:
This property is the one you'd want to be true if you were using a generative function with untraced randomness as a proposal in a routine which uses importance sampling, for instance.
In GenJAX, one way you might encounter this is by using pseudo-random routines in your modeling code:
# notice how the key is explicit
@genjax.gen
def model_with_untraced_randomness(key: PRNGKey):
x = genjax.normal(0.0, 1.0) "x"
v = some_random_process(key, x)
y = genjax.normal(v, 1.0) @ "y"
In this case, the score (in logspace) is given by:
which satisfies the requirement by virtue of the fact:
Source code in src/genjax/_src/core/generative/generative_function.py
get_subtrace
¶
get_subtrace(*addresses: Address) -> Trace[Any]
Return the subtrace having the supplied address. Specifying multiple addresses will apply the operation recursively.
GenJAX does not guarantee the validity of any inference computations performed
using information from the returned subtrace. In other words, it is safe to
inspect the data of subtraces -- but it not safe to use that data to make decisions
about inference. This is true of all the methods on the subtrace, including
Trace.get_args
, Trace.get_score
, Trace.get_retval
, etc. It is safe to look,
but don't use the data for non-trivial things!
Source code in src/genjax/_src/core/generative/generative_function.py
update
¶
update(
key: PRNGKey,
constraint: ChoiceMap,
argdiffs: tuple[Any, ...] | None = None,
) -> tuple[Self, Weight, Retdiff[R], ChoiceMap]
This method calls out to the underlying GenerativeFunction.edit
method - see EditRequest
and edit
for more information.
Source code in src/genjax/_src/core/generative/generative_function.py
genjax.core.EditRequest
¶
Bases: Pytree
An EditRequest
is a request to edit a trace of a generative function. Generative functions respond to instances of subtypes of EditRequest
by providing an edit
implementation.
Updating a trace is a common operation in inference processes, but naively mutating the trace will invalidate the mathematical invariants that Gen retains. EditRequest
instances denote requests for SMC moves in the framework of SMCP3, which preserve these invariants.
Source code in src/genjax/_src/core/generative/concepts.py
Generative functions with addressed random choices¶
Generative functions will often include addressed random choices. These are random choices which are given a name via an addressing syntax, and can be accessed by name via extended interfaces on the ChoiceMap
type which supports the addressing.
genjax.core.ChoiceMap
¶
Bases: Pytree
The type ChoiceMap
denotes a map-like value which can be sampled from
generative functions.
Generative functions which utilize ChoiceMap
as their sample representation typically support a notion of addressing for the random choices they make. ChoiceMap
stores addressed random choices, and provides a data language for querying and manipulating these choices.
Examples:
(Making choice maps) Choice maps can be constructed using the ChoiceMapBuilder
interface
(Getting submaps) Hierarchical choice maps support __call__
, which allows for the retrieval of submaps at addresses:
from genjax import ChoiceMapBuilder as C
chm = C["x", "y"].set(3.0)
submap = chm("x")
print(submap.render_html())
(Getting values) Choice maps support __getitem__
, which allows for the retrieval of values at addresses:
from genjax import ChoiceMapBuilder as C
chm = C["x", "y"].set(3.0)
value = chm["x", "y"]
print(value)
(Making vectorized choice maps) Choice maps can be constructed using jax.vmap
:
from genjax import ChoiceMapBuilder as C
from jax import vmap
import jax.numpy as jnp
vec_chm = vmap(lambda idx, v: C["x", idx].set(v))(jnp.arange(10), jnp.ones(10))
print(vec_chm.render_html())
Methods:
Name | Description |
---|---|
__call__ |
Alias for |
choice |
Creates a ChoiceMap containing a single value. |
d |
Creates a ChoiceMap from a dictionary. |
empty |
Returns a ChoiceMap with no values or submaps. |
entry |
Creates a ChoiceMap with a single value at a specified address. |
extend |
Returns a new ChoiceMap with the given address component as its root. |
filter |
Filter the choice map on the |
from_mapping |
Creates a ChoiceMap from an iterable of address-value pairs. |
get_selection |
Returns a Selection representing the structure of this ChoiceMap. |
invalid_subset |
Identifies the subset of choices that are invalid for a given generative function and its arguments. |
kw |
Creates a ChoiceMap from keyword arguments. |
mask |
Returns a new ChoiceMap with values masked by a boolean flag. |
merge |
Merges this ChoiceMap with another ChoiceMap. |
simplify |
Previously pushed down filters, now acts as identity. |
static_is_empty |
Returns True if this ChoiceMap is equal to |
switch |
Creates a ChoiceMap that switches between multiple ChoiceMaps based on an index. |
Attributes:
Name | Type | Description |
---|---|---|
at |
_ChoiceMapBuilder
|
Returns a _ChoiceMapBuilder instance for constructing nested ChoiceMaps. |
Source code in src/genjax/_src/core/generative/choice_map.py
847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913 914 915 916 917 918 919 920 921 922 923 924 925 926 927 928 929 930 931 932 933 934 935 936 937 938 939 940 941 942 943 944 945 946 947 948 949 950 951 952 953 954 955 956 957 958 959 960 961 962 963 964 965 966 967 968 969 970 971 972 973 974 975 976 977 978 979 980 981 982 983 984 985 986 987 988 989 990 991 992 993 994 995 996 997 998 999 1000 1001 1002 1003 1004 1005 1006 1007 1008 1009 1010 1011 1012 1013 1014 1015 1016 1017 1018 1019 1020 1021 1022 1023 1024 1025 1026 1027 1028 1029 1030 1031 1032 1033 1034 1035 1036 1037 1038 1039 1040 1041 1042 1043 1044 1045 1046 1047 1048 1049 1050 1051 1052 1053 1054 1055 1056 1057 1058 1059 1060 1061 1062 1063 1064 1065 1066 1067 1068 1069 1070 1071 1072 1073 1074 1075 1076 1077 1078 1079 1080 1081 1082 1083 1084 1085 1086 1087 1088 1089 1090 1091 1092 1093 1094 1095 1096 1097 1098 1099 1100 1101 1102 1103 1104 1105 1106 1107 1108 1109 1110 1111 1112 1113 1114 1115 1116 1117 1118 1119 1120 1121 1122 1123 1124 1125 1126 1127 1128 1129 1130 1131 1132 1133 1134 1135 1136 1137 1138 1139 1140 1141 1142 1143 1144 1145 1146 1147 1148 1149 1150 1151 1152 1153 1154 1155 1156 1157 1158 1159 1160 1161 1162 1163 1164 1165 1166 1167 1168 1169 1170 1171 1172 1173 1174 1175 1176 1177 1178 1179 1180 1181 1182 1183 1184 1185 1186 1187 1188 1189 1190 1191 1192 1193 1194 1195 1196 1197 1198 1199 1200 1201 1202 1203 1204 1205 1206 1207 1208 1209 1210 1211 1212 1213 1214 1215 1216 1217 1218 1219 1220 1221 1222 1223 1224 1225 1226 1227 1228 1229 1230 1231 1232 1233 1234 1235 1236 1237 1238 1239 1240 1241 1242 1243 1244 1245 1246 1247 1248 1249 1250 1251 1252 1253 1254 1255 1256 1257 1258 1259 1260 1261 1262 1263 1264 1265 1266 1267 1268 1269 1270 1271 1272 1273 1274 1275 1276 1277 1278 1279 1280 1281 1282 1283 1284 1285 1286 1287 1288 1289 1290 1291 1292 1293 1294 1295 1296 1297 1298 1299 1300 1301 1302 1303 1304 1305 1306 1307 1308 1309 1310 1311 1312 1313 1314 1315 1316 1317 1318 1319 1320 1321 1322 1323 1324 1325 1326 1327 1328 1329 1330 1331 1332 1333 1334 1335 1336 1337 1338 1339 1340 1341 1342 1343 1344 1345 1346 1347 1348 1349 1350 1351 1352 1353 1354 1355 1356 1357 1358 1359 1360 1361 1362 1363 1364 1365 1366 1367 1368 1369 1370 1371 1372 1373 1374 1375 1376 1377 1378 1379 1380 1381 1382 1383 1384 1385 1386 1387 1388 1389 1390 1391 1392 1393 |
|
at
property
¶
Returns a _ChoiceMapBuilder instance for constructing nested ChoiceMaps.
This property allows for a fluent interface to build complex ChoiceMaps by chaining address components and setting values.
Returns:
Type | Description |
---|---|
_ChoiceMapBuilder
|
A builder object for constructing ChoiceMaps. |
choice
staticmethod
¶
Creates a ChoiceMap containing a single value.
This method creates and returns an instance of Choice, which represents a ChoiceMap with a single value at the root level.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
|
Any
|
The value to be stored in the ChoiceMap. |
required |
Returns:
Type | Description |
---|---|
ChoiceMap
|
A ChoiceMap containing the single value. |
Example
Source code in src/genjax/_src/core/generative/choice_map.py
d
staticmethod
¶
Creates a ChoiceMap from a dictionary.
This method creates and returns a ChoiceMap based on the key-value pairs in the provided dictionary. Each key in the dictionary becomes an address in the ChoiceMap, and the corresponding value is stored at that address.
Dict-shaped values are recursively converted to ChoiceMap instances.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
|
dict[K_addr, Any]
|
A dictionary where keys are addresses and values are the corresponding data to be stored in the ChoiceMap. |
required |
Returns:
Type | Description |
---|---|
ChoiceMap
|
A ChoiceMap containing the key-value pairs from the input dictionary. |
Example
from genjax import ChoiceMap
dict_chm = ChoiceMap.d({"x": 42, "y": {"z": [1, 2, 3]}})
assert dict_chm["x"] == 42
assert dict_chm["y", "z"] == [1, 2, 3]
Source code in src/genjax/_src/core/generative/choice_map.py
empty
staticmethod
¶
empty() -> ChoiceMap
Returns a ChoiceMap with no values or submaps.
Returns:
Type | Description |
---|---|
ChoiceMap
|
An empty ChoiceMap. |
entry
staticmethod
¶
Creates a ChoiceMap with a single value at a specified address.
This method creates and returns a ChoiceMap with a new ChoiceMap stored at the given address.
- if the provided value is already a ChoiceMap, it will be used directly;
dict
values will be passed toChoiceMap.d
;- any other value will be passed to
ChoiceMap.value
.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
|
dict[K_addr, Any] | ChoiceMap | Any
|
The value to be stored in the ChoiceMap. Can be any value, a dict or a ChoiceMap. |
required |
|
AddressComponent
|
The address at which to store the value. Can be a static or dynamic address component. |
()
|
Returns:
Type | Description |
---|---|
ChoiceMap
|
A ChoiceMap with the value stored at the specified address. |
Example
import genjax
import jax.numpy as jnp
# Using an existing ChoiceMap
nested_chm = ChoiceMap.entry(ChoiceMap.value(42), "x")
assert nested_chm["x"] == 42
# Using a dict generates a new `ChoiceMap.d` call
nested_chm = ChoiceMap.entry({"y": 42}, "x")
assert nested_chm["x", "y"] == 42
# Static address
static_chm = ChoiceMap.entry(42, "x")
assert static_chm["x"] == 42
# Dynamic address
dynamic_chm = ChoiceMap.entry(
jnp.array([1.1, 2.2, 3.3]), jnp.array([1, 2, 3])
)
assert dynamic_chm[1] == genjax.Mask(1.1, True)
Source code in src/genjax/_src/core/generative/choice_map.py
extend
¶
extend(*addrs: AddressComponent) -> ChoiceMap
Returns a new ChoiceMap with the given address component as its root.
This method creates a new ChoiceMap where the current ChoiceMap becomes a submap under the specified address component. It effectively adds a new level of hierarchy to the ChoiceMap structure.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
|
AddressComponent
|
The address components to use as the new root. |
()
|
Returns:
Type | Description |
---|---|
ChoiceMap
|
A new ChoiceMap with the current ChoiceMap nested under the given address. |
Example
original_chm = ChoiceMap.value(42)
indexed_chm = original_chm.extend("x")
assert indexed_chm["x"] == 42
Source code in src/genjax/_src/core/generative/choice_map.py
filter
abstractmethod
¶
Filter the choice map on the Selection
. The resulting choice map only contains the addresses that return True when presented to the selection.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
|
Selection | Flag
|
The Selection to filter the choice map with. |
required |
Returns:
Type | Description |
---|---|
ChoiceMap
|
A new ChoiceMap containing only the addresses selected by the given Selection. |
Examples:
import jax
import genjax
from genjax import bernoulli
from genjax import SelectionBuilder as S
@genjax.gen
def model():
x = bernoulli(0.3) @ "x"
y = bernoulli(0.3) @ "y"
return x
key = jax.random.key(314159)
tr = model.simulate(key, ())
chm = tr.get_choices()
selection = S["x"]
filtered = chm.filter(selection)
assert "y" not in filtered
Source code in src/genjax/_src/core/generative/choice_map.py
from_mapping
staticmethod
¶
Creates a ChoiceMap from an iterable of address-value pairs.
This method constructs a ChoiceMap by iterating through the provided pairs, where each pair consists of an address (or address component) and a corresponding value. The resulting ChoiceMap will contain all the values at their respective addresses.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
|
Iterable[tuple[K_addr, Any]]
|
An iterable of tuples, where each tuple contains an address (or address component) and its corresponding value. The address can be a single component or a tuple of components. |
required |
Returns:
Type | Description |
---|---|
ChoiceMap
|
A ChoiceMap containing all the address-value pairs from the input. |
Example
Note
If multiple pairs have the same address, later pairs will overwrite earlier ones.
Source code in src/genjax/_src/core/generative/choice_map.py
get_selection
¶
get_selection() -> Selection
Returns a Selection representing the structure of this ChoiceMap.
This method creates a Selection that matches the hierarchical structure of the current ChoiceMap. The resulting Selection can be used to filter or query other ChoiceMaps with the same structure.
Returns:
Type | Description |
---|---|
Selection
|
A Selection object representing the structure of this ChoiceMap. |
Example
chm = ChoiceMap.value(5).extend("x")
sel = chm.get_selection()
assert sel["x"] == True
assert sel["y"] == False
Source code in src/genjax/_src/core/generative/choice_map.py
invalid_subset
¶
invalid_subset(
gen_fn: GenerativeFunction[Any], args: tuple[Any, ...]
) -> ChoiceMap | None
Identifies the subset of choices that are invalid for a given generative function and its arguments.
This method checks if all choices in the current ChoiceMap are valid for the given generative function and its arguments.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
|
GenerativeFunction[Any]
|
The generative function to check against. |
required |
|
tuple[Any, ...]
|
The arguments to the generative function. |
required |
Returns:
Type | Description |
---|---|
ChoiceMap | None
|
A ChoiceMap containing any extra choices not reachable in the course of |
Example
@genjax.gen
def model(x):
y = bernoulli(0.5) @ "y"
return x + y
chm = ChoiceMap.d({"y": 1, "z": 2})
extras = chm.invalid_subset(model, (1,))
assert "z" in extras # "z" is an extra choice not in the model
Source code in src/genjax/_src/core/generative/choice_map.py
kw
staticmethod
¶
kw(**kwargs) -> ChoiceMap
Creates a ChoiceMap from keyword arguments.
This method creates and returns a ChoiceMap based on the provided keyword arguments. Each keyword argument becomes an address in the ChoiceMap, and its value is stored at that address.
Dict-shaped values are recursively converted to ChoiceMap instances with calls to ChoiceMap.d
.
Returns:
Type | Description |
---|---|
ChoiceMap
|
A ChoiceMap containing the key-value pairs from the input keyword arguments. |
Example
kw_chm = ChoiceMap.kw(x=42, y=[1, 2, 3], z={"w": 10.0})
assert kw_chm["x"] == 42
assert kw_chm["y"] == [1, 2, 3]
assert kw_chm["z", "w"] == 10.0
Source code in src/genjax/_src/core/generative/choice_map.py
mask
¶
Returns a new ChoiceMap with values masked by a boolean flag.
This method creates a new ChoiceMap where the values are conditionally included based on the provided flag. If the flag is True, the original values are retained; if False, the ChoiceMap behaves as if it's empty.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
|
Flag
|
A boolean flag determining whether to include the values. |
required |
Returns:
Type | Description |
---|---|
ChoiceMap
|
A new ChoiceMap with values conditionally masked. |
Example
original_chm = ChoiceMap.value(42)
masked_chm = original_chm.mask(True)
assert masked_chm.get_value() == 42
masked_chm = original_chm.mask(False)
assert masked_chm.get_value() is None
Source code in src/genjax/_src/core/generative/choice_map.py
merge
¶
Merges this ChoiceMap with another ChoiceMap.
This method combines the current ChoiceMap with another ChoiceMap using the XOR operation (^). It creates a new ChoiceMap that contains all addresses from both input ChoiceMaps; any overlapping addresses will trigger an error on access at the address via [<addr>]
or get_value()
. Use |
if you don't want this behavior.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
|
ChoiceMap
|
The ChoiceMap to merge with the current one. |
required |
Returns:
Type | Description |
---|---|
ChoiceMap
|
A new ChoiceMap resulting from the merge operation. |
Example
chm1 = ChoiceMap.value(5).extend("x")
chm2 = ChoiceMap.value(10).extend("y")
merged_chm = chm1.merge(chm2)
assert merged_chm["x"] == 5
assert merged_chm["y"] == 10
Note
This method is equivalent to using the | operator between two ChoiceMaps.
Source code in src/genjax/_src/core/generative/choice_map.py
simplify
¶
simplify() -> ChoiceMap
Previously pushed down filters, now acts as identity.
Source code in src/genjax/_src/core/generative/choice_map.py
static_is_empty
¶
Returns True if this ChoiceMap is equal to ChoiceMap.empty()
, False otherwise.
switch
staticmethod
¶
Creates a ChoiceMap that switches between multiple ChoiceMaps based on an index.
This method creates a new ChoiceMap that selectively includes values from a sequence of input ChoiceMaps based on the provided index. The resulting ChoiceMap will contain values from the ChoiceMap at the position specified by the index, while masking out values from all other ChoiceMaps.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
|
int | IntArray
|
An index or array of indices specifying which ChoiceMap(s) to select from. |
required |
|
Iterable[ChoiceMap]
|
An iterable of ChoiceMaps to switch between. |
required |
Returns:
Type | Description |
---|---|
ChoiceMap
|
A new ChoiceMap containing values from the selected ChoiceMap(s). |
Example
chm1 = ChoiceMap.d({"x": 1, "y": 2})
chm2 = ChoiceMap.d({"x": 3, "y": 4})
chm3 = ChoiceMap.d({"x": 5, "y": 6})
switched = ChoiceMap.switch(jnp.array(1), [chm1, chm2, chm3])
assert switched["x"].unmask() == 3
assert switched["y"].unmask() == 4
Source code in src/genjax/_src/core/generative/choice_map.py
genjax.core.Selection
¶
Bases: Pytree
A class representing a selection of addresses in a ChoiceMap.
Selection objects are used to filter and manipulate ChoiceMaps by specifying which addresses should be included or excluded.
Selection instances support various operations such as union (via &
), intersection (via |
), and complement (via ~
), allowing for complex selection criteria to be constructed.
Methods:
Name | Description |
---|---|
all |
Creates a Selection that includes all addresses. |
none |
Creates a Selection that includes no addresses. |
at |
A builder instance for creating Selection objects using indexing syntax. |
Examples:
Creating selections:
import genjax
from genjax import Selection
# Select all addresses
all_sel = Selection.all()
# Select no addresses
none_sel = Selection.none()
# Select specific addresses
specific_sel = Selection.at["x", "y"]
# Match (<wildcard>, "y")
wildcard_sel = Selection.at[..., "y"]
# Combine selections
combined_sel = specific_sel | Selection.at["z"]
Querying selections:
# Create a selection
sel = Selection.at["x", "y"]
# Querying the selection using () returns a sub-selection
assert sel("x") == Selection.at["y"]
assert sel("z") == Selection.none()
# Querying the selection using [] returns a `bool` representing whether or not the input matches:
assert sel["x"] == False
assert sel["x", "y"] == True
# Querying the selection using "in" acts the same:
assert not "x" in sel
assert ("x", "y") in sel
# Nested querying
nested_sel = Selection.at["a", "b", "c"]
assert nested_sel("a")("b") == Selection.at["c"]
Selection objects can passed to a ChoiceMap
via the filter
method to filter and manipulate data based on address patterns.
Attributes:
Name | Type | Description |
---|---|---|
at |
Final[_SelectionBuilder]
|
A builder instance for creating Selection objects. |
Source code in src/genjax/_src/core/generative/choice_map.py
124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 |
|
at
class-attribute
instance-attribute
¶
A builder instance for creating Selection objects.
at
provides a convenient interface for constructing Selection objects
using a familiar indexing syntax. It allows for the creation of complex
selections by chaining multiple address components.
Examples:
Creating a selection:
all
staticmethod
¶
all() -> Selection
Returns a Selection that selects all addresses.
Returns:
Type | Description |
---|---|
Selection
|
A Selection that selects everything. |
Example
from genjax import Selection
all_selection = Selection.all()
assert all_selection["any_address"] == True
Source code in src/genjax/_src/core/generative/choice_map.py
extend
¶
extend(*addrs: ExtendedStaticAddressComponent) -> Selection
Returns a new Selection that is prefixed by the given address components.
This method creates a new Selection that applies the current selection to the specified address components. It handles both static and dynamic address components.
Note that ...
as an address component will match any supplied address.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
|
ExtendedStaticAddressComponent
|
The address components under which to nest the selection. |
()
|
Returns:
Type | Description |
---|---|
Selection
|
A new Selection extended by the given address component. |
Example
base_selection = Selection.all()
indexed_selection = base_selection.extend("x")
assert indexed_selection["x", "any_subaddress"] == True
assert indexed_selection["y"] == False
Source code in src/genjax/_src/core/generative/choice_map.py
filter
¶
Returns a new ChoiceMap filtered with this Selection.
This method applies the current Selection to the given ChoiceMap, effectively filtering out addresses that are not matched.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
|
ChoiceMap
|
The ChoiceMap to be filtered. |
required |
Returns:
Type | Description |
---|---|
ChoiceMap
|
A new ChoiceMap containing only the addresses selected by this Selection. |
Example
selection = Selection.at["x"]
chm = ChoiceMap.kw(x=1, y=2)
filtered_chm = selection.filter(chm)
assert "x" in filtered_chm
assert "y" not in filtered_chm
Source code in src/genjax/_src/core/generative/choice_map.py
leaf
staticmethod
¶
leaf() -> Selection
Returns a Selection that selects only leaf addresses.
A leaf address is an address that doesn't have any sub-addresses. This selection is useful when you want to target only the final elements in a nested structure.
Returns:
Type | Description |
---|---|
Selection
|
A Selection that selects only leaf addresses. |
Example
leaf_selection = Selection.leaf().extend("a", "b")
assert leaf_selection["a", "b"]
assert not leaf_selection["a", "b", "anything"]
Source code in src/genjax/_src/core/generative/choice_map.py
none
staticmethod
¶
none() -> Selection
Returns a Selection that selects no addresses.
Returns:
Type | Description |
---|---|
Selection
|
A Selection that selects nothing. |
Source code in src/genjax/_src/core/generative/choice_map.py
JAX compatible data via Pytree
¶
JAX natively works with arrays, and with instances of Python classes which can be broken down into lists of arrays. JAX's Pytree
system provides a way to register a class with methods that can break instances of the class down into a list of arrays (canonically referred to as flattening), and build an instance back up given a list of arrays (canonically referred to as unflattening).
GenJAX provides an abstract class called Pytree
which automates the implementation of the flatten
/ unflatten
methods for a class. GenJAX's Pytree
inherits from penzai.Struct
, to support pretty printing, and some convenient methods to annotate what data should be part of the Pytree
type (static fields, won't be broken down into a JAX array) and what data should be considered dynamic.
genjax.core.Pytree
¶
Bases: Struct
Pytree
is an abstract base class which registers a class with JAX's Pytree
system. JAX's Pytree
system tracks how data classes should behave across JAX-transformed function boundaries, like jax.jit
or jax.vmap
.
Inheriting this class provides the implementor with the freedom to declare how the subfields of a class should behave:
Pytree.static(...)
: the value of the field cannot be a JAX traced value, it must be a Python literal, or a constant). The values of static fields are embedded in thePyTreeDef
of any instance of the class.Pytree.field(...)
or no annotation: the value may be a JAX traced value, and JAX will attempt to convert it to tracer values inside of its transformations.
If a field points to another Pytree
, it should not be declared as Pytree.static()
, as the Pytree
interface will automatically handle the Pytree
fields as dynamic fields.
Methods:
Name | Description |
---|---|
dataclass |
Denote that a class (which is inheriting |
static |
Declare a field of a |
field |
Declare a field of a |
Source code in src/genjax/_src/core/pytree.py
40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 |
|
dataclass
staticmethod
¶
Denote that a class (which is inheriting Pytree
) should be treated as a dataclass, meaning it can hold data in fields which are declared as part of the class.
A dataclass is to be distinguished from a "methods only" Pytree
class, which does not have fields, but may define methods.
The latter cannot be instantiated, but can be inherited from, while the former can be instantiated:
the Pytree.dataclass
declaration informs the system how to instantiate the class as a dataclass,
and how to automatically define JAX's Pytree
interfaces (tree_flatten
, tree_unflatten
, etc.) for the dataclass, based on the fields declared in the class, and possibly Pytree.static(...)
or Pytree.field(...)
annotations (or lack thereof, the default is that all fields are Pytree.field(...)
).
All Pytree
dataclasses support pretty printing, as well as rendering to HTML.
Examples:
from genjax import Pytree
from genjax.typing import FloatArray
import jax.numpy as jnp
@Pytree.dataclass
# Enforces type annotations on instantiation.
class MyClass(Pytree):
my_static_field: int = Pytree.static()
my_dynamic_field: FloatArray
print(MyClass(10, jnp.array(5.0)).render_html())
Source code in src/genjax/_src/core/pytree.py
static
staticmethod
¶
Declare a field of a Pytree
dataclass to be static. Users can provide additional keyword argument options,
like default
or default_factory
, to customize how the field is instantiated when an instance of
the dataclass is instantiated.` Fields which are provided with default values must come after required fields in the dataclass declaration.
Examples:
@Pytree.dataclass
# Enforces type annotations on instantiation.
class MyClass(Pytree):
my_dynamic_field: FloatArray
my_static_field: int = Pytree.static(default=0)
print(MyClass(jnp.array(5.0)).render_html())
Source code in src/genjax/_src/core/pytree.py
field
staticmethod
¶
Declare a field of a Pytree
dataclass to be dynamic. Alternatively, one can leave the annotation off in the declaration.
genjax.core.Const
¶
Bases: Generic[R]
, Pytree
JAX-compatible way to tag a value as a constant. Valid constants include Python literals, strings, essentially anything that won't hold JAX arrays inside of a computation.
Examples:
Instances of Const
can be created using a Pytree
classmethod:
Constants can be freely used across jax.jit
boundaries:
from genjax import Pytree
def f(c):
if c.unwrap() == 5:
return 10.0
else:
return 5.0
c = Pytree.const(5)
r = jax.jit(f)(c)
print(r)
Methods:
Name | Description |
---|---|
unwrap |
Unwrap a constant value from a |
Source code in src/genjax/_src/core/pytree.py
unwrap
¶
Unwrap a constant value from a Const
instance.
This method can be used as an instance method or as a static method. When used as a static method, it returns the input value unchanged if it is not a Const
instance.
Returns:
Name | Type | Description |
---|---|---|
R |
R
|
The unwrapped value if self is a |
Examples:
from genjax import Pytree, Const
c = Pytree.const(5)
val = c.unwrap() # Returns 5
# Can also be used as static method
val = Const.unwrap(10) # Returns 10 unchanged
Source code in src/genjax/_src/core/pytree.py
genjax.core.Closure
¶
Bases: Generic[R]
, Pytree
JAX-compatible closure type. It's a closure as a Pytree
- meaning the static source code / callable is separated from dynamic data (which must be tracked by JAX).
Examples:
Instances of Closure
can be created using Pytree.partial
-- note the order of the "closed over" arguments:
from genjax import Pytree
def g(y):
@Pytree.partial(y) # dynamic values come first
def f(v, x):
# v will be bound to the value of y
return x * (v * 5.0)
return f
clos = jax.jit(g)(5.0)
print(clos.render_html())
Closures can be invoked / JIT compiled in other code:
Source code in src/genjax/_src/core/pytree.py
Dynamism in JAX: masks and sum types¶
The semantics of Gen are defined independently of any particular computational substrate or implementation - but JAX (and XLA through JAX) is a unique substrate, offering high performance, the ability to transformation code ahead-of-time via program transformations, and ... a rather unique set of restrictions.
JAX is a two-phase system¶
While not yet formally modelled, it's appropriate to think of JAX as separating computation into two phases:
- The statics phase (which occurs at JAX tracing / transformation time).
- The runtime phase (which occurs when a computation written in JAX is actually deployed via XLA and executed on a physical device somewhere in the world).
JAX has different rules for handling values depending on which phase we are in.
For instance, JAX disallows usage of runtime values to resolve Python control flow at tracing time (intuition: we don't actually know the value yet!) and will error if the user attempts to trace through a Python program with incorrect usage of runtime values.
In GenJAX, we take advantage of JAX's tracing to construct code which, when traced, produces specialized code depending on static information. At the same time, we are careful to encode Gen's interfaces to respect JAX's rules which govern how static / runtime values can be used.
The most primitive way to encode runtime uncertainty about a piece of data is to attach a bool
to it, which indicates whether the data is "on" or "off".
GenJAX contains a system for tagging data with flags, to indicate if the data is valid or invalid during inference interface computations at runtime. The key data structure which supports this system is genjax.core.Mask
.
genjax.core.Mask
¶
Bases: Generic[R]
, Pytree
The Mask
datatype wraps a value in a Boolean flag which denotes whether the data is valid or invalid to use in inference computations.
Masks can be used in a variety of ways as part of generative computations - their primary role is to denote data which is valid under inference computations. Valid data can be used as ChoiceMap
leaves, and participate in generative and inference computations (like scores, and importance weights or density ratios). A Mask with a False flag should be considered unusable, and should be handled with care.
If a flag
has a non-scalar shape, that implies that the mask is vectorized, and that the ArrayLike
value, or each leaf in the pytree, must have the flag's shape as its prefix (i.e., must have been created with a jax.vmap
call or via a GenJAX vmap
combinator).
Encountering Mask
in your computation¶
When users see Mask
in their computations, they are expected to interact with them by either:
-
Unmasking them using the
Mask.unmask
interface, a potentially unsafe operation. -
Destructuring them manually, and handling the cases.
Usage of invalid data¶
If you use invalid Mask(data, False)
data in inference computations, you may encounter silently incorrect results.
Methods:
Name | Description |
---|---|
unmask |
Unmask the |
Source code in src/genjax/_src/core/generative/functional_types.py
42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 |
|
unmask
¶
unmask(default: R | None = None) -> R
Unmask the Mask
, returning the value within.
This operation is inherently unsafe with respect to inference semantics if no default value is provided. It is only valid if the Mask
wraps valid data at runtime, or if a default value is supplied.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
|
R | None
|
An optional default value to return if the mask is invalid. |
None
|
Returns:
Type | Description |
---|---|
R
|
The unmasked value if valid, or the default value if provided and the mask is invalid. |
Source code in src/genjax/_src/core/generative/functional_types.py
Static typing with genjax.typing
a.k.a 🐻beartype
🐻¶
GenJAX uses beartype
to perform type checking during JAX tracing / compile time. This means that beartype
, normally a fast runtime type checker, operates at JAX tracing time to ensure that the arguments and return values are correct, with zero runtime cost.
Generative interface types¶
genjax.core.Arguments
module-attribute
¶
Arguments
is the type of argument values to generative functions. It is a type alias for Tuple
, and is used to improve readability and parsing of interface specifications.
genjax.core.Score
module-attribute
¶
A score is a density ratio, described fully in simulate
.
The type Score
does not enforce any meaningful mathematical invariants, but is used to denote the type of scores in the GenJAX system, to improve readability and parsing of interface specifications.
genjax.core.Weight
module-attribute
¶
A weight is a density ratio which often occurs in the context of proper weighting for Target
distributions, or in Gen's edit
interface, whose mathematical content is described in edit
.
The type Weight
does not enforce any meaningful mathematical invariants, but is used to denote the type of weights in GenJAX, to improve readability and parsing of interface specifications / expectations.
genjax.core.Retdiff
module-attribute
¶
Retdiff
is the type of return values with an attached ChangeType
(c.f. edit
).
When used under type checking, Retdiff
assumes that the return value is a Pytree
(either, defined via GenJAX's Pytree
interface or registered with JAX's system). It checks that the leaves are Diff
type with attached ChangeType
.
genjax.core.Argdiffs
module-attribute
¶
Argdiffs
is the type of argument values with an attached ChangeType
(c.f. edit
).
When used under type checking, Retdiff
assumes that the argument values are Pytree
(either, defined via GenJAX's Pytree
interface or registered with JAX's system). For each argument, it checks that the leaves are Diff
type with attached ChangeType
.