Skip to content

Commit ca7fb37

Browse files
author
Nathan Hawes
committed
[CodeCompletion][Sema][Parse] Migrate unresolved member completion to the solver-based completion implementation
Following on from updating regular member completion, this hooks up unresolved member completion (i.e. .<complete here>) to the typeCheckForCodeCompletion API to generate completions from all solutions the constraint solver produces (even those requiring fixes), rather than relying on a single solution being applied to the AST (if any). This lets us produce unresolved member completions even when the contextual type is ambiguous or involves errors. Whenever typeCheckExpression is called on an expression containing a code completion expression and a CompletionCallback has been set, each solution formed is passed to the callback so the type of the completion expression can be extracted and used to lookup up the members to return.
1 parent 7abf272 commit ca7fb37

19 files changed

+450
-255
lines changed

include/swift/Sema/CodeCompletionTypeChecking.h

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,11 @@
1818
#ifndef SWIFT_SEMA_CODECOMPLETIONTYPECHECKING_H
1919
#define SWIFT_SEMA_CODECOMPLETIONTYPECHECKING_H
2020

21+
#include "swift/Basic/LLVM.h"
22+
#include "swift/AST/Type.h"
23+
#include "llvm/ADT/DenseMap.h"
24+
#include "llvm/ADT/SmallVector.h"
25+
2126
namespace swift {
2227
class Decl;
2328
class DeclContext;
@@ -32,7 +37,7 @@ namespace swift {
3237
class TypeCheckCompletionCallback {
3338
public:
3439
/// Called for each solution produced while type-checking an expression
35-
/// containing a code completion expression.
40+
/// that the code completion expression participates in.
3641
virtual void sawSolution(const constraints::Solution &solution) = 0;
3742
virtual ~TypeCheckCompletionCallback() {}
3843
};
@@ -78,6 +83,39 @@ namespace swift {
7883

7984
void sawSolution(const constraints::Solution &solution) override;
8085
};
86+
87+
/// Used to collect and store information needed to perform unresolved member
88+
/// completion (\c CompletionKind::UnresolvedMember ) from the solutions
89+
/// formed during expression type-checking.
90+
class UnresolvedMemberTypeCheckCompletionCallback: public TypeCheckCompletionCallback {
91+
public:
92+
struct Result {
93+
Type ExpectedTy;
94+
bool IsSingleExpressionBody;
95+
};
96+
97+
private:
98+
CodeCompletionExpr *CompletionExpr;
99+
SmallVector<Result, 4> Results;
100+
bool GotCallback = false;
101+
102+
public:
103+
UnresolvedMemberTypeCheckCompletionCallback(CodeCompletionExpr *CompletionExpr)
104+
: CompletionExpr(CompletionExpr) {}
105+
106+
ArrayRef<Result> getResults() const { return Results; }
107+
108+
/// True if at least one solution was passed via the \c sawSolution
109+
/// callback.
110+
bool gotCallback() const { return GotCallback; }
111+
112+
/// Typecheck the code completion expression in its outermost expression
113+
/// context, calling \c sawSolution for each solution formed.
114+
void fallbackTypeCheck(DeclContext *DC);
115+
116+
void sawSolution(const constraints::Solution &solution) override;
117+
};
118+
81119
}
82120

83121
#endif

lib/IDE/CodeCompletion.cpp

Lines changed: 67 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -4340,6 +4340,16 @@ class CompletionLookup final : public swift::VisibleDeclConsumer {
43404340
if (!T->mayHaveMembers())
43414341
return;
43424342

4343+
if (auto objT = T->getOptionalObjectType()) {
4344+
// Add 'nil' keyword with erasing '.' instruction.
4345+
unsigned bytesToErase = 0;
4346+
auto &SM = CurrDeclContext->getASTContext().SourceMgr;
4347+
if (DotLoc.isValid())
4348+
bytesToErase = SM.getByteDistance(DotLoc, SM.getCodeCompletionLoc());
4349+
addKeyword("nil", T, SemanticContextKind::None,
4350+
CodeCompletionKeywordKind::kw_nil, bytesToErase);
4351+
}
4352+
43434353
// We can only say .foo where foo is a static member of the contextual
43444354
// type and has the same type (or if the member is a function, then the
43454355
// same result type) as the contextual type.
@@ -4378,14 +4388,6 @@ class CompletionLookup final : public swift::VisibleDeclConsumer {
43784388
objT = objT->lookThroughAllOptionalTypes();
43794389
if (seenTypes.insert(objT->getCanonicalType()).second)
43804390
getUnresolvedMemberCompletions(objT);
4381-
4382-
// Add 'nil' keyword with erasing '.' instruction.
4383-
unsigned bytesToErase = 0;
4384-
auto &SM = CurrDeclContext->getASTContext().SourceMgr;
4385-
if (DotLoc.isValid())
4386-
bytesToErase = SM.getByteDistance(DotLoc, SM.getCodeCompletionLoc());
4387-
addKeyword("nil", T, SemanticContextKind::None,
4388-
CodeCompletionKeywordKind::kw_nil, bytesToErase);
43894391
}
43904392
getUnresolvedMemberCompletions(T);
43914393
}
@@ -6083,6 +6085,45 @@ static void deliverCompletionResults(CodeCompletionContext &CompletionContext,
60836085
DCForModules);
60846086
}
60856087

6088+
void deliverUnresolvedMemberResults(
6089+
ArrayRef<UnresolvedMemberTypeCheckCompletionCallback::Result> Results,
6090+
DeclContext *DC, SourceLoc DotLoc,
6091+
ide::CodeCompletionContext &CompletionCtx,
6092+
CodeCompletionConsumer &Consumer) {
6093+
ASTContext &Ctx = DC->getASTContext();
6094+
CompletionLookup Lookup(CompletionCtx.getResultSink(), Ctx, DC,
6095+
&CompletionCtx);
6096+
6097+
assert(DotLoc.isValid());
6098+
Lookup.setHaveDot(DotLoc);
6099+
Lookup.shouldCheckForDuplicates(Results.size() > 1);
6100+
6101+
// Get the canonical versions of the top-level types
6102+
SmallPtrSet<CanType, 4> originalTypes;
6103+
for (auto &Result: Results)
6104+
originalTypes.insert(Result.ExpectedTy->getCanonicalType());
6105+
6106+
for (auto &Result: Results) {
6107+
Lookup.setExpectedTypes({Result.ExpectedTy},
6108+
Result.IsSingleExpressionBody,
6109+
/*expectsNonVoid*/true);
6110+
Lookup.setIdealExpectedType(Result.ExpectedTy);
6111+
6112+
// For optional types, also get members of the unwrapped type if it's not
6113+
// already equivalent to one of the top-level types. Handling it via the top
6114+
// level type and not here ensures we give the correct type relation
6115+
// (identical, rather than convertible).
6116+
if (Result.ExpectedTy->getOptionalObjectType()) {
6117+
Type Unwrapped = Result.ExpectedTy->lookThroughAllOptionalTypes();
6118+
if (originalTypes.insert(Unwrapped->getCanonicalType()).second)
6119+
Lookup.getUnresolvedMemberCompletions(Unwrapped);
6120+
}
6121+
Lookup.getUnresolvedMemberCompletions(Result.ExpectedTy);
6122+
}
6123+
SourceFile *SF = DC->getParentSourceFile();
6124+
deliverCompletionResults(CompletionCtx, Lookup, *SF, Consumer);
6125+
}
6126+
60866127
void deliverDotExprResults(
60876128
ArrayRef<DotExprTypeCheckCompletionCallback::Result> Results,
60886129
Expr *BaseExpr, DeclContext *DC, SourceLoc DotLoc, bool IsInSelector,
@@ -6158,6 +6199,23 @@ bool CodeCompletionCallbacksImpl::trySolverCompletion(bool MaybeFuncBody) {
61586199
Consumer);
61596200
return true;
61606201
}
6202+
case CompletionKind::UnresolvedMember: {
6203+
assert(CodeCompleteTokenExpr);
6204+
assert(CurDeclContext);
6205+
6206+
UnresolvedMemberTypeCheckCompletionCallback Lookup(CodeCompleteTokenExpr);
6207+
llvm::SaveAndRestore<TypeCheckCompletionCallback*>
6208+
CompletionCollector(Context.CompletionCallback, &Lookup);
6209+
typeCheckContextAt(CurDeclContext, CompletionLoc);
6210+
6211+
if (!Lookup.gotCallback())
6212+
Lookup.fallbackTypeCheck(CurDeclContext);
6213+
6214+
addKeywords(CompletionContext.getResultSink(), MaybeFuncBody);
6215+
deliverUnresolvedMemberResults(Lookup.getResults(), CurDeclContext, DotLoc,
6216+
CompletionContext, Consumer);
6217+
return true;
6218+
}
61616219
default:
61626220
return false;
61636221
}
@@ -6277,6 +6335,7 @@ void CodeCompletionCallbacksImpl::doneParsing() {
62776335
switch (Kind) {
62786336
case CompletionKind::None:
62796337
case CompletionKind::DotExpr:
6338+
case CompletionKind::UnresolvedMember:
62806339
llvm_unreachable("should be already handled");
62816340
return;
62826341

@@ -6478,15 +6537,6 @@ void CodeCompletionCallbacksImpl::doneParsing() {
64786537
Lookup.addImportModuleNames();
64796538
break;
64806539
}
6481-
case CompletionKind::UnresolvedMember: {
6482-
Lookup.setHaveDot(DotLoc);
6483-
ExprContextInfo ContextInfo(CurDeclContext, CodeCompleteTokenExpr);
6484-
Lookup.setExpectedTypes(ContextInfo.getPossibleTypes(),
6485-
ContextInfo.isSingleExpressionBody());
6486-
Lookup.setIdealExpectedType(CodeCompleteTokenExpr->getType());
6487-
Lookup.getUnresolvedMemberCompletions(ContextInfo.getPossibleTypes());
6488-
break;
6489-
}
64906540
case CompletionKind::CallArg: {
64916541
ExprContextInfo ContextInfo(CurDeclContext, CodeCompleteTokenExpr);
64926542

lib/Parse/ParseExpr.cpp

Lines changed: 14 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -3581,26 +3581,22 @@ Parser::parseExprCollectionElement(Optional<bool> &isDictionary) {
35813581
return Element;
35823582

35833583
// Parse the ':'.
3584-
if (!consumeIf(tok::colon)) {
3585-
if (Element.hasCodeCompletion()) {
3586-
// Return the completion expression itself so we can analyze the type
3587-
// later.
3588-
return Element;
3584+
ParserResult<Expr> Value;
3585+
if (consumeIf(tok::colon)) {
3586+
// Parse the value.
3587+
Value = parseExpr(diag::expected_value_in_dictionary_literal);
3588+
if (Value.isNull()) {
3589+
if (!Element.hasCodeCompletion()) {
3590+
Value = makeParserResult(Value, new (Context) ErrorExpr(PreviousLoc));
3591+
} else {
3592+
Value = makeParserResult(Value,
3593+
new (Context) CodeCompletionExpr(PreviousLoc));
3594+
}
35893595
}
3596+
} else {
35903597
diagnose(Tok, diag::expected_colon_in_dictionary_literal);
3591-
return ParserStatus(Element) | makeParserError();
3592-
}
3593-
3594-
// Parse the value.
3595-
auto Value = parseExpr(diag::expected_value_in_dictionary_literal);
3596-
3597-
if (Value.isNull()) {
3598-
if (!Element.hasCodeCompletion()) {
3599-
Value = makeParserResult(Value, new (Context) ErrorExpr(PreviousLoc));
3600-
} else {
3601-
Value = makeParserResult(Value,
3602-
new (Context) CodeCompletionExpr(PreviousLoc));
3603-
}
3598+
Value = makeParserResult(makeParserError(),
3599+
new (Context) ErrorExpr(SourceRange()));
36043600
}
36053601

36063602
// Make a tuple of Key Value pair.

lib/Sema/CSBindings.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1384,7 +1384,8 @@ TypeVariableBinding::fixForHole(ConstraintSystem &cs) const {
13841384
// under-constrained due to e.g. lack of expressions on the
13851385
// right-hand side of the token, which are required for a
13861386
// regular type-check.
1387-
if (dstLocator->directlyAt<CodeCompletionExpr>())
1387+
if (dstLocator->directlyAt<CodeCompletionExpr>() ||
1388+
srcLocator->directlyAt<CodeCompletionExpr>())
13881389
return None;
13891390
}
13901391

lib/Sema/CSGen.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1020,7 +1020,7 @@ namespace {
10201020
// independent of the wider expression containing the ErrorExpr, so
10211021
// there's no point attempting to produce a solution for it.
10221022
SourceRange range = E->getSourceRange();
1023-
if (range.isInvalid() ||
1023+
if (range.isValid() &&
10241024
CS.getASTContext().SourceMgr.rangeContainsCodeCompletionLoc(range))
10251025
return nullptr;
10261026

lib/Sema/CSRanking.cpp

Lines changed: 63 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -32,74 +32,83 @@ using namespace constraints;
3232
#define DEBUG_TYPE "Constraint solver overall"
3333
STATISTIC(NumDiscardedSolutions, "Number of solutions discarded");
3434

35-
void ConstraintSystem::increaseScore(ScoreKind kind, unsigned value) {
36-
unsigned index = static_cast<unsigned>(kind);
37-
CurrentScore.Data[index] += value;
35+
static StringRef getScoreKindName(ScoreKind kind) {
36+
switch (kind) {
37+
case SK_Hole:
38+
return "hole in the constraint system";
3839

39-
if (isDebugMode() && value > 0) {
40-
if (solverState)
41-
llvm::errs().indent(solverState->depth * 2);
42-
llvm::errs() << "(increasing score due to ";
43-
switch (kind) {
44-
case SK_Hole:
45-
llvm::errs() << "hole in the constraint system";
46-
break;
40+
case SK_Unavailable:
41+
return "use of an unavailable declaration";
4742

48-
case SK_Unavailable:
49-
llvm::errs() << "use of an unavailable declaration";
50-
break;
43+
case SK_AsyncSyncMismatch:
44+
return "async/synchronous mismatch";
5145

52-
case SK_AsyncSyncMismatch:
53-
llvm::errs() << "async/synchronous mismatch";
54-
break;
46+
case SK_ForwardTrailingClosure:
47+
return "forward scan when matching a trailing closure";
5548

56-
case SK_ForwardTrailingClosure:
57-
llvm::errs() << "forward scan when matching a trailing closure";
58-
break;
49+
case SK_Fix:
50+
return "attempting to fix the source";
5951

60-
case SK_Fix:
61-
llvm::errs() << "attempting to fix the source";
62-
break;
52+
case SK_DisfavoredOverload:
53+
return "disfavored overload";
6354

64-
case SK_DisfavoredOverload:
65-
llvm::errs() << "disfavored overload";
66-
break;
55+
case SK_ForceUnchecked:
56+
return "force of an implicitly unwrapped optional";
6757

68-
case SK_ForceUnchecked:
69-
llvm::errs() << "force of an implicitly unwrapped optional";
70-
break;
58+
case SK_UserConversion:
59+
return "user conversion";
7160

72-
case SK_UserConversion:
73-
llvm::errs() << "user conversion";
74-
break;
61+
case SK_FunctionConversion:
62+
return "function conversion";
7563

76-
case SK_FunctionConversion:
77-
llvm::errs() << "function conversion";
78-
break;
64+
case SK_NonDefaultLiteral:
65+
return "non-default literal";
66+
67+
case SK_CollectionUpcastConversion:
68+
return "collection upcast conversion";
69+
70+
case SK_ValueToOptional:
71+
return "value to optional";
72+
73+
case SK_EmptyExistentialConversion:
74+
return "empty-existential conversion";
75+
76+
case SK_KeyPathSubscript:
77+
return "key path subscript";
7978

79+
case SK_ValueToPointerConversion:
80+
return "value-to-pointer conversion";
81+
}
82+
}
83+
84+
void ConstraintSystem::increaseScore(ScoreKind kind, unsigned value) {
85+
if (isForCodeCompletion()) {
86+
switch (kind) {
8087
case SK_NonDefaultLiteral:
81-
llvm::errs() << "non-default literal";
82-
break;
83-
84-
case SK_CollectionUpcastConversion:
85-
llvm::errs() << "collection upcast conversion";
86-
break;
87-
88-
case SK_ValueToOptional:
89-
llvm::errs() << "value to optional";
90-
break;
91-
case SK_EmptyExistentialConversion:
92-
llvm::errs() << "empty-existential conversion";
93-
break;
94-
case SK_KeyPathSubscript:
95-
llvm::errs() << "key path subscript";
96-
break;
97-
case SK_ValueToPointerConversion:
98-
llvm::errs() << "value-to-pointer conversion";
88+
// Don't increase score for non-default literals in expressions involving
89+
// a code completion. In the below example, members of EnumA and EnumB
90+
// should be ranked equally:
91+
// func overloaded(_ x: Float, _ y: EnumA) {}
92+
// func overloaded(_ x: Int, _ y: EnumB) {}
93+
// func overloaded(_ x: Float) -> EnumA {}
94+
// func overloaded(_ x: Int) -> EnumB {}
95+
//
96+
// overloaded(1, .<complete>) {}
97+
// overloaded(1).<complete>
98+
return;
99+
default:
99100
break;
100101
}
101-
llvm::errs() << ")\n";
102102
}
103+
104+
if (isDebugMode() && value > 0) {
105+
if (solverState)
106+
llvm::errs().indent(solverState->depth * 2);
107+
llvm::errs() << "(increasing score due to " << getScoreKindName(kind) << ")\n";
108+
}
109+
110+
unsigned index = static_cast<unsigned>(kind);
111+
CurrentScore.Data[index] += value;
103112
}
104113

105114
bool ConstraintSystem::worseThanBestSolution() const {

0 commit comments

Comments
 (0)