@@ -434,7 +434,7 @@ std::optional<SplitTargetResult> splitTargetData(omp::TargetOp targetOp,
434434 rewriter.inlineRegionBefore (targetOp.getRegion (), newTargetOp.getRegion (),
435435 newTargetOp.getRegion ().begin ());
436436
437- rewriter.replaceOp (targetOp, newTargetOp );
437+ rewriter.replaceOp (targetOp, targetDataOp );
438438 return SplitTargetResult{cast<omp::TargetOp>(newTargetOp), targetDataOp};
439439}
440440
@@ -807,11 +807,30 @@ static void moveToHost(omp::TargetOp targetOp, RewriterBase &rewriter) {
807807 rewriter.eraseOp (tmpCall);
808808 } else {
809809 Operation *clonedOp = rewriter.clone (*op, mapping);
810- if (isa<fir::AllocMemOp>(clonedOp) || isa<fir::FreeMemOp>(clonedOp))
811- opsToReplace.push_back (clonedOp);
812810 for (unsigned i = 0 ; i < op->getNumResults (); ++i) {
813811 mapping.map (op->getResult (i), clonedOp->getResult (i));
814812 }
813+ // fir.declare changes its type when hoisting it out of omp.target to
814+ // omp.target_data Introduce a load, if original declareOp input is not of
815+ // reference type, but cloned delcareOp input is reference type.
816+ if (fir::DeclareOp clonedDeclareOp = dyn_cast<fir::DeclareOp>(clonedOp)) {
817+ auto originalDeclareOp = cast<fir::DeclareOp>(op);
818+ Type originalInType = originalDeclareOp.getMemref ().getType ();
819+ Type clonedInType = clonedDeclareOp.getMemref ().getType ();
820+
821+ fir::ReferenceType originalRefType =
822+ dyn_cast<fir::ReferenceType>(originalInType);
823+ fir::ReferenceType clonedRefType =
824+ dyn_cast<fir::ReferenceType>(clonedInType);
825+ if (!originalRefType && clonedRefType) {
826+ Type clonedEleTy = clonedRefType.getElementType ();
827+ if (clonedEleTy == originalDeclareOp.getType ()) {
828+ opsToReplace.push_back (clonedOp);
829+ }
830+ }
831+ }
832+ if (isa<fir::AllocMemOp>(clonedOp) || isa<fir::FreeMemOp>(clonedOp))
833+ opsToReplace.push_back (clonedOp);
815834 }
816835 }
817836 for (Operation *op : opsToReplace) {
@@ -833,6 +852,15 @@ static void moveToHost(omp::TargetOp targetOp, RewriterBase &rewriter) {
833852 rewriter.create <omp::TargetFreeMemOp>(freeOp.getLoc (), device,
834853 firConvertOp.getResult ());
835854 rewriter.eraseOp (freeOp);
855+ } else if (fir::DeclareOp clonedDeclareOp = dyn_cast<fir::DeclareOp>(op)) {
856+ Type clonedInType = clonedDeclareOp.getMemref ().getType ();
857+ fir::ReferenceType clonedRefType =
858+ dyn_cast<fir::ReferenceType>(clonedInType);
859+ Type clonedEleTy = clonedRefType.getElementType ();
860+ rewriter.setInsertionPoint (op);
861+ Value loadedValue = rewriter.create <fir::LoadOp>(
862+ clonedDeclareOp.getLoc (), clonedEleTy, clonedDeclareOp.getMemref ());
863+ clonedDeclareOp.getResult ().replaceAllUsesWith (loadedValue);
836864 }
837865 }
838866 rewriter.eraseOp (targetOp);
0 commit comments