Skip to content

Commit 03a43d3

Browse files
committed
support none
1 parent a7de659 commit 03a43d3

File tree

2 files changed

+27
-1
lines changed

2 files changed

+27
-1
lines changed

_unittests/ut_torch_export_patches/test_patch_module.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -325,6 +325,30 @@ def forward(self, x, y):
325325
self.assertEqualAny(expected_0, ep.module()(x, -y))
326326
self.assertEqualAny(expected_1, ep.module()(-x, -y))
327327

328+
def test_rewrite_forward_none(self):
329+
330+
class Model(torch.nn.Module):
331+
def forward(self, x, y):
332+
if x is None:
333+
x = torch.abs(y)
334+
return x + y
335+
336+
x, y = torch.rand((3, 4)), torch.rand((3, 4))
337+
expected, expected_ = Model()(x, y), Model()(-x, y)
338+
339+
rewritten = transform_method(Model.forward, verbose=self.verbose)
340+
self.assertIn("torch.abs(", rewritten.code)
341+
self.assertIn("abs", rewritten.dump)
342+
Model.forward = rewritten.func
343+
self.assertEqualAny(expected, Model()(x, y))
344+
self.assertEqualAny(expected_, Model()(-x, y))
345+
346+
DYN = torch.export.Dim.DYNAMIC
347+
ds = ({0: DYN, 1: DYN}, {0: DYN, 1: DYN})
348+
ep = torch.export.export(Model(), (x, y), dynamic_shapes=ds)
349+
self.assertEqualAny(expected, ep.module()(x, y))
350+
self.assertEqualAny(expected_, ep.module()(-x, y))
351+
328352

329353
if __name__ == "__main__":
330354
unittest.main(verbosity=2)

onnx_diagnostic/torch_export_patches/patch_module.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ def _find_id(self, exprs: List["ast.Node"]) -> List[str]:
8080
for n in ast.walk(expr):
8181
if (
8282
isinstance(n, ast.Name)
83-
and isinstance(n.ctx, ast.Load)
83+
# and isinstance(n.ctx, ast.Load)
8484
and n.id not in self.skip_objects
8585
):
8686
vars.append(n.id)
@@ -267,9 +267,11 @@ def visit_If(self, node):
267267
tuple(t for t in tgt.elts if t.id not in dropped), ctx=ast.Store()
268268
)
269269

270+
added = {tgt.id} if isinstance(tgt, ast.Name) else set(t.id for t in tgt.elts)
270271
assign = ast.Assign(targets=[tgt], value=call)
271272
ast.copy_location(assign, node)
272273
ast.fix_missing_locations(assign)
274+
self.local_variables = known_local_variables | added
273275
return [then_def, else_def, assign]
274276

275277
# Case 2: return in both branches, we assume both branches return the same results.

0 commit comments

Comments
 (0)