|
| 1 | +from typing import Optional |
| 2 | +import onnx |
| 3 | +import torch |
| 4 | +from ...helpers.torch_helper import onnx_dtype_to_torch_dtype |
| 5 | +from . import OpRun, OpRunValue |
| 6 | + |
| 7 | + |
| 8 | +class LayerNormalization_17(OpRun): |
| 9 | + "LayerNormalization" |
| 10 | + |
| 11 | + def __init__(self, node: onnx.NodeProto, version: Optional[int] = None): |
| 12 | + super().__init__(node, version) |
| 13 | + self.axis = self.get_attribute_int(node, "axis", -1) |
| 14 | + self.epsilon = self.get_attribute_float(node, "epsilon", 1e-5) |
| 15 | + self.stash_type = onnx_dtype_to_torch_dtype( |
| 16 | + self.get_attribute_int(node, "stash_type", onnx.TensorProto.FLOAT) # type: ignore[arg-type] |
| 17 | + ) |
| 18 | + self.compute_std = len(node.output) > 1 |
| 19 | + |
| 20 | + def run(self, x, scale, bias=None): |
| 21 | + original_dtype = x.dtype |
| 22 | + xt = x.tensor.to(self.stash_type) |
| 23 | + res = torch.nn.functional.layer_norm( |
| 24 | + xt, |
| 25 | + xt.shape[self.axis :], |
| 26 | + weight=scale.tensor, |
| 27 | + bias=None if bias is None else bias.tensor, |
| 28 | + eps=self.epsilon, |
| 29 | + ) |
| 30 | + if not self.compute_std: |
| 31 | + return OpRunValue(res.to(original_dtype)) |
| 32 | + axes = tuple(range(len(xt.shape)))[self.axis :] |
| 33 | + mean, var = torch.var(xt, dim=axes, keepdim=False) |
| 34 | + x_inv_std_dev = torch.reciprocal(torch.sqrt(var + self.epsilon)) |
| 35 | + return OpRunValue(res.to(original_dtype)), OpRunValue(mean), OpRunValue(x_inv_std_dev) |
| 36 | + |
| 37 | + |
| 38 | +class Softmax_13(OpRun): |
| 39 | + "Softmax" |
| 40 | + |
| 41 | + def __init__(self, node: onnx.NodeProto, version: Optional[int] = None): |
| 42 | + super().__init__(node, version) |
| 43 | + self.axis = self.get_attribute_int(node, "axis", -1) |
| 44 | + assert isinstance(self.axis, int), f"Unexpected value for attribute axis={self.axis!r}" |
| 45 | + # this is out of spec |
| 46 | + stash_type = self.get_attribute_int(node, "stash_type", None) |
| 47 | + self.stash_type = None if stash_type is None else onnx_dtype_to_torch_dtype(stash_type) |
| 48 | + |
| 49 | + def run(self, data: OpRunValue) -> OpRunValue: |
| 50 | + return OpRunValue( |
| 51 | + torch.nn.functional.softmax(data.tensor, dim=self.axis, dtype=self.stash_type) |
| 52 | + ) |
| 53 | + |
| 54 | + |
| 55 | +class Tanh_6(OpRun): |
| 56 | + "Tanh" |
| 57 | + |
| 58 | + def run(self, data: OpRunValue) -> OpRunValue: |
| 59 | + return OpRunValue(torch.nn.functional.tanh(data.tensor)) |
0 commit comments