1
- // ===- LowerWorkshare.cpp - special cases for bufferization -------===//
1
+ // ===- LowerWorkdistribute.cpp
2
+ // -------------------------------------------------===//
2
3
//
3
4
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4
5
// See https://llvm.org/LICENSE.txt for license information.
8
9
//
9
10
// This file implements the lowering and optimisations of omp.workdistribute.
10
11
//
12
+ // Fortran array statements are lowered to fir as fir.do_loop unordered.
13
+ // lower-workdistribute pass works mainly on identifying fir.do_loop unordered
14
+ // that is nested in target{teams{workdistribute{fir.do_loop unordered}}} and
15
+ // lowers it to target{teams{parallel{wsloop{loop_nest}}}}.
16
+ // It hoists all the other ops outside target region.
17
+ // Relaces heap allocation on target with omp.target_allocmem and
18
+ // deallocation with omp.target_freemem from host. Also replaces
19
+ // runtime function "Assign" with equivalent omp function. ex. @_FortranAAssign
20
+ // on target, once hoisted outside target is replaced with @_FortranAAssign_omp.
21
+ //
11
22
// ===----------------------------------------------------------------------===//
12
23
13
24
#include " flang/Optimizer/Builder/FIRBuilder.h"
@@ -49,6 +60,8 @@ using namespace mlir;
49
60
50
61
namespace {
51
62
63
+ // The isRuntimeCall function is a utility designed to determine
64
+ // if a given operation is a call to a Fortran-specific runtime function.
52
65
static bool isRuntimeCall (Operation *op) {
53
66
if (auto callOp = dyn_cast<fir::CallOp>(op)) {
54
67
auto callee = callOp.getCallee ();
@@ -61,8 +74,8 @@ static bool isRuntimeCall(Operation *op) {
61
74
return false ;
62
75
}
63
76
64
- // / This is the single source of truth about whether we should parallelize an
65
- // / operation nested in an omp.execute region.
77
+ // This is the single source of truth about whether we should parallelize an
78
+ // operation nested in an omp.execute region.
66
79
static bool shouldParallelize (Operation *op) {
67
80
if (llvm::any_of (op->getResults (),
68
81
[](OpResult v) -> bool { return !v.use_empty (); }))
@@ -74,13 +87,16 @@ static bool shouldParallelize(Operation *op) {
74
87
return false ;
75
88
return *unordered;
76
89
}
77
- if (isRuntimeCall (op)) {
90
+ if (isRuntimeCall (op) &&
91
+ (op->getName ().getStringRef () == " _FortranAAssign" )) {
78
92
return true ;
79
93
}
80
- // We cannot parallise anything else
94
+ // We cannot parallise anything else.
81
95
return false ;
82
96
}
83
97
98
+ // The getPerfectlyNested function is a generic utility for finding
99
+ // a single, "perfectly nested" operation within a parent operation.
84
100
template <typename T>
85
101
static T getPerfectlyNested (Operation *op) {
86
102
if (op->getNumRegions () != 1 )
@@ -96,33 +112,37 @@ static T getPerfectlyNested(Operation *op) {
96
112
return nullptr ;
97
113
}
98
114
99
- // / If B() and D() are parallelizable,
100
- // /
101
- // / omp.teams {
102
- // / omp.workdistribute {
103
- // / A()
104
- // / B()
105
- // / C()
106
- // / D()
107
- // / E()
108
- // / }
109
- // / }
110
- // /
111
- // / becomes
112
- // /
113
- // / A()
114
- // / omp.teams {
115
- // / omp.workdistribute {
116
- // / B()
117
- // / }
118
- // / }
119
- // / C()
120
- // / omp.teams {
121
- // / omp.workdistribute {
122
- // / D()
123
- // / }
124
- // / }
125
- // / E()
115
+ // FissionWorkdistribute method finds the parallelizable ops
116
+ // within teams {workdistribute} region and moves them to their
117
+ // own teams{workdistribute} region.
118
+ //
119
+ // If B() and D() are parallelizable,
120
+ //
121
+ // omp.teams {
122
+ // omp.workdistribute {
123
+ // A()
124
+ // B()
125
+ // C()
126
+ // D()
127
+ // E()
128
+ // }
129
+ // }
130
+ //
131
+ // becomes
132
+ //
133
+ // A()
134
+ // omp.teams {
135
+ // omp.workdistribute {
136
+ // B()
137
+ // }
138
+ // }
139
+ // C()
140
+ // omp.teams {
141
+ // omp.workdistribute {
142
+ // D()
143
+ // }
144
+ // }
145
+ // E()
126
146
127
147
static bool FissionWorkdistribute (omp::WorkdistributeOp workdistribute) {
128
148
OpBuilder rewriter (workdistribute);
@@ -215,29 +235,6 @@ static bool FissionWorkdistribute(omp::WorkdistributeOp workdistribute) {
215
235
return changed;
216
236
}
217
237
218
- // / If fir.do_loop is present inside teams workdistribute
219
- // /
220
- // / omp.teams {
221
- // / omp.workdistribute {
222
- // / fir.do_loop unoredered {
223
- // / ...
224
- // / }
225
- // / }
226
- // / }
227
- // /
228
- // / Then, its lowered to
229
- // /
230
- // / omp.teams {
231
- // / omp.parallel {
232
- // / omp.distribute {
233
- // / omp.wsloop {
234
- // / omp.loop_nest
235
- // / ...
236
- // / }
237
- // / }
238
- // / }
239
- // / }
240
-
241
238
static void genParallelOp (Location loc, OpBuilder &rewriter, bool composite) {
242
239
auto parallelOp = rewriter.create <mlir::omp::ParallelOp>(loc);
243
240
parallelOp.setComposite (composite);
@@ -295,6 +292,33 @@ static void genWsLoopOp(mlir::OpBuilder &rewriter, fir::DoLoopOp doLoop,
295
292
return ;
296
293
}
297
294
295
+ // WorkdistributeDoLower method finds the fir.do_loop unoredered
296
+ // nested in teams {workdistribute{fir.do_loop unoredered}} and
297
+ // lowers it to teams {parallel { distribute {wsloop {loop_nest}}}}.
298
+ //
299
+ // If fir.do_loop is present inside teams workdistribute
300
+ //
301
+ // omp.teams {
302
+ // omp.workdistribute {
303
+ // fir.do_loop unoredered {
304
+ // ...
305
+ // }
306
+ // }
307
+ // }
308
+ //
309
+ // Then, its lowered to
310
+ //
311
+ // omp.teams {
312
+ // omp.parallel {
313
+ // omp.distribute {
314
+ // omp.wsloop {
315
+ // omp.loop_nest
316
+ // ...
317
+ // }
318
+ // }
319
+ // }
320
+ // }
321
+
298
322
static bool WorkdistributeDoLower (omp::WorkdistributeOp workdistribute) {
299
323
OpBuilder rewriter (workdistribute);
300
324
auto doLoop = getPerfectlyNested<fir::DoLoopOp>(workdistribute);
@@ -312,20 +336,23 @@ static bool WorkdistributeDoLower(omp::WorkdistributeOp workdistribute) {
312
336
return false ;
313
337
}
314
338
315
- // / If A() and B () are present inside teams workdistribute
316
- // /
317
- // / omp.teams {
318
- // / omp.workdistribute {
319
- // / A()
320
- // / B()
321
- // / }
322
- // / }
323
- // /
324
- // / Then, its lowered to
325
- // /
326
- // / A()
327
- // / B()
328
- // /
339
+ // TeamsWorkdistributeToSingleOp method hoists all the ops inside
340
+ // teams {workdistribute{}} before teams op.
341
+ //
342
+ // If A() and B () are present inside teams workdistribute
343
+ //
344
+ // omp.teams {
345
+ // omp.workdistribute {
346
+ // A()
347
+ // B()
348
+ // }
349
+ // }
350
+ //
351
+ // Then, its lowered to
352
+ //
353
+ // A()
354
+ // B()
355
+ //
329
356
330
357
static bool TeamsWorkdistributeToSingleOp (omp::TeamsOp teamsOp) {
331
358
auto workdistributeOp = getPerfectlyNested<omp::WorkdistributeOp>(teamsOp);
@@ -358,11 +385,11 @@ struct SplitTargetResult {
358
385
omp::TargetDataOp dataOp;
359
386
};
360
387
361
- // / If multiple workdistribute are nested in a target regions, we will need to
362
- // / split the target region, but we want to preserve the data semantics of the
363
- // / original data region and avoid unnecessary data movement at each of the
364
- // / subkernels - we split the target region into a target_data{target}
365
- // / nest where only the outer one moves the data
388
+ // If multiple workdistribute are nested in a target regions, we will need to
389
+ // split the target region, but we want to preserve the data semantics of the
390
+ // original data region and avoid unnecessary data movement at each of the
391
+ // subkernels - we split the target region into a target_data{target}
392
+ // nest where only the outer one moves the data
366
393
std::optional<SplitTargetResult> splitTargetData (omp::TargetOp targetOp,
367
394
RewriterBase &rewriter) {
368
395
auto loc = targetOp->getLoc ();
@@ -438,6 +465,10 @@ std::optional<SplitTargetResult> splitTargetData(omp::TargetOp targetOp,
438
465
return SplitTargetResult{cast<omp::TargetOp>(newTargetOp), targetDataOp};
439
466
}
440
467
468
+ // getNestedOpToIsolate function is designed to identify a specific teams
469
+ // parallel op within the body of an omp::TargetOp that should be "isolated."
470
+ // This returns a tuple of op, if its first op in targetBlock, or if the op is
471
+ // last op in the tragte block.
441
472
static std::optional<std::tuple<Operation *, bool , bool >>
442
473
getNestedOpToIsolate (omp::TargetOp targetOp) {
443
474
if (targetOp.getRegion ().empty ())
@@ -638,6 +669,15 @@ static void reloadCacheAndRecompute(
638
669
}
639
670
}
640
671
672
+ // isolateOp method rewrites a omp.target_data { omp.target } in to
673
+ // omp.target_data {
674
+ // // preTargetOp region contains ops before splitBeforeOp.
675
+ // omp.target {}
676
+ // // isolatedTargetOp region contains splitBeforeOp,
677
+ // omp.target {}
678
+ // // postTargetOp region contains ops after splitBeforeOp.
679
+ // omp.target {}
680
+ // }
641
681
static SplitResult isolateOp (Operation *splitBeforeOp, bool splitAfter,
642
682
RewriterBase &rewriter) {
643
683
auto targetOp = cast<omp::TargetOp>(splitBeforeOp->getParentOp ());
@@ -796,6 +836,10 @@ genI32Constant(mlir::Location loc, mlir::RewriterBase &rewriter, int value) {
796
836
797
837
static Type getOmpDeviceType (MLIRContext *c) { return IntegerType::get (c, 32 ); }
798
838
839
+ // moveToHost method clones all the ops from target region outside of it.
840
+ // It hoists runtime functions and replaces them with omp vesions.
841
+ // Also hoists and replaces fir.allocmem with omp.target_allocmem and
842
+ // fir.freemem with omp.target_freemem
799
843
static void moveToHost (omp::TargetOp targetOp, RewriterBase &rewriter) {
800
844
OpBuilder::InsertionGuard guard (rewriter);
801
845
Block *targetBlock = &targetOp.getRegion ().front ();
@@ -815,7 +859,7 @@ static void moveToHost(omp::TargetOp targetOp, RewriterBase &rewriter) {
815
859
Value privateVar = targetOp.getPrivateVars ()[i];
816
860
// The mapping should link the device-side variable to the host-side one.
817
861
BlockArgument arg = targetBlock->getArguments ()[mapSize + i];
818
- // Map the device-side copy (` arg` ) to the host-side value (` privateVar` ).
862
+ // Map the device-side copy (arg) to the host-side value (privateVar).
819
863
mapping.map (arg, privateVar);
820
864
}
821
865
@@ -868,7 +912,6 @@ static void moveToHost(omp::TargetOp targetOp, RewriterBase &rewriter) {
868
912
// fir.declare changes its type when hoisting it out of omp.target to
869
913
// omp.target_data Introduce a load, if original declareOp input is not of
870
914
// reference type, but cloned delcareOp input is reference type.
871
-
872
915
if (fir::DeclareOp clonedDeclareOp = dyn_cast<fir::DeclareOp>(clonedOp)) {
873
916
auto originalDeclareOp = cast<fir::DeclareOp>(op);
874
917
Type originalInType = originalDeclareOp.getMemref ().getType ();
@@ -890,6 +933,8 @@ static void moveToHost(omp::TargetOp targetOp, RewriterBase &rewriter) {
890
933
}
891
934
}
892
935
936
+ // Replace fir.allocmem with omp.target_allocmem,
937
+ // fir.freemem with omp.target_freemem.
893
938
for (Operation *op : opsToReplace) {
894
939
if (auto allocOp = dyn_cast<fir::AllocMemOp>(op)) {
895
940
rewriter.setInsertionPoint (allocOp);
0 commit comments