@@ -1459,8 +1459,21 @@ class PullbackCloner::Implementation final
1459
1459
assert (getRemappedTangentType (urci->getOperand ()->getType ()) ==
1460
1460
getRemappedTangentType (urci->getType ()) &&
1461
1461
" Operand/result must have the same `TangentVector` type" );
1462
- auto adj = getAdjointValue (bb, urci);
1463
- addAdjointValue (bb, urci->getOperand (), adj, urci->getLoc ());
1462
+ switch (getTangentValueCategory (urci)) {
1463
+ case SILValueCategory::Object: {
1464
+ auto adj = getAdjointValue (bb, urci);
1465
+ addAdjointValue (bb, urci->getOperand (), adj, urci->getLoc ());
1466
+ break ;
1467
+ }
1468
+ case SILValueCategory::Address: {
1469
+ auto &adjDest = getAdjointBuffer (bb, urci);
1470
+ auto destType = remapType (adjDest->getType ());
1471
+ addToAdjointBuffer (bb, urci->getOperand (), adjDest, urci->getLoc ());
1472
+ builder.emitDestroyAddrAndFold (urci->getLoc (), adjDest);
1473
+ emitZeroIndirect (destType.getASTType (), adjDest, urci->getLoc ());
1474
+ break ;
1475
+ }
1476
+ }
1464
1477
}
1465
1478
1466
1479
// / Handle `upcast` instruction.
@@ -1473,8 +1486,21 @@ class PullbackCloner::Implementation final
1473
1486
assert (getRemappedTangentType (ui->getOperand ()->getType ()) ==
1474
1487
getRemappedTangentType (ui->getType ()) &&
1475
1488
" Operand/result must have the same `TangentVector` type" );
1476
- auto adj = getAdjointValue (bb, ui);
1477
- addAdjointValue (bb, ui->getOperand (), adj, ui->getLoc ());
1489
+ switch (getTangentValueCategory (ui)) {
1490
+ case SILValueCategory::Object: {
1491
+ auto adj = getAdjointValue (bb, ui);
1492
+ addAdjointValue (bb, ui->getOperand (), adj, ui->getLoc ());
1493
+ break ;
1494
+ }
1495
+ case SILValueCategory::Address: {
1496
+ auto &adjDest = getAdjointBuffer (bb, ui);
1497
+ auto destType = remapType (adjDest->getType ());
1498
+ addToAdjointBuffer (bb, ui->getOperand (), adjDest, ui->getLoc ());
1499
+ builder.emitDestroyAddrAndFold (ui->getLoc (), adjDest);
1500
+ emitZeroIndirect (destType.getASTType (), adjDest, ui->getLoc ());
1501
+ break ;
1502
+ }
1503
+ }
1478
1504
}
1479
1505
1480
1506
#define NOT_DIFFERENTIABLE (INST, DIAG ) void visit##INST##Inst(INST##Inst *inst);
0 commit comments