Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions ogcore/default_parameters.json
Original file line number Diff line number Diff line change
Expand Up @@ -4348,13 +4348,13 @@
"type": "float",
"value": [
{
"value": 1e-09
"value": 1e-03
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@vahid-ahmadi I echo @nikhilwoodruff, let's not change the default values in OG-Core. These parameters can be adjusted for particular runs (e.g., you can set a looser tolerance if you want a quicker solution).

}
],
"validators": {
"range": {
"min": 1e-13,
"max": 0.001
"max": 0.1
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This max value seems too high to allow. For example, module units for labor supply are typically less than 0.5. So this would mean an error of 20% or more is ok. I think that's too much for any simulation.

}
}
},
Expand Down Expand Up @@ -4384,13 +4384,13 @@
"type": "float",
"value": [
{
"value": 1e-08
"value": 1e-04
}
],
"validators": {
"range": {
"min": 1e-13,
"max": 0.001
"max": 0.1
}
}
},
Expand Down
217 changes: 148 additions & 69 deletions ogcore/household.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,56 @@
import logging
from ogcore import config

# Try to import Numba for JIT compilation (optional speedup)
try:
from numba import njit, prange
HAS_NUMBA = True
except ImportError:
HAS_NUMBA = False
# Fallback: decorator that does nothing
def njit(*args, **kwargs):
def decorator(func):
return func
return decorator
prange = range

"""
------------------------------------------------------------------------
Functions
------------------------------------------------------------------------
"""


# Numba-optimized core computation for marginal utility of consumption
@njit(cache=True)
def _marg_ut_cons_numba(c_flat, sigma, epsilon=0.003):
"""
Numba-optimized core computation for marginal utility of consumption.

Args:
c_flat (1D array): flattened consumption array
sigma (float): coefficient of relative risk aversion
epsilon (float): threshold for constraint handling

Returns:
MU_c (1D array): marginal utility values
"""
n = c_flat.shape[0]
MU_c = np.zeros(n)
b2 = (-sigma * (epsilon ** (-sigma - 1))) / 2
b1 = (epsilon ** (-sigma)) - 2 * b2 * epsilon

for i in prange(n):
if c_flat[i] < epsilon:
# Quadratic extrapolation for constrained values
MU_c[i] = 2 * b2 * c_flat[i] + b1
else:
# Normal marginal utility
MU_c[i] = c_flat[i] ** (-sigma)

return MU_c


def marg_ut_cons(c, sigma):
r"""
Compute the marginal utility of consumption.
Expand All @@ -34,105 +77,141 @@ def marg_ut_cons(c, sigma):
"""
if np.ndim(c) == 0:
c = np.array([c])
epsilon = 0.003
cvec_cnstr = c < epsilon
MU_c = np.zeros(c.shape)
MU_c[~cvec_cnstr] = c[~cvec_cnstr] ** (-sigma)
b2 = (-sigma * (epsilon ** (-sigma - 1))) / 2
b1 = (epsilon ** (-sigma)) - 2 * b2 * epsilon
MU_c[cvec_cnstr] = 2 * b2 * c[cvec_cnstr] + b1
output = MU_c

original_shape = c.shape
c_flat = c.ravel()

# Use Numba-optimized function
MU_c_flat = _marg_ut_cons_numba(c_flat, sigma)

# Reshape back to original shape
output = MU_c_flat.reshape(original_shape)
output = np.squeeze(output)

return output


def marg_ut_labor(n, chi_n, p):
r"""
Compute the marginal disutility of labor.

.. math::
MDU_{l} = \chi^n_{s}\biggl(\frac{b}{\tilde{l}}\biggr)
\biggl(\frac{n_{j,s,t}}{\tilde{l}}\biggr)^{\upsilon-1}
\Biggl[1-\biggl(\frac{n_{j,s,t}}{\tilde{l}}\biggr)^\upsilon
\Biggr]^{\frac{1-\upsilon}{\upsilon}}
# Numba-optimized core computation for marginal disutility of labor
@njit(cache=True)
def _marg_ut_labor_numba(nvec_flat, b_ellipse, ltilde, upsilon):
"""
Numba-optimized core computation for marginal disutility of labor.

Args:
n (array_like): household labor supply
chi_n (array_like): utility weights on disutility of labor
p (OG-Core Specifications object): model parameters
nvec_flat (1D array): flattened labor supply array
b_ellipse (float): ellipse parameter
ltilde (float): time endowment
upsilon (float): Frisch elasticity parameter

Returns:
output (array_like): marginal disutility of labor supply

MDU_n (1D array): marginal disutility values
"""
nvec = n
if np.ndim(nvec) == 0:
nvec = np.array([nvec])
n = nvec_flat.shape[0]
MDU_n = np.zeros(n)
eps_low = 0.000001
eps_high = p.ltilde - 0.000001
nvec_low = nvec < eps_low
nvec_high = nvec > eps_high
nvec_uncstr = np.logical_and(~nvec_low, ~nvec_high)
MDU_n = np.zeros(nvec.shape)
MDU_n[nvec_uncstr] = (
(p.b_ellipse / p.ltilde)
* ((nvec[nvec_uncstr] / p.ltilde) ** (p.upsilon - 1))
* (
(1 - ((nvec[nvec_uncstr] / p.ltilde) ** p.upsilon))
** ((1 - p.upsilon) / p.upsilon)
)
)
eps_high = ltilde - 0.000001

# Pre-compute coefficients for low constraint
b2 = (
0.5
* p.b_ellipse
* (p.ltilde ** (-p.upsilon))
* (p.upsilon - 1)
* (eps_low ** (p.upsilon - 2))
* b_ellipse
* (ltilde ** (-upsilon))
* (upsilon - 1)
* (eps_low ** (upsilon - 2))
* (
(1 - ((eps_low / p.ltilde) ** p.upsilon))
** ((1 - p.upsilon) / p.upsilon)
(1 - ((eps_low / ltilde) ** upsilon))
** ((1 - upsilon) / upsilon)
)
* (
1
+ ((eps_low / p.ltilde) ** p.upsilon)
* ((1 - ((eps_low / p.ltilde) ** p.upsilon)) ** (-1))
+ ((eps_low / ltilde) ** upsilon)
* ((1 - ((eps_low / ltilde) ** upsilon)) ** (-1))
)
)
b1 = (p.b_ellipse / p.ltilde) * (
(eps_low / p.ltilde) ** (p.upsilon - 1)
b1 = (b_ellipse / ltilde) * (
(eps_low / ltilde) ** (upsilon - 1)
) * (
(1 - ((eps_low / p.ltilde) ** p.upsilon))
** ((1 - p.upsilon) / p.upsilon)
) - (
2 * b2 * eps_low
)
MDU_n[nvec_low] = 2 * b2 * nvec[nvec_low] + b1
(1 - ((eps_low / ltilde) ** upsilon))
** ((1 - upsilon) / upsilon)
) - (2 * b2 * eps_low)

# Pre-compute coefficients for high constraint
d2 = (
0.5
* p.b_ellipse
* (p.ltilde ** (-p.upsilon))
* (p.upsilon - 1)
* (eps_high ** (p.upsilon - 2))
* b_ellipse
* (ltilde ** (-upsilon))
* (upsilon - 1)
* (eps_high ** (upsilon - 2))
* (
(1 - ((eps_high / p.ltilde) ** p.upsilon))
** ((1 - p.upsilon) / p.upsilon)
(1 - ((eps_high / ltilde) ** upsilon))
** ((1 - upsilon) / upsilon)
)
* (
1
+ ((eps_high / p.ltilde) ** p.upsilon)
* ((1 - ((eps_high / p.ltilde) ** p.upsilon)) ** (-1))
+ ((eps_high / ltilde) ** upsilon)
* ((1 - ((eps_high / ltilde) ** upsilon)) ** (-1))
)
)
d1 = (p.b_ellipse / p.ltilde) * (
(eps_high / p.ltilde) ** (p.upsilon - 1)
d1 = (b_ellipse / ltilde) * (
(eps_high / ltilde) ** (upsilon - 1)
) * (
(1 - ((eps_high / p.ltilde) ** p.upsilon))
** ((1 - p.upsilon) / p.upsilon)
) - (
2 * d2 * eps_high
(1 - ((eps_high / ltilde) ** upsilon))
** ((1 - upsilon) / upsilon)
) - (2 * d2 * eps_high)

for i in prange(n):
nval = nvec_flat[i]
if nval < eps_low:
MDU_n[i] = 2 * b2 * nval + b1
elif nval > eps_high:
MDU_n[i] = 2 * d2 * nval + d1
else:
# Unconstrained case
MDU_n[i] = (
(b_ellipse / ltilde)
* ((nval / ltilde) ** (upsilon - 1))
* (
(1 - ((nval / ltilde) ** upsilon))
** ((1 - upsilon) / upsilon)
)
)

return MDU_n


def marg_ut_labor(n, chi_n, p):
r"""
Compute the marginal disutility of labor.

.. math::
MDU_{l} = \chi^n_{s}\biggl(\frac{b}{\tilde{l}}\biggr)
\biggl(\frac{n_{j,s,t}}{\tilde{l}}\biggr)^{\upsilon-1}
\Biggl[1-\biggl(\frac{n_{j,s,t}}{\tilde{l}}\biggr)^\upsilon
\Biggr]^{\frac{1-\upsilon}{\upsilon}}

Args:
n (array_like): household labor supply
chi_n (array_like): utility weights on disutility of labor
p (OG-Core Specifications object): model parameters

Returns:
output (array_like): marginal disutility of labor supply

"""
nvec = n
if np.ndim(nvec) == 0:
nvec = np.array([nvec])

original_shape = nvec.shape
nvec_flat = nvec.ravel()

# Use Numba-optimized function
MDU_n_flat = _marg_ut_labor_numba(
nvec_flat, p.b_ellipse, p.ltilde, p.upsilon
)
MDU_n[nvec_high] = 2 * d2 * nvec[nvec_high] + d1

# Reshape and apply chi_n weights
MDU_n = MDU_n_flat.reshape(original_shape)
output = MDU_n * np.squeeze(chi_n)
output = np.squeeze(output)
return output
Expand Down
60 changes: 56 additions & 4 deletions ogcore/tax.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,42 @@
from ogcore import utils, pensions
from ogcore.txfunc import get_tax_rates

# Try to import Numba for JIT compilation
try:
from numba import njit, vectorize, float64
HAS_NUMBA = True
except ImportError:
HAS_NUMBA = False
# Create dummy decorators if Numba not available
def njit(*args, **kwargs):
def decorator(func):
return func
if len(args) == 1 and callable(args[0]):
return args[0]
return decorator
def vectorize(*args, **kwargs):
def decorator(func):
return np.vectorize(func)
return decorator
float64 = None


# Numba-optimized wealth tax functions
if HAS_NUMBA:
@vectorize([float64(float64, float64, float64, float64)], nopython=True, cache=True)
def _etr_wealth_numba(b, h_wealth, m_wealth, p_wealth):
"""Numba-optimized effective tax rate on wealth (scalar)."""
return (p_wealth * h_wealth * b) / (h_wealth * b + m_wealth)

@vectorize([float64(float64, float64, float64, float64)], nopython=True, cache=True)
def _mtr_wealth_numba(b, h_wealth, m_wealth, p_wealth):
"""Numba-optimized marginal tax rate on wealth (scalar)."""
etr = (p_wealth * h_wealth * b) / (h_wealth * b + m_wealth)
return etr * 2 - ((h_wealth**2 * p_wealth * b**2) / ((b * h_wealth + m_wealth) ** 2))
else:
_etr_wealth_numba = None
_mtr_wealth_numba = None

"""
------------------------------------------------------------------------
Functions
Expand All @@ -34,7 +70,15 @@ def ETR_wealth(b, h_wealth, m_wealth, p_wealth):
tau_w (Numpy array): effective tax rate on wealth, size = SxJ

"""
tau_w = (p_wealth * h_wealth * b) / (h_wealth * b + m_wealth)
if HAS_NUMBA and _etr_wealth_numba is not None:
# Ensure arrays are float64 for Numba
b_arr = np.asarray(b, dtype=np.float64)
h_arr = np.float64(h_wealth) if np.isscalar(h_wealth) else np.asarray(h_wealth, dtype=np.float64)
m_arr = np.float64(m_wealth) if np.isscalar(m_wealth) else np.asarray(m_wealth, dtype=np.float64)
p_arr = np.float64(p_wealth) if np.isscalar(p_wealth) else np.asarray(p_wealth, dtype=np.float64)
tau_w = _etr_wealth_numba(b_arr, h_arr, m_arr, p_arr)
else:
tau_w = (p_wealth * h_wealth * b) / (h_wealth * b + m_wealth)

return tau_w

Expand All @@ -57,9 +101,17 @@ def MTR_wealth(b, h_wealth, m_wealth, p_wealth):
tau_prime (Numpy array): marginal tax rate on wealth, size = SxJ

"""
tau_prime = ETR_wealth(b, h_wealth, m_wealth, p_wealth) * 2 - (
(h_wealth**2 * p_wealth * b**2) / ((b * h_wealth + m_wealth) ** 2)
)
if HAS_NUMBA and _mtr_wealth_numba is not None:
# Ensure arrays are float64 for Numba
b_arr = np.asarray(b, dtype=np.float64)
h_arr = np.float64(h_wealth) if np.isscalar(h_wealth) else np.asarray(h_wealth, dtype=np.float64)
m_arr = np.float64(m_wealth) if np.isscalar(m_wealth) else np.asarray(m_wealth, dtype=np.float64)
p_arr = np.float64(p_wealth) if np.isscalar(p_wealth) else np.asarray(p_wealth, dtype=np.float64)
tau_prime = _mtr_wealth_numba(b_arr, h_arr, m_arr, p_arr)
else:
tau_prime = ETR_wealth(b, h_wealth, m_wealth, p_wealth) * 2 - (
(h_wealth**2 * p_wealth * b**2) / ((b * h_wealth + m_wealth) ** 2)
)

return tau_prime

Expand Down
Loading