Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 100 additions & 0 deletions lib/Dialect/TritonInstrument/Transforms/FpSanitizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -422,6 +422,15 @@ Value bitcastToFloat(PatternRewriter &rewriter, Location loc, Value v,
return tt::BitcastOp::create(rewriter, loc, floatTy, v);
}

uint64_t stableStringHash(StringRef str) {
uint64_t h = 14695981039346656037ull;
for (uint8_t c : str.bytes()) {
h ^= c;
h *= 1099511628211ull;
}
return h;
}

uint64_t murmur64Mixer(uint64_t h) {
h ^= h >> 33;
h *= 0xff51afd7ed558ccd;
Expand Down Expand Up @@ -458,6 +467,58 @@ Value fpsanSRem(PatternRewriter &rewriter, Location loc, Value num, Value den) {
return bitcastToFloat(rewriter, loc, resI, num.getType());
}

bool isIntLike(Type ty) { return isa<IntegerType>(getElementType(ty)); }

bool isNumericLike(Type ty) {
Type elemTy = getElementType(ty);
return isa<FloatType>(elemTy) || isa<IntegerType>(elemTy);
}

bool externHasNumericOperands(tt::ExternElementwiseOp op) {
return llvm::all_of(op.getOperands(), [](Value operand) {
return isNumericLike(operand.getType());
});
}

bool externInvolvesFloatLike(tt::ExternElementwiseOp op) {
return isFloatLike(op.getType()) ||
llvm::any_of(op.getOperands(), [](Value operand) {
return isFloatLike(operand.getType());
});
}

Value castExternOperandToResultInt(PatternRewriter &rewriter, Location loc,
Value operand, Type resultIntTy) {
if (isFloatLike(operand.getType())) {
return castIntValueToType(
rewriter, loc, bitcastToInt(rewriter, loc, operand), resultIntTy);
}
if (isIntLike(operand.getType())) {
return castIntValueToType(rewriter, loc, operand, resultIntTy);
}
return Value();
}

Value fpsanVariadicExternTagged(PatternRewriter &rewriter, Location loc,
tt::ExternElementwiseOp op, uint64_t hash) {
Type resultTy = op.getType();
Type resultIntTy = getIntTypeLike(resultTy);

Value sumI = getIntConstantLike(rewriter, loc, resultIntTy, 0);
for (Value operand : op.getOperands()) {
Value operandI =
castExternOperandToResultInt(rewriter, loc, operand, resultIntTy);
if (!operandI)
return Value();
sumI = arith::AddIOp::create(rewriter, loc, sumI, operandI);
}

auto hashVal = getIntConstantLike(rewriter, loc, resultIntTy,
static_cast<int64_t>(hash));
auto outI = arith::XOrIOp::create(rewriter, loc, sumI, hashVal);
return bitcastToFloat(rewriter, loc, outI, resultTy);
}

std::optional<ScratchInfo>
createOperandScratch(PatternRewriter &rewriter, Location loc,
TmemScratchManager &scratch, Value memdesc,
Expand Down Expand Up @@ -1101,6 +1162,25 @@ template <typename OpTy> struct UnaryPattern : public OpRewritePattern<OpTy> {
UnaryOpId unaryOpId;
};

struct ExternElementwisePattern
: public OpRewritePattern<tt::ExternElementwiseOp> {
using OpRewritePattern::OpRewritePattern;

LogicalResult matchAndRewrite(tt::ExternElementwiseOp op,
PatternRewriter &rewriter) const override {
if (!op.getPure() || !isFloatLike(op.getType()) ||
op.getNumOperands() == 0 || !externHasNumericOperands(op))
return failure();
Comment on lines +1171 to +1173

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Instrument float-input externs with non-float results

Drop the isFloatLike(op.getType()) gate here. It prevents rewriting libdevice-style float→int externs (e.g., float2int_*, ilogb, isnan), and the later unsupported-extern walk then fails the pass for them. That breaks the commit intent to support arbitrary float extern_elementwise (including libdevice functions) and causes valid kernels to fail under fpsan.

Useful? React with 👍 / 👎.


uint64_t hash = stableStringHash(op.getSymbol());
Value result = fpsanVariadicExternTagged(rewriter, op.getLoc(), op, hash);
if (!result)
return failure();
rewriter.replaceOp(op, result);
return success();
}
};

class FpSanitizerPass
: public impl::TritonInstrumentFpSanitizerBase<FpSanitizerPass> {
public:
Expand Down Expand Up @@ -1139,6 +1219,7 @@ class FpSanitizerPass
patterns.add<UnaryPattern<math::CeilOp>>(&getContext(), UnaryOpId::Ceil);
patterns.add<UnaryPattern<tt::PreciseSqrtOp>>(&getContext(),
UnaryOpId::PreciseSqrt);
patterns.add<ExternElementwisePattern>(&getContext());
patterns.add<TMEMLoadPattern, TMEMStorePattern, TMEMCopyPattern,
TCGen5MMAPattern>(&getContext(), &scratch);
patterns.add<TCGen5CommitPattern>(&getContext());
Expand All @@ -1148,6 +1229,25 @@ class FpSanitizerPass
signalPassFailure();
}

getOperation()->walk([&](tt::ExternElementwiseOp op) {
if (!externInvolvesFloatLike(op))
return WalkResult::advance();

hasUnsupportedOperations = true;
llvm::errs()
<< "FpSanitizer error: Unsupported extern_elementwise: symbol="
<< op.getSymbol() << ", pure=" << op.getPure()
<< ", num_operands=" << op.getNumOperands() << ", result_ty=";
op.getType().print(llvm::errs());
llvm::errs() << ", operand_tys=(";
llvm::interleaveComma(op.getOperandTypes(), llvm::errs(),
[&](Type ty) { ty.print(llvm::errs()); });
llvm::errs() << ")\n";
return WalkResult::interrupt();
});
if (hasUnsupportedOperations)
signalPassFailure();

// TODO: Remove unused tmem usages. This requires unwiring them from the
// warp specialize partitions.
}
Expand Down
Loading
Loading