@@ -3709,6 +3709,160 @@ static bool diagnoseAmbiguity(
37093709 return diagnosed;
37103710}
37113711
3712+ using FixInContext = std::pair<const Solution *, const ConstraintFix *>;
3713+
3714+ // Attempts to diagnose function call ambiguities of types inferred for a result
3715+ // generic parameter from contextual type and a closure argument that
3716+ // conflicting infer a different type for the same argument. Example:
3717+ // func callit<T>(_ f: () -> T) -> T {
3718+ // f()
3719+ // }
3720+ //
3721+ // func context() -> Int {
3722+ // callit {
3723+ // print("hello")
3724+ // }
3725+ // }
3726+ // Where generic argument `T` can be inferred both as `Int` from contextual
3727+ // result and `Void` from the closure argument result.
3728+ static bool diagnoseContextualFunctionCallGenericAmbiguity (
3729+ ConstraintSystem &cs, ArrayRef<FixInContext> contextualFixes,
3730+ ArrayRef<FixInContext> allFixes) {
3731+
3732+ if (contextualFixes.empty ())
3733+ return false ;
3734+
3735+ auto contextualFix = contextualFixes.front ();
3736+ if (!std::all_of (contextualFixes.begin () + 1 , contextualFixes.end (),
3737+ [&contextualFix](FixInContext fix) {
3738+ return fix.second ->getLocator () ==
3739+ contextualFix.second ->getLocator ();
3740+ }))
3741+ return false ;
3742+
3743+ auto fixLocator = contextualFix.second ->getLocator ();
3744+ auto contextualAnchor = fixLocator->getAnchor ();
3745+ auto *AE = getAsExpr<ApplyExpr>(contextualAnchor);
3746+ // All contextual failures anchored on the same function call.
3747+ if (!AE)
3748+ return false ;
3749+
3750+ auto fnLocator = cs.getConstraintLocator (AE->getSemanticFn ());
3751+ auto overload = contextualFix.first ->getOverloadChoiceIfAvailable (fnLocator);
3752+ if (!overload)
3753+ return false ;
3754+
3755+ auto applyFnType = overload->openedType ->castTo <FunctionType>();
3756+ auto resultTypeVar = applyFnType->getResult ()->getAs <TypeVariableType>();
3757+ if (!resultTypeVar)
3758+ return false ;
3759+
3760+ auto *GP = resultTypeVar->getImpl ().getGenericParameter ();
3761+ if (!GP)
3762+ return false ;
3763+
3764+ auto applyLoc =
3765+ cs.getConstraintLocator (AE, {LocatorPathElt::ApplyArgument ()});
3766+ auto argMatching =
3767+ contextualFix.first ->argumentMatchingChoices .find (applyLoc);
3768+ if (argMatching == contextualFix.first ->argumentMatchingChoices .end ()) {
3769+ return false ;
3770+ }
3771+
3772+ auto typeParamResultInvolvesTypeVar = [&cs, &applyFnType, &argMatching](
3773+ unsigned argIdx,
3774+ TypeVariableType *typeVar) {
3775+ auto argParamMatch = argMatching->second .parameterBindings [argIdx];
3776+ auto param = applyFnType->getParams ()[argParamMatch.front ()];
3777+ if (param.isVariadic ()) {
3778+ auto paramType = param.getParameterType ();
3779+ // Variadic parameter is constructed as an ArraySliceType(which is
3780+ // just sugared type for a bound generic) with the closure type as
3781+ // element.
3782+ auto baseType = paramType->getDesugaredType ()->castTo <BoundGenericType>();
3783+ auto paramFnType = baseType->getGenericArgs ()[0 ]->castTo <FunctionType>();
3784+ return cs.typeVarOccursInType (typeVar, paramFnType->getResult ());
3785+ }
3786+ auto paramFnType = param.getParameterType ()->castTo <FunctionType>();
3787+ return cs.typeVarOccursInType (typeVar, paramFnType->getResult ());
3788+ };
3789+
3790+ llvm::SmallVector<ClosureExpr *, 2 > closureArguments;
3791+ // A single closure argument.
3792+ if (auto *closure =
3793+ getAsExpr<ClosureExpr>(AE->getArg ()->getSemanticsProvidingExpr ())) {
3794+ if (typeParamResultInvolvesTypeVar (/* paramIdx=*/ 0 , resultTypeVar))
3795+ closureArguments.push_back (closure);
3796+ } else if (auto *argTuple = getAsExpr<TupleExpr>(AE->getArg ())) {
3797+ for (auto i : indices (argTuple->getElements ())) {
3798+ auto arg = argTuple->getElements ()[i];
3799+ auto *closure = getAsExpr<ClosureExpr>(arg);
3800+ if (closure &&
3801+ typeParamResultInvolvesTypeVar (/* paramIdx=*/ i, resultTypeVar)) {
3802+ closureArguments.push_back (closure);
3803+ }
3804+ }
3805+ }
3806+
3807+ // If no closure result's involves the generic parameter, just bail because we
3808+ // won't find a conflict.
3809+ if (closureArguments.empty ())
3810+ return false ;
3811+
3812+ // At least one closure where result type involves the generic parameter.
3813+ // So let's try to collect the set of fixed types for the generic parameter
3814+ // from all the closure contextual fix/solutions and if there are more than
3815+ // one fixed type diagnose it.
3816+ llvm::SmallSetVector<Type, 4 > genericParamInferredTypes;
3817+ for (auto &fix : contextualFixes)
3818+ genericParamInferredTypes.insert (fix.first ->getFixedType (resultTypeVar));
3819+
3820+ if (llvm::all_of (allFixes, [&](FixInContext fix) {
3821+ auto fixLocator = fix.second ->getLocator ();
3822+ if (fixLocator->isForContextualType ())
3823+ return true ;
3824+
3825+ if (!(fix.second ->getKind () == FixKind::ContextualMismatch ||
3826+ fix.second ->getKind () == FixKind::AllowTupleTypeMismatch))
3827+ return false ;
3828+
3829+ auto anchor = fixLocator->getAnchor ();
3830+ if (!(anchor == contextualAnchor ||
3831+ fixLocator->isLastElement <LocatorPathElt::ClosureResult>() ||
3832+ fixLocator->isLastElement <LocatorPathElt::ClosureBody>()))
3833+ return false ;
3834+
3835+ genericParamInferredTypes.insert (
3836+ fix.first ->getFixedType (resultTypeVar));
3837+ return true ;
3838+ })) {
3839+
3840+ if (genericParamInferredTypes.size () != 2 )
3841+ return false ;
3842+
3843+ auto &DE = cs.getASTContext ().Diags ;
3844+ llvm::SmallString<64 > arguments;
3845+ llvm::raw_svector_ostream OS (arguments);
3846+ interleave (
3847+ genericParamInferredTypes,
3848+ [&](Type argType) { OS << " '" << argType << " '" ; },
3849+ [&OS] { OS << " vs. " ; });
3850+
3851+ DE.diagnose (AE->getLoc (), diag::conflicting_arguments_for_generic_parameter,
3852+ GP, OS.str ());
3853+
3854+ DE.diagnose (AE->getLoc (),
3855+ diag::generic_parameter_inferred_from_result_context, GP,
3856+ genericParamInferredTypes.back ());
3857+ DE.diagnose (closureArguments.front ()->getStartLoc (),
3858+ diag::generic_parameter_inferred_from_closure, GP,
3859+ genericParamInferredTypes.front ());
3860+
3861+ return true ;
3862+ }
3863+ return false ;
3864+ }
3865+
37123866bool ConstraintSystem::diagnoseAmbiguityWithFixes (
37133867 SmallVectorImpl<Solution> &solutions) {
37143868 if (solutions.empty ())
@@ -3761,16 +3915,15 @@ bool ConstraintSystem::diagnoseAmbiguityWithFixes(
37613915 // d. Diagnose remaining (uniqued based on kind + locator) fixes
37623916 // iff they appear in all of the solutions.
37633917
3764- using Fix = std::pair<const Solution *, const ConstraintFix *>;
3765-
3766- llvm::SmallSetVector<Fix, 4 > fixes;
3918+ llvm::SmallSetVector<FixInContext, 4 > fixes;
37673919 for (auto &solution : solutions) {
37683920 for (auto *fix : solution.Fixes )
37693921 fixes.insert ({&solution, fix});
37703922 }
37713923
3772- llvm::MapVector<ConstraintLocator *, SmallVector<Fix, 4 >> fixesByCallee;
3773- llvm::SmallVector<Fix, 4 > contextualFixes;
3924+ llvm::MapVector<ConstraintLocator *, SmallVector<FixInContext, 4 >>
3925+ fixesByCallee;
3926+ llvm::SmallVector<FixInContext, 4 > contextualFixes;
37743927
37753928 for (const auto &entry : fixes) {
37763929 const auto &solution = *entry.first ;
@@ -3790,7 +3943,7 @@ bool ConstraintSystem::diagnoseAmbiguityWithFixes(
37903943 bool diagnosed = false ;
37913944
37923945 // All of the fixes which have been considered already.
3793- llvm::SmallSetVector<Fix , 4 > consideredFixes;
3946+ llvm::SmallSetVector<FixInContext , 4 > consideredFixes;
37943947
37953948 for (const auto &ambiguity : solutionDiff.overloads ) {
37963949 auto fixes = fixesByCallee.find (ambiguity.locator );
@@ -3813,7 +3966,8 @@ bool ConstraintSystem::diagnoseAmbiguityWithFixes(
38133966 // overload choices.
38143967 fixes.set_subtract (consideredFixes);
38153968
3816- llvm::MapVector<std::pair<FixKind, ConstraintLocator *>, SmallVector<Fix, 4 >>
3969+ llvm::MapVector<std::pair<FixKind, ConstraintLocator *>,
3970+ SmallVector<FixInContext, 4 >>
38173971 fixesByKind;
38183972
38193973 for (const auto &entry : fixes) {
@@ -3837,6 +3991,10 @@ bool ConstraintSystem::diagnoseAmbiguityWithFixes(
38373991 }
38383992 }
38393993
3994+ if (!diagnosed && diagnoseContextualFunctionCallGenericAmbiguity (
3995+ *this , contextualFixes, fixes.getArrayRef ()))
3996+ return true ;
3997+
38403998 return diagnosed;
38413999}
38424000
0 commit comments