diff --git a/pytools/persistent_dict.py b/pytools/persistent_dict.py index de0b5fe4..050e0be0 100644 --- a/pytools/persistent_dict.py +++ b/pytools/persistent_dict.py @@ -244,8 +244,18 @@ def __call__(self, key: Any) -> str: # understandably don't like it. def update_for_type(self, key_hash: Hash, key: type) -> None: - key_hash.update( - f"{key.__module__}.{key.__qualname__}.{key.__name__}".encode()) + try: + module = sys.modules[key.__module__] + resolved = module + for attr in key.__qualname__.split("."): + resolved = getattr(resolved, attr) # pyright: ignore[reportAny] + if resolved is key: + # Globally accessible: hash based on name + self.rec(key_hash, f"{key.__module__}.{key.__qualname__}") + else: + raise ValueError(f"Cannot hash function-local class '{key}'") + except (KeyError, AttributeError): + raise ValueError(f"Cannot hash function-local class '{key}'") from None update_for_ABCMeta = update_for_type diff --git a/pytools/test/test_persistent_dict.py b/pytools/test/test_persistent_dict.py index 2cdd2a7c..e06c0978 100644 --- a/pytools/test/test_persistent_dict.py +++ b/pytools/test/test_persistent_dict.py @@ -503,8 +503,6 @@ def test_ABC_hashing() -> None: # noqa: N802 class MyABC(ABC): # noqa: B024 pass - assert keyb(MyABC) != keyb(ABC) - with pytest.raises(TypeError): keyb(MyABC()) @@ -515,19 +513,29 @@ class MyABC2(MyABC): def update_persistent_hash(self, key_hash, key_builder): key_builder.rec(key_hash, 42) - assert keyb(MyABC2) != keyb(MyABC) assert keyb(MyABC2()) class MyABC3(metaclass=ABCMeta): # noqa: B024 def update_persistent_hash(self, key_hash, key_builder): key_builder.rec(key_hash, 42) - assert keyb(MyABC3) != keyb(MyABC) != keyb(MyABC3()) + assert keyb(MyABC3()) + + +class WithoutUpdateMethodGlobal: + pass def test_class_hashing() -> None: keyb = KeyBuilder() + assert keyb(WithoutUpdateMethodGlobal) == keyb(WithoutUpdateMethodGlobal) + assert keyb(WithoutUpdateMethodGlobal) == "49c4673089d30507" + + with pytest.raises(TypeError): + # does not have update_persistent_hash() = > will raise + keyb(WithoutUpdateMethodGlobal()) + class WithUpdateMethod: def update_persistent_hash(self, key_hash, key_builder): # Only called for instances of this class, not for the class itself @@ -540,15 +548,11 @@ class TagClass(Tag): class TagClass2(Tag): pass - assert keyb(WithUpdateMethod) != keyb(WithUpdateMethod()) - assert keyb(TagClass) != keyb(TagClass()) - assert keyb(TagClass2) != keyb(TagClass2()) + assert keyb(WithUpdateMethod()) == keyb(WithUpdateMethod()) - assert keyb(TagClass) != keyb(TagClass2) assert keyb(TagClass()) != keyb(TagClass2()) assert keyb(TagClass()) == "7b3e4e66503438f6" - assert keyb(TagClass2) == "690b86bbf51aad83" @tag_dataclass class TagClass3(Tag):