Skip to content

Commit 35b17d7

Browse files
committed
SourceKit: allow expression type request to specify a list of protocol USRs for filtering
The client usually cares about a subset of all expressions. A way to differentiate them is by the protocols these expressions' types conform to. This patch allows the request to add a list of protocol USRs so that the response only includes those interested expressions that conform to any of the input protocols. We also add a field to the response for each expression type to indicate the conforming protocols names that were originally in the input list. When an empty list of protocol USRs are given, we report all expressions' types in the file like the old behavior. rdar://35199889
1 parent 70a0d4a commit 35b17d7

File tree

15 files changed

+436
-80
lines changed

15 files changed

+436
-80
lines changed

include/swift/Sema/IDETypeChecking.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
#ifndef SWIFT_SEMA_IDETYPECHECKING_H
2020
#define SWIFT_SEMA_IDETYPECHECKING_H
2121

22+
#include "llvm/ADT/MapVector.h"
2223
#include "swift/Basic/SourceLoc.h"
2324
#include <memory>
2425

@@ -177,12 +178,21 @@ namespace swift {
177178

178179
/// The length of the printed type
179180
uint32_t typeLength;
181+
182+
/// The offsets and lengths of all protocols the type conforms to
183+
std::vector<std::pair<uint32_t, uint32_t>> protocols;
180184
};
181185

182186
/// Collect type information for every expression in \c SF; all types will
183187
/// be printed to \c OS.
184188
ArrayRef<ExpressionTypeInfo> collectExpressionType(SourceFile &SF,
189+
ArrayRef<const char *> ExpectedProtocols,
185190
std::vector<ExpressionTypeInfo> &scratch, llvm::raw_ostream &OS);
191+
192+
/// Resolve a list of mangled names to accessible protocol decls from
193+
/// the decl context.
194+
bool resolveProtocolNames(DeclContext *DC, ArrayRef<const char *> names,
195+
llvm::MapVector<ProtocolDecl*, StringRef> &result);
186196
}
187197

188198
#endif

lib/IDE/ConformingMethodList.cpp

Lines changed: 8 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -32,9 +32,8 @@ class ConformingMethodListCallbacks : public CodeCompletionCallbacks {
3232
Expr *ParsedExpr = nullptr;
3333
DeclContext *CurDeclContext = nullptr;
3434

35-
void resolveExpectedTypes(ArrayRef<const char *> names, SourceLoc loc,
36-
SmallVectorImpl<ProtocolDecl *> &result);
37-
void getMatchingMethods(Type T, ArrayRef<ProtocolDecl *> expectedTypes,
35+
void getMatchingMethods(Type T,
36+
llvm::MapVector<ProtocolDecl*, StringRef> &expectedTypes,
3837
SmallVectorImpl<ValueDecl *> &result);
3938

4039
public:
@@ -89,9 +88,8 @@ void ConformingMethodListCallbacks::doneParsing() {
8988
if (!T || T->is<ErrorType>() || T->is<UnresolvedType>())
9089
return;
9190

92-
SmallVector<ProtocolDecl *, 4> expectedProtocols;
93-
resolveExpectedTypes(ExpectedTypeNames, ParsedExpr->getLoc(),
94-
expectedProtocols);
91+
llvm::MapVector<ProtocolDecl*, StringRef> expectedProtocols;
92+
resolveProtocolNames(CurDeclContext, ExpectedTypeNames, expectedProtocols);
9593

9694
// Collect the matching methods.
9795
ConformingMethodListResult result(CurDeclContext, T);
@@ -100,21 +98,8 @@ void ConformingMethodListCallbacks::doneParsing() {
10098
Consumer.handleResult(result);
10199
}
102100

103-
void ConformingMethodListCallbacks::resolveExpectedTypes(
104-
ArrayRef<const char *> names, SourceLoc loc,
105-
SmallVectorImpl<ProtocolDecl *> &result) {
106-
auto &ctx = CurDeclContext->getASTContext();
107-
108-
for (auto name : names) {
109-
if (auto ty = Demangle::getTypeForMangling(ctx, name)) {
110-
if (auto Proto = dyn_cast_or_null<ProtocolDecl>(ty->getAnyGeneric()))
111-
result.push_back(Proto);
112-
}
113-
}
114-
}
115-
116101
void ConformingMethodListCallbacks::getMatchingMethods(
117-
Type T, ArrayRef<ProtocolDecl *> expectedTypes,
102+
Type T, llvm::MapVector<ProtocolDecl*, StringRef> &expectedTypes,
118103
SmallVectorImpl<ValueDecl *> &result) {
119104
if (!T->mayHaveMembers())
120105
return;
@@ -126,7 +111,7 @@ void ConformingMethodListCallbacks::getMatchingMethods(
126111
Type T;
127112

128113
/// The list of expected types.
129-
ArrayRef<ProtocolDecl *> ExpectedTypes;
114+
llvm::MapVector<ProtocolDecl*, StringRef> &ExpectedTypes;
130115

131116
/// Result sink to populate.
132117
SmallVectorImpl<ValueDecl *> &Result;
@@ -149,7 +134,7 @@ void ConformingMethodListCallbacks::getMatchingMethods(
149134

150135
// The return type conforms to any of the requested protocols.
151136
for (auto Proto : ExpectedTypes) {
152-
if (CurModule->conformsToProtocol(declTy, Proto))
137+
if (CurModule->conformsToProtocol(declTy, Proto.first))
153138
return true;
154139
}
155140

@@ -158,7 +143,7 @@ void ConformingMethodListCallbacks::getMatchingMethods(
158143

159144
public:
160145
LocalConsumer(DeclContext *DC, Type T,
161-
ArrayRef<ProtocolDecl *> expectedTypes,
146+
llvm::MapVector<ProtocolDecl*, StringRef> &expectedTypes,
162147
SmallVectorImpl<ValueDecl *> &result)
163148
: CurModule(DC->getParentModule()), T(T), ExpectedTypes(expectedTypes),
164149
Result(result) {}

lib/IDE/IDETypeChecking.cpp

Lines changed: 81 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
#include "swift/AST/GenericEnvironment.h"
2222
#include "swift/AST/Module.h"
2323
#include "swift/AST/NameLookup.h"
24+
#include "swift/AST/ASTDemangler.h"
2425
#include "swift/AST/ProtocolConformance.h"
2526
#include "swift/Sema/IDETypeChecking.h"
2627
#include "swift/IDE/SourceEntityWalker.h"
@@ -582,6 +583,7 @@ collectDefaultImplementationForProtocolMembers(ProtocolDecl *PD,
582583

583584
/// This walker will traverse the AST and report types for every expression.
584585
class ExpressionTypeCollector: public SourceEntityWalker {
586+
ModuleDecl &Module;
585587
SourceManager &SM;
586588
unsigned int BufferId;
587589
std::vector<ExpressionTypeInfo> &Results;
@@ -596,7 +598,13 @@ class ExpressionTypeCollector: public SourceEntityWalker {
596598
// [offset, length].
597599
llvm::DenseMap<unsigned, llvm::DenseSet<unsigned>> AllPrintedTypes;
598600

599-
bool shouldReport(unsigned Offset, unsigned Length, Expr *E) {
601+
// When non empty, we only print expression types that conform to any of
602+
// these protocols.
603+
llvm::MapVector<ProtocolDecl*, StringRef> &InterestedProtocols;
604+
605+
bool shouldReport(unsigned Offset, unsigned Length, Expr *E,
606+
std::vector<StringRef> &Conformances) {
607+
assert(Conformances.empty());
600608
// We shouldn't report null types.
601609
if (E->getType().isNull())
602610
return false;
@@ -605,58 +613,116 @@ class ExpressionTypeCollector: public SourceEntityWalker {
605613
// report again. This makes sure we always report the outtermost type of
606614
// several overlapping expressions.
607615
auto &Bucket = AllPrintedTypes[Offset];
608-
return Bucket.find(Length) == Bucket.end();
616+
if (Bucket.find(Length) != Bucket.end())
617+
return false;
618+
619+
// We print every expression if the interested protocols are empty.
620+
if (InterestedProtocols.empty())
621+
return true;
622+
623+
// Collecting protocols conformed by this expressions that are in the list.
624+
for (auto Proto: InterestedProtocols) {
625+
if (Module.conformsToProtocol(E->getType(), Proto.first)) {
626+
Conformances.push_back(Proto.second);
627+
}
628+
}
629+
630+
// We only print the type of the expression if it conforms to any of the
631+
// interested protocols.
632+
return !Conformances.empty();
609633
}
610634

611635
// Find an existing offset in the type buffer otherwise print the type to
612636
// the buffer.
613-
uint32_t getTypeOffsets(StringRef PrintedType) {
637+
std::pair<uint32_t, uint32_t> getTypeOffsets(StringRef PrintedType) {
614638
auto It = TypeOffsets.find(PrintedType);
615639
if (It == TypeOffsets.end()) {
616640
TypeOffsets[PrintedType] = OS.tell();
617-
OS << PrintedType;
641+
OS << PrintedType << '\0';
618642
}
619-
return TypeOffsets[PrintedType];
643+
return {TypeOffsets[PrintedType], PrintedType.size()};
620644
}
621645

646+
622647
public:
623-
ExpressionTypeCollector(SourceFile &SF, std::vector<ExpressionTypeInfo> &Results,
624-
llvm::raw_ostream &OS): SM(SF.getASTContext().SourceMgr),
648+
ExpressionTypeCollector(SourceFile &SF,
649+
llvm::MapVector<ProtocolDecl*, StringRef> &InterestedProtocols,
650+
std::vector<ExpressionTypeInfo> &Results,
651+
llvm::raw_ostream &OS): Module(*SF.getParentModule()),
652+
SM(SF.getASTContext().SourceMgr),
625653
BufferId(*SF.getBufferID()),
626-
Results(Results), OS(OS) {}
654+
Results(Results), OS(OS),
655+
InterestedProtocols(InterestedProtocols) {}
627656
bool walkToExprPre(Expr *E) override {
628657
if (E->getSourceRange().isInvalid())
629658
return true;
630659
CharSourceRange Range =
631660
Lexer::getCharSourceRangeFromSourceRange(SM, E->getSourceRange());
632661
unsigned Offset = SM.getLocOffsetInBuffer(Range.getStart(), BufferId);
633662
unsigned Length = Range.getByteLength();
634-
if (!shouldReport(Offset, Length, E))
663+
std::vector<StringRef> Conformances;
664+
if (!shouldReport(Offset, Length, E, Conformances))
635665
return true;
636666
// Print the type to a temporary buffer.
637667
SmallString<64> Buffer;
638668
{
639669
llvm::raw_svector_ostream OS(Buffer);
640670
E->getType()->getRValueType()->reconstituteSugar(true)->print(OS);
641-
// Ensure the end user can directly use the char*
642-
OS << '\0';
643671
}
644-
672+
auto Ty = getTypeOffsets(Buffer.str());
645673
// Add the type information to the result list.
646-
Results.push_back({Offset, Length, getTypeOffsets(Buffer.str()),
647-
static_cast<uint32_t>(Buffer.size()) - 1});
674+
Results.push_back({Offset, Length, Ty.first, Ty.second, {}});
675+
676+
// Adding all protocol names to the result.
677+
for(auto Con: Conformances) {
678+
auto Ty = getTypeOffsets(Con);
679+
Results.back().protocols.push_back({Ty.first, Ty.second});
680+
}
648681

649682
// Keep track of that we have a type reported for this range.
650683
AllPrintedTypes[Offset].insert(Length);
651684
return true;
652685
}
653686
};
654687

688+
bool swift::resolveProtocolNames(DeclContext *DC,
689+
ArrayRef<const char *> names,
690+
llvm::MapVector<ProtocolDecl*, StringRef> &result) {
691+
assert(result.empty());
692+
auto &ctx = DC->getASTContext();
693+
for (auto name : names) {
694+
// First try to solve by usr
695+
ProtocolDecl *pd = dyn_cast_or_null<ProtocolDecl>(Demangle::
696+
getTypeDeclForUSR(ctx, name));
697+
if (!pd) {
698+
// Second try to solve by mangled symbol name
699+
pd = dyn_cast_or_null<ProtocolDecl>(Demangle::getTypeDeclForMangling(ctx, name));
700+
}
701+
if (!pd) {
702+
// Thirdly try to solve by mangled type name
703+
if (auto ty = Demangle::getTypeForMangling(ctx, name)) {
704+
pd = dyn_cast_or_null<ProtocolDecl>(ty->getAnyGeneric());
705+
}
706+
}
707+
if (pd) {
708+
result.insert({pd, name});
709+
}
710+
}
711+
if (names.size() == result.size())
712+
return false;
713+
// If we resolved none but the given names are not empty, return true for failure.
714+
return result.size() == 0;
715+
}
716+
655717
ArrayRef<ExpressionTypeInfo>
656718
swift::collectExpressionType(SourceFile &SF,
719+
ArrayRef<const char *> ExpectedProtocols,
657720
std::vector<ExpressionTypeInfo> &Scratch,
658721
llvm::raw_ostream &OS) {
659-
ExpressionTypeCollector Walker(SF, Scratch, OS);
722+
llvm::MapVector<ProtocolDecl*, StringRef> InterestedProtocols;
723+
if (resolveProtocolNames(&SF, ExpectedProtocols, InterestedProtocols))
724+
return {};
725+
ExpressionTypeCollector Walker(SF, InterestedProtocols, Scratch, OS);
660726
Walker.walk(SF);
661727
return Scratch;
662728
}
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
protocol ProtEmpty {}
2+
3+
protocol Prot {}
4+
5+
protocol Prot1 {}
6+
7+
class Clas: Prot {
8+
var value: Clas { return self }
9+
func getValue() -> Clas { return self }
10+
}
11+
12+
struct Stru: Prot, Prot1 {
13+
var value: Stru { return self }
14+
func getValue() -> Stru { return self }
15+
}
16+
17+
class C {}
18+
19+
func ArrayC(_ a: [C]) {
20+
_ = a.count
21+
_ = a.description.count.advanced(by: 1).description
22+
_ = a[0]
23+
}
24+
25+
func ArrayClas(_ a: [Clas]) {
26+
_ = a[0].value.getValue().value
27+
}
28+
29+
func ArrayClas(_ a: [Stru]) {
30+
_ = a[0].value.getValue().value
31+
}

test/IDE/expr_type_filtered.swift

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
// RUN: %target-swift-ide-test -print-expr-type -source-filename %S/Inputs/ExprTypeFiltered.swift -swift-version 5 -module-name filtered -usr-filter 's:8filtered9ProtEmptyP' | %FileCheck %s -check-prefix=EMPTY
2+
// RUN: %target-swift-ide-test -print-expr-type -source-filename %S/Inputs/ExprTypeFiltered.swift -swift-version 5 -module-name filtered -usr-filter 's:8filtered4ProtP' | %FileCheck %s -check-prefix=PROTO
3+
// RUN: %target-swift-ide-test -print-expr-type -source-filename %S/Inputs/ExprTypeFiltered.swift -swift-version 5 -module-name filtered -usr-filter 's:8filtered5Prot1P' | %FileCheck %s -check-prefix=PROTO1
4+
5+
// EMPTY: class Clas: Prot {
6+
// EMPTY: var value: Clas { return self }
7+
// EMPTY: func getValue() -> Clas { return self }
8+
// EMPTY: }
9+
// EMPTY: struct Stru: Prot, Prot1 {
10+
// EMPTY: var value: Stru { return self }
11+
// EMPTY: func getValue() -> Stru { return self }
12+
// EMPTY: }
13+
14+
// PROTO: class Clas: Prot {
15+
// PROTO: var value: Clas { return <expr type:"Clas">self</expr> }
16+
// PROTO: func getValue() -> Clas { return <expr type:"Clas">self</expr> }
17+
// PROTO: }
18+
// PROTO: struct Stru: Prot, Prot1 {
19+
// PROTO: var value: Stru { return <expr type:"Stru">self</expr> }
20+
// PROTO: func getValue() -> Stru { return <expr type:"Stru">self</expr> }
21+
// PROTO: }
22+
23+
// PROTO1: class Clas: Prot {
24+
// PROTO1: var value: Clas { return self }
25+
// PROTO1: func getValue() -> Clas { return self }
26+
// PROTO1: }
27+
// PROTO1: struct Stru: Prot, Prot1 {
28+
// PROTO1: var value: Stru { return <expr type:"Stru">self</expr> }
29+
// PROTO1: func getValue() -> Stru { return <expr type:"Stru">self</expr> }
30+
// PROTO1: }

0 commit comments

Comments
 (0)