|
| 1 | +# ----------------------------------------------------------------------------- |
| 2 | +# |
| 3 | +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. |
| 4 | +# SPDX-License-Identifier: BSD-3-Clause |
| 5 | +# |
| 6 | +# ----------------------------------------------------------------------------- |
| 7 | + |
| 8 | +import torch |
| 9 | +from torch import nn |
| 10 | +from transformers.models.deberta_v2.modeling_deberta_v2 import ( |
| 11 | + DisentangledSelfAttention, |
| 12 | +) |
| 13 | + |
| 14 | + |
| 15 | +def make_log_bucket_position_onnx(relative_pos, bucket_size: int, max_position: int): |
| 16 | + sign = torch.sign(relative_pos) |
| 17 | + mid = bucket_size // 2 |
| 18 | + abs_pos = torch.abs(relative_pos) |
| 19 | + |
| 20 | + # Instead of torch.where with complex conditions, use mask-based approach |
| 21 | + # Original: torch.where((relative_pos < mid) & (relative_pos > -mid), mid-1, abs_pos) |
| 22 | + is_in_mid_range = abs_pos < mid |
| 23 | + abs_pos_clamped = torch.where(is_in_mid_range, torch.tensor(mid - 1).type_as(relative_pos), abs_pos) |
| 24 | + |
| 25 | + # Compute log position |
| 26 | + log_pos = ( |
| 27 | + torch.ceil(torch.log(abs_pos_clamped / mid) / torch.log(torch.tensor((max_position - 1) / mid)) * (mid - 1)) |
| 28 | + + mid |
| 29 | + ) |
| 30 | + |
| 31 | + # Select between relative_pos and log_pos based on whether abs_pos <= mid |
| 32 | + bucket_pos = torch.where(abs_pos <= mid, relative_pos.type_as(log_pos), log_pos * sign) |
| 33 | + return bucket_pos |
| 34 | + |
| 35 | + |
| 36 | +def build_relative_position_onnx(query_layer, key_layer, bucket_size: int = -1, max_position: int = -1): |
| 37 | + """ |
| 38 | + Build relative position according to the query and key. |
| 39 | + """ |
| 40 | + query_size = query_layer.size(-2) |
| 41 | + key_size = key_layer.size(-2) |
| 42 | + |
| 43 | + q_ids = torch.arange(query_size, dtype=torch.long, device=query_layer.device) |
| 44 | + k_ids = torch.arange(key_size, dtype=torch.long, device=key_layer.device) |
| 45 | + rel_pos_ids = q_ids[:, None] - k_ids[None, :] |
| 46 | + |
| 47 | + if bucket_size > 0 and max_position > 0: |
| 48 | + rel_pos_ids = make_log_bucket_position_onnx(rel_pos_ids, bucket_size, max_position) |
| 49 | + |
| 50 | + rel_pos_ids = rel_pos_ids.to(torch.long) |
| 51 | + rel_pos_ids = rel_pos_ids[:query_size, :] |
| 52 | + rel_pos_ids = rel_pos_ids.unsqueeze(0) |
| 53 | + return rel_pos_ids |
| 54 | + |
| 55 | + |
| 56 | +def c2p_dynamic_expand_onnx(c2p_pos, query_layer, relative_pos): |
| 57 | + return c2p_pos.expand([query_layer.size(0), query_layer.size(1), query_layer.size(2), relative_pos.size(-1)]) |
| 58 | + |
| 59 | + |
| 60 | +def p2c_dynamic_expand_onnx(c2p_pos, query_layer, key_layer): |
| 61 | + return c2p_pos.expand([query_layer.size(0), query_layer.size(1), key_layer.size(-2), key_layer.size(-2)]) |
| 62 | + |
| 63 | + |
| 64 | +def pos_dynamic_expand_onnx(pos_index, p2c_att, key_layer): |
| 65 | + return pos_index.expand(p2c_att.size()[:2] + (pos_index.size(-2), key_layer.size(-2))) |
| 66 | + |
| 67 | + |
| 68 | +def scaled_size_sqrt_onnx(query_layer: torch.Tensor, scale_factor: int): |
| 69 | + return torch.sqrt(torch.tensor(query_layer.size(-1), dtype=torch.float) * scale_factor) |
| 70 | + |
| 71 | + |
| 72 | +def build_rpos_onnx(query_layer, key_layer, relative_pos, position_buckets: int, max_relative_positions: int): |
| 73 | + """ |
| 74 | + ONNX-compatible version of build_rpos. |
| 75 | +
|
| 76 | + Removes @torch.jit.script and conditional logic that depends on tensor sizes. |
| 77 | + Instead, we always compute the relative position to avoid dynamic branching. |
| 78 | + """ |
| 79 | + # Original had: if key_layer.size(-2) != query_layer.size(-2): |
| 80 | + # This creates a dynamic condition in ONNX. Instead, we'll always use relative_pos |
| 81 | + # if it's provided, otherwise compute it. |
| 82 | + if relative_pos is None: |
| 83 | + return build_relative_position_onnx( |
| 84 | + key_layer, |
| 85 | + key_layer, |
| 86 | + bucket_size=position_buckets, |
| 87 | + max_position=max_relative_positions, |
| 88 | + ) |
| 89 | + else: |
| 90 | + return relative_pos |
| 91 | + |
| 92 | + |
| 93 | +class QEffDisentangledSelfAttention(DisentangledSelfAttention): |
| 94 | + """ |
| 95 | + ONNX-compatible version of DisentangledSelfAttention. |
| 96 | +
|
| 97 | + Overrides methods to use ONNX-compatible helper functions without @torch.jit.script. |
| 98 | + """ |
| 99 | + |
| 100 | + def disentangled_attention_bias(self, query_layer, key_layer, relative_pos, rel_embeddings, scale_factor): |
| 101 | + """ |
| 102 | + Override to use ONNX-compatible functions. |
| 103 | + """ |
| 104 | + if relative_pos is None: |
| 105 | + relative_pos = build_relative_position_onnx( |
| 106 | + query_layer, |
| 107 | + key_layer, |
| 108 | + bucket_size=self.position_buckets, |
| 109 | + max_position=self.max_relative_positions, |
| 110 | + ) |
| 111 | + if relative_pos.dim() == 2: |
| 112 | + relative_pos = relative_pos.unsqueeze(0).unsqueeze(0) |
| 113 | + elif relative_pos.dim() == 3: |
| 114 | + relative_pos = relative_pos.unsqueeze(1) |
| 115 | + elif relative_pos.dim() != 4: |
| 116 | + raise ValueError(f"Relative position ids must be of dim 2 or 3 or 4. {relative_pos.dim()}") |
| 117 | + |
| 118 | + att_span = self.pos_ebd_size |
| 119 | + relative_pos = relative_pos.to(device=query_layer.device, dtype=torch.long) |
| 120 | + |
| 121 | + rel_embeddings = rel_embeddings[0 : att_span * 2, :].unsqueeze(0) |
| 122 | + if self.share_att_key: |
| 123 | + pos_query_layer = self.transpose_for_scores( |
| 124 | + self.query_proj(rel_embeddings), self.num_attention_heads |
| 125 | + ).repeat(query_layer.size(0) // self.num_attention_heads, 1, 1) |
| 126 | + pos_key_layer = self.transpose_for_scores(self.key_proj(rel_embeddings), self.num_attention_heads).repeat( |
| 127 | + query_layer.size(0) // self.num_attention_heads, 1, 1 |
| 128 | + ) |
| 129 | + else: |
| 130 | + if "c2p" in self.pos_att_type: |
| 131 | + pos_key_layer = self.transpose_for_scores( |
| 132 | + self.pos_key_proj(rel_embeddings), self.num_attention_heads |
| 133 | + ).repeat(query_layer.size(0) // self.num_attention_heads, 1, 1) |
| 134 | + if "p2c" in self.pos_att_type: |
| 135 | + pos_query_layer = self.transpose_for_scores( |
| 136 | + self.pos_query_proj(rel_embeddings), self.num_attention_heads |
| 137 | + ).repeat(query_layer.size(0) // self.num_attention_heads, 1, 1) |
| 138 | + |
| 139 | + score = 0 |
| 140 | + # content->position |
| 141 | + if "c2p" in self.pos_att_type: |
| 142 | + scale = scaled_size_sqrt_onnx(pos_key_layer, scale_factor) |
| 143 | + c2p_att = torch.bmm(query_layer, pos_key_layer.transpose(-1, -2)) |
| 144 | + c2p_pos = torch.clamp(relative_pos + att_span, 0, att_span * 2 - 1) |
| 145 | + c2p_att = torch.gather( |
| 146 | + c2p_att, |
| 147 | + dim=-1, |
| 148 | + index=c2p_pos.squeeze(0).expand([query_layer.size(0), query_layer.size(1), relative_pos.size(-1)]), |
| 149 | + ) |
| 150 | + score += c2p_att / scale.to(dtype=c2p_att.dtype) |
| 151 | + |
| 152 | + # position->content |
| 153 | + if "p2c" in self.pos_att_type: |
| 154 | + scale = scaled_size_sqrt_onnx(pos_query_layer, scale_factor) |
| 155 | + r_pos = build_rpos_onnx( |
| 156 | + query_layer, |
| 157 | + key_layer, |
| 158 | + relative_pos, |
| 159 | + self.position_buckets, |
| 160 | + self.max_relative_positions, |
| 161 | + ) |
| 162 | + p2c_pos = torch.clamp(-r_pos + att_span, 0, att_span * 2 - 1) |
| 163 | + p2c_att = torch.bmm(key_layer, pos_query_layer.transpose(-1, -2)) |
| 164 | + p2c_att = torch.gather( |
| 165 | + p2c_att, |
| 166 | + dim=-1, |
| 167 | + index=p2c_pos.squeeze(0).expand([query_layer.size(0), key_layer.size(-2), key_layer.size(-2)]), |
| 168 | + ).transpose(-1, -2) |
| 169 | + score += p2c_att / scale.to(dtype=p2c_att.dtype) |
| 170 | + |
| 171 | + return score |
| 172 | + |
| 173 | + def forward( |
| 174 | + self, |
| 175 | + hidden_states, |
| 176 | + attention_mask, |
| 177 | + output_attentions=False, |
| 178 | + query_states=None, |
| 179 | + relative_pos=None, |
| 180 | + rel_embeddings=None, |
| 181 | + ): |
| 182 | + """ |
| 183 | + Forward pass using ONNX-compatible attention bias computation. |
| 184 | + """ |
| 185 | + if query_states is None: |
| 186 | + query_states = hidden_states |
| 187 | + query_layer = self.transpose_for_scores(self.query_proj(query_states), self.num_attention_heads) |
| 188 | + key_layer = self.transpose_for_scores(self.key_proj(hidden_states), self.num_attention_heads) |
| 189 | + value_layer = self.transpose_for_scores(self.value_proj(hidden_states), self.num_attention_heads) |
| 190 | + |
| 191 | + rel_att = None |
| 192 | + # Take the dot product between "query" and "key" to get the raw attention scores. |
| 193 | + scale_factor = 1 |
| 194 | + if "c2p" in self.pos_att_type: |
| 195 | + scale_factor += 1 |
| 196 | + if "p2c" in self.pos_att_type: |
| 197 | + scale_factor += 1 |
| 198 | + scale = scaled_size_sqrt_onnx(query_layer, scale_factor) |
| 199 | + attention_scores = torch.bmm(query_layer, key_layer.transpose(-1, -2) / scale.to(dtype=query_layer.dtype)) |
| 200 | + if self.relative_attention: |
| 201 | + rel_embeddings = self.pos_dropout(rel_embeddings) |
| 202 | + rel_att = self.disentangled_attention_bias( |
| 203 | + query_layer, key_layer, relative_pos, rel_embeddings, scale_factor |
| 204 | + ) |
| 205 | + |
| 206 | + if rel_att is not None: |
| 207 | + attention_scores = attention_scores + rel_att |
| 208 | + attention_scores = attention_scores |
| 209 | + attention_scores = attention_scores.view( |
| 210 | + -1, self.num_attention_heads, attention_scores.size(-2), attention_scores.size(-1) |
| 211 | + ) |
| 212 | + |
| 213 | + attention_mask = attention_mask.bool() |
| 214 | + attention_scores = attention_scores.masked_fill(~(attention_mask), torch.finfo(query_layer.dtype).min) |
| 215 | + # bsz x height x length x dimension |
| 216 | + attention_probs = nn.functional.softmax(attention_scores, dim=-1) |
| 217 | + |
| 218 | + attention_probs = self.dropout(attention_probs) |
| 219 | + context_layer = torch.bmm( |
| 220 | + attention_probs.view(-1, attention_probs.size(-2), attention_probs.size(-1)), value_layer |
| 221 | + ) |
| 222 | + context_layer = ( |
| 223 | + context_layer.view(-1, self.num_attention_heads, context_layer.size(-2), context_layer.size(-1)) |
| 224 | + .permute(0, 2, 1, 3) |
| 225 | + .contiguous() |
| 226 | + ) |
| 227 | + new_context_layer_shape = context_layer.size()[:-2] + (-1,) |
| 228 | + context_layer = context_layer.view(new_context_layer_shape) |
| 229 | + if not output_attentions: |
| 230 | + return (context_layer, None) |
| 231 | + return (context_layer, attention_probs) |
0 commit comments