@@ -428,7 +428,7 @@ private func handleNonApplies(for rootClosure: SingleValueInstruction,
428
428
possibleMarkDependenceBases. insert ( pai)
429
429
rootClosurePossibleLiveRange. insert ( use. instruction)
430
430
haveUsedReabstraction = true
431
- } else {
431
+ } else if pai . isPullbackInResultOfAutodiffVJP {
432
432
rootClosureApplies. pushIfNotVisited ( use)
433
433
}
434
434
@@ -677,6 +677,9 @@ private func markConvertedAndReabstractedClosuresAsUsed(rootClosure: Value, conv
677
677
markConvertedAndReabstractedClosuresAsUsed ( rootClosure: rootClosure, convertedAndReabstractedClosure: mdi. value,
678
678
convertedAndReabstractedClosures: & convertedAndReabstractedClosures)
679
679
default :
680
+ log ( " Parent function of callSite: \( rootClosure. parentFunction) " )
681
+ log ( " Root closure: \( rootClosure) " )
682
+ log ( " Converted/reabstracted closure: \( convertedAndReabstractedClosure) " )
680
683
fatalError ( " While marking converted/reabstracted closures as used, found unexpected instruction: \( convertedAndReabstractedClosure) " )
681
684
}
682
685
}
@@ -814,7 +817,7 @@ private extension SpecializationCloner {
814
817
if closureArgDesc. isClosureGuaranteed || closureArgDesc. parameterInfo. isTrivialNoescapeClosure,
815
818
!allClonedReleasableClosures. isEmpty
816
819
{
817
- for exitBlock in closureArgDesc . reachableExitBBs {
820
+ for exitBlock in callSite . reachableExitBBsInCallee {
818
821
let clonedExitBlock = self . getClonedBlock ( for: exitBlock)
819
822
820
823
let terminator = clonedExitBlock. terminator is UnreachableInst
@@ -824,12 +827,8 @@ private extension SpecializationCloner {
824
827
let builder = Builder ( before: terminator, self . context)
825
828
826
829
for closure in allClonedReleasableClosures {
827
- if let pai = closure as? PartialApplyInst ,
828
- pai. isOnStack
829
- {
830
- builder. destroyPartialApplyOnStack ( paiOnStack: pai, self . context)
831
- } else {
832
- builder. createReleaseValue ( operand: closure)
830
+ if let pai = closure as? PartialApplyInst {
831
+ builder. destroyPartialApply ( pai: pai, self . context)
833
832
}
834
833
}
835
834
}
@@ -960,6 +959,9 @@ private extension Builder {
960
959
& releasableClonedReabstractedClosures, & origToClonedValueMap)
961
960
962
961
guard let function = pai. referencedFunction else {
962
+ log ( " Parent function of callSite: \( rootClosure. parentFunction) " )
963
+ log ( " Root closure: \( rootClosure) " )
964
+ log ( " Unsupported reabstraction closure: \( pai) " )
963
965
fatalError ( " Encountered unsupported reabstraction (via partial_apply) of root closure! " )
964
966
}
965
967
@@ -982,6 +984,9 @@ private extension Builder {
982
984
return reabstracted
983
985
984
986
default :
987
+ log ( " Parent function of callSite: \( rootClosure. parentFunction) " )
988
+ log ( " Root closure: \( rootClosure) " )
989
+ log ( " Converted/reabstracted closure: \( reabstractedClosure) " )
985
990
fatalError ( " Encountered unsupported reabstraction of root closure: \( reabstractedClosure) " )
986
991
}
987
992
}
@@ -993,27 +998,32 @@ private extension Builder {
993
998
return ( finalClonedReabstractedClosure as! SingleValueInstruction , releasableClonedReabstractedClosures)
994
999
}
995
1000
996
- func destroyPartialApplyOnStack( paiOnStack: PartialApplyInst , _ context: FunctionPassContext ) {
997
- precondition ( paiOnStack. isOnStack, " Function must only be called for `partial_apply`s on stack! " )
998
-
1001
+ func destroyPartialApply( pai: PartialApplyInst , _ context: FunctionPassContext ) {
999
1002
// TODO: Support only OSSA instructions once the OSSA elimination pass is moved after all function optimization
1000
1003
// passes.
1001
- //
1002
- // for arg in paiOnStack.arguments {
1003
- // self.createDestroyValue(operand: arg)
1004
- // }
1005
1004
1006
- // self.createDestroyValue(operand: paiOnStack)
1005
+ if pai. isOnStack {
1006
+ // for arg in pai.arguments {
1007
+ // self.createDestroyValue(operand: arg)
1008
+ // }
1009
+ // self.createDestroyValue(operand: pai)
1007
1010
1008
- if paiOnStack . parentFunction. hasOwnership {
1011
+ if pai . parentFunction. hasOwnership {
1009
1012
// Under OSSA, the closure acts as an owned value whose lifetime is a borrow scope for the captures, so we need to
1010
1013
// end the borrow scope before ending the lifetimes of the captures themselves.
1011
- self . createDestroyValue ( operand: paiOnStack)
1012
- self . destroyCapturedArgs ( for: paiOnStack)
1014
+ self . createDestroyValue ( operand: pai)
1015
+ self . destroyCapturedArgs ( for: pai)
1016
+ } else {
1017
+ self . destroyCapturedArgs ( for: pai)
1018
+ self . createDeallocStack ( pai)
1019
+ context. notifyInvalidatedStackNesting ( )
1020
+ }
1013
1021
} else {
1014
- self . destroyCapturedArgs ( for: paiOnStack)
1015
- self . createDeallocStack ( paiOnStack)
1016
- context. notifyInvalidatedStackNesting ( )
1022
+ if pai. parentFunction. hasOwnership {
1023
+ self . createDestroyValue ( operand: pai)
1024
+ } else {
1025
+ self . createReleaseValue ( operand: pai)
1026
+ }
1017
1027
}
1018
1028
}
1019
1029
}
@@ -1107,9 +1117,11 @@ private extension PartialApplyInst {
1107
1117
}
1108
1118
1109
1119
var isPartialApplyOfThunk : Bool {
1110
- if self . numArguments == 1 || self . numArguments == 2 ,
1120
+ if self . numArguments == 1 ,
1111
1121
let fun = self . referencedFunction,
1112
- fun. thunkKind == . reabstractionThunk || fun. thunkKind == . thunk
1122
+ fun. thunkKind == . reabstractionThunk || fun. thunkKind == . thunk,
1123
+ self . arguments [ 0 ] . type. isFunction,
1124
+ self . arguments [ 0 ] . type. isReferenceCounted ( in: self . parentFunction) || self . callee. type. isThickFunction
1113
1125
{
1114
1126
return true
1115
1127
}
@@ -1279,10 +1291,6 @@ private struct ClosureArgDescriptor {
1279
1291
var isClosureConsumed : Bool {
1280
1292
closureParamInfo. convention. isConsumed
1281
1293
}
1282
-
1283
- var reachableExitBBs : [ BasicBlock ] {
1284
- closure. parentFunction. blocks. filter { $0. isReachableExitBlock }
1285
- }
1286
1294
}
1287
1295
1288
1296
/// Represents a callsite containing one or more closure arguments.
@@ -1302,6 +1310,10 @@ private struct CallSite {
1302
1310
applySite. referencedFunction!
1303
1311
}
1304
1312
1313
+ var reachableExitBBsInCallee : [ BasicBlock ] {
1314
+ applyCallee. blocks. filter { $0. isReachableExitBlock }
1315
+ }
1316
+
1305
1317
func hasClosureArg( at index: Int ) -> Bool {
1306
1318
closureArgDescriptors. contains { $0. closureArgumentIndex == index }
1307
1319
}
0 commit comments