Skip to content

Commit 5b77c4c

Browse files
frabert2over12tnytown
authored
[CIR-link] move link interface implementation into dialect to prevent circular dependency (#37)
* [CIR-link] move link interface implementation into dialect to prevent circular dependency * expose interfaces * expose conlict resolution * [CIR] Initialize flag for target dialect (#38) --------- Co-authored-by: 2over12 <[email protected]> Co-authored-by: Andrew Pan <[email protected]>
1 parent f4752e8 commit 5b77c4c

File tree

11 files changed

+162
-94
lines changed

11 files changed

+162
-94
lines changed

clang/include/clang/Frontend/FrontendOptions.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -453,7 +453,7 @@ class FrontendOptions {
453453
std::string ClangIRIdiomRecognizerOpts;
454454
std::string ClangIRLibOptOpts;
455455

456-
frontend::MLIRDialectKind MLIRTargetDialect;
456+
frontend::MLIRDialectKind MLIRTargetDialect = frontend::MLIR_CORE;
457457

458458
/// The input kind, either specified via -x argument or deduced from the input
459459
/// file name.

clang/lib/CIR/Dialect/IR/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ add_clang_library(MLIRCIR
66
CIRMemorySlot.cpp
77
CIRTypes.cpp
88
FPEnv.cpp
9+
CIRLinkerInterface.cpp
910

1011
DEPENDS
1112
MLIRBuiltinLocationAttributesIncGen

clang/lib/CIR/Interfaces/CMakeLists.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
add_clang_library(MLIRCIRInterfaces
22
ASTAttrInterfaces.cpp
33
CIROpInterfaces.cpp
4-
CIRLinkerInterface.cpp
54
CIRLoopOpInterface.cpp
65
CIRFPTypeInterface.cpp
76

mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -241,6 +241,7 @@ SmallVector<IntT> convertArrayToIndices(ArrayAttr attrs) {
241241
return convertArrayToIndices<IntT>(attrs.getValue());
242242
}
243243

244+
244245
/// Register the `LLVMLinkerInterface` implementation of `LinkerInterface`
245246
/// within the LLVM dialect.
246247
void registerLinkerInterface(DialectRegistry &registry);

mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
#define MLIR_DIALECT_LLVMIR_LLVMINTERFACES_H_
1515

1616
#include "mlir/Dialect/LLVMIR/LLVMAttrs.h"
17-
1817
namespace mlir {
1918
namespace LLVM {
2019
namespace detail {
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
#ifndef MLIR_DIALECT_LLVMIR_LLVMLINKERINTERFACE_H
2+
#define MLIR_DIALECT_LLVMIR_LLVMLINKERINTERFACE_H
3+
4+
#include "mlir/Linker/LLVMLinkerMixin.h"
5+
namespace mlir {
6+
namespace LLVM {
7+
8+
class LLVMSymbolLinkerInterface
9+
: public link::SymbolAttrLLVMLinkerInterface<LLVMSymbolLinkerInterface> {
10+
public:
11+
LLVMSymbolLinkerInterface(Dialect *dialect);
12+
13+
bool canBeLinked(Operation *op) const override;
14+
static Linkage getLinkage(Operation *op);
15+
static Visibility getVisibility(Operation *op);
16+
static void setVisibility(Operation *op, Visibility visibility);
17+
static bool isDeclaration(Operation *op);
18+
static unsigned getBitWidth(Operation *op);
19+
static UnnamedAddr getUnnamedAddr(Operation *op);
20+
static void setUnnamedAddr(Operation *op, UnnamedAddr val);
21+
};
22+
23+
} // namespace LLVM
24+
} // namespace mlir
25+
26+
#endif // MLIR_DIALECT_LLVMIR_LLVMLINKERINTERFACE_H

mlir/include/mlir/Linker/LLVMLinkerMixin.h

Lines changed: 8 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -137,13 +137,6 @@ static UnnamedAddr getMinUnnamedAddr(UnnamedAddr lhs, UnnamedAddr rhs) {
137137
// LLVMLinkerMixin
138138
//===----------------------------------------------------------------------===//
139139

140-
enum class ConflictResolution {
141-
LinkFromSrc,
142-
LinkFromDst,
143-
LinkFromBothAndRenameDst,
144-
LinkFromBothAndRenameSrc,
145-
};
146-
147140
template <typename DerivedLinkerInterface>
148141
class LLVMLinkerMixin {
149142
const DerivedLinkerInterface &getDerived() const {
@@ -198,7 +191,7 @@ class LLVMLinkerMixin {
198191
isAvailableExternallyLinkage(srcLinkage));
199192
}
200193

201-
LogicalResult verifyLinkageCompatibility(Conflict pair) {
194+
LogicalResult verifyLinkageCompatibility(Conflict pair) const{
202195
const DerivedLinkerInterface &derived = getDerived();
203196
assert(derived.canBeLinked(pair.src) && "expected linkable operation");
204197
assert(derived.canBeLinked(pair.dst) && "expected linkable operation");
@@ -221,7 +214,7 @@ class LLVMLinkerMixin {
221214
return success();
222215
}
223216

224-
ConflictResolution resolveConflict(Conflict pair) {
217+
ConflictResolution getConflictResolution(Conflict pair) const {
225218
const DerivedLinkerInterface &derived = getDerived();
226219
assert(derived.canBeLinked(pair.src) && "expected linkable operation");
227220
assert(derived.canBeLinked(pair.dst) && "expected linkable operation");
@@ -318,28 +311,14 @@ class SymbolAttrLLVMLinkerInterface
318311
return LinkerMixin::isLinkNeeded(pair, forDependency);
319312
}
320313

321-
LogicalResult resolveConflict(Conflict pair) override {
322-
if (failed(LinkerMixin::verifyLinkageCompatibility(pair)))
323-
return failure();
324-
ConflictResolution resolution = LinkerMixin::resolveConflict(pair);
325-
326-
switch (resolution) {
327-
case ConflictResolution::LinkFromSrc:
328-
registerForLink(pair.src);
329-
return success();
330-
case ConflictResolution::LinkFromDst:
331-
return success();
332-
case ConflictResolution::LinkFromBothAndRenameDst:
333-
uniqued.insert(pair.dst);
334-
registerForLink(pair.src);
335-
return success();
336-
case ConflictResolution::LinkFromBothAndRenameSrc:
337-
uniqued.insert(pair.src);
338-
return success();
339-
}
314+
LogicalResult verifyLinkageCompatibility(Conflict pair) const override {
315+
return LinkerMixin::verifyLinkageCompatibility(pair);
316+
}
340317

341-
llvm_unreachable("unimplemented conflict resolution");
318+
ConflictResolution getConflictResolution(Conflict pair) const override {
319+
return LinkerMixin::getConflictResolution(pair);
342320
}
321+
343322
};
344323

345324
} // namespace mlir::link

mlir/include/mlir/Linker/LinkerInterface.h

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
#include "mlir/IR/DialectInterface.h"
2020
#include "mlir/IR/IRMapping.h"
2121
#include "llvm/ADT/DenseMap.h"
22+
#include "llvm/Support/Error.h"
2223

2324
namespace mlir::link {
2425

@@ -43,6 +44,10 @@ class LinkState {
4344

4445
Operation *remapped(Operation *src) const;
4546

47+
LinkState nest(ModuleOp submod) const;
48+
49+
void updateState(const LinkState &substate);
50+
4651
private:
4752
IRMapping mapping;
4853
OpBuilder builder;
@@ -136,6 +141,12 @@ class SymbolLinkerInterface : public LinkerInterface<SymbolLinkerInterface> {
136141
//===----------------------------------------------------------------------===//
137142
// SymbolAttrLinkerInterface
138143
//===----------------------------------------------------------------------===//
144+
enum class ConflictResolution {
145+
LinkFromSrc,
146+
LinkFromDst,
147+
LinkFromBothAndRenameDst,
148+
LinkFromBothAndRenameSrc,
149+
};
139150

140151
class SymbolAttrLinkerInterface : public SymbolLinkerInterface {
141152
public:
@@ -153,6 +164,16 @@ class SymbolAttrLinkerInterface : public SymbolLinkerInterface {
153164
/// Records a non-conflicting operation for linking.
154165
void registerForLink(Operation *op) override;
155166

167+
/// Resolves a conflict between an existing operation and a new one.
168+
LogicalResult resolveConflict(Conflict pair) override;
169+
170+
virtual LogicalResult resolveConflict(Conflict pair, ConflictResolution resolution);
171+
172+
/// Gets the conflict resolution for a given conflict
173+
virtual ConflictResolution getConflictResolution(Conflict pair) const = 0;
174+
175+
virtual LogicalResult verifyLinkageCompatibility(Conflict pair) const = 0;
176+
156177
/// Dependencies of the given operation required to be linked.
157178
SmallVector<Operation *> dependencies(Operation *op) const override;
158179

mlir/lib/Dialect/LLVMIR/IR/LLVMLinkerInterface.cpp

Lines changed: 62 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -13,88 +13,88 @@
1313
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
1414
#include "mlir/Linker/LLVMLinkerMixin.h"
1515
#include "mlir/Linker/LinkerInterface.h"
16-
16+
#include "mlir/Dialect/LLVMIR/LLVMLinkerInterface.h"
1717
using namespace mlir;
1818
using namespace mlir::link;
1919

2020
//===----------------------------------------------------------------------===//
2121
// LLVMSymbolLinkerInterface
2222
//===----------------------------------------------------------------------===//
2323

24-
class LLVMSymbolLinkerInterface
25-
: public SymbolAttrLLVMLinkerInterface<LLVMSymbolLinkerInterface> {
26-
public:
27-
LLVMSymbolLinkerInterface(Dialect *dialect)
28-
: SymbolAttrLLVMLinkerInterface(dialect) {}
2924

30-
bool canBeLinked(Operation *op) const override {
31-
return isa<LLVM::GlobalOp>(op) || isa<LLVM::LLVMFuncOp>(op);
32-
}
25+
26+
27+
mlir::LLVM::LLVMSymbolLinkerInterface::LLVMSymbolLinkerInterface(Dialect *dialect)
28+
: SymbolAttrLLVMLinkerInterface(dialect) {}
29+
30+
bool mlir::LLVM::LLVMSymbolLinkerInterface::canBeLinked(Operation *op) const {
31+
return isa<LLVM::GlobalOp>(op) || isa<LLVM::LLVMFuncOp>(op);
32+
}
3333

3434
//===--------------------------------------------------------------------===//
3535
// LLVMLinkerMixin required methods from derived linker interface
3636
//===--------------------------------------------------------------------===//
3737

38-
static Linkage getLinkage(Operation *op) {
39-
if (auto gv = dyn_cast<LLVM::GlobalOp>(op))
40-
return gv.getLinkage();
41-
if (auto fn = dyn_cast<LLVM::LLVMFuncOp>(op))
42-
return fn.getLinkage();
43-
llvm_unreachable("unexpected operation");
44-
}
38+
Linkage mlir::LLVM::LLVMSymbolLinkerInterface::getLinkage(Operation *op) {
39+
if (auto gv = dyn_cast<LLVM::GlobalOp>(op))
40+
return gv.getLinkage();
41+
if (auto fn = dyn_cast<LLVM::LLVMFuncOp>(op))
42+
return fn.getLinkage();
43+
llvm_unreachable("unexpected operation");
44+
}
4545

46-
static Visibility getVisibility(Operation *op) {
47-
if (auto gv = dyn_cast<LLVM::GlobalOp>(op))
48-
return gv.getVisibility_();
49-
if (auto fn = dyn_cast<LLVM::LLVMFuncOp>(op))
50-
return fn.getVisibility_();
51-
llvm_unreachable("unexpected operation");
52-
}
46+
Visibility mlir::LLVM::LLVMSymbolLinkerInterface::getVisibility(Operation *op) {
47+
if (auto gv = dyn_cast<LLVM::GlobalOp>(op))
48+
return gv.getVisibility_();
49+
if (auto fn = dyn_cast<LLVM::LLVMFuncOp>(op))
50+
return fn.getVisibility_();
51+
llvm_unreachable("unexpected operation");
52+
}
5353

54-
static void setVisibility(Operation *op, Visibility visibility) {
55-
if (auto gv = dyn_cast<LLVM::GlobalOp>(op))
56-
return gv.setVisibility_(visibility);
57-
if (auto fn = dyn_cast<LLVM::LLVMFuncOp>(op))
58-
return fn.setVisibility_(visibility);
59-
llvm_unreachable("unexpected operation");
60-
}
54+
void mlir::LLVM::LLVMSymbolLinkerInterface::setVisibility(Operation *op, Visibility visibility) {
55+
if (auto gv = dyn_cast<LLVM::GlobalOp>(op))
56+
return gv.setVisibility_(visibility);
57+
if (auto fn = dyn_cast<LLVM::LLVMFuncOp>(op))
58+
return fn.setVisibility_(visibility);
59+
llvm_unreachable("unexpected operation");
60+
}
6161

62-
// Return true if the primary definition of this global value is outside of
63-
// the current translation unit.
64-
static bool isDeclaration(Operation *op) {
65-
if (auto gv = dyn_cast<LLVM::GlobalOp>(op))
66-
return gv.getInitializerRegion().empty() && !gv.getValue();
67-
if (auto fn = dyn_cast<LLVM::LLVMFuncOp>(op))
68-
return fn.getBody().empty();
69-
llvm_unreachable("unexpected operation");
70-
}
62+
// Return true if the primary definition of this global value is outside of
63+
// the current translation unit.
64+
bool mlir::LLVM::LLVMSymbolLinkerInterface::isDeclaration(Operation *op) {
65+
if (auto gv = dyn_cast<LLVM::GlobalOp>(op))
66+
return gv.getInitializerRegion().empty() && !gv.getValue();
67+
if (auto fn = dyn_cast<LLVM::LLVMFuncOp>(op))
68+
return fn.getBody().empty();
69+
llvm_unreachable("unexpected operation");
70+
}
7171

72-
static unsigned getBitWidth(Operation *op) {
73-
if (auto gv = dyn_cast<LLVM::GlobalOp>(op))
74-
return gv.getType().getIntOrFloatBitWidth();
75-
llvm_unreachable("unexpected operation");
76-
}
72+
unsigned mlir::LLVM::LLVMSymbolLinkerInterface::getBitWidth(Operation *op) {
73+
if (auto gv = dyn_cast<LLVM::GlobalOp>(op))
74+
return gv.getType().getIntOrFloatBitWidth();
75+
llvm_unreachable("unexpected operation");
76+
}
7777

78-
static UnnamedAddr getUnnamedAddr(Operation *op) {
79-
if (auto gv = dyn_cast<LLVM::GlobalOp>(op)) {
80-
auto addr = gv.getUnnamedAddr();
81-
return addr ? *addr : UnnamedAddr::Global;
82-
}
83-
if (auto fn = dyn_cast<LLVM::LLVMFuncOp>(op)) {
84-
auto addr = fn.getUnnamedAddr();
85-
return addr ? *addr : UnnamedAddr::Global;
86-
}
87-
llvm_unreachable("unexpected operation");
78+
UnnamedAddr mlir::LLVM::LLVMSymbolLinkerInterface::getUnnamedAddr(Operation *op) {
79+
if (auto gv = dyn_cast<LLVM::GlobalOp>(op)) {
80+
auto addr = gv.getUnnamedAddr();
81+
return addr ? *addr : UnnamedAddr::Global;
8882
}
89-
90-
static void setUnnamedAddr(Operation *op, UnnamedAddr val) {
91-
if (auto gv = dyn_cast<LLVM::GlobalOp>(op))
92-
return gv.setUnnamedAddr(val);
93-
if (auto fn = dyn_cast<LLVM::LLVMFuncOp>(op))
94-
return fn.setUnnamedAddr(val);
95-
llvm_unreachable("unexpected operation");
83+
if (auto fn = dyn_cast<LLVM::LLVMFuncOp>(op)) {
84+
auto addr = fn.getUnnamedAddr();
85+
return addr ? *addr : UnnamedAddr::Global;
9686
}
97-
};
87+
llvm_unreachable("unexpected operation");
88+
}
89+
90+
void mlir::LLVM::LLVMSymbolLinkerInterface::setUnnamedAddr(Operation *op, UnnamedAddr val) {
91+
if (auto gv = dyn_cast<LLVM::GlobalOp>(op))
92+
return gv.setUnnamedAddr(val);
93+
if (auto fn = dyn_cast<LLVM::LLVMFuncOp>(op))
94+
return fn.setUnnamedAddr(val);
95+
llvm_unreachable("unexpected operation");
96+
}
97+
9898

9999
//===----------------------------------------------------------------------===//
100100
// registerLinkerInterface

0 commit comments

Comments
 (0)