Skip to content

Commit b9341c8

Browse files
authored
[AutoDiff] Fix protocol witness SILGen for @differentiable class methods. (swiftlang#32639)
During protocol witness SILGen for `@differentiable` class methods, replace the `AutoDiffDerivativeFunctionIdentifier` generic signature with the witness thunk substitution map's generic signature. Resolves TF-1180: vtable SIL verification error.
1 parent 9e85a39 commit b9341c8

File tree

3 files changed

+49
-1
lines changed

3 files changed

+49
-1
lines changed

include/swift/SIL/SILDeclRef.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -309,7 +309,7 @@ struct SILDeclRef {
309309
/// function.
310310
SILDeclRef asAutoDiffDerivativeFunction(
311311
AutoDiffDerivativeFunctionIdentifier *derivativeId) const {
312-
assert(!derivativeFunctionIdentifier);
312+
assert(derivativeId);
313313
SILDeclRef declRef = *this;
314314
declRef.derivativeFunctionIdentifier = derivativeId;
315315
return declRef;

lib/SILGen/SILGenPoly.cpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4604,6 +4604,17 @@ getWitnessFunctionRef(SILGenFunction &SGF,
46044604
}
46054605
case WitnessDispatchKind::Class: {
46064606
SILValue selfPtr = witnessParams.back().getValue();
4607+
// If `witness` is a derivative function `SILDeclRef`, replace the
4608+
// derivative function identifier's generic signature with the witness thunk
4609+
// substitution map's generic signature.
4610+
if (auto *derivativeId = witness.derivativeFunctionIdentifier) {
4611+
auto *newDerivativeId = AutoDiffDerivativeFunctionIdentifier::get(
4612+
derivativeId->getKind(), derivativeId->getParameterIndices(),
4613+
witnessSubs.getGenericSignature(), SGF.getASTContext());
4614+
return SGF.emitClassMethodRef(
4615+
loc, selfPtr, witness.asAutoDiffDerivativeFunction(newDerivativeId),
4616+
witnessFTy);
4617+
}
46074618
return SGF.emitClassMethodRef(loc, selfPtr, witness, witnessFTy);
46084619
}
46094620
}
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
// RUN: %target-swift-frontend -emit-silgen %s
2+
3+
// TF-1180: SIL verification error regarding `@differentiable` class method
4+
// witnesses for `@differentiable` protocol requirements.
5+
6+
import _Differentiation
7+
8+
protocol Protocol {
9+
@differentiable
10+
func method(_ x: Float) -> Float
11+
}
12+
13+
class Class: Protocol {
14+
@differentiable
15+
public func method(_ x: Float) -> Float { x }
16+
}
17+
18+
// Original error:
19+
// SIL verification failed: method does not appear in the class's vtable: VerifyClassMethodVisitor(member).Seen
20+
// Verifying instruction:
21+
// %2 = load_borrow %1 : $*Class // users: %8, %4, %3
22+
// -> %3 = class_method %2 : $Class, #Class.method!jvp.SU.<Self where Self : Protocol> : (Class) -> (Float) -> Float, $@convention(method) (Float, @guaranteed Class) -> (Float, @owned @callee_guaranteed (Float) -> Float) // user: %4
23+
// %4 = apply %3(%0, %2) : $@convention(method) (Float, @guaranteed Class) -> (Float, @owned @callee_guaranteed (Float) -> Float) // user: %5
24+
// In function:
25+
// // AD__$s4main5ClassCAA8ProtocolA2aDP6methodyS2fFTW_jvp_SU
26+
// sil private [transparent] [thunk] [ossa] @AD__$s4main5ClassCAA8ProtocolA2aDP6methodyS2fFTW_jvp_SU : $@convention(witness_method: Protocol) (Float, @in_guaranteed Class) -> (Float, @owned @callee_guaranteed (Float) -> Float) {
27+
// // %0 // user: %4
28+
// // %1 // user: %2
29+
// bb0(%0 : $Float, %1 : $*Class):
30+
// %2 = load_borrow %1 : $*Class // users: %8, %4, %3
31+
// %3 = class_method %2 : $Class, #Class.method!jvp.SU.<Self where Self : Protocol> : (Class) -> (Float) -> Float, $@convention(method) (Float, @guaranteed Class) -> (Float, @owned @callee_guaranteed (Float) -> Float) // user: %4
32+
// %4 = apply %3(%0, %2) : $@convention(method) (Float, @guaranteed Class) -> (Float, @owned @callee_guaranteed (Float) -> Float) // user: %5
33+
// (%5, %6) = destructure_tuple %4 : $(Float, @callee_guaranteed (Float) -> Float) // users: %7, %7
34+
// %7 = tuple (%5 : $Float, %6 : $@callee_guaranteed (Float) -> Float) // user: %9
35+
// end_borrow %2 : $Class // id: %8
36+
// return %7 : $(Float, @callee_guaranteed (Float) -> Float) // id: %9
37+
// } // end sil function 'AD__$s4main5ClassCAA8ProtocolA2aDP6methodyS2fFTW_jvp_SU'

0 commit comments

Comments
 (0)