@@ -566,22 +566,60 @@ static void collectNonRecomputableDeps(Value &v, omp::TargetOp targetOp,
566566 toRecompute);
567567}
568568
569- static void reloadCacheAndRecompute (Location loc, RewriterBase &rewriter,
570- MLIRContext &ctx, IRMapping &mapping,
571- Operation *splitBefore, Block *targetBlock,
572- Block *newTargetBlock,
573- SmallVector<Value> &allocs,
574- SetVector<Operation *> &toRecompute) {
575- for (unsigned i = 0 ; i < targetBlock->getNumArguments (); i++) {
576- auto originalArg = targetBlock->getArgument (i);
569+ static void createBlockArgsAndMap (Location loc, RewriterBase &rewriter,
570+ omp::TargetOp &targetOp, Block *targetBlock,
571+ Block *newTargetBlock,
572+ SmallVector<Value> &mapOperands,
573+ SmallVector<Value> &allocs,
574+ IRMapping &irMapping) {
575+ // Map `map_operands` to block arguments.
576+ unsigned originalMapVarsSize = targetOp.getMapVars ().size ();
577+ for (unsigned i = 0 ; i < mapOperands.size (); ++i) {
578+ Value originalValue;
579+ BlockArgument newArg;
580+ // Map the new arguments from the original block.
581+ if (i < originalMapVarsSize) {
582+ originalValue = targetBlock->getArgument (i);
583+ newArg = newTargetBlock->addArgument (originalValue.getType (),
584+ originalValue.getLoc ());
585+ }
586+ // Map the new arguments from the `allocs`.
587+ else {
588+ originalValue = allocs[i - originalMapVarsSize];
589+ newArg = newTargetBlock->addArgument (
590+ getPtrTypeForOmp (originalValue.getType ()), originalValue.getLoc ());
591+ }
592+ irMapping.map (originalValue, newArg);
593+ }
594+ // Map `private_vars` to block arguments.
595+ unsigned originalPrivateVarsSize = targetOp.getPrivateVars ().size ();
596+ for (unsigned i = 0 ; i < originalPrivateVarsSize; ++i) {
597+ auto originalArg = targetBlock->getArgument (originalMapVarsSize + i);
577598 auto newArg = newTargetBlock->addArgument (originalArg.getType (),
578599 originalArg.getLoc ());
579- mapping .map (originalArg, newArg);
600+ irMapping .map (originalArg, newArg);
580601 }
581- auto llvmPtrTy = LLVM::LLVMPointerType::get (&ctx);
582- for (auto original : allocs) {
583- Value newArg = newTargetBlock->addArgument (
584- getPtrTypeForOmp (original.getType ()), original.getLoc ());
602+ return ;
603+ }
604+
605+ static void reloadCacheAndRecompute (
606+ Location loc, RewriterBase &rewriter, Operation *splitBefore,
607+ omp::TargetOp &targetOp, Block *targetBlock, Block *newTargetBlock,
608+ SmallVector<Value> &mapOperands, SmallVector<Value> &allocs,
609+ SetVector<Operation *> &toRecompute, IRMapping &irMapping) {
610+ createBlockArgsAndMap (loc, rewriter, targetOp, targetBlock, newTargetBlock,
611+ mapOperands, allocs, irMapping);
612+ // Handle the load operations for the allocs.
613+ rewriter.setInsertionPointToStart (newTargetBlock);
614+ auto llvmPtrTy = LLVM::LLVMPointerType::get (targetOp.getContext ());
615+
616+ unsigned originalMapVarsSize = targetOp.getMapVars ().size ();
617+ // Create Stores for allocs.
618+ for (unsigned i = 0 ; i < allocs.size (); ++i) {
619+ Value original = allocs[i];
620+ // Get the new block argument for this specific allocated value.
621+ Value newArg = newTargetBlock->getArgument (originalMapVarsSize + i);
622+
585623 Value restored;
586624 if (isPtr (original.getType ())) {
587625 restored = rewriter.create <LLVM::LoadOp>(loc, llvmPtrTy, newArg);
@@ -591,18 +629,18 @@ static void reloadCacheAndRecompute(Location loc, RewriterBase &rewriter,
591629 } else {
592630 restored = rewriter.create <fir::LoadOp>(loc, newArg);
593631 }
594- mapping .map (original, restored);
632+ irMapping .map (original, restored);
595633 }
634+
596635 for (auto it = targetBlock->begin (); it != splitBefore->getIterator (); it++) {
597636 if (toRecompute.contains (&*it))
598- rewriter.clone (*it, mapping );
637+ rewriter.clone (*it, irMapping );
599638 }
600639}
601640
602641static SplitResult isolateOp (Operation *splitBeforeOp, bool splitAfter,
603642 RewriterBase &rewriter) {
604643 auto targetOp = cast<omp::TargetOp>(splitBeforeOp->getParentOp ());
605- MLIRContext &ctx = *targetOp.getContext ();
606644 assert (targetOp);
607645 auto loc = targetOp.getLoc ();
608646 auto *targetBlock = &targetOp.getRegion ().front ();
@@ -657,22 +695,29 @@ static SplitResult isolateOp(Operation *splitBeforeOp, bool splitAfter,
657695 auto *preTargetBlock = rewriter.createBlock (
658696 &preTargetOp.getRegion (), preTargetOp.getRegion ().begin (), {}, {});
659697 IRMapping preMapping;
660- for (unsigned i = 0 ; i < targetBlock->getNumArguments (); i++) {
661- auto originalArg = targetBlock->getArgument (i);
662- auto newArg = preTargetBlock->addArgument (originalArg.getType (),
663- originalArg.getLoc ());
664- preMapping.map (originalArg, newArg);
665- }
666- for (auto it = targetBlock->begin (); it != splitBeforeOp->getIterator (); it++)
667- rewriter.clone (*it, preMapping);
668698
699+ createBlockArgsAndMap (loc, rewriter, targetOp, targetBlock, preTargetBlock,
700+ preMapOperands, allocs, preMapping);
701+
702+ // Handle the store operations for the allocs.
703+ rewriter.setInsertionPointToStart (preTargetBlock);
669704 auto llvmPtrTy = LLVM::LLVMPointerType::get (targetOp.getContext ());
670705
671- for (auto original : allocs) {
672- Value toStore = preMapping.lookup (original);
673- auto newArg = preTargetBlock->addArgument (
674- getPtrTypeForOmp (original.getType ()), original.getLoc ());
675- if (isPtr (original.getType ())) {
706+ // Clone the original operations.
707+ for (auto it = targetBlock->begin (); it != splitBeforeOp->getIterator ();
708+ it++) {
709+ rewriter.clone (*it, preMapping);
710+ }
711+
712+ unsigned originalMapVarsSize = targetOp.getMapVars ().size ();
713+ // Create Stores for allocs.
714+ for (unsigned i = 0 ; i < allocs.size (); ++i) {
715+ Value originalResult = allocs[i];
716+ Value toStore = preMapping.lookup (originalResult);
717+ // Get the new block argument for this specific allocated value.
718+ Value newArg = preTargetBlock->getArgument (originalMapVarsSize + i);
719+
720+ if (isPtr (originalResult.getType ())) {
676721 if (!isa<LLVM::LLVMPointerType>(toStore.getType ()))
677722 toStore = rewriter.create <fir::ConvertOp>(loc, llvmPtrTy, toStore);
678723 rewriter.create <LLVM::StoreOp>(loc, toStore, newArg);
@@ -701,9 +746,9 @@ static SplitResult isolateOp(Operation *splitBeforeOp, bool splitAfter,
701746 isolatedTargetOp.getRegion ().begin (), {}, {});
702747
703748 IRMapping isolatedMapping;
704- reloadCacheAndRecompute (loc, rewriter, ctx, isolatedMapping, splitBeforeOp ,
705- targetBlock, isolatedTargetBlock , allocs,
706- toRecompute);
749+ reloadCacheAndRecompute (loc, rewriter, splitBeforeOp, targetOp, targetBlock ,
750+ isolatedTargetBlock, postMapOperands , allocs,
751+ toRecompute, isolatedMapping );
707752 rewriter.clone (*splitBeforeOp, isolatedMapping);
708753 rewriter.create <omp::TerminatorOp>(loc);
709754
@@ -725,8 +770,9 @@ static SplitResult isolateOp(Operation *splitBeforeOp, bool splitAfter,
725770 auto *postTargetBlock = rewriter.createBlock (
726771 &postTargetOp.getRegion (), postTargetOp.getRegion ().begin (), {}, {});
727772 IRMapping postMapping;
728- reloadCacheAndRecompute (loc, rewriter, ctx, postMapping, splitBeforeOp,
729- targetBlock, postTargetBlock, allocs, toRecompute);
773+ reloadCacheAndRecompute (loc, rewriter, splitBeforeOp, targetOp, targetBlock,
774+ postTargetBlock, postMapOperands, allocs,
775+ toRecompute, postMapping);
730776
731777 assert (splitBeforeOp->getNumResults () == 0 ||
732778 llvm::all_of (splitBeforeOp->getResults (),
@@ -755,15 +801,24 @@ static void moveToHost(omp::TargetOp targetOp, RewriterBase &rewriter) {
755801 Block *targetBlock = &targetOp.getRegion ().front ();
756802 assert (targetBlock == &targetOp.getRegion ().back ());
757803 IRMapping mapping;
758- for (auto map :
759- zip_equal (targetOp.getMapVars (), targetBlock->getArguments ())) {
760- Value mapInfo = std::get<0 >(map);
761- BlockArgument arg = std::get<1 >(map);
804+ for (unsigned i = 0 ; i < targetOp.getMapVars ().size (); ++i) {
805+ Value mapInfo = targetOp.getMapVars ()[i];
806+ BlockArgument arg = targetBlock->getArguments ()[i];
762807 Operation *op = mapInfo.getDefiningOp ();
763808 assert (op);
764809 auto mapInfoOp = cast<omp::MapInfoOp>(op);
810+ // map the block argument to the host-side variable pointer
765811 mapping.map (arg, mapInfoOp.getVarPtr ());
766812 }
813+ unsigned mapSize = targetOp.getMapVars ().size ();
814+ for (unsigned i = 0 ; i < targetOp.getPrivateVars ().size (); ++i) {
815+ Value privateVar = targetOp.getPrivateVars ()[i];
816+ // The mapping should link the device-side variable to the host-side one.
817+ BlockArgument arg = targetBlock->getArguments ()[mapSize + i];
818+ // Map the device-side copy (`arg`) to the host-side value (`privateVar`).
819+ mapping.map (arg, privateVar);
820+ }
821+
767822 rewriter.setInsertionPoint (targetOp);
768823 SmallVector<Operation *> opsToReplace;
769824 Value device = targetOp.getDevice ();
@@ -813,6 +868,7 @@ static void moveToHost(omp::TargetOp targetOp, RewriterBase &rewriter) {
813868 // fir.declare changes its type when hoisting it out of omp.target to
814869 // omp.target_data Introduce a load, if original declareOp input is not of
815870 // reference type, but cloned delcareOp input is reference type.
871+
816872 if (fir::DeclareOp clonedDeclareOp = dyn_cast<fir::DeclareOp>(clonedOp)) {
817873 auto originalDeclareOp = cast<fir::DeclareOp>(op);
818874 Type originalInType = originalDeclareOp.getMemref ().getType ();
@@ -833,6 +889,7 @@ static void moveToHost(omp::TargetOp targetOp, RewriterBase &rewriter) {
833889 opsToReplace.push_back (clonedOp);
834890 }
835891 }
892+
836893 for (Operation *op : opsToReplace) {
837894 if (auto allocOp = dyn_cast<fir::AllocMemOp>(op)) {
838895 rewriter.setInsertionPoint (allocOp);
0 commit comments