From ac4a0fd7a7119e81969bc873bf700dcae019855b Mon Sep 17 00:00:00 2001 From: kp992 Date: Thu, 25 Dec 2025 16:57:43 -0800 Subject: [PATCH 1/2] Update lecture to use JAX and remove numba --- lectures/exchangeable.md | 178 +++++++++++++++++++++------------------ 1 file changed, 94 insertions(+), 84 deletions(-) diff --git a/lectures/exchangeable.md b/lectures/exchangeable.md index 564cbb1cc..8b924b685 100644 --- a/lectures/exchangeable.md +++ b/lectures/exchangeable.md @@ -3,8 +3,10 @@ jupytext: text_representation: extension: .md format_name: myst + format_version: 0.13 + jupytext_version: 1.16.7 kernelspec: - display_name: Python 3 + display_name: Python 3 (ipykernel) language: python name: python3 --- @@ -66,16 +68,21 @@ Below, we'll often use Let’s start with some imports: -```{code-cell} ipython ---- -tags: [hide-output] ---- +```{code-cell} ipython3 +:tags: [hide-output] + import matplotlib.pyplot as plt -from numba import jit, vectorize -from math import gamma import scipy.optimize as op from scipy.integrate import quad import numpy as np +import jax.numpy as jnp +from jax.scipy.special import gamma +import jax +``` + +```{code-cell} ipython3 +# Enable JAX to use 64-bit float operations +jax.config.update("jax_enable_x64", True) ``` ## Independently and Identically Distributed @@ -402,67 +409,69 @@ nature has chosen distribution $f$ – works. To create the Python infrastructure to do our work for us, we construct a wrapper function that displays informative graphs given parameters of $f$ and $g$. -```{code-cell} python3 -@vectorize +```{code-cell} ipython3 +@jax.jit def p(x, a, b): - "The general beta distribution function." + """The general beta distribution function.""" r = gamma(a + b) / (gamma(a) * gamma(b)) - return r * x ** (a-1) * (1 - x) ** (b-1) + return r * x ** (a - 1) * (1 - x) ** (b - 1) +``` +```{code-cell} ipython3 def learning_example(F_a=1, F_b=1, G_a=3, G_b=1.2): """ A wrapper function that displays the updating rule of belief π, given the parameters which specify F and G distributions. """ - f = jit(lambda x: p(x, F_a, F_b)) - g = jit(lambda x: p(x, G_a, G_b)) + f = jax.jit(lambda x: p(x, F_a, F_b)) + g = jax.jit(lambda x: p(x, G_a, G_b)) # l(w) = f(w) / g(w) l = lambda w: f(w) / g(w) # objective function for solving l(w) = 1 - obj = lambda w: l(w) - 1 + obj = jax.jit(lambda w: l(w) - 1) x_grid = np.linspace(0, 1, 100) - π_grid = np.linspace(1e-3, 1-1e-3, 100) + π_grid = np.linspace(1e-3, 1 - 1e-3, 100) w_max = 1 - w_grid = np.linspace(1e-12, w_max-1e-12, 100) + w_grid = np.linspace(1e-12, w_max - 1e-12, 100) # the mode of beta distribution # use this to divide w into two intervals for root finding G_mode = (G_a - 1) / (G_a + G_b - 2) roots = np.empty(2) roots[0] = op.root_scalar(obj, bracket=[1e-10, G_mode]).root - roots[1] = op.root_scalar(obj, bracket=[G_mode, 1-1e-10]).root + roots[1] = op.root_scalar(obj, bracket=[G_mode, 1 - 1e-10]).root fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(18, 5)) - ax1.plot(l(w_grid), w_grid, label='$l$', lw=2) - ax1.vlines(1., 0., 1., linestyle="--") - ax1.hlines(roots, 0., 2., linestyle="--") - ax1.set_xlim([0., 2.]) + ax1.plot(l(w_grid), w_grid, label="$l$", lw=2) + ax1.vlines(1.0, 0.0, 1.0, linestyle="--") + ax1.hlines(roots, 0.0, 2.0, linestyle="--") + ax1.set_xlim([0.0, 2.0]) ax1.legend(loc=4) - ax1.set(xlabel='$l(w)=f(w)/g(w)$', ylabel='$w$') + ax1.set(xlabel="$l(w)=f(w)/g(w)$", ylabel="$w$") - ax2.plot(f(x_grid), x_grid, label='$f$', lw=2) - ax2.plot(g(x_grid), x_grid, label='$g$', lw=2) - ax2.vlines(1., 0., 1., linestyle="--") - ax2.hlines(roots, 0., 2., linestyle="--") + ax2.plot(f(x_grid), x_grid, label="$f$", lw=2) + ax2.plot(g(x_grid), x_grid, label="$g$", lw=2) + ax2.vlines(1.0, 0.0, 1.0, linestyle="--") + ax2.hlines(roots, 0.0, 2.0, linestyle="--") ax2.legend(loc=4) - ax2.set(xlabel='$f(w), g(w)$', ylabel='$w$') + ax2.set(xlabel="$f(w), g(w)$", ylabel="$w$") area1 = quad(f, 0, roots[0])[0] area2 = quad(g, roots[0], roots[1])[0] area3 = quad(f, roots[1], 1)[0] ax2.text((f(0) + f(roots[0])) / 4, roots[0] / 2, f"{area1: .3g}") - ax2.fill_between([0, 1], 0, roots[0], color='blue', alpha=0.15) + ax2.fill_between([0, 1], 0, roots[0], color="blue", alpha=0.15) ax2.text(np.mean(g(roots)) / 2, np.mean(roots), f"{area2: .3g}") w_roots = np.linspace(roots[0], roots[1], 20) - ax2.fill_betweenx(w_roots, 0, g(w_roots), color='orange', alpha=0.15) + ax2.fill_betweenx(w_roots, 0, g(w_roots), color="orange", alpha=0.15) ax2.text((f(roots[1]) + f(1)) / 4, (roots[1] + 1) / 2, f"{area3: .3g}") - ax2.fill_between([0, 1], roots[1], 1, color='blue', alpha=0.15) + ax2.fill_between([0, 1], roots[1], 1, color="blue", alpha=0.15) W = np.arange(0.01, 0.99, 0.08) Π = np.arange(0.01, 0.99, 0.08) @@ -474,13 +483,13 @@ def learning_example(F_a=1, F_b=1, G_a=3, G_b=1.2): lw = l(w) ΔΠ[i, j] = π * (lw / (π * lw + 1 - π) - 1) - q = ax3.quiver(Π, W, ΔΠ, ΔW, scale=2, color='r', alpha=0.8) + q = ax3.quiver(Π, W, ΔΠ, ΔW, scale=2, color="r", alpha=0.8) - ax3.fill_between(π_grid, 0, roots[0], color='blue', alpha=0.15) - ax3.fill_between(π_grid, roots[0], roots[1], color='green', alpha=0.15) - ax3.fill_between(π_grid, roots[1], w_max, color='blue', alpha=0.15) - ax3.hlines(roots, 0., 1., linestyle="--") - ax3.set(xlabel=r'$\pi$', ylabel='$w$') + ax3.fill_between(π_grid, 0, roots[0], color="blue", alpha=0.15) + ax3.fill_between(π_grid, roots[0], roots[1], color="green", alpha=0.15) + ax3.fill_between(π_grid, roots[1], w_max, color="blue", alpha=0.15) + ax3.hlines(roots, 0.0, 1.0, linestyle="--") + ax3.set(xlabel=r"$\pi$", ylabel="$w$") ax3.grid() plt.show() @@ -490,7 +499,7 @@ Now we'll create a group of graphs that illustrate dynamics induced by Bayes' L We'll begin with Python function default values of various objects, then change them in a subsequent example. -```{code-cell} python3 +```{code-cell} ipython3 learning_example() ``` @@ -531,7 +540,7 @@ Next we use our code to create graphs for another instance of our model. We keep $F$ the same as in the preceding instance, namely a uniform distribution, but now assume that $G$ is a Beta distribution with parameters $G_a=2, G_b=1.6$. -```{code-cell} python3 +```{code-cell} ipython3 learning_example(G_a=2, G_b=1.6) ``` @@ -552,71 +561,72 @@ Outcomes depend on a peculiar property of likelihood ratio processes discussed To proceed, we create some Python code. -```{code-cell} python3 +```{code-cell} ipython3 def function_factory(F_a=1, F_b=1, G_a=3, G_b=1.2): # define f and g - f = jit(lambda x: p(x, F_a, F_b)) - g = jit(lambda x: p(x, G_a, G_b)) - - @jit - def update(a, b, π): - "Update π by drawing from beta distribution with parameters a and b" + f = jax.jit(lambda x: p(x, F_a, F_b)) + g = jax.jit(lambda x: p(x, G_a, G_b)) + @jax.jit + def update(key, a, b, π): + """Update π by drawing from beta distribution with parameters a and b""" # Draw - w = np.random.beta(a, b) - + w = jax.random.beta(key, a, b) # Update belief π = 1 / (1 + ((1 - π) * g(w)) / (π * f(w))) - return π - @jit - def simulate_path(a, b, T=50): - "Simulates a path of beliefs π with length T" - - π = np.empty(T+1) + @jax.jit + def simulate_path(keys, a, b): + """Simulates a path of beliefs π with length T""" - # initial condition - π[0] = 0.5 + # Generate all random keys upfront + def scan_fn(π_prev, subkey): + π_new = update(subkey, a, b, π_prev) + return π_new, π_new - for t in range(1, T+1): - π[t] = update(a, b, π[t-1]) - - return π + # Scan over the keys + _, π_path = jax.lax.scan(scan_fn, 0.5, keys) + # Prepend initial condition + return jnp.concatenate([jnp.array([0.5]), π_path]) def simulate(a=1, b=1, T=50, N=200, display=True): "Simulates N paths of beliefs π with length T" + # vectorize over the first argument of the function + simulate_path_vmap = jax.vmap(simulate_path, in_axes=(0, None, None)) + # create any random seed. Using some combination of a, b, T, N + # so that the seed becomes unique. + random_seed = int(a * b + T + N) + keys = jax.random.split(jax.random.PRNGKey(random_seed), (N, T)) + # Call the vectorized function over the first axis of keys + π_paths = simulate_path_vmap(keys, a, b) - π_paths = np.empty((N, T+1)) if display: fig = plt.figure() + for i in range(N): + plt.plot( + range(T + 1), π_paths[i], color="b", lw=0.8, alpha=0.5 + ) - for i in range(N): - π_paths[i] = simulate_path(a=a, b=b, T=T) - if display: - plt.plot(range(T+1), π_paths[i], color='b', lw=0.8, alpha=0.5) - - if display: plt.show() - return π_paths return simulate ``` -```{code-cell} python3 +```{code-cell} ipython3 simulate = function_factory() ``` We begin by generating $N$ simulated $\{\pi_t\}$ paths with $T$ periods when the sequence is truly IID draws from $F$. We set an initial prior $\pi_{-1} = .5$. -```{code-cell} python3 +```{code-cell} ipython3 T = 50 ``` -```{code-cell} python3 +```{code-cell} ipython3 # when nature selects F π_paths_F = simulate(a=1, b=1, T=T, N=1000) ``` @@ -629,7 +639,7 @@ discovers the truth for most of our paths. Next, we generate paths with $T$ periods when the sequence is truly IID draws from $G$. Again, we set the initial prior $\pi_{-1} = .5$. -```{code-cell} python3 +```{code-cell} ipython3 # when nature selects G π_paths_G = simulate(a=3, b=1.2, T=T, N=1000) ``` @@ -647,9 +657,9 @@ Using $N$ simulated $\pi_t$ paths, we compute $1 - \sum_{i=1}^{N}\pi_{i,t}$ at each $t$ when the data are generated as draws from $F$ and compute $\sum_{i=1}^{N}\pi_{i,t}$ when the data are generated as draws from $G$. -```{code-cell} python3 -plt.plot(range(T+1), 1 - np.mean(π_paths_F, 0), label='F generates') -plt.plot(range(T+1), np.mean(π_paths_G, 0), label='G generates') +```{code-cell} ipython3 +plt.plot(range(T + 1), 1 - np.mean(π_paths_F, 0), label="F generates") +plt.plot(range(T + 1), np.mean(π_paths_G, 0), label="G generates") plt.legend() plt.title("convergence"); ``` @@ -673,23 +683,23 @@ where $a =f,g$. The following code approximates the integral above: -```{code-cell} python3 +```{code-cell} ipython3 def expected_ratio(F_a=1, F_b=1, G_a=3, G_b=1.2): # define f and g - f = jit(lambda x: p(x, F_a, F_b)) - g = jit(lambda x: p(x, G_a, G_b)) + f = jax.jit(lambda x: p(x, F_a, F_b)) + g = jax.jit(lambda x: p(x, G_a, G_b)) - l = lambda w: f(w) / g(w) - integrand_f = lambda w, π: f(w) * l(w) / (π * l(w) + 1 - π) - integrand_g = lambda w, π: g(w) * l(w) / (π * l(w) + 1 - π) + l = jax.jit(lambda w: f(w) / g(w)) + integrand_f = jax.jit(lambda w, π: f(w) * l(w) / (π * l(w) + 1 - π)) + integrand_g = jax.jit(lambda w, π: g(w) * l(w) / (π * l(w) + 1 - π)) π_grid = np.linspace(0.02, 0.98, 100) expected_rario = np.empty(len(π_grid)) for q, inte in zip(["f", "g"], [integrand_f, integrand_g]): for i, π in enumerate(π_grid): - expected_rario[i]= quad(inte, 0, 1, args=(π,))[0] + expected_rario[i] = quad(inte, 0, 1, args=(π,))[0] plt.plot(π_grid, expected_rario, label=f"{q} generates") plt.hlines(1, 0, 1, linestyle="--") @@ -703,7 +713,7 @@ def expected_ratio(F_a=1, F_b=1, G_a=3, G_b=1.2): First, consider the case where $F_a=F_b=1$ and $G_a=3, G_b=1.2$. -```{code-cell} python3 +```{code-cell} ipython3 expected_ratio() ``` @@ -716,7 +726,7 @@ distributions, and $F_a=G_a=3, F_b=G_b=1.2$. In a sense, here there is nothing to learn. -```{code-cell} python3 +```{code-cell} ipython3 expected_ratio(F_a=3, F_b=1.2) ``` @@ -726,7 +736,7 @@ Finally, let's look at a case in which $f$ and $g$ are neither very different nor identical, in particular one in which $F_a=2, F_b=1$ and $G_a=3, G_b=1.2$. -```{code-cell} python3 +```{code-cell} ipython3 expected_ratio(F_a=2, F_b=1, G_a=3, G_b=1.2) ``` From 5098c0e8571c8604a64377dbe5c37a7c683a94c7 Mon Sep 17 00:00:00 2001 From: kp992 Date: Thu, 25 Dec 2025 17:01:08 -0800 Subject: [PATCH 2/2] fix typos --- lectures/exchangeable.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/lectures/exchangeable.md b/lectures/exchangeable.md index 8b924b685..012405082 100644 --- a/lectures/exchangeable.md +++ b/lectures/exchangeable.md @@ -506,7 +506,7 @@ learning_example() Please look at the three graphs above created for an instance in which $f$ is a uniform distribution on $[0,1]$ (i.e., a Beta distribution with parameters $F_a=1, F_b=1$), while $g$ is a Beta distribution with the default parameter values $G_a=3, G_b=1.2$. -The graph on the left plots the likelihood ratio $l(w)$ as the absciassa axis against $w$ as the ordinate. +The graph on the left plots the likelihood ratio $l(w)$ as the abscissa axis against $w$ as the ordinate. The middle graph plots both $f(w)$ and $g(w)$ against $w$, with the horizontal dotted lines showing values of $w$ at which the likelihood ratio equals $1$. @@ -720,7 +720,7 @@ expected_ratio() The above graphs shows that when $F$ generates the data, $\pi_t$ on average always heads north, while when $G$ generates the data, $\pi_t$ heads south. -Next, we'll look at a degenerate case in whcih $f$ and $g$ are identical beta +Next, we'll look at a degenerate case in which $f$ and $g$ are identical beta distributions, and $F_a=G_a=3, F_b=G_b=1.2$. In a sense, here there