Skip to content

Commit dfc80c7

Browse files
committed
add bindings
1 parent 7e0bb01 commit dfc80c7

File tree

1 file changed

+39
-0
lines changed

1 file changed

+39
-0
lines changed

src/torch_extension_sycl.cc

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,45 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
3838
// "fp8_blockwise_scaled_mm(Tensor mat_a, Tensor mat_b, Tensor scales_a, Tensor scales_b, ScalarType out_dtype,
3939
// -> Tensor");
4040
// m.impl("fp8_blockwise_scaled_mm", torch::kXPU, &fp8_blockwise_scaled_mm);
41+
42+
/*
43+
* From cutlass attention
44+
*/
45+
m.def(
46+
"fwd(Tensor! q,"
47+
" Tensor k,"
48+
" Tensor v,"
49+
" Tensor? k_new,"
50+
" Tensor? v_new,"
51+
" Tensor? q_v,"
52+
" Tensor!? out,"
53+
" Tensor? cu_seqlens_q,"
54+
" Tensor? cu_seqlens_k,"
55+
" Tensor? cu_seqlens_k_new,"
56+
" Tensor? seqused_q,"
57+
" Tensor? seqused_k,"
58+
" int? max_seqlen_q,"
59+
" int? max_seqlen_k,"
60+
" Tensor? page_table,"
61+
" Tensor? kv_batch_idx,"
62+
" Tensor? leftpad_k,"
63+
" Tensor? rotary_cos,"
64+
" Tensor? rotary_sin,"
65+
" Tensor? seqlens_rotary,"
66+
" Tensor? q_descale,"
67+
" Tensor? k_descale,"
68+
" Tensor? v_descale,"
69+
" float softmax_scale,"
70+
" bool is_causal,"
71+
" int window_size_left,"
72+
" int window_size_right,"
73+
" float softcap,"
74+
" bool is_rotary_interleaved,"
75+
" Tensor? scheduler_metadata,"
76+
" int num_splits,"
77+
" bool? pack_gqa,"
78+
" int sm_margin) -> Tensor[]");
79+
m.impl("fwd", torch::kXPU, make_pytorch_shim(&mha_fwd));
4180
}
4281

4382
REGISTER_EXTENSION(common_ops)

0 commit comments

Comments
 (0)