Skip to content

Commit df830bc

Browse files
committed
[AutoDiff] NFC: Reimplement VJPCloner using pimpl pattern.
`VJPCloner.h` is now tiny: `VJPCloner` exposes only a `bool run()` entry point. All of the implementation is moved to `VJPCloner::Implementation` in `VJPCloner.cpp`. Methods can be defined directly in `VJPCloner.cpp` without separate declarations.
1 parent cd3f46f commit df830bc

File tree

4 files changed

+704
-685
lines changed

4 files changed

+704
-685
lines changed

include/swift/SILOptimizer/Differentiation/PullbackCloner.h

Lines changed: 1 addition & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -18,29 +18,14 @@
1818
#ifndef SWIFT_SILOPTIMIZER_UTILS_DIFFERENTIATION_PULLBACKCLONER_H
1919
#define SWIFT_SILOPTIMIZER_UTILS_DIFFERENTIATION_PULLBACKCLONER_H
2020

21-
#include "swift/SILOptimizer/Analysis/DifferentiableActivityAnalysis.h"
22-
#include "swift/SILOptimizer/Differentiation/AdjointValue.h"
23-
#include "swift/SILOptimizer/Differentiation/DifferentiationInvoker.h"
24-
#include "swift/SILOptimizer/Differentiation/LinearMapInfo.h"
25-
26-
#include "swift/SIL/TypeSubstCloner.h"
27-
#include "llvm/ADT/DenseMap.h"
28-
2921
namespace swift {
30-
31-
class SILDifferentiabilityWitness;
32-
class SILBasicBlock;
33-
class SILFunction;
34-
class SILInstruction;
35-
3622
namespace autodiff {
3723

38-
class ADContext;
3924
class VJPCloner;
4025

4126
/// A helper class for generating pullback functions.
4227
class PullbackCloner final {
43-
struct Implementation;
28+
class Implementation;
4429
Implementation &impl;
4530

4631
public:

include/swift/SILOptimizer/Differentiation/VJPCloner.h

Lines changed: 23 additions & 131 deletions
Original file line numberDiff line numberDiff line change
@@ -22,149 +22,41 @@
2222
#include "swift/SILOptimizer/Differentiation/DifferentiationInvoker.h"
2323
#include "swift/SILOptimizer/Differentiation/LinearMapInfo.h"
2424

25-
#include "swift/SIL/TypeSubstCloner.h"
26-
#include "llvm/ADT/DenseMap.h"
27-
2825
namespace swift {
29-
30-
class SILDifferentiabilityWitness;
31-
class SILBasicBlock;
32-
class SILFunction;
33-
class SILInstruction;
34-
3526
namespace autodiff {
3627

3728
class ADContext;
3829
class PullbackCloner;
3930

40-
class VJPCloner final
41-
: public TypeSubstCloner<VJPCloner, SILOptFunctionBuilder> {
42-
friend class PullbackCloner;
43-
44-
private:
45-
/// The global context.
46-
ADContext &context;
47-
48-
/// The original function.
49-
SILFunction *const original;
50-
51-
/// The differentiability witness.
52-
SILDifferentiabilityWitness *const witness;
53-
54-
/// The VJP function.
55-
SILFunction *const vjp;
56-
57-
/// The pullback function.
58-
SILFunction *pullback;
59-
60-
/// The differentiation invoker.
61-
DifferentiationInvoker invoker;
62-
63-
/// Info from activity analysis on the original function.
64-
const DifferentiableActivityInfo &activityInfo;
65-
66-
/// The linear map info.
67-
LinearMapInfo pullbackInfo;
68-
69-
/// Caches basic blocks whose phi arguments have been remapped (adding a
70-
/// predecessor enum argument).
71-
SmallPtrSet<SILBasicBlock *, 4> remappedBasicBlocks;
72-
73-
bool errorOccurred = false;
74-
75-
/// Mapping from original blocks to pullback values. Used to build pullback
76-
/// struct instances.
77-
llvm::DenseMap<SILBasicBlock *, SmallVector<SILValue, 8>> pullbackValues;
78-
79-
ASTContext &getASTContext() const { return vjp->getASTContext(); }
80-
SILModule &getModule() const { return vjp->getModule(); }
81-
const SILAutoDiffIndices getIndices() const {
82-
return witness->getSILAutoDiffIndices();
83-
}
84-
85-
static SubstitutionMap getSubstitutionMap(SILFunction *original,
86-
SILFunction *vjp);
87-
88-
static const DifferentiableActivityInfo &
89-
getActivityInfo(ADContext &context, SILFunction *original,
90-
SILAutoDiffIndices indices, SILFunction *vjp);
31+
/// A helper class for generating VJP functions.
32+
class VJPCloner final {
33+
class Implementation;
34+
Implementation &impl;
9135

9236
public:
37+
/// Creates a VJP cloner.
38+
///
39+
/// The parent VJP cloner stores the original function and an empty
40+
/// to-be-generated pullback function.
9341
explicit VJPCloner(ADContext &context, SILFunction *original,
9442
SILDifferentiabilityWitness *witness, SILFunction *vjp,
9543
DifferentiationInvoker invoker);
96-
97-
SILFunction *createEmptyPullback();
98-
99-
/// Run VJP generation. Returns true on error.
44+
~VJPCloner();
45+
46+
ADContext &getContext() const;
47+
SILModule &getModule() const;
48+
SILFunction &getOriginal() const;
49+
SILFunction &getVJP() const;
50+
SILFunction &getPullback() const;
51+
SILDifferentiabilityWitness *getWitness() const;
52+
const SILAutoDiffIndices getIndices() const;
53+
DifferentiationInvoker getInvoker() const;
54+
LinearMapInfo &getPullbackInfo() const;
55+
const DifferentiableActivityInfo &getActivityInfo() const;
56+
57+
/// Performs VJP generation on the empty VJP function. Returns true if any
58+
/// error occurs.
10059
bool run();
101-
102-
void postProcess(SILInstruction *orig, SILInstruction *cloned);
103-
104-
/// Remap original basic blocks, adding predecessor enum arguments.
105-
SILBasicBlock *remapBasicBlock(SILBasicBlock *bb);
106-
107-
/// General visitor for all instructions. If any error is emitted by previous
108-
/// visits, bail out.
109-
void visit(SILInstruction *inst);
110-
111-
void visitSILInstruction(SILInstruction *inst);
112-
113-
private:
114-
/// Get the lowered SIL type of the given AST type.
115-
SILType getLoweredType(Type type);
116-
117-
/// Get the lowered SIL type of the given nominal type declaration.
118-
SILType getNominalDeclLoweredType(NominalTypeDecl *nominal);
119-
120-
// Creates a trampoline block for given original terminator instruction, the
121-
// pullback struct value for its parent block, and a successor basic block.
122-
//
123-
// The trampoline block has the same arguments as and branches to the remapped
124-
// successor block, but drops the last predecessor enum argument.
125-
//
126-
// Used for cloning branching terminator instructions with specific
127-
// requirements on successor block arguments, where an additional predecessor
128-
// enum argument is not acceptable.
129-
SILBasicBlock *createTrampolineBasicBlock(TermInst *termInst,
130-
StructInst *pbStructVal,
131-
SILBasicBlock *succBB);
132-
133-
/// Build a pullback struct value for the given original terminator
134-
/// instruction.
135-
StructInst *buildPullbackValueStructValue(TermInst *termInst);
136-
137-
/// Build a predecessor enum instance using the given builder for the given
138-
/// original predecessor/successor blocks and pullback struct value.
139-
EnumInst *buildPredecessorEnumValue(SILBuilder &builder,
140-
SILBasicBlock *predBB,
141-
SILBasicBlock *succBB,
142-
SILValue pbStructVal);
143-
144-
public:
145-
void visitReturnInst(ReturnInst *ri);
146-
147-
void visitBranchInst(BranchInst *bi);
148-
149-
void visitCondBranchInst(CondBranchInst *cbi);
150-
151-
void visitSwitchEnumInstBase(SwitchEnumInstBase *inst);
152-
153-
void visitSwitchEnumInst(SwitchEnumInst *sei);
154-
155-
void visitSwitchEnumAddrInst(SwitchEnumAddrInst *seai);
156-
157-
void visitCheckedCastBranchInst(CheckedCastBranchInst *ccbi);
158-
159-
void visitCheckedCastValueBranchInst(CheckedCastValueBranchInst *ccvbi);
160-
161-
void visitCheckedCastAddrBranchInst(CheckedCastAddrBranchInst *ccabi);
162-
163-
// If an `apply` has active results or active inout arguments, replace it
164-
// with an `apply` of its VJP.
165-
void visitApplyInst(ApplyInst *ai);
166-
167-
void visitDifferentiableFunctionInst(DifferentiableFunctionInst *dfi);
16860
};
16961

17062
} // end namespace autodiff

lib/SILOptimizer/Differentiation/PullbackCloner.cpp

Lines changed: 24 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,11 @@
1818
#define DEBUG_TYPE "differentiation"
1919

2020
#include "swift/SILOptimizer/Differentiation/PullbackCloner.h"
21+
#include "swift/SILOptimizer/Analysis/DifferentiableActivityAnalysis.h"
2122
#include "swift/SILOptimizer/Differentiation/ADContext.h"
23+
#include "swift/SILOptimizer/Differentiation/AdjointValue.h"
24+
#include "swift/SILOptimizer/Differentiation/DifferentiationInvoker.h"
25+
#include "swift/SILOptimizer/Differentiation/LinearMapInfo.h"
2226
#include "swift/SILOptimizer/Differentiation/Thunk.h"
2327
#include "swift/SILOptimizer/Differentiation/VJPCloner.h"
2428

@@ -27,8 +31,10 @@
2731
#include "swift/AST/TypeCheckRequests.h"
2832
#include "swift/SIL/InstructionUtils.h"
2933
#include "swift/SIL/Projection.h"
34+
#include "swift/SIL/TypeSubstCloner.h"
3035
#include "swift/SILOptimizer/PassManager/PrettyStackTrace.h"
3136
#include "swift/SILOptimizer/Utils/SILOptFunctionBuilder.h"
37+
#include "llvm/ADT/DenseMap.h"
3238

3339
namespace swift {
3440

@@ -49,8 +55,12 @@ class VJPCloner;
4955
/// instructions per basic block in reverse order. This visitation order is
5056
/// necessary for generating pullback functions, whose control flow graph is
5157
/// ~a transposed version of the original function's control flow graph.
52-
struct PullbackCloner::Implementation final
58+
class PullbackCloner::Implementation final
5359
: public SILInstructionVisitor<PullbackCloner::Implementation> {
60+
61+
public:
62+
explicit Implementation(VJPCloner &vjpCloner);
63+
5464
private:
5565
/// The parent VJP cloner.
5666
VJPCloner &vjpCloner;
@@ -124,23 +134,21 @@ struct PullbackCloner::Implementation final
124134

125135
bool errorOccurred = false;
126136

127-
ADContext &getContext() const { return vjpCloner.context; }
137+
ADContext &getContext() const { return vjpCloner.getContext(); }
128138
SILModule &getModule() const { return getContext().getModule(); }
129139
ASTContext &getASTContext() const { return getPullback().getASTContext(); }
130-
SILFunction &getOriginal() const { return *vjpCloner.original; }
131-
SILFunction &getPullback() const { return *vjpCloner.pullback; }
132-
SILDifferentiabilityWitness *getWitness() const { return vjpCloner.witness; }
133-
DifferentiationInvoker getInvoker() const { return vjpCloner.invoker; }
134-
LinearMapInfo &getPullbackInfo() { return vjpCloner.pullbackInfo; }
140+
SILFunction &getOriginal() const { return vjpCloner.getOriginal(); }
141+
SILFunction &getPullback() const { return vjpCloner.getPullback(); }
142+
SILDifferentiabilityWitness *getWitness() const {
143+
return vjpCloner.getWitness();
144+
}
145+
DifferentiationInvoker getInvoker() const { return vjpCloner.getInvoker(); }
146+
LinearMapInfo &getPullbackInfo() const { return vjpCloner.getPullbackInfo(); }
135147
const SILAutoDiffIndices getIndices() const { return vjpCloner.getIndices(); }
136148
const DifferentiableActivityInfo &getActivityInfo() const {
137-
return vjpCloner.activityInfo;
149+
return vjpCloner.getActivityInfo();
138150
}
139151

140-
public:
141-
explicit Implementation(VJPCloner &vjpCloner);
142-
143-
private:
144152
//--------------------------------------------------------------------------//
145153
// Pullback struct mapping
146154
//--------------------------------------------------------------------------//
@@ -1520,9 +1528,10 @@ PullbackCloner::Implementation::Implementation(VJPCloner &vjpCloner)
15201528
auto *domAnalysis = passManager.getAnalysis<DominanceAnalysis>();
15211529
auto *postDomAnalysis = passManager.getAnalysis<PostDominanceAnalysis>();
15221530
auto *postOrderAnalysis = passManager.getAnalysis<PostOrderAnalysis>();
1523-
domInfo = domAnalysis->get(vjpCloner.original);
1524-
postDomInfo = postDomAnalysis->get(vjpCloner.original);
1525-
postOrderInfo = postOrderAnalysis->get(vjpCloner.original);
1531+
auto *original = &vjpCloner.getOriginal();
1532+
domInfo = domAnalysis->get(original);
1533+
postDomInfo = postDomAnalysis->get(original);
1534+
postOrderInfo = postOrderAnalysis->get(original);
15261535
}
15271536

15281537
PullbackCloner::PullbackCloner(VJPCloner &vjpCloner)

0 commit comments

Comments
 (0)