-
Notifications
You must be signed in to change notification settings - Fork 156
[TileOp] Implement WGMMA for T.gemm_v2 #813
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
- Added support for the WGMMA intrinsic in the TileLang framework, enabling efficient matrix multiplication on newer architectures. - Refactored GEMM layout functions to accept a boolean parameter for K dimension handling, improving flexibility in layout generation. - Updated layout inference logic to accommodate new WGMMA configurations and ensure compatibility with existing GEMM operations. - Enhanced Python bindings for layout functions, allowing for better integration and usability in user-defined operations. - Improved documentation for layout functions and GEMM operations to clarify usage and parameters. These changes enhance the performance and usability of GEMM operations, particularly for advanced architectures, while maintaining backward compatibility with existing implementations.
…bility - Improved code formatting across multiple files for better readability, including consistent indentation and line breaks. - Updated layout function signatures to enhance clarity, particularly in `gemm_layouts.cc`, `layout.cc`, and `layout.h`. - Refactored lambda functions in `builtin.cc` and `gemm_py.cc` for improved structure and maintainability. - Enhanced comments and documentation in layout-related files to clarify usage and parameters. These changes contribute to a cleaner codebase and improved maintainability of layout functions in the TileLang framework.
👋 Hi! Thank you for contributing to the TileLang project. Please remember to run We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀 |
Note Other AI code review bot(s) detectedCodeRabbit has detected other AI code review bot(s) in this pull request and will avoid duplicating their findings in the review comments. This may lead to a less comprehensive review. WalkthroughReplaces numeric k-factor GEMM layout parameters with boolean k_inner across layout APIs and callers; adds WGMMA support (intrinsics, PTX/CUDA codegen, descriptors), new Python layout/descriptor helpers and equality checks, and a runtime GemmInst dispatch with a new GemmWGMMA backend. Changes
Sequence Diagram(s)sequenceDiagram
autonumber
actor User
participant GemmPy
participant FFI as _ffi_api.GemmPyGemmInst
participant Dispatcher as GemmPy.dispatch
participant Impl as GemmMMA/GemmWGMMA
participant Emitter as TensorCoreIntrinEmitter
participant TIR as TIR Lowering
participant CG as CUDA/PTX Codegen
User->>GemmPy: infer_layout / lower(request)
GemmPy->>FFI: GemmPyGemmInst(thread_nums, target)
FFI-->>GemmPy: GemmInst (MMA / WGMMA / MFMA)
GemmPy->>Dispatcher: select implementation
Dispatcher->>Impl: dispatch to GemmMMA or GemmWGMMA
Impl->>Emitter: configure(dtypes, trans, tiles, thread_var)
Impl->>TIR: emit TIR with intrinsics (ptx_mma_* or ptx_wgmma_*)
TIR->>CG: emit descriptor init / wgmma PTX assembly
CG-->>User: compiled kernel
sequenceDiagram
autonumber
participant TIR
participant TL as tl.builtin
participant CUDA as codegen_cuda
participant PTX as codegen_ptx
note over TIR,PTX: WGMMA-SS descriptor flow
TIR->>TL: initialize_descriptor(descA, A_smem_ptr, layout/offsets)
TIR->>TL: initialize_descriptor(descB, B_smem_ptr, layout/offsets)
TIR->>TL: ptx_wgmma_ss(..., A_desc, A_off, B_desc, B_off, C_ptr, ...)
TL-->>CUDA: intrinsic calls with descriptors
CUDA->>PTX: PrintWGMMAAssembly(a_is_shared=true, validated config)
PTX-->>CUDA: PTX assembly string
CUDA-->>TIR: inline asm emitted into kernel
Estimated code review effort🎯 5 (Critical) | ⏱️ ~120 minutes Possibly related PRs
Suggested reviewers
Poem
Pre-merge checks and finishing touches❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✨ Finishing touches
🧪 Generate unit tests
Tip 👮 Agentic pre-merge checks are now available in preview!Pro plan users can now enable pre-merge checks in their settings to enforce checklists before merging PRs.
Please see the documentation for more information. Example: reviews:
pre_merge_checks:
custom_checks:
- name: "Undocumented Breaking Changes"
mode: "warning"
instructions: |
Pass/fail criteria: All breaking changes to public APIs, CLI flags, environment variables, configuration keys, database schemas, or HTTP/GraphQL endpoints must be documented in the "Breaking Change" section of the PR description and in CHANGELOG.md. Exclude purely internal or private changes (e.g., code not exported from package entry points or explicitly marked as internal). Please share your feedback with us on this Discord post. Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
…GMMA - Introduced new TileLang builtins `initialize_descriptor` and `increase_descriptor_offset` to facilitate descriptor management for WGMMA operations. - Updated `builtin.cc` and `builtin.h` to define and document the new builtins, enhancing the framework's capabilities for descriptor handling. - Modified `codegen_cuda.cc` and `ptx.cc` to integrate the new builtins into the code generation process, ensuring proper assembly generation for WGMMA operations. - Enhanced the `GemmWGMMA` class to utilize the new descriptor functionalities, improving the efficiency of matrix multiplication operations. - Updated related tests and documentation to reflect the new features and ensure comprehensive coverage. These changes enhance the TileLang framework's support for advanced matrix operations on newer architectures, improving performance and usability.
- Enhanced code formatting across multiple files for better readability, including consistent indentation and line breaks. - Updated function signatures and comments in `builtin.h`, `codegen_cuda.cc`, and `ptx.cc` to improve clarity. - Refactored descriptor initialization and offset manipulation functions in `builtin.py` and `wgmma_macro_generator.py` for improved structure. - Cleaned up unnecessary whitespace and improved alignment in `common.h` and `allocate.py`. These changes contribute to a cleaner and more maintainable codebase in the TileLang framework.
- Updated the subproject commit for `cutlass` to indicate a dirty state. - Refactored the `UpdateAnalyzer` function in `layout.cc` to call `LayoutNode::getVarMap()` instead of `getVarMap()`, improving clarity and ensuring proper context for variable mapping. These changes enhance the maintainability and clarity of the layout handling in the TileLang framework.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 6
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
tilelang/language/customize.py (1)
193-197
: Bug: extent computation uses list max instead of elementwise max
max(src_extent, dst_extent)
returns one of the lists lexicographically, not per-dimension. This can mis-size regions and corrupt memory ops.Apply:
- extent = max(src_extent, dst_extent) + # elementwise extent unification + extent = [T.max(a, b) for a, b in zip(src_extent, dst_extent)]
🧹 Nitpick comments (32)
tilelang/language/__init__.py (1)
45-46
: Export looks good; clean up unused noqa.
- Re-exporting alloc_descriptor is correct and matches allocate.py.
- Ruff flags Line 45 for unused
# noqa: F401
(RUF100). Drop it here (others weren’t flagged in this diff).Apply:
- alloc_descriptor, # noqa: F401 + alloc_descriptor,tilelang/layout/layout.py (3)
92-94
: Add docstring/return type for get_forward_index for parity with getters.Minor consistency nit. Consider:
- def get_forward_index(self): - return self.index + def get_forward_index(self) -> PrimExpr | list[PrimExpr]: + """Return the computed forward index expression(s).""" + return self.index
136-146
: API parity: provide Pythonic equality too.Keep is_equal, but also implement eq delegating to it for ergonomic comparisons; keep hash=None to avoid hashing mutable objects.
def is_equal(self, other: "Layout") -> bool: """ Check if the current layout is equal to another layout. """ return _ffi_api.Layout_is_equal(self, other) + def __eq__(self, other: object) -> bool: + if not isinstance(other, Layout): + return NotImplemented + return self.is_equal(other) + __hash__ = None
147-148
: repr: avoid huge dumps when vars/index grow.Consider truncating sequences for readability in logs.
- return f"Layout<{self.get_input_shape()}->{self.get_output_shape()}, {self.get_forward_vars()} -> {self.get_forward_index()}>" + fv = self.get_forward_vars() + fi = self.get_forward_index() + def _short(x): + s = str(x) + return s if len(s) <= 120 else s[:117] + "..." + return f"Layout<{self.get_input_shape()}->{self.get_output_shape()}, {_short(fv)} -> {_short(fi)}>"src/op/builtin.h (2)
164-185
: Fix WGMMA doc comments to match actual RS/SS operand forms.Current comment for ptx_wgmma_rs refers to A_descriptor; RS variant uses A_buf (regular pointer/buffer) per Python wrappers. Also B_offset is documented as Var (should be PrimExpr). Update comments to avoid API confusion.
- * void ptx_wgmma_rs(StringImm accum_dtype, StringImm wgmma_prefix, bool - * trans_a, bool trans_b, StringImm a_dtype_abbrv, StringImm b_dtype_abbrv, - * StringImm accum_dtype_abbrv, Var A_descriptor, PrimExpr A_offset, Var - * B_descriptor, Var B_offset, Var C_data, Var C_offset, bool scale_out, bool + * void ptx_wgmma_rs(StringImm accum_dtype, StringImm wgmma_prefix, bool + * trans_a, bool trans_b, StringImm a_dtype_abbrv, StringImm b_dtype_abbrv, + * StringImm accum_dtype_abbrv, Var A_buf, PrimExpr A_offset, Var + * B_descriptor, PrimExpr B_offset, Var C_data, PrimExpr C_offset, bool scale_out, bool * scale_in_a, bool scale_in_b);
344-361
: Descriptor intrinsic docs: correct operation wording.increase_descriptor_offset increments the offset; the block comment says “setting the start address.” Align wording to avoid misuse.
- * \brief tilelang intrinsic for setting the start address of a descriptor - * buffer for wgmma/utcmma. + * \brief tilelang intrinsic for increasing the offset of a descriptor + * buffer for wgmma/utcmma.tilelang/language/tir/op.py (2)
1064-1104
: Nit: Docstring says “wmma” but this is WGMMA; also consider briefly documenting operand order.Purely cosmetic; helps future maintainers and avoids confusion with WMMA.
- """TVM intrinsic for ptx tensor core wmma instructions + """TVM intrinsic for PTX warp-group MMA (WGMMA) instructions + Operand order: prefix, trans_a, trans_b, a/b/accum dtype abbrvs, + A_desc, A_offset, B_desc, B_offset, C_data, C_offset, scale_out, scale_in_a, scale_in_b.
1106-1144
: Nit: Same WGMMA terminology/doc tweak as above; otherwise wrapper matches builtin.LGTM functionally; mirrors 15-arg registration.
- return call_intrin( + # PTX warp-group MMA (WGMMA) RS variant: A from register, B from descriptor. + return call_intrin(tilelang/layout/__init__.py (1)
6-13
: Remove unused “noqa: F401” directives or enable F401 in Ruff config.Ruff flags these as unused (RUF100). Either drop them or configure Ruff to honor F401.
-from .swizzle import ( - make_swizzled_layout, # noqa: F401 - make_wgmma_swizzled_layout, # noqa: F401 - make_full_bank_swizzled_layout, # noqa: F401 - make_half_bank_swizzled_layout, # noqa: F401 - make_quarter_bank_swizzled_layout, # noqa: F401 -) -from .gemm_sp import make_metadata_layout # noqa: F401 +from .swizzle import ( + make_swizzled_layout, + make_wgmma_swizzled_layout, + make_full_bank_swizzled_layout, + make_half_bank_swizzled_layout, + make_quarter_bank_swizzled_layout, +) +from .gemm_sp import make_metadata_layouttilelang/tileop/gemm/gemm_mma.py (2)
60-77
: Use layout_map (or underscore the arg) to avoid ARG002 and keep parity with WGMMA.Mirror WGMMA: when provided, feed A/B shared layouts into the emitter; otherwise prefix arg as
_layout_map
to appease linters.- def lower(self, layout_map: dict, target: Target, thread_nums: int, thread_var: tir.Var): + def lower(self, layout_map: dict, target: Target, thread_nums: int, thread_var: tir.Var): @@ mma_emitter = TensorCoreIntrinEmitter( @@ thread_var=thread_var, ) + + # Optional: honor externally inferred layouts if present (parity with WGMMA) + if self.A in layout_map: + mma_emitter._assign_a_shared_layout(layout_map[self.A]) + if self.B in layout_map: + mma_emitter._assign_b_shared_layout(layout_map[self.B])
90-91
: Replace assert with explicit validation (and check divisibility).Python asserts can be stripped with -O; raise a ValueError and also enforce block_K % micro_size_k == 0 to match loop step.
- assert block_K >= micro_size_k, f"block_K ({block_K}) must be >= micro_size_k ({micro_size_k})" + if block_K < micro_size_k or (block_K % micro_size_k) != 0: + raise ValueError( + f"Invalid K tile: block_K={block_K}, micro_size_k={micro_size_k} " + "(must be >= and divisible)." + )src/op/gemm.cc (1)
45-48
: Documentation inconsistency between comment and implementation.The documentation states that
kPack
must be 1, but the implementation at lines 71-73 allows values of both 1 and 2. This creates confusion about the actual requirements.Either update the documentation to reflect that kPack can be 1 or 2, or enforce the restriction that it must be 1:
- * @note If `kPack` is provided it must be 1; otherwise the constructor - * fails with an ICHECK (runtime assertion). No other validation is - * performed here. + * @note If `kPack` is provided it must be 1 or 2; otherwise the constructor + * fails with an ICHECK (runtime assertion). No other validation is + * performed here.tilelang/language/builtin.py (1)
375-375
: Consider moving error messages to exception classes.While not critical, defining error messages inside exception classes improves maintainability and reusability.
Consider creating custom exception classes:
class InvalidDescriptorTypeError(TypeError): """Raised when descriptor is not a Buffer or BufferLoad.""" def __init__(self): super().__init__("Descriptor must be a tvm.tir.Buffer or tvm.tir.BufferLoad.") class InvalidDescriptorShapeError(ValueError): """Raised when descriptor is not a 1D buffer of size 1.""" def __init__(self): super().__init__("Descriptor must be a 1D buffer of size 1.")Also applies to: 378-378, 401-401, 404-404
tilelang/layout/swizzle.py (2)
23-23
: Use Optional type annotation for nullable parameter.The
continuity
parameter can beNone
but isn't typed asOptional
.Add proper type annotation:
+from typing import Optional + def make_wgmma_swizzled_layout(buffer: tvm.tir.Buffer, - continuity: int = None, + continuity: Optional[int] = None, k_major: bool = True):
54-54
: Consider more descriptive error messages.The error messages could be more helpful by describing what arguments are expected.
Improve error messages:
- raise ValueError(f"Invalid arguments: {args}") + raise ValueError(f"Expected either a single buffer or (stride, continuous, element_size), got {len(args)} arguments: {args}")Also applies to: 79-79, 104-104
src/layout/gemm_layouts.cc (2)
573-574
: Parameter name inconsistency with function documentation.The parameter name
k_inner
doesn't match the documentation's description which refers to whether the "K dimension is in the inner loop". Consider renaming tok_is_inner
oris_k_inner
for better clarity.
532-541
: Missing implementation of k_major parameter in Volta layout.The function signature was updated to use
bool k_inner
but the implementation still usesk_inner
directly as a boolean flag without considering the k-major semantics that the rest of the codebase expects.Based on the pattern in
makeGemmABLayout
andmakeGemmABLayoutHopper
, you may want to verify that this implementation correctly handles the k-major/k-inner semantics.tilelang/tileop/gemm/gemm_wgmma.py (1)
110-111
: Verify RS implementation for A in fragment and B in shared.The
_gemm_rsr
function name and docstring mention loading data from "shared buffers A_shared and B_shared", but the RS variant should have A in registers/fragments. The docstring appears to be copied from the SS variant.Apply this diff to fix the docstring:
@T.prim_func def _gemm_rsr() -> None: """ - The inner macro that loads data from shared buffers A_shared and - B_shared into local fragments, then issues Tensor Core mma ops, - accumulating into C_local. + The inner macro that uses data from local fragment A_local and + shared buffer B_shared, then issues Tensor Core mma ops, + accumulating into C_local. """ mma_emitter.wgmma(A_local, B_shared, C_local, clear_accum)tilelang/intrinsics/wgmma_macro_generator.py (3)
104-104
: Remove unused parametern_dim
from method signature.The method
_initialize_wgmma_prefix
has an unused parametern_dim=16
that shadows the instance variableself.n_dim
.Apply this diff to fix:
- def _initialize_wgmma_prefix(self, n_dim: int = 16): + def _initialize_wgmma_prefix(self): inst_m, inst_n = 64, self.block_col_warps * self.warp_col_tiles
133-143
: Extract swizzle mode detection to centralized utility.The
_determinate_swizzle_mode
method performs layout equality checks against multiple swizzle patterns. This logic could be refactored into a centralized layout utility to avoid duplication if similar detection is needed elsewhere.Would you like me to help create a centralized swizzle mode detection utility that could be reused across the codebase?
345-353
: Improve error message specificity for unsupported dtypes.The error message for unsupported dtypes could be more informative by including the actual bit width.
Apply this diff:
else: - raise ValueError(f"Unsupported dtype {dtype}") + raise ValueError(f"Unsupported dtype {dtype} with {dtype_bits} bits for MMA load layout")src/layout/layout.cc (1)
495-505
: Parameter naming inconsistency in swizzled layout creation.The
make_swizzled_layout
function usesallow_pad
to choose betweenmakeGemmABLayout
andmakeGemmABLayoutHopper
. The parameter nameallow_pad
doesn't clearly convey that it's selecting between different hardware layout strategies (standard vs Hopper).Consider renaming to be more descriptive:
.def("tl.make_swizzled_layout", [](int stride, int continuous, int element_size, bool k_inner, - bool allow_pad = true) { - if (allow_pad) { + bool use_hopper_layout = false) { + if (!use_hopper_layout) { return makeGemmABLayout(stride, continuous, continuous, element_size, k_inner); } else { return makeGemmABLayoutHopper(stride, continuous, continuous, element_size, k_inner); } })tilelang/tileop/gemm/__init__.py (2)
82-98
: Add parameter validation for thread_nums.While the FFI call handles the selection logic, it would be good to validate that
thread_nums
is positive before passing it to the FFI.def _select_gemm_instruction(self, thread_nums: int, target: Target) -> GemmInst: """Select the appropriate GEMM instruction based on target and thread configuration. The selection logic follows this priority: 1. WGMMA for Hopper architecture with sufficient matrix size and warp count 2. MFMA for CDNA (AMD) architecture 3. MMA for CUDA architecture 4. Fallback to MMA for other cases Args: thread_nums: Number of threads in the block target: Target architecture Returns: GemmInst: The selected GEMM instruction type """ + if thread_nums <= 0: + raise ValueError(f"thread_nums must be positive, got {thread_nums}") return GemmInst(_ffi_api.GemmPyGemmInst(self, int(thread_nums), target))
118-118
: Consider using a more specific exception message.The error message could be more informative by including what implementations are available.
- raise NotImplementedError("MFMA is not implemented") + raise NotImplementedError("MFMA is not implemented. Available implementations: MMA, WGMMA")src/target/ptx.cc (2)
1053-1168
: Complex operand generation but missing a_is_shared validation for register path.The function generates WGMMA operands correctly but doesn't validate that when
a_is_shared
is false, the operation is actually supported for register-based A operands.Consider adding validation:
inline std::tuple<std::string, std::string, std::string, std::string> GetWGMMAOperands(int m, int n, int k, ptx::DataType dtype_a, ptx::DataType dtype_b, ptx::DataType dtype_c, bool sparse, bool a_is_shared) { + // WGMMA with register-based A operand has limitations + if (!a_is_shared) { + // Add any specific validation for register-based A operands if needed + // based on NVIDIA documentation + } std::stringstream templates, inputs, outputs, predicate;
1263-1266
: Consider extracting predicate setup as a constant.The predicate setup code could be defined as a constant string for better maintainability.
+ constexpr const char* PREDICATE_SETUP = + "{.reg .pred p;\n" + "setp.ne.b32 p, {predicate}, 0;\n"; + std::string asm_code = R"( { __asm__ __volatile__( - "{.reg .pred p;\n" - "setp.ne.b32 p, {predicate}, 0;\n" + ")" PREDICATE_SETUP R"( "wgmma.mma_async{.sparse}.sync.aligned{.shape}{.dtype}{.atype}{.btype}" "{templates};\n}" : {outputs} : {inputs}); } )";tilelang/language/customize.py (2)
160-179
: Deduplicate get_extent; current helper misses BufferLoad and can divergeYou re-implement
get_extent
but omitBufferLoad
handling (supported in tilelang/language/copy.py). Prefer reusing the canonical helper to avoid drift.Apply:
@@ - def get_extent(data): - """ - Return the inferred extent (shape) of a buffer-like object. - ... - """ - if isinstance(data, Var) and T.has_let_value(data): - data = T.get_let_value(data) - if isinstance(data, Buffer): - return data.shape - elif isinstance(data, BufferRegion): - return [x.extent for x in data.region] - else: - return None - - src_extent = get_extent(value) - dst_extent = get_extent(dst) + src_extent = _get_extent(value) + dst_extent = _get_extent(dst)Also add the import near the top of this file:
from tilelang.language.copy import get_extent as _get_extent
82-105
: Honor provided extents in buffer_region_to_tile_regionThe
extents
parameter is only asserted but ignored. This prevents aligning extents when mixing BufferRegion with BufferLoad.Apply:
@@ - return region(T.BufferLoad(buffer_region.buffer, mins), access_type, *region_extents) + # Override trailing dims with requested extents if provided + if extents: + region_extents = list(region_extents) + region_extents[-len(extents):] = extents + return region(T.BufferLoad(buffer_region.buffer, mins), access_type, *region_extents)tilelang/language/allocate.py (1)
129-136
: Add type hints for API clarityTighten signature and return type.
Apply:
-def alloc_descriptor(dtype="uint64", scope="local.descriptor"): +def alloc_descriptor(dtype: str = "uint64", scope: str = "local.descriptor") -> T.Buffer: """Allocate a descriptor buffer for wgmma and utcmma. Returns: T.Buffer: A TVM buffer object allocated as a descriptor """ return T.alloc_buffer([1], dtype, scope=scope)src/layout/layout.h (3)
163-167
: Prefer explicit enum over boolean for K placementA boolean is easy to misuse. Consider a strongly-typed enum (e.g.,
enum class KPlacement { Inner, Outer };
) to make call sites self-documenting and prevent silent int→bool coercions.Example:
enum class KPlacement { Inner, Outer }; Layout makeGemmABLayout(int mat_stride, int mat_continuous, int continuity, int element_size, KPlacement k_place = KPlacement::Inner);
168-169
: CDNA kPack rename verified — code updated; fix lingering doc/commentsmakeGemmABLayoutCDNA (declaration/definition) and callers use int kPack and Python bindings expose "kPack".
- Update remaining comment references to the old name:
- src/op/gemm_py.h — comment referencing "k_pack" (around line ~30).
- src/op/gemm_sp.h — similar comment (around line ~25).
177-178
: Volta layout: no functional change required — callers pass explicit k_inner; remove or flip default for clarity.
makeGemmVoltaABLayout(..., bool k_inner = true) is declared in src/layout/layout.h; call sites pass explicit values (src/op/gemm.cc:479, 492).
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (30)
src/layout/gemm_layouts.cc
(4 hunks)src/layout/layout.cc
(4 hunks)src/layout/layout.h
(2 hunks)src/op/builtin.cc
(2 hunks)src/op/builtin.h
(2 hunks)src/op/gemm.cc
(7 hunks)src/op/gemm_py.cc
(4 hunks)src/op/gemm_py.h
(1 hunks)src/target/codegen_cuda.cc
(5 hunks)src/target/ptx.cc
(6 hunks)src/target/ptx.h
(1 hunks)src/tl_templates/cuda/common.h
(4 hunks)src/transform/lower_device_storage_access_info.cc
(1 hunks)src/transform/storage_rewrite.cc
(2 hunks)tilelang/intrinsics/wgmma_macro_generator.py
(1 hunks)tilelang/language/__init__.py
(4 hunks)tilelang/language/allocate.py
(1 hunks)tilelang/language/ast/ir.py
(2 hunks)tilelang/language/builtin.py
(2 hunks)tilelang/language/customize.py
(10 hunks)tilelang/language/tir/ir.py
(1 hunks)tilelang/language/tir/op.py
(1 hunks)tilelang/layout/__init__.py
(1 hunks)tilelang/layout/fragment.py
(1 hunks)tilelang/layout/layout.py
(2 hunks)tilelang/layout/swizzle.py
(1 hunks)tilelang/tileop/gemm/__init__.py
(3 hunks)tilelang/tileop/gemm/gemm_base.py
(2 hunks)tilelang/tileop/gemm/gemm_mma.py
(2 hunks)tilelang/tileop/gemm/gemm_wgmma.py
(1 hunks)
🧰 Additional context used
🧬 Code graph analysis (23)
tilelang/language/tir/ir.py (2)
tilelang/language/tir/op.py (2)
ptx_wgmma_ss
(1064-1103)ptx_wgmma_rs
(1106-1143)tilelang/language/ast/ir.py (1)
_dtype_forward
(1876-1884)
tilelang/language/allocate.py (2)
src/transform/storage_rewrite.cc (4)
dtype
(696-702)dtype
(696-696)scope
(674-678)scope
(674-674)tilelang/language/ast/ir.py (1)
alloc_buffer
(441-508)
src/transform/lower_device_storage_access_info.cc (1)
src/transform/storage_rewrite.cc (2)
scope
(674-678)scope
(674-674)
tilelang/layout/layout.py (1)
tilelang/layout/fragment.py (1)
is_equal
(209-213)
tilelang/language/ast/ir.py (2)
tilelang/language/tir/op.py (2)
ptx_wgmma_ss
(1064-1103)ptx_wgmma_rs
(1106-1143)tilelang/language/tir/ir.py (1)
_dtype_forward
(156-164)
tilelang/layout/swizzle.py (2)
tilelang/language/ast/ir.py (1)
buffer
(93-161)src/layout/swizzle.h (1)
tvm
(12-70)
tilelang/tileop/gemm/gemm_wgmma.py (6)
tilelang/tileop/gemm/gemm_base.py (17)
GemmBase
(12-120)infer_layout
(15-16)policy
(119-120)M
(34-35)N
(38-39)in_dtype
(54-56)accum_dtype
(59-60)trans_A
(46-47)trans_B
(50-51)chunk
(63-64)is_gemm_ss
(21-22)K
(42-43)A
(67-68)B
(71-72)C
(75-76)lower
(18-19)clear_accum
(107-108)tilelang/layout/swizzle.py (1)
make_wgmma_swizzled_layout
(22-34)tilelang/intrinsics/wgmma_macro_generator.py (6)
TensorCoreIntrinEmitter
(63-477)make_mma_store_layout
(423-477)make_mma_load_layout
(311-421)_assign_a_shared_layout
(96-98)_assign_b_shared_layout
(100-102)wgmma
(145-233)tilelang/utils/language.py (2)
is_shared
(25-39)is_fragment
(68-78)tilelang/transform/simplify.py (1)
_Simplify
(30-49)tilelang/tileop/gemm/gemm_mma.py (3)
infer_layout
(15-58)is_gemm_ss
(204-205)lower
(60-202)
tilelang/intrinsics/wgmma_macro_generator.py (6)
tilelang/utils/language.py (1)
is_fragment
(68-78)tilelang/layout/swizzle.py (3)
make_full_bank_swizzled_layout
(39-59)make_half_bank_swizzled_layout
(64-84)make_quarter_bank_swizzled_layout
(89-109)tilelang/layout/fragment.py (4)
is_equal
(209-213)Fragment
(13-213)replicate
(146-160)repeat
(123-144)tilelang/language/builtin.py (1)
initialize_descriptor
(355-386)tilelang/language/tir/op.py (2)
ptx_wgmma_ss
(1064-1103)ptx_wgmma_rs
(1106-1143)tilelang/intrinsics/mma_macro_generator.py (1)
get_store_index_map
(159-165)
src/target/codegen_cuda.cc (4)
src/transform/storage_rewrite.cc (2)
scope
(674-678)scope
(674-674)tilelang/language/tir/op.py (2)
ptx_wgmma_ss
(1064-1103)ptx_wgmma_rs
(1106-1143)src/target/ptx.cc (2)
PrintWGMMAAssembly
(1235-1306)PrintWGMMAAssembly
(1235-1244)tilelang/language/builtin.py (2)
initialize_descriptor
(355-386)increase_descriptor_offset
(389-411)
src/op/builtin.h (2)
tilelang/language/tir/op.py (2)
ptx_wgmma_ss
(1064-1103)ptx_wgmma_rs
(1106-1143)tilelang/language/builtin.py (2)
initialize_descriptor
(355-386)increase_descriptor_offset
(389-411)
src/tl_templates/cuda/common.h (3)
src/tl_templates/cuda/copy_sm90.h (1)
void
(255-258)src/tl_templates/cuda/ldsm.h (12)
void
(7-14)void
(16-23)void
(25-33)void
(35-42)void
(44-52)void
(54-62)void
(64-70)void
(72-79)void
(81-89)void
(91-98)void
(100-108)void
(110-119)tilelang/language/builtin.py (2)
initialize_descriptor
(355-386)increase_descriptor_offset
(389-411)
src/target/ptx.h (1)
src/target/ptx.cc (2)
PrintWGMMAAssembly
(1235-1306)PrintWGMMAAssembly
(1235-1244)
tilelang/language/builtin.py (3)
src/op/builtin.h (1)
tvm
(13-363)tilelang/language/ast/ir.py (1)
evaluate
(1319-1331)tilelang/language/tir/op.py (1)
call_intrin
(119-144)
src/op/gemm.cc (2)
tilelang/tileop/gemm/gemm_base.py (4)
trans_A
(46-47)trans_B
(50-51)A
(67-68)B
(71-72)src/layout/gemm_layouts.cc (2)
makeGemmABLayout
(573-592)makeGemmABLayout
(573-574)
tilelang/language/customize.py (1)
tilelang/language/copy.py (1)
get_extent
(105-118)
src/op/builtin.cc (2)
tilelang/language/tir/op.py (2)
ptx_wgmma_ss
(1064-1103)ptx_wgmma_rs
(1106-1143)tilelang/language/builtin.py (2)
initialize_descriptor
(355-386)increase_descriptor_offset
(389-411)
tilelang/tileop/gemm/gemm_mma.py (3)
tilelang/tileop/gemm/gemm_wgmma.py (1)
lower
(64-125)tilelang/tileop/gemm/__init__.py (1)
lower
(76-80)tilelang/tileop/gemm/gemm_base.py (1)
lower
(18-19)
tilelang/language/__init__.py (1)
tilelang/language/allocate.py (1)
alloc_descriptor
(129-135)
tilelang/layout/__init__.py (2)
tilelang/layout/swizzle.py (5)
make_swizzled_layout
(10-18)make_wgmma_swizzled_layout
(22-34)make_full_bank_swizzled_layout
(39-59)make_half_bank_swizzled_layout
(64-84)make_quarter_bank_swizzled_layout
(89-109)tilelang/layout/gemm_sp.py (1)
make_metadata_layout
(98-109)
tilelang/tileop/gemm/__init__.py (4)
tilelang/ir.py (1)
GemmWarpPolicy
(30-39)tilelang/tileop/gemm/gemm_mma.py (3)
GemmMMA
(13-214)lower
(60-202)infer_layout
(15-58)tilelang/tileop/gemm/gemm_wgmma.py (3)
GemmWGMMA
(13-137)lower
(64-125)infer_layout
(15-62)tilelang/tileop/gemm/gemm_base.py (2)
lower
(18-19)infer_layout
(15-16)
tilelang/layout/fragment.py (1)
tilelang/layout/layout.py (3)
get_input_shape
(59-68)get_output_shape
(70-79)is_equal
(136-145)
src/layout/layout.h (1)
src/layout/gemm_layouts.cc (4)
makeGemmABLayoutHopper
(594-615)makeGemmABLayoutHopper
(594-595)makeGemmABLayoutCDNA
(617-625)makeGemmABLayoutCDNA
(617-618)
src/layout/layout.cc (1)
src/layout/gemm_layouts.cc (10)
makeGemmABLayout
(573-592)makeGemmABLayout
(573-574)makeGemmABLayoutHopper
(594-615)makeGemmABLayoutHopper
(594-595)makeFullBankSwizzleLayout
(375-391)makeFullBankSwizzleLayout
(375-375)makeHalfBankSwizzleLayout
(356-372)makeHalfBankSwizzleLayout
(356-356)makeQuarterBankSwizzleLayout
(336-353)makeQuarterBankSwizzleLayout
(336-337)
🪛 Ruff (0.12.2)
tilelang/layout/swizzle.py
23-23: PEP 484 prohibits implicit Optional
Convert to Optional[T]
(RUF013)
54-54: Avoid specifying long messages outside the exception class
(TRY003)
79-79: Avoid specifying long messages outside the exception class
(TRY003)
104-104: Avoid specifying long messages outside the exception class
(TRY003)
tilelang/tileop/gemm/gemm_wgmma.py
61-62: Avoid specifying long messages outside the exception class
(TRY003)
124-125: Avoid specifying long messages outside the exception class
(TRY003)
tilelang/intrinsics/wgmma_macro_generator.py
104-104: Unused method argument: n_dim
(ARG002)
143-143: Avoid specifying long messages outside the exception class
(TRY003)
353-353: Avoid specifying long messages outside the exception class
(TRY003)
tilelang/language/builtin.py
375-375: Avoid specifying long messages outside the exception class
(TRY003)
377-377: Parenthesize a and b
expressions when chaining and
and or
together, to make the precedence clear
Parenthesize the and
subexpression
(RUF021)
378-378: Avoid specifying long messages outside the exception class
(TRY003)
401-401: Avoid specifying long messages outside the exception class
(TRY003)
403-403: Parenthesize a and b
expressions when chaining and
and or
together, to make the precedence clear
Parenthesize the and
subexpression
(RUF021)
404-404: Avoid specifying long messages outside the exception class
(TRY003)
tilelang/tileop/gemm/gemm_mma.py
60-60: Unused method argument: layout_map
(ARG002)
tilelang/language/__init__.py
45-45: Unused noqa
directive (non-enabled: F401
)
Remove unused noqa
directive
(RUF100)
tilelang/layout/__init__.py
7-7: Unused noqa
directive (non-enabled: F401
)
Remove unused noqa
directive
(RUF100)
8-8: Unused noqa
directive (non-enabled: F401
)
Remove unused noqa
directive
(RUF100)
9-9: Unused noqa
directive (non-enabled: F401
)
Remove unused noqa
directive
(RUF100)
10-10: Unused noqa
directive (non-enabled: F401
)
Remove unused noqa
directive
(RUF100)
11-11: Unused noqa
directive (non-enabled: F401
)
Remove unused noqa
directive
(RUF100)
13-13: Unused noqa
directive (non-enabled: F401
)
Remove unused noqa
directive
(RUF100)
tilelang/tileop/gemm/__init__.py
120-120: Avoid specifying long messages outside the exception class
(TRY003)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Cursor Bugbot
🔇 Additional comments (43)
tilelang/tileop/gemm/gemm_base.py (2)
8-8
: Add PrimExpr import for updated clear_accum type annotation.The import is correctly added to support the type annotation change.
107-108
: Update type annotation from bool to PrimExpr.The change from
bool
toPrimExpr
aligns with the WGMMA backend requirements where clear_accum can be an expression rather than just a boolean value. This provides more flexibility for the accumulation control logic.src/op/gemm_py.h (1)
112-113
: Clean private section reorganization.The reorganization of private members with explicit
private:
labeling improves code organization and readability. TheGemmInst
enum andGetGemmInst
method remain appropriately private while allowing FFI access through wrapper functions.src/transform/lower_device_storage_access_info.cc (1)
47-48
: Correct exclusion of descriptor scope from storage lowering.The addition of
scope.tag != ".descriptor"
properly excludes descriptor buffers from the generic storage access lowering path, which is necessary for the WGMMA descriptor handling. This change aligns with the corresponding exclusion instorage_rewrite.cc
and allows descriptor buffers to maintain their special semantics.tilelang/language/tir/ir.py (1)
294-295
: Add WGMMA PTX intrinsic wrappers.The new
ptx_wgmma_ss
andptx_wgmma_rs
wrappers correctly follow the established pattern using_dtype_forward
decorator, providing consistent API access to the underlying PTX WGMMA intrinsics for shared-shared and register-shared variants.tilelang/layout/fragment.py (2)
207-207
: Enhanced Fragment representation with shape information.The updated
__repr__
method now includes input and output shapes, providing more comprehensive debugging information. The use ofget_input_shape()
andget_output_shape()
follows the established pattern from the base Layout class.
209-213
: Add Fragment equality check method.The new
is_equal
method provides a proper way to compare Fragment instances by delegating to the FFI implementation. This aligns with the corresponding method in the base Layout class and enables proper equality testing for Fragment objects.src/transform/storage_rewrite.cc (2)
677-678
: Exclude.descriptor
from “special tagged memory” — OK.This prevents special-merge and memory‑info constraints from applying to descriptor buffers, which is desirable.
847-849
: Skip memory‑info path for.descriptor
in NewAllocTagMerged — OK; add test.Change correctly treats descriptor allocations neutrally (GetMemoryInfo is not invoked for
.descriptor
). Add a minimal IR test that ensures:
- local.descriptor buffers are not grouped into special merges
- no GetMemoryInfo lookups occur for
.descriptor
scopesRelevant files: src/transform/storage_rewrite.cc (≈lines 846–850), src/transform/lower_device_storage_access_info.cc (≈line 48).
tilelang/language/ast/ir.py (1)
1895-1896
: Approve: WGMMA wrappers wired and registeredWrapper bindings exist in tilelang/language/ast/ir.py (ptx_wgmma_ss/rs), TIR op wrappers call _tvm_op.Op.get("tl.ptx_wgmma_ss/rs") in tilelang/language/tir/op.py, builtins are registered in src/op/builtin.cc/.h (TIR_DEFINE_TL_BUILTIN + declarations) and referenced by codegen in src/target/codegen_cuda.cc — no issues found. Add a small CI check to assert builtin registration if desired.
src/op/builtin.cc (2)
174-183
: Descriptor intrinsics API consistent with wrappers.5 inputs for initialize and 2 for increase align with the Python layer; opaque effect is appropriate given side effects. LGTM.
83-92
: WGMMA intrinsics: 15-arg arity verified; codegen enforces 15 operands. Python wrappers define ptx_wgmma_ss/rs and builtin registers .set_num_inputs(15); src/target/codegen_cuda.cc contains ICHECK_EQ(op->args.size(), 15U) for both handlers — operand alignment confirmed.src/op/gemm_py.cc (3)
109-111
: Unreachable-return safety: fine as-is.This satisfies compilers post-ICHECK; harmless.
255-257
: Unreachable-return safety: fine as-is.Consistent with the ICHECK case; no functional impact.
284-290
: FFI reflection for GemmInst selection: LGTM.Simple, side‑effect free binding; useful for runtime dispatch.
src/op/gemm.cc (1)
481-481
: LGTM! Consistent layout API migration to boolean flags.The transition from numeric codes (1/2) to boolean flags for k_inner parameter across all backends (Volta, Ampere/Turing/SM120, Hopper) improves API clarity and type safety.
Also applies to: 507-507, 520-520, 545-545, 547-547, 563-563, 565-565
src/target/codegen_cuda.cc (5)
898-898
: Including descriptor scope in pointer access logic.The addition of
"local.descriptor"
scope alongside"local.var"
is correct for treating descriptors as local objects.
1788-1789
: LGTM! Proper allocation of descriptor storage.The allocation of
tl::GmmaDescriptor
for descriptor scope is correctly implemented.
1305-1340
: Check for additional error conditions in WGMMA intrinsics.The WGMMA SS intrinsic implementation looks correct, but consider adding validation for the descriptor parameters.
Should we verify that the descriptors are properly initialized before use? Consider adding checks to ensure
a_desc
andb_desc
are valid descriptor handles.
1701-1713
: LGTM! Descriptor operations correctly implemented.The implementation of
initialize_descriptor
andincrease_descriptor_offset
intrinsics is correct and properly forwards all parameters to the TL template functions.Also applies to: 1714-1721
1823-1823
: Correct handling of descriptor scope in allocation check.The condition properly excludes
local.descriptor
from the unsupported scope error path.tilelang/language/builtin.py (1)
355-386
: LGTM! Well-documented descriptor initialization function.The
initialize_descriptor
function is well-implemented with proper type checking, parameter documentation, and error handling.tilelang/layout/swizzle.py (2)
10-18
: LGTM! Well-structured swizzle layout functions.The updated
make_swizzled_layout
and newmake_wgmma_swizzled_layout
functions are properly implemented with clear parameter forwarding to the FFI API.Also applies to: 22-34
39-109
: LGTM! Flexible bank-swizzled layout helpers.The three bank-swizzled layout functions (full/half/quarter) are well-implemented with flexible argument handling that supports both buffer objects and explicit parameters.
src/tl_templates/cuda/common.h (1)
307-367
: LGTM! Well-structured descriptor union implementation.The
GmmaDescriptor
union is properly designed with:
- Multiple access patterns via
desc_
,reg32_[]
, andreg16_[]
- Clear bitfield layout matching CUDA WMMA descriptor format
- Proper copy and move semantics
- Convenient arithmetic operator for offset adjustments
src/layout/layout.cc (2)
461-478
: LGTM! Equality check methods properly implemented.The new
Layout_is_equal
andFragment_is_equal
FFI bindings correctly expose the underlying equality check functionality with proper node casting.
506-511
: LGTM! WGMMA swizzled layout properly wired.The
make_wgmma_swizzled_layout
correctly passes through the continuity parameter separately from mat_continuous, which is essential for WGMMA's layout requirements.tilelang/tileop/gemm/gemm_wgmma.py (1)
36-37
: Verify continuity calculation for k-major layoutstilelang/tileop/gemm/gemm_wgmma.py:36-37,50 — current code:
a_continuity = self.M if a_is_k_major else 4 * self.K // m_warp
andb_continuity = self.N if b_is_k_major else 4 * self.K // n_warp
. Confirm whether the k-major branch should intentionally useself.M
/self.N
(instead of a K-derived continuity) and whether the4 *
factor is correct; if intentional, add an inline comment explaining the rationale and add a unit test, otherwise correct the formula to derive continuity from K.tilelang/intrinsics/wgmma_macro_generator.py (1)
220-231
: Use 64-bit arithmetic for shared-memory offset calculations.A_offset / B_offset in tilelang/intrinsics/wgmma_macro_generator.py (lines 220–231) perform multiplications like
i * 64 * A_buf.shape[-1]
that can overflow 32-bit; ensure intermediate arithmetic uses 64-bit (cast the PrimExprs to int64 or use an explicit int64 cast / bounds check) before multiplying byelems_in_bytes
and passing to ptx_wgmma_ss.tilelang/tileop/gemm/__init__.py (6)
1-1
: LGTM! Addition ofIntEnum
import is appropriate for the newGemmInst
class.The import is necessary for creating the strongly-typed enumeration that will be used to represent different GEMM instruction types.
10-11
: LGTM! Import additions align with the new WGMMA support.The imports for
GemmWGMMA
and_ffi_api
are necessary for the new architecture-aware dispatch mechanism.
21-24
: Good addition oflayout_map
parameter to support WGMMA requirements.The parameter addition correctly passes the layout information through to the lower-level implementation.
27-42
: LGTM! Well-designed enumeration for GEMM instruction types.The
GemmInst
enumeration provides a clean abstraction for different GEMM implementations with convenient helper methods for type checking.
71-81
: Architecture-aware dispatch implementation looks good.The new
infer_layout
andlower
methods properly delegate to architecture-specific implementations using a clean dispatch pattern.
100-120
: LGTM! Clean dispatch with proper error handling.The method provides a clear mapping from instruction types to implementation classes with appropriate error handling for unsupported types.
src/target/ptx.cc (7)
149-156
: LGTM! Useful helper function for integer type checking.This function provides a convenient way to check if a PTX data type represents an integer, which is needed for WGMMA scale input validation.
188-197
: LGTM! Appropriate layout type conversion helper.The function correctly maps boolean values to layout types for WGMMA assembly generation.
278-291
: Well-structured WGMMA configuration struct.The struct follows the same pattern as
MMAConfig
and provides proper equality comparison for configuration validation.
293-720
: Comprehensive WGMMA configuration table.The extensive configuration table covers all supported WGMMA operations including dense and sparse variants across multiple data types. This aligns well with the PR objective of supporting int8, fp8, and tf32 data types.
859-879
: Solid configuration validation function.The function properly validates WGMMA configurations by checking data type compatibility and searching for matching configurations in the valid list.
1119-1121
: Fatal error for sparse WGMMA is reasonable.Since sparse WGMMA is not yet implemented, failing fast with a clear error message is the appropriate approach.
1235-1306
: Well-structured WGMMA assembly generation.The function properly generates PTX assembly for WGMMA operations with support for various configurations including scale parameters and layout options. The template replacement approach is clean and maintainable.
tilelang/language/allocate.py (1)
129-136
: New descriptor allocator looks goodMinimal, aligned with ".descriptor" scope handling in codegen. LGTM.
int element_size, bool k_inner = true); | ||
Layout makeGemmABLayoutHopper(int mat_stride, int mat_continuous, | ||
int continuity, int element_size, int kfactor); | ||
int continuity, int element_size, | ||
bool k_inner = true); | ||
Layout makeGemmABLayoutCDNA(int stride, int continuous, int element_size, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🧩 Analysis chain
Signature change to bool k_inner: audit all call sites
Converting int kfactor
→ bool k_inner
changes semantics; old int literals (e.g., 2) now coerce to true
. Please verify all callers pass a boolean with the intended meaning.
Run:
🏁 Script executed:
#!/bin/bash
set -euo pipefail
# Show call sites with context
rg -nPU -C3 --type=cpp --type=cc --type=cxx --type=h --type=cu '\bmakeGemmABLayout\s*\('
rg -nPU -C3 --type=cpp --type=cc --type=cxx --type=h --type=cu '\bmakeGemmABLayoutHopper\s*\('
# Heuristic: calls with 5+ args where last arg is not clearly true/false or named
rg -nPU --type=cpp --type=cc --type=cxx --type=cu '\bmakeGemm(ABLayout|ABLayoutHopper)\s*\((?:[^,]*,){4}\s*([^,)]+)\)' -r '$1' | grep -v -E '^\s*(true|false)\s*$' || true
Length of output: 26
🏁 Script executed:
#!/bin/bash
set -euo pipefail
echo "1) Listing raw matches (context):"
if command -v rg >/dev/null 2>&1; then
rg -n -U -P '\bmakeGemmABLayout(?:Hopper|CDNA)?\s*\(' -C3 --hidden || true
else
grep -nR --line-number -E '\bmakeGemmABLayout(?:Hopper|CDNA)?\s*\(' . || true
fi
echo
echo "2) Extracting full call expression and printing the last argument (best-effort):"
python - <<'PY'
import re,subprocess
pat=re.compile(r'\b(makeGemmABLayout(?:Hopper|CDNA)?)\s*\(')
# try rg first, fallback to grep
try:
out=subprocess.check_output(['rg','-n','-U','-P',r'\bmakeGemmABLayout(?:Hopper|CDNA)?\s*\('],stderr=subprocess.DEVNULL).decode()
except Exception:
try:
out=subprocess.check_output(['grep','-nR','-E',r'\bmakeGemmABLayout(?:Hopper|CDNA)?\s*\(','./'],stderr=subprocess.DEVNULL).decode()
except Exception:
out=''
for line in out.splitlines():
parts=line.split(':',2)
if len(parts)<3:
continue
path,lineno,content=parts[0],int(parts[1]),parts[2]
try:
with open(path,'r',encoding='utf-8') as f:
data=f.read()
except Exception:
continue
lines=data.splitlines()
idx=max(0,lineno-1)
window = '\n'.join(lines[idx:idx+200])
m=pat.search(window)
if not m:
continue
start_idx=m.end()-1
depth=0
end_pos=None
for i,ch in enumerate(window[start_idx:], start=start_idx):
if ch=='(':
depth+=1
elif ch==')':
depth-=1
if depth==0:
end_pos=i
break
call_text = window[m.start(): end_pos+1] if end_pos is not None else window[m.start():m.start()+400]
inside = call_text[call_text.find('(')+1: (call_text.rfind(')') if call_text.rfind(')')>call_text.find('(') else None )]
inside = re.sub(r'/\*.*?\*/|//.*?$','',inside,flags=re.S|re.M)
# split args by commas not inside nested constructs
args=[]
cur=''
depth=0
for ch in inside:
if ch in '([{<':
depth+=1
elif ch in ')]}>':
depth-=1
if ch==',' and depth==0:
args.append(cur.strip()); cur=''
else:
cur+=ch
if cur.strip():
args.append(cur.strip())
last = args[-1] if args else ''
# simplify last token
last_token = last.strip().split()[-1] if last.strip() else ''
print(f"{path}:{lineno}: func={m.group(1)} last_arg_raw='{last}' last_arg_token='{last_token}'")
PY
Length of output: 9413
Critical: callers passing integer k-factor (1/2) must be fixed — makeGemmABLayout/Hopper now takes bool k_inner
src/op/gemm_sp.cc passes integer ternaries (trans_A ? 1 : 2 and trans_B ? 2 : 1) at src/op/gemm_sp.cc:196-198 and src/op/gemm_sp.cc:209-211; both non-zero values now convert to true and break intended semantics. Change these calls to pass an explicit boolean with the intended meaning (e.g., use trans_A/trans_B or (expr)==1), or restore an int overload if 1/2 semantics are required. Other callers (src/op/gemm.cc, src/layout/*.cc/h) were checked and pass booleans.
Downcast<PrimFunc>((*f)(GetRef<GemmPy>(this), T.layout_map, T.target, | ||
T.thread_bounds, T.thread_var)); | ||
ICHECK(prim_func->attrs.defined()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
🧩 Analysis chain
FFI lower signature now includes layout_map; ensure all providers updated.
The C++ call passes T.layout_map. Verify the Python FFI target “tl.gemm_py.lower” accepts (gemm_py, layout_map, target, thread_bounds, thread_var) everywhere to avoid runtime errors.
🏁 Script executed:
#!/bin/bash
# Confirm the new signature at Python side
rg -nP 'def\s+lower\((self|gemm_py)[^)]*layout_map' -C2
# Find the FFI registration or dispatcher for "tl.gemm_py.lower"
rg -n 'gemm_py\.lower' -S
# Check base class signatures to avoid LSP mismatches
rg -nP 'class\s+GemmBase\b.*?\n' -n tilelang/tileop/gemm -C2
rg -nP 'def\s+lower\(' tilelang/tileop/gemm -S
Length of output: 2364
Add layout_map to GemmBase.lower
Change tilelang/tileop/gemm/gemm_base.py:18 from
def lower(self, target, thread_nums, thread_var)
to
def lower(self, layout_map: dict, target: Target, thread_nums: int, thread_var: tir.Var)
so the base signature matches the FFI tl.gemm_py.lower and existing subclass implementations.
/*! | ||
* \brief Print WGMMA assembly string given parameters. | ||
* \param shape The shape string mMnNkK | ||
* \param A_layout The layout of multiplicand A, can be either "row" or "col". | ||
* \param B_layout The layout of multiplicand B, can be either "row" or "col". | ||
* \param A_dtype The data type of multiplicand A. | ||
* \param B_dtype The data type of multiplicand B. | ||
* \param C_dtype The data type of multiplicand C. | ||
*/ | ||
std::string PrintWGMMAAssembly( | ||
const std::string &shape, const bool &A_layout, const bool &B_layout, | ||
const std::string &A_dtype, const std::string &B_dtype, | ||
const std::string &C_dtype, const std::string &a_desc, | ||
const std::string &A_offset, const std::string &b_desc, | ||
const std::string &B_offset, const std::string &c_ptr, | ||
const std::string &c_offset, const bool &scale_out, const bool &scale_in_a, | ||
const bool &scale_in_b, const bool &a_is_shared, | ||
const std::string &metadata, const std::string &metadata_offset, | ||
const std::string &sparsity_selector, bool sparse); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Documentation incomplete and misaligned parameters.
The documentation comment only describes the first 6 parameters, but the function signature has 17 parameters. This incomplete documentation makes it difficult to understand the purpose of the remaining parameters.
Please update the documentation to include all parameters:
/*!
* \brief Print WGMMA assembly string given parameters.
* \param shape The shape string mMnNkK
* \param A_layout The layout of multiplicand A, can be either "row" or "col".
* \param B_layout The layout of multiplicand B, can be either "row" or "col".
* \param A_dtype The data type of multiplicand A.
* \param B_dtype The data type of multiplicand B.
* \param C_dtype The data type of multiplicand C.
+ * \param a_desc Descriptor for operand A (shared memory descriptor or register pointer).
+ * \param A_offset Offset for operand A.
+ * \param b_desc Descriptor for operand B.
+ * \param B_offset Offset for operand B.
+ * \param c_ptr Pointer to accumulator C.
+ * \param c_offset Offset for accumulator C.
+ * \param scale_out Scaling output flag.
+ * \param scale_in_a Scaling input A flag.
+ * \param scale_in_b Scaling input B flag.
+ * \param a_is_shared Whether operand A is in shared memory.
+ * \param metadata Pointer to metadata buffer (for sparse operations).
+ * \param metadata_offset Offset in metadata buffer.
+ * \param sparsity_selector Sparsity selector for sparse operations.
+ * \param sparse Whether this is a sparse WGMMA operation.
*/
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
/*! | |
* \brief Print WGMMA assembly string given parameters. | |
* \param shape The shape string mMnNkK | |
* \param A_layout The layout of multiplicand A, can be either "row" or "col". | |
* \param B_layout The layout of multiplicand B, can be either "row" or "col". | |
* \param A_dtype The data type of multiplicand A. | |
* \param B_dtype The data type of multiplicand B. | |
* \param C_dtype The data type of multiplicand C. | |
*/ | |
std::string PrintWGMMAAssembly( | |
const std::string &shape, const bool &A_layout, const bool &B_layout, | |
const std::string &A_dtype, const std::string &B_dtype, | |
const std::string &C_dtype, const std::string &a_desc, | |
const std::string &A_offset, const std::string &b_desc, | |
const std::string &B_offset, const std::string &c_ptr, | |
const std::string &c_offset, const bool &scale_out, const bool &scale_in_a, | |
const bool &scale_in_b, const bool &a_is_shared, | |
const std::string &metadata, const std::string &metadata_offset, | |
const std::string &sparsity_selector, bool sparse); | |
/*! | |
* \brief Print WGMMA assembly string given parameters. | |
* \param shape The shape string mMnNkK | |
* \param A_layout The layout of multiplicand A, can be either "row" or "col". | |
* \param B_layout The layout of multiplicand B, can be either "row" or "col". | |
* \param A_dtype The data type of multiplicand A. | |
* \param B_dtype The data type of multiplicand B. | |
* \param C_dtype The data type of multiplicand C. | |
* \param a_desc Descriptor for operand A (shared memory descriptor or register pointer). | |
* \param A_offset Offset for operand A. | |
* \param b_desc Descriptor for operand B. | |
* \param B_offset Offset for operand B. | |
* \param c_ptr Pointer to accumulator C. | |
* \param c_offset Offset for accumulator C. | |
* \param scale_out Scaling output flag. | |
* \param scale_in_a Scaling input A flag. | |
* \param scale_in_b Scaling input B flag. | |
* \param a_is_shared Whether operand A is in shared memory. | |
* \param metadata Pointer to metadata buffer (for sparse operations). | |
* \param metadata_offset Offset in metadata buffer. | |
* \param sparsity_selector Sparsity selector for sparse operations. | |
* \param sparse Whether this is a sparse WGMMA operation. | |
*/ | |
std::string PrintWGMMAAssembly( | |
const std::string &shape, const bool &A_layout, const bool &B_layout, | |
const std::string &A_dtype, const std::string &B_dtype, | |
const std::string &C_dtype, const std::string &a_desc, | |
const std::string &A_offset, const std::string &b_desc, | |
const std::string &B_offset, const std::string &c_ptr, | |
const std::string &c_offset, const bool &scale_out, const bool &scale_in_a, | |
const bool &scale_in_b, const bool &a_is_shared, | |
const std::string &metadata, const std::string &metadata_offset, | |
const std::string &sparsity_selector, bool sparse); |
🤖 Prompt for AI Agents
In src/target/ptx.h around lines 68 to 86, the function comment documents only
the first six parameters while the declaration has 17; update the block comment
to include @param entries for every parameter in the signature (A_layout,
B_layout — explain what true/false means or that they represent "row"/"col";
a_desc and b_desc — describe descriptor strings for A/B; A_offset and B_offset —
offset expressions or registers for A/B; c_ptr and c_offset — pointer and offset
used for C output; scale_out, scale_in_a, scale_in_b — boolean flags controlling
scaling behavior; a_is_shared — whether A is in shared memory; metadata and
metadata_offset — metadata pointer/string and its offset; sparsity_selector —
sparsity control selector; sparse — whether sparse mode is enabled), keeping
each description concise and aligned with the existing doc style.
src/target/ptx.h
Outdated
const std::string &shape, const bool &A_layout, const bool &B_layout, | ||
const std::string &A_dtype, const std::string &B_dtype, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Inconsistent use of const references for boolean parameters.
The function signature uses const bool &
for boolean parameters, which is unnecessary overhead for primitive types. This differs from the existing PrintMMAAssembly
function that passes booleans by value.
Apply this diff to pass booleans by value:
-std::string PrintWGMMAAssembly(
- const std::string &shape, const bool &A_layout, const bool &B_layout,
+std::string PrintWGMMAAssembly(
+ const std::string &shape, bool A_layout, bool B_layout,
const std::string &A_dtype, const std::string &B_dtype,
const std::string &C_dtype, const std::string &a_desc,
const std::string &A_offset, const std::string &b_desc,
const std::string &B_offset, const std::string &c_ptr,
- const std::string &c_offset, const bool &scale_out, const bool &scale_in_a,
- const bool &scale_in_b, const bool &a_is_shared,
+ const std::string &c_offset, bool scale_out, bool scale_in_a,
+ bool scale_in_b, bool a_is_shared,
const std::string &metadata, const std::string &metadata_offset,
const std::string &sparsity_selector, bool sparse);
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
const std::string &shape, const bool &A_layout, const bool &B_layout, | |
const std::string &A_dtype, const std::string &B_dtype, | |
std::string PrintWGMMAAssembly( | |
const std::string &shape, bool A_layout, bool B_layout, | |
const std::string &A_dtype, const std::string &B_dtype, | |
const std::string &C_dtype, const std::string &a_desc, | |
const std::string &A_offset, const std::string &b_desc, | |
const std::string &B_offset, const std::string &c_ptr, | |
const std::string &c_offset, bool scale_out, bool scale_in_a, | |
bool scale_in_b, bool a_is_shared, | |
const std::string &metadata, const std::string &metadata_offset, | |
const std::string &sparsity_selector, bool sparse); |
🤖 Prompt for AI Agents
In src/target/ptx.h around lines 78-79, the function signature uses `const bool
&A_layout` and `const bool &B_layout` which is unnecessary for primitive types;
change both parameters to pass by value (`bool A_layout, bool B_layout`) and
update any matching function declarations/definitions and callers to use value
semantics so signatures remain consistent with PrintMMAAssembly.
template <int layout_type = 0, int leading_byte_offset = 0, | ||
int stride_byte_offset = 0, typename T> | ||
TL_DEVICE void initialize_descriptor(GmmaDescriptor &descriptor, | ||
T *start_address) { | ||
descriptor.bitfield.start_address_ = | ||
cute::cast_smem_ptr_to_uint(start_address) >> 4; | ||
descriptor.bitfield.layout_type_ = layout_type; | ||
descriptor.bitfield.base_offset_ = 0; | ||
descriptor.bitfield.leading_byte_offset_ = leading_byte_offset; | ||
descriptor.bitfield.stride_byte_offset_ = stride_byte_offset; | ||
} | ||
|
||
template <typename T> | ||
TL_DEVICE void increase_descriptor_offset(GmmaDescriptor &descriptor, | ||
T offset) { | ||
descriptor.reg32_[0] += (offset >> 4); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
🧩 Analysis chain
Consider validating descriptor parameters.
The initialize_descriptor
and increase_descriptor_offset
functions directly manipulate hardware descriptor fields without validation. Consider adding assertions or checks for:
- Valid layout_type values (0-3)
- Byte offset alignment requirements
- Address alignment to 16-byte boundaries
🏁 Script executed:
#!/bin/bash
# Check if there are any existing validation patterns for descriptor initialization
ast-grep --pattern 'initialize_descriptor($$$) {
$$$
assert($$$)
$$$
}'
# Look for any documentation or comments about descriptor constraints
rg -A 3 -B 3 "descriptor.*align|align.*descriptor" --type cpp --type h
Length of output: 333
Validate GMMA descriptor parameters before writing hardware fields
These functions write descriptor bitfields with no checks; src/target/cuda.h documents a 64‑byte alignment requirement for tensor map descriptors. Add assertions/guards in src/tl_templates/cuda/common.h (initialize_descriptor and increase_descriptor_offset, ~lines 405–421):
- Assert layout_type is in the supported range (0–3).
- Assert start_address is 16‑byte aligned (e.g., (uintptr_t)start_address & 0xF == 0).
- Assert leading_byte_offset, stride_byte_offset, and the offset passed to increase_descriptor_offset are multiples of 16 (since code uses >> 4) and fit the descriptor bitfield widths; also guard against reg32_[0] overflow when adding (offset >> 4).
Files to edit: src/tl_templates/cuda/common.h (initialize_descriptor, increase_descriptor_offset). Reference: src/target/cuda.h (tensor map descriptor alignment comment).
🤖 Prompt for AI Agents
In src/tl_templates/cuda/common.h around lines 405 to 421, add validation before
writing descriptor fields: assert layout_type is within 0..3; assert
start_address is 16-byte aligned ((uintptr_t)start_address & 0xF == 0); assert
leading_byte_offset and stride_byte_offset are multiples of 16 and within the
bitfield width limits documented for GmmaDescriptor; in
increase_descriptor_offset assert offset is a multiple of 16, compute delta =
offset >> 4 and check that descriptor.reg32_[0] + delta will not overflow the
destination bitfield (or reg32_) before applying it; if any check fails, either
trigger a fast-fail assertion or return/handle the error path so invalid values
are never written to hardware fields.
if not isinstance(descriptor, (BufferLoad, Buffer)): | ||
raise TypeError("Descriptor must be a tvm.tir.Buffer or tvm.tir.BufferLoad.") | ||
|
||
if isinstance(descriptor, Buffer) and len(descriptor.shape) != 1 or descriptor.shape[0] != 1: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fix operator precedence ambiguity.
The condition mixes and
and or
operators without proper parentheses, which can lead to unexpected behavior.
Apply this diff to fix the operator precedence:
- if isinstance(descriptor, Buffer) and len(descriptor.shape) != 1 or descriptor.shape[0] != 1:
+ if isinstance(descriptor, Buffer) and (len(descriptor.shape) != 1 or descriptor.shape[0] != 1):
The same issue exists in increase_descriptor_offset
:
- if isinstance(descriptor, Buffer) and len(descriptor.shape) != 1 or descriptor.shape[0] != 1:
+ if isinstance(descriptor, Buffer) and (len(descriptor.shape) != 1 or descriptor.shape[0] != 1):
Also applies to: 403-403
🧰 Tools
🪛 Ruff (0.12.2)
377-377: Parenthesize a and b
expressions when chaining and
and or
together, to make the precedence clear
Parenthesize the and
subexpression
(RUF021)
🤖 Prompt for AI Agents
In tilelang/language/builtin.py around lines 377 and 403, the conditional mixes
"and" and "or" without parentheses causing ambiguous precedence; wrap the shape
checks in parentheses so the intent is clear — e.g. make the condition check
"isinstance(descriptor, Buffer) and (len(descriptor.shape) != 1 or
descriptor.shape[0] != 1)" (and apply the same parenthesized grouping in
increase_descriptor_offset at line 403) to ensure the logical grouping is
correct.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 6
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (2)
src/target/ptx.h (1)
29-31
: Missing headers for Replacer (compile blocker).Replacer uses std::vector and std::pair but (and ideally ) aren’t included.
Apply:
#include <string> #include <tuple> +#include <vector> +#include <utility>src/target/ptx.cc (1)
26-31
: Missing (compile blocker).std::stringstream is used but header isn’t included.
Apply:
#include <algorithm> +#include <sstream> #include <string> #include <tuple> #include <utility> #include <vector>
♻️ Duplicate comments (5)
src/target/ptx.h (2)
156-164
: WGMMA Doxygen is incomplete and mismatched with the signature.Only the first 6 params are documented; names refer to A/B_layout strings while the API uses a_is_k_major/b_is_k_major and many more args.
Recommend adding @param docs for all parameters and aligning names/semantics with the signature (a_is_k_major, b_is_k_major, a_desc, A_offset, b_desc, B_offset, c_ptr, c_offset, scale_out, scale_in_a, scale_in_b, a_is_shared, metadata, metadata_offset, sparsity_selector, sparse). Want a ready-to-apply doc patch?
165-174
: Pass booleans by value, not const references.Primitive bools shouldn’t be passed by const&. Keep consistent with other APIs.
Apply:
-std::string PrintWGMMAAssembly( - const std::string &shape, const bool &a_is_k_major, const bool &b_is_k_major, +std::string PrintWGMMAAssembly( + const std::string &shape, bool a_is_k_major, bool b_is_k_major, const std::string &A_dtype, const std::string &B_dtype, const std::string &C_dtype, const std::string &a_desc, const std::string &A_offset, const std::string &b_desc, const std::string &B_offset, const std::string &c_ptr, - const std::string &c_offset, const bool &scale_out, const bool &scale_in_a, - const bool &scale_in_b, const bool &a_is_shared, + const std::string &c_offset, bool scale_out, bool scale_in_a, + bool scale_in_b, bool a_is_shared, const std::string &metadata, const std::string &metadata_offset, const std::string &sparsity_selector, bool sparse);src/tl_templates/cuda/common.h (3)
438-454
: Add basic parameter validation (layout_type range).Follow-up to earlier review on descriptor validation.
Add: assert(0 <= layout_type && layout_type <= 3).
450-454
: increase_descriptor_offset risks overflow/carry into other fields.Directly adding to reg32_[0] can spill into adjacent fields; enforce 16B granularity and update the bitfield safely.
Apply:
TL_DEVICE void increase_descriptor_offset(GmmaDescriptor &descriptor, T offset) { - descriptor.reg32_[0] += (offset >> 4); + assert((offset % 16) == 0); + uint32_t delta = uint32_t(offset >> 4); + uint32_t new_sa = uint32_t(descriptor.bitfield.start_address_) + delta; + // 14-bit field + assert((new_sa & ~0x3FFFu) == 0); + descriptor.bitfield.start_address_ = uint16_t(new_sa & 0x3FFFu); }
438-448
: initialize_descriptor: missing >>4 for byte-offset fields; no validation.Bitfields exclude 4 LSBs; current code writes raw byte offsets. Also lacks alignment/range checks.
Apply:
TL_DEVICE void initialize_descriptor(GmmaDescriptor &descriptor, T *start_address) { - descriptor.bitfield.start_address_ = - cute::cast_smem_ptr_to_uint(start_address) >> 4; + auto sa = cute::cast_smem_ptr_to_uint(start_address); + // 16B alignment and 14-bit range + assert((sa & 0xF) == 0); + descriptor.bitfield.start_address_ = uint16_t((sa >> 4) & 0x3FFF); descriptor.bitfield.layout_type_ = layout_type; descriptor.bitfield.base_offset_ = 0; - descriptor.bitfield.leading_byte_offset_ = leading_byte_offset; - descriptor.bitfield.stride_byte_offset_ = stride_byte_offset; + assert((leading_byte_offset % 16) == 0); + assert((stride_byte_offset % 16) == 0); + descriptor.bitfield.leading_byte_offset_ = + uint16_t(((leading_byte_offset >> 4) & 0x3FFF)); + descriptor.bitfield.stride_byte_offset_ = + uint16_t(((stride_byte_offset >> 4) & 0x3FFF)); }Also assert layout_type in [0,3].
🧹 Nitpick comments (14)
tilelang/language/tir/op.py (1)
1106-1144
: Add a docstring for ptx_wGMMA_rs; clarify variant.Mirror the ss variant doc and note A is a buffer (not a descriptor).
Apply this diff:
def ptx_wgmma_rs( dtype, wgmma_prefix, a_is_k_major, b_is_k_major, a_dtype_abbrv, b_dtype_abbrv, accum_dtype_abbrv, A_buf, A_offset, B_desc, B_offset, C_data, C_offset, scale_out, scale_in_a, scale_in_b, ): - - + """PTX WGMMA (warp-group MMA) intrinsic wrapper. + + Variant: rs (A uses buffer pointer, B uses descriptor). + See NVIDIA PTX ISA — Warpgroup matrix instructions (WGMMA). + https://docs.nvidia.com/cuda/parallel-thread-execution/index.html + """Optional: consider renaming A_buf to A_data or A_ptr for consistency with C_data (only if no keyword callers rely on A_buf).
src/target/ptx.h (1)
45-69
: Two DataType enums exist (ptx::DataType here and tl::DataType in common.h).Risk of drift and conversion friction.
Add explicit conversion helpers and static_assert ordinal equivalence in one place, or consolidate to a single enum declared in a shared header.
src/tl_templates/cuda/common.h (1)
18-19
: Redundant SMEM pointer casters.Both cute::cast_smem_ptr_to_uint and local helpers exist (cast_smem_ptr_to_int/smem_ptr_to_uint).
Pick one idiom and remove duplicates to avoid confusion.
src/tl_templates/cuda/instruction/wgmma.h (1)
22-35
: Don’t rely on device printf; fail fast when unspecialized.Keep printf only under a debug guard; restore a compile-time error otherwise.
Apply:
- ) { - printf("DEBUG: WgmmaSSImpl fallback - A_type=%d (kFloat16=%d), B_type=%d, C_type=%d, M=%d, N=%d, K=%d, tnspA=%d, tnspB=%d, scaleA=%d, scaleB=%d\n", - (int)A_type, (int)DataType::kFloat16, (int)B_type, (int)C_type, M, N, K, (int)tnspA, (int)tnspB, scaleA, scaleB); - // 暂时注释掉 static_assert 来看调试输出 - // static_assert(always_false_v<decltype(c)>, - // "wgmma_ss: No specialization available for given template parameters!"); - }; + ) { +#if defined(TL_DEBUG_WGMMA_FALLBACK) + printf("WGMMA fallback A=%d B=%d C=%d M=%d N=%d K=%d tnspA=%d tnspB=%d scaleA=%d scaleB=%d\n", + (int)A_type, (int)B_type, (int)C_type, M, N, K, (int)tnspA, (int)tnspB, scaleA, scaleB); +#else + static_assert(always_false_v<decltype(c)>, + "wgmma_ss: No specialization available for given template parameters"); +#endif + };src/target/ptx.cc (2)
171-177
: Pass bool by value in LayoutTypeFromBool.Avoid const& for primitives.
Apply:
-LayoutType LayoutTypeFromBool(const bool &layout) { +LayoutType LayoutTypeFromBool(bool layout) {
1186-1195
: Unnecessary const& on bools in PrintWGMMAAssembly.Match the header change; pass by value.
Apply:
-std::string PrintWGMMAAssembly( - const std::string &shape, const bool &a_is_k_major, const bool &b_is_k_major, +std::string PrintWGMMAAssembly( + const std::string &shape, bool a_is_k_major, bool b_is_k_major, const std::string &A_dtype, const std::string &B_dtype, const std::string &C_dtype, const std::string &a_desc, const std::string &A_offset, const std::string &b_desc, const std::string &B_offset, const std::string &c_ptr, - const std::string &c_offset, const bool &scale_out, const bool &scale_in_a, - const bool &scale_in_b, const bool &a_is_shared, + const std::string &c_offset, bool scale_out, bool scale_in_a, + bool scale_in_b, bool a_is_shared, const std::string &metadata, const std::string &metadata_offset, const std::string &sparsity_selector, bool sparse) {src/target/codegen_cuda.cc (3)
1306-1317
: Fix argument docs for tl::ptx_wgmma_ss branchThe comments don’t match the actual args (dtype is the call’s return type, not in
op->args
). Please correct to avoid future regressions.- // arg 0: dtype - // arg 1: shape - // arg 2: A_layout - // arg 3: B_layout - // arg 4: A_dtype - // arg 5: B_dtype - // arg 6: C_dtype - // arg 7: multiplicand_a - // arg 8: multiplicand_b - // arg 9: accumulator - // arg 10: saturate + // arg 0: wgmma_prefix (shape string, e.g. "m64n128k32") + // arg 1: a_is_k_major (bool) + // arg 2: b_is_k_major (bool) + // arg 3: a_dtype_abbrv (string) + // arg 4: b_dtype_abbrv (string) + // arg 5: accum_dtype_abbrv (string) + // arg 6: A_desc (descriptor, shared path) + // arg 7: A_offset (byte offset) + // arg 8: B_desc (descriptor) + // arg 9: B_offset (byte offset) + // arg 10: C_data (ptr) + // arg 11: C_offset (byte offset) + // arg 12: scale_out (bool) + // arg 13: scale_in_a (bool) + // arg 14: scale_in_b (bool)
1364-1375
: Align RS branch docs with actual argsMirror the corrected SS docs for RS; current comments mention “dtype/saturate”, which don’t exist here.
- // arg 0: dtype - // arg 1: shape - // arg 2: A_layout - // arg 3: B_layout - // arg 4: A_dtype - // arg 5: B_dtype - // arg 6: C_dtype - // arg 7: multiplicand_a - // arg 8: multiplicand_b - // arg 9: accumulator - // arg 10: saturate + // arg 0: wgmma_prefix (shape) + // arg 1: a_is_k_major (bool) + // arg 2: b_is_k_major (bool) + // arg 3: a_dtype_abbrv (string) + // arg 4: b_dtype_abbrv (string) + // arg 5: accum_dtype_abbrv (string) + // arg 6: A_buf (global pointer, non-shared path) + // arg 7: A_offset (byte offset) + // arg 8: B_desc (descriptor) + // arg 9: B_offset (byte offset) + // arg 10: C_data (ptr) + // arg 11: C_offset (byte offset) + // arg 12: scale_out (bool) + // arg 13: scale_in_a (bool) + // arg 14: scale_in_b (bool)
1724-1743
: Descriptor intrinsics emission — consider explicit offset type and statement-only usage
- Template parameter for
increase_descriptor_offset
should be explicit-width to avoid ABI surprises on different toolchains.- These are side-effect calls; ensure they’re only used in
EvaluateNode
contexts.- os << "tl::increase_descriptor_offset<int>(" << PrintExpr(descriptor) + os << "tl::increase_descriptor_offset<int32_t>(" << PrintExpr(descriptor) << ", " << PrintExpr(offset) << ")";Would you like me to scan call sites to confirm they’re only used as statements?
tilelang/intrinsics/wgmma_macro_generator.py (5)
216-219
: Use read access for descriptor base pointers.These descriptors are read by the instruction; use
access_ptr("r")
, not"w"
. Covered in the earlier diffs.Also applies to: 283-283
104-109
: Remove unusedn_dim
parameter.
_initialize_wgmma_prefix(self, n_dim: int = 16)
doesn’t usen_dim
(Ruff ARG002). Simplify:- def _initialize_wgmma_prefix(self, n_dim: int = 16): + def _initialize_wgmma_prefix(self): @@ - self._initialize_wgmma_prefix(self.n_dim) + self._initialize_wgmma_prefix()Also applies to: 94-95
333-333
: Drop duplicate import ofis_fragment
.Already imported at the top; remove the inner import.
- from tilelang.utils import is_fragment
312-317
: Fix docstring: this is a load layout for operand A, not a store layout.- Create a layout function for storing MMA results into a fragment buffer. - This layout is used in conjunction with `inverse_mma_store_layout` to - map fragment indices to threads and local indices. + Create a layout describing how to load MMA operand A from a fragment buffer. + This layout is used in conjunction with `inverse_mma_load_layout` to + map fragment indices to threads and per-thread local indices.
41-59
: Clarifyswizzle_atom_size()
semantics or compute from bytes to avoid confusion.If you keep this helper, consider defining it as
swizzle_byte_size() // 16
(16‑byte atoms), to match descriptor units terminology. Current// 16
on bits is non‑obvious. Usage has been corrected in earlier diffs, but this improves readability:def swizzle_atom_size(self) -> int: - if self.is_swizzle_32b(): - return 32 // 16 - elif self.is_swizzle_64b(): - return 64 // 16 - elif self.is_swizzle_128b(): - return 128 // 16 - else: - return 1 + # number of 16-byte atoms in the swizzle size (32B→2, 64B→4, 128B→8) + return self.swizzle_byte_size() // 16
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (9)
src/op/builtin.h
(2 hunks)src/target/codegen_cuda.cc
(5 hunks)src/target/ptx.cc
(9 hunks)src/target/ptx.h
(2 hunks)src/tl_templates/cuda/common.h
(5 hunks)src/tl_templates/cuda/gemm.h
(1 hunks)src/tl_templates/cuda/instruction/wgmma.h
(1 hunks)tilelang/intrinsics/wgmma_macro_generator.py
(1 hunks)tilelang/language/tir/op.py
(1 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
- src/op/builtin.h
🧰 Additional context used
🧬 Code graph analysis (5)
src/tl_templates/cuda/instruction/wgmma.h (2)
src/op/builtin.h (1)
tl
(22-362)src/tl_templates/cuda/common.h (5)
tl
(306-400)DataType
(315-360)int
(97-100)int
(135-142)uint32_t
(118-120)
src/target/codegen_cuda.cc (4)
src/transform/storage_rewrite.cc (6)
scope
(674-678)scope
(674-674)n
(366-370)n
(366-366)n
(371-375)n
(371-371)tilelang/language/tir/op.py (2)
ptx_wgmma_ss
(1064-1103)ptx_wgmma_rs
(1106-1143)src/target/ptx.cc (8)
PrintWGMMAAssembly
(1186-1257)PrintWGMMAAssembly
(1186-1195)ParseMMAShape
(140-148)ParseMMAShape
(140-140)DTypeEnumToString
(107-109)DTypeEnumToString
(107-107)DTypeEnumToString
(111-113)DTypeEnumToString
(111-111)tilelang/language/builtin.py (2)
initialize_descriptor
(355-386)increase_descriptor_offset
(389-411)
src/tl_templates/cuda/common.h (3)
src/target/ptx.h (1)
DataType
(45-91)src/tl_templates/cuda/ldsm.h (12)
void
(7-14)void
(16-23)void
(25-33)void
(35-42)void
(44-52)void
(54-62)void
(64-70)void
(72-79)void
(81-89)void
(91-98)void
(100-108)void
(110-119)tilelang/language/builtin.py (2)
initialize_descriptor
(355-386)increase_descriptor_offset
(389-411)
tilelang/intrinsics/wgmma_macro_generator.py (6)
tilelang/utils/language.py (1)
is_fragment
(68-78)tilelang/layout/swizzle.py (3)
make_full_bank_swizzled_layout
(39-59)make_half_bank_swizzled_layout
(64-84)make_quarter_bank_swizzled_layout
(89-109)tilelang/layout/fragment.py (3)
Fragment
(13-213)replicate
(146-160)repeat
(123-144)tilelang/language/builtin.py (1)
initialize_descriptor
(355-386)tilelang/language/tir/op.py (3)
ptx_wgmma_ss
(1064-1103)ptx_wgmma_rs
(1106-1143)any
(1774-1790)tilelang/intrinsics/mma_macro_generator.py (1)
get_store_index_map
(159-165)
src/target/ptx.h (1)
src/target/ptx.cc (10)
DTypeFromString
(54-104)DTypeFromString
(54-54)DTypeEnumToString
(107-109)DTypeEnumToString
(107-107)DTypeEnumToString
(111-113)DTypeEnumToString
(111-111)ParseMMAShape
(140-148)ParseMMAShape
(140-140)PrintWGMMAAssembly
(1186-1257)PrintWGMMAAssembly
(1186-1195)
🪛 Ruff (0.12.2)
tilelang/intrinsics/wgmma_macro_generator.py
104-104: Unused method argument: n_dim
(ARG002)
143-143: Avoid specifying long messages outside the exception class
(TRY003)
353-353: Avoid specifying long messages outside the exception class
(TRY003)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Cursor Bugbot
🔇 Additional comments (9)
tilelang/language/tir/op.py (1)
1064-1104
: Change docstring from WMMA → WGMMA and update linkFile: tilelang/language/tir/op.py (ptx_wgmma_ss). Docstring currently says WMMA; it should reference WGMMA (warp-group). Verified tl.ptx_wgmma_* ops are registered and expect 15 args.
def ptx_wgmma_ss( @@ - """TVM intrinsic for ptx tensor core wmma instructions - https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-instructions-for-wmma - """ + """PTX WGMMA (warp-group MMA) intrinsic wrapper. + + Variant: ss (A uses descriptor, B uses descriptor). + See NVIDIA PTX ISA — Warpgroup matrix instructions (WGMMA). + https://docs.nvidia.com/cuda/parallel-thread-execution/index.html + """src/tl_templates/cuda/instruction/wgmma.h (2)
417-452
: INT8 coverage seems partial (only N=8/16).valid_wgmma_configs include N up to 256. If codegen may request N=32/64/… you’ll hit the fallback at compile time.
Would you like me to generate the missing specializations (or a meta-emitter) for s8->s32 shapes N∈{32,64,96,128,192,256}?
457-491
: FP8 E4M3→F16/F32: only N=8 variants present.Configs list N up to 256. Confirm dispatcher never asks for larger N, or add the remaining specializations.
I can auto-generate these blocks to match the config table.
src/tl_templates/cuda/gemm.h (1)
5-6
: Include order LGTM.Including wgmma.h in the sm90+ path before gemm_sm90.h is appropriate.
src/target/ptx.cc (1)
1222-1256
: WGMMA operand builder: sparse path abort is fine; ensure scale_in use matches dtype.Scale inputs are emitted only for non-integer types; that matches spec. No action needed.
Please confirm codegen never requests sparse WGMMA yet; otherwise return a diagnostic instead of LOG(FATAL) at higher layers.
src/tl_templates/cuda/common.h (1)
315-339
: Enum duplication with ptx::DataType — parity unverifiedNo matches found for tvm::tl::codegen::ptx::DataType; ensure src/tl_templates/cuda/common.h (DataType, lines 315–339) has identical ordinals/values to the ptx definition. To verify locally, run:
rg -nP '\benum class DataType\b' -S rg -nP 'ptx::DataType|namespace\s+ptx' -Ssrc/target/codegen_cuda.cc (2)
898-901
: Local descriptor deref handling in GetBufferRef — LGTMReturning the scalar
vid
for"local.descriptor"
(same as"local.var"
) avoids bogus[]
/casts ontl::GmmaDescriptor
. This matches the Python-sideBuffer(shape=[1])
abstraction.
1810-1812
: Ensure tl::GmmaDescriptor is visible from generated TUsrc/target/codegen_cuda.cc already emits "#include <tl_templates/cuda/gemm.h>" in Finish() (decl_stream at src/target/codegen_cuda.cc:186). No declaration of tl::GmmaDescriptor was found in the repo—confirm that gemm.h (or an included header like common.h) actually defines tl::GmmaDescriptor; if it does not, add an explicit include for the header that declares it in Finish().
tilelang/intrinsics/wgmma_macro_generator.py (1)
104-109
: Validateinst_n
inwgmma_prefix
against supported shapes.Hardware supports specific N sizes (e.g., 8/16/32/64/128).
inst_n = block_col_warps * warp_col_tiles
may produce unsupported values. Add a guard or normalize.
// Decay to a uint64_t | ||
CUTE_HOST_DEVICE constexpr operator uint64_t() const noexcept { | ||
return desc_; | ||
} | ||
template <typename T> | ||
CUTE_HOST_DEVICE constexpr GmmaDescriptor operator+(const T &offset) const { | ||
GmmaDescriptor ret; | ||
ret.reg32_[0] = reg32_[0] + uint32_t(offset); | ||
ret.reg32_[1] = reg32_[1]; | ||
return ret; | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
operator+ adds raw bytes; descriptor fields are in 16‑byte quanta.
The descriptor encodes addresses/offsets with 4 LSBs stripped. Adding raw offset likely mis-encodes.
Apply:
- template <typename T>
- CUTE_HOST_DEVICE constexpr GmmaDescriptor operator+(const T &offset) const {
+ template <typename T>
+ CUTE_HOST_DEVICE constexpr GmmaDescriptor operator+(const T &offset) const {
GmmaDescriptor ret;
- ret.reg32_[0] = reg32_[0] + uint32_t(offset);
+ ret.reg32_[0] = reg32_[0] + uint32_t(offset >> 4);
ret.reg32_[1] = reg32_[1];
return ret;
}
And document that offset is in bytes.
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
// Decay to a uint64_t | |
CUTE_HOST_DEVICE constexpr operator uint64_t() const noexcept { | |
return desc_; | |
} | |
template <typename T> | |
CUTE_HOST_DEVICE constexpr GmmaDescriptor operator+(const T &offset) const { | |
GmmaDescriptor ret; | |
ret.reg32_[0] = reg32_[0] + uint32_t(offset); | |
ret.reg32_[1] = reg32_[1]; | |
return ret; | |
} | |
// Decay to a uint64_t | |
CUTE_HOST_DEVICE constexpr operator uint64_t() const noexcept { | |
return desc_; | |
} | |
template <typename T> | |
CUTE_HOST_DEVICE constexpr GmmaDescriptor operator+(const T &offset) const { | |
GmmaDescriptor ret; | |
ret.reg32_[0] = reg32_[0] + uint32_t(offset >> 4); | |
ret.reg32_[1] = reg32_[1]; | |
return ret; | |
} |
assert k_dim >= micro_size_k, f"k_dim must be greater than or equal to {micro_size_k}, got k_dim: {k_dim}" | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add divisibility assertion for K micro-tiles.
You iterate range(k_dim // micro_size_k)
but don’t enforce divisibility; tails would be silently dropped. Add:
- assert k_dim >= micro_size_k, f"k_dim must be greater than or equal to {micro_size_k}, got k_dim: {k_dim}"
+ assert k_dim >= micro_size_k, f"k_dim must be greater than or equal to {micro_size_k}, got k_dim: {k_dim}"
+ assert (k_dim % micro_size_k) == 0, f"k_dim must be divisible by micro_size_k={micro_size_k}, got k_dim: {k_dim}"
Apply in both SS and RS.
Also applies to: 255-256
🤖 Prompt for AI Agents
In tilelang/intrinsics/wgmma_macro_generator.py around lines 168-169 (and
similarly around lines 255-256), you assert k_dim >= micro_size_k but do not
require k_dim to be divisible by micro_size_k, causing any tail to be silently
dropped when iterating range(k_dim // micro_size_k); add an assertion ensuring
k_dim % micro_size_k == 0 with a clear error message (e.g., "k_dim must be
divisible by micro_size_k, got k_dim: {k_dim}, micro_size_k: {micro_size_k}") in
both the SS and RS sections so the code fails fast instead of silently
truncating tails.
elems_in_bytes = DataType(self.a_dtype).bits // 8 | ||
|
||
# by default, we utilize non-swizzle layout offset | ||
a_leading_byte_offset = (8 * 8 * elems_in_bytes) if a_is_k_major else (8 * m_dim * | ||
elems_in_bytes) | ||
a_stride_byte_offset = (8 * k_dim * elems_in_bytes) if a_is_k_major else (8 * 8 * | ||
elems_in_bytes) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fix LBO/SBO units and per-operand element sizes (incorrect descriptor math).
- You compute element size once from A (
elems_in_bytes
) and reuse it for B. Ifa_dtype != b_dtype
(e.g., int8 × int8, fp8, tf32), B’s descriptor math is wrong. - In swizzled MN‑major paths you set LBO to
swizzle_atom_size()
(already in 16‑B atoms) but still right‑shift by 4 when initializing the descriptor, effectively dividing by 16 twice (LBO=0 for 32B/64B/128B). Use swizzle byte size and keep the>> 4
conversion, or pass atom counts without shifting.
Apply this diff to make LBO/SBO correct and per‑operand:
- elems_in_bytes = DataType(self.a_dtype).bits // 8
+ a_elems_in_bytes = DataType(self.a_dtype).bits // 8
+ b_elems_in_bytes = DataType(self.b_dtype).bits // 8
@@
- a_leading_byte_offset = (8 * 8 * elems_in_bytes) if a_is_k_major else (8 * m_dim *
- elems_in_bytes)
- a_stride_byte_offset = (8 * k_dim * elems_in_bytes) if a_is_k_major else (8 * 8 *
- elems_in_bytes)
+ a_leading_byte_offset = (8 * 8 * a_elems_in_bytes) if a_is_k_major else (8 * m_dim *
+ a_elems_in_bytes)
+ a_stride_byte_offset = (8 * k_dim * a_elems_in_bytes) if a_is_k_major else (8 * 8 *
+ a_elems_in_bytes)
@@
- if not a_swizzle_mode.is_none():
+ if not a_swizzle_mode.is_none():
# swizzle mode doesn't require LBO/SBO to be 1
# https://docs.nvidia.com/cuda/parallel-thread-execution/#asynchronous-warpgroup-level-leading-dimension-byte-offset
if a_is_k_major:
a_leading_byte_offset = 16
else:
# MN Major
# LBO represents the distance between two atoms along the M dimension
# SBO represents the distance between two atoms along the K dimension
- a_leading_byte_offset = a_swizzle_mode.swizzle_atom_size()
- a_stride_byte_offset = 8 * 64 * elems_in_bytes
+ a_leading_byte_offset = a_swizzle_mode.swizzle_byte_size()
+ a_stride_byte_offset = 8 * 64 * a_elems_in_bytes
@@
- b_leading_byte_offset = (8 * 8 * elems_in_bytes) if b_is_k_major else (8 * n_dim *
- elems_in_bytes)
- b_stride_byte_offset = (8 * k_dim * elems_in_bytes) if b_is_k_major else (8 * 8 *
- elems_in_bytes)
+ b_leading_byte_offset = (8 * 8 * b_elems_in_bytes) if b_is_k_major else (8 * n_dim *
+ b_elems_in_bytes)
+ b_stride_byte_offset = (8 * k_dim * b_elems_in_bytes) if b_is_k_major else (8 * 8 *
+ b_elems_in_bytes)
if not b_swizzle_mode.is_none():
# swizzle mode doesn't require LBO/SBO to be 1
# https://docs.nvidia.com/cuda/parallel-thread-execution/#asynchronous-warpgroup-level-leading-dimension-byte-offset
if b_is_k_major:
b_leading_byte_offset = 16
else:
# MN Major
# LBO represents the distance between two atoms along the N dimension
# SBO represents the distance between two atoms along the K dimension
- b_leading_byte_offset = b_swizzle_mode.swizzle_atom_size()
- b_stride_byte_offset = 8 * n_dim * elems_in_bytes
+ b_leading_byte_offset = b_swizzle_mode.swizzle_byte_size()
+ b_stride_byte_offset = 8 * n_dim * b_elems_in_bytes
@@
- T.initialize_descriptor(desc_a, A_buf.access_ptr("w"), a_swizzle_mode,
+ T.initialize_descriptor(desc_a, A_buf.access_ptr("r"), a_swizzle_mode,
int(a_leading_byte_offset >> 4), int(a_stride_byte_offset >> 4))
- T.initialize_descriptor(desc_b, B_buf.access_ptr("w"), b_swizzle_mode,
+ T.initialize_descriptor(desc_b, B_buf.access_ptr("r"), b_swizzle_mode,
int(b_leading_byte_offset >> 4), int(b_stride_byte_offset >> 4))
@@
- accum_dtype_abbrv, desc_a.data, (A_offset * elems_in_bytes) >> 4,
- desc_b.data, (B_offset * elems_in_bytes) >> 4, C_local_buf.data,
+ accum_dtype_abbrv, desc_a.data, (A_offset * a_elems_in_bytes) >> 4,
+ desc_b.data, (B_offset * b_elems_in_bytes) >> 4, C_local_buf.data,
Also applies to: 184-195, 196-211, 216-219, 229-231
assert k_dim >= micro_size_k, f"k_dim must be greater than or equal to {micro_size_k}, got k_dim: {k_dim}" | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Mirror the fixes in RS path (per‑operand sizes, swizzle LBO units, pointer mode).
Do the same corrections for RS:
- Use separate
a_elems_in_bytes
/b_elems_in_bytes
. - Use
swizzle_byte_size()
for LBO and keep>>4
. - Use
access_ptr("r")
for descriptors.
Also add the divisibility assert (see separate comment). Apply:
- assert k_dim >= micro_size_k, f"k_dim must be greater than or equal to {micro_size_k}, got k_dim: {k_dim}"
+ assert k_dim >= micro_size_k, f"k_dim must be greater than or equal to {micro_size_k}, got k_dim: {k_dim}"
+ assert (k_dim % micro_size_k) == 0, f"k_dim must be divisible by micro_size_k={micro_size_k}, got k_dim: {k_dim}"
@@
- elems_in_bytes = DataType(self.a_dtype).bits // 8
+ a_elems_in_bytes = DataType(self.a_dtype).bits // 8
+ b_elems_in_bytes = DataType(self.b_dtype).bits // 8
@@
- b_leading_byte_offset = (8 * 8 * elems_in_bytes) if b_is_k_major else (8 * n_dim *
- elems_in_bytes)
- b_stride_byte_offset = (8 * k_dim * elems_in_bytes) if b_is_k_major else (8 * 8 *
- elems_in_bytes)
+ b_leading_byte_offset = (8 * 8 * b_elems_in_bytes) if b_is_k_major else (8 * n_dim *
+ b_elems_in_bytes)
+ b_stride_byte_offset = (8 * k_dim * b_elems_in_bytes) if b_is_k_major else (8 * 8 *
+ b_elems_in_bytes)
if not b_swizzle_mode.is_none():
# swizzle mode doesn't require LBO/SBO to be 1
# https://docs.nvidia.com/cuda/parallel-thread-execution/#asynchronous-warpgroup-level-leading-dimension-byte-offset
if b_is_k_major:
b_leading_byte_offset = 16
else:
# MN Major
# LBO represents the distance between two atoms along the N dimension
# SBO represents the distance between two atoms along the K dimension
- b_leading_byte_offset = b_swizzle_mode.swizzle_atom_size()
- b_stride_byte_offset = 8 * n_dim * elems_in_bytes
+ b_leading_byte_offset = b_swizzle_mode.swizzle_byte_size()
+ b_stride_byte_offset = 8 * n_dim * b_elems_in_bytes
@@
- T.initialize_descriptor(desc_b, B_buf.access_ptr("w"), b_swizzle_mode,
+ T.initialize_descriptor(desc_b, B_buf.access_ptr("r"), b_swizzle_mode,
int(b_leading_byte_offset >> 4), int(b_stride_byte_offset >> 4))
@@
- (B_offset * elems_in_bytes) >> 4,
+ (B_offset * b_elems_in_bytes) >> 4,
Also applies to: 257-278, 283-284, 301-302
🤖 Prompt for AI Agents
In tilelang/intrinsics/wgmma_macro_generator.py around lines 255-256 (and
similarly apply changes to ranges 257-278, 283-284, 301-302), the RS path
currently mirrors the MS fixes incorrectly: replace any single shared
elems_in_bytes with separate a_elems_in_bytes and b_elems_in_bytes computed from
their per‑operand element sizes; for LBO swizzle sizes call
swizzle_byte_size(...) and retain the bitshift >> 4 as before; change descriptor
pointer creation to use access_ptr("r") instead of the current mode; and add the
divisibility assert that ensures k_dim (or relevant dim) is divisible by
micro_size_k (matching the other path). Ensure these edits are applied
consistently across the listed line ranges.
accum_dtype, | ||
wgmma_prefix, | ||
self.a_transposed, | ||
not self.b_transposed, | ||
a_dtype_abbrv, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🧩 Analysis chain
Likely wrong K‑major flags passed to RS intrinsic (inverted vs SS).
ptx_wgmma_rs
takes the same a_is_k_major
/b_is_k_major
semantics as SS. SS passes not self.a_transposed
and self.b_transposed
, but RS passes their negations. Align them:
- self.a_transposed,
- not self.b_transposed,
+ (not self.a_transposed),
+ self.b_transposed,
Please confirm against the PTX WGMMA docs for RS.
🌐 Web query:
What are the `a_is_k_major` / `b_is_k_major` semantics for `wgmma.mma_async` RS vs SS variants in NVIDIA PTX?
💡 Result:
Short answer: those flags tell the instruction whether the A / B operand’s shared‑memory descriptor (SS) — or the register‑sourced operand’s logical layout (RS) — is K‑major (i.e. the 8×2 / 2×(N/8) “core‑matrix” tile stride is along the K dimension) or not. Concretely:
-
SS variants (the .ss forms) take 64‑bit matrix descriptors for A and/or B; the descriptor’s “major‑K” bit(s) (a_is_k_major / b_is_k_major) mean “this SMEM tile is laid out K‑major” so the hardware will interpret LBO/SBO in the descriptor accordingly (core matrices are arranged with K as the strided direction). Use these bits when your shared‑memory tile is stored with K as the leading/strided dimension. [1][2]
-
RS variants (the .rs forms) source one operand from registers (or a thread‑local fragment) and the other from a descriptor; for RS the a_is_k_major / b_is_k_major semantic still denotes whether that operand’s logical tile is K‑major, but for the register‑sourced operand the layout interpretation is applied to the fragment/register mapping (i.e. whether the fragment is arranged as K‑major vs M/N‑major). In practice this means you must set the bit to match how your register fragment was produced (or how you packed it), otherwise the multiply/accumulate will use the wrong element ordering. [1][2]
References:
- PTX WGMMA / wgmma.mma_async description (matrix descriptors, K‑major vs MN‑major and SS/RS usage). [1]
- CUTLASS / WGMMA notes (explanation of descriptors, LBO/SBO and K‑major layout and the distinction of SS mode requiring descriptors). [2]
[1] NVIDIA PTX ISA (wgmma / wgmma.mma_async matrix descriptor & layout).
[2] CUTLASS / WGMMA tutorial and notes on shared‑memory descriptors and K‑major layout.
Fix inverted K‑major flags for RS WGMMA intrinsic
RS uses the same a_is_k_major / b_is_k_major semantics as SS; the call currently passes their negations. Update the args at tilelang/intrinsics/wgmma_macro_generator.py:291-295:
- self.a_transposed,
- not self.b_transposed,
+ (not self.a_transposed),
+ self.b_transposed,
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
accum_dtype, | |
wgmma_prefix, | |
self.a_transposed, | |
not self.b_transposed, | |
a_dtype_abbrv, | |
accum_dtype, | |
wgmma_prefix, | |
(not self.a_transposed), | |
self.b_transposed, | |
a_dtype_abbrv, |
🤖 Prompt for AI Agents
In tilelang/intrinsics/wgmma_macro_generator.py around lines 291-295 the K-major
flags for the RS WGMMA intrinsic are passed inverted relative to the SS
semantics; replace the current arguments so the a/b K-major booleans are the
negations of their current values (change self.a_transposed to not
self.a_transposed, and change not self.b_transposed to self.b_transposed) so the
call passes the correct a_is_k_major and b_is_k_major flags.
…e size calculations in wgmma macro generator. Enhanced handling of leading and stride byte offsets based on swizzle mode, improving clarity and performance in tensor core intrinsic emissions.
print(f"b_leading_byte_offset: {b_leading_byte_offset}") | ||
print(f"b_stride_byte_offset: {b_stride_byte_offset}") | ||
# exit() | ||
@T.macro |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
// 暂时注释掉 static_assert 来看调试输出 | ||
// static_assert(always_false_v<decltype(c)>, | ||
// "wgmma_ss: No specialization available for given template parameters!"); | ||
}; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
class GemmInst(IntEnum): | ||
MMA = 0 | ||
WGMMMA = 1 | ||
MFMA = 2 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
♻️ Duplicate comments (6)
src/target/codegen_cuda.cc (1)
1305-1362
: WGMMA-SS: emit validated asm_code; drop ad‑hoc string (invalid arithmetic on descriptor, dead code)
asm_code
fromPrintWGMMAAssembly(...)
is computed and then ignored. The customwgmma_asm_code
does pointer arithmetic ontl::GmmaDescriptor
viauint64_t((desc_a) + (A_offset))
, which is invalid. Unify with the RS path and streamasm_code
.Apply:
const bool a_is_shared = true; this->PrintIndent(); std::string asm_code = PrintWGMMAAssembly( shape, a_is_k_major, b_is_k_major, A_dtype, B_dtype, C_dtype, a_desc, A_offset, b_desc, B_offset, c_ref, c_offset, scale_out, scale_in_a, scale_in_b, a_is_shared, "", "", "", false); - auto [m, n, k] = tl::codegen::ptx::ParseMMAShape(shape); - std::string wgmma_asm_code = "tl::wgmma_ss<(AType), (BType), (CType), (M), (N), (K), (tnspA), (tnspB), (scaleA), (scaleB)>(uint64_t((desc_a) + (A_offset)), uint64_t((desc_b) + (B_offset)), ((uint32_t*)((C))), (scale_out));\n"; - // replace patterns - tl::codegen::Replacer replacer; - replacer.register_rule("(AType)", tl::codegen::ptx::DTypeEnumToString(A_dtype)); - replacer.register_rule("(BType)", tl::codegen::ptx::DTypeEnumToString(B_dtype)); - replacer.register_rule("(CType)", tl::codegen::ptx::DTypeEnumToString(C_dtype)); - replacer.register_rule("(M)", std::to_string(m)); - replacer.register_rule("(N)", std::to_string(n)); - replacer.register_rule("(K)", std::to_string(k)); - replacer.register_rule("(tnspA)", a_is_k_major? "false": "true"); - replacer.register_rule("(tnspB)", b_is_k_major? "false": "true"); - replacer.register_rule("(scaleA)", scale_in_a? "1": "-1"); - replacer.register_rule("(scaleB)", scale_in_b? "1": "-1"); - replacer.register_rule("(desc_a)", a_desc); - replacer.register_rule("(A_offset)", A_offset); - replacer.register_rule("(desc_b)", b_desc); - replacer.register_rule("(B_offset)", B_offset); - replacer.register_rule("(C)", c_ref + " + " + c_offset); - replacer.register_rule("(scale_out)", scale_out ? "true" : "false"); - wgmma_asm_code = replacer.rewrite(wgmma_asm_code); - this->stream << wgmma_asm_code; + this->stream << asm_code;tilelang/intrinsics/wgmma_macro_generator.py (5)
159-169
: Add validations:m_dim
multiple-of-64 andk_dim
divisibility bymicro_size_k
.Prevents silent no-op loops and tails.
m_dim = self.block_row_warps * self.warp_row_tiles warp_cols = self.warp_cols micro_size_k = self.micro_size_k k_dim, n_dim = self.chunk, self.block_col_warps * self.warp_col_tiles wgmma_prefix = self.wgmma_prefix @@ - assert k_dim >= micro_size_k, f"k_dim must be greater than or equal to {micro_size_k}, got k_dim: {k_dim}" + assert m_dim >= 64 and (m_dim % 64) == 0, f"m_dim must be a multiple of 64, got {m_dim}" + assert k_dim >= micro_size_k, f"k_dim must be ≥ {micro_size_k}, got {k_dim}" + assert (k_dim % micro_size_k) == 0, f"k_dim must be divisible by micro_size_k={micro_size_k}, got {k_dim}"
176-219
: Fix LBO/SBO units, per-operand element sizes, and descriptor pointer modes (SS path).Currently mixes A’s element size for B, uses
swizzle_atom_size()
then>>4
(double-divide by 16), and passesaccess_ptr("w")
. This will generate incorrect descriptors and offsets, especially for int8/fp8/tf32.- elems_in_bytes = DataType(self.a_dtype).bits // 8 + a_elems_in_bytes = DataType(self.a_dtype).bits // 8 + b_elems_in_bytes = DataType(self.b_dtype).bits // 8 @@ - a_leading_byte_offset = (8 * 8 * elems_in_bytes) if a_is_k_major else (8 * m_dim * - elems_in_bytes) - a_stride_byte_offset = (8 * k_dim * elems_in_bytes) if a_is_k_major else (8 * 8 * - elems_in_bytes) + a_leading_byte_offset = (8 * 8 * a_elems_in_bytes) if a_is_k_major else (8 * m_dim * + a_elems_in_bytes) + a_stride_byte_offset = (8 * k_dim * a_elems_in_bytes) if a_is_k_major else (8 * 8 * + a_elems_in_bytes) @@ - else: + else: # MN Major # LBO represents the distance between two atoms along the M dimension # SBO represents the distance between two atoms along the K dimension - a_leading_byte_offset = a_swizzle_mode.swizzle_atom_size() - a_stride_byte_offset = 8 * 64 * elems_in_bytes + a_leading_byte_offset = a_swizzle_mode.swizzle_byte_size() + a_stride_byte_offset = 8 * 64 * a_elems_in_bytes @@ - b_leading_byte_offset = (8 * 8 * elems_in_bytes) if b_is_k_major else (8 * n_dim * - elems_in_bytes) - b_stride_byte_offset = (8 * k_dim * elems_in_bytes) if b_is_k_major else (8 * 8 * - elems_in_bytes) + b_leading_byte_offset = (8 * 8 * b_elems_in_bytes) if b_is_k_major else (8 * n_dim * + b_elems_in_bytes) + b_stride_byte_offset = (8 * k_dim * b_elems_in_bytes) if b_is_k_major else (8 * 8 * + b_elems_in_bytes) if not b_swizzle_mode.is_none(): @@ - else: - # MN Major, K * N - # LBO represents the distance between two atoms along the N dimension - # SBO represents the distance between two atoms along the K dimension - b_n_axis_atoms = n_dim // (b_swizzle_mode.swizzle_byte_size() // elems_in_bytes) - if b_n_axis_atoms <= 1: - b_leading_byte_offset = 0 - else: - b_leading_byte_offset = 8 * 8 * elems_in_bytes * k_dim - - if b_n_axis_atoms <= 1: - b_stride_byte_offset = 8 * elems_in_bytes * n_dim - else: - b_stride_byte_offset = 8 * elems_in_bytes * (b_swizzle_mode.swizzle_byte_size() // elems_in_bytes) + else: + # MN Major (N × K): use swizzle byte size for LBO; SBO spans N in bytes. + b_leading_byte_offset = b_swizzle_mode.swizzle_byte_size() + b_stride_byte_offset = 8 * n_dim * b_elems_in_bytes @@ - desc_a = T.alloc_descriptor() - desc_b = T.alloc_descriptor() - T.initialize_descriptor(desc_a, A_buf.access_ptr("w"), a_swizzle_mode, + desc_a = T.alloc_descriptor() + desc_b = T.alloc_descriptor() + T.initialize_descriptor(desc_a, A_buf.access_ptr("r"), a_swizzle_mode, int(a_leading_byte_offset >> 4), int(a_stride_byte_offset >> 4)) - T.initialize_descriptor(desc_b, B_buf.access_ptr("w"), b_swizzle_mode, + T.initialize_descriptor(desc_b, B_buf.access_ptr("r"), b_swizzle_mode, int(b_leading_byte_offset >> 4), int(b_stride_byte_offset >> 4)) @@ - B_offset = k_dim_offset if b_is_k_major else k_dim_offset * (b_swizzle_mode.swizzle_byte_size() // elems_in_bytes) + B_offset = k_dim_offset if b_is_k_major else k_dim_offset * (b_swizzle_mode.swizzle_byte_size() // b_elems_in_bytes) @@ - accum_dtype_abbrv, desc_a.data, (A_offset * elems_in_bytes) >> 4, - desc_b.data, (B_offset * elems_in_bytes) >> 4, C_local_buf.data, + accum_dtype_abbrv, desc_a.data, (A_offset * a_elems_in_bytes) >> 4, + desc_b.data, (B_offset * b_elems_in_bytes) >> 4, C_local_buf.data,Also applies to: 239-254
278-279
: Add RS-side divisibility assert.Mirror the SS check to avoid dropped tails.
- assert k_dim >= micro_size_k, f"k_dim must be greater than or equal to {micro_size_k}, got k_dim: {k_dim}" + assert k_dim >= micro_size_k, f"k_dim must be ≥ {micro_size_k}, got {k_dim}" + assert (k_dim % micro_size_k) == 0, f"k_dim must be divisible by micro_size_k={micro_size_k}, got {k_dim}"
280-309
: Mirror per-operand sizes and descriptor pointer mode (RS path).Use
b_elems_in_bytes
andaccess_ptr("r")
; fix LBO/SBO units for swizzle.- elems_in_bytes = DataType(self.a_dtype).bits // 8 + a_elems_in_bytes = DataType(self.a_dtype).bits // 8 + b_elems_in_bytes = DataType(self.b_dtype).bits // 8 @@ - b_leading_byte_offset = (8 * 8 * elems_in_bytes) if b_is_k_major else (8 * n_dim * - elems_in_bytes) - b_stride_byte_offset = (8 * k_dim * elems_in_bytes) if b_is_k_major else (8 * 8 * - elems_in_bytes) + b_leading_byte_offset = (8 * 8 * b_elems_in_bytes) if b_is_k_major else (8 * n_dim * + b_elems_in_bytes) + b_stride_byte_offset = (8 * k_dim * b_elems_in_bytes) if b_is_k_major else (8 * 8 * + b_elems_in_bytes) @@ - else: - # MN Major - # LBO represents the distance between two atoms along the N dimension - # SBO represents the distance between two atoms along the K dimension - b_n_axis_atoms = n_dim // (b_swizzle_mode.swizzle_byte_size() // elems_in_bytes) - if b_n_axis_atoms <= 1: - b_leading_byte_offset = 0 - else: - b_leading_byte_offset = 8 * 8 * elems_in_bytes * k_dim - - if b_n_axis_atoms <= 1: - b_stride_byte_offset = 8 * elems_in_bytes * n_dim - else: - b_stride_byte_offset = 8 * elems_in_bytes * (b_swizzle_mode.swizzle_byte_size() // elems_in_bytes) + else: + # MN Major + b_leading_byte_offset = b_swizzle_mode.swizzle_byte_size() + b_stride_byte_offset = 8 * n_dim * b_elems_in_bytes @@ - T.initialize_descriptor(desc_b, B_buf.access_ptr("w"), b_swizzle_mode, + T.initialize_descriptor(desc_b, B_buf.access_ptr("r"), b_swizzle_mode, int(b_leading_byte_offset >> 4), int(b_stride_byte_offset >> 4)) @@ - (B_offset * elems_in_bytes) >> 4, + (B_offset * b_elems_in_bytes) >> 4,Also applies to: 313-314, 332-332
324-325
: Fix inverted K-major flags for RS.RS uses the same a/b K-major semantics as SS.
- self.a_transposed, - not self.b_transposed, + (not self.a_transposed), + self.b_transposed,
🧹 Nitpick comments (6)
src/target/codegen_cuda.cc (1)
1363-1397
: WGMMA-RS path looks good; minor naming/comment nitLogic mirrors SS but correctly streams
asm_code
. Consider renamingA_layout/B_layout
toa_is_k_major/b_is_k_major
and fixing the arg comments (arg 0 is shape; dtype is not an arg) for clarity.tilelang/intrinsics/wgmma_macro_generator.py (5)
104-109
: Remove unused arg and make prefix computation explicit.
n_dim
is unused; also guardinst_k
with a sanity check (tf32 can be tricky).- def _initialize_wgmma_prefix(self, n_dim: int = 16): + def _initialize_wgmma_prefix(self): inst_m, inst_n = 64, self.block_col_warps * self.warp_col_tiles - # 256 bits per instruction - inst_k = 256 // DataType(self.a_dtype).bits + # k derived from input dtype; ensure it's one of the valid WGMMA shapes. + bits = DataType(self.a_dtype).bits + assert 256 % bits == 0, f"Unsupported dtype bits for WGMMA prefix: {bits}" + inst_k = 256 // bits self.wgmma_prefix = f"m{inst_m}n{inst_n}k{inst_k}"- self._initialize_wgmma_prefix(self.n_dim) + self._initialize_wgmma_prefix()Please verify tf32’s effective width in your stack (if tvm “tf32” reports 19 bits, switch to an explicit dtype→k map).
Also applies to: 94-95
221-234
: Remove debug prints (noise in codegen path).These spam stdout during TIR lowering; drop or guard behind a verbose flag.
- print(f"a_leading_byte_offset: {a_leading_byte_offset >> 4}") - print(f"a_stride_byte_offset: {a_stride_byte_offset >> 4}") - - print(f"b_swizzle_atom_size: {b_swizzle_mode.swizzle_atom_size()}") - print(f"b_swizzle_byte_size: {b_swizzle_mode.swizzle_byte_size()}") - print(f"m_dim: {m_dim}") - print(f"n_dim: {n_dim}") - print(f"k_dim: {k_dim}") - print(f"micro_size_k: {micro_size_k}") - print(f"a_leading_byte_offset: {a_leading_byte_offset}") - print(f"a_stride_byte_offset: {a_stride_byte_offset}") - print(f"b_leading_byte_offset: {b_leading_byte_offset}") - print(f"b_stride_byte_offset: {b_stride_byte_offset}")
342-347
: Docstring mismatch: this is a load-layout helper.Tweak wording to avoid confusion.
- Create a layout function for storing MMA results into a fragment buffer. + Create a layout function for loading operand A into a fragment buffer.
364-367
: Drop redundant import inside method.
is_fragment
is already imported at module scope.- from tilelang.utils import is_fragment
143-143
: Tidy exception messages per linter hint (TRY003).Shorten or use a custom exception type; current messages are fine but flagged.
Also applies to: 384-384
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
src/target/codegen_cuda.cc
(5 hunks)tilelang/intrinsics/wgmma_macro_generator.py
(1 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
src/target/codegen_cuda.cc (4)
src/transform/storage_rewrite.cc (6)
scope
(674-678)scope
(674-674)n
(366-370)n
(366-366)n
(371-375)n
(371-371)tilelang/language/tir/op.py (2)
ptx_wgmma_ss
(1064-1103)ptx_wgmma_rs
(1106-1143)src/target/ptx.cc (8)
PrintWGMMAAssembly
(1186-1257)PrintWGMMAAssembly
(1186-1195)ParseMMAShape
(140-148)ParseMMAShape
(140-140)DTypeEnumToString
(107-109)DTypeEnumToString
(107-107)DTypeEnumToString
(111-113)DTypeEnumToString
(111-111)tilelang/language/builtin.py (2)
initialize_descriptor
(355-386)increase_descriptor_offset
(389-411)
tilelang/intrinsics/wgmma_macro_generator.py (8)
tilelang/utils/language.py (1)
is_fragment
(68-78)tilelang/layout/layout.py (3)
Layout
(13-148)is_equal
(136-145)inverse
(125-134)tilelang/layout/swizzle.py (3)
make_full_bank_swizzled_layout
(39-59)make_half_bank_swizzled_layout
(64-84)make_quarter_bank_swizzled_layout
(89-109)tilelang/layout/fragment.py (4)
is_equal
(209-213)Fragment
(13-213)replicate
(146-160)repeat
(123-144)tilelang/language/allocate.py (1)
alloc_descriptor
(129-135)tilelang/language/builtin.py (1)
initialize_descriptor
(355-386)tilelang/language/tir/op.py (2)
ptx_wgmma_ss
(1064-1103)ptx_wgmma_rs
(1106-1143)tilelang/intrinsics/mma_macro_generator.py (1)
get_store_index_map
(159-165)
🪛 Ruff (0.12.2)
tilelang/intrinsics/wgmma_macro_generator.py
104-104: Unused method argument: n_dim
(ARG002)
143-143: Avoid specifying long messages outside the exception class
(TRY003)
384-384: Avoid specifying long messages outside the exception class
(TRY003)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Cursor Bugbot
🔇 Additional comments (5)
src/target/codegen_cuda.cc (5)
898-901
: Descriptor refs: correct special‑casing for local.descriptorReturning the variable directly for descriptor buffers is the right behavior and aligns with how the descriptor is declared.
1721-1721
: LGTMNo issues spotted in tl_shuffle_elect emission.
1845-1845
: Scope guard update is correctSkipping the unsupported-scope check for
local.descriptor
is necessary after introducing descriptor allocations.
1723-1742
: Descriptor intrinsics emission LGTM — confirm offset widthinitialize_descriptor templating is correct. For increase_descriptor_offset(...), confirm the API expects a 32-bit offset; repository search did not locate the declaration, so either switch to std::int32_t (or a wider type) and add an explicit cast to avoid narrowing, or confirm plain int is intended. Location: src/target/codegen_cuda.cc:1739-1742
1809-1811
: Descriptor allocation: ensure header is availablesrc/target/codegen_cuda.cc already writes
#include <tl_templates/cuda/gemm.h>
into the generated decls, but I couldn't find a definition fortl::GmmaDescriptor
in the repo — ensuregemm.h
(or a transitive include) definestl::GmmaDescriptor
; if it does not, add#include <tl_templates/cuda/common.h>
in Finish() so the generated file includes the descriptor definition.
warp_row_tiles: int = 8, | ||
warp_col_tiles: int = 8, | ||
chunk: int = 16, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Default warp_row_tiles
conflicts with your own assertion (must be ≥16).
Set the default to 16 to avoid constructor-time assertion failures via parent init.
- warp_row_tiles: int = 8,
+ warp_row_tiles: int = 16,
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
warp_row_tiles: int = 8, | |
warp_col_tiles: int = 8, | |
chunk: int = 16, | |
warp_row_tiles: int = 16, | |
warp_col_tiles: int = 8, | |
chunk: int = 16, |
🤖 Prompt for AI Agents
In tilelang/intrinsics/wgmma_macro_generator.py around lines 83 to 85, the
default value for warp_row_tiles is 8 which conflicts with the class assertion
requiring warp_row_tiles >= 16 and can trigger assertion failures during parent
initialization; change the default warp_row_tiles to 16 (update the parameter
default in the function/class signature) so it satisfies the assertion by
default and avoid constructor-time failures.
TODO Items
T.alloc_descriptor
to create a static descriptor, allowing onlystart_address
updates within tiled GEMM.int8
,fp8
,tf32
.gemm_rs
.Summary by CodeRabbit
New Features
Layout & UX
API Changes