diff --git a/src/qcodes/parameters/command.py b/src/qcodes/parameters/command.py index 18815f6beb41..cfae86cd7cff 100644 --- a/src/qcodes/parameters/command.py +++ b/src/qcodes/parameters/command.py @@ -68,7 +68,7 @@ def __init__( ): self.arg_count = arg_count - if no_cmd_function is not None and not is_function(no_cmd_function, arg_count): + if no_cmd_function is not None and not is_function(no_cmd_function, arg_count, coroutine = None): raise TypeError( f"no_cmd_function must be None or a function " f"taking the same args as the command, not " diff --git a/src/qcodes/utils/function_helpers.py b/src/qcodes/utils/function_helpers.py index d4b4d7a7a6ad..67e6f75a37d6 100644 --- a/src/qcodes/utils/function_helpers.py +++ b/src/qcodes/utils/function_helpers.py @@ -1,7 +1,7 @@ -from inspect import iscoroutinefunction, signature +from inspect import CO_VARARGS, iscoroutinefunction, signature -def is_function(f: object, arg_count: int, coroutine: bool = False) -> bool: +def is_function(f: object, arg_count: int, coroutine: bool | None = False) -> bool: """ Check and require a function that can accept the specified number of positional arguments, which either is or is not a coroutine @@ -19,8 +19,11 @@ def is_function(f: object, arg_count: int, coroutine: bool = False) -> bool: if not isinstance(arg_count, int) or arg_count < 0: raise TypeError("arg_count must be a non-negative integer") - if not (callable(f) and bool(coroutine) is iscoroutinefunction(f)): + if not callable(f): return False + if coroutine is not None: + if bool(coroutine) is not iscoroutinefunction(f): + return False if isinstance(f, type): # for type casting functions, eg int, str, float @@ -28,6 +31,26 @@ def is_function(f: object, arg_count: int, coroutine: bool = False) -> bool: # otherwise the user should make an explicit function. return arg_count == 1 + if func_code := getattr(f, "__code__", None): + # handle objects like functools.partial(f, ...) + func_defaults = getattr(f, "__defaults__", None) + number_of_defaults = len(func_defaults) if func_defaults is not None else 0 + + if getattr(f, "__self__", None) is not None: + # bound method + min_positional = func_code.co_argcount - 1 - number_of_defaults + max_positional = func_code.co_argcount - 1 + else: + min_positional = func_code.co_argcount - number_of_defaults + max_positional = func_code.co_argcount + + if func_code.co_flags & CO_VARARGS: + # we have *args + max_positional = 10e10 + + ev = min_positional <= arg_count <= max_positional + return ev + try: sig = signature(f) except ValueError: diff --git a/tests/utils/test_isfunction.py b/tests/utils/test_isfunction.py index fc2a3dc368be..6692d6675f0c 100644 --- a/tests/utils/test_isfunction.py +++ b/tests/utils/test_isfunction.py @@ -1,3 +1,4 @@ +from functools import partial from typing import NoReturn import pytest @@ -25,6 +26,10 @@ def f2(a: object, b: object) -> NoReturn: assert is_function(f1, 1) assert is_function(f2, 2) + assert is_function(f0, 0, coroutine = False) + assert is_function(f1, 1, coroutine = False) + assert is_function(f2, 2, coroutine = False) + assert not (is_function(f0, 1) or is_function(f0, 2)) assert not (is_function(f1, 0) or is_function(f1, 2)) assert not (is_function(f2, 0) or is_function(f2, 1)) @@ -36,6 +41,32 @@ def f2(a: object, b: object) -> NoReturn: is_function(f0, -1) +def test_function_partial() -> None: + def f0(one_arg: int) -> int: + return one_arg + + f = partial(f0, 1) + assert is_function(f, 0) + assert not is_function(f, 1) + + +def test_function_varargs() -> None: + def f(*args) -> None: + return None + + assert is_function(f, 0) + assert is_function(f, 1) + assert is_function(f, 100) + + def g(a, b=1, *args) -> None: + return None + + assert not is_function(g, 0) + assert is_function(g, 1) + assert is_function(g, 2) + assert is_function(g, 100) + + class AClass: def method_a(self) -> NoReturn: raise RuntimeError("function should not get called") @@ -78,3 +109,4 @@ async def f_async() -> NoReturn: assert not is_function(f_async, 0, coroutine=False) assert is_function(f_async, 0, coroutine=True) assert not is_function(f_async, 0) + assert is_function(f_async, 0, coroutine=None)