21
21
#include " swift/AST/GenericEnvironment.h"
22
22
#include " swift/AST/Module.h"
23
23
#include " swift/AST/NameLookup.h"
24
+ #include " swift/AST/ASTDemangler.h"
24
25
#include " swift/AST/ProtocolConformance.h"
25
26
#include " swift/Sema/IDETypeChecking.h"
26
27
#include " swift/IDE/SourceEntityWalker.h"
@@ -582,6 +583,7 @@ collectDefaultImplementationForProtocolMembers(ProtocolDecl *PD,
582
583
583
584
// / This walker will traverse the AST and report types for every expression.
584
585
class ExpressionTypeCollector : public SourceEntityWalker {
586
+ ModuleDecl &Module;
585
587
SourceManager &SM;
586
588
unsigned int BufferId;
587
589
std::vector<ExpressionTypeInfo> &Results;
@@ -596,7 +598,13 @@ class ExpressionTypeCollector: public SourceEntityWalker {
596
598
// [offset, length].
597
599
llvm::DenseMap<unsigned , llvm::DenseSet<unsigned >> AllPrintedTypes;
598
600
599
- bool shouldReport (unsigned Offset, unsigned Length, Expr *E) {
601
+ // When non empty, we only print expression types that conform to any of
602
+ // these protocols.
603
+ llvm::MapVector<ProtocolDecl*, StringRef> &InterestedProtocols;
604
+
605
+ bool shouldReport (unsigned Offset, unsigned Length, Expr *E,
606
+ std::vector<StringRef> &Conformances) {
607
+ assert (Conformances.empty ());
600
608
// We shouldn't report null types.
601
609
if (E->getType ().isNull ())
602
610
return false ;
@@ -605,58 +613,116 @@ class ExpressionTypeCollector: public SourceEntityWalker {
605
613
// report again. This makes sure we always report the outtermost type of
606
614
// several overlapping expressions.
607
615
auto &Bucket = AllPrintedTypes[Offset];
608
- return Bucket.find (Length) == Bucket.end ();
616
+ if (Bucket.find (Length) != Bucket.end ())
617
+ return false ;
618
+
619
+ // We print every expression if the interested protocols are empty.
620
+ if (InterestedProtocols.empty ())
621
+ return true ;
622
+
623
+ // Collecting protocols conformed by this expressions that are in the list.
624
+ for (auto Proto: InterestedProtocols) {
625
+ if (Module.conformsToProtocol (E->getType (), Proto.first )) {
626
+ Conformances.push_back (Proto.second );
627
+ }
628
+ }
629
+
630
+ // We only print the type of the expression if it conforms to any of the
631
+ // interested protocols.
632
+ return !Conformances.empty ();
609
633
}
610
634
611
635
// Find an existing offset in the type buffer otherwise print the type to
612
636
// the buffer.
613
- uint32_t getTypeOffsets (StringRef PrintedType) {
637
+ std::pair< uint32_t , uint32_t > getTypeOffsets (StringRef PrintedType) {
614
638
auto It = TypeOffsets.find (PrintedType);
615
639
if (It == TypeOffsets.end ()) {
616
640
TypeOffsets[PrintedType] = OS.tell ();
617
- OS << PrintedType;
641
+ OS << PrintedType << ' \0 ' ;
618
642
}
619
- return TypeOffsets[PrintedType];
643
+ return { TypeOffsets[PrintedType], PrintedType. size ()} ;
620
644
}
621
645
646
+
622
647
public:
623
- ExpressionTypeCollector (SourceFile &SF, std::vector<ExpressionTypeInfo> &Results,
624
- llvm::raw_ostream &OS): SM(SF.getASTContext().SourceMgr),
648
+ ExpressionTypeCollector (SourceFile &SF,
649
+ llvm::MapVector<ProtocolDecl*, StringRef> &InterestedProtocols,
650
+ std::vector<ExpressionTypeInfo> &Results,
651
+ llvm::raw_ostream &OS): Module(*SF.getParentModule()),
652
+ SM (SF.getASTContext().SourceMgr),
625
653
BufferId(*SF.getBufferID()),
626
- Results(Results), OS(OS) {}
654
+ Results(Results), OS(OS),
655
+ InterestedProtocols(InterestedProtocols) {}
627
656
bool walkToExprPre (Expr *E) override {
628
657
if (E->getSourceRange ().isInvalid ())
629
658
return true ;
630
659
CharSourceRange Range =
631
660
Lexer::getCharSourceRangeFromSourceRange (SM, E->getSourceRange ());
632
661
unsigned Offset = SM.getLocOffsetInBuffer (Range.getStart (), BufferId);
633
662
unsigned Length = Range.getByteLength ();
634
- if (!shouldReport (Offset, Length, E))
663
+ std::vector<StringRef> Conformances;
664
+ if (!shouldReport (Offset, Length, E, Conformances))
635
665
return true ;
636
666
// Print the type to a temporary buffer.
637
667
SmallString<64 > Buffer;
638
668
{
639
669
llvm::raw_svector_ostream OS (Buffer);
640
670
E->getType ()->getRValueType ()->reconstituteSugar (true )->print (OS);
641
- // Ensure the end user can directly use the char*
642
- OS << ' \0 ' ;
643
671
}
644
-
672
+ auto Ty = getTypeOffsets (Buffer. str ());
645
673
// Add the type information to the result list.
646
- Results.push_back ({Offset, Length, getTypeOffsets (Buffer.str ()),
647
- static_cast <uint32_t >(Buffer.size ()) - 1 });
674
+ Results.push_back ({Offset, Length, Ty.first , Ty.second , {}});
675
+
676
+ // Adding all protocol names to the result.
677
+ for (auto Con: Conformances) {
678
+ auto Ty = getTypeOffsets (Con);
679
+ Results.back ().protocols .push_back ({Ty.first , Ty.second });
680
+ }
648
681
649
682
// Keep track of that we have a type reported for this range.
650
683
AllPrintedTypes[Offset].insert (Length);
651
684
return true ;
652
685
}
653
686
};
654
687
688
+ bool swift::resolveProtocolNames (DeclContext *DC,
689
+ ArrayRef<const char *> names,
690
+ llvm::MapVector<ProtocolDecl*, StringRef> &result) {
691
+ assert (result.empty ());
692
+ auto &ctx = DC->getASTContext ();
693
+ for (auto name : names) {
694
+ // First try to solve by usr
695
+ ProtocolDecl *pd = dyn_cast_or_null<ProtocolDecl>(Demangle::
696
+ getTypeDeclForUSR (ctx, name));
697
+ if (!pd) {
698
+ // Second try to solve by mangled symbol name
699
+ pd = dyn_cast_or_null<ProtocolDecl>(Demangle::getTypeDeclForMangling (ctx, name));
700
+ }
701
+ if (!pd) {
702
+ // Thirdly try to solve by mangled type name
703
+ if (auto ty = Demangle::getTypeForMangling (ctx, name)) {
704
+ pd = dyn_cast_or_null<ProtocolDecl>(ty->getAnyGeneric ());
705
+ }
706
+ }
707
+ if (pd) {
708
+ result.insert ({pd, name});
709
+ }
710
+ }
711
+ if (names.size () == result.size ())
712
+ return false ;
713
+ // If we resolved none but the given names are not empty, return true for failure.
714
+ return result.size () == 0 ;
715
+ }
716
+
655
717
ArrayRef<ExpressionTypeInfo>
656
718
swift::collectExpressionType (SourceFile &SF,
719
+ ArrayRef<const char *> ExpectedProtocols,
657
720
std::vector<ExpressionTypeInfo> &Scratch,
658
721
llvm::raw_ostream &OS) {
659
- ExpressionTypeCollector Walker (SF, Scratch, OS);
722
+ llvm::MapVector<ProtocolDecl*, StringRef> InterestedProtocols;
723
+ if (resolveProtocolNames (&SF, ExpectedProtocols, InterestedProtocols))
724
+ return {};
725
+ ExpressionTypeCollector Walker (SF, InterestedProtocols, Scratch, OS);
660
726
Walker.walk (SF);
661
727
return Scratch;
662
728
}
0 commit comments