@@ -1066,6 +1066,16 @@ class AssociatedTypeInference {
1066
1066
bool isBetterSolution (const InferredTypeWitnessesSolution &first,
1067
1067
const InferredTypeWitnessesSolution &second);
1068
1068
1069
+ // / Find the best solution.
1070
+ // /
1071
+ // / \param solutions All of the solutions to consider. On success,
1072
+ // / this will contain only the best solution.
1073
+ // /
1074
+ // / \returns \c false if there was a single best solution,
1075
+ // / \c true if no single best solution exists.
1076
+ bool findBestSolution (
1077
+ SmallVectorImpl<InferredTypeWitnessesSolution> &solutions);
1078
+
1069
1079
// / Emit a diagnostic for the case where there are no solutions at all
1070
1080
// / to consider.
1071
1081
// /
@@ -1902,19 +1912,20 @@ AssociatedTypeInference::inferTypeWitnessesViaAssociatedType(
1902
1912
else
1903
1913
continue ;
1904
1914
1915
+ if (result.empty ()) {
1916
+ // If we found at least one default candidate, we must allow for the
1917
+ // possibility that no default is chosen by adding a tautological witness
1918
+ // to our disjunction.
1919
+ result.push_back (InferredAssociatedTypesByWitness ());
1920
+ }
1921
+
1905
1922
// Add this result.
1906
1923
InferredAssociatedTypesByWitness inferred;
1907
1924
inferred.Witness = typeDecl;
1908
1925
inferred.Inferred .push_back ({assocType, witnessType});
1909
1926
result.push_back (std::move (inferred));
1910
1927
}
1911
1928
1912
- if (!result.empty ()) {
1913
- // If we found at least one default candidate, we must allow for the
1914
- // possibility that no default is chosen by adding a tautological witness
1915
- // to our disjunction.
1916
- result.push_back (InferredAssociatedTypesByWitness ());
1917
- }
1918
1929
return result;
1919
1930
}
1920
1931
@@ -3130,6 +3141,35 @@ void AssociatedTypeInference::findSolutionsRec(
3130
3141
known->first = replaced;
3131
3142
}
3132
3143
3144
+ if (!ctx.LangOpts .EnableExperimentalAssociatedTypeInference ) {
3145
+ // Check whether our current solution matches the given solution.
3146
+ auto matchesSolution =
3147
+ [&](const InferredTypeWitnessesSolution &solution) {
3148
+ for (const auto &existingTypeWitness : solution.TypeWitnesses ) {
3149
+ auto typeWitness = typeWitnesses.begin (existingTypeWitness.first );
3150
+ if (!typeWitness->first ->isEqual (existingTypeWitness.second .first ))
3151
+ return false ;
3152
+ }
3153
+
3154
+ return true ;
3155
+ };
3156
+
3157
+ // If we've seen this solution already, bail out; there's no point in
3158
+ // checking further.
3159
+ if (llvm::any_of (solutions, matchesSolution)) {
3160
+ LLVM_DEBUG (llvm::dbgs () << std::string (valueWitnesses.size (), ' +' )
3161
+ << " + Duplicate valid solution found\n " ;);
3162
+ ++NumDuplicateSolutionStates;
3163
+ return ;
3164
+ }
3165
+ if (llvm::any_of (nonViableSolutions, matchesSolution)) {
3166
+ LLVM_DEBUG (llvm::dbgs () << std::string (valueWitnesses.size (), ' +' )
3167
+ << " + Duplicate invalid solution found\n " ;);
3168
+ ++NumDuplicateSolutionStates;
3169
+ return ;
3170
+ }
3171
+ }
3172
+
3133
3173
// / Check the current set of type witnesses.
3134
3174
bool invalid = checkCurrentTypeWitnesses (valueWitnesses);
3135
3175
@@ -3156,6 +3196,8 @@ void AssociatedTypeInference::findSolutionsRec(
3156
3196
= numValueWitnessesInProtocolExtensions;
3157
3197
3158
3198
// We fold away non-viable solutions that have the same type witnesses.
3199
+ if (ctx.LangOpts .EnableExperimentalAssociatedTypeInference ) {
3200
+
3159
3201
if (invalid) {
3160
3202
if (llvm::find (nonViableSolutions, solution) != nonViableSolutions.end ()) {
3161
3203
LLVM_DEBUG (llvm::dbgs () << std::string (valueWitnesses.size (), ' +' )
@@ -3168,6 +3210,22 @@ void AssociatedTypeInference::findSolutionsRec(
3168
3210
return ;
3169
3211
}
3170
3212
3213
+ }
3214
+
3215
+ if (!ctx.LangOpts .EnableExperimentalAssociatedTypeInference ) {
3216
+
3217
+ auto &solutionList = invalid ? nonViableSolutions : solutions;
3218
+ solutionList.push_back (solution);
3219
+
3220
+ // If this solution was clearly better than the previous best solution,
3221
+ // swap them.
3222
+ if (solutionList.back ().NumValueWitnessesInProtocolExtensions
3223
+ < solutionList.front ().NumValueWitnessesInProtocolExtensions ) {
3224
+ std::swap (solutionList.front (), solutionList.back ());
3225
+ }
3226
+
3227
+ } else {
3228
+
3171
3229
// For valid solutions, we want to find the best solution if one exists.
3172
3230
// We maintain the invariant that no viable solution is clearly worse than
3173
3231
// any other viable solution. If multiple viable solutions remain after
@@ -3197,6 +3255,8 @@ void AssociatedTypeInference::findSolutionsRec(
3197
3255
});
3198
3256
3199
3257
solutions.push_back (std::move (solution));
3258
+
3259
+ }
3200
3260
return ;
3201
3261
}
3202
3262
@@ -3565,6 +3625,58 @@ bool AssociatedTypeInference::isBetterSolution(
3565
3625
return firstBetter;
3566
3626
}
3567
3627
3628
+ bool AssociatedTypeInference::findBestSolution (
3629
+ SmallVectorImpl<InferredTypeWitnessesSolution> &solutions) {
3630
+ if (solutions.empty ()) return true ;
3631
+ if (solutions.size () == 1 ) return false ;
3632
+
3633
+ // The solution at the front has the smallest number of value witnesses found
3634
+ // in protocol extensions, by construction.
3635
+ unsigned bestNumValueWitnessesInProtocolExtensions
3636
+ = solutions.front ().NumValueWitnessesInProtocolExtensions ;
3637
+
3638
+ // Erase any solutions with more value witnesses in protocol
3639
+ // extensions than the best.
3640
+ solutions.erase (
3641
+ std::remove_if (solutions.begin (), solutions.end (),
3642
+ [&](const InferredTypeWitnessesSolution &solution) {
3643
+ return solution.NumValueWitnessesInProtocolExtensions >
3644
+ bestNumValueWitnessesInProtocolExtensions;
3645
+ }),
3646
+ solutions.end ());
3647
+
3648
+ // If we're down to one solution, success!
3649
+ if (solutions.size () == 1 ) return false ;
3650
+
3651
+ // Find a solution that's at least as good as the solutions that follow it.
3652
+ unsigned bestIdx = 0 ;
3653
+ for (unsigned i = 1 , n = solutions.size (); i != n; ++i) {
3654
+ if (isBetterSolution (solutions[i], solutions[bestIdx]))
3655
+ bestIdx = i;
3656
+ }
3657
+
3658
+ // Make sure that solution is better than any of the other solutions.
3659
+ bool ambiguous = false ;
3660
+ for (unsigned i = 1 , n = solutions.size (); i != n; ++i) {
3661
+ if (i != bestIdx && !isBetterSolution (solutions[bestIdx], solutions[i])) {
3662
+ ambiguous = true ;
3663
+ break ;
3664
+ }
3665
+ }
3666
+
3667
+ // If the result was ambiguous, fail.
3668
+ if (ambiguous) {
3669
+ assert (solutions.size () != 1 && " should have succeeded somewhere above?" );
3670
+ return true ;
3671
+
3672
+ }
3673
+ // Keep the best solution, erasing all others.
3674
+ if (bestIdx != 0 )
3675
+ solutions[0 ] = std::move (solutions[bestIdx]);
3676
+ solutions.erase (solutions.begin () + 1 , solutions.end ());
3677
+ return false ;
3678
+ }
3679
+
3568
3680
namespace {
3569
3681
// / A failed type witness binding.
3570
3682
struct FailedTypeWitness {
@@ -3971,7 +4083,9 @@ auto AssociatedTypeInference::solve()
3971
4083
}
3972
4084
3973
4085
// Happy case: we found exactly one unique viable solution.
3974
- if (solutions.size () == 1 ) {
4086
+ if (!findBestSolution (solutions)) {
4087
+ assert (solutions.size () == 1 && " Not a unique best solution?" );
4088
+
3975
4089
// Form the resulting solution.
3976
4090
auto &typeWitnesses = solutions.front ().TypeWitnesses ;
3977
4091
for (auto assocType : unresolvedAssocTypes) {
0 commit comments