From 13f35057e22977ca41dae9f79c46213749629096 Mon Sep 17 00:00:00 2001 From: Maksim Levental Date: Mon, 27 Jan 2025 15:55:48 -0500 Subject: [PATCH 1/3] [NFC] replace include/triton/Conversion/TritonGPUToLLVM/Utility.h macros with TritonImplicitLocOpBuilder --- .../Conversion/TritonGPUToLLVM/Utility.h | 513 ++++++++----- .../TritonGPUToLLVM/AssertOpToLLVM.cpp | 13 +- .../TritonGPUToLLVM/ControlFlowOpToLLVM.cpp | 10 +- .../TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp | 94 +-- .../SharedToDotOperandFMA.cpp | 79 +- .../TritonGPUToLLVM/ElementwiseOpToLLVM.cpp | 23 +- .../TritonGPUToLLVM/GatherOpToLLVM.cpp | 23 +- .../TritonGPUToLLVM/HistogramOpToLLVM.cpp | 58 +- .../TritonGPUToLLVM/MakeRangeOpToLLVM.cpp | 3 +- .../TritonGPUToLLVM/MemoryOpToLLVM.cpp | 3 +- .../TritonGPUToLLVM/ReduceOpToLLVM.cpp | 48 +- .../TritonGPUToLLVM/ReduceScanCommon.h | 5 +- .../TritonGPUToLLVM/ScanOpToLLVM.cpp | 88 ++- lib/Conversion/TritonGPUToLLVM/Utility.cpp | 180 +++-- .../TritonGPUToLLVM/ViewOpToLLVM.cpp | 14 +- .../TritonAMDGPUToLLVM/BufferOpsEmitter.cpp | 48 +- .../ConvertLayoutOpToLLVM.cpp | 103 +-- .../SharedToDotOperandHelper.cpp | 47 +- .../SharedToDotOperandMFMA.cpp | 79 +- .../SharedToDotOperandWMMA.cpp | 64 +- .../TritonAMDGPUToLLVM/DotOpToLLVM/MFMA.cpp | 63 +- .../TritonAMDGPUToLLVM/DotOpToLLVM/WMMA.cpp | 40 +- .../ElementwiseOpToLLVM.cpp | 722 +++++++++--------- .../TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp | 256 ++++--- .../lib/TritonAMDGPUToLLVM/MemoryOpToLLVM.cpp | 3 +- .../amd/lib/TritonAMDGPUToLLVM/TargetInfo.cpp | 42 +- .../TritonAMDGPUToLLVM/UpcastMXFPToLLVM.cpp | 117 +-- .../amd/lib/TritonAMDGPUToLLVM/Utility.cpp | 80 +- .../lib/NVGPUToLLVM/NVGPUToLLVMPass.cpp | 13 +- .../TritonNVIDIAGPUToLLVM/BarrierOpToLLVM.cpp | 11 +- .../ConvertLayoutOpToLLVM.cpp | 91 +-- .../SharedToDotOperandMMAv2OrV3.cpp | 173 +++-- .../DotOpToLLVM/MMAv2.cpp | 21 +- .../DotOpToLLVM/WGMMA.cpp | 85 ++- .../ElementwiseOpToLLVM.cpp | 22 +- .../LoadStoreOpToLLVM.cpp | 104 +-- .../TritonNVIDIAGPUToLLVM/MemoryOpToLLVM.cpp | 23 +- .../lib/TritonNVIDIAGPUToLLVM/TMAToLLVM.cpp | 23 +- .../lib/TritonNVIDIAGPUToLLVM/TargetInfo.cpp | 88 ++- .../TensorPtrOpsToLLVM.cpp | 3 +- .../UpcastMXFPToLLVM.cpp | 49 +- .../lib/TritonNVIDIAGPUToLLVM/Utility.cpp | 46 +- 42 files changed, 1996 insertions(+), 1574 deletions(-) diff --git a/include/triton/Conversion/TritonGPUToLLVM/Utility.h b/include/triton/Conversion/TritonGPUToLLVM/Utility.h index 926c361c45e9..b41de1ae9a73 100644 --- a/include/triton/Conversion/TritonGPUToLLVM/Utility.h +++ b/include/triton/Conversion/TritonGPUToLLVM/Utility.h @@ -6,6 +6,7 @@ #include "mlir/Conversion/LLVMCommon/Pattern.h" #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/IR/ImplicitLocOpBuilder.h" #include "mlir/Interfaces/FunctionInterfaces.h" #include "triton/Analysis/Utility.h" #include "triton/Conversion/MLIRTypes.h" @@ -29,89 +30,249 @@ using namespace mlir; using namespace mlir::triton; -// Shortcuts for some commonly used LLVM ops to keep code simple and intuitive -// Operators -#define inttofloat(...) rewriter.create(loc, __VA_ARGS__) -#define inttoptr(...) rewriter.create(loc, __VA_ARGS__) -#define ptrtoint(...) rewriter.create(loc, __VA_ARGS__) -#define zext(...) rewriter.create(loc, __VA_ARGS__) -#define sext(...) rewriter.create(loc, __VA_ARGS__) -#define fpext(...) rewriter.create(loc, __VA_ARGS__) -#define fptrunc(...) rewriter.create(loc, __VA_ARGS__) -#define trunc(...) rewriter.create(loc, __VA_ARGS__) -#define udiv(...) rewriter.create(loc, __VA_ARGS__) -#define sdiv(...) rewriter.create(loc, __VA_ARGS__) -#define urem(...) rewriter.create(loc, __VA_ARGS__) -#define add(...) rewriter.create(loc, __VA_ARGS__) -#define sub(...) rewriter.create(loc, __VA_ARGS__) -#define fadd(...) rewriter.create(loc, __VA_ARGS__) -#define mul(...) rewriter.create(loc, __VA_ARGS__) -#define fmul(...) rewriter.create(loc, __VA_ARGS__) -#define fma(...) rewriter.create(loc, __VA_ARGS__) -#define neg(...) rewriter.create(loc, __VA_ARGS__) -#define smax(...) rewriter.create(loc, __VA_ARGS__) -#define umax(...) rewriter.create(loc, __VA_ARGS__) -#define fmax(...) rewriter.create(loc, __VA_ARGS__) -#define smin(...) rewriter.create(loc, __VA_ARGS__) -#define umin(...) rewriter.create(loc, __VA_ARGS__) -#define fmin(...) rewriter.create(loc, __VA_ARGS__) -#define shl(...) rewriter.create(loc, __VA_ARGS__) -#define lshr(...) rewriter.create(loc, __VA_ARGS__) -#define ashr(...) rewriter.create(loc, __VA_ARGS__) -#define and_(...) rewriter.create(loc, __VA_ARGS__) -#define xor_(...) rewriter.create(loc, __VA_ARGS__) -#define or_(...) rewriter.create(loc, __VA_ARGS__) -#define bitcast(val__, type__) \ - rewriter.create(loc, type__, val__) -#define addrspacecast(...) \ - rewriter.create(loc, __VA_ARGS__) -#define gep(...) rewriter.create(loc, __VA_ARGS__) -#define ptr_ty(...) LLVM::LLVMPointerType::get(__VA_ARGS__) -#define insert_val(...) rewriter.create(loc, __VA_ARGS__) -#define extract_val(...) rewriter.create(loc, __VA_ARGS__) -#define insert_element(...) \ - rewriter.create(loc, __VA_ARGS__) -#define extract_element(...) \ - rewriter.create(loc, __VA_ARGS__) -#define load(...) rewriter.create(loc, __VA_ARGS__) -#define store(...) rewriter.create(loc, __VA_ARGS__) -#define fcmp_ogt(lhs, rhs) \ - rewriter.create(loc, rewriter.getI1Type(), \ - LLVM::FCmpPredicate::ogt, lhs, rhs) -#define fcmp_olt(lhs, rhs) \ - rewriter.create(loc, rewriter.getI1Type(), \ - LLVM::FCmpPredicate::olt, lhs, rhs) -#define fcmp_eq(lhs, rhs) \ - rewriter.create(loc, rewriter.getI1Type(), \ - LLVM::FCmpPredicate::oeq, lhs, rhs) -#define icmp_eq(...) \ - rewriter.create(loc, LLVM::ICmpPredicate::eq, __VA_ARGS__) -#define icmp_ne(...) \ - rewriter.create(loc, LLVM::ICmpPredicate::ne, __VA_ARGS__) -#define icmp_slt(...) \ - rewriter.create(loc, LLVM::ICmpPredicate::slt, __VA_ARGS__) -#define icmp_sle(...) \ - rewriter.create(loc, LLVM::ICmpPredicate::sle, __VA_ARGS__) -#define icmp_sgt(...) \ - rewriter.create(loc, LLVM::ICmpPredicate::sgt, __VA_ARGS__) -#define icmp_sge(...) \ - rewriter.create(loc, LLVM::ICmpPredicate::sge, __VA_ARGS__) -#define icmp_ult(...) \ - rewriter.create(loc, LLVM::ICmpPredicate::ult, __VA_ARGS__) -#define icmp_ule(...) \ - rewriter.create(loc, LLVM::ICmpPredicate::ule, __VA_ARGS__) -#define icmp_ugt(...) \ - rewriter.create(loc, LLVM::ICmpPredicate::ugt, __VA_ARGS__) -#define icmp_uge(...) \ - rewriter.create(loc, LLVM::ICmpPredicate::uge, __VA_ARGS__) -#define select(...) rewriter.create(loc, __VA_ARGS__) -#define address_of(...) rewriter.create(loc, __VA_ARGS__) -#define barrier() rewriter.create(loc) -#define undef(...) rewriter.create(loc, __VA_ARGS__) -#define null(...) rewriter.create(loc, __VA_ARGS__) -#define call(...) LLVM::createLLVMCallOp(rewriter, loc, __VA_ARGS__) +namespace mlir::triton { +struct TritonLLVMOpBuilder { + TritonLLVMOpBuilder(const Location &loc, RewriterBase &builder) + : loc(loc), builder(builder) {} + // Shortcuts for some commonly used LLVM ops to keep code simple and intuitive + // Operators + template LLVM::SIToFPOp inttofloat(Args &&...args) { + return builder.create(loc, std::forward(args)...); + } + template LLVM::IntToPtrOp inttoptr(Args &&...args) { + return builder.create(loc, std::forward(args)...); + } + template LLVM::PtrToIntOp ptrtoint(Args &&...args) { + return builder.create(loc, std::forward(args)...); + } + template LLVM::ZExtOp zext(Args &&...args) { + return builder.create(loc, std::forward(args)...); + } + template LLVM::SExtOp sext(Args &&...args) { + return builder.create(loc, std::forward(args)...); + } + template LLVM::FPExtOp fpext(Args &&...args) { + return builder.create(loc, std::forward(args)...); + } + template LLVM::FPTruncOp fptrunc(Args &&...args) { + return builder.create(loc, std::forward(args)...); + } + template LLVM::TruncOp trunc(Args &&...args) { + return builder.create(loc, std::forward(args)...); + } + template LLVM::UDivOp udiv(Args &&...args) { + return builder.create(loc, std::forward(args)...); + } + template LLVM::SDivOp sdiv(Args &&...args) { + return builder.create(loc, std::forward(args)...); + } + template LLVM::URemOp urem(Args &&...args) { + return builder.create(loc, std::forward(args)...); + } + template LLVM::AddOp add(Args &&...args) { + return builder.create(loc, std::forward(args)...); + } + template LLVM::SubOp sub(Args &&...args) { + return builder.create(loc, std::forward(args)...); + } + template LLVM::FAddOp fadd(Args &&...args) { + return builder.create(loc, std::forward(args)...); + } + template LLVM::MulOp mul(Args &&...args) { + return builder.create(loc, std::forward(args)...); + } + template LLVM::FMulOp fmul(Args &&...args) { + return builder.create(loc, std::forward(args)...); + } + template LLVM::FMAOp fma(Args &&...args) { + return builder.create(loc, std::forward(args)...); + } + template LLVM::FNegOp neg(Args &&...args) { + return builder.create(loc, std::forward(args)...); + } + template LLVM::SMaxOp smax(Args &&...args) { + return builder.create(loc, std::forward(args)...); + } + template LLVM::UMaxOp umax(Args &&...args) { + return builder.create(loc, std::forward(args)...); + } + template LLVM::MaxNumOp fmax(Args &&...args) { + return builder.create(loc, std::forward(args)...); + } + template LLVM::SMinOp smin(Args &&...args) { + return builder.create(loc, std::forward(args)...); + } + template LLVM::UMinOp umin(Args &&...args) { + return builder.create(loc, std::forward(args)...); + } + template LLVM::MinNumOp fmin(Args &&...args) { + return builder.create(loc, std::forward(args)...); + } + template LLVM::ShlOp shl(Args &&...args) { + return builder.create(loc, std::forward(args)...); + } + template LLVM::LShrOp lshr(Args &&...args) { + return builder.create(loc, std::forward(args)...); + } + template LLVM::AShrOp ashr(Args &&...args) { + return builder.create(loc, std::forward(args)...); + } + template LLVM::AndOp and_(Args &&...args) { + return builder.create(loc, std::forward(args)...); + } + template LLVM::XOrOp xor_(Args &&...args) { + return builder.create(loc, std::forward(args)...); + } + template LLVM::OrOp or_(Args &&...args) { + return builder.create(loc, std::forward(args)...); + } + LLVM::BitcastOp bitcast(Value val, Type type) { + return builder.create(loc, type, val); + } + template + LLVM::AddrSpaceCastOp addrspacecast(Args &&...args) { + return builder.create(loc, + std::forward(args)...); + } + template LLVM::GEPOp gep(Args &&...args) { + return builder.create(loc, std::forward(args)...); + } + template LLVM::InsertValueOp insert_val(Args &&...args) { + return builder.create(loc, + std::forward(args)...); + } + template LLVM::ExtractValueOp extract_val(Args &&...args) { + return builder.create(loc, + std::forward(args)...); + } + template + LLVM::InsertElementOp insert_element(Args &&...args) { + return builder.create(loc, + std::forward(args)...); + } + template + LLVM::ExtractElementOp extract_element(Args &&...args) { + return builder.create(loc, + std::forward(args)...); + } + template LLVM::LoadOp load(Args &&...args) { + return builder.create(loc, std::forward(args)...); + } + template LLVM::StoreOp store(Args &&...args) { + return builder.create(loc, std::forward(args)...); + } + template LLVM::FCmpOp fcmp_ogt(Value lhs, Value rhs) { + return builder.create(loc, builder.getI1Type(), + LLVM::FCmpPredicate::ogt, lhs, rhs); + } + template LLVM::FCmpOp fcmp_olt(Value lhs, Value rhs) { + return builder.create(loc, builder.getI1Type(), + LLVM::FCmpPredicate::olt, lhs, rhs); + } + template LLVM::FCmpOp fcmp_eq(Value lhs, Value rhs) { + return builder.create(loc, builder.getI1Type(), + LLVM::FCmpPredicate::oeq, lhs, rhs); + } + template LLVM::ICmpOp icmp_eq(Args &&...args) { + return builder.create(loc, LLVM::ICmpPredicate::eq, + std::forward(args)...); + } + template LLVM::ICmpOp icmp_ne(Args &&...args) { + return builder.create(loc, LLVM::ICmpPredicate::ne, + std::forward(args)...); + } + template LLVM::ICmpOp icmp_slt(Args &&...args) { + return builder.create(loc, LLVM::ICmpPredicate::slt, + std::forward(args)...); + } + template LLVM::ICmpOp icmp_sle(Args &&...args) { + return builder.create(loc, LLVM::ICmpPredicate::sle, + std::forward(args)...); + } + template LLVM::ICmpOp icmp_sgt(Args &&...args) { + return builder.create(loc, LLVM::ICmpPredicate::sgt, + std::forward(args)...); + } + template LLVM::ICmpOp icmp_sge(Args &&...args) { + return builder.create(loc, LLVM::ICmpPredicate::sge, + std::forward(args)...); + } + template LLVM::ICmpOp icmp_ult(Args &&...args) { + return builder.create(loc, LLVM::ICmpPredicate::ult, + std::forward(args)...); + } + template LLVM::ICmpOp icmp_ule(Args &&...args) { + return builder.create(loc, LLVM::ICmpPredicate::ule, + std::forward(args)...); + } + template LLVM::ICmpOp icmp_ugt(Args &&...args) { + return builder.create(loc, LLVM::ICmpPredicate::ugt, + std::forward(args)...); + } + template LLVM::ICmpOp icmp_uge(Args &&...args) { + return builder.create(loc, LLVM::ICmpPredicate::uge, + std::forward(args)...); + } + template LLVM::SelectOp select(Args &&...args) { + return builder.create(loc, std::forward(args)...); + } + template LLVM::AddressOfOp address_of(Args &&...args) { + return builder.create(loc, std::forward(args)...); + } + mlir::gpu::BarrierOp barrier() { + return builder.create(loc); + } + template LLVM::UndefOp undef(Args &&...args) { + return builder.create(loc, std::forward(args)...); + } + template LLVM::ZeroOp null(Args &&...args) { + return builder.create(loc, std::forward(args)...); + } + template LLVM::CallOp call(Args &&...args) { + return builder.create(loc, std::forward(args)...); + } + // Constants + Value int_val(short bitwidth, int64_t val) { + Type ty = builder.getIntegerType(bitwidth); + return builder.create(loc, ty, + builder.getIntegerAttr(ty, val)); + } + Value i1_val(int64_t val) { return int_val(1, val); } + Value true_val() { return int_val(1, true); } + Value false_val() { return int_val(1, false); } + Value f16_val(float v) { + auto type = type::f16Ty(builder.getContext()); + return builder.create(loc, type, + builder.getF16FloatAttr(v)); + } + Value f32_val(float v) { + auto type = type::f32Ty(builder.getContext()); + return builder.create(loc, type, + builder.getF32FloatAttr(v)); + } + Value f64_val(double v) { + auto type = type::f64Ty(builder.getContext()); + return builder.create(loc, type, + builder.getF64FloatAttr(v)); + } + Value i8_val(int64_t val) { return int_val(8, val); } + Value i16_val(int64_t val) { return int_val(16, val); } + Value i32_val(int64_t val) { return int_val(32, val); } + Value i64_val(int64_t val) { return int_val(64, val); } + Value tid_val() { + Value tid = + builder.create<::mlir::gpu::ThreadIdOp>(loc, ::mlir::gpu::Dimension::x); + Type i32_ty = builder.getIntegerType(32); + return builder.create(loc, i32_ty, tid); + } + + Location loc; + RewriterBase &builder; +}; +} // namespace mlir::triton // Types +#define ptr_ty(...) LLVM::LLVMPointerType::get(__VA_ARGS__) #define int_ty(width) rewriter.getIntegerType(width) #define i64_ty rewriter.getIntegerType(64) #define i32_ty rewriter.getIntegerType(32) @@ -131,21 +292,6 @@ using namespace mlir::triton; #define struct_ty(...) LLVM::LLVMStructType::getLiteral(ctx, __VA_ARGS__) #define array_ty(elemTy, count) LLVM::LLVMArrayType::get(elemTy, count) -// Constants -#define int_val(bitwidth, val) \ - LLVM::createLLVMIntegerConstant(rewriter, loc, bitwidth, val) -#define i1_val(val) LLVM::createConstantI1(loc, rewriter, val) -#define true_val() i1_val(true) -#define false_val() i1_val(false) -#define f16_val(...) LLVM::createConstantF16(loc, rewriter, __VA_ARGS__) -#define f32_val(...) LLVM::createConstantF32(loc, rewriter, __VA_ARGS__) -#define f64_val(...) LLVM::createConstantF64(loc, rewriter, __VA_ARGS__) -#define i8_val(val) int_val(8, val) -#define i16_val(val) int_val(16, val) -#define i32_val(...) LLVM::createConstantI32(loc, rewriter, __VA_ARGS__) -#define i64_val(...) LLVM::createConstantI64(loc, rewriter, __VA_ARGS__) -#define tid_val() getThreadId(rewriter, loc) - // Attributes #define i32_arr_attr(...) rewriter.getI32ArrayAttr({__VA_ARGS__}) #define i64_arr_attr(...) rewriter.getI64ArrayAttr({__VA_ARGS__}) @@ -256,7 +402,8 @@ class SharedMemoryObject { SharedMemoryObject(Value base, Type baseElemType, int64_t rank, Location loc, RewriterBase &rewriter) : base(base), baseElemType(baseElemType) { - offsets.append(rank, i32_val(0)); + auto b = TritonLLVMOpBuilder(loc, rewriter); + offsets.append(rank, b.i32_val(0)); } SmallVector getOffsets() const { return offsets; } @@ -298,10 +445,11 @@ class SharedMemoryObject { // TODO(Keren): deprecate the method once AMD backend has cleaned up Value getBaseBeforeSlice(int dim, Location loc, RewriterBase &rewriter) const { + auto b = TritonLLVMOpBuilder(loc, rewriter); Value cSwizzleOffset = getCSwizzleOffset(dim); - Value offset = sub(i32_val(0), cSwizzleOffset); + Value offset = b.sub(b.i32_val(0), cSwizzleOffset); Type type = base.getType(); - return gep(type, baseElemType, base, offset); + return b.gep(type, baseElemType, base, offset); } private: @@ -332,8 +480,9 @@ class SharedMemoryObject { SmallVector strides(shape.size()); auto order = SharedMemoryObject::getOrderForShape(shape, layoutOrder); int64_t stride = 1; + auto b = TritonLLVMOpBuilder(loc, rewriter); for (auto idx : order) { - strides[idx] = i32_val(stride); + strides[idx] = b.i32_val(stride); stride *= shape[idx]; } return strides; @@ -439,7 +588,8 @@ inline Value getGlobalScratchPtr(Location loc, RewriterBase &rewriter, } auto ptrTy = mlir::LLVM::LLVMPointerType::get(rewriter.getContext(), 1); - return gep(ptrTy, i8_ty, gmemBase, allocOffset); + auto b = TritonLLVMOpBuilder(loc, rewriter); + return b.gep(ptrTy, i8_ty, gmemBase, allocOffset); } // Base for entire kernel @@ -461,21 +611,22 @@ inline Value getGlobalScratchPtr(Location loc, RewriterBase &rewriter, gridDim[k] = rewriter.create(loc, k); } + auto b = TritonLLVMOpBuilder(loc, rewriter); Value linearId = gridIdx[2]; for (int k = 0; k < 2; ++k) { - linearId = add(gridIdx[1 - k], mul(linearId, gridDim[1 - k])); + linearId = b.add(gridIdx[1 - k], b.mul(linearId, gridDim[1 - k])); } auto allocSize = allocSizeAttr.getValue().getZExtValue(); - Value offset = mul(linearId, i32_val(allocSize)); + Value offset = b.mul(linearId, b.i32_val(allocSize)); if (allocOffset) { - offset = add(offset, allocOffset); + offset = b.add(offset, allocOffset); } auto *ctx = rewriter.getContext(); auto res = - gep(mlir::LLVM::LLVMPointerType::get(ctx, 1), i8_ty, gmemBase, offset); + b.gep(mlir::LLVM::LLVMPointerType::get(ctx, 1), i8_ty, gmemBase, offset); return res; } @@ -489,8 +640,10 @@ inline Value getSharedMemoryBase(Location loc, RewriterBase &rewriter, size_t offset = cast(op->getAttr("allocation.offset")) .getValue() .getZExtValue(); - Value offVal = i32_val(offset); - Value base = gep(ptrTy, i8_ty, LLVM::getStackPointer(rewriter, func), offVal); + auto b = TritonLLVMOpBuilder(loc, rewriter); + Value offVal = b.i32_val(offset); + Value base = + b.gep(ptrTy, i8_ty, LLVM::getStackPointer(rewriter, func), offVal); return base; } @@ -530,9 +683,10 @@ using ::mlir::triton::gpu::SliceEncodingAttr; inline Value dot(RewriterBase &rewriter, Location loc, ArrayRef offsets, ArrayRef strides) { assert(offsets.size() == strides.size()); - Value ret = i32_val(0); + auto b = TritonLLVMOpBuilder(loc, rewriter); + Value ret = b.i32_val(0); for (auto [offset, stride] : llvm::zip(offsets, strides)) { - ret = add(ret, mul(offset, stride)); + ret = b.add(ret, b.mul(offset, stride)); } return ret; } @@ -576,9 +730,10 @@ emitBaseIndexWithinCTAForBlockedLayout(Location loc, RewriterBase &rewriter, MLIRContext *ctx = rewriter.getContext(); auto shape = type.getShape(); Value threadId = getThreadId(rewriter, loc); - Value warpSize = i32_val(triton::gpu::getWarpSize(blockedLayout)); - Value laneId = urem(threadId, warpSize); - Value warpId = udiv(threadId, warpSize); + auto b = TritonLLVMOpBuilder(loc, rewriter); + Value warpSize = b.i32_val(triton::gpu::getWarpSize(blockedLayout)); + Value laneId = b.urem(threadId, warpSize); + Value warpId = b.udiv(threadId, warpSize); auto sizePerThread = blockedLayout.getSizePerThread(); auto threadsPerWarp = blockedLayout.getThreadsPerWarp(); auto warpsPerCTA = blockedLayout.getWarpsPerCTA(); @@ -600,16 +755,16 @@ emitBaseIndexWithinCTAForBlockedLayout(Location loc, RewriterBase &rewriter, auto maxWarps = ceil(shapePerCTA[k], sizePerThread[k] * threadsPerWarp[k]); auto maxThreads = ceil(shapePerCTA[k], sizePerThread[k]); - multiDimWarpId[k] = urem(multiDimWarpId[k], i32_val(maxWarps)); - multiDimThreadId[k] = urem(multiDimThreadId[k], i32_val(maxThreads)); + multiDimWarpId[k] = b.urem(multiDimWarpId[k], b.i32_val(maxWarps)); + multiDimThreadId[k] = b.urem(multiDimThreadId[k], b.i32_val(maxThreads)); // multiDimBase[k] = (multiDimThreadId[k] + // multiDimWarpId[k] * threadsPerWarp[k]) * // sizePerThread[k]; - Value threadsPerWarpK = i32_val(threadsPerWarp[k]); - Value sizePerThreadK = i32_val(sizePerThread[k]); + Value threadsPerWarpK = b.i32_val(threadsPerWarp[k]); + Value sizePerThreadK = b.i32_val(sizePerThread[k]); multiDimBase[k] = - mul(sizePerThreadK, - add(multiDimThreadId[k], mul(multiDimWarpId[k], threadsPerWarpK))); + b.mul(sizePerThreadK, b.add(multiDimThreadId[k], + b.mul(multiDimWarpId[k], threadsPerWarpK))); } return multiDimBase; @@ -632,14 +787,15 @@ emitBaseIndexWithinCTAForMmaLayoutV2V3(Location loc, RewriterBase &rewriter, auto warpOrder = triton::gpu::getWarpOrder(mmaLayout); ArrayRef instrShape = mmaLayout.getInstrShape(); SmallVector warpsPerCTA; + auto b = TritonLLVMOpBuilder(loc, rewriter); for (unsigned i = 0; i < rank; ++i) - warpsPerCTA.push_back(i32_val(_warpsPerCTA[i])); + warpsPerCTA.push_back(b.i32_val(_warpsPerCTA[i])); auto shapePerCTA = getShapePerCTA(mmaLayout, shape); Value threadId = getThreadId(rewriter, loc); - Value warpSize = i32_val(32); - Value laneId = urem(threadId, warpSize); - Value warpId = udiv(threadId, warpSize); + Value warpSize = b.i32_val(32); + Value laneId = b.urem(threadId, warpSize); + Value warpId = b.udiv(threadId, warpSize); uint32_t repM = (_warpsPerCTA[rank - 2] * instrShape[rank - 2]) / shapePerCTA[rank - 2]; @@ -660,11 +816,11 @@ emitBaseIndexWithinCTAForMmaLayoutV2V3(Location loc, RewriterBase &rewriter, SmallVector multiDimWarpId(rank); multiDimWarpId = delinearize(rewriter, loc, warpId, _warpsPerCTA, warpOrder); - Value warpIdM = urem(multiDimWarpId[rank - 2], i32_val(warpsM)); - Value warpIdN = urem(multiDimWarpId[rank - 1], i32_val(warpsN)); + Value warpIdM = b.urem(multiDimWarpId[rank - 2], b.i32_val(warpsM)); + Value warpIdN = b.urem(multiDimWarpId[rank - 1], b.i32_val(warpsN)); - Value offWarpM = mul(warpIdM, i32_val(instrShape[rank - 2])); - Value offWarpN = mul(warpIdN, i32_val(instrShape[rank - 1])); + Value offWarpM = b.mul(warpIdM, b.i32_val(instrShape[rank - 2])); + Value offWarpN = b.mul(warpIdN, b.i32_val(instrShape[rank - 1])); SmallVector multiDimBase(rank); if (rank == 3) @@ -676,10 +832,10 @@ emitBaseIndexWithinCTAForMmaLayoutV2V3(Location loc, RewriterBase &rewriter, // we rely on the caller to check. Worst case we crash, which is better than // silently producing bad code. if (warpsM != 0) - multiDimBase[rank - 2] = add(udiv(laneId, i32_val(4)), offWarpM); + multiDimBase[rank - 2] = b.add(b.udiv(laneId, b.i32_val(4)), offWarpM); if (warpsN != 0) multiDimBase[rank - 1] = - add(mul(i32_val(2), urem(laneId, i32_val(4))), offWarpN); + b.add(b.mul(b.i32_val(2), b.urem(laneId, b.i32_val(4))), offWarpN); return multiDimBase; } @@ -693,55 +849,56 @@ emitBaseIndexForMfmaLayout(Location loc, RewriterBase &rewriter, assert(rank == 2 || rank == 3); auto _warpsPerCTA = mfmaLayout.getWarpsPerCTA(); SmallVector warpsPerCTA; + auto b = TritonLLVMOpBuilder(loc, rewriter); for (unsigned i = 0; i < rank; ++i) - warpsPerCTA.push_back(i32_val(_warpsPerCTA[i])); + warpsPerCTA.push_back(b.i32_val(_warpsPerCTA[i])); unsigned mDim = mfmaLayout.getMDim(); unsigned nDim = mfmaLayout.getNDim(); assert((mDim == nDim && (mDim == 32 || mDim == 16 || mDim == 4)) || (mDim == 64 && nDim == 4) || (mDim == 4 && nDim == 64)); Value threadId = getThreadId(rewriter, loc); - Value warpSize = i32_val(triton::gpu::getWarpSize(mfmaLayout)); + Value warpSize = b.i32_val(triton::gpu::getWarpSize(mfmaLayout)); Value effectiveWarpSize = warpSize; if (mDim == 4 && nDim == 4) { const int uniqueValuesPerWarp = 4; - effectiveWarpSize = i32_val(uniqueValuesPerWarp); + effectiveWarpSize = b.i32_val(uniqueValuesPerWarp); } - Value laneId = urem(threadId, effectiveWarpSize); - Value warpId = udiv(threadId, warpSize); + Value laneId = b.urem(threadId, effectiveWarpSize); + Value warpId = b.udiv(threadId, warpSize); SmallVector multiDimWarpId = delinearize(rewriter, loc, warpId, _warpsPerCTA, triton::gpu::getWarpOrder(mfmaLayout)); if (shape[rank - 2] >= mDim) { assert(shape[rank - 2] % mDim == 0); multiDimWarpId[rank - 2] = - urem(multiDimWarpId[rank - 2], - i32_val(ceil(shape[rank - 2], mDim))); + b.urem(multiDimWarpId[rank - 2], + b.i32_val(ceil(shape[rank - 2], mDim))); } if (shape[rank - 1] >= nDim) { assert(shape[rank - 1] % nDim == 0); multiDimWarpId[rank - 1] = - urem(multiDimWarpId[rank - 1], - i32_val(ceil(shape[rank - 1], nDim))); + b.urem(multiDimWarpId[rank - 1], + b.i32_val(ceil(shape[rank - 1], nDim))); } - Value offWarp0 = mul(multiDimWarpId[rank - 2], i32_val(mDim)); - Value offWarp1 = mul(multiDimWarpId[rank - 1], i32_val(nDim)); + Value offWarp0 = b.mul(multiDimWarpId[rank - 2], b.i32_val(mDim)); + Value offWarp1 = b.mul(multiDimWarpId[rank - 1], b.i32_val(nDim)); SmallVector multiDimBase(rank); if (mfmaLayout.getIsTransposed()) { multiDimBase[rank - 1] = - add(mul(i32_val(4), udiv(laneId, i32_val(mDim))), offWarp1); - multiDimBase[rank - 2] = add(urem(laneId, i32_val(mDim)), offWarp0); + b.add(b.mul(b.i32_val(4), b.udiv(laneId, b.i32_val(mDim))), offWarp1); + multiDimBase[rank - 2] = b.add(b.urem(laneId, b.i32_val(mDim)), offWarp0); } else { multiDimBase[rank - 2] = - add(mul(i32_val(4), udiv(laneId, i32_val(nDim))), offWarp0); - multiDimBase[rank - 1] = add(urem(laneId, i32_val(nDim)), offWarp1); + b.add(b.mul(b.i32_val(4), b.udiv(laneId, b.i32_val(nDim))), offWarp0); + multiDimBase[rank - 1] = b.add(b.urem(laneId, b.i32_val(nDim)), offWarp1); } // TODO(Lixun): It is assumed when rank = 3, warpsPerCTA is set to // {numWarps, 1, 1}. We need to generalize the offset computation. if (rank == 3) { assert(_warpsPerCTA[1] == 1 && _warpsPerCTA[2] == 1); - multiDimBase[0] = urem(warpId, i32_val(shape[0])); + multiDimBase[0] = b.urem(warpId, b.i32_val(shape[0])); } return multiDimBase; } @@ -821,64 +978,65 @@ emitBaseIndexForWmmaLayout(Location loc, RewriterBase &rewriter, auto rank = _warpsPerCTA.size(); assert(rank == 2 || rank == 3); SmallVector warpsPerCTA; + auto b = TritonLLVMOpBuilder(loc, rewriter); for (unsigned i = 0; i < rank; ++i) - warpsPerCTA.push_back(i32_val(_warpsPerCTA[i])); + warpsPerCTA.push_back(b.i32_val(_warpsPerCTA[i])); auto mnkDim = AMDWmmaEncodingAttr::getMNKDimPerInstr(); Value threadId = getThreadId(rewriter, loc); - Value warpSize = i32_val(triton::gpu::getWarpSize(wmmaLayout)); + Value warpSize = b.i32_val(triton::gpu::getWarpSize(wmmaLayout)); Value laneId = - urem(threadId, i32_val(triton::gpu::getWarpSize(wmmaLayout) / 2)); - Value threadIdPerWarp = urem(threadId, warpSize); + b.urem(threadId, b.i32_val(triton::gpu::getWarpSize(wmmaLayout) / 2)); + Value threadIdPerWarp = b.urem(threadId, warpSize); - Value warpId = udiv(threadId, warpSize); + Value warpId = b.udiv(threadId, warpSize); SmallVector multiDimWarpId = delinearize(rewriter, loc, warpId, _warpsPerCTA, triton::gpu::getWarpOrder(wmmaLayout)); if (shape[rank - 2] >= mnkDim[0]) { assert(shape[rank - 2] % mnkDim[0] == 0); multiDimWarpId[rank - 2] = - urem(multiDimWarpId[rank - 2], - i32_val(ceil(shape[rank - 2], mnkDim[0]))); + b.urem(multiDimWarpId[rank - 2], + b.i32_val(ceil(shape[rank - 2], mnkDim[0]))); } if (shape[rank - 1] >= mnkDim[1]) { assert(shape[rank - 1] % mnkDim[1] == 0); multiDimWarpId[rank - 1] = - urem(multiDimWarpId[rank - 1], - i32_val(ceil(shape[rank - 1], mnkDim[1]))); + b.urem(multiDimWarpId[rank - 1], + b.i32_val(ceil(shape[rank - 1], mnkDim[1]))); } - Value offWarp0 = mul(multiDimWarpId[rank - 2], i32_val(mnkDim[0])); - Value offWarp1 = mul(multiDimWarpId[rank - 1], i32_val(mnkDim[1])); + Value offWarp0 = b.mul(multiDimWarpId[rank - 2], b.i32_val(mnkDim[0])); + Value offWarp1 = b.mul(multiDimWarpId[rank - 1], b.i32_val(mnkDim[1])); SmallVector multiDimBase(rank); auto ver = wmmaLayout.getVersion(); if (ver == 1) { multiDimBase[rank - 2] = - add(udiv(threadIdPerWarp, i32_val(mnkDim[2])), offWarp0); + b.add(b.udiv(threadIdPerWarp, b.i32_val(mnkDim[2])), offWarp0); } else { assert(ver == 2); if (wmmaLayout.getIsTransposed()) { multiDimBase[rank - 1] = - add(mul(udiv(threadIdPerWarp, i32_val(16)), - i32_val(wmmaLayout.getSizePerThread()[rank - 1])), - offWarp1); - multiDimBase[rank - 2] = add(laneId, offWarp0); + b.add(b.mul(b.udiv(threadIdPerWarp, b.i32_val(16)), + b.i32_val(wmmaLayout.getSizePerThread()[rank - 1])), + offWarp1); + multiDimBase[rank - 2] = b.add(laneId, offWarp0); } else { multiDimBase[rank - 2] = - add(mul(udiv(threadIdPerWarp, i32_val(16)), - i32_val(wmmaLayout.getSizePerThread()[rank - 2])), - offWarp0); - multiDimBase[rank - 1] = add(laneId, offWarp1); + b.add(b.mul(b.udiv(threadIdPerWarp, b.i32_val(16)), + b.i32_val(wmmaLayout.getSizePerThread()[rank - 2])), + offWarp0); + multiDimBase[rank - 1] = b.add(laneId, offWarp1); } } - multiDimBase[rank - 1] = add(laneId, offWarp1); + multiDimBase[rank - 1] = b.add(laneId, offWarp1); // TODO: It is assumed when rank = 3, warpsPerCTA is set to // {numWarps, 1, 1}. We need to generalize the offset computation. if (rank == 3) { assert(_warpsPerCTA[1] == 1 && _warpsPerCTA[2] == 1); - multiDimBase[0] = urem(warpId, i32_val(shape[0])); + multiDimBase[0] = b.urem(warpId, b.i32_val(shape[0])); } return multiDimBase; } @@ -907,16 +1065,18 @@ inline SmallVector emitCTAOffsetForLayout(Location loc, SmallVector multiDimClusterCTAId = delinearize(rewriter, loc, clusterCTAId, CTAsPerCGA, CTAOrder); + auto b = TritonLLVMOpBuilder(loc, rewriter); // CTA Wrapping for (unsigned i = 0; i < rank; ++i) { // This wrapping rule must be consistent with getShapePerCTA unsigned splitNum = std::min(shape[i], CTASplitNum[i]); - multiDimClusterCTAId[i] = urem(multiDimClusterCTAId[i], i32_val(splitNum)); + multiDimClusterCTAId[i] = + b.urem(multiDimClusterCTAId[i], b.i32_val(splitNum)); } SmallVector CTAOffset(rank); for (unsigned i = 0; i < rank; ++i) - CTAOffset[i] = mul(multiDimClusterCTAId[i], i32_val(shapePerCTA[i])); + CTAOffset[i] = b.mul(multiDimClusterCTAId[i], b.i32_val(shapePerCTA[i])); return CTAOffset; } @@ -954,6 +1114,7 @@ emitBaseIndexForLayoutImpl(Location loc, RewriterBase &rewriter, } else { llvm_unreachable("unsupported emitBaseIndexForLayout"); } + auto b = TritonLLVMOpBuilder(loc, rewriter); if (withCTAOffset) { auto CTAOffset = emitCTAOffsetForLayout(loc, rewriter, target, layout, shape); @@ -964,7 +1125,7 @@ emitBaseIndexForLayoutImpl(Location loc, RewriterBase &rewriter, // off. if (!result[k]) continue; - result[k] = add(result[k], CTAOffset[k]); + result[k] = b.add(result[k], CTAOffset[k]); } } return result; @@ -1058,9 +1219,10 @@ inline SmallVector unpackLLElements(Location loc, Value llvmStruct, ArrayRef types = cast(llvmStruct.getType()).getBody(); SmallVector results(types.size()); + auto b = TritonLLVMOpBuilder(loc, rewriter); for (unsigned i = 0; i < types.size(); ++i) { Type type = types[i]; - results[i] = extract_val(type, llvmStruct, i); + results[i] = b.extract_val(type, llvmStruct, i); } return results; } @@ -1083,6 +1245,7 @@ inline Value packLLElements(Location loc, << resultVals.size(); } Value llvmStruct = rewriter.create(loc, structType); + auto b = TritonLLVMOpBuilder(loc, rewriter); for (const auto &v : llvm::enumerate(resultVals)) { if (!v.value()) { emitError(loc) @@ -1096,7 +1259,7 @@ inline Value packLLElements(Location loc, << elementTypes[v.index()] << " but got " << v.value().getType(); } - llvmStruct = insert_val(structType, llvmStruct, v.value(), v.index()); + llvmStruct = b.insert_val(structType, llvmStruct, v.value(), v.index()); } return llvmStruct; } @@ -1109,10 +1272,11 @@ inline SmallVector unpackLLVector(Location loc, Value llvmVec, isa(llvmVec.getType())) return {llvmVec}; + auto b = TritonLLVMOpBuilder(loc, rewriter); SmallVector results; for (int i = 0; i < cast(llvmVec.getType()).getNumElements(); i++) { - results.push_back(extract_element(llvmVec, i32_val(i))); + results.push_back(b.extract_element(llvmVec, b.i32_val(i))); } return results; } @@ -1121,9 +1285,10 @@ inline Value packLLVector(Location loc, ValueRange vals, RewriterBase &rewriter) { assert(vals.size() > 0); auto vecType = vec_ty(vals[0].getType(), vals.size()); - Value vec = undef(vecType); + auto b = TritonLLVMOpBuilder(loc, rewriter); + Value vec = b.undef(vecType); for (int i = 0; i < vals.size(); i++) { - vec = insert_element(vec, vals[i], i32_val(i)); + vec = b.insert_element(vec, vals[i], b.i32_val(i)); } return vec; } diff --git a/lib/Conversion/TritonGPUToLLVM/AssertOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/AssertOpToLLVM.cpp index 7a3c8ce27abd..1a5e0809b427 100644 --- a/lib/Conversion/TritonGPUToLLVM/AssertOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/AssertOpToLLVM.cpp @@ -18,17 +18,18 @@ struct AssertOpConversion : public ConvertOpToLLVMPattern { matchAndRewrite(triton::AssertOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto loc = op.getLoc(); + auto b = TritonLLVMOpBuilder(loc, rewriter); auto ctx = rewriter.getContext(); auto typeConverter = getTypeConverter(); auto elems = unpackLLElements(loc, adaptor.getCondition(), rewriter); auto elemTy = elems[0].getType(); - Value condition = int_val(elemTy.getIntOrFloatBitWidth(), 0); + Value condition = b.int_val(elemTy.getIntOrFloatBitWidth(), 0); for (auto elem : elems) { if (elemTy.isSignedInteger() || elemTy.isSignlessInteger()) { - condition = - or_(condition, - icmp_eq(elem, rewriter.create( - loc, elemTy, rewriter.getZeroAttr(elemTy)))); + condition = b.or_( + condition, + b.icmp_eq(elem, rewriter.create( + loc, elemTy, rewriter.getZeroAttr(elemTy)))); } else { assert(false && "Unsupported type for assert"); return failure(); @@ -41,7 +42,7 @@ struct AssertOpConversion : public ConvertOpToLLVMPattern { // tensor in those two operations may have different layout we need to // make sure all the threads are done executing the assert before going to // the next op. - barrier(); + b.barrier(); } rewriter.eraseOp(op); return success(); diff --git a/lib/Conversion/TritonGPUToLLVM/ControlFlowOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/ControlFlowOpToLLVM.cpp index 06e19029ebb8..f2d5351d3c20 100644 --- a/lib/Conversion/TritonGPUToLLVM/ControlFlowOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/ControlFlowOpToLLVM.cpp @@ -13,6 +13,8 @@ struct ReturnOpConversion : public ConvertOpToLLVMPattern { matchAndRewrite(triton::ReturnOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto funcOp = op->getParentOfType(); + auto loc = op.getLoc(); + auto b = TritonLLVMOpBuilder(loc, rewriter); if (funcOp->hasAttr("nvvm.kernel")) { // A GPU kernel if (op.getNumOperands() > 0) { @@ -34,10 +36,9 @@ struct ReturnOpConversion : public ConvertOpToLLVMPattern { funcOp.getResultTypes()); Value packedResults = rewriter.create(op.getLoc(), packedResultsTy); - auto loc = op.getLoc(); for (auto it : llvm::enumerate(adaptor.getOperands())) { - packedResults = insert_val(packedResultsTy, packedResults, it.value(), - it.index()); + packedResults = b.insert_val(packedResultsTy, packedResults, + it.value(), it.index()); } newOp = rewriter.create(op.getLoc(), packedResults); } @@ -78,6 +79,7 @@ struct CallOpConversion : public ConvertOpToLLVMPattern { // Get the last argument of the caller, which is the current stack pointer // of shared memory and append it to the operands of the callOp. auto loc = callOp.getLoc(); + auto b = TritonLLVMOpBuilder(loc, rewriter); auto caller = callOp->getParentOfType(); auto promotedOperands = this->getTypeConverter()->promoteOperands( callOp.getLoc(), /*opOperands=*/callOp->getOperands(), @@ -95,7 +97,7 @@ struct CallOpConversion : public ConvertOpToLLVMPattern { Value opOffsetVal; if (opOffsetAttr) { auto opOffset = opOffsetAttr.getValue().getZExtValue(); - opOffsetVal = i32_val(opOffset); + opOffsetVal = b.i32_val(opOffset); } promotedOperands.push_back( diff --git a/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp index f277561a0deb..c816dbadd858 100644 --- a/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp @@ -62,6 +62,7 @@ struct ConvertLayoutOpConversion ArrayRef origRepShape, ArrayRef outOrd, SmallVector &vals, Value smemBase) const { + auto b = TritonLLVMOpBuilder(loc, rewriter); auto accumNumCTAsEachRep = product(numCTAsEachRep); auto layout = type.getEncoding(); auto rank = type.getRank(); @@ -110,29 +111,29 @@ struct ConvertLayoutOpConversion Value offset = LLVM::linearize(rewriter, loc, multiDimOffsetWrapped, paddedRepShape, outOrd); auto elemPtrTy = smemBase.getType(); - Value ptr = gep(elemPtrTy, llvmElemTy, smemBase, offset); + Value ptr = b.gep(elemPtrTy, llvmElemTy, smemBase, offset); auto vecTy = vec_ty(llvmElemTy, vec); if (stNotRd) { - Value valVec = undef(vecTy); + Value valVec = b.undef(vecTy); for (unsigned v = 0; v < vec; ++v) { auto currVal = vals[elemId + linearCTAId * accumSizePerThread + v]; if (isInt1) - currVal = zext(llvmElemTy, currVal); + currVal = b.zext(llvmElemTy, currVal); else if (isPtr) - currVal = ptrtoint(llvmElemTy, currVal); - valVec = insert_element(vecTy, valVec, currVal, i32_val(v)); + currVal = b.ptrtoint(llvmElemTy, currVal); + valVec = b.insert_element(vecTy, valVec, currVal, b.i32_val(v)); } - store(valVec, ptr); + b.store(valVec, ptr); } else { - Value valVec = load(vecTy, ptr); + Value valVec = b.load(vecTy, ptr); for (unsigned v = 0; v < vec; ++v) { - Value currVal = extract_element(llvmElemTy, valVec, i32_val(v)); + Value currVal = b.extract_element(llvmElemTy, valVec, b.i32_val(v)); if (isInt1) - currVal = icmp_ne(currVal, - rewriter.create( - loc, i8_ty, rewriter.getI8IntegerAttr(0))); + currVal = b.icmp_ne( + currVal, rewriter.create( + loc, i8_ty, rewriter.getI8IntegerAttr(0))); else if (isPtr) - currVal = inttoptr(llvmElemTyOrig, currVal); + currVal = b.inttoptr(llvmElemTyOrig, currVal); vals[elemId + linearCTAId * accumSizePerThread + v] = currVal; } } @@ -146,6 +147,7 @@ struct ConvertLayoutOpConversion ConversionPatternRewriter &rewriter, const TargetInfoBase &targetInfo) const { auto loc = op.getLoc(); + auto b = TritonLLVMOpBuilder(loc, rewriter); auto typeConverter = getTypeConverter(); RankedTensorType srcTy = op.getSrc().getType(); RankedTensorType dstTy = op.getType(); @@ -205,12 +207,12 @@ struct ConvertLayoutOpConversion auto multiDimRepId = getMultiDimIndex(repId, numReplicates, outOrd); if (repId != 0) { - barrier(); + b.barrier(); } processReplica(loc, rewriter, /*stNotRd*/ true, srcTy, inNumCTAsEachRep, multiDimRepId, inVec, paddedRepShape, origRepShape, outOrd, vals, smemBase); - barrier(); + b.barrier(); processReplica(loc, rewriter, /*stNotRd*/ false, dstTy, outNumCTAsEachRep, multiDimRepId, outVec, paddedRepShape, origRepShape, outOrd, outVals, smemBase); @@ -355,6 +357,7 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion ConversionPatternRewriter &rewriter) const { MLIRContext *ctx = op.getContext(); auto loc = op.getLoc(); + auto b = TritonLLVMOpBuilder(loc, rewriter); auto srcTy = op.getSrc().getType(); auto dstTy = op.getType(); @@ -399,9 +402,9 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion // Munge input values for (const auto &it : llvm::enumerate(inVals)) { if (isSubByteInt) { - inVals[it.index()] = zext(llvmElemTy, it.value()); + inVals[it.index()] = b.zext(llvmElemTy, it.value()); } else if (isPtr) { - inVals[it.index()] = ptrtoint(llvmElemTy, it.value()); + inVals[it.index()] = b.ptrtoint(llvmElemTy, it.value()); } } @@ -417,9 +420,9 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion // Unmunge output values for (const auto &it : llvm::enumerate(outVals)) { if (isSubByteInt) { - outVals[it.index()] = trunc(llvmElemTyOrig, it.value()); + outVals[it.index()] = b.trunc(llvmElemTyOrig, it.value()); } else if (isPtr) { - outVals[it.index()] = inttoptr(llvmElemTyOrig, it.value()); + outVals[it.index()] = b.inttoptr(llvmElemTyOrig, it.value()); } } @@ -443,6 +446,7 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion ConversionPatternRewriter &rewriter) const { MLIRContext *ctx = op.getContext(); auto loc = op.getLoc(); + auto b = TritonLLVMOpBuilder(loc, rewriter); StringAttr kRegister = str_attr("register"); StringAttr kLane = str_attr("lane"); @@ -452,9 +456,9 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion StringAttr kIteration = str_attr("iteration"); Value threadId = getThreadId(rewriter, loc); - Value threadsPerWarp = i32_val(srcLayout.getInDimSize(kLane)); - Value laneId = urem(threadId, threadsPerWarp); - Value warpId = udiv(threadId, threadsPerWarp); + Value threadsPerWarp = b.i32_val(srcLayout.getInDimSize(kLane)); + Value laneId = b.urem(threadId, threadsPerWarp); + Value warpId = b.udiv(threadId, threadsPerWarp); auto scratchConfig = getScratchConfigForCvt(op.getSrc().getType(), op.getType()); @@ -541,37 +545,38 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion {kWarp, 0}, {kBlock, 0}})[0] .second; - Value offset = xor_(regBase, i32_val(regIdx)); + Value offset = b.xor_(regBase, b.i32_val(regIdx)); if (paddedSize > 0) { assert(llvm::isPowerOf2_32(paddedStride)); assert(llvm::isPowerOf2_32(paddedSize)); auto rshiftVal = llvm::Log2_32(paddedStride); auto lshiftVal = llvm::Log2_32(paddedSize); - offset = add(shl(lshr(offset, i32_val(rshiftVal)), i32_val(lshiftVal)), - offset); + offset = b.add( + b.shl(b.lshr(offset, b.i32_val(rshiftVal)), b.i32_val(lshiftVal)), + offset); } - auto vecAddr = gep(sharedPtrTy, elemTy, smemBase, offset); + auto vecAddr = b.gep(sharedPtrTy, elemTy, smemBase, offset); vecAddr.setInbounds(true); return vecAddr; }; auto storeBase = applyLinearLayout(loc, rewriter, shmemStoreLayout, - {{kRegister, i32_val(0)}, + {{kRegister, b.i32_val(0)}, {kLane, laneId}, {kWarp, warpId}, - {kBlock, i32_val(0)}})[0] + {kBlock, b.i32_val(0)}})[0] .second; auto loadBase = applyLinearLayout(loc, rewriter, shmemLoadLayout, - {{kRegister, i32_val(0)}, + {{kRegister, b.i32_val(0)}, {kLane, laneId}, {kWarp, warpId}, - {kBlock, i32_val(0)}})[0] + {kBlock, b.i32_val(0)}})[0] .second; // register idx -> Value llvm::MapVector outVals; for (int i = 0; i < iterations; i++) { if (i != 0) - barrier(); + b.barrier(); auto &inRegs = inRegsForIter[i]; auto &outRegs = outRegsForIter[i]; @@ -591,11 +596,11 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion targetInfo.storeMatrixShared(rewriter, loc, vecAddr, valsVec); } else { targetInfo.storeDShared(rewriter, loc, vecAddr, std::nullopt, valsVec, - /*pred=*/true_val()); + /*pred=*/b.true_val()); } } - barrier(); + b.barrier(); for (int j = 0; j < outSize / iterations; j += scratchConfig.outVec) { auto outRegSlice = outRegs[j]; @@ -603,7 +608,7 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion Value valsVec = targetInfo.loadDShared(rewriter, loc, vecAddr, std::nullopt, vec_ty(elemTy, scratchConfig.outVec), - /*pred=*/true_val()); + /*pred=*/b.true_val()); for (Value v : unpackLLVector(loc, valsVec, rewriter)) outVals[outRegSlice++] = v; } @@ -646,6 +651,7 @@ void ConvertLayoutOpUsingLinearLayoutsConversion::transferWithinWarp( ConversionPatternRewriter &rewriter) const { MLIRContext *ctx = op.getContext(); Location loc = op.getLoc(); + auto b = TritonLLVMOpBuilder(loc, rewriter); StringAttr kRegister = str_attr("register"); StringAttr kLane = str_attr("lane"); assert(!cvtNeedsSharedMemory(op.getSrc().getType(), op.getType())); @@ -657,8 +663,8 @@ void ConvertLayoutOpUsingLinearLayoutsConversion::transferWithinWarp( SmallVector shflOuts(Cp.getInDimSize(kRegister)); Value threadId = getThreadId(rewriter, loc); - Value threadsPerWarp = i32_val(Cp.getInDimSize(kLane)); - Value laneId = urem(threadId, threadsPerWarp); + Value threadsPerWarp = b.i32_val(Cp.getInDimSize(kLane)); + Value laneId = b.urem(threadId, threadsPerWarp); // Emit one shuffle per destination register. for (int i : llvm::seq(shflOuts.size())) { @@ -667,22 +673,22 @@ void ConvertLayoutOpUsingLinearLayoutsConversion::transferWithinWarp( // At the same time, for each register, P1 returns the source value index // to provide as the shuffle value. auto out = applyLinearLayout(loc, rewriter, P1, - {{kLane, laneId}, {kRegister, i32_val(i)}}); + {{kLane, laneId}, {kRegister, b.i32_val(i)}}); assert(out.size() == 1); Value srcRegIdx = out.front().second; // The size of the input lane dimension is the number of selects to emit. // TODO(jeff): For dtypes smaller than i32, we can use byte permutes and // shuffle multiple values at a time. - Value shflSrc = undef(srcValues.front().getType()); + Value shflSrc = b.undef(srcValues.front().getType()); for (int j : llvm::seq(reducedP1.getInDimSize(kLane))) { int32_t check = reducedP1.apply({{kLane, j}, {kRegister, i}}).front().second; - shflSrc = - select(icmp_eq(srcRegIdx, i32_val(check)), srcValues[check], shflSrc); + shflSrc = b.select(b.icmp_eq(srcRegIdx, b.i32_val(check)), + srcValues[check], shflSrc); } out = applyLinearLayout(loc, rewriter, Cp, - {{kLane, laneId}, {kRegister, i32_val(i)}}); + {{kLane, laneId}, {kRegister, b.i32_val(i)}}); assert(out.size() == 1); Value shflIdx = out.front().second; shflOuts[i] = targetInfo.shuffleIdx(rewriter, loc, shflSrc, shflIdx); @@ -693,16 +699,16 @@ void ConvertLayoutOpUsingLinearLayoutsConversion::transferWithinWarp( // selects. SmallVector results(shflOuts.size()); for (int i : llvm::seq(results.size())) { - Value result = undef(srcValues.front().getType()); + Value result = b.undef(srcValues.front().getType()); auto out = applyLinearLayout(loc, rewriter, P2inv, - {{kLane, laneId}, {kRegister, i32_val(i)}}); + {{kLane, laneId}, {kRegister, b.i32_val(i)}}); Value resultIdx = out.front().second; for (int j : llvm::seq(reducedP2.getInDimSize(kLane))) { int32_t check = reducedP2.apply({{kLane, j}, {kRegister, i}}).front().second; - result = - select(icmp_eq(resultIdx, i32_val(check)), shflOuts[check], result); + result = b.select(b.icmp_eq(resultIdx, b.i32_val(check)), shflOuts[check], + result); } results[i] = result; } diff --git a/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandFMA.cpp b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandFMA.cpp index 1c0322b75e73..6dbbe0a110ef 100644 --- a/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandFMA.cpp +++ b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandFMA.cpp @@ -37,15 +37,16 @@ bool isSwizzled(SharedEncodingAttr layout) { return layout.getMaxPhase() != 1; } SmallVector swizzleIndices(ConversionPatternRewriter &rewriter, Location loc, SmallVector rawIndices, SharedEncodingAttr layout) { + auto b = TritonLLVMOpBuilder(loc, rewriter); const auto &order = layout.getOrder(); auto rank = order.size(); if (!isSwizzled(layout)) return rawIndices; - auto vec = i32_val(layout.getVec()); - auto perPhase = i32_val(layout.getPerPhase()); - auto maxPhase = i32_val(layout.getMaxPhase()); + auto vec = b.i32_val(layout.getVec()); + auto perPhase = b.i32_val(layout.getPerPhase()); + auto maxPhase = b.i32_val(layout.getMaxPhase()); auto fastIdx = rawIndices[order[0]]; auto secondIdx = rawIndices[order[1]]; @@ -53,10 +54,10 @@ SmallVector swizzleIndices(ConversionPatternRewriter &rewriter, // swizzledGroup = ((fastIdx // vec) ^ phase) * vec // groupRemainder = fastIdx % vec // colOff = swizzledGroup + groupRemainder - auto phase = urem(udiv(secondIdx, perPhase), maxPhase); - auto swizzledGroup = mul(xor_(udiv(fastIdx, vec), phase), vec); - auto groupRemainder = urem(fastIdx, vec); - auto colOff = add(swizzledGroup, groupRemainder); + auto phase = b.urem(b.udiv(secondIdx, perPhase), maxPhase); + auto swizzledGroup = b.mul(b.xor_(b.udiv(fastIdx, vec), phase), vec); + auto groupRemainder = b.urem(fastIdx, vec); + auto colOff = b.add(swizzledGroup, groupRemainder); SmallVector swizzledIndices = rawIndices; swizzledIndices[order[0]] = colOff; @@ -80,6 +81,7 @@ void storeValuesInLinearVector(PatternRewriter &rewriter, Location loc, unsigned kIdx, unsigned nonKIdx, unsigned bIdx, const DimIdx &dim, int vecDim, ArrayRef opOrder) { + auto b = TritonLLVMOpBuilder(loc, rewriter); auto vecTy = cast(vec.getType()); auto vectorSize = vecTy.getNumElements(); auto elemTy = vecTy.getElementType(); @@ -91,7 +93,7 @@ void storeValuesInLinearVector(PatternRewriter &rewriter, Location loc, spatialIdx[vecDim] += elem; unsigned linearIdx = linearize(spatialIdx, perThreadTileShape, opOrder); - opValues[linearIdx] = extract_element(elemTy, vec, i32_val(elem)); + opValues[linearIdx] = b.extract_element(elemTy, vec, b.i32_val(elem)); } } @@ -116,9 +118,10 @@ Value getUnswizzledFirstElemOffset(ConversionPatternRewriter &rewriter, Location loc, unsigned B, unsigned NonK, Value bTileOffset, Value nonKTileOffset, Value bStride, Value nonKStride) { - auto bOffset = mul(urem(bTileOffset, i32_val(B)), bStride); - auto nonKOffset = mul(urem(nonKTileOffset, i32_val(NonK)), nonKStride); - Value threadIdDependantOffset = add(bOffset, nonKOffset); + auto b = TritonLLVMOpBuilder(loc, rewriter); + auto bOffset = b.mul(b.urem(bTileOffset, b.i32_val(B)), bStride); + auto nonKOffset = b.mul(b.urem(nonKTileOffset, b.i32_val(NonK)), nonKStride); + Value threadIdDependantOffset = b.add(bOffset, nonKOffset); return threadIdDependantOffset; } @@ -154,14 +157,15 @@ Value computeSwizzledOffset(ConversionPatternRewriter &rewriter, Location loc, SharedEncodingAttr sharedLayout, ArrayRef opTensorShape, ArrayRef strides) { - Value offset = i32_val(0); + auto b = TritonLLVMOpBuilder(loc, rewriter); + Value offset = b.i32_val(0); // Compute unswizzled multi dim coordinates in shared memory object SmallVector elemMultiDimIndices(3); elemMultiDimIndices[dim.batch] = - add(bTileOffset, i32_val(i.bTile * shapePerCTABTile + i.b)); - elemMultiDimIndices[dim.nonK] = - add(nonKTileOffset, i32_val(i.nonKTile * shapePerCTANonKTile + i.nonK)); - elemMultiDimIndices[dim.k] = i32_val(i.k); + b.add(bTileOffset, b.i32_val(i.bTile * shapePerCTABTile + i.b)); + elemMultiDimIndices[dim.nonK] = b.add( + nonKTileOffset, b.i32_val(i.nonKTile * shapePerCTANonKTile + i.nonK)); + elemMultiDimIndices[dim.k] = b.i32_val(i.k); // Apply swizzling pattern to fastest dimension SmallVector swizzledIndices = @@ -170,9 +174,10 @@ Value computeSwizzledOffset(ConversionPatternRewriter &rewriter, Location loc, // Linearize shared mem object dimensions into flat offset for (int d = 0; d < 3; ++d) { // wrap index if it is larger than tensor - auto wrappedDimIndex = urem(swizzledIndices[d], i32_val(opTensorShape[d])); - auto dimOffset = mul(wrappedDimIndex, strides[d]); - offset = add(offset, dimOffset); + auto wrappedDimIndex = + b.urem(swizzledIndices[d], b.i32_val(opTensorShape[d])); + auto dimOffset = b.mul(wrappedDimIndex, strides[d]); + offset = b.add(offset, dimOffset); } return offset; } @@ -185,16 +190,17 @@ Value computeNonSwizzledOffset(ConversionPatternRewriter &rewriter, unsigned shapePerCTABTile, unsigned shapePerCTANonKTile, ArrayRef strides) { + auto b = TritonLLVMOpBuilder(loc, rewriter); SmallVector offsetIndices(3); offsetIndices[dim.batch] = - i32_val((i.bTile * shapePerCTABTile + i.b) % tensorShape[dim.batch]); - offsetIndices[dim.nonK] = i32_val( + b.i32_val((i.bTile * shapePerCTABTile + i.b) % tensorShape[dim.batch]); + offsetIndices[dim.nonK] = b.i32_val( (i.nonKTile * shapePerCTANonKTile + i.nonK) % tensorShape[dim.nonK]); - offsetIndices[dim.k] = i32_val(i.k); + offsetIndices[dim.k] = b.i32_val(i.k); - Value offset = i32_val(0); + Value offset = b.i32_val(0); for (int d = 0; d < 3; ++d) - offset = add(offset, mul(offsetIndices[d], strides[d])); + offset = b.add(offset, b.mul(offsetIndices[d], strides[d])); return offset; } @@ -213,6 +219,7 @@ Value loadFMAOp(Value srcVal, Value llVal, BlockedEncodingAttr dLayout, Value thread, Location loc, const LLVMTypeConverter *typeConverter, ConversionPatternRewriter &rewriter, const int dotOpNo) { + auto tb = TritonLLVMOpBuilder(loc, rewriter); if (!verifyCTALayout(dLayout.getCTALayout())) return Value(); @@ -247,9 +254,9 @@ Value loadFMAOp(Value srcVal, Value llVal, BlockedEncodingAttr dLayout, auto warpsPerCTA = expandMatrixShapeWithBatch(ArrayRef(dLayout.getWarpsPerCTA())); - auto warpSize = i32_val(triton::gpu::getWarpSize(dLayout)); - auto laneId = urem(thread, warpSize); - auto warpId = udiv(thread, warpSize); + auto warpSize = tb.i32_val(triton::gpu::getWarpSize(dLayout)); + auto laneId = tb.urem(thread, warpSize); + auto warpId = tb.udiv(thread, warpSize); auto laneIds = mlir::LLVM::delinearize(rewriter, loc, laneId, threadsPerWarp, opOrder); auto warpIds = @@ -258,13 +265,13 @@ Value loadFMAOp(Value srcVal, Value llVal, BlockedEncodingAttr dLayout, auto sizePerWarpNonK = sizePerThread[dim.nonK] * threadsPerWarp[dim.nonK]; Value bTileOffset = - mul(laneIds[dim.batch], i32_val(sizePerThread[dim.batch])); + tb.mul(laneIds[dim.batch], tb.i32_val(sizePerThread[dim.batch])); bTileOffset = - add(bTileOffset, mul(warpIds[dim.batch], i32_val(sizePerWarpB))); + tb.add(bTileOffset, tb.mul(warpIds[dim.batch], tb.i32_val(sizePerWarpB))); Value nonKTileOffset = - mul(laneIds[dim.nonK], i32_val(sizePerThread[dim.nonK])); - nonKTileOffset = - add(nonKTileOffset, mul(warpIds[dim.nonK], i32_val(sizePerWarpNonK))); + tb.mul(laneIds[dim.nonK], tb.i32_val(sizePerThread[dim.nonK])); + nonKTileOffset = tb.add( + nonKTileOffset, tb.mul(warpIds[dim.nonK], tb.i32_val(sizePerWarpNonK))); auto elemTy = typeConverter->convertType(opTensorTy.getElementType()); Type ptrTy = smem.getBase().getType(); @@ -320,7 +327,7 @@ Value loadFMAOp(Value srcVal, Value llVal, BlockedEncodingAttr dLayout, auto laneOffset = getUnswizzledFirstElemOffset( rewriter, loc, B, NonK, bTileOffset, nonKTileOffset, smemStrides[dim.batch], smemStrides[dim.nonK]); - basePtr = gep(ptrTy, elemTy, smem.getBase(), laneOffset); + basePtr = tb.gep(ptrTy, elemTy, smem.getBase(), laneOffset); } // This loop nest iterates over all values loaded in one thread across batch, @@ -335,7 +342,7 @@ Value loadFMAOp(Value srcVal, Value llVal, BlockedEncodingAttr dLayout, for (unsigned nonKTile = 0; nonKTile < numNonKTiles; ++nonKTile) for (unsigned nonK = 0; nonK < sizeNonKPerThread; nonK += dimStep[dim.nonK]) { - Value offset = i32_val(0); + Value offset = tb.i32_val(0); Indexes idx = {bTile, b, k, nonKTile, nonK}; // swizzled variant is more general, but it limits optimization of @@ -351,8 +358,8 @@ Value loadFMAOp(Value srcVal, Value llVal, BlockedEncodingAttr dLayout, shapePerCTANonKTile, smemStrides); } - Value elemAddr = gep(ptrTy, elemTy, basePtr, offset); - Value vec = load(vecTy, elemAddr); + Value elemAddr = tb.gep(ptrTy, elemTy, basePtr, offset); + Value vec = tb.load(vecTy, elemAddr); storeValuesInLinearVector( rewriter, loc, opValues, vec, perThreadShape, /*kIdx*/ k, /*nonKIdx*/ nonKTile * sizeNonKPerThread + nonK, diff --git a/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp index 335912c778fc..349ccfdd41be 100644 --- a/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp @@ -40,6 +40,7 @@ struct AddPtrOpConversion : public ConvertOpToLLVMPattern { matchAndRewrite(AddPtrOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Location loc = op->getLoc(); + auto b = TritonLLVMOpBuilder(loc, rewriter); auto resultTy = op.getType(); auto typeConverter = getTypeConverter(); auto resultTensorTy = dyn_cast(resultTy); @@ -52,7 +53,7 @@ struct AddPtrOpConversion : public ConvertOpToLLVMPattern { auto offsets = unpackLLElements(loc, adaptor.getOffset(), rewriter); SmallVector resultVals(elems); for (unsigned i = 0; i < elems; ++i) { - resultVals[i] = gep(ptrTy, elemTy, ptrs[i], offsets[i]); + resultVals[i] = b.gep(ptrTy, elemTy, ptrs[i], offsets[i]); } Value view = packLLElements(loc, typeConverter, resultVals, rewriter, resultTy); @@ -62,8 +63,8 @@ struct AddPtrOpConversion : public ConvertOpToLLVMPattern { auto resultPtrTy = typeConverter->convertType(resultTy); auto resultElemTy = typeConverter->convertType( cast(resultTy).getPointeeType()); - Value result = - gep(resultPtrTy, resultElemTy, adaptor.getPtr(), adaptor.getOffset()); + Value result = b.gep(resultPtrTy, resultElemTy, adaptor.getPtr(), + adaptor.getOffset()); rewriter.replaceOp(op, result); } return success(); @@ -247,6 +248,7 @@ struct ElementwiseInlineAsmOpConversion MultipleOperandsRange operands, ConversionPatternRewriter &rewriter, Location loc) const { + auto b = TritonLLVMOpBuilder(loc, rewriter); SmallVector packedOperands; unsigned numPackedElements = op.getPackedElement(); for (int i = 0, e = op.getNumOperands(); i < e; i++) { @@ -262,9 +264,9 @@ struct ElementwiseInlineAsmOpConversion } Type t = vec_ty(getTypeConverter()->convertType(elemTy), numElementPerReg); - Value packed = undef(t); + Value packed = b.undef(t); for (int k = 0; k < numElementPerReg; k++) { - packed = insert_element(packed, operands[j + k][i], i32_val(k)); + packed = b.insert_element(packed, operands[j + k][i], b.i32_val(k)); } packedOperands.push_back(packed); } @@ -277,6 +279,7 @@ struct ElementwiseInlineAsmOpConversion ConversionPatternRewriter &rewriter, MultipleOperandsRange operands, Location loc) const { auto ctx = op->getContext(); + auto b = TritonLLVMOpBuilder(loc, rewriter); if (operands.size() % op.getPackedElement() != 0) llvm::report_fatal_error("Inline asm op has more packed elements than " @@ -330,13 +333,13 @@ struct ElementwiseInlineAsmOpConversion for (int j = 0; j < op.getPackedElement(); j++) { Value val; if (asmRetTypes.size() > 1) { - val = extract_val(asmResults, structIdx++); + val = b.extract_val(asmResults, structIdx++); } else { val = asmResults; } if (auto vectorTy = dyn_cast(val.getType())) { for (int k = 0; k < vectorTy.getNumElements(); k++) { - ret[i].push_back(extract_element(val, i32_val(k))); + ret[i].push_back(b.extract_element(val, b.i32_val(k))); } j += vectorTy.getNumElements() - 1; } else { @@ -351,6 +354,7 @@ struct ElementwiseInlineAsmOpConversion matchAndRewrite(ElementwiseInlineAsmOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Location loc = op->getLoc(); + auto b = TritonLLVMOpBuilder(loc, rewriter); // Layout is unpackedOperands[operand][elem]. SmallVector> unpackedOperands; @@ -375,7 +379,7 @@ struct ElementwiseInlineAsmOpConversion op.getPackedElement() - numElemsPerThread % op.getPackedElement(); for (auto &operands : unpackedOperands) { for (int i = 0; i < numPaddedValue; i++) { - operands.push_back(undef(operands[0].getType())); + operands.push_back(b.undef(operands[0].getType())); } } } @@ -444,6 +448,7 @@ struct AbsFOpConversion ConversionPatternRewriter &rewriter, Type elemTy, MultipleOperandsRange operands, Location loc) const { + auto b = TritonLLVMOpBuilder(loc, rewriter); if (llvm::isa(elemTy)) { // Mask out the sign bit auto num_bits = @@ -452,7 +457,7 @@ struct AbsFOpConversion auto mask = (1u << (num_bits - 1u)) - 1u; auto maskAttr = rewriter.getIntegerAttr(elemTy, mask); auto maskConst = rewriter.create(loc, maskAttr); - return {and_(operands[0][0], maskConst)}; + return {b.and_(operands[0][0], maskConst)}; } return {rewriter.create(loc, elemTy, operands[0][0])}; diff --git a/lib/Conversion/TritonGPUToLLVM/GatherOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/GatherOpToLLVM.cpp index 15f38930795e..9507c6478ea9 100644 --- a/lib/Conversion/TritonGPUToLLVM/GatherOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/GatherOpToLLVM.cpp @@ -50,14 +50,15 @@ GatherOpConversion::matchAndRewrite(GatherOp op, OpAdaptor adaptor, static Value convertIndexToI32(Location loc, Value index, ConversionPatternRewriter &rewriter) { + auto b = TritonLLVMOpBuilder(loc, rewriter); unsigned idxWidth = index.getType().getIntOrFloatBitWidth(); // The LL index computations are performed with 32 bit integers. If the // indices are something else, cast them to i32. if (idxWidth > 32) { - index = trunc(i32_ty, index); + index = b.trunc(i32_ty, index); } else if (idxWidth < 32) { // Negative indices don't make sense, so zero-extend. - index = zext(i32_ty, index); + index = b.zext(i32_ty, index); } return index; } @@ -65,6 +66,7 @@ static Value convertIndexToI32(Location loc, Value index, void GatherOpConversion::emitGatherInShared( GatherOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { Location loc = op.getLoc(); + auto b = TritonLLVMOpBuilder(loc, rewriter); RankedTensorType srcType = op.getSrc().getType(); // Compute the src subtensor shape owned by this CTA. @@ -95,12 +97,12 @@ void GatherOpConversion::emitGatherInShared( // tensor. Value offset = LLVM::linearize(rewriter, loc, indices, srcShapePerCTA); // Emit the offset into the shared memory and then store the value. - Value ptr = gep(smemBase.getType(), elemType, smemBase, offset); - store(value, ptr); + Value ptr = b.gep(smemBase.getType(), elemType, smemBase, offset); + b.store(value, ptr); } // Synchronize the whole CTA. - barrier(); + b.barrier(); // Grab the index values owned by this thread. SmallVector idxValues = @@ -124,8 +126,8 @@ void GatherOpConversion::emitGatherInShared( for (auto [i, idx, indices] : llvm::enumerate(idxValues, dstIndices)) { indices[axis] = convertIndexToI32(loc, idx, rewriter); Value offset = LLVM::linearize(rewriter, loc, indices, srcShapePerCTA); - Value ptr = gep(smemBase.getType(), elemType, smemBase, offset); - results[i] = load(elemType, ptr); + Value ptr = b.gep(smemBase.getType(), elemType, smemBase, offset); + results[i] = b.load(elemType, ptr); } Value packed = @@ -188,6 +190,7 @@ void GatherOpConversion::emitWarpLocalGather( GatherOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { MLIRContext *ctx = op.getContext(); Location loc = op.getLoc(); + auto b = TritonLLVMOpBuilder(loc, rewriter); RankedTensorType srcType = op.getSrc().getType(); RankedTensorType idxType = op.getIndices().getType(); @@ -299,7 +302,7 @@ void GatherOpConversion::emitWarpLocalGather( {{kBlock, blockId}, {kWarp, warpId}, {kLane, laneId}, - {kRegister, i32_val(idxReg)}}); + {kRegister, b.i32_val(idxReg)}}); assert(column.size() == otherDims.size()); // Combine the computed column with the data-dependent gather index. @@ -320,7 +323,7 @@ void GatherOpConversion::emitWarpLocalGather( int32_t srcBase = invertSrcRegMapColPart.apply(normalizedColumn).front().second; - Value result = undef(srcValues.front().getType()); + Value result = b.undef(srcValues.front().getType()); for (unsigned i = 0; i != numRegsPerColumn; ++i) { int32_t rest = invertSrcRegMapRest.apply({{kGatherDim, i}}).front().second; @@ -328,7 +331,7 @@ void GatherOpConversion::emitWarpLocalGather( Value value = targetInfo.shuffleIdx(rewriter, loc, srcValues[srcRegIdx], srcLane); - result = select(icmp_eq(i32_val(srcRegIdx), srcReg), value, result); + result = b.select(b.icmp_eq(b.i32_val(srcRegIdx), srcReg), value, result); } results.push_back(result); diff --git a/lib/Conversion/TritonGPUToLLVM/HistogramOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/HistogramOpToLLVM.cpp index ed4837fc190d..ac350a149d8c 100644 --- a/lib/Conversion/TritonGPUToLLVM/HistogramOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/HistogramOpToLLVM.cpp @@ -18,9 +18,10 @@ static SmallVector computeWarpLevelHistogram( Location loc, RankedTensorType srcType, SmallVector &srcValues, int numBins, int numThreadPerWarp, Value threadId, ConversionPatternRewriter &rewriter, const TargetInfoBase &targetInfo) { + auto b = TritonLLVMOpBuilder(loc, rewriter); assert(numBins % numThreadPerWarp == 0 && "numBins must be divisible by numThreadPerWarp"); - Value zero = i32_val(0); + Value zero = b.i32_val(0); int numBits = log2Int(numBins); int numBitsLaneId = log2Int(numThreadPerWarp); unsigned numElementsPerThreads = triton::gpu::getTotalElemsPerThread(srcType); @@ -34,25 +35,26 @@ static SmallVector computeWarpLevelHistogram( Value value = srcValues[i]; SmallVector ballotBits; for (int j = 0; j < numBits; ++j) { - Value bitSet = and_(value, i32_val(1 << j)); - Value cmp = icmp_ne(bitSet, zero); + Value bitSet = b.and_(value, b.i32_val(1 << j)); + Value cmp = b.icmp_ne(bitSet, zero); Value bit = targetInfo.ballot(rewriter, loc, int_ty(numThreadPerWarp), cmp); ballotBits.push_back(bit); } uint64_t fullMaskValue = numThreadPerWarp == 32 ? 0xFFFFFFFF : 0xFFFFFFFFFFFFFFFF; - Value fullMask = int_val(numThreadPerWarp, fullMaskValue); + Value fullMask = b.int_val(numThreadPerWarp, fullMaskValue); Value mask = fullMask; // If not all threads have unique data, mask out the redundant ones. if (numThreadWithUniqueData < numThreadPerWarp) { - mask = int_val(numThreadPerWarp, (1ULL << numThreadWithUniqueData) - 1); + mask = b.int_val(numThreadPerWarp, (1ULL << numThreadWithUniqueData) - 1); } for (int i = 0; i < numBitsLaneId; i++) { - Value updateMask = select(icmp_ne(and_(threadId, i32_val(1 << i)), zero), - int_val(numThreadPerWarp, 0), fullMask); - mask = - and_(mask, xor_(ballotBits[i + numBits - numBitsLaneId], updateMask)); + Value updateMask = + b.select(b.icmp_ne(b.and_(threadId, b.i32_val(1 << i)), zero), + b.int_val(numThreadPerWarp, 0), fullMask); + mask = b.and_( + mask, b.xor_(ballotBits[i + numBits - numBitsLaneId], updateMask)); } // at this point, 'mask' tells you which elements are in a bin owned by this // thread. @@ -60,16 +62,16 @@ static SmallVector computeWarpLevelHistogram( Value binMask = mask; for (int j = 0; j < numBits - numBitsLaneId; j++) { Value updateMask = - int_val(numThreadPerWarp, ((k & (1 << j)) ? 0 : fullMaskValue)); - binMask = and_(binMask, xor_(ballotBits[j], updateMask)); + b.int_val(numThreadPerWarp, ((k & (1 << j)) ? 0 : fullMaskValue)); + binMask = b.and_(binMask, b.xor_(ballotBits[j], updateMask)); } // at this point, 'bin_mask' tells you which elements are in the kth bin // owned by this thread. Value bitCount = rewriter.create( loc, int_ty(numThreadPerWarp), binMask); if (numThreadPerWarp > 32) - bitCount = trunc(i32_ty, bitCount); - warpLevelHistogram[k] = add(warpLevelHistogram[k], bitCount); + bitCount = b.trunc(i32_ty, bitCount); + warpLevelHistogram[k] = b.add(warpLevelHistogram[k], bitCount); } } return warpLevelHistogram; @@ -86,22 +88,24 @@ static SmallVector computeCrossWarpHistogram( Value baseSharedMemPtr, const SmallVector &warpLevelHistogram, int numBins, int numThreadPerWarp, const SmallVector &indices, Value threadId, int numWarps) { + auto b = TritonLLVMOpBuilder(loc, rewriter); SmallVector histogramValues; unsigned numWarpsWithUniqueData = mlir::triton::gpu::getWarpsPerCTAWithUniqueData(srcType.getEncoding(), srcType.getShape())[0]; - Value laneId = and_(threadId, i32_val(numThreadPerWarp - 1)); + Value laneId = b.and_(threadId, b.i32_val(numThreadPerWarp - 1)); // Initialize the shared memory with zeros. int64_t numElementPerThread = ceil(numBins, numThreadPerWarp * numWarps); for (int i = 0; i < numElementPerThread; ++i) { - Value offset = add(threadId, i32_val((i * numWarps * numThreadPerWarp))); - offset = urem(offset, i32_val(numBins)); + Value offset = + b.add(threadId, b.i32_val((i * numWarps * numThreadPerWarp))); + offset = b.urem(offset, b.i32_val(numBins)); Value sharedMemPtr = - gep(baseSharedMemPtr.getType(), i32_ty, baseSharedMemPtr, offset); - store(i32_val(0), sharedMemPtr); + b.gep(baseSharedMemPtr.getType(), i32_ty, baseSharedMemPtr, offset); + b.store(b.i32_val(0), sharedMemPtr); } - barrier(); + b.barrier(); Block *afterAtomics = nullptr; // If some warps have replicated data we need to skip those warps when // accumulating. @@ -111,30 +115,30 @@ static SmallVector computeCrossWarpHistogram( rewriter.splitBlock(currentBlock, rewriter.getInsertionPoint()); Block *atomicBlock = rewriter.createBlock(afterAtomics); rewriter.setInsertionPointToEnd(currentBlock); - Value cond = - icmp_ult(threadId, i32_val(numWarpsWithUniqueData * numThreadPerWarp)); + Value cond = b.icmp_ult( + threadId, b.i32_val(numWarpsWithUniqueData * numThreadPerWarp)); rewriter.create(loc, cond, atomicBlock, afterAtomics); rewriter.setInsertionPointToStart(atomicBlock); } // Apply atomic add to update the histogram in shared memory. for (int i = 0; i < warpLevelHistogram.size(); ++i) { Value warpLevelHistogramValue = warpLevelHistogram[i]; - Value offset = - add(mul(laneId, i32_val(warpLevelHistogram.size())), i32_val(i)); + Value offset = b.add(b.mul(laneId, b.i32_val(warpLevelHistogram.size())), + b.i32_val(i)); Value sharedMemPtr = - gep(baseSharedMemPtr.getType(), i32_ty, baseSharedMemPtr, offset); + b.gep(baseSharedMemPtr.getType(), i32_ty, baseSharedMemPtr, offset); atomicAdd(sharedMemPtr, warpLevelHistogramValue, loc, rewriter); } if (afterAtomics) { rewriter.create(loc, afterAtomics); rewriter.setInsertionPointToStart(afterAtomics); } - barrier(); + b.barrier(); // load the histogram to register with the right layout. for (Value index : indices) { Value sharedMemPtr = - gep(baseSharedMemPtr.getType(), i32_ty, baseSharedMemPtr, index); - Value val = load(i32_ty, sharedMemPtr); + b.gep(baseSharedMemPtr.getType(), i32_ty, baseSharedMemPtr, index); + Value val = b.load(i32_ty, sharedMemPtr); histogramValues.push_back(val); } return histogramValues; diff --git a/lib/Conversion/TritonGPUToLLVM/MakeRangeOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/MakeRangeOpToLLVM.cpp index 43120c7913a5..8060b443121b 100644 --- a/lib/Conversion/TritonGPUToLLVM/MakeRangeOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/MakeRangeOpToLLVM.cpp @@ -18,6 +18,7 @@ struct MakeRangeOpConversion matchAndRewrite(triton::MakeRangeOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Location loc = op->getLoc(); + auto b = TritonLLVMOpBuilder(loc, rewriter); RankedTensorType ty = op.getType(); auto shape = ty.getShape(); auto layout = ty.getEncoding(); @@ -32,7 +33,7 @@ struct MakeRangeOpConversion // expand dims + broadcast. very weird behavior otherwise potentially. for (const auto &multiDim : llvm::enumerate(idxs)) { assert(multiDim.value().size() == 1); - retVals[multiDim.index()] = add(multiDim.value()[0], start); + retVals[multiDim.index()] = b.add(multiDim.value()[0], start); } auto typeConverter = getTypeConverter(); Value result = packLLElements(loc, typeConverter, retVals, rewriter, ty); diff --git a/lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp index 03f6e0bb802e..4f45cfaccac8 100644 --- a/lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp @@ -39,6 +39,7 @@ struct GlobalScratchAllocOpConversion matchAndRewrite(triton::gpu::GlobalScratchAllocOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Location loc = op.getLoc(); + auto b = TritonLLVMOpBuilder(loc, rewriter); auto opOffsetAttr = op->getAttrOfType( "ttg.global_scratch_memory_offset"); @@ -50,7 +51,7 @@ struct GlobalScratchAllocOpConversion return failure(); } Value ptr = - LLVM::getGlobalScratchPtr(loc, rewriter, funcOp, i32_val(opOffset)); + LLVM::getGlobalScratchPtr(loc, rewriter, funcOp, b.i32_val(opOffset)); rewriter.replaceOp(op, ptr); return success(); diff --git a/lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.cpp index 3c79e49e6f3a..f63b86b99b55 100644 --- a/lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.cpp @@ -111,7 +111,8 @@ struct ReduceOpConversion void sync(ConversionPatternRewriter &rewriter, Location loc, triton::ReduceOp op) const { - barrier(); + auto b = TritonLLVMOpBuilder(loc, rewriter); + b.barrier(); } // Reduce along op axis for elements that are in the same thread. The @@ -224,14 +225,15 @@ struct ReduceOpConversion ConversionPatternRewriter &rewriter) const { triton::ReduceOp op = helper.getOperation(); Location loc = op.getLoc(); + auto b = TritonLLVMOpBuilder(loc, rewriter); Value threadId = getThreadId(rewriter, loc); auto srcLayout = mlir::cast(helper.getSrcLayout()); auto mod = op.getOperation()->getParentOfType(); Value warpSize = - i32_val(triton::gpu::TritonGPUDialect::getThreadsPerWarp(mod)); - Value warpId = udiv(threadId, warpSize); - Value laneId = urem(threadId, warpSize); + b.i32_val(triton::gpu::TritonGPUDialect::getThreadsPerWarp(mod)); + Value warpId = b.udiv(threadId, warpSize); + Value laneId = b.urem(threadId, warpSize); unsigned axis = op.getAxis(); auto smemShape = helper.getScratchRepShape(); @@ -245,9 +247,9 @@ struct ReduceOpConversion delinearize(rewriter, loc, srcLayout, srcShape, kWarp, warpId); Value laneIdAxis = multiDimLaneId[axis]; - Value laneZero = icmp_eq(laneIdAxis, i32_val(0)); + Value laneZero = b.icmp_eq(laneIdAxis, b.i32_val(0)); Value write = - and_(and_(isRepresentativeLane, isRepresentativeWarp), laneZero); + b.and_(b.and_(isRepresentativeLane, isRepresentativeWarp), laneZero); Value warpIdAxis = multiDimWarpId[axis]; @@ -263,7 +265,7 @@ struct ReduceOpConversion for (unsigned i = 0; i < op.getNumOperands(); ++i) { auto elemTy = getElementType(op, i); Value writePtr = - gep(smemBases[i].getType(), elemTy, smemBases[i], writeOffset); + b.gep(smemBases[i].getType(), elemTy, smemBases[i], writeOffset); targetInfo.storeShared(rewriter, loc, writePtr, acc[i], write); } } @@ -279,6 +281,7 @@ struct ReduceOpConversion unsigned elems = product(smemShape); unsigned sizeInterWarps = helper.getInterWarpSizeWithUniqueData(); Location loc = op.getLoc(); + auto b = TritonLLVMOpBuilder(loc, rewriter); auto mod = op.getOperation()->getParentOfType(); int numLanes = triton::gpu::TritonGPUDialect::getThreadsPerWarp(mod); @@ -286,19 +289,19 @@ struct ReduceOpConversion int numThreads = numLanes * numWarps; Value threadId = getThreadId(rewriter, loc); - Value warpSize = i32_val(numLanes); - Value laneId = urem(threadId, warpSize); - Value zero = i32_val(0); + Value warpSize = b.i32_val(numLanes); + Value laneId = b.urem(threadId, warpSize); + Value zero = b.i32_val(0); unsigned elemsPerThread = std::max(elems / numThreads, 1); - Value threadIsNeeded = icmp_slt(threadId, i32_val(elems)); + Value threadIsNeeded = b.icmp_slt(threadId, b.i32_val(elems)); Value readOffset = threadId; for (unsigned round = 0; round < elemsPerThread; ++round) { SmallVector acc(op.getNumOperands()); for (unsigned i = 0; i < op.getNumOperands(); ++i) { auto elemTy = getElementType(op, i); Value readPtr = - gep(smemBases[i].getType(), elemTy, smemBases[i], readOffset); + b.gep(smemBases[i].getType(), elemTy, smemBases[i], readOffset); acc[i] = targetInfo.loadShared(rewriter, loc, readPtr, elemTy, threadIsNeeded); } @@ -310,20 +313,20 @@ struct ReduceOpConversion for (unsigned i = 0; i < op.getNumOperands(); ++i) { auto elemTy = getElementType(op, i); writePtrs[i] = - gep(smemBases[i].getType(), elemTy, smemBases[i], writeOffset); + b.gep(smemBases[i].getType(), elemTy, smemBases[i], writeOffset); } - Value laneIdModSizeInterWarps = urem(laneId, i32_val(sizeInterWarps)); + Value laneIdModSizeInterWarps = b.urem(laneId, b.i32_val(sizeInterWarps)); Value laneIdModSizeInterWarpsIsZero = - icmp_eq(laneIdModSizeInterWarps, zero); - Value pred = and_(threadIsNeeded, laneIdModSizeInterWarpsIsZero); + b.icmp_eq(laneIdModSizeInterWarps, zero); + Value pred = b.and_(threadIsNeeded, laneIdModSizeInterWarpsIsZero); for (unsigned i = 0; i < op.getNumOperands(); ++i) { targetInfo.storeShared(rewriter, loc, writePtrs[i], acc[i], pred); } if (round != elemsPerThread - 1) { - readOffset = add(readOffset, i32_val(numThreads)); + readOffset = b.add(readOffset, b.i32_val(numThreads)); } } } @@ -336,6 +339,7 @@ struct ReduceOpConversion ConversionPatternRewriter &rewriter) const { triton::ReduceOp op = helper.getOperation(); Location loc = op.getLoc(); + auto b = TritonLLVMOpBuilder(loc, rewriter); auto srcLayout = helper.getSrcLayout(); auto axis = op.getAxis(); auto smemOrder = helper.getOrderWithAxisAtBeginning(); @@ -356,7 +360,7 @@ struct ReduceOpConversion SmallVector resultVals(resultElems); for (size_t j = 0; j < resultElems; ++j) { SmallVector readIdx = resultIndices[j]; - readIdx.insert(readIdx.begin() + op.getAxis(), i32_val(0)); + readIdx.insert(readIdx.begin() + op.getAxis(), b.i32_val(0)); for (size_t resultIdx = 0, resultDim = resultShape.size(); resultIdx < resultDim; ++resultIdx) { auto smemIdx = resultIdx < op.getAxis() ? resultIdx : resultIdx + 1; @@ -366,21 +370,21 @@ struct ReduceOpConversion // elements is accumulated in smem. Modulo smemShape effectively // replicates srcShape elements to src sizePerThread. readIdx[smemIdx] = - urem(readIdx[smemIdx], i32_val(smemShape[smemIdx])); + b.urem(readIdx[smemIdx], b.i32_val(smemShape[smemIdx])); } } Value readOffset = linearize(rewriter, loc, readIdx, smemShape, smemOrder); Value readPtr = - gep(smemBases[i].getType(), elemTy, smemBases[i], readOffset); - resultVals[j] = load(elemTy, readPtr); + b.gep(smemBases[i].getType(), elemTy, smemBases[i], readOffset); + resultVals[j] = b.load(elemTy, readPtr); } results[i] = packLLElements(loc, getTypeConverter(), resultVals, rewriter, resultTy); } else { // 0d-tensor -> scalar - results[i] = load(elemTy, smemBases[i]); + results[i] = b.load(elemTy, smemBases[i]); } } rewriter.replaceOp(op, results); diff --git a/lib/Conversion/TritonGPUToLLVM/ReduceScanCommon.h b/lib/Conversion/TritonGPUToLLVM/ReduceScanCommon.h index 959d35968b83..e3012d29d083 100644 --- a/lib/Conversion/TritonGPUToLLVM/ReduceScanCommon.h +++ b/lib/Conversion/TritonGPUToLLVM/ReduceScanCommon.h @@ -131,6 +131,7 @@ class ConvertTritonGPUReduceScanToLLVMPattern ConversionPatternRewriter &rewriter, const TargetInfoBase &targetInfo) const { auto loc = op.getLoc(); + auto b = TritonLLVMOpBuilder(loc, rewriter); // indices will store the index of the op operands in descending order // of their bitwidths std::vector indices(op.getNumOperands()); @@ -147,8 +148,8 @@ class ConvertTritonGPUReduceScanToLLVMPattern indexToBase[indices[0]] = basePtr; for (unsigned i = 1; i < op.getNumOperands(); ++i) { indexToBase[indices[i]] = - gep(basePtr.getType(), getElementType(op, indices[i - 1]), - indexToBase[indices[i - 1]], i32_val(elems)); + b.gep(basePtr.getType(), getElementType(op, indices[i - 1]), + indexToBase[indices[i - 1]], b.i32_val(elems)); } // smemBases[k] is the base pointer for the k-th operand SmallVector smemBases(op.getNumOperands()); diff --git a/lib/Conversion/TritonGPUToLLVM/ScanOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/ScanOpToLLVM.cpp index fdf17c36b0a2..27cf5e8a4205 100644 --- a/lib/Conversion/TritonGPUToLLVM/ScanOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/ScanOpToLLVM.cpp @@ -53,6 +53,7 @@ static void warpScan(SmallVector> &srcValues, const TargetInfoBase &targetInfo, ScanLoweringHelper &helper, Value laneIdAxis) { Location loc = helper.getLoc(); + auto b = TritonLLVMOpBuilder(loc, rewriter); unsigned scanElementsPerThreads = helper.getAxisNumElementsPerThread(); unsigned elementStride = helper.getAxisElementStride(); unsigned threadStride = helper.getAxisThreadStride(); @@ -69,11 +70,11 @@ static void warpScan(SmallVector> &srcValues, for (unsigned j = 0; j < acc.size(); ++j) { shfl[j] = targetInfo.shuffleUp(rewriter, loc, acc[j], i * threadStride); } - Value mask = icmp_sge(laneIdAxis, i32_val(i)); + Value mask = b.icmp_sge(laneIdAxis, b.i32_val(i)); SmallVector tempAcc = accumulate(helper, rewriter, shfl, acc, mask); for (unsigned j = 0; j < acc.size(); ++j) { - acc[j] = select(mask, tempAcc[j], acc[j]); + acc[j] = b.select(mask, tempAcc[j], acc[j]); } } srcValues[srcIndex] = std::move(acc); @@ -94,6 +95,7 @@ static void storeWarpAccumulator(SmallVector> &srcValues, Value parallelLaneId, const TargetInfoBase &targetInfo) { Location loc = helper.getLoc(); + auto b = TritonLLVMOpBuilder(loc, rewriter); unsigned scanElementsPerThreads = helper.getAxisNumElementsPerThread(); unsigned scanDim = helper.getAxisNumThreadsPerWarpWithUniqueData(); unsigned numParallelLane = helper.getNonAxisNumThreadsPerCTA(); @@ -107,12 +109,13 @@ static void storeWarpAccumulator(SmallVector> &srcValues, if (elementIdx != scanElementsPerThreads - 1) continue; auto lastElement = srcValues[srcIndex]; - Value mask = icmp_eq(laneId, i32_val(scanDim - 1)); - Value index = add(parallelLaneId, mul(warpId, i32_val(numParallelLane))); - index = add(index, i32_val(chunkId * numParallelLane * axisNumWarps)); + Value mask = b.icmp_eq(laneId, b.i32_val(scanDim - 1)); + Value index = + b.add(parallelLaneId, b.mul(warpId, b.i32_val(numParallelLane))); + index = b.add(index, b.i32_val(chunkId * numParallelLane * axisNumWarps)); for (unsigned i = 0; i < lastElement.size(); ++i) { Value writePtr = - gep(smemBases[i].getType(), smemTypes[i], smemBases[i], index); + b.gep(smemBases[i].getType(), smemTypes[i], smemBases[i], index); targetInfo.storeShared(rewriter, loc, writePtr, lastElement[i], mask); } chunkId++; @@ -132,15 +135,16 @@ static void AddPartialReduce(SmallVector> &srcValues, ArrayRef smemTypes, Value warpId, Value laneIdAxis, Value parallelLaneId) { Location loc = helper.getLoc(); + auto b = TritonLLVMOpBuilder(loc, rewriter); unsigned numParallelLane = helper.getNonAxisNumThreadsPerCTA(); unsigned scanElementsPerThreads = helper.getAxisNumElementsPerThread(); unsigned parallelElementsPerThread = helper.getNonAxisNumElementsPerThread(); unsigned elementStride = helper.getAxisElementStride(); unsigned threadStride = helper.getAxisThreadStride(); unsigned axisNumWarps = helper.getAxisNumWarpsWithUniqueData(); - Value maskNotFirstWarp = icmp_ne(warpId, i32_val(0)); - Value maskNotFirstLane = icmp_ne(laneIdAxis, i32_val(0)); - Value maskNotFirstThread = or_(maskNotFirstWarp, maskNotFirstLane); + Value maskNotFirstWarp = b.icmp_ne(warpId, b.i32_val(0)); + Value maskNotFirstLane = b.icmp_ne(laneIdAxis, b.i32_val(0)); + Value maskNotFirstThread = b.or_(maskNotFirstWarp, maskNotFirstLane); struct Accumulator { SmallVector acc; SmallVector maskedAcc; @@ -171,13 +175,14 @@ static void AddPartialReduce(SmallVector> &srcValues, Accumulator &accumulator = accumulators[accumulatorIndex]; unsigned axisBlockId = (blockId / blockStride) % numScanBlocks; for (unsigned i = 0; i < axisNumWarps; ++i) { - Value index = add(parallelLaneId, i32_val(numParallelLane * - (i + chunkId * axisNumWarps))); + Value index = + b.add(parallelLaneId, + b.i32_val(numParallelLane * (i + chunkId * axisNumWarps))); SmallVector partialReduce(helper.getNumOperands()); for (unsigned j = 0; j < helper.getNumOperands(); ++j) { auto elemTy = smemTypes[j]; - Value ptr = gep(smemBases[j].getType(), elemTy, smemBases[j], index); - partialReduce[j] = load(elemTy, ptr); + Value ptr = b.gep(smemBases[j].getType(), elemTy, smemBases[j], index); + partialReduce[j] = b.load(elemTy, ptr); } if (accumulator.acc.size() == 0) { @@ -185,12 +190,12 @@ static void AddPartialReduce(SmallVector> &srcValues, accumulator.maskedAcc = partialReduce; continue; } - Value mask = icmp_sge(warpId, i32_val(i + 1)); + Value mask = b.icmp_sge(warpId, b.i32_val(i + 1)); accumulator.acc = accumulate(helper, rewriter, accumulator.acc, partialReduce); for (unsigned j = 0; j < helper.getNumOperands(); ++j) { accumulator.maskedAcc[j] = - select(mask, accumulator.acc[j], accumulator.maskedAcc[j]); + b.select(mask, accumulator.acc[j], accumulator.maskedAcc[j]); } } @@ -202,7 +207,7 @@ static void AddPartialReduce(SmallVector> &srcValues, // accumulate. auto val = srcValues[srcIndex]; for (unsigned i = 0; i < helper.getNumOperands(); ++i) { - temp[i] = select(maskNotFirstWarp, temp[i], val[i]); + temp[i] = b.select(maskNotFirstWarp, temp[i], val[i]); } } srcValues[srcIndex] = temp; @@ -210,7 +215,8 @@ static void AddPartialReduce(SmallVector> &srcValues, SmallVector lastElement(helper.getNumOperands()); for (unsigned i = 0; i < helper.getNumOperands(); ++i) { auto elem = targetInfo.shuffleUp(rewriter, loc, temp[i], threadStride); - lastElement[i] = select(maskNotFirstLane, elem, accumulator.maskedAcc[i]); + lastElement[i] = + b.select(maskNotFirstLane, elem, accumulator.maskedAcc[i]); } for (unsigned i = 1; i < scanElementsPerThreads; ++i) { pred = axisBlockId == 0 ? maskNotFirstThread : Value{}; @@ -220,8 +226,8 @@ static void AddPartialReduce(SmallVector> &srcValues, // For the first warp and first chunk we don't have anything to // accumulate. for (unsigned j = 0; j < helper.getNumOperands(); ++j) { - laneValue[j] = select(maskNotFirstThread, laneValue[j], - srcValues[srcIndex - i * elementStride][j]); + laneValue[j] = b.select(maskNotFirstThread, laneValue[j], + srcValues[srcIndex - i * elementStride][j]); } } srcValues[srcIndex - i * elementStride] = std::move(laneValue); @@ -239,6 +245,7 @@ static void AddPartialReduceOneWarp(SmallVector> &srcValues, ScanLoweringHelper &helper, Value warpId, Value laneIdAxis, Value laneIdLast) { Location loc = helper.getLoc(); + auto b = TritonLLVMOpBuilder(loc, rewriter); unsigned scanElementsPerThreads = helper.getAxisNumElementsPerThread(); unsigned parallelElementsPerThread = helper.getNonAxisNumElementsPerThread(); unsigned elementStride = helper.getAxisElementStride(); @@ -246,9 +253,9 @@ static void AddPartialReduceOneWarp(SmallVector> &srcValues, unsigned axisNumWarps = helper.getAxisNumWarpsWithUniqueData(); unsigned numParallelLane = helper.getNonAxisNumThreadsPerCTA(); unsigned scanDim = helper.getAxisNumThreadsPerWarpWithUniqueData(); - Value maskFirstWarp = icmp_eq(warpId, i32_val(0)); - Value maskFirstLane = icmp_eq(laneIdAxis, i32_val(0)); - Value maskFirstThread = and_(maskFirstWarp, maskFirstLane); + Value maskFirstWarp = b.icmp_eq(warpId, b.i32_val(0)); + Value maskFirstLane = b.icmp_eq(laneIdAxis, b.i32_val(0)); + Value maskFirstThread = b.and_(maskFirstWarp, maskFirstLane); unsigned numScanBlocks = helper.getAxisNumBlocks(); unsigned numParallelBlocks = helper.getNonAxisNumBlocks(); assert(numScanBlocks * numParallelBlocks * parallelElementsPerThread * @@ -282,7 +289,8 @@ static void AddPartialReduceOneWarp(SmallVector> &srcValues, for (unsigned i = 0; i < helper.getNumOperands(); ++i) { lastElement[i] = targetInfo.shuffleUp( rewriter, loc, srcValues[srcIndex][i], threadStride); - lastElement[i] = select(maskFirstLane, accumulator[i], lastElement[i]); + lastElement[i] = + b.select(maskFirstLane, accumulator[i], lastElement[i]); if (numScanBlocks > 1) // Update accumulator with the value from the last lane. accumulator[i] = targetInfo.shuffleIdx( @@ -298,9 +306,9 @@ static void AddPartialReduceOneWarp(SmallVector> &srcValues, for (unsigned j = 0; j < helper.getNumOperands(); ++j) { // For the first warp and first chunk we don't have anything to // accumulate. - laneValue[j] = - select(maskFirstThread, - srcValues[srcIndex - i * elementStride][j], laneValue[j]); + laneValue[j] = b.select(maskFirstThread, + srcValues[srcIndex - i * elementStride][j], + laneValue[j]); } } srcValues[srcIndex - i * elementStride] = std::move(laneValue); @@ -384,6 +392,7 @@ ScanOpConversion::getDelinearizedIds(ConversionPatternRewriter &rewriter, ScanLoweringHelper &helper, Value laneId, Value warpId) const { auto loc = helper.getLoc(); + auto b = TritonLLVMOpBuilder(loc, rewriter); unsigned axis = helper.getAxis(); auto srcEncoding = helper.getEncoding(); @@ -399,17 +408,17 @@ ScanOpConversion::getDelinearizedIds(ConversionPatternRewriter &rewriter, Value laneIdAxis = multiDimLaneId[axis]; Value warpIdAxis = multiDimWarpId[axis]; - multiDimLaneId[axis] = i32_val(0); + multiDimLaneId[axis] = b.i32_val(0); threadsPerWarp[axis] = 1; Value laneIdParallel = linearize(rewriter, loc, multiDimLaneId, threadsPerWarp, threadOrder); - multiDimWarpId[axis] = i32_val(0); + multiDimWarpId[axis] = b.i32_val(0); warpsPerCTA[axis] = 1; Value warpIdParallel = linearize(rewriter, loc, multiDimWarpId, warpsPerCTA, warpOrder); - Value flatIdParallel = - add(laneIdParallel, - mul(warpIdParallel, i32_val(helper.getNonAxisNumThreadsPerWarp()))); + Value flatIdParallel = b.add( + laneIdParallel, + b.mul(warpIdParallel, b.i32_val(helper.getNonAxisNumThreadsPerWarp()))); return std::make_tuple(laneIdAxis, warpIdAxis, flatIdParallel); } @@ -460,15 +469,16 @@ ScanOpConversion::emitFastScan(triton::ScanOp op, triton::ScanOpAdaptor adaptor, const TargetInfoBase &targetInfo) const { ScanLoweringHelper helper(op); auto loc = helper.getLoc(); + auto b = TritonLLVMOpBuilder(loc, rewriter); if (!helper.isSupported()) return op.emitError("TODO: unsupported scan layout"); Value threadId = getThreadId(rewriter, loc); auto mod = op->getParentOfType(); unsigned iWarpSize = triton::gpu::TritonGPUDialect::getThreadsPerWarp(mod); - Value warpSize = i32_val(iWarpSize); - Value warpId = udiv(threadId, warpSize); - Value laneId = urem(threadId, warpSize); + Value warpSize = b.i32_val(iWarpSize); + Value warpId = b.udiv(threadId, warpSize); + Value laneId = b.urem(threadId, warpSize); // Clamp the lane ID to just threads with unique data within a warp. LinearLayout layout = @@ -476,12 +486,12 @@ ScanOpConversion::emitFastScan(triton::ScanOp op, triton::ScanOpAdaptor adaptor, StringAttr kLane = rewriter.getStringAttr("lane"); int32_t laneMask = layout.getFreeVariableMasks()[kLane]; laneMask = (layout.getInDimSize(kLane) - 1) & ~laneMask; - laneId = and_(laneId, i32_val(laneMask)); + laneId = b.and_(laneId, b.i32_val(laneMask)); auto [laneIdAxis, warpIdAxis, flatIdParallel] = getDelinearizedIds(rewriter, helper, laneId, warpId); auto axisNumWarps = helper.getAxisNumWarpsWithUniqueData(); - warpIdAxis = urem(warpIdAxis, i32_val(axisNumWarps)); + warpIdAxis = b.urem(warpIdAxis, b.i32_val(axisNumWarps)); auto srcValues = unpackInputs(loc, op, adaptor, rewriter, *getTypeConverter()); @@ -493,7 +503,7 @@ ScanOpConversion::emitFastScan(triton::ScanOp op, triton::ScanOpAdaptor adaptor, // having to add a lot of the complex cross warp code (if rev switch // first/last etc). Reverse first seems more maintainable.) if (op.getReverse()) { - warpIdAxis = sub(i32_val(axisNumWarps - 1), warpIdAxis); + warpIdAxis = b.sub(b.i32_val(axisNumWarps - 1), warpIdAxis); srcValues = flipSrcValues(loc, op, rewriter, targetInfo, srcValues, iWarpSize); } @@ -518,7 +528,7 @@ ScanOpConversion::emitFastScan(triton::ScanOp op, triton::ScanOpAdaptor adaptor, // Store the partial reducing for each warp into shared memory. storeWarpAccumulator(srcValues, rewriter, helper, laneIdAxis, warpIdAxis, smemBases, smemTypes, flatIdParallel, targetInfo); - barrier(); + b.barrier(); // Read back the partial reduction of each warp and accumulate them based on // warpId. Then update each chunk of contiguous elements by adding the // accumulated value from the previous lane. @@ -529,7 +539,7 @@ ScanOpConversion::emitFastScan(triton::ScanOp op, triton::ScanOpAdaptor adaptor, // the axis. unsigned scanDim = helper.getAxisNumThreadsPerWarpWithUniqueData(); auto multiDimLaneId = getMultiDimLaneId(rewriter, helper, laneId); - multiDimLaneId[helper.getAxis()] = i32_val(scanDim - 1); + multiDimLaneId[helper.getAxis()] = b.i32_val(scanDim - 1); auto threadsPerWarp = triton::gpu::getThreadsPerWarp(helper.getEncoding()); auto laneIdLast = linearize(rewriter, loc, multiDimLaneId, threadsPerWarp, triton::gpu::getOrder(helper.getEncoding())); diff --git a/lib/Conversion/TritonGPUToLLVM/Utility.cpp b/lib/Conversion/TritonGPUToLLVM/Utility.cpp index 4d25656199dd..feba52e0bed3 100644 --- a/lib/Conversion/TritonGPUToLLVM/Utility.cpp +++ b/lib/Conversion/TritonGPUToLLVM/Utility.cpp @@ -62,6 +62,7 @@ SmallVector> applyLinearLayout(Location loc, RewriterBase &rewriter, const LinearLayout &layout, ArrayRef> indices) { + auto b = TritonLLVMOpBuilder(loc, rewriter); assert(layout.getNumInDims() == indices.size()); for (auto [inDimName, idx] : indices) { assert(layout.hasInDim(inDimName) && "Invalid inDimName"); @@ -88,13 +89,13 @@ applyLinearLayout(Location loc, RewriterBase &rewriter, SmallVector constantComponent = llvm::to_vector(llvm::make_second_range(layout.apply(constantIns))); - Value zero = i32_val(0); + Value zero = b.i32_val(0); SmallVector> outIndices; for (auto [i, outDimName] : llvm::enumerate(layout.getOutDimNames())) { if (constantComponent[i] == 0) outIndices.push_back({outDimName, zero}); else - outIndices.push_back({outDimName, i32_val(constantComponent[i])}); + outIndices.push_back({outDimName, b.i32_val(constantComponent[i])}); } for (auto [inDimName, idx] : indices) { @@ -104,13 +105,13 @@ applyLinearLayout(Location loc, RewriterBase &rewriter, int nBits = layout.getInDimSizeLog2(inDimName); for (int i = 0; i < nBits; i++) { - Value bit = and_(idx, i32_val(1 << i)); - Value bit_is_zero = icmp_eq(bit, zero); + Value bit = b.and_(idx, b.i32_val(1 << i)); + Value bit_is_zero = b.icmp_eq(bit, zero); for (auto &[outDimName, outIdx] : outIndices) { int32_t basis = layout.getBasis(inDimName, i, outDimName); if (basis == 0) continue; - outIdx = xor_(outIdx, select(bit_is_zero, zero, i32_val(basis))); + outIdx = b.xor_(outIdx, b.select(bit_is_zero, zero, b.i32_val(basis))); } } } @@ -123,18 +124,20 @@ std::tuple emitHardwareTuple(Location loc, const TargetInfoBase &target, bool withCTAOffset, unsigned threadsPerWarpCst) { + auto b = TritonLLVMOpBuilder(loc, rewriter); Value threadId = getThreadId(rewriter, loc); - Value threadsPerWarp = i32_val(threadsPerWarpCst); - Value laneId = urem(threadId, threadsPerWarp); - Value warpId = udiv(threadId, threadsPerWarp); + Value threadsPerWarp = b.i32_val(threadsPerWarpCst); + Value laneId = b.urem(threadId, threadsPerWarp); + Value warpId = b.udiv(threadId, threadsPerWarp); Value blockId = - withCTAOffset ? target.getClusterCTAId(rewriter, loc) : i32_val(0); + withCTAOffset ? target.getClusterCTAId(rewriter, loc) : b.i32_val(0); return {laneId, warpId, blockId}; } SmallVector> emitIndices(Location loc, RewriterBase &rewriter, const TargetInfoBase &target, Attribute layout, RankedTensorType type, bool withCTAOffset) { + auto b = TritonLLVMOpBuilder(loc, rewriter); MLIRContext *ctx = rewriter.getContext(); auto shape = type.getShape(); @@ -161,7 +164,7 @@ emitIndices(Location loc, RewriterBase &rewriter, const TargetInfoBase &target, // This approach produces code with lower register pressure and // less computations, compared to fused L(r,t,w,b) method. auto idxsBase = applyLinearLayout(loc, rewriter, ll, - {{kRegister, i32_val(0)}, + {{kRegister, b.i32_val(0)}, {kLane, laneId}, {kWarp, warpId}, {kBlock, blockId}}); @@ -173,7 +176,7 @@ emitIndices(Location loc, RewriterBase &rewriter, const TargetInfoBase &target, auto dimName = idxBase.first; assert(dimName == idxReg.first && "dim names of block+warp+thread and register idx should be equal"); - auto idx = xor_(idxBase.second, i32_val(idxReg.second)); + auto idx = b.xor_(idxBase.second, b.i32_val(idxReg.second)); idxs.emplace_back(dimName, idx); } assert(idxs.size() == rank); @@ -195,6 +198,7 @@ Value getSmemVecAddr(const LinearLayout ®Layout, triton::gpu::MemDescType sharedTy, Type elemLlvmTy, Value regId, Value laneId, Value warpId, Value blockId, Location loc, RewriterBase &rewriter) { + auto b = TritonLLVMOpBuilder(loc, rewriter); MLIRContext *ctx = rewriter.getContext(); StringAttr kBlock = str_attr("block"); StringAttr kRegister = str_attr("register"); @@ -283,16 +287,16 @@ Value getSmemVecAddr(const LinearLayout ®Layout, {kBlock, blockId}})); for (auto i = 0; i < rank; i++) { multiDimTensorOffsets[i].second = - add(multiDimTensorOffsets[i].second, smemOffsets[i]); + b.add(multiDimTensorOffsets[i].second, smemOffsets[i]); } smemOffset = applyLinearLayout(loc, rewriter, invertAllocSharedLayout, multiDimTensorOffsets)[0] .second; Value baseToAllocBaseDist = dot(rewriter, loc, smemOffsets, smemStrides); - smemOffset = sub(smemOffset, baseToAllocBaseDist); + smemOffset = b.sub(smemOffset, baseToAllocBaseDist); } auto ptrTy = smemBase.getType(); - auto vecAddr = gep(ptrTy, elemLlvmTy, smemBase, smemOffset); + auto vecAddr = b.gep(ptrTy, elemLlvmTy, smemBase, smemOffset); vecAddr.setInbounds(true); return vecAddr; } @@ -305,6 +309,7 @@ bool emitTransferBetweenRegistersAndShared( Location loc, RewriterBase &rewriter, const TargetInfoBase &target, std::function perVectorCallback) { MLIRContext *ctx = rewriter.getContext(); + auto b = TritonLLVMOpBuilder(loc, rewriter); StringAttr kBlock = str_attr("block"); StringAttr kRegister = str_attr("register"); @@ -366,7 +371,7 @@ bool emitTransferBetweenRegistersAndShared( auto vecTy = vec_ty(elemLlvmTy, vecElems); SmallVector ret; for (int i = 0; i < numElems / vecElems; i++) { - auto regId = i32_val(i * vecElems); + auto regId = b.i32_val(i * vecElems); auto vecAddr = getSmemVecAddr( regLayout, regToSharedLayout, invertAllocSharedLayout, smemObj, sharedTy, elemLlvmTy, regId, laneId, warpId, blockId, loc, rewriter); @@ -395,16 +400,17 @@ SmallVector loadSharedToDistributed(RankedTensorType dstTy, const SharedMemoryObject &smemObj, Location loc, RewriterBase &rewriter, const TargetInfoBase &target) { + auto b = TritonLLVMOpBuilder(loc, rewriter); SmallVector ret; bool success = emitTransferBetweenRegistersAndShared( dstTy, srcTy, elemLlvmTy, /*maxVecElems=*/std::nullopt, smemObj, loc, rewriter, target, [&](VectorType vecTy, Value vecAddr) { - auto vecVal = load(vecTy, vecAddr); + auto vecVal = b.load(vecTy, vecAddr); vecVal.setAlignment(vecTy.getNumElements() * elemLlvmTy.getIntOrFloatBitWidth() / 8); for (int v = 0; v < vecTy.getNumElements(); v++) { - ret.push_back(extract_element(elemLlvmTy, vecVal, i32_val(v))); + ret.push_back(b.extract_element(elemLlvmTy, vecVal, b.i32_val(v))); } }); if (!success) @@ -420,17 +426,18 @@ void storeDistributedToShared(triton::gpu::MemDescType dstTy, RewriterBase &rewriter, const TargetInfoBase &target, std::pair *const llvmOpCount) { + auto b = TritonLLVMOpBuilder(loc, rewriter); bool success = emitTransferBetweenRegistersAndShared( srcTy, dstTy, elemLlvmTy, /*maxVecElems=*/std::nullopt, smemObj, loc, rewriter, target, [&](VectorType vecTy, Value vecAddr) { ArrayRef vals = srcVals.take_front(vecTy.getNumElements()); srcVals = srcVals.drop_front(vecTy.getNumElements()); - Value vec = undef(vecTy); + Value vec = b.undef(vecTy); for (int i = 0; i < vals.size(); i++) { - vec = insert_element(vec, vals[i], i32_val(i)); + vec = b.insert_element(vec, vals[i], b.i32_val(i)); } - store(vec, vecAddr) + b.store(vec, vecAddr) .setAlignment(vecTy.getNumElements() * elemLlvmTy.getIntOrFloatBitWidth() / 8); if (llvmOpCount) { @@ -567,6 +574,7 @@ bool isConstantZero(Value v) { Value getStructFromSharedMemoryObject(Location loc, const SharedMemoryObject &smemObj, RewriterBase &rewriter) { + auto b = TritonLLVMOpBuilder(loc, rewriter); auto elems = smemObj.getElems(); auto types = smemObj.getTypes(); auto structTy = @@ -575,7 +583,7 @@ Value getStructFromSharedMemoryObject(Location loc, Value llvmStruct = rewriter.create(loc, structTy); for (const auto &v : llvm::enumerate(elems)) { assert(v.value() && "can not insert null values"); - llvmStruct = insert_val(structTy, llvmStruct, v.value(), v.index()); + llvmStruct = b.insert_val(structTy, llvmStruct, v.value(), v.index()); } return llvmStruct; } @@ -584,12 +592,13 @@ SharedMemoryObject getSharedMemoryObjectFromStruct(Location loc, Value llvmStruct, Type elemTy, RewriterBase &rewriter) { + auto b = TritonLLVMOpBuilder(loc, rewriter); ArrayRef types = cast(llvmStruct.getType()).getBody(); SmallVector elems(types.size()); for (unsigned i = 0; i < types.size(); ++i) { Type type = types[i]; - elems[i] = extract_val(type, llvmStruct, i); + elems[i] = b.extract_val(type, llvmStruct, i); } return {/*base=*/elems[0], /*baseElemType=*/elemTy, @@ -598,6 +607,7 @@ SharedMemoryObject getSharedMemoryObjectFromStruct(Location loc, // Extract the bits of `a` that are set in `mask` Value pext_i32(RewriterBase &rewriter, Location loc, Value a, uint32_t mask) { + auto b = TritonLLVMOpBuilder(loc, rewriter); assert(a.getType() == i32_ty && "a must be i32"); // Handle width = 32 to avoid doing 1 << 32 if (mask == 0xFFFFFFFF) @@ -607,7 +617,7 @@ Value pext_i32(RewriterBase &rewriter, Location loc, Value a, uint32_t mask) { // https://forums.developer.nvidia.com/t/pdep-and-pext-functionality-for-cuda/270973 uint32_t mskConst = mask; uint32_t extcnt = 0; - Value result = i32_val(0); + Value result = b.i32_val(0); while (mskConst) { uint32_t oldmsk = mskConst; uint32_t bitgrplsb = mskConst & (-mskConst); @@ -618,7 +628,8 @@ Value pext_i32(RewriterBase &rewriter, Location loc, Value a, uint32_t mask) { uint32_t grplen = __builtin_ctz(~(bitgrp >> lsbpos)); uint32_t shift = lsbpos - extcnt; extcnt += grplen; - result = or_(result, lshr(and_(i32_val(bitgrp), a), i32_val(shift))); + result = + b.or_(result, b.lshr(b.and_(b.i32_val(bitgrp), a), b.i32_val(shift))); } return result; } @@ -627,14 +638,16 @@ std::tuple, Value> delinearize(RewriterBase &rewriter, Location loc, triton::gpu::DistributedEncodingTrait layout, ArrayRef shape, StringAttr dimName, Value linear) { + auto b = TritonLLVMOpBuilder(loc, rewriter); auto ll = triton::gpu::toLinearLayout(shape, layout); auto linearLayout = triton::gpu::LinearEncodingAttr::get(rewriter.getContext(), ll); assert(ll.hasInDim(dimName)); int32_t freeVarMask = ll.getFreeVariableMasks()[dimName]; - auto isRepresentative = true_val(); + auto isRepresentative = b.true_val(); if (freeVarMask != 0) { - isRepresentative = icmp_eq(and_(i32_val(freeVarMask), linear), i32_val(0)); + isRepresentative = + b.icmp_eq(b.and_(b.i32_val(freeVarMask), linear), b.i32_val(0)); // We remove the bits of linear that are set to one in freeVarMask int32_t nonFreeVarMask = ~freeVarMask & (ll.getInDimSize(dimName) - 1); linear = pext_i32(rewriter, loc, linear, nonFreeVarMask); @@ -673,13 +686,14 @@ SmallVector delinearize(RewriterBase &rewriter, Location loc, SmallVector delinearize(RewriterBase &rewriter, Location loc, unsigned linear, ArrayRef shape) { + auto b = TritonLLVMOpBuilder(loc, rewriter); unsigned rank = shape.size(); assert(rank > 0); SmallVector multiDim(rank); unsigned remained = linear; for (auto &&en : llvm::enumerate(shape)) { unsigned dimSize = en.value(); - multiDim[en.index()] = i32_val(remained % dimSize); + multiDim[en.index()] = b.i32_val(remained % dimSize); remained = remained / dimSize; } return multiDim; @@ -687,14 +701,15 @@ SmallVector delinearize(RewriterBase &rewriter, Location loc, SmallVector delinearize(RewriterBase &rewriter, Location loc, Value linear, ArrayRef shape) { + auto b = TritonLLVMOpBuilder(loc, rewriter); unsigned rank = shape.size(); assert(rank > 0); SmallVector multiDim(rank); Value remained = linear; for (auto &&en : llvm::enumerate(shape)) { - Value dimSize = i32_val(en.value()); - multiDim[en.index()] = urem(remained, dimSize); - remained = udiv(remained, dimSize); + Value dimSize = b.i32_val(en.value()); + multiDim[en.index()] = b.urem(remained, dimSize); + remained = b.udiv(remained, dimSize); } return multiDim; } @@ -720,14 +735,15 @@ Value linearize(RewriterBase &rewriter, Location loc, ArrayRef multiDim, Value linearize(RewriterBase &rewriter, Location loc, ArrayRef multiDim, ArrayRef shape) { + auto b = TritonLLVMOpBuilder(loc, rewriter); auto rank = multiDim.size(); - Value linear = i32_val(0); + Value linear = b.i32_val(0); if (rank > 0) { linear = multiDim.back(); for (auto [dim, dimShape] : llvm::reverse(llvm::zip(multiDim.drop_back(), shape.drop_back()))) { - Value dimSize = i32_val(dimShape); - linear = add(mul(linear, dimSize), dim); + Value dimSize = b.i32_val(dimShape); + linear = b.add(b.mul(linear, dimSize), dim); } } return linear; @@ -743,6 +759,7 @@ size_t linearize(ArrayRef multiDim, ArrayRef shape, Value addStringToModule(Location loc, RewriterBase &rewriter, StringRef key, StringRef content) { + auto b = TritonLLVMOpBuilder(loc, rewriter); auto moduleOp = rewriter.getBlock()->getParent()->getParentOfType(); auto ctx = moduleOp.getContext(); unsigned stringNumber = 0; @@ -766,12 +783,12 @@ Value addStringToModule(Location loc, RewriterBase &rewriter, StringRef key, rewriter.getStringAttr(contentStr)); } - Value zero = i32_val(0); + Value zero = b.i32_val(0); Type globalPtrType = LLVM::LLVMPointerType::get(ctx, global.getAddrSpace()); Value globalPtr = rewriter.create( UnknownLoc::get(ctx), globalPtrType, global.getSymName()); Value stringStart = - gep(ptr_ty(ctx), i8_ty, globalPtr, SmallVector({zero})); + b.gep(ptr_ty(ctx), i8_ty, globalPtr, SmallVector({zero})); return stringStart; } @@ -781,6 +798,7 @@ SmallVector getMultiDimOffset(Attribute layout, Location loc, unsigned elemId, RankedTensorType type, ArrayRef multiDimCTAInRepId, ArrayRef shapePerCTATile) { + auto b = TritonLLVMOpBuilder(loc, rewriter); auto shape = type.getShape(); unsigned rank = shape.size(); if (auto blockedLayout = dyn_cast(layout)) { @@ -791,9 +809,9 @@ SmallVector getMultiDimOffset(Attribute layout, Location loc, elemId, getSizePerThread(layout), getOrder(layout)); for (unsigned d = 0; d < rank; ++d) { multiDimOffset[d] = - add(multiDimOffsetFirstElem[d], - i32_val(multiDimCTAInRepId[d] * shapePerCTATile[d] + - multiDimElemId[d])); + b.add(multiDimOffsetFirstElem[d], + b.i32_val(multiDimCTAInRepId[d] * shapePerCTATile[d] + + multiDimElemId[d])); } return multiDimOffset; } @@ -839,32 +857,34 @@ SmallVector getMultiDimOffset(Attribute layout, Location loc, auto warpsPerCTA = mmaLayout.getWarpsPerCTA(); auto warpOrder = triton::gpu::getWarpOrder(mmaLayout); multiDimWarpId = delinearize(rewriter, loc, warpId, warpsPerCTA, warpOrder); - Value _1 = i32_val(1); - Value _2 = i32_val(2); - Value _4 = i32_val(4); - Value _8 = i32_val(8); - Value _16 = i32_val(16); + Value _1 = b.i32_val(1); + Value _2 = b.i32_val(2); + Value _4 = b.i32_val(4); + Value _8 = b.i32_val(8); + Value _16 = b.i32_val(16); if (mmaLayout.isAmpere() || mmaLayout.isHopper()) { - multiDimWarpId[rank - 1] = urem( - multiDimWarpId[rank - 1], - i32_val(ceil(shapePerCTA[rank - 1], instrShape[rank - 1]))); - multiDimWarpId[rank - 2] = urem( - multiDimWarpId[rank - 2], - i32_val(ceil(shapePerCTA[rank - 2], instrShape[rank - 2]))); - - Value mmaGrpId = udiv(laneId, _4); - Value mmaGrpIdP8 = add(mmaGrpId, _8); - Value mmaThreadIdInGrp = urem(laneId, _4); - Value mmaThreadIdInGrpM2 = mul(mmaThreadIdInGrp, _2); - Value mmaThreadIdInGrpM2P1 = add(mmaThreadIdInGrpM2, _1); + multiDimWarpId[rank - 1] = + b.urem(multiDimWarpId[rank - 1], + b.i32_val(ceil(shapePerCTA[rank - 1], + instrShape[rank - 1]))); + multiDimWarpId[rank - 2] = + b.urem(multiDimWarpId[rank - 2], + b.i32_val(ceil(shapePerCTA[rank - 2], + instrShape[rank - 2]))); + + Value mmaGrpId = b.udiv(laneId, _4); + Value mmaGrpIdP8 = b.add(mmaGrpId, _8); + Value mmaThreadIdInGrp = b.urem(laneId, _4); + Value mmaThreadIdInGrpM2 = b.mul(mmaThreadIdInGrp, _2); + Value mmaThreadIdInGrpM2P1 = b.add(mmaThreadIdInGrpM2, _1); Value rowWarpOffset = - mul(multiDimWarpId[rank - 2], i32_val(instrShape[rank - 2])); - mmaRowIdx[0] = add(mmaGrpId, rowWarpOffset); - mmaRowIdx[1] = add(mmaGrpIdP8, rowWarpOffset); + b.mul(multiDimWarpId[rank - 2], b.i32_val(instrShape[rank - 2])); + mmaRowIdx[0] = b.add(mmaGrpId, rowWarpOffset); + mmaRowIdx[1] = b.add(mmaGrpIdP8, rowWarpOffset); Value colWarpOffset = - mul(multiDimWarpId[rank - 1], i32_val(instrShape[rank - 1])); - mmaColIdx[0] = add(mmaThreadIdInGrpM2, colWarpOffset); - mmaColIdx[1] = add(mmaThreadIdInGrpM2P1, colWarpOffset); + b.mul(multiDimWarpId[rank - 1], b.i32_val(instrShape[rank - 1])); + mmaColIdx[0] = b.add(mmaThreadIdInGrpM2, colWarpOffset); + mmaColIdx[1] = b.add(mmaThreadIdInGrpM2P1, colWarpOffset); } else { llvm_unreachable("Unexpected MMALayout version"); } @@ -875,24 +895,26 @@ SmallVector getMultiDimOffset(Attribute layout, Location loc, unsigned nGrpId = elemId / 4; multiDimOffset[0] = elemIdRem4 < 2 ? mmaRowIdx[0] : mmaRowIdx[1]; multiDimOffset[1] = elemIdRem4 % 2 == 0 ? mmaColIdx[0] : mmaColIdx[1]; - multiDimOffset[1] = add(multiDimOffset[1], i32_val(8 * nGrpId)); - multiDimOffset[0] = add(multiDimOffset[0], i32_val(multiDimCTAInRepId[0] * - shapePerCTATile[0])); - multiDimOffset[1] = add(multiDimOffset[1], i32_val(multiDimCTAInRepId[1] * - shapePerCTATile[1])); + multiDimOffset[1] = b.add(multiDimOffset[1], b.i32_val(8 * nGrpId)); + multiDimOffset[0] = + b.add(multiDimOffset[0], + b.i32_val(multiDimCTAInRepId[0] * shapePerCTATile[0])); + multiDimOffset[1] = + b.add(multiDimOffset[1], + b.i32_val(multiDimCTAInRepId[1] * shapePerCTATile[1])); } else if (mmaLayout.isAmpere()) { if (rank == 3) multiDimOffset[0] = - add(multiDimWarpId[0], - i32_val(multiDimCTAInRepId[0] * shapePerCTATile[0])); + b.add(multiDimWarpId[0], + b.i32_val(multiDimCTAInRepId[0] * shapePerCTATile[0])); multiDimOffset[rank - 2] = elemId < 2 ? mmaRowIdx[0] : mmaRowIdx[1]; multiDimOffset[rank - 1] = elemId % 2 == 0 ? mmaColIdx[0] : mmaColIdx[1]; - multiDimOffset[rank - 2] = - add(multiDimOffset[rank - 2], i32_val(multiDimCTAInRepId[rank - 2] * - shapePerCTATile[rank - 2])); - multiDimOffset[rank - 1] = - add(multiDimOffset[rank - 1], i32_val(multiDimCTAInRepId[rank - 1] * - shapePerCTATile[rank - 1])); + multiDimOffset[rank - 2] = b.add( + multiDimOffset[rank - 2], + b.i32_val(multiDimCTAInRepId[rank - 2] * shapePerCTATile[rank - 2])); + multiDimOffset[rank - 1] = b.add( + multiDimOffset[rank - 1], + b.i32_val(multiDimCTAInRepId[rank - 1] * shapePerCTATile[rank - 1])); } else { llvm_unreachable("Unexpected MMALayout version"); } @@ -911,8 +933,8 @@ SmallVector getMultiDimOffset(Attribute layout, Location loc, emitWmmaOffsetForCTA(wmmaLayout, offsets, 0, multiDimCTAInRepId[0], multiDimCTAInRepId[1]); } - multiDimOffset[0] = add(multiDimBase[0], i32_val(offsets[elemId][0])); - multiDimOffset[1] = add(multiDimBase[1], i32_val(offsets[elemId][1])); + multiDimOffset[0] = b.add(multiDimBase[0], b.i32_val(offsets[elemId][0])); + multiDimOffset[1] = b.add(multiDimBase[1], b.i32_val(offsets[elemId][1])); return multiDimOffset; } llvm_unreachable("unexpected layout in getMultiDimOffset"); @@ -922,11 +944,12 @@ SmallVector getWrappedMultiDimOffset( RewriterBase &rewriter, Location loc, ArrayRef multiDimOffset, ArrayRef shape, SmallVector shapePerCTATile, SmallVector shapePerCTA) { + auto b = TritonLLVMOpBuilder(loc, rewriter); unsigned rank = shape.size(); SmallVector multiDimOffsetWrapped(rank); for (unsigned d = 0; d < rank; ++d) { if (shapePerCTATile[d] > shapePerCTA[d]) - multiDimOffsetWrapped[d] = urem(multiDimOffset[d], i32_val(shape[d])); + multiDimOffsetWrapped[d] = b.urem(multiDimOffset[d], b.i32_val(shape[d])); else multiDimOffsetWrapped[d] = multiDimOffset[d]; } @@ -939,13 +962,14 @@ SharedMemoryObject getExpandedSharedMemoryObject(ConversionPatternRewriter &rewriter, Location loc, SharedMemoryObject smemObj, ArrayRef shape) { + auto b = TritonLLVMOpBuilder(loc, rewriter); assert(shape.size() == 2 || shape.size() == 3); auto offsets = smemObj.getOffsets(); auto rank = offsets.size(); assert(rank == shape.size()); if (rank == 3) return smemObj; - offsets.insert(offsets.begin(), i32_val(0)); + offsets.insert(offsets.begin(), b.i32_val(0)); auto expandedSmemObj = SharedMemoryObject(smemObj.getBase(), smemObj.getBaseElemType(), offsets); return expandedSmemObj; diff --git a/lib/Conversion/TritonGPUToLLVM/ViewOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/ViewOpToLLVM.cpp index 160385458656..5dfd10077b5d 100644 --- a/lib/Conversion/TritonGPUToLLVM/ViewOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/ViewOpToLLVM.cpp @@ -21,6 +21,7 @@ struct SplatOpConversion : public ConvertOpToLLVMPattern { const LLVMTypeConverter *typeConverter, ConversionPatternRewriter &rewriter, Location loc) { + auto b = TritonLLVMOpBuilder(loc, rewriter); auto tensorTy = cast(resType); // Check the converted type for the tensor as depending on the encoding the // converter may pick different element types. @@ -36,13 +37,13 @@ struct SplatOpConversion : public ConvertOpToLLVMPattern { unsigned ratio = srcBitWidth / cstBitWidth; Type intTy = IntegerType::get(elemType.getContext(), cstBitWidth); VectorType vecType = VectorType::get(ratio, intTy); - Value intCst = bitcast(constVal, intTy); - Value vec = undef(vecType); + Value intCst = b.bitcast(constVal, intTy); + Value vec = b.undef(vecType); for (unsigned i = 0; i < ratio; ++i) - vec = insert_element(vecType, vec, intCst, int_val(32, i)); + vec = b.insert_element(vecType, vec, intCst, b.int_val(32, i)); constVal = vec; } - auto llSrc = bitcast(constVal, srcType); + auto llSrc = b.bitcast(constVal, srcType); size_t elemsPerThread = getTotalElemsPerThread(tensorTy); llvm::SmallVector elems(elemsPerThread, llSrc); return packLLElements(loc, typeConverter, elems, rewriter, resType); @@ -366,6 +367,7 @@ struct MemDescSubviewOpConversion matchAndRewrite(triton::gpu::MemDescSubviewOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Location loc = op->getLoc(); + auto b = TritonLLVMOpBuilder(loc, rewriter); auto srcTy = op.getSrc().getType(); auto llvmElemTy = getTypeConverter()->convertType(srcTy.getElementType()); auto layoutOrder = getOrder(srcTy.getEncoding()); @@ -381,14 +383,14 @@ struct MemDescSubviewOpConversion auto destRank = op.getResult().getType().getRank(); auto rankReduced = srcTy.getRank() - destRank; for (int i = rankReduced; i < opOffsetVals.size(); i++) { - offsetVals.push_back(add(opOffsetVals[i], smemObj.getOffsets()[i])); + offsetVals.push_back(b.add(opOffsetVals[i], smemObj.getOffsets()[i])); } // Compute the offset based on the original strides of the shared memory // object auto offset = dot(rewriter, loc, opOffsetVals, opSmemStrides); auto elemPtrTy = smemObj.getBase().getType(); smemObj = SharedMemoryObject( - gep(elemPtrTy, llvmElemTy, smemObj.getBase(), offset), llvmElemTy, + b.gep(elemPtrTy, llvmElemTy, smemObj.getBase(), offset), llvmElemTy, offsetVals); auto retVal = getStructFromSharedMemoryObject(loc, smemObj, rewriter); rewriter.replaceOp(op, retVal); diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/BufferOpsEmitter.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/BufferOpsEmitter.cpp index 6d79bd7aae08..c333de6162f4 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/BufferOpsEmitter.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/BufferOpsEmitter.cpp @@ -39,6 +39,7 @@ BufferEmitter::BufferEmitter(RewriterBase &rw, Location loc, TargetInfo ti) Value BufferEmitter::createResourceDescriptor(Value basePtr, Value blockStride) { + auto b = TritonLLVMOpBuilder(loc, rewriter); // 1. Create the resource descriptor // bits 0-11: dst sel, ignored by these intrinsics // bits 12-14: data format (ignored, must be nonzero, 7=float) @@ -64,11 +65,11 @@ Value BufferEmitter::createResourceDescriptor(Value basePtr, flags |= (oob << 28); } - Value stride = int_val(16, 0); + Value stride = b.int_val(16, 0); if (targetInfo.getISAFamily() == ISAFamily::CDNA3) { if (blockStride) { // TODO: BufferAtomicRMWOp is unsupported - Value enableSwizzle = int_val(16, 16384); - Value mask14b = int_val(16, 16383); + Value enableSwizzle = b.int_val(16, 16384); + Value mask14b = b.int_val(16, 16383); // Cache swizzle supports only upto 8k stride. Also simply swizzling the // largest available stride (8k) doesn't help those unsupported large // stride. Especially better to avoid using the stride which is 2^N when @@ -82,9 +83,9 @@ Value BufferEmitter::createResourceDescriptor(Value basePtr, } } - Value flagsConst = int_val(32, flags); + Value flagsConst = b.int_val(32, flags); Type rsrcType = LLVM::LLVMPointerType::get(rewriter.getContext(), 8); - Value numRecordsByte = int_val(32, std::numeric_limits::max() - 1); + Value numRecordsByte = b.int_val(32, std::numeric_limits::max() - 1); Value resource = rewriter.createOrFold( loc, rsrcType, basePtr, stride, numRecordsByte, flagsConst); @@ -94,24 +95,26 @@ Value BufferEmitter::createResourceDescriptor(Value basePtr, Value BufferEmitter::emitLoad(Type type, Value rsrcDesc, Value offset, Value pred, Value falseVal, triton::CacheModifier cm) { + auto b = TritonLLVMOpBuilder(loc, rewriter); SmallVector args; fillCommonArgs(type, rsrcDesc, offset, pred, cm, /*isBufferLoad=*/true, args); Type bufferType = getBufferOpType(type, false); Value data = rewriter.create( loc, bufferType, args, ArrayRef()); - data = bitcast(data, type); + data = b.bitcast(data, type); if (!isZero(falseVal)) - data = select(pred, data, falseVal); + data = b.select(pred, data, falseVal); return data; } Value BufferEmitter::emitAtomicRMW(RMWOp rmwType, Type type, Value rsrcDesc, Value offset, Value data, Value pred, bool hasUsers) { + auto b = TritonLLVMOpBuilder(loc, rewriter); VectorType vecTy = cast(data.getType()); Type bufferType = getBufferOpType(type, true); if (vecTy != bufferType) - data = bitcast(data, bufferType); + data = b.bitcast(data, bufferType); SmallVector args{data}; fillCommonArgsAtomics(type, rsrcDesc, offset, pred, hasUsers, args); @@ -126,15 +129,16 @@ Value BufferEmitter::emitAtomicRMW(RMWOp rmwType, Type type, Value rsrcDesc, auto bufferAtomicRMW = LLVM::createLLVMIntrinsicCallOp( rewriter, loc, instrinsic, bufferType, args); - return bitcast(bufferAtomicRMW.getResult(0), type); + return b.bitcast(bufferAtomicRMW.getResult(0), type); } void BufferEmitter::emitStore(Value rsrcDesc, Value offset, Value data, Value pred, triton::CacheModifier cm) { + auto b = TritonLLVMOpBuilder(loc, rewriter); VectorType vecTy = cast(data.getType()); Type bufferType = getBufferOpType(vecTy, false); if (vecTy != bufferType) - data = bitcast(data, bufferType); + data = b.bitcast(data, bufferType); SmallVector args{data}; fillCommonArgs(vecTy, rsrcDesc, offset, pred, cm, /*isBufferLoad=*/false, args); @@ -192,7 +196,7 @@ void BufferEmitter::fillCommonArgs(Type type, Value rsrcDesc, Value vOffsetElems, Value pred, triton::CacheModifier cm, bool isBufferLoad, SmallVector &args) { - + auto b = TritonLLVMOpBuilder(loc, rewriter); // 1. Create the (masked) offset Type elementType = getElementTypeOrSelf(type); const int valueElemNBits = std::max(8u, elementType.getIntOrFloatBitWidth()); @@ -200,18 +204,18 @@ void BufferEmitter::fillCommonArgs(Type type, Value rsrcDesc, // Please note: the index passed is not in bytes, but in number of elements // In order to pass the index to the buffer operation, we need to convert in // bytes (i.e., we need to multiply by `elementByteWidth`) - Value vOffsetOutOfBunds = int_val( + Value vOffsetOutOfBunds = b.int_val( 32, static_cast(std::numeric_limits::max() + int64_t(1))); - Value vOffsetBytes = mul(int_val(32, elementByteWidth), vOffsetElems); - Value maskedOffsetBytes = select(pred, vOffsetBytes, vOffsetOutOfBunds); + Value vOffsetBytes = b.mul(b.int_val(32, elementByteWidth), vOffsetElems); + Value maskedOffsetBytes = b.select(pred, vOffsetBytes, vOffsetOutOfBunds); // 2. Set the sgprOffset to 0 - Value sgprOffset = int_val(32, 0); + Value sgprOffset = b.int_val(32, 0); // 3. Create the cache modifiers word int32_t aux = getCtrlBitsForCacheModifierOnTarget(cm, isBufferLoad, targetInfo); - Value cacheModifiers = int_val(32, aux); + Value cacheModifiers = b.int_val(32, aux); // 4. Add the arguments args.push_back(rsrcDesc); @@ -224,7 +228,7 @@ void BufferEmitter::fillCommonArgsAtomics(Type type, Value rsrcDesc, Value vOffsetElems, Value pred, bool hasUsers, SmallVector &args) { - + auto b = TritonLLVMOpBuilder(loc, rewriter); // 1. Create the (masked) offset Type elementType = getElementTypeOrSelf(type); const int valueElemNBits = std::max(8u, elementType.getIntOrFloatBitWidth()); @@ -232,13 +236,13 @@ void BufferEmitter::fillCommonArgsAtomics(Type type, Value rsrcDesc, // Please note: the index passed is not in bytes, but in number of elements // In order to pass the index to the buffer operation, we need to convert in // bytes (i.e., we need to multiply by `elementByteWidth`) - Value vOffsetOutOfBunds = int_val( + Value vOffsetOutOfBunds = b.int_val( 32, static_cast(std::numeric_limits::max() + int64_t(1))); - Value vOffsetBytes = mul(int_val(32, elementByteWidth), vOffsetElems); - Value maskedOffsetBytes = select(pred, vOffsetBytes, vOffsetOutOfBunds); + Value vOffsetBytes = b.mul(b.int_val(32, elementByteWidth), vOffsetElems); + Value maskedOffsetBytes = b.select(pred, vOffsetBytes, vOffsetOutOfBunds); // 2. Set the sgprOffset to 0 - Value sgprOffset = int_val(32, 0); + Value sgprOffset = b.int_val(32, 0); // 3. Create the cache modifiers word int32_t aux = 0; @@ -249,7 +253,7 @@ void BufferEmitter::fillCommonArgsAtomics(Type type, Value rsrcDesc, aux = getCtrlBitsForBufferAtomicsOnGFX942( /*setSC0*/ false, /*setSC1*/ false, /*setNT*/ false); - Value cacheModifiers = int_val(32, aux); + Value cacheModifiers = b.int_val(32, aux); // 4. Add the arguments args.push_back(rsrcDesc); diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM.cpp index 3d88ea981b14..59a3f25648ba 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM.cpp @@ -48,6 +48,7 @@ struct ConvertLayoutOpMFMAToDotOpConversion return failure(); auto loc = op.getLoc(); + auto b = TritonLLVMOpBuilder(loc, rewriter); SmallVector inVals = unpackLLElements(loc, adaptor.getSrc(), rewriter); @@ -63,32 +64,32 @@ struct ConvertLayoutOpMFMAToDotOpConversion auto elemTy = int_ty(8); auto vecTy = vec_ty(elemTy, 4); - Value c16 = i32_val(16); - Value c32 = i32_val(32); - Value c48 = i32_val(48); - Value c64 = i32_val(64); + Value c16 = b.i32_val(16); + Value c32 = b.i32_val(32); + Value c48 = b.i32_val(48); + Value c64 = b.i32_val(64); - Value threadId = tid_val(); - Value laneId = urem(threadId, c64); + Value threadId = b.tid_val(); + Value laneId = b.urem(threadId, c64); - Value mask0 = icmp_slt(laneId, c32); - Value mask1 = icmp_slt(urem(laneId, c32), c16); + Value mask0 = b.icmp_slt(laneId, c32); + Value mask1 = b.icmp_slt(b.urem(laneId, c32), c16); - Value addrShift16 = urem(add(laneId, c16), c64); - Value addrShift32 = urem(add(laneId, c32), c64); - Value addrShift48 = urem(add(laneId, c48), c64); + Value addrShift16 = b.urem(b.add(laneId, c16), c64); + Value addrShift32 = b.urem(b.add(laneId, c32), c64); + Value addrShift48 = b.urem(b.add(laneId, c48), c64); SmallVector outVals; for (size_t startIdx = 0; startIdx < inVals.size(); startIdx += 8) { - Value vec0 = undef(vecTy); + Value vec0 = b.undef(vecTy); for (size_t vIdx = 0; vIdx < 4; ++vIdx) { - vec0 = - insert_element(vecTy, vec0, inVals[startIdx + vIdx], i32_val(vIdx)); + vec0 = b.insert_element(vecTy, vec0, inVals[startIdx + vIdx], + b.i32_val(vIdx)); } - Value vec1 = undef(vecTy); + Value vec1 = b.undef(vecTy); for (size_t vIdx = 0; vIdx < 4; ++vIdx) { - vec1 = insert_element(vecTy, vec1, inVals[startIdx + vIdx + 4], - i32_val(vIdx)); + vec1 = b.insert_element(vecTy, vec1, inVals[startIdx + vIdx + 4], + b.i32_val(vIdx)); } Value resVec0, resVec1; @@ -110,17 +111,17 @@ struct ConvertLayoutOpMFMAToDotOpConversion |____________________________________________________________||___ */ - Value shflVec0 = - bitcast(targetInfo.shuffleIdx( - rewriter, loc, bitcast(vec0, int_ty(32)), addrShift32), - vecTy); - Value shflVec1 = - bitcast(targetInfo.shuffleIdx( - rewriter, loc, bitcast(vec1, int_ty(32)), addrShift32), - vecTy); - - resVec0 = select(mask0, vec0, shflVec1); - resVec1 = select(mask0, shflVec0, vec1); + Value shflVec0 = b.bitcast( + targetInfo.shuffleIdx(rewriter, loc, b.bitcast(vec0, int_ty(32)), + addrShift32), + vecTy); + Value shflVec1 = b.bitcast( + targetInfo.shuffleIdx(rewriter, loc, b.bitcast(vec1, int_ty(32)), + addrShift32), + vecTy); + + resVec0 = b.select(mask0, vec0, shflVec1); + resVec1 = b.select(mask0, shflVec0, vec1); } else if (mfmaLayout.getMDim() == 16) { /* 16x16x32 case: @@ -139,34 +140,34 @@ struct ConvertLayoutOpMFMAToDotOpConversion |________________________________________________________________| */ - Value shflVec0_16 = - bitcast(targetInfo.shuffleIdx( - rewriter, loc, bitcast(vec0, int_ty(32)), addrShift16), - vecTy); - Value shflVec0_32 = - bitcast(targetInfo.shuffleIdx( - rewriter, loc, bitcast(vec0, int_ty(32)), addrShift32), - vecTy); - Value shflVec1_32 = - bitcast(targetInfo.shuffleIdx( - rewriter, loc, bitcast(vec1, int_ty(32)), addrShift32), - vecTy); - Value shflVec1_48 = - bitcast(targetInfo.shuffleIdx( - rewriter, loc, bitcast(vec1, int_ty(32)), addrShift48), - vecTy); - - resVec0 = select(mask0, select(mask1, vec0, shflVec0_16), - select(mask1, shflVec1_32, shflVec1_48)); - resVec1 = select(mask0, select(mask1, shflVec0_16, shflVec0_32), - select(mask1, shflVec1_48, vec1)); + Value shflVec0_16 = b.bitcast( + targetInfo.shuffleIdx(rewriter, loc, b.bitcast(vec0, int_ty(32)), + addrShift16), + vecTy); + Value shflVec0_32 = b.bitcast( + targetInfo.shuffleIdx(rewriter, loc, b.bitcast(vec0, int_ty(32)), + addrShift32), + vecTy); + Value shflVec1_32 = b.bitcast( + targetInfo.shuffleIdx(rewriter, loc, b.bitcast(vec1, int_ty(32)), + addrShift32), + vecTy); + Value shflVec1_48 = b.bitcast( + targetInfo.shuffleIdx(rewriter, loc, b.bitcast(vec1, int_ty(32)), + addrShift48), + vecTy); + + resVec0 = b.select(mask0, b.select(mask1, vec0, shflVec0_16), + b.select(mask1, shflVec1_32, shflVec1_48)); + resVec1 = b.select(mask0, b.select(mask1, shflVec0_16, shflVec0_32), + b.select(mask1, shflVec1_48, vec1)); } for (size_t vIdx = 0; vIdx < 4; ++vIdx) { - outVals.push_back(extract_element(elemTy, resVec0, i32_val(vIdx))); + outVals.push_back(b.extract_element(elemTy, resVec0, b.i32_val(vIdx))); } for (size_t vIdx = 0; vIdx < 4; ++vIdx) { - outVals.push_back(extract_element(elemTy, resVec1, i32_val(vIdx))); + outVals.push_back(b.extract_element(elemTy, resVec1, b.i32_val(vIdx))); } } diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandHelper.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandHelper.cpp index 4206705c5175..0cc945368633 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandHelper.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandHelper.cpp @@ -9,11 +9,12 @@ Value getWarpIdInBlock(ConversionPatternRewriter &rewriter, Location loc, Value warpId, const ArrayRef &wpt, int elemPerInstrNonK, int tensorSizeNonK, int nonKIdx, const ArrayRef &order) { + auto b = TritonLLVMOpBuilder(loc, rewriter); SmallVector multiDimWarpId = delinearize(rewriter, loc, warpId, wpt, order); - return urem(multiDimWarpId[nonKIdx], - i32_val(tensorSizeNonK / elemPerInstrNonK)); + return b.urem(multiDimWarpId[nonKIdx], + b.i32_val(tensorSizeNonK / elemPerInstrNonK)); } bool isSwizzled(SharedEncodingAttr layout) { return layout.getMaxPhase() != 1; } @@ -21,6 +22,7 @@ bool isSwizzled(SharedEncodingAttr layout) { return layout.getMaxPhase() != 1; } std::pair swizzleIndexes(ConversionPatternRewriter &rewriter, Location loc, Value row, Value col, SharedMemoryObject smemObj, SharedEncodingAttr attr) { + auto b = TritonLLVMOpBuilder(loc, rewriter); (void)smemObj; // unused in current pattern const auto &order = attr.getOrder(); auto rank = order.size(); @@ -29,18 +31,18 @@ swizzleIndexes(ConversionPatternRewriter &rewriter, Location loc, Value row, // tensor is column-wise, so swapping col and row in computations std::swap(row, col); } - auto vec = i32_val(attr.getVec()); - auto perPhase = i32_val(attr.getPerPhase()); - auto maxPhase = i32_val(attr.getMaxPhase()); + auto vec = b.i32_val(attr.getVec()); + auto perPhase = b.i32_val(attr.getPerPhase()); + auto maxPhase = b.i32_val(attr.getMaxPhase()); // phase = (row // perPhase) % maxPhase // colOffSwizzled = ((col // vec) ^ phase) * vec // colOffOrdered = col % vec // colOff = colOffSwizzled + colOffOrdered - auto phase = urem(udiv(row, perPhase), maxPhase); - auto colOffSwizzled = mul(xor_(udiv(col, vec), phase), vec); - auto colOffOrdered = urem(col, vec); - auto colOff = add(colOffSwizzled, colOffOrdered); + auto phase = b.urem(b.udiv(row, perPhase), maxPhase); + auto colOffSwizzled = b.mul(b.xor_(b.udiv(col, vec), phase), vec); + auto colOffOrdered = b.urem(col, vec); + auto colOff = b.add(colOffSwizzled, colOffOrdered); if (transposed) return {colOff, row}; @@ -51,25 +53,27 @@ swizzleIndexes(ConversionPatternRewriter &rewriter, Location loc, Value row, Value computeOffset(ConversionPatternRewriter &rewriter, Location loc, Value row, Value col, SharedMemoryObject smemObj, ArrayRef smemStrides, SharedEncodingAttr srcLayout) { + auto b = TritonLLVMOpBuilder(loc, rewriter); auto [swizzledRow, swizzledCol] = swizzleIndexes(rewriter, loc, row, col, smemObj, srcLayout); auto rank = smemStrides.size(); assert(rank == 2 || rank == 3); - Value rowOffset = mul(swizzledRow, smemStrides[rank - 2]); - Value colOffset = mul(swizzledCol, smemStrides[rank - 1]); - return add(rowOffset, colOffset); + Value rowOffset = b.mul(swizzledRow, smemStrides[rank - 2]); + Value colOffset = b.mul(swizzledCol, smemStrides[rank - 1]); + return b.add(rowOffset, colOffset); } Value computeBasePtr(ConversionPatternRewriter &rewriter, Location loc, const SharedMemoryObject &smemObj, ArrayRef smemStrides) { + auto b = TritonLLVMOpBuilder(loc, rewriter); Value base = smemObj.getBase(); Type type = base.getType(); Type elemType = smemObj.getBaseElemType(); for (int i = 0; i < smemStrides.size(); ++i) { Value offset = - sub(i32_val(0), mul(smemObj.getOffsets()[i], smemStrides[i])); - base = gep(type, elemType, base, offset); + b.sub(b.i32_val(0), b.mul(smemObj.getOffsets()[i], smemStrides[i])); + base = b.gep(type, elemType, base, offset); } return base; } @@ -120,6 +124,7 @@ computeOffsetsAType(ConversionPatternRewriter &rewriter, Location loc, ArrayRef reps, SharedMemoryObject smemObj, ArrayRef smemStrides, SharedEncodingAttr srcLayout, unsigned nonKDim, unsigned kDim) { + auto b = TritonLLVMOpBuilder(loc, rewriter); SmallVector offsets = smemObj.getOffsets(); auto order = srcLayout.getOrder(); auto rank = offsets.size(); @@ -143,7 +148,7 @@ computeOffsetsAType(ConversionPatternRewriter &rewriter, Location loc, for (int block = 0; block < numBlocks; ++block) { int blockNonKOffset = block * nonKDim * warpsPerBlock; for (int i = 0; i < blockSize; ++i) { - Value row = add(mapping[i][0], i32_val(blockNonKOffset)); + Value row = b.add(mapping[i][0], b.i32_val(blockNonKOffset)); Value col = mapping[i][1]; aOffsets[block * blockSize + i] = computeOffset( rewriter, loc, row, col, smemObj, smemStrides, srcLayout); @@ -160,9 +165,10 @@ computeOffsetsAType(ConversionPatternRewriter &rewriter, Location loc, } for (int block = 0; block < numBlocks; ++block) { int blockNonKOffset = block * nonKDim * warpsPerBlock; - Value offAdjust = mul(i32_val(blockNonKOffset), smemStrides[rank - 2]); + Value offAdjust = + b.mul(b.i32_val(blockNonKOffset), smemStrides[rank - 2]); for (int i = 0; i < blockSize; ++i) - aOffsets[block * blockSize + i] = add(offAdjust, inblockOffset[i]); + aOffsets[block * blockSize + i] = b.add(offAdjust, inblockOffset[i]); } } return aOffsets; @@ -187,6 +193,7 @@ computeOffsetsBType(ConversionPatternRewriter &rewriter, Location loc, ArrayRef reps, SharedMemoryObject smemObj, ArrayRef smemStrides, SharedEncodingAttr srcLayout, unsigned nonKDim, unsigned kDim) { + auto b = TritonLLVMOpBuilder(loc, rewriter); // transpose reps and offsets, because operand B has layout equal to // transposed operand A layout // this unifies axis order, so non-K dim is 0, k dim is 1 @@ -219,7 +226,7 @@ computeOffsetsBType(ConversionPatternRewriter &rewriter, Location loc, // swap row and col, because operand B layout is // a transposed operand A layout Value row = mapping[i][1]; - Value col = add(mapping[i][0], i32_val(blockNonKOffset)); + Value col = b.add(mapping[i][0], b.i32_val(blockNonKOffset)); bOffsets[block * blockSize + i] = computeOffset( rewriter, loc, row, col, smemObj, smemStrides, srcLayout); } @@ -237,9 +244,9 @@ computeOffsetsBType(ConversionPatternRewriter &rewriter, Location loc, } for (int block = 0; block < numBlocks; ++block) { int blockNonKOffset = block * nonKDim * warpsPerBlock; - Value offAdjust = mul(i32_val(blockNonKOffset), tStrides[rank - 2]); + Value offAdjust = b.mul(b.i32_val(blockNonKOffset), tStrides[rank - 2]); for (int i = 0; i < mapping.size(); ++i) - bOffsets[block * blockSize + i] = add(offAdjust, inblockOffset[i]); + bOffsets[block * blockSize + i] = b.add(offAdjust, inblockOffset[i]); } } return bOffsets; diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMFMA.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMFMA.cpp index e4a683d870f8..7b885a75a111 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMFMA.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMFMA.cpp @@ -73,47 +73,49 @@ llvm::SmallVector> computeTensorElemMappingInBlock( const ArrayRef &elemsPerInstr, Value warpId, Value laneId, int numOfElems, ArrayRef reps, ArrayRef smemOffsets, int loadVecSize, unsigned iNonKDim, unsigned iKDim) { + auto b = TritonLLVMOpBuilder(loc, rewriter); auto numM = reps[1]; auto numK = reps[2]; const int loadsPerThread = numOfElems / loadVecSize; llvm::SmallVector> mapping(numK * loadsPerThread); - Value _0 = i32_val(0); - Value _32 = i32_val(32); - Value nonKDim = i32_val(iNonKDim); - Value warpVOffset = mul(warpId, i32_val(elemsPerInstr[0])); + Value _0 = b.i32_val(0); + Value _32 = b.i32_val(32); + Value nonKDim = b.i32_val(iNonKDim); + Value warpVOffset = b.mul(warpId, b.i32_val(elemsPerInstr[0])); auto rank = smemOffsets.size(); for (int tile = 0; tile < numK; ++tile) { Value tileVOffset = _0; - Value tileHOffset = i32_val(tile * elemsPerInstr[1]); + Value tileHOffset = b.i32_val(tile * elemsPerInstr[1]); - Value laneVOffset = urem(laneId, nonKDim); + Value laneVOffset = b.urem(laneId, nonKDim); Value laneHOffset; if (iNonKDim == 32) { - laneHOffset = select(icmp_uge(laneId, _32), i32_val(numOfElems), _0); + laneHOffset = + b.select(b.icmp_uge(laneId, _32), b.i32_val(numOfElems), _0); } else { // In this configuration warp contains 16 copies of same data if ((iKDim == 1 || iKDim == 4) && iNonKDim == 4) { - laneHOffset = i32_val(0); + laneHOffset = b.i32_val(0); } else { assert(iKDim * iNonKDim / numOfElems == 64 && "seems no all threads in warp contain unique elements"); - laneHOffset = mul(udiv(laneId, nonKDim), i32_val(numOfElems)); + laneHOffset = b.mul(b.udiv(laneId, nonKDim), b.i32_val(numOfElems)); } } for (int loadId = 0; loadId < loadsPerThread; ++loadId) { Value elemVOffset = _0; - Value elemHOffset = i32_val(loadId * loadVecSize); + Value elemHOffset = b.i32_val(loadId * loadVecSize); - Value sliceVOffset = - add(add(add(tileVOffset, laneVOffset), elemVOffset), warpVOffset); - Value sliceHOffset = add(add(tileHOffset, laneHOffset), elemHOffset); + Value sliceVOffset = b.add( + b.add(b.add(tileVOffset, laneVOffset), elemVOffset), warpVOffset); + Value sliceHOffset = b.add(b.add(tileHOffset, laneHOffset), elemHOffset); - Value row = add(sliceVOffset, smemOffsets[rank - 2]); - Value col = add(sliceHOffset, smemOffsets[rank - 1]); + Value row = b.add(sliceVOffset, smemOffsets[rank - 2]); + Value col = b.add(sliceHOffset, smemOffsets[rank - 1]); mapping[loadsPerThread * tile + loadId] = {row, col}; } @@ -142,6 +144,7 @@ fastPathComputeOffsets(ConversionPatternRewriter &rewriter, Location loc, const ArrayRef &elemsPerInstr, Value warpId, Value laneId, int warpsPerBlock, int numOfElems, ArrayRef reps, Value cSwizzleOffset) { + auto b = TritonLLVMOpBuilder(loc, rewriter); auto numK = reps[1]; auto numN = reps[2]; SmallVector offsets(numK * numN * numOfElems); @@ -149,14 +152,14 @@ fastPathComputeOffsets(ConversionPatternRewriter &rewriter, Location loc, auto iKDim = elemsPerInstr[0]; auto iNonKDim = elemsPerInstr[1]; int lineSize = warpsPerBlock * iNonKDim * numN; - Value _nonKDim = i32_val(iNonKDim); - Value warpOffset = mul(warpId, i32_val(iNonKDim)); - Value colOffset = urem(laneId, _nonKDim); + Value _nonKDim = b.i32_val(iNonKDim); + Value warpOffset = b.mul(warpId, b.i32_val(iNonKDim)); + Value colOffset = b.urem(laneId, _nonKDim); for (int block = 0; block < numN; ++block) { - Value blockOffset = i32_val(block * iNonKDim * warpsPerBlock); + Value blockOffset = b.i32_val(block * iNonKDim * warpsPerBlock); for (int tile = 0; tile < numK; ++tile) { - Value tileOffset = i32_val(tile * iKDim * lineSize); + Value tileOffset = b.i32_val(tile * iKDim * lineSize); for (int elem = 0; elem < numOfElems; ++elem) { // halfOffset is an offset related to wrapping of warp in the tile. // for example, mfma 32 case (mapping of tensor elements to lane ids in @@ -172,14 +175,14 @@ fastPathComputeOffsets(ConversionPatternRewriter &rewriter, Location loc, // 32 33 34 35 ... 63 Value halfOffset; if ((iKDim == 1 || iKDim == 4) && iNonKDim == 4) - halfOffset = i32_val(0); + halfOffset = b.i32_val(0); else halfOffset = - mul(udiv(laneId, _nonKDim), i32_val(numOfElems * lineSize)); - Value rowOffset = add(i32_val(elem * lineSize), halfOffset); - Value elemOffset = add(rowOffset, colOffset); - Value offset = - add(add(add(warpOffset, blockOffset), tileOffset), elemOffset); + b.mul(b.udiv(laneId, _nonKDim), b.i32_val(numOfElems * lineSize)); + Value rowOffset = b.add(b.i32_val(elem * lineSize), halfOffset); + Value elemOffset = b.add(rowOffset, colOffset); + Value offset = b.add(b.add(b.add(warpOffset, blockOffset), tileOffset), + elemOffset); offsets[numK * numOfElems * block + numOfElems * tile + elem] = offset; } } @@ -196,6 +199,7 @@ Value convertLayout(int opIdx, ConversionPatternRewriter &rewriter, Location loc, Value tensor, DotOperandEncodingAttr encoding, const SharedMemoryObject &smemObj, const LLVMTypeConverter *typeConverter, Value thread) { + auto tb = TritonLLVMOpBuilder(loc, rewriter); assert((opIdx == 0 || opIdx == 1) && "unexpected operand idx"); auto aTensorTy = cast(tensor.getType()); ArrayRef shape = aTensorTy.getShape(); @@ -250,9 +254,9 @@ Value convertLayout(int opIdx, ConversionPatternRewriter &rewriter, unsigned iWarpSize = triton::gpu::getWarpSize(mfmaLayout); assert(iWarpSize == 64); - Value warpSize = i32_val(iWarpSize); - Value linearWarpId = udiv(thread, warpSize); - Value lane = urem(thread, warpSize); + Value warpSize = tb.i32_val(iWarpSize); + Value linearWarpId = tb.udiv(thread, warpSize); + Value lane = tb.urem(thread, warpSize); Value spatialWarpId = AMD::getWarpIdInBlock( rewriter, loc, linearWarpId, warpsPerCTA, mfmaInstrNonK, @@ -271,7 +275,7 @@ Value convertLayout(int opIdx, ConversionPatternRewriter &rewriter, int warpsPerBlockNonK = std::min(warpsPerCTA[nonKDimIdx], maxNumWarps); int warpsPerBatch = rank == 3 ? std::min(shape[0], warpsPerCTA[0]) : 1; - Value warpIdInBatch = urem(linearWarpId, i32_val(warpsPerBatch)); + Value warpIdInBatch = tb.urem(linearWarpId, tb.i32_val(warpsPerBatch)); elemTy = typeConverter->convertType(elemTy); SmallVector loadedValues; @@ -340,23 +344,24 @@ Value convertLayout(int opIdx, ConversionPatternRewriter &rewriter, VectorType loadVecTy = vec_ty(elemTy, elemsPerLoad); for (int b = 0; b < repB; ++b) { int operandSize = shape[rank - 1] * shape[rank - 2]; - Value batchOffset = mul(i32_val(operandSize), - add(warpIdInBatch, i32_val(b * warpsPerBatch))); + Value batchOffset = + tb.mul(tb.i32_val(operandSize), + tb.add(warpIdInBatch, tb.i32_val(b * warpsPerBatch))); for (int nonK = 0; nonK < numRepNonK; ++nonK) { int blockNonKOffset = nonK * mfmaInstrNonK * warpsPerBlockNonK; - Value warpBlockOffAdjust = i32_val(blockNonKOffset * shape[order[0]]); + Value warpBlockOffAdjust = tb.i32_val(blockNonKOffset * shape[order[0]]); for (int k = 0; k < numRepK; ++k) { auto vecTy = vec_ty(resElemTy, numOfElems); for (unsigned loadId = 0; loadId < loadsPerThread; ++loadId) { Value loadOffset; loadOffset = offsets[nonK * loadsPerThread * numRepK + k * loadsPerThread + loadId]; - loadOffset = add(loadOffset, batchOffset); - Value loadAddress = gep(smemPtrTy, elemTy, smemBase, loadOffset); - Value loadedValue = load(loadVecTy, loadAddress); + loadOffset = tb.add(loadOffset, batchOffset); + Value loadAddress = tb.gep(smemPtrTy, elemTy, smemBase, loadOffset); + Value loadedValue = tb.load(loadVecTy, loadAddress); for (int elemId = 0; elemId < elemsPerLoad; ++elemId) { Value elemVal = - extract_element(elemTy, loadedValue, i32_val(elemId)); + tb.extract_element(elemTy, loadedValue, tb.i32_val(elemId)); loadedValues.push_back(elemVal); } } diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandWMMA.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandWMMA.cpp index d147d8627d05..6aaf8283f157 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandWMMA.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandWMMA.cpp @@ -70,26 +70,27 @@ computeTensorElemMappingInBlockWmma1( const ArrayRef &elemsPerInstr, Value warpId, Value laneId, int numOfElems, ArrayRef reps, ArrayRef smemOffsets, int loadVecSize, unsigned iNonKDim, [[maybe_unused]] unsigned iKDim) { + auto b = TritonLLVMOpBuilder(loc, rewriter); assert(reps.size() == 3); assert(elemsPerInstr.size() == 2); auto numK = reps[2]; const int loadsPerThread = numOfElems / loadVecSize; llvm::SmallVector> mapping(numK * loadsPerThread); - Value elemsPerInstrV = i32_val(elemsPerInstr[0]); - Value warpVOffset = mul(warpId, elemsPerInstrV); - Value sliceVOffset = add(urem(laneId, elemsPerInstrV), warpVOffset); + Value elemsPerInstrV = b.i32_val(elemsPerInstr[0]); + Value warpVOffset = b.mul(warpId, elemsPerInstrV); + Value sliceVOffset = b.add(b.urem(laneId, elemsPerInstrV), warpVOffset); auto rank = smemOffsets.size(); - Value row = add(sliceVOffset, smemOffsets[rank - 2]); + Value row = b.add(sliceVOffset, smemOffsets[rank - 2]); for (int tile = 0; tile < numK; ++tile) { - Value tileHOffset = i32_val(tile * elemsPerInstr[1]); + Value tileHOffset = b.i32_val(tile * elemsPerInstr[1]); for (int loadId = 0; loadId < loadsPerThread; ++loadId) { - Value elemHOffset = i32_val(loadId * loadVecSize); - Value sliceHOffset = add(tileHOffset, elemHOffset); + Value elemHOffset = b.i32_val(loadId * loadVecSize); + Value sliceHOffset = b.add(tileHOffset, elemHOffset); - Value col = add(sliceHOffset, smemOffsets[rank - 1]); + Value col = b.add(sliceHOffset, smemOffsets[rank - 1]); mapping[loadsPerThread * tile + loadId] = {row, col}; } } @@ -103,29 +104,30 @@ computeTensorElemMappingInBlockWmma2( const ArrayRef &elemsPerInstr, Value warpId, Value laneId, int numOfElems, ArrayRef reps, ArrayRef smemOffsets, int loadVecSize, unsigned iNonKDim, [[maybe_unused]] unsigned iKDim) { + auto b = TritonLLVMOpBuilder(loc, rewriter); assert(reps.size() == 3); assert(elemsPerInstr.size() == 2); auto numK = reps[2]; const int loadsPerThread = numOfElems / loadVecSize; llvm::SmallVector> mapping(numK * loadsPerThread); - Value rowsPerInstr = i32_val(elemsPerInstr[0]); - Value colsPerInstr = i32_val(elemsPerInstr[1]); - Value elemsPerThread = i32_val(elemsPerInstr[1] / 2); - Value warpVOffset = mul(warpId, rowsPerInstr); - Value sliceVOffset = add(urem(laneId, rowsPerInstr), warpVOffset); + Value rowsPerInstr = b.i32_val(elemsPerInstr[0]); + Value colsPerInstr = b.i32_val(elemsPerInstr[1]); + Value elemsPerThread = b.i32_val(elemsPerInstr[1] / 2); + Value warpVOffset = b.mul(warpId, rowsPerInstr); + Value sliceVOffset = b.add(b.urem(laneId, rowsPerInstr), warpVOffset); auto rank = smemOffsets.size(); - Value row = add(sliceVOffset, smemOffsets[rank - 2]); - Value laneHOffset = mul(udiv(laneId, colsPerInstr), elemsPerThread); + Value row = b.add(sliceVOffset, smemOffsets[rank - 2]); + Value laneHOffset = b.mul(b.udiv(laneId, colsPerInstr), elemsPerThread); for (int tile = 0; tile < numK; ++tile) { - Value tileHOffset = add(laneHOffset, i32_val(tile * elemsPerInstr[1])); + Value tileHOffset = b.add(laneHOffset, b.i32_val(tile * elemsPerInstr[1])); for (int loadId = 0; loadId < loadsPerThread; ++loadId) { - Value elemHOffset = i32_val(loadId * loadVecSize); - Value sliceHOffset = add(tileHOffset, elemHOffset); + Value elemHOffset = b.i32_val(loadId * loadVecSize); + Value sliceHOffset = b.add(tileHOffset, elemHOffset); - Value col = add(sliceHOffset, smemOffsets[rank - 1]); + Value col = b.add(sliceHOffset, smemOffsets[rank - 1]); mapping[loadsPerThread * tile + loadId] = {row, col}; } @@ -138,6 +140,7 @@ Value convertLayout(int opIdx, ConversionPatternRewriter &rewriter, Location loc, Value tensor, DotOperandEncodingAttr encoding, const SharedMemoryObject &smemObj, const LLVMTypeConverter *typeConverter, Value thread) { + auto tb = TritonLLVMOpBuilder(loc, rewriter); assert((opIdx == 0 || opIdx == 1) && "unexpected operand idx"); auto rank = smemObj.getOffsets().size(); int kDimIdx = opIdx == 0 ? rank - 1 : rank - 2; @@ -171,18 +174,18 @@ Value convertLayout(int opIdx, ConversionPatternRewriter &rewriter, unsigned iWaveSize = triton::gpu::getWarpSize(wmmaLayout); assert(iWaveSize == 32); - Value waveSize = i32_val(iWaveSize); - Value linearWaveId = udiv(thread, waveSize); + Value waveSize = tb.i32_val(iWaveSize); + Value linearWaveId = tb.udiv(thread, waveSize); unsigned numElemsPerThreadPerRep = wmmaLayout.getSizePerThreadForOperand(kWidth, opIdx)[kDimIdx]; - Value lane = urem(thread, waveSize); + Value lane = tb.urem(thread, waveSize); unsigned int maxNumWarps = shape[nonKDimIdx] / wmmaInstrNonK; int warpsPerBlockNonK = std::min(warpsPerCTA[nonKDimIdx], maxNumWarps); int warpsPerBatch = rank == 3 ? std::min(shape[0], warpsPerCTA[0]) : 1; - Value waveIdInBatch = urem(linearWaveId, i32_val(warpsPerBatch)); + Value waveIdInBatch = tb.urem(linearWaveId, tb.i32_val(warpsPerBatch)); elemTy = typeConverter->convertType(elemTy); SmallVector loadedValues; @@ -215,21 +218,22 @@ Value convertLayout(int opIdx, ConversionPatternRewriter &rewriter, auto loadVecTy = vec_ty(elemTy, elemsPerLoad); for (int b = 0; b < repB; ++b) { int operandSize = shape[rank - 1] * shape[rank - 2]; - Value batchOffset = mul(i32_val(operandSize), - add(waveIdInBatch, i32_val(b * warpsPerBatch))); + Value batchOffset = + tb.mul(tb.i32_val(operandSize), + tb.add(waveIdInBatch, tb.i32_val(b * warpsPerBatch))); for (int nonK = 0; nonK < numRepNonK; ++nonK) { for (int k = 0; k < numRepK; ++k) { auto vecTy = vec_ty(resElemTy, numElemsPerThreadPerRep); - Value valVec = undef(vecTy); + Value valVec = tb.undef(vecTy); for (unsigned loadId = 0; loadId < loadsPerThread; ++loadId) { Value loadOffset = offsets[nonK * loadsPerThread * numRepK + k * loadsPerThread + loadId]; - loadOffset = add(loadOffset, batchOffset); - Value loadAddress = gep(smemPtrTy, elemTy, smemBase, loadOffset); - Value loadedValue = load(loadVecTy, loadAddress); + loadOffset = tb.add(loadOffset, batchOffset); + Value loadAddress = tb.gep(smemPtrTy, elemTy, smemBase, loadOffset); + Value loadedValue = tb.load(loadVecTy, loadAddress); for (int elemId = 0; elemId < elemsPerLoad; ++elemId) { Value elemVal = - extract_element(elemTy, loadedValue, i32_val(elemId)); + tb.extract_element(elemTy, loadedValue, tb.i32_val(elemId)); loadedValues.push_back(elemVal); } } diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/DotOpToLLVM/MFMA.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/DotOpToLLVM/MFMA.cpp index aa60c1e58cc9..96035f92ee7d 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/DotOpToLLVM/MFMA.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/DotOpToLLVM/MFMA.cpp @@ -62,8 +62,9 @@ struct DotOpMFMAConversionHelper { Value generateMFMAOp(StringRef mfmaInsnName, Value valA, Value valB, Value valC) const { + auto b = TritonLLVMOpBuilder(loc, rewriter); auto resType = valC.getType(); - Value zeroFlag = i32_val(0); + Value zeroFlag = b.i32_val(0); OperationState loweredOp(loc, mfmaInsnName); loweredOp.addTypes(resType); loweredOp.addOperands({valA, valB, valC, zeroFlag, zeroFlag, zeroFlag}); @@ -94,6 +95,7 @@ struct DotOpMFMAConversionHelper { Value processSubBlocks(int numSubBlocks, Value acc, bool reduceSubBlocks, bool zeroSubBlocks) const { + auto b = TritonLLVMOpBuilder(loc, rewriter); assert((numSubBlocks & (numSubBlocks - 1)) == 0 && "numSubBlocks in not pow 2!"); if (numSubBlocks == 1) @@ -101,14 +103,14 @@ struct DotOpMFMAConversionHelper { constexpr int warpSize = 64; int subBlockSize = warpSize / numSubBlocks; Value laneId = getThreadId(); - laneId = and_(laneId, i32_val(warpSize - 1)); + laneId = b.and_(laneId, b.i32_val(warpSize - 1)); auto vecTy = dyn_cast(acc.getType()); auto elemType = vecTy.getElementType(); assert(elemType.getIntOrFloatBitWidth() == 32); int numScalars = vecTy.getNumElements(); std::vector accScalar(numScalars); for (int i = 0; i < numScalars; ++i) - accScalar[i] = extract_element(elemType, acc, i32_val(i)); + accScalar[i] = b.extract_element(elemType, acc, b.i32_val(i)); if (reduceSubBlocks) { while (subBlockSize < warpSize) { @@ -116,9 +118,9 @@ struct DotOpMFMAConversionHelper { Value other_acc = shuffleXor(loc, rewriter, accScalar[i], subBlockSize); if (elemType.isInteger(32)) - accScalar[i] = add(accScalar[i], other_acc); + accScalar[i] = b.add(accScalar[i], other_acc); else - accScalar[i] = fadd(accScalar[i], other_acc); + accScalar[i] = b.fadd(accScalar[i], other_acc); } subBlockSize *= 2; } @@ -126,17 +128,18 @@ struct DotOpMFMAConversionHelper { if (zeroSubBlocks) { Value zero; if (elemType.isInteger(32)) - zero = i32_val(0); + zero = b.i32_val(0); else - zero = f32_val(0.0); - auto cond = icmp_ult(laneId, i32_val(subBlockSize)); + zero = b.f32_val(0.0); + auto cond = b.icmp_ult(laneId, b.i32_val(subBlockSize)); for (int i = 0; i < numScalars; ++i) - accScalar[i] = select(cond, accScalar[i], zero); + accScalar[i] = b.select(cond, accScalar[i], zero); } - Value reducedAcc = undef(vecTy); + Value reducedAcc = b.undef(vecTy); for (int i = 0; i < numScalars; ++i) - reducedAcc = insert_element(vecTy, reducedAcc, accScalar[i], i32_val(i)); + reducedAcc = + b.insert_element(vecTy, reducedAcc, accScalar[i], b.i32_val(i)); return reducedAcc; } @@ -164,6 +167,7 @@ struct DotOpMFMAConversionHelper { // Conduct the Dot conversion. LogicalResult convertDot(DotOp op, DotOpAdaptor adaptor) const { + auto tb = TritonLLVMOpBuilder(loc, rewriter); // Check if this dot has come with priority set by setprio. auto setPrioOp = dyn_cast_or_null(op->getPrevNode()); @@ -248,13 +252,13 @@ struct DotOpMFMAConversionHelper { for (int b = 0; b < numRepB; ++b) { for (int m = 0; m < numRepM; ++m) { for (int n = 0; n < numRepN; ++n) { - Value acc = undef(vecTy); + Value acc = tb.undef(vecTy); for (unsigned v = 0; v < elemsPerVec; ++v) { - acc = insert_element( + acc = tb.insert_element( vecTy, acc, fc[b * numRepM * numRepN * elemsPerVec + m * numRepN * elemsPerVec + n * elemsPerVec + v], - i32_val(v)); + tb.i32_val(v)); } acc = zeroAuxiliarBlocks(subBlocks, acc); for (int k = 0; k < numRepK; k++) { @@ -270,7 +274,7 @@ struct DotOpMFMAConversionHelper { } acc = reduceSubBlocks(subBlocks, acc); for (unsigned v = 0; v < elemsPerVec; ++v) { - Value accElem = extract_element(dstElemTy, acc, i32_val(v)); + Value accElem = tb.extract_element(dstElemTy, acc, tb.i32_val(v)); // Dot operand layout minimal tile is kDimInstrSize elements across // K dimension. If dot operand K dimension is smaller, layout // assigns tensor elements to multiple different hardware locations. @@ -295,13 +299,13 @@ struct DotOpMFMAConversionHelper { auto shiftSize = llvm::Log2_32(duplicationRate); assert(!accElem.getType().isUnsignedInteger() && "MFMA uses signed accumulator"); - accElem = ashr(accElem, i32_val(shiftSize)); + accElem = tb.ashr(accElem, tb.i32_val(shiftSize)); } else { auto multiplierAttr = rewriter.getFloatAttr(dstElemTy, 1.0 / duplicationRate); auto multiplierVal = rewriter.create( loc, dstElemTy, multiplierAttr); - accElem = fmul(accElem, multiplierVal); + accElem = tb.fmul(accElem, multiplierVal); } } auto linearIdx = b * numRepM * numRepN * elemsPerVec + @@ -343,29 +347,31 @@ struct DotOpMFMAConversionHelper { /// kBase elements for each mfma instruction SmallVector extractOperands(Value rawElems, int kWidth, int kBase, Type type) const { + auto b = TritonLLVMOpBuilder(loc, rewriter); int kpack = kWidth / kBase; SmallVector results; auto vecTy = vec_ty(type, kBase); if (type.isBF16()) vecTy = vec_ty(i16_ty, kBase); for (int k = 0; k < kpack; ++k) { - Value vec = undef(vecTy); + Value vec = b.undef(vecTy); for (int elemId = 0; elemId < kBase; ++elemId) { - auto val = extract_element(type, rawElems, i32_val(elemId + k * kBase)); + auto val = + b.extract_element(type, rawElems, b.i32_val(elemId + k * kBase)); if (type.isBF16()) { // rocdl.mfma.f32.32x32x8bf16.1k calls for input of i16 type - auto cast = bitcast(val, i16_ty); - vec = insert_element(vecTy, vec, cast, i32_val(elemId)); + auto cast = b.bitcast(val, i16_ty); + vec = b.insert_element(vecTy, vec, cast, b.i32_val(elemId)); } else { - vec = insert_element(vecTy, vec, val, i32_val(elemId)); + vec = b.insert_element(vecTy, vec, val, b.i32_val(elemId)); } } if (type.getIntOrFloatBitWidth() == 8) { if (4 == kBase) // This is for int8 on pre- MI300 GPUs - results.push_back(bitcast(vec, i32_ty)); + results.push_back(b.bitcast(vec, i32_ty)); if (8 == kBase) - results.push_back(bitcast(vec, i64_ty)); + results.push_back(b.bitcast(vec, i64_ty)); } else { results.push_back(vec); } @@ -379,6 +385,7 @@ struct DotOpMFMAConversionHelper { getValuesFromDotOperandLayoutStruct(Value value, int batch, int n0, int n1, int kWidth, int kBase, Type type, bool allowXF32) const { + auto tb = TritonLLVMOpBuilder(loc, rewriter); auto elems = unpackLLElements(loc, value, rewriter); int kpack = kWidth / kBase; SmallVector dotOpVals(kpack); @@ -387,19 +394,19 @@ struct DotOpMFMAConversionHelper { for (int j = 0; j < n1; j++) { Type elemTy = typeConverter->convertType(type); Type ty = vec_ty(elemTy, kWidth); - Value rawElems = undef(ty); + Value rawElems = tb.undef(ty); for (int k = 0; k < kWidth; ++k) { - rawElems = insert_element( + rawElems = tb.insert_element( ty, rawElems, elems[kWidth * n1 * n0 * b + kWidth * n1 * i + kWidth * j + k], - i32_val(k)); + tb.i32_val(k)); } Value convertedElems; if (type.isF32() && !allowXF32) { for (int k = 0; k < kpack; ++k) dotOpVals[k][{b, i, j}] = - extract_element(type, rawElems, i32_val(k)); + tb.extract_element(type, rawElems, tb.i32_val(k)); } else { SmallVector vals; if (type.isF32() && allowXF32) { diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/DotOpToLLVM/WMMA.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/DotOpToLLVM/WMMA.cpp index ef523d3c6805..fa44931013a0 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/DotOpToLLVM/WMMA.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/DotOpToLLVM/WMMA.cpp @@ -54,6 +54,7 @@ getValuesFromDotOperandLayoutStruct(ConversionPatternRewriter &rewriter, const LLVMTypeConverter *typeConverter, Value value, int batch, int n0, int n1, int kWidth, Type type, Location loc) { + auto tb = TritonLLVMOpBuilder(loc, rewriter); auto elems = unpackLLElements(loc, value, rewriter); ValueTable vals; for (int b = 0; b < batch; b++) { @@ -61,21 +62,21 @@ getValuesFromDotOperandLayoutStruct(ConversionPatternRewriter &rewriter, for (int j = 0; j < n1; j++) { Type elemTy = typeConverter->convertType(type); Type ty = vec_ty(elemTy, kWidth); - Value rawElems = undef(ty); + Value rawElems = tb.undef(ty); for (int k = 0; k < kWidth; ++k) { - rawElems = insert_element( + rawElems = tb.insert_element( ty, rawElems, elems[n0 * n1 * kWidth * b + kWidth * (n1 * i + j) + k], - i32_val(k)); + tb.i32_val(k)); } Value convertedElems; if (type.isF16()) { convertedElems = rawElems; } else if (type.isBF16()) { - convertedElems = bitcast(rawElems, vec_ty(i16_ty, kWidth)); + convertedElems = tb.bitcast(rawElems, vec_ty(i16_ty, kWidth)); } else { - convertedElems = bitcast( + convertedElems = tb.bitcast( rawElems, vec_ty(i32_ty, kWidth * type.getIntOrFloatBitWidth() / i32_ty.getIntOrFloatBitWidth())); } @@ -121,8 +122,9 @@ static WMMAInstrType getWMMAInstrTypeFromDot(DotOp op) { Value generateROCDLOp(ConversionPatternRewriter &rewriter, Location loc, WMMAInstrType wmmaType, Value valA, Value valB, Value valC, Type aElType, Type bElType) { + auto b = TritonLLVMOpBuilder(loc, rewriter); auto resType = valC.getType(); - Value falseFlag = int_val(1, false); + Value falseFlag = b.int_val(1, false); switch (wmmaType) { case WMMAInstrType::FP32_FP16: return rewriter.create( @@ -139,14 +141,14 @@ Value generateROCDLOp(ConversionPatternRewriter &rewriter, Location loc, case WMMAInstrType::I32_I8: return rewriter.create( loc, TypeRange{resType}, - ValueRange{int_val(1, !aElType.isUnsignedInteger()), valA, - int_val(1, !bElType.isUnsignedInteger()), valB, valC, + ValueRange{b.int_val(1, !aElType.isUnsignedInteger()), valA, + b.int_val(1, !bElType.isUnsignedInteger()), valB, valC, falseFlag}); case WMMAInstrType::I32_I4: return rewriter.create( loc, TypeRange{resType}, - ValueRange{int_val(1, !aElType.isUnsignedInteger()), valA, - int_val(1, !bElType.isUnsignedInteger()), valB, valC, + ValueRange{b.int_val(1, !aElType.isUnsignedInteger()), valA, + b.int_val(1, !bElType.isUnsignedInteger()), valB, valC, falseFlag}); default: llvm::report_fatal_error("WMMA data type not supported"); @@ -205,21 +207,22 @@ Value generateWMMAIntrinsic(ConversionPatternRewriter &rewriter, Location loc, WMMAInstrType wmmaType, Value valA, Value valB, Value valC, Type aElType, Type bElType, Type dElType) { + auto b = TritonLLVMOpBuilder(loc, rewriter); auto name = getWmmaIntrinsicName(aElType, bElType, dElType, valA.getType(), valC.getType()); LLVM::FastmathFlagsAttr defaultFlags{}; SmallVector operands; if (aElType.isInteger()) - operands.push_back(int_val(1, !aElType.isUnsignedInteger())); + operands.push_back(b.int_val(1, !aElType.isUnsignedInteger())); operands.push_back(valA); if (bElType.isInteger()) - operands.push_back(int_val(1, !bElType.isUnsignedInteger())); + operands.push_back(b.int_val(1, !bElType.isUnsignedInteger())); operands.push_back(valB); operands.push_back(valC); // Flag for using low bits in registers. Result could be already packed to // int32. Set low bits by default for now. if (32 / dElType.getIntOrFloatBitWidth() > 1 || dElType.isInteger(32)) { - operands.push_back(int_val(1, false)); + operands.push_back(b.int_val(1, false)); } auto wmmaIntrinsic = LLVM::createLLVMIntrinsicCallOp( rewriter, loc, name, valC.getType(), operands); @@ -251,6 +254,7 @@ LogicalResult convertDot(DotOp op, DotOpAdaptor adaptor, auto wmmaInstrType = getWMMAInstrTypeFromDot(op); auto loc = op.getLoc(); + auto tb = TritonLLVMOpBuilder(loc, rewriter); Value a = op.getA(); Value b = op.getB(); Value d = op.getD(); @@ -304,10 +308,10 @@ LogicalResult convertDot(DotOp op, DotOpAdaptor adaptor, auto nRepOffId = n * dElemsToStorePerThread; auto fcThreadOffIdx = batchOffIdx + mRepOffId + nRepOffId; - Value acc = undef(vecTy); + Value acc = tb.undef(vecTy); for (unsigned v = 0; v < dElemsToStorePerThread; ++v) { - acc = insert_element(vecTy, acc, fc[fcThreadOffIdx + v], - i32_val(v * paddedOutputElemSize)); + acc = tb.insert_element(vecTy, acc, fc[fcThreadOffIdx + v], + tb.i32_val(v * paddedOutputElemSize)); } for (size_t k = 0; k < numRepK; k++) { acc = wmmaLayout.getIsTransposed() @@ -321,8 +325,8 @@ LogicalResult convertDot(DotOp op, DotOpAdaptor adaptor, bTensorTy.getElementType(), dstElemTy, wmmaVer); } for (unsigned v = 0; v < dElemsToStorePerThread; ++v) { - fc[fcThreadOffIdx + v] = extract_element( - dstElemTy, acc, i32_val(v * paddedOutputElemSize)); + fc[fcThreadOffIdx + v] = tb.extract_element( + dstElemTy, acc, tb.i32_val(v * paddedOutputElemSize)); } } } diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/ElementwiseOpToLLVM.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/ElementwiseOpToLLVM.cpp index 35a2e1a34bcd..994a81d58ca1 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/ElementwiseOpToLLVM.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/ElementwiseOpToLLVM.cpp @@ -38,55 +38,57 @@ namespace { static SmallVector Fp16_to_Fp8E5M2_RTNE(Location loc, ConversionPatternRewriter &rewriter, const SmallVector &v) { + auto b = TritonLLVMOpBuilder(loc, rewriter); auto fp16x2VecTy = vec_ty(f16_ty, 2); - Value fp16x2Vec0 = undef(fp16x2VecTy); - Value fp16x2Vec1 = undef(fp16x2VecTy); - fp16x2Vec0 = insert_element(fp16x2VecTy, fp16x2Vec0, v[0], i32_val(0)); - fp16x2Vec0 = insert_element(fp16x2VecTy, fp16x2Vec0, v[1], i32_val(1)); - fp16x2Vec1 = insert_element(fp16x2VecTy, fp16x2Vec1, v[2], i32_val(0)); - fp16x2Vec1 = insert_element(fp16x2VecTy, fp16x2Vec1, v[3], i32_val(1)); + Value fp16x2Vec0 = b.undef(fp16x2VecTy); + Value fp16x2Vec1 = b.undef(fp16x2VecTy); + fp16x2Vec0 = b.insert_element(fp16x2VecTy, fp16x2Vec0, v[0], b.i32_val(0)); + fp16x2Vec0 = b.insert_element(fp16x2VecTy, fp16x2Vec0, v[1], b.i32_val(1)); + fp16x2Vec1 = b.insert_element(fp16x2VecTy, fp16x2Vec1, v[2], b.i32_val(0)); + fp16x2Vec1 = b.insert_element(fp16x2VecTy, fp16x2Vec1, v[3], b.i32_val(1)); - Value a0 = bitcast(fp16x2Vec0, i32_ty); - Value a1 = bitcast(fp16x2Vec1, i32_ty); + Value a0 = b.bitcast(fp16x2Vec0, i32_ty); + Value a1 = b.bitcast(fp16x2Vec1, i32_ty); - a0 = and_(i32_ty, a0, i32_val(0xfffefffe)); - a1 = and_(i32_ty, a1, i32_val(0xfffefffe)); + a0 = b.and_(i32_ty, a0, b.i32_val(0xfffefffe)); + a1 = b.and_(i32_ty, a1, b.i32_val(0xfffefffe)); - a0 = add(i32_ty, a0, i32_val(0x00800080)); - a1 = add(i32_ty, a1, i32_val(0x00800080)); + a0 = b.add(i32_ty, a0, b.i32_val(0x00800080)); + a1 = b.add(i32_ty, a1, b.i32_val(0x00800080)); auto fp8x4VecTy = vec_ty(i8_ty, 4); - a0 = bitcast(a0, fp8x4VecTy); - a1 = bitcast(a1, fp8x4VecTy); + a0 = b.bitcast(a0, fp8x4VecTy); + a1 = b.bitcast(a1, fp8x4VecTy); - return {extract_element(i8_ty, a0, i32_val(1)), - extract_element(i8_ty, a0, i32_val(3)), - extract_element(i8_ty, a1, i32_val(1)), - extract_element(i8_ty, a1, i32_val(3))}; + return {b.extract_element(i8_ty, a0, b.i32_val(1)), + b.extract_element(i8_ty, a0, b.i32_val(3)), + b.extract_element(i8_ty, a1, b.i32_val(1)), + b.extract_element(i8_ty, a1, b.i32_val(3))}; } static SmallVector Fp16_to_Fp8E5M2_RTZ(Location loc, ConversionPatternRewriter &rewriter, const SmallVector &v) { + auto b = TritonLLVMOpBuilder(loc, rewriter); auto fp16x2VecTy = vec_ty(f16_ty, 2); - Value fp16x2Vec0 = undef(fp16x2VecTy); - Value fp16x2Vec1 = undef(fp16x2VecTy); - fp16x2Vec0 = insert_element(fp16x2VecTy, fp16x2Vec0, v[0], i32_val(0)); - fp16x2Vec0 = insert_element(fp16x2VecTy, fp16x2Vec0, v[1], i32_val(1)); - fp16x2Vec1 = insert_element(fp16x2VecTy, fp16x2Vec1, v[2], i32_val(0)); - fp16x2Vec1 = insert_element(fp16x2VecTy, fp16x2Vec1, v[3], i32_val(1)); + Value fp16x2Vec0 = b.undef(fp16x2VecTy); + Value fp16x2Vec1 = b.undef(fp16x2VecTy); + fp16x2Vec0 = b.insert_element(fp16x2VecTy, fp16x2Vec0, v[0], b.i32_val(0)); + fp16x2Vec0 = b.insert_element(fp16x2VecTy, fp16x2Vec0, v[1], b.i32_val(1)); + fp16x2Vec1 = b.insert_element(fp16x2VecTy, fp16x2Vec1, v[2], b.i32_val(0)); + fp16x2Vec1 = b.insert_element(fp16x2VecTy, fp16x2Vec1, v[3], b.i32_val(1)); - Value a0 = bitcast(fp16x2Vec0, i32_ty); - Value a1 = bitcast(fp16x2Vec1, i32_ty); + Value a0 = b.bitcast(fp16x2Vec0, i32_ty); + Value a1 = b.bitcast(fp16x2Vec1, i32_ty); auto fp8x4VecTy = vec_ty(i8_ty, 4); - a0 = bitcast(a0, fp8x4VecTy); - a1 = bitcast(a1, fp8x4VecTy); + a0 = b.bitcast(a0, fp8x4VecTy); + a1 = b.bitcast(a1, fp8x4VecTy); - return {extract_element(i8_ty, a0, i32_val(1)), - extract_element(i8_ty, a0, i32_val(3)), - extract_element(i8_ty, a1, i32_val(1)), - extract_element(i8_ty, a1, i32_val(3))}; + return {b.extract_element(i8_ty, a0, b.i32_val(1)), + b.extract_element(i8_ty, a0, b.i32_val(3)), + b.extract_element(i8_ty, a1, b.i32_val(1)), + b.extract_element(i8_ty, a1, b.i32_val(3))}; } //===----------------===// @@ -101,37 +103,40 @@ Fp16_to_Fp8E5M2_RTZ(Location loc, ConversionPatternRewriter &rewriter, static Value Fp16_to_Fp8E4M3FN_RTNE_oneValue(Location loc, ConversionPatternRewriter &rewriter, Value v) { + auto b = TritonLLVMOpBuilder(loc, rewriter); StringRef funcName = "llvm.is.fpclass"; Value isNaN = LLVM::createLLVMIntrinsicCallOp(rewriter, loc, funcName, i1_ty, - {v, i32_val(0x3)}) + {v, b.i32_val(0x3)}) ->getResult(0); // Get sign and absolute value - Value vi16 = bitcast(v, i16_ty); - Value sign = trunc(i8_ty, lshr(and_(vi16, i16_val(0x8000)), i16_val(8))); - vi16 = and_(vi16, i16_val(0x7FFF)); + Value vi16 = b.bitcast(v, i16_ty); + Value sign = + b.trunc(i8_ty, b.lshr(b.and_(vi16, b.i16_val(0x8000)), b.i16_val(8))); + vi16 = b.and_(vi16, b.i16_val(0x7FFF)); // Rounding to nearest even constexpr uint16_t baseRoundingBias = 0x003F; // 1 << (10 - 3 - 1) - 1 // S.EEEEE.MMMMMMMMMM => 0.00000.00M0000000 => 0.00000.000000000M - Value remainingMantissaLSB = lshr(and_(vi16, i16_val(0x0080)), i16_val(7)); - Value roundingBias = add(remainingMantissaLSB, i16_val(baseRoundingBias)); - Value vFp8 = add(vi16, roundingBias); + Value remainingMantissaLSB = + b.lshr(b.and_(vi16, b.i16_val(0x0080)), b.i16_val(7)); + Value roundingBias = b.add(remainingMantissaLSB, b.i16_val(baseRoundingBias)); + Value vFp8 = b.add(vi16, roundingBias); // Reduce mantissa to 3 bits - vFp8 = and_(vFp8, i16_val(0xFF80)); // 0xFF80 == 1.11111.1110000000 + vFp8 = b.and_(vFp8, b.i16_val(0xFF80)); // 0xFF80 == 1.11111.1110000000 // 0x2400 is the FP16 representation of 2^{-6}, which is the smallest normal // number in FP8E4M3FN. We round numbers smaller than that to 0x2400 to make // it easier to handle subnormals - vFp8 = umax(vFp8, i16_val(0x2400)); + vFp8 = b.umax(vFp8, b.i16_val(0x2400)); // Adjust exponent bias - vFp8 = sub(vFp8, i16_val(0x2000)); // (15 - 7) << 10 + vFp8 = b.sub(vFp8, b.i16_val(0x2000)); // (15 - 7) << 10 // Shift right and truncate - vFp8 = trunc(i8_ty, lshr(vFp8, i16_val(7))); // 10 - 3 + vFp8 = b.trunc(i8_ty, b.lshr(vFp8, b.i16_val(7))); // 10 - 3 // 0x5F7F == 0.10111.1101111111 is the largest possible normal // number(including infinity) after rounding in FP8 @@ -139,8 +144,8 @@ Fp16_to_Fp8E4M3FN_RTNE_oneValue(Location loc, // In saturation mode, numbers larger than the max normal number(including // infinity) in FP8 after rounding will be replaced with max_E4M3, i.e. 0x7E // === 0.1111.110 - Value isOverflowOrInf = icmp_ugt(vi16, i16_val(0x5F7F)); - vFp8 = select(isOverflowOrInf, i8_val(0x7E), vFp8); + Value isOverflowOrInf = b.icmp_ugt(vi16, b.i16_val(0x5F7F)); + vFp8 = b.select(isOverflowOrInf, b.i8_val(0x7E), vFp8); // Round subnormals to nearest even. Ref: // https://github.com/openxla/xla/blob/f20c6fe2/xla/service/elemental_ir_emitter.cc#L272 @@ -151,19 +156,19 @@ Fp16_to_Fp8E4M3FN_RTNE_oneValue(Location loc, for (int i = lutSize - 1; i >= 0; i--) { Value cmp; if (i % 2 == 0) { - cmp = icmp_ule(vi16, i16_val(halfwayPointsLUT[i])); + cmp = b.icmp_ule(vi16, b.i16_val(halfwayPointsLUT[i])); } else { - cmp = icmp_ult(vi16, i16_val(halfwayPointsLUT[i])); + cmp = b.icmp_ult(vi16, b.i16_val(halfwayPointsLUT[i])); } - vFp8 = select(cmp, i8_val(i), vFp8); + vFp8 = b.select(cmp, b.i8_val(i), vFp8); } // NaN remains NaN after conversion - vFp8 = select(isNaN, i8_val(0x7F), vFp8); + vFp8 = b.select(isNaN, b.i8_val(0x7F), vFp8); // Set sign bit - vFp8 = or_(vFp8, sign); + vFp8 = b.or_(vFp8, sign); return vFp8; } @@ -192,14 +197,15 @@ static SmallVector cvtFp8ToFp32(Location loc, ConversionPatternRewriter &rewriter, Value v0, Value v1, const std::string &fp8_format) { + auto b = TritonLLVMOpBuilder(loc, rewriter); assert(fp8_format == "fp8" || fp8_format == "bf8"); std::string ins_str = "v_cvt_pk_f32_" + fp8_format; auto fp8x4VecTy = vec_ty(i8_ty, 4); - Value fp8x4Vec = undef(fp8x4VecTy); - fp8x4Vec = insert_element(fp8x4VecTy, fp8x4Vec, v0, i32_val(0)); - fp8x4Vec = insert_element(fp8x4VecTy, fp8x4Vec, v1, i32_val(1)); - auto i32v = bitcast(fp8x4Vec, i32_ty); + Value fp8x4Vec = b.undef(fp8x4VecTy); + fp8x4Vec = b.insert_element(fp8x4VecTy, fp8x4Vec, v0, b.i32_val(0)); + fp8x4Vec = b.insert_element(fp8x4VecTy, fp8x4Vec, v1, b.i32_val(1)); + auto i32v = b.bitcast(fp8x4Vec, i32_ty); GCNBuilder builder1; auto &cvt = *builder1.create(ins_str); @@ -208,11 +214,11 @@ static SmallVector cvtFp8ToFp32(Location loc, cvt(res, operand); auto i64v = builder1.launch(rewriter, loc, i64_ty, false); auto fp32x2VecTy = vec_ty(f32_ty, 2); - auto fp32x2Vec = bitcast(i64v, fp32x2VecTy); + auto fp32x2Vec = b.bitcast(i64v, fp32x2VecTy); SmallVector ret(2); - ret[0] = extract_element(f32_ty, fp32x2Vec, i32_val(0)); - ret[1] = extract_element(f32_ty, fp32x2Vec, i32_val(1)); + ret[0] = b.extract_element(f32_ty, fp32x2Vec, b.i32_val(0)); + ret[1] = b.extract_element(f32_ty, fp32x2Vec, b.i32_val(1)); return ret; } @@ -222,6 +228,7 @@ static SmallVector cvtFp32ToFp8(Location loc, ConversionPatternRewriter &rewriter, Value v0, Value v1, const std::string &fp8_format) { + auto b = TritonLLVMOpBuilder(loc, rewriter); assert(fp8_format == "fp8" || fp8_format == "bf8"); std::string ins_str = "v_cvt_pk_" + fp8_format + "_f32"; @@ -234,11 +241,11 @@ static SmallVector cvtFp32ToFp8(Location loc, auto fp8x4Vec = builder.launch(rewriter, loc, i32_ty, false); auto fp8x4VecTy = vec_ty(i8_ty, 4); - auto a1 = bitcast(fp8x4Vec, fp8x4VecTy); + auto a1 = b.bitcast(fp8x4Vec, fp8x4VecTy); SmallVector ret(2); - ret[0] = extract_element(i8_ty, a1, i32_val(0)); - ret[1] = extract_element(i8_ty, a1, i32_val(1)); + ret[0] = b.extract_element(i8_ty, a1, b.i32_val(0)); + ret[1] = b.extract_element(i8_ty, a1, b.i32_val(1)); return ret; } @@ -302,29 +309,30 @@ Fp8E4M3FNUZ_to_Fp32(Location loc, ConversionPatternRewriter &rewriter, static Value Fp16_to_Fp8E5M2FNUZ_oneValue(Location loc, ConversionPatternRewriter &rewriter, Value v) { - auto vi16 = bitcast(v, i16_ty); - auto e = and_(i16_ty, vi16, int_val(16, 0x7C00)); - auto sign = and_(i16_ty, vi16, int_val(16, 0x8000)); + auto b = TritonLLVMOpBuilder(loc, rewriter); + auto vi16 = b.bitcast(v, i16_ty); + auto e = b.and_(i16_ty, vi16, b.int_val(16, 0x7C00)); + auto sign = b.and_(i16_ty, vi16, b.int_val(16, 0x8000)); // normal value - auto a = and_(i16_ty, vi16, int_val(16, 0x7FFFF)); - auto a1 = add(i16_ty, a, int_val(16, 0x0400)); - auto o1 = or_(i16_ty, a1, sign); + auto a = b.and_(i16_ty, vi16, b.int_val(16, 0x7FFFF)); + auto a1 = b.add(i16_ty, a, b.int_val(16, 0x0400)); + auto o1 = b.or_(i16_ty, a1, sign); // subnormal value, e is 0 - auto m = and_(i16_ty, vi16, int_val(16, 0x03FF)); - auto m2 = shl(m, int_val(16, 1)); - auto o2 = or_(i16_ty, sign, or_(i16_ty, int_val(16, 1), m2)); + auto m = b.and_(i16_ty, vi16, b.int_val(16, 0x03FF)); + auto m2 = b.shl(m, b.int_val(16, 1)); + auto o2 = b.or_(i16_ty, sign, b.or_(i16_ty, b.int_val(16, 1), m2)); - auto e_is_zero = icmp_eq(e, int_val(16, 0)); - auto e_is_all1 = icmp_eq(e, int_val(16, 0x7C00)); + auto e_is_zero = b.icmp_eq(e, b.int_val(16, 0)); + auto e_is_all1 = b.icmp_eq(e, b.int_val(16, 0x7C00)); - auto ot = select(e_is_zero, o2, o1); - auto o = select(e_is_all1, vi16, ot); + auto ot = b.select(e_is_zero, o2, o1); + auto o = b.select(e_is_all1, vi16, ot); auto fp8x2VecTy = vec_ty(i8_ty, 2); - auto res = bitcast(o, fp8x2VecTy); + auto res = b.bitcast(o, fp8x2VecTy); - return extract_element(i8_ty, res, i32_val(1)); + return b.extract_element(i8_ty, res, b.i32_val(1)); } static SmallVector @@ -350,25 +358,26 @@ ConverterT Fp16_to_Fp8E5M2FNUZ(AMD::ISAFamily isaFamily) { static Value Fp8E4M3FN_to_Fp16_oneValue(Location loc, ConversionPatternRewriter &rewriter, Value v) { + auto b = TritonLLVMOpBuilder(loc, rewriter); auto fp8x2VecTy = vec_ty(i8_ty, 2); - Value a = undef(fp8x2VecTy); - a = insert_element(fp8x2VecTy, a, i8_val(0), i32_val(0)); - a = insert_element(fp8x2VecTy, a, v, i32_val(1)); - a = bitcast(a, i16_ty); + Value a = b.undef(fp8x2VecTy); + a = b.insert_element(fp8x2VecTy, a, b.i8_val(0), b.i32_val(0)); + a = b.insert_element(fp8x2VecTy, a, v, b.i32_val(1)); + a = b.bitcast(a, i16_ty); // Get sign and absolute value - Value sign = and_(a, i16_val(0x8000)); - a = and_(a, i16_val(0x7FFF)); + Value sign = b.and_(a, b.i16_val(0x8000)); + a = b.and_(a, b.i16_val(0x7FFF)); // Right shift 1 bit to adjust the positions of exponent and mantissa - a = lshr(a, i16_val(1)); + a = b.lshr(a, b.i16_val(1)); // Adjust exponent, (15 - 7) << 10 === 0x2000 - a = add(a, i16_val(0x2000)); + a = b.add(a, b.i16_val(0x2000)); // Check NaN - Value vAbs = and_(bitcast(v, i8_ty), i8_val(0x7F)); - a = select(icmp_eq(vAbs, i8_val(0x7F)), i16_val(0x7E00), a); + Value vAbs = b.and_(b.bitcast(v, i8_ty), b.i8_val(0x7F)); + a = b.select(b.icmp_eq(vAbs, b.i8_val(0x7F)), b.i16_val(0x7E00), a); // Check denorms and zero // Here we use a LUT to map S.0000.000 ~ S.0000.111 to its corresponding fp16 @@ -378,12 +387,13 @@ static Value Fp8E4M3FN_to_Fp16_oneValue(Location loc, 0x0000, 0x1800, 0x1C00, 0x1E00, 0x2000, 0x2100, 0x2200, 0x2300}; for (int i = 0; i < lutSize; i++) { - a = select(icmp_eq(vAbs, i8_val(i)), i16_val(denormsAndZeroLut[i]), a); + a = b.select(b.icmp_eq(vAbs, b.i8_val(i)), b.i16_val(denormsAndZeroLut[i]), + a); } // Set sign - a = or_(a, sign); - a = bitcast(a, f16_ty); + a = b.or_(a, sign); + a = b.bitcast(a, f16_ty); return a; } @@ -400,37 +410,39 @@ static SmallVector Fp8E4M3FN_to_Fp16(Location loc, static SmallVector Fp8E5M2_to_Fp16(Location loc, ConversionPatternRewriter &rewriter, const SmallVector &v) { + auto b = TritonLLVMOpBuilder(loc, rewriter); auto fp8x4VecTy = vec_ty(i8_ty, 4); - Value a0 = undef(fp8x4VecTy); - a0 = insert_element(fp8x4VecTy, a0, int_val(8, 0), i32_val(0)); - a0 = insert_element(fp8x4VecTy, a0, v[0], i32_val(1)); - a0 = insert_element(fp8x4VecTy, a0, int_val(8, 0), i32_val(2)); - a0 = insert_element(fp8x4VecTy, a0, v[1], i32_val(3)); - a0 = bitcast(a0, i32_ty); - Value a1 = undef(fp8x4VecTy); - a1 = insert_element(fp8x4VecTy, a1, int_val(8, 0), i32_val(0)); - a1 = insert_element(fp8x4VecTy, a1, v[2], i32_val(1)); - a1 = insert_element(fp8x4VecTy, a1, int_val(8, 0), i32_val(2)); - a1 = insert_element(fp8x4VecTy, a1, v[3], i32_val(3)); - a1 = bitcast(a1, i32_ty); + Value a0 = b.undef(fp8x4VecTy); + a0 = b.insert_element(fp8x4VecTy, a0, b.int_val(8, 0), b.i32_val(0)); + a0 = b.insert_element(fp8x4VecTy, a0, v[0], b.i32_val(1)); + a0 = b.insert_element(fp8x4VecTy, a0, b.int_val(8, 0), b.i32_val(2)); + a0 = b.insert_element(fp8x4VecTy, a0, v[1], b.i32_val(3)); + a0 = b.bitcast(a0, i32_ty); + Value a1 = b.undef(fp8x4VecTy); + a1 = b.insert_element(fp8x4VecTy, a1, b.int_val(8, 0), b.i32_val(0)); + a1 = b.insert_element(fp8x4VecTy, a1, v[2], b.i32_val(1)); + a1 = b.insert_element(fp8x4VecTy, a1, b.int_val(8, 0), b.i32_val(2)); + a1 = b.insert_element(fp8x4VecTy, a1, v[3], b.i32_val(3)); + a1 = b.bitcast(a1, i32_ty); auto fp16x2VecTy = vec_ty(f16_ty, 2); - auto fp16x2Vec0 = bitcast(a0, fp16x2VecTy); - auto fp16x2Vec1 = bitcast(a1, fp16x2VecTy); + auto fp16x2Vec0 = b.bitcast(a0, fp16x2VecTy); + auto fp16x2Vec1 = b.bitcast(a1, fp16x2VecTy); - return {extract_element(f16_ty, fp16x2Vec0, i32_val(0)), - extract_element(f16_ty, fp16x2Vec0, i32_val(1)), - extract_element(f16_ty, fp16x2Vec1, i32_val(0)), - extract_element(f16_ty, fp16x2Vec1, i32_val(1))}; + return {b.extract_element(f16_ty, fp16x2Vec0, b.i32_val(0)), + b.extract_element(f16_ty, fp16x2Vec0, b.i32_val(1)), + b.extract_element(f16_ty, fp16x2Vec1, b.i32_val(0)), + b.extract_element(f16_ty, fp16x2Vec1, b.i32_val(1))}; } static Value convertBf16ToFp32(Location loc, ConversionPatternRewriter &rewriter, const Value &v) { - auto as_int16 = bitcast(v, i16_ty); - auto as_int32 = zext(i32_ty, as_int16); - auto shifted = shl(i32_ty, as_int32, i32_val(16)); - return bitcast(shifted, f32_ty); + auto b = TritonLLVMOpBuilder(loc, rewriter); + auto as_int16 = b.bitcast(v, i16_ty); + auto as_int32 = b.zext(i32_ty, as_int16); + auto shifted = b.shl(i32_ty, as_int32, b.i32_val(16)); + return b.bitcast(shifted, f32_ty); } static Value buildGCNInstruction(Location loc, RewriterBase &rewritter, @@ -459,11 +471,12 @@ static Value buildGCNInstruction(Location loc, RewriterBase &rewritter, static Value convertFp32ToBf16(Location loc, ConversionPatternRewriter &rewriter, const Value &v, const RoundingMode rounding) { - auto as_int32 = bitcast(v, i32_ty); + auto b = TritonLLVMOpBuilder(loc, rewriter); + auto as_int32 = b.bitcast(v, i32_ty); if (rounding == RoundingMode::RTZ) { - auto shifted = lshr(i32_ty, as_int32, i32_val(16)); - auto truncated = trunc(i16_ty, shifted); - return bitcast(truncated, bf16_ty); + auto shifted = b.lshr(i32_ty, as_int32, b.i32_val(16)); + auto truncated = b.trunc(i16_ty, shifted); + return b.bitcast(truncated, bf16_ty); } // This implementation is a faster version for fp32 to bf16 type conversion @@ -476,63 +489,64 @@ static Value convertFp32ToBf16(Location loc, Value isNan = buildGCNInstruction(loc, rewriter, "v_cmp_u_f32", constraints0, vals0, i64_ty); - Value v16 = i32_val(16); - Value v1 = i32_val(1); + Value v16 = b.i32_val(16); + Value v1 = b.i32_val(1); SmallVector constraints1 = {"=v", "v", "v", "v"}; SmallVector vals1 = {v, v16, v1}; Value tmp = buildGCNInstruction(loc, rewriter, "v_bfe_u32", constraints1, vals1, i32_ty); SmallVector constraints2 = {"=v", "v", "v", "v"}; - Value v7FFF = i32_val(0x7FFF); + Value v7FFF = b.i32_val(0x7FFF); SmallVector vals2 = {v, tmp, v7FFF}; Value tmp1 = buildGCNInstruction(loc, rewriter, "v_add3_u32", constraints2, vals2, i32_ty); SmallVector constraints3 = {"=v", "v", "v", "s"}; - Value vNan = i32_val(0x7FFF0000); + Value vNan = b.i32_val(0x7FFF0000); SmallVector vals3 = {tmp1, vNan, isNan}; Value cndMask = buildGCNInstruction(loc, rewriter, "v_cndmask_b32", constraints3, vals3, i32_ty); - Value shifted = lshr(i32_ty, cndMask, v16); - Value truncated = trunc(i16_ty, shifted); - return bitcast(truncated, bf16_ty); + Value shifted = b.lshr(i32_ty, cndMask, v16); + Value truncated = b.trunc(i16_ty, shifted); + return b.bitcast(truncated, bf16_ty); } static Value Fp8E5M2FNUZ_to_Fp16_oneValue(Location loc, ConversionPatternRewriter &rewriter, Value v) { + auto b = TritonLLVMOpBuilder(loc, rewriter); auto fp8x2VecTy = vec_ty(i8_ty, 2); - Value a = undef(fp8x2VecTy); - a = insert_element(fp8x2VecTy, a, int_val(8, 0), i32_val(0)); - a = insert_element(fp8x2VecTy, a, v, i32_val(1)); - a = bitcast(a, i16_ty); + Value a = b.undef(fp8x2VecTy); + a = b.insert_element(fp8x2VecTy, a, b.int_val(8, 0), b.i32_val(0)); + a = b.insert_element(fp8x2VecTy, a, v, b.i32_val(1)); + a = b.bitcast(a, i16_ty); - auto e = and_(i16_ty, a, int_val(16, 0x7C00)); - auto m = and_(i16_ty, a, int_val(16, 0x0300)); - auto sign = and_(i16_ty, a, int_val(16, 0x8000)); + auto e = b.and_(i16_ty, a, b.int_val(16, 0x7C00)); + auto m = b.and_(i16_ty, a, b.int_val(16, 0x0300)); + auto sign = b.and_(i16_ty, a, b.int_val(16, 0x8000)); // check whether all exponents are zeros - auto e_is_zero = icmp_eq(e, int_val(16, 0x0)); + auto e_is_zero = b.icmp_eq(e, b.int_val(16, 0x0)); // case 1, e is zero, need to move m right by 1 bit - auto m1 = lshr(i16_ty, m, int_val(16, 1)); - auto o0 = or_(i16_ty, sign, m1); + auto m1 = b.lshr(i16_ty, m, b.int_val(16, 1)); + auto o0 = b.or_(i16_ty, sign, m1); // case 2, e is nonzero, sub exponent by 1 - auto e1 = sub(i16_ty, e, int_val(16, 0x0400)); + auto e1 = b.sub(i16_ty, e, b.int_val(16, 0x0400)); - auto e_is_one = icmp_eq(e, int_val(16, 0x0400)); - auto m2 = add(i16_ty, m1, int_val(16, 0x0200)); + auto e_is_one = b.icmp_eq(e, b.int_val(16, 0x0400)); + auto m2 = b.add(i16_ty, m1, b.int_val(16, 0x0200)); - auto o1 = or_(i16_ty, sign, or_(i16_ty, m, e1)); - auto o2 = or_(i16_ty, sign, m2); + auto o1 = b.or_(i16_ty, sign, b.or_(i16_ty, m, e1)); + auto o2 = b.or_(i16_ty, sign, m2); - auto o12 = select(e_is_one, o2, o1); - auto o = select(e_is_zero, o0, o12); + auto o12 = b.select(e_is_one, o2, o1); + auto o = b.select(e_is_zero, o0, o12); - return bitcast(o, f16_ty); + return b.bitcast(o, f16_ty); } static SmallVector @@ -558,143 +572,149 @@ ConverterT Fp8E5M2FNUZ_to_Fp16(AMD::ISAFamily isaFamily) { static SmallVector Fp8E5M2_to_Bf16(Location loc, ConversionPatternRewriter &rewriter, const SmallVector &v) { + auto b = TritonLLVMOpBuilder(loc, rewriter); auto fp8x4VecTy = vec_ty(i8_ty, 4); - Value a0 = undef(fp8x4VecTy); - a0 = insert_element(fp8x4VecTy, a0, int_val(8, 0), i32_val(0)); - a0 = insert_element(fp8x4VecTy, a0, v[0], i32_val(1)); - a0 = insert_element(fp8x4VecTy, a0, int_val(8, 0), i32_val(2)); - a0 = insert_element(fp8x4VecTy, a0, v[1], i32_val(3)); - a0 = bitcast(a0, i32_ty); - - Value a1 = undef(fp8x4VecTy); - a1 = insert_element(fp8x4VecTy, a1, int_val(8, 0), i32_val(0)); - a1 = insert_element(fp8x4VecTy, a1, v[2], i32_val(1)); - a1 = insert_element(fp8x4VecTy, a1, int_val(8, 0), i32_val(2)); - a1 = insert_element(fp8x4VecTy, a1, v[3], i32_val(3)); - a1 = bitcast(a1, i32_ty); - - Value b0 = and_(i32_ty, a0, i32_val(0x7fff7fff)); - Value b1 = and_(i32_ty, a1, i32_val(0x7fff7fff)); - b0 = lshr(i32_ty, b0, i32_val(3)); - b1 = lshr(i32_ty, b1, i32_val(3)); - - Value c0 = shl(i32_ty, b0, i32_val(16)); - Value c1 = and_(i32_ty, b0, i32_val(0xFFFF0000)); - Value c2 = shl(i32_ty, b1, i32_val(16)); - Value c3 = and_(i32_ty, b1, i32_val(0xFFFF0000)); - - c0 = bitcast(c0, f32_ty); - c1 = bitcast(c1, f32_ty); - c2 = bitcast(c2, f32_ty); - c3 = bitcast(c3, f32_ty); - - Value d0 = fmul(f32_ty, c0, f32_val(0x1p+112)); - Value d1 = fmul(f32_ty, c1, f32_val(0x1p+112)); - Value d2 = fmul(f32_ty, c2, f32_val(0x1p+112)); - Value d3 = fmul(f32_ty, c3, f32_val(0x1p+112)); - - d0 = bitcast(d0, i32_ty); - d1 = bitcast(d1, i32_ty); - d2 = bitcast(d2, i32_ty); - d3 = bitcast(d3, i32_ty); - - Value out0 = or_(i32_ty, lshr(i32_ty, d0, i32_val(16)), d1); - Value out1 = or_(i32_ty, lshr(i32_ty, d2, i32_val(16)), d3); - - Value sign0 = and_(i32_ty, a0, i32_val(0x80008000)); - Value sign1 = and_(i32_ty, a1, i32_val(0x80008000)); - - out0 = or_(i32_ty, out0, sign0); - out1 = or_(i32_ty, out1, sign1); + Value a0 = b.undef(fp8x4VecTy); + a0 = b.insert_element(fp8x4VecTy, a0, b.int_val(8, 0), b.i32_val(0)); + a0 = b.insert_element(fp8x4VecTy, a0, v[0], b.i32_val(1)); + a0 = b.insert_element(fp8x4VecTy, a0, b.int_val(8, 0), b.i32_val(2)); + a0 = b.insert_element(fp8x4VecTy, a0, v[1], b.i32_val(3)); + a0 = b.bitcast(a0, i32_ty); + + Value a1 = b.undef(fp8x4VecTy); + a1 = b.insert_element(fp8x4VecTy, a1, b.int_val(8, 0), b.i32_val(0)); + a1 = b.insert_element(fp8x4VecTy, a1, v[2], b.i32_val(1)); + a1 = b.insert_element(fp8x4VecTy, a1, b.int_val(8, 0), b.i32_val(2)); + a1 = b.insert_element(fp8x4VecTy, a1, v[3], b.i32_val(3)); + a1 = b.bitcast(a1, i32_ty); + + Value b0 = b.and_(i32_ty, a0, b.i32_val(0x7fff7fff)); + Value b1 = b.and_(i32_ty, a1, b.i32_val(0x7fff7fff)); + b0 = b.lshr(i32_ty, b0, b.i32_val(3)); + b1 = b.lshr(i32_ty, b1, b.i32_val(3)); + + Value c0 = b.shl(i32_ty, b0, b.i32_val(16)); + Value c1 = b.and_(i32_ty, b0, b.i32_val(0xFFFF0000)); + Value c2 = b.shl(i32_ty, b1, b.i32_val(16)); + Value c3 = b.and_(i32_ty, b1, b.i32_val(0xFFFF0000)); + + c0 = b.bitcast(c0, f32_ty); + c1 = b.bitcast(c1, f32_ty); + c2 = b.bitcast(c2, f32_ty); + c3 = b.bitcast(c3, f32_ty); + + Value d0 = b.fmul(f32_ty, c0, b.f32_val(0x1p+112)); + Value d1 = b.fmul(f32_ty, c1, b.f32_val(0x1p+112)); + Value d2 = b.fmul(f32_ty, c2, b.f32_val(0x1p+112)); + Value d3 = b.fmul(f32_ty, c3, b.f32_val(0x1p+112)); + + d0 = b.bitcast(d0, i32_ty); + d1 = b.bitcast(d1, i32_ty); + d2 = b.bitcast(d2, i32_ty); + d3 = b.bitcast(d3, i32_ty); + + Value out0 = b.or_(i32_ty, b.lshr(i32_ty, d0, b.i32_val(16)), d1); + Value out1 = b.or_(i32_ty, b.lshr(i32_ty, d2, b.i32_val(16)), d3); + + Value sign0 = b.and_(i32_ty, a0, b.i32_val(0x80008000)); + Value sign1 = b.and_(i32_ty, a1, b.i32_val(0x80008000)); + + out0 = b.or_(i32_ty, out0, sign0); + out1 = b.or_(i32_ty, out1, sign1); auto bf16x2VecTy = vec_ty(bf16_ty, 2); - out0 = bitcast(out0, bf16x2VecTy); - out1 = bitcast(out1, bf16x2VecTy); + out0 = b.bitcast(out0, bf16x2VecTy); + out1 = b.bitcast(out1, bf16x2VecTy); - return {extract_element(bf16_ty, out0, i32_val(0)), - extract_element(bf16_ty, out0, i32_val(1)), - extract_element(bf16_ty, out1, i32_val(0)), - extract_element(bf16_ty, out1, i32_val(1))}; + return {b.extract_element(bf16_ty, out0, b.i32_val(0)), + b.extract_element(bf16_ty, out0, b.i32_val(1)), + b.extract_element(bf16_ty, out1, b.i32_val(0)), + b.extract_element(bf16_ty, out1, b.i32_val(1))}; } static SmallVector Bf16_to_Fp8E5M2(Location loc, ConversionPatternRewriter &rewriter, const SmallVector &v) { + auto b = TritonLLVMOpBuilder(loc, rewriter); auto bf16x2VecTy = vec_ty(bf16_ty, 2); - Value bf16x2Vec0 = undef(bf16x2VecTy); - Value bf16x2Vec1 = undef(bf16x2VecTy); - bf16x2Vec0 = insert_element(bf16x2VecTy, bf16x2Vec0, v[0], i32_val(0)); - bf16x2Vec0 = insert_element(bf16x2VecTy, bf16x2Vec0, v[1], i32_val(1)); - bf16x2Vec1 = insert_element(bf16x2VecTy, bf16x2Vec1, v[2], i32_val(0)); - bf16x2Vec1 = insert_element(bf16x2VecTy, bf16x2Vec1, v[3], i32_val(1)); - bf16x2Vec0 = bitcast(bf16x2Vec0, i32_ty); - bf16x2Vec1 = bitcast(bf16x2Vec1, i32_ty); - - Value sign0 = and_(i32_ty, bf16x2Vec0, i32_val(0x80008000)); - Value sign1 = and_(i32_ty, bf16x2Vec1, i32_val(0x80008000)); + Value bf16x2Vec0 = b.undef(bf16x2VecTy); + Value bf16x2Vec1 = b.undef(bf16x2VecTy); + bf16x2Vec0 = b.insert_element(bf16x2VecTy, bf16x2Vec0, v[0], b.i32_val(0)); + bf16x2Vec0 = b.insert_element(bf16x2VecTy, bf16x2Vec0, v[1], b.i32_val(1)); + bf16x2Vec1 = b.insert_element(bf16x2VecTy, bf16x2Vec1, v[2], b.i32_val(0)); + bf16x2Vec1 = b.insert_element(bf16x2VecTy, bf16x2Vec1, v[3], b.i32_val(1)); + bf16x2Vec0 = b.bitcast(bf16x2Vec0, i32_ty); + bf16x2Vec1 = b.bitcast(bf16x2Vec1, i32_ty); + + Value sign0 = b.and_(i32_ty, bf16x2Vec0, b.i32_val(0x80008000)); + Value sign1 = b.and_(i32_ty, bf16x2Vec1, b.i32_val(0x80008000)); auto fp8x4VecTy = vec_ty(i8_ty, 4); - Value sign = undef(fp8x4VecTy); - sign0 = bitcast(sign0, fp8x4VecTy); - sign1 = bitcast(sign1, fp8x4VecTy); - sign = insert_element(fp8x4VecTy, sign, - extract_element(i8_ty, sign0, i32_val(1)), i32_val(0)); - sign = insert_element(fp8x4VecTy, sign, - extract_element(i8_ty, sign0, i32_val(3)), i32_val(1)); - sign = insert_element(fp8x4VecTy, sign, - extract_element(i8_ty, sign1, i32_val(1)), i32_val(2)); - sign = insert_element(fp8x4VecTy, sign, - extract_element(i8_ty, sign1, i32_val(3)), i32_val(3)); - sign = bitcast(sign, i32_ty); - - Value nosign0 = and_(i32_ty, bf16x2Vec0, i32_val(0x7fff7fff)); - Value nosign1 = and_(i32_ty, bf16x2Vec1, i32_val(0x7fff7fff)); - - Value nosign_0_0 = and_(i32_ty, nosign0, i32_val(0xffff0000)); - nosign_0_0 = umax(i32_ty, nosign_0_0, i32_val(0x38000000)); - nosign_0_0 = umin(i32_ty, nosign_0_0, i32_val(0x57e00000)); - Value nosign_0_1 = and_(i32_ty, nosign0, i32_val(0x0000ffff)); - nosign_0_1 = umax(i32_ty, nosign_0_1, i32_val(0x3800)); - nosign_0_1 = umin(i32_ty, nosign_0_1, i32_val(0x57e0)); - nosign0 = or_(i32_ty, nosign_0_0, nosign_0_1); - - Value nosign_1_0 = and_(i32_ty, nosign1, i32_val(0xffff0000)); - nosign_1_0 = umax(i32_ty, nosign_1_0, i32_val(0x38000000)); - nosign_1_0 = umin(i32_ty, nosign_1_0, i32_val(0x57e00000)); - Value nosign_1_1 = and_(i32_ty, nosign1, i32_val(0x0000ffff)); - nosign_1_1 = umax(i32_ty, nosign_1_1, i32_val(0x3800)); - nosign_1_1 = umin(i32_ty, nosign_1_1, i32_val(0x57e0)); - nosign1 = or_(i32_ty, nosign_1_0, nosign_1_1); - - nosign0 = add(i32_ty, nosign0, i32_val(0x00100010)); - nosign1 = add(i32_ty, nosign1, i32_val(0x00100010)); - nosign0 = sub(i32_ty, nosign0, i32_val(0x38003800)); - nosign1 = sub(i32_ty, nosign1, i32_val(0x38003800)); - nosign0 = shl(i32_ty, nosign0, i32_val(3)); - nosign1 = shl(i32_ty, nosign1, i32_val(3)); - - nosign0 = bitcast(nosign0, fp8x4VecTy); - nosign1 = bitcast(nosign1, fp8x4VecTy); - Value nosign = undef(fp8x4VecTy); - nosign = - insert_element(fp8x4VecTy, nosign, - extract_element(i8_ty, nosign0, i32_val(1)), i32_val(0)); - nosign = - insert_element(fp8x4VecTy, nosign, - extract_element(i8_ty, nosign0, i32_val(3)), i32_val(1)); - nosign = - insert_element(fp8x4VecTy, nosign, - extract_element(i8_ty, nosign1, i32_val(1)), i32_val(2)); - nosign = - insert_element(fp8x4VecTy, nosign, - extract_element(i8_ty, nosign1, i32_val(3)), i32_val(3)); - nosign = bitcast(nosign, i32_ty); - - Value fp8x4Vec = or_(i32_ty, nosign, sign); - fp8x4Vec = bitcast(fp8x4Vec, fp8x4VecTy); - return {extract_element(i8_ty, fp8x4Vec, i32_val(0)), - extract_element(i8_ty, fp8x4Vec, i32_val(1)), - extract_element(i8_ty, fp8x4Vec, i32_val(2)), - extract_element(i8_ty, fp8x4Vec, i32_val(3))}; + Value sign = b.undef(fp8x4VecTy); + sign0 = b.bitcast(sign0, fp8x4VecTy); + sign1 = b.bitcast(sign1, fp8x4VecTy); + sign = b.insert_element(fp8x4VecTy, sign, + b.extract_element(i8_ty, sign0, b.i32_val(1)), + b.i32_val(0)); + sign = b.insert_element(fp8x4VecTy, sign, + b.extract_element(i8_ty, sign0, b.i32_val(3)), + b.i32_val(1)); + sign = b.insert_element(fp8x4VecTy, sign, + b.extract_element(i8_ty, sign1, b.i32_val(1)), + b.i32_val(2)); + sign = b.insert_element(fp8x4VecTy, sign, + b.extract_element(i8_ty, sign1, b.i32_val(3)), + b.i32_val(3)); + sign = b.bitcast(sign, i32_ty); + + Value nosign0 = b.and_(i32_ty, bf16x2Vec0, b.i32_val(0x7fff7fff)); + Value nosign1 = b.and_(i32_ty, bf16x2Vec1, b.i32_val(0x7fff7fff)); + + Value nosign_0_0 = b.and_(i32_ty, nosign0, b.i32_val(0xffff0000)); + nosign_0_0 = b.umax(i32_ty, nosign_0_0, b.i32_val(0x38000000)); + nosign_0_0 = b.umin(i32_ty, nosign_0_0, b.i32_val(0x57e00000)); + Value nosign_0_1 = b.and_(i32_ty, nosign0, b.i32_val(0x0000ffff)); + nosign_0_1 = b.umax(i32_ty, nosign_0_1, b.i32_val(0x3800)); + nosign_0_1 = b.umin(i32_ty, nosign_0_1, b.i32_val(0x57e0)); + nosign0 = b.or_(i32_ty, nosign_0_0, nosign_0_1); + + Value nosign_1_0 = b.and_(i32_ty, nosign1, b.i32_val(0xffff0000)); + nosign_1_0 = b.umax(i32_ty, nosign_1_0, b.i32_val(0x38000000)); + nosign_1_0 = b.umin(i32_ty, nosign_1_0, b.i32_val(0x57e00000)); + Value nosign_1_1 = b.and_(i32_ty, nosign1, b.i32_val(0x0000ffff)); + nosign_1_1 = b.umax(i32_ty, nosign_1_1, b.i32_val(0x3800)); + nosign_1_1 = b.umin(i32_ty, nosign_1_1, b.i32_val(0x57e0)); + nosign1 = b.or_(i32_ty, nosign_1_0, nosign_1_1); + + nosign0 = b.add(i32_ty, nosign0, b.i32_val(0x00100010)); + nosign1 = b.add(i32_ty, nosign1, b.i32_val(0x00100010)); + nosign0 = b.sub(i32_ty, nosign0, b.i32_val(0x38003800)); + nosign1 = b.sub(i32_ty, nosign1, b.i32_val(0x38003800)); + nosign0 = b.shl(i32_ty, nosign0, b.i32_val(3)); + nosign1 = b.shl(i32_ty, nosign1, b.i32_val(3)); + + nosign0 = b.bitcast(nosign0, fp8x4VecTy); + nosign1 = b.bitcast(nosign1, fp8x4VecTy); + Value nosign = b.undef(fp8x4VecTy); + nosign = b.insert_element(fp8x4VecTy, nosign, + b.extract_element(i8_ty, nosign0, b.i32_val(1)), + b.i32_val(0)); + nosign = b.insert_element(fp8x4VecTy, nosign, + b.extract_element(i8_ty, nosign0, b.i32_val(3)), + b.i32_val(1)); + nosign = b.insert_element(fp8x4VecTy, nosign, + b.extract_element(i8_ty, nosign1, b.i32_val(1)), + b.i32_val(2)); + nosign = b.insert_element(fp8x4VecTy, nosign, + b.extract_element(i8_ty, nosign1, b.i32_val(3)), + b.i32_val(3)); + nosign = b.bitcast(nosign, i32_ty); + + Value fp8x4Vec = b.or_(i32_ty, nosign, sign); + fp8x4Vec = b.bitcast(fp8x4Vec, fp8x4VecTy); + return {b.extract_element(i8_ty, fp8x4Vec, b.i32_val(0)), + b.extract_element(i8_ty, fp8x4Vec, b.i32_val(1)), + b.extract_element(i8_ty, fp8x4Vec, b.i32_val(2)), + b.extract_element(i8_ty, fp8x4Vec, b.i32_val(3))}; } //===-----------------------------------------===// @@ -705,35 +725,36 @@ static SmallVector Bf16_to_Fp8E5M2(Location loc, static SmallVector Fp8E4M3FN_to_Bf16(Location loc, ConversionPatternRewriter &rewriter, const SmallVector &v) { + auto b = TritonLLVMOpBuilder(loc, rewriter); auto fp8x4VecTy = vec_ty(i8_ty, 4); - Value a0 = undef(fp8x4VecTy); - a0 = insert_element(fp8x4VecTy, a0, int_val(8, 0), i32_val(0)); - a0 = insert_element(fp8x4VecTy, a0, v[0], i32_val(1)); - a0 = insert_element(fp8x4VecTy, a0, int_val(8, 0), i32_val(2)); - a0 = insert_element(fp8x4VecTy, a0, v[1], i32_val(3)); - a0 = bitcast(a0, i32_ty); - - Value b0 = and_(i32_ty, a0, i32_val(0x7fff7fff)); - b0 = lshr(i32_ty, b0, i32_val(4)); - - Value c0 = shl(i32_ty, b0, i32_val(16)); - Value c1 = and_(i32_ty, b0, i32_val(0xFFFF0000)); - c0 = bitcast(c0, f32_ty); - c1 = bitcast(c1, f32_ty); - - Value d0 = fmul(f32_ty, c0, f32_val(0x1p+120)); // bias 2**(127-7) - Value d1 = fmul(f32_ty, c1, f32_val(0x1p+120)); - d0 = bitcast(d0, i32_ty); - d1 = bitcast(d1, i32_ty); - - Value out0 = or_(i32_ty, lshr(i32_ty, d0, i32_val(16)), d1); - Value sign0 = and_(i32_ty, a0, i32_val(0x80008000)); - out0 = or_(i32_ty, out0, sign0); + Value a0 = b.undef(fp8x4VecTy); + a0 = b.insert_element(fp8x4VecTy, a0, b.int_val(8, 0), b.i32_val(0)); + a0 = b.insert_element(fp8x4VecTy, a0, v[0], b.i32_val(1)); + a0 = b.insert_element(fp8x4VecTy, a0, b.int_val(8, 0), b.i32_val(2)); + a0 = b.insert_element(fp8x4VecTy, a0, v[1], b.i32_val(3)); + a0 = b.bitcast(a0, i32_ty); + + Value b0 = b.and_(i32_ty, a0, b.i32_val(0x7fff7fff)); + b0 = b.lshr(i32_ty, b0, b.i32_val(4)); + + Value c0 = b.shl(i32_ty, b0, b.i32_val(16)); + Value c1 = b.and_(i32_ty, b0, b.i32_val(0xFFFF0000)); + c0 = b.bitcast(c0, f32_ty); + c1 = b.bitcast(c1, f32_ty); + + Value d0 = b.fmul(f32_ty, c0, b.f32_val(0x1p+120)); // bias 2**(127-7) + Value d1 = b.fmul(f32_ty, c1, b.f32_val(0x1p+120)); + d0 = b.bitcast(d0, i32_ty); + d1 = b.bitcast(d1, i32_ty); + + Value out0 = b.or_(i32_ty, b.lshr(i32_ty, d0, b.i32_val(16)), d1); + Value sign0 = b.and_(i32_ty, a0, b.i32_val(0x80008000)); + out0 = b.or_(i32_ty, out0, sign0); auto bf16x2VecTy = vec_ty(bf16_ty, 2); - out0 = bitcast(out0, bf16x2VecTy); - return {extract_element(bf16_ty, out0, i32_val(0)), - extract_element(bf16_ty, out0, i32_val(1))}; + out0 = b.bitcast(out0, bf16x2VecTy); + return {b.extract_element(bf16_ty, out0, b.i32_val(0)), + b.extract_element(bf16_ty, out0, b.i32_val(1))}; } /****************************************************************************/ @@ -783,33 +804,34 @@ Bf16_to_Fp8E5M2FNUZ(Location loc, ConversionPatternRewriter &rewriter, static Value Fp8E4M3FNUZ_to_Fp16_oneValue(Location loc, ConversionPatternRewriter &rewriter, Value v) { + auto tb = TritonLLVMOpBuilder(loc, rewriter); auto fp8x2VecTy = vec_ty(i8_ty, 2); - Value a = undef(fp8x2VecTy); - a = insert_element(fp8x2VecTy, a, int_val(8, 0), i32_val(0)); - a = insert_element(fp8x2VecTy, a, v, i32_val(1)); - a = bitcast(a, i16_ty); + Value a = tb.undef(fp8x2VecTy); + a = tb.insert_element(fp8x2VecTy, a, tb.int_val(8, 0), tb.i32_val(0)); + a = tb.insert_element(fp8x2VecTy, a, v, tb.i32_val(1)); + a = tb.bitcast(a, i16_ty); - auto e_mask = int_val(16, 0x7A00); - auto e = and_(i16_ty, a, e_mask); + auto e_mask = tb.int_val(16, 0x7A00); + auto e = tb.and_(i16_ty, a, e_mask); - auto m = and_(i16_ty, a, int_val(16, 0x0700)); - auto sign = and_(i16_ty, a, int_val(16, 0x8000)); + auto m = tb.and_(i16_ty, a, tb.int_val(16, 0x0700)); + auto sign = tb.and_(i16_ty, a, tb.int_val(16, 0x8000)); // check whether all exponents are zeros - auto e_is_zero = icmp_eq(e, int_val(16, 0x0)); - auto b = and_(i16_ty, a, int_val(16, 0x7FFF)); - auto b1 = lshr(i16_ty, b, int_val(16, 1)); + auto e_is_zero = tb.icmp_eq(e, tb.int_val(16, 0x0)); + auto b = tb.and_(i16_ty, a, tb.int_val(16, 0x7FFF)); + auto b1 = tb.lshr(i16_ty, b, tb.int_val(16, 1)); // case 1, e is nonzero, add exponent by 6 - auto o0v = add(i16_ty, b1, int_val(16, 0x0C00)); - auto o0 = or_(i16_ty, o0v, sign); + auto o0v = tb.add(i16_ty, b1, tb.int_val(16, 0x0C00)); + auto o0 = tb.or_(i16_ty, o0v, sign); // case 2, e is nonzero, add exponent by 7 - auto o1v = add(i16_ty, b1, int_val(16, 0x1C00)); - auto o1 = or_(i16_ty, o1v, sign); + auto o1v = tb.add(i16_ty, b1, tb.int_val(16, 0x1C00)); + auto o1 = tb.or_(i16_ty, o1v, sign); - auto io = select(e_is_zero, o0, o1); - return bitcast(io, f16_ty); + auto io = tb.select(e_is_zero, o0, o1); + return tb.bitcast(io, f16_ty); } static SmallVector @@ -836,34 +858,35 @@ static ConverterT Fp8E4M3FNUZ_to_Fp16(AMD::ISAFamily isaFamily) { static Value Fp16_to_Fp8E4M3FNUZ_oneValue(Location loc, ConversionPatternRewriter &rewriter, Value v) { - auto vi16 = bitcast(v, i16_ty); - auto e10 = and_(vi16, int_val(16, 0x7C00)); - auto e = lshr(i16_ty, e10, int_val(16, 10)); + auto b = TritonLLVMOpBuilder(loc, rewriter); + auto vi16 = b.bitcast(v, i16_ty); + auto e10 = b.and_(vi16, b.int_val(16, 0x7C00)); + auto e = b.lshr(i16_ty, e10, b.int_val(16, 10)); - auto s = and_(i16_ty, vi16, int_val(16, 0x8000)); + auto s = b.and_(i16_ty, vi16, b.int_val(16, 0x8000)); - auto m7 = and_(i16_ty, vi16, int_val(16, 0x0380)); - auto m = shl(i16_ty, m7, int_val(16, 1)); + auto m7 = b.and_(i16_ty, vi16, b.int_val(16, 0x0380)); + auto m = b.shl(i16_ty, m7, b.int_val(16, 1)); // three cases: // 1) e > 21 --> e = 1111, // 2) e <= 7 ---> e = 0, // 3) others, normal conversion - auto e1 = int_val(16, 0x7800); - auto e2 = int_val(16, 0x0); - auto e31 = sub(i16_ty, e10, int_val(16, 0x1C00)); - auto e3 = shl(i16_ty, e31, int_val(16, 1)); + auto e1 = b.int_val(16, 0x7800); + auto e2 = b.int_val(16, 0x0); + auto e31 = b.sub(i16_ty, e10, b.int_val(16, 0x1C00)); + auto e3 = b.shl(i16_ty, e31, b.int_val(16, 1)); - auto c13 = icmp_sgt(e, int_val(16, 21)); - auto e13 = select(c13, e1, e3); - auto c23 = icmp_sle(e, int_val(16, 7)); - auto re = select(c23, e2, e13); + auto c13 = b.icmp_sgt(e, b.int_val(16, 21)); + auto e13 = b.select(c13, e1, e3); + auto c23 = b.icmp_sle(e, b.int_val(16, 7)); + auto re = b.select(c23, e2, e13); - auto r = or_(i16_ty, s, or_(i16_ty, re, m)); + auto r = b.or_(i16_ty, s, b.or_(i16_ty, re, m)); auto fp8x2VecTy = vec_ty(i8_ty, 2); - auto res = bitcast(r, fp8x2VecTy); + auto res = b.bitcast(r, fp8x2VecTy); - return extract_element(i8_ty, res, i32_val(1)); + return b.extract_element(i8_ty, res, b.i32_val(1)); } static SmallVector @@ -991,6 +1014,7 @@ struct FpToFpOpConversion ConversionPatternRewriter &rewriter, Type elemTy, MultipleOperandsRange operands, Location loc) const { + auto b = TritonLLVMOpBuilder(loc, rewriter); auto srcElementType = getElementType(op.getSrc()); auto dstElementType = getElementType(op.getResult()); auto roundingMode = op.getRounding(); @@ -1055,7 +1079,7 @@ struct FpToFpOpConversion for (Value &v : inVals) v = LLVM::AMD::cvtFp32ToFp16(loc, rewriter, v, roundingMode.value_or(RoundingMode::RTNE)); - inVals.resize(numElements, undef(typeConverter->convertType(srcType))); + inVals.resize(numElements, b.undef(typeConverter->convertType(srcType))); SmallVector outVals; if (srcType != dstType) { auto getCvtFunc = getConversionFunc(srcType, dstType, roundingMode); @@ -1184,10 +1208,11 @@ struct FSubOpConversion static SmallVector S8_to_Bf16(Location loc, ConversionPatternRewriter &rewriter, const SmallVector &v) { + auto b = TritonLLVMOpBuilder(loc, rewriter); SmallVector inValues = {v[0], v[1], v[2], v[3]}; SmallVector outValues = {}; for (Value inVal : inValues) { - Value i32Val = sext(i32_ty, inVal); + Value i32Val = b.sext(i32_ty, inVal); GCNBuilder builder; auto &cvt = *builder.create("v_cvt_f32_i32"); @@ -1196,10 +1221,10 @@ static SmallVector S8_to_Bf16(Location loc, cvt(res, operand); auto f32Val = builder.launch(rewriter, loc, f32_ty, false); - f32Val = bitcast(f32Val, i32_ty); - auto shifted = lshr(i32_ty, f32Val, i32_val(16)); - auto truncated = trunc(i16_ty, shifted); - outValues.push_back(bitcast(truncated, bf16_ty)); + f32Val = b.bitcast(f32Val, i32_ty); + auto shifted = b.lshr(i32_ty, f32Val, b.i32_val(16)); + auto truncated = b.trunc(i16_ty, shifted); + outValues.push_back(b.bitcast(truncated, bf16_ty)); } return outValues; } @@ -1310,12 +1335,13 @@ struct ExpOpConversionApprox ConversionPatternRewriter &rewriter, Type elemTy, MultipleOperandsRange operands, Location loc) const { + auto b = TritonLLVMOpBuilder(loc, rewriter); // For non-FP32 input, call __ocml_exp_f64 for higher-precision calculation if (elemTy.getIntOrFloatBitWidth() != 32) return {}; const double log2e = 1.4426950408889634; - Value prod = fmul(f32_ty, operands[0][0], f32_val(log2e)); + Value prod = b.fmul(f32_ty, operands[0][0], b.f32_val(log2e)); // Here we use llvm.exp2.f32 instead of math::Exp2Op. The latter // flushes denorms by default, but we want to preserve denorms by default @@ -1409,17 +1435,19 @@ struct RsqrtOpConversion static inline std::pair scaleUpIfDenorm(ConversionPatternRewriter &rewriter, Location loc, const Value &src, float scaleThreshold, float scaleFactor) { - Value needScale = fcmp_ogt(f32_val(scaleThreshold), src); - Value scaledSrc = fmul(f32_ty, src, f32_val(scaleFactor)); - Value selectedSrc = select(needScale, scaledSrc, src); + auto b = TritonLLVMOpBuilder(loc, rewriter); + Value needScale = b.fcmp_ogt(b.f32_val(scaleThreshold), src); + Value scaledSrc = b.fmul(f32_ty, src, b.f32_val(scaleFactor)); + Value selectedSrc = b.select(needScale, scaledSrc, src); return {needScale, selectedSrc}; } static inline Value scaleDownIfDenorm(ConversionPatternRewriter &rewriter, Location loc, const Value &src, Value needScale, float scaleFactor) { - Value scaledSrc = fmul(f32_ty, src, f32_val(scaleFactor)); - return select(needScale, scaledSrc, src); + auto b = TritonLLVMOpBuilder(loc, rewriter); + Value scaledSrc = b.fmul(f32_ty, src, b.f32_val(scaleFactor)); + return b.select(needScale, scaledSrc, src); } struct SqrtOpConversion @@ -1437,6 +1465,7 @@ struct SqrtOpConversion ConversionPatternRewriter &rewriter, Type elemTy, MultipleOperandsRange operands, Location loc) const { + auto b = TritonLLVMOpBuilder(loc, rewriter); // This function only handles FP32 inputs. Other data types are lowered to // LLVM::SqrtOp by MLIR. // @@ -1453,7 +1482,7 @@ struct SqrtOpConversion if (elemTy.getIntOrFloatBitWidth() != 32) return {}; - Value needScale = false_val(); + Value needScale = b.false_val(); Value scaledSrc = operands[0][0]; if (!ftz) { // For non-ftz cases, if the input value is below 2^{-96}, it needs to be @@ -1509,6 +1538,7 @@ struct PreciseSqrtOpConversion ConversionPatternRewriter &rewriter, Type elemTy, MultipleOperandsRange operands, Location loc) const { + auto b = TritonLLVMOpBuilder(loc, rewriter); // If the op is neither FP32 nor denorm flushing(ftz), it's directly lowered // to LLVM::SqrtOp. if (elemTy.getIntOrFloatBitWidth() != 32 || !ftz) { @@ -1535,15 +1565,15 @@ struct PreciseSqrtOpConversion LLVM::createLLVMCallOp(rewriter, loc, funcOp, operands[0]).getResult(); Value sqrtX = operands[0][0]; - Value sqrtS = fmul(f32_ty, sqrtX, sqrtR); + Value sqrtS = b.fmul(f32_ty, sqrtX, sqrtR); // Refine the approximation with Newton iteration - Value sqrtH = fmul(f32_ty, sqrtR, f32_val(0.5f)); - Value sqrtE = fma(neg(f32_ty, sqrtH), sqrtS, f32_val(0.5f)); - sqrtH = fma(sqrtH, sqrtE, sqrtH); - sqrtS = fma(sqrtS, sqrtE, sqrtS); - Value sqrtD = fma(neg(f32_ty, sqrtS), sqrtS, sqrtX); - sqrtS = fma(sqrtD, sqrtH, sqrtS); + Value sqrtH = b.fmul(f32_ty, sqrtR, b.f32_val(0.5f)); + Value sqrtE = b.fma(b.neg(f32_ty, sqrtH), sqrtS, b.f32_val(0.5f)); + sqrtH = b.fma(sqrtH, sqrtE, sqrtH); + sqrtS = b.fma(sqrtS, sqrtE, sqrtS); + Value sqrtD = b.fma(b.neg(f32_ty, sqrtS), sqrtS, sqrtX); + sqrtS = b.fma(sqrtD, sqrtH, sqrtS); // Handle +0/-0/+inf // These flags come from @@ -1555,7 +1585,7 @@ struct PreciseSqrtOpConversion Value isZeroOrPosInf = rewriter.create(loc, i1_ty, sqrtX, fcPosInf | fcZero); - return {select(isZeroOrPosInf, sqrtX, sqrtS)}; + return {b.select(isZeroOrPosInf, sqrtX, sqrtS)}; } private: diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp index 5ee6fdc4530d..18f5cfc68abe 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp @@ -30,9 +30,10 @@ namespace { // Used to mask out the redundant data accessed by threads. Value redundantDataMask(Type valueTy, ConversionPatternRewriter &rewriter, Location loc, const AMD::TargetInfo &targetInfo) { + auto b = TritonLLVMOpBuilder(loc, rewriter); auto tensorTy = dyn_cast(valueTy); - Value mask = int_val(1, 1); - auto tid = tid_val(); + Value mask = b.int_val(1, 1); + auto tid = b.tid_val(); auto clusterCTAId = targetInfo.getClusterCTAId(rewriter, loc); if (tensorTy) { auto layout = tensorTy.getEncoding(); @@ -50,9 +51,9 @@ Value redundantDataMask(Type valueTy, ConversionPatternRewriter &rewriter, warpOrder = triton::gpu::getWarpOrder(layout); } auto shapePerCTATile = triton::gpu::getShapePerCTATile(layout); - Value warpSize = i32_val(triton::gpu::getWarpSize(layout)); - Value laneId = urem(tid, warpSize); - Value warpId = udiv(tid, warpSize); + Value warpSize = b.i32_val(triton::gpu::getWarpSize(layout)); + Value laneId = b.urem(tid, warpSize); + Value warpId = b.udiv(tid, warpSize); // TODO: [DOT LL] // The delinearize function is not entirely correct for certain layouts, // such as wgmma. The correct approach is to convert a legacy layout to its @@ -69,14 +70,15 @@ Value redundantDataMask(Type valueTy, ConversionPatternRewriter &rewriter, // Otherwise, we need to mask threads that will replicate data on this // dimension. Calculate the thread index on this dimension for the CTA Value threadDim = - add(mul(multiDimWarpId[dim], i32_val(threadsPerWarp[dim])), - multiDimThreadId[dim]); - mask = and_(mask, icmp_slt(mul(threadDim, i32_val(sizePerThread[dim])), - i32_val(shape[dim]))); + b.add(b.mul(multiDimWarpId[dim], b.i32_val(threadsPerWarp[dim])), + multiDimThreadId[dim]); + mask = b.and_(mask, + b.icmp_slt(b.mul(threadDim, b.i32_val(sizePerThread[dim])), + b.i32_val(shape[dim]))); } // Do not write duplicated data when multicast is enabled if (triton::gpu::getNumCTAs(layout) > 1) { - auto _0 = i32_val(0); + auto _0 = b.i32_val(0); auto CTAsPerCGA = triton::gpu::getCTAsPerCGA(layout); auto CTASplitNum = triton::gpu::getCTASplitNum(layout); auto CTAOrder = triton::gpu::getCTAOrder(layout); @@ -90,7 +92,7 @@ Value redundantDataMask(Type valueTy, ConversionPatternRewriter &rewriter, continue; // This wrapping rule must be consistent with emitCTAOffsetForLayout unsigned splitNum = std::min(shape[dim], CTASplitNum[dim]); - Value repId = udiv(multiDimClusterCTAId[dim], i32_val(splitNum)); + Value repId = b.udiv(multiDimClusterCTAId[dim], b.i32_val(splitNum)); // Consider the example where CTAsPerCGA = [4] and CTASplitNum = [2]: // CTA0 and CTA2 holds data of block0, // CTA1 and CTA3 holds data of block1. @@ -100,14 +102,14 @@ Value redundantDataMask(Type valueTy, ConversionPatternRewriter &rewriter, // Actually in all existing cases of multicast, splitNum is always 1. // The mask is equivalent to: // multiDimClusterCTAId[dim] == 0 - mask = and_(mask, icmp_eq(repId, _0)); + mask = b.and_(mask, b.icmp_eq(repId, _0)); } } } else { // If the tensor is not ranked, then it is a scalar and only thread 0 of // CTA0 can write - mask = and_(mask, icmp_eq(clusterCTAId, i32_val(0))); - mask = and_(mask, icmp_eq(tid, i32_val(0))); + mask = b.and_(mask, b.icmp_eq(clusterCTAId, b.i32_val(0))); + mask = b.and_(mask, b.icmp_eq(tid, b.i32_val(0))); } return mask; } @@ -135,14 +137,15 @@ struct LoadStoreConversionBase { const LLVMTypeConverter *typeConverter, Location loc, VectorType vecTy, ArrayRef elems, int64_t start) const { + auto b = TritonLLVMOpBuilder(loc, rewriter); int64_t vec = vecTy.getNumElements(); // If we need to mask the loaded value with other elements - Value v = undef(vecTy); + Value v = b.undef(vecTy); for (size_t s = 0; s < vec; ++s) { Value otherElem = elems[start + s]; Value indexVal = LLVM::createIndexConstant(rewriter, loc, typeConverter, s); - v = insert_element(vecTy, v, otherElem, indexVal); + v = b.insert_element(vecTy, v, otherElem, indexVal); } return v; } @@ -220,6 +223,7 @@ struct LoadOpConversion : public ConvertOpToLLVMPattern, matchAndRewrite(triton::LoadOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto loc = op->getLoc(); + auto b = TritonLLVMOpBuilder(loc, rewriter); // original values Value ptr = op.getPtr(); @@ -272,7 +276,7 @@ struct LoadOpConversion : public ConvertOpToLLVMPattern, const size_t movWidth = width < 16 ? 16 : width; assert(wordNElems * nWords * numVecs == numElems); - Value pred = mask ? maskElems[vecStart] : int_val(1, 1); + Value pred = mask ? maskElems[vecStart] : b.int_val(1, 1); Value ptr = ptrElems[vecStart]; Value falseVal = createZeroVector(rewriter, loc, cast(vecTy)); @@ -287,7 +291,7 @@ struct LoadOpConversion : public ConvertOpToLLVMPattern, for (size_t ii = 0; ii < vec; ++ii) { Value vecIdx = createIndexAttrConstant( rewriter, loc, getTypeConverter()->getIndexType(), ii); - Value loaded = extract_element(valueElemTy, loadVal, vecIdx); + Value loaded = b.extract_element(valueElemTy, loadVal, vecIdx); loadedVals.push_back(loaded); } } // end vec @@ -321,6 +325,7 @@ struct BufferLoadOpConversion matchAndRewrite(triton::amdgpu::BufferLoadOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto loc = op->getLoc(); + auto b = TritonLLVMOpBuilder(loc, rewriter); LLVM::AMD::BufferEmitter bufferEmitter(rewriter, loc, targetInfo); // original values @@ -363,7 +368,7 @@ struct BufferLoadOpConversion SmallVector loadedVals; Type vecTy = LLVM::getFixedVectorType(valueElemTy, vec); for (size_t vecStart = 0; vecStart < numElems; vecStart += vec) { - Value pred = mask ? maskElems[vecStart] : int_val(1, 1); + Value pred = mask ? maskElems[vecStart] : b.int_val(1, 1); Value falseVal = createZeroVector(rewriter, loc, cast(vecTy)); if (otherElems.size() != 0) falseVal = packElementRangeIntoVector( @@ -374,7 +379,7 @@ struct BufferLoadOpConversion for (size_t ii = 0; ii < vec; ++ii) { Value vecIdx = createIndexAttrConstant( rewriter, loc, getTypeConverter()->getIndexType(), ii); - Value loaded = extract_element(valueElemTy, loadVal, vecIdx); + Value loaded = b.extract_element(valueElemTy, loadVal, vecIdx); loadedVals.push_back(loaded); } } // end vec @@ -414,6 +419,7 @@ struct StoreOpConversion : public ConvertOpToLLVMPattern, Value llValue = adaptor.getValue(); auto loc = op->getLoc(); + auto b = TritonLLVMOpBuilder(loc, rewriter); MLIRContext *ctx = rewriter.getContext(); auto valueTy = value.getType(); @@ -440,7 +446,7 @@ struct StoreOpConversion : public ConvertOpToLLVMPattern, const int numVecs = elemsPerThread / vec; Value rDataMask = redundantDataMask(valueTy, rewriter, loc, targetInfo); for (size_t vecStart = 0; vecStart < elemsPerThread; vecStart += vec) { - Value pred = mask ? and_(maskElems[vecStart], rDataMask) : rDataMask; + Value pred = mask ? b.and_(maskElems[vecStart], rDataMask) : rDataMask; auto vecTy = LLVM::getFixedVectorType(valueElemTy, vec); const size_t maxWordWidth = std::max(32, valueElemNBits); @@ -498,6 +504,7 @@ struct BufferAtomicRMWOpConversion matchAndRewrite(triton::amdgpu::BufferAtomicRMWOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto loc = op->getLoc(); + auto b = TritonLLVMOpBuilder(loc, rewriter); LLVM::AMD::BufferEmitter bufferEmitter(rewriter, loc, targetInfo); // original values @@ -672,7 +679,7 @@ struct BufferAtomicRMWOpConversion for (size_t vecStart = 0; vecStart < numElems; vecStart += vec) { Type vecTy = LLVM::getFixedVectorType(valueElemTy, vec); - Value pred = mask ? and_(maskElems[vecStart], rDataMask) : rDataMask; + Value pred = mask ? b.and_(maskElems[vecStart], rDataMask) : rDataMask; Value falseVal = createZeroVector(rewriter, loc, cast(vecTy)); // Create the store val Value storeVal = packElementRangeIntoVector( @@ -696,7 +703,7 @@ struct BufferAtomicRMWOpConversion for (size_t ii = 0; ii < vec; ++ii) { Value vecIdx = createIndexAttrConstant( rewriter, loc, getTypeConverter()->getIndexType(), ii); - Value loaded = extract_element(valueElemTy, loadVal, vecIdx); + Value loaded = b.extract_element(valueElemTy, loadVal, vecIdx); loadedVals.push_back(loaded); } } // end vec @@ -733,6 +740,7 @@ struct BufferStoreOpConversion matchAndRewrite(triton::amdgpu::BufferStoreOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto loc = op->getLoc(); + auto b = TritonLLVMOpBuilder(loc, rewriter); LLVM::AMD::BufferEmitter bufferEmitter(rewriter, loc, targetInfo); // original values @@ -769,7 +777,7 @@ struct BufferStoreOpConversion Value rDataMask = redundantDataMask(valueTy, rewriter, loc, targetInfo); for (size_t vecStart = 0; vecStart < numElems; vecStart += vec) { Type vecTy = LLVM::getFixedVectorType(valueElemTy, vec); - Value pred = mask ? and_(maskElems[vecStart], rDataMask) : rDataMask; + Value pred = mask ? b.and_(maskElems[vecStart], rDataMask) : rDataMask; // Create the store val Value storeVal = packElementRangeIntoVector( rewriter, this->getTypeConverter(), loc, cast(vecTy), @@ -800,6 +808,7 @@ struct AtomicCASOpConversion ConversionPatternRewriter &rewriter) const override { // extract relevant info from Module auto loc = op.getLoc(); + auto b = TritonLLVMOpBuilder(loc, rewriter); MLIRContext *ctx = rewriter.getContext(); Value ptr = op.getPtr(); @@ -841,11 +850,11 @@ struct AtomicCASOpConversion // atomic ops for (size_t i = 0; i < elemsPerThread; i += vec) { - Value casVal = undef(vecTy); + Value casVal = b.undef(vecTy); for (int ii = 0; ii < vec; ++ii) { Value iiVal = createIndexAttrConstant( rewriter, loc, getTypeConverter()->getIndexType(), ii); - casVal = insert_element(vecTy, casVal, valElements[i + ii], iiVal); + casVal = b.insert_element(vecTy, casVal, valElements[i + ii], iiVal); } Value casPtr = ptrElements[i]; @@ -863,11 +872,12 @@ struct AtomicCASOpConversion StringRef(scopeStr.value())); // Extract the new_loaded value from the pair. - Value ret = extract_val(valueElemTy, cmpxchg, i); + Value ret = b.extract_val(valueElemTy, cmpxchg, i); for (int ii = 0; ii < vec; ++ii) { resultVals[i + ii] = - vec == 1 ? ret : extract_element(valueElemTy, ret, i32_val(ii)); + vec == 1 ? ret + : b.extract_element(valueElemTy, ret, b.i32_val(ii)); } } else { // for scalar // Build blocks to bypass the atomic instruction for ~rmwMask. @@ -878,8 +888,8 @@ struct AtomicCASOpConversion // Fill entry block with global memory barrier and conditional branch. rewriter.setInsertionPointToEnd(curBlock); - auto tid = tid_val(); - Value pred = icmp_eq(tid, i32_val(i)); + auto tid = b.tid_val(); + Value pred = b.icmp_eq(tid, b.i32_val(i)); rewriter.create(loc, pred, atomicBlock, endBlock); // Build main block with atomic_cmpxchg. @@ -893,10 +903,10 @@ struct AtomicCASOpConversion if (atomicNeedsSharedMemory(op.getResult())) { // Extract the new_loaded value from the pair. - Value newLoaded = extract_val(valueElemTy, cmpxchg, 0); + Value newLoaded = b.extract_val(valueElemTy, cmpxchg, 0); Value atomPtr = getSharedMemoryBase(loc, rewriter, targetInfo, op.getOperation()); - store(newLoaded, atomPtr); + b.store(newLoaded, atomPtr); } rewriter.create(loc, ValueRange(), endBlock); @@ -912,10 +922,10 @@ struct AtomicCASOpConversion GCNBuilder BuilderMemfenceLDS; BuilderMemfenceLDS.create<>("s_waitcnt lgkmcnt(0)")->operator()(); BuilderMemfenceLDS.launch(rewriter, loc, void_ty(ctx)); - barrier(); + b.barrier(); Value atomPtr = getSharedMemoryBase(loc, rewriter, targetInfo, op.getOperation()); - Value ret = load(valueElemTy, atomPtr); + Value ret = b.load(valueElemTy, atomPtr); rewriter.replaceOp(op, {ret}); } } @@ -940,7 +950,8 @@ bool supportsGlobalAtomicF16PackedAndDpp(triton::AMD::ISAFamily isaFamily) { Value generateI32DppMove(PatternRewriter &rewriter, Value val, int dppCtrl) { assert(val.getType().isInteger(32)); auto loc = val.getLoc(); - Value old = i32_val(0); + auto b = TritonLLVMOpBuilder(loc, rewriter); + Value old = b.i32_val(0); int rowMask = 0b1111; // enable all rows int bankMask = 0b1111; // enable all banks bool boundCtrl = false; @@ -959,19 +970,20 @@ Value shiftRightI32ByDpp(PatternRewriter &rewriter, Value val) { Value generatePopcount64(PatternRewriter &rewriter, Value val) { auto loc = val.getLoc(); - Value m1 = i64_val(0x5555555555555555); // binary: 0101 0101.. - Value m2 = i64_val(0x3333333333333333); // binary: 0011 0011.. - Value m4 = i64_val(0x0f0f0f0f0f0f0f0f); // binary: 0000 1111.. + auto b = TritonLLVMOpBuilder(loc, rewriter); + Value m1 = b.i64_val(0x5555555555555555); // binary: 0101 0101.. + Value m2 = b.i64_val(0x3333333333333333); // binary: 0011 0011.. + Value m4 = b.i64_val(0x0f0f0f0f0f0f0f0f); // binary: 0000 1111.. // binary: 0000 0001 0000 0001.. - Value h01 = i64_val(0x0101010101010101); + Value h01 = b.i64_val(0x0101010101010101); // put count of each 2 bits into those 2 bits - val = sub(val, and_(m1, lshr(val, i64_val(1)))); + val = b.sub(val, b.and_(m1, b.lshr(val, b.i64_val(1)))); // put count of each 4 bits into those 4 bits - val = add(and_(val, m2), and_(lshr(val, i64_val(2)), m2)); + val = b.add(b.and_(val, m2), b.and_(b.lshr(val, b.i64_val(2)), m2)); // put count of each 8 bits into those 8 bits - val = and_(add(val, lshr(val, i64_val(4))), m4); + val = b.and_(b.add(val, b.lshr(val, b.i64_val(4))), m4); // left 8 bits of x + (x<<8) + (x<<16) + (x<<24) + ... - return lshr(mul(val, h01), i64_val(56)); + return b.lshr(b.mul(val, h01), b.i64_val(56)); } Value genReadFirstLane(PatternRewriter &rewriter, Value v) { @@ -1001,6 +1013,7 @@ template Value genI32TiledOp(PatternRewriter &rewriter, Generator genCall, Value argToSplit, Values... args) { auto loc = argToSplit.getLoc(); + auto b = TritonLLVMOpBuilder(loc, rewriter); Type ty = argToSplit.getType(); size_t tySize = ty.getIntOrFloatBitWidth(); size_t i32Size = i32_ty.getIntOrFloatBitWidth(); @@ -1008,14 +1021,14 @@ Value genI32TiledOp(PatternRewriter &rewriter, Generator genCall, assert(tySize % i32Size == 0 && count > 0 && "Unalligned types are not supported yet."); Type i32VecValTy = vec_ty(i32_ty, count); - Value vec = undef(i32VecValTy); - Value valCasted = bitcast(argToSplit, i32VecValTy); + Value vec = b.undef(i32VecValTy); + Value valCasted = b.bitcast(argToSplit, i32VecValTy); for (int i = 0; i < count; i++) { - Value subVal = extract_element(i32_ty, valCasted, i32_val(i)); + Value subVal = b.extract_element(i32_ty, valCasted, b.i32_val(i)); Value result = genCall(rewriter, subVal, args...); - vec = insert_element(i32VecValTy, vec, result, i32_val(i)); + vec = b.insert_element(i32VecValTy, vec, result, b.i32_val(i)); } - return bitcast(vec, ty); + return b.bitcast(vec, ty); } struct AtomicRMWOpConversion @@ -1063,6 +1076,7 @@ struct AtomicRMWOpConversion matchAndRewrite(triton::AtomicRMWOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto loc = op.getLoc(); + auto b = TritonLLVMOpBuilder(loc, rewriter); MLIRContext *ctx = rewriter.getContext(); auto atomicRmwAttr = op.getAtomicRmwOp(); @@ -1132,12 +1146,12 @@ struct AtomicRMWOpConversion // mask numElems = tensorTy.getNumElements(); } - Value mask = int_val(1, 1); - auto tid = tid_val(); - mask = and_(mask, - icmp_slt(mul(tid, i32_val(elemsPerThread)), i32_val(numElems))); + Value mask = b.int_val(1, 1); + auto tid = b.tid_val(); + mask = b.and_(mask, b.icmp_slt(b.mul(tid, b.i32_val(elemsPerThread)), + b.i32_val(numElems))); if (useDppForPackedF16) - mask = and_(mask, icmp_eq(urem(tid, i32_val(2)), i32_val(0))); + mask = b.and_(mask, b.icmp_eq(b.urem(tid, b.i32_val(2)), b.i32_val(0))); auto memOrdering = op.getSem(); auto scope = op.getScope(); @@ -1155,36 +1169,36 @@ struct AtomicRMWOpConversion Value rmwPtr = ptrElements[i]; // TODO: in case llMask is zero we can create only one branch for all // elemsPerThread. - Value rmwMask = llMask ? and_(mask, maskElements[i]) : mask; + Value rmwMask = llMask ? b.and_(mask, maskElements[i]) : mask; Value operand; if (useDppForPackedF16) { // Move %val to left neighbour to proceed packed atomic further. - Value packedVal = null(packF16Ty); - packedVal = - insert_element(packF16Ty, packedVal, valElements[i], i32_val(0)); + Value packedVal = b.null(packF16Ty); + packedVal = b.insert_element(packF16Ty, packedVal, valElements[i], + b.i32_val(0)); // Pack to i32 type to simplify transaction - packedVal = bitcast(packedVal, i32_ty); + packedVal = b.bitcast(packedVal, i32_ty); Value dppMoveRes = shiftLeftI32ByDpp(rewriter, packedVal); // Unpack results back - Value unpackedDppRes = bitcast(dppMoveRes, packF16Ty); - operand = undef(packF16Ty); + Value unpackedDppRes = b.bitcast(dppMoveRes, packF16Ty); + operand = b.undef(packF16Ty); operand = - insert_element(packF16Ty, operand, valElements[i], i32_val(0)); - operand = insert_element( + b.insert_element(packF16Ty, operand, valElements[i], b.i32_val(0)); + operand = b.insert_element( packF16Ty, operand, - extract_element(valueElemTy, unpackedDppRes, i32_val(0)), - i32_val(1)); + b.extract_element(valueElemTy, unpackedDppRes, b.i32_val(0)), + b.i32_val(1)); } else if (vec == 1) { operand = valElements[i]; } else { - operand = undef(vecTy); + operand = b.undef(vecTy); for (size_t ii = 0; ii < vec; ++ii) - operand = - insert_element(vecTy, operand, valElements[i + ii], i32_val(ii)); + operand = b.insert_element(vecTy, operand, valElements[i + ii], + b.i32_val(ii)); } - Value undefVal = undef(retType); + Value undefVal = b.undef(retType); // Build blocks to bypass the atomic instruction for ~rmwMask. auto *curBlock = rewriter.getInsertionBlock(); auto *endBlock = curBlock->splitBlock(rewriter.getInsertionPoint()); @@ -1214,7 +1228,7 @@ struct AtomicRMWOpConversion if (atomicNeedsSharedMemory(op.getResult())) { Value atomPtr = getSharedMemoryBase(loc, rewriter, targetInfo, op.getOperation()); - store(atom, atomPtr); + b.store(atom, atomPtr); } } rewriter.create(loc, atom, endBlock); @@ -1225,21 +1239,22 @@ struct AtomicRMWOpConversion if (useDppForPackedF16) { // Return packed to i32 result after atomic operation back from master // lane. - auto packedRet = bitcast(retVal, i32_ty); + auto packedRet = b.bitcast(retVal, i32_ty); Value dppMovRes = shiftRightI32ByDpp(rewriter, packedRet); // Unpack results back - Value unpackedDppRes = bitcast(dppMovRes, packF16Ty); - retVal = insert_element( + Value unpackedDppRes = b.bitcast(dppMovRes, packF16Ty); + retVal = b.insert_element( packF16Ty, retVal, - extract_element(valueElemTy, unpackedDppRes, i32_val(1)), - i32_val(1)); + b.extract_element(valueElemTy, unpackedDppRes, b.i32_val(1)), + b.i32_val(1)); resultVals[i] = - extract_element(valueElemTy, retVal, urem(tid, i32_val(2))); + b.extract_element(valueElemTy, retVal, b.urem(tid, b.i32_val(2))); } else { for (int ii = 0; ii < vec; ++ii) { resultVals[i + ii] = - vec == 1 ? retVal - : extract_element(valueElemTy, retVal, i32_val(ii)); + vec == 1 + ? retVal + : b.extract_element(valueElemTy, retVal, b.i32_val(ii)); } } } else { @@ -1249,8 +1264,8 @@ struct AtomicRMWOpConversion } Value atomPtr = getSharedMemoryBase(loc, rewriter, targetInfo, op.getOperation()); - barrier(); - Value ret = load(valueElemTy, atomPtr); + b.barrier(); + Value ret = b.load(valueElemTy, atomPtr); rewriter.replaceOp(op, {ret}); } } @@ -1286,10 +1301,11 @@ struct AtomicRMWOpConversion // to consolidate group data into leader threads. // 5. Perform global atomic operations by leader threads. auto loc = operand.getLoc(); + auto b = TritonLLVMOpBuilder(loc, rewriter); Type operandElemType = operand.getType(); Type origPtrType = rmwPtr.getType(); - rmwPtr = ptrtoint(i64_ty, rmwPtr); + rmwPtr = b.ptrtoint(i64_ty, rmwPtr); auto *curBlock = rewriter.getInsertionBlock(); auto *afterLoopBlock = curBlock->splitBlock(rewriter.getInsertionPoint()); @@ -1300,7 +1316,7 @@ struct AtomicRMWOpConversion curBlock->getParent(), std::next(Region::iterator(curBlock))); loopBody->addArgument(i32_ty, loc); // base rewriter.setInsertionPointToEnd(curBlock); - rewriter.create(loc, i32_val(0), loopBody); + rewriter.create(loc, b.i32_val(0), loopBody); // Greed search of same addr within wavefront. Also collect auxiliary // information about relative position: @@ -1313,20 +1329,20 @@ struct AtomicRMWOpConversion Value chosen = genI32TiledOp(rewriter, genReadFirstLane, rmwPtr); // this flag is required to disable thread if we have already checked its // pointer - Value done = icmp_eq(chosen, rmwPtr); + Value done = b.icmp_eq(chosen, rmwPtr); Value mask = targetInfo.ballot(rewriter, loc, i64_ty, done); Value start = loopBody->getArgument(0); - Value cnt = trunc(i32_ty, generatePopcount64(rewriter, mask)); + Value cnt = b.trunc(i32_ty, generatePopcount64(rewriter, mask)); Value mbcntLoRes = rewriter .create( - loc, i32_ty, trunc(i32_ty, mask), i32_val(0)) + loc, i32_ty, b.trunc(i32_ty, mask), b.i32_val(0)) ->getResult(0); Value idx = rewriter.create( - loc, i32_ty, trunc(i32_ty, lshr(mask, i64_val(32))), mbcntLoRes); - Value base = add(start, cnt); - Value leader = icmp_eq(idx, i32_val(0)); - cnt = sub(cnt, idx); - idx = add(idx, start); + loc, i32_ty, b.trunc(i32_ty, b.lshr(mask, b.i64_val(32))), mbcntLoRes); + Value base = b.add(start, cnt); + Value leader = b.icmp_eq(idx, b.i32_val(0)); + cnt = b.sub(cnt, idx); + idx = b.add(idx, start); rewriter.create(loc, done, afterLoopBlock, ValueRange({idx, cnt, leader}), loopBody, ValueRange({base})); @@ -1336,20 +1352,20 @@ struct AtomicRMWOpConversion Value idxRes = afterLoopBlock->getArgument(0); Value cntRes = afterLoopBlock->getArgument(1); Value leaderRes = afterLoopBlock->getArgument(2); - Value idxScaledForPermute = mul(idxRes, i32_val(4)); + Value idxScaledForPermute = b.mul(idxRes, b.i32_val(4)); // Make groups continuous rmwPtr = genI32TiledOp(rewriter, genPermute, rmwPtr, idxScaledForPermute); operand = genI32TiledOp(rewriter, genPermute, operand, idxScaledForPermute); // Actualize auxiliary info as well - Value packedRoleInfo = - genI32TiledOp(rewriter, genPermute, - or_(zext(i32_ty, leaderRes), - or_(idxScaledForPermute, shl(cntRes, i32_val(8)))), - idxScaledForPermute); + Value packedRoleInfo = genI32TiledOp( + rewriter, genPermute, + b.or_(b.zext(i32_ty, leaderRes), + b.or_(idxScaledForPermute, b.shl(cntRes, b.i32_val(8)))), + idxScaledForPermute); idxScaledForPermute = packedRoleInfo; - cntRes = and_(lshr(packedRoleInfo, i32_val(8)), i32_val(0xff)); - leaderRes = icmp_ne(and_(packedRoleInfo, i32_val(1)), i32_val(0)); + cntRes = b.and_(b.lshr(packedRoleInfo, b.i32_val(8)), b.i32_val(0xff)); + leaderRes = b.icmp_ne(b.and_(packedRoleInfo, b.i32_val(1)), b.i32_val(0)); auto *afterRedBlock = afterLoopBlock->splitBlock(rewriter.getInsertionPoint()); @@ -1358,9 +1374,10 @@ struct AtomicRMWOpConversion rewriter.createBlock(afterLoopBlock->getParent(), std::next(Region::iterator(afterLoopBlock))); rewriter.setInsertionPointToEnd(afterLoopBlock); - Value reductionCond = icmp_ne( - targetInfo.ballot(rewriter, loc, i64_ty, icmp_ne(cntRes, i32_val(1))), - i64_val(0)); + Value reductionCond = + b.icmp_ne(targetInfo.ballot(rewriter, loc, i64_ty, + b.icmp_ne(cntRes, b.i32_val(1))), + b.i64_val(0)); rewriter.create(loc, reductionCond, partialReductionBlock, afterRedBlock, operand); rewriter.setInsertionPointToEnd(partialReductionBlock); @@ -1368,39 +1385,42 @@ struct AtomicRMWOpConversion auto performOpIfCond = [&](Value res, Value v, Value cond) -> Value { Type ty = v.getType(); assert(ty == res.getType()); - Value notCond = icmp_eq(cond, false_val()); + Value notCond = b.icmp_eq(cond, b.false_val()); switch (opKind) { case LLVM::AtomicBinOp::_and: // res &= cond ? v : 1111.. - return and_(res, or_(v, sub(int_val(ty.getIntOrFloatBitWidth(), 0), - zext(ty, notCond)))); + return b.and_(res, + b.or_(v, b.sub(b.int_val(ty.getIntOrFloatBitWidth(), 0), + b.zext(ty, notCond)))); case LLVM::AtomicBinOp::_or: // res |= cond ? v : 0 - return or_(res, mul(v, zext(ty, cond))); + return b.or_(res, b.mul(v, b.zext(ty, cond))); case LLVM::AtomicBinOp::_xor: // res ^= cond ? v : 0 - return xor_(res, mul(v, zext(ty, cond))); + return b.xor_(res, b.mul(v, b.zext(ty, cond))); case LLVM::AtomicBinOp::add: // res += cond ? v : 0 - return add(res, mul(v, zext(ty, cond))); + return b.add(res, b.mul(v, b.zext(ty, cond))); case LLVM::AtomicBinOp::fadd: // res += cond ? v : 0 - return fadd( - res, fmul(v, inttofloat(ty, zext(int_ty(ty.getIntOrFloatBitWidth()), - cond)))); + return b.fadd( + res, b.fmul(v, b.inttofloat( + ty, b.zext(int_ty(ty.getIntOrFloatBitWidth()), + cond)))); case LLVM::AtomicBinOp::max: case LLVM::AtomicBinOp::umax: // res = cond ? umax(v, res) : res - return or_(mul(res, zext(ty, notCond)), - mul(umax(v, res), zext(ty, cond))); + return b.or_(b.mul(res, b.zext(ty, notCond)), + b.mul(b.umax(v, res), b.zext(ty, cond))); case LLVM::AtomicBinOp::min: case LLVM::AtomicBinOp::umin: // res = cond ? umin(v, res) : res - return or_(mul(res, zext(ty, notCond)), - mul(umin(v, res), zext(ty, cond))); + return b.or_(b.mul(res, b.zext(ty, notCond)), + b.mul(b.umin(v, res), b.zext(ty, cond))); case LLVM::AtomicBinOp::xchg: // res = cond ? v : res - return or_(mul(res, zext(ty, notCond)), mul(v, zext(ty, cond))); + return b.or_(b.mul(res, b.zext(ty, notCond)), + b.mul(v, b.zext(ty, cond))); default: llvm_unreachable("Unsupported atomic binary operation."); } @@ -1409,8 +1429,8 @@ struct AtomicRMWOpConversion // Reduce to leader thread for (int i = 32; i != 0; i /= 2) { Value tmp = genI32TiledOp(rewriter, genBPermute, acc, - add(idxScaledForPermute, i32_val(i * 4))); - acc = performOpIfCond(acc, tmp, icmp_ult(i32_val(i), cntRes)); + b.add(idxScaledForPermute, b.i32_val(i * 4))); + acc = performOpIfCond(acc, tmp, b.icmp_ult(b.i32_val(i), cntRes)); } rewriter.create(loc, acc, afterRedBlock); @@ -1422,12 +1442,12 @@ struct AtomicRMWOpConversion afterRedBlock->getParent(), std::next(Region::iterator(afterRedBlock))); rewriter.setInsertionPointToEnd(afterRedBlock); Value leaderCond = leaderRes; - Value defaultRes = undef(operandElemType); + Value defaultRes = b.undef(operandElemType); rewriter.create(loc, leaderCond, leaderBlock, endBlock, defaultRes); rewriter.setInsertionPointToEnd(leaderBlock); // Utilize global atomic only by leader threads - rmwPtr = inttoptr(origPtrType, rmwPtr); + rmwPtr = b.inttoptr(origPtrType, rmwPtr); Value atom = rewriter .create(loc, opKind, rmwPtr, afterRedBlock->getArgument(0), diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/MemoryOpToLLVM.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/MemoryOpToLLVM.cpp index c9602c969423..3b212f9d9660 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/MemoryOpToLLVM.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/MemoryOpToLLVM.cpp @@ -60,6 +60,7 @@ struct LocalLoadOpConversion ConversionPatternRewriter &rewriter, const DotOperandEncodingAttr &dotOperandLayout) const { auto loc = op.getLoc(); + auto b = TritonLLVMOpBuilder(loc, rewriter); Value src = op.getSrc(); Value dst = op.getResult(); auto llvmElemTy = typeConverter->convertType( @@ -75,7 +76,7 @@ struct LocalLoadOpConversion : SharedToDotOperandWMMA::convertLayout; res = sharedToDotConvert(dotOperandLayout.getOpIdx(), rewriter, loc, src, dotOperandLayout, smemObj, typeConverter, - tid_val()); + b.tid_val()); } else { assert(false && "unsupported layout found"); } diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/TargetInfo.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/TargetInfo.cpp index 0835e8fc514a..59d05a82a3f7 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/TargetInfo.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/TargetInfo.cpp @@ -27,11 +27,12 @@ LLVM::LLVMFuncOp getOrInsertFunction(T &moduleOp, const Location loc, Value printfPromoteValue(RewriterBase &rewriter, Value value) { auto *context = rewriter.getContext(); auto loc = UnknownLoc::get(context); + auto b = TritonLLVMOpBuilder(loc, rewriter); auto type = value.getType(); if (isa(type)) { // The llvm.ptrtoint op requires signless integer types. - return ptrtoint(i64_ty, value); + return b.ptrtoint(i64_ty, value); } assert(type.getIntOrFloatBitWidth() <= 64); @@ -39,18 +40,18 @@ Value printfPromoteValue(RewriterBase &rewriter, Value value) { if (auto floatType = dyn_cast(type)) { Value newValue = value; if (!floatType.isF64()) - newValue = fpext(f64_ty, newValue); - return bitcast(newValue, i64_ty); + newValue = b.fpext(f64_ty, newValue); + return b.bitcast(newValue, i64_ty); } assert(type.isIntOrIndex()); if (type.getIntOrFloatBitWidth() < 64) { if (type.isUnsignedInteger()) - return zext(ui64_ty, value); + return b.zext(ui64_ty, value); if (type.isSignedInteger()) - return sext(i64_ty, value); + return b.sext(i64_ty, value); // Signless integers are printed using unsigned integer formats. - return zext(i64_ty, value); + return b.zext(i64_ty, value); } return value; @@ -145,16 +146,17 @@ Value TargetInfo::programId(RewriterBase &rewriter, Location loc, static inline Type castToAndSExtInt(RewriterBase &rewriter, Location loc, Value &val, Type fromType, unsigned toBits) { + auto b = TritonLLVMOpBuilder(loc, rewriter); unsigned originalBits = fromType.getIntOrFloatBitWidth(); Type toType = fromType; if (!fromType.isIntOrIndex()) { - val = bitcast(val, int_ty(originalBits)); + val = b.bitcast(val, int_ty(originalBits)); toType = int_ty(originalBits); } if (originalBits < toBits) { - val = sext(int_ty(toBits), val); + val = b.sext(int_ty(toBits), val); toType = int_ty(toBits); } @@ -167,15 +169,16 @@ static inline Type castToAndSExtInt(RewriterBase &rewriter, Location loc, static inline Value truncAndCastFromInt(RewriterBase &rewriter, Location loc, Value val, Type valType, unsigned fromBits) { + auto b = TritonLLVMOpBuilder(loc, rewriter); unsigned originalBits = valType.getIntOrFloatBitWidth(); Value toVal = val; if (originalBits < fromBits) { - toVal = trunc(int_ty(originalBits), toVal); + toVal = b.trunc(int_ty(originalBits), toVal); } if (!valType.isIntOrIndex()) { - toVal = bitcast(toVal, valType); + toVal = b.bitcast(toVal, valType); } return toVal; @@ -185,6 +188,7 @@ bool TargetInfo::warpReduce(RewriterBase &rewriter, Location loc, SmallVector &acc, triton::ReduceOp op, unsigned numLaneToReduce, unsigned interleave) const { + auto b = TritonLLVMOpBuilder(loc, rewriter); if (numLaneToReduce != 64) return false; @@ -307,7 +311,7 @@ bool TargetInfo::warpReduce(RewriterBase &rewriter, Location loc, std::string intrinsic = "llvm.amdgcn.readlane"; Value result = LLVM::createLLVMIntrinsicCallOp(rewriter, loc, intrinsic, actualType, - ValueRange{buf, i32_val(63)}) + ValueRange{buf, b.i32_val(63)}) ->getResult(0); result = truncAndCastFromInt(rewriter, loc, result, valType, 16); @@ -324,6 +328,7 @@ void TargetInfo::printfImpl(Value formatStrStart, int formatStrByteCount, auto moduleOp = rewriter.getBlock()->getParent()->getParentOfType(); auto *ctx = rewriter.getContext(); mlir::Location loc = UnknownLoc::get(ctx); + auto b = TritonLLVMOpBuilder(loc, rewriter); // See // https://github.com/ROCm/ROCm-Device-Libs/blob/rocm-6.0.x/ockl/src/services.cl#L263-L361 @@ -349,16 +354,16 @@ void TargetInfo::printfImpl(Value formatStrStart, int formatStrByteCount, // Emit the intrinsic function call to begin the printf. Value zeroI64 = rewriter.create(loc, i64_ty, 0); Value message = - call(printBeginFn, useStdErr ? ValueRange() : zeroI64).getResult(); + b.call(printBeginFn, useStdErr ? ValueRange() : zeroI64).getResult(); // Emit the intrinsic function call to handle the printf format string. - Value oneI32 = i32_val(1); - Value zeroI32 = i32_val(0); + Value oneI32 = b.i32_val(1); + Value zeroI32 = b.i32_val(0); Value formatStrLen = rewriter.create(loc, i64_ty, formatStrByteCount); SmallVector arguments = {message, formatStrStart, formatStrLen, args.empty() ? oneI32 : zeroI32}; - message = call(printStrFn, arguments).getResult(); + message = b.call(printStrFn, arguments).getResult(); // Emit the intrinsic function call to handle arguments iteratively. // We can only handle at most 7 values each time. @@ -369,7 +374,7 @@ void TargetInfo::printfImpl(Value formatStrStart, int formatStrByteCount, SmallVector arguments; arguments.push_back(message); - arguments.push_back(i32_val(numArgs)); + arguments.push_back(b.i32_val(numArgs)); for (size_t i = group; i < bound; ++i) { arguments.push_back(printfPromoteValue(rewriter, args[i])); } @@ -380,7 +385,7 @@ void TargetInfo::printfImpl(Value formatStrStart, int formatStrByteCount, Value isLast = (bound == args.size()) ? oneI32 : zeroI32; arguments.push_back(isLast); - message = call(printArgsFn, arguments).getResult(); + message = b.call(printArgsFn, arguments).getResult(); } } @@ -411,6 +416,7 @@ void TargetInfo::printf(RewriterBase &rewriter, StringRef msg, void TargetInfo::assertFail(RewriterBase &rewriter, Location loc, StringRef message, StringRef file, StringRef func, int line) const { + auto b = TritonLLVMOpBuilder(loc, rewriter); // Compose and print an assert message. llvm::SmallString<256> msgBuffer; llvm::Twine("device assertion failed: '" + message + "', in " + func + @@ -423,7 +429,7 @@ void TargetInfo::assertFail(RewriterBase &rewriter, Location loc, // Set block barrrier before aborting kernel, give a chance for all // the threads in a block to check/print the assert failure. - barrier(); + b.barrier(); // Perform the trap to abort the kernel. rewriter.create(loc); } diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/UpcastMXFPToLLVM.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/UpcastMXFPToLLVM.cpp index dc6fd4b28644..44241f7e1c86 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/UpcastMXFPToLLVM.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/UpcastMXFPToLLVM.cpp @@ -25,6 +25,7 @@ SmallVector upcast8xMxfp4(RewriterBase &rewriter, UpcastMXFPOp upcastOp, bool tofp16, Value packedVec) { Location loc = upcastOp.getLoc(); + auto b = TritonLLVMOpBuilder(loc, rewriter); // MXFP4 has 4 bits, S.EE.M, for Sign, Exponent, and Mantissa respectively. // For a specific S, we have a total of 8 bit patterns. We can encode all @@ -49,11 +50,11 @@ SmallVector upcast8xMxfp4(RewriterBase &rewriter, // 0.11.1 | 0x40c0 | 0x4600 | + 6.0 // // Encode Byte #0 (M) for BF16/FP16 in a LUT. - Value resB0LutLo = tofp16 ? i32_val(0) : i32_val(0xc0800000); - Value resB0LutHi = tofp16 ? i32_val(0) : i32_val(0xc0804000); + Value resB0LutLo = tofp16 ? b.i32_val(0) : b.i32_val(0xc0800000); + Value resB0LutHi = tofp16 ? b.i32_val(0) : b.i32_val(0xc0804000); // Encode Byte #1 (EM, non-S part) for BF16/FP16 in a LUT. - Value resB1LutLoNoS = tofp16 ? i32_val(0x3e3c3800) : i32_val(0x3f3f3f00); - Value resB1LutHiNoS = tofp16 ? i32_val(0x46444240) : i32_val(0x40404040); + Value resB1LutLoNoS = tofp16 ? b.i32_val(0x3e3c3800) : b.i32_val(0x3f3f3f00); + Value resB1LutHiNoS = tofp16 ? b.i32_val(0x46444240) : b.i32_val(0x40404040); Type i32Ty = rewriter.getI32Type(); auto permU32FnTy = LLVM::LLVMFunctionType::get(i32Ty, {i32Ty, i32Ty, i32Ty}); @@ -62,28 +63,28 @@ SmallVector upcast8xMxfp4(RewriterBase &rewriter, // Start with 8 mxfp4 elements in a single i32 register // | e7e6 | e5e4 | e3e2 | e1e0 | - Value input = bitcast(packedVec, i32Ty); + Value input = b.bitcast(packedVec, i32Ty); // Step 1: extract EM bits for elements 0,2,4,6 and 1,3,5,7 respectively. // e2m1_6420_idx = | 0[0e6EM] | 0[0e4EM] | 0[0e2EM] | 0[0e0EM] | - Value e2m1_6420_idx = and_(input, i32_val(0x07070707)); + Value e2m1_6420_idx = b.and_(input, b.i32_val(0x07070707)); // e2m1_7531_idx = | [0e7EM]0 | [0e5EM]0 | [0e3EM]0 | [0e1EM]0 | - Value e2m1_7531_idx = and_(input, i32_val(0x70707070)); + Value e2m1_7531_idx = b.and_(input, b.i32_val(0x70707070)); // e2m1_7531_idx = | 0[0e7EM] | 0[0e5EM] | 0[0e3EM] | 0[0e1EM] | - e2m1_7531_idx = lshr(e2m1_7531_idx, i32_val(4)); + e2m1_7531_idx = b.lshr(e2m1_7531_idx, b.i32_val(4)); // Step 2: extract S bit for elements 0,2,4,6 and 1,3,5,7 // s_6420 = | 0[e6S000] | 0[e4S000] | 0[e2S000] | 0[e0S000] | - Value s_6420 = and_(input, i32_val(0x08080808)); + Value s_6420 = b.and_(input, b.i32_val(0x08080808)); // s_6420 = | [e6S000]0 | [e4S000]0 | [e2S000]0 | [e0S000]0 | - s_6420 = shl(s_6420, i32_val(4)); + s_6420 = b.shl(s_6420, b.i32_val(4)); // s_7531 = | [e7S000]0 | [e5S000]0 | [e3S000]0 | [e1S000]0 | - Value s_7531 = and_(input, i32_val(0x80808080)); + Value s_7531 = b.and_(input, b.i32_val(0x80808080)); // Step 3: Upcast elements 0,2,4,6 to 4 16-bit elements // Select Byte #0. It's always 0 if upcasting to fp16. // resB0_6420 = | e6B0 | e4B0 | e2B0 | e0B0 | - Value resB0_6420 = i32_val(0); + Value resB0_6420 = b.i32_val(0); if (!tofp16) { resB0_6420 = LLVM::createLLVMCallOp(rewriter, loc, funcOp, {resB0LutHi, resB0LutLo, e2m1_6420_idx}) @@ -95,25 +96,25 @@ SmallVector upcast8xMxfp4(RewriterBase &rewriter, {resB1LutHiNoS, resB1LutLoNoS, e2m1_6420_idx}) .getResult(); // resB1_6420 = | e6B1 | e4B1 | e2B1 | e0B1 | - Value resB1_6420 = or_(resB1NoS_6420, s_6420); + Value resB1_6420 = b.or_(resB1NoS_6420, s_6420); // Construct 16-bit values of e0 and e2 // res_20 = | e2B1 | e2B0 | e0B1 | e0B0 | = | e2_f16 | e0_f16 | Value res_20 = LLVM::createLLVMCallOp(rewriter, loc, funcOp, - {resB1_6420, resB0_6420, i32_val(0x05010400)}) + {resB1_6420, resB0_6420, b.i32_val(0x05010400)}) .getResult(); // Construct 16-bit values of e4 and e6 // res_64 = | e6B1 | e6B0 | e4B1 | e4B0 | = | e6_f16 | e4_f16 | Value res_64 = LLVM::createLLVMCallOp(rewriter, loc, funcOp, - {resB1_6420, resB0_6420, i32_val(0x07030602)}) + {resB1_6420, resB0_6420, b.i32_val(0x07030602)}) .getResult(); // Step 4: Upcast elements 1,3,5,7 to 4 16-bit elements // This is a copy of step 3 on different group of elements // Select Byte #0. It's always 0 if upcasting to fp16. // resB0_7531 = | e7B0 | e5B0 | e3B0 | e1B0 | - Value resB0_7531 = i32_val(0); + Value resB0_7531 = b.i32_val(0); if (!tofp16) { resB0_7531 = LLVM::createLLVMCallOp(rewriter, loc, funcOp, {resB0LutHi, resB0LutLo, e2m1_7531_idx}) @@ -125,36 +126,36 @@ SmallVector upcast8xMxfp4(RewriterBase &rewriter, {resB1LutHiNoS, resB1LutLoNoS, e2m1_7531_idx}) .getResult(); // resB1_7531 = | e7B1 | e5B1 | e3B1 | e1B1 | - Value resB1_7531 = or_(resB1NoS_7531, s_7531); + Value resB1_7531 = b.or_(resB1NoS_7531, s_7531); // Construct 16-bit values of e1 and e3 // res_31 = | e3B1 | e3B0 | e1B1 | e1B0 | = | e3_f16 | e1_f16 | Value res_31 = LLVM::createLLVMCallOp(rewriter, loc, funcOp, - {resB1_7531, resB0_7531, i32_val(0x05010400)}) + {resB1_7531, resB0_7531, b.i32_val(0x05010400)}) .getResult(); // Construct 16-bit values of e5 and e7 // res_75 = | e7B1 | e7B0 | e5B1 | e5B0 | = | e7_f16 | e5_f16 | Value res_75 = LLVM::createLLVMCallOp(rewriter, loc, funcOp, - {resB1_7531, resB0_7531, i32_val(0x07030602)}) + {resB1_7531, resB0_7531, b.i32_val(0x07030602)}) .getResult(); // Step 5: Reorder 16-bit elements to be 0,1,2,3,4,5,6,7 // res_10 = | e1_f16 | e0_f16 | Value res_10 = LLVM::createLLVMCallOp(rewriter, loc, funcOp, - {res_31, res_20, i32_val(0x05040100)}) + {res_31, res_20, b.i32_val(0x05040100)}) .getResult(); // res_32 = | e3_f16 | e2_f16 | Value res_32 = LLVM::createLLVMCallOp(rewriter, loc, funcOp, - {res_31, res_20, i32_val(0x07060302)}) + {res_31, res_20, b.i32_val(0x07060302)}) .getResult(); // res_54 = | e5_f16 | e4_f16 | Value res_54 = LLVM::createLLVMCallOp(rewriter, loc, funcOp, - {res_75, res_64, i32_val(0x05040100)}) + {res_75, res_64, b.i32_val(0x05040100)}) .getResult(); // res_76 = | e7_f16 | e6_f16 | Value res_76 = LLVM::createLLVMCallOp(rewriter, loc, funcOp, - {res_75, res_64, i32_val(0x07060302)}) + {res_75, res_64, b.i32_val(0x07060302)}) .getResult(); return {res_10, res_32, res_54, res_76}; @@ -164,6 +165,7 @@ SmallVector upcastMxfp4(RewriterBase &rewriter, UpcastMXFPOp upcastOp, bool toFp16, ArrayRef values) { assert(values.size() % 4 == 0); Location loc = upcastOp.getLoc(); + auto b = TritonLLVMOpBuilder(loc, rewriter); SmallVector results; results.reserve(values.size() * 2); @@ -173,17 +175,17 @@ SmallVector upcastMxfp4(RewriterBase &rewriter, UpcastMXFPOp upcastOp, Value v1 = values[i + 1]; Value v2 = values[i + 2]; Value v3 = values[i + 3]; - Value packedVec = undef(vec_ty(i8_ty, 4)); - packedVec = insert_element(packedVec, v0, i32_val(0)); - packedVec = insert_element(packedVec, v1, i32_val(1)); - packedVec = insert_element(packedVec, v2, i32_val(2)); - packedVec = insert_element(packedVec, v3, i32_val(3)); + Value packedVec = b.undef(vec_ty(i8_ty, 4)); + packedVec = b.insert_element(packedVec, v0, b.i32_val(0)); + packedVec = b.insert_element(packedVec, v1, b.i32_val(1)); + packedVec = b.insert_element(packedVec, v2, b.i32_val(2)); + packedVec = b.insert_element(packedVec, v3, b.i32_val(3)); SmallVector v4i32 = upcast8xMxfp4(rewriter, upcastOp, toFp16, packedVec); for (int j = 0; j < 4; j++) { - Value elements = bitcast(v4i32[j], vec_ty(elemType, 2)); - results.push_back(extract_element(elements, i32_val(0))); - results.push_back(extract_element(elements, i32_val(1))); + Value elements = b.bitcast(v4i32[j], vec_ty(elemType, 2)); + results.push_back(b.extract_element(elements, b.i32_val(0))); + results.push_back(b.extract_element(elements, b.i32_val(1))); } } return results; @@ -191,16 +193,18 @@ SmallVector upcastMxfp4(RewriterBase &rewriter, UpcastMXFPOp upcastOp, Value mxfpScaleFp16(RewriterBase &rewriter, Location loc, Value v, Value scale, bool fastMath) { - Value scaleF32 = bitcast(shl(zext(i32_ty, scale), i32_val(23)), f32_ty); + auto b = TritonLLVMOpBuilder(loc, rewriter); + Value scaleF32 = + b.bitcast(b.shl(b.zext(i32_ty, scale), b.i32_val(23)), f32_ty); Value scaleF16 = LLVM::AMD::cvtFp32ToFp16(loc, rewriter, scaleF32, RoundingMode::RTNE); - Value mulF16 = fmul(v, scaleF16); + Value mulF16 = b.fmul(v, scaleF16); if (fastMath) return mulF16; // Account for NaN in the scale as per the mxfp specification. - Value scaleIsNan = icmp_eq(scale, i8_val(0xff)); - Value nanF16 = bitcast(i16_val(0x7c01), f16_ty); - return select(scaleIsNan, nanF16, bitcast(mulF16, f16_ty)); + Value scaleIsNan = b.icmp_eq(scale, b.i8_val(0xff)); + Value nanF16 = b.bitcast(b.i16_val(0x7c01), f16_ty); + return b.select(scaleIsNan, nanF16, b.bitcast(mulF16, f16_ty)); }; // Scales the given bf16 v using the given scale factor without relying on bf16 @@ -211,18 +215,21 @@ Value mxfpScaleFp16(RewriterBase &rewriter, Location loc, Value v, Value scale, // for us, just with unnecessary overheads. Value mxfpScaleBf16ViaF32(RewriterBase &rewriter, Location loc, Value v, Value scale, bool fastMath) { - Value c16 = i32_val(16); - Value vF32 = bitcast(shl(zext(i32_ty, bitcast(v, i16_ty)), c16), f32_ty); - Value scaleF32 = bitcast(shl(zext(i32_ty, scale), i32_val(23)), f32_ty); - Value mulF32 = fmul(vF32, scaleF32); - Value mulI16 = trunc(i16_ty, lshr(bitcast(mulF32, i32_ty), c16)); - Value mulBf16 = bitcast(mulI16, bf16_ty); + auto b = TritonLLVMOpBuilder(loc, rewriter); + Value c16 = b.i32_val(16); + Value vF32 = + b.bitcast(b.shl(b.zext(i32_ty, b.bitcast(v, i16_ty)), c16), f32_ty); + Value scaleF32 = + b.bitcast(b.shl(b.zext(i32_ty, scale), b.i32_val(23)), f32_ty); + Value mulF32 = b.fmul(vF32, scaleF32); + Value mulI16 = b.trunc(i16_ty, b.lshr(b.bitcast(mulF32, i32_ty), c16)); + Value mulBf16 = b.bitcast(mulI16, bf16_ty); if (fastMath) return mulBf16; // Account for NaN in the scale as per the mxfp specification. - Value scaleIsNan = icmp_eq(scale, i8_val(0xff)); - Value nanBf16 = bitcast(i16_val(0x7fff), bf16_ty); - return select(scaleIsNan, nanBf16, mulBf16); + Value scaleIsNan = b.icmp_eq(scale, b.i8_val(0xff)); + Value nanBf16 = b.bitcast(b.i16_val(0x7fff), bf16_ty); + return b.select(scaleIsNan, nanBf16, mulBf16); }; class UpcastMXFPOpPattern : public ConvertOpToLLVMPattern { @@ -245,6 +252,7 @@ class UpcastMXFPOpPattern : public ConvertOpToLLVMPattern { return rewriter.notifyMatchFailure(op, "NYI: non-mxfp4/mxfp8 cases"); Location loc = op.getLoc(); + auto b = TritonLLVMOpBuilder(loc, rewriter); auto xVals = unpackLLElements(loc, adaptor.getSrc(), rewriter); auto scaleVals = unpackLLElements(loc, adaptor.getScale(), rewriter); LDBG("x: " << xVals.size() << " x " << xVals.front().getType()); @@ -270,10 +278,10 @@ class UpcastMXFPOpPattern : public ConvertOpToLLVMPattern { int numThreads = triton::gpu::TritonGPUDialect::getThreadsPerWarp( op->getParentOfType()); - Value warpSize = i32_val(numThreads); - Value tid = tid_val(); - Value warpId = udiv(tid, warpSize); - Value laneId = urem(tid, warpSize); + Value warpSize = b.i32_val(numThreads); + Value tid = b.tid_val(); + Value warpId = b.udiv(tid, warpSize); + Value laneId = b.urem(tid, warpSize); bool useFp16 = op.getType().getElementType().isF16(); if (isPacked) { @@ -285,7 +293,8 @@ class UpcastMXFPOpPattern : public ConvertOpToLLVMPattern { // blocked layout for the A scale tensor, we made sure that it has a // threadsPerWarp = [M=mDim, K=64/mDim]. So the threads holding scale values // for the current thread starts at ((tid % mDim) * (64 / mDim)). - Value offset = mul(urem(laneId, i32_val(mDim)), i32_val(numThreads / mDim)); + Value offset = + b.mul(b.urem(laneId, b.i32_val(mDim)), b.i32_val(numThreads / mDim)); if (mDim == 32) { // One mfma32 intrinsic processes a 32x8 A tensor slice. Due to how we @@ -295,7 +304,7 @@ class UpcastMXFPOpPattern : public ConvertOpToLLVMPattern { // first 4 1x4 vectors it holds shares the first scale value at row (tid % // mDim). the second 4 1x4 vectors shares the second scale value at row // (tid % mDim); and so forth. - std::array scaleThreads = {offset, add(offset, i32_val(1))}; + std::array scaleThreads = {offset, b.add(offset, b.i32_val(1))}; for (auto [i, scaleVal] : llvm::enumerate(scaleVals)) { std::array si = { @@ -317,9 +326,9 @@ class UpcastMXFPOpPattern : public ConvertOpToLLVMPattern { // One mfma16 intrinsic processes a 16x16 A tensor slice. Similarly, we // need to tile the warp 2 times to cover 32 valeus. So for a thread, the // first 2 1x4 vectors shares the first scale value at row (tid % mDim). - std::array scaleThreads = {offset, add(offset, i32_val(1)), - add(offset, i32_val(2)), - add(offset, i32_val(3))}; + std::array scaleThreads = {offset, b.add(offset, b.i32_val(1)), + b.add(offset, b.i32_val(2)), + b.add(offset, b.i32_val(3))}; for (auto [i, scaleVal] : llvm::enumerate(scaleVals)) { auto si = std::array{ diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/Utility.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/Utility.cpp index ad7d892b3270..d4b8d7abe01f 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/Utility.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/Utility.cpp @@ -45,12 +45,13 @@ std::string mangleFunc(std::string name, Type type) { // the same `pred` value Value createVectorMaskFromPredicate(RewriterBase &rewriter, Location loc, Value pred, int64_t vecSize) { + auto b = TritonLLVMOpBuilder(loc, rewriter); auto vecMaskTy = LLVM::getFixedVectorType(rewriter.getI1Type(), vecSize); - Value maskVal = undef(vecMaskTy); + Value maskVal = b.undef(vecMaskTy); for (size_t s = 0; s < vecSize; ++s) { Value indexVal = rewriter.create(loc, rewriter.getI64IntegerAttr(s)); - maskVal = insert_element(vecMaskTy, maskVal, pred, indexVal); + maskVal = b.insert_element(vecMaskTy, maskVal, pred, indexVal); } return maskVal; } @@ -75,6 +76,7 @@ namespace mlir::LLVM::AMD { static Value shuffleCommonImpl(Location loc, RewriterBase &rewriter, ISAFamily isaFamily, Value val, Value i, int strideInt, ShflKind mode, Value clamp) { + auto b = TritonLLVMOpBuilder(loc, rewriter); unsigned bits = val.getType().getIntOrFloatBitWidth(); // On AMD, the ds_swizzle_b32 and ds_permute_b32 instructions work on @@ -82,33 +84,33 @@ static Value shuffleCommonImpl(Location loc, RewriterBase &rewriter, auto valType = val.getType(); if (!valType.isInteger(32) && bits <= 32) { if (!valType.isIntOrIndex()) - val = bitcast(val, int_ty(bits)); + val = b.bitcast(val, int_ty(bits)); if (bits < 32) - val = sext(i32_ty, val); + val = b.sext(i32_ty, val); val = shuffleCommonImpl(loc, rewriter, isaFamily, val, i, strideInt, mode, clamp); if (bits < 32) - val = trunc(int_ty(bits), val); + val = b.trunc(int_ty(bits), val); if (!valType.isIntOrIndex()) - val = bitcast(val, valType); + val = b.bitcast(val, valType); return val; } if (bits == 64) { Type vecTy = vec_ty(f32_ty, 2); - Value vec = bitcast(val, vecTy); - Value val0 = extract_element(f32_ty, vec, i32_val(0)); - Value val1 = extract_element(f32_ty, vec, i32_val(1)); + Value vec = b.bitcast(val, vecTy); + Value val0 = b.extract_element(f32_ty, vec, b.i32_val(0)); + Value val1 = b.extract_element(f32_ty, vec, b.i32_val(1)); val0 = shuffleCommonImpl(loc, rewriter, isaFamily, val0, i, strideInt, mode, clamp); val1 = shuffleCommonImpl(loc, rewriter, isaFamily, val1, i, strideInt, mode, clamp); - vec = undef(vecTy); - vec = insert_element(vecTy, vec, val0, i32_val(0)); - vec = insert_element(vecTy, vec, val1, i32_val(1)); - return bitcast(vec, val.getType()); + vec = b.undef(vecTy); + vec = b.insert_element(vecTy, vec, val0, b.i32_val(0)); + vec = b.insert_element(vecTy, vec, val1, b.i32_val(1)); + return b.bitcast(vec, val.getType()); } auto mod = rewriter.getBlock()->getParent()->getParentOfType(); @@ -116,13 +118,13 @@ static Value shuffleCommonImpl(Location loc, RewriterBase &rewriter, rewriter.create<::mlir::gpu::ThreadIdOp>(loc, ::mlir::gpu::Dimension::x); threadId = rewriter.create(loc, i32_ty, threadId); unsigned iWarpSize = triton::gpu::TritonGPUDialect::getThreadsPerWarp(mod); - Value warpSize = i32_val(iWarpSize); - Value laneId = urem(threadId, warpSize); + Value warpSize = b.i32_val(iWarpSize); + Value laneId = b.urem(threadId, warpSize); auto bpermute = [&](Value lane) { // Multiple lineId by 4. (More on permute instruction semantics: // https://www.amd.com/content/dam/amd/en/documents/instinct-tech-docs/instruction-set-architectures/instinct-mi200-cdna2-instruction-set-architecture.pdf#page=180 - Value byteOffset = i32_val(2); - Value permuteAddr = shl(lane, byteOffset); + Value byteOffset = b.i32_val(2); + Value permuteAddr = b.shl(lane, byteOffset); return rewriter.create(loc, valType, permuteAddr, val); }; @@ -136,11 +138,11 @@ static Value shuffleCommonImpl(Location loc, RewriterBase &rewriter, ValueRange{rewriter.create<::mlir::gpu::ThreadIdOp>( loc, rewriter.getIndexType(), ::mlir::gpu::Dimension::x)}) .getResult(0); - Value stride = i32_val(32); - Value lineId = xor_(threadId, stride); + Value stride = b.i32_val(32); + Value lineId = b.xor_(threadId, stride); return bpermute(lineId); } else if (strideInt == 16) { - Value offset = i32_val(0x401F); + Value offset = b.i32_val(0x401F); return rewriter.create(loc, valType, val, offset); } else { if (isaFamily != ISAFamily::CDNA2 && isaFamily != ISAFamily::CDNA3) { @@ -151,7 +153,7 @@ static Value shuffleCommonImpl(Location loc, RewriterBase &rewriter, // than 16. The pattern stride is the key of the map. DenseMap masks{ {16, 0x401F}, {8, 0x201F}, {4, 0x101F}, {2, 0x081F}, {1, 0x041F}}; - Value offset = i32_val(masks[strideInt]); + Value offset = b.i32_val(masks[strideInt]); return rewriter.create(loc, valType, val, offset); } @@ -219,9 +221,9 @@ static Value shuffleCommonImpl(Location loc, RewriterBase &rewriter, } break; case ShflKind::up: { - Value mask = icmp_slt(laneId, i); - Value delta = sub(laneId, i); - Value index = select(mask, laneId, delta); + Value mask = b.icmp_slt(laneId, i); + Value delta = b.sub(laneId, i); + Value index = b.select(mask, laneId, delta); return bpermute(index); } case ShflKind::idx: @@ -236,38 +238,43 @@ static Value shuffleCommonImpl(Location loc, RewriterBase &rewriter, static Value shuffleCommon(Location loc, RewriterBase &rewriter, ISAFamily isaFamily, Value val, Value i, int strideInt, ShflKind mode, Value clamp) { + auto b = TritonLLVMOpBuilder(loc, rewriter); // To shuffle pointers, convert them to i64. Type valTy = val.getType(); if (isa(valTy)) - val = ptrtoint(i64_ty, val); + val = b.ptrtoint(i64_ty, val); Value result = shuffleCommonImpl(loc, rewriter, isaFamily, val, i, strideInt, mode, clamp); if (isa(valTy)) - result = inttoptr(valTy, result); + result = b.inttoptr(valTy, result); return result; } Value shuffleXor(Location loc, RewriterBase &rewriter, Value val, int i, ISAFamily isaFamily) { - return shuffleCommon(loc, rewriter, isaFamily, val, i32_val(i), i, - ShflKind::bfly, i32_val(0x1f)); + auto b = TritonLLVMOpBuilder(loc, rewriter); + return shuffleCommon(loc, rewriter, isaFamily, val, b.i32_val(i), i, + ShflKind::bfly, b.i32_val(0x1f)); } Value shuffleUp(Location loc, RewriterBase &rewriter, Value val, int i, ISAFamily isaFamily) { - return shuffleCommon(loc, rewriter, isaFamily, val, i32_val(i), i, - ShflKind::up, i32_val(0x0)); + auto b = TritonLLVMOpBuilder(loc, rewriter); + return shuffleCommon(loc, rewriter, isaFamily, val, b.i32_val(i), i, + ShflKind::up, b.i32_val(0x0)); } Value shuffleIdx(Location loc, RewriterBase &rewriter, Value val, int i, ISAFamily isaFamily) { - return shuffleIdx(loc, rewriter, val, i32_val(i), isaFamily); + auto b = TritonLLVMOpBuilder(loc, rewriter); + return shuffleIdx(loc, rewriter, val, b.i32_val(i), isaFamily); } Value shuffleIdx(Location loc, RewriterBase &rewriter, Value val, Value i, ISAFamily isaFamily) { + auto b = TritonLLVMOpBuilder(loc, rewriter); return shuffleCommon(loc, rewriter, isaFamily, val, i, 0, ShflKind::idx, - i32_val(0x1f)); + b.i32_val(0x1f)); } Value llGetPid(Location loc, RewriterBase &rewriter, ModuleOp moduleOp, @@ -285,7 +292,7 @@ Value llGetPid(Location loc, RewriterBase &rewriter, ModuleOp moduleOp, Value llLoad(RewriterBase &rewriter, Location loc, Value ptr, Type elemTy, Value pred, Value falseVal, int64_t alignmentBytes, triton::CacheModifier cm) { - + auto b = TritonLLVMOpBuilder(loc, rewriter); // Try to emit llvm.intr.masked.load if we can. In theory the backend should // be happier because we emit less branchy code to optimize. The backend will // lower it down however it wants at some point. @@ -295,13 +302,13 @@ Value llLoad(RewriterBase &rewriter, Location loc, Value ptr, Type elemTy, // to bitcast to `vector<1xelemTy>` (and back) int64_t vecSize = getNumElements(elemTy); Type vecType = castToVectorType(elemTy); - falseVal = bitcast(falseVal, vecType); + falseVal = b.bitcast(falseVal, vecType); Value maskVal = createVectorMaskFromPredicate(rewriter, loc, pred, vecSize); bool nt = (cm == triton::CacheModifier::CG); Value vecData = rewriter.create( loc, vecType, ptr, maskVal, falseVal, alignmentBytes, nt); // If it is not a vector, remember to bitcast back to a scalar - vecData = bitcast(vecData, elemTy); + vecData = b.bitcast(vecData, elemTy); return vecData; } @@ -332,6 +339,7 @@ Value llLoad(RewriterBase &rewriter, Location loc, Value ptr, Type elemTy, void llStore(RewriterBase &rewriter, Location loc, Value ptr, Value val, Value pred, int64_t alignmentBytes, triton::CacheModifier cm) { + auto b = TritonLLVMOpBuilder(loc, rewriter); // Try to emit llvm.intr.masked.store if we can. In theory the backend should // be happier because we emit less branchy code to optimize. The backend will // lower it down however it wants at some point. @@ -341,7 +349,7 @@ void llStore(RewriterBase &rewriter, Location loc, Value ptr, Value val, Type elemTy = val.getType(); int64_t vecSize = getNumElements(elemTy); Type vecType = castToVectorType(elemTy); - val = bitcast(val, vecType); + val = b.bitcast(val, vecType); Value maskVal = createVectorMaskFromPredicate(rewriter, loc, pred, vecSize); auto op = rewriter.create(loc, val, ptr, maskVal, alignmentBytes); diff --git a/third_party/nvidia/lib/NVGPUToLLVM/NVGPUToLLVMPass.cpp b/third_party/nvidia/lib/NVGPUToLLVM/NVGPUToLLVMPass.cpp index 7a1f518c8b8f..7a8c03146489 100644 --- a/third_party/nvidia/lib/NVGPUToLLVM/NVGPUToLLVMPass.cpp +++ b/third_party/nvidia/lib/NVGPUToLLVM/NVGPUToLLVMPass.cpp @@ -70,17 +70,18 @@ Type getTypeFromConstraint(char constraint, PatternRewriter &rewriter) { // val to i32 using ptrtoint(i32_ty, val) Value convertToType(Value val, std::string constraint, Location loc, PatternRewriter &rewriter) { + auto b = TritonLLVMOpBuilder(loc, rewriter); auto isConstraintNumber = isNumber(constraint); if (!isConstraintNumber) { auto ty = getTypeFromConstraint(constraint[0], rewriter); if (isa(val.getType())) { - return ptrtoint(ty, val); + return b.ptrtoint(ty, val); } else { assert(val.getType().getIntOrFloatBitWidth() <= ty.getIntOrFloatBitWidth() && "Cannot convert to a smaller type"); if (val.getType().getIntOrFloatBitWidth() < ty.getIntOrFloatBitWidth()) - return zext(ty, val); + return b.zext(ty, val); } } return val; @@ -101,6 +102,7 @@ OperandsAndConstraints unpackOperands(const OperandsAndConstraints &operandsAndConstraints, PTXBuilder &ptxBuilder, Location loc, PatternRewriter &rewriter) { + auto b = TritonLLVMOpBuilder(loc, rewriter); OperandsAndConstraints unpackedOperands; for (const auto &[operand, constraint] : operandsAndConstraints) { auto llvmStruct = llvm::dyn_cast(operand.getType()); @@ -114,11 +116,11 @@ unpackOperands(const OperandsAndConstraints &operandsAndConstraints, if (isConstraintNumber) { auto constraintInt = std::stoi(constraint) + i; unpackedOperands.push_back( - {extract_val(llvmStruct.getBody()[i], operand, i), + {b.extract_val(llvmStruct.getBody()[i], operand, i), std::to_string(constraintInt)}); } else { unpackedOperands.push_back( - {extract_val(llvmStruct.getBody()[i], operand, i), constraint}); + {b.extract_val(llvmStruct.getBody()[i], operand, i), constraint}); } } } else { @@ -392,6 +394,7 @@ class LoadAcquireOpPattern : public OpRewritePattern { LogicalResult matchAndRewrite(ttn::LoadAcquireOp op, PatternRewriter &rewriter) const override { auto loc = op->getLoc(); + auto b = TritonLLVMOpBuilder(loc, rewriter); Type valueTy = op.getType(); const unsigned valueNBits = std::max(8u, valueTy.getIntOrFloatBitWidth()); const size_t maxWordWidth = std::max(32, valueNBits); @@ -418,7 +421,7 @@ class LoadAcquireOpPattern : public OpRewritePattern { // Create inline ASM signature Type retTy = IntegerType::get(getContext(), width); Value ret = ptxBuilder.launch(rewriter, loc, retTy); - ret = bitcast(ret, op.getType()); + ret = b.bitcast(ret, op.getType()); rewriter.replaceOp(op, {ret}); return success(); diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/BarrierOpToLLVM.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/BarrierOpToLLVM.cpp index 7e727e3dee1d..510bc4d41692 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/BarrierOpToLLVM.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/BarrierOpToLLVM.cpp @@ -56,13 +56,14 @@ struct InitBarrierOpConversion matchAndRewrite(triton::nvidia_gpu::InitBarrierOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Location loc = op->getLoc(); + auto b = TritonLLVMOpBuilder(loc, rewriter); auto smemObj = LLVM::getSharedMemoryObjectFromStruct( loc, adaptor.getAlloc(), typeConverter->convertType(op.getAlloc().getType().getElementType()), rewriter); auto id = getThreadId(rewriter, loc); - auto pred = icmp_eq(id, i32_val(0)); + auto pred = b.icmp_eq(id, b.i32_val(0)); ::mlir::triton::PTXBuilder ptxBuilder; const std::string ptx = "@$0 mbarrier.init.shared::cta.b64 [$1], " + std::to_string(op.getCount()) + ";"; @@ -85,13 +86,14 @@ struct InvalBarrierOpConversion matchAndRewrite(triton::nvidia_gpu::InvalBarrierOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Location loc = op->getLoc(); + auto b = TritonLLVMOpBuilder(loc, rewriter); auto smemObj = LLVM::getSharedMemoryObjectFromStruct( loc, adaptor.getAlloc(), typeConverter->convertType(op.getAlloc().getType().getElementType()), rewriter); auto id = getThreadId(rewriter, loc); - Value pred = icmp_eq(id, i32_val(0)); + Value pred = b.icmp_eq(id, b.i32_val(0)); ::mlir::triton::PTXBuilder ptxBuilder; const std::string ptx = "@$0 mbarrier.inval.shared::cta.b64 [$1];"; auto &barSyncOp = *ptxBuilder.create<>(ptx); @@ -113,14 +115,15 @@ struct BarrierExpectConversion matchAndRewrite(triton::nvidia_gpu::BarrierExpectOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Location loc = op->getLoc(); + auto b = TritonLLVMOpBuilder(loc, rewriter); auto smemObj = LLVM::getSharedMemoryObjectFromStruct( loc, adaptor.getAlloc(), typeConverter->convertType(op.getAlloc().getType().getElementType()), rewriter); auto id = getThreadId(rewriter, loc); - Value pred = icmp_eq(id, i32_val(0)); - pred = and_(pred, adaptor.getPred()); + Value pred = b.icmp_eq(id, b.i32_val(0)); + pred = b.and_(pred, adaptor.getPred()); ::mlir::triton::PTXBuilder ptxBuilder; const std::string ptx = "@$0 mbarrier.arrive.expect_tx.shared.b64 _, [$1], " + diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM.cpp index 30a740f8bc01..8abb7131ef9e 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM.cpp @@ -54,6 +54,7 @@ struct ConvertLayoutOpConversion const TargetInfoBase &targetInfo) const { MLIRContext *ctx = rewriter.getContext(); auto loc = op.getLoc(); + auto b = TritonLLVMOpBuilder(loc, rewriter); auto typeConverter = getTypeConverter(); auto srcTy = op.getSrc().getType(); auto dstTy = op.getType(); @@ -69,7 +70,7 @@ struct ConvertLayoutOpConversion Value smemBase = LLVM::getSharedMemoryBase(loc, rewriter, targetInfo, op.getOperation()); - smemBase = bitcast(smemBase, elemPtrTy); + smemBase = b.bitcast(smemBase, elemPtrTy); auto smemShape = convertType(srcShapePerCTA); // Store to local shared memory @@ -83,8 +84,8 @@ struct ConvertLayoutOpConversion for (unsigned i = 0; i < inIndices.size(); ++i) { Value offset = LLVM::linearize(rewriter, loc, inIndices[i], smemShape); - Value ptr = gep(elemPtrTy, llvmElemTy, smemBase, offset); - store(inVals[i], ptr); + Value ptr = b.gep(elemPtrTy, llvmElemTy, smemBase, offset); + b.store(inVals[i], ptr); } } @@ -96,7 +97,7 @@ struct ConvertLayoutOpConversion { SmallVector srcShapePerCTACache; for (unsigned i = 0; i < rank; ++i) - srcShapePerCTACache.push_back(i32_val(srcShapePerCTA[i])); + srcShapePerCTACache.push_back(b.i32_val(srcShapePerCTA[i])); SmallVector outVals; auto outIndices = emitIndices(loc, rewriter, targetInfo, dstLayout, dstTy, @@ -108,8 +109,8 @@ struct ConvertLayoutOpConversion SmallVector multiDimCTAId, localCoord; for (unsigned d = 0; d < rank; ++d) { - multiDimCTAId.push_back(udiv(coord[d], srcShapePerCTACache[d])); - localCoord.push_back(urem(coord[d], srcShapePerCTACache[d])); + multiDimCTAId.push_back(b.udiv(coord[d], srcShapePerCTACache[d])); + localCoord.push_back(b.urem(coord[d], srcShapePerCTACache[d])); } Value remoteCTAId = LLVM::linearize(rewriter, loc, multiDimCTAId, @@ -117,9 +118,10 @@ struct ConvertLayoutOpConversion Value localOffset = LLVM::linearize(rewriter, loc, localCoord, smemShape); - Value ptr = gep(elemPtrTy, llvmElemTy, smemBase, localOffset); - outVals.push_back(targetInfo.loadDShared( - rewriter, loc, ptr, remoteCTAId, llvmElemTy, /*pred=*/true_val())); + Value ptr = b.gep(elemPtrTy, llvmElemTy, smemBase, localOffset); + outVals.push_back(targetInfo.loadDShared(rewriter, loc, ptr, + remoteCTAId, llvmElemTy, + /*pred=*/b.true_val())); } Value result = @@ -142,51 +144,52 @@ struct ConvertLayoutOpConversion OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { auto loc = op.getLoc(); + auto b = TritonLLVMOpBuilder(loc, rewriter); auto dstTy = op.getType(); auto vals = unpackLLElements(loc, adaptor.getSrc(), rewriter); SmallVector retVals; for (int i = 0; i < vals.size(); i += 8) { - Value upper = undef(vec_ty(i8_ty, 4)); + Value upper = b.undef(vec_ty(i8_ty, 4)); for (int j = 0; j < 4; j++) { - upper = - insert_element(vec_ty(i8_ty, 4), upper, vals[i + j], i32_val(j)); + upper = b.insert_element(vec_ty(i8_ty, 4), upper, vals[i + j], + b.i32_val(j)); } - upper = bitcast(upper, i32_ty); - Value lower = undef(vec_ty(i8_ty, 4)); + upper = b.bitcast(upper, i32_ty); + Value lower = b.undef(vec_ty(i8_ty, 4)); for (int j = 0; j < 4; j++) { - lower = insert_element(vec_ty(i8_ty, 4), lower, vals[i + 4 + j], - i32_val(j)); + lower = b.insert_element(vec_ty(i8_ty, 4), lower, vals[i + 4 + j], + b.i32_val(j)); } - lower = bitcast(lower, i32_ty); - - Value threadIdMod4 = urem(getThreadId(rewriter, loc), i32_val(4)); - Value cnd = or_(icmp_eq(threadIdMod4, i32_val(0)), - icmp_eq(threadIdMod4, i32_val(3))); - Value selectorEx0 = select(cnd, i32_val(0x3210), i32_val(0x7654)); - Value selectorEx1 = select(cnd, i32_val(0x7654), i32_val(0x3210)); - Value selectorEx4 = select(cnd, i32_val(0x5410), i32_val(0x1054)); - Value selectorEx5 = select(cnd, i32_val(0x7632), i32_val(0x3276)); - - Value isOne = icmp_eq(threadIdMod4, i32_val(1)); - Value isTwo = icmp_eq(threadIdMod4, i32_val(2)); - Value isThree = icmp_eq(threadIdMod4, i32_val(3)); - Value upperIdx = i32_val(0); - upperIdx = select(isOne, i32_val(3), upperIdx); - upperIdx = select(isTwo, i32_val(1), upperIdx); - upperIdx = select(isThree, i32_val(2), upperIdx); - - Value lowerIdx = i32_val(1); - lowerIdx = select(isOne, i32_val(2), lowerIdx); - lowerIdx = select(isTwo, i32_val(0), lowerIdx); - lowerIdx = select(isThree, i32_val(3), lowerIdx); + lower = b.bitcast(lower, i32_ty); + + Value threadIdMod4 = b.urem(getThreadId(rewriter, loc), b.i32_val(4)); + Value cnd = b.or_(b.icmp_eq(threadIdMod4, b.i32_val(0)), + b.icmp_eq(threadIdMod4, b.i32_val(3))); + Value selectorEx0 = b.select(cnd, b.i32_val(0x3210), b.i32_val(0x7654)); + Value selectorEx1 = b.select(cnd, b.i32_val(0x7654), b.i32_val(0x3210)); + Value selectorEx4 = b.select(cnd, b.i32_val(0x5410), b.i32_val(0x1054)); + Value selectorEx5 = b.select(cnd, b.i32_val(0x7632), b.i32_val(0x3276)); + + Value isOne = b.icmp_eq(threadIdMod4, b.i32_val(1)); + Value isTwo = b.icmp_eq(threadIdMod4, b.i32_val(2)); + Value isThree = b.icmp_eq(threadIdMod4, b.i32_val(3)); + Value upperIdx = b.i32_val(0); + upperIdx = b.select(isOne, b.i32_val(3), upperIdx); + upperIdx = b.select(isTwo, b.i32_val(1), upperIdx); + upperIdx = b.select(isThree, b.i32_val(2), upperIdx); + + Value lowerIdx = b.i32_val(1); + lowerIdx = b.select(isOne, b.i32_val(2), lowerIdx); + lowerIdx = b.select(isTwo, b.i32_val(0), lowerIdx); + lowerIdx = b.select(isThree, b.i32_val(3), lowerIdx); Value upper0 = LLVM::NVIDIA::permute(loc, rewriter, upper, lower, selectorEx0); Value lower0 = LLVM::NVIDIA::permute(loc, rewriter, upper, lower, selectorEx1); - Value mask = i32_val(0xFFFFFFFF); + Value mask = b.i32_val(0xFFFFFFFF); // Set clamp tp shuffle only within 4 lanes. - Value clamp = i32_val(0x1C1F); + Value clamp = b.i32_val(0x1C1F); upper0 = rewriter.create(loc, i32_ty, mask, upper0, upperIdx, clamp, NVVM::ShflKind::idx, UnitAttr()); @@ -195,15 +198,15 @@ struct ConvertLayoutOpConversion clamp, NVVM::ShflKind::idx, UnitAttr()); Value upper1 = LLVM::NVIDIA::permute(loc, rewriter, upper0, lower0, selectorEx4); - Value vecVal = bitcast(upper1, vec_ty(i8_ty, 4)); + Value vecVal = b.bitcast(upper1, vec_ty(i8_ty, 4)); for (int i = 0; i < 4; i++) { - retVals.push_back(extract_element(i8_ty, vecVal, i32_val(i))); + retVals.push_back(b.extract_element(i8_ty, vecVal, b.i32_val(i))); } Value lower1 = LLVM::NVIDIA::permute(loc, rewriter, upper0, lower0, selectorEx5); - vecVal = bitcast(lower1, vec_ty(i8_ty, 4)); + vecVal = b.bitcast(lower1, vec_ty(i8_ty, 4)); for (int i = 0; i < 4; i++) { - retVals.push_back(extract_element(i8_ty, vecVal, i32_val(i))); + retVals.push_back(b.extract_element(i8_ty, vecVal, b.i32_val(i))); } } Value result = diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMMAv2OrV3.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMMAv2OrV3.cpp index 14691634c790..04e489513908 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMMAv2OrV3.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMMAv2OrV3.cpp @@ -103,17 +103,18 @@ class MMA16816SmemLoader { SmallVector MMA16816SmemLoader::computeLdmatrixMatOffs(Value lane, Value cSwizzleOffset) { + auto b = TritonLLVMOpBuilder(loc, rewriter); Value warpB = multiDimWarpId[0]; Value warpId = kOrder == 2 ? multiDimWarpId[1] : multiDimWarpId[2]; // 4x4 matrices - Value rowInMat = urem(lane, i32_val(8)); // row in the 8x8 matrix - Value matIndex = - udiv(lane, i32_val(8)); // linear index of the matrix in the 2x2 matrices + Value rowInMat = b.urem(lane, b.i32_val(8)); // row in the 8x8 matrix + Value matIndex = b.udiv( + lane, b.i32_val(8)); // linear index of the matrix in the 2x2 matrices // Decompose matIndex => s_0, s_1, that is the coordinate in 2x2 matrices in a // warp - Value matIndexY = urem(matIndex, i32_val(2)); - Value matIndexX = udiv(matIndex, i32_val(2)); + Value matIndexY = b.urem(matIndex, b.i32_val(2)); + Value matIndexX = b.udiv(matIndex, b.i32_val(2)); // We use different orders for a and b for better performance. Value kMatArr = @@ -144,13 +145,13 @@ MMA16816SmemLoader::computeLdmatrixMatOffs(Value lane, Value cSwizzleOffset) { // access will be out of bound. In the future we should change this case to // ldmatrix.x2 if (kOrder == 1 && nPerWarp == 8) { - matOff[nonKOrder] = mul(warpId, i32_val(warpMatOffset)); + matOff[nonKOrder] = b.mul(warpId, b.i32_val(warpMatOffset)); } else { - matOff[nonKOrder] = add( - mul(warpId, i32_val(warpMatOffset)), // warp offset (kOrder=2) - mul(nkMatArr, - i32_val( - inWarpMatOffset))); // matrix offset inside a warp (kOrder=2) + matOff[nonKOrder] = b.add( + b.mul(warpId, b.i32_val(warpMatOffset)), // warp offset (kOrder=2) + b.mul(nkMatArr, + b.i32_val( + inWarpMatOffset))); // matrix offset inside a warp (kOrder=2) } matOff[kOrder] = kMatArr; @@ -159,39 +160,41 @@ MMA16816SmemLoader::computeLdmatrixMatOffs(Value lane, Value cSwizzleOffset) { Value stridedMatIndex = matOff[order[1]]; // Add the offset of the slice Value contiguousSliceMatOffset = - udiv(cSwizzleOffset, i32_val(contiguousMatShape)); + b.udiv(cSwizzleOffset, b.i32_val(contiguousMatShape)); SmallVector offs(numPtrs); - Value phase = urem(udiv(rowInMat, i32_val(perPhase)), i32_val(maxPhase)); + Value phase = + b.urem(b.udiv(rowInMat, b.i32_val(perPhase)), b.i32_val(maxPhase)); // To prevent out-of-bound access of B when warpsPerTile * 16 > tile_size. // In such a case, we need to wrap around the offset of B. // |0 1 2 3 0 1 2 3| -> | 0(0) 1(1) 2(2) 3(3) | // |0 1 2 3 0 1 2 3| | 0(0) 1(1) 2(2) 3(3) | // ~~~~~~~ out-of-bound access - Value rowOffset = - urem(add(rowInMat, mul(stridedMatIndex, i32_val(stridedMatShape))), - i32_val(tileShape[order[1]])); + Value rowOffset = b.urem( + b.add(rowInMat, b.mul(stridedMatIndex, b.i32_val(stridedMatShape))), + b.i32_val(tileShape[order[1]])); auto contiguousTileNumMats = tileShape[order[0]] / matShape[order[0]]; for (int i = 0; i < numPtrs; ++i) { Value contiguousIndex = - add(contiguousMatIndex, i32_val(i * contiguousLoadMatOffset)); + b.add(contiguousMatIndex, b.i32_val(i * contiguousLoadMatOffset)); if (warpsPerCTA[order[0]] > contiguousTileNumMats || contiguousTileNumMats % warpsPerCTA[order[0]] != 0) - contiguousIndex = urem(contiguousIndex, i32_val(contiguousTileNumMats)); - contiguousIndex = add(contiguousIndex, contiguousSliceMatOffset); - Value contiguousIndexSwizzled = xor_(contiguousIndex, phase); + contiguousIndex = + b.urem(contiguousIndex, b.i32_val(contiguousTileNumMats)); + contiguousIndex = b.add(contiguousIndex, contiguousSliceMatOffset); + Value contiguousIndexSwizzled = b.xor_(contiguousIndex, phase); if (tileShape[0] != 1) { Value batchOffset = - mul(warpB, i32_val(tileShape[order[0]] * tileShape[order[1]])); - offs[i] = - add(batchOffset, - add(mul(contiguousIndexSwizzled, i32_val(contiguousMatShape)), - mul(rowOffset, stridedSmemOffset))); + b.mul(warpB, b.i32_val(tileShape[order[0]] * tileShape[order[1]])); + offs[i] = b.add(batchOffset, b.add(b.mul(contiguousIndexSwizzled, + b.i32_val(contiguousMatShape)), + b.mul(rowOffset, stridedSmemOffset))); } else { - offs[i] = add(mul(contiguousIndexSwizzled, i32_val(contiguousMatShape)), - mul(rowOffset, stridedSmemOffset)); + offs[i] = + b.add(b.mul(contiguousIndexSwizzled, b.i32_val(contiguousMatShape)), + b.mul(rowOffset, stridedSmemOffset)); } } @@ -224,6 +227,7 @@ MMA16816SmemLoader::computeLdmatrixMatOffs(Value lane, Value cSwizzleOffset) { SmallVector MMA16816SmemLoader::computeLdsMatOffs(Value lane, Value cSwizzleOffset) { + auto b = TritonLLVMOpBuilder(loc, rewriter); Value warpB = multiDimWarpId[0]; Value warpOff = kOrder == 2 ? multiDimWarpId[1] : multiDimWarpId[2]; @@ -236,40 +240,44 @@ SmallVector MMA16816SmemLoader::computeLdsMatOffs(Value lane, int quadHeight = laneHeight; // outer index base - Value iBase = udiv(lane, i32_val(laneWidth)); + Value iBase = b.udiv(lane, b.i32_val(laneWidth)); for (int rep = 0; rep < numPtrs / (2 * kWidth); ++rep) for (int quadId = 0; quadId < 2; ++quadId) for (int elemId = 0; elemId < kWidth; ++elemId) { // inner index base - Value jBase = mul(urem(lane, i32_val(laneWidth)), i32_val(kWidth)); - jBase = add(jBase, i32_val(elemId)); + Value jBase = + b.mul(b.urem(lane, b.i32_val(laneWidth)), b.i32_val(kWidth)); + jBase = b.add(jBase, b.i32_val(elemId)); // inner index offset - Value jOff = i32_val(0); + Value jOff = b.i32_val(0); if (!needTrans) { - jOff = add(jOff, i32_val(quadId)); - jOff = add(jOff, i32_val(rep * contiguousLoadMatOffset)); + jOff = b.add(jOff, b.i32_val(quadId)); + jOff = b.add(jOff, b.i32_val(rep * contiguousLoadMatOffset)); } // outer index offset - Value iOff = mul(warpOff, i32_val(warpMatOffset)); + Value iOff = b.mul(warpOff, b.i32_val(warpMatOffset)); if (needTrans) { int pStride = kOrder == 2 ? 1 : 2; - iOff = add(iOff, i32_val(quadId * inWarpMatOffset)); - iOff = add(iOff, i32_val(rep * contiguousLoadMatOffset * pStride)); + iOff = b.add(iOff, b.i32_val(quadId * inWarpMatOffset)); + iOff = + b.add(iOff, b.i32_val(rep * contiguousLoadMatOffset * pStride)); } // swizzle if (!needTrans) { - Value phase = urem(udiv(iBase, i32_val(perPhase)), i32_val(maxPhase)); - jOff = add(jOff, udiv(cSwizzleOffset, i32_val(quadWidth))); - jOff = xor_(jOff, phase); + Value phase = + b.urem(b.udiv(iBase, b.i32_val(perPhase)), b.i32_val(maxPhase)); + jOff = b.add(jOff, b.udiv(cSwizzleOffset, b.i32_val(quadWidth))); + jOff = b.xor_(jOff, phase); } else { - Value phase = urem(udiv(jBase, i32_val(perPhase)), i32_val(maxPhase)); - iOff = add(iOff, udiv(cSwizzleOffset, i32_val(quadHeight))); - iOff = xor_(iOff, phase); + Value phase = + b.urem(b.udiv(jBase, b.i32_val(perPhase)), b.i32_val(maxPhase)); + iOff = b.add(iOff, b.udiv(cSwizzleOffset, b.i32_val(quadHeight))); + iOff = b.xor_(iOff, phase); } // To prevent out-of-bound access when tile is too small. - Value i = add(iBase, mul(iOff, i32_val(quadHeight))); - Value j = add(jBase, mul(jOff, i32_val(quadWidth))); + Value i = b.add(iBase, b.mul(iOff, b.i32_val(quadHeight))); + Value j = b.add(jBase, b.mul(jOff, b.i32_val(quadWidth))); // Compute id of this ptr int idx = rep * 2 * kWidth; if (needTrans) { @@ -282,14 +290,14 @@ SmallVector MMA16816SmemLoader::computeLdsMatOffs(Value lane, } if (needTrans) { - offs[idx] = add(i, mul(j, stridedSmemOffset)); + offs[idx] = b.add(i, b.mul(j, stridedSmemOffset)); } else { - offs[idx] = add(mul(i, stridedSmemOffset), j); + offs[idx] = b.add(b.mul(i, stridedSmemOffset), j); } if (tileShape[0] != 1) { - Value batchOffset = - mul(warpB, i32_val(tileShape[order[0]] * tileShape[order[1]])); - offs[idx] = add(batchOffset, offs[idx]); + Value batchOffset = b.mul( + warpB, b.i32_val(tileShape[order[0]] * tileShape[order[1]])); + offs[idx] = b.add(batchOffset, offs[idx]); } } @@ -299,6 +307,7 @@ SmallVector MMA16816SmemLoader::computeLdsMatOffs(Value lane, std::tuple MMA16816SmemLoader::loadX4(int batch, int mat0, int mat1, ArrayRef ptrs, Type matTy, Type shemTy) const { + auto b = TritonLLVMOpBuilder(loc, rewriter); assert(mat0 % 2 == 0 && mat1 % 2 == 0 && "smem matrix load must be aligned"); int matIdx[3] = {0, mat0, mat1}; @@ -334,18 +343,19 @@ MMA16816SmemLoader::loadX4(int batch, int mat0, int mat1, ArrayRef ptrs, } if (canUseLdmatrix) { - Value stridedOffset = - mul(i32_val(matIdx[order[1]] * stridedLoadMatOffset * stridedMatShape), - stridedSmemOffset); + Value stridedOffset = b.mul( + b.i32_val(matIdx[order[1]] * stridedLoadMatOffset * stridedMatShape), + stridedSmemOffset); if (batch != 0) - stridedOffset = add( - stridedOffset, mul(i32_val(batch * warpsPerCTA[0]), smemBatchOffset)); - Value readPtr = gep(ptr_ty(ctx, 3), shemTy, ptr, stridedOffset); + stridedOffset = + b.add(stridedOffset, + b.mul(b.i32_val(batch * warpsPerCTA[0]), smemBatchOffset)); + Value readPtr = b.gep(ptr_ty(ctx, 3), shemTy, ptr, stridedOffset); auto ldMatrixOp = rewriter.create(loc, resTy, readPtr, needTrans); auto resV4 = ldMatrixOp.getResult(); - return {extract_val(elemTy, resV4, 0), extract_val(elemTy, resV4, 1), - extract_val(elemTy, resV4, 2), extract_val(elemTy, resV4, 3)}; + return {b.extract_val(elemTy, resV4, 0), b.extract_val(elemTy, resV4, 1), + b.extract_val(elemTy, resV4, 2), b.extract_val(elemTy, resV4, 3)}; } else { // base pointers // ptrs[k][...] holds `vec` pointers each for (quadK == k) @@ -362,11 +372,11 @@ MMA16816SmemLoader::loadX4(int batch, int mat0, int mat1, ArrayRef ptrs, : stridedLoadMatOffset * stridedMatShape; else _i1 += (kOrder == 2 ? 1 : stridedLoadMatOffset) * stridedMatShape; - Value i0 = mul(i32_val(_i0), stridedSmemOffset); - Value i1 = mul(i32_val(_i1), stridedSmemOffset); + Value i0 = b.mul(b.i32_val(_i0), stridedSmemOffset); + Value i1 = b.mul(b.i32_val(_i1), stridedSmemOffset); if (batch != 0) { - i0 = add(i0, mul(i32_val(batch * warpsPerCTA[0]), smemBatchOffset)); - i1 = add(i1, mul(i32_val(batch * warpsPerCTA[0]), smemBatchOffset)); + i0 = b.add(i0, b.mul(b.i32_val(batch * warpsPerCTA[0]), smemBatchOffset)); + i1 = b.add(i1, b.mul(b.i32_val(batch * warpsPerCTA[0]), smemBatchOffset)); } // ii[m] holds the offset for (quadM == m) std::array ii = {i0, i1}; @@ -377,7 +387,7 @@ MMA16816SmemLoader::loadX4(int batch, int mat0, int mat1, ArrayRef ptrs, // i iterates the 2x2 quads, m-first for (int i = 0; i < 4; ++i) for (int j = 0; j < vecWidth; ++j) { - vptrs[i][j] = gep(ptr_ty(ctx, 3), shemTy, ptrs[i / 2][j], ii[i % 2]); + vptrs[i][j] = b.gep(ptr_ty(ctx, 3), shemTy, ptrs[i / 2][j], ii[i % 2]); } // row + trans and col + no-trans are equivalent bool isActualTrans = @@ -392,20 +402,20 @@ MMA16816SmemLoader::loadX4(int batch, int mat0, int mat1, ArrayRef ptrs, // Hopper may not contain 32b contiguously along k-dimension int kBits = isHopper ? (8 * elemBytes * kWidth) : 32; int vecSize = kBits / canonBits; - retElems.fill(undef(vec_ty(canonInt, vecSize))); + retElems.fill(b.undef(vec_ty(canonInt, vecSize))); for (int r = 0; r < 2; ++r) { for (int em = 0; em < 2 * vecWidth; em += inc) { int e = em % vecWidth; int m = em / vecWidth; int idx = m * 2 + r; - Value ptr = bitcast(vptrs[idx][e], ptr_ty(ctx, 3)); - Value val = load(packedTy, ptr); - Value canonval = bitcast(val, vec_ty(canonInt, canonWidth)); + Value ptr = b.bitcast(vptrs[idx][e], ptr_ty(ctx, 3)); + Value val = b.load(packedTy, ptr); + Value canonval = b.bitcast(val, vec_ty(canonInt, canonWidth)); for (int w = 0; w < canonWidth; ++w) { int ridx = idx + w * kWidth / vecWidth; - retElems[ridx] = - insert_element(retElems[ridx], - extract_element(canonval, i32_val(w)), i32_val(e)); + retElems[ridx] = b.insert_element( + retElems[ridx], b.extract_element(canonval, b.i32_val(w)), + b.i32_val(e)); } } } @@ -414,8 +424,8 @@ MMA16816SmemLoader::loadX4(int batch, int mat0, int mat1, ArrayRef ptrs, auto iTy = isHopper ? int_ty(kBits) : i32_ty; - return {bitcast(retElems[0], iTy), bitcast(retElems[1], iTy), - bitcast(retElems[2], iTy), bitcast(retElems[3], iTy)}; + return {b.bitcast(retElems[0], iTy), b.bitcast(retElems[1], iTy), + b.bitcast(retElems[2], iTy), b.bitcast(retElems[3], iTy)}; } } @@ -515,6 +525,7 @@ Value composeValuesToDotOperandLayoutStruct( const LLVMTypeConverter *typeConverter, Location loc, ConversionPatternRewriter &rewriter, Type eltTy, int kWidth, bool isHopper, bool isA) { + auto tb = TritonLLVMOpBuilder(loc, rewriter); auto bitwidth = eltTy.getIntOrFloatBitWidth(); assert(32 >= bitwidth && "only support 32-bit or less"); auto numElemsPerVec = isHopper ? kWidth : 32 / bitwidth; @@ -534,9 +545,9 @@ Value composeValuesToDotOperandLayoutStruct( auto unpackVec = [&](int b, int m, int k) { for (int kIter = 0; kIter < kIters; ++kIter) { auto val = vals.at({b, m, (k + kIter) % kSize}); - auto vec = bitcast(val, vecTy); + auto vec = tb.bitcast(val, vecTy); for (auto i = 0; i < numElemsPerVec; ++i) { - elems.push_back(extract_element(eltTy, vec, i32_val(i))); + elems.push_back(tb.extract_element(eltTy, vec, tb.i32_val(i))); } } }; @@ -577,6 +588,7 @@ getLoadMatrixFn(MemDescType descTy, const SharedMemoryObject &smemObj, Value lane, ValueTable &vals, bool isA, const LLVMTypeConverter *typeConverter, ConversionPatternRewriter &rewriter, Location loc) { + auto tb = TritonLLVMOpBuilder(loc, rewriter); auto shapePerCTA = getShapePerCTA(descTy); Type eltTy = descTy.getElementType(); // We assumes that the input operand of Dot should be from shared layout. @@ -594,7 +606,7 @@ getLoadMatrixFn(MemDescType descTy, const SharedMemoryObject &smemObj, int nPerWarp = std::max(shapePerCTA[2] / mmaLayout.getWarpsPerCTA()[2], 8); // (a, b) is the coordinate. - auto load = [=, &rewriter, &vals](int batch, int a, int b) { + auto load = [=, &rewriter, &vals, &tb](int batch, int a, int b) { MMA16816SmemLoader loader( nPerWarp, warpsPerTile, order, mmaLayout.getWarpsPerCTA(), kOrder, kWidth, strides, shapePerCTA /*tileShape*/, instrShape, matShape, @@ -610,7 +622,7 @@ getLoadMatrixFn(MemDescType descTy, const SharedMemoryObject &smemObj, Type smemTy = getSharedMemTy(eltTy); for (int i = 0; i < numPtrs; ++i) ptrs[i] = - gep(ptr_ty(rewriter.getContext(), 3), smemTy, smemBase, offs[i]); + tb.gep(ptr_ty(rewriter.getContext(), 3), smemTy, smemBase, offs[i]); // actually load from shared memory auto matTy = LLVM::LLVMStructType::getLiteral(eltTy.getContext(), SmallVector(4, i32_ty)); @@ -640,6 +652,7 @@ Value loadArg(ConversionPatternRewriter &rewriter, Location loc, MemDescType descTy, DotOperandEncodingAttr encoding, const SharedMemoryObject &smemObj, const LLVMTypeConverter *typeConverter, Value thread, bool isA) { + auto tb = TritonLLVMOpBuilder(loc, rewriter); auto mmaLayout = mlir::cast(encoding.getParent()); bool isHopper = mmaLayout.getVersionMajor() == 3; auto shapePerCTA = getShapePerCTA(descTy); @@ -659,15 +672,15 @@ Value loadArg(ConversionPatternRewriter &rewriter, Location loc, auto warpsPerCTA = mmaLayout.getWarpsPerCTA(); auto warpOrder = mmaLayout.getWarpOrder(); - Value warp = udiv(thread, i32_val(32)); - Value lane = urem(thread, i32_val(32)); + Value warp = tb.udiv(thread, tb.i32_val(32)); + Value lane = tb.urem(thread, tb.i32_val(32)); SmallVector multiDimWarpId = delinearize(rewriter, loc, warp, warpsPerCTA, warpOrder); - Value warpB = urem(multiDimWarpId[0], i32_val(shapePerCTA[0])); + Value warpB = tb.urem(multiDimWarpId[0], tb.i32_val(shapePerCTA[0])); int warpsPerTile; - Value warpM = urem(multiDimWarpId[1], i32_val(shapePerCTA[1] / 16)); - Value warpN = urem(multiDimWarpId[2], i32_val(shapePerCTA[2] / 8)); + Value warpM = tb.urem(multiDimWarpId[1], tb.i32_val(shapePerCTA[1] / 16)); + Value warpN = tb.urem(multiDimWarpId[2], tb.i32_val(shapePerCTA[2] / 8)); if (isA) warpsPerTile = std::min(warpsPerCTA[1], shapePerCTA[1] / 16); else diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/MMAv2.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/MMAv2.cpp index c5ec00097d93..24defdf1975e 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/MMAv2.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/MMAv2.cpp @@ -17,6 +17,7 @@ using ValueTableV2 = std::map, Value>; Value loadC(Value tensor, Value llTensor, const LLVMTypeConverter *typeConverter, Location loc, ConversionPatternRewriter &rewriter) { + auto b = TritonLLVMOpBuilder(loc, rewriter); MLIRContext *ctx = tensor.getContext(); auto tensorTy = cast(tensor.getType()); size_t fcSize = triton::gpu::getTotalElemsPerThread(tensor.getType()); @@ -42,8 +43,9 @@ Value loadC(Value tensor, Value llTensor, for (int i = 0; i < fcSize; i += numCPackedElem) { Value pack = rewriter.create(loc, cPackTy); for (int j = 0; j < numCPackedElem; ++j) { - pack = insert_element( - cPackTy, pack, extract_val(cElemTy, llTensor, i + j), i32_val(j)); + pack = b.insert_element(cPackTy, pack, + b.extract_val(cElemTy, llTensor, i + j), + b.i32_val(j)); } cPack.push_back(pack); } @@ -62,6 +64,7 @@ ValueTableV2 getValuesFromDotOperandLayoutStruct( const LLVMTypeConverter *typeConverter, Location loc, ConversionPatternRewriter &rewriter, Value value, int batch, int repOuter, int repK, RankedTensorType type) { + auto b = TritonLLVMOpBuilder(loc, rewriter); auto elems = unpackLLElements(loc, value, rewriter); auto eltTy = typeConverter->convertType(type.getElementType()); int offset{}; @@ -71,11 +74,12 @@ ValueTableV2 getValuesFromDotOperandLayoutStruct( auto vecTy = vec_ty(eltTy, numElemsPerVec); auto packVec = [&](std::array dstIdx) { - Value vec = undef(vecTy); + Value vec = b.undef(vecTy); for (auto i = 0; i < numElemsPerVec; ++i) { - vec = insert_element(vec, bitcast(elems[offset + i], eltTy), i32_val(i)); + vec = b.insert_element(vec, b.bitcast(elems[offset + i], eltTy), + b.i32_val(i)); } - vals[dstIdx] = bitcast(vec, i32_ty); + vals[dstIdx] = b.bitcast(vec, i32_ty); offset += numElemsPerVec; }; @@ -469,6 +473,7 @@ LogicalResult convertDot(const LLVMTypeConverter *typeConverter, Value a, Value b, Value c, Value d, Value loadedA, Value loadedB, Value loadedC, DotOp op, DotOpAdaptor adaptor, bool isTuring) { + auto tb = TritonLLVMOpBuilder(loc, rewriter); MLIRContext *ctx = c.getContext(); auto aTensorTy = cast(a.getType()); auto bTensorTy = cast(b.getType()); @@ -551,7 +556,7 @@ LogicalResult convertDot(const LLVMTypeConverter *typeConverter, Type elemTy = cast(mmaOut.getType()).getBody()[0]; for (int i = 0; i < numMmaRets; ++i) { fc[(m * colsPerThread + 4 * n) / numCPackedElem + i + batchOffset * b] = - extract_val(elemTy, mmaOut, i); + tb.extract_val(elemTy, mmaOut, i); } }; @@ -571,8 +576,8 @@ LogicalResult convertDot(const LLVMTypeConverter *typeConverter, for (int j = 0; j < numCPackedElem; ++j) { results[i * numCPackedElem + j] = numCPackedElem > 1 - ? bitcast(extract_element(fc[i], i32_val(j)), resElemTy) - : bitcast(fc[i], resElemTy); + ? tb.bitcast(tb.extract_element(fc[i], tb.i32_val(j)), resElemTy) + : tb.bitcast(fc[i], resElemTy); } } Value res = packLLElements(loc, typeConverter, results, rewriter, structTy); diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/WGMMA.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/WGMMA.cpp index 7450bc3f4e1b..2291308b6d77 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/WGMMA.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/WGMMA.cpp @@ -115,7 +115,7 @@ static Value createDescriptor(ConversionPatternRewriter &rewriter, Location loc, } desc.strideDimensionBaseOffset = swizzling >> 1; desc.leadDimensionBaseOffset = (swizzling * stride) >> 4; - return int_val(64, desc.descriptor); + return b.int_val(64, desc.descriptor); } mlir::triton::NVIDIA::DotOpMmaV3SmemLoader::DotOpMmaV3SmemLoader( @@ -124,13 +124,14 @@ mlir::triton::NVIDIA::DotOpMmaV3SmemLoader::DotOpMmaV3SmemLoader( int64_t elementBitwidth, ConversionPatternRewriter &rewriter, Location loc) : base(base), shape(shape), warpId(warpId), dimWpt(dimWpt), trans(trans), instrShape(instrShape), elemBits(elementBitwidth) { + auto b = TritonLLVMOpBuilder(loc, rewriter); auto ty = cast(tensor.getType()); auto sharedLayout = cast(ty.getEncoding()); ord = sharedLayout.getOrder(); const int perPhase = sharedLayout.getPerPhase(); const int maxPhase = sharedLayout.getMaxPhase(); elemsPerSwizzlingRow = 128 * 8 / perPhase / elemBits; - elemsPerSwizzlingRowVal = i32_val(elemsPerSwizzlingRow); + elemsPerSwizzlingRowVal = b.i32_val(elemsPerSwizzlingRow); uint32_t widthInByte = shape[ord[0]] * elemBits / 8; int64_t swizzling = getSwizzlingFromLayout(sharedLayout, widthInByte); @@ -140,31 +141,35 @@ mlir::triton::NVIDIA::DotOpMmaV3SmemLoader::DotOpMmaV3SmemLoader( Value mlir::triton::NVIDIA::DotOpMmaV3SmemLoader::smemLoad( int a, int b, ConversionPatternRewriter &rewriter, Location loc) { - Value k = i32_val(b * instrShape[1]); - Value m = add(i32_val(a * dimWpt * instrShape[0]), - mul(warpId, i32_val(instrShape[0]))); - if (trans) { - std::swap(k, m); - } - Value leading_offset = mul(udiv(k, elemsPerSwizzlingRowVal), - i32_val(shape[ord[1]] * elemsPerSwizzlingRow)); - Value stride_offset = mul(m, elemsPerSwizzlingRowVal); + auto tb = TritonLLVMOpBuilder(loc, rewriter); + Value k = tb.i32_val(b * instrShape[1]); + Value m = tb.add(tb.i32_val(a * dimWpt * instrShape[0]), + tb.mul(warpId, tb.i32_val(instrShape[0]))); + if (trans) { + std::swap(k, m); + } + Value leading_offset = + tb.mul(tb.udiv(k, elemsPerSwizzlingRowVal), + tb.i32_val(shape[ord[1]] * elemsPerSwizzlingRow)); + Value stride_offset = tb.mul(m, elemsPerSwizzlingRowVal); Value offset = - add(add(leading_offset, stride_offset), urem(k, elemsPerSwizzlingRowVal)); + tb.add(tb.add(leading_offset, stride_offset), + tb.urem(k, elemsPerSwizzlingRowVal)); Value off1; // Avoid the runtime udiv if we know the elements are byte multiples if (elemBits % 8) { - off1 = udiv(mul(i32_val(elemBits), offset), i32_val(8)); + off1 = tb.udiv(tb.mul(tb.i32_val(elemBits), offset), tb.i32_val(8)); } else { - off1 = mul(i32_val(elemBits / 8), offset); + off1 = tb.mul(tb.i32_val(elemBits / 8), offset); } - Value off_ = zext(i64_ty, udiv(off1, i32_val(16))); + Value off_ = tb.zext(i64_ty, tb.udiv(off1, tb.i32_val(16))); - Value loadDesc = add(descriptor, off_); + Value loadDesc = tb.add(descriptor, off_); // Add the base at the end to make it easier to do loop invariant code // motion. - loadDesc = add(loadDesc, lshr(shl(ptrtoint(i64_ty, base), int_val(64, 46)), - int_val(64, 50))); + loadDesc = tb.add( + loadDesc, tb.lshr(tb.shl(tb.ptrtoint(i64_ty, base), tb.int_val(64, 46)), + tb.int_val(64, 50))); return loadDesc; } @@ -172,6 +177,7 @@ DotOpMmaV3SmemLoader loadA(const LLVMTypeConverter *typeConverter, ConversionPatternRewriter &rewriter, Location loc, const NvidiaMmaEncodingAttr &mmaEncoding, Value tensor, Value smemObjBase, Value thread) { + auto b = TritonLLVMOpBuilder(loc, rewriter); auto aTy = cast(tensor.getType()); auto aSharedLayout = dyn_cast(aTy.getEncoding()); assert(aSharedLayout && "only support load dot operand from shared."); @@ -183,13 +189,13 @@ DotOpMmaV3SmemLoader loadA(const LLVMTypeConverter *typeConverter, // The descriptor should be calculated based on the first warp of the // warpgroup. - Value warp = and_(udiv(thread, i32_val(32)), i32_val(0xFFFFFFFC)); + Value warp = b.and_(b.udiv(thread, b.i32_val(32)), b.i32_val(0xFFFFFFFC)); // Workaround for a bug in ptxas 12.3 that cause a failure in // test_core.py::test_dot. The shuffle will force the compiler to treat the // value as uniform and prevent wrong optimizations. warp = mlir::LLVM::NVIDIA::shuffleIdx(loc, rewriter, warp, 0); - Value warpM = urem(warp, i32_val(wpt[0])); - Value warpId = urem(warpM, i32_val(shapePerCTA[0] / instrShape[0])); + Value warpM = b.urem(warp, b.i32_val(wpt[0])); + Value warpId = b.urem(warpM, b.i32_val(shapePerCTA[0] / instrShape[0])); return {tensor, smemObjBase, @@ -207,6 +213,7 @@ DotOpMmaV3SmemLoader loadB(const LLVMTypeConverter *typeConverter, ConversionPatternRewriter &rewriter, Location loc, NvidiaMmaEncodingAttr &mmaEncoding, Value tensor, Value base, Value thread) { + auto b = TritonLLVMOpBuilder(loc, rewriter); auto bTy = cast(tensor.getType()); auto bSharedLayout = cast(bTy.getEncoding()); assert(bSharedLayout && "only support load B from shared."); @@ -216,10 +223,10 @@ DotOpMmaV3SmemLoader loadB(const LLVMTypeConverter *typeConverter, bool transB = bOrd[0] == 1; auto shapePerCTA = triton::gpu::getShapePerCTA(bTy); - Value warp = and_(udiv(thread, i32_val(32)), i32_val(0xFFFFFFFC)); - Value warpMN = udiv(warp, i32_val(wpt[0])); - Value warpN = urem(warpMN, i32_val(wpt[1])); - Value warpId = urem(warpN, i32_val(shapePerCTA[1] / instrShape[1])); + Value warp = b.and_(b.udiv(thread, b.i32_val(32)), b.i32_val(0xFFFFFFFC)); + Value warpMN = b.udiv(warp, b.i32_val(wpt[0])); + Value warpN = b.urem(warpMN, b.i32_val(wpt[1])); + Value warpId = b.urem(warpN, b.i32_val(shapePerCTA[1] / instrShape[1])); return {tensor, base, @@ -262,6 +269,7 @@ llvm::SmallVector loadReg(ConversionPatternRewriter &rewriter, const SmallVector &elements, int startIndex, int numElements, Operation *insertBefore) { + auto b = TritonLLVMOpBuilder(loc, rewriter); OpBuilder::InsertionGuard g(rewriter); rewriter.setInsertionPoint(insertBefore); @@ -283,9 +291,9 @@ llvm::SmallVector loadReg(ConversionPatternRewriter &rewriter, Value pack = rewriter.create(loc, packTy); for (int j = 0; j < numElemsPer32Bits; ++j) { Value element = elements[startIndex + i * numElemsPer32Bits + j]; - pack = insert_element(packTy, pack, element, i32_val(j)); + pack = b.insert_element(packTy, pack, element, b.i32_val(j)); } - pack = bitcast(pack, rewriter.getIntegerType(32)); + pack = b.bitcast(pack, rewriter.getIntegerType(32)); mmaOut[i] = pack; } return mmaOut; @@ -296,15 +304,18 @@ SmallVector unpackAccumulator(ConversionPatternRewriter &rewriter, Location loc, const SmallVector &packed, RankedTensorType tensorTy) { + auto b = TritonLLVMOpBuilder(loc, rewriter); if (!tensorTy.getElementType().isF16()) return packed; // For fp16 the accumulator is pack into 32-bit integers so we need to unpack // it. SmallVector results; for (Value elem : packed) { - elem = bitcast(elem, vec_ty(rewriter.getF16Type(), 2)); - results.push_back(extract_element(rewriter.getF16Type(), elem, i32_val(0))); - results.push_back(extract_element(rewriter.getF16Type(), elem, i32_val(1))); + elem = b.bitcast(elem, vec_ty(rewriter.getF16Type(), 2)); + results.push_back( + b.extract_element(rewriter.getF16Type(), elem, b.i32_val(0))); + results.push_back( + b.extract_element(rewriter.getF16Type(), elem, b.i32_val(1))); } return results; } @@ -325,19 +336,20 @@ static Value faddAccumulate(ConversionPatternRewriter &rewriter, Location loc, static SmallVector emitWait(ConversionPatternRewriter &rewriter, Location loc, SmallVector acc, int pendings) { + auto b = TritonLLVMOpBuilder(loc, rewriter); SmallVector types(acc.size(), acc[0].getType()); auto structTy = LLVM::LLVMStructType::getLiteral(rewriter.getContext(), types); Value llvmStruct = rewriter.create(loc, structTy); int i = 0; for (Value v : acc) { - llvmStruct = insert_val(structTy, llvmStruct, v, i++); + llvmStruct = b.insert_val(structTy, llvmStruct, v, i++); } Value res = rewriter.create(loc, llvmStruct, pendings); SmallVector results; for (int i = 0; i < acc.size(); ++i) { - results.push_back(extract_val(types[0], res, i)); + results.push_back(b.extract_val(types[0], res, i)); } return results; } @@ -349,6 +361,7 @@ LogicalResult convertDot(const LLVMTypeConverter *typeConverter, Value loadedC, bool allowTF32, bool needsPartialAccumulator, uint32_t maxNumImpreciseAcc, bool sync, Value thread) { + auto tb = TritonLLVMOpBuilder(loc, rewriter); auto aTensorTy = cast(a.getType()); auto bTensorTy = cast(b.getType()); auto dTensorTy = cast(d.getType()); @@ -421,13 +434,13 @@ LogicalResult convertDot(const LLVMTypeConverter *typeConverter, auto accTy = LLVM::LLVMStructType::getLiteral(rewriter.getContext(), elemTypes); Value d; - Value useC = i1_val(0); + Value useC = tb.i1_val(0); if (!zeroAcc) { d = packLLElements(loc, typeConverter, mmaOut, rewriter, accTy); - useC = i1_val(1); + useC = tb.i1_val(1); } if (useCOperand) - useC = and_(useC, useCOperand); + useC = tb.and_(useC, useCOperand); uint32_t numLowPrecisionAcc = 0; Value partialAcc; for (int k = 0; k < numRepK; ++k) { @@ -460,7 +473,7 @@ LogicalResult convertDot(const LLVMTypeConverter *typeConverter, mmaAcc = rewriter.create( loc, accTy, a, b, useC, mmaAcc, M, N, K, eltTypeC, eltTypeA, eltTypeB, layoutA, layoutB); - useC = i1_val(1); + useC = tb.i1_val(1); if (needsPartialAccumulator) partialAcc = mmaAcc; else diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ElementwiseOpToLLVM.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ElementwiseOpToLLVM.cpp index d489d0a1b1f4..be110707d017 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ElementwiseOpToLLVM.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ElementwiseOpToLLVM.cpp @@ -279,11 +279,11 @@ static ConverterT makeConverterFromPtx(const std::string &ptxAsm, Type inType, Type outType, const int inVecWidthBits = 32, const int outVecWidthBits = 32) { - ConverterT converter = [ptxAsm, inType, outType, inVecWidthBits, outVecWidthBits](Location loc, ConversionPatternRewriter &rewriter, const SmallVector &v) -> SmallVector { + auto b = TritonLLVMOpBuilder(loc, rewriter); int numElements = v.size(); assert(numElements == 4 || numElements == 2 && "invalid vector size"); @@ -293,12 +293,12 @@ static ConverterT makeConverterFromPtx(const std::string &ptxAsm, Type inType, // first, we pack `v` into 32-bit ints int inVecWidth = inVecWidthBits / inBitwidth; auto inVecTy = vec_ty(inType, inVecWidth); - SmallVector inPacked(numElements / inVecWidth, undef(inVecTy)); + SmallVector inPacked(numElements / inVecWidth, b.undef(inVecTy)); for (size_t i = 0; i < numElements; i++) - inPacked[i / inVecWidth] = insert_element( - inVecTy, inPacked[i / inVecWidth], v[i], i32_val(i % inVecWidth)); + inPacked[i / inVecWidth] = b.insert_element( + inVecTy, inPacked[i / inVecWidth], v[i], b.i32_val(i % inVecWidth)); for (size_t i = 0; i < inPacked.size(); i++) - inPacked[i] = bitcast(inPacked[i], int_ty(inVecWidthBits)); + inPacked[i] = b.bitcast(inPacked[i], int_ty(inVecWidthBits)); // then, we run the provided inline PTX int outVecWidth = outVecWidthBits / outBitwidth; @@ -325,13 +325,13 @@ static ConverterT makeConverterFromPtx(const std::string &ptxAsm, Type inType, auto outStructTy = struct_ty(SmallVector(outNums, outVecTy)); auto outStruct = builder.launch(rewriter, loc, outStructTy, false); for (int i = 0; i < outNums; i++) - outPacked.push_back(extract_val(outVecTy, outStruct, i)); + outPacked.push_back(b.extract_val(outVecTy, outStruct, i)); } // unpack the output SmallVector ret; for (size_t i = 0; i < numElements; i++) - ret.push_back(extract_element(outType, outPacked[i / outVecWidth], - i32_val(i % outVecWidth))); + ret.push_back(b.extract_element(outType, outPacked[i / outVecWidth], + b.i32_val(i % outVecWidth))); return ret; }; return converter; @@ -485,6 +485,7 @@ struct FpToFpOpConversion ConversionPatternRewriter &rewriter, Type elemTy, MultipleOperandsRange operands, Location loc) const { + auto b = TritonLLVMOpBuilder(loc, rewriter); auto srcElementType = getElementType(op.getSrc()); auto dstElementType = getElementType(op.getResult()); auto roundingMode = op.getRounding(); @@ -541,7 +542,7 @@ struct FpToFpOpConversion if (useFP16IntermediateSrc) for (Value &v : inVals) v = convertFp32ToFp16(loc, rewriter, v, RoundingMode::RTZ); - inVals.resize(numElements, undef(typeConverter->convertType(srcType))); + inVals.resize(numElements, b.undef(typeConverter->convertType(srcType))); SmallVector outVals = cvtFunc(loc, rewriter, inVals); assert(outVals.size() == inVals.size()); outVals.resize(std::min(numElements, operands.size())); @@ -787,12 +788,13 @@ struct ExpOpConversionApprox ConversionPatternRewriter &rewriter, Type elemTy, MultipleOperandsRange operands, Location loc) const { + auto b = TritonLLVMOpBuilder(loc, rewriter); // For non-FP32 input, call __nv_expf for higher-precision calculation if (elemTy.getIntOrFloatBitWidth() != 32) return {}; const double log2e = 1.4426950408889634; - Value prod = fmul(f32_ty, operands[0][0], f32_val(log2e)); + Value prod = b.fmul(f32_ty, operands[0][0], b.f32_val(log2e)); PTXBuilder ptxBuilder; auto &exp2 = ptxBuilder.create("ex2")->o("approx").o("f32"); diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/LoadStoreOpToLLVM.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/LoadStoreOpToLLVM.cpp index 3afd34a1fe49..889888e7e505 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/LoadStoreOpToLLVM.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/LoadStoreOpToLLVM.cpp @@ -67,8 +67,9 @@ llvm::MapVector getFreeVariableMasks(Type type) { } Value maybeAnd(RewriterBase &rewriter, Location loc, Value a, Value b) { + auto tb = TritonLLVMOpBuilder(loc, rewriter); if (a && b) { - return and_(a, b); + return tb.and_(a, b); } return a ? a : b; } @@ -80,6 +81,7 @@ Value emitRedundantThreadPredicate( ModuleOp moduleOp, const llvm::MapVector &freeVarMasks, ConversionPatternRewriter &rewriter, Location loc, const NVIDIA::TargetInfo &targetInfo) { + auto b = TritonLLVMOpBuilder(loc, rewriter); auto ctx = rewriter.getContext(); auto kLane = str_attr("lane"); auto kWarp = str_attr("warp"); @@ -89,7 +91,7 @@ Value emitRedundantThreadPredicate( auto emitBlockId = freeVarMasks.lookup(kBlock) != 0; auto [laneId, warpId, blockId] = emitHardwareTuple(loc, rewriter, targetInfo, emitBlockId, warpSize); - auto zero = i32_val(0); + auto zero = b.i32_val(0); Value pred; auto dimNames = {kLane, kWarp, kBlock}; @@ -97,7 +99,7 @@ Value emitRedundantThreadPredicate( for (auto [dimName, dimId] : llvm::zip(dimNames, dimIds)) { int32_t mask = freeVarMasks.lookup(dimName); if (mask != 0) { - auto dimPred = icmp_eq(and_(dimId, i32_val(mask)), zero); + auto dimPred = b.icmp_eq(b.and_(dimId, b.i32_val(mask)), zero); pred = maybeAnd(rewriter, loc, pred, dimPred); } } @@ -177,6 +179,7 @@ struct LoadOpConversion : public ConvertOpToLLVMPattern, ConversionPatternRewriter &rewriter) const override { auto ctx = getContext(); auto loc = op->getLoc(); + auto b = TritonLLVMOpBuilder(loc, rewriter); auto typeConverter = getTypeConverter(); // original values @@ -307,14 +310,14 @@ struct LoadOpConversion : public ConvertOpToLLVMPattern, size_t size = width / valueElemNBits; auto vecTy = LLVM::getFixedVectorType(valueElemTy, size); - Value v = undef(vecTy); + Value v = b.undef(vecTy); for (size_t s = 0; s < size; ++s) { Value falseVal = otherElems[vecStart + ii * size + s]; Value sVal = createIndexAttrConstant( rewriter, loc, typeConverter->getIndexType(), s); - v = insert_element(vecTy, v, falseVal, sVal); + v = b.insert_element(vecTy, v, falseVal, sVal); } - v = bitcast(v, IntegerType::get(getContext(), width)); + v = b.bitcast(v, IntegerType::get(getContext(), width)); PTXInstr::Operand *opr{}; @@ -376,19 +379,19 @@ struct LoadOpConversion : public ConvertOpToLLVMPattern, for (unsigned int ii = 0; ii < nWords; ++ii) { Value curr; if (isa(retTy)) { - curr = extract_val(IntegerType::get(getContext(), width), ret, ii); + curr = b.extract_val(IntegerType::get(getContext(), width), ret, ii); } else { curr = ret; } - curr = bitcast(curr, LLVM::getFixedVectorType(valueElemTy, - width / valueElemNBits)); + curr = b.bitcast(curr, LLVM::getFixedVectorType( + valueElemTy, width / valueElemNBits)); rets.push_back(curr); } int tmp = width / valueElemNBits; for (size_t ii = 0; ii < vec; ++ii) { Value vecIdx = createIndexAttrConstant( rewriter, loc, typeConverter->getIndexType(), ii % tmp); - Value loaded = extract_element(valueElemTy, rets[ii / tmp], vecIdx); + Value loaded = b.extract_element(valueElemTy, rets[ii / tmp], vecIdx); loadedVals.push_back(loaded); } } // end vec @@ -421,6 +424,7 @@ struct StoreOpConversion : public ConvertOpToLLVMPattern, Value llValue = adaptor.getValue(); auto loc = op->getLoc(); + auto b = TritonLLVMOpBuilder(loc, rewriter); MLIRContext *ctx = rewriter.getContext(); auto valueTy = value.getType(); @@ -489,19 +493,19 @@ struct StoreOpConversion : public ConvertOpToLLVMPattern, SmallVector> asmArgs; for (size_t wordIdx = 0; wordIdx < nWords; ++wordIdx) { // llWord is a width-len composition - Value llWord = undef(wordTy); + Value llWord = b.undef(wordTy); // Insert each value element to the composition for (size_t elemIdx = 0; elemIdx < wordNElems; ++elemIdx) { const size_t elemOffset = vecStart + wordIdx * wordNElems + elemIdx; assert(elemOffset < valueElems.size()); Value elem = valueElems[elemOffset]; if (elem.getType().isInteger(1)) - elem = sext(i8_ty, elem); - elem = bitcast(elem, valueElemTy); + elem = b.sext(i8_ty, elem); + elem = b.bitcast(elem, valueElemTy); - llWord = insert_element(wordTy, llWord, elem, i32_val(elemIdx)); + llWord = b.insert_element(wordTy, llWord, elem, b.i32_val(elemIdx)); } - llWord = bitcast(llWord, valArgTy); + llWord = b.bitcast(llWord, valArgTy); std::string constraint = (width == 64) ? "l" : ((width == 32) ? "r" : "c"); asmArgs.emplace_back(llWord, constraint); @@ -545,8 +549,9 @@ struct StoreOpConversion : public ConvertOpToLLVMPattern, void createBarrier(ConversionPatternRewriter &rewriter, Location loc, int numCTAs) { + auto b = TritonLLVMOpBuilder(loc, rewriter); if (numCTAs == 1) { - barrier(); + b.barrier(); } else { rewriter.create(loc, false); rewriter.create(loc); @@ -567,6 +572,7 @@ struct AtomicCASOpConversion matchAndRewrite(triton::AtomicCASOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto loc = op.getLoc(); + auto b = TritonLLVMOpBuilder(loc, rewriter); MLIRContext *ctx = rewriter.getContext(); auto moduleOp = op->getParentOfType(); @@ -620,11 +626,11 @@ struct AtomicCASOpConversion continue; } - Value casVal = undef(vecTy); + Value casVal = b.undef(vecTy); for (int ii = 0; ii < vec; ++ii) { Value iiVal = createIndexAttrConstant( rewriter, loc, getTypeConverter()->getIndexType(), ii); - casVal = insert_element(vecTy, casVal, valElements[i + ii], iiVal); + casVal = b.insert_element(vecTy, casVal, valElements[i + ii], iiVal); } Value casPtr = ptrElements[i]; @@ -652,7 +658,8 @@ struct AtomicCASOpConversion auto ret = ptxBuilderAtomicCAS.launch(rewriter, loc, retType); for (int ii = 0; ii < vec; ++ii) { resultVals[i + ii] = - vec == 1 ? ret : extract_element(valueElemTy, ret, i32_val(ii)); + vec == 1 ? ret + : b.extract_element(valueElemTy, ret, b.i32_val(ii)); } } else { auto old = ptxBuilderAtomicCAS.launch(rewriter, loc, valueElemTy); @@ -662,7 +669,7 @@ struct AtomicCASOpConversion } Value atomPtr = LLVM::getSharedMemoryBase(loc, rewriter, targetInfo, op.getOperation()); - atomPtr = bitcast(atomPtr, ptr_ty(ctx, 3)); + atomPtr = b.bitcast(atomPtr, ptr_ty(ctx, 3)); // Only threads with mask = True store the result PTXBuilder ptxBuilderStore; auto *dstOprStore = ptxBuilderStore.newAddrOperand(atomPtr, "r"); @@ -673,7 +680,7 @@ struct AtomicCASOpConversion auto ASMReturnTy = void_ty(ctx); ptxBuilderStore.launch(rewriter, loc, ASMReturnTy); createBarrier(rewriter, loc, numCTAs); - Value ret = load(valueElemTy, atomPtr); + Value ret = b.load(valueElemTy, atomPtr); rewriter.replaceOp(op, {ret}); } } @@ -756,6 +763,7 @@ struct AtomicRMWOpConversion matchAndRewrite(triton::AtomicRMWOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto loc = op.getLoc(); + auto b = TritonLLVMOpBuilder(loc, rewriter); MLIRContext *ctx = rewriter.getContext(); auto moduleOp = op->getParentOfType(); @@ -865,11 +873,11 @@ struct AtomicRMWOpConversion } Value atomPtr = LLVM::getSharedMemoryBase(loc, rewriter, targetInfo, op.getOperation()); - atomPtr = bitcast(atomPtr, ptr_ty(ctx, 3)); + atomPtr = b.bitcast(atomPtr, ptr_ty(ctx, 3)); // Only threads with rmwMask = True store the result targetInfo.storeShared(rewriter, loc, atomPtr, loadAcquireOp, pred); createBarrier(rewriter, loc, numCTAs); - Value ret = load(valueElemTy, atomPtr); + Value ret = b.load(valueElemTy, atomPtr); rewriter.replaceOp(op, {ret}); continue; } @@ -901,10 +909,10 @@ struct AtomicRMWOpConversion ptxBuilderAtomicRMW.newOperand(valElements[i + ii], tyId)); } } else if (packed > 1) { - Value rmwVal = undef(packedTy); + Value rmwVal = b.undef(packedTy); for (int ii = 0; ii < packed; ++ii) { - rmwVal = insert_element(packedTy, rmwVal, valElements[i + ii], - i32_val(ii)); + rmwVal = b.insert_element(packedTy, rmwVal, valElements[i + ii], + b.i32_val(ii)); } valOpr = ptxBuilderAtomicRMW.newOperand(rmwVal, tyId); } else { @@ -974,11 +982,12 @@ struct AtomicRMWOpConversion if (vec > 1) { for (unsigned ii = 0; ii < vec; ++ii) { - resultVals[i + ii] = extract_val(valueElemTy, ret, ii); + resultVals[i + ii] = b.extract_val(valueElemTy, ret, ii); } } else if (packed > 1) { for (unsigned ii = 0; ii < packed; ++ii) { - resultVals[i + ii] = extract_element(valueElemTy, ret, i32_val(ii)); + resultVals[i + ii] = + b.extract_element(valueElemTy, ret, b.i32_val(ii)); } } else { resultVals[i] = ret; @@ -994,11 +1003,11 @@ struct AtomicRMWOpConversion } Value atomPtr = LLVM::getSharedMemoryBase(loc, rewriter, targetInfo, op.getOperation()); - atomPtr = bitcast(atomPtr, ptr_ty(ctx, 3)); + atomPtr = b.bitcast(atomPtr, ptr_ty(ctx, 3)); // Only threads with rmwMask = True store the result targetInfo.storeShared(rewriter, loc, atomPtr, old, pred); createBarrier(rewriter, loc, numCTAs); - Value ret = load(valueElemTy, atomPtr); + Value ret = b.load(valueElemTy, atomPtr); rewriter.replaceOp(op, {ret}); } } @@ -1027,6 +1036,7 @@ struct AsyncCopyGlobalToLocalOpConversion ConversionPatternRewriter &rewriter) const override { auto ctx = getContext(); auto loc = op.getLoc(); + auto b = TritonLLVMOpBuilder(loc, rewriter); Value res = op.getResult(); Value mask = op.getMask(); Value other = op.getOther(); @@ -1144,7 +1154,7 @@ struct AsyncCopyGlobalToLocalOpConversion // op.getMask() and redundantDataMask() into the same predicate, the // way it is done for LoadOp. auto selectOp = - select(maskElems[elemIdx], i32_val(wordBytes), i32_val(0)); + b.select(maskElems[elemIdx], b.i32_val(wordBytes), b.i32_val(0)); srcSize = ptxBuilder.newOperand(selectOp, "r"); } @@ -1180,6 +1190,7 @@ struct AsyncTMACopyGlobalToLocalOpConversion return op.emitError("volatile not supported yet"); auto loc = op.getLoc(); + auto b = TritonLLVMOpBuilder(loc, rewriter); Type llvmElemTy = typeConverter->convertType(op.getResult().getType().getElementType()); auto barrierMemObj = LLVM::getSharedMemoryObjectFromStruct( @@ -1198,7 +1209,7 @@ struct AsyncTMACopyGlobalToLocalOpConversion Value pred = adaptor.getPred(); // Select just one thread for the TMA copy. This also helps the compiler to // figure out that the op is uniform. - pred = and_(pred, LLVM::NVIDIA::createElectPredicate(loc, rewriter)); + pred = b.and_(pred, LLVM::NVIDIA::createElectPredicate(loc, rewriter)); int elementSizeInBytes = op.getResult().getType().getElementType().getIntOrFloatBitWidth() / 8; @@ -1220,16 +1231,16 @@ struct AsyncTMACopyGlobalToLocalOpConversion for (int copyIdx = 0; copyIdx < numCopies; copyIdx += numWarps) { int numWarpsToCopy = std::min(numCopies - copyIdx, numWarps); if (numWarpsToCopy == 1) - warpID = i32_val(0); + warpID = b.i32_val(0); Value boxPred = - and_(pred, icmp_ult(id, i32_val(numWarpsToCopy * warpSize))); + b.and_(pred, b.icmp_ult(id, b.i32_val(numWarpsToCopy * warpSize))); ::mlir::triton::PTXBuilder ptxBuilderTMA; Type elemPtrTy = ptr_ty(rewriter.getContext(), 3); - Value copyIdxVal = add(warpID, i32_val(copyIdx)); + Value copyIdxVal = b.add(warpID, b.i32_val(copyIdx)); Value shMemOffset = - mul(copyIdxVal, i32_val(totalNumElements / numCopies)); + b.mul(copyIdxVal, b.i32_val(totalNumElements / numCopies)); Value shMemPtr = - gep(elemPtrTy, llvmElemTy, dstMemObj.getBase(), shMemOffset); + b.gep(elemPtrTy, llvmElemTy, dstMemObj.getBase(), shMemOffset); SmallVector operands = { ptxBuilderTMA.newOperand(boxPred, "b"), ptxBuilderTMA.newOperand(shMemPtr, "r"), @@ -1241,8 +1252,8 @@ struct AsyncTMACopyGlobalToLocalOpConversion for (int i = 0; i < rank; i++) { Value coord = adaptor.getCoord()[rank - i - 1]; if (i == 0) { - Value offset = mul(copyIdxVal, i32_val(128 / elementSizeInBytes)); - coord = add(coord, offset); + Value offset = b.mul(copyIdxVal, b.i32_val(128 / elementSizeInBytes)); + coord = b.add(coord, offset); } operands.push_back(ptxBuilderTMA.newOperand(coord, "r")); tmaInst += "$" + std::to_string(operandIdx++); @@ -1272,6 +1283,7 @@ struct AsyncTMACopyLocalToGlobalOpConversion OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto loc = op.getLoc(); + auto b = TritonLLVMOpBuilder(loc, rewriter); Type llvmElemTy = typeConverter->convertType(op.getSrc().getType().getElementType()); auto dstMemObj = LLVM::getSharedMemoryObjectFromStruct( @@ -1305,16 +1317,16 @@ struct AsyncTMACopyLocalToGlobalOpConversion for (int copyIdx = 0; copyIdx < numCopies; copyIdx += numWarps) { int numWarpsToCopy = std::min(numCopies - copyIdx, numWarps); if (numWarpsToCopy == 1) - warpID = i32_val(0); + warpID = b.i32_val(0); Value boxPred = - and_(pred, icmp_ult(id, i32_val(numWarpsToCopy * warpSize))); + b.and_(pred, b.icmp_ult(id, b.i32_val(numWarpsToCopy * warpSize))); ::mlir::triton::PTXBuilder ptxBuilderTMA; Type elemPtrTy = ptr_ty(rewriter.getContext(), 3); - Value copyIdxVal = add(warpID, i32_val(copyIdx)); + Value copyIdxVal = b.add(warpID, b.i32_val(copyIdx)); Value shMemOffset = - mul(copyIdxVal, i32_val(totalNumElements / numCopies)); + b.mul(copyIdxVal, b.i32_val(totalNumElements / numCopies)); Value shMemPtr = - gep(elemPtrTy, llvmElemTy, dstMemObj.getBase(), shMemOffset); + b.gep(elemPtrTy, llvmElemTy, dstMemObj.getBase(), shMemOffset); SmallVector operands = { ptxBuilderTMA.newOperand(boxPred, "b"), ptxBuilderTMA.newOperand(adaptor.getDescPtr(), "l")}; @@ -1324,8 +1336,8 @@ struct AsyncTMACopyLocalToGlobalOpConversion for (int i = 0; i < rank; i++) { Value coord = adaptor.getCoord()[rank - i - 1]; if (i == 0) { - Value offset = mul(copyIdxVal, i32_val(128 / elementSizeInBytes)); - coord = add(coord, offset); + Value offset = b.mul(copyIdxVal, b.i32_val(128 / elementSizeInBytes)); + coord = b.add(coord, offset); } operands.push_back(ptxBuilderTMA.newOperand(coord, "r")); tmaInst += "$" + std::to_string(operandIdx++); diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/MemoryOpToLLVM.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/MemoryOpToLLVM.cpp index 8fa340cb14d5..3166d4d525a7 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/MemoryOpToLLVM.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/MemoryOpToLLVM.cpp @@ -81,6 +81,7 @@ struct LocalLoadOpConversion ConversionPatternRewriter &rewriter) const { auto ctx = rewriter.getContext(); auto loc = op.getLoc(); + auto b = TritonLLVMOpBuilder(loc, rewriter); auto dstTy = cast(op.getType()); auto srcTy = cast(op.getSrc().getType()); auto dotEnc = cast(dstTy.getEncoding()); @@ -97,7 +98,6 @@ struct LocalLoadOpConversion chooseLdMatrixLayout(dotEnc, shape, needTrans, bitwidth); auto smemObj = LLVM::getSharedMemoryObjectFromStruct(loc, adaptor.getSrc(), llvmElemTy, rewriter); - // Emit ldmatrix load operations for values packed in i32s SmallVector elemsI32; // Typically we load 32x8 to use ldmatrix.x4, but the minimum tile size for @@ -117,7 +117,7 @@ struct LocalLoadOpConversion loc, matTy, vecAddr, /*needTrans=*/needTrans); auto res = ldMatrixOp.getResult(); for (auto i = 0; i < numElemsI32; ++i) { - elemsI32.push_back(extract_val(i32_ty, res, i)); + elemsI32.push_back(b.extract_val(i32_ty, res, i)); } }); assert(valid && "Failed to emit ldmatrix load operations"); @@ -127,9 +127,9 @@ struct LocalLoadOpConversion auto numElemsPerVec = 32 / bitwidth; auto vecTy = vec_ty(llvmElemTy, numElemsPerVec); for (int v = 0; v < static_cast(elemsI32.size()); ++v) { - auto vec = bitcast(elemsI32[v], vecTy); + auto vec = b.bitcast(elemsI32[v], vecTy); for (int i = 0; i < numElemsPerVec; ++i) - elems.push_back(extract_element(llvmElemTy, vec, i32_val(i))); + elems.push_back(b.extract_element(llvmElemTy, vec, b.i32_val(i))); } auto structTy = LLVM::LLVMStructType::getLiteral( @@ -148,6 +148,7 @@ LogicalResult lowerDistributedToSharedStmatrix( Value adaptorSrc, Value smemBase, const TypeConverter *typeConverter, ConversionPatternRewriter &rewriter, const TargetInfoBase &targetInfo, std::pair *const llvmOpCount = nullptr) { + auto b = TritonLLVMOpBuilder(loc, rewriter); auto mmaEncoding = dyn_cast(src.getType().getEncoding()); if (!mmaEncoding) @@ -187,15 +188,15 @@ LogicalResult lowerDistributedToSharedStmatrix( auto kBlock = str_attr("block"); Value threadId = getThreadId(rewriter, loc); - Value threadsPerWarp = i32_val(layout.getInDimSize(kLane)); - Value laneId = urem(threadId, threadsPerWarp); - Value warpId = udiv(threadId, threadsPerWarp); + Value threadsPerWarp = b.i32_val(layout.getInDimSize(kLane)); + Value laneId = b.urem(threadId, threadsPerWarp); + Value warpId = b.udiv(threadId, threadsPerWarp); auto regBase = applyLinearLayout(loc, rewriter, layout, - {{kRegister, i32_val(0)}, + {{kRegister, b.i32_val(0)}, {kLane, laneId}, {kWarp, warpId}, - {kBlock, i32_val(0)}})[0] + {kBlock, b.i32_val(0)}})[0] .second; auto srcVals = unpackLLElements(loc, adaptorSrc, rewriter); auto srcVec = layout.getNumConsecutiveInOut(); @@ -203,8 +204,8 @@ LogicalResult lowerDistributedToSharedStmatrix( auto regIdx = layout.apply({{kRegister, i}, {kLane, 0}, {kWarp, 0}, {kBlock, 0}})[0] .second; - Value offset = xor_(regBase, i32_val(regIdx)); - auto vecAddr = gep(smemPtrTy, llvmElemTy, smemBase, offset); + Value offset = b.xor_(regBase, b.i32_val(regIdx)); + auto vecAddr = b.gep(smemPtrTy, llvmElemTy, smemBase, offset); vecAddr.setInbounds(true); SmallVector inValsVec; for (int j = 0; j < srcVec; j++) diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TMAToLLVM.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TMAToLLVM.cpp index 826c6787242b..8d87993ffdbe 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TMAToLLVM.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TMAToLLVM.cpp @@ -21,6 +21,7 @@ void tensormap_cp_fenceproxy(Location loc, MLIRContext *ctx, ConversionPatternRewriter &rewriter, Value outPtr, Value inPtr) { PTXBuilder ptxBuilder; + auto b = TritonLLVMOpBuilder(loc, rewriter); // prepare asm operands auto *outAddrOpr = ptxBuilder.newAddrOperand(outPtr, "l"); @@ -35,7 +36,7 @@ void tensormap_cp_fenceproxy(Location loc, MLIRContext *ctx, // Execute collectively on first warp in block constexpr int kWarpSize = 32; Value threadId = getThreadId(rewriter, loc); - Value pred = icmp_slt(threadId, i32_val(kWarpSize)); + Value pred = b.icmp_slt(threadId, b.i32_val(kWarpSize)); cp(outAddrOpr, inAddrOpr, sizeOpr).predicate(pred); ptxBuilder.launch(rewriter, loc, void_ty(ctx)); @@ -46,6 +47,7 @@ void tensormap_replace_generic(Location loc, MLIRContext *ctx, std::string fieldName, Value descPtr, int32_t newVal) { PTXBuilder ptxBuilder; + auto b = TritonLLVMOpBuilder(loc, rewriter); // prepare asm operands auto *descAddrOpr = ptxBuilder.newAddrOperand(descPtr, "l"); @@ -59,7 +61,7 @@ void tensormap_replace_generic(Location loc, MLIRContext *ctx, .o("b32"); Value threadId = getThreadId(rewriter, loc); - Value pred = icmp_eq(threadId, i32_val(0)); + Value pred = b.icmp_eq(threadId, b.i32_val(0)); replace(descAddrOpr, newValOpr).predicate(pred); ptxBuilder.launch(rewriter, loc, void_ty(ctx)); @@ -71,6 +73,7 @@ void tensormap_replace_generic(Location loc, MLIRContext *ctx, Value newVal, std::optional ord = std::nullopt) { PTXBuilder ptxBuilder; + auto b = TritonLLVMOpBuilder(loc, rewriter); auto newValTy = newVal.getType(); int width = 0; @@ -98,7 +101,7 @@ void tensormap_replace_generic(Location loc, MLIRContext *ctx, .o("b64", width == 64); Value threadId = getThreadId(rewriter, loc); - Value pred = icmp_eq(threadId, i32_val(0)); + Value pred = b.icmp_eq(threadId, b.i32_val(0)); if (ord) { replace(descAddrOpr, ordOpr, newValOpr).predicate(pred); @@ -189,6 +192,7 @@ struct ExperimentalTensormapFenceproxyAcquireOpConversion auto loc = op.getLoc(); PTXBuilder ptxBuilder; + auto b = TritonLLVMOpBuilder(loc, rewriter); // prepare asm operands auto *descAddrOpr = ptxBuilder.newAddrOperand(adaptor.getDescPtr(), "l"); @@ -197,7 +201,7 @@ struct ExperimentalTensormapFenceproxyAcquireOpConversion // Define the instruction opcode constexpr int kWarpSize = 32; Value threadId = getThreadId(rewriter, loc); - Value pred = icmp_slt(threadId, i32_val(kWarpSize)); + Value pred = b.icmp_slt(threadId, b.i32_val(kWarpSize)); auto &fence = *ptxBuilder.create<>("fence.proxy.tensormap::generic.acquire.gpu"); fence(descAddrOpr, sizeOpr).predicate(pred); @@ -217,13 +221,15 @@ struct ExperimentalTensormapFenceproxyAcquireOpConversion void zero_fill_tma(Location loc, MLIRContext *ctx, ConversionPatternRewriter &rewriter, const NVIDIA::TargetInfo &targetInfo, Value descPtr) { + auto b = TritonLLVMOpBuilder(loc, rewriter); // Write out zeros constexpr int kWarpSize = 32; Value threadId = getThreadId(rewriter, loc); - Value pred = icmp_slt(threadId, i32_val(kWarpSize)); + Value pred = b.icmp_slt(threadId, b.i32_val(kWarpSize)); - auto fillVal = i32_val(0); - auto writeAddr = gep(descPtr.getType(), fillVal.getType(), descPtr, threadId); + auto fillVal = b.i32_val(0); + auto writeAddr = + b.gep(descPtr.getType(), fillVal.getType(), descPtr, threadId); targetInfo.storeShared(rewriter, loc, writeAddr, fillVal, pred); // Sync warp @@ -247,6 +253,7 @@ struct ExperimentalTensormapCreateOpConversion matchAndRewrite(triton::ExperimentalTensormapCreateOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Location loc = op->getLoc(); + auto b = TritonLLVMOpBuilder(loc, rewriter); auto ctx = getContext(); bool needsStrideWorkaround = targetInfo.getPtxVersion() <= 85; @@ -268,7 +275,7 @@ struct ExperimentalTensormapCreateOpConversion auto strideVal = op.getGlobalStride()[i]; if (needsStrideWorkaround) { // Workaround for a ptxas bug - strideVal = ashr(strideVal, i64_val(4)); + strideVal = b.ashr(strideVal, b.i64_val(4)); } tensormap_replace_global_stride(loc, ctx, rewriter, smemBase, i, strideVal); diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TargetInfo.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TargetInfo.cpp index 7c4a9e5b92df..e1437ee34570 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TargetInfo.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TargetInfo.cpp @@ -40,19 +40,20 @@ std::pair printfPromoteValue(RewriterBase &rewriter, Value value) { Value newOp = value; Type newType = type; auto loc = UnknownLoc::get(context); + auto b = TritonLLVMOpBuilder(loc, rewriter); bool isUnsigned = type.isUnsignedInteger(); if (type.isIntOrIndex() && type.getIntOrFloatBitWidth() < 32) { if (isUnsigned) { newType = ui32_ty; - newOp = zext(newType, value); + newOp = b.zext(newType, value); } else { newType = i32_ty; - newOp = sext(newType, value); + newOp = b.sext(newType, value); } } else if (type.isBF16() || type.isF16() || type.isF32()) { newType = f64_ty; - newOp = fpext(newType, value); + newOp = b.fpext(newType, value); } return {newType, newOp}; @@ -126,7 +127,8 @@ Value TargetInfo::getClusterCTAId(RewriterBase &rewriter, Location loc) const { Value TargetInfo::ballot(RewriterBase &rewriter, Location loc, Type type, Value cmp) const { - Value threadMask = int_val(type.getIntOrFloatBitWidth(), -1); + auto b = TritonLLVMOpBuilder(loc, rewriter); + Value threadMask = b.int_val(type.getIntOrFloatBitWidth(), -1); return rewriter.create(loc, type, threadMask, cmp); } @@ -164,6 +166,7 @@ static bool isConstantTruePred(Value pred) { void TargetInfo::storeDShared(RewriterBase &rewriter, Location loc, Value ptr, std::optional ctaId, Value val, Value pred) const { + auto b = TritonLLVMOpBuilder(loc, rewriter); MLIRContext *ctx = rewriter.getContext(); auto ptrTy = cast(ptr.getType()); assert(ptrTy.getAddressSpace() == 3 && "Invalid addr space for load_dsmem"); @@ -185,7 +188,7 @@ void TargetInfo::storeDShared(RewriterBase &rewriter, Location loc, Value ptr, "don't know how to load/store vectors of sub-byte elems"); SmallVector vals = unpackLLVector(loc, val, rewriter); for (Value &v : vals) { - v = zext(int_ty(8), bitcast(v, int_ty(elemBitwidth))); + v = b.zext(int_ty(8), b.bitcast(v, int_ty(elemBitwidth))); } storeDShared(rewriter, loc, ptr, ctaId, packLLVector(loc, vals, rewriter), pred); @@ -195,7 +198,7 @@ void TargetInfo::storeDShared(RewriterBase &rewriter, Location loc, Value ptr, if (!elemTy.isInteger()) { SmallVector vals = unpackLLVector(loc, val, rewriter); for (Value &v : vals) { - v = bitcast(v, int_ty(elemBitwidth)); + v = b.bitcast(v, int_ty(elemBitwidth)); } storeDShared(rewriter, loc, ptr, ctaId, packLLVector(loc, vals, rewriter), pred); @@ -216,7 +219,7 @@ void TargetInfo::storeDShared(RewriterBase &rewriter, Location loc, Value ptr, Value v = packLLVector( loc, ArrayRef(oldVals).slice(i * elemsPerPack, elemsPerPack), rewriter); - newVals.push_back(bitcast(v, i32_ty)); + newVals.push_back(b.bitcast(v, i32_ty)); } storeDShared(rewriter, loc, ptr, ctaId, packLLVector(loc, newVals, rewriter), pred); @@ -231,8 +234,8 @@ void TargetInfo::storeDShared(RewriterBase &rewriter, Location loc, Value ptr, auto newVecTy = vec_ty(elemTy, maxVec); SmallVector vals = unpackLLVector(loc, val, rewriter); for (int i = 0; i < vec / maxVec; i++) { - auto newPtr = gep(ptr.getType(), elemTy, ptr, i32_val(i * maxVec), - /*inbounds=*/true); + auto newPtr = b.gep(ptr.getType(), elemTy, ptr, b.i32_val(i * maxVec), + /*inbounds=*/true); storeDShared( rewriter, loc, newPtr, ctaId, packLLVector(loc, ArrayRef(vals).slice(i * maxVec, maxVec), rewriter), @@ -265,7 +268,7 @@ void TargetInfo::storeDShared(RewriterBase &rewriter, Location loc, Value ptr, if (vec > 1) { SmallVector> vecVals; for (int i = 0; i < vec; i++) { - vecVals.push_back({extract_element(val, i32_val(i)), constraint}); + vecVals.push_back({b.extract_element(val, b.i32_val(i)), constraint}); } valOpr = builder.newListOperand(vecVals); } else { @@ -278,6 +281,7 @@ void TargetInfo::storeDShared(RewriterBase &rewriter, Location loc, Value ptr, Value TargetInfo::loadDShared(RewriterBase &rewriter, Location loc, Value ptr, std::optional ctaId, Type loadTy, Value pred) const { + auto b = TritonLLVMOpBuilder(loc, rewriter); MLIRContext *ctx = rewriter.getContext(); auto ptrTy = cast(ptr.getType()); assert(ptrTy.getAddressSpace() == 3 && "Invalid addr space for load_dsmem"); @@ -302,7 +306,7 @@ Value TargetInfo::loadDShared(RewriterBase &rewriter, Location loc, Value ptr, SmallVector vals = unpackLLVector( loc, loadDShared(rewriter, loc, ptr, ctaId, int_ty(8), pred), rewriter); assert(vals.size() == 1); - return bitcast(trunc(int_ty(elemBitwidth), vals[0]), elemTy); + return b.bitcast(b.trunc(int_ty(elemBitwidth), vals[0]), elemTy); } // We only know how to load integers. @@ -311,7 +315,7 @@ Value TargetInfo::loadDShared(RewriterBase &rewriter, Location loc, Value ptr, SmallVector vals = unpackLLVector( loc, loadDShared(rewriter, loc, ptr, ctaId, newLoadTy, pred), rewriter); for (Value &v : vals) { - v = bitcast(v, elemTy); + v = b.bitcast(v, elemTy); } return packLLVector(loc, vals, rewriter); } @@ -328,7 +332,7 @@ Value TargetInfo::loadDShared(RewriterBase &rewriter, Location loc, Value ptr, // Unpack the b32's into the original vector type. SmallVector vals; for (Value v : unpackLLVector(loc, res, rewriter)) { - Value vv = bitcast(v, vec_ty(elemTy, 32 / elemBitwidth)); + Value vv = b.bitcast(v, vec_ty(elemTy, 32 / elemBitwidth)); for (Value vvv : unpackLLVector(loc, vv, rewriter)) { vals.push_back(vvv); } @@ -343,8 +347,8 @@ Value TargetInfo::loadDShared(RewriterBase &rewriter, Location loc, Value ptr, SmallVector vals; for (int i = 0; i < vec / maxVec; i++) { - auto newPtr = gep(ptr.getType(), elemTy, ptr, i32_val(i * maxVec), - /*inbounds=*/true); + auto newPtr = b.gep(ptr.getType(), elemTy, ptr, b.i32_val(i * maxVec), + /*inbounds=*/true); auto newVal = loadDShared(rewriter, loc, newPtr, ctaId, vec_ty(elemTy, maxVec), pred); for (Value v : unpackLLVector(loc, newVal, rewriter)) { @@ -376,13 +380,13 @@ Value TargetInfo::loadDShared(RewriterBase &rewriter, Location loc, Value ptr, if (isConstantTruePred(pred)) { Type resultTy = vec == 1 ? Type(int_ty(elemBitwidth)) : Type(vec_ty(int_ty(elemBitwidth), vec)); - load = load(resultTy, ptr); + load = b.load(resultTy, ptr); if (vec > 1) { Type structTy = struct_ty(SmallVector(vec, int_ty(elemBitwidth))); - Value structValue = undef(structTy); + Value structValue = b.undef(structTy); for (int i = 0; i < vec; i++) { - structValue = insert_val(structTy, structValue, - extract_element(load, i32_val(i)), i); + structValue = b.insert_val(structTy, structValue, + b.extract_element(load, b.i32_val(i)), i); } load = structValue; } @@ -430,13 +434,14 @@ bool TargetInfo::warpReduce(RewriterBase &rewriter, Location loc, SmallVector &acc, triton::ReduceOp op, unsigned numLaneToReduce, unsigned interleave) const { + auto b = TritonLLVMOpBuilder(loc, rewriter); if (auto kind = matchReduxKind(op, computeCapability)) { // Based on benchmarking on A100 redux op gives a speed up only when doing // a single reduction (not partitioned) and when the mask is static. // Therefore we currently only enable it to reduce across all the lanes. if (numLaneToReduce == 32) { assert(acc.size() == 1); - Value mask = i32_val(0xFFFFFFFF); + Value mask = b.i32_val(0xFFFFFFFF); // Even though we currently don't use redux for partitioned reduction // the code below supports it in case we want to tweak the heuristic. if (numLaneToReduce < 32) { @@ -444,22 +449,22 @@ bool TargetInfo::warpReduce(RewriterBase &rewriter, Location loc, // each group of numLaneToReduce threads has the correct mask. unsigned bitmask = (1 << numLaneToReduce) - 1; Value threadId = getThreadId(rewriter, loc); - Value laneId = urem(threadId, i32_val(32)); - mask = shl(i32_val(bitmask), - and_(laneId, i32_val(~(numLaneToReduce - 1)))); + Value laneId = b.urem(threadId, b.i32_val(32)); + mask = b.shl(b.i32_val(bitmask), + b.and_(laneId, b.i32_val(~(numLaneToReduce - 1)))); } for (unsigned i = 0; i < acc.size(); ++i) { unsigned bitwidth = cast(acc[i].getType()).getWidth(); if (bitwidth < 32) { if (*kind == NVVM::ReduxKind::MIN || *kind == NVVM::ReduxKind::MAX) - acc[i] = sext(i32_ty, acc[i]); + acc[i] = b.sext(i32_ty, acc[i]); else - acc[i] = zext(i32_ty, acc[i]); + acc[i] = b.zext(i32_ty, acc[i]); } acc[i] = rewriter.create(loc, acc[i].getType(), acc[0], *kind, mask); if (bitwidth < 32) - acc[i] = trunc(int_ty(bitwidth), acc[i]); + acc[i] = b.trunc(int_ty(bitwidth), acc[i]); } return true; } @@ -509,6 +514,7 @@ bool TargetInfo::canUseStMatrix(RankedTensorType tensorTy, void TargetInfo::storeMatrixShared(RewriterBase &rewriter, Location loc, Value ptr, Value val) const { + auto b = TritonLLVMOpBuilder(loc, rewriter); auto vals = unpackLLVector(loc, val, rewriter); // Ensure input consists of 4 vectors, each holding 2 elements of 16 bits assert(vals[0].getType().getIntOrFloatBitWidth() == 16 && @@ -518,11 +524,11 @@ void TargetInfo::storeMatrixShared(RewriterBase &rewriter, Location loc, Type packedTy = vec_ty(vals[0].getType(), 2); SmallVector inputs; for (int i = 0; i < 4; i++) { - Value input = undef(packedTy); + Value input = b.undef(packedTy); for (int j = 0; j < 2; j++) { - input = insert_element(packedTy, input, vals[i * 2 + j], i32_val(j)); + input = b.insert_element(packedTy, input, vals[i * 2 + j], b.i32_val(j)); } - inputs.push_back(bitcast(input, i32_ty)); + inputs.push_back(b.bitcast(input, i32_ty)); } rewriter.create(loc, ptr, inputs); } @@ -540,11 +546,12 @@ void TargetInfo::printf(RewriterBase &rewriter, Value formatStrStart, auto moduleOp = rewriter.getBlock()->getParent()->getParentOfType(); auto funcOp = getVprintfDeclaration(rewriter); auto loc = UnknownLoc::get(ctx); + auto b = TritonLLVMOpBuilder(loc, rewriter); - Value one = i32_val(1); - Value zero = i32_val(0); + Value one = b.i32_val(1); + Value zero = b.i32_val(0); - Value bufferPtr = null(ptr); + Value bufferPtr = b.null(ptr); SmallVector newArgs; if (args.size() >= 1) { @@ -563,16 +570,16 @@ void TargetInfo::printf(RewriterBase &rewriter, Value formatStrStart, /*alignment=*/0); for (const auto &entry : llvm::enumerate(newArgs)) { - auto index = i32_val(entry.index()); + auto index = b.i32_val(entry.index()); auto fieldPtr = - gep(ptr_ty(ctx), structTy, allocated, ArrayRef{zero, index}); - store(entry.value(), fieldPtr); + b.gep(ptr_ty(ctx), structTy, allocated, ArrayRef{zero, index}); + b.store(entry.value(), fieldPtr); } - bufferPtr = bitcast(allocated, ptr); + bufferPtr = b.bitcast(allocated, ptr); } SmallVector operands{formatStrStart, bufferPtr}; - call(funcOp, operands); + b.call(funcOp, operands); } void TargetInfo::printf(RewriterBase &rewriter, StringRef msg, @@ -590,6 +597,7 @@ void TargetInfo::printf(RewriterBase &rewriter, StringRef msg, void TargetInfo::assertFail(RewriterBase &rewriter, Location loc, StringRef message, StringRef file, StringRef func, int line) const { + auto b = TritonLLVMOpBuilder(loc, rewriter); auto funcOp = getAssertfailDeclaration(rewriter); auto moduleOp = rewriter.getBlock()->getParent()->getParentOfType(); llvm::SmallString<64> messageString(message), fileString(file), @@ -603,11 +611,11 @@ void TargetInfo::assertFail(RewriterBase &rewriter, Location loc, LLVM::addStringToModule(loc, rewriter, "assertFile_", fileString); Value funcStringVal = LLVM::addStringToModule(loc, rewriter, "assertFunc_", funcString); - Value lineNumber = i32_val(line); - Value charSize = int_val(sizeof(size_t) * 8, sizeof(char)); + Value lineNumber = b.i32_val(line); + Value charSize = b.int_val(sizeof(size_t) * 8, sizeof(char)); SmallVector operands = {messageStringVal, fileStringVal, lineNumber, funcStringVal, charSize}; - call(funcOp, operands); + b.call(funcOp, operands); } int TargetInfo::getSharedAddressSpace() const { return 3; } diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TensorPtrOpsToLLVM.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TensorPtrOpsToLLVM.cpp index c117eb176431..97e830967093 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TensorPtrOpsToLLVM.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TensorPtrOpsToLLVM.cpp @@ -70,6 +70,7 @@ struct AdvanceOpConversion : public ConvertOpToLLVMPattern { // struct { offset0, offset1, shape0, shape1, stride0, // stride1, base_ptr}; auto loc = op.getLoc(); + auto b = TritonLLVMOpBuilder(loc, rewriter); auto ptrType = op.getPtr().getType(); auto tensorPtr = adaptor.getPtr(); @@ -79,7 +80,7 @@ struct AdvanceOpConversion : public ConvertOpToLLVMPattern { SmallVector newOffsets; for (auto [offset, oldOffset] : llvm::zip_first(offsets, elems)) { - newOffsets.push_back((add(offset, oldOffset))); + newOffsets.push_back((b.add(offset, oldOffset))); } for (size_t i = 0; i < newOffsets.size(); ++i) { diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/UpcastMXFPToLLVM.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/UpcastMXFPToLLVM.cpp index 7b20e8527752..3dbfbc311fc3 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/UpcastMXFPToLLVM.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/UpcastMXFPToLLVM.cpp @@ -85,6 +85,7 @@ static Value createInlineAsmUpcast(Location loc, RewriterBase &rewriter, static SmallVector convertFP4x2To16x2(RewriterBase &rewriter, Location loc, Type targetTy, ArrayRef values) { + auto b = TritonLLVMOpBuilder(loc, rewriter); SmallVector results; MLIRContext *ctx = rewriter.getContext(); bool isFP16 = targetTy == f16_ty; @@ -96,21 +97,21 @@ static SmallVector convertFP4x2To16x2(RewriterBase &rewriter, Value v1 = values[i + 1]; Value v2 = values[i + 2]; Value v3 = values[i + 3]; - Value packedVec = undef(vec_ty(i8_ty, 4)); - packedVec = insert_element(packedVec, v0, i32_val(0)); - packedVec = insert_element(packedVec, v1, i32_val(1)); - packedVec = insert_element(packedVec, v2, i32_val(2)); - packedVec = insert_element(packedVec, v3, i32_val(3)); + Value packedVec = b.undef(vec_ty(i8_ty, 4)); + packedVec = b.insert_element(packedVec, v0, b.i32_val(0)); + packedVec = b.insert_element(packedVec, v1, b.i32_val(1)); + packedVec = b.insert_element(packedVec, v2, b.i32_val(2)); + packedVec = b.insert_element(packedVec, v3, b.i32_val(3)); SmallVector rets(4, i32_ty); Type retType = struct_ty(rets); const char *upcastPtx = isFP16 ? FP4ToFP16Ptx : FP4ToBF16Ptx; Value ret = createInlineAsmUpcast(loc, rewriter, retType, packedVec, upcastPtx); for (int i = 0; i < 4; i++) { - Value extractI32 = extract_val(ret, i); - Value vecbf16 = bitcast(extractI32, vec_ty(targetTy, 2)); - results.push_back(extract_element(vecbf16, i32_val(0))); - results.push_back(extract_element(vecbf16, i32_val(1))); + Value extractI32 = b.extract_val(ret, i); + Value vecbf16 = b.bitcast(extractI32, vec_ty(targetTy, 2)); + results.push_back(b.extract_element(vecbf16, b.i32_val(0))); + results.push_back(b.extract_element(vecbf16, b.i32_val(1))); } } return results; @@ -118,21 +119,22 @@ static SmallVector convertFP4x2To16x2(RewriterBase &rewriter, Value mxfpScale(RewriterBase &rewriter, Location loc, Value v, Value scale, Type fp_ty, bool fastMath) { + auto b = TritonLLVMOpBuilder(loc, rewriter); Value scaleFP; if (fp_ty == bf16_ty) { - scaleFP = bitcast(shl(zext(i16_ty, scale), i16_val(7)), fp_ty); + scaleFP = b.bitcast(b.shl(b.zext(i16_ty, scale), b.i16_val(7)), fp_ty); } else { assert(fp_ty == f16_ty); - scaleFP = - bitcast(shl(zext(i32_ty, scale), i32_val(23)), rewriter.getF32Type()); - scaleFP = fptrunc(fp_ty, scaleFP); + scaleFP = b.bitcast(b.shl(b.zext(i32_ty, scale), b.i32_val(23)), + rewriter.getF32Type()); + scaleFP = b.fptrunc(fp_ty, scaleFP); } - Value scaledV = fmul(bitcast(v, fp_ty), scaleFP); + Value scaledV = b.fmul(b.bitcast(v, fp_ty), scaleFP); if (fastMath) return scaledV; // Account for NaN in the scale as per the mxfp specification. - Value scaleIsNan = icmp_eq(scale, i8_val(0xff)); - return select(scaleIsNan, bitcast(i16_val(0x7fff), fp_ty), scaledV); + Value scaleIsNan = b.icmp_eq(scale, b.i8_val(0xff)); + return b.select(scaleIsNan, b.bitcast(b.i16_val(0x7fff), fp_ty), scaledV); }; namespace { @@ -151,6 +153,7 @@ class UpcastMXFPOpPattern : public ConvertOpToLLVMPattern { ConversionPatternRewriter &rewriter) const override { auto loc = op.getLoc(); + auto b = TritonLLVMOpBuilder(loc, rewriter); auto tyX = cast(op->getOperandTypes()[0]); auto operands = adaptor.getOperands(); @@ -159,12 +162,12 @@ class UpcastMXFPOpPattern : public ConvertOpToLLVMPattern { auto fpType = op.getFpType(); auto outType = op.getType().getElementType(); - Value tid = tid_val(); + Value tid = b.tid_val(); auto mod = op->getParentOfType(); Value warpSize = - i32_val(triton::gpu::TritonGPUDialect::getThreadsPerWarp(mod)); - Value warpId = udiv(tid, warpSize); - Value laneId = urem(tid, warpSize); + b.i32_val(triton::gpu::TritonGPUDialect::getThreadsPerWarp(mod)); + Value warpId = b.udiv(tid, warpSize); + Value laneId = b.urem(tid, warpSize); auto kWidth = cast(op.getType().getEncoding()).getKWidth(); @@ -176,9 +179,9 @@ class UpcastMXFPOpPattern : public ConvertOpToLLVMPattern { // Since we go from a threadShape of 8x4 to 16x2, we let c = tid / 4 * 2 // Then, we need elements c and c + 16 for the first two mxfp vectors // and elements c + 1 and c + 17 for the last two mxfp vectors - auto c = mul(udiv(laneId, i32_val(4)), i32_val(2)); - std::array ci = {c, add(c, i32_val(16)), add(c, i32_val(1)), - add(c, i32_val(17))}; + auto c = b.mul(b.udiv(laneId, b.i32_val(4)), b.i32_val(2)); + std::array ci = {c, b.add(c, b.i32_val(16)), + b.add(c, b.i32_val(1)), b.add(c, b.i32_val(17))}; // TODO Move this logic to using LinearLayouts // Each scale in a warp has to be replicated to cover a tile of shape mxk = diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/Utility.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/Utility.cpp index 044d420b1cb3..1dcc6e5370f2 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/Utility.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/Utility.cpp @@ -10,66 +10,72 @@ using namespace mlir::triton; static Value shuffleCommonImpl(Location loc, RewriterBase &rewriter, Value val, Value i, NVVM::ShflKind mode, Value clamp) { + auto b = TritonLLVMOpBuilder(loc, rewriter); unsigned bits = val.getType().getIntOrFloatBitWidth(); if (bits == 64) { Type vecTy = vec_ty(f32_ty, 2); - Value vec = bitcast(val, vecTy); - Value val0 = extract_element(f32_ty, vec, i32_val(0)); - Value val1 = extract_element(f32_ty, vec, i32_val(1)); + Value vec = b.bitcast(val, vecTy); + Value val0 = b.extract_element(f32_ty, vec, b.i32_val(0)); + Value val1 = b.extract_element(f32_ty, vec, b.i32_val(1)); val0 = shuffleCommonImpl(loc, rewriter, val0, i, mode, clamp); val1 = shuffleCommonImpl(loc, rewriter, val1, i, mode, clamp); - vec = undef(vecTy); - vec = insert_element(vecTy, vec, val0, i32_val(0)); - vec = insert_element(vecTy, vec, val1, i32_val(1)); - return bitcast(vec, val.getType()); + vec = b.undef(vecTy); + vec = b.insert_element(vecTy, vec, val0, b.i32_val(0)); + vec = b.insert_element(vecTy, vec, val1, b.i32_val(1)); + return b.bitcast(vec, val.getType()); } Type type = val.getType(); if (type != i32_ty) { - val = bitcast(val, int_ty(bits)); + val = b.bitcast(val, int_ty(bits)); if (bits < 32) - val = zext(i32_ty, val); + val = b.zext(i32_ty, val); } - Value mask = i32_val(0xFFFFFFFF); + Value mask = b.i32_val(0xFFFFFFFF); Value result = rewriter.create(loc, i32_ty, mask, val, i, clamp, mode, UnitAttr()); if (type != i32_ty) { if (bits < 32) - result = trunc(int_ty(bits), result); - result = bitcast(result, type); + result = b.trunc(int_ty(bits), result); + result = b.bitcast(result, type); } return result; } static Value shuffleCommon(Location loc, RewriterBase &rewriter, Value val, Value i, NVVM::ShflKind mode, Value clamp) { + auto b = TritonLLVMOpBuilder(loc, rewriter); // To shuffle pointers, convert them to i64. Type valTy = val.getType(); if (isa(valTy)) - val = ptrtoint(i64_ty, val); + val = b.ptrtoint(i64_ty, val); Value result = shuffleCommonImpl(loc, rewriter, val, i, mode, clamp); if (isa(valTy)) - result = inttoptr(valTy, result); + result = b.inttoptr(valTy, result); return result; } Value shuffleXor(Location loc, RewriterBase &rewriter, Value val, int i) { - return shuffleCommon(loc, rewriter, val, i32_val(i), NVVM::ShflKind::bfly, - i32_val(0x1f)); + auto b = TritonLLVMOpBuilder(loc, rewriter); + return shuffleCommon(loc, rewriter, val, b.i32_val(i), NVVM::ShflKind::bfly, + b.i32_val(0x1f)); } Value shuffleUp(Location loc, RewriterBase &rewriter, Value val, int i) { - return shuffleCommon(loc, rewriter, val, i32_val(i), NVVM::ShflKind::up, - i32_val(0x0)); + auto b = TritonLLVMOpBuilder(loc, rewriter); + return shuffleCommon(loc, rewriter, val, b.i32_val(i), NVVM::ShflKind::up, + b.i32_val(0x0)); } Value shuffleIdx(Location loc, RewriterBase &rewriter, Value val, int i) { - return shuffleIdx(loc, rewriter, val, i32_val(i)); + auto b = TritonLLVMOpBuilder(loc, rewriter); + return shuffleIdx(loc, rewriter, val, b.i32_val(i)); } Value shuffleIdx(Location loc, RewriterBase &rewriter, Value val, Value i) { + auto b = TritonLLVMOpBuilder(loc, rewriter); return shuffleCommon(loc, rewriter, val, i, NVVM::ShflKind::idx, - i32_val(0x1f)); + b.i32_val(0x1f)); } Value llGetPid(Location loc, RewriterBase &rewriter, ModuleOp moduleOp, From debf809c0278fe997b720f696fb118d66d4e3c77 Mon Sep 17 00:00:00 2001 From: max Date: Mon, 27 Jan 2025 19:12:40 -0500 Subject: [PATCH 2/3] comments --- .../Conversion/TritonGPUToLLVM/Utility.h | 213 +++++++++--------- 1 file changed, 104 insertions(+), 109 deletions(-) diff --git a/include/triton/Conversion/TritonGPUToLLVM/Utility.h b/include/triton/Conversion/TritonGPUToLLVM/Utility.h index b41de1ae9a73..cc23f8eee856 100644 --- a/include/triton/Conversion/TritonGPUToLLVM/Utility.h +++ b/include/triton/Conversion/TritonGPUToLLVM/Utility.h @@ -6,7 +6,6 @@ #include "mlir/Conversion/LLVMCommon/Pattern.h" #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" -#include "mlir/IR/ImplicitLocOpBuilder.h" #include "mlir/Interfaces/FunctionInterfaces.h" #include "triton/Analysis/Utility.h" #include "triton/Conversion/MLIRTypes.h" @@ -31,243 +30,247 @@ using namespace mlir; using namespace mlir::triton; namespace mlir::triton { + +// Returns CTA level thread idx +inline Value getThreadId(OpBuilder &rewriter, Location loc) { + Value tid = + rewriter.create<::mlir::gpu::ThreadIdOp>(loc, ::mlir::gpu::Dimension::x); + Type i32_ty = rewriter.getIntegerType(32); + return rewriter.create(loc, i32_ty, tid); +} + struct TritonLLVMOpBuilder { - TritonLLVMOpBuilder(const Location &loc, RewriterBase &builder) - : loc(loc), builder(builder) {} + TritonLLVMOpBuilder(const Location &loc, OpBuilder &builder) + : loc(loc), builder(&builder) {} // Shortcuts for some commonly used LLVM ops to keep code simple and intuitive // Operators template LLVM::SIToFPOp inttofloat(Args &&...args) { - return builder.create(loc, std::forward(args)...); + return builder->create(loc, std::forward(args)...); } template LLVM::IntToPtrOp inttoptr(Args &&...args) { - return builder.create(loc, std::forward(args)...); + return builder->create(loc, std::forward(args)...); } template LLVM::PtrToIntOp ptrtoint(Args &&...args) { - return builder.create(loc, std::forward(args)...); + return builder->create(loc, std::forward(args)...); } template LLVM::ZExtOp zext(Args &&...args) { - return builder.create(loc, std::forward(args)...); + return builder->create(loc, std::forward(args)...); } template LLVM::SExtOp sext(Args &&...args) { - return builder.create(loc, std::forward(args)...); + return builder->create(loc, std::forward(args)...); } template LLVM::FPExtOp fpext(Args &&...args) { - return builder.create(loc, std::forward(args)...); + return builder->create(loc, std::forward(args)...); } template LLVM::FPTruncOp fptrunc(Args &&...args) { - return builder.create(loc, std::forward(args)...); + return builder->create(loc, std::forward(args)...); } template LLVM::TruncOp trunc(Args &&...args) { - return builder.create(loc, std::forward(args)...); + return builder->create(loc, std::forward(args)...); } template LLVM::UDivOp udiv(Args &&...args) { - return builder.create(loc, std::forward(args)...); + return builder->create(loc, std::forward(args)...); } template LLVM::SDivOp sdiv(Args &&...args) { - return builder.create(loc, std::forward(args)...); + return builder->create(loc, std::forward(args)...); } template LLVM::URemOp urem(Args &&...args) { - return builder.create(loc, std::forward(args)...); + return builder->create(loc, std::forward(args)...); } template LLVM::AddOp add(Args &&...args) { - return builder.create(loc, std::forward(args)...); + return builder->create(loc, std::forward(args)...); } template LLVM::SubOp sub(Args &&...args) { - return builder.create(loc, std::forward(args)...); + return builder->create(loc, std::forward(args)...); } template LLVM::FAddOp fadd(Args &&...args) { - return builder.create(loc, std::forward(args)...); + return builder->create(loc, std::forward(args)...); } template LLVM::MulOp mul(Args &&...args) { - return builder.create(loc, std::forward(args)...); + return builder->create(loc, std::forward(args)...); } template LLVM::FMulOp fmul(Args &&...args) { - return builder.create(loc, std::forward(args)...); + return builder->create(loc, std::forward(args)...); } template LLVM::FMAOp fma(Args &&...args) { - return builder.create(loc, std::forward(args)...); + return builder->create(loc, std::forward(args)...); } template LLVM::FNegOp neg(Args &&...args) { - return builder.create(loc, std::forward(args)...); + return builder->create(loc, std::forward(args)...); } template LLVM::SMaxOp smax(Args &&...args) { - return builder.create(loc, std::forward(args)...); + return builder->create(loc, std::forward(args)...); } template LLVM::UMaxOp umax(Args &&...args) { - return builder.create(loc, std::forward(args)...); + return builder->create(loc, std::forward(args)...); } template LLVM::MaxNumOp fmax(Args &&...args) { - return builder.create(loc, std::forward(args)...); + return builder->create(loc, std::forward(args)...); } template LLVM::SMinOp smin(Args &&...args) { - return builder.create(loc, std::forward(args)...); + return builder->create(loc, std::forward(args)...); } template LLVM::UMinOp umin(Args &&...args) { - return builder.create(loc, std::forward(args)...); + return builder->create(loc, std::forward(args)...); } template LLVM::MinNumOp fmin(Args &&...args) { - return builder.create(loc, std::forward(args)...); + return builder->create(loc, std::forward(args)...); } template LLVM::ShlOp shl(Args &&...args) { - return builder.create(loc, std::forward(args)...); + return builder->create(loc, std::forward(args)...); } template LLVM::LShrOp lshr(Args &&...args) { - return builder.create(loc, std::forward(args)...); + return builder->create(loc, std::forward(args)...); } template LLVM::AShrOp ashr(Args &&...args) { - return builder.create(loc, std::forward(args)...); + return builder->create(loc, std::forward(args)...); } template LLVM::AndOp and_(Args &&...args) { - return builder.create(loc, std::forward(args)...); + return builder->create(loc, std::forward(args)...); } template LLVM::XOrOp xor_(Args &&...args) { - return builder.create(loc, std::forward(args)...); + return builder->create(loc, std::forward(args)...); } template LLVM::OrOp or_(Args &&...args) { - return builder.create(loc, std::forward(args)...); + return builder->create(loc, std::forward(args)...); } LLVM::BitcastOp bitcast(Value val, Type type) { - return builder.create(loc, type, val); + return builder->create(loc, type, val); } template LLVM::AddrSpaceCastOp addrspacecast(Args &&...args) { - return builder.create(loc, - std::forward(args)...); + return builder->create(loc, + std::forward(args)...); } template LLVM::GEPOp gep(Args &&...args) { - return builder.create(loc, std::forward(args)...); + return builder->create(loc, std::forward(args)...); } template LLVM::InsertValueOp insert_val(Args &&...args) { - return builder.create(loc, - std::forward(args)...); + return builder->create(loc, + std::forward(args)...); } template LLVM::ExtractValueOp extract_val(Args &&...args) { - return builder.create(loc, - std::forward(args)...); + return builder->create(loc, + std::forward(args)...); } template LLVM::InsertElementOp insert_element(Args &&...args) { - return builder.create(loc, - std::forward(args)...); + return builder->create(loc, + std::forward(args)...); } template LLVM::ExtractElementOp extract_element(Args &&...args) { - return builder.create(loc, - std::forward(args)...); + return builder->create(loc, + std::forward(args)...); } template LLVM::LoadOp load(Args &&...args) { - return builder.create(loc, std::forward(args)...); + return builder->create(loc, std::forward(args)...); } template LLVM::StoreOp store(Args &&...args) { - return builder.create(loc, std::forward(args)...); + return builder->create(loc, std::forward(args)...); } - template LLVM::FCmpOp fcmp_ogt(Value lhs, Value rhs) { - return builder.create(loc, builder.getI1Type(), - LLVM::FCmpPredicate::ogt, lhs, rhs); + LLVM::FCmpOp fcmp_ogt(Value lhs, Value rhs) { + return builder->create(loc, builder->getI1Type(), + LLVM::FCmpPredicate::ogt, lhs, rhs); } - template LLVM::FCmpOp fcmp_olt(Value lhs, Value rhs) { - return builder.create(loc, builder.getI1Type(), - LLVM::FCmpPredicate::olt, lhs, rhs); + LLVM::FCmpOp fcmp_olt(Value lhs, Value rhs) { + return builder->create(loc, builder->getI1Type(), + LLVM::FCmpPredicate::olt, lhs, rhs); } - template LLVM::FCmpOp fcmp_eq(Value lhs, Value rhs) { - return builder.create(loc, builder.getI1Type(), - LLVM::FCmpPredicate::oeq, lhs, rhs); + LLVM::FCmpOp fcmp_eq(Value lhs, Value rhs) { + return builder->create(loc, builder->getI1Type(), + LLVM::FCmpPredicate::oeq, lhs, rhs); } template LLVM::ICmpOp icmp_eq(Args &&...args) { - return builder.create(loc, LLVM::ICmpPredicate::eq, - std::forward(args)...); + return builder->create(loc, LLVM::ICmpPredicate::eq, + std::forward(args)...); } template LLVM::ICmpOp icmp_ne(Args &&...args) { - return builder.create(loc, LLVM::ICmpPredicate::ne, - std::forward(args)...); + return builder->create(loc, LLVM::ICmpPredicate::ne, + std::forward(args)...); } template LLVM::ICmpOp icmp_slt(Args &&...args) { - return builder.create(loc, LLVM::ICmpPredicate::slt, - std::forward(args)...); + return builder->create(loc, LLVM::ICmpPredicate::slt, + std::forward(args)...); } template LLVM::ICmpOp icmp_sle(Args &&...args) { - return builder.create(loc, LLVM::ICmpPredicate::sle, - std::forward(args)...); + return builder->create(loc, LLVM::ICmpPredicate::sle, + std::forward(args)...); } template LLVM::ICmpOp icmp_sgt(Args &&...args) { - return builder.create(loc, LLVM::ICmpPredicate::sgt, - std::forward(args)...); + return builder->create(loc, LLVM::ICmpPredicate::sgt, + std::forward(args)...); } template LLVM::ICmpOp icmp_sge(Args &&...args) { - return builder.create(loc, LLVM::ICmpPredicate::sge, - std::forward(args)...); + return builder->create(loc, LLVM::ICmpPredicate::sge, + std::forward(args)...); } template LLVM::ICmpOp icmp_ult(Args &&...args) { - return builder.create(loc, LLVM::ICmpPredicate::ult, - std::forward(args)...); + return builder->create(loc, LLVM::ICmpPredicate::ult, + std::forward(args)...); } template LLVM::ICmpOp icmp_ule(Args &&...args) { - return builder.create(loc, LLVM::ICmpPredicate::ule, - std::forward(args)...); + return builder->create(loc, LLVM::ICmpPredicate::ule, + std::forward(args)...); } template LLVM::ICmpOp icmp_ugt(Args &&...args) { - return builder.create(loc, LLVM::ICmpPredicate::ugt, - std::forward(args)...); + return builder->create(loc, LLVM::ICmpPredicate::ugt, + std::forward(args)...); } template LLVM::ICmpOp icmp_uge(Args &&...args) { - return builder.create(loc, LLVM::ICmpPredicate::uge, - std::forward(args)...); + return builder->create(loc, LLVM::ICmpPredicate::uge, + std::forward(args)...); } template LLVM::SelectOp select(Args &&...args) { - return builder.create(loc, std::forward(args)...); + return builder->create(loc, std::forward(args)...); } template LLVM::AddressOfOp address_of(Args &&...args) { - return builder.create(loc, std::forward(args)...); + return builder->create(loc, std::forward(args)...); } mlir::gpu::BarrierOp barrier() { - return builder.create(loc); + return builder->create(loc); } template LLVM::UndefOp undef(Args &&...args) { - return builder.create(loc, std::forward(args)...); + return builder->create(loc, std::forward(args)...); } template LLVM::ZeroOp null(Args &&...args) { - return builder.create(loc, std::forward(args)...); + return builder->create(loc, std::forward(args)...); } template LLVM::CallOp call(Args &&...args) { - return builder.create(loc, std::forward(args)...); + return builder->create(loc, std::forward(args)...); } // Constants Value int_val(short bitwidth, int64_t val) { - Type ty = builder.getIntegerType(bitwidth); - return builder.create(loc, ty, - builder.getIntegerAttr(ty, val)); + Type ty = builder->getIntegerType(bitwidth); + return builder->create(loc, ty, + builder->getIntegerAttr(ty, val)); } Value i1_val(int64_t val) { return int_val(1, val); } Value true_val() { return int_val(1, true); } Value false_val() { return int_val(1, false); } Value f16_val(float v) { - auto type = type::f16Ty(builder.getContext()); - return builder.create(loc, type, - builder.getF16FloatAttr(v)); + auto type = type::f16Ty(builder->getContext()); + return builder->create(loc, type, + builder->getF16FloatAttr(v)); } Value f32_val(float v) { - auto type = type::f32Ty(builder.getContext()); - return builder.create(loc, type, - builder.getF32FloatAttr(v)); + auto type = type::f32Ty(builder->getContext()); + return builder->create(loc, type, + builder->getF32FloatAttr(v)); } Value f64_val(double v) { - auto type = type::f64Ty(builder.getContext()); - return builder.create(loc, type, - builder.getF64FloatAttr(v)); + auto type = type::f64Ty(builder->getContext()); + return builder->create(loc, type, + builder->getF64FloatAttr(v)); } Value i8_val(int64_t val) { return int_val(8, val); } Value i16_val(int64_t val) { return int_val(16, val); } Value i32_val(int64_t val) { return int_val(32, val); } Value i64_val(int64_t val) { return int_val(64, val); } - Value tid_val() { - Value tid = - builder.create<::mlir::gpu::ThreadIdOp>(loc, ::mlir::gpu::Dimension::x); - Type i32_ty = builder.getIntegerType(32); - return builder.create(loc, i32_ty, tid); - } + Value tid_val() { return getThreadId(*this->builder, loc); } Location loc; - RewriterBase &builder; + OpBuilder *builder; }; } // namespace mlir::triton @@ -657,14 +660,6 @@ Value mxfpScaleBf16(RewriterBase &rewriter, Location loc, Value v, Value scale, } // namespace LLVM -/* ------------------------------------ */ -// Returns CTA level thread idx -inline Value getThreadId(RewriterBase &rewriter, Location loc) { - Value tid = - rewriter.create<::mlir::gpu::ThreadIdOp>(loc, ::mlir::gpu::Dimension::x); - return rewriter.create(loc, i32_ty, tid); -} - // ----------------------------------------------------------------------- // Shared memory utilities // ----------------------------------------------------------------------- From d0dafa55603dcf2cea96f2941097fdab8f14261a Mon Sep 17 00:00:00 2001 From: Maksim Levental Date: Tue, 28 Jan 2025 13:39:31 -0500 Subject: [PATCH 3/3] rebase --- .../lib/NVGPUToLLVM/NVGPUToLLVMPass.cpp | 12 ++- .../DotOpToLLVM/MMAv5.cpp | 60 +++++++------ .../DotOpToLLVM/WGMMA.cpp | 24 ++--- .../LoadStoreOpToLLVM.cpp | 21 +++-- .../TensorMemoryToLLVM.cpp | 87 +++++++++++-------- .../lib/TritonNVIDIAGPUToLLVM/Utility.cpp | 5 +- 6 files changed, 118 insertions(+), 91 deletions(-) diff --git a/third_party/nvidia/lib/NVGPUToLLVM/NVGPUToLLVMPass.cpp b/third_party/nvidia/lib/NVGPUToLLVM/NVGPUToLLVMPass.cpp index 7a8c03146489..11a43cac5a49 100644 --- a/third_party/nvidia/lib/NVGPUToLLVM/NVGPUToLLVMPass.cpp +++ b/third_party/nvidia/lib/NVGPUToLLVM/NVGPUToLLVMPass.cpp @@ -231,8 +231,10 @@ class WarpIdOpPattern : public OpRewritePattern { LogicalResult matchAndRewrite(ttn::WarpIdOp op, PatternRewriter &rewriter) const override { auto loc = op.getLoc(); + auto b = TritonLLVMOpBuilder(loc, rewriter); + Value threadId = rewriter.create(loc, i32_ty); - Value warpId = udiv(threadId, i32_val(32)); + Value warpId = b.udiv(threadId, b.i32_val(32)); warpId = LLVM::NVIDIA::shuffleIdx(loc, rewriter, warpId, 0); rewriter.replaceOp(op, warpId); return success(); @@ -648,6 +650,7 @@ static Value createTMAlloc(IRRewriter &rewriter, LLVM::LLVMFuncOp func, size_t size, Value pred, bool twoCTAs) { PTXBuilder ptxBuilder; Location loc = func.getLoc(); + auto b = TritonLLVMOpBuilder(loc, rewriter); Value sharedMem = mlir::LLVM::getStackPointer(rewriter, func); std::string ptxString = "@$0 tcgen05.alloc.cta_group::" + std::to_string(twoCTAs ? 2 : 1) + @@ -660,9 +663,9 @@ static Value createTMAlloc(IRRewriter &rewriter, LLVM::LLVMFuncOp func, auto voidTy = void_ty(func->getContext()); ptxBuilder.launch(rewriter, loc, void_ty(func->getContext())); rewriter.create(loc); - Value address = load(i32_ty, sharedMem); + Value address = b.load(i32_ty, sharedMem); rewriter.create(loc); - address = inttoptr(ptr_ty(func.getContext(), 6), address); + address = b.inttoptr(ptr_ty(func.getContext(), 6), address); return address; } @@ -709,6 +712,7 @@ static Value initTensorMemory(LLVM::LLVMFuncOp func) { rewriter.setInsertionPointToStart(&func.front()); auto ctx = mod.getContext(); auto loc = func.getLoc(); + auto b = TritonLLVMOpBuilder(loc, rewriter); // A proper error will be raised by the frontend, but to allow compilation to // continue we emit a trap. if (size > 512) { @@ -721,7 +725,7 @@ static Value initTensorMemory(LLVM::LLVMFuncOp func) { // should be fine for now. bool useTwoCTAs = numCTAs == 2; Value threadId = rewriter.create(loc, i32_ty); - Value pred = icmp_ult(threadId, i32_val(32)); + Value pred = b.icmp_ult(threadId, b.i32_val(32)); Value alloc = createTMAlloc(rewriter, func, size, pred, useTwoCTAs); createRelinquishAlloc(rewriter, loc, pred, useTwoCTAs); // TODO: pred will have a long liverange, we need to check if this is a diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/MMAv5.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/MMAv5.cpp index c347cad9880c..74aeb1062508 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/MMAv5.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/MMAv5.cpp @@ -33,6 +33,7 @@ mlir::triton::NVIDIA::DotOpMmaV5TmemLoader::DotOpMmaV5TmemLoader( Value mlir::triton::NVIDIA::DotOpMmaV5TmemLoader::tmemLoad( int a, int b, ConversionPatternRewriter &rewriter, Location loc) { + auto tb = TritonLLVMOpBuilder(loc, rewriter); int numRows = 64; if (interleaved || instrShape[0] >= 128) numRows = 128; @@ -40,15 +41,15 @@ Value mlir::triton::NVIDIA::DotOpMmaV5TmemLoader::tmemLoad( ((instrShape[0] * instrShape[1]) / numRows) / numElementsPer32b; Value address = base; int blockId = a + b * numRepM; - address = ptrtoint(i32_ty, address); + address = tb.ptrtoint(i32_ty, address); if (!interleaved) { - address = add(address, i32_val(numColPerBlock * blockId)); + address = tb.add(address, tb.i32_val(numColPerBlock * blockId)); } else { int blockIdIsOdd = blockId & 1; int blockIdPrevEven = blockId - blockIdIsOdd; - Value offset = - i32_val(numColPerBlock * blockIdPrevEven + ((16 * blockIdIsOdd) << 16)); - address = add(address, offset); + Value offset = tb.i32_val(numColPerBlock * blockIdPrevEven + + ((16 * blockIdIsOdd) << 16)); + address = tb.add(address, offset); } return address; } @@ -72,6 +73,7 @@ static Value createInstDescriptor(ConversionPatternRewriter &rewriter, triton::nvidia_gpu::TCGen5MMAOp op, int M, int N, bool transposeA, bool transposeB) { Location loc = op.getLoc(); + auto b = TritonLLVMOpBuilder(loc, rewriter); union TCGen5InstructionDescriptor { uint32_t descriptor; struct { @@ -119,7 +121,7 @@ static Value createInstDescriptor(ConversionPatternRewriter &rewriter, Type dstElType = op.getD().getType().getElementType(); assert(dstElType.isF16() || dstElType.isF32()); desc.dType = dstElType.isF16() ? 0 : 1; - return int_val(32, desc.descriptor); + return b.int_val(32, desc.descriptor); } static Value createScaleInstDescriptor(ConversionPatternRewriter &rewriter, @@ -129,6 +131,7 @@ static Value createScaleInstDescriptor(ConversionPatternRewriter &rewriter, int scaleFactorsubIdxB, mxfpKind mxfpInstKind) { Location loc = op.getLoc(); + auto b = TritonLLVMOpBuilder(loc, rewriter); union TCGen5InstructionDescriptor { uint32_t descriptor; struct { @@ -209,7 +212,7 @@ static Value createScaleInstDescriptor(ConversionPatternRewriter &rewriter, } } - return int_val(32, desc.descriptor); + return b.int_val(32, desc.descriptor); } static void createGen5MMA(ConversionPatternRewriter &rewriter, Location loc, @@ -276,6 +279,7 @@ static void createScaledGen5MMA(ConversionPatternRewriter &rewriter, static void createMMACommit(ConversionPatternRewriter &rewriter, Location loc, Value barrier, Value pred, bool twoCTAs = false) { PTXBuilder ptxBuilder; + auto b = TritonLLVMOpBuilder(loc, rewriter); SmallVector ptxOperands; auto *predOperand = ptxBuilder.newOperand(pred, "b"); ptxOperands.push_back(predOperand); @@ -285,7 +289,7 @@ static void createMMACommit(ConversionPatternRewriter &rewriter, Location loc, if (twoCTAs) { // .multicast::cluster and mask 0x3 means the completion of UTCMMA.2CTA will // be boardcasted into CTAid 0 and 1 - auto *ctaMask = ptxBuilder.newOperand(int_val(16, 0x3), "h"); + auto *ctaMask = ptxBuilder.newOperand(b.int_val(16, 0x3), "h"); ptxOperands.push_back(ctaMask); opcode = "@$0 " "tcgen05.commit.cta_group::2.mbarrier::arrive::one.shared::" @@ -303,12 +307,12 @@ void convertDot(const LLVMTypeConverter *typeConverter, triton::nvidia_gpu::TCGen5MMAOp op, Value a, Value b, Value d, Value loadedA, Value loadedB, Value loadedD, Value useDFlag, Value pred, Value barrier) { - + auto tb = TritonLLVMOpBuilder(loc, rewriter); bool twoCTAs = op.getTwoCtas().has_value(); // Only run mma on one thread. We currently use elect as ptxas is not able to // detect that tid.x == 0 is true only for 1 thread. Value warpId = rewriter.create(loc); - Value wapr0 = icmp_eq(warpId, i32_val(0)); + Value wapr0 = tb.icmp_eq(warpId, tb.i32_val(0)); if (twoCTAs) { // TODO: we have to sync the two CTAs because we currently don't use remove // barriers for the copies. @@ -316,10 +320,10 @@ void convertDot(const LLVMTypeConverter *typeConverter, rewriter.create(loc); Value clusterId = rewriter.create(loc); - Value cluster0 = icmp_eq(clusterId, i32_val(0)); - pred = and_(pred, cluster0); + Value cluster0 = tb.icmp_eq(clusterId, tb.i32_val(0)); + pred = tb.and_(pred, cluster0); } - pred = and_(pred, wapr0); + pred = tb.and_(pred, wapr0); // Wrap the whole mma code sequence within a IF block. auto *curBlock = rewriter.getInsertionBlock(); @@ -382,7 +386,7 @@ void convertDot(const LLVMTypeConverter *typeConverter, Value instDescriptor = createInstDescriptor(rewriter, op, twoCTAs ? mmaSizeM * 2 : mmaSizeM, mmaSizeN, transA, transB); - Value zero = i32_val(0); + Value zero = tb.i32_val(0); SmallVector shapeA(triton::gpu::getShapePerCTA(aTensorTy)); SmallVector shapeB(triton::gpu::getShapePerCTA(bTensorTy)); SmallVector aOperandShape = {(unsigned)mmaSizeM, @@ -411,7 +415,7 @@ void convertDot(const LLVMTypeConverter *typeConverter, b = bLoader.smemLoad(n, k, rewriter, loc); createGen5MMA(rewriter, loc, op, a, b, accAddress, pred, instDescriptor, useInitAcc, aInTmem, twoCTAs); - useInitAcc = i1_val(1); + useInitAcc = tb.i1_val(1); } } } @@ -475,6 +479,7 @@ struct TCGen5MMAScaledOpConversion "tensorcore op should have a barrier at this point."); auto typeConverter = getTypeConverter(); Location loc = op.getLoc(); + auto tb = TritonLLVMOpBuilder(loc, rewriter); auto aTensorTy = cast(op.getA().getType()); auto bTensorTy = cast(op.getB().getType()); auto dTensorTy = cast(op.getD().getType()); @@ -508,15 +513,15 @@ struct TCGen5MMAScaledOpConversion loc, adaptor.getD(), typeConverter->convertType(dTensorTy.getElementType()), rewriter) .getBase(); - baseD = ptrtoint(i32_ty, baseD); + baseD = tb.ptrtoint(i32_ty, baseD); Value baseScaleA = getSharedMemoryObjectFromStruct(loc, adaptor.getAScale(), i8_ty, rewriter) .getBase(); Value baseScaleB = getSharedMemoryObjectFromStruct(loc, adaptor.getBScale(), i8_ty, rewriter) .getBase(); - baseScaleA = ptrtoint(i32_ty, baseScaleA); - baseScaleB = ptrtoint(i32_ty, baseScaleB); + baseScaleA = tb.ptrtoint(i32_ty, baseScaleA); + baseScaleB = tb.ptrtoint(i32_ty, baseScaleB); unsigned int M = dTensorTy.getDimSize(0); unsigned int N = dTensorTy.getDimSize(1); @@ -537,7 +542,7 @@ struct TCGen5MMAScaledOpConversion int numRepK = ceil(K, mmaSizeK); bool interleaved = (mmaSizeM == 64 && (numRepM > 1 || numRepN > 1)); - Value zero = i32_val(0); + Value zero = tb.i32_val(0); SmallVector shapeA(aTensorTy.getShape()); SmallVector shapeB(bTensorTy.getShape()); if (opKindIsMXFP4) { @@ -561,11 +566,12 @@ struct TCGen5MMAScaledOpConversion numBitsPerElementB, rewriter, loc); // TODO: Support accumulator init optimization for scaled dot - Value useInitAcc = int_val(1, 1); + Value useInitAcc = tb.int_val(1, 1); // Only run mma on one thread. We currently use elect as ptxas is not able // to detect that tid.x == 0 is true only for 1 thread. - Value pred = and_(adaptor.getPred(), - LLVM::NVIDIA::createElectPredicateWarp0(loc, rewriter)); + Value pred = + tb.and_(adaptor.getPred(), + LLVM::NVIDIA::createElectPredicateWarp0(loc, rewriter)); int numRows = 128; int colSizeInBits = 32; int numColPerBlock = @@ -599,16 +605,16 @@ struct TCGen5MMAScaledOpConversion // Blocks are laid out along M first then N as described in // `TensorMemorySpace` definition. int blockId = m + n * numRepM; - Value accAddress = add(baseD, i32_val(numColPerBlock * blockId)); + Value accAddress = tb.add(baseD, tb.i32_val(numColPerBlock * blockId)); for (int k = 0; k < numRepK; k++) { Value a = aLoader->memLoad(m, k, rewriter, loc); Value b = bLoader.smemLoad(n, k, rewriter, loc); int subWordIdx = k % (4 / scaleFactorColsPerSet); int wordIdx = k / (4 / scaleFactorColsPerSet); - Value scaleA = add(baseScaleA, i32_val((m + wordIdx * numRepM) * - numColPerScaleBlockA)); - Value scaleB = add(baseScaleB, i32_val((n + wordIdx * numRepN) * - numColPerScaleBlockB)); + Value scaleA = tb.add(baseScaleA, tb.i32_val((m + wordIdx * numRepM) * + numColPerScaleBlockA)); + Value scaleB = tb.add(baseScaleB, tb.i32_val((n + wordIdx * numRepN) * + numColPerScaleBlockB)); Value instDescriptor = createScaleInstDescriptor( rewriter, op, mmaSizeM, mmaSizeN, transA, transB, subWordIdx, subWordIdx, mxfpInstKind); diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/WGMMA.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/WGMMA.cpp index 2291308b6d77..66a8ff3069f3 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/WGMMA.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/WGMMA.cpp @@ -93,6 +93,7 @@ int64_t getSwizzlingFromLayout(const SharedEncodingAttr &layout, static Value createDescriptor(ConversionPatternRewriter &rewriter, Location loc, int64_t swizzling, uint32_t stride) { + auto b = TritonLLVMOpBuilder(loc, rewriter); static_assert(sizeof(SMEMDescriptor) == 8, "Descriptor size should be 64 bits."); SMEMDescriptor desc; @@ -144,17 +145,16 @@ Value mlir::triton::NVIDIA::DotOpMmaV3SmemLoader::smemLoad( auto tb = TritonLLVMOpBuilder(loc, rewriter); Value k = tb.i32_val(b * instrShape[1]); Value m = tb.add(tb.i32_val(a * dimWpt * instrShape[0]), - tb.mul(warpId, tb.i32_val(instrShape[0]))); - if (trans) { - std::swap(k, m); - } - Value leading_offset = - tb.mul(tb.udiv(k, elemsPerSwizzlingRowVal), - tb.i32_val(shape[ord[1]] * elemsPerSwizzlingRow)); + tb.mul(warpId, tb.i32_val(instrShape[0]))); + if (trans) { + std::swap(k, m); + } + Value leading_offset = + tb.mul(tb.udiv(k, elemsPerSwizzlingRowVal), + tb.i32_val(shape[ord[1]] * elemsPerSwizzlingRow)); Value stride_offset = tb.mul(m, elemsPerSwizzlingRowVal); - Value offset = - tb.add(tb.add(leading_offset, stride_offset), - tb.urem(k, elemsPerSwizzlingRowVal)); + Value offset = tb.add(tb.add(leading_offset, stride_offset), + tb.urem(k, elemsPerSwizzlingRowVal)); Value off1; // Avoid the runtime udiv if we know the elements are byte multiples if (elemBits % 8) { @@ -168,8 +168,8 @@ Value mlir::triton::NVIDIA::DotOpMmaV3SmemLoader::smemLoad( // Add the base at the end to make it easier to do loop invariant code // motion. loadDesc = tb.add( - loadDesc, tb.lshr(tb.shl(tb.ptrtoint(i64_ty, base), tb.int_val(64, 46)), - tb.int_val(64, 50))); + loadDesc, tb.lshr(tb.shl(tb.ptrtoint(i64_ty, base), tb.int_val(64, 46)), + tb.int_val(64, 50))); return loadDesc; } diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/LoadStoreOpToLLVM.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/LoadStoreOpToLLVM.cpp index 889888e7e505..0a3a6a0e50b6 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/LoadStoreOpToLLVM.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/LoadStoreOpToLLVM.cpp @@ -1384,6 +1384,7 @@ static LogicalResult iterateGatherScatterIndices( function_ref)> callback) { MLIRContext *ctx = op->getContext(); Location loc = op->getLoc(); + auto b = TritonLLVMOpBuilder(loc, rewriter); StringAttr kDim0 = str_attr("dim0"); StringAttr kDim1 = str_attr("dim1"); @@ -1461,24 +1462,25 @@ static LogicalResult iterateGatherScatterIndices( Value warpId = rewriter.create(loc); // Each block has separate shared memory. Multiple CTAs don't work anyways. - Value blockId = i32_val(0); + Value blockId = b.i32_val(0); // Mask out warps with redundant x offsets. - pred = and_(pred, icmp_eq(i32_val(0), and_(warpId, i32_val(warpMask)))); + pred = b.and_(pred, + b.icmp_eq(b.i32_val(0), b.and_(warpId, b.i32_val(warpMask)))); // Select one thread in each warp to issue the gather4 messages. - pred = and_(pred, LLVM::NVIDIA::createElectPredicate(loc, rewriter)); + pred = b.and_(pred, LLVM::NVIDIA::createElectPredicate(loc, rewriter)); SmallVector xOffsets = unpackLLElements(loc, xOffsetsValue, rewriter); // Lane ID doesn't matter. - Value laneId = i32_val(0); + Value laneId = b.i32_val(0); for (auto regId : seq(0, xOffsets.size(), 4)) { // Skip redundant x offsets within a thread. if ((regMask & regId) != 0) continue; - Value regIdVal = i32_val(regId); + Value regIdVal = b.i32_val(regId); for (auto msgId : llvm::seq(numMessagesPerRow)) { - Value msgIdVal = i32_val(msgId); + Value msgIdVal = b.i32_val(msgId); auto result = applyLinearLayout(loc, rewriter, msgToShared, {{kMsg, msgIdVal}, @@ -1492,8 +1494,8 @@ static LogicalResult iterateGatherScatterIndices( // Because we checked that the memdesc's allocshape and shape match, we // can ignore the strides and directly index into the shmem object. Value shMemPtr = - gep(elemPtrTy, llvmElemTy, smemObj.getBase(), shMemOffset); - Value yOffset = add(yOffsetValue, i32_val(msgId * msgSize)); + b.gep(elemPtrTy, llvmElemTy, smemObj.getBase(), shMemOffset); + Value yOffset = b.add(yOffsetValue, b.i32_val(msgId * msgSize)); callback(pred, shMemPtr, yOffset, ArrayRef(xOffsets).slice(regId, 4)); }; @@ -1571,6 +1573,7 @@ LogicalResult AsyncTMAScatterOpConversion::matchAndRewrite( triton::nvidia_gpu::AsyncTMAScatterOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { Location loc = op.getLoc(); + auto b = TritonLLVMOpBuilder(loc, rewriter); MLIRContext *ctx = getContext(); LLVM::LLVMVoidType voidTy = void_ty(op->getContext()); @@ -1601,7 +1604,7 @@ LogicalResult AsyncTMAScatterOpConversion::matchAndRewrite( if (failed(iterateGatherScatterIndices( op, rewriter, *getTypeConverter(), op.getXOffsets(), op.getSrc(), adaptor.getSrc(), adaptor.getXOffsets(), adaptor.getYOffset(), - /*pred=*/true_val(), callback))) + /*pred=*/b.true_val(), callback))) return failure(); // TODO: Separate the syncronizations operations into separate TTGIR ops to diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TensorMemoryToLLVM.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TensorMemoryToLLVM.cpp index 22c4e4a9b302..ec59b9681521 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TensorMemoryToLLVM.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TensorMemoryToLLVM.cpp @@ -17,21 +17,22 @@ namespace { SmallVector packToI32(const SmallVector &values, Location loc, ConversionPatternRewriter &rewriter) { + auto b = TritonLLVMOpBuilder(loc, rewriter); SmallVector packedValues; Type elType = values[0].getType(); int numElementsPer32B = 32 / elType.getIntOrFloatBitWidth(); if (numElementsPer32B == 1) return values; - Value packed = undef(vec_ty(elType, numElementsPer32B)); + Value packed = b.undef(vec_ty(elType, numElementsPer32B)); for (int i = 0; i < values.size(); i++) { Value val = values[i]; - packed = insert_element(packed.getType(), packed, val, - i32_val(i % numElementsPer32B)); + packed = b.insert_element(packed.getType(), packed, val, + b.i32_val(i % numElementsPer32B)); if (i % numElementsPer32B == numElementsPer32B - 1 || i == values.size() - 1) { - packed = bitcast(packed, i32_ty); + packed = b.bitcast(packed, i32_ty); packedValues.push_back(packed); - packed = undef(vec_ty(elType, numElementsPer32B)); + packed = b.undef(vec_ty(elType, numElementsPer32B)); } } return packedValues; @@ -58,6 +59,7 @@ void calculateAddressAndEmitTmemMessage( int /*secondHalfColOffset*/, bool /*unpackedb16*/, int /*regsPerMessage*/, bool /*useStridedMessage*/)> &createMemoryOp) { + auto b = TritonLLVMOpBuilder(loc, rewriter); const int numRowsPerWarp = 32; if (!nvidia_gpu::isDistributedLayoutTMemCompatible(mod, tensorType, @@ -103,9 +105,9 @@ void calculateAddressAndEmitTmemMessage( numColsPerBlock *= 2; Value warpId = rewriter.create(loc); - Value warpIdInGroup = urem(warpId, i32_val(4)); - Value warpGroupId = udiv(warpId, i32_val(4)); - Value rowId = mul(warpIdInGroup, i32_val(numRowsPerWarp)); + Value warpIdInGroup = b.urem(warpId, b.i32_val(4)); + Value warpGroupId = b.udiv(warpId, b.i32_val(4)); + Value rowId = b.mul(warpIdInGroup, b.i32_val(numRowsPerWarp)); int colsPerWarpGroup = numColsPerBlock / numWarpGroupsPerBlock; @@ -121,28 +123,32 @@ void calculateAddressAndEmitTmemMessage( } for (int block = 0; block < numBlocks; block += numWarpGroups) { - Value address = ptrtoint(i32_ty, baseAddress); + Value address = b.ptrtoint(i32_ty, baseAddress); Value blockId = - add(i32_val(block), udiv(warpGroupId, i32_val(numWarpGroupsPerBlock))); + b.add(b.i32_val(block), + b.udiv(warpGroupId, b.i32_val(numWarpGroupsPerBlock))); Value blockRowId = rowId; Value warpGroupIdInBlock = - urem(warpGroupId, i32_val(numWarpGroupsPerBlock)); - Value startColumnId = mul(warpGroupIdInBlock, i32_val(colsPerWarpGroup)); + b.urem(warpGroupId, b.i32_val(numWarpGroupsPerBlock)); + Value startColumnId = + b.mul(warpGroupIdInBlock, b.i32_val(colsPerWarpGroup)); if (blocksInterleaved) { - Value blockIdIsOdd = urem(blockId, i32_val(2)); - Value blockIdPrevEven = sub(blockId, blockIdIsOdd); - blockRowId = add(blockRowId, mul(blockIdIsOdd, i32_val(16))); - startColumnId = add(startColumnId, - mul(blockIdPrevEven, i32_val(numColsPerBlock / 2))); + Value blockIdIsOdd = b.urem(blockId, b.i32_val(2)); + Value blockIdPrevEven = b.sub(blockId, blockIdIsOdd); + blockRowId = b.add(blockRowId, b.mul(blockIdIsOdd, b.i32_val(16))); + startColumnId = + b.add(startColumnId, + b.mul(blockIdPrevEven, b.i32_val(numColsPerBlock / 2))); } else { startColumnId = - add(startColumnId, mul(blockId, i32_val(numColsPerBlock))); + b.add(startColumnId, b.mul(blockId, b.i32_val(numColsPerBlock))); } - address = add(add(address, shl(blockRowId, i32_val(16))), startColumnId); + address = + b.add(b.add(address, b.shl(blockRowId, b.i32_val(16))), startColumnId); for (int colStart = 0; colStart < numColsPerBlock; colStart += numColsPerMessage) { - Value startAddress = add(address, i32_val(colStart)); + Value startAddress = b.add(address, b.i32_val(colStart)); // Column offset of second half of the message in case of 16x32bx2 // message. @@ -218,6 +224,7 @@ static void lowerStoreToTensorMemory(Location loc, ModuleOp mod, Value src, Value dest, Value llSrc, Value pred, SharedMemoryObject smemObj, ConversionPatternRewriter &rewriter) { + auto b = TritonLLVMOpBuilder(loc, rewriter); SmallVector srcValues = unpackLLElements(loc, llSrc, rewriter); srcValues = packToI32(srcValues, loc, rewriter); auto dstType = cast(dest.getType()); @@ -246,7 +253,7 @@ static void lowerStoreToTensorMemory(Location loc, ModuleOp mod, Value src, // Emit a barrier to ensure all threads have finished writing to tensor memory // before any use of the tensor memory. - barrier(); + b.barrier(); } struct TensorMemoryAllocOpConversion @@ -257,20 +264,21 @@ struct TensorMemoryAllocOpConversion matchAndRewrite(triton::nvidia_gpu::TMEMAllocOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Location loc = op->getLoc(); + auto b = TritonLLVMOpBuilder(loc, rewriter); auto mod = op->getParentOfType(); Value base = rewriter.create(loc); - Value baseInt = ptrtoint(i32_ty, base); + Value baseInt = b.ptrtoint(i32_ty, base); int colOffset = cast(op->getAttr("tensor_memory_col_offset")) .getValue() .getZExtValue(); int rowOffset = cast(op->getAttr("tensor_memory_row_offset")) .getValue() .getZExtValue(); - Value allocAddress = add(baseInt, i32_val(colOffset | rowOffset << 16)); + Value allocAddress = b.add(baseInt, b.i32_val(colOffset | rowOffset << 16)); // Cast to address space 3 as the shared memory object uses 3. // TODO: clean this up and use either a int or ptr address space 6 auto ptrTy = LLVM::LLVMPointerType::get(rewriter.getContext(), 3); - Value ptr = inttoptr(ptrTy, allocAddress); + Value ptr = b.inttoptr(ptrTy, allocAddress); SmallVector order(op.getType().getRank()); std::iota(order.begin(), order.end(), 0); std::reverse(order.begin(), order.end()); @@ -280,7 +288,7 @@ struct TensorMemoryAllocOpConversion if (op.getSrc()) { lowerStoreToTensorMemory(loc, mod, op.getSrc(), op.getResult(), - adaptor.getSrc(), i1_val(true), smemObj, + adaptor.getSrc(), b.i1_val(true), smemObj, rewriter); } @@ -333,17 +341,18 @@ static Value createTensorMemoryLoad(Location loc, static SmallVector unpackResults(Value packedValues, Type elemTy, int numCols, Location loc, ConversionPatternRewriter &rewriter) { + auto b = TritonLLVMOpBuilder(loc, rewriter); SmallVector resultVals; int numElementsPer32B = 32 / elemTy.getIntOrFloatBitWidth(); Type packedType = elemTy; if (numElementsPer32B > 1) packedType = vec_ty(elemTy, numElementsPer32B); for (int i = 0; i < numCols; i++) { - Value result = extract_val(i32_ty, packedValues, i); - result = bitcast(result, packedType); + Value result = b.extract_val(i32_ty, packedValues, i); + result = b.bitcast(result, packedType); if (numElementsPer32B > 1) { for (int j = 0; j < numElementsPer32B; j++) { - Value elem = extract_element(elemTy, result, i32_val(j)); + Value elem = b.extract_element(elemTy, result, b.i32_val(j)); resultVals.push_back(elem); } } else { @@ -425,6 +434,7 @@ struct TensorMemoryStoreOpConversion static Value createBlockedScalesSMEMDescriptor(ConversionPatternRewriter &rewriter, Location loc, Value baseSrc) { + auto b = TritonLLVMOpBuilder(loc, rewriter); static_assert(sizeof(NVIDIA::SMEMDescriptor) == 8, "Descriptor size should be 64 bits."); NVIDIA::SMEMDescriptor desc; @@ -434,9 +444,9 @@ createBlockedScalesSMEMDescriptor(ConversionPatternRewriter &rewriter, desc.strideDimensionBaseOffset = 128 >> 4; // 8 x 16 bytes // See matrix-descriptor-encode(x) function in the ptx doc. // matrix-descriptor-encode(addr) = (addr & 0x3FFFF) >> 4 - auto smemAddr = ptrtoint(i64_ty, baseSrc); - return add(int_val(64, desc.descriptor), - lshr(shl(smemAddr, int_val(64, 46)), int_val(64, 50))); + auto smemAddr = b.ptrtoint(i64_ty, baseSrc); + return b.add(b.int_val(64, desc.descriptor), + b.lshr(b.shl(smemAddr, b.int_val(64, 46)), b.int_val(64, 50))); } static void createCommit(ConversionPatternRewriter &rewriter, Location loc, @@ -468,6 +478,7 @@ struct TensorMemoryCopyOpConversion matchAndRewrite(triton::nvidia_gpu::TMEMCopyOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Location loc = op->getLoc(); + auto b = TritonLLVMOpBuilder(loc, rewriter); auto srcTy = cast(op.getSrc().getType()); assert(isa(srcTy.getMemorySpace())); assert(isa(srcTy.getEncoding())); @@ -502,10 +513,11 @@ struct TensorMemoryCopyOpConversion for (int j = 0; j < repK; ++j) { // Multiple copies of 32x128b blocks are laid out along M/N first then // K - auto colOffset = int_val(32, (j * repMorN + i) * 4); - auto tmemAddr = add(ptrtoint(i32_ty, baseDst), colOffset); + auto colOffset = b.int_val(32, (j * repMorN + i) * 4); + auto tmemAddr = b.add(b.ptrtoint(i32_ty, baseDst), colOffset); createTcgen05Cp(rewriter, loc, tmemAddr, smemDesc, pred); - smemDesc = add(smemDesc, int_val(64, 512 >> 4)); // one chunk = 32x16B + smemDesc = + b.add(smemDesc, b.int_val(64, 512 >> 4)); // one chunk = 32x16B } } }; @@ -568,6 +580,7 @@ struct MemDescSubviewOpConversion matchAndRewrite(triton::gpu::MemDescSubviewOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Location loc = op->getLoc(); + auto b = TritonLLVMOpBuilder(loc, rewriter); auto srcTy = op.getSrc().getType(); auto dstTy = op.getResult().getType(); auto llvmElemTy = getTypeConverter()->convertType(srcTy.getElementType()); @@ -591,13 +604,13 @@ struct MemDescSubviewOpConversion triton::nvidia_gpu::TMemAllocation tmemAlloc = triton::nvidia_gpu::getTmemAllocSizes(cast(dstTy)); int numColOffset = tmemAlloc.numCols; - Value newBase = ptrtoint(rewriter.getI32Type(), smemObj.getBase()); + Value newBase = b.ptrtoint(rewriter.getI32Type(), smemObj.getBase()); newBase = rewriter.create( loc, newBase, rewriter.create(loc, opOffsetVals[0], - i32_val(numColOffset))); + b.i32_val(numColOffset))); auto elemPtrTy = ptr_ty(rewriter.getContext(), 3); - smemObj = SharedMemoryObject(inttoptr(elemPtrTy, newBase), llvmElemTy, + smemObj = SharedMemoryObject(b.inttoptr(elemPtrTy, newBase), llvmElemTy, offsetVals); rewriter.replaceOp(op, getStructFromSharedMemoryObject(loc, smemObj, rewriter)); diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/Utility.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/Utility.cpp index 1dcc6e5370f2..94c472a4314e 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/Utility.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/Utility.cpp @@ -129,9 +129,10 @@ Value createElectPredicate(Location loc, RewriterBase &rewriter) { } Value createElectPredicateWarp0(Location loc, RewriterBase &rewriter) { + auto b = TritonLLVMOpBuilder(loc, rewriter); Value threadId = getThreadId(rewriter, loc); - Value warp0 = icmp_ult(threadId, i32_val(32)); - return and_(warp0, createElectPredicate(loc, rewriter)); + Value warp0 = b.icmp_ult(threadId, b.i32_val(32)); + return b.and_(warp0, createElectPredicate(loc, rewriter)); } } // namespace NVIDIA