Skip to content

Commit 12e873d

Browse files
authored
Merge pull request #21332 from DougGregor/ext-typealias-of-specialized-5.0
[5.0] [Type checker] Allow extensions of typealiases naming generic specializations
2 parents f9f9781 + 3ff5339 commit 12e873d

File tree

3 files changed

+134
-30
lines changed

3 files changed

+134
-30
lines changed

lib/Sema/TypeCheckDecl.cpp

Lines changed: 53 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -4844,26 +4844,21 @@ static bool isPassThroughTypealias(TypeAliasDecl *typealias) {
48444844

48454845
/// Form the interface type of an extension from the raw type and the
48464846
/// extension's list of generic parameters.
4847-
static Type formExtensionInterfaceType(TypeChecker &tc, ExtensionDecl *ext,
4848-
Type type,
4849-
GenericParamList *genericParams,
4850-
bool &mustInferRequirements) {
4847+
static Type formExtensionInterfaceType(
4848+
TypeChecker &tc, ExtensionDecl *ext,
4849+
Type type,
4850+
GenericParamList *genericParams,
4851+
SmallVectorImpl<std::pair<Type, Type>> &sameTypeReqs,
4852+
bool &mustInferRequirements) {
48514853
if (type->is<ErrorType>())
48524854
return type;
48534855

48544856
// Find the nominal type declaration and its parent type.
4855-
Type parentType;
4856-
GenericTypeDecl *genericDecl;
4857-
if (auto unbound = type->getAs<UnboundGenericType>()) {
4858-
parentType = unbound->getParent();
4859-
genericDecl = unbound->getDecl();
4860-
} else {
4861-
if (type->is<ProtocolCompositionType>())
4862-
type = type->getCanonicalType();
4863-
auto nominalType = type->castTo<NominalType>();
4864-
parentType = nominalType->getParent();
4865-
genericDecl = nominalType->getDecl();
4866-
}
4857+
if (type->is<ProtocolCompositionType>())
4858+
type = type->getCanonicalType();
4859+
4860+
Type parentType = type->getNominalParent();
4861+
GenericTypeDecl *genericDecl = type->getAnyGeneric();
48674862

48684863
// Reconstruct the parent, if there is one.
48694864
if (parentType) {
@@ -4873,7 +4868,7 @@ static Type formExtensionInterfaceType(TypeChecker &tc, ExtensionDecl *ext,
48734868
: genericParams;
48744869
parentType =
48754870
formExtensionInterfaceType(tc, ext, parentType, parentGenericParams,
4876-
mustInferRequirements);
4871+
sameTypeReqs, mustInferRequirements);
48774872
}
48784873

48794874
// Find the nominal type.
@@ -4891,9 +4886,20 @@ static Type formExtensionInterfaceType(TypeChecker &tc, ExtensionDecl *ext,
48914886
resultType = NominalType::get(nominal, parentType,
48924887
nominal->getASTContext());
48934888
} else {
4889+
auto currentBoundType = type->getAs<BoundGenericType>();
4890+
48944891
// Form the bound generic type with the type parameters provided.
4892+
unsigned gpIndex = 0;
48954893
for (auto gp : *genericParams) {
4896-
genericArgs.push_back(gp->getDeclaredInterfaceType());
4894+
SWIFT_DEFER { ++gpIndex; };
4895+
4896+
auto gpType = gp->getDeclaredInterfaceType();
4897+
genericArgs.push_back(gpType);
4898+
4899+
if (currentBoundType) {
4900+
sameTypeReqs.push_back({gpType,
4901+
currentBoundType->getGenericArgs()[gpIndex]});
4902+
}
48974903
}
48984904

48994905
resultType = BoundGenericType::get(nominal, parentType, genericArgs);
@@ -4930,8 +4936,9 @@ checkExtensionGenericParams(TypeChecker &tc, ExtensionDecl *ext, Type type,
49304936

49314937
// Form the interface type of the extension.
49324938
bool mustInferRequirements = false;
4939+
SmallVector<std::pair<Type, Type>, 4> sameTypeReqs;
49334940
Type extInterfaceType =
4934-
formExtensionInterfaceType(tc, ext, type, genericParams,
4941+
formExtensionInterfaceType(tc, ext, type, genericParams, sameTypeReqs,
49354942
mustInferRequirements);
49364943

49374944
// Local function used to infer requirements from the extended type.
@@ -4943,18 +4950,34 @@ checkExtensionGenericParams(TypeChecker &tc, ExtensionDecl *ext, Type type,
49434950
extInterfaceType,
49444951
nullptr,
49454952
source);
4953+
4954+
for (const auto &sameTypeReq : sameTypeReqs) {
4955+
builder.addRequirement(
4956+
Requirement(RequirementKind::SameType, sameTypeReq.first,
4957+
sameTypeReq.second),
4958+
source, ext->getModuleContext());
4959+
}
49464960
};
49474961

49484962
// Validate the generic type signature.
49494963
auto *env = tc.checkGenericEnvironment(genericParams,
49504964
ext->getDeclContext(), nullptr,
49514965
/*allowConcreteGenericParams=*/true,
49524966
ext, inferExtendedTypeReqs,
4953-
mustInferRequirements);
4967+
(mustInferRequirements ||
4968+
!sameTypeReqs.empty()));
49544969

49554970
return { env, extInterfaceType };
49564971
}
49574972

4973+
static bool isNonGenericTypeAliasType(Type type) {
4974+
// A non-generic typealias can extend a specialized type.
4975+
if (auto *aliasType = dyn_cast<NameAliasType>(type.getPointer()))
4976+
return aliasType->getDecl()->getGenericParamsOfContext() == nullptr;
4977+
4978+
return false;
4979+
}
4980+
49584981
static void validateExtendedType(ExtensionDecl *ext, TypeChecker &tc) {
49594982
// If we didn't parse a type, fill in an error type and bail out.
49604983
if (!ext->getExtendedTypeLoc().getTypeRepr()) {
@@ -4998,20 +5021,22 @@ static void validateExtendedType(ExtensionDecl *ext, TypeChecker &tc) {
49985021
return;
49995022
}
50005023

5001-
// Cannot extend a bound generic type.
5002-
if (extendedType->isSpecialized()) {
5003-
tc.diagnose(ext->getLoc(), diag::extension_specialization,
5004-
extendedType->getAnyNominal()->getName())
5024+
// Cannot extend function types, tuple types, etc.
5025+
if (!extendedType->getAnyNominal()) {
5026+
tc.diagnose(ext->getLoc(), diag::non_nominal_extension, extendedType)
50055027
.highlight(ext->getExtendedTypeLoc().getSourceRange());
50065028
ext->setInvalid();
50075029
ext->getExtendedTypeLoc().setInvalidType(tc.Context);
50085030
return;
50095031
}
50105032

5011-
// Cannot extend function types, tuple types, etc.
5012-
if (!extendedType->getAnyNominal()) {
5013-
tc.diagnose(ext->getLoc(), diag::non_nominal_extension, extendedType)
5014-
.highlight(ext->getExtendedTypeLoc().getSourceRange());
5033+
// Cannot extend a bound generic type, unless it's referenced via a
5034+
// non-generic typealias type.
5035+
if (extendedType->isSpecialized() &&
5036+
!isNonGenericTypeAliasType(extendedType)) {
5037+
tc.diagnose(ext->getLoc(), diag::extension_specialization,
5038+
extendedType->getAnyNominal()->getName())
5039+
.highlight(ext->getExtendedTypeLoc().getSourceRange());
50155040
ext->setInvalid();
50165041
ext->getExtendedTypeLoc().setInvalidType(tc.Context);
50175042
return;

test/decl/ext/generic.swift

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ extension X<Int, Double, String> {
3030

3131
typealias GGG = X<Int, Double, String>
3232

33-
extension GGG { } // expected-error{{constrained extension must be declared on the unspecialized generic type 'X' with constraints specified by a 'where' clause}}
33+
extension GGG { } // okay through a typealias
3434

3535
// Lvalue check when the archetypes are not the same.
3636
struct LValueCheck<T> {
@@ -209,4 +209,3 @@ extension A.B {
209209
extension A.B.D {
210210
func g() { }
211211
}
212-

test/decl/ext/typealias.swift

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
// RUN: %target-typecheck-verify-swift
2+
3+
struct Foo<T> {
4+
var maybeT: T? { return nil }
5+
}
6+
7+
extension Foo {
8+
struct Bar<U, V> {
9+
var maybeT: T? { return nil }
10+
var maybeU: U? { return nil }
11+
var maybeV: V? { return nil }
12+
13+
struct Inner {
14+
var maybeT: T? { return nil }
15+
var maybeU: U? { return nil }
16+
var maybeV: V? { return nil }
17+
}
18+
}
19+
}
20+
21+
typealias FooInt = Foo<Int>
22+
23+
extension FooInt {
24+
func goodT() -> Int {
25+
return maybeT!
26+
}
27+
28+
func badT() -> Float {
29+
return maybeT! // expected-error{{cannot convert return expression of type 'Int' to return type 'Float'}}
30+
}
31+
}
32+
33+
typealias FooIntBarFloatDouble = Foo<Int>.Bar<Float, Double>
34+
35+
extension FooIntBarFloatDouble {
36+
func goodT() -> Int {
37+
return maybeT!
38+
}
39+
func goodU() -> Float {
40+
return maybeU!
41+
}
42+
func goodV() -> Double {
43+
return maybeV!
44+
}
45+
46+
func badT() -> Float {
47+
return maybeT! // expected-error{{cannot convert return expression of type 'Int' to return type 'Float'}}
48+
}
49+
func badU() -> Int {
50+
return maybeU! // expected-error{{cannot convert return expression of type 'Float' to return type 'Int'}}
51+
}
52+
func badV() -> Int {
53+
return maybeV! // expected-error{{cannot convert return expression of type 'Double' to return type 'Int'}}
54+
}
55+
}
56+
57+
typealias FooIntBarFloatDoubleInner = Foo<Int>.Bar<Float, Double>.Inner
58+
59+
extension FooIntBarFloatDoubleInner {
60+
func goodT() -> Int {
61+
return maybeT!
62+
}
63+
func goodU() -> Float {
64+
return maybeU!
65+
}
66+
func goodV() -> Double {
67+
return maybeV!
68+
}
69+
70+
func badT() -> Float {
71+
return maybeT! // expected-error{{cannot convert return expression of type 'Int' to return type 'Float'}}
72+
}
73+
func badU() -> Int {
74+
return maybeU! // expected-error{{cannot convert return expression of type 'Float' to return type 'Int'}}
75+
}
76+
func badV() -> Int {
77+
return maybeV! // expected-error{{cannot convert return expression of type 'Double' to return type 'Int'}}
78+
}
79+
}
80+

0 commit comments

Comments
 (0)