Skip to content

Commit 450f828

Browse files
committed
[MLIR][mlir-link] Add partial support for Comdat
1 parent 1251463 commit 450f828

File tree

6 files changed

+140
-28
lines changed

6 files changed

+140
-28
lines changed

clang/lib/CIR/Dialect/IR/CIRLinkerInterface.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,16 @@ class CIRSymbolLinkerInterface
7373
llvm_unreachable("unexpected operation");
7474
}
7575

76+
static bool isComdat(Operation *op) {
77+
// TODO(frabert): Extracting comdat info from CIR is not implemented yet
78+
return false;
79+
}
80+
81+
static std::optional<mlir::link::ComdatSelector> getComdatSelector(Operation *op) {
82+
// TODO(frabert): Extracting comdat info from CIR is not implemented yet
83+
return std::nullopt;
84+
}
85+
7686
// TODO: expose lowerCIRVisibilityToLLVMVisibility from LowerToLLVM.cpp
7787
static Visibility toLLVMVisibility(cir::VisibilityAttr visibility) {
7888
return toLLVMVisibility(visibility.getValue());

mlir/include/mlir/Dialect/LLVMIR/LLVMLinkerInterface.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@ class LLVMSymbolLinkerInterface
1515
static Linkage getLinkage(Operation *op);
1616
static Visibility getVisibility(Operation *op);
1717
static void setVisibility(Operation *op, Visibility visibility);
18+
static bool isComdat(Operation *op);
19+
static std::optional<link::ComdatSelector> getComdatSelector(Operation *op);
1820
static bool isDeclaration(Operation *op);
1921
static unsigned getBitWidth(Operation *op);
2022
static UnnamedAddr getUnnamedAddr(Operation *op);

mlir/include/mlir/Linker/LLVMLinkerMixin.h

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,17 @@ static UnnamedAddr getMinUnnamedAddr(UnnamedAddr lhs, UnnamedAddr rhs) {
133133
return UnnamedAddr::Global;
134134
}
135135

136+
//===----------------------------------------------------------------------===//
137+
// Comdat helpers
138+
//===----------------------------------------------------------------------===//
139+
140+
using ComdatKind = LLVM::comdat::Comdat;
141+
142+
struct ComdatSelector {
143+
StringRef name;
144+
ComdatKind kind;
145+
};
146+
136147
//===----------------------------------------------------------------------===//
137148
// LLVMLinkerMixin
138149
//===----------------------------------------------------------------------===//
@@ -158,6 +169,9 @@ class LLVMLinkerMixin {
158169
if (pair.src == pair.dst)
159170
return false;
160171

172+
if (derived.isComdat(pair.src))
173+
return true;
174+
161175
Linkage srcLinkage = derived.getLinkage(pair.src);
162176

163177
// Always import variables with appending linkage.
@@ -318,6 +332,29 @@ class LLVMLinkerMixin {
318332
return ConflictResolution::LinkFromSrc;
319333
}
320334

335+
std::optional<ComdatSelector> srcComdatSel = derived.getComdatSelector(pair.src);
336+
std::optional<ComdatSelector> dstComdatSel = derived.getComdatSelector(pair.dst);
337+
if (srcComdatSel.has_value() && dstComdatSel.has_value()) {
338+
auto srcComdatName = srcComdatSel->name;
339+
auto dstComdatName = dstComdatSel->name;
340+
auto srcComdat = srcComdatSel->kind;
341+
auto dstComdat = dstComdatSel->kind;
342+
if (srcComdatName != dstComdatName) {
343+
llvm_unreachable("Comdat selector names don't match");
344+
}
345+
if (srcComdat != dstComdat) {
346+
llvm_unreachable("Comdat selector kinds don't match");
347+
}
348+
349+
if (srcComdat == mlir::LLVM::comdat::Comdat::Any) {
350+
return ConflictResolution::LinkFromDst;
351+
}
352+
if (srcComdat == mlir::LLVM::comdat::Comdat::NoDeduplicate) {
353+
return ConflictResolution::Failure;
354+
}
355+
llvm_unreachable("unimplemented comdat kind");
356+
}
357+
321358
llvm_unreachable("unimplemented conflict resolution");
322359
}
323360
};

mlir/include/mlir/Linker/LinkerInterface.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,7 @@ enum class ConflictResolution {
152152
LinkFromDst,
153153
LinkFromBothAndRenameDst,
154154
LinkFromBothAndRenameSrc,
155+
Failure,
155156
};
156157

157158
class SymbolAttrLinkerInterface : public SymbolLinkerInterface {

mlir/lib/Dialect/LLVMIR/IR/LLVMLinkerInterface.cpp

Lines changed: 88 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ LLVM::LLVMSymbolLinkerInterface::LLVMSymbolLinkerInterface(Dialect *dialect)
2727

2828
bool LLVM::LLVMSymbolLinkerInterface::canBeLinked(Operation *op) const {
2929
return isa<LLVM::GlobalOp, LLVM::LLVMFuncOp, LLVM::GlobalCtorsOp,
30-
LLVM::GlobalDtorsOp>(op);
30+
LLVM::GlobalDtorsOp, LLVM::ComdatOp>(op);
3131
}
3232

3333
//===--------------------------------------------------------------------===//
@@ -39,7 +39,7 @@ Linkage LLVM::LLVMSymbolLinkerInterface::getLinkage(Operation *op) {
3939
return gv.getLinkage();
4040
if (auto fn = dyn_cast<LLVM::LLVMFuncOp>(op))
4141
return fn.getLinkage();
42-
if (isa<LLVM::GlobalCtorsOp, LLVM::GlobalDtorsOp>(op))
42+
if (isa<LLVM::GlobalCtorsOp, LLVM::GlobalDtorsOp, LLVM::ComdatOp>(op))
4343
return Linkage::Appending;
4444
llvm_unreachable("unexpected operation");
4545
}
@@ -49,7 +49,8 @@ Visibility LLVM::LLVMSymbolLinkerInterface::getVisibility(Operation *op) {
4949
return gv.getVisibility_();
5050
if (auto fn = dyn_cast<LLVM::LLVMFuncOp>(op))
5151
return fn.getVisibility_();
52-
if (isa<LLVM::GlobalCtorsOp, LLVM::GlobalDtorsOp>(op))
52+
53+
if (isa<LLVM::GlobalCtorsOp, LLVM::GlobalDtorsOp, LLVM::ComdatOp>(op))
5354
return Visibility::Default;
5455
llvm_unreachable("unexpected operation");
5556
}
@@ -62,19 +63,50 @@ void LLVM::LLVMSymbolLinkerInterface::setVisibility(Operation *op,
6263
return fn.setVisibility_(visibility);
6364
// GlobalCotrs and Dtors are defined as operations in mlir
6465
// but as globals in LLVM IR
65-
if (isa<LLVM::GlobalCtorsOp, LLVM::GlobalDtorsOp>(op))
66+
if (isa<LLVM::GlobalCtorsOp, LLVM::GlobalDtorsOp, LLVM::ComdatOp>(op))
6667
return;
6768
llvm_unreachable("unexpected operation");
6869
}
6970

71+
static bool hasComdat(Operation *op) {
72+
if (auto gv = dyn_cast<LLVM::GlobalOp>(op))
73+
return gv.getComdat().has_value();
74+
if (auto fn = dyn_cast<LLVM::LLVMFuncOp>(op))
75+
return fn.getComdat().has_value();
76+
llvm_unreachable("unexpected operation");
77+
}
78+
79+
static SymbolRefAttr getComdatSymbol(Operation *op) {
80+
assert(hasComdat(op) && "Operation with Comdat expected");
81+
if (auto gv = dyn_cast<LLVM::GlobalOp>(op))
82+
return gv.getComdat().value();
83+
if (auto fn = dyn_cast<LLVM::LLVMFuncOp>(op))
84+
return fn.getComdat().value();
85+
llvm_unreachable("unexpected operation");
86+
}
87+
88+
bool LLVM::LLVMSymbolLinkerInterface::isComdat(Operation *op) {
89+
return isa<LLVM::ComdatOp>(op);
90+
}
91+
92+
std::optional<mlir::link::ComdatSelector> LLVM::LLVMSymbolLinkerInterface::getComdatSelector(Operation *op) {
93+
if (!hasComdat(op))
94+
return std::nullopt;
95+
96+
auto symbol = getComdatSymbol(op);
97+
auto *symTabOp = SymbolTable::getNearestSymbolTable(op);
98+
auto comdatSelector = cast<mlir::LLVM::ComdatSelectorOp>(SymbolTable::lookupSymbolIn(symTabOp, symbol));
99+
return {{comdatSelector.getSymName(), comdatSelector.getComdat()}};
100+
}
101+
70102
// Return true if the primary definition of this global value is outside of
71103
// the current translation unit.
72104
bool LLVM::LLVMSymbolLinkerInterface::isDeclaration(Operation *op) {
73105
if (auto gv = dyn_cast<LLVM::GlobalOp>(op))
74106
return gv.getInitializerRegion().empty() && !gv.getValue();
75107
if (auto fn = dyn_cast<LLVM::LLVMFuncOp>(op))
76108
return fn.getBody().empty();
77-
if (isa<LLVM::GlobalCtorsOp, LLVM::GlobalDtorsOp>(op))
109+
if (isa<LLVM::GlobalCtorsOp, LLVM::GlobalDtorsOp, LLVM::ComdatOp>(op))
78110
return false;
79111
llvm_unreachable("unexpected operation");
80112
}
@@ -97,7 +129,7 @@ UnnamedAddr LLVM::LLVMSymbolLinkerInterface::getUnnamedAddr(Operation *op) {
97129
auto addr = fn.getUnnamedAddr();
98130
return addr ? *addr : UnnamedAddr::Global;
99131
}
100-
if (isa<LLVM::GlobalCtorsOp, LLVM::GlobalDtorsOp>(op))
132+
if (isa<LLVM::GlobalCtorsOp, LLVM::GlobalDtorsOp, LLVM::ComdatOp>(op))
101133
return UnnamedAddr::Global;
102134
llvm_unreachable("unexpected operation");
103135
}
@@ -110,7 +142,7 @@ void LLVM::LLVMSymbolLinkerInterface::setUnnamedAddr(Operation *op,
110142
return fn.setUnnamedAddr(val);
111143
// GlobalCotrs and Dtors are defined as operations in mlir
112144
// but as globals in LLVM IR
113-
if (isa<LLVM::GlobalCtorsOp, LLVM::GlobalDtorsOp>(op))
145+
if (isa<LLVM::GlobalCtorsOp, LLVM::GlobalDtorsOp, LLVM::ComdatOp>(op))
114146
return;
115147
llvm_unreachable("unexpected operation");
116148
}
@@ -360,39 +392,67 @@ getAppendedOpWithInitRegion(llvm::ArrayRef<mlir::Operation *> globs,
360392
return targetGV;
361393
}
362394

363-
Operation *LLVM::LLVMSymbolLinkerInterface::appendGlobals(llvm::StringRef glob,
364-
LinkState &state) {
365-
if (glob == "llvm.global_ctors")
366-
return appendGlobalStructors<LLVM::GlobalCtorsOp>(state);
367-
if (glob == "llvm.global_dtors")
368-
return appendGlobalStructors<LLVM::GlobalDtorsOp>(state);
369-
370-
const auto &globs = append.lookup(glob);
371-
auto lastGV = dyn_cast<LLVM::GlobalOp>(globs.back());
372-
if (!lastGV)
373-
llvm_unreachable("unexpected operation");
374-
395+
static Operation *appendGlobalOps(ArrayRef<Operation *> globs, LLVM::GlobalOp lastGV, LinkState &state) {
375396
// Src ops that are declarations are ignored in favour of dst operation
376397
// This mimics the behaviour of linkAppendingVarProto in llvm-link
377398
if (globs.size() == 1)
378-
return state.clone(globs.front());
399+
return state.clone(globs.front());
379400

380401
if (!lastGV.getInitializer().empty()) {
381-
return getAppendedOpWithInitRegion(globs, state);
402+
return getAppendedOpWithInitRegion(globs, state);
382403
} else {
383-
auto [value, type] = getAppendedAttr(globs, state);
404+
auto [value, type] = getAppendedAttr(globs, state);
384405

385-
auto valueAttrName = lastGV.getValueAttrName();
386-
auto typeAttrName = lastGV.getGlobalTypeAttrName();
406+
auto valueAttrName = lastGV.getValueAttrName();
407+
auto typeAttrName = lastGV.getGlobalTypeAttrName();
387408

388-
auto cloned = state.clone(globs.back());
389-
cloned->setAttr(valueAttrName, value);
390-
cloned->setAttr(typeAttrName, TypeAttr::get(type));
391-
return cloned;
409+
auto cloned = state.clone(globs.back());
410+
cloned->setAttr(valueAttrName, value);
411+
cloned->setAttr(typeAttrName, TypeAttr::get(type));
412+
return cloned;
392413
}
393414
llvm_unreachable("unknown value attribute type");
394415
}
395416

417+
static Operation *appendComdatOps(ArrayRef<Operation *> globs, LLVM::ComdatOp comdat, LinkState &state) {
418+
auto result = cast<LLVM::ComdatOp>(state.clone(comdat));
419+
llvm::StringMap<Operation *> selectors;
420+
421+
for (auto selector : result.getOps<LLVM::ComdatSelectorOp>()) {
422+
selectors[selector.getSymName()] = selector;
423+
}
424+
425+
for (auto *glob : globs) {
426+
comdat = dyn_cast<LLVM::ComdatOp>(glob);
427+
for (auto &op : comdat.getBody().getOps()) {
428+
auto selector = cast<LLVM::ComdatSelectorOp>(op);
429+
auto selectorName = selector.getSymName();
430+
if (selectors.contains(selectorName)) {
431+
continue;
432+
}
433+
auto *cloned = state.clone(selector);
434+
cloned->moveBefore(&result.getBody().front().back());
435+
selectors[selectorName] = cloned;
436+
}
437+
}
438+
return result;
439+
}
440+
441+
Operation *LLVM::LLVMSymbolLinkerInterface::appendGlobals(llvm::StringRef glob,
442+
LinkState &state) {
443+
if (glob == "llvm.global_ctors")
444+
return appendGlobalStructors<LLVM::GlobalCtorsOp>(state);
445+
if (glob == "llvm.global_dtors")
446+
return appendGlobalStructors<LLVM::GlobalDtorsOp>(state);
447+
448+
const auto &globs = append.lookup(glob);
449+
if (auto lastGV = dyn_cast<LLVM::GlobalOp>(globs.back()))
450+
return appendGlobalOps(globs, lastGV, state);
451+
if (auto comdat = dyn_cast<LLVM::ComdatOp>(globs.back()))
452+
return appendComdatOps(globs, comdat, state);
453+
llvm_unreachable("unexpected operation");
454+
}
455+
396456
//===----------------------------------------------------------------------===//
397457
// registerLinkerInterface
398458
//===----------------------------------------------------------------------===//

mlir/lib/Linker/LinkerInterface.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -196,6 +196,8 @@ SymbolAttrLinkerInterface::resolveConflict(Conflict pair,
196196
case ConflictResolution::LinkFromBothAndRenameSrc:
197197
uniqued.insert(pair.src);
198198
return success();
199+
case ConflictResolution::Failure:
200+
return pair.src->emitError("Linker error");
199201
}
200202

201203
llvm_unreachable("unimplemented conflict resolution");

0 commit comments

Comments
 (0)