Skip to content

Commit 5585047

Browse files
committed
Fix hoisting declare ops out of omp.target
1 parent b29f28e commit 5585047

File tree

1 file changed

+31
-3
lines changed

1 file changed

+31
-3
lines changed

flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp

Lines changed: 31 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -434,7 +434,7 @@ std::optional<SplitTargetResult> splitTargetData(omp::TargetOp targetOp,
434434
rewriter.inlineRegionBefore(targetOp.getRegion(), newTargetOp.getRegion(),
435435
newTargetOp.getRegion().begin());
436436

437-
rewriter.replaceOp(targetOp, newTargetOp);
437+
rewriter.replaceOp(targetOp, targetDataOp);
438438
return SplitTargetResult{cast<omp::TargetOp>(newTargetOp), targetDataOp};
439439
}
440440

@@ -807,11 +807,30 @@ static void moveToHost(omp::TargetOp targetOp, RewriterBase &rewriter) {
807807
rewriter.eraseOp(tmpCall);
808808
} else {
809809
Operation *clonedOp = rewriter.clone(*op, mapping);
810-
if (isa<fir::AllocMemOp>(clonedOp) || isa<fir::FreeMemOp>(clonedOp))
811-
opsToReplace.push_back(clonedOp);
812810
for (unsigned i = 0; i < op->getNumResults(); ++i) {
813811
mapping.map(op->getResult(i), clonedOp->getResult(i));
814812
}
813+
// fir.declare changes its type when hoisting it out of omp.target to
814+
// omp.target_data Introduce a load, if original declareOp input is not of
815+
// reference type, but cloned delcareOp input is reference type.
816+
if (fir::DeclareOp clonedDeclareOp = dyn_cast<fir::DeclareOp>(clonedOp)) {
817+
auto originalDeclareOp = cast<fir::DeclareOp>(op);
818+
Type originalInType = originalDeclareOp.getMemref().getType();
819+
Type clonedInType = clonedDeclareOp.getMemref().getType();
820+
821+
fir::ReferenceType originalRefType =
822+
dyn_cast<fir::ReferenceType>(originalInType);
823+
fir::ReferenceType clonedRefType =
824+
dyn_cast<fir::ReferenceType>(clonedInType);
825+
if (!originalRefType && clonedRefType) {
826+
Type clonedEleTy = clonedRefType.getElementType();
827+
if (clonedEleTy == originalDeclareOp.getType()) {
828+
opsToReplace.push_back(clonedOp);
829+
}
830+
}
831+
}
832+
if (isa<fir::AllocMemOp>(clonedOp) || isa<fir::FreeMemOp>(clonedOp))
833+
opsToReplace.push_back(clonedOp);
815834
}
816835
}
817836
for (Operation *op : opsToReplace) {
@@ -833,6 +852,15 @@ static void moveToHost(omp::TargetOp targetOp, RewriterBase &rewriter) {
833852
rewriter.create<omp::TargetFreeMemOp>(freeOp.getLoc(), device,
834853
firConvertOp.getResult());
835854
rewriter.eraseOp(freeOp);
855+
} else if (fir::DeclareOp clonedDeclareOp = dyn_cast<fir::DeclareOp>(op)) {
856+
Type clonedInType = clonedDeclareOp.getMemref().getType();
857+
fir::ReferenceType clonedRefType =
858+
dyn_cast<fir::ReferenceType>(clonedInType);
859+
Type clonedEleTy = clonedRefType.getElementType();
860+
rewriter.setInsertionPoint(op);
861+
Value loadedValue = rewriter.create<fir::LoadOp>(
862+
clonedDeclareOp.getLoc(), clonedEleTy, clonedDeclareOp.getMemref());
863+
clonedDeclareOp.getResult().replaceAllUsesWith(loadedValue);
836864
}
837865
}
838866
rewriter.eraseOp(targetOp);

0 commit comments

Comments
 (0)