From 8a14dfbe571dbbb70ae6e53a5ad9e9e3b48b9f34 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Mon, 13 Mar 2023 20:55:01 -0500 Subject: [PATCH 1/8] add (strict) typing to Record --- pytools/__init__.py | 32 ++++++++++++++++++-------------- 1 file changed, 18 insertions(+), 14 deletions(-) diff --git a/pytools/__init__.py b/pytools/__init__.py index a2b2d410..700b2309 100644 --- a/pytools/__init__.py +++ b/pytools/__init__.py @@ -408,7 +408,8 @@ class RecordWithoutPickling: __slots__: ClassVar[List[str]] = [] fields: ClassVar[Set[str]] - def __init__(self, valuedict=None, exclude=None, **kwargs): + def __init__(self, valuedict: Optional[Dict[str, Any]] = None, + exclude: Optional[List[str]] = None, **kwargs: Any) -> None: assert self.__class__ is not Record if exclude is None: @@ -427,7 +428,7 @@ def __init__(self, valuedict=None, exclude=None, **kwargs): fields.add(key) setattr(self, key, value) - def get_copy_kwargs(self, **kwargs): + def get_copy_kwargs(self, **kwargs: Any) -> Any: for f in self.__class__.fields: if f not in kwargs: try: @@ -436,17 +437,17 @@ def get_copy_kwargs(self, **kwargs): pass return kwargs - def copy(self, **kwargs): + def copy(self, **kwargs: Any) -> "RecordWithoutPickling": return self.__class__(**self.get_copy_kwargs(**kwargs)) - def __repr__(self): + def __repr__(self) -> str: return "{}({})".format( self.__class__.__name__, ", ".join(f"{fld}={getattr(self, fld)!r}" for fld in self.__class__.fields if hasattr(self, fld))) - def register_fields(self, new_fields): + def register_fields(self, new_fields: Set[str]) -> None: try: fields = self.__class__.fields except AttributeError: @@ -454,7 +455,7 @@ def register_fields(self, new_fields): fields.update(new_fields) - def __getattr__(self, name): + def __getattr__(self, name: str) -> None: # This method is implemented to avoid pylint 'no-member' errors for # attribute access. raise AttributeError( @@ -465,13 +466,13 @@ def __getattr__(self, name): class Record(RecordWithoutPickling): __slots__: ClassVar[List[str]] = [] - def __getstate__(self): + def __getstate__(self) -> Dict[str, Any]: return { key: getattr(self, key) for key in self.__class__.fields if hasattr(self, key)} - def __setstate__(self, valuedict): + def __setstate__(self, valuedict: Dict[str, Any]) -> None: try: fields = self.__class__.fields except AttributeError: @@ -481,30 +482,33 @@ def __setstate__(self, valuedict): fields.add(key) setattr(self, key, value) - def __eq__(self, other): + def __eq__(self, other: Any) -> bool: if self is other: return True + if not isinstance(other, Record): + return False return (self.__class__ == other.__class__ and self.__getstate__() == other.__getstate__()) - def __ne__(self, other): + def __ne__(self, other: Any) -> bool: return not self.__eq__(other) class ImmutableRecordWithoutPickling(RecordWithoutPickling): """Hashable record. Does not explicitly enforce immutability.""" - def __init__(self, *args, **kwargs): + def __init__(self, *args: Any, **kwargs: Any) -> None: RecordWithoutPickling.__init__(self, *args, **kwargs) - self._cached_hash = None + self._cached_hash: Optional[int] = None - def __hash__(self): + def __hash__(self) -> int: # This attribute may vanish during pickling. if getattr(self, "_cached_hash", None) is None: self._cached_hash = hash( (type(self),) + tuple(getattr(self, field) for field in self.__class__.fields)) - return self._cached_hash + from typing import cast + return cast(int, self._cached_hash) class ImmutableRecord(ImmutableRecordWithoutPickling, Record): From 0367e6110ec72d8603dc2586d79986a9769bcd17 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Mon, 13 Mar 2023 21:01:26 -0500 Subject: [PATCH 2/8] add graph.py --- pytools/graph.py | 5 +++-- run-mypy.sh | 2 +- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/pytools/graph.py b/pytools/graph.py index 054c5444..7ac43be8 100644 --- a/pytools/graph.py +++ b/pytools/graph.py @@ -255,8 +255,9 @@ def __init__(self, node: NodeT, key: Any) -> None: self.node = node self.key = key - def __lt__(self, other: "HeapEntry") -> bool: - return self.key < other.key + def __lt__(self, other: Any) -> bool: + from typing import cast + return cast(bool, self.key < other.key) def compute_topological_order(graph: GraphT[NodeT], diff --git a/run-mypy.sh b/run-mypy.sh index 39055a8c..0a300069 100755 --- a/run-mypy.sh +++ b/run-mypy.sh @@ -4,4 +4,4 @@ set -ex mypy --show-error-codes pytools -mypy --strict --follow-imports=skip pytools/datatable.py +mypy --strict --follow-imports=skip pytools/datatable.py pytools/graph.py From cc402df59c7babab1134b1a14c9d1f598bedcb9d Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Mon, 13 Mar 2023 21:13:32 -0500 Subject: [PATCH 3/8] add mpi.py --- pytools/datatable.py | 3 +-- pytools/mpi.py | 12 +++++++----- run-mypy.sh | 2 +- 3 files changed, 9 insertions(+), 8 deletions(-) diff --git a/pytools/datatable.py b/pytools/datatable.py index 41c56909..a2aa21d7 100644 --- a/pytools/datatable.py +++ b/pytools/datatable.py @@ -11,8 +11,7 @@ """ -# type-ignore-reason: Record is untyped -class Row(Record): # type: ignore[misc] +class Row(Record): pass diff --git a/pytools/mpi.py b/pytools/mpi.py index f74c1307..27a0cb9d 100644 --- a/pytools/mpi.py +++ b/pytools/mpi.py @@ -33,10 +33,10 @@ """ from contextlib import AbstractContextManager, contextmanager -from typing import Generator, Tuple, Type, Union +from typing import Any, Callable, Generator, Sequence, Tuple, Type, Union -def check_for_mpi_relaunch(argv): +def check_for_mpi_relaunch(argv: Sequence[Any]) -> None: if argv[1] != "--mpi-relaunch": return @@ -48,7 +48,9 @@ def check_for_mpi_relaunch(argv): sys.exit() -def run_with_mpi_ranks(py_script, ranks, callable_, args=(), kwargs=None): +def run_with_mpi_ranks(py_script: str, ranks: int, + callable_: Callable[[Any], Any], args: Any = (), + kwargs: Any = None) -> None: if kwargs is None: kwargs = {} @@ -70,7 +72,7 @@ def run_with_mpi_ranks(py_script, ranks, callable_, args=(), kwargs=None): def pytest_raises_on_rank(my_rank: int, fail_rank: int, expected_exception: Union[Type[BaseException], Tuple[Type[BaseException], ...]]) \ - -> Generator[AbstractContextManager, None, None]: + -> Generator[AbstractContextManager[Any], None, None]: """ Like :func:`pytest.raises`, but only expect an exception on rank *fail_rank*. """ @@ -79,7 +81,7 @@ def pytest_raises_on_rank(my_rank: int, fail_rank: int, import pytest if my_rank == fail_rank: - cm: AbstractContextManager = pytest.raises(expected_exception) + cm: AbstractContextManager[Any] = pytest.raises(expected_exception) else: cm = nullcontext() diff --git a/run-mypy.sh b/run-mypy.sh index 0a300069..53754863 100755 --- a/run-mypy.sh +++ b/run-mypy.sh @@ -4,4 +4,4 @@ set -ex mypy --show-error-codes pytools -mypy --strict --follow-imports=skip pytools/datatable.py pytools/graph.py +mypy --strict pytools/datatable.py pytools/graph.py pytools/mpi.py From f39354cda7410de29dc57375728fd505639f592b Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Mon, 13 Mar 2023 23:46:27 -0500 Subject: [PATCH 4/8] pass 1 through __init__.py --- pytools/__init__.py | 333 +++++++++++++++++++++++--------------------- run-mypy.sh | 2 +- 2 files changed, 174 insertions(+), 161 deletions(-) diff --git a/pytools/__init__.py b/pytools/__init__.py index 700b2309..649cd09b 100644 --- a/pytools/__init__.py +++ b/pytools/__init__.py @@ -36,8 +36,12 @@ from functools import reduce, wraps from sys import intern from typing import ( - Any, Callable, ClassVar, Dict, Generic, Hashable, Iterable, List, Mapping, - Optional, Set, Tuple, Type, TypeVar, Union) + Any, Callable, cast, ClassVar, Dict, Generic, Hashable, Iterable, List, Mapping, + Optional, Set, Tuple, Type, TypeVar, Union, ValuesView, KeysView, ItemsView, Sequence, Generator, TYPE_CHECKING) + + +if TYPE_CHECKING: + import numpy as np try: @@ -227,7 +231,7 @@ def __init__(self, f: F, deadline: Optional[Union[int, str]] = None) -> None: self.f = f self.deadline = deadline - def __call__(self, *args, **kwargs): + def __call__(self, *args: Any, **kwargs: Any) -> Any: from warnings import warn warn(f"This function is deprecated and will go away in {self.deadline}. " f"Use {self.f.__module__}.{self.f.__name__} instead.", @@ -238,7 +242,7 @@ def __call__(self, *args, **kwargs): def deprecate_keyword(oldkey: str, newkey: Optional[str] = None, *, - deadline: Optional[str] = None): + deadline: Optional[str] = None) -> Callable[[Any], Any]: """Decorator used to deprecate function keyword arguments. :arg oldkey: deprecated argument name. @@ -250,9 +254,9 @@ def deprecate_keyword(oldkey: str, if deadline is None: deadline = "the future" - def wrapper(func): + def wrapper(func: Callable[[Any], Any]) -> Callable[[Any], Any]: @wraps(func) - def inner_wrapper(*args, **kwargs): + def inner_wrapper(*args: Any, **kwargs: Any) -> Any: if oldkey in kwargs: if newkey is None: warn(f"The '{oldkey}' keyword is deprecated and will " @@ -282,7 +286,7 @@ def inner_wrapper(*args, **kwargs): # {{{ math -def delta(x, y): +def delta(x: float, y: float) -> float: if x == y: return 1 else: @@ -307,7 +311,7 @@ def levi_civita(tup: Tuple[int, ...]) -> int: # NOTE: only available in python >= 3.8 perm = MovedFunctionDeprecationWrapper(math.perm, deadline=2023) except AttributeError: - def _unchecked_perm(n, k): + def _unchecked_perm(n: SupportsIndex, k: SupportsIndex) -> int: result = 1 while k: result *= n @@ -370,28 +374,28 @@ def comb(n: SupportsIndex, # type: ignore[misc] return _unchecked_perm(n, k) // math.factorial(k) -def norm_1(iterable): +def norm_1(iterable: Iterable[float]) -> float: return sum(abs(x) for x in iterable) -def norm_2(iterable): - return sum(x**2 for x in iterable)**0.5 +def norm_2(iterable: Iterable[float]) -> float: + return math.sqrt(sum(x**2 for x in iterable)) -def norm_inf(iterable): +def norm_inf(iterable: Iterable[float]) -> float: return max(abs(x) for x in iterable) -def norm_p(iterable, p): - return sum(i**p for i in iterable)**(1/p) +def norm_p(iterable: Iterable[float], p: float) -> float: + return math.pow(sum(i**p for i in iterable), 1/p) class Norm: - def __init__(self, p): + def __init__(self, p: float) -> None: self.p = p - def __call__(self, iterable): - return sum(i**self.p for i in iterable)**(1/self.p) + def __call__(self, iterable: Iterable[float]) -> float: + return math.pow(sum(i**self.p for i in iterable), 1/self.p) # }}} @@ -518,27 +522,27 @@ class ImmutableRecord(ImmutableRecordWithoutPickling, Record): class Reference: - def __init__(self, value): + def __init__(self, value: Any) -> None: self.value = value - def get(self): + def get(self) -> Any: from warnings import warn warn("Reference.get() is deprecated -- use ref.value instead") return self.value - def set(self, value): + def set(self, value: Any) -> None: self.value = value class FakeList: - def __init__(self, f, length): + def __init__(self, f: Callable[[int], Any], length: int) -> None: self._Length = length self._Function = f - def __len__(self): + def __len__(self) -> int: return self._Length - def __getitem__(self, index): + def __getitem__(self, index: int) -> Any: try: return [self._Function(i) for i in range(*index.indices(self._Length))] @@ -549,42 +553,42 @@ def __getitem__(self, index): # {{{ dependent dictionary class DependentDictionary: - def __init__(self, f, start=None): + def __init__(self, f: Callable[[Any], Any], start: Optional[Dict[Any, Any]] = None): if start is None: start = {} self._Function = f self._Dictionary = start.copy() - def copy(self): + def copy(self) -> "DependentDictionary": return DependentDictionary(self._Function, self._Dictionary) - def __contains__(self, key): + def __contains__(self, key: Hashable) -> bool: try: self[key] # pylint: disable=pointless-statement return True except KeyError: return False - def __getitem__(self, key): + def __getitem__(self, key: Hashable) -> Any: try: return self._Dictionary[key] except KeyError: return self._Function(self._Dictionary, key) - def __setitem__(self, key, value): + def __setitem__(self, key: Hashable, value: Any) -> None: self._Dictionary[key] = value - def genuineKeys(self): # noqa + def genuineKeys(self) -> List[Hashable]: # noqa return list(self._Dictionary.keys()) - def iteritems(self): + def iteritems(self) -> ItemsView[Any, Any]: return self._Dictionary.items() - def iterkeys(self): + def iterkeys(self) -> KeysView[Any]: return self._Dictionary.keys() - def itervalues(self): + def itervalues(self) -> ValuesView[Any]: return self._Dictionary.values() # }}} @@ -604,7 +608,7 @@ def one(iterable: Iterable[T]) -> T: except StopIteration: raise ValueError("empty iterable passed to 'one()'") - def no_more(): + def no_more() -> bool: try: next(it) raise ValueError("iterable with more than one entry passed to 'one()'") @@ -635,7 +639,7 @@ def is_single_valued( all_equal = is_single_valued -def all_roughly_equal(iterable, threshold): +def all_roughly_equal(iterable: Iterable[T], threshold: float) -> bool: return is_single_valued(iterable, equality_pred=lambda a, b: abs(a-b) < threshold) @@ -653,7 +657,7 @@ def single_valued( except StopIteration: raise ValueError("empty iterable passed to 'single_valued()'") - def others_same(): + def others_same() -> bool: for other_item in it: if not equality_pred(other_item, first_item): return False @@ -684,7 +688,7 @@ def memoize(*args: F, **kwargs: Any) -> F: default_key_func: Optional[Callable[..., Any]] if use_kw: - def default_key_func(*inner_args, **inner_kwargs): # noqa pylint:disable=function-redefined + def default_key_func(*inner_args: Any, **inner_kwargs: Any) -> Any: # noqa pylint:disable=function-redefined return inner_args, frozenset(inner_kwargs.items()) else: default_key_func = None @@ -697,19 +701,19 @@ def default_key_func(*inner_args, **inner_kwargs): # noqa pylint:disable=functi ", ".join(kwargs.keys()))) if key_func is not None: - def _decorator(func): - def wrapper(*args, **kwargs): + def _decorator(func: F) -> F: + def wrapper(*args: Any, **kwargs: Any) -> Any: key = key_func(*args, **kwargs) try: - return func._memoize_dic[key] # noqa: E501 # pylint: disable=protected-access + return func._memoize_dic[key] # type: ignore[attr-defined] # noqa: E501 # pylint: disable=protected-access except AttributeError: # _memoize_dic doesn't exist yet. result = func(*args, **kwargs) - func._memoize_dic = {key: result} # noqa: E501 # pylint: disable=protected-access + func._memoize_dic = {key: result} # type: ignore[attr-defined] # noqa: E501 # pylint: disable=protected-access return result except KeyError: result = func(*args, **kwargs) - func._memoize_dic[key] = result # noqa: E501 # pylint: disable=protected-access + func._memoize_dic[key] = result # type: ignore[attr-defined] # noqa: E501 # pylint: disable=protected-access return result from functools import update_wrapper @@ -717,18 +721,18 @@ def wrapper(*args, **kwargs): return wrapper else: - def _decorator(func): - def wrapper(*args): + def _decorator(func: F) -> F: + def wrapper(*args: Any) -> Any: try: - return func._memoize_dic[args] # noqa: E501 # pylint: disable=protected-access + return func._memoize_dic[args] # type: ignore[attr-defined] # noqa: E501 # pylint: disable=protected-access except AttributeError: # _memoize_dic doesn't exist yet. result = func(*args) - func._memoize_dic = {args: result} # noqa: E501 # pylint:disable=protected-access + func._memoize_dic = {args: result} # type: ignore[attr-defined] # noqa: E501 # pylint:disable=protected-access return result except KeyError: result = func(*args) - func._memoize_dic[args] = result # noqa: E501 # pylint: disable=protected-access + func._memoize_dic[args] = result # type: ignore[attr-defined] # noqa: E501 # pylint: disable=protected-access return result from functools import update_wrapper @@ -786,7 +790,8 @@ def wrapper(obj: T, *args: P.args, **kwargs: P.kwargs) -> R: getattr(obj, cache_dict_name)[key] = result return result - def clear_cache(obj): + def clear_cache(obj: object) -> None: + assert cache_dict_name is not None object.__delattr__(obj, cache_dict_name) from functools import update_wrapper @@ -862,7 +867,8 @@ def wrapper(obj: T, *args: P.args, **kwargs: P.kwargs) -> R: getattr(obj, cache_dict_name)[cache_key] = result return result - def clear_cache(obj): + def clear_cache(obj: Any) -> None: + assert cache_dict_name is not None object.__delattr__(obj, cache_dict_name) from functools import update_wrapper @@ -993,28 +999,28 @@ class InfixOperator: Following a recipe from http://aspn.activestate.com/ASPN/Cookbook/Python/Recipe/384122 """ - def __init__(self, function): + def __init__(self, function: Callable[..., Any]) -> None: self.function = function - def __rlshift__(self, other): + def __rlshift__(self, other: Any) -> "InfixOperator": return InfixOperator(lambda x: self.function(other, x)) - def __rshift__(self, other): + def __rshift__(self, other: Any) -> Any: return self.function(other) - def call(self, a, b): + def call(self, a: Any, b: Any) -> Any: return self.function(a, b) -def monkeypatch_method(cls): +def monkeypatch_method(cls: Any) -> Callable[..., Any]: # from GvR, http://mail.python.org/pipermail/python-dev/2008-January/076194.html - def decorator(func): + def decorator(func: Callable[..., Any]) -> Callable[..., Any]: setattr(cls, func.__name__, func) return func return decorator -def monkeypatch_class(_name, bases, namespace): +def monkeypatch_class(_name: Any, bases: Sequence[Any], namespace: Dict[str, Any]) -> Any: # from GvR, http://mail.python.org/pipermail/python-dev/2008-January/076194.html assert len(bases) == 1, "Exactly one base class required" @@ -1029,15 +1035,15 @@ def monkeypatch_class(_name, bases, namespace): # {{{ generic utilities -def add_tuples(t1, t2): +def add_tuples(t1: Tuple[Any, ...], t2: Tuple[Any, ...]) -> Tuple[Any, ...]: return tuple([t1v + t2v for t1v, t2v in zip(t1, t2)]) -def negate_tuple(t1): +def negate_tuple(t1: Tuple[Any, ...]) -> Tuple[Any, ...]: return tuple([-t1v for t1v in t1]) -def shift(vec, dist): +def shift(vec: List[Any], dist: int) -> List[Any]: """Return a copy of *vec* shifted by *dist* such that .. code:: python @@ -1058,11 +1064,11 @@ def shift(vec, dist): return result -def len_iterable(iterable): +def len_iterable(iterable: Iterable[Any]) -> int: return sum(1 for i in iterable) -def flatten(iterable): +def flatten(iterable: Iterable[Any]) -> Generator[Any, None, None]: """For an iterable of sub-iterables, generate each member of each sub-iterable in turn, i.e. a flattened version of that super-iterable. @@ -1072,18 +1078,18 @@ def flatten(iterable): yield from sublist -def general_sum(sequence): +def general_sum(sequence: Sequence[float]) -> float: return reduce(operator.add, sequence) -def linear_combination(coefficients, vectors): +def linear_combination(coefficients: Sequence[float], vectors: Sequence[float]) -> float: result = coefficients[0] * vectors[0] for c, v in zip(coefficients[1:], vectors[1:]): result += c*v return result -def common_prefix(iterable, empty=None): +def common_prefix(iterable: Iterable[Any], empty: Optional[Any] = None) -> Any: it = iter(iterable) try: pfx = next(it) @@ -1101,11 +1107,11 @@ def common_prefix(iterable, empty=None): return pfx -def decorate(function, iterable): +def decorate(function: Callable[[Any], Any], iterable: Iterable[Any]) -> List[Tuple[Any, ...]]: return [(x, function(x)) for x in iterable] -def partition(criterion, iterable): +def partition(criterion: Callable[[Any], Any], iterable: Iterable[Any]) -> Tuple[Any, Any]: part_true = [] part_false = [] for i in iterable: @@ -1116,7 +1122,7 @@ def partition(criterion, iterable): return part_true, part_false -def partition2(iterable): +def partition2(iterable: Iterable[Any]) -> Tuple[Any, Any]: part_true = [] part_false = [] for pred, i in iterable: @@ -1132,7 +1138,7 @@ def product(iterable: Iterable[Any]) -> Any: return reduce(mul, iterable, 1) -def reverse_dictionary(the_dict): +def reverse_dictionary(the_dict: Dict[Any, Any]) -> Dict[Any, Any]: result = {} for key, value in the_dict.items(): if value in result: @@ -1142,16 +1148,16 @@ def reverse_dictionary(the_dict): return result -def set_sum(set_iterable): +def set_sum(set_iterable: Iterable[Set[Any]]) -> Set[Any]: from operator import or_ return reduce(or_, set_iterable, set()) -def div_ceil(nr, dr): +def div_ceil(nr: int, dr: int) -> int: return -(-nr // dr) -def uniform_interval_splitting(n, granularity, max_intervals): +def uniform_interval_splitting(n: int, granularity: int, max_intervals: int) -> Tuple[int, int]: """ Return *(interval_size, num_intervals)* such that:: num_intervals * interval_size >= n @@ -1177,7 +1183,7 @@ def uniform_interval_splitting(n, granularity, max_intervals): return interval_size, num_intervals -def find_max_where(predicate, prec=1e-5, initial_guess=1, fail_bound=1e38): +def find_max_where(predicate: Callable[[Any], bool], prec: float = 1e-5, initial_guess: float = 1, fail_bound: float = 1e38) -> float: """Find the largest value for which a predicate is true, along a half-line. 0 is assumed to be the lower bound.""" @@ -1234,7 +1240,7 @@ def find_max_where(predicate, prec=1e-5, initial_guess=1, fail_bound=1e38): # {{{ argmin, argmax -def argmin2(iterable, return_value=False): +def argmin2(iterable: Iterable[Sequence[Any]], return_value: bool = False) -> Union[Any, Tuple[int, Any]]: it = iter(iterable) try: current_argmin, current_min = next(it) @@ -1252,7 +1258,7 @@ def argmin2(iterable, return_value=False): return current_argmin -def argmax2(iterable, return_value=False): +def argmax2(iterable: Iterable[Sequence[Any]], return_value: bool = False) -> Union[Any, Tuple[int, Any]]: it = iter(iterable) try: current_argmax, current_max = next(it) @@ -1270,11 +1276,11 @@ def argmax2(iterable, return_value=False): return current_argmax -def argmin(iterable): +def argmin(iterable: Iterable[Any]) -> Any: return argmin2(enumerate(iterable)) -def argmax(iterable): +def argmax(iterable: Iterable[Any]) -> Any: return argmax2(enumerate(iterable)) # }}} @@ -1282,7 +1288,7 @@ def argmax(iterable): # {{{ cartesian products etc. -def cartesian_product(*args): +def cartesian_product(*args: Iterable[Any]) -> Generator[Tuple[Any, ...], None, None]: if len(args) == 1: for arg in args[0]: yield (arg,) @@ -1293,14 +1299,14 @@ def cartesian_product(*args): yield prod + (i,) -def distinct_pairs(list1, list2): +def distinct_pairs(list1: Sequence[Any], list2: Sequence[Any]) -> Generator[Tuple[Any, Any], None, None]: for i, xi in enumerate(list1): for j, yj in enumerate(list2): if i != j: yield (xi, yj) -def cartesian_product_sum(list1, list2): +def cartesian_product_sum(list1: Sequence[Any], list2: Sequence[Any]) -> Generator[Any, None, None]: """This routine returns a list of sums of each element of list1 with each element of list2. Also works with lists. """ @@ -1313,7 +1319,7 @@ def cartesian_product_sum(list1, list2): # {{{ elementary statistics -def average(iterable): +def average(iterable: Iterable[float]) -> float: """Return the average of the values in iterable. iterable may not be empty. @@ -1338,20 +1344,20 @@ class VarianceAggregator: See http://en.wikipedia.org/wiki/Algorithms_for_calculating_variance Adheres to pysqlite's aggregate interface. """ - def __init__(self, entire_pop): + def __init__(self, entire_pop: int) -> None: self.n = 0 - self.mean = 0 - self.m2 = 0 + self.mean = 0.0 + self.m2 = 0.0 self.entire_pop = entire_pop - def step(self, x): + def step(self, x: float) -> None: self.n += 1 delta_ = x - self.mean self.mean += delta_/self.n self.m2 += delta_*(x - self.mean) - def finalize(self): + def finalize(self) -> Optional[float]: if self.entire_pop: if self.n == 0: return None @@ -1364,7 +1370,7 @@ def finalize(self): return self.m2/(self.n - 1) -def variance(iterable, entire_pop): +def variance(iterable: Iterable[float], entire_pop: int) -> Optional[float]: v_comp = VarianceAggregator(entire_pop) for x in iterable: @@ -1373,9 +1379,9 @@ def variance(iterable, entire_pop): return v_comp.finalize() -def std_deviation(iterable, finite_pop): +def std_deviation(iterable: Iterable[float], finite_pop: int) -> Optional[float]: from math import sqrt - return sqrt(variance(iterable, finite_pop)) + return sqrt(cast(float, variance(iterable, finite_pop))) # }}} @@ -1490,7 +1496,7 @@ def generate_all_integer_tuples_below(n, length, least_abs=0): n, length, least_abs)) -def generate_permutations(original): +def generate_permutations(original: Sequence[Any]) -> Generator[List[Any], None, None]: """Generate all permutations of the list *original*. Nicked from http://aspn.activestate.com/ASPN/Cookbook/Python/Recipe/252178 @@ -1516,7 +1522,7 @@ def generate_unique_permutations(original): yield perm_ -def enumerate_basic_directions(dimensions): +def enumerate_basic_directions(dimensions: int): coordinate_list = [[0], [1], [-1]] return reduce(cartesian_product_sum, [coordinate_list] * dimensions)[1:] @@ -1656,7 +1662,7 @@ def _get_alignments(self) -> Tuple[str, ...]: + (self.alignments[-1],) * (self.ncolumns - len(self.alignments)) ) - def _get_column_widths(self, rows) -> Tuple[int, ...]: + def _get_column_widths(self, rows: Tuple[Any, ...]) -> Tuple[int, ...]: return tuple([ max(len(row[i]) for row in rows) for i in range(self.ncolumns) ]) @@ -1822,7 +1828,7 @@ def merge_tables(*tables: Table, if isinstance(skip_columns, int): skip_columns = (skip_columns,) - def remove_columns(i, row): + def remove_columns(i: int, row: Tuple[Any, ...]) -> Tuple[Any, ...]: if i == 0 or skip_columns is None: return row else: @@ -1837,7 +1843,7 @@ def remove_columns(i, row): result = Table(alignments=alignments) for i in range(tables[0].nrows): - row = [] + row: List[Any] = [] for j, tbl in enumerate(tables): row.extend(remove_columns(j, tbl.rows[i])) @@ -1851,8 +1857,8 @@ def remove_columns(i, row): # {{{ histogram formatting def string_histogram( # pylint: disable=too-many-arguments,too-many-locals - iterable, min_value=None, max_value=None, - bin_count=20, width=70, bin_starts=None, use_unicode=True): + iterable: Iterable[float], min_value: Optional[float] = None, max_value: Optional[float] = None, + bin_count: int = 20, width: int = 70, bin_starts: Optional[Sequence[float]] = None, use_unicode: bool = True) -> str: if bin_starts is None: if min_value is None or max_value is None: iterable = list(iterable) @@ -1879,7 +1885,7 @@ def string_histogram( # pylint: disable=too-many-arguments,too-many-locals from math import ceil, floor if use_unicode: - def format_bar(cnt): + def format_bar(cnt: int) -> str: scaled = cnt*width/max_count full = int(floor(scaled)) eighths = int(ceil((scaled-full)*8)) @@ -1888,7 +1894,7 @@ def format_bar(cnt): else: return full*chr(0x2588) else: - def format_bar(cnt): + def format_bar(cnt: int) -> str: return int(ceil(cnt*width/max_count))*"#" max_count = max(bins) @@ -1903,7 +1909,7 @@ def format_bar(cnt): # }}} -def word_wrap(text, width, wrap_using="\n"): +def word_wrap(text: str, width: int, wrap_using: str = "\n") -> str: # http://code.activestate.com/recipes/148061-one-liner-word-wrap-function/ r""" A word-wrap function that preserves existing line breaks @@ -1926,7 +1932,7 @@ def word_wrap(text, width, wrap_using="\n"): # {{{ command line interfaces -def _exec_arg(arg, execenv): +def _exec_arg(arg: str, execenv: Dict[str, Any]) -> None: import os if os.access(arg, os.F_OK): exec(compile(open(arg), arg, "exec"), execenv) @@ -1938,7 +1944,7 @@ class CPyUserInterface: class Parameters(Record): pass - def __init__(self, variables, constants=None, doc=None): + def __init__(self, variables: Dict[str, Any], constants: Optional[Dict[str, Any]] = None, doc: Optional[Dict[str, Any]] = None) -> None: if constants is None: constants = {} if doc is None: @@ -1947,7 +1953,7 @@ def __init__(self, variables, constants=None, doc=None): self.constants = constants self.doc = doc - def show_usage(self, progname): + def show_usage(self, progname: str) -> None: print(f"usage: {progname} ") print() print("FILE-OR-STATEMENTS may either be Python statements of the form") @@ -1969,7 +1975,7 @@ def show_usage(self, progname): if c in self.doc: print(f" {self.doc[c]}") - def gather(self, argv=None): + def gather(self, argv: Optional[Sequence[Any]] = None) -> Parameters: if argv is None: argv = sys.argv @@ -2001,7 +2007,7 @@ def gather(self, argv=None): self.validate(result) return result - def validate(self, setup): + def validate(self, setup: Any) -> None: pass # }}} @@ -2010,18 +2016,18 @@ def validate(self, setup): # {{{ debugging class StderrToStdout: - def __enter__(self): + def __enter__(self) -> None: # pylint: disable=attribute-defined-outside-init self.stderr_backup = sys.stderr sys.stderr = sys.stdout - def __exit__(self, exc_type, exc_val, exc_tb): + def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: sys.stderr = self.stderr_backup del self.stderr_backup def typedump(val: Any, max_seq: int = 5, - special_handlers: Optional[Mapping[Type, Callable]] = None, + special_handlers: Optional[Mapping[Type[Any], Callable[[Any], str]]] = None, fully_qualified_name: bool = True) -> str: """ Return a string representation of the type of *val*, recursing into @@ -2089,7 +2095,7 @@ def objname(obj: Any) -> str: return objname(val) -def invoke_editor(s, filename="edit.txt", descr="the file"): +def invoke_editor(s: str, filename: str = "edit.txt", descr: str = "the file") -> str: from tempfile import mkdtemp tempdir = mkdtemp() @@ -2130,7 +2136,7 @@ class ProgressBar: # pylint: disable=too-many-instance-attributes .. automethod:: __enter__ .. automethod:: __exit__ """ - def __init__(self, descr, total, initial=0, length=40): + def __init__(self, descr: str, total: float, initial: float = 0, length: float = 40): import time self.description = descr self.total = total @@ -2143,9 +2149,9 @@ def __init__(self, descr, total, initial=0, length=40): self.speed_meas_start_time = self.start_time self.speed_meas_start_done = initial - self.time_per_step = None + self.time_per_step: Optional[float] = None - def draw(self): + def draw(self) -> None: import time now = time.time() @@ -2172,27 +2178,27 @@ def draw(self): sys.stderr.write("{:<20} [{}] ETA {}\r".format( self.description, - squares*"#"+(self.length-squares)*" ", + squares*"#"+int(self.length-squares)*" ", eta_str)) self.last_squares = squares self.last_update_time = now - def progress(self, steps=1): + def progress(self, steps: int = 1) -> None: self.set_progress(self.done + steps) - def set_progress(self, done): + def set_progress(self, done: float) -> None: self.done = done self.draw() - def finished(self): + def finished(self) -> None: self.set_progress(self.total) sys.stderr.write("\n") - def __enter__(self): + def __enter__(self) -> None: self.draw() - def __exit__(self, exc_type, exc_val, exc_tb): + def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: self.finished() # }}} @@ -2200,13 +2206,13 @@ def __exit__(self, exc_type, exc_val, exc_tb): # {{{ file system related -def assert_not_a_file(name): +def assert_not_a_file(name: str) -> None: import os if os.access(name, os.F_OK): raise OSError(f"file `{name}' already exists") -def add_python_path_relative_to_script(rel_path): +def add_python_path_relative_to_script(rel_path: str) -> None: from os.path import abspath, dirname, join script_name = sys.argv[0] @@ -2219,10 +2225,12 @@ def add_python_path_relative_to_script(rel_path): # {{{ numpy dtype mangling -def common_dtype(dtypes, default=None): +def common_dtype(dtypes: Sequence["np.dtype[Any]"], default: Optional["np.dtype[Any]"] = None) -> "np.dtype[Any]": + import numpy as np + dtypes = list(dtypes) if dtypes: - return argmax2((dtype, dtype.num) for dtype in dtypes) + return cast(np.dtype[Any], argmax2((dtype, dtype.num) for dtype in dtypes)) else: if default is not None: return default @@ -2231,12 +2239,12 @@ def common_dtype(dtypes, default=None): "cannot find common dtype of empty dtype list") -def to_uncomplex_dtype(dtype): +def to_uncomplex_dtype(dtype: "np.dtype[Any]") -> "Type[Any]": import numpy as np return np.array(1, dtype=dtype).real.dtype.type -def match_precision(dtype, dtype_to_match): +def match_precision(dtype: "np.dtype[Any]", dtype_to_match: "np.dtype[Any]") -> "np.dtype[Any]": import numpy tgt_is_double = dtype_to_match in [ @@ -2259,7 +2267,7 @@ def match_precision(dtype, dtype_to_match): # {{{ unique name generation -def generate_unique_names(prefix): +def generate_unique_names(prefix: str) -> Generator[str, None, None]: yield prefix try_num = 0 @@ -2391,19 +2399,19 @@ def __call__(self, based_on: str = "id") -> str: # {{{ recursion limit class MinRecursionLimit: - def __init__(self, min_rec_limit): + def __init__(self, min_rec_limit: int) -> None: self.min_rec_limit = min_rec_limit - def __enter__(self): + def __enter__(self) -> None: # pylint: disable=attribute-defined-outside-init self.prev_recursion_limit = sys.getrecursionlimit() new_limit = max(self.prev_recursion_limit, self.min_rec_limit) sys.setrecursionlimit(new_limit) - def __exit__(self, exc_type, exc_val, exc_tb): + def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: # Deep recursion can produce deeply nested data structures - # (or long chains of to-be gc'd generators) that cannot be + # (or long chains of to-be gc'd generators) that cannot # undergo garbage collection with a lower recursion limit. # # As a result, it doesn't seem possible to lower the recursion limit @@ -2419,7 +2427,7 @@ def __exit__(self, exc_type, exc_val, exc_tb): # {{{ download from web if not present -def download_from_web_if_not_present(url, local_name=None): +def download_from_web_if_not_present(url: str, local_name: Optional[str] = None) -> None: """ .. versionadded:: 2017.5 """ @@ -2447,7 +2455,7 @@ def download_from_web_if_not_present(url, local_name=None): # {{{ find git revisions -def find_git_revision(tree_root): # pylint: disable=too-many-locals +def find_git_revision(tree_root): # type: ignore[no-untyped-def] # pylint: disable=too-many-locals # Keep this routine self-contained so that it can be copy-pasted into # setup.py. @@ -2477,7 +2485,7 @@ def find_git_revision(tree_root): # pylint: disable=too-many-locals cwd=tree_root, env=env) (git_rev, _) = p.communicate() - git_rev = git_rev.decode() + git_rev = git_rev.decode() # type: ignore[assignment] git_rev = git_rev.rstrip() @@ -2491,18 +2499,18 @@ def find_git_revision(tree_root): # pylint: disable=too-many-locals return git_rev -def find_module_git_revision(module_file, n_levels_up): +def find_module_git_revision(module_file, n_levels_up): # type: ignore[no-untyped-def] from os.path import dirname, join tree_root = join(*([dirname(module_file)] + [".." * n_levels_up])) - return find_git_revision(tree_root) + return find_git_revision(tree_root) # type: ignore[no-untyped-call] # }}} # {{{ create a reshaped view of a numpy array -def reshaped_view(a, newshape): +def reshaped_view(a: "np.ndarray[Any, Any]", newshape: Tuple[int, ...]) -> "np.ndarray[Any, Any]": """ Create a new view object with shape ``newshape`` without copying the data of ``a``. This function is different from ``numpy.reshape`` by raising an exception when data copy is necessary. @@ -2540,30 +2548,33 @@ class ProcessTimer: .. versionadded:: 2018.5 """ - def __init__(self): + def __init__(self) -> None: import time self.perf_counter_start = time.perf_counter() self.process_time_start = time.process_time() - self.wall_elapsed = None - self.process_elapsed = None + self.wall_elapsed: Optional[float] = None + self.process_elapsed: Optional[float] = None - def __enter__(self): + def __enter__(self) -> "ProcessTimer": return self - def __exit__(self, exc_type, exc_val, exc_tb): + def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: self.done() - def done(self): + def done(self) -> None: import time self.wall_elapsed = time.perf_counter() - self.perf_counter_start self.process_elapsed = time.process_time() - self.process_time_start - def __str__(self): - cpu = self.process_elapsed / self.wall_elapsed + def __str__(self) -> str: + if self.process_elapsed and self.wall_elapsed: + cpu = (self.process_elapsed / self.wall_elapsed) + else: + cpu = 0.0 return f"{self.wall_elapsed:.2f}s wall {cpu:.2f}x CPU" - def __repr__(self): + def __repr__(self) -> str: wall = self.wall_elapsed process = self.process_elapsed @@ -2588,8 +2599,8 @@ class ProcessLogger: # pylint: disable=too-many-instance-attributes default_noisy_level = logging.INFO def __init__( # pylint: disable=too-many-arguments - self, logger, description, - silent_level=None, noisy_level=None, long_threshold_seconds=None): + self, logger: logging.Logger, description: str, + silent_level: Optional[int] =None, noisy_level: Optional[int] = None, long_threshold_seconds: Optional[float] = None) -> None: self.logger = logger self.description = description self.silent_level = silent_level or logging.DEBUG @@ -2643,7 +2654,7 @@ def __init__( # pylint: disable=too-many-arguments self.timer = ProcessTimer() - def _log_start_if_long(self): + def _log_start_if_long(self) -> None: from time import sleep sleep_duration = 10*self.long_threshold_seconds @@ -2656,13 +2667,13 @@ def _log_start_if_long(self): sleep_duration) def done( # pylint: disable=keyword-arg-before-vararg - self, extra_msg=None, *extra_fmt_args): + self, extra_msg: Optional[str] = None, *extra_fmt_args: Any) -> None: self.timer.done() self.is_done = True completion_level = ( self.noisy_level - if self.timer.wall_elapsed > self.long_threshold_seconds + if self.timer.wall_elapsed is not None and self.timer.wall_elapsed > self.long_threshold_seconds else self.silent_level) msg = "%s: completed (%s)" @@ -2674,10 +2685,10 @@ def done( # pylint: disable=keyword-arg-before-vararg self.logger.log(completion_level, msg, *fmt_args) - def __enter__(self): + def __enter__(self) -> None: pass - def __exit__(self, exc_type, exc_val, exc_tb): + def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: self.done() @@ -2693,13 +2704,13 @@ class log_process: # noqa: N801 .. automethod:: __call__ """ - def __init__(self, logger, description=None, long_threshold_seconds=None): + def __init__(self, logger: logging.Logger, description: Optional[str] = None, long_threshold_seconds: Optional[float] = None) -> None: self.logger = logger self.description = description self.long_threshold_seconds = long_threshold_seconds - def __call__(self, wrapped): - def wrapper(*args, **kwargs): + def __call__(self, wrapped: Any) -> Any: + def wrapper(*args: Any, **kwargs: Any) -> Any: with ProcessLogger( self.logger, self.description or wrapped.__name__, @@ -2716,7 +2727,7 @@ def wrapper(*args, **kwargs): # {{{ sorting in natural order -def natorder(item): +def natorder(item: Any) -> List[Any]: """Return a key for natural order string comparison. See :func:`natsorted`. @@ -2738,7 +2749,7 @@ def natorder(item): return result -def natsorted(iterable, key=None, reverse=False): +def natsorted(iterable: Iterable[Any], key: Optional[Callable[[Any], Any]] = None, reverse: bool = False) -> List[Any]: """Sort using natural order [1]_, as opposed to lexicographic order. Example:: @@ -2762,7 +2773,9 @@ def natsorted(iterable, key=None, reverse=False): """ if key is None: key = lambda x: x - return sorted(iterable, key=lambda y: natorder(key(y)), reverse=reverse) + + # type-ignore-reason: mypy thinks key could be None + return sorted(iterable, key=lambda y: natorder(key(y)), reverse=reverse) # type: ignore[misc] # }}} @@ -2776,7 +2789,7 @@ def natsorted(iterable, key=None, reverse=False): del _DOTTED_WORDS -def resolve_name(name): +def resolve_name(name: str) -> object: """A backport of :func:`pkgutil.resolve_name` (added in Python 3.9). .. versionadded:: 2021.1.2 @@ -2825,7 +2838,7 @@ def resolve_name(name): # {{{ unordered_hash -def unordered_hash(hash_instance, iterable, hash_constructor=None): +def unordered_hash(hash_instance: Any, iterable: Iterable[Any], hash_constructor: Optional[Callable[[], Any]] = None) -> Any: """Using a hash algorithm given by the parameter-less constructor *hash_constructor*, return a hash object whose internal state depends on the entries of *iterable*, but not their order. If *hash* @@ -2870,7 +2883,7 @@ def unordered_hash(hash_instance, iterable, hash_constructor=None): # {{{ sphere_sample -def sphere_sample_equidistant(npoints_approx: int, r: float = 1.0): +def sphere_sample_equidistant(npoints_approx: int, r: float = 1.0) -> "np.ndarray[Any, Any]": """Generate points regularly distributed on a sphere based on https://www.cmu.edu/biolphys/deserno/pdf/sphere_equi.pdf. @@ -2879,7 +2892,7 @@ def sphere_sample_equidistant(npoints_approx: int, r: float = 1.0): """ import numpy as np - points: List[np.ndarray] = [] + points: List[np.ndarray[Any, Any]] = [] count = 0 a = 4 * np.pi / npoints_approx @@ -2919,7 +2932,7 @@ def sphere_sample_equidistant(npoints_approx: int, r: float = 1.0): def sphere_sample_fibonacci( npoints: int, r: float = 1.0, *, - optimize: Optional[str] = None): + optimize: Optional[str] = None) -> "np.ndarray[Any, Any]": """Generate points on a sphere based on an offset Fibonacci lattice from [2]_. .. [2] http://extremelearning.com.au/how-to-evenly-distribute-points-on-a-sphere-more-effectively-than-the-canonical-fibonacci-lattice/ @@ -2995,7 +3008,7 @@ def strtobool(val: Optional[str], default: Optional[bool] = None) -> bool: # }}} -def _test(): +def _test() -> None: import doctest doctest.testmod() diff --git a/run-mypy.sh b/run-mypy.sh index 53754863..5e3ad909 100755 --- a/run-mypy.sh +++ b/run-mypy.sh @@ -4,4 +4,4 @@ set -ex mypy --show-error-codes pytools -mypy --strict pytools/datatable.py pytools/graph.py pytools/mpi.py +mypy --strict --follow-imports=skip pytools/datatable.py pytools/graph.py pytools/mpi.py From 0d3f910ab9eb05f3cf3d7ac4c0d5fd32879f4652 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Tue, 14 Mar 2023 00:01:43 -0500 Subject: [PATCH 5/8] pass 2 --- pytools/__init__.py | 93 ++++++++------------------------------------- 1 file changed, 15 insertions(+), 78 deletions(-) diff --git a/pytools/__init__.py b/pytools/__init__.py index 649cd09b..c3df2f4b 100644 --- a/pytools/__init__.py +++ b/pytools/__init__.py @@ -307,71 +307,9 @@ def levi_civita(tup: Tuple[int, ...]) -> int: factorial = MovedFunctionDeprecationWrapper(math.factorial, deadline=2023) -try: - # NOTE: only available in python >= 3.8 - perm = MovedFunctionDeprecationWrapper(math.perm, deadline=2023) -except AttributeError: - def _unchecked_perm(n: SupportsIndex, k: SupportsIndex) -> int: - result = 1 - while k: - result *= n - n -= 1 - k -= 1 - - return result - - def perm(n: SupportsIndex, # type: ignore[misc] - k: Optional[SupportsIndex] = None) -> int: - """ - :returns: :math:`P(n, k)`, the number of permutations of length :math:`k` - drawn from :math:`n` choices. - """ - from warnings import warn - warn("This function is deprecated and will go away in 2023. " - "Use `math.perm` instead, which is available from Python 3.8.", - DeprecationWarning, stacklevel=2) - - if k is None: - return math.factorial(n) - - import operator - n, k = operator.index(n), operator.index(k) - if k > n: - return 0 - - if k < 0: - raise ValueError("k must be a non-negative integer") +perm = MovedFunctionDeprecationWrapper(math.perm, deadline=2023) - if n < 0: - raise ValueError("n must be a non-negative integer") - - from numbers import Integral - if not isinstance(k, Integral): - raise TypeError(f"'{type(k).__name__}' object cannot be interpreted " - "as an integer") - - if not isinstance(n, Integral): - raise TypeError(f"'{type(n).__name__}' object cannot be interpreted " - "as an integer") - - return _unchecked_perm(n, k) - -try: - # NOTE: only available in python >= 3.8 - comb = MovedFunctionDeprecationWrapper(math.comb, deadline=2023) -except AttributeError: - def comb(n: SupportsIndex, # type: ignore[misc] - k: SupportsIndex) -> int: - """ - :returns: :math:`C(n, k)`, the number of combinations (subsets) - of length :math:`k` drawn from :math:`n` choices. - """ - from warnings import warn - warn("This function is deprecated and will go away in 2023. " - "Use `math.comb` instead, which is available from Python 3.8.", - DeprecationWarning, stacklevel=2) - - return _unchecked_perm(n, k) // math.factorial(k) +comb = MovedFunctionDeprecationWrapper(math.comb, deadline=2023) def norm_1(iterable: Iterable[float]) -> float: @@ -542,18 +480,20 @@ def __init__(self, f: Callable[[int], Any], length: int) -> None: def __len__(self) -> int: return self._Length - def __getitem__(self, index: int) -> Any: + def __getitem__(self, index: Union[slice, int]) -> Any: try: + assert isinstance(index, slice) return [self._Function(i) for i in range(*index.indices(self._Length))] except AttributeError: + assert isinstance(index, int) return self._Function(index) # {{{ dependent dictionary class DependentDictionary: - def __init__(self, f: Callable[[Any], Any], start: Optional[Dict[Any, Any]] = None): + def __init__(self, f: Callable[[Any, Any], Any], start: Optional[Dict[Any, Any]] = None): if start is None: start = {} @@ -641,7 +581,7 @@ def is_single_valued( def all_roughly_equal(iterable: Iterable[T], threshold: float) -> bool: return is_single_valued(iterable, - equality_pred=lambda a, b: abs(a-b) < threshold) + equality_pred=lambda a, b: abs(cast(float, a)-cast(float, b)) < threshold) def single_valued( @@ -671,7 +611,7 @@ def others_same() -> bool: # {{{ memoization / attribute storage -def memoize(*args: F, **kwargs: Any) -> F: +def memoize(*args: F, **kwargs: Any) -> Callable[[Any], Any]: """Stores previously computed function values in a cache. Two keyword-only arguments are supported: @@ -701,7 +641,7 @@ def default_key_func(*inner_args: Any, **inner_kwargs: Any) -> Any: # noqa pyli ", ".join(kwargs.keys()))) if key_func is not None: - def _decorator(func: F) -> F: + def _decorator(func: F) -> Callable[[Any], Any]: def wrapper(*args: Any, **kwargs: Any) -> Any: key = key_func(*args, **kwargs) try: @@ -721,7 +661,7 @@ def wrapper(*args: Any, **kwargs: Any) -> Any: return wrapper else: - def _decorator(func: F) -> F: + def _decorator(func: F) -> Callable[[Any], Any]: def wrapper(*args: Any) -> Any: try: return func._memoize_dic[args] # type: ignore[attr-defined] # noqa: E501 # pylint: disable=protected-access @@ -740,7 +680,7 @@ def wrapper(*args: Any) -> Any: return wrapper if not args: - return _decorator # type: ignore + return _decorator if callable(args[0]) and len(args) == 1: return _decorator(args[0]) raise TypeError( @@ -894,7 +834,7 @@ class keyed_memoize_method(keyed_memoize_on_first_arg): # noqa: N801 Can memoize methods on classes that do not allow setting attributes (e.g. by overwritting ``__setattr__``), e.g. frozen :mod:`dataclasses`. """ - def _default_cache_dict_name(self, function): + def _default_cache_dict_name(self, function: Callable[[Any], Any]) -> str: return intern(f"_memoize_dic_{function.__name__}") @@ -944,10 +884,7 @@ def new_inner(*args: P.args, **kwargs: P.kwargs) -> R: self.cache_dict[args] = result return result - # NOTE: mypy gets confused because it types `wraps` as - # Callable[[VarArg(Any)], Any] - # which, for some reason, is not compatible with `F` - return new_inner # type: ignore[return-value] + return new_inner class keyed_memoize_in(Generic[P, R]): # noqa @@ -1388,12 +1325,12 @@ def std_deviation(iterable: Iterable[float], finite_pop: int) -> Optional[float] # {{{ permutations, tuples, integer sequences -def wandering_element(length, wanderer=1, landscape=0): +def wandering_element(length: int, wanderer: float = 1, landscape: float = 0) -> Generator[Tuple[float, ...], None, None]: for i in range(length): yield i*(landscape,) + (wanderer,) + (length-1-i)*(landscape,) -def indices_in_shape(shape): +def indices_in_shape(shape: Sequence[int]) -> Generator[Tuple[int, ...], None, None]: from warnings import warn warn("indices_in_shape is deprecated. You should prefer numpy.ndindex.", DeprecationWarning, stacklevel=2) From 4c6c266d35c715f7d8094487bd274c5d36020000 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Tue, 14 Mar 2023 17:32:36 -0500 Subject: [PATCH 6/8] pass 3 --- pytools/__init__.py | 46 ++++++++++++++++++++++++--------------------- 1 file changed, 25 insertions(+), 21 deletions(-) diff --git a/pytools/__init__.py b/pytools/__init__.py index c3df2f4b..93a060ee 100644 --- a/pytools/__init__.py +++ b/pytools/__init__.py @@ -716,7 +716,7 @@ def wrapper(obj: T, *args: P.args, **kwargs: P.kwargs) -> R: assert cache_dict_name is not None try: - return getattr(obj, cache_dict_name)[key] + return cast(R, getattr(obj, cache_dict_name)[key]) except AttributeError: attribute_error = True except KeyError: @@ -741,7 +741,7 @@ def clear_cache(obj: object) -> None: # into the function's dict is moderately sketchy. new_wrapper.clear_cache = clear_cache # type: ignore[attr-defined] - return new_wrapper + return cast(Callable[Concatenate[T, P], R], new_wrapper) def memoize_method( @@ -797,7 +797,7 @@ def wrapper(obj: T, *args: P.args, **kwargs: P.kwargs) -> R: assert cache_dict_name is not None try: - return getattr(obj, cache_dict_name)[cache_key] + return cast(R, getattr(obj, cache_dict_name)[cache_key]) except AttributeError: result = function(obj, *args, **kwargs) object.__setattr__(obj, cache_dict_name, {cache_key: result}) @@ -815,10 +815,10 @@ def clear_cache(obj: Any) -> None: new_wrapper = update_wrapper(wrapper, function) new_wrapper.clear_cache = clear_cache # type: ignore[attr-defined] - return new_wrapper + return cast(Callable[Concatenate[T, P], R], new_wrapper) -class keyed_memoize_method(keyed_memoize_on_first_arg): # noqa: N801 +class keyed_memoize_method(keyed_memoize_on_first_arg[T, P, R]): # noqa: N801 """Like :class:`memoize_method`, but additionally uses a function *key* to compute the key under which the function result is stored. @@ -870,7 +870,7 @@ def __init__(self, container: Any, identifier: Hashable) -> None: object.__setattr__(container, "_pytools_memoize_in_dict", memoize_in_dict) - self.cache_dict = memoize_in_dict.setdefault(identifier, {}) + self.cache_dict: Dict[P.args, R] = memoize_in_dict.setdefault(identifier, {}) def __call__(self, inner: Callable[P, R]) -> Callable[P, R]: @wraps(inner) @@ -907,7 +907,7 @@ def __init__(self, object.__setattr__(container, "_pytools_keyed_memoize_in_dict", memoize_in_dict) - self.cache_dict = memoize_in_dict.setdefault(identifier, {}) + self.cache_dict: Dict[P.args, R] = memoize_in_dict.setdefault(identifier, {}) self.key = key def __call__(self, inner: Callable[P, R]) -> Callable[P, R]: @@ -1243,7 +1243,7 @@ def distinct_pairs(list1: Sequence[Any], list2: Sequence[Any]) -> Generator[Tupl yield (xi, yj) -def cartesian_product_sum(list1: Sequence[Any], list2: Sequence[Any]) -> Generator[Any, None, None]: +def cartesian_product_sum(list1: List[int], list2: List[int]) -> Generator[int, None, None]: """This routine returns a list of sums of each element of list1 with each element of list2. Also works with lists. """ @@ -1350,17 +1350,19 @@ def indices_in_shape(shape: Sequence[int]) -> Generator[Tuple[int, ...], None, N yield (i,)+rest -def generate_nonnegative_integer_tuples_below(n, length=None, least=0): +def generate_nonnegative_integer_tuples_below(n: Union[int, Sequence[int]], length: Optional[int] = None, least: int = 0) -> Generator[Union[Tuple[int, ...], Tuple[()]], None, None]: """n may be a sequence, in which case length must be None.""" if length is None: if not n: yield () return + n = cast(Sequence[int], n) my_n = n[0] n = n[1:] next_length = None else: + n = cast(int, n) my_n = n assert length >= 0 @@ -1373,14 +1375,16 @@ def generate_nonnegative_integer_tuples_below(n, length=None, least=0): for i in range(least, my_n): my_part = (i,) for base in generate_nonnegative_integer_tuples_below(n, next_length, least): + base = cast(Tuple[Any], base) yield my_part + base def generate_decreasing_nonnegative_tuples_summing_to( - n, length, min_value=0, max_value=None): + n: int, length: int, min_value: int = 0, max_value: Optional[int] = None) -> Generator[Union[Tuple[int, ...], Tuple[()]], None, None]: if length == 0: yield () elif length == 1: + max_value = cast(int, max_value) if n <= max_value: #print "MX", n, max_value yield (n,) @@ -1397,7 +1401,7 @@ def generate_decreasing_nonnegative_tuples_summing_to( yield (i,) + remainder -def generate_nonnegative_integer_tuples_summing_to_at_most(n, length): +def generate_nonnegative_integer_tuples_summing_to_at_most(n: int, length: int) -> Generator[Union[Tuple[int, ...], Tuple[()]], None, None]: """Enumerate all non-negative integer tuples summing to at most n, exhausting the search space by varying the first entry fastest, and the last entry the slowest. @@ -1416,7 +1420,7 @@ def generate_nonnegative_integer_tuples_summing_to_at_most(n, length): generate_positive_integer_tuples_below = generate_nonnegative_integer_tuples_below -def _pos_and_neg_adaptor(tuple_iter): +def _pos_and_neg_adaptor(tuple_iter: Iterable[Tuple[float, ...]]) -> Generator[Tuple[float, ...], None, None]: for tup in tuple_iter: nonzero_indices = [i for i in range(len(tup)) if tup[i] != 0] for do_neg_tup in generate_nonnegative_integer_tuples_below( @@ -1428,12 +1432,12 @@ def _pos_and_neg_adaptor(tuple_iter): yield tuple(this_result) -def generate_all_integer_tuples_below(n, length, least_abs=0): +def generate_all_integer_tuples_below(n: int, length: int, least_abs: int = 0) -> Generator[Tuple[float, ...], None, None]: return _pos_and_neg_adaptor(generate_nonnegative_integer_tuples_below( n, length, least_abs)) -def generate_permutations(original: Sequence[Any]) -> Generator[List[Any], None, None]: +def generate_permutations(original: List[Any]) -> Generator[List[Any], None, None]: """Generate all permutations of the list *original*. Nicked from http://aspn.activestate.com/ASPN/Cookbook/Python/Recipe/252178 @@ -1447,7 +1451,7 @@ def generate_permutations(original: Sequence[Any]) -> Generator[List[Any], None, yield perm_[:i] + original[0:1] + perm_[i:] -def generate_unique_permutations(original): +def generate_unique_permutations(original: List[Any]) -> Generator[List[Any], None, None]: """Generate all unique permutations of the list *original*. """ @@ -1459,16 +1463,16 @@ def generate_unique_permutations(original): yield perm_ -def enumerate_basic_directions(dimensions: int): +def enumerate_basic_directions(dimensions: int) -> List[List[int]]: coordinate_list = [[0], [1], [-1]] - return reduce(cartesian_product_sum, [coordinate_list] * dimensions)[1:] + return reduce(cartesian_product_sum, [coordinate_list] * dimensions)[1:] # type: ignore[arg-type] # }}} # {{{ index mangling -def get_read_from_map_from_permutation(original, permuted): +def get_read_from_map_from_permutation(original: List[int], permuted: List[int]) -> Tuple[int, ...]: """With a permutation given by *original* and *permuted*, generate a list *rfm* of indices such that ``permuted[i] == original[rfm[i]]``. @@ -1495,7 +1499,7 @@ def get_read_from_map_from_permutation(original, permuted): return tuple(where_in_original[pi] for pi in permuted) -def get_write_to_map_from_permutation(original, permuted): +def get_write_to_map_from_permutation(original: List[int], permuted: List[int]) -> Tuple[int, ...]: """With a permutation given by *original* and *permuted*, generate a list *wtm* of indices such that ``permuted[wtm[i]] == original[i]``. @@ -1599,7 +1603,7 @@ def _get_alignments(self) -> Tuple[str, ...]: + (self.alignments[-1],) * (self.ncolumns - len(self.alignments)) ) - def _get_column_widths(self, rows: Tuple[Any, ...]) -> Tuple[int, ...]: + def _get_column_widths(self, rows: List[Tuple[Any, ...]]) -> Tuple[int, ...]: return tuple([ max(len(row[i]) for row in rows) for i in range(self.ncolumns) ]) @@ -1872,7 +1876,7 @@ def word_wrap(text: str, width: int, wrap_using: str = "\n") -> str: def _exec_arg(arg: str, execenv: Dict[str, Any]) -> None: import os if os.access(arg, os.F_OK): - exec(compile(open(arg), arg, "exec"), execenv) + exec(compile(open(arg).read(), arg, "exec"), execenv) else: exec(compile(arg, "", "exec"), execenv) From c59cbd01a29f007872c965c16c7473ced967e0af Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Tue, 14 Mar 2023 17:49:39 -0500 Subject: [PATCH 7/8] pass 4 --- pytools/__init__.py | 132 +++++++++++++++++++++++++++++--------------- run-mypy.sh | 2 +- 2 files changed, 88 insertions(+), 46 deletions(-) diff --git a/pytools/__init__.py b/pytools/__init__.py index 93a060ee..8ca00551 100644 --- a/pytools/__init__.py +++ b/pytools/__init__.py @@ -37,7 +37,8 @@ from sys import intern from typing import ( Any, Callable, cast, ClassVar, Dict, Generic, Hashable, Iterable, List, Mapping, - Optional, Set, Tuple, Type, TypeVar, Union, ValuesView, KeysView, ItemsView, Sequence, Generator, TYPE_CHECKING) + Optional, Set, Tuple, Type, TypeVar, Union, ValuesView, KeysView, + ItemsView, Sequence, Generator, TYPE_CHECKING) if TYPE_CHECKING: @@ -45,9 +46,9 @@ try: - from typing import Concatenate, SupportsIndex + from typing import Concatenate except ImportError: - from typing_extensions import Concatenate, SupportsIndex + from typing_extensions import Concatenate try: from typing import ParamSpec @@ -493,7 +494,8 @@ def __getitem__(self, index: Union[slice, int]) -> Any: # {{{ dependent dictionary class DependentDictionary: - def __init__(self, f: Callable[[Any, Any], Any], start: Optional[Dict[Any, Any]] = None): + def __init__(self, f: Callable[[Any, Any], Any], + start: Optional[Dict[Any, Any]] = None) -> None: if start is None: start = {} @@ -581,7 +583,7 @@ def is_single_valued( def all_roughly_equal(iterable: Iterable[T], threshold: float) -> bool: return is_single_valued(iterable, - equality_pred=lambda a, b: abs(cast(float, a)-cast(float, b)) < threshold) + equality_pred=lambda a, b: abs(cast(float, a)-cast(float, b)) < threshold) def single_valued( @@ -957,7 +959,8 @@ def decorator(func: Callable[..., Any]) -> Callable[..., Any]: return decorator -def monkeypatch_class(_name: Any, bases: Sequence[Any], namespace: Dict[str, Any]) -> Any: +def monkeypatch_class(_name: Any, bases: Sequence[Any], + namespace: Dict[str, Any]) -> Any: # from GvR, http://mail.python.org/pipermail/python-dev/2008-January/076194.html assert len(bases) == 1, "Exactly one base class required" @@ -1019,7 +1022,8 @@ def general_sum(sequence: Sequence[float]) -> float: return reduce(operator.add, sequence) -def linear_combination(coefficients: Sequence[float], vectors: Sequence[float]) -> float: +def linear_combination(coefficients: Sequence[float], + vectors: Sequence[float]) -> float: result = coefficients[0] * vectors[0] for c, v in zip(coefficients[1:], vectors[1:]): result += c*v @@ -1044,11 +1048,13 @@ def common_prefix(iterable: Iterable[Any], empty: Optional[Any] = None) -> Any: return pfx -def decorate(function: Callable[[Any], Any], iterable: Iterable[Any]) -> List[Tuple[Any, ...]]: +def decorate(function: Callable[[Any], Any], + iterable: Iterable[Any]) -> List[Tuple[Any, ...]]: return [(x, function(x)) for x in iterable] -def partition(criterion: Callable[[Any], Any], iterable: Iterable[Any]) -> Tuple[Any, Any]: +def partition(criterion: Callable[[Any], Any], + iterable: Iterable[Any]) -> Tuple[Any, Any]: part_true = [] part_false = [] for i in iterable: @@ -1094,7 +1100,8 @@ def div_ceil(nr: int, dr: int) -> int: return -(-nr // dr) -def uniform_interval_splitting(n: int, granularity: int, max_intervals: int) -> Tuple[int, int]: +def uniform_interval_splitting(n: int, granularity: int, + max_intervals: int) -> Tuple[int, int]: """ Return *(interval_size, num_intervals)* such that:: num_intervals * interval_size >= n @@ -1120,7 +1127,8 @@ def uniform_interval_splitting(n: int, granularity: int, max_intervals: int) -> return interval_size, num_intervals -def find_max_where(predicate: Callable[[Any], bool], prec: float = 1e-5, initial_guess: float = 1, fail_bound: float = 1e38) -> float: +def find_max_where(predicate: Callable[[Any], bool], prec: float = 1e-5, + initial_guess: float = 1, fail_bound: float = 1e38) -> float: """Find the largest value for which a predicate is true, along a half-line. 0 is assumed to be the lower bound.""" @@ -1177,7 +1185,8 @@ def find_max_where(predicate: Callable[[Any], bool], prec: float = 1e-5, initial # {{{ argmin, argmax -def argmin2(iterable: Iterable[Sequence[Any]], return_value: bool = False) -> Union[Any, Tuple[int, Any]]: +def argmin2(iterable: Iterable[Sequence[Any]], return_value: bool = False) \ + -> Union[Any, Tuple[int, Any]]: it = iter(iterable) try: current_argmin, current_min = next(it) @@ -1195,7 +1204,8 @@ def argmin2(iterable: Iterable[Sequence[Any]], return_value: bool = False) -> Un return current_argmin -def argmax2(iterable: Iterable[Sequence[Any]], return_value: bool = False) -> Union[Any, Tuple[int, Any]]: +def argmax2(iterable: Iterable[Sequence[Any]], return_value: bool = False) \ + -> Union[Any, Tuple[int, Any]]: it = iter(iterable) try: current_argmax, current_max = next(it) @@ -1225,7 +1235,8 @@ def argmax(iterable: Iterable[Any]) -> Any: # {{{ cartesian products etc. -def cartesian_product(*args: Iterable[Any]) -> Generator[Tuple[Any, ...], None, None]: +def cartesian_product(*args: Iterable[Any]) \ + -> Generator[Tuple[Any, ...], None, None]: if len(args) == 1: for arg in args[0]: yield (arg,) @@ -1236,14 +1247,16 @@ def cartesian_product(*args: Iterable[Any]) -> Generator[Tuple[Any, ...], None, yield prod + (i,) -def distinct_pairs(list1: Sequence[Any], list2: Sequence[Any]) -> Generator[Tuple[Any, Any], None, None]: +def distinct_pairs(list1: Sequence[Any], list2: Sequence[Any]) \ + -> Generator[Tuple[Any, Any], None, None]: for i, xi in enumerate(list1): for j, yj in enumerate(list2): if i != j: yield (xi, yj) -def cartesian_product_sum(list1: List[int], list2: List[int]) -> Generator[int, None, None]: +def cartesian_product_sum(list1: List[int], list2: List[int]) \ + -> Generator[int, None, None]: """This routine returns a list of sums of each element of list1 with each element of list2. Also works with lists. """ @@ -1325,7 +1338,8 @@ def std_deviation(iterable: Iterable[float], finite_pop: int) -> Optional[float] # {{{ permutations, tuples, integer sequences -def wandering_element(length: int, wanderer: float = 1, landscape: float = 0) -> Generator[Tuple[float, ...], None, None]: +def wandering_element(length: int, wanderer: float = 1, landscape: float = 0) \ + -> Generator[Tuple[float, ...], None, None]: for i in range(length): yield i*(landscape,) + (wanderer,) + (length-1-i)*(landscape,) @@ -1350,7 +1364,10 @@ def indices_in_shape(shape: Sequence[int]) -> Generator[Tuple[int, ...], None, N yield (i,)+rest -def generate_nonnegative_integer_tuples_below(n: Union[int, Sequence[int]], length: Optional[int] = None, least: int = 0) -> Generator[Union[Tuple[int, ...], Tuple[()]], None, None]: +def generate_nonnegative_integer_tuples_below(n: Union[int, Sequence[int]], + length: Optional[int] = None, + least: int = 0) \ + -> Generator[Union[Tuple[int, ...], Tuple[()]], None, None]: """n may be a sequence, in which case length must be None.""" if length is None: if not n: @@ -1380,7 +1397,8 @@ def generate_nonnegative_integer_tuples_below(n: Union[int, Sequence[int]], leng def generate_decreasing_nonnegative_tuples_summing_to( - n: int, length: int, min_value: int = 0, max_value: Optional[int] = None) -> Generator[Union[Tuple[int, ...], Tuple[()]], None, None]: + n: int, length: int, min_value: int = 0, max_value: Optional[int] = None) \ + -> Generator[Union[Tuple[int, ...], Tuple[()]], None, None]: if length == 0: yield () elif length == 1: @@ -1401,7 +1419,8 @@ def generate_decreasing_nonnegative_tuples_summing_to( yield (i,) + remainder -def generate_nonnegative_integer_tuples_summing_to_at_most(n: int, length: int) -> Generator[Union[Tuple[int, ...], Tuple[()]], None, None]: +def generate_nonnegative_integer_tuples_summing_to_at_most(n: int, length: int) \ + -> Generator[Union[Tuple[int, ...], Tuple[()]], None, None]: """Enumerate all non-negative integer tuples summing to at most n, exhausting the search space by varying the first entry fastest, and the last entry the slowest. @@ -1420,7 +1439,8 @@ def generate_nonnegative_integer_tuples_summing_to_at_most(n: int, length: int) generate_positive_integer_tuples_below = generate_nonnegative_integer_tuples_below -def _pos_and_neg_adaptor(tuple_iter: Iterable[Tuple[float, ...]]) -> Generator[Tuple[float, ...], None, None]: +def _pos_and_neg_adaptor(tuple_iter: Iterable[Tuple[float, ...]]) \ + -> Generator[Tuple[float, ...], None, None]: for tup in tuple_iter: nonzero_indices = [i for i in range(len(tup)) if tup[i] != 0] for do_neg_tup in generate_nonnegative_integer_tuples_below( @@ -1432,7 +1452,8 @@ def _pos_and_neg_adaptor(tuple_iter: Iterable[Tuple[float, ...]]) -> Generator[T yield tuple(this_result) -def generate_all_integer_tuples_below(n: int, length: int, least_abs: int = 0) -> Generator[Tuple[float, ...], None, None]: +def generate_all_integer_tuples_below(n: int, length: int, least_abs: int = 0) \ + -> Generator[Tuple[float, ...], None, None]: return _pos_and_neg_adaptor(generate_nonnegative_integer_tuples_below( n, length, least_abs)) @@ -1451,7 +1472,8 @@ def generate_permutations(original: List[Any]) -> Generator[List[Any], None, Non yield perm_[:i] + original[0:1] + perm_[i:] -def generate_unique_permutations(original: List[Any]) -> Generator[List[Any], None, None]: +def generate_unique_permutations(original: List[Any]) \ + -> Generator[List[Any], None, None]: """Generate all unique permutations of the list *original*. """ @@ -1465,14 +1487,15 @@ def generate_unique_permutations(original: List[Any]) -> Generator[List[Any], No def enumerate_basic_directions(dimensions: int) -> List[List[int]]: coordinate_list = [[0], [1], [-1]] - return reduce(cartesian_product_sum, [coordinate_list] * dimensions)[1:] # type: ignore[arg-type] + return reduce(cartesian_product_sum, [coordinate_list] * dimensions)[1:] # type: ignore[arg-type] # noqa: E501 # }}} # {{{ index mangling -def get_read_from_map_from_permutation(original: List[int], permuted: List[int]) -> Tuple[int, ...]: +def get_read_from_map_from_permutation(original: List[int], permuted: List[int]) \ + -> Tuple[int, ...]: """With a permutation given by *original* and *permuted*, generate a list *rfm* of indices such that ``permuted[i] == original[rfm[i]]``. @@ -1499,7 +1522,8 @@ def get_read_from_map_from_permutation(original: List[int], permuted: List[int]) return tuple(where_in_original[pi] for pi in permuted) -def get_write_to_map_from_permutation(original: List[int], permuted: List[int]) -> Tuple[int, ...]: +def get_write_to_map_from_permutation(original: List[int], permuted: List[int]) \ + -> Tuple[int, ...]: """With a permutation given by *original* and *permuted*, generate a list *wtm* of indices such that ``permuted[wtm[i]] == original[i]``. @@ -1798,8 +1822,10 @@ def remove_columns(i: int, row: Tuple[Any, ...]) -> Tuple[Any, ...]: # {{{ histogram formatting def string_histogram( # pylint: disable=too-many-arguments,too-many-locals - iterable: Iterable[float], min_value: Optional[float] = None, max_value: Optional[float] = None, - bin_count: int = 20, width: int = 70, bin_starts: Optional[Sequence[float]] = None, use_unicode: bool = True) -> str: + iterable: Iterable[float], min_value: Optional[float] = None, + max_value: Optional[float] = None, bin_count: int = 20, + width: int = 70, bin_starts: Optional[Sequence[float]] = None, + use_unicode: bool = True) -> str: if bin_starts is None: if min_value is None or max_value is None: iterable = list(iterable) @@ -1885,7 +1911,9 @@ class CPyUserInterface: class Parameters(Record): pass - def __init__(self, variables: Dict[str, Any], constants: Optional[Dict[str, Any]] = None, doc: Optional[Dict[str, Any]] = None) -> None: + def __init__(self, variables: Dict[str, Any], + constants: Optional[Dict[str, Any]] = None, + doc: Optional[Dict[str, Any]] = None) -> None: if constants is None: constants = {} if doc is None: @@ -1968,7 +1996,8 @@ def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: def typedump(val: Any, max_seq: int = 5, - special_handlers: Optional[Mapping[Type[Any], Callable[[Any], str]]] = None, + special_handlers: + Optional[Mapping[Type[Any], Callable[[Any], str]]] = None, fully_qualified_name: bool = True) -> str: """ Return a string representation of the type of *val*, recursing into @@ -2036,7 +2065,8 @@ def objname(obj: Any) -> str: return objname(val) -def invoke_editor(s: str, filename: str = "edit.txt", descr: str = "the file") -> str: +def invoke_editor(s: str, filename: str = "edit.txt", descr: str = "the file") \ + -> str: from tempfile import mkdtemp tempdir = mkdtemp() @@ -2077,7 +2107,8 @@ class ProgressBar: # pylint: disable=too-many-instance-attributes .. automethod:: __enter__ .. automethod:: __exit__ """ - def __init__(self, descr: str, total: float, initial: float = 0, length: float = 40): + def __init__(self, descr: str, total: float, initial: float = 0, + length: float = 40) -> None: import time self.description = descr self.total = total @@ -2166,7 +2197,8 @@ def add_python_path_relative_to_script(rel_path: str) -> None: # {{{ numpy dtype mangling -def common_dtype(dtypes: Sequence["np.dtype[Any]"], default: Optional["np.dtype[Any]"] = None) -> "np.dtype[Any]": +def common_dtype(dtypes: Sequence["np.dtype[Any]"], + default: Optional["np.dtype[Any]"] = None) -> "np.dtype[Any]": import numpy as np dtypes = list(dtypes) @@ -2185,7 +2217,8 @@ def to_uncomplex_dtype(dtype: "np.dtype[Any]") -> "Type[Any]": return np.array(1, dtype=dtype).real.dtype.type -def match_precision(dtype: "np.dtype[Any]", dtype_to_match: "np.dtype[Any]") -> "np.dtype[Any]": +def match_precision(dtype: "np.dtype[Any]", dtype_to_match: "np.dtype[Any]") \ + -> "np.dtype[Any]": import numpy tgt_is_double = dtype_to_match in [ @@ -2368,7 +2401,8 @@ def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: # {{{ download from web if not present -def download_from_web_if_not_present(url: str, local_name: Optional[str] = None) -> None: +def download_from_web_if_not_present(url: str, local_name: Optional[str] = None) \ + -> None: """ .. versionadded:: 2017.5 """ @@ -2396,7 +2430,8 @@ def download_from_web_if_not_present(url: str, local_name: Optional[str] = None) # {{{ find git revisions -def find_git_revision(tree_root): # type: ignore[no-untyped-def] # pylint: disable=too-many-locals +def find_git_revision(tree_root: str) \ + -> Optional[bytes]: # pylint: disable=too-many-locals # Keep this routine self-contained so that it can be copy-pasted into # setup.py. @@ -2440,7 +2475,7 @@ def find_git_revision(tree_root): # type: ignore[no-untyped-def] # pylint: dis return git_rev -def find_module_git_revision(module_file, n_levels_up): # type: ignore[no-untyped-def] +def find_module_git_revision(module_file, n_levels_up): from os.path import dirname, join tree_root = join(*([dirname(module_file)] + [".." * n_levels_up])) @@ -2451,7 +2486,8 @@ def find_module_git_revision(module_file, n_levels_up): # type: ignore[no-untyp # {{{ create a reshaped view of a numpy array -def reshaped_view(a: "np.ndarray[Any, Any]", newshape: Tuple[int, ...]) -> "np.ndarray[Any, Any]": +def reshaped_view(a: "np.ndarray[Any, Any]", newshape: Tuple[int, ...]) \ + -> "np.ndarray[Any, Any]": """ Create a new view object with shape ``newshape`` without copying the data of ``a``. This function is different from ``numpy.reshape`` by raising an exception when data copy is necessary. @@ -2541,7 +2577,8 @@ class ProcessLogger: # pylint: disable=too-many-instance-attributes def __init__( # pylint: disable=too-many-arguments self, logger: logging.Logger, description: str, - silent_level: Optional[int] =None, noisy_level: Optional[int] = None, long_threshold_seconds: Optional[float] = None) -> None: + silent_level: Optional[int] = None, noisy_level: Optional[int] = + None, long_threshold_seconds: Optional[float] = None) -> None: self.logger = logger self.description = description self.silent_level = silent_level or logging.DEBUG @@ -2614,7 +2651,8 @@ def done( # pylint: disable=keyword-arg-before-vararg completion_level = ( self.noisy_level - if self.timer.wall_elapsed is not None and self.timer.wall_elapsed > self.long_threshold_seconds + if (self.timer.wall_elapsed is not None + and self.timer.wall_elapsed > self.long_threshold_seconds) else self.silent_level) msg = "%s: completed (%s)" @@ -2645,7 +2683,8 @@ class log_process: # noqa: N801 .. automethod:: __call__ """ - def __init__(self, logger: logging.Logger, description: Optional[str] = None, long_threshold_seconds: Optional[float] = None) -> None: + def __init__(self, logger: logging.Logger, description: Optional[str] = None, + long_threshold_seconds: Optional[float] = None) -> None: self.logger = logger self.description = description self.long_threshold_seconds = long_threshold_seconds @@ -2690,7 +2729,8 @@ def natorder(item: Any) -> List[Any]: return result -def natsorted(iterable: Iterable[Any], key: Optional[Callable[[Any], Any]] = None, reverse: bool = False) -> List[Any]: +def natsorted(iterable: Iterable[Any], key: Optional[Callable[[Any], Any]] = None, + reverse: bool = False) -> List[Any]: """Sort using natural order [1]_, as opposed to lexicographic order. Example:: @@ -2716,7 +2756,7 @@ def natsorted(iterable: Iterable[Any], key: Optional[Callable[[Any], Any]] = Non key = lambda x: x # type-ignore-reason: mypy thinks key could be None - return sorted(iterable, key=lambda y: natorder(key(y)), reverse=reverse) # type: ignore[misc] + return sorted(iterable, key=lambda y: natorder(key(y)), reverse=reverse) # type: ignore[misc] # noqa: E501 # }}} @@ -2779,7 +2819,8 @@ def resolve_name(name: str) -> object: # {{{ unordered_hash -def unordered_hash(hash_instance: Any, iterable: Iterable[Any], hash_constructor: Optional[Callable[[], Any]] = None) -> Any: +def unordered_hash(hash_instance: Any, iterable: Iterable[Any], + hash_constructor: Optional[Callable[[], Any]] = None) -> Any: """Using a hash algorithm given by the parameter-less constructor *hash_constructor*, return a hash object whose internal state depends on the entries of *iterable*, but not their order. If *hash* @@ -2824,7 +2865,8 @@ def unordered_hash(hash_instance: Any, iterable: Iterable[Any], hash_constructor # {{{ sphere_sample -def sphere_sample_equidistant(npoints_approx: int, r: float = 1.0) -> "np.ndarray[Any, Any]": +def sphere_sample_equidistant(npoints_approx: int, r: float = 1.0) \ + -> "np.ndarray[Any, Any]": """Generate points regularly distributed on a sphere based on https://www.cmu.edu/biolphys/deserno/pdf/sphere_equi.pdf. diff --git a/run-mypy.sh b/run-mypy.sh index 5e3ad909..53754863 100755 --- a/run-mypy.sh +++ b/run-mypy.sh @@ -4,4 +4,4 @@ set -ex mypy --show-error-codes pytools -mypy --strict --follow-imports=skip pytools/datatable.py pytools/graph.py pytools/mpi.py +mypy --strict pytools/datatable.py pytools/graph.py pytools/mpi.py From 953870b06f65556666df186e73b77050e5d28d64 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Tue, 14 Mar 2023 17:53:47 -0500 Subject: [PATCH 8/8] pass 5 --- doc/conf.py | 15 +++++++++++++++ pytools/__init__.py | 13 ++++++++----- pytools/mpi.py | 2 +- run-mypy.sh | 2 +- setup.py | 1 + 5 files changed, 26 insertions(+), 7 deletions(-) diff --git a/doc/conf.py b/doc/conf.py index 659ed9e9..cb5d5ead 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -1,5 +1,6 @@ from urllib.request import urlopen + _conf_url = \ "https://raw.githubusercontent.com/inducer/sphinxconfig/main/sphinxconfig.py" with urlopen(_conf_url) as _inf: @@ -41,3 +42,17 @@ autodoc_type_aliases = {"GraphT": "pytools.graph.GraphT", "NodeT": "pytools.graph.NodeT", } + +# Some modules need to import things just so that sphinx can resolve symbols in +# type annotations. Often, we do not want these imports (e.g. of PyOpenCL) when +# in normal use (because they would introduce unintended side effects or hard +# dependencies). This flag exists so that these imports only occur during doc +# build. Since sphinx appears to resolve type hints lexically (as it should), +# this needs to be cross-module (since, e.g. an inherited pytools +# docstring can be read by sphinx when building meshmode, a dependent package), +# this needs a setting of the same name across all packages involved, that's +# why this name is as global-sounding as it is. +import sys + + +sys._BUILDING_SPHINX_DOCS = True diff --git a/pytools/__init__.py b/pytools/__init__.py index 8ca00551..3a1b0fa1 100644 --- a/pytools/__init__.py +++ b/pytools/__init__.py @@ -36,14 +36,17 @@ from functools import reduce, wraps from sys import intern from typing import ( - Any, Callable, cast, ClassVar, Dict, Generic, Hashable, Iterable, List, Mapping, - Optional, Set, Tuple, Type, TypeVar, Union, ValuesView, KeysView, - ItemsView, Sequence, Generator, TYPE_CHECKING) + TYPE_CHECKING, Any, Callable, ClassVar, Dict, Generator, Generic, Hashable, + ItemsView, Iterable, KeysView, List, Mapping, Optional, Sequence, Set, Tuple, + Type, TypeVar, Union, ValuesView, cast) if TYPE_CHECKING: import numpy as np +if getattr(sys, "_BUILDING_SPHINX_DOCS", False): + import numpy as np # noqa: F811 + try: from typing import Concatenate @@ -2475,11 +2478,11 @@ def find_git_revision(tree_root: str) \ return git_rev -def find_module_git_revision(module_file, n_levels_up): +def find_module_git_revision(module_file: str, n_levels_up: int) -> Optional[bytes]: from os.path import dirname, join tree_root = join(*([dirname(module_file)] + [".." * n_levels_up])) - return find_git_revision(tree_root) # type: ignore[no-untyped-call] + return find_git_revision(tree_root) # }}} diff --git a/pytools/mpi.py b/pytools/mpi.py index 27a0cb9d..392ee30d 100644 --- a/pytools/mpi.py +++ b/pytools/mpi.py @@ -72,7 +72,7 @@ def run_with_mpi_ranks(py_script: str, ranks: int, def pytest_raises_on_rank(my_rank: int, fail_rank: int, expected_exception: Union[Type[BaseException], Tuple[Type[BaseException], ...]]) \ - -> Generator[AbstractContextManager[Any], None, None]: + -> Generator[Any, None, None]: """ Like :func:`pytest.raises`, but only expect an exception on rank *fail_rank*. """ diff --git a/run-mypy.sh b/run-mypy.sh index 53754863..f3e96420 100755 --- a/run-mypy.sh +++ b/run-mypy.sh @@ -4,4 +4,4 @@ set -ex mypy --show-error-codes pytools -mypy --strict pytools/datatable.py pytools/graph.py pytools/mpi.py +mypy --strict pytools/datatable.py pytools/graph.py pytools/mpi.py pytools/__init__.py diff --git a/setup.py b/setup.py index 852fd244..1df4e668 100644 --- a/setup.py +++ b/setup.py @@ -2,6 +2,7 @@ from setuptools import setup + ver_dic = {} version_file = open("pytools/version.py") try: