Skip to content

Commit 56a6eb2

Browse files
committed
update
1 parent a56ef5e commit 56a6eb2

File tree

6 files changed

+45
-39
lines changed

6 files changed

+45
-39
lines changed

flang/include/flang/Optimizer/Support/Utils.h

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -200,11 +200,6 @@ std::optional<llvm::ArrayRef<int64_t>> getComponentLowerBoundsIfNonDefault(
200200
fir::RecordType recordType, llvm::StringRef component,
201201
mlir::ModuleOp module, const mlir::SymbolTable *symbolTable = nullptr);
202202

203-
// Convert FIR type to LLVM without turning fir.box<T> into memory
204-
// reference.
205-
mlir::Type convertObjectType(const fir::LLVMTypeConverter &converter,
206-
mlir::Type firType);
207-
208203
/// Generate a LLVM constant value of type `ity`, using the provided offset.
209204
mlir::LLVM::ConstantOp
210205
genConstantIndex(mlir::Location loc, mlir::Type ity,

flang/lib/Optimizer/CodeGen/CodeGenOpenMP.cpp

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,13 @@ struct PrivateClauseOpConversion
127127
}
128128
};
129129

130+
static mlir::Type convertObjectType(const fir::LLVMTypeConverter &converter,
131+
mlir::Type firType) {
132+
if (auto boxTy = mlir::dyn_cast<fir::BaseBoxType>(firType))
133+
return converter.convertBoxTypeAsStruct(boxTy);
134+
return converter.convertType(firType);
135+
}
136+
130137
// FIR Op specific conversion for TargetAllocMemOp
131138
struct TargetAllocMemOpConversion
132139
: public OpenMPFIROpConversion<mlir::omp::TargetAllocMemOp> {
@@ -139,7 +146,7 @@ struct TargetAllocMemOpConversion
139146
mlir::Location loc = allocmemOp.getLoc();
140147
auto ity = lowerTy().indexType();
141148
mlir::Type dataTy = fir::unwrapRefType(heapTy);
142-
mlir::Type llvmObjectTy = fir::convertObjectType(lowerTy(), dataTy);
149+
mlir::Type llvmObjectTy = convertObjectType(lowerTy(), dataTy);
143150
if (fir::isRecordWithTypeParameters(fir::unwrapSequenceType(dataTy)))
144151
TODO(loc, "omp.target_allocmem codegen of derived type with length "
145152
"parameters");

flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp

Lines changed: 35 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -762,37 +762,19 @@ static void moveToHost(omp::TargetOp targetOp, RewriterBase &rewriter) {
762762
mapping.map(arg, mapInfoOp.getVarPtr());
763763
}
764764
rewriter.setInsertionPoint(targetOp);
765-
SmallVector<Operation *> opsToMove;
765+
SmallVector<Operation *> opsToReplace;
766+
Value device = targetOp.getDevice();
767+
/*
768+
if (!device) {
769+
device = genI32Constant(targetOp.getLoc(), rewriter, 0);
770+
}
766771
for (auto it = targetBlock->begin(), end = std::prev(targetBlock->end());
767772
it != end; ++it) {
768773
auto *op = &*it;
769-
auto allocOp = dyn_cast<fir::AllocMemOp>(op);
770-
auto freeOp = dyn_cast<fir::FreeMemOp>(op);
771774
fir::CallOp runtimeCall = nullptr;
772775
if (isRuntimeCall(op))
773776
runtimeCall = cast<fir::CallOp>(op);
774-
775-
if (allocOp || freeOp || runtimeCall) {
776-
Value device = targetOp.getDevice();
777-
if (!device) {
778-
device = genI32Constant(it->getLoc(), rewriter, 0);
779-
}
780-
if (allocOp) {
781-
auto tmpAllocOp = rewriter.create<fir::OmpTargetAllocMemOp>(
782-
allocOp.getLoc(), allocOp.getType(), device,
783-
allocOp.getInTypeAttr(), allocOp.getUniqNameAttr(),
784-
allocOp.getBindcNameAttr(), allocOp.getTypeparams(),
785-
allocOp.getShape());
786-
auto newAllocOp = cast<fir::OmpTargetAllocMemOp>(
787-
rewriter.clone(*tmpAllocOp.getOperation(), mapping));
788-
mapping.map(allocOp.getResult(), newAllocOp.getResult());
789-
rewriter.eraseOp(tmpAllocOp);
790-
} else if (freeOp) {
791-
auto tmpFreeOp = rewriter.create<fir::OmpTargetFreeMemOp>(
792-
freeOp.getLoc(), device, freeOp.getHeapref());
793-
rewriter.clone(*tmpFreeOp.getOperation(), mapping);
794-
rewriter.eraseOp(tmpFreeOp);
795-
} else if (runtimeCall) {
777+
if (runtimeCall) {
796778
auto module = runtimeCall->getParentOfType<ModuleOp>();
797779
auto callee = cast<func::FuncOp>(
798780
module.lookupSymbol(runtimeCall.getCalleeAttr()));
@@ -824,15 +806,42 @@ static void moveToHost(omp::TargetOp targetOp, RewriterBase &rewriter) {
824806
Operation *newCall = rewriter.clone(*tmpCall, mapping);
825807
mapping.map(&*it, newCall);
826808
rewriter.eraseOp(tmpCall);
827-
}
828809
} else {
829810
Operation *clonedOp = rewriter.clone(*op, mapping);
811+
auto allocOp = dyn_cast<fir::AllocMemOp>(clonedOp);
812+
auto freeOp = dyn_cast<fir::FreeMemOp>(clonedOp);
813+
if (allocOp || freeOp)
814+
opsToReplace.push_back(clonedOp);
830815
for (unsigned i = 0; i < op->getNumResults(); ++i) {
831816
mapping.map(op->getResult(i), clonedOp->getResult(i));
832817
}
833818
}
834819
}
820+
for (Operation* op : opsToReplace) {
821+
if (auto allocOp = dyn_cast<fir::AllocMemOp>(op)) {
822+
rewriter.setInsertionPoint(allocOp);
823+
auto ompAllocmemOp = rewriter.create<omp::TargetAllocMemOp>(
824+
allocOp.getLoc(), rewriter.getI64Type(), device,
825+
allocOp.getInTypeAttr(), allocOp.getUniqNameAttr(),
826+
allocOp.getBindcNameAttr(), allocOp.getTypeparams(),
827+
allocOp.getShape());
828+
auto firConvertOp = rewriter.create<fir::ConvertOp>(allocOp.getLoc(),
829+
allocOp.getResult().getType(), ompAllocmemOp.getResult());
830+
rewriter.replaceOp(allocOp, firConvertOp.getResult());
831+
}
832+
else if (auto freeOp = dyn_cast<fir::FreeMemOp>(op)) {
833+
rewriter.setInsertionPoint(freeOp);
834+
auto firConvertOp = rewriter.create<fir::ConvertOp>(
835+
freeOp.getLoc(),
836+
rewriter.getI64Type(),
837+
freeOp.getHeapref());
838+
rewriter.create<omp::TargetFreeMemOp>(
839+
freeOp.getLoc(), device, firConvertOp.getResult());
840+
rewriter.eraseOp(freeOp);
841+
}
842+
}
835843
rewriter.eraseOp(targetOp);
844+
*/
836845
}
837846

838847
void fissionTarget(omp::TargetOp targetOp, RewriterBase &rewriter) {

flang/lib/Optimizer/Support/Utils.cpp

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -51,13 +51,6 @@ std::optional<llvm::ArrayRef<int64_t>> fir::getComponentLowerBoundsIfNonDefault(
5151
return std::nullopt;
5252
}
5353

54-
mlir::Type fir::convertObjectType(const fir::LLVMTypeConverter &converter,
55-
mlir::Type firType) {
56-
if (auto boxTy = mlir::dyn_cast<fir::BaseBoxType>(firType))
57-
return converter.convertBoxTypeAsStruct(boxTy);
58-
return converter.convertType(firType);
59-
}
60-
6154
mlir::LLVM::ConstantOp
6255
fir::genConstantIndex(mlir::Location loc, mlir::Type ity,
6356
mlir::ConversionPatternRewriter &rewriter,

mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2205,6 +2205,7 @@ def TargetFreeMemOp : OpenMP_Op<"target_freemem",
22052205
Arg<I64, "", [MemFree]>:$heapref
22062206
);
22072207
let assemblyFormat = "$device `,` $heapref attr-dict `:` type($device) `,` qualified(type($heapref))";
2208+
}
22082209

22092210
//===----------------------------------------------------------------------===//
22102211
// workdistribute Construct

mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3974,6 +3974,7 @@ llvm::LogicalResult omp::TargetAllocMemOp::verify() {
39743974
if (!mlir::dyn_cast<IntegerType>(outType))
39753975
return emitOpError("must be a integer type");
39763976
return mlir::success();
3977+
}
39773978

39783979
//===----------------------------------------------------------------------===//
39793980
// WorkdistributeOp

0 commit comments

Comments
 (0)