Skip to content

Commit ccca08b

Browse files
authored
[AutoDiff] Add reverse-mode switch_enum support for Optional. (swiftlang#33218)
Pullback generation now supports `switch_enum` and `switch_enum_addr` instructions for `Optional`-typed operands. Currently, the logic is special-cased to `Optional`, but may be generalized in the future to support enums following general rules.
1 parent 9bbe6e7 commit ccca08b

File tree

4 files changed

+524
-20
lines changed

4 files changed

+524
-20
lines changed

lib/SILOptimizer/Differentiation/PullbackCloner.cpp

Lines changed: 202 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -650,6 +650,19 @@ class PullbackCloner::Implementation final
650650
return alloc;
651651
}
652652

653+
//--------------------------------------------------------------------------//
654+
// Optional differentiation
655+
//--------------------------------------------------------------------------//
656+
657+
/// Given a `wrappedAdjoint` value of type `T.TangentVector`, creates an
658+
/// `Optional<T>.TangentVector` value from it and adds it to the adjoint value
659+
/// of `optionalValue`.
660+
///
661+
/// `wrappedAdjoint` may be an object or address value, both cases are
662+
/// handled.
663+
void accumulateAdjointForOptional(SILBasicBlock *bb, SILValue optionalValue,
664+
SILValue wrappedAdjoint);
665+
653666
//--------------------------------------------------------------------------//
654667
// Array literal initialization differentiation
655668
//--------------------------------------------------------------------------//
@@ -1503,6 +1516,30 @@ class PullbackCloner::Implementation final
15031516
}
15041517
}
15051518

1519+
/// Handle `unchecked_take_enum_data_addr` instruction.
1520+
/// Currently, only `Optional`-typed operands are supported.
1521+
/// Original: y = unchecked_take_enum_data_addr x : $*Enum, #Enum.Case
1522+
/// Adjoint: adj[x] += $Enum.TangentVector(adj[y])
1523+
void
1524+
visitUncheckedTakeEnumDataAddrInst(UncheckedTakeEnumDataAddrInst *utedai) {
1525+
auto *bb = utedai->getParent();
1526+
auto adjBuf = getAdjointBuffer(bb, utedai);
1527+
auto enumTy = utedai->getOperand()->getType();
1528+
auto *optionalEnumDecl = getASTContext().getOptionalDecl();
1529+
// Only `Optional`-typed operands are supported for now. Diagnose all other
1530+
// enum operand types.
1531+
if (enumTy.getASTType().getEnumOrBoundGenericEnum() != optionalEnumDecl) {
1532+
LLVM_DEBUG(getADDebugStream()
1533+
<< "Unhandled instruction in PullbackCloner: " << *utedai);
1534+
getContext().emitNondifferentiabilityError(
1535+
utedai, getInvoker(),
1536+
diag::autodiff_expression_not_differentiable_note);
1537+
errorOccurred = true;
1538+
return;
1539+
}
1540+
accumulateAdjointForOptional(bb, utedai->getOperand(), adjBuf);
1541+
}
1542+
15061543
#define NOT_DIFFERENTIABLE(INST, DIAG) void visit##INST##Inst(INST##Inst *inst);
15071544
#undef NOT_DIFFERENTIABLE
15081545

@@ -1639,11 +1676,16 @@ bool PullbackCloner::Implementation::run() {
16391676
// Diagnose active enum values. Differentiation of enum values requires
16401677
// special adjoint value handling and is not yet supported. Diagnose
16411678
// only the first active enum value to prevent too many diagnostics.
1642-
if (type.getEnumOrBoundGenericEnum()) {
1643-
getContext().emitNondifferentiabilityError(
1644-
v, getInvoker(), diag::autodiff_enums_unsupported);
1645-
errorOccurred = true;
1646-
return true;
1679+
//
1680+
// Do not diagnose `Optional`-typed values, which will have special-case
1681+
// differentiation support.
1682+
if (auto *enumDecl = type.getEnumOrBoundGenericEnum()) {
1683+
if (enumDecl != getContext().getASTContext().getOptionalDecl()) {
1684+
getContext().emitNondifferentiabilityError(
1685+
v, getInvoker(), diag::autodiff_enums_unsupported);
1686+
errorOccurred = true;
1687+
return true;
1688+
}
16471689
}
16481690
// Diagnose unsupported stored property projections.
16491691
if (auto *inst = dyn_cast<FieldIndexCacheBase>(v)) {
@@ -1972,6 +2014,103 @@ void PullbackCloner::Implementation::emitZeroDerivativesForNonvariedResult(
19722014
<< pullback);
19732015
}
19742016

2017+
void PullbackCloner::Implementation::accumulateAdjointForOptional(
2018+
SILBasicBlock *bb, SILValue optionalValue, SILValue wrappedAdjoint) {
2019+
auto pbLoc = getPullback().getLocation();
2020+
// Handle `switch_enum` on `Optional`.
2021+
auto *optionalEnumDecl = getASTContext().getOptionalDecl();
2022+
auto optionalTy = optionalValue->getType();
2023+
assert(optionalTy.getASTType().getEnumOrBoundGenericEnum() ==
2024+
optionalEnumDecl);
2025+
// `Optional<T>`
2026+
optionalTy = remapType(optionalTy);
2027+
// `T`
2028+
auto wrappedType = optionalTy.getOptionalObjectType();
2029+
// `T.TangentVector`
2030+
auto wrappedTanType = remapType(wrappedAdjoint->getType());
2031+
// `Optional<T.TangentVector>`
2032+
auto optionalOfWrappedTanType = SILType::getOptionalType(wrappedTanType);
2033+
// `Optional<T>.TangentVector`
2034+
auto optionalTanTy = getRemappedTangentType(optionalTy);
2035+
auto *optionalTanDecl = optionalTanTy.getNominalOrBoundGenericNominal();
2036+
// Look up the `Optional<T>.TangentVector.init` declaration.
2037+
auto initLookup =
2038+
optionalTanDecl->lookupDirect(DeclBaseName::createConstructor());
2039+
ConstructorDecl *constructorDecl = nullptr;
2040+
for (auto *candidate : initLookup) {
2041+
auto candidateModule = candidate->getModuleContext();
2042+
if (candidateModule->getName() ==
2043+
builder.getASTContext().Id_Differentiation ||
2044+
candidateModule->isStdlibModule()) {
2045+
assert(!constructorDecl && "Multiple `Optional.TangentVector.init`s");
2046+
constructorDecl = cast<ConstructorDecl>(candidate);
2047+
#ifdef NDEBUG
2048+
break;
2049+
#endif
2050+
}
2051+
}
2052+
assert(constructorDecl && "No `Optional.TangentVector.init`");
2053+
2054+
// Allocate a local buffer for the `Optional` adjoint value.
2055+
auto *optTanAdjBuf = builder.createAllocStack(pbLoc, optionalTanTy);
2056+
// Find `Optional<T.TangentVector>.some` EnumElementDecl.
2057+
auto someEltDecl = builder.getASTContext().getOptionalSomeDecl();
2058+
2059+
// Initialize a `Optional<T.TangentVector>` buffer from `wrappedAdjoint`as the
2060+
// input for `Optional<T>.TangentVector.init`.
2061+
auto *optArgBuf = builder.createAllocStack(pbLoc, optionalOfWrappedTanType);
2062+
if (optionalOfWrappedTanType.isLoadableOrOpaque(builder.getFunction())) {
2063+
// %enum = enum $Optional<T.TangentVector>, #Optional.some!enumelt,
2064+
// %wrappedAdjoint : $T
2065+
auto *enumInst = builder.createEnum(pbLoc, wrappedAdjoint, someEltDecl,
2066+
optionalOfWrappedTanType);
2067+
// store %enum to %optArgBuf
2068+
builder.emitStoreValueOperation(pbLoc, enumInst, optArgBuf,
2069+
StoreOwnershipQualifier::Trivial);
2070+
} else {
2071+
// %enumAddr = init_enum_data_addr %optArgBuf $Optional<T.TangentVector>,
2072+
// #Optional.some!enumelt
2073+
auto *enumAddr = builder.createInitEnumDataAddr(
2074+
pbLoc, optArgBuf, someEltDecl, wrappedTanType.getAddressType());
2075+
// copy_addr %wrappedAdjoint to [initialization] %enumAddr
2076+
builder.createCopyAddr(pbLoc, wrappedAdjoint, enumAddr, IsNotTake,
2077+
IsInitialization);
2078+
// inject_enum_addr %optArgBuf : $*Optional<T.TangentVector>,
2079+
// #Optional.some!enumelt
2080+
builder.createInjectEnumAddr(pbLoc, optArgBuf, someEltDecl);
2081+
}
2082+
2083+
// Apply `Optional<T>.TangentVector.init`.
2084+
SILOptFunctionBuilder fb(getContext().getTransform());
2085+
// %init_fn = function_ref @Optional<T>.TangentVector.init
2086+
auto *initFn = fb.getOrCreateFunction(pbLoc, SILDeclRef(constructorDecl),
2087+
NotForDefinition);
2088+
auto *initFnRef = builder.createFunctionRef(pbLoc, initFn);
2089+
auto *diffProto =
2090+
builder.getASTContext().getProtocol(KnownProtocolKind::Differentiable);
2091+
auto *swiftModule = getModule().getSwiftModule();
2092+
auto diffConf =
2093+
swiftModule->lookupConformance(wrappedType.getASTType(), diffProto);
2094+
assert(!diffConf.isInvalid() && "Missing conformance to `Differentiable`");
2095+
auto subMap = SubstitutionMap::get(
2096+
initFn->getLoweredFunctionType()->getSubstGenericSignature(),
2097+
ArrayRef<Type>(wrappedType.getASTType()), {diffConf});
2098+
// %metatype = metatype $Optional<T>.TangentVector.Type
2099+
auto metatypeType = CanMetatypeType::get(optionalTanTy.getASTType(),
2100+
MetatypeRepresentation::Thin);
2101+
auto metatypeSILType = SILType::getPrimitiveObjectType(metatypeType);
2102+
auto metatype = builder.createMetatype(pbLoc, metatypeSILType);
2103+
// apply %init_fn(%optTanAdjBuf, %optArgBuf, %metatype)
2104+
builder.createApply(pbLoc, initFnRef, subMap,
2105+
{optTanAdjBuf, optArgBuf, metatype});
2106+
builder.createDeallocStack(pbLoc, optArgBuf);
2107+
2108+
// Accumulate adjoint for the incoming `Optional` value.
2109+
addToAdjointBuffer(bb, optionalValue, optTanAdjBuf, pbLoc);
2110+
builder.emitDestroyAddr(pbLoc, optTanAdjBuf);
2111+
builder.createDeallocStack(pbLoc, optTanAdjBuf);
2112+
}
2113+
19752114
SILBasicBlock *PullbackCloner::Implementation::buildPullbackSuccessor(
19762115
SILBasicBlock *origBB, SILBasicBlock *origPredBB,
19772116
SmallDenseMap<SILValue, TrampolineBlockSet> &pullbackTrampolineBlockMap) {
@@ -2110,18 +2249,64 @@ void PullbackCloner::Implementation::visitSILBasicBlock(SILBasicBlock *bb) {
21102249
// Get predecessor terminator operands.
21112250
SmallVector<std::pair<SILBasicBlock *, SILValue>, 4> incomingValues;
21122251
bbArg->getSingleTerminatorOperands(incomingValues);
2113-
// Materialize adjoint value of active basic block argument, create a
2114-
// copy, and set copy as adjoint value of incoming values.
2115-
auto bbArgAdj = getAdjointValue(bb, bbArg);
2116-
auto concreteBBArgAdj = materializeAdjointDirect(bbArgAdj, pbLoc);
2117-
auto concreteBBArgAdjCopy =
2118-
builder.emitCopyValueOperation(pbLoc, concreteBBArgAdj);
2119-
for (auto pair : incomingValues) {
2120-
auto *predBB = std::get<0>(pair);
2121-
auto incomingValue = std::get<1>(pair);
2122-
blockTemporaries[getPullbackBlock(predBB)].insert(concreteBBArgAdjCopy);
2123-
setAdjointValue(predBB, incomingValue,
2124-
makeConcreteAdjointValue(concreteBBArgAdjCopy));
2252+
2253+
// Returns true if the given terminator instruction is a `switch_enum` on
2254+
// an `Optional`-typed value. `switch_enum` instructions require
2255+
// special-case adjoint value propagation for the operand.
2256+
auto isSwitchEnumInstOnOptional =
2257+
[&ctx = getASTContext()](TermInst *termInst) {
2258+
if (!termInst)
2259+
return false;
2260+
if (auto *sei = dyn_cast<SwitchEnumInst>(termInst)) {
2261+
auto *optionalEnumDecl = ctx.getOptionalDecl();
2262+
auto operandTy = sei->getOperand()->getType();
2263+
return operandTy.getASTType().getEnumOrBoundGenericEnum() ==
2264+
optionalEnumDecl;
2265+
}
2266+
return false;
2267+
};
2268+
2269+
// Check the tangent value category of the active basic block argument.
2270+
switch (getTangentValueCategory(bbArg)) {
2271+
// If argument has a loadable tangent value category: materialize adjoint
2272+
// value of the argument, create a copy, and set the copy as the adjoint
2273+
// value of incoming values.
2274+
case SILValueCategory::Object: {
2275+
auto bbArgAdj = getAdjointValue(bb, bbArg);
2276+
auto concreteBBArgAdj = materializeAdjointDirect(bbArgAdj, pbLoc);
2277+
auto concreteBBArgAdjCopy =
2278+
builder.emitCopyValueOperation(pbLoc, concreteBBArgAdj);
2279+
for (auto pair : incomingValues) {
2280+
auto *predBB = std::get<0>(pair);
2281+
auto incomingValue = std::get<1>(pair);
2282+
blockTemporaries[getPullbackBlock(predBB)].insert(concreteBBArgAdjCopy);
2283+
// Handle `switch_enum` on `Optional`.
2284+
auto termInst = bbArg->getSingleTerminator();
2285+
if (isSwitchEnumInstOnOptional(termInst))
2286+
accumulateAdjointForOptional(bb, incomingValue, concreteBBArgAdjCopy);
2287+
else
2288+
setAdjointValue(predBB, incomingValue,
2289+
makeConcreteAdjointValue(concreteBBArgAdjCopy));
2290+
}
2291+
break;
2292+
}
2293+
// If argument has an address tangent value category: materialize adjoint
2294+
// value of the argument, create a copy, and set the copy as the adjoint
2295+
// value of incoming values.
2296+
case SILValueCategory::Address: {
2297+
auto bbArgAdjBuf = getAdjointBuffer(bb, bbArg);
2298+
for (auto pair : incomingValues) {
2299+
auto *predBB = std::get<0>(pair);
2300+
auto incomingValue = std::get<1>(pair);
2301+
// Handle `switch_enum` on `Optional`.
2302+
auto termInst = bbArg->getSingleTerminator();
2303+
if (isSwitchEnumInstOnOptional(termInst))
2304+
accumulateAdjointForOptional(bb, incomingValue, bbArgAdjBuf);
2305+
else
2306+
addToAdjointBuffer(predBB, incomingValue, bbArgAdjBuf, pbLoc);
2307+
}
2308+
break;
2309+
}
21252310
}
21262311
}
21272312

test/AutoDiff/SILOptimizer/activity_analysis.swift

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -200,7 +200,7 @@ func checked_cast_addr_nonactive_result<T: Differentiable>(_ x: T) -> T {
200200
@differentiable
201201
// expected-note @+1 {{when differentiating this function definition}}
202202
func checked_cast_addr_active_result<T: Differentiable>(x: T) -> T {
203-
// expected-note @+1 {{differentiating enum values is not yet supported}}
203+
// expected-note @+1 {{expression is not differentiable}}
204204
if let y = x as? Float {
205205
// Use `y: Float?` value in an active way.
206206
return y as! T
@@ -744,8 +744,8 @@ func testClassModifyAccessor(_ c: inout C) {
744744
@differentiable
745745
// expected-note @+1 {{when differentiating this function definition}}
746746
func testActiveOptional(_ x: Float) -> Float {
747-
// expected-note @+1 {{differentiating enum values is not yet supported}}
748747
var maybe: Float? = 10
748+
// expected-note @+1 {{expression is not differentiable}}
749749
maybe = x
750750
return maybe!
751751
}

test/AutoDiff/SILOptimizer/differentiation_diagnostics.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -157,8 +157,8 @@ class C<T: Differentiable>: Differentiable {
157157
@differentiable
158158
// expected-note @+1 {{when differentiating this function definition}}
159159
func usesOptionals(_ x: Float) -> Float {
160-
// expected-note @+1 {{differentiating enum values is not yet supported}}
161160
var maybe: Float? = 10
161+
// expected-note @+1 {{expression is not differentiable}}
162162
maybe = x
163163
return maybe!
164164
}

0 commit comments

Comments
 (0)