@@ -4844,26 +4844,21 @@ static bool isPassThroughTypealias(TypeAliasDecl *typealias) {
4844
4844
4845
4845
// / Form the interface type of an extension from the raw type and the
4846
4846
// / extension's list of generic parameters.
4847
- static Type formExtensionInterfaceType (TypeChecker &tc, ExtensionDecl *ext,
4848
- Type type,
4849
- GenericParamList *genericParams,
4850
- bool &mustInferRequirements) {
4847
+ static Type formExtensionInterfaceType (
4848
+ TypeChecker &tc, ExtensionDecl *ext,
4849
+ Type type,
4850
+ GenericParamList *genericParams,
4851
+ SmallVectorImpl<std::pair<Type, Type>> &sameTypeReqs,
4852
+ bool &mustInferRequirements) {
4851
4853
if (type->is <ErrorType>())
4852
4854
return type;
4853
4855
4854
4856
// Find the nominal type declaration and its parent type.
4855
- Type parentType;
4856
- GenericTypeDecl *genericDecl;
4857
- if (auto unbound = type->getAs <UnboundGenericType>()) {
4858
- parentType = unbound->getParent ();
4859
- genericDecl = unbound->getDecl ();
4860
- } else {
4861
- if (type->is <ProtocolCompositionType>())
4862
- type = type->getCanonicalType ();
4863
- auto nominalType = type->castTo <NominalType>();
4864
- parentType = nominalType->getParent ();
4865
- genericDecl = nominalType->getDecl ();
4866
- }
4857
+ if (type->is <ProtocolCompositionType>())
4858
+ type = type->getCanonicalType ();
4859
+
4860
+ Type parentType = type->getNominalParent ();
4861
+ GenericTypeDecl *genericDecl = type->getAnyGeneric ();
4867
4862
4868
4863
// Reconstruct the parent, if there is one.
4869
4864
if (parentType) {
@@ -4873,7 +4868,7 @@ static Type formExtensionInterfaceType(TypeChecker &tc, ExtensionDecl *ext,
4873
4868
: genericParams;
4874
4869
parentType =
4875
4870
formExtensionInterfaceType (tc, ext, parentType, parentGenericParams,
4876
- mustInferRequirements);
4871
+ sameTypeReqs, mustInferRequirements);
4877
4872
}
4878
4873
4879
4874
// Find the nominal type.
@@ -4891,9 +4886,20 @@ static Type formExtensionInterfaceType(TypeChecker &tc, ExtensionDecl *ext,
4891
4886
resultType = NominalType::get (nominal, parentType,
4892
4887
nominal->getASTContext ());
4893
4888
} else {
4889
+ auto currentBoundType = type->getAs <BoundGenericType>();
4890
+
4894
4891
// Form the bound generic type with the type parameters provided.
4892
+ unsigned gpIndex = 0 ;
4895
4893
for (auto gp : *genericParams) {
4896
- genericArgs.push_back (gp->getDeclaredInterfaceType ());
4894
+ SWIFT_DEFER { ++gpIndex; };
4895
+
4896
+ auto gpType = gp->getDeclaredInterfaceType ();
4897
+ genericArgs.push_back (gpType);
4898
+
4899
+ if (currentBoundType) {
4900
+ sameTypeReqs.push_back ({gpType,
4901
+ currentBoundType->getGenericArgs ()[gpIndex]});
4902
+ }
4897
4903
}
4898
4904
4899
4905
resultType = BoundGenericType::get (nominal, parentType, genericArgs);
@@ -4930,8 +4936,9 @@ checkExtensionGenericParams(TypeChecker &tc, ExtensionDecl *ext, Type type,
4930
4936
4931
4937
// Form the interface type of the extension.
4932
4938
bool mustInferRequirements = false ;
4939
+ SmallVector<std::pair<Type, Type>, 4 > sameTypeReqs;
4933
4940
Type extInterfaceType =
4934
- formExtensionInterfaceType (tc, ext, type, genericParams,
4941
+ formExtensionInterfaceType (tc, ext, type, genericParams, sameTypeReqs,
4935
4942
mustInferRequirements);
4936
4943
4937
4944
// Local function used to infer requirements from the extended type.
@@ -4943,18 +4950,34 @@ checkExtensionGenericParams(TypeChecker &tc, ExtensionDecl *ext, Type type,
4943
4950
extInterfaceType,
4944
4951
nullptr ,
4945
4952
source);
4953
+
4954
+ for (const auto &sameTypeReq : sameTypeReqs) {
4955
+ builder.addRequirement (
4956
+ Requirement (RequirementKind::SameType, sameTypeReq.first ,
4957
+ sameTypeReq.second ),
4958
+ source, ext->getModuleContext ());
4959
+ }
4946
4960
};
4947
4961
4948
4962
// Validate the generic type signature.
4949
4963
auto *env = tc.checkGenericEnvironment (genericParams,
4950
4964
ext->getDeclContext (), nullptr ,
4951
4965
/* allowConcreteGenericParams=*/ true ,
4952
4966
ext, inferExtendedTypeReqs,
4953
- mustInferRequirements);
4967
+ (mustInferRequirements ||
4968
+ !sameTypeReqs.empty ()));
4954
4969
4955
4970
return { env, extInterfaceType };
4956
4971
}
4957
4972
4973
+ static bool isNonGenericTypeAliasType (Type type) {
4974
+ // A non-generic typealias can extend a specialized type.
4975
+ if (auto *aliasType = dyn_cast<NameAliasType>(type.getPointer ()))
4976
+ return aliasType->getDecl ()->getGenericParamsOfContext () == nullptr ;
4977
+
4978
+ return false ;
4979
+ }
4980
+
4958
4981
static void validateExtendedType (ExtensionDecl *ext, TypeChecker &tc) {
4959
4982
// If we didn't parse a type, fill in an error type and bail out.
4960
4983
if (!ext->getExtendedTypeLoc ().getTypeRepr ()) {
@@ -4998,20 +5021,22 @@ static void validateExtendedType(ExtensionDecl *ext, TypeChecker &tc) {
4998
5021
return ;
4999
5022
}
5000
5023
5001
- // Cannot extend a bound generic type.
5002
- if (extendedType->isSpecialized ()) {
5003
- tc.diagnose (ext->getLoc (), diag::extension_specialization,
5004
- extendedType->getAnyNominal ()->getName ())
5024
+ // Cannot extend function types, tuple types, etc.
5025
+ if (!extendedType->getAnyNominal ()) {
5026
+ tc.diagnose (ext->getLoc (), diag::non_nominal_extension, extendedType)
5005
5027
.highlight (ext->getExtendedTypeLoc ().getSourceRange ());
5006
5028
ext->setInvalid ();
5007
5029
ext->getExtendedTypeLoc ().setInvalidType (tc.Context );
5008
5030
return ;
5009
5031
}
5010
5032
5011
- // Cannot extend function types, tuple types, etc.
5012
- if (!extendedType->getAnyNominal ()) {
5013
- tc.diagnose (ext->getLoc (), diag::non_nominal_extension, extendedType)
5014
- .highlight (ext->getExtendedTypeLoc ().getSourceRange ());
5033
+ // Cannot extend a bound generic type, unless it's referenced via a
5034
+ // non-generic typealias type.
5035
+ if (extendedType->isSpecialized () &&
5036
+ !isNonGenericTypeAliasType (extendedType)) {
5037
+ tc.diagnose (ext->getLoc (), diag::extension_specialization,
5038
+ extendedType->getAnyNominal ()->getName ())
5039
+ .highlight (ext->getExtendedTypeLoc ().getSourceRange ());
5015
5040
ext->setInvalid ();
5016
5041
ext->getExtendedTypeLoc ().setInvalidType (tc.Context );
5017
5042
return ;
0 commit comments