Skip to content

Commit 1775e8a

Browse files
committed
[AutoDiff upstream] Add VJPEmitter.
`VJPEmitter` is a cloner that emits VJP functions. It implements reverse-mode automatic differentiation, along with `PullbackEmitter`. `VJPEmitter` clones an original function, replacing function applications with VJP function applications. In VJP functions, each basic block takes a pullback struct (containing callee pullbacks) and produces a predecessor enum: these data structures are consumed by pullback functions.
1 parent fa405e6 commit 1775e8a

File tree

7 files changed

+983
-6
lines changed

7 files changed

+983
-6
lines changed

include/swift/AST/DiagnosticsSIL.def

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -510,6 +510,25 @@ NOTE(autodiff_member_subset_indices_not_differentiable,none,
510510
NOTE(autodiff_cannot_param_subset_thunk_partially_applied_orig_fn,none,
511511
"cannot convert a direct method reference to a '@differentiable' "
512512
"function; use an explicit closure instead", ())
513+
NOTE(autodiff_cannot_differentiate_through_multiple_results,none,
514+
"cannot differentiate through multiple results", ())
515+
// TODO(TF-1149): Remove this diagnostic.
516+
NOTE(autodiff_loadable_value_addressonly_tangent_unsupported,none,
517+
"cannot yet differentiate value whose type %0 has a compile-time known "
518+
"size, but whose 'TangentVector' contains stored properties of unknown "
519+
"size; consider modifying %1 to use fewer generic parameters in stored "
520+
"properties", (Type, Type))
521+
NOTE(autodiff_enums_unsupported,none,
522+
"differentiating enum values is not yet supported", ())
523+
NOTE(autodiff_stored_property_no_corresponding_tangent,none,
524+
"property cannot be differentiated because '%0.TangentVector' does not "
525+
"have a member named '%1'", (StringRef, StringRef))
526+
NOTE(autodiff_coroutines_not_supported,none,
527+
"differentiation of coroutine calls is not yet supported", ())
528+
NOTE(autodiff_cannot_differentiate_writes_to_global_variables,none,
529+
"cannot differentiate writes to global variables", ())
530+
NOTE(autodiff_cannot_differentiate_writes_to_mutable_captures,none,
531+
"cannot differentiate writes to mutable captures", ())
513532

514533
ERROR(non_physical_addressof,none,
515534
"addressof only works with purely physical lvalues; "

include/swift/SIL/SILCloner.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -47,16 +47,16 @@ class SILCloner : protected SILInstructionVisitor<ImplClass> {
4747
TypeSubstitutionMap OpenedExistentialSubs;
4848
SILOpenedArchetypesTracker OpenedArchetypesTracker;
4949

50-
private:
51-
/// MARK: Private state hidden from CRTP extensions.
52-
5350
// The old-to-new value map.
5451
llvm::DenseMap<SILValue, SILValue> ValueMap;
5552

5653
/// The old-to-new block map. Some entries may be premapped with original
5754
/// blocks.
5855
llvm::DenseMap<SILBasicBlock*, SILBasicBlock*> BBMap;
5956

57+
private:
58+
/// MARK: Private state hidden from CRTP extensions.
59+
6060
// The original blocks in DFS preorder. All blocks in this list are mapped.
6161
// After cloning, this represents the entire cloned CFG.
6262
//
Lines changed: 153 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,153 @@
1+
//===--- VJPEmitter.h - VJP Generation in differentiation -----*- C++ -*---===//
2+
//
3+
// This source file is part of the Swift.org open source project
4+
//
5+
// Copyright (c) 2019 - 2020 Apple Inc. and the Swift project authors
6+
// Licensed under Apache License v2.0 with Runtime Library Exception
7+
//
8+
// See https://swift.org/LICENSE.txt for license information
9+
// See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors
10+
//
11+
//===----------------------------------------------------------------------===//
12+
//
13+
// This file defines a helper class for generating VJP functions for automatic
14+
// differentiation.
15+
//
16+
//===----------------------------------------------------------------------===//
17+
18+
#ifndef SWIFT_SILOPTIMIZER_UTILS_DIFFERENTIATION_VJPEMITTER_H
19+
#define SWIFT_SILOPTIMIZER_UTILS_DIFFERENTIATION_VJPEMITTER_H
20+
21+
#include "swift/SIL/TypeSubstCloner.h"
22+
#include "swift/SILOptimizer/Analysis/DifferentiableActivityAnalysis.h"
23+
#include "swift/SILOptimizer/Utils/Differentiation/DifferentiationInvoker.h"
24+
#include "swift/SILOptimizer/Utils/Differentiation/LinearMapInfo.h"
25+
#include "llvm/ADT/DenseMap.h"
26+
27+
namespace swift {
28+
29+
class SILDifferentiabilityWitness;
30+
class SILBasicBlock;
31+
class SILFunction;
32+
class SILInstruction;
33+
34+
namespace autodiff {
35+
36+
class ADContext;
37+
class PullbackEmitter;
38+
39+
class VJPEmitter final
40+
: public TypeSubstCloner<VJPEmitter, SILOptFunctionBuilder> {
41+
friend class PullbackEmitter;
42+
43+
private:
44+
/// The global context.
45+
ADContext &context;
46+
47+
/// The original function.
48+
SILFunction *const original;
49+
50+
/// The differentiability witness.
51+
SILDifferentiabilityWitness *const witness;
52+
53+
/// The VJP function.
54+
SILFunction *const vjp;
55+
56+
/// The pullback function.
57+
SILFunction *pullback;
58+
59+
/// The differentiation invoker.
60+
DifferentiationInvoker invoker;
61+
62+
/// Info from activity analysis on the original function.
63+
const DifferentiableActivityInfo &activityInfo;
64+
65+
/// The linear map info.
66+
LinearMapInfo pullbackInfo;
67+
68+
/// Caches basic blocks whose phi arguments have been remapped (adding a
69+
/// predecessor enum argument).
70+
SmallPtrSet<SILBasicBlock *, 4> remappedBasicBlocks;
71+
72+
bool errorOccurred = false;
73+
74+
/// Mapping from original blocks to pullback values. Used to build pullback
75+
/// struct instances.
76+
llvm::DenseMap<SILBasicBlock *, SmallVector<SILValue, 8>> pullbackValues;
77+
78+
ASTContext &getASTContext() const { return vjp->getASTContext(); }
79+
SILModule &getModule() const { return vjp->getModule(); }
80+
const SILAutoDiffIndices getIndices() const {
81+
return witness->getSILAutoDiffIndices();
82+
}
83+
84+
static SubstitutionMap getSubstitutionMap(SILFunction *original,
85+
SILFunction *vjp);
86+
87+
static const DifferentiableActivityInfo &
88+
getActivityInfo(ADContext &context, SILFunction *original,
89+
SILAutoDiffIndices indices, SILFunction *vjp);
90+
91+
public:
92+
explicit VJPEmitter(ADContext &context, SILFunction *original,
93+
SILDifferentiabilityWitness *witness, SILFunction *vjp,
94+
DifferentiationInvoker invoker);
95+
96+
SILFunction *createEmptyPullback();
97+
98+
/// Run VJP generation. Returns true on error.
99+
bool run();
100+
101+
void postProcess(SILInstruction *orig, SILInstruction *cloned);
102+
103+
/// Remap original basic blocks, adding predecessor enum arguments.
104+
SILBasicBlock *remapBasicBlock(SILBasicBlock *bb);
105+
106+
/// General visitor for all instructions. If any error is emitted by previous
107+
/// visits, bail out.
108+
void visit(SILInstruction *inst);
109+
110+
void visitSILInstruction(SILInstruction *inst);
111+
112+
private:
113+
/// Get the lowered SIL type of the given AST type.
114+
SILType getLoweredType(Type type);
115+
116+
/// Get the lowered SIL type of the given nominal type declaration.
117+
SILType getNominalDeclLoweredType(NominalTypeDecl *nominal);
118+
119+
/// Build a pullback struct value for the original block corresponding to the
120+
/// given terminator.
121+
StructInst *buildPullbackValueStructValue(TermInst *termInst);
122+
123+
/// Build a predecessor enum instance using the given builder for the given
124+
/// original predecessor/successor blocks and pullback struct value.
125+
EnumInst *buildPredecessorEnumValue(SILBuilder &builder,
126+
SILBasicBlock *predBB,
127+
SILBasicBlock *succBB,
128+
SILValue pbStructVal);
129+
130+
public:
131+
void visitReturnInst(ReturnInst *ri);
132+
133+
void visitBranchInst(BranchInst *bi);
134+
135+
void visitCondBranchInst(CondBranchInst *cbi);
136+
137+
void visitSwitchEnumInstBase(SwitchEnumInstBase *inst);
138+
139+
void visitSwitchEnumInst(SwitchEnumInst *sei);
140+
141+
void visitSwitchEnumAddrInst(SwitchEnumAddrInst *seai);
142+
143+
// If an `apply` has active results or active inout arguments, replace it
144+
// with an `apply` of its VJP.
145+
void visitApplyInst(ApplyInst *ai);
146+
147+
void visitDifferentiableFunctionInst(DifferentiableFunctionInst *dfi);
148+
};
149+
150+
} // end namespace autodiff
151+
} // end namespace swift
152+
153+
#endif // SWIFT_SILOPTIMIZER_UTILS_DIFFERENTIATION_VJPEMITTER_H

lib/SILOptimizer/Mandatory/Differentiation.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
#include "swift/SILOptimizer/PassManager/Transforms.h"
4141
#include "swift/SILOptimizer/Utils/Differentiation/ADContext.h"
4242
#include "swift/SILOptimizer/Utils/Differentiation/Thunk.h"
43+
#include "swift/SILOptimizer/Utils/Differentiation/VJPEmitter.h"
4344
#include "swift/SILOptimizer/Utils/SILOptFunctionBuilder.h"
4445
#include "llvm/ADT/APSInt.h"
4546
#include "llvm/ADT/BreadthFirstIterator.h"
@@ -944,9 +945,8 @@ bool DifferentiationTransformer::canonicalizeDifferentiabilityWitness(
944945
auto *vjp = createEmptyVJP(context, original, witness, serializeFunctions);
945946
witness->setVJP(vjp);
946947
context.recordGeneratedFunction(vjp);
947-
// TODO(TF-1211): Upstream and use `VJPEmitter`. Fatal error with a nice
948-
// message for now.
949-
emitFatalError(context, vjp, "_fatalErrorVJPNotGenerated");
948+
VJPEmitter emitter(context, original, witness, vjp, invoker);
949+
return emitter.run();
950950
}
951951
return false;
952952
}

lib/SILOptimizer/PassManager/PassPipeline.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,9 @@ static void addMandatoryOptPipeline(SILPassPipelinePlan &P) {
9292
P.addAllocBoxToStack();
9393
P.addNoReturnFolding();
9494
addDefiniteInitialization(P);
95+
96+
// Automatic differentiation: canonicalize all differentiability witnesses
97+
// and `differentiable_function` instructions.
9598
P.addDifferentiation();
9699

97100
// Only run semantic arc opts if we are optimizing and if mandatory semantic

lib/SILOptimizer/Utils/Differentiation/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,4 +4,5 @@ silopt_register_sources(
44
DifferentiationInvoker.cpp
55
LinearMapInfo.cpp
66
Thunk.cpp
7+
VJPEmitter.cpp
78
)

0 commit comments

Comments
 (0)