@@ -834,13 +834,140 @@ genI32Constant(mlir::Location loc, mlir::RewriterBase &rewriter, int value) {
834
834
return rewriter.create <mlir::LLVM::ConstantOp>(loc, i32Ty, attr);
835
835
}
836
836
837
- static Type getOmpDeviceType (MLIRContext *c) { return IntegerType::get (c, 32 ); }
837
+ static mlir::LLVM::ConstantOp
838
+ genI64Constant (mlir::Location loc, mlir::RewriterBase &rewriter, int value) {
839
+ mlir::Type i64Ty = rewriter.getI64Type ();
840
+ mlir::IntegerAttr attr = rewriter.getI64IntegerAttr (value);
841
+ return rewriter.create <mlir::LLVM::ConstantOp>(loc, i64Ty, attr);
842
+ }
843
+
844
+ static Value genDescriptorGetBaseAddress (fir::FirOpBuilder &builder,
845
+ Location loc, Value boxDesc) {
846
+ Value box = boxDesc;
847
+ if (auto refBox = dyn_cast<fir::ReferenceType>(boxDesc.getType ())) {
848
+ box = fir::LoadOp::create (builder, loc, boxDesc);
849
+ }
850
+ assert (isa<fir::BoxType>(box.getType ()) &&
851
+ " Unknown type passed to genDescriptorGetBaseAddress" );
852
+ auto i8Type = builder.getI8Type ();
853
+ auto unknownArrayType =
854
+ fir::SequenceType::get ({fir::SequenceType::getUnknownExtent ()}, i8Type);
855
+ auto i8BoxType = fir::BoxType::get (unknownArrayType);
856
+ auto typedBox = fir::ConvertOp::create (builder, loc, i8BoxType, box);
857
+ auto rawAddr = fir::BoxAddrOp::create (builder, loc, typedBox);
858
+ return rawAddr;
859
+ }
860
+
861
+ static Value genDescriptorGetTotalElements (fir::FirOpBuilder &builder,
862
+ Location loc, Value boxDesc) {
863
+ Value box = boxDesc;
864
+ if (auto refBox = dyn_cast<fir::ReferenceType>(boxDesc.getType ())) {
865
+ box = fir::LoadOp::create (builder, loc, boxDesc);
866
+ }
867
+ assert (isa<fir::BoxType>(box.getType ()) &&
868
+ " Unknown type passed to genDescriptorGetTotalElements" );
869
+ auto i64Type = builder.getI64Type ();
870
+ return fir::BoxTotalElementsOp::create (builder, loc, i64Type, box);
871
+ }
872
+
873
+ static Value genDescriptorGetEleSize (fir::FirOpBuilder &builder, Location loc,
874
+ Value boxDesc) {
875
+ Value box = boxDesc;
876
+ if (auto refBox = dyn_cast<fir::ReferenceType>(boxDesc.getType ())) {
877
+ box = fir::LoadOp::create (builder, loc, boxDesc);
878
+ }
879
+ assert (isa<fir::BoxType>(box.getType ()) &&
880
+ " Unknown type passed to genDescriptorGetElementSize" );
881
+ auto i64Type = builder.getI64Type ();
882
+ return fir::BoxEleSizeOp::create (builder, loc, i64Type, box);
883
+ }
884
+
885
+ static Value genDescriptorGetDataSizeInBytes (fir::FirOpBuilder &builder,
886
+ Location loc, Value boxDesc) {
887
+ Value box = boxDesc;
888
+ if (auto refBox = dyn_cast<fir::ReferenceType>(boxDesc.getType ())) {
889
+ box = fir::LoadOp::create (builder, loc, boxDesc);
890
+ }
891
+ assert (isa<fir::BoxType>(box.getType ()) &&
892
+ " Unknown type passed to genDescriptorGetElementSize" );
893
+ Value eleSize = genDescriptorGetEleSize (builder, loc, box);
894
+ Value totalElements = genDescriptorGetTotalElements (builder, loc, box);
895
+ return mlir::arith::MulIOp::create (builder, loc, totalElements, eleSize);
896
+ }
897
+
898
+ static mlir::Value genOmpGetMappedPtrIfPresent (fir::FirOpBuilder &builder,
899
+ mlir::Location loc,
900
+ mlir::Value hostPtr,
901
+ mlir::Value deviceNum,
902
+ mlir::ModuleOp module ) {
903
+ auto *context = builder.getContext ();
904
+ auto voidPtrType = fir::LLVMPointerType::get (context, builder.getI8Type ());
905
+ auto i32Type = builder.getI32Type ();
906
+ auto funcName = " omp_get_mapped_ptr" ;
907
+ auto funcOp = module .lookupSymbol <mlir::func::FuncOp>(funcName);
908
+
909
+ if (!funcOp) {
910
+ auto funcType =
911
+ mlir::FunctionType::get (context, {voidPtrType, i32Type}, {voidPtrType});
912
+
913
+ mlir::OpBuilder::InsertionGuard guard (builder);
914
+ builder.setInsertionPointToStart (module .getBody ());
915
+
916
+ funcOp = mlir::func::FuncOp::create (builder, loc, funcName, funcType);
917
+ funcOp.setPrivate ();
918
+ }
919
+
920
+ llvm::SmallVector<mlir::Value> args;
921
+ args.push_back (fir::ConvertOp::create (builder, loc, voidPtrType, hostPtr));
922
+ args.push_back (fir::ConvertOp::create (builder, loc, i32Type, deviceNum));
923
+ auto callOp = fir::CallOp::create (builder, loc, funcOp, args);
924
+ auto mappedPtr = callOp.getResult (0 );
925
+ auto isNull = builder.genIsNullAddr (loc, mappedPtr);
926
+ auto convertedHostPtr =
927
+ fir::ConvertOp::create (builder, loc, voidPtrType, hostPtr);
928
+ auto result = arith::SelectOp::create (builder, loc, isNull, convertedHostPtr,
929
+ mappedPtr);
930
+ return result;
931
+ }
932
+
933
+ static void genOmpTargetMemcpyCall (fir::FirOpBuilder &builder,
934
+ mlir::Location loc, mlir::Value dst,
935
+ mlir::Value src, mlir::Value length,
936
+ mlir::Value dstOffset, mlir::Value srcOffset,
937
+ mlir::Value device, mlir::ModuleOp module ) {
938
+ auto *context = builder.getContext ();
939
+ // int omp_target_memcpy(void *dst, const void *src, size_t length,
940
+ // size_t dst_offset, size_t src_offset,
941
+ // int dst_device, int src_device)
942
+ auto funcName = " omp_target_memcpy" ;
943
+ auto voidPtrType = fir::LLVMPointerType::get (context, builder.getI8Type ());
944
+ auto sizeTType = builder.getI64Type (); // assuming size_t is 64-bit
945
+ auto i32Type = builder.getI32Type ();
946
+ auto funcOp = module .lookupSymbol <mlir::func::FuncOp>(funcName);
947
+
948
+ if (!funcOp) {
949
+ mlir::OpBuilder::InsertionGuard guard (builder);
950
+ builder.setInsertionPointToStart (module .getBody ());
951
+ llvm::SmallVector<mlir::Type> argTypes = {
952
+ voidPtrType, voidPtrType, sizeTType, sizeTType,
953
+ sizeTType, i32Type, i32Type};
954
+ auto funcType = mlir::FunctionType::get (context, argTypes, {i32Type});
955
+ funcOp = mlir::func::FuncOp::create (builder, loc, funcName, funcType);
956
+ funcOp.setPrivate ();
957
+ }
958
+
959
+ llvm::SmallVector<mlir::Value> args{dst, src, length, dstOffset,
960
+ srcOffset, device, device};
961
+ fir::CallOp::create (builder, loc, funcOp, args);
962
+ return ;
963
+ }
838
964
839
965
// moveToHost method clones all the ops from target region outside of it.
840
966
// It hoists runtime functions and replaces them with omp vesions.
841
967
// Also hoists and replaces fir.allocmem with omp.target_allocmem and
842
968
// fir.freemem with omp.target_freemem
843
- static void moveToHost (omp::TargetOp targetOp, RewriterBase &rewriter) {
969
+ static void moveToHost (omp::TargetOp targetOp, RewriterBase &rewriter,
970
+ mlir::ModuleOp module ) {
844
971
OpBuilder::InsertionGuard guard (rewriter);
845
972
Block *targetBlock = &targetOp.getRegion ().front ();
846
973
assert (targetBlock == &targetOp.getRegion ().back ());
@@ -859,7 +986,7 @@ static void moveToHost(omp::TargetOp targetOp, RewriterBase &rewriter) {
859
986
Value privateVar = targetOp.getPrivateVars ()[i];
860
987
// The mapping should link the device-side variable to the host-side one.
861
988
BlockArgument arg = targetBlock->getArguments ()[mapSize + i];
862
- // Map the device-side copy (arg) to the host-side value (privateVar).
989
+ // Map the device-side copy (` arg` ) to the host-side value (` privateVar` ).
863
990
mapping.map (arg, privateVar);
864
991
}
865
992
@@ -872,69 +999,43 @@ static void moveToHost(omp::TargetOp targetOp, RewriterBase &rewriter) {
872
999
for (auto it = targetBlock->begin (), end = std::prev (targetBlock->end ());
873
1000
it != end; ++it) {
874
1001
auto *op = &*it;
875
- if (isRuntimeCall (op)) {
876
- fir::CallOp runtimeCall = cast<fir::CallOp>(op);
877
- auto module = runtimeCall->getParentOfType <ModuleOp>();
878
- auto callee =
879
- cast<func::FuncOp>(module .lookupSymbol (runtimeCall.getCalleeAttr ()));
880
- std::string newCalleeName = (callee.getName () + " _omp" ).str ();
881
- mlir::OpBuilder moduleBuilder (module .getBodyRegion ());
882
- func::FuncOp newCallee =
883
- cast_or_null<func::FuncOp>(module .lookupSymbol (newCalleeName));
884
- if (!newCallee) {
885
- SmallVector<Type> argTypes (callee.getFunctionType ().getInputs ());
886
- argTypes.push_back (getOmpDeviceType (rewriter.getContext ()));
887
- newCallee = moduleBuilder.create <func::FuncOp>(
888
- callee->getLoc (), newCalleeName,
889
- FunctionType::get (rewriter.getContext (), argTypes,
890
- callee.getFunctionType ().getResults ()));
891
- if (callee.getArgAttrs ())
892
- newCallee.setArgAttrsAttr (*callee.getArgAttrs ());
893
- if (callee.getResAttrs ())
894
- newCallee.setResAttrsAttr (*callee.getResAttrs ());
895
- newCallee.setSymVisibility (callee.getSymVisibility ());
896
- newCallee->setDiscardableAttrs (callee->getDiscardableAttrDictionary ());
897
- }
898
- SmallVector<Value> operands = runtimeCall.getOperands ();
899
- operands.push_back (device);
900
- auto tmpCall = rewriter.create <fir::CallOp>(
901
- runtimeCall.getLoc (), runtimeCall.getResultTypes (),
902
- SymbolRefAttr::get (newCallee), operands, nullptr , nullptr , nullptr ,
903
- runtimeCall.getFastmathAttr ());
904
- Operation *newCall = rewriter.clone (*tmpCall, mapping);
905
- mapping.map (&*it, newCall);
906
- rewriter.eraseOp (tmpCall);
907
- } else {
908
- Operation *clonedOp = rewriter.clone (*op, mapping);
909
- for (unsigned i = 0 ; i < op->getNumResults (); ++i) {
910
- mapping.map (op->getResult (i), clonedOp->getResult (i));
911
- }
912
- // fir.declare changes its type when hoisting it out of omp.target to
913
- // omp.target_data Introduce a load, if original declareOp input is not of
914
- // reference type, but cloned delcareOp input is reference type.
915
- if (fir::DeclareOp clonedDeclareOp = dyn_cast<fir::DeclareOp>(clonedOp)) {
916
- auto originalDeclareOp = cast<fir::DeclareOp>(op);
917
- Type originalInType = originalDeclareOp.getMemref ().getType ();
918
- Type clonedInType = clonedDeclareOp.getMemref ().getType ();
919
-
920
- fir::ReferenceType originalRefType =
921
- dyn_cast<fir::ReferenceType>(originalInType);
922
- fir::ReferenceType clonedRefType =
923
- dyn_cast<fir::ReferenceType>(clonedInType);
924
- if (!originalRefType && clonedRefType) {
925
- Type clonedEleTy = clonedRefType.getElementType ();
926
- if (clonedEleTy == originalDeclareOp.getType ()) {
927
- opsToReplace.push_back (clonedOp);
928
- }
1002
+ Operation *clonedOp = rewriter.clone (*op, mapping);
1003
+ for (unsigned i = 0 ; i < op->getNumResults (); ++i) {
1004
+ mapping.map (op->getResult (i), clonedOp->getResult (i));
1005
+ }
1006
+ // fir.declare changes its type when hoisting it out of omp.target to
1007
+ // omp.target_data Introduce a load, if original declareOp input is not of
1008
+ // reference type, but cloned delcareOp input is reference type.
1009
+
1010
+ if (fir::DeclareOp clonedDeclareOp = dyn_cast<fir::DeclareOp>(clonedOp)) {
1011
+ auto originalDeclareOp = cast<fir::DeclareOp>(op);
1012
+ Type originalInType = originalDeclareOp.getMemref ().getType ();
1013
+ Type clonedInType = clonedDeclareOp.getMemref ().getType ();
1014
+
1015
+ fir::ReferenceType originalRefType =
1016
+ dyn_cast<fir::ReferenceType>(originalInType);
1017
+ fir::ReferenceType clonedRefType =
1018
+ dyn_cast<fir::ReferenceType>(clonedInType);
1019
+ if (!originalRefType && clonedRefType) {
1020
+ Type clonedEleTy = clonedRefType.getElementType ();
1021
+ if (clonedEleTy == originalDeclareOp.getType ()) {
1022
+ opsToReplace.push_back (clonedOp);
929
1023
}
930
1024
}
1025
+ }
931
1026
if (isa<fir::AllocMemOp>(clonedOp) || isa<fir::FreeMemOp>(clonedOp))
932
1027
opsToReplace.push_back (clonedOp);
933
- }
1028
+ if (isRuntimeCall (clonedOp)) {
1029
+ fir::CallOp runtimeCall = cast<fir::CallOp>(op);
1030
+ if ((*runtimeCall.getCallee ()).getRootReference ().getValue () ==
1031
+ " _FortranAAssign" ) {
1032
+ opsToReplace.push_back (clonedOp);
1033
+ } else {
1034
+ llvm_unreachable (" Unhandled runtime call hoisting." );
1035
+ }
1036
+ }
934
1037
}
935
1038
936
- // Replace fir.allocmem with omp.target_allocmem,
937
- // fir.freemem with omp.target_freemem.
938
1039
for (Operation *op : opsToReplace) {
939
1040
if (auto allocOp = dyn_cast<fir::AllocMemOp>(op)) {
940
1041
rewriter.setInsertionPoint (allocOp);
@@ -963,16 +1064,40 @@ static void moveToHost(omp::TargetOp targetOp, RewriterBase &rewriter) {
963
1064
Value loadedValue = rewriter.create <fir::LoadOp>(
964
1065
clonedDeclareOp.getLoc (), clonedEleTy, clonedDeclareOp.getMemref ());
965
1066
clonedDeclareOp.getResult ().replaceAllUsesWith (loadedValue);
1067
+ } else if (isRuntimeCall (op)) {
1068
+ rewriter.setInsertionPoint (op);
1069
+ fir::CallOp runtimeCall = cast<fir::CallOp>(op);
1070
+ SmallVector<Value> operands = runtimeCall.getOperands ();
1071
+ mlir::Location loc = runtimeCall.getLoc ();
1072
+ fir::FirOpBuilder builder{rewriter, op};
1073
+ assert (operands.size () == 4 );
1074
+ Value sourceFile{operands[2 ]}, sourceLine{operands[3 ]};
1075
+
1076
+ auto fromBaseAddr =
1077
+ genDescriptorGetBaseAddress (builder, loc, operands[1 ]);
1078
+ auto toBaseAddr = genDescriptorGetBaseAddress (builder, loc, operands[0 ]);
1079
+ auto dataSizeInBytes =
1080
+ genDescriptorGetDataSizeInBytes (builder, loc, operands[1 ]);
1081
+
1082
+ Value toPtr =
1083
+ genOmpGetMappedPtrIfPresent (builder, loc, toBaseAddr, device, module );
1084
+ Value fromPtr = genOmpGetMappedPtrIfPresent (builder, loc, fromBaseAddr,
1085
+ device, module );
1086
+ Value zero = genI64Constant (loc, rewriter, 0 );
1087
+ genOmpTargetMemcpyCall (builder, loc, toPtr, fromPtr, dataSizeInBytes,
1088
+ zero, zero, device, module );
1089
+ rewriter.eraseOp (op);
966
1090
}
967
1091
}
968
1092
rewriter.eraseOp (targetOp);
969
1093
}
970
1094
971
- void fissionTarget (omp::TargetOp targetOp, RewriterBase &rewriter) {
1095
+ void fissionTarget (omp::TargetOp targetOp, RewriterBase &rewriter,
1096
+ mlir::ModuleOp module ) {
972
1097
auto tuple = getNestedOpToIsolate (targetOp);
973
1098
if (!tuple) {
974
1099
LLVM_DEBUG (llvm::dbgs () << " No op to isolate\n " );
975
- moveToHost (targetOp, rewriter);
1100
+ moveToHost (targetOp, rewriter, module );
976
1101
return ;
977
1102
}
978
1103
@@ -982,18 +1107,18 @@ void fissionTarget(omp::TargetOp targetOp, RewriterBase &rewriter) {
982
1107
983
1108
if (splitBefore && splitAfter) {
984
1109
auto res = isolateOp (toIsolate, splitAfter, rewriter);
985
- moveToHost (res.preTargetOp , rewriter);
986
- fissionTarget (res.postTargetOp , rewriter);
1110
+ moveToHost (res.preTargetOp , rewriter, module );
1111
+ fissionTarget (res.postTargetOp , rewriter, module );
987
1112
return ;
988
1113
}
989
1114
if (splitBefore) {
990
1115
auto res = isolateOp (toIsolate, splitAfter, rewriter);
991
- moveToHost (res.preTargetOp , rewriter);
1116
+ moveToHost (res.preTargetOp , rewriter, module );
992
1117
return ;
993
1118
}
994
1119
if (splitAfter) {
995
1120
auto res = isolateOp (toIsolate->getNextNode (), splitAfter, rewriter);
996
- fissionTarget (res.postTargetOp , rewriter);
1121
+ fissionTarget (res.postTargetOp , rewriter, module );
997
1122
return ;
998
1123
}
999
1124
}
@@ -1023,7 +1148,7 @@ class LowerWorkdistributePass
1023
1148
for (auto targetOp : targetOps) {
1024
1149
auto res = splitTargetData (targetOp, rewriter);
1025
1150
if (res)
1026
- fissionTarget (res->targetOp , rewriter);
1151
+ fissionTarget (res->targetOp , rewriter, moduleOp );
1027
1152
}
1028
1153
}
1029
1154
}
0 commit comments