|
10 | 10 |
|
11 | 11 |
|
12 | 12 | TFLOAT = onnx.TensorProto.FLOAT |
| 13 | +TINT64 = onnx.TensorProto.INT64 |
13 | 14 |
|
14 | 15 |
|
15 | 16 | class TestTorchEvaluator(ExtTestCase): |
@@ -72,6 +73,31 @@ def test_binary_ops(self): |
72 | 73 | else: |
73 | 74 | self.assertEmpty(v.value) |
74 | 75 |
|
| 76 | + def test_slice_squeeze(self): |
| 77 | + X = oh.make_tensor_value_info("X", TFLOAT, [None, None]) |
| 78 | + starts = oh.make_tensor_value_info("starts", TINT64, [None]) |
| 79 | + ends = oh.make_tensor_value_info("ends", TINT64, [None]) |
| 80 | + axes = oh.make_tensor_value_info("axes", TINT64, [None]) |
| 81 | + Y = oh.make_tensor_value_info("Y", TINT64, [None]) |
| 82 | + nodes = [ |
| 83 | + oh.make_node("Slice", ["X", "starts", "ends", "axes"], ["T"]), |
| 84 | + oh.make_node("Squeeze", ["T", "axes"], ["Y"]), |
| 85 | + ] |
| 86 | + graph = oh.make_graph(nodes, "g", [X, starts, ends, axes], [Y]) |
| 87 | + model = oh.make_model(graph, opset_imports=[oh.make_opsetid("", 18)]) |
| 88 | + feeds = { |
| 89 | + "X": torch.tensor([[0]], dtype=torch.int64), |
| 90 | + "starts": torch.tensor([0], dtype=torch.int64), |
| 91 | + "ends": torch.tensor([1], dtype=torch.int64), |
| 92 | + "axes": torch.tensor([0], dtype=torch.int64), |
| 93 | + } |
| 94 | + expected = ExtendedReferenceEvaluator(model).run( |
| 95 | + None, {k: v.numpy() for k, v in feeds.items()} |
| 96 | + ) |
| 97 | + rt = TorchEvaluator(model) |
| 98 | + got = rt.run(None, feeds) |
| 99 | + self.assertEqualAny(expected, [g.detach().numpy() for g in got]) |
| 100 | + |
75 | 101 |
|
76 | 102 | if __name__ == "__main__": |
77 | 103 | unittest.main(verbosity=2) |
0 commit comments