@@ -905,6 +905,7 @@ class PullbackCloner::Implementation final
905
905
bool runForSemanticMemberAccessor ();
906
906
bool runForSemanticMemberGetter ();
907
907
bool runForSemanticMemberSetter ();
908
+ bool runForSemanticMemberModify ();
908
909
909
910
// / If original result is non-varied, it will always have a zero derivative.
910
911
// / Skip full pullback generation and simply emit zero derivatives for wrt
@@ -2452,7 +2453,8 @@ bool PullbackCloner::Implementation::run() {
2452
2453
2453
2454
// If the original function is an accessor with special-case pullback
2454
2455
// generation logic, do special-case generation.
2455
- if (isSemanticMemberAccessor (&original)) {
2456
+ bool isSemanticMemberAcc = isSemanticMemberAccessor (&original);
2457
+ if (isSemanticMemberAcc) {
2456
2458
if (runForSemanticMemberAccessor ())
2457
2459
return true ;
2458
2460
}
@@ -2730,7 +2732,8 @@ bool PullbackCloner::Implementation::run() {
2730
2732
#endif
2731
2733
2732
2734
LLVM_DEBUG (getADDebugStream ()
2733
- << " Generated pullback for " << original.getName () << " :\n "
2735
+ << " Generated " << (isSemanticMemberAcc ? " semantic member accessor" : " normal" )
2736
+ << " pullback for " << original.getName () << " :\n "
2734
2737
<< pullback);
2735
2738
return errorOccurred;
2736
2739
}
@@ -3205,7 +3208,8 @@ bool PullbackCloner::Implementation::runForSemanticMemberAccessor() {
3205
3208
return runForSemanticMemberGetter ();
3206
3209
case AccessorKind::Set:
3207
3210
return runForSemanticMemberSetter ();
3208
- // TODO(https://github.com/apple/swift/issues/55084): Support `modify` accessors.
3211
+ case AccessorKind::Modify:
3212
+ return runForSemanticMemberModify ();
3209
3213
default :
3210
3214
llvm_unreachable (" Unsupported accessor kind; inconsistent with "
3211
3215
" `isSemanticMemberAccessor`?" );
@@ -3389,6 +3393,83 @@ bool PullbackCloner::Implementation::runForSemanticMemberSetter() {
3389
3393
return false ;
3390
3394
}
3391
3395
3396
+ bool PullbackCloner::Implementation::runForSemanticMemberModify () {
3397
+ auto &original = getOriginal ();
3398
+ auto &pullback = getPullback ();
3399
+ auto pbLoc = getPullback ().getLocation ();
3400
+
3401
+ auto *accessor = cast<AccessorDecl>(original.getDeclContext ()->getAsDecl ());
3402
+ assert (accessor->getAccessorKind () == AccessorKind::Modify);
3403
+
3404
+ auto *origEntry = original.getEntryBlock ();
3405
+ // We assume that the accessor has a simple 3-BB structure with yield in the entry BB
3406
+ // plus resume and unwind BBs
3407
+ auto *yi = cast<YieldInst>(origEntry->getTerminator ());
3408
+ auto *origResumeBB = yi->getResumeBB ();
3409
+
3410
+ auto *pbEntry = pullback.getEntryBlock ();
3411
+ builder.setCurrentDebugScope (
3412
+ remapScope (origEntry->getScopeOfFirstNonMetaInstruction ()));
3413
+ builder.setInsertionPoint (pbEntry);
3414
+
3415
+ // Get _modify accessor argument values.
3416
+ // Accessor type : $(inout Self) -> @yields @inout Argument
3417
+ // Pullback type : $(inout Self', linear map tuple) -> @yields @inout Argument'
3418
+ // Normally pullbacks for semantic member accessors are single BB and
3419
+ // therefore have empty linear map tuple, however, coroutines have a branching
3420
+ // control flow due to possible coroutine abort, so we need to accommodate for
3421
+ // this. We keep branch tracing enums in order not to special case in many
3422
+ // other places. As there is no way to return to coroutine via abort exit, we
3423
+ // essentially "linearize" a coroutine.
3424
+ auto loweredFnTy = original.getLoweredFunctionType ();
3425
+ auto pullbackLoweredFnTy = pullback.getLoweredFunctionType ();
3426
+
3427
+ assert (loweredFnTy->getNumParameters () == 1 &&
3428
+ loweredFnTy->getNumYields () == 1 );
3429
+ assert (pullbackLoweredFnTy->getNumParameters () == 2 );
3430
+ assert (pullbackLoweredFnTy->getNumYields () == 1 );
3431
+
3432
+ SILValue origSelf = original.getArgumentsWithoutIndirectResults ().front ();
3433
+
3434
+ SmallVector<SILValue, 8 > origFormalResults;
3435
+ collectAllFormalResultsInTypeOrder (original, origFormalResults);
3436
+
3437
+ assert (getConfig ().resultIndices ->getNumIndices () == 2 &&
3438
+ " Modify accessor should have two semantic results" );
3439
+
3440
+ auto origYield = origFormalResults[*std::next (getConfig ().resultIndices ->begin ())];
3441
+
3442
+ // Look up the corresponding field in the tangent space.
3443
+ auto *origField = cast<VarDecl>(accessor->getStorage ());
3444
+ auto baseType = remapType (origSelf->getType ()).getASTType ();
3445
+ auto *tanField = getTangentStoredProperty (getContext (), origField, baseType,
3446
+ pbLoc, getInvoker ());
3447
+ if (!tanField) {
3448
+ errorOccurred = true ;
3449
+ return true ;
3450
+ }
3451
+
3452
+ auto adjSelf = getAdjointBuffer (origResumeBB, origSelf);
3453
+ auto *adjSelfElt = builder.createStructElementAddr (pbLoc, adjSelf, tanField);
3454
+ // Modify accessors have inout yields and therefore should yield addresses.
3455
+ assert (getTangentValueCategory (origYield) == SILValueCategory::Address &&
3456
+ " Modify accessors should yield indirect" );
3457
+
3458
+ // Yield the adjoint buffer and do everything else in the resume
3459
+ // destination. Unwind destination is unreachable as the coroutine can never
3460
+ // be aborted.
3461
+ auto *unwindBB = getPullback ().createBasicBlock ();
3462
+ auto *resumeBB = getPullbackBlock (origEntry);
3463
+ builder.createYield (yi->getLoc (), {adjSelfElt}, resumeBB, unwindBB);
3464
+ builder.setInsertionPoint (unwindBB);
3465
+ builder.createUnreachable (SILLocation::invalid ());
3466
+
3467
+ builder.setInsertionPoint (resumeBB);
3468
+ addToAdjointBuffer (origEntry, origSelf, adjSelf, pbLoc);
3469
+
3470
+ return false ;
3471
+ }
3472
+
3392
3473
// --------------------------------------------------------------------------//
3393
3474
// Adjoint buffer mapping
3394
3475
// --------------------------------------------------------------------------//
0 commit comments