|
4 | 4 | import unittest |
5 | 5 | import torch |
6 | 6 | from onnx_diagnostic.ext_test_case import ExtTestCase, hide_stdout |
| 7 | +from onnx_diagnostic.torch_export_patches import torch_export_patches, torch_export_rewrite |
7 | 8 | from onnx_diagnostic.torch_export_patches.patch_module import ( |
8 | 9 | transform_method, |
9 | 10 | inplace_add_parent, |
10 | 11 | ) |
11 | 12 |
|
12 | 13 |
|
| 14 | +class _ModelForATest(torch.nn.Module): |
| 15 | + def forward(self, x, y): |
| 16 | + if x.sum() > 0: |
| 17 | + return x + y |
| 18 | + else: |
| 19 | + return torch.abs(x) + y + 1 |
| 20 | + |
| 21 | + |
| 22 | +def _single_forward(x, y): |
| 23 | + if x.sum() > 0: |
| 24 | + return x + y |
| 25 | + else: |
| 26 | + return torch.abs(x) + y + 1 |
| 27 | + |
| 28 | + |
13 | 29 | class TestPatchModule(ExtTestCase): |
14 | 30 | def test_parent(self): |
15 | 31 | class Model(torch.nn.Module): |
@@ -361,8 +377,71 @@ def test_rewrite_PLBartEncoderLayer(self): |
361 | 377 | ), |
362 | 378 | rewritten.code, |
363 | 379 | ) |
364 | | - print() |
365 | | - print(rewritten.code) |
| 380 | + |
| 381 | + @hide_stdout() |
| 382 | + def test_torch_export_patch_method_tuple(self): |
| 383 | + class Model(torch.nn.Module): |
| 384 | + def forward(self, x, y): |
| 385 | + if x.sum() > 0: |
| 386 | + return x + y |
| 387 | + else: |
| 388 | + return torch.abs(x) + y + 1 |
| 389 | + |
| 390 | + model = Model() |
| 391 | + x, y = torch.rand((4, 5)), torch.rand((4, 5)) |
| 392 | + expected = model(x, y) |
| 393 | + DYN = torch.export.Dim.DYNAMIC |
| 394 | + ds = ({0: DYN, 1: DYN}, {0: DYN, 1: DYN}) |
| 395 | + with torch_export_patches(rewrite_methods=[(Model, "forward")], verbose=2): |
| 396 | + ep = torch.export.export(model, (x, y), dynamic_shapes=ds) |
| 397 | + got = ep.module()(x, y) |
| 398 | + self.assertEqualArray(expected, got) |
| 399 | + |
| 400 | + @hide_stdout() |
| 401 | + def test_torch_export_rewrite_method_tuple(self): |
| 402 | + class Model(torch.nn.Module): |
| 403 | + def forward(self, x, y): |
| 404 | + if x.sum() > 0: |
| 405 | + return x + y |
| 406 | + else: |
| 407 | + return torch.abs(x) + y + 1 |
| 408 | + |
| 409 | + model = Model() |
| 410 | + x, y = torch.rand((4, 5)), torch.rand((4, 5)) |
| 411 | + expected = model(x, y) |
| 412 | + DYN = torch.export.Dim.DYNAMIC |
| 413 | + ds = ({0: DYN, 1: DYN}, {0: DYN, 1: DYN}) |
| 414 | + with torch_export_rewrite(rewrite_methods=[(Model, "forward")], verbose=1): |
| 415 | + ep = torch.export.export(model, (x, y), dynamic_shapes=ds) |
| 416 | + got = ep.module()(x, y) |
| 417 | + self.assertEqualArray(expected, got) |
| 418 | + |
| 419 | + def test_torch_export_rewrite_method_only(self): |
| 420 | + model = _ModelForATest() |
| 421 | + x, y = torch.rand((4, 5)), torch.rand((4, 5)) |
| 422 | + expected = model(x, y) |
| 423 | + DYN = torch.export.Dim.DYNAMIC |
| 424 | + ds = ({0: DYN, 1: DYN}, {0: DYN, 1: DYN}) |
| 425 | + with torch_export_rewrite(rewrite_methods=[_ModelForATest.forward], verbose=0): |
| 426 | + ep = torch.export.export(model, (x, y), dynamic_shapes=ds) |
| 427 | + got = ep.module()(x, y) |
| 428 | + self.assertEqualArray(expected, got) |
| 429 | + |
| 430 | + @hide_stdout() |
| 431 | + def test_torch_export_rewrite_function(self): |
| 432 | + class Model(torch.nn.Module): |
| 433 | + def forward(self, x, y): |
| 434 | + return _single_forward(x, y) |
| 435 | + |
| 436 | + model = Model() |
| 437 | + x, y = torch.rand((4, 5)), torch.rand((4, 5)) |
| 438 | + expected = model(x, y) |
| 439 | + DYN = torch.export.Dim.DYNAMIC |
| 440 | + ds = ({0: DYN, 1: DYN}, {0: DYN, 1: DYN}) |
| 441 | + with torch_export_rewrite(rewrite_methods=[_single_forward], verbose=1): |
| 442 | + ep = torch.export.export(model, (x, y), dynamic_shapes=ds) |
| 443 | + got = ep.module()(x, y) |
| 444 | + self.assertEqualArray(expected, got) |
366 | 445 |
|
367 | 446 |
|
368 | 447 | if __name__ == "__main__": |
|
0 commit comments