Skip to content

Fix type issue introduced by #28#39

Open
HaileyStorm wants to merge 1 commit intoxjdr-alt:mainfrom
HaileyStorm:patch-1
Open

Fix type issue introduced by #28#39
HaileyStorm wants to merge 1 commit intoxjdr-alt:mainfrom
HaileyStorm:patch-1

Conversation

@HaileyStorm
Copy link
Contributor

@HaileyStorm HaileyStorm commented Oct 7, 2024

Commit #28 changed apply_rotary_embed to have dtype parameter with default float32, and forces attention softmax to be done float32. Since attention doesn't specify the dtype parameter when calling apply_rotary_embed, and the output matmul doesn't convert back from float32 to match the values type, this is an issue if you're running BF16.

This specifies the existing xq.dtype for the dtype parameter when calling apply_rotary_embed (alternatively, we could cast keys to float32 in scores = torch.matmul(xq, keys)), and casts scores to match values at the output matmul.

Commit xjdr-alt#28 changed `apply_rotary_embed` to have dtype parameter with default float32, and forces attention softmax to be done float32. Since `attention` doesn't specify the dtype parameter when calling `apply_rotary_embed`, and output matmul doesn't convert back from float32 to match the values type, this is an issue if you're running BF16.

This specifies the existing xq.dtype for the dtype parameter when calling `apply_rotary_embed` (alternatively, we could cast keys to float32 in `scores = torch.matmul(xq, keys)`), and converts the scores to match values at the output matmul.
@xjdr-alt
Copy link
Owner

xjdr-alt commented Oct 8, 2024

@Arrabonae could you take a look

Copy link
Contributor

@citizenhicks citizenhicks left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tested this, works well. thanks for spotting this issue!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants