Skip to content

Commit d96b73a

Browse files
committed
[AutoDiff] Make Differentiable derivation support property wrappers.
Differentiable conformance derivation now "peers through" property wrappers. Synthesized TangentVector structs contain wrapped properties' TangentVectors as stored properties, not wrappers' TangentVectors. Property wrapper types are not required to conform to `Differentiable`. Property wrapper types are required to provide `wrappedValue.set`, which is needed to synthesize `mutating func move(along:)`. ``` import _Differentiation @propertyWrapper struct Wrapper<Value> { var wrappedValue: Value } struct Struct: Differentiable { @wrapper var x: Float = 0 // Compiler now synthesizes: // struct TangentVector: Differentiable & AdditiveArithmetic { // var x: Float // ... // } } ``` Resolves SR-12638.
1 parent 37657d0 commit d96b73a

File tree

5 files changed

+158
-54
lines changed

5 files changed

+158
-54
lines changed

include/swift/AST/DiagnosticsSema.def

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2734,13 +2734,21 @@ WARNING(differentiable_nondiff_type_implicit_noderivative_fixit,none,
27342734
"stored property %0 has no derivative because %1 does not conform to "
27352735
"'Differentiable'; add an explicit '@noDerivative' attribute"
27362736
"%select{|, or conform %2 to 'AdditiveArithmetic'}3",
2737-
(Identifier, Type, Identifier, bool))
2738-
WARNING(differentiable_let_property_implicit_noderivative_fixit,none,
2737+
(/*propName*/ Identifier, /*propType*/ Type, /*nominalName*/ Identifier,
2738+
/*nominalCanDeriveAdditiveArithmetic*/ bool))
2739+
WARNING(differentiable_immutable_wrapper_implicit_noderivative_fixit,none,
27392740
"synthesis of the 'Differentiable.move(along:)' requirement for %1 "
2740-
"requires all stored properties to be mutable; use 'var' instead, or add "
2741-
"an explicit '@noDerivative' attribute"
2741+
"requires all stored properties not marked with `@noDerivative` to be "
2742+
"mutable; add an explicit '@noDerivative' attribute"
27422743
"%select{|, or conform %1 to 'AdditiveArithmetic'}2",
2743-
(Identifier, Identifier, bool))
2744+
(/*wrapperType*/ StringRef, /*nominalName*/ Identifier,
2745+
/*nominalCanDeriveAdditiveArithmetic*/ bool))
2746+
WARNING(differentiable_let_property_implicit_noderivative_fixit,none,
2747+
"synthesis of the 'Differentiable.move(along:)' requirement for %0 "
2748+
"requires all stored properties not marked with `@noDerivative` to be "
2749+
"mutable; use 'var' instead, or add an explicit '@noDerivative' attribute"
2750+
"%select{|, or conform %0 to 'AdditiveArithmetic'}1",
2751+
(/*nominalName*/ Identifier, /*nominalCanDeriveAdditiveArithmetic*/ bool))
27442752

27452753
NOTE(codable_extraneous_codingkey_case_here,none,
27462754
"CodingKey case %0 does not match any stored properties", (Identifier))

lib/Sema/DerivedConformanceDifferentiable.cpp

Lines changed: 59 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
#include "swift/AST/ParameterList.h"
2626
#include "swift/AST/Pattern.h"
2727
#include "swift/AST/ProtocolConformance.h"
28+
#include "swift/AST/PropertyWrappers.h"
2829
#include "swift/AST/Stmt.h"
2930
#include "swift/AST/Types.h"
3031
#include "DerivedConformances.h"
@@ -39,14 +40,23 @@ getStoredPropertiesForDifferentiation(NominalTypeDecl *nominal, DeclContext *DC,
3940
auto &C = nominal->getASTContext();
4041
auto *diffableProto = C.getProtocol(KnownProtocolKind::Differentiable);
4142
for (auto *vd : nominal->getStoredProperties()) {
43+
// Peer through property wrappers: use original wrapped properties instead.
44+
if (auto *originalProperty = vd->getOriginalWrappedProperty()) {
45+
// Skip property wrappers that do not define `wrappedValue.set`.
46+
// `mutating func move(along:)` cannot be synthesized to update these
47+
// properties.
48+
auto *wrapperDecl =
49+
vd->getInterfaceType()->getNominalOrBoundGenericNominal();
50+
auto *wrappedValueDecl =
51+
wrapperDecl->getPropertyWrapperTypeInfo().valueVar;
52+
if (!wrappedValueDecl->getAccessor(AccessorKind::Set))
53+
continue;
54+
// Use the original wrapped property.
55+
vd = originalProperty;
56+
}
4257
// Skip stored properties with `@noDerivative` attribute.
4358
if (vd->getAttrs().hasAttribute<NoDerivativeAttr>())
4459
continue;
45-
// For property wrapper backing storage properties, skip if original
46-
// property has `@noDerivative` attribute.
47-
if (auto *originalProperty = vd->getOriginalWrappedProperty())
48-
if (originalProperty->getAttrs().hasAttribute<NoDerivativeAttr>())
49-
continue;
5060
// Skip `let` stored properties. `mutating func move(along:)` cannot be
5161
// synthesized to update these properties.
5262
if (vd->isLet())
@@ -224,15 +234,15 @@ deriveBodyDifferentiable_method(AbstractFunctionDecl *funcDecl,
224234
if (confRef.isConcrete())
225235
memberMethodDecl = confRef.getConcrete()->getWitnessDecl(methodReq);
226236
assert(memberMethodDecl && "Member method declaration must exist");
227-
auto memberMethodDRE =
237+
auto *memberMethodDRE =
228238
new (C) DeclRefExpr(memberMethodDecl, DeclNameLoc(), /*Implicit*/ true);
229239
memberMethodDRE->setFunctionRefKind(FunctionRefKind::SingleApply);
230240

231241
// Create reference to member method: `x.move(along:)`.
232-
auto memberExpr =
242+
Expr *memberExpr =
233243
new (C) MemberRefExpr(selfDRE, SourceLoc(), member, DeclNameLoc(),
234244
/*Implicit*/ true);
235-
auto memberMethodExpr =
245+
auto *memberMethodExpr =
236246
new (C) DotSyntaxCallExpr(memberMethodDRE, SourceLoc(), memberExpr);
237247

238248
// Create reference to parameter member: `direction.x`.
@@ -483,20 +493,52 @@ static void addAssociatedTypeAliasDecl(Identifier name, DeclContext *sourceDC,
483493
static void checkAndDiagnoseImplicitNoDerivative(ASTContext &Context,
484494
NominalTypeDecl *nominal,
485495
DeclContext *DC) {
486-
auto *diffableProto = Context.getProtocol(KnownProtocolKind::Differentiable);
496+
// If nominal type can conform to `AdditiveArithmetic`, suggest adding a
497+
// conformance to `AdditiveArithmetic` in fix-its.
498+
// `Differentiable` protocol requirements all have default implementations
499+
// when `Self` conforms to `AdditiveArithmetic`, so `Differentiable`
500+
// derived conformances will no longer be necessary.
487501
bool nominalCanDeriveAdditiveArithmetic =
488502
DerivedConformance::canDeriveAdditiveArithmetic(nominal, DC);
503+
auto *diffableProto = Context.getProtocol(KnownProtocolKind::Differentiable);
504+
// Check all stored properties.
489505
for (auto *vd : nominal->getStoredProperties()) {
506+
// Peer through property wrappers: use original wrapped properties.
507+
if (auto *originalProperty = vd->getOriginalWrappedProperty()) {
508+
// Skip wrapped properties with `@noDerivative` attribute.
509+
if (originalProperty->getAttrs().hasAttribute<NoDerivativeAttr>())
510+
continue;
511+
// Diagnose wrapped properties whose property wrappers do not define
512+
// `wrappedValue.set`. `mutating func move(along:)` cannot be synthesized
513+
// to update these properties.
514+
auto *wrapperDecl =
515+
vd->getInterfaceType()->getNominalOrBoundGenericNominal();
516+
auto *wrappedValueDecl =
517+
wrapperDecl->getPropertyWrapperTypeInfo().valueVar;
518+
if (!wrappedValueDecl->getAccessor(AccessorKind::Set)) {
519+
auto loc =
520+
originalProperty->getAttributeInsertionLoc(/*forModifier*/ false);
521+
Context.Diags
522+
.diagnose(
523+
loc,
524+
diag::
525+
differentiable_immutable_wrapper_implicit_noderivative_fixit,
526+
wrapperDecl->getNameStr(), nominal->getName(),
527+
nominalCanDeriveAdditiveArithmetic)
528+
.fixItInsert(loc, "@noDerivative ");
529+
// Add an implicit `@noDerivative` attribute.
530+
originalProperty->getAttrs().add(
531+
new (Context) NoDerivativeAttr(/*Implicit*/ true));
532+
continue;
533+
}
534+
// Use the original wrapped property.
535+
vd = originalProperty;
536+
}
490537
if (vd->getInterfaceType()->hasError())
491538
continue;
492539
// Skip stored properties with `@noDerivative` attribute.
493540
if (vd->getAttrs().hasAttribute<NoDerivativeAttr>())
494541
continue;
495-
// For property wrapper backing storage properties, skip if original
496-
// property has `@noDerivative` attribute.
497-
if (auto *originalProperty = vd->getOriginalWrappedProperty())
498-
if (originalProperty->getAttrs().hasAttribute<NoDerivativeAttr>())
499-
continue;
500542
// Check whether to diagnose stored property.
501543
auto varType = DC->mapTypeIntoContext(vd->getValueInterfaceType());
502544
bool conformsToDifferentiable =
@@ -508,14 +550,8 @@ static void checkAndDiagnoseImplicitNoDerivative(ASTContext &Context,
508550
// Otherwise, add an implicit `@noDerivative` attribute.
509551
vd->getAttrs().add(new (Context) NoDerivativeAttr(/*Implicit*/ true));
510552
auto loc = vd->getAttributeInsertionLoc(/*forModifier*/ false);
511-
if (auto *originalProperty = vd->getOriginalWrappedProperty())
512-
loc = originalProperty->getAttributeInsertionLoc(/*forModifier*/ false);
513553
assert(loc.isValid() && "Expected valid source location");
514-
// If nominal type can conform to `AdditiveArithmetic`, suggest conforming
515-
// adding a conformance to `AdditiveArithmetic`.
516-
// `Differentiable` protocol requirements all have default implementations
517-
// when `Self` conforms to `AdditiveArithmetic`, so `Differentiable`
518-
// derived conformances will no longer be necessary.
554+
// Diagnose properties that do not conform to `Differentiable`.
519555
if (!conformsToDifferentiable) {
520556
Context.Diags
521557
.diagnose(
@@ -526,11 +562,11 @@ static void checkAndDiagnoseImplicitNoDerivative(ASTContext &Context,
526562
.fixItInsert(loc, "@noDerivative ");
527563
continue;
528564
}
565+
// Otherwise, diagnose `let` property.
529566
Context.Diags
530567
.diagnose(loc,
531568
diag::differentiable_let_property_implicit_noderivative_fixit,
532-
vd->getName(), nominal->getName(),
533-
nominalCanDeriveAdditiveArithmetic)
569+
nominal->getName(), nominalCanDeriveAdditiveArithmetic)
534570
.fixItInsert(loc, "@noDerivative ");
535571
}
536572
}

test/AutoDiff/Sema/DerivedConformances/class_differentiable.swift

Lines changed: 25 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ class ImmutableStoredProperties: Differentiable {
3838
// expected-warning @+1 {{stored property 'nondiff' has no derivative because 'Int' does not conform to 'Differentiable'; add an explicit '@noDerivative' attribute}} {{3-3=@noDerivative }}
3939
let nondiff: Int
4040

41-
// expected-warning @+1 {{synthesis of the 'Differentiable.move(along:)' requirement for 'ImmutableStoredProperties' requires all stored properties to be mutable; use 'var' instead, or add an explicit '@noDerivative' attribute}} {{3-3=@noDerivative }}
41+
// expected-warning @+1 {{synthesis of the 'Differentiable.move(along:)' requirement for 'ImmutableStoredProperties' requires all stored properties not marked with `@noDerivative` to be mutable; use 'var' instead, or add an explicit '@noDerivative' attribute}} {{3-3=@noDerivative }}
4242
let diff: Float
4343

4444
init() {
@@ -56,7 +56,8 @@ class MutableStoredPropertiesWithInitialValue: Differentiable {
5656
}
5757
// Test class with both an empty constructor and memberwise initializer.
5858
class AllMixedStoredPropertiesHaveInitialValue: Differentiable {
59-
let x = Float(1) // expected-warning {{synthesis of the 'Differentiable.move(along:)' requirement for 'AllMixedStoredPropertiesHaveInitialValue' requires all stored properties to be mutable; use 'var' instead, or add an explicit '@noDerivative' attribute}} {{3-3=@noDerivative }}
59+
// expected-warning @+1 {{synthesis of the 'Differentiable.move(along:)' requirement for 'AllMixedStoredPropertiesHaveInitialValue' requires all stored properties not marked with `@noDerivative` to be mutable; use 'var' instead, or add an explicit '@noDerivative' attribute}} {{3-3=@noDerivative }}
60+
let x = Float(1)
6061
var y = Float(1)
6162
// Memberwise initializer should be `init(y:)` since `x` is immutable.
6263
static func testMemberwiseInitializer() {
@@ -506,26 +507,35 @@ where T: AdditiveArithmetic {}
506507
extension NoMemberwiseInitializerExtended: Differentiable
507508
where T: Differentiable & AdditiveArithmetic {}
508509

510+
// Test property wrappers.
509511
// TF-1190: Test `@noDerivative` warning for property wrapper backing storage properties.
510512

511513
@propertyWrapper
512-
struct Wrapper<Value> {
514+
struct ImmutableWrapper<Value> {
513515
private var value: Value
514-
var wrappedValue: Value {
515-
get { value }
516-
set { value = newValue }
516+
var wrappedValue: Value { value }
517+
init(wrappedValue: Value) {
518+
self.value = wrappedValue
517519
}
518520
}
519-
struct TF_1190<T> {}
520-
class TF_1190_Outer: Differentiable {
521-
// expected-warning @+1 {{stored property '_x' has no derivative because 'Wrapper<TF_1190<Float>>' does not conform to 'Differentiable'; add an explicit '@noDerivative' attribute}}
522-
@Wrapper var x: TF_1190<Float>
523-
@noDerivative @Wrapper var y: TF_1190<Float>
524521

525-
init(x: TF_1190<Float>, y: TF_1190<Float>) {
526-
self.x = x
527-
self.y = y
528-
}
522+
@propertyWrapper
523+
struct Wrapper<Value> {
524+
var wrappedValue: Value
525+
}
526+
527+
struct Generic<T> {}
528+
extension Generic: Differentiable where T: Differentiable {}
529+
530+
class WrappedProperties: Differentiable {
531+
// expected-warning @+1 {{synthesis of the 'Differentiable.move(along:)' requirement for 'WrappedProperties' requires all stored properties not marked with `@noDerivative` to be mutable; add an explicit '@noDerivative' attribute}}
532+
@ImmutableWrapper var immutableInt: Generic<Int> = Generic()
533+
534+
// expected-warning @+1 {{stored property 'mutableInt' has no derivative because 'Generic<Int>' does not conform to 'Differentiable'; add an explicit '@noDerivative' attribute}}
535+
@Wrapper var mutableInt: Generic<Int> = Generic()
536+
537+
@Wrapper var float: Generic<Float> = Generic()
538+
@noDerivative @ImmutableWrapper var nondiff: Generic<Int> = Generic()
529539
}
530540

531541
// Test derived conformances in disallowed contexts.

test/AutoDiff/Sema/DerivedConformances/derived_differentiable.swift

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,3 +80,40 @@ struct UsableFromInlineStruct: Differentiable {}
8080
// CHECK-AST: internal init()
8181
// CHECK-AST: @usableFromInline
8282
// CHECK-AST: struct TangentVector : Differentiable, AdditiveArithmetic {
83+
84+
// Test property wrappers.
85+
86+
@propertyWrapper
87+
struct Wrapper<Value> {
88+
var wrappedValue: Value
89+
}
90+
91+
struct WrappedPropertiesStruct: Differentiable {
92+
@Wrapper @Wrapper var x: Float
93+
@Wrapper var y: Float
94+
var z: Float
95+
@noDerivative @Wrapper var nondiff: Float
96+
}
97+
98+
// CHECK-AST-LABEL: internal struct WrappedPropertiesStruct : Differentiable {
99+
// CHECK-AST: internal struct TangentVector : Differentiable, AdditiveArithmetic {
100+
// CHECK-AST: internal var x: Float.TangentVector
101+
// CHECK-AST: internal var y: Float.TangentVector
102+
// CHECK-AST: internal var z: Float.TangentVector
103+
// CHECK-AST: }
104+
// CHECK-AST: }
105+
106+
class WrappedPropertiesClass: Differentiable {
107+
@Wrapper @Wrapper var x: Float = 1
108+
@Wrapper var y: Float = 2
109+
var z: Float = 3
110+
@noDerivative @Wrapper var noDeriv: Float = 4
111+
}
112+
113+
// CHECK-AST-LABEL: internal class WrappedPropertiesClass : Differentiable {
114+
// CHECK-AST: internal struct TangentVector : Differentiable, AdditiveArithmetic {
115+
// CHECK-AST: internal var x: Float.TangentVector
116+
// CHECK-AST: internal var y: Float.TangentVector
117+
// CHECK-AST: internal var z: Float.TangentVector
118+
// CHECK-AST: }
119+
// CHECK-AST: }

test/AutoDiff/Sema/DerivedConformances/struct_differentiable.swift

Lines changed: 24 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ struct ImmutableStoredProperties: Differentiable {
2424
// expected-warning @+1 {{stored property 'nondiff' has no derivative because 'Int' does not conform to 'Differentiable'; add an explicit '@noDerivative' attribute, or conform 'ImmutableStoredProperties' to 'AdditiveArithmetic'}} {{3-3=@noDerivative }}
2525
let nondiff: Int
2626

27-
// expected-warning @+1 {{synthesis of the 'Differentiable.move(along:)' requirement for 'ImmutableStoredProperties' requires all stored properties to be mutable; use 'var' instead, or add an explicit '@noDerivative' attribute, or conform 'ImmutableStoredProperties' to 'AdditiveArithmetic'}} {{3-3=@noDerivative }}
27+
// expected-warning @+1 {{synthesis of the 'Differentiable.move(along:)' requirement for 'ImmutableStoredProperties' requires all stored properties not marked with `@noDerivative` to be mutable; use 'var' instead, or add an explicit '@noDerivative' attribute, or conform 'ImmutableStoredProperties' to 'AdditiveArithmetic}} {{3-3=@noDerivative }}
2828
let diff: Float
2929
}
3030
func testImmutableStoredProperties() {
@@ -36,7 +36,8 @@ struct MutableStoredPropertiesWithInitialValue: Differentiable {
3636
}
3737
// Test struct with both an empty constructor and memberwise initializer.
3838
struct AllMixedStoredPropertiesHaveInitialValue: Differentiable {
39-
let x = Float(1) // expected-warning {{synthesis of the 'Differentiable.move(along:)' requirement for 'AllMixedStoredPropertiesHaveInitialValue' requires all stored properties to be mutable; use 'var' instead, or add an explicit '@noDerivative' attribute}} {{3-3=@noDerivative }}
39+
// expected-warning @+1 {{synthesis of the 'Differentiable.move(along:)' requirement for 'AllMixedStoredPropertiesHaveInitialValue' requires all stored properties not marked with `@noDerivative` to be mutable; use 'var' instead, or add an explicit '@noDerivative' attribute}} {{3-3=@noDerivative }}
40+
let x = Float(1)
4041
var y = Float(1)
4142
// Memberwise initializer should be `init(y:)` since `x` is immutable.
4243
static func testMemberwiseInitializer() {
@@ -323,20 +324,32 @@ where T: AdditiveArithmetic {}
323324
extension NoMemberwiseInitializerExtended: Differentiable
324325
where T: Differentiable & AdditiveArithmetic {}
325326

327+
// Test property wrappers.
326328
// TF-1190: Test `@noDerivative` warning for property wrapper backing storage properties.
327329

328330
@propertyWrapper
329-
struct Wrapper<Value> {
331+
struct ImmutableWrapper<Value> {
330332
private var value: Value
331-
var wrappedValue: Value {
332-
value
333-
}
333+
var wrappedValue: Value { value }
334+
}
335+
336+
@propertyWrapper
337+
struct Wrapper<Value> {
338+
var wrappedValue: Value
334339
}
335-
struct TF_1190<T> {}
336-
struct TF_1190_Outer: Differentiable {
337-
// expected-warning @+1 {{stored property '_x' has no derivative because 'Wrapper<TF_1190<Float>>' does not conform to 'Differentiable'; add an explicit '@noDerivative' attribute}}
338-
@Wrapper var x: TF_1190<Float>
339-
@noDerivative @Wrapper var y: TF_1190<Float>
340+
341+
struct Generic<T> {}
342+
extension Generic: Differentiable where T: Differentiable {}
343+
344+
struct WrappedProperties: Differentiable {
345+
// expected-warning @+1 {{synthesis of the 'Differentiable.move(along:)' requirement for 'WrappedProperties' requires all stored properties not marked with `@noDerivative` to be mutable; add an explicit '@noDerivative' attribute}}
346+
@ImmutableWrapper var immutableInt: Generic<Int>
347+
348+
// expected-warning @+1 {{stored property 'mutableInt' has no derivative because 'Generic<Int>' does not conform to 'Differentiable'; add an explicit '@noDerivative' attribute}}
349+
@Wrapper var mutableInt: Generic<Int>
350+
351+
@Wrapper var float: Generic<Float>
352+
@noDerivative @ImmutableWrapper var nondiff: Generic<Int>
340353
}
341354

342355
// Verify that cross-file derived conformances are disallowed.

0 commit comments

Comments
 (0)