diff --git a/entropix/torch_model.py b/entropix/torch_model.py index 0ebb3e9..0d06ac6 100644 --- a/entropix/torch_model.py +++ b/entropix/torch_model.py @@ -47,14 +47,17 @@ def attention(x: torch.Tensor, layer_weights: LayerWeights, model_params, cur_po xq = torch.permute(xq, (0, 2, 1, 3)) # (bs, n_heads, seqlen, head_dim) keys = torch.permute(keys, (0, 2, 3, 1)) # (bs, n_heads, head_dim, cache_len + seqlen) values = torch.permute(values, (0, 2, 1, 3)) # (bs, n_heads, cache_len + seqlen, head_dim) + xq = xq.to(torch.bfloat16) + keys = keys.to(torch.bfloat16) + scores = torch.matmul(xq, keys) pre_scores = scores / math.sqrt(model_params.head_dim) - scores = pre_scores.to(torch.float32) # Always do attention softmax at float32 + scores = pre_scores.to(torch.bfloat16) # Always do attention softmax at float32 if cur_pos == 0: scores = scores + attn_mask mask = torch.where(scores != 0.0, scores, DEFAULT_MASK_VALUE) padded_logits = torch.where((mask >= DEFAULT_MASK_VALUE * 0.5), scores, DEFAULT_MASK_VALUE) - scores = F.softmax(padded_logits, dim=-1).to(torch.float32) + scores = F.softmax(padded_logits, dim=-1).to(torch.bfloat16) output = torch.matmul(scores, values) output = output.transpose(1, 2).reshape(xq.shape[0], xq.shape[2], -1) out = F.linear(output, layer_weights.wo) @@ -77,4 +80,4 @@ def xfmr(xfmr_weights: XfmrWeights, model_params: ModelParams, tokens: torch.Ten h = h + h_attn h = h + feed_forward(rms_norm(h, xfmr_weights.layer_weights[i].ffn_norm), xfmr_weights.layer_weights[i]) logits = F.linear(rms_norm(h, xfmr_weights.norm), xfmr_weights.output) - return logits, kvcache, scores, attn_stats \ No newline at end of file + return logits, kvcache, scores, attn_stats