Skip to content

Commit 7668666

Browse files
authored
Support differentiation of wrapped value modify accessors (#78794)
Some fixes for coroutines with normal results and `partial_apply` of coroutines were required. Fixes #55084
1 parent 401b705 commit 7668666

File tree

11 files changed

+159
-31
lines changed

11 files changed

+159
-31
lines changed

lib/IRGen/GenCall.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4982,7 +4982,9 @@ static void emitRetconCoroutineEntry(
49824982
ArrayRef<llvm::Value *> extraArguments, llvm::Constant *allocFn,
49834983
llvm::Constant *deallocFn, ArrayRef<llvm::Value *> finalArguments) {
49844984
auto prototype =
4985-
IGF.IGM.getOpaquePtr(IGF.IGM.getAddrOfContinuationPrototype(fnType));
4985+
IGF.IGM.getOpaquePtr(
4986+
IGF.IGM.getAddrOfContinuationPrototype(fnType,
4987+
fnType->getInvocationGenericSignature()));
49864988
// Call the right 'llvm.coro.id.retcon' variant.
49874989
SmallVector<llvm::Value *, 8> arguments;
49884990
arguments.push_back(

lib/IRGen/GenDecl.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6344,13 +6344,14 @@ IRGenModule::getAddrOfDefaultAssociatedConformanceAccessor(
63446344
}
63456345

63466346
llvm::Function *
6347-
IRGenModule::getAddrOfContinuationPrototype(CanSILFunctionType fnType) {
6347+
IRGenModule::getAddrOfContinuationPrototype(CanSILFunctionType fnType,
6348+
CanGenericSignature sig) {
63486349
LinkEntity entity = LinkEntity::forCoroutineContinuationPrototype(fnType);
63496350

63506351
llvm::Function *&entry = GlobalFuncs[entity];
63516352
if (entry) return entry;
63526353

6353-
GenericContextScope scope(*this, fnType->getInvocationGenericSignature());
6354+
GenericContextScope scope(*this, sig);
63546355
auto signature = Signature::forCoroutineContinuation(*this, fnType);
63556356
LinkInfo link = LinkInfo::get(*this, entity, NotForDefinition);
63566357
entry = createFunction(*this, link, signature);

lib/IRGen/GenFunc.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1486,7 +1486,8 @@ class CoroPartialApplicationForwarderEmission
14861486
auto prototype = subIGF.IGM.getOpaquePtr(
14871487
subIGF.IGM.getAddrOfContinuationPrototype(
14881488
cast<SILFunctionType>(
1489-
unsubstType->mapTypeOutOfContext()->getCanonicalType())));
1489+
unsubstType->mapTypeOutOfContext()->getCanonicalType()),
1490+
origType->getInvocationGenericSignature()));
14901491

14911492

14921493
// Use free as our allocator.

lib/IRGen/IRGenModule.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1899,7 +1899,8 @@ private: \
18991899

19001900
void emitDynamicReplacementOriginalFunctionThunk(SILFunction *f);
19011901

1902-
llvm::Function *getAddrOfContinuationPrototype(CanSILFunctionType fnType);
1902+
llvm::Function *getAddrOfContinuationPrototype(CanSILFunctionType fnType,
1903+
CanGenericSignature sig);
19031904
Address getAddrOfSILGlobalVariable(SILGlobalVariable *var,
19041905
const TypeInfo &ti,
19051906
ForDefinition_t forDefinition);

lib/IRGen/IRGenSIL.cpp

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4762,14 +4762,18 @@ void IRGenSILFunction::visitEndApply(BeginApplyInst *i, EndApplyInst *ei) {
47624762
const auto &coroutine = getLoweredCoroutine(i->getTokenResult());
47634763
bool isAbort = ei == nullptr;
47644764

4765-
auto sig = Signature::forCoroutineContinuation(IGM, i->getOrigCalleeType());
4765+
// Lower the return value in the callee's generic context.
4766+
auto origCalleeType = i->getOrigCalleeType();
4767+
GenericContextScope scope(IGM, origCalleeType->getInvocationGenericSignature());
4768+
4769+
auto sig = Signature::forCoroutineContinuation(IGM, origCalleeType);
47664770

47674771
// Cast the continuation pointer to the right function pointer type.
47684772
auto continuation = coroutine.Continuation;
47694773
continuation = Builder.CreateBitCast(continuation, IGM.PtrTy);
47704774

47714775
auto schemaAndEntity =
4772-
getCoroutineResumeFunctionPointerAuth(IGM, i->getOrigCalleeType());
4776+
getCoroutineResumeFunctionPointerAuth(IGM, origCalleeType);
47734777
auto pointerAuth = PointerAuthInfo::emit(*this, schemaAndEntity.first,
47744778
coroutine.getBuffer().getAddress(),
47754779
schemaAndEntity.second);
@@ -4798,16 +4802,15 @@ void IRGenSILFunction::visitEndApply(BeginApplyInst *i, EndApplyInst *ei) {
47984802

47994803
if (!isAbort) {
48004804
auto resultType = call->getType();
4805+
Explosion e;
48014806
if (!resultType->isVoidTy()) {
4802-
Explosion e;
48034807
// FIXME: Do we need to handle ABI-related conversions here?
48044808
// It seems we cannot have C function convention for coroutines, etc.
48054809
extractScalarResults(*this, resultType, call, e);
4806-
4807-
// NOTE: This inserts a new entry into the LoweredValues DenseMap,
4808-
// invalidating the reference held by `coroutine`.
4809-
setLoweredExplosion(ei, e);
48104810
}
4811+
// NOTE: This inserts a new entry into the LoweredValues DenseMap,
4812+
// invalidating the reference held by `coroutine`.
4813+
setLoweredExplosion(ei, e);
48114814
}
48124815
}
48134816

lib/SILOptimizer/Differentiation/Common.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -61,10 +61,10 @@ bool isSemanticMemberAccessor(SILFunction *original) {
6161
auto *accessor = dyn_cast<AccessorDecl>(decl);
6262
if (!accessor)
6363
return false;
64-
// Currently, only getters and setters are supported.
65-
// TODO(https://github.com/apple/swift/issues/55084): Support `modify` accessors.
64+
// Currently, only getters, setters and _modify accessors are supported.
6665
if (accessor->getAccessorKind() != AccessorKind::Get &&
67-
accessor->getAccessorKind() != AccessorKind::Set)
66+
accessor->getAccessorKind() != AccessorKind::Set &&
67+
accessor->getAccessorKind() != AccessorKind::Modify)
6868
return false;
6969
// Accessor must come from a `var` declaration.
7070
auto *varDecl = dyn_cast<VarDecl>(accessor->getStorage());

lib/SILOptimizer/Differentiation/PullbackCloner.cpp

Lines changed: 84 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -905,6 +905,7 @@ class PullbackCloner::Implementation final
905905
bool runForSemanticMemberAccessor();
906906
bool runForSemanticMemberGetter();
907907
bool runForSemanticMemberSetter();
908+
bool runForSemanticMemberModify();
908909

909910
/// If original result is non-varied, it will always have a zero derivative.
910911
/// Skip full pullback generation and simply emit zero derivatives for wrt
@@ -2452,7 +2453,8 @@ bool PullbackCloner::Implementation::run() {
24522453

24532454
// If the original function is an accessor with special-case pullback
24542455
// generation logic, do special-case generation.
2455-
if (isSemanticMemberAccessor(&original)) {
2456+
bool isSemanticMemberAcc = isSemanticMemberAccessor(&original);
2457+
if (isSemanticMemberAcc) {
24562458
if (runForSemanticMemberAccessor())
24572459
return true;
24582460
}
@@ -2730,7 +2732,8 @@ bool PullbackCloner::Implementation::run() {
27302732
#endif
27312733

27322734
LLVM_DEBUG(getADDebugStream()
2733-
<< "Generated pullback for " << original.getName() << ":\n"
2735+
<< "Generated " << (isSemanticMemberAcc ? "semantic member accessor" : "normal")
2736+
<< " pullback for " << original.getName() << ":\n"
27342737
<< pullback);
27352738
return errorOccurred;
27362739
}
@@ -3205,7 +3208,8 @@ bool PullbackCloner::Implementation::runForSemanticMemberAccessor() {
32053208
return runForSemanticMemberGetter();
32063209
case AccessorKind::Set:
32073210
return runForSemanticMemberSetter();
3208-
// TODO(https://github.com/apple/swift/issues/55084): Support `modify` accessors.
3211+
case AccessorKind::Modify:
3212+
return runForSemanticMemberModify();
32093213
default:
32103214
llvm_unreachable("Unsupported accessor kind; inconsistent with "
32113215
"`isSemanticMemberAccessor`?");
@@ -3389,6 +3393,83 @@ bool PullbackCloner::Implementation::runForSemanticMemberSetter() {
33893393
return false;
33903394
}
33913395

3396+
bool PullbackCloner::Implementation::runForSemanticMemberModify() {
3397+
auto &original = getOriginal();
3398+
auto &pullback = getPullback();
3399+
auto pbLoc = getPullback().getLocation();
3400+
3401+
auto *accessor = cast<AccessorDecl>(original.getDeclContext()->getAsDecl());
3402+
assert(accessor->getAccessorKind() == AccessorKind::Modify);
3403+
3404+
auto *origEntry = original.getEntryBlock();
3405+
// We assume that the accessor has a simple 3-BB structure with yield in the entry BB
3406+
// plus resume and unwind BBs
3407+
auto *yi = cast<YieldInst>(origEntry->getTerminator());
3408+
auto *origResumeBB = yi->getResumeBB();
3409+
3410+
auto *pbEntry = pullback.getEntryBlock();
3411+
builder.setCurrentDebugScope(
3412+
remapScope(origEntry->getScopeOfFirstNonMetaInstruction()));
3413+
builder.setInsertionPoint(pbEntry);
3414+
3415+
// Get _modify accessor argument values.
3416+
// Accessor type : $(inout Self) -> @yields @inout Argument
3417+
// Pullback type : $(inout Self', linear map tuple) -> @yields @inout Argument'
3418+
// Normally pullbacks for semantic member accessors are single BB and
3419+
// therefore have empty linear map tuple, however, coroutines have a branching
3420+
// control flow due to possible coroutine abort, so we need to accommodate for
3421+
// this. We keep branch tracing enums in order not to special case in many
3422+
// other places. As there is no way to return to coroutine via abort exit, we
3423+
// essentially "linearize" a coroutine.
3424+
auto loweredFnTy = original.getLoweredFunctionType();
3425+
auto pullbackLoweredFnTy = pullback.getLoweredFunctionType();
3426+
3427+
assert(loweredFnTy->getNumParameters() == 1 &&
3428+
loweredFnTy->getNumYields() == 1);
3429+
assert(pullbackLoweredFnTy->getNumParameters() == 2);
3430+
assert(pullbackLoweredFnTy->getNumYields() == 1);
3431+
3432+
SILValue origSelf = original.getArgumentsWithoutIndirectResults().front();
3433+
3434+
SmallVector<SILValue, 8> origFormalResults;
3435+
collectAllFormalResultsInTypeOrder(original, origFormalResults);
3436+
3437+
assert(getConfig().resultIndices->getNumIndices() == 2 &&
3438+
"Modify accessor should have two semantic results");
3439+
3440+
auto origYield = origFormalResults[*std::next(getConfig().resultIndices->begin())];
3441+
3442+
// Look up the corresponding field in the tangent space.
3443+
auto *origField = cast<VarDecl>(accessor->getStorage());
3444+
auto baseType = remapType(origSelf->getType()).getASTType();
3445+
auto *tanField = getTangentStoredProperty(getContext(), origField, baseType,
3446+
pbLoc, getInvoker());
3447+
if (!tanField) {
3448+
errorOccurred = true;
3449+
return true;
3450+
}
3451+
3452+
auto adjSelf = getAdjointBuffer(origResumeBB, origSelf);
3453+
auto *adjSelfElt = builder.createStructElementAddr(pbLoc, adjSelf, tanField);
3454+
// Modify accessors have inout yields and therefore should yield addresses.
3455+
assert(getTangentValueCategory(origYield) == SILValueCategory::Address &&
3456+
"Modify accessors should yield indirect");
3457+
3458+
// Yield the adjoint buffer and do everything else in the resume
3459+
// destination. Unwind destination is unreachable as the coroutine can never
3460+
// be aborted.
3461+
auto *unwindBB = getPullback().createBasicBlock();
3462+
auto *resumeBB = getPullbackBlock(origEntry);
3463+
builder.createYield(yi->getLoc(), {adjSelfElt}, resumeBB, unwindBB);
3464+
builder.setInsertionPoint(unwindBB);
3465+
builder.createUnreachable(SILLocation::invalid());
3466+
3467+
builder.setInsertionPoint(resumeBB);
3468+
addToAdjointBuffer(origEntry, origSelf, adjSelf, pbLoc);
3469+
3470+
return false;
3471+
}
3472+
33923473
//--------------------------------------------------------------------------//
33933474
// Adjoint buffer mapping
33943475
//--------------------------------------------------------------------------//

lib/SILOptimizer/Differentiation/VJPCloner.cpp

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -460,6 +460,16 @@ class VJPCloner::Implementation final
460460
TypeSubstCloner::visitEndApplyInst(eai);
461461
return;
462462
}
463+
// If the original function is a semantic member accessor, do standard
464+
// cloning. Semantic member accessors have special pullback generation
465+
// logic, so all `end_apply` instructions can be directly cloned to the VJP.
466+
if (isSemanticMemberAccessor(original)) {
467+
LLVM_DEBUG(getADDebugStream()
468+
<< "Cloning `end_apply` in semantic member accessor:\n"
469+
<< *eai << '\n');
470+
TypeSubstCloner::visitEndApplyInst(eai);
471+
return;
472+
}
463473

464474
Builder.setCurrentDebugScope(getOpScope(eai->getDebugScope()));
465475
auto loc = eai->getLoc();
@@ -607,6 +617,16 @@ class VJPCloner::Implementation final
607617
TypeSubstCloner::visitBeginApplyInst(bai);
608618
return;
609619
}
620+
// If the original function is a semantic member accessor, do standard
621+
// cloning. Semantic member accessors have special pullback generation
622+
// logic, so all `begin_apply` instructions can be directly cloned to the VJP.
623+
if (isSemanticMemberAccessor(original)) {
624+
LLVM_DEBUG(getADDebugStream()
625+
<< "Cloning `begin_apply` in semantic member accessor:\n"
626+
<< *bai << '\n');
627+
TypeSubstCloner::visitBeginApplyInst(bai);
628+
return;
629+
}
610630

611631
Builder.setCurrentDebugScope(getOpScope(bai->getDebugScope()));
612632
auto loc = bai->getLoc();

stdlib/private/DifferentiationUnittest/DifferentiationUnittest.swift.gyb

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -340,6 +340,29 @@ where
340340
{
341341
return (lhs * rhs, { (dx, dy) in dx * rhs + dy * lhs })
342342
}
343+
344+
@usableFromInline
345+
@derivative(of: *=)
346+
static func _vjpMultiplyAssign(_ lhs: inout Self, _ rhs: Self) -> (
347+
value: Void, pullback: (inout Self) -> Self)
348+
{
349+
defer { lhs *= rhs }
350+
return ((), { [lhs = lhs] v in
351+
let drhs = lhs * v
352+
v *= rhs
353+
return drhs
354+
})
355+
}
356+
357+
@usableFromInline
358+
@derivative(of: *=)
359+
static func _jvpMultiplyAssign(_ lhs: inout Self, _ rhs: Self) -> (
360+
value: Void, differential: (inout Self, Self) -> Void)
361+
{
362+
let oldLhs = lhs
363+
lhs *= rhs
364+
return ((), { $0 = $0 * rhs + oldLhs * $1 })
365+
}
343366
}
344367

345368
extension ${Self}

test/AutoDiff/SILOptimizer/differentiation_diagnostics.swift

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -680,9 +680,7 @@ extension DifferentiableWrapper: Differentiable where Value: Differentiable {}
680680
// accesses.
681681

682682
struct Struct: Differentiable {
683-
// expected-error @+4 {{expression is not differentiable}}
684-
// expected-error @+3 {{expression is not differentiable}}
685-
// expected-note @+2 {{cannot differentiate access to property 'Struct._x' because 'Struct.TangentVector' does not have a stored property named '_x'}}
683+
// expected-error @+2 {{expression is not differentiable}}
686684
// expected-note @+1 {{cannot differentiate access to property 'Struct._x' because 'Struct.TangentVector' does not have a stored property named '_x'}}
687685
@DifferentiableWrapper @DifferentiableWrapper var x: Float = 10
688686

0 commit comments

Comments
 (0)