Skip to content

Conversation

LeiWang1999
Copy link
Member

@LeiWang1999 LeiWang1999 commented Sep 14, 2025

TODO Items

  • Introduce T.alloc_descriptor to create a static descriptor, allowing only start_address updates within tiled GEMM.
  • Add support for additional data types: int8, fp8, tf32.
  • Implement gemm_rs.

Summary by CodeRabbit

  • New Features

    • Added a WGMMA-based GEMM backend with automatic instruction selection and lowering.
    • New PTX WGMMA intrinsics and descriptor helpers (initialize/increase) plus alloc_descriptor for descriptor buffers.
    • Added a WMMA/WGMMA instruction emitter and device implementations.
  • Layout & UX

    • Expanded swizzled layout builders (wgmma/full/half/quarter bank), k-major/allow-pad options, improved layout/fragment repr and equality checks.
  • API Changes

    • Replaced numeric k-factor with boolean k_inner in GEMM layout builders.
    • GEMM lowering now accepts a layout_map argument.

- Added support for the WGMMA intrinsic in the TileLang framework, enabling efficient matrix multiplication on newer architectures.
- Refactored GEMM layout functions to accept a boolean parameter for K dimension handling, improving flexibility in layout generation.
- Updated layout inference logic to accommodate new WGMMA configurations and ensure compatibility with existing GEMM operations.
- Enhanced Python bindings for layout functions, allowing for better integration and usability in user-defined operations.
- Improved documentation for layout functions and GEMM operations to clarify usage and parameters.

These changes enhance the performance and usability of GEMM operations, particularly for advanced architectures, while maintaining backward compatibility with existing implementations.
…bility

- Improved code formatting across multiple files for better readability, including consistent indentation and line breaks.
- Updated layout function signatures to enhance clarity, particularly in `gemm_layouts.cc`, `layout.cc`, and `layout.h`.
- Refactored lambda functions in `builtin.cc` and `gemm_py.cc` for improved structure and maintainability.
- Enhanced comments and documentation in layout-related files to clarify usage and parameters.

These changes contribute to a cleaner codebase and improved maintainability of layout functions in the TileLang framework.
Copy link

👋 Hi! Thank you for contributing to the TileLang project.

Please remember to run bash format.sh in the root directory of the project to ensure your changes are properly linted and formatted. This will help ensure your contribution passes the format check.

We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work!

🚀

Copy link
Contributor

coderabbitai bot commented Sep 14, 2025

Note

Other AI code review bot(s) detected

CodeRabbit has detected other AI code review bot(s) in this pull request and will avoid duplicating their findings in the review comments. This may lead to a less comprehensive review.

Walkthrough

Replaces numeric k-factor GEMM layout parameters with boolean k_inner across layout APIs and callers; adds WGMMA support (intrinsics, PTX/CUDA codegen, descriptors), new Python layout/descriptor helpers and equality checks, and a runtime GemmInst dispatch with a new GemmWGMMA backend.

Changes

Cohort / File(s) Summary
GEMM layout API (kfactor → k_inner)
src/layout/gemm_layouts.cc, src/layout/layout.h, src/op/gemm.cc, tilelang/layout/swizzle.py, tilelang/layout/__init__.py
Replace integer kfactor/transpose codes with boolean k_inner / k_major flags in layout constructors and callers; swizzle FFI extended to accept k_major/allow_pad and new wgmma swizzle entrypoint exposed.
Layout / Fragment equality & repr
src/layout/layout.cc, tilelang/layout/layout.py, tilelang/layout/fragment.py
Added FFI bindings tl.Layout_is_equal and tl.Fragment_is_equal; added Layout.is_equal/get_forward_index/repr and Fragment.is_equal (and expanded Fragment.repr).
WGMMA intrinsics and PTX backend
src/op/builtin.h, src/op/builtin.cc, src/target/ptx.h, src/target/ptx.cc, src/target/codegen_cuda.cc, src/tl_templates/cuda/common.h, src/tl_templates/cuda/instruction/wgmma.h
Added tl.ptx_wgmma_ss/rs and descriptor intrinsics; introduced PTX-side DataType enum, WGMMA config validation, PrintWGMMAAssembly, and extensive wgmma instruction specializations; CUDA codegen emits descriptor ops and wgmma intrinsics; common.h adds GmmaDescriptor and helpers.
Descriptor allocation & high-level helpers
tilelang/language/allocate.py, tilelang/language/__init__.py, tilelang/language/builtin.py
Added alloc_descriptor (scope "local.descriptor") and high-level builtin wrappers initialize_descriptor / increase_descriptor_offset with argument normalization and intrinsic lowering.
GEMM dispatch and WGMMA backend
tilelang/tileop/gemm/__init__.py, tilelang/tileop/gemm/gemm_mma.py, tilelang/tileop/gemm/gemm_wgmma.py, src/op/gemm_py.cc
Runtime GemmInst selection via FFI GemmPyGemmInst; gemm infer/ lower dispatch to GemmMMA or new GemmWGMMA implementation; gemm_py.lower and GemmMMA.lower signatures gain layout_map; new GemmWGMMA implements infer_layout and lower for SS/RS variants.
Transforms: descriptor scope handling
src/transform/lower_device_storage_access_info.cc, src/transform/storage_rewrite.cc
Treat ".descriptor" scope as excluded from special-tag memory handling and memory-info-driven merging; Allocate lowering skips descriptor-tagged buffers.
WGMMA macro emitter (Python)
tilelang/intrinsics/wgmma_macro_generator.py
New TensorCoreIntrinEmitter with swizzle modes, descriptor-aware operand handling, wgmma/rs emission, and fragment load/store layout helpers.
TL / TIR wrappers & builtins
tilelang/language/ast/ir.py, tilelang/language/tir/ir.py, tilelang/language/tir/op.py, tilelang/language/builtin.py
Added ptx_wgmma_ss/rs wrappers and new op wrappers; added high-level builtin descriptor helpers.
Layout helpers re-exports & Python API surface
tilelang/layout/__init__.py, tilelang/layout/swizzle.py, tilelang/layout/fragment.py, tilelang/layout/layout.py
Re-exported new swizzle builders; extended make_swizzled_layout signature; added bank-swizzle helpers; removed a module helper and added instance equality/ repr helpers.
Templates / headers / misc
src/tl_templates/cuda/common.h, src/tl_templates/cuda/gemm.h, src/tl_templates/cuda/instruction/wgmma.h, tilelang/tileop/gemm/gemm_base.py, various small edits
Added descriptor types and includes; include wgmma.h for sm90+; added ElementBits/WgmmaOutputRegs traits and wgmma impl; minor typing/docstring/formatting tweaks.

Sequence Diagram(s)

sequenceDiagram
  autonumber
  actor User
  participant GemmPy
  participant FFI as _ffi_api.GemmPyGemmInst
  participant Dispatcher as GemmPy.dispatch
  participant Impl as GemmMMA/GemmWGMMA
  participant Emitter as TensorCoreIntrinEmitter
  participant TIR as TIR Lowering
  participant CG as CUDA/PTX Codegen

  User->>GemmPy: infer_layout / lower(request)
  GemmPy->>FFI: GemmPyGemmInst(thread_nums, target)
  FFI-->>GemmPy: GemmInst (MMA / WGMMA / MFMA)
  GemmPy->>Dispatcher: select implementation
  Dispatcher->>Impl: dispatch to GemmMMA or GemmWGMMA
  Impl->>Emitter: configure(dtypes, trans, tiles, thread_var)
  Impl->>TIR: emit TIR with intrinsics (ptx_mma_* or ptx_wgmma_*)
  TIR->>CG: emit descriptor init / wgmma PTX assembly
  CG-->>User: compiled kernel
Loading
sequenceDiagram
  autonumber
  participant TIR
  participant TL as tl.builtin
  participant CUDA as codegen_cuda
  participant PTX as codegen_ptx

  note over TIR,PTX: WGMMA-SS descriptor flow
  TIR->>TL: initialize_descriptor(descA, A_smem_ptr, layout/offsets)
  TIR->>TL: initialize_descriptor(descB, B_smem_ptr, layout/offsets)
  TIR->>TL: ptx_wgmma_ss(..., A_desc, A_off, B_desc, B_off, C_ptr, ...)
  TL-->>CUDA: intrinsic calls with descriptors
  CUDA->>PTX: PrintWGMMAAssembly(a_is_shared=true, validated config)
  PTX-->>CUDA: PTX assembly string
  CUDA-->>TIR: inline asm emitted into kernel
Loading

Estimated code review effort

🎯 5 (Critical) | ⏱️ ~120 minutes

Possibly related PRs

Suggested reviewers

  • tzj-fxz

Poem

I nibble bits and flip K's side,
Descriptors snug where bytes reside.
Swizzles hop and warps align,
WGMMA hums — tiles combine.
Hop, compile, my carrot—speed! 🐇✨

Pre-merge checks and finishing touches

❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 32.00% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
✅ Passed checks (2 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title Check ✅ Passed The title "[TileOp] Implement WGMMA for T.gemm_v2" is concise and directly describes the primary change set; the diff adds WGMMA support across TileOp/GEMM (intrinsics, codegen, layouts, descriptor support, and new gemm_wgmma lowering), so the title accurately reflects the main intent and scope of the PR.
✨ Finishing touches
  • 📝 Generate Docstrings
🧪 Generate unit tests
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Tip

👮 Agentic pre-merge checks are now available in preview!

Pro plan users can now enable pre-merge checks in their settings to enforce checklists before merging PRs.

  • Built-in checks – Quickly apply ready-made checks to enforce title conventions, require pull request descriptions that follow templates, validate linked issues for compliance, and more.
  • Custom agentic checks – Define your own rules using CodeRabbit’s advanced agentic capabilities to enforce organization-specific policies and workflows. For example, you can instruct CodeRabbit’s agent to verify that API documentation is updated whenever API schema files are modified in a PR. Note: Upto 5 custom checks are currently allowed during the preview period. Pricing for this feature will be announced in a few weeks.

Please see the documentation for more information.

Example:

reviews:
  pre_merge_checks:
    custom_checks:
      - name: "Undocumented Breaking Changes"
        mode: "warning"
        instructions: |
          Pass/fail criteria: All breaking changes to public APIs, CLI flags, environment variables, configuration keys, database schemas, or HTTP/GraphQL endpoints must be documented in the "Breaking Change" section of the PR description and in CHANGELOG.md. Exclude purely internal or private changes (e.g., code not exported from package entry points or explicitly marked as internal).

Please share your feedback with us on this Discord post.


Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

…GMMA

- Introduced new TileLang builtins `initialize_descriptor` and `increase_descriptor_offset` to facilitate descriptor management for WGMMA operations.
- Updated `builtin.cc` and `builtin.h` to define and document the new builtins, enhancing the framework's capabilities for descriptor handling.
- Modified `codegen_cuda.cc` and `ptx.cc` to integrate the new builtins into the code generation process, ensuring proper assembly generation for WGMMA operations.
- Enhanced the `GemmWGMMA` class to utilize the new descriptor functionalities, improving the efficiency of matrix multiplication operations.
- Updated related tests and documentation to reflect the new features and ensure comprehensive coverage.

These changes enhance the TileLang framework's support for advanced matrix operations on newer architectures, improving performance and usability.
- Enhanced code formatting across multiple files for better readability, including consistent indentation and line breaks.
- Updated function signatures and comments in `builtin.h`, `codegen_cuda.cc`, and `ptx.cc` to improve clarity.
- Refactored descriptor initialization and offset manipulation functions in `builtin.py` and `wgmma_macro_generator.py` for improved structure.
- Cleaned up unnecessary whitespace and improved alignment in `common.h` and `allocate.py`.

These changes contribute to a cleaner and more maintainable codebase in the TileLang framework.
- Updated the subproject commit for `cutlass` to indicate a dirty state.
- Refactored the `UpdateAnalyzer` function in `layout.cc` to call `LayoutNode::getVarMap()` instead of `getVarMap()`, improving clarity and ensuring proper context for variable mapping.

These changes enhance the maintainability and clarity of the layout handling in the TileLang framework.
@LeiWang1999 LeiWang1999 marked this pull request as ready for review September 16, 2025 08:12
cursor[bot]

This comment was marked as outdated.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 6

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
tilelang/language/customize.py (1)

193-197: Bug: extent computation uses list max instead of elementwise max

max(src_extent, dst_extent) returns one of the lists lexicographically, not per-dimension. This can mis-size regions and corrupt memory ops.

Apply:

-    extent = max(src_extent, dst_extent)
+    # elementwise extent unification
+    extent = [T.max(a, b) for a, b in zip(src_extent, dst_extent)]
🧹 Nitpick comments (32)
tilelang/language/__init__.py (1)

45-46: Export looks good; clean up unused noqa.

  • Re-exporting alloc_descriptor is correct and matches allocate.py.
  • Ruff flags Line 45 for unused # noqa: F401 (RUF100). Drop it here (others weren’t flagged in this diff).

Apply:

-    alloc_descriptor,  # noqa: F401
+    alloc_descriptor,
tilelang/layout/layout.py (3)

92-94: Add docstring/return type for get_forward_index for parity with getters.

Minor consistency nit. Consider:

-    def get_forward_index(self):
-        return self.index
+    def get_forward_index(self) -> PrimExpr | list[PrimExpr]:
+        """Return the computed forward index expression(s)."""
+        return self.index

136-146: API parity: provide Pythonic equality too.

Keep is_equal, but also implement eq delegating to it for ergonomic comparisons; keep hash=None to avoid hashing mutable objects.

     def is_equal(self, other: "Layout") -> bool:
         """
         Check if the current layout is equal to another layout.
         """
         return _ffi_api.Layout_is_equal(self, other)

+    def __eq__(self, other: object) -> bool:
+        if not isinstance(other, Layout):
+            return NotImplemented
+        return self.is_equal(other)
+    __hash__ = None

147-148: repr: avoid huge dumps when vars/index grow.

Consider truncating sequences for readability in logs.

-        return f"Layout<{self.get_input_shape()}->{self.get_output_shape()}, {self.get_forward_vars()} -> {self.get_forward_index()}>"
+        fv = self.get_forward_vars()
+        fi = self.get_forward_index()
+        def _short(x): 
+            s = str(x)
+            return s if len(s) <= 120 else s[:117] + "..."
+        return f"Layout<{self.get_input_shape()}->{self.get_output_shape()}, {_short(fv)} -> {_short(fi)}>"
src/op/builtin.h (2)

164-185: Fix WGMMA doc comments to match actual RS/SS operand forms.

Current comment for ptx_wgmma_rs refers to A_descriptor; RS variant uses A_buf (regular pointer/buffer) per Python wrappers. Also B_offset is documented as Var (should be PrimExpr). Update comments to avoid API confusion.

- *  void ptx_wgmma_rs(StringImm accum_dtype, StringImm wgmma_prefix, bool
- * trans_a, bool trans_b, StringImm a_dtype_abbrv, StringImm b_dtype_abbrv,
- * StringImm accum_dtype_abbrv, Var A_descriptor, PrimExpr A_offset, Var
- * B_descriptor, Var B_offset, Var C_data, Var C_offset, bool scale_out, bool
+ *  void ptx_wgmma_rs(StringImm accum_dtype, StringImm wgmma_prefix, bool
+ * trans_a, bool trans_b, StringImm a_dtype_abbrv, StringImm b_dtype_abbrv,
+ * StringImm accum_dtype_abbrv, Var A_buf, PrimExpr A_offset, Var
+ * B_descriptor, PrimExpr B_offset, Var C_data, PrimExpr C_offset, bool scale_out, bool
  * scale_in_a, bool scale_in_b);

344-361: Descriptor intrinsic docs: correct operation wording.

increase_descriptor_offset increments the offset; the block comment says “setting the start address.” Align wording to avoid misuse.

- * \brief tilelang intrinsic for setting the start address of a descriptor
- * buffer for wgmma/utcmma.
+ * \brief tilelang intrinsic for increasing the offset of a descriptor
+ * buffer for wgmma/utcmma.
tilelang/language/tir/op.py (2)

1064-1104: Nit: Docstring says “wmma” but this is WGMMA; also consider briefly documenting operand order.

Purely cosmetic; helps future maintainers and avoids confusion with WMMA.

-    """TVM intrinsic for ptx tensor core wmma instructions
+    """TVM intrinsic for PTX warp-group MMA (WGMMA) instructions
+    Operand order: prefix, trans_a, trans_b, a/b/accum dtype abbrvs,
+    A_desc, A_offset, B_desc, B_offset, C_data, C_offset, scale_out, scale_in_a, scale_in_b.

1106-1144: Nit: Same WGMMA terminology/doc tweak as above; otherwise wrapper matches builtin.

LGTM functionally; mirrors 15-arg registration.

-    return call_intrin(
+    # PTX warp-group MMA (WGMMA) RS variant: A from register, B from descriptor.
+    return call_intrin(
tilelang/layout/__init__.py (1)

6-13: Remove unused “noqa: F401” directives or enable F401 in Ruff config.

Ruff flags these as unused (RUF100). Either drop them or configure Ruff to honor F401.

-from .swizzle import (
-    make_swizzled_layout,  # noqa: F401
-    make_wgmma_swizzled_layout,  # noqa: F401
-    make_full_bank_swizzled_layout,  # noqa: F401
-    make_half_bank_swizzled_layout,  # noqa: F401
-    make_quarter_bank_swizzled_layout,  # noqa: F401
-)
-from .gemm_sp import make_metadata_layout  # noqa: F401
+from .swizzle import (
+    make_swizzled_layout,
+    make_wgmma_swizzled_layout,
+    make_full_bank_swizzled_layout,
+    make_half_bank_swizzled_layout,
+    make_quarter_bank_swizzled_layout,
+)
+from .gemm_sp import make_metadata_layout
tilelang/tileop/gemm/gemm_mma.py (2)

60-77: Use layout_map (or underscore the arg) to avoid ARG002 and keep parity with WGMMA.

Mirror WGMMA: when provided, feed A/B shared layouts into the emitter; otherwise prefix arg as _layout_map to appease linters.

-    def lower(self, layout_map: dict, target: Target, thread_nums: int, thread_var: tir.Var):
+    def lower(self, layout_map: dict, target: Target, thread_nums: int, thread_var: tir.Var):
@@
         mma_emitter = TensorCoreIntrinEmitter(
@@
             thread_var=thread_var,
         )
+
+        # Optional: honor externally inferred layouts if present (parity with WGMMA)
+        if self.A in layout_map:
+            mma_emitter._assign_a_shared_layout(layout_map[self.A])
+        if self.B in layout_map:
+            mma_emitter._assign_b_shared_layout(layout_map[self.B])

90-91: Replace assert with explicit validation (and check divisibility).

Python asserts can be stripped with -O; raise a ValueError and also enforce block_K % micro_size_k == 0 to match loop step.

-        assert block_K >= micro_size_k, f"block_K ({block_K}) must be >= micro_size_k ({micro_size_k})"
+        if block_K < micro_size_k or (block_K % micro_size_k) != 0:
+            raise ValueError(
+                f"Invalid K tile: block_K={block_K}, micro_size_k={micro_size_k} "
+                "(must be >= and divisible)."
+            )
src/op/gemm.cc (1)

45-48: Documentation inconsistency between comment and implementation.

The documentation states that kPack must be 1, but the implementation at lines 71-73 allows values of both 1 and 2. This creates confusion about the actual requirements.

Either update the documentation to reflect that kPack can be 1 or 2, or enforce the restriction that it must be 1:

- * @note If `kPack` is provided it must be 1; otherwise the constructor
- *       fails with an ICHECK (runtime assertion). No other validation is
- *       performed here.
+ * @note If `kPack` is provided it must be 1 or 2; otherwise the constructor
+ *       fails with an ICHECK (runtime assertion). No other validation is
+ *       performed here.
tilelang/language/builtin.py (1)

375-375: Consider moving error messages to exception classes.

While not critical, defining error messages inside exception classes improves maintainability and reusability.

Consider creating custom exception classes:

class InvalidDescriptorTypeError(TypeError):
    """Raised when descriptor is not a Buffer or BufferLoad."""
    def __init__(self):
        super().__init__("Descriptor must be a tvm.tir.Buffer or tvm.tir.BufferLoad.")

class InvalidDescriptorShapeError(ValueError):
    """Raised when descriptor is not a 1D buffer of size 1."""
    def __init__(self):
        super().__init__("Descriptor must be a 1D buffer of size 1.")

Also applies to: 378-378, 401-401, 404-404

tilelang/layout/swizzle.py (2)

23-23: Use Optional type annotation for nullable parameter.

The continuity parameter can be None but isn't typed as Optional.

Add proper type annotation:

+from typing import Optional
+
 def make_wgmma_swizzled_layout(buffer: tvm.tir.Buffer,
-                               continuity: int = None,
+                               continuity: Optional[int] = None,
                                k_major: bool = True):

54-54: Consider more descriptive error messages.

The error messages could be more helpful by describing what arguments are expected.

Improve error messages:

-        raise ValueError(f"Invalid arguments: {args}")
+        raise ValueError(f"Expected either a single buffer or (stride, continuous, element_size), got {len(args)} arguments: {args}")

Also applies to: 79-79, 104-104

src/layout/gemm_layouts.cc (2)

573-574: Parameter name inconsistency with function documentation.

The parameter name k_inner doesn't match the documentation's description which refers to whether the "K dimension is in the inner loop". Consider renaming to k_is_inner or is_k_inner for better clarity.


532-541: Missing implementation of k_major parameter in Volta layout.

The function signature was updated to use bool k_inner but the implementation still uses k_inner directly as a boolean flag without considering the k-major semantics that the rest of the codebase expects.

Based on the pattern in makeGemmABLayout and makeGemmABLayoutHopper, you may want to verify that this implementation correctly handles the k-major/k-inner semantics.

tilelang/tileop/gemm/gemm_wgmma.py (1)

110-111: Verify RS implementation for A in fragment and B in shared.

The _gemm_rsr function name and docstring mention loading data from "shared buffers A_shared and B_shared", but the RS variant should have A in registers/fragments. The docstring appears to be copied from the SS variant.

Apply this diff to fix the docstring:

             @T.prim_func
             def _gemm_rsr() -> None:
                 """
-                The inner macro that loads data from shared buffers A_shared and
-                B_shared into local fragments, then issues Tensor Core mma ops,
-                accumulating into C_local.
+                The inner macro that uses data from local fragment A_local and
+                shared buffer B_shared, then issues Tensor Core mma ops,
+                accumulating into C_local.
                 """
                 mma_emitter.wgmma(A_local, B_shared, C_local, clear_accum)
tilelang/intrinsics/wgmma_macro_generator.py (3)

104-104: Remove unused parameter n_dim from method signature.

The method _initialize_wgmma_prefix has an unused parameter n_dim=16 that shadows the instance variable self.n_dim.

Apply this diff to fix:

-    def _initialize_wgmma_prefix(self, n_dim: int = 16):
+    def _initialize_wgmma_prefix(self):
         inst_m, inst_n = 64, self.block_col_warps * self.warp_col_tiles

133-143: Extract swizzle mode detection to centralized utility.

The _determinate_swizzle_mode method performs layout equality checks against multiple swizzle patterns. This logic could be refactored into a centralized layout utility to avoid duplication if similar detection is needed elsewhere.

Would you like me to help create a centralized swizzle mode detection utility that could be reused across the codebase?


345-353: Improve error message specificity for unsupported dtypes.

The error message for unsupported dtypes could be more informative by including the actual bit width.

Apply this diff:

         else:
-            raise ValueError(f"Unsupported dtype {dtype}")
+            raise ValueError(f"Unsupported dtype {dtype} with {dtype_bits} bits for MMA load layout")
src/layout/layout.cc (1)

495-505: Parameter naming inconsistency in swizzled layout creation.

The make_swizzled_layout function uses allow_pad to choose between makeGemmABLayout and makeGemmABLayoutHopper. The parameter name allow_pad doesn't clearly convey that it's selecting between different hardware layout strategies (standard vs Hopper).

Consider renaming to be more descriptive:

       .def("tl.make_swizzled_layout",
            [](int stride, int continuous, int element_size, bool k_inner,
-              bool allow_pad = true) {
-             if (allow_pad) {
+              bool use_hopper_layout = false) {
+             if (!use_hopper_layout) {
                return makeGemmABLayout(stride, continuous, continuous,
                                        element_size, k_inner);
              } else {
                return makeGemmABLayoutHopper(stride, continuous, continuous,
                                              element_size, k_inner);
              }
            })
tilelang/tileop/gemm/__init__.py (2)

82-98: Add parameter validation for thread_nums.

While the FFI call handles the selection logic, it would be good to validate that thread_nums is positive before passing it to the FFI.

 def _select_gemm_instruction(self, thread_nums: int, target: Target) -> GemmInst:
     """Select the appropriate GEMM instruction based on target and thread configuration.
     
     The selection logic follows this priority:
     1. WGMMA for Hopper architecture with sufficient matrix size and warp count
     2. MFMA for CDNA (AMD) architecture
     3. MMA for CUDA architecture
     4. Fallback to MMA for other cases
     
     Args:
         thread_nums: Number of threads in the block
         target: Target architecture
         
     Returns:
         GemmInst: The selected GEMM instruction type
     """
+    if thread_nums <= 0:
+        raise ValueError(f"thread_nums must be positive, got {thread_nums}")
     return GemmInst(_ffi_api.GemmPyGemmInst(self, int(thread_nums), target))

118-118: Consider using a more specific exception message.

The error message could be more informative by including what implementations are available.

-            raise NotImplementedError("MFMA is not implemented")
+            raise NotImplementedError("MFMA is not implemented. Available implementations: MMA, WGMMA")
src/target/ptx.cc (2)

1053-1168: Complex operand generation but missing a_is_shared validation for register path.

The function generates WGMMA operands correctly but doesn't validate that when a_is_shared is false, the operation is actually supported for register-based A operands.

Consider adding validation:

 inline std::tuple<std::string, std::string, std::string, std::string>
 GetWGMMAOperands(int m, int n, int k, ptx::DataType dtype_a,
                  ptx::DataType dtype_b, ptx::DataType dtype_c, bool sparse,
                  bool a_is_shared) {
+  // WGMMA with register-based A operand has limitations
+  if (!a_is_shared) {
+    // Add any specific validation for register-based A operands if needed
+    // based on NVIDIA documentation
+  }
   std::stringstream templates, inputs, outputs, predicate;

1263-1266: Consider extracting predicate setup as a constant.

The predicate setup code could be defined as a constant string for better maintainability.

+  constexpr const char* PREDICATE_SETUP = 
+      "{.reg .pred p;\n"
+      "setp.ne.b32 p, {predicate}, 0;\n";
+
   std::string asm_code = R"(
   {
     __asm__ __volatile__(
-      "{.reg .pred p;\n"
-      "setp.ne.b32 p, {predicate}, 0;\n"
+      ")" PREDICATE_SETUP R"(
       "wgmma.mma_async{.sparse}.sync.aligned{.shape}{.dtype}{.atype}{.btype}"
       "{templates};\n}"
       : {outputs}
       : {inputs});
   }
 )";
tilelang/language/customize.py (2)

160-179: Deduplicate get_extent; current helper misses BufferLoad and can diverge

You re-implement get_extent but omit BufferLoad handling (supported in tilelang/language/copy.py). Prefer reusing the canonical helper to avoid drift.

Apply:

@@
-    def get_extent(data):
-        """
-        Return the inferred extent (shape) of a buffer-like object.
-        ...
-        """
-        if isinstance(data, Var) and T.has_let_value(data):
-            data = T.get_let_value(data)
-        if isinstance(data, Buffer):
-            return data.shape
-        elif isinstance(data, BufferRegion):
-            return [x.extent for x in data.region]
-        else:
-            return None
-
-    src_extent = get_extent(value)
-    dst_extent = get_extent(dst)
+    src_extent = _get_extent(value)
+    dst_extent = _get_extent(dst)

Also add the import near the top of this file:

from tilelang.language.copy import get_extent as _get_extent

82-105: Honor provided extents in buffer_region_to_tile_region

The extents parameter is only asserted but ignored. This prevents aligning extents when mixing BufferRegion with BufferLoad.

Apply:

@@
-    return region(T.BufferLoad(buffer_region.buffer, mins), access_type, *region_extents)
+    # Override trailing dims with requested extents if provided
+    if extents:
+        region_extents = list(region_extents)
+        region_extents[-len(extents):] = extents
+    return region(T.BufferLoad(buffer_region.buffer, mins), access_type, *region_extents)
tilelang/language/allocate.py (1)

129-136: Add type hints for API clarity

Tighten signature and return type.

Apply:

-def alloc_descriptor(dtype="uint64", scope="local.descriptor"):
+def alloc_descriptor(dtype: str = "uint64", scope: str = "local.descriptor") -> T.Buffer:
     """Allocate a descriptor buffer for wgmma and utcmma.
 
     Returns:
         T.Buffer: A TVM buffer object allocated as a descriptor
     """
     return T.alloc_buffer([1], dtype, scope=scope)
src/layout/layout.h (3)

163-167: Prefer explicit enum over boolean for K placement

A boolean is easy to misuse. Consider a strongly-typed enum (e.g., enum class KPlacement { Inner, Outer };) to make call sites self-documenting and prevent silent int→bool coercions.

Example:

enum class KPlacement { Inner, Outer };
Layout makeGemmABLayout(int mat_stride, int mat_continuous, int continuity,
                        int element_size, KPlacement k_place = KPlacement::Inner);

168-169: CDNA kPack rename verified — code updated; fix lingering doc/comments

makeGemmABLayoutCDNA (declaration/definition) and callers use int kPack and Python bindings expose "kPack".

  • Update remaining comment references to the old name:
    • src/op/gemm_py.h — comment referencing "k_pack" (around line ~30).
    • src/op/gemm_sp.h — similar comment (around line ~25).

177-178: Volta layout: no functional change required — callers pass explicit k_inner; remove or flip default for clarity.
makeGemmVoltaABLayout(..., bool k_inner = true) is declared in src/layout/layout.h; call sites pass explicit values (src/op/gemm.cc:479, 492).

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between ae9b706 and eac5433.

📒 Files selected for processing (30)
  • src/layout/gemm_layouts.cc (4 hunks)
  • src/layout/layout.cc (4 hunks)
  • src/layout/layout.h (2 hunks)
  • src/op/builtin.cc (2 hunks)
  • src/op/builtin.h (2 hunks)
  • src/op/gemm.cc (7 hunks)
  • src/op/gemm_py.cc (4 hunks)
  • src/op/gemm_py.h (1 hunks)
  • src/target/codegen_cuda.cc (5 hunks)
  • src/target/ptx.cc (6 hunks)
  • src/target/ptx.h (1 hunks)
  • src/tl_templates/cuda/common.h (4 hunks)
  • src/transform/lower_device_storage_access_info.cc (1 hunks)
  • src/transform/storage_rewrite.cc (2 hunks)
  • tilelang/intrinsics/wgmma_macro_generator.py (1 hunks)
  • tilelang/language/__init__.py (4 hunks)
  • tilelang/language/allocate.py (1 hunks)
  • tilelang/language/ast/ir.py (2 hunks)
  • tilelang/language/builtin.py (2 hunks)
  • tilelang/language/customize.py (10 hunks)
  • tilelang/language/tir/ir.py (1 hunks)
  • tilelang/language/tir/op.py (1 hunks)
  • tilelang/layout/__init__.py (1 hunks)
  • tilelang/layout/fragment.py (1 hunks)
  • tilelang/layout/layout.py (2 hunks)
  • tilelang/layout/swizzle.py (1 hunks)
  • tilelang/tileop/gemm/__init__.py (3 hunks)
  • tilelang/tileop/gemm/gemm_base.py (2 hunks)
  • tilelang/tileop/gemm/gemm_mma.py (2 hunks)
  • tilelang/tileop/gemm/gemm_wgmma.py (1 hunks)
🧰 Additional context used
🧬 Code graph analysis (23)
tilelang/language/tir/ir.py (2)
tilelang/language/tir/op.py (2)
  • ptx_wgmma_ss (1064-1103)
  • ptx_wgmma_rs (1106-1143)
tilelang/language/ast/ir.py (1)
  • _dtype_forward (1876-1884)
tilelang/language/allocate.py (2)
src/transform/storage_rewrite.cc (4)
  • dtype (696-702)
  • dtype (696-696)
  • scope (674-678)
  • scope (674-674)
tilelang/language/ast/ir.py (1)
  • alloc_buffer (441-508)
src/transform/lower_device_storage_access_info.cc (1)
src/transform/storage_rewrite.cc (2)
  • scope (674-678)
  • scope (674-674)
tilelang/layout/layout.py (1)
tilelang/layout/fragment.py (1)
  • is_equal (209-213)
tilelang/language/ast/ir.py (2)
tilelang/language/tir/op.py (2)
  • ptx_wgmma_ss (1064-1103)
  • ptx_wgmma_rs (1106-1143)
tilelang/language/tir/ir.py (1)
  • _dtype_forward (156-164)
tilelang/layout/swizzle.py (2)
tilelang/language/ast/ir.py (1)
  • buffer (93-161)
src/layout/swizzle.h (1)
  • tvm (12-70)
tilelang/tileop/gemm/gemm_wgmma.py (6)
tilelang/tileop/gemm/gemm_base.py (17)
  • GemmBase (12-120)
  • infer_layout (15-16)
  • policy (119-120)
  • M (34-35)
  • N (38-39)
  • in_dtype (54-56)
  • accum_dtype (59-60)
  • trans_A (46-47)
  • trans_B (50-51)
  • chunk (63-64)
  • is_gemm_ss (21-22)
  • K (42-43)
  • A (67-68)
  • B (71-72)
  • C (75-76)
  • lower (18-19)
  • clear_accum (107-108)
tilelang/layout/swizzle.py (1)
  • make_wgmma_swizzled_layout (22-34)
tilelang/intrinsics/wgmma_macro_generator.py (6)
  • TensorCoreIntrinEmitter (63-477)
  • make_mma_store_layout (423-477)
  • make_mma_load_layout (311-421)
  • _assign_a_shared_layout (96-98)
  • _assign_b_shared_layout (100-102)
  • wgmma (145-233)
tilelang/utils/language.py (2)
  • is_shared (25-39)
  • is_fragment (68-78)
tilelang/transform/simplify.py (1)
  • _Simplify (30-49)
tilelang/tileop/gemm/gemm_mma.py (3)
  • infer_layout (15-58)
  • is_gemm_ss (204-205)
  • lower (60-202)
tilelang/intrinsics/wgmma_macro_generator.py (6)
tilelang/utils/language.py (1)
  • is_fragment (68-78)
tilelang/layout/swizzle.py (3)
  • make_full_bank_swizzled_layout (39-59)
  • make_half_bank_swizzled_layout (64-84)
  • make_quarter_bank_swizzled_layout (89-109)
tilelang/layout/fragment.py (4)
  • is_equal (209-213)
  • Fragment (13-213)
  • replicate (146-160)
  • repeat (123-144)
tilelang/language/builtin.py (1)
  • initialize_descriptor (355-386)
tilelang/language/tir/op.py (2)
  • ptx_wgmma_ss (1064-1103)
  • ptx_wgmma_rs (1106-1143)
tilelang/intrinsics/mma_macro_generator.py (1)
  • get_store_index_map (159-165)
src/target/codegen_cuda.cc (4)
src/transform/storage_rewrite.cc (2)
  • scope (674-678)
  • scope (674-674)
tilelang/language/tir/op.py (2)
  • ptx_wgmma_ss (1064-1103)
  • ptx_wgmma_rs (1106-1143)
src/target/ptx.cc (2)
  • PrintWGMMAAssembly (1235-1306)
  • PrintWGMMAAssembly (1235-1244)
tilelang/language/builtin.py (2)
  • initialize_descriptor (355-386)
  • increase_descriptor_offset (389-411)
src/op/builtin.h (2)
tilelang/language/tir/op.py (2)
  • ptx_wgmma_ss (1064-1103)
  • ptx_wgmma_rs (1106-1143)
tilelang/language/builtin.py (2)
  • initialize_descriptor (355-386)
  • increase_descriptor_offset (389-411)
src/tl_templates/cuda/common.h (3)
src/tl_templates/cuda/copy_sm90.h (1)
  • void (255-258)
src/tl_templates/cuda/ldsm.h (12)
  • void (7-14)
  • void (16-23)
  • void (25-33)
  • void (35-42)
  • void (44-52)
  • void (54-62)
  • void (64-70)
  • void (72-79)
  • void (81-89)
  • void (91-98)
  • void (100-108)
  • void (110-119)
tilelang/language/builtin.py (2)
  • initialize_descriptor (355-386)
  • increase_descriptor_offset (389-411)
src/target/ptx.h (1)
src/target/ptx.cc (2)
  • PrintWGMMAAssembly (1235-1306)
  • PrintWGMMAAssembly (1235-1244)
tilelang/language/builtin.py (3)
src/op/builtin.h (1)
  • tvm (13-363)
tilelang/language/ast/ir.py (1)
  • evaluate (1319-1331)
tilelang/language/tir/op.py (1)
  • call_intrin (119-144)
src/op/gemm.cc (2)
tilelang/tileop/gemm/gemm_base.py (4)
  • trans_A (46-47)
  • trans_B (50-51)
  • A (67-68)
  • B (71-72)
src/layout/gemm_layouts.cc (2)
  • makeGemmABLayout (573-592)
  • makeGemmABLayout (573-574)
tilelang/language/customize.py (1)
tilelang/language/copy.py (1)
  • get_extent (105-118)
src/op/builtin.cc (2)
tilelang/language/tir/op.py (2)
  • ptx_wgmma_ss (1064-1103)
  • ptx_wgmma_rs (1106-1143)
tilelang/language/builtin.py (2)
  • initialize_descriptor (355-386)
  • increase_descriptor_offset (389-411)
tilelang/tileop/gemm/gemm_mma.py (3)
tilelang/tileop/gemm/gemm_wgmma.py (1)
  • lower (64-125)
tilelang/tileop/gemm/__init__.py (1)
  • lower (76-80)
tilelang/tileop/gemm/gemm_base.py (1)
  • lower (18-19)
tilelang/language/__init__.py (1)
tilelang/language/allocate.py (1)
  • alloc_descriptor (129-135)
tilelang/layout/__init__.py (2)
tilelang/layout/swizzle.py (5)
  • make_swizzled_layout (10-18)
  • make_wgmma_swizzled_layout (22-34)
  • make_full_bank_swizzled_layout (39-59)
  • make_half_bank_swizzled_layout (64-84)
  • make_quarter_bank_swizzled_layout (89-109)
tilelang/layout/gemm_sp.py (1)
  • make_metadata_layout (98-109)
tilelang/tileop/gemm/__init__.py (4)
tilelang/ir.py (1)
  • GemmWarpPolicy (30-39)
tilelang/tileop/gemm/gemm_mma.py (3)
  • GemmMMA (13-214)
  • lower (60-202)
  • infer_layout (15-58)
tilelang/tileop/gemm/gemm_wgmma.py (3)
  • GemmWGMMA (13-137)
  • lower (64-125)
  • infer_layout (15-62)
tilelang/tileop/gemm/gemm_base.py (2)
  • lower (18-19)
  • infer_layout (15-16)
tilelang/layout/fragment.py (1)
tilelang/layout/layout.py (3)
  • get_input_shape (59-68)
  • get_output_shape (70-79)
  • is_equal (136-145)
src/layout/layout.h (1)
src/layout/gemm_layouts.cc (4)
  • makeGemmABLayoutHopper (594-615)
  • makeGemmABLayoutHopper (594-595)
  • makeGemmABLayoutCDNA (617-625)
  • makeGemmABLayoutCDNA (617-618)
src/layout/layout.cc (1)
src/layout/gemm_layouts.cc (10)
  • makeGemmABLayout (573-592)
  • makeGemmABLayout (573-574)
  • makeGemmABLayoutHopper (594-615)
  • makeGemmABLayoutHopper (594-595)
  • makeFullBankSwizzleLayout (375-391)
  • makeFullBankSwizzleLayout (375-375)
  • makeHalfBankSwizzleLayout (356-372)
  • makeHalfBankSwizzleLayout (356-356)
  • makeQuarterBankSwizzleLayout (336-353)
  • makeQuarterBankSwizzleLayout (336-337)
🪛 Ruff (0.12.2)
tilelang/layout/swizzle.py

23-23: PEP 484 prohibits implicit Optional

Convert to Optional[T]

(RUF013)


54-54: Avoid specifying long messages outside the exception class

(TRY003)


79-79: Avoid specifying long messages outside the exception class

(TRY003)


104-104: Avoid specifying long messages outside the exception class

(TRY003)

tilelang/tileop/gemm/gemm_wgmma.py

61-62: Avoid specifying long messages outside the exception class

(TRY003)


124-125: Avoid specifying long messages outside the exception class

(TRY003)

tilelang/intrinsics/wgmma_macro_generator.py

104-104: Unused method argument: n_dim

(ARG002)


143-143: Avoid specifying long messages outside the exception class

(TRY003)


353-353: Avoid specifying long messages outside the exception class

(TRY003)

tilelang/language/builtin.py

375-375: Avoid specifying long messages outside the exception class

(TRY003)


377-377: Parenthesize a and b expressions when chaining and and or together, to make the precedence clear

Parenthesize the and subexpression

(RUF021)


378-378: Avoid specifying long messages outside the exception class

(TRY003)


401-401: Avoid specifying long messages outside the exception class

(TRY003)


403-403: Parenthesize a and b expressions when chaining and and or together, to make the precedence clear

Parenthesize the and subexpression

(RUF021)


404-404: Avoid specifying long messages outside the exception class

(TRY003)

tilelang/tileop/gemm/gemm_mma.py

60-60: Unused method argument: layout_map

(ARG002)

tilelang/language/__init__.py

45-45: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)

tilelang/layout/__init__.py

7-7: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)


8-8: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)


9-9: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)


10-10: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)


11-11: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)


13-13: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)

tilelang/tileop/gemm/__init__.py

120-120: Avoid specifying long messages outside the exception class

(TRY003)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Cursor Bugbot
🔇 Additional comments (43)
tilelang/tileop/gemm/gemm_base.py (2)

8-8: Add PrimExpr import for updated clear_accum type annotation.

The import is correctly added to support the type annotation change.


107-108: Update type annotation from bool to PrimExpr.

The change from bool to PrimExpr aligns with the WGMMA backend requirements where clear_accum can be an expression rather than just a boolean value. This provides more flexibility for the accumulation control logic.

src/op/gemm_py.h (1)

112-113: Clean private section reorganization.

The reorganization of private members with explicit private: labeling improves code organization and readability. The GemmInst enum and GetGemmInst method remain appropriately private while allowing FFI access through wrapper functions.

src/transform/lower_device_storage_access_info.cc (1)

47-48: Correct exclusion of descriptor scope from storage lowering.

The addition of scope.tag != ".descriptor" properly excludes descriptor buffers from the generic storage access lowering path, which is necessary for the WGMMA descriptor handling. This change aligns with the corresponding exclusion in storage_rewrite.cc and allows descriptor buffers to maintain their special semantics.

tilelang/language/tir/ir.py (1)

294-295: Add WGMMA PTX intrinsic wrappers.

The new ptx_wgmma_ss and ptx_wgmma_rs wrappers correctly follow the established pattern using _dtype_forward decorator, providing consistent API access to the underlying PTX WGMMA intrinsics for shared-shared and register-shared variants.

tilelang/layout/fragment.py (2)

207-207: Enhanced Fragment representation with shape information.

The updated __repr__ method now includes input and output shapes, providing more comprehensive debugging information. The use of get_input_shape() and get_output_shape() follows the established pattern from the base Layout class.


209-213: Add Fragment equality check method.

The new is_equal method provides a proper way to compare Fragment instances by delegating to the FFI implementation. This aligns with the corresponding method in the base Layout class and enables proper equality testing for Fragment objects.

src/transform/storage_rewrite.cc (2)

677-678: Exclude .descriptor from “special tagged memory” — OK.

This prevents special-merge and memory‑info constraints from applying to descriptor buffers, which is desirable.


847-849: Skip memory‑info path for .descriptor in NewAllocTagMerged — OK; add test.

Change correctly treats descriptor allocations neutrally (GetMemoryInfo is not invoked for .descriptor). Add a minimal IR test that ensures:

  • local.descriptor buffers are not grouped into special merges
  • no GetMemoryInfo lookups occur for .descriptor scopes

Relevant files: src/transform/storage_rewrite.cc (≈lines 846–850), src/transform/lower_device_storage_access_info.cc (≈line 48).

tilelang/language/ast/ir.py (1)

1895-1896: Approve: WGMMA wrappers wired and registered

Wrapper bindings exist in tilelang/language/ast/ir.py (ptx_wgmma_ss/rs), TIR op wrappers call _tvm_op.Op.get("tl.ptx_wgmma_ss/rs") in tilelang/language/tir/op.py, builtins are registered in src/op/builtin.cc/.h (TIR_DEFINE_TL_BUILTIN + declarations) and referenced by codegen in src/target/codegen_cuda.cc — no issues found. Add a small CI check to assert builtin registration if desired.

src/op/builtin.cc (2)

174-183: Descriptor intrinsics API consistent with wrappers.

5 inputs for initialize and 2 for increase align with the Python layer; opaque effect is appropriate given side effects. LGTM.


83-92: WGMMA intrinsics: 15-arg arity verified; codegen enforces 15 operands. Python wrappers define ptx_wgmma_ss/rs and builtin registers .set_num_inputs(15); src/target/codegen_cuda.cc contains ICHECK_EQ(op->args.size(), 15U) for both handlers — operand alignment confirmed.

src/op/gemm_py.cc (3)

109-111: Unreachable-return safety: fine as-is.

This satisfies compilers post-ICHECK; harmless.


255-257: Unreachable-return safety: fine as-is.

Consistent with the ICHECK case; no functional impact.


284-290: FFI reflection for GemmInst selection: LGTM.

Simple, side‑effect free binding; useful for runtime dispatch.

src/op/gemm.cc (1)

481-481: LGTM! Consistent layout API migration to boolean flags.

The transition from numeric codes (1/2) to boolean flags for k_inner parameter across all backends (Volta, Ampere/Turing/SM120, Hopper) improves API clarity and type safety.

Also applies to: 507-507, 520-520, 545-545, 547-547, 563-563, 565-565

src/target/codegen_cuda.cc (5)

898-898: Including descriptor scope in pointer access logic.

The addition of "local.descriptor" scope alongside "local.var" is correct for treating descriptors as local objects.


1788-1789: LGTM! Proper allocation of descriptor storage.

The allocation of tl::GmmaDescriptor for descriptor scope is correctly implemented.


1305-1340: Check for additional error conditions in WGMMA intrinsics.

The WGMMA SS intrinsic implementation looks correct, but consider adding validation for the descriptor parameters.

Should we verify that the descriptors are properly initialized before use? Consider adding checks to ensure a_desc and b_desc are valid descriptor handles.


1701-1713: LGTM! Descriptor operations correctly implemented.

The implementation of initialize_descriptor and increase_descriptor_offset intrinsics is correct and properly forwards all parameters to the TL template functions.

Also applies to: 1714-1721


1823-1823: Correct handling of descriptor scope in allocation check.

The condition properly excludes local.descriptor from the unsupported scope error path.

tilelang/language/builtin.py (1)

355-386: LGTM! Well-documented descriptor initialization function.

The initialize_descriptor function is well-implemented with proper type checking, parameter documentation, and error handling.

tilelang/layout/swizzle.py (2)

10-18: LGTM! Well-structured swizzle layout functions.

The updated make_swizzled_layout and new make_wgmma_swizzled_layout functions are properly implemented with clear parameter forwarding to the FFI API.

Also applies to: 22-34


39-109: LGTM! Flexible bank-swizzled layout helpers.

The three bank-swizzled layout functions (full/half/quarter) are well-implemented with flexible argument handling that supports both buffer objects and explicit parameters.

src/tl_templates/cuda/common.h (1)

307-367: LGTM! Well-structured descriptor union implementation.

The GmmaDescriptor union is properly designed with:

  • Multiple access patterns via desc_, reg32_[], and reg16_[]
  • Clear bitfield layout matching CUDA WMMA descriptor format
  • Proper copy and move semantics
  • Convenient arithmetic operator for offset adjustments
src/layout/layout.cc (2)

461-478: LGTM! Equality check methods properly implemented.

The new Layout_is_equal and Fragment_is_equal FFI bindings correctly expose the underlying equality check functionality with proper node casting.


506-511: LGTM! WGMMA swizzled layout properly wired.

The make_wgmma_swizzled_layout correctly passes through the continuity parameter separately from mat_continuous, which is essential for WGMMA's layout requirements.

tilelang/tileop/gemm/gemm_wgmma.py (1)

36-37: Verify continuity calculation for k-major layouts

tilelang/tileop/gemm/gemm_wgmma.py:36-37,50 — current code: a_continuity = self.M if a_is_k_major else 4 * self.K // m_warp and b_continuity = self.N if b_is_k_major else 4 * self.K // n_warp. Confirm whether the k-major branch should intentionally use self.M/self.N (instead of a K-derived continuity) and whether the 4 * factor is correct; if intentional, add an inline comment explaining the rationale and add a unit test, otherwise correct the formula to derive continuity from K.

tilelang/intrinsics/wgmma_macro_generator.py (1)

220-231: Use 64-bit arithmetic for shared-memory offset calculations.

A_offset / B_offset in tilelang/intrinsics/wgmma_macro_generator.py (lines 220–231) perform multiplications like i * 64 * A_buf.shape[-1] that can overflow 32-bit; ensure intermediate arithmetic uses 64-bit (cast the PrimExprs to int64 or use an explicit int64 cast / bounds check) before multiplying by elems_in_bytes and passing to ptx_wgmma_ss.

tilelang/tileop/gemm/__init__.py (6)

1-1: LGTM! Addition of IntEnum import is appropriate for the new GemmInst class.

The import is necessary for creating the strongly-typed enumeration that will be used to represent different GEMM instruction types.


10-11: LGTM! Import additions align with the new WGMMA support.

The imports for GemmWGMMA and _ffi_api are necessary for the new architecture-aware dispatch mechanism.


21-24: Good addition of layout_map parameter to support WGMMA requirements.

The parameter addition correctly passes the layout information through to the lower-level implementation.


27-42: LGTM! Well-designed enumeration for GEMM instruction types.

The GemmInst enumeration provides a clean abstraction for different GEMM implementations with convenient helper methods for type checking.


71-81: Architecture-aware dispatch implementation looks good.

The new infer_layout and lower methods properly delegate to architecture-specific implementations using a clean dispatch pattern.


100-120: LGTM! Clean dispatch with proper error handling.

The method provides a clear mapping from instruction types to implementation classes with appropriate error handling for unsupported types.

src/target/ptx.cc (7)

149-156: LGTM! Useful helper function for integer type checking.

This function provides a convenient way to check if a PTX data type represents an integer, which is needed for WGMMA scale input validation.


188-197: LGTM! Appropriate layout type conversion helper.

The function correctly maps boolean values to layout types for WGMMA assembly generation.


278-291: Well-structured WGMMA configuration struct.

The struct follows the same pattern as MMAConfig and provides proper equality comparison for configuration validation.


293-720: Comprehensive WGMMA configuration table.

The extensive configuration table covers all supported WGMMA operations including dense and sparse variants across multiple data types. This aligns well with the PR objective of supporting int8, fp8, and tf32 data types.


859-879: Solid configuration validation function.

The function properly validates WGMMA configurations by checking data type compatibility and searching for matching configurations in the valid list.


1119-1121: Fatal error for sparse WGMMA is reasonable.

Since sparse WGMMA is not yet implemented, failing fast with a clear error message is the appropriate approach.


1235-1306: Well-structured WGMMA assembly generation.

The function properly generates PTX assembly for WGMMA operations with support for various configurations including scale parameters and layout options. The template replacement approach is clean and maintainable.

tilelang/language/allocate.py (1)

129-136: New descriptor allocator looks good

Minimal, aligned with ".descriptor" scope handling in codegen. LGTM.

Comment on lines +163 to 167
int element_size, bool k_inner = true);
Layout makeGemmABLayoutHopper(int mat_stride, int mat_continuous,
int continuity, int element_size, int kfactor);
int continuity, int element_size,
bool k_inner = true);
Layout makeGemmABLayoutCDNA(int stride, int continuous, int element_size,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

🧩 Analysis chain

Signature change to bool k_inner: audit all call sites

Converting int kfactorbool k_inner changes semantics; old int literals (e.g., 2) now coerce to true. Please verify all callers pass a boolean with the intended meaning.

Run:


🏁 Script executed:

#!/bin/bash
set -euo pipefail

# Show call sites with context
rg -nPU -C3 --type=cpp --type=cc --type=cxx --type=h --type=cu '\bmakeGemmABLayout\s*\('
rg -nPU -C3 --type=cpp --type=cc --type=cxx --type=h --type=cu '\bmakeGemmABLayoutHopper\s*\('

# Heuristic: calls with 5+ args where last arg is not clearly true/false or named
rg -nPU --type=cpp --type=cc --type=cxx --type=cu '\bmakeGemm(ABLayout|ABLayoutHopper)\s*\((?:[^,]*,){4}\s*([^,)]+)\)' -r '$1' | grep -v -E '^\s*(true|false)\s*$' || true

Length of output: 26


🏁 Script executed:

#!/bin/bash
set -euo pipefail

echo "1) Listing raw matches (context):"
if command -v rg >/dev/null 2>&1; then
  rg -n -U -P '\bmakeGemmABLayout(?:Hopper|CDNA)?\s*\(' -C3 --hidden || true
else
  grep -nR --line-number -E '\bmakeGemmABLayout(?:Hopper|CDNA)?\s*\(' . || true
fi

echo
echo "2) Extracting full call expression and printing the last argument (best-effort):"
python - <<'PY'
import re,subprocess
pat=re.compile(r'\b(makeGemmABLayout(?:Hopper|CDNA)?)\s*\(')
# try rg first, fallback to grep
try:
    out=subprocess.check_output(['rg','-n','-U','-P',r'\bmakeGemmABLayout(?:Hopper|CDNA)?\s*\('],stderr=subprocess.DEVNULL).decode()
except Exception:
    try:
        out=subprocess.check_output(['grep','-nR','-E',r'\bmakeGemmABLayout(?:Hopper|CDNA)?\s*\(','./'],stderr=subprocess.DEVNULL).decode()
    except Exception:
        out=''

for line in out.splitlines():
    parts=line.split(':',2)
    if len(parts)<3:
        continue
    path,lineno,content=parts[0],int(parts[1]),parts[2]
    try:
        with open(path,'r',encoding='utf-8') as f:
            data=f.read()
    except Exception:
        continue
    lines=data.splitlines()
    idx=max(0,lineno-1)
    window = '\n'.join(lines[idx:idx+200])
    m=pat.search(window)
    if not m:
        continue
    start_idx=m.end()-1
    depth=0
    end_pos=None
    for i,ch in enumerate(window[start_idx:], start=start_idx):
        if ch=='(':
            depth+=1
        elif ch==')':
            depth-=1
            if depth==0:
                end_pos=i
                break
    call_text = window[m.start(): end_pos+1] if end_pos is not None else window[m.start():m.start()+400]
    inside = call_text[call_text.find('(')+1: (call_text.rfind(')') if call_text.rfind(')')>call_text.find('(') else None )]
    inside = re.sub(r'/\*.*?\*/|//.*?$','',inside,flags=re.S|re.M)
    # split args by commas not inside nested constructs
    args=[]
    cur=''
    depth=0
    for ch in inside:
        if ch in '([{<':
            depth+=1
        elif ch in ')]}>':
            depth-=1
        if ch==',' and depth==0:
            args.append(cur.strip()); cur=''
        else:
            cur+=ch
    if cur.strip():
        args.append(cur.strip())
    last = args[-1] if args else ''
    # simplify last token
    last_token = last.strip().split()[-1] if last.strip() else ''
    print(f"{path}:{lineno}: func={m.group(1)} last_arg_raw='{last}' last_arg_token='{last_token}'")
PY

Length of output: 9413


Critical: callers passing integer k-factor (1/2) must be fixed — makeGemmABLayout/Hopper now takes bool k_inner

src/op/gemm_sp.cc passes integer ternaries (trans_A ? 1 : 2 and trans_B ? 2 : 1) at src/op/gemm_sp.cc:196-198 and src/op/gemm_sp.cc:209-211; both non-zero values now convert to true and break intended semantics. Change these calls to pass an explicit boolean with the intended meaning (e.g., use trans_A/trans_B or (expr)==1), or restore an int overload if 1/2 semantics are required. Other callers (src/op/gemm.cc, src/layout/*.cc/h) were checked and pass booleans.

Comment on lines +231 to 233
Downcast<PrimFunc>((*f)(GetRef<GemmPy>(this), T.layout_map, T.target,
T.thread_bounds, T.thread_var));
ICHECK(prim_func->attrs.defined());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

🧩 Analysis chain

FFI lower signature now includes layout_map; ensure all providers updated.

The C++ call passes T.layout_map. Verify the Python FFI target “tl.gemm_py.lower” accepts (gemm_py, layout_map, target, thread_bounds, thread_var) everywhere to avoid runtime errors.


🏁 Script executed:

#!/bin/bash
# Confirm the new signature at Python side
rg -nP 'def\s+lower\((self|gemm_py)[^)]*layout_map' -C2

# Find the FFI registration or dispatcher for "tl.gemm_py.lower"
rg -n 'gemm_py\.lower' -S

# Check base class signatures to avoid LSP mismatches
rg -nP 'class\s+GemmBase\b.*?\n' -n tilelang/tileop/gemm -C2
rg -nP 'def\s+lower\(' tilelang/tileop/gemm -S

Length of output: 2364


Add layout_map to GemmBase.lower

Change tilelang/tileop/gemm/gemm_base.py:18 from
def lower(self, target, thread_nums, thread_var)
to
def lower(self, layout_map: dict, target: Target, thread_nums: int, thread_var: tir.Var)
so the base signature matches the FFI tl.gemm_py.lower and existing subclass implementations.

Comment on lines 68 to 86
/*!
* \brief Print WGMMA assembly string given parameters.
* \param shape The shape string mMnNkK
* \param A_layout The layout of multiplicand A, can be either "row" or "col".
* \param B_layout The layout of multiplicand B, can be either "row" or "col".
* \param A_dtype The data type of multiplicand A.
* \param B_dtype The data type of multiplicand B.
* \param C_dtype The data type of multiplicand C.
*/
std::string PrintWGMMAAssembly(
const std::string &shape, const bool &A_layout, const bool &B_layout,
const std::string &A_dtype, const std::string &B_dtype,
const std::string &C_dtype, const std::string &a_desc,
const std::string &A_offset, const std::string &b_desc,
const std::string &B_offset, const std::string &c_ptr,
const std::string &c_offset, const bool &scale_out, const bool &scale_in_a,
const bool &scale_in_b, const bool &a_is_shared,
const std::string &metadata, const std::string &metadata_offset,
const std::string &sparsity_selector, bool sparse);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Documentation incomplete and misaligned parameters.

The documentation comment only describes the first 6 parameters, but the function signature has 17 parameters. This incomplete documentation makes it difficult to understand the purpose of the remaining parameters.

Please update the documentation to include all parameters:

 /*!
  * \brief Print WGMMA assembly string given parameters.
  * \param shape The shape string mMnNkK
  * \param A_layout The layout of multiplicand A, can be either "row" or "col".
  * \param B_layout The layout of multiplicand B, can be either "row" or "col".
  * \param A_dtype The data type of multiplicand A.
  * \param B_dtype The data type of multiplicand B.
  * \param C_dtype The data type of multiplicand C.
+ * \param a_desc Descriptor for operand A (shared memory descriptor or register pointer).
+ * \param A_offset Offset for operand A.
+ * \param b_desc Descriptor for operand B.
+ * \param B_offset Offset for operand B.
+ * \param c_ptr Pointer to accumulator C.
+ * \param c_offset Offset for accumulator C.
+ * \param scale_out Scaling output flag.
+ * \param scale_in_a Scaling input A flag.
+ * \param scale_in_b Scaling input B flag.
+ * \param a_is_shared Whether operand A is in shared memory.
+ * \param metadata Pointer to metadata buffer (for sparse operations).
+ * \param metadata_offset Offset in metadata buffer.
+ * \param sparsity_selector Sparsity selector for sparse operations.
+ * \param sparse Whether this is a sparse WGMMA operation.
  */
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
/*!
* \brief Print WGMMA assembly string given parameters.
* \param shape The shape string mMnNkK
* \param A_layout The layout of multiplicand A, can be either "row" or "col".
* \param B_layout The layout of multiplicand B, can be either "row" or "col".
* \param A_dtype The data type of multiplicand A.
* \param B_dtype The data type of multiplicand B.
* \param C_dtype The data type of multiplicand C.
*/
std::string PrintWGMMAAssembly(
const std::string &shape, const bool &A_layout, const bool &B_layout,
const std::string &A_dtype, const std::string &B_dtype,
const std::string &C_dtype, const std::string &a_desc,
const std::string &A_offset, const std::string &b_desc,
const std::string &B_offset, const std::string &c_ptr,
const std::string &c_offset, const bool &scale_out, const bool &scale_in_a,
const bool &scale_in_b, const bool &a_is_shared,
const std::string &metadata, const std::string &metadata_offset,
const std::string &sparsity_selector, bool sparse);
/*!
* \brief Print WGMMA assembly string given parameters.
* \param shape The shape string mMnNkK
* \param A_layout The layout of multiplicand A, can be either "row" or "col".
* \param B_layout The layout of multiplicand B, can be either "row" or "col".
* \param A_dtype The data type of multiplicand A.
* \param B_dtype The data type of multiplicand B.
* \param C_dtype The data type of multiplicand C.
* \param a_desc Descriptor for operand A (shared memory descriptor or register pointer).
* \param A_offset Offset for operand A.
* \param b_desc Descriptor for operand B.
* \param B_offset Offset for operand B.
* \param c_ptr Pointer to accumulator C.
* \param c_offset Offset for accumulator C.
* \param scale_out Scaling output flag.
* \param scale_in_a Scaling input A flag.
* \param scale_in_b Scaling input B flag.
* \param a_is_shared Whether operand A is in shared memory.
* \param metadata Pointer to metadata buffer (for sparse operations).
* \param metadata_offset Offset in metadata buffer.
* \param sparsity_selector Sparsity selector for sparse operations.
* \param sparse Whether this is a sparse WGMMA operation.
*/
std::string PrintWGMMAAssembly(
const std::string &shape, const bool &A_layout, const bool &B_layout,
const std::string &A_dtype, const std::string &B_dtype,
const std::string &C_dtype, const std::string &a_desc,
const std::string &A_offset, const std::string &b_desc,
const std::string &B_offset, const std::string &c_ptr,
const std::string &c_offset, const bool &scale_out, const bool &scale_in_a,
const bool &scale_in_b, const bool &a_is_shared,
const std::string &metadata, const std::string &metadata_offset,
const std::string &sparsity_selector, bool sparse);
🤖 Prompt for AI Agents
In src/target/ptx.h around lines 68 to 86, the function comment documents only
the first six parameters while the declaration has 17; update the block comment
to include @param entries for every parameter in the signature (A_layout,
B_layout — explain what true/false means or that they represent "row"/"col";
a_desc and b_desc — describe descriptor strings for A/B; A_offset and B_offset —
offset expressions or registers for A/B; c_ptr and c_offset — pointer and offset
used for C output; scale_out, scale_in_a, scale_in_b — boolean flags controlling
scaling behavior; a_is_shared — whether A is in shared memory; metadata and
metadata_offset — metadata pointer/string and its offset; sparsity_selector —
sparsity control selector; sparse — whether sparse mode is enabled), keeping
each description concise and aligned with the existing doc style.

src/target/ptx.h Outdated
Comment on lines 78 to 79
const std::string &shape, const bool &A_layout, const bool &B_layout,
const std::string &A_dtype, const std::string &B_dtype,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Inconsistent use of const references for boolean parameters.

The function signature uses const bool & for boolean parameters, which is unnecessary overhead for primitive types. This differs from the existing PrintMMAAssembly function that passes booleans by value.

Apply this diff to pass booleans by value:

-std::string PrintWGMMAAssembly(
-    const std::string &shape, const bool &A_layout, const bool &B_layout,
+std::string PrintWGMMAAssembly(
+    const std::string &shape, bool A_layout, bool B_layout,
     const std::string &A_dtype, const std::string &B_dtype,
     const std::string &C_dtype, const std::string &a_desc,
     const std::string &A_offset, const std::string &b_desc,
     const std::string &B_offset, const std::string &c_ptr,
-    const std::string &c_offset, const bool &scale_out, const bool &scale_in_a,
-    const bool &scale_in_b, const bool &a_is_shared,
+    const std::string &c_offset, bool scale_out, bool scale_in_a,
+    bool scale_in_b, bool a_is_shared,
     const std::string &metadata, const std::string &metadata_offset,
     const std::string &sparsity_selector, bool sparse);
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
const std::string &shape, const bool &A_layout, const bool &B_layout,
const std::string &A_dtype, const std::string &B_dtype,
std::string PrintWGMMAAssembly(
const std::string &shape, bool A_layout, bool B_layout,
const std::string &A_dtype, const std::string &B_dtype,
const std::string &C_dtype, const std::string &a_desc,
const std::string &A_offset, const std::string &b_desc,
const std::string &B_offset, const std::string &c_ptr,
const std::string &c_offset, bool scale_out, bool scale_in_a,
bool scale_in_b, bool a_is_shared,
const std::string &metadata, const std::string &metadata_offset,
const std::string &sparsity_selector, bool sparse);
🤖 Prompt for AI Agents
In src/target/ptx.h around lines 78-79, the function signature uses `const bool
&A_layout` and `const bool &B_layout` which is unnecessary for primitive types;
change both parameters to pass by value (`bool A_layout, bool B_layout`) and
update any matching function declarations/definitions and callers to use value
semantics so signatures remain consistent with PrintMMAAssembly.

Comment on lines +405 to +421
template <int layout_type = 0, int leading_byte_offset = 0,
int stride_byte_offset = 0, typename T>
TL_DEVICE void initialize_descriptor(GmmaDescriptor &descriptor,
T *start_address) {
descriptor.bitfield.start_address_ =
cute::cast_smem_ptr_to_uint(start_address) >> 4;
descriptor.bitfield.layout_type_ = layout_type;
descriptor.bitfield.base_offset_ = 0;
descriptor.bitfield.leading_byte_offset_ = leading_byte_offset;
descriptor.bitfield.stride_byte_offset_ = stride_byte_offset;
}

template <typename T>
TL_DEVICE void increase_descriptor_offset(GmmaDescriptor &descriptor,
T offset) {
descriptor.reg32_[0] += (offset >> 4);
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

🧩 Analysis chain

Consider validating descriptor parameters.

The initialize_descriptor and increase_descriptor_offset functions directly manipulate hardware descriptor fields without validation. Consider adding assertions or checks for:

  • Valid layout_type values (0-3)
  • Byte offset alignment requirements
  • Address alignment to 16-byte boundaries

🏁 Script executed:

#!/bin/bash
# Check if there are any existing validation patterns for descriptor initialization
ast-grep --pattern 'initialize_descriptor($$$) {
  $$$
  assert($$$)
  $$$
}'

# Look for any documentation or comments about descriptor constraints
rg -A 3 -B 3 "descriptor.*align|align.*descriptor" --type cpp --type h

Length of output: 333


Validate GMMA descriptor parameters before writing hardware fields

These functions write descriptor bitfields with no checks; src/target/cuda.h documents a 64‑byte alignment requirement for tensor map descriptors. Add assertions/guards in src/tl_templates/cuda/common.h (initialize_descriptor and increase_descriptor_offset, ~lines 405–421):

  • Assert layout_type is in the supported range (0–3).
  • Assert start_address is 16‑byte aligned (e.g., (uintptr_t)start_address & 0xF == 0).
  • Assert leading_byte_offset, stride_byte_offset, and the offset passed to increase_descriptor_offset are multiples of 16 (since code uses >> 4) and fit the descriptor bitfield widths; also guard against reg32_[0] overflow when adding (offset >> 4).

Files to edit: src/tl_templates/cuda/common.h (initialize_descriptor, increase_descriptor_offset). Reference: src/target/cuda.h (tensor map descriptor alignment comment).

🤖 Prompt for AI Agents
In src/tl_templates/cuda/common.h around lines 405 to 421, add validation before
writing descriptor fields: assert layout_type is within 0..3; assert
start_address is 16-byte aligned ((uintptr_t)start_address & 0xF == 0); assert
leading_byte_offset and stride_byte_offset are multiples of 16 and within the
bitfield width limits documented for GmmaDescriptor; in
increase_descriptor_offset assert offset is a multiple of 16, compute delta =
offset >> 4 and check that descriptor.reg32_[0] + delta will not overflow the
destination bitfield (or reg32_) before applying it; if any check fails, either
trigger a fast-fail assertion or return/handle the error path so invalid values
are never written to hardware fields.

if not isinstance(descriptor, (BufferLoad, Buffer)):
raise TypeError("Descriptor must be a tvm.tir.Buffer or tvm.tir.BufferLoad.")

if isinstance(descriptor, Buffer) and len(descriptor.shape) != 1 or descriptor.shape[0] != 1:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Fix operator precedence ambiguity.

The condition mixes and and or operators without proper parentheses, which can lead to unexpected behavior.

Apply this diff to fix the operator precedence:

-    if isinstance(descriptor, Buffer) and len(descriptor.shape) != 1 or descriptor.shape[0] != 1:
+    if isinstance(descriptor, Buffer) and (len(descriptor.shape) != 1 or descriptor.shape[0] != 1):

The same issue exists in increase_descriptor_offset:

-    if isinstance(descriptor, Buffer) and len(descriptor.shape) != 1 or descriptor.shape[0] != 1:
+    if isinstance(descriptor, Buffer) and (len(descriptor.shape) != 1 or descriptor.shape[0] != 1):

Also applies to: 403-403

🧰 Tools
🪛 Ruff (0.12.2)

377-377: Parenthesize a and b expressions when chaining and and or together, to make the precedence clear

Parenthesize the and subexpression

(RUF021)

🤖 Prompt for AI Agents
In tilelang/language/builtin.py around lines 377 and 403, the conditional mixes
"and" and "or" without parentheses causing ambiguous precedence; wrap the shape
checks in parentheses so the intent is clear — e.g. make the condition check
"isinstance(descriptor, Buffer) and (len(descriptor.shape) != 1 or
descriptor.shape[0] != 1)" (and apply the same parenthesized grouping in
increase_descriptor_offset at line 403) to ensure the logical grouping is
correct.

cursor[bot]

This comment was marked as outdated.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 6

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (2)
src/target/ptx.h (1)

29-31: Missing headers for Replacer (compile blocker).

Replacer uses std::vector and std::pair but (and ideally ) aren’t included.

Apply:

 #include <string>
 #include <tuple>
+#include <vector>
+#include <utility>
src/target/ptx.cc (1)

26-31: Missing (compile blocker).

std::stringstream is used but header isn’t included.

Apply:

 #include <algorithm>
+#include <sstream>
 #include <string>
 #include <tuple>
 #include <utility>
 #include <vector>
♻️ Duplicate comments (5)
src/target/ptx.h (2)

156-164: WGMMA Doxygen is incomplete and mismatched with the signature.

Only the first 6 params are documented; names refer to A/B_layout strings while the API uses a_is_k_major/b_is_k_major and many more args.

Recommend adding @param docs for all parameters and aligning names/semantics with the signature (a_is_k_major, b_is_k_major, a_desc, A_offset, b_desc, B_offset, c_ptr, c_offset, scale_out, scale_in_a, scale_in_b, a_is_shared, metadata, metadata_offset, sparsity_selector, sparse). Want a ready-to-apply doc patch?


165-174: Pass booleans by value, not const references.

Primitive bools shouldn’t be passed by const&. Keep consistent with other APIs.

Apply:

-std::string PrintWGMMAAssembly(
-    const std::string &shape, const bool &a_is_k_major, const bool &b_is_k_major,
+std::string PrintWGMMAAssembly(
+    const std::string &shape, bool a_is_k_major, bool b_is_k_major,
     const std::string &A_dtype, const std::string &B_dtype,
     const std::string &C_dtype, const std::string &a_desc,
     const std::string &A_offset, const std::string &b_desc,
     const std::string &B_offset, const std::string &c_ptr,
-    const std::string &c_offset, const bool &scale_out, const bool &scale_in_a,
-    const bool &scale_in_b, const bool &a_is_shared,
+    const std::string &c_offset, bool scale_out, bool scale_in_a,
+    bool scale_in_b, bool a_is_shared,
     const std::string &metadata, const std::string &metadata_offset,
     const std::string &sparsity_selector, bool sparse);
src/tl_templates/cuda/common.h (3)

438-454: Add basic parameter validation (layout_type range).

Follow-up to earlier review on descriptor validation.

Add: assert(0 <= layout_type && layout_type <= 3).


450-454: increase_descriptor_offset risks overflow/carry into other fields.

Directly adding to reg32_[0] can spill into adjacent fields; enforce 16B granularity and update the bitfield safely.

Apply:

 TL_DEVICE void increase_descriptor_offset(GmmaDescriptor &descriptor,
                                           T offset) {
-  descriptor.reg32_[0] += (offset >> 4);
+  assert((offset % 16) == 0);
+  uint32_t delta = uint32_t(offset >> 4);
+  uint32_t new_sa = uint32_t(descriptor.bitfield.start_address_) + delta;
+  // 14-bit field
+  assert((new_sa & ~0x3FFFu) == 0);
+  descriptor.bitfield.start_address_ = uint16_t(new_sa & 0x3FFFu);
 }

438-448: initialize_descriptor: missing >>4 for byte-offset fields; no validation.

Bitfields exclude 4 LSBs; current code writes raw byte offsets. Also lacks alignment/range checks.

Apply:

 TL_DEVICE void initialize_descriptor(GmmaDescriptor &descriptor,
                                      T *start_address) {
-  descriptor.bitfield.start_address_ =
-      cute::cast_smem_ptr_to_uint(start_address) >> 4;
+  auto sa = cute::cast_smem_ptr_to_uint(start_address);
+  // 16B alignment and 14-bit range
+  assert((sa & 0xF) == 0);
+  descriptor.bitfield.start_address_ = uint16_t((sa >> 4) & 0x3FFF);
   descriptor.bitfield.layout_type_ = layout_type;
   descriptor.bitfield.base_offset_ = 0;
-  descriptor.bitfield.leading_byte_offset_ = leading_byte_offset;
-  descriptor.bitfield.stride_byte_offset_ = stride_byte_offset;
+  assert((leading_byte_offset % 16) == 0);
+  assert((stride_byte_offset % 16) == 0);
+  descriptor.bitfield.leading_byte_offset_ =
+      uint16_t(((leading_byte_offset >> 4) & 0x3FFF));
+  descriptor.bitfield.stride_byte_offset_ =
+      uint16_t(((stride_byte_offset >> 4) & 0x3FFF));
 }

Also assert layout_type in [0,3].

🧹 Nitpick comments (14)
tilelang/language/tir/op.py (1)

1106-1144: Add a docstring for ptx_wGMMA_rs; clarify variant.

Mirror the ss variant doc and note A is a buffer (not a descriptor).

Apply this diff:

 def ptx_wgmma_rs(
     dtype,
     wgmma_prefix,
     a_is_k_major,
     b_is_k_major,
     a_dtype_abbrv,
     b_dtype_abbrv,
     accum_dtype_abbrv,
     A_buf,
     A_offset,
     B_desc,
     B_offset,
     C_data,
     C_offset,
     scale_out,
     scale_in_a,
     scale_in_b,
 ):
-
-
+    """PTX WGMMA (warp-group MMA) intrinsic wrapper.
+
+    Variant: rs (A uses buffer pointer, B uses descriptor).
+    See NVIDIA PTX ISA — Warpgroup matrix instructions (WGMMA).
+    https://docs.nvidia.com/cuda/parallel-thread-execution/index.html
+    """

Optional: consider renaming A_buf to A_data or A_ptr for consistency with C_data (only if no keyword callers rely on A_buf).

src/target/ptx.h (1)

45-69: Two DataType enums exist (ptx::DataType here and tl::DataType in common.h).

Risk of drift and conversion friction.

Add explicit conversion helpers and static_assert ordinal equivalence in one place, or consolidate to a single enum declared in a shared header.

src/tl_templates/cuda/common.h (1)

18-19: Redundant SMEM pointer casters.

Both cute::cast_smem_ptr_to_uint and local helpers exist (cast_smem_ptr_to_int/smem_ptr_to_uint).

Pick one idiom and remove duplicates to avoid confusion.

src/tl_templates/cuda/instruction/wgmma.h (1)

22-35: Don’t rely on device printf; fail fast when unspecialized.

Keep printf only under a debug guard; restore a compile-time error otherwise.

Apply:

-    ) {
-        printf("DEBUG: WgmmaSSImpl fallback - A_type=%d (kFloat16=%d), B_type=%d, C_type=%d, M=%d, N=%d, K=%d, tnspA=%d, tnspB=%d, scaleA=%d, scaleB=%d\n",
-               (int)A_type, (int)DataType::kFloat16, (int)B_type, (int)C_type, M, N, K, (int)tnspA, (int)tnspB, scaleA, scaleB);
-        // 暂时注释掉 static_assert 来看调试输出
-        // static_assert(always_false_v<decltype(c)>,
-        //     "wgmma_ss: No specialization available for given template parameters!");
-    };
+    ) {
+#if defined(TL_DEBUG_WGMMA_FALLBACK)
+        printf("WGMMA fallback A=%d B=%d C=%d M=%d N=%d K=%d tnspA=%d tnspB=%d scaleA=%d scaleB=%d\n",
+               (int)A_type, (int)B_type, (int)C_type, M, N, K, (int)tnspA, (int)tnspB, scaleA, scaleB);
+#else
+        static_assert(always_false_v<decltype(c)>,
+                      "wgmma_ss: No specialization available for given template parameters");
+#endif
+    };
src/target/ptx.cc (2)

171-177: Pass bool by value in LayoutTypeFromBool.

Avoid const& for primitives.

Apply:

-LayoutType LayoutTypeFromBool(const bool &layout) {
+LayoutType LayoutTypeFromBool(bool layout) {

1186-1195: Unnecessary const& on bools in PrintWGMMAAssembly.

Match the header change; pass by value.

Apply:

-std::string PrintWGMMAAssembly(
-    const std::string &shape, const bool &a_is_k_major, const bool &b_is_k_major,
+std::string PrintWGMMAAssembly(
+    const std::string &shape, bool a_is_k_major, bool b_is_k_major,
     const std::string &A_dtype, const std::string &B_dtype,
     const std::string &C_dtype, const std::string &a_desc,
     const std::string &A_offset, const std::string &b_desc,
     const std::string &B_offset, const std::string &c_ptr,
-    const std::string &c_offset, const bool &scale_out, const bool &scale_in_a,
-    const bool &scale_in_b, const bool &a_is_shared,
+    const std::string &c_offset, bool scale_out, bool scale_in_a,
+    bool scale_in_b, bool a_is_shared,
     const std::string &metadata, const std::string &metadata_offset,
     const std::string &sparsity_selector, bool sparse) {
src/target/codegen_cuda.cc (3)

1306-1317: Fix argument docs for tl::ptx_wgmma_ss branch

The comments don’t match the actual args (dtype is the call’s return type, not in op->args). Please correct to avoid future regressions.

-    // arg 0: dtype
-    // arg 1: shape
-    // arg 2: A_layout
-    // arg 3: B_layout
-    // arg 4: A_dtype
-    // arg 5: B_dtype
-    // arg 6: C_dtype
-    // arg 7: multiplicand_a
-    // arg 8: multiplicand_b
-    // arg 9: accumulator
-    // arg 10: saturate
+    // arg 0: wgmma_prefix (shape string, e.g. "m64n128k32")
+    // arg 1: a_is_k_major (bool)
+    // arg 2: b_is_k_major (bool)
+    // arg 3: a_dtype_abbrv (string)
+    // arg 4: b_dtype_abbrv (string)
+    // arg 5: accum_dtype_abbrv (string)
+    // arg 6: A_desc (descriptor, shared path)
+    // arg 7: A_offset (byte offset)
+    // arg 8: B_desc (descriptor)
+    // arg 9: B_offset (byte offset)
+    // arg 10: C_data (ptr)
+    // arg 11: C_offset (byte offset)
+    // arg 12: scale_out (bool)
+    // arg 13: scale_in_a (bool)
+    // arg 14: scale_in_b (bool)

1364-1375: Align RS branch docs with actual args

Mirror the corrected SS docs for RS; current comments mention “dtype/saturate”, which don’t exist here.

-    // arg 0: dtype
-    // arg 1: shape
-    // arg 2: A_layout
-    // arg 3: B_layout
-    // arg 4: A_dtype
-    // arg 5: B_dtype
-    // arg 6: C_dtype
-    // arg 7: multiplicand_a
-    // arg 8: multiplicand_b
-    // arg 9: accumulator
-    // arg 10: saturate
+    // arg 0: wgmma_prefix (shape)
+    // arg 1: a_is_k_major (bool)
+    // arg 2: b_is_k_major (bool)
+    // arg 3: a_dtype_abbrv (string)
+    // arg 4: b_dtype_abbrv (string)
+    // arg 5: accum_dtype_abbrv (string)
+    // arg 6: A_buf (global pointer, non-shared path)
+    // arg 7: A_offset (byte offset)
+    // arg 8: B_desc (descriptor)
+    // arg 9: B_offset (byte offset)
+    // arg 10: C_data (ptr)
+    // arg 11: C_offset (byte offset)
+    // arg 12: scale_out (bool)
+    // arg 13: scale_in_a (bool)
+    // arg 14: scale_in_b (bool)

1724-1743: Descriptor intrinsics emission — consider explicit offset type and statement-only usage

  • Template parameter for increase_descriptor_offset should be explicit-width to avoid ABI surprises on different toolchains.
  • These are side-effect calls; ensure they’re only used in EvaluateNode contexts.
-    os << "tl::increase_descriptor_offset<int>(" << PrintExpr(descriptor)
+    os << "tl::increase_descriptor_offset<int32_t>(" << PrintExpr(descriptor)
        << ", " << PrintExpr(offset) << ")";

Would you like me to scan call sites to confirm they’re only used as statements?

tilelang/intrinsics/wgmma_macro_generator.py (5)

216-219: Use read access for descriptor base pointers.

These descriptors are read by the instruction; use access_ptr("r"), not "w". Covered in the earlier diffs.

Also applies to: 283-283


104-109: Remove unused n_dim parameter.

_initialize_wgmma_prefix(self, n_dim: int = 16) doesn’t use n_dim (Ruff ARG002). Simplify:

-    def _initialize_wgmma_prefix(self, n_dim: int = 16):
+    def _initialize_wgmma_prefix(self):
@@
-        self._initialize_wgmma_prefix(self.n_dim)
+        self._initialize_wgmma_prefix()

Also applies to: 94-95


333-333: Drop duplicate import of is_fragment.

Already imported at the top; remove the inner import.

-        from tilelang.utils import is_fragment

312-317: Fix docstring: this is a load layout for operand A, not a store layout.

-        Create a layout function for storing MMA results into a fragment buffer.
-        This layout is used in conjunction with `inverse_mma_store_layout` to
-        map fragment indices to threads and local indices.
+        Create a layout describing how to load MMA operand A from a fragment buffer.
+        This layout is used in conjunction with `inverse_mma_load_layout` to
+        map fragment indices to threads and per-thread local indices.

41-59: Clarify swizzle_atom_size() semantics or compute from bytes to avoid confusion.

If you keep this helper, consider defining it as swizzle_byte_size() // 16 (16‑byte atoms), to match descriptor units terminology. Current // 16 on bits is non‑obvious. Usage has been corrected in earlier diffs, but this improves readability:

     def swizzle_atom_size(self) -> int:
-        if self.is_swizzle_32b():
-            return 32 // 16
-        elif self.is_swizzle_64b():
-            return 64 // 16
-        elif self.is_swizzle_128b():
-            return 128 // 16
-        else:
-            return 1
+        # number of 16-byte atoms in the swizzle size (32B→2, 64B→4, 128B→8)
+        return self.swizzle_byte_size() // 16
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between eac5433 and 51fcf15.

📒 Files selected for processing (9)
  • src/op/builtin.h (2 hunks)
  • src/target/codegen_cuda.cc (5 hunks)
  • src/target/ptx.cc (9 hunks)
  • src/target/ptx.h (2 hunks)
  • src/tl_templates/cuda/common.h (5 hunks)
  • src/tl_templates/cuda/gemm.h (1 hunks)
  • src/tl_templates/cuda/instruction/wgmma.h (1 hunks)
  • tilelang/intrinsics/wgmma_macro_generator.py (1 hunks)
  • tilelang/language/tir/op.py (1 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
  • src/op/builtin.h
🧰 Additional context used
🧬 Code graph analysis (5)
src/tl_templates/cuda/instruction/wgmma.h (2)
src/op/builtin.h (1)
  • tl (22-362)
src/tl_templates/cuda/common.h (5)
  • tl (306-400)
  • DataType (315-360)
  • int (97-100)
  • int (135-142)
  • uint32_t (118-120)
src/target/codegen_cuda.cc (4)
src/transform/storage_rewrite.cc (6)
  • scope (674-678)
  • scope (674-674)
  • n (366-370)
  • n (366-366)
  • n (371-375)
  • n (371-371)
tilelang/language/tir/op.py (2)
  • ptx_wgmma_ss (1064-1103)
  • ptx_wgmma_rs (1106-1143)
src/target/ptx.cc (8)
  • PrintWGMMAAssembly (1186-1257)
  • PrintWGMMAAssembly (1186-1195)
  • ParseMMAShape (140-148)
  • ParseMMAShape (140-140)
  • DTypeEnumToString (107-109)
  • DTypeEnumToString (107-107)
  • DTypeEnumToString (111-113)
  • DTypeEnumToString (111-111)
tilelang/language/builtin.py (2)
  • initialize_descriptor (355-386)
  • increase_descriptor_offset (389-411)
src/tl_templates/cuda/common.h (3)
src/target/ptx.h (1)
  • DataType (45-91)
src/tl_templates/cuda/ldsm.h (12)
  • void (7-14)
  • void (16-23)
  • void (25-33)
  • void (35-42)
  • void (44-52)
  • void (54-62)
  • void (64-70)
  • void (72-79)
  • void (81-89)
  • void (91-98)
  • void (100-108)
  • void (110-119)
tilelang/language/builtin.py (2)
  • initialize_descriptor (355-386)
  • increase_descriptor_offset (389-411)
tilelang/intrinsics/wgmma_macro_generator.py (6)
tilelang/utils/language.py (1)
  • is_fragment (68-78)
tilelang/layout/swizzle.py (3)
  • make_full_bank_swizzled_layout (39-59)
  • make_half_bank_swizzled_layout (64-84)
  • make_quarter_bank_swizzled_layout (89-109)
tilelang/layout/fragment.py (3)
  • Fragment (13-213)
  • replicate (146-160)
  • repeat (123-144)
tilelang/language/builtin.py (1)
  • initialize_descriptor (355-386)
tilelang/language/tir/op.py (3)
  • ptx_wgmma_ss (1064-1103)
  • ptx_wgmma_rs (1106-1143)
  • any (1774-1790)
tilelang/intrinsics/mma_macro_generator.py (1)
  • get_store_index_map (159-165)
src/target/ptx.h (1)
src/target/ptx.cc (10)
  • DTypeFromString (54-104)
  • DTypeFromString (54-54)
  • DTypeEnumToString (107-109)
  • DTypeEnumToString (107-107)
  • DTypeEnumToString (111-113)
  • DTypeEnumToString (111-111)
  • ParseMMAShape (140-148)
  • ParseMMAShape (140-140)
  • PrintWGMMAAssembly (1186-1257)
  • PrintWGMMAAssembly (1186-1195)
🪛 Ruff (0.12.2)
tilelang/intrinsics/wgmma_macro_generator.py

104-104: Unused method argument: n_dim

(ARG002)


143-143: Avoid specifying long messages outside the exception class

(TRY003)


353-353: Avoid specifying long messages outside the exception class

(TRY003)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Cursor Bugbot
🔇 Additional comments (9)
tilelang/language/tir/op.py (1)

1064-1104: Change docstring from WMMA → WGMMA and update link

File: tilelang/language/tir/op.py (ptx_wgmma_ss). Docstring currently says WMMA; it should reference WGMMA (warp-group). Verified tl.ptx_wgmma_* ops are registered and expect 15 args.

 def ptx_wgmma_ss(
 @@
-    """TVM intrinsic for ptx tensor core wmma instructions
-    https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-instructions-for-wmma
-    """
+    """PTX WGMMA (warp-group MMA) intrinsic wrapper.
+
+    Variant: ss (A uses descriptor, B uses descriptor).
+    See NVIDIA PTX ISA — Warpgroup matrix instructions (WGMMA).
+    https://docs.nvidia.com/cuda/parallel-thread-execution/index.html
+    """
src/tl_templates/cuda/instruction/wgmma.h (2)

417-452: INT8 coverage seems partial (only N=8/16).

valid_wgmma_configs include N up to 256. If codegen may request N=32/64/… you’ll hit the fallback at compile time.

Would you like me to generate the missing specializations (or a meta-emitter) for s8->s32 shapes N∈{32,64,96,128,192,256}?


457-491: FP8 E4M3→F16/F32: only N=8 variants present.

Configs list N up to 256. Confirm dispatcher never asks for larger N, or add the remaining specializations.

I can auto-generate these blocks to match the config table.

src/tl_templates/cuda/gemm.h (1)

5-6: Include order LGTM.

Including wgmma.h in the sm90+ path before gemm_sm90.h is appropriate.

src/target/ptx.cc (1)

1222-1256: WGMMA operand builder: sparse path abort is fine; ensure scale_in use matches dtype.

Scale inputs are emitted only for non-integer types; that matches spec. No action needed.

Please confirm codegen never requests sparse WGMMA yet; otherwise return a diagnostic instead of LOG(FATAL) at higher layers.

src/tl_templates/cuda/common.h (1)

315-339: Enum duplication with ptx::DataType — parity unverified

No matches found for tvm::tl::codegen::ptx::DataType; ensure src/tl_templates/cuda/common.h (DataType, lines 315–339) has identical ordinals/values to the ptx definition. To verify locally, run:

rg -nP '\benum class DataType\b' -S
rg -nP 'ptx::DataType|namespace\s+ptx' -S
src/target/codegen_cuda.cc (2)

898-901: Local descriptor deref handling in GetBufferRef — LGTM

Returning the scalar vid for "local.descriptor" (same as "local.var") avoids bogus []/casts on tl::GmmaDescriptor. This matches the Python-side Buffer(shape=[1]) abstraction.


1810-1812: Ensure tl::GmmaDescriptor is visible from generated TU

src/target/codegen_cuda.cc already emits "#include <tl_templates/cuda/gemm.h>" in Finish() (decl_stream at src/target/codegen_cuda.cc:186). No declaration of tl::GmmaDescriptor was found in the repo—confirm that gemm.h (or an included header like common.h) actually defines tl::GmmaDescriptor; if it does not, add an explicit include for the header that declares it in Finish().

tilelang/intrinsics/wgmma_macro_generator.py (1)

104-109: Validate inst_n in wgmma_prefix against supported shapes.

Hardware supports specific N sizes (e.g., 8/16/32/64/128). inst_n = block_col_warps * warp_col_tiles may produce unsupported values. Add a guard or normalize.

Comment on lines +389 to +399
// Decay to a uint64_t
CUTE_HOST_DEVICE constexpr operator uint64_t() const noexcept {
return desc_;
}
template <typename T>
CUTE_HOST_DEVICE constexpr GmmaDescriptor operator+(const T &offset) const {
GmmaDescriptor ret;
ret.reg32_[0] = reg32_[0] + uint32_t(offset);
ret.reg32_[1] = reg32_[1];
return ret;
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

operator+ adds raw bytes; descriptor fields are in 16‑byte quanta.

The descriptor encodes addresses/offsets with 4 LSBs stripped. Adding raw offset likely mis-encodes.

Apply:

-  template <typename T>
-  CUTE_HOST_DEVICE constexpr GmmaDescriptor operator+(const T &offset) const {
+  template <typename T>
+  CUTE_HOST_DEVICE constexpr GmmaDescriptor operator+(const T &offset) const {
     GmmaDescriptor ret;
-    ret.reg32_[0] = reg32_[0] + uint32_t(offset);
+    ret.reg32_[0] = reg32_[0] + uint32_t(offset >> 4);
     ret.reg32_[1] = reg32_[1];
     return ret;
   }

And document that offset is in bytes.

📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
// Decay to a uint64_t
CUTE_HOST_DEVICE constexpr operator uint64_t() const noexcept {
return desc_;
}
template <typename T>
CUTE_HOST_DEVICE constexpr GmmaDescriptor operator+(const T &offset) const {
GmmaDescriptor ret;
ret.reg32_[0] = reg32_[0] + uint32_t(offset);
ret.reg32_[1] = reg32_[1];
return ret;
}
// Decay to a uint64_t
CUTE_HOST_DEVICE constexpr operator uint64_t() const noexcept {
return desc_;
}
template <typename T>
CUTE_HOST_DEVICE constexpr GmmaDescriptor operator+(const T &offset) const {
GmmaDescriptor ret;
ret.reg32_[0] = reg32_[0] + uint32_t(offset >> 4);
ret.reg32_[1] = reg32_[1];
return ret;
}

Comment on lines +168 to +169
assert k_dim >= micro_size_k, f"k_dim must be greater than or equal to {micro_size_k}, got k_dim: {k_dim}"

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Add divisibility assertion for K micro-tiles.

You iterate range(k_dim // micro_size_k) but don’t enforce divisibility; tails would be silently dropped. Add:

-        assert k_dim >= micro_size_k, f"k_dim must be greater than or equal to {micro_size_k}, got k_dim: {k_dim}"
+        assert k_dim >= micro_size_k, f"k_dim must be greater than or equal to {micro_size_k}, got k_dim: {k_dim}"
+        assert (k_dim % micro_size_k) == 0, f"k_dim must be divisible by micro_size_k={micro_size_k}, got k_dim: {k_dim}"

Apply in both SS and RS.

Also applies to: 255-256

🤖 Prompt for AI Agents
In tilelang/intrinsics/wgmma_macro_generator.py around lines 168-169 (and
similarly around lines 255-256), you assert k_dim >= micro_size_k but do not
require k_dim to be divisible by micro_size_k, causing any tail to be silently
dropped when iterating range(k_dim // micro_size_k); add an assertion ensuring
k_dim % micro_size_k == 0 with a clear error message (e.g., "k_dim must be
divisible by micro_size_k, got k_dim: {k_dim}, micro_size_k: {micro_size_k}") in
both the SS and RS sections so the code fails fast instead of silently
truncating tails.

Comment on lines +176 to +183
elems_in_bytes = DataType(self.a_dtype).bits // 8

# by default, we utilize non-swizzle layout offset
a_leading_byte_offset = (8 * 8 * elems_in_bytes) if a_is_k_major else (8 * m_dim *
elems_in_bytes)
a_stride_byte_offset = (8 * k_dim * elems_in_bytes) if a_is_k_major else (8 * 8 *
elems_in_bytes)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Fix LBO/SBO units and per-operand element sizes (incorrect descriptor math).

  • You compute element size once from A (elems_in_bytes) and reuse it for B. If a_dtype != b_dtype (e.g., int8 × int8, fp8, tf32), B’s descriptor math is wrong.
  • In swizzled MN‑major paths you set LBO to swizzle_atom_size() (already in 16‑B atoms) but still right‑shift by 4 when initializing the descriptor, effectively dividing by 16 twice (LBO=0 for 32B/64B/128B). Use swizzle byte size and keep the >> 4 conversion, or pass atom counts without shifting.

Apply this diff to make LBO/SBO correct and per‑operand:

-        elems_in_bytes = DataType(self.a_dtype).bits // 8
+        a_elems_in_bytes = DataType(self.a_dtype).bits // 8
+        b_elems_in_bytes = DataType(self.b_dtype).bits // 8
@@
-        a_leading_byte_offset = (8 * 8 * elems_in_bytes) if a_is_k_major else (8 * m_dim *
-                                                                               elems_in_bytes)
-        a_stride_byte_offset = (8 * k_dim * elems_in_bytes) if a_is_k_major else (8 * 8 *
-                                                                                  elems_in_bytes)
+        a_leading_byte_offset = (8 * 8 * a_elems_in_bytes) if a_is_k_major else (8 * m_dim *
+                                                                                 a_elems_in_bytes)
+        a_stride_byte_offset = (8 * k_dim * a_elems_in_bytes) if a_is_k_major else (8 * 8 *
+                                                                                    a_elems_in_bytes)
@@
-        if not a_swizzle_mode.is_none():
+        if not a_swizzle_mode.is_none():
             # swizzle mode doesn't require LBO/SBO to be 1
             # https://docs.nvidia.com/cuda/parallel-thread-execution/#asynchronous-warpgroup-level-leading-dimension-byte-offset
             if a_is_k_major:
                 a_leading_byte_offset = 16
             else:
                 # MN Major
                 # LBO represents the distance between two atoms along the M dimension
                 # SBO represents the distance between two atoms along the K dimension
-                a_leading_byte_offset = a_swizzle_mode.swizzle_atom_size()
-                a_stride_byte_offset = 8 * 64 * elems_in_bytes
+                a_leading_byte_offset = a_swizzle_mode.swizzle_byte_size()
+                a_stride_byte_offset = 8 * 64 * a_elems_in_bytes
@@
-        b_leading_byte_offset = (8 * 8 * elems_in_bytes) if b_is_k_major else (8 * n_dim *
-                                                                               elems_in_bytes)
-        b_stride_byte_offset = (8 * k_dim * elems_in_bytes) if b_is_k_major else (8 * 8 *
-                                                                                  elems_in_bytes)
+        b_leading_byte_offset = (8 * 8 * b_elems_in_bytes) if b_is_k_major else (8 * n_dim *
+                                                                                 b_elems_in_bytes)
+        b_stride_byte_offset = (8 * k_dim * b_elems_in_bytes) if b_is_k_major else (8 * 8 *
+                                                                                    b_elems_in_bytes)
         if not b_swizzle_mode.is_none():
             # swizzle mode doesn't require LBO/SBO to be 1
             # https://docs.nvidia.com/cuda/parallel-thread-execution/#asynchronous-warpgroup-level-leading-dimension-byte-offset
             if b_is_k_major:
                 b_leading_byte_offset = 16
             else:
                 # MN Major
                 # LBO represents the distance between two atoms along the N dimension
                 # SBO represents the distance between two atoms along the K dimension
-                b_leading_byte_offset = b_swizzle_mode.swizzle_atom_size()
-                b_stride_byte_offset = 8 * n_dim * elems_in_bytes
+                b_leading_byte_offset = b_swizzle_mode.swizzle_byte_size()
+                b_stride_byte_offset = 8 * n_dim * b_elems_in_bytes
@@
-            T.initialize_descriptor(desc_a, A_buf.access_ptr("w"), a_swizzle_mode,
+            T.initialize_descriptor(desc_a, A_buf.access_ptr("r"), a_swizzle_mode,
                                     int(a_leading_byte_offset >> 4), int(a_stride_byte_offset >> 4))
-            T.initialize_descriptor(desc_b, B_buf.access_ptr("w"), b_swizzle_mode,
+            T.initialize_descriptor(desc_b, B_buf.access_ptr("r"), b_swizzle_mode,
                                     int(b_leading_byte_offset >> 4), int(b_stride_byte_offset >> 4))
@@
-                                   accum_dtype_abbrv, desc_a.data, (A_offset * elems_in_bytes) >> 4,
-                                   desc_b.data, (B_offset * elems_in_bytes) >> 4, C_local_buf.data,
+                                   accum_dtype_abbrv, desc_a.data, (A_offset * a_elems_in_bytes) >> 4,
+                                   desc_b.data, (B_offset * b_elems_in_bytes) >> 4, C_local_buf.data,

Also applies to: 184-195, 196-211, 216-219, 229-231

Comment on lines +255 to +256
assert k_dim >= micro_size_k, f"k_dim must be greater than or equal to {micro_size_k}, got k_dim: {k_dim}"

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Mirror the fixes in RS path (per‑operand sizes, swizzle LBO units, pointer mode).

Do the same corrections for RS:

  • Use separate a_elems_in_bytes/b_elems_in_bytes.
  • Use swizzle_byte_size() for LBO and keep >>4.
  • Use access_ptr("r") for descriptors.

Also add the divisibility assert (see separate comment). Apply:

-        assert k_dim >= micro_size_k, f"k_dim must be greater than or equal to {micro_size_k}, got k_dim: {k_dim}"
+        assert k_dim >= micro_size_k, f"k_dim must be greater than or equal to {micro_size_k}, got k_dim: {k_dim}"
+        assert (k_dim % micro_size_k) == 0, f"k_dim must be divisible by micro_size_k={micro_size_k}, got k_dim: {k_dim}"
@@
-        elems_in_bytes = DataType(self.a_dtype).bits // 8
+        a_elems_in_bytes = DataType(self.a_dtype).bits // 8
+        b_elems_in_bytes = DataType(self.b_dtype).bits // 8
@@
-        b_leading_byte_offset = (8 * 8 * elems_in_bytes) if b_is_k_major else (8 * n_dim *
-                                                                               elems_in_bytes)
-        b_stride_byte_offset = (8 * k_dim * elems_in_bytes) if b_is_k_major else (8 * 8 *
-                                                                                  elems_in_bytes)
+        b_leading_byte_offset = (8 * 8 * b_elems_in_bytes) if b_is_k_major else (8 * n_dim *
+                                                                                 b_elems_in_bytes)
+        b_stride_byte_offset = (8 * k_dim * b_elems_in_bytes) if b_is_k_major else (8 * 8 *
+                                                                                    b_elems_in_bytes)
         if not b_swizzle_mode.is_none():
             # swizzle mode doesn't require LBO/SBO to be 1
             # https://docs.nvidia.com/cuda/parallel-thread-execution/#asynchronous-warpgroup-level-leading-dimension-byte-offset
             if b_is_k_major:
                 b_leading_byte_offset = 16
             else:
                 # MN Major
                 # LBO represents the distance between two atoms along the N dimension
                 # SBO represents the distance between two atoms along the K dimension
-                b_leading_byte_offset = b_swizzle_mode.swizzle_atom_size()
-                b_stride_byte_offset = 8 * n_dim * elems_in_bytes
+                b_leading_byte_offset = b_swizzle_mode.swizzle_byte_size()
+                b_stride_byte_offset = 8 * n_dim * b_elems_in_bytes
@@
-            T.initialize_descriptor(desc_b, B_buf.access_ptr("w"), b_swizzle_mode,
+            T.initialize_descriptor(desc_b, B_buf.access_ptr("r"), b_swizzle_mode,
                                     int(b_leading_byte_offset >> 4), int(b_stride_byte_offset >> 4))
@@
-                        (B_offset * elems_in_bytes) >> 4,
+                        (B_offset * b_elems_in_bytes) >> 4,

Also applies to: 257-278, 283-284, 301-302

🤖 Prompt for AI Agents
In tilelang/intrinsics/wgmma_macro_generator.py around lines 255-256 (and
similarly apply changes to ranges 257-278, 283-284, 301-302), the RS path
currently mirrors the MS fixes incorrectly: replace any single shared
elems_in_bytes with separate a_elems_in_bytes and b_elems_in_bytes computed from
their per‑operand element sizes; for LBO swizzle sizes call
swizzle_byte_size(...) and retain the bitshift >> 4 as before; change descriptor
pointer creation to use access_ptr("r") instead of the current mode; and add the
divisibility assert that ensures k_dim (or relevant dim) is divisible by
micro_size_k (matching the other path). Ensure these edits are applied
consistently across the listed line ranges.

Comment on lines +291 to +295
accum_dtype,
wgmma_prefix,
self.a_transposed,
not self.b_transposed,
a_dtype_abbrv,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

🧩 Analysis chain

Likely wrong K‑major flags passed to RS intrinsic (inverted vs SS).

ptx_wgmma_rs takes the same a_is_k_major/b_is_k_major semantics as SS. SS passes not self.a_transposed and self.b_transposed, but RS passes their negations. Align them:

-                        self.a_transposed,
-                        not self.b_transposed,
+                        (not self.a_transposed),
+                        self.b_transposed,

Please confirm against the PTX WGMMA docs for RS.


🌐 Web query:

What are the `a_is_k_major` / `b_is_k_major` semantics for `wgmma.mma_async` RS vs SS variants in NVIDIA PTX?

💡 Result:

Short answer: those flags tell the instruction whether the A / B operand’s shared‑memory descriptor (SS) — or the register‑sourced operand’s logical layout (RS) — is K‑major (i.e. the 8×2 / 2×(N/8) “core‑matrix” tile stride is along the K dimension) or not. Concretely:

  • SS variants (the .ss forms) take 64‑bit matrix descriptors for A and/or B; the descriptor’s “major‑K” bit(s) (a_is_k_major / b_is_k_major) mean “this SMEM tile is laid out K‑major” so the hardware will interpret LBO/SBO in the descriptor accordingly (core matrices are arranged with K as the strided direction). Use these bits when your shared‑memory tile is stored with K as the leading/strided dimension. [1][2]

  • RS variants (the .rs forms) source one operand from registers (or a thread‑local fragment) and the other from a descriptor; for RS the a_is_k_major / b_is_k_major semantic still denotes whether that operand’s logical tile is K‑major, but for the register‑sourced operand the layout interpretation is applied to the fragment/register mapping (i.e. whether the fragment is arranged as K‑major vs M/N‑major). In practice this means you must set the bit to match how your register fragment was produced (or how you packed it), otherwise the multiply/accumulate will use the wrong element ordering. [1][2]

References:

  • PTX WGMMA / wgmma.mma_async description (matrix descriptors, K‑major vs MN‑major and SS/RS usage). [1]
  • CUTLASS / WGMMA notes (explanation of descriptors, LBO/SBO and K‑major layout and the distinction of SS mode requiring descriptors). [2]

[1] NVIDIA PTX ISA (wgmma / wgmma.mma_async matrix descriptor & layout).
[2] CUTLASS / WGMMA tutorial and notes on shared‑memory descriptors and K‑major layout.


Fix inverted K‑major flags for RS WGMMA intrinsic

RS uses the same a_is_k_major / b_is_k_major semantics as SS; the call currently passes their negations. Update the args at tilelang/intrinsics/wgmma_macro_generator.py:291-295:

-                        self.a_transposed,
-                        not self.b_transposed,
+                        (not self.a_transposed),
+                        self.b_transposed,
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
accum_dtype,
wgmma_prefix,
self.a_transposed,
not self.b_transposed,
a_dtype_abbrv,
accum_dtype,
wgmma_prefix,
(not self.a_transposed),
self.b_transposed,
a_dtype_abbrv,
🤖 Prompt for AI Agents
In tilelang/intrinsics/wgmma_macro_generator.py around lines 291-295 the K-major
flags for the RS WGMMA intrinsic are passed inverted relative to the SS
semantics; replace the current arguments so the a/b K-major booleans are the
negations of their current values (change self.a_transposed to not
self.a_transposed, and change not self.b_transposed to self.b_transposed) so the
call passes the correct a_is_k_major and b_is_k_major flags.

…e size calculations in wgmma macro generator. Enhanced handling of leading and stride byte offsets based on swizzle mode, improving clarity and performance in tensor core intrinsic emissions.
print(f"b_leading_byte_offset: {b_leading_byte_offset}")
print(f"b_stride_byte_offset: {b_stride_byte_offset}")
# exit()
@T.macro
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bug: Debugging Artifacts in Code

Debugging print statements and a commented exit() call were accidentally committed. These temporary artifacts output debug information during runtime, polluting the console.

Fix in Cursor Fix in Web

// 暂时注释掉 static_assert 来看调试输出
// static_assert(always_false_v<decltype(c)>,
// "wgmma_ss: No specialization available for given template parameters!");
};
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bug: Debug Code and Masked Errors in Template

The WgmmaSSImpl fallback template includes a debug printf statement that pollutes output and a commented-out static_assert. This static_assert (noted by a Chinese comment as temporary debugging code) could mask compilation errors.

Fix in Cursor Fix in Web

class GemmInst(IntEnum):
MMA = 0
WGMMMA = 1
MFMA = 2
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bug: Enum Mismatch and Typo Cause Runtime Errors

The GemmInst Python enum has a value mismatch with its C++ counterpart, leading to runtime errors when C++ returns an undefined enum value (e.g., for MFMA or kUTCMMA). There's also a typo in the WGMMMA enum member, which should be WGMMA.

Fix in Cursor Fix in Web

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

♻️ Duplicate comments (6)
src/target/codegen_cuda.cc (1)

1305-1362: WGMMA-SS: emit validated asm_code; drop ad‑hoc string (invalid arithmetic on descriptor, dead code)

asm_code from PrintWGMMAAssembly(...) is computed and then ignored. The custom wgmma_asm_code does pointer arithmetic on tl::GmmaDescriptor via uint64_t((desc_a) + (A_offset)), which is invalid. Unify with the RS path and stream asm_code.

Apply:

   const bool a_is_shared = true;
   this->PrintIndent();
   std::string asm_code = PrintWGMMAAssembly(
       shape, a_is_k_major, b_is_k_major, A_dtype, B_dtype, C_dtype, a_desc, A_offset,
       b_desc, B_offset, c_ref, c_offset, scale_out, scale_in_a, scale_in_b,
       a_is_shared, "", "", "", false);
-  auto [m, n, k] = tl::codegen::ptx::ParseMMAShape(shape);
-  std::string wgmma_asm_code = "tl::wgmma_ss<(AType), (BType), (CType), (M), (N), (K), (tnspA), (tnspB), (scaleA), (scaleB)>(uint64_t((desc_a) + (A_offset)), uint64_t((desc_b) + (B_offset)), ((uint32_t*)((C))), (scale_out));\n";
-  // replace patterns
-  tl::codegen::Replacer replacer;
-  replacer.register_rule("(AType)", tl::codegen::ptx::DTypeEnumToString(A_dtype));
-  replacer.register_rule("(BType)", tl::codegen::ptx::DTypeEnumToString(B_dtype));
-  replacer.register_rule("(CType)", tl::codegen::ptx::DTypeEnumToString(C_dtype));
-  replacer.register_rule("(M)", std::to_string(m));
-  replacer.register_rule("(N)", std::to_string(n));
-  replacer.register_rule("(K)", std::to_string(k));
-  replacer.register_rule("(tnspA)", a_is_k_major? "false": "true");
-  replacer.register_rule("(tnspB)", b_is_k_major? "false": "true");
-  replacer.register_rule("(scaleA)", scale_in_a? "1": "-1");
-  replacer.register_rule("(scaleB)", scale_in_b? "1": "-1");
-  replacer.register_rule("(desc_a)", a_desc);
-  replacer.register_rule("(A_offset)", A_offset);
-  replacer.register_rule("(desc_b)", b_desc);
-  replacer.register_rule("(B_offset)", B_offset);
-  replacer.register_rule("(C)", c_ref + " + " + c_offset);
-  replacer.register_rule("(scale_out)", scale_out ? "true" : "false");
-  wgmma_asm_code = replacer.rewrite(wgmma_asm_code);
-  this->stream << wgmma_asm_code;
+  this->stream << asm_code;
tilelang/intrinsics/wgmma_macro_generator.py (5)

159-169: Add validations: m_dim multiple-of-64 and k_dim divisibility by micro_size_k.

Prevents silent no-op loops and tails.

         m_dim = self.block_row_warps * self.warp_row_tiles
         warp_cols = self.warp_cols
         micro_size_k = self.micro_size_k
         k_dim, n_dim = self.chunk, self.block_col_warps * self.warp_col_tiles
         wgmma_prefix = self.wgmma_prefix
@@
-        assert k_dim >= micro_size_k, f"k_dim must be greater than or equal to {micro_size_k}, got k_dim: {k_dim}"
+        assert m_dim >= 64 and (m_dim % 64) == 0, f"m_dim must be a multiple of 64, got {m_dim}"
+        assert k_dim >= micro_size_k, f"k_dim must be ≥ {micro_size_k}, got {k_dim}"
+        assert (k_dim % micro_size_k) == 0, f"k_dim must be divisible by micro_size_k={micro_size_k}, got {k_dim}"

176-219: Fix LBO/SBO units, per-operand element sizes, and descriptor pointer modes (SS path).

Currently mixes A’s element size for B, uses swizzle_atom_size() then >>4 (double-divide by 16), and passes access_ptr("w"). This will generate incorrect descriptors and offsets, especially for int8/fp8/tf32.

-        elems_in_bytes = DataType(self.a_dtype).bits // 8
+        a_elems_in_bytes = DataType(self.a_dtype).bits // 8
+        b_elems_in_bytes = DataType(self.b_dtype).bits // 8
@@
-        a_leading_byte_offset = (8 * 8 * elems_in_bytes) if a_is_k_major else (8 * m_dim *
-                                                                               elems_in_bytes)
-        a_stride_byte_offset = (8 * k_dim * elems_in_bytes) if a_is_k_major else (8 * 8 *
-                                                                                  elems_in_bytes)
+        a_leading_byte_offset = (8 * 8 * a_elems_in_bytes) if a_is_k_major else (8 * m_dim *
+                                                                                 a_elems_in_bytes)
+        a_stride_byte_offset = (8 * k_dim * a_elems_in_bytes) if a_is_k_major else (8 * 8 *
+                                                                                    a_elems_in_bytes)
@@
-            else:
+            else:
                 # MN Major
                 # LBO represents the distance between two atoms along the M dimension
                 # SBO represents the distance between two atoms along the K dimension
-                a_leading_byte_offset = a_swizzle_mode.swizzle_atom_size()
-                a_stride_byte_offset = 8 * 64 * elems_in_bytes
+                a_leading_byte_offset = a_swizzle_mode.swizzle_byte_size()
+                a_stride_byte_offset = 8 * 64 * a_elems_in_bytes
@@
-        b_leading_byte_offset = (8 * 8 * elems_in_bytes) if b_is_k_major else (8 * n_dim *
-                                                                               elems_in_bytes)
-        b_stride_byte_offset = (8 * k_dim * elems_in_bytes) if b_is_k_major else (8 * 8 *
-                                                                                  elems_in_bytes)
+        b_leading_byte_offset = (8 * 8 * b_elems_in_bytes) if b_is_k_major else (8 * n_dim *
+                                                                                 b_elems_in_bytes)
+        b_stride_byte_offset = (8 * k_dim * b_elems_in_bytes) if b_is_k_major else (8 * 8 *
+                                                                                    b_elems_in_bytes)
         if not b_swizzle_mode.is_none():
@@
-            else:
-                # MN Major, K * N
-                # LBO represents the distance between two atoms along the N dimension
-                # SBO represents the distance between two atoms along the K dimension
-                b_n_axis_atoms = n_dim // (b_swizzle_mode.swizzle_byte_size() // elems_in_bytes)
-                if b_n_axis_atoms <= 1:
-                    b_leading_byte_offset = 0
-                else:
-                    b_leading_byte_offset = 8 * 8 * elems_in_bytes * k_dim
-
-                if b_n_axis_atoms <= 1:
-                    b_stride_byte_offset = 8 * elems_in_bytes * n_dim
-                else:
-                    b_stride_byte_offset = 8 * elems_in_bytes * (b_swizzle_mode.swizzle_byte_size() // elems_in_bytes)
+            else:
+                # MN Major (N × K): use swizzle byte size for LBO; SBO spans N in bytes.
+                b_leading_byte_offset = b_swizzle_mode.swizzle_byte_size()
+                b_stride_byte_offset = 8 * n_dim * b_elems_in_bytes
@@
-            desc_a = T.alloc_descriptor()
-            desc_b = T.alloc_descriptor()
-            T.initialize_descriptor(desc_a, A_buf.access_ptr("w"), a_swizzle_mode,
+            desc_a = T.alloc_descriptor()
+            desc_b = T.alloc_descriptor()
+            T.initialize_descriptor(desc_a, A_buf.access_ptr("r"), a_swizzle_mode,
                                     int(a_leading_byte_offset >> 4), int(a_stride_byte_offset >> 4))
-            T.initialize_descriptor(desc_b, B_buf.access_ptr("w"), b_swizzle_mode,
+            T.initialize_descriptor(desc_b, B_buf.access_ptr("r"), b_swizzle_mode,
                                     int(b_leading_byte_offset >> 4), int(b_stride_byte_offset >> 4))
@@
-                    B_offset = k_dim_offset if b_is_k_major else k_dim_offset * (b_swizzle_mode.swizzle_byte_size() // elems_in_bytes)
+                    B_offset = k_dim_offset if b_is_k_major else k_dim_offset * (b_swizzle_mode.swizzle_byte_size() // b_elems_in_bytes)
@@
-                                   accum_dtype_abbrv, desc_a.data, (A_offset * elems_in_bytes) >> 4,
-                                   desc_b.data, (B_offset * elems_in_bytes) >> 4, C_local_buf.data,
+                                   accum_dtype_abbrv, desc_a.data, (A_offset * a_elems_in_bytes) >> 4,
+                                   desc_b.data, (B_offset * b_elems_in_bytes) >> 4, C_local_buf.data,

Also applies to: 239-254


278-279: Add RS-side divisibility assert.

Mirror the SS check to avoid dropped tails.

-        assert k_dim >= micro_size_k, f"k_dim must be greater than or equal to {micro_size_k}, got k_dim: {k_dim}"
+        assert k_dim >= micro_size_k, f"k_dim must be ≥ {micro_size_k}, got {k_dim}"
+        assert (k_dim % micro_size_k) == 0, f"k_dim must be divisible by micro_size_k={micro_size_k}, got {k_dim}"

280-309: Mirror per-operand sizes and descriptor pointer mode (RS path).

Use b_elems_in_bytes and access_ptr("r"); fix LBO/SBO units for swizzle.

-        elems_in_bytes = DataType(self.a_dtype).bits // 8
+        a_elems_in_bytes = DataType(self.a_dtype).bits // 8
+        b_elems_in_bytes = DataType(self.b_dtype).bits // 8
@@
-        b_leading_byte_offset = (8 * 8 * elems_in_bytes) if b_is_k_major else (8 * n_dim *
-                                                                               elems_in_bytes)
-        b_stride_byte_offset = (8 * k_dim * elems_in_bytes) if b_is_k_major else (8 * 8 *
-                                                                                  elems_in_bytes)
+        b_leading_byte_offset = (8 * 8 * b_elems_in_bytes) if b_is_k_major else (8 * n_dim *
+                                                                                 b_elems_in_bytes)
+        b_stride_byte_offset = (8 * k_dim * b_elems_in_bytes) if b_is_k_major else (8 * 8 *
+                                                                                    b_elems_in_bytes)
@@
-            else:
-                # MN Major
-                # LBO represents the distance between two atoms along the N dimension
-                # SBO represents the distance between two atoms along the K dimension
-                b_n_axis_atoms = n_dim // (b_swizzle_mode.swizzle_byte_size() // elems_in_bytes)
-                if b_n_axis_atoms <= 1:
-                    b_leading_byte_offset = 0
-                else:
-                    b_leading_byte_offset = 8 * 8 * elems_in_bytes * k_dim
-
-                if b_n_axis_atoms <= 1:
-                    b_stride_byte_offset = 8 * elems_in_bytes * n_dim
-                else:
-                    b_stride_byte_offset = 8 * elems_in_bytes * (b_swizzle_mode.swizzle_byte_size() // elems_in_bytes)
+            else:
+                # MN Major
+                b_leading_byte_offset = b_swizzle_mode.swizzle_byte_size()
+                b_stride_byte_offset = 8 * n_dim * b_elems_in_bytes
@@
-            T.initialize_descriptor(desc_b, B_buf.access_ptr("w"), b_swizzle_mode,
+            T.initialize_descriptor(desc_b, B_buf.access_ptr("r"), b_swizzle_mode,
                                     int(b_leading_byte_offset >> 4), int(b_stride_byte_offset >> 4))
@@
-                        (B_offset * elems_in_bytes) >> 4,
+                        (B_offset * b_elems_in_bytes) >> 4,

Also applies to: 313-314, 332-332


324-325: Fix inverted K-major flags for RS.

RS uses the same a/b K-major semantics as SS.

-                        self.a_transposed,
-                        not self.b_transposed,
+                        (not self.a_transposed),
+                        self.b_transposed,
🧹 Nitpick comments (6)
src/target/codegen_cuda.cc (1)

1363-1397: WGMMA-RS path looks good; minor naming/comment nit

Logic mirrors SS but correctly streams asm_code. Consider renaming A_layout/B_layout to a_is_k_major/b_is_k_major and fixing the arg comments (arg 0 is shape; dtype is not an arg) for clarity.

tilelang/intrinsics/wgmma_macro_generator.py (5)

104-109: Remove unused arg and make prefix computation explicit.

n_dim is unused; also guard inst_k with a sanity check (tf32 can be tricky).

-    def _initialize_wgmma_prefix(self, n_dim: int = 16):
+    def _initialize_wgmma_prefix(self):
         inst_m, inst_n = 64, self.block_col_warps * self.warp_col_tiles
-        # 256 bits per instruction
-        inst_k = 256 // DataType(self.a_dtype).bits
+        # k derived from input dtype; ensure it's one of the valid WGMMA shapes.
+        bits = DataType(self.a_dtype).bits
+        assert 256 % bits == 0, f"Unsupported dtype bits for WGMMA prefix: {bits}"
+        inst_k = 256 // bits
         self.wgmma_prefix = f"m{inst_m}n{inst_n}k{inst_k}"
-        self._initialize_wgmma_prefix(self.n_dim)
+        self._initialize_wgmma_prefix()

Please verify tf32’s effective width in your stack (if tvm “tf32” reports 19 bits, switch to an explicit dtype→k map).

Also applies to: 94-95


221-234: Remove debug prints (noise in codegen path).

These spam stdout during TIR lowering; drop or guard behind a verbose flag.

-        print(f"a_leading_byte_offset: {a_leading_byte_offset >> 4}")
-        print(f"a_stride_byte_offset: {a_stride_byte_offset >> 4}")
-
-        print(f"b_swizzle_atom_size: {b_swizzle_mode.swizzle_atom_size()}")
-        print(f"b_swizzle_byte_size: {b_swizzle_mode.swizzle_byte_size()}")
-        print(f"m_dim: {m_dim}")
-        print(f"n_dim: {n_dim}")
-        print(f"k_dim: {k_dim}")
-        print(f"micro_size_k: {micro_size_k}")
-        print(f"a_leading_byte_offset: {a_leading_byte_offset}")
-        print(f"a_stride_byte_offset: {a_stride_byte_offset}")
-        print(f"b_leading_byte_offset: {b_leading_byte_offset}")
-        print(f"b_stride_byte_offset: {b_stride_byte_offset}")

342-347: Docstring mismatch: this is a load-layout helper.

Tweak wording to avoid confusion.

-        Create a layout function for storing MMA results into a fragment buffer.
+        Create a layout function for loading operand A into a fragment buffer.

364-367: Drop redundant import inside method.

is_fragment is already imported at module scope.

-        from tilelang.utils import is_fragment

143-143: Tidy exception messages per linter hint (TRY003).

Shorten or use a custom exception type; current messages are fine but flagged.

Also applies to: 384-384

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 51fcf15 and ce9f545.

📒 Files selected for processing (2)
  • src/target/codegen_cuda.cc (5 hunks)
  • tilelang/intrinsics/wgmma_macro_generator.py (1 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
src/target/codegen_cuda.cc (4)
src/transform/storage_rewrite.cc (6)
  • scope (674-678)
  • scope (674-674)
  • n (366-370)
  • n (366-366)
  • n (371-375)
  • n (371-371)
tilelang/language/tir/op.py (2)
  • ptx_wgmma_ss (1064-1103)
  • ptx_wgmma_rs (1106-1143)
src/target/ptx.cc (8)
  • PrintWGMMAAssembly (1186-1257)
  • PrintWGMMAAssembly (1186-1195)
  • ParseMMAShape (140-148)
  • ParseMMAShape (140-140)
  • DTypeEnumToString (107-109)
  • DTypeEnumToString (107-107)
  • DTypeEnumToString (111-113)
  • DTypeEnumToString (111-111)
tilelang/language/builtin.py (2)
  • initialize_descriptor (355-386)
  • increase_descriptor_offset (389-411)
tilelang/intrinsics/wgmma_macro_generator.py (8)
tilelang/utils/language.py (1)
  • is_fragment (68-78)
tilelang/layout/layout.py (3)
  • Layout (13-148)
  • is_equal (136-145)
  • inverse (125-134)
tilelang/layout/swizzle.py (3)
  • make_full_bank_swizzled_layout (39-59)
  • make_half_bank_swizzled_layout (64-84)
  • make_quarter_bank_swizzled_layout (89-109)
tilelang/layout/fragment.py (4)
  • is_equal (209-213)
  • Fragment (13-213)
  • replicate (146-160)
  • repeat (123-144)
tilelang/language/allocate.py (1)
  • alloc_descriptor (129-135)
tilelang/language/builtin.py (1)
  • initialize_descriptor (355-386)
tilelang/language/tir/op.py (2)
  • ptx_wgmma_ss (1064-1103)
  • ptx_wgmma_rs (1106-1143)
tilelang/intrinsics/mma_macro_generator.py (1)
  • get_store_index_map (159-165)
🪛 Ruff (0.12.2)
tilelang/intrinsics/wgmma_macro_generator.py

104-104: Unused method argument: n_dim

(ARG002)


143-143: Avoid specifying long messages outside the exception class

(TRY003)


384-384: Avoid specifying long messages outside the exception class

(TRY003)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Cursor Bugbot
🔇 Additional comments (5)
src/target/codegen_cuda.cc (5)

898-901: Descriptor refs: correct special‑casing for local.descriptor

Returning the variable directly for descriptor buffers is the right behavior and aligns with how the descriptor is declared.


1721-1721: LGTM

No issues spotted in tl_shuffle_elect emission.


1845-1845: Scope guard update is correct

Skipping the unsupported-scope check for local.descriptor is necessary after introducing descriptor allocations.


1723-1742: Descriptor intrinsics emission LGTM — confirm offset width

initialize_descriptor templating is correct. For increase_descriptor_offset(...), confirm the API expects a 32-bit offset; repository search did not locate the declaration, so either switch to std::int32_t (or a wider type) and add an explicit cast to avoid narrowing, or confirm plain int is intended. Location: src/target/codegen_cuda.cc:1739-1742


1809-1811: Descriptor allocation: ensure header is available

src/target/codegen_cuda.cc already writes #include <tl_templates/cuda/gemm.h> into the generated decls, but I couldn't find a definition for tl::GmmaDescriptor in the repo — ensure gemm.h (or a transitive include) defines tl::GmmaDescriptor; if it does not, add #include <tl_templates/cuda/common.h> in Finish() so the generated file includes the descriptor definition.

Comment on lines +83 to +85
warp_row_tiles: int = 8,
warp_col_tiles: int = 8,
chunk: int = 16,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Default warp_row_tiles conflicts with your own assertion (must be ≥16).

Set the default to 16 to avoid constructor-time assertion failures via parent init.

-        warp_row_tiles: int = 8,
+        warp_row_tiles: int = 16,
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
warp_row_tiles: int = 8,
warp_col_tiles: int = 8,
chunk: int = 16,
warp_row_tiles: int = 16,
warp_col_tiles: int = 8,
chunk: int = 16,
🤖 Prompt for AI Agents
In tilelang/intrinsics/wgmma_macro_generator.py around lines 83 to 85, the
default value for warp_row_tiles is 8 which conflicts with the class assertion
requiring warp_row_tiles >= 16 and can trigger assertion failures during parent
initialization; change the default warp_row_tiles to 16 (update the parameter
default in the function/class signature) so it satisfies the assertion by
default and avoid constructor-time failures.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant