|
37 | 37 | extern "C" { |
38 | 38 | PT_EXPORT {{kernel_call_signature}} { |
39 | 39 | try { |
40 | | - int64_t B = {{kernel.size(Y, 0, -3, default_value=1)}}; |
| 40 | + int B = {{kernel.size(Y, 0, -3, default_value=1)}}; |
41 | 41 | using ElementComputeEpilogue = {{instance_type}}::ElementAccumulator; |
42 | 42 | using coord_t = cutlass::gemm::GemmCoord::Index; |
43 | 43 | static cutlass::KernelHardwareInfo hw_info; |
|
154 | 154 | extern "C" { |
155 | 155 | PT_EXPORT {{kernel_call_signature}} { |
156 | 156 | try { |
157 | | - int64_t B = {{kernel.size(Y, 0, -3, default_value=1)}}; |
| 157 | + int B = {{kernel.size(Y, 0, -3, default_value=1)}}; |
158 | 158 | using ElementComputeEpilogue = {{instance_type}}::ElementAccumulator; |
159 | 159 | using coord_t = cutlass::gemm::GemmCoord::Index; |
160 | 160 | static cutlass::KernelHardwareInfo hw_info; |
|
266 | 266 | // Initialize GemmSparse arguments. |
267 | 267 | arguments = { |
268 | 268 | { |
269 | | - static_cast<coord_t>({{M}}), |
270 | | - static_cast<coord_t>({{N}}), |
| 269 | + static_cast<coord_t>(M), |
| 270 | + static_cast<coord_t>(N), |
271 | 271 | static_cast<coord_t>(2 * K), |
272 | 272 | }, // GemmCoord problem_size |
273 | 273 | X_ref, // TensorRef<ElementA const, LayoutA> ref_A |
|
304 | 304 | if (block.size()<=0) return false; |
305 | 305 | Element scope_max(static_cast<Element>(max)), scope_min(static_cast<Element>(min)); |
306 | 306 | cutlass::reference::device::BlockFillRandomUniform( |
307 | | - block.get(), block.size(), seed, scope_max, scope_min, 0); |
| 307 | + (Element*)block.get(), block.size(), seed, scope_max, scope_min, 0); |
308 | 308 |
|
309 | 309 | return true; |
310 | 310 | } |
311 | 311 |
|
| 312 | +{% if Meta is defined and Meta is not none %} |
| 313 | +template <class Element> |
| 314 | +bool initialize_block_meta( |
| 315 | + cutlass::DeviceAllocation<Element>& block, |
| 316 | + uint64_t seed) { |
| 317 | + if (block.size()<=0) return false; |
| 318 | + cutlass::reference::device::BlockFillRandomSparseMeta( |
| 319 | + (Element*)block.get(), block.size(), seed, {{instance_type}}::kMetaSizeInBits); |
| 320 | + return true; |
| 321 | +} |
| 322 | +{% endif %} |
| 323 | +
|
312 | 324 | extern "C" int run_standalone(uint64_t seed, int repetitions) { |
313 | 325 | std::cout << "Starting GEMM Standalone test run with seed " << seed << std::endl; |
314 | 326 | size_t workspace_size = 0; |
315 | 327 | size_t* workspace_size_ptr = &workspace_size; |
316 | 328 |
|
| 329 | + int M = {{kernel.get_layout_args()[0]}}; |
| 330 | + int N = {{kernel.get_layout_args()[1]}}; |
| 331 | + int K = {{kernel.get_layout_args()[2]}}; |
| 332 | + int lda = {{kernel.get_layout_args()[3]}}; |
| 333 | + int ldb = {{kernel.get_layout_args()[4]}}; |
| 334 | + int ldc = {{kernel.get_layout_args()[5]}}; |
| 335 | + int ldd = {{kernel.get_layout_args()[6]}}; |
| 336 | +
|
317 | 337 | using ElementA = {{kernel.cutlass_dtype(X)}}; |
318 | 338 | using ElementB = {{kernel.cutlass_dtype(W)}}; |
319 | 339 | using ElementC = {{kernel.cutlass_dtype(Bias, default_dtype='uint8_t')}}; // may not be void |
320 | 340 | using ElementD = {{kernel.cutlass_dtype(Y)}}; |
| 341 | + {% if Meta is defined and Meta is not none %} |
| 342 | + using ElementE = {{kernel.cutlass_dtype(Meta)}}; |
| 343 | + {% endif %} |
321 | 344 |
|
322 | 345 | cutlass::DeviceAllocation<ElementA> X_data({{kernel.max_valid_index(X)+1}}); |
323 | 346 | initialize_block(X_data, seed++); |
|
326 | 349 | cutlass::DeviceAllocation<ElementC> Bias_data({{kernel.max_valid_index(Bias)+1}}); |
327 | 350 | initialize_block(Bias_data, seed++); |
328 | 351 | cutlass::DeviceAllocation<ElementD> Y_data({{kernel.max_valid_index(Y)+1}}); |
| 352 | + {% if Meta is defined and Meta is not none %} |
| 353 | + cutlass::DeviceAllocation<ElementE> Meta_data({{kernel.max_valid_index(Meta)+1}}); |
| 354 | + initialize_block_meta(Meta_data, seed++); |
| 355 | + {% endif %} |
329 | 356 |
|
330 | 357 | cutlass::DeviceAllocation<uint8_t> workspace_data; |
331 | 358 | // Call once with workspace_size_ptr set to get workspace size |
@@ -466,6 +493,14 @@ def _get_extra_inputs_and_names( |
466 | 493 | ) -> tuple[Optional[Buffer], list[Optional[Buffer]], list[str]]: |
467 | 494 | raise NotImplementedError |
468 | 495 |
|
| 496 | + @abstractmethod |
| 497 | + def _update_arg_names_for_test_call_statement( |
| 498 | + self, |
| 499 | + arg_names: list[str], |
| 500 | + input_nodes: list[Buffer], |
| 501 | + ) -> list[str]: |
| 502 | + raise NotImplementedError |
| 503 | + |
469 | 504 | def _add_cutlass_gemm_choices( |
470 | 505 | self, |
471 | 506 | choices: list[ChoiceCaller], |
@@ -980,13 +1015,14 @@ def test_call_statement( |
980 | 1015 | """ |
981 | 1016 | _, __, arg_types = kernel.args.cpp_argdefs() |
982 | 1017 | arg_names = [name.strip() for name in names_str.strip().split(",")] |
983 | | - if input_nodes[2] is None: |
984 | | - del arg_names[2] |
| 1018 | + arg_names = self._update_arg_names_for_test_call_statement( |
| 1019 | + arg_names, input_nodes |
| 1020 | + ) |
985 | 1021 | arguments = [ |
986 | 1022 | f"(({arg_type}){arg_name}_data.get())" |
987 | 1023 | for arg_type, arg_name in zip(arg_types, arg_names) |
988 | 1024 | ] |
989 | | - return f"{kernel.kernel_name}({', '.join(arguments)}, workspace_size_ptr, (uint8_t*)workspace_data.get(), 0);" |
| 1025 | + return f"{kernel.kernel_name}({', '.join(arguments)}, M, N, K, lda, ldb, ldc, ldd, workspace_size_ptr, (uint8_t*)workspace_data.get(), 0);" # noqa: B950 |
990 | 1026 |
|
991 | 1027 |
|
992 | 1028 | class CUTLASS3xGemmTemplate(CUTLASSGemmTemplate): |
@@ -1206,6 +1242,15 @@ def _get_extra_inputs_and_names( |
1206 | 1242 | names: list[str] = [] |
1207 | 1243 | return (Bias, inputs, names) |
1208 | 1244 |
|
| 1245 | + def _update_arg_names_for_test_call_statement( |
| 1246 | + self, |
| 1247 | + arg_names: list[str], |
| 1248 | + input_nodes: list[Buffer], |
| 1249 | + ) -> list[str]: |
| 1250 | + if input_nodes[2] is None: |
| 1251 | + del arg_names[2] |
| 1252 | + return arg_names |
| 1253 | + |
1209 | 1254 | def render_gemm_arguments( |
1210 | 1255 | self, |
1211 | 1256 | argument_template: str, |
@@ -1482,6 +1527,17 @@ def _get_extra_inputs_and_names( |
1482 | 1527 | names = ["Meta"] |
1483 | 1528 | return (Bias, inputs, names) |
1484 | 1529 |
|
| 1530 | + def _update_arg_names_for_test_call_statement( |
| 1531 | + self, |
| 1532 | + arg_names: list[str], |
| 1533 | + input_nodes: list[Buffer], |
| 1534 | + ) -> list[str]: |
| 1535 | + if input_nodes[3] is None: |
| 1536 | + del arg_names[3] |
| 1537 | + if input_nodes[2] is None: |
| 1538 | + del arg_names[2] |
| 1539 | + return arg_names |
| 1540 | + |
1485 | 1541 | def render_gemm_arguments( |
1486 | 1542 | self, |
1487 | 1543 | instance_type: str, |
|
0 commit comments