Skip to content

Commit 9fe0ed2

Browse files
authored
[AutoDiff] Fix class differentiation. (swiftlang#33151)
Fix pullback generation for `unchecked_ref_cast` and `upcast` instructions for class types with an address tangent value category. Upstream end-to-end class differentiation tests.
1 parent deb0254 commit 9fe0ed2

File tree

2 files changed

+553
-4
lines changed

2 files changed

+553
-4
lines changed

lib/SILOptimizer/Differentiation/PullbackCloner.cpp

Lines changed: 30 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1459,8 +1459,21 @@ class PullbackCloner::Implementation final
14591459
assert(getRemappedTangentType(urci->getOperand()->getType()) ==
14601460
getRemappedTangentType(urci->getType()) &&
14611461
"Operand/result must have the same `TangentVector` type");
1462-
auto adj = getAdjointValue(bb, urci);
1463-
addAdjointValue(bb, urci->getOperand(), adj, urci->getLoc());
1462+
switch (getTangentValueCategory(urci)) {
1463+
case SILValueCategory::Object: {
1464+
auto adj = getAdjointValue(bb, urci);
1465+
addAdjointValue(bb, urci->getOperand(), adj, urci->getLoc());
1466+
break;
1467+
}
1468+
case SILValueCategory::Address: {
1469+
auto &adjDest = getAdjointBuffer(bb, urci);
1470+
auto destType = remapType(adjDest->getType());
1471+
addToAdjointBuffer(bb, urci->getOperand(), adjDest, urci->getLoc());
1472+
builder.emitDestroyAddrAndFold(urci->getLoc(), adjDest);
1473+
emitZeroIndirect(destType.getASTType(), adjDest, urci->getLoc());
1474+
break;
1475+
}
1476+
}
14641477
}
14651478

14661479
/// Handle `upcast` instruction.
@@ -1473,8 +1486,21 @@ class PullbackCloner::Implementation final
14731486
assert(getRemappedTangentType(ui->getOperand()->getType()) ==
14741487
getRemappedTangentType(ui->getType()) &&
14751488
"Operand/result must have the same `TangentVector` type");
1476-
auto adj = getAdjointValue(bb, ui);
1477-
addAdjointValue(bb, ui->getOperand(), adj, ui->getLoc());
1489+
switch (getTangentValueCategory(ui)) {
1490+
case SILValueCategory::Object: {
1491+
auto adj = getAdjointValue(bb, ui);
1492+
addAdjointValue(bb, ui->getOperand(), adj, ui->getLoc());
1493+
break;
1494+
}
1495+
case SILValueCategory::Address: {
1496+
auto &adjDest = getAdjointBuffer(bb, ui);
1497+
auto destType = remapType(adjDest->getType());
1498+
addToAdjointBuffer(bb, ui->getOperand(), adjDest, ui->getLoc());
1499+
builder.emitDestroyAddrAndFold(ui->getLoc(), adjDest);
1500+
emitZeroIndirect(destType.getASTType(), adjDest, ui->getLoc());
1501+
break;
1502+
}
1503+
}
14781504
}
14791505

14801506
#define NOT_DIFFERENTIABLE(INST, DIAG) void visit##INST##Inst(INST##Inst *inst);

0 commit comments

Comments
 (0)