Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions changelog_entry.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
- bump: patch
changes:
fixed:
- Optimisation improvements for loading tax-benefit systems (caching).
33 changes: 31 additions & 2 deletions policyengine_core/parameters/at_instant_like.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@
from policyengine_core import periods
from policyengine_core.periods import Instant

# Cache for instant -> string conversions used in get_at_instant
_instant_str_cache: dict = {}


class AtInstantLike(abc.ABC):
"""
Expand All @@ -14,8 +17,34 @@ def __call__(self, instant: Instant) -> Any:
return self.get_at_instant(instant)

def get_at_instant(self, instant: Instant) -> Any:
instant = str(periods.instant(instant))
return self._get_at_instant(instant)
# Fast path for Instant objects - use their __str__ which is cached
if isinstance(instant, Instant):
return self._get_at_instant(str(instant))

# For other types, use a cache to avoid repeated conversions
# Create a hashable cache key
cache_key = None
if isinstance(instant, str):
cache_key = instant
elif isinstance(instant, tuple):
cache_key = instant
elif isinstance(instant, int):
cache_key = (instant,)
elif hasattr(instant, "year"): # datetime.date
cache_key = (instant.year, instant.month, instant.day)

if cache_key is not None:
cached_str = _instant_str_cache.get(cache_key)
if cached_str is not None:
return self._get_at_instant(cached_str)
instant_obj = periods.instant(instant)
instant_str = str(instant_obj)
_instant_str_cache[cache_key] = instant_str
return self._get_at_instant(instant_str)

# Fallback for other types (Period, list, etc.)
instant_str = str(periods.instant(instant))
return self._get_at_instant(instant_str)

@abc.abstractmethod
def _get_at_instant(self, instant): ...
21 changes: 14 additions & 7 deletions policyengine_core/parameters/operations/uprate_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,26 +124,33 @@ def uprate_parameters(root: ParameterNode) -> ParameterNode:
parameter.values_list[0].instant_str
)

# Pre-compute values that don't change in the loop
last_instant_str = str(last_instant)
value_at_start = parameter(last_instant)
uprater_at_start = uprating_parameter(last_instant)

if uprater_at_start is None:
raise ValueError(
f"Failed to uprate using {uprating_parameter.name} at {last_instant} for {parameter.name} because the uprating parameter is not defined at {last_instant}."
)

# Pre-compute uprater values for all entries to avoid repeated lookups
has_rounding = "rounding" in meta

# For each defined instant in the uprating parameter
for entry in uprating_parameter.values_list[::-1]:
entry_instant = instant(entry.instant_str)
# If the uprater instant is defined after the last parameter instant
if entry_instant > last_instant:
# Apply the uprater and add to the parameter
value_at_start = parameter(last_instant)
uprater_at_start = uprating_parameter(last_instant)
if uprater_at_start is None:
raise ValueError(
f"Failed to uprate using {uprating_parameter.name} at {last_instant} for {parameter.name} at {entry_instant} because the uprating parameter is not defined at {last_instant}."
)
uprater_at_entry = uprating_parameter(
entry_instant
)
uprater_change = (
uprater_at_entry / uprater_at_start
)
uprated_value = value_at_start * uprater_change
if "rounding" in meta:
if has_rounding:
uprated_value = round_uprated_value(
meta, uprated_value
)
Expand Down
42 changes: 31 additions & 11 deletions policyengine_core/periods/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,11 @@
from policyengine_core import periods
from policyengine_core.periods import config

# Global cache for instant objects to avoid repeated tuple creation
_instant_cache: dict = {}

@lru_cache(maxsize=1024)

@lru_cache(maxsize=10000)
def _instant_from_string(instant_str: str) -> "periods.Instant":
"""Cached parsing of instant strings."""
if not config.INSTANT_PATTERN.match(instant_str):
Expand Down Expand Up @@ -48,18 +51,35 @@ def instant(instant):
return instant
if isinstance(instant, str):
return _instant_from_string(instant)

# For other types, create a cache key and check the cache
cache_key = None
# Check Period before tuple since Period is a subclass of tuple
if isinstance(instant, periods.Period):
return instant.start
elif isinstance(instant, datetime.date):
instant = periods.Instant((instant.year, instant.month, instant.day))
cache_key = (instant.year, instant.month, instant.day)
elif isinstance(instant, int):
instant = (instant,)
elif isinstance(instant, list):
assert 1 <= len(instant) <= 3
instant = tuple(instant)
elif isinstance(instant, periods.Period):
instant = instant.start
else:
assert isinstance(instant, tuple), instant
assert 1 <= len(instant) <= 3
cache_key = (instant, 1, 1)
elif isinstance(instant, (tuple, list)):
if len(instant) == 1:
cache_key = (instant[0], 1, 1)
elif len(instant) == 2:
cache_key = (instant[0], instant[1], 1)
elif len(instant) == 3:
cache_key = tuple(instant)

if cache_key is not None:
cached = _instant_cache.get(cache_key)
if cached is not None:
return cached
result = periods.Instant(cache_key)
_instant_cache[cache_key] = result
return result

# Fallback for unexpected types
assert isinstance(instant, tuple), instant
assert 1 <= len(instant) <= 3
if len(instant) == 1:
return periods.Instant((instant[0], 1, 1))
if len(instant) == 2:
Expand Down
17 changes: 7 additions & 10 deletions policyengine_core/periods/period_.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

import calendar
from datetime import datetime
from datetime import datetime, date, timedelta
from typing import List

from policyengine_core import periods
Expand Down Expand Up @@ -463,15 +463,12 @@ def stop(self) -> periods.Instant:
return periods.Instant((float("inf"), float("inf"), float("inf")))
if unit == "day":
if size > 1:
day += size - 1
month_last_day = calendar.monthrange(year, month)[1]
while day > month_last_day:
month += 1
if month == 13:
year += 1
month = 1
day -= month_last_day
month_last_day = calendar.monthrange(year, month)[1]
# Use datetime arithmetic for efficient day calculation
start_date = date(year, month, day)
end_date = start_date + timedelta(days=size - 1)
return periods.Instant(
(end_date.year, end_date.month, end_date.day)
)
else:
if unit == "month":
month += size
Expand Down