Skip to content

Commit a930387

Browse files
committed
Add comments/description for functions.
1 parent 8c3785a commit a930387

File tree

1 file changed

+121
-76
lines changed

1 file changed

+121
-76
lines changed

flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp

Lines changed: 121 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
//===- LowerWorkshare.cpp - special cases for bufferization -------===//
1+
//===- LowerWorkdistribute.cpp
2+
//-------------------------------------------------===//
23
//
34
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
45
// See https://llvm.org/LICENSE.txt for license information.
@@ -8,6 +9,16 @@
89
//
910
// This file implements the lowering and optimisations of omp.workdistribute.
1011
//
12+
// Fortran array statements are lowered to fir as fir.do_loop unordered.
13+
// lower-workdistribute pass works mainly on identifying fir.do_loop unordered
14+
// that is nested in target{teams{workdistribute{fir.do_loop unordered}}} and
15+
// lowers it to target{teams{parallel{wsloop{loop_nest}}}}.
16+
// It hoists all the other ops outside target region.
17+
// Relaces heap allocation on target with omp.target_allocmem and
18+
// deallocation with omp.target_freemem from host. Also replaces
19+
// runtime function "Assign" with equivalent omp function. ex. @_FortranAAssign
20+
// on target, once hoisted outside target is replaced with @_FortranAAssign_omp.
21+
//
1122
//===----------------------------------------------------------------------===//
1223

1324
#include "flang/Optimizer/Builder/FIRBuilder.h"
@@ -49,6 +60,8 @@ using namespace mlir;
4960

5061
namespace {
5162

63+
// The isRuntimeCall function is a utility designed to determine
64+
// if a given operation is a call to a Fortran-specific runtime function.
5265
static bool isRuntimeCall(Operation *op) {
5366
if (auto callOp = dyn_cast<fir::CallOp>(op)) {
5467
auto callee = callOp.getCallee();
@@ -61,8 +74,8 @@ static bool isRuntimeCall(Operation *op) {
6174
return false;
6275
}
6376

64-
/// This is the single source of truth about whether we should parallelize an
65-
/// operation nested in an omp.execute region.
77+
// This is the single source of truth about whether we should parallelize an
78+
// operation nested in an omp.execute region.
6679
static bool shouldParallelize(Operation *op) {
6780
if (llvm::any_of(op->getResults(),
6881
[](OpResult v) -> bool { return !v.use_empty(); }))
@@ -74,13 +87,16 @@ static bool shouldParallelize(Operation *op) {
7487
return false;
7588
return *unordered;
7689
}
77-
if (isRuntimeCall(op)) {
90+
if (isRuntimeCall(op) &&
91+
(op->getName().getStringRef() == "_FortranAAssign")) {
7892
return true;
7993
}
80-
// We cannot parallise anything else
94+
// We cannot parallise anything else.
8195
return false;
8296
}
8397

98+
// The getPerfectlyNested function is a generic utility for finding
99+
// a single, "perfectly nested" operation within a parent operation.
84100
template <typename T>
85101
static T getPerfectlyNested(Operation *op) {
86102
if (op->getNumRegions() != 1)
@@ -96,33 +112,37 @@ static T getPerfectlyNested(Operation *op) {
96112
return nullptr;
97113
}
98114

99-
/// If B() and D() are parallelizable,
100-
///
101-
/// omp.teams {
102-
/// omp.workdistribute {
103-
/// A()
104-
/// B()
105-
/// C()
106-
/// D()
107-
/// E()
108-
/// }
109-
/// }
110-
///
111-
/// becomes
112-
///
113-
/// A()
114-
/// omp.teams {
115-
/// omp.workdistribute {
116-
/// B()
117-
/// }
118-
/// }
119-
/// C()
120-
/// omp.teams {
121-
/// omp.workdistribute {
122-
/// D()
123-
/// }
124-
/// }
125-
/// E()
115+
// FissionWorkdistribute method finds the parallelizable ops
116+
// within teams {workdistribute} region and moves them to their
117+
// own teams{workdistribute} region.
118+
//
119+
// If B() and D() are parallelizable,
120+
//
121+
// omp.teams {
122+
// omp.workdistribute {
123+
// A()
124+
// B()
125+
// C()
126+
// D()
127+
// E()
128+
// }
129+
// }
130+
//
131+
// becomes
132+
//
133+
// A()
134+
// omp.teams {
135+
// omp.workdistribute {
136+
// B()
137+
// }
138+
// }
139+
// C()
140+
// omp.teams {
141+
// omp.workdistribute {
142+
// D()
143+
// }
144+
// }
145+
// E()
126146

127147
static bool FissionWorkdistribute(omp::WorkdistributeOp workdistribute) {
128148
OpBuilder rewriter(workdistribute);
@@ -215,29 +235,6 @@ static bool FissionWorkdistribute(omp::WorkdistributeOp workdistribute) {
215235
return changed;
216236
}
217237

218-
/// If fir.do_loop is present inside teams workdistribute
219-
///
220-
/// omp.teams {
221-
/// omp.workdistribute {
222-
/// fir.do_loop unoredered {
223-
/// ...
224-
/// }
225-
/// }
226-
/// }
227-
///
228-
/// Then, its lowered to
229-
///
230-
/// omp.teams {
231-
/// omp.parallel {
232-
/// omp.distribute {
233-
/// omp.wsloop {
234-
/// omp.loop_nest
235-
/// ...
236-
/// }
237-
/// }
238-
/// }
239-
/// }
240-
241238
static void genParallelOp(Location loc, OpBuilder &rewriter, bool composite) {
242239
auto parallelOp = rewriter.create<mlir::omp::ParallelOp>(loc);
243240
parallelOp.setComposite(composite);
@@ -295,6 +292,33 @@ static void genWsLoopOp(mlir::OpBuilder &rewriter, fir::DoLoopOp doLoop,
295292
return;
296293
}
297294

295+
// WorkdistributeDoLower method finds the fir.do_loop unoredered
296+
// nested in teams {workdistribute{fir.do_loop unoredered}} and
297+
// lowers it to teams {parallel { distribute {wsloop {loop_nest}}}}.
298+
//
299+
// If fir.do_loop is present inside teams workdistribute
300+
//
301+
// omp.teams {
302+
// omp.workdistribute {
303+
// fir.do_loop unoredered {
304+
// ...
305+
// }
306+
// }
307+
// }
308+
//
309+
// Then, its lowered to
310+
//
311+
// omp.teams {
312+
// omp.parallel {
313+
// omp.distribute {
314+
// omp.wsloop {
315+
// omp.loop_nest
316+
// ...
317+
// }
318+
// }
319+
// }
320+
// }
321+
298322
static bool WorkdistributeDoLower(omp::WorkdistributeOp workdistribute) {
299323
OpBuilder rewriter(workdistribute);
300324
auto doLoop = getPerfectlyNested<fir::DoLoopOp>(workdistribute);
@@ -312,20 +336,23 @@ static bool WorkdistributeDoLower(omp::WorkdistributeOp workdistribute) {
312336
return false;
313337
}
314338

315-
/// If A() and B () are present inside teams workdistribute
316-
///
317-
/// omp.teams {
318-
/// omp.workdistribute {
319-
/// A()
320-
/// B()
321-
/// }
322-
/// }
323-
///
324-
/// Then, its lowered to
325-
///
326-
/// A()
327-
/// B()
328-
///
339+
// TeamsWorkdistributeToSingleOp method hoists all the ops inside
340+
// teams {workdistribute{}} before teams op.
341+
//
342+
// If A() and B () are present inside teams workdistribute
343+
//
344+
// omp.teams {
345+
// omp.workdistribute {
346+
// A()
347+
// B()
348+
// }
349+
// }
350+
//
351+
// Then, its lowered to
352+
//
353+
// A()
354+
// B()
355+
//
329356

330357
static bool TeamsWorkdistributeToSingleOp(omp::TeamsOp teamsOp) {
331358
auto workdistributeOp = getPerfectlyNested<omp::WorkdistributeOp>(teamsOp);
@@ -358,11 +385,11 @@ struct SplitTargetResult {
358385
omp::TargetDataOp dataOp;
359386
};
360387

361-
/// If multiple workdistribute are nested in a target regions, we will need to
362-
/// split the target region, but we want to preserve the data semantics of the
363-
/// original data region and avoid unnecessary data movement at each of the
364-
/// subkernels - we split the target region into a target_data{target}
365-
/// nest where only the outer one moves the data
388+
// If multiple workdistribute are nested in a target regions, we will need to
389+
// split the target region, but we want to preserve the data semantics of the
390+
// original data region and avoid unnecessary data movement at each of the
391+
// subkernels - we split the target region into a target_data{target}
392+
// nest where only the outer one moves the data
366393
std::optional<SplitTargetResult> splitTargetData(omp::TargetOp targetOp,
367394
RewriterBase &rewriter) {
368395
auto loc = targetOp->getLoc();
@@ -438,6 +465,10 @@ std::optional<SplitTargetResult> splitTargetData(omp::TargetOp targetOp,
438465
return SplitTargetResult{cast<omp::TargetOp>(newTargetOp), targetDataOp};
439466
}
440467

468+
// getNestedOpToIsolate function is designed to identify a specific teams
469+
// parallel op within the body of an omp::TargetOp that should be "isolated."
470+
// This returns a tuple of op, if its first op in targetBlock, or if the op is
471+
// last op in the tragte block.
441472
static std::optional<std::tuple<Operation *, bool, bool>>
442473
getNestedOpToIsolate(omp::TargetOp targetOp) {
443474
if (targetOp.getRegion().empty())
@@ -638,6 +669,15 @@ static void reloadCacheAndRecompute(
638669
}
639670
}
640671

672+
// isolateOp method rewrites a omp.target_data { omp.target } in to
673+
// omp.target_data {
674+
// // preTargetOp region contains ops before splitBeforeOp.
675+
// omp.target {}
676+
// // isolatedTargetOp region contains splitBeforeOp,
677+
// omp.target {}
678+
// // postTargetOp region contains ops after splitBeforeOp.
679+
// omp.target {}
680+
// }
641681
static SplitResult isolateOp(Operation *splitBeforeOp, bool splitAfter,
642682
RewriterBase &rewriter) {
643683
auto targetOp = cast<omp::TargetOp>(splitBeforeOp->getParentOp());
@@ -796,6 +836,10 @@ genI32Constant(mlir::Location loc, mlir::RewriterBase &rewriter, int value) {
796836

797837
static Type getOmpDeviceType(MLIRContext *c) { return IntegerType::get(c, 32); }
798838

839+
// moveToHost method clones all the ops from target region outside of it.
840+
// It hoists runtime functions and replaces them with omp vesions.
841+
// Also hoists and replaces fir.allocmem with omp.target_allocmem and
842+
// fir.freemem with omp.target_freemem
799843
static void moveToHost(omp::TargetOp targetOp, RewriterBase &rewriter) {
800844
OpBuilder::InsertionGuard guard(rewriter);
801845
Block *targetBlock = &targetOp.getRegion().front();
@@ -815,7 +859,7 @@ static void moveToHost(omp::TargetOp targetOp, RewriterBase &rewriter) {
815859
Value privateVar = targetOp.getPrivateVars()[i];
816860
// The mapping should link the device-side variable to the host-side one.
817861
BlockArgument arg = targetBlock->getArguments()[mapSize + i];
818-
// Map the device-side copy (`arg`) to the host-side value (`privateVar`).
862+
// Map the device-side copy (arg) to the host-side value (privateVar).
819863
mapping.map(arg, privateVar);
820864
}
821865

@@ -868,7 +912,6 @@ static void moveToHost(omp::TargetOp targetOp, RewriterBase &rewriter) {
868912
// fir.declare changes its type when hoisting it out of omp.target to
869913
// omp.target_data Introduce a load, if original declareOp input is not of
870914
// reference type, but cloned delcareOp input is reference type.
871-
872915
if (fir::DeclareOp clonedDeclareOp = dyn_cast<fir::DeclareOp>(clonedOp)) {
873916
auto originalDeclareOp = cast<fir::DeclareOp>(op);
874917
Type originalInType = originalDeclareOp.getMemref().getType();
@@ -890,6 +933,8 @@ static void moveToHost(omp::TargetOp targetOp, RewriterBase &rewriter) {
890933
}
891934
}
892935

936+
// Replace fir.allocmem with omp.target_allocmem,
937+
// fir.freemem with omp.target_freemem.
893938
for (Operation *op : opsToReplace) {
894939
if (auto allocOp = dyn_cast<fir::AllocMemOp>(op)) {
895940
rewriter.setInsertionPoint(allocOp);

0 commit comments

Comments
 (0)