Skip to content

Commit 12eb1ba

Browse files
committed
fix a few things
1 parent 118388a commit 12eb1ba

File tree

5 files changed

+130
-13
lines changed

5 files changed

+130
-13
lines changed

_unittests/ut_reference/test_torch_onnx_evaluator.py

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -435,6 +435,23 @@ def test_op_reduce_min(self):
435435
atol=1e-6,
436436
)
437437

438+
def test_op_reduce_min_no_axes(self):
439+
model = oh.make_model(
440+
oh.make_graph(
441+
[oh.make_node("ReduceMin", ["X"], ["Z"], keepdims=0)],
442+
"dummy",
443+
[oh.make_tensor_value_info("X", TFLOAT, ["a", "b", "c"])],
444+
[oh.make_tensor_value_info("Z", TFLOAT, ["a", "b", "c"])],
445+
),
446+
ir_version=9,
447+
opset_imports=[oh.make_opsetid("", 18)],
448+
)
449+
self._finalize_test(
450+
model,
451+
torch.rand(3, 4, 5, dtype=torch.float32),
452+
atol=1e-6,
453+
)
454+
438455
def test_op_reduce_min_17(self):
439456
model = oh.make_model(
440457
oh.make_graph(
@@ -452,6 +469,23 @@ def test_op_reduce_min_17(self):
452469
atol=1e-6,
453470
)
454471

472+
def test_op_reduce_min_17_no_axes(self):
473+
model = oh.make_model(
474+
oh.make_graph(
475+
[oh.make_node("ReduceMin", ["X"], ["Z"])],
476+
"dummy",
477+
[oh.make_tensor_value_info("X", TFLOAT, ["a", "b", "c"])],
478+
[oh.make_tensor_value_info("Z", TFLOAT, ["a", "b", "c"])],
479+
),
480+
ir_version=9,
481+
opset_imports=[oh.make_opsetid("", 17)],
482+
)
483+
self._finalize_test(
484+
model,
485+
torch.rand(3, 4, 5, dtype=torch.float32),
486+
atol=1e-6,
487+
)
488+
455489
def test_op_reduce_sum(self):
456490
model = oh.make_model(
457491
oh.make_graph(
@@ -1123,6 +1157,42 @@ def test_conv(self):
11231157
B = torch.tensor([[[[0]]]], dtype=torch.float32)
11241158
self._finalize_test(model, X, W, B, use_ort=True)
11251159

1160+
def test_conv_autopad_valid(self):
1161+
model = oh.make_model(
1162+
oh.make_graph(
1163+
[
1164+
oh.make_node(
1165+
"Conv",
1166+
["X", "W", "B"],
1167+
["Y"],
1168+
dilations=[1, 1],
1169+
strides=[2, 2],
1170+
auto_pad="VALID",
1171+
)
1172+
],
1173+
"g",
1174+
[
1175+
oh.make_tensor_value_info("X", TFLOAT, [None, None, None, None]),
1176+
oh.make_tensor_value_info("W", TFLOAT, [None, None, None, None]),
1177+
oh.make_tensor_value_info("B", TFLOAT, [None, None, None, None]),
1178+
],
1179+
[oh.make_tensor_value_info("Y", TFLOAT, [None, None, None, None])],
1180+
),
1181+
opset_imports=[oh.make_opsetid("", 18)],
1182+
ir_version=10,
1183+
)
1184+
sH, sW = 5, 5
1185+
i = sH // 2
1186+
j = sW // 2
1187+
X = torch.zeros((1, 1, sH, sW), dtype=torch.float32)
1188+
X[0, 0, i, j] = 1.0
1189+
W = torch.zeros((1, 1, 3, 3), dtype=torch.float32)
1190+
W[0, 0, :, :] = torch.minimum(
1191+
2 ** torch.arange(9).reshape((3, -1)), torch.tensor([256])
1192+
)
1193+
B = torch.tensor([[[[0]]]], dtype=torch.float32)
1194+
self._finalize_test(model, X, W, B, use_ort=True)
1195+
11261196
def test_nonzero(self):
11271197
model = oh.make_model(
11281198
oh.make_graph(

onnx_diagnostic/reference/torch_ops/_op_run.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,11 @@ class OpRunTensor(OpRunValue):
2828
"""
2929

3030
def __init__(self, tensor, is_constant: bool = False, may_cpu: bool = False):
31-
assert isinstance(tensor, torch.Tensor), f"Unexpected type {type(tensor)}"
32-
assert tensor is None or tensor.numel() > 1 or tensor.item() != -666666
31+
assert isinstance(tensor, torch.Tensor), (
32+
f"Unexpected type {type(tensor)}, "
33+
f"__name__={getattr(tensor, '__name__', 'no name')}"
34+
)
35+
assert tensor is None or tensor.numel() != 1 or tensor.item() != -666666
3336
self.tensor = (
3437
tensor.cpu()
3538
if may_cpu
@@ -197,6 +200,7 @@ def __init__(self, node: onnx.NodeProto, version: Optional[int] = None):
197200
), f"Cannot guess version from name={self.__class__.__name__!r}"
198201
version = int(name[1])
199202
self.version = version
203+
self.name = node.name
200204

201205
def __str__(self) -> str:
202206
"usual"
@@ -291,6 +295,15 @@ def get_attribute_tensor(self, node: onnx.NodeProto, name: str) -> Optional[torc
291295
return None
292296
return to_tensor(att.t)
293297

298+
def same_device(self, *tensors: torch.Tensor) -> Tuple[torch.Tensor, ...]:
299+
"""Puts all tensors on the same device."""
300+
devices = [t.get_device() for t in tensors]
301+
if len(set(devices)) == 1:
302+
return tuple(tensors)
303+
index = devices.index(max(devices))
304+
device = tensors[index].device
305+
return tuple(t.to(device) for t in tensors)
306+
294307

295308
class OpRunFunction(OpRun):
296309
"""

onnx_diagnostic/reference/torch_ops/nn_ops.py

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -65,11 +65,30 @@ def run(self, x, w, b=None):
6565
), f"conv not implemented for kernel_shape={kernel_shape} and w.shape={w.shape}"
6666
dilations = self.dilations or [1 for _ in x.shape[2:]]
6767
strides = self.strides or [1 for _ in x.shape[2:]]
68-
pads = self.pads or ([0 for _ in x.shape[2:]] * 2)
69-
assert (
70-
self.auto_pad == "NOTSET"
71-
), f"conv not implemented for auto_pad={self.auto_pad!r}"
72-
assert len(set(pads)) == 1, f"conv not implemented for pads={pads}"
68+
69+
if self.auto_pad in {"SAME_LOWER", "SAME_UPPER"}:
70+
head = []
71+
tail = []
72+
for i in range(len(x.shape) - 2):
73+
d = x.shape[i + 2]
74+
target_size = (d + strides[i] - 1) // strides[i]
75+
pad_needed = (target_size - 1) * strides[i] + kernel_shape[i] - d
76+
pad_head = (
77+
(pad_needed + 1) // 2 if self.auto_pad == "SAME_LOWER" else pad_needed // 2
78+
)
79+
pad_tail = pad_needed - pad_head
80+
head.append(pad_head)
81+
tail.append(pad_tail)
82+
pads = head + tail
83+
else:
84+
pads = self.pads or ([0 for _ in x.shape[2:]] * 2)
85+
86+
assert len(set(pads)) == 1, (
87+
f"conv not implemented for pads={pads}, "
88+
f"auto_pad={self.auto_pad!r}, strides={strides}, "
89+
f"x.shape={x.shape}, kernel_shape={kernel_shape}"
90+
)
91+
7392
if b is None:
7493
bias = None
7594
else:

onnx_diagnostic/reference/torch_ops/other_ops.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,21 @@ def __init__(self, node: onnx.NodeProto, version: Optional[int] = None):
4242
self.axis = axis
4343

4444
def run(self, *data: OpRunTensor) -> OpRunTensor:
45-
return OpRunTensor(torch.cat([t.tensor for t in data], axis=self.axis))
45+
assert data, f"No tensor to concatenate in node name {self.name!r}"
46+
devices = [d.get_device() for d in data]
47+
if len(set(devices)) == 1:
48+
return OpRunTensor(torch.cat([t.tensor for t in data], axis=self.axis))
49+
if (
50+
data[0].dtype == torch.int64
51+
and self.axis == 0
52+
and max(d.tensor.ndim for d in data) == 1
53+
and max(d.tensor.numel() for d in data) <= 8
54+
):
55+
# This is a shape
56+
return OpRunTensor(torch.cat([t.tensor.cpu() for t in data], axis=self.axis))
57+
index = devices.index(max(devices))
58+
device = data[index].tensor.device
59+
return OpRunTensor(torch.cat([t.tensor.to(device) for t in data], axis=self.axis))
4660

4761

4862
class NonZero_13(OpRun):
@@ -88,4 +102,5 @@ class Where_9(OpRun):
88102
"Where"
89103

90104
def run(self, cond: OpRunTensor, x: OpRunTensor, y: OpRunTensor) -> OpRunTensor:
91-
return OpRunTensor(torch.where(cond.tensor, x.tensor, y.tensor))
105+
tcond, tx, ty = self.same_device(cond.tensor, x.tensor, y.tensor)
106+
return OpRunTensor(torch.where(tcond, tx, ty))

onnx_diagnostic/reference/torch_ops/reduce_ops.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ def run(self, x: OpRunTensor, axes: Optional[OpRunTensor] = None) -> OpRunTensor
4242
assert (
4343
not self.keepdims
4444
), f"axes is Empty, keepdims={self.keepdims} for {self.__class__.__name__}"
45-
return OpRunTensor(torch.max(x.tensor).values)
45+
return OpRunTensor(x.tensor.max())
4646
taxes = axes.as_tuple_int
4747
if len(taxes) == 1:
4848
t = x.tensor.max(taxes[0], keepdim=self.keepdims)
@@ -81,7 +81,7 @@ def run(self, x: OpRunTensor) -> OpRunTensor:
8181
assert (
8282
not self.keepdims
8383
), f"axes is Empty, keepdims={self.keepdims} for {self.__class__.__name__}"
84-
return OpRunTensor(torch.min(x.tensor).values)
84+
return OpRunTensor(x.tensor.min())
8585
taxes = tuple(axes)
8686
if len(taxes) == 1:
8787
t = x.tensor.min(taxes[0], keepdim=self.keepdims)
@@ -100,8 +100,8 @@ def run(self, x: OpRunTensor, axes: Optional[OpRunTensor] = None) -> OpRunTensor
100100
if axes is None:
101101
assert (
102102
not self.keepdims
103-
), f"axes is Empty, keepdims={self.keepdims} for {self.__class__.__name__}"
104-
return OpRunTensor(torch.min(x.tensor).values)
103+
), f"axes is empty, keepdims={self.keepdims} for {self.__class__.__name__}"
104+
return OpRunTensor(torch.min(x.tensor))
105105
taxes = axes.as_tuple_int
106106
if len(taxes) == 1:
107107
t = x.tensor.min(taxes[0], keepdim=self.keepdims)

0 commit comments

Comments
 (0)