Skip to content

Commit 2cf7d63

Browse files
authored
[AutoDiff] Fix ownership error in VJPCloner::visitApplyInst. (swiftlang#35003)
Change the following code pattern: ``` %x1 = convert_function %x0 : $@differentiable (...) -> ... %x2 = begin_borrow %x1 ... // use %x2 %x3 = end_borrow %x2 destroy_value %x0 ``` To the following: ``` %x1 = begin_borrow %x0 : $@differentiable (...) -> ... %x2 = convert_function %x0 ... // use %x2 %x3 = end_borrow %x1 destroy_value %x0 ``` Resolves SR-13933: "multiple consuming users" ownership error caused by `VJPCloner::visitApply` related to `@differentiable`-function-typed callees. Also upstream test/AutoDiff/SILOptimizer/generics.swift from tensorflow branch.
1 parent 89c617a commit 2cf7d63

File tree

6 files changed

+433
-9
lines changed

6 files changed

+433
-9
lines changed

lib/SILOptimizer/Differentiation/VJPCloner.cpp

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -488,15 +488,17 @@ class VJPCloner::Implementation final
488488
return;
489489
}
490490
}
491-
auto origFnType = origCallee->getType().castTo<SILFunctionType>();
492-
auto origFnUnsubstType = origFnType->getUnsubstitutedType(getModule());
493-
if (origFnType != origFnUnsubstType) {
494-
origCallee = builder.createConvertFunction(
495-
loc, origCallee, SILType::getPrimitiveObjectType(origFnUnsubstType),
496-
/*withoutActuallyEscaping*/ false);
497-
}
498491
builder.emitScopedBorrowOperation(
499492
loc, origCallee, [&](SILValue borrowedDiffFunc) {
493+
auto origFnType = origCallee->getType().castTo<SILFunctionType>();
494+
auto origFnUnsubstType =
495+
origFnType->getUnsubstitutedType(getModule());
496+
if (origFnType != origFnUnsubstType) {
497+
borrowedDiffFunc = builder.createConvertFunction(
498+
loc, borrowedDiffFunc,
499+
SILType::getPrimitiveObjectType(origFnUnsubstType),
500+
/*withoutActuallyEscaping*/ false);
501+
}
500502
vjpValue = builder.createDifferentiableFunctionExtract(
501503
loc, NormalDifferentiableFunctionTypeComponent::VJP,
502504
borrowedDiffFunc);

0 commit comments

Comments
 (0)