Skip to content

Commit 8c3785a

Browse files
committed
Handle case when private maps are present in omp.target
1 parent 408eca8 commit 8c3785a

File tree

1 file changed

+95
-38
lines changed

1 file changed

+95
-38
lines changed

flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp

Lines changed: 95 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -566,22 +566,60 @@ static void collectNonRecomputableDeps(Value &v, omp::TargetOp targetOp,
566566
toRecompute);
567567
}
568568

569-
static void reloadCacheAndRecompute(Location loc, RewriterBase &rewriter,
570-
MLIRContext &ctx, IRMapping &mapping,
571-
Operation *splitBefore, Block *targetBlock,
572-
Block *newTargetBlock,
573-
SmallVector<Value> &allocs,
574-
SetVector<Operation *> &toRecompute) {
575-
for (unsigned i = 0; i < targetBlock->getNumArguments(); i++) {
576-
auto originalArg = targetBlock->getArgument(i);
569+
static void createBlockArgsAndMap(Location loc, RewriterBase &rewriter,
570+
omp::TargetOp &targetOp, Block *targetBlock,
571+
Block *newTargetBlock,
572+
SmallVector<Value> &mapOperands,
573+
SmallVector<Value> &allocs,
574+
IRMapping &irMapping) {
575+
// Map `map_operands` to block arguments.
576+
unsigned originalMapVarsSize = targetOp.getMapVars().size();
577+
for (unsigned i = 0; i < mapOperands.size(); ++i) {
578+
Value originalValue;
579+
BlockArgument newArg;
580+
// Map the new arguments from the original block.
581+
if (i < originalMapVarsSize) {
582+
originalValue = targetBlock->getArgument(i);
583+
newArg = newTargetBlock->addArgument(originalValue.getType(),
584+
originalValue.getLoc());
585+
}
586+
// Map the new arguments from the `allocs`.
587+
else {
588+
originalValue = allocs[i - originalMapVarsSize];
589+
newArg = newTargetBlock->addArgument(
590+
getPtrTypeForOmp(originalValue.getType()), originalValue.getLoc());
591+
}
592+
irMapping.map(originalValue, newArg);
593+
}
594+
// Map `private_vars` to block arguments.
595+
unsigned originalPrivateVarsSize = targetOp.getPrivateVars().size();
596+
for (unsigned i = 0; i < originalPrivateVarsSize; ++i) {
597+
auto originalArg = targetBlock->getArgument(originalMapVarsSize + i);
577598
auto newArg = newTargetBlock->addArgument(originalArg.getType(),
578599
originalArg.getLoc());
579-
mapping.map(originalArg, newArg);
600+
irMapping.map(originalArg, newArg);
580601
}
581-
auto llvmPtrTy = LLVM::LLVMPointerType::get(&ctx);
582-
for (auto original : allocs) {
583-
Value newArg = newTargetBlock->addArgument(
584-
getPtrTypeForOmp(original.getType()), original.getLoc());
602+
return;
603+
}
604+
605+
static void reloadCacheAndRecompute(
606+
Location loc, RewriterBase &rewriter, Operation *splitBefore,
607+
omp::TargetOp &targetOp, Block *targetBlock, Block *newTargetBlock,
608+
SmallVector<Value> &mapOperands, SmallVector<Value> &allocs,
609+
SetVector<Operation *> &toRecompute, IRMapping &irMapping) {
610+
createBlockArgsAndMap(loc, rewriter, targetOp, targetBlock, newTargetBlock,
611+
mapOperands, allocs, irMapping);
612+
// Handle the load operations for the allocs.
613+
rewriter.setInsertionPointToStart(newTargetBlock);
614+
auto llvmPtrTy = LLVM::LLVMPointerType::get(targetOp.getContext());
615+
616+
unsigned originalMapVarsSize = targetOp.getMapVars().size();
617+
// Create Stores for allocs.
618+
for (unsigned i = 0; i < allocs.size(); ++i) {
619+
Value original = allocs[i];
620+
// Get the new block argument for this specific allocated value.
621+
Value newArg = newTargetBlock->getArgument(originalMapVarsSize + i);
622+
585623
Value restored;
586624
if (isPtr(original.getType())) {
587625
restored = rewriter.create<LLVM::LoadOp>(loc, llvmPtrTy, newArg);
@@ -591,18 +629,18 @@ static void reloadCacheAndRecompute(Location loc, RewriterBase &rewriter,
591629
} else {
592630
restored = rewriter.create<fir::LoadOp>(loc, newArg);
593631
}
594-
mapping.map(original, restored);
632+
irMapping.map(original, restored);
595633
}
634+
596635
for (auto it = targetBlock->begin(); it != splitBefore->getIterator(); it++) {
597636
if (toRecompute.contains(&*it))
598-
rewriter.clone(*it, mapping);
637+
rewriter.clone(*it, irMapping);
599638
}
600639
}
601640

602641
static SplitResult isolateOp(Operation *splitBeforeOp, bool splitAfter,
603642
RewriterBase &rewriter) {
604643
auto targetOp = cast<omp::TargetOp>(splitBeforeOp->getParentOp());
605-
MLIRContext &ctx = *targetOp.getContext();
606644
assert(targetOp);
607645
auto loc = targetOp.getLoc();
608646
auto *targetBlock = &targetOp.getRegion().front();
@@ -657,22 +695,29 @@ static SplitResult isolateOp(Operation *splitBeforeOp, bool splitAfter,
657695
auto *preTargetBlock = rewriter.createBlock(
658696
&preTargetOp.getRegion(), preTargetOp.getRegion().begin(), {}, {});
659697
IRMapping preMapping;
660-
for (unsigned i = 0; i < targetBlock->getNumArguments(); i++) {
661-
auto originalArg = targetBlock->getArgument(i);
662-
auto newArg = preTargetBlock->addArgument(originalArg.getType(),
663-
originalArg.getLoc());
664-
preMapping.map(originalArg, newArg);
665-
}
666-
for (auto it = targetBlock->begin(); it != splitBeforeOp->getIterator(); it++)
667-
rewriter.clone(*it, preMapping);
668698

699+
createBlockArgsAndMap(loc, rewriter, targetOp, targetBlock, preTargetBlock,
700+
preMapOperands, allocs, preMapping);
701+
702+
// Handle the store operations for the allocs.
703+
rewriter.setInsertionPointToStart(preTargetBlock);
669704
auto llvmPtrTy = LLVM::LLVMPointerType::get(targetOp.getContext());
670705

671-
for (auto original : allocs) {
672-
Value toStore = preMapping.lookup(original);
673-
auto newArg = preTargetBlock->addArgument(
674-
getPtrTypeForOmp(original.getType()), original.getLoc());
675-
if (isPtr(original.getType())) {
706+
// Clone the original operations.
707+
for (auto it = targetBlock->begin(); it != splitBeforeOp->getIterator();
708+
it++) {
709+
rewriter.clone(*it, preMapping);
710+
}
711+
712+
unsigned originalMapVarsSize = targetOp.getMapVars().size();
713+
// Create Stores for allocs.
714+
for (unsigned i = 0; i < allocs.size(); ++i) {
715+
Value originalResult = allocs[i];
716+
Value toStore = preMapping.lookup(originalResult);
717+
// Get the new block argument for this specific allocated value.
718+
Value newArg = preTargetBlock->getArgument(originalMapVarsSize + i);
719+
720+
if (isPtr(originalResult.getType())) {
676721
if (!isa<LLVM::LLVMPointerType>(toStore.getType()))
677722
toStore = rewriter.create<fir::ConvertOp>(loc, llvmPtrTy, toStore);
678723
rewriter.create<LLVM::StoreOp>(loc, toStore, newArg);
@@ -701,9 +746,9 @@ static SplitResult isolateOp(Operation *splitBeforeOp, bool splitAfter,
701746
isolatedTargetOp.getRegion().begin(), {}, {});
702747

703748
IRMapping isolatedMapping;
704-
reloadCacheAndRecompute(loc, rewriter, ctx, isolatedMapping, splitBeforeOp,
705-
targetBlock, isolatedTargetBlock, allocs,
706-
toRecompute);
749+
reloadCacheAndRecompute(loc, rewriter, splitBeforeOp, targetOp, targetBlock,
750+
isolatedTargetBlock, postMapOperands, allocs,
751+
toRecompute, isolatedMapping);
707752
rewriter.clone(*splitBeforeOp, isolatedMapping);
708753
rewriter.create<omp::TerminatorOp>(loc);
709754

@@ -725,8 +770,9 @@ static SplitResult isolateOp(Operation *splitBeforeOp, bool splitAfter,
725770
auto *postTargetBlock = rewriter.createBlock(
726771
&postTargetOp.getRegion(), postTargetOp.getRegion().begin(), {}, {});
727772
IRMapping postMapping;
728-
reloadCacheAndRecompute(loc, rewriter, ctx, postMapping, splitBeforeOp,
729-
targetBlock, postTargetBlock, allocs, toRecompute);
773+
reloadCacheAndRecompute(loc, rewriter, splitBeforeOp, targetOp, targetBlock,
774+
postTargetBlock, postMapOperands, allocs,
775+
toRecompute, postMapping);
730776

731777
assert(splitBeforeOp->getNumResults() == 0 ||
732778
llvm::all_of(splitBeforeOp->getResults(),
@@ -755,15 +801,24 @@ static void moveToHost(omp::TargetOp targetOp, RewriterBase &rewriter) {
755801
Block *targetBlock = &targetOp.getRegion().front();
756802
assert(targetBlock == &targetOp.getRegion().back());
757803
IRMapping mapping;
758-
for (auto map :
759-
zip_equal(targetOp.getMapVars(), targetBlock->getArguments())) {
760-
Value mapInfo = std::get<0>(map);
761-
BlockArgument arg = std::get<1>(map);
804+
for (unsigned i = 0; i < targetOp.getMapVars().size(); ++i) {
805+
Value mapInfo = targetOp.getMapVars()[i];
806+
BlockArgument arg = targetBlock->getArguments()[i];
762807
Operation *op = mapInfo.getDefiningOp();
763808
assert(op);
764809
auto mapInfoOp = cast<omp::MapInfoOp>(op);
810+
// map the block argument to the host-side variable pointer
765811
mapping.map(arg, mapInfoOp.getVarPtr());
766812
}
813+
unsigned mapSize = targetOp.getMapVars().size();
814+
for (unsigned i = 0; i < targetOp.getPrivateVars().size(); ++i) {
815+
Value privateVar = targetOp.getPrivateVars()[i];
816+
// The mapping should link the device-side variable to the host-side one.
817+
BlockArgument arg = targetBlock->getArguments()[mapSize + i];
818+
// Map the device-side copy (`arg`) to the host-side value (`privateVar`).
819+
mapping.map(arg, privateVar);
820+
}
821+
767822
rewriter.setInsertionPoint(targetOp);
768823
SmallVector<Operation *> opsToReplace;
769824
Value device = targetOp.getDevice();
@@ -813,6 +868,7 @@ static void moveToHost(omp::TargetOp targetOp, RewriterBase &rewriter) {
813868
// fir.declare changes its type when hoisting it out of omp.target to
814869
// omp.target_data Introduce a load, if original declareOp input is not of
815870
// reference type, but cloned delcareOp input is reference type.
871+
816872
if (fir::DeclareOp clonedDeclareOp = dyn_cast<fir::DeclareOp>(clonedOp)) {
817873
auto originalDeclareOp = cast<fir::DeclareOp>(op);
818874
Type originalInType = originalDeclareOp.getMemref().getType();
@@ -833,6 +889,7 @@ static void moveToHost(omp::TargetOp targetOp, RewriterBase &rewriter) {
833889
opsToReplace.push_back(clonedOp);
834890
}
835891
}
892+
836893
for (Operation *op : opsToReplace) {
837894
if (auto allocOp = dyn_cast<fir::AllocMemOp>(op)) {
838895
rewriter.setInsertionPoint(allocOp);

0 commit comments

Comments
 (0)