Skip to content

Commit c63153d

Browse files
authored
Merge pull request swiftlang#32069 from dan-zheng/autodiff-branching-casts
[AutoDiff] Support differentiation of branching cast instructions.
2 parents 4bf0ddb + d5d076d commit c63153d

File tree

6 files changed

+261
-82
lines changed

6 files changed

+261
-82
lines changed

include/swift/SILOptimizer/Differentiation/VJPEmitter.h

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -117,9 +117,21 @@ class VJPEmitter final
117117
/// Get the lowered SIL type of the given nominal type declaration.
118118
SILType getNominalDeclLoweredType(NominalTypeDecl *nominal);
119119

120-
/// Build a pullback struct value for the original block corresponding to the
121-
/// given terminator.
122-
StructInst *buildPullbackValueStructValue(TermInst *termInst);
120+
// Creates a trampoline block for given original terminator instruction, the
121+
// pullback struct value for its parent block, and a successor basic block.
122+
//
123+
// The trampoline block has the same arguments as and branches to the remapped
124+
// successor block, but drops the last predecessor enum argument.
125+
//
126+
// Used for cloning branching terminator instructions with specific
127+
// requirements on successor block arguments, where an additional predecessor
128+
// enum argument is not acceptable.
129+
SILBasicBlock *createTrampolineBasicBlock(TermInst *termInst,
130+
StructInst *pbStructVal,
131+
SILBasicBlock *succBB);
132+
133+
/// Build a pullback struct value for the given original block.
134+
StructInst *buildPullbackValueStructValue(SILBasicBlock *bb);
123135

124136
/// Build a predecessor enum instance using the given builder for the given
125137
/// original predecessor/successor blocks and pullback struct value.
@@ -141,6 +153,12 @@ class VJPEmitter final
141153

142154
void visitSwitchEnumAddrInst(SwitchEnumAddrInst *seai);
143155

156+
void visitCheckedCastBranchInst(CheckedCastBranchInst *ccbi);
157+
158+
void visitCheckedCastValueBranchInst(CheckedCastValueBranchInst *ccvbi);
159+
160+
void visitCheckedCastAddrBranchInst(CheckedCastAddrBranchInst *ccabi);
161+
144162
// If an `apply` has active results or active inout arguments, replace it
145163
// with an `apply` of its VJP.
146164
void visitApplyInst(ApplyInst *ai);

lib/SILOptimizer/Analysis/DifferentiableActivityAnalysis.cpp

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -193,12 +193,26 @@ void DifferentiableActivityInfo::propagateVaried(
193193
if (auto *destBBArg = cbi->getArgForOperand(operand))
194194
setVariedAndPropagateToUsers(destBBArg, i);
195195
}
196-
// Handle `switch_enum`.
197-
else if (auto *sei = dyn_cast<SwitchEnumInst>(inst)) {
198-
if (isVaried(sei->getOperand(), i))
199-
for (auto *succBB : sei->getSuccessorBlocks())
196+
// Handle `checked_cast_addr_br`.
197+
// Propagate variedness from source operand to destination operand, in
198+
// addition to all successor block arguments.
199+
else if (auto *ccabi = dyn_cast<CheckedCastAddrBranchInst>(inst)) {
200+
if (isVaried(ccabi->getSrc(), i)) {
201+
setVariedAndPropagateToUsers(ccabi->getDest(), i);
202+
for (auto *succBB : ccabi->getSuccessorBlocks())
200203
for (auto *arg : succBB->getArguments())
201204
setVariedAndPropagateToUsers(arg, i);
205+
}
206+
}
207+
// Handle all other terminators: if any operand is active, propagate
208+
// variedness to all successor block arguments. This logic may be incorrect
209+
// for some terminator instructions, so special cases must be defined above.
210+
else if (auto *termInst = dyn_cast<TermInst>(inst)) {
211+
for (auto &op : termInst->getAllOperands())
212+
if (isVaried(op.get(), i))
213+
for (auto *succBB : termInst->getSuccessorBlocks())
214+
for (auto *arg : succBB->getArguments())
215+
setVariedAndPropagateToUsers(arg, i);
202216
}
203217
// Handle everything else.
204218
else {

lib/SILOptimizer/Differentiation/VJPEmitter.cpp

Lines changed: 83 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -265,6 +265,31 @@ SILBasicBlock *VJPEmitter::remapBasicBlock(SILBasicBlock *bb) {
265265
return vjpBB;
266266
}
267267

268+
SILBasicBlock *VJPEmitter::createTrampolineBasicBlock(TermInst *termInst,
269+
StructInst *pbStructVal,
270+
SILBasicBlock *succBB) {
271+
assert(llvm::find(termInst->getSuccessorBlocks(), succBB) !=
272+
termInst->getSuccessorBlocks().end() &&
273+
"Basic block is not a successor of terminator instruction");
274+
// Create the trampoline block.
275+
auto *vjpSuccBB = getOpBasicBlock(succBB);
276+
auto *trampolineBB = vjp->createBasicBlockBefore(vjpSuccBB);
277+
for (auto *arg : vjpSuccBB->getArguments().drop_back())
278+
trampolineBB->createPhiArgument(arg->getType(), arg->getOwnershipKind());
279+
// In the trampoline block, build predecessor enum value for VJP successor
280+
// block and branch to it.
281+
SILBuilder trampolineBuilder(trampolineBB);
282+
auto *origBB = termInst->getParent();
283+
auto *succEnumVal =
284+
buildPredecessorEnumValue(trampolineBuilder, origBB, succBB, pbStructVal);
285+
SmallVector<SILValue, 4> forwardedArguments(
286+
trampolineBB->getArguments().begin(), trampolineBB->getArguments().end());
287+
forwardedArguments.push_back(succEnumVal);
288+
trampolineBuilder.createBranch(termInst->getLoc(), vjpSuccBB,
289+
forwardedArguments);
290+
return trampolineBB;
291+
}
292+
268293
void VJPEmitter::visit(SILInstruction *inst) {
269294
if (errorOccurred)
270295
return;
@@ -290,10 +315,9 @@ SILType VJPEmitter::getNominalDeclLoweredType(NominalTypeDecl *nominal) {
290315
return getLoweredType(nominalType);
291316
}
292317

293-
StructInst *VJPEmitter::buildPullbackValueStructValue(TermInst *termInst) {
294-
assert(termInst->getFunction() == original);
295-
auto loc = termInst->getFunction()->getLocation();
296-
auto *origBB = termInst->getParent();
318+
StructInst *VJPEmitter::buildPullbackValueStructValue(SILBasicBlock *origBB) {
319+
assert(origBB->getParent() == original);
320+
auto loc = origBB->getParent()->getLocation();
297321
auto *vjpBB = BBMap[origBB];
298322
auto *pbStruct = pullbackInfo.getLinearMapStruct(origBB);
299323
auto structLoweredTy = getNominalDeclLoweredType(pbStruct);
@@ -333,9 +357,11 @@ EnumInst *VJPEmitter::buildPredecessorEnumValue(SILBuilder &builder,
333357

334358
void VJPEmitter::visitReturnInst(ReturnInst *ri) {
335359
auto loc = ri->getOperand().getLoc();
336-
auto *origExit = ri->getParent();
337360
auto &builder = getBuilder();
338-
auto *pbStructVal = buildPullbackValueStructValue(ri);
361+
362+
// Build pullback struct value for original block.
363+
auto *origExit = ri->getParent();
364+
auto *pbStructVal = buildPullbackValueStructValue(origExit);
339365

340366
// Get the value in the VJP corresponding to the original result.
341367
auto *origRetInst = cast<ReturnInst>(origExit->getTerminator());
@@ -390,7 +416,7 @@ void VJPEmitter::visitBranchInst(BranchInst *bi) {
390416
// Build pullback struct value for original block.
391417
// Build predecessor enum value for destination block.
392418
auto *origBB = bi->getParent();
393-
auto *pbStructVal = buildPullbackValueStructValue(bi);
419+
auto *pbStructVal = buildPullbackValueStructValue(origBB);
394420
auto *enumVal = buildPredecessorEnumValue(getBuilder(), origBB,
395421
bi->getDestBB(), pbStructVal);
396422

@@ -407,85 +433,30 @@ void VJPEmitter::visitBranchInst(BranchInst *bi) {
407433

408434
void VJPEmitter::visitCondBranchInst(CondBranchInst *cbi) {
409435
// Build pullback struct value for original block.
410-
// Build predecessor enum values for true/false blocks.
411-
auto *origBB = cbi->getParent();
412-
auto *pbStructVal = buildPullbackValueStructValue(cbi);
413-
414-
// Creates a trampoline block for given original successor block. The
415-
// trampoline block has the same arguments as the VJP successor block but
416-
// drops the last predecessor enum argument. The generated `switch_enum`
417-
// instruction branches to the trampoline block, and the trampoline block
418-
// constructs a predecessor enum value and branches to the VJP successor
419-
// block.
420-
auto createTrampolineBasicBlock =
421-
[&](SILBasicBlock *origSuccBB) -> SILBasicBlock * {
422-
auto *vjpSuccBB = getOpBasicBlock(origSuccBB);
423-
// Create the trampoline block.
424-
auto *trampolineBB = vjp->createBasicBlockBefore(vjpSuccBB);
425-
for (auto *arg : vjpSuccBB->getArguments().drop_back())
426-
trampolineBB->createPhiArgument(arg->getType(), arg->getOwnershipKind());
427-
// Build predecessor enum value for successor block and branch to it.
428-
SILBuilder trampolineBuilder(trampolineBB);
429-
auto *succEnumVal = buildPredecessorEnumValue(trampolineBuilder, origBB,
430-
origSuccBB, pbStructVal);
431-
SmallVector<SILValue, 4> forwardedArguments(
432-
trampolineBB->getArguments().begin(),
433-
trampolineBB->getArguments().end());
434-
forwardedArguments.push_back(succEnumVal);
435-
trampolineBuilder.createBranch(cbi->getLoc(), vjpSuccBB,
436-
forwardedArguments);
437-
return trampolineBB;
438-
};
439-
436+
auto *pbStructVal = buildPullbackValueStructValue(cbi->getParent());
440437
// Create a new `cond_br` instruction.
441-
getBuilder().createCondBranch(cbi->getLoc(), getOpValue(cbi->getCondition()),
442-
createTrampolineBasicBlock(cbi->getTrueBB()),
443-
createTrampolineBasicBlock(cbi->getFalseBB()));
438+
getBuilder().createCondBranch(
439+
cbi->getLoc(), getOpValue(cbi->getCondition()),
440+
createTrampolineBasicBlock(cbi, pbStructVal, cbi->getTrueBB()),
441+
createTrampolineBasicBlock(cbi, pbStructVal, cbi->getFalseBB()));
444442
}
445443

446444
void VJPEmitter::visitSwitchEnumInstBase(SwitchEnumInstBase *sei) {
447445
// Build pullback struct value for original block.
448-
auto *origBB = sei->getParent();
449-
auto *pbStructVal = buildPullbackValueStructValue(sei);
450-
451-
// Creates a trampoline block for given original successor block. The
452-
// trampoline block has the same arguments as the VJP successor block but
453-
// drops the last predecessor enum argument. The generated `switch_enum`
454-
// instruction branches to the trampoline block, and the trampoline block
455-
// constructs a predecessor enum value and branches to the VJP successor
456-
// block.
457-
auto createTrampolineBasicBlock =
458-
[&](SILBasicBlock *origSuccBB) -> SILBasicBlock * {
459-
auto *vjpSuccBB = getOpBasicBlock(origSuccBB);
460-
// Create the trampoline block.
461-
auto *trampolineBB = vjp->createBasicBlockBefore(vjpSuccBB);
462-
for (auto *destArg : vjpSuccBB->getArguments().drop_back())
463-
trampolineBB->createPhiArgument(destArg->getType(),
464-
destArg->getOwnershipKind());
465-
// Build predecessor enum value for successor block and branch to it.
466-
SILBuilder trampolineBuilder(trampolineBB);
467-
auto *succEnumVal = buildPredecessorEnumValue(trampolineBuilder, origBB,
468-
origSuccBB, pbStructVal);
469-
SmallVector<SILValue, 4> forwardedArguments(
470-
trampolineBB->getArguments().begin(),
471-
trampolineBB->getArguments().end());
472-
forwardedArguments.push_back(succEnumVal);
473-
trampolineBuilder.createBranch(sei->getLoc(), vjpSuccBB,
474-
forwardedArguments);
475-
return trampolineBB;
476-
};
446+
auto *pbStructVal = buildPullbackValueStructValue(sei->getParent());
477447

478448
// Create trampoline successor basic blocks.
479449
SmallVector<std::pair<EnumElementDecl *, SILBasicBlock *>, 4> caseBBs;
480450
for (unsigned i : range(sei->getNumCases())) {
481451
auto caseBB = sei->getCase(i);
482-
auto *trampolineBB = createTrampolineBasicBlock(caseBB.second);
452+
auto *trampolineBB =
453+
createTrampolineBasicBlock(sei, pbStructVal, caseBB.second);
483454
caseBBs.push_back({caseBB.first, trampolineBB});
484455
}
485456
// Create trampoline default basic block.
486457
SILBasicBlock *newDefaultBB = nullptr;
487458
if (auto *defaultBB = sei->getDefaultBBOrNull().getPtrOrNull())
488-
newDefaultBB = createTrampolineBasicBlock(defaultBB);
459+
newDefaultBB = createTrampolineBasicBlock(sei, pbStructVal, defaultBB);
489460

490461
// Create a new `switch_enum` instruction.
491462
switch (sei->getKind()) {
@@ -510,6 +481,47 @@ void VJPEmitter::visitSwitchEnumAddrInst(SwitchEnumAddrInst *seai) {
510481
visitSwitchEnumInstBase(seai);
511482
}
512483

484+
void VJPEmitter::visitCheckedCastBranchInst(CheckedCastBranchInst *ccbi) {
485+
// Build pullback struct value for original block.
486+
auto *pbStructVal = buildPullbackValueStructValue(ccbi->getParent());
487+
// Create a new `checked_cast_branch` instruction.
488+
getBuilder().createCheckedCastBranch(
489+
ccbi->getLoc(), ccbi->isExact(), getOpValue(ccbi->getOperand()),
490+
getOpType(ccbi->getTargetLoweredType()),
491+
getOpASTType(ccbi->getTargetFormalType()),
492+
createTrampolineBasicBlock(ccbi, pbStructVal, ccbi->getSuccessBB()),
493+
createTrampolineBasicBlock(ccbi, pbStructVal, ccbi->getFailureBB()),
494+
ccbi->getTrueBBCount(), ccbi->getFalseBBCount());
495+
}
496+
497+
void VJPEmitter::visitCheckedCastValueBranchInst(
498+
CheckedCastValueBranchInst *ccvbi) {
499+
// Build pullback struct value for original block.
500+
auto *pbStructVal = buildPullbackValueStructValue(ccvbi->getParent());
501+
// Create a new `checked_cast_value_branch` instruction.
502+
getBuilder().createCheckedCastValueBranch(
503+
ccvbi->getLoc(), getOpValue(ccvbi->getOperand()),
504+
getOpASTType(ccvbi->getSourceFormalType()),
505+
getOpType(ccvbi->getTargetLoweredType()),
506+
getOpASTType(ccvbi->getTargetFormalType()),
507+
createTrampolineBasicBlock(ccvbi, pbStructVal, ccvbi->getSuccessBB()),
508+
createTrampolineBasicBlock(ccvbi, pbStructVal, ccvbi->getFailureBB()));
509+
}
510+
511+
void VJPEmitter::visitCheckedCastAddrBranchInst(
512+
CheckedCastAddrBranchInst *ccabi) {
513+
// Build pullback struct value for original block.
514+
auto *pbStructVal = buildPullbackValueStructValue(ccabi->getParent());
515+
// Create a new `checked_cast_addr_branch` instruction.
516+
getBuilder().createCheckedCastAddrBranch(
517+
ccabi->getLoc(), ccabi->getConsumptionKind(), getOpValue(ccabi->getSrc()),
518+
getOpASTType(ccabi->getSourceFormalType()), getOpValue(ccabi->getDest()),
519+
getOpASTType(ccabi->getTargetFormalType()),
520+
createTrampolineBasicBlock(ccabi, pbStructVal, ccabi->getSuccessBB()),
521+
createTrampolineBasicBlock(ccabi, pbStructVal, ccabi->getFailureBB()),
522+
ccabi->getTrueBBCount(), ccabi->getFalseBBCount());
523+
}
524+
513525
void VJPEmitter::visitApplyInst(ApplyInst *ai) {
514526
// If callee should not be differentiated, do standard cloning.
515527
if (!pullbackInfo.shouldDifferentiateApplySite(ai)) {

lib/SILOptimizer/Mandatory/Differentiation.cpp

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -152,10 +152,12 @@ static bool diagnoseUnsupportedControlFlow(ADContext &context,
152152
// Diagnose unsupported branching terminators.
153153
for (auto &bb : *original) {
154154
auto *term = bb.getTerminator();
155-
// Supported terminators are: `br`, `cond_br`, `switch_enum`,
156-
// `switch_enum_addr`.
155+
// Check supported branching terminators.
157156
if (isa<BranchInst>(term) || isa<CondBranchInst>(term) ||
158-
isa<SwitchEnumInst>(term) || isa<SwitchEnumAddrInst>(term))
157+
isa<SwitchEnumInst>(term) || isa<SwitchEnumAddrInst>(term) ||
158+
isa<CheckedCastBranchInst>(term) ||
159+
isa<CheckedCastValueBranchInst>(term) ||
160+
isa<CheckedCastAddrBranchInst>(term))
159161
continue;
160162
// If terminator is an unsupported branching terminator, emit an error.
161163
if (term->isBranch()) {

0 commit comments

Comments
 (0)