Skip to content

Commit 535025d

Browse files
authored
[CIR-link] Refactor nest/updateState (#35)
* [CIR-link] Refactor nest/updateState * Formatting
1 parent b795889 commit 535025d

File tree

3 files changed

+23
-23
lines changed

3 files changed

+23
-23
lines changed

mlir/include/mlir/Linker/LLVMLinkerMixin.h

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -191,13 +191,13 @@ class LLVMLinkerMixin {
191191
isAvailableExternallyLinkage(srcLinkage));
192192
}
193193

194-
LogicalResult verifyLinkageCompatibility(Conflict pair) const{
194+
LogicalResult verifyLinkageCompatibility(Conflict pair) const {
195195
const DerivedLinkerInterface &derived = getDerived();
196196
assert(derived.canBeLinked(pair.src) && "expected linkable operation");
197197
assert(derived.canBeLinked(pair.dst) && "expected linkable operation");
198198

199199
auto linkError = [&](const Twine &error) -> LogicalResult {
200-
return pair.src->emitError(error) << " dst: " << pair.dst->getLoc();
200+
return pair.src->emitError(error) << " dst: " << pair.dst->getLoc();
201201
};
202202

203203
Linkage srcLinkage = derived.getLinkage(pair.src);
@@ -208,13 +208,14 @@ class LLVMLinkerMixin {
208208

209209
if (isAppendingLinkage(srcLinkage) && isAppendingLinkage(dstLinkage)) {
210210
if (srcUnnamedAddr != dstUnnamedAddr) {
211-
return linkError("Appending variables with different unnamed_addr need to be linked");
211+
return linkError("Appending variables with different unnamed_addr need "
212+
"to be linked");
212213
}
213214
}
214215
return success();
215216
}
216217

217-
ConflictResolution getConflictResolution(Conflict pair) const {
218+
ConflictResolution getConflictResolution(Conflict pair) const {
218219
const DerivedLinkerInterface &derived = getDerived();
219220
assert(derived.canBeLinked(pair.src) && "expected linkable operation");
220221
assert(derived.canBeLinked(pair.dst) && "expected linkable operation");
@@ -318,7 +319,6 @@ class SymbolAttrLLVMLinkerInterface
318319
ConflictResolution getConflictResolution(Conflict pair) const override {
319320
return LinkerMixin::getConflictResolution(pair);
320321
}
321-
322322
};
323323

324324
} // namespace mlir::link

mlir/include/mlir/Linker/LinkerInterface.h

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
#include "mlir/IR/IRMapping.h"
2121
#include "llvm/ADT/DenseMap.h"
2222
#include "llvm/Support/Error.h"
23+
#include <memory>
2324

2425
namespace mlir::link {
2526

@@ -35,7 +36,8 @@ enum LinkerFlags {
3536

3637
class LinkState {
3738
public:
38-
LinkState(ModuleOp dst) : builder(dst.getBodyRegion()) {}
39+
LinkState(ModuleOp dst)
40+
: mapping(std::make_shared<IRMapping>()), builder(dst.getBodyRegion()) {}
3941

4042
Operation *clone(Operation *src);
4143
Operation *cloneWithoutRegions(Operation *src);
@@ -46,10 +48,12 @@ class LinkState {
4648

4749
LinkState nest(ModuleOp submod) const;
4850

49-
void updateState(const LinkState &substate);
50-
5151
private:
52-
IRMapping mapping;
52+
// Private constructor used by nest()
53+
LinkState(ModuleOp dst, std::shared_ptr<IRMapping> mapping)
54+
: mapping(mapping), builder(dst.getBodyRegion()) {}
55+
56+
std::shared_ptr<IRMapping> mapping;
5357
OpBuilder builder;
5458
};
5559

@@ -167,7 +171,8 @@ class SymbolAttrLinkerInterface : public SymbolLinkerInterface {
167171
/// Resolves a conflict between an existing operation and a new one.
168172
LogicalResult resolveConflict(Conflict pair) override;
169173

170-
virtual LogicalResult resolveConflict(Conflict pair, ConflictResolution resolution);
174+
virtual LogicalResult resolveConflict(Conflict pair,
175+
ConflictResolution resolution);
171176

172177
/// Gets the conflict resolution for a given conflict
173178
virtual ConflictResolution getConflictResolution(Conflict pair) const = 0;

mlir/lib/Linker/LinkerInterface.cpp

Lines changed: 8 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -18,22 +18,23 @@ using namespace mlir::link;
1818
//===----------------------------------------------------------------------===//
1919

2020
template <typename CloneFunc>
21-
Operation *cloneImpl(Operation *src, IRMapping &mapping, CloneFunc cloneFunc) {
22-
assert(!mapping.contains(src));
21+
Operation *cloneImpl(Operation *src, std::shared_ptr<IRMapping> &mapping,
22+
CloneFunc cloneFunc) {
23+
assert(!mapping->contains(src));
2324
Operation *dst = cloneFunc(src);
24-
mapping.map(src, dst);
25+
mapping->map(src, dst);
2526
return dst;
2627
}
2728

2829
Operation *LinkState::clone(Operation *src) {
2930
return cloneImpl(src, mapping, [this](Operation *op) {
30-
return builder.clone(*op, mapping);
31+
return builder.clone(*op, *mapping);
3132
});
3233
}
3334

3435
Operation *LinkState::cloneWithoutRegions(Operation *src) {
3536
return cloneImpl(src, mapping, [this](Operation *op) {
36-
return builder.cloneWithoutRegions(*op, mapping);
37+
return builder.cloneWithoutRegions(*op, *mapping);
3738
});
3839
}
3940

@@ -42,20 +43,14 @@ Operation *LinkState::getDestinationOp() const {
4243
}
4344

4445
Operation *LinkState::remapped(Operation *src) const {
45-
return mapping.lookupOrNull(src);
46+
return mapping->lookupOrNull(src);
4647
}
4748

4849
LinkState LinkState::nest(ModuleOp submod) const {
4950
assert(submod->getParentOfType<mlir::ModuleOp>().getOperation() ==
5051
getDestinationOp() &&
5152
"Submodule should be directly nested in the current state");
52-
LinkState submodState(submod);
53-
submodState.mapping = mapping;
54-
return submodState;
55-
}
56-
57-
void LinkState::updateState(const LinkState &substate) {
58-
mapping = substate.mapping;
53+
return LinkState(submod, mapping);
5954
}
6055

6156
//===----------------------------------------------------------------------===//

0 commit comments

Comments
 (0)