Skip to content

Commit b169760

Browse files
authored
Merge pull request swiftlang#31700 from dan-zheng/wrapped-property-differentiation
[AutoDiff] Fix semantic member accessor pullback ownership errors.
2 parents 7a238c6 + cde1d18 commit b169760

File tree

5 files changed

+161
-16
lines changed

5 files changed

+161
-16
lines changed

lib/SILOptimizer/Differentiation/PullbackEmitter.cpp

Lines changed: 16 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -540,11 +540,9 @@ bool PullbackEmitter::runForSemanticMemberGetter() {
540540
auto *accessor = cast<AccessorDecl>(original.getDeclContext()->getAsDecl());
541541
assert(accessor->getAccessorKind() == AccessorKind::Get);
542542

543-
auto origExitIt = original.findReturnBB();
544-
assert(origExitIt != original.end() &&
545-
"Functions without returns must have been diagnosed");
546-
auto *origExit = &*origExitIt;
547-
builder.setInsertionPoint(pullback.getEntryBlock());
543+
auto *origEntry = original.getEntryBlock();
544+
auto *pbEntry = pullback.getEntryBlock();
545+
builder.setInsertionPoint(pbEntry);
548546

549547
// Get getter argument and result values.
550548
// Getter type: $(Self) -> Result
@@ -582,10 +580,10 @@ bool PullbackEmitter::runForSemanticMemberGetter() {
582580
// TODO(TF-1255): Simplify using unified adjoint value data structure.
583581
switch (tangentVectorSILTy.getCategory()) {
584582
case SILValueCategory::Object: {
585-
auto adjResult = getAdjointValue(origExit, origResult);
583+
auto adjResult = getAdjointValue(origEntry, origResult);
586584
switch (adjResult.getKind()) {
587585
case AdjointValueKind::Zero:
588-
addAdjointValue(origExit, origSelf,
586+
addAdjointValue(origEntry, origSelf,
589587
makeZeroAdjointValue(tangentVectorSILTy), pbLoc);
590588
break;
591589
case AdjointValueKind::Concrete:
@@ -603,7 +601,7 @@ bool PullbackEmitter::runForSemanticMemberGetter() {
603601
eltVals.push_back(makeZeroAdjointValue(fieldSILTy));
604602
}
605603
}
606-
addAdjointValue(origExit, origSelf,
604+
addAdjointValue(origEntry, origSelf,
607605
makeAggregateAdjointValue(tangentVectorSILTy, eltVals),
608606
pbLoc);
609607
}
@@ -615,22 +613,24 @@ bool PullbackEmitter::runForSemanticMemberGetter() {
615613
auto pbIndRes = pullback.getIndirectResults().front();
616614
auto *adjSelf = createFunctionLocalAllocation(
617615
pbIndRes->getType().getObjectType(), pbLoc);
618-
setAdjointBuffer(origExit, origSelf, adjSelf);
616+
setAdjointBuffer(origEntry, origSelf, adjSelf);
619617
for (auto *field : tangentVectorDecl->getStoredProperties()) {
620618
auto *adjSelfElt = builder.createStructElementAddr(pbLoc, adjSelf, field);
621619
if (field == tanField) {
622620
// Switch based on the property's value category.
623621
// TODO(TF-1255): Simplify using unified adjoint value data structure.
624622
switch (origResult->getType().getCategory()) {
625623
case SILValueCategory::Object: {
626-
auto adjResult = getAdjointValue(origExit, origResult);
624+
auto adjResult = getAdjointValue(origEntry, origResult);
627625
auto adjResultValue = materializeAdjointDirect(adjResult, pbLoc);
628-
builder.emitStoreValueOperation(pbLoc, adjResultValue, adjSelfElt,
626+
auto adjResultValueCopy =
627+
builder.emitCopyValueOperation(pbLoc, adjResultValue);
628+
builder.emitStoreValueOperation(pbLoc, adjResultValueCopy, adjSelfElt,
629629
StoreOwnershipQualifier::Init);
630630
break;
631631
}
632632
case SILValueCategory::Address: {
633-
auto adjResult = getAdjointBuffer(origExit, origResult);
633+
auto adjResult = getAdjointBuffer(origEntry, origResult);
634634
builder.createCopyAddr(pbLoc, adjResult, adjSelfElt, IsTake,
635635
IsInitialization);
636636
destroyedLocalAllocations.insert(adjResult);
@@ -657,8 +657,9 @@ bool PullbackEmitter::runForSemanticMemberSetter() {
657657
auto *accessor = cast<AccessorDecl>(original.getDeclContext()->getAsDecl());
658658
assert(accessor->getAccessorKind() == AccessorKind::Set);
659659

660-
auto origEntry = original.getEntryBlock();
661-
builder.setInsertionPoint(pullback.getEntryBlock());
660+
auto *origEntry = original.getEntryBlock();
661+
auto *pbEntry = pullback.getEntryBlock();
662+
builder.setInsertionPoint(pbEntry);
662663

663664
// Get setter argument values.
664665
// Setter type: $(inout Self, Argument) -> ()
@@ -703,6 +704,7 @@ bool PullbackEmitter::runForSemanticMemberSetter() {
703704
auto adjArg = builder.emitLoadValueOperation(pbLoc, adjSelfElt,
704705
LoadOwnershipQualifier::Take);
705706
setAdjointValue(origEntry, origArg, makeConcreteAdjointValue(adjArg));
707+
blockTemporaries[pbEntry].insert(adjArg);
706708
break;
707709
}
708710
case SILValueCategory::Address: {
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
import _Differentiation
2+
3+
/// A non-trivial, loadable type.
4+
///
5+
/// Used to test differentiation transform coverage.
6+
struct NontrivialLoadable<T> {
7+
fileprivate class Box {
8+
fileprivate var value: T
9+
init(_ value: T) {
10+
self.value = value
11+
}
12+
}
13+
private var handle: Box
14+
15+
init(_ value: T) {
16+
self.handle = Box(value)
17+
}
18+
19+
var value: T {
20+
get { handle.value }
21+
set { handle.value = newValue }
22+
}
23+
}
24+
25+
extension NontrivialLoadable: ExpressibleByFloatLiteral
26+
where T: ExpressibleByFloatLiteral {
27+
init(floatLiteral value: T.FloatLiteralType) {
28+
self.handle = Box(T(floatLiteral: value))
29+
}
30+
}
31+
32+
extension NontrivialLoadable: ExpressibleByIntegerLiteral
33+
where T: ExpressibleByIntegerLiteral {
34+
init(integerLiteral value: T.IntegerLiteralType) {
35+
self.handle = Box(T(integerLiteral: value))
36+
}
37+
}
38+
39+
extension NontrivialLoadable: Equatable where T: Equatable {
40+
static func == (lhs: NontrivialLoadable, rhs: NontrivialLoadable) -> Bool {
41+
return lhs.value == rhs.value
42+
}
43+
}
44+
45+
extension NontrivialLoadable: AdditiveArithmetic where T: AdditiveArithmetic {
46+
static var zero: NontrivialLoadable { return NontrivialLoadable(T.zero) }
47+
static func + (lhs: NontrivialLoadable, rhs: NontrivialLoadable)
48+
-> NontrivialLoadable
49+
{
50+
return NontrivialLoadable(lhs.value + rhs.value)
51+
}
52+
static func - (lhs: NontrivialLoadable, rhs: NontrivialLoadable)
53+
-> NontrivialLoadable
54+
{
55+
return NontrivialLoadable(lhs.value - rhs.value)
56+
}
57+
}
58+
59+
extension NontrivialLoadable: Differentiable
60+
where T: Differentiable, T == T.TangentVector {
61+
typealias TangentVector = NontrivialLoadable<T.TangentVector>
62+
}
Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
// RUN: %target-swift-frontend -emit-sil -verify %s %S/Inputs/nontrivial_loadable_type.swift
2+
3+
// Test property wrapper differentiation coverage for a variety of property
4+
// types: trivial, non-trivial loadable, and address-only.
5+
6+
import DifferentiationUnittest
7+
8+
// MARK: Property wrappers
9+
10+
@propertyWrapper
11+
struct SimpleWrapper<Value> {
12+
var wrappedValue: Value // stored property
13+
}
14+
15+
@propertyWrapper
16+
struct Wrapper<Value> {
17+
private var value: Value
18+
var wrappedValue: Value { // computed property
19+
get { value }
20+
set { value = newValue }
21+
}
22+
23+
init(wrappedValue: Value) {
24+
self.value = wrappedValue
25+
}
26+
}
27+
28+
// MARK: Types with wrapped properties
29+
30+
struct Struct: Differentiable {
31+
@Wrapper @SimpleWrapper var trivial: Float = 10
32+
@Wrapper @SimpleWrapper var tracked: Tracked<Float> = 20
33+
@Wrapper @SimpleWrapper var nontrivial: NontrivialLoadable<Float> = 30
34+
35+
static func testGetters() {
36+
let _: @differentiable (Self) -> Float = { $0.trivial }
37+
let _: @differentiable (Self) -> Tracked<Float> = { $0.tracked }
38+
let _: @differentiable (Self) -> NontrivialLoadable<Float> = { $0.nontrivial }
39+
}
40+
41+
static func testSetters() {
42+
let _: @differentiable (inout Self, Float) -> Void =
43+
{ $0.trivial = $1 }
44+
let _: @differentiable (inout Self, Tracked<Float>) -> Void =
45+
{ $0.tracked = $1 }
46+
let _: @differentiable (inout Self, NontrivialLoadable<Float>) -> Void =
47+
{ $0.nontrivial = $1 }
48+
}
49+
}
50+
51+
struct GenericStruct<T: Differentiable>: Differentiable {
52+
@Wrapper @SimpleWrapper var trivial: Float = 10
53+
@Wrapper @SimpleWrapper var tracked: Tracked<Float> = 20
54+
@Wrapper @SimpleWrapper var nontrivial: NontrivialLoadable<Float> = 30
55+
@Wrapper @SimpleWrapper var addressOnly: T
56+
57+
// SR-12778: Test getter pullback for non-trivial loadable property.
58+
static func testGetters() {
59+
let _: @differentiable (Self) -> Float = { $0.trivial }
60+
let _: @differentiable (Self) -> Tracked<Float> = { $0.tracked }
61+
let _: @differentiable (Self) -> NontrivialLoadable<Float> = { $0.nontrivial }
62+
let _: @differentiable (Self) -> T = { $0.addressOnly }
63+
}
64+
65+
// SR-12779: Test setter pullback for non-trivial loadable property.
66+
static func testSetters() {
67+
let _: @differentiable (inout Self, Float) -> Void =
68+
{ $0.trivial = $1 }
69+
let _: @differentiable (inout Self, Tracked<Float>) -> Void =
70+
{ $0.tracked = $1 }
71+
let _: @differentiable (inout Self, NontrivialLoadable<Float>) -> Void =
72+
{ $0.nontrivial = $1 }
73+
let _: @differentiable (inout Self, T) -> Void =
74+
{ $0.addressOnly = $1 }
75+
}
76+
}

test/AutoDiff/SILOptimizer/semantic_member_accessors_sil.swift

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ func trigger<T: Differentiable>(_ x: T.Type) {
4444
// CHECK-LABEL: // differentiability witness for Struct.x.getter
4545
// CHECK-NEXT: sil_differentiability_witness private [parameters 0] [results 0] @$s4null6StructV1xSfvg : $@convention(method) (Struct) -> Float {
4646

47-
// CHECK-LABEL: sil private [ossa] @AD__$s4null7GenericV1xxvs__pullback_src_0_wrt_0_1_16_Differentiation14DifferentiableRzl : $@convention(method) <τ_0_0 where τ_0_0 : Differentiable> (@inout Generic<τ_0_0>.TangentVector, @owned {{.*}}) -> @out τ_0_0.TangentVector {
47+
// CHECK-LABEL: sil private [ossa] @AD__$s4null7GenericV1xxvs__pullback_src_0_wrt_0_1_{{16_Differentiation|s}}14DifferentiableRzl : $@convention(method) <τ_0_0 where τ_0_0 : Differentiable> (@inout Generic<τ_0_0>.TangentVector, @owned {{.*}}) -> @out τ_0_0.TangentVector {
4848
// CHECK: bb0([[ADJ_X_RESULT:%.*]] : $*τ_0_0.TangentVector, [[ADJ_SELF:%.*]] : $*Generic<τ_0_0>.TangentVector, {{.*}} : {{.*}}):
4949
// CHECK: [[ADJ_X_TMP:%.*]] = alloc_stack $τ_0_0.TangentVector
5050
// CHECK: [[ZERO_FN:%.*]] = witness_method $τ_0_0.TangentVector, #AdditiveArithmetic.zero!getter
@@ -60,7 +60,7 @@ func trigger<T: Differentiable>(_ x: T.Type) {
6060
// CHECK: return {{.*}} : $()
6161
// CHECK: }
6262

63-
// CHECK-LABEL: sil private [ossa] @AD__$s4null7GenericV1xxvg__pullback_src_0_wrt_0_16_Differentiation14DifferentiableRzl : $@convention(method) <τ_0_0 where τ_0_0 : Differentiable> (@in_guaranteed τ_0_0.TangentVector, @owned {{.*}}) -> @out Generic<τ_0_0>.TangentVector {
63+
// CHECK-LABEL: sil private [ossa] @AD__$s4null7GenericV1xxvg__pullback_src_0_wrt_0_{{16_Differentiation|s}}14DifferentiableRzl : $@convention(method) <τ_0_0 where τ_0_0 : Differentiable> (@in_guaranteed τ_0_0.TangentVector, @owned {{.*}}) -> @out Generic<τ_0_0>.TangentVector {
6464
// CHECK: bb0([[ADJ_SELF_RESULT:%.*]] : $*Generic<τ_0_0>.TangentVector, [[SEED:%.*]] : $*τ_0_0.TangentVector, {{.*}} : ${{.*}}):
6565
// CHECK: [[ADJ_SELF_TMP:%.*]] = alloc_stack $Generic<τ_0_0>.TangentVector
6666
// CHECK: [[SEED_COPY:%.*]] = alloc_stack $τ_0_0.TangentVector

test/AutoDiff/validation-test/property_wrappers.swift

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,8 @@ PropertyWrapperTests.test("SimpleClass") {
141141
*/
142142

143143
// From: https://github.com/apple/swift-evolution/blob/master/proposals/0258-property-wrappers.md#proposed-solution
144+
// Tests the following functionality:
145+
// - Enum property wrapper.
144146
@propertyWrapper
145147
enum Lazy<Value> {
146148
case uninitialized(() -> Value)
@@ -151,10 +153,13 @@ enum Lazy<Value> {
151153
}
152154

153155
var wrappedValue: Value {
156+
// TODO(TF-1250): Replace with actual mutating getter implementation.
157+
// Requires differentiation to support functions with multiple results.
154158
get {
155159
switch self {
156160
case .uninitialized(let initializer):
157161
let value = initializer()
162+
// NOTE: Actual implementation assigns to `self` here.
158163
return value
159164
case .initialized(let value):
160165
return value

0 commit comments

Comments
 (0)