Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 0 additions & 18 deletions third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1980,16 +1980,6 @@ struct LoadOpToBlockIOConversion
if (totalBytesPerRowPerMatrix > MAX_WIDTH)
return failure();

// Load multiple dot operands by enlarging the vBlocks.
vBlocks = std::min(vBlocks,
static_cast<int>(MAX_WIDTH / totalBytesPerRowPerMatrix));
// vBlocks has HW limitation of 4.
vBlocks = std::min(vBlocks, 4);
// Limit vBlocks to 1 if block size is smaller than GRF size.
const unsigned GRF_SIZE = 64;
if (tileHeight * tileWidth * packedElemSizeInBits / 8 < GRF_SIZE)
vBlocks = 1;

Location loc = op.getLoc();
auto b = TritonLLVMOpBuilder(loc, rewriter);
MLIRContext *ctx = rewriter.getContext();
Expand Down Expand Up @@ -2361,14 +2351,6 @@ struct DescriptorLoadOpToBlockIOConversion
if (totalBytesPerRowPerMatrix > MAX_WIDTH)
return failure();

// Load multiple dot operands by enlarging the vBlocks.
vBlocks = std::min(vBlocks,
static_cast<int>(MAX_WIDTH / totalBytesPerRowPerMatrix));
vBlocks = std::min(vBlocks, 4);
const unsigned GRF_SIZE = 64;
if (tileHeight * tileWidth * packedElemSizeInBits / 8 < GRF_SIZE)
vBlocks = 1;

Location loc = op.getLoc();
auto b = TritonLLVMOpBuilder(loc, rewriter);
MLIRContext *ctx = rewriter.getContext();
Expand Down
28 changes: 23 additions & 5 deletions third_party/intel/lib/TritonIntelGPUTransforms/BlockIOUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -330,11 +330,29 @@ getBlockIOTileSize(const LinearLayout &ll, unsigned memContiguousDim,
// insert the remaining register base.
regPackBases.insert(1 << regBaseIter);
}
return BlockIOTileSizeInfo(tileShape[transpose ? fastChangeDim : rowDim],
tileShape[transpose ? rowDim : fastChangeDim] /
numElemPerPackedVal,
numElemPerPackedVal, vBlocks, rowDim,
fastChangeDim, transpose, std::move(regPackBases));
int tileHeight = tileShape[transpose ? fastChangeDim : rowDim];
int tileWidth =
tileShape[transpose ? rowDim : fastChangeDim] / numElemPerPackedVal;

// Cap vBlocks for loads based on HW constraints.
if constexpr (isLoad) {
constexpr int MAX_WIDTH_BYTES = 64;
unsigned packedElemSizeInBits = elemSizeInBits * numElemPerPackedVal;
unsigned totalBytesPerRowPerMatrix = tileWidth * packedElemSizeInBits / 8;
if (totalBytesPerRowPerMatrix > 0) {
vBlocks =
std::min(vBlocks, static_cast<unsigned>(MAX_WIDTH_BYTES /
totalBytesPerRowPerMatrix));
}
vBlocks = std::min(vBlocks, 4u);
constexpr unsigned GRF_SIZE = 64;
if (tileHeight * tileWidth * packedElemSizeInBits / 8 < GRF_SIZE)
vBlocks = 1;
}

return BlockIOTileSizeInfo(tileHeight, tileWidth, numElemPerPackedVal,
vBlocks, rowDim, fastChangeDim, transpose,
std::move(regPackBases));
}

// Explicit instantiations.
Expand Down