Skip to content

Commit f1e602f

Browse files
authored
Merge pull request #41189 from xedin/trailing-closures-with-callAsFunction
[ConstraintSystem] Match trailing closures to implicit `.callAsFunction` when necessary
2 parents 6fe3ccc + dee174a commit f1e602f

File tree

8 files changed

+265
-23
lines changed

8 files changed

+265
-23
lines changed

include/swift/AST/ArgumentList.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -249,11 +249,13 @@ class alignas(Argument) ArgumentList final
249249
static ArgumentList *
250250
createImplicit(ASTContext &ctx, SourceLoc lParenLoc, ArrayRef<Argument> args,
251251
SourceLoc rParenLoc,
252+
Optional<unsigned> firstTrailingClosureIndex = None,
252253
AllocationArena arena = AllocationArena::Permanent);
253254

254255
/// Create a new implicit ArgumentList with a set of \p args.
255256
static ArgumentList *
256257
createImplicit(ASTContext &ctx, ArrayRef<Argument> args,
258+
Optional<unsigned> firstTrailingClosureIndex = None,
257259
AllocationArena arena = AllocationArena::Permanent);
258260

259261
/// Create a new implicit ArgumentList with a single labeled argument

include/swift/Sema/ConstraintSystem.h

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1220,6 +1220,10 @@ class Solution {
12201220
/// constructions) to the argument lists for the call to that locator.
12211221
llvm::MapVector<ConstraintLocator *, ArgumentList *> argumentLists;
12221222

1223+
/// The set of implicitly generated `.callAsFunction` root expressions.
1224+
llvm::DenseMap<ConstraintLocator *, UnresolvedDotExpr *>
1225+
ImplicitCallAsFunctionRoots;
1226+
12231227
/// Record a new argument matching choice for given locator that maps a
12241228
/// single argument to a single parameter.
12251229
void recordSingleArgMatchingChoice(ConstraintLocator *locator);
@@ -2468,6 +2472,12 @@ class ConstraintSystem {
24682472
/// types.
24692473
llvm::DenseMap<CanType, DynamicCallableMethods> DynamicCallableCache;
24702474

2475+
/// A cache of implicitly generated dot-member expressions used as roots
2476+
/// for some `.callAsFunction` calls. The key here is "base" locator for
2477+
/// the `.callAsFunction` member reference.
2478+
llvm::SmallMapVector<ConstraintLocator *, UnresolvedDotExpr *, 2>
2479+
ImplicitCallAsFunctionRoots;
2480+
24712481
private:
24722482
/// Describe the candidate expression for partial solving.
24732483
/// This class used by shrink & solve methods which apply
@@ -2951,6 +2961,9 @@ class ConstraintSystem {
29512961
/// The length of \c ArgumentLists.
29522962
unsigned numArgumentLists;
29532963

2964+
/// The length of \c ImplicitCallAsFunctionRoots.
2965+
unsigned numImplicitCallAsFunctionRoots;
2966+
29542967
/// The previous score.
29552968
Score PreviousScore;
29562969

@@ -3475,6 +3488,11 @@ class ConstraintSystem {
34753488
void recordMatchCallArgumentResult(ConstraintLocator *locator,
34763489
MatchCallArgumentResult result);
34773490

3491+
/// Record implicitly generated `callAsFunction` with root at the
3492+
/// given expression, located at \c locator.
3493+
void recordCallAsFunction(UnresolvedDotExpr *root, ArgumentList *arguments,
3494+
ConstraintLocator *locator);
3495+
34783496
/// Walk a closure AST to determine its effects.
34793497
///
34803498
/// \returns a function's extended info describing the effects, as
@@ -5474,6 +5492,7 @@ matchCallArguments(
54745492
ConstraintSystem::TypeMatchResult
54755493
matchCallArguments(ConstraintSystem &cs,
54765494
FunctionType *contextualType,
5495+
ArgumentList *argumentList,
54775496
ArrayRef<AnyFunctionType::Param> args,
54785497
ArrayRef<AnyFunctionType::Param> params,
54795498
ConstraintKind subKind,

lib/AST/ArgumentList.cpp

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -108,19 +108,22 @@ ArgumentList *ArgumentList::createTypeChecked(ASTContext &ctx,
108108
originalArgs->isImplicit(), originalArgs);
109109
}
110110

111-
ArgumentList *ArgumentList::createImplicit(ASTContext &ctx, SourceLoc lParenLoc,
112-
ArrayRef<Argument> args,
113-
SourceLoc rParenLoc,
114-
AllocationArena arena) {
115-
return create(ctx, lParenLoc, args, rParenLoc,
116-
/*firstTrailingClosureIdx*/ None, /*implicit*/ true,
111+
ArgumentList *
112+
ArgumentList::createImplicit(ASTContext &ctx, SourceLoc lParenLoc,
113+
ArrayRef<Argument> args, SourceLoc rParenLoc,
114+
Optional<unsigned> firstTrailingClosureIndex,
115+
AllocationArena arena) {
116+
return create(ctx, lParenLoc, args, rParenLoc, firstTrailingClosureIndex,
117+
/*implicit*/ true,
117118
/*originalArgs*/ nullptr, arena);
118119
}
119120

120-
ArgumentList *ArgumentList::createImplicit(ASTContext &ctx,
121-
ArrayRef<Argument> args,
122-
AllocationArena arena) {
123-
return createImplicit(ctx, SourceLoc(), args, SourceLoc(), arena);
121+
ArgumentList *
122+
ArgumentList::createImplicit(ASTContext &ctx, ArrayRef<Argument> args,
123+
Optional<unsigned> firstTrailingClosureIndex,
124+
AllocationArena arena) {
125+
return createImplicit(ctx, SourceLoc(), args, SourceLoc(),
126+
firstTrailingClosureIndex, arena);
124127
}
125128

126129
ArgumentList *ArgumentList::forImplicitSingle(ASTContext &ctx, Identifier label,

lib/Sema/CSApply.cpp

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7712,7 +7712,38 @@ Expr *ExprRewriter::finishApply(ApplyExpr *apply, Type openedType,
77127712
apply->setFn(declRef);
77137713

77147714
// Tail-recur to actually call the constructor.
7715-
return finishApply(apply, openedType, locator, ctorLocator);
7715+
auto *ctorCall = finishApply(apply, openedType, locator, ctorLocator);
7716+
7717+
// Check whether this is a situation like `T(...) { ... }` where `T` is
7718+
// a callable type and trailing closure(s) are associated with implicit
7719+
// `.callAsFunction` instead of constructor.
7720+
{
7721+
auto callAsFunction =
7722+
solution.ImplicitCallAsFunctionRoots.find(ctorLocator);
7723+
if (callAsFunction != solution.ImplicitCallAsFunctionRoots.end()) {
7724+
auto *dotExpr = callAsFunction->second;
7725+
auto resultTy = solution.getResolvedType(dotExpr);
7726+
7727+
auto *implicitCall = CallExpr::createImplicit(
7728+
cs.getASTContext(), ctorCall,
7729+
solution.getArgumentList(cs.getConstraintLocator(
7730+
dotExpr, ConstraintLocator::ApplyArgument)));
7731+
7732+
implicitCall->setType(resultTy);
7733+
cs.cacheType(implicitCall);
7734+
7735+
auto *memberCalleeLoc =
7736+
cs.getConstraintLocator(dotExpr,
7737+
{ConstraintLocator::ApplyFunction,
7738+
ConstraintLocator::ImplicitCallAsFunction},
7739+
/*summaryFlags=*/0);
7740+
7741+
return finishApply(implicitCall, resultTy, cs.getConstraintLocator(dotExpr),
7742+
memberCalleeLoc);
7743+
}
7744+
}
7745+
7746+
return ctorCall;
77167747
}
77177748

77187749
/// Determine whether this closure should be treated as Sendable.

lib/Sema/CSSimplify.cpp

Lines changed: 152 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1321,6 +1321,7 @@ class OpenTypeSequenceElements {
13211321
// Match the argument of a call to the parameter.
13221322
ConstraintSystem::TypeMatchResult constraints::matchCallArguments(
13231323
ConstraintSystem &cs, FunctionType *contextualType,
1324+
ArgumentList *argList,
13241325
ArrayRef<AnyFunctionType::Param> args,
13251326
ArrayRef<AnyFunctionType::Param> params, ConstraintKind subKind,
13261327
ConstraintLocatorBuilder locator,
@@ -1340,8 +1341,7 @@ ConstraintSystem::TypeMatchResult constraints::matchCallArguments(
13401341

13411342
ParameterListInfo paramInfo(params, callee, appliedSelf);
13421343

1343-
// Dig out the argument information.
1344-
auto *argList = cs.getArgumentList(loc);
1344+
// Make sure that argument list is available.
13451345
assert(argList);
13461346

13471347
// Apply labels to arguments.
@@ -10494,6 +10494,32 @@ bool ConstraintSystem::simplifyAppliedOverloads(
1049410494
numOptionalUnwraps, locator);
1049510495
}
1049610496

10497+
/// Create an implicit dot-member reference expression to be used
10498+
/// as a root for injected `.callAsFunction` call.
10499+
static UnresolvedDotExpr *
10500+
createImplicitRootForCallAsFunction(ConstraintSystem &cs, Type refType,
10501+
ArgumentList *arguments,
10502+
ConstraintLocator *calleeLocator) {
10503+
auto &ctx = cs.getASTContext();
10504+
auto *baseExpr = castToExpr(calleeLocator->getAnchor());
10505+
10506+
SmallVector<Identifier, 2> closureLabelsScratch;
10507+
// Create implicit `.callAsFunction` expression to use as an anchor
10508+
// for new argument list that only has trailing closures in it.
10509+
auto *implicitRef = UnresolvedDotExpr::createImplicit(
10510+
ctx, baseExpr, {ctx.Id_callAsFunction},
10511+
arguments->getArgumentLabels(closureLabelsScratch));
10512+
10513+
{
10514+
// Record a type of the new reference in the constraint system.
10515+
cs.setType(implicitRef, refType);
10516+
// Record new `.callAsFunction` in the constraint system.
10517+
cs.recordCallAsFunction(implicitRef, arguments, calleeLocator);
10518+
}
10519+
10520+
return implicitRef;
10521+
}
10522+
1049710523
ConstraintSystem::SolutionKind
1049810524
ConstraintSystem::simplifyApplicableFnConstraint(
1049910525
Type type1, Type type2,
@@ -10548,17 +10574,20 @@ ConstraintSystem::simplifyApplicableFnConstraint(
1054810574
};
1054910575

1055010576
// Local function to form an unsolved result.
10551-
auto formUnsolved = [&] {
10577+
auto formUnsolved = [&](bool activate = false) {
1055210578
if (flags.contains(TMF_GenerateConstraints)) {
10553-
addUnsolvedConstraint(
10554-
Constraint::createApplicableFunction(
10579+
auto *application = Constraint::createApplicableFunction(
1055510580
*this, type1, type2, trailingClosureMatching,
10556-
getConstraintLocator(locator)));
10581+
getConstraintLocator(locator));
10582+
10583+
addUnsolvedConstraint(application);
10584+
if (activate)
10585+
activateConstraint(application);
10586+
1055710587
return SolutionKind::Solved;
1055810588
}
1055910589

1056010590
return SolutionKind::Unsolved;
10561-
1056210591
};
1056310592

1056410593
// If right-hand side is a type variable, the constraint is unsolved.
@@ -10633,15 +10662,97 @@ ConstraintSystem::simplifyApplicableFnConstraint(
1063310662
? ConstraintKind::OperatorArgumentConversion
1063410663
: ConstraintKind::ArgumentConversion);
1063510664

10665+
auto *argumentsLoc = getConstraintLocator(
10666+
outerLocator.withPathElement(ConstraintLocator::ApplyArgument));
10667+
10668+
auto *argumentList = getArgumentList(argumentsLoc);
1063610669
// The argument type must be convertible to the input type.
1063710670
auto matchCallResult = ::matchCallArguments(
10638-
*this, func2, func1->getParams(), func2->getParams(), subKind,
10639-
outerLocator.withPathElement(ConstraintLocator::ApplyArgument),
10640-
trailingClosureMatching);
10671+
*this, func2, argumentList, func1->getParams(), func2->getParams(),
10672+
subKind, argumentsLoc, trailingClosureMatching);
1064110673

1064210674
switch (matchCallResult) {
10643-
case SolutionKind::Error:
10675+
case SolutionKind::Error: {
10676+
auto resultTy = func2->getResult();
10677+
10678+
// If this is a call that constructs a callable type with
10679+
// trailing closure(s), closure(s) might not belong to
10680+
// the constructor but rather to implicit `callAsFunction`,
10681+
// there is no way to determine that without trying.
10682+
if (resultTy->isCallableNominalType(DC) &&
10683+
argumentList->hasAnyTrailingClosures()) {
10684+
auto *calleeLoc = getCalleeLocator(argumentsLoc);
10685+
10686+
bool isInit = false;
10687+
if (auto overload = findSelectedOverloadFor(calleeLoc)) {
10688+
isInit = bool(dyn_cast_or_null<ConstructorDecl>(
10689+
overload->choice.getDeclOrNull()));
10690+
}
10691+
10692+
if (!isInit)
10693+
return SolutionKind::Error;
10694+
10695+
auto &ctx = getASTContext();
10696+
auto numTrailing = argumentList->getNumTrailingClosures();
10697+
10698+
SmallVector<Argument, 4> newArguments(
10699+
argumentList->getNonTrailingArgs());
10700+
SmallVector<Argument, 4> trailingClosures(
10701+
argumentList->getTrailingClosures());
10702+
10703+
// Original argument list with all the trailing closures removed.
10704+
auto *newArgumentList = ArgumentList::createParsed(
10705+
ctx, argumentList->getLParenLoc(), newArguments,
10706+
argumentList->getRParenLoc(),
10707+
/*firstTrailingClosureIndex=*/None);
10708+
10709+
auto trailingClosureTypes = func1->getParams().take_back(numTrailing);
10710+
// The original result type is going to become a result of
10711+
// implicit `.callAsFunction` instead since `.callAsFunction`
10712+
// is inserted between `.init` and trailing closures.
10713+
auto callAsFunctionResultTy = func1->getResult();
10714+
10715+
// The implicit replacement for original result type which
10716+
// represents a callable type produced by `.init` call.
10717+
auto callableType =
10718+
createTypeVariable(getConstraintLocator({}), /*flags=*/0);
10719+
10720+
// The original application type with all the trailing closures
10721+
// dropped from it and result replaced to the implicit variable.
10722+
func1 = FunctionType::get(func1->getParams().drop_back(numTrailing),
10723+
callableType, func1->getExtInfo());
10724+
10725+
auto matchCallResult = ::matchCallArguments(
10726+
*this, func2, newArgumentList, func1->getParams(),
10727+
func2->getParams(), subKind, argumentsLoc, trailingClosureMatching);
10728+
10729+
if (matchCallResult != SolutionKind::Solved)
10730+
return SolutionKind::Error;
10731+
10732+
auto *implicitCallArgumentList =
10733+
ArgumentList::createImplicit(ctx, trailingClosures,
10734+
/*firstTrailingClosureIndex=*/0);
10735+
10736+
auto *implicitRef = createImplicitRootForCallAsFunction(
10737+
*this, callAsFunctionResultTy, implicitCallArgumentList, calleeLoc);
10738+
10739+
auto callAsFunctionArguments =
10740+
FunctionType::get(trailingClosureTypes, callAsFunctionResultTy,
10741+
FunctionType::ExtInfo());
10742+
10743+
// Form an unsolved constraint to apply trailing closures to a
10744+
// callable type produced by `.init`. This constraint would become
10745+
// active when `callableType` is bound.
10746+
addUnsolvedConstraint(Constraint::create(
10747+
*this, ConstraintKind::ApplicableFunction, callAsFunctionArguments,
10748+
callableType,
10749+
getConstraintLocator(implicitRef,
10750+
ConstraintLocator::ApplyFunction)));
10751+
break;
10752+
}
10753+
1064410754
return SolutionKind::Error;
10755+
}
1064510756

1064610757
case SolutionKind::Unsolved: {
1064710758
// Only occurs when there is an ambiguity between forward scanning and
@@ -10691,6 +10802,26 @@ ConstraintSystem::simplifyApplicableFnConstraint(
1069110802
if (instance2->isTypeVariableOrMember())
1069210803
return formUnsolved();
1069310804

10805+
auto *argumentsLoc = getConstraintLocator(
10806+
outerLocator.withPathElement(ConstraintLocator::ApplyArgument));
10807+
10808+
auto *argumentList = getArgumentList(argumentsLoc);
10809+
assert(argumentList);
10810+
10811+
// Cannot simplify construction of callable types during constraint
10812+
// generation when trailing closures are present because such calls
10813+
// have special trailing closure matching semantics. It's unclear
10814+
// whether trailing arguments belong to `.init` or implicit
10815+
// `.callAsFunction` in this case.
10816+
//
10817+
// Note that the constraint has to be activate so that solver attempts
10818+
// once constraint generation is done.
10819+
if (getPhase() == ConstraintSystemPhase::ConstraintGeneration &&
10820+
argumentList->hasAnyTrailingClosures() &&
10821+
instance2->isCallableNominalType(DC)) {
10822+
return formUnsolved(/*activate=*/true);
10823+
}
10824+
1069410825
// Construct the instance from the input arguments.
1069510826
auto simplified = simplifyConstructionConstraint(instance2, func1, subflags,
1069610827
/*FIXME?*/ DC,
@@ -11579,6 +11710,7 @@ ConstraintSystem::simplifyRestrictedConstraintImpl(
1157911710
if (!ArgumentLists.count(memberLoc)) {
1158011711
auto *argList = ArgumentList::createImplicit(
1158111712
getASTContext(), {Argument(SourceLoc(), Identifier(), nullptr)},
11713+
/*firstTrailingClosureIndex=*/None,
1158211714
AllocationArena::ConstraintSolver);
1158311715
ArgumentLists.insert({memberLoc, argList});
1158411716
}
@@ -11867,6 +11999,15 @@ void ConstraintSystem::recordMatchCallArgumentResult(
1186711999
argumentMatchingChoices.insert({locator, result});
1186812000
}
1186912001

12002+
void ConstraintSystem::recordCallAsFunction(UnresolvedDotExpr *root,
12003+
ArgumentList *arguments,
12004+
ConstraintLocator *locator) {
12005+
ImplicitCallAsFunctionRoots.insert({locator, root});
12006+
12007+
associateArgumentList(
12008+
getConstraintLocator(root, ConstraintLocator::ApplyArgument), arguments);
12009+
}
12010+
1187012011
ConstraintSystem::SolutionKind ConstraintSystem::simplifyFixConstraint(
1187112012
ConstraintFix *fix, Type type1, Type type2, ConstraintKind matchKind,
1187212013
TypeMatchOptions flags, ConstraintLocatorBuilder locator) {

0 commit comments

Comments
 (0)