diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMLinkerInterface.h b/mlir/include/mlir/Dialect/LLVMIR/LLVMLinkerInterface.h index 536abe03c6f76..09ef926dbed76 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMLinkerInterface.h +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMLinkerInterface.h @@ -17,6 +17,7 @@ class LLVMSymbolLinkerInterface static void setVisibility(Operation *op, Visibility visibility); static bool isComdat(Operation *op); static std::optional getComdatSelector(Operation *op); + static LLVM::comdat::Comdat getComdatSelectionKind(Operation *op); static bool isDeclaration(Operation *op); static unsigned getBitWidth(Operation *op); static UnnamedAddr getUnnamedAddr(Operation *op); @@ -35,6 +36,9 @@ class LLVMSymbolLinkerInterface LogicalResult initialize(ModuleOp src) override; LogicalResult finalize(ModuleOp dst) const override; Operation *appendGlobals(llvm::StringRef glob, link::LinkState &state); + LogicalResult resolveComdats(ModuleOp srcMod, + SymbolTableCollection &collection); + std::optional getComdatResolution(Operation *); template Operation *appendGlobalStructors(link::LinkState &state) { @@ -125,6 +129,8 @@ class LLVMSymbolLinkerInterface private: DataLayoutSpecInterface dtla = {}; TargetSystemSpecInterface targetSys = {}; + llvm::StringMap> + comdatResolution; }; } // namespace LLVM diff --git a/mlir/include/mlir/Linker/LLVMLinkerMixin.h b/mlir/include/mlir/Linker/LLVMLinkerMixin.h index 9bf6812d96c8f..61ae2a725ad21 100644 --- a/mlir/include/mlir/Linker/LLVMLinkerMixin.h +++ b/mlir/include/mlir/Linker/LLVMLinkerMixin.h @@ -144,6 +144,12 @@ struct ComdatSelector { ComdatKind kind; }; +struct Comdat { + ComdatKind kind; + Operation *selectorOp; + llvm::SmallPtrSet< Operation *, 2> users; +}; + //===----------------------------------------------------------------------===// // LLVMLinkerMixin //===----------------------------------------------------------------------===// @@ -172,6 +178,12 @@ class LLVMLinkerMixin { if (derived.isComdat(pair.src)) return true; + if (std::optional res = + derived.getComdatResolution(pair.src)) { + // Comdats are either used or dropped as a group + return res.value() == ConflictResolution::LinkFromSrc; + } + Linkage srcLinkage = derived.getLinkage(pair.src); // Always import variables with appending linkage. @@ -218,6 +230,10 @@ class LLVMLinkerMixin { return pair.src->emitError(error) << " dst: " << pair.dst->getLoc(); }; + if (derived.isComdat(pair.src) != derived.isComdat(pair.dst)) { + return linkError("Linking ComdatOp with non-comdat op"); + } + Linkage srcLinkage = derived.getLinkage(pair.src); Linkage dstLinkage = derived.getLinkage(pair.dst); @@ -259,6 +275,11 @@ class LLVMLinkerMixin { assert(derived.canBeLinked(pair.src) && "expected linkable operation"); assert(derived.canBeLinked(pair.dst) && "expected linkable operation"); + // We insert the computed comdat information into the dst module comdat op + // Make sure it is linked in as we want + if (derived.isComdat(pair.src)) + return ConflictResolution::LinkFromDst; + Linkage srcLinkage = derived.getLinkage(pair.src); Linkage dstLinkage = derived.getLinkage(pair.dst); @@ -352,37 +373,6 @@ class LLVMLinkerMixin { return ConflictResolution::LinkFromSrc; } - std::optional srcComdatSel = - derived.getComdatSelector(pair.src); - std::optional dstComdatSel = - derived.getComdatSelector(pair.dst); - if (srcComdatSel.has_value() && dstComdatSel.has_value()) { - auto srcComdatName = srcComdatSel->name; - auto dstComdatName = dstComdatSel->name; - auto srcComdat = srcComdatSel->kind; - auto dstComdat = dstComdatSel->kind; - if (srcComdatName != dstComdatName) { - llvm_unreachable("Comdat selector names don't match"); - } - if (srcComdat != dstComdat) { - llvm_unreachable("Comdat selector kinds don't match"); - } - - if (srcComdat == mlir::LLVM::comdat::Comdat::Any) { - return ConflictResolution::LinkFromDst; - } - if (srcComdat == mlir::LLVM::comdat::Comdat::NoDeduplicate) { - return ConflictResolution::Failure; - } - if (srcComdat == mlir::LLVM::comdat::Comdat::ExactMatch) { - return ConflictResolution::LinkFromDst; - } - if (srcComdat == mlir::LLVM::comdat::Comdat::SameSize) { - return ConflictResolution::LinkFromDst; - } - llvm_unreachable("unimplemented comdat kind"); - } - // If we reach here, we have two external definitions that can't be resolved // This is typically an error case in LLVM linking if (isExternalLinkage(srcLinkage) && isExternalLinkage(dstLinkage) && diff --git a/mlir/include/mlir/Linker/LinkerInterface.h b/mlir/include/mlir/Linker/LinkerInterface.h index 09c67490b6fe6..eea014303057b 100644 --- a/mlir/include/mlir/Linker/LinkerInterface.h +++ b/mlir/include/mlir/Linker/LinkerInterface.h @@ -143,6 +143,12 @@ class SymbolLinkerInterface : public LinkerInterface { return state.clone(src); } + /// Perform tasks that need to be computed on whole-module basis before actual summary. + /// E.g. Pre-compute COMDAT resolution before actually linking the modules. + virtual LogicalResult moduleOpSummary(ModuleOp module) { + return success(); + } + /// Dependencies of the given operation required to be linked. virtual SmallVector dependencies(Operation *op, SymbolTableCollection &collection) const = 0; @@ -276,6 +282,14 @@ class SymbolLinkerInterfaces { return Conflict::noConflict(src); } + LogicalResult moduleOpSummary(ModuleOp src) { + for (SymbolLinkerInterface *linker : interfaces) { + if (failed(linker->moduleOpSummary(src))) + return failure(); + } + return success(); + } + private: SetVector interfaces; }; diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMLinkerInterface.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMLinkerInterface.cpp index 009d2f56d0de3..4b72eb26e419f 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMLinkerInterface.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMLinkerInterface.cpp @@ -108,6 +108,13 @@ LLVM::LLVMSymbolLinkerInterface::getComdatSelector(Operation *op) { return {{comdatSelector.getSymName(), comdatSelector.getComdat()}}; } +LLVM::comdat::Comdat +LLVM::LLVMSymbolLinkerInterface::getComdatSelectionKind(Operation *op) { + if (auto selector = dyn_cast(op)) + return selector.getComdat(); + llvm_unreachable("expected selector op"); +} + // Return true if the primary definition of this global value is outside of // the current translation unit. bool LLVM::LLVMSymbolLinkerInterface::isDeclaration(Operation *op) { @@ -547,6 +554,80 @@ Operation *LLVM::LLVMSymbolLinkerInterface::appendGlobals(llvm::StringRef glob, llvm_unreachable("unexpected operation"); } +static ConflictResolution computeComdatResolution() {} + +LogicalResult LLVM::LLVMSymbolLinkerInterface::resolveComdats( + ModuleOp srcMod, SymbolTableCollection &collection) { + LLVM::ComdatOp srcComdatOp; + LLVM::ComdatOp dstComdatOp; + for (auto &op : srcMod) { + if (auto comdatOp = dyn_cast(op)) { + srcComdatOp = comdatOp; + break; + } + } + + SymbolUserMap srcSymbolUsers(collection, + srcComdatOp->getParentOfType()); + // Get current resolved ComdatOp or insert srcComdatOp into summary + // TODO: use comdat summary to find conflict + if (auto it = summary.find(getSymbol(srcComdatOp)); it != summary.end()) { + dstComdatOp = cast(*it); + } else { + summary[getSymbol(srcComdatOp)] = srcComdatOp; + for (Operation &op : srcComdatOp.getBody().front()) { + ArrayRef users = srcSymbolUsers.getUsers(&op); + comdatResolution.try_emplace( + getSymbol(&op), + std::make_pair(link::Comdat{getComdatSelectionKind(&op), + &op, + {users.begin(), users.end()}}, + ConflictResolution::LinkFromSrc)); + } + return success(); + } + + for (Operation &op : srcComdatOp.getBody().front()) { + auto srcSelector = cast(op); + // TODO: use custom enum for comdat? + // If no conflict choose src + auto res = ConflictResolution::LinkFromSrc; + if (auto dstComdatIt = comdatResolution.find(getSymbol(&op)); + dstComdatIt != comdatResolution.end()) { + res = computeComdatResolution(/*TODO*/); + // remove dst ops from summary if src selected + } + switch (res) { + case ConflictResolution::LinkFromSrc: { + ArrayRef users = srcSymbolUsers.getUsers(&op); + comdatResolution.try_emplace( + getSymbol(srcSelector), + std::make_pair(link::Comdat{getComdatSelectionKind(srcSelector), + srcSelector, + {users.begin(), users.end()}}, + ConflictResolution::LinkFromSrc)); + break; + } + case ConflictResolution::LinkFromDst: + case ConflictResolution::LinkFromBothAndRenameDst: + case ConflictResolution::LinkFromBothAndRenameSrc: + case ConflictResolution::Failure: + return failure(); + } + } + return success(); +} + +std::optional +LLVM::LLVMSymbolLinkerInterface::getComdatResolution(Operation *op) { + if (hasComdat(op)) { + if (auto resIt = comdatResolution.find(getSymbol(op)); + resIt != comdatResolution.end()) + return resIt->second.second; + } + return {}; +} + //===----------------------------------------------------------------------===// // registerLinkerInterface //===----------------------------------------------------------------------===// diff --git a/mlir/lib/IR/BuiltinLinkerInterface.cpp b/mlir/lib/IR/BuiltinLinkerInterface.cpp index 0ffd211c0179b..851ac7022426d 100644 --- a/mlir/lib/IR/BuiltinLinkerInterface.cpp +++ b/mlir/lib/IR/BuiltinLinkerInterface.cpp @@ -36,8 +36,11 @@ class BuiltinLinkerInterface : public ModuleLinkerInterface { LogicalResult summarize(ModuleOp src, unsigned flags) override { WalkResult result = src.walk([&](Operation *op) { - if (op == src) + if (op == src) { + if (symbolLinkers.moduleOpSummary(src).failed()) + return WalkResult::interrupt(); return WalkResult::advance(); + } if (summarize(op, flags, /*forDependency=*/false).failed()) return WalkResult::interrupt();