Skip to content

Commit e8dd142

Browse files
committed
enable ut
1 parent dfc80c7 commit e8dd142

File tree

4 files changed

+23
-6
lines changed

4 files changed

+23
-6
lines changed

pyproject.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
[build-system]
22
requires = [
33
"scikit-build-core>=0.10",
4-
"pytorch-triton-xpu @ https://download.pytorch.org/whl/test/pytorch_triton_xpu-3.3.1-cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl",
54
"wheel",
65
]
76
build-backend = "scikit_build_core.build"

src/torch_extension_sycl.cc

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,9 @@ limitations under the License.
1616
#include <torch/all.h>
1717
#include <torch/library.h>
1818

19+
#include "sgl_kernel_torch_shim.h"
20+
21+
#include "sgl_flash_kernel_ops.h"
1922
#include "sgl_kernel_ops.h"
2023

2124
TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {

tests/test_flash_attention.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,10 @@
88
import torch.nn.functional as F
99
from einops import rearrange, repeat
1010

11+
import utils
12+
13+
device = utils.get_device()
14+
1115
apply_rotary_emb = None
1216

1317

@@ -25,11 +29,14 @@ def is_fa3_supported(device=None) -> bool:
2529
# https://docs.nvidia.com/cuda/cuda-c-programming-guide/#shared-memory-8-x
2630
# And for sgl-kernel right now, we can build fa3 on sm80/sm86/sm89/sm90a.
2731
# That means if you use A100/A*0/L20/L40/L40s/4090 you can use fa3.
28-
return (
32+
if torch.cuda.is_available():
33+
return (
2934
torch.cuda.get_device_capability(device)[0] == 9
3035
or torch.cuda.get_device_capability(device)[0] == 8
31-
) and (torch.version.cuda >= "12.3")
32-
36+
) and (torch.version.cuda >= "12.3")
37+
elif torch.xpu.is_available():
38+
device_name = torch.xpu.get_device_properties(0).name
39+
return "B580" in device_name or "e211" in device_name
3340

3441
DISABLE_BACKWARD = True
3542
# For CI test, we close them to True.
@@ -551,7 +558,6 @@ def test_flash_attn_kvcache(
551558
pytest.skip()
552559
if rotary_fraction == 0.0 and has_rotary_seqlens:
553560
pytest.skip()
554-
device = "cuda"
555561
# set seed
556562
torch.random.manual_seed(0)
557563
batch_size = 5
@@ -1077,7 +1083,6 @@ def test_flash_attn_varlen_output(
10771083
):
10781084
from sgl_kernel.flash_attn import flash_attn_varlen_func
10791085

1080-
device = "cuda"
10811086
# set seed
10821087
torch.random.manual_seed(seqlen_q + seqlen_k + d + int(causal) * 2 + int(local))
10831088
# batch_size = 40

tests/utils.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
import torch
2+
3+
def get_device() :
4+
if torch.cuda.is_available():
5+
device = torch.device("cuda")
6+
elif torch.xpu.is_available():
7+
device = torch.device("xpu")
8+
else:
9+
device = torch.device("cpu")
10+
return device

0 commit comments

Comments
 (0)