Skip to content

Commit 96f3f6f

Browse files
committed
[AutoDiff] Finish wrapped property differentiation support.
Add special-case VJP generation support for "semantic member accessors". This is necessary to avoid activity analysis related diagnostics and simplifies generated code. Fix "wrapped property mutability" check in `Differentiable` derived conformnances. This resolves SR-12642. Add e2e test using real world property wrappers (`@Lazy` and `@Clamping`).
1 parent ff9cd41 commit 96f3f6f

File tree

10 files changed

+283
-132
lines changed

10 files changed

+283
-132
lines changed

include/swift/SILOptimizer/Differentiation/Common.h

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,23 @@ ApplyInst *getAllocateUninitializedArrayIntrinsicElementAddress(SILValue v);
6161
/// tuple-typed and such a user exists.
6262
DestructureTupleInst *getSingleDestructureTupleUser(SILValue value);
6363

64+
/// Returns true if the given original function is a "semantic member accessor".
65+
///
66+
/// "Semantic member accessors" are attached to member properties that have a
67+
/// corresponding tangent stored property in the parent `TangentVector` type.
68+
/// These accessors have special-case pullback generation based on their
69+
/// semantic behavior.
70+
///
71+
/// "Semantic member accessors" currently include:
72+
/// - Stored property accessors. These are implicitly generated.
73+
/// - Property wrapper wrapped value accessors. These are implicitly generated
74+
/// and internally call `var wrappedValue`.
75+
bool isSemanticMemberAccessor(SILFunction *original);
76+
77+
/// Returns true if the given apply site has a "semantic member accessor"
78+
/// callee.
79+
bool hasSemanticMemberAccessorCallee(ApplySite applySite);
80+
6481
/// Given a full apply site, apply the given callback to each of its
6582
/// "direct results".
6683
///

include/swift/SILOptimizer/Differentiation/PullbackEmitter.h

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -324,11 +324,6 @@ class PullbackEmitter final : public SILInstructionVisitor<PullbackEmitter> {
324324
/// These accessors have special-case pullback generation based on their
325325
/// semantic behavior.
326326
///
327-
/// "Semantic member accessors" currently include:
328-
/// - Stored property accessors. These are implicitly generated.
329-
/// - Property wrapper wrapped value accessors. These are implicitly generated
330-
/// and internally call `var wrappedValue`.
331-
///
332327
/// Returns true if any error occurs.
333328
bool runForSemanticMemberAccessor();
334329
bool runForSemanticMemberGetter();

lib/SILOptimizer/Differentiation/Common.cpp

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,44 @@ DestructureTupleInst *getSingleDestructureTupleUser(SILValue value) {
7171
return result;
7272
}
7373

74+
bool isSemanticMemberAccessor(SILFunction *original) {
75+
auto *dc = original->getDeclContext();
76+
if (!dc)
77+
return false;
78+
auto *decl = dc->getAsDecl();
79+
if (!decl)
80+
return false;
81+
auto *accessor = dyn_cast<AccessorDecl>(decl);
82+
if (!accessor)
83+
return false;
84+
// Currently, only getters and setters are supported.
85+
// TODO(SR-12640): Support `modify` accessors.
86+
if (accessor->getAccessorKind() != AccessorKind::Get &&
87+
accessor->getAccessorKind() != AccessorKind::Set)
88+
return false;
89+
// Accessor must come from a `var` declaration.
90+
auto *varDecl = dyn_cast<VarDecl>(accessor->getStorage());
91+
if (!varDecl)
92+
return false;
93+
// Return true for stored property accessors.
94+
if (varDecl->hasStorage() && varDecl->isInstanceMember())
95+
return true;
96+
// Return true for properties that have attached property wrappers.
97+
if (varDecl->hasAttachedPropertyWrapper())
98+
return true;
99+
// Otherwise, return false.
100+
// User-defined accessors can never be supported because they may use custom
101+
// logic that does not semantically perform a member access.
102+
return false;
103+
}
104+
105+
bool hasSemanticMemberAccessorCallee(ApplySite applySite) {
106+
if (auto *FRI = dyn_cast<FunctionRefBaseInst>(applySite.getCallee()))
107+
if (auto *F = FRI->getReferencedFunctionOrNull())
108+
return isSemanticMemberAccessor(F);
109+
return false;
110+
}
111+
74112
void forEachApplyDirectResult(
75113
FullApplySite applySite,
76114
llvm::function_ref<void(SILValue)> resultCallback) {

lib/SILOptimizer/Differentiation/PullbackEmitter.cpp

Lines changed: 15 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -517,40 +517,6 @@ void PullbackEmitter::addToAdjointBuffer(SILBasicBlock *origBB,
517517
// Member accessor pullback generation
518518
//--------------------------------------------------------------------------//
519519

520-
/// Returns true if the given original function is a "semantic member accessor".
521-
static bool isSemanticMemberAccessor(SILFunction *original) {
522-
auto *dc = original->getDeclContext();
523-
if (!dc)
524-
return false;
525-
auto *decl = dc->getAsDecl();
526-
if (!decl)
527-
return false;
528-
auto *accessor = dyn_cast<AccessorDecl>(decl);
529-
if (!accessor)
530-
return false;
531-
// Currently, only getters and setters are supported.
532-
// TODO(SR-12640): Support `modify` accessors.
533-
if (accessor->getAccessorKind() != AccessorKind::Get &&
534-
accessor->getAccessorKind() != AccessorKind::Set)
535-
return false;
536-
// Accessor must come from a `var` declaration.
537-
auto *varDecl = dyn_cast<VarDecl>(accessor->getStorage());
538-
if (!varDecl)
539-
return false;
540-
// Return true for stored property accessors.
541-
if (varDecl->hasStorage())
542-
return true;
543-
// Return true for properties that have attached property wrappers.
544-
if (varDecl->hasAttachedPropertyWrapper())
545-
return true;
546-
// Otherwise, return false.
547-
// User-defined accessors can never be supported because they may use custom
548-
// logic that does not semantically perform a member access.
549-
// TODO(SR-12636): Support `@differentiable(useInTangentVector)` computed
550-
// properties.
551-
return false;
552-
}
553-
554520
bool PullbackEmitter::runForSemanticMemberAccessor() {
555521
auto &original = getOriginal();
556522
auto *accessor = cast<AccessorDecl>(original.getDeclContext()->getAsDecl());
@@ -593,7 +559,8 @@ bool PullbackEmitter::runForSemanticMemberGetter() {
593559
auto origResult = origFormalResults[getIndices().source];
594560

595561
// TODO(TF-970): Emit diagnostic when `TangentVector` is not a struct.
596-
auto tangentVectorSILTy = pullback.getConventions().getSingleSILResultType();
562+
auto tangentVectorSILTy = pullback.getConventions().getSingleSILResultType(
563+
TypeExpansionContext::minimal());
597564
auto tangentVectorTy = tangentVectorSILTy.getASTType();
598565
auto *tangentVectorDecl = tangentVectorTy->getStructOrBoundGenericStruct();
599566

@@ -806,15 +773,6 @@ bool PullbackEmitter::run() {
806773
return;
807774
visited.insert(v);
808775
auto type = v->getType();
809-
// Diagnose active enum values. Differentiation of enum values requires
810-
// special adjoint value handling and is not yet supported. Diagnose
811-
// only the first active enum value to prevent too many diagnostics.
812-
if (!diagnosedActiveEnumValue && type.getEnumOrBoundGenericEnum()) {
813-
getContext().emitNondifferentiabilityError(
814-
v, getInvoker(), diag::autodiff_enums_unsupported);
815-
errorOccurred = true;
816-
diagnosedActiveEnumValue = true;
817-
}
818776
// Diagnose active values whose value category is incompatible with their
819777
// tangent types's value category.
820778
//
@@ -842,6 +800,19 @@ bool PullbackEmitter::run() {
842800
}
843801
}
844802
}
803+
// Do not emit remaining activity-related diagnostics for semantic member
804+
// accessors, which have special-case pullback generation.
805+
if (isSemanticMemberAccessor(&original))
806+
return;
807+
// Diagnose active enum values. Differentiation of enum values requires
808+
// special adjoint value handling and is not yet supported. Diagnose
809+
// only the first active enum value to prevent too many diagnostics.
810+
if (!diagnosedActiveEnumValue && type.getEnumOrBoundGenericEnum()) {
811+
getContext().emitNondifferentiabilityError(
812+
v, getInvoker(), diag::autodiff_enums_unsupported);
813+
errorOccurred = true;
814+
diagnosedActiveEnumValue = true;
815+
}
845816
// Skip address projections.
846817
// Address projections do not need their own adjoint buffers; they
847818
// become projections into their adjoint base buffer.

lib/SILOptimizer/Differentiation/VJPEmitter.cpp

Lines changed: 37 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -511,14 +511,35 @@ void VJPEmitter::visitSwitchEnumAddrInst(SwitchEnumAddrInst *seai) {
511511
}
512512

513513
void VJPEmitter::visitApplyInst(ApplyInst *ai) {
514-
// If the function should not be differentiated or its the array literal
515-
// initialization intrinsic, just do standard cloning.
516-
if (!pullbackInfo.shouldDifferentiateApplySite(ai) ||
517-
isArrayLiteralIntrinsic(ai)) {
514+
// If callee should not be differentiated, do standard cloning.
515+
if (!pullbackInfo.shouldDifferentiateApplySite(ai)) {
518516
LLVM_DEBUG(getADDebugStream() << "No active results:\n" << *ai << '\n');
519517
TypeSubstCloner::visitApplyInst(ai);
520518
return;
521519
}
520+
// If callee is the array literal initialization intrinsic, do standard
521+
// cloning. Array literal differentiation is handled separately.
522+
if (isArrayLiteralIntrinsic(ai)) {
523+
LLVM_DEBUG(getADDebugStream() << "Cloning array literal intrinsic `apply`\n"
524+
<< *ai << '\n');
525+
TypeSubstCloner::visitApplyInst(ai);
526+
return;
527+
}
528+
// If the original function is a semantic member accessor, do standard
529+
// cloning. Semantic member accessors have special pullback generation logic,
530+
// so all `apply` instructions can be directly cloned to the VJP.
531+
if (isSemanticMemberAccessor(original)) {
532+
LLVM_DEBUG(getADDebugStream()
533+
<< "Cloning `apply` in semantic member accessor:\n"
534+
<< *ai << '\n');
535+
TypeSubstCloner::visitApplyInst(ai);
536+
return;
537+
}
538+
539+
auto loc = ai->getLoc();
540+
auto &builder = getBuilder();
541+
auto origCallee = getOpValue(ai->getCallee());
542+
auto originalFnTy = origCallee->getType().castTo<SILFunctionType>();
522543

523544
LLVM_DEBUG(getADDebugStream() << "VJP-transforming:\n" << *ai << '\n');
524545

@@ -558,31 +579,27 @@ void VJPEmitter::visitApplyInst(ApplyInst *ai) {
558579
activeParamIndices));
559580

560581
// Emit the VJP.
561-
auto loc = ai->getLoc();
562-
auto &builder = getBuilder();
563-
auto original = getOpValue(ai->getCallee());
564582
SILValue vjpValue;
565583
// If functionSource is a `@differentiable` function, just extract it.
566-
auto originalFnTy = original->getType().castTo<SILFunctionType>();
567584
if (originalFnTy->isDifferentiable()) {
568585
auto paramIndices = originalFnTy->getDifferentiabilityParameterIndices();
569586
for (auto i : indices.parameters->getIndices()) {
570587
if (!paramIndices->contains(i)) {
571588
context.emitNondifferentiabilityError(
572-
original, invoker,
589+
origCallee, invoker,
573590
diag::autodiff_function_noderivative_parameter_not_differentiable);
574591
errorOccurred = true;
575592
return;
576593
}
577594
}
578-
auto origFnType = original->getType().castTo<SILFunctionType>();
595+
auto origFnType = origCallee->getType().castTo<SILFunctionType>();
579596
auto origFnUnsubstType = origFnType->getUnsubstitutedType(getModule());
580597
if (origFnType != origFnUnsubstType) {
581-
original = builder.createConvertFunction(
582-
loc, original, SILType::getPrimitiveObjectType(origFnUnsubstType),
598+
origCallee = builder.createConvertFunction(
599+
loc, origCallee, SILType::getPrimitiveObjectType(origFnUnsubstType),
583600
/*withoutActuallyEscaping*/ false);
584601
}
585-
auto borrowedDiffFunc = builder.emitBeginBorrowOperation(loc, original);
602+
auto borrowedDiffFunc = builder.emitBeginBorrowOperation(loc, origCallee);
586603
vjpValue = builder.createDifferentiableFunctionExtract(
587604
loc, NormalDifferentiableFunctionTypeComponent::VJP, borrowedDiffFunc);
588605
vjpValue = builder.emitCopyValueOperation(loc, vjpValue);
@@ -624,7 +641,7 @@ void VJPEmitter::visitApplyInst(ApplyInst *ai) {
624641
}
625642
if (!remappedResultType.isDifferentiable(getModule())) {
626643
context.emitNondifferentiabilityError(
627-
original, invoker, diag::autodiff_nondifferentiable_result);
644+
origCallee, invoker, diag::autodiff_nondifferentiable_result);
628645
errorOccurred = true;
629646
return true;
630647
}
@@ -645,7 +662,7 @@ void VJPEmitter::visitApplyInst(ApplyInst *ai) {
645662
/*
646663
DifferentiationInvoker indirect(ai, attr);
647664
auto insertion =
648-
context.getInvokers().try_emplace({this->original, attr}, indirect);
665+
context.getInvokers().try_emplace({original, attr}, indirect);
649666
auto &invoker = insertion.first->getSecond();
650667
invoker = indirect;
651668
*/
@@ -656,21 +673,21 @@ void VJPEmitter::visitApplyInst(ApplyInst *ai) {
656673
// function operand is specialized with a remapped version of same
657674
// substitution map using an argument-less `partial_apply`.
658675
if (ai->getSubstitutionMap().empty()) {
659-
original = builder.emitCopyValueOperation(loc, original);
676+
origCallee = builder.emitCopyValueOperation(loc, origCallee);
660677
} else {
661678
auto substMap = getOpSubstitutionMap(ai->getSubstitutionMap());
662679
auto vjpPartialApply = getBuilder().createPartialApply(
663-
ai->getLoc(), original, substMap, {},
680+
ai->getLoc(), origCallee, substMap, {},
664681
ParameterConvention::Direct_Guaranteed);
665-
original = vjpPartialApply;
666-
originalFnTy = original->getType().castTo<SILFunctionType>();
682+
origCallee = vjpPartialApply;
683+
originalFnTy = origCallee->getType().castTo<SILFunctionType>();
667684
// Diagnose if new original function type is non-differentiable.
668685
if (diagnoseNondifferentiableOriginalFunctionType(originalFnTy))
669686
return;
670687
}
671688

672689
auto *diffFuncInst = context.createDifferentiableFunction(
673-
getBuilder(), loc, indices.parameters, original);
690+
getBuilder(), loc, indices.parameters, origCallee);
674691

675692
// Record the `differentiable_function` instruction.
676693
context.addDifferentiableFunctionInstToWorklist(diffFuncInst);

lib/Sema/DerivedConformanceDifferentiable.cpp

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -42,11 +42,11 @@ getStoredPropertiesForDifferentiation(NominalTypeDecl *nominal, DeclContext *DC,
4242
for (auto *vd : nominal->getStoredProperties()) {
4343
// Peer through property wrappers: use original wrapped properties instead.
4444
if (auto *originalProperty = vd->getOriginalWrappedProperty()) {
45-
// Skip wrapped properties that do not define a setter. They define a setter
46-
// only if property wrappers define a setter for `var wrappedValue`.
47-
// `mutating func move(along:)` cannot be synthesized to update these
48-
// properties.
49-
if (!originalProperty->getAccessor(AccessorKind::Set))
45+
// Skip immutable wrapped properties. `mutating func move(along:)` cannot
46+
// be synthesized to update these properties.
47+
auto mutability = originalProperty->getPropertyWrapperMutability();
48+
assert(mutability.hasValue() && "Expected wrapped property mutability");
49+
if (mutability->Setter != PropertyWrapperMutability::Value::Mutating)
5050
continue;
5151
// Use the original wrapped property.
5252
vd = originalProperty;
@@ -510,9 +510,9 @@ static void checkAndDiagnoseImplicitNoDerivative(ASTContext &Context,
510510
// to update these properties.
511511
auto *wrapperDecl =
512512
vd->getInterfaceType()->getNominalOrBoundGenericNominal();
513-
auto *wrappedValueDecl =
514-
wrapperDecl->getPropertyWrapperTypeInfo().valueVar;
515-
if (!wrappedValueDecl->getAccessor(AccessorKind::Set)) {
513+
auto mutability = originalProperty->getPropertyWrapperMutability();
514+
assert(mutability.hasValue() && "Expected wrapped property mutability");
515+
if (mutability->Setter != PropertyWrapperMutability::Value::Mutating) {
516516
auto loc =
517517
originalProperty->getAttributeInsertionLoc(/*forModifier*/ false);
518518
Context.Diags

test/AutoDiff/SILOptimizer/differentiation_diagnostics.swift

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -177,14 +177,30 @@ let _: @differentiable (Float) -> TF_691<Float> = { x in id(TF_691(x)) }
177177

178178
@propertyWrapper
179179
struct Wrapper<Value> {
180-
var wrappedValue: Value
180+
private var value: Value
181+
var wrappedValue: Value {
182+
get { value }
183+
set { value = newValue }
184+
}
181185
var projectedValue: Self { self }
186+
187+
init(wrappedValue: Value) {
188+
self.value = wrappedValue
189+
}
182190
}
183191

184192
@propertyWrapper
185193
struct DifferentiableWrapper<Value> {
186-
var wrappedValue: Value
194+
private var value: Value
195+
var wrappedValue: Value {
196+
get { value }
197+
set { value = newValue }
198+
}
187199
var projectedValue: Self { self }
200+
201+
init(wrappedValue: Value) {
202+
self.value = wrappedValue
203+
}
188204
}
189205
extension DifferentiableWrapper: Differentiable where Value: Differentiable {}
190206
// Note: property wrapped value differentiation works even if wrapper types do

test/AutoDiff/compiler_crashers/sr12642-differentiable-derivation-redeclared-property.swift

Lines changed: 0 additions & 18 deletions
This file was deleted.
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
// RUN: %target-swift-frontend -typecheck -verify %s
2+
// REQUIRES: asserts
3+
4+
// SR-12642: Crash regarding `Differentiable` derived conformances and
5+
// redeclared properties. This crash surfaced only briefly during the
6+
// implementation of wrapped property differentiation (SR-12637).
7+
8+
import _Differentiation
9+
10+
@propertyWrapper
11+
struct Wrapper<Value> {
12+
var wrappedValue: Value
13+
}
14+
15+
struct Generic<T> {}
16+
extension Generic: Differentiable where T: Differentiable {}
17+
18+
struct WrappedProperties: Differentiable {
19+
// expected-note @+2 {{'int' previously declared here}}
20+
// expected-warning @+1 {{stored property 'int' has no derivative because 'Generic<Int>' does not conform to 'Differentiable'; add an explicit '@noDerivative' attribute}}
21+
@Wrapper var int: Generic<Int>
22+
23+
// expected-error @+1 {{invalid redeclaration of 'int'}}
24+
@Wrapper var int: Generic<Int>
25+
}

0 commit comments

Comments
 (0)