Skip to content

Commit 29009f1

Browse files
authored
[GLUON] Allow TensorMemory layouts in to_linear_layout in the context of printing. (#8682)
1 parent 4b184cc commit 29009f1

File tree

5 files changed

+100
-47
lines changed

5 files changed

+100
-47
lines changed

python/src/gluon_ir.cc

Lines changed: 41 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -375,8 +375,47 @@ void init_gluon_ir(py::module &&m) {
375375
std::vector<int64_t> &shape) -> py::object {
376376
auto ctx = self.getContext();
377377
auto linearLayout = ttg::toLinearLayout(shape, layout);
378-
auto attr = ttg::LinearEncodingAttr::get(ctx, linearLayout);
379-
return layoutToGluon(attr);
378+
379+
if (isa<ttg::DistributedEncodingTrait>(layout)) {
380+
auto attr = ttg::LinearEncodingAttr::get(ctx, linearLayout);
381+
return layoutToGluon(attr);
382+
}
383+
if (isa<ttg::SharedEncodingTrait>(layout)) {
384+
auto alignment =
385+
cast<ttg::SharedEncodingTrait>(layout).getAlignment();
386+
auto attr = ttg::SharedLinearEncodingAttr::get(ctx, linearLayout,
387+
alignment);
388+
return layoutToGluon(attr);
389+
}
390+
391+
// TensorMemory encodings: keep the LinearLayout but wrap as
392+
// print-only Python object carrying row/col bases -> dim0/dim1.
393+
auto inNamesRange = linearLayout.getInDimNames();
394+
auto inNames = llvm::to_vector(inNamesRange);
395+
bool isTmemLayout =
396+
(inNames.size() == 2 && inNames[0].str() == "row" &&
397+
inNames[1].str() == "col");
398+
if (!isTmemLayout)
399+
throw std::invalid_argument(
400+
"Unsupported layout in to_linear_layout");
401+
402+
// Build Py _TensorMemoryLinearLayout(row_bases, col_bases, shape,
403+
// repr)
404+
py::object tmemCls =
405+
py::module::import(
406+
"triton.experimental.gluon.language.nvidia.blackwell")
407+
.attr("_TensorMemoryLinearLayout");
408+
auto bases = linearLayout.getBases();
409+
auto rowBases = bases[mlir::StringAttr::get(ctx, "row")];
410+
auto colBases = bases[mlir::StringAttr::get(ctx, "col")];
411+
auto outDims = linearLayout.getOutDims();
412+
std::vector<int> shapeVec;
413+
for (auto &od : outDims)
414+
shapeVec.push_back(od.second);
415+
416+
py::object pyObj = tmemCls(py::cast(rowBases), py::cast(colBases),
417+
py::cast(shapeVec));
418+
return pyObj;
380419
})
381420
.def("get_dot_operand_layout",
382421
[](GluonOpBuilder &self, unsigned opIdx, Attribute parent,

python/test/gluon/test_frontend.py

Lines changed: 17 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1461,48 +1461,29 @@ def kernel(reg_type: ttgl.constexpr, shared_type: ttgl.constexpr, ref_conflicts:
14611461

14621462

14631463
@pytest.mark.parametrize(
1464-
"layout, expected",
1464+
"layout, shape",
14651465
[
1466-
(
1467-
ttgl.BlockedLayout([1], [4], [4], [0]),
1468-
ttgl.DistributedLinearLayout(
1469-
reg_bases=[],
1470-
lane_bases=[[1], [2]],
1471-
warp_bases=[[4], [8]],
1472-
block_bases=[],
1473-
shape=[16],
1474-
),
1475-
),
1476-
(
1477-
ttgl.BlockedLayout([1], [4], [4], [0], [[1], [0]]),
1478-
ttgl.DistributedLinearLayout(
1479-
reg_bases=[],
1480-
lane_bases=[[1], [2]],
1481-
warp_bases=[[4], [8]],
1482-
block_bases=[[16], [0]],
1483-
shape=[32],
1484-
),
1485-
),
1486-
(
1487-
ttgl.BlockedLayout([8, 1], [8, 4], [1, 4], [0, 1], [[0, 1]]),
1488-
ttgl.DistributedLinearLayout(
1489-
reg_bases=[[1, 0], [2, 0], [4, 0], [0, 16], [0, 32]],
1490-
lane_bases=[[8, 0], [16, 0], [32, 0], [0, 1], [0, 2]],
1491-
warp_bases=[[0, 4], [0, 8]],
1492-
block_bases=[[0, 64]],
1493-
shape=[64, 128],
1494-
),
1495-
),
1466+
(ttgl.BlockedLayout([1], [4], [4], [0]), [16]),
1467+
(ttgl.BlockedLayout([1], [4], [4], [0], [[1], [0]]), [32]),
1468+
(ttgl.BlockedLayout([8, 1], [8, 4], [1, 4], [0, 1], [[0, 1]]), [64, 128]),
1469+
(ttgl.NVMMASharedLayout(swizzle_byte_width=128, element_bitwidth=16, rank=2), [64, 64]),
1470+
(TensorMemoryLayout((64, 64), col_stride=2), [64, 64]),
14961471
],
14971472
)
1498-
def test_to_linear_layout(layout, expected):
1473+
def test_to_linear_layout(layout, shape, capsys):
14991474

15001475
@gluon.jit
1501-
def kernel(layout: ttgl.constexpr, expected: ttgl.constexpr, shape: ttgl.constexpr):
1476+
def kernel(layout: ttgl.constexpr, shape: ttgl.constexpr):
15021477
computed: ttgl.constexpr = ttgl.to_linear_layout(layout, shape)
1503-
ttgl.static_assert(computed == expected)
1504-
1505-
run_parser(kernel, args=(layout, expected, tuple(expected.shape)), target=AMPERE_TARGET)
1478+
ttgl.static_print(computed)
1479+
1480+
run_parser(kernel, args=(layout, tuple(shape)), target=AMPERE_TARGET)
1481+
out = capsys.readouterr().out
1482+
if isinstance(layout, TensorMemoryLayout):
1483+
assert "rows=" in out
1484+
assert "cols=" in out
1485+
else:
1486+
assert "DistributedLinearLayout" in out or "SharedLinearLayout" in out
15061487

15071488

15081489
@filecheck_test

python/triton/experimental/gluon/language/_layouts.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from dataclasses import dataclass, field
2+
import itertools
23
from typing import List
34

45
from triton.language.core import _unwrap_if_constexpr, _unwrap_shape, constexpr_type
@@ -636,6 +637,15 @@ def _to_ir(self, builder):
636637
def mangle(self) -> str:
637638
return f"SharedLinear_{self.offset_bases}_{self.block_bases}_{self.alignment}_SharedLinear"
638639

640+
@property
641+
def shape(self):
642+
rank = len(self.offset_bases[0])
643+
max_stride = [1] * rank
644+
for b in itertools.chain(self.offset_bases, self.block_bases):
645+
for i, bi in enumerate(b):
646+
max_stride[i] = max(max_stride[i], bi)
647+
return [2 * s for s in max_stride]
648+
639649
def __hash__(self):
640650
return hash((
641651
tuple(map(tuple, self.offset_bases)),

python/triton/experimental/gluon/language/_semantic.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import math
33
from triton.language.semantic import TritonSemantic
44
from . import _core as ttgl
5-
from ._layouts import AutoLayout, DistributedLayout, DistributedLinearLayout, SliceLayout, SharedLayout, CoalescedLayout
5+
from ._layouts import AutoLayout, DistributedLayout, DistributedLinearLayout, SliceLayout, SharedLayout, CoalescedLayout, SharedLinearLayout
66
from triton._C.libtriton.gluon_ir import GluonOpBuilder, compute_tmem_reg_layout
77
from triton.compiler.code_generator import flatten_values_to_ir, unflatten_ir_values
88

@@ -301,15 +301,16 @@ def bank_conflicts(self, distr_ty, shared_ty):
301301
distr_ty.element_ty.primitive_bitwidth)
302302

303303
def to_linear_layout(self, layout, shape):
304-
_check(isinstance(layout, (DistributedLayout, SharedLayout)),
305-
lambda: f"Expected a DistributedLayout or SharedLayout, got {type(layout)}")
306-
307-
if not isinstance(shape, list):
308-
shape = list(shape)
309-
310-
layout = ttgl._unwrap_if_constexpr(layout)
304+
from triton.experimental.gluon.language.nvidia.blackwell import (
305+
TensorMemoryLayout,
306+
TensorMemoryScalesLayout,
307+
)
308+
_check(
309+
isinstance(layout, (DistributedLayout, SharedLayout, TensorMemoryLayout, TensorMemoryScalesLayout)), lambda:
310+
f"Expected a DistributedLayout, SharedLayout, or TensorMemoryLayout or TensorMemoryScalesLayout, got {type(layout)}"
311+
)
311312

312-
if isinstance(layout, (AutoLayout, DistributedLinearLayout)):
313+
if isinstance(layout, (AutoLayout, DistributedLinearLayout, SharedLinearLayout)):
313314
return ttgl.constexpr(layout)
314315

315316
return ttgl.constexpr(self.builder.to_linear_layout(layout._to_ir(self.builder), shape))

python/triton/experimental/gluon/language/nvidia/blackwell/__init__.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from typing import Optional, Tuple, List, TYPE_CHECKING
33

44
from dataclasses import dataclass
5+
import itertools
56
from triton.runtime.jit import constexpr_function
67
from triton.experimental.gluon.language import _core as ttgl
78
from triton.experimental.gluon.language._core import builtin, base_type, base_value, _unwrap_if_constexpr
@@ -26,7 +27,9 @@
2627
"mma_v2",
2728
"tensor_memory_descriptor",
2829
"TensorMemoryLayout",
30+
"TensorMemoryScalesLayout",
2931
"tma",
32+
"_TensorMemoryLinearLayout",
3033
]
3134

3235

@@ -104,6 +107,25 @@ def __hash__(self):
104107
return hash(self.cta_split_num)
105108

106109

110+
@dataclass(frozen=True)
111+
class _TensorMemoryLinearLayout:
112+
"""
113+
Print-only linear layout for TMEM (row/col -> dim0/dim1).
114+
"""
115+
rows: List[List[int]]
116+
cols: List[List[int]]
117+
shape: List[int]
118+
119+
def _to_ir(self, builder):
120+
raise RuntimeError("TensorMemoryLinearLayout is print-only; IR materialization is unsupported")
121+
122+
def mangle(self):
123+
return f"TMLL_{self.shape}_TMLL"
124+
125+
def __hash__(self):
126+
return hash((tuple(map(tuple, self.rows)), tuple(map(tuple, self.cols)), tuple(self.shape)))
127+
128+
107129
@constexpr_function
108130
def get_tmem_reg_layout(
109131
element_ty,

0 commit comments

Comments
 (0)