117
117
#include " mlir/IR/TypeUtilities.h"
118
118
#include " llvm/ADT/DenseSet.h"
119
119
#include " llvm/ADT/SetVector.h"
120
- #include " llvm/Support/Debug.h"
121
- #include " llvm/Support/FormatVariadic.h"
122
-
123
- #define DEBUG_TYPE " comprehensive-module-bufferize"
124
120
125
121
using namespace mlir ;
126
122
using namespace linalg ;
127
123
using namespace tensor ;
128
124
using namespace comprehensive_bufferize ;
129
125
130
- #define DBGS () (llvm::dbgs() << ' [' << DEBUG_TYPE << " ] " )
131
- #define LDBG (X ) LLVM_DEBUG(DBGS() << X)
132
-
133
- // Forward declarations.
134
- #ifndef NDEBUG
135
- static std::string printOperationInfo (Operation *, bool prefix = true );
136
- static std::string printValueInfo (Value, bool prefix = true );
137
- #endif
138
-
139
126
static bool isaTensor (Type t) { return t.isa <TensorType>(); }
140
127
141
128
// ===----------------------------------------------------------------------===//
@@ -164,64 +151,11 @@ static void setInPlaceOpResult(OpResult opResult, bool inPlace) {
164
151
attr ? SmallVector<StringRef>(
165
152
llvm::to_vector<4 >(attr.getAsValueRange <StringAttr>()))
166
153
: SmallVector<StringRef>(op->getNumResults (), " false" );
167
- LDBG (" ->set inPlace=" << inPlace << " <- #" << opResult.getResultNumber ()
168
- << " : " << printOperationInfo (op) << " \n " );
169
154
inPlaceVector[opResult.getResultNumber ()] = inPlace ? " true" : " false" ;
170
155
op->setAttr (kInPlaceResultsAttrName ,
171
156
OpBuilder (op).getStrArrayAttr (inPlaceVector));
172
157
}
173
158
174
- // ===----------------------------------------------------------------------===//
175
- // Printing helpers.
176
- // ===----------------------------------------------------------------------===//
177
-
178
- #ifndef NDEBUG
179
- // / Helper method printing the bufferization information of a buffer / tensor.
180
- static void printTensorOrBufferInfo (std::string prefix, Value value,
181
- AsmState &state, llvm::raw_ostream &os) {
182
- if (!value.getType ().isa <ShapedType>())
183
- return ;
184
- os << prefix;
185
- value.printAsOperand (os, state);
186
- os << " : " << value.getType ();
187
- }
188
-
189
- // / Print the operation name and bufferization information.
190
- static std::string printOperationInfo (Operation *op, bool prefix) {
191
- std::string result;
192
- llvm::raw_string_ostream os (result);
193
- AsmState state (op->getParentOfType <mlir::FuncOp>());
194
- StringRef tab = prefix ? " \n [" DEBUG_TYPE " ]\t " : " " ;
195
- os << tab << op->getName ();
196
- SmallVector<Value> shapedOperands;
197
- for (OpOperand &opOperand : op->getOpOperands ()) {
198
- std::string prefix =
199
- llvm::formatv (" {0} -> #{1} " , tab, opOperand.getOperandNumber ());
200
- printTensorOrBufferInfo (prefix, opOperand.get (), state, os);
201
- }
202
- for (OpResult opResult : op->getOpResults ()) {
203
- std::string prefix =
204
- llvm::formatv (" {0} <- #{1} " , tab, opResult.getResultNumber ());
205
- printTensorOrBufferInfo (prefix, opResult, state, os);
206
- }
207
- return result;
208
- }
209
-
210
- // / Print the bufferization information for the defining op or block argument.
211
- static std::string printValueInfo (Value value, bool prefix) {
212
- auto *op = value.getDefiningOp ();
213
- if (op)
214
- return printOperationInfo (op, prefix);
215
- // Print the block argument bufferization information.
216
- std::string result;
217
- llvm::raw_string_ostream os (result);
218
- AsmState state (value.getParentRegion ()->getParentOfType <mlir::FuncOp>());
219
- os << value;
220
- printTensorOrBufferInfo (" \n\t - " , value, state, os);
221
- return result;
222
- }
223
- #endif
224
-
225
159
// ===----------------------------------------------------------------------===//
226
160
// Bufferization-specific alias analysis.
227
161
// ===----------------------------------------------------------------------===//
@@ -251,7 +185,6 @@ static bool isInplaceMemoryWrite(OpOperand &opOperand,
251
185
static bool aliasesNonWritableBuffer (Value value,
252
186
const BufferizationAliasInfo &aliasInfo,
253
187
BufferizationState &state) {
254
- LDBG (" WRITABILITY ANALYSIS FOR " << printValueInfo (value) << " \n " );
255
188
bool foundNonWritableBuffer = false ;
256
189
aliasInfo.applyOnAliases (value, [&](Value v) {
257
190
// Query BufferizableOpInterface to see if the OpResult is writable.
@@ -270,35 +203,22 @@ static bool aliasesNonWritableBuffer(Value value,
270
203
foundNonWritableBuffer = true ;
271
204
});
272
205
273
- if (foundNonWritableBuffer)
274
- LDBG (" --> NON WRITABLE\n " );
275
- else
276
- LDBG (" --> WRITABLE\n " );
277
-
278
206
return foundNonWritableBuffer;
279
207
}
280
208
281
209
// / Return true if the buffer to which `operand` would bufferize is equivalent
282
210
// / to some buffer write.
283
211
static bool aliasesInPlaceWrite (Value value,
284
212
const BufferizationAliasInfo &aliasInfo) {
285
- LDBG (" ----Start aliasesInPlaceWrite\n " );
286
- LDBG (" -------for : " << printValueInfo (value) << ' \n ' );
287
213
bool foundInplaceWrite = false ;
288
214
aliasInfo.applyOnAliases (value, [&](Value v) {
289
215
for (auto &use : v.getUses ()) {
290
216
if (isInplaceMemoryWrite (use, aliasInfo)) {
291
- LDBG (" -----------wants to bufferize to inPlace write: "
292
- << printOperationInfo (use.getOwner ()) << ' \n ' );
293
217
foundInplaceWrite = true ;
294
218
return ;
295
219
}
296
220
}
297
221
});
298
-
299
- if (!foundInplaceWrite)
300
- LDBG (" ----------->does not alias an inplace write\n " );
301
-
302
222
return foundInplaceWrite;
303
223
}
304
224
@@ -317,6 +237,39 @@ static bool happensBefore(Operation *a, Operation *b,
317
237
return false ;
318
238
}
319
239
240
+ // / Annotate IR with details about the detected RaW conflict.
241
+ static void annotateConflict (OpOperand *uRead, OpOperand *uConflictingWrite,
242
+ Value lastWrite) {
243
+ static uint64_t counter = 0 ;
244
+ Operation *readingOp = uRead->getOwner ();
245
+ Operation *conflictingWritingOp = uConflictingWrite->getOwner ();
246
+
247
+ OpBuilder b (conflictingWritingOp->getContext ());
248
+ std::string id = " C_" + std::to_string (counter++);
249
+
250
+ std::string conflictingWriteAttr =
251
+ id +
252
+ " [CONFL-WRITE: " + std::to_string (uConflictingWrite->getOperandNumber ()) +
253
+ " ]" ;
254
+ conflictingWritingOp->setAttr (conflictingWriteAttr, b.getUnitAttr ());
255
+
256
+ std::string readAttr =
257
+ id + " [READ: " + std::to_string (uRead->getOperandNumber ()) + " ]" ;
258
+ readingOp->setAttr (readAttr, b.getUnitAttr ());
259
+
260
+ if (auto opResult = lastWrite.dyn_cast <OpResult>()) {
261
+ std::string lastWriteAttr = id + " [LAST-WRITE: result " +
262
+ std::to_string (opResult.getResultNumber ()) +
263
+ " ]" ;
264
+ opResult.getDefiningOp ()->setAttr (lastWriteAttr, b.getUnitAttr ());
265
+ } else {
266
+ auto bbArg = lastWrite.cast <BlockArgument>();
267
+ std::string lastWriteAttr =
268
+ id + " [LAST-WRITE: bbArg " + std::to_string (bbArg.getArgNumber ()) + " ]" ;
269
+ bbArg.getOwner ()->getParentOp ()->setAttr (lastWriteAttr, b.getUnitAttr ());
270
+ }
271
+ }
272
+
320
273
// / Given sets of uses and writes, return true if there is a RaW conflict under
321
274
// / the assumption that all given reads/writes alias the same buffer and that
322
275
// / all given writes bufferize inplace.
@@ -351,14 +304,6 @@ static bool hasReadAfterWriteInterference(
351
304
// met for uConflictingWrite to be an actual conflict.
352
305
Operation *conflictingWritingOp = uConflictingWrite->getOwner ();
353
306
354
- // Print some debug info.
355
- LDBG (" Found potential conflict:\n " );
356
- LDBG (" READ = #" << uRead->getOperandNumber () << " of "
357
- << printOperationInfo (readingOp) << " \n " );
358
- LDBG (" CONFLICTING WRITE = #"
359
- << uConflictingWrite->getOperandNumber () << " of "
360
- << printOperationInfo (conflictingWritingOp) << " \n " );
361
-
362
307
// No conflict if the readingOp dominates conflictingWritingOp, i.e., the
363
308
// write is not visible when reading.
364
309
if (happensBefore (readingOp, conflictingWritingOp, domInfo))
@@ -387,8 +332,6 @@ static bool hasReadAfterWriteInterference(
387
332
if (insideMutuallyExclusiveRegions (readingOp, conflictingWritingOp))
388
333
continue ;
389
334
390
- LDBG (" WRITE = #" << printValueInfo (lastWrite) << " \n " );
391
-
392
335
// No conflict if the conflicting write happens before the last
393
336
// write.
394
337
if (Operation *writingOp = lastWrite.getDefiningOp ()) {
@@ -413,12 +356,14 @@ static bool hasReadAfterWriteInterference(
413
356
continue ;
414
357
415
358
// All requirements are met. Conflict found!
416
- LDBG (" CONFLICT CONFIRMED!\n\n " );
359
+
360
+ if (options.printConflicts )
361
+ annotateConflict (uRead, uConflictingWrite, lastWrite);
362
+
417
363
return true ;
418
364
}
419
365
}
420
366
421
- LDBG (" NOT A CONFLICT!\n\n " );
422
367
return false ;
423
368
}
424
369
@@ -530,7 +475,6 @@ wouldCreateWriteToNonWritableBuffer(OpOperand &opOperand, OpResult opResult,
530
475
if (!hasWrite)
531
476
return false ;
532
477
533
- LDBG (" ->the corresponding buffer is not writeable\n " );
534
478
return true ;
535
479
}
536
480
@@ -548,13 +492,6 @@ static LogicalResult bufferizableInPlaceAnalysisImpl(
548
492
" operand and result do not match" );
549
493
#endif // NDEBUG
550
494
551
- int64_t resultNumber = result.getResultNumber ();
552
- (void )resultNumber;
553
- LDBG (' \n ' );
554
- LDBG (" Inplace analysis for <- #" << resultNumber << " -> #"
555
- << operand.getOperandNumber () << " in "
556
- << printValueInfo (result) << ' \n ' );
557
-
558
495
bool foundInterference =
559
496
wouldCreateWriteToNonWritableBuffer (operand, result, aliasInfo, state) ||
560
497
wouldCreateReadAfterWriteInterference (operand, result, domInfo, state,
@@ -565,8 +502,6 @@ static LogicalResult bufferizableInPlaceAnalysisImpl(
565
502
else
566
503
aliasInfo.bufferizeInPlace (result, operand);
567
504
568
- LDBG (" Done inplace analysis for result #" << resultNumber << ' \n ' );
569
-
570
505
return success ();
571
506
}
572
507
0 commit comments