Skip to content

Commit 3387c0d

Browse files
authored
Merge pull request microsoft#7669 from eendebakpt/is_function
Improve performance of is_function
2 parents 5923404 + f04096f commit 3387c0d

File tree

3 files changed

+61
-4
lines changed

3 files changed

+61
-4
lines changed

src/qcodes/parameters/command.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,9 @@ def __init__(
6868
):
6969
self.arg_count = arg_count
7070

71-
if no_cmd_function is not None and not is_function(no_cmd_function, arg_count):
71+
if no_cmd_function is not None and not is_function(
72+
no_cmd_function, arg_count, coroutine=None
73+
):
7274
raise TypeError(
7375
f"no_cmd_function must be None or a function "
7476
f"taking the same args as the command, not "

src/qcodes/utils/function_helpers.py

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
1-
from inspect import iscoroutinefunction, signature
1+
from inspect import CO_VARARGS, iscoroutinefunction, signature
22

33

4-
def is_function(f: object, arg_count: int, coroutine: bool = False) -> bool:
4+
def is_function(f: object, arg_count: int, coroutine: bool | None = False) -> bool:
55
"""
66
Check and require a function that can accept the specified number of
77
positional arguments, which either is or is not a coroutine
@@ -19,15 +19,38 @@ def is_function(f: object, arg_count: int, coroutine: bool = False) -> bool:
1919
if not isinstance(arg_count, int) or arg_count < 0:
2020
raise TypeError("arg_count must be a non-negative integer")
2121

22-
if not (callable(f) and bool(coroutine) is iscoroutinefunction(f)):
22+
if not callable(f):
2323
return False
24+
if coroutine is not None:
25+
if bool(coroutine) is not iscoroutinefunction(f):
26+
return False
2427

2528
if isinstance(f, type):
2629
# for type casting functions, eg int, str, float
2730
# only support the one-parameter form of these,
2831
# otherwise the user should make an explicit function.
2932
return arg_count == 1
3033

34+
if func_code := getattr(f, "__code__", None):
35+
# handle objects like functools.partial(f, ...)
36+
func_defaults = getattr(f, "__defaults__", None)
37+
number_of_defaults = len(func_defaults) if func_defaults is not None else 0
38+
39+
if getattr(f, "__self__", None) is not None:
40+
# bound method
41+
min_positional = func_code.co_argcount - 1 - number_of_defaults
42+
max_positional = func_code.co_argcount - 1
43+
else:
44+
min_positional = func_code.co_argcount - number_of_defaults
45+
max_positional = func_code.co_argcount
46+
47+
if func_code.co_flags & CO_VARARGS:
48+
# we have *args
49+
max_positional = 10e10
50+
51+
ev = min_positional <= arg_count <= max_positional
52+
return ev
53+
3154
try:
3255
sig = signature(f)
3356
except ValueError:

tests/utils/test_isfunction.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from functools import partial
12
from typing import NoReturn
23

34
import pytest
@@ -25,6 +26,10 @@ def f2(a: object, b: object) -> NoReturn:
2526
assert is_function(f1, 1)
2627
assert is_function(f2, 2)
2728

29+
assert is_function(f0, 0, coroutine=False)
30+
assert is_function(f1, 1, coroutine=False)
31+
assert is_function(f2, 2, coroutine=False)
32+
2833
assert not (is_function(f0, 1) or is_function(f0, 2))
2934
assert not (is_function(f1, 0) or is_function(f1, 2))
3035
assert not (is_function(f2, 0) or is_function(f2, 1))
@@ -36,6 +41,32 @@ def f2(a: object, b: object) -> NoReturn:
3641
is_function(f0, -1)
3742

3843

44+
def test_function_partial() -> None:
45+
def f0(one_arg: int) -> int:
46+
return one_arg
47+
48+
f = partial(f0, 1)
49+
assert is_function(f, 0)
50+
assert not is_function(f, 1)
51+
52+
53+
def test_function_varargs() -> None:
54+
def f(*args) -> None:
55+
return None
56+
57+
assert is_function(f, 0)
58+
assert is_function(f, 1)
59+
assert is_function(f, 100)
60+
61+
def g(a, b=1, *args) -> None:
62+
return None
63+
64+
assert not is_function(g, 0)
65+
assert is_function(g, 1)
66+
assert is_function(g, 2)
67+
assert is_function(g, 100)
68+
69+
3970
class AClass:
4071
def method_a(self) -> NoReturn:
4172
raise RuntimeError("function should not get called")
@@ -78,3 +109,4 @@ async def f_async() -> NoReturn:
78109
assert not is_function(f_async, 0, coroutine=False)
79110
assert is_function(f_async, 0, coroutine=True)
80111
assert not is_function(f_async, 0)
112+
assert is_function(f_async, 0, coroutine=None)

0 commit comments

Comments
 (0)