@@ -36,8 +36,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
3636 // PagedAttention V2.
3737 ops.def (
3838 " paged_attention_v2("
39- " Tensor! out, Tensor exp_sums, Tensor max_logits,"
40- " Tensor tmp_out, Tensor query, Tensor key_cache,"
39+ " Tensor! out, Tensor! exp_sums, Tensor! max_logits,"
40+ " Tensor! tmp_out, Tensor query, Tensor key_cache,"
4141 " Tensor value_cache, int num_kv_heads, float scale,"
4242 " Tensor block_tables, Tensor seq_lens, int block_size,"
4343 " int max_seq_len, Tensor? alibi_slopes,"
@@ -73,7 +73,11 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
7373 ops.impl (" gelu_quick" , torch::kCUDA , &gelu_quick);
7474
7575 // prepare_inputs advance_step
76- ops.def (" advance_step" , &advance_step);
76+ ops.def (
77+ " advance_step(int num_seqs, int num_queries, int block_size, "
78+ " Tensor! input_tokens, Tensor sampled_token_ids, "
79+ " Tensor! input_positions, Tensor! seq_lens, Tensor! slot_mapping, "
80+ " Tensor block_tables) -> ()" );
7781 ops.impl (" advance_step" , torch::kCUDA , &advance_step);
7882
7983 // Layernorm
@@ -110,27 +114,56 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
110114 // Quantization ops
111115#ifndef USE_ROCM
112116 // Quantized GEMM for AQLM.
113- ops.def (" aqlm_gemm" , &aqlm_gemm);
117+ ops.def (
118+ " aqlm_gemm(Tensor input, Tensor codes, Tensor codebooks, "
119+ " Tensor scales, int[] codebook_partition_sizes, Tensor? bias) "
120+ " -> Tensor" );
114121 ops.impl (" aqlm_gemm" , torch::kCUDA , &aqlm_gemm);
115122
116123 // Decompression method for AQLM.
117- ops.def (" aqlm_dequant" , &aqlm_dequant);
124+ ops.def (
125+ " aqlm_dequant(Tensor codes, Tensor codebooks, "
126+ " int[] codebook_partition_sizes) -> Tensor" );
118127 ops.impl (" aqlm_dequant" , torch::kCUDA , &aqlm_dequant);
119128
120129 // Quantized GEMM for AWQ.
121- ops.def (" awq_gemm" , &awq_gemm);
130+ ops.def (
131+ " awq_gemm(Tensor _in_feats, Tensor _kernel, Tensor _scaling_factors, "
132+ " Tensor _zeros, int split_k_iters) -> Tensor" );
122133 ops.impl (" awq_gemm" , torch::kCUDA , &awq_gemm);
123134
124135 // Dequantization for AWQ.
125- ops.def (" awq_dequantize" , &awq_dequantize);
136+ ops.def (
137+ " awq_dequantize(Tensor _kernel, Tensor _scaling_factors, "
138+ " Tensor _zeros, int split_k_iters, int thx, int thy) -> Tensor" );
126139 ops.impl (" awq_dequantize" , torch::kCUDA , &awq_dequantize);
127140
141+ // Note about marlin kernel 'workspace' arguments:
142+ // Technically these should be mutable since they are modified by the kernel.
143+ // But since they are set back to zero once the kernel is finished we can
144+ // hand wave and say that they have no net effect.
145+ //
146+ // The reason to mark 'workspace' as immutable is so that they don't interfere
147+ // with using ScalarType arguments in the ops. If they are marked as mutable,
148+ // pytorch throws an assert in
149+ // 'torch._higher_order_ops._register_effectful_op' that prevents these
150+ // kernels from being torch.compile'd.
151+ // See the following document for more info on custom types and ops that use
152+ // custom types:
153+ // https://docs.google.com/document/d/18fBMPuOJ0fY5ZQ6YyrHUppw9FA332CpNtgB6SOIgyuA
154+
128155 // Marlin (Dense) Optimized Quantized GEMM for GPTQ.
129- ops.def (" marlin_gemm" , &marlin_gemm);
156+ ops.def (
157+ " marlin_gemm(Tensor a, Tensor b_q_weight, Tensor b_scales, "
158+ " Tensor! workspace, int size_m, int size_n, int size_k) -> Tensor" );
130159 ops.impl (" marlin_gemm" , torch::kCUDA , &marlin_gemm);
131160
132161 // Marlin_24 (Sparse) Optimized Quantized GEMM for GPTQ.
133- ops.def (" gptq_marlin_24_gemm" , &gptq_marlin_24_gemm);
162+ ops.def (
163+ " gptq_marlin_24_gemm(Tensor a, Tensor b_q_weight, Tensor b_meta, "
164+ " Tensor b_scales, Tensor workspace, "
165+ " __torch__.torch.classes._core_C.ScalarType b_q_type, "
166+ " int size_m, int size_n, int size_k) -> Tensor" );
134167 ops.impl (" gptq_marlin_24_gemm" , torch::kCUDA , &gptq_marlin_24_gemm);
135168
136169 // Machete (Dense) Optimized Mixed Precision GEMM for Hopper.
@@ -149,35 +182,55 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
149182 ops.impl (" machete_prepack_B" , torch::kCUDA , &machete::prepack_B);
150183
151184 // gptq_marlin Optimized Quantized GEMM for GPTQ.
152- ops.def (" gptq_marlin_gemm" , &gptq_marlin_gemm);
185+ ops.def (
186+ " gptq_marlin_gemm(Tensor a, Tensor b_q_weight, Tensor b_scales, "
187+ " Tensor b_zeros, Tensor g_idx, Tensor perm, Tensor workspace, "
188+ " __torch__.torch.classes._core_C.ScalarType b_q_type, "
189+ " int size_m, int size_n, int size_k, bool is_k_full, "
190+ " bool has_zp, bool use_fp32_reduce) -> Tensor" );
153191 ops.impl (" gptq_marlin_gemm" , torch::kCUDA , &gptq_marlin_gemm);
154192
155193 // gptq_marlin repack from GPTQ.
156- ops.def (" gptq_marlin_repack" , &gptq_marlin_repack);
194+ ops.def (
195+ " gptq_marlin_repack(Tensor b_q_weight, Tensor perm, "
196+ " SymInt size_k, SymInt size_n, int num_bits) -> Tensor" );
157197 ops.impl (" gptq_marlin_repack" , torch::kCUDA , &gptq_marlin_repack);
198+ ops.impl (" gptq_marlin_repack" , torch::kMeta , &gptq_marlin_repack_meta);
158199
159200 // awq_marlin repack from AWQ.
160- ops.def (" awq_marlin_repack" , &awq_marlin_repack);
201+ ops.def (
202+ " awq_marlin_repack(Tensor b_q_weight, SymInt size_k, "
203+ " SymInt size_n, int num_bits) -> Tensor" );
161204 ops.impl (" awq_marlin_repack" , torch::kCUDA , &awq_marlin_repack);
205+ ops.impl (" awq_marlin_repack" , torch::kMeta , &awq_marlin_repack_meta);
162206
163207 // Dequantization for GGML.
164- ops.def (" ggml_dequantize" , &ggml_dequantize );
208+ ops.def (" ggml_dequantize(Tensor W, int type, int m, int n) -> Tensor " );
165209 ops.impl (" ggml_dequantize" , torch::kCUDA , &ggml_dequantize);
166210
167211 // mmvq kernel for GGML.
168- ops.def (" ggml_mul_mat_vec_a8" , &ggml_mul_mat_vec_a8);
212+ ops.def (
213+ " ggml_mul_mat_vec_a8(Tensor W, Tensor X, int type, int row) "
214+ " -> Tensor" );
169215 ops.impl (" ggml_mul_mat_vec_a8" , torch::kCUDA , &ggml_mul_mat_vec_a8);
170216
171217 // mmq kernel for GGML.
172- ops.def (" ggml_mul_mat_a8" , &ggml_mul_mat_a8 );
218+ ops.def (" ggml_mul_mat_a8(Tensor W, Tensor X, int type, int row) -> Tensor " );
173219 ops.impl (" ggml_mul_mat_a8" , torch::kCUDA , &ggml_mul_mat_a8);
174220
175221 // fp8_marlin Optimized Quantized GEMM for FP8 weight-only.
176- ops.def (" fp8_marlin_gemm" , &fp8_marlin_gemm);
222+ ops.def (
223+ " fp8_marlin_gemm(Tensor a, Tensor b_q_weight, Tensor b_scales, "
224+ " Tensor! workspace, int num_bits, int size_m, int size_n, "
225+ " int size_k) -> Tensor" );
177226 ops.impl (" fp8_marlin_gemm" , torch::kCUDA , &fp8_marlin_gemm);
178227
179228 // marlin_qqq_gemm for QQQ.
180- ops.def (" marlin_qqq_gemm" , &marlin_qqq_gemm);
229+ ops.def (
230+ " marlin_qqq_gemm(Tensor a, Tensor b_q_weight, "
231+ " Tensor s_tok, Tensor s_ch, Tensor s_group, "
232+ " Tensor! workspace, int size_m, int size_n, "
233+ " int size_k) -> Tensor" );
181234 ops.impl (" marlin_qqq_gemm" , torch::kCUDA , &marlin_qqq_gemm);
182235
183236 // CUTLASS w8a8 GEMM, supporting symmetric per-tensor or per-row/column
@@ -199,16 +252,16 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
199252
200253 // Check if cutlass scaled_mm is supported for CUDA devices of the given
201254 // capability
202- ops.def (" cutlass_scaled_mm_supports_fp8" , &cutlass_scaled_mm_supports_fp8 );
203- ops.impl (" cutlass_scaled_mm_supports_fp8" , torch:: kCUDA ,
204- &cutlass_scaled_mm_supports_fp8);
255+ ops.def (" cutlass_scaled_mm_supports_fp8(int cuda_device_capability) -> bool " );
256+ ops.impl (" cutlass_scaled_mm_supports_fp8" , &cutlass_scaled_mm_supports_fp8);
257+
205258 // Mamba selective scan kernel
206259 ops.def (
207260 " selective_scan_fwd(Tensor! u, Tensor! delta,"
208261 " Tensor! A, Tensor! B, Tensor! C,"
209262 " Tensor? D_, Tensor? z_, Tensor? delta_bias_,"
210263 " bool delta_softplus,"
211- " Tensor? index_, Tensor? x) -> Tensor[]" );
264+ " Tensor? index_, Tensor(a! -> *) ? x) -> Tensor(a) []" );
212265 ops.impl (" selective_scan_fwd" , torch::kCUDA , &selective_scan_fwd);
213266
214267 ops.def (
@@ -230,7 +283,12 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
230283#endif
231284
232285 // Quantized GEMM for GPTQ.
233- ops.def (" gptq_gemm" , &gptq_gemm);
286+ // Note: even though the C++ inferred schema is correct for this op, it seems
287+ // to prevent the meta function registry.
288+ ops.def (
289+ " gptq_gemm(Tensor a, Tensor b_q_weight, Tensor b_gptq_qzeros, "
290+ " Tensor b_gptq_scales, Tensor b_g_idx, bool use_exllama, int bit) "
291+ " -> Tensor" );
234292 ops.impl (" gptq_gemm" , torch::kCUDA , &gptq_gemm);
235293
236294 // Post processing for GPTQ.
@@ -250,8 +308,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
250308
251309 // Compute dynamic-per-token FP8 quantized tensor and scaling factor.
252310 ops.def (
253- " dynamic_per_token_scaled_fp8_quant(Tensor! out, Tensor input, Tensor! "
254- " scale, Tensor? scale_ub) -> "
311+ " dynamic_per_token_scaled_fp8_quant(Tensor! out, Tensor input, "
312+ " Tensor! scale, Tensor? scale_ub) -> "
255313 " ()" );
256314 ops.impl (" dynamic_per_token_scaled_fp8_quant" , torch::kCUDA ,
257315 &dynamic_per_token_scaled_fp8_quant);
@@ -288,8 +346,8 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) {
288346
289347 // Copy the cache blocks from src to dst.
290348 cache_ops.def (
291- " copy_blocks(Tensor[]! key_caches, Tensor[]! value_caches, Tensor "
292- " block_mapping) -> ()" );
349+ " copy_blocks(Tensor(a!)[] key_caches, Tensor[](b!) value_caches, "
350+ " Tensor block_mapping) -> ()" );
293351 cache_ops.impl (" copy_blocks" , torch::kCUDA , ©_blocks);
294352
295353 // Reshape the key and value tensors and cache them.
@@ -314,33 +372,37 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) {
314372
315373 // Convert the key and value cache to fp8 data type.
316374 cache_ops.def (
317- " convert_fp8(Tensor! dst_cache, Tensor src_cache, float scale, str "
318- " kv_cache_dtype) -> ()" );
375+ " convert_fp8(Tensor! dst_cache, Tensor src_cache, float scale, "
376+ " str kv_cache_dtype) -> ()" );
319377 cache_ops.impl (" convert_fp8" , torch::kCUDA , &convert_fp8);
320378}
321379
322380TORCH_LIBRARY_EXPAND (CONCAT(TORCH_EXTENSION_NAME, _cuda_utils), cuda_utils) {
323381 // Cuda utils
324382
325383 // Gets the specified device attribute.
326- cuda_utils.def (" get_device_attribute" , &get_device_attribute );
327- cuda_utils.impl (" get_device_attribute" , torch:: kCUDA , &get_device_attribute);
384+ cuda_utils.def (" get_device_attribute(int attribute, int device_id) -> int " );
385+ cuda_utils.impl (" get_device_attribute" , &get_device_attribute);
328386
329387 // Gets the maximum shared memory per block device attribute.
330- cuda_utils.def (" get_max_shared_memory_per_block_device_attribute " ,
331- &get_max_shared_memory_per_block_device_attribute );
388+ cuda_utils.def (
389+ " get_max_shared_memory_per_block_device_attribute(int device_id) -> int " );
332390 cuda_utils.impl (" get_max_shared_memory_per_block_device_attribute" ,
333- torch::kCUDA ,
334391 &get_max_shared_memory_per_block_device_attribute);
335392}
336393
337394#ifndef USE_ROCM
338395TORCH_LIBRARY_EXPAND (CONCAT(TORCH_EXTENSION_NAME, _custom_ar), custom_ar) {
339396 // Custom all-reduce kernels
340- custom_ar.def (" init_custom_ar" , &init_custom_ar);
397+ custom_ar.def (
398+ " init_custom_ar(Tensor meta, Tensor rank_data, "
399+ " str[] handles, int[] offsets, int rank, "
400+ " bool full_nvlink) -> int" );
341401 custom_ar.impl (" init_custom_ar" , torch::kCUDA , &init_custom_ar);
342402
343- custom_ar.def (" should_custom_ar" , &should_custom_ar);
403+ custom_ar.def (
404+ " should_custom_ar(Tensor inp, int max_size, int world_size, "
405+ " bool full_nvlink) -> bool" );
344406 custom_ar.impl (" should_custom_ar" , torch::kCUDA , &should_custom_ar);
345407
346408 custom_ar.def (" all_reduce_reg(int fa, Tensor inp, Tensor! out) -> ()" );
@@ -352,21 +414,15 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _custom_ar), custom_ar) {
352414 custom_ar.impl (" all_reduce_unreg" , torch::kCUDA , &all_reduce_unreg);
353415
354416 custom_ar.def (" dispose" , &dispose);
355- custom_ar.impl (" dispose" , torch::kCPU , &dispose);
356-
357417 custom_ar.def (" meta_size" , &meta_size);
358- custom_ar.impl (" meta_size" , torch::kCPU , &meta_size);
359418
360- custom_ar.def (" register_buffer" , ®ister_buffer);
419+ custom_ar.def (
420+ " register_buffer(int fa, Tensor t, str[] handles, "
421+ " int[] offsets) -> ()" );
361422 custom_ar.impl (" register_buffer" , torch::kCUDA , ®ister_buffer);
362423
363424 custom_ar.def (" get_graph_buffer_ipc_meta" , &get_graph_buffer_ipc_meta);
364- custom_ar.impl (" get_graph_buffer_ipc_meta" , torch::kCPU ,
365- &get_graph_buffer_ipc_meta);
366-
367425 custom_ar.def (" register_graph_buffers" , ®ister_graph_buffers);
368- custom_ar.impl (" register_graph_buffers" , torch::kCPU ,
369- ®ister_graph_buffers);
370426}
371427#endif
372428
0 commit comments