diff --git a/changelog_entry.yaml b/changelog_entry.yaml index e69de29b..e5391df5 100644 --- a/changelog_entry.yaml +++ b/changelog_entry.yaml @@ -0,0 +1,4 @@ +- bump: patch + changes: + fixed: + - Optimisation improvements for loading tax-benefit systems (caching). diff --git a/policyengine_core/parameters/at_instant_like.py b/policyengine_core/parameters/at_instant_like.py index ecde8614..86e0b22f 100644 --- a/policyengine_core/parameters/at_instant_like.py +++ b/policyengine_core/parameters/at_instant_like.py @@ -4,6 +4,9 @@ from policyengine_core import periods from policyengine_core.periods import Instant +# Cache for instant -> string conversions used in get_at_instant +_instant_str_cache: dict = {} + class AtInstantLike(abc.ABC): """ @@ -14,8 +17,34 @@ def __call__(self, instant: Instant) -> Any: return self.get_at_instant(instant) def get_at_instant(self, instant: Instant) -> Any: - instant = str(periods.instant(instant)) - return self._get_at_instant(instant) + # Fast path for Instant objects - use their __str__ which is cached + if isinstance(instant, Instant): + return self._get_at_instant(str(instant)) + + # For other types, use a cache to avoid repeated conversions + # Create a hashable cache key + cache_key = None + if isinstance(instant, str): + cache_key = instant + elif isinstance(instant, tuple): + cache_key = instant + elif isinstance(instant, int): + cache_key = (instant,) + elif hasattr(instant, "year"): # datetime.date + cache_key = (instant.year, instant.month, instant.day) + + if cache_key is not None: + cached_str = _instant_str_cache.get(cache_key) + if cached_str is not None: + return self._get_at_instant(cached_str) + instant_obj = periods.instant(instant) + instant_str = str(instant_obj) + _instant_str_cache[cache_key] = instant_str + return self._get_at_instant(instant_str) + + # Fallback for other types (Period, list, etc.) + instant_str = str(periods.instant(instant)) + return self._get_at_instant(instant_str) @abc.abstractmethod def _get_at_instant(self, instant): ... diff --git a/policyengine_core/parameters/operations/uprate_parameters.py b/policyengine_core/parameters/operations/uprate_parameters.py index bdeb2a30..b09b2f5c 100644 --- a/policyengine_core/parameters/operations/uprate_parameters.py +++ b/policyengine_core/parameters/operations/uprate_parameters.py @@ -124,18 +124,25 @@ def uprate_parameters(root: ParameterNode) -> ParameterNode: parameter.values_list[0].instant_str ) + # Pre-compute values that don't change in the loop + last_instant_str = str(last_instant) + value_at_start = parameter(last_instant) + uprater_at_start = uprating_parameter(last_instant) + + if uprater_at_start is None: + raise ValueError( + f"Failed to uprate using {uprating_parameter.name} at {last_instant} for {parameter.name} because the uprating parameter is not defined at {last_instant}." + ) + + # Pre-compute uprater values for all entries to avoid repeated lookups + has_rounding = "rounding" in meta + # For each defined instant in the uprating parameter for entry in uprating_parameter.values_list[::-1]: entry_instant = instant(entry.instant_str) # If the uprater instant is defined after the last parameter instant if entry_instant > last_instant: # Apply the uprater and add to the parameter - value_at_start = parameter(last_instant) - uprater_at_start = uprating_parameter(last_instant) - if uprater_at_start is None: - raise ValueError( - f"Failed to uprate using {uprating_parameter.name} at {last_instant} for {parameter.name} at {entry_instant} because the uprating parameter is not defined at {last_instant}." - ) uprater_at_entry = uprating_parameter( entry_instant ) @@ -143,7 +150,7 @@ def uprate_parameters(root: ParameterNode) -> ParameterNode: uprater_at_entry / uprater_at_start ) uprated_value = value_at_start * uprater_change - if "rounding" in meta: + if has_rounding: uprated_value = round_uprated_value( meta, uprated_value ) diff --git a/policyengine_core/periods/helpers.py b/policyengine_core/periods/helpers.py index be675f55..f8ee95b8 100644 --- a/policyengine_core/periods/helpers.py +++ b/policyengine_core/periods/helpers.py @@ -5,8 +5,11 @@ from policyengine_core import periods from policyengine_core.periods import config +# Global cache for instant objects to avoid repeated tuple creation +_instant_cache: dict = {} -@lru_cache(maxsize=1024) + +@lru_cache(maxsize=10000) def _instant_from_string(instant_str: str) -> "periods.Instant": """Cached parsing of instant strings.""" if not config.INSTANT_PATTERN.match(instant_str): @@ -48,18 +51,35 @@ def instant(instant): return instant if isinstance(instant, str): return _instant_from_string(instant) + + # For other types, create a cache key and check the cache + cache_key = None + # Check Period before tuple since Period is a subclass of tuple + if isinstance(instant, periods.Period): + return instant.start elif isinstance(instant, datetime.date): - instant = periods.Instant((instant.year, instant.month, instant.day)) + cache_key = (instant.year, instant.month, instant.day) elif isinstance(instant, int): - instant = (instant,) - elif isinstance(instant, list): - assert 1 <= len(instant) <= 3 - instant = tuple(instant) - elif isinstance(instant, periods.Period): - instant = instant.start - else: - assert isinstance(instant, tuple), instant - assert 1 <= len(instant) <= 3 + cache_key = (instant, 1, 1) + elif isinstance(instant, (tuple, list)): + if len(instant) == 1: + cache_key = (instant[0], 1, 1) + elif len(instant) == 2: + cache_key = (instant[0], instant[1], 1) + elif len(instant) == 3: + cache_key = tuple(instant) + + if cache_key is not None: + cached = _instant_cache.get(cache_key) + if cached is not None: + return cached + result = periods.Instant(cache_key) + _instant_cache[cache_key] = result + return result + + # Fallback for unexpected types + assert isinstance(instant, tuple), instant + assert 1 <= len(instant) <= 3 if len(instant) == 1: return periods.Instant((instant[0], 1, 1)) if len(instant) == 2: diff --git a/policyengine_core/periods/period_.py b/policyengine_core/periods/period_.py index 6310d72f..e7656b37 100644 --- a/policyengine_core/periods/period_.py +++ b/policyengine_core/periods/period_.py @@ -1,7 +1,7 @@ from __future__ import annotations import calendar -from datetime import datetime +from datetime import datetime, date, timedelta from typing import List from policyengine_core import periods @@ -463,15 +463,12 @@ def stop(self) -> periods.Instant: return periods.Instant((float("inf"), float("inf"), float("inf"))) if unit == "day": if size > 1: - day += size - 1 - month_last_day = calendar.monthrange(year, month)[1] - while day > month_last_day: - month += 1 - if month == 13: - year += 1 - month = 1 - day -= month_last_day - month_last_day = calendar.monthrange(year, month)[1] + # Use datetime arithmetic for efficient day calculation + start_date = date(year, month, day) + end_date = start_date + timedelta(days=size - 1) + return periods.Instant( + (end_date.year, end_date.month, end_date.day) + ) else: if unit == "month": month += size