Skip to content

Commit 78092a1

Browse files
committed
ut
1 parent 951257d commit 78092a1

File tree

1 file changed

+10
-6
lines changed

1 file changed

+10
-6
lines changed

_unittests/ut_export/test_jit.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import inspect
12
import unittest
23
import torch
34
from onnx_diagnostic.ext_test_case import (
@@ -15,6 +16,9 @@
1516
to_onnx = None
1617

1718

19+
has_scan_reverse = "reverse" in set(inspect.signature(torch.ops.higher_order.scan).parameters)
20+
21+
1822
@torch.jit.script_if_tracing
1923
def dummy_loop(padded: torch.Tensor, pos: torch.Tensor):
2024
copy = torch.zeros(padded.shape)
@@ -36,12 +40,12 @@ def pad_row(padded, p):
3640
row[: p.item()] = padded[: p.item()]
3741
return (row,)
3842

39-
return torch.ops.higher_order.scan(
40-
pad_row,
41-
[],
42-
[padded, pos],
43-
[],
44-
)
43+
if has_scan_reverse:
44+
# torch==2.6
45+
return torch.ops.higher_order.scan(
46+
pad_row, [], [padded, pos], additional_inputs=[], reverse=False, dim=0
47+
)
48+
return torch.ops.higher_order.scan(pad_row, [], [padded, pos], [])
4549

4650

4751
def select_when_exporting(f, f_scan):

0 commit comments

Comments
 (0)