From c01e0f9b47501c96c23cd6e3c8b6c056864ea7e1 Mon Sep 17 00:00:00 2001 From: Max Ghenis Date: Wed, 3 Dec 2025 16:59:06 -0500 Subject: [PATCH] Raise ValueError for invalid enum values instead of warning MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Change the behavior from logging a warning and returning index 0 to raising a ValueError with a clear message. This prevents silent data corruption when incorrect enum values are passed to simulations. The previous behavior (introduced in the searchsorted refactor) would: 1. Log a warning about invalid values 2. Return 0 (first enum value) for the invalid entries 3. Continue execution with corrupted data This was an improvement over the original np.select behavior (which silently returned 0 without any warning), but still allowed simulations to run with incorrect data. Now invalid enum values will raise: ValueError: Invalid value(s) ['MARRIED'] for enum FilingStatus. Valid values are: ['SINGLE', 'JOINT', 'SEPARATE', ...] Fixes #410 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude --- changelog_entry.yaml | 4 +++ policyengine_core/enums/enum.py | 10 +++--- tests/core/enums/test_enum.py | 57 ++++++++++++++++++++------------- 3 files changed, 44 insertions(+), 27 deletions(-) diff --git a/changelog_entry.yaml b/changelog_entry.yaml index e69de29b..c3008017 100644 --- a/changelog_entry.yaml +++ b/changelog_entry.yaml @@ -0,0 +1,4 @@ +- bump: minor + changes: + changed: + - Invalid enum values now raise ValueError instead of logging a warning and returning index 0. This prevents silent data corruption when incorrect enum strings are passed to simulations. diff --git a/policyengine_core/enums/enum.py b/policyengine_core/enums/enum.py index 3072d9a2..5b6e4b9e 100644 --- a/policyengine_core/enums/enum.py +++ b/policyengine_core/enums/enum.py @@ -85,14 +85,14 @@ def encode(cls, array: Union[EnumArray, np.ndarray]) -> EnumArray: # For non-matches, return 0 (first enum value) to match old np.select behaviour matches = sorted_names[positions] == array indices = np.where(matches, sorted_indices[positions], 0) - # Log warning for invalid values + # Raise error for invalid values to prevent silent data corruption invalid_mask = ~matches if np.any(invalid_mask): invalid_values = np.unique(array[invalid_mask]) - log.warning( - f"Invalid values for enum {cls.__name__}: " - f"{invalid_values.tolist()}. " - f"These will be encoded as index 0." + valid_names = [item.name for item in cls] + raise ValueError( + f"Invalid value(s) {invalid_values.tolist()} for enum " + f"{cls.__name__}. Valid values are: {valid_names}" ) elif array.dtype.kind in {"i", "u"}: # Integer array - already indices diff --git a/tests/core/enums/test_enum.py b/tests/core/enums/test_enum.py index 6c8833b3..3605ef6d 100644 --- a/tests/core/enums/test_enum.py +++ b/tests/core/enums/test_enum.py @@ -1,6 +1,5 @@ import pytest import numpy as np -import logging from policyengine_core.enums.enum import Enum from policyengine_core.enums.enum_array import EnumArray @@ -40,8 +39,8 @@ class Sample(Enum): assert encoded_array.dtype.kind == "i" -def test_enum_encode_invalid_values_logs_warning(caplog): - """Test that encoding invalid enum string values logs a warning.""" +def test_enum_encode_invalid_values_raises_error(): + """Test that encoding invalid enum string values raises ValueError.""" class Sample(Enum): MAXWELL = "maxwell" @@ -50,22 +49,18 @@ class Sample(Enum): # Array with invalid values mixed in array_with_invalid = np.array(["MAXWELL", "INVALID_VALUE", "DWORKIN"]) - with caplog.at_level(logging.WARNING): - encoded = Sample.encode(array_with_invalid) + with pytest.raises(ValueError) as exc_info: + Sample.encode(array_with_invalid) - # Should still return an array (with 0 for invalid) - assert len(encoded) == 3 - assert encoded[0] == Sample.MAXWELL.index - assert encoded[1] == 0 # Invalid defaults to 0 - assert encoded[2] == Sample.DWORKIN.index + error_message = str(exc_info.value) + assert "INVALID_VALUE" in error_message + assert "Sample" in error_message + assert "MAXWELL" in error_message # Valid values listed + assert "DWORKIN" in error_message # Valid values listed - # Should have logged a warning - assert any("INVALID_VALUE" in record.message for record in caplog.records) - assert any("Sample" in record.message for record in caplog.records) - -def test_enum_encode_all_invalid_logs_warning(caplog): - """Test that encoding all invalid values logs a warning.""" +def test_enum_encode_all_invalid_raises_error(): + """Test that encoding all invalid values raises ValueError.""" class Sample(Enum): MAXWELL = "maxwell" @@ -73,11 +68,29 @@ class Sample(Enum): all_invalid = np.array(["FOO", "BAR", "BAZ"]) - with caplog.at_level(logging.WARNING): - encoded = Sample.encode(all_invalid) + with pytest.raises(ValueError) as exc_info: + Sample.encode(all_invalid) + + error_message = str(exc_info.value) + # Should mention all unique invalid values + assert ( + "FOO" in error_message + or "BAR" in error_message + or "BAZ" in error_message + ) + + +def test_enum_encode_empty_string_raises_error(): + """Test that encoding empty strings raises ValueError.""" + + class Sample(Enum): + MAXWELL = "maxwell" + DWORKIN = "dworkin" + + array_with_empty = np.array(["MAXWELL", "", "DWORKIN"]) - # All should be 0 - assert all(encoded == 0) + with pytest.raises(ValueError) as exc_info: + Sample.encode(array_with_empty) - # Should have logged warnings - assert len(caplog.records) > 0 + # Empty string should be in the error message (represented as '') + assert "''" in str(exc_info.value) or '""' in str(exc_info.value)