Skip to content

Commit 69494f7

Browse files
committed
reduce
1 parent 12eb1ba commit 69494f7

File tree

2 files changed

+19
-9
lines changed

2 files changed

+19
-9
lines changed

_unittests/ut_reference/test_torch_onnx_evaluator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -472,7 +472,7 @@ def test_op_reduce_min_17(self):
472472
def test_op_reduce_min_17_no_axes(self):
473473
model = oh.make_model(
474474
oh.make_graph(
475-
[oh.make_node("ReduceMin", ["X"], ["Z"])],
475+
[oh.make_node("ReduceMin", ["X"], ["Z"], keepdims=0)],
476476
"dummy",
477477
[oh.make_tensor_value_info("X", TFLOAT, ["a", "b", "c"])],
478478
[oh.make_tensor_value_info("Z", TFLOAT, ["a", "b", "c"])],

onnx_diagnostic/reference/torch_ops/nn_ops.py

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -122,14 +122,24 @@ def __init__(self, node: onnx.NodeProto, version: Optional[int] = None):
122122

123123
def run(self, x, scale, bias=None):
124124
original_dtype = x.dtype
125-
xt = x.tensor.to(self.stash_type)
126-
res = torch.nn.functional.layer_norm(
127-
xt,
128-
xt.shape[self.axis :],
129-
weight=scale.tensor,
130-
bias=None if bias is None else bias.tensor,
131-
eps=self.epsilon,
132-
)
125+
if self.stash_type == torch.float32 and x.tensor.dtype != torch.float64:
126+
xt = x.tensor
127+
res = torch.nn.functional.layer_norm(
128+
xt,
129+
xt.shape[self.axis :],
130+
weight=scale.tensor,
131+
bias=None if bias is None else bias.tensor,
132+
eps=self.epsilon,
133+
)
134+
else:
135+
xt = x.tensor.to(self.stash_type)
136+
res = torch.nn.functional.layer_norm(
137+
xt,
138+
xt.shape[self.axis :],
139+
weight=scale.tensor.to(self.stash_type),
140+
bias=None if bias is None else bias.tensor.to(self.stash_type),
141+
eps=self.epsilon,
142+
)
133143
if not self.compute_std:
134144
return OpRunTensor(res.to(original_dtype))
135145
axes = tuple(range(len(xt.shape)))[self.axis :]

0 commit comments

Comments
 (0)