-
Notifications
You must be signed in to change notification settings - Fork 266
[Layout] Strict annotate completed replicated layout for fragment with constant index #929
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
…e in ParallelOpNode - Introduced IsCompletedReplicated method in FragmentNode to check if a buffer is fully replicated. - Enhanced InferLayout in ParallelOpNode to handle layout inference for replicated buffers, ensuring only fragment[0] access is allowed. - Updated error handling for non-zero index access in fragment buffers to improve robustness.
…allel.cc - Enhanced formatting in FragmentNode's IsCompletedReplicated method for better clarity. - Updated InferLayout method in ParallelOpNode to improve code readability by adjusting line breaks and indentation. - Ensured consistent formatting across conditional statements and comments for improved maintainability.
WalkthroughAdds FragmentNode::IsCompletedReplicated() and replication-aware layout inference in ParallelOp strict-mode; refactors a Python transform to skip non-fragment buffers; updates many example mains to accept M,N,K params and reduces some test prepare_output usage. Changes
Sequence Diagram(s)sequenceDiagram
autonumber
participant Caller as ParallelOp::InferLayout
participant Scanner as BufferScanner
participant Classifier as ReplicaChecker
participant Layout as compute_loop_layout_from_buffer
participant FreeMode as FreeModeFallback
rect rgba(240,248,255,0.9)
note over Caller,Scanner: Strict-mode scan of fragment buffers
Caller->>Scanner: enumerate read/write buffers & indices
Scanner-->>Classifier: buffer entries (scope, indices)
Classifier-->>Caller: mark fully-replicated (all idx==0 & fragment completed) or non-replicated
end
rect rgba(245,255,250,0.9)
note over Caller: Multi-path source selection
Caller->>Caller: choose source: 1) non-replicated write, 2) non-replicated read, 3) replicated-write backup
alt non-replicated write chosen
Caller->>Layout: compute layout (common path)
else non-replicated read chosen
Caller->>Layout: compute layout (read-based)
else replicated write backup
Caller->>Layout: attempt replication-aware inference (use forward thread/index)
else no source
Caller->>FreeMode: fall back to free-mode inference (cross-thread checks / post-process)
end
end
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~25 minutes Possibly related PRs
Suggested reviewers
Poem
Pre-merge checks and finishing touches❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✨ Finishing touches
🧪 Generate unit tests
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
👋 Hi! Thank you for contributing to the TileLang project. Please remember to run We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
🧹 Nitpick comments (3)
src/op/parallel.cc (3)
285-304
: Verify the source buffer selection priority logic.The logic now distinguishes between:
- Fully replicated reducer write buffers (allowed as sources)
- Non-replicated write buffers (preferred sources)
- Fully replicated non-reducer write buffers (backup only)
Line 295: Complex condition - The condition
(!is_fully_replicated && source_buffer.defined())
appears to prefer non-replicated buffers over previously selected replicated ones. However, the outer condition already ensures we only enter when!is_fully_replicated || is_reducer
, so if we reach line 295 withis_fully_replicated=true
, thenis_reducer
must be true.Potential issue: If we first encounter a replicated reducer (which sets
source_buffer
), then encounter a non-replicated non-reducer, the condition at line 295 will override the reducer choice. Is this intended?Please verify the intended priority order:
- Non-replicated write buffer (highest priority)
- Replicated reducer write buffer
- Non-replicated read buffer
- Replicated non-reducer write buffer (backup)
The current logic may not correctly implement this priority if buffers arrive in different orders.
Consider restructuring for clearer priority:
if (buffer_is_write_.count(buffer)) { if (!is_fully_replicated) { // Non-replicated write buffer: highest priority source_buffer = buffer; } else if (is_reducer && !source_buffer.defined()) { // Replicated reducer: lower priority, only if no source yet source_buffer = buffer; } else if (!is_reducer && !replicated_write_buffer.defined()) { // Replicated non-reducer: backup only replicated_write_buffer = buffer; DLOG(INFO) << "Found fully replicated non-reducer write buffer ..."; } }
457-472
: Review the backup buffer fallback mechanism.Lines 457-472 implement a backup path when no standard source buffer is available:
Design question: The code attempts to infer from
replicated_write_buffer
but catchesLayoutConflictException
and falls back to free mode. This seems reasonable, but:
Line 470: Silent failure - After catching the exception and clearing
replicated_write_buffer
, the code continues to line 474's free mode check. This silent fallback might mask configuration issues.Error propagation - The caught exception message is logged at WARNING level, but the user might not see why inference ultimately uses free mode.
Consider whether the exception should be re-thrown with additional context, or if a more explicit failure mode would be clearer:
} catch (const LayoutConflictException &e) { DLOG(WARNING) << "Failed to infer from replicated buffer: " << e.what() << ". Falling back to free mode"; // Optionally: track that we fell back due to conflict replicated_write_buffer = Buffer(); // Clear to trigger free mode below }This is already implemented correctly, but consider adding a metric or counter for monitoring how often this fallback path is taken in production.
216-258
: Fix typo in replication comment
Apply this diff to correct the comment in src/op/parallel.cc:- // Deduce buffers that shoule be complicated replicated. + // Deduce buffers that should be completely replicated.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (4)
src/layout/layout.cc
(1 hunks)src/layout/layout.h
(1 hunks)src/op/parallel.cc
(6 hunks)tilelang/transform/add_bufstore_wrapper.py
(1 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
src/layout/layout.h (1)
src/layout/layout.cc (2)
IsCompletedReplicated
(330-334)IsCompletedReplicated
(330-330)
src/op/parallel.cc (3)
src/transform/legalize_safe_memory_access.cc (8)
buffer
(80-88)buffer
(80-80)buffer
(91-130)buffer
(91-92)buffer
(236-240)buffer
(236-236)buffer
(242-245)buffer
(242-242)src/layout/layout.cc (2)
Fragment
(293-315)Fragment
(317-327)src/op/parallel.h (1)
LayoutConflictException
(27-34)
🪛 Ruff (0.13.2)
tilelang/transform/add_bufstore_wrapper.py
141-143: Avoid specifying long messages outside the exception class
(TRY003)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
- GitHub Check: build-test-amd
- GitHub Check: format-check
🔇 Additional comments (5)
tilelang/transform/add_bufstore_wrapper.py (1)
137-143
: LGTM: Clearer control flow with early filtering.The refactor to use early-continue for non-fragment buffers improves readability while preserving the original validation logic for fragment buffers.
src/layout/layout.h (1)
104-104
: LGTM: Clean API addition.The new
IsCompletedReplicated()
method provides a clear interface for detecting fully replicated fragments. The const-correctness and naming are appropriate.src/layout/layout.cc (1)
329-334
: LGTM: Correct implementation of replication detection.The implementation correctly uses
Analyzer::Simplify
before comparing withReplicationPlaceholder()
. The use ofExprDeepEqual
is appropriate for symbolic expression comparison.src/op/parallel.cc (2)
474-519
: Review the free mode inference path.The free mode inference (when no source buffer is available) has been refactored:
Good practices observed:
- Lines 478-480: Correctly uses remapped buffers for vectorization analysis
- Lines 490-496: Adjusts vector size to ensure divisibility
- Lines 499-512: Validates against
coalesced_width
annotationLine 474: Condition logic - The check
!loop_layout_.defined()
handles both the case where no source buffer existed and the case where backup buffer inference failed. This is correct.Completeness: The free mode path appears complete and handles edge cases properly (vectorization, thread bounds, predicates for replication).
326-395
: No issues withuses_loop_var
detection or warning logic Theuses_loop_var
flag only matches variables inloop_vars_
, so therep
iteration var cannot trigger a false positive, and the initial warning for fully replicated non-reducers is acceptable since aLayoutConflictException
is thrown later when appropriate.
auto buffer_is_completed_replicated = [&](const Buffer &buffer) { | ||
if (buffer.scope() != "local.fragment") | ||
return false; | ||
auto frag = T.layout_map[buffer].as<Fragment>().value(); | ||
// buffer indices should be IntImm | ||
for (const auto &index : indice_map_[buffer]) { | ||
if (!index.as<IntImmNode>()) { | ||
return false; | ||
} else if (index.as<IntImmNode>()->value != 0) { | ||
LOG(FATAL) << "buffer " << buffer << " is not completed replicated"; | ||
} | ||
} | ||
return frag->IsCompletedReplicated(); | ||
}; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Review the replication detection logic.
The buffer_is_completed_replicated
lambda checks if a buffer is fully replicated by:
- Checking scope is "local.fragment"
- Verifying all indices are IntImm with value 0
- Calling
IsCompletedReplicated()
on the fragment
Line 268: Error handling concern - The code uses LOG(FATAL)
when a non-zero index is found, but this happens inside a lambda that returns bool. This means the function will terminate the program rather than returning false, which seems overly aggressive.
Consider whether LOG(FATAL)
is appropriate here. If non-zero indices on replicated buffers should be caught earlier (e.g., in strict mode at lines 248-253), this check may be redundant. Otherwise, consider returning false
instead:
} else if (index.as<IntImmNode>()->value != 0) {
- LOG(FATAL) << "buffer " << buffer << " is not completed replicated";
+ return false;
}
Alternatively, if this truly represents an invariant violation, add a comment explaining why program termination is necessary.
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
auto buffer_is_completed_replicated = [&](const Buffer &buffer) { | |
if (buffer.scope() != "local.fragment") | |
return false; | |
auto frag = T.layout_map[buffer].as<Fragment>().value(); | |
// buffer indices should be IntImm | |
for (const auto &index : indice_map_[buffer]) { | |
if (!index.as<IntImmNode>()) { | |
return false; | |
} else if (index.as<IntImmNode>()->value != 0) { | |
LOG(FATAL) << "buffer " << buffer << " is not completed replicated"; | |
} | |
} | |
return frag->IsCompletedReplicated(); | |
}; | |
auto buffer_is_completed_replicated = [&](const Buffer &buffer) { | |
if (buffer.scope() != "local.fragment") | |
return false; | |
auto frag = T.layout_map[buffer].as<Fragment>().value(); | |
// buffer indices should be IntImm | |
for (const auto &index : indice_map_[buffer]) { | |
if (!index.as<IntImmNode>()) { | |
return false; | |
} else if (index.as<IntImmNode>()->value != 0) { | |
return false; | |
} | |
} | |
return frag->IsCompletedReplicated(); | |
}; |
🤖 Prompt for AI Agents
In src/op/parallel.cc around lines 259 to 272, the lambda
buffer_is_completed_replicated currently calls LOG(FATAL) when an index IntImm
value != 0 which will abort the program; change that LOG(FATAL) to return false
so the lambda simply reports the buffer is not completed replicated (treat
non-zero indices as not replicated) and rely on earlier strict-mode checks for
fatal validation, or if you intend this to be an invariant that must abort,
replace the LOG(FATAL) with a clear comment explaining why termination is
required.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 2
♻️ Duplicate comments (1)
src/op/parallel.cc (1)
259-272
: Reconsider LOG(FATAL) in predicate lambda.The
buffer_is_completed_replicated
lambda is used as a predicate to check whether a buffer is fully replicated. However, line 268 callsLOG(FATAL)
when encountering a non-zero IntImm index, which will abort the program rather than returningfalse
. This is inconsistent with the function's role as a boolean predicate and overly aggressive for non-strict inference modes (kCommon/kFree).In strict mode (lines 248-253), non-zero indices already trigger a fatal error with a clear message. This lambda should simply return
false
for non-zero indices, letting earlier validation catch true invariant violations.Apply this diff:
} else if (index.as<IntImmNode>()->value != 0) { - LOG(FATAL) << "buffer " << buffer << " is not completed replicated"; + return false; }Based on learnings (past review comment on lines 259-272).
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
src/op/parallel.cc
(5 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
src/op/parallel.cc (2)
src/transform/legalize_safe_memory_access.cc (8)
buffer
(80-88)buffer
(80-80)buffer
(91-130)buffer
(91-92)buffer
(236-240)buffer
(236-236)buffer
(242-245)buffer
(242-242)src/layout/layout.cc (2)
Fragment
(293-315)Fragment
(317-327)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
- GitHub Check: build-test-amd
- GitHub Check: format-check
- GitHub Check: bot-task
🔇 Additional comments (7)
src/op/parallel.cc (7)
216-257
: Strict mode replication handling looks correct.The logic properly identifies fragment buffers with constant zero indices and creates replicated layouts for them. The fatal error for non-zero indices (lines 249-253) is appropriate for strict mode, as it enforces the invariant that only
fragment[0]
accesses are allowed in parallel loops.
274-290
: LGTM: Fragment buffer collection logic.The code correctly identifies fragment buffers with constant indices versus those with variable indices, which is essential for determining whether layout propagation should apply.
292-301
: Well-reasoned layout propagation guard.The
allow_layout_propgate
flag correctly distinguishes scenarios where layout propagation is needed (when fragments have non-constant indices) from those where it isn't (all fragments use constant indices likefrag[0]
). The inline examples clarify the distinction effectively.
307-338
: Source buffer selection logic looks sound.The updated selection logic appropriately deprioritizes fully replicated buffers when choosing source buffers for layout inference. This aligns with the comment on lines 363-367 describing the preference order.
339-362
: LGTM: Enhanced debug logging.The additions to
compute_loop_layout_from_buffer
are debug logging statements that improve observability without changing the core logic.
425-470
: Free inference fallback logic preserved.The free inference path (when no source buffer is available) remains intact and correctly handles vectorization constraints and thread bounds. The restructuring doesn't affect this logic.
363-373
: Verify layout assignment for constant-index fragments when allow_layout_propagate is false
The&& allow_layout_propagate
guards skip both source buffers when all index fragments are constant; confirm that layouts for constant-index-only fragments are always set (e.g., in strict mode) before reaching this code path, or add a comment documenting the intended behavior.
return {}; | ||
if (level == InferLevel::kStrict) { | ||
LayoutMap results; | ||
// Deduce buffers that shoule be complicated replicated. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fix typo in comment.
The comment contains a typo: "shoule be complicated" should be "should be completely".
Apply this diff:
- // Deduce buffers that shoule be complicated replicated.
+ // Deduce buffers that should be completely replicated.
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
// Deduce buffers that shoule be complicated replicated. | |
// Deduce buffers that should be completely replicated. |
🤖 Prompt for AI Agents
In src/op/parallel.cc around line 218, update the comment text to fix the typo:
replace "Deduce buffers that shoule be complicated replicated." with "Deduce
buffers that should be completely replicated." — ensure exact wording and
punctuation are applied to the comment line only.
|
||
// Step 1: try to infer loop's partition from a source fragment | ||
Buffer source_buffer, read_source_buffer; | ||
Buffer replicated_write_buffer; // Backup: fully replicated write buffer |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion | 🟠 Major
🧩 Analysis chain
Unused variable: replicated_write_buffer.
The variable replicated_write_buffer
is declared with a comment indicating it's a "backup: fully replicated write buffer", but it's never assigned or used anywhere in the function. This suggests incomplete implementation or dead code.
If this variable is intended for future use, consider adding a TODO comment. Otherwise, remove it:
- Buffer replicated_write_buffer; // Backup: fully replicated write buffer
Run the following script to confirm it's not used elsewhere:
🏁 Script executed:
#!/bin/bash
# Search for any usage of replicated_write_buffer in the codebase
rg -n "replicated_write_buffer" --type=cpp
Length of output: 97
Remove unused variable
Delete the declaration of replicated_write_buffer
at line 305:
- Buffer replicated_write_buffer; // Backup: fully replicated write buffer
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
Buffer replicated_write_buffer; // Backup: fully replicated write buffer |
🤖 Prompt for AI Agents
In src/op/parallel.cc around line 305, the local declaration "Buffer
replicated_write_buffer;" is unused; remove that single declaration line from
the file, verify there are no remaining references to replicated_write_buffer
elsewhere in the file or project, and rebuild to ensure no compile errors from
its removal.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
src/op/parallel.cc (1)
325-346
: Unused variable and logic inconsistency with comment.Line 325 computes
is_fully_replicated
but this variable is never used in the subsequent logic. Additionally, the comment at lines 333–334 states "if the buffer is completed replicated, we don't need to infer the layout from this buffer," but the code does not checkis_fully_replicated
before updatingread_source_buffer
at lines 335–339. This means fully replicated buffers may still be selected as read sources, contradicting the stated intent.Apply this diff to enforce the comment's intent:
} else { // Keep the buffer with largest number of indices // (which means the inference based on that buffer is more accurate) // as read_source_buffer to get more accurate layout // if the buffer is completed replicated, we don't need to infer the // layout from this buffer. - if ((!read_source_buffer.defined() || - indice_map_[buffer].size() > - indice_map_[read_source_buffer].size())) { + if (!is_fully_replicated && + (!read_source_buffer.defined() || + indice_map_[buffer].size() > + indice_map_[read_source_buffer].size())) { read_source_buffer = buffer; }
♻️ Duplicate comments (3)
src/op/parallel.cc (3)
218-218
: Fix typo in comment.The comment still contains the typo flagged in previous reviews: "shoule be complicated" should read "should be completely".
Apply this diff:
- // Deduce buffers that shoule be complicated replicated. + // Deduce buffers that should be completely replicated.
269-282
: Replace LOG(FATAL) with return false in boolean lambda.The
buffer_is_completed_replicated
lambda is designed to return a boolean indicating whether a buffer meets replication criteria. UsingLOG(FATAL)
at line 278 aborts the program instead of returningfalse
, which is inconsistent with the lambda's purpose. Non-zero indices should simply be treated as "not replicated" and rely on the strict-mode validation at lines 236–240 for fatal checks.Apply this diff:
} else if (index.as<IntImmNode>()->value != 0) { - LOG(FATAL) << "buffer " << buffer << " is not completed replicated"; + return false; }
315-315
: Remove unused variable.The variable
replicated_write_buffer
is declared but never assigned or used anywhere in the function, as confirmed by previous reviews.Apply this diff:
- Buffer replicated_write_buffer; // Backup: fully replicated write buffer
🧹 Nitpick comments (1)
src/op/parallel.cc (1)
236-240
: Consider throwing a structured exception instead of aborting.Using
LOG(FATAL)
terminates the entire program. For consistency with other error handling in this file (e.g.,LayoutConflictException
at line 518), consider throwing a structured exception that can be caught and tested.Apply this diff:
- LOG(FATAL) - << "Fragment buffer access with non-zero index [" << imm->value - << "] is not supported. " - << "Only fragment[0] access is allowed within T.Parallel loop."; + std::ostringstream oss; + oss << "Fragment buffer " << buffer + << " access with non-zero index [" << imm->value + << "] is not supported. " + << "Only fragment[0] access is allowed within T.Parallel loop."; + throw LayoutConflictException(oss.str());
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
src/op/parallel.cc
(5 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
src/op/parallel.cc (3)
src/transform/legalize_safe_memory_access.cc (8)
buffer
(80-88)buffer
(80-80)buffer
(91-130)buffer
(91-92)buffer
(236-240)buffer
(236-236)buffer
(242-245)buffer
(242-242)src/layout/layout.cc (2)
Fragment
(293-315)Fragment
(317-327)tilelang/layout/fragment.py (1)
Fragment
(13-207)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
- GitHub Check: build-test-amd
- GitHub Check: format-check
🔇 Additional comments (4)
src/op/parallel.cc (4)
283-311
: LGTM: Fragment buffer classification logic is sound.The code correctly classifies fragment buffers into const-index and all-fragment categories, and the
allow_layout_propgate
flag accurately reflects whether common layout propagation is needed based on the presence of non-constant indices.
349-371
: LGTM: Debug logging added for layout computation.The addition of debug logs at the entry and exit of
compute_loop_layout_from_buffer
improves traceability during layout inference without altering the core logic.
373-383
: LGTM: Multi-path inference order is well-structured.The addition of the
allow_layout_propgate
guard on lines 379 and 382 correctly restricts layout propagation to cases where non-constant index fragment buffers exist, aligning with the logic defined at lines 310–311. The comment at lines 373–377 clearly documents the inference priority.
435-480
: LGTM: Free inference fallback correctly guarded.Wrapping the free inference block with
if (!loop_layout_.defined())
ensures it only executes when no layout was inferred from source buffers, preventing unintended overwrites.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (1)
examples/warp_specialize/example_warp_specialize_gemm_copy_0_gemm_1.py (1)
51-51
: Consider aligning default dimensions across warp specialization examples.This module uses M=N=K=1024 as defaults, while
example_warp_specialize_gemm_copy_1_gemm_0.py
and other related modules use M=N=K=16384. If there's no specific reason for smaller defaults here, consider aligning with the larger dimensions for consistency.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (9)
examples/gemm/example_gemm_intrinsics.py
(2 hunks)examples/gemm/example_gemm_persistent.py
(2 hunks)examples/gemm/test_example_gemm.py
(1 hunks)examples/warp_specialize/example_warp_specialize_gemm_barrierpipe_stage2.py
(1 hunks)examples/warp_specialize/example_warp_specialize_gemm_copy_0_gemm_1.py
(1 hunks)examples/warp_specialize/example_warp_specialize_gemm_copy_1_gemm_0.py
(1 hunks)examples/warp_specialize/example_warp_specialize_gemm_copy_gemm_0_1.py
(1 hunks)examples/warp_specialize/example_warp_specialize_gemm_softpipe_stage2.py
(1 hunks)examples/warp_specialize/test_example_warp_specialize.py
(1 hunks)
🧰 Additional context used
🧬 Code graph analysis (9)
examples/warp_specialize/example_warp_specialize_gemm_copy_0_gemm_1.py (3)
examples/gemm/example_gemm_intrinsics.py (1)
main
(165-181)examples/warp_specialize/example_warp_specialize_gemm_barrierpipe_stage2.py (2)
main
(14-49)main
(54-84)examples/warp_specialize/example_warp_specialize_gemm_copy_1_gemm_0.py (2)
main
(18-47)main
(52-83)
examples/warp_specialize/example_warp_specialize_gemm_softpipe_stage2.py (3)
examples/warp_specialize/example_warp_specialize_gemm_copy_gemm_0_1.py (2)
main
(26-57)main
(62-100)examples/warp_specialize/example_warp_specialize_gemm_barrierpipe_stage2.py (2)
main
(14-49)main
(54-84)examples/warp_specialize/example_warp_specialize_gemm_copy_1_gemm_0.py (2)
main
(18-47)main
(52-83)
examples/warp_specialize/example_warp_specialize_gemm_barrierpipe_stage2.py (3)
examples/warp_specialize/example_warp_specialize_gemm_copy_0_gemm_1.py (2)
main
(18-46)main
(51-82)examples/warp_specialize/example_warp_specialize_gemm_copy_1_gemm_0.py (2)
main
(18-47)main
(52-83)examples/warp_specialize/example_warp_specialize_gemm_softpipe_stage2.py (2)
main
(11-41)main
(46-81)
examples/warp_specialize/test_example_warp_specialize.py (4)
examples/warp_specialize/example_warp_specialize_gemm_barrierpipe_stage2.py (2)
main
(14-49)main
(54-84)examples/warp_specialize/example_warp_specialize_gemm_copy_0_gemm_1.py (2)
main
(18-46)main
(51-82)examples/warp_specialize/example_warp_specialize_gemm_copy_1_gemm_0.py (2)
main
(18-47)main
(52-83)examples/warp_specialize/example_warp_specialize_gemm_softpipe_stage2.py (2)
main
(11-41)main
(46-81)
examples/warp_specialize/example_warp_specialize_gemm_copy_gemm_0_1.py (1)
tilelang/jit/kernel.py (1)
out_idx
(446-447)
examples/gemm/test_example_gemm.py (2)
examples/gemm/example_gemm_intrinsics.py (1)
main
(165-181)examples/gemm/example_gemm_autotune.py (2)
main
(122-145)main
(239-262)
examples/gemm/example_gemm_intrinsics.py (4)
examples/gemm_fp8/example_tilelang_gemm_fp8_intrinsic.py (1)
main
(217-219)examples/gemm_fp8/example_tilelang_gemm_fp8.py (1)
main
(59-61)examples/gemm/example_gemm_schedule.py (1)
main
(43-65)examples/gemm/example_gemm.py (1)
main
(30-52)
examples/warp_specialize/example_warp_specialize_gemm_copy_1_gemm_0.py (1)
examples/warp_specialize/example_warp_specialize_gemm_copy_gemm_0_1.py (2)
main
(26-57)main
(62-100)
examples/gemm/example_gemm_persistent.py (1)
examples/gemm/example_gemm_intrinsics.py (1)
main
(165-181)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
- GitHub Check: bot-task
- GitHub Check: format-check
🔇 Additional comments (14)
examples/warp_specialize/example_warp_specialize_gemm_copy_gemm_0_1.py (1)
10-10
: LGTM! Formatting improvement.The consolidation of decorator arguments onto a single line improves code compactness without affecting functionality.
examples/warp_specialize/example_warp_specialize_gemm_copy_1_gemm_0.py (1)
52-52
: LGTM! Parameterization improves testability.The updated signature with defaults (M=16384, N=16384, K=16384) allows callers to override dimensions while preserving backward compatibility.
examples/warp_specialize/example_warp_specialize_gemm_barrierpipe_stage2.py (1)
54-54
: LGTM! Consistent parameterization.The signature update with M=16384, N=16384, K=16384 aligns with other warp specialization examples and enables dimension customization in tests.
examples/warp_specialize/example_warp_specialize_gemm_softpipe_stage2.py (1)
46-46
: LGTM! Parameterization enables flexible testing.The signature change with defaults M=16384, N=16384, K=16384 is consistent with the broader refactoring across warp specialization examples.
examples/warp_specialize/test_example_warp_specialize.py (4)
19-19
: LGTM! Test updated to match new signature.The explicit M=1024, N=1024, K=1024 arguments correctly invoke the parameterized
main()
and override the default 16384 dimensions for faster test execution.
25-25
: LGTM! Consistent test parameterization.Test invocation updated to explicitly pass M=1024, N=1024, K=1024, matching the pattern across all warp specialization tests.
31-31
: LGTM! Test invocation consistent with new signature.Explicit dimensions (M=1024, N=1024, K=1024) override the module's default 16384 values for efficient test execution.
37-37
: LGTM! Test updated to use parameterized entry point.The explicit M=1024, N=1024, K=1024 arguments ensure consistent test dimensions across all warp specialization modules.
examples/gemm/example_gemm_persistent.py (2)
121-121
: LGTM: Signature parameterization enables flexible dimension passing.The updated signature allows callers to specify M, N, K explicitly while preserving default behavior. This aligns with the broader PR pattern to parameterize GEMM entry points.
153-159
: LGTM: Argparse integration enables CLI dimension control.The command-line defaults (8192) intentionally differ from the function defaults (4096) to provide larger problem sizes when the script is run standalone, while preserving smaller defaults for programmatic invocation.
examples/gemm/test_example_gemm.py (2)
10-10
: LGTM: Test updated to pass explicit dimensions.The test now supplies explicit M=1024, N=1024, K=1024 to match the parameterized
main()
signature inexample_gemm_autotune
, using smaller dimensions for faster test execution.
14-14
: LGTM: Test updated to pass explicit dimensions.The test now supplies explicit M=1024, N=1024, K=1024 to match the parameterized
main()
signature inexample_gemm_intrinsics
, using smaller dimensions for faster test execution.examples/gemm/example_gemm_intrinsics.py (2)
165-165
: LGTM: Signature parameterization enables flexible dimension passing.The updated signature allows callers to specify M, N, K explicitly while preserving default behavior. This aligns with the broader PR pattern to parameterize GEMM entry points.
185-185
: LGTM: Standalone invocation updated to match parameterized signature.The explicit dimension passing preserves the original default behavior while aligning with the new signature.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
🧹 Nitpick comments (1)
examples/gemm/example_gemm_autotune.py (1)
280-280
: Consider aligning CLI argument names with function parameters.The function call
main(args.m, args.n, args.k, ...)
works correctly via positional binding, but there's a naming mismatch: CLI arguments use lowercase (--m
,--n
,--k
) while the function signature now uses uppercase (M
,N
,K
).For consistency and clarity, consider either:
- Updating argparse to use uppercase:
parser.add_argument("--M", ...)
- Using keyword arguments:
main(M=args.m, N=args.n, K=args.k, ...)
Apply this diff if you prefer keyword arguments for explicitness:
- main(args.m, args.n, args.k, args.use_autotune, args.with_roller) + main(M=args.m, N=args.n, K=args.k, use_autotune=args.use_autotune, with_roller=args.with_roller)
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
examples/gemm/example_gemm_autotune.py
(1 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
examples/gemm/example_gemm_autotune.py (3)
examples/dequantize_gemm/example_dequant_gemm_bf16_fp4_hopper.py (2)
main
(294-347)main
(393-437)examples/dequantize_gemm/example_dequant_gemm_bf16_mxfp4_hopper.py (2)
main
(314-371)main
(473-538)tilelang/tileop/gemm/gemm_base.py (3)
M
(33-34)N
(37-38)K
(41-42)
def main(M: int = 4096, | ||
N: int = 4096, | ||
K: int = 4096, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion | 🟠 Major
🧩 Analysis chain
Good standardization on uppercase matrix dimensions.
The parameter renaming from lowercase m, n, k
to uppercase M, N, K
aligns with conventional mathematical notation for matrix dimensions and matches the naming in tilelang/tileop/gemm/gemm_base.py
properties.
However, verify that the CLI argument names at lines 266-268 shouldn't also be updated to --M
, --N
, --K
for consistency.
Run this script to check how other GEMM examples handle CLI argument naming:
🏁 Script executed:
#!/bin/bash
# Check CLI argument naming patterns across GEMM examples
fd -e py . examples/gemm/ --exec rg -n 'parser\.add_argument.*--(m|n|k|M|N|K)' {}
Length of output: 494
Sync CLI flags with uppercase parameters
Change --m
, --n
, --k
to --M
, --N
, --K
in the parser definitions (lines 266–268) and update the main
call to args.M
, args.N
, args.K
(line 280) to match the uppercase signature and align with other GEMM examples.
🤖 Prompt for AI Agents
In examples/gemm/example_gemm_autotune.py around lines 239 and specifically
update parser code at lines 266–268 and the main invocation at line 280: the CLI
flags currently use lowercase --m/--n/--k but the main function signature
expects uppercase M, N, K; change the parser.add_argument flags to use --M, --N,
--K (and corresponding dest/argument names if present) and update the call to
main to pass args.M, args.N, args.K instead of args.m, args.n, args.k so the CLI
matches the function signature and other GEMM examples.
Summary by CodeRabbit
New Features
Bug Fixes
Refactor
Tests