Skip to content

Commit cdab6c1

Browse files
authored
[mlir][mlir-link] Copy in datalayout in llvm dialect linker (#77)
* [mlir][mlir-link] Copy in datalayout in llvm dialect linker
1 parent fba1ea8 commit cdab6c1

File tree

5 files changed

+44
-1
lines changed

5 files changed

+44
-1
lines changed

mlir/include/mlir/Dialect/LLVMIR/LLVMLinkerInterface.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,8 @@ class LLVMSymbolLinkerInterface
3131
StringRef getSymbol(Operation *op) const override;
3232
Operation *materialize(Operation *src, link::LinkState &state) const override;
3333
SmallVector<Operation *> dependencies(Operation *op) const override;
34+
LogicalResult initialize(ModuleOp src) override;
35+
LogicalResult finalize(ModuleOp dst) const override;
3436
Operation *appendGlobals(llvm::StringRef glob, link::LinkState &state);
3537

3638
template <typename structor_t>
@@ -118,6 +120,10 @@ class LLVMSymbolLinkerInterface
118120
structor.setDataAttr(newDataAttr);
119121
return cloned;
120122
}
123+
124+
private:
125+
DataLayoutSpecInterface dtla = {};
126+
TargetSystemSpecInterface targetSys = {};
121127
};
122128

123129
} // namespace LLVM

mlir/include/mlir/Linker/LinkerInterface.h

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,9 +87,12 @@ class LinkerInterface : public DialectInterface::Base<ConcreteType> {
8787
LinkerInterface(Dialect *dialect)
8888
: DialectInterface::Base<ConcreteType>(dialect) {}
8989

90-
/// Runs initialization of alinker before summarization for the given module
90+
/// Runs initialization of a linker before summarization for the given module
9191
virtual LogicalResult initialize(ModuleOp src) { return success(); }
9292

93+
/// Runs finalization of a linker after linking for the given module
94+
virtual LogicalResult finalize(ModuleOp dst) const { return success(); }
95+
9396
/// Link operations from current summary using state builder
9497
virtual LogicalResult link(LinkState &state) const = 0;
9598
};
@@ -255,6 +258,14 @@ class SymbolLinkerInterfaces {
255258
return success();
256259
}
257260

261+
LogicalResult finalize(ModuleOp dst) const {
262+
for (SymbolLinkerInterface *linker : interfaces) {
263+
if (failed(linker->finalize(dst)))
264+
return failure();
265+
}
266+
return success();
267+
}
268+
258269
Conflict findConflict(Operation *src) const {
259270
for (SymbolLinkerInterface *linker : interfaces) {
260271
if (auto pair = linker->findConflict(src); pair.hasConflict())

mlir/lib/Dialect/LLVMIR/IR/LLVMLinkerInterface.cpp

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
//===----------------------------------------------------------------------===//
1212

1313
#include "mlir/Dialect/LLVMIR/LLVMLinkerInterface.h"
14+
#include "mlir/Dialect/DLTI/DLTI.h"
1415
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
1516
#include "mlir/Linker/LLVMLinkerMixin.h"
1617
#include "mlir/Linker/LinkerInterface.h"
@@ -309,6 +310,23 @@ LLVM::LLVMSymbolLinkerInterface::dependencies(Operation *op) const {
309310
return result;
310311
}
311312

313+
LogicalResult LLVM::LLVMSymbolLinkerInterface::initialize(ModuleOp src) {
314+
dtla = src.getDataLayoutSpec();
315+
targetSys = src.getTargetSystemSpec();
316+
return success();
317+
}
318+
319+
LogicalResult LLVM::LLVMSymbolLinkerInterface::finalize(ModuleOp dst) const {
320+
SmallVector<NamedAttribute, 2> newAttrs;
321+
// The names are currently hardcoded for dlti dialect
322+
// Nice solution would be preferable
323+
if (dtla)
324+
dst->setAttr(DataLayoutSpecAttr::name, dyn_cast<Attribute>(dtla));
325+
if (targetSys)
326+
dst->setAttr(TargetSystemSpecAttr::name, dyn_cast<Attribute>(targetSys));
327+
return success();
328+
}
329+
312330
static std::pair<Attribute, Type>
313331
getAppendedArrayAttr(llvm::ArrayRef<mlir::Operation *> globs,
314332
LinkState &state) {

mlir/lib/IR/BuiltinLinkerInterface.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,10 @@ class BuiltinLinkerInterface : public ModuleLinkerInterface {
3030
return symbolLinkers.initialize(src);
3131
}
3232

33+
LogicalResult finalize(ModuleOp dst) const override {
34+
return symbolLinkers.finalize(dst);
35+
}
36+
3337
LogicalResult summarize(ModuleOp src, unsigned flags) override {
3438
WalkResult result = src.walk([&](Operation *op) {
3539
if (op == src)

mlir/lib/Linker/Linker.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,10 @@ OwningOpRef<ModuleOp> Linker::link(bool sortSymbols) {
109109
symbol->moveBefore(&mod.front());
110110
}
111111
}
112+
ModuleLinkerInterface *iface = getModuleLinkerInterface(composite.get());
113+
114+
if (failed(iface->finalize(composite.get())))
115+
return nullptr;
112116

113117
return std::move(composite);
114118
}

0 commit comments

Comments
 (0)