Skip to content

Commit 6a7aacd

Browse files
committed
[ASTGen] Fix '@differentiable' attribute
* Typo: '_liner' -> '_linear' * Accept '@differentiable(_linear)' type attribute
1 parent d9f5001 commit 6a7aacd

File tree

3 files changed

+14
-3
lines changed

3 files changed

+14
-3
lines changed

lib/ASTGen/Sources/ASTGen/DeclAttrs.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -603,7 +603,7 @@ extension ASTGenVisitor {
603603
switch text {
604604
case "reverse": return .reverse
605605
case "wrt", "withRespectTo": return .normal
606-
case "_liner": return .linear
606+
case "_linear": return .linear
607607
case "_forward": return .forward
608608
default: return .nonDifferentiable
609609
}

lib/ASTGen/Sources/ASTGen/TypeAttrs.swift

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -206,10 +206,17 @@ extension ASTGenVisitor {
206206
differentiabilityLoc = nil
207207
}
208208

209-
// Only 'reverse' is supported today.
210-
guard differentiability == .reverse else {
209+
// Only 'reverse' is formally supported today. '_linear' works for testing
210+
// purposes. '_forward' is rejected.
211+
switch differentiability {
212+
case .normal, .nonDifferentiable:
211213
// TODO: Diagnose
212214
fatalError("Only @differentiable(reverse) is supported")
215+
case .forward:
216+
// TODO: Diagnose
217+
fatalError("Only @differentiable(reverse) is supported")
218+
case .reverse, .linear:
219+
break
213220
}
214221

215222
return .createParsed(

test/ASTGen/autodiff.swift

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,10 @@ func testDifferentiableTypeAttr(_ fn: @escaping @differentiable(reverse) (Float)
1717
-> @differentiable(reverse) (Float) -> Float {
1818
return fn
1919
}
20+
func testDifferentiableTypeAttrLinear(_ fn: @escaping @differentiable(_linear) (Float) -> Float)
21+
-> @differentiable(_linear) (Float) -> Float {
22+
return fn
23+
}
2024

2125
@differentiable(reverse)
2226
func testDifferentiableSimple(_ x: Float) -> Float { return x * x }

0 commit comments

Comments
 (0)