Skip to content

Commit 4bba194

Browse files
committed
[region-isolation] Add support for differentiability instructions.
1 parent 3473108 commit 4bba194

File tree

2 files changed

+171
-5
lines changed

2 files changed

+171
-5
lines changed

lib/SILOptimizer/Analysis/RegionAnalysis.cpp

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -196,11 +196,13 @@ static bool isStaticallyLookThroughInst(SILInstruction *inst) {
196196
case SILInstructionKind::CopyableToMoveOnlyWrapperValueInst:
197197
case SILInstructionKind::DestructureStructInst:
198198
case SILInstructionKind::DestructureTupleInst:
199+
case SILInstructionKind::DifferentiableFunctionExtractInst:
199200
case SILInstructionKind::DropDeinitInst:
200201
case SILInstructionKind::EndCOWMutationInst:
201202
case SILInstructionKind::EndInitLetRefInst:
202203
case SILInstructionKind::ExplicitCopyValueInst:
203204
case SILInstructionKind::InitEnumDataAddrInst:
205+
case SILInstructionKind::LinearFunctionExtractInst:
204206
case SILInstructionKind::MarkDependenceInst:
205207
case SILInstructionKind::MarkUninitializedInst:
206208
case SILInstructionKind::MarkUnresolvedNonCopyableValueInst:
@@ -2465,11 +2467,16 @@ CONSTANT_TRANSLATION(InitExistentialValueInst, Unhandled)
24652467
CONSTANT_TRANSLATION(InitExistentialMetatypeInst, Unhandled)
24662468
CONSTANT_TRANSLATION(OpenExistentialMetatypeInst, Unhandled)
24672469
CONSTANT_TRANSLATION(OpenExistentialValueInst, Unhandled)
2468-
CONSTANT_TRANSLATION(DifferentiableFunctionInst, Unhandled)
2469-
CONSTANT_TRANSLATION(LinearFunctionInst, Unhandled)
2470-
CONSTANT_TRANSLATION(DifferentiableFunctionExtractInst, Unhandled)
2471-
CONSTANT_TRANSLATION(LinearFunctionExtractInst, Unhandled)
2472-
CONSTANT_TRANSLATION(DifferentiabilityWitnessFunctionInst, Unhandled)
2470+
2471+
//===---
2472+
// Differentiable
2473+
//
2474+
2475+
CONSTANT_TRANSLATION(DifferentiabilityWitnessFunctionInst, AssignFresh)
2476+
CONSTANT_TRANSLATION(DifferentiableFunctionExtractInst, LookThrough)
2477+
CONSTANT_TRANSLATION(LinearFunctionExtractInst, LookThrough)
2478+
CONSTANT_TRANSLATION(LinearFunctionInst, Assign)
2479+
CONSTANT_TRANSLATION(DifferentiableFunctionInst, Assign)
24732480

24742481
//===---
24752482
// Packs
Lines changed: 159 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,159 @@
1+
// RUN: %target-sil-opt -transfer-non-sendable -enable-experimental-feature RegionBasedIsolation -strict-concurrency=complete %s -verify -o /dev/null
2+
3+
// REQUIRES: concurrency
4+
// REQUIRES: asserts
5+
6+
sil_stage raw
7+
8+
import Swift
9+
import Builtin
10+
import _Differentiation
11+
12+
////////////////////////
13+
// MARK: Declarations //
14+
////////////////////////
15+
16+
class NonSendableKlass {}
17+
18+
sil @transferNonSendableKlass : $@convention(thin) @async (@guaranteed NonSendableKlass) -> ()
19+
sil @useNonSendableKlass : $@convention(thin) (@guaranteed NonSendableKlass) -> ()
20+
sil @constructNonSendableKlass : $@convention(thin) () -> @owned NonSendableKlass
21+
22+
final class SendableKlass : Sendable {}
23+
24+
sil @transferSendableKlass : $@convention(thin) @async (@guaranteed SendableKlass) -> ()
25+
sil @constructSendableKlass : $@convention(thin) () -> @owned SendableKlass
26+
27+
final class KlassContainingKlasses {
28+
let nsImmutable : NonSendableKlass
29+
var nsMutable : NonSendableKlass
30+
let sImmutable : SendableKlass
31+
var sMutable : SendableKlass
32+
}
33+
34+
sil @transferKlassContainingKlasses : $@convention(thin) @async (@guaranteed KlassContainingKlasses) -> ()
35+
sil @useKlassContainingKlasses : $@convention(thin) (@guaranteed KlassContainingKlasses) -> ()
36+
sil @constructKlassContainingKlasses : $@convention(thin) () -> @owned KlassContainingKlasses
37+
38+
@_moveOnly
39+
struct NonSendableMoveOnlyStruct {
40+
var ns: NonSendableKlass
41+
42+
deinit
43+
}
44+
45+
sil @constructMoveOnlyStruct : $@convention(thin) () -> @owned NonSendableMoveOnlyStruct
46+
sil @transferMoveOnlyStruct : $@convention(thin) @async (@guaranteed NonSendableMoveOnlyStruct) -> ()
47+
48+
struct NonSendableStruct {
49+
var ns: NonSendableKlass
50+
}
51+
52+
sil @constructStruct : $@convention(thin) () -> @owned NonSendableStruct
53+
sil @transferStruct : $@convention(thin) @async (@guaranteed NonSendableStruct) -> ()
54+
55+
sil @transferRawPointer : $@convention(thin) @async (Builtin.RawPointer) -> ()
56+
sil @useRawPointer : $@convention(thin) (Builtin.RawPointer) -> ()
57+
sil @initRawPointer : $@convention(thin) () -> Builtin.RawPointer
58+
59+
sil @transferIndirect : $@convention(thin) @async <τ_0_0> (@in_guaranteed τ_0_0) -> ()
60+
sil @useIndirect : $@convention(thin) <τ_0_0> (@in_guaranteed τ_0_0) -> ()
61+
sil @initIndirect : $@convention(thin) <T> () -> @out T
62+
63+
enum FakeOptional<T> {
64+
case none
65+
case some(T)
66+
}
67+
68+
sil @getLinearFunction : $@convention(thin) () -> @owned @differentiable(_linear) @callee_guaranteed (Float) -> Float
69+
sil @transferLinearFunction : $@async @convention(thin) (@guaranteed @differentiable(_linear) @callee_guaranteed (Float) -> Float) -> ()
70+
sil @getDifferentiableFunction : $@convention(thin) () -> @owned @differentiable(reverse) @callee_guaranteed (Float) -> Float
71+
sil @transferDifferentiableFunction : $@async @convention(thin) (@guaranteed @differentiable(reverse) @callee_guaranteed (Float) -> Float) -> ()
72+
sil @getFunction : $@convention(thin) () -> @owned @callee_guaranteed (Float) -> Float
73+
sil @useFunction : $@convention(thin) (@guaranteed @callee_guaranteed (Float) -> Float) -> ()
74+
sil @transferFunction : $@async @convention(thin) (@guaranteed @callee_guaranteed (Float) -> Float) -> ()
75+
76+
/////////////////
77+
// MARK: Tests //
78+
/////////////////
79+
80+
sil [ossa] @linear_function_extract_test : $@async @convention(thin) () -> () {
81+
bb0:
82+
%init = function_ref @getLinearFunction : $@convention(thin) () -> @owned @differentiable(_linear) @callee_guaranteed (Float) -> Float
83+
%0 = apply %init() : $@convention(thin) () -> @owned @differentiable(_linear) @callee_guaranteed (Float) -> Float
84+
85+
%transfer = function_ref @transferLinearFunction : $@async @convention(thin) (@guaranteed @differentiable(_linear) @callee_guaranteed (Float) -> Float) -> ()
86+
apply [caller_isolation=nonisolated] [callee_isolation=global_actor] %transfer(%0) : $@async @convention(thin) (@guaranteed @differentiable(_linear) @callee_guaranteed (Float) -> Float) -> () // expected-warning {{transferring value of non-Sendable type '@differentiable(_linear) @callee_guaranteed (Float) -> Float' from nonisolated context to global actor '<null>'-isolated context}}
87+
88+
// No error since this is just look through.
89+
%1 = begin_borrow %0 : $@differentiable(_linear) @callee_guaranteed (Float) -> Float
90+
%2 = linear_function_extract [original] %1 : $@differentiable(_linear) @callee_guaranteed (Float) -> Float
91+
%f = function_ref @useFunction : $@convention(thin) (@guaranteed @callee_guaranteed (Float) -> Float) -> ()
92+
apply %f(%2) : $@convention(thin) (@guaranteed @callee_guaranteed (Float) -> Float) -> () // expected-note {{access here could race}}
93+
end_borrow %1 : $@differentiable(_linear) @callee_guaranteed (Float) -> Float
94+
destroy_value %0 : $@differentiable(_linear) @callee_guaranteed (Float) -> Float
95+
96+
%9999 = tuple ()
97+
return %9999 : $()
98+
}
99+
100+
sil [ossa] @differentiable_function_extract_test : $@async @convention(thin) () -> () {
101+
bb0:
102+
%init = function_ref @getDifferentiableFunction : $@convention(thin) () -> @owned @differentiable(reverse) @callee_guaranteed (Float) -> Float
103+
%0 = apply %init() : $@convention(thin) () -> @owned @differentiable(reverse) @callee_guaranteed (Float) -> Float
104+
105+
%transfer = function_ref @transferDifferentiableFunction : $@async @convention(thin) (@guaranteed @differentiable(reverse) @callee_guaranteed (Float) -> Float) -> ()
106+
apply [caller_isolation=nonisolated] [callee_isolation=global_actor] %transfer(%0) : $@async @convention(thin) (@guaranteed @differentiable(reverse) @callee_guaranteed (Float) -> Float) -> () // expected-warning {{transferring value of non-Sendable type '@differentiable(reverse) @callee_guaranteed (Float) -> Float' from nonisolated context to global actor '<null>'-isolated context}}
107+
108+
// No error since this is just look through.
109+
%1 = begin_borrow %0 : $@differentiable(reverse) @callee_guaranteed (Float) -> Float
110+
%2 = differentiable_function_extract [original] %1 : $@differentiable(reverse) @callee_guaranteed (Float) -> Float
111+
%f = function_ref @useFunction : $@convention(thin) (@guaranteed @callee_guaranteed (Float) -> Float) -> ()
112+
apply %f(%2) : $@convention(thin) (@guaranteed @callee_guaranteed (Float) -> Float) -> () // expected-note {{access here could race}}
113+
end_borrow %1 : $@differentiable(reverse) @callee_guaranteed (Float) -> Float
114+
destroy_value %0 : $@differentiable(reverse) @callee_guaranteed (Float) -> Float
115+
116+
%9999 = tuple ()
117+
return %9999 : $()
118+
}
119+
120+
sil [ossa] @linear_function_test : $@async @convention(thin) () -> () {
121+
bb0:
122+
%0 = function_ref @getFunction : $@convention(thin) () -> @owned @callee_guaranteed (Float) -> Float
123+
%1 = apply %0() : $@convention(thin) () -> @owned @callee_guaranteed (Float) -> Float
124+
125+
%transfer = function_ref @transferFunction : $@async @convention(thin) (@guaranteed @callee_guaranteed (Float) -> Float) -> ()
126+
apply [caller_isolation=nonisolated] [callee_isolation=global_actor] %transfer(%1) : $@async @convention(thin) (@guaranteed @callee_guaranteed (Float) -> Float) -> () // expected-warning {{transferring value of non-Sendable type '@callee_guaranteed (Float) -> Float' from nonisolated context to global actor '<null>'-isolated context}}
127+
128+
%2 = begin_borrow %1 : $@callee_guaranteed (Float) -> Float
129+
%3 = linear_function [parameters 0] %2 : $@callee_guaranteed (Float) -> Float // expected-note {{access here could race}}
130+
131+
end_borrow %2 : $@callee_guaranteed (Float) -> Float
132+
destroy_value %1 : $@callee_guaranteed (Float) -> Float
133+
%9999 = tuple ()
134+
return %9999 : $()
135+
}
136+
137+
sil @derivative : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float)
138+
sil @vjp : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float)
139+
140+
// differentiable_function [parameters 0] [results 0] [[ORIG_FN]] : $@convention(thin) (Float) -> Float with_derivative {undef : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float), [[VJP_FN]] : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float)}
141+
sil [ossa] @differentiable_function_test : $@async @convention(thin) () -> () {
142+
bb0:
143+
%0 = function_ref @getFunction : $@convention(thin) () -> @owned @callee_guaranteed (Float) -> Float
144+
%1 = apply %0() : $@convention(thin) () -> @owned @callee_guaranteed (Float) -> Float
145+
146+
%transfer = function_ref @transferFunction : $@async @convention(thin) (@guaranteed @callee_guaranteed (Float) -> Float) -> ()
147+
apply [caller_isolation=nonisolated] [callee_isolation=global_actor] %transfer(%1) : $@async @convention(thin) (@guaranteed @callee_guaranteed (Float) -> Float) -> () // expected-warning {{transferring value of non-Sendable type '@callee_guaranteed (Float) -> Float' from nonisolated context to global actor '<null>'-isolated context}}
148+
149+
%derivative = function_ref @derivative : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float)
150+
%vjp = function_ref @vjp : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float)
151+
%2 = begin_borrow %1 : $@callee_guaranteed (Float) -> Float
152+
%3 = differentiable_function [parameters 0] [results 0] %2 : $@callee_guaranteed (Float) -> Float with_derivative {%derivative : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float), %vjp : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float)} // expected-note {{access here could race}}
153+
154+
end_borrow %2 : $@callee_guaranteed (Float) -> Float
155+
destroy_value %1 : $@callee_guaranteed (Float) -> Float
156+
157+
%9999 = tuple ()
158+
return %9999 : $()
159+
}

0 commit comments

Comments
 (0)