Skip to content

Commit 30e59cd

Browse files
Chilleezou3519
authored andcommitted
[functorch] Add wraps to jacfwd, fix partitioning heurisitc
1 parent 7143f06 commit 30e59cd

File tree

2 files changed

+3
-12
lines changed

2 files changed

+3
-12
lines changed

functorch/functorch/_src/eager_transforms.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -929,6 +929,7 @@ def jacfwd(func: Callable, argnums: argnums_t = 0, has_aux: bool = False):
929929
>>> assert torch.allclose(jacobian[1], expectedY)
930930
931931
"""
932+
@wraps(func)
932933
def wrapper_fn(*args):
933934
f_wrapper, primals = _argnums_partial(func, args, argnums)
934935
flat_primals, primals_spec = tree_flatten(primals)

functorch/functorch/_src/partitioners.py

Lines changed: 2 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -251,23 +251,13 @@ def classify_nodes(joint_module):
251251
return required_fw_nodes, required_bw_nodes, unclaimed_nodes
252252

253253
required_fw_nodes, required_bw_nodes, unclaimed_nodes = classify_nodes(joint_module)
254-
255-
def dist_from_bw(node):
256-
print(node)
257-
if node not in required_fw_nodes:
258-
return 0
259-
dist = int(1e9)
260-
for n in node.users:
261-
dist = min(dist_from_bw(n) + 1, dist)
262-
return dist
263-
264254
for node in reversed(joint_module.graph.nodes):
265255
if node not in required_fw_nodes:
266256
node.dist_from_bw = 0
267257
else:
268258
node.dist_from_bw = int(1e9)
269-
for n in node.users:
270-
node.dist_from_bw = min(node.dist_from_bw, n.dist_from_bw + 1)
259+
for user in node.users:
260+
node.dist_from_bw = min(node.dist_from_bw, user.dist_from_bw + 1)
271261

272262
aten = torch.ops.aten
273263

0 commit comments

Comments
 (0)