Skip to content

Commit df6f2c0

Browse files
committed
add gather
1 parent 9d31c15 commit df6f2c0

File tree

4 files changed

+46
-1
lines changed

4 files changed

+46
-1
lines changed

_unittests/ut_reference/test_torch_evaluator.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -240,6 +240,26 @@ def test_op_concat(self):
240240
torch.rand((4, 5, 2, 7), dtype=torch.float32),
241241
)
242242

243+
def test_op_gather(self):
244+
model = oh.make_model(
245+
oh.make_graph(
246+
[oh.make_node("Gather", ["X", "Y"], ["Z"], axis=1)],
247+
"dummy",
248+
[
249+
oh.make_tensor_value_info("X", TFLOAT, ["a", "b", "c", "d"]),
250+
oh.make_tensor_value_info("Y", TINT64, ["f"]),
251+
],
252+
[oh.make_tensor_value_info("Z", TFLOAT, ["a", "b", "d"])],
253+
),
254+
ir_version=9,
255+
opset_imports=[oh.make_opsetid("", 18)],
256+
)
257+
self._finalize_test(
258+
model,
259+
torch.rand((5, 4, 3, 2), dtype=torch.float32),
260+
torch.tensor([0, 1, 3], dtype=torch.int64),
261+
)
262+
243263

244264
if __name__ == "__main__":
245265
unittest.main(verbosity=2)
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from ._op_run import OpRun, OpRunValue
2-
from .access_ops import Slice_13
2+
from .access_ops import Gather_1, Slice_13
33
from .binary_ops import Add_1, Div_1, MatMul_1, Mul_1, Sub_1
44
from .other_ops import Cast_6, Concat_1, Transpose_1
55
from .shape_ops import Reshape_5, Shape_15, Squeeze_13, Unsqueeze_13

onnx_diagnostic/reference/torch_ops/access_ops.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,26 @@
11
from typing import Optional
2+
import onnx
3+
import torch
24
from . import OpRun, OpRunValue
35

46

7+
class Gather_1(OpRun):
8+
"Gather"
9+
10+
def __init__(self, node: onnx.NodeProto, version: Optional[int] = None):
11+
super().__init__(node, version)
12+
axis = self.get_attribute_int(node, "axis", 0)
13+
assert isinstance(axis, int), f"Unexpected value for attribute axis={axis!r}"
14+
self.axis = axis
15+
16+
def run(self, x, indices):
17+
if indices.tensor.numel() == 0:
18+
return torch.empty((0,), dtype=x.tensor.dtype, device=x.tensor.device)
19+
ind = [slice(0, s) for s in x.shape]
20+
ind[self.axis] = indices.tensor
21+
return OpRunValue(x.tensor[*ind])
22+
23+
524
class Slice_13(OpRun):
625
"Slice"
726

onnx_diagnostic/reference/torch_ops/other_ops.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66

77

88
class Cast_6(OpRun):
9+
"Cast"
10+
911
def __init__(self, node: onnx.NodeProto, version: Optional[int] = None):
1012
super().__init__(node, version)
1113
to = self.get_attribute_int(node, "to", 0)
@@ -17,6 +19,8 @@ def run(self, data: OpRunValue) -> OpRunValue:
1719

1820

1921
class Concat_1(OpRun):
22+
"Concat"
23+
2024
def __init__(self, node: onnx.NodeProto, version: Optional[int] = None):
2125
super().__init__(node, version)
2226
axis = self.get_attribute_int(node, "axis", 0)
@@ -28,6 +32,8 @@ def run(self, *data: OpRunValue) -> OpRunValue:
2832

2933

3034
class Transpose_1(OpRun):
35+
"Transpose"
36+
3137
def __init__(self, node: onnx.NodeProto, version: Optional[int] = None):
3238
super().__init__(node, version)
3339
self.perm = self.get_attribute_ints(node, "perm", None)

0 commit comments

Comments
 (0)