Skip to content

Journey to the center of genjax.core

This page describes the set of core concepts and datatypes in GenJAX, including Gen's generative datatypes and concepts (GenerativeFunction, Trace, ChoiceMap, and EditRequest), the core JAX compatibility datatypes (Pytree, Const, and Closure), as well as functionally inspired Pytree extensions (Mask), and GenJAX's approach to "static" (JAX tracing time) typechecking.

genjax.core.GenerativeFunction

Bases: Generic[R], Pytree

GenerativeFunction is the type of generative functions, the main computational object in Gen.

Generative functions are a type of probabilistic program. In terms of their mathematical specification, they come equipped with a few ingredients:

  • (Distribution over samples) \(P(\cdot_t, \cdot_r; a)\) - a probability distribution over samples \(t\) and untraced randomness \(r\), indexed by arguments \(a\). This ingredient is involved in all the interfaces and specifies the distribution over samples which the generative function represents.
  • (Family of K/L proposals) \((K(\cdot_t, \cdot_{K_r}; u, t), L(\cdot_t, \cdot_{L_r}; u, t)) = \mathcal{F}(u, t)\) - a family of pairs of probabilistic programs (referred to as K and L), indexed by EditRequest \(u\) and an existing sample \(t\). This ingredient supports the edit and importance interface, and is used to specify an SMCP3 move which the generative function must provide in response to an edit request. K and L must satisfy additional properties, described further in edit.
  • (Return value function) \(f(t, r, a)\) - a deterministic return value function, which maps samples and untraced randomness to return values.

Generative functions also support a family of Target distributions - a Target distribution is a (possibly unnormalized) distribution, typically induced by inference problems.

  • \(\delta_\emptyset\) - the empty target, whose only possible value is the empty sample, with density 1.
  • (Family of targets induced by \(P\)) \(T_P(a, c)\) - a family of targets indexed by arguments \(a\) and constraints (ChoiceMap), created by pairing the distribution over samples \(P\) with arguments and constraint.

Generative functions expose computations using these ingredients through the generative function interface (the methods which are documented below).

Examples:

The interface methods can be used to implement inference algorithms directly - here's a simple example using bootstrap importance sampling directly:

import jax
from jax.scipy.special import logsumexp
import jax.tree_util as jtu
from genjax import ChoiceMapBuilder as C
from genjax import gen, uniform, flip, categorical


@gen
def model():
    p = uniform(0.0, 1.0) @ "p"
    f1 = flip(p) @ "f1"
    f2 = flip(p) @ "f2"


# Bootstrap importance sampling.
def importance_sampling(key, constraint):
    key, sub_key = jax.random.split(key)
    sub_keys = jax.random.split(sub_key, 5)
    tr, log_weights = jax.vmap(model.importance, in_axes=(0, None, None))(
        sub_keys, constraint, ()
    )
    logits = log_weights - logsumexp(log_weights)
    idx = categorical(logits)(key)
    return jtu.tree_map(lambda v: v[idx], tr.get_choices())


sub_keys = jax.random.split(jax.random.key(0), 50)
samples = jax.jit(jax.vmap(importance_sampling, in_axes=(0, None)))(
    sub_keys, C.kw(f1=True, f2=True)
)
print(samples.render_html())

Methods:

Name Description
__abstract_call__

Used to support JAX tracing, although this default implementation involves no

accumulate

When called on a genjax.GenerativeFunction of type (c, a) -> c, returns a new genjax.GenerativeFunction of type (c, [a]) -> [c] where

assess

Return the score and the return value when the generative function is invoked with the provided arguments, and constrained to take the provided sample as the sampled value.

contramap

Specialized version of genjax.GenerativeFunction.dimap where only the pre-processing function is applied.

dimap

Returns a new genjax.GenerativeFunction and applies pre- and post-processing functions to its arguments and return value.

edit

Update a trace in response to an EditRequest, returning a new Trace, an incremental Weight for the new target, a Retdiff return value tagged with change information, and a backward EditRequest which requests the reverse move (to go back to the original trace).

get_zero_trace
    Returns a trace with zero values for all leaves, generated without executing the generative function.
handle_kwargs

Returns a new GenerativeFunction like self, but where all GFI methods accept a tuple of arguments and a dictionary of keyword arguments.

importance

Returns a properly weighted pair, a Trace and a Weight, properly weighted for the target induced by the generative function for the provided constraint and arguments.

iterate

When called on a genjax.GenerativeFunction of type a -> a, returns a new genjax.GenerativeFunction of type a -> [a] where

iterate_final

Returns a decorator that wraps a genjax.GenerativeFunction of type a -> a and returns a new genjax.GenerativeFunction of type a -> a where

map

Specialized version of genjax.dimap where only the post-processing function is applied.

mask

Enables dynamic masking of generative functions. Returns a new genjax.GenerativeFunction like self, but which accepts an additional boolean first argument.

masked_iterate

Transforms a generative function that takes a single argument of type a and returns a value of type a, into a function that takes a tuple of arguments (a, [mask]) and returns a list of values of type a.

masked_iterate_final

Transforms a generative function that takes a single argument of type a and returns a value of type a, into a function that takes a tuple of arguments (a, [mask]) and returns a value of type a.

mix

Takes any number of genjax.GenerativeFunctions and returns a new genjax.GenerativeFunction that represents a mixture model.

or_else

Returns a GenerativeFunction that accepts

propose

Samples a ChoiceMap and any untraced randomness \(r\) from the generative function's distribution over samples (\(P\)), and returns the Score of that sample under the distribution, and the R of the generative function's return value function \(f(r, t, a)\) for the sample and untraced randomness.

reduce

When called on a genjax.GenerativeFunction of type (c, a) -> c, returns a new genjax.GenerativeFunction of type (c, [a]) -> c where

repeat

Returns a genjax.GenerativeFunction that samples from self n times, returning a vector of n results.

scan

When called on a genjax.GenerativeFunction of type (c, a) -> (c, b), returns a new genjax.GenerativeFunction of type (c, [a]) -> (c, [b]) where

simulate

Execute the generative function, sampling from its distribution over samples, and return a Trace.

switch

Given n genjax.GenerativeFunction inputs, returns a new genjax.GenerativeFunction that accepts n+2 arguments:

vmap

Returns a GenerativeFunction that performs a vectorized map over the argument specified by in_axes. Traced values are nested under an index, and the retval is vectorized.

Source code in src/genjax/_src/core/generative/generative_function.py
 238
 239
 240
 241
 242
 243
 244
 245
 246
 247
 248
 249
 250
 251
 252
 253
 254
 255
 256
 257
 258
 259
 260
 261
 262
 263
 264
 265
 266
 267
 268
 269
 270
 271
 272
 273
 274
 275
 276
 277
 278
 279
 280
 281
 282
 283
 284
 285
 286
 287
 288
 289
 290
 291
 292
 293
 294
 295
 296
 297
 298
 299
 300
 301
 302
 303
 304
 305
 306
 307
 308
 309
 310
 311
 312
 313
 314
 315
 316
 317
 318
 319
 320
 321
 322
 323
 324
 325
 326
 327
 328
 329
 330
 331
 332
 333
 334
 335
 336
 337
 338
 339
 340
 341
 342
 343
 344
 345
 346
 347
 348
 349
 350
 351
 352
 353
 354
 355
 356
 357
 358
 359
 360
 361
 362
 363
 364
 365
 366
 367
 368
 369
 370
 371
 372
 373
 374
 375
 376
 377
 378
 379
 380
 381
 382
 383
 384
 385
 386
 387
 388
 389
 390
 391
 392
 393
 394
 395
 396
 397
 398
 399
 400
 401
 402
 403
 404
 405
 406
 407
 408
 409
 410
 411
 412
 413
 414
 415
 416
 417
 418
 419
 420
 421
 422
 423
 424
 425
 426
 427
 428
 429
 430
 431
 432
 433
 434
 435
 436
 437
 438
 439
 440
 441
 442
 443
 444
 445
 446
 447
 448
 449
 450
 451
 452
 453
 454
 455
 456
 457
 458
 459
 460
 461
 462
 463
 464
 465
 466
 467
 468
 469
 470
 471
 472
 473
 474
 475
 476
 477
 478
 479
 480
 481
 482
 483
 484
 485
 486
 487
 488
 489
 490
 491
 492
 493
 494
 495
 496
 497
 498
 499
 500
 501
 502
 503
 504
 505
 506
 507
 508
 509
 510
 511
 512
 513
 514
 515
 516
 517
 518
 519
 520
 521
 522
 523
 524
 525
 526
 527
 528
 529
 530
 531
 532
 533
 534
 535
 536
 537
 538
 539
 540
 541
 542
 543
 544
 545
 546
 547
 548
 549
 550
 551
 552
 553
 554
 555
 556
 557
 558
 559
 560
 561
 562
 563
 564
 565
 566
 567
 568
 569
 570
 571
 572
 573
 574
 575
 576
 577
 578
 579
 580
 581
 582
 583
 584
 585
 586
 587
 588
 589
 590
 591
 592
 593
 594
 595
 596
 597
 598
 599
 600
 601
 602
 603
 604
 605
 606
 607
 608
 609
 610
 611
 612
 613
 614
 615
 616
 617
 618
 619
 620
 621
 622
 623
 624
 625
 626
 627
 628
 629
 630
 631
 632
 633
 634
 635
 636
 637
 638
 639
 640
 641
 642
 643
 644
 645
 646
 647
 648
 649
 650
 651
 652
 653
 654
 655
 656
 657
 658
 659
 660
 661
 662
 663
 664
 665
 666
 667
 668
 669
 670
 671
 672
 673
 674
 675
 676
 677
 678
 679
 680
 681
 682
 683
 684
 685
 686
 687
 688
 689
 690
 691
 692
 693
 694
 695
 696
 697
 698
 699
 700
 701
 702
 703
 704
 705
 706
 707
 708
 709
 710
 711
 712
 713
 714
 715
 716
 717
 718
 719
 720
 721
 722
 723
 724
 725
 726
 727
 728
 729
 730
 731
 732
 733
 734
 735
 736
 737
 738
 739
 740
 741
 742
 743
 744
 745
 746
 747
 748
 749
 750
 751
 752
 753
 754
 755
 756
 757
 758
 759
 760
 761
 762
 763
 764
 765
 766
 767
 768
 769
 770
 771
 772
 773
 774
 775
 776
 777
 778
 779
 780
 781
 782
 783
 784
 785
 786
 787
 788
 789
 790
 791
 792
 793
 794
 795
 796
 797
 798
 799
 800
 801
 802
 803
 804
 805
 806
 807
 808
 809
 810
 811
 812
 813
 814
 815
 816
 817
 818
 819
 820
 821
 822
 823
 824
 825
 826
 827
 828
 829
 830
 831
 832
 833
 834
 835
 836
 837
 838
 839
 840
 841
 842
 843
 844
 845
 846
 847
 848
 849
 850
 851
 852
 853
 854
 855
 856
 857
 858
 859
 860
 861
 862
 863
 864
 865
 866
 867
 868
 869
 870
 871
 872
 873
 874
 875
 876
 877
 878
 879
 880
 881
 882
 883
 884
 885
 886
 887
 888
 889
 890
 891
 892
 893
 894
 895
 896
 897
 898
 899
 900
 901
 902
 903
 904
 905
 906
 907
 908
 909
 910
 911
 912
 913
 914
 915
 916
 917
 918
 919
 920
 921
 922
 923
 924
 925
 926
 927
 928
 929
 930
 931
 932
 933
 934
 935
 936
 937
 938
 939
 940
 941
 942
 943
 944
 945
 946
 947
 948
 949
 950
 951
 952
 953
 954
 955
 956
 957
 958
 959
 960
 961
 962
 963
 964
 965
 966
 967
 968
 969
 970
 971
 972
 973
 974
 975
 976
 977
 978
 979
 980
 981
 982
 983
 984
 985
 986
 987
 988
 989
 990
 991
 992
 993
 994
 995
 996
 997
 998
 999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479
1480
1481
1482
1483
1484
class GenerativeFunction(Generic[R], Pytree):
    """
    `GenerativeFunction` is the type of _generative functions_, the main computational object in Gen.

    Generative functions are a type of probabilistic program. In terms of their mathematical specification, they come equipped with a few ingredients:

    * (**Distribution over samples**) $P(\\cdot_t, \\cdot_r; a)$ - a probability distribution over samples $t$ and untraced randomness $r$, indexed by arguments $a$. This ingredient is involved in all the interfaces and specifies the distribution over samples which the generative function represents.
    * (**Family of K/L proposals**) $(K(\\cdot_t, \\cdot_{K_r}; u, t), L(\\cdot_t, \\cdot_{L_r}; u, t)) = \\mathcal{F}(u, t)$ - a family of pairs of probabilistic programs (referred to as K and L), indexed by [`EditRequest`][genjax.core.EditRequest] $u$ and an existing sample $t$. This ingredient supports the [`edit`][genjax.core.GenerativeFunction.edit] and [`importance`][genjax.core.GenerativeFunction.importance] interface, and is used to specify an SMCP3 move which the generative function must provide in response to an edit request. K and L must satisfy additional properties, described further in [`edit`][genjax.core.GenerativeFunction.edit].
    * (**Return value function**) $f(t, r, a)$ - a deterministic return value function, which maps samples and untraced randomness to return values.

    Generative functions also support a family of [`Target`][genjax.inference.Target] distributions - a [`Target`][genjax.inference.Target] distribution is a (possibly unnormalized) distribution, typically induced by inference problems.

    * $\\delta_\\emptyset$ - the empty target, whose only possible value is the empty sample, with density 1.
    * (**Family of targets induced by $P$**) $T_P(a, c)$ - a family of targets indexed by arguments $a$ and constraints (`ChoiceMap`), created by pairing the distribution over samples $P$ with arguments and constraint.

    Generative functions expose computations using these ingredients through the _generative function interface_ (the methods which are documented below).

    Examples:
        The interface methods can be used to implement inference algorithms directly - here's a simple example using bootstrap importance sampling directly:
        ```python exec="yes" html="true" source="material-block" session="core"
        import jax
        from jax.scipy.special import logsumexp
        import jax.tree_util as jtu
        from genjax import ChoiceMapBuilder as C
        from genjax import gen, uniform, flip, categorical


        @gen
        def model():
            p = uniform(0.0, 1.0) @ "p"
            f1 = flip(p) @ "f1"
            f2 = flip(p) @ "f2"


        # Bootstrap importance sampling.
        def importance_sampling(key, constraint):
            key, sub_key = jax.random.split(key)
            sub_keys = jax.random.split(sub_key, 5)
            tr, log_weights = jax.vmap(model.importance, in_axes=(0, None, None))(
                sub_keys, constraint, ()
            )
            logits = log_weights - logsumexp(log_weights)
            idx = categorical(logits)(key)
            return jtu.tree_map(lambda v: v[idx], tr.get_choices())


        sub_keys = jax.random.split(jax.random.key(0), 50)
        samples = jax.jit(jax.vmap(importance_sampling, in_axes=(0, None)))(
            sub_keys, C.kw(f1=True, f2=True)
        )
        print(samples.render_html())
        ```
    """

    def __call__(self, *args, **kwargs) -> "GenerativeFunctionClosure[R]":
        return GenerativeFunctionClosure(self, args, kwargs)

    def __abstract_call__(self, *args) -> R:
        """Used to support JAX tracing, although this default implementation involves no
        JAX operations (it takes a fixed-key sample from the return value).

        Generative functions may customize this to improve compilation time.
        """
        return self.get_zero_trace(*args).get_retval()

    def handle_kwargs(self) -> "GenerativeFunction[R]":
        """
        Returns a new GenerativeFunction like `self`, but where all GFI methods accept a tuple of arguments and a dictionary of keyword arguments.

        The returned GenerativeFunction can be invoked with `__call__` with no special argument handling (just like the original).

        In place of `args` tuples in GFI methods, the new GenerativeFunction expects a 2-tuple containing:

        1. A tuple containing the original positional arguments.
        2. A dictionary containing the keyword arguments.

        This allows for more flexible argument passing, especially useful in contexts where
        keyword arguments need to be handled separately or passed through multiple layers.

        Returns:
            A new GenerativeFunction that accepts (args_tuple, kwargs_dict) for all GFI methods.

        Example:
            ```python exec="yes" html="true" source="material-block" session="core"
            import genjax
            import jax


            @genjax.gen
            def model(x, y, z=1.0):
                _ = genjax.normal(x + y, z) @ "v"
                return x + y + z


            key = jax.random.key(0)
            kw_model = model.handle_kwargs()

            tr = kw_model.simulate(key, ((1.0, 2.0), {"z": 3.0}))
            print(tr.render_html())
            ```
        """
        return IgnoreKwargs(self)

    def get_zero_trace(self, *args, **_kwargs) -> Trace[R]:
        """
        Returns a trace with zero values for all leaves, generated without executing the generative function.

        This method is useful for static analysis and shape inference without executing the generative function. It returns a trace with the same structure as a real trace, but filled with zero or default values.

        Args:
            *args: The arguments to the generative function.
            **_kwargs: Ignored keyword arguments.

        Returns:
            A trace with zero values, matching the structure of a real trace.

        Note:
            This method uses the `empty_trace` utility function, which creates a trace without spending any FLOPs. The resulting trace has the correct structure but contains placeholder zero values.

        Example:
            ```python exec="yes" html="true" source="material-block" session="core"
            @genjax.gen
            def weather_model():
                temperature = genjax.normal(20.0, 5.0) @ "temperature"
                is_sunny = genjax.bernoulli(0.7) @ "is_sunny"
                return {"temperature": temperature, "is_sunny": is_sunny}


            zero_trace = weather_model.get_zero_trace()
            print("Zero trace structure:")
            print(zero_trace.render_html())

            print("\nActual simulation:")
            key = jax.random.key(0)
            actual_trace = weather_model.simulate(key, ())
            print(actual_trace.render_html())
            ```
        """
        return empty_trace(self, args)

    @abstractmethod
    def simulate(
        self,
        key: PRNGKey,
        args: Arguments,
    ) -> Trace[R]:
        """
        Execute the generative function, sampling from its distribution over samples, and return a [`Trace`][genjax.core.Trace].

        ## More on traces

        The [`Trace`][genjax.core.Trace] returned by `simulate` implements its own interface.

        It is responsible for storing the arguments of the invocation ([`genjax.Trace.get_args`][]), the return value of the generative function ([`genjax.Trace.get_retval`][]), the identity of the generative function which produced the trace ([`genjax.Trace.get_gen_fn`][]), the sample of traced random choices produced during the invocation ([`genjax.Trace.get_choices`][]) and _the score_ of the sample ([`genjax.Trace.get_score`][]).

        Examples:
            ```python exec="yes" html="true" source="material-block" session="core"
            import genjax
            import jax
            from jax import vmap, jit
            from jax.random import split


            @genjax.gen
            def model():
                x = genjax.normal(0.0, 1.0) @ "x"
                return x


            key = jax.random.key(0)
            tr = model.simulate(key, ())
            print(tr.render_html())
            ```

            Another example, using the same model, composed into [`genjax.repeat`](combinators.md#genjax.repeat) - which creates a new generative function, which has the same interface:
            ```python exec="yes" html="true" source="material-block" session="core"
            @genjax.gen
            def model():
                x = genjax.normal(0.0, 1.0) @ "x"
                return x


            key = jax.random.key(0)
            tr = model.repeat(n=10).simulate(key, ())
            print(tr.render_html())
            ```

            (**Fun, flirty, fast ... parallel?**) Feel free to use `jax.jit` and `jax.vmap`!
            ```python exec="yes" html="true" source="material-block" session="core"
            key = jax.random.key(0)
            sub_keys = split(key, 10)
            sim = model.repeat(n=10).simulate
            tr = jit(vmap(sim, in_axes=(0, None)))(sub_keys, ())
            print(tr.render_html())
            ```
        """

    @abstractmethod
    def assess(
        self,
        sample: ChoiceMap,
        args: Arguments,
    ) -> tuple[Score, R]:
        """
        Return [the score][genjax.core.Trace.get_score] and [the return value][genjax.core.Trace.get_retval] when the generative function is invoked with the provided arguments, and constrained to take the provided sample as the sampled value.

        It is an error if the provided sample value is off the support of the distribution over the `ChoiceMap` type, or otherwise induces a partial constraint on the execution of the generative function (which would require the generative function to provide an `edit` implementation which responds to the `EditRequest` induced by the [`importance`][genjax.core.GenerativeFunction.importance] interface).

        Examples:
            This method is similar to density evaluation interfaces for distributions.
            ```python exec="yes" html="true" source="material-block" session="core"
            from genjax import normal
            from genjax import ChoiceMapBuilder as C

            sample = C.v(1.0)
            score, retval = normal.assess(sample, (1.0, 1.0))
            print((score, retval))
            ```

            But it also works with generative functions that sample from spaces with more structure:

            ```python exec="yes" html="true" source="material-block" session="core"
            from genjax import gen
            from genjax import normal
            from genjax import ChoiceMapBuilder as C


            @gen
            def model():
                v1 = normal(0.0, 1.0) @ "v1"
                v2 = normal(v1, 1.0) @ "v2"


            sample = C.kw(v1=1.0, v2=0.0)
            score, retval = model.assess(sample, ())
            print((score, retval))
            ```
        """

    @abstractmethod
    def generate(
        self,
        key: PRNGKey,
        constraint: ChoiceMap,
        args: Arguments,
    ) -> tuple[Trace[R], Weight]:
        pass

    @abstractmethod
    def project(
        self,
        key: PRNGKey,
        trace: Trace[R],
        selection: Selection,
    ) -> Weight:
        pass

    @abstractmethod
    def edit(
        self,
        key: PRNGKey,
        trace: Trace[R],
        edit_request: EditRequest,
        argdiffs: Argdiffs,
    ) -> tuple[Trace[R], Weight, Retdiff[R], EditRequest]:
        """
        Update a trace in response to an [`EditRequest`][genjax.core.EditRequest], returning a new [`Trace`][genjax.core.Trace], an incremental [`Weight`][genjax.core.Weight] for the new target, a [`Retdiff`][genjax.core.Retdiff] return value tagged with change information, and a backward [`EditRequest`][genjax.core.EditRequest] which requests the reverse move (to go back to the original trace).

        The specification of this interface is parametric over the kind of `EditRequest` -- responding to an `EditRequest` instance requires that the generative function provides an implementation of a sequential Monte Carlo move in the [SMCP3](https://proceedings.mlr.press/v206/lew23a.html) framework. Users of inference algorithms are not expected to understand the ingredients, but inference algorithm developers are.

        Examples:
            Updating a trace in response to a request for a [`Target`][genjax.inference.Target] change induced by a change to the arguments:
            ```python exec="yes" source="material-block" session="core"
            import jax
            from genjax import gen, normal, Diff, Update, ChoiceMap as C

            key = jax.random.key(0)


            @gen
            def model(var):
                v1 = normal(0.0, 1.0) @ "v1"
                v2 = normal(v1, var) @ "v2"
                return v2


            # Generating an initial trace properly weighted according
            # to the target induced by the constraint.
            constraint = C.kw(v2=1.0)
            initial_tr, w = model.importance(key, constraint, (1.0,))

            # Updating the trace to a new target.
            new_tr, inc_w, retdiff, bwd_prob = model.edit(
                key,
                initial_tr,
                Update(
                    C.empty(),
                ),
                Diff.unknown_change((3.0,)),
            )
            ```

            Now, let's inspect the trace:
            ```python exec="yes" html="true" source="material-block" session="core"
            # Inspect the trace, the sampled values should not have changed!
            sample = new_tr.get_choices()
            print(sample["v1"], sample["v2"])
            ```

            And the return value diff:
            ```python exec="yes" html="true" source="material-block" session="core"
            # The return value also should not have changed!
            print(retdiff.render_html())
            ```

            As expected, neither have changed -- but the weight is non-zero:
            ```python exec="yes" html="true" source="material-block" session="core"
            print(w)
            ```

        ## Mathematical ingredients behind edit

        The `edit` interface exposes [SMCP3 moves](https://proceedings.mlr.press/v206/lew23a.html). Here, we omit the measure theoretic description, and refer interested readers to [the paper](https://proceedings.mlr.press/v206/lew23a.html). Informally, the ingredients of such a move are:

        * The previous target $T$.
        * The new target $T'$.
        * A pair of kernel probabilistic programs, called $K$ and $L$:
            * The K kernel is a kernel probabilistic program which accepts a previous sample $x_{t-1}$ from $T$ as an argument, may sample auxiliary randomness $u_K$, and returns a new sample $x_t$ approximately distributed according to $T'$, along with transformed randomness $u_L$.
            * The L kernel is a kernel probabilistic program which accepts the new sample $x_t$, and provides a density evaluator for the auxiliary randomness $u_L$ which K returns, and an inverter $x_t \\mapsto x_{t-1}$ which is _almost everywhere_ the identity function.

        The specification of these ingredients are encapsulated in the type signature of the `edit` interface.

        ## Understanding the `edit` interface

        The `edit` interface uses the mathematical ingredients described above to perform probability-aware mutations and incremental [`Weight`][genjax.core.Weight] computations on [`Trace`][genjax.core.Trace] instances, which allows Gen to provide automation to support inference agorithms like importance sampling, SMC, MCMC and many more.

        An `EditRequest` denotes a function $tr \\mapsto (T, T')$ from traces to a pair of targets (the previous [`Target`][genjax.inference.Target] $T$, and the final [`Target`][genjax.inference.Target] $T'$).

        Several common types of moves can be requested via the `Update` type:

        ```python exec="yes" source="material-block" session="core"
        from genjax import Update
        from genjax import ChoiceMap

        g = Update(
            ChoiceMap.empty(),  # Constraint
        )
        ```

        `Update` contains information about changes to the arguments of the generative function ([`Argdiffs`][genjax.core.Argdiffs]) and a constraint which specifies an additional move to be performed.

        ```python exec="yes" html="true" source="material-block" session="core"
        new_tr, inc_w, retdiff, bwd_prob = model.edit(
            key,
            initial_tr,
            Update(
                C.kw(v1=3.0),
            ),
            Diff.unknown_change((3.0,)),
        )
        print((new_tr.get_choices()["v1"], w))
        ```

        **Additional notes on [`Argdiffs`][genjax.core.Argdiffs]**

        Argument changes induce changes to the distribution over samples, internal K and L proposals, and (by virtue of changes to $P$) target distributions. The [`Argdiffs`][genjax.core.Argdiffs] type denotes the type of values attached with a _change type_, a piece of data which indicates how the value has changed from the arguments which created the trace. Generative functions can utilize change type information to inform efficient [`edit`][genjax.core.GenerativeFunction.edit] implementations.
        """
        pass

    ######################
    # Derived interfaces #
    ######################

    def update(
        self,
        key: PRNGKey,
        trace: Trace[R],
        constraint: ChoiceMap,
        argdiffs: Argdiffs,
    ) -> tuple[Trace[R], Weight, Retdiff[R], ChoiceMap]:
        request = Update(
            constraint,
        )
        tr, w, rd, bwd = request.edit(
            key,
            trace,
            argdiffs,
        )
        assert isinstance(bwd, Update), type(bwd)
        return tr, w, rd, bwd.constraint

    def importance(
        self,
        key: PRNGKey,
        constraint: ChoiceMap,
        args: Arguments,
    ) -> tuple[Trace[R], Weight]:
        """
        Returns a properly weighted pair, a [`Trace`][genjax.core.Trace] and a [`Weight`][genjax.core.Weight], properly weighted for the target induced by the generative function for the provided constraint and arguments.

        Examples:
            (**Full constraints**) A simple example using the `importance` interface on distributions:
            ```python exec="yes" html="true" source="material-block" session="core"
            import jax
            from genjax import normal
            from genjax import ChoiceMapBuilder as C

            key = jax.random.key(0)

            tr, w = normal.importance(key, C.v(1.0), (0.0, 1.0))
            print(tr.get_choices().render_html())
            ```

            (**Internal proposal for partial constraints**) Specifying a _partial_ constraint on a [`StaticGenerativeFunction`][genjax.StaticGenerativeFunction]:
            ```python exec="yes" html="true" source="material-block" session="core"
            from genjax import flip, uniform, gen
            from genjax import ChoiceMapBuilder as C


            @gen
            def model():
                p = uniform(0.0, 1.0) @ "p"
                f1 = flip(p) @ "f1"
                f2 = flip(p) @ "f2"


            tr, w = model.importance(key, C.kw(f1=True, f2=True), ())
            print(tr.get_choices().render_html())
            ```

        Under the hood, creates an [`EditRequest`][genjax.core.EditRequest] which requests that the generative function respond with a move from the _empty_ trace (the only possible value for _empty_ target $\\delta_\\emptyset$) to the target induced by the generative function for constraint $C$ with arguments $a$.
        """

        return self.generate(
            key,
            constraint,
            args,
        )

    def propose(
        self,
        key: PRNGKey,
        args: Arguments,
    ) -> tuple[ChoiceMap, Score, R]:
        """
        Samples a [`ChoiceMap`][genjax.core.ChoiceMap] and any untraced randomness $r$ from the generative function's distribution over samples ($P$), and returns the [`Score`][genjax.core.Score] of that sample under the distribution, and the `R` of the generative function's return value function $f(r, t, a)$ for the sample and untraced randomness.
        """
        tr = self.simulate(key, args)
        sample = tr.get_choices()
        score = tr.get_score()
        retval = tr.get_retval()
        return sample, score, retval

    ######################################################
    # Convenience: postfix syntax for combinators / DSLs #
    ######################################################

    ###############
    # Combinators #
    ###############

    # TODO think through, or note, that the R that comes out will have to be bounded by pytree.
    def vmap(self, /, *, in_axes: InAxes = 0) -> "GenerativeFunction[R]":
        """
        Returns a [`GenerativeFunction`][genjax.GenerativeFunction] that performs a vectorized map over the argument specified by `in_axes`. Traced values are nested under an index, and the retval is vectorized.

        Args:
            in_axes: Selector specifying which input arguments (or index into them) should be vectorized. Defaults to 0, i.e., the first argument. See [this link](https://jax.readthedocs.io/en/latest/pytrees.html#applying-optional-parameters-to-pytrees) for more detail.

        Returns:
            A new [`GenerativeFunction`][genjax.GenerativeFunction] that accepts an argument of one-higher dimension at the position specified by `in_axes`.

        Examples:
            ```python exec="yes" html="true" source="material-block" session="gen-fn"
            import jax
            import jax.numpy as jnp
            import genjax


            @genjax.gen
            def model(x):
                v = genjax.normal(x, 1.0) @ "v"
                return genjax.normal(v, 0.01) @ "q"


            vmapped = model.vmap(in_axes=0)

            key = jax.random.key(314159)
            arr = jnp.ones(100)

            # `vmapped` accepts an array if numbers instead of the original
            # single number that `model` accepted.
            tr = jax.jit(vmapped.simulate)(key, (arr,))

            print(tr.render_html())
            ```
        """
        import genjax

        return genjax.vmap(in_axes=in_axes)(self)

    def repeat(self, /, *, n: int) -> "GenerativeFunction[R]":
        """
        Returns a [`genjax.GenerativeFunction`][] that samples from `self` `n` times, returning a vector of `n` results.

        The values traced by each call `gen_fn` will be nested under an integer index that matches the loop iteration index that generated it.

        This combinator is useful for creating multiple samples from `self` in a batched manner.

        Args:
            n: The number of times to sample from the generative function.

        Returns:
            A new [`genjax.GenerativeFunction`][] that samples from the original function `n` times.

        Examples:
            ```python exec="yes" html="true" source="material-block" session="repeat"
            import genjax, jax


            @genjax.gen
            def normal_draw(mean):
                return genjax.normal(mean, 1.0) @ "x"


            normal_draws = normal_draw.repeat(n=10)

            key = jax.random.key(314159)

            # Generate 10 draws from a normal distribution with mean 2.0
            tr = jax.jit(normal_draws.simulate)(key, (2.0,))
            print(tr.render_html())
            ```
        """
        import genjax

        return genjax.repeat(n=n)(self)

    def scan(
        self: "GenerativeFunction[tuple[Carry, Y]]",
        /,
        *,
        n: int | None = None,
    ) -> "GenerativeFunction[tuple[Carry, Y]]":
        """
        When called on a [`genjax.GenerativeFunction`][] of type `(c, a) -> (c, b)`, returns a new [`genjax.GenerativeFunction`][] of type `(c, [a]) -> (c, [b])` where

        - `c` is a loop-carried value, which must hold a fixed shape and dtype across all iterations
        - `a` may be a primitive, an array type or a pytree (container) type with array leaves
        - `b` may be a primitive, an array type or a pytree (container) type with array leaves.

        The values traced by each call to the original generative function will be nested under an integer index that matches the loop iteration index that generated it.

        For any array type specifier `t`, `[t]` represents the type with an additional leading axis, and if `t` is a pytree (container) type with array leaves then `[t]` represents the type with the same pytree structure and corresponding leaves each with an additional leading axis.

        When the type of `xs` in the snippet below (denoted `[a]` above) is an array type or None, and the type of `ys` in the snippet below (denoted `[b]` above) is an array type, the semantics of the returned [`genjax.GenerativeFunction`][] are given roughly by this Python implementation:

        ```python
        def scan(f, init, xs, length=None):
            if xs is None:
                xs = [None] * length
            carry = init
            ys = []
            for x in xs:
                carry, y = f(carry, x)
                ys.append(y)
            return carry, np.stack(ys)
        ```

        Unlike that Python version, both `xs` and `ys` may be arbitrary pytree values, and so multiple arrays can be scanned over at once and produce multiple output arrays. `None` is actually a special case of this, as it represents an empty pytree.

        The loop-carried value `c` must hold a fixed shape and dtype across all iterations (and not just be consistent up to NumPy rank/shape broadcasting and dtype promotion rules, for example). In other words, the type `c` in the type signature above represents an array with a fixed shape and dtype (or a nested tuple/list/dict container data structure with a fixed structure and arrays with fixed shape and dtype at the leaves).

        Args:
            n: optional integer specifying the number of loop iterations, which (if supplied) must agree with the sizes of leading axes of the arrays in the returned function's second argument. If supplied then the returned generative function can take `None` as its second argument.

        Returns:
            A new [`genjax.GenerativeFunction`][] that takes a loop-carried value and a new input, and returns a new loop-carried value along with either `None` or an output to be collected into the second return value.

        Examples:
            Scan for 1000 iterations with no array input:
            ```python exec="yes" html="true" source="material-block" session="scan"
            import jax
            import genjax


            @genjax.gen
            def random_walk_step(prev, _):
                x = genjax.normal(prev, 1.0) @ "x"
                return x, None


            random_walk = random_walk_step.scan(n=1000)

            init = 0.5
            key = jax.random.key(314159)

            tr = jax.jit(random_walk.simulate)(key, (init, None))
            print(tr.render_html())
            ```

            Scan across an input array:
            ```python exec="yes" html="true" source="material-block" session="scan"
            import jax.numpy as jnp


            @genjax.gen
            def add_and_square_step(sum, x):
                new_sum = sum + x
                return new_sum, sum * sum


            # notice no `n` parameter supplied:
            add_and_square_all = add_and_square_step.scan()
            init = 0.0
            xs = jnp.ones(10)

            tr = jax.jit(add_and_square_all.simulate)(key, (init, xs))

            # The retval has the final carry and an array of all `sum*sum` returned.
            print(tr.render_html())
            ```
        """
        import genjax

        return genjax.scan(n=n)(self)

    def accumulate(self) -> "GenerativeFunction[R]":
        """
        When called on a [`genjax.GenerativeFunction`][] of type `(c, a) -> c`, returns a new [`genjax.GenerativeFunction`][] of type `(c, [a]) -> [c]` where

        - `c` is a loop-carried value, which must hold a fixed shape and dtype across all iterations
        - `[c]` is an array of all loop-carried values seen during iteration (including the first)
        - `a` may be a primitive, an array type or a pytree (container) type with array leaves

        All traced values are nested under an index.

        For any array type specifier `t`, `[t]` represents the type with an additional leading axis, and if `t` is a pytree (container) type with array leaves then `[t]` represents the type with the same pytree structure and corresponding leaves each with an additional leading axis.

        The semantics of the returned [`genjax.GenerativeFunction`][] are given roughly by this Python implementation (note the similarity to [`itertools.accumulate`](https://docs.python.org/3/library/itertools.html#itertools.accumulate)):

        ```python
        def accumulate(f, init, xs):
            carry = init
            carries = [init]
            for x in xs:
                carry = f(carry, x)
                carries.append(carry)
            return carries
        ```

        Unlike that Python version, both `xs` and `carries` may be arbitrary pytree values, and so multiple arrays can be scanned over at once and produce multiple output arrays.

        The loop-carried value `c` must hold a fixed shape and dtype across all iterations (and not just be consistent up to NumPy rank/shape broadcasting and dtype promotion rules, for example). In other words, the type `c` in the type signature above represents an array with a fixed shape and dtype (or a nested tuple/list/dict container data structure with a fixed structure and arrays with fixed shape and dtype at the leaves).

        Examples:
            ```python exec="yes" html="true" source="material-block" session="scan"
            import jax
            import genjax
            import jax.numpy as jnp


            @genjax.accumulate()
            @genjax.gen
            def add(sum, x):
                new_sum = sum + x
                return new_sum


            init = 0.0
            key = jax.random.key(314159)
            xs = jnp.ones(10)

            tr = jax.jit(add.simulate)(key, (init, xs))
            print(tr.render_html())
            ```
        """
        import genjax

        return genjax.accumulate()(self)

    def reduce(self) -> "GenerativeFunction[R]":
        """
        When called on a [`genjax.GenerativeFunction`][] of type `(c, a) -> c`, returns a new [`genjax.GenerativeFunction`][] of type `(c, [a]) -> c` where

        - `c` is a loop-carried value, which must hold a fixed shape and dtype across all iterations
        - `a` may be a primitive, an array type or a pytree (container) type with array leaves

        All traced values are nested under an index.

        For any array type specifier `t`, `[t]` represents the type with an additional leading axis, and if `t` is a pytree (container) type with array leaves then `[t]` represents the type with the same pytree structure and corresponding leaves each with an additional leading axis.

        The semantics of the returned [`genjax.GenerativeFunction`][] are given roughly by this Python implementation (note the similarity to [`functools.reduce`](https://docs.python.org/3/library/itertools.html#functools.reduce)):

        ```python
        def reduce(f, init, xs):
            carry = init
            for x in xs:
                carry = f(carry, x)
            return carry
        ```

        Unlike that Python version, both `xs` and `carry` may be arbitrary pytree values, and so multiple arrays can be scanned over at once and produce multiple output arrays.

        The loop-carried value `c` must hold a fixed shape and dtype across all iterations (and not just be consistent up to NumPy rank/shape broadcasting and dtype promotion rules, for example). In other words, the type `c` in the type signature above represents an array with a fixed shape and dtype (or a nested tuple/list/dict container data structure with a fixed structure and arrays with fixed shape and dtype at the leaves).

        Examples:
            sum an array of numbers:
            ```python exec="yes" html="true" source="material-block" session="scan"
            import jax
            import genjax
            import jax.numpy as jnp


            @genjax.reduce()
            @genjax.gen
            def add(sum, x):
                new_sum = sum + x
                return new_sum


            init = 0.0
            key = jax.random.key(314159)
            xs = jnp.ones(10)

            tr = jax.jit(add.simulate)(key, (init, xs))
            print(tr.render_html())
            ```
        """
        import genjax

        return genjax.reduce()(self)

    def iterate(
        self,
        /,
        *,
        n: int,
    ) -> "GenerativeFunction[R]":
        """
        When called on a [`genjax.GenerativeFunction`][] of type `a -> a`, returns a new [`genjax.GenerativeFunction`][] of type `a -> [a]` where

        - `a` is a loop-carried value, which must hold a fixed shape and dtype across all iterations
        - `[a]` is an array of all `a`, `f(a)`, `f(f(a))` etc. values seen during iteration.

        All traced values are nested under an index.

        The semantics of the returned [`genjax.GenerativeFunction`][] are given roughly by this Python implementation:

        ```python
        def iterate(f, n, init):
            input = init
            seen = [init]
            for _ in range(n):
                input = f(input)
                seen.append(input)
            return seen
        ```

        `init` may be an arbitrary pytree value, and so multiple arrays can be iterated over at once and produce multiple output arrays.

        The iterated value `a` must hold a fixed shape and dtype across all iterations (and not just be consistent up to NumPy rank/shape broadcasting and dtype promotion rules, for example). In other words, the type `a` in the type signature above represents an array with a fixed shape and dtype (or a nested tuple/list/dict container data structure with a fixed structure and arrays with fixed shape and dtype at the leaves).

        Args:
            n: the number of iterations to run.

        Examples:
            iterative addition, returning all intermediate sums:
            ```python exec="yes" html="true" source="material-block" session="scan"
            import jax
            import genjax


            @genjax.iterate(n=100)
            @genjax.gen
            def inc(x):
                return x + 1


            init = 0.0
            key = jax.random.key(314159)

            tr = jax.jit(inc.simulate)(key, (init,))
            print(tr.render_html())
            ```
        """
        import genjax

        return genjax.iterate(n=n)(self)

    def iterate_final(
        self,
        /,
        *,
        n: int,
    ) -> "GenerativeFunction[R]":
        """
        Returns a decorator that wraps a [`genjax.GenerativeFunction`][] of type `a -> a` and returns a new [`genjax.GenerativeFunction`][] of type `a -> a` where

        - `a` is a loop-carried value, which must hold a fixed shape and dtype across all iterations
        - the original function is invoked `n` times with each input coming from the previous invocation's output, so that the new function returns $f^n(a)$

        All traced values are nested under an index.

        The semantics of the returned [`genjax.GenerativeFunction`][] are given roughly by this Python implementation:

        ```python
        def iterate_final(f, n, init):
            ret = init
            for _ in range(n):
                ret = f(ret)
            return ret
        ```

        `init` may be an arbitrary pytree value, and so multiple arrays can be iterated over at once and produce multiple output arrays.

        The iterated value `a` must hold a fixed shape and dtype across all iterations (and not just be consistent up to NumPy rank/shape broadcasting and dtype promotion rules, for example). In other words, the type `a` in the type signature above represents an array with a fixed shape and dtype (or a nested tuple/list/dict container data structure with a fixed structure and arrays with fixed shape and dtype at the leaves).

        Args:
            n: the number of iterations to run.

        Examples:
            iterative addition:
            ```python exec="yes" html="true" source="material-block" session="scan"
            import jax
            import genjax


            @genjax.iterate_final(n=100)
            @genjax.gen
            def inc(x):
                return x + 1


            init = 0.0
            key = jax.random.key(314159)

            tr = jax.jit(inc.simulate)(key, (init,))
            print(tr.render_html())
            ```
        """
        import genjax

        return genjax.iterate_final(n=n)(self)

    def masked_iterate(self) -> "GenerativeFunction[R]":
        """
        Transforms a generative function that takes a single argument of type `a` and returns a value of type `a`, into a function that takes a tuple of arguments `(a, [mask])` and returns a list of values of type `a`.

        The original function is modified to accept an additional argument `mask`, which is a boolean value indicating whether the operation should be masked or not. The function returns a Masked list of results of the original operation with the input [mask] as mask.

        All traced values from the kernel generative function are traced (with an added axis due to the scan) but only those indices from [mask] with a flag of True will accounted for in inference, notably for score computations.

        Example:
            ```python exec="yes" html="true" source="material-block" session="scan"
            import jax
            import genjax

            masks = jnp.array([True, False, True])


            # Create a kernel generative function
            @genjax.gen
            def step(x):
                _ = (
                    genjax.normal.mask().vmap(in_axes=(0, None, None))(masks, x, 1.0)
                    @ "rats"
                )
                return x


            # Create a model using masked_iterate
            model = step.masked_iterate()

            # Simulate from the model
            key = jax.random.key(0)
            mask_steps = jnp.arange(10) < 5
            tr = model.simulate(key, (0.0, mask_steps))
            print(tr.render_html())
            ```
        """
        import genjax

        return genjax.masked_iterate()(self)

    def masked_iterate_final(self) -> "GenerativeFunction[R]":
        """
        Transforms a generative function that takes a single argument of type `a` and returns a value of type `a`, into a function that takes a tuple of arguments `(a, [mask])` and returns a value of type `a`.

        The original function is modified to accept an additional argument `mask`, which is a boolean value indicating whether the operation should be masked or not. The function returns the result of the original operation if `mask` is `True`, and the original input if `mask` is `False`.

        All traced values from the kernel generative function are traced (with an added axis due to the scan) but only those indices from [mask] with a flag of True will accounted for in inference, notably for score computations.

        Example:
            ```python exec="yes" html="true" source="material-block" session="scan"
            import jax
            import genjax

            masks = jnp.array([True, False, True])


            # Create a kernel generative function
            @genjax.gen
            def step(x):
                _ = (
                    genjax.normal.mask().vmap(in_axes=(0, None, None))(masks, x, 1.0)
                    @ "rats"
                )
                return x


            # Create a model using masked_iterate_final
            model = step.masked_iterate_final()

            # Simulate from the model
            key = jax.random.key(0)
            mask_steps = jnp.arange(10) < 5
            tr = model.simulate(key, (0.0, mask_steps))
            print(tr.render_html())
            ```
        """
        import genjax

        return genjax.masked_iterate_final()(self)

    def mask(self, /) -> "GenerativeFunction[genjax.Mask[R]]":
        """
        Enables dynamic masking of generative functions. Returns a new [`genjax.GenerativeFunction`][] like `self`, but which accepts an additional boolean first argument.

        If `True`, the invocation of `self` is masked, and its contribution to the score is ignored. If `False`, it has the same semantics as if one was invoking `self` without masking.

        The return value type is a `Mask`, with a flag value equal to the supplied boolean.

        Returns:
            The masked version of the original [`genjax.GenerativeFunction`][].

        Examples:
            Masking a normal draw:
            ```python exec="yes" html="true" source="material-block" session="mask"
            import genjax, jax


            @genjax.gen
            def normal_draw(mean):
                return genjax.normal(mean, 1.0) @ "x"


            masked_normal_draw = normal_draw.mask()

            key = jax.random.key(314159)
            tr = jax.jit(masked_normal_draw.simulate)(
                key,
                (
                    False,
                    2.0,
                ),
            )
            print(tr.render_html())
            ```
        """
        import genjax

        return genjax.mask(self)

    def or_else(self, gen_fn: "GenerativeFunction[R]", /) -> "GenerativeFunction[R]":
        """
        Returns a [`GenerativeFunction`][genjax.GenerativeFunction] that accepts

        - a boolean argument
        - an argument tuple for `self`
        - an argument tuple for the supplied `gen_fn`

        and acts like `self` when the boolean is `True` or like `gen_fn` otherwise.

        Args:
            gen_fn: called when the boolean argument is `False`.

        Examples:
            ```python exec="yes" html="true" source="material-block" session="gen-fn"
            import jax
            import jax.numpy as jnp
            import genjax


            @genjax.gen
            def if_model(x):
                return genjax.normal(x, 1.0) @ "if_value"


            @genjax.gen
            def else_model(x):
                return genjax.normal(x, 5.0) @ "else_value"


            @genjax.gen
            def model(toss: bool):
                # Note that the returned model takes a new boolean predicate in
                # addition to argument tuples for each branch.
                return if_model.or_else(else_model)(toss, (1.0,), (10.0,)) @ "tossed"


            key = jax.random.key(314159)

            tr = jax.jit(model.simulate)(key, (True,))

            print(tr.render_html())
            ```
        """
        import genjax

        return genjax.or_else(self, gen_fn)

    def switch(self, *branches: "GenerativeFunction[R]") -> "genjax.Switch[R]":
        """
        Given `n` [`genjax.GenerativeFunction`][] inputs, returns a new [`genjax.GenerativeFunction`][] that accepts `n+2` arguments:

        - an index in the range $[0, n+1)$
        - a tuple of arguments for `self` and each of the input generative functions (`n+1` total tuples)

        and executes the generative function at the supplied index with its provided arguments.

        If `index` is out of bounds, `index` is clamped to within bounds.

        Examples:
            ```python exec="yes" html="true" source="material-block" session="switch"
            import jax, genjax


            @genjax.gen
            def branch_1():
                x = genjax.normal(0.0, 1.0) @ "x1"


            @genjax.gen
            def branch_2():
                x = genjax.bernoulli(0.3) @ "x2"


            switch = branch_1.switch(branch_2)

            key = jax.random.key(314159)
            jitted = jax.jit(switch.simulate)

            # Select `branch_2` by providing 1:
            tr = jitted(key, (1, (), ()))

            print(tr.render_html())
            ```
        """
        import genjax

        return genjax.switch(self, *branches)

    def mix(self, *fns: "GenerativeFunction[R]") -> "GenerativeFunction[R]":
        """
        Takes any number of [`genjax.GenerativeFunction`][]s and returns a new [`genjax.GenerativeFunction`][] that represents a mixture model.

        The returned generative function takes the following arguments:

        - `mixture_logits`: Logits for the categorical distribution used to select a component.
        - `*args`: Argument tuples for `self` and each of the input generative functions

        and samples from `self` or one of the input generative functions based on a draw from a categorical distribution defined by the provided mixture logits.

        Args:
            *fns: Variable number of [`genjax.GenerativeFunction`][]s to be mixed with `self`.

        Returns:
            A new [`genjax.GenerativeFunction`][] representing the mixture model.

        Examples:
            ```python exec="yes" html="true" source="material-block" session="mix"
            import jax
            import genjax


            # Define component generative functions
            @genjax.gen
            def component1(x):
                return genjax.normal(x, 1.0) @ "y"


            @genjax.gen
            def component2(x):
                return genjax.normal(x, 2.0) @ "y"


            # Create mixture model
            mixture = component1.mix(component2)

            # Use the mixture model
            key = jax.random.key(0)
            logits = jax.numpy.array([0.3, 0.7])  # Favors component2
            trace = mixture.simulate(key, (logits, (0.0,), (7.0,)))
            print(trace.render_html())
                ```
        """
        import genjax

        return genjax.mix(self, *fns)

    def dimap(
        self,
        /,
        *,
        pre: Callable[..., ArgTuple],
        post: Callable[[tuple[Any, ...], ArgTuple, R], S],
    ) -> "GenerativeFunction[S]":
        """
        Returns a new [`genjax.GenerativeFunction`][] and applies pre- and post-processing functions to its arguments and return value.

        !!! info
            Prefer [`genjax.GenerativeFunction.map`][] if you only need to transform the return value, or [`genjax.GenerativeFunction.contramap`][] if you only need to transform the arguments.

        Args:
            pre: A callable that preprocesses the arguments before passing them to the wrapped function. Note that `pre` must return a _tuple_ of arguments, not a bare argument. Default is the identity function.
            post: A callable that postprocesses the return value of the wrapped function. Default is the identity function.

        Returns:
            A new [`genjax.GenerativeFunction`][] with `pre` and `post` applied.

        Examples:
            ```python exec="yes" html="true" source="material-block" session="dimap"
            import jax, genjax


            # Define pre- and post-processing functions
            def pre_process(x, y):
                return (x + 1, y * 2)


            def post_process(args, xformed, retval):
                return retval**2


            @genjax.gen
            def model(x, y):
                return genjax.normal(x, y) @ "z"


            dimap_model = model.dimap(pre=pre_process, post=post_process)

            # Use the dimap model
            key = jax.random.key(0)
            trace = dimap_model.simulate(key, (2.0, 3.0))

            print(trace.render_html())
            ```
        """
        import genjax

        return genjax.dimap(pre=pre, post=post)(self)

    def map(self, f: Callable[[R], S]) -> "GenerativeFunction[S]":
        """
        Specialized version of [`genjax.dimap`][] where only the post-processing function is applied.

        Args:
            f: A callable that postprocesses the return value of the wrapped function.

        Returns:
            A [`genjax.GenerativeFunction`][] that acts like `self` with a post-processing function to its return value.

        Examples:
            ```python exec="yes" html="true" source="material-block" session="map"
            import jax, genjax


            # Define a post-processing function
            def square(x):
                return x**2


            @genjax.gen
            def model(x):
                return genjax.normal(x, 1.0) @ "z"


            map_model = model.map(square)

            # Use the map model
            key = jax.random.key(0)
            trace = map_model.simulate(key, (2.0,))

            print(trace.render_html())
            ```
        """
        import genjax

        return genjax.map(f=f)(self)

    def contramap(self, f: Callable[..., ArgTuple]) -> "GenerativeFunction[R]":
        """
        Specialized version of [`genjax.GenerativeFunction.dimap`][] where only the pre-processing function is applied.

        Args:
            f: A callable that preprocesses the arguments of the wrapped function. Note that `f` must return a _tuple_ of arguments, not a bare argument.

        Returns:
            A [`genjax.GenerativeFunction`][] that acts like `self` with a pre-processing function to its arguments.

        Examples:
            ```python exec="yes" html="true" source="material-block" session="contramap"
            import jax, genjax


            # Define a pre-processing function.
            # Note that this function must return a tuple of arguments!
            def add_one(x):
                return (x + 1,)


            @genjax.gen
            def model(x):
                return genjax.normal(x, 1.0) @ "z"


            contramap_model = model.contramap(add_one)

            # Use the contramap model
            key = jax.random.key(0)
            trace = contramap_model.simulate(key, (2.0,))

            print(trace.render_html())
            ```
        """
        import genjax

        return genjax.contramap(f=f)(self)

    #####################
    # GenSP / inference #
    #####################

    def marginal(
        self,
        /,
        *,
        selection: Any | None = None,
        algorithm: Any | None = None,
    ) -> "genjax.Marginal[R]":
        from genjax import Selection, marginal

        if selection is None:
            selection = Selection.all()

        return marginal(selection=selection, algorithm=algorithm)(self)

__abstract_call__

__abstract_call__(*args) -> R

Used to support JAX tracing, although this default implementation involves no JAX operations (it takes a fixed-key sample from the return value).

Generative functions may customize this to improve compilation time.

Source code in src/genjax/_src/core/generative/generative_function.py
def __abstract_call__(self, *args) -> R:
    """Used to support JAX tracing, although this default implementation involves no
    JAX operations (it takes a fixed-key sample from the return value).

    Generative functions may customize this to improve compilation time.
    """
    return self.get_zero_trace(*args).get_retval()

accumulate

accumulate() -> GenerativeFunction[R]

When called on a genjax.GenerativeFunction of type (c, a) -> c, returns a new genjax.GenerativeFunction of type (c, [a]) -> [c] where

  • c is a loop-carried value, which must hold a fixed shape and dtype across all iterations
  • [c] is an array of all loop-carried values seen during iteration (including the first)
  • a may be a primitive, an array type or a pytree (container) type with array leaves

All traced values are nested under an index.

For any array type specifier t, [t] represents the type with an additional leading axis, and if t is a pytree (container) type with array leaves then [t] represents the type with the same pytree structure and corresponding leaves each with an additional leading axis.

The semantics of the returned genjax.GenerativeFunction are given roughly by this Python implementation (note the similarity to itertools.accumulate):

def accumulate(f, init, xs):
    carry = init
    carries = [init]
    for x in xs:
        carry = f(carry, x)
        carries.append(carry)
    return carries

Unlike that Python version, both xs and carries may be arbitrary pytree values, and so multiple arrays can be scanned over at once and produce multiple output arrays.

The loop-carried value c must hold a fixed shape and dtype across all iterations (and not just be consistent up to NumPy rank/shape broadcasting and dtype promotion rules, for example). In other words, the type c in the type signature above represents an array with a fixed shape and dtype (or a nested tuple/list/dict container data structure with a fixed structure and arrays with fixed shape and dtype at the leaves).

Examples:

import jax
import genjax
import jax.numpy as jnp


@genjax.accumulate()
@genjax.gen
def add(sum, x):
    new_sum = sum + x
    return new_sum


init = 0.0
key = jax.random.key(314159)
xs = jnp.ones(10)

tr = jax.jit(add.simulate)(key, (init, xs))
print(tr.render_html())
Source code in src/genjax/_src/core/generative/generative_function.py
def accumulate(self) -> "GenerativeFunction[R]":
    """
    When called on a [`genjax.GenerativeFunction`][] of type `(c, a) -> c`, returns a new [`genjax.GenerativeFunction`][] of type `(c, [a]) -> [c]` where

    - `c` is a loop-carried value, which must hold a fixed shape and dtype across all iterations
    - `[c]` is an array of all loop-carried values seen during iteration (including the first)
    - `a` may be a primitive, an array type or a pytree (container) type with array leaves

    All traced values are nested under an index.

    For any array type specifier `t`, `[t]` represents the type with an additional leading axis, and if `t` is a pytree (container) type with array leaves then `[t]` represents the type with the same pytree structure and corresponding leaves each with an additional leading axis.

    The semantics of the returned [`genjax.GenerativeFunction`][] are given roughly by this Python implementation (note the similarity to [`itertools.accumulate`](https://docs.python.org/3/library/itertools.html#itertools.accumulate)):

    ```python
    def accumulate(f, init, xs):
        carry = init
        carries = [init]
        for x in xs:
            carry = f(carry, x)
            carries.append(carry)
        return carries
    ```

    Unlike that Python version, both `xs` and `carries` may be arbitrary pytree values, and so multiple arrays can be scanned over at once and produce multiple output arrays.

    The loop-carried value `c` must hold a fixed shape and dtype across all iterations (and not just be consistent up to NumPy rank/shape broadcasting and dtype promotion rules, for example). In other words, the type `c` in the type signature above represents an array with a fixed shape and dtype (or a nested tuple/list/dict container data structure with a fixed structure and arrays with fixed shape and dtype at the leaves).

    Examples:
        ```python exec="yes" html="true" source="material-block" session="scan"
        import jax
        import genjax
        import jax.numpy as jnp


        @genjax.accumulate()
        @genjax.gen
        def add(sum, x):
            new_sum = sum + x
            return new_sum


        init = 0.0
        key = jax.random.key(314159)
        xs = jnp.ones(10)

        tr = jax.jit(add.simulate)(key, (init, xs))
        print(tr.render_html())
        ```
    """
    import genjax

    return genjax.accumulate()(self)

assess abstractmethod

assess(
    sample: ChoiceMap, args: Arguments
) -> tuple[Score, R]

Return the score and the return value when the generative function is invoked with the provided arguments, and constrained to take the provided sample as the sampled value.

It is an error if the provided sample value is off the support of the distribution over the ChoiceMap type, or otherwise induces a partial constraint on the execution of the generative function (which would require the generative function to provide an edit implementation which responds to the EditRequest induced by the importance interface).

Examples:

This method is similar to density evaluation interfaces for distributions.

from genjax import normal
from genjax import ChoiceMapBuilder as C

sample = C.v(1.0)
score, retval = normal.assess(sample, (1.0, 1.0))
print((score, retval))
(Array(-0.9189385, dtype=float32), 1.0)

But it also works with generative functions that sample from spaces with more structure:

from genjax import gen
from genjax import normal
from genjax import ChoiceMapBuilder as C


@gen
def model():
    v1 = normal(0.0, 1.0) @ "v1"
    v2 = normal(v1, 1.0) @ "v2"


sample = C.kw(v1=1.0, v2=0.0)
score, retval = model.assess(sample, ())
print((score, retval))
(Array(-2.837877, dtype=float32), None)
Source code in src/genjax/_src/core/generative/generative_function.py
@abstractmethod
def assess(
    self,
    sample: ChoiceMap,
    args: Arguments,
) -> tuple[Score, R]:
    """
    Return [the score][genjax.core.Trace.get_score] and [the return value][genjax.core.Trace.get_retval] when the generative function is invoked with the provided arguments, and constrained to take the provided sample as the sampled value.

    It is an error if the provided sample value is off the support of the distribution over the `ChoiceMap` type, or otherwise induces a partial constraint on the execution of the generative function (which would require the generative function to provide an `edit` implementation which responds to the `EditRequest` induced by the [`importance`][genjax.core.GenerativeFunction.importance] interface).

    Examples:
        This method is similar to density evaluation interfaces for distributions.
        ```python exec="yes" html="true" source="material-block" session="core"
        from genjax import normal
        from genjax import ChoiceMapBuilder as C

        sample = C.v(1.0)
        score, retval = normal.assess(sample, (1.0, 1.0))
        print((score, retval))
        ```

        But it also works with generative functions that sample from spaces with more structure:

        ```python exec="yes" html="true" source="material-block" session="core"
        from genjax import gen
        from genjax import normal
        from genjax import ChoiceMapBuilder as C


        @gen
        def model():
            v1 = normal(0.0, 1.0) @ "v1"
            v2 = normal(v1, 1.0) @ "v2"


        sample = C.kw(v1=1.0, v2=0.0)
        score, retval = model.assess(sample, ())
        print((score, retval))
        ```
    """

contramap

contramap(
    f: Callable[..., ArgTuple]
) -> GenerativeFunction[R]

Specialized version of genjax.GenerativeFunction.dimap where only the pre-processing function is applied.

Parameters:

Name Type Description Default

f

Callable[..., ArgTuple]

A callable that preprocesses the arguments of the wrapped function. Note that f must return a tuple of arguments, not a bare argument.

required

Returns:

Type Description
GenerativeFunction[R]

A genjax.GenerativeFunction that acts like self with a pre-processing function to its arguments.

Examples:

import jax, genjax


# Define a pre-processing function.
# Note that this function must return a tuple of arguments!
def add_one(x):
    return (x + 1,)


@genjax.gen
def model(x):
    return genjax.normal(x, 1.0) @ "z"


contramap_model = model.contramap(add_one)

# Use the contramap model
key = jax.random.key(0)
trace = contramap_model.simulate(key, (2.0,))

print(trace.render_html())
Source code in src/genjax/_src/core/generative/generative_function.py
def contramap(self, f: Callable[..., ArgTuple]) -> "GenerativeFunction[R]":
    """
    Specialized version of [`genjax.GenerativeFunction.dimap`][] where only the pre-processing function is applied.

    Args:
        f: A callable that preprocesses the arguments of the wrapped function. Note that `f` must return a _tuple_ of arguments, not a bare argument.

    Returns:
        A [`genjax.GenerativeFunction`][] that acts like `self` with a pre-processing function to its arguments.

    Examples:
        ```python exec="yes" html="true" source="material-block" session="contramap"
        import jax, genjax


        # Define a pre-processing function.
        # Note that this function must return a tuple of arguments!
        def add_one(x):
            return (x + 1,)


        @genjax.gen
        def model(x):
            return genjax.normal(x, 1.0) @ "z"


        contramap_model = model.contramap(add_one)

        # Use the contramap model
        key = jax.random.key(0)
        trace = contramap_model.simulate(key, (2.0,))

        print(trace.render_html())
        ```
    """
    import genjax

    return genjax.contramap(f=f)(self)

dimap

dimap(
    *,
    pre: Callable[..., ArgTuple],
    post: Callable[[tuple[Any, ...], ArgTuple, R], S]
) -> GenerativeFunction[S]

Returns a new genjax.GenerativeFunction and applies pre- and post-processing functions to its arguments and return value.

Info

Prefer genjax.GenerativeFunction.map if you only need to transform the return value, or genjax.GenerativeFunction.contramap if you only need to transform the arguments.

Parameters:

Name Type Description Default

pre

Callable[..., ArgTuple]

A callable that preprocesses the arguments before passing them to the wrapped function. Note that pre must return a tuple of arguments, not a bare argument. Default is the identity function.

required

post

Callable[[tuple[Any, ...], ArgTuple, R], S]

A callable that postprocesses the return value of the wrapped function. Default is the identity function.

required

Returns:

Type Description
GenerativeFunction[S]

A new genjax.GenerativeFunction with pre and post applied.

Examples:

import jax, genjax


# Define pre- and post-processing functions
def pre_process(x, y):
    return (x + 1, y * 2)


def post_process(args, xformed, retval):
    return retval**2


@genjax.gen
def model(x, y):
    return genjax.normal(x, y) @ "z"


dimap_model = model.dimap(pre=pre_process, post=post_process)

# Use the dimap model
key = jax.random.key(0)
trace = dimap_model.simulate(key, (2.0, 3.0))

print(trace.render_html())
Source code in src/genjax/_src/core/generative/generative_function.py
def dimap(
    self,
    /,
    *,
    pre: Callable[..., ArgTuple],
    post: Callable[[tuple[Any, ...], ArgTuple, R], S],
) -> "GenerativeFunction[S]":
    """
    Returns a new [`genjax.GenerativeFunction`][] and applies pre- and post-processing functions to its arguments and return value.

    !!! info
        Prefer [`genjax.GenerativeFunction.map`][] if you only need to transform the return value, or [`genjax.GenerativeFunction.contramap`][] if you only need to transform the arguments.

    Args:
        pre: A callable that preprocesses the arguments before passing them to the wrapped function. Note that `pre` must return a _tuple_ of arguments, not a bare argument. Default is the identity function.
        post: A callable that postprocesses the return value of the wrapped function. Default is the identity function.

    Returns:
        A new [`genjax.GenerativeFunction`][] with `pre` and `post` applied.

    Examples:
        ```python exec="yes" html="true" source="material-block" session="dimap"
        import jax, genjax


        # Define pre- and post-processing functions
        def pre_process(x, y):
            return (x + 1, y * 2)


        def post_process(args, xformed, retval):
            return retval**2


        @genjax.gen
        def model(x, y):
            return genjax.normal(x, y) @ "z"


        dimap_model = model.dimap(pre=pre_process, post=post_process)

        # Use the dimap model
        key = jax.random.key(0)
        trace = dimap_model.simulate(key, (2.0, 3.0))

        print(trace.render_html())
        ```
    """
    import genjax

    return genjax.dimap(pre=pre, post=post)(self)

edit abstractmethod

edit(
    key: PRNGKey,
    trace: Trace[R],
    edit_request: EditRequest,
    argdiffs: Argdiffs,
) -> tuple[Trace[R], Weight, Retdiff[R], EditRequest]

Update a trace in response to an EditRequest, returning a new Trace, an incremental Weight for the new target, a Retdiff return value tagged with change information, and a backward EditRequest which requests the reverse move (to go back to the original trace).

The specification of this interface is parametric over the kind of EditRequest -- responding to an EditRequest instance requires that the generative function provides an implementation of a sequential Monte Carlo move in the SMCP3 framework. Users of inference algorithms are not expected to understand the ingredients, but inference algorithm developers are.

Examples:

Updating a trace in response to a request for a Target change induced by a change to the arguments:

import jax
from genjax import gen, normal, Diff, Update, ChoiceMap as C

key = jax.random.key(0)


@gen
def model(var):
    v1 = normal(0.0, 1.0) @ "v1"
    v2 = normal(v1, var) @ "v2"
    return v2


# Generating an initial trace properly weighted according
# to the target induced by the constraint.
constraint = C.kw(v2=1.0)
initial_tr, w = model.importance(key, constraint, (1.0,))

# Updating the trace to a new target.
new_tr, inc_w, retdiff, bwd_prob = model.edit(
    key,
    initial_tr,
    Update(
        C.empty(),
    ),
    Diff.unknown_change((3.0,)),
)

Now, let's inspect the trace:

# Inspect the trace, the sampled values should not have changed!
sample = new_tr.get_choices()
print(sample["v1"], sample["v2"])
-2.4424558 1.0

And the return value diff:

# The return value also should not have changed!
print(retdiff.render_html())

As expected, neither have changed -- but the weight is non-zero:

print(w)
-6.8441896

Mathematical ingredients behind edit

The edit interface exposes SMCP3 moves. Here, we omit the measure theoretic description, and refer interested readers to the paper. Informally, the ingredients of such a move are:

  • The previous target \(T\).
  • The new target \(T'\).
  • A pair of kernel probabilistic programs, called \(K\) and \(L\):
    • The K kernel is a kernel probabilistic program which accepts a previous sample \(x_{t-1}\) from \(T\) as an argument, may sample auxiliary randomness \(u_K\), and returns a new sample \(x_t\) approximately distributed according to \(T'\), along with transformed randomness \(u_L\).
    • The L kernel is a kernel probabilistic program which accepts the new sample \(x_t\), and provides a density evaluator for the auxiliary randomness \(u_L\) which K returns, and an inverter \(x_t \mapsto x_{t-1}\) which is almost everywhere the identity function.

The specification of these ingredients are encapsulated in the type signature of the edit interface.

Understanding the edit interface

The edit interface uses the mathematical ingredients described above to perform probability-aware mutations and incremental Weight computations on Trace instances, which allows Gen to provide automation to support inference agorithms like importance sampling, SMC, MCMC and many more.

An EditRequest denotes a function \(tr \mapsto (T, T')\) from traces to a pair of targets (the previous Target \(T\), and the final Target \(T'\)).

Several common types of moves can be requested via the Update type:

from genjax import Update
from genjax import ChoiceMap

g = Update(
    ChoiceMap.empty(),  # Constraint
)

Update contains information about changes to the arguments of the generative function (Argdiffs) and a constraint which specifies an additional move to be performed.

new_tr, inc_w, retdiff, bwd_prob = model.edit(
    key,
    initial_tr,
    Update(
        C.kw(v1=3.0),
    ),
    Diff.unknown_change((3.0,)),
)
print((new_tr.get_choices()["v1"], w))
(3.0, Array(-6.8441896, dtype=float32))

Additional notes on Argdiffs

Argument changes induce changes to the distribution over samples, internal K and L proposals, and (by virtue of changes to \(P\)) target distributions. The Argdiffs type denotes the type of values attached with a change type, a piece of data which indicates how the value has changed from the arguments which created the trace. Generative functions can utilize change type information to inform efficient edit implementations.

Source code in src/genjax/_src/core/generative/generative_function.py
@abstractmethod
def edit(
    self,
    key: PRNGKey,
    trace: Trace[R],
    edit_request: EditRequest,
    argdiffs: Argdiffs,
) -> tuple[Trace[R], Weight, Retdiff[R], EditRequest]:
    """
    Update a trace in response to an [`EditRequest`][genjax.core.EditRequest], returning a new [`Trace`][genjax.core.Trace], an incremental [`Weight`][genjax.core.Weight] for the new target, a [`Retdiff`][genjax.core.Retdiff] return value tagged with change information, and a backward [`EditRequest`][genjax.core.EditRequest] which requests the reverse move (to go back to the original trace).

    The specification of this interface is parametric over the kind of `EditRequest` -- responding to an `EditRequest` instance requires that the generative function provides an implementation of a sequential Monte Carlo move in the [SMCP3](https://proceedings.mlr.press/v206/lew23a.html) framework. Users of inference algorithms are not expected to understand the ingredients, but inference algorithm developers are.

    Examples:
        Updating a trace in response to a request for a [`Target`][genjax.inference.Target] change induced by a change to the arguments:
        ```python exec="yes" source="material-block" session="core"
        import jax
        from genjax import gen, normal, Diff, Update, ChoiceMap as C

        key = jax.random.key(0)


        @gen
        def model(var):
            v1 = normal(0.0, 1.0) @ "v1"
            v2 = normal(v1, var) @ "v2"
            return v2


        # Generating an initial trace properly weighted according
        # to the target induced by the constraint.
        constraint = C.kw(v2=1.0)
        initial_tr, w = model.importance(key, constraint, (1.0,))

        # Updating the trace to a new target.
        new_tr, inc_w, retdiff, bwd_prob = model.edit(
            key,
            initial_tr,
            Update(
                C.empty(),
            ),
            Diff.unknown_change((3.0,)),
        )
        ```

        Now, let's inspect the trace:
        ```python exec="yes" html="true" source="material-block" session="core"
        # Inspect the trace, the sampled values should not have changed!
        sample = new_tr.get_choices()
        print(sample["v1"], sample["v2"])
        ```

        And the return value diff:
        ```python exec="yes" html="true" source="material-block" session="core"
        # The return value also should not have changed!
        print(retdiff.render_html())
        ```

        As expected, neither have changed -- but the weight is non-zero:
        ```python exec="yes" html="true" source="material-block" session="core"
        print(w)
        ```

    ## Mathematical ingredients behind edit

    The `edit` interface exposes [SMCP3 moves](https://proceedings.mlr.press/v206/lew23a.html). Here, we omit the measure theoretic description, and refer interested readers to [the paper](https://proceedings.mlr.press/v206/lew23a.html). Informally, the ingredients of such a move are:

    * The previous target $T$.
    * The new target $T'$.
    * A pair of kernel probabilistic programs, called $K$ and $L$:
        * The K kernel is a kernel probabilistic program which accepts a previous sample $x_{t-1}$ from $T$ as an argument, may sample auxiliary randomness $u_K$, and returns a new sample $x_t$ approximately distributed according to $T'$, along with transformed randomness $u_L$.
        * The L kernel is a kernel probabilistic program which accepts the new sample $x_t$, and provides a density evaluator for the auxiliary randomness $u_L$ which K returns, and an inverter $x_t \\mapsto x_{t-1}$ which is _almost everywhere_ the identity function.

    The specification of these ingredients are encapsulated in the type signature of the `edit` interface.

    ## Understanding the `edit` interface

    The `edit` interface uses the mathematical ingredients described above to perform probability-aware mutations and incremental [`Weight`][genjax.core.Weight] computations on [`Trace`][genjax.core.Trace] instances, which allows Gen to provide automation to support inference agorithms like importance sampling, SMC, MCMC and many more.

    An `EditRequest` denotes a function $tr \\mapsto (T, T')$ from traces to a pair of targets (the previous [`Target`][genjax.inference.Target] $T$, and the final [`Target`][genjax.inference.Target] $T'$).

    Several common types of moves can be requested via the `Update` type:

    ```python exec="yes" source="material-block" session="core"
    from genjax import Update
    from genjax import ChoiceMap

    g = Update(
        ChoiceMap.empty(),  # Constraint
    )
    ```

    `Update` contains information about changes to the arguments of the generative function ([`Argdiffs`][genjax.core.Argdiffs]) and a constraint which specifies an additional move to be performed.

    ```python exec="yes" html="true" source="material-block" session="core"
    new_tr, inc_w, retdiff, bwd_prob = model.edit(
        key,
        initial_tr,
        Update(
            C.kw(v1=3.0),
        ),
        Diff.unknown_change((3.0,)),
    )
    print((new_tr.get_choices()["v1"], w))
    ```

    **Additional notes on [`Argdiffs`][genjax.core.Argdiffs]**

    Argument changes induce changes to the distribution over samples, internal K and L proposals, and (by virtue of changes to $P$) target distributions. The [`Argdiffs`][genjax.core.Argdiffs] type denotes the type of values attached with a _change type_, a piece of data which indicates how the value has changed from the arguments which created the trace. Generative functions can utilize change type information to inform efficient [`edit`][genjax.core.GenerativeFunction.edit] implementations.
    """
    pass

get_zero_trace

get_zero_trace(*args, **_kwargs) -> Trace[R]
    Returns a trace with zero values for all leaves, generated without executing the generative function.

    This method is useful for static analysis and shape inference without executing the generative function. It returns a trace with the same structure as a real trace, but filled with zero or default values.

    Args:
        *args: The arguments to the generative function.
        **_kwargs: Ignored keyword arguments.

    Returns:
        A trace with zero values, matching the structure of a real trace.

    Note:
        This method uses the `empty_trace` utility function, which creates a trace without spending any FLOPs. The resulting trace has the correct structure but contains placeholder zero values.

    Example:
        ```python exec="yes" html="true" source="material-block" session="core"
        @genjax.gen
        def weather_model():
            temperature = genjax.normal(20.0, 5.0) @ "temperature"
            is_sunny = genjax.bernoulli(0.7) @ "is_sunny"
            return {"temperature": temperature, "is_sunny": is_sunny}


        zero_trace = weather_model.get_zero_trace()
        print("Zero trace structure:")
        print(zero_trace.render_html())

        print("

Actual simulation:") key = jax.random.key(0) actual_trace = weather_model.simulate(key, ()) print(actual_trace.render_html()) ```

Source code in src/genjax/_src/core/generative/generative_function.py
def get_zero_trace(self, *args, **_kwargs) -> Trace[R]:
    """
    Returns a trace with zero values for all leaves, generated without executing the generative function.

    This method is useful for static analysis and shape inference without executing the generative function. It returns a trace with the same structure as a real trace, but filled with zero or default values.

    Args:
        *args: The arguments to the generative function.
        **_kwargs: Ignored keyword arguments.

    Returns:
        A trace with zero values, matching the structure of a real trace.

    Note:
        This method uses the `empty_trace` utility function, which creates a trace without spending any FLOPs. The resulting trace has the correct structure but contains placeholder zero values.

    Example:
        ```python exec="yes" html="true" source="material-block" session="core"
        @genjax.gen
        def weather_model():
            temperature = genjax.normal(20.0, 5.0) @ "temperature"
            is_sunny = genjax.bernoulli(0.7) @ "is_sunny"
            return {"temperature": temperature, "is_sunny": is_sunny}


        zero_trace = weather_model.get_zero_trace()
        print("Zero trace structure:")
        print(zero_trace.render_html())

        print("\nActual simulation:")
        key = jax.random.key(0)
        actual_trace = weather_model.simulate(key, ())
        print(actual_trace.render_html())
        ```
    """
    return empty_trace(self, args)

handle_kwargs

handle_kwargs() -> GenerativeFunction[R]

Returns a new GenerativeFunction like self, but where all GFI methods accept a tuple of arguments and a dictionary of keyword arguments.

The returned GenerativeFunction can be invoked with __call__ with no special argument handling (just like the original).

In place of args tuples in GFI methods, the new GenerativeFunction expects a 2-tuple containing:

  1. A tuple containing the original positional arguments.
  2. A dictionary containing the keyword arguments.

This allows for more flexible argument passing, especially useful in contexts where keyword arguments need to be handled separately or passed through multiple layers.

Returns:

Type Description
GenerativeFunction[R]

A new GenerativeFunction that accepts (args_tuple, kwargs_dict) for all GFI methods.

Example
import genjax
import jax


@genjax.gen
def model(x, y, z=1.0):
    _ = genjax.normal(x + y, z) @ "v"
    return x + y + z


key = jax.random.key(0)
kw_model = model.handle_kwargs()

tr = kw_model.simulate(key, ((1.0, 2.0), {"z": 3.0}))
print(tr.render_html())
Source code in src/genjax/_src/core/generative/generative_function.py
def handle_kwargs(self) -> "GenerativeFunction[R]":
    """
    Returns a new GenerativeFunction like `self`, but where all GFI methods accept a tuple of arguments and a dictionary of keyword arguments.

    The returned GenerativeFunction can be invoked with `__call__` with no special argument handling (just like the original).

    In place of `args` tuples in GFI methods, the new GenerativeFunction expects a 2-tuple containing:

    1. A tuple containing the original positional arguments.
    2. A dictionary containing the keyword arguments.

    This allows for more flexible argument passing, especially useful in contexts where
    keyword arguments need to be handled separately or passed through multiple layers.

    Returns:
        A new GenerativeFunction that accepts (args_tuple, kwargs_dict) for all GFI methods.

    Example:
        ```python exec="yes" html="true" source="material-block" session="core"
        import genjax
        import jax


        @genjax.gen
        def model(x, y, z=1.0):
            _ = genjax.normal(x + y, z) @ "v"
            return x + y + z


        key = jax.random.key(0)
        kw_model = model.handle_kwargs()

        tr = kw_model.simulate(key, ((1.0, 2.0), {"z": 3.0}))
        print(tr.render_html())
        ```
    """
    return IgnoreKwargs(self)

importance

importance(
    key: PRNGKey, constraint: ChoiceMap, args: Arguments
) -> tuple[Trace[R], Weight]

Returns a properly weighted pair, a Trace and a Weight, properly weighted for the target induced by the generative function for the provided constraint and arguments.

Examples:

(Full constraints) A simple example using the importance interface on distributions:

import jax
from genjax import normal
from genjax import ChoiceMapBuilder as C

key = jax.random.key(0)

tr, w = normal.importance(key, C.v(1.0), (0.0, 1.0))
print(tr.get_choices().render_html())

(Internal proposal for partial constraints) Specifying a partial constraint on a StaticGenerativeFunction:

from genjax import flip, uniform, gen
from genjax import ChoiceMapBuilder as C


@gen
def model():
    p = uniform(0.0, 1.0) @ "p"
    f1 = flip(p) @ "f1"
    f2 = flip(p) @ "f2"


tr, w = model.importance(key, C.kw(f1=True, f2=True), ())
print(tr.get_choices().render_html())

Under the hood, creates an EditRequest which requests that the generative function respond with a move from the empty trace (the only possible value for empty target \(\delta_\emptyset\)) to the target induced by the generative function for constraint \(C\) with arguments \(a\).

Source code in src/genjax/_src/core/generative/generative_function.py
def importance(
    self,
    key: PRNGKey,
    constraint: ChoiceMap,
    args: Arguments,
) -> tuple[Trace[R], Weight]:
    """
    Returns a properly weighted pair, a [`Trace`][genjax.core.Trace] and a [`Weight`][genjax.core.Weight], properly weighted for the target induced by the generative function for the provided constraint and arguments.

    Examples:
        (**Full constraints**) A simple example using the `importance` interface on distributions:
        ```python exec="yes" html="true" source="material-block" session="core"
        import jax
        from genjax import normal
        from genjax import ChoiceMapBuilder as C

        key = jax.random.key(0)

        tr, w = normal.importance(key, C.v(1.0), (0.0, 1.0))
        print(tr.get_choices().render_html())
        ```

        (**Internal proposal for partial constraints**) Specifying a _partial_ constraint on a [`StaticGenerativeFunction`][genjax.StaticGenerativeFunction]:
        ```python exec="yes" html="true" source="material-block" session="core"
        from genjax import flip, uniform, gen
        from genjax import ChoiceMapBuilder as C


        @gen
        def model():
            p = uniform(0.0, 1.0) @ "p"
            f1 = flip(p) @ "f1"
            f2 = flip(p) @ "f2"


        tr, w = model.importance(key, C.kw(f1=True, f2=True), ())
        print(tr.get_choices().render_html())
        ```

    Under the hood, creates an [`EditRequest`][genjax.core.EditRequest] which requests that the generative function respond with a move from the _empty_ trace (the only possible value for _empty_ target $\\delta_\\emptyset$) to the target induced by the generative function for constraint $C$ with arguments $a$.
    """

    return self.generate(
        key,
        constraint,
        args,
    )

iterate

iterate(*, n: int) -> GenerativeFunction[R]

When called on a genjax.GenerativeFunction of type a -> a, returns a new genjax.GenerativeFunction of type a -> [a] where

  • a is a loop-carried value, which must hold a fixed shape and dtype across all iterations
  • [a] is an array of all a, f(a), f(f(a)) etc. values seen during iteration.

All traced values are nested under an index.

The semantics of the returned genjax.GenerativeFunction are given roughly by this Python implementation:

def iterate(f, n, init):
    input = init
    seen = [init]
    for _ in range(n):
        input = f(input)
        seen.append(input)
    return seen

init may be an arbitrary pytree value, and so multiple arrays can be iterated over at once and produce multiple output arrays.

The iterated value a must hold a fixed shape and dtype across all iterations (and not just be consistent up to NumPy rank/shape broadcasting and dtype promotion rules, for example). In other words, the type a in the type signature above represents an array with a fixed shape and dtype (or a nested tuple/list/dict container data structure with a fixed structure and arrays with fixed shape and dtype at the leaves).

Parameters:

Name Type Description Default

n

int

the number of iterations to run.

required

Examples:

iterative addition, returning all intermediate sums:

import jax
import genjax


@genjax.iterate(n=100)
@genjax.gen
def inc(x):
    return x + 1


init = 0.0
key = jax.random.key(314159)

tr = jax.jit(inc.simulate)(key, (init,))
print(tr.render_html())

Source code in src/genjax/_src/core/generative/generative_function.py
def iterate(
    self,
    /,
    *,
    n: int,
) -> "GenerativeFunction[R]":
    """
    When called on a [`genjax.GenerativeFunction`][] of type `a -> a`, returns a new [`genjax.GenerativeFunction`][] of type `a -> [a]` where

    - `a` is a loop-carried value, which must hold a fixed shape and dtype across all iterations
    - `[a]` is an array of all `a`, `f(a)`, `f(f(a))` etc. values seen during iteration.

    All traced values are nested under an index.

    The semantics of the returned [`genjax.GenerativeFunction`][] are given roughly by this Python implementation:

    ```python
    def iterate(f, n, init):
        input = init
        seen = [init]
        for _ in range(n):
            input = f(input)
            seen.append(input)
        return seen
    ```

    `init` may be an arbitrary pytree value, and so multiple arrays can be iterated over at once and produce multiple output arrays.

    The iterated value `a` must hold a fixed shape and dtype across all iterations (and not just be consistent up to NumPy rank/shape broadcasting and dtype promotion rules, for example). In other words, the type `a` in the type signature above represents an array with a fixed shape and dtype (or a nested tuple/list/dict container data structure with a fixed structure and arrays with fixed shape and dtype at the leaves).

    Args:
        n: the number of iterations to run.

    Examples:
        iterative addition, returning all intermediate sums:
        ```python exec="yes" html="true" source="material-block" session="scan"
        import jax
        import genjax


        @genjax.iterate(n=100)
        @genjax.gen
        def inc(x):
            return x + 1


        init = 0.0
        key = jax.random.key(314159)

        tr = jax.jit(inc.simulate)(key, (init,))
        print(tr.render_html())
        ```
    """
    import genjax

    return genjax.iterate(n=n)(self)

iterate_final

iterate_final(*, n: int) -> GenerativeFunction[R]

Returns a decorator that wraps a genjax.GenerativeFunction of type a -> a and returns a new genjax.GenerativeFunction of type a -> a where

  • a is a loop-carried value, which must hold a fixed shape and dtype across all iterations
  • the original function is invoked n times with each input coming from the previous invocation's output, so that the new function returns \(f^n(a)\)

All traced values are nested under an index.

The semantics of the returned genjax.GenerativeFunction are given roughly by this Python implementation:

def iterate_final(f, n, init):
    ret = init
    for _ in range(n):
        ret = f(ret)
    return ret

init may be an arbitrary pytree value, and so multiple arrays can be iterated over at once and produce multiple output arrays.

The iterated value a must hold a fixed shape and dtype across all iterations (and not just be consistent up to NumPy rank/shape broadcasting and dtype promotion rules, for example). In other words, the type a in the type signature above represents an array with a fixed shape and dtype (or a nested tuple/list/dict container data structure with a fixed structure and arrays with fixed shape and dtype at the leaves).

Parameters:

Name Type Description Default

n

int

the number of iterations to run.

required

Examples:

iterative addition:

import jax
import genjax


@genjax.iterate_final(n=100)
@genjax.gen
def inc(x):
    return x + 1


init = 0.0
key = jax.random.key(314159)

tr = jax.jit(inc.simulate)(key, (init,))
print(tr.render_html())

Source code in src/genjax/_src/core/generative/generative_function.py
def iterate_final(
    self,
    /,
    *,
    n: int,
) -> "GenerativeFunction[R]":
    """
    Returns a decorator that wraps a [`genjax.GenerativeFunction`][] of type `a -> a` and returns a new [`genjax.GenerativeFunction`][] of type `a -> a` where

    - `a` is a loop-carried value, which must hold a fixed shape and dtype across all iterations
    - the original function is invoked `n` times with each input coming from the previous invocation's output, so that the new function returns $f^n(a)$

    All traced values are nested under an index.

    The semantics of the returned [`genjax.GenerativeFunction`][] are given roughly by this Python implementation:

    ```python
    def iterate_final(f, n, init):
        ret = init
        for _ in range(n):
            ret = f(ret)
        return ret
    ```

    `init` may be an arbitrary pytree value, and so multiple arrays can be iterated over at once and produce multiple output arrays.

    The iterated value `a` must hold a fixed shape and dtype across all iterations (and not just be consistent up to NumPy rank/shape broadcasting and dtype promotion rules, for example). In other words, the type `a` in the type signature above represents an array with a fixed shape and dtype (or a nested tuple/list/dict container data structure with a fixed structure and arrays with fixed shape and dtype at the leaves).

    Args:
        n: the number of iterations to run.

    Examples:
        iterative addition:
        ```python exec="yes" html="true" source="material-block" session="scan"
        import jax
        import genjax


        @genjax.iterate_final(n=100)
        @genjax.gen
        def inc(x):
            return x + 1


        init = 0.0
        key = jax.random.key(314159)

        tr = jax.jit(inc.simulate)(key, (init,))
        print(tr.render_html())
        ```
    """
    import genjax

    return genjax.iterate_final(n=n)(self)

map

map(f: Callable[[R], S]) -> GenerativeFunction[S]

Specialized version of genjax.dimap where only the post-processing function is applied.

Parameters:

Name Type Description Default

f

Callable[[R], S]

A callable that postprocesses the return value of the wrapped function.

required

Returns:

Type Description
GenerativeFunction[S]

A genjax.GenerativeFunction that acts like self with a post-processing function to its return value.

Examples:

import jax, genjax


# Define a post-processing function
def square(x):
    return x**2


@genjax.gen
def model(x):
    return genjax.normal(x, 1.0) @ "z"


map_model = model.map(square)

# Use the map model
key = jax.random.key(0)
trace = map_model.simulate(key, (2.0,))

print(trace.render_html())
Source code in src/genjax/_src/core/generative/generative_function.py
def map(self, f: Callable[[R], S]) -> "GenerativeFunction[S]":
    """
    Specialized version of [`genjax.dimap`][] where only the post-processing function is applied.

    Args:
        f: A callable that postprocesses the return value of the wrapped function.

    Returns:
        A [`genjax.GenerativeFunction`][] that acts like `self` with a post-processing function to its return value.

    Examples:
        ```python exec="yes" html="true" source="material-block" session="map"
        import jax, genjax


        # Define a post-processing function
        def square(x):
            return x**2


        @genjax.gen
        def model(x):
            return genjax.normal(x, 1.0) @ "z"


        map_model = model.map(square)

        # Use the map model
        key = jax.random.key(0)
        trace = map_model.simulate(key, (2.0,))

        print(trace.render_html())
        ```
    """
    import genjax

    return genjax.map(f=f)(self)

mask

mask() -> GenerativeFunction[Mask[R]]

Enables dynamic masking of generative functions. Returns a new genjax.GenerativeFunction like self, but which accepts an additional boolean first argument.

If True, the invocation of self is masked, and its contribution to the score is ignored. If False, it has the same semantics as if one was invoking self without masking.

The return value type is a Mask, with a flag value equal to the supplied boolean.

Returns:

Type Description
GenerativeFunction[Mask[R]]

The masked version of the original genjax.GenerativeFunction.

Examples:

Masking a normal draw:

import genjax, jax


@genjax.gen
def normal_draw(mean):
    return genjax.normal(mean, 1.0) @ "x"


masked_normal_draw = normal_draw.mask()

key = jax.random.key(314159)
tr = jax.jit(masked_normal_draw.simulate)(
    key,
    (
        False,
        2.0,
    ),
)
print(tr.render_html())

Source code in src/genjax/_src/core/generative/generative_function.py
def mask(self, /) -> "GenerativeFunction[genjax.Mask[R]]":
    """
    Enables dynamic masking of generative functions. Returns a new [`genjax.GenerativeFunction`][] like `self`, but which accepts an additional boolean first argument.

    If `True`, the invocation of `self` is masked, and its contribution to the score is ignored. If `False`, it has the same semantics as if one was invoking `self` without masking.

    The return value type is a `Mask`, with a flag value equal to the supplied boolean.

    Returns:
        The masked version of the original [`genjax.GenerativeFunction`][].

    Examples:
        Masking a normal draw:
        ```python exec="yes" html="true" source="material-block" session="mask"
        import genjax, jax


        @genjax.gen
        def normal_draw(mean):
            return genjax.normal(mean, 1.0) @ "x"


        masked_normal_draw = normal_draw.mask()

        key = jax.random.key(314159)
        tr = jax.jit(masked_normal_draw.simulate)(
            key,
            (
                False,
                2.0,
            ),
        )
        print(tr.render_html())
        ```
    """
    import genjax

    return genjax.mask(self)

masked_iterate

masked_iterate() -> GenerativeFunction[R]

Transforms a generative function that takes a single argument of type a and returns a value of type a, into a function that takes a tuple of arguments (a, [mask]) and returns a list of values of type a.

The original function is modified to accept an additional argument mask, which is a boolean value indicating whether the operation should be masked or not. The function returns a Masked list of results of the original operation with the input [mask] as mask.

All traced values from the kernel generative function are traced (with an added axis due to the scan) but only those indices from [mask] with a flag of True will accounted for in inference, notably for score computations.

Example
import jax
import genjax

masks = jnp.array([True, False, True])


# Create a kernel generative function
@genjax.gen
def step(x):
    _ = (
        genjax.normal.mask().vmap(in_axes=(0, None, None))(masks, x, 1.0)
        @ "rats"
    )
    return x


# Create a model using masked_iterate
model = step.masked_iterate()

# Simulate from the model
key = jax.random.key(0)
mask_steps = jnp.arange(10) < 5
tr = model.simulate(key, (0.0, mask_steps))
print(tr.render_html())
Source code in src/genjax/_src/core/generative/generative_function.py
def masked_iterate(self) -> "GenerativeFunction[R]":
    """
    Transforms a generative function that takes a single argument of type `a` and returns a value of type `a`, into a function that takes a tuple of arguments `(a, [mask])` and returns a list of values of type `a`.

    The original function is modified to accept an additional argument `mask`, which is a boolean value indicating whether the operation should be masked or not. The function returns a Masked list of results of the original operation with the input [mask] as mask.

    All traced values from the kernel generative function are traced (with an added axis due to the scan) but only those indices from [mask] with a flag of True will accounted for in inference, notably for score computations.

    Example:
        ```python exec="yes" html="true" source="material-block" session="scan"
        import jax
        import genjax

        masks = jnp.array([True, False, True])


        # Create a kernel generative function
        @genjax.gen
        def step(x):
            _ = (
                genjax.normal.mask().vmap(in_axes=(0, None, None))(masks, x, 1.0)
                @ "rats"
            )
            return x


        # Create a model using masked_iterate
        model = step.masked_iterate()

        # Simulate from the model
        key = jax.random.key(0)
        mask_steps = jnp.arange(10) < 5
        tr = model.simulate(key, (0.0, mask_steps))
        print(tr.render_html())
        ```
    """
    import genjax

    return genjax.masked_iterate()(self)

masked_iterate_final

masked_iterate_final() -> GenerativeFunction[R]

Transforms a generative function that takes a single argument of type a and returns a value of type a, into a function that takes a tuple of arguments (a, [mask]) and returns a value of type a.

The original function is modified to accept an additional argument mask, which is a boolean value indicating whether the operation should be masked or not. The function returns the result of the original operation if mask is True, and the original input if mask is False.

All traced values from the kernel generative function are traced (with an added axis due to the scan) but only those indices from [mask] with a flag of True will accounted for in inference, notably for score computations.

Example
import jax
import genjax

masks = jnp.array([True, False, True])


# Create a kernel generative function
@genjax.gen
def step(x):
    _ = (
        genjax.normal.mask().vmap(in_axes=(0, None, None))(masks, x, 1.0)
        @ "rats"
    )
    return x


# Create a model using masked_iterate_final
model = step.masked_iterate_final()

# Simulate from the model
key = jax.random.key(0)
mask_steps = jnp.arange(10) < 5
tr = model.simulate(key, (0.0, mask_steps))
print(tr.render_html())
Source code in src/genjax/_src/core/generative/generative_function.py
def masked_iterate_final(self) -> "GenerativeFunction[R]":
    """
    Transforms a generative function that takes a single argument of type `a` and returns a value of type `a`, into a function that takes a tuple of arguments `(a, [mask])` and returns a value of type `a`.

    The original function is modified to accept an additional argument `mask`, which is a boolean value indicating whether the operation should be masked or not. The function returns the result of the original operation if `mask` is `True`, and the original input if `mask` is `False`.

    All traced values from the kernel generative function are traced (with an added axis due to the scan) but only those indices from [mask] with a flag of True will accounted for in inference, notably for score computations.

    Example:
        ```python exec="yes" html="true" source="material-block" session="scan"
        import jax
        import genjax

        masks = jnp.array([True, False, True])


        # Create a kernel generative function
        @genjax.gen
        def step(x):
            _ = (
                genjax.normal.mask().vmap(in_axes=(0, None, None))(masks, x, 1.0)
                @ "rats"
            )
            return x


        # Create a model using masked_iterate_final
        model = step.masked_iterate_final()

        # Simulate from the model
        key = jax.random.key(0)
        mask_steps = jnp.arange(10) < 5
        tr = model.simulate(key, (0.0, mask_steps))
        print(tr.render_html())
        ```
    """
    import genjax

    return genjax.masked_iterate_final()(self)

mix

Takes any number of genjax.GenerativeFunctions and returns a new genjax.GenerativeFunction that represents a mixture model.

The returned generative function takes the following arguments:

  • mixture_logits: Logits for the categorical distribution used to select a component.
  • *args: Argument tuples for self and each of the input generative functions

and samples from self or one of the input generative functions based on a draw from a categorical distribution defined by the provided mixture logits.

Parameters:

Name Type Description Default

*fns

GenerativeFunction[R]

Variable number of genjax.GenerativeFunctions to be mixed with self.

()

Returns:

Type Description
GenerativeFunction[R]

A new genjax.GenerativeFunction representing the mixture model.

Examples:

```python exec="yes" html="true" source="material-block" session="mix" import jax import genjax

Define component generative functions

@genjax.gen def component1(x): return genjax.normal(x, 1.0) @ "y"

@genjax.gen def component2(x): return genjax.normal(x, 2.0) @ "y"

Create mixture model

mixture = component1.mix(component2)

Use the mixture model

key = jax.random.key(0) logits = jax.numpy.array([0.3, 0.7]) # Favors component2 trace = mixture.simulate(key, (logits, (0.0,), (7.0,))) print(trace.render_html()) ```

Source code in src/genjax/_src/core/generative/generative_function.py
def mix(self, *fns: "GenerativeFunction[R]") -> "GenerativeFunction[R]":
    """
    Takes any number of [`genjax.GenerativeFunction`][]s and returns a new [`genjax.GenerativeFunction`][] that represents a mixture model.

    The returned generative function takes the following arguments:

    - `mixture_logits`: Logits for the categorical distribution used to select a component.
    - `*args`: Argument tuples for `self` and each of the input generative functions

    and samples from `self` or one of the input generative functions based on a draw from a categorical distribution defined by the provided mixture logits.

    Args:
        *fns: Variable number of [`genjax.GenerativeFunction`][]s to be mixed with `self`.

    Returns:
        A new [`genjax.GenerativeFunction`][] representing the mixture model.

    Examples:
        ```python exec="yes" html="true" source="material-block" session="mix"
        import jax
        import genjax


        # Define component generative functions
        @genjax.gen
        def component1(x):
            return genjax.normal(x, 1.0) @ "y"


        @genjax.gen
        def component2(x):
            return genjax.normal(x, 2.0) @ "y"


        # Create mixture model
        mixture = component1.mix(component2)

        # Use the mixture model
        key = jax.random.key(0)
        logits = jax.numpy.array([0.3, 0.7])  # Favors component2
        trace = mixture.simulate(key, (logits, (0.0,), (7.0,)))
        print(trace.render_html())
            ```
    """
    import genjax

    return genjax.mix(self, *fns)

or_else

Returns a GenerativeFunction that accepts

  • a boolean argument
  • an argument tuple for self
  • an argument tuple for the supplied gen_fn

and acts like self when the boolean is True or like gen_fn otherwise.

Parameters:

Name Type Description Default

gen_fn

GenerativeFunction[R]

called when the boolean argument is False.

required

Examples:

import jax
import jax.numpy as jnp
import genjax


@genjax.gen
def if_model(x):
    return genjax.normal(x, 1.0) @ "if_value"


@genjax.gen
def else_model(x):
    return genjax.normal(x, 5.0) @ "else_value"


@genjax.gen
def model(toss: bool):
    # Note that the returned model takes a new boolean predicate in
    # addition to argument tuples for each branch.
    return if_model.or_else(else_model)(toss, (1.0,), (10.0,)) @ "tossed"


key = jax.random.key(314159)

tr = jax.jit(model.simulate)(key, (True,))

print(tr.render_html())
Source code in src/genjax/_src/core/generative/generative_function.py
def or_else(self, gen_fn: "GenerativeFunction[R]", /) -> "GenerativeFunction[R]":
    """
    Returns a [`GenerativeFunction`][genjax.GenerativeFunction] that accepts

    - a boolean argument
    - an argument tuple for `self`
    - an argument tuple for the supplied `gen_fn`

    and acts like `self` when the boolean is `True` or like `gen_fn` otherwise.

    Args:
        gen_fn: called when the boolean argument is `False`.

    Examples:
        ```python exec="yes" html="true" source="material-block" session="gen-fn"
        import jax
        import jax.numpy as jnp
        import genjax


        @genjax.gen
        def if_model(x):
            return genjax.normal(x, 1.0) @ "if_value"


        @genjax.gen
        def else_model(x):
            return genjax.normal(x, 5.0) @ "else_value"


        @genjax.gen
        def model(toss: bool):
            # Note that the returned model takes a new boolean predicate in
            # addition to argument tuples for each branch.
            return if_model.or_else(else_model)(toss, (1.0,), (10.0,)) @ "tossed"


        key = jax.random.key(314159)

        tr = jax.jit(model.simulate)(key, (True,))

        print(tr.render_html())
        ```
    """
    import genjax

    return genjax.or_else(self, gen_fn)

propose

propose(
    key: PRNGKey, args: Arguments
) -> tuple[ChoiceMap, Score, R]

Samples a ChoiceMap and any untraced randomness \(r\) from the generative function's distribution over samples (\(P\)), and returns the Score of that sample under the distribution, and the R of the generative function's return value function \(f(r, t, a)\) for the sample and untraced randomness.

Source code in src/genjax/_src/core/generative/generative_function.py
def propose(
    self,
    key: PRNGKey,
    args: Arguments,
) -> tuple[ChoiceMap, Score, R]:
    """
    Samples a [`ChoiceMap`][genjax.core.ChoiceMap] and any untraced randomness $r$ from the generative function's distribution over samples ($P$), and returns the [`Score`][genjax.core.Score] of that sample under the distribution, and the `R` of the generative function's return value function $f(r, t, a)$ for the sample and untraced randomness.
    """
    tr = self.simulate(key, args)
    sample = tr.get_choices()
    score = tr.get_score()
    retval = tr.get_retval()
    return sample, score, retval

reduce

reduce() -> GenerativeFunction[R]

When called on a genjax.GenerativeFunction of type (c, a) -> c, returns a new genjax.GenerativeFunction of type (c, [a]) -> c where

  • c is a loop-carried value, which must hold a fixed shape and dtype across all iterations
  • a may be a primitive, an array type or a pytree (container) type with array leaves

All traced values are nested under an index.

For any array type specifier t, [t] represents the type with an additional leading axis, and if t is a pytree (container) type with array leaves then [t] represents the type with the same pytree structure and corresponding leaves each with an additional leading axis.

The semantics of the returned genjax.GenerativeFunction are given roughly by this Python implementation (note the similarity to functools.reduce):

def reduce(f, init, xs):
    carry = init
    for x in xs:
        carry = f(carry, x)
    return carry

Unlike that Python version, both xs and carry may be arbitrary pytree values, and so multiple arrays can be scanned over at once and produce multiple output arrays.

The loop-carried value c must hold a fixed shape and dtype across all iterations (and not just be consistent up to NumPy rank/shape broadcasting and dtype promotion rules, for example). In other words, the type c in the type signature above represents an array with a fixed shape and dtype (or a nested tuple/list/dict container data structure with a fixed structure and arrays with fixed shape and dtype at the leaves).

Examples:

sum an array of numbers:

import jax
import genjax
import jax.numpy as jnp


@genjax.reduce()
@genjax.gen
def add(sum, x):
    new_sum = sum + x
    return new_sum


init = 0.0
key = jax.random.key(314159)
xs = jnp.ones(10)

tr = jax.jit(add.simulate)(key, (init, xs))
print(tr.render_html())

Source code in src/genjax/_src/core/generative/generative_function.py
def reduce(self) -> "GenerativeFunction[R]":
    """
    When called on a [`genjax.GenerativeFunction`][] of type `(c, a) -> c`, returns a new [`genjax.GenerativeFunction`][] of type `(c, [a]) -> c` where

    - `c` is a loop-carried value, which must hold a fixed shape and dtype across all iterations
    - `a` may be a primitive, an array type or a pytree (container) type with array leaves

    All traced values are nested under an index.

    For any array type specifier `t`, `[t]` represents the type with an additional leading axis, and if `t` is a pytree (container) type with array leaves then `[t]` represents the type with the same pytree structure and corresponding leaves each with an additional leading axis.

    The semantics of the returned [`genjax.GenerativeFunction`][] are given roughly by this Python implementation (note the similarity to [`functools.reduce`](https://docs.python.org/3/library/itertools.html#functools.reduce)):

    ```python
    def reduce(f, init, xs):
        carry = init
        for x in xs:
            carry = f(carry, x)
        return carry
    ```

    Unlike that Python version, both `xs` and `carry` may be arbitrary pytree values, and so multiple arrays can be scanned over at once and produce multiple output arrays.

    The loop-carried value `c` must hold a fixed shape and dtype across all iterations (and not just be consistent up to NumPy rank/shape broadcasting and dtype promotion rules, for example). In other words, the type `c` in the type signature above represents an array with a fixed shape and dtype (or a nested tuple/list/dict container data structure with a fixed structure and arrays with fixed shape and dtype at the leaves).

    Examples:
        sum an array of numbers:
        ```python exec="yes" html="true" source="material-block" session="scan"
        import jax
        import genjax
        import jax.numpy as jnp


        @genjax.reduce()
        @genjax.gen
        def add(sum, x):
            new_sum = sum + x
            return new_sum


        init = 0.0
        key = jax.random.key(314159)
        xs = jnp.ones(10)

        tr = jax.jit(add.simulate)(key, (init, xs))
        print(tr.render_html())
        ```
    """
    import genjax

    return genjax.reduce()(self)

repeat

repeat(*, n: int) -> GenerativeFunction[R]

Returns a genjax.GenerativeFunction that samples from self n times, returning a vector of n results.

The values traced by each call gen_fn will be nested under an integer index that matches the loop iteration index that generated it.

This combinator is useful for creating multiple samples from self in a batched manner.

Parameters:

Name Type Description Default

n

int

The number of times to sample from the generative function.

required

Returns:

Type Description
GenerativeFunction[R]

A new genjax.GenerativeFunction that samples from the original function n times.

Examples:

import genjax, jax


@genjax.gen
def normal_draw(mean):
    return genjax.normal(mean, 1.0) @ "x"


normal_draws = normal_draw.repeat(n=10)

key = jax.random.key(314159)

# Generate 10 draws from a normal distribution with mean 2.0
tr = jax.jit(normal_draws.simulate)(key, (2.0,))
print(tr.render_html())
Source code in src/genjax/_src/core/generative/generative_function.py
def repeat(self, /, *, n: int) -> "GenerativeFunction[R]":
    """
    Returns a [`genjax.GenerativeFunction`][] that samples from `self` `n` times, returning a vector of `n` results.

    The values traced by each call `gen_fn` will be nested under an integer index that matches the loop iteration index that generated it.

    This combinator is useful for creating multiple samples from `self` in a batched manner.

    Args:
        n: The number of times to sample from the generative function.

    Returns:
        A new [`genjax.GenerativeFunction`][] that samples from the original function `n` times.

    Examples:
        ```python exec="yes" html="true" source="material-block" session="repeat"
        import genjax, jax


        @genjax.gen
        def normal_draw(mean):
            return genjax.normal(mean, 1.0) @ "x"


        normal_draws = normal_draw.repeat(n=10)

        key = jax.random.key(314159)

        # Generate 10 draws from a normal distribution with mean 2.0
        tr = jax.jit(normal_draws.simulate)(key, (2.0,))
        print(tr.render_html())
        ```
    """
    import genjax

    return genjax.repeat(n=n)(self)

scan

scan(
    *, n: int | None = None
) -> GenerativeFunction[tuple[Carry, Y]]

When called on a genjax.GenerativeFunction of type (c, a) -> (c, b), returns a new genjax.GenerativeFunction of type (c, [a]) -> (c, [b]) where

  • c is a loop-carried value, which must hold a fixed shape and dtype across all iterations
  • a may be a primitive, an array type or a pytree (container) type with array leaves
  • b may be a primitive, an array type or a pytree (container) type with array leaves.

The values traced by each call to the original generative function will be nested under an integer index that matches the loop iteration index that generated it.

For any array type specifier t, [t] represents the type with an additional leading axis, and if t is a pytree (container) type with array leaves then [t] represents the type with the same pytree structure and corresponding leaves each with an additional leading axis.

When the type of xs in the snippet below (denoted [a] above) is an array type or None, and the type of ys in the snippet below (denoted [b] above) is an array type, the semantics of the returned genjax.GenerativeFunction are given roughly by this Python implementation:

def scan(f, init, xs, length=None):
    if xs is None:
        xs = [None] * length
    carry = init
    ys = []
    for x in xs:
        carry, y = f(carry, x)
        ys.append(y)
    return carry, np.stack(ys)

Unlike that Python version, both xs and ys may be arbitrary pytree values, and so multiple arrays can be scanned over at once and produce multiple output arrays. None is actually a special case of this, as it represents an empty pytree.

The loop-carried value c must hold a fixed shape and dtype across all iterations (and not just be consistent up to NumPy rank/shape broadcasting and dtype promotion rules, for example). In other words, the type c in the type signature above represents an array with a fixed shape and dtype (or a nested tuple/list/dict container data structure with a fixed structure and arrays with fixed shape and dtype at the leaves).

Parameters:

Name Type Description Default

n

int | None

optional integer specifying the number of loop iterations, which (if supplied) must agree with the sizes of leading axes of the arrays in the returned function's second argument. If supplied then the returned generative function can take None as its second argument.

None

Returns:

Type Description
GenerativeFunction[tuple[Carry, Y]]

A new genjax.GenerativeFunction that takes a loop-carried value and a new input, and returns a new loop-carried value along with either None or an output to be collected into the second return value.

Examples:

Scan for 1000 iterations with no array input:

import jax
import genjax


@genjax.gen
def random_walk_step(prev, _):
    x = genjax.normal(prev, 1.0) @ "x"
    return x, None


random_walk = random_walk_step.scan(n=1000)

init = 0.5
key = jax.random.key(314159)

tr = jax.jit(random_walk.simulate)(key, (init, None))
print(tr.render_html())

Scan across an input array:

import jax.numpy as jnp


@genjax.gen
def add_and_square_step(sum, x):
    new_sum = sum + x
    return new_sum, sum * sum


# notice no `n` parameter supplied:
add_and_square_all = add_and_square_step.scan()
init = 0.0
xs = jnp.ones(10)

tr = jax.jit(add_and_square_all.simulate)(key, (init, xs))

# The retval has the final carry and an array of all `sum*sum` returned.
print(tr.render_html())

Source code in src/genjax/_src/core/generative/generative_function.py
def scan(
    self: "GenerativeFunction[tuple[Carry, Y]]",
    /,
    *,
    n: int | None = None,
) -> "GenerativeFunction[tuple[Carry, Y]]":
    """
    When called on a [`genjax.GenerativeFunction`][] of type `(c, a) -> (c, b)`, returns a new [`genjax.GenerativeFunction`][] of type `(c, [a]) -> (c, [b])` where

    - `c` is a loop-carried value, which must hold a fixed shape and dtype across all iterations
    - `a` may be a primitive, an array type or a pytree (container) type with array leaves
    - `b` may be a primitive, an array type or a pytree (container) type with array leaves.

    The values traced by each call to the original generative function will be nested under an integer index that matches the loop iteration index that generated it.

    For any array type specifier `t`, `[t]` represents the type with an additional leading axis, and if `t` is a pytree (container) type with array leaves then `[t]` represents the type with the same pytree structure and corresponding leaves each with an additional leading axis.

    When the type of `xs` in the snippet below (denoted `[a]` above) is an array type or None, and the type of `ys` in the snippet below (denoted `[b]` above) is an array type, the semantics of the returned [`genjax.GenerativeFunction`][] are given roughly by this Python implementation:

    ```python
    def scan(f, init, xs, length=None):
        if xs is None:
            xs = [None] * length
        carry = init
        ys = []
        for x in xs:
            carry, y = f(carry, x)
            ys.append(y)
        return carry, np.stack(ys)
    ```

    Unlike that Python version, both `xs` and `ys` may be arbitrary pytree values, and so multiple arrays can be scanned over at once and produce multiple output arrays. `None` is actually a special case of this, as it represents an empty pytree.

    The loop-carried value `c` must hold a fixed shape and dtype across all iterations (and not just be consistent up to NumPy rank/shape broadcasting and dtype promotion rules, for example). In other words, the type `c` in the type signature above represents an array with a fixed shape and dtype (or a nested tuple/list/dict container data structure with a fixed structure and arrays with fixed shape and dtype at the leaves).

    Args:
        n: optional integer specifying the number of loop iterations, which (if supplied) must agree with the sizes of leading axes of the arrays in the returned function's second argument. If supplied then the returned generative function can take `None` as its second argument.

    Returns:
        A new [`genjax.GenerativeFunction`][] that takes a loop-carried value and a new input, and returns a new loop-carried value along with either `None` or an output to be collected into the second return value.

    Examples:
        Scan for 1000 iterations with no array input:
        ```python exec="yes" html="true" source="material-block" session="scan"
        import jax
        import genjax


        @genjax.gen
        def random_walk_step(prev, _):
            x = genjax.normal(prev, 1.0) @ "x"
            return x, None


        random_walk = random_walk_step.scan(n=1000)

        init = 0.5
        key = jax.random.key(314159)

        tr = jax.jit(random_walk.simulate)(key, (init, None))
        print(tr.render_html())
        ```

        Scan across an input array:
        ```python exec="yes" html="true" source="material-block" session="scan"
        import jax.numpy as jnp


        @genjax.gen
        def add_and_square_step(sum, x):
            new_sum = sum + x
            return new_sum, sum * sum


        # notice no `n` parameter supplied:
        add_and_square_all = add_and_square_step.scan()
        init = 0.0
        xs = jnp.ones(10)

        tr = jax.jit(add_and_square_all.simulate)(key, (init, xs))

        # The retval has the final carry and an array of all `sum*sum` returned.
        print(tr.render_html())
        ```
    """
    import genjax

    return genjax.scan(n=n)(self)

simulate abstractmethod

simulate(key: PRNGKey, args: Arguments) -> Trace[R]

Execute the generative function, sampling from its distribution over samples, and return a Trace.

More on traces

The Trace returned by simulate implements its own interface.

It is responsible for storing the arguments of the invocation (genjax.Trace.get_args), the return value of the generative function (genjax.Trace.get_retval), the identity of the generative function which produced the trace (genjax.Trace.get_gen_fn), the sample of traced random choices produced during the invocation (genjax.Trace.get_choices) and the score of the sample (genjax.Trace.get_score).

Examples:

import genjax
import jax
from jax import vmap, jit
from jax.random import split


@genjax.gen
def model():
    x = genjax.normal(0.0, 1.0) @ "x"
    return x


key = jax.random.key(0)
tr = model.simulate(key, ())
print(tr.render_html())

Another example, using the same model, composed into genjax.repeat - which creates a new generative function, which has the same interface:

@genjax.gen
def model():
    x = genjax.normal(0.0, 1.0) @ "x"
    return x


key = jax.random.key(0)
tr = model.repeat(n=10).simulate(key, ())
print(tr.render_html())

(Fun, flirty, fast ... parallel?) Feel free to use jax.jit and jax.vmap!

key = jax.random.key(0)
sub_keys = split(key, 10)
sim = model.repeat(n=10).simulate
tr = jit(vmap(sim, in_axes=(0, None)))(sub_keys, ())
print(tr.render_html())

Source code in src/genjax/_src/core/generative/generative_function.py
@abstractmethod
def simulate(
    self,
    key: PRNGKey,
    args: Arguments,
) -> Trace[R]:
    """
    Execute the generative function, sampling from its distribution over samples, and return a [`Trace`][genjax.core.Trace].

    ## More on traces

    The [`Trace`][genjax.core.Trace] returned by `simulate` implements its own interface.

    It is responsible for storing the arguments of the invocation ([`genjax.Trace.get_args`][]), the return value of the generative function ([`genjax.Trace.get_retval`][]), the identity of the generative function which produced the trace ([`genjax.Trace.get_gen_fn`][]), the sample of traced random choices produced during the invocation ([`genjax.Trace.get_choices`][]) and _the score_ of the sample ([`genjax.Trace.get_score`][]).

    Examples:
        ```python exec="yes" html="true" source="material-block" session="core"
        import genjax
        import jax
        from jax import vmap, jit
        from jax.random import split


        @genjax.gen
        def model():
            x = genjax.normal(0.0, 1.0) @ "x"
            return x


        key = jax.random.key(0)
        tr = model.simulate(key, ())
        print(tr.render_html())
        ```

        Another example, using the same model, composed into [`genjax.repeat`](combinators.md#genjax.repeat) - which creates a new generative function, which has the same interface:
        ```python exec="yes" html="true" source="material-block" session="core"
        @genjax.gen
        def model():
            x = genjax.normal(0.0, 1.0) @ "x"
            return x


        key = jax.random.key(0)
        tr = model.repeat(n=10).simulate(key, ())
        print(tr.render_html())
        ```

        (**Fun, flirty, fast ... parallel?**) Feel free to use `jax.jit` and `jax.vmap`!
        ```python exec="yes" html="true" source="material-block" session="core"
        key = jax.random.key(0)
        sub_keys = split(key, 10)
        sim = model.repeat(n=10).simulate
        tr = jit(vmap(sim, in_axes=(0, None)))(sub_keys, ())
        print(tr.render_html())
        ```
    """

switch

switch(*branches: GenerativeFunction[R]) -> Switch[R]

Given n genjax.GenerativeFunction inputs, returns a new genjax.GenerativeFunction that accepts n+2 arguments:

  • an index in the range \([0, n+1)\)
  • a tuple of arguments for self and each of the input generative functions (n+1 total tuples)

and executes the generative function at the supplied index with its provided arguments.

If index is out of bounds, index is clamped to within bounds.

Examples:

import jax, genjax


@genjax.gen
def branch_1():
    x = genjax.normal(0.0, 1.0) @ "x1"


@genjax.gen
def branch_2():
    x = genjax.bernoulli(0.3) @ "x2"


switch = branch_1.switch(branch_2)

key = jax.random.key(314159)
jitted = jax.jit(switch.simulate)

# Select `branch_2` by providing 1:
tr = jitted(key, (1, (), ()))

print(tr.render_html())
Source code in src/genjax/_src/core/generative/generative_function.py
def switch(self, *branches: "GenerativeFunction[R]") -> "genjax.Switch[R]":
    """
    Given `n` [`genjax.GenerativeFunction`][] inputs, returns a new [`genjax.GenerativeFunction`][] that accepts `n+2` arguments:

    - an index in the range $[0, n+1)$
    - a tuple of arguments for `self` and each of the input generative functions (`n+1` total tuples)

    and executes the generative function at the supplied index with its provided arguments.

    If `index` is out of bounds, `index` is clamped to within bounds.

    Examples:
        ```python exec="yes" html="true" source="material-block" session="switch"
        import jax, genjax


        @genjax.gen
        def branch_1():
            x = genjax.normal(0.0, 1.0) @ "x1"


        @genjax.gen
        def branch_2():
            x = genjax.bernoulli(0.3) @ "x2"


        switch = branch_1.switch(branch_2)

        key = jax.random.key(314159)
        jitted = jax.jit(switch.simulate)

        # Select `branch_2` by providing 1:
        tr = jitted(key, (1, (), ()))

        print(tr.render_html())
        ```
    """
    import genjax

    return genjax.switch(self, *branches)

vmap

vmap(*, in_axes: InAxes = 0) -> GenerativeFunction[R]

Returns a GenerativeFunction that performs a vectorized map over the argument specified by in_axes. Traced values are nested under an index, and the retval is vectorized.

Parameters:

Name Type Description Default

in_axes

InAxes

Selector specifying which input arguments (or index into them) should be vectorized. Defaults to 0, i.e., the first argument. See this link for more detail.

0

Returns:

Type Description
GenerativeFunction[R]

A new GenerativeFunction that accepts an argument of one-higher dimension at the position specified by in_axes.

Examples:

import jax
import jax.numpy as jnp
import genjax


@genjax.gen
def model(x):
    v = genjax.normal(x, 1.0) @ "v"
    return genjax.normal(v, 0.01) @ "q"


vmapped = model.vmap(in_axes=0)

key = jax.random.key(314159)
arr = jnp.ones(100)

# `vmapped` accepts an array if numbers instead of the original
# single number that `model` accepted.
tr = jax.jit(vmapped.simulate)(key, (arr,))

print(tr.render_html())
Source code in src/genjax/_src/core/generative/generative_function.py
def vmap(self, /, *, in_axes: InAxes = 0) -> "GenerativeFunction[R]":
    """
    Returns a [`GenerativeFunction`][genjax.GenerativeFunction] that performs a vectorized map over the argument specified by `in_axes`. Traced values are nested under an index, and the retval is vectorized.

    Args:
        in_axes: Selector specifying which input arguments (or index into them) should be vectorized. Defaults to 0, i.e., the first argument. See [this link](https://jax.readthedocs.io/en/latest/pytrees.html#applying-optional-parameters-to-pytrees) for more detail.

    Returns:
        A new [`GenerativeFunction`][genjax.GenerativeFunction] that accepts an argument of one-higher dimension at the position specified by `in_axes`.

    Examples:
        ```python exec="yes" html="true" source="material-block" session="gen-fn"
        import jax
        import jax.numpy as jnp
        import genjax


        @genjax.gen
        def model(x):
            v = genjax.normal(x, 1.0) @ "v"
            return genjax.normal(v, 0.01) @ "q"


        vmapped = model.vmap(in_axes=0)

        key = jax.random.key(314159)
        arr = jnp.ones(100)

        # `vmapped` accepts an array if numbers instead of the original
        # single number that `model` accepted.
        tr = jax.jit(vmapped.simulate)(key, (arr,))

        print(tr.render_html())
        ```
    """
    import genjax

    return genjax.vmap(in_axes=in_axes)(self)

Traces are data structures which record (execution and inference) data about the invocation of generative functions. Traces are often specialized to a generative function language, to take advantage of data locality, and other representation optimizations. Traces support a trace interface: a set of accessor methods designed to provide convenient manipulation when handling traces in inference algorithms. We document this interface below for the Trace data type.

genjax.core.Trace

Bases: Generic[R], Pytree

Trace is the type of traces of generative functions.

A trace is a data structure used to represent sampled executions of generative functions. Traces track metadata associated with the probabilities of choices, as well as other data associated with the invocation of a generative function, including the arguments it was invoked with, its return value, and the identity of the generative function itself.

Methods:

Name Description
edit

This method calls out to the underlying GenerativeFunction.edit method - see EditRequest and edit for more information.

get_args

Returns the Arguments for the GenerativeFunction invocation which created the Trace.

get_choices

Retrieves the random choices made in a trace in the form of a genjax.ChoiceMap.

get_gen_fn

Returns the GenerativeFunction whose invocation created the Trace.

get_inner_trace

Override this method to provide Trace.get_subtrace support

get_retval

Returns the R from the GenerativeFunction invocation which created the Trace.

get_score

Return the Score of the Trace.

get_subtrace

Return the subtrace having the supplied address. Specifying multiple addresses

update

This method calls out to the underlying GenerativeFunction.edit method - see EditRequest and edit for more information.

Source code in src/genjax/_src/core/generative/generative_function.py
class Trace(Generic[R], Pytree):
    """
    `Trace` is the type of traces of generative functions.

    A trace is a data structure used to represent sampled executions of
    generative functions. Traces track metadata associated with the probabilities
    of choices, as well as other data associated with
    the invocation of a generative function, including the arguments it
    was invoked with, its return value, and the identity of the generative function itself.
    """

    @abstractmethod
    def get_args(self) -> Arguments:
        """Returns the [`Arguments`][genjax.core.Arguments] for the [`GenerativeFunction`][genjax.core.GenerativeFunction] invocation which created the [`Trace`][genjax.core.Trace]."""

    @abstractmethod
    def get_retval(self) -> R:
        """Returns the `R` from the [`GenerativeFunction`][genjax.core.GenerativeFunction] invocation which created the [`Trace`][genjax.core.Trace]."""

    @abstractmethod
    def get_score(self) -> Score:
        """Return the [`Score`][genjax.core.Score] of the `Trace`.

        The score must satisfy a particular mathematical specification: it's either an exact density evaluation of $P$ (the distribution over samples) for the sample returned by [`genjax.Trace.get_choices`][], or _a sample from an estimator_ (a density estimate) if the generative function contains _untraced randomness_.

        Let $s$ be the score, $t$ the sample, and $a$ the arguments: when the generative function contains no _untraced randomness_, the score (in logspace) is given by:

        $$
        \\log s := \\log P(t; a)
        $$

        (**With untraced randomness**) Gen allows for the possibility of sources of randomness _which are not traced_. When these sources are included in generative computations, the score is defined so that the following property holds:

        $$
        \\mathbb{E}_{r\\sim~P(r | t; a)}\\big[\\frac{1}{s}\\big] = \\frac{1}{P(t; a)}
        $$

        This property is the one you'd want to be true if you were using a generative function with untraced randomness _as a proposal_ in a routine which uses importance sampling, for instance.

        In GenJAX, one way you might encounter this is by using pseudo-random routines in your modeling code:
        ```python
        # notice how the key is explicit
        @genjax.gen
        def model_with_untraced_randomness(key: PRNGKey):
            x = genjax.normal(0.0, 1.0) "x"
            v = some_random_process(key, x)
            y = genjax.normal(v, 1.0) @ "y"
        ```

        In this case, the score (in logspace) is given by:

        $$
        \\log s := \\log P(r, t; a) - \\log Q(r; a)
        $$

        which satisfies the requirement by virtue of the fact:

        $$
        \\begin{aligned}
        \\mathbb{E}_{r\\sim~P(r | t; a)}\\big[\\frac{1}{s}\\big] &= \\mathbb{E}_{r\\sim P(r | t; a)}\\big[\\frac{Q(r; a)}{P(r, t; a)} \\big] \\\\ &= \\frac{1}{P(t; a)} \\mathbb{E}_{r\\sim P(r | t; a)}\\big[\\frac{Q(r; a)}{P(r | t; a)}\\big] \\\\
        &= \\frac{1}{P(t; a)}
        \\end{aligned}
        $$

        """

    @abstractmethod
    def get_choices(self) -> "genjax.ChoiceMap":
        """Retrieves the random choices made in a trace in the form of a [`genjax.ChoiceMap`][]."""
        pass

    @nobeartype
    @deprecated(reason="Use .get_choices() instead.", version="0.8.1")
    def get_sample(self):
        return self.get_choices()

    @abstractmethod
    def get_gen_fn(self) -> "GenerativeFunction[R]":
        """Returns the [`GenerativeFunction`][genjax.core.GenerativeFunction] whose invocation created the [`Trace`][genjax.core.Trace]."""
        pass

    def edit(
        self,
        key: PRNGKey,
        request: EditRequest,
        argdiffs: tuple[Any, ...] | None = None,
    ) -> tuple[Self, Weight, Retdiff[R], EditRequest]:
        """
        This method calls out to the underlying [`GenerativeFunction.edit`][genjax.core.GenerativeFunction.edit] method - see [`EditRequest`][genjax.core.EditRequest] and [`edit`][genjax.core.GenerativeFunction.edit] for more information.
        """
        return request.edit(
            key,
            self,
            Diff.no_change(self.get_args()) if argdiffs is None else argdiffs,
        )  # pyright: ignore[reportReturnType]

    def update(
        self,
        key: PRNGKey,
        constraint: ChoiceMap,
        argdiffs: tuple[Any, ...] | None = None,
    ) -> tuple[Self, Weight, Retdiff[R], ChoiceMap]:
        """
        This method calls out to the underlying [`GenerativeFunction.edit`][genjax.core.GenerativeFunction.edit] method - see [`EditRequest`][genjax.core.EditRequest] and [`edit`][genjax.core.GenerativeFunction.edit] for more information.
        """
        return self.get_gen_fn().update(
            key,
            self,
            constraint,
            Diff.no_change(self.get_args()) if argdiffs is None else argdiffs,
        )  # pyright: ignore[reportReturnType]

    def project(
        self,
        key: PRNGKey,
        selection: Selection,
    ) -> Weight:
        gen_fn = self.get_gen_fn()
        return gen_fn.project(
            key,
            self,
            selection,
        )

    def get_subtrace(self, *addresses: Address) -> "Trace[Any]":
        """
        Return the subtrace having the supplied address. Specifying multiple addresses
        will apply the operation recursively.

        GenJAX does not guarantee the validity of any inference computations performed
        using information from the returned subtrace. In other words, it is safe to
        inspect the data of subtraces -- but it not safe to use that data to make decisions
        about inference. This is true of all the methods on the subtrace, including
        `Trace.get_args`, `Trace.get_score`, `Trace.get_retval`, etc. It is safe to look,
        but don't use the data for non-trivial things!"""

        return functools.reduce(
            lambda tr, addr: tr.get_inner_trace(addr), addresses, self
        )

    def get_inner_trace(self, _address: Address) -> "Trace[Any]":
        """Override this method to provide `Trace.get_subtrace` support
        for those trace types that have substructure that can be addressed
        in this way.

        NOTE: `get_inner_trace` takes a full `Address` because, unlike `ChoiceMap`, if a user traces to a tupled address like ("a", "b"), then the resulting `StaticTrace` will store a sub-trace at this address, vs flattening it out.

        As a result, `tr.get_inner_trace(("a", "b"))` does not equal `tr.get_inner_trace("a").get_inner_trace("b")`."""
        raise NotImplementedError(
            "This type of Trace object does not possess subtraces."
        )

    ###################
    # Batch semantics #
    ###################

    @property
    def batch_shape(self):
        return len(self.get_score())

edit

edit(
    key: PRNGKey,
    request: EditRequest,
    argdiffs: tuple[Any, ...] | None = None,
) -> tuple[Self, Weight, Retdiff[R], EditRequest]

This method calls out to the underlying GenerativeFunction.edit method - see EditRequest and edit for more information.

Source code in src/genjax/_src/core/generative/generative_function.py
def edit(
    self,
    key: PRNGKey,
    request: EditRequest,
    argdiffs: tuple[Any, ...] | None = None,
) -> tuple[Self, Weight, Retdiff[R], EditRequest]:
    """
    This method calls out to the underlying [`GenerativeFunction.edit`][genjax.core.GenerativeFunction.edit] method - see [`EditRequest`][genjax.core.EditRequest] and [`edit`][genjax.core.GenerativeFunction.edit] for more information.
    """
    return request.edit(
        key,
        self,
        Diff.no_change(self.get_args()) if argdiffs is None else argdiffs,
    )  # pyright: ignore[reportReturnType]

get_args abstractmethod

get_args() -> Arguments

Returns the Arguments for the GenerativeFunction invocation which created the Trace.

Source code in src/genjax/_src/core/generative/generative_function.py
@abstractmethod
def get_args(self) -> Arguments:
    """Returns the [`Arguments`][genjax.core.Arguments] for the [`GenerativeFunction`][genjax.core.GenerativeFunction] invocation which created the [`Trace`][genjax.core.Trace]."""

get_choices abstractmethod

get_choices() -> ChoiceMap

Retrieves the random choices made in a trace in the form of a genjax.ChoiceMap.

Source code in src/genjax/_src/core/generative/generative_function.py
@abstractmethod
def get_choices(self) -> "genjax.ChoiceMap":
    """Retrieves the random choices made in a trace in the form of a [`genjax.ChoiceMap`][]."""
    pass

get_gen_fn abstractmethod

get_gen_fn() -> GenerativeFunction[R]

Returns the GenerativeFunction whose invocation created the Trace.

Source code in src/genjax/_src/core/generative/generative_function.py
@abstractmethod
def get_gen_fn(self) -> "GenerativeFunction[R]":
    """Returns the [`GenerativeFunction`][genjax.core.GenerativeFunction] whose invocation created the [`Trace`][genjax.core.Trace]."""
    pass

get_inner_trace

get_inner_trace(_address: Address) -> Trace[Any]

Override this method to provide Trace.get_subtrace support for those trace types that have substructure that can be addressed in this way.

NOTE: get_inner_trace takes a full Address because, unlike ChoiceMap, if a user traces to a tupled address like ("a", "b"), then the resulting StaticTrace will store a sub-trace at this address, vs flattening it out.

As a result, tr.get_inner_trace(("a", "b")) does not equal tr.get_inner_trace("a").get_inner_trace("b").

Source code in src/genjax/_src/core/generative/generative_function.py
def get_inner_trace(self, _address: Address) -> "Trace[Any]":
    """Override this method to provide `Trace.get_subtrace` support
    for those trace types that have substructure that can be addressed
    in this way.

    NOTE: `get_inner_trace` takes a full `Address` because, unlike `ChoiceMap`, if a user traces to a tupled address like ("a", "b"), then the resulting `StaticTrace` will store a sub-trace at this address, vs flattening it out.

    As a result, `tr.get_inner_trace(("a", "b"))` does not equal `tr.get_inner_trace("a").get_inner_trace("b")`."""
    raise NotImplementedError(
        "This type of Trace object does not possess subtraces."
    )

get_retval abstractmethod

get_retval() -> R

Returns the R from the GenerativeFunction invocation which created the Trace.

Source code in src/genjax/_src/core/generative/generative_function.py
@abstractmethod
def get_retval(self) -> R:
    """Returns the `R` from the [`GenerativeFunction`][genjax.core.GenerativeFunction] invocation which created the [`Trace`][genjax.core.Trace]."""

get_score abstractmethod

get_score() -> Score

Return the Score of the Trace.

The score must satisfy a particular mathematical specification: it's either an exact density evaluation of \(P\) (the distribution over samples) for the sample returned by genjax.Trace.get_choices, or a sample from an estimator (a density estimate) if the generative function contains untraced randomness.

Let \(s\) be the score, \(t\) the sample, and \(a\) the arguments: when the generative function contains no untraced randomness, the score (in logspace) is given by:

\[ \log s := \log P(t; a) \]

(With untraced randomness) Gen allows for the possibility of sources of randomness which are not traced. When these sources are included in generative computations, the score is defined so that the following property holds:

\[ \mathbb{E}_{r\sim~P(r | t; a)}\big[\frac{1}{s}\big] = \frac{1}{P(t; a)} \]

This property is the one you'd want to be true if you were using a generative function with untraced randomness as a proposal in a routine which uses importance sampling, for instance.

In GenJAX, one way you might encounter this is by using pseudo-random routines in your modeling code:

# notice how the key is explicit
@genjax.gen
def model_with_untraced_randomness(key: PRNGKey):
    x = genjax.normal(0.0, 1.0) "x"
    v = some_random_process(key, x)
    y = genjax.normal(v, 1.0) @ "y"

In this case, the score (in logspace) is given by:

\[ \log s := \log P(r, t; a) - \log Q(r; a) \]

which satisfies the requirement by virtue of the fact:

\[ \begin{aligned} \mathbb{E}_{r\sim~P(r | t; a)}\big[\frac{1}{s}\big] &= \mathbb{E}_{r\sim P(r | t; a)}\big[\frac{Q(r; a)}{P(r, t; a)} \big] \\ &= \frac{1}{P(t; a)} \mathbb{E}_{r\sim P(r | t; a)}\big[\frac{Q(r; a)}{P(r | t; a)}\big] \\ &= \frac{1}{P(t; a)} \end{aligned} \]
Source code in src/genjax/_src/core/generative/generative_function.py
@abstractmethod
def get_score(self) -> Score:
    """Return the [`Score`][genjax.core.Score] of the `Trace`.

    The score must satisfy a particular mathematical specification: it's either an exact density evaluation of $P$ (the distribution over samples) for the sample returned by [`genjax.Trace.get_choices`][], or _a sample from an estimator_ (a density estimate) if the generative function contains _untraced randomness_.

    Let $s$ be the score, $t$ the sample, and $a$ the arguments: when the generative function contains no _untraced randomness_, the score (in logspace) is given by:

    $$
    \\log s := \\log P(t; a)
    $$

    (**With untraced randomness**) Gen allows for the possibility of sources of randomness _which are not traced_. When these sources are included in generative computations, the score is defined so that the following property holds:

    $$
    \\mathbb{E}_{r\\sim~P(r | t; a)}\\big[\\frac{1}{s}\\big] = \\frac{1}{P(t; a)}
    $$

    This property is the one you'd want to be true if you were using a generative function with untraced randomness _as a proposal_ in a routine which uses importance sampling, for instance.

    In GenJAX, one way you might encounter this is by using pseudo-random routines in your modeling code:
    ```python
    # notice how the key is explicit
    @genjax.gen
    def model_with_untraced_randomness(key: PRNGKey):
        x = genjax.normal(0.0, 1.0) "x"
        v = some_random_process(key, x)
        y = genjax.normal(v, 1.0) @ "y"
    ```

    In this case, the score (in logspace) is given by:

    $$
    \\log s := \\log P(r, t; a) - \\log Q(r; a)
    $$

    which satisfies the requirement by virtue of the fact:

    $$
    \\begin{aligned}
    \\mathbb{E}_{r\\sim~P(r | t; a)}\\big[\\frac{1}{s}\\big] &= \\mathbb{E}_{r\\sim P(r | t; a)}\\big[\\frac{Q(r; a)}{P(r, t; a)} \\big] \\\\ &= \\frac{1}{P(t; a)} \\mathbb{E}_{r\\sim P(r | t; a)}\\big[\\frac{Q(r; a)}{P(r | t; a)}\\big] \\\\
    &= \\frac{1}{P(t; a)}
    \\end{aligned}
    $$

    """

get_subtrace

get_subtrace(*addresses: Address) -> Trace[Any]

Return the subtrace having the supplied address. Specifying multiple addresses will apply the operation recursively.

GenJAX does not guarantee the validity of any inference computations performed using information from the returned subtrace. In other words, it is safe to inspect the data of subtraces -- but it not safe to use that data to make decisions about inference. This is true of all the methods on the subtrace, including Trace.get_args, Trace.get_score, Trace.get_retval, etc. It is safe to look, but don't use the data for non-trivial things!

Source code in src/genjax/_src/core/generative/generative_function.py
def get_subtrace(self, *addresses: Address) -> "Trace[Any]":
    """
    Return the subtrace having the supplied address. Specifying multiple addresses
    will apply the operation recursively.

    GenJAX does not guarantee the validity of any inference computations performed
    using information from the returned subtrace. In other words, it is safe to
    inspect the data of subtraces -- but it not safe to use that data to make decisions
    about inference. This is true of all the methods on the subtrace, including
    `Trace.get_args`, `Trace.get_score`, `Trace.get_retval`, etc. It is safe to look,
    but don't use the data for non-trivial things!"""

    return functools.reduce(
        lambda tr, addr: tr.get_inner_trace(addr), addresses, self
    )

update

update(
    key: PRNGKey,
    constraint: ChoiceMap,
    argdiffs: tuple[Any, ...] | None = None,
) -> tuple[Self, Weight, Retdiff[R], ChoiceMap]

This method calls out to the underlying GenerativeFunction.edit method - see EditRequest and edit for more information.

Source code in src/genjax/_src/core/generative/generative_function.py
def update(
    self,
    key: PRNGKey,
    constraint: ChoiceMap,
    argdiffs: tuple[Any, ...] | None = None,
) -> tuple[Self, Weight, Retdiff[R], ChoiceMap]:
    """
    This method calls out to the underlying [`GenerativeFunction.edit`][genjax.core.GenerativeFunction.edit] method - see [`EditRequest`][genjax.core.EditRequest] and [`edit`][genjax.core.GenerativeFunction.edit] for more information.
    """
    return self.get_gen_fn().update(
        key,
        self,
        constraint,
        Diff.no_change(self.get_args()) if argdiffs is None else argdiffs,
    )  # pyright: ignore[reportReturnType]

genjax.core.EditRequest

Bases: Pytree

An EditRequest is a request to edit a trace of a generative function. Generative functions respond to instances of subtypes of EditRequest by providing an edit implementation.

Updating a trace is a common operation in inference processes, but naively mutating the trace will invalidate the mathematical invariants that Gen retains. EditRequest instances denote requests for SMC moves in the framework of SMCP3, which preserve these invariants.

Source code in src/genjax/_src/core/generative/concepts.py
class EditRequest(Pytree):
    """
    An `EditRequest` is a request to edit a trace of a generative function. Generative functions respond to instances of subtypes of `EditRequest` by providing an [`edit`][genjax.core.GenerativeFunction.edit] implementation.

    Updating a trace is a common operation in inference processes, but naively mutating the trace will invalidate the mathematical invariants that Gen retains. `EditRequest` instances denote requests for _SMC moves_ in the framework of [SMCP3](https://proceedings.mlr.press/v206/lew23a.html), which preserve these invariants.
    """

    @abstractmethod
    def edit(
        self,
        key: PRNGKey,
        tr: "genjax.Trace[R]",
        argdiffs: Argdiffs,
    ) -> "tuple[genjax.Trace[R], Weight, Retdiff[R], EditRequest]":
        pass

    def dimap(
        self,
        /,
        *,
        pre: Callable[[Argdiffs], Argdiffs] = lambda v: v,
        post: Callable[[Retdiff[R]], Retdiff[R]] = lambda v: v,
    ) -> "genjax.DiffAnnotate[Self]":
        from genjax import DiffAnnotate

        return DiffAnnotate(self, argdiff_fn=pre, retdiff_fn=post)

    def map(
        self,
        post: Callable[[Retdiff[R]], Retdiff[R]],
    ) -> "genjax.DiffAnnotate[Self]":
        return self.dimap(post=post)

    def contramap(
        self,
        pre: Callable[[Argdiffs], Argdiffs],
    ) -> "genjax.DiffAnnotate[Self]":
        return self.dimap(pre=pre)

Generative functions with addressed random choices

Generative functions will often include addressed random choices. These are random choices which are given a name via an addressing syntax, and can be accessed by name via extended interfaces on the ChoiceMap type which supports the addressing.

genjax.core.ChoiceMap

Bases: Pytree

The type ChoiceMap denotes a map-like value which can be sampled from generative functions.

Generative functions which utilize ChoiceMap as their sample representation typically support a notion of addressing for the random choices they make. ChoiceMap stores addressed random choices, and provides a data language for querying and manipulating these choices.

Examples:

(Making choice maps) Choice maps can be constructed using the ChoiceMapBuilder interface

from genjax import ChoiceMapBuilder as C

chm = C["x"].set(3.0)
print(chm.render_html())

(Getting submaps) Hierarchical choice maps support __call__, which allows for the retrieval of submaps at addresses:

from genjax import ChoiceMapBuilder as C

chm = C["x", "y"].set(3.0)
submap = chm("x")
print(submap.render_html())

(Getting values) Choice maps support __getitem__, which allows for the retrieval of values at addresses:

from genjax import ChoiceMapBuilder as C

chm = C["x", "y"].set(3.0)
value = chm["x", "y"]
print(value)
3.0

(Making vectorized choice maps) Choice maps can be constructed using jax.vmap:

from genjax import ChoiceMapBuilder as C
from jax import vmap
import jax.numpy as jnp

vec_chm = vmap(lambda idx, v: C["x", idx].set(v))(jnp.arange(10), jnp.ones(10))
print(vec_chm.render_html())

Methods:

Name Description
__call__

Alias for get_submap(*addresses).

choice

Creates a ChoiceMap containing a single value.

d

Creates a ChoiceMap from a dictionary.

empty

Returns a ChoiceMap with no values or submaps.

entry

Creates a ChoiceMap with a single value at a specified address.

extend

Returns a new ChoiceMap with the given address component as its root.

filter

Filter the choice map on the Selection. The resulting choice map only contains the addresses that return True when presented to the selection.

from_mapping

Creates a ChoiceMap from an iterable of address-value pairs.

get_selection

Returns a Selection representing the structure of this ChoiceMap.

invalid_subset

Identifies the subset of choices that are invalid for a given generative function and its arguments.

kw

Creates a ChoiceMap from keyword arguments.

mask

Returns a new ChoiceMap with values masked by a boolean flag.

merge

Merges this ChoiceMap with another ChoiceMap.

simplify

Previously pushed down filters, now acts as identity.

static_is_empty

Returns True if this ChoiceMap is equal to ChoiceMap.empty(), False otherwise.

switch

Creates a ChoiceMap that switches between multiple ChoiceMaps based on an index.

Attributes:

Name Type Description
at _ChoiceMapBuilder

Returns a _ChoiceMapBuilder instance for constructing nested ChoiceMaps.

Source code in src/genjax/_src/core/generative/choice_map.py
 847
 848
 849
 850
 851
 852
 853
 854
 855
 856
 857
 858
 859
 860
 861
 862
 863
 864
 865
 866
 867
 868
 869
 870
 871
 872
 873
 874
 875
 876
 877
 878
 879
 880
 881
 882
 883
 884
 885
 886
 887
 888
 889
 890
 891
 892
 893
 894
 895
 896
 897
 898
 899
 900
 901
 902
 903
 904
 905
 906
 907
 908
 909
 910
 911
 912
 913
 914
 915
 916
 917
 918
 919
 920
 921
 922
 923
 924
 925
 926
 927
 928
 929
 930
 931
 932
 933
 934
 935
 936
 937
 938
 939
 940
 941
 942
 943
 944
 945
 946
 947
 948
 949
 950
 951
 952
 953
 954
 955
 956
 957
 958
 959
 960
 961
 962
 963
 964
 965
 966
 967
 968
 969
 970
 971
 972
 973
 974
 975
 976
 977
 978
 979
 980
 981
 982
 983
 984
 985
 986
 987
 988
 989
 990
 991
 992
 993
 994
 995
 996
 997
 998
 999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
class ChoiceMap(Pytree):
    """The type `ChoiceMap` denotes a map-like value which can be sampled from
    generative functions.

    Generative functions which utilize `ChoiceMap` as their sample representation typically support a notion of _addressing_ for the random choices they make. `ChoiceMap` stores addressed random choices, and provides a data language for querying and manipulating these choices.

    Examples:
        (**Making choice maps**) Choice maps can be constructed using the `ChoiceMapBuilder` interface
        ```python exec="yes" html="true" source="material-block" session="choicemap"
        from genjax import ChoiceMapBuilder as C

        chm = C["x"].set(3.0)
        print(chm.render_html())
        ```

        (**Getting submaps**) Hierarchical choice maps support `__call__`, which allows for the retrieval of _submaps_ at addresses:
        ```python exec="yes" html="true" source="material-block" session="choicemap"
        from genjax import ChoiceMapBuilder as C

        chm = C["x", "y"].set(3.0)
        submap = chm("x")
        print(submap.render_html())
        ```

        (**Getting values**) Choice maps support `__getitem__`, which allows for the retrieval of _values_ at addresses:
        ```python exec="yes" html="true" source="material-block" session="choicemap"
        from genjax import ChoiceMapBuilder as C

        chm = C["x", "y"].set(3.0)
        value = chm["x", "y"]
        print(value)
        ```

        (**Making vectorized choice maps**) Choice maps can be constructed using `jax.vmap`:
        ```python exec="yes" html="true" source="material-block" session="choicemap"
        from genjax import ChoiceMapBuilder as C
        from jax import vmap
        import jax.numpy as jnp

        vec_chm = vmap(lambda idx, v: C["x", idx].set(v))(jnp.arange(10), jnp.ones(10))
        print(vec_chm.render_html())
        ```
    """

    #######################
    # Map-like interfaces #
    #######################

    @abstractmethod
    def filter(self, selection: Selection | Flag) -> "ChoiceMap":
        """
        Filter the choice map on the `Selection`. The resulting choice map only contains the addresses that return True when presented to the selection.

        Args:
            selection: The Selection to filter the choice map with.

        Returns:
            A new ChoiceMap containing only the addresses selected by the given Selection.

        Examples:
            ```python exec="yes" html="true" source="material-block" session="choicemap"
            import jax
            import genjax
            from genjax import bernoulli
            from genjax import SelectionBuilder as S


            @genjax.gen
            def model():
                x = bernoulli(0.3) @ "x"
                y = bernoulli(0.3) @ "y"
                return x


            key = jax.random.key(314159)
            tr = model.simulate(key, ())
            chm = tr.get_choices()
            selection = S["x"]
            filtered = chm.filter(selection)
            assert "y" not in filtered
            ```
        """

    @abstractmethod
    def get_value(self) -> Any:
        pass

    @abstractmethod
    def get_inner_map(
        self,
        addr: AddressComponent,
    ) -> "ChoiceMap":
        pass

    def get_submap(self, *addresses: Address) -> "ChoiceMap":
        addr = tuple(
            label for a in addresses for label in (a if isinstance(a, tuple) else (a,))
        )
        addr: tuple[AddressComponent, ...] = _validate_addr(
            addr, allow_partial_slice=True
        )
        return functools.reduce(lambda chm, addr: chm.get_inner_map(addr), addr, self)

    def has_value(self) -> bool:
        return self.get_value() is not None

    ######################################
    # Convenient syntax for construction #
    ######################################

    builder: Final[_ChoiceMapBuilder] = _ChoiceMapBuilder(None, [])

    @staticmethod
    def empty() -> "ChoiceMap":
        """
        Returns a ChoiceMap with no values or submaps.

        Returns:
            An empty ChoiceMap.
        """
        return _empty

    @staticmethod
    def choice(v: Any) -> "ChoiceMap":
        """
        Creates a ChoiceMap containing a single value.

        This method creates and returns an instance of Choice, which represents
        a ChoiceMap with a single value at the root level.

        Args:
            v: The value to be stored in the ChoiceMap.

        Returns:
            A ChoiceMap containing the single value.

        Example:
            ```python exec="yes" html="true" source="material-block" session="choicemap"
            from genjax import ChoiceMap

            value_chm = ChoiceMap.value(42)
            assert value_chm.get_value() == 42
            ```
        """
        return Choice.build(v)

    @staticmethod
    @nobeartype
    @deprecated("Use ChoiceMap.choice() instead.")
    def value(v: Any) -> "ChoiceMap":
        return ChoiceMap.choice(v)

    @staticmethod
    def entry(
        v: "dict[K_addr, Any] | ChoiceMap | Any", *addrs: AddressComponent
    ) -> "ChoiceMap":
        """
        Creates a ChoiceMap with a single value at a specified address.

        This method creates and returns a ChoiceMap with a new ChoiceMap stored at
        the given address.

        - if the provided value is already a ChoiceMap, it will be used directly;
        - `dict` values will be passed to `ChoiceMap.d`;
        - any other value will be passed to `ChoiceMap.value`.

        Args:
            v: The value to be stored in the ChoiceMap. Can be any value, a dict or a ChoiceMap.
            addrs: The address at which to store the value. Can be a static or dynamic address component.

        Returns:
            A ChoiceMap with the value stored at the specified address.

        Example:
            ```python exec="yes" html="true" source="material-block" session="choicemap"
            import genjax
            import jax.numpy as jnp

            # Using an existing ChoiceMap
            nested_chm = ChoiceMap.entry(ChoiceMap.value(42), "x")
            assert nested_chm["x"] == 42

            # Using a dict generates a new `ChoiceMap.d` call
            nested_chm = ChoiceMap.entry({"y": 42}, "x")
            assert nested_chm["x", "y"] == 42

            # Static address
            static_chm = ChoiceMap.entry(42, "x")
            assert static_chm["x"] == 42

            # Dynamic address
            dynamic_chm = ChoiceMap.entry(
                jnp.array([1.1, 2.2, 3.3]), jnp.array([1, 2, 3])
            )
            assert dynamic_chm[1] == genjax.Mask(1.1, True)
            ```
        """
        if isinstance(v, ChoiceMap):
            chm = v
        elif isinstance(v, dict):
            chm = ChoiceMap.d(v)
        else:
            chm = ChoiceMap.choice(v)

        return chm.extend(*addrs)

    @staticmethod
    def from_mapping(pairs: Iterable[tuple[K_addr, Any]]) -> "ChoiceMap":
        """
        Creates a ChoiceMap from an iterable of address-value pairs.

        This method constructs a ChoiceMap by iterating through the provided pairs,
        where each pair consists of an address (or address component) and a corresponding value.
        The resulting ChoiceMap will contain all the values at their respective addresses.

        Args:
            pairs: An iterable of tuples, where each tuple contains an address (or address component) and its corresponding value. The address can be a single component or a tuple of components.

        Returns:
            A ChoiceMap containing all the address-value pairs from the input.

        Example:
            ```python
            pairs = [("x", 42), (("y", "z"), 10), ("w", [1, 2, 3])]
            chm = ChoiceMap.from_mapping(pairs)
            assert chm["x"] == 42
            assert chm["y", "z"] == 10
            assert chm["w"] == [1, 2, 3]
            ```

        Note:
            If multiple pairs have the same address, later pairs will overwrite earlier ones.
        """
        acc = ChoiceMap.empty()

        for addr, v in pairs:
            addr = addr if isinstance(addr, tuple) else (addr,)
            acc |= ChoiceMap.entry(v, *addr)

        return acc

    @staticmethod
    def d(d: dict[K_addr, Any]) -> "ChoiceMap":
        """
        Creates a ChoiceMap from a dictionary.

        This method creates and returns a ChoiceMap based on the key-value pairs in the provided dictionary. Each key in the dictionary becomes an address in the ChoiceMap, and the corresponding value is stored at that address.

        Dict-shaped values are recursively converted to ChoiceMap instances.

        Args:
            d: A dictionary where keys are addresses and values are the corresponding data to be stored in the ChoiceMap.

        Returns:
            A ChoiceMap containing the key-value pairs from the input dictionary.

        Example:
            ```python exec="yes" html="true" source="material-block" session="choicemap"
            from genjax import ChoiceMap

            dict_chm = ChoiceMap.d({"x": 42, "y": {"z": [1, 2, 3]}})
            assert dict_chm["x"] == 42
            assert dict_chm["y", "z"] == [1, 2, 3]
            ```
        """
        return ChoiceMap.from_mapping(d.items())

    @staticmethod
    def kw(**kwargs) -> "ChoiceMap":
        """
        Creates a ChoiceMap from keyword arguments.

        This method creates and returns a ChoiceMap based on the provided keyword arguments.
        Each keyword argument becomes an address in the ChoiceMap, and its value is stored at that address.

        Dict-shaped values are recursively converted to ChoiceMap instances with calls to `ChoiceMap.d`.

        Returns:
            A ChoiceMap containing the key-value pairs from the input keyword arguments.

        Example:
            ```python exec="yes" html="true" source="material-block" session="choicemap"
            kw_chm = ChoiceMap.kw(x=42, y=[1, 2, 3], z={"w": 10.0})
            assert kw_chm["x"] == 42
            assert kw_chm["y"] == [1, 2, 3]
            assert kw_chm["z", "w"] == 10.0
            ```
        """
        return ChoiceMap.d(kwargs)

    @staticmethod
    def switch(idx: int | IntArray, chms: Iterable["ChoiceMap"]) -> "ChoiceMap":
        """
        Creates a ChoiceMap that switches between multiple ChoiceMaps based on an index.

        This method creates a new ChoiceMap that selectively includes values from a sequence of
        input ChoiceMaps based on the provided index. The resulting ChoiceMap will contain
        values from the ChoiceMap at the position specified by the index, while masking out
        values from all other ChoiceMaps.

        Args:
            idx: An index or array of indices specifying which ChoiceMap(s) to select from.
            chms: An iterable of ChoiceMaps to switch between.

        Returns:
            A new ChoiceMap containing values from the selected ChoiceMap(s).

        Example:
            ```python exec="yes" html="true" source="material-block" session="choicemap"
            chm1 = ChoiceMap.d({"x": 1, "y": 2})
            chm2 = ChoiceMap.d({"x": 3, "y": 4})
            chm3 = ChoiceMap.d({"x": 5, "y": 6})

            switched = ChoiceMap.switch(jnp.array(1), [chm1, chm2, chm3])
            assert switched["x"].unmask() == 3
            assert switched["y"].unmask() == 4
            ```
        """
        return Switch.build(idx, chms)

    ######################
    # Combinator methods #
    ######################

    def mask(self, flag: Flag) -> "ChoiceMap":
        """
        Returns a new ChoiceMap with values masked by a boolean flag.

        This method creates a new ChoiceMap where the values are conditionally
        included based on the provided flag. If the flag is True, the original
        values are retained; if False, the ChoiceMap behaves as if it's empty.

        Args:
            flag: A boolean flag determining whether to include the values.

        Returns:
            A new ChoiceMap with values conditionally masked.

        Example:
            ```python exec="yes" html="true" source="material-block" session="choicemap"
            original_chm = ChoiceMap.value(42)
            masked_chm = original_chm.mask(True)
            assert masked_chm.get_value() == 42

            masked_chm = original_chm.mask(False)
            assert masked_chm.get_value() is None
            ```
        """
        return self.filter(flag)

    def extend(self, *addrs: AddressComponent) -> "ChoiceMap":
        """
        Returns a new ChoiceMap with the given address component as its root.

        This method creates a new ChoiceMap where the current ChoiceMap becomes a submap
        under the specified address component. It effectively adds a new level of hierarchy
        to the ChoiceMap structure.

        Args:
            addrs: The address components to use as the new root.

        Returns:
            A new ChoiceMap with the current ChoiceMap nested under the given address.

        Example:
            ```python exec="yes" html="true" source="material-block" session="choicemap"
            original_chm = ChoiceMap.value(42)
            indexed_chm = original_chm.extend("x")
            assert indexed_chm["x"] == 42
            ```
        """
        acc = self
        for addr in reversed(addrs):
            if isinstance(addr, StaticAddressComponent):
                acc = Static.build({addr: acc})
            else:
                acc = Indexed.build(acc, addr)

        return acc

    def merge(self, other: "ChoiceMap") -> "ChoiceMap":
        """
        Merges this ChoiceMap with another ChoiceMap.

        This method combines the current ChoiceMap with another ChoiceMap using the XOR operation (^). It creates a new ChoiceMap that contains all addresses from both input ChoiceMaps; any overlapping addresses will trigger an error on access at the address via `[<addr>]` or `get_value()`. Use `|` if you don't want this behavior.

        Args:
            other: The ChoiceMap to merge with the current one.

        Returns:
            A new ChoiceMap resulting from the merge operation.

        Example:
            ```python exec="yes" html="true" source="material-block" session="choicemap"
            chm1 = ChoiceMap.value(5).extend("x")
            chm2 = ChoiceMap.value(10).extend("y")
            merged_chm = chm1.merge(chm2)
            assert merged_chm["x"] == 5
            assert merged_chm["y"] == 10
            ```

        Note:
            This method is equivalent to using the | operator between two ChoiceMaps.
        """
        return self | other

    def get_selection(self) -> Selection:
        """
        Returns a Selection representing the structure of this ChoiceMap.

        This method creates a Selection that matches the hierarchical structure
        of the current ChoiceMap. The resulting Selection can be used to filter
        or query other ChoiceMaps with the same structure.

        Returns:
            A Selection object representing the structure of this ChoiceMap.

        Example:
            ```python exec="yes" html="true" source="material-block" session="choicemap"
            chm = ChoiceMap.value(5).extend("x")
            sel = chm.get_selection()
            assert sel["x"] == True
            assert sel["y"] == False
            ```
        """
        return ChmSel.build(self)

    def static_is_empty(self) -> bool:
        """
        Returns True if this ChoiceMap is equal to `ChoiceMap.empty()`, False otherwise.
        """
        return False

    ###########
    # Dunders #
    ###########

    @nobeartype
    @deprecated(
        reason="^ is deprecated, please use | or _.merge(...) instead.",
        version="0.8.0",
    )
    def __xor__(self, other: "ChoiceMap") -> "ChoiceMap":
        return self | other

    def __or__(self, other: "ChoiceMap") -> "ChoiceMap":
        return Or.build(self, other)

    def __and__(self, other: "ChoiceMap") -> "ChoiceMap":
        return other.filter(self.get_selection())

    def __add__(self, other: "ChoiceMap") -> "ChoiceMap":
        return self | other

    def __call__(
        self,
        *addresses: Address,
    ) -> "ChoiceMap":
        """Alias for `get_submap(*addresses)`."""
        return self.get_submap(*addresses)

    def __getitem__(
        self,
        addr: Address,
    ):
        submap = self.get_submap(addr)
        v = submap.get_value()
        if v is None:
            raise ChoiceMapNoValueAtAddress(addr)
        else:
            return v

    def __contains__(
        self,
        addr: Address,
    ) -> bool:
        return self.get_submap(addr).has_value()

    @property
    def at(self) -> _ChoiceMapBuilder:
        """
        Returns a _ChoiceMapBuilder instance for constructing nested ChoiceMaps.

        This property allows for a fluent interface to build complex ChoiceMaps
        by chaining address components and setting values.

        Returns:
            A builder object for constructing ChoiceMaps.

        Example:
            ```python exec="yes" html="true" source="material-block" session="choicemap"
            from genjax import ChoiceMap

            chm = ChoiceMap.d({("x", "y"): 3.0, "z": 12.0})
            updated = chm.at["x", "y"].set(4.0)

            assert updated["x", "y"] == 4.0
            assert updated["z"] == chm["z"]
            ```
        """
        return _ChoiceMapBuilder(self, [])

    @nobeartype
    @deprecated(
        reason="Acts as identity; filters are now automatically pushed down.",
        version="0.8.0",
    )
    def simplify(self) -> "ChoiceMap":
        """Previously pushed down filters, now acts as identity."""
        return self

    def invalid_subset(
        self,
        gen_fn: "genjax.GenerativeFunction[Any]",
        args: tuple[Any, ...],
    ) -> "ChoiceMap | None":
        """
        Identifies the subset of choices that are invalid for a given generative function and its arguments.

        This method checks if all choices in the current ChoiceMap are valid for the given
        generative function and its arguments.

        Args:
            gen_fn: The generative function to check against.
            args: The arguments to the generative function.

        Returns:
            A ChoiceMap containing any extra choices not reachable in the course of `gen_fn`'s execution, or None if no extra choices are found.

        Example:
            ```python exec="yes" html="true" source="material-block" session="choicemap"
            @genjax.gen
            def model(x):
                y = bernoulli(0.5) @ "y"
                return x + y


            chm = ChoiceMap.d({"y": 1, "z": 2})
            extras = chm.invalid_subset(model, (1,))
            assert "z" in extras  # "z" is an extra choice not in the model
            ```
        """
        shape_chm = gen_fn.get_zero_trace(*args).get_choices()
        shape_sel = _shape_selection(shape_chm)
        extras = self.filter(~shape_sel)
        if not extras.static_is_empty():
            return extras

at property

at: _ChoiceMapBuilder

Returns a _ChoiceMapBuilder instance for constructing nested ChoiceMaps.

This property allows for a fluent interface to build complex ChoiceMaps by chaining address components and setting values.

Returns:

Type Description
_ChoiceMapBuilder

A builder object for constructing ChoiceMaps.

Example
from genjax import ChoiceMap

chm = ChoiceMap.d({("x", "y"): 3.0, "z": 12.0})
updated = chm.at["x", "y"].set(4.0)

assert updated["x", "y"] == 4.0
assert updated["z"] == chm["z"]

__call__

__call__(*addresses: Address) -> ChoiceMap

Alias for get_submap(*addresses).

Source code in src/genjax/_src/core/generative/choice_map.py
def __call__(
    self,
    *addresses: Address,
) -> "ChoiceMap":
    """Alias for `get_submap(*addresses)`."""
    return self.get_submap(*addresses)

choice staticmethod

choice(v: Any) -> ChoiceMap

Creates a ChoiceMap containing a single value.

This method creates and returns an instance of Choice, which represents a ChoiceMap with a single value at the root level.

Parameters:

Name Type Description Default

v

Any

The value to be stored in the ChoiceMap.

required

Returns:

Type Description
ChoiceMap

A ChoiceMap containing the single value.

Example
from genjax import ChoiceMap

value_chm = ChoiceMap.value(42)
assert value_chm.get_value() == 42
Source code in src/genjax/_src/core/generative/choice_map.py
@staticmethod
def choice(v: Any) -> "ChoiceMap":
    """
    Creates a ChoiceMap containing a single value.

    This method creates and returns an instance of Choice, which represents
    a ChoiceMap with a single value at the root level.

    Args:
        v: The value to be stored in the ChoiceMap.

    Returns:
        A ChoiceMap containing the single value.

    Example:
        ```python exec="yes" html="true" source="material-block" session="choicemap"
        from genjax import ChoiceMap

        value_chm = ChoiceMap.value(42)
        assert value_chm.get_value() == 42
        ```
    """
    return Choice.build(v)

d staticmethod

d(d: dict[K_addr, Any]) -> ChoiceMap

Creates a ChoiceMap from a dictionary.

This method creates and returns a ChoiceMap based on the key-value pairs in the provided dictionary. Each key in the dictionary becomes an address in the ChoiceMap, and the corresponding value is stored at that address.

Dict-shaped values are recursively converted to ChoiceMap instances.

Parameters:

Name Type Description Default

d

dict[K_addr, Any]

A dictionary where keys are addresses and values are the corresponding data to be stored in the ChoiceMap.

required

Returns:

Type Description
ChoiceMap

A ChoiceMap containing the key-value pairs from the input dictionary.

Example
from genjax import ChoiceMap

dict_chm = ChoiceMap.d({"x": 42, "y": {"z": [1, 2, 3]}})
assert dict_chm["x"] == 42
assert dict_chm["y", "z"] == [1, 2, 3]
Source code in src/genjax/_src/core/generative/choice_map.py
@staticmethod
def d(d: dict[K_addr, Any]) -> "ChoiceMap":
    """
    Creates a ChoiceMap from a dictionary.

    This method creates and returns a ChoiceMap based on the key-value pairs in the provided dictionary. Each key in the dictionary becomes an address in the ChoiceMap, and the corresponding value is stored at that address.

    Dict-shaped values are recursively converted to ChoiceMap instances.

    Args:
        d: A dictionary where keys are addresses and values are the corresponding data to be stored in the ChoiceMap.

    Returns:
        A ChoiceMap containing the key-value pairs from the input dictionary.

    Example:
        ```python exec="yes" html="true" source="material-block" session="choicemap"
        from genjax import ChoiceMap

        dict_chm = ChoiceMap.d({"x": 42, "y": {"z": [1, 2, 3]}})
        assert dict_chm["x"] == 42
        assert dict_chm["y", "z"] == [1, 2, 3]
        ```
    """
    return ChoiceMap.from_mapping(d.items())

empty staticmethod

empty() -> ChoiceMap

Returns a ChoiceMap with no values or submaps.

Returns:

Type Description
ChoiceMap

An empty ChoiceMap.

Source code in src/genjax/_src/core/generative/choice_map.py
@staticmethod
def empty() -> "ChoiceMap":
    """
    Returns a ChoiceMap with no values or submaps.

    Returns:
        An empty ChoiceMap.
    """
    return _empty

entry staticmethod

entry(
    v: dict[K_addr, Any] | ChoiceMap | Any,
    *addrs: AddressComponent
) -> ChoiceMap

Creates a ChoiceMap with a single value at a specified address.

This method creates and returns a ChoiceMap with a new ChoiceMap stored at the given address.

  • if the provided value is already a ChoiceMap, it will be used directly;
  • dict values will be passed to ChoiceMap.d;
  • any other value will be passed to ChoiceMap.value.

Parameters:

Name Type Description Default

v

dict[K_addr, Any] | ChoiceMap | Any

The value to be stored in the ChoiceMap. Can be any value, a dict or a ChoiceMap.

required

addrs

AddressComponent

The address at which to store the value. Can be a static or dynamic address component.

()

Returns:

Type Description
ChoiceMap

A ChoiceMap with the value stored at the specified address.

Example
import genjax
import jax.numpy as jnp

# Using an existing ChoiceMap
nested_chm = ChoiceMap.entry(ChoiceMap.value(42), "x")
assert nested_chm["x"] == 42

# Using a dict generates a new `ChoiceMap.d` call
nested_chm = ChoiceMap.entry({"y": 42}, "x")
assert nested_chm["x", "y"] == 42

# Static address
static_chm = ChoiceMap.entry(42, "x")
assert static_chm["x"] == 42

# Dynamic address
dynamic_chm = ChoiceMap.entry(
    jnp.array([1.1, 2.2, 3.3]), jnp.array([1, 2, 3])
)
assert dynamic_chm[1] == genjax.Mask(1.1, True)
Source code in src/genjax/_src/core/generative/choice_map.py
@staticmethod
def entry(
    v: "dict[K_addr, Any] | ChoiceMap | Any", *addrs: AddressComponent
) -> "ChoiceMap":
    """
    Creates a ChoiceMap with a single value at a specified address.

    This method creates and returns a ChoiceMap with a new ChoiceMap stored at
    the given address.

    - if the provided value is already a ChoiceMap, it will be used directly;
    - `dict` values will be passed to `ChoiceMap.d`;
    - any other value will be passed to `ChoiceMap.value`.

    Args:
        v: The value to be stored in the ChoiceMap. Can be any value, a dict or a ChoiceMap.
        addrs: The address at which to store the value. Can be a static or dynamic address component.

    Returns:
        A ChoiceMap with the value stored at the specified address.

    Example:
        ```python exec="yes" html="true" source="material-block" session="choicemap"
        import genjax
        import jax.numpy as jnp

        # Using an existing ChoiceMap
        nested_chm = ChoiceMap.entry(ChoiceMap.value(42), "x")
        assert nested_chm["x"] == 42

        # Using a dict generates a new `ChoiceMap.d` call
        nested_chm = ChoiceMap.entry({"y": 42}, "x")
        assert nested_chm["x", "y"] == 42

        # Static address
        static_chm = ChoiceMap.entry(42, "x")
        assert static_chm["x"] == 42

        # Dynamic address
        dynamic_chm = ChoiceMap.entry(
            jnp.array([1.1, 2.2, 3.3]), jnp.array([1, 2, 3])
        )
        assert dynamic_chm[1] == genjax.Mask(1.1, True)
        ```
    """
    if isinstance(v, ChoiceMap):
        chm = v
    elif isinstance(v, dict):
        chm = ChoiceMap.d(v)
    else:
        chm = ChoiceMap.choice(v)

    return chm.extend(*addrs)

extend

extend(*addrs: AddressComponent) -> ChoiceMap

Returns a new ChoiceMap with the given address component as its root.

This method creates a new ChoiceMap where the current ChoiceMap becomes a submap under the specified address component. It effectively adds a new level of hierarchy to the ChoiceMap structure.

Parameters:

Name Type Description Default

addrs

AddressComponent

The address components to use as the new root.

()

Returns:

Type Description
ChoiceMap

A new ChoiceMap with the current ChoiceMap nested under the given address.

Example
original_chm = ChoiceMap.value(42)
indexed_chm = original_chm.extend("x")
assert indexed_chm["x"] == 42
Source code in src/genjax/_src/core/generative/choice_map.py
def extend(self, *addrs: AddressComponent) -> "ChoiceMap":
    """
    Returns a new ChoiceMap with the given address component as its root.

    This method creates a new ChoiceMap where the current ChoiceMap becomes a submap
    under the specified address component. It effectively adds a new level of hierarchy
    to the ChoiceMap structure.

    Args:
        addrs: The address components to use as the new root.

    Returns:
        A new ChoiceMap with the current ChoiceMap nested under the given address.

    Example:
        ```python exec="yes" html="true" source="material-block" session="choicemap"
        original_chm = ChoiceMap.value(42)
        indexed_chm = original_chm.extend("x")
        assert indexed_chm["x"] == 42
        ```
    """
    acc = self
    for addr in reversed(addrs):
        if isinstance(addr, StaticAddressComponent):
            acc = Static.build({addr: acc})
        else:
            acc = Indexed.build(acc, addr)

    return acc

filter abstractmethod

filter(selection: Selection | Flag) -> ChoiceMap

Filter the choice map on the Selection. The resulting choice map only contains the addresses that return True when presented to the selection.

Parameters:

Name Type Description Default

selection

Selection | Flag

The Selection to filter the choice map with.

required

Returns:

Type Description
ChoiceMap

A new ChoiceMap containing only the addresses selected by the given Selection.

Examples:

import jax
import genjax
from genjax import bernoulli
from genjax import SelectionBuilder as S


@genjax.gen
def model():
    x = bernoulli(0.3) @ "x"
    y = bernoulli(0.3) @ "y"
    return x


key = jax.random.key(314159)
tr = model.simulate(key, ())
chm = tr.get_choices()
selection = S["x"]
filtered = chm.filter(selection)
assert "y" not in filtered
Source code in src/genjax/_src/core/generative/choice_map.py
@abstractmethod
def filter(self, selection: Selection | Flag) -> "ChoiceMap":
    """
    Filter the choice map on the `Selection`. The resulting choice map only contains the addresses that return True when presented to the selection.

    Args:
        selection: The Selection to filter the choice map with.

    Returns:
        A new ChoiceMap containing only the addresses selected by the given Selection.

    Examples:
        ```python exec="yes" html="true" source="material-block" session="choicemap"
        import jax
        import genjax
        from genjax import bernoulli
        from genjax import SelectionBuilder as S


        @genjax.gen
        def model():
            x = bernoulli(0.3) @ "x"
            y = bernoulli(0.3) @ "y"
            return x


        key = jax.random.key(314159)
        tr = model.simulate(key, ())
        chm = tr.get_choices()
        selection = S["x"]
        filtered = chm.filter(selection)
        assert "y" not in filtered
        ```
    """

from_mapping staticmethod

from_mapping(
    pairs: Iterable[tuple[K_addr, Any]]
) -> ChoiceMap

Creates a ChoiceMap from an iterable of address-value pairs.

This method constructs a ChoiceMap by iterating through the provided pairs, where each pair consists of an address (or address component) and a corresponding value. The resulting ChoiceMap will contain all the values at their respective addresses.

Parameters:

Name Type Description Default

pairs

Iterable[tuple[K_addr, Any]]

An iterable of tuples, where each tuple contains an address (or address component) and its corresponding value. The address can be a single component or a tuple of components.

required

Returns:

Type Description
ChoiceMap

A ChoiceMap containing all the address-value pairs from the input.

Example
pairs = [("x", 42), (("y", "z"), 10), ("w", [1, 2, 3])]
chm = ChoiceMap.from_mapping(pairs)
assert chm["x"] == 42
assert chm["y", "z"] == 10
assert chm["w"] == [1, 2, 3]
Note

If multiple pairs have the same address, later pairs will overwrite earlier ones.

Source code in src/genjax/_src/core/generative/choice_map.py
@staticmethod
def from_mapping(pairs: Iterable[tuple[K_addr, Any]]) -> "ChoiceMap":
    """
    Creates a ChoiceMap from an iterable of address-value pairs.

    This method constructs a ChoiceMap by iterating through the provided pairs,
    where each pair consists of an address (or address component) and a corresponding value.
    The resulting ChoiceMap will contain all the values at their respective addresses.

    Args:
        pairs: An iterable of tuples, where each tuple contains an address (or address component) and its corresponding value. The address can be a single component or a tuple of components.

    Returns:
        A ChoiceMap containing all the address-value pairs from the input.

    Example:
        ```python
        pairs = [("x", 42), (("y", "z"), 10), ("w", [1, 2, 3])]
        chm = ChoiceMap.from_mapping(pairs)
        assert chm["x"] == 42
        assert chm["y", "z"] == 10
        assert chm["w"] == [1, 2, 3]
        ```

    Note:
        If multiple pairs have the same address, later pairs will overwrite earlier ones.
    """
    acc = ChoiceMap.empty()

    for addr, v in pairs:
        addr = addr if isinstance(addr, tuple) else (addr,)
        acc |= ChoiceMap.entry(v, *addr)

    return acc

get_selection

get_selection() -> Selection

Returns a Selection representing the structure of this ChoiceMap.

This method creates a Selection that matches the hierarchical structure of the current ChoiceMap. The resulting Selection can be used to filter or query other ChoiceMaps with the same structure.

Returns:

Type Description
Selection

A Selection object representing the structure of this ChoiceMap.

Example
chm = ChoiceMap.value(5).extend("x")
sel = chm.get_selection()
assert sel["x"] == True
assert sel["y"] == False
Source code in src/genjax/_src/core/generative/choice_map.py
def get_selection(self) -> Selection:
    """
    Returns a Selection representing the structure of this ChoiceMap.

    This method creates a Selection that matches the hierarchical structure
    of the current ChoiceMap. The resulting Selection can be used to filter
    or query other ChoiceMaps with the same structure.

    Returns:
        A Selection object representing the structure of this ChoiceMap.

    Example:
        ```python exec="yes" html="true" source="material-block" session="choicemap"
        chm = ChoiceMap.value(5).extend("x")
        sel = chm.get_selection()
        assert sel["x"] == True
        assert sel["y"] == False
        ```
    """
    return ChmSel.build(self)

invalid_subset

invalid_subset(
    gen_fn: GenerativeFunction[Any], args: tuple[Any, ...]
) -> ChoiceMap | None

Identifies the subset of choices that are invalid for a given generative function and its arguments.

This method checks if all choices in the current ChoiceMap are valid for the given generative function and its arguments.

Parameters:

Name Type Description Default

gen_fn

GenerativeFunction[Any]

The generative function to check against.

required

args

tuple[Any, ...]

The arguments to the generative function.

required

Returns:

Type Description
ChoiceMap | None

A ChoiceMap containing any extra choices not reachable in the course of gen_fn's execution, or None if no extra choices are found.

Example
@genjax.gen
def model(x):
    y = bernoulli(0.5) @ "y"
    return x + y


chm = ChoiceMap.d({"y": 1, "z": 2})
extras = chm.invalid_subset(model, (1,))
assert "z" in extras  # "z" is an extra choice not in the model
Source code in src/genjax/_src/core/generative/choice_map.py
def invalid_subset(
    self,
    gen_fn: "genjax.GenerativeFunction[Any]",
    args: tuple[Any, ...],
) -> "ChoiceMap | None":
    """
    Identifies the subset of choices that are invalid for a given generative function and its arguments.

    This method checks if all choices in the current ChoiceMap are valid for the given
    generative function and its arguments.

    Args:
        gen_fn: The generative function to check against.
        args: The arguments to the generative function.

    Returns:
        A ChoiceMap containing any extra choices not reachable in the course of `gen_fn`'s execution, or None if no extra choices are found.

    Example:
        ```python exec="yes" html="true" source="material-block" session="choicemap"
        @genjax.gen
        def model(x):
            y = bernoulli(0.5) @ "y"
            return x + y


        chm = ChoiceMap.d({"y": 1, "z": 2})
        extras = chm.invalid_subset(model, (1,))
        assert "z" in extras  # "z" is an extra choice not in the model
        ```
    """
    shape_chm = gen_fn.get_zero_trace(*args).get_choices()
    shape_sel = _shape_selection(shape_chm)
    extras = self.filter(~shape_sel)
    if not extras.static_is_empty():
        return extras

kw staticmethod

kw(**kwargs) -> ChoiceMap

Creates a ChoiceMap from keyword arguments.

This method creates and returns a ChoiceMap based on the provided keyword arguments. Each keyword argument becomes an address in the ChoiceMap, and its value is stored at that address.

Dict-shaped values are recursively converted to ChoiceMap instances with calls to ChoiceMap.d.

Returns:

Type Description
ChoiceMap

A ChoiceMap containing the key-value pairs from the input keyword arguments.

Example
kw_chm = ChoiceMap.kw(x=42, y=[1, 2, 3], z={"w": 10.0})
assert kw_chm["x"] == 42
assert kw_chm["y"] == [1, 2, 3]
assert kw_chm["z", "w"] == 10.0
Source code in src/genjax/_src/core/generative/choice_map.py
@staticmethod
def kw(**kwargs) -> "ChoiceMap":
    """
    Creates a ChoiceMap from keyword arguments.

    This method creates and returns a ChoiceMap based on the provided keyword arguments.
    Each keyword argument becomes an address in the ChoiceMap, and its value is stored at that address.

    Dict-shaped values are recursively converted to ChoiceMap instances with calls to `ChoiceMap.d`.

    Returns:
        A ChoiceMap containing the key-value pairs from the input keyword arguments.

    Example:
        ```python exec="yes" html="true" source="material-block" session="choicemap"
        kw_chm = ChoiceMap.kw(x=42, y=[1, 2, 3], z={"w": 10.0})
        assert kw_chm["x"] == 42
        assert kw_chm["y"] == [1, 2, 3]
        assert kw_chm["z", "w"] == 10.0
        ```
    """
    return ChoiceMap.d(kwargs)

mask

mask(flag: Flag) -> ChoiceMap

Returns a new ChoiceMap with values masked by a boolean flag.

This method creates a new ChoiceMap where the values are conditionally included based on the provided flag. If the flag is True, the original values are retained; if False, the ChoiceMap behaves as if it's empty.

Parameters:

Name Type Description Default

flag

Flag

A boolean flag determining whether to include the values.

required

Returns:

Type Description
ChoiceMap

A new ChoiceMap with values conditionally masked.

Example
original_chm = ChoiceMap.value(42)
masked_chm = original_chm.mask(True)
assert masked_chm.get_value() == 42

masked_chm = original_chm.mask(False)
assert masked_chm.get_value() is None
Source code in src/genjax/_src/core/generative/choice_map.py
def mask(self, flag: Flag) -> "ChoiceMap":
    """
    Returns a new ChoiceMap with values masked by a boolean flag.

    This method creates a new ChoiceMap where the values are conditionally
    included based on the provided flag. If the flag is True, the original
    values are retained; if False, the ChoiceMap behaves as if it's empty.

    Args:
        flag: A boolean flag determining whether to include the values.

    Returns:
        A new ChoiceMap with values conditionally masked.

    Example:
        ```python exec="yes" html="true" source="material-block" session="choicemap"
        original_chm = ChoiceMap.value(42)
        masked_chm = original_chm.mask(True)
        assert masked_chm.get_value() == 42

        masked_chm = original_chm.mask(False)
        assert masked_chm.get_value() is None
        ```
    """
    return self.filter(flag)

merge

Merges this ChoiceMap with another ChoiceMap.

This method combines the current ChoiceMap with another ChoiceMap using the XOR operation (^). It creates a new ChoiceMap that contains all addresses from both input ChoiceMaps; any overlapping addresses will trigger an error on access at the address via [<addr>] or get_value(). Use | if you don't want this behavior.

Parameters:

Name Type Description Default

other

ChoiceMap

The ChoiceMap to merge with the current one.

required

Returns:

Type Description
ChoiceMap

A new ChoiceMap resulting from the merge operation.

Example
chm1 = ChoiceMap.value(5).extend("x")
chm2 = ChoiceMap.value(10).extend("y")
merged_chm = chm1.merge(chm2)
assert merged_chm["x"] == 5
assert merged_chm["y"] == 10
Note

This method is equivalent to using the | operator between two ChoiceMaps.

Source code in src/genjax/_src/core/generative/choice_map.py
def merge(self, other: "ChoiceMap") -> "ChoiceMap":
    """
    Merges this ChoiceMap with another ChoiceMap.

    This method combines the current ChoiceMap with another ChoiceMap using the XOR operation (^). It creates a new ChoiceMap that contains all addresses from both input ChoiceMaps; any overlapping addresses will trigger an error on access at the address via `[<addr>]` or `get_value()`. Use `|` if you don't want this behavior.

    Args:
        other: The ChoiceMap to merge with the current one.

    Returns:
        A new ChoiceMap resulting from the merge operation.

    Example:
        ```python exec="yes" html="true" source="material-block" session="choicemap"
        chm1 = ChoiceMap.value(5).extend("x")
        chm2 = ChoiceMap.value(10).extend("y")
        merged_chm = chm1.merge(chm2)
        assert merged_chm["x"] == 5
        assert merged_chm["y"] == 10
        ```

    Note:
        This method is equivalent to using the | operator between two ChoiceMaps.
    """
    return self | other

simplify

simplify() -> ChoiceMap

Previously pushed down filters, now acts as identity.

Source code in src/genjax/_src/core/generative/choice_map.py
@nobeartype
@deprecated(
    reason="Acts as identity; filters are now automatically pushed down.",
    version="0.8.0",
)
def simplify(self) -> "ChoiceMap":
    """Previously pushed down filters, now acts as identity."""
    return self

static_is_empty

static_is_empty() -> bool

Returns True if this ChoiceMap is equal to ChoiceMap.empty(), False otherwise.

Source code in src/genjax/_src/core/generative/choice_map.py
def static_is_empty(self) -> bool:
    """
    Returns True if this ChoiceMap is equal to `ChoiceMap.empty()`, False otherwise.
    """
    return False

switch staticmethod

switch(
    idx: int | IntArray, chms: Iterable[ChoiceMap]
) -> ChoiceMap

Creates a ChoiceMap that switches between multiple ChoiceMaps based on an index.

This method creates a new ChoiceMap that selectively includes values from a sequence of input ChoiceMaps based on the provided index. The resulting ChoiceMap will contain values from the ChoiceMap at the position specified by the index, while masking out values from all other ChoiceMaps.

Parameters:

Name Type Description Default

idx

int | IntArray

An index or array of indices specifying which ChoiceMap(s) to select from.

required

chms

Iterable[ChoiceMap]

An iterable of ChoiceMaps to switch between.

required

Returns:

Type Description
ChoiceMap

A new ChoiceMap containing values from the selected ChoiceMap(s).

Example
chm1 = ChoiceMap.d({"x": 1, "y": 2})
chm2 = ChoiceMap.d({"x": 3, "y": 4})
chm3 = ChoiceMap.d({"x": 5, "y": 6})

switched = ChoiceMap.switch(jnp.array(1), [chm1, chm2, chm3])
assert switched["x"].unmask() == 3
assert switched["y"].unmask() == 4
Source code in src/genjax/_src/core/generative/choice_map.py
@staticmethod
def switch(idx: int | IntArray, chms: Iterable["ChoiceMap"]) -> "ChoiceMap":
    """
    Creates a ChoiceMap that switches between multiple ChoiceMaps based on an index.

    This method creates a new ChoiceMap that selectively includes values from a sequence of
    input ChoiceMaps based on the provided index. The resulting ChoiceMap will contain
    values from the ChoiceMap at the position specified by the index, while masking out
    values from all other ChoiceMaps.

    Args:
        idx: An index or array of indices specifying which ChoiceMap(s) to select from.
        chms: An iterable of ChoiceMaps to switch between.

    Returns:
        A new ChoiceMap containing values from the selected ChoiceMap(s).

    Example:
        ```python exec="yes" html="true" source="material-block" session="choicemap"
        chm1 = ChoiceMap.d({"x": 1, "y": 2})
        chm2 = ChoiceMap.d({"x": 3, "y": 4})
        chm3 = ChoiceMap.d({"x": 5, "y": 6})

        switched = ChoiceMap.switch(jnp.array(1), [chm1, chm2, chm3])
        assert switched["x"].unmask() == 3
        assert switched["y"].unmask() == 4
        ```
    """
    return Switch.build(idx, chms)

genjax.core.Selection

Bases: Pytree

A class representing a selection of addresses in a ChoiceMap.

Selection objects are used to filter and manipulate ChoiceMaps by specifying which addresses should be included or excluded.

Selection instances support various operations such as union (via &), intersection (via |), and complement (via ~), allowing for complex selection criteria to be constructed.

Methods:

Name Description
all

Creates a Selection that includes all addresses.

none

Creates a Selection that includes no addresses.

at

A builder instance for creating Selection objects using indexing syntax.

Examples:

Creating selections:

import genjax
from genjax import Selection

# Select all addresses
all_sel = Selection.all()

# Select no addresses
none_sel = Selection.none()

# Select specific addresses
specific_sel = Selection.at["x", "y"]

# Match (<wildcard>, "y")
wildcard_sel = Selection.at[..., "y"]

# Combine selections
combined_sel = specific_sel | Selection.at["z"]

Querying selections:

# Create a selection
sel = Selection.at["x", "y"]

# Querying the selection using () returns a sub-selection
assert sel("x") == Selection.at["y"]
assert sel("z") == Selection.none()

# Querying the selection using [] returns a `bool` representing whether or not the input matches:
assert sel["x"] == False
assert sel["x", "y"] == True

# Querying the selection using "in" acts the same:
assert not "x" in sel
assert ("x", "y") in sel

# Nested querying
nested_sel = Selection.at["a", "b", "c"]
assert nested_sel("a")("b") == Selection.at["c"]

Selection objects can passed to a ChoiceMap via the filter method to filter and manipulate data based on address patterns.

Attributes:

Name Type Description
at Final[_SelectionBuilder]

A builder instance for creating Selection objects.

Source code in src/genjax/_src/core/generative/choice_map.py
class Selection(Pytree):
    """
    A class representing a selection of addresses in a ChoiceMap.

    Selection objects are used to filter and manipulate ChoiceMaps by specifying which addresses should be included or excluded.

    Selection instances support various operations such as union (via `&`), intersection (via `|`), and complement (via `~`), allowing for complex selection criteria to be constructed.

    Methods:
        all(): Creates a Selection that includes all addresses.
        none(): Creates a Selection that includes no addresses.
        at: A builder instance for creating Selection objects using indexing syntax.

    Examples:
        Creating selections:
        ```python exec="yes" html="true" source="material-block" session="choicemap"
        import genjax
        from genjax import Selection

        # Select all addresses
        all_sel = Selection.all()

        # Select no addresses
        none_sel = Selection.none()

        # Select specific addresses
        specific_sel = Selection.at["x", "y"]

        # Match (<wildcard>, "y")
        wildcard_sel = Selection.at[..., "y"]

        # Combine selections
        combined_sel = specific_sel | Selection.at["z"]
        ```

        Querying selections:
        ```python exec="yes" html="true" source="material-block" session="choicemap"
        # Create a selection
        sel = Selection.at["x", "y"]

        # Querying the selection using () returns a sub-selection
        assert sel("x") == Selection.at["y"]
        assert sel("z") == Selection.none()

        # Querying the selection using [] returns a `bool` representing whether or not the input matches:
        assert sel["x"] == False
        assert sel["x", "y"] == True

        # Querying the selection using "in" acts the same:
        assert not "x" in sel
        assert ("x", "y") in sel

        # Nested querying
        nested_sel = Selection.at["a", "b", "c"]
        assert nested_sel("a")("b") == Selection.at["c"]
        ```

    Selection objects can passed to a `ChoiceMap` via the `filter` method to filter and manipulate data based on address patterns.
    """

    #################################################
    # Convenient syntax for constructing selections #
    #################################################

    at: Final[_SelectionBuilder] = _SelectionBuilder()
    """A builder instance for creating Selection objects.

    `at` provides a convenient interface for constructing Selection objects
    using a familiar indexing syntax. It allows for the creation of complex
    selections by chaining multiple address components.

    Examples:
        Creating a selection:
        ```python exec="yes" html="true" source="material-block" session="choicemap"
        from genjax import Selection
        Selection.at["x", "y"]
        ```
    """

    @staticmethod
    def all() -> "Selection":
        """
        Returns a Selection that selects all addresses.

        Returns:
            A Selection that selects everything.

        Example:
            ```python exec="yes" html="true" source="material-block" session="choicemap"
            from genjax import Selection

            all_selection = Selection.all()
            assert all_selection["any_address"] == True
            ```
        """
        return AllSel()

    @staticmethod
    def none() -> "Selection":
        """
        Returns a Selection that selects no addresses.

        Returns:
            A Selection that selects nothing.

        Example:
            ```python exec="yes" html="true" source="material-block" session="choicemap"
            none_selection = Selection.none()
            assert none_selection["any_address"] == False
            ```
        """
        return NoneSel()

    @staticmethod
    def leaf() -> "Selection":
        """
        Returns a Selection that selects only leaf addresses.

        A leaf address is an address that doesn't have any sub-addresses.
        This selection is useful when you want to target only the final elements in a nested structure.

        Returns:
            A Selection that selects only leaf addresses.

        Example:
            ```python exec="yes" html="true" source="material-block" session="choicemap"
            leaf_selection = Selection.leaf().extend("a", "b")
            assert leaf_selection["a", "b"]
            assert not leaf_selection["a", "b", "anything"]
            ```
        """
        return LeafSel()

    ######################
    # Combinator methods #
    ######################

    def __or__(self, other: "Selection") -> "Selection":
        return OrSel.build(self, other)

    def __and__(self, other: "Selection") -> "Selection":
        return AndSel.build(self, other)

    def __invert__(self) -> "Selection":
        return ComplementSel.build(self)

    def complement(self) -> "Selection":
        return ~self

    def filter(self, sample: "ChoiceMap") -> "ChoiceMap":
        """
        Returns a new ChoiceMap filtered with this Selection.

        This method applies the current Selection to the given ChoiceMap, effectively filtering out addresses that are not matched.

        Args:
            sample: The ChoiceMap to be filtered.

        Returns:
            A new ChoiceMap containing only the addresses selected by this Selection.

        Example:
            ```python exec="yes" html="true" source="material-block" session="choicemap"
            selection = Selection.at["x"]

            chm = ChoiceMap.kw(x=1, y=2)
            filtered_chm = selection.filter(chm)

            assert "x" in filtered_chm
            assert "y" not in filtered_chm
            ```
        """
        return sample.filter(self)

    def extend(self, *addrs: ExtendedStaticAddressComponent) -> "Selection":
        """
        Returns a new Selection that is prefixed by the given address components.

        This method creates a new Selection that applies the current selection
        to the specified address components. It handles both static and dynamic
        address components.

        Note that `...` as an address component will match any supplied address.

        Args:
            addrs: The address components under which to nest the selection.

        Returns:
            A new Selection extended by the given address component.

        Example:
            ```python exec="yes" html="true" source="material-block" session="choicemap"
            base_selection = Selection.all()
            indexed_selection = base_selection.extend("x")
            assert indexed_selection["x", "any_subaddress"] == True
            assert indexed_selection["y"] == False
            ```
        """
        acc = self
        for addr in reversed(addrs):
            acc = StaticSel.build(acc, addr)
        return acc

    def __call__(
        self,
        addr: StaticAddress,
    ) -> "Selection":
        addr = addr if isinstance(addr, tuple) else (addr,)
        subselection = self
        for comp in addr:
            subselection = subselection.get_subselection(comp)
        return subselection

    def __getitem__(
        self,
        addr: StaticAddress,
    ) -> bool:
        return self(addr).check()

    def __contains__(
        self,
        addr: StaticAddress,
    ) -> bool:
        return self[addr]

    @abstractmethod
    def check(self) -> bool:
        pass

    @abstractmethod
    def get_subselection(self, addr: StaticAddressComponent) -> "Selection":
        pass

at class-attribute instance-attribute

at: Final[_SelectionBuilder] = _SelectionBuilder()

A builder instance for creating Selection objects.

at provides a convenient interface for constructing Selection objects using a familiar indexing syntax. It allows for the creation of complex selections by chaining multiple address components.

Examples:

Creating a selection:

from genjax import Selection
Selection.at["x", "y"]

all staticmethod

all() -> Selection

Returns a Selection that selects all addresses.

Returns:

Type Description
Selection

A Selection that selects everything.

Example
from genjax import Selection

all_selection = Selection.all()
assert all_selection["any_address"] == True
Source code in src/genjax/_src/core/generative/choice_map.py
@staticmethod
def all() -> "Selection":
    """
    Returns a Selection that selects all addresses.

    Returns:
        A Selection that selects everything.

    Example:
        ```python exec="yes" html="true" source="material-block" session="choicemap"
        from genjax import Selection

        all_selection = Selection.all()
        assert all_selection["any_address"] == True
        ```
    """
    return AllSel()

extend

extend(*addrs: ExtendedStaticAddressComponent) -> Selection

Returns a new Selection that is prefixed by the given address components.

This method creates a new Selection that applies the current selection to the specified address components. It handles both static and dynamic address components.

Note that ... as an address component will match any supplied address.

Parameters:

Name Type Description Default

addrs

ExtendedStaticAddressComponent

The address components under which to nest the selection.

()

Returns:

Type Description
Selection

A new Selection extended by the given address component.

Example
base_selection = Selection.all()
indexed_selection = base_selection.extend("x")
assert indexed_selection["x", "any_subaddress"] == True
assert indexed_selection["y"] == False
Source code in src/genjax/_src/core/generative/choice_map.py
def extend(self, *addrs: ExtendedStaticAddressComponent) -> "Selection":
    """
    Returns a new Selection that is prefixed by the given address components.

    This method creates a new Selection that applies the current selection
    to the specified address components. It handles both static and dynamic
    address components.

    Note that `...` as an address component will match any supplied address.

    Args:
        addrs: The address components under which to nest the selection.

    Returns:
        A new Selection extended by the given address component.

    Example:
        ```python exec="yes" html="true" source="material-block" session="choicemap"
        base_selection = Selection.all()
        indexed_selection = base_selection.extend("x")
        assert indexed_selection["x", "any_subaddress"] == True
        assert indexed_selection["y"] == False
        ```
    """
    acc = self
    for addr in reversed(addrs):
        acc = StaticSel.build(acc, addr)
    return acc

filter

Returns a new ChoiceMap filtered with this Selection.

This method applies the current Selection to the given ChoiceMap, effectively filtering out addresses that are not matched.

Parameters:

Name Type Description Default

sample

ChoiceMap

The ChoiceMap to be filtered.

required

Returns:

Type Description
ChoiceMap

A new ChoiceMap containing only the addresses selected by this Selection.

Example
selection = Selection.at["x"]

chm = ChoiceMap.kw(x=1, y=2)
filtered_chm = selection.filter(chm)

assert "x" in filtered_chm
assert "y" not in filtered_chm
Source code in src/genjax/_src/core/generative/choice_map.py
def filter(self, sample: "ChoiceMap") -> "ChoiceMap":
    """
    Returns a new ChoiceMap filtered with this Selection.

    This method applies the current Selection to the given ChoiceMap, effectively filtering out addresses that are not matched.

    Args:
        sample: The ChoiceMap to be filtered.

    Returns:
        A new ChoiceMap containing only the addresses selected by this Selection.

    Example:
        ```python exec="yes" html="true" source="material-block" session="choicemap"
        selection = Selection.at["x"]

        chm = ChoiceMap.kw(x=1, y=2)
        filtered_chm = selection.filter(chm)

        assert "x" in filtered_chm
        assert "y" not in filtered_chm
        ```
    """
    return sample.filter(self)

leaf staticmethod

leaf() -> Selection

Returns a Selection that selects only leaf addresses.

A leaf address is an address that doesn't have any sub-addresses. This selection is useful when you want to target only the final elements in a nested structure.

Returns:

Type Description
Selection

A Selection that selects only leaf addresses.

Example
leaf_selection = Selection.leaf().extend("a", "b")
assert leaf_selection["a", "b"]
assert not leaf_selection["a", "b", "anything"]
Source code in src/genjax/_src/core/generative/choice_map.py
@staticmethod
def leaf() -> "Selection":
    """
    Returns a Selection that selects only leaf addresses.

    A leaf address is an address that doesn't have any sub-addresses.
    This selection is useful when you want to target only the final elements in a nested structure.

    Returns:
        A Selection that selects only leaf addresses.

    Example:
        ```python exec="yes" html="true" source="material-block" session="choicemap"
        leaf_selection = Selection.leaf().extend("a", "b")
        assert leaf_selection["a", "b"]
        assert not leaf_selection["a", "b", "anything"]
        ```
    """
    return LeafSel()

none staticmethod

none() -> Selection

Returns a Selection that selects no addresses.

Returns:

Type Description
Selection

A Selection that selects nothing.

Example
none_selection = Selection.none()
assert none_selection["any_address"] == False
Source code in src/genjax/_src/core/generative/choice_map.py
@staticmethod
def none() -> "Selection":
    """
    Returns a Selection that selects no addresses.

    Returns:
        A Selection that selects nothing.

    Example:
        ```python exec="yes" html="true" source="material-block" session="choicemap"
        none_selection = Selection.none()
        assert none_selection["any_address"] == False
        ```
    """
    return NoneSel()

JAX compatible data via Pytree

JAX natively works with arrays, and with instances of Python classes which can be broken down into lists of arrays. JAX's Pytree system provides a way to register a class with methods that can break instances of the class down into a list of arrays (canonically referred to as flattening), and build an instance back up given a list of arrays (canonically referred to as unflattening).

GenJAX provides an abstract class called Pytree which automates the implementation of the flatten / unflatten methods for a class. GenJAX's Pytree inherits from penzai.Struct, to support pretty printing, and some convenient methods to annotate what data should be part of the Pytree type (static fields, won't be broken down into a JAX array) and what data should be considered dynamic.

genjax.core.Pytree

Bases: Struct

Pytree is an abstract base class which registers a class with JAX's Pytree system. JAX's Pytree system tracks how data classes should behave across JAX-transformed function boundaries, like jax.jit or jax.vmap.

Inheriting this class provides the implementor with the freedom to declare how the subfields of a class should behave:

  • Pytree.static(...): the value of the field cannot be a JAX traced value, it must be a Python literal, or a constant). The values of static fields are embedded in the PyTreeDef of any instance of the class.
  • Pytree.field(...) or no annotation: the value may be a JAX traced value, and JAX will attempt to convert it to tracer values inside of its transformations.

If a field points to another Pytree, it should not be declared as Pytree.static(), as the Pytree interface will automatically handle the Pytree fields as dynamic fields.

Methods:

Name Description
dataclass

Denote that a class (which is inheriting Pytree) should be treated as a dataclass, meaning it can hold data in fields which are declared as part of the class.

static

Declare a field of a Pytree dataclass to be static. Users can provide additional keyword argument options,

field

Declare a field of a Pytree dataclass to be dynamic. Alternatively, one can leave the annotation off in the declaration.

Source code in src/genjax/_src/core/pytree.py
class Pytree(pz.Struct):
    """`Pytree` is an abstract base class which registers a class with JAX's `Pytree`
    system. JAX's `Pytree` system tracks how data classes should behave across JAX-transformed function boundaries, like `jax.jit` or `jax.vmap`.

    Inheriting this class provides the implementor with the freedom to declare how the subfields of a class should behave:

    * `Pytree.static(...)`: the value of the field cannot be a JAX traced value, it must be a Python literal, or a constant). The values of static fields are embedded in the `PyTreeDef` of any instance of the class.
    * `Pytree.field(...)` or no annotation: the value may be a JAX traced value, and JAX will attempt to convert it to tracer values inside of its transformations.

    If a field _points to another `Pytree`_, it should not be declared as `Pytree.static()`, as the `Pytree` interface will automatically handle the `Pytree` fields as dynamic fields.

    """

    @staticmethod
    @overload
    def dataclass(
        incoming: None = None,
        /,
        **kwargs,
    ) -> Callable[[type[R]], type[R]]: ...

    @staticmethod
    @overload
    def dataclass(
        incoming: type[R],
        /,
        **kwargs,
    ) -> type[R]: ...

    @dataclass_transform(
        frozen_default=True,
    )
    @staticmethod
    def dataclass(
        incoming: type[R] | None = None,
        /,
        **kwargs,
    ) -> type[R] | Callable[[type[R]], type[R]]:
        """
        Denote that a class (which is inheriting `Pytree`) should be treated as a dataclass, meaning it can hold data in fields which are declared as part of the class.

        A dataclass is to be distinguished from a "methods only" `Pytree` class, which does not have fields, but may define methods.
        The latter cannot be instantiated, but can be inherited from, while the former can be instantiated:
        the `Pytree.dataclass` declaration informs the system _how to instantiate_ the class as a dataclass,
        and how to automatically define JAX's `Pytree` interfaces (`tree_flatten`, `tree_unflatten`, etc.) for the dataclass, based on the fields declared in the class, and possibly `Pytree.static(...)` or `Pytree.field(...)` annotations (or lack thereof, the default is that all fields are `Pytree.field(...)`).

        All `Pytree` dataclasses support pretty printing, as well as rendering to HTML.

        Examples:
            ```python exec="yes" html="true" source="material-block" session="core"
            from genjax import Pytree
            from genjax.typing import FloatArray
            import jax.numpy as jnp


            @Pytree.dataclass
            # Enforces type annotations on instantiation.
            class MyClass(Pytree):
                my_static_field: int = Pytree.static()
                my_dynamic_field: FloatArray


            print(MyClass(10, jnp.array(5.0)).render_html())
            ```
        """

        return pz.pytree_dataclass(
            incoming,
            overwrite_parent_init=True,
            **kwargs,
        )

    @staticmethod
    def static(**kwargs):
        """Declare a field of a `Pytree` dataclass to be static. Users can provide additional keyword argument options,
        like `default` or `default_factory`, to customize how the field is instantiated when an instance of
        the dataclass is instantiated.` Fields which are provided with default values must come after required fields in the dataclass declaration.

        Examples:
            ```python exec="yes" html="true" source="material-block" session="core"
            @Pytree.dataclass
            # Enforces type annotations on instantiation.
            class MyClass(Pytree):
                my_dynamic_field: FloatArray
                my_static_field: int = Pytree.static(default=0)


            print(MyClass(jnp.array(5.0)).render_html())
            ```

        """
        return field(metadata={"pytree_node": False}, **kwargs)

    @staticmethod
    def field(**kwargs):
        "Declare a field of a `Pytree` dataclass to be dynamic. Alternatively, one can leave the annotation off in the declaration."
        return field(**kwargs)

    ##############################
    # Utility class constructors #
    ##############################

    @staticmethod
    def const(v):
        # The value must be concrete!
        # It cannot be a JAX traced value.
        assert static_check_is_concrete(v)
        if isinstance(v, Const):
            return v
        else:
            return Const(v)

    # Safe: will not wrap a Const in another Const, and will not
    # wrap dynamic values.
    @staticmethod
    def tree_const(v):
        def _inner(v):
            if isinstance(v, Const):
                return v
            elif static_check_is_concrete(v):
                return Const(v)
            else:
                return v

        return jtu.tree_map(
            _inner,
            v,
            is_leaf=lambda v: isinstance(v, Const),
        )

    @staticmethod
    def tree_const_unwrap(v):
        def _inner(v):
            if isinstance(v, Const):
                return v.val
            else:
                return v

        return jtu.tree_map(
            _inner,
            v,
            is_leaf=lambda v: isinstance(v, Const),
        )

    @staticmethod
    def partial(*args) -> Callable[[Callable[..., R]], "Closure[R]"]:
        return lambda fn: Closure[R](args, fn)

    def treedef(self):
        return jtu.tree_structure(self)

    #################
    # Static checks #
    #################

    @staticmethod
    def static_check_tree_structure_equivalence(trees: list[Any]):
        if not trees:
            return True
        else:
            fst, *rest = trees
            treedef = jtu.tree_structure(fst)
            check = all(map(lambda v: treedef == jtu.tree_structure(v), rest))
            return check

    def treescope_color(self) -> str:
        """Computes a CSS color to display for this object in treescope.

        This function can be overridden to change the color for a particular object
        in treescope, without having to register a new handler.

        (note that we are overriding the Penzai base class's implementation so that ALL structs receive colors, not just classes with `__call__` implemented.)

        Returns:
          A CSS color string to use as a background/highlight color for this object.
          Alternatively, a tuple of (border, fill) CSS colors.
        """
        type_string = type(self).__module__ + "." + type(self).__qualname__
        return formatting_util.color_from_string(type_string)

    def render_html(self):
        return treescope.render_to_html(
            self,
            roundtrip_mode=False,
        )

dataclass staticmethod

dataclass(
    incoming: None = None, /, **kwargs
) -> Callable[[type[R]], type[R]]
dataclass(incoming: type[R], /, **kwargs) -> type[R]
dataclass(
    incoming: type[R] | None = None, /, **kwargs
) -> type[R] | Callable[[type[R]], type[R]]

Denote that a class (which is inheriting Pytree) should be treated as a dataclass, meaning it can hold data in fields which are declared as part of the class.

A dataclass is to be distinguished from a "methods only" Pytree class, which does not have fields, but may define methods. The latter cannot be instantiated, but can be inherited from, while the former can be instantiated: the Pytree.dataclass declaration informs the system how to instantiate the class as a dataclass, and how to automatically define JAX's Pytree interfaces (tree_flatten, tree_unflatten, etc.) for the dataclass, based on the fields declared in the class, and possibly Pytree.static(...) or Pytree.field(...) annotations (or lack thereof, the default is that all fields are Pytree.field(...)).

All Pytree dataclasses support pretty printing, as well as rendering to HTML.

Examples:

from genjax import Pytree
from genjax.typing import FloatArray
import jax.numpy as jnp


@Pytree.dataclass
# Enforces type annotations on instantiation.
class MyClass(Pytree):
    my_static_field: int = Pytree.static()
    my_dynamic_field: FloatArray


print(MyClass(10, jnp.array(5.0)).render_html())
Source code in src/genjax/_src/core/pytree.py
@dataclass_transform(
    frozen_default=True,
)
@staticmethod
def dataclass(
    incoming: type[R] | None = None,
    /,
    **kwargs,
) -> type[R] | Callable[[type[R]], type[R]]:
    """
    Denote that a class (which is inheriting `Pytree`) should be treated as a dataclass, meaning it can hold data in fields which are declared as part of the class.

    A dataclass is to be distinguished from a "methods only" `Pytree` class, which does not have fields, but may define methods.
    The latter cannot be instantiated, but can be inherited from, while the former can be instantiated:
    the `Pytree.dataclass` declaration informs the system _how to instantiate_ the class as a dataclass,
    and how to automatically define JAX's `Pytree` interfaces (`tree_flatten`, `tree_unflatten`, etc.) for the dataclass, based on the fields declared in the class, and possibly `Pytree.static(...)` or `Pytree.field(...)` annotations (or lack thereof, the default is that all fields are `Pytree.field(...)`).

    All `Pytree` dataclasses support pretty printing, as well as rendering to HTML.

    Examples:
        ```python exec="yes" html="true" source="material-block" session="core"
        from genjax import Pytree
        from genjax.typing import FloatArray
        import jax.numpy as jnp


        @Pytree.dataclass
        # Enforces type annotations on instantiation.
        class MyClass(Pytree):
            my_static_field: int = Pytree.static()
            my_dynamic_field: FloatArray


        print(MyClass(10, jnp.array(5.0)).render_html())
        ```
    """

    return pz.pytree_dataclass(
        incoming,
        overwrite_parent_init=True,
        **kwargs,
    )

static staticmethod

static(**kwargs)

Declare a field of a Pytree dataclass to be static. Users can provide additional keyword argument options, like default or default_factory, to customize how the field is instantiated when an instance of the dataclass is instantiated.` Fields which are provided with default values must come after required fields in the dataclass declaration.

Examples:

@Pytree.dataclass
# Enforces type annotations on instantiation.
class MyClass(Pytree):
    my_dynamic_field: FloatArray
    my_static_field: int = Pytree.static(default=0)


print(MyClass(jnp.array(5.0)).render_html())
Source code in src/genjax/_src/core/pytree.py
@staticmethod
def static(**kwargs):
    """Declare a field of a `Pytree` dataclass to be static. Users can provide additional keyword argument options,
    like `default` or `default_factory`, to customize how the field is instantiated when an instance of
    the dataclass is instantiated.` Fields which are provided with default values must come after required fields in the dataclass declaration.

    Examples:
        ```python exec="yes" html="true" source="material-block" session="core"
        @Pytree.dataclass
        # Enforces type annotations on instantiation.
        class MyClass(Pytree):
            my_dynamic_field: FloatArray
            my_static_field: int = Pytree.static(default=0)


        print(MyClass(jnp.array(5.0)).render_html())
        ```

    """
    return field(metadata={"pytree_node": False}, **kwargs)

field staticmethod

field(**kwargs)

Declare a field of a Pytree dataclass to be dynamic. Alternatively, one can leave the annotation off in the declaration.

Source code in src/genjax/_src/core/pytree.py
@staticmethod
def field(**kwargs):
    "Declare a field of a `Pytree` dataclass to be dynamic. Alternatively, one can leave the annotation off in the declaration."
    return field(**kwargs)

genjax.core.Const

Bases: Generic[R], Pytree

JAX-compatible way to tag a value as a constant. Valid constants include Python literals, strings, essentially anything that won't hold JAX arrays inside of a computation.

Examples:

Instances of Const can be created using a Pytree classmethod:

from genjax import Pytree

c = Pytree.const(5)
print(c.render_html())

Constants can be freely used across jax.jit boundaries:

from genjax import Pytree


def f(c):
    if c.unwrap() == 5:
        return 10.0
    else:
        return 5.0


c = Pytree.const(5)
r = jax.jit(f)(c)
print(r)
10.0

Methods:

Name Description
unwrap

Unwrap a constant value from a Const instance.

Source code in src/genjax/_src/core/pytree.py
@Pytree.dataclass
class Const(Generic[R], Pytree):
    """
    JAX-compatible way to tag a value as a constant. Valid constants include Python literals, strings, essentially anything **that won't hold JAX arrays** inside of a computation.

    Examples:
        Instances of `Const` can be created using a `Pytree` classmethod:
        ```python exec="yes" html="true" source="material-block" session="core"
        from genjax import Pytree

        c = Pytree.const(5)
        print(c.render_html())
        ```

        Constants can be freely used across [`jax.jit`](https://jax.readthedocs.io/en/latest/_autosummary/jax.jit.html) boundaries:
        ```python exec="yes" html="true" source="material-block" session="core"
        from genjax import Pytree


        def f(c):
            if c.unwrap() == 5:
                return 10.0
            else:
                return 5.0


        c = Pytree.const(5)
        r = jax.jit(f)(c)
        print(r)
        ```
    """

    val: R = Pytree.static()

    def __call__(self, *args):
        assert isinstance(self.val, Callable), (
            f"Wrapped `val` {self.val} is not Callable."
        )
        return self.val(*args)

    def unwrap(self: Any) -> R:
        """Unwrap a constant value from a `Const` instance.

        This method can be used as an instance method or as a static method. When used as a static method, it returns the input value unchanged if it is not a `Const` instance.

        Returns:
            R: The unwrapped value if self is a `Const`, otherwise returns self unchanged.

        Examples:
            ```python exec="yes" html="true" source="material-block" session="core"
            from genjax import Pytree, Const

            c = Pytree.const(5)
            val = c.unwrap()  # Returns 5

            # Can also be used as static method
            val = Const.unwrap(10)  # Returns 10 unchanged
            ```
        """
        if isinstance(self, Const):
            return self.val
        else:
            return self

unwrap

unwrap() -> R

Unwrap a constant value from a Const instance.

This method can be used as an instance method or as a static method. When used as a static method, it returns the input value unchanged if it is not a Const instance.

Returns:

Name Type Description
R R

The unwrapped value if self is a Const, otherwise returns self unchanged.

Examples:

from genjax import Pytree, Const

c = Pytree.const(5)
val = c.unwrap()  # Returns 5

# Can also be used as static method
val = Const.unwrap(10)  # Returns 10 unchanged
Source code in src/genjax/_src/core/pytree.py
def unwrap(self: Any) -> R:
    """Unwrap a constant value from a `Const` instance.

    This method can be used as an instance method or as a static method. When used as a static method, it returns the input value unchanged if it is not a `Const` instance.

    Returns:
        R: The unwrapped value if self is a `Const`, otherwise returns self unchanged.

    Examples:
        ```python exec="yes" html="true" source="material-block" session="core"
        from genjax import Pytree, Const

        c = Pytree.const(5)
        val = c.unwrap()  # Returns 5

        # Can also be used as static method
        val = Const.unwrap(10)  # Returns 10 unchanged
        ```
    """
    if isinstance(self, Const):
        return self.val
    else:
        return self

genjax.core.Closure

Bases: Generic[R], Pytree

JAX-compatible closure type. It's a closure as a Pytree - meaning the static source code / callable is separated from dynamic data (which must be tracked by JAX).

Examples:

Instances of Closure can be created using Pytree.partial -- note the order of the "closed over" arguments:

from genjax import Pytree


def g(y):
    @Pytree.partial(y)  # dynamic values come first
    def f(v, x):
        # v will be bound to the value of y
        return x * (v * 5.0)

    return f


clos = jax.jit(g)(5.0)
print(clos.render_html())

Closures can be invoked / JIT compiled in other code:

r = jax.jit(lambda x: clos(x))(3.0)
print(r)
75.0

Source code in src/genjax/_src/core/pytree.py
@Pytree.dataclass
class Closure(Generic[R], Pytree):
    """
    JAX-compatible closure type. It's a closure _as a [`Pytree`][genjax.core.Pytree]_ - meaning the static _source code_ / _callable_ is separated from dynamic data (which must be tracked by JAX).

    Examples:
        Instances of `Closure` can be created using `Pytree.partial` -- note the order of the "closed over" arguments:
        ```python exec="yes" html="true" source="material-block" session="core"
        from genjax import Pytree


        def g(y):
            @Pytree.partial(y)  # dynamic values come first
            def f(v, x):
                # v will be bound to the value of y
                return x * (v * 5.0)

            return f


        clos = jax.jit(g)(5.0)
        print(clos.render_html())
        ```

        Closures can be invoked / JIT compiled in other code:
        ```python exec="yes" html="true" source="material-block" session="core"
        r = jax.jit(lambda x: clos(x))(3.0)
        print(r)
        ```
    """

    dyn_args: tuple[Any, ...]
    fn: Callable[..., R] = Pytree.static()

    def __call__(self, *args, **kwargs) -> R:
        return self.fn(*self.dyn_args, *args, **kwargs)

Dynamism in JAX: masks and sum types

The semantics of Gen are defined independently of any particular computational substrate or implementation - but JAX (and XLA through JAX) is a unique substrate, offering high performance, the ability to transformation code ahead-of-time via program transformations, and ... a rather unique set of restrictions.

JAX is a two-phase system

While not yet formally modelled, it's appropriate to think of JAX as separating computation into two phases:

  • The statics phase (which occurs at JAX tracing / transformation time).
  • The runtime phase (which occurs when a computation written in JAX is actually deployed via XLA and executed on a physical device somewhere in the world).

JAX has different rules for handling values depending on which phase we are in.

For instance, JAX disallows usage of runtime values to resolve Python control flow at tracing time (intuition: we don't actually know the value yet!) and will error if the user attempts to trace through a Python program with incorrect usage of runtime values.

In GenJAX, we take advantage of JAX's tracing to construct code which, when traced, produces specialized code depending on static information. At the same time, we are careful to encode Gen's interfaces to respect JAX's rules which govern how static / runtime values can be used.

The most primitive way to encode runtime uncertainty about a piece of data is to attach a bool to it, which indicates whether the data is "on" or "off".

GenJAX contains a system for tagging data with flags, to indicate if the data is valid or invalid during inference interface computations at runtime. The key data structure which supports this system is genjax.core.Mask.

genjax.core.Mask

Bases: Generic[R], Pytree

The Mask datatype wraps a value in a Boolean flag which denotes whether the data is valid or invalid to use in inference computations.

Masks can be used in a variety of ways as part of generative computations - their primary role is to denote data which is valid under inference computations. Valid data can be used as ChoiceMap leaves, and participate in generative and inference computations (like scores, and importance weights or density ratios). A Mask with a False flag should be considered unusable, and should be handled with care.

If a flag has a non-scalar shape, that implies that the mask is vectorized, and that the ArrayLike value, or each leaf in the pytree, must have the flag's shape as its prefix (i.e., must have been created with a jax.vmap call or via a GenJAX vmap combinator).

Encountering Mask in your computation

When users see Mask in their computations, they are expected to interact with them by either:

  • Unmasking them using the Mask.unmask interface, a potentially unsafe operation.

  • Destructuring them manually, and handling the cases.

Usage of invalid data

If you use invalid Mask(data, False) data in inference computations, you may encounter silently incorrect results.

Methods:

Name Description
unmask

Unmask the Mask, returning the value within.

Source code in src/genjax/_src/core/generative/functional_types.py
@Pytree.dataclass(match_args=True, init=False)
class Mask(Generic[R], Pytree):
    """The `Mask` datatype wraps a value in a Boolean flag which denotes whether the data is valid or invalid to use in inference computations.

    Masks can be used in a variety of ways as part of generative computations - their primary role is to denote data which is valid under inference computations. Valid data can be used as `ChoiceMap` leaves, and participate in generative and inference computations (like scores, and importance weights or density ratios). A Mask with a False flag **should** be considered unusable, and should be handled with care.

    If a `flag` has a non-scalar shape, that implies that the mask is vectorized, and that the `ArrayLike` value, or each leaf in the pytree, must have the flag's shape as its prefix (i.e., must have been created with a `jax.vmap` call or via a GenJAX `vmap` combinator).

    ## Encountering `Mask` in your computation

    When users see `Mask` in their computations, they are expected to interact with them by either:

    * Unmasking them using the `Mask.unmask` interface, a potentially unsafe operation.

    * Destructuring them manually, and handling the cases.

    ## Usage of invalid data

    If you use invalid `Mask(data, False)` data in inference computations, you may encounter silently incorrect results.
    """

    value: R
    flag: Flag | Diff[Flag]

    ################
    # Constructors #
    ################

    def __init__(self, value: R, flag: Flag | Diff[Flag] = True) -> None:
        assert not isinstance(value, Mask), (
            f"Mask should not be instantiated with another Mask! found {value}"
        )
        Mask._validate_init(value, flag)

        self.value, self.flag = value, flag  # pyright: ignore[reportAttributeAccessIssue]

    @staticmethod
    def _validate_init(value: R, flag: Flag | Diff[Flag]) -> None:
        """Validates that non-scalar flags are only used with vectorized masks.

        When a flag has a non-scalar shape (e.g. shape (3,)), this indicates the mask is vectorized.
        In this case, each leaf value in the pytree must have the flag's shape as a prefix of its own shape.
        For example, if flag has shape (3,), then array leaves must have shapes like (3,), (3,4), (3,2,1) etc.

        This ensures that vectorized flags properly align with vectorized data.

        Args:
            value: The value to be masked, can be a pytree
            flag: The flag to apply, either a scalar or array flag

        Raises:
            ValueError: If a non-scalar flag's shape is not a prefix of all leaf value shapes
        """
        flag = flag.get_primal() if isinstance(flag, Diff) else flag
        f_shape = jnp.shape(flag)
        if f_shape == ():
            return None

        leaf_shapes = [jnp.shape(leaf) for leaf in jtu.tree_leaves(value)]
        prefix_len = len(f_shape)

        for shape in leaf_shapes:
            if shape[:prefix_len] != f_shape:
                raise ValueError(
                    f"Vectorized flag {flag}'s shape {f_shape} must be a prefix of all leaf shapes. Found {shape}."
                )

    @staticmethod
    def _validate_leaf_shapes(this: R, other: R):
        """Validates that two values have matching shapes at each leaf.

        Used by __or__, __xor__ etc. to ensure we only combine masks with values whose leaves have matching shapes.
        Broadcasting is not supported - array shapes must match exactly.

        Args:
            this: First value to compare
            other: Second value to compare

        Raises:
            ValueError: If any leaf shapes don't match exactly
        """

        # Check array shapes match exactly (no broadcasting)
        def check_leaf_shapes(x, y):
            x_shape = jnp.shape(x)
            y_shape = jnp.shape(y)
            if x_shape != y_shape:
                raise ValueError(
                    f"Cannot combine masks with different array shapes: {x_shape} vs {y_shape}"
                )
            return None

        jtu.tree_map(check_leaf_shapes, this, other)

    def _validate_mask_shapes(self, other: "Mask[R]") -> None:
        """Used by __or__, __xor__ etc. to ensure we only combine masks with matching pytree shape and matching leaf shapes."""
        if jtu.tree_structure(self.value) != jtu.tree_structure(other.value):
            raise ValueError("Cannot combine masks with different tree structures!")

        Mask._validate_leaf_shapes(self, other)
        return None

    @staticmethod
    def build(v: "R | Mask[R]", f: Flag | Diff[Flag] = True) -> "Mask[R]":
        """
        Create a Mask instance, potentially from an existing Mask or a raw value.

        This method allows for the creation of a new Mask or the modification of an existing one. If the input is already a Mask, it combines the new flag with the existing one using a logical AND operation.

        Args:
            v: The value to be masked. Can be a raw value or an existing Mask.
            f: The flag to be applied to the value.

        Returns:
            A new Mask instance with the given value and flag.

        Note:
            If `v` is already a Mask, the new flag is combined with the existing one using a logical AND, ensuring that the resulting Mask is only valid if both input flags are valid.
        """
        match v:
            case Mask(value, g):
                assert not isinstance(f, Diff) and not isinstance(g, Diff)
                assert FlagOp.is_scalar(f) or (jnp.shape(f) == jnp.shape(g)), (
                    f"Can't build a Mask with non-matching Flag shapes {jnp.shape(f)} and {jnp.shape(g)}"
                )
                return Mask[R](value, FlagOp.and_(f, g))
            case _:
                return Mask[R](v, f)

    @staticmethod
    def maybe_mask(v: "R | Mask[R]", f: Flag) -> "R | Mask[R] | None":
        """
        Create a Mask instance or return the original value based on the flag.

        This method is similar to `build`, but it handles concrete flag values differently. For concrete True flags, it returns the original value without wrapping it in a Mask. For concrete False flags, it returns None. For non-concrete flags, it creates a new Mask instance.

        Args:
            v: The value to be potentially masked. Can be a raw value or an existing Mask.
            f: The flag to be applied to the value.

        Returns:
            - The original value `v` if `f` is concretely True.
            - None if `f` is concretely False.
            - A new Mask instance with the given value and flag if `f` is not concrete.
        """
        return Mask.build(v, f).flatten()

    #############
    # Accessors #
    #############

    def __getitem__(self, path) -> "Mask[R]":
        path = path if isinstance(path, tuple) else (path,)

        f = self.primal_flag()
        if isinstance(f, Array) and f.shape:
            # A non-scalar flag must have been produced via vectorization. Because a scalar flag can
            # wrap a non-scalar value, only use the vectorized components of the path to index into the flag...
            f = f[path[: len(f.shape)]]

        # but the use full path to index into the value.
        v_idx = jtu.tree_map(lambda v: v[path], self.value)

        # Reconstruct Diff if needed
        if isinstance(self.flag, Diff):
            f = Diff(f, self.flag.tangent)

        return Mask.build(v_idx, f)

    def flatten(self) -> "R | Mask[R] | None":
        """
        Flatten a Mask instance into its underlying value or None.

        "Flattening" occurs when the flag value is a concrete Boolean (True/False). In these cases, the Mask is simplified to either its raw value or None. If the flag is not concrete (i.e., a symbolic/traced value), the Mask remains intact.

        This method evaluates the mask's flag and returns:
        - None if the flag is concretely False or the value is None
        - The raw value if the flag is concretely True
        - The Mask instance itself if the flag is not concrete

        Returns:
            The flattened result based on the mask's flag state.
        """
        flag = self.primal_flag()
        if FlagOp.concrete_false(flag):
            return None
        elif FlagOp.concrete_true(flag):
            return self.value
        else:
            return self

    def unmask(self, default: R | None = None) -> R:
        """
        Unmask the `Mask`, returning the value within.

        This operation is inherently unsafe with respect to inference semantics if no default value is provided. It is only valid if the `Mask` wraps valid data at runtime, or if a default value is supplied.

        Args:
            default: An optional default value to return if the mask is invalid.

        Returns:
            The unmasked value if valid, or the default value if provided and the mask is invalid.
        """
        if default is None:

            def _check():
                checkify.check(
                    jnp.all(self.primal_flag()),
                    "Attempted to unmask when a mask flag (or some flag in a vectorized mask) is False: the unmasked value is invalid.\n",
                )

            optional_check(_check)
            return self.value
        else:

            def inner(true_v: ArrayLike, false_v: ArrayLike) -> Array:
                return jnp.where(self.primal_flag(), true_v, false_v)

            return jtu.tree_map(inner, self.value, default)

    def primal_flag(self) -> Flag:
        """
        Returns the primal flag of the mask.

        This method retrieves the primal (non-`Diff`-wrapped) flag value. If the flag
        is a Diff type (which contains both primal and tangent components), it returns
        the primal component. Otherwise, it returns the flag as is.

        Returns:
            The primal flag value.
        """
        match self.flag:
            case Diff(primal, _):
                return primal
            case flag:
                return flag

    ###############
    # Combinators #
    ###############

    def _or_idx(self, first: Flag, second: Flag):
        """Converts a pair of flag arrays into an array of indices for selecting between two values.

        This function implements a truth table for selecting between two values based on their flags:

        first | second | output | meaning
        ------+--------+--------+------------------
            0   |   0    |   -1   | neither valid
            1   |   0    |    0   | first valid only
            0   |   1    |    1   | second valid only
            1   |   1    |    0   | both valid for OR, invalid for XOR

        The output index is used to select between the corresponding values:
           0 -> select first value
           1 -> select second value

        Args:
            first: The flag for the first value
            second: The flag for the second value

        Returns:
            An Array of indices (-1, 0, or 1) indicating which value to select from each side.
        """
        # Note that the validation has already run to check that these flags have the same shape.
        return first + 2 * FlagOp.and_(FlagOp.not_(first), second) - 1

    def __or__(self, other: "Mask[R]") -> "Mask[R]":
        self._validate_mask_shapes(other)

        match self.primal_flag(), other.primal_flag():
            case True, _:
                return self
            case False, _:
                return other
            case self_flag, other_flag:
                idx = self._or_idx(self_flag, other_flag)
                return tree_choose(idx, [self, other])

    def __xor__(self, other: "Mask[R]") -> "Mask[R]":
        self._validate_mask_shapes(other)

        match self.primal_flag(), other.primal_flag():
            case (False, False) | (True, True):
                return Mask.build(self, False)
            case True, False:
                return self
            case False, True:
                return other
            case self_flag, other_flag:
                idx = self._or_idx(self_flag, other_flag)

                # note that `idx` above will choose the correct side for the FF, FT and TF cases,
                # but will equal 0 for TT flags. We use `FlagOp.xor_` to override this flag to equal
                # False, since neither side in the TT case will provide a `False` flag for us.
                chosen = tree_choose(idx, [self.value, other.value])
                return Mask(chosen, FlagOp.xor_(self_flag, other_flag))

    def __invert__(self) -> "Mask[R]":
        not_flag = jtu.tree_map(FlagOp.not_, self.flag)
        return Mask(self.value, not_flag)

    @staticmethod
    def or_n(mask: "Mask[R]", *masks: "Mask[R]") -> "Mask[R]":
        """Performs an n-ary OR operation on a sequence of Mask objects.

        Args:
            mask: The first mask to combine
            *masks: Variable number of additional masks to combine with OR

        Returns:
            A new Mask combining all inputs with OR operations
        """
        return functools.reduce(lambda a, b: a | b, masks, mask)

    @staticmethod
    def xor_n(mask: "Mask[R]", *masks: "Mask[R]") -> "Mask[R]":
        """Performs an n-ary XOR operation on a sequence of Mask objects.

        Args:
            mask: The first mask to combine
            *masks: Variable number of additional masks to combine with XOR

        Returns:
            A new Mask combining all inputs with XOR operations
        """
        return functools.reduce(lambda a, b: a ^ b, masks, mask)

unmask

unmask(default: R | None = None) -> R

Unmask the Mask, returning the value within.

This operation is inherently unsafe with respect to inference semantics if no default value is provided. It is only valid if the Mask wraps valid data at runtime, or if a default value is supplied.

Parameters:

Name Type Description Default

default

R | None

An optional default value to return if the mask is invalid.

None

Returns:

Type Description
R

The unmasked value if valid, or the default value if provided and the mask is invalid.

Source code in src/genjax/_src/core/generative/functional_types.py
def unmask(self, default: R | None = None) -> R:
    """
    Unmask the `Mask`, returning the value within.

    This operation is inherently unsafe with respect to inference semantics if no default value is provided. It is only valid if the `Mask` wraps valid data at runtime, or if a default value is supplied.

    Args:
        default: An optional default value to return if the mask is invalid.

    Returns:
        The unmasked value if valid, or the default value if provided and the mask is invalid.
    """
    if default is None:

        def _check():
            checkify.check(
                jnp.all(self.primal_flag()),
                "Attempted to unmask when a mask flag (or some flag in a vectorized mask) is False: the unmasked value is invalid.\n",
            )

        optional_check(_check)
        return self.value
    else:

        def inner(true_v: ArrayLike, false_v: ArrayLike) -> Array:
            return jnp.where(self.primal_flag(), true_v, false_v)

        return jtu.tree_map(inner, self.value, default)

Static typing with genjax.typing a.k.a 🐻beartype🐻

GenJAX uses beartype to perform type checking during JAX tracing / compile time. This means that beartype, normally a fast runtime type checker, operates at JAX tracing time to ensure that the arguments and return values are correct, with zero runtime cost.

Generative interface types

genjax.core.Arguments module-attribute

Arguments = tuple

Arguments is the type of argument values to generative functions. It is a type alias for Tuple, and is used to improve readability and parsing of interface specifications.

genjax.core.Score module-attribute

Score = FloatArray

A score is a density ratio, described fully in simulate.

The type Score does not enforce any meaningful mathematical invariants, but is used to denote the type of scores in the GenJAX system, to improve readability and parsing of interface specifications.

genjax.core.Weight module-attribute

Weight = FloatArray

A weight is a density ratio which often occurs in the context of proper weighting for Target distributions, or in Gen's edit interface, whose mathematical content is described in edit.

The type Weight does not enforce any meaningful mathematical invariants, but is used to denote the type of weights in GenJAX, to improve readability and parsing of interface specifications / expectations.

genjax.core.Retdiff module-attribute

Retdiff = Annotated[
    R, Is[lambda x: static_check_tree_diff(x)]
]

Retdiff is the type of return values with an attached ChangeType (c.f. edit).

When used under type checking, Retdiff assumes that the return value is a Pytree (either, defined via GenJAX's Pytree interface or registered with JAX's system). It checks that the leaves are Diff type with attached ChangeType.

genjax.core.Argdiffs module-attribute

Argdiffs = Annotated[
    tuple[Any, ...],
    Is[lambda x: static_check_tree_diff(x)],
]

Argdiffs is the type of argument values with an attached ChangeType (c.f. edit).

When used under type checking, Retdiff assumes that the argument values are Pytree (either, defined via GenJAX's Pytree interface or registered with JAX's system). For each argument, it checks that the leaves are Diff type with attached ChangeType.