Skip to content

Commit 61b5674

Browse files
authored
[LAYOUTS] Implement generically getElemsPerThread (triton-lang#5841)
While doing so, we remove the SliceEncodingAttr hack!
1 parent 7da6d0b commit 61b5674

File tree

6 files changed

+43
-332
lines changed

6 files changed

+43
-332
lines changed

include/triton/Dialect/TritonGPU/IR/Dialect.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -70,10 +70,12 @@ struct SharedMemory : public SideEffects::Resource::Base<SharedMemory> {
7070
StringRef getName() final { return "<SharedMemory>"; }
7171
};
7272

73+
// Convert a distributed layout to a linear encoding
74+
LinearEncodingAttr toLinearEncoding(Attribute layout, ArrayRef<int64_t> shape);
75+
7376
unsigned getTotalElemsPerThread(Type type);
7477

75-
unsigned getTotalElemsPerThread(Attribute layout, ArrayRef<int64_t> shape,
76-
Type eltTy);
78+
unsigned getTotalElemsPerThread(Attribute layout, ArrayRef<int64_t> shape);
7779

7880
SmallVector<unsigned> getElemsPerThread(Type type);
7981

include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -501,13 +501,17 @@ We call each individual tile "rep".
501501
InterfaceMethod<"Return total element size per thread.",
502502
"unsigned",
503503
"getTotalElemsPerThread",
504-
(ins "ArrayRef<int64_t>":$tensorShape,
505-
"Type":$eltTy)>,
504+
(ins "ArrayRef<int64_t>":$shape),
505+
/*defaultImplementation=*/[{
506+
return toLinearEncoding($_self, shape).getTotalElemsPerThread(shape);
507+
}]>,
506508
InterfaceMethod<"Return element size per thread in each dimension.",
507509
"SmallVector<unsigned>",
508510
"getElemsPerThread",
509-
(ins "ArrayRef<int64_t>":$tensorShape,
510-
"Type":$eltTy)>,
511+
(ins "ArrayRef<int64_t>":$shape),
512+
/*defaultImplementation=*/[{
513+
return toLinearEncoding($_self, shape).getElemsPerThread(shape);
514+
}]>,
511515
// Interface for the meta information about the multiple thread hierarchy.
512516
InterfaceMethod<"Get the shape of the warps per CTA.",
513517
"SmallVector<unsigned>",
@@ -577,8 +581,7 @@ L(T) = [ {0,8} , {1,9} , {2,10}, {3,11}, {0,8} , {1, 9} , {2, 10}, {3, 11},
577581
}];
578582

579583
code extraDistributedDeclaration = extraBaseClassDeclaration # [{
580-
unsigned getTotalElemsPerThread(ArrayRef<int64_t> shape, Type eltTy) const;
581-
SmallVector<unsigned> getElemsPerThread(ArrayRef<int64_t> shape, Type eltTy) const;
584+
// Implemented in subclasses
582585
SmallVector<unsigned> getRepOrder() const;
583586
SmallVector<unsigned> getCTAsPerCGA() const;
584587
SmallVector<unsigned> getCTAOrder() const;
@@ -613,6 +616,10 @@ def LinearEncodingAttr : DistributedEncoding<"LinearEncoding", "linear_encoding"
613616
let parameters = (ins LinearLayoutParam:$linearLayout);
614617

615618
let extraClassDeclaration = extraDistributedDeclaration # [{
619+
// Generic distributed encoding methods
620+
unsigned getTotalElemsPerThread(ArrayRef<int64_t> shape) const;
621+
SmallVector<unsigned> getElemsPerThread(ArrayRef<int64_t> shape) const;
622+
616623
SmallVector<unsigned> getContigPerThread() const;
617624
SmallVector<unsigned> getOrder() const;
618625

@@ -965,7 +972,6 @@ V [ 0,4,8...60 1,5...61 2,6...62 3,7...63 ] [ 128,132...188 129,
965972
return true;
966973
}
967974
SmallVector<unsigned> getSizePerThreadForOperand(int kWidth, int opIdx) const;
968-
unsigned getTotalElemsPerThreadForOperand(ArrayRef<int64_t> shape, Type eltTy, int kWidth, int opIdx) const;
969975
SmallVector<int64_t> getInstrShapeForOperand(int kWidth, int opIdx) const;
970976
SmallVector<int64_t> getRepForOperand(ArrayRef<int64_t> operandShape, int kWidth, int opIdx) const;
971977
SmallVector<unsigned> getRepOrderForOperand(int opIdx) const;
@@ -1095,7 +1101,6 @@ Row |
10951101
return true;
10961102
}
10971103
SmallVector<unsigned> getSizePerThreadForOperand(int kWidth, int opIdx) const;
1098-
unsigned getTotalElemsPerThreadForOperand(ArrayRef<int64_t> shape, Type eltTy, int kWidth, int opIdx) const;
10991104
SmallVector<int64_t> getElemsPerInstrForOperands() const;
11001105
SmallVector<int64_t> getRepForOperand(ArrayRef<int64_t> operandShape,
11011106
Type elemType, int kWidth, int opIdx) const;

0 commit comments

Comments
 (0)