@@ -455,18 +455,6 @@ class JVPCloner::Implementation final
455
455
return ;
456
456
}
457
457
458
- // Diagnose functions with active inout arguments.
459
- // TODO(TF-129): Support `inout` argument differentiation.
460
- for (auto inoutArg : ai->getInoutArguments ()) {
461
- if (activityInfo.isActive (inoutArg, getIndices ())) {
462
- context.emitNondifferentiabilityError (
463
- ai, invoker,
464
- diag::autodiff_cannot_differentiate_through_inout_arguments);
465
- errorOccurred = true ;
466
- return ;
467
- }
468
- }
469
-
470
458
auto loc = ai->getLoc ();
471
459
auto &builder = getBuilder ();
472
460
auto origCallee = getOpValue (ai->getCallee ());
@@ -1241,6 +1229,10 @@ class JVPCloner::Implementation final
1241
1229
SmallVector<SILValue, 8 > differentialAllResults;
1242
1230
collectAllActualResultsInTypeOrder (
1243
1231
differentialCall, differentialDirectResults, differentialAllResults);
1232
+ for (auto inoutArg : ai->getInoutArguments ())
1233
+ origAllResults.push_back (inoutArg);
1234
+ for (auto inoutArg : differentialCall->getInoutArguments ())
1235
+ differentialAllResults.push_back (inoutArg);
1244
1236
assert (applyIndices.results ->getNumIndices () ==
1245
1237
differentialAllResults.size ());
1246
1238
@@ -1484,11 +1476,14 @@ void JVPCloner::Implementation::prepareForDifferentialGeneration() {
1484
1476
auto origIndResults = original->getIndirectResults ();
1485
1477
auto diffIndResults = differential.getIndirectResults ();
1486
1478
#ifndef NDEBUG
1487
- unsigned numInoutParameters = llvm::count_if (
1488
- original->getLoweredFunctionType ()->getParameters (),
1489
- [](SILParameterInfo paramInfo) { return paramInfo.isIndirectInOut (); });
1490
- assert (origIndResults.size () + numInoutParameters == diffIndResults.size ());
1479
+ unsigned numNonWrtInoutParameters = llvm::count_if (
1480
+ range (original->getLoweredFunctionType ()->getNumParameters ()),
1481
+ [&] (unsigned i) {
1482
+ auto ¶mInfo = original->getLoweredFunctionType ()->getParameters ()[i];
1483
+ return paramInfo.isIndirectInOut () && !getIndices ().parameters ->contains (i);
1484
+ });
1491
1485
#endif
1486
+ assert (origIndResults.size () + numNonWrtInoutParameters == diffIndResults.size ());
1492
1487
for (auto &origBB : *original)
1493
1488
for (auto i : indices (origIndResults))
1494
1489
setTangentBuffer (&origBB, origIndResults[i], diffIndResults[i]);
@@ -1521,23 +1516,10 @@ void JVPCloner::Implementation::prepareForDifferentialGeneration() {
1521
1516
auto origParams = origTy->getParameters ();
1522
1517
auto indices = witness->getSILAutoDiffIndices ();
1523
1518
1524
- // Add differential results.
1525
- Optional<SILParameterInfo> inoutDiffParam = None;
1526
- for (auto origParam : origTy->getParameters ()) {
1527
- if (!origParam.isIndirectInOut ())
1528
- continue ;
1529
- inoutDiffParam = origParam;
1530
- }
1531
-
1532
- if (inoutDiffParam) {
1533
- dfResults.push_back (
1534
- SILResultInfo (inoutDiffParam->getInterfaceType ()
1535
- ->getAutoDiffTangentSpace (lookupConformance)
1536
- ->getType ()
1537
- ->getCanonicalType (witnessCanGenSig),
1538
- ResultConvention::Indirect));
1539
- } else {
1540
- for (auto resultIndex : indices.results ->getIndices ()) {
1519
+
1520
+ for (auto resultIndex : indices.results ->getIndices ()) {
1521
+ if (resultIndex < origTy->getNumResults ()) {
1522
+ // Handle formal original result.
1541
1523
auto origResult = origTy->getResults ()[resultIndex];
1542
1524
origResult = origResult.getWithInterfaceType (
1543
1525
origResult.getInterfaceType ()->getCanonicalType (witnessCanGenSig));
@@ -1548,6 +1530,25 @@ void JVPCloner::Implementation::prepareForDifferentialGeneration() {
1548
1530
->getCanonicalType (witnessCanGenSig),
1549
1531
origResult.getConvention ()));
1550
1532
}
1533
+ else {
1534
+ // Handle original `inout` parameter.
1535
+ auto inoutParamIndex = resultIndex - origTy->getNumResults ();
1536
+ auto inoutParamIt = std::next (
1537
+ origTy->getIndirectMutatingParameters ().begin (), inoutParamIndex);
1538
+ auto paramIndex =
1539
+ std::distance (origTy->getParameters ().begin (), &*inoutParamIt);
1540
+ // If the original `inout` parameter is a differentiability parameter, then
1541
+ // it already has a corresponding differential parameter. Skip adding a
1542
+ // corresponding differential result.
1543
+ if (indices.parameters ->contains (paramIndex))
1544
+ continue ;
1545
+ auto inoutParam = origTy->getParameters ()[paramIndex];
1546
+ auto paramTan = inoutParam.getInterfaceType ()->getAutoDiffTangentSpace (
1547
+ lookupConformance);
1548
+ assert (paramTan && " Parameter type does not have a tangent space?" );
1549
+ dfResults.push_back (
1550
+ {paramTan->getCanonicalType (), ResultConvention::Indirect});
1551
+ }
1551
1552
}
1552
1553
1553
1554
// Add differential parameters for the requested wrt parameters.
0 commit comments