@@ -2052,6 +2052,8 @@ bool PullbackCloner::Implementation::run() {
20522052 SmallVector<SILValue, 8 > retElts;
20532053 // This vector will contain all indirect parameter adjoint buffers.
20542054 SmallVector<SILValue, 4 > indParamAdjoints;
2055+ // This vector will identify the locations where initialization is needed.
2056+ SmallBitVector outputsToInitialize;
20552057
20562058 auto conv = getOriginal ().getConventions ();
20572059 auto origParams = getOriginal ().getArgumentsWithoutIndirectResults ();
@@ -2071,25 +2073,62 @@ bool PullbackCloner::Implementation::run() {
20712073 case SILValueCategory::Address: {
20722074 auto adjBuf = getAdjointBuffer (origEntry, origParam);
20732075 indParamAdjoints.push_back (adjBuf);
2076+ outputsToInitialize.push_back (
2077+ !conv.getParameters ()[parameterIndex].isIndirectMutating ());
20742078 break ;
20752079 }
20762080 }
20772081 };
2082+ SmallVector<SILArgument *, 4 > pullbackIndirectResults (
2083+ getPullback ().getIndirectResults ().begin (),
2084+ getPullback ().getIndirectResults ().end ());
2085+
20782086 // Collect differentiation parameter adjoints.
2087+ // Do a first pass to collect non-inout values.
2088+ unsigned pullbackInoutArgumentIndex = 0 ;
2089+ for (auto i : getConfig ().parameterIndices ->getIndices ()) {
2090+ auto isParameterInout = conv.getParameters ()[i].isIndirectMutating ();
2091+ if (!isParameterInout) {
2092+ addRetElt (i);
2093+ }
2094+ }
2095+
2096+ // Do a second pass for all inout parameters.
20792097 for (auto i : getConfig ().parameterIndices ->getIndices ()) {
2080- // Skip `inout` parameters.
2081- if (conv.getParameters ()[i].isIndirectMutating ())
2098+ // Skip non-inout parameters.
2099+ auto isParameterInout = conv.getParameters ()[i].isIndirectMutating ();
2100+ if (!isParameterInout)
20822101 continue ;
2102+
2103+ // Skip `inout` parameters for functions with a single basic block:
2104+ // adjoint accumulation for those parameters is already done by
2105+ // per-instruction visitors.
2106+ if (getOriginal ().size () == 1 )
2107+ continue ;
2108+
2109+ // For functions with multiple basic blocks, accumulation is needed
2110+ // for `inout` parameters because pullback basic blocks have different
2111+ // adjoint buffers.
2112+ auto pullbackInoutArgument =
2113+ getPullback ()
2114+ .getArgumentsWithoutIndirectResults ()[pullbackInoutArgumentIndex++];
2115+ pullbackIndirectResults.push_back (pullbackInoutArgument);
20832116 addRetElt (i);
20842117 }
20852118
20862119 // Copy them to adjoint indirect results.
2087- assert (indParamAdjoints.size () == getPullback (). getIndirectResults () .size () &&
2120+ assert (indParamAdjoints.size () == pullbackIndirectResults .size () &&
20882121 " Indirect parameter adjoint count mismatch" );
2089- for (auto pair : zip (indParamAdjoints, getPullback ().getIndirectResults ())) {
2122+ unsigned currentIndex = 0 ;
2123+ for (auto pair : zip (indParamAdjoints, pullbackIndirectResults)) {
20902124 auto source = std::get<0 >(pair);
20912125 auto *dest = std::get<1 >(pair);
2092- builder.createCopyAddr (pbLoc, source, dest, IsTake, IsInitialization);
2126+ if (outputsToInitialize[currentIndex]) {
2127+ builder.createCopyAddr (pbLoc, source, dest, IsTake, IsInitialization);
2128+ } else {
2129+ builder.createCopyAddr (pbLoc, source, dest, IsTake, IsNotInitialization);
2130+ }
2131+ currentIndex++;
20932132 // Prevent source buffer from being deallocated, since the underlying
20942133 // value is moved.
20952134 destroyedLocalAllocations.insert (source);
0 commit comments