Skip to content
Open
2 changes: 1 addition & 1 deletion src/qcodes/parameters/command.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def __init__(
):
self.arg_count = arg_count

if no_cmd_function is not None and not is_function(no_cmd_function, arg_count):
if no_cmd_function is not None and not is_function(no_cmd_function, arg_count, coroutine = None):
raise TypeError(
f"no_cmd_function must be None or a function "
f"taking the same args as the command, not "
Expand Down
29 changes: 26 additions & 3 deletions src/qcodes/utils/function_helpers.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from inspect import iscoroutinefunction, signature
from inspect import CO_VARARGS, iscoroutinefunction, signature


def is_function(f: object, arg_count: int, coroutine: bool = False) -> bool:
def is_function(f: object, arg_count: int, coroutine: bool | None = False) -> bool:
"""
Check and require a function that can accept the specified number of
positional arguments, which either is or is not a coroutine
Expand All @@ -19,15 +19,38 @@ def is_function(f: object, arg_count: int, coroutine: bool = False) -> bool:
if not isinstance(arg_count, int) or arg_count < 0:
raise TypeError("arg_count must be a non-negative integer")

if not (callable(f) and bool(coroutine) is iscoroutinefunction(f)):
if not callable(f):
return False
if coroutine is not None:
if bool(coroutine) is not iscoroutinefunction(f):
return False

if isinstance(f, type):
# for type casting functions, eg int, str, float
# only support the one-parameter form of these,
# otherwise the user should make an explicit function.
return arg_count == 1

if func_code := getattr(f, "__code__", None):
# handle objects like functools.partial(f, ...)
func_defaults = getattr(f, "__defaults__", None)
number_of_defaults = len(func_defaults) if func_defaults is not None else 0

if getattr(f, "__self__", None) is not None:
# bound method
min_positional = func_code.co_argcount - 1 - number_of_defaults
max_positional = func_code.co_argcount - 1
else:
min_positional = func_code.co_argcount - number_of_defaults
max_positional = func_code.co_argcount

if func_code.co_flags & CO_VARARGS:
# we have *args
max_positional = 10e10

ev = min_positional <= arg_count <= max_positional
return ev

try:
sig = signature(f)
except ValueError:
Expand Down
32 changes: 32 additions & 0 deletions tests/utils/test_isfunction.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from functools import partial
from typing import NoReturn

import pytest
Expand Down Expand Up @@ -25,6 +26,10 @@ def f2(a: object, b: object) -> NoReturn:
assert is_function(f1, 1)
assert is_function(f2, 2)

assert is_function(f0, 0, coroutine = False)
assert is_function(f1, 1, coroutine = False)
assert is_function(f2, 2, coroutine = False)

assert not (is_function(f0, 1) or is_function(f0, 2))
assert not (is_function(f1, 0) or is_function(f1, 2))
assert not (is_function(f2, 0) or is_function(f2, 1))
Expand All @@ -36,6 +41,32 @@ def f2(a: object, b: object) -> NoReturn:
is_function(f0, -1)


def test_function_partial() -> None:
def f0(one_arg: int) -> int:
return one_arg

f = partial(f0, 1)
assert is_function(f, 0)
assert not is_function(f, 1)


def test_function_varargs() -> None:
def f(*args) -> None:
return None

assert is_function(f, 0)
assert is_function(f, 1)
assert is_function(f, 100)

def g(a, b=1, *args) -> None:
return None

assert not is_function(g, 0)
assert is_function(g, 1)
assert is_function(g, 2)
assert is_function(g, 100)


class AClass:
def method_a(self) -> NoReturn:
raise RuntimeError("function should not get called")
Expand Down Expand Up @@ -78,3 +109,4 @@ async def f_async() -> NoReturn:
assert not is_function(f_async, 0, coroutine=False)
assert is_function(f_async, 0, coroutine=True)
assert not is_function(f_async, 0)
assert is_function(f_async, 0, coroutine=None)