Skip to content

Commit 033f9c7

Browse files
committed
[Type checker] Introduce "contextual patterns".
Contextual pattern describes a particular pattern with enough contextual information to determine its type. Use this to simplify TypeChecker::typeCheckPattern()'s interface in a manner that will admit request'ification.
1 parent 9c8351b commit 033f9c7

File tree

7 files changed

+149
-31
lines changed

7 files changed

+149
-31
lines changed

include/swift/AST/Pattern.h

Lines changed: 87 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -756,7 +756,93 @@ inline Pattern *Pattern::getSemanticsProvidingPattern() {
756756
return vp->getSubPattern()->getSemanticsProvidingPattern();
757757
return this;
758758
}
759-
759+
760+
/// Describes a pattern and the context in which it occurs.
761+
class ContextualPattern {
762+
/// The pattern and whether this is the top-level pattern.
763+
llvm::PointerIntPair<Pattern *, 1, bool> patternAndTopLevel;
764+
765+
/// Either the declaration context or the enclosing pattern binding
766+
/// declaration.
767+
llvm::PointerUnion<PatternBindingDecl *, DeclContext *> declOrContext;
768+
769+
/// Index into the pattern binding declaration, when there is one.
770+
unsigned index = 0;
771+
772+
ContextualPattern(
773+
Pattern *pattern, bool topLevel,
774+
llvm::PointerUnion<PatternBindingDecl *, DeclContext *> declOrContext,
775+
unsigned index
776+
) : patternAndTopLevel(pattern, topLevel),
777+
declOrContext(declOrContext),
778+
index(index) { }
779+
780+
public:
781+
/// Produce a contextual pattern for a pattern binding declaration entry.
782+
static ContextualPattern forPatternBindingDecl(
783+
PatternBindingDecl *pbd, unsigned index);
784+
785+
/// Produce a contextual pattern for a raw pattern that always allows
786+
/// inference.
787+
static ContextualPattern forRawPattern(Pattern *pattern, DeclContext *dc) {
788+
return ContextualPattern(pattern, /*topLevel=*/true, dc, /*index=*/0);
789+
}
790+
791+
/// Retrieve a contextual pattern for the given subpattern.
792+
ContextualPattern forSubPattern(
793+
Pattern *subpattern, bool retainTopLevel) const {
794+
return ContextualPattern(
795+
subpattern, isTopLevel() && retainTopLevel, declOrContext, index);
796+
}
797+
798+
/// Retrieve the pattern.
799+
Pattern *getPattern() const {
800+
return patternAndTopLevel.getPointer();
801+
}
802+
803+
/// Whether this is the top-level pattern in this context.
804+
bool isTopLevel() const {
805+
return patternAndTopLevel.getInt();
806+
}
807+
808+
/// Retrieve the declaration context of the pattern.
809+
DeclContext *getDeclContext() const;
810+
811+
/// Retrieve the pattern binding declaration that owns this pattern, if
812+
/// there is one.
813+
PatternBindingDecl *getPatternBindingDecl() const;
814+
815+
/// Retrieve the index into the pattern binding declaration for the top-level
816+
/// pattern.
817+
unsigned getPatternBindingIndex() const {
818+
assert(getPatternBindingDecl() != nullptr);
819+
return index;
820+
}
821+
822+
/// Whether this pattern allows type inference, e.g., from an initializer
823+
/// expression.
824+
bool allowsInference() const;
825+
826+
friend llvm::hash_code hash_value(const ContextualPattern &pattern) {
827+
return llvm::hash_combine(pattern.getPattern(),
828+
pattern.isTopLevel(),
829+
pattern.declOrContext);
830+
}
831+
832+
friend bool operator==(const ContextualPattern &lhs,
833+
const ContextualPattern &rhs) {
834+
return lhs.patternAndTopLevel == rhs.patternAndTopLevel &&
835+
lhs.declOrContext == rhs.declOrContext;
836+
}
837+
838+
friend bool operator!=(const ContextualPattern &lhs,
839+
const ContextualPattern &rhs) {
840+
return !(lhs == rhs);
841+
}
842+
};
843+
844+
void simple_display(llvm::raw_ostream &out, const ContextualPattern &pattern);
845+
760846
} // end namespace swift
761847

762848
#endif

lib/AST/Pattern.cpp

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -496,3 +496,33 @@ const UnifiedStatsReporter::TraceFormatter*
496496
FrontendStatsTracer::getTraceFormatter<const Pattern *>() {
497497
return &TF;
498498
}
499+
500+
501+
ContextualPattern ContextualPattern::forPatternBindingDecl(
502+
PatternBindingDecl *pbd, unsigned index) {
503+
return ContextualPattern(
504+
pbd->getPattern(index), /*isTopLevel=*/true, pbd, index);
505+
}
506+
507+
DeclContext *ContextualPattern::getDeclContext() const {
508+
if (auto pbd = getPatternBindingDecl())
509+
return pbd->getDeclContext();
510+
511+
return declOrContext.get<DeclContext *>();
512+
}
513+
514+
PatternBindingDecl *ContextualPattern::getPatternBindingDecl() const {
515+
return declOrContext.dyn_cast<PatternBindingDecl *>();
516+
}
517+
518+
bool ContextualPattern::allowsInference() const {
519+
if (auto pbd = getPatternBindingDecl())
520+
return pbd->isInitialized(index);
521+
522+
return true;
523+
}
524+
525+
void swift::simple_display(llvm::raw_ostream &out,
526+
const ContextualPattern &pattern) {
527+
out << "(pattern @ " << pattern.getPattern() << ")";
528+
}

lib/Sema/CSGen.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2159,9 +2159,9 @@ namespace {
21592159

21602160
case PatternKind::Typed: {
21612161
// FIXME: Need a better locator for a pattern as a base.
2162-
TypeResolutionOptions options(TypeResolverContext::InExpression);
2163-
options |= TypeResolutionFlags::AllowUnboundGenerics;
2164-
Type type = TypeChecker::typeCheckPattern(pattern, CurDC, options);
2162+
auto contextualPattern =
2163+
ContextualPattern::forRawPattern(pattern, CurDC);
2164+
Type type = TypeChecker::typeCheckPattern(contextualPattern);
21652165
Type openedType = CS.openUnboundGenericType(type, locator);
21662166

21672167
// For a typed pattern, simply return the opened type of the pattern.

lib/Sema/TypeCheckConstraints.cpp

Lines changed: 7 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2864,10 +2864,8 @@ bool TypeChecker::typeCheckPatternBinding(PatternBindingDecl *PBD,
28642864
if (pattern->hasType())
28652865
patternType = pattern->getType();
28662866
else {
2867-
TypeResolutionOptions options(TypeResolverContext::InExpression);
2868-
options |= TypeResolutionFlags::AllowUnspecifiedTypes;
2869-
options |= TypeResolutionFlags::AllowUnboundGenerics;
2870-
patternType = typeCheckPattern(pattern, DC, options);
2867+
auto contextualPattern = ContextualPattern::forRawPattern(pattern, DC);
2868+
patternType = typeCheckPattern(contextualPattern);
28712869
}
28722870

28732871
if (patternType->hasError()) {
@@ -2995,11 +2993,9 @@ bool TypeChecker::typeCheckForEachBinding(DeclContext *dc, ForEachStmt *stmt) {
29952993
return true;
29962994
}
29972995

2998-
TypeResolutionOptions options(TypeResolverContext::InExpression);
2999-
options |= TypeResolutionFlags::AllowUnspecifiedTypes;
3000-
options |= TypeResolutionFlags::AllowUnboundGenerics;
3001-
Type patternType = TypeChecker::typeCheckPattern(
3002-
Stmt->getPattern(), DC, options);
2996+
auto contextualPattern =
2997+
ContextualPattern::forRawPattern(Stmt->getPattern(), DC);
2998+
Type patternType = TypeChecker::typeCheckPattern(contextualPattern);
30032999
if (patternType->hasError()) {
30043000
// FIXME: Handle errors better.
30053001
Stmt->getPattern()->setType(ErrorType::get(ctx));
@@ -3221,10 +3217,8 @@ bool TypeChecker::typeCheckStmtCondition(StmtCondition &cond, DeclContext *dc,
32213217

32223218
// Check the pattern, it allows unspecified types because the pattern can
32233219
// provide type information.
3224-
TypeResolutionOptions options(TypeResolverContext::InExpression);
3225-
options |= TypeResolutionFlags::AllowUnspecifiedTypes;
3226-
options |= TypeResolutionFlags::AllowUnboundGenerics;
3227-
Type patternType = TypeChecker::typeCheckPattern(pattern, dc, options);
3220+
auto contextualPattern = ContextualPattern::forRawPattern(pattern, dc);
3221+
Type patternType = TypeChecker::typeCheckPattern(contextualPattern);
32283222
if (patternType->hasError()) {
32293223
typeCheckPatternFailed();
32303224
continue;

lib/Sema/TypeCheckPattern.cpp

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -703,8 +703,20 @@ static Type validateTypedPattern(TypeResolution resolution,
703703
return TL.getType();
704704
}
705705

706-
Type TypeChecker::typeCheckPattern(Pattern *P, DeclContext *dc,
707-
TypeResolutionOptions options) {
706+
Type TypeChecker::typeCheckPattern(ContextualPattern pattern) {
707+
Pattern *P = pattern.getPattern();
708+
DeclContext *dc = pattern.getDeclContext();
709+
710+
TypeResolutionOptions options(pattern.getPatternBindingDecl()
711+
? TypeResolverContext::PatternBindingDecl
712+
: TypeResolverContext::InExpression);
713+
if (pattern.allowsInference()) {
714+
options |= TypeResolutionFlags::AllowUnspecifiedTypes;
715+
options |= TypeResolutionFlags::AllowUnboundGenerics;
716+
}
717+
if (!pattern.isTopLevel())
718+
options = options.withoutContext();
719+
708720
auto &Context = dc->getASTContext();
709721
switch (P->getKind()) {
710722
// Type-check paren patterns by checking the sub-pattern and
@@ -716,7 +728,8 @@ Type TypeChecker::typeCheckPattern(Pattern *P, DeclContext *dc,
716728
SP = PP->getSubPattern();
717729
else
718730
SP = cast<VarPattern>(P)->getSubPattern();
719-
Type subType = TypeChecker::typeCheckPattern(SP, dc, options);
731+
Type subType = TypeChecker::typeCheckPattern(
732+
pattern.forSubPattern(SP, /*retainTopLevel=*/true));
720733
if (subType->hasError())
721734
return ErrorType::get(Context);
722735

@@ -757,11 +770,10 @@ Type TypeChecker::typeCheckPattern(Pattern *P, DeclContext *dc,
757770
bool hadError = false;
758771
SmallVector<TupleTypeElt, 8> typeElts;
759772

760-
const auto elementOptions = options.withoutContext();
761773
for (unsigned i = 0, e = tuplePat->getNumElements(); i != e; ++i) {
762774
TuplePatternElt &elt = tuplePat->getElement(i);
763-
Pattern *pattern = elt.getPattern();
764-
Type subType = TypeChecker::typeCheckPattern(pattern, dc, elementOptions);
775+
Type subType = TypeChecker::typeCheckPattern(
776+
pattern.forSubPattern(elt.getPattern(), /*retainTopLevel=*/false));
765777
if (subType->hasError())
766778
hadError = true;
767779

lib/Sema/TypeCheckStorage.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -206,6 +206,8 @@ PatternBindingEntryRequest::evaluate(Evaluator &eval,
206206
// In particular, it's /not/ correct to check the PBD's DeclContext because
207207
// top-level variables in a script file are accessible from other files,
208208
// even though the PBD is inside a TopLevelCodeDecl.
209+
auto contextualPattern =
210+
ContextualPattern::forPatternBindingDecl(binding, entryNumber);
209211
TypeResolutionOptions options(TypeResolverContext::PatternBindingDecl);
210212

211213
if (binding->isInitialized(entryNumber)) {
@@ -214,8 +216,7 @@ PatternBindingEntryRequest::evaluate(Evaluator &eval,
214216
options |= TypeResolutionFlags::AllowUnboundGenerics;
215217
}
216218

217-
Type patternType = TypeChecker::typeCheckPattern(
218-
pattern, binding->getDeclContext(), options);
219+
Type patternType = TypeChecker::typeCheckPattern(contextualPattern);
219220
if (patternType->hasError()) {
220221
swift::setBoundVarsTypeError(pattern, Context);
221222
binding->setInvalid();

lib/Sema/TypeChecker.h

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -999,16 +999,11 @@ class TypeChecker final {
999999

10001000
/// Type check the given pattern.
10011001
///
1002-
/// \param P The pattern to type check.
1003-
/// \param dc The context in which type checking occurs.
1004-
/// \param options Options that control type resolution.
1005-
///
10061002
/// \returns the type of the pattern, which may be an error type if an
10071003
/// unrecoverable error occurred. If the options permit it, the type may
10081004
/// involve \c UnresolvedType (for patterns with no type information) and
10091005
/// unbound generic types.
1010-
static Type typeCheckPattern(Pattern *P, DeclContext *dc,
1011-
TypeResolutionOptions options);
1006+
static Type typeCheckPattern(ContextualPattern pattern);
10121007

10131008
static bool typeCheckCatchPattern(CatchStmt *S, DeclContext *dc);
10141009

0 commit comments

Comments
 (0)