@@ -566,22 +566,60 @@ static void collectNonRecomputableDeps(Value &v, omp::TargetOp targetOp,
566
566
toRecompute);
567
567
}
568
568
569
- static void reloadCacheAndRecompute (Location loc, RewriterBase &rewriter,
570
- MLIRContext &ctx, IRMapping &mapping,
571
- Operation *splitBefore, Block *targetBlock,
572
- Block *newTargetBlock,
573
- SmallVector<Value> &allocs,
574
- SetVector<Operation *> &toRecompute) {
575
- for (unsigned i = 0 ; i < targetBlock->getNumArguments (); i++) {
576
- auto originalArg = targetBlock->getArgument (i);
569
+ static void createBlockArgsAndMap (Location loc, RewriterBase &rewriter,
570
+ omp::TargetOp &targetOp, Block *targetBlock,
571
+ Block *newTargetBlock,
572
+ SmallVector<Value> &mapOperands,
573
+ SmallVector<Value> &allocs,
574
+ IRMapping &irMapping) {
575
+ // Map `map_operands` to block arguments.
576
+ unsigned originalMapVarsSize = targetOp.getMapVars ().size ();
577
+ for (unsigned i = 0 ; i < mapOperands.size (); ++i) {
578
+ Value originalValue;
579
+ BlockArgument newArg;
580
+ // Map the new arguments from the original block.
581
+ if (i < originalMapVarsSize) {
582
+ originalValue = targetBlock->getArgument (i);
583
+ newArg = newTargetBlock->addArgument (originalValue.getType (),
584
+ originalValue.getLoc ());
585
+ }
586
+ // Map the new arguments from the `allocs`.
587
+ else {
588
+ originalValue = allocs[i - originalMapVarsSize];
589
+ newArg = newTargetBlock->addArgument (
590
+ getPtrTypeForOmp (originalValue.getType ()), originalValue.getLoc ());
591
+ }
592
+ irMapping.map (originalValue, newArg);
593
+ }
594
+ // Map `private_vars` to block arguments.
595
+ unsigned originalPrivateVarsSize = targetOp.getPrivateVars ().size ();
596
+ for (unsigned i = 0 ; i < originalPrivateVarsSize; ++i) {
597
+ auto originalArg = targetBlock->getArgument (originalMapVarsSize + i);
577
598
auto newArg = newTargetBlock->addArgument (originalArg.getType (),
578
599
originalArg.getLoc ());
579
- mapping .map (originalArg, newArg);
600
+ irMapping .map (originalArg, newArg);
580
601
}
581
- auto llvmPtrTy = LLVM::LLVMPointerType::get (&ctx);
582
- for (auto original : allocs) {
583
- Value newArg = newTargetBlock->addArgument (
584
- getPtrTypeForOmp (original.getType ()), original.getLoc ());
602
+ return ;
603
+ }
604
+
605
+ static void reloadCacheAndRecompute (
606
+ Location loc, RewriterBase &rewriter, Operation *splitBefore,
607
+ omp::TargetOp &targetOp, Block *targetBlock, Block *newTargetBlock,
608
+ SmallVector<Value> &mapOperands, SmallVector<Value> &allocs,
609
+ SetVector<Operation *> &toRecompute, IRMapping &irMapping) {
610
+ createBlockArgsAndMap (loc, rewriter, targetOp, targetBlock, newTargetBlock,
611
+ mapOperands, allocs, irMapping);
612
+ // Handle the load operations for the allocs.
613
+ rewriter.setInsertionPointToStart (newTargetBlock);
614
+ auto llvmPtrTy = LLVM::LLVMPointerType::get (targetOp.getContext ());
615
+
616
+ unsigned originalMapVarsSize = targetOp.getMapVars ().size ();
617
+ // Create Stores for allocs.
618
+ for (unsigned i = 0 ; i < allocs.size (); ++i) {
619
+ Value original = allocs[i];
620
+ // Get the new block argument for this specific allocated value.
621
+ Value newArg = newTargetBlock->getArgument (originalMapVarsSize + i);
622
+
585
623
Value restored;
586
624
if (isPtr (original.getType ())) {
587
625
restored = rewriter.create <LLVM::LoadOp>(loc, llvmPtrTy, newArg);
@@ -591,18 +629,18 @@ static void reloadCacheAndRecompute(Location loc, RewriterBase &rewriter,
591
629
} else {
592
630
restored = rewriter.create <fir::LoadOp>(loc, newArg);
593
631
}
594
- mapping .map (original, restored);
632
+ irMapping .map (original, restored);
595
633
}
634
+
596
635
for (auto it = targetBlock->begin (); it != splitBefore->getIterator (); it++) {
597
636
if (toRecompute.contains (&*it))
598
- rewriter.clone (*it, mapping );
637
+ rewriter.clone (*it, irMapping );
599
638
}
600
639
}
601
640
602
641
static SplitResult isolateOp (Operation *splitBeforeOp, bool splitAfter,
603
642
RewriterBase &rewriter) {
604
643
auto targetOp = cast<omp::TargetOp>(splitBeforeOp->getParentOp ());
605
- MLIRContext &ctx = *targetOp.getContext ();
606
644
assert (targetOp);
607
645
auto loc = targetOp.getLoc ();
608
646
auto *targetBlock = &targetOp.getRegion ().front ();
@@ -657,22 +695,29 @@ static SplitResult isolateOp(Operation *splitBeforeOp, bool splitAfter,
657
695
auto *preTargetBlock = rewriter.createBlock (
658
696
&preTargetOp.getRegion (), preTargetOp.getRegion ().begin (), {}, {});
659
697
IRMapping preMapping;
660
- for (unsigned i = 0 ; i < targetBlock->getNumArguments (); i++) {
661
- auto originalArg = targetBlock->getArgument (i);
662
- auto newArg = preTargetBlock->addArgument (originalArg.getType (),
663
- originalArg.getLoc ());
664
- preMapping.map (originalArg, newArg);
665
- }
666
- for (auto it = targetBlock->begin (); it != splitBeforeOp->getIterator (); it++)
667
- rewriter.clone (*it, preMapping);
668
698
699
+ createBlockArgsAndMap (loc, rewriter, targetOp, targetBlock, preTargetBlock,
700
+ preMapOperands, allocs, preMapping);
701
+
702
+ // Handle the store operations for the allocs.
703
+ rewriter.setInsertionPointToStart (preTargetBlock);
669
704
auto llvmPtrTy = LLVM::LLVMPointerType::get (targetOp.getContext ());
670
705
671
- for (auto original : allocs) {
672
- Value toStore = preMapping.lookup (original);
673
- auto newArg = preTargetBlock->addArgument (
674
- getPtrTypeForOmp (original.getType ()), original.getLoc ());
675
- if (isPtr (original.getType ())) {
706
+ // Clone the original operations.
707
+ for (auto it = targetBlock->begin (); it != splitBeforeOp->getIterator ();
708
+ it++) {
709
+ rewriter.clone (*it, preMapping);
710
+ }
711
+
712
+ unsigned originalMapVarsSize = targetOp.getMapVars ().size ();
713
+ // Create Stores for allocs.
714
+ for (unsigned i = 0 ; i < allocs.size (); ++i) {
715
+ Value originalResult = allocs[i];
716
+ Value toStore = preMapping.lookup (originalResult);
717
+ // Get the new block argument for this specific allocated value.
718
+ Value newArg = preTargetBlock->getArgument (originalMapVarsSize + i);
719
+
720
+ if (isPtr (originalResult.getType ())) {
676
721
if (!isa<LLVM::LLVMPointerType>(toStore.getType ()))
677
722
toStore = rewriter.create <fir::ConvertOp>(loc, llvmPtrTy, toStore);
678
723
rewriter.create <LLVM::StoreOp>(loc, toStore, newArg);
@@ -701,9 +746,9 @@ static SplitResult isolateOp(Operation *splitBeforeOp, bool splitAfter,
701
746
isolatedTargetOp.getRegion ().begin (), {}, {});
702
747
703
748
IRMapping isolatedMapping;
704
- reloadCacheAndRecompute (loc, rewriter, ctx, isolatedMapping, splitBeforeOp ,
705
- targetBlock, isolatedTargetBlock , allocs,
706
- toRecompute);
749
+ reloadCacheAndRecompute (loc, rewriter, splitBeforeOp, targetOp, targetBlock ,
750
+ isolatedTargetBlock, postMapOperands , allocs,
751
+ toRecompute, isolatedMapping );
707
752
rewriter.clone (*splitBeforeOp, isolatedMapping);
708
753
rewriter.create <omp::TerminatorOp>(loc);
709
754
@@ -725,8 +770,9 @@ static SplitResult isolateOp(Operation *splitBeforeOp, bool splitAfter,
725
770
auto *postTargetBlock = rewriter.createBlock (
726
771
&postTargetOp.getRegion (), postTargetOp.getRegion ().begin (), {}, {});
727
772
IRMapping postMapping;
728
- reloadCacheAndRecompute (loc, rewriter, ctx, postMapping, splitBeforeOp,
729
- targetBlock, postTargetBlock, allocs, toRecompute);
773
+ reloadCacheAndRecompute (loc, rewriter, splitBeforeOp, targetOp, targetBlock,
774
+ postTargetBlock, postMapOperands, allocs,
775
+ toRecompute, postMapping);
730
776
731
777
assert (splitBeforeOp->getNumResults () == 0 ||
732
778
llvm::all_of (splitBeforeOp->getResults (),
@@ -755,15 +801,24 @@ static void moveToHost(omp::TargetOp targetOp, RewriterBase &rewriter) {
755
801
Block *targetBlock = &targetOp.getRegion ().front ();
756
802
assert (targetBlock == &targetOp.getRegion ().back ());
757
803
IRMapping mapping;
758
- for (auto map :
759
- zip_equal (targetOp.getMapVars (), targetBlock->getArguments ())) {
760
- Value mapInfo = std::get<0 >(map);
761
- BlockArgument arg = std::get<1 >(map);
804
+ for (unsigned i = 0 ; i < targetOp.getMapVars ().size (); ++i) {
805
+ Value mapInfo = targetOp.getMapVars ()[i];
806
+ BlockArgument arg = targetBlock->getArguments ()[i];
762
807
Operation *op = mapInfo.getDefiningOp ();
763
808
assert (op);
764
809
auto mapInfoOp = cast<omp::MapInfoOp>(op);
810
+ // map the block argument to the host-side variable pointer
765
811
mapping.map (arg, mapInfoOp.getVarPtr ());
766
812
}
813
+ unsigned mapSize = targetOp.getMapVars ().size ();
814
+ for (unsigned i = 0 ; i < targetOp.getPrivateVars ().size (); ++i) {
815
+ Value privateVar = targetOp.getPrivateVars ()[i];
816
+ // The mapping should link the device-side variable to the host-side one.
817
+ BlockArgument arg = targetBlock->getArguments ()[mapSize + i];
818
+ // Map the device-side copy (`arg`) to the host-side value (`privateVar`).
819
+ mapping.map (arg, privateVar);
820
+ }
821
+
767
822
rewriter.setInsertionPoint (targetOp);
768
823
SmallVector<Operation *> opsToReplace;
769
824
Value device = targetOp.getDevice ();
@@ -813,6 +868,7 @@ static void moveToHost(omp::TargetOp targetOp, RewriterBase &rewriter) {
813
868
// fir.declare changes its type when hoisting it out of omp.target to
814
869
// omp.target_data Introduce a load, if original declareOp input is not of
815
870
// reference type, but cloned delcareOp input is reference type.
871
+
816
872
if (fir::DeclareOp clonedDeclareOp = dyn_cast<fir::DeclareOp>(clonedOp)) {
817
873
auto originalDeclareOp = cast<fir::DeclareOp>(op);
818
874
Type originalInType = originalDeclareOp.getMemref ().getType ();
@@ -833,6 +889,7 @@ static void moveToHost(omp::TargetOp targetOp, RewriterBase &rewriter) {
833
889
opsToReplace.push_back (clonedOp);
834
890
}
835
891
}
892
+
836
893
for (Operation *op : opsToReplace) {
837
894
if (auto allocOp = dyn_cast<fir::AllocMemOp>(op)) {
838
895
rewriter.setInsertionPoint (allocOp);
0 commit comments