Skip to content
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,7 @@ if(TRITON_BUILD_PYTHON_MODULE)
add_library(triton SHARED ${PYTHON_SRC_PATH}/main.cc
${PYTHON_SRC_PATH}/ir.cc
${PYTHON_SRC_PATH}/gluon_ir.cc
${PYTHON_SRC_PATH}/linear_layout.cc
${PYTHON_SRC_PATH}/passes.cc
${PYTHON_SRC_PATH}/interpreter.cc
${PYTHON_SRC_PATH}/llvm.cc
Expand Down
6 changes: 6 additions & 0 deletions include/triton/Dialect/TritonGPU/IR/Dialect.h
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,12 @@ void dumpHWLayout(RankedTensorType tensorType);
// Return a string representation of the layout of the tensor.
std::string getLayoutStr(RankedTensorType tensorType, bool useHWPointOfView);

// Return a string representation of the shared layout of the tensor.
std::string getSharedLayoutStr(LinearLayout &ll, bool useHWPointOfView);

// Return a string representation of the distributed layout of the tensor.
std::string getDistributedLayoutStr(LinearLayout &ll, bool useHWPointOfView);

template <typename T>
llvm::SmallVector<T> expandMatrixShapeWithBatch(llvm::ArrayRef<T> s);

Expand Down
2 changes: 2 additions & 0 deletions include/triton/Tools/LinearLayout.h
Original file line number Diff line number Diff line change
Expand Up @@ -869,6 +869,8 @@ inline std::ostream &operator<<(std::ostream &os, const ColumnAction &action) {
return os;
}

std::unique_ptr<uint64_t[]> getMatrix(const LinearLayout &layout);

} // namespace mlir::triton

#endif // TRITON_TOOLS_LINEARLAYOUT_H
67 changes: 31 additions & 36 deletions lib/Dialect/TritonGPU/IR/Dialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3307,20 +3307,17 @@ static std::string paddedString(int value, int max) {
return str;
}

std::string getSharedLayoutStr(RankedTensorType type, bool useHWPointOfView) {
if (!type)
return "";

std::string mlir::triton::gpu::getSharedLayoutStr(LinearLayout &ll,
bool useHWPointOfView) {
// This RankedTensorType is a MemDescType (?!)
auto shape = type.getShape();
auto layout = type.getEncoding();
LinearLayout ll = triton::gpu::toLinearLayout(shape, layout);
auto outDimNames = llvm::to_vector(ll.getOutDimNames());
auto shape = convertType<int64_t>(llvm::to_vector(ll.getOutDimSizes()));
auto *ctx = outDimNames[0].getContext();

StringAttr kOffset = StringAttr::get(type.getContext(), "offset");
StringAttr kBlock = StringAttr::get(type.getContext(), "block");
int64_t tensorSize = product(type.getShape());
auto enc = type.getEncoding();
unsigned numBlocks = getNumCTAs(enc);
StringAttr kOffset = StringAttr::get(ctx, "offset");
StringAttr kBlock = StringAttr::get(ctx, "block");
int64_t tensorSize = product(shape);
unsigned numBlocks = ll.getInDimSize(kBlock);
int32_t blockSize = tensorSize / numBlocks;

// elementMapping is for the non-hw layout, offsetMapping for hw-layout
Expand Down Expand Up @@ -3374,7 +3371,7 @@ std::string getSharedLayoutStr(RankedTensorType type, bool useHWPointOfView) {
std::string layoutStr;

if (!useHWPointOfView) {
int rank = type.getRank();
int rank = shape.size();
bool newLine = true;
for (int i = 0; i < tensorSize; i++) {
auto indices = delinearizeIndex(i, shape);
Expand Down Expand Up @@ -3422,21 +3419,19 @@ std::string getSharedLayoutStr(RankedTensorType type, bool useHWPointOfView) {
return layoutStr;
}

std::string getDistributedLayoutStr(RankedTensorType tensorType,
bool useHWPointOfView) {
auto layout = tensorType.getEncoding();
if (!layout)
return "";

StringAttr kRegister = StringAttr::get(tensorType.getContext(), "register");
StringAttr kLane = StringAttr::get(tensorType.getContext(), "lane");
StringAttr kWarp = StringAttr::get(tensorType.getContext(), "warp");
StringAttr kBlock = StringAttr::get(tensorType.getContext(), "block");
std::string mlir::triton::gpu::getDistributedLayoutStr(LinearLayout &ll,
bool useHWPointOfView) {
auto inDimNames = llvm::to_vector(ll.getInDimNames());
auto *ctx = inDimNames[0].getContext();
StringAttr kRegister = StringAttr::get(ctx, "register");
StringAttr kLane = StringAttr::get(ctx, "lane");
StringAttr kWarp = StringAttr::get(ctx, "warp");
StringAttr kBlock = StringAttr::get(ctx, "block");

LinearLayout ll = toLinearLayout(tensorType);
int64_t tensorSize = product(tensorType.getShape());
int64_t tensorSize = ll.getTotalOutDimSize();
std::vector<std::string> elementMapping(tensorSize);
std::vector<std::string> threadMapping;
auto shape = convertType<int64_t>(llvm::to_vector(ll.getOutDimSizes()));
unsigned threadsPerWarp = ll.getInDimSize(kLane);
unsigned numWarpsPerCTA = ll.getInDimSize(kWarp);
unsigned numBlocks = ll.getInDimSize(kBlock);
Expand All @@ -3456,7 +3451,7 @@ std::string getDistributedLayoutStr(RankedTensorType tensorType,
int stride = 1;
for (int i = outputs.size() - 1; i >= 0; i--) {
linearizedIdx += outputs[i].second * stride;
stride *= tensorType.getDimSize(i);
stride *= shape[i];
}
std::string &value = elementMapping[linearizedIdx];
if (!value.empty())
Expand All @@ -3476,8 +3471,7 @@ std::string getDistributedLayoutStr(RankedTensorType tensorType,
for (int i = 0; i < outputs.size(); i++) {
if (i > 0)
threadInfo += ",";
threadInfo +=
paddedString(outputs[i].second, tensorType.getDimSize(i));
threadInfo += paddedString(outputs[i].second, shape[i]);
}
threadInfo += ")";
threadMapping.push_back(threadInfo);
Expand All @@ -3488,13 +3482,13 @@ std::string getDistributedLayoutStr(RankedTensorType tensorType,
std::string layoutStr;
if (!useHWPointOfView) {
// Printing the threads containing each elements of the tensor.
int rank = tensorType.getRank();
int rank = ll.getNumOutDims();
bool newLine = true;
for (int i = 0; i < tensorSize; i++) {
auto indices = delinearizeIndex(i, tensorType.getShape());
auto indices = delinearizeIndex(i, shape);
int numOpenBracket = 0;
for (int j = rank - 1; j >= 0; j--) {
if (indices[j] % tensorType.getDimSize(j) != 0)
if (indices[j] % shape[j] != 0)
break;
layoutStr += "[";
numOpenBracket++;
Expand All @@ -3506,13 +3500,13 @@ std::string getDistributedLayoutStr(RankedTensorType tensorType,
}

layoutStr += elementMapping[i];
auto nextIndices = delinearizeIndex(i + 1, tensorType.getShape());
auto nextIndices = delinearizeIndex(i + 1, shape);
for (int j = rank - 1; j >= 0; j--) {
if (nextIndices[j] % tensorType.getDimSize(j) != 0)
if (nextIndices[j] % shape[j] != 0)
break;
layoutStr += "]";
}
if (nextIndices.back() % tensorType.getShape().back() == 0) {
if (nextIndices.back() % shape.back() == 0) {
layoutStr += "\n";
newLine = true;
} else {
Expand Down Expand Up @@ -3578,15 +3572,16 @@ mlir::triton::gpu::expandMatrixOrderWithBatch(llvm::ArrayRef<unsigned> o) {
std::string mlir::triton::gpu::getLayoutStr(RankedTensorType tensorType,
bool useHWPointOfView) {
auto layout = tensorType.getEncoding();
LinearLayout ll = triton::gpu::toLinearLayout(tensorType.getShape(), layout);

// tensorType is needed later on (e.g., getDimSize(j)), so we still have to
// pass it as a param
// TODO: Pass TensorOrMemDesc instead of RankedTensorType in
// triton-tensor-layout.cpp
if (mlir::isa<SharedEncodingTrait>(layout)) {
return getSharedLayoutStr(tensorType, useHWPointOfView);
return getSharedLayoutStr(ll, useHWPointOfView);
} else if (mlir::isa<DistributedEncodingTrait>(layout)) {
return getDistributedLayoutStr(tensorType, useHWPointOfView);
return getDistributedLayoutStr(ll, useHWPointOfView);
}

// else unimplemented, return error
Expand Down
104 changes: 52 additions & 52 deletions lib/Tools/LinearLayout.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,56 +65,6 @@ void dumpMatrix(uint64_t *m, int numRows, int numCols) {
}
}

// Build a matrix of size sum(outDimSizeLog2) x sum(inDimSizeLog2) representing
// the bases of the given layout. This can then be used by f2reduce.
//
// This function is called from the constructor of LinearLayout, so be careful
// not to use any functions that create LLs in here.
std::unique_ptr<uint64_t[]> getMatrix(const LinearLayout &layout) {
int numRows = layout.getTotalOutDimSizeLog2();
int numCols = layout.getTotalInDimSizeLog2();

// Don't handle giant LLs. This makes some things easier; for example, each
// row can be a single uint64_t.
assert(numCols <= 64 && "LinearLayout too large");
assert(numRows <= 64 && "LinearLayout too large");

// Suppose we have a layout specified by the following values.
//
// L(0,1) = (0b01, 0b1)
// L(0,2) = (0b10, 0b0)
// L(1,0) = (0b10, 0b0)
// L(2,0) = (0b11, 0b0)
//
// We will create one column per entry above. The max bit width of the
// codomain is (2,1), so our matrix will have 2+1=3 rows. The final matrix
// will be
//
// | L(0,1)[0] L(0,2)[0] L(1,0)[0] L(2,0)[0] | | 0b1001 |
// | ↓ ↓ ↓ ↓ | | 0b0111 |
// | L(0,1)[1] L(0,2)[1] L(1,0)[1] L(2,0)[1] | = | 0b1000 |
// | ↓ ↓ ↓ ↓ |
//
// Note `new uint64_t[n]()` is zero-initialized, but `new uint64_t[n]` is not.
std::unique_ptr<uint64_t[]> m(new uint64_t[numRows]());
int r = 0;
for (StringAttr outDim : layout.getOutDimNames()) {
int c = 0;
for (StringAttr inDim : layout.getInDimNames()) {
for (int i = 0; i < layout.getInDimSizeLog2(inDim); i++) {
uint64_t basis = layout.getBasis(inDim, i, outDim);
for (int j = 0; j < layout.getOutDimSizeLog2(outDim); j++) {
m[r + j] |= ((basis >> j) & 1) << c;
}
c++;
}
}
r += layout.getOutDimSizeLog2(outDim);
}

return m;
}

// Compute the rank of the matrix formed by taking the bases for the given
// outDim as columns. In other words, finds the number of linearly-independent
// bases for this output dimension.
Expand Down Expand Up @@ -340,7 +290,7 @@ int32_t LinearLayout::getOutDimIndex(StringAttr outDim) const {

int32_t LinearLayout::getInDimSizeLog2(StringAttr inDim) const {
auto it = bases.find(inDim);
assert(it != bases.end());
assert(it != bases.end() && "inDim not found in layout");
return it->second.size();
}

Expand All @@ -353,7 +303,7 @@ int32_t LinearLayout::getTotalInDimSizeLog2() const {

int32_t LinearLayout::getOutDimSizeLog2(StringAttr outDim) const {
auto it = outDims.find(outDim);
assert(it != outDims.end());
assert(it != outDims.end() && "outDim not found in layout");
return llvm::Log2_32(it->second);
}

Expand Down Expand Up @@ -1370,4 +1320,54 @@ std::string ColumnAction::toString() const {
return ret;
}

// Build a matrix of size sum(outDimSizeLog2) x sum(inDimSizeLog2) representing
// the bases of the given layout. This can then be used by f2reduce.
//
// This function is called from the constructor of LinearLayout, so be careful
// not to use any functions that create LLs in here.
std::unique_ptr<uint64_t[]> getMatrix(const LinearLayout &layout) {
int numRows = layout.getTotalOutDimSizeLog2();
int numCols = layout.getTotalInDimSizeLog2();

// Don't handle giant LLs. This makes some things easier; for example, each
// row can be a single uint64_t.
assert(numCols <= 64 && "LinearLayout too large");
assert(numRows <= 64 && "LinearLayout too large");

// Suppose we have a layout specified by the following values.
//
// L(0,1) = (0b01, 0b1)
// L(0,2) = (0b10, 0b0)
// L(1,0) = (0b10, 0b0)
// L(2,0) = (0b11, 0b0)
//
// We will create one column per entry above. The max bit width of the
// codomain is (2,1), so our matrix will have 2+1=3 rows. The final matrix
// will be
//
// | L(0,1)[0] L(0,2)[0] L(1,0)[0] L(2,0)[0] | | 0b1001 |
// | ↓ ↓ ↓ ↓ | | 0b0111 |
// | L(0,1)[1] L(0,2)[1] L(1,0)[1] L(2,0)[1] | = | 0b1000 |
// | ↓ ↓ ↓ ↓ |
//
// Note `new uint64_t[n]()` is zero-initialized, but `new uint64_t[n]` is not.
std::unique_ptr<uint64_t[]> m(new uint64_t[numRows]());
int r = 0;
for (StringAttr outDim : layout.getOutDimNames()) {
int c = 0;
for (StringAttr inDim : layout.getInDimNames()) {
for (int i = 0; i < layout.getInDimSizeLog2(inDim); i++) {
uint64_t basis = layout.getBasis(inDim, i, outDim);
for (int j = 0; j < layout.getOutDimSizeLog2(outDim); j++) {
m[r + j] |= ((basis >> j) & 1) << c;
}
c++;
}
}
r += layout.getOutDimSizeLog2(outDim);
}

return m;
}

} // namespace mlir::triton
9 changes: 8 additions & 1 deletion python/src/gluon_ir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -328,14 +328,21 @@ void init_gluon_ir(py::module &&m) {
/*requiresSurjective=*/true);
return ttg::LinearEncodingAttr::get(ctx, ll);
})
.def("to_linear_layout",
.def("to_linear",
[](GluonOpBuilder &self, Attribute layout,
std::vector<int64_t> &shape) -> py::object {
auto ctx = self.getContext();
auto linearLayout = ttg::toLinearLayout(shape, layout);
auto attr = ttg::LinearEncodingAttr::get(ctx, linearLayout);
return layoutToGluon(attr);
})
.def("to_linear_layout",
[](GluonOpBuilder &self, Attribute layout,
std::vector<int64_t> &shape) -> py::object {
auto ctx = self.getContext();
auto linearLayout = ttg::toLinearLayout(shape, layout);
return py::cast(linearLayout);
})
.def("get_dot_operand_layout",
[](GluonOpBuilder &self, unsigned opIdx, Attribute parent,
unsigned kWidth) -> Attribute {
Expand Down
Loading
Loading