forked from triton-lang/triton
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathPipeline.cpp
More file actions
595 lines (523 loc) · 22 KB
/
Pipeline.cpp
File metadata and controls
595 lines (523 loc) · 22 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
#include "mlir/IR/BlockAndValueMapping.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/Transforms/Passes.h"
//===----------------------------------------------------------------------===//
//
// This file implements loop software pipelining
// The implementation here is inspired by the pipeline pass in Triton (-v2.0)
// and SCF's LoopPipelining.
//
//===----------------------------------------------------------------------===//
using namespace mlir;
#define GEN_PASS_CLASSES
#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc"
namespace {
class LoopPipeliner {
/// comments on numStages:
/// [0, numStages-1) are in the prologue
/// numStages-1 is appended after the loop body
int numStages;
/// cache forOp we are working on
scf::ForOp forOp;
/// cache YieldOp for this forOp
scf::YieldOp yieldOp;
/// loads to be pipelined
SetVector<Value> loads;
/// the value that each load will be mapped to (after layout conversion)
DenseMap<Value, Value> loadsMapping;
/// load => buffer
DenseMap<Value, Value> loadsBuffer;
/// load => buffer at stage N
DenseMap<Value, SmallVector<Value>> loadStageBuffer;
/// load => after extract
DenseMap<Value, Value> loadsExtract;
///
Value pipelineIterIdx;
///
Value loopIterIdx;
/// value (in loop) => value at stage N
DenseMap<Value, SmallVector<Value>> valueMapping;
/// Block arguments that loads depend on
DenseSet<BlockArgument> depArgs;
/// Operations (inside the loop body) that loads depend on
DenseSet<Operation *> depOps;
/// collect values that v depends on and are defined inside the loop
void collectDeps(Value v, int stages, DenseSet<Value> &deps);
void setValueMapping(Value origin, Value newValue, int stage);
Value lookupOrDefault(Value origin, int stage);
/// return true if this op uses any of `loads`
bool isDirectUserOfAsyncLoad(Operation &op);
/// returns a empty buffer of size <numStages, ...>
triton::gpu::AllocTensorOp allocateEmptyBuffer(Operation *op,
OpBuilder &builder);
public:
LoopPipeliner(scf::ForOp forOp, int numStages)
: forOp(forOp), numStages(numStages) {
// cache yieldOp
yieldOp = cast<scf::YieldOp>(forOp.getBody()->getTerminator());
}
/// Collect loads to pipeline. Return success if we can pipeline this loop
LogicalResult initialize();
/// emit pipelined loads (before loop body)
void emitPrologue();
/// create the new ForOp (add new args & insert prefetched ops)
scf::ForOp createNewForOp();
friend class PipelinePass;
};
// helpers
void LoopPipeliner::setValueMapping(Value origin, Value newValue, int stage) {
if (valueMapping.find(origin) == valueMapping.end())
valueMapping[origin] = SmallVector<Value>(numStages);
valueMapping[origin][stage] = newValue;
}
Value LoopPipeliner::lookupOrDefault(Value origin, int stage) {
if (valueMapping.find(origin) == valueMapping.end())
return origin;
return valueMapping[origin][stage];
}
void LoopPipeliner::collectDeps(Value v, int stages, DenseSet<Value> &deps) {
// Loop-invarant value. skip
if (v.getParentRegion() != &forOp.getLoopBody())
return;
// Since we only need to peel the loop numStages-1 times, don't worry about
// depends that are too far away
if (stages < 0)
return;
if (auto arg = v.dyn_cast<BlockArgument>()) {
deps.insert(v);
// Note: we have iv as the first arg, so the op idx is arg.getArgNumber()-1
collectDeps(yieldOp->getOperand(arg.getArgNumber() - 1), stages - 1, deps);
} else { // value
// v might be in deps, but we still need to visit v.
// This is because v might depends on value in previous iterations
deps.insert(v);
for (Value op : v.getDefiningOp()->getOperands())
collectDeps(op, stages, deps);
}
}
bool LoopPipeliner::isDirectUserOfAsyncLoad(Operation &op) {
for (Value loadOp : loads) {
assert(loadOp.hasOneUse() &&
"load should only have one use (ConvertLayout)");
Value loadUseResult = loadOp.getUsers().begin()->getResult(0);
for (Value opOperand : op.getOperands()) {
if (opOperand == loadUseResult)
return true;
}
}
return false;
}
triton::gpu::AllocTensorOp
LoopPipeliner::allocateEmptyBuffer(Operation *op, OpBuilder &builder) {
// allocate a buffer for each pipelined tensor
// shape: e.g. (numStages==4), <32x64xbf16> -> <4x32x64xbf16>
Value convertLayout = loadsMapping[op->getResult(0)];
if (auto tensorType = convertLayout.getType().dyn_cast<RankedTensorType>()) {
SmallVector<int64_t> shape(tensorType.getShape().begin(),
tensorType.getShape().end());
shape.insert(shape.begin(), numStages);
Type elementType = tensorType.getElementType();
// The encoding of the buffer is similar to the original tensor
Attribute encoding = tensorType.getEncoding();
auto bufferType = RankedTensorType::get(shape, elementType, encoding);
return builder.create<triton::gpu::AllocTensorOp>(convertLayout.getLoc(),
bufferType);
}
llvm_unreachable("Async copy's return should be of RankedTensorType");
}
/// A load instruction can be pipelined if:
/// - the load doesn't depend on any other loads (after loop peeling)
/// - (?) this load is not a loop-invariant value (we should run LICM before
/// this pass?)
LogicalResult LoopPipeliner::initialize() {
Block *loop = forOp.getBody();
// can we use forOp.walk(...) here?
SmallVector<triton::LoadOp, 2> allLoads;
for (Operation &op : *loop)
if (auto loadOp = dyn_cast<triton::LoadOp>(&op))
allLoads.push_back(loadOp);
// Early stop: no need to continue if there is no load in the loop.
if (allLoads.empty())
return failure();
// load => values that it depends on
DenseMap<Value, DenseSet<Value>> loadDeps;
for (triton::LoadOp loadOp : allLoads) {
DenseSet<Value> deps;
for (Value op : loadOp->getOperands())
collectDeps(op, numStages - 1, deps);
loadDeps[loadOp] = deps;
}
// Don't pipeline loads that depend on other loads
// (Because if a load depends on another load, this load needs to wait on the
// other load in the prologue, which is against the point of the pipeline
// pass)
for (triton::LoadOp loadOp : allLoads) {
bool isCandiate = true;
for (triton::LoadOp other : allLoads) {
if (loadDeps[loadOp].contains(other)) {
isCandiate = false;
break;
}
}
// For now, we only pipeline loads that have one covert_layout (to smem) use
// TODO: lift this constraint in the future
if (isCandiate && loadOp.getResult().hasOneUse()) {
isCandiate = false;
Operation *use = *loadOp.getResult().getUsers().begin();
if (auto convertLayout =
llvm::dyn_cast<triton::gpu::ConvertLayoutOp>(use)) {
if (auto tensorType = convertLayout.getResult()
.getType()
.dyn_cast<RankedTensorType>()) {
if (tensorType.getEncoding().isa<triton::gpu::SharedEncodingAttr>()) {
isCandiate = true;
loadsMapping[loadOp] = convertLayout;
}
}
}
} else
isCandiate = false;
if (isCandiate)
loads.insert(loadOp);
}
// we have some loads to pipeline
if (!loads.empty()) {
// update depArgs & depOps
for (Value loadOp : loads) {
for (Value dep : loadDeps[loadOp]) {
// TODO: we should record the stage that the value is depended on
if (auto arg = dep.dyn_cast<BlockArgument>())
depArgs.insert(arg);
else
depOps.insert(dep.getDefiningOp());
}
}
return success();
}
return failure();
}
void LoopPipeliner::emitPrologue() {
// llvm::errs() << "loads to pipeline...:\n";
// for (Value load : loads)
// llvm::errs() << load << "\n";
OpBuilder builder(forOp);
for (BlockArgument &arg : forOp.getRegionIterArgs()) {
OpOperand &operand = forOp.getOpOperandForRegionIterArg(arg);
setValueMapping(arg, operand.get(), 0);
}
// prologue from [0, numStage-1)
Value iv = forOp.getLowerBound();
pipelineIterIdx = builder.create<arith::ConstantIntOp>(iv.getLoc(), 0, 32);
for (int stage = 0; stage < numStages - 1; ++stage) {
// special handling for induction variable as the increment is implicit
if (stage != 0)
iv = builder.create<arith::AddIOp>(iv.getLoc(), iv, forOp.getStep());
setValueMapping(forOp.getInductionVar(), iv, stage);
// special handling for loop condition as there is no condition in ForOp
Value loopCond = builder.create<arith::CmpIOp>(
iv.getLoc(), arith::CmpIPredicate::slt, iv, forOp.getUpperBound());
// rematerialize peeled values
SmallVector<Operation *> orderedDeps;
for (Operation &op : forOp.getLoopBody().front()) {
if (depOps.contains(&op))
orderedDeps.push_back(&op);
else if (loads.contains(op.getResult(0)))
orderedDeps.push_back(&op);
}
assert(depOps.size() + loads.size() == orderedDeps.size() &&
"depOps contains invalid values");
for (Operation *op : orderedDeps) {
Operation *newOp = nullptr;
if (loads.contains(op->getResult(0))) {
// Allocate empty buffer
if (stage == 0) {
loadsBuffer[op->getResult(0)] = allocateEmptyBuffer(op, builder);
loadStageBuffer[op->getResult(0)] = {loadsBuffer[op->getResult(0)]};
}
// load => copy async
// TODO: check if the hardware supports async copy
if (auto loadOp = llvm::dyn_cast<triton::LoadOp>(op)) {
newOp = builder.create<triton::gpu::InsertSliceAsyncOp>(
op->getLoc(), loadsBuffer[loadOp].getType(),
lookupOrDefault(loadOp.ptr(), stage),
loadStageBuffer[loadOp][stage], pipelineIterIdx,
lookupOrDefault(loadOp.mask(), stage),
lookupOrDefault(loadOp.other(), stage), loadOp.cache(),
loadOp.evict(), loadOp.isVolatile(), /*axis*/ 0);
loadStageBuffer[loadOp].push_back(newOp->getResult(0));
} else
llvm_unreachable("This should be LoadOp");
} else {
newOp = builder.clone(*op);
// Update loop-carried uses
for (unsigned opIdx = 0; opIdx < op->getNumOperands(); ++opIdx) {
auto it = valueMapping.find(op->getOperand(opIdx));
if (it != valueMapping.end()) {
Value v = it->second[stage];
assert(v);
newOp->setOperand(opIdx, v);
} // else, op at opIdx is a loop-invariant value
}
}
// If this is a load/async_copy, we need to update the mask
if (Value mask = [&]() {
if (auto loadOp = llvm::dyn_cast<triton::LoadOp>(newOp)) {
return loadOp.mask();
} else if (auto insertSliceAsyncOp =
llvm::dyn_cast<triton::gpu::InsertSliceAsyncOp>(
newOp)) {
return insertSliceAsyncOp.mask();
} else {
return mlir::Value();
}
}()) {
// assert(I1 or TensorOf<[I1]>);
OpBuilder::InsertionGuard g(builder);
// TODO: move this out of the loop
builder.setInsertionPoint(newOp);
Value splatCond = builder.create<triton::SplatOp>(
mask.getLoc(), mask.getType(), loopCond);
Value newMask =
builder.create<arith::AndIOp>(mask.getLoc(), mask, splatCond);
// TODO: better way to do this?
if (llvm::isa<triton::LoadOp>(newOp))
newOp->setOperand(1, newMask);
else // InsertSliceAsyncOp
newOp->setOperand(3, newMask);
}
// update mapping of results
for (unsigned dstIdx : llvm::seq(unsigned(0), op->getNumResults())) {
Value originalResult = op->getResult(dstIdx);
// copy_async will update the value of its only use
// TODO: load should no be used in the preheader?
if (loads.contains(originalResult)) {
break;
// originalResult = loadsMapping[originalResult];
}
setValueMapping(originalResult, newOp->getResult(dstIdx), stage);
// update mapping for loop-carried values (args)
for (OpOperand &operand : yieldOp->getOpOperands()) {
if (operand.get() == op->getResult(dstIdx))
setValueMapping(
forOp.getRegionIterArgs()[operand.getOperandNumber()],
newOp->getResult(dstIdx), stage + 1);
}
}
}
pipelineIterIdx = builder.create<arith::AddIOp>(
iv.getLoc(), pipelineIterIdx,
builder.create<arith::ConstantIntOp>(iv.getLoc(), 1, 32));
} // for (int stage = 0; stage < numStages - 1; ++stage)
// async.wait & extract_slice
Operation *asyncWait = builder.create<triton::gpu::AsyncWaitOp>(
loads[0].getLoc(), loads.size() * (numStages - 2));
loopIterIdx = builder.create<arith::ConstantIntOp>(iv.getLoc(), 0, 32);
for (Value loadOp : loads) {
Value extractSlice = builder.create<triton::gpu::ExtractSliceOp>(
loadOp.getLoc(), loadsMapping[loadOp].getType(),
loadStageBuffer[loadOp][numStages - 1], loopIterIdx, /*axis*/ 0);
loadsExtract[loadOp] = extractSlice;
}
}
scf::ForOp LoopPipeliner::createNewForOp() {
OpBuilder builder(forOp);
// order of new args:
// (original args),
// (insertSliceAsync buffer at stage numStages - 1) for each load
// (extracted tensor) for each load
// (depArgs at stage numStages-1)
// (iv at stage numStages-1)
// (pipeline iteration index)
// (loop iteration index)
SmallVector<Value> newLoopArgs;
// We need this to update operands for yield
// original block arg => new arg's idx
DenseMap<BlockArgument, size_t> depArgsIdx;
for (auto v : forOp.getIterOperands())
newLoopArgs.push_back(v);
size_t bufferIdx = newLoopArgs.size();
for (Value loadOp : loads)
newLoopArgs.push_back(loadStageBuffer[loadOp].back());
size_t loadIdx = newLoopArgs.size();
for (Value loadOp : loads)
newLoopArgs.push_back(loadsExtract[loadOp]);
size_t depArgsBeginIdx = newLoopArgs.size();
for (BlockArgument depArg : depArgs) {
depArgsIdx[depArg] = newLoopArgs.size();
newLoopArgs.push_back(valueMapping[depArg][numStages - 1]);
}
size_t nextIVIdx = newLoopArgs.size();
newLoopArgs.push_back(valueMapping[forOp.getInductionVar()][numStages - 2]);
newLoopArgs.push_back(pipelineIterIdx);
newLoopArgs.push_back(loopIterIdx);
for (size_t i = 0; i < newLoopArgs.size(); ++i)
assert(newLoopArgs[i]);
// 1. signature of the new ForOp
auto newForOp = builder.create<scf::ForOp>(
forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
forOp.getStep(), newLoopArgs);
// 2. body of the new ForOp
builder.setInsertionPointToStart(newForOp.getBody());
BlockAndValueMapping mapping;
for (const auto &arg : llvm::enumerate(forOp.getRegionIterArgs()))
mapping.map(arg.value(), newForOp.getRegionIterArgs()[arg.index()]);
// 2.1 clone the loop body, replace original args with args of the new ForOp
// Insert async wait if necessary.
for (Operation &op : forOp.getBody()->without_terminator()) {
Operation *newOp = builder.clone(op, mapping);
// update mapping of results
for (unsigned dstIdx : llvm::seq(unsigned(0), op.getNumResults()))
mapping.map(op.getResult(dstIdx), newOp->getResult(dstIdx));
}
// 3. replace loads with block args (from prologue)
for (size_t idx = 0; idx < loads.size(); ++idx) {
Value load = loads[idx];
assert(load.hasOneUse() &&
"we assume that this load has one use (ConvertLayout)");
Value loadUse = load.getUsers().begin()->getResult(0);
mapping.lookup(loadUse).replaceAllUsesWith(
newForOp.getRegionIterArgs()[loadIdx + idx]);
// delete old load and layout conversion
mapping.lookup(loadUse).getDefiningOp()->erase();
mapping.lookup(load).getDefiningOp()->erase();
}
// 4. prefetch the next iteration
SmallVector<Operation *> orderedDeps;
for (Operation &op : forOp.getLoopBody().front()) {
if (depOps.contains(&op))
orderedDeps.push_back(&op);
else if (loads.contains(op.getResult(0)))
orderedDeps.push_back(&op);
}
assert(depOps.size() + loads.size() == orderedDeps.size() &&
"depOps contains invalid values");
BlockAndValueMapping nextMapping;
DenseMap<BlockArgument, Value> depArgsMapping;
size_t argIdx = 0;
for (BlockArgument arg : depArgs) {
nextMapping.map(arg,
newForOp.getRegionIterArgs()[argIdx + depArgsBeginIdx]);
++argIdx;
}
// special handling for iv & loop condition
Value nextIV = builder.create<arith::AddIOp>(
newForOp.getInductionVar().getLoc(),
newForOp.getRegionIterArgs()[nextIVIdx], newForOp.getStep());
Value nextLoopCond =
builder.create<arith::CmpIOp>(nextIV.getLoc(), arith::CmpIPredicate::slt,
nextIV, newForOp.getUpperBound());
// slice index
SmallVector<Value> nextBuffers;
SmallVector<Value> extractSlices;
pipelineIterIdx = newForOp.getRegionIterArgs()[nextIVIdx + 1];
Value insertSliceIndex = builder.create<arith::RemSIOp>(
nextIV.getLoc(), pipelineIterIdx,
builder.create<arith::ConstantIntOp>(nextIV.getLoc(), numStages, 32));
loopIterIdx = newForOp.getRegionIterArgs()[nextIVIdx + 2];
Value extractSliceIndex = builder.create<arith::RemSIOp>(
nextIV.getLoc(), loopIterIdx,
builder.create<arith::ConstantIntOp>(nextIV.getLoc(), numStages, 32));
for (Operation *op : orderedDeps) {
Operation *nextOp = nullptr;
// TODO(da): does this work if loadOp has no mask?
// update loading mask
if (loads.contains(op->getResult(0))) {
auto loadOp = llvm::cast<triton::LoadOp>(op);
Value mask = loadOp.mask();
if (mask) {
Value splatCond = builder.create<triton::SplatOp>(
mask.getLoc(), mask.getType(), nextLoopCond);
Value newMask = builder.create<arith::AndIOp>(
mask.getLoc(), splatCond, nextMapping.lookupOrDefault(mask));
// if mask is defined outside the loop, don't update the map more than
// once
if (!(forOp.isDefinedOutsideOfLoop(mask) && nextMapping.contains(mask)))
nextMapping.map(mask, newMask);
}
Value insertAsyncOp = builder.create<triton::gpu::InsertSliceAsyncOp>(
op->getLoc(), loadsBuffer[loadOp].getType(),
nextMapping.lookupOrDefault(loadOp.ptr()),
newForOp.getRegionIterArgs()[bufferIdx + nextBuffers.size()],
insertSliceIndex, nextMapping.lookupOrDefault(loadOp.mask()),
nextMapping.lookupOrDefault(loadOp.other()), loadOp.cache(),
loadOp.evict(), loadOp.isVolatile(), /*axis*/ 0);
nextBuffers.push_back(insertAsyncOp);
nextOp = builder.create<triton::gpu::ExtractSliceOp>(
op->getLoc(), loadsMapping[loadOp].getType(), insertAsyncOp,
extractSliceIndex, /*axis*/ 0);
extractSlices.push_back(nextOp->getResult(0));
} else
nextOp = builder.clone(*op, nextMapping);
// update mapping of results
for (unsigned dstIdx : llvm::seq(unsigned(0), op->getNumResults())) {
nextMapping.map(op->getResult(dstIdx), nextOp->getResult(dstIdx));
// if this is a loop-carried value, update the mapping for yield
auto originYield = cast<scf::YieldOp>(forOp.getBody()->getTerminator());
for (OpOperand &operand : originYield->getOpOperands()) {
if (operand.get() == op->getResult(dstIdx)) {
size_t originIdx = operand.getOperandNumber();
size_t newArgIdx = depArgsIdx[forOp.getRegionIterArgs()[originIdx]];
BlockArgument newArg = newForOp.getRegionIterArgs()[newArgIdx];
depArgsMapping[newArg] = nextOp->getResult(dstIdx);
}
}
}
}
// async.wait & extract_slice
Operation *asyncWait = builder.create<triton::gpu::AsyncWaitOp>(
loads[0].getLoc(), loads.size() * (numStages - 2));
for (auto it = extractSlices.rbegin(); it != extractSlices.rend(); ++it) {
// move extract_slice after asyncWait
it->getDefiningOp()->moveAfter(asyncWait);
}
// bump iteration count
pipelineIterIdx = builder.create<arith::AddIOp>(
nextIV.getLoc(), pipelineIterIdx,
builder.create<arith::ConstantIntOp>(nextIV.getLoc(), 1, 32));
loopIterIdx = builder.create<arith::AddIOp>(
nextIV.getLoc(), loopIterIdx,
builder.create<arith::ConstantIntOp>(nextIV.getLoc(), 1, 32));
// Finally, the YieldOp, need to sync with the order of newLoopArgs
SmallVector<Value> yieldValues;
for (Value v : forOp.getBody()->getTerminator()->getOperands())
yieldValues.push_back(mapping.lookup(v));
for (Value nextBuffer : nextBuffers)
yieldValues.push_back(nextBuffer);
for (Value nextSlice : extractSlices)
yieldValues.push_back(nextSlice);
for (size_t i = depArgsBeginIdx; i < nextIVIdx; ++i)
yieldValues.push_back(
depArgsMapping.lookup(newForOp.getRegionIterArgs()[i]));
yieldValues.push_back(nextIV);
yieldValues.push_back(pipelineIterIdx);
yieldValues.push_back(loopIterIdx);
builder.setInsertionPointToEnd(newForOp.getBody());
auto test = builder.create<scf::YieldOp>(
forOp.getBody()->getTerminator()->getLoc(), yieldValues);
return newForOp;
}
// ref: mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp
struct PipelinePass : public TritonGPUPipelineBase<PipelinePass> {
PipelinePass() = default;
PipelinePass(int numStages) { this->numStages = numStages; }
void runOnOperation() override {
int numStages = this->numStages;
if (numStages <= 1)
return;
getOperation()->walk([&](scf::ForOp forOp) -> void {
LoopPipeliner pipeliner(forOp, numStages);
if (pipeliner.initialize().failed())
return;
pipeliner.emitPrologue();
scf::ForOp newForOp = pipeliner.createNewForOp();
// replace the original loop
for (unsigned i = 0; i < forOp->getNumResults(); ++i)
forOp->getResult(i).replaceAllUsesWith(newForOp->getResult(i));
forOp->erase();
});
}
};
} // anonymous namespace
std::unique_ptr<Pass> mlir::createTritonGPUPipelinePass(int numStages) {
return std::make_unique<PipelinePass>(numStages);
}