Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions test/TritonGPU/amd/optimize-lds-usage.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -118,3 +118,22 @@ module attributes {"triton_gpu.num-warps" = 8 : i32, "triton_gpu.threads-per-war
tt.return
}
}

// -----

// Checks that optimization do not crash on 1d tensor
// CHECK-LABEL: convert_1d
// CHECK: triton_gpu.local_alloc
// CHECK-NEXT: triton_gpu.convert_layout
// CHECK-NEXT: triton_gpu.local_load
#blocked = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [64], warpsPerCTA = [4], order = [0]}>
#mma = #triton_gpu.amd_mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [32, 32], isTransposed = true}>
#shared = #triton_gpu.shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [0, 1], hasLeadingOffset = false}>
module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 64 : i32} {
tt.func public @convert_1d(%arg0: tensor<128xf32, #triton_gpu.slice<{dim = 0, parent = #mma}>>, %arg1: tensor<128x128xf32, #mma>) attributes {noinline = false} {
%alloc = triton_gpu.local_alloc %arg1 : (tensor<128x128xf32, #mma>) -> !tt.memdesc<128x128xf32, #shared, #triton_gpu.shared_memory>
%1 = triton_gpu.convert_layout %arg0 : tensor<128xf32, #triton_gpu.slice<{dim = 0, parent = #mma}>> -> tensor<128xf32, #blocked>
%load = triton_gpu.local_load %alloc : !tt.memdesc<128x128xf32, #shared, #triton_gpu.shared_memory> -> tensor<128x128xf32, #mma>
tt.return
}
}
16 changes: 13 additions & 3 deletions third_party/amd/lib/TritonAMDGPUToLLVM/OptimizeLDSUsage.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,8 @@ class OptimizeAMDLDSUsage
auto dstEnc = dstType.getEncoding();

auto ctx = srcEnc.getContext();
auto rank = srcType.getShape().size();
auto rank = srcType.getRank();

unsigned numWarps = triton::gpu::getNumWarpsPerCTA(srcEnc);
auto warpSize = triton::gpu::getWarpSize(srcEnc);

Expand All @@ -109,11 +110,20 @@ class OptimizeAMDLDSUsage
// Create a list of temporary layouts
SmallVector<unsigned> elemsPerThread(rank, 1);
SmallVector<unsigned> threadsPerWarp(rank, 1);
threadsPerWarp[rank - 1] = warpSize / 8;
threadsPerWarp[rank - 2] = warpSize / threadsPerWarp[rank - 1];

// Special case for rank == 1
if (rank == 1) {
threadsPerWarp[0] = warpSize;
} else {
assert(rank > 1);
threadsPerWarp[rank - 1] = warpSize / 8;
threadsPerWarp[rank - 2] = warpSize / threadsPerWarp[rank - 1];
}

auto layoutCTA = triton::gpu::getCTALayout(srcEnc);
auto order = triton::gpu::getOrder(srcEnc);
SmallVector<unsigned> dummyWarpsPerCTA(rank, 1);

auto baseFallbackLayout = triton::gpu::BlockedEncodingAttr::get(
ctx, elemsPerThread, threadsPerWarp, dummyWarpsPerCTA, order,
layoutCTA);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,9 +68,13 @@ Attribute createTmpLayout(Attribute layout, ArrayRef<unsigned> warpsPerCTA) {
ctx, src.getOpIdx(), createTmpLayout(src.getParent(), warpsPerCTA),
src.getKWidth());
}
if (auto src = dyn_cast<triton::gpu::SliceEncodingAttr>(layout))
if (auto src = dyn_cast<triton::gpu::SliceEncodingAttr>(layout)) {
// TODO: think of a way to construct slice layouts based on warpsPerCTA
// argument
auto parentWarpsPerCTA = triton::gpu::getWarpsPerCTA(src.getParent());
return triton::gpu::SliceEncodingAttr::get(
ctx, src.getDim(), createTmpLayout(src.getParent(), warpsPerCTA));
ctx, src.getDim(), createTmpLayout(src.getParent(), parentWarpsPerCTA));
}
assert("Encountered unsupported layout");
return Attribute();
}
Expand Down
9 changes: 9 additions & 0 deletions third_party/nvidia/backend/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,15 @@ def format_of(ty):
static void _launch(int gridX, int gridY, int gridZ, int num_warps, int num_ctas, int clusterDimX, int clusterDimY, int clusterDimZ, int shared_memory, CUstream stream, CUfunction function{', ' + arg_decls if len(arg_decls) > 0 else ''}) {{
void *params[] = {{ {', '.join(f"&arg{i}" for i in params)} }};
if (gridX*gridY*gridZ > 0) {{
CUcontext pctx;
CUDA_CHECK(cuCtxGetCurrent(&pctx));
if (!pctx) {{
// Ensure device context.
CUdevice device;
CUDA_CHECK(cuDeviceGet(&device, 0));
CUDA_CHECK(cuDevicePrimaryCtxRetain(&pctx, device));
CUDA_CHECK(cuCtxSetCurrent(pctx));
}}
if (num_ctas == 1) {{
CUDA_CHECK(cuLaunchKernel(function, gridX, gridY, gridZ, 32*num_warps, 1, 1, shared_memory, stream, params, 0));
}} else {{
Expand Down
Loading