Skip to content

Commit 1f5818a

Browse files
committed
[AutoDiff] Support property wrapper differentiation.
Support differentiation of property wrapper wrapped value getters and setters. Create new pullback generation code path for "semantic member accessors". "Semantic member accessors" are attached to member properties that have a corresponding tangent stored property in the parent `TangentVector` type. These accessors have special-case pullback generation based on their semantic behavior. Currently, only getters and setters are supported. This special-case pullback generation is currently used for stored property accessors and property wrapper wrapped value accessors. In the future, it can also be used to support `@differentiable(useInTangentVector)` computed properties: SR-12636. User-defined accesors cannot use this code path because they may use custom logic that does not semantically perform a member access. Resolves SR-12639.
1 parent d96b73a commit 1f5818a

File tree

7 files changed

+553
-21
lines changed

7 files changed

+553
-21
lines changed

include/swift/SILOptimizer/Differentiation/AdjointValue.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,7 @@ class AdjointValue final {
165165
break;
166166
}
167167
}
168+
SWIFT_DEBUG_DUMP { print(llvm::dbgs()); };
168169
};
169170

170171
inline llvm::raw_ostream &operator<<(llvm::raw_ostream &os,

include/swift/SILOptimizer/Differentiation/PullbackEmitter.h

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ class PullbackEmitter final : public SILInstructionVisitor<PullbackEmitter> {
100100
SILBuilder localAllocBuilder;
101101

102102
/// Stack buffers allocated for storing local adjoint values.
103-
SmallVector<SILValue, 64> functionLocalAllocations;
103+
SmallVector<AllocStackInst *, 64> functionLocalAllocations;
104104

105105
/// A set used to remember local allocations that were destroyed.
106106
llvm::SmallDenseSet<SILValue> destroyedLocalAllocations;
@@ -316,6 +316,24 @@ class PullbackEmitter final : public SILInstructionVisitor<PullbackEmitter> {
316316
/// if any error occurs.
317317
bool run();
318318

319+
/// Performs pullback generation on the empty pullback function, given that
320+
/// the original function is a "semantic member accessor".
321+
///
322+
/// "Semantic member accessors" are attached to member properties that have a
323+
/// corresponding tangent stored property in the parent `TangentVector` type.
324+
/// These accessors have special-case pullback generation based on their
325+
/// semantic behavior.
326+
///
327+
/// "Semantic member accessors" currently include:
328+
/// - Stored property accessors. These are implicitly generated.
329+
/// - Property wrapper wrapped value accessors. These are implicitly generated
330+
/// and internally call `var wrappedValue`.
331+
///
332+
/// Returns true if any error occurs.
333+
bool runForSemanticMemberAccessor();
334+
bool runForSemanticMemberGetter();
335+
bool runForSemanticMemberSetter();
336+
319337
/// If original result is non-varied, it will always have a zero derivative.
320338
/// Skip full pullback generation and simply emit zero derivatives for wrt
321339
/// parameters.

lib/SILOptimizer/Differentiation/PullbackEmitter.cpp

Lines changed: 256 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@
2222
#include "swift/SILOptimizer/Differentiation/Thunk.h"
2323
#include "swift/SILOptimizer/Differentiation/VJPEmitter.h"
2424

25+
#include "swift/AST/Expr.h"
26+
#include "swift/AST/PropertyWrappers.h"
2527
#include "swift/SIL/InstructionUtils.h"
2628
#include "swift/SIL/Projection.h"
2729
#include "swift/SILOptimizer/PassManager/PrettyStackTrace.h"
@@ -511,6 +513,242 @@ void PullbackEmitter::addToAdjointBuffer(SILBasicBlock *origBB,
511513
accumulateIndirect(adjointBuffer, rhsBufferAccess, loc);
512514
}
513515

516+
//--------------------------------------------------------------------------//
517+
// Member accessor pullback generation
518+
//--------------------------------------------------------------------------//
519+
520+
/// Returns true if the given original function is a "semantic member accessor".
521+
static bool isSemanticMemberAccessor(SILFunction *original) {
522+
auto *dc = original->getDeclContext();
523+
if (!dc)
524+
return false;
525+
auto *decl = dc->getAsDecl();
526+
if (!decl)
527+
return false;
528+
auto *accessor = dyn_cast<AccessorDecl>(decl);
529+
if (!accessor)
530+
return false;
531+
// Currently, only getters and setters are supported.
532+
// TODO(SR-12640): Support `modify` accessors.
533+
if (accessor->getAccessorKind() != AccessorKind::Get &&
534+
accessor->getAccessorKind() != AccessorKind::Set)
535+
return false;
536+
// Accessor must come from a `var` declaration.
537+
auto *varDecl = dyn_cast<VarDecl>(accessor->getStorage());
538+
if (!varDecl)
539+
return false;
540+
// Return true for stored property accessors.
541+
if (varDecl->hasStorage())
542+
return true;
543+
// Return true for properties that have attached property wrappers.
544+
if (varDecl->hasAttachedPropertyWrapper())
545+
return true;
546+
// Otherwise, return false.
547+
// User-defined accessors can never be supported because they may use custom
548+
// logic that does not semantically perform a member access.
549+
// TODO(SR-12636): Support `@differentiable(useInTangentVector)` computed
550+
// properties.
551+
return false;
552+
}
553+
554+
bool PullbackEmitter::runForSemanticMemberAccessor() {
555+
auto &original = getOriginal();
556+
auto *accessor = cast<AccessorDecl>(original.getDeclContext()->getAsDecl());
557+
switch (accessor->getAccessorKind()) {
558+
case AccessorKind::Get:
559+
return runForSemanticMemberGetter();
560+
case AccessorKind::Set:
561+
return runForSemanticMemberSetter();
562+
// TODO(SR-12640): Support `modify` accessors.
563+
default:
564+
llvm_unreachable("Unsupported accessor kind; inconsistent with "
565+
"`isSemanticMemberAccessor`?");
566+
}
567+
}
568+
569+
bool PullbackEmitter::runForSemanticMemberGetter() {
570+
auto &original = getOriginal();
571+
auto &pullback = getPullback();
572+
auto pbLoc = getPullback().getLocation();
573+
574+
auto *accessor = cast<AccessorDecl>(original.getDeclContext()->getAsDecl());
575+
assert(accessor->getAccessorKind() == AccessorKind::Get);
576+
577+
auto origExitIt = original.findReturnBB();
578+
assert(origExitIt != original.end() &&
579+
"Functions without returns must have been diagnosed");
580+
auto *origExit = &*origExitIt;
581+
builder.setInsertionPoint(pullback.getEntryBlock());
582+
583+
// Get getter argument and result values.
584+
// Getter type: $(Self) -> Result
585+
// Pullback type: $(Result', PB_Struct) -> Self'
586+
assert(original.getLoweredFunctionType()->getNumParameters() == 1);
587+
assert(pullback.getLoweredFunctionType()->getNumParameters() == 2);
588+
assert(pullback.getLoweredFunctionType()->getNumResults() == 1);
589+
SILValue origSelf = original.getArgumentsWithoutIndirectResults().front();
590+
591+
SmallVector<SILValue, 8> origFormalResults;
592+
collectAllFormalResultsInTypeOrder(original, origFormalResults);
593+
auto origResult = origFormalResults[getIndices().source];
594+
595+
// TODO(TF-970): Emit diagnostic when `TangentVector` is not a struct.
596+
auto tangentVectorSILTy = pullback.getConventions().getSingleSILResultType();
597+
auto tangentVectorTy = tangentVectorSILTy.getASTType();
598+
auto *tangentVectorDecl = tangentVectorTy->getStructOrBoundGenericStruct();
599+
600+
// Look up the corresponding field in the tangent space.
601+
VarDecl *origField = cast<VarDecl>(accessor->getStorage());
602+
VarDecl *tanField = nullptr;
603+
auto tanFieldLookup = tangentVectorDecl->lookupDirect(origField->getName());
604+
if (tanFieldLookup.empty()) {
605+
getContext().emitNondifferentiabilityError(
606+
pbLoc.getSourceLoc(), getInvoker(),
607+
diag::autodiff_stored_property_no_corresponding_tangent,
608+
origSelf->getType().getASTType().getString(), origField->getNameStr());
609+
errorOccurred = true;
610+
return true;
611+
}
612+
tanField = cast<VarDecl>(tanFieldLookup.front());
613+
614+
// Switch based on the base tangent struct's value category.
615+
// TODO(TF-1255): Simplify using unified adjoint value data structure.
616+
switch (tangentVectorSILTy.getCategory()) {
617+
case SILValueCategory::Object: {
618+
auto adjResult = getAdjointValue(origExit, origResult);
619+
switch (adjResult.getKind()) {
620+
case AdjointValueKind::Zero:
621+
addAdjointValue(origExit, origSelf,
622+
makeZeroAdjointValue(tangentVectorSILTy), pbLoc);
623+
break;
624+
case AdjointValueKind::Concrete:
625+
case AdjointValueKind::Aggregate: {
626+
SmallVector<AdjointValue, 8> eltVals;
627+
for (auto *field : tangentVectorDecl->getStoredProperties()) {
628+
if (field == tanField) {
629+
eltVals.push_back(adjResult);
630+
} else {
631+
auto substMap = tangentVectorTy->getMemberSubstitutionMap(
632+
field->getModuleContext(), field);
633+
auto fieldTy = field->getType().subst(substMap);
634+
auto fieldSILTy = getTypeLowering(fieldTy).getLoweredType();
635+
assert(fieldSILTy.isObject());
636+
eltVals.push_back(makeZeroAdjointValue(fieldSILTy));
637+
}
638+
}
639+
addAdjointValue(origExit, origSelf,
640+
makeAggregateAdjointValue(tangentVectorSILTy, eltVals),
641+
pbLoc);
642+
}
643+
}
644+
break;
645+
}
646+
case SILValueCategory::Address: {
647+
assert(pullback.getIndirectResults().size() == 1);
648+
auto pbIndRes = pullback.getIndirectResults().front();
649+
auto *adjSelf = createFunctionLocalAllocation(
650+
pbIndRes->getType().getObjectType(), pbLoc);
651+
setAdjointBuffer(origExit, origSelf, adjSelf);
652+
for (auto *field : tangentVectorDecl->getStoredProperties()) {
653+
auto *adjSelfElt = builder.createStructElementAddr(pbLoc, adjSelf, field);
654+
if (field == tanField) {
655+
// Switch based on the property's value category.
656+
// TODO(TF-1255): Simplify using unified adjoint value data structure.
657+
switch (origResult->getType().getCategory()) {
658+
case SILValueCategory::Object: {
659+
auto adjResult = getAdjointValue(origExit, origResult);
660+
auto adjResultValue = materializeAdjointDirect(adjResult, pbLoc);
661+
builder.emitStoreValueOperation(pbLoc, adjResultValue, adjSelfElt,
662+
StoreOwnershipQualifier::Init);
663+
break;
664+
}
665+
case SILValueCategory::Address: {
666+
auto adjResult = getAdjointBuffer(origExit, origResult);
667+
builder.createCopyAddr(pbLoc, adjResult, adjSelfElt, IsTake,
668+
IsInitialization);
669+
destroyedLocalAllocations.insert(adjResult);
670+
break;
671+
}
672+
}
673+
} else {
674+
auto fieldType = pullback.mapTypeIntoContext(field->getInterfaceType())
675+
->getCanonicalType();
676+
emitZeroIndirect(fieldType, adjSelfElt, pbLoc);
677+
}
678+
}
679+
break;
680+
}
681+
}
682+
return false;
683+
}
684+
685+
bool PullbackEmitter::runForSemanticMemberSetter() {
686+
auto &original = getOriginal();
687+
auto &pullback = getPullback();
688+
auto pbLoc = getPullback().getLocation();
689+
690+
auto *accessor = cast<AccessorDecl>(original.getDeclContext()->getAsDecl());
691+
assert(accessor->getAccessorKind() == AccessorKind::Set);
692+
693+
auto origEntry = original.getEntryBlock();
694+
builder.setInsertionPoint(pullback.getEntryBlock());
695+
696+
// Get setter argument values.
697+
// Setter type: $(inout Self, Argument) -> ()
698+
// Pullback type (wrt self): $(inout Self', PB_Struct) -> ()
699+
// Pullback type (wrt both): $(inout Self', PB_Struct) -> Argument'
700+
assert(original.getLoweredFunctionType()->getNumParameters() == 2);
701+
assert(pullback.getLoweredFunctionType()->getNumParameters() == 2);
702+
assert(pullback.getLoweredFunctionType()->getNumResults() == 0 ||
703+
pullback.getLoweredFunctionType()->getNumResults() == 1);
704+
705+
SILValue origArg = original.getArgumentsWithoutIndirectResults()[0];
706+
SILValue origSelf = original.getArgumentsWithoutIndirectResults()[1];
707+
708+
// TODO(TF-970): Emit diagnostic when `TangentVector` is not a struct.
709+
auto tangentVectorSILTy = pullback.getLoweredFunctionType()
710+
->getParameters()[0]
711+
.getSILStorageInterfaceType();
712+
assert(tangentVectorSILTy.getCategory() == SILValueCategory::Address);
713+
auto tangentVectorTy = tangentVectorSILTy.getASTType();
714+
auto *tangentVectorDecl = tangentVectorTy->getStructOrBoundGenericStruct();
715+
716+
// Look up the corresponding field in the tangent space.
717+
VarDecl *origField = cast<VarDecl>(accessor->getStorage());
718+
VarDecl *tanField = nullptr;
719+
auto tanFieldLookup = tangentVectorDecl->lookupDirect(origField->getName());
720+
if (tanFieldLookup.empty()) {
721+
getContext().emitNondifferentiabilityError(
722+
pbLoc.getSourceLoc(), getInvoker(),
723+
diag::autodiff_stored_property_no_corresponding_tangent,
724+
origSelf->getType().getASTType().getString(), origField->getNameStr());
725+
errorOccurred = true;
726+
return true;
727+
}
728+
tanField = cast<VarDecl>(tanFieldLookup.front());
729+
730+
auto adjSelf = getAdjointBuffer(origEntry, origSelf);
731+
auto *adjSelfElt = builder.createStructElementAddr(pbLoc, adjSelf, tanField);
732+
// Switch based on the property's value category.
733+
// TODO(TF-1255): Simplify using unified adjoint value data structure.
734+
switch (origArg->getType().getCategory()) {
735+
case SILValueCategory::Object: {
736+
auto adjArg = builder.emitLoadValueOperation(pbLoc, adjSelfElt,
737+
LoadOwnershipQualifier::Take);
738+
setAdjointValue(origEntry, origArg, makeConcreteAdjointValue(adjArg));
739+
break;
740+
}
741+
case SILValueCategory::Address: {
742+
addToAdjointBuffer(origEntry, origArg, adjSelfElt, pbLoc);
743+
builder.emitDestroyOperation(pbLoc, adjSelfElt);
744+
break;
745+
}
746+
}
747+
emitZeroIndirect(adjSelfElt->getType().getASTType(), adjSelfElt, pbLoc);
748+
749+
return false;
750+
}
751+
514752
//--------------------------------------------------------------------------//
515753
// Entry point
516754
//--------------------------------------------------------------------------//
@@ -756,12 +994,21 @@ bool PullbackEmitter::run() {
756994
<< " as the adjoint of original result " << origResult);
757995
}
758996

997+
// If the original function is an accessor with special-case pullback
998+
// generation logic, do special-case generation.
999+
if (isSemanticMemberAccessor(&original)) {
1000+
if (runForSemanticMemberAccessor())
1001+
return true;
1002+
}
1003+
// Otherwise, perform standard pullback generation.
7591004
// Visit original blocks blocks in post-order and perform differentiation
7601005
// in corresponding pullback blocks. If errors occurred, back out.
761-
for (auto *bb : postOrderPostDomOrder) {
762-
visitSILBasicBlock(bb);
763-
if (errorOccurred)
764-
return true;
1006+
else {
1007+
for (auto *bb : postOrderPostDomOrder) {
1008+
visitSILBasicBlock(bb);
1009+
if (errorOccurred)
1010+
return true;
1011+
}
7651012
}
7661013

7671014
// Prepare and emit a `return` in the pullback exit block.
@@ -1343,8 +1590,7 @@ void PullbackEmitter::visitStructInst(StructInst *si) {
13431590
auto adjStruct = materializeAdjointDirect(std::move(av), loc);
13441591
// Find the struct `TangentVector` type.
13451592
auto structTy = remapType(si->getType()).getASTType();
1346-
auto tangentVectorTy =
1347-
getTangentSpace(structTy)->getType()->getCanonicalType();
1593+
auto tangentVectorTy = getTangentSpace(structTy)->getCanonicalType();
13481594
assert(!getTypeLowering(tangentVectorTy).isAddressOnly());
13491595
auto *tangentVectorDecl = tangentVectorTy->getStructOrBoundGenericStruct();
13501596
assert(tangentVectorDecl);
@@ -1406,8 +1652,7 @@ void PullbackEmitter::visitStructExtractInst(StructExtractInst *sei) {
14061652
"differentiated; activity analysis should not marked as varied");
14071653
auto *bb = sei->getParent();
14081654
auto structTy = remapType(sei->getOperand()->getType()).getASTType();
1409-
auto tangentVectorTy =
1410-
getTangentSpace(structTy)->getType()->getCanonicalType();
1655+
auto tangentVectorTy = getTangentSpace(structTy)->getCanonicalType();
14111656
assert(!getTypeLowering(tangentVectorTy).isAddressOnly());
14121657
auto tangentVectorSILTy = SILType::getPrimitiveObjectType(tangentVectorTy);
14131658
auto *tangentVectorDecl = tangentVectorTy->getStructOrBoundGenericStruct();
@@ -1449,8 +1694,7 @@ void PullbackEmitter::visitStructExtractInst(StructExtractInst *sei) {
14491694
auto substMap = tangentVectorTy->getMemberSubstitutionMap(
14501695
field->getModuleContext(), field);
14511696
auto fieldTy = field->getType().subst(substMap);
1452-
auto fieldSILTy = getContext().getTypeConverter().getLoweredType(
1453-
fieldTy, TypeExpansionContext::minimal());
1697+
auto fieldSILTy = getTypeLowering(fieldTy).getLoweredType();
14541698
assert(fieldSILTy.isObject());
14551699
eltVals.push_back(makeZeroAdjointValue(fieldSILTy));
14561700
}
@@ -1466,8 +1710,7 @@ void PullbackEmitter::visitRefElementAddrInst(RefElementAddrInst *reai) {
14661710
auto *bb = reai->getParent();
14671711
auto adjBuf = getAdjointBuffer(bb, reai);
14681712
auto classTy = remapType(reai->getOperand()->getType()).getASTType();
1469-
auto tangentVectorTy =
1470-
getTangentSpace(classTy)->getType()->getCanonicalType();
1713+
auto tangentVectorTy = getTangentSpace(classTy)->getCanonicalType();
14711714
assert(!getTypeLowering(tangentVectorTy).isAddressOnly());
14721715
auto tangentVectorSILTy = SILType::getPrimitiveObjectType(tangentVectorTy);
14731716
auto *tangentVectorDecl = tangentVectorTy->getStructOrBoundGenericStruct();
@@ -1498,8 +1741,7 @@ void PullbackEmitter::visitRefElementAddrInst(RefElementAddrInst *reai) {
14981741
auto substMap = tangentVectorTy->getMemberSubstitutionMap(
14991742
field->getModuleContext(), field);
15001743
auto fieldTy = field->getType().subst(substMap);
1501-
auto fieldSILTy = getContext().getTypeConverter().getLoweredType(
1502-
fieldTy, TypeExpansionContext::minimal());
1744+
auto fieldSILTy = getTypeLowering(fieldTy).getLoweredType();
15031745
assert(fieldSILTy.isObject());
15041746
eltVals.push_back(makeZeroAdjointValue(fieldSILTy));
15051747
}

lib/Sema/DerivedConformanceDifferentiable.cpp

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -42,14 +42,11 @@ getStoredPropertiesForDifferentiation(NominalTypeDecl *nominal, DeclContext *DC,
4242
for (auto *vd : nominal->getStoredProperties()) {
4343
// Peer through property wrappers: use original wrapped properties instead.
4444
if (auto *originalProperty = vd->getOriginalWrappedProperty()) {
45-
// Skip property wrappers that do not define `wrappedValue.set`.
45+
// Skip wrapped properties that do not define a setter. They define a setter
46+
// only if property wrappers define a setter for `var wrappedValue`.
4647
// `mutating func move(along:)` cannot be synthesized to update these
4748
// properties.
48-
auto *wrapperDecl =
49-
vd->getInterfaceType()->getNominalOrBoundGenericNominal();
50-
auto *wrappedValueDecl =
51-
wrapperDecl->getPropertyWrapperTypeInfo().valueVar;
52-
if (!wrappedValueDecl->getAccessor(AccessorKind::Set))
49+
if (!originalProperty->getAccessor(AccessorKind::Set))
5350
continue;
5451
// Use the original wrapped property.
5552
vd = originalProperty;

0 commit comments

Comments
 (0)