Skip to content

Commit 9d31c15

Browse files
committed
add concat
1 parent 3e2237b commit 9d31c15

File tree

3 files changed

+32
-1
lines changed

3 files changed

+32
-1
lines changed

_unittests/ut_reference/test_torch_evaluator.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -220,6 +220,26 @@ def test_op_unsqueeze(self):
220220
torch.tensor([2], dtype=torch.int64),
221221
)
222222

223+
def test_op_concat(self):
224+
model = oh.make_model(
225+
oh.make_graph(
226+
[oh.make_node("Concat", ["X", "Y"], ["Z"], axis=2)],
227+
"dummy",
228+
[
229+
oh.make_tensor_value_info("X", TFLOAT, ["a", "b", 1, "d"]),
230+
oh.make_tensor_value_info("Y", TFLOAT, ["a", "b", 1, "d"]),
231+
],
232+
[oh.make_tensor_value_info("Z", TFLOAT, ["a", "b", "d"])],
233+
),
234+
ir_version=9,
235+
opset_imports=[oh.make_opsetid("", 18)],
236+
)
237+
self._finalize_test(
238+
model,
239+
torch.rand((4, 5, 1, 7), dtype=torch.float32),
240+
torch.rand((4, 5, 2, 7), dtype=torch.float32),
241+
)
242+
223243

224244
if __name__ == "__main__":
225245
unittest.main(verbosity=2)
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from ._op_run import OpRun, OpRunValue
22
from .access_ops import Slice_13
33
from .binary_ops import Add_1, Div_1, MatMul_1, Mul_1, Sub_1
4-
from .other_ops import Cast_6, Transpose_1
4+
from .other_ops import Cast_6, Concat_1, Transpose_1
55
from .shape_ops import Reshape_5, Shape_15, Squeeze_13, Unsqueeze_13

onnx_diagnostic/reference/torch_ops/other_ops.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,17 @@ def run(self, data: OpRunValue) -> OpRunValue:
1616
return OpRunValue(data.tensor.to(self.to))
1717

1818

19+
class Concat_1(OpRun):
20+
def __init__(self, node: onnx.NodeProto, version: Optional[int] = None):
21+
super().__init__(node, version)
22+
axis = self.get_attribute_int(node, "axis", 0)
23+
assert isinstance(axis, int), f"Unexpected value for attribute axis={axis!r}"
24+
self.axis = axis
25+
26+
def run(self, *data: OpRunValue) -> OpRunValue:
27+
return OpRunValue(torch.cat([t.tensor for t in data], axis=self.axis))
28+
29+
1930
class Transpose_1(OpRun):
2031
def __init__(self, node: onnx.NodeProto, version: Optional[int] = None):
2132
super().__init__(node, version)

0 commit comments

Comments
 (0)