Skip to content

Commit 2e9733c

Browse files
SavchenkoValeriyvinay-deshmukh
authored andcommitted
[ValueTracking] Refine known bits for linear interpolation patterns (llvm#166378)
In this patch, we try to detect the lerp pattern: a * (b - c) + c * d where a >= 0, b >= 0, c >= 0, d >= 0, and b >= c. In that particular case, we can use the following chain of reasoning: a * (b - c) + c * d <= a' * (b - c) + a' * c = a' * b where a' = max(a, d) Since that is true for arbitrary a, b, c and d within our constraints, we can conclude that: max(a * (b - c) + c * d) <= max(max(a), max(d)) * max(b) = U Considering that any result of the lerp would be less or equal to U, it would have at least the number of leading 0s as in U. While being quite a specific situation, it is fairly common in computer graphics in the shape of alpha blending. In conjunction with llvm#165877, increases vectorization factor for lerp loops.
1 parent caf9bb9 commit 2e9733c

File tree

3 files changed

+332
-0
lines changed

3 files changed

+332
-0
lines changed

llvm/include/llvm/IR/PatternMatch.h

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -872,18 +872,32 @@ inline bind_and_match_ty<const Value, MatchTy> m_Value(const Value *&V,
872872

873873
/// Match an instruction, capturing it if we match.
874874
inline bind_ty<Instruction> m_Instruction(Instruction *&I) { return I; }
875+
inline bind_ty<const Instruction> m_Instruction(const Instruction *&I) {
876+
return I;
877+
}
875878

876879
/// Match against the nested pattern, and capture the instruction if we match.
877880
template <typename MatchTy>
878881
inline bind_and_match_ty<Instruction, MatchTy>
879882
m_Instruction(Instruction *&I, const MatchTy &Match) {
880883
return {I, Match};
881884
}
885+
template <typename MatchTy>
886+
inline bind_and_match_ty<const Instruction, MatchTy>
887+
m_Instruction(const Instruction *&I, const MatchTy &Match) {
888+
return {I, Match};
889+
}
882890

883891
/// Match a unary operator, capturing it if we match.
884892
inline bind_ty<UnaryOperator> m_UnOp(UnaryOperator *&I) { return I; }
893+
inline bind_ty<const UnaryOperator> m_UnOp(const UnaryOperator *&I) {
894+
return I;
895+
}
885896
/// Match a binary operator, capturing it if we match.
886897
inline bind_ty<BinaryOperator> m_BinOp(BinaryOperator *&I) { return I; }
898+
inline bind_ty<const BinaryOperator> m_BinOp(const BinaryOperator *&I) {
899+
return I;
900+
}
887901
/// Match a with overflow intrinsic, capturing it if we match.
888902
inline bind_ty<WithOverflowInst> m_WithOverflowInst(WithOverflowInst *&I) {
889903
return I;

llvm/lib/Analysis/ValueTracking.cpp

Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -350,6 +350,139 @@ unsigned llvm::ComputeMaxSignificantBits(const Value *V, const DataLayout &DL,
350350
return V->getType()->getScalarSizeInBits() - SignBits + 1;
351351
}
352352

353+
/// Try to detect the lerp pattern: a * (b - c) + c * d
354+
/// where a >= 0, b >= 0, c >= 0, d >= 0, and b >= c.
355+
///
356+
/// In that particular case, we can use the following chain of reasoning:
357+
///
358+
/// a * (b - c) + c * d <= a' * (b - c) + a' * c = a' * b where a' = max(a, d)
359+
///
360+
/// Since that is true for arbitrary a, b, c and d within our constraints, we
361+
/// can conclude that:
362+
///
363+
/// max(a * (b - c) + c * d) <= max(max(a), max(d)) * max(b) = U
364+
///
365+
/// Considering that any result of the lerp would be less or equal to U, it
366+
/// would have at least the number of leading 0s as in U.
367+
///
368+
/// While being quite a specific situation, it is fairly common in computer
369+
/// graphics in the shape of alpha blending.
370+
///
371+
/// Modifies given KnownOut in-place with the inferred information.
372+
static void computeKnownBitsFromLerpPattern(const Value *Op0, const Value *Op1,
373+
const APInt &DemandedElts,
374+
KnownBits &KnownOut,
375+
const SimplifyQuery &Q,
376+
unsigned Depth) {
377+
378+
Type *Ty = Op0->getType();
379+
const unsigned BitWidth = Ty->getScalarSizeInBits();
380+
381+
// Only handle scalar types for now
382+
if (Ty->isVectorTy())
383+
return;
384+
385+
// Try to match: a * (b - c) + c * d.
386+
// When a == 1 => A == nullptr, the same applies to d/D as well.
387+
const Value *A = nullptr, *B = nullptr, *C = nullptr, *D = nullptr;
388+
const Instruction *SubBC = nullptr;
389+
390+
const auto MatchSubBC = [&]() {
391+
// (b - c) can have two forms that interest us:
392+
//
393+
// 1. sub nuw %b, %c
394+
// 2. xor %c, %b
395+
//
396+
// For the first case, nuw flag guarantees our requirement b >= c.
397+
//
398+
// The second case might happen when the analysis can infer that b is a mask
399+
// for c and we can transform sub operation into xor (that is usually true
400+
// for constant b's). Even though xor is symmetrical, canonicalization
401+
// ensures that the constant will be the RHS. We have additional checks
402+
// later on to ensure that this xor operation is equivalent to subtraction.
403+
return m_Instruction(SubBC, m_CombineOr(m_NUWSub(m_Value(B), m_Value(C)),
404+
m_Xor(m_Value(C), m_Value(B))));
405+
};
406+
407+
const auto MatchASubBC = [&]() {
408+
// Cases:
409+
// - a * (b - c)
410+
// - (b - c) * a
411+
// - (b - c) <- a implicitly equals 1
412+
return m_CombineOr(m_c_Mul(m_Value(A), MatchSubBC()), MatchSubBC());
413+
};
414+
415+
const auto MatchCD = [&]() {
416+
// Cases:
417+
// - d * c
418+
// - c * d
419+
// - c <- d implicitly equals 1
420+
return m_CombineOr(m_c_Mul(m_Value(D), m_Specific(C)), m_Specific(C));
421+
};
422+
423+
const auto Match = [&](const Value *LHS, const Value *RHS) {
424+
// We do use m_Specific(C) in MatchCD, so we have to make sure that
425+
// it's bound to anything and match(LHS, MatchASubBC()) absolutely
426+
// has to evaluate first and return true.
427+
//
428+
// If Match returns true, it is guaranteed that B != nullptr, C != nullptr.
429+
return match(LHS, MatchASubBC()) && match(RHS, MatchCD());
430+
};
431+
432+
if (!Match(Op0, Op1) && !Match(Op1, Op0))
433+
return;
434+
435+
const auto ComputeKnownBitsOrOne = [&](const Value *V) {
436+
// For some of the values we use the convention of leaving
437+
// it nullptr to signify an implicit constant 1.
438+
return V ? computeKnownBits(V, DemandedElts, Q, Depth + 1)
439+
: KnownBits::makeConstant(APInt(BitWidth, 1));
440+
};
441+
442+
// Check that all operands are non-negative
443+
const KnownBits KnownA = ComputeKnownBitsOrOne(A);
444+
if (!KnownA.isNonNegative())
445+
return;
446+
447+
const KnownBits KnownD = ComputeKnownBitsOrOne(D);
448+
if (!KnownD.isNonNegative())
449+
return;
450+
451+
const KnownBits KnownB = computeKnownBits(B, DemandedElts, Q, Depth + 1);
452+
if (!KnownB.isNonNegative())
453+
return;
454+
455+
const KnownBits KnownC = computeKnownBits(C, DemandedElts, Q, Depth + 1);
456+
if (!KnownC.isNonNegative())
457+
return;
458+
459+
// If we matched subtraction as xor, we need to actually check that xor
460+
// is semantically equivalent to subtraction.
461+
//
462+
// For that to be true, b has to be a mask for c or that b's known
463+
// ones cover all known and possible ones of c.
464+
if (SubBC->getOpcode() == Instruction::Xor &&
465+
!KnownC.getMaxValue().isSubsetOf(KnownB.getMinValue()))
466+
return;
467+
468+
const APInt MaxA = KnownA.getMaxValue();
469+
const APInt MaxD = KnownD.getMaxValue();
470+
const APInt MaxAD = APIntOps::umax(MaxA, MaxD);
471+
const APInt MaxB = KnownB.getMaxValue();
472+
473+
// We can't infer leading zeros info if the upper-bound estimate wraps.
474+
bool Overflow;
475+
const APInt UpperBound = MaxAD.umul_ov(MaxB, Overflow);
476+
477+
if (Overflow)
478+
return;
479+
480+
// If we know that x <= y and both are positive than x has at least the same
481+
// number of leading zeros as y.
482+
const unsigned MinimumNumberOfLeadingZeros = UpperBound.countl_zero();
483+
KnownOut.Zero.setHighBits(MinimumNumberOfLeadingZeros);
484+
}
485+
353486
static void computeKnownBitsAddSub(bool Add, const Value *Op0, const Value *Op1,
354487
bool NSW, bool NUW,
355488
const APInt &DemandedElts,
@@ -369,6 +502,10 @@ static void computeKnownBitsAddSub(bool Add, const Value *Op0, const Value *Op1,
369502
isImpliedByDomCondition(ICmpInst::ICMP_SLE, Op1, Op0, Q.CxtI, Q.DL)
370503
.value_or(false))
371504
KnownOut.makeNonNegative();
505+
506+
if (Add)
507+
// Try to match lerp pattern and combine results
508+
computeKnownBitsFromLerpPattern(Op0, Op1, DemandedElts, KnownOut, Q, Depth);
372509
}
373510

374511
static void computeKnownBitsMul(const Value *Op0, const Value *Op1, bool NSW,
Lines changed: 181 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,181 @@
1+
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5
2+
; RUN: opt < %s -passes=instcombine -S | FileCheck %s
3+
4+
; Test known bits refinements for pattern: a * (b - c) + c * d
5+
; where a > 0, c > 0, b > 0, d > 0, and b > c.
6+
; This pattern is a generalization of lerp and it appears frequently in graphics operations.
7+
8+
define i32 @test_clamp(i8 %a, i8 %c, i8 %d) {
9+
; CHECK-LABEL: define i32 @test_clamp(
10+
; CHECK-SAME: i8 [[A:%.*]], i8 [[C:%.*]], i8 [[D:%.*]]) {
11+
; CHECK-NEXT: [[A32:%.*]] = zext i8 [[A]] to i32
12+
; CHECK-NEXT: [[C32:%.*]] = zext i8 [[C]] to i32
13+
; CHECK-NEXT: [[D32:%.*]] = zext i8 [[D]] to i32
14+
; CHECK-NEXT: [[SUB:%.*]] = xor i32 [[C32]], 255
15+
; CHECK-NEXT: [[MUL1:%.*]] = mul nuw nsw i32 [[SUB]], [[A32]]
16+
; CHECK-NEXT: [[MUL2:%.*]] = mul nuw nsw i32 [[C32]], [[D32]]
17+
; CHECK-NEXT: [[ADD:%.*]] = add nuw nsw i32 [[MUL1]], [[MUL2]]
18+
; CHECK-NEXT: ret i32 [[ADD]]
19+
;
20+
%a32 = zext i8 %a to i32
21+
%c32 = zext i8 %c to i32
22+
%d32 = zext i8 %d to i32
23+
%sub = sub i32 255, %c32
24+
%mul1 = mul i32 %a32, %sub
25+
%mul2 = mul i32 %c32, %d32
26+
%add = add i32 %mul1, %mul2
27+
%cmp = icmp ugt i32 %add, 65535
28+
%result = select i1 %cmp, i32 65535, i32 %add
29+
ret i32 %result
30+
}
31+
32+
define i1 @test_trunc_cmp(i8 %a, i8 %c, i8 %d) {
33+
; CHECK-LABEL: define i1 @test_trunc_cmp(
34+
; CHECK-SAME: i8 [[A:%.*]], i8 [[C:%.*]], i8 [[D:%.*]]) {
35+
; CHECK-NEXT: [[A32:%.*]] = zext i8 [[A]] to i32
36+
; CHECK-NEXT: [[C32:%.*]] = zext i8 [[C]] to i32
37+
; CHECK-NEXT: [[D32:%.*]] = zext i8 [[D]] to i32
38+
; CHECK-NEXT: [[SUB:%.*]] = xor i32 [[C32]], 255
39+
; CHECK-NEXT: [[MUL1:%.*]] = mul nuw nsw i32 [[SUB]], [[A32]]
40+
; CHECK-NEXT: [[MUL2:%.*]] = mul nuw nsw i32 [[C32]], [[D32]]
41+
; CHECK-NEXT: [[ADD:%.*]] = add nuw nsw i32 [[MUL1]], [[MUL2]]
42+
; CHECK-NEXT: [[CMP:%.*]] = icmp eq i32 [[ADD]], 1234
43+
; CHECK-NEXT: ret i1 [[CMP]]
44+
;
45+
%a32 = zext i8 %a to i32
46+
%c32 = zext i8 %c to i32
47+
%d32 = zext i8 %d to i32
48+
%sub = sub i32 255, %c32
49+
%mul1 = mul i32 %a32, %sub
50+
%mul2 = mul i32 %c32, %d32
51+
%add = add i32 %mul1, %mul2
52+
%trunc = trunc i32 %add to i16
53+
%cmp = icmp eq i16 %trunc, 1234
54+
ret i1 %cmp
55+
}
56+
57+
define i1 @test_trunc_cmp_xor(i8 %a, i8 %c, i8 %d) {
58+
; CHECK-LABEL: define i1 @test_trunc_cmp_xor(
59+
; CHECK-SAME: i8 [[A:%.*]], i8 [[C:%.*]], i8 [[D:%.*]]) {
60+
; CHECK-NEXT: [[A32:%.*]] = zext i8 [[A]] to i32
61+
; CHECK-NEXT: [[C32:%.*]] = zext i8 [[C]] to i32
62+
; CHECK-NEXT: [[D32:%.*]] = zext i8 [[D]] to i32
63+
; CHECK-NEXT: [[SUB:%.*]] = xor i32 [[C32]], 255
64+
; CHECK-NEXT: [[MUL1:%.*]] = mul nuw nsw i32 [[SUB]], [[A32]]
65+
; CHECK-NEXT: [[MUL2:%.*]] = mul nuw nsw i32 [[C32]], [[D32]]
66+
; CHECK-NEXT: [[ADD:%.*]] = add nuw nsw i32 [[MUL1]], [[MUL2]]
67+
; CHECK-NEXT: [[CMP:%.*]] = icmp eq i32 [[ADD]], 1234
68+
; CHECK-NEXT: ret i1 [[CMP]]
69+
;
70+
%a32 = zext i8 %a to i32
71+
%c32 = zext i8 %c to i32
72+
%d32 = zext i8 %d to i32
73+
%sub = xor i32 255, %c32
74+
%mul1 = mul i32 %a32, %sub
75+
%mul2 = mul i32 %c32, %d32
76+
%add = add i32 %mul1, %mul2
77+
%trunc = trunc i32 %add to i16
78+
%cmp = icmp eq i16 %trunc, 1234
79+
ret i1 %cmp
80+
}
81+
82+
define i1 @test_trunc_cmp_arbitrary_b(i8 %a, i8 %b, i8 %c, i8 %d) {
83+
; CHECK-LABEL: define i1 @test_trunc_cmp_arbitrary_b(
84+
; CHECK-SAME: i8 [[A:%.*]], i8 [[B:%.*]], i8 [[C:%.*]], i8 [[D:%.*]]) {
85+
; CHECK-NEXT: [[A32:%.*]] = zext i8 [[A]] to i32
86+
; CHECK-NEXT: [[B32:%.*]] = zext i8 [[B]] to i32
87+
; CHECK-NEXT: [[C32:%.*]] = zext i8 [[C]] to i32
88+
; CHECK-NEXT: [[D32:%.*]] = zext i8 [[D]] to i32
89+
; CHECK-NEXT: [[SUB:%.*]] = sub nuw nsw i32 [[B32]], [[C32]]
90+
; CHECK-NEXT: [[MUL1:%.*]] = mul nuw nsw i32 [[SUB]], [[A32]]
91+
; CHECK-NEXT: [[MUL2:%.*]] = mul nuw nsw i32 [[C32]], [[D32]]
92+
; CHECK-NEXT: [[ADD:%.*]] = add nuw nsw i32 [[MUL1]], [[MUL2]]
93+
; CHECK-NEXT: [[CMP:%.*]] = icmp eq i32 [[ADD]], 1234
94+
; CHECK-NEXT: ret i1 [[CMP]]
95+
;
96+
%a32 = zext i8 %a to i32
97+
%b32 = zext i8 %b to i32
98+
%c32 = zext i8 %c to i32
99+
%d32 = zext i8 %d to i32
100+
%sub = sub nsw nuw i32 %b32, %c32
101+
%mul1 = mul i32 %a32, %sub
102+
%mul2 = mul i32 %c32, %d32
103+
%add = add i32 %mul1, %mul2
104+
%trunc = trunc i32 %add to i16
105+
%cmp = icmp eq i16 %trunc, 1234
106+
ret i1 %cmp
107+
}
108+
109+
110+
define i1 @test_trunc_cmp_no_a(i8 %b, i8 %c, i8 %d) {
111+
; CHECK-LABEL: define i1 @test_trunc_cmp_no_a(
112+
; CHECK-SAME: i8 [[B:%.*]], i8 [[C:%.*]], i8 [[D:%.*]]) {
113+
; CHECK-NEXT: [[B32:%.*]] = zext i8 [[B]] to i32
114+
; CHECK-NEXT: [[C32:%.*]] = zext i8 [[C]] to i32
115+
; CHECK-NEXT: [[D32:%.*]] = zext i8 [[D]] to i32
116+
; CHECK-NEXT: [[MUL1:%.*]] = sub nuw nsw i32 [[B32]], [[C32]]
117+
; CHECK-NEXT: [[MUL2:%.*]] = mul nuw nsw i32 [[C32]], [[D32]]
118+
; CHECK-NEXT: [[ADD:%.*]] = add nuw nsw i32 [[MUL1]], [[MUL2]]
119+
; CHECK-NEXT: [[CMP:%.*]] = icmp eq i32 [[ADD]], 1234
120+
; CHECK-NEXT: ret i1 [[CMP]]
121+
;
122+
%b32 = zext i8 %b to i32
123+
%c32 = zext i8 %c to i32
124+
%d32 = zext i8 %d to i32
125+
%sub = sub nuw i32 %b32, %c32
126+
%mul2 = mul i32 %c32, %d32
127+
%add = add i32 %sub, %mul2
128+
%trunc = trunc i32 %add to i16
129+
%cmp = icmp eq i16 %trunc, 1234
130+
ret i1 %cmp
131+
}
132+
133+
define i1 @test_trunc_cmp_no_d(i8 %a, i8 %b, i8 %c) {
134+
; CHECK-LABEL: define i1 @test_trunc_cmp_no_d(
135+
; CHECK-SAME: i8 [[A:%.*]], i8 [[B:%.*]], i8 [[C:%.*]]) {
136+
; CHECK-NEXT: [[A32:%.*]] = zext i8 [[A]] to i32
137+
; CHECK-NEXT: [[B32:%.*]] = zext i8 [[B]] to i32
138+
; CHECK-NEXT: [[C32:%.*]] = zext i8 [[C]] to i32
139+
; CHECK-NEXT: [[SUB:%.*]] = sub nuw nsw i32 [[B32]], [[C32]]
140+
; CHECK-NEXT: [[MUL1:%.*]] = mul nuw nsw i32 [[SUB]], [[A32]]
141+
; CHECK-NEXT: [[ADD:%.*]] = add nuw nsw i32 [[MUL1]], [[C32]]
142+
; CHECK-NEXT: [[CMP:%.*]] = icmp eq i32 [[ADD]], 1234
143+
; CHECK-NEXT: ret i1 [[CMP]]
144+
;
145+
%a32 = zext i8 %a to i32
146+
%b32 = zext i8 %b to i32
147+
%c32 = zext i8 %c to i32
148+
%sub = sub nsw nuw i32 %b32, %c32
149+
%mul1 = mul i32 %a32, %sub
150+
%add = add i32 %mul1, %c32
151+
%trunc = trunc i32 %add to i16
152+
%cmp = icmp eq i16 %trunc, 1234
153+
ret i1 %cmp
154+
}
155+
156+
define i1 @test_trunc_cmp_xor_negative(i8 %a, i8 %c, i8 %d) {
157+
; CHECK-LABEL: define i1 @test_trunc_cmp_xor_negative(
158+
; CHECK-SAME: i8 [[A:%.*]], i8 [[C:%.*]], i8 [[D:%.*]]) {
159+
; CHECK-NEXT: [[A32:%.*]] = zext i8 [[A]] to i32
160+
; CHECK-NEXT: [[C32:%.*]] = zext i8 [[C]] to i32
161+
; CHECK-NEXT: [[D32:%.*]] = zext i8 [[D]] to i32
162+
; CHECK-NEXT: [[SUB:%.*]] = xor i32 [[C32]], 234
163+
; CHECK-NEXT: [[MUL1:%.*]] = mul nuw nsw i32 [[SUB]], [[A32]]
164+
; CHECK-NEXT: [[MUL2:%.*]] = mul nuw nsw i32 [[C32]], [[D32]]
165+
; CHECK-NEXT: [[ADD:%.*]] = add nuw nsw i32 [[MUL1]], [[MUL2]]
166+
; CHECK-NEXT: [[TRUNC:%.*]] = trunc i32 [[ADD]] to i16
167+
; CHECK-NEXT: [[CMP:%.*]] = icmp eq i16 [[TRUNC]], 1234
168+
; CHECK-NEXT: ret i1 [[CMP]]
169+
;
170+
%a32 = zext i8 %a to i32
171+
%c32 = zext i8 %c to i32
172+
%d32 = zext i8 %d to i32
173+
%sub = xor i32 234, %c32
174+
%mul1 = mul i32 %a32, %sub
175+
%mul2 = mul i32 %c32, %d32
176+
%add = add i32 %mul1, %mul2
177+
; We should keep the trunc in this case
178+
%trunc = trunc i32 %add to i16
179+
%cmp = icmp eq i16 %trunc, 1234
180+
ret i1 %cmp
181+
}

0 commit comments

Comments
 (0)