Skip to content

Commit b8a6074

Browse files
committed
small fix
1 parent 2863442 commit b8a6074

File tree

3 files changed

+58
-63
lines changed

3 files changed

+58
-63
lines changed

cmake/BuildFlags.cmake

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -129,8 +129,6 @@ if(CMAKE_CXX_COMPILER_ID STREQUAL "GNU")
129129

130130
set(SYCL_FLAGS ${SYCL_FLAGS} ${SYCL_KERNEL_OPTIONS})
131131

132-
# set(SYCL_OFFLINE_COMPILER_CG_OPTIONS ${SYCL_OFFLINE_COMPILER_CG_OPTIONS} -fno-sycl-instrument-device-code)
133-
# set(SYCL_OFFLINE_COMPILER_CG_OPTIONS ${SYCL_OFFLINE_COMPILER_CG_OPTIONS} ${SYCL_LINK_FLAGS})
134132
set(SYCL_OFFLINE_COMPILER_FLAGS "${SYCL_OFFLINE_COMPILER_AOT_OPTIONS}${SYCL_OFFLINE_COMPILER_CG_OPTIONS}")
135133
else()
136134
message("Not compiling with XPU. Currently only support GCC compiler on Linux as CXX compiler.")

python/sgl_kernel/flash_attn.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -198,8 +198,6 @@ def flash_attn_with_kvcache(
198198
)
199199
).to(torch.int32)
200200

201-
import pdb; pdb.set_trace()
202-
203201
out, softmax_lse, *rest = torch.ops.sgl_kernel.fwd.default(
204202
q,
205203
k_cache,

tests/test_flash_attention.py

Lines changed: 58 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -921,65 +921,64 @@ def test_flash_attn_kvcache(
921921

922922
# # Check that FlashAttention's numerical error is at most twice the numerical error
923923
# # of a Pytorch implementation.
924-
# if new_kv:
925-
# if page_size is None:
926-
# k_cache_select = (
927-
# k_cache.to(dtype_ref)
928-
# if not has_batch_idx
929-
# else k_cache.to(dtype_ref)[cache_batch_idx]
930-
# )
931-
# v_cache_select = (
932-
# v_cache.to(dtype_ref)
933-
# if not has_batch_idx
934-
# else v_cache.to(dtype_ref)[cache_batch_idx]
935-
# )
936-
# else:
937-
# k_cache_select = rearrange(
938-
# k_cache_paged.to(dtype_ref)[
939-
# (
940-
# page_table
941-
# if not has_batch_idx
942-
# else page_table[cache_batch_idx]
943-
# ).flatten()
944-
# ],
945-
# "(b nblocks) block_size ... -> b (nblocks block_size) ...",
946-
# b=batch_size,
947-
# )[:, :seqlen_k].to(dtype_ref)
948-
# v_cache_select = rearrange(
949-
# v_cache_paged.to(dtype_ref)[
950-
# (
951-
# page_table
952-
# if not has_batch_idx
953-
# else page_table[cache_batch_idx]
954-
# ).flatten()
955-
# ],
956-
# "(b nblocks) block_size ... -> b (nblocks block_size) ...",
957-
# b=batch_size,
958-
# )[:, :seqlen_k].to(dtype_ref)
959-
# k_cache_ref = k_cache_ref.to(dtype).to(dtype_ref)
960-
# v_cache_ref = v_cache_ref.to(dtype).to(dtype_ref)
961-
# # if dtype is not torch.float8_e4m3fn:
962-
# # import pdb; pdb.set_trace()
963-
# # assert torch.equal(v_cache_select, v_cache_ref)
964-
# # else:
965-
# # assert torch.allclose(
966-
# # v_cache_select, v_cache_ref, rtol=1e-3, atol=1e-3
967-
# # )
968-
# # breakpoint()
969-
# # if rotary_dim == 0 and dtype is not torch.float8_e4m3fn:
970-
# # if rotary_dim == 0:
971-
# # assert torch.equal(k_cache_select, k_cache_ref)
972-
# # else:
973-
# # # if not torch.allclose(k_cache_select, k_cache_ref, rtol=1e-3, atol=1e-3):
974-
# # # breakpoint()
975-
# # if dtype is not torch.float8_e4m3fn:
976-
# # assert torch.allclose(
977-
# # k_cache_select, k_cache_ref, rtol=1e-3, atol=1e-3
978-
# # )
979-
# # else:
980-
# # assert torch.allclose(
981-
# # k_cache_select, k_cache_ref, rtol=1e-1, atol=1e-1
982-
# # )
924+
if new_kv:
925+
if page_size is None:
926+
k_cache_select = (
927+
k_cache.to(dtype_ref)
928+
if not has_batch_idx
929+
else k_cache.to(dtype_ref)[cache_batch_idx]
930+
)
931+
v_cache_select = (
932+
v_cache.to(dtype_ref)
933+
if not has_batch_idx
934+
else v_cache.to(dtype_ref)[cache_batch_idx]
935+
)
936+
else:
937+
k_cache_select = rearrange(
938+
k_cache_paged.to(dtype_ref)[
939+
(
940+
page_table
941+
if not has_batch_idx
942+
else page_table[cache_batch_idx]
943+
).flatten()
944+
],
945+
"(b nblocks) block_size ... -> b (nblocks block_size) ...",
946+
b=batch_size,
947+
)[:, :seqlen_k].to(dtype_ref)
948+
v_cache_select = rearrange(
949+
v_cache_paged.to(dtype_ref)[
950+
(
951+
page_table
952+
if not has_batch_idx
953+
else page_table[cache_batch_idx]
954+
).flatten()
955+
],
956+
"(b nblocks) block_size ... -> b (nblocks block_size) ...",
957+
b=batch_size,
958+
)[:, :seqlen_k].to(dtype_ref)
959+
k_cache_ref = k_cache_ref.to(dtype).to(dtype_ref)
960+
v_cache_ref = v_cache_ref.to(dtype).to(dtype_ref)
961+
if dtype is not torch.float8_e4m3fn:
962+
import pdb; pdb.set_trace()
963+
assert torch.equal(v_cache_select, v_cache_ref)
964+
else:
965+
assert torch.allclose(
966+
v_cache_select, v_cache_ref, rtol=1e-3, atol=1e-3
967+
)
968+
breakpoint()
969+
if rotary_dim == 0 and dtype is not torch.float8_e4m3fn:
970+
assert torch.equal(k_cache_select, k_cache_ref)
971+
else:
972+
# if not torch.allclose(k_cache_select, k_cache_ref, rtol=1e-3, atol=1e-3):
973+
# breakpoint()
974+
if dtype is not torch.float8_e4m3fn:
975+
assert torch.allclose(
976+
k_cache_select, k_cache_ref, rtol=1e-3, atol=1e-3
977+
)
978+
else:
979+
assert torch.allclose(
980+
k_cache_select, k_cache_ref, rtol=1e-1, atol=1e-1
981+
)
983982
mult = 4 if dtype == torch.float8_e4m3fn else 2
984983
assert (out - out_ref).abs().max().item() <= mult * (
985984
out_pt - out_ref

0 commit comments

Comments
 (0)