Skip to content

Commit 544327a

Browse files
authored
Adding support of QEFFAutoModelForSequenceClassification (quic#729)
Added support of model [Llama-Prompt-Guard-2-22M](https://huggingface.co/meta-llama/Llama-Prompt-Guard-2-22M). PyTorch vs AIC MAD -> 0.0031892061233520508 --------- Signed-off-by: Amit Raj <amitraj@qti.qualcomm.com>
1 parent fc42332 commit 544327a

File tree

11 files changed

+791
-2
lines changed

11 files changed

+791
-2
lines changed

QEfficient/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
QEFFAutoModelForCausalLM,
2525
QEFFAutoModelForCTC,
2626
QEFFAutoModelForImageTextToText,
27+
QEFFAutoModelForSequenceClassification,
2728
QEFFAutoModelForSpeechSeq2Seq,
2829
QEFFCommonLoader,
2930
)
@@ -53,6 +54,7 @@
5354
"QEFFAutoModelForCTC",
5455
"QEffAutoPeftModelForCausalLM",
5556
"QEFFAutoModelForImageTextToText",
57+
"QEFFAutoModelForSequenceClassification",
5658
"QEFFAutoModelForSpeechSeq2Seq",
5759
"QEFFCommonLoader",
5860
"QEffFluxPipeline",

QEfficient/base/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,5 +11,6 @@
1111
QEFFAutoModelForCausalLM,
1212
QEFFAutoModelForCTC,
1313
QEFFAutoModelForImageTextToText,
14+
QEFFAutoModelForSequenceClassification,
1415
QEFFAutoModelForSpeechSeq2Seq,
1516
)
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
# -----------------------------------------------------------------------------
2+
#
3+
# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
4+
# SPDX-License-Identifier: BSD-3-Clause
5+
#
6+
# -----------------------------------------------------------------------------
Lines changed: 231 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,231 @@
1+
# -----------------------------------------------------------------------------
2+
#
3+
# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
4+
# SPDX-License-Identifier: BSD-3-Clause
5+
#
6+
# -----------------------------------------------------------------------------
7+
8+
import torch
9+
from torch import nn
10+
from transformers.models.deberta_v2.modeling_deberta_v2 import (
11+
DisentangledSelfAttention,
12+
)
13+
14+
15+
def make_log_bucket_position_onnx(relative_pos, bucket_size: int, max_position: int):
16+
sign = torch.sign(relative_pos)
17+
mid = bucket_size // 2
18+
abs_pos = torch.abs(relative_pos)
19+
20+
# Instead of torch.where with complex conditions, use mask-based approach
21+
# Original: torch.where((relative_pos < mid) & (relative_pos > -mid), mid-1, abs_pos)
22+
is_in_mid_range = abs_pos < mid
23+
abs_pos_clamped = torch.where(is_in_mid_range, torch.tensor(mid - 1).type_as(relative_pos), abs_pos)
24+
25+
# Compute log position
26+
log_pos = (
27+
torch.ceil(torch.log(abs_pos_clamped / mid) / torch.log(torch.tensor((max_position - 1) / mid)) * (mid - 1))
28+
+ mid
29+
)
30+
31+
# Select between relative_pos and log_pos based on whether abs_pos <= mid
32+
bucket_pos = torch.where(abs_pos <= mid, relative_pos.type_as(log_pos), log_pos * sign)
33+
return bucket_pos
34+
35+
36+
def build_relative_position_onnx(query_layer, key_layer, bucket_size: int = -1, max_position: int = -1):
37+
"""
38+
Build relative position according to the query and key.
39+
"""
40+
query_size = query_layer.size(-2)
41+
key_size = key_layer.size(-2)
42+
43+
q_ids = torch.arange(query_size, dtype=torch.long, device=query_layer.device)
44+
k_ids = torch.arange(key_size, dtype=torch.long, device=key_layer.device)
45+
rel_pos_ids = q_ids[:, None] - k_ids[None, :]
46+
47+
if bucket_size > 0 and max_position > 0:
48+
rel_pos_ids = make_log_bucket_position_onnx(rel_pos_ids, bucket_size, max_position)
49+
50+
rel_pos_ids = rel_pos_ids.to(torch.long)
51+
rel_pos_ids = rel_pos_ids[:query_size, :]
52+
rel_pos_ids = rel_pos_ids.unsqueeze(0)
53+
return rel_pos_ids
54+
55+
56+
def c2p_dynamic_expand_onnx(c2p_pos, query_layer, relative_pos):
57+
return c2p_pos.expand([query_layer.size(0), query_layer.size(1), query_layer.size(2), relative_pos.size(-1)])
58+
59+
60+
def p2c_dynamic_expand_onnx(c2p_pos, query_layer, key_layer):
61+
return c2p_pos.expand([query_layer.size(0), query_layer.size(1), key_layer.size(-2), key_layer.size(-2)])
62+
63+
64+
def pos_dynamic_expand_onnx(pos_index, p2c_att, key_layer):
65+
return pos_index.expand(p2c_att.size()[:2] + (pos_index.size(-2), key_layer.size(-2)))
66+
67+
68+
def scaled_size_sqrt_onnx(query_layer: torch.Tensor, scale_factor: int):
69+
return torch.sqrt(torch.tensor(query_layer.size(-1), dtype=torch.float) * scale_factor)
70+
71+
72+
def build_rpos_onnx(query_layer, key_layer, relative_pos, position_buckets: int, max_relative_positions: int):
73+
"""
74+
ONNX-compatible version of build_rpos.
75+
76+
Removes @torch.jit.script and conditional logic that depends on tensor sizes.
77+
Instead, we always compute the relative position to avoid dynamic branching.
78+
"""
79+
# Original had: if key_layer.size(-2) != query_layer.size(-2):
80+
# This creates a dynamic condition in ONNX. Instead, we'll always use relative_pos
81+
# if it's provided, otherwise compute it.
82+
if relative_pos is None:
83+
return build_relative_position_onnx(
84+
key_layer,
85+
key_layer,
86+
bucket_size=position_buckets,
87+
max_position=max_relative_positions,
88+
)
89+
else:
90+
return relative_pos
91+
92+
93+
class QEffDisentangledSelfAttention(DisentangledSelfAttention):
94+
"""
95+
ONNX-compatible version of DisentangledSelfAttention.
96+
97+
Overrides methods to use ONNX-compatible helper functions without @torch.jit.script.
98+
"""
99+
100+
def disentangled_attention_bias(self, query_layer, key_layer, relative_pos, rel_embeddings, scale_factor):
101+
"""
102+
Override to use ONNX-compatible functions.
103+
"""
104+
if relative_pos is None:
105+
relative_pos = build_relative_position_onnx(
106+
query_layer,
107+
key_layer,
108+
bucket_size=self.position_buckets,
109+
max_position=self.max_relative_positions,
110+
)
111+
if relative_pos.dim() == 2:
112+
relative_pos = relative_pos.unsqueeze(0).unsqueeze(0)
113+
elif relative_pos.dim() == 3:
114+
relative_pos = relative_pos.unsqueeze(1)
115+
elif relative_pos.dim() != 4:
116+
raise ValueError(f"Relative position ids must be of dim 2 or 3 or 4. {relative_pos.dim()}")
117+
118+
att_span = self.pos_ebd_size
119+
relative_pos = relative_pos.to(device=query_layer.device, dtype=torch.long)
120+
121+
rel_embeddings = rel_embeddings[0 : att_span * 2, :].unsqueeze(0)
122+
if self.share_att_key:
123+
pos_query_layer = self.transpose_for_scores(
124+
self.query_proj(rel_embeddings), self.num_attention_heads
125+
).repeat(query_layer.size(0) // self.num_attention_heads, 1, 1)
126+
pos_key_layer = self.transpose_for_scores(self.key_proj(rel_embeddings), self.num_attention_heads).repeat(
127+
query_layer.size(0) // self.num_attention_heads, 1, 1
128+
)
129+
else:
130+
if "c2p" in self.pos_att_type:
131+
pos_key_layer = self.transpose_for_scores(
132+
self.pos_key_proj(rel_embeddings), self.num_attention_heads
133+
).repeat(query_layer.size(0) // self.num_attention_heads, 1, 1)
134+
if "p2c" in self.pos_att_type:
135+
pos_query_layer = self.transpose_for_scores(
136+
self.pos_query_proj(rel_embeddings), self.num_attention_heads
137+
).repeat(query_layer.size(0) // self.num_attention_heads, 1, 1)
138+
139+
score = 0
140+
# content->position
141+
if "c2p" in self.pos_att_type:
142+
scale = scaled_size_sqrt_onnx(pos_key_layer, scale_factor)
143+
c2p_att = torch.bmm(query_layer, pos_key_layer.transpose(-1, -2))
144+
c2p_pos = torch.clamp(relative_pos + att_span, 0, att_span * 2 - 1)
145+
c2p_att = torch.gather(
146+
c2p_att,
147+
dim=-1,
148+
index=c2p_pos.squeeze(0).expand([query_layer.size(0), query_layer.size(1), relative_pos.size(-1)]),
149+
)
150+
score += c2p_att / scale.to(dtype=c2p_att.dtype)
151+
152+
# position->content
153+
if "p2c" in self.pos_att_type:
154+
scale = scaled_size_sqrt_onnx(pos_query_layer, scale_factor)
155+
r_pos = build_rpos_onnx(
156+
query_layer,
157+
key_layer,
158+
relative_pos,
159+
self.position_buckets,
160+
self.max_relative_positions,
161+
)
162+
p2c_pos = torch.clamp(-r_pos + att_span, 0, att_span * 2 - 1)
163+
p2c_att = torch.bmm(key_layer, pos_query_layer.transpose(-1, -2))
164+
p2c_att = torch.gather(
165+
p2c_att,
166+
dim=-1,
167+
index=p2c_pos.squeeze(0).expand([query_layer.size(0), key_layer.size(-2), key_layer.size(-2)]),
168+
).transpose(-1, -2)
169+
score += p2c_att / scale.to(dtype=p2c_att.dtype)
170+
171+
return score
172+
173+
def forward(
174+
self,
175+
hidden_states,
176+
attention_mask,
177+
output_attentions=False,
178+
query_states=None,
179+
relative_pos=None,
180+
rel_embeddings=None,
181+
):
182+
"""
183+
Forward pass using ONNX-compatible attention bias computation.
184+
"""
185+
if query_states is None:
186+
query_states = hidden_states
187+
query_layer = self.transpose_for_scores(self.query_proj(query_states), self.num_attention_heads)
188+
key_layer = self.transpose_for_scores(self.key_proj(hidden_states), self.num_attention_heads)
189+
value_layer = self.transpose_for_scores(self.value_proj(hidden_states), self.num_attention_heads)
190+
191+
rel_att = None
192+
# Take the dot product between "query" and "key" to get the raw attention scores.
193+
scale_factor = 1
194+
if "c2p" in self.pos_att_type:
195+
scale_factor += 1
196+
if "p2c" in self.pos_att_type:
197+
scale_factor += 1
198+
scale = scaled_size_sqrt_onnx(query_layer, scale_factor)
199+
attention_scores = torch.bmm(query_layer, key_layer.transpose(-1, -2) / scale.to(dtype=query_layer.dtype))
200+
if self.relative_attention:
201+
rel_embeddings = self.pos_dropout(rel_embeddings)
202+
rel_att = self.disentangled_attention_bias(
203+
query_layer, key_layer, relative_pos, rel_embeddings, scale_factor
204+
)
205+
206+
if rel_att is not None:
207+
attention_scores = attention_scores + rel_att
208+
attention_scores = attention_scores
209+
attention_scores = attention_scores.view(
210+
-1, self.num_attention_heads, attention_scores.size(-2), attention_scores.size(-1)
211+
)
212+
213+
attention_mask = attention_mask.bool()
214+
attention_scores = attention_scores.masked_fill(~(attention_mask), torch.finfo(query_layer.dtype).min)
215+
# bsz x height x length x dimension
216+
attention_probs = nn.functional.softmax(attention_scores, dim=-1)
217+
218+
attention_probs = self.dropout(attention_probs)
219+
context_layer = torch.bmm(
220+
attention_probs.view(-1, attention_probs.size(-2), attention_probs.size(-1)), value_layer
221+
)
222+
context_layer = (
223+
context_layer.view(-1, self.num_attention_heads, context_layer.size(-2), context_layer.size(-1))
224+
.permute(0, 2, 1, 3)
225+
.contiguous()
226+
)
227+
new_context_layer_shape = context_layer.size()[:-2] + (-1,)
228+
context_layer = context_layer.view(new_context_layer_shape)
229+
if not output_attentions:
230+
return (context_layer, None)
231+
return (context_layer, attention_probs)

0 commit comments

Comments
 (0)