Skip to content

Conversation

LeiWang1999
Copy link
Member

@LeiWang1999 LeiWang1999 commented Oct 1, 2025

Summary by CodeRabbit

  • New Features

    • Replication-aware layout inference with primary/secondary/backup paths and automatic fallback.
    • New public check to detect fully replicated fragments.
    • Expanded diagnostic warnings and debug logging for layout decisions.
  • Bug Fixes

    • Prevents incorrect layout propagation from fully replicated non-reducer buffers.
  • Refactor

    • Buffer validation streamlined to skip non-fragment buffers early.
  • Tests

    • Reduced test workload, simplified imports; many example mains now accept dimension parameters with defaults.

…e in ParallelOpNode

- Introduced IsCompletedReplicated method in FragmentNode to check if a buffer is fully replicated.
- Enhanced InferLayout in ParallelOpNode to handle layout inference for replicated buffers, ensuring only fragment[0] access is allowed.
- Updated error handling for non-zero index access in fragment buffers to improve robustness.
…allel.cc

- Enhanced formatting in FragmentNode's IsCompletedReplicated method for better clarity.
- Updated InferLayout method in ParallelOpNode to improve code readability by adjusting line breaks and indentation.
- Ensured consistent formatting across conditional statements and comments for improved maintainability.
Copy link
Contributor

coderabbitai bot commented Oct 1, 2025

Walkthrough

Adds FragmentNode::IsCompletedReplicated() and replication-aware layout inference in ParallelOp strict-mode; refactors a Python transform to skip non-fragment buffers; updates many example mains to accept M,N,K params and reduces some test prepare_output usage.

Changes

Cohort / File(s) Change summary
Layout API extension
src/layout/layout.h, src/layout/layout.cc
Adds FragmentNode::IsCompletedReplicated() const which simplifies forward_thread_ and checks deep-equality against ReplicationPlaceholder; adds a clarifying comment.
Parallel inference flow updates
src/op/parallel.cc
Extends strict-mode inference to detect fully-replicated local.fragment buffers, create forward vars/index/thread and replication-aware Fragment entries; prefers non-replicated write/read sources, adds replicated-write backup path, free-mode cross-thread/pure-buffer-store checks, expanded logging and adjusted compute_loop_layout_from_buffer logic and fallbacks.
Python transform refactor
tilelang/transform/add_bufstore_wrapper.py
Early-continue for non-local.fragment buffers in fragment-buffer validation loop to make control flow explicit while preserving behavior.
Examples — GEMM
examples/gemm/example_gemm_intrinsics.py, examples/gemm/example_gemm_persistent.py, examples/gemm/example_gemm_autotune.py, examples/gemm/test_example_gemm.py
Change main() signatures to accept M,N,K with defaults (and rename lowercase params to uppercase in autotune); update callers/tests to pass explicit dimensions.
Examples — Warp specialize
examples/warp_specialize/* (multiple files)
Multiple main() entry points now accept M,N,K defaults (various values); tests updated to call with explicit sizes; one decorator formatting/comment removed.
Tests / GDN example
examples/gdn/test_example_gdn_compilation.py, examples/warp_specialize/test_example_warp_specialize.py
Reduce test problem sizes (e.g., S from 32768→1024) for faster tests and remove some prepare_output usages in tests; update invocations to pass explicit dimensions.

Sequence Diagram(s)

sequenceDiagram
  autonumber
  participant Caller as ParallelOp::InferLayout
  participant Scanner as BufferScanner
  participant Classifier as ReplicaChecker
  participant Layout as compute_loop_layout_from_buffer
  participant FreeMode as FreeModeFallback

  rect rgba(240,248,255,0.9)
  note over Caller,Scanner: Strict-mode scan of fragment buffers
  Caller->>Scanner: enumerate read/write buffers & indices
  Scanner-->>Classifier: buffer entries (scope, indices)
  Classifier-->>Caller: mark fully-replicated (all idx==0 & fragment completed) or non-replicated
  end

  rect rgba(245,255,250,0.9)
  note over Caller: Multi-path source selection
  Caller->>Caller: choose source: 1) non-replicated write, 2) non-replicated read, 3) replicated-write backup
  alt non-replicated write chosen
    Caller->>Layout: compute layout (common path)
  else non-replicated read chosen
    Caller->>Layout: compute layout (read-based)
  else replicated write backup
    Caller->>Layout: attempt replication-aware inference (use forward thread/index)
  else no source
    Caller->>FreeMode: fall back to free-mode inference (cross-thread checks / post-process)
  end
  end
Loading

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~25 minutes

Possibly related PRs

Suggested reviewers

  • chengyupku
  • xysmlx

Poem

I twitch my whiskers at threads that bind,
Copies march where indices find.
I sniff each fragment, forward-thread in tow,
Stitch layouts neat where replicated rows grow.
Hop-hop — a rabbit’s cheer for code that flows. 🐰✨

Pre-merge checks and finishing touches

❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 0.00% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
✅ Passed checks (2 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title Check ✅ Passed The title succinctly highlights the primary change in the strict layout inference path, namely annotating completed replication for fragments with constant indices, and it aligns directly with the implementation of IsCompletedReplicated and the updated kStrict inference logic in the layout module.
✨ Finishing touches
  • 📝 Generate Docstrings
🧪 Generate unit tests
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link

github-actions bot commented Oct 1, 2025

👋 Hi! Thank you for contributing to the TileLang project.

Please remember to run bash format.sh in the root directory of the project to ensure your changes are properly linted and formatted. This will help ensure your contribution passes the format check.

We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work!

🚀

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (3)
src/op/parallel.cc (3)

285-304: Verify the source buffer selection priority logic.

The logic now distinguishes between:

  • Fully replicated reducer write buffers (allowed as sources)
  • Non-replicated write buffers (preferred sources)
  • Fully replicated non-reducer write buffers (backup only)

Line 295: Complex condition - The condition (!is_fully_replicated && source_buffer.defined()) appears to prefer non-replicated buffers over previously selected replicated ones. However, the outer condition already ensures we only enter when !is_fully_replicated || is_reducer, so if we reach line 295 with is_fully_replicated=true, then is_reducer must be true.

Potential issue: If we first encounter a replicated reducer (which sets source_buffer), then encounter a non-replicated non-reducer, the condition at line 295 will override the reducer choice. Is this intended?

Please verify the intended priority order:

  1. Non-replicated write buffer (highest priority)
  2. Replicated reducer write buffer
  3. Non-replicated read buffer
  4. Replicated non-reducer write buffer (backup)

The current logic may not correctly implement this priority if buffers arrive in different orders.

Consider restructuring for clearer priority:

if (buffer_is_write_.count(buffer)) {
  if (!is_fully_replicated) {
    // Non-replicated write buffer: highest priority
    source_buffer = buffer;
  } else if (is_reducer && !source_buffer.defined()) {
    // Replicated reducer: lower priority, only if no source yet
    source_buffer = buffer;
  } else if (!is_reducer && !replicated_write_buffer.defined()) {
    // Replicated non-reducer: backup only
    replicated_write_buffer = buffer;
    DLOG(INFO) << "Found fully replicated non-reducer write buffer ...";
  }
}

457-472: Review the backup buffer fallback mechanism.

Lines 457-472 implement a backup path when no standard source buffer is available:

Design question: The code attempts to infer from replicated_write_buffer but catches LayoutConflictException and falls back to free mode. This seems reasonable, but:

  1. Line 470: Silent failure - After catching the exception and clearing replicated_write_buffer, the code continues to line 474's free mode check. This silent fallback might mask configuration issues.

  2. Error propagation - The caught exception message is logged at WARNING level, but the user might not see why inference ultimately uses free mode.

Consider whether the exception should be re-thrown with additional context, or if a more explicit failure mode would be clearer:

} catch (const LayoutConflictException &e) {
  DLOG(WARNING) << "Failed to infer from replicated buffer: " << e.what()
                << ". Falling back to free mode";
  // Optionally: track that we fell back due to conflict
  replicated_write_buffer = Buffer(); // Clear to trigger free mode below
}

This is already implemented correctly, but consider adding a metric or counter for monitoring how often this fallback path is taken in production.


216-258: Fix typo in replication comment
Apply this diff to correct the comment in src/op/parallel.cc:

-    // Deduce buffers that shoule be complicated replicated.
+    // Deduce buffers that should be completely replicated.
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between f09e91e and 6464a9d.

📒 Files selected for processing (4)
  • src/layout/layout.cc (1 hunks)
  • src/layout/layout.h (1 hunks)
  • src/op/parallel.cc (6 hunks)
  • tilelang/transform/add_bufstore_wrapper.py (1 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
src/layout/layout.h (1)
src/layout/layout.cc (2)
  • IsCompletedReplicated (330-334)
  • IsCompletedReplicated (330-330)
src/op/parallel.cc (3)
src/transform/legalize_safe_memory_access.cc (8)
  • buffer (80-88)
  • buffer (80-80)
  • buffer (91-130)
  • buffer (91-92)
  • buffer (236-240)
  • buffer (236-236)
  • buffer (242-245)
  • buffer (242-242)
src/layout/layout.cc (2)
  • Fragment (293-315)
  • Fragment (317-327)
src/op/parallel.h (1)
  • LayoutConflictException (27-34)
🪛 Ruff (0.13.2)
tilelang/transform/add_bufstore_wrapper.py

141-143: Avoid specifying long messages outside the exception class

(TRY003)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
  • GitHub Check: build-test-amd
  • GitHub Check: format-check
🔇 Additional comments (5)
tilelang/transform/add_bufstore_wrapper.py (1)

137-143: LGTM: Clearer control flow with early filtering.

The refactor to use early-continue for non-fragment buffers improves readability while preserving the original validation logic for fragment buffers.

src/layout/layout.h (1)

104-104: LGTM: Clean API addition.

The new IsCompletedReplicated() method provides a clear interface for detecting fully replicated fragments. The const-correctness and naming are appropriate.

src/layout/layout.cc (1)

329-334: LGTM: Correct implementation of replication detection.

The implementation correctly uses Analyzer::Simplify before comparing with ReplicationPlaceholder(). The use of ExprDeepEqual is appropriate for symbolic expression comparison.

src/op/parallel.cc (2)

474-519: Review the free mode inference path.

The free mode inference (when no source buffer is available) has been refactored:

Good practices observed:

  • Lines 478-480: Correctly uses remapped buffers for vectorization analysis
  • Lines 490-496: Adjusts vector size to ensure divisibility
  • Lines 499-512: Validates against coalesced_width annotation

Line 474: Condition logic - The check !loop_layout_.defined() handles both the case where no source buffer existed and the case where backup buffer inference failed. This is correct.

Completeness: The free mode path appears complete and handles edge cases properly (vectorization, thread bounds, predicates for replication).


326-395: No issues with uses_loop_var detection or warning logic The uses_loop_var flag only matches variables in loop_vars_, so the rep iteration var cannot trigger a false positive, and the initial warning for fully replicated non-reducers is acceptable since a LayoutConflictException is thrown later when appropriate.

Comment on lines +259 to +272
auto buffer_is_completed_replicated = [&](const Buffer &buffer) {
if (buffer.scope() != "local.fragment")
return false;
auto frag = T.layout_map[buffer].as<Fragment>().value();
// buffer indices should be IntImm
for (const auto &index : indice_map_[buffer]) {
if (!index.as<IntImmNode>()) {
return false;
} else if (index.as<IntImmNode>()->value != 0) {
LOG(FATAL) << "buffer " << buffer << " is not completed replicated";
}
}
return frag->IsCompletedReplicated();
};
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Review the replication detection logic.

The buffer_is_completed_replicated lambda checks if a buffer is fully replicated by:

  1. Checking scope is "local.fragment"
  2. Verifying all indices are IntImm with value 0
  3. Calling IsCompletedReplicated() on the fragment

Line 268: Error handling concern - The code uses LOG(FATAL) when a non-zero index is found, but this happens inside a lambda that returns bool. This means the function will terminate the program rather than returning false, which seems overly aggressive.

Consider whether LOG(FATAL) is appropriate here. If non-zero indices on replicated buffers should be caught earlier (e.g., in strict mode at lines 248-253), this check may be redundant. Otherwise, consider returning false instead:

       } else if (index.as<IntImmNode>()->value != 0) {
-        LOG(FATAL) << "buffer " << buffer << " is not completed replicated";
+        return false;
       }

Alternatively, if this truly represents an invariant violation, add a comment explaining why program termination is necessary.

📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
auto buffer_is_completed_replicated = [&](const Buffer &buffer) {
if (buffer.scope() != "local.fragment")
return false;
auto frag = T.layout_map[buffer].as<Fragment>().value();
// buffer indices should be IntImm
for (const auto &index : indice_map_[buffer]) {
if (!index.as<IntImmNode>()) {
return false;
} else if (index.as<IntImmNode>()->value != 0) {
LOG(FATAL) << "buffer " << buffer << " is not completed replicated";
}
}
return frag->IsCompletedReplicated();
};
auto buffer_is_completed_replicated = [&](const Buffer &buffer) {
if (buffer.scope() != "local.fragment")
return false;
auto frag = T.layout_map[buffer].as<Fragment>().value();
// buffer indices should be IntImm
for (const auto &index : indice_map_[buffer]) {
if (!index.as<IntImmNode>()) {
return false;
} else if (index.as<IntImmNode>()->value != 0) {
return false;
}
}
return frag->IsCompletedReplicated();
};
🤖 Prompt for AI Agents
In src/op/parallel.cc around lines 259 to 272, the lambda
buffer_is_completed_replicated currently calls LOG(FATAL) when an index IntImm
value != 0 which will abort the program; change that LOG(FATAL) to return false
so the lambda simply reports the buffer is not completed replicated (treat
non-zero indices as not replicated) and rely on earlier strict-mode checks for
fatal validation, or if you intend this to be an invariant that must abort,
replace the LOG(FATAL) with a clear comment explaining why termination is
required.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

♻️ Duplicate comments (1)
src/op/parallel.cc (1)

259-272: Reconsider LOG(FATAL) in predicate lambda.

The buffer_is_completed_replicated lambda is used as a predicate to check whether a buffer is fully replicated. However, line 268 calls LOG(FATAL) when encountering a non-zero IntImm index, which will abort the program rather than returning false. This is inconsistent with the function's role as a boolean predicate and overly aggressive for non-strict inference modes (kCommon/kFree).

In strict mode (lines 248-253), non-zero indices already trigger a fatal error with a clear message. This lambda should simply return false for non-zero indices, letting earlier validation catch true invariant violations.

Apply this diff:

       } else if (index.as<IntImmNode>()->value != 0) {
-        LOG(FATAL) << "buffer " << buffer << " is not completed replicated";
+        return false;
       }

Based on learnings (past review comment on lines 259-272).

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 6464a9d and 470c13f.

📒 Files selected for processing (1)
  • src/op/parallel.cc (5 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
src/op/parallel.cc (2)
src/transform/legalize_safe_memory_access.cc (8)
  • buffer (80-88)
  • buffer (80-80)
  • buffer (91-130)
  • buffer (91-92)
  • buffer (236-240)
  • buffer (236-236)
  • buffer (242-245)
  • buffer (242-242)
src/layout/layout.cc (2)
  • Fragment (293-315)
  • Fragment (317-327)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
  • GitHub Check: build-test-amd
  • GitHub Check: format-check
  • GitHub Check: bot-task
🔇 Additional comments (7)
src/op/parallel.cc (7)

216-257: Strict mode replication handling looks correct.

The logic properly identifies fragment buffers with constant zero indices and creates replicated layouts for them. The fatal error for non-zero indices (lines 249-253) is appropriate for strict mode, as it enforces the invariant that only fragment[0] accesses are allowed in parallel loops.


274-290: LGTM: Fragment buffer collection logic.

The code correctly identifies fragment buffers with constant indices versus those with variable indices, which is essential for determining whether layout propagation should apply.


292-301: Well-reasoned layout propagation guard.

The allow_layout_propgate flag correctly distinguishes scenarios where layout propagation is needed (when fragments have non-constant indices) from those where it isn't (all fragments use constant indices like frag[0]). The inline examples clarify the distinction effectively.


307-338: Source buffer selection logic looks sound.

The updated selection logic appropriately deprioritizes fully replicated buffers when choosing source buffers for layout inference. This aligns with the comment on lines 363-367 describing the preference order.


339-362: LGTM: Enhanced debug logging.

The additions to compute_loop_layout_from_buffer are debug logging statements that improve observability without changing the core logic.


425-470: Free inference fallback logic preserved.

The free inference path (when no source buffer is available) remains intact and correctly handles vectorization constraints and thread bounds. The restructuring doesn't affect this logic.


363-373: Verify layout assignment for constant-index fragments when allow_layout_propagate is false
The && allow_layout_propagate guards skip both source buffers when all index fragments are constant; confirm that layouts for constant-index-only fragments are always set (e.g., in strict mode) before reaching this code path, or add a comment documenting the intended behavior.

return {};
if (level == InferLevel::kStrict) {
LayoutMap results;
// Deduce buffers that shoule be complicated replicated.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Fix typo in comment.

The comment contains a typo: "shoule be complicated" should be "should be completely".

Apply this diff:

-    // Deduce buffers that shoule be complicated replicated.
+    // Deduce buffers that should be completely replicated.
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
// Deduce buffers that shoule be complicated replicated.
// Deduce buffers that should be completely replicated.
🤖 Prompt for AI Agents
In src/op/parallel.cc around line 218, update the comment text to fix the typo:
replace "Deduce buffers that shoule be complicated replicated." with "Deduce
buffers that should be completely replicated." — ensure exact wording and
punctuation are applied to the comment line only.


// Step 1: try to infer loop's partition from a source fragment
Buffer source_buffer, read_source_buffer;
Buffer replicated_write_buffer; // Backup: fully replicated write buffer
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion | 🟠 Major

🧩 Analysis chain

Unused variable: replicated_write_buffer.

The variable replicated_write_buffer is declared with a comment indicating it's a "backup: fully replicated write buffer", but it's never assigned or used anywhere in the function. This suggests incomplete implementation or dead code.

If this variable is intended for future use, consider adding a TODO comment. Otherwise, remove it:

-  Buffer replicated_write_buffer; // Backup: fully replicated write buffer

Run the following script to confirm it's not used elsewhere:


🏁 Script executed:

#!/bin/bash
# Search for any usage of replicated_write_buffer in the codebase
rg -n "replicated_write_buffer" --type=cpp

Length of output: 97


Remove unused variable

Delete the declaration of replicated_write_buffer at line 305:

- Buffer replicated_write_buffer; // Backup: fully replicated write buffer
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
Buffer replicated_write_buffer; // Backup: fully replicated write buffer
🤖 Prompt for AI Agents
In src/op/parallel.cc around line 305, the local declaration "Buffer
replicated_write_buffer;" is unused; remove that single declaration line from
the file, verify there are no remaining references to replicated_write_buffer
elsewhere in the file or project, and rebuild to ensure no compile errors from
its removal.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
src/op/parallel.cc (1)

325-346: Unused variable and logic inconsistency with comment.

Line 325 computes is_fully_replicated but this variable is never used in the subsequent logic. Additionally, the comment at lines 333–334 states "if the buffer is completed replicated, we don't need to infer the layout from this buffer," but the code does not check is_fully_replicated before updating read_source_buffer at lines 335–339. This means fully replicated buffers may still be selected as read sources, contradicting the stated intent.

Apply this diff to enforce the comment's intent:

       } else {
         // Keep the buffer with largest number of indices
         // (which means the inference based on that buffer is more accurate)
         // as read_source_buffer to get more accurate layout
         // if the buffer is completed replicated, we don't need to infer the
         // layout from this buffer.
-        if ((!read_source_buffer.defined() ||
-             indice_map_[buffer].size() >
-                 indice_map_[read_source_buffer].size())) {
+        if (!is_fully_replicated &&
+            (!read_source_buffer.defined() ||
+             indice_map_[buffer].size() >
+                 indice_map_[read_source_buffer].size())) {
           read_source_buffer = buffer;
         }
♻️ Duplicate comments (3)
src/op/parallel.cc (3)

218-218: Fix typo in comment.

The comment still contains the typo flagged in previous reviews: "shoule be complicated" should read "should be completely".

Apply this diff:

-    // Deduce buffers that shoule be complicated replicated.
+    // Deduce buffers that should be completely replicated.

269-282: Replace LOG(FATAL) with return false in boolean lambda.

The buffer_is_completed_replicated lambda is designed to return a boolean indicating whether a buffer meets replication criteria. Using LOG(FATAL) at line 278 aborts the program instead of returning false, which is inconsistent with the lambda's purpose. Non-zero indices should simply be treated as "not replicated" and rely on the strict-mode validation at lines 236–240 for fatal checks.

Apply this diff:

       } else if (index.as<IntImmNode>()->value != 0) {
-        LOG(FATAL) << "buffer " << buffer << " is not completed replicated";
+        return false;
       }

315-315: Remove unused variable.

The variable replicated_write_buffer is declared but never assigned or used anywhere in the function, as confirmed by previous reviews.

Apply this diff:

-  Buffer replicated_write_buffer; // Backup: fully replicated write buffer
🧹 Nitpick comments (1)
src/op/parallel.cc (1)

236-240: Consider throwing a structured exception instead of aborting.

Using LOG(FATAL) terminates the entire program. For consistency with other error handling in this file (e.g., LayoutConflictException at line 518), consider throwing a structured exception that can be caught and tested.

Apply this diff:

-            LOG(FATAL)
-                << "Fragment buffer access with non-zero index [" << imm->value
-                << "] is not supported. "
-                << "Only fragment[0] access is allowed within T.Parallel loop.";
+            std::ostringstream oss;
+            oss << "Fragment buffer " << buffer 
+                << " access with non-zero index [" << imm->value
+                << "] is not supported. "
+                << "Only fragment[0] access is allowed within T.Parallel loop.";
+            throw LayoutConflictException(oss.str());
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 470c13f and 3d36cfc.

📒 Files selected for processing (1)
  • src/op/parallel.cc (5 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
src/op/parallel.cc (3)
src/transform/legalize_safe_memory_access.cc (8)
  • buffer (80-88)
  • buffer (80-80)
  • buffer (91-130)
  • buffer (91-92)
  • buffer (236-240)
  • buffer (236-236)
  • buffer (242-245)
  • buffer (242-242)
src/layout/layout.cc (2)
  • Fragment (293-315)
  • Fragment (317-327)
tilelang/layout/fragment.py (1)
  • Fragment (13-207)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
  • GitHub Check: build-test-amd
  • GitHub Check: format-check
🔇 Additional comments (4)
src/op/parallel.cc (4)

283-311: LGTM: Fragment buffer classification logic is sound.

The code correctly classifies fragment buffers into const-index and all-fragment categories, and the allow_layout_propgate flag accurately reflects whether common layout propagation is needed based on the presence of non-constant indices.


349-371: LGTM: Debug logging added for layout computation.

The addition of debug logs at the entry and exit of compute_loop_layout_from_buffer improves traceability during layout inference without altering the core logic.


373-383: LGTM: Multi-path inference order is well-structured.

The addition of the allow_layout_propgate guard on lines 379 and 382 correctly restricts layout propagation to cases where non-constant index fragment buffers exist, aligning with the logic defined at lines 310–311. The comment at lines 373–377 clearly documents the inference priority.


435-480: LGTM: Free inference fallback correctly guarded.

Wrapping the free inference block with if (!loop_layout_.defined()) ensures it only executes when no layout was inferred from source buffers, preventing unintended overwrites.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Nitpick comments (1)
examples/warp_specialize/example_warp_specialize_gemm_copy_0_gemm_1.py (1)

51-51: Consider aligning default dimensions across warp specialization examples.

This module uses M=N=K=1024 as defaults, while example_warp_specialize_gemm_copy_1_gemm_0.py and other related modules use M=N=K=16384. If there's no specific reason for smaller defaults here, consider aligning with the larger dimensions for consistency.

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 71f98af and 4f0efee.

📒 Files selected for processing (9)
  • examples/gemm/example_gemm_intrinsics.py (2 hunks)
  • examples/gemm/example_gemm_persistent.py (2 hunks)
  • examples/gemm/test_example_gemm.py (1 hunks)
  • examples/warp_specialize/example_warp_specialize_gemm_barrierpipe_stage2.py (1 hunks)
  • examples/warp_specialize/example_warp_specialize_gemm_copy_0_gemm_1.py (1 hunks)
  • examples/warp_specialize/example_warp_specialize_gemm_copy_1_gemm_0.py (1 hunks)
  • examples/warp_specialize/example_warp_specialize_gemm_copy_gemm_0_1.py (1 hunks)
  • examples/warp_specialize/example_warp_specialize_gemm_softpipe_stage2.py (1 hunks)
  • examples/warp_specialize/test_example_warp_specialize.py (1 hunks)
🧰 Additional context used
🧬 Code graph analysis (9)
examples/warp_specialize/example_warp_specialize_gemm_copy_0_gemm_1.py (3)
examples/gemm/example_gemm_intrinsics.py (1)
  • main (165-181)
examples/warp_specialize/example_warp_specialize_gemm_barrierpipe_stage2.py (2)
  • main (14-49)
  • main (54-84)
examples/warp_specialize/example_warp_specialize_gemm_copy_1_gemm_0.py (2)
  • main (18-47)
  • main (52-83)
examples/warp_specialize/example_warp_specialize_gemm_softpipe_stage2.py (3)
examples/warp_specialize/example_warp_specialize_gemm_copy_gemm_0_1.py (2)
  • main (26-57)
  • main (62-100)
examples/warp_specialize/example_warp_specialize_gemm_barrierpipe_stage2.py (2)
  • main (14-49)
  • main (54-84)
examples/warp_specialize/example_warp_specialize_gemm_copy_1_gemm_0.py (2)
  • main (18-47)
  • main (52-83)
examples/warp_specialize/example_warp_specialize_gemm_barrierpipe_stage2.py (3)
examples/warp_specialize/example_warp_specialize_gemm_copy_0_gemm_1.py (2)
  • main (18-46)
  • main (51-82)
examples/warp_specialize/example_warp_specialize_gemm_copy_1_gemm_0.py (2)
  • main (18-47)
  • main (52-83)
examples/warp_specialize/example_warp_specialize_gemm_softpipe_stage2.py (2)
  • main (11-41)
  • main (46-81)
examples/warp_specialize/test_example_warp_specialize.py (4)
examples/warp_specialize/example_warp_specialize_gemm_barrierpipe_stage2.py (2)
  • main (14-49)
  • main (54-84)
examples/warp_specialize/example_warp_specialize_gemm_copy_0_gemm_1.py (2)
  • main (18-46)
  • main (51-82)
examples/warp_specialize/example_warp_specialize_gemm_copy_1_gemm_0.py (2)
  • main (18-47)
  • main (52-83)
examples/warp_specialize/example_warp_specialize_gemm_softpipe_stage2.py (2)
  • main (11-41)
  • main (46-81)
examples/warp_specialize/example_warp_specialize_gemm_copy_gemm_0_1.py (1)
tilelang/jit/kernel.py (1)
  • out_idx (446-447)
examples/gemm/test_example_gemm.py (2)
examples/gemm/example_gemm_intrinsics.py (1)
  • main (165-181)
examples/gemm/example_gemm_autotune.py (2)
  • main (122-145)
  • main (239-262)
examples/gemm/example_gemm_intrinsics.py (4)
examples/gemm_fp8/example_tilelang_gemm_fp8_intrinsic.py (1)
  • main (217-219)
examples/gemm_fp8/example_tilelang_gemm_fp8.py (1)
  • main (59-61)
examples/gemm/example_gemm_schedule.py (1)
  • main (43-65)
examples/gemm/example_gemm.py (1)
  • main (30-52)
examples/warp_specialize/example_warp_specialize_gemm_copy_1_gemm_0.py (1)
examples/warp_specialize/example_warp_specialize_gemm_copy_gemm_0_1.py (2)
  • main (26-57)
  • main (62-100)
examples/gemm/example_gemm_persistent.py (1)
examples/gemm/example_gemm_intrinsics.py (1)
  • main (165-181)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
  • GitHub Check: bot-task
  • GitHub Check: format-check
🔇 Additional comments (14)
examples/warp_specialize/example_warp_specialize_gemm_copy_gemm_0_1.py (1)

10-10: LGTM! Formatting improvement.

The consolidation of decorator arguments onto a single line improves code compactness without affecting functionality.

examples/warp_specialize/example_warp_specialize_gemm_copy_1_gemm_0.py (1)

52-52: LGTM! Parameterization improves testability.

The updated signature with defaults (M=16384, N=16384, K=16384) allows callers to override dimensions while preserving backward compatibility.

examples/warp_specialize/example_warp_specialize_gemm_barrierpipe_stage2.py (1)

54-54: LGTM! Consistent parameterization.

The signature update with M=16384, N=16384, K=16384 aligns with other warp specialization examples and enables dimension customization in tests.

examples/warp_specialize/example_warp_specialize_gemm_softpipe_stage2.py (1)

46-46: LGTM! Parameterization enables flexible testing.

The signature change with defaults M=16384, N=16384, K=16384 is consistent with the broader refactoring across warp specialization examples.

examples/warp_specialize/test_example_warp_specialize.py (4)

19-19: LGTM! Test updated to match new signature.

The explicit M=1024, N=1024, K=1024 arguments correctly invoke the parameterized main() and override the default 16384 dimensions for faster test execution.


25-25: LGTM! Consistent test parameterization.

Test invocation updated to explicitly pass M=1024, N=1024, K=1024, matching the pattern across all warp specialization tests.


31-31: LGTM! Test invocation consistent with new signature.

Explicit dimensions (M=1024, N=1024, K=1024) override the module's default 16384 values for efficient test execution.


37-37: LGTM! Test updated to use parameterized entry point.

The explicit M=1024, N=1024, K=1024 arguments ensure consistent test dimensions across all warp specialization modules.

examples/gemm/example_gemm_persistent.py (2)

121-121: LGTM: Signature parameterization enables flexible dimension passing.

The updated signature allows callers to specify M, N, K explicitly while preserving default behavior. This aligns with the broader PR pattern to parameterize GEMM entry points.


153-159: LGTM: Argparse integration enables CLI dimension control.

The command-line defaults (8192) intentionally differ from the function defaults (4096) to provide larger problem sizes when the script is run standalone, while preserving smaller defaults for programmatic invocation.

examples/gemm/test_example_gemm.py (2)

10-10: LGTM: Test updated to pass explicit dimensions.

The test now supplies explicit M=1024, N=1024, K=1024 to match the parameterized main() signature in example_gemm_autotune, using smaller dimensions for faster test execution.


14-14: LGTM: Test updated to pass explicit dimensions.

The test now supplies explicit M=1024, N=1024, K=1024 to match the parameterized main() signature in example_gemm_intrinsics, using smaller dimensions for faster test execution.

examples/gemm/example_gemm_intrinsics.py (2)

165-165: LGTM: Signature parameterization enables flexible dimension passing.

The updated signature allows callers to specify M, N, K explicitly while preserving default behavior. This aligns with the broader PR pattern to parameterize GEMM entry points.


185-185: LGTM: Standalone invocation updated to match parameterized signature.

The explicit dimension passing preserves the original default behavior while aligning with the new signature.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (1)
examples/gemm/example_gemm_autotune.py (1)

280-280: Consider aligning CLI argument names with function parameters.

The function call main(args.m, args.n, args.k, ...) works correctly via positional binding, but there's a naming mismatch: CLI arguments use lowercase (--m, --n, --k) while the function signature now uses uppercase (M, N, K).

For consistency and clarity, consider either:

  1. Updating argparse to use uppercase: parser.add_argument("--M", ...)
  2. Using keyword arguments: main(M=args.m, N=args.n, K=args.k, ...)

Apply this diff if you prefer keyword arguments for explicitness:

-    main(args.m, args.n, args.k, args.use_autotune, args.with_roller)
+    main(M=args.m, N=args.n, K=args.k, use_autotune=args.use_autotune, with_roller=args.with_roller)
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 4f0efee and e4b8a95.

📒 Files selected for processing (1)
  • examples/gemm/example_gemm_autotune.py (1 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
examples/gemm/example_gemm_autotune.py (3)
examples/dequantize_gemm/example_dequant_gemm_bf16_fp4_hopper.py (2)
  • main (294-347)
  • main (393-437)
examples/dequantize_gemm/example_dequant_gemm_bf16_mxfp4_hopper.py (2)
  • main (314-371)
  • main (473-538)
tilelang/tileop/gemm/gemm_base.py (3)
  • M (33-34)
  • N (37-38)
  • K (41-42)

Comment on lines +239 to +241
def main(M: int = 4096,
N: int = 4096,
K: int = 4096,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion | 🟠 Major

🧩 Analysis chain

Good standardization on uppercase matrix dimensions.

The parameter renaming from lowercase m, n, k to uppercase M, N, K aligns with conventional mathematical notation for matrix dimensions and matches the naming in tilelang/tileop/gemm/gemm_base.py properties.

However, verify that the CLI argument names at lines 266-268 shouldn't also be updated to --M, --N, --K for consistency.

Run this script to check how other GEMM examples handle CLI argument naming:


🏁 Script executed:

#!/bin/bash
# Check CLI argument naming patterns across GEMM examples
fd -e py . examples/gemm/ --exec rg -n 'parser\.add_argument.*--(m|n|k|M|N|K)' {}

Length of output: 494


Sync CLI flags with uppercase parameters

Change --m, --n, --k to --M, --N, --K in the parser definitions (lines 266–268) and update the main call to args.M, args.N, args.K (line 280) to match the uppercase signature and align with other GEMM examples.

🤖 Prompt for AI Agents
In examples/gemm/example_gemm_autotune.py around lines 239 and specifically
update parser code at lines 266–268 and the main invocation at line 280: the CLI
flags currently use lowercase --m/--n/--k but the main function signature
expects uppercase M, N, K; change the parser.add_argument flags to use --M, --N,
--K (and corresponding dest/argument names if present) and update the call to
main to pass args.M, args.N, args.K instead of args.m, args.n, args.k so the CLI
matches the function signature and other GEMM examples.

@LeiWang1999 LeiWang1999 merged commit fc4bd45 into tile-ai:main Oct 2, 2025
6 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant