Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions changelog_entry.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
- bump: minor
changes:
changed:
- Invalid enum values now raise ValueError instead of logging a warning and returning index 0. This prevents silent data corruption when incorrect enum strings are passed to simulations.
10 changes: 5 additions & 5 deletions policyengine_core/enums/enum.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,14 +85,14 @@ def encode(cls, array: Union[EnumArray, np.ndarray]) -> EnumArray:
# For non-matches, return 0 (first enum value) to match old np.select behaviour
matches = sorted_names[positions] == array
indices = np.where(matches, sorted_indices[positions], 0)
# Log warning for invalid values
# Raise error for invalid values to prevent silent data corruption
invalid_mask = ~matches
if np.any(invalid_mask):
invalid_values = np.unique(array[invalid_mask])
log.warning(
f"Invalid values for enum {cls.__name__}: "
f"{invalid_values.tolist()}. "
f"These will be encoded as index 0."
valid_names = [item.name for item in cls]
raise ValueError(
f"Invalid value(s) {invalid_values.tolist()} for enum "
f"{cls.__name__}. Valid values are: {valid_names}"
)
elif array.dtype.kind in {"i", "u"}:
# Integer array - already indices
Expand Down
57 changes: 35 additions & 22 deletions tests/core/enums/test_enum.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import pytest
import numpy as np
import logging
from policyengine_core.enums.enum import Enum
from policyengine_core.enums.enum_array import EnumArray

Expand Down Expand Up @@ -40,8 +39,8 @@ class Sample(Enum):
assert encoded_array.dtype.kind == "i"


def test_enum_encode_invalid_values_logs_warning(caplog):
"""Test that encoding invalid enum string values logs a warning."""
def test_enum_encode_invalid_values_raises_error():
"""Test that encoding invalid enum string values raises ValueError."""

class Sample(Enum):
MAXWELL = "maxwell"
Expand All @@ -50,34 +49,48 @@ class Sample(Enum):
# Array with invalid values mixed in
array_with_invalid = np.array(["MAXWELL", "INVALID_VALUE", "DWORKIN"])

with caplog.at_level(logging.WARNING):
encoded = Sample.encode(array_with_invalid)
with pytest.raises(ValueError) as exc_info:
Sample.encode(array_with_invalid)

# Should still return an array (with 0 for invalid)
assert len(encoded) == 3
assert encoded[0] == Sample.MAXWELL.index
assert encoded[1] == 0 # Invalid defaults to 0
assert encoded[2] == Sample.DWORKIN.index
error_message = str(exc_info.value)
assert "INVALID_VALUE" in error_message
assert "Sample" in error_message
assert "MAXWELL" in error_message # Valid values listed
assert "DWORKIN" in error_message # Valid values listed

# Should have logged a warning
assert any("INVALID_VALUE" in record.message for record in caplog.records)
assert any("Sample" in record.message for record in caplog.records)


def test_enum_encode_all_invalid_logs_warning(caplog):
"""Test that encoding all invalid values logs a warning."""
def test_enum_encode_all_invalid_raises_error():
"""Test that encoding all invalid values raises ValueError."""

class Sample(Enum):
MAXWELL = "maxwell"
DWORKIN = "dworkin"

all_invalid = np.array(["FOO", "BAR", "BAZ"])

with caplog.at_level(logging.WARNING):
encoded = Sample.encode(all_invalid)
with pytest.raises(ValueError) as exc_info:
Sample.encode(all_invalid)

error_message = str(exc_info.value)
# Should mention all unique invalid values
assert (
"FOO" in error_message
or "BAR" in error_message
or "BAZ" in error_message
)


def test_enum_encode_empty_string_raises_error():
"""Test that encoding empty strings raises ValueError."""

class Sample(Enum):
MAXWELL = "maxwell"
DWORKIN = "dworkin"

array_with_empty = np.array(["MAXWELL", "", "DWORKIN"])

# All should be 0
assert all(encoded == 0)
with pytest.raises(ValueError) as exc_info:
Sample.encode(array_with_empty)

# Should have logged warnings
assert len(caplog.records) > 0
# Empty string should be in the error message (represented as '')
assert "''" in str(exc_info.value) or '""' in str(exc_info.value)