Skip to content

Commit b9877c2

Browse files
authored
Merge pull request swiftlang#23255 from nkcsgexi/proco-filter
SourceKit: allow expression type request to specify a list of protocol USRs for filtering
2 parents 4f3a9ac + 35b17d7 commit b9877c2

File tree

15 files changed

+436
-80
lines changed

15 files changed

+436
-80
lines changed

include/swift/Sema/IDETypeChecking.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
#ifndef SWIFT_SEMA_IDETYPECHECKING_H
2020
#define SWIFT_SEMA_IDETYPECHECKING_H
2121

22+
#include "llvm/ADT/MapVector.h"
2223
#include "swift/Basic/SourceLoc.h"
2324
#include <memory>
2425

@@ -180,12 +181,21 @@ namespace swift {
180181

181182
/// The length of the printed type
182183
uint32_t typeLength;
184+
185+
/// The offsets and lengths of all protocols the type conforms to
186+
std::vector<std::pair<uint32_t, uint32_t>> protocols;
183187
};
184188

185189
/// Collect type information for every expression in \c SF; all types will
186190
/// be printed to \c OS.
187191
ArrayRef<ExpressionTypeInfo> collectExpressionType(SourceFile &SF,
192+
ArrayRef<const char *> ExpectedProtocols,
188193
std::vector<ExpressionTypeInfo> &scratch, llvm::raw_ostream &OS);
194+
195+
/// Resolve a list of mangled names to accessible protocol decls from
196+
/// the decl context.
197+
bool resolveProtocolNames(DeclContext *DC, ArrayRef<const char *> names,
198+
llvm::MapVector<ProtocolDecl*, StringRef> &result);
189199
}
190200

191201
#endif

lib/IDE/ConformingMethodList.cpp

Lines changed: 8 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -32,9 +32,8 @@ class ConformingMethodListCallbacks : public CodeCompletionCallbacks {
3232
Expr *ParsedExpr = nullptr;
3333
DeclContext *CurDeclContext = nullptr;
3434

35-
void resolveExpectedTypes(ArrayRef<const char *> names, SourceLoc loc,
36-
SmallVectorImpl<ProtocolDecl *> &result);
37-
void getMatchingMethods(Type T, ArrayRef<ProtocolDecl *> expectedTypes,
35+
void getMatchingMethods(Type T,
36+
llvm::MapVector<ProtocolDecl*, StringRef> &expectedTypes,
3837
SmallVectorImpl<ValueDecl *> &result);
3938

4039
public:
@@ -89,9 +88,8 @@ void ConformingMethodListCallbacks::doneParsing() {
8988
if (!T || T->is<ErrorType>() || T->is<UnresolvedType>())
9089
return;
9190

92-
SmallVector<ProtocolDecl *, 4> expectedProtocols;
93-
resolveExpectedTypes(ExpectedTypeNames, ParsedExpr->getLoc(),
94-
expectedProtocols);
91+
llvm::MapVector<ProtocolDecl*, StringRef> expectedProtocols;
92+
resolveProtocolNames(CurDeclContext, ExpectedTypeNames, expectedProtocols);
9593

9694
// Collect the matching methods.
9795
ConformingMethodListResult result(CurDeclContext, T);
@@ -100,21 +98,8 @@ void ConformingMethodListCallbacks::doneParsing() {
10098
Consumer.handleResult(result);
10199
}
102100

103-
void ConformingMethodListCallbacks::resolveExpectedTypes(
104-
ArrayRef<const char *> names, SourceLoc loc,
105-
SmallVectorImpl<ProtocolDecl *> &result) {
106-
auto &ctx = CurDeclContext->getASTContext();
107-
108-
for (auto name : names) {
109-
if (auto ty = Demangle::getTypeForMangling(ctx, name)) {
110-
if (auto Proto = dyn_cast_or_null<ProtocolDecl>(ty->getAnyGeneric()))
111-
result.push_back(Proto);
112-
}
113-
}
114-
}
115-
116101
void ConformingMethodListCallbacks::getMatchingMethods(
117-
Type T, ArrayRef<ProtocolDecl *> expectedTypes,
102+
Type T, llvm::MapVector<ProtocolDecl*, StringRef> &expectedTypes,
118103
SmallVectorImpl<ValueDecl *> &result) {
119104
if (!T->mayHaveMembers())
120105
return;
@@ -126,7 +111,7 @@ void ConformingMethodListCallbacks::getMatchingMethods(
126111
Type T;
127112

128113
/// The list of expected types.
129-
ArrayRef<ProtocolDecl *> ExpectedTypes;
114+
llvm::MapVector<ProtocolDecl*, StringRef> &ExpectedTypes;
130115

131116
/// Result sink to populate.
132117
SmallVectorImpl<ValueDecl *> &Result;
@@ -149,7 +134,7 @@ void ConformingMethodListCallbacks::getMatchingMethods(
149134

150135
// The return type conforms to any of the requested protocols.
151136
for (auto Proto : ExpectedTypes) {
152-
if (CurModule->conformsToProtocol(declTy, Proto))
137+
if (CurModule->conformsToProtocol(declTy, Proto.first))
153138
return true;
154139
}
155140

@@ -158,7 +143,7 @@ void ConformingMethodListCallbacks::getMatchingMethods(
158143

159144
public:
160145
LocalConsumer(DeclContext *DC, Type T,
161-
ArrayRef<ProtocolDecl *> expectedTypes,
146+
llvm::MapVector<ProtocolDecl*, StringRef> &expectedTypes,
162147
SmallVectorImpl<ValueDecl *> &result)
163148
: CurModule(DC->getParentModule()), T(T), ExpectedTypes(expectedTypes),
164149
Result(result) {}

lib/IDE/IDETypeChecking.cpp

Lines changed: 81 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
#include "swift/AST/GenericEnvironment.h"
2222
#include "swift/AST/Module.h"
2323
#include "swift/AST/NameLookup.h"
24+
#include "swift/AST/ASTDemangler.h"
2425
#include "swift/AST/ProtocolConformance.h"
2526
#include "swift/Sema/IDETypeChecking.h"
2627
#include "swift/IDE/SourceEntityWalker.h"
@@ -582,6 +583,7 @@ collectDefaultImplementationForProtocolMembers(ProtocolDecl *PD,
582583

583584
/// This walker will traverse the AST and report types for every expression.
584585
class ExpressionTypeCollector: public SourceEntityWalker {
586+
ModuleDecl &Module;
585587
SourceManager &SM;
586588
unsigned int BufferId;
587589
std::vector<ExpressionTypeInfo> &Results;
@@ -596,7 +598,13 @@ class ExpressionTypeCollector: public SourceEntityWalker {
596598
// [offset, length].
597599
llvm::DenseMap<unsigned, llvm::DenseSet<unsigned>> AllPrintedTypes;
598600

599-
bool shouldReport(unsigned Offset, unsigned Length, Expr *E) {
601+
// When non empty, we only print expression types that conform to any of
602+
// these protocols.
603+
llvm::MapVector<ProtocolDecl*, StringRef> &InterestedProtocols;
604+
605+
bool shouldReport(unsigned Offset, unsigned Length, Expr *E,
606+
std::vector<StringRef> &Conformances) {
607+
assert(Conformances.empty());
600608
// We shouldn't report null types.
601609
if (E->getType().isNull())
602610
return false;
@@ -605,58 +613,116 @@ class ExpressionTypeCollector: public SourceEntityWalker {
605613
// report again. This makes sure we always report the outtermost type of
606614
// several overlapping expressions.
607615
auto &Bucket = AllPrintedTypes[Offset];
608-
return Bucket.find(Length) == Bucket.end();
616+
if (Bucket.find(Length) != Bucket.end())
617+
return false;
618+
619+
// We print every expression if the interested protocols are empty.
620+
if (InterestedProtocols.empty())
621+
return true;
622+
623+
// Collecting protocols conformed by this expressions that are in the list.
624+
for (auto Proto: InterestedProtocols) {
625+
if (Module.conformsToProtocol(E->getType(), Proto.first)) {
626+
Conformances.push_back(Proto.second);
627+
}
628+
}
629+
630+
// We only print the type of the expression if it conforms to any of the
631+
// interested protocols.
632+
return !Conformances.empty();
609633
}
610634

611635
// Find an existing offset in the type buffer otherwise print the type to
612636
// the buffer.
613-
uint32_t getTypeOffsets(StringRef PrintedType) {
637+
std::pair<uint32_t, uint32_t> getTypeOffsets(StringRef PrintedType) {
614638
auto It = TypeOffsets.find(PrintedType);
615639
if (It == TypeOffsets.end()) {
616640
TypeOffsets[PrintedType] = OS.tell();
617-
OS << PrintedType;
641+
OS << PrintedType << '\0';
618642
}
619-
return TypeOffsets[PrintedType];
643+
return {TypeOffsets[PrintedType], PrintedType.size()};
620644
}
621645

646+
622647
public:
623-
ExpressionTypeCollector(SourceFile &SF, std::vector<ExpressionTypeInfo> &Results,
624-
llvm::raw_ostream &OS): SM(SF.getASTContext().SourceMgr),
648+
ExpressionTypeCollector(SourceFile &SF,
649+
llvm::MapVector<ProtocolDecl*, StringRef> &InterestedProtocols,
650+
std::vector<ExpressionTypeInfo> &Results,
651+
llvm::raw_ostream &OS): Module(*SF.getParentModule()),
652+
SM(SF.getASTContext().SourceMgr),
625653
BufferId(*SF.getBufferID()),
626-
Results(Results), OS(OS) {}
654+
Results(Results), OS(OS),
655+
InterestedProtocols(InterestedProtocols) {}
627656
bool walkToExprPre(Expr *E) override {
628657
if (E->getSourceRange().isInvalid())
629658
return true;
630659
CharSourceRange Range =
631660
Lexer::getCharSourceRangeFromSourceRange(SM, E->getSourceRange());
632661
unsigned Offset = SM.getLocOffsetInBuffer(Range.getStart(), BufferId);
633662
unsigned Length = Range.getByteLength();
634-
if (!shouldReport(Offset, Length, E))
663+
std::vector<StringRef> Conformances;
664+
if (!shouldReport(Offset, Length, E, Conformances))
635665
return true;
636666
// Print the type to a temporary buffer.
637667
SmallString<64> Buffer;
638668
{
639669
llvm::raw_svector_ostream OS(Buffer);
640670
E->getType()->getRValueType()->reconstituteSugar(true)->print(OS);
641-
// Ensure the end user can directly use the char*
642-
OS << '\0';
643671
}
644-
672+
auto Ty = getTypeOffsets(Buffer.str());
645673
// Add the type information to the result list.
646-
Results.push_back({Offset, Length, getTypeOffsets(Buffer.str()),
647-
static_cast<uint32_t>(Buffer.size()) - 1});
674+
Results.push_back({Offset, Length, Ty.first, Ty.second, {}});
675+
676+
// Adding all protocol names to the result.
677+
for(auto Con: Conformances) {
678+
auto Ty = getTypeOffsets(Con);
679+
Results.back().protocols.push_back({Ty.first, Ty.second});
680+
}
648681

649682
// Keep track of that we have a type reported for this range.
650683
AllPrintedTypes[Offset].insert(Length);
651684
return true;
652685
}
653686
};
654687

688+
bool swift::resolveProtocolNames(DeclContext *DC,
689+
ArrayRef<const char *> names,
690+
llvm::MapVector<ProtocolDecl*, StringRef> &result) {
691+
assert(result.empty());
692+
auto &ctx = DC->getASTContext();
693+
for (auto name : names) {
694+
// First try to solve by usr
695+
ProtocolDecl *pd = dyn_cast_or_null<ProtocolDecl>(Demangle::
696+
getTypeDeclForUSR(ctx, name));
697+
if (!pd) {
698+
// Second try to solve by mangled symbol name
699+
pd = dyn_cast_or_null<ProtocolDecl>(Demangle::getTypeDeclForMangling(ctx, name));
700+
}
701+
if (!pd) {
702+
// Thirdly try to solve by mangled type name
703+
if (auto ty = Demangle::getTypeForMangling(ctx, name)) {
704+
pd = dyn_cast_or_null<ProtocolDecl>(ty->getAnyGeneric());
705+
}
706+
}
707+
if (pd) {
708+
result.insert({pd, name});
709+
}
710+
}
711+
if (names.size() == result.size())
712+
return false;
713+
// If we resolved none but the given names are not empty, return true for failure.
714+
return result.size() == 0;
715+
}
716+
655717
ArrayRef<ExpressionTypeInfo>
656718
swift::collectExpressionType(SourceFile &SF,
719+
ArrayRef<const char *> ExpectedProtocols,
657720
std::vector<ExpressionTypeInfo> &Scratch,
658721
llvm::raw_ostream &OS) {
659-
ExpressionTypeCollector Walker(SF, Scratch, OS);
722+
llvm::MapVector<ProtocolDecl*, StringRef> InterestedProtocols;
723+
if (resolveProtocolNames(&SF, ExpectedProtocols, InterestedProtocols))
724+
return {};
725+
ExpressionTypeCollector Walker(SF, InterestedProtocols, Scratch, OS);
660726
Walker.walk(SF);
661727
return Scratch;
662728
}
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
protocol ProtEmpty {}
2+
3+
protocol Prot {}
4+
5+
protocol Prot1 {}
6+
7+
class Clas: Prot {
8+
var value: Clas { return self }
9+
func getValue() -> Clas { return self }
10+
}
11+
12+
struct Stru: Prot, Prot1 {
13+
var value: Stru { return self }
14+
func getValue() -> Stru { return self }
15+
}
16+
17+
class C {}
18+
19+
func ArrayC(_ a: [C]) {
20+
_ = a.count
21+
_ = a.description.count.advanced(by: 1).description
22+
_ = a[0]
23+
}
24+
25+
func ArrayClas(_ a: [Clas]) {
26+
_ = a[0].value.getValue().value
27+
}
28+
29+
func ArrayClas(_ a: [Stru]) {
30+
_ = a[0].value.getValue().value
31+
}

test/IDE/expr_type_filtered.swift

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
// RUN: %target-swift-ide-test -print-expr-type -source-filename %S/Inputs/ExprTypeFiltered.swift -swift-version 5 -module-name filtered -usr-filter 's:8filtered9ProtEmptyP' | %FileCheck %s -check-prefix=EMPTY
2+
// RUN: %target-swift-ide-test -print-expr-type -source-filename %S/Inputs/ExprTypeFiltered.swift -swift-version 5 -module-name filtered -usr-filter 's:8filtered4ProtP' | %FileCheck %s -check-prefix=PROTO
3+
// RUN: %target-swift-ide-test -print-expr-type -source-filename %S/Inputs/ExprTypeFiltered.swift -swift-version 5 -module-name filtered -usr-filter 's:8filtered5Prot1P' | %FileCheck %s -check-prefix=PROTO1
4+
5+
// EMPTY: class Clas: Prot {
6+
// EMPTY: var value: Clas { return self }
7+
// EMPTY: func getValue() -> Clas { return self }
8+
// EMPTY: }
9+
// EMPTY: struct Stru: Prot, Prot1 {
10+
// EMPTY: var value: Stru { return self }
11+
// EMPTY: func getValue() -> Stru { return self }
12+
// EMPTY: }
13+
14+
// PROTO: class Clas: Prot {
15+
// PROTO: var value: Clas { return <expr type:"Clas">self</expr> }
16+
// PROTO: func getValue() -> Clas { return <expr type:"Clas">self</expr> }
17+
// PROTO: }
18+
// PROTO: struct Stru: Prot, Prot1 {
19+
// PROTO: var value: Stru { return <expr type:"Stru">self</expr> }
20+
// PROTO: func getValue() -> Stru { return <expr type:"Stru">self</expr> }
21+
// PROTO: }
22+
23+
// PROTO1: class Clas: Prot {
24+
// PROTO1: var value: Clas { return self }
25+
// PROTO1: func getValue() -> Clas { return self }
26+
// PROTO1: }
27+
// PROTO1: struct Stru: Prot, Prot1 {
28+
// PROTO1: var value: Stru { return <expr type:"Stru">self</expr> }
29+
// PROTO1: func getValue() -> Stru { return <expr type:"Stru">self</expr> }
30+
// PROTO1: }

0 commit comments

Comments
 (0)