@@ -501,13 +501,17 @@ We call each individual tile "rep".
501501 InterfaceMethod<"Return total element size per thread.",
502502 "unsigned",
503503 "getTotalElemsPerThread",
504- (ins "ArrayRef<int64_t>":$tensorShape,
505- "Type":$eltTy)>,
504+ (ins "ArrayRef<int64_t>":$shape),
505+ /*defaultImplementation=*/[{
506+ return toLinearEncoding($_self, shape).getTotalElemsPerThread(shape);
507+ }]>,
506508 InterfaceMethod<"Return element size per thread in each dimension.",
507509 "SmallVector<unsigned>",
508510 "getElemsPerThread",
509- (ins "ArrayRef<int64_t>":$tensorShape,
510- "Type":$eltTy)>,
511+ (ins "ArrayRef<int64_t>":$shape),
512+ /*defaultImplementation=*/[{
513+ return toLinearEncoding($_self, shape).getElemsPerThread(shape);
514+ }]>,
511515 // Interface for the meta information about the multiple thread hierarchy.
512516 InterfaceMethod<"Get the shape of the warps per CTA.",
513517 "SmallVector<unsigned>",
@@ -577,8 +581,7 @@ L(T) = [ {0,8} , {1,9} , {2,10}, {3,11}, {0,8} , {1, 9} , {2, 10}, {3, 11},
577581 }];
578582
579583 code extraDistributedDeclaration = extraBaseClassDeclaration # [{
580- unsigned getTotalElemsPerThread(ArrayRef<int64_t> shape, Type eltTy) const;
581- SmallVector<unsigned> getElemsPerThread(ArrayRef<int64_t> shape, Type eltTy) const;
584+ // Implemented in subclasses
582585 SmallVector<unsigned> getRepOrder() const;
583586 SmallVector<unsigned> getCTAsPerCGA() const;
584587 SmallVector<unsigned> getCTAOrder() const;
@@ -613,6 +616,10 @@ def LinearEncodingAttr : DistributedEncoding<"LinearEncoding", "linear_encoding"
613616 let parameters = (ins LinearLayoutParam:$linearLayout);
614617
615618 let extraClassDeclaration = extraDistributedDeclaration # [{
619+ // Generic distributed encoding methods
620+ unsigned getTotalElemsPerThread(ArrayRef<int64_t> shape) const;
621+ SmallVector<unsigned> getElemsPerThread(ArrayRef<int64_t> shape) const;
622+
616623 SmallVector<unsigned> getContigPerThread() const;
617624 SmallVector<unsigned> getOrder() const;
618625
@@ -965,7 +972,6 @@ V [ 0,4,8...60 1,5...61 2,6...62 3,7...63 ] [ 128,132...188 129,
965972 return true;
966973 }
967974 SmallVector<unsigned> getSizePerThreadForOperand(int kWidth, int opIdx) const;
968- unsigned getTotalElemsPerThreadForOperand(ArrayRef<int64_t> shape, Type eltTy, int kWidth, int opIdx) const;
969975 SmallVector<int64_t> getInstrShapeForOperand(int kWidth, int opIdx) const;
970976 SmallVector<int64_t> getRepForOperand(ArrayRef<int64_t> operandShape, int kWidth, int opIdx) const;
971977 SmallVector<unsigned> getRepOrderForOperand(int opIdx) const;
@@ -1095,7 +1101,6 @@ Row |
10951101 return true;
10961102 }
10971103 SmallVector<unsigned> getSizePerThreadForOperand(int kWidth, int opIdx) const;
1098- unsigned getTotalElemsPerThreadForOperand(ArrayRef<int64_t> shape, Type eltTy, int kWidth, int opIdx) const;
10991104 SmallVector<int64_t> getElemsPerInstrForOperands() const;
11001105 SmallVector<int64_t> getRepForOperand(ArrayRef<int64_t> operandShape,
11011106 Type elemType, int kWidth, int opIdx) const;
0 commit comments