Skip to content

Commit f7c88be

Browse files
committed
Updated SoftMax calculation precision for all modeling files.
Enabled CI tests for fp16 based LMs, embedding and sequence classification models. Modified CI based config for LLM tests. Embedding models have high MAD for fp16 exported models(~0.015) Certain CausalLMs cause a token mismatch after few tokens for fp16 setup. Whisper model has a clip operator issue for fp16 exported models so its not enabled yet. Signed-off-by: Dhiraj Kumar Sah <dhirajku@qti.qualcomm.com>
1 parent df78631 commit f7c88be

30 files changed

+665
-369
lines changed

QEfficient/transformers/models/falcon/modeling_falcon.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,9 @@ def forward(
156156
attention_scores = torch.where(
157157
attention_mask, torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=self.config.torch_dtype), attention_scores
158158
)
159-
attention_scores = F.softmax(attention_scores + attention_mask, dim=-1, dtype=hidden_states.dtype)
159+
attention_scores = F.softmax(attention_scores + attention_mask, dim=-1, dtype=torch.float32).to(
160+
query_layer.dtype
161+
)
160162
# It is unclear why neither dropout nor head_mask is applied here (while it is with alibi).
161163
attn_output = attention_scores @ value_layer
162164

QEfficient/transformers/models/gemma/modeling_gemma.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ def eager_attention_forward(
114114
attention_mask, torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=module.config.torch_dtype), attn_weights
115115
)
116116

117-
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=module.config.torch_dtype).to(query.dtype)
117+
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
118118
attn_output = torch.matmul(attn_weights, value_states)
119119
attn_output = attn_output.transpose(1, 2).contiguous()
120120

QEfficient/transformers/models/gemma2/modeling_gemma2.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ def eager_attention_forward(
121121
attention_mask, torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=module.config.torch_dtype), attn_weights
122122
)
123123

124-
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=module.config.torch_dtype).to(query.dtype)
124+
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
125125
attn_output = torch.matmul(attn_weights, value_states)
126126
attn_output = attn_output.transpose(1, 2).contiguous()
127127

QEfficient/transformers/models/gemma3/modeling_gemma3.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,7 @@ def eager_attention_forward(
166166
attention_mask, torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=module.config.torch_dtype), attn_weights
167167
)
168168

169-
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=module.config.torch_dtype).to(query.dtype)
169+
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
170170
attn_output = torch.matmul(attn_weights, value_states)
171171
attn_output = attn_output.transpose(1, 2).contiguous()
172172

@@ -277,7 +277,7 @@ def forward(
277277
)
278278

279279
# upcast attention to fp32
280-
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=self.config.torch_dtype).to(query_states.dtype)
280+
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
281281
attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
282282
attn_output = torch.matmul(attn_weights, value_states)
283283

QEfficient/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ def eager_attention_forward(
8686
attn_weights = torch.where(
8787
attention_mask, torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=module.config.torch_dtype), attn_weights
8888
)
89-
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=module.config.torch_dtype).to(query.dtype)
89+
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
9090
attn_output = torch.matmul(attn_weights, value_states)
9191
attn_output = attn_output.transpose(1, 2).contiguous()
9292

QEfficient/transformers/models/gpt_oss/modeling_gpt_oss.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -661,7 +661,7 @@ def eager_attention_forward_blocked(
661661
)
662662
combined_logits = torch.cat([curr_attn_weights, sinks], dim=-1)
663663
combined_logits = combined_logits - combined_logits.max(dim=-1, keepdim=True).values
664-
curr_attn_weights = nn.functional.softmax(combined_logits, dim=-1, dtype=module.config.torch_dtype)
664+
curr_attn_weights = nn.functional.softmax(combined_logits, dim=-1, dtype=torch.float32)
665665
curr_attn_weights = curr_attn_weights[..., :-1]
666666
out_block = torch.matmul(curr_attn_weights, value_states)
667667
outs.append(out_block)
@@ -724,7 +724,7 @@ def opt_eager_attention_forward_blocked(
724724
)
725725
combined_logits = torch.cat([curr_attn_weights, sinks], dim=-1)
726726
combined_logits = combined_logits - combined_logits.max(dim=-1, keepdim=True).values
727-
curr_attn_weights = nn.functional.softmax(combined_logits, dim=-1, dtype=module.config.torch_dtype)
727+
curr_attn_weights = nn.functional.softmax(combined_logits, dim=-1, dtype=torch.float32)
728728
curr_attn_weights = curr_attn_weights[..., :-1]
729729
out_block = torch.matmul(curr_attn_weights, v_block)
730730
outs.append(out_block)

QEfficient/transformers/models/granite/modeling_granite.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ def eager_attention_forward(
113113
attention_mask, torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=module.config.torch_dtype), attn_weights
114114
)
115115

116-
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=module.config.torch_dtype).to(query.dtype)
116+
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
117117
attn_output = torch.matmul(attn_weights, value_states)
118118
attn_output = attn_output.transpose(1, 2).contiguous()
119119

QEfficient/transformers/models/granitemoe/modeling_granitemoe.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -190,7 +190,7 @@ def eager_attention_forward(
190190
)
191191

192192
# upcast attention to fp32
193-
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=module.config.torch_dtype).to(query.dtype)
193+
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
194194
attn_output = torch.matmul(attn_weights, value_states)
195195
attn_output = attn_output.transpose(1, 2).contiguous()
196196
return attn_output, attn_weights

QEfficient/transformers/models/grok_1/modeling_grok1.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,7 @@ def forward(self, hidden_states: torch.Tensor):
149149
hidden_states = hidden_states.view(-1, hidden_dim)
150150
router_logits = self.gate(hidden_states)
151151

152-
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
152+
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float32)
153153
routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
154154
# Creating experts mask and routing weights masked
155155
awesome_experts_mask_1 = (

QEfficient/transformers/models/llama/modeling_llama.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ def eager_attention_forward(
114114
attention_mask, torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=module.config.torch_dtype), attn_weights
115115
)
116116

117-
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=module.config.torch_dtype).to(query.dtype)
117+
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
118118
attn_output = torch.matmul(attn_weights, value_states)
119119
attn_output = attn_output.transpose(1, 2).contiguous()
120120

0 commit comments

Comments
 (0)