Skip to content

Commit 7da6d0b

Browse files
authored
[NFC] Remove custom backend callback and move it to TargetInfo (triton-lang#5854)
1 parent 6e41010 commit 7da6d0b

File tree

8 files changed

+28
-41
lines changed

8 files changed

+28
-41
lines changed

include/triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h

Lines changed: 4 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -28,20 +28,6 @@ constexpr int patternBenefitClampOptimizedPattern = 20;
2828
constexpr int patternBenefitConvertLayoutOptimizedPattern = 20;
2929
constexpr int patternBenefitNvidiaTensorCoreSubviewPattern = 20;
3030

31-
struct BackendCallbacks {
32-
/**
33-
* A backend-specific callback for appending auxiliary data during
34-
* `LocalStoreOp` conversion.
35-
*
36-
* @param[in] op The reference to the re-written `LocalStoreOp`.
37-
* @param[in] count The number of issued LLVM instructions.
38-
* @param[in] type The input type of issued LLVM instructions.
39-
*/
40-
std::function<void(triton::gpu::LocalStoreOp op, size_t llvmOpCount,
41-
Type llvmOpType)>
42-
localStoreOpConversion = nullptr;
43-
};
44-
4531
void populateElementwiseOpToLLVMPatterns(
4632
LLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
4733
ModuleAxisInfoAnalysis &axisInfoAnalysis, const TargetInfoBase &targetInfo,
@@ -51,10 +37,10 @@ void populateElementwiseOpToLLVMPatterns(
5137
// callback receives 1) the current source op, 2) the number of issued LLVM
5238
// instructions and 3) their input types. Each MLIR backend can provide a
5339
// callback and, thus, handle backend-specific behaviors.
54-
void populateMemoryOpToLLVMPatterns(
55-
LLVMTypeConverter &typeConverter, const TargetInfoBase &targetInfo,
56-
RewritePatternSet &patterns, PatternBenefit benefit,
57-
std::optional<BackendCallbacks> backendCallbacks = std::nullopt);
40+
void populateMemoryOpToLLVMPatterns(LLVMTypeConverter &typeConverter,
41+
const TargetInfoBase &targetInfo,
42+
RewritePatternSet &patterns,
43+
PatternBenefit benefit);
5844

5945
void populateAssertOpToLLVMPattern(LLVMTypeConverter &typeConverter,
6046
RewritePatternSet &patterns,

include/triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,11 @@ class TargetInfoBase {
9191

9292
virtual bool supportVectorizedAtomics() const = 0;
9393

94+
// Helper used by targets to annotate store operations during lowering to
95+
// llvm.
96+
virtual void storeOpAnnotation(triton::gpu::LocalStoreOp op,
97+
size_t localStoreOpCount, Type type) const {}
98+
9499
virtual ~TargetInfoBase() {}
95100
};
96101
} // namespace mlir::triton

lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp

Lines changed: 4 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -154,15 +154,12 @@ struct LocalStoreOpConversion
154154
public:
155155
using ConvertOpToLLVMPattern<
156156
triton::gpu::LocalStoreOp>::ConvertOpToLLVMPattern;
157-
using BackendCallbackType =
158-
decltype(BackendCallbacks::localStoreOpConversion);
159157

160158
LocalStoreOpConversion(const LLVMTypeConverter &converter,
161159
const TargetInfoBase &targetInfo,
162-
BackendCallbackType backendCallback,
163160
PatternBenefit benefit = 1)
164161
: ConvertOpToLLVMPattern<triton::gpu::LocalStoreOp>(converter, benefit),
165-
targetInfo(targetInfo), backendCallback(backendCallback) {}
162+
targetInfo(targetInfo) {}
166163

167164
LogicalResult
168165
matchAndRewrite(triton::gpu::LocalStoreOp op, OpAdaptor adaptor,
@@ -178,31 +175,24 @@ struct LocalStoreOpConversion
178175
adaptor.getSrc(), smemObj, getTypeConverter(),
179176
rewriter, targetInfo, &llvmOpCount);
180177

181-
if (backendCallback)
182-
(backendCallback)(op, llvmOpCount.first, llvmOpCount.second);
178+
targetInfo.storeOpAnnotation(op, llvmOpCount.first, llvmOpCount.second);
183179

184180
rewriter.eraseOp(op);
185181
return success();
186182
}
187183

188184
private:
189185
const TargetInfoBase &targetInfo;
190-
BackendCallbackType backendCallback;
191186
};
192187

193188
} // namespace
194189

195190
void mlir::triton::populateMemoryOpToLLVMPatterns(
196191
LLVMTypeConverter &typeConverter, const TargetInfoBase &targetInfo,
197-
RewritePatternSet &patterns, PatternBenefit benefit,
198-
std::optional<BackendCallbacks> backendCallbacks) {
192+
RewritePatternSet &patterns, PatternBenefit benefit) {
199193
patterns.add<GlobalScratchAllocOpConversion>(typeConverter, benefit);
200194
patterns.add<LocalAllocOpConversion>(typeConverter, targetInfo, benefit);
201195
patterns.add<LocalDeallocOpConversion>(typeConverter, benefit);
202196
patterns.add<LocalLoadOpConversion>(typeConverter, targetInfo, benefit);
203-
204-
auto backendCall =
205-
backendCallbacks ? backendCallbacks->localStoreOpConversion : nullptr;
206-
patterns.add<LocalStoreOpConversion>(typeConverter, targetInfo, backendCall,
207-
benefit);
197+
patterns.add<LocalStoreOpConversion>(typeConverter, targetInfo, benefit);
208198
}

third_party/amd/lib/TritonAMDGPUToLLVM/SchedInstructions.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -86,8 +86,8 @@ void setNumGeneratedDsReads(gpu::LocalLoadOp op, size_t dsReadsCount,
8686
});
8787
}
8888

89-
void storeOpConversionCallback(triton::gpu::LocalStoreOp op,
90-
size_t localStoreOpCount, Type type) {
89+
void storeOpSchedAnnotations(triton::gpu::LocalStoreOp op,
90+
size_t localStoreOpCount, Type type) {
9191
MLIRContext *ctx = op->getContext();
9292
auto counterAttr =
9393
triton::amdgpu::InstCounterAttr::get(ctx, localStoreOpCount, type);

third_party/amd/lib/TritonAMDGPUToLLVM/SchedInstructions.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,8 @@ void setNumGeneratedGlobalLoads(LoadOpType op, size_t globalLoadsCount,
1818
Type type);
1919
void setNumGeneratedDsReads(gpu::LocalLoadOp op, size_t numDsReadsCount,
2020
Type type);
21-
void storeOpConversionCallback(triton::gpu::LocalStoreOp op, size_t llvmOpCount,
22-
Type type);
21+
void storeOpSchedAnnotations(triton::gpu::LocalStoreOp op, size_t llvmOpCount,
22+
Type type);
2323
triton::DotOp getSingleDotOpIfExists(scf::ForOp forOp);
2424
} // namespace mlir::triton
2525

third_party/amd/lib/TritonAMDGPUToLLVM/TargetInfo.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
#include "TargetInfo.h"
2+
#include "SchedInstructions.h"
23
#include "TritonAMDGPUToLLVM/GCNAsmFormat.h"
34
#include "Utility.h"
45
#include "mlir/Dialect/Arith/IR/Arith.h"
@@ -442,4 +443,9 @@ bool TargetInfo::supportVectorizedAtomics() const {
442443
return true;
443444
}
444445

446+
void TargetInfo::storeOpAnnotation(triton::gpu::LocalStoreOp op,
447+
size_t localStoreOpCount, Type type) const {
448+
storeOpSchedAnnotations(op, localStoreOpCount, type);
449+
}
450+
445451
} // namespace mlir::triton::AMD

third_party/amd/lib/TritonAMDGPUToLLVM/TargetInfo.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,9 @@ class TargetInfo : public mlir::triton::TargetInfoBase {
6868

6969
bool supportVectorizedAtomics() const override;
7070

71+
void storeOpAnnotation(triton::gpu::LocalStoreOp op, size_t localStoreOpCount,
72+
Type type) const override;
73+
7174
private:
7275
void printfImpl(Value formatStrStart, int formatStrByteCount, ValueRange args,
7376
RewriterBase &rewriter, bool useStdErr) const;

third_party/amd/lib/TritonAMDGPUToLLVM/TritonGPUToLLVM.cpp

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -195,12 +195,9 @@ struct ConvertTritonAMDGPUToLLVM
195195
populatePatterns7(mlir::triton::populateGatherOpToLLVMPatterns,
196196
commonBenefit);
197197

198-
mlir::triton::BackendCallbacks callbacks;
199-
callbacks.localStoreOpConversion = storeOpConversionCallback;
200-
201198
AMD::populateMemoryOpToLLVMPatterns(typeConverter, patterns, AMDBenefit);
202-
mlir::triton::populateMemoryOpToLLVMPatterns(
203-
typeConverter, targetInfo, patterns, commonBenefit, callbacks);
199+
mlir::triton::populateMemoryOpToLLVMPatterns(typeConverter, targetInfo,
200+
patterns, commonBenefit);
204201
mlir::triton::populateMakeRangeOpToLLVMPattern(typeConverter, targetInfo,
205202
patterns, commonBenefit);
206203
mlir::triton::populateAssertOpToLLVMPattern(typeConverter, patterns,

0 commit comments

Comments
 (0)