@@ -37,6 +37,36 @@ static std::unique_ptr<AsmState> getRootAsmState(Block *block) {
3737 return nullptr ;
3838}
3939
40+ struct OpInfo {
41+ // Which partitions the op is contained within.
42+ llvm::BitVector membership;
43+ // Which partitions transitively depend on this operation.
44+ llvm::BitVector hazards;
45+ };
46+
47+ struct PartitionBuilder {
48+ unsigned ordinal;
49+ // Affinity of the partition.
50+ IREE::Stream::AffinityAttr affinity;
51+ // Ops present in the partition; ops may be present in multiple partitions.
52+ SetVector<Operation *> ops;
53+ // Ops that were cloned and are known not to have their values escape.
54+ DenseSet<Operation *> clonedOps;
55+ // Which partitions transitively depend on this partition.
56+ llvm::BitVector hazards;
57+ void insert (Operation *op, OpInfo &opInfo) {
58+ if (auto affinityOp = dyn_cast<IREE::Stream::AffinityOpInterface>(op)) {
59+ affinity = affinity ? affinity.joinAND (affinityOp.getAffinityAttr ())
60+ : affinityOp.getAffinityAttr ();
61+ }
62+ opInfo.membership .set (ordinal);
63+ if (opInfo.hazards .size () > ordinal)
64+ opInfo.hazards .reset (ordinal);
65+ ops.insert (op);
66+ hazards |= opInfo.hazards ;
67+ }
68+ };
69+
4070// This is terrible. See Stream/Analysis/Partition.h for a description of what
4171// a real implementation would do. We want cost modeling for tie breakers when
4272// an op could be in multiple partitions, cloning for ops that are not worth
@@ -46,36 +76,8 @@ partitionStreamableOpsReference(IREE::Stream::PartitioningConfigAttr config,
4676 Block *block) {
4777 PartitionSet partitionSet;
4878
49- struct OpInfo {
50- // Which partitions the op is contained within.
51- llvm::BitVector membership;
52- // Which partitions transitively depend on this operation.
53- llvm::BitVector hazards;
54- };
5579 DenseMap<Operation *, OpInfo> opInfos;
5680
57- struct PartitionBuilder {
58- unsigned ordinal;
59- // Affinity of the partition.
60- IREE::Stream::AffinityAttr affinity;
61- // Ops present in the partition; ops may be present in multiple partitions.
62- SetVector<Operation *> ops;
63- // Ops that were cloned and are known not to have their values escape.
64- DenseSet<Operation *> clonedOps;
65- // Which partitions transitively depend on this partition.
66- llvm::BitVector hazards;
67- void insert (Operation *op, OpInfo &opInfo) {
68- if (auto affinityOp = dyn_cast<IREE::Stream::AffinityOpInterface>(op)) {
69- affinity = affinity ? affinity.joinAND (affinityOp.getAffinityAttr ())
70- : affinityOp.getAffinityAttr ();
71- }
72- opInfo.membership .set (ordinal);
73- if (opInfo.hazards .size () > ordinal)
74- opInfo.hazards .reset (ordinal);
75- ops.insert (op);
76- hazards |= opInfo.hazards ;
77- }
78- };
7981 SmallVector<std::unique_ptr<PartitionBuilder>> builders;
8082 llvm::BitVector usableBuilders;
8183
@@ -93,7 +95,8 @@ partitionStreamableOpsReference(IREE::Stream::PartitioningConfigAttr config,
9395 };
9496
9597 auto canAddOpToPartition = [&](Operation &op, OpInfo &opInfo,
96- unsigned partitionOrdinal) {
98+ unsigned partitionOrdinal,
99+ bool check_for_clones = true ) {
97100 auto streamableOp = dyn_cast<IREE::Stream::StreamableOpInterface>(op);
98101 if (!streamableOp) {
99102 return false ;
@@ -111,7 +114,8 @@ partitionStreamableOpsReference(IREE::Stream::PartitioningConfigAttr config,
111114 return false ;
112115 }
113116
114- bool preferCloneToConsumers = streamableOp.preferCloneToConsumers ();
117+ bool preferCloneToConsumers =
118+ check_for_clones && streamableOp.preferCloneToConsumers ();
115119 llvm::BitVector *opHazards = nullptr ;
116120 llvm::BitVector opHazardsInCandidatePartition;
117121 if (preferCloneToConsumers) {
@@ -154,6 +158,13 @@ partitionStreamableOpsReference(IREE::Stream::PartitioningConfigAttr config,
154158 llvm::DenseMap<Operation *, llvm::SmallVector<Operation *>> syncOps;
155159
156160 for (auto &op : llvm::reverse (*block)) {
161+
162+ LLVM_DEBUG ({
163+ llvm::dbgs () << " ====\n Partitioning op:\n " ;
164+ op.print (llvm::dbgs (), *asmState);
165+ llvm::dbgs () << " \n " ;
166+ });
167+
157168 // Skip constants; they just add noise (and since they are heavily CSE'd
158169 // they have lots of users to test).
159170 if (op.hasTrait <OpTrait::ConstantLike>()) {
@@ -190,6 +201,11 @@ partitionStreamableOpsReference(IREE::Stream::PartitioningConfigAttr config,
190201 syncOps[producer] = llvm::SmallVector<Operation *>();
191202 }
192203 syncOps[producer].push_back (&op);
204+ LLVM_DEBUG ({
205+ llvm::dbgs () << " Skipping sync op for now \n " ;
206+ op.print (llvm::dbgs (), *asmState);
207+ llvm::dbgs () << " \n " ;
208+ });
193209 continue ;
194210 }
195211 }
@@ -203,12 +219,6 @@ partitionStreamableOpsReference(IREE::Stream::PartitioningConfigAttr config,
203219 opInfo.hazards .reserve (builders.size () + 1 );
204220 opInfo.hazards .resize (builders.size (), /* t=*/ false );
205221
206- LLVM_DEBUG ({
207- llvm::dbgs () << " ====\n Partitioning op:\n " ;
208- op.print (llvm::dbgs (), *asmState);
209- llvm::dbgs () << " \n " ;
210- });
211-
212222 // Set bits for each partition this op may be able to be placed into.
213223 // We prune the set based on whether the users are part of a transitive
214224 // dependency chain down the use-def chain to a partition.
@@ -242,6 +252,19 @@ partitionStreamableOpsReference(IREE::Stream::PartitioningConfigAttr config,
242252 continue ;
243253 }
244254 auto &userInfo = userInfoIt->second ;
255+
256+ LLVM_DEBUG ({
257+ llvm::dbgs () << " Testing sync user:\n " ;
258+ user->print (llvm::dbgs (), *asmState);
259+ llvm::dbgs () << " \n " ;
260+ for (auto membershipOrdinal : userInfo.membership .set_bits ()) {
261+ llvm::dbgs () << " member of partition " << membershipOrdinal
262+ << " \n " ;
263+ }
264+ for (auto hazardOrdinal : userInfo.hazards .set_bits ()) {
265+ llvm::dbgs () << " hazard w/ partition " << hazardOrdinal << " \n " ;
266+ }
267+ });
245268 opInfo.hazards |= userInfo.membership ;
246269 opInfo.hazards |= userInfo.hazards ;
247270 consumers.reset ();
@@ -282,6 +305,34 @@ partitionStreamableOpsReference(IREE::Stream::PartitioningConfigAttr config,
282305 continue ;
283306 }
284307
308+ // If we prefer to clone to our consumers, but we are
309+ // only cloning to a subset, we have to re-check our
310+ // partitions as they may generate cycles.
311+ if (streamableOp.preferCloneToConsumers ()) {
312+ auto tempCandidates = candidates;
313+ tempCandidates &= consumers;
314+ if (tempCandidates.count () != consumers.count ()) {
315+ // Prune candidates that do not have a compatible affinity.
316+ for (auto ordinal : candidates.set_bits ()) {
317+ if (!canAddOpToPartition (op, opInfo, ordinal, false )) {
318+ LLVM_DEBUG (llvm::dbgs () << " Candidate partition " << ordinal
319+ << " incompatible for clone\n " );
320+ candidates.reset (ordinal);
321+ }
322+ }
323+
324+ for (auto syncOp : syncOps[&op]) {
325+ for (auto ordinal : candidates.set_bits ()) {
326+ if (!canAddOpToPartition (*syncOp, opInfo, ordinal, false )) {
327+ LLVM_DEBUG (llvm::dbgs () << " Candidate partition " << ordinal
328+ << " incompatible for clone\n " );
329+ candidates.reset (ordinal);
330+ }
331+ }
332+ }
333+ }
334+ }
335+
285336 // First see which partitions are consuming this that we can also safely
286337 // move in to.
287338 consumers &= candidates;
@@ -321,6 +372,10 @@ partitionStreamableOpsReference(IREE::Stream::PartitioningConfigAttr config,
321372
322373 // If we have synchronization operations we can place in the last block:
323374 for (auto syncOp : syncOps[&op]) {
375+ LLVM_DEBUG (llvm::dbgs () << " Moving sync to candidate partition "
376+ << firstCandidateOrdinal << " :\n " );
377+ LLVM_DEBUG (syncOp->print (llvm::dbgs (), *asmState));
378+ LLVM_DEBUG (llvm::dbgs () << " \n " );
324379 builder->insert (syncOp, opInfo);
325380 }
326381
0 commit comments