-
Notifications
You must be signed in to change notification settings - Fork 12
Initialize Cutlass-SYCL support #6
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 24 commits
b46c6b0
866ab6c
20d35e0
25c7bd5
6bd00d7
d5a32ec
77d3545
31f6fa9
5388444
2863442
b8a6074
6ad98d8
1550a6a
9dba4ff
f5c2c89
a05d6ce
8b0d167
c394772
0079b6e
8a3ddea
67a20fe
53850dc
884200b
6433273
f183b23
cdf10ed
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -3,11 +3,6 @@ | |
| import torch | ||
| import torch.nn as nn | ||
|
|
||
| try: | ||
| from sgl_kernel import flash_ops | ||
| except: | ||
| raise ImportError("Can not import sgl_kernel. Please check your installation.") | ||
|
|
||
|
|
||
| def is_fa3_supported(device=None) -> bool: | ||
| # There some fa3 FYI | ||
|
|
@@ -18,10 +13,16 @@ def is_fa3_supported(device=None) -> bool: | |
| # https://docs.nvidia.com/cuda/cuda-c-programming-guide/#shared-memory-8-x | ||
| # And for sgl-kernel right now, we can build fa3 on sm80/sm86/sm89/sm90a. | ||
| # That means if you use A100/A*0/L20/L40/L40s/4090 you can use fa3. | ||
| return ( | ||
| torch.cuda.get_device_capability(device)[0] == 9 | ||
| or torch.cuda.get_device_capability(device)[0] == 8 | ||
| ) and (torch.version.cuda >= "12.3") | ||
| if torch.cuda.is_available(): | ||
| return ( | ||
| torch.cuda.get_device_capability(device)[0] == 9 | ||
| or torch.cuda.get_device_capability(device)[0] == 8 | ||
| ) and (torch.version.cuda >= "12.3") | ||
| elif torch.xpu.is_available(): | ||
| device_name = torch.xpu.get_device_properties(0).name | ||
| return "B580" in device_name or "e211" in device_name | ||
| else: | ||
| return False | ||
|
|
||
|
|
||
| def maybe_contiguous(x): | ||
|
|
@@ -171,21 +172,47 @@ def flash_attn_with_kvcache( | |
| rotary_cos, rotary_sin = [maybe_contiguous(x) for x in (rotary_cos, rotary_sin)] | ||
| rotary_seqlens = maybe_contiguous(rotary_seqlens) | ||
|
|
||
| if cu_seqlens_q == None: # !is_varlen_q | ||
| cu_seqlens_q = torch.arange( | ||
| 0, q.size(0) + 1, dtype=torch.int, device=q.device | ||
| ) * q.size(1) | ||
| max_seqlen_q = q.size(1) | ||
| q = q.view(-1, q.size(-2), q.size(-1)).contiguous() | ||
| # if cu_seqlens_k_new is None and k is not None: # !is_varlen_k_new | ||
| # cu_seqlens_k_new = torch.arange( | ||
| # 0, k.size(0) + 1, dtype=torch.int, device=k.device | ||
| # ) | ||
| # elif k is None: | ||
| # cu_seqlens_k_new = torch.zeros_like( | ||
| # cu_seqlens_q, dtype=torch.int32, device=q.device | ||
| # ) | ||
| if cache_seqlens is not None: | ||
| max_seqlen_k = cache_seqlens.max().item() | ||
| assert cache_seqlens.size(0) + 1 == cu_seqlens_q.size(0) | ||
| # max_page_size_per_seq = page_table.size(1) | ||
| # # will delete later | ||
| # num_pages_per_seq = torch.arange( | ||
| # 0, | ||
| # cache_seqlens.size(0) * max_page_size_per_seq, | ||
| # max_page_size_per_seq, | ||
| # device=cache_seqlens.device, | ||
| # ).to(torch.int32) | ||
| cu_seqlens_k = torch.concat( | ||
| ( | ||
| torch.zeros(1, dtype=torch.int32, device=cache_seqlens.device), | ||
| torch.cumsum(cache_seqlens, 0), | ||
| ) | ||
| ).to(torch.int32) | ||
|
|
||
| out, softmax_lse, *rest = torch.ops.sgl_kernel.fwd.default( | ||
| q, | ||
| k_cache, | ||
| v_cache, | ||
| k, | ||
| v, | ||
| qv, | ||
| None, # out | ||
| cu_seqlens_q, | ||
| None, # cu_seqlens_k | ||
| cu_seqlens_k_new, | ||
| None, # seqused_q | ||
| cache_seqlens, | ||
| cu_seqlens_k, | ||
| max_seqlen_q, | ||
| None, # max_seqlen_k | ||
| max_seqlen_k, | ||
| page_table, | ||
| cache_batch_idx, | ||
| cache_leftpad, | ||
|
|
@@ -235,13 +262,26 @@ def flash_attn_varlen_func( | |
| ): | ||
| if not is_fa3_supported(): | ||
| raise NotImplementedError( | ||
| "flash_attn at sgl-kernel is only supported on sm90 and above" | ||
| "flash_attn at sgl-kernel-xpu is only supported on BMG and later" | ||
| ) | ||
|
|
||
| if softmax_scale is None: | ||
| softmax_scale = (q.shape[-1] + (qv.shape[-1] if qv is not None else 0)) ** ( | ||
| -0.5 | ||
| ) | ||
| if cu_seqlens_q == None: # !is_varlen_q | ||
| cu_seqlens_q = torch.arange( | ||
| 0, q.size(0) + 1, dtype=torch.int, device=q.device | ||
| ) * q.size(1) | ||
| max_seqlen_q = q.size(1) | ||
| q = q.view(-1, q.size(-2), q.size(-1)).contiguous() | ||
| batch_size = cu_seqlens_q.numel() - 1 | ||
| page_table = ( | ||
| torch.arange(0, batch_size, device=q.device) | ||
| .to(torch.int32) | ||
| .reshape([batch_size, 1]) | ||
| .contiguous() | ||
| ) | ||
|
Comment on lines
+255
to
+267
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. what extra functionality we are trying to provide ?
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. current kernel implementation are align between vllm and sglang requests, so there will be some changes on the sglang side.” |
||
|
|
||
| out, softmax_lse, *rest = torch.ops.sgl_kernel.fwd.default( | ||
| q, | ||
|
|
@@ -250,15 +290,13 @@ def flash_attn_varlen_func( | |
| None, # k_new | ||
| None, # v_new | ||
| qv, # qv | ||
| None, # out | ||
| cu_seqlens_q, | ||
| cu_seqlens_k, | ||
| None, # cu_seqlens_k_new | ||
| seqused_q, | ||
| seqused_k, | ||
| max_seqlen_q, | ||
| max_seqlen_k, | ||
| None, # page_table, | ||
| page_table, # page_table, | ||
| page_table, # num_pages_per_seq | ||
| None, # kv_batch_idx | ||
| None, # leftpad_k | ||
| None, # rotary cos | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why are we changing function signature ?