Skip to content

Commit 0677c69

Browse files
aviator19941xintin
authored andcommitted
Move scaled_dot_product_attention_bhsd under iree.turbine (iree-org#870)
This PR moves `scaled_dot_product_attention_bhsd` under iree.turbine to have one util file for all the reference kernels. --------- Signed-off-by: aviator19941 <[email protected]> Signed-off-by: xintin <[email protected]>
1 parent 43551dd commit 0677c69

File tree

6 files changed

+77
-65
lines changed

6 files changed

+77
-65
lines changed
Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
# Copyright 2025 Advanced Micro Devices, Inc.
2+
#
3+
# Licensed under the Apache License v2.0 with LLVM Exceptions.
4+
# See https://llvm.org/LICENSE.txt for license information.
5+
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
7+
import torch
8+
import torch.nn.functional as F
9+
from torch import Tensor
10+
11+
12+
def scaled_dot_product_attention_bhsd(
13+
query: Tensor,
14+
key: Tensor,
15+
value: Tensor,
16+
is_causal: bool = False,
17+
sliding_window: int = -1,
18+
custom_mask: Tensor | None = None,
19+
) -> Tensor:
20+
"""
21+
This version mimics PyTorch's `torch.nn.functional.scaled_dot_product_attention`
22+
with optional causal masking and improved numerical stability.
23+
Intended for comparison and debugging purposes.
24+
Args:
25+
query (Tensor): query tensor of shape [B, H, S_q, D].
26+
key (Tensor): key tensor of shape [B, H, S_k, D].
27+
value (Tensor): value tensor of shape [B, H, S_k, D].
28+
is_causal (bool): If True, applies causal masking to the attention logits.
29+
Returns:
30+
Tensor: Output tensor of shape [B, H, S_q, D] after applying attention.
31+
"""
32+
if query.dtype != torch.float32:
33+
query = query.to(torch.float32)
34+
if key.dtype != torch.float32:
35+
key = key.to(torch.float32)
36+
if value.dtype != torch.float32:
37+
value = value.to(torch.float32)
38+
39+
scale: float = query.shape[-1] ** -0.5
40+
attn_logits: Tensor = torch.matmul(query, key.transpose(-2, -1)) * scale
41+
42+
if sliding_window >= 0:
43+
assert is_causal, f"Sliding window only supported with causal"
44+
45+
if is_causal:
46+
seq_len_q, seq_len_k = attn_logits.shape[-2], attn_logits.shape[-1]
47+
causal_mask: Tensor = torch.tril(
48+
torch.ones(
49+
(seq_len_q, seq_len_k), device=attn_logits.device, dtype=torch.bool
50+
)
51+
)
52+
if sliding_window >= 0:
53+
causal_mask = causal_mask.triu(-sliding_window)
54+
attn_logits = attn_logits.masked_fill(~causal_mask, float("-inf"))
55+
56+
if custom_mask is not None:
57+
bool_mask = custom_mask.to(torch.bool)
58+
bool_mask = bool_mask[:, None, :, None]
59+
assert bool_mask.shape == (query.shape[0], 1, query.shape[2], 1)
60+
attn_logits = attn_logits.masked_fill(bool_mask, float("-inf"))
61+
62+
# Improve numerical stability using log-sum-exp trick
63+
attn_logits = attn_logits - attn_logits.max(dim=-1, keepdim=True).values
64+
attn_weights: Tensor = F.softmax(attn_logits, dim=-1)
65+
attn_weights = torch.nan_to_num(attn_weights, nan=0.0)
66+
67+
return torch.matmul(attn_weights, value)

tests/kernel/wave/attention/gqa_vanilla_attention_test.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@
2626
enable_scheduling_barriers,
2727
require_e2e,
2828
require_cdna3,
29-
scaled_dot_product_attention_bhsd,
3029
)
3130
from ..common.shapes import get_test_shapes
3231
from iree.turbine.kernel.wave.templates.gqa_vanilla_attention import (
@@ -35,6 +34,9 @@
3534
from iree.turbine.kernel.wave.templates.attention_common import AttentionShape
3635
from iree.turbine.kernel.wave.scheduling.schedule import SchedulingType
3736
from iree.turbine.kernel.wave.compile import wave_compile, WaveCompileOptions
37+
from iree.turbine.kernel.wave.utils.reference_kernel_utils import (
38+
scaled_dot_product_attention_bhsd,
39+
)
3840

3941

4042
@require_e2e

tests/kernel/wave/attention/vanilla_attention_test.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,6 @@
3434
param_bool,
3535
require_cdna3,
3636
require_e2e,
37-
scaled_dot_product_attention_bhsd,
3837
)
3938
from ..common.shapes import get_test_shapes
4039
from iree.turbine.kernel.wave.templates.vanilla_attention import (
@@ -45,6 +44,9 @@
4544
from iree.turbine.kernel.wave.templates.attention_common import AttentionShape
4645
from iree.turbine.kernel.wave.scheduling.schedule import SchedulingType
4746
from iree.turbine.kernel.wave.compile import wave_compile, WaveCompileOptions
47+
from iree.turbine.kernel.wave.utils.reference_kernel_utils import (
48+
scaled_dot_product_attention_bhsd,
49+
)
4850

4951

5052
@require_e2e

tests/kernel/wave/common/utils.py

Lines changed: 0 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,6 @@
99
from iree.turbine.kernel.wave.utils.run_utils import (
1010
get_default_arch,
1111
)
12-
import torch
13-
import torch.nn.functional as F
14-
from torch import Tensor
1512

1613
require_e2e = pytest.mark.require_e2e
1714
expensive_test = pytest.mark.expensive_test
@@ -42,63 +39,3 @@ def param_bool(name, shortname=None, values=None):
4239
values = values or [False, True]
4340
ids = [f"{shortname}" if v else f"no_{shortname}" for v in values]
4441
return pytest.mark.parametrize(name, [pytest.param(v) for v in values], ids=ids)
45-
46-
47-
def scaled_dot_product_attention_bhsd(
48-
query: Tensor,
49-
key: Tensor,
50-
value: Tensor,
51-
is_causal: bool = False,
52-
sliding_window: int = -1,
53-
custom_mask: Tensor | None = None,
54-
) -> Tensor:
55-
"""
56-
This version mimics PyTorch's `torch.nn.functional.scaled_dot_product_attention`
57-
with optional causal masking and improved numerical stability.
58-
Intended for comparison and debugging purposes.
59-
60-
Args:
61-
query (Tensor): query tensor of shape [B, H, S_q, D].
62-
key (Tensor): key tensor of shape [B, H, S_k, D].
63-
value (Tensor): value tensor of shape [B, H, S_k, D].
64-
is_causal (bool): If True, applies causal masking to the attention logits.
65-
66-
Returns:
67-
Tensor: Output tensor of shape [B, H, S_q, D] after applying attention.
68-
"""
69-
if query.dtype != torch.float32:
70-
query = query.to(torch.float32)
71-
if key.dtype != torch.float32:
72-
key = key.to(torch.float32)
73-
if value.dtype != torch.float32:
74-
value = value.to(torch.float32)
75-
76-
scale: float = query.shape[-1] ** -0.5
77-
attn_logits: Tensor = torch.matmul(query, key.transpose(-2, -1)) * scale
78-
79-
if sliding_window >= 0:
80-
assert is_causal, f"Sliding window only supported with causal"
81-
82-
if is_causal:
83-
seq_len_q, seq_len_k = attn_logits.shape[-2], attn_logits.shape[-1]
84-
causal_mask: Tensor = torch.tril(
85-
torch.ones(
86-
(seq_len_q, seq_len_k), device=attn_logits.device, dtype=torch.bool
87-
)
88-
)
89-
if sliding_window >= 0:
90-
causal_mask = causal_mask.triu(-sliding_window)
91-
attn_logits = attn_logits.masked_fill(~causal_mask, float("-inf"))
92-
93-
if custom_mask is not None:
94-
bool_mask = custom_mask.to(torch.bool)
95-
bool_mask = bool_mask[:, None, :, None]
96-
assert bool_mask.shape == (query.shape[0], 1, query.shape[2], 1)
97-
attn_logits = attn_logits.masked_fill(bool_mask, float("-inf"))
98-
99-
# Improve numerical stability using log-sum-exp trick
100-
attn_logits = attn_logits - attn_logits.max(dim=-1, keepdim=True).values
101-
attn_weights: Tensor = F.softmax(attn_logits, dim=-1)
102-
attn_weights = torch.nan_to_num(attn_weights, nan=0.0)
103-
104-
return torch.matmul(attn_weights, value)

tests/kernel/wave/nn/functional/wave_quant_attention_test.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
from ...common.utils import (
1313
require_e2e,
1414
require_cdna3,
15+
)
16+
from iree.turbine.kernel.wave.utils.reference_kernel_utils import (
1517
scaled_dot_product_attention_bhsd,
1618
)
1719

tests/kernel/wave/nn/functional/wave_sdpa_test.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
from ...common.utils import (
1313
require_e2e,
1414
require_cdna3,
15+
)
16+
from iree.turbine.kernel.wave.utils.reference_kernel_utils import (
1517
scaled_dot_product_attention_bhsd,
1618
)
1719

0 commit comments

Comments
 (0)