Skip to content

Commit 86906cf

Browse files
authored
[AMD] Fix f16/bf16 buffer atomic operations properly (#6139)
Reverted the workaround that disabled f16/bf16 buffer atomic operations: #6090. Added an additional check for vector size, ensuring applicability of packed instructions for f16/bf16. See also: * #6090 * #6126 - rootcause fix Signed-off-by: Ilya Veselov <[email protected]>
1 parent 3784fda commit 86906cf

File tree

1 file changed

+12
-8
lines changed

1 file changed

+12
-8
lines changed

third_party/amd/lib/TritonAMDGPUTransforms/ConvertToBufferOps.cpp

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -313,18 +313,23 @@ struct ConvertTritonAtomicRMWOpToBufferAtomicRMW
313313
// 4. Buffer atomic RMW does not support FP8 ops
314314
// easier to just check what we support
315315
auto checkType = getElementTypeOrSelf(op.getVal());
316-
// TODO: F16 and BF16 data types are supported by intrinsics with packed
317-
// arithmetic on adjacent addresses, requiring the leading address to be
318-
// 4-byte aligned. A runtime check should be implemented to enforce this
319-
// requirement and ensure fallback to regular atomic operations when
320-
// alignment is not met.
321-
bool isSupportedType = checkType.isF32() || checkType.isF64() ||
316+
bool isSupportedType = checkType.isF16() || checkType.isBF16() ||
317+
checkType.isF32() || checkType.isF64() ||
322318
checkType.isInteger(32) || checkType.isInteger(64);
323319
if (!isSupportedType) {
324320
return rewriter.notifyMatchFailure(op, "RMW with unsupported type");
325321
}
326322
LDBG("RMW supported type");
327323

324+
auto vecSize = getVectorSize(ptr, axisAnalysisPass);
325+
// f16/bf16 dtypes could only be efficiently calculated using instructions
326+
// that pack 2 elements (e.g. @llvm.amdgcn.raw.buffer.atomic.fadd.v2f16)
327+
if (vecSize % 2 != 0 && (checkType.isF16() || checkType.isBF16())) {
328+
return rewriter.notifyMatchFailure(
329+
op, "RMW float 16 dtypes must be aligned by 2");
330+
}
331+
LDBG("RMW passed alignment check");
332+
328333
// 5. Check if the RMWOp is supported
329334
switch (atomicRmwOp) {
330335
case RMWOp::AND:
@@ -355,8 +360,7 @@ struct ConvertTritonAtomicRMWOpToBufferAtomicRMW
355360
// are contiguous we can emit the buffer op. Otherwise, the buffer ops
356361
// lowering will try to emit individual (unsupported) f16/bf16 ops.
357362
auto elemBitWidth = tensorType.getElementTypeBitWidth();
358-
opBitWidth =
359-
getVectorSize(basePtr, tensorOffset, axisAnalysisPass) * elemBitWidth;
363+
opBitWidth = vecSize * elemBitWidth;
360364
} else {
361365
opBitWidth = opValueType.getIntOrFloatBitWidth();
362366
}

0 commit comments

Comments
 (0)