Skip to content

Commit a8a95d0

Browse files
authored
[AutoDiff upstream] Enable @differentiable on setters. (swiftlang#35133)
Enable `@differentiable` attribute on setters of properties and subscripts in `Differentiable`-conforming types. Add automatically-differentiated `@differentiable` setter test. Resolves TF-1166.
1 parent 2c5abd8 commit a8a95d0

File tree

5 files changed

+52
-15
lines changed

5 files changed

+52
-15
lines changed

lib/Sema/TypeCheckAttr.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4325,11 +4325,10 @@ resolveDifferentiableAttrOriginalFunction(DifferentiableAttr *attr) {
43254325
}
43264326
}
43274327
// Non-`get` accessors are not yet supported: `set`, `read`, and `modify`.
4328-
// TODO(TF-129): Enable `set` when differentiation supports inout parameters.
43294328
// TODO(TF-1080): Enable `read` and `modify` when differentiation supports
43304329
// coroutines.
43314330
if (auto *accessor = dyn_cast_or_null<AccessorDecl>(original))
4332-
if (!accessor->isGetter())
4331+
if (!accessor->isGetter() && !accessor->isSetter())
43334332
original = nullptr;
43344333
// Diagnose if original `AbstractFunctionDecl` could not be resolved.
43354334
if (!original) {

test/AutoDiff/Sema/differentiable_attr_type_checking.swift

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -171,15 +171,13 @@ struct SubscriptMethod: Differentiable {
171171
subscript(explicit x: Float) -> Float {
172172
@differentiable // ok
173173
get { return x }
174-
// expected-error @+1 {{'@differentiable' attribute cannot be applied to this declaration}}
175174
@differentiable
176175
set {}
177176
}
178177

179178
subscript(x: Float, y: Float) -> Float {
180179
@differentiable // ok
181180
get { return x + y }
182-
// expected-error @+1 {{'@differentiable' attribute cannot be applied to this declaration}}
183181
@differentiable
184182
set {}
185183
}
@@ -700,7 +698,6 @@ struct Accessors: Differentiable {
700698

701699
var stored: Float
702700
var computed: Float {
703-
// expected-error @+1 {{'@differentiable' attribute cannot be applied to this declaration}}
704701
@differentiable
705702
set { stored = newValue }
706703

test/AutoDiff/TBD/derivative_symbols.swift

Lines changed: 23 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,12 @@ public func topLevelDerivative<T: Differentiable>(_ x: T) -> (
2222
public struct Struct: Differentiable {
2323
var stored: Float
2424

25-
// Test property.
26-
@differentiable
25+
// Test property: getter and setter.
2726
public var property: Float {
28-
stored
27+
@differentiable
28+
get { stored }
29+
@differentiable
30+
set { stored = newValue }
2931
}
3032

3133
// Test initializer.
@@ -35,24 +37,37 @@ public struct Struct: Differentiable {
3537
}
3638

3739
// Test method.
38-
public func method(x: Float, y: Float) -> Float { x }
40+
public func method(_ x: Float, _ y: Float) -> Float { x }
3941

4042
@derivative(of: method)
41-
public func jvpMethod(x: Float, y: Float) -> (
43+
public func jvpMethod(_ x: Float, _ y: Float) -> (
4244
value: Float, differential: (TangentVector, Float, Float) -> Float
4345
) {
4446
fatalError()
4547
}
4648

47-
// Test subscript.
48-
public subscript(x: Float) -> Float { x }
49+
// Test subscript: getter and setter.
50+
public subscript(_ x: Float) -> Float {
51+
@differentiable
52+
get { x }
53+
54+
@differentiable
55+
set { stored = newValue }
56+
}
4957

5058
@derivative(of: subscript)
51-
public func vjpSubscript(x: Float) -> (
59+
public func vjpSubscript(_ x: Float) -> (
5260
value: Float, pullback: (Float) -> (TangentVector, Float)
5361
) {
5462
fatalError()
5563
}
64+
65+
@derivative(of: subscript.set)
66+
public mutating func vjpSubscriptSetter(_ x: Float, _ newValue: Float) -> (
67+
value: (), pullback: (inout TangentVector) -> (Float, Float)
68+
) {
69+
fatalError()
70+
}
5671
}
5772

5873
extension Array where Element == Struct {

test/AutoDiff/validation-test/derivative_registration.swift

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,8 +94,11 @@ DerivativeRegistrationTests.testWithLeakChecking("InstanceMethod") {
9494

9595
extension Wrapper {
9696
subscript(_ x: Tracked<Float>) -> Tracked<Float> {
97+
@differentiable
9798
@_semantics("autodiff.opaque")
9899
get { float * x }
100+
101+
@differentiable
99102
set {}
100103
}
101104

@@ -117,7 +120,10 @@ DerivativeRegistrationTests.testWithLeakChecking("SubscriptGetter") {
117120

118121
extension Wrapper {
119122
subscript() -> Tracked<Float> {
123+
@differentiable
120124
get { float }
125+
126+
@differentiable
121127
set { float = newValue }
122128
}
123129

test/AutoDiff/validation-test/inout_parameters.swift

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,15 @@ InoutParameterAutoDiffTests.test("SetAccessor") {
117117
get { x }
118118
set { x = newValue }
119119
}
120+
121+
// Computed property with explicit `@differentiable` accessors.
122+
var doubled: Float {
123+
@differentiable
124+
get { x + x }
125+
126+
@differentiable
127+
set { x = newValue / 2 }
128+
}
120129
}
121130

122131
// `squared` implemented using a `set` accessor.
@@ -126,8 +135,19 @@ InoutParameterAutoDiffTests.test("SetAccessor") {
126135
s.computed *= x
127136
return s.x
128137
}
129-
expectEqual(6, gradient(at: 3, in: squared))
130-
expectEqual(8, gradient(at: 4, in: squared))
138+
expectEqual((9, 6), valueWithGradient(at: 3, in: squared))
139+
expectEqual((16, 8), valueWithGradient(at: 4, in: squared))
140+
141+
// `quadrupled` implemented using a `set` accessor.
142+
func quadrupled(_ x: Float) -> Float {
143+
var s = S(x: 1)
144+
s.doubled *= 4 * x
145+
return s.x
146+
}
147+
print(valueWithGradient(at: 3, in: quadrupled))
148+
print(valueWithGradient(at: 4, in: quadrupled))
149+
expectEqual((12, 4), valueWithGradient(at: 3, in: quadrupled))
150+
expectEqual((16, 4), valueWithGradient(at: 4, in: quadrupled))
131151
}
132152

133153
// Test differentiation wrt `inout` parameters that have a class type.

0 commit comments

Comments
 (0)