@@ -134,19 +134,6 @@ class JVPCloner::Implementation final
134134 // General utilities
135135 // --------------------------------------------------------------------------//
136136
137- SILBasicBlock::iterator getNextDifferentialLocalAllocationInsertionPoint () {
138- // If there are no local allocations, insert at the beginning of the tangent
139- // entry.
140- if (differentialLocalAllocations.empty ())
141- return getDifferential ().getEntryBlock ()->begin ();
142- // Otherwise, insert before the last local allocation. Inserting before
143- // rather than after ensures that allocation and zero initialization
144- // instructions are grouped together.
145- auto lastLocalAlloc = differentialLocalAllocations.back ();
146- auto it = lastLocalAlloc->getDefiningInstruction ()->getIterator ();
147- return it;
148- }
149-
150137 // / Get the lowered SIL type of the given AST type.
151138 SILType getLoweredType (Type type) {
152139 auto jvpGenSig = jvp->getLoweredFunctionType ()->getSubstGenericSignature ();
@@ -309,6 +296,8 @@ class JVPCloner::Implementation final
309296 // Tangent buffer mapping
310297 // --------------------------------------------------------------------------//
311298
299+ // / Sets the tangent buffer for the original buffer. Asserts that the
300+ // / original buffer does not already have a tangent buffer.
312301 void setTangentBuffer (SILBasicBlock *origBB, SILValue originalBuffer,
313302 SILValue tangentBuffer) {
314303 assert (originalBuffer->getType ().isAddress ());
@@ -318,13 +307,14 @@ class JVPCloner::Implementation final
318307 (void )insertion;
319308 }
320309
310+ // / Returns the tangent buffer for the original buffer. Asserts that the
311+ // / original buffer has a tangent buffer.
321312 SILValue &getTangentBuffer (SILBasicBlock *origBB, SILValue originalBuffer) {
322313 assert (originalBuffer->getType ().isAddress ());
323314 assert (originalBuffer->getFunction () == original);
324- auto insertion =
325- bufferMap.try_emplace ({origBB, originalBuffer}, SILValue ());
326- assert (!insertion.second && " Tangent buffer should already exist" );
327- return insertion.first ->getSecond ();
315+ auto it = bufferMap.find ({origBB, originalBuffer});
316+ assert (it != bufferMap.end () && " Tangent buffer should already exist" );
317+ return it->getSecond ();
328318 }
329319
330320 // --------------------------------------------------------------------------//
@@ -446,9 +436,21 @@ class JVPCloner::Implementation final
446436 // If an `apply` has active results or active inout parameters, replace it
447437 // with an `apply` of its JVP.
448438 void visitApplyInst (ApplyInst *ai) {
439+ bool shouldDifferentiate =
440+ differentialInfo.shouldDifferentiateApplySite (ai);
441+ // If the function has no active arguments or results, zero-initialize the
442+ // tangent buffers of the active indirect results.
443+ if (!shouldDifferentiate) {
444+ for (auto indResult : ai->getIndirectSILResults ())
445+ if (activityInfo.isActive (indResult, getIndices ())) {
446+ auto &tanBuf = getTangentBuffer (ai->getParent (), indResult);
447+ emitZeroIndirect (tanBuf->getType ().getASTType (), tanBuf,
448+ tanBuf.getLoc ());
449+ }
450+ }
449451 // If the function should not be differentiated or its the array literal
450452 // initialization intrinsic, just do standard cloning.
451- if (!differentialInfo. shouldDifferentiateApplySite (ai) ||
453+ if (!shouldDifferentiate ||
452454 ArraySemanticsCall (ai, semantics::ARRAY_UNINITIALIZED_INTRINSIC)) {
453455 LLVM_DEBUG (getADDebugStream () << " No active results:\n " << *ai << ' \n ' );
454456 TypeSubstCloner::visitApplyInst (ai);
@@ -789,7 +791,7 @@ class JVPCloner::Implementation final
789791 auto &diffBuilder = getDifferentialBuilder ();
790792 auto loc = dvi->getLoc ();
791793 auto tanVal = materializeTangent (getTangentValue (dvi->getOperand ()), loc);
792- diffBuilder.emitDestroyValue (loc, tanVal);
794+ diffBuilder.emitDestroyValueOperation (loc, tanVal);
793795 }
794796
795797 CLONE_AND_EMIT_TANGENT (CopyValue, cvi) {
@@ -804,7 +806,20 @@ class JVPCloner::Implementation final
804806 // / Handle `load` instruction.
805807 // / Original: y = load x
806808 // / Tangent: tan[y] = load tan[x]
807- CLONE_AND_EMIT_TANGENT (Load, li) {
809+ void visitLoadInst (LoadInst *li) {
810+ TypeSubstCloner::visitLoadInst (li);
811+ // If an active buffer is loaded with take to a non-active value, destroy
812+ // the active buffer's tangent buffer.
813+ if (!differentialInfo.shouldDifferentiateInstruction (li)) {
814+ auto isTake =
815+ (li->getOwnershipQualifier () == LoadOwnershipQualifier::Take);
816+ if (isTake && activityInfo.isActive (li->getOperand (), getIndices ())) {
817+ auto &tanBuf = getTangentBuffer (li->getParent (), li->getOperand ());
818+ getDifferentialBuilder ().emitDestroyOperation (tanBuf.getLoc (), tanBuf);
819+ }
820+ return ;
821+ }
822+ // Otherwise, do standard differential cloning.
808823 auto &diffBuilder = getDifferentialBuilder ();
809824 auto *bb = li->getParent ();
810825 auto loc = li->getLoc ();
@@ -829,7 +844,19 @@ class JVPCloner::Implementation final
829844 // / Handle `store` instruction in the differential.
830845 // / Original: store x to y
831846 // / Tangent: store tan[x] to tan[y]
832- CLONE_AND_EMIT_TANGENT (Store, si) {
847+ void visitStoreInst (StoreInst *si) {
848+ TypeSubstCloner::visitStoreInst (si);
849+ // If a non-active value is stored into an active buffer, zero-initialize
850+ // the active buffer's tangent buffer.
851+ if (!differentialInfo.shouldDifferentiateInstruction (si)) {
852+ if (activityInfo.isActive (si->getDest (), getIndices ())) {
853+ auto &tanBufDest = getTangentBuffer (si->getParent (), si->getDest ());
854+ emitZeroIndirect (tanBufDest->getType ().getASTType (), tanBufDest,
855+ tanBufDest.getLoc ());
856+ }
857+ return ;
858+ }
859+ // Otherwise, do standard differential cloning.
833860 auto &diffBuilder = getDifferentialBuilder ();
834861 auto loc = si->getLoc ();
835862 auto tanValSrc = materializeTangent (getTangentValue (si->getSrc ()), loc);
@@ -841,7 +868,19 @@ class JVPCloner::Implementation final
841868 // / Handle `store_borrow` instruction in the differential.
842869 // / Original: store_borrow x to y
843870 // / Tangent: store_borrow tan[x] to tan[y]
844- CLONE_AND_EMIT_TANGENT (StoreBorrow, sbi) {
871+ void visitStoreBorrowInst (StoreBorrowInst *sbi) {
872+ TypeSubstCloner::visitStoreBorrowInst (sbi);
873+ // If a non-active value is stored into an active buffer, zero-initialize
874+ // the active buffer's tangent buffer.
875+ if (!differentialInfo.shouldDifferentiateInstruction (sbi)) {
876+ if (activityInfo.isActive (sbi->getDest (), getIndices ())) {
877+ auto &tanBufDest = getTangentBuffer (sbi->getParent (), sbi->getDest ());
878+ emitZeroIndirect (tanBufDest->getType ().getASTType (), tanBufDest,
879+ tanBufDest.getLoc ());
880+ }
881+ return ;
882+ }
883+ // Otherwise, do standard differential cloning.
845884 auto &diffBuilder = getDifferentialBuilder ();
846885 auto loc = sbi->getLoc ();
847886 auto tanValSrc = materializeTangent (getTangentValue (sbi->getSrc ()), loc);
@@ -852,13 +891,32 @@ class JVPCloner::Implementation final
852891 // / Handle `copy_addr` instruction.
853892 // / Original: copy_addr x to y
854893 // / Tangent: copy_addr tan[x] to tan[y]
855- CLONE_AND_EMIT_TANGENT (CopyAddr, cai) {
894+ void visitCopyAddrInst (CopyAddrInst *cai) {
895+ TypeSubstCloner::visitCopyAddrInst (cai);
896+ // If a non-active buffer is copied into an active buffer, zero-initialize
897+ // the destination buffer's tangent buffer.
898+ // If an active buffer is copied with take into a non-active buffer, destroy
899+ // the source buffer's tangent buffer.
900+ if (!differentialInfo.shouldDifferentiateInstruction (cai)) {
901+ if (activityInfo.isActive (cai->getDest (), getIndices ())) {
902+ auto &tanBufDest = getTangentBuffer (cai->getParent (), cai->getDest ());
903+ emitZeroIndirect (tanBufDest->getType ().getASTType (), tanBufDest,
904+ tanBufDest.getLoc ());
905+ }
906+ if (cai->isTakeOfSrc () &&
907+ activityInfo.isActive (cai->getSrc (), getIndices ())) {
908+ auto &tanBufSrc = getTangentBuffer (cai->getParent (), cai->getSrc ());
909+ getDifferentialBuilder ().emitDestroyOperation (tanBufSrc.getLoc (),
910+ tanBufSrc);
911+ }
912+ return ;
913+ }
914+ // Otherwise, do standard differential cloning.
856915 auto diffBuilder = getDifferentialBuilder ();
857916 auto loc = cai->getLoc ();
858917 auto *bb = cai->getParent ();
859918 auto &tanSrc = getTangentBuffer (bb, cai->getSrc ());
860919 auto tanDest = getTangentBuffer (bb, cai->getDest ());
861-
862920 diffBuilder.createCopyAddr (loc, tanSrc, tanDest, cai->isTakeOfSrc (),
863921 cai->isInitializationOfDest ());
864922 }
@@ -918,8 +976,8 @@ class JVPCloner::Implementation final
918976 auto &diffBuilder = getDifferentialBuilder ();
919977 auto *bb = eai->getParent ();
920978 auto loc = eai->getLoc ();
921- auto tanSrc = getTangentBuffer (bb, eai->getOperand ());
922- diffBuilder.createEndAccess (loc, tanSrc , eai->isAborting ());
979+ auto tanOperand = getTangentBuffer (bb, eai->getOperand ());
980+ diffBuilder.createEndAccess (loc, tanOperand , eai->isAborting ());
923981 }
924982
925983 // / Handle `alloc_stack` instruction.
@@ -930,7 +988,7 @@ class JVPCloner::Implementation final
930988 auto *mappedAllocStackInst = diffBuilder.createAllocStack (
931989 asi->getLoc (), getRemappedTangentType (asi->getElementType ()),
932990 asi->getVarInfo ());
933- bufferMap. try_emplace ({ asi->getParent (), asi} , mappedAllocStackInst);
991+ setTangentBuffer ( asi->getParent (), asi, mappedAllocStackInst);
934992 }
935993
936994 // / Handle `dealloc_stack` instruction.
@@ -1062,16 +1120,15 @@ class JVPCloner::Implementation final
10621120 auto tanType = getRemappedTangentType (tei->getType ());
10631121 auto tanSource =
10641122 materializeTangent (getTangentValue (tei->getOperand ()), loc);
1065- SILValue tanBuf;
1066- // If the tangent buffer of the source does not have a tuple type, then
1123+ // If the tangent value of the source does not have a tuple type, then
10671124 // it must represent a "single element tuple type". Use it directly.
10681125 if (!tanSource->getType ().is <TupleType>()) {
10691126 setTangentValue (tei->getParent (), tei,
10701127 makeConcreteTangentValue (tanSource));
10711128 } else {
1072- tanBuf =
1129+ auto tanElt =
10731130 diffBuilder.createTupleExtract (loc, tanSource, tanIndex, tanType);
1074- bufferMap. try_emplace ({ tei->getParent (), tei}, tanBuf );
1131+ setTangentValue ( tei->getParent (), tei, makeConcreteTangentValue (tanElt) );
10751132 }
10761133 }
10771134
@@ -1100,7 +1157,7 @@ class JVPCloner::Implementation final
11001157 tanBuf = diffBuilder.createTupleElementAddr (teai->getLoc (), tanSource,
11011158 tanIndex, tanType);
11021159 }
1103- bufferMap. try_emplace ({ teai->getParent (), teai} , tanBuf);
1160+ setTangentBuffer ( teai->getParent (), teai, tanBuf);
11041161 }
11051162
11061163 // / Handle `destructure_tuple` instruction.
@@ -1282,9 +1339,8 @@ class JVPCloner::Implementation final
12821339 // Collect original results.
12831340 SmallVector<SILValue, 2 > originalResults;
12841341 collectAllDirectResultsInTypeOrder (*original, originalResults);
1285- // Collect differential return elements .
1342+ // Collect differential direct results .
12861343 SmallVector<SILValue, 8 > retElts;
1287- // for (auto origResult : originalResults) {
12881344 for (auto i : range (originalResults.size ())) {
12891345 auto origResult = originalResults[i];
12901346 if (!getIndices ().results ->contains (i))
@@ -1401,7 +1457,10 @@ JVPCloner::Implementation::getDifferentialStructElement(SILBasicBlock *origBB,
14011457void JVPCloner::Implementation::prepareForDifferentialGeneration () {
14021458 // Create differential blocks and arguments.
14031459 auto &differential = getDifferential ();
1460+ auto diffLoc = differential.getLocation ();
14041461 auto *origEntry = original->getEntryBlock ();
1462+ auto origFnTy = original->getLoweredFunctionType ();
1463+
14051464 for (auto &origBB : *original) {
14061465 auto *diffBB = differential.createBasicBlock ();
14071466 diffBBMap.insert ({&origBB, diffBB});
@@ -1482,21 +1541,51 @@ void JVPCloner::Implementation::prepareForDifferentialGeneration() {
14821541 << " as the tangent of original result " << *origArg);
14831542 }
14841543
1485- // Initialize tangent mapping for indirect results.
1486- auto origIndResults = original->getIndirectResults ();
1544+ // Initialize tangent mapping for original indirect results and non-wrt
1545+ // `inout` parameters. The tangent buffers of these address values are
1546+ // differential indirect results.
1547+
1548+ // Collect original results.
1549+ SmallVector<SILValue, 2 > originalResults;
1550+ collectAllFormalResultsInTypeOrder (*original, originalResults);
1551+
1552+ // Iterate over differentiability results.
1553+ differentialBuilder.setInsertionPoint (differential.getEntryBlock ());
14871554 auto diffIndResults = differential.getIndirectResults ();
1488- #ifndef NDEBUG
1489- unsigned numNonWrtInoutParameters = llvm::count_if (
1490- range (original->getLoweredFunctionType ()->getNumParameters ()),
1491- [&] (unsigned i) {
1492- auto ¶mInfo = original->getLoweredFunctionType ()->getParameters ()[i];
1493- return paramInfo.isIndirectInOut () && !getIndices ().parameters ->contains (i);
1494- });
1495- #endif
1496- assert (origIndResults.size () + numNonWrtInoutParameters == diffIndResults.size ());
1497- for (auto &origBB : *original)
1498- for (auto i : indices (origIndResults))
1499- setTangentBuffer (&origBB, origIndResults[i], diffIndResults[i]);
1555+ unsigned differentialIndirectResultIndex = 0 ;
1556+ for (auto resultIndex : getIndices ().results ->getIndices ()) {
1557+ auto origResult = originalResults[resultIndex];
1558+ // Handle original formal indirect result.
1559+ if (resultIndex < origFnTy->getNumResults ()) {
1560+ // Skip original direct results.
1561+ if (origResult->getType ().isObject ())
1562+ continue ;
1563+ auto diffIndResult = diffIndResults[differentialIndirectResultIndex++];
1564+ setTangentBuffer (origEntry, origResult, diffIndResult);
1565+ // If original indirect result is non-varied, zero-initialize its tangent
1566+ // buffer.
1567+ if (!activityInfo.isVaried (origResult, getIndices ().parameters ))
1568+ emitZeroIndirect (diffIndResult->getType ().getASTType (),
1569+ diffIndResult, diffLoc);
1570+ continue ;
1571+ }
1572+ // Handle original non-wrt `inout` parameter.
1573+ // Only original *non-wrt* `inout` parameters have corresponding
1574+ // differential indirect results.
1575+ auto inoutParamIndex = resultIndex - origFnTy->getNumResults ();
1576+ auto inoutParamIt = std::next (
1577+ origFnTy->getIndirectMutatingParameters ().begin (), inoutParamIndex);
1578+ auto paramIndex =
1579+ std::distance (origFnTy->getParameters ().begin (), &*inoutParamIt);
1580+ if (getIndices ().parameters ->contains (paramIndex))
1581+ continue ;
1582+ auto diffIndResult = diffIndResults[differentialIndirectResultIndex++];
1583+ setTangentBuffer (origEntry, origResult, diffIndResult);
1584+ // Original `inout` parameters are initialized, so their tangent buffers
1585+ // must also be initialized.
1586+ emitZeroIndirect (diffIndResult->getType ().getASTType (),
1587+ diffIndResult, diffLoc);
1588+ }
15001589}
15011590
15021591/* static*/ SILFunction *JVPCloner::Implementation::createEmptyDifferential (
@@ -1526,7 +1615,6 @@ void JVPCloner::Implementation::prepareForDifferentialGeneration() {
15261615 auto origParams = origTy->getParameters ();
15271616 auto indices = witness->getSILAutoDiffIndices ();
15281617
1529-
15301618 for (auto resultIndex : indices.results ->getIndices ()) {
15311619 if (resultIndex < origTy->getNumResults ()) {
15321620 // Handle formal original result.
@@ -1539,17 +1627,16 @@ void JVPCloner::Implementation::prepareForDifferentialGeneration() {
15391627 ->getType ()
15401628 ->getCanonicalType (witnessCanGenSig),
15411629 origResult.getConvention ()));
1542- }
1543- else {
1630+ } else {
15441631 // Handle original `inout` parameter.
15451632 auto inoutParamIndex = resultIndex - origTy->getNumResults ();
15461633 auto inoutParamIt = std::next (
15471634 origTy->getIndirectMutatingParameters ().begin (), inoutParamIndex);
15481635 auto paramIndex =
15491636 std::distance (origTy->getParameters ().begin (), &*inoutParamIt);
1550- // If the original `inout` parameter is a differentiability parameter, then
1551- // it already has a corresponding differential parameter. Skip adding a
1552- // corresponding differential result.
1637+ // If the original `inout` parameter is a differentiability parameter,
1638+ // then it already has a corresponding differential parameter. Do not add
1639+ // a corresponding differential result.
15531640 if (indices.parameters ->contains (paramIndex))
15541641 continue ;
15551642 auto inoutParam = origTy->getParameters ()[paramIndex];
0 commit comments