Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions mlir/include/mlir/Dialect/LLVMIR/LLVMLinkerInterface.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@ class LLVMSymbolLinkerInterface
LogicalResult initialize(ModuleOp src) override;
LogicalResult finalize(ModuleOp dst) const override;
Operation *appendGlobals(llvm::StringRef glob, link::LinkState &state);
LogicalResult resolveComdats(ModuleOp srcMod,
SymbolTableCollection &collection);
std::optional<link::ConflictResolution> getComdatResolution(Operation *);

template <typename structor_t>
Operation *appendGlobalStructors(link::LinkState &state) {
Expand Down Expand Up @@ -125,6 +128,7 @@ class LLVMSymbolLinkerInterface
private:
DataLayoutSpecInterface dtla = {};
TargetSystemSpecInterface targetSys = {};
llvm::StringMap<link::ConflictResolution> comdatResolution;
};

} // namespace LLVM
Expand Down
52 changes: 21 additions & 31 deletions mlir/include/mlir/Linker/LLVMLinkerMixin.h
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,12 @@ struct ComdatSelector {
ComdatKind kind;
};

struct Comdat {
ComdatKind kind;
StringRef name;
llvm::SmallPtrSet< Operation *, 2> users;
};

//===----------------------------------------------------------------------===//
// LLVMLinkerMixin
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -172,6 +178,12 @@ class LLVMLinkerMixin {
if (derived.isComdat(pair.src))
return true;

if (std::optional<link::ConflictResolution> res =
derived.getComdatResolution(pair.src)) {
// Comdats are either used or dropped as a group
return res.value() == ConflictResolution::LinkFromSrc;
}

Linkage srcLinkage = derived.getLinkage(pair.src);

// Always import variables with appending linkage.
Expand Down Expand Up @@ -218,6 +230,10 @@ class LLVMLinkerMixin {
return pair.src->emitError(error) << " dst: " << pair.dst->getLoc();
};

if (derived.isComdat(pair.src) != derived.isComdat(pair.dst)) {
return linkError("Linking ComdatOp with non-comdat op");
}

Linkage srcLinkage = derived.getLinkage(pair.src);
Linkage dstLinkage = derived.getLinkage(pair.dst);

Expand Down Expand Up @@ -259,6 +275,11 @@ class LLVMLinkerMixin {
assert(derived.canBeLinked(pair.src) && "expected linkable operation");
assert(derived.canBeLinked(pair.dst) && "expected linkable operation");

// We insert the computed comdat information into the dst module comdat op
// Make sure it is linked in as we want
if (derived.isComdat(pair.src))
return ConflictResolution::LinkFromDst;

Linkage srcLinkage = derived.getLinkage(pair.src);
Linkage dstLinkage = derived.getLinkage(pair.dst);

Expand Down Expand Up @@ -352,37 +373,6 @@ class LLVMLinkerMixin {
return ConflictResolution::LinkFromSrc;
}

std::optional<ComdatSelector> srcComdatSel =
derived.getComdatSelector(pair.src);
std::optional<ComdatSelector> dstComdatSel =
derived.getComdatSelector(pair.dst);
if (srcComdatSel.has_value() && dstComdatSel.has_value()) {
auto srcComdatName = srcComdatSel->name;
auto dstComdatName = dstComdatSel->name;
auto srcComdat = srcComdatSel->kind;
auto dstComdat = dstComdatSel->kind;
if (srcComdatName != dstComdatName) {
llvm_unreachable("Comdat selector names don't match");
}
if (srcComdat != dstComdat) {
llvm_unreachable("Comdat selector kinds don't match");
}

if (srcComdat == mlir::LLVM::comdat::Comdat::Any) {
return ConflictResolution::LinkFromDst;
}
if (srcComdat == mlir::LLVM::comdat::Comdat::NoDeduplicate) {
return ConflictResolution::Failure;
}
if (srcComdat == mlir::LLVM::comdat::Comdat::ExactMatch) {
return ConflictResolution::LinkFromDst;
}
if (srcComdat == mlir::LLVM::comdat::Comdat::SameSize) {
return ConflictResolution::LinkFromDst;
}
llvm_unreachable("unimplemented comdat kind");
}

// If we reach here, we have two external definitions that can't be resolved
// This is typically an error case in LLVM linking
if (isExternalLinkage(srcLinkage) && isExternalLinkage(dstLinkage) &&
Expand Down
14 changes: 14 additions & 0 deletions mlir/include/mlir/Linker/LinkerInterface.h
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,12 @@ class SymbolLinkerInterface : public LinkerInterface<SymbolLinkerInterface> {
return state.clone(src);
}

/// Perform tasks that need to be computed on whole-module basis before actual summary.
/// E.g. Pre-compute COMDAT resolution before actually linking the modules.
virtual LogicalResult moduleOpSummary(ModuleOp module) {
return success();
}

/// Dependencies of the given operation required to be linked.
virtual SmallVector<Operation *>
dependencies(Operation *op, SymbolTableCollection &collection) const = 0;
Expand Down Expand Up @@ -276,6 +282,14 @@ class SymbolLinkerInterfaces {
return Conflict::noConflict(src);
}

LogicalResult moduleOpSummary(ModuleOp src) {
for (SymbolLinkerInterface *linker : interfaces) {
if (failed(linker->moduleOpSummary(src)))
return failure();
}
return success();
}

private:
SetVector<SymbolLinkerInterface *> interfaces;
};
Expand Down
56 changes: 56 additions & 0 deletions mlir/lib/Dialect/LLVMIR/IR/LLVMLinkerInterface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -547,6 +547,62 @@ Operation *LLVM::LLVMSymbolLinkerInterface::appendGlobals(llvm::StringRef glob,
llvm_unreachable("unexpected operation");
}

LogicalResult LLVM::LLVMSymbolLinkerInterface::resolveComdats(
ModuleOp srcMod, SymbolTableCollection &collection) {
LLVM::ComdatOp srcComdatOp;
LLVM::ComdatOp dstComdatOp;
for (auto &op : srcMod) {
if (auto comdatOp = dyn_cast<LLVM::ComdatOp>(op)) {
srcComdatOp = comdatOp;
break;
}
}
comdatResolution.clear();

// Get current resolved ComdatOp or insert srcComdatOp into summary
if (auto it = summary.find(getSymbol(srcComdatOp)); it != summary.end()) {
dstComdatOp = cast<LLVM::ComdatOp>(*it);
} else {
summary[getSymbol(srcComdatOp)] = srcComdatOp;
for (Operation &op : srcComdatOp.getBody().front())
comdatResolution.try_emplace(getSymbol(&op),
ConflictResolution::LinkFromSrc);
return success();
}

auto dstMod = dstComdatOp->getParentOfType<ModuleOp>();
SymbolTable &dstSymbolTab = collection.getSymbolTable(dstMod);
// SymbolTable &srcSymbolTab = collection.getSymbolTable(srcMod);

for (Operation &op : srcComdatOp.getBody().front()) {
auto srcSelector = cast<LLVM::ComdatSelectorOp>(op);
if (Operation *dstSelector =
dstSymbolTab.lookup(srcSelector.getSymName())) {
// compute resolution
// insert into map
// remove dst ops from summary if src selected
} else {
// If no conflict, choose src
comdatResolution.try_emplace(getSymbol(srcSelector),
ConflictResolution::LinkFromSrc);
// insert into srcComdatOp body
}
}
// insert resolved comdat into dst - update sym tab (insert ops)
assert(false);
return success();
}

std::optional<link::ConflictResolution>
LLVM::LLVMSymbolLinkerInterface::getComdatResolution(Operation *op) {
if (hasComdat(op)) {
if (auto resIt = comdatResolution.find(getSymbol(op));
resIt != comdatResolution.end())
return resIt->second;
}
return {};
}

//===----------------------------------------------------------------------===//
// registerLinkerInterface
//===----------------------------------------------------------------------===//
Expand Down
5 changes: 4 additions & 1 deletion mlir/lib/IR/BuiltinLinkerInterface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,11 @@ class BuiltinLinkerInterface : public ModuleLinkerInterface {

LogicalResult summarize(ModuleOp src, unsigned flags) override {
WalkResult result = src.walk([&](Operation *op) {
if (op == src)
if (op == src) {
if (symbolLinkers.moduleOpSummary(src).failed())
return WalkResult::interrupt();
return WalkResult::advance();
}

if (summarize(op, flags, /*forDependency=*/false).failed())
return WalkResult::interrupt();
Expand Down