Skip to content

Commit 90e5957

Browse files
committed
[Constraint solver] Use type variable constraints to find literal types.
Rather than looking directly at the subexpressions of a function application to determine whether they are literal expressions, look instead at the type variables for the argument types: if they have a literal-conforms-to constraint on them, use the named literal protocol to determine favored declarations. This refactoring is intended to broaden the applicability of the existing "favored declaration" machinery to also consider the literal constraints of other type variables that are equivalent to the type variable named by the argument.
1 parent 1ab417c commit 90e5957

File tree

1 file changed

+36
-33
lines changed

1 file changed

+36
-33
lines changed

lib/Sema/CSGen.cpp

Lines changed: 36 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -540,7 +540,6 @@ namespace {
540540
/// "favored" because they match exactly.
541541
bool isFavoredParamAndArg(ConstraintSystem &CS,
542542
Type paramTy,
543-
Expr *arg,
544543
Type argTy,
545544
Type otherArgTy = Type()) {
546545
// Determine the argument type.
@@ -550,32 +549,45 @@ namespace {
550549
if (paramTy->isEqual(argTy))
551550
return true;
552551

553-
// If the argument is a literal, this is a favored param/arg pair if
554-
// the parameter is of that default type.
555-
auto &tc = CS.getTypeChecker();
556-
auto literalProto = tc.getLiteralProtocol(arg->getSemanticsProvidingExpr());
557-
if (!literalProto) return false;
552+
llvm::SmallSetVector<ProtocolDecl *, 2> literalProtos;
553+
if (auto argTypeVar = argTy->getAs<TypeVariableType>()) {
554+
llvm::SetVector<Constraint *> constraints;
555+
CS.getConstraintGraph().gatherConstraints(
556+
argTypeVar, constraints,
557+
ConstraintGraph::GatheringKind::EquivalenceClass,
558+
[](Constraint *constraint) {
559+
return constraint->getKind() == ConstraintKind::LiteralConformsTo;
560+
});
561+
562+
for (auto constraint : constraints) {
563+
literalProtos.insert(constraint->getProtocol());
564+
}
565+
}
558566

559567
// Dig out the second argument type.
560568
if (otherArgTy)
561569
otherArgTy = otherArgTy->getWithoutSpecifierType();
562570

563-
// If there is another, concrete argument, check whether it's type
564-
// conforms to the literal protocol and test against it directly.
565-
// This helps to avoid 'widening' the favored type to the default type for
566-
// the literal.
567-
if (otherArgTy && otherArgTy->getAnyNominal()) {
568-
return otherArgTy->isEqual(paramTy) &&
569-
tc.conformsToProtocol(otherArgTy, literalProto, CS.DC,
570-
ConformanceCheckFlags::InExpression);
571+
auto &tc = CS.getTypeChecker();
572+
for (auto literalProto : literalProtos) {
573+
// If there is another, concrete argument, check whether it's type
574+
// conforms to the literal protocol and test against it directly.
575+
// This helps to avoid 'widening' the favored type to the default type for
576+
// the literal.
577+
if (otherArgTy && otherArgTy->getAnyNominal()) {
578+
if (otherArgTy->isEqual(paramTy) &&
579+
tc.conformsToProtocol(otherArgTy, literalProto, CS.DC,
580+
ConformanceCheckFlags::InExpression))
581+
return true;
582+
} else if (Type defaultType = tc.getDefaultType(literalProto, CS.DC)) {
583+
// If there is a default type for the literal protocol, check whether
584+
// it is the same as the parameter type.
585+
// Check whether there is a default type to compare against.
586+
if (paramTy->isEqual(defaultType))
587+
return true;
588+
}
571589
}
572590

573-
// If there is a default type for the literal protocol, check whether
574-
// it is the same as the parameter type.
575-
// Check whether there is a default type to compare against.
576-
if (Type defaultType = tc.getDefaultType(literalProto, CS.DC))
577-
return paramTy->isEqual(defaultType);
578-
579591
return false;
580592
}
581593

@@ -744,7 +756,7 @@ namespace {
744756
auto contextualTy = CS.getContextualType(expr);
745757

746758
return isFavoredParamAndArg(
747-
CS, paramTy, expr->getArg(),
759+
CS, paramTy,
748760
CS.getType(expr->getArg())->getWithoutParens()) &&
749761
(!contextualTy || contextualTy->isEqual(resultTy));
750762
};
@@ -899,14 +911,6 @@ namespace {
899911
CS.setFavoredType(argTupleExpr->getElement(1), favoredExprTy);
900912
secondFavoredTy = favoredExprTy;
901913
}
902-
903-
if (firstFavoredTy && firstArgTy->is<TypeVariableType>()) {
904-
firstArgTy = firstFavoredTy;
905-
}
906-
907-
if (secondFavoredTy && secondArgTy->is<TypeVariableType>()) {
908-
secondArgTy = secondFavoredTy;
909-
}
910914
}
911915

912916
// Figure out the parameter type.
@@ -924,9 +928,8 @@ namespace {
924928
auto resultTy = fnTy->getResult();
925929
auto contextualTy = CS.getContextualType(expr);
926930

927-
return (isFavoredParamAndArg(CS, firstParamTy, firstArg, firstArgTy,
928-
secondArgTy) ||
929-
isFavoredParamAndArg(CS, secondParamTy, secondArg, secondArgTy,
931+
return (isFavoredParamAndArg(CS, firstParamTy, firstArgTy, secondArgTy) ||
932+
isFavoredParamAndArg(CS, secondParamTy, secondArgTy,
930933
firstArgTy)) &&
931934
firstParamTy->isEqual(secondParamTy) &&
932935
!isPotentialForcingOpportunity(firstArgTy, secondArgTy) &&
@@ -1109,7 +1112,7 @@ namespace {
11091112
auto keyTy = dictTy->first;
11101113
auto valueTy = dictTy->second;
11111114

1112-
if (isFavoredParamAndArg(CS, keyTy, index, CS.getType(index))) {
1115+
if (isFavoredParamAndArg(CS, keyTy, CS.getType(index))) {
11131116
outputTy = OptionalType::get(valueTy);
11141117

11151118
if (isLValueBase)

0 commit comments

Comments
 (0)