Skip to content

Commit 76ad4ab

Browse files
authored
[AutoDiff upstream] Fill in derivative witness table/vtable thunks. (swiftlang#30634)
Generate `differentiable_function` and `differentiable_function_extract` in derivative witness table/vtable thunks. `differentiation_function` is later canonicalized by the differentiation transform. Add SIL FileCheck tests.
1 parent e5cb871 commit 76ad4ab

File tree

5 files changed

+81
-16
lines changed

5 files changed

+81
-16
lines changed

lib/SIL/OwnershipUtils.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ bool swift::isOwnershipForwardingValueKind(SILNodeKind kind) {
3030
case SILNodeKind::TupleInst:
3131
case SILNodeKind::StructInst:
3232
case SILNodeKind::EnumInst:
33+
case SILNodeKind::DifferentiableFunctionInst:
3334
case SILNodeKind::OpenExistentialRefInst:
3435
case SILNodeKind::UpcastInst:
3536
case SILNodeKind::UncheckedRefCastInst:
@@ -46,7 +47,6 @@ bool swift::isOwnershipForwardingValueKind(SILNodeKind kind) {
4647
case SILNodeKind::DestructureTupleInst:
4748
case SILNodeKind::MarkDependenceInst:
4849
case SILNodeKind::InitExistentialRefInst:
49-
case SILNodeKind::DifferentiableFunctionInst:
5050
return true;
5151
default:
5252
return false;
@@ -59,9 +59,9 @@ bool swift::isGuaranteedForwardingValueKind(SILNodeKind kind) {
5959
switch (kind) {
6060
case SILNodeKind::TupleExtractInst:
6161
case SILNodeKind::StructExtractInst:
62+
case SILNodeKind::DifferentiableFunctionExtractInst:
6263
case SILNodeKind::OpenExistentialValueInst:
6364
case SILNodeKind::OpenExistentialBoxValueInst:
64-
case SILNodeKind::DifferentiableFunctionExtractInst:
6565
return true;
6666
default:
6767
return isOwnershipForwardingValueKind(kind);

lib/SILGen/SILGenPoly.cpp

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4455,10 +4455,17 @@ getWitnessFunctionRef(SILGenFunction &SGF,
44554455
switch (witnessKind) {
44564456
case WitnessDispatchKind::Static:
44574457
if (auto *derivativeId = witness.derivativeFunctionIdentifier) {
4458-
// TODO(TF-1139, TF-1140): Replace `undef` with `differentiable_function`
4459-
// and `differentiable_function_extract`.
4460-
auto derivativeFnSilType = SILType::getPrimitiveObjectType(witnessFTy);
4461-
return SILUndef::get(derivativeFnSilType, SGF.F);
4458+
auto originalFn =
4459+
SGF.emitGlobalFunctionRef(loc, witness.asAutoDiffOriginalFunction());
4460+
auto *loweredParamIndices = autodiff::getLoweredParameterIndices(
4461+
derivativeId->getParameterIndices(),
4462+
witness.getDecl()->getInterfaceType()->castTo<AnyFunctionType>());
4463+
auto diffFn = SGF.B.createDifferentiableFunction(loc, loweredParamIndices,
4464+
originalFn);
4465+
return SGF.B.createDifferentiableFunctionExtract(
4466+
loc,
4467+
NormalDifferentiableFunctionTypeComponent(derivativeId->getKind()),
4468+
diffFn);
44624469
}
44634470
return SGF.emitGlobalFunctionRef(loc, witness);
44644471
case WitnessDispatchKind::Dynamic:

lib/SILGen/SILGenThunk.cpp

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -194,14 +194,14 @@ SILFunction *SILGenModule::getOrCreateAutoDiffClassMethodThunk(
194194
auto *derivativeFnDecl = derivativeFnDeclRef.getDecl();
195195

196196
SILGenFunctionBuilder builder(*this);
197-
auto originalFn = derivativeFnDeclRef.asAutoDiffOriginalFunction();
197+
auto originalFnDeclRef = derivativeFnDeclRef.asAutoDiffOriginalFunction();
198198
// TODO(TF-685): Use principled thunk mangling.
199199
// Do not simply reuse reabstraction thunk mangling.
200200
auto name = derivativeFnDeclRef.mangle() + "_vtable_entry_thunk";
201201
auto *thunk = builder.getOrCreateFunction(
202-
derivativeFnDecl, name, originalFn.getLinkage(ForDefinition), constantTy,
203-
IsBare, IsTransparent, derivativeFnDeclRef.isSerialized(), IsNotDynamic,
204-
ProfileCounter(), IsThunk);
202+
derivativeFnDecl, name, originalFnDeclRef.getLinkage(ForDefinition),
203+
constantTy, IsBare, IsTransparent, derivativeFnDeclRef.isSerialized(),
204+
IsNotDynamic, ProfileCounter(), IsThunk);
205205
if (!thunk->empty())
206206
return thunk;
207207

@@ -212,14 +212,20 @@ SILFunction *SILGenModule::getOrCreateAutoDiffClassMethodThunk(
212212
auto loc = derivativeFnDeclRef.getAsRegularLocation();
213213
SGF.collectThunkParams(loc, params);
214214

215-
// TODO(TF-1139, TF-1140): Replace `undef` with `differentiable_function` and
216-
// `differentiable_function_extract`.
217-
auto derivativeSilTy = SILType::getPrimitiveObjectType(constantTy);
218-
auto derivativeFn = SILUndef::get(derivativeSilTy, *thunk);
215+
auto originalFn = SGF.emitGlobalFunctionRef(loc, originalFnDeclRef);
216+
auto *loweredParamIndices = autodiff::getLoweredParameterIndices(
217+
derivativeId->getParameterIndices(),
218+
derivativeFnDecl->getInterfaceType()->castTo<AnyFunctionType>());
219+
auto diffFn =
220+
SGF.B.createDifferentiableFunction(loc, loweredParamIndices, originalFn);
221+
auto derivativeFn = SGF.B.createDifferentiableFunctionExtract(
222+
loc, NormalDifferentiableFunctionTypeComponent(derivativeId->getKind()),
223+
diffFn);
224+
auto derivativeFnSILTy = SILType::getPrimitiveObjectType(constantTy);
219225
SmallVector<SILValue, 4> args(thunk->getArguments().begin(),
220226
thunk->getArguments().end());
221227
auto apply =
222-
SGF.emitApplyWithRethrow(loc, derivativeFn, derivativeSilTy,
228+
SGF.emitApplyWithRethrow(loc, derivativeFn, derivativeFnSILTy,
223229
SGF.getForwardingSubstitutionMap(), args);
224230
SGF.B.createReturn(loc, apply);
225231

test/AutoDiff/SILGen/vtable.swift

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,19 @@ class Sub: Super {
9595

9696
class SubSub: Sub {}
9797

98+
// Check vtable entry thunks.
99+
100+
// CHECK-LABEL: sil hidden [transparent] [thunk] [ossa] @AD__${{.*}}5SuperC6methody{{.*}}jvp_src_0_wrt_0_vtable_entry_thunk : $@convention(method) (Float, Float, @guaranteed Super) -> (Float, @owned @callee_guaranteed (Float) -> Float) {
101+
// CHECK: bb0(%0 : $Float, %1 : $Float, %2 : @guaranteed $Super):
102+
// CHECK: %3 = function_ref @$s6vtable5SuperC6methodyS2f_SftF : $@convention(method) (Float, Float, @guaranteed Super) -> Float
103+
// CHECK: %4 = differentiable_function [parameters 0] %3 : $@convention(method) (Float, Float, @guaranteed Super) -> Float
104+
// CHECK: %5 = differentiable_function_extract [jvp] %4 : $@differentiable @convention(method) (Float, @noDerivative Float, @noDerivative @guaranteed Super) -> Float
105+
// CHECK: %6 = apply %5(%0, %1, %2) : $@convention(method) (Float, Float, @guaranteed Super) -> (Float, @owned @callee_guaranteed (Float) -> Float)
106+
// CHECK: return %6 : $(Float, @callee_guaranteed (Float) -> Float)
107+
// CHECK: }
108+
109+
// Check vtable entries: new vs `[override]` vs `[inherited]` entries.
110+
98111
// CHECK-LABEL: sil_vtable Super {
99112
// CHECK: #Super.method: (Super) -> (Float, Float) -> Float : @$s6vtable5SuperC6methodyS2f_SftF
100113
// CHECK: #Super.method!jvp.SUU: (Super) -> (Float, Float) -> Float : @AD__$s6vtable5SuperC6methodyS2f_SftF__jvp_src_0_wrt_0_vtable_entry_thunk

test/AutoDiff/SILGen/witness_table.swift

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,12 +35,32 @@ struct Struct: Protocol {
3535
}
3636

3737
// CHECK-LABEL: sil private [transparent] [thunk] [ossa] @AD__${{.*}}method{{.*}}_jvp_SUU : $@convention(witness_method: Protocol) (Float, Double, @in_guaranteed Struct) -> (Float, @owned @callee_guaranteed (Float) -> Float) {
38+
// CHECK: [[ORIG_FN:%.*]] = function_ref {{.*}}method{{.*}} : $@convention(method) (Float, Double, Struct) -> Float
39+
// CHECK: [[DIFF_FN:%.*]] = differentiable_function [parameters 0] [[ORIG_FN]]
40+
// CHECK: [[JVP_FN:%.*]] = differentiable_function_extract [jvp] [[DIFF_FN]]
41+
// CHECK: apply [[JVP_FN]]
42+
// CHECK: }
3843

3944
// CHECK-LABEL: sil private [transparent] [thunk] [ossa] @AD__${{.*}}method{{.*}}_vjp_SUU : $@convention(witness_method: Protocol) (Float, Double, @in_guaranteed Struct) -> (Float, @owned @callee_guaranteed (Float) -> Float) {
45+
// CHECK: [[ORIG_FN:%.*]] = function_ref {{.*}}method{{.*}} : $@convention(method) (Float, Double, Struct) -> Float
46+
// CHECK: [[DIFF_FN:%.*]] = differentiable_function [parameters 0] [[ORIG_FN]]
47+
// CHECK: [[VJP_FN:%.*]] = differentiable_function_extract [vjp] [[DIFF_FN]]
48+
// CHECK: apply [[VJP_FN]]
49+
// CHECK: }
4050

4151
// CHECK-LABEL: sil private [transparent] [thunk] [ossa] @AD__${{.*}}method{{.*}}_jvp_SSS : $@convention(witness_method: Protocol) (Float, Double, @in_guaranteed Struct) -> (Float, @owned @callee_guaranteed @substituted <τ_0_0> (Float, Double, @in_guaranteed τ_0_0) -> Float for <DummyTangentVector>) {
52+
// CHECK: [[ORIG_FN:%.*]] = function_ref {{.*}}method{{.*}} : $@convention(method) (Float, Double, Struct) -> Float
53+
// CHECK: [[DIFF_FN:%.*]] = differentiable_function [parameters 0 1 2] [[ORIG_FN]]
54+
// CHECK: [[JVP_FN:%.*]] = differentiable_function_extract [jvp] [[DIFF_FN]]
55+
// CHECK: apply [[JVP_FN]]
56+
// CHECK: }
4257

4358
// CHECK-LABEL: sil private [transparent] [thunk] [ossa] @AD__${{.*}}method{{.*}}_vjp_SSS : $@convention(witness_method: Protocol) (Float, Double, @in_guaranteed Struct) -> (Float, @owned @callee_guaranteed @substituted <τ_0_0> (Float) -> (Float, Double, @out τ_0_0) for <DummyTangentVector>) {
59+
// CHECK: [[ORIG_FN:%.*]] = function_ref {{.*}}method{{.*}} : $@convention(method) (Float, Double, Struct) -> Float
60+
// CHECK: [[DIFF_FN:%.*]] = differentiable_function [parameters 0 1 2] [[ORIG_FN]]
61+
// CHECK: [[VJP_FN:%.*]] = differentiable_function_extract [vjp] [[DIFF_FN]]
62+
// CHECK: apply [[VJP_FN]]
63+
// CHECK: }
4464

4565
@differentiable
4666
var property: Float {
@@ -49,9 +69,18 @@ struct Struct: Protocol {
4969
}
5070

5171
// CHECK-LABEL: sil private [transparent] [thunk] [ossa] @AD__${{.*}}property{{.*}}_jvp_S : $@convention(witness_method: Protocol) (@in_guaranteed Struct) -> (Float, @owned @callee_guaranteed @substituted <τ_0_0> (@in_guaranteed τ_0_0) -> Float for <DummyTangentVector>) {
72+
// CHECK: [[ORIG_FN:%.*]] = function_ref {{.*}}property{{.*}} : $@convention(method) (Struct) -> Float
73+
// CHECK: [[DIFF_FN:%.*]] = differentiable_function [parameters 0] [[ORIG_FN]]
74+
// CHECK: [[JVP_FN:%.*]] = differentiable_function_extract [jvp] [[DIFF_FN]]
75+
// CHECK: apply [[JVP_FN]]
76+
// CHECK: }
5277

5378
// CHECK-LABEL: sil private [transparent] [thunk] [ossa] @AD__${{.*}}property{{.*}}_vjp_S : $@convention(witness_method: Protocol) (@in_guaranteed Struct) -> (Float, @owned @callee_guaranteed @substituted <τ_0_0> (Float) -> @out τ_0_0 for <DummyTangentVector>) {
54-
79+
// CHECK: [[ORIG_FN:%.*]] = function_ref {{.*}}property{{.*}} : $@convention(method) (Struct) -> Float
80+
// CHECK: [[DIFF_FN:%.*]] = differentiable_function [parameters 0] [[ORIG_FN]]
81+
// CHECK: [[VJP_FN:%.*]] = differentiable_function_extract [vjp] [[DIFF_FN]]
82+
// CHECK: apply [[VJP_FN]]
83+
// CHECK: }
5584

5685
@differentiable(wrt: x)
5786
subscript(_ x: Float, _ y: Float) -> Float {
@@ -60,8 +89,18 @@ struct Struct: Protocol {
6089
}
6190

6291
// CHECK-LABEL: sil private [transparent] [thunk] [ossa] @AD__$s13witness_table6StructVAA8ProtocolA2aDPyS2f_SftcigTW_jvp_SUU : $@convention(witness_method: Protocol) (Float, Float, @in_guaranteed Struct) -> (Float, @owned @callee_guaranteed (Float) -> Float) {
92+
// CHECK: [[ORIG_FN:%.*]] = function_ref @$s13witness_table6StructVyS2f_Sftcig : $@convention(method) (Float, Float, Struct) -> Float
93+
// CHECK: [[DIFF_FN:%.*]] = differentiable_function [parameters 0] [[ORIG_FN]]
94+
// CHECK: [[JVP_FN:%.*]] = differentiable_function_extract [jvp] [[DIFF_FN]]
95+
// CHECK: apply [[JVP_FN]]
96+
// CHECK: }
6397

6498
// CHECK-LABEL: sil private [transparent] [thunk] [ossa] @AD__$s13witness_table6StructVAA8ProtocolA2aDPyS2f_SftcigTW_vjp_SUU : $@convention(witness_method: Protocol) (Float, Float, @in_guaranteed Struct) -> (Float, @owned @callee_guaranteed (Float) -> Float) {
99+
// CHECK: [[ORIG_FN:%.*]] = function_ref @$s13witness_table6StructVyS2f_Sftcig : $@convention(method) (Float, Float, Struct) -> Float
100+
// CHECK: [[DIFF_FN:%.*]] = differentiable_function [parameters 0] [[ORIG_FN]]
101+
// CHECK: [[VJP_FN:%.*]] = differentiable_function_extract [vjp] [[DIFF_FN]]
102+
// CHECK: apply [[VJP_FN]]
103+
// CHECK: }
65104
}
66105

67106
// CHECK-LABEL: sil_witness_table hidden Struct: Protocol module witness_table {

0 commit comments

Comments
 (0)