1- // ===- LowerWorkshare.cpp - special cases for bufferization -------===//
1+ // ===- LowerWorkdistribute.cpp
2+ // -------------------------------------------------===//
23//
34// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
45// See https://llvm.org/LICENSE.txt for license information.
89//
910// This file implements the lowering and optimisations of omp.workdistribute.
1011//
12+ // Fortran array statements are lowered to fir as fir.do_loop unordered.
13+ // lower-workdistribute pass works mainly on identifying fir.do_loop unordered
14+ // that is nested in target{teams{workdistribute{fir.do_loop unordered}}} and
15+ // lowers it to target{teams{parallel{wsloop{loop_nest}}}}.
16+ // It hoists all the other ops outside target region.
17+ // Relaces heap allocation on target with omp.target_allocmem and
18+ // deallocation with omp.target_freemem from host. Also replaces
19+ // runtime function "Assign" with equivalent omp function. ex. @_FortranAAssign
20+ // on target, once hoisted outside target is replaced with @_FortranAAssign_omp.
21+ //
1122// ===----------------------------------------------------------------------===//
1223
1324#include " flang/Optimizer/Builder/FIRBuilder.h"
@@ -49,6 +60,8 @@ using namespace mlir;
4960
5061namespace {
5162
63+ // The isRuntimeCall function is a utility designed to determine
64+ // if a given operation is a call to a Fortran-specific runtime function.
5265static bool isRuntimeCall (Operation *op) {
5366 if (auto callOp = dyn_cast<fir::CallOp>(op)) {
5467 auto callee = callOp.getCallee ();
@@ -61,8 +74,8 @@ static bool isRuntimeCall(Operation *op) {
6174 return false ;
6275}
6376
64- // / This is the single source of truth about whether we should parallelize an
65- // / operation nested in an omp.execute region.
77+ // This is the single source of truth about whether we should parallelize an
78+ // operation nested in an omp.execute region.
6679static bool shouldParallelize (Operation *op) {
6780 if (llvm::any_of (op->getResults (),
6881 [](OpResult v) -> bool { return !v.use_empty (); }))
@@ -74,13 +87,16 @@ static bool shouldParallelize(Operation *op) {
7487 return false ;
7588 return *unordered;
7689 }
77- if (isRuntimeCall (op)) {
90+ if (isRuntimeCall (op) &&
91+ (op->getName ().getStringRef () == " _FortranAAssign" )) {
7892 return true ;
7993 }
80- // We cannot parallise anything else
94+ // We cannot parallise anything else.
8195 return false ;
8296}
8397
98+ // The getPerfectlyNested function is a generic utility for finding
99+ // a single, "perfectly nested" operation within a parent operation.
84100template <typename T>
85101static T getPerfectlyNested (Operation *op) {
86102 if (op->getNumRegions () != 1 )
@@ -96,33 +112,37 @@ static T getPerfectlyNested(Operation *op) {
96112 return nullptr ;
97113}
98114
99- // / If B() and D() are parallelizable,
100- // /
101- // / omp.teams {
102- // / omp.workdistribute {
103- // / A()
104- // / B()
105- // / C()
106- // / D()
107- // / E()
108- // / }
109- // / }
110- // /
111- // / becomes
112- // /
113- // / A()
114- // / omp.teams {
115- // / omp.workdistribute {
116- // / B()
117- // / }
118- // / }
119- // / C()
120- // / omp.teams {
121- // / omp.workdistribute {
122- // / D()
123- // / }
124- // / }
125- // / E()
115+ // FissionWorkdistribute method finds the parallelizable ops
116+ // within teams {workdistribute} region and moves them to their
117+ // own teams{workdistribute} region.
118+ //
119+ // If B() and D() are parallelizable,
120+ //
121+ // omp.teams {
122+ // omp.workdistribute {
123+ // A()
124+ // B()
125+ // C()
126+ // D()
127+ // E()
128+ // }
129+ // }
130+ //
131+ // becomes
132+ //
133+ // A()
134+ // omp.teams {
135+ // omp.workdistribute {
136+ // B()
137+ // }
138+ // }
139+ // C()
140+ // omp.teams {
141+ // omp.workdistribute {
142+ // D()
143+ // }
144+ // }
145+ // E()
126146
127147static bool FissionWorkdistribute (omp::WorkdistributeOp workdistribute) {
128148 OpBuilder rewriter (workdistribute);
@@ -215,29 +235,6 @@ static bool FissionWorkdistribute(omp::WorkdistributeOp workdistribute) {
215235 return changed;
216236}
217237
218- // / If fir.do_loop is present inside teams workdistribute
219- // /
220- // / omp.teams {
221- // / omp.workdistribute {
222- // / fir.do_loop unoredered {
223- // / ...
224- // / }
225- // / }
226- // / }
227- // /
228- // / Then, its lowered to
229- // /
230- // / omp.teams {
231- // / omp.parallel {
232- // / omp.distribute {
233- // / omp.wsloop {
234- // / omp.loop_nest
235- // / ...
236- // / }
237- // / }
238- // / }
239- // / }
240-
241238static void genParallelOp (Location loc, OpBuilder &rewriter, bool composite) {
242239 auto parallelOp = rewriter.create <mlir::omp::ParallelOp>(loc);
243240 parallelOp.setComposite (composite);
@@ -295,6 +292,33 @@ static void genWsLoopOp(mlir::OpBuilder &rewriter, fir::DoLoopOp doLoop,
295292 return ;
296293}
297294
295+ // WorkdistributeDoLower method finds the fir.do_loop unoredered
296+ // nested in teams {workdistribute{fir.do_loop unoredered}} and
297+ // lowers it to teams {parallel { distribute {wsloop {loop_nest}}}}.
298+ //
299+ // If fir.do_loop is present inside teams workdistribute
300+ //
301+ // omp.teams {
302+ // omp.workdistribute {
303+ // fir.do_loop unoredered {
304+ // ...
305+ // }
306+ // }
307+ // }
308+ //
309+ // Then, its lowered to
310+ //
311+ // omp.teams {
312+ // omp.parallel {
313+ // omp.distribute {
314+ // omp.wsloop {
315+ // omp.loop_nest
316+ // ...
317+ // }
318+ // }
319+ // }
320+ // }
321+
298322static bool WorkdistributeDoLower (omp::WorkdistributeOp workdistribute) {
299323 OpBuilder rewriter (workdistribute);
300324 auto doLoop = getPerfectlyNested<fir::DoLoopOp>(workdistribute);
@@ -312,20 +336,23 @@ static bool WorkdistributeDoLower(omp::WorkdistributeOp workdistribute) {
312336 return false ;
313337}
314338
315- // / If A() and B () are present inside teams workdistribute
316- // /
317- // / omp.teams {
318- // / omp.workdistribute {
319- // / A()
320- // / B()
321- // / }
322- // / }
323- // /
324- // / Then, its lowered to
325- // /
326- // / A()
327- // / B()
328- // /
339+ // TeamsWorkdistributeToSingleOp method hoists all the ops inside
340+ // teams {workdistribute{}} before teams op.
341+ //
342+ // If A() and B () are present inside teams workdistribute
343+ //
344+ // omp.teams {
345+ // omp.workdistribute {
346+ // A()
347+ // B()
348+ // }
349+ // }
350+ //
351+ // Then, its lowered to
352+ //
353+ // A()
354+ // B()
355+ //
329356
330357static bool TeamsWorkdistributeToSingleOp (omp::TeamsOp teamsOp) {
331358 auto workdistributeOp = getPerfectlyNested<omp::WorkdistributeOp>(teamsOp);
@@ -358,11 +385,11 @@ struct SplitTargetResult {
358385 omp::TargetDataOp dataOp;
359386};
360387
361- // / If multiple workdistribute are nested in a target regions, we will need to
362- // / split the target region, but we want to preserve the data semantics of the
363- // / original data region and avoid unnecessary data movement at each of the
364- // / subkernels - we split the target region into a target_data{target}
365- // / nest where only the outer one moves the data
388+ // If multiple workdistribute are nested in a target regions, we will need to
389+ // split the target region, but we want to preserve the data semantics of the
390+ // original data region and avoid unnecessary data movement at each of the
391+ // subkernels - we split the target region into a target_data{target}
392+ // nest where only the outer one moves the data
366393std::optional<SplitTargetResult> splitTargetData (omp::TargetOp targetOp,
367394 RewriterBase &rewriter) {
368395 auto loc = targetOp->getLoc ();
@@ -438,6 +465,10 @@ std::optional<SplitTargetResult> splitTargetData(omp::TargetOp targetOp,
438465 return SplitTargetResult{cast<omp::TargetOp>(newTargetOp), targetDataOp};
439466}
440467
468+ // getNestedOpToIsolate function is designed to identify a specific teams
469+ // parallel op within the body of an omp::TargetOp that should be "isolated."
470+ // This returns a tuple of op, if its first op in targetBlock, or if the op is
471+ // last op in the tragte block.
441472static std::optional<std::tuple<Operation *, bool , bool >>
442473getNestedOpToIsolate (omp::TargetOp targetOp) {
443474 if (targetOp.getRegion ().empty ())
@@ -638,6 +669,15 @@ static void reloadCacheAndRecompute(
638669 }
639670}
640671
672+ // isolateOp method rewrites a omp.target_data { omp.target } in to
673+ // omp.target_data {
674+ // // preTargetOp region contains ops before splitBeforeOp.
675+ // omp.target {}
676+ // // isolatedTargetOp region contains splitBeforeOp,
677+ // omp.target {}
678+ // // postTargetOp region contains ops after splitBeforeOp.
679+ // omp.target {}
680+ // }
641681static SplitResult isolateOp (Operation *splitBeforeOp, bool splitAfter,
642682 RewriterBase &rewriter) {
643683 auto targetOp = cast<omp::TargetOp>(splitBeforeOp->getParentOp ());
@@ -796,6 +836,10 @@ genI32Constant(mlir::Location loc, mlir::RewriterBase &rewriter, int value) {
796836
797837static Type getOmpDeviceType (MLIRContext *c) { return IntegerType::get (c, 32 ); }
798838
839+ // moveToHost method clones all the ops from target region outside of it.
840+ // It hoists runtime functions and replaces them with omp vesions.
841+ // Also hoists and replaces fir.allocmem with omp.target_allocmem and
842+ // fir.freemem with omp.target_freemem
799843static void moveToHost (omp::TargetOp targetOp, RewriterBase &rewriter) {
800844 OpBuilder::InsertionGuard guard (rewriter);
801845 Block *targetBlock = &targetOp.getRegion ().front ();
@@ -815,7 +859,7 @@ static void moveToHost(omp::TargetOp targetOp, RewriterBase &rewriter) {
815859 Value privateVar = targetOp.getPrivateVars ()[i];
816860 // The mapping should link the device-side variable to the host-side one.
817861 BlockArgument arg = targetBlock->getArguments ()[mapSize + i];
818- // Map the device-side copy (` arg` ) to the host-side value (` privateVar` ).
862+ // Map the device-side copy (arg) to the host-side value (privateVar).
819863 mapping.map (arg, privateVar);
820864 }
821865
@@ -868,7 +912,6 @@ static void moveToHost(omp::TargetOp targetOp, RewriterBase &rewriter) {
868912 // fir.declare changes its type when hoisting it out of omp.target to
869913 // omp.target_data Introduce a load, if original declareOp input is not of
870914 // reference type, but cloned delcareOp input is reference type.
871-
872915 if (fir::DeclareOp clonedDeclareOp = dyn_cast<fir::DeclareOp>(clonedOp)) {
873916 auto originalDeclareOp = cast<fir::DeclareOp>(op);
874917 Type originalInType = originalDeclareOp.getMemref ().getType ();
@@ -890,6 +933,8 @@ static void moveToHost(omp::TargetOp targetOp, RewriterBase &rewriter) {
890933 }
891934 }
892935
936+ // Replace fir.allocmem with omp.target_allocmem,
937+ // fir.freemem with omp.target_freemem.
893938 for (Operation *op : opsToReplace) {
894939 if (auto allocOp = dyn_cast<fir::AllocMemOp>(op)) {
895940 rewriter.setInsertionPoint (allocOp);
0 commit comments