diff --git a/changelog_entry.yaml b/changelog_entry.yaml index e69de29b..c3008017 100644 --- a/changelog_entry.yaml +++ b/changelog_entry.yaml @@ -0,0 +1,4 @@ +- bump: minor + changes: + changed: + - Invalid enum values now raise ValueError instead of logging a warning and returning index 0. This prevents silent data corruption when incorrect enum strings are passed to simulations. diff --git a/policyengine_core/enums/enum.py b/policyengine_core/enums/enum.py index 3072d9a2..5b6e4b9e 100644 --- a/policyengine_core/enums/enum.py +++ b/policyengine_core/enums/enum.py @@ -85,14 +85,14 @@ def encode(cls, array: Union[EnumArray, np.ndarray]) -> EnumArray: # For non-matches, return 0 (first enum value) to match old np.select behaviour matches = sorted_names[positions] == array indices = np.where(matches, sorted_indices[positions], 0) - # Log warning for invalid values + # Raise error for invalid values to prevent silent data corruption invalid_mask = ~matches if np.any(invalid_mask): invalid_values = np.unique(array[invalid_mask]) - log.warning( - f"Invalid values for enum {cls.__name__}: " - f"{invalid_values.tolist()}. " - f"These will be encoded as index 0." + valid_names = [item.name for item in cls] + raise ValueError( + f"Invalid value(s) {invalid_values.tolist()} for enum " + f"{cls.__name__}. Valid values are: {valid_names}" ) elif array.dtype.kind in {"i", "u"}: # Integer array - already indices diff --git a/tests/core/enums/test_enum.py b/tests/core/enums/test_enum.py index 6c8833b3..3605ef6d 100644 --- a/tests/core/enums/test_enum.py +++ b/tests/core/enums/test_enum.py @@ -1,6 +1,5 @@ import pytest import numpy as np -import logging from policyengine_core.enums.enum import Enum from policyengine_core.enums.enum_array import EnumArray @@ -40,8 +39,8 @@ class Sample(Enum): assert encoded_array.dtype.kind == "i" -def test_enum_encode_invalid_values_logs_warning(caplog): - """Test that encoding invalid enum string values logs a warning.""" +def test_enum_encode_invalid_values_raises_error(): + """Test that encoding invalid enum string values raises ValueError.""" class Sample(Enum): MAXWELL = "maxwell" @@ -50,22 +49,18 @@ class Sample(Enum): # Array with invalid values mixed in array_with_invalid = np.array(["MAXWELL", "INVALID_VALUE", "DWORKIN"]) - with caplog.at_level(logging.WARNING): - encoded = Sample.encode(array_with_invalid) + with pytest.raises(ValueError) as exc_info: + Sample.encode(array_with_invalid) - # Should still return an array (with 0 for invalid) - assert len(encoded) == 3 - assert encoded[0] == Sample.MAXWELL.index - assert encoded[1] == 0 # Invalid defaults to 0 - assert encoded[2] == Sample.DWORKIN.index + error_message = str(exc_info.value) + assert "INVALID_VALUE" in error_message + assert "Sample" in error_message + assert "MAXWELL" in error_message # Valid values listed + assert "DWORKIN" in error_message # Valid values listed - # Should have logged a warning - assert any("INVALID_VALUE" in record.message for record in caplog.records) - assert any("Sample" in record.message for record in caplog.records) - -def test_enum_encode_all_invalid_logs_warning(caplog): - """Test that encoding all invalid values logs a warning.""" +def test_enum_encode_all_invalid_raises_error(): + """Test that encoding all invalid values raises ValueError.""" class Sample(Enum): MAXWELL = "maxwell" @@ -73,11 +68,29 @@ class Sample(Enum): all_invalid = np.array(["FOO", "BAR", "BAZ"]) - with caplog.at_level(logging.WARNING): - encoded = Sample.encode(all_invalid) + with pytest.raises(ValueError) as exc_info: + Sample.encode(all_invalid) + + error_message = str(exc_info.value) + # Should mention all unique invalid values + assert ( + "FOO" in error_message + or "BAR" in error_message + or "BAZ" in error_message + ) + + +def test_enum_encode_empty_string_raises_error(): + """Test that encoding empty strings raises ValueError.""" + + class Sample(Enum): + MAXWELL = "maxwell" + DWORKIN = "dworkin" + + array_with_empty = np.array(["MAXWELL", "", "DWORKIN"]) - # All should be 0 - assert all(encoded == 0) + with pytest.raises(ValueError) as exc_info: + Sample.encode(array_with_empty) - # Should have logged warnings - assert len(caplog.records) > 0 + # Empty string should be in the error message (represented as '') + assert "''" in str(exc_info.value) or '""' in str(exc_info.value)