diff --git a/test/Conversion/amd/wmma-v2-shortcut.mlir b/test/Conversion/amd/wmma-v2-shortcut.mlir index 3ccd3ce02c73..b8711c81b00b 100644 --- a/test/Conversion/amd/wmma-v2-shortcut.mlir +++ b/test/Conversion/amd/wmma-v2-shortcut.mlir @@ -1,4 +1,4 @@ -// RUN: triton-opt %s --tritongpu-reduce-data-duplication --allocate-shared-memory --convert-triton-amdgpu-to-llvm=arch="gfx1200" -split-input-file | FileCheck %s +// RUN: triton-opt %s --tritongpu-reduce-data-duplication --allocate-shared-memory --convert-triton-amdgpu-to-llvm=arch="gfx1200" -reconcile-unrealized-casts -split-input-file | FileCheck %s #wmmaTv2 = #ttg.amd_wmma<{version = 2, warpsPerCTA = [1, 1], isTranspose = true}> #dotop0v2 = #ttg.dot_op<{opIdx = 0, parent = #wmmaTv2, kWidth=8}> diff --git a/test/Conversion/tritongpu_to_llvm.mlir b/test/Conversion/tritongpu_to_llvm.mlir index 2bed4b1dd503..5b6b27bccf93 100644 --- a/test/Conversion/tritongpu_to_llvm.mlir +++ b/test/Conversion/tritongpu_to_llvm.mlir @@ -1,4 +1,4 @@ -// RUN: triton-opt %s -split-input-file --allocate-shared-memory-nv --convert-triton-gpu-to-llvm 2>/dev/null | FileCheck %s --dump-input-context 20 +// RUN: triton-opt %s -split-input-file --allocate-shared-memory-nv --convert-triton-gpu-to-llvm -reconcile-unrealized-casts 2>/dev/null | FileCheck %s --dump-input-context 20 module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { // CHECK: llvm.func @test_empty_kernel(%arg0: i32, %arg1: !llvm.ptr<1>, %arg2: !llvm.ptr<1>, %arg3: !llvm.ptr<1>) @@ -2451,7 +2451,6 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ // CHECK-LABEL: @reinterpret_tensor_descriptor tt.func private @reinterpret_tensor_descriptor(%arg0: !tt.ptr) -> !tt.tensordesc> { - // CHECK: builtin.unrealized_conversion_cast // CHECK-NEXT: llvm.addrspacecast %arg0 : !llvm.ptr to !llvm.ptr %0 = ttng.reinterpret_tensor_descriptor %arg0 : !tt.ptr to !tt.tensordesc> tt.return %0 : !tt.tensordesc> diff --git a/test/Conversion/tritonnvidiagpu_to_llvm.mlir b/test/Conversion/tritonnvidiagpu_to_llvm.mlir index 8ba489dddeb8..22083207b924 100644 --- a/test/Conversion/tritonnvidiagpu_to_llvm.mlir +++ b/test/Conversion/tritonnvidiagpu_to_llvm.mlir @@ -1,4 +1,4 @@ -// RUN: triton-opt %s -split-input-file --convert-triton-gpu-to-llvm=compute-capability=90 | FileCheck %s +// RUN: triton-opt %s -split-input-file --convert-triton-gpu-to-llvm=compute-capability=90 -reconcile-unrealized-casts | FileCheck %s #shared0 = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> #smem = #ttg.shared_memory @@ -45,7 +45,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { // CHECK-LABEL: arrive_barrier tt.func @arrive_barrier(%alloc: !ttg.memdesc<1xi64, #shared0, #smem>) { - // CHECK: [[TID:%.*]] = nvvm.read.ptx.sreg.tid.x + // CHECK-NEXT: [[TID:%.*]] = nvvm.read.ptx.sreg.tid.x // CHECK-NEXT: [[C127:%.*]] = llvm.mlir.constant(127 : i32) // CHECK-NEXT: [[RTID:%.*]] = llvm.and [[TID]], [[C127]] // CHECK-NEXT: [[C0:%.*]] = llvm.mlir.constant(0 : i32) @@ -57,7 +57,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { // CHECK-LABEL: arrive_barrier_pred tt.func @arrive_barrier_pred(%alloc: !ttg.memdesc<1xi64, #shared0, #smem>, %pred: i1) { - // CHECK: [[TID:%.*]] = nvvm.read.ptx.sreg.tid.x + // CHECK-NEXT: [[TID:%.*]] = nvvm.read.ptx.sreg.tid.x // CHECK-NEXT: [[C127:%.*]] = llvm.mlir.constant(127 : i32) // CHECK-NEXT: [[RTID:%.*]] = llvm.and [[TID]], [[C127]] // CHECK-NEXT: [[C0:%.*]] = llvm.mlir.constant(0 : i32) diff --git a/test/LLVMIR/convert-to-llvmir-with-dbg-info.mlir b/test/LLVMIR/convert-to-llvmir-with-dbg-info.mlir index 5379099ae01b..033c3dc4c3cb 100644 --- a/test/LLVMIR/convert-to-llvmir-with-dbg-info.mlir +++ b/test/LLVMIR/convert-to-llvmir-with-dbg-info.mlir @@ -29,6 +29,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { %arg2: !llvm.ptr<1>, %arg3: i32, %arg4: !llvm.ptr<1>) { %constant_i32 = llvm.mlir.constant(9 : i32) : i32 %constant_i16 = llvm.mlir.constant(0 : i16) : i16 + %constant_i64 = llvm.mlir.constant(9 : i64) : i64 // CHECK: !DILocalVariable(name: "pid", scope: %pid = rocdl.workgroup.id.x : i32 loc(#loc14) @@ -49,14 +50,14 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { // CHECK: !DILocalVariable(name: "x", scope: %x_ptr = llvm.getelementptr %arg0[%block_start] : (!llvm.ptr<1>, i32) -> !llvm.ptr<1>, f32 - %x_buffer_ptr = rocdl.make.buffer.rsrc %x_ptr, %constant_i16, %constant_i32, %constant_i32 : <1> to <8> loc(#loc18) + %x_buffer_ptr = rocdl.make.buffer.rsrc %x_ptr, %constant_i16, %constant_i64, %constant_i32 : <1> to <8> loc(#loc18) llvm.intr.dbg.value #di_local_variable4 = %x_buffer_ptr : !llvm.ptr<8> loc(#loc8) %x_val = rocdl.raw.ptr.buffer.load %x_buffer_ptr, %mask_i1, %constant_i32, %constant_i32 : vector<4xf32> loc(#loc18) %x_scalar = llvm.extractelement %x_val[%constant_i32 : i32] : vector<4xf32> loc(#loc18) // CHECK: !DILocalVariable(name: "y", scope: %y_ptr = llvm.getelementptr %arg1[%block_start] : (!llvm.ptr<1>, i32) -> !llvm.ptr<1>, f32 - %y_buffer_ptr = rocdl.make.buffer.rsrc %y_ptr, %constant_i16, %constant_i32, %constant_i32 : <1> to <8> loc(#loc19) + %y_buffer_ptr = rocdl.make.buffer.rsrc %y_ptr, %constant_i16, %constant_i64, %constant_i32 : <1> to <8> loc(#loc19) llvm.intr.dbg.value #di_local_variable5 = %y_buffer_ptr : !llvm.ptr<8> loc(#loc10) %y_val = rocdl.raw.ptr.buffer.load %y_buffer_ptr, %mask_i1, %constant_i32, %constant_i32 : vector<4xf32> loc(#loc19) %y_scalar = llvm.extractelement %y_val[%constant_i32 : i32] : vector<4xf32> loc(#loc19) diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/BufferOpsEmitter.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/BufferOpsEmitter.cpp index d520edd207bb..482aeb857947 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/BufferOpsEmitter.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/BufferOpsEmitter.cpp @@ -86,7 +86,7 @@ Value BufferEmitter::createResourceDescriptor(Value basePtr, Value flagsConst = b.int_val(32, flags); Type rsrcType = LLVM::LLVMPointerType::get(rewriter.getContext(), 8); - Value numRecordsByte = b.int_val(32, std::numeric_limits::max() - 1); + Value numRecordsByte = b.int_val(64, std::numeric_limits::max() - 1); Value resource = rewriter.createOrFold( loc, rsrcType, basePtr, stride, numRecordsByte, flagsConst); diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/DotOpToLLVM/MFMA.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/DotOpToLLVM/MFMA.cpp index 3b6b762805e5..cba4d59b2001 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/DotOpToLLVM/MFMA.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/DotOpToLLVM/MFMA.cpp @@ -425,21 +425,9 @@ struct DotOpMFMAConversionHelper { // Now we have a vector of kBase elements of desired type. // Then we need to prepare vec for results. if (type.getIntOrFloatBitWidth() == 8) { - if (1 == kBase) { + if (1 == kBase) // This is only for the scale operands of scaled mfma on CDNA4 - if (isConstantScale) { - // If the scale is constant(created by arith::ConstantOp), it will - // be put in a sgpr instead of vgpr. In that case, instead of - // vgpr[7:0], the instruction reads sgpr[30:23] as the scale value. - // So we need to manually left shift the scale by 23 bits to meet - // the requirement. - results = b.shl(i32_ty, b.zext(i32_ty, b.bitcast(vec, i8_ty)), - b.i32_val(23)); - } else { - results = b.zext(i32_ty, b.bitcast(vec, i8_ty)); - } - } - + results = b.zext(i32_ty, b.bitcast(vec, i8_ty)); if (2 == kBase) // This case can occur during scale tensor packing when there aren't // enough elements to fill all 4 opSel slots. For example, with an A