Skip to content

Commit add3086

Browse files
committed
dis
1 parent 43ad01c commit add3086

File tree

1 file changed

+4
-1
lines changed

1 file changed

+4
-1
lines changed

_unittests/ut_export/test_control_flow.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import torch
44
from onnxscript import script, FLOAT, INT64
55
from onnxscript import opset18 as op
6-
from onnx_diagnostic.ext_test_case import ExtTestCase
6+
from onnx_diagnostic.ext_test_case import ExtTestCase, requires_torch
77
from onnx_diagnostic.export.control_flow import enable_code_export_control_flow, loop_for
88
from onnx_diagnostic.export.control_flow_research import simple_loop_for as loop_for_r
99
from onnx_diagnostic.export.api import to_onnx
@@ -47,6 +47,7 @@ def concatenation(N: INT64[1], x: FLOAT[None]) -> FLOAT[None, 1]:
4747
onx = concatenation.to_model_proto()
4848
self.dump_onnx("test_onnxscript_loop.onnx", onx)
4949

50+
@requires_torch("2.9.99")
5051
def test_loop_one_custom(self):
5152
class Model(torch.nn.Module):
5253
def forward(self, n_iter, x):
@@ -77,6 +78,7 @@ def body(i, x):
7778
self.dump_onnx("test_loop_one_custom.onnx", onx)
7879
self.assert_onnx_disc("test_loop_one_custom", onx, model, (n_iter, x))
7980

81+
@requires_torch("2.9.99")
8082
def test_loop_two_custom(self):
8183
class Model(torch.nn.Module):
8284
def forward(self, n_iter, x):
@@ -108,6 +110,7 @@ def body(i, x):
108110
self.dump_onnx("test_loop_one_custom.onnx", onx)
109111
self.assert_onnx_disc("test_loop_one_custom", onx, model, (n_iter, x))
110112

113+
@requires_torch("2.9.99")
111114
def test_loop_two_custom_reduction_dim(self):
112115
class Model(torch.nn.Module):
113116
def forward(self, n_iter, x):

0 commit comments

Comments
 (0)