Skip to content

Commit 67e72f8

Browse files
Mogballzwu-2025
authored andcommitted
[Pipeliner] Refactor partition and stage builders (NFC) (triton-lang#6906)
1 parent f3fb638 commit 67e72f8

File tree

14 files changed

+261
-330
lines changed

14 files changed

+261
-330
lines changed

include/triton/Dialect/TritonGPU/Transforms/Partition.h

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -35,9 +35,6 @@ class Partition {
3535
int getStage() const { return stage; }
3636
ArrayRef<Operation *> getOps() const { return ops; }
3737

38-
void insert(Operation *op) { ops.push_back(op); }
39-
void remove(Operation *op) { ops.erase(llvm::find(ops, op)); }
40-
4138
private:
4239
void setIndex(int idx) { this->idx = idx; }
4340
friend class WarpSchedule;
@@ -59,8 +56,6 @@ class WarpSchedule {
5956
public:
6057
// Create a new partition with a stage.
6158
Partition *addPartition(unsigned stage);
62-
// Update the op to partition mapping.
63-
void updatePartitions();
6459

6560
// Get the partition the op belongs to.
6661
Partition *getPartition(Operation *op);

include/triton/Dialect/TritonGPU/Transforms/PipeliningUtility.h

Lines changed: 4 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -131,42 +131,12 @@ int getNumStagesOrDefault(scf::ForOp forOp, int defaultNumStages);
131131

132132
// Given a result of MemDescSubview, or Alloca, create a MemDescSubview with a
133133
// single buffer slice (leading dimension equal to 1), at the given index.
134-
template <typename TBuilder>
135134
TypedValue<triton::gpu::MemDescType>
136-
createSingleBufferView(TBuilder &builder, Value alloc, Value idx) {
137-
assert(isa<triton::gpu::MemDescType>(alloc.getType()) &&
138-
"Expected MemDescType");
139-
auto allocDescType = cast<triton::gpu::MemDescType>(alloc.getType());
140-
SmallVector<int64_t> shape;
141-
if (allocDescType.getShape().size() > 1) {
142-
shape.insert(shape.end(), allocDescType.getShape().begin() + 1,
143-
allocDescType.getShape().end());
144-
} else {
145-
shape.push_back(1);
146-
}
147-
auto viewDescType = triton::gpu::MemDescType::get(
148-
shape, allocDescType.getElementType(), allocDescType.getEncoding(),
149-
allocDescType.getMemorySpace(), allocDescType.getMutableMemory(),
150-
/*allocShape=*/allocDescType.getAllocShape());
151-
SmallVector<Value> idxs = {idx};
152-
if (allocDescType.getShape().size() > 1) {
153-
Value zero =
154-
builder.template create<arith::ConstantIntOp>(alloc.getLoc(), 0, 32);
155-
for (unsigned i = 1; i < allocDescType.getShape().size(); i++) {
156-
idxs.push_back(zero);
157-
}
158-
}
159-
return builder.template create<triton::gpu::MemDescSubviewOp>(
160-
alloc.getLoc(), viewDescType, alloc, idxs);
161-
}
162-
163-
template <typename TBuilder>
135+
createSingleBufferView(OpBuilder &builder, Value alloc, Value idx);
136+
// Given a result of MemDescSubview, or Alloca, create a MemDescSubview with a
137+
// single buffer slice (leading dimension equal to 1), at the given index.
164138
TypedValue<triton::gpu::MemDescType>
165-
createSingleBufferView(TBuilder &builder, Value alloc, int idx) {
166-
return createSingleBufferView(
167-
builder, alloc,
168-
builder.template create<arith::ConstantIntOp>(alloc.getLoc(), idx, 32));
169-
}
139+
createSingleBufferView(OpBuilder &builder, Value alloc, int idx);
170140

171141
} // namespace triton
172142
} // namespace mlir

include/triton/Dialect/TritonGPU/Transforms/Utility.h

Lines changed: 6 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -250,44 +250,19 @@ void eraseLoopCarriedValues(scf::ForOp &loop, llvm::BitVector indices);
250250
} // namespace mlir
251251

252252
namespace mlir::triton {
253-
254253
/// Replace all uses of `oldUse` with `val` and propagate the type if needed.
255254
/// This is useful when we need to change a memory descriptor from immutable to
256255
/// mutable.
257256
void replaceUsesAndPropagateType(OpBuilder &builder, Operation *oldUse,
258257
Value val);
259258

260-
template <typename BuilderT>
259+
/// Replace all uses of `old` with a local load from `alloc` unless the use is a
260+
/// `ttg.local_alloc` with a matching shared encoding, in which case the shared
261+
/// memory is forwarded directly into the use.
261262
void replaceUsesWithLocalLoad(
262-
BuilderT &builder, OpResult old, TypedValue<triton::gpu::MemDescType> alloc,
263-
TypedValue<triton::gpu::AsyncTokenType> token = {}) {
264-
// Remove redundant local_load -> local_alloc
265-
namespace ttg = triton::gpu;
266-
using triton::gpu::LocalAllocOp;
267-
auto allocTy = alloc.getType();
268-
SmallVector<LocalAllocOp> allocsToErase;
269-
for (Operation *user : old.getUsers()) {
270-
if (auto userAlloc = dyn_cast<LocalAllocOp>(user)) {
271-
if (allocTy.getEncoding() == userAlloc.getType().getEncoding()) {
272-
replaceUsesAndPropagateType(builder, userAlloc, alloc);
273-
allocsToErase.push_back(userAlloc);
274-
}
275-
}
276-
}
277-
278-
// If there are some uses that were not local_allocs, we need to create a
279-
// local_load for them.
280-
if (std::distance(old.getUsers().begin(), old.getUsers().end()) >
281-
allocsToErase.size()) {
282-
auto loc = old.getOwner()->getLoc();
283-
auto sharedLoad = builder.template create<ttg::LocalLoadOp>(
284-
loc, old.getType(), alloc, token);
285-
old.replaceAllUsesWith(sharedLoad.getResult());
286-
}
287-
for (auto alloc : allocsToErase) {
288-
alloc.erase();
289-
}
290-
}
263+
OpBuilder &builder, OpResult old,
264+
TypedValue<triton::gpu::MemDescType> alloc,
265+
TypedValue<triton::gpu::AsyncTokenType> token = {});
291266
} // namespace mlir::triton
292267

293268
#endif // TRITON_DIALECT_TRITONGPU_TRANSFORMS_UTILITY_H_

include/triton/Dialect/TritonGPU/Transforms/WarpSpecialization.h

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,6 @@ namespace scf {
88
class ForOp;
99
} // namespace scf
1010
namespace triton::gpu {
11-
// Identify load-mma dependencies and specialize them to different partitions.
12-
LogicalResult specializeLoadMMADependencies(scf::ForOp &loop,
13-
int defaultNumStages);
1411
// This is the final step to prepare a loop for warp specialization. This takes
1512
// a loop with a partition schedule and rewrites the loop such that all SSA
1613
// dependencies between partitions are passed through shared memory and

include/triton/Dialect/TritonNvidiaGPU/Transforms/TMAUtilities.h

Lines changed: 5 additions & 101 deletions
Original file line numberDiff line numberDiff line change
@@ -16,17 +16,9 @@ inline bool isFp4Padded(Attribute encoding) {
1616
return mmaEnc && mmaEnc.getFp4Padded();
1717
}
1818

19-
template <typename BuilderT>
20-
inline SmallVector<Value> translateTMAIndices(BuilderT &builder, Location loc,
21-
Attribute encoding,
22-
SmallVector<Value> indices) {
23-
if (isFp4Padded(encoding)) {
24-
auto two = builder.template create<arith::ConstantIntOp>(loc, 2, 32);
25-
indices.back() =
26-
builder.template create<arith::MulIOp>(loc, indices.back(), two);
27-
}
28-
return indices;
29-
}
19+
SmallVector<Value> translateTMAIndices(OpBuilder &builder, Location loc,
20+
Attribute encoding,
21+
SmallVector<Value> indices);
3022

3123
gpu::CTALayoutAttr updateCTALayoutForShape(gpu::CTALayoutAttr ctaLayout,
3224
ArrayRef<int64_t> shape);
@@ -69,95 +61,7 @@ std::optional<int> getTMASwizzleMode(Operation *op, TensorDescType ty);
6961

7062
std::optional<int> getTMAElementType(Operation *op, TensorDescType ty);
7163

72-
template <typename BuilderT>
73-
mlir::LogicalResult createTMADesc(mlir::Value tmaPtr,
74-
mlir::triton::MakeTensorDescOp op,
75-
BuilderT &builder) {
76-
using namespace mlir;
77-
MLIRContext *ctx = op.getContext();
78-
auto loc = op.getLoc();
79-
auto mkI32Constant = [&](int32_t val) {
80-
return builder.template create<arith::ConstantOp>(
81-
loc, builder.getI32Type(), builder.getI32IntegerAttr(val));
82-
};
83-
84-
auto elemType = op.getBase().getType().getPointeeType();
85-
auto elemSize = elemType.getIntOrFloatBitWidth() / 8;
86-
auto encoding = op.getType().getBlockType().getEncoding();
87-
auto mmaEncoding =
88-
llvm::dyn_cast_or_null<gpu::NVMMASharedEncodingAttr>(encoding);
89-
bool fp4Padded = mmaEncoding && mmaEncoding.getFp4Padded();
90-
91-
int paddingScale = fp4Padded ? 2 : 1;
92-
auto shapePerCTA = gpu::getShapePerCTA(encoding, op.getTensorShape());
93-
auto blockShape =
94-
getTMABlockShape(encoding, shapePerCTA, /*packedSize=*/false);
95-
auto contigDimSize = blockShape.back();
96-
97-
llvm::SmallVector<Value> boxDim;
98-
if (fp4Padded && contigDimSize != 128) {
99-
return op->emitError(
100-
"FP4 padded loads require 128 elements or more in the last dim");
101-
}
102-
boxDim.push_back(mkI32Constant(contigDimSize));
103-
for (int k = shapePerCTA.size() - 2; k >= 0; --k)
104-
boxDim.push_back(mkI32Constant(blockShape[k]));
105-
106-
unsigned swizzleBytes = mmaEncoding ? mmaEncoding.getSwizzlingByteWidth() : 0;
107-
if (!mmaEncoding) {
108-
auto swizzledEnc = dyn_cast<gpu::SwizzledSharedEncodingAttr>(
109-
op.getType().getBlockType().getEncoding());
110-
if (!swizzledEnc || swizzledEnc.getVec() != 1 ||
111-
swizzledEnc.getPerPhase() != 1 || swizzledEnc.getMaxPhase() != 1) {
112-
op->emitError() << "Unhandled encoding type";
113-
return failure();
114-
}
115-
}
116-
117-
auto maybeSwizzleMode = getTMASwizzleMode(op, op.getType());
118-
if (!maybeSwizzleMode)
119-
return failure();
120-
auto swizzleMode = *maybeSwizzleMode;
121-
122-
Value elemSizeVal = builder.template create<arith::ConstantOp>(
123-
loc, builder.getI64Type(), builder.getI64IntegerAttr(elemSize));
124-
125-
SmallVector<Value> globalDim(llvm::reverse(op.getShape()));
126-
SmallVector<Value> globalStride;
127-
for (int k = op.getStrides().size() - 2; k >= 0; --k) {
128-
globalStride.push_back(op.getStrides()[k]);
129-
}
130-
131-
if (fp4Padded) {
132-
// Convert number of bytes to number of mxfp4 elements
133-
globalDim[0] = builder.template create<arith::MulIOp>(loc, globalDim[0],
134-
mkI32Constant(2));
135-
}
136-
137-
SmallVector<Value> elementStride(globalDim.size(), mkI32Constant(1));
138-
139-
for (int i = 0; i < globalStride.size(); ++i)
140-
globalStride[i] = builder.template create<arith::MulIOp>(
141-
loc, globalStride[i], elemSizeVal);
142-
143-
auto elemTypeEnum = getTMAElementType(op, op.getType());
144-
if (!elemTypeEnum) {
145-
return failure();
146-
}
147-
148-
builder.template create<triton::ExperimentalTensormapCreateOp>(
149-
loc,
150-
/*desc_ptr=*/tmaPtr,
151-
/*global_address=*/op.getBase(),
152-
/*box_dim=*/boxDim,
153-
/*global_dim=*/globalDim,
154-
/*global_stride=*/globalStride,
155-
/*element_strides=*/elementStride,
156-
/*elem_type*/ builder.getI32IntegerAttr(*elemTypeEnum),
157-
/*interleave_layout*/ builder.getI32IntegerAttr(0),
158-
/*swizzle_mode=*/builder.getI32IntegerAttr(swizzleMode),
159-
/*fill_mode=*/builder.getI32IntegerAttr(0));
160-
return success();
161-
}
64+
LogicalResult createTMADesc(Value tmaPtr, MakeTensorDescOp op,
65+
OpBuilder &builder);
16266

16367
} // namespace mlir::triton::nvidia_gpu

0 commit comments

Comments
 (0)