@@ -834,13 +834,140 @@ genI32Constant(mlir::Location loc, mlir::RewriterBase &rewriter, int value) {
834834 return rewriter.create <mlir::LLVM::ConstantOp>(loc, i32Ty, attr);
835835}
836836
837- static Type getOmpDeviceType (MLIRContext *c) { return IntegerType::get (c, 32 ); }
837+ static mlir::LLVM::ConstantOp
838+ genI64Constant (mlir::Location loc, mlir::RewriterBase &rewriter, int value) {
839+ mlir::Type i64Ty = rewriter.getI64Type ();
840+ mlir::IntegerAttr attr = rewriter.getI64IntegerAttr (value);
841+ return rewriter.create <mlir::LLVM::ConstantOp>(loc, i64Ty, attr);
842+ }
843+
844+ static Value genDescriptorGetBaseAddress (fir::FirOpBuilder &builder,
845+ Location loc, Value boxDesc) {
846+ Value box = boxDesc;
847+ if (auto refBox = dyn_cast<fir::ReferenceType>(boxDesc.getType ())) {
848+ box = fir::LoadOp::create (builder, loc, boxDesc);
849+ }
850+ assert (isa<fir::BoxType>(box.getType ()) &&
851+ " Unknown type passed to genDescriptorGetBaseAddress" );
852+ auto i8Type = builder.getI8Type ();
853+ auto unknownArrayType =
854+ fir::SequenceType::get ({fir::SequenceType::getUnknownExtent ()}, i8Type);
855+ auto i8BoxType = fir::BoxType::get (unknownArrayType);
856+ auto typedBox = fir::ConvertOp::create (builder, loc, i8BoxType, box);
857+ auto rawAddr = fir::BoxAddrOp::create (builder, loc, typedBox);
858+ return rawAddr;
859+ }
860+
861+ static Value genDescriptorGetTotalElements (fir::FirOpBuilder &builder,
862+ Location loc, Value boxDesc) {
863+ Value box = boxDesc;
864+ if (auto refBox = dyn_cast<fir::ReferenceType>(boxDesc.getType ())) {
865+ box = fir::LoadOp::create (builder, loc, boxDesc);
866+ }
867+ assert (isa<fir::BoxType>(box.getType ()) &&
868+ " Unknown type passed to genDescriptorGetTotalElements" );
869+ auto i64Type = builder.getI64Type ();
870+ return fir::BoxTotalElementsOp::create (builder, loc, i64Type, box);
871+ }
872+
873+ static Value genDescriptorGetEleSize (fir::FirOpBuilder &builder, Location loc,
874+ Value boxDesc) {
875+ Value box = boxDesc;
876+ if (auto refBox = dyn_cast<fir::ReferenceType>(boxDesc.getType ())) {
877+ box = fir::LoadOp::create (builder, loc, boxDesc);
878+ }
879+ assert (isa<fir::BoxType>(box.getType ()) &&
880+ " Unknown type passed to genDescriptorGetElementSize" );
881+ auto i64Type = builder.getI64Type ();
882+ return fir::BoxEleSizeOp::create (builder, loc, i64Type, box);
883+ }
884+
885+ static Value genDescriptorGetDataSizeInBytes (fir::FirOpBuilder &builder,
886+ Location loc, Value boxDesc) {
887+ Value box = boxDesc;
888+ if (auto refBox = dyn_cast<fir::ReferenceType>(boxDesc.getType ())) {
889+ box = fir::LoadOp::create (builder, loc, boxDesc);
890+ }
891+ assert (isa<fir::BoxType>(box.getType ()) &&
892+ " Unknown type passed to genDescriptorGetElementSize" );
893+ Value eleSize = genDescriptorGetEleSize (builder, loc, box);
894+ Value totalElements = genDescriptorGetTotalElements (builder, loc, box);
895+ return mlir::arith::MulIOp::create (builder, loc, totalElements, eleSize);
896+ }
897+
898+ static mlir::Value genOmpGetMappedPtrIfPresent (fir::FirOpBuilder &builder,
899+ mlir::Location loc,
900+ mlir::Value hostPtr,
901+ mlir::Value deviceNum,
902+ mlir::ModuleOp module ) {
903+ auto *context = builder.getContext ();
904+ auto voidPtrType = fir::LLVMPointerType::get (context, builder.getI8Type ());
905+ auto i32Type = builder.getI32Type ();
906+ auto funcName = " omp_get_mapped_ptr" ;
907+ auto funcOp = module .lookupSymbol <mlir::func::FuncOp>(funcName);
908+
909+ if (!funcOp) {
910+ auto funcType =
911+ mlir::FunctionType::get (context, {voidPtrType, i32Type}, {voidPtrType});
912+
913+ mlir::OpBuilder::InsertionGuard guard (builder);
914+ builder.setInsertionPointToStart (module .getBody ());
915+
916+ funcOp = mlir::func::FuncOp::create (builder, loc, funcName, funcType);
917+ funcOp.setPrivate ();
918+ }
919+
920+ llvm::SmallVector<mlir::Value> args;
921+ args.push_back (fir::ConvertOp::create (builder, loc, voidPtrType, hostPtr));
922+ args.push_back (fir::ConvertOp::create (builder, loc, i32Type, deviceNum));
923+ auto callOp = fir::CallOp::create (builder, loc, funcOp, args);
924+ auto mappedPtr = callOp.getResult (0 );
925+ auto isNull = builder.genIsNullAddr (loc, mappedPtr);
926+ auto convertedHostPtr =
927+ fir::ConvertOp::create (builder, loc, voidPtrType, hostPtr);
928+ auto result = arith::SelectOp::create (builder, loc, isNull, convertedHostPtr,
929+ mappedPtr);
930+ return result;
931+ }
932+
933+ static void genOmpTargetMemcpyCall (fir::FirOpBuilder &builder,
934+ mlir::Location loc, mlir::Value dst,
935+ mlir::Value src, mlir::Value length,
936+ mlir::Value dstOffset, mlir::Value srcOffset,
937+ mlir::Value device, mlir::ModuleOp module ) {
938+ auto *context = builder.getContext ();
939+ // int omp_target_memcpy(void *dst, const void *src, size_t length,
940+ // size_t dst_offset, size_t src_offset,
941+ // int dst_device, int src_device)
942+ auto funcName = " omp_target_memcpy" ;
943+ auto voidPtrType = fir::LLVMPointerType::get (context, builder.getI8Type ());
944+ auto sizeTType = builder.getI64Type (); // assuming size_t is 64-bit
945+ auto i32Type = builder.getI32Type ();
946+ auto funcOp = module .lookupSymbol <mlir::func::FuncOp>(funcName);
947+
948+ if (!funcOp) {
949+ mlir::OpBuilder::InsertionGuard guard (builder);
950+ builder.setInsertionPointToStart (module .getBody ());
951+ llvm::SmallVector<mlir::Type> argTypes = {
952+ voidPtrType, voidPtrType, sizeTType, sizeTType,
953+ sizeTType, i32Type, i32Type};
954+ auto funcType = mlir::FunctionType::get (context, argTypes, {i32Type});
955+ funcOp = mlir::func::FuncOp::create (builder, loc, funcName, funcType);
956+ funcOp.setPrivate ();
957+ }
958+
959+ llvm::SmallVector<mlir::Value> args{dst, src, length, dstOffset,
960+ srcOffset, device, device};
961+ fir::CallOp::create (builder, loc, funcOp, args);
962+ return ;
963+ }
838964
839965// moveToHost method clones all the ops from target region outside of it.
840966// It hoists runtime functions and replaces them with omp vesions.
841967// Also hoists and replaces fir.allocmem with omp.target_allocmem and
842968// fir.freemem with omp.target_freemem
843- static void moveToHost (omp::TargetOp targetOp, RewriterBase &rewriter) {
969+ static void moveToHost (omp::TargetOp targetOp, RewriterBase &rewriter,
970+ mlir::ModuleOp module ) {
844971 OpBuilder::InsertionGuard guard (rewriter);
845972 Block *targetBlock = &targetOp.getRegion ().front ();
846973 assert (targetBlock == &targetOp.getRegion ().back ());
@@ -859,7 +986,7 @@ static void moveToHost(omp::TargetOp targetOp, RewriterBase &rewriter) {
859986 Value privateVar = targetOp.getPrivateVars ()[i];
860987 // The mapping should link the device-side variable to the host-side one.
861988 BlockArgument arg = targetBlock->getArguments ()[mapSize + i];
862- // Map the device-side copy (arg) to the host-side value (privateVar).
989+ // Map the device-side copy (` arg` ) to the host-side value (` privateVar` ).
863990 mapping.map (arg, privateVar);
864991 }
865992
@@ -872,69 +999,43 @@ static void moveToHost(omp::TargetOp targetOp, RewriterBase &rewriter) {
872999 for (auto it = targetBlock->begin (), end = std::prev (targetBlock->end ());
8731000 it != end; ++it) {
8741001 auto *op = &*it;
875- if (isRuntimeCall (op)) {
876- fir::CallOp runtimeCall = cast<fir::CallOp>(op);
877- auto module = runtimeCall->getParentOfType <ModuleOp>();
878- auto callee =
879- cast<func::FuncOp>(module .lookupSymbol (runtimeCall.getCalleeAttr ()));
880- std::string newCalleeName = (callee.getName () + " _omp" ).str ();
881- mlir::OpBuilder moduleBuilder (module .getBodyRegion ());
882- func::FuncOp newCallee =
883- cast_or_null<func::FuncOp>(module .lookupSymbol (newCalleeName));
884- if (!newCallee) {
885- SmallVector<Type> argTypes (callee.getFunctionType ().getInputs ());
886- argTypes.push_back (getOmpDeviceType (rewriter.getContext ()));
887- newCallee = moduleBuilder.create <func::FuncOp>(
888- callee->getLoc (), newCalleeName,
889- FunctionType::get (rewriter.getContext (), argTypes,
890- callee.getFunctionType ().getResults ()));
891- if (callee.getArgAttrs ())
892- newCallee.setArgAttrsAttr (*callee.getArgAttrs ());
893- if (callee.getResAttrs ())
894- newCallee.setResAttrsAttr (*callee.getResAttrs ());
895- newCallee.setSymVisibility (callee.getSymVisibility ());
896- newCallee->setDiscardableAttrs (callee->getDiscardableAttrDictionary ());
897- }
898- SmallVector<Value> operands = runtimeCall.getOperands ();
899- operands.push_back (device);
900- auto tmpCall = rewriter.create <fir::CallOp>(
901- runtimeCall.getLoc (), runtimeCall.getResultTypes (),
902- SymbolRefAttr::get (newCallee), operands, nullptr , nullptr , nullptr ,
903- runtimeCall.getFastmathAttr ());
904- Operation *newCall = rewriter.clone (*tmpCall, mapping);
905- mapping.map (&*it, newCall);
906- rewriter.eraseOp (tmpCall);
907- } else {
908- Operation *clonedOp = rewriter.clone (*op, mapping);
909- for (unsigned i = 0 ; i < op->getNumResults (); ++i) {
910- mapping.map (op->getResult (i), clonedOp->getResult (i));
911- }
912- // fir.declare changes its type when hoisting it out of omp.target to
913- // omp.target_data Introduce a load, if original declareOp input is not of
914- // reference type, but cloned delcareOp input is reference type.
915- if (fir::DeclareOp clonedDeclareOp = dyn_cast<fir::DeclareOp>(clonedOp)) {
916- auto originalDeclareOp = cast<fir::DeclareOp>(op);
917- Type originalInType = originalDeclareOp.getMemref ().getType ();
918- Type clonedInType = clonedDeclareOp.getMemref ().getType ();
919-
920- fir::ReferenceType originalRefType =
921- dyn_cast<fir::ReferenceType>(originalInType);
922- fir::ReferenceType clonedRefType =
923- dyn_cast<fir::ReferenceType>(clonedInType);
924- if (!originalRefType && clonedRefType) {
925- Type clonedEleTy = clonedRefType.getElementType ();
926- if (clonedEleTy == originalDeclareOp.getType ()) {
927- opsToReplace.push_back (clonedOp);
928- }
1002+ Operation *clonedOp = rewriter.clone (*op, mapping);
1003+ for (unsigned i = 0 ; i < op->getNumResults (); ++i) {
1004+ mapping.map (op->getResult (i), clonedOp->getResult (i));
1005+ }
1006+ // fir.declare changes its type when hoisting it out of omp.target to
1007+ // omp.target_data Introduce a load, if original declareOp input is not of
1008+ // reference type, but cloned delcareOp input is reference type.
1009+
1010+ if (fir::DeclareOp clonedDeclareOp = dyn_cast<fir::DeclareOp>(clonedOp)) {
1011+ auto originalDeclareOp = cast<fir::DeclareOp>(op);
1012+ Type originalInType = originalDeclareOp.getMemref ().getType ();
1013+ Type clonedInType = clonedDeclareOp.getMemref ().getType ();
1014+
1015+ fir::ReferenceType originalRefType =
1016+ dyn_cast<fir::ReferenceType>(originalInType);
1017+ fir::ReferenceType clonedRefType =
1018+ dyn_cast<fir::ReferenceType>(clonedInType);
1019+ if (!originalRefType && clonedRefType) {
1020+ Type clonedEleTy = clonedRefType.getElementType ();
1021+ if (clonedEleTy == originalDeclareOp.getType ()) {
1022+ opsToReplace.push_back (clonedOp);
9291023 }
9301024 }
1025+ }
9311026 if (isa<fir::AllocMemOp>(clonedOp) || isa<fir::FreeMemOp>(clonedOp))
9321027 opsToReplace.push_back (clonedOp);
933- }
1028+ if (isRuntimeCall (clonedOp)) {
1029+ fir::CallOp runtimeCall = cast<fir::CallOp>(op);
1030+ if ((*runtimeCall.getCallee ()).getRootReference ().getValue () ==
1031+ " _FortranAAssign" ) {
1032+ opsToReplace.push_back (clonedOp);
1033+ } else {
1034+ llvm_unreachable (" Unhandled runtime call hoisting." );
1035+ }
1036+ }
9341037 }
9351038
936- // Replace fir.allocmem with omp.target_allocmem,
937- // fir.freemem with omp.target_freemem.
9381039 for (Operation *op : opsToReplace) {
9391040 if (auto allocOp = dyn_cast<fir::AllocMemOp>(op)) {
9401041 rewriter.setInsertionPoint (allocOp);
@@ -963,16 +1064,40 @@ static void moveToHost(omp::TargetOp targetOp, RewriterBase &rewriter) {
9631064 Value loadedValue = rewriter.create <fir::LoadOp>(
9641065 clonedDeclareOp.getLoc (), clonedEleTy, clonedDeclareOp.getMemref ());
9651066 clonedDeclareOp.getResult ().replaceAllUsesWith (loadedValue);
1067+ } else if (isRuntimeCall (op)) {
1068+ rewriter.setInsertionPoint (op);
1069+ fir::CallOp runtimeCall = cast<fir::CallOp>(op);
1070+ SmallVector<Value> operands = runtimeCall.getOperands ();
1071+ mlir::Location loc = runtimeCall.getLoc ();
1072+ fir::FirOpBuilder builder{rewriter, op};
1073+ assert (operands.size () == 4 );
1074+ Value sourceFile{operands[2 ]}, sourceLine{operands[3 ]};
1075+
1076+ auto fromBaseAddr =
1077+ genDescriptorGetBaseAddress (builder, loc, operands[1 ]);
1078+ auto toBaseAddr = genDescriptorGetBaseAddress (builder, loc, operands[0 ]);
1079+ auto dataSizeInBytes =
1080+ genDescriptorGetDataSizeInBytes (builder, loc, operands[1 ]);
1081+
1082+ Value toPtr =
1083+ genOmpGetMappedPtrIfPresent (builder, loc, toBaseAddr, device, module );
1084+ Value fromPtr = genOmpGetMappedPtrIfPresent (builder, loc, fromBaseAddr,
1085+ device, module );
1086+ Value zero = genI64Constant (loc, rewriter, 0 );
1087+ genOmpTargetMemcpyCall (builder, loc, toPtr, fromPtr, dataSizeInBytes,
1088+ zero, zero, device, module );
1089+ rewriter.eraseOp (op);
9661090 }
9671091 }
9681092 rewriter.eraseOp (targetOp);
9691093}
9701094
971- void fissionTarget (omp::TargetOp targetOp, RewriterBase &rewriter) {
1095+ void fissionTarget (omp::TargetOp targetOp, RewriterBase &rewriter,
1096+ mlir::ModuleOp module ) {
9721097 auto tuple = getNestedOpToIsolate (targetOp);
9731098 if (!tuple) {
9741099 LLVM_DEBUG (llvm::dbgs () << " No op to isolate\n " );
975- moveToHost (targetOp, rewriter);
1100+ moveToHost (targetOp, rewriter, module );
9761101 return ;
9771102 }
9781103
@@ -982,18 +1107,18 @@ void fissionTarget(omp::TargetOp targetOp, RewriterBase &rewriter) {
9821107
9831108 if (splitBefore && splitAfter) {
9841109 auto res = isolateOp (toIsolate, splitAfter, rewriter);
985- moveToHost (res.preTargetOp , rewriter);
986- fissionTarget (res.postTargetOp , rewriter);
1110+ moveToHost (res.preTargetOp , rewriter, module );
1111+ fissionTarget (res.postTargetOp , rewriter, module );
9871112 return ;
9881113 }
9891114 if (splitBefore) {
9901115 auto res = isolateOp (toIsolate, splitAfter, rewriter);
991- moveToHost (res.preTargetOp , rewriter);
1116+ moveToHost (res.preTargetOp , rewriter, module );
9921117 return ;
9931118 }
9941119 if (splitAfter) {
9951120 auto res = isolateOp (toIsolate->getNextNode (), splitAfter, rewriter);
996- fissionTarget (res.postTargetOp , rewriter);
1121+ fissionTarget (res.postTargetOp , rewriter, module );
9971122 return ;
9981123 }
9991124}
@@ -1023,7 +1148,7 @@ class LowerWorkdistributePass
10231148 for (auto targetOp : targetOps) {
10241149 auto res = splitTargetData (targetOp, rewriter);
10251150 if (res)
1026- fissionTarget (res->targetOp , rewriter);
1151+ fissionTarget (res->targetOp , rewriter, moduleOp );
10271152 }
10281153 }
10291154 }
0 commit comments