From e075101e48326bc291ed6952f755ad5f8f442e7e Mon Sep 17 00:00:00 2001 From: Nikhil Woodruff Date: Sun, 30 Nov 2025 21:42:58 +0000 Subject: [PATCH 1/6] Optimise import and microsimulation init performance Three changes that together reduce import + Microsimulation() time by ~40%: 1. Enum encoding: replace np.select (O(n*m)) with np.searchsorted (O(n log m)) plus cached lookup arrays 2. empty_clone: replace dynamic type creation with object.__new__() 3. Period/instant parsing: add lru_cache to avoid repeated strptime calls --- policyengine_core/commons/misc.py | 14 +- policyengine_core/enums/enum.py | 46 ++++--- policyengine_core/periods/helpers.py | 187 +++++++++++++++------------ 3 files changed, 132 insertions(+), 115 deletions(-) diff --git a/policyengine_core/commons/misc.py b/policyengine_core/commons/misc.py index 668997ae2..6b2cb4790 100644 --- a/policyengine_core/commons/misc.py +++ b/policyengine_core/commons/misc.py @@ -28,19 +28,7 @@ def empty_clone(original: T) -> T: True """ - - Dummy: object - new: T - - Dummy = type( - "Dummy", - (original.__class__,), - {"__init__": lambda self: None}, - ) - - new = Dummy() - new.__class__ = original.__class__ - return new + return object.__new__(original.__class__) def stringify_array(array: ArrayType) -> str: diff --git a/policyengine_core/enums/enum.py b/policyengine_core/enums/enum.py index b66d7f237..42853240f 100644 --- a/policyengine_core/enums/enum.py +++ b/policyengine_core/enums/enum.py @@ -1,6 +1,7 @@ from __future__ import annotations import enum -from typing import Union +from functools import lru_cache +from typing import Tuple, Union import numpy as np from .config import ENUM_ARRAY_DTYPE from .enum_array import EnumArray @@ -23,6 +24,18 @@ def __init__(self, name: str) -> None: __eq__ = object.__eq__ __hash__ = object.__hash__ + @classmethod + @lru_cache(maxsize=None) + def _get_sorted_lookup_arrays(cls) -> Tuple[np.ndarray, np.ndarray]: + """Build cached sorted arrays for fast searchsorted-based lookup.""" + name_to_index = {item.name: item.index for item in cls} + sorted_names = sorted(name_to_index.keys()) + sorted_names_arr = np.array(sorted_names) + sorted_indices = np.array( + [name_to_index[n] for n in sorted_names], dtype=ENUM_ARRAY_DTYPE + ) + return sorted_names_arr, sorted_indices + @classmethod def encode(cls, array: Union[EnumArray, np.ndarray]) -> EnumArray: """ @@ -49,34 +62,31 @@ def encode(cls, array: Union[EnumArray, np.ndarray]) -> EnumArray: if isinstance(array, EnumArray): return array - # First, convert byte-string arrays to Unicode-string arrays - # Confusingly, Numpy uses "S" to refer to byte-string arrays - # and "U" to refer to Unicode-string arrays, which are also - # referred to as the "str" type - if isinstance(array[0], Enum): + # Handle Enum item arrays by extracting names + if len(array) > 0 and isinstance(array[0], Enum): array = np.array([item.name for item in array]) + + # Convert byte-strings or object arrays to Unicode strings if array.dtype.kind == "S" or array.dtype == object: - # Convert boolean array to string array array = array.astype(str) + if isinstance(array, np.ndarray) and array.dtype.kind in {"U", "S"}: - # String array - indices = np.select( - [array == item.name for item in cls], - [item.index for item in cls], - ) + # String array - use searchsorted for O(n log m) lookup + sorted_names, sorted_indices = cls._get_sorted_lookup_arrays() + positions = np.searchsorted(sorted_names, array) + indices = sorted_indices[positions] elif isinstance(array, np.ndarray) and array.dtype.kind == "O": - # Enum items array + # Object array containing Enum items if len(array) > 0: first_item = array[0] if cls.__name__ == type(first_item).__name__: - # Use the same Enum class as the array items cls = type(first_item) - indices = np.select( - [array == item for item in cls], - [item.index for item in cls], + # Extract indices directly from enum items + indices = np.array( + [item.index for item in array], dtype=ENUM_ARRAY_DTYPE ) elif array.dtype.kind in {"i", "u"}: - # Integer array + # Integer array - already indices indices = array else: raise ValueError(f"Unsupported array dtype: {array.dtype}") diff --git a/policyengine_core/periods/helpers.py b/policyengine_core/periods/helpers.py index 26bdc3d26..be675f556 100644 --- a/policyengine_core/periods/helpers.py +++ b/policyengine_core/periods/helpers.py @@ -1,10 +1,29 @@ import datetime import os +from functools import lru_cache from policyengine_core import periods from policyengine_core.periods import config +@lru_cache(maxsize=1024) +def _instant_from_string(instant_str: str) -> "periods.Instant": + """Cached parsing of instant strings.""" + if not config.INSTANT_PATTERN.match(instant_str): + raise ValueError( + "'{}' is not a valid instant. Instants are described using the 'YYYY-MM-DD' format, for instance '2015-06-15'.".format( + instant_str + ) + ) + parts = instant_str.split("-", 2)[:3] + if len(parts) == 1: + return periods.Instant((int(parts[0]), 1, 1)) + elif len(parts) == 2: + return periods.Instant((int(parts[0]), int(parts[1]), 1)) + else: + return periods.Instant((int(parts[0]), int(parts[1]), int(parts[2]))) + + def instant(instant): """Return a new instant, aka a triple of integers (year, month, day). @@ -28,15 +47,7 @@ def instant(instant): if isinstance(instant, periods.Instant): return instant if isinstance(instant, str): - if not config.INSTANT_PATTERN.match(instant): - raise ValueError( - "'{}' is not a valid instant. Instants are described using the 'YYYY-MM-DD' format, for instance '2015-06-15'.".format( - instant - ) - ) - instant = periods.Instant( - int(fragment) for fragment in instant.split("-", 2)[:3] - ) + return _instant_from_string(instant) elif isinstance(instant, datetime.date): instant = periods.Instant((instant.year, instant.month, instant.day)) elif isinstance(instant, int): @@ -67,108 +78,77 @@ def instant_date(instant): return instant_date -def period(value): - """Return a new period, aka a triple (unit, start_instant, size). - - >>> period('2014') - Period((YEAR, Instant((2014, 1, 1)), 1)) - >>> period('year:2014') - Period((YEAR, Instant((2014, 1, 1)), 1)) - - >>> period('2014-2') - Period((MONTH, Instant((2014, 2, 1)), 1)) - >>> period('2014-02') - Period((MONTH, Instant((2014, 2, 1)), 1)) - >>> period('month:2014-2') - Period((MONTH, Instant((2014, 2, 1)), 1)) - - >>> period('year:2014-2') - Period((YEAR, Instant((2014, 2, 1)), 1)) - """ - if isinstance(value, periods.Period): - return value - - if isinstance(value, periods.Instant): - return periods.Period((config.DAY, value, 1)) - - def parse_simple_period(value): - """ - Parses simple periods respecting the ISO format, such as 2012 or 2015-03 - """ +@lru_cache(maxsize=1024) +def _parse_simple_period(value: str): + """Cached parsing of simple periods respecting the ISO format.""" + try: + date = datetime.datetime.strptime(value, "%Y") + except ValueError: try: - date = datetime.datetime.strptime(value, "%Y") + date = datetime.datetime.strptime(value, "%Y-%m") except ValueError: try: - date = datetime.datetime.strptime(value, "%Y-%m") + date = datetime.datetime.strptime(value, "%Y-%m-%d") except ValueError: - try: - date = datetime.datetime.strptime(value, "%Y-%m-%d") - except ValueError: - return None - else: - return periods.Period( - ( - config.DAY, - periods.Instant((date.year, date.month, date.day)), - 1, - ) - ) + return None else: return periods.Period( ( - config.MONTH, - periods.Instant((date.year, date.month, 1)), + config.DAY, + periods.Instant((date.year, date.month, date.day)), 1, ) ) else: return periods.Period( - (config.YEAR, periods.Instant((date.year, date.month, 1)), 1) + ( + config.MONTH, + periods.Instant((date.year, date.month, 1)), + 1, + ) ) - - def raise_error(value): - message = os.linesep.join( - [ - "Expected a period (eg. '2017', '2017-01', '2017-01-01', ...); got: '{}'.".format( - value - ), - "Learn more about legal period formats in OpenFisca:", - ".", - ] - ) - raise ValueError(message) - - if value == "ETERNITY" or value == config.ETERNITY: + else: return periods.Period( - ("eternity", instant(datetime.date.min), float("inf")) + (config.YEAR, periods.Instant((date.year, date.month, 1)), 1) ) - # check the type - if isinstance(value, int): - return periods.Period((config.YEAR, periods.Instant((value, 1, 1)), 1)) - if not isinstance(value, str): - raise_error(value) +def _raise_period_error(value): + message = os.linesep.join( + [ + "Expected a period (eg. '2017', '2017-01', '2017-01-01', ...); got: '{}'.".format( + value + ), + "Learn more about legal period formats in OpenFisca:", + ".", + ] + ) + raise ValueError(message) + + +@lru_cache(maxsize=1024) +def _period_from_string(value: str) -> "periods.Period": + """Cached parsing of period strings.""" # try to parse as a simple period - period = parse_simple_period(value) - if period is not None: - return period + result = _parse_simple_period(value) + if result is not None: + return result # complex period must have a ':' in their strings if ":" not in value: - raise_error(value) + _raise_period_error(value) components = value.split(":") # left-most component must be a valid unit unit = components[0] if unit not in (config.DAY, config.MONTH, config.YEAR): - raise_error(value) + _raise_period_error(value) # middle component must be a valid iso period - base_period = parse_simple_period(components[1]) + base_period = _parse_simple_period(components[1]) if not base_period: - raise_error(value) + _raise_period_error(value) # period like year:2015-03 have a size of 1 if len(components) == 2: @@ -178,18 +158,57 @@ def raise_error(value): try: size = int(components[2]) except ValueError: - raise_error(value) + _raise_period_error(value) # if there is more than 2 ":" in the string, the period is invalid else: - raise_error(value) + _raise_period_error(value) # reject ambiguous period such as month:2014 if unit_weight(base_period.unit) > unit_weight(unit): - raise_error(value) + _raise_period_error(value) return periods.Period((unit, base_period.start, size)) +def period(value): + """Return a new period, aka a triple (unit, start_instant, size). + + >>> period('2014') + Period((YEAR, Instant((2014, 1, 1)), 1)) + >>> period('year:2014') + Period((YEAR, Instant((2014, 1, 1)), 1)) + + >>> period('2014-2') + Period((MONTH, Instant((2014, 2, 1)), 1)) + >>> period('2014-02') + Period((MONTH, Instant((2014, 2, 1)), 1)) + >>> period('month:2014-2') + Period((MONTH, Instant((2014, 2, 1)), 1)) + + >>> period('year:2014-2') + Period((YEAR, Instant((2014, 2, 1)), 1)) + """ + if isinstance(value, periods.Period): + return value + + if isinstance(value, periods.Instant): + return periods.Period((config.DAY, value, 1)) + + if value == "ETERNITY" or value == config.ETERNITY: + return periods.Period( + ("eternity", instant(datetime.date.min), float("inf")) + ) + + # check the type + if isinstance(value, int): + return periods.Period((config.YEAR, periods.Instant((value, 1, 1)), 1)) + + if isinstance(value, str): + return _period_from_string(value) + + _raise_period_error(value) + + def key_period_size(period): """ Defines a key in order to sort periods by length. It uses two aspects : first unit then size From f0e8f665cc1f40b232c248dde061f11700a2e4fb Mon Sep 17 00:00:00 2001 From: Nikhil Woodruff Date: Sun, 30 Nov 2025 21:45:17 +0000 Subject: [PATCH 2/6] Add changelog entry --- changelog_entry.yaml | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/changelog_entry.yaml b/changelog_entry.yaml index e69de29bb..6863aab54 100644 --- a/changelog_entry.yaml +++ b/changelog_entry.yaml @@ -0,0 +1,6 @@ +- bump: minor + changes: + changed: + - Optimised enum encoding using searchsorted instead of np.select. + - Optimised empty_clone using object.__new__() instead of dynamic type creation. + - Added lru_cache to period and instant string parsing. From 1b58bdd82a5db988bc51520c27a16f2bea08d2a2 Mon Sep 17 00:00:00 2001 From: Nikhil Woodruff Date: Sun, 30 Nov 2025 22:01:40 +0000 Subject: [PATCH 3/6] Vectorise random() function with PCG hash --- changelog_entry.yaml | 1 + policyengine_core/commons/formulas.py | 24 +++++++++++++----------- 2 files changed, 14 insertions(+), 11 deletions(-) diff --git a/changelog_entry.yaml b/changelog_entry.yaml index 6863aab54..e356fb89a 100644 --- a/changelog_entry.yaml +++ b/changelog_entry.yaml @@ -4,3 +4,4 @@ - Optimised enum encoding using searchsorted instead of np.select. - Optimised empty_clone using object.__new__() instead of dynamic type creation. - Added lru_cache to period and instant string parsing. + - Vectorised random() function using PCG hash instead of per-entity RNG instantiation. diff --git a/policyengine_core/commons/formulas.py b/policyengine_core/commons/formulas.py index 34d89beab..2bd1c494d 100644 --- a/policyengine_core/commons/formulas.py +++ b/policyengine_core/commons/formulas.py @@ -333,17 +333,19 @@ def random(population): # Get entity IDs for the period entity_ids = population(f"{population.entity.key}_id", period) - # Generate random values for each entity - values = np.array( - [ - np.random.default_rng( - seed=int( - abs(id * 100 + population.simulation.count_random_calls) - ) - ).random() - for id in entity_ids - ] - ) + # Generate deterministic random values using vectorised hash + seeds = np.abs( + entity_ids * 100 + population.simulation.count_random_calls + ).astype(np.uint64) + + # PCG-style mixing function for high-quality pseudo-random generation + x = seeds * np.uint64(0x5851F42D4C957F2D) + x = x ^ (x >> np.uint64(33)) + x = x * np.uint64(0xC4CEB9FE1A85EC53) + x = x ^ (x >> np.uint64(33)) + + # Convert to float in [0, 1) using upper 53 bits for full double precision + values = (x >> np.uint64(11)).astype(np.float64) / (2**53) return values From 24858df3e978d3393f0e72b79ee4a7176568defd Mon Sep 17 00:00:00 2001 From: Nikhil Woodruff Date: Sun, 30 Nov 2025 22:03:46 +0000 Subject: [PATCH 4/6] Fix enum encoding edge cases per review comments --- policyengine_core/enums/enum.py | 23 +++++++++++------------ 1 file changed, 11 insertions(+), 12 deletions(-) diff --git a/policyengine_core/enums/enum.py b/policyengine_core/enums/enum.py index 42853240f..fd0b1fc8b 100644 --- a/policyengine_core/enums/enum.py +++ b/policyengine_core/enums/enum.py @@ -62,9 +62,12 @@ def encode(cls, array: Union[EnumArray, np.ndarray]) -> EnumArray: if isinstance(array, EnumArray): return array - # Handle Enum item arrays by extracting names + # Handle Enum item arrays by extracting indices directly if len(array) > 0 and isinstance(array[0], Enum): - array = np.array([item.name for item in array]) + indices = np.array( + [item.index for item in array], dtype=ENUM_ARRAY_DTYPE + ) + return EnumArray(indices, cls) # Convert byte-strings or object arrays to Unicode strings if array.dtype.kind == "S" or array.dtype == object: @@ -74,17 +77,13 @@ def encode(cls, array: Union[EnumArray, np.ndarray]) -> EnumArray: # String array - use searchsorted for O(n log m) lookup sorted_names, sorted_indices = cls._get_sorted_lookup_arrays() positions = np.searchsorted(sorted_names, array) + # Clip positions to valid range to avoid IndexError + positions = np.clip(positions, 0, len(sorted_names) - 1) + # Validate that we found exact matches + if not np.all(sorted_names[positions] == array): + invalid = array[sorted_names[positions] != array] + raise ValueError(f"Invalid enum values: {invalid[:5]}") indices = sorted_indices[positions] - elif isinstance(array, np.ndarray) and array.dtype.kind == "O": - # Object array containing Enum items - if len(array) > 0: - first_item = array[0] - if cls.__name__ == type(first_item).__name__: - cls = type(first_item) - # Extract indices directly from enum items - indices = np.array( - [item.index for item in array], dtype=ENUM_ARRAY_DTYPE - ) elif array.dtype.kind in {"i", "u"}: # Integer array - already indices indices = array From 77878fb6f2b93a98559479c38ef3498349977af0 Mon Sep 17 00:00:00 2001 From: Nikhil Woodruff Date: Sun, 30 Nov 2025 22:09:49 +0000 Subject: [PATCH 5/6] Return 0 for invalid enum values to match old np.select behaviour --- policyengine_core/enums/enum.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/policyengine_core/enums/enum.py b/policyengine_core/enums/enum.py index fd0b1fc8b..00e6e498f 100644 --- a/policyengine_core/enums/enum.py +++ b/policyengine_core/enums/enum.py @@ -79,11 +79,9 @@ def encode(cls, array: Union[EnumArray, np.ndarray]) -> EnumArray: positions = np.searchsorted(sorted_names, array) # Clip positions to valid range to avoid IndexError positions = np.clip(positions, 0, len(sorted_names) - 1) - # Validate that we found exact matches - if not np.all(sorted_names[positions] == array): - invalid = array[sorted_names[positions] != array] - raise ValueError(f"Invalid enum values: {invalid[:5]}") - indices = sorted_indices[positions] + # For non-matches, return 0 (first enum value) to match old np.select behaviour + matches = sorted_names[positions] == array + indices = np.where(matches, sorted_indices[positions], 0) elif array.dtype.kind in {"i", "u"}: # Integer array - already indices indices = array From 84762859aa6b3885500b516cbf8bf12bbe17eb94 Mon Sep 17 00:00:00 2001 From: Max Ghenis Date: Mon, 1 Dec 2025 10:12:47 -0500 Subject: [PATCH 6/6] Add warning for invalid enum values and document random() change MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Log warning when encoding invalid enum string values (they default to 0) - Add tests for invalid enum value warning - Document in changelog that random() now produces different sequences 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude --- changelog_entry.yaml | 3 ++- policyengine_core/enums/enum.py | 12 +++++++++ tests/core/enums/test_enum.py | 44 +++++++++++++++++++++++++++++++++ 3 files changed, 58 insertions(+), 1 deletion(-) diff --git a/changelog_entry.yaml b/changelog_entry.yaml index e356fb89a..cfe3043b5 100644 --- a/changelog_entry.yaml +++ b/changelog_entry.yaml @@ -4,4 +4,5 @@ - Optimised enum encoding using searchsorted instead of np.select. - Optimised empty_clone using object.__new__() instead of dynamic type creation. - Added lru_cache to period and instant string parsing. - - Vectorised random() function using PCG hash instead of per-entity RNG instantiation. + - Vectorised random() function using PCG hash instead of per-entity RNG instantiation. Note that this changes random value sequences - simulations using random() will produce different (but still deterministic) values. + - Added warning logging for invalid enum string values during encoding. diff --git a/policyengine_core/enums/enum.py b/policyengine_core/enums/enum.py index 00e6e498f..3072d9a27 100644 --- a/policyengine_core/enums/enum.py +++ b/policyengine_core/enums/enum.py @@ -1,11 +1,14 @@ from __future__ import annotations import enum +import logging from functools import lru_cache from typing import Tuple, Union import numpy as np from .config import ENUM_ARRAY_DTYPE from .enum_array import EnumArray +log = logging.getLogger(__name__) + class Enum(enum.Enum): """ @@ -82,6 +85,15 @@ def encode(cls, array: Union[EnumArray, np.ndarray]) -> EnumArray: # For non-matches, return 0 (first enum value) to match old np.select behaviour matches = sorted_names[positions] == array indices = np.where(matches, sorted_indices[positions], 0) + # Log warning for invalid values + invalid_mask = ~matches + if np.any(invalid_mask): + invalid_values = np.unique(array[invalid_mask]) + log.warning( + f"Invalid values for enum {cls.__name__}: " + f"{invalid_values.tolist()}. " + f"These will be encoded as index 0." + ) elif array.dtype.kind in {"i", "u"}: # Integer array - already indices indices = array diff --git a/tests/core/enums/test_enum.py b/tests/core/enums/test_enum.py index 0dde9bac4..6c8833b38 100644 --- a/tests/core/enums/test_enum.py +++ b/tests/core/enums/test_enum.py @@ -1,5 +1,6 @@ import pytest import numpy as np +import logging from policyengine_core.enums.enum import Enum from policyengine_core.enums.enum_array import EnumArray @@ -37,3 +38,46 @@ class Sample(Enum): assert len(encoded_array) == 3 assert isinstance(encoded_array, EnumArray) assert encoded_array.dtype.kind == "i" + + +def test_enum_encode_invalid_values_logs_warning(caplog): + """Test that encoding invalid enum string values logs a warning.""" + + class Sample(Enum): + MAXWELL = "maxwell" + DWORKIN = "dworkin" + + # Array with invalid values mixed in + array_with_invalid = np.array(["MAXWELL", "INVALID_VALUE", "DWORKIN"]) + + with caplog.at_level(logging.WARNING): + encoded = Sample.encode(array_with_invalid) + + # Should still return an array (with 0 for invalid) + assert len(encoded) == 3 + assert encoded[0] == Sample.MAXWELL.index + assert encoded[1] == 0 # Invalid defaults to 0 + assert encoded[2] == Sample.DWORKIN.index + + # Should have logged a warning + assert any("INVALID_VALUE" in record.message for record in caplog.records) + assert any("Sample" in record.message for record in caplog.records) + + +def test_enum_encode_all_invalid_logs_warning(caplog): + """Test that encoding all invalid values logs a warning.""" + + class Sample(Enum): + MAXWELL = "maxwell" + DWORKIN = "dworkin" + + all_invalid = np.array(["FOO", "BAR", "BAZ"]) + + with caplog.at_level(logging.WARNING): + encoded = Sample.encode(all_invalid) + + # All should be 0 + assert all(encoded == 0) + + # Should have logged warnings + assert len(caplog.records) > 0