diff --git a/third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp b/third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp index 54b254b1f3..5caf651fc6 100644 --- a/third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp +++ b/third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp @@ -1980,16 +1980,6 @@ struct LoadOpToBlockIOConversion if (totalBytesPerRowPerMatrix > MAX_WIDTH) return failure(); - // Load multiple dot operands by enlarging the vBlocks. - vBlocks = std::min(vBlocks, - static_cast(MAX_WIDTH / totalBytesPerRowPerMatrix)); - // vBlocks has HW limitation of 4. - vBlocks = std::min(vBlocks, 4); - // Limit vBlocks to 1 if block size is smaller than GRF size. - const unsigned GRF_SIZE = 64; - if (tileHeight * tileWidth * packedElemSizeInBits / 8 < GRF_SIZE) - vBlocks = 1; - Location loc = op.getLoc(); auto b = TritonLLVMOpBuilder(loc, rewriter); MLIRContext *ctx = rewriter.getContext(); @@ -2361,14 +2351,6 @@ struct DescriptorLoadOpToBlockIOConversion if (totalBytesPerRowPerMatrix > MAX_WIDTH) return failure(); - // Load multiple dot operands by enlarging the vBlocks. - vBlocks = std::min(vBlocks, - static_cast(MAX_WIDTH / totalBytesPerRowPerMatrix)); - vBlocks = std::min(vBlocks, 4); - const unsigned GRF_SIZE = 64; - if (tileHeight * tileWidth * packedElemSizeInBits / 8 < GRF_SIZE) - vBlocks = 1; - Location loc = op.getLoc(); auto b = TritonLLVMOpBuilder(loc, rewriter); MLIRContext *ctx = rewriter.getContext(); diff --git a/third_party/intel/lib/TritonIntelGPUTransforms/BlockIOUtils.cpp b/third_party/intel/lib/TritonIntelGPUTransforms/BlockIOUtils.cpp index 53418d9613..03d0fadfdc 100644 --- a/third_party/intel/lib/TritonIntelGPUTransforms/BlockIOUtils.cpp +++ b/third_party/intel/lib/TritonIntelGPUTransforms/BlockIOUtils.cpp @@ -330,11 +330,29 @@ getBlockIOTileSize(const LinearLayout &ll, unsigned memContiguousDim, // insert the remaining register base. regPackBases.insert(1 << regBaseIter); } - return BlockIOTileSizeInfo(tileShape[transpose ? fastChangeDim : rowDim], - tileShape[transpose ? rowDim : fastChangeDim] / - numElemPerPackedVal, - numElemPerPackedVal, vBlocks, rowDim, - fastChangeDim, transpose, std::move(regPackBases)); + int tileHeight = tileShape[transpose ? fastChangeDim : rowDim]; + int tileWidth = + tileShape[transpose ? rowDim : fastChangeDim] / numElemPerPackedVal; + + // Cap vBlocks for loads based on HW constraints. + if constexpr (isLoad) { + constexpr int MAX_WIDTH_BYTES = 64; + unsigned packedElemSizeInBits = elemSizeInBits * numElemPerPackedVal; + unsigned totalBytesPerRowPerMatrix = tileWidth * packedElemSizeInBits / 8; + if (totalBytesPerRowPerMatrix > 0) { + vBlocks = + std::min(vBlocks, static_cast(MAX_WIDTH_BYTES / + totalBytesPerRowPerMatrix)); + } + vBlocks = std::min(vBlocks, 4u); + constexpr unsigned GRF_SIZE = 64; + if (tileHeight * tileWidth * packedElemSizeInBits / 8 < GRF_SIZE) + vBlocks = 1; + } + + return BlockIOTileSizeInfo(tileHeight, tileWidth, numElemPerPackedVal, + vBlocks, rowDim, fastChangeDim, transpose, + std::move(regPackBases)); } // Explicit instantiations.