Skip to content

Commit 6e5958f

Browse files
committed
AutoDiff: Simplify closure discriminator assignment
Closure discriminators don't need to be unique across function bodies, so we can always set it to 0 here instead of using DiscriminatorFinder.
1 parent 43fd786 commit 6e5958f

File tree

2 files changed

+30
-36
lines changed

2 files changed

+30
-36
lines changed

lib/Sema/DerivedConformanceDifferentiable.cpp

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -388,10 +388,7 @@ deriveBodyDifferentiable_zeroTangentVectorInitializer(
388388
DeclNameLoc(), /*Implicit*/ true);
389389

390390
// Create closure expression.
391-
DiscriminatorFinder DF;
392-
for (Decl *D : parentDC->getParentSourceFile()->getTopLevelDecls())
393-
D->walk(DF);
394-
auto discriminator = DF.getNextDiscriminator();
391+
unsigned discriminator = 0;
395392
auto resultTy = funcDecl->getMethodInterfaceType()
396393
->castTo<AnyFunctionType>()
397394
->getResult();
@@ -502,10 +499,7 @@ deriveBodyDifferentiable_zeroTangentVectorInitializer(
502499

503500
// Create closure expression:
504501
// `{ TangentVector(x: x_zeroTangentVectorInitializer(), ...) }`.
505-
DiscriminatorFinder DF;
506-
for (Decl *D : parentDC->getParentSourceFile()->getTopLevelDecls())
507-
D->walk(DF);
508-
auto discriminator = DF.getNextDiscriminator();
502+
unsigned discriminator = 0;
509503
auto resultTy = funcDecl->getMethodInterfaceType()
510504
->castTo<AnyFunctionType>()
511505
->getResult();

test/AutoDiff/Sema/DerivedConformances/derived_zero_tangent_vector_initializer.swift

Lines changed: 28 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -137,8 +137,8 @@ enum CustomTangentVectorEnum<T: Differentiable>: Differentiable {
137137
// CHECK: [[Y_PROP:%.*]] = struct_extract [[SELF]] : $SelfTangentVectorStruct, #SelfTangentVectorStruct.y
138138
// CHECK: [[Y_ZERO_INIT_FN:%.*]] = function_ref @$sSd{{.*}}E28zeroTangentVectorInitializerSdycvg : $@convention(method) (Double) -> @owned @callee_guaranteed () -> Double
139139
// CHECK: [[Y_ZERO_INIT:%.*]] = apply [[Y_ZERO_INIT_FN]]([[Y_PROP]])
140-
// CHECK: // function_ref closure #2 in SelfTangentVectorStruct.zeroTangentVectorInitializer.getter
141-
// CHECK: [[CLOSURE_FN:%.*]] = function_ref @${{.*}}SelfTangentVectorStructV0bgH11InitializerACycvgACycfU0_
140+
// CHECK: // function_ref closure #1 in SelfTangentVectorStruct.zeroTangentVectorInitializer.getter
141+
// CHECK: [[CLOSURE_FN:%.*]] = function_ref @${{.*}}SelfTangentVectorStructV0bgH11InitializerACycvgACycfU_
142142
// CHECK: [[X_ZERO_INIT_COPY:%.*]] = copy_value [[X_ZERO_INIT]]
143143
// CHECK: [[Y_ZERO_INIT_COPY:%.*]] = copy_value [[Y_ZERO_INIT]]
144144
// CHECK: [[ZERO_INIT:%.*]] = partial_apply [callee_guaranteed] [[CLOSURE_FN]]([[X_ZERO_INIT_COPY]], [[Y_ZERO_INIT_COPY]])
@@ -148,8 +148,8 @@ enum CustomTangentVectorEnum<T: Differentiable>: Differentiable {
148148
// CHECK-LABEL: // CustomTangentVectorStruct.zeroTangentVectorInitializer.getter
149149
// CHECK-NEXT: sil hidden [_semantics "autodiff.nonvarying"] [ossa] @${{.*}}CustomTangentVectorStructV0bgH11Initializer0gH0Qzycvg : $@convention(method) <T, U where T : Differentiable, U : Differentiable> (@in_guaranteed CustomTangentVectorStruct<T, U>) -> @owned @callee_guaranteed @substituted <τ_0_0> () -> @out τ_0_0 for <T.TangentVector> {
150150
// CHECK: bb0([[SELF:%.*]] : $*CustomTangentVectorStruct<T, U>):
151-
// CHECK: // function_ref closure #3 in CustomTangentVectorStruct.zeroTangentVectorInitializer.getter
152-
// CHECK: function_ref @${{.*}}CustomTangentVectorStructV0bgH11Initializer0gH0QzycvgAFycfU1_
151+
// CHECK: // function_ref closure #1 in CustomTangentVectorStruct.zeroTangentVectorInitializer.getter
152+
// CHECK: function_ref @${{.*}}CustomTangentVectorStructV0bgH11Initializer0gH0QzycvgAFycfU_
153153
// CHECK: }
154154

155155
// CHECK-LABEL: // MemberwiseTangentVectorClass.zeroTangentVectorInitializer.getter
@@ -163,8 +163,8 @@ enum CustomTangentVectorEnum<T: Differentiable>: Differentiable {
163163
// CHECK: [[Y_PROP:%.*]] = apply [[Y_PROP_METHOD]]([[SELF]])
164164
// CHECK: [[Y_ZERO_INIT_FN:%.*]] = function_ref @$sSd{{.*}}E28zeroTangentVectorInitializerSdycvg : $@convention(method) (Double) -> @owned @callee_guaranteed () -> Double
165165
// CHECK: [[Y_ZERO_INIT:%.*]] = apply [[Y_ZERO_INIT_FN]]([[Y_PROP]])
166-
// CHECK: // function_ref closure #4 in MemberwiseTangentVectorClass.zeroTangentVectorInitializer.getter
167-
// CHECK: [[CLOSURE_FN:%.*]] = function_ref @${{.*}}MemberwiseTangentVectorClassC0bgH11InitializerAC0gH0VycvgAFycfU2_
166+
// CHECK: // function_ref closure #1 in MemberwiseTangentVectorClass.zeroTangentVectorInitializer.getter
167+
// CHECK: [[CLOSURE_FN:%.*]] = function_ref @${{.*}}MemberwiseTangentVectorClassC0bgH11InitializerAC0gH0VycvgAFycfU_
168168
// CHECK: [[X_ZERO_INIT_COPY:%.*]] = copy_value [[X_ZERO_INIT]]
169169
// CHECK: [[Y_ZERO_INIT_COPY:%.*]] = copy_value [[Y_ZERO_INIT]]
170170
// CHECK: [[ZERO_INIT:%.*]] = partial_apply [callee_guaranteed] [[CLOSURE_FN]]([[X_ZERO_INIT_COPY]], [[Y_ZERO_INIT_COPY]])
@@ -174,29 +174,29 @@ enum CustomTangentVectorEnum<T: Differentiable>: Differentiable {
174174
// CHECK-LABEL: // SelfTangentVectorClass.zeroTangentVectorInitializer.getter
175175
// CHECK-NEXT: sil hidden [_semantics "autodiff.nonvarying"] [ossa] @${{.*}}SelfTangentVectorClassC0bgH11InitializerACycvg : $@convention(method) (@guaranteed SelfTangentVectorClass) -> @owned @callee_guaranteed () -> @owned SelfTangentVectorClass {
176176
// CHECK: bb0([[SELF:%.*]] : @guaranteed $SelfTangentVectorClass):
177-
// CHECK: // function_ref closure #5 in SelfTangentVectorClass.zeroTangentVectorInitializer.getter
178-
// CHECK: function_ref @${{.*}}SelfTangentVectorClassC0bgH11InitializerACycvgACycfU3_
177+
// CHECK: // function_ref closure #1 in SelfTangentVectorClass.zeroTangentVectorInitializer.getter
178+
// CHECK: function_ref @${{.*}}SelfTangentVectorClassC0bgH11InitializerACycvgACycfU_
179179
// CHECK: }
180180

181181
// CHECK-LABEL: // CustomTangentVectorClass.zeroTangentVectorInitializer.getter
182182
// CHECK-NEXT: sil hidden [_semantics "autodiff.nonvarying"] [ossa] @${{.*}}CustomTangentVectorClassC0bgH11Initializer0gH0Qzycvg : $@convention(method) <T, U where T : Differentiable, U : Differentiable> (@guaranteed CustomTangentVectorClass<T, U>) -> @owned @callee_guaranteed @substituted <τ_0_0> () -> @out τ_0_0 for <T.TangentVector> {
183183
// CHECK: bb0(%0 : @guaranteed $CustomTangentVectorClass<T, U>):
184-
// CHECK: // function_ref closure #6 in CustomTangentVectorClass.zeroTangentVectorInitializer.getter
185-
// CHECK: function_ref @${{.*}}CustomTangentVectorClassC0bgH11Initializer0gH0QzycvgAFycfU4_
184+
// CHECK: // function_ref closure #1 in CustomTangentVectorClass.zeroTangentVectorInitializer.getter
185+
// CHECK: function_ref @${{.*}}CustomTangentVectorClassC0bgH11Initializer0gH0QzycvgAFycfU_
186186
// CHECK: }
187187

188188
// CHECK-LABEL: // SelfTangentVectorEnum.zeroTangentVectorInitializer.getter
189189
// CHECK-NEXT: sil hidden [_semantics "autodiff.nonvarying"] [ossa] @${{.*}}SelfTangentVectorEnumO0bgH11InitializerACycvg : $@convention(method) (@guaranteed SelfTangentVectorEnum) -> @owned @callee_guaranteed () -> @owned SelfTangentVectorEnum {
190190
// CHECK: bb0([[SELF:%.*]] : @guaranteed $SelfTangentVectorEnum):
191-
// CHECK: // function_ref closure #7 in SelfTangentVectorEnum.zeroTangentVectorInitializer.getter
192-
// CHECK: function_ref @${{.*}}SelfTangentVectorEnumO0bgH11InitializerACycvgACycfU5_
191+
// CHECK: // function_ref closure #1 in SelfTangentVectorEnum.zeroTangentVectorInitializer.getter
192+
// CHECK: function_ref @${{.*}}SelfTangentVectorEnumO0bgH11InitializerACycvgACycfU_
193193
// CHECK: }
194194

195195
// CHECK-LABEL: // CustomTangentVectorEnum.zeroTangentVectorInitializer.getter
196196
// CHECK-NEXT: sil hidden [_semantics "autodiff.nonvarying"] [ossa] @${{.*}}CustomTangentVectorEnumO0bgH11Initializer0gH0Qzycvg : $@convention(method) <T where T : Differentiable> (@in_guaranteed CustomTangentVectorEnum<T>) -> @owned @callee_guaranteed @substituted <τ_0_0> () -> @out τ_0_0 for <T.TangentVector> {
197197
// CHECK: bb0([[SELF:%.*]] : $*CustomTangentVectorEnum<T>):
198-
// CHECK: // function_ref closure #8 in CustomTangentVectorEnum.zeroTangentVectorInitializer.getter
199-
// CHECK: function_ref @${{.*}}CustomTangentVectorEnumO0bgH11Initializer0gH0QzycvgAFycfU6_
198+
// CHECK: // function_ref closure #1 in CustomTangentVectorEnum.zeroTangentVectorInitializer.getter
199+
// CHECK: function_ref @${{.*}}CustomTangentVectorEnumO0bgH11Initializer0gH0QzycvgAFycfU_
200200
// CHECK: }
201201

202202
// CHECK-LABEL: // closure #1 in MemberwiseTangentVectorStruct.zeroTangentVectorInitializer.getter
@@ -206,44 +206,44 @@ enum CustomTangentVectorEnum<T: Differentiable>: Differentiable {
206206
// CHECK-NOT: witness_method {{.*}}, #AdditiveArithmetic.zero!getter
207207
// CHECK: }
208208

209-
// CHECK-LABEL: // closure #2 in SelfTangentVectorStruct.zeroTangentVectorInitializer.getter
210-
// CHECK-NEXT: sil private [ossa] @${{.*}}SelfTangentVectorStructV0bgH11InitializerACycvgACycfU0_ : $@convention(thin) (@guaranteed @callee_guaranteed () -> Float, @guaranteed @callee_guaranteed () -> Double) -> SelfTangentVectorStruct {
209+
// CHECK-LABEL: // closure #1 in SelfTangentVectorStruct.zeroTangentVectorInitializer.getter
210+
// CHECK-NEXT: sil private [ossa] @${{.*}}SelfTangentVectorStructV0bgH11InitializerACycvgACycfU_ : $@convention(thin) (@guaranteed @callee_guaranteed () -> Float, @guaranteed @callee_guaranteed () -> Double) -> SelfTangentVectorStruct {
211211
// CHECK: // function_ref SelfTangentVectorStruct.init(x:y:)
212212
// CHECK-NOT: // function_ref static {{.*}}.zero.getter
213213
// CHECK-NOT: witness_method {{.*}}, #AdditiveArithmetic.zero!getter
214214
// CHECK: }
215215

216-
// CHECK-LABEL: // closure #3 in CustomTangentVectorStruct.zeroTangentVectorInitializer.getter
217-
// CHECK-NEXT: sil private [ossa] @${{.*}}CustomTangentVectorStructV0bgH11Initializer0gH0QzycvgAFycfU1_ : $@convention(thin) <T, U where T : Differentiable, U : Differentiable> () -> @out T.TangentVector {
216+
// CHECK-LABEL: // closure #1 in CustomTangentVectorStruct.zeroTangentVectorInitializer.getter
217+
// CHECK-NEXT: sil private [ossa] @${{.*}}CustomTangentVectorStructV0bgH11Initializer0gH0QzycvgAFycfU_ : $@convention(thin) <T, U where T : Differentiable, U : Differentiable> () -> @out T.TangentVector {
218218
// CHECK: witness_method $T.TangentVector, #AdditiveArithmetic.zero!getter
219219
// CHECK: }
220220

221-
// CHECK-LABEL: // closure #4 in MemberwiseTangentVectorClass.zeroTangentVectorInitializer.getter
222-
// CHECK-NEXT: sil private [ossa] @${{.*}}MemberwiseTangentVectorClassC0bgH11InitializerAC0gH0VycvgAFycfU2_ : $@convention(thin) (@guaranteed @callee_guaranteed () -> Float, @guaranteed @callee_guaranteed () -> Double) -> MemberwiseTangentVectorClass.TangentVector {
221+
// CHECK-LABEL: // closure #1 in MemberwiseTangentVectorClass.zeroTangentVectorInitializer.getter
222+
// CHECK-NEXT: sil private [ossa] @${{.*}}MemberwiseTangentVectorClassC0bgH11InitializerAC0gH0VycvgAFycfU_ : $@convention(thin) (@guaranteed @callee_guaranteed () -> Float, @guaranteed @callee_guaranteed () -> Double) -> MemberwiseTangentVectorClass.TangentVector {
223223
// CHECK: // function_ref MemberwiseTangentVectorClass.TangentVector.init(x:y:)
224224
// CHECK-NOT: // function_ref static {{.*}}.zero.getter
225225
// CHECK-NOT: witness_method {{.*}}, #AdditiveArithmetic.zero!getter
226226
// CHECK: }
227227

228-
// CHECK-LABEL: // closure #5 in SelfTangentVectorClass.zeroTangentVectorInitializer.getter
229-
// CHECK-NEXT: sil private [ossa] @${{.*}}SelfTangentVectorClassC0bgH11InitializerACycvgACycfU3_ : $@convention(thin) () -> @owned SelfTangentVectorClass {
228+
// CHECK-LABEL: // closure #1 in SelfTangentVectorClass.zeroTangentVectorInitializer.getter
229+
// CHECK-NEXT: sil private [ossa] @${{.*}}SelfTangentVectorClassC0bgH11InitializerACycvgACycfU_ : $@convention(thin) () -> @owned SelfTangentVectorClass {
230230
// CHECK: // function_ref static SelfTangentVectorClass.zero.getter
231231
// CHECK: function_ref @${{.*}}SelfTangentVectorClassC0B0ACXDvgZ : $@convention(method) (@thick SelfTangentVectorClass.Type) -> @owned SelfTangentVectorClass
232232
// CHECK: }
233233

234-
// CHECK-LABEL: // closure #6 in CustomTangentVectorClass.zeroTangentVectorInitializer.getter
235-
// CHECK-NEXT: sil private [ossa] @${{.*}}CustomTangentVectorClassC0bgH11Initializer0gH0QzycvgAFycfU4_ : $@convention(thin) <T, U where T : Differentiable, U : Differentiable> () -> @out T.TangentVector {
234+
// CHECK-LABEL: // closure #1 in CustomTangentVectorClass.zeroTangentVectorInitializer.getter
235+
// CHECK-NEXT: sil private [ossa] @${{.*}}CustomTangentVectorClassC0bgH11Initializer0gH0QzycvgAFycfU_ : $@convention(thin) <T, U where T : Differentiable, U : Differentiable> () -> @out T.TangentVector {
236236
// CHECK: witness_method $T.TangentVector, #AdditiveArithmetic.zero!getter
237237
// CHECK: }
238238

239239
// TODO(TF-1012): Implement memberwise `zeroTangentVectorInitializer` synthesis for enums.
240-
// CHECK-LABEL: // closure #7 in SelfTangentVectorEnum.zeroTangentVectorInitializer.getter
241-
// CHECK-NEXT: sil private [ossa] @${{.*}}SelfTangentVectorEnumO0bgH11InitializerACycvgACycfU5_ : $@convention(thin) () -> @owned SelfTangentVectorEnum {
240+
// CHECK-LABEL: // closure #1 in SelfTangentVectorEnum.zeroTangentVectorInitializer.getter
241+
// CHECK-NEXT: sil private [ossa] @${{.*}}SelfTangentVectorEnumO0bgH11InitializerACycvgACycfU_ : $@convention(thin) () -> @owned SelfTangentVectorEnum {
242242
// CHECK: // function_ref static SelfTangentVectorEnum.zero.getter
243243
// CHECK: function_ref @${{.*}}SelfTangentVectorEnumO0B0ACvgZ : $@convention(method) (@thin SelfTangentVectorEnum.Type) -> @owned SelfTangentVectorEnum
244244
// CHECK: }
245245

246-
// CHECK-LABEL: // closure #8 in CustomTangentVectorEnum.zeroTangentVectorInitializer.getter
247-
// CHECK-NEXT: sil private [ossa] @$s39derived_zero_tangent_vector_initializer23CustomTangentVectorEnumO0bgH11Initializer0gH0QzycvgAFycfU6_ : $@convention(thin) <T where T : Differentiable> () -> @out T.TangentVector {
246+
// CHECK-LABEL: // closure #1 in CustomTangentVectorEnum.zeroTangentVectorInitializer.getter
247+
// CHECK-NEXT: sil private [ossa] @$s39derived_zero_tangent_vector_initializer23CustomTangentVectorEnumO0bgH11Initializer0gH0QzycvgAFycfU_ : $@convention(thin) <T where T : Differentiable> () -> @out T.TangentVector {
248248
// CHECK: witness_method $T.TangentVector, #AdditiveArithmetic.zero!getter
249249
// CHECK: }

0 commit comments

Comments
 (0)