Skip to content

Commit 8bdf782

Browse files
Chilleezou3519
authored andcommitted
[functorch] fix distance heuristic
1 parent 6aacc28 commit 8bdf782

File tree

2 files changed

+22
-11
lines changed

2 files changed

+22
-11
lines changed

functorch/functorch/_src/decompositions.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,21 @@
1212
# Decompositions have been ported to torch._decomp inside of PyTorch core. The only decompositions here are temporary or hacks. Please submit your contributions to PyTorch core!
1313

1414

15-
@register_decomposition(aten.trace.default)
15+
def maybe_register_decomposition(op):
16+
def decorator(f):
17+
try:
18+
return register_decomposition(op)(f)
19+
except Exception:
20+
return f
21+
return decorator
22+
23+
24+
@maybe_register_decomposition(aten.trace.default)
1625
def trace(self: Tensor) -> Tensor:
1726
return torch.sum(torch.diag(self))
1827

1928

20-
@register_decomposition(aten.log_sigmoid_forward)
29+
@maybe_register_decomposition(aten.log_sigmoid_forward.default)
2130
def log_sigmoid_forward(self: Tensor) -> Tuple[Tensor, Tensor]:
2231
min = torch.minimum(self.new_zeros(()), self)
2332
z = torch.exp(-torch.abs(self))

functorch/functorch/_src/partitioners.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -251,21 +251,23 @@ def classify_nodes(joint_module):
251251
return required_fw_nodes, required_bw_nodes, unclaimed_nodes
252252

253253
required_fw_nodes, required_bw_nodes, unclaimed_nodes = classify_nodes(joint_module)
254-
cache = {}
255254

256-
def dist_from_fw(node):
257-
if node in cache:
258-
return cache[node]
255+
def dist_from_bw(node):
256+
print(node)
259257
if node not in required_fw_nodes:
260258
return 0
261259
dist = int(1e9)
262260
for n in node.users:
263-
dist = min(dist_from_fw(n) + 1, dist)
264-
cache[node] = dist
261+
dist = min(dist_from_bw(n) + 1, dist)
265262
return dist
266263

267-
for node in joint_module.graph.nodes:
268-
node.dist_from_fw = dist_from_fw(node)
264+
for node in reversed(joint_module.graph.nodes):
265+
if node not in required_fw_nodes:
266+
node.dist_from_bw = 0
267+
else:
268+
node.dist_from_bw = int(1e9)
269+
for n in node.users:
270+
node.dist_from_bw = min(node.dist_from_bw, n.dist_from_bw + 1)
269271

270272
aten = torch.ops.aten
271273

@@ -325,7 +327,7 @@ def get_node_weight(node):
325327
mem_sz = _size_of(node.meta['tensor_meta'])
326328

327329
# Heuristic to bias towards nodes closer to the backwards pass
328-
mem_sz = int(mem_sz + node.dist_from_fw)
330+
mem_sz = int(mem_sz + node.dist_from_bw)
329331

330332
if is_materialized(node):
331333
return mem_sz

0 commit comments

Comments
 (0)