Skip to content

Commit 2cec313

Browse files
committed
fix: skip output node in FX graph cleanup to fix erase_node crash (Granite4 GPTQ)
Signed-off-by: gillesturpin <turpingilles@orange.fr>
1 parent a9847e0 commit 2cec313

File tree

1 file changed

+19
-0
lines changed

1 file changed

+19
-0
lines changed

src/llmcompressor/pipelines/sequential/transformers_helpers.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1478,9 +1478,28 @@ def to_meta(value):
14781478
to_delete = collections.OrderedDict()
14791479
while to_visit:
14801480
n = to_visit.pop(0)
1481+
# Never mark the output node for deletion — instead
1482+
# we will strip references to deleted nodes from it.
1483+
if n.op == "output":
1484+
continue
14811485
to_delete[n] = None
14821486
to_visit += list(n.users.keys())
14831487

1488+
# Remove references to dead nodes from the output node
1489+
for out_node in reversed(self.graph.nodes):
1490+
if out_node.op == "output":
1491+
1492+
def _strip(a):
1493+
if isinstance(a, torch.fx.Node) and a in to_delete:
1494+
return None
1495+
return a
1496+
1497+
new_args = torch.fx.node.map_aggregate(
1498+
out_node.args, _strip
1499+
)
1500+
out_node.args = new_args
1501+
break
1502+
14841503
for user in reversed(to_delete.keys()):
14851504
self.graph.erase_node(user)
14861505

0 commit comments

Comments
 (0)