Skip to content

Commit dbb38f1

Browse files
authored
Merge pull request #67435 from xedin/rdar-112029630
[CSRanking] Augment overload ranking to account for variadic generics
2 parents dc39f04 + b98cd11 commit dbb38f1

File tree

6 files changed

+256
-66
lines changed

6 files changed

+256
-66
lines changed

include/swift/Sema/ConstraintSystem.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6268,11 +6268,16 @@ Type isPlaceholderVar(PatternBindingDecl *PB);
62686268
/// Dump an anchor node for a constraint locator or contextual type.
62696269
void dumpAnchor(ASTNode anchor, SourceManager *SM, raw_ostream &out);
62706270

6271+
bool isPackExpansionType(Type type);
6272+
62716273
/// Check whether the type is a tuple consisting of a single unlabeled element
62726274
/// of \c PackExpansionType or a type variable that represents a pack expansion
62736275
/// type.
62746276
bool isSingleUnlabeledPackExpansionTuple(Type type);
62756277

6278+
bool containsPackExpansionType(ArrayRef<AnyFunctionType::Param> params);
6279+
bool containsPackExpansionType(TupleType *tuple);
6280+
62766281
/// \returns null if \c type is not a single unlabeled pack expansion tuple.
62776282
Type getPatternTypeOfSingleUnlabeledPackExpansionTuple(Type type);
62786283

lib/Sema/CSGen.cpp

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -592,7 +592,25 @@ namespace {
592592
// Find the argument type.
593593
size_t nArgs = expr->getArgs()->size();
594594
auto fnExpr = expr->getFn();
595-
595+
596+
auto mustConsiderVariadicGenericOverloads = [&](ValueDecl *overload) {
597+
if (overload->getAttrs().hasAttribute<DisfavoredOverloadAttr>())
598+
return false;
599+
600+
auto genericContext = overload->getAsGenericContext();
601+
if (!genericContext)
602+
return false;
603+
604+
auto *GPL = genericContext->getGenericParams();
605+
if (!GPL)
606+
return false;
607+
608+
return llvm::any_of(GPL->getParams(),
609+
[&](const GenericTypeParamDecl *GP) {
610+
return GP->isParameterPack();
611+
});
612+
};
613+
596614
// Check to ensure that we have an OverloadedDeclRef, and that we're not
597615
// favoring multiple overload constraints. (Otherwise, in this case
598616
// favoring is useless.
@@ -630,8 +648,9 @@ namespace {
630648
return nArgs == paramCount.first ||
631649
nArgs == paramCount.second;
632650
};
633-
634-
favorCallOverloads(expr, CS, isFavoredDecl);
651+
652+
favorCallOverloads(expr, CS, isFavoredDecl,
653+
mustConsiderVariadicGenericOverloads);
635654
}
636655

637656
// We only currently perform favoring for unary args.
@@ -655,7 +674,8 @@ namespace {
655674
// inside an extension context, since any archetypes in the parameter
656675
// list could match exactly.
657676
auto mustConsider = [&](ValueDecl *value) -> bool {
658-
return isa<ProtocolDecl>(value->getDeclContext());
677+
return isa<ProtocolDecl>(value->getDeclContext()) ||
678+
mustConsiderVariadicGenericOverloads(value);
659679
};
660680

661681
favorCallOverloads(expr, CS, isFavoredDecl, mustConsider);

lib/Sema/CSRanking.cpp

Lines changed: 110 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -633,75 +633,126 @@ bool CompareDeclSpecializationRequest::evaluate(
633633
auto params1 = funcTy1->getParams();
634634
auto params2 = funcTy2->getParams();
635635

636-
unsigned numParams1 = params1.size();
637-
unsigned numParams2 = params2.size();
638-
if (numParams1 > numParams2)
639-
return completeResult(false);
640-
641-
// If they both have trailing closures, compare those separately.
642-
bool compareTrailingClosureParamsSeparately = false;
643-
if (numParams1 > 0 && numParams2 > 0 &&
644-
params1.back().getParameterType()->is<AnyFunctionType>() &&
645-
params2.back().getParameterType()->is<AnyFunctionType>()) {
646-
compareTrailingClosureParamsSeparately = true;
647-
}
636+
// TODO: We should consider merging these two branches together in
637+
// the future instead of re-implementing `matchCallArguments`.
638+
if (containsPackExpansionType(params1) ||
639+
containsPackExpansionType(params2)) {
640+
ParameterListInfo paramListInfo(params2, decl2, decl2->hasCurriedSelf());
641+
642+
MatchCallArgumentListener listener;
643+
SmallVector<AnyFunctionType::Param> args(params1);
644+
auto matching = matchCallArguments(
645+
args, params2, paramListInfo, llvm::None,
646+
/*allowFixes=*/false, listener, TrailingClosureMatching::Forward);
647+
648+
if (!matching)
649+
return completeResult(false);
650+
651+
for (unsigned paramIdx = 0,
652+
numParams = matching->parameterBindings.size();
653+
paramIdx != numParams; ++paramIdx) {
654+
const auto &param = params2[paramIdx];
655+
auto paramTy = param.getOldType();
656+
657+
if (paramListInfo.isVariadicGenericParameter(paramIdx) &&
658+
isPackExpansionType(paramTy)) {
659+
SmallVector<Type, 2> argTypes;
660+
for (auto argIdx : matching->parameterBindings[paramIdx]) {
661+
// Don't prefer `T...` over `repeat each T`.
662+
if (args[argIdx].isVariadic())
663+
return completeResult(false);
664+
argTypes.push_back(args[argIdx].getPlainType());
665+
}
648666

649-
auto maybeAddSubtypeConstraint =
650-
[&](const AnyFunctionType::Param &param1,
651-
const AnyFunctionType::Param &param2) -> bool {
652-
// If one parameter is variadic and the other is not...
653-
if (param1.isVariadic() != param2.isVariadic()) {
654-
// If the first parameter is the variadic one, it's not
655-
// more specialized.
656-
if (param1.isVariadic())
657-
return false;
667+
auto *argPack = PackType::get(cs.getASTContext(), argTypes);
668+
cs.addConstraint(ConstraintKind::Subtype,
669+
PackExpansionType::get(argPack, argPack), paramTy,
670+
locator);
671+
continue;
672+
}
673+
674+
for (auto argIdx : matching->parameterBindings[paramIdx]) {
675+
const auto &arg = args[argIdx];
676+
// Always prefer non-variadic version when possible.
677+
if (arg.isVariadic())
678+
return completeResult(false);
658679

659-
fewerEffectiveParameters = true;
680+
cs.addConstraint(ConstraintKind::Subtype, arg.getOldType(),
681+
paramTy, locator);
682+
}
683+
}
684+
} else {
685+
unsigned numParams1 = params1.size();
686+
unsigned numParams2 = params2.size();
687+
688+
if (numParams1 > numParams2)
689+
return completeResult(false);
690+
691+
// If they both have trailing closures, compare those separately.
692+
bool compareTrailingClosureParamsSeparately = false;
693+
if (numParams1 > 0 && numParams2 > 0 &&
694+
params1.back().getParameterType()->is<AnyFunctionType>() &&
695+
params2.back().getParameterType()->is<AnyFunctionType>()) {
696+
compareTrailingClosureParamsSeparately = true;
660697
}
661698

662-
Type paramType1 = getAdjustedParamType(param1);
663-
Type paramType2 = getAdjustedParamType(param2);
699+
auto maybeAddSubtypeConstraint =
700+
[&](const AnyFunctionType::Param &param1,
701+
const AnyFunctionType::Param &param2) -> bool {
702+
// If one parameter is variadic and the other is not...
703+
if (param1.isVariadic() != param2.isVariadic()) {
704+
// If the first parameter is the variadic one, it's not
705+
// more specialized.
706+
if (param1.isVariadic())
707+
return false;
708+
709+
fewerEffectiveParameters = true;
710+
}
664711

665-
// Check whether the first parameter is a subtype of the second.
666-
cs.addConstraint(ConstraintKind::Subtype, paramType1, paramType2,
667-
locator);
668-
return true;
669-
};
670-
671-
auto pairMatcher = [&](unsigned idx1, unsigned idx2) -> bool {
672-
// Emulate behavior from when IUO was a type, where IUOs
673-
// were considered subtypes of plain optionals, but not
674-
// vice-versa. This wouldn't normally happen, but there are
675-
// cases where we can rename imported APIs so that we have a
676-
// name collision, and where the parameter type(s) are the
677-
// same except for details of the kind of optional declared.
678-
auto param1IsIUO = paramIsIUO(decl1, idx1);
679-
auto param2IsIUO = paramIsIUO(decl2, idx2);
680-
if (param2IsIUO && !param1IsIUO)
681-
return false;
682-
683-
if (!maybeAddSubtypeConstraint(params1[idx1], params2[idx2]))
684-
return false;
712+
Type paramType1 = getAdjustedParamType(param1);
713+
Type paramType2 = getAdjustedParamType(param2);
685714

686-
return true;
687-
};
715+
// Check whether the first parameter is a subtype of the second.
716+
cs.addConstraint(ConstraintKind::Subtype, paramType1, paramType2,
717+
locator);
718+
return true;
719+
};
688720

689-
ParameterListInfo paramInfo(params2, decl2, decl2->hasCurriedSelf());
690-
auto params2ForMatching = params2;
691-
if (compareTrailingClosureParamsSeparately) {
692-
--numParams1;
693-
params2ForMatching = params2.drop_back();
694-
}
721+
auto pairMatcher = [&](unsigned idx1, unsigned idx2) -> bool {
722+
// Emulate behavior from when IUO was a type, where IUOs
723+
// were considered subtypes of plain optionals, but not
724+
// vice-versa. This wouldn't normally happen, but there are
725+
// cases where we can rename imported APIs so that we have a
726+
// name collision, and where the parameter type(s) are the
727+
// same except for details of the kind of optional declared.
728+
auto param1IsIUO = paramIsIUO(decl1, idx1);
729+
auto param2IsIUO = paramIsIUO(decl2, idx2);
730+
if (param2IsIUO && !param1IsIUO)
731+
return false;
732+
733+
if (!maybeAddSubtypeConstraint(params1[idx1], params2[idx2]))
734+
return false;
695735

696-
InputMatcher IM(params2ForMatching, paramInfo);
697-
if (IM.match(numParams1, pairMatcher) != InputMatcher::IM_Succeeded)
698-
return completeResult(false);
736+
return true;
737+
};
738+
739+
ParameterListInfo paramInfo(params2, decl2, decl2->hasCurriedSelf());
740+
auto params2ForMatching = params2;
741+
if (compareTrailingClosureParamsSeparately) {
742+
--numParams1;
743+
params2ForMatching = params2.drop_back();
744+
}
745+
746+
InputMatcher IM(params2ForMatching, paramInfo);
747+
if (IM.match(numParams1, pairMatcher) != InputMatcher::IM_Succeeded)
748+
return completeResult(false);
699749

700-
fewerEffectiveParameters |= (IM.getNumSkippedParameters() != 0);
750+
fewerEffectiveParameters |= (IM.getNumSkippedParameters() != 0);
701751

702-
if (compareTrailingClosureParamsSeparately)
703-
if (!maybeAddSubtypeConstraint(params1.back(), params2.back()))
704-
knownNonSubtype = true;
752+
if (compareTrailingClosureParamsSeparately)
753+
if (!maybeAddSubtypeConstraint(params1.back(), params2.back()))
754+
knownNonSubtype = true;
755+
}
705756
}
706757

707758
if (!knownNonSubtype) {

lib/Sema/CSSimplify.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ static llvm::Optional<unsigned> scoreParamAndArgNameTypo(StringRef paramName,
117117
return dist;
118118
}
119119

120-
static bool isPackExpansionType(Type type) {
120+
bool constraints::isPackExpansionType(Type type) {
121121
if (type->is<PackExpansionType>())
122122
return true;
123123

@@ -152,13 +152,13 @@ Type constraints::getPatternTypeOfSingleUnlabeledPackExpansionTuple(Type type) {
152152
return {};
153153
}
154154

155-
static bool containsPackExpansionType(ArrayRef<AnyFunctionType::Param> params) {
155+
bool constraints::containsPackExpansionType(ArrayRef<AnyFunctionType::Param> params) {
156156
return llvm::any_of(params, [&](const auto &param) {
157157
return isPackExpansionType(param.getPlainType());
158158
});
159159
}
160160

161-
static bool containsPackExpansionType(TupleType *tuple) {
161+
bool constraints::containsPackExpansionType(TupleType *tuple) {
162162
return llvm::any_of(tuple->getElements(), [&](const auto &elt) {
163163
return isPackExpansionType(elt.getType());
164164
});

test/Constraints/pack-expansion-expressions.swift

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -680,3 +680,30 @@ do {
680680
}
681681
}
682682
}
683+
684+
// rdar://112029630 - incorrect variadic generic overload ranking
685+
do {
686+
func test1<T>(_: T...) {}
687+
// expected-note@-1 {{found this candidate}}
688+
func test1<each T>(_: repeat each T) {}
689+
// expected-note@-1 {{found this candidate}}
690+
691+
test1(1, 2, 3) // expected-error {{ambiguous use of 'test1'}}
692+
test1(1, "a") // Ok
693+
694+
func test2<each T>(_: repeat each T) {}
695+
// expected-note@-1 {{found this candidate}}
696+
func test2<each T>(vals: repeat each T) {}
697+
// expected-note@-1 {{found this candidate}}
698+
699+
test2() // expected-error {{ambiguous use of 'test2'}}
700+
701+
func test_different_requirements<A: BinaryInteger & StringProtocol>(_ a: A) {
702+
func test3<each T: BinaryInteger>(str: String, _: repeat each T) {}
703+
// expected-note@-1 {{found this candidate}}
704+
func test3<each U: StringProtocol>(str: repeat each U) {}
705+
// expected-note@-1 {{found this candidate}}
706+
707+
test3(str: "", a, a) // expected-error {{ambiguous use of 'test3'}}
708+
}
709+
}
Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
// RUN: %target-swift-emit-silgen %s -verify -swift-version 5 -disable-availability-checking | %FileCheck %s
2+
3+
// CHECK-LABEL: sil hidden [ossa] @$s33variadic_generic_overload_ranking05test_d15_concrete_over_A0yyF
4+
func test_ranking_concrete_over_variadic() {
5+
func test() {}
6+
func test<T>(_: T) {}
7+
func test<each T>(_: repeat each T) {}
8+
9+
// CHECK: // function_ref test #1 () in test_ranking_concrete_over_variadic()
10+
// CHECK-NEXT: {{.*}} = function_ref @$s33variadic_generic_overload_ranking05test_d15_concrete_over_A0yyF0E0L_yyF : $@convention(thin) () -> ()
11+
test()
12+
// CHECK: // function_ref test #2 <A>(_:) in test_ranking_concrete_over_variadic()
13+
// CHECK-NEXT: {{.*}} = function_ref @$s33variadic_generic_overload_ranking05test_d15_concrete_over_A0yyF0E0L0_yyxlF : $@convention(thin) <τ_0_0> (@in_guaranteed τ_0_0) -> ()
14+
test(1)
15+
// CHECK: // function_ref test #3 <each A>(_:) in test_ranking_concrete_over_variadic()
16+
// CHECK-NEXT: {{.*}} = function_ref @$s33variadic_generic_overload_ranking05test_d15_concrete_over_A0yyF0E0L1_yyxxQpRvzlF : $@convention(thin) <each τ_0_0> (@pack_guaranteed Pack{repeat each τ_0_0}) -> ()
17+
test(1, "")
18+
}
19+
20+
// CHECK-LABEL: sil hidden [ossa] @$s33variadic_generic_overload_ranking05test_d1_A31_over_concrete_with_conversionsyyF
21+
func test_ranking_variadic_over_concrete_with_conversions() {
22+
func test<T>(_: T, _: Any) {}
23+
func test<each T>(_: repeat each T) {}
24+
25+
// CHECK: // function_ref test #2 <each A>(_:) in test_ranking_variadic_over_concrete_with_conversions()
26+
// CHECK-LABEL: {{.*}} = function_ref @$s33variadic_generic_overload_ranking05test_d1_A31_over_concrete_with_conversionsyyF0E0L0_yyxxQpRvzlF : $@convention(thin) <each τ_0_0> (@pack_guaranteed Pack{repeat each τ_0_0}) -> ()
27+
test(1, "")
28+
29+
func test_disfavored<T>(_: T, _: Any) {}
30+
@_disfavoredOverload
31+
func test_disfavored<each T>(_: repeat each T) {}
32+
33+
// CHECK: // function_ref test_disfavored #1 <A>(_:_:) in test_ranking_variadic_over_concrete_with_conversions()
34+
// CHECK-NEXT: {{.*}} = function_ref @$s33variadic_generic_overload_ranking05test_d1_A31_over_concrete_with_conversionsyyF0E11_disfavoredL_yyx_yptlF : $@convention(thin) <τ_0_0> (@in_guaranteed τ_0_0, @in_guaranteed Any) -> ()
35+
test_disfavored(2, "a")
36+
}
37+
38+
// CHECK-LABEL: sil hidden [ossa] @$s33variadic_generic_overload_ranking05test_d1_a1_B13_over_regularyyF : $@convention(thin) () -> ()
39+
func test_ranking_variadic_generic_over_regular() {
40+
func test1<T>(_: T...) {}
41+
func test1<each T>(_: repeat each T) {}
42+
43+
// CHECK: // function_ref test1 #2 <each A>(_:) in test_ranking_variadic_generic_over_regular()
44+
// CHECK-NEXT: {{.*}} = function_ref @$s33variadic_generic_overload_ranking05test_d1_a1_B13_over_regularyyF5test1L0_yyxxQpRvzlF : $@convention(thin) <each τ_0_0> (@pack_guaranteed Pack{repeat each τ_0_0}) -> ()
45+
test1(1, "a")
46+
}
47+
48+
protocol P {
49+
}
50+
51+
// CHECK-LABEL: sil hidden [ossa] @$s33variadic_generic_overload_ranking05test_D25_with_multiple_expansionsyyF
52+
func test_ranking_with_multiple_expansions() {
53+
struct Empty : P {}
54+
struct Tuple<T> : P {
55+
init(_: T) {}
56+
}
57+
58+
struct Builder {
59+
static func build() -> Empty { Empty() }
60+
static func build<T: P>(_ a: T) -> T { a }
61+
static func build<T: P>(_ a: T, _ b: T) -> Tuple<(T, T)> { Tuple((a, b)) }
62+
static func build<each T: P>(_ v: repeat each T) -> Tuple<(repeat each T)> { Tuple((repeat each v)) }
63+
64+
static func otherBuild<T: P, U: P>(a: T, b: U) {}
65+
static func otherBuild<each T: P, each U: P>(a: repeat each T, b: repeat each U) {}
66+
}
67+
68+
// CHECK: // function_ref static otherBuild<A, B>(a:b:) in Builder #1 in test_ranking_with_multiple_expansions()
69+
// CHECK-NEXT: {{.*}} = function_ref @$s33variadic_generic_overload_ranking05test_D25_with_multiple_expansionsyyF7BuilderL_V10otherBuild1a1byx_q_tAA1PRzAaHR_r0_lFZ : $@convention(method) <τ_0_0, τ_0_1 where τ_0_0 : P, τ_0_1 : P> (@in_guaranteed τ_0_0, @in_guaranteed τ_0_1, @thin Builder.Type) -> ()
70+
Builder.otherBuild(a: Empty(), b: Empty())
71+
// CHECK: // function_ref static otherBuild<A, B>(a:b:) in Builder #1 in test_ranking_with_multiple_expansions()
72+
// CHECK-NEXT: {{.*}} = function_ref @$s33variadic_generic_overload_ranking05test_D25_with_multiple_expansionsyyF7BuilderL_V10otherBuild1a1byx_q_tAA1PRzAaHR_r0_lFZ : $@convention(method) <τ_0_0, τ_0_1 where τ_0_0 : P, τ_0_1 : P> (@in_guaranteed τ_0_0, @in_guaranteed τ_0_1, @thin Builder.Type) -> ()
73+
Builder.otherBuild(a: Empty(), b: Tuple<Int>(42))
74+
75+
// CHECK: // function_ref static build() in Builder #1 in test_ranking_with_multiple_expansions()
76+
// CHECK-NEXT: {{.*}} = function_ref @$s33variadic_generic_overload_ranking05test_D25_with_multiple_expansionsyyF7BuilderL_V5buildAaByyF5EmptyL_VyFZ : $@convention(method) (@thin Builder.Type) -> Empty
77+
_ = Builder.build()
78+
// CHECK: // function_ref static build<A>(_:) in Builder #1 in test_ranking_with_multiple_expansions()
79+
// CHECK-NEXT: {{.*}} = function_ref @$s33variadic_generic_overload_ranking05test_D25_with_multiple_expansionsyyF7BuilderL_V5buildyxxAA1PRzlFZ : $@convention(method) <τ_0_0 where τ_0_0 : P> (@in_guaranteed τ_0_0, @thin Builder.Type) -> @out τ_0_0
80+
_ = Builder.build(Empty())
81+
// CHECK: // function_ref static build<A>(_:_:) in Builder #1 in test_ranking_with_multiple_expansions()
82+
// CHECK-NEXT: {{.*}} = function_ref @$s33variadic_generic_overload_ranking05test_D25_with_multiple_expansionsyyF7BuilderL_V5buildyAaByyF5TupleL_Vyx_xtGx_xtAA1PRzlFZ : $@convention(method) <τ_0_0 where τ_0_0 : P> (@in_guaranteed τ_0_0, @in_guaranteed τ_0_0, @thin Builder.Type) -> Tuple<(τ_0_0, τ_0_0)>
83+
_ = Builder.build(Empty(), Empty())
84+
// CHECK: // function_ref static build<each A>(_:) in Builder #1 in test_ranking_with_multiple_expansions()
85+
// CHECK-NEXT: {{.*}} = function_ref @$s33variadic_generic_overload_ranking05test_D25_with_multiple_expansionsyyF7BuilderL_V5buildyAaByyF5TupleL_VyxxQp_tGxxQpRvzAA1PRzlFZ : $@convention(method) <each τ_0_0 where repeat each τ_0_0 : P> (@pack_guaranteed Pack{repeat each τ_0_0}, @thin Builder.Type) -> Tuple<(repeat each τ_0_0)>
86+
_ = Builder.build(Empty(), Tuple<(Int, String)>((42, "")))
87+
}

0 commit comments

Comments
 (0)