-
Notifications
You must be signed in to change notification settings - Fork 156
[TileOp] Implement WGMMA for T.gemm_v2 #813
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 9 commits
ec26c23
2ff5cbf
0166a90
ce83ace
72e900d
22131e7
6632a70
eac5433
51fcf15
ce9f545
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -42,7 +42,7 @@ using namespace tir; | |
* @param vmap Mapping from access pointer vars to Buffer objects used to | ||
* resolve the Buffer corresponding to each pointer argument. | ||
* | ||
* @note If `kPack` is provided it must be 1 or 2; otherwise the constructor | ||
* @note If `kPack` is provided it must be 1; otherwise the constructor | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
* fails with an ICHECK (runtime assertion). No other validation is | ||
* performed here. | ||
*/ | ||
|
@@ -478,7 +478,7 @@ LayoutMap GemmNode::InferLayout(const LayoutInferArgs &T, | |
int dim_A = A->shape.size(); | ||
results.Set(A, makeGemmVoltaABLayout(*as_const_int(A->shape[dim_A - 2]), | ||
*as_const_int(A->shape[dim_A - 1]), | ||
true, trans_A ? 1 : 2)); | ||
true, !trans_A)); | ||
} else if (A.scope() == "local.fragment") { | ||
ICHECK(trans_A == false); | ||
auto fragment = makeGemmVoltaFragmentA(M, N, K, M / warp_m, N / warp_n); | ||
|
@@ -491,7 +491,7 @@ LayoutMap GemmNode::InferLayout(const LayoutInferArgs &T, | |
int dim_B = B->shape.size(); | ||
results.Set(B, makeGemmVoltaABLayout(*as_const_int(B->shape[dim_B - 2]), | ||
*as_const_int(B->shape[dim_B - 1]), | ||
false, trans_B ? 2 : 1)); | ||
false, trans_B)); | ||
} else if (TargetIsAmpere(T.target) || TargetIsTuring(T.target) || | ||
TargetIsSM120(T.target)) { | ||
auto fragment = | ||
|
@@ -504,7 +504,7 @@ LayoutMap GemmNode::InferLayout(const LayoutInferArgs &T, | |
const int64_t mat_continuous = *as_const_int(A->shape[dim_A - 1]); | ||
results.Set(A, | ||
makeGemmABLayout(mat_stride, mat_continuous, mat_continuous, | ||
A->dtype.bits(), trans_A ? 1 : 2)); | ||
A->dtype.bits(), !trans_A)); | ||
} else if (A.scope() == "local.fragment") { | ||
auto fragment = makeGemmFragmentA(M, N, K, M / warp_m, N / warp_n, | ||
A->dtype.bits(), trans_A); | ||
|
@@ -518,7 +518,7 @@ LayoutMap GemmNode::InferLayout(const LayoutInferArgs &T, | |
const int64_t mat_continuous = *as_const_int(B->shape[dim_B - 1]); | ||
results.Set(B, | ||
makeGemmABLayout(mat_stride, mat_continuous, mat_continuous, | ||
B->dtype.bits(), trans_B ? 2 : 1)); | ||
B->dtype.bits(), trans_B)); | ||
} else if (B.scope() == "local.fragment") { | ||
auto fragment = | ||
makeGemmFragmentB(M, N, K, M / warp_m, N / warp_n, trans_B); | ||
|
@@ -542,9 +542,9 @@ LayoutMap GemmNode::InferLayout(const LayoutInferArgs &T, | |
auto ABLayout = | ||
gemm_inst == GemmInst::kWGMMA | ||
? makeGemmABLayoutHopper(mat_stride, mat_continuous, continuity, | ||
A->dtype.bits(), trans_A ? 1 : 2) | ||
A->dtype.bits(), !trans_A) | ||
: makeGemmABLayout(mat_stride, mat_continuous, mat_continuous, | ||
A->dtype.bits(), trans_A ? 1 : 2); | ||
A->dtype.bits(), !trans_A); | ||
results.Set(A, ABLayout); | ||
} else { | ||
auto fragment = makeGemmFragmentA(M, N, K, M / warp_m, N / warp_n, | ||
|
@@ -560,9 +560,9 @@ LayoutMap GemmNode::InferLayout(const LayoutInferArgs &T, | |
auto ABLayout = | ||
gemm_inst == GemmInst::kWGMMA | ||
? makeGemmABLayoutHopper(mat_stride, mat_continuous, continuity, | ||
B->dtype.bits(), trans_B ? 2 : 1) | ||
B->dtype.bits(), trans_B) | ||
: makeGemmABLayout(mat_stride, mat_continuous, mat_continuous, | ||
B->dtype.bits(), trans_B ? 2 : 1); | ||
B->dtype.bits(), trans_B); | ||
results.Set(B, ABLayout); | ||
} else { | ||
auto fragment = | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -106,6 +106,8 @@ GemmPyNode::GemmInst GemmPyNode::GetGemmInst(int block_size, | |
return GemmInst::kMMA; | ||
} else { | ||
ICHECK(0) << "Unsupported target for gemm: " << target->str(); | ||
return GemmInst::kMMA; // This line will never be reached due to ICHECK, but | ||
// satisfies compiler | ||
} | ||
} | ||
|
||
|
@@ -225,8 +227,9 @@ Stmt GemmPyNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { | |
M, N, block_size, T.target, gemm_inst == GemmInst::kWGMMA); | ||
|
||
if (const auto f = ffi::Function::GetGlobal("tl.gemm_py.lower")) { | ||
auto prim_func = Downcast<PrimFunc>( | ||
(*f)(GetRef<GemmPy>(this), T.target, T.thread_bounds, T.thread_var)); | ||
auto prim_func = | ||
Downcast<PrimFunc>((*f)(GetRef<GemmPy>(this), T.layout_map, T.target, | ||
T.thread_bounds, T.thread_var)); | ||
ICHECK(prim_func->attrs.defined()); | ||
Comment on lines
+231
to
233
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🛠️ Refactor suggestion 🧩 Analysis chainFFI lower signature now includes layout_map; ensure all providers updated. The C++ call passes T.layout_map. Verify the Python FFI target “tl.gemm_py.lower” accepts (gemm_py, layout_map, target, thread_bounds, thread_var) everywhere to avoid runtime errors. 🏁 Script executed: #!/bin/bash
# Confirm the new signature at Python side
rg -nP 'def\s+lower\((self|gemm_py)[^)]*layout_map' -C2
# Find the FFI registration or dispatcher for "tl.gemm_py.lower"
rg -n 'gemm_py\.lower' -S
# Check base class signatures to avoid LSP mismatches
rg -nP 'class\s+GemmBase\b.*?\n' -n tilelang/tileop/gemm -C2
rg -nP 'def\s+lower\(' tilelang/tileop/gemm -S Length of output: 2364 Add layout_map to GemmBase.lower Change tilelang/tileop/gemm/gemm_base.py:18 from |
||
auto global_symbol = prim_func->attrs.GetAttr<String>("global_symbol"); | ||
ICHECK(global_symbol.defined()); | ||
|
@@ -249,6 +252,8 @@ Stmt GemmPyNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { | |
/*name_hint=*/global_symbol.value(), prim_func->body)); | ||
} else { | ||
LOG(FATAL) << "No lower function found for gemm_py"; | ||
return Stmt(); // This line will never be reached due to LOG(FATAL), but | ||
// satisfies compiler | ||
} | ||
} | ||
|
||
|
@@ -275,5 +280,14 @@ TIR_REGISTER_TL_OP(GemmPy, gemm_py) | |
Integer(CallEffectKind::kOpaque)); | ||
|
||
TVM_FFI_STATIC_INIT_BLOCK({ GemmPyNode::RegisterReflection(); }); | ||
|
||
TVM_FFI_STATIC_INIT_BLOCK({ | ||
namespace refl = tvm::ffi::reflection; | ||
refl::GlobalDef().def("tl.GemmPyGemmInst", | ||
[](GemmPy gemm_py, int block_size, Target target) { | ||
return gemm_py->GetGemmInst(block_size, target); | ||
}); | ||
}); | ||
|
||
} // namespace tl | ||
} // namespace tvm |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🧩 Analysis chain
Signature change to bool k_inner: audit all call sites
Converting
int kfactor
→bool k_inner
changes semantics; old int literals (e.g., 2) now coerce totrue
. Please verify all callers pass a boolean with the intended meaning.Run:
🏁 Script executed:
Length of output: 26
🏁 Script executed:
Length of output: 9413
Critical: callers passing integer k-factor (1/2) must be fixed — makeGemmABLayout/Hopper now takes bool k_inner
src/op/gemm_sp.cc passes integer ternaries (trans_A ? 1 : 2 and trans_B ? 2 : 1) at src/op/gemm_sp.cc:196-198 and src/op/gemm_sp.cc:209-211; both non-zero values now convert to true and break intended semantics. Change these calls to pass an explicit boolean with the intended meaning (e.g., use trans_A/trans_B or (expr)==1), or restore an int overload if 1/2 semantics are required. Other callers (src/op/gemm.cc, src/layout/*.cc/h) were checked and pass booleans.