Skip to content

Commit 45d0e6b

Browse files
Chilleezou3519
authored andcommitted
[functorch] fixed some static argnums stuff
1 parent 3c07478 commit 45d0e6b

File tree

1 file changed

+3
-0
lines changed

1 file changed

+3
-0
lines changed

functorch/functorch/_src/aot_autograd.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from .partitioners import default_partition
1212
from .named_members_polyfill import _named_parameters, _named_buffers
1313
from typing import Callable, List, Dict, Any, Tuple, Optional
14+
from functools import wraps
1415

1516
try:
1617
from torchdynamo import disable as disable_torchdynamo
@@ -261,6 +262,7 @@ def rearrange(tensor_args, static_args, static_argnums):
261262
else:
262263
args.append(tensor_args[tensor_index])
263264
tensor_index += 1
265+
index += 1
264266

265267
while tensor_index < len(tensor_args):
266268
args.append(tensor_args[tensor_index])
@@ -375,6 +377,7 @@ def aot_function(
375377
static_argnums = list(static_argnums)
376378
static_argnums.sort()
377379

380+
@wraps(fn)
378381
def returned_function(*args, **kwargs):
379382
global compile_cache
380383
nonlocal cached_res

0 commit comments

Comments
 (0)