Skip to content

Commit cdb5326

Browse files
authored
[AMD] Use AxisAnalysis for buffer op offsets to compute vector size (#6126)
This PR changes the lowering to use the `AxisAnalysis` on the offsets to get the correct vector size for buffer operations. The lowering of buffer operations did not analyze the contiguity of the offset values but was only looking at the layout. This results in wrong vectorization when the offsets are not contiguous per lane (e.g. strided loads). It also adds a lit test for strided buffer loads and stores and adjusts some existing buffer operation lit tests to provide enough information for the `AxisAnalysis` to work.
1 parent 2a650c2 commit cdb5326

File tree

8 files changed

+99
-63
lines changed

8 files changed

+99
-63
lines changed

include/triton/Analysis/AxisInfo.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -202,8 +202,9 @@ class ModuleAxisInfoAnalysis : public CallGraph<AxisInfoMapT> {
202202
return &(it->second);
203203
}
204204

205-
unsigned getPtrContiguity(Value ptr);
206-
unsigned getPtrAlignment(Value ptr);
205+
unsigned getContiguity(Value value);
206+
unsigned getAlignment(Value value);
207+
207208
unsigned getMaskAlignment(Value mask);
208209

209210
private:

lib/Analysis/AxisInfo.cpp

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1218,46 +1218,51 @@ void AxisInfo::initPessimisticStateFromFunc(int argNumber, T funcOp,
12181218
return AxisInfo(contiguity, divisibility, constancy, constantValue);
12191219
}
12201220

1221-
unsigned ModuleAxisInfoAnalysis::getPtrContiguity(Value ptr) {
1222-
auto tensorTy = dyn_cast<RankedTensorType>(ptr.getType());
1221+
unsigned ModuleAxisInfoAnalysis::getContiguity(Value value) {
1222+
auto tensorTy = dyn_cast<RankedTensorType>(value.getType());
12231223
if (!tensorTy)
12241224
return 1;
1225-
12261225
// FIXME: This is not as good as it could be, as we don't need to restrict
12271226
// the analysis to one dimension. We should determine contiguity on the
12281227
// flattenOuts() layout
12291228
auto linAttr =
12301229
gpu::toLinearEncoding(tensorTy.getEncoding(), tensorTy.getShape());
12311230
auto order = linAttr.getOrder();
1232-
unsigned align = getPtrAlignment(ptr);
1231+
unsigned align = getAlignment(value);
12331232

12341233
auto uniqueContigPerThread = linAttr.getContigPerThread();
12351234
assert(order[0] < uniqueContigPerThread.size() &&
12361235
"Unexpected uniqueContigPerThread size");
12371236
unsigned contiguity = uniqueContigPerThread[order[0]];
1238-
LDBG("getPtrContiguity uniqueContigPerThread = " << contiguity);
1237+
LDBG("getContiguity uniqueContigPerThread = " << contiguity);
12391238
contiguity = std::min(align, contiguity);
12401239

12411240
return contiguity;
12421241
}
12431242

1244-
unsigned ModuleAxisInfoAnalysis::getPtrAlignment(Value ptr) {
1245-
auto tensorTy = dyn_cast<RankedTensorType>(ptr.getType());
1243+
unsigned ModuleAxisInfoAnalysis::getAlignment(Value value) {
1244+
auto tensorTy = dyn_cast<RankedTensorType>(value.getType());
12461245
if (!tensorTy)
12471246
return 1;
1248-
auto *axisInfo = getAxisInfo(ptr);
1247+
auto *axisInfo = getAxisInfo(value);
12491248
if (!axisInfo)
12501249
return 1;
12511250
auto linAttr =
12521251
gpu::toLinearEncoding(tensorTy.getEncoding(), tensorTy.getShape());
12531252
auto order = linAttr.getOrder();
12541253
auto maxMultipleBytes = axisInfo->getDivisibility(order[0]);
12551254
auto maxContig = axisInfo->getContiguity(order[0]);
1256-
auto elemNumBits = triton::getPointeeBitWidth(tensorTy);
1255+
1256+
auto elemTy = tensorTy.getElementType();
1257+
// Get the pointee type if we have a tensor of ptrs to compute contiguity for
1258+
if (auto ptrTy = dyn_cast<PointerType>(elemTy)) {
1259+
elemTy = ptrTy.getPointeeType();
1260+
}
1261+
auto elemNumBits = elemTy.getIntOrFloatBitWidth();
12571262
auto elemNumBytes = std::max<unsigned>(elemNumBits / 8, 1);
12581263
auto maxMultiple = std::max<int64_t>(maxMultipleBytes / elemNumBytes, 1);
12591264
unsigned alignment = std::min(maxMultiple, maxContig);
1260-
LDBG("getPtrAlignment order[0] "
1265+
LDBG("getAlignment order[0] "
12611266
<< order[0] << " maxMultipleBytes = " << maxMultipleBytes
12621267
<< " maxContig = " << maxContig << " elemNumBits = " << elemNumBits
12631268
<< " maxMultiple = " << maxMultiple << " alignment " << alignment);

lib/Dialect/TritonGPU/Transforms/Pipeliner/AssignLatencies.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ bool isSmallLoad(tt::LoadOp loadOp,
4646
assert(!isLoadFromTensorPtr(loadOp) &&
4747
"Block ptr should have been lowered before this pass.");
4848
auto ptr = loadOp.getPtr();
49-
unsigned vec = axisInfoAnalysis.getPtrContiguity(ptr);
49+
unsigned vec = axisInfoAnalysis.getContiguity(ptr);
5050
if (auto mask = loadOp.getMask())
5151
vec = std::min<unsigned>(vec, axisInfoAnalysis.getMaskAlignment(mask));
5252

test/Conversion/amd/buffer_load_store.mlir

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -239,3 +239,22 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.thr
239239
tt.return
240240
}
241241
}
242+
243+
// -----
244+
245+
#blocked = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [64], warpsPerCTA = [4], order = [0]}>
246+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} {
247+
// CHECK-LABEL: strided_buffer_load_and_store
248+
tt.func public @strided_buffer_load_and_store(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}) attributes {noinline = false} {
249+
%cst = arith.constant dense<2> : tensor<1024xi32, #blocked>
250+
%0 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, #blocked>
251+
%1 = arith.muli %0, %cst : tensor<1024xi32, #blocked>
252+
// CHECK-COUNT-4: rocdl.raw.ptr.buffer.load {{.*}}, {{.*}}, {{.*}}, {{.*}} : f32
253+
// CHECK-NOT: rocdl.raw.ptr.buffer.load
254+
%2 = amdgpu.buffer_load %arg0[%1] : tensor<1024xf32, #blocked>
255+
// CHECK-COUNT-4: rocdl.raw.ptr.buffer.store {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}} : f32
256+
// CHECK-NOT: rocdl.raw.ptr.buffer.store
257+
amdgpu.buffer_store %2, %arg1[%1] : tensor<1024xf32, #blocked>
258+
tt.return
259+
}
260+
}

test/Conversion/amd/buffer_load_to_local_to_llvm.mlir

Lines changed: 40 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -22,42 +22,49 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shar
2222

2323
// -----
2424

25-
#blocked = #ttg.blocked<{sizePerThread = [1, 2], threadsPerWarp = [2, 32], warpsPerCTA = [4, 1], order = [1, 0]}>
26-
#shared = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0]}>
25+
#blocked = #ttg.blocked<{sizePerThread = [2, 1], threadsPerWarp = [32, 2], warpsPerCTA = [1, 32], order = [0, 1]}>
26+
#shared = #ttg.swizzled_shared<{vec = 2, perPhase = 1, maxPhase = 1, order = [0, 1]}>
2727
#smem = #ttg.shared_memory
28-
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shared = 8192 : i32, ttg.target = "hip:gfx942", "ttg.threads-per-warp" = 64 : i32} {
28+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 32 : i32, ttg.shared = 0 : i32, ttg.target = "hip:gfx942", "ttg.threads-per-warp" = 64 : i32} {
2929
// COMMON-LABEL: buffer_load_to_local_vectorized_2xf16
30-
tt.func public @buffer_load_to_local_vectorized_2xf16(
31-
%arg1: !tt.ptr<f16> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32},
32-
%arg2: !ttg.memdesc<32x64xf16, #shared, #smem, mutable>,
33-
%arg3: i32) {
34-
%1 = tt.splat %arg3: i32 -> tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked}>>
35-
%2 = tt.expand_dims %1 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x64xi32, #blocked>
36-
%3 = tt.broadcast %2 : tensor<1x64xi32, #blocked> -> tensor<32x64xi32, #blocked>
37-
// Each thread needs to load 8 elements and we load 2 (sizePerThread) per buffer load instruction
30+
tt.func public @buffer_load_to_local_vectorized_2xf16(%arg1: !tt.ptr<f16> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg2: !ttg.memdesc<64x64xf16, #shared, #smem, mutable>) {
31+
%cst = arith.constant dense<64> : tensor<1x64xi32, #blocked>
32+
%0 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>>
33+
%1 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked}>>
34+
%2 = tt.expand_dims %0 {axis = 1 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<64x1xi32, #blocked>
35+
%3 = tt.broadcast %2 : tensor<64x1xi32, #blocked> -> tensor<64x64xi32, #blocked>
36+
%4 = tt.expand_dims %1 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x64xi32, #blocked>
37+
%5 = arith.muli %4, %cst : tensor<1x64xi32, #blocked>
38+
%6 = tt.broadcast %5 : tensor<1x64xi32, #blocked> -> tensor<64x64xi32, #blocked>
39+
%7 = arith.addi %3, %6 : tensor<64x64xi32, #blocked>
40+
41+
// Each thread needs to load 2 elements and we load 2 (sizePerThread) per buffer load instruction
3842
// COMMON: rocdl.make.buffer.rsrc
3943
// COMMON-NOT: rocdl.make.buffer.rsrc
40-
// COMMON-COUNT-4: rocdl.raw.ptr.buffer.load.lds
44+
// COMMON: rocdl.raw.ptr.buffer.load.lds
4145
// COMMON-NOT: rocdl.raw.ptr.buffer.load.lds
42-
%65 = amdgpu.buffer_load_to_local %arg1[%3] into %arg2 : <f16>[tensor<32x64xi32, #blocked>] -> <32x64xf16, #shared, #smem, mutable>
46+
%8 = amdgpu.buffer_load_to_local %arg1[%7] into %arg2 : <f16>[tensor<64x64xi32, #blocked>] -> <64x64xf16, #shared, #smem, mutable>
4347
tt.return
4448
}
4549
}
4650

4751
// -----
4852

49-
#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [4, 1], order = [1, 0]}>
50-
#shared = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0]}>
53+
#blocked = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 32], order = [0, 1]}>
54+
#shared = #ttg.swizzled_shared<{vec = 2, perPhase = 1, maxPhase = 1, order = [0, 1]}>
5155
#smem = #ttg.shared_memory
52-
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shared = 8192 : i32, ttg.target = "hip:gfx942", "ttg.threads-per-warp" = 64 : i32} {
56+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 32 : i32, ttg.shared = 0 : i32, ttg.target = "hip:gfx942", "ttg.threads-per-warp" = 64 : i32} {
5357
// COMMON-LABEL: buffer_load_to_local_vectorized_8xf16
54-
tt.func public @buffer_load_to_local_vectorized_8xf16(
55-
%arg1: !tt.ptr<f16> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32},
56-
%arg2: !ttg.memdesc<32x64xf16, #shared, #smem, mutable>,
57-
%arg3: i32) {
58-
%1 = tt.splat %arg3: i32 -> tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked}>>
59-
%2 = tt.expand_dims %1 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x64xi32, #blocked>
60-
%3 = tt.broadcast %2 : tensor<1x64xi32, #blocked> -> tensor<32x64xi32, #blocked>
58+
tt.func public @buffer_load_to_local_vectorized_8xf16(%arg1: !tt.ptr<f16> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg2: !ttg.memdesc<64x64xf16, #shared, #smem, mutable>) {
59+
%cst = arith.constant dense<64> : tensor<1x64xi32, #blocked>
60+
%0 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>>
61+
%1 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked}>>
62+
%2 = tt.expand_dims %0 {axis = 1 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<64x1xi32, #blocked>
63+
%3 = tt.broadcast %2 : tensor<64x1xi32, #blocked> -> tensor<64x64xi32, #blocked>
64+
%4 = tt.expand_dims %1 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x64xi32, #blocked>
65+
%5 = arith.muli %4, %cst : tensor<1x64xi32, #blocked>
66+
%6 = tt.broadcast %5 : tensor<1x64xi32, #blocked> -> tensor<64x64xi32, #blocked>
67+
%7 = arith.addi %3, %6 : tensor<64x64xi32, #blocked>
6168

6269
// Each thread needs to load 8 elements and we load 8 (sizePerThread) per buffer load instruction
6370
// GFX950: rocdl.make.buffer.rsrc
@@ -68,7 +75,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shar
6875
// GFX942 does not support vectorization > 4bytes so we cannot lower it
6976
// GFX942-NOT: rocdl.raw.ptr.buffer.load.lds
7077
// GFX942: amdgpu.buffer_load_to_local
71-
%65 = amdgpu.buffer_load_to_local %arg1[%3] into %arg2 : <f16>[tensor<32x64xi32, #blocked>] -> <32x64xf16, #shared, #smem, mutable>
78+
%8 = amdgpu.buffer_load_to_local %arg1[%7] into %arg2 : <f16>[tensor<64x64xi32, #blocked>] -> <64x64xf16, #shared, #smem, mutable>
7279
tt.return
7380
}
7481
}
@@ -129,30 +136,28 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shar
129136

130137
// -----
131138

132-
133-
#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [2, 32], warpsPerCTA = [16, 1], order = [1, 0]}>
134-
#shared = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0]}>
139+
#blocked = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [64], warpsPerCTA = [1], order = [0]}>
140+
#shared = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0]}>
135141
#smem = #ttg.shared_memory
136-
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 16 : i32, ttg.shared = 8192 : i32, ttg.target = "hip:gfx942", "ttg.threads-per-warp" = 64 : i32} {
142+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, "ttg.threads-per-warp" = 64 : i32} {
137143
// COMMON-LABEL: buffer_load_to_local_cache_mods
138-
tt.func public @buffer_load_to_local_cache_mods(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32},
139-
%arg1: !tt.ptr<f16>,
140-
%arg2: tensor<32x32xi32, #blocked>,
141-
%arg3: !ttg.memdesc<32x32xf16, #shared, #smem, mutable>) {
144+
tt.func public @buffer_load_to_local_cache_mods(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32},
145+
%arg2: !ttg.memdesc<64xf32, #shared, #smem, mutable>) {
146+
%0 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #blocked>
142147
// The first constant 0 skips the LDS offset which is also 0
143148
// COMMON: llvm.getelementptr
144149
// COMMON: llvm.mlir.constant(0 : i32) : i32
145150
// COMMON: %[[aux_ca:.*]] = llvm.mlir.constant(0 : i32) : i32
146151
// COMMON: rocdl.raw.ptr.buffer.load.lds {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, %[[aux_ca]]
147-
%1 = amdgpu.buffer_load_to_local %arg1[%arg2] cacheModifier = ca into %arg3: <f16>[tensor<32x32xi32, #blocked>] -> <32x32xf16, #shared, #smem, mutable>
152+
%1 = amdgpu.buffer_load_to_local %arg0[%0] cacheModifier = ca into %arg2: <f32>[tensor<64xi32, #blocked>] -> <64xf32, #shared, #smem, mutable>
148153
// COMMON: llvm.getelementptr
149154
// COMMON: %[[aux_cg:.*]] = llvm.mlir.constant(3 : i32) : i32
150155
// COMMON: rocdl.raw.ptr.buffer.load.lds {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, %[[aux_cg]]
151-
%2 = amdgpu.buffer_load_to_local %arg1[%arg2] cacheModifier = cg into %arg3: <f16>[tensor<32x32xi32, #blocked>] -> <32x32xf16, #shared, #smem, mutable>
156+
%2 = amdgpu.buffer_load_to_local %arg0[%0] cacheModifier = cg into %arg2: <f32>[tensor<64xi32, #blocked>] -> <64xf32, #shared, #smem, mutable>
152157
// COMMON: llvm.getelementptr
153158
// COMMON: %[[aux_cv:.*]] = llvm.mlir.constant(17 : i32) : i32
154159
// COMMON: rocdl.raw.ptr.buffer.load.lds {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, %[[aux_cv]]
155-
%3 = amdgpu.buffer_load_to_local %arg1[%arg2] cacheModifier = cv into %arg3: <f16>[tensor<32x32xi32, #blocked>] -> <32x32xf16, #shared, #smem, mutable>
160+
%3 = amdgpu.buffer_load_to_local %arg0[%0] cacheModifier = cv into %arg2: <f32>[tensor<64xi32, #blocked>] -> <64xf32, #shared, #smem, mutable>
156161

157162
tt.return
158163
}

third_party/amd/lib/TritonAMDGPUToLLVM/Utility.cpp

Lines changed: 19 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -528,14 +528,29 @@ unsigned getContiguity(Value ptr, ModuleAxisInfoAnalysis &axisAnalysisPass) {
528528
auto tensorTy = dyn_cast<RankedTensorType>(ptr.getType());
529529
if (!tensorTy)
530530
return 1;
531-
return axisAnalysisPass.getPtrContiguity(ptr);
531+
return axisAnalysisPass.getContiguity(ptr);
532532
}
533533

534534
unsigned getContiguity(Value ptr, Value offset,
535535
ModuleAxisInfoAnalysis &axisAnalysisPass) {
536-
// Get contiguity from the offset
536+
537537
Type type = getPointerTypeWithShape(ptr, offset);
538538
RankedTensorType tensorTy = cast<RankedTensorType>(type);
539+
540+
// To compute the contiguity of the scalar/warp-uniform ptr and offset pair we
541+
// need to look at the contiguity of the offsets and the alignment of the ptr
542+
auto contiguity = axisAnalysisPass.getContiguity(offset);
543+
544+
// To get the alignment of the scalar ptr we need to look at the divisibility
545+
auto *axisInfo = axisAnalysisPass.getAxisInfo(ptr);
546+
auto maxMultipleBytes = axisInfo->getDivisibility(0);
547+
auto elemNumBits = triton::getPointeeBitWidth(tensorTy);
548+
auto elemNumBytes = std::max<unsigned>(elemNumBits / 8, 1);
549+
auto align = std::max<unsigned>(maxMultipleBytes / elemNumBytes, 1);
550+
551+
// FIXME (Alex): this should not be needed anymore because it's done inside
552+
// getContiguity, but we have an order issues with LL, so we keep this
553+
// until the LL order issue is fixed
539554
auto layout = tensorTy.getEncoding();
540555
auto linearLayout = triton::gpu::toLinearLayout(tensorTy.getShape(), layout);
541556
auto llAttr =
@@ -544,19 +559,10 @@ unsigned getContiguity(Value ptr, Value offset,
544559
auto contigPerThread = llAttr.getContigPerThread();
545560
assert(order[0] < contigPerThread.size() &&
546561
"Unexpected contigPerThread size");
547-
unsigned contiguity = contigPerThread[order[0]];
548-
549-
// Get alignment from the pointer. Since this is a scalar pointer
550-
// we should not take the pointer contiguity to consider alignment
551-
auto *axisInfo = axisAnalysisPass.getAxisInfo(ptr);
552-
auto maxMultipleBytes = axisInfo->getDivisibility(0);
553-
auto elemNumBits = triton::getPointeeBitWidth(tensorTy);
554-
auto elemNumBytes = std::max<unsigned>(elemNumBits / 8, 1);
555-
auto align = std::max<int64_t>(maxMultipleBytes / elemNumBytes, 1);
562+
contiguity = std::min(contiguity, contigPerThread[order[0]]);
556563

557564
// Final contiguity is a min of the offset contiguity and pointer alignment
558-
contiguity = std::min<int64_t>(align, contiguity);
559-
return contiguity;
565+
return std::min(align, contiguity);
560566
}
561567

562568
unsigned getVectorSize(Value ptr, ModuleAxisInfoAnalysis &axisAnalysisPass) {

third_party/amd/lib/TritonAMDGPUTransforms/StreamPipeline.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -502,7 +502,7 @@ void StreamPipeliner::assignMemoryLayouts() {
502502
assert(!isLoadFromTensorPtr(loadOp) &&
503503
"Block ptr should have been lowered before this pass.");
504504
auto ptr = loadOp.getPtr();
505-
unsigned vec = axisInfoAnalysis.getPtrContiguity(ptr);
505+
unsigned vec = axisInfoAnalysis.getContiguity(ptr);
506506
if (auto mask = loadOp.getMask())
507507
vec = std::min<unsigned>(vec, axisInfoAnalysis.getMaskAlignment(mask));
508508

third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/LoadStoreOpToLLVM.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,7 @@ struct LoadStoreConversionBase {
139139
auto tensorTy = dyn_cast<RankedTensorType>(ptr.getType());
140140
if (!tensorTy)
141141
return 1;
142-
return axisAnalysisPass.getPtrContiguity(ptr);
142+
return axisAnalysisPass.getContiguity(ptr);
143143
}
144144

145145
unsigned getVectorSize(Value ptr) const {

0 commit comments

Comments
 (0)