Skip to content

Commit 93dac8f

Browse files
committed
[Type checker] Be more rigorous about extracting argument labels from calls.
Whenever we have a call, retrieve the argument labels from the argument structurally and associate them with the callee. We were previously doing this as a separate AST walk (which was unnecessary), so fold that into constraint generation for a CallExpr. This is a slightly-pared-back version of 3753d77 that isn't so rigid in its interpretation of ASTs. I'll tighten up the semantics over time.
1 parent 4a60b6c commit 93dac8f

File tree

7 files changed

+110
-91
lines changed

7 files changed

+110
-91
lines changed

include/swift/AST/Expr.h

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1621,6 +1621,9 @@ class ParenExpr : public IdentityExpr {
16211621
/// \brief Whether this expression has a trailing closure as its argument.
16221622
bool hasTrailingClosure() const { return HasTrailingClosure; }
16231623

1624+
/// Create a new, implicit parenthesized expression.
1625+
static ParenExpr *createImplicit(ASTContext &ctx, Expr *expr);
1626+
16241627
static bool classof(const Expr *E) { return E->getKind() == ExprKind::Paren; }
16251628
};
16261629

@@ -3195,7 +3198,15 @@ class CallExpr : public ApplyExpr {
31953198
SourceLoc FnLoc = getFn()->getLoc();
31963199
return FnLoc.isValid() ? FnLoc : getArg()->getLoc();
31973200
}
3198-
3201+
3202+
/// Retrieve the expression that direct represents the callee.
3203+
///
3204+
/// The "direct" callee is the expression representing the callee
3205+
/// after looking through top-level constructs that don't affect the
3206+
/// identity of the callee, e.g., extra parentheses, optional
3207+
/// unwrapping (?)/forcing (!), etc.
3208+
Expr *getDirectCallee() const;
3209+
31993210
static bool classof(const Expr *E) { return E->getKind() == ExprKind::Call; }
32003211
};
32013212

lib/AST/Expr.cpp

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -922,6 +922,14 @@ SequenceExpr *SequenceExpr::create(ASTContext &ctx, ArrayRef<Expr*> elements) {
922922
return ::new(Buffer) SequenceExpr(elements);
923923
}
924924

925+
ParenExpr *ParenExpr::createImplicit(ASTContext &ctx, Expr *expr) {
926+
ParenExpr *result = new (ctx) ParenExpr(SourceLoc(), expr, SourceLoc(),
927+
/*hasTrailingClosure=*/false,
928+
expr->getType());
929+
result->setImplicit();
930+
return result;
931+
}
932+
925933
SourceLoc TupleExpr::getStartLoc() const {
926934
if (LParenLoc.isValid()) return LParenLoc;
927935
if (getNumElements() == 0) return SourceLoc();
@@ -1033,6 +1041,25 @@ ValueDecl *ApplyExpr::getCalledValue() const {
10331041
return ::getCalledValue(Fn);
10341042
}
10351043

1044+
Expr *CallExpr::getDirectCallee() const {
1045+
auto fn = getFn();
1046+
while (true) {
1047+
fn = fn->getSemanticsProvidingExpr();
1048+
1049+
if (auto force = dyn_cast<ForceValueExpr>(fn)) {
1050+
fn = force->getSubExpr();
1051+
continue;
1052+
}
1053+
1054+
if (auto bind = dyn_cast<BindOptionalExpr>(fn)) {
1055+
fn = bind->getSubExpr();
1056+
continue;
1057+
}
1058+
1059+
return fn;
1060+
}
1061+
}
1062+
10361063
RebindSelfInConstructorExpr::RebindSelfInConstructorExpr(Expr *SubExpr,
10371064
VarDecl *Self)
10381065
: Expr(ExprKind::RebindSelfInConstructor, /*Implicit=*/true,

lib/Sema/CSGen.cpp

Lines changed: 49 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -1597,6 +1597,12 @@ namespace {
15971597

15981598
// If there is an argument, apply it.
15991599
if (auto arg = expr->getArgument()) {
1600+
// Identify and record the argument labels for this call.
1601+
if (auto argumentLabels = getArgumentLabelsForCallArgument(arg)) {
1602+
CS.CalleeArgumentLabels[CS.getConstraintLocator(expr)] =
1603+
*argumentLabels;
1604+
}
1605+
16001606
// The result type of the function must be convertible to the base type.
16011607
// TODO: we definitely want this to include ImplicitlyUnwrappedOptional; does it
16021608
// need to include everything else in the world?
@@ -2429,11 +2435,52 @@ namespace {
24292435
return expr->getType();
24302436
}
24312437

2438+
using ArgumentLabelState = ConstraintSystem::ArgumentLabelState;
2439+
2440+
/// Extract argument labels from the given call argument.
2441+
Optional<ArgumentLabelState> getArgumentLabelsForCallArgument(Expr *arg) {
2442+
// Parentheses.
2443+
if (auto paren = dyn_cast<ParenExpr>(arg)) {
2444+
return ArgumentLabelState{
2445+
CS.allocateCopy(llvm::makeArrayRef(Identifier())),
2446+
paren->hasTrailingClosure() };
2447+
}
2448+
2449+
// FIXME: Hack because not all CallExprs come in with
2450+
// ParenExpr/TupleExpr arguments.
2451+
if (!isa<TupleExpr>(arg)) return None;
2452+
2453+
// Tuples.
2454+
auto tuple = cast<TupleExpr>(arg);
2455+
2456+
// If we have element names, use 'em.
2457+
if (tuple->hasElementNames()) {
2458+
return ArgumentLabelState{tuple->getElementNames(),
2459+
tuple->hasTrailingClosure()};
2460+
}
2461+
2462+
// Otherwise, there are no argument labels.
2463+
llvm::SmallVector<Identifier, 4> names(tuple->getNumElements(),
2464+
Identifier());
2465+
return ArgumentLabelState{CS.allocateCopy(names),
2466+
tuple->hasTrailingClosure()};
2467+
}
2468+
24322469
Type visitApplyExpr(ApplyExpr *expr) {
24332470
Type outputTy;
24342471

24352472
auto fnExpr = expr->getFn();
2436-
2473+
2474+
// Identify and record the argument labels for a call.
2475+
if (auto call = dyn_cast<CallExpr>(expr)) {
2476+
if (auto argumentLabels =
2477+
getArgumentLabelsForCallArgument(expr->getArg())) {
2478+
auto callee = call->getDirectCallee();
2479+
CS.CalleeArgumentLabels[CS.getConstraintLocator(callee)] =
2480+
*argumentLabels;
2481+
}
2482+
}
2483+
24372484
if (isa<DeclRefExpr>(fnExpr)) {
24382485
if (auto fnType = fnExpr->getType()->getAs<AnyFunctionType>()) {
24392486
outputTy = fnType->getResult();
@@ -2536,7 +2583,7 @@ namespace {
25362583
FunctionType::ExtInfo extInfo;
25372584
if (isa<ClosureExpr>(fnExpr->getSemanticsProvidingExpr()))
25382585
extInfo = extInfo.withNoEscape();
2539-
2586+
25402587
auto funcTy = FunctionType::get(expr->getArg()->getType(), outputTy,
25412588
extInfo);
25422589

@@ -2999,87 +3046,12 @@ namespace {
29993046
/// \brief Ignore declarations.
30003047
bool walkToDeclPre(Decl *decl) override { return false; }
30013048
};
3002-
3003-
/// AST walker that records the keyword arguments provided at each
3004-
/// call site.
3005-
class ArgumentLabelWalker : public ASTWalker {
3006-
ConstraintSystem &CS;
3007-
llvm::DenseMap<Expr *, Expr *> ParentMap;
3008-
3009-
public:
3010-
ArgumentLabelWalker(ConstraintSystem &cs, Expr *expr)
3011-
: CS(cs), ParentMap(expr->getParentMap()) { }
3012-
3013-
using State = ConstraintSystem::ArgumentLabelState;
3014-
3015-
void associateArgumentLabels(Expr *arg, State labels,
3016-
bool labelsArePermanent) {
3017-
// Our parent must be a call.
3018-
auto call = dyn_cast_or_null<CallExpr>(ParentMap[arg]);
3019-
if (!call)
3020-
return;
3021-
3022-
// We must have originated at the call argument.
3023-
if (arg != call->getArg())
3024-
return;
3025-
3026-
// Dig out the function, looking through, parentheses, ?, and !.
3027-
auto fn = call->getFn();
3028-
do {
3029-
fn = fn->getSemanticsProvidingExpr();
3030-
3031-
if (auto force = dyn_cast<ForceValueExpr>(fn)) {
3032-
fn = force->getSubExpr();
3033-
continue;
3034-
}
3035-
3036-
if (auto bind = dyn_cast<BindOptionalExpr>(fn)) {
3037-
fn = bind->getSubExpr();
3038-
continue;
3039-
}
3040-
3041-
break;
3042-
} while (true);
3043-
3044-
// Record the labels.
3045-
if (!labelsArePermanent)
3046-
labels.Labels = CS.allocateCopy(labels.Labels);
3047-
CS.ArgumentLabels[CS.getConstraintLocator(fn)] = labels;
3048-
}
3049-
3050-
std::pair<bool, Expr *> walkToExprPre(Expr *expr) override {
3051-
if (auto tuple = dyn_cast<TupleExpr>(expr)) {
3052-
if (tuple->hasElementNames())
3053-
associateArgumentLabels(expr,
3054-
{ tuple->getElementNames(),
3055-
tuple->hasTrailingClosure() },
3056-
/*labelsArePermanent*/ true);
3057-
else {
3058-
llvm::SmallVector<Identifier, 4> names(tuple->getNumElements(),
3059-
Identifier());
3060-
associateArgumentLabels(expr, { names, tuple->hasTrailingClosure() },
3061-
/*labelsArePermanent*/ false);
3062-
}
3063-
} else if (auto paren = dyn_cast<ParenExpr>(expr)) {
3064-
associateArgumentLabels(paren,
3065-
{ { Identifier() },
3066-
paren->hasTrailingClosure() },
3067-
/*labelsArePermanent*/ false);
3068-
}
3069-
3070-
return { true, expr };
3071-
}
3072-
};
3073-
30743049
} // end anonymous namespace
30753050

30763051
Expr *ConstraintSystem::generateConstraints(Expr *expr) {
30773052
// Remove implicit conversions from the expression.
30783053
expr = expr->walk(SanitizeExpr(getTypeChecker()));
30793054

3080-
// Walk the expression to associate labeled arguments.
3081-
expr->walk(ArgumentLabelWalker(*this, expr));
3082-
30833055
// Walk the expression, generating constraints.
30843056
ConstraintGenerator cg(*this);
30853057
ConstraintWalker cw(cg);

lib/Sema/CSSimplify.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2612,8 +2612,8 @@ getArgumentLabels(ConstraintSystem &cs, ConstraintLocatorBuilder locator) {
26122612
if (!parts.empty())
26132613
return None;
26142614

2615-
auto known = cs.ArgumentLabels.find(cs.getConstraintLocator(anchor));
2616-
if (known == cs.ArgumentLabels.end())
2615+
auto known = cs.CalleeArgumentLabels.find(cs.getConstraintLocator(anchor));
2616+
if (known == cs.CalleeArgumentLabels.end())
26172617
return None;
26182618

26192619
return known->second;

lib/Sema/ConstraintSystem.h

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1042,16 +1042,25 @@ class ConstraintSystem {
10421042
/// we're exploring.
10431043
SolverState *solverState = nullptr;
10441044

1045+
/// Describes argument labels as they occur in a call.
10451046
struct ArgumentLabelState {
1047+
/// The actual argument labels provided at the call site.
1048+
///
1049+
/// The contents of this array are only guaranteed to live until
1050+
/// the constraint system is destroyed.
10461051
ArrayRef<Identifier> Labels;
1052+
1053+
/// Whether the last of the arguments was written as a trailing
1054+
/// closure. It will have an empty Identifier().
10471055
bool HasTrailingClosure;
10481056
};
10491057

1050-
/// A mapping from the constraint locators for references to various
1051-
/// names (e.g., member references, normal name references, possible
1052-
/// constructions) to the argument labels provided in the call to
1053-
/// that locator.
1054-
llvm::DenseMap<ConstraintLocator *, ArgumentLabelState> ArgumentLabels;
1058+
/// The argument labels that are associated with a given callee.
1059+
///
1060+
/// The argument labels for a particular callee are recorded at the
1061+
/// time of constraint generation based on the syntactic call
1062+
/// argument.
1063+
llvm::DenseMap<ConstraintLocator *, ArgumentLabelState> CalleeArgumentLabels;
10551064

10561065
/// FIXME: This is a workaround for the way we perform protocol
10571066
/// conformance checking for generic requirements, where we re-use

lib/Sema/MiscDiagnostics.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2830,14 +2830,16 @@ class ObjCSelectorWalker : public ASTWalker {
28302830
bool hadParens = false;
28312831
auto lookThroughParens = [&](Expr *arg, bool outermost) -> Expr * {
28322832
if (auto parenExpr = dyn_cast<ParenExpr>(arg)) {
2833-
if (!outermost) {
2833+
if (!outermost && !parenExpr->isImplicit()) {
28342834
hadParens = true;
28352835
return parenExpr->getSubExpr()->getSemanticsProvidingExpr();
28362836
}
28372837

28382838
arg = parenExpr->getSubExpr();
28392839
if (auto innerParenExpr = dyn_cast<ParenExpr>(arg)) {
2840-
hadParens = true;
2840+
if (!innerParenExpr->isImplicit())
2841+
hadParens = true;
2842+
28412843
arg = innerParenExpr->getSubExpr();
28422844
}
28432845
}

lib/Sema/TypeCheckConstraints.cpp

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -792,10 +792,8 @@ namespace {
792792
continue;
793793
}
794794

795-
// Look through identity, force-value, and 'try' expressions.
796-
if (isa<IdentityExpr>(ancestor) ||
797-
isa<ForceValueExpr>(ancestor) ||
798-
isa<AnyTryExpr>(ancestor)) {
795+
// Look through force-value, and 'try' expressions.
796+
if (isa<ForceValueExpr>(ancestor) || isa<AnyTryExpr>(ancestor)) {
799797
if (target)
800798
target = ancestor;
801799
continue;

0 commit comments

Comments
 (0)