Skip to content

Commit cb3da9f

Browse files
authored
Merge branch 'main' into backwardslice-fix
2 parents 3569317 + 82fec37 commit cb3da9f

File tree

27 files changed

+973
-101
lines changed

27 files changed

+973
-101
lines changed

bin/RegisterTritonDialects.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ inline void registerTritonDialects(mlir::DialectRegistry &registry) {
6262
mlir::registerTritonAMDGPUAccelerateMatmul();
6363
mlir::registerTritonAMDGPUOptimizeEpilogue();
6464
mlir::registerTritonAMDGPUReorderInstructions();
65+
mlir::registerTritonAMDGPUBlockPingpong();
6566
mlir::registerTritonAMDGPUStreamPipeline();
6667
mlir::registerTritonAMDGPUCanonicalizePointers();
6768
mlir::registerTritonAMDGPUConvertToBufferOps();

include/triton/Dialect/TritonGPU/Transforms/Utility.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,8 @@ Operation *cloneWithInferType(mlir::OpBuilder &rewriter, Operation *op,
163163
LogicalResult getConvertBackwardSlice(
164164
Value root, SetVector<Value> &slice, Attribute rootEncoding,
165165
DenseMap<Value, Attribute> &layout,
166-
std::function<bool(Operation *)> stopPropagation = nullptr);
166+
std::function<bool(Operation *)> stopPropagation = nullptr,
167+
std::function<Value(Value, Attribute)> getExistingConversion = nullptr);
167168

168169
// Populate pattern to remove dead cycles in ForOp.
169170
void populateForOpDeadArgumentElimination(RewritePatternSet &patterns);

include/triton/Tools/LinearLayout.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -679,6 +679,10 @@ class LinearLayout {
679679

680680
// Get the layout that is the inverse of this layout.
681681
[[nodiscard]] LinearLayout invert() const;
682+
// Compute and return a psueodinverse of this layout. This is a layout such
683+
// that `B = A.psuedoinvert()` implies that `A(B(x)) = I`. If `A` is
684+
// invertible, then this returns `A^-1`.
685+
[[nodiscard]] LinearLayout pseudoinvert() const;
682686

683687
// For each in-dim, returns a bitmask of the "free variables" in the layout
684688
// function.

include/triton/Tools/Sys/GetEnv.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ inline const std::set<std::string> CACHE_INVALIDATING_ENV_VARS = {
2929
"TRITON_DISABLE_RESHAPE_ENCODING_INFERENCE",
3030
"TRITON_ENABLE_LLVM_DEBUG",
3131
"TRITON_HIP_STREAM_PREFETCH",
32+
"TRITON_HIP_USE_BLOCK_PINGPONG",
3233
"TRITON_LLVM_DEBUG_ONLY",
3334
"USE_IR_LOC",
3435
"NVPTX_ENABLE_DUMP",

lib/Analysis/Utility.cpp

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -491,13 +491,8 @@ bool GatherLoweringHelper::isWarpLocal() {
491491
// in the index and source tensors are the same. This means we don't need to
492492
// xor shuffle across threads before emitting index shuffles; we push warp
493493
// shuffling to layout conversions.
494-
if (srcLayout->sublayout(kLane, otherDims) !=
495-
idxLayout->sublayout(kLane, otherDims))
496-
return false;
497-
498-
// Otherwise, the source layout has to be invertible. This primarily means
499-
// the codegen path doesn't support broadcasted source layouts.
500-
return srcLayout->isInvertible();
494+
return srcLayout->sublayout(kLane, otherDims) ==
495+
idxLayout->sublayout(kLane, otherDims);
501496
}
502497

503498
unsigned getNumScratchElements(ArrayRef<unsigned> shape) {

lib/Conversion/TritonGPUToLLVM/GatherOpToLLVM.cpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -240,9 +240,10 @@ void GatherOpConversion::emitWarpLocalGather(
240240
// `llvm.select` using `src_reg` to get the right one. `K` is the number of
241241
// elements per column owned by a thread.
242242

243-
// Fully invert the source layout. We know it is invertible because
244-
// `isWarpLocal` checked this.
245-
LinearLayout invSrcLayout = srcLayout.invert();
243+
// Invert the source layout. It doesn't matter whether it is fully invertible
244+
// with respect to anything except the register input dimension, since we know
245+
// those don't vary in ways that matter for codegen.
246+
LinearLayout invSrcLayout = srcLayout.pseudoinvert();
246247

247248
// Sanity check: the warp must be invariant to the index because otherwise the
248249
// gather would need to read across warps!

lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp

Lines changed: 46 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -116,17 +116,13 @@ class LayoutPropagation {
116116
class LayoutRematerialization {
117117
public:
118118
LayoutRematerialization(FuncOp F) : funcOp(F) {}
119+
119120
// Map the original value to the remat'ed one.
120121
void addRematValue(Value old, Attribute encoding, Value newV);
121-
bool hasRematValue(Value value, Attribute encoding) {
122-
return rematMapping.contains({value, encoding});
123-
}
124-
// Return the remat'ed value in the given encoding.
125-
Value getRematValue(Value value, Attribute encoding) {
126-
auto it = rematMapping.find({value, encoding});
127-
assert(it != rematMapping.end());
128-
return it->second;
129-
}
122+
// Get the remat'ed value in the given encoding, if one already exists and
123+
// is different then the layout conversion root.
124+
Value getRematValue(Value value, Attribute encoding, Value root) const;
125+
130126
void cleanup();
131127
void backwardRematerialization();
132128
void backwardRematerialization(ConvertLayoutOp convertOp);
@@ -137,6 +133,11 @@ class LayoutRematerialization {
137133
void rewriteSlice(SetVector<Value> &slice, DenseMap<Value, Attribute> &layout,
138134
ConvertLayoutOp convertOp);
139135

136+
LogicalResult getRematerializableSlice(
137+
Value root, Attribute rootEncoding, SetVector<Value> &slice,
138+
DenseMap<Value, Attribute> &layout,
139+
std::function<bool(Operation *)> stopPropagation = nullptr);
140+
140141
private:
141142
void updateRematMapping(SmallVector<std::tuple<Value, Value>> &values);
142143
// Existing tuples of (value, layout) that needs to be updated when recreating
@@ -157,6 +158,21 @@ void LayoutRematerialization::addRematValue(Value old, Attribute encoding,
157158
mappedValues[old] = encoding;
158159
}
159160

161+
Value LayoutRematerialization::getRematValue(Value value, Attribute encoding,
162+
Value root) const {
163+
Value remat = rematMapping.lookup({value, encoding});
164+
if (!remat)
165+
return {};
166+
// If the remat'ed value is a conversion result, make sure it is different
167+
// than the root of the one we're looking at.
168+
if (auto cvt = remat.getDefiningOp<ConvertLayoutOp>()) {
169+
if (cvt.getSrc() == root)
170+
return {};
171+
}
172+
// This remat'ed value can be reused.
173+
return remat;
174+
}
175+
160176
// Remove unneeded values now that we are done with the rematMapping.
161177
void LayoutRematerialization::cleanup() {
162178
for (Operation *op : llvm::reverse(opToDelete))
@@ -766,8 +782,8 @@ void LayoutRematerialization::rewriteSlice(SetVector<Value> &slice,
766782
auto layoutIt = layout.find(v);
767783
assert(layoutIt != layout.end());
768784
// If we already have a remat value for this value, use it.
769-
if (hasRematValue(v, layoutIt->second)) {
770-
mapping.map(v, getRematValue(v, layoutIt->second));
785+
if (Value remat = getRematValue(v, layoutIt->second, convertOp.getSrc())) {
786+
mapping.map(v, remat);
771787
valuesWithExistingRemat.insert(v);
772788
continue;
773789
}
@@ -928,12 +944,17 @@ void LayoutRematerialization::rewriteSlice(SetVector<Value> &slice,
928944
rewriteSlice(slice, layout, convertOp, mapping);
929945
}
930946

931-
LogicalResult getRematerializableSlice(
947+
LogicalResult LayoutRematerialization::getRematerializableSlice(
932948
Value root, Attribute rootEncoding, SetVector<Value> &slice,
933949
DenseMap<Value, Attribute> &layout,
934-
std::function<bool(Operation *)> stopPropagation = nullptr) {
935-
LogicalResult result = getConvertBackwardSlice(root, slice, rootEncoding,
936-
layout, stopPropagation);
950+
std::function<bool(Operation *)> stopPropagation) {
951+
// Allow re-using existing conversions for a value.
952+
auto getExistingConversion = [&](Value value, Attribute encoding) -> Value {
953+
return getRematValue(value, encoding, root);
954+
};
955+
LogicalResult result =
956+
getConvertBackwardSlice(root, slice, rootEncoding, layout,
957+
stopPropagation, getExistingConversion);
937958
if (result.failed() || slice.empty())
938959
return failure();
939960

@@ -950,8 +971,14 @@ LogicalResult getRematerializableSlice(
950971
void LayoutRematerialization::backwardRematerialization() {
951972
// Go through each ConvertLayoutOp.
952973
SmallVector<ConvertLayoutOp> convertOps;
953-
funcOp.walk(
954-
[&](ConvertLayoutOp convertOp) { convertOps.push_back(convertOp); });
974+
funcOp.walk([&](ConvertLayoutOp convertOp) {
975+
convertOps.push_back(convertOp);
976+
// Add existing layout conversions as rematerializations of themselves. This
977+
// enables rematerialization of other conversions to re-use existing
978+
// conversions. Importantly, don't add them to `mappedValues`.
979+
rematMapping.insert(
980+
{{convertOp.getSrc(), convertOp.getType().getEncoding()}, convertOp});
981+
});
955982
for (ConvertLayoutOp convertOp : convertOps) {
956983
backwardRematerialization(convertOp);
957984
}
@@ -976,14 +1003,13 @@ void LayoutRematerialization::backwardRematerialization(
9761003
// careful with the heuristics for both correctness and perf
9771004
if (isa<DotOperandEncodingAttr, LinearEncodingAttr>(targetType.getEncoding()))
9781005
return;
979-
Value oldV = convertOp->getOperand(0);
1006+
Value oldV = convertOp.getSrc();
9801007
LDBG("check backward remat with source " << oldV << " encoding "
9811008
<< targetType.getEncoding());
9821009
// Check to see if there are existing remat'ed values for the pair of oldValue
9831010
// and encoding.
984-
if (hasRematValue(oldV, targetType.getEncoding())) {
1011+
if (Value newV = getRematValue(oldV, targetType.getEncoding(), oldV)) {
9851012
// Replace it with the remat'ed value.
986-
Value newV = getRematValue(oldV, targetType.getEncoding());
9871013
convertOp.replaceAllUsesWith(newV);
9881014
opToDelete.insert(convertOp);
9891015
LDBG("found remat'ed value" << newV);

lib/Dialect/TritonGPU/Transforms/Utility.cpp

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -757,11 +757,11 @@ static bool isFreeConvert(Operation *op) {
757757
convertOp.getType());
758758
}
759759

760-
LogicalResult
761-
getConvertBackwardSlice(Value root, SetVector<Value> &slice,
762-
Attribute rootEncoding,
763-
DenseMap<Value, Attribute> &layout,
764-
std::function<bool(Operation *)> stopPropagation) {
760+
LogicalResult getConvertBackwardSlice(
761+
Value root, SetVector<Value> &slice, Attribute rootEncoding,
762+
DenseMap<Value, Attribute> &layout,
763+
std::function<bool(Operation *)> stopPropagation,
764+
std::function<Value(Value, Attribute)> getExistingConversion) {
765765
DenseSet<std::pair<Value, Attribute>> seen;
766766
SmallVector<std::pair<Value, Attribute>> queue;
767767

@@ -802,6 +802,12 @@ getConvertBackwardSlice(Value root, SetVector<Value> &slice,
802802

803803
continue;
804804
}
805+
Value existing;
806+
if (getExistingConversion &&
807+
(existing = getExistingConversion(currentValue, encoding))) {
808+
enqueue(existing, encoding);
809+
continue;
810+
}
805811
if (auto *definingOp = currentValue.getDefiningOp()) {
806812
// If the op has multiple results we need to update all results layout.
807813
for (Value result : definingOp->getResults()) {

lib/Tools/LinearLayout.cpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -957,9 +957,13 @@ LinearLayout LinearLayout::invertAndCompose(const LinearLayout &outer) const {
957957
}
958958

959959
LinearLayout LinearLayout::invert() const {
960-
// A^-1(x) = A^-1(I(x)), thus A.invert() = I.invertAndCompose(A)
961960
assert(isInvertible() &&
962961
"A linear layout must be surjective and square to be invertible");
962+
return pseudoinvert();
963+
}
964+
965+
LinearLayout LinearLayout::pseudoinvert() const {
966+
// A^-1(x) = A^-1(I(x)), thus A.invert() = I.invertAndCompose(A)
963967
LinearLayout identity = LinearLayout::empty();
964968
for (auto outDim : getOutDimNames()) {
965969
identity *= LinearLayout::identity1D(getOutDimSize(outDim), outDim, outDim);

python/test/unit/runtime/test_cache.py

Lines changed: 20 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -199,7 +199,7 @@ def kernel(X, i: tl.int32):
199199
kernel[(1, )](x, 8)
200200
kernel[(1, )](x, 16)
201201
kernel[(1, )](x, 17)
202-
assert len(kernel.cache[device]) == 3
202+
assert len(kernel.device_caches[device][0]) == 3
203203

204204

205205
GLOBAL_DEFAULT_ARG = 1
@@ -223,7 +223,7 @@ def kernel(X, i: tl.constexpr = GLOBAL_DEFAULT_ARG):
223223
assert x == torch.ones_like(x)
224224

225225
device = getattr(torch, device).current_device()
226-
assert len(kernel.cache[device]) == 1
226+
assert len(kernel.device_caches[device][0]) == 1
227227

228228

229229
GLOBAL_VAR: tl.constexpr = 1
@@ -416,13 +416,13 @@ def kernel_add(a, b, o, N: tl.constexpr):
416416
32,
417417
]
418418
device = getattr(torch, device).current_device()
419-
assert len(kernel_add.cache[device]) == 0
419+
assert len(kernel_add.device_caches[device][0]) == 0
420420
kernel_add.warmup(torch.float32, torch.float32, torch.float32, 32, grid=(1, ))
421-
assert len(kernel_add.cache[device]) == 1
421+
assert len(kernel_add.device_caches[device][0]) == 1
422422
kernel_add.warmup(*args, grid=(1, ))
423-
assert len(kernel_add.cache[device]) == 1
423+
assert len(kernel_add.device_caches[device][0]) == 1
424424
kernel_add.warmup(*args, grid=(1, ))
425-
assert len(kernel_add.cache[device]) == 1
425+
assert len(kernel_add.device_caches[device][0]) == 1
426426

427427

428428
def test_jit_debug(device) -> None:
@@ -433,12 +433,12 @@ def kernel(tmp):
433433

434434
device = getattr(torch, device).current_device()
435435
tmp = torch.tensor([1], dtype=torch.int32, device=device)
436-
assert len(kernel.cache[device]) == 0
436+
assert len(kernel.device_caches[device][0]) == 0
437437
kernel[(1, )](tmp, debug=False)
438-
assert len(kernel.cache[device]) == 1
438+
assert len(kernel.device_caches[device][0]) == 1
439439
kernel[(1, )](tmp, debug=True)
440-
assert len(kernel.cache[device]) == 2
441-
bins = list(kernel.cache[device].values())
440+
assert len(kernel.device_caches[device][0]) == 2
441+
bins = list(kernel.device_caches[device][0].values())
442442
assert bins[0].asm['ttir'] != bins[1].asm['ttir']
443443

444444

@@ -455,18 +455,18 @@ def kernel_add_device(a, b, o, N: tl.constexpr):
455455
add_fn(a, b, o, N)
456456

457457
device = getattr(torch, device).current_device()
458-
assert len(kernel_add_device.cache[device]) == 0
458+
assert len(kernel_add_device.device_caches[device][0]) == 0
459459
kernel_add_device.warmup(torch.float32, torch.float32, torch.float32, 32, grid=(1, ))
460-
assert len(kernel_add_device.cache[device]) == 1
461-
bins = list(kernel_add_device.cache[device].values())
460+
assert len(kernel_add_device.device_caches[device][0]) == 1
461+
bins = list(kernel_add_device.device_caches[device][0].values())
462462
inline_ttir = bins[0].asm['ttir']
463463
add_fn.noinline = True
464464
add_fn.hash = None
465465
kernel_add_device.hash = None
466-
kernel_add_device.cache[device].clear()
466+
kernel_add_device.device_caches[device][0].clear()
467467
kernel_add_device.warmup(torch.float32, torch.float32, torch.float32, 32, grid=(1, ))
468-
assert len(kernel_add_device.cache[device]) == 1
469-
bins = list(kernel_add_device.cache[device].values())
468+
assert len(kernel_add_device.device_caches[device][0]) == 1
469+
bins = list(kernel_add_device.device_caches[device][0].values())
470470
noinline_ttir = bins[0].asm['ttir']
471471
assert inline_ttir != noinline_ttir
472472

@@ -514,12 +514,12 @@ def cache_hook(*args, **kwargs):
514514

515515
# clear the cache
516516
shutil.rmtree(fresh_triton_cache)
517-
kernel_add.cache[device].clear()
517+
kernel_add.device_caches[device][0].clear()
518518

519519
# preload the kernel
520520
kernel_preload = kernel_add.preload(specialization_data)
521521
assert kernel_preload.hash == hash
522-
assert len(kernel_add.cache[device]) == 1
522+
assert len(kernel_add.device_caches[device][0]) == 1
523523

524524
# we should hit the cache and not compile anything
525525
counter = 0
@@ -532,7 +532,7 @@ def inc_counter(*args, **kwargs):
532532
final_kernel = kernel_add.warmup(torch.float32, torch.float32, torch.float32, 32, tl.float32, grid=(1, ))
533533
JITFunction.cache_hook = None
534534
assert counter == 0
535-
assert len(kernel_add.cache[device]) == 1
535+
assert len(kernel_add.device_caches[device][0]) == 1
536536
assert final_kernel.hash == hash
537537

538538
# test that we can't preload a mismatched kernel
@@ -572,7 +572,7 @@ def compiled_hook(*args, **kwargs):
572572
kernel_add.warmup(torch.float32, torch.float32, torch.float32, 32, tl.float32, grid=(1, ))
573573
assert specialization_data is not None and specialization_data_compiled == specialization_data
574574
assert is_warmup is True
575-
assert key in kernel_add.cache[getattr(torch, device).current_device()]
575+
assert key in kernel_add.device_caches[getattr(torch, device).current_device()][0]
576576

577577

578578
@pytest.mark.skipif(reason="within_2g is a HIP specific optimization", condition=not is_hip())

0 commit comments

Comments
 (0)