Skip to content

Commit e9ed2d5

Browse files
authored
[AutoDiff] Fix @derivative attribute type-checking crash. (swiftlang#30936)
Fix `@derivative` attribute type-checking crash, so far reproducible only via `-parse-stdlib`. The crash occurs because it is not sufficient for type-checking to check for `Differentiable` conformances. We must also check for invalid `TangentVector` associated types. Resolves SR-12559.
1 parent 9a6ae6e commit e9ed2d5

File tree

2 files changed

+67
-30
lines changed

2 files changed

+67
-30
lines changed

lib/Sema/TypeCheckAttr.cpp

Lines changed: 37 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -3157,20 +3157,43 @@ DynamicallyReplacedDeclRequest::evaluate(Evaluator &evaluator,
31573157
return nullptr;
31583158
}
31593159

3160-
/// Returns true if the given type conforms to `Differentiable` in the given
3161-
/// module.
3162-
static bool conformsToDifferentiable(Type type, DeclContext *DC) {
3160+
/// If the given type conforms to `Differentiable` in the given context, returns
3161+
/// the `ProtocolConformanceRef`. Otherwise, returns an invalid
3162+
/// `ProtocolConformanceRef`.
3163+
///
3164+
/// This helper verifies that the `TangentVector` type witness is valid, in case
3165+
/// the conformance has not been fully checked and the type witness cannot be
3166+
/// resolved.
3167+
static ProtocolConformanceRef getDifferentiableConformance(Type type,
3168+
DeclContext *DC) {
31633169
auto &ctx = type->getASTContext();
31643170
auto *differentiableProto =
31653171
ctx.getProtocol(KnownProtocolKind::Differentiable);
3166-
auto conf = TypeChecker::conformsToProtocol(
3167-
type, differentiableProto, DC, ConformanceCheckFlags::InExpression);
3172+
auto conf =
3173+
TypeChecker::conformsToProtocol(type, differentiableProto, DC, None);
31683174
if (!conf)
3169-
return false;
3175+
return ProtocolConformanceRef();
31703176
// Try to get the `TangentVector` type witness, in case the conformance has
3171-
// not been fully checked and the type witness cannot be resolved.
3177+
// not been fully checked.
3178+
Type tanType = conf.getTypeWitnessByName(type, ctx.Id_TangentVector);
3179+
if (tanType.isNull() || tanType->hasError())
3180+
return ProtocolConformanceRef();
3181+
return conf;
3182+
};
3183+
3184+
/// Returns true if the given type conforms to `Differentiable` in the given
3185+
/// contxt. If `tangentVectorEqualsSelf` is true, also check whether the given
3186+
/// type satisfies `TangentVector == Self`.
3187+
static bool conformsToDifferentiable(Type type, DeclContext *DC,
3188+
bool tangentVectorEqualsSelf = false) {
3189+
auto conf = getDifferentiableConformance(type, DC);
3190+
if (conf.isInvalid())
3191+
return false;
3192+
if (!tangentVectorEqualsSelf)
3193+
return true;
3194+
auto &ctx = type->getASTContext();
31723195
Type tanType = conf.getTypeWitnessByName(type, ctx.Id_TangentVector);
3173-
return !tanType.isNull() && !tanType->hasError();
3196+
return type->isEqual(tanType);
31743197
};
31753198

31763199
IndexSubset *TypeChecker::inferDifferentiabilityParameters(
@@ -4364,9 +4387,8 @@ static bool typeCheckDerivativeAttr(ASTContext &Ctx, Decl *D,
43644387
auto originalResult = originalResults.front();
43654388
auto originalResultType = originalResult.type;
43664389
// Check that the original semantic result conforms to `Differentiable`.
4367-
auto *diffableProto = Ctx.getProtocol(KnownProtocolKind::Differentiable);
4368-
auto valueResultConf = TypeChecker::conformsToProtocol(
4369-
originalResultType, diffableProto, derivative->getDeclContext(), None);
4390+
auto valueResultConf = getDifferentiableConformance(
4391+
originalResultType, derivative->getDeclContext());
43704392
if (!valueResultConf) {
43714393
diags.diagnose(attr->getLocation(),
43724394
diag::derivative_attr_result_value_not_differentiable,
@@ -4467,21 +4489,6 @@ DerivativeAttrOriginalDeclRequest::evaluate(Evaluator &evaluator,
44674489
return nullptr;
44684490
}
44694491

4470-
/// Returns true if the given type's `TangentVector` is equal to itself in the
4471-
/// given module.
4472-
static bool tangentVectorEqualsSelf(Type type, DeclContext *DC) {
4473-
assert(conformsToDifferentiable(type, DC));
4474-
auto &ctx = type->getASTContext();
4475-
auto *differentiableProto =
4476-
ctx.getProtocol(KnownProtocolKind::Differentiable);
4477-
auto conf = TypeChecker::conformsToProtocol(
4478-
type, differentiableProto, DC,
4479-
ConformanceCheckFlags::InExpression);
4480-
auto tanType = conf.getTypeWitnessByName(type, ctx.Id_TangentVector);
4481-
return type->getCanonicalType() == tanType->getCanonicalType();
4482-
};
4483-
4484-
44854492
// Computes the linearity parameter indices from the given parsed linearity
44864493
// parameters for the given transpose function. On error, emits diagnostics and
44874494
// returns `nullptr`.
@@ -4600,8 +4607,8 @@ static bool checkLinearityParameters(
46004607
parsedLinearParams.empty() ? attrLoc : parsedLinearParams[i].getLoc();
46014608
// Parameter must conform to `Differentiable` and satisfy
46024609
// `Self == Self.TangentVector`.
4603-
if (!conformsToDifferentiable(linearParamType, originalAFD) ||
4604-
!tangentVectorEqualsSelf(linearParamType, originalAFD)) {
4610+
if (!conformsToDifferentiable(linearParamType, originalAFD,
4611+
/*tangentVectorEqualsSelf*/ true)) {
46054612
diags.diagnose(loc,
46064613
diag::transpose_attr_invalid_linearity_parameter_or_result,
46074614
linearParamType.getString(), /*isParameter*/ true);
@@ -4713,8 +4720,8 @@ void AttributeChecker::visitTransposeAttr(TransposeAttr *attr) {
47134720
if (expectedOriginalResultType->hasTypeParameter())
47144721
expectedOriginalResultType = transpose->mapTypeIntoContext(
47154722
expectedOriginalResultType);
4716-
if (!conformsToDifferentiable(expectedOriginalResultType, transpose) ||
4717-
!tangentVectorEqualsSelf(expectedOriginalResultType, transpose)) {
4723+
if (!conformsToDifferentiable(expectedOriginalResultType, transpose,
4724+
/*tangentVectorEqualsSelf*/ true)) {
47184725
diagnoseAndRemoveAttr(
47194726
attr, diag::transpose_attr_invalid_linearity_parameter_or_result,
47204727
expectedOriginalResultType.getString(), /*isParameter*/ false);
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
// RUN: not %target-swift-frontend-typecheck -parse-stdlib %s
2+
3+
// SR-12559: `@derivative` attribute type-checking crash.
4+
// This program is not valid, but the compiler should not crash nonetheless.
5+
// Reproducible only with `-parse-stdlib`.
6+
7+
// The crash occurs because it is not sufficient for attribute type-checking to
8+
// check for `Differentiable` conformances. We must also check for invalid
9+
// associated types:
10+
//
11+
// (lldb) p valueResultConf.dump()
12+
// (normal_conformance type=AnyDerivative protocol=Differentiable
13+
// (assoc_type req=TangentVector type=<<error type>>))
14+
15+
import _Differentiation
16+
17+
struct AnyDerivative: Differentiable {
18+
init<T>(_ base: T) {}
19+
20+
@derivative(of: init)
21+
static func _vjpInit<T: Differentiable>(
22+
_ base: T
23+
) -> (value: AnyDerivative, pullback: (AnyDerivative) -> T.TangentVector) {
24+
fatalError()
25+
}
26+
27+
typealias TangentVector = AnyDerivative
28+
}
29+
30+
// Assertion failed: (resultTan && "Original result has no tangent space?"), function getAutoDiffDerivativeFunctionLinearMapType, file /Users/danielzheng/swift-merge/swift/lib/AST/Type.cpp, line 5190.

0 commit comments

Comments
 (0)