|
| 1 | +# Copyright 2024 Arm Limited and/or its affiliates. |
| 2 | +# All rights reserved. |
| 3 | +# |
| 4 | +# This source code is licensed under the BSD-style license found in the |
| 5 | +# LICENSE file in the root directory of this source tree. |
| 6 | + |
| 7 | +import operator |
| 8 | + |
| 9 | +import torch |
| 10 | +from executorch.backends.arm._passes.arm_pass_utils import create_node |
| 11 | +from executorch.exir.dialects._ops import ops as exir_ops |
| 12 | +from executorch.exir.pass_base import ExportPass, PassResult |
| 13 | + |
| 14 | + |
| 15 | +def get_layer_norm_decomposition(op) -> tuple: |
| 16 | + if op == exir_ops.edge.aten.native_layer_norm.default: |
| 17 | + return ( |
| 18 | + exir_ops.edge.aten.mean.dim, |
| 19 | + exir_ops.edge.aten.sub.Tensor, |
| 20 | + exir_ops.edge.aten.var.correction, |
| 21 | + exir_ops.edge.aten.full.default, |
| 22 | + exir_ops.edge.aten.add.Tensor, |
| 23 | + exir_ops.edge.aten.rsqrt.default, |
| 24 | + exir_ops.edge.aten.mul.Tensor, |
| 25 | + exir_ops.edge.aten.view_copy.default, |
| 26 | + ) |
| 27 | + if op == torch.ops.aten.layer_norm.default: |
| 28 | + return ( |
| 29 | + torch.ops.aten.mean.dim, |
| 30 | + torch.ops.aten.sub.Tensor, |
| 31 | + torch.ops.aten.var.correction, |
| 32 | + torch.ops.aten.full.default, |
| 33 | + torch.ops.aten.add.Tensor, |
| 34 | + torch.ops.aten.rsqrt.default, |
| 35 | + torch.ops.aten.mul.Tensor, |
| 36 | + torch.ops.aten.view_copy.default, |
| 37 | + ) |
| 38 | + raise RuntimeError(f"Can't get layer_norm composition for op {op}") |
| 39 | + |
| 40 | + |
| 41 | +class DecomposeLayerNormPass(ExportPass): |
| 42 | + """ |
| 43 | + layernorm is defined as: ((x - E[x]) / sqrt(Var[x] + eps)) * weights + bias |
| 44 | + Decompose layernorm(x, normalized_shape, weights, bias, eps) to a sequence of: |
| 45 | + mean = op_mean(x, dims) # E[x] |
| 46 | + var = op_var(x, dims) # Var[x] |
| 47 | + denominator = op_sub(x, mean) # (x - E[x]) |
| 48 | + add = op_add(var, eps) # Var[x] + eps |
| 49 | + rsqrt = op_rsqrt(add) # 1 / sqrt(Var[x] + eps) |
| 50 | + mul = op_mul(denominator, rsqrt) # ((x - E[x]) / sqrt(Var[x] + eps)) * weigths |
| 51 | + bias = op_add(mul, bias) # ((x - E[x]) / sqrt(Var[x] + eps)) * weigths + bias |
| 52 | +
|
| 53 | + Source: https://pytorch.org/docs/stable/generated/torch.nn.LayerNorm.html |
| 54 | + """ |
| 55 | + |
| 56 | + def call(self, graph_module: torch.fx.GraphModule): |
| 57 | + for node in graph_module.graph.nodes: |
| 58 | + if node.op != "call_function" or node.target not in ( |
| 59 | + exir_ops.edge.aten.native_layer_norm.default, |
| 60 | + torch.ops.aten.layer_norm.default, |
| 61 | + ): |
| 62 | + continue |
| 63 | + |
| 64 | + # epsilon default value |
| 65 | + epsilon = torch.finfo().eps |
| 66 | + weights = None |
| 67 | + bias = None |
| 68 | + args = node.args |
| 69 | + meta = node.meta |
| 70 | + match len(args): |
| 71 | + case 5: |
| 72 | + x, normalized_shape, weights, bias, epsilon = args |
| 73 | + case 4: |
| 74 | + x, normalized_shape, weights, bias = args |
| 75 | + case 3: |
| 76 | + x, normalized_shape, weights = args |
| 77 | + case _: |
| 78 | + x, normalized_shape = args |
| 79 | + |
| 80 | + n_dims = len(normalized_shape) |
| 81 | + if isinstance(meta["val"], tuple): |
| 82 | + shape = meta["val"][0].size() |
| 83 | + else: |
| 84 | + shape = meta["val"].size() |
| 85 | + dtype = meta["val"][0].dtype |
| 86 | + rank = len(shape) |
| 87 | + dims = list(range(-1, -1 * (n_dims + 1), -1)) |
| 88 | + dims = [dim % rank for dim in dims] |
| 89 | + weights_reshaped_shape = [shape[i] if i in dims else 1 for i in range(rank)] |
| 90 | + epsilon_reshaped_shape = [1] * rank |
| 91 | + |
| 92 | + ( |
| 93 | + mean_op, |
| 94 | + sub_op, |
| 95 | + var_op, |
| 96 | + full_op, |
| 97 | + add_op, |
| 98 | + rsqrt_op, |
| 99 | + mul_op, |
| 100 | + view_op, |
| 101 | + ) = get_layer_norm_decomposition(node.target) |
| 102 | + with graph_module.graph.inserting_before(node): |
| 103 | + keepdim = True |
| 104 | + mean = create_node(graph_module.graph, mean_op, args=(x, dims, keepdim)) |
| 105 | + sub = create_node(graph_module.graph, sub_op, args=(x, mean)) |
| 106 | + var = create_node( |
| 107 | + graph_module.graph, |
| 108 | + var_op, |
| 109 | + args=(x, dims), |
| 110 | + kwargs={"correction": 0, "keepdim": keepdim}, |
| 111 | + ) |
| 112 | + full = create_node( |
| 113 | + graph_module.graph, |
| 114 | + full_op, |
| 115 | + args=(epsilon_reshaped_shape, epsilon), |
| 116 | + kwargs={"dtype": dtype}, |
| 117 | + ) |
| 118 | + add0 = create_node(graph_module.graph, add_op, args=(var, full)) |
| 119 | + rsqrt = create_node(graph_module.graph, rsqrt_op, args=(add0,)) |
| 120 | + mul0 = create_node(graph_module.graph, mul_op, args=(sub, rsqrt)) |
| 121 | + if weights is not None: |
| 122 | + weights_reshaped = create_node( |
| 123 | + graph_module.graph, |
| 124 | + view_op, |
| 125 | + args=(weights, weights_reshaped_shape), |
| 126 | + ) |
| 127 | + mul1 = create_node( |
| 128 | + graph_module.graph, mul_op, args=(mul0, weights_reshaped) |
| 129 | + ) |
| 130 | + else: |
| 131 | + mul1 = mul0 |
| 132 | + output = mul1 |
| 133 | + if bias is not None: |
| 134 | + bias_reshaped_shape = weights_reshaped_shape |
| 135 | + bias_reshaped = create_node( |
| 136 | + graph_module.graph, view_op, args=(bias, bias_reshaped_shape) |
| 137 | + ) |
| 138 | + output = create_node( |
| 139 | + graph_module.graph, add_op, args=(mul1, bias_reshaped) |
| 140 | + ) |
| 141 | + |
| 142 | + users = [user for user in node.users if node != user] |
| 143 | + node.replace_all_uses_with(output) |
| 144 | + for user in users: |
| 145 | + if user.target == operator.getitem: |
| 146 | + user.replace_all_uses_with(output) |
| 147 | + graph_module.graph.erase_node(node) |
| 148 | + graph_module.graph.eliminate_dead_code() |
| 149 | + graph_module.recompile() |
| 150 | + graph_module = super().call(graph_module).graph_module |
| 151 | + |
| 152 | + return PassResult(graph_module, True) |
0 commit comments