@@ -434,7 +434,7 @@ std::optional<SplitTargetResult> splitTargetData(omp::TargetOp targetOp,
434
434
rewriter.inlineRegionBefore (targetOp.getRegion (), newTargetOp.getRegion (),
435
435
newTargetOp.getRegion ().begin ());
436
436
437
- rewriter.replaceOp (targetOp, newTargetOp );
437
+ rewriter.replaceOp (targetOp, targetDataOp );
438
438
return SplitTargetResult{cast<omp::TargetOp>(newTargetOp), targetDataOp};
439
439
}
440
440
@@ -807,11 +807,30 @@ static void moveToHost(omp::TargetOp targetOp, RewriterBase &rewriter) {
807
807
rewriter.eraseOp (tmpCall);
808
808
} else {
809
809
Operation *clonedOp = rewriter.clone (*op, mapping);
810
- if (isa<fir::AllocMemOp>(clonedOp) || isa<fir::FreeMemOp>(clonedOp))
811
- opsToReplace.push_back (clonedOp);
812
810
for (unsigned i = 0 ; i < op->getNumResults (); ++i) {
813
811
mapping.map (op->getResult (i), clonedOp->getResult (i));
814
812
}
813
+ // fir.declare changes its type when hoisting it out of omp.target to
814
+ // omp.target_data Introduce a load, if original declareOp input is not of
815
+ // reference type, but cloned delcareOp input is reference type.
816
+ if (fir::DeclareOp clonedDeclareOp = dyn_cast<fir::DeclareOp>(clonedOp)) {
817
+ auto originalDeclareOp = cast<fir::DeclareOp>(op);
818
+ Type originalInType = originalDeclareOp.getMemref ().getType ();
819
+ Type clonedInType = clonedDeclareOp.getMemref ().getType ();
820
+
821
+ fir::ReferenceType originalRefType =
822
+ dyn_cast<fir::ReferenceType>(originalInType);
823
+ fir::ReferenceType clonedRefType =
824
+ dyn_cast<fir::ReferenceType>(clonedInType);
825
+ if (!originalRefType && clonedRefType) {
826
+ Type clonedEleTy = clonedRefType.getElementType ();
827
+ if (clonedEleTy == originalDeclareOp.getType ()) {
828
+ opsToReplace.push_back (clonedOp);
829
+ }
830
+ }
831
+ }
832
+ if (isa<fir::AllocMemOp>(clonedOp) || isa<fir::FreeMemOp>(clonedOp))
833
+ opsToReplace.push_back (clonedOp);
815
834
}
816
835
}
817
836
for (Operation *op : opsToReplace) {
@@ -833,6 +852,15 @@ static void moveToHost(omp::TargetOp targetOp, RewriterBase &rewriter) {
833
852
rewriter.create <omp::TargetFreeMemOp>(freeOp.getLoc (), device,
834
853
firConvertOp.getResult ());
835
854
rewriter.eraseOp (freeOp);
855
+ } else if (fir::DeclareOp clonedDeclareOp = dyn_cast<fir::DeclareOp>(op)) {
856
+ Type clonedInType = clonedDeclareOp.getMemref ().getType ();
857
+ fir::ReferenceType clonedRefType =
858
+ dyn_cast<fir::ReferenceType>(clonedInType);
859
+ Type clonedEleTy = clonedRefType.getElementType ();
860
+ rewriter.setInsertionPoint (op);
861
+ Value loadedValue = rewriter.create <fir::LoadOp>(
862
+ clonedDeclareOp.getLoc (), clonedEleTy, clonedDeclareOp.getMemref ());
863
+ clonedDeclareOp.getResult ().replaceAllUsesWith (loadedValue);
836
864
}
837
865
}
838
866
rewriter.eraseOp (targetOp);
0 commit comments