Skip to content
Open
34 changes: 17 additions & 17 deletions src/layout/gemm_layouts.cc
Original file line number Diff line number Diff line change
Expand Up @@ -138,8 +138,8 @@ Fragment makeGemmFragmentCHopper(const int block_m, const int block_n,
const int warp_m, const int warp_n,
const int element_size) {
ICHECK(block_m % warp_m == 0);
// ICHECK(block_n == warp_n);
ICHECK(warp_m % 16 == 0) << "warp_m=" << warp_m;

auto warp_layout = makeGemmFragment8x8()->Repeat({2, warp_n / 8}, false,
false); // 16 x N (1 warp)
auto block_layout = warp_layout->Repeat({block_m / warp_m, block_n / warp_n},
Expand Down Expand Up @@ -530,8 +530,8 @@ Layout MakeGemmVoltaBLayoutCongruous(int stride, int continuous) {
}

Layout makeGemmVoltaABLayout(int stride, int continuous, bool is_a,
int kfactor) {
if (kfactor == 2)
bool k_inner) {
if (k_inner)
return MakeGemmVoltaABLayoutCrosswise(stride, continuous);
if (is_a && continuous % 64 == 0)
return MakeGemmVoltaALayoutCongruous(stride, continuous);
Expand All @@ -558,29 +558,29 @@ Layout makeGemmVoltaABLayout(int stride, int continuous, bool is_a,
* select specific swizzling strategies. It might be the same as mat_continuous
* or different based on tiling or hardware details.
* \param element_size The size of each element in the matrix, in bits (e.g., 8,
* 16, 32, 64). \param kfactor An integer factor that influences layout
* 16, 32, 64). \param k_inner Whether the K dimension is in the inner loop.
* selection, particularly for fp64 and int8 types. It often relates to how the
* K dimension of the GEMM (M x K * K x N) is handled or tiled.
* - For fp64 (element_size == 64):
* - kfactor == 1 often implies K is in the "outer" loop (e.g.,
* KxN matrix).
* - kfactor == 2 often implies K is in the "inner" loop (e.g.,
* NxK matrix).
* - k_inner == false often implies K is in the "outer" loop
* (e.g., KxN matrix).
* - k_inner == true often implies K is in the "inner" loop
* (e.g., NxK matrix).
* - For int8 (element_size == 8):
* - kfactor == 1 uses a padded layout.
* - k_inner == false uses a padded layout.
* \return A Layout object representing the chosen memory layout.
*/
Layout makeGemmABLayout(int mat_stride, int mat_continuous, int continuity,
int element_size, int kfactor) {
int element_size, bool k_inner) {
if (element_size == 64) {
if (kfactor == 1 && continuity % 16 == 0) // float64 KxN
if (!k_inner && continuity % 16 == 0) // float64 KxN
return makeGemmABLayoutF64_Kouter(mat_stride, mat_continuous);
if (kfactor == 2 && continuity % 16 == 0) // float64 NxK
if (k_inner && continuity % 16 == 0) // float64 NxK
return makeGemmABLayoutF64_Kinner(mat_stride, mat_continuous);
return makeGemmABLayoutPadded(mat_stride, mat_continuous, element_size);
}
int vector_size = 128 / element_size;
if (kfactor == 1 && element_size == 8) // int8 KxN
if (!k_inner && element_size == 8) // int8 KxN
return makeGemmABLayoutPadded(mat_stride, mat_continuous, element_size);
else if (mat_continuous % (vector_size * 8) == 0)
return makeFullBankSwizzleLayout(mat_stride, mat_continuous, element_size);
Expand All @@ -592,17 +592,17 @@ Layout makeGemmABLayout(int mat_stride, int mat_continuous, int continuity,
}

Layout makeGemmABLayoutHopper(int mat_stride, int mat_continuous,
int continuity, int element_size, int kfactor) {
int continuity, int element_size, bool k_inner) {
if (element_size == 64) {
if (kfactor == 1 && continuity % 16 == 0) // float64 KxN
if (!k_inner && continuity % 16 == 0) // float64 KxN
return makeGemmABLayoutF64_Kouter(mat_stride, mat_continuous);
if (kfactor == 2 && continuity % 16 == 0) // float64 NxK
if (k_inner && continuity % 16 == 0) // float64 NxK
return makeGemmABLayoutF64_Kinner(mat_stride, mat_continuous);
return makeQuarterBankSwizzleLayout(mat_stride, mat_continuous,
element_size);
}
int vector_size = 128 / element_size;
if (kfactor == 1 && element_size == 8) // int8 KxN
if (!k_inner && element_size == 8) // int8 KxN
return makeQuarterBankSwizzleLayout(mat_stride, mat_continuous,
element_size);
else if (mat_continuous % (vector_size * 8) == 0)
Expand Down
41 changes: 38 additions & 3 deletions src/layout/layout.cc
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ void LayoutNode::RegisterReflection() {
}

void LayoutNode::UpdateAnalyzer(arith::Analyzer *analyzer) const {
for (const auto &[var, dom] : getVarMap()) {
for (const auto &[var, dom] : LayoutNode::getVarMap()) {
analyzer->Bind(var, dom);
}
}
Expand Down Expand Up @@ -458,6 +458,11 @@ TVM_FFI_STATIC_INIT_BLOCK({
[](Layout layout) { return layout->GetForwardIndex(); })
.def("tl.Layout_forward_vars",
[](Layout layout) { return layout->GetForwardVars(); })
.def("tl.Layout_is_equal",
[](Layout layout, Layout other) {
const LayoutNode *other_node = other.as<LayoutNode>();
return layout->IsEqual(other_node);
})
.def_packed("tl.Fragment",
[](PackedArgs args, Any *rv) {
*rv = Fragment(
Expand All @@ -466,6 +471,11 @@ TVM_FFI_STATIC_INIT_BLOCK({
/*forward_thread=*/args[2].cast<PrimExpr>(),
/*thread_replicate=*/args[3].cast<IterVar>());
})
.def("tl.Fragment_is_equal",
[](Fragment fragment, Fragment other) {
const FragmentNode *other_node = other.as<FragmentNode>();
return fragment->IsEqual(other_node);
})
.def("tl.Fragment_thread_size",
[](Fragment fragment) { return fragment->ThreadExtent(); })
.def("tl.Fragment_thread",
Expand All @@ -483,9 +493,34 @@ TVM_FFI_STATIC_INIT_BLOCK({
.def("tl.Fragment_condense_rep_var",
[](Fragment fragment) { return fragment->CondenseReplicateVar(); })
.def("tl.make_swizzled_layout",
[](int stride, int continuous, int element_size, bool k_inner,
bool allow_pad = true) {
if (allow_pad) {
return makeGemmABLayout(stride, continuous, continuous,
element_size, k_inner);
} else {
return makeGemmABLayoutHopper(stride, continuous, continuous,
element_size, k_inner);
}
})
.def("tl.make_wgmma_swizzled_layout",
[](int stride, int mat_continuous, int continuity, int element_size,
bool k_inner) {
return makeGemmABLayoutHopper(stride, mat_continuous, continuity,
element_size, k_inner);
})
.def("tl.make_full_bank_swizzled_layout",
[](int stride, int continuous, int element_size) {
return makeFullBankSwizzleLayout(stride, continuous, element_size);
})
.def("tl.make_half_bank_swizzled_layout",
[](int stride, int continuous, int element_size) {
return makeHalfBankSwizzleLayout(stride, continuous, element_size);
})
.def("tl.make_quarter_bank_swizzled_layout",
[](int stride, int continuous, int element_size) {
return makeGemmABLayout(stride, continuous, continuous,
element_size, 0);
return makeQuarterBankSwizzleLayout(stride, continuous,
element_size);
});
});

Expand Down
9 changes: 5 additions & 4 deletions src/layout/layout.h
Original file line number Diff line number Diff line change
Expand Up @@ -160,11 +160,12 @@ Fragment makeGemmFragmentACDNA(const int block_m, const int block_n,
Layout makeGemmLayoutLinear(int stride, int continuous);
Layout makeGemmABLayoutPadded(int stride, int continuous, int element_size);
Layout makeGemmABLayout(int mat_stride, int mat_continuous, int continuity,
int element_size, int kfactor);
int element_size, bool k_inner = true);
Layout makeGemmABLayoutHopper(int mat_stride, int mat_continuous,
int continuity, int element_size, int kfactor);
int continuity, int element_size,
bool k_inner = true);
Layout makeGemmABLayoutCDNA(int stride, int continuous, int element_size,
Comment on lines +163 to 167
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

🧩 Analysis chain

Signature change to bool k_inner: audit all call sites

Converting int kfactorbool k_inner changes semantics; old int literals (e.g., 2) now coerce to true. Please verify all callers pass a boolean with the intended meaning.

Run:


🏁 Script executed:

#!/bin/bash
set -euo pipefail

# Show call sites with context
rg -nPU -C3 --type=cpp --type=cc --type=cxx --type=h --type=cu '\bmakeGemmABLayout\s*\('
rg -nPU -C3 --type=cpp --type=cc --type=cxx --type=h --type=cu '\bmakeGemmABLayoutHopper\s*\('

# Heuristic: calls with 5+ args where last arg is not clearly true/false or named
rg -nPU --type=cpp --type=cc --type=cxx --type=cu '\bmakeGemm(ABLayout|ABLayoutHopper)\s*\((?:[^,]*,){4}\s*([^,)]+)\)' -r '$1' | grep -v -E '^\s*(true|false)\s*$' || true

Length of output: 26


🏁 Script executed:

#!/bin/bash
set -euo pipefail

echo "1) Listing raw matches (context):"
if command -v rg >/dev/null 2>&1; then
  rg -n -U -P '\bmakeGemmABLayout(?:Hopper|CDNA)?\s*\(' -C3 --hidden || true
else
  grep -nR --line-number -E '\bmakeGemmABLayout(?:Hopper|CDNA)?\s*\(' . || true
fi

echo
echo "2) Extracting full call expression and printing the last argument (best-effort):"
python - <<'PY'
import re,subprocess
pat=re.compile(r'\b(makeGemmABLayout(?:Hopper|CDNA)?)\s*\(')
# try rg first, fallback to grep
try:
    out=subprocess.check_output(['rg','-n','-U','-P',r'\bmakeGemmABLayout(?:Hopper|CDNA)?\s*\('],stderr=subprocess.DEVNULL).decode()
except Exception:
    try:
        out=subprocess.check_output(['grep','-nR','-E',r'\bmakeGemmABLayout(?:Hopper|CDNA)?\s*\(','./'],stderr=subprocess.DEVNULL).decode()
    except Exception:
        out=''

for line in out.splitlines():
    parts=line.split(':',2)
    if len(parts)<3:
        continue
    path,lineno,content=parts[0],int(parts[1]),parts[2]
    try:
        with open(path,'r',encoding='utf-8') as f:
            data=f.read()
    except Exception:
        continue
    lines=data.splitlines()
    idx=max(0,lineno-1)
    window = '\n'.join(lines[idx:idx+200])
    m=pat.search(window)
    if not m:
        continue
    start_idx=m.end()-1
    depth=0
    end_pos=None
    for i,ch in enumerate(window[start_idx:], start=start_idx):
        if ch=='(':
            depth+=1
        elif ch==')':
            depth-=1
            if depth==0:
                end_pos=i
                break
    call_text = window[m.start(): end_pos+1] if end_pos is not None else window[m.start():m.start()+400]
    inside = call_text[call_text.find('(')+1: (call_text.rfind(')') if call_text.rfind(')')>call_text.find('(') else None )]
    inside = re.sub(r'/\*.*?\*/|//.*?$','',inside,flags=re.S|re.M)
    # split args by commas not inside nested constructs
    args=[]
    cur=''
    depth=0
    for ch in inside:
        if ch in '([{<':
            depth+=1
        elif ch in ')]}>':
            depth-=1
        if ch==',' and depth==0:
            args.append(cur.strip()); cur=''
        else:
            cur+=ch
    if cur.strip():
        args.append(cur.strip())
    last = args[-1] if args else ''
    # simplify last token
    last_token = last.strip().split()[-1] if last.strip() else ''
    print(f"{path}:{lineno}: func={m.group(1)} last_arg_raw='{last}' last_arg_token='{last_token}'")
PY

Length of output: 9413


Critical: callers passing integer k-factor (1/2) must be fixed — makeGemmABLayout/Hopper now takes bool k_inner

src/op/gemm_sp.cc passes integer ternaries (trans_A ? 1 : 2 and trans_B ? 2 : 1) at src/op/gemm_sp.cc:196-198 and src/op/gemm_sp.cc:209-211; both non-zero values now convert to true and break intended semantics. Change these calls to pass an explicit boolean with the intended meaning (e.g., use trans_A/trans_B or (expr)==1), or restore an int overload if 1/2 semantics are required. Other callers (src/op/gemm.cc, src/layout/*.cc/h) were checked and pass booleans.

int kfactor);
int kPack);

Fragment makeGemmVoltaFragmentC(const int block_m, const int block_n,
const int warp_m, const int warp_n,
Expand All @@ -173,7 +174,7 @@ Fragment makeGemmVoltaFragmentA(const int block_m, const int block_n,
const int block_k, const int warp_m,
const int warp_n);
Layout makeGemmVoltaABLayout(int stride, int continuous, bool is_a,
int kfactor);
bool k_inner = true);

Layout makeFullBankSwizzleLayout(int stride, int continuous, int element_size);
Layout makeHalfBankSwizzleLayout(int stride, int continuous, int element_size);
Expand Down
20 changes: 20 additions & 0 deletions src/op/builtin.cc
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,16 @@ TIR_DEFINE_TL_BUILTIN(mbarrier_expect_tx)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));

TIR_DEFINE_TL_BUILTIN(ptx_wgmma_ss)
.set_num_inputs(15)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));

TIR_DEFINE_TL_BUILTIN(ptx_wgmma_rs)
.set_num_inputs(15)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));

TIR_DEFINE_TL_BUILTIN(ptx_ldmatrix)
.set_num_inputs(4)
.set_attr<TCallEffectKind>("TCallEffectKind",
Expand Down Expand Up @@ -161,5 +171,15 @@ TIR_DEFINE_TL_BUILTIN(tl_shuffle_elect)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kPure));

TIR_DEFINE_TL_BUILTIN(initialize_descriptor)
.set_num_inputs(5)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));

TIR_DEFINE_TL_BUILTIN(increase_descriptor_offset)
.set_num_inputs(2)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));

} // namespace tl
} // namespace tvm
40 changes: 40 additions & 0 deletions src/op/builtin.h
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,28 @@ TVM_DLL const Op &mbarrier_wait_parity();
*/
TVM_DLL const Op &mbarrier_expect_tx();

/*!
* \brief tvm intrinsic for ptx tensor core wgmma instructions.
*
* void ptx_wgmma_ss(StringImm accum_dtype, StringImm wgmma_prefix, bool
* a_is_k_major, bool b_is_k_major, StringImm a_dtype_abbrv, StringImm b_dtype_abbrv,
* StringImm accum_dtype_abbrv, Var A_descriptor, PrimExpr A_offset, Var
* B_descriptor, Var B_offset, Var C_data, Var C_offset, bool scale_out, bool
* scale_in_a, bool scale_in_b);
*/
TVM_DLL const Op &ptx_wgmma_ss();

/*!
* \brief tvm intrinsics for ptx tensor core wgmma instructions.
*
* void ptx_wgmma_rs(StringImm accum_dtype, StringImm wgmma_prefix, bool
* a_is_k_major, bool b_is_k_major, StringImm a_dtype_abbrv, StringImm b_dtype_abbrv,
* StringImm accum_dtype_abbrv, Var A_descriptor, PrimExpr A_offset, Var
* B_descriptor, Var B_offset, Var C_data, Var C_offset, bool scale_out, bool
* scale_in_a, bool scale_in_b);
*/
TVM_DLL const Op &ptx_wgmma_rs();

/*!
* \brief tvm intrinsics for ldmatrix
*
Expand Down Expand Up @@ -319,6 +341,24 @@ TVM_DLL const Op &tl_gemm_sp();
*/
TVM_DLL const Op &tl_shuffle_elect();

/*!
* \brief tilelang intrinsic for initializing a descriptor buffer for
* wgmma/utcmma.
*
* This op is used to represent a descriptor initialization operation in
* tilelang.
*/
TVM_DLL const Op &initialize_descriptor();

/*!
* \brief tilelang intrinsic for setting the start address of a descriptor
* buffer for wgmma/utcmma.
*
* This op is used to represent a descriptor start address setting operation in
* tilelang.
*/
TVM_DLL const Op &increase_descriptor_offset();

} // namespace tl
} // namespace tvm

Expand Down
18 changes: 9 additions & 9 deletions src/op/gemm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ using namespace tir;
* @param vmap Mapping from access pointer vars to Buffer objects used to
* resolve the Buffer corresponding to each pointer argument.
*
* @note If `kPack` is provided it must be 1 or 2; otherwise the constructor
* @note If `kPack` is provided it must be 1; otherwise the constructor
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bug: Gemm Constructor Documentation Mismatch

The Gemm constructor's documentation for kPack is incomplete and inconsistent with its validation logic. The comment states kPack must be 1, but the underlying validation permits 1 or 2.

Fix in Cursor Fix in Web

* fails with an ICHECK (runtime assertion). No other validation is
* performed here.
*/
Expand Down Expand Up @@ -478,7 +478,7 @@ LayoutMap GemmNode::InferLayout(const LayoutInferArgs &T,
int dim_A = A->shape.size();
results.Set(A, makeGemmVoltaABLayout(*as_const_int(A->shape[dim_A - 2]),
*as_const_int(A->shape[dim_A - 1]),
true, trans_A ? 1 : 2));
true, !trans_A));
} else if (A.scope() == "local.fragment") {
ICHECK(trans_A == false);
auto fragment = makeGemmVoltaFragmentA(M, N, K, M / warp_m, N / warp_n);
Expand All @@ -491,7 +491,7 @@ LayoutMap GemmNode::InferLayout(const LayoutInferArgs &T,
int dim_B = B->shape.size();
results.Set(B, makeGemmVoltaABLayout(*as_const_int(B->shape[dim_B - 2]),
*as_const_int(B->shape[dim_B - 1]),
false, trans_B ? 2 : 1));
false, trans_B));
} else if (TargetIsAmpere(T.target) || TargetIsTuring(T.target) ||
TargetIsSM120(T.target)) {
auto fragment =
Expand All @@ -504,7 +504,7 @@ LayoutMap GemmNode::InferLayout(const LayoutInferArgs &T,
const int64_t mat_continuous = *as_const_int(A->shape[dim_A - 1]);
results.Set(A,
makeGemmABLayout(mat_stride, mat_continuous, mat_continuous,
A->dtype.bits(), trans_A ? 1 : 2));
A->dtype.bits(), !trans_A));
} else if (A.scope() == "local.fragment") {
auto fragment = makeGemmFragmentA(M, N, K, M / warp_m, N / warp_n,
A->dtype.bits(), trans_A);
Expand All @@ -518,7 +518,7 @@ LayoutMap GemmNode::InferLayout(const LayoutInferArgs &T,
const int64_t mat_continuous = *as_const_int(B->shape[dim_B - 1]);
results.Set(B,
makeGemmABLayout(mat_stride, mat_continuous, mat_continuous,
B->dtype.bits(), trans_B ? 2 : 1));
B->dtype.bits(), trans_B));
} else if (B.scope() == "local.fragment") {
auto fragment =
makeGemmFragmentB(M, N, K, M / warp_m, N / warp_n, trans_B);
Expand All @@ -542,9 +542,9 @@ LayoutMap GemmNode::InferLayout(const LayoutInferArgs &T,
auto ABLayout =
gemm_inst == GemmInst::kWGMMA
? makeGemmABLayoutHopper(mat_stride, mat_continuous, continuity,
A->dtype.bits(), trans_A ? 1 : 2)
A->dtype.bits(), !trans_A)
: makeGemmABLayout(mat_stride, mat_continuous, mat_continuous,
A->dtype.bits(), trans_A ? 1 : 2);
A->dtype.bits(), !trans_A);
results.Set(A, ABLayout);
} else {
auto fragment = makeGemmFragmentA(M, N, K, M / warp_m, N / warp_n,
Expand All @@ -560,9 +560,9 @@ LayoutMap GemmNode::InferLayout(const LayoutInferArgs &T,
auto ABLayout =
gemm_inst == GemmInst::kWGMMA
? makeGemmABLayoutHopper(mat_stride, mat_continuous, continuity,
B->dtype.bits(), trans_B ? 2 : 1)
B->dtype.bits(), trans_B)
: makeGemmABLayout(mat_stride, mat_continuous, mat_continuous,
B->dtype.bits(), trans_B ? 2 : 1);
B->dtype.bits(), trans_B);
results.Set(B, ABLayout);
} else {
auto fragment =
Expand Down
18 changes: 16 additions & 2 deletions src/op/gemm_py.cc
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,8 @@ GemmPyNode::GemmInst GemmPyNode::GetGemmInst(int block_size,
return GemmInst::kMMA;
} else {
ICHECK(0) << "Unsupported target for gemm: " << target->str();
return GemmInst::kMMA; // This line will never be reached due to ICHECK, but
// satisfies compiler
}
}

Expand Down Expand Up @@ -225,8 +227,9 @@ Stmt GemmPyNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
M, N, block_size, T.target, gemm_inst == GemmInst::kWGMMA);

if (const auto f = ffi::Function::GetGlobal("tl.gemm_py.lower")) {
auto prim_func = Downcast<PrimFunc>(
(*f)(GetRef<GemmPy>(this), T.target, T.thread_bounds, T.thread_var));
auto prim_func =
Downcast<PrimFunc>((*f)(GetRef<GemmPy>(this), T.layout_map, T.target,
T.thread_bounds, T.thread_var));
ICHECK(prim_func->attrs.defined());
Comment on lines +231 to 233
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

🧩 Analysis chain

FFI lower signature now includes layout_map; ensure all providers updated.

The C++ call passes T.layout_map. Verify the Python FFI target “tl.gemm_py.lower” accepts (gemm_py, layout_map, target, thread_bounds, thread_var) everywhere to avoid runtime errors.


🏁 Script executed:

#!/bin/bash
# Confirm the new signature at Python side
rg -nP 'def\s+lower\((self|gemm_py)[^)]*layout_map' -C2

# Find the FFI registration or dispatcher for "tl.gemm_py.lower"
rg -n 'gemm_py\.lower' -S

# Check base class signatures to avoid LSP mismatches
rg -nP 'class\s+GemmBase\b.*?\n' -n tilelang/tileop/gemm -C2
rg -nP 'def\s+lower\(' tilelang/tileop/gemm -S

Length of output: 2364


Add layout_map to GemmBase.lower

Change tilelang/tileop/gemm/gemm_base.py:18 from
def lower(self, target, thread_nums, thread_var)
to
def lower(self, layout_map: dict, target: Target, thread_nums: int, thread_var: tir.Var)
so the base signature matches the FFI tl.gemm_py.lower and existing subclass implementations.

auto global_symbol = prim_func->attrs.GetAttr<String>("global_symbol");
ICHECK(global_symbol.defined());
Expand All @@ -249,6 +252,8 @@ Stmt GemmPyNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
/*name_hint=*/global_symbol.value(), prim_func->body));
} else {
LOG(FATAL) << "No lower function found for gemm_py";
return Stmt(); // This line will never be reached due to LOG(FATAL), but
// satisfies compiler
}
}

Expand All @@ -275,5 +280,14 @@ TIR_REGISTER_TL_OP(GemmPy, gemm_py)
Integer(CallEffectKind::kOpaque));

TVM_FFI_STATIC_INIT_BLOCK({ GemmPyNode::RegisterReflection(); });

TVM_FFI_STATIC_INIT_BLOCK({
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tl.GemmPyGemmInst",
[](GemmPy gemm_py, int block_size, Target target) {
return gemm_py->GetGemmInst(block_size, target);
});
});

} // namespace tl
} // namespace tvm
2 changes: 1 addition & 1 deletion src/op/gemm_py.h
Original file line number Diff line number Diff line change
Expand Up @@ -105,11 +105,11 @@ class GemmPyNode : public TileOperatorNode {

TileOperator Clone() const;

private:
// Target GEMM instruction
enum class GemmInst : uint8_t { kMMA, kWGMMA, kUTCMMA, kMFMA };
GemmInst GetGemmInst(int block_size, Target target) const;

private:
mutable bool completed_ = false;
};

Expand Down
Loading