Skip to content

Commit c02f30f

Browse files
[Sema] Diagnose generic parameter contextual inference ambiguity between function call result and closure argument
1 parent 528764c commit c02f30f

File tree

2 files changed

+143
-2
lines changed

2 files changed

+143
-2
lines changed

include/swift/AST/DiagnosticsSema.def

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -380,6 +380,10 @@ ERROR(cannot_convert_argument_value_generic,none,
380380
ERROR(conflicting_arguments_for_generic_parameter,none,
381381
"conflicting arguments to generic parameter %0 (%1)",
382382
(Type, StringRef))
383+
ERROR(conflicting_inferred_generic_parameter_result_and_closure,none,
384+
"conflicting inferred types from call result and closure "
385+
"argument to generic parameter %0 (%1)",
386+
(Type, StringRef))
383387

384388
// @_nonEphemeral conversion diagnostics
385389
ERROR(cannot_pass_type_to_non_ephemeral,none,

lib/Sema/ConstraintSystem.cpp

Lines changed: 139 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3709,6 +3709,141 @@ static bool diagnoseAmbiguity(
37093709
return diagnosed;
37103710
}
37113711

3712+
using Fix = std::pair<const Solution *, const ConstraintFix *>;
3713+
3714+
// Attempts to diagnose function call ambiguities of types inferred for a result
3715+
// generic parameter from contextual type and a closure argument that
3716+
// conflicting infer a different type for the same argument. Example:
3717+
// func callit<T>(_ f: () -> T) -> T {
3718+
// f()
3719+
// }
3720+
//
3721+
// func context() -> Int {
3722+
// callit {
3723+
// print("hello")
3724+
// }
3725+
// }
3726+
// Where generic argument `T` can be inferred both as `Int` from contextual
3727+
// result and `Void` from the closure argument result.
3728+
static bool
3729+
diagnoseContextualFunctionCallGenericAmbiguity(ConstraintSystem &cs,
3730+
ArrayRef<Fix> contextualFixes,
3731+
ArrayRef<Fix> allFixes) {
3732+
3733+
if (contextualFixes.empty())
3734+
return false;
3735+
3736+
auto contextualFix = contextualFixes.front();
3737+
if (!std::all_of(contextualFixes.begin() + 1, contextualFixes.end(),
3738+
[&contextualFix](Fix fix) {
3739+
return fix.second->getLocator() ==
3740+
contextualFix.second->getLocator();
3741+
}))
3742+
return false;
3743+
3744+
auto fixLocator = contextualFix.second->getLocator();
3745+
auto contextualAnchor = fixLocator->getAnchor();
3746+
auto *AE = getAsExpr<ApplyExpr>(contextualAnchor);
3747+
// All contextual failures anchored on the same function call.
3748+
if (!AE)
3749+
return false;
3750+
3751+
auto fnLocator = cs.getConstraintLocator(AE->getSemanticFn());
3752+
auto overload = contextualFix.first->getOverloadChoiceIfAvailable(fnLocator);
3753+
if (!overload)
3754+
return false;
3755+
3756+
auto applyFnType = overload->openedType->castTo<FunctionType>();
3757+
auto resultTypeVar = applyFnType->getResult()->getAs<TypeVariableType>();
3758+
if (!resultTypeVar)
3759+
return false;
3760+
3761+
auto *GP = resultTypeVar->getImpl().getGenericParameter();
3762+
if (!GP)
3763+
return false;
3764+
3765+
auto typeParamResultInvolvesTypeVar =
3766+
[&applyFnType](unsigned paramIdx, TypeVariableType *typeVar) {
3767+
auto param = applyFnType->getParams()[paramIdx];
3768+
auto paramType = param.getParameterType()->castTo<FunctionType>();
3769+
3770+
bool contains = false;
3771+
paramType->getResult().visit([&](Type ty) {
3772+
if (ty->isEqual(typeVar))
3773+
contains = true;
3774+
});
3775+
return contains;
3776+
};
3777+
3778+
llvm::SmallVector<ClosureExpr *, 4> closureArguments;
3779+
// A single closure argument.
3780+
if (auto *closure =
3781+
getAsExpr<ClosureExpr>(AE->getArg()->getSemanticsProvidingExpr())) {
3782+
if (typeParamResultInvolvesTypeVar(/*paramIdx=*/0, resultTypeVar))
3783+
closureArguments.push_back(closure);
3784+
} else if (auto *argTuple = getAsExpr<TupleExpr>(AE->getArg())) {
3785+
for (auto i : indices(argTuple->getElements())) {
3786+
auto arg = argTuple->getElements()[i];
3787+
auto *closure = getAsExpr<ClosureExpr>(arg);
3788+
if (closure &&
3789+
typeParamResultInvolvesTypeVar(/*paramIdx=*/i, resultTypeVar)) {
3790+
closureArguments.push_back(closure);
3791+
}
3792+
}
3793+
}
3794+
3795+
// If no closure result's involves the generic parameter, just bail because we
3796+
// won't find a conflict.
3797+
if (closureArguments.empty())
3798+
return false;
3799+
3800+
// At least one closure where result type involves the generic parameter.
3801+
// So let's try to collect the set of fixed types for the generic parameter
3802+
// from all the closure contextual fix/solutions and if there are more than
3803+
// one fixed type diagnose it.
3804+
llvm::SmallSetVector<Type, 4> genericParamInferredTypes;
3805+
for (auto &fix : contextualFixes)
3806+
genericParamInferredTypes.insert(fix.first->getFixedType(resultTypeVar));
3807+
3808+
if (llvm::all_of(allFixes, [&](Fix fix) {
3809+
auto fixLocator = fix.second->getLocator();
3810+
if (fixLocator->isForContextualType())
3811+
return true;
3812+
3813+
if (!(fix.second->getKind() == FixKind::ContextualMismatch ||
3814+
fix.second->getKind() == FixKind::AllowTupleTypeMismatch))
3815+
return false;
3816+
3817+
auto anchor = fixLocator->getAnchor();
3818+
if (!(anchor == contextualAnchor ||
3819+
fixLocator->isLastElement<LocatorPathElt::ClosureResult>() ||
3820+
fixLocator->isLastElement<LocatorPathElt::ClosureBody>()))
3821+
return false;
3822+
3823+
genericParamInferredTypes.insert(
3824+
fix.first->getFixedType(resultTypeVar));
3825+
return true;
3826+
})) {
3827+
3828+
if (genericParamInferredTypes.size() <= 1)
3829+
return false;
3830+
3831+
auto &DE = cs.getASTContext().Diags;
3832+
llvm::SmallString<64> arguments;
3833+
llvm::raw_svector_ostream OS(arguments);
3834+
interleave(
3835+
genericParamInferredTypes,
3836+
[&](Type argType) { OS << "'" << argType << "'"; },
3837+
[&OS] { OS << " vs. "; });
3838+
3839+
DE.diagnose(AE->getLoc(),
3840+
diag::conflicting_inferred_generic_parameter_result_and_closure,
3841+
GP, OS.str());
3842+
return true;
3843+
}
3844+
return false;
3845+
}
3846+
37123847
bool ConstraintSystem::diagnoseAmbiguityWithFixes(
37133848
SmallVectorImpl<Solution> &solutions) {
37143849
if (solutions.empty())
@@ -3761,8 +3896,6 @@ bool ConstraintSystem::diagnoseAmbiguityWithFixes(
37613896
// d. Diagnose remaining (uniqued based on kind + locator) fixes
37623897
// iff they appear in all of the solutions.
37633898

3764-
using Fix = std::pair<const Solution *, const ConstraintFix *>;
3765-
37663899
llvm::SmallSetVector<Fix, 4> fixes;
37673900
for (auto &solution : solutions) {
37683901
for (auto *fix : solution.Fixes)
@@ -3837,6 +3970,10 @@ bool ConstraintSystem::diagnoseAmbiguityWithFixes(
38373970
}
38383971
}
38393972

3973+
if (!diagnosed && diagnoseContextualFunctionCallGenericAmbiguity(
3974+
*this, contextualFixes, fixes.getArrayRef()))
3975+
return true;
3976+
38403977
return diagnosed;
38413978
}
38423979

0 commit comments

Comments
 (0)