diff --git a/pytools/__init__.py b/pytools/__init__.py index 8f3c9739..de00ed0b 100644 --- a/pytools/__init__.py +++ b/pytools/__init__.py @@ -36,14 +36,22 @@ from functools import reduce, wraps from sys import intern from typing import ( - Any, Callable, ClassVar, Dict, Generic, Hashable, Iterable, List, Mapping, - Optional, Set, Tuple, Type, TypeVar, Union) + TYPE_CHECKING, Any, Callable, ClassVar, Dict, Generator, Generic, Hashable, + ItemsView, Iterable, KeysView, List, Mapping, Optional, Sequence, Set, Tuple, + Type, TypeVar, Union, ValuesView, cast) + + +if TYPE_CHECKING: + import numpy as np + +if getattr(sys, "_BUILDING_SPHINX_DOCS", False): + import numpy as np # noqa: F811 try: - from typing import Concatenate, SupportsIndex + from typing import Concatenate except ImportError: - from typing_extensions import Concatenate, SupportsIndex + from typing_extensions import Concatenate try: from typing import ParamSpec @@ -227,7 +235,7 @@ def __init__(self, f: F, deadline: Optional[Union[int, str]] = None) -> None: self.f = f self.deadline = deadline - def __call__(self, *args, **kwargs): + def __call__(self, *args: Any, **kwargs: Any) -> Any: from warnings import warn warn(f"This function is deprecated and will go away in {self.deadline}. " f"Use {self.f.__module__}.{self.f.__name__} instead.", @@ -238,7 +246,7 @@ def __call__(self, *args, **kwargs): def deprecate_keyword(oldkey: str, newkey: Optional[str] = None, *, - deadline: Optional[str] = None): + deadline: Optional[str] = None) -> Callable[[Any], Any]: """Decorator used to deprecate function keyword arguments. :arg oldkey: deprecated argument name. @@ -250,9 +258,9 @@ def deprecate_keyword(oldkey: str, if deadline is None: deadline = "the future" - def wrapper(func): + def wrapper(func: Callable[[Any], Any]) -> Callable[[Any], Any]: @wraps(func) - def inner_wrapper(*args, **kwargs): + def inner_wrapper(*args: Any, **kwargs: Any) -> Any: if oldkey in kwargs: if newkey is None: warn(f"The '{oldkey}' keyword is deprecated and will " @@ -282,7 +290,7 @@ def inner_wrapper(*args, **kwargs): # {{{ math -def delta(x, y): +def delta(x: float, y: float) -> float: if x == y: return 1 else: @@ -303,95 +311,33 @@ def levi_civita(tup: Tuple[int, ...]) -> int: factorial = MovedFunctionDeprecationWrapper(math.factorial, deadline=2023) -try: - # NOTE: only available in python >= 3.8 - perm = MovedFunctionDeprecationWrapper(math.perm, deadline=2023) -except AttributeError: - def _unchecked_perm(n, k): - result = 1 - while k: - result *= n - n -= 1 - k -= 1 +perm = MovedFunctionDeprecationWrapper(math.perm, deadline=2023) - return result +comb = MovedFunctionDeprecationWrapper(math.comb, deadline=2023) - def perm(n: SupportsIndex, # type: ignore[misc] - k: Optional[SupportsIndex] = None) -> int: - """ - :returns: :math:`P(n, k)`, the number of permutations of length :math:`k` - drawn from :math:`n` choices. - """ - from warnings import warn - warn("This function is deprecated and will go away in 2023. " - "Use `math.perm` instead, which is available from Python 3.8.", - DeprecationWarning, stacklevel=2) - if k is None: - return math.factorial(n) - - import operator - n, k = operator.index(n), operator.index(k) - if k > n: - return 0 - - if k < 0: - raise ValueError("k must be a non-negative integer") - - if n < 0: - raise ValueError("n must be a non-negative integer") - - from numbers import Integral - if not isinstance(k, Integral): - raise TypeError(f"'{type(k).__name__}' object cannot be interpreted " - "as an integer") - - if not isinstance(n, Integral): - raise TypeError(f"'{type(n).__name__}' object cannot be interpreted " - "as an integer") - - return _unchecked_perm(n, k) - -try: - # NOTE: only available in python >= 3.8 - comb = MovedFunctionDeprecationWrapper(math.comb, deadline=2023) -except AttributeError: - def comb(n: SupportsIndex, # type: ignore[misc] - k: SupportsIndex) -> int: - """ - :returns: :math:`C(n, k)`, the number of combinations (subsets) - of length :math:`k` drawn from :math:`n` choices. - """ - from warnings import warn - warn("This function is deprecated and will go away in 2023. " - "Use `math.comb` instead, which is available from Python 3.8.", - DeprecationWarning, stacklevel=2) - - return _unchecked_perm(n, k) // math.factorial(k) - - -def norm_1(iterable): +def norm_1(iterable: Iterable[float]) -> float: return sum(abs(x) for x in iterable) -def norm_2(iterable): - return sum(x**2 for x in iterable)**0.5 +def norm_2(iterable: Iterable[float]) -> float: + return math.sqrt(sum(x**2 for x in iterable)) -def norm_inf(iterable): +def norm_inf(iterable: Iterable[float]) -> float: return max(abs(x) for x in iterable) -def norm_p(iterable, p): - return sum(i**p for i in iterable)**(1/p) +def norm_p(iterable: Iterable[float], p: float) -> float: + return math.pow(sum(i**p for i in iterable), 1/p) class Norm: - def __init__(self, p): + def __init__(self, p: float) -> None: self.p = p - def __call__(self, iterable): - return sum(i**self.p for i in iterable)**(1/self.p) + def __call__(self, iterable: Iterable[float]) -> float: + return math.pow(sum(i**self.p for i in iterable), 1/self.p) # }}} @@ -408,7 +354,8 @@ class RecordWithoutPickling: __slots__: ClassVar[List[str]] = [] fields: ClassVar[Set[str]] - def __init__(self, valuedict=None, exclude=None, **kwargs): + def __init__(self, valuedict: Optional[Dict[str, Any]] = None, + exclude: Optional[List[str]] = None, **kwargs: Any) -> None: assert self.__class__ is not Record if exclude is None: @@ -427,7 +374,7 @@ def __init__(self, valuedict=None, exclude=None, **kwargs): fields.add(key) setattr(self, key, value) - def get_copy_kwargs(self, **kwargs): + def get_copy_kwargs(self, **kwargs: Any) -> Any: for f in self.__class__.fields: if f not in kwargs: try: @@ -436,17 +383,17 @@ def get_copy_kwargs(self, **kwargs): pass return kwargs - def copy(self, **kwargs): + def copy(self, **kwargs: Any) -> "RecordWithoutPickling": return self.__class__(**self.get_copy_kwargs(**kwargs)) - def __repr__(self): + def __repr__(self) -> str: return "{}({})".format( self.__class__.__name__, ", ".join(f"{fld}={getattr(self, fld)!r}" for fld in self.__class__.fields if hasattr(self, fld))) - def register_fields(self, new_fields): + def register_fields(self, new_fields: Set[str]) -> None: try: fields = self.__class__.fields except AttributeError: @@ -454,7 +401,7 @@ def register_fields(self, new_fields): fields.update(new_fields) - def __getattr__(self, name): + def __getattr__(self, name: str) -> None: # This method is implemented to avoid pylint 'no-member' errors for # attribute access. raise AttributeError( @@ -465,13 +412,13 @@ def __getattr__(self, name): class Record(RecordWithoutPickling): __slots__: ClassVar[List[str]] = [] - def __getstate__(self): + def __getstate__(self) -> Dict[str, Any]: return { key: getattr(self, key) for key in self.__class__.fields if hasattr(self, key)} - def __setstate__(self, valuedict): + def __setstate__(self, valuedict: Dict[str, Any]) -> None: try: fields = self.__class__.fields except AttributeError: @@ -481,30 +428,33 @@ def __setstate__(self, valuedict): fields.add(key) setattr(self, key, value) - def __eq__(self, other): + def __eq__(self, other: Any) -> bool: if self is other: return True + if not isinstance(other, Record): + return False return (self.__class__ == other.__class__ and self.__getstate__() == other.__getstate__()) - def __ne__(self, other): + def __ne__(self, other: Any) -> bool: return not self.__eq__(other) class ImmutableRecordWithoutPickling(RecordWithoutPickling): """Hashable record. Does not explicitly enforce immutability.""" - def __init__(self, *args, **kwargs): + def __init__(self, *args: Any, **kwargs: Any) -> None: RecordWithoutPickling.__init__(self, *args, **kwargs) - self._cached_hash = None + self._cached_hash: Optional[int] = None - def __hash__(self): + def __hash__(self) -> int: # This attribute may vanish during pickling. if getattr(self, "_cached_hash", None) is None: self._cached_hash = hash( (type(self),) + tuple(getattr(self, field) for field in self.__class__.fields)) - return self._cached_hash + from typing import cast + return cast(int, self._cached_hash) class ImmutableRecord(ImmutableRecordWithoutPickling, Record): @@ -514,73 +464,76 @@ class ImmutableRecord(ImmutableRecordWithoutPickling, Record): class Reference: - def __init__(self, value): + def __init__(self, value: Any) -> None: self.value = value - def get(self): + def get(self) -> Any: from warnings import warn warn("Reference.get() is deprecated -- use ref.value instead") return self.value - def set(self, value): + def set(self, value: Any) -> None: self.value = value class FakeList: - def __init__(self, f, length): + def __init__(self, f: Callable[[int], Any], length: int) -> None: self._Length = length self._Function = f - def __len__(self): + def __len__(self) -> int: return self._Length - def __getitem__(self, index): + def __getitem__(self, index: Union[slice, int]) -> Any: try: + assert isinstance(index, slice) return [self._Function(i) for i in range(*index.indices(self._Length))] except AttributeError: + assert isinstance(index, int) return self._Function(index) # {{{ dependent dictionary class DependentDictionary: - def __init__(self, f, start=None): + def __init__(self, f: Callable[[Any, Any], Any], + start: Optional[Dict[Any, Any]] = None) -> None: if start is None: start = {} self._Function = f self._Dictionary = start.copy() - def copy(self): + def copy(self) -> "DependentDictionary": return DependentDictionary(self._Function, self._Dictionary) - def __contains__(self, key): + def __contains__(self, key: Hashable) -> bool: try: self[key] # pylint: disable=pointless-statement return True except KeyError: return False - def __getitem__(self, key): + def __getitem__(self, key: Hashable) -> Any: try: return self._Dictionary[key] except KeyError: return self._Function(self._Dictionary, key) - def __setitem__(self, key, value): + def __setitem__(self, key: Hashable, value: Any) -> None: self._Dictionary[key] = value - def genuineKeys(self): # noqa + def genuineKeys(self) -> List[Hashable]: # noqa return list(self._Dictionary.keys()) - def iteritems(self): + def iteritems(self) -> ItemsView[Any, Any]: return self._Dictionary.items() - def iterkeys(self): + def iterkeys(self) -> KeysView[Any]: return self._Dictionary.keys() - def itervalues(self): + def itervalues(self) -> ValuesView[Any]: return self._Dictionary.values() # }}} @@ -600,7 +553,7 @@ def one(iterable: Iterable[T]) -> T: except StopIteration: raise ValueError("empty iterable passed to 'one()'") - def no_more(): + def no_more() -> bool: try: next(it) raise ValueError("iterable with more than one entry passed to 'one()'") @@ -631,9 +584,9 @@ def is_single_valued( all_equal = is_single_valued -def all_roughly_equal(iterable, threshold): +def all_roughly_equal(iterable: Iterable[T], threshold: float) -> bool: return is_single_valued(iterable, - equality_pred=lambda a, b: abs(a-b) < threshold) + equality_pred=lambda a, b: abs(cast(float, a)-cast(float, b)) < threshold) def single_valued( @@ -649,7 +602,7 @@ def single_valued( except StopIteration: raise ValueError("empty iterable passed to 'single_valued()'") - def others_same(): + def others_same() -> bool: for other_item in it: if not equality_pred(other_item, first_item): return False @@ -663,7 +616,7 @@ def others_same(): # {{{ memoization / attribute storage -def memoize(*args: F, **kwargs: Any) -> F: +def memoize(*args: F, **kwargs: Any) -> Callable[[Any], Any]: """Stores previously computed function values in a cache. Two keyword-only arguments are supported: @@ -680,7 +633,7 @@ def memoize(*args: F, **kwargs: Any) -> F: default_key_func: Optional[Callable[..., Any]] if use_kw: - def default_key_func(*inner_args, **inner_kwargs): # noqa pylint:disable=function-redefined + def default_key_func(*inner_args: Any, **inner_kwargs: Any) -> Any: # noqa pylint:disable=function-redefined return inner_args, frozenset(inner_kwargs.items()) else: default_key_func = None @@ -693,19 +646,19 @@ def default_key_func(*inner_args, **inner_kwargs): # noqa pylint:disable=functi ", ".join(kwargs.keys()))) if key_func is not None: - def _decorator(func): - def wrapper(*args, **kwargs): + def _decorator(func: F) -> Callable[[Any], Any]: + def wrapper(*args: Any, **kwargs: Any) -> Any: key = key_func(*args, **kwargs) try: - return func._memoize_dic[key] # noqa: E501 # pylint: disable=protected-access + return func._memoize_dic[key] # type: ignore[attr-defined] # noqa: E501 # pylint: disable=protected-access except AttributeError: # _memoize_dic doesn't exist yet. result = func(*args, **kwargs) - func._memoize_dic = {key: result} # noqa: E501 # pylint: disable=protected-access + func._memoize_dic = {key: result} # type: ignore[attr-defined] # noqa: E501 # pylint: disable=protected-access return result except KeyError: result = func(*args, **kwargs) - func._memoize_dic[key] = result # noqa: E501 # pylint: disable=protected-access + func._memoize_dic[key] = result # type: ignore[attr-defined] # noqa: E501 # pylint: disable=protected-access return result from functools import update_wrapper @@ -713,18 +666,18 @@ def wrapper(*args, **kwargs): return wrapper else: - def _decorator(func): - def wrapper(*args): + def _decorator(func: F) -> Callable[[Any], Any]: + def wrapper(*args: Any) -> Any: try: - return func._memoize_dic[args] # noqa: E501 # pylint: disable=protected-access + return func._memoize_dic[args] # type: ignore[attr-defined] # noqa: E501 # pylint: disable=protected-access except AttributeError: # _memoize_dic doesn't exist yet. result = func(*args) - func._memoize_dic = {args: result} # noqa: E501 # pylint:disable=protected-access + func._memoize_dic = {args: result} # type: ignore[attr-defined] # noqa: E501 # pylint:disable=protected-access return result except KeyError: result = func(*args) - func._memoize_dic[args] = result # noqa: E501 # pylint: disable=protected-access + func._memoize_dic[args] = result # type: ignore[attr-defined] # noqa: E501 # pylint: disable=protected-access return result from functools import update_wrapper @@ -732,7 +685,7 @@ def wrapper(*args): return wrapper if not args: - return _decorator # type: ignore + return _decorator if callable(args[0]) and len(args) == 1: return _decorator(args[0]) raise TypeError( @@ -768,7 +721,7 @@ def wrapper(obj: T, *args: P.args, **kwargs: P.kwargs) -> R: assert cache_dict_name is not None try: - return getattr(obj, cache_dict_name)[key] + return cast(R, getattr(obj, cache_dict_name)[key]) except AttributeError: attribute_error = True except KeyError: @@ -782,7 +735,8 @@ def wrapper(obj: T, *args: P.args, **kwargs: P.kwargs) -> R: getattr(obj, cache_dict_name)[key] = result return result - def clear_cache(obj): + def clear_cache(obj: object) -> None: + assert cache_dict_name is not None object.__delattr__(obj, cache_dict_name) from functools import update_wrapper @@ -792,7 +746,7 @@ def clear_cache(obj): # into the function's dict is moderately sketchy. new_wrapper.clear_cache = clear_cache # type: ignore[attr-defined] - return new_wrapper + return cast(Callable[Concatenate[T, P], R], new_wrapper) def memoize_method( @@ -848,7 +802,7 @@ def wrapper(obj: T, *args: P.args, **kwargs: P.kwargs) -> R: assert cache_dict_name is not None try: - return getattr(obj, cache_dict_name)[cache_key] + return cast(R, getattr(obj, cache_dict_name)[cache_key]) except AttributeError: result = function(obj, *args, **kwargs) object.__setattr__(obj, cache_dict_name, {cache_key: result}) @@ -858,17 +812,18 @@ def wrapper(obj: T, *args: P.args, **kwargs: P.kwargs) -> R: getattr(obj, cache_dict_name)[cache_key] = result return result - def clear_cache(obj): + def clear_cache(obj: Any) -> None: + assert cache_dict_name is not None object.__delattr__(obj, cache_dict_name) from functools import update_wrapper new_wrapper = update_wrapper(wrapper, function) new_wrapper.clear_cache = clear_cache # type: ignore[attr-defined] - return new_wrapper + return cast(Callable[Concatenate[T, P], R], new_wrapper) -class keyed_memoize_method(keyed_memoize_on_first_arg): # noqa: N801 +class keyed_memoize_method(keyed_memoize_on_first_arg[T, P, R]): # noqa: N801 """Like :class:`memoize_method`, but additionally uses a function *key* to compute the key under which the function result is stored. @@ -884,7 +839,7 @@ class keyed_memoize_method(keyed_memoize_on_first_arg): # noqa: N801 Can memoize methods on classes that do not allow setting attributes (e.g. by overwritting ``__setattr__``), e.g. frozen :mod:`dataclasses`. """ - def _default_cache_dict_name(self, function): + def _default_cache_dict_name(self, function: Callable[[Any], Any]) -> str: return intern(f"_memoize_dic_{function.__name__}") @@ -920,7 +875,7 @@ def __init__(self, container: Any, identifier: Hashable) -> None: object.__setattr__(container, "_pytools_memoize_in_dict", memoize_in_dict) - self.cache_dict = memoize_in_dict.setdefault(identifier, {}) + self.cache_dict: Dict[P.args, R] = memoize_in_dict.setdefault(identifier, {}) def __call__(self, inner: Callable[P, R]) -> Callable[P, R]: @wraps(inner) @@ -934,10 +889,7 @@ def new_inner(*args: P.args, **kwargs: P.kwargs) -> R: self.cache_dict[args] = result return result - # NOTE: mypy gets confused because it types `wraps` as - # Callable[[VarArg(Any)], Any] - # which, for some reason, is not compatible with `F` - return new_inner # type: ignore[return-value] + return new_inner class keyed_memoize_in(Generic[P, R]): # noqa @@ -960,7 +912,7 @@ def __init__(self, object.__setattr__(container, "_pytools_keyed_memoize_in_dict", memoize_in_dict) - self.cache_dict = memoize_in_dict.setdefault(identifier, {}) + self.cache_dict: Dict[P.args, R] = memoize_in_dict.setdefault(identifier, {}) self.key = key def __call__(self, inner: Callable[P, R]) -> Callable[P, R]: @@ -989,28 +941,29 @@ class InfixOperator: Following a recipe from http://aspn.activestate.com/ASPN/Cookbook/Python/Recipe/384122 """ - def __init__(self, function): + def __init__(self, function: Callable[..., Any]) -> None: self.function = function - def __rlshift__(self, other): + def __rlshift__(self, other: Any) -> "InfixOperator": return InfixOperator(lambda x: self.function(other, x)) - def __rshift__(self, other): + def __rshift__(self, other: Any) -> Any: return self.function(other) - def call(self, a, b): + def call(self, a: Any, b: Any) -> Any: return self.function(a, b) -def monkeypatch_method(cls): +def monkeypatch_method(cls: Any) -> Callable[..., Any]: # from GvR, http://mail.python.org/pipermail/python-dev/2008-January/076194.html - def decorator(func): + def decorator(func: Callable[..., Any]) -> Callable[..., Any]: setattr(cls, func.__name__, func) return func return decorator -def monkeypatch_class(_name, bases, namespace): +def monkeypatch_class(_name: Any, bases: Sequence[Any], + namespace: Dict[str, Any]) -> Any: # from GvR, http://mail.python.org/pipermail/python-dev/2008-January/076194.html assert len(bases) == 1, "Exactly one base class required" @@ -1025,15 +978,15 @@ def monkeypatch_class(_name, bases, namespace): # {{{ generic utilities -def add_tuples(t1, t2): +def add_tuples(t1: Tuple[Any, ...], t2: Tuple[Any, ...]) -> Tuple[Any, ...]: return tuple([t1v + t2v for t1v, t2v in zip(t1, t2)]) -def negate_tuple(t1): +def negate_tuple(t1: Tuple[Any, ...]) -> Tuple[Any, ...]: return tuple([-t1v for t1v in t1]) -def shift(vec, dist): +def shift(vec: List[Any], dist: int) -> List[Any]: """Return a copy of *vec* shifted by *dist* such that .. code:: python @@ -1054,11 +1007,11 @@ def shift(vec, dist): return result -def len_iterable(iterable): +def len_iterable(iterable: Iterable[Any]) -> int: return sum(1 for i in iterable) -def flatten(iterable): +def flatten(iterable: Iterable[Any]) -> Generator[Any, None, None]: """For an iterable of sub-iterables, generate each member of each sub-iterable in turn, i.e. a flattened version of that super-iterable. @@ -1068,18 +1021,19 @@ def flatten(iterable): yield from sublist -def general_sum(sequence): +def general_sum(sequence: Sequence[float]) -> float: return reduce(operator.add, sequence) -def linear_combination(coefficients, vectors): +def linear_combination(coefficients: Sequence[float], + vectors: Sequence[float]) -> float: result = coefficients[0] * vectors[0] for c, v in zip(coefficients[1:], vectors[1:]): result += c*v return result -def common_prefix(iterable, empty=None): +def common_prefix(iterable: Iterable[Any], empty: Optional[Any] = None) -> Any: it = iter(iterable) try: pfx = next(it) @@ -1097,11 +1051,13 @@ def common_prefix(iterable, empty=None): return pfx -def decorate(function, iterable): +def decorate(function: Callable[[Any], Any], + iterable: Iterable[Any]) -> List[Tuple[Any, ...]]: return [(x, function(x)) for x in iterable] -def partition(criterion, iterable): +def partition(criterion: Callable[[Any], Any], + iterable: Iterable[Any]) -> Tuple[Any, Any]: part_true = [] part_false = [] for i in iterable: @@ -1112,7 +1068,7 @@ def partition(criterion, iterable): return part_true, part_false -def partition2(iterable): +def partition2(iterable: Iterable[Any]) -> Tuple[Any, Any]: part_true = [] part_false = [] for pred, i in iterable: @@ -1128,7 +1084,7 @@ def product(iterable: Iterable[Any]) -> Any: return reduce(mul, iterable, 1) -def reverse_dictionary(the_dict): +def reverse_dictionary(the_dict: Dict[Any, Any]) -> Dict[Any, Any]: result = {} for key, value in the_dict.items(): if value in result: @@ -1138,16 +1094,17 @@ def reverse_dictionary(the_dict): return result -def set_sum(set_iterable): +def set_sum(set_iterable: Iterable[Set[Any]]) -> Set[Any]: from operator import or_ return reduce(or_, set_iterable, set()) -def div_ceil(nr, dr): +def div_ceil(nr: int, dr: int) -> int: return -(-nr // dr) -def uniform_interval_splitting(n, granularity, max_intervals): +def uniform_interval_splitting(n: int, granularity: int, + max_intervals: int) -> Tuple[int, int]: """ Return *(interval_size, num_intervals)* such that:: num_intervals * interval_size >= n @@ -1173,7 +1130,8 @@ def uniform_interval_splitting(n, granularity, max_intervals): return interval_size, num_intervals -def find_max_where(predicate, prec=1e-5, initial_guess=1, fail_bound=1e38): +def find_max_where(predicate: Callable[[Any], bool], prec: float = 1e-5, + initial_guess: float = 1, fail_bound: float = 1e38) -> float: """Find the largest value for which a predicate is true, along a half-line. 0 is assumed to be the lower bound.""" @@ -1230,7 +1188,8 @@ def find_max_where(predicate, prec=1e-5, initial_guess=1, fail_bound=1e38): # {{{ argmin, argmax -def argmin2(iterable, return_value=False): +def argmin2(iterable: Iterable[Sequence[Any]], return_value: bool = False) \ + -> Union[Any, Tuple[int, Any]]: it = iter(iterable) try: current_argmin, current_min = next(it) @@ -1248,7 +1207,8 @@ def argmin2(iterable, return_value=False): return current_argmin -def argmax2(iterable, return_value=False): +def argmax2(iterable: Iterable[Sequence[Any]], return_value: bool = False) \ + -> Union[Any, Tuple[int, Any]]: it = iter(iterable) try: current_argmax, current_max = next(it) @@ -1266,11 +1226,11 @@ def argmax2(iterable, return_value=False): return current_argmax -def argmin(iterable): +def argmin(iterable: Iterable[Any]) -> Any: return argmin2(enumerate(iterable)) -def argmax(iterable): +def argmax(iterable: Iterable[Any]) -> Any: return argmax2(enumerate(iterable)) # }}} @@ -1278,7 +1238,8 @@ def argmax(iterable): # {{{ cartesian products etc. -def cartesian_product(*args): +def cartesian_product(*args: Iterable[Any]) \ + -> Generator[Tuple[Any, ...], None, None]: if len(args) == 1: for arg in args[0]: yield (arg,) @@ -1289,14 +1250,16 @@ def cartesian_product(*args): yield prod + (i,) -def distinct_pairs(list1, list2): +def distinct_pairs(list1: Sequence[Any], list2: Sequence[Any]) \ + -> Generator[Tuple[Any, Any], None, None]: for i, xi in enumerate(list1): for j, yj in enumerate(list2): if i != j: yield (xi, yj) -def cartesian_product_sum(list1, list2): +def cartesian_product_sum(list1: List[int], list2: List[int]) \ + -> Generator[int, None, None]: """This routine returns a list of sums of each element of list1 with each element of list2. Also works with lists. """ @@ -1309,7 +1272,7 @@ def cartesian_product_sum(list1, list2): # {{{ elementary statistics -def average(iterable): +def average(iterable: Iterable[float]) -> float: """Return the average of the values in iterable. iterable may not be empty. @@ -1334,20 +1297,20 @@ class VarianceAggregator: See http://en.wikipedia.org/wiki/Algorithms_for_calculating_variance Adheres to pysqlite's aggregate interface. """ - def __init__(self, entire_pop): + def __init__(self, entire_pop: int) -> None: self.n = 0 - self.mean = 0 - self.m2 = 0 + self.mean = 0.0 + self.m2 = 0.0 self.entire_pop = entire_pop - def step(self, x): + def step(self, x: float) -> None: self.n += 1 delta_ = x - self.mean self.mean += delta_/self.n self.m2 += delta_*(x - self.mean) - def finalize(self): + def finalize(self) -> Optional[float]: if self.entire_pop: if self.n == 0: return None @@ -1360,7 +1323,7 @@ def finalize(self): return self.m2/(self.n - 1) -def variance(iterable, entire_pop): +def variance(iterable: Iterable[float], entire_pop: int) -> Optional[float]: v_comp = VarianceAggregator(entire_pop) for x in iterable: @@ -1369,21 +1332,22 @@ def variance(iterable, entire_pop): return v_comp.finalize() -def std_deviation(iterable, finite_pop): +def std_deviation(iterable: Iterable[float], finite_pop: int) -> Optional[float]: from math import sqrt - return sqrt(variance(iterable, finite_pop)) + return sqrt(cast(float, variance(iterable, finite_pop))) # }}} # {{{ permutations, tuples, integer sequences -def wandering_element(length, wanderer=1, landscape=0): +def wandering_element(length: int, wanderer: float = 1, landscape: float = 0) \ + -> Generator[Tuple[float, ...], None, None]: for i in range(length): yield i*(landscape,) + (wanderer,) + (length-1-i)*(landscape,) -def indices_in_shape(shape): +def indices_in_shape(shape: Sequence[int]) -> Generator[Tuple[int, ...], None, None]: from warnings import warn warn("indices_in_shape is deprecated. You should prefer numpy.ndindex.", DeprecationWarning, stacklevel=2) @@ -1403,17 +1367,22 @@ def indices_in_shape(shape): yield (i,)+rest -def generate_nonnegative_integer_tuples_below(n, length=None, least=0): +def generate_nonnegative_integer_tuples_below(n: Union[int, Sequence[int]], + length: Optional[int] = None, + least: int = 0) \ + -> Generator[Union[Tuple[int, ...], Tuple[()]], None, None]: """n may be a sequence, in which case length must be None.""" if length is None: if not n: yield () return + n = cast(Sequence[int], n) my_n = n[0] n = n[1:] next_length = None else: + n = cast(int, n) my_n = n assert length >= 0 @@ -1426,14 +1395,17 @@ def generate_nonnegative_integer_tuples_below(n, length=None, least=0): for i in range(least, my_n): my_part = (i,) for base in generate_nonnegative_integer_tuples_below(n, next_length, least): + base = cast(Tuple[Any], base) yield my_part + base def generate_decreasing_nonnegative_tuples_summing_to( - n, length, min_value=0, max_value=None): + n: int, length: int, min_value: int = 0, max_value: Optional[int] = None) \ + -> Generator[Union[Tuple[int, ...], Tuple[()]], None, None]: if length == 0: yield () elif length == 1: + max_value = cast(int, max_value) if n <= max_value: #print "MX", n, max_value yield (n,) @@ -1450,7 +1422,8 @@ def generate_decreasing_nonnegative_tuples_summing_to( yield (i,) + remainder -def generate_nonnegative_integer_tuples_summing_to_at_most(n, length): +def generate_nonnegative_integer_tuples_summing_to_at_most(n: int, length: int) \ + -> Generator[Union[Tuple[int, ...], Tuple[()]], None, None]: """Enumerate all non-negative integer tuples summing to at most n, exhausting the search space by varying the first entry fastest, and the last entry the slowest. @@ -1469,7 +1442,8 @@ def generate_nonnegative_integer_tuples_summing_to_at_most(n, length): generate_positive_integer_tuples_below = generate_nonnegative_integer_tuples_below -def _pos_and_neg_adaptor(tuple_iter): +def _pos_and_neg_adaptor(tuple_iter: Iterable[Tuple[float, ...]]) \ + -> Generator[Tuple[float, ...], None, None]: for tup in tuple_iter: nonzero_indices = [i for i in range(len(tup)) if tup[i] != 0] for do_neg_tup in generate_nonnegative_integer_tuples_below( @@ -1481,12 +1455,13 @@ def _pos_and_neg_adaptor(tuple_iter): yield tuple(this_result) -def generate_all_integer_tuples_below(n, length, least_abs=0): +def generate_all_integer_tuples_below(n: int, length: int, least_abs: int = 0) \ + -> Generator[Tuple[float, ...], None, None]: return _pos_and_neg_adaptor(generate_nonnegative_integer_tuples_below( n, length, least_abs)) -def generate_permutations(original): +def generate_permutations(original: List[Any]) -> Generator[List[Any], None, None]: """Generate all permutations of the list *original*. Nicked from http://aspn.activestate.com/ASPN/Cookbook/Python/Recipe/252178 @@ -1500,7 +1475,8 @@ def generate_permutations(original): yield perm_[:i] + original[0:1] + perm_[i:] -def generate_unique_permutations(original): +def generate_unique_permutations(original: List[Any]) \ + -> Generator[List[Any], None, None]: """Generate all unique permutations of the list *original*. """ @@ -1512,16 +1488,17 @@ def generate_unique_permutations(original): yield perm_ -def enumerate_basic_directions(dimensions): +def enumerate_basic_directions(dimensions: int) -> List[List[int]]: coordinate_list = [[0], [1], [-1]] - return reduce(cartesian_product_sum, [coordinate_list] * dimensions)[1:] + return reduce(cartesian_product_sum, [coordinate_list] * dimensions)[1:] # type: ignore[arg-type] # noqa: E501 # }}} # {{{ index mangling -def get_read_from_map_from_permutation(original, permuted): +def get_read_from_map_from_permutation(original: List[int], permuted: List[int]) \ + -> Tuple[int, ...]: """With a permutation given by *original* and *permuted*, generate a list *rfm* of indices such that ``permuted[i] == original[rfm[i]]``. @@ -1548,7 +1525,8 @@ def get_read_from_map_from_permutation(original, permuted): return tuple(where_in_original[pi] for pi in permuted) -def get_write_to_map_from_permutation(original, permuted): +def get_write_to_map_from_permutation(original: List[int], permuted: List[int]) \ + -> Tuple[int, ...]: """With a permutation given by *original* and *permuted*, generate a list *wtm* of indices such that ``permuted[wtm[i]] == original[i]``. @@ -1652,7 +1630,7 @@ def _get_alignments(self) -> Tuple[str, ...]: + (self.alignments[-1],) * (self.ncolumns - len(self.alignments)) ) - def _get_column_widths(self, rows) -> Tuple[int, ...]: + def _get_column_widths(self, rows: List[Tuple[Any, ...]]) -> Tuple[int, ...]: return tuple([ max(len(row[i]) for row in rows) for i in range(self.ncolumns) ]) @@ -1818,7 +1796,7 @@ def merge_tables(*tables: Table, if isinstance(skip_columns, int): skip_columns = (skip_columns,) - def remove_columns(i, row): + def remove_columns(i: int, row: Tuple[Any, ...]) -> Tuple[Any, ...]: if i == 0 or skip_columns is None: return row else: @@ -1833,7 +1811,7 @@ def remove_columns(i, row): result = Table(alignments=alignments) for i in range(tables[0].nrows): - row = [] + row: List[Any] = [] for j, tbl in enumerate(tables): row.extend(remove_columns(j, tbl.rows[i])) @@ -1847,8 +1825,10 @@ def remove_columns(i, row): # {{{ histogram formatting def string_histogram( # pylint: disable=too-many-arguments,too-many-locals - iterable, min_value=None, max_value=None, - bin_count=20, width=70, bin_starts=None, use_unicode=True): + iterable: Iterable[float], min_value: Optional[float] = None, + max_value: Optional[float] = None, bin_count: int = 20, + width: int = 70, bin_starts: Optional[Sequence[float]] = None, + use_unicode: bool = True) -> str: if bin_starts is None: if min_value is None or max_value is None: iterable = list(iterable) @@ -1875,7 +1855,7 @@ def string_histogram( # pylint: disable=too-many-arguments,too-many-locals from math import ceil, floor if use_unicode: - def format_bar(cnt): + def format_bar(cnt: int) -> str: scaled = cnt*width/max_count full = int(floor(scaled)) eighths = int(ceil((scaled-full)*8)) @@ -1884,7 +1864,7 @@ def format_bar(cnt): else: return full*chr(0x2588) else: - def format_bar(cnt): + def format_bar(cnt: int) -> str: return int(ceil(cnt*width/max_count))*"#" max_count = max(bins) @@ -1899,7 +1879,7 @@ def format_bar(cnt): # }}} -def word_wrap(text, width, wrap_using="\n"): +def word_wrap(text: str, width: int, wrap_using: str = "\n") -> str: # http://code.activestate.com/recipes/148061-one-liner-word-wrap-function/ r""" A word-wrap function that preserves existing line breaks @@ -1922,10 +1902,10 @@ def word_wrap(text, width, wrap_using="\n"): # {{{ command line interfaces -def _exec_arg(arg, execenv): +def _exec_arg(arg: str, execenv: Dict[str, Any]) -> None: import os if os.access(arg, os.F_OK): - exec(compile(open(arg), arg, "exec"), execenv) + exec(compile(open(arg).read(), arg, "exec"), execenv) else: exec(compile(arg, "", "exec"), execenv) @@ -1934,7 +1914,9 @@ class CPyUserInterface: class Parameters(Record): pass - def __init__(self, variables, constants=None, doc=None): + def __init__(self, variables: Dict[str, Any], + constants: Optional[Dict[str, Any]] = None, + doc: Optional[Dict[str, Any]] = None) -> None: if constants is None: constants = {} if doc is None: @@ -1943,7 +1925,7 @@ def __init__(self, variables, constants=None, doc=None): self.constants = constants self.doc = doc - def show_usage(self, progname): + def show_usage(self, progname: str) -> None: print(f"usage: {progname} ") print() print("FILE-OR-STATEMENTS may either be Python statements of the form") @@ -1965,7 +1947,7 @@ def show_usage(self, progname): if c in self.doc: print(f" {self.doc[c]}") - def gather(self, argv=None): + def gather(self, argv: Optional[Sequence[Any]] = None) -> Parameters: if argv is None: argv = sys.argv @@ -1997,7 +1979,7 @@ def gather(self, argv=None): self.validate(result) return result - def validate(self, setup): + def validate(self, setup: Any) -> None: pass # }}} @@ -2006,18 +1988,19 @@ def validate(self, setup): # {{{ debugging class StderrToStdout: - def __enter__(self): + def __enter__(self) -> None: # pylint: disable=attribute-defined-outside-init self.stderr_backup = sys.stderr sys.stderr = sys.stdout - def __exit__(self, exc_type, exc_val, exc_tb): + def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: sys.stderr = self.stderr_backup del self.stderr_backup def typedump(val: Any, max_seq: int = 5, - special_handlers: Optional[Mapping[Type, Callable]] = None, + special_handlers: + Optional[Mapping[Type[Any], Callable[[Any], str]]] = None, fully_qualified_name: bool = True) -> str: """ Return a string representation of the type of *val*, recursing into @@ -2085,7 +2068,8 @@ def objname(obj: Any) -> str: return objname(val) -def invoke_editor(s, filename="edit.txt", descr="the file"): +def invoke_editor(s: str, filename: str = "edit.txt", descr: str = "the file") \ + -> str: from tempfile import mkdtemp tempdir = mkdtemp() @@ -2126,7 +2110,8 @@ class ProgressBar: # pylint: disable=too-many-instance-attributes .. automethod:: __enter__ .. automethod:: __exit__ """ - def __init__(self, descr, total, initial=0, length=40): + def __init__(self, descr: str, total: float, initial: float = 0, + length: float = 40) -> None: import time self.description = descr self.total = total @@ -2139,9 +2124,9 @@ def __init__(self, descr, total, initial=0, length=40): self.speed_meas_start_time = self.start_time self.speed_meas_start_done = initial - self.time_per_step = None + self.time_per_step: Optional[float] = None - def draw(self): + def draw(self) -> None: import time now = time.time() @@ -2168,27 +2153,27 @@ def draw(self): sys.stderr.write("{:<20} [{}] ETA {}\r".format( self.description, - squares*"#"+(self.length-squares)*" ", + squares*"#"+int(self.length-squares)*" ", eta_str)) self.last_squares = squares self.last_update_time = now - def progress(self, steps=1): + def progress(self, steps: int = 1) -> None: self.set_progress(self.done + steps) - def set_progress(self, done): + def set_progress(self, done: float) -> None: self.done = done self.draw() - def finished(self): + def finished(self) -> None: self.set_progress(self.total) sys.stderr.write("\n") - def __enter__(self): + def __enter__(self) -> None: self.draw() - def __exit__(self, exc_type, exc_val, exc_tb): + def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: self.finished() # }}} @@ -2196,13 +2181,13 @@ def __exit__(self, exc_type, exc_val, exc_tb): # {{{ file system related -def assert_not_a_file(name): +def assert_not_a_file(name: str) -> None: import os if os.access(name, os.F_OK): raise OSError(f"file `{name}' already exists") -def add_python_path_relative_to_script(rel_path): +def add_python_path_relative_to_script(rel_path: str) -> None: from os.path import abspath, dirname, join script_name = sys.argv[0] @@ -2215,10 +2200,13 @@ def add_python_path_relative_to_script(rel_path): # {{{ numpy dtype mangling -def common_dtype(dtypes, default=None): +def common_dtype(dtypes: Sequence["np.dtype[Any]"], + default: Optional["np.dtype[Any]"] = None) -> "np.dtype[Any]": + import numpy as np + dtypes = list(dtypes) if dtypes: - return argmax2((dtype, dtype.num) for dtype in dtypes) + return cast(np.dtype[Any], argmax2((dtype, dtype.num) for dtype in dtypes)) else: if default is not None: return default @@ -2227,12 +2215,13 @@ def common_dtype(dtypes, default=None): "cannot find common dtype of empty dtype list") -def to_uncomplex_dtype(dtype): +def to_uncomplex_dtype(dtype: "np.dtype[Any]") -> "Type[Any]": import numpy as np return np.array(1, dtype=dtype).real.dtype.type -def match_precision(dtype, dtype_to_match): +def match_precision(dtype: "np.dtype[Any]", dtype_to_match: "np.dtype[Any]") \ + -> "np.dtype[Any]": import numpy tgt_is_double = dtype_to_match in [ @@ -2255,7 +2244,7 @@ def match_precision(dtype, dtype_to_match): # {{{ unique name generation -def generate_unique_names(prefix): +def generate_unique_names(prefix: str) -> Generator[str, None, None]: yield prefix try_num = 0 @@ -2387,19 +2376,19 @@ def __call__(self, based_on: str = "id") -> str: # {{{ recursion limit class MinRecursionLimit: - def __init__(self, min_rec_limit): + def __init__(self, min_rec_limit: int) -> None: self.min_rec_limit = min_rec_limit - def __enter__(self): + def __enter__(self) -> None: # pylint: disable=attribute-defined-outside-init self.prev_recursion_limit = sys.getrecursionlimit() new_limit = max(self.prev_recursion_limit, self.min_rec_limit) sys.setrecursionlimit(new_limit) - def __exit__(self, exc_type, exc_val, exc_tb): + def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: # Deep recursion can produce deeply nested data structures - # (or long chains of to-be gc'd generators) that cannot be + # (or long chains of to-be gc'd generators) that cannot # undergo garbage collection with a lower recursion limit. # # As a result, it doesn't seem possible to lower the recursion limit @@ -2415,7 +2404,8 @@ def __exit__(self, exc_type, exc_val, exc_tb): # {{{ download from web if not present -def download_from_web_if_not_present(url, local_name=None): +def download_from_web_if_not_present(url: str, local_name: Optional[str] = None) \ + -> None: """ .. versionadded:: 2017.5 """ @@ -2443,7 +2433,8 @@ def download_from_web_if_not_present(url, local_name=None): # {{{ find git revisions -def find_git_revision(tree_root): # pylint: disable=too-many-locals +def find_git_revision(tree_root: str) \ + -> Optional[bytes]: # pylint: disable=too-many-locals # Keep this routine self-contained so that it can be copy-pasted into # setup.py. @@ -2473,7 +2464,7 @@ def find_git_revision(tree_root): # pylint: disable=too-many-locals cwd=tree_root, env=env) (git_rev, _) = p.communicate() - git_rev = git_rev.decode() + git_rev = git_rev.decode() # type: ignore[assignment] git_rev = git_rev.rstrip() @@ -2487,7 +2478,7 @@ def find_git_revision(tree_root): # pylint: disable=too-many-locals return git_rev -def find_module_git_revision(module_file, n_levels_up): +def find_module_git_revision(module_file: str, n_levels_up: int) -> Optional[bytes]: from os.path import dirname, join tree_root = join(*([dirname(module_file)] + [".." * n_levels_up])) @@ -2498,7 +2489,8 @@ def find_module_git_revision(module_file, n_levels_up): # {{{ create a reshaped view of a numpy array -def reshaped_view(a, newshape): +def reshaped_view(a: "np.ndarray[Any, Any]", newshape: Tuple[int, ...]) \ + -> "np.ndarray[Any, Any]": """ Create a new view object with shape ``newshape`` without copying the data of ``a``. This function is different from ``numpy.reshape`` by raising an exception when data copy is necessary. @@ -2536,30 +2528,33 @@ class ProcessTimer: .. versionadded:: 2018.5 """ - def __init__(self): + def __init__(self) -> None: import time self.perf_counter_start = time.perf_counter() self.process_time_start = time.process_time() - self.wall_elapsed = None - self.process_elapsed = None + self.wall_elapsed: Optional[float] = None + self.process_elapsed: Optional[float] = None - def __enter__(self): + def __enter__(self) -> "ProcessTimer": return self - def __exit__(self, exc_type, exc_val, exc_tb): + def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: self.done() - def done(self): + def done(self) -> None: import time self.wall_elapsed = time.perf_counter() - self.perf_counter_start self.process_elapsed = time.process_time() - self.process_time_start - def __str__(self): - cpu = self.process_elapsed / self.wall_elapsed + def __str__(self) -> str: + if self.process_elapsed and self.wall_elapsed: + cpu = (self.process_elapsed / self.wall_elapsed) + else: + cpu = 0.0 return f"{self.wall_elapsed:.2f}s wall {cpu:.2f}x CPU" - def __repr__(self): + def __repr__(self) -> str: wall = self.wall_elapsed process = self.process_elapsed @@ -2596,8 +2591,9 @@ class ProcessLogger: # pylint: disable=too-many-instance-attributes default_noisy_level = logging.INFO def __init__( # pylint: disable=too-many-arguments - self, logger, description, - silent_level=None, noisy_level=None, long_threshold_seconds=None): + self, logger: logging.Logger, description: str, + silent_level: Optional[int] = None, noisy_level: Optional[int] = + None, long_threshold_seconds: Optional[float] = None) -> None: self.logger = logger self.description = description self.silent_level = silent_level or logging.DEBUG @@ -2655,13 +2651,14 @@ def __init__( # pylint: disable=too-many-arguments self.timer = ProcessTimer() def done( # pylint: disable=keyword-arg-before-vararg - self, extra_msg=None, *extra_fmt_args): + self, extra_msg: Optional[str] = None, *extra_fmt_args: Any) -> None: self.timer.done() self._done_indicator[0] = True completion_level = ( self.noisy_level - if self.timer.wall_elapsed > self.long_threshold_seconds + if (self.timer.wall_elapsed is not None + and self.timer.wall_elapsed > self.long_threshold_seconds) else self.silent_level) msg = "%s: completed (%s)" @@ -2673,10 +2670,10 @@ def done( # pylint: disable=keyword-arg-before-vararg self.logger.log(completion_level, msg, *fmt_args) - def __enter__(self): + def __enter__(self) -> None: pass - def __exit__(self, exc_type, exc_val, exc_tb): + def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: self.done() @@ -2692,13 +2689,14 @@ class log_process: # noqa: N801 .. automethod:: __call__ """ - def __init__(self, logger, description=None, long_threshold_seconds=None): + def __init__(self, logger: logging.Logger, description: Optional[str] = None, + long_threshold_seconds: Optional[float] = None) -> None: self.logger = logger self.description = description self.long_threshold_seconds = long_threshold_seconds - def __call__(self, wrapped): - def wrapper(*args, **kwargs): + def __call__(self, wrapped: Any) -> Any: + def wrapper(*args: Any, **kwargs: Any) -> Any: with ProcessLogger( self.logger, self.description or wrapped.__name__, @@ -2715,7 +2713,7 @@ def wrapper(*args, **kwargs): # {{{ sorting in natural order -def natorder(item): +def natorder(item: Any) -> List[Any]: """Return a key for natural order string comparison. See :func:`natsorted`. @@ -2737,7 +2735,8 @@ def natorder(item): return result -def natsorted(iterable, key=None, reverse=False): +def natsorted(iterable: Iterable[Any], key: Optional[Callable[[Any], Any]] = None, + reverse: bool = False) -> List[Any]: """Sort using natural order [1]_, as opposed to lexicographic order. Example:: @@ -2761,7 +2760,9 @@ def natsorted(iterable, key=None, reverse=False): """ if key is None: key = lambda x: x - return sorted(iterable, key=lambda y: natorder(key(y)), reverse=reverse) + + # type-ignore-reason: mypy thinks key could be None + return sorted(iterable, key=lambda y: natorder(key(y)), reverse=reverse) # type: ignore[misc] # noqa: E501 # }}} @@ -2775,7 +2776,7 @@ def natsorted(iterable, key=None, reverse=False): del _DOTTED_WORDS -def resolve_name(name): +def resolve_name(name: str) -> object: """A backport of :func:`pkgutil.resolve_name` (added in Python 3.9). .. versionadded:: 2021.1.2 @@ -2824,7 +2825,8 @@ def resolve_name(name): # {{{ unordered_hash -def unordered_hash(hash_instance, iterable, hash_constructor=None): +def unordered_hash(hash_instance: Any, iterable: Iterable[Any], + hash_constructor: Optional[Callable[[], Any]] = None) -> Any: """Using a hash algorithm given by the parameter-less constructor *hash_constructor*, return a hash object whose internal state depends on the entries of *iterable*, but not their order. If *hash* @@ -2869,7 +2871,8 @@ def unordered_hash(hash_instance, iterable, hash_constructor=None): # {{{ sphere_sample -def sphere_sample_equidistant(npoints_approx: int, r: float = 1.0): +def sphere_sample_equidistant(npoints_approx: int, r: float = 1.0) \ + -> "np.ndarray[Any, Any]": """Generate points regularly distributed on a sphere based on https://www.cmu.edu/biolphys/deserno/pdf/sphere_equi.pdf. @@ -2878,7 +2881,7 @@ def sphere_sample_equidistant(npoints_approx: int, r: float = 1.0): """ import numpy as np - points: List[np.ndarray] = [] + points: List[np.ndarray[Any, Any]] = [] count = 0 a = 4 * np.pi / npoints_approx @@ -2918,7 +2921,7 @@ def sphere_sample_equidistant(npoints_approx: int, r: float = 1.0): def sphere_sample_fibonacci( npoints: int, r: float = 1.0, *, - optimize: Optional[str] = None): + optimize: Optional[str] = None) -> "np.ndarray[Any, Any]": """Generate points on a sphere based on an offset Fibonacci lattice from [2]_. .. [2] http://extremelearning.com.au/how-to-evenly-distribute-points-on-a-sphere-more-effectively-than-the-canonical-fibonacci-lattice/ @@ -2994,7 +2997,7 @@ def strtobool(val: Optional[str], default: Optional[bool] = None) -> bool: # }}} -def _test(): +def _test() -> None: import doctest doctest.testmod() diff --git a/pytools/datatable.py b/pytools/datatable.py index 4fcb03e6..2f36e220 100644 --- a/pytools/datatable.py +++ b/pytools/datatable.py @@ -11,8 +11,7 @@ """ -# type-ignore-reason: Record is untyped -class Row(Record): # type: ignore[misc] +class Row(Record): pass diff --git a/pytools/graph.py b/pytools/graph.py index 054c5444..7ac43be8 100644 --- a/pytools/graph.py +++ b/pytools/graph.py @@ -255,8 +255,9 @@ def __init__(self, node: NodeT, key: Any) -> None: self.node = node self.key = key - def __lt__(self, other: "HeapEntry") -> bool: - return self.key < other.key + def __lt__(self, other: Any) -> bool: + from typing import cast + return cast(bool, self.key < other.key) def compute_topological_order(graph: GraphT[NodeT], diff --git a/pytools/mpi.py b/pytools/mpi.py index f74c1307..392ee30d 100644 --- a/pytools/mpi.py +++ b/pytools/mpi.py @@ -33,10 +33,10 @@ """ from contextlib import AbstractContextManager, contextmanager -from typing import Generator, Tuple, Type, Union +from typing import Any, Callable, Generator, Sequence, Tuple, Type, Union -def check_for_mpi_relaunch(argv): +def check_for_mpi_relaunch(argv: Sequence[Any]) -> None: if argv[1] != "--mpi-relaunch": return @@ -48,7 +48,9 @@ def check_for_mpi_relaunch(argv): sys.exit() -def run_with_mpi_ranks(py_script, ranks, callable_, args=(), kwargs=None): +def run_with_mpi_ranks(py_script: str, ranks: int, + callable_: Callable[[Any], Any], args: Any = (), + kwargs: Any = None) -> None: if kwargs is None: kwargs = {} @@ -70,7 +72,7 @@ def run_with_mpi_ranks(py_script, ranks, callable_, args=(), kwargs=None): def pytest_raises_on_rank(my_rank: int, fail_rank: int, expected_exception: Union[Type[BaseException], Tuple[Type[BaseException], ...]]) \ - -> Generator[AbstractContextManager, None, None]: + -> Generator[Any, None, None]: """ Like :func:`pytest.raises`, but only expect an exception on rank *fail_rank*. """ @@ -79,7 +81,7 @@ def pytest_raises_on_rank(my_rank: int, fail_rank: int, import pytest if my_rank == fail_rank: - cm: AbstractContextManager = pytest.raises(expected_exception) + cm: AbstractContextManager[Any] = pytest.raises(expected_exception) else: cm = nullcontext() diff --git a/run-mypy.sh b/run-mypy.sh index 39055a8c..f3e96420 100755 --- a/run-mypy.sh +++ b/run-mypy.sh @@ -4,4 +4,4 @@ set -ex mypy --show-error-codes pytools -mypy --strict --follow-imports=skip pytools/datatable.py +mypy --strict pytools/datatable.py pytools/graph.py pytools/mpi.py pytools/__init__.py diff --git a/setup.py b/setup.py index 852fd244..1df4e668 100644 --- a/setup.py +++ b/setup.py @@ -2,6 +2,7 @@ from setuptools import setup + ver_dic = {} version_file = open("pytools/version.py") try: