Skip to content

Commit 86cd5e0

Browse files
committed
ind
1 parent 8a78df2 commit 86cd5e0

File tree

2 files changed

+3
-3
lines changed

2 files changed

+3
-3
lines changed

_unittests/ut_reference/test_torch_evaluator.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -444,13 +444,13 @@ def test_op_reduce_sum(self):
444444
model,
445445
torch.rand(3, 4, 5, dtype=torch.float32),
446446
torch.tensor([1], dtype=torch.int64),
447-
atol=1e-6,
447+
atol=1e-5,
448448
)
449449
self._finalize_test(
450450
model,
451451
torch.rand(3, 4, 5, dtype=torch.float32),
452452
torch.tensor([1, 2], dtype=torch.int64),
453-
atol=1e-6,
453+
atol=1e-5,
454454
)
455455

456456

onnx_diagnostic/reference/torch_ops/access_ops.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ def run(self, x, indices):
1818
return torch.empty((0,), dtype=x.tensor.dtype, device=x.tensor.device)
1919
ind = [slice(0, s) for s in x.shape]
2020
ind[self.axis] = indices.tensor
21-
return OpRunValue(x.tensor[*ind])
21+
return OpRunValue(x.tensor[tuple(ind)])
2222

2323

2424
class Slice_13(OpRun):

0 commit comments

Comments
 (0)