Skip to content

Commit a6e84b3

Browse files
committed
[ConstraintSystem] Add call site support for type inference from default expressions
1 parent 23297c9 commit a6e84b3

File tree

3 files changed

+263
-7
lines changed

3 files changed

+263
-7
lines changed

include/swift/AST/DiagnosticsSema.def

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6238,5 +6238,24 @@ ERROR(type_sequence_on_non_generic_param, none,
62386238
"'@_typeSequence' must appear on a generic parameter",
62396239
())
62406240

6241+
//------------------------------------------------------------------------------
6242+
// MARK: Type inference from default expressions
6243+
//------------------------------------------------------------------------------
6244+
6245+
ERROR(cannot_default_generic_parameter_inferrable_from_result, none,
6246+
"cannot use default expression for inference of %0 because it "
6247+
"is inferrable from result type",
6248+
(Type))
6249+
6250+
ERROR(cannot_default_generic_parameter_inferrable_from_another_parameter, none,
6251+
"cannot use default expression for inference of %0 because it "
6252+
"is inferrable from parameters %1",
6253+
(Type, StringRef))
6254+
6255+
ERROR(cannot_default_generic_parameter_inferrable_through_same_type, none,
6256+
"cannot use default expression for inference of %0 because it "
6257+
"is inferrable through same-type requirement: %1",
6258+
(Type, StringRef))
6259+
62416260
#define UNDEFINE_DIAGNOSTIC_MACROS
62426261
#include "DefineDiagnosticMacros.h"

lib/Sema/CSSimplify.cpp

Lines changed: 28 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1626,10 +1626,35 @@ static ConstraintSystem::TypeMatchResult matchCallArguments(
16261626
continue;
16271627
}
16281628

1629-
// Skip unfulfilled parameters. There's nothing to do for them.
1630-
if (parameterBindings[paramIdx].empty())
1629+
// If type inference from default arguments is enabled, let's
1630+
// add a constraint from the parameter if necessary, otherwise
1631+
// there is nothing to do but move to the next parameter.
1632+
if (parameterBindings[paramIdx].empty()) {
1633+
auto &ctx = cs.getASTContext();
1634+
1635+
if (paramTy->isTypeVariableOrMember() &&
1636+
ctx.TypeCheckerOpts.EnableTypeInferenceFromDefaultArguments) {
1637+
auto *paramList = getParameterList(callee);
1638+
auto defaultExprType = paramList->get(paramIdx)->getTypeOfDefaultExpr();
1639+
1640+
// A caller side default.
1641+
if (!defaultExprType)
1642+
continue;
1643+
1644+
// If this is just a regular default type that works
1645+
// for any generic parameter type, let's continue.
1646+
if (defaultExprType->hasArchetype())
1647+
continue;
1648+
1649+
cs.addConstraint(
1650+
ConstraintKind::ArgumentConversion, paramTy, defaultExprType,
1651+
locator.withPathElement(LocatorPathElt::ApplyArgToParam(
1652+
paramIdx, paramIdx, param.getParameterFlags())));
1653+
}
1654+
16311655
continue;
1632-
1656+
}
1657+
16331658
// Compare each of the bound arguments for this parameter.
16341659
for (auto argIdx : parameterBindings[paramIdx]) {
16351660
auto loc = locator.withPathElement(LocatorPathElt::ApplyArgToParam(

lib/Sema/TypeCheckConstraints.cpp

Lines changed: 216 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -431,10 +431,222 @@ Type TypeChecker::typeCheckParameterDefault(Expr *&defaultValue,
431431
DeclContext *DC, Type paramType,
432432
bool isAutoClosure) {
433433
assert(paramType && !paramType->hasError());
434-
return typeCheckExpression(defaultValue, DC, /*contextualInfo=*/
435-
{paramType, isAutoClosure
436-
? CTP_AutoclosureDefaultParameter
437-
: CTP_DefaultParameter});
434+
435+
auto &ctx = DC->getASTContext();
436+
437+
// First, let's try to type-check default expression using interface
438+
// type of the parameter, if that succeeds - we are done.
439+
SolutionApplicationTarget defaultExprTarget(
440+
defaultValue, DC,
441+
isAutoClosure ? CTP_AutoclosureDefaultParameter : CTP_DefaultParameter,
442+
paramType, /*isDiscarded=*/false);
443+
444+
{
445+
// Buffer all of the diagnostics produced by \c typeCheckExpression
446+
// since in some cases we need to try type-checking again with a
447+
// different contextual type, see below.
448+
DiagnosticTransaction diagnostics(ctx.Diags);
449+
450+
// First, let's try to type-check default expression using
451+
// archetypes, which guarantees that it would work for any
452+
// substitution of the generic parameter (if they are involved).
453+
if (auto result = typeCheckExpression(defaultExprTarget)) {
454+
defaultValue = result->getAsExpr();
455+
return defaultValue->getType();
456+
}
457+
458+
// If inference is disabled, fail.
459+
if (!ctx.TypeCheckerOpts.EnableTypeInferenceFromDefaultArguments)
460+
return Type();
461+
462+
// Ignore any diagnostics emitted by the original type-check.
463+
diagnostics.abort();
464+
}
465+
466+
// Let's see whether it would be possible to use default expression
467+
// for generic parameter inference.
468+
//
469+
// First, let's check whether:
470+
// - Parameter type is a generic parameter; and
471+
// - It's only used in the current position in the parameter list
472+
// or result. This check makes sure that that generic argument
473+
// could only come from an explicit argument or this expression.
474+
//
475+
// If both of aforementioned conditions are true, let's attempt
476+
// to open generic parameter and infer the type of this default
477+
// expression.
478+
auto interfaceType = paramType->mapTypeOutOfContext();
479+
if (!interfaceType->isTypeParameter())
480+
return Type();
481+
482+
auto containsType = [&](Type type, Type contained) {
483+
return type.findIf(
484+
[&contained](Type nested) { return nested->isEqual(contained); });
485+
};
486+
487+
// Anchor of this default expression.
488+
auto *anchor = cast<ValueDecl>(DC->getParent()->getAsDecl());
489+
490+
// Check whether generic parameter is only mentioned once in
491+
// the anchor's signature.
492+
{
493+
auto anchorTy = anchor->getInterfaceType()->castTo<GenericFunctionType>();
494+
495+
// Reject if generic parameter could be inferred from result type.
496+
if (containsType(anchorTy->getResult(), interfaceType)) {
497+
ctx.Diags.diagnose(
498+
defaultValue->getLoc(),
499+
diag::cannot_default_generic_parameter_inferrable_from_result,
500+
interfaceType);
501+
return Type();
502+
}
503+
504+
// Reject if generic parameter is used in multiple different positions
505+
// in the parameter list.
506+
507+
llvm::SmallVector<unsigned, 2> affectedParams;
508+
for (unsigned i : indices(anchorTy->getParams())) {
509+
const auto &param = anchorTy->getParams()[i];
510+
511+
if (containsType(param.getPlainType(), interfaceType))
512+
affectedParams.push_back(i);
513+
}
514+
515+
if (affectedParams.size() > 1) {
516+
SmallString<32> paramBuf;
517+
llvm::raw_svector_ostream params(paramBuf);
518+
519+
interleave(
520+
affectedParams, [&](const unsigned index) { params << "#" << index; },
521+
[&] { params << ", "; });
522+
523+
ctx.Diags.diagnose(
524+
defaultValue->getLoc(),
525+
diag::
526+
cannot_default_generic_parameter_inferrable_from_another_parameter,
527+
interfaceType, params.str());
528+
return Type();
529+
}
530+
}
531+
532+
auto signature = DC->getGenericSignatureOfContext();
533+
assert(signature && "generic parameter without signature?");
534+
535+
ConstraintSystemOptions options;
536+
options |= ConstraintSystemFlags::AllowFixes;
537+
538+
ConstraintSystem cs(DC, options);
539+
540+
auto *locator = cs.getConstraintLocator(
541+
defaultValue, LocatorPathElt::ContextualType(
542+
defaultExprTarget.getExprContextualTypePurpose()));
543+
544+
// A replacement for generic parameter type to associate any generic
545+
// requirements with.
546+
auto *contextualTy = cs.createTypeVariable(locator, /*flags=*/0);
547+
548+
auto *requirementBaseLocator = cs.getConstraintLocator(
549+
locator, LocatorPathElt::OpenedGeneric(signature));
550+
551+
// Let's check all of the requirements this parameter is invoved in,
552+
// If it's connected to any other generic types (directly or through
553+
// a dependent member type), that means it could be inferred through
554+
// them e.g. `T: X.Y` or `T == U`.
555+
{
556+
auto isViable = [](Type type) {
557+
return !(type->hasTypeParameter() && type->hasDependentMember());
558+
};
559+
560+
auto recordRequirement = [&](unsigned index, Requirement requirement,
561+
ConstraintLocator *locator) {
562+
cs.openGenericRequirement(DC->getParent(), index, requirement,
563+
/*skipSelfProtocolConstraint=*/false, locator,
564+
[](Type type) -> Type { return type; });
565+
};
566+
567+
auto requirements = signature.getRequirements();
568+
for (unsigned reqIdx = 0; reqIdx != requirements.size(); ++reqIdx) {
569+
auto &requirement = requirements[reqIdx];
570+
571+
switch (requirement.getKind()) {
572+
case RequirementKind::Conformance: {
573+
if (!requirement.getFirstType()->isEqual(interfaceType))
574+
continue;
575+
576+
recordRequirement(reqIdx,
577+
{RequirementKind::Conformance, contextualTy,
578+
requirement.getSecondType()},
579+
requirementBaseLocator);
580+
break;
581+
}
582+
583+
case RequirementKind::Superclass: {
584+
auto subclassTy = requirement.getFirstType();
585+
auto superclassTy = requirement.getSecondType();
586+
587+
if (subclassTy->isEqual(interfaceType) && isViable(superclassTy)) {
588+
recordRequirement(
589+
reqIdx, {RequirementKind::Superclass, contextualTy, superclassTy},
590+
requirementBaseLocator);
591+
}
592+
593+
break;
594+
}
595+
596+
case RequirementKind::SameType: {
597+
// If there is a same-type constraint that involves our parameter
598+
// type, fail the type-check since the type could be inferred
599+
// through other positions.
600+
if (containsType(requirement.getFirstType(), interfaceType) ||
601+
containsType(requirement.getSecondType(), interfaceType)) {
602+
SmallString<32> reqBuf;
603+
llvm::raw_svector_ostream req(reqBuf);
604+
605+
requirement.print(req, PrintOptions());
606+
607+
ctx.Diags.diagnose(
608+
defaultValue->getLoc(),
609+
diag::
610+
cannot_default_generic_parameter_inferrable_through_same_type,
611+
interfaceType, req.str());
612+
return Type();
613+
}
614+
615+
continue;
616+
}
617+
618+
case RequirementKind::Layout:
619+
if (!requirement.getFirstType()->isEqual(interfaceType))
620+
continue;
621+
622+
recordRequirement(reqIdx,
623+
{RequirementKind::Layout, contextualTy,
624+
requirement.getLayoutConstraint()},
625+
requirementBaseLocator);
626+
break;
627+
}
628+
}
629+
}
630+
631+
defaultExprTarget.setExprConversionType(contextualTy);
632+
cs.setContextualType(defaultValue,
633+
defaultExprTarget.getExprContextualTypeLoc(),
634+
defaultExprTarget.getExprContextualTypePurpose());
635+
636+
auto viable = cs.solve(defaultExprTarget, FreeTypeVariableBinding::Disallow);
637+
if (!viable)
638+
return Type();
639+
640+
auto &solution = (*viable)[0];
641+
642+
cs.applySolution(solution);
643+
644+
if (auto result = cs.applySolution(solution, defaultExprTarget)) {
645+
defaultValue = result->getAsExpr();
646+
return defaultValue->getType();
647+
}
648+
649+
return Type();
438650
}
439651

440652
bool TypeChecker::typeCheckBinding(Pattern *&pattern, Expr *&initializer,

0 commit comments

Comments
 (0)