diff --git a/test/TritonGPU/amd/optimize-lds-usage.mlir b/test/TritonGPU/amd/optimize-lds-usage.mlir index 0ec2ad5382ec..5cd34aab27d2 100644 --- a/test/TritonGPU/amd/optimize-lds-usage.mlir +++ b/test/TritonGPU/amd/optimize-lds-usage.mlir @@ -118,3 +118,22 @@ module attributes {"triton_gpu.num-warps" = 8 : i32, "triton_gpu.threads-per-war tt.return } } + +// ----- + +// Checks that optimization do not crash on 1d tensor +// CHECK-LABEL: convert_1d +// CHECK: triton_gpu.local_alloc +// CHECK-NEXT: triton_gpu.convert_layout +// CHECK-NEXT: triton_gpu.local_load +#blocked = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [64], warpsPerCTA = [4], order = [0]}> +#mma = #triton_gpu.amd_mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [32, 32], isTransposed = true}> +#shared = #triton_gpu.shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [0, 1], hasLeadingOffset = false}> +module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { + tt.func public @convert_1d(%arg0: tensor<128xf32, #triton_gpu.slice<{dim = 0, parent = #mma}>>, %arg1: tensor<128x128xf32, #mma>) attributes {noinline = false} { + %alloc = triton_gpu.local_alloc %arg1 : (tensor<128x128xf32, #mma>) -> !tt.memdesc<128x128xf32, #shared, #triton_gpu.shared_memory> + %1 = triton_gpu.convert_layout %arg0 : tensor<128xf32, #triton_gpu.slice<{dim = 0, parent = #mma}>> -> tensor<128xf32, #blocked> + %load = triton_gpu.local_load %alloc : !tt.memdesc<128x128xf32, #shared, #triton_gpu.shared_memory> -> tensor<128x128xf32, #mma> + tt.return + } +} diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/OptimizeLDSUsage.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/OptimizeLDSUsage.cpp index db3223f119da..4a0a7fed22b0 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/OptimizeLDSUsage.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/OptimizeLDSUsage.cpp @@ -96,7 +96,8 @@ class OptimizeAMDLDSUsage auto dstEnc = dstType.getEncoding(); auto ctx = srcEnc.getContext(); - auto rank = srcType.getShape().size(); + auto rank = srcType.getRank(); + unsigned numWarps = triton::gpu::getNumWarpsPerCTA(srcEnc); auto warpSize = triton::gpu::getWarpSize(srcEnc); @@ -109,11 +110,20 @@ class OptimizeAMDLDSUsage // Create a list of temporary layouts SmallVector elemsPerThread(rank, 1); SmallVector threadsPerWarp(rank, 1); - threadsPerWarp[rank - 1] = warpSize / 8; - threadsPerWarp[rank - 2] = warpSize / threadsPerWarp[rank - 1]; + + // Special case for rank == 1 + if (rank == 1) { + threadsPerWarp[0] = warpSize; + } else { + assert(rank > 1); + threadsPerWarp[rank - 1] = warpSize / 8; + threadsPerWarp[rank - 2] = warpSize / threadsPerWarp[rank - 1]; + } + auto layoutCTA = triton::gpu::getCTALayout(srcEnc); auto order = triton::gpu::getOrder(srcEnc); SmallVector dummyWarpsPerCTA(rank, 1); + auto baseFallbackLayout = triton::gpu::BlockedEncodingAttr::get( ctx, elemsPerThread, threadsPerWarp, dummyWarpsPerCTA, order, layoutCTA); diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/OptimizeLDSUtility.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/OptimizeLDSUtility.cpp index dfa8e06e247c..fb0bfb656ef4 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/OptimizeLDSUtility.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/OptimizeLDSUtility.cpp @@ -68,9 +68,13 @@ Attribute createTmpLayout(Attribute layout, ArrayRef warpsPerCTA) { ctx, src.getOpIdx(), createTmpLayout(src.getParent(), warpsPerCTA), src.getKWidth()); } - if (auto src = dyn_cast(layout)) + if (auto src = dyn_cast(layout)) { + // TODO: think of a way to construct slice layouts based on warpsPerCTA + // argument + auto parentWarpsPerCTA = triton::gpu::getWarpsPerCTA(src.getParent()); return triton::gpu::SliceEncodingAttr::get( - ctx, src.getDim(), createTmpLayout(src.getParent(), warpsPerCTA)); + ctx, src.getDim(), createTmpLayout(src.getParent(), parentWarpsPerCTA)); + } assert("Encountered unsupported layout"); return Attribute(); } diff --git a/third_party/nvidia/backend/driver.py b/third_party/nvidia/backend/driver.py index 38ce62b0c2a2..e0727a74e10d 100644 --- a/third_party/nvidia/backend/driver.py +++ b/third_party/nvidia/backend/driver.py @@ -208,6 +208,15 @@ def format_of(ty): static void _launch(int gridX, int gridY, int gridZ, int num_warps, int num_ctas, int clusterDimX, int clusterDimY, int clusterDimZ, int shared_memory, CUstream stream, CUfunction function{', ' + arg_decls if len(arg_decls) > 0 else ''}) {{ void *params[] = {{ {', '.join(f"&arg{i}" for i in params)} }}; if (gridX*gridY*gridZ > 0) {{ + CUcontext pctx; + CUDA_CHECK(cuCtxGetCurrent(&pctx)); + if (!pctx) {{ + // Ensure device context. + CUdevice device; + CUDA_CHECK(cuDeviceGet(&device, 0)); + CUDA_CHECK(cuDevicePrimaryCtxRetain(&pctx, device)); + CUDA_CHECK(cuCtxSetCurrent(pctx)); + }} if (num_ctas == 1) {{ CUDA_CHECK(cuLaunchKernel(function, gridX, gridY, gridZ, 32*num_warps, 1, 1, shared_memory, stream, params, 0)); }} else {{