Skip to content

Commit d067e7e

Browse files
author
marcrasi
authored
[AutoDiff upstream] more IRGen for @differentiable functions (swiftlang#30688)
1 parent b6bcd85 commit d067e7e

File tree

2 files changed

+150
-7
lines changed

2 files changed

+150
-7
lines changed

lib/IRGen/IRGenSIL.cpp

Lines changed: 33 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2350,11 +2350,33 @@ void IRGenSILFunction::visitTryApplyInst(swift::TryApplyInst *i) {
23502350
}
23512351

23522352
void IRGenSILFunction::visitFullApplySite(FullApplySite site) {
2353-
const LoweredValue &calleeLV = getLoweredValue(site.getCallee());
2354-
23552353
auto origCalleeType = site.getOrigCalleeType();
23562354
auto substCalleeType = site.getSubstCalleeType();
2357-
2355+
if (site.getOrigCalleeType()->isDifferentiable()) {
2356+
origCalleeType = origCalleeType->getWithoutDifferentiability();
2357+
substCalleeType = substCalleeType->getWithoutDifferentiability();
2358+
}
2359+
2360+
// If the callee is a differentiable function, we extract the original
2361+
// function because we want to call the original function.
2362+
Optional<LoweredValue> diffCalleeOrigFnLV;
2363+
if (site.getOrigCalleeType()->isDifferentiable()) {
2364+
auto diffFnExplosion = getLoweredExplosion(site.getCallee());
2365+
Explosion origFnExplosion;
2366+
unsigned fieldSize = 1;
2367+
if (origCalleeType->getRepresentation() ==
2368+
SILFunctionTypeRepresentation::Thick) {
2369+
fieldSize = 2;
2370+
}
2371+
origFnExplosion.add(diffFnExplosion.getRange(0, 0 + fieldSize));
2372+
(void)diffFnExplosion.claimAll();
2373+
diffCalleeOrigFnLV = LoweredValue(origFnExplosion);
2374+
}
2375+
2376+
const LoweredValue &calleeLV =
2377+
diffCalleeOrigFnLV ? *diffCalleeOrigFnLV :
2378+
getLoweredValue(site.getCallee());
2379+
23582380
auto args = site.getArguments();
23592381
SILFunctionConventions origConv(origCalleeType, getSILModule());
23602382
assert(origConv.getNumSILArguments() == args.size());
@@ -4542,11 +4564,15 @@ void IRGenSILFunction::visitConvertEscapeToNoEscapeInst(
45424564
swift::ConvertEscapeToNoEscapeInst *i) {
45434565
// This instruction makes the context trivial.
45444566
Explosion in = getLoweredExplosion(i->getOperand());
4545-
llvm::Value *fn = in.claimNext();
4546-
llvm::Value *ctx = in.claimNext();
45474567
Explosion out;
4548-
out.add(fn);
4549-
out.add(Builder.CreateBitCast(ctx, IGM.OpaquePtrTy));
4568+
// Differentiable functions contain multiple pairs of fn and ctx pointer.
4569+
for (unsigned index : range(in.size() / 2)) {
4570+
(void)index;
4571+
llvm::Value *fn = in.claimNext();
4572+
llvm::Value *ctx = in.claimNext();
4573+
out.add(fn);
4574+
out.add(Builder.CreateBitCast(ctx, IGM.OpaquePtrTy));
4575+
}
45504576
setLoweredExplosion(i, out);
45514577
}
45524578

Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
// RUN: %target-swift-frontend -enable-experimental-differentiable-programming -emit-ir %s | %FileCheck %s
2+
3+
sil_stage raw
4+
5+
import Swift
6+
import Builtin
7+
import _Differentiation
8+
9+
sil @f : $@convention(thin) (Float) -> Float {
10+
bb0(%0 : $Float):
11+
return undef : $Float
12+
}
13+
14+
sil @f_jvp : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float) {
15+
bb0(%0 : $Float):
16+
return undef : $(Float, @callee_guaranteed (Float) -> Float)
17+
}
18+
19+
sil @f_vjp : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float) {
20+
bb0(%0 : $Float):
21+
return undef : $(Float, @callee_guaranteed (Float) -> Float)
22+
}
23+
24+
sil @test_form_diff_func : $@convention(thin) () -> @owned @differentiable @callee_guaranteed (Float) -> Float {
25+
bb0:
26+
%orig = function_ref @f : $@convention(thin) (Float) -> Float
27+
%origThick = thin_to_thick_function %orig : $@convention(thin) (Float) -> Float to $@callee_guaranteed (Float) -> Float
28+
%jvp = function_ref @f_jvp : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float)
29+
%jvpThick = thin_to_thick_function %jvp : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float) to $@callee_guaranteed (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float)
30+
%vjp = function_ref @f_vjp : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float)
31+
%vjpThick = thin_to_thick_function %vjp : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float) to $@callee_guaranteed (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float)
32+
%result = differentiable_function [parameters 0] %origThick : $@callee_guaranteed (Float) -> Float with_derivative {%jvpThick : $@callee_guaranteed (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float), %vjpThick : $@callee_guaranteed (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float)}
33+
return %result : $@differentiable @callee_guaranteed (Float) -> Float
34+
}
35+
36+
// CHECK-LABEL: define{{.*}}test_form_diff_func(<{ %swift.function, %swift.function, %swift.function }>*
37+
// CHECK-SAME: [[OUT:%.*]])
38+
// CHECK: [[OUT_ORIG:%.*]] = getelementptr{{.*}}[[OUT]], i32 0, i32 0
39+
// CHECK: [[OUT_ORIG_FN:%.*]] = getelementptr{{.*}}[[OUT_ORIG]], i32 0, i32 0
40+
// CHECK: store{{.*}}@f{{.*}}[[OUT_ORIG_FN]]
41+
// CHECK: [[OUT_ORIG_DATA:%.*]] = getelementptr{{.*}} [[OUT_ORIG]], i32 0, i32 1
42+
// CHECK: store{{.*}}null{{.*}} [[OUT_ORIG_DATA]]
43+
// CHECK: [[OUT_JVP:%.*]] = getelementptr{{.*}}[[OUT]], i32 0, i32 1
44+
// CHECK: [[OUT_JVP_FN:%.*]] = getelementptr{{.*}}[[OUT_JVP]], i32 0, i32 0
45+
// CHECK: store{{.*}}@f_jvp{{.*}}[[OUT_JVP_FN]]
46+
// CHECK: [[OUT_JVP_DATA:%.*]] = getelementptr{{.*}}[[OUT_JVP]], i32 0, i32 1
47+
// CHECK: store{{.*}} null{{.*}}[[OUT_JVP_DATA]]
48+
// CHECK: [[OUT_VJP:%.*]] = getelementptr{{.*}}[[OUT]], i32 0, i32 2
49+
// CHECK: [[OUT_VJP_FN:%.*]] = getelementptr{{.*}}[[OUT_VJP]], i32 0, i32 0
50+
// CHECK: store{{.*}}@f_vjp{{.*}}[[OUT_VJP_FN]]
51+
// CHECK: [[OUT_VJP_DATA:%.*]] = getelementptr{{.*}}[[OUT_VJP]], i32 0, i32 1
52+
// CHECK: store{{.*}}null{{.*}}[[OUT_VJP_DATA]]
53+
54+
sil @test_extract_components : $@convention(thin) (@guaranteed @differentiable @callee_guaranteed (Float) -> Float) -> (@owned @callee_guaranteed (Float) -> Float, @owned @callee_guaranteed (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float), @owned @callee_guaranteed (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float)) {
55+
bb0(%0 : $@differentiable @callee_guaranteed (Float) -> Float):
56+
%orig = differentiable_function_extract [original] %0 : $@differentiable @callee_guaranteed (Float) -> Float
57+
%jvp = differentiable_function_extract [jvp] %0 : $@differentiable @callee_guaranteed (Float) -> Float
58+
%vjp = differentiable_function_extract [vjp] %0 : $@differentiable @callee_guaranteed (Float) -> Float
59+
%result = tuple (%orig : $@callee_guaranteed (Float) -> Float, %jvp : $@callee_guaranteed (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float), %vjp : $@callee_guaranteed (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float))
60+
return %result : $(@callee_guaranteed (Float) -> Float, @callee_guaranteed (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float), @callee_guaranteed (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float))
61+
}
62+
63+
// CHECK-LABEL: define{{.*}}@test_extract_components(<{ %swift.function, %swift.function, %swift.function }>*
64+
// CHECK-SAME: [[OUT:%.*]], <{ %swift.function, %swift.function, %swift.function }>*{{.*}}[[IN:%.*]])
65+
// CHECK: [[ORIG:%.*]] = getelementptr{{.*}}[[IN]], i32 0, i32 0
66+
// CHECK: [[ORIG_FN_ADDR:%.*]] = getelementptr{{.*}}[[ORIG]], i32 0, i32 0
67+
// CHECK: [[ORIG_FN:%.*]] = load{{.*}}[[ORIG_FN_ADDR]]
68+
// CHECK: [[ORIG_DATA_ADDR:%.*]] = getelementptr{{.*}}[[ORIG]], i32 0, i32 1
69+
// CHECK: [[ORIG_DATA:%.*]] = load{{.*}}[[ORIG_DATA_ADDR]]
70+
// CHECK: [[JVP:%.*]] = getelementptr{{.*}}[[IN]], i32 0, i32 1
71+
// CHECK: [[JVP_FN_ADDR:%.*]] = getelementptr{{.*}}[[JVP]], i32 0, i32 0
72+
// CHECK: [[JVP_FN:%.*]] = load{{.*}}[[JVP_FN_ADDR]]
73+
// CHECK: [[JVP_DATA_ADDR:%.*]] = getelementptr{{.*}}[[JVP]], i32 0, i32 1
74+
// CHECK: [[JVP_DATA:%.*]] = load{{.*}}[[JVP_DATA_ADDR]]
75+
// CHECK: [[VJP:%.*]] = getelementptr{{.*}}[[IN]], i32 0, i32 2
76+
// CHECK: [[VJP_FN_ADDR:%.*]] = getelementptr{{.*}}[[VJP]], i32 0, i32 0
77+
// CHECK: [[VJP_FN:%.*]] = load{{.*}}[[VJP_FN_ADDR]]
78+
// CHECK: [[VJP_DATA_ADDR:%.*]] = getelementptr{{.*}}[[VJP]], i32 0, i32 1
79+
// CHECK: [[VJP_DATA:%.*]] = load{{.*}}[[VJP_DATA_ADDR]]
80+
// CHECK: [[OUT_0:%.*]] = getelementptr{{.*}}[[OUT]], i32 0, i32 0
81+
// CHECK: [[OUT_0_FN:%.*]] = getelementptr{{.*}}[[OUT_0]], i32 0, i32 0
82+
// CHECK: store{{.*}}[[ORIG_FN]]{{.*}}[[OUT_0_FN]]
83+
// CHECK: [[OUT_0_DATA:%.*]] = getelementptr{{.*}}[[OUT_0]], i32 0, i32 1
84+
// CHECK: store{{.*}}[[ORIG_DATA]]{{.*}}[[OUT_0_DATA]]
85+
// CHECK: [[OUT_1:%.*]] = getelementptr{{.*}}[[OUT]], i32 0, i32 1
86+
// CHECK: [[OUT_1_FN:%.*]] = getelementptr{{.*}}[[OUT_1]], i32 0, i32 0
87+
// CHECK: store{{.*}}[[JVP_FN]]{{.*}}[[OUT_1_FN]]
88+
// CHECK: [[OUT_1_DATA:%.*]] = getelementptr{{.*}}[[OUT_1]], i32 0, i32 1
89+
// CHECK: store{{.*}}[[JVP_DATA]]{{.*}}[[OUT_1_DATA]]
90+
// CHECK: [[OUT_2:%.*]] = getelementptr{{.*}}[[OUT]], i32 0, i32 2
91+
// CHECK: [[OUT_2_FN:%.*]] = getelementptr{{.*}}[[OUT_2]], i32 0, i32 0
92+
// CHECK: store{{.*}}[[VJP_FN]]{{.*}}[[OUT_2_FN]]
93+
// CHECK: [[OUT_2_DATA:%.*]] = getelementptr{{.*}}[[OUT_2]], i32 0, i32 1
94+
// CHECK: store{{.*}}[[VJP_DATA]]{{.*}}[[OUT_2_DATA]]
95+
96+
sil @test_call_diff_fn : $@convention(thin) (@guaranteed @differentiable @callee_guaranteed (Float) -> Float, Float) -> Float {
97+
bb0(%0 : $@differentiable @callee_guaranteed (Float) -> Float, %1 : $Float):
98+
%2 = apply %0(%1) : $@differentiable @callee_guaranteed (Float) -> Float
99+
return %2 : $Float
100+
}
101+
102+
// CHECK-LABEL: define{{.*}}@test_call_diff_fn(<{ %swift.function, %swift.function, %swift.function }>*
103+
// CHECK-SAME: [[DIFF_FN:%.*]], float [[INPUT_FLOAT:%.*]])
104+
// CHECK: [[ORIG:%.*]] = getelementptr{{.*}}[[DIFF_FN]], i32 0, i32 0
105+
// CHECK: [[ORIG_FN_ADDR:%.*]] = getelementptr{{.*}}[[ORIG]], i32 0, i32 0
106+
// CHECK: [[ORIG_FN:%.*]] = load{{.*}}[[ORIG_FN_ADDR]]
107+
// CHECK: [[ORIG_DATA_ADDR:%.*]] = getelementptr{{.*}}[[ORIG]], i32 0, i32 1
108+
// CHECK: [[ORIG_DATA:%.*]] = load{{.*}}[[ORIG_DATA_ADDR]]
109+
// CHECK: [[ORIG_FN_CAST:%.*]] = bitcast{{.*}}[[ORIG_FN]]
110+
// CHECK: [[RESULT:%.*]] = call swiftcc float [[ORIG_FN_CAST]](float [[INPUT_FLOAT]], %swift.refcounted* swiftself [[ORIG_DATA]])
111+
// CHECK: ret float [[RESULT]]
112+
113+
sil @test_convert_escape_to_noescape : $@convention(thin) (@guaranteed @differentiable @callee_guaranteed (Float) -> Float) -> () {
114+
bb0(%0 : $@differentiable @callee_guaranteed (Float) -> Float):
115+
%1 = convert_escape_to_noescape %0 : $@differentiable @callee_guaranteed (Float) -> Float to $@noescape @differentiable @callee_guaranteed (Float) -> Float
116+
return undef : $()
117+
}

0 commit comments

Comments
 (0)