1111
1212#include " iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUEnums.h"
1313#include " llvm/ADT/APInt.h"
14+ #include " llvm/ADT/STLExtras.h"
1415#include " llvm/ADT/Sequence.h"
1516#include " llvm/Support/DebugLog.h"
1617#include " llvm/Support/InterleavedRange.h"
@@ -48,20 +49,14 @@ llvm::raw_ostream &operator<<(llvm::raw_ostream &os,
4849 return os;
4950}
5051
51- // Shortened helper to compute the product of `values`.
52- static int64_t prod (ArrayRef<int64_t > values) {
53- return ShapedType::getNumElements (values);
54- }
55-
5652static int64_t calculateOperandsSharedMemoryUsedInBytes (
5753 const GPUMMASchedule &schedule, int64_t lhsBitwidth, int64_t rhsBitwidth,
5854 int64_t numRhs = 1 ) {
59-
60- int64_t tileM = schedule.mSize * prod (schedule.mTileSizes ) *
61- prod (schedule.mSubgroupCounts );
62- int64_t tileN = schedule.nSize * prod (schedule.nTileSizes ) *
63- prod (schedule.nSubgroupCounts );
64- int64_t tileK = schedule.kSize * prod (schedule.kTileSizes );
55+ int64_t tileM = schedule.mSize * llvm::product_of (schedule.mTileSizes ) *
56+ llvm::product_of (schedule.mSubgroupCounts );
57+ int64_t tileN = schedule.nSize * llvm::product_of (schedule.nTileSizes ) *
58+ llvm::product_of (schedule.nSubgroupCounts );
59+ int64_t tileK = schedule.kSize * llvm::product_of (schedule.kTileSizes );
6560 return (tileM * tileK * lhsBitwidth + numRhs * tileN * tileK * rhsBitwidth) /
6661 8 ;
6762}
@@ -70,11 +65,10 @@ static int64_t
7065calculateResultSharedMemoryUsedInBytes (const GPUMMASchedule &schedule,
7166 int64_t resultBitwidth,
7267 int64_t numRes = 1 ) {
73-
74- int64_t tileM = schedule.mSize * prod (schedule.mTileSizes ) *
75- prod (schedule.mSubgroupCounts );
76- int64_t tileN = schedule.nSize * prod (schedule.nTileSizes ) *
77- prod (schedule.nSubgroupCounts );
68+ int64_t tileM = schedule.mSize * llvm::product_of (schedule.mTileSizes ) *
69+ llvm::product_of (schedule.mSubgroupCounts );
70+ int64_t tileN = schedule.nSize * llvm::product_of (schedule.nTileSizes ) *
71+ llvm::product_of (schedule.nSubgroupCounts );
7872 return (numRes * tileM * tileN * resultBitwidth) / 8 ;
7973}
8074
@@ -150,13 +144,14 @@ static bool isValidMMASchedule(const GPUMatmulShapeType &problem,
150144 const int64_t kMaxVectorLoadBitWidth = 128 ;
151145 int64_t elemsPerThread =
152146 kMaxVectorLoadBitWidth / problem.bType .getIntOrFloatBitWidth ();
153- int64_t wgThreads = subgroupSize * prod (schedule.mSubgroupCounts ) *
154- prod (schedule.nSubgroupCounts );
155- int64_t mWgSize = schedule.mSize * prod (schedule.mTileSizes ) *
156- prod (schedule.mSubgroupCounts );
157- int64_t nWgSize = schedule.nSize * prod (schedule.nTileSizes ) *
158- prod (schedule.nSubgroupCounts );
159- int64_t kWgSize = schedule.kSize * prod (schedule.kTileSizes );
147+ int64_t wgThreads = subgroupSize *
148+ llvm::product_of (schedule.mSubgroupCounts ) *
149+ llvm::product_of (schedule.nSubgroupCounts );
150+ int64_t mWgSize = schedule.mSize * llvm::product_of (schedule.mTileSizes ) *
151+ llvm::product_of (schedule.mSubgroupCounts );
152+ int64_t nWgSize = schedule.nSize * llvm::product_of (schedule.nTileSizes ) *
153+ llvm::product_of (schedule.nSubgroupCounts );
154+ int64_t kWgSize = schedule.kSize * llvm::product_of (schedule.kTileSizes );
160155 int64_t innerLhsDimSize = transposedLhs ? mWgSize : kWgSize ;
161156 int64_t innerRhsDimSize = transposedRhs ? kWgSize : nWgSize;
162157
@@ -263,12 +258,8 @@ static LogicalResult canTargetIntrinsic(const GPUMatmulShapeType &problem,
263258 // established after we sweep the different tile sizes for a problem config.
264259 // Once a precise threshold is established, replace 4 with the threshold and
265260 // remove this todo.
266- const int64_t mSize =
267- std::accumulate (problem.mSizes .begin (), problem.mSizes .end (), 1 ,
268- std::multiplies<int64_t >());
269- const int64_t nSize =
270- std::accumulate (problem.nSizes .begin (), problem.nSizes .end (), 1 ,
271- std::multiplies<int64_t >());
261+ const int64_t mSize = llvm::product_of (problem.mSizes );
262+ const int64_t nSize = llvm::product_of (problem.nSizes );
272263 // TODO(jornt): Remove this check as batch size doesn't make a computation
273264 // more compute bound, so it shouldn't be considered.
274265 if (!problem.batchSizes .empty ()) {
@@ -383,7 +374,6 @@ static void distributeGCDForDim(bool isMDim, int64_t &mTotalTileToDistribute,
383374 int64_t &nTileSizeDistributed,
384375 int64_t &remainingSubgroups,
385376 int64_t &remainingTiles) {
386-
387377 int64_t &totalTilesToDistribute =
388378 isMDim ? mTotalTileToDistribute : nTotalTileToDistribute;
389379 int64_t &subgroupDistributed =
@@ -418,8 +408,8 @@ static GPUMMASchedule getOptimalMMASchedule(const GPUMatmulShapeType &problem,
418408 llvm::divideCeil (problem.mSizes .back (), intrinsic.mSizes [0 ]);
419409 nTotalTileCounts.back () =
420410 llvm::divideCeil (problem.nSizes .back (), intrinsic.nSizes [0 ]);
421- int64_t mTotalTileToDistribute = prod (mTotalTileCounts );
422- int64_t nTotalTileToDistribute = prod (nTotalTileCounts);
411+ int64_t mTotalTileToDistribute = llvm::product_of (mTotalTileCounts );
412+ int64_t nTotalTileToDistribute = llvm::product_of (nTotalTileCounts);
423413
424414 int64_t remainingSubgroups = seeds.bestSubgroupCountPerWorkgroup ;
425415 int64_t remainingTiles = seeds.bestMNTileCountPerSubgroup ;
@@ -529,9 +519,9 @@ static GPUMMASchedule getOptimalMMASchedule(const GPUMatmulShapeType &problem,
529519 getBestKTileSizes (problem, intrinsic, seeds);
530520
531521 return GPUMMASchedule{intrinsic.mmaKind ,
532- prod (intrinsic.mSizes ),
533- prod (intrinsic.nSizes ),
534- prod (intrinsic.kSizes ),
522+ llvm::product_of (intrinsic.mSizes ),
523+ llvm::product_of (intrinsic.nSizes ),
524+ llvm::product_of (intrinsic.kSizes ),
535525 mSubgroupCounts ,
536526 nSubgroupCounts,
537527 mTileSizes ,
@@ -741,8 +731,8 @@ getOptimalAttentionPVSchedule(const GPUMatmulShapeType &problem,
741731 // subgroups on N leaves room to distribute subgroups on K1 and how that
742732 // effects the softmax computation hasn't been experimented with yet.
743733 //
744- // Distribute tile sizes on N as much as we can as it's completly unrolled and
745- // then distribute remaining tiles and subgroups on M.
734+ // Distribute tile sizes on N as much as we can as it's completely unrolled
735+ // and then distribute remaining tiles and subgroups on M.
746736 for (int nDim = problem.nSizes .size () - 1 ; nDim >= 0 ; --nDim) {
747737 // Do not distribute N on subgroups.
748738 nSubgroupCounts[nDim] = 1 ;
@@ -793,7 +783,6 @@ FailureOr<std::pair<GPUMMASchedule, GPUMMASchedule>> deduceAttentionSchedule(
793783 const GPUMMAHeuristicSeeds &pvMatmulSeeds, int64_t sharedMemLimitInBytes,
794784 int64_t subgroupSize, bool transposedQ, bool transposedK, bool transposedV,
795785 bool canUpcastAcc, bool mustBeAligned) {
796-
797786 SmallVector<uint64_t > qkViableIntrinsicIndices;
798787 SmallVector<uint64_t > pvViableIntrinsicIndices;
799788 for (const auto &[index, intrinsic] : llvm::enumerate (intrinsics)) {
0 commit comments