Skip to content

Commit 4eec3f6

Browse files
committed
[CSOptimizer] Introduce a drop of inference to preserved unary argument behavior
Thanks to `LinkedExprAnalyzer` unary argument hack was able to infer matching based on literals and arithmetic operator chains, let's preserve that behavior in a more principled manner.
1 parent 0525818 commit 4eec3f6

File tree

2 files changed

+117
-10
lines changed

2 files changed

+117
-10
lines changed

lib/Sema/CSOptimizer.cpp

Lines changed: 116 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -262,6 +262,79 @@ inferTypeFromInitializerResultType(ConstraintSystem &cs,
262262
return {instanceTy, hasFailable};
263263
}
264264

265+
/// If the given expression represents a chain of operators that only have
266+
/// only literals as arguments, attempt to deduce a potential type of the
267+
/// chain. For example if chain has only integral literals it's going to
268+
/// be `Int`, if there are some floating-point literals mixed in - it's going
269+
/// to be `Double`.
270+
static Type inferTypeOfArithmeticOperatorChain(DeclContext *dc, ASTNode node) {
271+
auto binaryOp = getAsExpr<BinaryExpr>(node);
272+
if (!binaryOp)
273+
return Type();
274+
275+
class OperatorChainAnalyzer : public ASTWalker {
276+
ASTContext &C;
277+
DeclContext *DC;
278+
279+
llvm::SmallPtrSet<Type, 2> literals;
280+
281+
bool unsupported = false;
282+
283+
PreWalkResult<Expr *> walkToExprPre(Expr *expr) override {
284+
if (isa<BinaryExpr>(expr))
285+
return Action::Continue(expr);
286+
287+
if (isa<ParenExpr>(expr))
288+
return Action::Continue(expr);
289+
290+
// This inference works only with arithmetic operators
291+
// because we know the structure of their overloads.
292+
if (auto *ODRE = dyn_cast<OverloadedDeclRefExpr>(expr)) {
293+
if (auto *choice = ODRE->getDecls().front()) {
294+
if (choice->getBaseIdentifier().isArithmeticOperator())
295+
return Action::Continue(expr);
296+
}
297+
}
298+
299+
if (auto *LE = dyn_cast<LiteralExpr>(expr)) {
300+
if (auto *P = TypeChecker::getLiteralProtocol(C, LE)) {
301+
if (auto defaultTy = TypeChecker::getDefaultType(P, DC)) {
302+
if (defaultTy->isInt()) {
303+
// Don't add `Int` if `Double` is already in the list.
304+
if (literals.contains(C.getDoubleType()))
305+
return Action::Continue(expr);
306+
} else if (defaultTy->isDouble()) {
307+
// A single use of a floating-point literal flips the
308+
// type of the entire chain to `Double`.
309+
(void)literals.erase(C.getIntType());
310+
}
311+
312+
literals.insert(defaultTy);
313+
return Action::Continue(expr);
314+
}
315+
}
316+
}
317+
318+
unsupported = true;
319+
return Action::Stop();
320+
}
321+
322+
public:
323+
OperatorChainAnalyzer(DeclContext *DC) : C(DC->getASTContext()), DC(DC) {}
324+
325+
Type chainType() const {
326+
if (unsupported)
327+
return Type();
328+
return literals.size() != 1 ? Type() : *literals.begin();
329+
}
330+
};
331+
332+
OperatorChainAnalyzer analyzer(dc);
333+
binaryOp->walk(analyzer);
334+
335+
return analyzer.chainType();
336+
}
337+
265338
NullablePtr<Constraint> getApplicableFnConstraint(ConstraintGraph &CG,
266339
Constraint *disjunction) {
267340
auto *boundVar = disjunction->getNestedConstraints()[0]
@@ -418,25 +491,62 @@ static std::optional<DisjunctionInfo> preserveFavoringOfUnlabeledUnaryArgument(
418491
// The hack operated on "favored" types and only declaration references,
419492
// applications, and (dynamic) subscripts had them if they managed to
420493
// get an overload choice selected during constraint generation.
494+
// It's sometimes possible to infer a type of a literal and an operator
495+
// chain, so it should be allowed as well.
421496
if (!(isExpr<DeclRefExpr>(argument) || isExpr<ApplyExpr>(argument) ||
422497
isExpr<SubscriptExpr>(argument) ||
423-
isExpr<DynamicSubscriptExpr>(argument)))
498+
isExpr<DynamicSubscriptExpr>(argument) ||
499+
isExpr<LiteralExpr>(argument) || isExpr<BinaryExpr>(argument)))
424500
return {/*score=*/0};
425501

426-
auto argumentType = cs.getType(argument);
502+
auto argumentType = cs.getType(argument)->getRValueType();
503+
504+
// For chains like `1 + 2 * 3` it's easy to deduce the type because
505+
// we know what literal types are preferred.
506+
if (isa<BinaryExpr>(argument)) {
507+
auto chainTy = inferTypeOfArithmeticOperatorChain(cs.DC, argument);
508+
if (!chainTy)
509+
return {/*score=*/0};
510+
511+
argumentType = chainTy;
512+
}
513+
514+
// Use default type of a literal (when available) to make a guess.
515+
// This is what old hack used to do as well.
516+
if (auto *LE = dyn_cast<LiteralExpr>(argument)) {
517+
auto *P = TypeChecker::getLiteralProtocol(cs.getASTContext(), LE);
518+
if (!P)
519+
return {/*score=*/0};
520+
521+
auto defaultTy = TypeChecker::getDefaultType(P, cs.DC);
522+
if (!defaultTy)
523+
return {/*score=*/0};
524+
525+
argumentType = defaultTy;
526+
}
527+
528+
ASSERT(argumentType);
529+
427530
if (argumentType->hasTypeVariable() || argumentType->hasDependentMember())
428531
return {/*score=*/0};
429532

430533
SmallVector<Constraint *, 2> favoredChoices;
431534
forEachDisjunctionChoice(
432535
cs, disjunction,
433-
[&argumentType, &favoredChoices](Constraint *choice, ValueDecl *decl,
434-
FunctionType *overloadType) {
536+
[&argumentType, &favoredChoices, &argument](
537+
Constraint *choice, ValueDecl *decl, FunctionType *overloadType) {
435538
if (overloadType->getNumParams() != 1)
436539
return;
437540

438-
auto paramType = overloadType->getParams()[0].getPlainType();
439-
if (paramType->isEqual(argumentType))
541+
auto &param = overloadType->getParams()[0];
542+
543+
// Literals are speculative, let's not attempt to apply them too
544+
// eagerly.
545+
if (!param.getParameterFlags().isNone() &&
546+
(isa<LiteralExpr>(argument) || isa<BinaryExpr>(argument)))
547+
return;
548+
549+
if (argumentType->isEqual(param.getPlainType()))
440550
favoredChoices.push_back(choice);
441551
});
442552

test/Constraints/old_hack_related_ambiguities.swift

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,11 @@ func entity(_: Int) -> Int {
66

77
struct Test {
88
func test(_ v: Int) -> Int { v }
9-
// expected-note@-1 {{found this candidate}}
109
func test(_ v: Int?) -> Int? { v }
11-
// expected-note@-1 {{found this candidate}}
1210
}
1311

1412
func test_ternary_literal(v: Test) -> Int? {
15-
// Literals don't have a favored type
16-
true ? v.test(0) : nil // expected-error {{ambiguous use of 'test'}}
13+
true ? v.test(0) : nil // Ok
1714
}
1815

1916
func test_ternary(v: Test) -> Int? {

0 commit comments

Comments
 (0)