Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions changelog_entry.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
- bump: minor
changes:
changed:
- Optimised enum encoding using searchsorted instead of np.select.
- Optimised empty_clone using object.__new__() instead of dynamic type creation.
- Added lru_cache to period and instant string parsing.
- Vectorised random() function using PCG hash instead of per-entity RNG instantiation. Note that this changes random value sequences - simulations using random() will produce different (but still deterministic) values.
- Added warning logging for invalid enum string values during encoding.
24 changes: 13 additions & 11 deletions policyengine_core/commons/formulas.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,17 +333,19 @@ def random(population):
# Get entity IDs for the period
entity_ids = population(f"{population.entity.key}_id", period)

# Generate random values for each entity
values = np.array(
[
np.random.default_rng(
seed=int(
abs(id * 100 + population.simulation.count_random_calls)
)
).random()
for id in entity_ids
]
)
# Generate deterministic random values using vectorised hash
seeds = np.abs(
entity_ids * 100 + population.simulation.count_random_calls
).astype(np.uint64)

# PCG-style mixing function for high-quality pseudo-random generation
x = seeds * np.uint64(0x5851F42D4C957F2D)
x = x ^ (x >> np.uint64(33))
x = x * np.uint64(0xC4CEB9FE1A85EC53)
x = x ^ (x >> np.uint64(33))

# Convert to float in [0, 1) using upper 53 bits for full double precision
values = (x >> np.uint64(11)).astype(np.float64) / (2**53)

return values

Expand Down
14 changes: 1 addition & 13 deletions policyengine_core/commons/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,19 +28,7 @@ def empty_clone(original: T) -> T:
True

"""

Dummy: object
new: T

Dummy = type(
"Dummy",
(original.__class__,),
{"__init__": lambda self: None},
)

new = Dummy()
new.__class__ = original.__class__
return new
return object.__new__(original.__class__)


def stringify_array(array: ArrayType) -> str:
Expand Down
69 changes: 44 additions & 25 deletions policyengine_core/enums/enum.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
from __future__ import annotations
import enum
from typing import Union
import logging
from functools import lru_cache
from typing import Tuple, Union
import numpy as np
from .config import ENUM_ARRAY_DTYPE
from .enum_array import EnumArray

log = logging.getLogger(__name__)


class Enum(enum.Enum):
"""
Expand All @@ -23,6 +27,18 @@ def __init__(self, name: str) -> None:
__eq__ = object.__eq__
__hash__ = object.__hash__

@classmethod
@lru_cache(maxsize=None)
def _get_sorted_lookup_arrays(cls) -> Tuple[np.ndarray, np.ndarray]:
"""Build cached sorted arrays for fast searchsorted-based lookup."""
name_to_index = {item.name: item.index for item in cls}
sorted_names = sorted(name_to_index.keys())
sorted_names_arr = np.array(sorted_names)
sorted_indices = np.array(
[name_to_index[n] for n in sorted_names], dtype=ENUM_ARRAY_DTYPE
)
return sorted_names_arr, sorted_indices

@classmethod
def encode(cls, array: Union[EnumArray, np.ndarray]) -> EnumArray:
"""
Expand All @@ -49,34 +65,37 @@ def encode(cls, array: Union[EnumArray, np.ndarray]) -> EnumArray:
if isinstance(array, EnumArray):
return array

# First, convert byte-string arrays to Unicode-string arrays
# Confusingly, Numpy uses "S" to refer to byte-string arrays
# and "U" to refer to Unicode-string arrays, which are also
# referred to as the "str" type
if isinstance(array[0], Enum):
array = np.array([item.name for item in array])
# Handle Enum item arrays by extracting indices directly
if len(array) > 0 and isinstance(array[0], Enum):
indices = np.array(
[item.index for item in array], dtype=ENUM_ARRAY_DTYPE
)
return EnumArray(indices, cls)

# Convert byte-strings or object arrays to Unicode strings
if array.dtype.kind == "S" or array.dtype == object:
# Convert boolean array to string array
array = array.astype(str)

if isinstance(array, np.ndarray) and array.dtype.kind in {"U", "S"}:
# String array
indices = np.select(
[array == item.name for item in cls],
[item.index for item in cls],
)
elif isinstance(array, np.ndarray) and array.dtype.kind == "O":
# Enum items array
if len(array) > 0:
first_item = array[0]
if cls.__name__ == type(first_item).__name__:
# Use the same Enum class as the array items
cls = type(first_item)
indices = np.select(
[array == item for item in cls],
[item.index for item in cls],
)
# String array - use searchsorted for O(n log m) lookup
sorted_names, sorted_indices = cls._get_sorted_lookup_arrays()
positions = np.searchsorted(sorted_names, array)
# Clip positions to valid range to avoid IndexError
positions = np.clip(positions, 0, len(sorted_names) - 1)
# For non-matches, return 0 (first enum value) to match old np.select behaviour
matches = sorted_names[positions] == array
indices = np.where(matches, sorted_indices[positions], 0)
# Log warning for invalid values
invalid_mask = ~matches
if np.any(invalid_mask):
invalid_values = np.unique(array[invalid_mask])
log.warning(
f"Invalid values for enum {cls.__name__}: "
f"{invalid_values.tolist()}. "
f"These will be encoded as index 0."
)
elif array.dtype.kind in {"i", "u"}:
# Integer array
# Integer array - already indices
indices = array
else:
raise ValueError(f"Unsupported array dtype: {array.dtype}")
Expand Down
Loading