Skip to content

Commit 486e667

Browse files
authored
[AutoDiff] Support direct init reference differentiation. (swiftlang#30946)
Support `@differentiable` function conversion for `init` references, in addition to `func` references and literal closures. Minor usability improvement. Resolves SR-12562.
2 parents a864a57 + b4fa7e0 commit 486e667

File tree

4 files changed

+22
-11
lines changed

4 files changed

+22
-11
lines changed

include/swift/AST/DiagnosticsSema.def

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1221,7 +1221,7 @@ ERROR(invalid_autoclosure_forwarding,none,
12211221
"add () to forward @autoclosure parameter", ())
12221222
ERROR(invalid_differentiable_function_conversion_expr,none,
12231223
"a '@differentiable%select{|(linear)}0' function can only be formed from "
1224-
"a reference to a 'func' or a literal closure", (bool))
1224+
"a reference to a 'func' or 'init' or a literal closure", (bool))
12251225
NOTE(invalid_differentiable_function_conversion_parameter,none,
12261226
"did you mean to take a '%0' closure?", (StringRef))
12271227
ERROR(invalid_autoclosure_pointer_conversion,none,

lib/Sema/CSApply.cpp

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5971,7 +5971,7 @@ maybeDiagnoseUnsupportedDifferentiableConversion(ConstraintSystem &cs,
59715971
semanticExpr = capture->getClosureBody();
59725972
if (isa<ClosureExpr>(semanticExpr)) return;
59735973
if (auto *declRef = dyn_cast<DeclRefExpr>(semanticExpr)) {
5974-
if (isa<FuncDecl>(declRef->getDecl())) return;
5974+
if (isa<AbstractFunctionDecl>(declRef->getDecl())) return;
59755975
// If the referenced decl is a function parameter, the user may want
59765976
// to change the declaration to be a '@differentiable' closure. Emit a
59775977
// note with a fix-it.
@@ -6001,13 +6001,16 @@ maybeDiagnoseUnsupportedDifferentiableConversion(ConstraintSystem &cs,
60016001
if (isa<FuncDecl>(memberRef->getMember().getDecl())) return;
60026002
} else if (auto *dotSyntaxCall =
60036003
dyn_cast<DotSyntaxCallExpr>(semanticExpr)) {
6004-
Expr *fnExpr = dotSyntaxCall->getFn()->getSemanticsProvidingExpr();
6005-
while (auto *autoclosureExpr = dyn_cast<AutoClosureExpr>(fnExpr))
6006-
if (auto *unwrappedFnExpr = autoclosureExpr->getUnwrappedCurryThunkExpr())
6007-
fnExpr = unwrappedFnExpr;
60086004
// Recurse on the function expression.
6005+
auto *fnExpr = dotSyntaxCall->getFn()->getSemanticsProvidingExpr();
60096006
maybeDiagnoseFunctionRef(fnExpr);
60106007
return;
6008+
} else if (auto *autoclosureExpr = dyn_cast<AutoClosureExpr>(semanticExpr)) {
6009+
// Peer through curry thunks.
6010+
if (auto *unwrappedFnExpr = autoclosureExpr->getUnwrappedCurryThunkExpr()) {
6011+
maybeDiagnoseFunctionRef(unwrappedFnExpr);
6012+
return;
6013+
}
60116014
}
60126015
ctx.Diags.diagnose(expr->getLoc(),
60136016
diag::invalid_differentiable_function_conversion_expr,

test/AutoDiff/Sema/differentiable_func_type.swift

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -60,19 +60,19 @@ func fakeGradient<T, U: FloatingPoint>(of f: @differentiable (T) -> U) {}
6060

6161
func takesOpaqueClosure(f: @escaping (Float) -> Float) {
6262
// expected-note @-1 {{did you mean to take a '@differentiable' closure?}} {{38-38=@differentiable }}
63-
// expected-error @+1 {{a '@differentiable' function can only be formed from a reference to a 'func' or a literal closure}}
63+
// expected-error @+1 {{a '@differentiable' function can only be formed from a reference to a 'func' or 'init' or a literal closure}}
6464
fakeGradient(of: f)
6565
}
6666

6767
let globalAddOne: (Float) -> Float = { $0 + 1 }
68-
// expected-error @+1 {{a '@differentiable' function can only be formed from a reference to a 'func' or a literal closure}}
68+
// expected-error @+1 {{a '@differentiable' function can only be formed from a reference to a 'func' or 'init' or a literal closure}}
6969
fakeGradient(of: globalAddOne)
7070

7171
func someScope() {
7272
let localAddOne: (Float) -> Float = { $0 + 1 }
73-
// expected-error @+1 {{a '@differentiable' function can only be formed from a reference to a 'func' or a literal closure}}
73+
// expected-error @+1 {{a '@differentiable' function can only be formed from a reference to a 'func' or 'init' or a literal closure}}
7474
fakeGradient(of: globalAddOne)
75-
// expected-error @+1 {{a '@differentiable' function can only be formed from a reference to a 'func' or a literal closure}}
75+
// expected-error @+1 {{a '@differentiable' function can only be formed from a reference to a 'func' or 'init' or a literal closure}}
7676
fakeGradient(of: localAddOne)
7777
// The following case is okay during type checking, but will fail in the AD transform.
7878
fakeGradient { localAddOne($0) }
@@ -95,10 +95,15 @@ func linearToDifferentiable(_ f: @escaping @differentiable(linear) (Float) -> Fl
9595
}
9696

9797
func differentiableToLinear(_ f: @escaping @differentiable (Float) -> Float) {
98-
// expected-error @+1 {{a '@differentiable(linear)' function can only be formed from a reference to a 'func' or a literal closure}}
98+
// expected-error @+1 {{a '@differentiable(linear)' function can only be formed from a reference to a 'func' or 'init' or a literal closure}}
9999
_ = f as @differentiable(linear) (Float) -> Float
100100
}
101101

102+
struct Struct: Differentiable {
103+
var x: Float
104+
}
105+
let _: @differentiable (Float) -> Struct = Struct.init
106+
102107
//===----------------------------------------------------------------------===//
103108
// Parameter selection (@noDerivative)
104109
//===----------------------------------------------------------------------===//

test/AutoDiff/validation-test/simple_math.swift

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -202,6 +202,9 @@ SimpleMathTests.test("StructMemberwiseInitializer") {
202202
}
203203
}
204204

205+
// Test direct `init` reference.
206+
expectEqual(10, pullback(at: 4, in: Foo.init)(.init(stored: 10)))
207+
205208
let 𝛁foo = pullback(at: Float(4), in: { input -> Foo in
206209
let foo = Foo(stored: input)
207210
let foo2 = foo + foo

0 commit comments

Comments
 (0)