Skip to content

Commit f5d304c

Browse files
committed
[Docs] [AutoDiff] Add default derivative subsection and fix minor issues.
1 parent 05081ee commit f5d304c

File tree

1 file changed

+77
-7
lines changed

1 file changed

+77
-7
lines changed

docs/DifferentiableProgramming.md

Lines changed: 77 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1452,8 +1452,11 @@ making the other function linear.
14521452

14531453
A protocol requirement or class method/property/subscript can be made
14541454
differentiable via a derivative function or transpose function defined in an
1455-
extension. A dispatched call to such a member can be differentiated even if the
1456-
concrete implementation is not differentiable.
1455+
extension. When a protocol requirement is not marked with `@differentiable` but
1456+
has been made differentiable by a `@derivative` or `@transpose` declaration in a
1457+
protocol extension, a dispatched call to such a member can be differentiated,
1458+
and the derivative or transpose is always the one provided in the protocol
1459+
extension.
14571460

14581461
#### Linear maps
14591462

@@ -1731,8 +1734,8 @@ public extension ElementaryFunctions where Self: Differentiable, Self == Self.Ta
17311734

17321735
@inlinable
17331736
@derivative(of: log)
1734-
func _(_ x: Self) -> (value: Self, differential: @differential(linear) (Self) -> Self) { dx in
1735-
(log(x), { 1 / x * dx })
1737+
func _(_ x: Self) -> (value: Self, differential: @differential(linear) (Self) -> Self) {
1738+
(log(x), { dx in 1 / x * dx })
17361739
}
17371740

17381741
@inlinable
@@ -1749,6 +1752,73 @@ public extension ElementaryFunctions where Self: Differentiable, Self == Self.Ta
17491752
}
17501753
```
17511754

1755+
#### Default derivatives
1756+
1757+
In a protocol extension, class definition, or class definition, providing a
1758+
derivative or transpose for a protocol extension or a non-final class member is
1759+
considered as providing a default derivative for that member. Types that conform
1760+
to the protocol or inherit from the class can inherit the default derivative.
1761+
1762+
If the original member does not have a `@differentiable` attribute, a default
1763+
derivative is implicitly add to all conforming/overriding implementations.
1764+
1765+
```swift
1766+
protocol P {
1767+
func foo(_ x: Float) -> Float
1768+
}
1769+
1770+
extension P {
1771+
@derivative(of: foo(x:))
1772+
func _(_ x: Float) -> (value: Float, differential: (Float) -> Float) {
1773+
(value: foo(x), differential: { _ in 42 })
1774+
}
1775+
}
1776+
1777+
struct S: P {
1778+
func foo(_ x: Float) -> Float {
1779+
33
1780+
}
1781+
}
1782+
1783+
let s = S()
1784+
let d = derivative(at: 0) { x in
1785+
s.foo(x)
1786+
} // ==> 42
1787+
```
1788+
1789+
When a protocol requirement or class member is marked with `@differentiable`, it
1790+
is considered as a _differentiability customization point_. This means that all
1791+
conforming/overriding implementation must provide a corresponding
1792+
`@differentiable` attribute, which causes the implementation to be will be
1793+
differentiated. To inherit the default derivative without differentiating the
1794+
implementation, add `default` to the `@differentiable` attribute.
1795+
1796+
```swift
1797+
protocol P {
1798+
@differentiable
1799+
func foo(_ x: Float) -> Float
1800+
}
1801+
1802+
extension P {
1803+
@derivative(of: foo(x:))
1804+
func _(_ x: Float) -> (value: Float, differential: (Float) -> Float) {
1805+
(value: foo(x), differential: { _ in 42 })
1806+
}
1807+
}
1808+
1809+
struct S: P {
1810+
@differentiable(default) // Inherits from `P.foo(_:)`.
1811+
func foo(_ x: Float) -> Float {
1812+
33
1813+
}
1814+
}
1815+
1816+
let s = S()
1817+
let d = derivative(at: 0) { x in
1818+
s.foo(x)
1819+
} // ==> 42
1820+
```
1821+
17521822
### Differentiable function types
17531823

17541824
Differentiability is a fundamental mathematical concept that applies not only to
@@ -2239,13 +2309,13 @@ whether the derivative is always zero and warns the user.
22392309

22402310
```swift
22412311
let grad = gradient(at: 1.0) { x in
2242-
3.squareRoot()
2312+
Double(3).squareRoot()
22432313
}
22442314
```
22452315

22462316
```console
2247-
test.swift:4:18: warning: result does not depend on differentiation arguments and will always have a zero derivative; do you want to add '.withoutDerivative()' to make it explicit?
2248-
3.squareRoot()
2317+
test.swift:4:18: warning: result does not depend on differentiation arguments and will always have a zero derivative; do you want to use 'withoutDerivative(at:)' to make it explicit?
2318+
Double(3).squareRoot()
22492319
^
22502320
withoutDerivative(at:)
22512321
```

0 commit comments

Comments
 (0)