@@ -701,105 +701,175 @@ class MapInfoFinalizationPass
701
701
702
702
auto recordType = mlir::cast<fir::RecordType>(underlyingType);
703
703
llvm::SmallVector<mlir::Value> newMapOpsForFields;
704
- llvm::SmallVector<int64_t > fieldIndicies ;
704
+ llvm::SmallVector<llvm::SmallVector< int64_t >> newMemberIndexPaths ;
705
705
706
- for (auto fieldMemTyPair : recordType.getTypeList ()) {
707
- auto &field = fieldMemTyPair.first ;
708
- auto memTy = fieldMemTyPair.second ;
709
-
710
- bool shouldMapField =
711
- llvm::find_if (mapVarForwardSlice, [&](mlir::Operation *sliceOp) {
712
- if (!fir::isAllocatableType (memTy))
713
- return false ;
714
-
715
- auto designateOp = mlir::dyn_cast<hlfir::DesignateOp>(sliceOp);
716
- if (!designateOp)
717
- return false ;
718
-
719
- return designateOp.getComponent () &&
720
- designateOp.getComponent ()->strref () == field;
721
- }) != mapVarForwardSlice.end ();
722
-
723
- // TODO Handle recursive record types. Adapting
724
- // `createParentSymAndGenIntermediateMaps` to work direclty on MLIR
725
- // entities might be helpful here.
726
-
727
- if (!shouldMapField)
728
- continue ;
729
-
730
- int32_t fieldIdx = recordType.getFieldIndex (field);
706
+ auto appendMemberMap = [&](mlir::Location loc, mlir::Value coordRef,
707
+ mlir::Type memTy,
708
+ llvm::ArrayRef<int64_t > indexPath,
709
+ llvm::StringRef memberName) {
710
+ // Check if already mapped (index path equality).
731
711
bool alreadyMapped = [&]() {
732
712
if (op.getMembersIndexAttr ())
733
713
for (auto indexList : op.getMembersIndexAttr ()) {
734
714
auto indexListAttr = mlir::cast<mlir::ArrayAttr>(indexList);
735
- if (indexListAttr.size () == 1 &&
736
- mlir::cast<mlir::IntegerAttr>(indexListAttr[0 ]).getInt () ==
737
- fieldIdx)
715
+ if (indexListAttr.size () != indexPath.size ())
716
+ continue ;
717
+ bool allEq = true ;
718
+ for (auto [i, attr] : llvm::enumerate (indexListAttr)) {
719
+ if (mlir::cast<mlir::IntegerAttr>(attr).getInt () !=
720
+ indexPath[i]) {
721
+ allEq = false ;
722
+ break ;
723
+ }
724
+ }
725
+ if (allEq)
738
726
return true ;
739
727
}
740
728
741
729
return false ;
742
730
}();
743
731
744
732
if (alreadyMapped)
745
- continue ;
733
+ return ;
746
734
747
735
builder.setInsertionPoint (op);
748
- fir::IntOrValue idxConst =
749
- mlir::IntegerAttr::get (builder.getI32Type (), fieldIdx);
750
- auto fieldCoord = fir::CoordinateOp::create (
751
- builder, op.getLoc (), builder.getRefType (memTy), op.getVarPtr (),
752
- llvm::SmallVector<fir::IntOrValue, 1 >{idxConst});
753
736
fir::factory::AddrAndBoundsInfo info =
754
- fir::factory::getDataOperandBaseAddr (
755
- builder, fieldCoord, /* isOptional=*/ false , op. getLoc () );
737
+ fir::factory::getDataOperandBaseAddr (builder, coordRef,
738
+ /* isOptional=*/ false , loc );
756
739
llvm::SmallVector<mlir::Value> bounds =
757
740
fir::factory::genImplicitBoundsOps<mlir::omp::MapBoundsOp,
758
741
mlir::omp::MapBoundsType>(
759
742
builder, info,
760
- hlfir::translateToExtendedValue (op. getLoc () , builder,
761
- hlfir::Entity{fieldCoord })
743
+ hlfir::translateToExtendedValue (loc , builder,
744
+ hlfir::Entity{coordRef })
762
745
.first ,
763
- /* dataExvIsAssumedSize=*/ false , op. getLoc () );
746
+ /* dataExvIsAssumedSize=*/ false , loc );
764
747
765
748
mlir::omp::MapInfoOp fieldMapOp = mlir::omp::MapInfoOp::create (
766
- builder, op.getLoc (), fieldCoord.getResult ().getType (),
767
- fieldCoord.getResult (),
768
- mlir::TypeAttr::get (
769
- fir::unwrapRefType (fieldCoord.getResult ().getType ())),
749
+ builder, loc, coordRef.getType (), coordRef,
750
+ mlir::TypeAttr::get (fir::unwrapRefType (coordRef.getType ())),
770
751
op.getMapTypeAttr (),
771
752
builder.getAttr <mlir::omp::VariableCaptureKindAttr>(
772
753
mlir::omp::VariableCaptureKind::ByRef),
773
754
/* varPtrPtr=*/ mlir::Value{}, /* members=*/ mlir::ValueRange{},
774
755
/* members_index=*/ mlir::ArrayAttr{}, bounds,
775
756
/* mapperId=*/ mlir::FlatSymbolRefAttr (),
776
- builder.getStringAttr (op.getNameAttr ().strref () + " ." + field +
777
- " .implicit_map" ),
757
+ builder.getStringAttr (op.getNameAttr ().strref () + " ." +
758
+ memberName + " .implicit_map" ),
778
759
/* partial_map=*/ builder.getBoolAttr (false ));
779
760
newMapOpsForFields.emplace_back (fieldMapOp);
780
- fieldIndicies.emplace_back (fieldIdx);
761
+ newMemberIndexPaths.emplace_back (indexPath.begin (), indexPath.end ());
762
+ };
763
+
764
+ // 1) Handle direct top-level allocatable fields (existing behavior).
765
+ for (auto fieldMemTyPair : recordType.getTypeList ()) {
766
+ auto &field = fieldMemTyPair.first ;
767
+ auto memTy = fieldMemTyPair.second ;
768
+
769
+ if (!fir::isAllocatableType (memTy))
770
+ continue ;
771
+
772
+ bool referenced = llvm::any_of (mapVarForwardSlice, [&](auto *opv) {
773
+ auto designateOp = mlir::dyn_cast<hlfir::DesignateOp>(opv);
774
+ return designateOp && designateOp.getComponent () &&
775
+ designateOp.getComponent ()->strref () == field;
776
+ });
777
+ if (!referenced)
778
+ continue ;
779
+
780
+ int32_t fieldIdx = recordType.getFieldIndex (field);
781
+ builder.setInsertionPoint (op);
782
+ fir::IntOrValue idxConst =
783
+ mlir::IntegerAttr::get (builder.getI32Type (), fieldIdx);
784
+ auto fieldCoord = fir::CoordinateOp::create (
785
+ builder, op.getLoc (), builder.getRefType (memTy), op.getVarPtr (),
786
+ llvm::SmallVector<fir::IntOrValue, 1 >{idxConst});
787
+ appendMemberMap (op.getLoc (), fieldCoord, memTy, {fieldIdx}, field);
788
+ }
789
+
790
+ // Handle nested allocatable fields along any component chain
791
+ // referenced in the region via HLFIR designates.
792
+ for (mlir::Operation *sliceOp : mapVarForwardSlice) {
793
+ auto designateOp = mlir::dyn_cast<hlfir::DesignateOp>(sliceOp);
794
+ if (!designateOp || !designateOp.getComponent ())
795
+ continue ;
796
+ llvm::SmallVector<llvm::StringRef> compPathReversed;
797
+ compPathReversed.push_back (designateOp.getComponent ()->strref ());
798
+ mlir::Value curBase = designateOp.getMemref ();
799
+ bool rootedAtMapArg = false ;
800
+ while (true ) {
801
+ if (auto parentDes = curBase.getDefiningOp <hlfir::DesignateOp>()) {
802
+ if (!parentDes.getComponent ())
803
+ break ;
804
+ compPathReversed.push_back (parentDes.getComponent ()->strref ());
805
+ curBase = parentDes.getMemref ();
806
+ continue ;
807
+ }
808
+ if (auto decl = curBase.getDefiningOp <hlfir::DeclareOp>()) {
809
+ if (auto barg =
810
+ mlir::dyn_cast<mlir::BlockArgument>(decl.getMemref ()))
811
+ rootedAtMapArg = (barg == opBlockArg);
812
+ } else if (auto blockArg =
813
+ mlir::dyn_cast_or_null<mlir::BlockArgument>(
814
+ curBase)) {
815
+ rootedAtMapArg = (blockArg == opBlockArg);
816
+ }
817
+ break ;
818
+ }
819
+ if (!rootedAtMapArg || compPathReversed.size () < 2 )
820
+ continue ;
821
+ builder.setInsertionPoint (op);
822
+ llvm::SmallVector<int64_t > indexPath;
823
+ mlir::Type curTy = underlyingType;
824
+ mlir::Value coordRef = op.getVarPtr ();
825
+ bool validPath = true ;
826
+ for (llvm::StringRef compName : llvm::reverse (compPathReversed)) {
827
+ auto recTy = mlir::dyn_cast<fir::RecordType>(curTy);
828
+ if (!recTy) {
829
+ validPath = false ;
830
+ break ;
831
+ }
832
+ int32_t idx = recTy.getFieldIndex (compName);
833
+ if (idx < 0 ) {
834
+ validPath = false ;
835
+ break ;
836
+ }
837
+ indexPath.push_back (idx);
838
+ mlir::Type memTy = recTy.getType (idx);
839
+ fir::IntOrValue idxConst =
840
+ mlir::IntegerAttr::get (builder.getI32Type (), idx);
841
+ coordRef = fir::CoordinateOp::create (
842
+ builder, op.getLoc (), builder.getRefType (memTy), coordRef,
843
+ llvm::SmallVector<fir::IntOrValue, 1 >{idxConst});
844
+ curTy = memTy;
845
+ }
846
+ if (!validPath)
847
+ continue ;
848
+ if (auto finalRefTy =
849
+ mlir::dyn_cast<fir::ReferenceType>(coordRef.getType ())) {
850
+ mlir::Type eleTy = finalRefTy.getElementType ();
851
+ if (fir::isAllocatableType (eleTy))
852
+ appendMemberMap (op.getLoc (), coordRef, eleTy, indexPath,
853
+ compPathReversed.front ());
854
+ }
781
855
}
782
856
783
857
if (newMapOpsForFields.empty ())
784
858
return mlir::WalkResult::advance ();
785
859
786
860
op.getMembersMutable ().append (newMapOpsForFields);
787
861
llvm::SmallVector<llvm::SmallVector<int64_t >> newMemberIndices;
788
- mlir::ArrayAttr oldMembersIdxAttr = op.getMembersIndexAttr ();
789
-
790
- if (oldMembersIdxAttr)
791
- for (mlir::Attribute indexList : oldMembersIdxAttr) {
862
+ if (mlir::ArrayAttr oldAttr = op.getMembersIndexAttr ())
863
+ for (mlir::Attribute indexList : oldAttr) {
792
864
llvm::SmallVector<int64_t > listVec;
793
865
794
866
for (mlir::Attribute index : mlir::cast<mlir::ArrayAttr>(indexList))
795
867
listVec.push_back (mlir::cast<mlir::IntegerAttr>(index).getInt ());
796
868
797
869
newMemberIndices.emplace_back (std::move (listVec));
798
870
}
799
-
800
- for (int64_t newFieldIdx : fieldIndicies)
801
- newMemberIndices.emplace_back (
802
- llvm::SmallVector<int64_t >(1 , newFieldIdx));
871
+ for (auto &path : newMemberIndexPaths)
872
+ newMemberIndices.emplace_back (path);
803
873
804
874
op.setMembersIndexAttr (builder.create2DI64ArrayAttr (newMemberIndices));
805
875
op.setPartialMap (true );
0 commit comments