Skip to content

Commit 40c4485

Browse files
authored
[AutoDiff] Support differentiation of non-active try_apply. (swiftlang#33483)
Add differentiation support for non-active `try_apply` SIL instructions. Notable pullback generation changes: * Original basic blocks are now visited in a different order: * starting from the original basic block, all its predecessors * are visited in a breadth-first search order. This ensures that * all successors of any block are visited before the block itself. Resolves TF-433.
1 parent fbc0463 commit 40c4485

File tree

10 files changed

+181
-114
lines changed

10 files changed

+181
-114
lines changed

include/swift/SILOptimizer/Differentiation/Common.h

Lines changed: 0 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -271,59 +271,6 @@ inline void createEntryArguments(SILFunction *f) {
271271
}
272272
}
273273

274-
/// Helper class for visiting basic blocks in post-order post-dominance order,
275-
/// based on a worklist algorithm.
276-
class PostOrderPostDominanceOrder {
277-
SmallVector<DominanceInfoNode *, 16> buffer;
278-
PostOrderFunctionInfo *postOrderInfo;
279-
size_t srcIdx = 0;
280-
281-
public:
282-
/// Constructor.
283-
/// \p root The root of the post-dominator tree.
284-
/// \p postOrderInfo The post-order info of the function.
285-
/// \p capacity Should be the number of basic blocks in the dominator tree to
286-
/// reduce memory allocation.
287-
PostOrderPostDominanceOrder(DominanceInfoNode *root,
288-
PostOrderFunctionInfo *postOrderInfo,
289-
int capacity = 0)
290-
: postOrderInfo(postOrderInfo) {
291-
buffer.reserve(capacity);
292-
buffer.push_back(root);
293-
}
294-
295-
/// Get the next block from the worklist.
296-
DominanceInfoNode *getNext() {
297-
if (srcIdx == buffer.size())
298-
return nullptr;
299-
return buffer[srcIdx++];
300-
}
301-
302-
/// Pushes the dominator children of a block onto the worklist in post-order.
303-
void pushChildren(DominanceInfoNode *node) {
304-
pushChildrenIf(node, [](SILBasicBlock *) { return true; });
305-
}
306-
307-
/// Conditionally pushes the dominator children of a block onto the worklist
308-
/// in post-order.
309-
template <typename Pred>
310-
void pushChildrenIf(DominanceInfoNode *node, Pred pred) {
311-
SmallVector<DominanceInfoNode *, 4> children;
312-
for (auto *child : *node)
313-
children.push_back(child);
314-
llvm::sort(children.begin(), children.end(),
315-
[&](DominanceInfoNode *n1, DominanceInfoNode *n2) {
316-
return postOrderInfo->getPONumber(n1->getBlock()) <
317-
postOrderInfo->getPONumber(n2->getBlock());
318-
});
319-
for (auto *child : children) {
320-
SILBasicBlock *childBB = child->getBlock();
321-
if (pred(childBB))
322-
buffer.push_back(child);
323-
}
324-
}
325-
};
326-
327274
/// Cloner that remaps types using the target function's generic environment.
328275
class BasicTypeSubstCloner final
329276
: public TypeSubstCloner<BasicTypeSubstCloner, SILOptFunctionBuilder> {

lib/SILOptimizer/Differentiation/PullbackCloner.cpp

Lines changed: 33 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1721,22 +1721,36 @@ bool PullbackCloner::Implementation::run() {
17211721
domOrder.pushChildren(bb);
17221722
}
17231723

1724-
// Create pullback blocks and arguments, visiting original blocks in
1725-
// post-order post-dominance order.
1726-
SmallVector<SILBasicBlock *, 8> postOrderPostDomOrder;
1727-
// Start from the root node, which may have a marker `nullptr` block if
1728-
// there are multiple roots.
1729-
PostOrderPostDominanceOrder postDomOrder(postDomInfo->getRootNode(),
1730-
postOrderInfo, original.size());
1731-
while (auto *origNode = postDomOrder.getNext()) {
1732-
auto *origBB = origNode->getBlock();
1733-
postDomOrder.pushChildren(origNode);
1734-
// If node is the `nullptr` marker basic block, do not push it.
1735-
if (!origBB)
1736-
continue;
1737-
postOrderPostDomOrder.push_back(origBB);
1724+
// Create pullback blocks and arguments, visiting original blocks using BFS
1725+
// starting from the original exit block. Unvisited original basic blocks
1726+
// (e.g unreachable blocks) are not relevant for pullback generation and thus
1727+
// ignored.
1728+
// The original blocks in traversal order for pullback generation.
1729+
SmallVector<SILBasicBlock *, 8> originalBlocks;
1730+
// The set of visited original blocks.
1731+
SmallDenseSet<SILBasicBlock *, 8> visitedBlocks;
1732+
1733+
// Perform BFS from the original exit block.
1734+
{
1735+
std::deque<SILBasicBlock *> worklist = {};
1736+
worklist.push_back(origExit);
1737+
visitedBlocks.insert(origExit);
1738+
while (!worklist.empty()) {
1739+
auto *BB = worklist.front();
1740+
worklist.pop_front();
1741+
1742+
originalBlocks.push_back(BB);
1743+
1744+
for (auto *nextBB : BB->getPredecessorBlocks()) {
1745+
if (!visitedBlocks.count(nextBB)) {
1746+
worklist.push_back(nextBB);
1747+
visitedBlocks.insert(nextBB);
1748+
}
1749+
}
1750+
}
17381751
}
1739-
for (auto *origBB : postOrderPostDomOrder) {
1752+
1753+
for (auto *origBB : originalBlocks) {
17401754
auto *pullbackBB = pullback.createBasicBlock();
17411755
pullbackBBMap.insert({origBB, pullbackBB});
17421756
auto pbStructLoweredType =
@@ -1801,6 +1815,9 @@ bool PullbackCloner::Implementation::run() {
18011815
// struct argument. They branch from a pullback successor block to the
18021816
// pullback original block, passing adjoint values of active values.
18031817
for (auto *succBB : origBB->getSuccessorBlocks()) {
1818+
// Skip generating pullback block for original unreachable blocks.
1819+
if (!visitedBlocks.count(succBB))
1820+
continue;
18041821
auto *pullbackTrampolineBB = pullback.createBasicBlockBefore(pullbackBB);
18051822
pullbackTrampolineBBMap.insert({{origBB, succBB}, pullbackTrampolineBB});
18061823
// Get the enum element type (i.e. the pullback struct type). The enum
@@ -1870,7 +1887,7 @@ bool PullbackCloner::Implementation::run() {
18701887
// Visit original blocks blocks in post-order and perform differentiation
18711888
// in corresponding pullback blocks. If errors occurred, back out.
18721889
else {
1873-
for (auto *bb : postOrderPostDomOrder) {
1890+
for (auto *bb : originalBlocks) {
18741891
visitSILBasicBlock(bb);
18751892
if (errorOccurred)
18761893
return true;

lib/SILOptimizer/Differentiation/VJPCloner.cpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -618,6 +618,18 @@ class VJPCloner::Implementation final
618618
getOpValue(origCallee)->getDefiningInstruction());
619619
}
620620

621+
void visitTryApplyInst(TryApplyInst *tai) {
622+
// Build pullback struct value for original block.
623+
auto *pbStructVal = buildPullbackValueStructValue(tai);
624+
// Create a new `try_apply` instruction.
625+
auto args = getOpValueArray<8>(tai->getArguments());
626+
getBuilder().createTryApply(
627+
tai->getLoc(), getOpValue(tai->getCallee()),
628+
getOpSubstitutionMap(tai->getSubstitutionMap()), args,
629+
createTrampolineBasicBlock(tai, pbStructVal, tai->getNormalBB()),
630+
createTrampolineBasicBlock(tai, pbStructVal, tai->getErrorBB()));
631+
}
632+
621633
void visitDifferentiableFunctionInst(DifferentiableFunctionInst *dfi) {
622634
// Clone `differentiable_function` from original to VJP, then add the cloned
623635
// instruction to the `differentiable_function` worklist.

lib/SILOptimizer/Mandatory/Differentiation.cpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -158,8 +158,7 @@ static bool diagnoseNoReturn(ADContext &context, SILFunction *original,
158158
/// flow unsupported" error at appropriate source locations. Returns true if
159159
/// error is emitted.
160160
///
161-
/// Update as control flow support is added. Currently, branching terminators
162-
/// other than `br`, `cond_br`, `switch_enum` are not supported.
161+
/// Update as control flow support is added.
163162
static bool diagnoseUnsupportedControlFlow(ADContext &context,
164163
SILFunction *original,
165164
DifferentiationInvoker invoker) {
@@ -173,7 +172,7 @@ static bool diagnoseUnsupportedControlFlow(ADContext &context,
173172
isa<SwitchEnumInst>(term) || isa<SwitchEnumAddrInst>(term) ||
174173
isa<CheckedCastBranchInst>(term) ||
175174
isa<CheckedCastValueBranchInst>(term) ||
176-
isa<CheckedCastAddrBranchInst>(term))
175+
isa<CheckedCastAddrBranchInst>(term) || isa<TryApplyInst>(term))
177176
continue;
178177
// If terminator is an unsupported branching terminator, emit an error.
179178
if (term->isBranch()) {

test/AutoDiff/SILOptimizer/activity_analysis.swift

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -543,17 +543,25 @@ func activeInoutArgNonactiveInitialResult(_ x: Float) -> Float {
543543

544544
func rethrowing(_ x: () throws -> Void) rethrows -> Void {}
545545

546-
// expected-error @+1 {{function is not differentiable}}
547546
@differentiable
548-
// expected-note @+1 {{when differentiating this function definition}}
549547
func testTryApply(_ x: Float) -> Float {
550-
// expected-note @+1 {{cannot differentiate unsupported control flow}}
551548
rethrowing({})
552549
return x
553550
}
554551

555552
// TF-433: differentiation diagnoses `try_apply` before activity info is printed.
556-
// CHECK-NOT: [AD] Activity info for ${{.*}}testTryApply{{.*}} at (parameters=(0) results=(0))
553+
// CHECK-LABEL: [AD] Activity info for ${{.*}}testTryApply{{.*}} at (parameters=(0) results=(0))
554+
// CHECK: bb0:
555+
// CHECK: [ACTIVE] %0 = argument of bb0 : $Float
556+
// CHECK: [NONE] // function_ref closure #1 in testTryApply(_:)
557+
// CHECK: [NONE] %3 = convert_function %2 : $@convention(thin) () -> () to $@convention(thin) @noescape () -> ()
558+
// CHECK: [NONE] %4 = thin_to_thick_function %3 : $@convention(thin) @noescape () -> () to $@noescape @callee_guaranteed () -> ()
559+
// CHECK: [NONE] %5 = convert_function %4 : $@noescape @callee_guaranteed () -> () to $@noescape @callee_guaranteed () -> @error Error
560+
// CHECK: [NONE] // function_ref rethrowing(_:)
561+
// CHECK: bb1:
562+
// CHECK: [NONE] %8 = argument of bb1 : $()
563+
// CHECK: bb2:
564+
// CHECK: [NONE] %10 = argument of bb2 : $Error
557565

558566
//===----------------------------------------------------------------------===//
559567
// Coroutine differentiation (`begin_apply`)

test/AutoDiff/SILOptimizer/differentiation_control_flow_diagnostics.swift

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -78,11 +78,8 @@ func nested_loop(_ x: Float) -> Float {
7878

7979
func rethrowing(_ x: () throws -> Void) rethrows -> Void {}
8080

81-
// expected-error @+1 {{function is not differentiable}}
8281
@differentiable
83-
// expected-note @+1 {{when differentiating this function definition}}
8482
func testTryApply(_ x: Float) -> Float {
85-
// expected-note @+1 {{cannot differentiate unsupported control flow}}
8683
rethrowing({})
8784
return x
8885
}
@@ -93,10 +90,19 @@ func testTryApply(_ x: Float) -> Float {
9390
func withoutDerivative<T : Differentiable, R: Differentiable>(
9491
at x: T, in body: (T) throws -> R
9592
) rethrows -> R {
96-
// expected-note @+1 {{cannot differentiate unsupported control flow}}
93+
// expected-note @+1 {{expression is not differentiable}}
9794
try body(x)
9895
}
9996

97+
// Tests active `try_apply`.
98+
// expected-error @+1 {{function is not differentiable}}
99+
@differentiable
100+
// expected-note @+1 {{when differentiating this function definition}}
101+
func testNilCoalescing(_ maybeX: Float?) -> Float {
102+
// expected-note @+1 {{expression is not differentiable}}
103+
return maybeX ?? 10
104+
}
105+
100106
// Test unsupported differentiation of active enum values.
101107

102108
// expected-error @+1 {{function is not differentiable}}

test/AutoDiff/SILOptimizer/differentiation_control_flow_sil.swift

Lines changed: 28 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -68,29 +68,30 @@ func cond(_ x: Float) -> Float {
6868
// CHECK-SIL: [[PB:%.*]] = partial_apply [callee_guaranteed] [[PULLBACK_REF]]([[BB3_PB_STRUCT]])
6969
// CHECK-SIL: [[VJP_RESULT:%.*]] = tuple ([[ORIG_RES]] : $Float, [[PB]] : $@callee_guaranteed (Float) -> Float)
7070
// CHECK-SIL: return [[VJP_RESULT]]
71+
// CHECK-SIL-LABEL: } // end sil function 'AD__cond__vjp_src_0_wrt_0'
7172

7273

7374
// CHECK-SIL-LABEL: sil private [ossa] @AD__cond__pullback_src_0_wrt_0 : $@convention(thin) (Float, @owned _AD__cond_bb3__PB__src_0_wrt_0) -> Float {
7475
// CHECK-SIL: bb0([[SEED:%.*]] : $Float, [[BB3_PB_STRUCT:%.*]] : @owned $_AD__cond_bb3__PB__src_0_wrt_0):
7576
// CHECK-SIL: [[BB3_PRED:%.*]] = destructure_struct [[BB3_PB_STRUCT]] : $_AD__cond_bb3__PB__src_0_wrt_0
76-
// CHECK-SIL: switch_enum [[BB3_PRED]] : $_AD__cond_bb3__Pred__src_0_wrt_0, case #_AD__cond_bb3__Pred__src_0_wrt_0.bb2!enumelt: bb3, case #_AD__cond_bb3__Pred__src_0_wrt_0.bb1!enumelt: bb1
77+
// CHECK-SIL: switch_enum [[BB3_PRED]] : $_AD__cond_bb3__Pred__src_0_wrt_0, case #_AD__cond_bb3__Pred__src_0_wrt_0.bb2!enumelt: bb1, case #_AD__cond_bb3__Pred__src_0_wrt_0.bb1!enumelt: bb3
7778

78-
// CHECK-SIL: bb1([[BB3_PRED1_TRAMP_PB_STRUCT:%.*]] : @owned $_AD__cond_bb1__PB__src_0_wrt_0):
79-
// CHECK-SIL: br bb2({{%.*}} : $Float, {{%.*}}: $Float, [[BB3_PRED1_TRAMP_PB_STRUCT]] : $_AD__cond_bb1__PB__src_0_wrt_0)
79+
// CHECK-SIL: bb1([[BB3_PRED2_TRAMP_PB_STRUCT:%.*]] : @owned $_AD__cond_bb2__PB__src_0_wrt_0):
80+
// CHECK-SIL: br bb2({{%.*}} : $Float, {{%.*}}: $Float, [[BB3_PRED2_TRAMP_PB_STRUCT]] : $_AD__cond_bb2__PB__src_0_wrt_0)
8081

81-
// CHECK-SIL: bb2({{%.*}} : $Float, {{%.*}} : $Float, [[BB1_PB_STRUCT:%.*]] : @owned $_AD__cond_bb1__PB__src_0_wrt_0):
82-
// CHECK-SIL: ([[BB1_PRED:%.*]], [[BB1_PB:%.*]]) = destructure_struct [[BB1_PB_STRUCT]]
83-
// CHECK-SIL: [[BB1_ADJVALS:%.*]] = apply [[BB1_PB]]([[SEED]]) : $@callee_guaranteed (Float) -> (Float, Float)
84-
// CHECK-SIL: switch_enum [[BB1_PRED]] : $_AD__cond_bb1__Pred__src_0_wrt_0, case #_AD__cond_bb1__Pred__src_0_wrt_0.bb0!enumelt: bb5
85-
86-
// CHECK-SIL: bb3([[BB3_PRED2_TRAMP_PB_STRUCT:%.*]] : @owned $_AD__cond_bb2__PB__src_0_wrt_0):
87-
// CHECK-SIL: br bb4({{%.*}} : $Float, {{%.*}}: $Float, [[BB3_PRED2_TRAMP_PB_STRUCT]] : $_AD__cond_bb2__PB__src_0_wrt_0)
88-
89-
// CHECK-SIL: bb4({{%.*}} : $Float, {{%.*}} : $Float, [[BB2_PB_STRUCT:%.*]] : @owned $_AD__cond_bb2__PB__src_0_wrt_0):
82+
// CHECK-SIL: bb2({{%.*}} : $Float, {{%.*}} : $Float, [[BB2_PB_STRUCT:%.*]] : @owned $_AD__cond_bb2__PB__src_0_wrt_0):
9083
// CHECK-SIL: ([[BB2_PRED:%.*]], [[BB2_PB:%.*]]) = destructure_struct [[BB2_PB_STRUCT]]
9184
// CHECK-SIL: [[BB2_ADJVALS:%.*]] = apply [[BB2_PB]]([[SEED]]) : $@callee_guaranteed (Float) -> (Float, Float)
9285
// CHECK-SIL: switch_enum [[BB2_PRED]] : $_AD__cond_bb2__Pred__src_0_wrt_0, case #_AD__cond_bb2__Pred__src_0_wrt_0.bb0!enumelt: bb6
9386

87+
// CHECK-SIL: bb3([[BB3_PRED1_TRAMP_PB_STRUCT:%.*]] : @owned $_AD__cond_bb1__PB__src_0_wrt_0):
88+
// CHECK-SIL: br bb4({{%.*}} : $Float, {{%.*}}: $Float, [[BB3_PRED1_TRAMP_PB_STRUCT]] : $_AD__cond_bb1__PB__src_0_wrt_0)
89+
90+
// CHECK-SIL: bb4({{%.*}} : $Float, {{%.*}} : $Float, [[BB1_PB_STRUCT:%.*]] : @owned $_AD__cond_bb1__PB__src_0_wrt_0):
91+
// CHECK-SIL: ([[BB1_PRED:%.*]], [[BB1_PB:%.*]]) = destructure_struct [[BB1_PB_STRUCT]]
92+
// CHECK-SIL: [[BB1_ADJVALS:%.*]] = apply [[BB1_PB]]([[SEED]]) : $@callee_guaranteed (Float) -> (Float, Float)
93+
// CHECK-SIL: switch_enum [[BB1_PRED]] : $_AD__cond_bb1__Pred__src_0_wrt_0, case #_AD__cond_bb1__Pred__src_0_wrt_0.bb0!enumelt: bb5
94+
9495
// CHECK-SIL: bb5([[BB1_PRED0_TRAMP_PB_STRUCT:%.*]] : $_AD__cond_bb0__PB__src_0_wrt_0):
9596
// CHECK-SIL: br bb7({{%.*}} : $Float, [[BB1_PRED0_TRAMP_PB_STRUCT]] : $_AD__cond_bb0__PB__src_0_wrt_0)
9697

@@ -99,6 +100,7 @@ func cond(_ x: Float) -> Float {
99100

100101
// CHECK-SIL: bb7({{%.*}} : $Float, [[BB0_PB_STRUCT:%.*]] : $_AD__cond_bb0__PB__src_0_wrt_0):
101102
// CHECK-SIL: return {{%.*}} : $Float
103+
// CHECK-SIL-LABEL: } // end sil function 'AD__cond__pullback_src_0_wrt_0'
102104

103105
@differentiable
104106
@_silgen_name("nested_cond")
@@ -178,7 +180,7 @@ func enum_notactive(_ e: Enum, _ x: Float) -> Float {
178180
// CHECK-SIL: [[PB:%.*]] = partial_apply [callee_guaranteed] [[PULLBACK_REF]]([[BB3_PB_STRUCT]])
179181
// CHECK-SIL: [[VJP_RESULT:%.*]] = tuple ([[ORIG_RES]] : $Float, [[PB]] : $@callee_guaranteed (Float) -> Float)
180182
// CHECK-SIL: return [[VJP_RESULT]]
181-
// CHECK-SIL: }
183+
// CHECK-SIL-LABEL: } // end sil function 'AD__enum_notactive__vjp_src_0_wrt_1'
182184

183185
// Test `switch_enum_addr`.
184186

@@ -227,7 +229,7 @@ func enum_addr_notactive<T>(_ e: AddressOnlyEnum<T>, _ x: Float) -> Float {
227229
// CHECK-SIL: [[PB:%.*]] = partial_apply [callee_guaranteed] [[PB_FNREF]]<τ_0_0>([[BB3_PB_STRUCT]]) : $@convention(thin) <τ_0_0> (Float, @owned _AD__enum_addr_notactive_bb3__PB__src_0_wrt_1_l<τ_0_0>) -> Float
228230
// CHECK-SIL: [[VJP_RESULT:%.*]] = tuple ([[X_ARG]] : $Float, [[PB]] : $@callee_guaranteed (Float) -> Float)
229231
// CHECK-SIL: return [[VJP_RESULT]] : $(Float, @callee_guaranteed (Float) -> Float)
230-
// CHECK-SIL: }
232+
// CHECK-SIL-LABEL: } // end sil function 'AD__enum_addr_notactive__vjp_src_0_wrt_1_l'
231233

232234
// Test control flow + tuple buffer.
233235
// Verify that pullback buffers are not allocated for address projections.
@@ -248,25 +250,25 @@ func cond_tuple_var(_ x: Float) -> Float {
248250
// CHECK-SIL: [[BB3_PRED:%.*]] = destructure_struct [[BB3_PB_STRUCT]] : $_AD__cond_tuple_var_bb3__PB__src_0_wrt_0
249251
// CHECK-SIL: copy_addr {{%.*}} to {{%.*}} : $*(Float, Float)
250252
// CHECK-SIL-NOT: copy_addr {{%.*}} to {{%.*}} : $*Float
251-
// CHECK-SIL: switch_enum [[BB3_PRED]] : $_AD__cond_tuple_var_bb3__Pred__src_0_wrt_0, case #_AD__cond_tuple_var_bb3__Pred__src_0_wrt_0.bb2!enumelt: bb3, case #_AD__cond_tuple_var_bb3__Pred__src_0_wrt_0.bb1!enumelt: bb1
253+
// CHECK-SIL: switch_enum [[BB3_PRED]] : $_AD__cond_tuple_var_bb3__Pred__src_0_wrt_0, case #_AD__cond_tuple_var_bb3__Pred__src_0_wrt_0.bb2!enumelt: bb1, case #_AD__cond_tuple_var_bb3__Pred__src_0_wrt_0.bb1!enumelt: bb3
252254

253-
// CHECK-SIL: bb1([[BB3_PRED1_TRAMP_PB_STRUCT:%.*]] : $_AD__cond_tuple_var_bb1__PB__src_0_wrt_0):
254-
// CHECK-SIL: br bb2({{%.*}} : $Float, {{%.*}} : $Float, [[BB3_PRED1_TRAMP_PB_STRUCT]] : $_AD__cond_tuple_var_bb1__PB__src_0_wrt_0)
255+
// CHECK-SIL: bb1([[BB3_PRED2_TRAMP_PB_STRUCT:%.*]] : $_AD__cond_tuple_var_bb2__PB__src_0_wrt_0):
256+
// CHECK-SIL: br bb2({{%.*}} : $Float, {{%.*}} : $Float, [[BB3_PRED2_TRAMP_PB_STRUCT]] : $_AD__cond_tuple_var_bb2__PB__src_0_wrt_0)
255257

256-
// CHECK-SIL: bb2({{%.*}} : $Float, {{%.*}} : $Float, [[BB1_PB_STRUCT:%.*]] : $_AD__cond_tuple_var_bb1__PB__src_0_wrt_0):
257-
// CHECK-SIL: [[BB1_PRED:%.*]] = destructure_struct [[BB1_PB_STRUCT]]
258+
// CHECK-SIL: bb2({{%.*}} : $Float, {{%.*}} : $Float, [[BB2_PB_STRUCT:%.*]] : $_AD__cond_tuple_var_bb2__PB__src_0_wrt_0):
259+
// CHECK-SIL: [[BB2_PRED:%.*]] = destructure_struct [[BB2_PB_STRUCT]]
258260
// CHECK-SIL: copy_addr {{%.*}} to {{%.*}} : $*(Float, Float)
259261
// CHECK-SIL-NOT: copy_addr {{%.*}} to {{%.*}} : $*Float
260-
// CHECK-SIL: switch_enum [[BB1_PRED]] : $_AD__cond_tuple_var_bb1__Pred__src_0_wrt_0, case #_AD__cond_tuple_var_bb1__Pred__src_0_wrt_0.bb0!enumelt: bb5
262+
// CHECK-SIL: switch_enum [[BB2_PRED]] : $_AD__cond_tuple_var_bb2__Pred__src_0_wrt_0, case #_AD__cond_tuple_var_bb2__Pred__src_0_wrt_0.bb0!enumelt: bb6
261263

262-
// CHECK-SIL: bb3([[BB3_PRED2_TRAMP_PB_STRUCT:%.*]] : $_AD__cond_tuple_var_bb2__PB__src_0_wrt_0):
263-
// CHECK-SIL: br bb4({{%.*}} : $Float, {{%.*}} : $Float, [[BB3_PRED2_TRAMP_PB_STRUCT]] : $_AD__cond_tuple_var_bb2__PB__src_0_wrt_0)
264+
// CHECK-SIL: bb3([[BB3_PRED1_TRAMP_PB_STRUCT:%.*]] : $_AD__cond_tuple_var_bb1__PB__src_0_wrt_0):
265+
// CHECK-SIL: br bb4({{%.*}} : $Float, {{%.*}} : $Float, [[BB3_PRED1_TRAMP_PB_STRUCT]] : $_AD__cond_tuple_var_bb1__PB__src_0_wrt_0)
264266

265-
// CHECK-SIL: bb4({{%.*}} : $Float, {{%.*}} : $Float, [[BB2_PB_STRUCT:%.*]] : $_AD__cond_tuple_var_bb2__PB__src_0_wrt_0):
266-
// CHECK-SIL: [[BB2_PRED:%.*]] = destructure_struct [[BB2_PB_STRUCT]]
267+
// CHECK-SIL: bb4({{%.*}} : $Float, {{%.*}} : $Float, [[BB1_PB_STRUCT:%.*]] : $_AD__cond_tuple_var_bb1__PB__src_0_wrt_0):
268+
// CHECK-SIL: [[BB1_PRED:%.*]] = destructure_struct [[BB1_PB_STRUCT]]
267269
// CHECK-SIL: copy_addr {{%.*}} to {{%.*}} : $*(Float, Float)
268270
// CHECK-SIL-NOT: copy_addr {{%.*}} to {{%.*}} : $*Float
269-
// CHECK-SIL: switch_enum [[BB2_PRED]] : $_AD__cond_tuple_var_bb2__Pred__src_0_wrt_0, case #_AD__cond_tuple_var_bb2__Pred__src_0_wrt_0.bb0!enumelt: bb6
271+
// CHECK-SIL: switch_enum [[BB1_PRED]] : $_AD__cond_tuple_var_bb1__Pred__src_0_wrt_0, case #_AD__cond_tuple_var_bb1__Pred__src_0_wrt_0.bb0!enumelt: bb5
270272

271273
// CHECK-SIL: bb5([[BB1_PRED0_TRAMP_PB_STRUCT:%.*]] : $_AD__cond_tuple_var_bb0__PB__src_0_wrt_0):
272274
// CHECK-SIL: br bb7({{%.*}} : $Float, [[BB1_PRED0_TRAMP_PB_STRUCT]] : $_AD__cond_tuple_var_bb0__PB__src_0_wrt_0)
@@ -276,3 +278,4 @@ func cond_tuple_var(_ x: Float) -> Float {
276278

277279
// CHECK-SIL: bb7({{%.*}} : $Float, [[BB0_PB_STRUCT:%.*]] : $_AD__cond_tuple_var_bb0__PB__src_0_wrt_0):
278280
// CHECK-SIL: return {{%.*}} : $Float
281+
// CHECK-SIL-LABEL: } // end sil function 'AD__cond_tuple_var__pullback_src_0_wrt_0'

test/AutoDiff/SILOptimizer/differentiation_diagnostics.swift

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -32,22 +32,16 @@ func conditional(_ x: Float, _ flag: Bool) -> Float {
3232

3333
func throwing() throws -> Void {}
3434

35-
// expected-error @+2 {{function is not differentiable}}
36-
// expected-note @+2 {{when differentiating this function definition}}
3735
@differentiable
3836
func try_apply(_ x: Float) -> Float {
39-
// expected-note @+1 {{cannot differentiate unsupported control flow}}
4037
try! throwing()
4138
return x
4239
}
4340

4441
func rethrowing(_ x: () throws -> Void) rethrows -> Void {}
4542

46-
// expected-error @+2 {{function is not differentiable}}
47-
// expected-note @+2 {{when differentiating this function definition}}
4843
@differentiable
4944
func try_apply_rethrows(_ x: Float) -> Float {
50-
// expected-note @+1 {{cannot differentiate unsupported control flow}}
5145
rethrowing({})
5246
return x
5347
}

0 commit comments

Comments
 (0)