Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,7 @@ if(TRITON_BUILD_PYTHON_MODULE)
add_library(triton SHARED ${PYTHON_SRC_PATH}/main.cc
${PYTHON_SRC_PATH}/ir.cc
${PYTHON_SRC_PATH}/gluon_ir.cc
${PYTHON_SRC_PATH}/linear_layout.cc
${PYTHON_SRC_PATH}/passes.cc
${PYTHON_SRC_PATH}/interpreter.cc
${PYTHON_SRC_PATH}/llvm.cc
Expand Down
6 changes: 6 additions & 0 deletions include/triton/Dialect/TritonGPU/IR/Dialect.h
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,12 @@ void dumpHWLayout(RankedTensorType tensorType);
// Return a string representation of the layout of the tensor.
std::string getLayoutStr(RankedTensorType tensorType, bool useHWPointOfView);

// Return a string representation of the shared layout of the tensor.
std::string getSharedLayoutStr(LinearLayout &ll, bool useHWPointOfView);

// Return a string representation of the distributed layout of the tensor.
std::string getDistributedLayoutStr(LinearLayout &ll, bool useHWPointOfView);

template <typename T>
llvm::SmallVector<T> expandMatrixShapeWithBatch(llvm::ArrayRef<T> s);

Expand Down
2 changes: 2 additions & 0 deletions include/triton/Tools/LinearLayout.h
Original file line number Diff line number Diff line change
Expand Up @@ -869,6 +869,8 @@ inline std::ostream &operator<<(std::ostream &os, const ColumnAction &action) {
return os;
}

std::unique_ptr<uint64_t[]> getMatrix(const LinearLayout &layout);

} // namespace mlir::triton

#endif // TRITON_TOOLS_LINEARLAYOUT_H
67 changes: 31 additions & 36 deletions lib/Dialect/TritonGPU/IR/Dialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3307,20 +3307,17 @@ static std::string paddedString(int value, int max) {
return str;
}

std::string getSharedLayoutStr(RankedTensorType type, bool useHWPointOfView) {
if (!type)
return "";

std::string mlir::triton::gpu::getSharedLayoutStr(LinearLayout &ll,
bool useHWPointOfView) {
// This RankedTensorType is a MemDescType (?!)
auto shape = type.getShape();
auto layout = type.getEncoding();
LinearLayout ll = triton::gpu::toLinearLayout(shape, layout);
auto outDimNames = llvm::to_vector(ll.getOutDimNames());
auto shape = convertType<int64_t>(llvm::to_vector(ll.getOutDimSizes()));
auto *ctx = outDimNames[0].getContext();

StringAttr kOffset = StringAttr::get(type.getContext(), "offset");
StringAttr kBlock = StringAttr::get(type.getContext(), "block");
int64_t tensorSize = product(type.getShape());
auto enc = type.getEncoding();
unsigned numBlocks = getNumCTAs(enc);
StringAttr kOffset = StringAttr::get(ctx, "offset");
StringAttr kBlock = StringAttr::get(ctx, "block");
int64_t tensorSize = product(shape);
unsigned numBlocks = ll.getInDimSize(kBlock);
int32_t blockSize = tensorSize / numBlocks;

// elementMapping is for the non-hw layout, offsetMapping for hw-layout
Expand Down Expand Up @@ -3374,7 +3371,7 @@ std::string getSharedLayoutStr(RankedTensorType type, bool useHWPointOfView) {
std::string layoutStr;

if (!useHWPointOfView) {
int rank = type.getRank();
int rank = shape.size();
bool newLine = true;
for (int i = 0; i < tensorSize; i++) {
auto indices = delinearizeIndex(i, shape);
Expand Down Expand Up @@ -3422,21 +3419,19 @@ std::string getSharedLayoutStr(RankedTensorType type, bool useHWPointOfView) {
return layoutStr;
}

std::string getDistributedLayoutStr(RankedTensorType tensorType,
bool useHWPointOfView) {
auto layout = tensorType.getEncoding();
if (!layout)
return "";

StringAttr kRegister = StringAttr::get(tensorType.getContext(), "register");
StringAttr kLane = StringAttr::get(tensorType.getContext(), "lane");
StringAttr kWarp = StringAttr::get(tensorType.getContext(), "warp");
StringAttr kBlock = StringAttr::get(tensorType.getContext(), "block");
std::string mlir::triton::gpu::getDistributedLayoutStr(LinearLayout &ll,
bool useHWPointOfView) {
auto inDimNames = llvm::to_vector(ll.getInDimNames());
auto *ctx = inDimNames[0].getContext();
StringAttr kRegister = StringAttr::get(ctx, "register");
StringAttr kLane = StringAttr::get(ctx, "lane");
StringAttr kWarp = StringAttr::get(ctx, "warp");
StringAttr kBlock = StringAttr::get(ctx, "block");

LinearLayout ll = toLinearLayout(tensorType);
int64_t tensorSize = product(tensorType.getShape());
int64_t tensorSize = ll.getTotalOutDimSize();
std::vector<std::string> elementMapping(tensorSize);
std::vector<std::string> threadMapping;
auto shape = convertType<int64_t>(llvm::to_vector(ll.getOutDimSizes()));
unsigned threadsPerWarp = ll.getInDimSize(kLane);
unsigned numWarpsPerCTA = ll.getInDimSize(kWarp);
unsigned numBlocks = ll.getInDimSize(kBlock);
Expand All @@ -3456,7 +3451,7 @@ std::string getDistributedLayoutStr(RankedTensorType tensorType,
int stride = 1;
for (int i = outputs.size() - 1; i >= 0; i--) {
linearizedIdx += outputs[i].second * stride;
stride *= tensorType.getDimSize(i);
stride *= shape[i];
}
std::string &value = elementMapping[linearizedIdx];
if (!value.empty())
Expand All @@ -3476,8 +3471,7 @@ std::string getDistributedLayoutStr(RankedTensorType tensorType,
for (int i = 0; i < outputs.size(); i++) {
if (i > 0)
threadInfo += ",";
threadInfo +=
paddedString(outputs[i].second, tensorType.getDimSize(i));
threadInfo += paddedString(outputs[i].second, shape[i]);
}
threadInfo += ")";
threadMapping.push_back(threadInfo);
Expand All @@ -3488,13 +3482,13 @@ std::string getDistributedLayoutStr(RankedTensorType tensorType,
std::string layoutStr;
if (!useHWPointOfView) {
// Printing the threads containing each elements of the tensor.
int rank = tensorType.getRank();
int rank = ll.getNumOutDims();
bool newLine = true;
for (int i = 0; i < tensorSize; i++) {
auto indices = delinearizeIndex(i, tensorType.getShape());
auto indices = delinearizeIndex(i, shape);
int numOpenBracket = 0;
for (int j = rank - 1; j >= 0; j--) {
if (indices[j] % tensorType.getDimSize(j) != 0)
if (indices[j] % shape[j] != 0)
break;
layoutStr += "[";
numOpenBracket++;
Expand All @@ -3506,13 +3500,13 @@ std::string getDistributedLayoutStr(RankedTensorType tensorType,
}

layoutStr += elementMapping[i];
auto nextIndices = delinearizeIndex(i + 1, tensorType.getShape());
auto nextIndices = delinearizeIndex(i + 1, shape);
for (int j = rank - 1; j >= 0; j--) {
if (nextIndices[j] % tensorType.getDimSize(j) != 0)
if (nextIndices[j] % shape[j] != 0)
break;
layoutStr += "]";
}
if (nextIndices.back() % tensorType.getShape().back() == 0) {
if (nextIndices.back() % shape.back() == 0) {
layoutStr += "\n";
newLine = true;
} else {
Expand Down Expand Up @@ -3578,15 +3572,16 @@ mlir::triton::gpu::expandMatrixOrderWithBatch(llvm::ArrayRef<unsigned> o) {
std::string mlir::triton::gpu::getLayoutStr(RankedTensorType tensorType,
bool useHWPointOfView) {
auto layout = tensorType.getEncoding();
LinearLayout ll = triton::gpu::toLinearLayout(tensorType.getShape(), layout);

// tensorType is needed later on (e.g., getDimSize(j)), so we still have to
// pass it as a param
// TODO: Pass TensorOrMemDesc instead of RankedTensorType in
// triton-tensor-layout.cpp
if (mlir::isa<SharedEncodingTrait>(layout)) {
return getSharedLayoutStr(tensorType, useHWPointOfView);
return getSharedLayoutStr(ll, useHWPointOfView);
} else if (mlir::isa<DistributedEncodingTrait>(layout)) {
return getDistributedLayoutStr(tensorType, useHWPointOfView);
return getDistributedLayoutStr(ll, useHWPointOfView);
}

// else unimplemented, return error
Expand Down
104 changes: 52 additions & 52 deletions lib/Tools/LinearLayout.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,56 +65,6 @@ void dumpMatrix(uint64_t *m, int numRows, int numCols) {
}
}

// Build a matrix of size sum(outDimSizeLog2) x sum(inDimSizeLog2) representing
// the bases of the given layout. This can then be used by f2reduce.
//
// This function is called from the constructor of LinearLayout, so be careful
// not to use any functions that create LLs in here.
std::unique_ptr<uint64_t[]> getMatrix(const LinearLayout &layout) {
int numRows = layout.getTotalOutDimSizeLog2();
int numCols = layout.getTotalInDimSizeLog2();

// Don't handle giant LLs. This makes some things easier; for example, each
// row can be a single uint64_t.
assert(numCols <= 64 && "LinearLayout too large");
assert(numRows <= 64 && "LinearLayout too large");

// Suppose we have a layout specified by the following values.
//
// L(0,1) = (0b01, 0b1)
// L(0,2) = (0b10, 0b0)
// L(1,0) = (0b10, 0b0)
// L(2,0) = (0b11, 0b0)
//
// We will create one column per entry above. The max bit width of the
// codomain is (2,1), so our matrix will have 2+1=3 rows. The final matrix
// will be
//
// | L(0,1)[0] L(0,2)[0] L(1,0)[0] L(2,0)[0] | | 0b1001 |
// | ↓ ↓ ↓ ↓ | | 0b0111 |
// | L(0,1)[1] L(0,2)[1] L(1,0)[1] L(2,0)[1] | = | 0b1000 |
// | ↓ ↓ ↓ ↓ |
//
// Note `new uint64_t[n]()` is zero-initialized, but `new uint64_t[n]` is not.
std::unique_ptr<uint64_t[]> m(new uint64_t[numRows]());
int r = 0;
for (StringAttr outDim : layout.getOutDimNames()) {
int c = 0;
for (StringAttr inDim : layout.getInDimNames()) {
for (int i = 0; i < layout.getInDimSizeLog2(inDim); i++) {
uint64_t basis = layout.getBasis(inDim, i, outDim);
for (int j = 0; j < layout.getOutDimSizeLog2(outDim); j++) {
m[r + j] |= ((basis >> j) & 1) << c;
}
c++;
}
}
r += layout.getOutDimSizeLog2(outDim);
}

return m;
}

// Compute the rank of the matrix formed by taking the bases for the given
// outDim as columns. In other words, finds the number of linearly-independent
// bases for this output dimension.
Expand Down Expand Up @@ -340,7 +290,7 @@ int32_t LinearLayout::getOutDimIndex(StringAttr outDim) const {

int32_t LinearLayout::getInDimSizeLog2(StringAttr inDim) const {
auto it = bases.find(inDim);
assert(it != bases.end());
assert(it != bases.end() && "inDim not found in layout");
return it->second.size();
}

Expand All @@ -353,7 +303,7 @@ int32_t LinearLayout::getTotalInDimSizeLog2() const {

int32_t LinearLayout::getOutDimSizeLog2(StringAttr outDim) const {
auto it = outDims.find(outDim);
assert(it != outDims.end());
assert(it != outDims.end() && "outDim not found in layout");
return llvm::Log2_32(it->second);
}

Expand Down Expand Up @@ -1370,4 +1320,54 @@ std::string ColumnAction::toString() const {
return ret;
}

// Build a matrix of size sum(outDimSizeLog2) x sum(inDimSizeLog2) representing
// the bases of the given layout. This can then be used by f2reduce.
//
// This function is called from the constructor of LinearLayout, so be careful
// not to use any functions that create LLs in here.
std::unique_ptr<uint64_t[]> getMatrix(const LinearLayout &layout) {
int numRows = layout.getTotalOutDimSizeLog2();
int numCols = layout.getTotalInDimSizeLog2();

// Don't handle giant LLs. This makes some things easier; for example, each
// row can be a single uint64_t.
assert(numCols <= 64 && "LinearLayout too large");
assert(numRows <= 64 && "LinearLayout too large");

// Suppose we have a layout specified by the following values.
//
// L(0,1) = (0b01, 0b1)
// L(0,2) = (0b10, 0b0)
// L(1,0) = (0b10, 0b0)
// L(2,0) = (0b11, 0b0)
//
// We will create one column per entry above. The max bit width of the
// codomain is (2,1), so our matrix will have 2+1=3 rows. The final matrix
// will be
//
// | L(0,1)[0] L(0,2)[0] L(1,0)[0] L(2,0)[0] | | 0b1001 |
// | ↓ ↓ ↓ ↓ | | 0b0111 |
// | L(0,1)[1] L(0,2)[1] L(1,0)[1] L(2,0)[1] | = | 0b1000 |
// | ↓ ↓ ↓ ↓ |
//
// Note `new uint64_t[n]()` is zero-initialized, but `new uint64_t[n]` is not.
std::unique_ptr<uint64_t[]> m(new uint64_t[numRows]());
int r = 0;
for (StringAttr outDim : layout.getOutDimNames()) {
int c = 0;
for (StringAttr inDim : layout.getInDimNames()) {
for (int i = 0; i < layout.getInDimSizeLog2(inDim); i++) {
uint64_t basis = layout.getBasis(inDim, i, outDim);
for (int j = 0; j < layout.getOutDimSizeLog2(outDim); j++) {
m[r + j] |= ((basis >> j) & 1) << c;
}
c++;
}
}
r += layout.getOutDimSizeLog2(outDim);
}

return m;
}

} // namespace mlir::triton
9 changes: 8 additions & 1 deletion python/src/gluon_ir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -328,14 +328,21 @@ void init_gluon_ir(py::module &&m) {
/*requiresSurjective=*/true);
return ttg::LinearEncodingAttr::get(ctx, ll);
})
.def("to_linear_layout",
.def("to_linear",
[](GluonOpBuilder &self, Attribute layout,
std::vector<int64_t> &shape) -> py::object {
auto ctx = self.getContext();
auto linearLayout = ttg::toLinearLayout(shape, layout);
auto attr = ttg::LinearEncodingAttr::get(ctx, linearLayout);
return layoutToGluon(attr);
})
.def("to_linear_layout",
[](GluonOpBuilder &self, Attribute layout,
std::vector<int64_t> &shape) -> py::object {
auto ctx = self.getContext();
auto linearLayout = ttg::toLinearLayout(shape, layout);
return py::cast(linearLayout);
})
.def("get_dot_operand_layout",
[](GluonOpBuilder &self, unsigned opIdx, Attribute parent,
unsigned kWidth) -> Attribute {
Expand Down
Loading
Loading