Skip to content

Commit 3c3093d

Browse files
committed
[AutoDiff] Clean up VJP basic block utilities.
Add a common helper function `VJPEmitter::createTrampolineBasicBlock`. Change `VJPEmitter::buildPullbackValueStructValue` to take an original basic block instead of a terminator instruction.
1 parent a676a37 commit 3c3093d

File tree

2 files changed

+57
-74
lines changed

2 files changed

+57
-74
lines changed

include/swift/SILOptimizer/Differentiation/VJPEmitter.h

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -117,9 +117,21 @@ class VJPEmitter final
117117
/// Get the lowered SIL type of the given nominal type declaration.
118118
SILType getNominalDeclLoweredType(NominalTypeDecl *nominal);
119119

120-
/// Build a pullback struct value for the original block corresponding to the
121-
/// given terminator.
122-
StructInst *buildPullbackValueStructValue(TermInst *termInst);
120+
// Creates a trampoline block for given original terminator instruction, the
121+
// pullback struct value for its parent block, and a successor basic block.
122+
//
123+
// The trampoline block has the same arguments as and branches to the remapped
124+
// successor block, but drops the last predecessor enum argument.
125+
//
126+
// Used for cloning branching terminator instructions with specific
127+
// requirements on successor block arguments, where an additional predecessor
128+
// enum argument is not acceptable.
129+
SILBasicBlock *createTrampolineBasicBlock(TermInst *termInst,
130+
StructInst *pbStructVal,
131+
SILBasicBlock *succBB);
132+
133+
/// Build a pullback struct value for the given original block.
134+
StructInst *buildPullbackValueStructValue(SILBasicBlock *bb);
123135

124136
/// Build a predecessor enum instance using the given builder for the given
125137
/// original predecessor/successor blocks and pullback struct value.

lib/SILOptimizer/Differentiation/VJPEmitter.cpp

Lines changed: 42 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -265,6 +265,31 @@ SILBasicBlock *VJPEmitter::remapBasicBlock(SILBasicBlock *bb) {
265265
return vjpBB;
266266
}
267267

268+
SILBasicBlock *VJPEmitter::createTrampolineBasicBlock(TermInst *termInst,
269+
StructInst *pbStructVal,
270+
SILBasicBlock *succBB) {
271+
assert(llvm::find(termInst->getSuccessorBlocks(), succBB) !=
272+
termInst->getSuccessorBlocks().end() &&
273+
"Basic block is not a successor of terminator instruction");
274+
// Create the trampoline block.
275+
auto *vjpSuccBB = getOpBasicBlock(succBB);
276+
auto *trampolineBB = vjp->createBasicBlockBefore(vjpSuccBB);
277+
for (auto *arg : vjpSuccBB->getArguments().drop_back())
278+
trampolineBB->createPhiArgument(arg->getType(), arg->getOwnershipKind());
279+
// In the trampoline block, build predecessor enum value for VJP successor
280+
// block and branch to it.
281+
SILBuilder trampolineBuilder(trampolineBB);
282+
auto *origBB = termInst->getParent();
283+
auto *succEnumVal =
284+
buildPredecessorEnumValue(trampolineBuilder, origBB, succBB, pbStructVal);
285+
SmallVector<SILValue, 4> forwardedArguments(
286+
trampolineBB->getArguments().begin(), trampolineBB->getArguments().end());
287+
forwardedArguments.push_back(succEnumVal);
288+
trampolineBuilder.createBranch(termInst->getLoc(), vjpSuccBB,
289+
forwardedArguments);
290+
return trampolineBB;
291+
}
292+
268293
void VJPEmitter::visit(SILInstruction *inst) {
269294
if (errorOccurred)
270295
return;
@@ -290,10 +315,9 @@ SILType VJPEmitter::getNominalDeclLoweredType(NominalTypeDecl *nominal) {
290315
return getLoweredType(nominalType);
291316
}
292317

293-
StructInst *VJPEmitter::buildPullbackValueStructValue(TermInst *termInst) {
294-
assert(termInst->getFunction() == original);
295-
auto loc = termInst->getFunction()->getLocation();
296-
auto *origBB = termInst->getParent();
318+
StructInst *VJPEmitter::buildPullbackValueStructValue(SILBasicBlock *origBB) {
319+
assert(origBB->getParent() == original);
320+
auto loc = origBB->getParent()->getLocation();
297321
auto *vjpBB = BBMap[origBB];
298322
auto *pbStruct = pullbackInfo.getLinearMapStruct(origBB);
299323
auto structLoweredTy = getNominalDeclLoweredType(pbStruct);
@@ -333,9 +357,11 @@ EnumInst *VJPEmitter::buildPredecessorEnumValue(SILBuilder &builder,
333357

334358
void VJPEmitter::visitReturnInst(ReturnInst *ri) {
335359
auto loc = ri->getOperand().getLoc();
336-
auto *origExit = ri->getParent();
337360
auto &builder = getBuilder();
338-
auto *pbStructVal = buildPullbackValueStructValue(ri);
361+
362+
// Build pullback struct value for original block.
363+
auto *origExit = ri->getParent();
364+
auto *pbStructVal = buildPullbackValueStructValue(origExit);
339365

340366
// Get the value in the VJP corresponding to the original result.
341367
auto *origRetInst = cast<ReturnInst>(origExit->getTerminator());
@@ -390,7 +416,7 @@ void VJPEmitter::visitBranchInst(BranchInst *bi) {
390416
// Build pullback struct value for original block.
391417
// Build predecessor enum value for destination block.
392418
auto *origBB = bi->getParent();
393-
auto *pbStructVal = buildPullbackValueStructValue(bi);
419+
auto *pbStructVal = buildPullbackValueStructValue(origBB);
394420
auto *enumVal = buildPredecessorEnumValue(getBuilder(), origBB,
395421
bi->getDestBB(), pbStructVal);
396422

@@ -407,85 +433,30 @@ void VJPEmitter::visitBranchInst(BranchInst *bi) {
407433

408434
void VJPEmitter::visitCondBranchInst(CondBranchInst *cbi) {
409435
// Build pullback struct value for original block.
410-
// Build predecessor enum values for true/false blocks.
411-
auto *origBB = cbi->getParent();
412-
auto *pbStructVal = buildPullbackValueStructValue(cbi);
413-
414-
// Creates a trampoline block for given original successor block. The
415-
// trampoline block has the same arguments as the VJP successor block but
416-
// drops the last predecessor enum argument. The generated `switch_enum`
417-
// instruction branches to the trampoline block, and the trampoline block
418-
// constructs a predecessor enum value and branches to the VJP successor
419-
// block.
420-
auto createTrampolineBasicBlock =
421-
[&](SILBasicBlock *origSuccBB) -> SILBasicBlock * {
422-
auto *vjpSuccBB = getOpBasicBlock(origSuccBB);
423-
// Create the trampoline block.
424-
auto *trampolineBB = vjp->createBasicBlockBefore(vjpSuccBB);
425-
for (auto *arg : vjpSuccBB->getArguments().drop_back())
426-
trampolineBB->createPhiArgument(arg->getType(), arg->getOwnershipKind());
427-
// Build predecessor enum value for successor block and branch to it.
428-
SILBuilder trampolineBuilder(trampolineBB);
429-
auto *succEnumVal = buildPredecessorEnumValue(trampolineBuilder, origBB,
430-
origSuccBB, pbStructVal);
431-
SmallVector<SILValue, 4> forwardedArguments(
432-
trampolineBB->getArguments().begin(),
433-
trampolineBB->getArguments().end());
434-
forwardedArguments.push_back(succEnumVal);
435-
trampolineBuilder.createBranch(cbi->getLoc(), vjpSuccBB,
436-
forwardedArguments);
437-
return trampolineBB;
438-
};
439-
436+
auto *pbStructVal = buildPullbackValueStructValue(cbi->getParent());
440437
// Create a new `cond_br` instruction.
441-
getBuilder().createCondBranch(cbi->getLoc(), getOpValue(cbi->getCondition()),
442-
createTrampolineBasicBlock(cbi->getTrueBB()),
443-
createTrampolineBasicBlock(cbi->getFalseBB()));
438+
getBuilder().createCondBranch(
439+
cbi->getLoc(), getOpValue(cbi->getCondition()),
440+
createTrampolineBasicBlock(cbi, pbStructVal, cbi->getTrueBB()),
441+
createTrampolineBasicBlock(cbi, pbStructVal, cbi->getFalseBB()));
444442
}
445443

446444
void VJPEmitter::visitSwitchEnumInstBase(SwitchEnumInstBase *sei) {
447445
// Build pullback struct value for original block.
448-
auto *origBB = sei->getParent();
449-
auto *pbStructVal = buildPullbackValueStructValue(sei);
450-
451-
// Creates a trampoline block for given original successor block. The
452-
// trampoline block has the same arguments as the VJP successor block but
453-
// drops the last predecessor enum argument. The generated `switch_enum`
454-
// instruction branches to the trampoline block, and the trampoline block
455-
// constructs a predecessor enum value and branches to the VJP successor
456-
// block.
457-
auto createTrampolineBasicBlock =
458-
[&](SILBasicBlock *origSuccBB) -> SILBasicBlock * {
459-
auto *vjpSuccBB = getOpBasicBlock(origSuccBB);
460-
// Create the trampoline block.
461-
auto *trampolineBB = vjp->createBasicBlockBefore(vjpSuccBB);
462-
for (auto *destArg : vjpSuccBB->getArguments().drop_back())
463-
trampolineBB->createPhiArgument(destArg->getType(),
464-
destArg->getOwnershipKind());
465-
// Build predecessor enum value for successor block and branch to it.
466-
SILBuilder trampolineBuilder(trampolineBB);
467-
auto *succEnumVal = buildPredecessorEnumValue(trampolineBuilder, origBB,
468-
origSuccBB, pbStructVal);
469-
SmallVector<SILValue, 4> forwardedArguments(
470-
trampolineBB->getArguments().begin(),
471-
trampolineBB->getArguments().end());
472-
forwardedArguments.push_back(succEnumVal);
473-
trampolineBuilder.createBranch(sei->getLoc(), vjpSuccBB,
474-
forwardedArguments);
475-
return trampolineBB;
476-
};
446+
auto *pbStructVal = buildPullbackValueStructValue(sei->getParent());
477447

478448
// Create trampoline successor basic blocks.
479449
SmallVector<std::pair<EnumElementDecl *, SILBasicBlock *>, 4> caseBBs;
480450
for (unsigned i : range(sei->getNumCases())) {
481451
auto caseBB = sei->getCase(i);
482-
auto *trampolineBB = createTrampolineBasicBlock(caseBB.second);
452+
auto *trampolineBB =
453+
createTrampolineBasicBlock(sei, pbStructVal, caseBB.second);
483454
caseBBs.push_back({caseBB.first, trampolineBB});
484455
}
485456
// Create trampoline default basic block.
486457
SILBasicBlock *newDefaultBB = nullptr;
487458
if (auto *defaultBB = sei->getDefaultBBOrNull().getPtrOrNull())
488-
newDefaultBB = createTrampolineBasicBlock(defaultBB);
459+
newDefaultBB = createTrampolineBasicBlock(sei, pbStructVal, defaultBB);
489460

490461
// Create a new `switch_enum` instruction.
491462
switch (sei->getKind()) {

0 commit comments

Comments
 (0)