@@ -650,6 +650,19 @@ class PullbackCloner::Implementation final
650
650
return alloc;
651
651
}
652
652
653
+ // --------------------------------------------------------------------------//
654
+ // Optional differentiation
655
+ // --------------------------------------------------------------------------//
656
+
657
+ // / Given a `wrappedAdjoint` value of type `T.TangentVector`, creates an
658
+ // / `Optional<T>.TangentVector` value from it and adds it to the adjoint value
659
+ // / of `optionalValue`.
660
+ // /
661
+ // / `wrappedAdjoint` may be an object or address value, both cases are
662
+ // / handled.
663
+ void accumulateAdjointForOptional (SILBasicBlock *bb, SILValue optionalValue,
664
+ SILValue wrappedAdjoint);
665
+
653
666
// --------------------------------------------------------------------------//
654
667
// Array literal initialization differentiation
655
668
// --------------------------------------------------------------------------//
@@ -1503,6 +1516,30 @@ class PullbackCloner::Implementation final
1503
1516
}
1504
1517
}
1505
1518
1519
+ // / Handle `unchecked_take_enum_data_addr` instruction.
1520
+ // / Currently, only `Optional`-typed operands are supported.
1521
+ // / Original: y = unchecked_take_enum_data_addr x : $*Enum, #Enum.Case
1522
+ // / Adjoint: adj[x] += $Enum.TangentVector(adj[y])
1523
+ void
1524
+ visitUncheckedTakeEnumDataAddrInst (UncheckedTakeEnumDataAddrInst *utedai) {
1525
+ auto *bb = utedai->getParent ();
1526
+ auto adjBuf = getAdjointBuffer (bb, utedai);
1527
+ auto enumTy = utedai->getOperand ()->getType ();
1528
+ auto *optionalEnumDecl = getASTContext ().getOptionalDecl ();
1529
+ // Only `Optional`-typed operands are supported for now. Diagnose all other
1530
+ // enum operand types.
1531
+ if (enumTy.getASTType ().getEnumOrBoundGenericEnum () != optionalEnumDecl) {
1532
+ LLVM_DEBUG (getADDebugStream ()
1533
+ << " Unhandled instruction in PullbackCloner: " << *utedai);
1534
+ getContext ().emitNondifferentiabilityError (
1535
+ utedai, getInvoker (),
1536
+ diag::autodiff_expression_not_differentiable_note);
1537
+ errorOccurred = true ;
1538
+ return ;
1539
+ }
1540
+ accumulateAdjointForOptional (bb, utedai->getOperand (), adjBuf);
1541
+ }
1542
+
1506
1543
#define NOT_DIFFERENTIABLE (INST, DIAG ) void visit##INST##Inst(INST##Inst *inst);
1507
1544
#undef NOT_DIFFERENTIABLE
1508
1545
@@ -1639,11 +1676,16 @@ bool PullbackCloner::Implementation::run() {
1639
1676
// Diagnose active enum values. Differentiation of enum values requires
1640
1677
// special adjoint value handling and is not yet supported. Diagnose
1641
1678
// only the first active enum value to prevent too many diagnostics.
1642
- if (type.getEnumOrBoundGenericEnum ()) {
1643
- getContext ().emitNondifferentiabilityError (
1644
- v, getInvoker (), diag::autodiff_enums_unsupported);
1645
- errorOccurred = true ;
1646
- return true ;
1679
+ //
1680
+ // Do not diagnose `Optional`-typed values, which will have special-case
1681
+ // differentiation support.
1682
+ if (auto *enumDecl = type.getEnumOrBoundGenericEnum ()) {
1683
+ if (enumDecl != getContext ().getASTContext ().getOptionalDecl ()) {
1684
+ getContext ().emitNondifferentiabilityError (
1685
+ v, getInvoker (), diag::autodiff_enums_unsupported);
1686
+ errorOccurred = true ;
1687
+ return true ;
1688
+ }
1647
1689
}
1648
1690
// Diagnose unsupported stored property projections.
1649
1691
if (auto *inst = dyn_cast<FieldIndexCacheBase>(v)) {
@@ -1972,6 +2014,103 @@ void PullbackCloner::Implementation::emitZeroDerivativesForNonvariedResult(
1972
2014
<< pullback);
1973
2015
}
1974
2016
2017
+ void PullbackCloner::Implementation::accumulateAdjointForOptional (
2018
+ SILBasicBlock *bb, SILValue optionalValue, SILValue wrappedAdjoint) {
2019
+ auto pbLoc = getPullback ().getLocation ();
2020
+ // Handle `switch_enum` on `Optional`.
2021
+ auto *optionalEnumDecl = getASTContext ().getOptionalDecl ();
2022
+ auto optionalTy = optionalValue->getType ();
2023
+ assert (optionalTy.getASTType ().getEnumOrBoundGenericEnum () ==
2024
+ optionalEnumDecl);
2025
+ // `Optional<T>`
2026
+ optionalTy = remapType (optionalTy);
2027
+ // `T`
2028
+ auto wrappedType = optionalTy.getOptionalObjectType ();
2029
+ // `T.TangentVector`
2030
+ auto wrappedTanType = remapType (wrappedAdjoint->getType ());
2031
+ // `Optional<T.TangentVector>`
2032
+ auto optionalOfWrappedTanType = SILType::getOptionalType (wrappedTanType);
2033
+ // `Optional<T>.TangentVector`
2034
+ auto optionalTanTy = getRemappedTangentType (optionalTy);
2035
+ auto *optionalTanDecl = optionalTanTy.getNominalOrBoundGenericNominal ();
2036
+ // Look up the `Optional<T>.TangentVector.init` declaration.
2037
+ auto initLookup =
2038
+ optionalTanDecl->lookupDirect (DeclBaseName::createConstructor ());
2039
+ ConstructorDecl *constructorDecl = nullptr ;
2040
+ for (auto *candidate : initLookup) {
2041
+ auto candidateModule = candidate->getModuleContext ();
2042
+ if (candidateModule->getName () ==
2043
+ builder.getASTContext ().Id_Differentiation ||
2044
+ candidateModule->isStdlibModule ()) {
2045
+ assert (!constructorDecl && " Multiple `Optional.TangentVector.init`s" );
2046
+ constructorDecl = cast<ConstructorDecl>(candidate);
2047
+ #ifdef NDEBUG
2048
+ break ;
2049
+ #endif
2050
+ }
2051
+ }
2052
+ assert (constructorDecl && " No `Optional.TangentVector.init`" );
2053
+
2054
+ // Allocate a local buffer for the `Optional` adjoint value.
2055
+ auto *optTanAdjBuf = builder.createAllocStack (pbLoc, optionalTanTy);
2056
+ // Find `Optional<T.TangentVector>.some` EnumElementDecl.
2057
+ auto someEltDecl = builder.getASTContext ().getOptionalSomeDecl ();
2058
+
2059
+ // Initialize a `Optional<T.TangentVector>` buffer from `wrappedAdjoint`as the
2060
+ // input for `Optional<T>.TangentVector.init`.
2061
+ auto *optArgBuf = builder.createAllocStack (pbLoc, optionalOfWrappedTanType);
2062
+ if (optionalOfWrappedTanType.isLoadableOrOpaque (builder.getFunction ())) {
2063
+ // %enum = enum $Optional<T.TangentVector>, #Optional.some!enumelt,
2064
+ // %wrappedAdjoint : $T
2065
+ auto *enumInst = builder.createEnum (pbLoc, wrappedAdjoint, someEltDecl,
2066
+ optionalOfWrappedTanType);
2067
+ // store %enum to %optArgBuf
2068
+ builder.emitStoreValueOperation (pbLoc, enumInst, optArgBuf,
2069
+ StoreOwnershipQualifier::Trivial);
2070
+ } else {
2071
+ // %enumAddr = init_enum_data_addr %optArgBuf $Optional<T.TangentVector>,
2072
+ // #Optional.some!enumelt
2073
+ auto *enumAddr = builder.createInitEnumDataAddr (
2074
+ pbLoc, optArgBuf, someEltDecl, wrappedTanType.getAddressType ());
2075
+ // copy_addr %wrappedAdjoint to [initialization] %enumAddr
2076
+ builder.createCopyAddr (pbLoc, wrappedAdjoint, enumAddr, IsNotTake,
2077
+ IsInitialization);
2078
+ // inject_enum_addr %optArgBuf : $*Optional<T.TangentVector>,
2079
+ // #Optional.some!enumelt
2080
+ builder.createInjectEnumAddr (pbLoc, optArgBuf, someEltDecl);
2081
+ }
2082
+
2083
+ // Apply `Optional<T>.TangentVector.init`.
2084
+ SILOptFunctionBuilder fb (getContext ().getTransform ());
2085
+ // %init_fn = function_ref @Optional<T>.TangentVector.init
2086
+ auto *initFn = fb.getOrCreateFunction (pbLoc, SILDeclRef (constructorDecl),
2087
+ NotForDefinition);
2088
+ auto *initFnRef = builder.createFunctionRef (pbLoc, initFn);
2089
+ auto *diffProto =
2090
+ builder.getASTContext ().getProtocol (KnownProtocolKind::Differentiable);
2091
+ auto *swiftModule = getModule ().getSwiftModule ();
2092
+ auto diffConf =
2093
+ swiftModule->lookupConformance (wrappedType.getASTType (), diffProto);
2094
+ assert (!diffConf.isInvalid () && " Missing conformance to `Differentiable`" );
2095
+ auto subMap = SubstitutionMap::get (
2096
+ initFn->getLoweredFunctionType ()->getSubstGenericSignature (),
2097
+ ArrayRef<Type>(wrappedType.getASTType ()), {diffConf});
2098
+ // %metatype = metatype $Optional<T>.TangentVector.Type
2099
+ auto metatypeType = CanMetatypeType::get (optionalTanTy.getASTType (),
2100
+ MetatypeRepresentation::Thin);
2101
+ auto metatypeSILType = SILType::getPrimitiveObjectType (metatypeType);
2102
+ auto metatype = builder.createMetatype (pbLoc, metatypeSILType);
2103
+ // apply %init_fn(%optTanAdjBuf, %optArgBuf, %metatype)
2104
+ builder.createApply (pbLoc, initFnRef, subMap,
2105
+ {optTanAdjBuf, optArgBuf, metatype});
2106
+ builder.createDeallocStack (pbLoc, optArgBuf);
2107
+
2108
+ // Accumulate adjoint for the incoming `Optional` value.
2109
+ addToAdjointBuffer (bb, optionalValue, optTanAdjBuf, pbLoc);
2110
+ builder.emitDestroyAddr (pbLoc, optTanAdjBuf);
2111
+ builder.createDeallocStack (pbLoc, optTanAdjBuf);
2112
+ }
2113
+
1975
2114
SILBasicBlock *PullbackCloner::Implementation::buildPullbackSuccessor (
1976
2115
SILBasicBlock *origBB, SILBasicBlock *origPredBB,
1977
2116
SmallDenseMap<SILValue, TrampolineBlockSet> &pullbackTrampolineBlockMap) {
@@ -2110,18 +2249,64 @@ void PullbackCloner::Implementation::visitSILBasicBlock(SILBasicBlock *bb) {
2110
2249
// Get predecessor terminator operands.
2111
2250
SmallVector<std::pair<SILBasicBlock *, SILValue>, 4 > incomingValues;
2112
2251
bbArg->getSingleTerminatorOperands (incomingValues);
2113
- // Materialize adjoint value of active basic block argument, create a
2114
- // copy, and set copy as adjoint value of incoming values.
2115
- auto bbArgAdj = getAdjointValue (bb, bbArg);
2116
- auto concreteBBArgAdj = materializeAdjointDirect (bbArgAdj, pbLoc);
2117
- auto concreteBBArgAdjCopy =
2118
- builder.emitCopyValueOperation (pbLoc, concreteBBArgAdj);
2119
- for (auto pair : incomingValues) {
2120
- auto *predBB = std::get<0 >(pair);
2121
- auto incomingValue = std::get<1 >(pair);
2122
- blockTemporaries[getPullbackBlock (predBB)].insert (concreteBBArgAdjCopy);
2123
- setAdjointValue (predBB, incomingValue,
2124
- makeConcreteAdjointValue (concreteBBArgAdjCopy));
2252
+
2253
+ // Returns true if the given terminator instruction is a `switch_enum` on
2254
+ // an `Optional`-typed value. `switch_enum` instructions require
2255
+ // special-case adjoint value propagation for the operand.
2256
+ auto isSwitchEnumInstOnOptional =
2257
+ [&ctx = getASTContext ()](TermInst *termInst) {
2258
+ if (!termInst)
2259
+ return false ;
2260
+ if (auto *sei = dyn_cast<SwitchEnumInst>(termInst)) {
2261
+ auto *optionalEnumDecl = ctx.getOptionalDecl ();
2262
+ auto operandTy = sei->getOperand ()->getType ();
2263
+ return operandTy.getASTType ().getEnumOrBoundGenericEnum () ==
2264
+ optionalEnumDecl;
2265
+ }
2266
+ return false ;
2267
+ };
2268
+
2269
+ // Check the tangent value category of the active basic block argument.
2270
+ switch (getTangentValueCategory (bbArg)) {
2271
+ // If argument has a loadable tangent value category: materialize adjoint
2272
+ // value of the argument, create a copy, and set the copy as the adjoint
2273
+ // value of incoming values.
2274
+ case SILValueCategory::Object: {
2275
+ auto bbArgAdj = getAdjointValue (bb, bbArg);
2276
+ auto concreteBBArgAdj = materializeAdjointDirect (bbArgAdj, pbLoc);
2277
+ auto concreteBBArgAdjCopy =
2278
+ builder.emitCopyValueOperation (pbLoc, concreteBBArgAdj);
2279
+ for (auto pair : incomingValues) {
2280
+ auto *predBB = std::get<0 >(pair);
2281
+ auto incomingValue = std::get<1 >(pair);
2282
+ blockTemporaries[getPullbackBlock (predBB)].insert (concreteBBArgAdjCopy);
2283
+ // Handle `switch_enum` on `Optional`.
2284
+ auto termInst = bbArg->getSingleTerminator ();
2285
+ if (isSwitchEnumInstOnOptional (termInst))
2286
+ accumulateAdjointForOptional (bb, incomingValue, concreteBBArgAdjCopy);
2287
+ else
2288
+ setAdjointValue (predBB, incomingValue,
2289
+ makeConcreteAdjointValue (concreteBBArgAdjCopy));
2290
+ }
2291
+ break ;
2292
+ }
2293
+ // If argument has an address tangent value category: materialize adjoint
2294
+ // value of the argument, create a copy, and set the copy as the adjoint
2295
+ // value of incoming values.
2296
+ case SILValueCategory::Address: {
2297
+ auto bbArgAdjBuf = getAdjointBuffer (bb, bbArg);
2298
+ for (auto pair : incomingValues) {
2299
+ auto *predBB = std::get<0 >(pair);
2300
+ auto incomingValue = std::get<1 >(pair);
2301
+ // Handle `switch_enum` on `Optional`.
2302
+ auto termInst = bbArg->getSingleTerminator ();
2303
+ if (isSwitchEnumInstOnOptional (termInst))
2304
+ accumulateAdjointForOptional (bb, incomingValue, bbArgAdjBuf);
2305
+ else
2306
+ addToAdjointBuffer (predBB, incomingValue, bbArgAdjBuf, pbLoc);
2307
+ }
2308
+ break ;
2309
+ }
2125
2310
}
2126
2311
}
2127
2312
0 commit comments