@@ -744,14 +744,23 @@ class PullbackCloner::Implementation final
744744 // Optional differentiation
745745 // --------------------------------------------------------------------------//
746746
747- // / Given a `wrappedAdjoint` value of type `T.TangentVector`, creates an
748- // / `Optional<T>.TangentVector` value from it and adds it to the adjoint value
749- // / of `optionalValue`.
747+ // / Given a `wrappedAdjoint` value of type `T.TangentVector` and `Optional<T>`
748+ // / type, creates an `Optional<T>.TangentVector` buffer from it.
750749 // /
751750 // / `wrappedAdjoint` may be an object or address value, both cases are
752751 // / handled.
753- void accumulateAdjointForOptional (SILBasicBlock *bb, SILValue optionalValue,
754- SILValue wrappedAdjoint);
752+ AllocStackInst *createOptionalAdjoint (SILBasicBlock *bb,
753+ SILValue wrappedAdjoint,
754+ SILType optionalTy);
755+
756+ // / Accumulate optional buffer from `wrappedAdjoint`.
757+ void accumulateAdjointForOptionalBuffer (SILBasicBlock *bb,
758+ SILValue optionalBuffer,
759+ SILValue wrappedAdjoint);
760+
761+ // / Set optional value from `wrappedAdjoint`.
762+ void setAdjointValueForOptional (SILBasicBlock *bb, SILValue optionalValue,
763+ SILValue wrappedAdjoint);
755764
756765 // --------------------------------------------------------------------------//
757766 // Array literal initialization differentiation
@@ -1687,6 +1696,104 @@ class PullbackCloner::Implementation final
16871696 builder.emitZeroIntoBuffer (uccai->getLoc (), adjDest, IsInitialization);
16881697 }
16891698
1699+ // / Handle a sequence of `init_enum_data_addr` and `inject_enum_addr`
1700+ // / instructions.
1701+ // /
1702+ // / Original: y = init_enum_data_addr x
1703+ // / inject_enum_addr y
1704+ // /
1705+ // / Adjoint: adj[x] += unchecked_take_enum_data_addr adj[y]
1706+ void visitInjectEnumAddrInst (InjectEnumAddrInst *inject) {
1707+ SILBasicBlock *bb = inject->getParent ();
1708+ SILValue origEnum = inject->getOperand ();
1709+
1710+ // Only `Optional`-typed operands are supported for now. Diagnose all other
1711+ // enum operand types.
1712+ auto *optionalEnumDecl = getASTContext ().getOptionalDecl ();
1713+ if (origEnum->getType ().getEnumOrBoundGenericEnum () != optionalEnumDecl) {
1714+ LLVM_DEBUG (getADDebugStream ()
1715+ << " Unsupported enum type in PullbackCloner: " << *inject);
1716+ getContext ().emitNondifferentiabilityError (
1717+ inject, getInvoker (),
1718+ diag::autodiff_expression_not_differentiable_note);
1719+ errorOccurred = true ;
1720+ return ;
1721+ }
1722+
1723+ InitEnumDataAddrInst *origData = nullptr ;
1724+ for (auto use : origEnum->getUses ()) {
1725+ if (auto *init = dyn_cast<InitEnumDataAddrInst>(use->getUser ())) {
1726+ // We need a more complicated analysis when init_enum_data_addr and
1727+ // inject_enum_addr are in different blocks, or there is more than one
1728+ // such instruction. Bail out for now.
1729+ if (origData || init->getParent () != bb) {
1730+ LLVM_DEBUG (getADDebugStream ()
1731+ << " Could not find a matching init_enum_data_addr for: "
1732+ << *inject);
1733+ getContext ().emitNondifferentiabilityError (
1734+ inject, getInvoker (),
1735+ diag::autodiff_expression_not_differentiable_note);
1736+ errorOccurred = true ;
1737+ return ;
1738+ }
1739+
1740+ origData = init;
1741+ }
1742+ }
1743+
1744+ SILValue adjStruct = getAdjointBuffer (bb, origEnum);
1745+ StructDecl *adjStructDecl =
1746+ adjStruct->getType ().getStructOrBoundGenericStruct ();
1747+
1748+ VarDecl *adjOptVar = nullptr ;
1749+ if (adjStructDecl) {
1750+ ArrayRef<VarDecl *> properties = adjStructDecl->getStoredProperties ();
1751+ adjOptVar = properties.size () == 1 ? properties[0 ] : nullptr ;
1752+ }
1753+
1754+ EnumDecl *adjOptDecl =
1755+ adjOptVar ? adjOptVar->getTypeInContext ()->getEnumOrBoundGenericEnum ()
1756+ : nullptr ;
1757+
1758+ // Optional<T>.TangentVector should be a struct with a single
1759+ // Optional<T.TangentVector> property. This is an implementation detail of
1760+ // OptionalDifferentiation.swift
1761+ if (!adjOptDecl || adjOptDecl != optionalEnumDecl)
1762+ llvm_unreachable (" Unexpected type of Optional.TangentVector" );
1763+
1764+ SILLocation loc = origData->getLoc ();
1765+ StructElementAddrInst *adjOpt =
1766+ builder.createStructElementAddr (loc, adjStruct, adjOptVar);
1767+
1768+ // unchecked_take_enum_data_addr is destructive, so copy
1769+ // Optional<T.TangentVector> to a new alloca.
1770+ AllocStackInst *adjOptCopy =
1771+ createFunctionLocalAllocation (adjOpt->getType (), loc);
1772+ builder.createCopyAddr (loc, adjOpt, adjOptCopy, IsNotTake,
1773+ IsInitialization);
1774+
1775+ EnumElementDecl *someElemDecl = getASTContext ().getOptionalSomeDecl ();
1776+ UncheckedTakeEnumDataAddrInst *adjData =
1777+ builder.createUncheckedTakeEnumDataAddr (loc, adjOptCopy, someElemDecl);
1778+
1779+ setAdjointBuffer (bb, origData, adjData);
1780+
1781+ // The Optional copy is invalidated, do not attempt to destroy it at the end
1782+ // of the pullback. The value returned from unchecked_take_enum_data_addr is
1783+ // destroyed in visitInitEnumDataAddrInst.
1784+ destroyedLocalAllocations.insert (adjOptCopy);
1785+ }
1786+
1787+ // / Handle `init_enum_data_addr` instruction.
1788+ // / Destroy the value returned from `unchecked_take_enum_data_addr`.
1789+ void visitInitEnumDataAddrInst (InitEnumDataAddrInst *init) {
1790+ auto bufIt = bufferMap.find ({init->getParent (), SILValue (init)});
1791+ if (bufIt == bufferMap.end ())
1792+ return ;
1793+ SILValue adjData = bufIt->second ;
1794+ builder.emitDestroyAddr (init->getLoc (), adjData);
1795+ }
1796+
16901797 // / Handle `unchecked_ref_cast` instruction.
16911798 // / Original: y = unchecked_ref_cast x
16921799 // / Adjoint: adj[x] += adj[y]
@@ -1758,7 +1865,7 @@ class PullbackCloner::Implementation final
17581865 errorOccurred = true ;
17591866 return ;
17601867 }
1761- accumulateAdjointForOptional (bb, utedai->getOperand (), adjDest);
1868+ accumulateAdjointForOptionalBuffer (bb, utedai->getOperand (), adjDest);
17621869 builder.emitZeroIntoBuffer (utedai->getLoc (), adjDest, IsNotInitialization);
17631870 }
17641871
@@ -2342,12 +2449,11 @@ void PullbackCloner::Implementation::emitZeroDerivativesForNonvariedResult(
23422449 << pullback);
23432450}
23442451
2345- void PullbackCloner::Implementation::accumulateAdjointForOptional (
2346- SILBasicBlock *bb, SILValue optionalValue, SILValue wrappedAdjoint ) {
2452+ AllocStackInst * PullbackCloner::Implementation::createOptionalAdjoint (
2453+ SILBasicBlock *bb, SILValue wrappedAdjoint, SILType optionalTy ) {
23472454 auto pbLoc = getPullback ().getLocation ();
2348- // Handle `switch_enum` on `Optional`.
23492455 // `Optional<T>`
2350- auto optionalTy = remapType (optionalValue-> getType () );
2456+ optionalTy = remapType (optionalTy );
23512457 assert (optionalTy.getASTType ()->isOptional ());
23522458 // `T`
23532459 auto wrappedType = optionalTy.getOptionalObjectType ();
@@ -2429,13 +2535,45 @@ void PullbackCloner::Implementation::accumulateAdjointForOptional(
24292535 builder.createApply (pbLoc, initFnRef, subMap,
24302536 {optTanAdjBuf, optArgBuf, metatype});
24312537 builder.createDeallocStack (pbLoc, optArgBuf);
2538+ return optTanAdjBuf;
2539+ }
2540+
2541+ // Accumulate adjoint for the incoming `Optional` buffer.
2542+ void PullbackCloner::Implementation::accumulateAdjointForOptionalBuffer (
2543+ SILBasicBlock *bb, SILValue optionalBuffer, SILValue wrappedAdjoint) {
2544+ assert (getTangentValueCategory (optionalBuffer) == SILValueCategory::Address);
2545+ auto pbLoc = getPullback ().getLocation ();
24322546
2433- // Accumulate adjoint for the incoming `Optional` value.
2434- addToAdjointBuffer (bb, optionalValue, optTanAdjBuf, pbLoc);
2547+ // Allocate and initialize Optional<Wrapped>.TangentVector from
2548+ // Wrapped.TangentVector
2549+ AllocStackInst *optTanAdjBuf =
2550+ createOptionalAdjoint (bb, wrappedAdjoint, optionalBuffer->getType ());
2551+
2552+ // Accumulate into optionalBuffer
2553+ addToAdjointBuffer (bb, optionalBuffer, optTanAdjBuf, pbLoc);
24352554 builder.emitDestroyAddr (pbLoc, optTanAdjBuf);
24362555 builder.createDeallocStack (pbLoc, optTanAdjBuf);
24372556}
24382557
2558+ // Set the adjoint value for the incoming `Optional` value.
2559+ void PullbackCloner::Implementation::setAdjointValueForOptional (
2560+ SILBasicBlock *bb, SILValue optionalValue, SILValue wrappedAdjoint) {
2561+ assert (getTangentValueCategory (optionalValue) == SILValueCategory::Object);
2562+ auto pbLoc = getPullback ().getLocation ();
2563+
2564+ // Allocate and initialize Optional<Wrapped>.TangentVector from
2565+ // Wrapped.TangentVector
2566+ AllocStackInst *optTanAdjBuf =
2567+ createOptionalAdjoint (bb, wrappedAdjoint, optionalValue->getType ());
2568+
2569+ auto optTanAdjVal = builder.emitLoadValueOperation (
2570+ pbLoc, optTanAdjBuf, LoadOwnershipQualifier::Take);
2571+ recordTemporary (optTanAdjVal);
2572+ builder.createDeallocStack (pbLoc, optTanAdjBuf);
2573+
2574+ setAdjointValue (bb, optionalValue, makeConcreteAdjointValue (optTanAdjVal));
2575+ }
2576+
24392577SILBasicBlock *PullbackCloner::Implementation::buildPullbackSuccessor (
24402578 SILBasicBlock *origBB, SILBasicBlock *origPredBB,
24412579 SmallDenseMap<SILValue, TrampolineBlockSet> &pullbackTrampolineBlockMap) {
@@ -2623,7 +2761,7 @@ void PullbackCloner::Implementation::visitSILBasicBlock(SILBasicBlock *bb) {
26232761 // Handle `switch_enum` on `Optional`.
26242762 auto termInst = bbArg->getSingleTerminator ();
26252763 if (isSwitchEnumInstOnOptional (termInst)) {
2626- accumulateAdjointForOptional (bb, incomingValue, concreteBBArgAdjCopy);
2764+ setAdjointValueForOptional (bb, incomingValue, concreteBBArgAdjCopy);
26272765 } else {
26282766 blockTemporaries[getPullbackBlock (predBB)].insert (
26292767 concreteBBArgAdjCopy);
@@ -2643,7 +2781,7 @@ void PullbackCloner::Implementation::visitSILBasicBlock(SILBasicBlock *bb) {
26432781 // Handle `switch_enum` on `Optional`.
26442782 auto termInst = bbArg->getSingleTerminator ();
26452783 if (isSwitchEnumInstOnOptional (termInst))
2646- accumulateAdjointForOptional (bb, incomingValue, bbArgAdjBuf);
2784+ accumulateAdjointForOptionalBuffer (bb, incomingValue, bbArgAdjBuf);
26472785 else
26482786 addToAdjointBuffer (bb, incomingValue, bbArgAdjBuf, pbLoc);
26492787 }
0 commit comments