Skip to content

Commit 28c8fb8

Browse files
authored
Merge pull request swiftlang#12855 from slavapestov/fix-extension-binding
More efficient extension binding
2 parents 2ec6eb5 + 3f51dbc commit 28c8fb8

File tree

7 files changed

+52
-28
lines changed

7 files changed

+52
-28
lines changed

lib/Sema/TypeCheckDecl.cpp

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8084,6 +8084,8 @@ static Type formExtensionInterfaceType(Type type,
80848084
parentType = unbound->getParent();
80858085
nominal = cast<NominalTypeDecl>(unbound->getDecl());
80868086
} else {
8087+
if (type->is<ProtocolCompositionType>())
8088+
type = type->getCanonicalType();
80878089
auto nominalType = type->castTo<NominalType>();
80888090
parentType = nominalType->getParent();
80898091
nominal = nominalType->getDecl();
@@ -8100,11 +8102,8 @@ static Type formExtensionInterfaceType(Type type,
81008102

81018103
// If we don't have generic parameters at this level, just build the result.
81028104
if (!nominal->getGenericParams() || isa<ProtocolDecl>(nominal)) {
8103-
Type resultType = NominalType::get(nominal, parentType,
8104-
nominal->getASTContext());
8105-
8106-
// If the parent was unchanged, return the original pointer.
8107-
return resultType->isEqual(type) ? type : resultType;
8105+
return NominalType::get(nominal, parentType,
8106+
nominal->getASTContext());
81088107
}
81098108

81108109
// Form the bound generic type with the type parameters provided.
@@ -8113,8 +8112,7 @@ static Type formExtensionInterfaceType(Type type,
81138112
genericArgs.push_back(gp->getDeclaredInterfaceType());
81148113
}
81158114

8116-
Type resultType = BoundGenericType::get(nominal, parentType, genericArgs);
8117-
return resultType->isEqual(type) ? type : resultType;
8115+
return BoundGenericType::get(nominal, parentType, genericArgs);
81188116
}
81198117

81208118
/// Visit the given generic parameter lists from the outermost to the innermost,
@@ -8207,7 +8205,7 @@ void TypeChecker::validateExtension(ExtensionDecl *ext) {
82078205
// Check generic parameters.
82088206
GenericEnvironment *env;
82098207
std::tie(env, extendedType) = checkExtensionGenericParams(
8210-
*this, ext, ext->getExtendedType()->getCanonicalType(),
8208+
*this, ext, ext->getExtendedType(),
82118209
genericParams);
82128210

82138211
ext->getExtendedTypeLoc().setType(extendedType);

lib/Sema/TypeCheckProtocol.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2198,6 +2198,8 @@ namespace {
21982198
conformance->setState(ProtocolConformanceState::Checking);
21992199
SWIFT_DEFER { conformance->setState(ProtocolConformanceState::Complete); };
22002200

2201+
TC.validateDecl(Proto);
2202+
22012203
// If the protocol itself is invalid, there's nothing we can do.
22022204
if (Proto->isInvalid()) {
22032205
conformance->setInvalid();

lib/Sema/TypeCheckType.cpp

Lines changed: 27 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -458,6 +458,14 @@ Type TypeChecker::applyGenericArguments(Type type, TypeDecl *decl,
458458
}
459459
}
460460

461+
// Cannot extend a bound generic type.
462+
if (options.contains(TR_ExtensionBinding)) {
463+
diagnose(loc, diag::extension_specialization,
464+
genericDecl->getName())
465+
.highlight(generic->getSourceRange());
466+
return ErrorType::get(Context);
467+
}
468+
461469
// FIXME: More principled handling of circularity.
462470
if (!genericDecl->hasValidSignature()) {
463471
diagnose(loc, diag::recursive_type_reference,
@@ -676,22 +684,26 @@ static Type resolveTypeDecl(TypeChecker &TC, TypeDecl *typeDecl, SourceLoc loc,
676684
UnsatisfiedDependency *unsatisfiedDependency) {
677685
assert(fromDC && "No declaration context for type resolution?");
678686

679-
// If we have a callback to report dependencies, do so.
680-
if (unsatisfiedDependency) {
681-
if ((*unsatisfiedDependency)(requestResolveTypeDecl(typeDecl)))
682-
return nullptr;
683-
} else {
684-
// Validate the declaration.
685-
TC.validateDeclForNameLookup(typeDecl);
686-
}
687+
// Don't validate nominal type declarations during extension binding.
688+
if (!options.contains(TR_ExtensionBinding) ||
689+
!isa<NominalTypeDecl>(typeDecl)) {
690+
// If we have a callback to report dependencies, do so.
691+
if (unsatisfiedDependency) {
692+
if ((*unsatisfiedDependency)(requestResolveTypeDecl(typeDecl)))
693+
return nullptr;
694+
} else {
695+
// Validate the declaration.
696+
TC.validateDeclForNameLookup(typeDecl);
697+
}
687698

688-
// If we didn't bail out with an unsatisfiedDependency,
689-
// and were not able to validate recursively, bail out.
690-
if (!typeDecl->hasInterfaceType()) {
691-
TC.diagnose(loc, diag::recursive_type_reference,
692-
typeDecl->getDescriptiveKind(), typeDecl->getName());
693-
TC.diagnose(typeDecl->getLoc(), diag::type_declared_here);
694-
return ErrorType::get(TC.Context);
699+
// If we didn't bail out with an unsatisfiedDependency,
700+
// and were not able to validate recursively, bail out.
701+
if (!typeDecl->hasInterfaceType()) {
702+
TC.diagnose(loc, diag::recursive_type_reference,
703+
typeDecl->getDescriptiveKind(), typeDecl->getName());
704+
TC.diagnose(typeDecl->getLoc(), diag::type_declared_here);
705+
return ErrorType::get(TC.Context);
706+
}
695707
}
696708

697709
// Resolve the type declaration to a specific type. How this occurs

lib/Sema/TypeChecker.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -361,7 +361,8 @@ static void bindExtensionDecl(ExtensionDecl *ED, TypeChecker &TC) {
361361

362362
// If the extended type is generic or is a protocol. Clone or create
363363
// the generic parameters.
364-
if (extendedNominal->isGenericContext()) {
364+
if (extendedNominal->getGenericParamsOfContext() ||
365+
isa<ProtocolDecl>(extendedNominal)) {
365366
if (auto proto = dyn_cast<ProtocolDecl>(extendedNominal)) {
366367
// For a protocol extension, build the generic parameter list.
367368
ED->setGenericParams(proto->createGenericParams(ED));
@@ -376,7 +377,8 @@ static void bindExtensionDecl(ExtensionDecl *ED, TypeChecker &TC) {
376377
// If we have a trailing where clause, deal with it now.
377378
// For now, trailing where clauses are only permitted on protocol extensions.
378379
if (auto trailingWhereClause = ED->getTrailingWhereClause()) {
379-
if (!extendedNominal->isGenericContext()) {
380+
if (!(extendedNominal->getGenericParamsOfContext() ||
381+
isa<ProtocolDecl>(extendedNominal))) {
380382
// Only generic and protocol types are permitted to have
381383
// trailing where clauses.
382384
TC.diagnose(ED, diag::extension_nongeneric_trailing_where, extendedType)

test/decl/ext/generic.swift

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,10 @@ extension Double : P2 {
2020

2121
extension X<Int, Double, String> { } // expected-error{{constrained extension must be declared on the unspecialized generic type 'X' with constraints specified by a 'where' clause}}
2222

23+
typealias GGG = X<Int, Double, String>
24+
25+
extension GGG { } // expected-error{{constrained extension must be declared on the unspecialized generic type 'X' with constraints specified by a 'where' clause}}
26+
2327
// Lvalue check when the archetypes are not the same.
2428
struct LValueCheck<T> {
2529
let x = 0

test/decl/typealias/generic.swift

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -295,9 +295,9 @@ func takesSugaredType2(m: GenericClass<Int>.TA<Float>) {
295295
extension A {}
296296

297297
extension A<T> {} // expected-error {{generic type 'A' specialized with too few type parameters (got 1, but expected 2)}}
298-
extension A<Float,Int> {} // expected-error {{constrained extension must be declared on the unspecialized generic type 'MyType' with constraints specified by a 'where' clause}}
299-
extension C<T> {} // expected-error {{use of undeclared type 'T'}}
300-
extension C<Int> {} // expected-error {{constrained extension must be declared on the unspecialized generic type 'MyType' with constraints specified by a 'where' clause}}
298+
extension A<Float,Int> {} // expected-error {{constrained extension must be declared on the unspecialized generic type 'A' with constraints specified by a 'where' clause}}
299+
extension C<T> {} // expected-error {{constrained extension must be declared on the unspecialized generic type 'C' with constraints specified by a 'where' clause}}
300+
extension C<Int> {} // expected-error {{constrained extension must be declared on the unspecialized generic type 'C' with constraints specified by a 'where' clause}}
301301

302302

303303
protocol ErrorQ {
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
// RUN: %scale-test --sum-multi --typecheck --begin 5 --end 16 --step 5 --select validateDecl %s
2+
// REQUIRES: OS=macosx
3+
// REQUIRES: asserts
4+
5+
struct Struct${N} {}
6+
extension Struct${N} {}

0 commit comments

Comments
 (0)