@@ -762,37 +762,19 @@ static void moveToHost(omp::TargetOp targetOp, RewriterBase &rewriter) {
762762 mapping.map (arg, mapInfoOp.getVarPtr ());
763763 }
764764 rewriter.setInsertionPoint (targetOp);
765- SmallVector<Operation *> opsToMove;
765+ SmallVector<Operation *> opsToReplace;
766+ Value device = targetOp.getDevice ();
767+ /*
768+ if (!device) {
769+ device = genI32Constant(targetOp.getLoc(), rewriter, 0);
770+ }
766771 for (auto it = targetBlock->begin(), end = std::prev(targetBlock->end());
767772 it != end; ++it) {
768773 auto *op = &*it;
769- auto allocOp = dyn_cast<fir::AllocMemOp>(op);
770- auto freeOp = dyn_cast<fir::FreeMemOp>(op);
771774 fir::CallOp runtimeCall = nullptr;
772775 if (isRuntimeCall(op))
773776 runtimeCall = cast<fir::CallOp>(op);
774-
775- if (allocOp || freeOp || runtimeCall) {
776- Value device = targetOp.getDevice ();
777- if (!device) {
778- device = genI32Constant (it->getLoc (), rewriter, 0 );
779- }
780- if (allocOp) {
781- auto tmpAllocOp = rewriter.create <fir::OmpTargetAllocMemOp>(
782- allocOp.getLoc (), allocOp.getType (), device,
783- allocOp.getInTypeAttr (), allocOp.getUniqNameAttr (),
784- allocOp.getBindcNameAttr (), allocOp.getTypeparams (),
785- allocOp.getShape ());
786- auto newAllocOp = cast<fir::OmpTargetAllocMemOp>(
787- rewriter.clone (*tmpAllocOp.getOperation (), mapping));
788- mapping.map (allocOp.getResult (), newAllocOp.getResult ());
789- rewriter.eraseOp (tmpAllocOp);
790- } else if (freeOp) {
791- auto tmpFreeOp = rewriter.create <fir::OmpTargetFreeMemOp>(
792- freeOp.getLoc (), device, freeOp.getHeapref ());
793- rewriter.clone (*tmpFreeOp.getOperation (), mapping);
794- rewriter.eraseOp (tmpFreeOp);
795- } else if (runtimeCall) {
777+ if (runtimeCall) {
796778 auto module = runtimeCall->getParentOfType<ModuleOp>();
797779 auto callee = cast<func::FuncOp>(
798780 module.lookupSymbol(runtimeCall.getCalleeAttr()));
@@ -824,15 +806,42 @@ static void moveToHost(omp::TargetOp targetOp, RewriterBase &rewriter) {
824806 Operation *newCall = rewriter.clone(*tmpCall, mapping);
825807 mapping.map(&*it, newCall);
826808 rewriter.eraseOp(tmpCall);
827- }
828809 } else {
829810 Operation *clonedOp = rewriter.clone(*op, mapping);
811+ auto allocOp = dyn_cast<fir::AllocMemOp>(clonedOp);
812+ auto freeOp = dyn_cast<fir::FreeMemOp>(clonedOp);
813+ if (allocOp || freeOp)
814+ opsToReplace.push_back(clonedOp);
830815 for (unsigned i = 0; i < op->getNumResults(); ++i) {
831816 mapping.map(op->getResult(i), clonedOp->getResult(i));
832817 }
833818 }
834819 }
820+ for (Operation* op : opsToReplace) {
821+ if (auto allocOp = dyn_cast<fir::AllocMemOp>(op)) {
822+ rewriter.setInsertionPoint(allocOp);
823+ auto ompAllocmemOp = rewriter.create<omp::TargetAllocMemOp>(
824+ allocOp.getLoc(), rewriter.getI64Type(), device,
825+ allocOp.getInTypeAttr(), allocOp.getUniqNameAttr(),
826+ allocOp.getBindcNameAttr(), allocOp.getTypeparams(),
827+ allocOp.getShape());
828+ auto firConvertOp = rewriter.create<fir::ConvertOp>(allocOp.getLoc(),
829+ allocOp.getResult().getType(), ompAllocmemOp.getResult());
830+ rewriter.replaceOp(allocOp, firConvertOp.getResult());
831+ }
832+ else if (auto freeOp = dyn_cast<fir::FreeMemOp>(op)) {
833+ rewriter.setInsertionPoint(freeOp);
834+ auto firConvertOp = rewriter.create<fir::ConvertOp>(
835+ freeOp.getLoc(),
836+ rewriter.getI64Type(),
837+ freeOp.getHeapref());
838+ rewriter.create<omp::TargetFreeMemOp>(
839+ freeOp.getLoc(), device, firConvertOp.getResult());
840+ rewriter.eraseOp(freeOp);
841+ }
842+ }
835843 rewriter.eraseOp(targetOp);
844+ */
836845}
837846
838847void fissionTarget (omp::TargetOp targetOp, RewriterBase &rewriter) {
0 commit comments