Skip to content

Commit 3b8030a

Browse files
Merge pull request #80 from tensorcircuit/perf-fix-arg-alias-18006180529273122311
⚡ Optimize arg_alias and fix string alias bug
2 parents 95f85b5 + b53978e commit 3b8030a

File tree

2 files changed

+32
-10
lines changed

2 files changed

+32
-10
lines changed

tensorcircuit/utils.py

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -149,16 +149,24 @@ def arg_alias(
149149
:rtype: Callable[..., Any]
150150
"""
151151

152+
flat_aliases = []
153+
normalized_alias_dict = {}
154+
for k, vs in alias_dict.items():
155+
if isinstance(vs, str):
156+
vs_list = [vs]
157+
else:
158+
vs_list = vs # type: ignore
159+
normalized_alias_dict[k] = vs_list
160+
for v in vs_list:
161+
flat_aliases.append((k, v))
162+
152163
@wraps(f)
153164
def wrapper(*args: Any, **kws: Any) -> Any:
154-
for k, vs in alias_dict.items():
155-
if isinstance(vs, str):
156-
vs = []
157-
for v in vs:
158-
if v in kws:
159-
# in case it is None by design!
160-
kws[k] = kws[v]
161-
del kws[v]
165+
for k, v in flat_aliases:
166+
if v in kws:
167+
# in case it is None by design!
168+
kws[k] = kws[v]
169+
del kws[v]
162170
return f(*args, **kws)
163171

164172
if fix_doc:
@@ -176,15 +184,15 @@ def wrapper(*args: Any, **kws: Any) -> Any:
176184

177185
if line.strip().startswith(":param "):
178186
param = re.findall(r":param\s([^\s]+):", line)[0]
179-
if param in alias_dict:
187+
if param in normalized_alias_dict:
180188
j = 1
181189
while True:
182190
ndoc.append(doc[i + j])
183191
if doc[i + j].strip().startswith(":type"):
184192
break
185193
j += 1
186194
skip = True
187-
for v in alias_dict[param]:
195+
for v in normalized_alias_dict[param]:
188196
ndoc.append(
189197
re.sub(
190198
r"(.*:param\s)([^\s]+)(:.*)",

tests/test_miscs.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -192,6 +192,20 @@ def f(theta: float, beta: float) -> float:
192192
print(f.__doc__)
193193
assert len(f.__doc__.strip().split("\n")) == 12
194194

195+
@partial(tc.utils.arg_alias, alias_dict={"theta": "alpha"})
196+
def g(theta: float) -> float:
197+
"""
198+
g doc
199+
200+
:param theta: theta angle
201+
:type theta: float
202+
:return: theta
203+
"""
204+
return theta
205+
206+
np.testing.assert_allclose(g(alpha=1.0), 1.0, atol=1e-5)
207+
assert "alpha: alias for the argument ``theta``" in g.__doc__
208+
195209

196210
def test_finite_difference_tf(tfb):
197211
def f(param1, param2):

0 commit comments

Comments
 (0)