22
22
#include " swift/SILOptimizer/Differentiation/Thunk.h"
23
23
#include " swift/SILOptimizer/Differentiation/VJPEmitter.h"
24
24
25
+ #include " swift/AST/Expr.h"
26
+ #include " swift/AST/PropertyWrappers.h"
25
27
#include " swift/SIL/InstructionUtils.h"
26
28
#include " swift/SIL/Projection.h"
27
29
#include " swift/SILOptimizer/PassManager/PrettyStackTrace.h"
@@ -511,6 +513,242 @@ void PullbackEmitter::addToAdjointBuffer(SILBasicBlock *origBB,
511
513
accumulateIndirect (adjointBuffer, rhsBufferAccess, loc);
512
514
}
513
515
516
+ // --------------------------------------------------------------------------//
517
+ // Member accessor pullback generation
518
+ // --------------------------------------------------------------------------//
519
+
520
+ // / Returns true if the given original function is a "semantic member accessor".
521
+ static bool isSemanticMemberAccessor (SILFunction *original) {
522
+ auto *dc = original->getDeclContext ();
523
+ if (!dc)
524
+ return false ;
525
+ auto *decl = dc->getAsDecl ();
526
+ if (!decl)
527
+ return false ;
528
+ auto *accessor = dyn_cast<AccessorDecl>(decl);
529
+ if (!accessor)
530
+ return false ;
531
+ // Currently, only getters and setters are supported.
532
+ // TODO(SR-12640): Support `modify` accessors.
533
+ if (accessor->getAccessorKind () != AccessorKind::Get &&
534
+ accessor->getAccessorKind () != AccessorKind::Set)
535
+ return false ;
536
+ // Accessor must come from a `var` declaration.
537
+ auto *varDecl = dyn_cast<VarDecl>(accessor->getStorage ());
538
+ if (!varDecl)
539
+ return false ;
540
+ // Return true for stored property accessors.
541
+ if (varDecl->hasStorage ())
542
+ return true ;
543
+ // Return true for properties that have attached property wrappers.
544
+ if (varDecl->hasAttachedPropertyWrapper ())
545
+ return true ;
546
+ // Otherwise, return false.
547
+ // User-defined accessors can never be supported because they may use custom
548
+ // logic that does not semantically perform a member access.
549
+ // TODO(SR-12636): Support `@differentiable(useInTangentVector)` computed
550
+ // properties.
551
+ return false ;
552
+ }
553
+
554
+ bool PullbackEmitter::runForSemanticMemberAccessor () {
555
+ auto &original = getOriginal ();
556
+ auto *accessor = cast<AccessorDecl>(original.getDeclContext ()->getAsDecl ());
557
+ switch (accessor->getAccessorKind ()) {
558
+ case AccessorKind::Get:
559
+ return runForSemanticMemberGetter ();
560
+ case AccessorKind::Set:
561
+ return runForSemanticMemberSetter ();
562
+ // TODO(SR-12640): Support `modify` accessors.
563
+ default :
564
+ llvm_unreachable (" Unsupported accessor kind; inconsistent with "
565
+ " `isSemanticMemberAccessor`?" );
566
+ }
567
+ }
568
+
569
+ bool PullbackEmitter::runForSemanticMemberGetter () {
570
+ auto &original = getOriginal ();
571
+ auto &pullback = getPullback ();
572
+ auto pbLoc = getPullback ().getLocation ();
573
+
574
+ auto *accessor = cast<AccessorDecl>(original.getDeclContext ()->getAsDecl ());
575
+ assert (accessor->getAccessorKind () == AccessorKind::Get);
576
+
577
+ auto origExitIt = original.findReturnBB ();
578
+ assert (origExitIt != original.end () &&
579
+ " Functions without returns must have been diagnosed" );
580
+ auto *origExit = &*origExitIt;
581
+ builder.setInsertionPoint (pullback.getEntryBlock ());
582
+
583
+ // Get getter argument and result values.
584
+ // Getter type: $(Self) -> Result
585
+ // Pullback type: $(Result', PB_Struct) -> Self'
586
+ assert (original.getLoweredFunctionType ()->getNumParameters () == 1 );
587
+ assert (pullback.getLoweredFunctionType ()->getNumParameters () == 2 );
588
+ assert (pullback.getLoweredFunctionType ()->getNumResults () == 1 );
589
+ SILValue origSelf = original.getArgumentsWithoutIndirectResults ().front ();
590
+
591
+ SmallVector<SILValue, 8 > origFormalResults;
592
+ collectAllFormalResultsInTypeOrder (original, origFormalResults);
593
+ auto origResult = origFormalResults[getIndices ().source ];
594
+
595
+ // TODO(TF-970): Emit diagnostic when `TangentVector` is not a struct.
596
+ auto tangentVectorSILTy = pullback.getConventions ().getSingleSILResultType ();
597
+ auto tangentVectorTy = tangentVectorSILTy.getASTType ();
598
+ auto *tangentVectorDecl = tangentVectorTy->getStructOrBoundGenericStruct ();
599
+
600
+ // Look up the corresponding field in the tangent space.
601
+ VarDecl *origField = cast<VarDecl>(accessor->getStorage ());
602
+ VarDecl *tanField = nullptr ;
603
+ auto tanFieldLookup = tangentVectorDecl->lookupDirect (origField->getName ());
604
+ if (tanFieldLookup.empty ()) {
605
+ getContext ().emitNondifferentiabilityError (
606
+ pbLoc.getSourceLoc (), getInvoker (),
607
+ diag::autodiff_stored_property_no_corresponding_tangent,
608
+ origSelf->getType ().getASTType ().getString (), origField->getNameStr ());
609
+ errorOccurred = true ;
610
+ return true ;
611
+ }
612
+ tanField = cast<VarDecl>(tanFieldLookup.front ());
613
+
614
+ // Switch based on the base tangent struct's value category.
615
+ // TODO(TF-1255): Simplify using unified adjoint value data structure.
616
+ switch (tangentVectorSILTy.getCategory ()) {
617
+ case SILValueCategory::Object: {
618
+ auto adjResult = getAdjointValue (origExit, origResult);
619
+ switch (adjResult.getKind ()) {
620
+ case AdjointValueKind::Zero:
621
+ addAdjointValue (origExit, origSelf,
622
+ makeZeroAdjointValue (tangentVectorSILTy), pbLoc);
623
+ break ;
624
+ case AdjointValueKind::Concrete:
625
+ case AdjointValueKind::Aggregate: {
626
+ SmallVector<AdjointValue, 8 > eltVals;
627
+ for (auto *field : tangentVectorDecl->getStoredProperties ()) {
628
+ if (field == tanField) {
629
+ eltVals.push_back (adjResult);
630
+ } else {
631
+ auto substMap = tangentVectorTy->getMemberSubstitutionMap (
632
+ field->getModuleContext (), field);
633
+ auto fieldTy = field->getType ().subst (substMap);
634
+ auto fieldSILTy = getTypeLowering (fieldTy).getLoweredType ();
635
+ assert (fieldSILTy.isObject ());
636
+ eltVals.push_back (makeZeroAdjointValue (fieldSILTy));
637
+ }
638
+ }
639
+ addAdjointValue (origExit, origSelf,
640
+ makeAggregateAdjointValue (tangentVectorSILTy, eltVals),
641
+ pbLoc);
642
+ }
643
+ }
644
+ break ;
645
+ }
646
+ case SILValueCategory::Address: {
647
+ assert (pullback.getIndirectResults ().size () == 1 );
648
+ auto pbIndRes = pullback.getIndirectResults ().front ();
649
+ auto *adjSelf = createFunctionLocalAllocation (
650
+ pbIndRes->getType ().getObjectType (), pbLoc);
651
+ setAdjointBuffer (origExit, origSelf, adjSelf);
652
+ for (auto *field : tangentVectorDecl->getStoredProperties ()) {
653
+ auto *adjSelfElt = builder.createStructElementAddr (pbLoc, adjSelf, field);
654
+ if (field == tanField) {
655
+ // Switch based on the property's value category.
656
+ // TODO(TF-1255): Simplify using unified adjoint value data structure.
657
+ switch (origResult->getType ().getCategory ()) {
658
+ case SILValueCategory::Object: {
659
+ auto adjResult = getAdjointValue (origExit, origResult);
660
+ auto adjResultValue = materializeAdjointDirect (adjResult, pbLoc);
661
+ builder.emitStoreValueOperation (pbLoc, adjResultValue, adjSelfElt,
662
+ StoreOwnershipQualifier::Init);
663
+ break ;
664
+ }
665
+ case SILValueCategory::Address: {
666
+ auto adjResult = getAdjointBuffer (origExit, origResult);
667
+ builder.createCopyAddr (pbLoc, adjResult, adjSelfElt, IsTake,
668
+ IsInitialization);
669
+ destroyedLocalAllocations.insert (adjResult);
670
+ break ;
671
+ }
672
+ }
673
+ } else {
674
+ auto fieldType = pullback.mapTypeIntoContext (field->getInterfaceType ())
675
+ ->getCanonicalType ();
676
+ emitZeroIndirect (fieldType, adjSelfElt, pbLoc);
677
+ }
678
+ }
679
+ break ;
680
+ }
681
+ }
682
+ return false ;
683
+ }
684
+
685
+ bool PullbackEmitter::runForSemanticMemberSetter () {
686
+ auto &original = getOriginal ();
687
+ auto &pullback = getPullback ();
688
+ auto pbLoc = getPullback ().getLocation ();
689
+
690
+ auto *accessor = cast<AccessorDecl>(original.getDeclContext ()->getAsDecl ());
691
+ assert (accessor->getAccessorKind () == AccessorKind::Set);
692
+
693
+ auto origEntry = original.getEntryBlock ();
694
+ builder.setInsertionPoint (pullback.getEntryBlock ());
695
+
696
+ // Get setter argument values.
697
+ // Setter type: $(inout Self, Argument) -> ()
698
+ // Pullback type (wrt self): $(inout Self', PB_Struct) -> ()
699
+ // Pullback type (wrt both): $(inout Self', PB_Struct) -> Argument'
700
+ assert (original.getLoweredFunctionType ()->getNumParameters () == 2 );
701
+ assert (pullback.getLoweredFunctionType ()->getNumParameters () == 2 );
702
+ assert (pullback.getLoweredFunctionType ()->getNumResults () == 0 ||
703
+ pullback.getLoweredFunctionType ()->getNumResults () == 1 );
704
+
705
+ SILValue origArg = original.getArgumentsWithoutIndirectResults ()[0 ];
706
+ SILValue origSelf = original.getArgumentsWithoutIndirectResults ()[1 ];
707
+
708
+ // TODO(TF-970): Emit diagnostic when `TangentVector` is not a struct.
709
+ auto tangentVectorSILTy = pullback.getLoweredFunctionType ()
710
+ ->getParameters ()[0 ]
711
+ .getSILStorageInterfaceType ();
712
+ assert (tangentVectorSILTy.getCategory () == SILValueCategory::Address);
713
+ auto tangentVectorTy = tangentVectorSILTy.getASTType ();
714
+ auto *tangentVectorDecl = tangentVectorTy->getStructOrBoundGenericStruct ();
715
+
716
+ // Look up the corresponding field in the tangent space.
717
+ VarDecl *origField = cast<VarDecl>(accessor->getStorage ());
718
+ VarDecl *tanField = nullptr ;
719
+ auto tanFieldLookup = tangentVectorDecl->lookupDirect (origField->getName ());
720
+ if (tanFieldLookup.empty ()) {
721
+ getContext ().emitNondifferentiabilityError (
722
+ pbLoc.getSourceLoc (), getInvoker (),
723
+ diag::autodiff_stored_property_no_corresponding_tangent,
724
+ origSelf->getType ().getASTType ().getString (), origField->getNameStr ());
725
+ errorOccurred = true ;
726
+ return true ;
727
+ }
728
+ tanField = cast<VarDecl>(tanFieldLookup.front ());
729
+
730
+ auto adjSelf = getAdjointBuffer (origEntry, origSelf);
731
+ auto *adjSelfElt = builder.createStructElementAddr (pbLoc, adjSelf, tanField);
732
+ // Switch based on the property's value category.
733
+ // TODO(TF-1255): Simplify using unified adjoint value data structure.
734
+ switch (origArg->getType ().getCategory ()) {
735
+ case SILValueCategory::Object: {
736
+ auto adjArg = builder.emitLoadValueOperation (pbLoc, adjSelfElt,
737
+ LoadOwnershipQualifier::Take);
738
+ setAdjointValue (origEntry, origArg, makeConcreteAdjointValue (adjArg));
739
+ break ;
740
+ }
741
+ case SILValueCategory::Address: {
742
+ addToAdjointBuffer (origEntry, origArg, adjSelfElt, pbLoc);
743
+ builder.emitDestroyOperation (pbLoc, adjSelfElt);
744
+ break ;
745
+ }
746
+ }
747
+ emitZeroIndirect (adjSelfElt->getType ().getASTType (), adjSelfElt, pbLoc);
748
+
749
+ return false ;
750
+ }
751
+
514
752
// --------------------------------------------------------------------------//
515
753
// Entry point
516
754
// --------------------------------------------------------------------------//
@@ -756,12 +994,21 @@ bool PullbackEmitter::run() {
756
994
<< " as the adjoint of original result " << origResult);
757
995
}
758
996
997
+ // If the original function is an accessor with special-case pullback
998
+ // generation logic, do special-case generation.
999
+ if (isSemanticMemberAccessor (&original)) {
1000
+ if (runForSemanticMemberAccessor ())
1001
+ return true ;
1002
+ }
1003
+ // Otherwise, perform standard pullback generation.
759
1004
// Visit original blocks blocks in post-order and perform differentiation
760
1005
// in corresponding pullback blocks. If errors occurred, back out.
761
- for (auto *bb : postOrderPostDomOrder) {
762
- visitSILBasicBlock (bb);
763
- if (errorOccurred)
764
- return true ;
1006
+ else {
1007
+ for (auto *bb : postOrderPostDomOrder) {
1008
+ visitSILBasicBlock (bb);
1009
+ if (errorOccurred)
1010
+ return true ;
1011
+ }
765
1012
}
766
1013
767
1014
// Prepare and emit a `return` in the pullback exit block.
@@ -1343,8 +1590,7 @@ void PullbackEmitter::visitStructInst(StructInst *si) {
1343
1590
auto adjStruct = materializeAdjointDirect (std::move (av), loc);
1344
1591
// Find the struct `TangentVector` type.
1345
1592
auto structTy = remapType (si->getType ()).getASTType ();
1346
- auto tangentVectorTy =
1347
- getTangentSpace (structTy)->getType ()->getCanonicalType ();
1593
+ auto tangentVectorTy = getTangentSpace (structTy)->getCanonicalType ();
1348
1594
assert (!getTypeLowering (tangentVectorTy).isAddressOnly ());
1349
1595
auto *tangentVectorDecl = tangentVectorTy->getStructOrBoundGenericStruct ();
1350
1596
assert (tangentVectorDecl);
@@ -1406,8 +1652,7 @@ void PullbackEmitter::visitStructExtractInst(StructExtractInst *sei) {
1406
1652
" differentiated; activity analysis should not marked as varied" );
1407
1653
auto *bb = sei->getParent ();
1408
1654
auto structTy = remapType (sei->getOperand ()->getType ()).getASTType ();
1409
- auto tangentVectorTy =
1410
- getTangentSpace (structTy)->getType ()->getCanonicalType ();
1655
+ auto tangentVectorTy = getTangentSpace (structTy)->getCanonicalType ();
1411
1656
assert (!getTypeLowering (tangentVectorTy).isAddressOnly ());
1412
1657
auto tangentVectorSILTy = SILType::getPrimitiveObjectType (tangentVectorTy);
1413
1658
auto *tangentVectorDecl = tangentVectorTy->getStructOrBoundGenericStruct ();
@@ -1449,8 +1694,7 @@ void PullbackEmitter::visitStructExtractInst(StructExtractInst *sei) {
1449
1694
auto substMap = tangentVectorTy->getMemberSubstitutionMap (
1450
1695
field->getModuleContext (), field);
1451
1696
auto fieldTy = field->getType ().subst (substMap);
1452
- auto fieldSILTy = getContext ().getTypeConverter ().getLoweredType (
1453
- fieldTy, TypeExpansionContext::minimal ());
1697
+ auto fieldSILTy = getTypeLowering (fieldTy).getLoweredType ();
1454
1698
assert (fieldSILTy.isObject ());
1455
1699
eltVals.push_back (makeZeroAdjointValue (fieldSILTy));
1456
1700
}
@@ -1466,8 +1710,7 @@ void PullbackEmitter::visitRefElementAddrInst(RefElementAddrInst *reai) {
1466
1710
auto *bb = reai->getParent ();
1467
1711
auto adjBuf = getAdjointBuffer (bb, reai);
1468
1712
auto classTy = remapType (reai->getOperand ()->getType ()).getASTType ();
1469
- auto tangentVectorTy =
1470
- getTangentSpace (classTy)->getType ()->getCanonicalType ();
1713
+ auto tangentVectorTy = getTangentSpace (classTy)->getCanonicalType ();
1471
1714
assert (!getTypeLowering (tangentVectorTy).isAddressOnly ());
1472
1715
auto tangentVectorSILTy = SILType::getPrimitiveObjectType (tangentVectorTy);
1473
1716
auto *tangentVectorDecl = tangentVectorTy->getStructOrBoundGenericStruct ();
@@ -1498,8 +1741,7 @@ void PullbackEmitter::visitRefElementAddrInst(RefElementAddrInst *reai) {
1498
1741
auto substMap = tangentVectorTy->getMemberSubstitutionMap (
1499
1742
field->getModuleContext (), field);
1500
1743
auto fieldTy = field->getType ().subst (substMap);
1501
- auto fieldSILTy = getContext ().getTypeConverter ().getLoweredType (
1502
- fieldTy, TypeExpansionContext::minimal ());
1744
+ auto fieldSILTy = getTypeLowering (fieldTy).getLoweredType ();
1503
1745
assert (fieldSILTy.isObject ());
1504
1746
eltVals.push_back (makeZeroAdjointValue (fieldSILTy));
1505
1747
}
0 commit comments