Skip to content

Commit c682836

Browse files
AWoloszynbenvanik
andauthored
Fix an issue in ReferencePatitioning. (iree-org#21343)
When cloning operations into partitions, we intentionally only check the the users inside those partitions for dependencies. However if we cannot clone the operation into that particular partition for any other reason (hazard) then we have to rely on a dependency from another partition, we were not checking for hazards in this case. This solution is somewhat crude, although should be correct in all cases. If we are not able to clone an operation into ALL consumers, then we make sure for any partition it IS going to be cloned into, there will not be a data-hazard. A more elegant solution would be to leave it in all possible cloned partitions, but then for any other consunmers make sure they have at least one clone that would not make a cycle otherwise put it into a new partition. However that would require a re-work of the sorting logic which treats all cloned ops as interchangeable. --------- Signed-off-by: Andrew Woloszyn <[email protected]> Co-authored-by: Ben Vanik <[email protected]>
1 parent 369f992 commit c682836

File tree

2 files changed

+99
-44
lines changed

2 files changed

+99
-44
lines changed

compiler/src/iree/compiler/Dialect/Stream/Analysis/Partitioning.cpp

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -23,14 +23,14 @@ void dumpPartition(Partition &partition, AsmState &asmState) {
2323
partition.affinity.dump();
2424
llvm::dbgs() << "\n";
2525
}
26-
llvm::dbgs() << " INS:\n ";
27-
llvm::interleaveComma(partition.ins, llvm::dbgs(), [&](Value in) {
28-
in.printAsOperand(llvm::dbgs(), asmState);
29-
});
30-
llvm::dbgs() << "\n OUTS:\n ";
31-
llvm::interleaveComma(partition.outs, llvm::dbgs(), [&](Value out) {
32-
out.printAsOperand(llvm::dbgs(), asmState);
33-
});
26+
llvm::dbgs() << " INS:\n ";
27+
llvm::interleave(
28+
partition.ins, llvm::dbgs(),
29+
[&](Value in) { in.print(llvm::dbgs(), asmState); }, "\n ");
30+
llvm::dbgs() << "\n OUTS:\n ";
31+
llvm::interleave(
32+
partition.outs, llvm::dbgs(),
33+
[&](Value out) { out.print(llvm::dbgs(), asmState); }, "\n ");
3434
llvm::dbgs() << "\n OPS:\n";
3535
for (auto *op : llvm::reverse(partition.ops)) {
3636
llvm::dbgs() << " ";

compiler/src/iree/compiler/Dialect/Stream/Analysis/Partitioning/ReferencePartitioning.cpp

Lines changed: 91 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,36 @@ static std::unique_ptr<AsmState> getRootAsmState(Block *block) {
3737
return nullptr;
3838
}
3939

40+
struct OpInfo {
41+
// Which partitions the op is contained within.
42+
llvm::BitVector membership;
43+
// Which partitions transitively depend on this operation.
44+
llvm::BitVector hazards;
45+
};
46+
47+
struct PartitionBuilder {
48+
unsigned ordinal;
49+
// Affinity of the partition.
50+
IREE::Stream::AffinityAttr affinity;
51+
// Ops present in the partition; ops may be present in multiple partitions.
52+
SetVector<Operation *> ops;
53+
// Ops that were cloned and are known not to have their values escape.
54+
DenseSet<Operation *> clonedOps;
55+
// Which partitions transitively depend on this partition.
56+
llvm::BitVector hazards;
57+
void insert(Operation *op, OpInfo &opInfo) {
58+
if (auto affinityOp = dyn_cast<IREE::Stream::AffinityOpInterface>(op)) {
59+
affinity = affinity ? affinity.joinAND(affinityOp.getAffinityAttr())
60+
: affinityOp.getAffinityAttr();
61+
}
62+
opInfo.membership.set(ordinal);
63+
if (opInfo.hazards.size() > ordinal)
64+
opInfo.hazards.reset(ordinal);
65+
ops.insert(op);
66+
hazards |= opInfo.hazards;
67+
}
68+
};
69+
4070
// This is terrible. See Stream/Analysis/Partition.h for a description of what
4171
// a real implementation would do. We want cost modeling for tie breakers when
4272
// an op could be in multiple partitions, cloning for ops that are not worth
@@ -46,36 +76,8 @@ partitionStreamableOpsReference(IREE::Stream::PartitioningConfigAttr config,
4676
Block *block) {
4777
PartitionSet partitionSet;
4878

49-
struct OpInfo {
50-
// Which partitions the op is contained within.
51-
llvm::BitVector membership;
52-
// Which partitions transitively depend on this operation.
53-
llvm::BitVector hazards;
54-
};
5579
DenseMap<Operation *, OpInfo> opInfos;
5680

57-
struct PartitionBuilder {
58-
unsigned ordinal;
59-
// Affinity of the partition.
60-
IREE::Stream::AffinityAttr affinity;
61-
// Ops present in the partition; ops may be present in multiple partitions.
62-
SetVector<Operation *> ops;
63-
// Ops that were cloned and are known not to have their values escape.
64-
DenseSet<Operation *> clonedOps;
65-
// Which partitions transitively depend on this partition.
66-
llvm::BitVector hazards;
67-
void insert(Operation *op, OpInfo &opInfo) {
68-
if (auto affinityOp = dyn_cast<IREE::Stream::AffinityOpInterface>(op)) {
69-
affinity = affinity ? affinity.joinAND(affinityOp.getAffinityAttr())
70-
: affinityOp.getAffinityAttr();
71-
}
72-
opInfo.membership.set(ordinal);
73-
if (opInfo.hazards.size() > ordinal)
74-
opInfo.hazards.reset(ordinal);
75-
ops.insert(op);
76-
hazards |= opInfo.hazards;
77-
}
78-
};
7981
SmallVector<std::unique_ptr<PartitionBuilder>> builders;
8082
llvm::BitVector usableBuilders;
8183

@@ -93,7 +95,8 @@ partitionStreamableOpsReference(IREE::Stream::PartitioningConfigAttr config,
9395
};
9496

9597
auto canAddOpToPartition = [&](Operation &op, OpInfo &opInfo,
96-
unsigned partitionOrdinal) {
98+
unsigned partitionOrdinal,
99+
bool check_for_clones = true) {
97100
auto streamableOp = dyn_cast<IREE::Stream::StreamableOpInterface>(op);
98101
if (!streamableOp) {
99102
return false;
@@ -111,7 +114,8 @@ partitionStreamableOpsReference(IREE::Stream::PartitioningConfigAttr config,
111114
return false;
112115
}
113116

114-
bool preferCloneToConsumers = streamableOp.preferCloneToConsumers();
117+
bool preferCloneToConsumers =
118+
check_for_clones && streamableOp.preferCloneToConsumers();
115119
llvm::BitVector *opHazards = nullptr;
116120
llvm::BitVector opHazardsInCandidatePartition;
117121
if (preferCloneToConsumers) {
@@ -154,6 +158,13 @@ partitionStreamableOpsReference(IREE::Stream::PartitioningConfigAttr config,
154158
llvm::DenseMap<Operation *, llvm::SmallVector<Operation *>> syncOps;
155159

156160
for (auto &op : llvm::reverse(*block)) {
161+
162+
LLVM_DEBUG({
163+
llvm::dbgs() << "====\nPartitioning op:\n";
164+
op.print(llvm::dbgs(), *asmState);
165+
llvm::dbgs() << "\n";
166+
});
167+
157168
// Skip constants; they just add noise (and since they are heavily CSE'd
158169
// they have lots of users to test).
159170
if (op.hasTrait<OpTrait::ConstantLike>()) {
@@ -190,6 +201,11 @@ partitionStreamableOpsReference(IREE::Stream::PartitioningConfigAttr config,
190201
syncOps[producer] = llvm::SmallVector<Operation *>();
191202
}
192203
syncOps[producer].push_back(&op);
204+
LLVM_DEBUG({
205+
llvm::dbgs() << "Skipping sync op for now \n";
206+
op.print(llvm::dbgs(), *asmState);
207+
llvm::dbgs() << "\n";
208+
});
193209
continue;
194210
}
195211
}
@@ -203,12 +219,6 @@ partitionStreamableOpsReference(IREE::Stream::PartitioningConfigAttr config,
203219
opInfo.hazards.reserve(builders.size() + 1);
204220
opInfo.hazards.resize(builders.size(), /*t=*/false);
205221

206-
LLVM_DEBUG({
207-
llvm::dbgs() << "====\nPartitioning op:\n";
208-
op.print(llvm::dbgs(), *asmState);
209-
llvm::dbgs() << "\n";
210-
});
211-
212222
// Set bits for each partition this op may be able to be placed into.
213223
// We prune the set based on whether the users are part of a transitive
214224
// dependency chain down the use-def chain to a partition.
@@ -242,6 +252,19 @@ partitionStreamableOpsReference(IREE::Stream::PartitioningConfigAttr config,
242252
continue;
243253
}
244254
auto &userInfo = userInfoIt->second;
255+
256+
LLVM_DEBUG({
257+
llvm::dbgs() << "Testing sync user:\n";
258+
user->print(llvm::dbgs(), *asmState);
259+
llvm::dbgs() << "\n";
260+
for (auto membershipOrdinal : userInfo.membership.set_bits()) {
261+
llvm::dbgs() << " member of partition " << membershipOrdinal
262+
<< "\n";
263+
}
264+
for (auto hazardOrdinal : userInfo.hazards.set_bits()) {
265+
llvm::dbgs() << " hazard w/ partition " << hazardOrdinal << "\n";
266+
}
267+
});
245268
opInfo.hazards |= userInfo.membership;
246269
opInfo.hazards |= userInfo.hazards;
247270
consumers.reset();
@@ -282,6 +305,34 @@ partitionStreamableOpsReference(IREE::Stream::PartitioningConfigAttr config,
282305
continue;
283306
}
284307

308+
// If we prefer to clone to our consumers, but we are
309+
// only cloning to a subset, we have to re-check our
310+
// partitions as they may generate cycles.
311+
if (streamableOp.preferCloneToConsumers()) {
312+
auto tempCandidates = candidates;
313+
tempCandidates &= consumers;
314+
if (tempCandidates.count() != consumers.count()) {
315+
// Prune candidates that do not have a compatible affinity.
316+
for (auto ordinal : candidates.set_bits()) {
317+
if (!canAddOpToPartition(op, opInfo, ordinal, false)) {
318+
LLVM_DEBUG(llvm::dbgs() << "Candidate partition " << ordinal
319+
<< " incompatible for clone\n");
320+
candidates.reset(ordinal);
321+
}
322+
}
323+
324+
for (auto syncOp : syncOps[&op]) {
325+
for (auto ordinal : candidates.set_bits()) {
326+
if (!canAddOpToPartition(*syncOp, opInfo, ordinal, false)) {
327+
LLVM_DEBUG(llvm::dbgs() << "Candidate partition " << ordinal
328+
<< " incompatible for clone\n");
329+
candidates.reset(ordinal);
330+
}
331+
}
332+
}
333+
}
334+
}
335+
285336
// First see which partitions are consuming this that we can also safely
286337
// move in to.
287338
consumers &= candidates;
@@ -321,6 +372,10 @@ partitionStreamableOpsReference(IREE::Stream::PartitioningConfigAttr config,
321372

322373
// If we have synchronization operations we can place in the last block:
323374
for (auto syncOp : syncOps[&op]) {
375+
LLVM_DEBUG(llvm::dbgs() << "Moving sync to candidate partition "
376+
<< firstCandidateOrdinal << ":\n ");
377+
LLVM_DEBUG(syncOp->print(llvm::dbgs(), *asmState));
378+
LLVM_DEBUG(llvm::dbgs() << "\n");
324379
builder->insert(syncOp, opInfo);
325380
}
326381

0 commit comments

Comments
 (0)