Skip to content

Commit 74166a4

Browse files
committed
[Autodiff] Moves bridging code accesses in closure-spec opt behind APIs
Addresses some other surfacial feedback as well.
1 parent ab751d5 commit 74166a4

File tree

8 files changed

+96
-201
lines changed

8 files changed

+96
-201
lines changed

SwiftCompilerSources/Sources/Optimizer/FunctionPasses/ClosureSpecialization.swift

Lines changed: 22 additions & 87 deletions
Original file line numberDiff line numberDiff line change
@@ -185,27 +185,17 @@ private func getOrCreateSpecializedFunction(basedOn callSite: CallSite, _ contex
185185
let applySiteCallee = callSite.applyCallee
186186
let specializedParameters = applySiteCallee.convention.getSpecializedParameters(basedOn: callSite)
187187

188-
let createFn = { (functionPassContext: FunctionPassContext) in
189-
specializedFunctionName._withBridgedStringRef { nameRef in
190-
let bridgedParamInfos = specializedParameters.map { $0._bridged }
191-
192-
return bridgedParamInfos.withUnsafeBufferPointer { paramBuf in
193-
functionPassContext
194-
._bridged
195-
.ClosureSpecializer_createEmptyFunctionWithSpecializedSignature(nameRef, paramBuf.baseAddress, paramBuf.count,
196-
applySiteCallee.bridged,
197-
applySiteCallee.isSerialized)
198-
.function
199-
}
200-
}
201-
}
188+
let specializedFunction =
189+
context.createFunctionForClosureSpecialization(from: applySiteCallee, withName: specializedFunctionName,
190+
withParams: specializedParameters,
191+
withSerialization: applySiteCallee.isSerialized)
202192

203-
let buildFn = { (emptySpecializedFunction, functionPassContext) in
204-
let closureSpecCloner = SpecializationCloner(emptySpecializedFunction: emptySpecializedFunction, functionPassContext)
205-
closureSpecCloner.cloneAndSpecializeFunctionBody(using: callSite)
206-
}
193+
context.buildSpecializedFunction(specializedFunction: specializedFunction,
194+
buildFn: { (emptySpecializedFunction, functionPassContext) in
195+
let closureSpecCloner = SpecializationCloner(emptySpecializedFunction: emptySpecializedFunction, functionPassContext)
196+
closureSpecCloner.cloneAndSpecializeFunctionBody(using: callSite)
197+
})
207198

208-
let specializedFunction = context.createAndBuildSpecializedFunction(createFn: createFn, buildFn: buildFn)
209199
return (specializedFunction, false)
210200
}
211201

@@ -479,7 +469,6 @@ private func handleApplies(for rootClosure: SingleValueInstruction, callSiteMap:
479469
continue
480470
}
481471

482-
// Mark the converted/reabstracted closures as used.
483472
if haveUsedReabstraction {
484473
markConvertedAndReabstractedClosuresAsUsed(rootClosure: rootClosure, convertedAndReabstractedClosure: use.value,
485474
convertedAndReabstractedClosures: &convertedAndReabstractedClosures)
@@ -589,7 +578,7 @@ private extension SpecializationCloner {
589578

590579
let (allSpecializedEntryBlockArgs, closureArgIndexToAllClonedReleasableClosures) = cloneAllClosures(at: callSite)
591580

592-
self.cloneFunctionBody(from: callSite.applyCallee, entryBlockArgs: allSpecializedEntryBlockArgs)
581+
self.cloneFunctionBody(from: callSite.applyCallee, entryBlockArguments: allSpecializedEntryBlockArgs)
593582

594583
self.insertCleanupCodeForClonedReleasableClosures(
595584
from: callSite, closureArgIndexToAllClonedReleasableClosures: closureArgIndexToAllClonedReleasableClosures)
@@ -672,9 +661,7 @@ private extension SpecializationCloner {
672661

673662
let (finalClonedReabstractedClosure, releasableClonedReabstractedClosures) =
674663
builder.cloneRootClosureReabstractions(rootClosure: closureArgDesc.closure, clonedRootClosure: clonedRootClosure,
675-
reabstractedClosure: callSite
676-
.appliedArgForClosure(
677-
at: closureArgDesc.closureArgIndex)!,
664+
reabstractedClosure: callSite.appliedArgForClosure(at: closureArgDesc.closureArgIndex)!,
678665
origToClonedValueMap: origToClonedValueMap,
679666
self.context)
680667

@@ -715,7 +702,7 @@ private extension SpecializationCloner {
715702
// Insert a `destroy_value`, for all releasable closures, in all reachable exit BBs if the closure was passed as a
716703
// guaranteed parameter or its type was noescape+thick. This is b/c the closure was passed at +0 originally and we
717704
// need to balance the initial increment of the newly created closure(s).
718-
if closureArgDesc.isClosureGuaranteed || closureArgDesc.isClosureTrivialNoEscape,
705+
if closureArgDesc.isClosureGuaranteed || closureArgDesc.parameterInfo.isTrivialNoescapeClosure,
719706
!allClonedReleasableClosures.isEmpty
720707
{
721708
for exitBlock in closureArgDesc.reachableExitBBs {
@@ -745,10 +732,10 @@ private extension SpecializationCloner {
745732
private extension [HashableValue: Value] {
746733
subscript(key: Value) -> Value? {
747734
get {
748-
self[HashableValue(key)]
735+
self[key.hashable]
749736
}
750737
set {
751-
self[HashableValue(key)] = newValue
738+
self[key.hashable] = newValue
752739
}
753740
}
754741
}
@@ -760,8 +747,8 @@ private extension Builder {
760747
let function = self.createFunctionRef(closureArgDesc.callee)
761748

762749
if let pai = closureArgDesc.closure as? PartialApplyInst {
763-
return self.createPartialApply(forFunction: function, substitutionMap: SubstitutionMap(),
764-
capturedArgs: capturedArgs, calleeConvention: pai.calleeConvention,
750+
return self.createPartialApply(function: function, substitutionMap: SubstitutionMap(),
751+
capturedArguments: capturedArgs, calleeConvention: pai.calleeConvention,
765752
hasUnknownResultIsolation: pai.hasUnknownResultIsolation,
766753
isOnStack: pai.isOnStack)
767754
} else {
@@ -806,8 +793,8 @@ private extension Builder {
806793
}
807794

808795
let fri = self.createFunctionRef(function)
809-
let reabstracted = self.createPartialApply(forFunction: fri, substitutionMap: SubstitutionMap(),
810-
capturedArgs: [toBeReabstracted],
796+
let reabstracted = self.createPartialApply(function: fri, substitutionMap: SubstitutionMap(),
797+
capturedArguments: [toBeReabstracted],
811798
calleeConvention: pai.calleeConvention,
812799
hasUnknownResultIsolation: pai.hasUnknownResultIsolation,
813800
isOnStack: pai.isOnStack)
@@ -895,7 +882,7 @@ private extension FunctionConvention {
895882
private extension ParameterInfo {
896883
func withSpecializedConvention(isArgTypeTrivial: Bool) -> Self {
897884
let specializedParamConvention =
898-
if self.hasAllowedIndirectConvForClosureSpec {
885+
if self.convention.isAllowedIndirectConvForClosureSpec {
899886
self.convention
900887
} else {
901888
isArgTypeTrivial ? ArgumentConvention.directUnowned : ArgumentConvention.directOwned
@@ -905,13 +892,8 @@ private extension ParameterInfo {
905892
hasLoweredAddresses: self.hasLoweredAddresses)
906893
}
907894

908-
var hasAllowedIndirectConvForClosureSpec: Bool {
909-
switch convention {
910-
case .indirectInout, .indirectInoutAliasable:
911-
return true
912-
default:
913-
return false
914-
}
895+
var isTrivialNoescapeClosure: Bool {
896+
self.type.SILFunctionType_isTrivialNoescape()
915897
}
916898
}
917899

@@ -995,21 +977,6 @@ private extension Function {
995977
}
996978

997979
// ===================== Utility Types ===================== //
998-
private enum HashableValue: Hashable {
999-
case Argument(FunctionArgument)
1000-
case Instruction(SingleValueInstruction)
1001-
1002-
init(_ value: Value) {
1003-
if let instr = value as? SingleValueInstruction {
1004-
self = .Instruction(instr)
1005-
} else if let arg = value as? FunctionArgument {
1006-
self = .Argument(arg)
1007-
} else {
1008-
fatalError("Invalid hashable value: \(value)")
1009-
}
1010-
}
1011-
}
1012-
1013980
private struct OrderedDict<Key: Hashable, Value> {
1014981
private var valueIndexDict: [Key: Int] = [:]
1015982
private var entryList: [(Key, Value)] = []
@@ -1077,10 +1044,6 @@ private struct ClosureArgDescriptor {
10771044
closureInfo.closure
10781045
}
10791046

1080-
var isPartialApply: Bool {
1081-
closure is PartialApplyInst
1082-
}
1083-
10841047
var isPartialApplyOnStack: Bool {
10851048
if let pai = closure as? PartialApplyInst {
10861049
return pai.isOnStack
@@ -1116,7 +1079,7 @@ private struct ClosureArgDescriptor {
11161079
}
11171080
}
11181081

1119-
var arguments: (some Sequence<Value>)? {
1082+
var arguments: LazyMapSequence<OperandArray, Value>? {
11201083
if let pai = closure as? PartialApplyInst {
11211084
return pai.arguments
11221085
}
@@ -1131,14 +1094,6 @@ private struct ClosureArgDescriptor {
11311094
closureParamInfo.convention.isConsumed
11321095
}
11331096

1134-
var isClosureTrivialNoEscape: Bool {
1135-
closureParamInfo.type.SILFunctionType_isTrivialNoescape()
1136-
}
1137-
1138-
var parentFunction: Function {
1139-
closure.parentFunction
1140-
}
1141-
11421097
var reachableExitBBs: [BasicBlock] {
11431098
closure.parentFunction.blocks.filter { $0.isReachableExitBlock }
11441099
}
@@ -1161,14 +1116,6 @@ private struct CallSite {
11611116
applySite.referencedFunction!
11621117
}
11631118

1164-
var isCalleeSerialized: Bool {
1165-
applyCallee.isSerialized
1166-
}
1167-
1168-
var firstClosureArgDesc: ClosureArgDescriptor? {
1169-
closureArgDescriptors.first
1170-
}
1171-
11721119
func hasClosureArg(at index: Int) -> Bool {
11731120
closureArgDescriptors.contains { $0.closureArgumentIndex == index }
11741121
}
@@ -1177,10 +1124,6 @@ private struct CallSite {
11771124
closureArgDescriptors.first { $0.closureArgumentIndex == index }
11781125
}
11791126

1180-
func closureArg(at index: Int) -> SingleValueInstruction? {
1181-
closureArgDesc(at: index)?.closure
1182-
}
1183-
11841127
func appliedArgForClosure(at index: Int) -> Value? {
11851128
if let closureArgDesc = closureArgDesc(at: index) {
11861129
return applySite.arguments[closureArgDesc.closureArgIndex - applySite.unappliedArgumentCount]
@@ -1189,14 +1132,6 @@ private struct CallSite {
11891132
return nil
11901133
}
11911134

1192-
func closureCallee(at index: Int) -> Function? {
1193-
closureArgDesc(at: index)?.callee
1194-
}
1195-
1196-
func closureLoc(at index: Int) -> Location? {
1197-
closureArgDesc(at: index)?.location
1198-
}
1199-
12001135
func specializedCalleeName(_ context: FunctionPassContext) -> String {
12011136
let closureArgs = Array(self.closureArgDescriptors.map { $0.closure })
12021137
let closureIndices = Array(self.closureArgDescriptors.map { $0.closureArgIndex })

SwiftCompilerSources/Sources/Optimizer/PassManager/Context.swift

Lines changed: 20 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -372,21 +372,29 @@ struct FunctionPassContext : MutatingContext {
372372
return gv.globalVar
373373
}
374374

375-
/// Utility function that should be used by optimizations that generate new functions or specialized versions of
376-
/// existing functions.
377-
func createAndBuildSpecializedFunction(createFn: (FunctionPassContext) -> Function,
378-
buildFn: (Function, FunctionPassContext) -> ()) -> Function
375+
func createFunctionForClosureSpecialization(from applySiteCallee: Function, withName specializedFunctionName: String,
376+
withParams specializedParameters: [ParameterInfo],
377+
withSerialization isSerialized: Bool) -> Function
379378
{
380-
let specializedFunction = createFn(self)
381-
379+
return specializedFunctionName._withBridgedStringRef { nameRef in
380+
let bridgedParamInfos = specializedParameters.map { $0._bridged }
381+
382+
return bridgedParamInfos.withUnsafeBufferPointer { paramBuf in
383+
_bridged.ClosureSpecializer_createEmptyFunctionWithSpecializedSignature(nameRef, paramBuf.baseAddress,
384+
paramBuf.count,
385+
applySiteCallee.bridged,
386+
isSerialized).function
387+
}
388+
}
389+
}
390+
391+
func buildSpecializedFunction(specializedFunction: Function, buildFn: (Function, FunctionPassContext) -> ()) {
382392
let nestedFunctionPassContext =
383393
FunctionPassContext(_bridged: _bridged.initializeNestedPassContext(specializedFunction.bridged))
384394

385-
defer { _bridged.deinitializedNestedPassContext() }
386-
387-
buildFn(specializedFunction, nestedFunctionPassContext)
395+
defer { _bridged.deinitializedNestedPassContext() }
388396

389-
return specializedFunction
397+
buildFn(specializedFunction, nestedFunctionPassContext)
390398
}
391399
}
392400

@@ -477,8 +485,8 @@ extension Builder {
477485
/// Creates a builder which inserts instructions into an empty function, using the location of the function itself.
478486
init(atStartOf function: Function, _ context: some MutatingContext) {
479487
context.verifyIsTransforming(function: function)
480-
self.init(insertAt: .atStartOf(function), context.notifyInstructionChanged,
481-
context._bridged.asNotificationHandler())
488+
self.init(insertAt: .atStartOf(function), location: function.location,
489+
context.notifyInstructionChanged, context._bridged.asNotificationHandler())
482490
}
483491

484492
init(staticInitializerOf global: GlobalVariable, _ context: some MutatingContext) {

SwiftCompilerSources/Sources/Optimizer/Utilities/SpecializationCloner.swift

Lines changed: 10 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
//
33
// This source file is part of the Swift.org open source project
44
//
5-
// Copyright (c) 2014 - 2023 Apple Inc. and the Swift project authors
5+
// Copyright (c) 2014 - 2024 Apple Inc. and the Swift project authors
66
// Licensed under Apache License v2.0 with Runtime Library Exception
77
//
88
// See https://swift.org/LICENSE.txt for license information
@@ -16,40 +16,32 @@ import SIL
1616
/// Utility cloner type that can be used by optimizations that generate new functions or specialized versions of
1717
/// existing functions.
1818
struct SpecializationCloner {
19-
private var _context: FunctionPassContext
20-
private var _bridged: BridgedSpecializationCloner
19+
private let bridged: BridgedSpecializationCloner
20+
let context: FunctionPassContext
2121

2222
init(emptySpecializedFunction: Function, _ context: FunctionPassContext) {
23-
self._context = context
24-
self._bridged = BridgedSpecializationCloner(emptySpecializedFunction.bridged)
25-
}
26-
27-
public var context: FunctionPassContext {
28-
self._context
29-
}
30-
31-
public var bridged: BridgedSpecializationCloner {
32-
self._bridged
23+
self.bridged = BridgedSpecializationCloner(emptySpecializedFunction.bridged)
24+
self.context = context
3325
}
3426

35-
public var cloned: Function {
27+
var cloned: Function {
3628
bridged.getCloned().function
3729
}
3830

39-
public var entryBlock: BasicBlock {
31+
var entryBlock: BasicBlock {
4032
if cloned.blocks.isEmpty {
4133
cloned.appendNewBlock(context)
4234
} else {
4335
cloned.entryBlock
4436
}
4537
}
4638

47-
public func getClonedBlock(for originalBlock: BasicBlock) -> BasicBlock {
39+
func getClonedBlock(for originalBlock: BasicBlock) -> BasicBlock {
4840
bridged.getClonedBasicBlock(originalBlock.bridged).block
4941
}
5042

51-
public func cloneFunctionBody(from originalFunction: Function, entryBlockArgs: [Value]) {
52-
entryBlockArgs.withBridgedValues { bridgedEntryBlockArgs in
43+
func cloneFunctionBody(from originalFunction: Function, entryBlockArguments: [Value]) {
44+
entryBlockArguments.withBridgedValues { bridgedEntryBlockArgs in
5345
bridged.cloneFunctionBody(originalFunction.bridged, self.entryBlock.bridged, bridgedEntryBlockArgs)
5446
}
5547
}

0 commit comments

Comments
 (0)