Skip to content

Commit c7080fb

Browse files
committed
update moveToHost implementation
1 parent 8e7147f commit c7080fb

File tree

2 files changed

+214
-87
lines changed

2 files changed

+214
-87
lines changed

flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp

Lines changed: 192 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -834,13 +834,140 @@ genI32Constant(mlir::Location loc, mlir::RewriterBase &rewriter, int value) {
834834
return rewriter.create<mlir::LLVM::ConstantOp>(loc, i32Ty, attr);
835835
}
836836

837-
static Type getOmpDeviceType(MLIRContext *c) { return IntegerType::get(c, 32); }
837+
static mlir::LLVM::ConstantOp
838+
genI64Constant(mlir::Location loc, mlir::RewriterBase &rewriter, int value) {
839+
mlir::Type i64Ty = rewriter.getI64Type();
840+
mlir::IntegerAttr attr = rewriter.getI64IntegerAttr(value);
841+
return rewriter.create<mlir::LLVM::ConstantOp>(loc, i64Ty, attr);
842+
}
843+
844+
static Value genDescriptorGetBaseAddress(fir::FirOpBuilder &builder,
845+
Location loc, Value boxDesc) {
846+
Value box = boxDesc;
847+
if (auto refBox = dyn_cast<fir::ReferenceType>(boxDesc.getType())) {
848+
box = fir::LoadOp::create(builder, loc, boxDesc);
849+
}
850+
assert(isa<fir::BoxType>(box.getType()) &&
851+
"Unknown type passed to genDescriptorGetBaseAddress");
852+
auto i8Type = builder.getI8Type();
853+
auto unknownArrayType =
854+
fir::SequenceType::get({fir::SequenceType::getUnknownExtent()}, i8Type);
855+
auto i8BoxType = fir::BoxType::get(unknownArrayType);
856+
auto typedBox = fir::ConvertOp::create(builder, loc, i8BoxType, box);
857+
auto rawAddr = fir::BoxAddrOp::create(builder, loc, typedBox);
858+
return rawAddr;
859+
}
860+
861+
static Value genDescriptorGetTotalElements(fir::FirOpBuilder &builder,
862+
Location loc, Value boxDesc) {
863+
Value box = boxDesc;
864+
if (auto refBox = dyn_cast<fir::ReferenceType>(boxDesc.getType())) {
865+
box = fir::LoadOp::create(builder, loc, boxDesc);
866+
}
867+
assert(isa<fir::BoxType>(box.getType()) &&
868+
"Unknown type passed to genDescriptorGetTotalElements");
869+
auto i64Type = builder.getI64Type();
870+
return fir::BoxTotalElementsOp::create(builder, loc, i64Type, box);
871+
}
872+
873+
static Value genDescriptorGetEleSize(fir::FirOpBuilder &builder, Location loc,
874+
Value boxDesc) {
875+
Value box = boxDesc;
876+
if (auto refBox = dyn_cast<fir::ReferenceType>(boxDesc.getType())) {
877+
box = fir::LoadOp::create(builder, loc, boxDesc);
878+
}
879+
assert(isa<fir::BoxType>(box.getType()) &&
880+
"Unknown type passed to genDescriptorGetElementSize");
881+
auto i64Type = builder.getI64Type();
882+
return fir::BoxEleSizeOp::create(builder, loc, i64Type, box);
883+
}
884+
885+
static Value genDescriptorGetDataSizeInBytes(fir::FirOpBuilder &builder,
886+
Location loc, Value boxDesc) {
887+
Value box = boxDesc;
888+
if (auto refBox = dyn_cast<fir::ReferenceType>(boxDesc.getType())) {
889+
box = fir::LoadOp::create(builder, loc, boxDesc);
890+
}
891+
assert(isa<fir::BoxType>(box.getType()) &&
892+
"Unknown type passed to genDescriptorGetElementSize");
893+
Value eleSize = genDescriptorGetEleSize(builder, loc, box);
894+
Value totalElements = genDescriptorGetTotalElements(builder, loc, box);
895+
return mlir::arith::MulIOp::create(builder, loc, totalElements, eleSize);
896+
}
897+
898+
static mlir::Value genOmpGetMappedPtrIfPresent(fir::FirOpBuilder &builder,
899+
mlir::Location loc,
900+
mlir::Value hostPtr,
901+
mlir::Value deviceNum,
902+
mlir::ModuleOp module) {
903+
auto *context = builder.getContext();
904+
auto voidPtrType = fir::LLVMPointerType::get(context, builder.getI8Type());
905+
auto i32Type = builder.getI32Type();
906+
auto funcName = "omp_get_mapped_ptr";
907+
auto funcOp = module.lookupSymbol<mlir::func::FuncOp>(funcName);
908+
909+
if (!funcOp) {
910+
auto funcType =
911+
mlir::FunctionType::get(context, {voidPtrType, i32Type}, {voidPtrType});
912+
913+
mlir::OpBuilder::InsertionGuard guard(builder);
914+
builder.setInsertionPointToStart(module.getBody());
915+
916+
funcOp = mlir::func::FuncOp::create(builder, loc, funcName, funcType);
917+
funcOp.setPrivate();
918+
}
919+
920+
llvm::SmallVector<mlir::Value> args;
921+
args.push_back(fir::ConvertOp::create(builder, loc, voidPtrType, hostPtr));
922+
args.push_back(fir::ConvertOp::create(builder, loc, i32Type, deviceNum));
923+
auto callOp = fir::CallOp::create(builder, loc, funcOp, args);
924+
auto mappedPtr = callOp.getResult(0);
925+
auto isNull = builder.genIsNullAddr(loc, mappedPtr);
926+
auto convertedHostPtr =
927+
fir::ConvertOp::create(builder, loc, voidPtrType, hostPtr);
928+
auto result = arith::SelectOp::create(builder, loc, isNull, convertedHostPtr,
929+
mappedPtr);
930+
return result;
931+
}
932+
933+
static void genOmpTargetMemcpyCall(fir::FirOpBuilder &builder,
934+
mlir::Location loc, mlir::Value dst,
935+
mlir::Value src, mlir::Value length,
936+
mlir::Value dstOffset, mlir::Value srcOffset,
937+
mlir::Value device, mlir::ModuleOp module) {
938+
auto *context = builder.getContext();
939+
// int omp_target_memcpy(void *dst, const void *src, size_t length,
940+
// size_t dst_offset, size_t src_offset,
941+
// int dst_device, int src_device)
942+
auto funcName = "omp_target_memcpy";
943+
auto voidPtrType = fir::LLVMPointerType::get(context, builder.getI8Type());
944+
auto sizeTType = builder.getI64Type(); // assuming size_t is 64-bit
945+
auto i32Type = builder.getI32Type();
946+
auto funcOp = module.lookupSymbol<mlir::func::FuncOp>(funcName);
947+
948+
if (!funcOp) {
949+
mlir::OpBuilder::InsertionGuard guard(builder);
950+
builder.setInsertionPointToStart(module.getBody());
951+
llvm::SmallVector<mlir::Type> argTypes = {
952+
voidPtrType, voidPtrType, sizeTType, sizeTType,
953+
sizeTType, i32Type, i32Type};
954+
auto funcType = mlir::FunctionType::get(context, argTypes, {i32Type});
955+
funcOp = mlir::func::FuncOp::create(builder, loc, funcName, funcType);
956+
funcOp.setPrivate();
957+
}
958+
959+
llvm::SmallVector<mlir::Value> args{dst, src, length, dstOffset,
960+
srcOffset, device, device};
961+
fir::CallOp::create(builder, loc, funcOp, args);
962+
return;
963+
}
838964

839965
// moveToHost method clones all the ops from target region outside of it.
840966
// It hoists runtime functions and replaces them with omp vesions.
841967
// Also hoists and replaces fir.allocmem with omp.target_allocmem and
842968
// fir.freemem with omp.target_freemem
843-
static void moveToHost(omp::TargetOp targetOp, RewriterBase &rewriter) {
969+
static void moveToHost(omp::TargetOp targetOp, RewriterBase &rewriter,
970+
mlir::ModuleOp module) {
844971
OpBuilder::InsertionGuard guard(rewriter);
845972
Block *targetBlock = &targetOp.getRegion().front();
846973
assert(targetBlock == &targetOp.getRegion().back());
@@ -859,7 +986,7 @@ static void moveToHost(omp::TargetOp targetOp, RewriterBase &rewriter) {
859986
Value privateVar = targetOp.getPrivateVars()[i];
860987
// The mapping should link the device-side variable to the host-side one.
861988
BlockArgument arg = targetBlock->getArguments()[mapSize + i];
862-
// Map the device-side copy (arg) to the host-side value (privateVar).
989+
// Map the device-side copy (`arg`) to the host-side value (`privateVar`).
863990
mapping.map(arg, privateVar);
864991
}
865992

@@ -872,69 +999,43 @@ static void moveToHost(omp::TargetOp targetOp, RewriterBase &rewriter) {
872999
for (auto it = targetBlock->begin(), end = std::prev(targetBlock->end());
8731000
it != end; ++it) {
8741001
auto *op = &*it;
875-
if (isRuntimeCall(op)) {
876-
fir::CallOp runtimeCall = cast<fir::CallOp>(op);
877-
auto module = runtimeCall->getParentOfType<ModuleOp>();
878-
auto callee =
879-
cast<func::FuncOp>(module.lookupSymbol(runtimeCall.getCalleeAttr()));
880-
std::string newCalleeName = (callee.getName() + "_omp").str();
881-
mlir::OpBuilder moduleBuilder(module.getBodyRegion());
882-
func::FuncOp newCallee =
883-
cast_or_null<func::FuncOp>(module.lookupSymbol(newCalleeName));
884-
if (!newCallee) {
885-
SmallVector<Type> argTypes(callee.getFunctionType().getInputs());
886-
argTypes.push_back(getOmpDeviceType(rewriter.getContext()));
887-
newCallee = moduleBuilder.create<func::FuncOp>(
888-
callee->getLoc(), newCalleeName,
889-
FunctionType::get(rewriter.getContext(), argTypes,
890-
callee.getFunctionType().getResults()));
891-
if (callee.getArgAttrs())
892-
newCallee.setArgAttrsAttr(*callee.getArgAttrs());
893-
if (callee.getResAttrs())
894-
newCallee.setResAttrsAttr(*callee.getResAttrs());
895-
newCallee.setSymVisibility(callee.getSymVisibility());
896-
newCallee->setDiscardableAttrs(callee->getDiscardableAttrDictionary());
897-
}
898-
SmallVector<Value> operands = runtimeCall.getOperands();
899-
operands.push_back(device);
900-
auto tmpCall = rewriter.create<fir::CallOp>(
901-
runtimeCall.getLoc(), runtimeCall.getResultTypes(),
902-
SymbolRefAttr::get(newCallee), operands, nullptr, nullptr, nullptr,
903-
runtimeCall.getFastmathAttr());
904-
Operation *newCall = rewriter.clone(*tmpCall, mapping);
905-
mapping.map(&*it, newCall);
906-
rewriter.eraseOp(tmpCall);
907-
} else {
908-
Operation *clonedOp = rewriter.clone(*op, mapping);
909-
for (unsigned i = 0; i < op->getNumResults(); ++i) {
910-
mapping.map(op->getResult(i), clonedOp->getResult(i));
911-
}
912-
// fir.declare changes its type when hoisting it out of omp.target to
913-
// omp.target_data Introduce a load, if original declareOp input is not of
914-
// reference type, but cloned delcareOp input is reference type.
915-
if (fir::DeclareOp clonedDeclareOp = dyn_cast<fir::DeclareOp>(clonedOp)) {
916-
auto originalDeclareOp = cast<fir::DeclareOp>(op);
917-
Type originalInType = originalDeclareOp.getMemref().getType();
918-
Type clonedInType = clonedDeclareOp.getMemref().getType();
919-
920-
fir::ReferenceType originalRefType =
921-
dyn_cast<fir::ReferenceType>(originalInType);
922-
fir::ReferenceType clonedRefType =
923-
dyn_cast<fir::ReferenceType>(clonedInType);
924-
if (!originalRefType && clonedRefType) {
925-
Type clonedEleTy = clonedRefType.getElementType();
926-
if (clonedEleTy == originalDeclareOp.getType()) {
927-
opsToReplace.push_back(clonedOp);
928-
}
1002+
Operation *clonedOp = rewriter.clone(*op, mapping);
1003+
for (unsigned i = 0; i < op->getNumResults(); ++i) {
1004+
mapping.map(op->getResult(i), clonedOp->getResult(i));
1005+
}
1006+
// fir.declare changes its type when hoisting it out of omp.target to
1007+
// omp.target_data Introduce a load, if original declareOp input is not of
1008+
// reference type, but cloned delcareOp input is reference type.
1009+
1010+
if (fir::DeclareOp clonedDeclareOp = dyn_cast<fir::DeclareOp>(clonedOp)) {
1011+
auto originalDeclareOp = cast<fir::DeclareOp>(op);
1012+
Type originalInType = originalDeclareOp.getMemref().getType();
1013+
Type clonedInType = clonedDeclareOp.getMemref().getType();
1014+
1015+
fir::ReferenceType originalRefType =
1016+
dyn_cast<fir::ReferenceType>(originalInType);
1017+
fir::ReferenceType clonedRefType =
1018+
dyn_cast<fir::ReferenceType>(clonedInType);
1019+
if (!originalRefType && clonedRefType) {
1020+
Type clonedEleTy = clonedRefType.getElementType();
1021+
if (clonedEleTy == originalDeclareOp.getType()) {
1022+
opsToReplace.push_back(clonedOp);
9291023
}
9301024
}
1025+
}
9311026
if (isa<fir::AllocMemOp>(clonedOp) || isa<fir::FreeMemOp>(clonedOp))
9321027
opsToReplace.push_back(clonedOp);
933-
}
1028+
if (isRuntimeCall(clonedOp)) {
1029+
fir::CallOp runtimeCall = cast<fir::CallOp>(op);
1030+
if ((*runtimeCall.getCallee()).getRootReference().getValue() ==
1031+
"_FortranAAssign") {
1032+
opsToReplace.push_back(clonedOp);
1033+
} else {
1034+
llvm_unreachable("Unhandled runtime call hoisting.");
1035+
}
1036+
}
9341037
}
9351038

936-
// Replace fir.allocmem with omp.target_allocmem,
937-
// fir.freemem with omp.target_freemem.
9381039
for (Operation *op : opsToReplace) {
9391040
if (auto allocOp = dyn_cast<fir::AllocMemOp>(op)) {
9401041
rewriter.setInsertionPoint(allocOp);
@@ -963,16 +1064,40 @@ static void moveToHost(omp::TargetOp targetOp, RewriterBase &rewriter) {
9631064
Value loadedValue = rewriter.create<fir::LoadOp>(
9641065
clonedDeclareOp.getLoc(), clonedEleTy, clonedDeclareOp.getMemref());
9651066
clonedDeclareOp.getResult().replaceAllUsesWith(loadedValue);
1067+
} else if (isRuntimeCall(op)) {
1068+
rewriter.setInsertionPoint(op);
1069+
fir::CallOp runtimeCall = cast<fir::CallOp>(op);
1070+
SmallVector<Value> operands = runtimeCall.getOperands();
1071+
mlir::Location loc = runtimeCall.getLoc();
1072+
fir::FirOpBuilder builder{rewriter, op};
1073+
assert(operands.size() == 4);
1074+
Value sourceFile{operands[2]}, sourceLine{operands[3]};
1075+
1076+
auto fromBaseAddr =
1077+
genDescriptorGetBaseAddress(builder, loc, operands[1]);
1078+
auto toBaseAddr = genDescriptorGetBaseAddress(builder, loc, operands[0]);
1079+
auto dataSizeInBytes =
1080+
genDescriptorGetDataSizeInBytes(builder, loc, operands[1]);
1081+
1082+
Value toPtr =
1083+
genOmpGetMappedPtrIfPresent(builder, loc, toBaseAddr, device, module);
1084+
Value fromPtr = genOmpGetMappedPtrIfPresent(builder, loc, fromBaseAddr,
1085+
device, module);
1086+
Value zero = genI64Constant(loc, rewriter, 0);
1087+
genOmpTargetMemcpyCall(builder, loc, toPtr, fromPtr, dataSizeInBytes,
1088+
zero, zero, device, module);
1089+
rewriter.eraseOp(op);
9661090
}
9671091
}
9681092
rewriter.eraseOp(targetOp);
9691093
}
9701094

971-
void fissionTarget(omp::TargetOp targetOp, RewriterBase &rewriter) {
1095+
void fissionTarget(omp::TargetOp targetOp, RewriterBase &rewriter,
1096+
mlir::ModuleOp module) {
9721097
auto tuple = getNestedOpToIsolate(targetOp);
9731098
if (!tuple) {
9741099
LLVM_DEBUG(llvm::dbgs() << " No op to isolate\n");
975-
moveToHost(targetOp, rewriter);
1100+
moveToHost(targetOp, rewriter, module);
9761101
return;
9771102
}
9781103

@@ -982,18 +1107,18 @@ void fissionTarget(omp::TargetOp targetOp, RewriterBase &rewriter) {
9821107

9831108
if (splitBefore && splitAfter) {
9841109
auto res = isolateOp(toIsolate, splitAfter, rewriter);
985-
moveToHost(res.preTargetOp, rewriter);
986-
fissionTarget(res.postTargetOp, rewriter);
1110+
moveToHost(res.preTargetOp, rewriter, module);
1111+
fissionTarget(res.postTargetOp, rewriter, module);
9871112
return;
9881113
}
9891114
if (splitBefore) {
9901115
auto res = isolateOp(toIsolate, splitAfter, rewriter);
991-
moveToHost(res.preTargetOp, rewriter);
1116+
moveToHost(res.preTargetOp, rewriter, module);
9921117
return;
9931118
}
9941119
if (splitAfter) {
9951120
auto res = isolateOp(toIsolate->getNextNode(), splitAfter, rewriter);
996-
fissionTarget(res.postTargetOp, rewriter);
1121+
fissionTarget(res.postTargetOp, rewriter, module);
9971122
return;
9981123
}
9991124
}
@@ -1023,7 +1148,7 @@ class LowerWorkdistributePass
10231148
for (auto targetOp : targetOps) {
10241149
auto res = splitTargetData(targetOp, rewriter);
10251150
if (res)
1026-
fissionTarget(res->targetOp, rewriter);
1151+
fissionTarget(res->targetOp, rewriter, moduleOp);
10271152
}
10281153
}
10291154
}

0 commit comments

Comments
 (0)