From bd2307af5aad7c02c7049fcb6d6d2816841b396a Mon Sep 17 00:00:00 2001 From: kp992 Date: Thu, 25 Dec 2025 18:06:57 -0800 Subject: [PATCH 1/2] update lecture to use JAX jit and remove numba code --- lectures/mix_model.md | 356 +++++++++++++++++++++++++----------------- 1 file changed, 211 insertions(+), 145 deletions(-) diff --git a/lectures/mix_model.md b/lectures/mix_model.md index cac49c608..6a5350f88 100644 --- a/lectures/mix_model.md +++ b/lectures/mix_model.md @@ -4,7 +4,7 @@ jupytext: extension: .md format_name: myst format_version: 0.13 - jupytext_version: 1.17.2 + jupytext_version: 1.16.7 kernelspec: display_name: Python 3 (ipykernel) language: python @@ -124,29 +124,25 @@ As usual, we'll start by importing some Python tools. :hide-output: false import matplotlib.pyplot as plt +import seaborn as sns import numpy as np -from numba import vectorize, jit -from math import gamma import pandas as pd import scipy.stats as sp from scipy.integrate import quad -import seaborn as sns -colors = sns.color_palette() - import numpyro +import jax import numpyro.distributions as dist from numpyro.infer import MCMC, NUTS - import jax.numpy as jnp from jax import random +from jax.scipy.special import gamma np.random.seed(142857) +# Enable JAX to use 64-bit float operations +jax.config.update("jax_enable_x64", True) -@jit -def set_seed(): - np.random.seed(142857) -set_seed() +colors = sns.color_palette() ``` Let's use Python to generate two beta distributions @@ -158,35 +154,31 @@ Let's use Python to generate two beta distributions F_a, F_b = 1, 1 G_a, G_b = 3, 1.2 -@vectorize + +@jax.jit def p(x, a, b): r = gamma(a + b) / (gamma(a) * gamma(b)) - return r * x** (a-1) * (1 - x) ** (b-1) + return r * x ** (a - 1) * (1 - x) ** (b - 1) + # The two density functions. -f = jit(lambda x: p(x, F_a, F_b)) -g = jit(lambda x: p(x, G_a, G_b)) +f = jax.jit(lambda x: p(x, F_a, F_b)) +g = jax.jit(lambda x: p(x, G_a, G_b)) ``` ```{code-cell} ipython3 :hide-output: false -@jit def simulate(a, b, T=50, N=500): - ''' + """ Generate N sets of T observations of the likelihood ratio, return as N x T matrix. - - ''' - - l_arr = np.empty((N, T)) - - for i in range(N): - - for j in range(T): - w = np.random.beta(a, b) - l_arr[i, j] = f(w) / g(w) - + """ + # create any random seed. Using some combination of a, b, T, N + # so that the seed becomes unique. + random_seed = int(a * b + T + N) + w = jax.random.beta(jax.random.PRNGKey(random_seed), a, b, (N, T)) + l_arr = f(w) / g(w) return l_arr ``` @@ -248,25 +240,35 @@ See the [Mr. P Solver video on Monte Carlo simulation](https://www.google.com/se In the Python code below, we'll use both of our methods and confirm that each of them does a good job of sampling from our target mixture distribution. - ```{code-cell} ipython3 -@jit -def draw_lottery(p, N): - "Draw from the compound lottery directly." - - draws = [] - for i in range(0, N): - if np.random.rand()<=p: - draws.append(np.random.beta(F_a, F_b)) - else: - draws.append(np.random.beta(G_a, G_b)) - return np.array(draws) +def draw_lottery(p, N, random_seed_counter=0): + """Draw from the compound lottery directly.""" + # generate random key using some combination of p and N + random_seed = (p * N + 1234 + random_seed_counter).astype(int) + # Split key for all random draws + key_uniform, key_beta1, key_beta2 = jax.random.split( + jax.random.PRNGKey(random_seed), 3 + ) + + # Draw all random numbers at once + u = jax.random.uniform(key_uniform, shape=(N,)) + f_draws = jax.random.beta(key_beta1, F_a, F_b, shape=(N,)) + g_draws = jax.random.beta(key_beta2, G_a, G_b, shape=(N,)) + + # Select based on condition + return jnp.where(u <= p, f_draws, g_draws) + + +# JAX-JIT function with `N` as the static argument +draw_lottery = jax.jit(draw_lottery, static_argnums=(1,)) +``` +```{code-cell} ipython3 def draw_lottery_MC(p, N): "Draw from the compound lottery using the Monte Carlo trick." - xs = np.linspace(1e-8,1-(1e-8),10000) - CDF = p*sp.beta.cdf(xs, F_a, F_b) + (1-p)*sp.beta.cdf(xs, G_a, G_b) + xs = np.linspace(1e-8, 1 - (1e-8), 10000) + CDF = p * sp.beta.cdf(xs, F_a, F_b) + (1 - p) * sp.beta.cdf(xs, G_a, G_b) Us = np.random.rand(N) draws = xs[np.searchsorted(CDF[:-1], Us)] @@ -282,11 +284,11 @@ sample1 = draw_lottery(α, N) sample2 = draw_lottery_MC(α, N) # plot draws and density function -plt.hist(sample1, 50, density=True, alpha=0.5, label='direct draws') -plt.hist(sample2, 50, density=True, alpha=0.5, label='MC draws') +plt.hist(sample1, 50, density=True, alpha=0.5, label="direct draws") +plt.hist(sample2, 50, density=True, alpha=0.5, label="MC draws") -xs = np.linspace(0,1,1000) -plt.plot(xs, α*f(xs)+(1-α)*g(xs), color='red', label='density') +xs = np.linspace(0, 1, 1000) +plt.plot(xs, α * f(xs) + (1 - α) * g(xs), color="red", label="density") plt.legend() plt.show() @@ -326,13 +328,11 @@ likelihood ratio $ \ell $ according to recursion {eq}`eq:recur1` ```{code-cell} ipython3 :hide-output: false -@jit +@jax.jit def update(π, l): - "Update π using likelihood l" - + """Update π using likelihood l""" # Update belief π = π * l / (π * l + 1 - π) - return π ``` @@ -415,11 +415,11 @@ def simulate_mixed(α, T=50, N=500): return as N x T matrix, when the true density is mixed h;α """ - w_s = draw_lottery(α, N*T).reshape(N, T) + w_s = draw_lottery(α, N * T).reshape(N, T) l_arr = f(w_s) / g(w_s) - return l_arr + def plot_π_seq(α, π1=0.2, π2=0.8, T=200): """ Compute and plot π_seq and the log likelihood ratio process @@ -430,33 +430,41 @@ def plot_π_seq(α, π1=0.2, π2=0.8, T=200): l_seq_mixed = np.cumprod(l_arr_mixed, axis=1) T = l_arr_mixed.shape[1] - π_seq_mixed = np.empty((2, T+1)) + π_seq_mixed = np.empty((2, T + 1)) π_seq_mixed[:, 0] = π1, π2 for t in range(T): for i in range(2): - π_seq_mixed[i, t+1] = update(π_seq_mixed[i, t], l_arr_mixed[0, t]) + π_seq_mixed[i, t + 1] = update( + π_seq_mixed[i, t], l_arr_mixed[0, t] + ) # plot fig, ax1 = plt.subplots() for i in range(2): - ax1.plot(range(T+1), π_seq_mixed[i, :], label=rf"$\pi_0$={π_seq_mixed[i, 0]}") - - ax1.plot(np.nan, np.nan, '--', color='b', label='Log likelihood ratio process') + ax1.plot( + range(T + 1), + π_seq_mixed[i, :], + label=rf"$\pi_0$={π_seq_mixed[i, 0]}", + ) + + ax1.plot( + np.nan, np.nan, "--", color="b", label="Log likelihood ratio process" + ) ax1.set_ylabel(r"$\pi_t$") ax1.set_xlabel("t") ax1.legend() ax1.set_title("when $\\alpha F + (1-\\alpha)G$ governs data") ax2 = ax1.twinx() - ax2.plot(range(1, T+1), np.log(l_seq_mixed[0, :]), '--', color='b') + ax2.plot(range(1, T + 1), np.log(l_seq_mixed[0, :]), "--", color="b") ax2.set_ylabel("$log(L(w^{t}))$") plt.show() ``` ```{code-cell} ipython3 -plot_π_seq(α = 0.6) +plot_π_seq(α=0.6) ``` The above graph shows a sample path of the log likelihood ratio process as the blue dotted line, together with @@ -466,7 +474,7 @@ sample paths of $\pi_t$ that start from two distinct initial conditions. Let's see what happens when we change $\alpha$ ```{code-cell} ipython3 -plot_π_seq(α = 0.2) +plot_π_seq(α=0.2) ``` Evidently, $\alpha$ is having a big effect on the destination of $\pi_t$ as $t \rightarrow + \infty$ @@ -501,70 +509,84 @@ The only possible limits are $0$ and $1$. As $t \rightarrow +\infty$, $\pi_t$ goes to one if and only if $KL_f < KL_g$ ```{code-cell} ipython3 -@vectorize +@jax.jit def KL_g(α): "Compute the KL divergence KL(h, g)." - err = 1e-8 # to avoid 0 at end points - ws = np.linspace(err, 1-err, 10000) + err = 1e-8 # to avoid 0 at end points + ws = jnp.linspace(err, 1 - err, 10000) gs, fs = g(ws), f(ws) - hs = α*fs + (1-α)*gs - return np.sum(np.log(hs/gs)*hs)/10000 + hs = α * fs + (1 - α) * gs + return jnp.sum(jnp.log(hs / gs) * hs) / 10000 -@vectorize + +KL_g_v = jax.vmap(KL_g) + + +@jax.jit def KL_f(α): "Compute the KL divergence KL(h, f)." - err = 1e-8 # to avoid 0 at end points - ws = np.linspace(err, 1-err, 10000) + err = 1e-8 # to avoid 0 at end points + ws = jnp.linspace(err, 1 - err, 10000) gs, fs = g(ws), f(ws) - hs = α*fs + (1-α)*gs - return np.sum(np.log(hs/fs)*hs)/10000 + hs = α * fs + (1 - α) * gs + return jnp.sum(jnp.log(hs / fs) * hs) / 10000 + + +KL_f_v = jax.vmap(KL_f) # compute KL using quad in Scipy def KL_g_quad(α): "Compute the KL divergence KL(h, g) using scipy.integrate." - h = lambda x: α*f(x) + (1-α)*g(x) - return quad(lambda x: h(x) * np.log(h(x)/g(x)), 0, 1)[0] + h = lambda x: α * f(x) + (1 - α) * g(x) + return quad(lambda x: h(x) * np.log(h(x) / g(x)), 0, 1)[0] + def KL_f_quad(α): "Compute the KL divergence KL(h, f) using scipy.integrate." - h = lambda x: α*f(x) + (1-α)*g(x) - return quad(lambda x: h(x) * np.log(h(x)/f(x)), 0, 1)[0] + h = lambda x: α * f(x) + (1 - α) * g(x) + return quad(lambda x: h(x) * np.log(h(x) / f(x)), 0, 1)[0] + # vectorize KL_g_quad_v = np.vectorize(KL_g_quad) KL_f_quad_v = np.vectorize(KL_f_quad) -# Let us find the limit point +@jax.jit def π_lim(α, T=5000, π_0=0.4): - "Find limit of π sequence." - π_seq = np.zeros(T+1) - π_seq[0] = π_0 + """Find limit of π sequence.""" + # Get lottery draws l_arr = simulate_mixed(α, T, N=1)[0] - for t in range(T): - π_seq[t+1] = update(π_seq[t], l_arr[t]) - return π_seq[-1] + def scan_fn(π_prev, l_t): + π_new = update(π_prev, l_t) + return π_new, π_new + + # Scan over lottery draws + π_final, _ = jax.lax.scan(scan_fn, π_0, l_arr) + + return π_final + -π_lim_v = np.vectorize(π_lim) +π_lim_v = jax.vmap(π_lim) ``` Let us first plot the KL divergences $KL_g\left(\alpha\right), KL_f\left(\alpha\right)$ for each $\alpha$. ```{code-cell} ipython3 α_arr = np.linspace(0, 1, 100) -KL_g_arr = KL_g(α_arr) -KL_f_arr = KL_f(α_arr) +KL_g_arr = KL_g_v(α_arr) +KL_f_arr = KL_f_v(α_arr) fig, ax = plt.subplots(1, figsize=[10, 6]) -ax.plot(α_arr, KL_g_arr, label='KL(h, g)') -ax.plot(α_arr, KL_f_arr, label='KL(h, f)') -ax.set_ylabel('KL divergence') -ax.set_xlabel(r'$\alpha$') +ax.plot(α_arr, KL_g_arr, label="KL(h, g)") +ax.plot(α_arr, KL_f_arr, label="KL(h, f)") +ax.set_ylabel("KL divergence") +ax.set_xlabel(r"$\alpha$") -ax.legend(loc='upper right') +ax.legend(loc="upper right") plt.show() ``` @@ -572,7 +594,7 @@ Let's compute an $\alpha$ for which the KL divergence between $h$ and $g$ is t ```{code-cell} ipython3 # where KL_f = KL_g -discretion = α_arr[np.argmin(np.abs(KL_g_arr-KL_f_arr))] +discretion = α_arr[np.argmin(np.abs(KL_g_arr - KL_f_arr))] ``` We can compute and plot the convergence point $\pi_{\infty}$ for each $\alpha$ to verify that the convergence is indeed governed by the KL divergence. @@ -582,27 +604,29 @@ recorded on the $x$ axis. Thus, the graph below confirms how a minimum KL divergence governs what our type 1 agent eventually learns. - ```{code-cell} ipython3 -α_arr_x = α_arr[(α_arrdiscretion)] +α_arr_x = α_arr[(α_arr < discretion) | (α_arr > discretion)] π_lim_arr = π_lim_v(α_arr_x) # plot fig, ax = plt.subplots(1, figsize=[10, 6]) -ax.plot(α_arr, KL_g_arr, label='KL(h, g)') -ax.plot(α_arr, KL_f_arr, label='KL(h, f)') -ax.set_ylabel('KL divergence') -ax.set_xlabel(r'$\alpha$') +ax.plot(α_arr, KL_g_arr, label="KL(h, g)") +ax.plot(α_arr, KL_f_arr, label="KL(h, f)") +ax.set_ylabel("KL divergence") +ax.set_xlabel(r"$\alpha$") # plot KL ax2 = ax.twinx() # plot limit point -ax2.scatter(α_arr_x, π_lim_arr, - facecolors='none', - edgecolors='tab:blue', - label=r'$\pi$ lim') -ax2.set_ylabel('π lim') +ax2.scatter( + α_arr_x, + π_lim_arr, + facecolors="none", + edgecolors="tab:blue", + label=r"$\pi$ lim", +) +ax2.set_ylabel("π lim") ax.legend(loc=[0.85, 0.8]) ax2.legend(loc=[0.85, 0.73]) @@ -664,11 +688,19 @@ We use the `Mixture` class in numpyro to construct the likelihood function. data = draw_lottery(α, 1000) sizes = [5, 20, 50, 200, 1000, 25000] + def model(w): - α = numpyro.sample('α', dist.Uniform(low=0.0, high=1.0)) + α = numpyro.sample("α", dist.Uniform(low=0.0, high=1.0)) + + y_samp = numpyro.sample( + "w", + dist.Mixture( + dist.Categorical(jnp.array([α, 1 - α])), + [dist.Beta(F_a, F_b), dist.Beta(G_a, G_b)], + ), + obs=w, + ) - y_samp = numpyro.sample('w', - dist.Mixture(dist.Categorical(jnp.array([α, 1-α])), [dist.Beta(F_a, F_b), dist.Beta(G_a, G_b)]), obs=w) def MCMC_run(ws): "Compute posterior using MCMC with observed ws" @@ -678,24 +710,30 @@ def MCMC_run(ws): mcmc.run(rng_key=random.PRNGKey(142857), w=jnp.array(ws)) sample = mcmc.get_samples() - return sample['α'] + return sample["α"] ``` The following code generates the graph below that displays Bayesian posteriors for $\alpha$ at various history lengths. ```{code-cell} ipython3 - fig, ax = plt.subplots(figsize=(10, 6)) for i in range(len(sizes)): - sample = MCMC_run(data[:sizes[i]]) + sample = MCMC_run(data[: sizes[i]]) sns.histplot( - data=sample, kde=True, stat='density', alpha=0.2, ax=ax, - color=colors[i], binwidth=0.02, linewidth=0.05, label=f't={sizes[i]}' + data=sample, + kde=True, + stat="density", + alpha=0.2, + ax=ax, + color=colors[i], + binwidth=0.02, + linewidth=0.05, + label=f"t={sizes[i]}", ) -ax.set_title(r'$\pi_t(\alpha)$ as $t$ increases') +ax.set_title(r"$\pi_t(\alpha)$ as $t$ increases") ax.legend() -ax.set_xlabel(r'$\alpha$') +ax.set_xlabel(r"$\alpha$") plt.show() ``` @@ -783,47 +821,62 @@ T_mix = 200 # Three different priors with means 0.25, 0.5, 0.75 prior_params = [(1, 3), (1, 1), (3, 1)] -prior_means = [a/(a+b) for a, b in prior_params] +prior_means = [a / (a + b) for a, b in prior_params] w_mix = draw_lottery(x_true, T_mix) ``` ```{code-cell} ipython3 -@jit +@jax.jit def learn_x_bayesian(observations, α0, β0, grid_size=2000): """ Sequential Bayesian learning of the mixing probability x using a grid approximation. """ - w = np.asarray(observations) + w = jnp.asarray(observations) T = w.size - x_grid = np.linspace(1e-3, 1 - 1e-3, grid_size) + x_grid = jnp.linspace(1e-3, 1 - 1e-3, grid_size) # Log prior - log_prior = (α0 - 1) * np.log(x_grid) + (β0 - 1) * np.log1p(-x_grid) - - μ_path = np.empty(T + 1) - μ_path[0] = α0 / (α0 + β0) + log_prior = (α0 - 1) * jnp.log(x_grid) + (β0 - 1) * jnp.log1p(-x_grid) - log_post = log_prior.copy() - - for t in range(T): - wt = w[t] + def scan_fn(log_post, wt): # P(w_t | x) = x f(w_t) + (1 - x) g(w_t) like = x_grid * f(wt) + (1 - x_grid) * g(wt) - log_post += np.log(like) + log_post = log_post + jnp.log(like) + + # Normalize using log-sum-exp trick + log_post = log_post - jax.nn.logsumexp(log_post) + post = jnp.exp(log_post) + + # Compute posterior mean + μ = x_grid @ post + + return log_post, μ - # normalize - log_post -= log_post.max() - post = np.exp(log_post) - post /= post.sum() + # Initial posterior mean + μ_0 = α0 / (α0 + β0) - μ_path[t + 1] = x_grid @ post + # Scan over observations + _, μ_path = jax.lax.scan(scan_fn, log_prior, w) - return μ_path + # Prepend initial value + return jnp.concatenate([jnp.array([μ_0]), μ_path]) -x_posterior_means = [learn_x_bayesian(w_mix, α0, β0) for α0, β0 in prior_params] + +# Vectorize over different prior parameters +def compute_all_posteriors(observations, prior_params): + """Compute posterior means for all prior parameter pairs.""" + + def single_posterior(params): + α0, β0 = params + return learn_x_bayesian(observations, α0, β0) + + return jax.vmap(single_posterior)(jnp.array(prior_params)) + + +x_posterior_means = compute_all_posteriors(w_mix, jnp.array(prior_params)) ``` Let's visualize how the posterior mean of $x$ evolves over time, starting from three different prior beliefs. @@ -832,14 +885,23 @@ Let's visualize how the posterior mean of $x$ evolves over time, starting from t fig, ax = plt.subplots(figsize=(10, 6)) for i, (x_means, mean0) in enumerate(zip(x_posterior_means, prior_means)): - ax.plot(range(T_mix + 1), x_means, - label=fr'Prior mean = ${mean0:.2f}$', - color=colors[i], linewidth=2) - -ax.axhline(y=x_true, color='black', linestyle='--', - label=f'True x = {x_true}', linewidth=2) -ax.set_xlabel('$t$') -ax.set_ylabel('Posterior mean of $x$') + ax.plot( + range(T_mix + 1), + x_means, + label=rf"Prior mean = ${mean0:.2f}$", + color=colors[i], + linewidth=2, + ) + +ax.axhline( + y=x_true, + color="black", + linestyle="--", + label=f"True x = {x_true}", + linewidth=2, +) +ax.set_xlabel("$t$") +ax.set_ylabel("Posterior mean of $x$") ax.legend() plt.show() ``` @@ -849,21 +911,25 @@ The plot shows that regardless of the initial prior belief, all three posterior Next, let's look at multiple simulations with a longer time horizon, all starting from a uniform prior. ```{code-cell} ipython3 -set_seed() n_paths = 20 T_long = 10_000 fig, ax = plt.subplots(figsize=(10, 5)) for j in range(n_paths): - w_path = draw_lottery(x_true, T_long) + w_path = draw_lottery(x_true, T_long, j) x_means = learn_x_bayesian(w_path, 1, 1) # Uniform prior ax.plot(range(T_long + 1), x_means, alpha=0.5, linewidth=1) -ax.axhline(y=x_true, color='red', linestyle='--', - label=f'True x = {x_true}', linewidth=2) -ax.set_ylabel('Posterior mean of $x$') -ax.set_xlabel('$t$') +ax.axhline( + y=x_true, + color="red", + linestyle="--", + label=f"True x = {x_true}", + linewidth=2, +) +ax.set_ylabel("Posterior mean of $x$") +ax.set_xlabel("$t$") ax.legend() plt.tight_layout() plt.show() From edcf32c83a38f773331e3801d4dde259287f5811 Mon Sep 17 00:00:00 2001 From: kp992 Date: Thu, 25 Dec 2025 18:10:21 -0800 Subject: [PATCH 2/2] fix minor typos --- lectures/mix_model.md | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/lectures/mix_model.md b/lectures/mix_model.md index 6a5350f88..08d364fea 100644 --- a/lectures/mix_model.md +++ b/lectures/mix_model.md @@ -145,7 +145,7 @@ jax.config.update("jax_enable_x64", True) colors = sns.color_palette() ``` -Let's use Python to generate two beta distributions +Let's use Python to generate two Beta distributions. ```{code-cell} ipython3 :hide-output: false @@ -182,7 +182,7 @@ def simulate(a, b, T=50, N=500): return l_arr ``` -We’ll also use the following Python code to prepare some informative simulations +We'll also use the following Python code to prepare some informative simulations. ```{code-cell} ipython3 :hide-output: false @@ -296,7 +296,7 @@ plt.show() ## Type 1 Agent -We'll now study what our type 1 agent learns +We'll now study what our type 1 agent learns. Remember that our type 1 agent uses the wrong statistical model, thinking that nature mixed between $f$ and $g$ once and for all at time $-1$. @@ -471,13 +471,13 @@ The above graph shows a sample path of the log likelihood ratio process as the b sample paths of $\pi_t$ that start from two distinct initial conditions. -Let's see what happens when we change $\alpha$ +Let's see what happens when we change $\alpha$. ```{code-cell} ipython3 plot_π_seq(α=0.2) ``` -Evidently, $\alpha$ is having a big effect on the destination of $\pi_t$ as $t \rightarrow + \infty$ +Evidently, $\alpha$ is having a big effect on the destination of $\pi_t$ as $t \rightarrow +\infty$. ## Kullback-Leibler Divergence Governs Limit of $\pi_t$ @@ -506,7 +506,7 @@ $$ \min_{f,g} \{KL_g, KL_f\} $$ The only possible limits are $0$ and $1$. -As $t \rightarrow +\infty$, $\pi_t$ goes to one if and only if $KL_f < KL_g$ +As $t \rightarrow +\infty$, $\pi_t$ goes to one if and only if $KL_f < KL_g$. ```{code-cell} ipython3 @jax.jit