Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
182 changes: 96 additions & 86 deletions lectures/exchangeable.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@ jupytext:
text_representation:
extension: .md
format_name: myst
format_version: 0.13
jupytext_version: 1.16.7
kernelspec:
display_name: Python 3
display_name: Python 3 (ipykernel)
language: python
name: python3
---
Expand Down Expand Up @@ -66,16 +68,21 @@ Below, we'll often use

Let’s start with some imports:

```{code-cell} ipython
---
tags: [hide-output]
---
```{code-cell} ipython3
:tags: [hide-output]

import matplotlib.pyplot as plt
from numba import jit, vectorize
from math import gamma
import scipy.optimize as op
from scipy.integrate import quad
import numpy as np
import jax.numpy as jnp
from jax.scipy.special import gamma
import jax
```

```{code-cell} ipython3
# Enable JAX to use 64-bit float operations
jax.config.update("jax_enable_x64", True)
```

## Independently and Identically Distributed
Expand Down Expand Up @@ -402,67 +409,69 @@ nature has chosen distribution $f$ – works.
To create the Python infrastructure to do our work for us, we construct a wrapper function that displays informative graphs
given parameters of $f$ and $g$.

```{code-cell} python3
@vectorize
```{code-cell} ipython3
@jax.jit
def p(x, a, b):
"The general beta distribution function."
"""The general beta distribution function."""
r = gamma(a + b) / (gamma(a) * gamma(b))
return r * x ** (a-1) * (1 - x) ** (b-1)
return r * x ** (a - 1) * (1 - x) ** (b - 1)
```

```{code-cell} ipython3
def learning_example(F_a=1, F_b=1, G_a=3, G_b=1.2):
"""
A wrapper function that displays the updating rule of belief π,
given the parameters which specify F and G distributions.
"""

f = jit(lambda x: p(x, F_a, F_b))
g = jit(lambda x: p(x, G_a, G_b))
f = jax.jit(lambda x: p(x, F_a, F_b))
g = jax.jit(lambda x: p(x, G_a, G_b))

# l(w) = f(w) / g(w)
l = lambda w: f(w) / g(w)
# objective function for solving l(w) = 1
obj = lambda w: l(w) - 1
obj = jax.jit(lambda w: l(w) - 1)

x_grid = np.linspace(0, 1, 100)
π_grid = np.linspace(1e-3, 1-1e-3, 100)
π_grid = np.linspace(1e-3, 1 - 1e-3, 100)

w_max = 1
w_grid = np.linspace(1e-12, w_max-1e-12, 100)
w_grid = np.linspace(1e-12, w_max - 1e-12, 100)

# the mode of beta distribution
# use this to divide w into two intervals for root finding
G_mode = (G_a - 1) / (G_a + G_b - 2)
roots = np.empty(2)
roots[0] = op.root_scalar(obj, bracket=[1e-10, G_mode]).root
roots[1] = op.root_scalar(obj, bracket=[G_mode, 1-1e-10]).root
roots[1] = op.root_scalar(obj, bracket=[G_mode, 1 - 1e-10]).root

fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(18, 5))

ax1.plot(l(w_grid), w_grid, label='$l$', lw=2)
ax1.vlines(1., 0., 1., linestyle="--")
ax1.hlines(roots, 0., 2., linestyle="--")
ax1.set_xlim([0., 2.])
ax1.plot(l(w_grid), w_grid, label="$l$", lw=2)
ax1.vlines(1.0, 0.0, 1.0, linestyle="--")
ax1.hlines(roots, 0.0, 2.0, linestyle="--")
ax1.set_xlim([0.0, 2.0])
ax1.legend(loc=4)
ax1.set(xlabel='$l(w)=f(w)/g(w)$', ylabel='$w$')
ax1.set(xlabel="$l(w)=f(w)/g(w)$", ylabel="$w$")

ax2.plot(f(x_grid), x_grid, label='$f$', lw=2)
ax2.plot(g(x_grid), x_grid, label='$g$', lw=2)
ax2.vlines(1., 0., 1., linestyle="--")
ax2.hlines(roots, 0., 2., linestyle="--")
ax2.plot(f(x_grid), x_grid, label="$f$", lw=2)
ax2.plot(g(x_grid), x_grid, label="$g$", lw=2)
ax2.vlines(1.0, 0.0, 1.0, linestyle="--")
ax2.hlines(roots, 0.0, 2.0, linestyle="--")
ax2.legend(loc=4)
ax2.set(xlabel='$f(w), g(w)$', ylabel='$w$')
ax2.set(xlabel="$f(w), g(w)$", ylabel="$w$")

area1 = quad(f, 0, roots[0])[0]
area2 = quad(g, roots[0], roots[1])[0]
area3 = quad(f, roots[1], 1)[0]

ax2.text((f(0) + f(roots[0])) / 4, roots[0] / 2, f"{area1: .3g}")
ax2.fill_between([0, 1], 0, roots[0], color='blue', alpha=0.15)
ax2.fill_between([0, 1], 0, roots[0], color="blue", alpha=0.15)
ax2.text(np.mean(g(roots)) / 2, np.mean(roots), f"{area2: .3g}")
w_roots = np.linspace(roots[0], roots[1], 20)
ax2.fill_betweenx(w_roots, 0, g(w_roots), color='orange', alpha=0.15)
ax2.fill_betweenx(w_roots, 0, g(w_roots), color="orange", alpha=0.15)
ax2.text((f(roots[1]) + f(1)) / 4, (roots[1] + 1) / 2, f"{area3: .3g}")
ax2.fill_between([0, 1], roots[1], 1, color='blue', alpha=0.15)
ax2.fill_between([0, 1], roots[1], 1, color="blue", alpha=0.15)

W = np.arange(0.01, 0.99, 0.08)
Π = np.arange(0.01, 0.99, 0.08)
Expand All @@ -474,13 +483,13 @@ def learning_example(F_a=1, F_b=1, G_a=3, G_b=1.2):
lw = l(w)
ΔΠ[i, j] = π * (lw / (π * lw + 1 - π) - 1)

q = ax3.quiver(Π, W, ΔΠ, ΔW, scale=2, color='r', alpha=0.8)
q = ax3.quiver(Π, W, ΔΠ, ΔW, scale=2, color="r", alpha=0.8)

ax3.fill_between(π_grid, 0, roots[0], color='blue', alpha=0.15)
ax3.fill_between(π_grid, roots[0], roots[1], color='green', alpha=0.15)
ax3.fill_between(π_grid, roots[1], w_max, color='blue', alpha=0.15)
ax3.hlines(roots, 0., 1., linestyle="--")
ax3.set(xlabel=r'$\pi$', ylabel='$w$')
ax3.fill_between(π_grid, 0, roots[0], color="blue", alpha=0.15)
ax3.fill_between(π_grid, roots[0], roots[1], color="green", alpha=0.15)
ax3.fill_between(π_grid, roots[1], w_max, color="blue", alpha=0.15)
ax3.hlines(roots, 0.0, 1.0, linestyle="--")
ax3.set(xlabel=r"$\pi$", ylabel="$w$")
ax3.grid()

plt.show()
Expand All @@ -490,14 +499,14 @@ Now we'll create a group of graphs that illustrate dynamics induced by Bayes' L

We'll begin with Python function default values of various objects, then change them in a subsequent example.

```{code-cell} python3
```{code-cell} ipython3
learning_example()
```

Please look at the three graphs above created for an instance in which $f$ is a uniform distribution on $[0,1]$
(i.e., a Beta distribution with parameters $F_a=1, F_b=1$), while $g$ is a Beta distribution with the default parameter values $G_a=3, G_b=1.2$.

The graph on the left plots the likelihood ratio $l(w)$ as the absciassa axis against $w$ as the ordinate.
The graph on the left plots the likelihood ratio $l(w)$ as the abscissa axis against $w$ as the ordinate.

The middle graph plots both $f(w)$ and $g(w)$ against $w$, with the horizontal dotted lines showing values
of $w$ at which the likelihood ratio equals $1$.
Expand Down Expand Up @@ -531,7 +540,7 @@ Next we use our code to create graphs for another instance of our model.
We keep $F$ the same as in the preceding instance, namely a uniform distribution, but now assume that $G$
is a Beta distribution with parameters $G_a=2, G_b=1.6$.

```{code-cell} python3
```{code-cell} ipython3
learning_example(G_a=2, G_b=1.6)
```

Expand All @@ -552,71 +561,72 @@ Outcomes depend on a peculiar property of likelihood ratio processes discussed

To proceed, we create some Python code.

```{code-cell} python3
```{code-cell} ipython3
def function_factory(F_a=1, F_b=1, G_a=3, G_b=1.2):

# define f and g
f = jit(lambda x: p(x, F_a, F_b))
g = jit(lambda x: p(x, G_a, G_b))

@jit
def update(a, b, π):
"Update π by drawing from beta distribution with parameters a and b"
f = jax.jit(lambda x: p(x, F_a, F_b))
g = jax.jit(lambda x: p(x, G_a, G_b))

@jax.jit
def update(key, a, b, π):
"""Update π by drawing from beta distribution with parameters a and b"""
# Draw
w = np.random.beta(a, b)

w = jax.random.beta(key, a, b)
# Update belief
π = 1 / (1 + ((1 - π) * g(w)) / (π * f(w)))

return π

@jit
def simulate_path(a, b, T=50):
"Simulates a path of beliefs π with length T"

π = np.empty(T+1)
@jax.jit
def simulate_path(keys, a, b):
"""Simulates a path of beliefs π with length T"""

# initial condition
π[0] = 0.5
# Generate all random keys upfront
def scan_fn(π_prev, subkey):
π_new = update(subkey, a, b, π_prev)
return π_new, π_new

for t in range(1, T+1):
π[t] = update(a, b, π[t-1])

return π
# Scan over the keys
_, π_path = jax.lax.scan(scan_fn, 0.5, keys)
# Prepend initial condition
return jnp.concatenate([jnp.array([0.5]), π_path])

def simulate(a=1, b=1, T=50, N=200, display=True):
"Simulates N paths of beliefs π with length T"
# vectorize over the first argument of the function
simulate_path_vmap = jax.vmap(simulate_path, in_axes=(0, None, None))
# create any random seed. Using some combination of a, b, T, N
# so that the seed becomes unique.
random_seed = int(a * b + T + N)
keys = jax.random.split(jax.random.PRNGKey(random_seed), (N, T))
# Call the vectorized function over the first axis of keys
π_paths = simulate_path_vmap(keys, a, b)

π_paths = np.empty((N, T+1))
if display:
fig = plt.figure()
for i in range(N):
plt.plot(
range(T + 1), π_paths[i], color="b", lw=0.8, alpha=0.5
)

for i in range(N):
π_paths[i] = simulate_path(a=a, b=b, T=T)
if display:
plt.plot(range(T+1), π_paths[i], color='b', lw=0.8, alpha=0.5)

if display:
plt.show()

return π_paths

return simulate
```

```{code-cell} python3
```{code-cell} ipython3
simulate = function_factory()
```

We begin by generating $N$ simulated $\{\pi_t\}$ paths with $T$
periods when the sequence is truly IID draws from $F$. We set an initial prior $\pi_{-1} = .5$.

```{code-cell} python3
```{code-cell} ipython3
T = 50
```

```{code-cell} python3
```{code-cell} ipython3
# when nature selects F
π_paths_F = simulate(a=1, b=1, T=T, N=1000)
```
Expand All @@ -629,7 +639,7 @@ discovers the truth for most of our paths.
Next, we generate paths with $T$
periods when the sequence is truly IID draws from $G$. Again, we set the initial prior $\pi_{-1} = .5$.

```{code-cell} python3
```{code-cell} ipython3
# when nature selects G
π_paths_G = simulate(a=3, b=1.2, T=T, N=1000)
```
Expand All @@ -647,9 +657,9 @@ Using $N$ simulated $\pi_t$ paths, we compute
$1 - \sum_{i=1}^{N}\pi_{i,t}$ at each $t$ when the data are generated as draws from $F$
and compute $\sum_{i=1}^{N}\pi_{i,t}$ when the data are generated as draws from $G$.

```{code-cell} python3
plt.plot(range(T+1), 1 - np.mean(π_paths_F, 0), label='F generates')
plt.plot(range(T+1), np.mean(π_paths_G, 0), label='G generates')
```{code-cell} ipython3
plt.plot(range(T + 1), 1 - np.mean(π_paths_F, 0), label="F generates")
plt.plot(range(T + 1), np.mean(π_paths_G, 0), label="G generates")
plt.legend()
plt.title("convergence");
```
Expand All @@ -673,23 +683,23 @@ where $a =f,g$.

The following code approximates the integral above:

```{code-cell} python3
```{code-cell} ipython3
def expected_ratio(F_a=1, F_b=1, G_a=3, G_b=1.2):

# define f and g
f = jit(lambda x: p(x, F_a, F_b))
g = jit(lambda x: p(x, G_a, G_b))
f = jax.jit(lambda x: p(x, F_a, F_b))
g = jax.jit(lambda x: p(x, G_a, G_b))

l = lambda w: f(w) / g(w)
integrand_f = lambda w, π: f(w) * l(w) / (π * l(w) + 1 - π)
integrand_g = lambda w, π: g(w) * l(w) / (π * l(w) + 1 - π)
l = jax.jit(lambda w: f(w) / g(w))
integrand_f = jax.jit(lambda w, π: f(w) * l(w) / (π * l(w) + 1 - π))
integrand_g = jax.jit(lambda w, π: g(w) * l(w) / (π * l(w) + 1 - π))

π_grid = np.linspace(0.02, 0.98, 100)

expected_rario = np.empty(len(π_grid))
for q, inte in zip(["f", "g"], [integrand_f, integrand_g]):
for i, π in enumerate(π_grid):
expected_rario[i]= quad(inte, 0, 1, args=(π,))[0]
expected_rario[i] = quad(inte, 0, 1, args=(π,))[0]
plt.plot(π_grid, expected_rario, label=f"{q} generates")

plt.hlines(1, 0, 1, linestyle="--")
Expand All @@ -703,20 +713,20 @@ def expected_ratio(F_a=1, F_b=1, G_a=3, G_b=1.2):
First, consider the case where $F_a=F_b=1$ and
$G_a=3, G_b=1.2$.

```{code-cell} python3
```{code-cell} ipython3
expected_ratio()
```

The above graphs shows that when $F$ generates the data, $\pi_t$ on average always heads north, while
when $G$ generates the data, $\pi_t$ heads south.

Next, we'll look at a degenerate case in whcih $f$ and $g$ are identical beta
Next, we'll look at a degenerate case in which $f$ and $g$ are identical beta
distributions, and $F_a=G_a=3, F_b=G_b=1.2$.

In a sense, here there
is nothing to learn.

```{code-cell} python3
```{code-cell} ipython3
expected_ratio(F_a=3, F_b=1.2)
```

Expand All @@ -726,7 +736,7 @@ Finally, let's look at a case in which $f$ and $g$ are neither very
different nor identical, in particular one in which $F_a=2, F_b=1$ and
$G_a=3, G_b=1.2$.

```{code-cell} python3
```{code-cell} ipython3
expected_ratio(F_a=2, F_b=1, G_a=3, G_b=1.2)
```

Expand Down
Loading