Skip to content

Commit 487648a

Browse files
committed
[Autodiff] Fixes bugs in closure-spec opt that were causing "optimized" test failures on Linux builds
1 parent 993f7c3 commit 487648a

File tree

2 files changed

+281
-198
lines changed

2 files changed

+281
-198
lines changed

SwiftCompilerSources/Sources/Optimizer/FunctionPasses/ClosureSpecialization.swift

Lines changed: 40 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -428,7 +428,7 @@ private func handleNonApplies(for rootClosure: SingleValueInstruction,
428428
possibleMarkDependenceBases.insert(pai)
429429
rootClosurePossibleLiveRange.insert(use.instruction)
430430
haveUsedReabstraction = true
431-
} else {
431+
} else if pai.isPullbackInResultOfAutodiffVJP {
432432
rootClosureApplies.pushIfNotVisited(use)
433433
}
434434

@@ -677,6 +677,9 @@ private func markConvertedAndReabstractedClosuresAsUsed(rootClosure: Value, conv
677677
markConvertedAndReabstractedClosuresAsUsed(rootClosure: rootClosure, convertedAndReabstractedClosure: mdi.value,
678678
convertedAndReabstractedClosures: &convertedAndReabstractedClosures)
679679
default:
680+
log("Parent function of callSite: \(rootClosure.parentFunction)")
681+
log("Root closure: \(rootClosure)")
682+
log("Converted/reabstracted closure: \(convertedAndReabstractedClosure)")
680683
fatalError("While marking converted/reabstracted closures as used, found unexpected instruction: \(convertedAndReabstractedClosure)")
681684
}
682685
}
@@ -814,7 +817,7 @@ private extension SpecializationCloner {
814817
if closureArgDesc.isClosureGuaranteed || closureArgDesc.parameterInfo.isTrivialNoescapeClosure,
815818
!allClonedReleasableClosures.isEmpty
816819
{
817-
for exitBlock in closureArgDesc.reachableExitBBs {
820+
for exitBlock in callSite.reachableExitBBsInCallee {
818821
let clonedExitBlock = self.getClonedBlock(for: exitBlock)
819822

820823
let terminator = clonedExitBlock.terminator is UnreachableInst
@@ -824,12 +827,8 @@ private extension SpecializationCloner {
824827
let builder = Builder(before: terminator, self.context)
825828

826829
for closure in allClonedReleasableClosures {
827-
if let pai = closure as? PartialApplyInst,
828-
pai.isOnStack
829-
{
830-
builder.destroyPartialApplyOnStack(paiOnStack: pai, self.context)
831-
} else {
832-
builder.createReleaseValue(operand: closure)
830+
if let pai = closure as? PartialApplyInst {
831+
builder.destroyPartialApply(pai: pai, self.context)
833832
}
834833
}
835834
}
@@ -960,6 +959,9 @@ private extension Builder {
960959
&releasableClonedReabstractedClosures, &origToClonedValueMap)
961960

962961
guard let function = pai.referencedFunction else {
962+
log("Parent function of callSite: \(rootClosure.parentFunction)")
963+
log("Root closure: \(rootClosure)")
964+
log("Unsupported reabstraction closure: \(pai)")
963965
fatalError("Encountered unsupported reabstraction (via partial_apply) of root closure!")
964966
}
965967

@@ -982,6 +984,9 @@ private extension Builder {
982984
return reabstracted
983985

984986
default:
987+
log("Parent function of callSite: \(rootClosure.parentFunction)")
988+
log("Root closure: \(rootClosure)")
989+
log("Converted/reabstracted closure: \(reabstractedClosure)")
985990
fatalError("Encountered unsupported reabstraction of root closure: \(reabstractedClosure)")
986991
}
987992
}
@@ -993,27 +998,32 @@ private extension Builder {
993998
return (finalClonedReabstractedClosure as! SingleValueInstruction, releasableClonedReabstractedClosures)
994999
}
9951000

996-
func destroyPartialApplyOnStack(paiOnStack: PartialApplyInst, _ context: FunctionPassContext){
997-
precondition(paiOnStack.isOnStack, "Function must only be called for `partial_apply`s on stack!")
998-
1001+
func destroyPartialApply(pai: PartialApplyInst, _ context: FunctionPassContext){
9991002
// TODO: Support only OSSA instructions once the OSSA elimination pass is moved after all function optimization
10001003
// passes.
1001-
//
1002-
// for arg in paiOnStack.arguments {
1003-
// self.createDestroyValue(operand: arg)
1004-
// }
10051004

1006-
// self.createDestroyValue(operand: paiOnStack)
1005+
if pai.isOnStack {
1006+
// for arg in pai.arguments {
1007+
// self.createDestroyValue(operand: arg)
1008+
// }
1009+
// self.createDestroyValue(operand: pai)
10071010

1008-
if paiOnStack.parentFunction.hasOwnership {
1011+
if pai.parentFunction.hasOwnership {
10091012
// Under OSSA, the closure acts as an owned value whose lifetime is a borrow scope for the captures, so we need to
10101013
// end the borrow scope before ending the lifetimes of the captures themselves.
1011-
self.createDestroyValue(operand: paiOnStack)
1012-
self.destroyCapturedArgs(for: paiOnStack)
1014+
self.createDestroyValue(operand: pai)
1015+
self.destroyCapturedArgs(for: pai)
1016+
} else {
1017+
self.destroyCapturedArgs(for: pai)
1018+
self.createDeallocStack(pai)
1019+
context.notifyInvalidatedStackNesting()
1020+
}
10131021
} else {
1014-
self.destroyCapturedArgs(for: paiOnStack)
1015-
self.createDeallocStack(paiOnStack)
1016-
context.notifyInvalidatedStackNesting()
1022+
if pai.parentFunction.hasOwnership {
1023+
self.createDestroyValue(operand: pai)
1024+
} else {
1025+
self.createReleaseValue(operand: pai)
1026+
}
10171027
}
10181028
}
10191029
}
@@ -1107,9 +1117,11 @@ private extension PartialApplyInst {
11071117
}
11081118

11091119
var isPartialApplyOfThunk: Bool {
1110-
if self.numArguments == 1 || self.numArguments == 2,
1120+
if self.numArguments == 1,
11111121
let fun = self.referencedFunction,
1112-
fun.thunkKind == .reabstractionThunk || fun.thunkKind == .thunk
1122+
fun.thunkKind == .reabstractionThunk || fun.thunkKind == .thunk,
1123+
self.arguments[0].type.isFunction,
1124+
self.arguments[0].type.isReferenceCounted(in: self.parentFunction) || self.callee.type.isThickFunction
11131125
{
11141126
return true
11151127
}
@@ -1279,10 +1291,6 @@ private struct ClosureArgDescriptor {
12791291
var isClosureConsumed: Bool {
12801292
closureParamInfo.convention.isConsumed
12811293
}
1282-
1283-
var reachableExitBBs: [BasicBlock] {
1284-
closure.parentFunction.blocks.filter { $0.isReachableExitBlock }
1285-
}
12861294
}
12871295

12881296
/// Represents a callsite containing one or more closure arguments.
@@ -1302,6 +1310,10 @@ private struct CallSite {
13021310
applySite.referencedFunction!
13031311
}
13041312

1313+
var reachableExitBBsInCallee: [BasicBlock] {
1314+
applyCallee.blocks.filter { $0.isReachableExitBlock }
1315+
}
1316+
13051317
func hasClosureArg(at index: Int) -> Bool {
13061318
closureArgDescriptors.contains { $0.closureArgumentIndex == index }
13071319
}

0 commit comments

Comments
 (0)