|
3 | 3 | import torch |
4 | 4 | from onnxscript import script, FLOAT, INT64 |
5 | 5 | from onnxscript import opset18 as op |
6 | | -from onnx_diagnostic.ext_test_case import ExtTestCase |
| 6 | +from onnx_diagnostic.ext_test_case import ExtTestCase, requires_torch |
7 | 7 | from onnx_diagnostic.export.control_flow import enable_code_export_control_flow, loop_for |
8 | 8 | from onnx_diagnostic.export.control_flow_research import simple_loop_for as loop_for_r |
9 | 9 | from onnx_diagnostic.export.api import to_onnx |
@@ -47,6 +47,7 @@ def concatenation(N: INT64[1], x: FLOAT[None]) -> FLOAT[None, 1]: |
47 | 47 | onx = concatenation.to_model_proto() |
48 | 48 | self.dump_onnx("test_onnxscript_loop.onnx", onx) |
49 | 49 |
|
| 50 | + @requires_torch("2.9.99") |
50 | 51 | def test_loop_one_custom(self): |
51 | 52 | class Model(torch.nn.Module): |
52 | 53 | def forward(self, n_iter, x): |
@@ -77,6 +78,7 @@ def body(i, x): |
77 | 78 | self.dump_onnx("test_loop_one_custom.onnx", onx) |
78 | 79 | self.assert_onnx_disc("test_loop_one_custom", onx, model, (n_iter, x)) |
79 | 80 |
|
| 81 | + @requires_torch("2.9.99") |
80 | 82 | def test_loop_two_custom(self): |
81 | 83 | class Model(torch.nn.Module): |
82 | 84 | def forward(self, n_iter, x): |
@@ -108,6 +110,7 @@ def body(i, x): |
108 | 110 | self.dump_onnx("test_loop_one_custom.onnx", onx) |
109 | 111 | self.assert_onnx_disc("test_loop_one_custom", onx, model, (n_iter, x)) |
110 | 112 |
|
| 113 | + @requires_torch("2.9.99") |
111 | 114 | def test_loop_two_custom_reduction_dim(self): |
112 | 115 | class Model(torch.nn.Module): |
113 | 116 | def forward(self, n_iter, x): |
|
0 commit comments