Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 46 additions & 23 deletions code_to_optimize/discrete_riccati.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
"""
Utility functions used in CompEcon
"""Utility functions used in CompEcon

Based routines found in the CompEcon toolbox by Miranda and Fackler.

Expand All @@ -9,14 +8,16 @@
and Finance, MIT Press, 2002.

"""

from functools import reduce

import numba as nb
import numpy as np
import torch


def ckron(*arrays):
"""
Repeatedly applies the np.kron function to an arbitrary number of
"""Repeatedly applies the np.kron function to an arbitrary number of
input arrays

Parameters
Expand All @@ -43,8 +44,7 @@ def ckron(*arrays):


def gridmake(*arrays):
"""
Expands one or more vectors (or matrices) into a matrix where rows span the
"""Expands one or more vectors (or matrices) into a matrix where rows span the
cartesian product of combinations of the input arrays. Each column of the
input arrays will correspond to one column of the output matrix.

Expand Down Expand Up @@ -79,13 +79,11 @@ def gridmake(*arrays):
out = _gridmake2(out, arr)

return out
else:
raise NotImplementedError("Come back here")
raise NotImplementedError("Come back here")


def _gridmake2(x1, x2):
"""
Expands two vectors (or matrices) into a matrix where rows span the
"""Expands two vectors (or matrices) into a matrix where rows span the
cartesian product of combinations of the input arrays. Each column of the
input arrays will correspond to one column of the output matrix.

Expand Down Expand Up @@ -114,19 +112,14 @@ def _gridmake2(x1, x2):

"""
if x1.ndim == 1 and x2.ndim == 1:
return np.column_stack([np.tile(x1, x2.shape[0]),
np.repeat(x2, x1.shape[0])])
elif x1.ndim > 1 and x2.ndim == 1:
first = np.tile(x1, (x2.shape[0], 1))
second = np.repeat(x2, x1.shape[0])
return np.column_stack([first, second])
else:
raise NotImplementedError("Come back here")
return _gridmake2_1d_1d(x1, x2)
if x1.ndim > 1 and x2.ndim == 1:
return _gridmake2_2d_1d(x1, x2)
raise NotImplementedError("Come back here")


def _gridmake2_torch(x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor:
"""
PyTorch version of _gridmake2.
"""PyTorch version of _gridmake2.

Expands two tensors into a matrix where rows span the cartesian product
of combinations of the input tensors. Each column of the input tensors
Expand Down Expand Up @@ -161,10 +154,40 @@ def _gridmake2_torch(x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor:
first = x1.tile(x2.shape[0])
second = x2.repeat_interleave(x1.shape[0])
return torch.column_stack([first, second])
elif x1.dim() > 1 and x2.dim() == 1:
if x1.dim() > 1 and x2.dim() == 1:
# tile x1 along first dimension
first = x1.tile(x2.shape[0], 1)
second = x2.repeat_interleave(x1.shape[0])
return torch.column_stack([first, second])
else:
raise NotImplementedError("Come back here")
raise NotImplementedError("Come back here")


@nb.njit
def _gridmake2_1d_1d(x1, x2):
n1 = x1.shape[0]
n2 = x2.shape[0]
out = np.empty((n1 * n2, 2), dtype=x1.dtype)

for i in range(n2):
for j in range(n1):
out[i * n1 + j, 0] = x1[j]
out[i * n1 + j, 1] = x2[i]

return out


@nb.njit
def _gridmake2_2d_1d(x1, x2):
n1 = x1.shape[0]
n2 = x2.shape[0]
n_cols = x1.shape[1]
out = np.empty((n1 * n2, n_cols + 1), dtype=x1.dtype)

for i in range(n2):
for j in range(n1):
idx = i * n1 + j
for k in range(n_cols):
out[idx, k] = x1[j, k]
out[idx, n_cols] = x2[i]

return out
Loading