Skip to content

Commit d0b0e12

Browse files
committed
[MLIR][mlir-link] Fix conflict resolution
1 parent 436aded commit d0b0e12

File tree

5 files changed

+28
-47
lines changed

5 files changed

+28
-47
lines changed

mlir/include/mlir/Linker/IRMover.h

Lines changed: 3 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@ struct ConflictPair {
2525
static ConflictPair noConflict(Operation *src) { return {nullptr, src}; }
2626
};
2727

28+
using Summary = llvm::StringMap<ConflictPair>;
29+
2830
struct IRMover {
2931

3032
ModuleOp composite;
@@ -36,7 +38,7 @@ struct IRMover {
3638

3739
explicit IRMover(ModuleOp composite) : composite(composite) {}
3840

39-
LogicalResult move(ArrayRef<ConflictPair> valuesToLink);
41+
LogicalResult move(const Summary &summary);
4042

4143
private:
4244
Operation * remap(ConflictPair pair);
@@ -46,26 +48,4 @@ struct IRMover {
4648

4749
} // namespace mlir::link
4850

49-
namespace llvm {
50-
51-
template <>
52-
struct DenseMapInfo<mlir::link::ConflictPair> {
53-
static mlir::link::ConflictPair getEmptyKey() {
54-
auto *pointer = llvm::DenseMapInfo<void *>::getEmptyKey();
55-
return {{}, static_cast<mlir::Operation*>(pointer)};
56-
}
57-
static mlir::link::ConflictPair getTombstoneKey() {
58-
auto *pointer = llvm::DenseMapInfo<void *>::getTombstoneKey();
59-
return {{}, static_cast<mlir::Operation *>(pointer)};
60-
}
61-
static unsigned getHashValue(mlir::link::ConflictPair val) {
62-
return DenseMapInfo<const mlir::Operation *>::getHashValue(val.src);
63-
}
64-
static bool isEqual(mlir::link::ConflictPair lhs, mlir::link::ConflictPair rhs) {
65-
return lhs.src == rhs.src;
66-
}
67-
};
68-
69-
} // namespace llvm
70-
7151
#endif // MLIR_LINKER_IRMOVER_H

mlir/lib/Dialect/LLVMIR/IR/LLVMLinkerInterface.cpp

Lines changed: 16 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -167,25 +167,30 @@ class LLVMSymbolLinkerInterface : public SymbolLinkerInterface {
167167

168168
StringRef getSymbol(Operation *op) const override { return symbol(op); }
169169

170+
ConflictPair join(ConflictPair existing, Operation *src) const {
171+
assert(!existing.hasConflict() && "expected non-conflicting pair");
172+
if (isLocalLinkage(getLinkage(existing.src)))
173+
return ConflictPair::noConflict(existing.src);
174+
return {existing.src, src};
175+
}
176+
170177
ConflictPair findConflict(Operation *src) const override {
171178
assert(canBeLinked(src) && "expected linkable operation");
172179

173180
if (isLocalLinkage(getLinkage(src)))
174181
return ConflictPair::noConflict(src);
175182

176-
// TODO: make lookup through module state
177183
if (auto it = summary.find(getSymbol(src)); it != summary.end()) {
178-
Operation *dst = it->second;
179-
if (dst != src && !isLocalLinkage(getLinkage(dst))) {
180-
return {dst, src};
181-
}
184+
return join(it->second, src);
182185
}
183186

184187
return ConflictPair::noConflict(src);
185188
}
186189

187190
bool isLinkNeeded(ConflictPair pair, bool forDependency) const override {
188191
assert(canBeLinked(pair.src) && "expected linkable operation");
192+
if (pair.src == pair.dst)
193+
return false;
189194

190195
LLVM::Linkage srcLinkage = getLinkage(pair.src);
191196

@@ -320,21 +325,13 @@ class LLVMSymbolLinkerInterface : public SymbolLinkerInterface {
320325
valuesToClone.insert(*linkFromSrc ? pair.dst : pair.src);
321326

322327
if (*linkFromSrc)
323-
registerForLink(pair);
328+
summary[getSymbol(pair.src)] = pair;
324329
return success();
325330
}
326331

327332
void registerForLink(Operation *op) override {
328333
assert(canBeLinked(op) && "expected linkable operation");
329-
registerForLink(ConflictPair::noConflict(op));
330-
}
331-
332-
void registerForLink(ConflictPair pair) {
333-
StringRef sym = getSymbol(pair.src);
334-
assert(!summary.contains(sym) && "expected unique symbol");
335-
summary[sym] = pair.src;
336-
337-
valuesToLink.insert(pair);
334+
summary[getSymbol(op)] = ConflictPair::noConflict(op);
338335
}
339336

340337
LogicalResult initialize(ModuleOp src) override {
@@ -354,7 +351,7 @@ class LLVMSymbolLinkerInterface : public SymbolLinkerInterface {
354351
// }
355352

356353
IRMover mover(dst);
357-
return mover.move(valuesToLink.getArrayRef());
354+
return mover.move(summary);
358355
}
359356

360357
Operation *prototype(ConflictPair pair) const {
@@ -374,10 +371,10 @@ class LLVMSymbolLinkerInterface : public SymbolLinkerInterface {
374371
Operation *materialize(ConflictPair pair, ModuleOp dst) const override {
375372
// Make definition if destination does not have one or has only declaration
376373
bool forDefinition = !pair.dst || isDeclaration(pair.dst);
377-
if (!forDefinition)
374+
if (!forDefinition || isDeclaration(pair.src))
378375
return prototype(pair);
379376

380-
// Definition laready exists
377+
// Definition already exists
381378
if (pair.dst && !isDeclaration(pair.dst))
382379
return pair.dst;
383380

@@ -411,13 +408,11 @@ class LLVMSymbolLinkerInterface : public SymbolLinkerInterface {
411408
}
412409

413410
private:
414-
llvm::StringMap<Operation *> summary;
415-
416411
std::unique_ptr<SymbolTable> symbolTable;
417412

418413
SetVector<Operation *> valuesToClone;
419414

420-
SetVector<ConflictPair> valuesToLink;
415+
llvm::StringMap<ConflictPair> summary;
421416
};
422417

423418
//===----------------------------------------------------------------------===//

mlir/lib/Linker/IRMover.cpp

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,18 @@
1414
using namespace mlir;
1515
using namespace mlir::link;
1616

17-
LogicalResult IRMover::move(ArrayRef<ConflictPair> valuesToLink) {
18-
worklist.insert(worklist.end(), valuesToLink.rbegin(), valuesToLink.rend());
17+
LogicalResult IRMover::move(const Summary &summary) {
18+
worklist.reserve(summary.size());
19+
for (const auto &[_, pair] : summary) {
20+
worklist.push_back(pair);
21+
}
1922

2023
while (!worklist.empty()) {
2124
ConflictPair pair = worklist.back();
2225
worklist.pop_back();
2326

24-
assert(!mapping.contains(pair.src) && "expected no mapping for source");
27+
if (mapping.contains(pair.src))
28+
continue;
2529

2630
if (!remap(pair))
2731
return failure();

mlir/test/mlir-link/linkage-b.mlir

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
// RUN: mlir-link -sort-symbols -split-input-file %s | FileCheck %s
2+
// REQUIRES: mlir-link-common, mlir-link-weak
23

34
// CHECK: llvm.mlir.global common @test1_a(0 : i8) {addr_space = 0 : i32} : i8
45
// CHECK: llvm.mlir.global external @test2_a(0 : i8) {addr_space = 0 : i32} : i8

mlir/test/mlir-link/linkage-c.mlir

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
// RUN: mlir-link -sort-symbols -split-input-file %s | FileCheck %s
2+
// REQUIRES: mlir-link-common, mlir-link-weak
23

34
// CHECK: llvm.mlir.global common @test1_a(0 : i8) {addr_space = 0 : i32} : i8
45
// CHECK: llvm.mlir.global external @test2_a(0 : i8) {addr_space = 0 : i32} : i8

0 commit comments

Comments
 (0)