Skip to content

Commit 81eece7

Browse files
[mlir][linalg][bufferize] Debug output as IR attributes
Instead of printing analysis debug information to stderr, annotate the IR. This makes it easier to understand decisions made by the analysis, especially in larger input IR. Differential Revision: https://reviews.llvm.org/D115575
1 parent 6847379 commit 81eece7

File tree

5 files changed

+45
-113
lines changed

5 files changed

+45
-113
lines changed

mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,10 @@ struct BufferizationOptions {
140140
/// checking the results of the analysis) and post analysis steps.
141141
bool testAnalysisOnly = false;
142142

143+
/// If set to `true`, the IR is annotated with details about RaW conflicts.
144+
/// For debugging only. Should be used together with `testAnalysisOnly`.
145+
bool printConflicts = false;
146+
143147
/// Registered post analysis steps.
144148
PostAnalysisStepList postAnalysisSteps;
145149

mlir/include/mlir/Dialect/Linalg/Passes.td

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,9 @@ def LinalgComprehensiveModuleBufferize :
3939
Option<"testAnalysisOnly", "test-analysis-only", "bool",
4040
/*default=*/"false",
4141
"Only runs inplaceability analysis (for testing purposes only)">,
42+
Option<"printConflicts", "print-conflicts", "bool",
43+
/*default=*/"false",
44+
"Annotates IR with RaW conflicts. Requires test-analysis-only.">,
4245
Option<"allowReturnMemref", "allow-return-memref", "bool",
4346
/*default=*/"false",
4447
"Allows the return of memrefs (for testing purposes only)">,

mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp

Lines changed: 37 additions & 102 deletions
Original file line numberDiff line numberDiff line change
@@ -117,25 +117,12 @@
117117
#include "mlir/IR/TypeUtilities.h"
118118
#include "llvm/ADT/DenseSet.h"
119119
#include "llvm/ADT/SetVector.h"
120-
#include "llvm/Support/Debug.h"
121-
#include "llvm/Support/FormatVariadic.h"
122-
123-
#define DEBUG_TYPE "comprehensive-module-bufferize"
124120

125121
using namespace mlir;
126122
using namespace linalg;
127123
using namespace tensor;
128124
using namespace comprehensive_bufferize;
129125

130-
#define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
131-
#define LDBG(X) LLVM_DEBUG(DBGS() << X)
132-
133-
// Forward declarations.
134-
#ifndef NDEBUG
135-
static std::string printOperationInfo(Operation *, bool prefix = true);
136-
static std::string printValueInfo(Value, bool prefix = true);
137-
#endif
138-
139126
static bool isaTensor(Type t) { return t.isa<TensorType>(); }
140127

141128
//===----------------------------------------------------------------------===//
@@ -164,64 +151,11 @@ static void setInPlaceOpResult(OpResult opResult, bool inPlace) {
164151
attr ? SmallVector<StringRef>(
165152
llvm::to_vector<4>(attr.getAsValueRange<StringAttr>()))
166153
: SmallVector<StringRef>(op->getNumResults(), "false");
167-
LDBG("->set inPlace=" << inPlace << " <- #" << opResult.getResultNumber()
168-
<< ": " << printOperationInfo(op) << "\n");
169154
inPlaceVector[opResult.getResultNumber()] = inPlace ? "true" : "false";
170155
op->setAttr(kInPlaceResultsAttrName,
171156
OpBuilder(op).getStrArrayAttr(inPlaceVector));
172157
}
173158

174-
//===----------------------------------------------------------------------===//
175-
// Printing helpers.
176-
//===----------------------------------------------------------------------===//
177-
178-
#ifndef NDEBUG
179-
/// Helper method printing the bufferization information of a buffer / tensor.
180-
static void printTensorOrBufferInfo(std::string prefix, Value value,
181-
AsmState &state, llvm::raw_ostream &os) {
182-
if (!value.getType().isa<ShapedType>())
183-
return;
184-
os << prefix;
185-
value.printAsOperand(os, state);
186-
os << " : " << value.getType();
187-
}
188-
189-
/// Print the operation name and bufferization information.
190-
static std::string printOperationInfo(Operation *op, bool prefix) {
191-
std::string result;
192-
llvm::raw_string_ostream os(result);
193-
AsmState state(op->getParentOfType<mlir::FuncOp>());
194-
StringRef tab = prefix ? "\n[" DEBUG_TYPE "]\t" : "";
195-
os << tab << op->getName();
196-
SmallVector<Value> shapedOperands;
197-
for (OpOperand &opOperand : op->getOpOperands()) {
198-
std::string prefix =
199-
llvm::formatv("{0} -> #{1} ", tab, opOperand.getOperandNumber());
200-
printTensorOrBufferInfo(prefix, opOperand.get(), state, os);
201-
}
202-
for (OpResult opResult : op->getOpResults()) {
203-
std::string prefix =
204-
llvm::formatv("{0} <- #{1} ", tab, opResult.getResultNumber());
205-
printTensorOrBufferInfo(prefix, opResult, state, os);
206-
}
207-
return result;
208-
}
209-
210-
/// Print the bufferization information for the defining op or block argument.
211-
static std::string printValueInfo(Value value, bool prefix) {
212-
auto *op = value.getDefiningOp();
213-
if (op)
214-
return printOperationInfo(op, prefix);
215-
// Print the block argument bufferization information.
216-
std::string result;
217-
llvm::raw_string_ostream os(result);
218-
AsmState state(value.getParentRegion()->getParentOfType<mlir::FuncOp>());
219-
os << value;
220-
printTensorOrBufferInfo("\n\t - ", value, state, os);
221-
return result;
222-
}
223-
#endif
224-
225159
//===----------------------------------------------------------------------===//
226160
// Bufferization-specific alias analysis.
227161
//===----------------------------------------------------------------------===//
@@ -251,7 +185,6 @@ static bool isInplaceMemoryWrite(OpOperand &opOperand,
251185
static bool aliasesNonWritableBuffer(Value value,
252186
const BufferizationAliasInfo &aliasInfo,
253187
BufferizationState &state) {
254-
LDBG("WRITABILITY ANALYSIS FOR " << printValueInfo(value) << "\n");
255188
bool foundNonWritableBuffer = false;
256189
aliasInfo.applyOnAliases(value, [&](Value v) {
257190
// Query BufferizableOpInterface to see if the OpResult is writable.
@@ -270,35 +203,22 @@ static bool aliasesNonWritableBuffer(Value value,
270203
foundNonWritableBuffer = true;
271204
});
272205

273-
if (foundNonWritableBuffer)
274-
LDBG("--> NON WRITABLE\n");
275-
else
276-
LDBG("--> WRITABLE\n");
277-
278206
return foundNonWritableBuffer;
279207
}
280208

281209
/// Return true if the buffer to which `operand` would bufferize is equivalent
282210
/// to some buffer write.
283211
static bool aliasesInPlaceWrite(Value value,
284212
const BufferizationAliasInfo &aliasInfo) {
285-
LDBG("----Start aliasesInPlaceWrite\n");
286-
LDBG("-------for : " << printValueInfo(value) << '\n');
287213
bool foundInplaceWrite = false;
288214
aliasInfo.applyOnAliases(value, [&](Value v) {
289215
for (auto &use : v.getUses()) {
290216
if (isInplaceMemoryWrite(use, aliasInfo)) {
291-
LDBG("-----------wants to bufferize to inPlace write: "
292-
<< printOperationInfo(use.getOwner()) << '\n');
293217
foundInplaceWrite = true;
294218
return;
295219
}
296220
}
297221
});
298-
299-
if (!foundInplaceWrite)
300-
LDBG("----------->does not alias an inplace write\n");
301-
302222
return foundInplaceWrite;
303223
}
304224

@@ -317,6 +237,39 @@ static bool happensBefore(Operation *a, Operation *b,
317237
return false;
318238
}
319239

240+
/// Annotate IR with details about the detected RaW conflict.
241+
static void annotateConflict(OpOperand *uRead, OpOperand *uConflictingWrite,
242+
Value lastWrite) {
243+
static uint64_t counter = 0;
244+
Operation *readingOp = uRead->getOwner();
245+
Operation *conflictingWritingOp = uConflictingWrite->getOwner();
246+
247+
OpBuilder b(conflictingWritingOp->getContext());
248+
std::string id = "C_" + std::to_string(counter++);
249+
250+
std::string conflictingWriteAttr =
251+
id +
252+
"[CONFL-WRITE: " + std::to_string(uConflictingWrite->getOperandNumber()) +
253+
"]";
254+
conflictingWritingOp->setAttr(conflictingWriteAttr, b.getUnitAttr());
255+
256+
std::string readAttr =
257+
id + "[READ: " + std::to_string(uRead->getOperandNumber()) + "]";
258+
readingOp->setAttr(readAttr, b.getUnitAttr());
259+
260+
if (auto opResult = lastWrite.dyn_cast<OpResult>()) {
261+
std::string lastWriteAttr = id + "[LAST-WRITE: result " +
262+
std::to_string(opResult.getResultNumber()) +
263+
"]";
264+
opResult.getDefiningOp()->setAttr(lastWriteAttr, b.getUnitAttr());
265+
} else {
266+
auto bbArg = lastWrite.cast<BlockArgument>();
267+
std::string lastWriteAttr =
268+
id + "[LAST-WRITE: bbArg " + std::to_string(bbArg.getArgNumber()) + "]";
269+
bbArg.getOwner()->getParentOp()->setAttr(lastWriteAttr, b.getUnitAttr());
270+
}
271+
}
272+
320273
/// Given sets of uses and writes, return true if there is a RaW conflict under
321274
/// the assumption that all given reads/writes alias the same buffer and that
322275
/// all given writes bufferize inplace.
@@ -351,14 +304,6 @@ static bool hasReadAfterWriteInterference(
351304
// met for uConflictingWrite to be an actual conflict.
352305
Operation *conflictingWritingOp = uConflictingWrite->getOwner();
353306

354-
// Print some debug info.
355-
LDBG("Found potential conflict:\n");
356-
LDBG("READ = #" << uRead->getOperandNumber() << " of "
357-
<< printOperationInfo(readingOp) << "\n");
358-
LDBG("CONFLICTING WRITE = #"
359-
<< uConflictingWrite->getOperandNumber() << " of "
360-
<< printOperationInfo(conflictingWritingOp) << "\n");
361-
362307
// No conflict if the readingOp dominates conflictingWritingOp, i.e., the
363308
// write is not visible when reading.
364309
if (happensBefore(readingOp, conflictingWritingOp, domInfo))
@@ -387,8 +332,6 @@ static bool hasReadAfterWriteInterference(
387332
if (insideMutuallyExclusiveRegions(readingOp, conflictingWritingOp))
388333
continue;
389334

390-
LDBG("WRITE = #" << printValueInfo(lastWrite) << "\n");
391-
392335
// No conflict if the conflicting write happens before the last
393336
// write.
394337
if (Operation *writingOp = lastWrite.getDefiningOp()) {
@@ -413,12 +356,14 @@ static bool hasReadAfterWriteInterference(
413356
continue;
414357

415358
// All requirements are met. Conflict found!
416-
LDBG("CONFLICT CONFIRMED!\n\n");
359+
360+
if (options.printConflicts)
361+
annotateConflict(uRead, uConflictingWrite, lastWrite);
362+
417363
return true;
418364
}
419365
}
420366

421-
LDBG("NOT A CONFLICT!\n\n");
422367
return false;
423368
}
424369

@@ -530,7 +475,6 @@ wouldCreateWriteToNonWritableBuffer(OpOperand &opOperand, OpResult opResult,
530475
if (!hasWrite)
531476
return false;
532477

533-
LDBG("->the corresponding buffer is not writeable\n");
534478
return true;
535479
}
536480

@@ -548,13 +492,6 @@ static LogicalResult bufferizableInPlaceAnalysisImpl(
548492
"operand and result do not match");
549493
#endif // NDEBUG
550494

551-
int64_t resultNumber = result.getResultNumber();
552-
(void)resultNumber;
553-
LDBG('\n');
554-
LDBG("Inplace analysis for <- #" << resultNumber << " -> #"
555-
<< operand.getOperandNumber() << " in "
556-
<< printValueInfo(result) << '\n');
557-
558495
bool foundInterference =
559496
wouldCreateWriteToNonWritableBuffer(operand, result, aliasInfo, state) ||
560497
wouldCreateReadAfterWriteInterference(operand, result, domInfo, state,
@@ -565,8 +502,6 @@ static LogicalResult bufferizableInPlaceAnalysisImpl(
565502
else
566503
aliasInfo.bufferizeInPlace(result, operand);
567504

568-
LDBG("Done inplace analysis for result #" << resultNumber << '\n');
569-
570505
return success();
571506
}
572507

mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,6 @@
1414
#include "mlir/Dialect/MemRef/IR/MemRef.h"
1515
#include "mlir/Dialect/StandardOps/IR/Ops.h"
1616
#include "mlir/IR/Operation.h"
17-
#include "llvm/Support/Debug.h"
18-
#include "llvm/Support/FormatVariadic.h"
19-
20-
#define DEBUG_TYPE "comprehensive-module-bufferize"
21-
#define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
22-
#define LDBG(X) LLVM_DEBUG(DBGS() << X)
2317

2418
using namespace mlir;
2519
using namespace linalg;
@@ -181,7 +175,6 @@ static FunctionType getOrCreateBufferizedFunctionType(
181175
auto it2 = bufferizedFunctionTypes.try_emplace(
182176
funcOp, getBufferizedFunctionType(funcOp.getContext(), argumentTypes,
183177
resultTypes));
184-
LDBG("FT: " << funcOp.getType() << " -> " << it2.first->second << "\n");
185178
return it2.first->second;
186179
}
187180

@@ -227,7 +220,6 @@ static void equivalenceAnalysis(FuncOp funcOp,
227220
/// future.
228221
static LogicalResult bufferizeFuncOpBoundary(FuncOp funcOp,
229222
BufferizationState &state) {
230-
LLVM_DEBUG(DBGS() << "Begin bufferizeFuncOpBoundary:\n" << funcOp << "\n");
231223
ModuleBufferizationState &moduleState = getModuleBufferizationState(state);
232224

233225
// If nothing to do then we are done.
@@ -261,7 +253,6 @@ static LogicalResult bufferizeFuncOpBoundary(FuncOp funcOp,
261253
funcOp, funcOp.getType().getInputs(), TypeRange{},
262254
moduleState.bufferizedFunctionTypes);
263255
funcOp.setType(bufferizedFuncType);
264-
LLVM_DEBUG(DBGS() << "End bufferizeFuncOpBoundary no fun body: " << funcOp);
265256
return success();
266257
}
267258

@@ -341,8 +332,6 @@ static LogicalResult bufferizeFuncOpBoundary(FuncOp funcOp,
341332
// 4. Rewrite the FuncOp type to buffer form.
342333
funcOp.setType(bufferizedFuncType);
343334

344-
LLVM_DEBUG(DBGS() << "End bufferizeFuncOpBoundary:\n" << funcOp);
345-
346335
return success();
347336
}
348337

mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,7 @@ void LinalgComprehensiveModuleBufferize::runOnOperation() {
8989
options.allowUnknownOps = allowUnknownOps;
9090
options.analysisFuzzerSeed = analysisFuzzerSeed;
9191
options.testAnalysisOnly = testAnalysisOnly;
92+
options.printConflicts = printConflicts;
9293

9394
// Enable InitTensorOp elimination.
9495
options.addPostAnalysisStep<

0 commit comments

Comments
 (0)