-
Notifications
You must be signed in to change notification settings - Fork 156
[Layout] Introduce Flexible Parallel to Support T.serial and local buffers inside T.Parallel loop #844
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Note Other AI code review bot(s) detectedCodeRabbit has detected other AI code review bot(s) in this pull request and will avoid duplicating their findings in the review comments. This may lead to a less comprehensive review. WalkthroughAdds a PassContext toggle to disable Hopper WGMMA and wires it into GEMM selection; broadens loop/layout handling to support non-parallel inner loops with conflict detection and richer debug logging; tightens reducer annotations to parallel loops; lowers verbosity of some logs; rewrites AddWrapperForSingleBufStore and updates a small example. Changes
Sequence Diagram(s)sequenceDiagram
autonumber
participant Caller as GEMM selector
participant Ctxt as PassContext
participant Target
participant Config as kDisableWGMMA
Caller->>Target: Query TargetIsHopper()
alt Hopper
Caller->>Ctxt: GetConfig(Config, Optional<Bool>())
Ctxt-->>Caller: value_or(false)
alt Config == true
Note over Caller: allow_wgmma = false
else Config == false
Caller->>Caller: Evaluate M>=64 && (num_warps % 4 == 0)
Note over Caller: allow_wgmma = true iff constraints satisfied
end
else Non‑Hopper
Note over Caller: allow_wgmma = false
end
sequenceDiagram
autonumber
participant IL as InferInFreeMode
participant Comp as Component
participant State as InferenceState
participant Exec as RunInference
IL->>Comp: Iterate components
loop For each candidate root in component
IL->>State: Backup layout maps & lists
IL->>Exec: Run inference from candidate root
alt Exception thrown
Exec-->>IL: LayoutConflict/NormalizeIterException
IL->>State: Restore backup
Note over IL: Mark attempt invalid
else Success
Exec-->>IL: Updates + reg usage
IL->>IL: Track best plan (min reg usage)
IL->>State: Restore backup
end
end
IL->>State: Apply best plan for component
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60–90 minutes Possibly related PRs
Suggested reviewers
Poem
Pre-merge checks and finishing touches❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✨ Finishing touches
🧪 Generate unit tests
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
👋 Hi! Thank you for contributing to the TileLang project. Please remember to run We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Summary of Changes
Hello @LeiWang1999, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!
This pull request significantly enhances the compiler's ability to handle complex loop structures by introducing "Flexible Parallel" support. This allows for the proper handling and layout inference of T.serial loops and local buffers nested within T.Parallel constructs, which was previously a limitation. Alongside this, a new configuration option has been added to enable or disable Hopper WGMMA instructions, offering more granular control over hardware-specific optimizations. The changes also include substantial additions to debug logging within the layout inference and memory allocation passes, making it easier to trace and understand the compiler's decisions.
Highlights
- Flexible Parallel Loop Support: Introduced mechanisms to allow T.serial loops and local buffers to be used effectively within T.Parallel loops, enhancing the flexibility of parallelization and improving how inner loop variables are tracked during layout inference.
- WGMMA Control: Added a new pass configuration option (tl.disable_wgmma) to explicitly disable the use of Hopper WGMMA instructions, providing finer control over hardware-specific optimizations and removing the kUTCMMA instruction type.
- Layout Inference Enhancements: Improved the layout inference process by refining the selection of optimal layouts based on register usage, adding extensive debug logging, and modifying the ReducerLayoutAnnotator to specifically target parallel loops.
- Codebase Refinements: Updated the .clang-tidy configuration to disable the clang-analyzer-deadcode.DeadStores check and converted several LOG(INFO) statements to DLOG(INFO) for better control over debug output in various transformation passes.
Using Gemini Code Assist
The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.
Invoking Gemini
You can request assistance from Gemini at any point by creating a comment using either /gemini <command>
or @gemini-code-assist <command>
. Below is a summary of the supported commands on the current page.
Feature | Command | Description |
---|---|---|
Code Review | /gemini review |
Performs a code review for the current pull request in its current state. |
Pull Request Summary | /gemini summary |
Provides a summary of the current pull request in its current state. |
Comment | @gemini-code-assist | Responds in comments when explicitly tagged, both in pull request comments and review comments. |
Help | /gemini help |
Displays a list of available commands. |
Customization
To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/
folder in the base of the repository. Detailed instructions can be found here.
Limitations & Feedback
Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.
You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.
Footnotes
-
Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution. ↩
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (2)
src/op/gemm.h (1)
149-151
: SEqualReduce compares offset_A to other->offset_BLikely copy/paste bug; compare offset_A with offset_A.
- equal(offset_A, other->offset_B) && + equal(offset_A, other->offset_A) && equal(offset_B, other->offset_B) &&src/op/gemm_py.h (1)
71-73
: SEqualReduce compares offset_A to other->offset_BSame bug as GemmNode.
- equal(offset_A, other->offset_B) && + equal(offset_A, other->offset_A) && equal(offset_B, other->offset_B) &&
🧹 Nitpick comments (9)
.clang-tidy (1)
45-45
: Don’t globally disable DeadStoresThis hides real bugs; prefer targeted suppressions (NOLINT, [[maybe_unused]]) or narrow path scoping.
- -clang-analyzer-deadcode.DeadStores,
If you must keep it off, confirm it’s only for known false positives and document rationale in the PR.
src/transform/storage_rewrite.cc (1)
1792-1794
: DLOG addition OK; add explicit logging includeAvoid relying on transitive includes for DLOG.
@@ #include <tvm/tir/transform.h> +#include <tvm/runtime/logging.h>
Please confirm this file builds cleanly with -Werror if transitive includes change.
src/op/parallel.h (1)
143-144
: inner_vars_ lifecycle and equalityEnsure inner_vars_ is rebuilt per traversal (no stale state) and doesn’t affect semantic equality; if it does, include it in SEqual/Hash or compute on-demand only.
Add a brief docstring:
- // The inner_vars_ - Map<Var, IterVar> inner_vars_; + // Mapping of non-parallel (serial) loop vars to ordered IterVars collected during traversal. + Map<Var, IterVar> inner_vars_;src/transform/merge_shared_memory_allocations.cc (1)
643-643
: Gate debug logs with verbose flag and unify styleWrap DLOGs with verbose_ for consistency with later LOG(DEBUG) usage.
- DLOG(INFO) << "PlanAlignment"; + if (verbose_) DLOG(INFO) << "PlanAlignment"; @@ - DLOG(INFO) << "PostOrderVisit CallNode tl_gemm and tl_gemm_sp: " << call->op; + if (verbose_) DLOG(INFO) << "PostOrderVisit CallNode tl_gemm and tl_gemm_sp: " << call->op;Also applies to: 648-649
src/op/parallel.cc (2)
251-276
: Consider enhancing the error message for better debugging.The layout conflict detection is correctly implemented. However, the error message could be more informative.
Consider enhancing the error message to include the actual inner variable name:
- oss << "loop_var_to_thread = " << loop_var_to_thread - << "contains inner var" << *opt_var; + oss << "Layout inference conflict: thread expression '" << loop_var_to_thread + << "' contains non-parallel inner loop variable '" << opt_var->name_hint + << "'. Non-parallel loops cannot participate in thread-level layout inference.";
340-376
: Consider controlling debug logging via environment variable.The extensive DLOG statements are helpful for debugging, but they could affect performance even in debug builds when not needed.
Consider adding an environment variable check to control these verbose logs:
static bool IsVerboseLoggingEnabled() { static bool enabled = std::getenv("TL_VERBOSE_LAYOUT_INFER") != nullptr; return enabled; } // Then use: if (IsVerboseLoggingEnabled()) { DLOG(INFO) << "[PlanLoopPartition] vector_size = " << vector_size << '\n'; }src/transform/layout_inference.cc (2)
595-595
: Consider improving the error message for better diagnostics.The ICHECK error message could be more informative about which component failed.
- ICHECK(min_reg_num < INT64_MAX) << "no available layout found" << '\n'; + ICHECK(min_reg_num < INT64_MAX) + << "Failed to find a valid layout for component with root " << root + << ". All " << members.size() << " candidate roots resulted in conflicts or exceptions.";
720-721
: The FIXME indicates incomplete implementation.The comment suggests that distinguishing between in-Parallel and out-of-Parallel locals is not yet implemented. This could lead to incorrect behavior in certain edge cases.
The simplified parallelization logic needs verification that it correctly handles all local buffer scenarios. Can you confirm that the current implementation correctly handles:
- Local buffers declared inside T.Parallel loops
- Local buffers declared outside but used inside T.Parallel loops
- The interaction with the new T.serial support
Additionally, should this FIXME be tracked as a GitHub issue for proper resolution?
src/op/gemm.cc (1)
99-101
: Consider extracting the complex condition for readability.The WGMMA enablement condition combines multiple checks which could be clearer if separated.
- bool allow_wgmma = - !ctxt->GetConfig(kDisableWGMMA, Optional<Bool>()).value_or(false) && - TargetIsHopper(target) && (this->M >= 64) && (num_warps % 4 == 0); + bool wgmma_disabled = ctxt->GetConfig(kDisableWGMMA, Optional<Bool>()).value_or(false); + bool hardware_supports_wgmma = TargetIsHopper(target); + bool dimensions_suitable = (this->M >= 64) && (num_warps % 4 == 0); + bool allow_wgmma = !wgmma_disabled && hardware_supports_wgmma && dimensions_suitable;
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (13)
.clang-tidy
(1 hunks)src/op/builtin.cc
(1 hunks)src/op/builtin.h
(1 hunks)src/op/gemm.cc
(1 hunks)src/op/gemm.h
(1 hunks)src/op/gemm_py.h
(1 hunks)src/op/parallel.cc
(4 hunks)src/op/parallel.h
(2 hunks)src/transform/layout_inference.cc
(8 hunks)src/transform/layout_reducer.cc
(1 hunks)src/transform/merge_shared_memory_allocations.cc
(1 hunks)src/transform/storage_rewrite.cc
(1 hunks)tilelang/transform/pass_config.py
(1 hunks)
🧰 Additional context used
🧬 Code graph analysis (3)
src/transform/layout_reducer.cc (1)
src/transform/layout_inference.cc (12)
op
(42-48)op
(42-42)op
(317-377)op
(317-317)op
(401-425)op
(401-401)op
(427-444)op
(427-427)op
(446-455)op
(446-446)op
(641-653)op
(641-641)
src/op/parallel.cc (4)
src/transform/layout_inference.cc (18)
op
(42-48)op
(42-42)op
(317-377)op
(317-317)op
(401-425)op
(401-401)op
(427-444)op
(427-427)op
(446-455)op
(446-446)op
(641-653)op
(641-641)op
(690-763)op
(690-690)op
(765-774)op
(765-765)buffer
(393-399)buffer
(393-393)src/op/parallel.h (1)
LayoutConflictException
(27-34)src/layout/layout.cc (2)
Fragment
(274-296)Fragment
(298-308)src/transform/loop_partition.cc (4)
PlanLoopPartition
(187-191)PlanLoopPartition
(187-188)PlanLoopPartition
(193-199)PlanLoopPartition
(193-194)
src/op/gemm.cc (1)
src/target/utils.cc (4)
TargetGetWarpSize
(114-119)TargetGetWarpSize
(114-114)TargetIsHopper
(49-54)TargetIsHopper
(49-49)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
- GitHub Check: format-check
- GitHub Check: bot-task
- GitHub Check: format-check
🔇 Additional comments (10)
src/op/builtin.h (1)
47-47
: kDisableWGMMA key added — verify registration and usageNo repo matches for registration or read of "tl.disable_wgmma"; ensure builtin.cc registers it (TVM_REGISTER_PASS_CONFIG_OPTION) and gemm.cc reads it (GetConfig/PassContext lookup).
src/op/gemm_py.h (1)
110-110
: Enum change: audit for removed kUTCMMA usages (Py node)src/op/gemm_py.h:110 removed kUTCMMA from GemmInst; repository grep for '\bkUTCMMA\b' returned no hits — verify the Python-binding node mirrors the C++ check and that generated artifacts/tests contain no remaining kUTCMMA references.
tilelang/transform/pass_config.py (1)
40-42
: TL_DISABLE_WGMMA added — C++ constant matched; docs/examples missing
- Confirmed: TL_DISABLE_WGMMA = "tl.disable_wgmma" matches C++ kDisableWGMMA (src/op/builtin.h:47).
- Repo search for the exact string 'tl.disable_wgmma' returned only tilelang/transform/pass_config.py:40 and src/op/builtin.h:47 — no docs/examples reference found. Add documentation or example usage for this PassConfig key or confirm it's intentionally internal-only.
src/op/gemm.h (1)
189-189
: No remaining kUTCMMA references — verify switch handling in GetGemmInst callersrg found no occurrences of kUTCMMA; GetGemmInst is defined/used at:
- src/op/gemm.h:189
- src/op/gemm.cc:94 (def), 409, 472 (call sites)
- src/op/gemm_py.h:111
- src/op/gemm_py.cc:95 (def), 223 (call site)
Inspect the switch/case blocks in src/op/gemm.cc and src/op/gemm_py.cc for any fallthroughs or logic that assumed the removed enum value and update them if needed.
src/transform/layout_reducer.cc (1)
181-182
: LGTM! The restriction is appropriate for reducer layout annotation.The additional check
op->kind == ForKind::kParallel
correctly restricts reducer layout annotation to parallel loops only. This aligns with the PR's goal of supporting T.serial inside T.Parallel loops while maintaining proper reducer semantics.src/op/parallel.cc (1)
131-137
: Good separation of parallel vs inner loop tracking.The distinction between
loop_vars_
(for kParallel) andinner_vars_
(for non-parallel) appropriately tracks different loop types. The use of kDataPar vs kOrdered iteration types correctly reflects their different scheduling semantics.src/transform/layout_inference.cc (2)
108-116
: Debug logging looks good but ensure consistent formatting.The debug output is helpful for tracing layout inference steps.
534-593
: Excellent implementation of per-component root selection!The new approach of trying each member as the root and selecting the best based on register usage is a significant improvement. The exception handling for LayoutConflictException and NormalizeIterException prevents crashes while maintaining robustness.
src/op/builtin.cc (1)
30-30
: Config option properly registered.The kDisableWGMMA pass config option is correctly registered between related options.
src/op/gemm.cc (1)
95-101
: WGMMA disabling logic is correctly implemented.The configuration-driven approach using PassContext is appropriate and the default behavior (enabled when config absent) maintains backward compatibility.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces support for T.serial
and local buffers within T.Parallel
loops, which is a significant enhancement. The changes involve updates to layout inference, loop visitors, and GEMM operator logic. My review identified a critical issue in the GEMM operator where a correctness check for WGMMA seems to have been unintentionally removed. I also have a question regarding a FIXME
in the layout inference logic related to handling local buffers. Other changes, including extensive debug logging and refactoring, look good.
bool allow_wgmma = | ||
!ctxt->GetConfig(kDisableWGMMA, Optional<Bool>()).value_or(false) && | ||
TargetIsHopper(target) && (this->M >= 64) && (num_warps % 4 == 0); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The call to CheckWGMMA()
has been removed from the condition for enabling WGMMA. This function performs crucial checks for data types, K-dimension divisibility, and transpose flags, which are required for the correctness of WGMMA on the Hopper architecture. Without these checks, the compiler might generate WGMMA instructions for unsupported configurations, potentially leading to compilation errors or runtime failures.
bool allow_wgmma = | |
!ctxt->GetConfig(kDisableWGMMA, Optional<Bool>()).value_or(false) && | |
TargetIsHopper(target) && (this->M >= 64) && (num_warps % 4 == 0); | |
bool allow_wgmma = | |
!ctxt->GetConfig(kDisableWGMMA, Optional<Bool>()).value_or(false) && | |
TargetIsHopper(target) && (this->M >= 64) && (num_warps % 4 == 0) && CheckWGMMA(); |
auto loop_layout = result_.for_map[root]; | ||
bool parallel_loop = !is_register_store && !skip_thread_partition_; | ||
// FIXME: tell in-Parallel and out-of-Parallel `local`s apart | ||
bool parallel_loop = !skip_thread_partition_; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The condition for parallel_loop
no longer includes !is_register_store
. This will cause loops that only store to register-local buffers ("local"
scope) to be partitioned across threads. While this seems intended to support local buffers in T.Parallel
, the FIXME
on the preceding line indicates this may not be a complete solution.
Could you elaborate on the implications of this change? Forcing partitioning on loops that only use thread-local registers might be unnecessary if they are not shared. It would be helpful to understand the plan to address the FIXME
and correctly differentiate between local buffers that require partitioning and those that do not.
…dling in casting example
…hance buffer store wrapper with detailed comments
…ble_parallel_0918
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 3
🧹 Nitpick comments (3)
tilelang/transform/add_bufstore_wrapper.py (3)
19-19
: Consider removing unused parameters or documenting their purpose.The function arguments
mod
andctx
are not used in the implementation. These are likely part of the TVM pass interface requirements.If these parameters are required by the
prim_func_pass
interface, consider adding a comment to clarify:def pass_fn(func: PrimFunc, mod, ctx): + # mod and ctx are required by prim_func_pass interface but unused in this transformation
Alternatively, if they're truly unused, you could use underscore prefixes to indicate they're intentionally unused:
-def pass_fn(func: PrimFunc, mod, ctx): +def pass_fn(func: PrimFunc, _mod, _ctx):
71-88
: Potential issue: Buffer indices mapping may be overwritten.The
collect_buffer_indices
function maps each buffer to its indices, but if a buffer is accessed multiple times with different indices in the same statement, only the last access will be recorded. This could miss validation of earlier accesses.Consider collecting all unique indices for each buffer:
def collect_buffer_indices(statement) -> dict[Buffer, list[int]]: """ Maps each buffer to its access indices. Args: statement: The TIR statement to analyze Returns: - Dictionary mapping buffers to their access indices + Dictionary mapping buffers to all their access indices """ - buffer_to_indices = {} + from collections import defaultdict + buffer_to_indices = defaultdict(list) def visit_buffer_access(node): if isinstance(node, (BufferLoad, BufferStore)): - buffer_to_indices[node.buffer] = node.indices + if node.indices not in buffer_to_indices[node.buffer]: + buffer_to_indices[node.buffer].append(node.indices) post_order_visit(statement, visit_buffer_access) - return buffer_to_indices + return dict(buffer_to_indices)
133-141
: Fragment buffer validation needs adjustment for multiple accesses.The current validation logic checks fragment buffer indices, but due to the issue in
collect_buffer_indices
, it might only validate the last access if a buffer is accessed multiple times.After fixing the
collect_buffer_indices
function to collect all accesses, update the validation:# Validate fragment buffer indices - only index 0 is supported buffer_indices = collect_buffer_indices(statement) for buffer, indices in buffer_indices.items(): if buffer.scope() == "local.fragment": - for index in indices: - if isinstance(index, IntImm) and index != 0: - raise ValueError( - f"Fragment buffer access with non-zero index [{index}] is not supported. " - "Only fragment[0] access is allowed.") + for indices_list in indices: # Now a list of index lists + for index in indices_list: + if isinstance(index, IntImm) and index.value != 0: + raise ValueError( + f"Fragment buffer '{buffer.name}' access with non-zero index [{index.value}] is not supported. " + "Only fragment[0] access is allowed.")
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (6)
examples/cast/example_group_per_split_token_cast_to_fp8.py
(1 hunks)src/op/builtin.cc
(1 hunks)src/op/builtin.h
(1 hunks)src/op/parallel.h
(2 hunks)tilelang/transform/add_bufstore_wrapper.py
(1 hunks)tilelang/transform/pass_config.py
(1 hunks)
🚧 Files skipped from review as they are similar to previous changes (4)
- src/op/builtin.h
- tilelang/transform/pass_config.py
- examples/cast/example_group_per_split_token_cast_to_fp8.py
- src/op/parallel.h
🧰 Additional context used
🧬 Code graph analysis (1)
tilelang/transform/add_bufstore_wrapper.py (1)
src/op/parallel.h (1)
tvm
(22-161)
🪛 Ruff (0.12.2)
tilelang/transform/add_bufstore_wrapper.py
19-19: Unused function argument: mod
(ARG001)
19-19: Unused function argument: ctx
(ARG001)
139-141: Avoid specifying long messages outside the exception class
(TRY003)
🔇 Additional comments (6)
src/op/builtin.cc (1)
31-31
: LGTM — pass-config key registration verifiedC++ registration, header key, Python export, and GEMM gating via PassContext are consistent.
- src/op/builtin.cc:31 — TVM_REGISTER_PASS_CONFIG_OPTION(kDisableWGMMA, Bool);
- src/op/builtin.h:48 — static constexpr const char *kDisableWGMMA = "tl.disable_wgmma";
- tilelang/transform/pass_config.py:48 — TL_DISABLE_WGMMA = "tl.disable_wgmma"
- src/op/gemm.cc:100 — ctxt->GetConfig(kDisableWGMMA, Optional()).value_or(false) used to gate WGMMA
Confirm this change belongs in the "Flexible Parallel" PR or split it into a separate PR.
tilelang/transform/add_bufstore_wrapper.py (5)
1-1
: LGTM! Import additions align with new functionality.The added imports (
BufferLoad
,Buffer
,IntImm
) are appropriately used in the new helper functions for buffer analysis and index validation.
7-17
: Well-documented transformation purpose.The docstring clearly explains the transformation's purpose and constraints, providing good documentation for maintainers.
25-42
: Efficient variable collection implementation.The
get_used_variables
function correctly collects allVar
nodes using post-order traversal.
90-100
: Clear tile operation detection logic.The function correctly identifies parallel loops and loops with
num_stages
annotation as tile operations.
102-111
: Proper tracking of thread bindings and tile depth.The pre-visit function correctly tracks thread-bound variables and increments the tile operation depth counter.
def collect_buffer_accesses(statement) -> tuple[list[Buffer], list[Buffer]]: | ||
""" | ||
Categorizes buffers accessed in the statement by their scope. | ||
Args: | ||
statement: The TIR statement to analyze | ||
Returns: | ||
Tuple of (local_buffers, fragment_buffers) | ||
""" | ||
accessed_buffers = set() | ||
|
||
def visit_buffer_access(node): | ||
if isinstance(node, (BufferLoad, BufferStore)): | ||
accessed_buffers.add(node.buffer) | ||
|
||
post_order_visit(statement, visit_buffer_access) | ||
|
||
local_buffers = [] | ||
fragment_buffers = [] | ||
for buffer in accessed_buffers: | ||
if buffer.scope() == "local.fragment": | ||
fragment_buffers.append(buffer) | ||
elif buffer.scope().startswith("local"): | ||
local_buffers.append(buffer) | ||
return local_buffers, fragment_buffers |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
🧩 Analysis chain
Consider handling additional local buffer scopes.
The function correctly identifies fragment and local buffers, but only checks for local
prefix for non-fragment buffers. Depending on your buffer scope naming conventions, you might miss other local scopes.
Let me verify the buffer scope naming conventions in the codebase:
🏁 Script executed:
#!/bin/bash
# Search for buffer scope patterns in the codebase
rg -n --type=py 'scope\(\)' -A 2 -B 2 | head -30
# Look for buffer scope string patterns
rg -n --type=py '"local\.' | head -20
# Check for other scope patterns
rg -n --type=py 'buffer\.scope\(\)' -A 3 -B 3 | head -30
Length of output: 5680
Use startswith for fragment scopes and match local. via 'local.'*
utils/language.py uses startswith("local.fragment") and allocate.py defines "local.var" — replace strict equality and tighten the general local check.
- Replace buffer.scope() == "local.fragment" with buffer.scope().startswith("local.fragment") (tilelang/transform/add_bufstore_wrapper.py — occurrences around lines 65 and 136).
- Change elif buffer.scope().startswith("local") to elif buffer.scope().startswith("local.") (tilelang/transform/add_bufstore_wrapper.py — around line 67).
🤖 Prompt for AI Agents
In tilelang/transform/add_bufstore_wrapper.py around lines 44 to 69 (and also
update the other occurrence near line 136), the buffer scope checks are too
strict: replace buffer.scope() == "local.fragment" with
buffer.scope().startswith("local.fragment") to catch sub-scopes, and change elif
buffer.scope().startswith("local") to elif buffer.scope().startswith("local.")
to avoid matching unrelated scopes; apply these two replacements at both
locations mentioned.
for buffer, indices in buffer_indices.items(): | ||
if buffer.scope() == "local.fragment": | ||
for index in indices: | ||
if isinstance(index, IntImm) and index != 0: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IntImm comparison might be incorrect.
When comparing IntImm
objects, you should compare their values, not the objects directly. The comparison index != 0
might not work as intended.
Fix the comparison to use the value
attribute:
-if isinstance(index, IntImm) and index != 0:
+if isinstance(index, IntImm) and index.value != 0:
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
if isinstance(index, IntImm) and index != 0: | |
if isinstance(index, IntImm) and index.value != 0: |
🤖 Prompt for AI Agents
In tilelang/transform/add_bufstore_wrapper.py around line 138, the code checks
`if isinstance(index, IntImm) and index != 0:` which compares the IntImm object
to an int; change the comparison to use the IntImm's numeric field (e.g.,
`index.value != 0`) so you compare the underlying integer value rather than the
object.
# Wrap fragment[0] access with T.Parallel loop | ||
return For(Var("_", "int32"), 0, 1, ForKind.PARALLEL, statement) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🧩 Analysis chain
Verify parallel loop wrapper semantics.
The transformation wraps qualifying buffer stores with a single-iteration parallel loop. Please verify that this achieves the intended semantics for supporting T.serial and local buffers inside T.Parallel loops.
Let me check how this transformation interacts with the broader parallel loop handling:
🏁 Script executed:
#!/bin/bash
# Search for related parallel loop handling
rg -n --type=py "ForKind\.PARALLEL" -B 3 -A 3
# Look for T.serial usage patterns
rg -n --type=py "T\.serial|ForKind\.SERIAL" -B 2 -A 2
# Check for tests or examples of this transformation
fd -e py test | xargs rg -l "AddWrapperForSingleBufStore"
Length of output: 75573
🌐 Web query:
What is the purpose of wrapping a single statement with a For loop of extent 1 and ForKind.PARALLEL in TVM TIR?
💡 Result:
Wrapping a single statement in a For(extent=1, kind=ForKind.PARALLEL) is not for changing correctness but to give TIR a parallel loop node so downstream passes / codegen can treat that statement as parallelizable. Concretely it:
- creates an explicit loop AST node (so scheduling passes can target/annotate/transform it),
- marks the region as parallel (hints to scheduling, vectorize/parallel transforms and codegen/thread-mapping), and
- preserves semantics (extent=1 keeps identical execution) while enabling passes that require a For node with ForKind.PARALLEL. (tvm.apache.org)
If you want examples or how a specific pass (e.g., sch.parallel / ConvertForLoopsToSerial / loop-partition) interacts with such loops, tell me which pass and I’ll show the IR changes.
Citations:
🏁 Script executed:
#!/bin/bash
set -euo pipefail
nl -ba tilelang/transform/add_bufstore_wrapper.py | sed -n '1,240p'
Length of output: 40
🏁 Script executed:
#!/bin/bash
set -euo pipefail
echo "=== add_bufstore_wrapper.py (first 300 lines) ==="
sed -n '1,300p' tilelang/transform/add_bufstore_wrapper.py || { echo "file read failed"; exit 0; }
echo
echo "=== Matches in repo: fragment[0], ForKind.PARALLEL, alloc_local ==="
rg -n --hidden -S "fragment\[0\]|ForKind\.PARALLEL|alloc_local\(" || true
echo
echo "=== Show surrounding function for the wrapper (if found) ==="
rg -n --hidden -S "Wrap fragment|fragment\\[0\\]" tilelang/transform/add_bufstore_wrapper.py -n -C3 || true
Length of output: 31559
Fix parallel-wrapper semantics in add_bufstore_wrapper.py
- The single-iteration For(…, ForKind.PARALLEL, …) is fine to expose a parallel node, but there are correctness issues that must be fixed before merge.
- Index validation is incomplete — tilelang/transform/add_bufstore_wrapper.py:138-141 only rejects non-zero IntImm but will accept non-IntImm (dynamic) indices; require every fragment index to be IntImm(0) and raise otherwise.
- Wrapping only the BufferStore (tilelang/transform/add_bufstore_wrapper.py:143-144) can leave the fragment's alloc_local outside the parallel region; ensure fragment allocations and related loads/uses are placed inside the created parallel For (or wrap allocs/loads as well) so per-thread/local semantics are preserved.
- Verify thread-binding detection: pre_visit uses statement.node.var for AttrStmt('thread_extent') (tilelang/transform/add_bufstore_wrapper.py:100); confirm this reliably yields the bound Var or extract the correct Var to avoid missed detections or attribute errors.
🤖 Prompt for AI Agents
In tilelang/transform/add_bufstore_wrapper.py around lines 100 and 138-144, fix
three issues: (1) strengthen index validation by requiring every fragment index
to be an IntImm(0) and raise an explicit error if any index is not an IntImm(0)
(don’t accept dynamic indices), (2) expand the parallel wrapper so that the
created For(…, ForKind.PARALLEL, …) encloses not only the BufferStore but also
the fragment's alloc_local and any loads/uses tied to that fragment so
allocations and per-thread/local accesses live inside the parallel region, and
(3) make thread-binding detection robust in pre_visit by extracting the actual
bound Var from the AttrStmt('thread_extent') node reliably (don’t assume
statement.node.var blindly; extract the correct Var field or attribute to avoid
missed detections or attribute errors).
as title, thanks @huanqi .
/gemini summary.
Summary by CodeRabbit
New Features
Bug Fixes
Refactor
Chores
Examples