@@ -265,6 +265,31 @@ SILBasicBlock *VJPEmitter::remapBasicBlock(SILBasicBlock *bb) {
265
265
return vjpBB;
266
266
}
267
267
268
+ SILBasicBlock *VJPEmitter::createTrampolineBasicBlock (TermInst *termInst,
269
+ StructInst *pbStructVal,
270
+ SILBasicBlock *succBB) {
271
+ assert (llvm::find (termInst->getSuccessorBlocks (), succBB) !=
272
+ termInst->getSuccessorBlocks ().end () &&
273
+ " Basic block is not a successor of terminator instruction" );
274
+ // Create the trampoline block.
275
+ auto *vjpSuccBB = getOpBasicBlock (succBB);
276
+ auto *trampolineBB = vjp->createBasicBlockBefore (vjpSuccBB);
277
+ for (auto *arg : vjpSuccBB->getArguments ().drop_back ())
278
+ trampolineBB->createPhiArgument (arg->getType (), arg->getOwnershipKind ());
279
+ // In the trampoline block, build predecessor enum value for VJP successor
280
+ // block and branch to it.
281
+ SILBuilder trampolineBuilder (trampolineBB);
282
+ auto *origBB = termInst->getParent ();
283
+ auto *succEnumVal =
284
+ buildPredecessorEnumValue (trampolineBuilder, origBB, succBB, pbStructVal);
285
+ SmallVector<SILValue, 4 > forwardedArguments (
286
+ trampolineBB->getArguments ().begin (), trampolineBB->getArguments ().end ());
287
+ forwardedArguments.push_back (succEnumVal);
288
+ trampolineBuilder.createBranch (termInst->getLoc (), vjpSuccBB,
289
+ forwardedArguments);
290
+ return trampolineBB;
291
+ }
292
+
268
293
void VJPEmitter::visit (SILInstruction *inst) {
269
294
if (errorOccurred)
270
295
return ;
@@ -290,10 +315,9 @@ SILType VJPEmitter::getNominalDeclLoweredType(NominalTypeDecl *nominal) {
290
315
return getLoweredType (nominalType);
291
316
}
292
317
293
- StructInst *VJPEmitter::buildPullbackValueStructValue (TermInst *termInst) {
294
- assert (termInst->getFunction () == original);
295
- auto loc = termInst->getFunction ()->getLocation ();
296
- auto *origBB = termInst->getParent ();
318
+ StructInst *VJPEmitter::buildPullbackValueStructValue (SILBasicBlock *origBB) {
319
+ assert (origBB->getParent () == original);
320
+ auto loc = origBB->getParent ()->getLocation ();
297
321
auto *vjpBB = BBMap[origBB];
298
322
auto *pbStruct = pullbackInfo.getLinearMapStruct (origBB);
299
323
auto structLoweredTy = getNominalDeclLoweredType (pbStruct);
@@ -333,9 +357,11 @@ EnumInst *VJPEmitter::buildPredecessorEnumValue(SILBuilder &builder,
333
357
334
358
void VJPEmitter::visitReturnInst (ReturnInst *ri) {
335
359
auto loc = ri->getOperand ().getLoc ();
336
- auto *origExit = ri->getParent ();
337
360
auto &builder = getBuilder ();
338
- auto *pbStructVal = buildPullbackValueStructValue (ri);
361
+
362
+ // Build pullback struct value for original block.
363
+ auto *origExit = ri->getParent ();
364
+ auto *pbStructVal = buildPullbackValueStructValue (origExit);
339
365
340
366
// Get the value in the VJP corresponding to the original result.
341
367
auto *origRetInst = cast<ReturnInst>(origExit->getTerminator ());
@@ -390,7 +416,7 @@ void VJPEmitter::visitBranchInst(BranchInst *bi) {
390
416
// Build pullback struct value for original block.
391
417
// Build predecessor enum value for destination block.
392
418
auto *origBB = bi->getParent ();
393
- auto *pbStructVal = buildPullbackValueStructValue (bi );
419
+ auto *pbStructVal = buildPullbackValueStructValue (origBB );
394
420
auto *enumVal = buildPredecessorEnumValue (getBuilder (), origBB,
395
421
bi->getDestBB (), pbStructVal);
396
422
@@ -407,85 +433,30 @@ void VJPEmitter::visitBranchInst(BranchInst *bi) {
407
433
408
434
void VJPEmitter::visitCondBranchInst (CondBranchInst *cbi) {
409
435
// Build pullback struct value for original block.
410
- // Build predecessor enum values for true/false blocks.
411
- auto *origBB = cbi->getParent ();
412
- auto *pbStructVal = buildPullbackValueStructValue (cbi);
413
-
414
- // Creates a trampoline block for given original successor block. The
415
- // trampoline block has the same arguments as the VJP successor block but
416
- // drops the last predecessor enum argument. The generated `switch_enum`
417
- // instruction branches to the trampoline block, and the trampoline block
418
- // constructs a predecessor enum value and branches to the VJP successor
419
- // block.
420
- auto createTrampolineBasicBlock =
421
- [&](SILBasicBlock *origSuccBB) -> SILBasicBlock * {
422
- auto *vjpSuccBB = getOpBasicBlock (origSuccBB);
423
- // Create the trampoline block.
424
- auto *trampolineBB = vjp->createBasicBlockBefore (vjpSuccBB);
425
- for (auto *arg : vjpSuccBB->getArguments ().drop_back ())
426
- trampolineBB->createPhiArgument (arg->getType (), arg->getOwnershipKind ());
427
- // Build predecessor enum value for successor block and branch to it.
428
- SILBuilder trampolineBuilder (trampolineBB);
429
- auto *succEnumVal = buildPredecessorEnumValue (trampolineBuilder, origBB,
430
- origSuccBB, pbStructVal);
431
- SmallVector<SILValue, 4 > forwardedArguments (
432
- trampolineBB->getArguments ().begin (),
433
- trampolineBB->getArguments ().end ());
434
- forwardedArguments.push_back (succEnumVal);
435
- trampolineBuilder.createBranch (cbi->getLoc (), vjpSuccBB,
436
- forwardedArguments);
437
- return trampolineBB;
438
- };
439
-
436
+ auto *pbStructVal = buildPullbackValueStructValue (cbi->getParent ());
440
437
// Create a new `cond_br` instruction.
441
- getBuilder ().createCondBranch (cbi->getLoc (), getOpValue (cbi->getCondition ()),
442
- createTrampolineBasicBlock (cbi->getTrueBB ()),
443
- createTrampolineBasicBlock (cbi->getFalseBB ()));
438
+ getBuilder ().createCondBranch (
439
+ cbi->getLoc (), getOpValue (cbi->getCondition ()),
440
+ createTrampolineBasicBlock (cbi, pbStructVal, cbi->getTrueBB ()),
441
+ createTrampolineBasicBlock (cbi, pbStructVal, cbi->getFalseBB ()));
444
442
}
445
443
446
444
void VJPEmitter::visitSwitchEnumInstBase (SwitchEnumInstBase *sei) {
447
445
// Build pullback struct value for original block.
448
- auto *origBB = sei->getParent ();
449
- auto *pbStructVal = buildPullbackValueStructValue (sei);
450
-
451
- // Creates a trampoline block for given original successor block. The
452
- // trampoline block has the same arguments as the VJP successor block but
453
- // drops the last predecessor enum argument. The generated `switch_enum`
454
- // instruction branches to the trampoline block, and the trampoline block
455
- // constructs a predecessor enum value and branches to the VJP successor
456
- // block.
457
- auto createTrampolineBasicBlock =
458
- [&](SILBasicBlock *origSuccBB) -> SILBasicBlock * {
459
- auto *vjpSuccBB = getOpBasicBlock (origSuccBB);
460
- // Create the trampoline block.
461
- auto *trampolineBB = vjp->createBasicBlockBefore (vjpSuccBB);
462
- for (auto *destArg : vjpSuccBB->getArguments ().drop_back ())
463
- trampolineBB->createPhiArgument (destArg->getType (),
464
- destArg->getOwnershipKind ());
465
- // Build predecessor enum value for successor block and branch to it.
466
- SILBuilder trampolineBuilder (trampolineBB);
467
- auto *succEnumVal = buildPredecessorEnumValue (trampolineBuilder, origBB,
468
- origSuccBB, pbStructVal);
469
- SmallVector<SILValue, 4 > forwardedArguments (
470
- trampolineBB->getArguments ().begin (),
471
- trampolineBB->getArguments ().end ());
472
- forwardedArguments.push_back (succEnumVal);
473
- trampolineBuilder.createBranch (sei->getLoc (), vjpSuccBB,
474
- forwardedArguments);
475
- return trampolineBB;
476
- };
446
+ auto *pbStructVal = buildPullbackValueStructValue (sei->getParent ());
477
447
478
448
// Create trampoline successor basic blocks.
479
449
SmallVector<std::pair<EnumElementDecl *, SILBasicBlock *>, 4 > caseBBs;
480
450
for (unsigned i : range (sei->getNumCases ())) {
481
451
auto caseBB = sei->getCase (i);
482
- auto *trampolineBB = createTrampolineBasicBlock (caseBB.second );
452
+ auto *trampolineBB =
453
+ createTrampolineBasicBlock (sei, pbStructVal, caseBB.second );
483
454
caseBBs.push_back ({caseBB.first , trampolineBB});
484
455
}
485
456
// Create trampoline default basic block.
486
457
SILBasicBlock *newDefaultBB = nullptr ;
487
458
if (auto *defaultBB = sei->getDefaultBBOrNull ().getPtrOrNull ())
488
- newDefaultBB = createTrampolineBasicBlock (defaultBB);
459
+ newDefaultBB = createTrampolineBasicBlock (sei, pbStructVal, defaultBB);
489
460
490
461
// Create a new `switch_enum` instruction.
491
462
switch (sei->getKind ()) {
@@ -510,6 +481,47 @@ void VJPEmitter::visitSwitchEnumAddrInst(SwitchEnumAddrInst *seai) {
510
481
visitSwitchEnumInstBase (seai);
511
482
}
512
483
484
+ void VJPEmitter::visitCheckedCastBranchInst (CheckedCastBranchInst *ccbi) {
485
+ // Build pullback struct value for original block.
486
+ auto *pbStructVal = buildPullbackValueStructValue (ccbi->getParent ());
487
+ // Create a new `checked_cast_branch` instruction.
488
+ getBuilder ().createCheckedCastBranch (
489
+ ccbi->getLoc (), ccbi->isExact (), getOpValue (ccbi->getOperand ()),
490
+ getOpType (ccbi->getTargetLoweredType ()),
491
+ getOpASTType (ccbi->getTargetFormalType ()),
492
+ createTrampolineBasicBlock (ccbi, pbStructVal, ccbi->getSuccessBB ()),
493
+ createTrampolineBasicBlock (ccbi, pbStructVal, ccbi->getFailureBB ()),
494
+ ccbi->getTrueBBCount (), ccbi->getFalseBBCount ());
495
+ }
496
+
497
+ void VJPEmitter::visitCheckedCastValueBranchInst (
498
+ CheckedCastValueBranchInst *ccvbi) {
499
+ // Build pullback struct value for original block.
500
+ auto *pbStructVal = buildPullbackValueStructValue (ccvbi->getParent ());
501
+ // Create a new `checked_cast_value_branch` instruction.
502
+ getBuilder ().createCheckedCastValueBranch (
503
+ ccvbi->getLoc (), getOpValue (ccvbi->getOperand ()),
504
+ getOpASTType (ccvbi->getSourceFormalType ()),
505
+ getOpType (ccvbi->getTargetLoweredType ()),
506
+ getOpASTType (ccvbi->getTargetFormalType ()),
507
+ createTrampolineBasicBlock (ccvbi, pbStructVal, ccvbi->getSuccessBB ()),
508
+ createTrampolineBasicBlock (ccvbi, pbStructVal, ccvbi->getFailureBB ()));
509
+ }
510
+
511
+ void VJPEmitter::visitCheckedCastAddrBranchInst (
512
+ CheckedCastAddrBranchInst *ccabi) {
513
+ // Build pullback struct value for original block.
514
+ auto *pbStructVal = buildPullbackValueStructValue (ccabi->getParent ());
515
+ // Create a new `checked_cast_addr_branch` instruction.
516
+ getBuilder ().createCheckedCastAddrBranch (
517
+ ccabi->getLoc (), ccabi->getConsumptionKind (), getOpValue (ccabi->getSrc ()),
518
+ getOpASTType (ccabi->getSourceFormalType ()), getOpValue (ccabi->getDest ()),
519
+ getOpASTType (ccabi->getTargetFormalType ()),
520
+ createTrampolineBasicBlock (ccabi, pbStructVal, ccabi->getSuccessBB ()),
521
+ createTrampolineBasicBlock (ccabi, pbStructVal, ccabi->getFailureBB ()),
522
+ ccabi->getTrueBBCount (), ccabi->getFalseBBCount ());
523
+ }
524
+
513
525
void VJPEmitter::visitApplyInst (ApplyInst *ai) {
514
526
// If callee should not be differentiated, do standard cloning.
515
527
if (!pullbackInfo.shouldDifferentiateApplySite (ai)) {
0 commit comments