@@ -56,7 +56,10 @@ class GenericLoopConversionPattern
5656 " not yet implemented: Combined `parallel loop` directive" );
5757 break ;
5858 case GenericLoopCombinedInfo::TeamsLoop:
59- rewriteToDistributeParallelDo (loopOp, rewriter);
59+ if (teamsLoopCanBeParallelFor (loopOp))
60+ rewriteToDistributeParallelDo (loopOp, rewriter);
61+ else
62+ rewriteToDistrbute (loopOp, rewriter);
6063 break ;
6164 }
6265
@@ -97,8 +100,6 @@ class GenericLoopConversionPattern
97100 if (!loopOp.getReductionVars ().empty ())
98101 return todo (" reduction" );
99102
100- // TODO For `teams loop`, check similar constrains to what is checked
101- // by `TeamsLoopChecker` in SemaOpenMP.cpp.
102103 return mlir::success ();
103104 }
104105
@@ -118,6 +119,62 @@ class GenericLoopConversionPattern
118119 return result;
119120 }
120121
122+ // / Checks whether a `teams loop` construct can be rewriten to `teams
123+ // / distribute parallel do` or it has to be converted to `teams distribute`.
124+ // /
125+ // / This checks similar constrains to what is checked by `TeamsLoopChecker` in
126+ // / SemaOpenMP.cpp in clang.
127+ static bool teamsLoopCanBeParallelFor (mlir::omp::LoopOp loopOp) {
128+ bool canBeParallelFor =
129+ !loopOp
130+ .walk <mlir::WalkOrder::PreOrder>([&](mlir::Operation *nestedOp) {
131+ if (nestedOp == loopOp)
132+ return mlir::WalkResult::advance ();
133+
134+ if (auto nestedLoopOp =
135+ mlir::dyn_cast<mlir::omp::LoopOp>(nestedOp)) {
136+ GenericLoopCombinedInfo combinedInfo =
137+ findGenericLoopCombineInfo (nestedLoopOp);
138+
139+ // Worksharing loops cannot be nested inside each other.
140+ // Therefore, if the current `loop` directive nests another
141+ // `loop` whose `bind` modifier is `parallel`, this `loop`
142+ // directive cannot be mapped to `distribute parallel for`
143+ // but rather only to `distribute`.
144+ if (combinedInfo == GenericLoopCombinedInfo::Standalone &&
145+ nestedLoopOp.getBindKind () &&
146+ *nestedLoopOp.getBindKind () ==
147+ mlir::omp::ClauseBindKind::Parallel)
148+ return mlir::WalkResult::interrupt ();
149+
150+ // TODO check for combined `parallel loop` when we support
151+ // it.
152+ } else if (auto callOp =
153+ mlir::dyn_cast<mlir::CallOpInterface>(nestedOp)) {
154+ // Calls to non-OpenMP API runtime functions inhibits
155+ // transformation to `teams distribute parallel do` since the
156+ // called functions might have nested parallelism themselves.
157+ bool isOpenMPAPI = false ;
158+ mlir::CallInterfaceCallable callable =
159+ callOp.getCallableForCallee ();
160+
161+ if (auto callableSymRef =
162+ mlir::dyn_cast<mlir::SymbolRefAttr>(callable))
163+ isOpenMPAPI =
164+ callableSymRef.getRootReference ().strref ().starts_with (
165+ " omp_" );
166+
167+ if (!isOpenMPAPI)
168+ return mlir::WalkResult::interrupt ();
169+ }
170+
171+ return mlir::WalkResult::advance ();
172+ })
173+ .wasInterrupted ();
174+
175+ return canBeParallelFor;
176+ }
177+
121178 void rewriteStandaloneLoop (mlir::omp::LoopOp loopOp,
122179 mlir::ConversionPatternRewriter &rewriter) const {
123180 using namespace mlir ::omp;
0 commit comments