Skip to content

Commit b4f1e0e

Browse files
authored
[Flang][OpenMP] Implicitly map nested allocatable components in derived types (llvm#160116)
This PR adds support for nested derived types and their mappers to the MapInfoFinalization pass. - Generalize MapInfoFinalization to add child maps for arbitrarily nested allocatables when a derived object is mapped via declare mapper. - Traverse HLFIR designates rooted at the target block arg and build full coordinate_of chains; append members with correct membersIndex. This fixes llvm#156461.
1 parent d2ac21d commit b4f1e0e

File tree

3 files changed

+206
-55
lines changed

3 files changed

+206
-55
lines changed

flang/lib/Optimizer/OpenMP/MapInfoFinalization.cpp

Lines changed: 125 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -701,105 +701,175 @@ class MapInfoFinalizationPass
701701

702702
auto recordType = mlir::cast<fir::RecordType>(underlyingType);
703703
llvm::SmallVector<mlir::Value> newMapOpsForFields;
704-
llvm::SmallVector<int64_t> fieldIndicies;
704+
llvm::SmallVector<llvm::SmallVector<int64_t>> newMemberIndexPaths;
705705

706-
for (auto fieldMemTyPair : recordType.getTypeList()) {
707-
auto &field = fieldMemTyPair.first;
708-
auto memTy = fieldMemTyPair.second;
709-
710-
bool shouldMapField =
711-
llvm::find_if(mapVarForwardSlice, [&](mlir::Operation *sliceOp) {
712-
if (!fir::isAllocatableType(memTy))
713-
return false;
714-
715-
auto designateOp = mlir::dyn_cast<hlfir::DesignateOp>(sliceOp);
716-
if (!designateOp)
717-
return false;
718-
719-
return designateOp.getComponent() &&
720-
designateOp.getComponent()->strref() == field;
721-
}) != mapVarForwardSlice.end();
722-
723-
// TODO Handle recursive record types. Adapting
724-
// `createParentSymAndGenIntermediateMaps` to work direclty on MLIR
725-
// entities might be helpful here.
726-
727-
if (!shouldMapField)
728-
continue;
729-
730-
int32_t fieldIdx = recordType.getFieldIndex(field);
706+
auto appendMemberMap = [&](mlir::Location loc, mlir::Value coordRef,
707+
mlir::Type memTy,
708+
llvm::ArrayRef<int64_t> indexPath,
709+
llvm::StringRef memberName) {
710+
// Check if already mapped (index path equality).
731711
bool alreadyMapped = [&]() {
732712
if (op.getMembersIndexAttr())
733713
for (auto indexList : op.getMembersIndexAttr()) {
734714
auto indexListAttr = mlir::cast<mlir::ArrayAttr>(indexList);
735-
if (indexListAttr.size() == 1 &&
736-
mlir::cast<mlir::IntegerAttr>(indexListAttr[0]).getInt() ==
737-
fieldIdx)
715+
if (indexListAttr.size() != indexPath.size())
716+
continue;
717+
bool allEq = true;
718+
for (auto [i, attr] : llvm::enumerate(indexListAttr)) {
719+
if (mlir::cast<mlir::IntegerAttr>(attr).getInt() !=
720+
indexPath[i]) {
721+
allEq = false;
722+
break;
723+
}
724+
}
725+
if (allEq)
738726
return true;
739727
}
740728

741729
return false;
742730
}();
743731

744732
if (alreadyMapped)
745-
continue;
733+
return;
746734

747735
builder.setInsertionPoint(op);
748-
fir::IntOrValue idxConst =
749-
mlir::IntegerAttr::get(builder.getI32Type(), fieldIdx);
750-
auto fieldCoord = fir::CoordinateOp::create(
751-
builder, op.getLoc(), builder.getRefType(memTy), op.getVarPtr(),
752-
llvm::SmallVector<fir::IntOrValue, 1>{idxConst});
753736
fir::factory::AddrAndBoundsInfo info =
754-
fir::factory::getDataOperandBaseAddr(
755-
builder, fieldCoord, /*isOptional=*/false, op.getLoc());
737+
fir::factory::getDataOperandBaseAddr(builder, coordRef,
738+
/*isOptional=*/false, loc);
756739
llvm::SmallVector<mlir::Value> bounds =
757740
fir::factory::genImplicitBoundsOps<mlir::omp::MapBoundsOp,
758741
mlir::omp::MapBoundsType>(
759742
builder, info,
760-
hlfir::translateToExtendedValue(op.getLoc(), builder,
761-
hlfir::Entity{fieldCoord})
743+
hlfir::translateToExtendedValue(loc, builder,
744+
hlfir::Entity{coordRef})
762745
.first,
763-
/*dataExvIsAssumedSize=*/false, op.getLoc());
746+
/*dataExvIsAssumedSize=*/false, loc);
764747

765748
mlir::omp::MapInfoOp fieldMapOp = mlir::omp::MapInfoOp::create(
766-
builder, op.getLoc(), fieldCoord.getResult().getType(),
767-
fieldCoord.getResult(),
768-
mlir::TypeAttr::get(
769-
fir::unwrapRefType(fieldCoord.getResult().getType())),
749+
builder, loc, coordRef.getType(), coordRef,
750+
mlir::TypeAttr::get(fir::unwrapRefType(coordRef.getType())),
770751
op.getMapTypeAttr(),
771752
builder.getAttr<mlir::omp::VariableCaptureKindAttr>(
772753
mlir::omp::VariableCaptureKind::ByRef),
773754
/*varPtrPtr=*/mlir::Value{}, /*members=*/mlir::ValueRange{},
774755
/*members_index=*/mlir::ArrayAttr{}, bounds,
775756
/*mapperId=*/mlir::FlatSymbolRefAttr(),
776-
builder.getStringAttr(op.getNameAttr().strref() + "." + field +
777-
".implicit_map"),
757+
builder.getStringAttr(op.getNameAttr().strref() + "." +
758+
memberName + ".implicit_map"),
778759
/*partial_map=*/builder.getBoolAttr(false));
779760
newMapOpsForFields.emplace_back(fieldMapOp);
780-
fieldIndicies.emplace_back(fieldIdx);
761+
newMemberIndexPaths.emplace_back(indexPath.begin(), indexPath.end());
762+
};
763+
764+
// 1) Handle direct top-level allocatable fields (existing behavior).
765+
for (auto fieldMemTyPair : recordType.getTypeList()) {
766+
auto &field = fieldMemTyPair.first;
767+
auto memTy = fieldMemTyPair.second;
768+
769+
if (!fir::isAllocatableType(memTy))
770+
continue;
771+
772+
bool referenced = llvm::any_of(mapVarForwardSlice, [&](auto *opv) {
773+
auto designateOp = mlir::dyn_cast<hlfir::DesignateOp>(opv);
774+
return designateOp && designateOp.getComponent() &&
775+
designateOp.getComponent()->strref() == field;
776+
});
777+
if (!referenced)
778+
continue;
779+
780+
int32_t fieldIdx = recordType.getFieldIndex(field);
781+
builder.setInsertionPoint(op);
782+
fir::IntOrValue idxConst =
783+
mlir::IntegerAttr::get(builder.getI32Type(), fieldIdx);
784+
auto fieldCoord = fir::CoordinateOp::create(
785+
builder, op.getLoc(), builder.getRefType(memTy), op.getVarPtr(),
786+
llvm::SmallVector<fir::IntOrValue, 1>{idxConst});
787+
appendMemberMap(op.getLoc(), fieldCoord, memTy, {fieldIdx}, field);
788+
}
789+
790+
// Handle nested allocatable fields along any component chain
791+
// referenced in the region via HLFIR designates.
792+
for (mlir::Operation *sliceOp : mapVarForwardSlice) {
793+
auto designateOp = mlir::dyn_cast<hlfir::DesignateOp>(sliceOp);
794+
if (!designateOp || !designateOp.getComponent())
795+
continue;
796+
llvm::SmallVector<llvm::StringRef> compPathReversed;
797+
compPathReversed.push_back(designateOp.getComponent()->strref());
798+
mlir::Value curBase = designateOp.getMemref();
799+
bool rootedAtMapArg = false;
800+
while (true) {
801+
if (auto parentDes = curBase.getDefiningOp<hlfir::DesignateOp>()) {
802+
if (!parentDes.getComponent())
803+
break;
804+
compPathReversed.push_back(parentDes.getComponent()->strref());
805+
curBase = parentDes.getMemref();
806+
continue;
807+
}
808+
if (auto decl = curBase.getDefiningOp<hlfir::DeclareOp>()) {
809+
if (auto barg =
810+
mlir::dyn_cast<mlir::BlockArgument>(decl.getMemref()))
811+
rootedAtMapArg = (barg == opBlockArg);
812+
} else if (auto blockArg =
813+
mlir::dyn_cast_or_null<mlir::BlockArgument>(
814+
curBase)) {
815+
rootedAtMapArg = (blockArg == opBlockArg);
816+
}
817+
break;
818+
}
819+
if (!rootedAtMapArg || compPathReversed.size() < 2)
820+
continue;
821+
builder.setInsertionPoint(op);
822+
llvm::SmallVector<int64_t> indexPath;
823+
mlir::Type curTy = underlyingType;
824+
mlir::Value coordRef = op.getVarPtr();
825+
bool validPath = true;
826+
for (llvm::StringRef compName : llvm::reverse(compPathReversed)) {
827+
auto recTy = mlir::dyn_cast<fir::RecordType>(curTy);
828+
if (!recTy) {
829+
validPath = false;
830+
break;
831+
}
832+
int32_t idx = recTy.getFieldIndex(compName);
833+
if (idx < 0) {
834+
validPath = false;
835+
break;
836+
}
837+
indexPath.push_back(idx);
838+
mlir::Type memTy = recTy.getType(idx);
839+
fir::IntOrValue idxConst =
840+
mlir::IntegerAttr::get(builder.getI32Type(), idx);
841+
coordRef = fir::CoordinateOp::create(
842+
builder, op.getLoc(), builder.getRefType(memTy), coordRef,
843+
llvm::SmallVector<fir::IntOrValue, 1>{idxConst});
844+
curTy = memTy;
845+
}
846+
if (!validPath)
847+
continue;
848+
if (auto finalRefTy =
849+
mlir::dyn_cast<fir::ReferenceType>(coordRef.getType())) {
850+
mlir::Type eleTy = finalRefTy.getElementType();
851+
if (fir::isAllocatableType(eleTy))
852+
appendMemberMap(op.getLoc(), coordRef, eleTy, indexPath,
853+
compPathReversed.front());
854+
}
781855
}
782856

783857
if (newMapOpsForFields.empty())
784858
return mlir::WalkResult::advance();
785859

786860
op.getMembersMutable().append(newMapOpsForFields);
787861
llvm::SmallVector<llvm::SmallVector<int64_t>> newMemberIndices;
788-
mlir::ArrayAttr oldMembersIdxAttr = op.getMembersIndexAttr();
789-
790-
if (oldMembersIdxAttr)
791-
for (mlir::Attribute indexList : oldMembersIdxAttr) {
862+
if (mlir::ArrayAttr oldAttr = op.getMembersIndexAttr())
863+
for (mlir::Attribute indexList : oldAttr) {
792864
llvm::SmallVector<int64_t> listVec;
793865

794866
for (mlir::Attribute index : mlir::cast<mlir::ArrayAttr>(indexList))
795867
listVec.push_back(mlir::cast<mlir::IntegerAttr>(index).getInt());
796868

797869
newMemberIndices.emplace_back(std::move(listVec));
798870
}
799-
800-
for (int64_t newFieldIdx : fieldIndicies)
801-
newMemberIndices.emplace_back(
802-
llvm::SmallVector<int64_t>(1, newFieldIdx));
871+
for (auto &path : newMemberIndexPaths)
872+
newMemberIndices.emplace_back(path);
803873

804874
op.setMembersIndexAttr(builder.create2DI64ArrayAttr(newMemberIndices));
805875
op.setPartialMap(true);

flang/test/Lower/OpenMP/declare-mapper.f90

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
! RUN: %flang_fc1 -emit-hlfir -fopenmp -fopenmp-version=50 %t/omp-declare-mapper-3.f90 -o - | FileCheck %t/omp-declare-mapper-3.f90
77
! RUN: %flang_fc1 -emit-hlfir -fopenmp -fopenmp-version=50 %t/omp-declare-mapper-4.f90 -o - | FileCheck %t/omp-declare-mapper-4.f90
88
! RUN: %flang_fc1 -emit-hlfir -fopenmp -fopenmp-version=50 %t/omp-declare-mapper-5.f90 -o - | FileCheck %t/omp-declare-mapper-5.f90
9+
! RUN: %flang_fc1 -emit-hlfir -fopenmp -fopenmp-version=51 %t/omp-declare-mapper-6.f90 -o - | FileCheck %t/omp-declare-mapper-6.f90
910

1011
!--- omp-declare-mapper-1.f90
1112
subroutine declare_mapper_1
@@ -262,3 +263,40 @@ subroutine use_inner()
262263
!$omp end target
263264
end subroutine
264265
end program declare_mapper_5
266+
267+
!--- omp-declare-mapper-6.f90
268+
subroutine declare_mapper_nested_parent
269+
type :: inner_t
270+
real, allocatable :: deep_arr(:)
271+
end type inner_t
272+
273+
type, abstract :: base_t
274+
real, allocatable :: base_arr(:)
275+
type(inner_t) :: inner
276+
end type base_t
277+
278+
type, extends(base_t) :: real_t
279+
real, allocatable :: real_arr(:)
280+
end type real_t
281+
282+
!$omp declare mapper (custommapper : real_t :: t) map(tofrom: t%base_arr, t%real_arr)
283+
284+
type(real_t) :: r
285+
286+
allocate(r%base_arr(10))
287+
allocate(r%inner%deep_arr(10))
288+
allocate(r%real_arr(10))
289+
r%base_arr = 1.0
290+
r%inner%deep_arr = 4.0
291+
r%real_arr = 0.0
292+
293+
! CHECK: omp.target
294+
! Check implicit maps for nested parent and deep nested allocatable payloads
295+
! CHECK-DAG: omp.map.info {{.*}} {name = "r.base_arr.implicit_map"}
296+
! CHECK-DAG: omp.map.info {{.*}} {name = "r.deep_arr.implicit_map"}
297+
! The declared mapper's own allocatable is still mapped implicitly
298+
! CHECK-DAG: omp.map.info {{.*}} {name = "r.real_arr.implicit_map"}
299+
!$omp target map(mapper(custommapper), tofrom: r)
300+
r%real_arr = r%base_arr(1) + r%inner%deep_arr(1)
301+
!$omp end target
302+
end subroutine declare_mapper_nested_parent
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
! This test validates that declare mapper for a derived type that extends
2+
! a parent type with an allocatable component correctly maps the nested
3+
! allocatable payload via the mapper when the whole object is mapped on
4+
! target.
5+
6+
! REQUIRES: flang, amdgpu
7+
8+
! RUN: %libomptarget-compile-fortran-run-and-check-generic
9+
10+
program target_declare_mapper_parent_allocatable
11+
implicit none
12+
13+
type, abstract :: base_t
14+
real, allocatable :: base_arr(:)
15+
end type base_t
16+
17+
type, extends(base_t) :: real_t
18+
real, allocatable :: real_arr(:)
19+
end type real_t
20+
!$omp declare mapper(custommapper: real_t :: t) map(t%base_arr, t%real_arr)
21+
22+
type(real_t) :: r
23+
integer :: i
24+
allocate(r%base_arr(10), source=1.0)
25+
allocate(r%real_arr(10), source=1.0)
26+
27+
!$omp target map(tofrom: r)
28+
do i = 1, size(r%base_arr)
29+
r%base_arr(i) = 2.0
30+
r%real_arr(i) = 3.0
31+
r%real_arr(i) = r%base_arr(1)
32+
end do
33+
!$omp end target
34+
35+
36+
!CHECK: base_arr: 2. 2. 2. 2. 2. 2. 2. 2. 2. 2.
37+
print*, "base_arr: ", r%base_arr
38+
!CHECK: real_arr: 2. 2. 2. 2. 2. 2. 2. 2. 2. 2.
39+
print*, "real_arr: ", r%real_arr
40+
41+
deallocate(r%real_arr)
42+
deallocate(r%base_arr)
43+
end program target_declare_mapper_parent_allocatable

0 commit comments

Comments
 (0)