Skip to content

Commit a77f9c9

Browse files
committed
[MLIR][mlir-link] Make createCompositeModule part of LinkerInterface
1 parent 455d247 commit a77f9c9

File tree

6 files changed

+47
-66
lines changed

6 files changed

+47
-66
lines changed

mlir/include/mlir/Linker/IRMover.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,15 +19,15 @@ namespace mlir::link {
1919

2020
class IRMover {
2121
public:
22-
IRMover(Operation *composite);
22+
IRMover(mlir::ModuleOp composite) : composite(composite) {}
2323

2424
ModuleOp getComposite() { return composite; }
2525
MLIRContext *getContext() { return composite->getContext(); }
2626

2727
Error move(OwningOpRef<Operation *> src, ArrayRef<GlobalValue> valuesToLink);
2828

2929
private:
30-
ModuleOp composite;
30+
mlir::ModuleOp composite;
3131
};
3232

3333
} // namespace mlir::link

mlir/include/mlir/Linker/Linker.h

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -75,13 +75,12 @@ class Linker {
7575
bool internalize = false;
7676
};
7777

78-
Linker(ModuleOp composite, const LinkerConfig &config);
78+
Linker(const LinkerConfig &config, MLIRContext *context)
79+
: config(config), context(context) {}
7980

80-
MLIRContext *getContext() { return mover.getContext(); }
81+
MLIRContext *getContext() { return context; }
8182

82-
LogicalResult linkInModule(OwningOpRef<Operation *> src,
83-
unsigned flags = None,
84-
InternalizeCallbackFn internalizeCallback = {});
83+
LogicalResult linkInModule(OwningOpRef<ModuleOp> src, unsigned flags = None);
8584

8685
unsigned getFlags() const;
8786

@@ -92,9 +91,20 @@ class Linker {
9291
/// OverrideFromSrc flag set
9392
LinkFileConfig firstFileConfig(unsigned fileFlags = None) const;
9493

94+
OwningOpRef<ModuleOp> takeModule() { return std::move(composite); }
95+
96+
LogicalResult emitFileError(const Twine &fileName, const Twine &message) {
97+
return emitError("Error processing file '" + fileName + "': " + message);
98+
}
99+
100+
LogicalResult emitError(const Twine &message) {
101+
return mlir::emitError(UnknownLoc::get(context), message);
102+
}
103+
95104
private:
96105
const LinkerConfig &config;
97-
IRMover mover;
106+
MLIRContext *context;
107+
OwningOpRef<ModuleOp> composite;
98108
};
99109

100110
} // namespace mlir::link

mlir/include/mlir/Linker/LinkerInterface.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,11 @@ class LinkerInterface : public DialectInterface::Base<LinkerInterface> {
7474
const LinkerSummaryState *state) const {
7575
return failure();
7676
}
77+
78+
virtual OwningOpRef<ModuleOp> createCompositeModule(ModuleOp src) const {
79+
return ModuleOp::create(
80+
FileLineColLoc::get(src.getContext(), "composite", 0, 0));
81+
}
7782
};
7883

7984
struct LinkableOp {

mlir/lib/Linker/IRMover.cpp

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -617,11 +617,8 @@ Error MLIRLinker::run() {
617617
}
618618
} // namespace
619619

620-
IRMover::IRMover(Operation *composite) : composite(composite) {}
621-
622620
Error IRMover::move(OwningOpRef<Operation *> src,
623621
ArrayRef<GlobalValue> valuesToLink) {
624-
625622
MLIRLinker linker(composite, std::move(src), valuesToLink);
626623
return linker.run();
627624
}

mlir/lib/Linker/Linker.cpp

Lines changed: 11 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -33,14 +33,6 @@ class ModuleLinker {
3333
/// For symbol clashes, prefer those from src.
3434
unsigned flags;
3535

36-
/// List of global value names that should be internalized.
37-
StringSet<> internalize;
38-
39-
/// Function that will perform the actual internalization. The reason for a
40-
/// callback is that the linker cannot call internalizeModule without
41-
/// creating a circular dependency between IPO and the linker.
42-
InternalizeCallbackFn internalizeCallback;
43-
4436
ModuleOp getSourceModule() { return cast<ModuleOp>(src.get()); }
4537

4638
bool shouldOverrideFromSrc() const { return flags & Linker::OverrideFromSrc; }
@@ -85,8 +77,7 @@ class ModuleLinker {
8577
public:
8678
ModuleLinker(IRMover &mover, OwningOpRef<Operation *> src, unsigned flags,
8779
InternalizeCallbackFn internalizeCallback = {})
88-
: mover(mover), src(std::move(src)), flags(flags),
89-
internalizeCallback(std::move(internalizeCallback)) {}
80+
: mover(mover), src(std::move(src)), flags(flags) {}
9081
LogicalResult run();
9182
};
9283

@@ -235,7 +226,6 @@ bool ModuleLinker::linkIfNeeded(GlobalValue gv,
235226

236227
LogicalResult ModuleLinker::run() {
237228
LLVM_DEBUG(llvm::dbgs() << "ModuleLinker::run" << "\n");
238-
auto dst = mover.getComposite();
239229
auto src = getSourceModule();
240230

241231
std::vector<Operation *> gvToClone;
@@ -263,14 +253,6 @@ LogicalResult ModuleLinker::run() {
263253
// llvm_unreachable("unimplemented");
264254
// }
265255

266-
if (internalizeCallback) {
267-
for (GlobalValue gvl : valuesToLink) {
268-
StringRef name = gvl.getLinkedName();
269-
LLVM_DEBUG(llvm::dbgs() << "Internalizing: " << name << "\n");
270-
internalize.insert(name);
271-
}
272-
}
273-
274256
bool hasErrors = false;
275257
// TODO: We are moving whatever the local src points to here (this->src), so
276258
// it can't be touched past this point.
@@ -284,20 +266,20 @@ LogicalResult ModuleLinker::run() {
284266
if (hasErrors)
285267
return failure();
286268

287-
if (internalizeCallback) {
288-
internalizeCallback(dst, internalize);
289-
}
290-
291269
return success();
292270
}
293271

294-
Linker::Linker(ModuleOp composite, const LinkerConfig &cfg)
295-
: config(cfg), mover(composite) {}
272+
LogicalResult Linker::linkInModule(OwningOpRef<ModuleOp> src, unsigned flags) {
273+
if (!composite) {
274+
auto interface =
275+
dyn_cast_or_null<LinkerInterface>(src->getOperation()->getDialect());
276+
if (!interface)
277+
return emitError("Module does not have a linker interface");
278+
composite = interface->createCompositeModule(src.get());
279+
}
296280

297-
LogicalResult Linker::linkInModule(OwningOpRef<Operation *> src, unsigned flags,
298-
InternalizeCallbackFn internalizeCallback) {
299-
ModuleLinker modLinker(mover, std::move(src), flags,
300-
std::move(internalizeCallback));
281+
IRMover mover(composite.get());
282+
ModuleLinker modLinker(mover, std::move(src), flags);
301283
return modLinker.run();
302284
}
303285

mlir/lib/Tools/mlir-link/MlirLinkMain.cpp

Lines changed: 13 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -189,11 +189,11 @@ class FileProcessor {
189189
std::string errorMessage;
190190
auto input = openInputFile(fileName, &errorMessage);
191191
if (!input)
192-
return emitFileError(fileName, errorMessage);
192+
return linker.emitFileError(fileName, errorMessage);
193193

194194
// Process each file chunk
195195
if (failed(processFile(std::move(input), config)))
196-
return emitFileError(fileName, "Failed to process input file");
196+
return linker.emitFileError(fileName, "Failed to process input file");
197197

198198
return success();
199199
}
@@ -220,26 +220,20 @@ class FileProcessor {
220220
// Parse the source file
221221
OwningOpRef<Operation *> op =
222222
parseSourceFileForTool(sourceMgr, ctx, true /*insertImplicitModule*/);
223-
224223
ctx->enableMultithreading(wasThreadingEnabled);
225224

226225
if (!op)
227-
return emitError("Failed to parse source file");
226+
return linker.emitError("Failed to parse source file");
227+
228+
if (!isa<ModuleOp>(op.get()))
229+
return op->emitError("Expected a ModuleOp");
228230

229231
// TBD: symbol promotion
230232

231233
// TBD: internalization
232-
234+
OwningOpRef<ModuleOp> mod = cast<ModuleOp>(op.release());
233235
// Link the parsed operation
234-
return linker.linkInModule(std::move(op), config.flags);
235-
}
236-
237-
LogicalResult emitFileError(const Twine &fileName, const Twine &message) {
238-
return emitError("Error processing file '" + fileName + "': " + message);
239-
}
240-
241-
LogicalResult emitError(const Twine &message) {
242-
return mlir::emitError(mlir::UnknownLoc::get(linker.getContext()), message);
236+
return linker.linkInModule(std::move(mod), config.flags);
243237
}
244238

245239
Linker &linker;
@@ -269,22 +263,14 @@ LogicalResult mlir::MlirLinkMain(int argc, char **argv,
269263
MLIRContext context(registry);
270264
context.allowUnregisteredDialects(config.shouldAllowUnregisteredDialects());
271265

272-
// Create composite module
273-
OwningOpRef<ModuleOp> composite = [&context]() {
274-
OpBuilder builder(&context);
275-
return OwningOpRef<ModuleOp>(builder.create<ModuleOp>(
276-
FileLineColLoc::get(&context, "mlir-link", 0, 0)));
277-
}();
278-
279-
Linker linker(composite.get(), config);
266+
Linker linker(config, &context);
280267

281268
// Prepare output file
282269
std::string errorMessage;
283270
auto out = openOutputFile(config.outputFile, &errorMessage);
284271

285272
if (!out) {
286-
errs() << errorMessage;
287-
return failure();
273+
return linker.emitError("Failed to open output file: " + errorMessage);
288274
}
289275

290276
StringRef inMarker = config.inputSplitMarker();
@@ -296,10 +282,11 @@ LogicalResult mlir::MlirLinkMain(int argc, char **argv,
296282
if (failed(proc.linkFiles(config.inputFiles)))
297283
return failure();
298284

299-
// TODO: Remove
285+
OwningOpRef<ModuleOp> composite = linker.takeModule();
300286
if (failed(verify(composite.get(), true))) {
301-
llvm::outs() << "Verify failed\n";
287+
return composite->emitError("verification after linking failed");
302288
}
289+
303290
composite->print(out->os());
304291
out->keep();
305292

0 commit comments

Comments
 (0)