@@ -3157,20 +3157,43 @@ DynamicallyReplacedDeclRequest::evaluate(Evaluator &evaluator,
3157
3157
return nullptr ;
3158
3158
}
3159
3159
3160
- // / Returns true if the given type conforms to `Differentiable` in the given
3161
- // / module.
3162
- static bool conformsToDifferentiable (Type type, DeclContext *DC) {
3160
+ // / If the given type conforms to `Differentiable` in the given context, returns
3161
+ // / the `ProtocolConformanceRef`. Otherwise, returns an invalid
3162
+ // / `ProtocolConformanceRef`.
3163
+ // /
3164
+ // / This helper verifies that the `TangentVector` type witness is valid, in case
3165
+ // / the conformance has not been fully checked and the type witness cannot be
3166
+ // / resolved.
3167
+ static ProtocolConformanceRef getDifferentiableConformance (Type type,
3168
+ DeclContext *DC) {
3163
3169
auto &ctx = type->getASTContext ();
3164
3170
auto *differentiableProto =
3165
3171
ctx.getProtocol (KnownProtocolKind::Differentiable);
3166
- auto conf = TypeChecker::conformsToProtocol (
3167
- type, differentiableProto, DC, ConformanceCheckFlags::InExpression );
3172
+ auto conf =
3173
+ TypeChecker::conformsToProtocol ( type, differentiableProto, DC, None );
3168
3174
if (!conf)
3169
- return false ;
3175
+ return ProtocolConformanceRef () ;
3170
3176
// Try to get the `TangentVector` type witness, in case the conformance has
3171
- // not been fully checked and the type witness cannot be resolved.
3177
+ // not been fully checked.
3178
+ Type tanType = conf.getTypeWitnessByName (type, ctx.Id_TangentVector );
3179
+ if (tanType.isNull () || tanType->hasError ())
3180
+ return ProtocolConformanceRef ();
3181
+ return conf;
3182
+ };
3183
+
3184
+ // / Returns true if the given type conforms to `Differentiable` in the given
3185
+ // / contxt. If `tangentVectorEqualsSelf` is true, also check whether the given
3186
+ // / type satisfies `TangentVector == Self`.
3187
+ static bool conformsToDifferentiable (Type type, DeclContext *DC,
3188
+ bool tangentVectorEqualsSelf = false ) {
3189
+ auto conf = getDifferentiableConformance (type, DC);
3190
+ if (conf.isInvalid ())
3191
+ return false ;
3192
+ if (!tangentVectorEqualsSelf)
3193
+ return true ;
3194
+ auto &ctx = type->getASTContext ();
3172
3195
Type tanType = conf.getTypeWitnessByName (type, ctx.Id_TangentVector );
3173
- return !tanType. isNull () && !tanType-> hasError ( );
3196
+ return type-> isEqual (tanType );
3174
3197
};
3175
3198
3176
3199
IndexSubset *TypeChecker::inferDifferentiabilityParameters (
@@ -4364,9 +4387,8 @@ static bool typeCheckDerivativeAttr(ASTContext &Ctx, Decl *D,
4364
4387
auto originalResult = originalResults.front ();
4365
4388
auto originalResultType = originalResult.type ;
4366
4389
// Check that the original semantic result conforms to `Differentiable`.
4367
- auto *diffableProto = Ctx.getProtocol (KnownProtocolKind::Differentiable);
4368
- auto valueResultConf = TypeChecker::conformsToProtocol (
4369
- originalResultType, diffableProto, derivative->getDeclContext (), None);
4390
+ auto valueResultConf = getDifferentiableConformance (
4391
+ originalResultType, derivative->getDeclContext ());
4370
4392
if (!valueResultConf) {
4371
4393
diags.diagnose (attr->getLocation (),
4372
4394
diag::derivative_attr_result_value_not_differentiable,
@@ -4467,21 +4489,6 @@ DerivativeAttrOriginalDeclRequest::evaluate(Evaluator &evaluator,
4467
4489
return nullptr ;
4468
4490
}
4469
4491
4470
- // / Returns true if the given type's `TangentVector` is equal to itself in the
4471
- // / given module.
4472
- static bool tangentVectorEqualsSelf (Type type, DeclContext *DC) {
4473
- assert (conformsToDifferentiable (type, DC));
4474
- auto &ctx = type->getASTContext ();
4475
- auto *differentiableProto =
4476
- ctx.getProtocol (KnownProtocolKind::Differentiable);
4477
- auto conf = TypeChecker::conformsToProtocol (
4478
- type, differentiableProto, DC,
4479
- ConformanceCheckFlags::InExpression);
4480
- auto tanType = conf.getTypeWitnessByName (type, ctx.Id_TangentVector );
4481
- return type->getCanonicalType () == tanType->getCanonicalType ();
4482
- };
4483
-
4484
-
4485
4492
// Computes the linearity parameter indices from the given parsed linearity
4486
4493
// parameters for the given transpose function. On error, emits diagnostics and
4487
4494
// returns `nullptr`.
@@ -4600,8 +4607,8 @@ static bool checkLinearityParameters(
4600
4607
parsedLinearParams.empty () ? attrLoc : parsedLinearParams[i].getLoc ();
4601
4608
// Parameter must conform to `Differentiable` and satisfy
4602
4609
// `Self == Self.TangentVector`.
4603
- if (!conformsToDifferentiable (linearParamType, originalAFD) ||
4604
- ! tangentVectorEqualsSelf (linearParamType, originalAFD )) {
4610
+ if (!conformsToDifferentiable (linearParamType, originalAFD,
4611
+ /* tangentVectorEqualsSelf*/ true )) {
4605
4612
diags.diagnose (loc,
4606
4613
diag::transpose_attr_invalid_linearity_parameter_or_result,
4607
4614
linearParamType.getString (), /* isParameter*/ true );
@@ -4713,8 +4720,8 @@ void AttributeChecker::visitTransposeAttr(TransposeAttr *attr) {
4713
4720
if (expectedOriginalResultType->hasTypeParameter ())
4714
4721
expectedOriginalResultType = transpose->mapTypeIntoContext (
4715
4722
expectedOriginalResultType);
4716
- if (!conformsToDifferentiable (expectedOriginalResultType, transpose) ||
4717
- ! tangentVectorEqualsSelf (expectedOriginalResultType, transpose )) {
4723
+ if (!conformsToDifferentiable (expectedOriginalResultType, transpose,
4724
+ /* tangentVectorEqualsSelf*/ true )) {
4718
4725
diagnoseAndRemoveAttr (
4719
4726
attr, diag::transpose_attr_invalid_linearity_parameter_or_result,
4720
4727
expectedOriginalResultType.getString (), /* isParameter*/ false );
0 commit comments