Skip to content

Commit d5d076d

Browse files
committed
[AutoDiff] Support differentiation of branching cast instructions.
Support differentiation of `is` and `as?` operators. These operators lower to branching cast SIL instructions, requiring control flow differentiation support. Resolves SR-12898.
1 parent 24de636 commit d5d076d

File tree

6 files changed

+201
-7
lines changed

6 files changed

+201
-7
lines changed

include/swift/SILOptimizer/Differentiation/VJPEmitter.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,12 @@ class VJPEmitter final
153153

154154
void visitSwitchEnumAddrInst(SwitchEnumAddrInst *seai);
155155

156+
void visitCheckedCastBranchInst(CheckedCastBranchInst *ccbi);
157+
158+
void visitCheckedCastValueBranchInst(CheckedCastValueBranchInst *ccvbi);
159+
160+
void visitCheckedCastAddrBranchInst(CheckedCastAddrBranchInst *ccabi);
161+
156162
// If an `apply` has active results or active inout arguments, replace it
157163
// with an `apply` of its VJP.
158164
void visitApplyInst(ApplyInst *ai);

lib/SILOptimizer/Analysis/DifferentiableActivityAnalysis.cpp

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -193,12 +193,26 @@ void DifferentiableActivityInfo::propagateVaried(
193193
if (auto *destBBArg = cbi->getArgForOperand(operand))
194194
setVariedAndPropagateToUsers(destBBArg, i);
195195
}
196-
// Handle `switch_enum`.
197-
else if (auto *sei = dyn_cast<SwitchEnumInst>(inst)) {
198-
if (isVaried(sei->getOperand(), i))
199-
for (auto *succBB : sei->getSuccessorBlocks())
196+
// Handle `checked_cast_addr_br`.
197+
// Propagate variedness from source operand to destination operand, in
198+
// addition to all successor block arguments.
199+
else if (auto *ccabi = dyn_cast<CheckedCastAddrBranchInst>(inst)) {
200+
if (isVaried(ccabi->getSrc(), i)) {
201+
setVariedAndPropagateToUsers(ccabi->getDest(), i);
202+
for (auto *succBB : ccabi->getSuccessorBlocks())
200203
for (auto *arg : succBB->getArguments())
201204
setVariedAndPropagateToUsers(arg, i);
205+
}
206+
}
207+
// Handle all other terminators: if any operand is active, propagate
208+
// variedness to all successor block arguments. This logic may be incorrect
209+
// for some terminator instructions, so special cases must be defined above.
210+
else if (auto *termInst = dyn_cast<TermInst>(inst)) {
211+
for (auto &op : termInst->getAllOperands())
212+
if (isVaried(op.get(), i))
213+
for (auto *succBB : termInst->getSuccessorBlocks())
214+
for (auto *arg : succBB->getArguments())
215+
setVariedAndPropagateToUsers(arg, i);
202216
}
203217
// Handle everything else.
204218
else {

lib/SILOptimizer/Differentiation/VJPEmitter.cpp

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -481,6 +481,47 @@ void VJPEmitter::visitSwitchEnumAddrInst(SwitchEnumAddrInst *seai) {
481481
visitSwitchEnumInstBase(seai);
482482
}
483483

484+
void VJPEmitter::visitCheckedCastBranchInst(CheckedCastBranchInst *ccbi) {
485+
// Build pullback struct value for original block.
486+
auto *pbStructVal = buildPullbackValueStructValue(ccbi->getParent());
487+
// Create a new `checked_cast_branch` instruction.
488+
getBuilder().createCheckedCastBranch(
489+
ccbi->getLoc(), ccbi->isExact(), getOpValue(ccbi->getOperand()),
490+
getOpType(ccbi->getTargetLoweredType()),
491+
getOpASTType(ccbi->getTargetFormalType()),
492+
createTrampolineBasicBlock(ccbi, pbStructVal, ccbi->getSuccessBB()),
493+
createTrampolineBasicBlock(ccbi, pbStructVal, ccbi->getFailureBB()),
494+
ccbi->getTrueBBCount(), ccbi->getFalseBBCount());
495+
}
496+
497+
void VJPEmitter::visitCheckedCastValueBranchInst(
498+
CheckedCastValueBranchInst *ccvbi) {
499+
// Build pullback struct value for original block.
500+
auto *pbStructVal = buildPullbackValueStructValue(ccvbi->getParent());
501+
// Create a new `checked_cast_value_branch` instruction.
502+
getBuilder().createCheckedCastValueBranch(
503+
ccvbi->getLoc(), getOpValue(ccvbi->getOperand()),
504+
getOpASTType(ccvbi->getSourceFormalType()),
505+
getOpType(ccvbi->getTargetLoweredType()),
506+
getOpASTType(ccvbi->getTargetFormalType()),
507+
createTrampolineBasicBlock(ccvbi, pbStructVal, ccvbi->getSuccessBB()),
508+
createTrampolineBasicBlock(ccvbi, pbStructVal, ccvbi->getFailureBB()));
509+
}
510+
511+
void VJPEmitter::visitCheckedCastAddrBranchInst(
512+
CheckedCastAddrBranchInst *ccabi) {
513+
// Build pullback struct value for original block.
514+
auto *pbStructVal = buildPullbackValueStructValue(ccabi->getParent());
515+
// Create a new `checked_cast_addr_branch` instruction.
516+
getBuilder().createCheckedCastAddrBranch(
517+
ccabi->getLoc(), ccabi->getConsumptionKind(), getOpValue(ccabi->getSrc()),
518+
getOpASTType(ccabi->getSourceFormalType()), getOpValue(ccabi->getDest()),
519+
getOpASTType(ccabi->getTargetFormalType()),
520+
createTrampolineBasicBlock(ccabi, pbStructVal, ccabi->getSuccessBB()),
521+
createTrampolineBasicBlock(ccabi, pbStructVal, ccabi->getFailureBB()),
522+
ccabi->getTrueBBCount(), ccabi->getFalseBBCount());
523+
}
524+
484525
void VJPEmitter::visitApplyInst(ApplyInst *ai) {
485526
// If callee should not be differentiated, do standard cloning.
486527
if (!pullbackInfo.shouldDifferentiateApplySite(ai)) {

lib/SILOptimizer/Mandatory/Differentiation.cpp

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -152,10 +152,12 @@ static bool diagnoseUnsupportedControlFlow(ADContext &context,
152152
// Diagnose unsupported branching terminators.
153153
for (auto &bb : *original) {
154154
auto *term = bb.getTerminator();
155-
// Supported terminators are: `br`, `cond_br`, `switch_enum`,
156-
// `switch_enum_addr`.
155+
// Check supported branching terminators.
157156
if (isa<BranchInst>(term) || isa<CondBranchInst>(term) ||
158-
isa<SwitchEnumInst>(term) || isa<SwitchEnumAddrInst>(term))
157+
isa<SwitchEnumInst>(term) || isa<SwitchEnumAddrInst>(term) ||
158+
isa<CheckedCastBranchInst>(term) ||
159+
isa<CheckedCastValueBranchInst>(term) ||
160+
isa<CheckedCastAddrBranchInst>(term))
159161
continue;
160162
// If terminator is an unsupported branching terminator, emit an error.
161163
if (term->isBranch()) {

test/AutoDiff/SILOptimizer/activity_analysis.swift

Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,116 @@ func TF_954(_ x: Float) -> Float {
122122
// CHECK: [ACTIVE] %40 = begin_access [read] [static] %2 : $*Float
123123
// CHECK: [ACTIVE] %41 = load [trivial] %40 : $*Float
124124

125+
//===----------------------------------------------------------------------===//
126+
// Branching cast instructions
127+
//===----------------------------------------------------------------------===//
128+
129+
@differentiable
130+
func checked_cast_branch(_ x: Float) -> Float {
131+
// expected-warning @+1 {{'is' test is always true}}
132+
if Int.self is Any.Type {
133+
return x + x
134+
}
135+
return x * x
136+
}
137+
138+
// CHECK-LABEL: [AD] Activity info for ${{.*}}checked_cast_branch{{.*}} at (source=0 parameters=(0))
139+
// CHECK: bb0:
140+
// CHECK: [ACTIVE] %0 = argument of bb0 : $Float
141+
// CHECK: [NONE] %2 = metatype $@thin Int.Type
142+
// CHECK: [NONE] %3 = metatype $@thick Int.Type
143+
// CHECK: bb1:
144+
// CHECK: [NONE] %5 = argument of bb1 : $@thick Any.Type
145+
// CHECK: [NONE] %6 = integer_literal $Builtin.Int1, -1
146+
// CHECK: bb2:
147+
// CHECK: [NONE] %8 = argument of bb2 : $@thick Int.Type
148+
// CHECK: [NONE] %9 = integer_literal $Builtin.Int1, 0
149+
// CHECK: bb3:
150+
// CHECK: [NONE] %11 = argument of bb3 : $Builtin.Int1
151+
// CHECK: [NONE] %12 = metatype $@thin Bool.Type
152+
// CHECK: [NONE] // function_ref Bool.init(_builtinBooleanLiteral:)
153+
// CHECK: [NONE] %14 = apply %13(%11, %12) : $@convention(method) (Builtin.Int1, @thin Bool.Type) -> Bool
154+
// CHECK: [NONE] %15 = struct_extract %14 : $Bool, #Bool._value
155+
// CHECK: bb4:
156+
// CHECK: [USEFUL] %17 = metatype $@thin Float.Type
157+
// CHECK: [NONE] // function_ref static Float.+ infix(_:_:)
158+
// CHECK: [ACTIVE] %19 = apply %18(%0, %0, %17) : $@convention(method) (Float, Float, @thin Float.Type) -> Float
159+
// CHECK: bb5:
160+
// CHECK: [USEFUL] %21 = metatype $@thin Float.Type
161+
// CHECK: [NONE] // function_ref static Float.* infix(_:_:)
162+
// CHECK: [ACTIVE] %23 = apply %22(%0, %0, %21) : $@convention(method) (Float, Float, @thin Float.Type) -> Float
163+
164+
// CHECK-LABEL: sil hidden [ossa] @${{.*}}checked_cast_branch{{.*}} : $@convention(thin) (Float) -> Float {
165+
// CHECK: checked_cast_br %3 : $@thick Int.Type to Any.Type, bb1, bb2
166+
// CHECK: }
167+
168+
@differentiable
169+
func checked_cast_addr_nonactive_result<T: Differentiable>(_ x: T) -> T {
170+
if let _ = x as? Float {
171+
// Do nothing with `y: Float?` value.
172+
}
173+
return x
174+
}
175+
176+
// CHECK-LABEL: [AD] Activity info for ${{.*}}checked_cast_addr_nonactive_result{{.*}} at (source=0 parameters=(0))
177+
// CHECK: bb0:
178+
// CHECK: [ACTIVE] %0 = argument of bb0 : $*T
179+
// CHECK: [ACTIVE] %1 = argument of bb0 : $*T
180+
// CHECK: [VARIED] %3 = alloc_stack $T
181+
// CHECK: [VARIED] %5 = alloc_stack $Float
182+
// CHECK: bb1:
183+
// CHECK: [VARIED] %7 = load [trivial] %5 : $*Float
184+
// CHECK: [VARIED] %8 = enum $Optional<Float>, #Optional.some!enumelt, %7 : $Float
185+
// CHECK: bb2:
186+
// CHECK: [NONE] %11 = enum $Optional<Float>, #Optional.none!enumelt
187+
// CHECK: bb3:
188+
// CHECK: [VARIED] %14 = argument of bb3 : $Optional<Float>
189+
// CHECK: bb4:
190+
// CHECK: bb5:
191+
// CHECK: [VARIED] %18 = argument of bb5 : $Float
192+
// CHECK: bb6:
193+
// CHECK: [NONE] %22 = tuple ()
194+
195+
// CHECK-LABEL: sil hidden [ossa] @${{.*}}checked_cast_addr_nonactive_result{{.*}} : $@convention(thin) <T where T : Differentiable> (@in_guaranteed T) -> @out T {
196+
// CHECK: checked_cast_addr_br take_always T in %3 : $*T to Float in %5 : $*Float, bb1, bb2
197+
// CHECK: }
198+
199+
// expected-error @+1 {{function is not differentiable}}
200+
@differentiable
201+
// expected-note @+1 {{when differentiating this function definition}}
202+
func checked_cast_addr_active_result<T: Differentiable>(x: T) -> T {
203+
// expected-note @+1 {{differentiating enum values is not yet supported}}
204+
if let y = x as? Float {
205+
// Use `y: Float?` value in an active way.
206+
return y as! T
207+
}
208+
return x
209+
}
210+
211+
// CHECK-LABEL: [AD] Activity info for ${{.*}}checked_cast_addr_active_result{{.*}} at (source=0 parameters=(0))
212+
// CHECK: bb0:
213+
// CHECK: [ACTIVE] %0 = argument of bb0 : $*T
214+
// CHECK: [ACTIVE] %1 = argument of bb0 : $*T
215+
// CHECK: [ACTIVE] %3 = alloc_stack $T
216+
// CHECK: [ACTIVE] %5 = alloc_stack $Float
217+
// CHECK: bb1:
218+
// CHECK: [ACTIVE] %7 = load [trivial] %5 : $*Float
219+
// CHECK: [ACTIVE] %8 = enum $Optional<Float>, #Optional.some!enumelt, %7 : $Float
220+
// CHECK: bb2:
221+
// CHECK: [USEFUL] %11 = enum $Optional<Float>, #Optional.none!enumelt
222+
// CHECK: bb3:
223+
// CHECK: [ACTIVE] %14 = argument of bb3 : $Optional<Float>
224+
// CHECK: bb4:
225+
// CHECK: [ACTIVE] %16 = argument of bb4 : $Float
226+
// CHECK: [ACTIVE] %19 = alloc_stack $Float
227+
// CHECK: bb5:
228+
// CHECK: bb6:
229+
// CHECK: [NONE] %27 = tuple ()
230+
231+
// CHECK-LABEL: sil hidden [ossa] @${{.*}}checked_cast_addr_active_result{{.*}} : $@convention(thin) <T where T : Differentiable> (@in_guaranteed T) -> @out T {
232+
// CHECK: checked_cast_addr_br take_always T in %3 : $*T to Float in %5 : $*Float, bb1, bb2
233+
// CHECK: }
234+
125235
//===----------------------------------------------------------------------===//
126236
// Array literal differentiation
127237
//===----------------------------------------------------------------------===//

test/AutoDiff/validation-test/control_flow.swift

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -715,4 +715,25 @@ ControlFlowTests.test("Loops") {
715715
expectEqual((24, 28), valueWithGradient(at: 2, in: { x in nested_loop2(x, count: 4) }))
716716
}
717717

718+
ControlFlowTests.test("BranchingCastInstructions") {
719+
// checked_cast_br
720+
func typeCheckOperator<T>(_ x: Float, _ metatype: T.Type) -> Float {
721+
if metatype is Int.Type {
722+
return x + x
723+
}
724+
return x * x
725+
}
726+
expectEqual((6, 2), valueWithGradient(at: 3, in: { typeCheckOperator($0, Int.self) }))
727+
expectEqual((9, 6), valueWithGradient(at: 3, in: { typeCheckOperator($0, Float.self) }))
728+
729+
// checked_cast_addr_br
730+
func conditionalCast<T: Differentiable>(_ x: T) -> T {
731+
if let _ = x as? Float {
732+
// Do nothing with `y: Float?` value.
733+
}
734+
return x
735+
}
736+
expectEqual((3, 1), valueWithGradient(at: Float(3), in: conditionalCast))
737+
}
738+
718739
runAllTests()

0 commit comments

Comments
 (0)