Skip to content

Commit 575bfbc

Browse files
committed
[WIP][MLIR][mlir-link] Add COMDAT resolution to LLVM dialect linker.
1 parent 0d4dab7 commit 575bfbc

File tree

3 files changed

+81
-31
lines changed

3 files changed

+81
-31
lines changed

mlir/include/mlir/Dialect/LLVMIR/LLVMLinkerInterface.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,9 @@ class LLVMSymbolLinkerInterface
3535
LogicalResult initialize(ModuleOp src) override;
3636
LogicalResult finalize(ModuleOp dst) const override;
3737
Operation *appendGlobals(llvm::StringRef glob, link::LinkState &state);
38+
LogicalResult resolveComdats(ModuleOp srcMod,
39+
SymbolTableCollection &collection);
40+
std::optional<link::ConflictResolution> getComdatResolution(Operation *);
3841

3942
template <typename structor_t>
4043
Operation *appendGlobalStructors(link::LinkState &state) {
@@ -125,6 +128,7 @@ class LLVMSymbolLinkerInterface
125128
private:
126129
DataLayoutSpecInterface dtla = {};
127130
TargetSystemSpecInterface targetSys = {};
131+
llvm::StringMap<link::ConflictResolution> comdatResolution;
128132
};
129133

130134
} // namespace LLVM

mlir/include/mlir/Linker/LLVMLinkerMixin.h

Lines changed: 21 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,12 @@ struct ComdatSelector {
144144
ComdatKind kind;
145145
};
146146

147+
struct Comdat {
148+
ComdatKind kind;
149+
StringRef name;
150+
llvm::SmallPtrSet< Operation *, 2> users;
151+
};
152+
147153
//===----------------------------------------------------------------------===//
148154
// LLVMLinkerMixin
149155
//===----------------------------------------------------------------------===//
@@ -172,6 +178,12 @@ class LLVMLinkerMixin {
172178
if (derived.isComdat(pair.src))
173179
return true;
174180

181+
if (std::optional<link::ConflictResolution> res =
182+
derived.getComdatResolution(pair.src)) {
183+
// Comdats are either used or dropped as a group
184+
return res.value() == ConflictResolution::LinkFromSrc;
185+
}
186+
175187
Linkage srcLinkage = derived.getLinkage(pair.src);
176188

177189
// Always import variables with appending linkage.
@@ -218,6 +230,10 @@ class LLVMLinkerMixin {
218230
return pair.src->emitError(error) << " dst: " << pair.dst->getLoc();
219231
};
220232

233+
if (derived.isComdat(pair.src) != derived.isComdat(pair.dst)) {
234+
return linkError("Linking ComdatOp with non-comdat op");
235+
}
236+
221237
Linkage srcLinkage = derived.getLinkage(pair.src);
222238
Linkage dstLinkage = derived.getLinkage(pair.dst);
223239

@@ -259,6 +275,11 @@ class LLVMLinkerMixin {
259275
assert(derived.canBeLinked(pair.src) && "expected linkable operation");
260276
assert(derived.canBeLinked(pair.dst) && "expected linkable operation");
261277

278+
// We insert the computed comdat information into the dst module comdat op
279+
// Make sure it is linked in as we want
280+
if (derived.isComdat(pair.src))
281+
return ConflictResolution::LinkFromDst;
282+
262283
Linkage srcLinkage = derived.getLinkage(pair.src);
263284
Linkage dstLinkage = derived.getLinkage(pair.dst);
264285

@@ -352,37 +373,6 @@ class LLVMLinkerMixin {
352373
return ConflictResolution::LinkFromSrc;
353374
}
354375

355-
std::optional<ComdatSelector> srcComdatSel =
356-
derived.getComdatSelector(pair.src);
357-
std::optional<ComdatSelector> dstComdatSel =
358-
derived.getComdatSelector(pair.dst);
359-
if (srcComdatSel.has_value() && dstComdatSel.has_value()) {
360-
auto srcComdatName = srcComdatSel->name;
361-
auto dstComdatName = dstComdatSel->name;
362-
auto srcComdat = srcComdatSel->kind;
363-
auto dstComdat = dstComdatSel->kind;
364-
if (srcComdatName != dstComdatName) {
365-
llvm_unreachable("Comdat selector names don't match");
366-
}
367-
if (srcComdat != dstComdat) {
368-
llvm_unreachable("Comdat selector kinds don't match");
369-
}
370-
371-
if (srcComdat == mlir::LLVM::comdat::Comdat::Any) {
372-
return ConflictResolution::LinkFromDst;
373-
}
374-
if (srcComdat == mlir::LLVM::comdat::Comdat::NoDeduplicate) {
375-
return ConflictResolution::Failure;
376-
}
377-
if (srcComdat == mlir::LLVM::comdat::Comdat::ExactMatch) {
378-
return ConflictResolution::LinkFromDst;
379-
}
380-
if (srcComdat == mlir::LLVM::comdat::Comdat::SameSize) {
381-
return ConflictResolution::LinkFromDst;
382-
}
383-
llvm_unreachable("unimplemented comdat kind");
384-
}
385-
386376
// If we reach here, we have two external definitions that can't be resolved
387377
// This is typically an error case in LLVM linking
388378
if (isExternalLinkage(srcLinkage) && isExternalLinkage(dstLinkage) &&

mlir/lib/Dialect/LLVMIR/IR/LLVMLinkerInterface.cpp

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -547,6 +547,62 @@ Operation *LLVM::LLVMSymbolLinkerInterface::appendGlobals(llvm::StringRef glob,
547547
llvm_unreachable("unexpected operation");
548548
}
549549

550+
LogicalResult LLVM::LLVMSymbolLinkerInterface::resolveComdats(
551+
ModuleOp srcMod, SymbolTableCollection &collection) {
552+
LLVM::ComdatOp srcComdatOp;
553+
LLVM::ComdatOp dstComdatOp;
554+
for (auto &op : srcMod) {
555+
if (auto comdatOp = dyn_cast<LLVM::ComdatOp>(op)) {
556+
srcComdatOp = comdatOp;
557+
break;
558+
}
559+
}
560+
comdatResolution.clear();
561+
562+
// Get current resolved ComdatOp or insert srcComdatOp into summary
563+
if (auto it = summary.find(getSymbol(srcComdatOp)); it != summary.end()) {
564+
dstComdatOp = cast<LLVM::ComdatOp>(*it);
565+
} else {
566+
summary[getSymbol(srcComdatOp)] = srcComdatOp;
567+
for (Operation &op : srcComdatOp.getBody().front())
568+
comdatResolution.try_emplace(getSymbol(&op),
569+
ConflictResolution::LinkFromSrc);
570+
return success();
571+
}
572+
573+
auto dstMod = dstComdatOp->getParentOfType<ModuleOp>();
574+
SymbolTable &dstSymbolTab = collection.getSymbolTable(dstMod);
575+
// SymbolTable &srcSymbolTab = collection.getSymbolTable(srcMod);
576+
577+
for (Operation &op : srcComdatOp.getBody().front()) {
578+
auto srcSelector = cast<LLVM::ComdatSelectorOp>(op);
579+
if (Operation *dstSelector =
580+
dstSymbolTab.lookup(srcSelector.getSymName())) {
581+
// compute resolution
582+
// insert into map
583+
// remove dst ops from summary if src selected
584+
} else {
585+
// If no conflict, choose src
586+
comdatResolution.try_emplace(getSymbol(srcSelector),
587+
ConflictResolution::LinkFromSrc);
588+
// insert into srcComdatOp body
589+
}
590+
}
591+
// insert resolved comdat into dst - update sym tab (insert ops)
592+
assert(false);
593+
return success();
594+
}
595+
596+
std::optional<link::ConflictResolution>
597+
LLVM::LLVMSymbolLinkerInterface::getComdatResolution(Operation *op) {
598+
if (hasComdat(op)) {
599+
if (auto resIt = comdatResolution.find(getSymbol(op));
600+
resIt != comdatResolution.end())
601+
return resIt->second;
602+
}
603+
return {};
604+
}
605+
550606
//===----------------------------------------------------------------------===//
551607
// registerLinkerInterface
552608
//===----------------------------------------------------------------------===//

0 commit comments

Comments
 (0)