Skip to content

Commit fa405e6

Browse files
committed
[AutoDiff upstream] Add LinearMapInfo.
`LinearMapInfo` contains information about linear map structs and branching trace enums, which are auxiliary data structures created by the differentiation transform. These data structures are constructed in JVP/VJP functions and consumed in differential/pullback functions.
1 parent 9e28e0a commit fa405e6

File tree

3 files changed

+793
-0
lines changed

3 files changed

+793
-0
lines changed
Lines changed: 227 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,227 @@
1+
//===--- LinearMapInfo.h --------------------------------------*- C++ -*---===//
2+
//
3+
// This source file is part of the Swift.org open source project
4+
//
5+
// Copyright (c) 2019 Apple Inc. and the Swift project authors
6+
// Licensed under Apache License v2.0 with Runtime Library Exception
7+
//
8+
// See https://swift.org/LICENSE.txt for license information
9+
// See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors
10+
//
11+
//===----------------------------------------------------------------------===//
12+
//
13+
// Linear map struct and branching trace enum information for differentation.
14+
//
15+
//===----------------------------------------------------------------------===//
16+
17+
#ifndef SWIFT_SILOPTIMIZER_UTILS_DIFFERENTIATION_LINEARMAPINFO_H
18+
#define SWIFT_SILOPTIMIZER_UTILS_DIFFERENTIATION_LINEARMAPINFO_H
19+
20+
#include "swift/AST/AutoDiff.h"
21+
#include "swift/SIL/ApplySite.h"
22+
#include "swift/SILOptimizer/Analysis/DifferentiableActivityAnalysis.h"
23+
#include "llvm/ADT/DenseMap.h"
24+
25+
namespace swift {
26+
27+
class SILFunction;
28+
class SILLoopInfo;
29+
30+
namespace autodiff {
31+
32+
class ADContext;
33+
34+
/// Linear map struct and branching trace enum information for an original
35+
/// function and and derivative function (JVP or VJP).
36+
///
37+
/// Linear map structs contain all callee linear maps produced in a JVP/VJP
38+
/// basic block. A linear map struct is created for each basic block in the
39+
/// original function, and a linear map struct field is created for every active
40+
/// `apply` in the original basic block.
41+
///
42+
/// Branching trace enums model the control flow graph of the original function.
43+
/// A branching trace enum is created for each basic block in the original
44+
/// function, and a branching trace enum case is created for every basic block
45+
/// predecessor/successor. This supports control flow differentiation: JVP/VJP
46+
/// functions build branching trace enums to record an execution trace. Indirect
47+
/// branching trace enums are created for basic blocks that are in loops.
48+
///
49+
/// Linear map struct values and branching trace enum values are constructed in
50+
/// JVP/VJP functions and consumed in pullback/differential functions.
51+
class LinearMapInfo {
52+
private:
53+
/// The linear map kind.
54+
AutoDiffLinearMapKind kind;
55+
56+
/// The original function.
57+
SILFunction *const original;
58+
59+
/// The derivative function.
60+
SILFunction *const derivative;
61+
62+
/// Activity info of the original function.
63+
const DifferentiableActivityInfo &activityInfo;
64+
65+
/// Differentiation indices of the function.
66+
const SILAutoDiffIndices indices;
67+
68+
/// Mapping from original basic blocks to linear map structs.
69+
llvm::DenseMap<SILBasicBlock *, StructDecl *> linearMapStructs;
70+
71+
/// Mapping from original basic blocks to branching trace enums.
72+
/// For pullbacks: these are predecessor enums.
73+
/// For differentials: these are successor enums.
74+
llvm::DenseMap<SILBasicBlock *, EnumDecl *> branchingTraceDecls;
75+
76+
/// Mapping from `apply` instructions in the original function to the
77+
/// corresponding linear map field declaration in the linear map struct.
78+
llvm::DenseMap<ApplyInst *, VarDecl *> linearMapFieldMap;
79+
80+
/// Mapping from predecessor-succcessor basic block pairs in the original
81+
/// function to the corresponding branching trace enum case.
82+
llvm::DenseMap<std::pair<SILBasicBlock *, SILBasicBlock *>, EnumElementDecl *>
83+
branchingTraceEnumCases;
84+
85+
/// Mapping from linear map structs to their branching trace enum fields.
86+
llvm::DenseMap<StructDecl *, VarDecl *> linearMapStructEnumFields;
87+
88+
/// A type converter, used to compute struct/enum SIL types.
89+
Lowering::TypeConverter &typeConverter;
90+
91+
private:
92+
/// Remaps the given type into the derivative function's context.
93+
SILType remapTypeInDerivative(SILType ty);
94+
95+
/// Adds a `VarDecl` member with the given name and type to the given nominal
96+
/// declaration.
97+
VarDecl *addVarDecl(NominalTypeDecl *nominal, StringRef name, Type type);
98+
99+
/// Retrieves the file unit that contains implicit declarations in the
100+
/// current Swift module. If it does not exist, create one.
101+
///
102+
// FIXME: Currently it defaults to the file containing `original`, if it can
103+
// be determined. Otherwise, it defaults to any file unit in the module. To
104+
// handle this more properly, we could revive the DerivedFileUnit class to
105+
// contain all synthesized implicit type declarations.
106+
SourceFile &getDeclarationFileUnit();
107+
108+
/// Computes and sets the access level for the given nominal type, given the
109+
/// original function linkage.
110+
void computeAccessLevel(NominalTypeDecl *nominal, SILLinkage originalLinkage);
111+
112+
/// Creates an enum declaration with the given JVP/VJP generic signature,
113+
/// whose cases represent the predecessors/successors of the given original
114+
/// block.
115+
EnumDecl *createBranchingTraceDecl(SILBasicBlock *originalBB,
116+
SILAutoDiffIndices indices,
117+
CanGenericSignature genericSig,
118+
SILLoopInfo *loopInfo);
119+
120+
/// Creates a struct declaration with the given JVP/VJP generic signature, for
121+
/// storing the linear map values and predecessor/successor basic block of the
122+
/// given original block.
123+
StructDecl *createLinearMapStruct(SILBasicBlock *originalBB,
124+
SILAutoDiffIndices indices,
125+
CanGenericSignature genericSig);
126+
127+
/// Adds a linear map field to the linear map struct.
128+
VarDecl *addLinearMapDecl(ApplyInst *ai, SILType linearMapType);
129+
130+
/// Given an `apply` instruction, conditionally adds a linear map struct field
131+
/// for its linear map function if it is active.
132+
void addLinearMapToStruct(ADContext &context, ApplyInst *ai,
133+
SILAutoDiffIndices indices);
134+
135+
/// Generates linear map struct and branching enum declarations for the given
136+
/// function. Linear map structs are populated with linear map fields and a
137+
/// branching enum field.
138+
void generateDifferentiationDataStructures(ADContext &context,
139+
SILAutoDiffIndices indices,
140+
SILFunction *derivative);
141+
142+
public:
143+
bool shouldDifferentiateApplySite(FullApplySite applySite);
144+
bool shouldDifferentiateInstruction(SILInstruction *inst);
145+
146+
LinearMapInfo(const LinearMapInfo &) = delete;
147+
LinearMapInfo &operator=(const LinearMapInfo &) = delete;
148+
149+
explicit LinearMapInfo(ADContext &context, AutoDiffLinearMapKind kind,
150+
SILFunction *original, SILFunction *derivative,
151+
SILAutoDiffIndices indices,
152+
const DifferentiableActivityInfo &activityInfo);
153+
154+
/// Returns the linear map struct associated with the given original block.
155+
StructDecl *getLinearMapStruct(SILBasicBlock *origBB) const {
156+
return linearMapStructs.lookup(origBB);
157+
}
158+
159+
/// Returns the lowered SIL type of the linear map struct associated with the
160+
/// given original block.
161+
SILType getLinearMapStructLoweredType(SILBasicBlock *origBB) const {
162+
auto derivativeGenSig =
163+
derivative->getLoweredFunctionType()->getSubstGenericSignature();
164+
auto *linMapStruct = getLinearMapStruct(origBB);
165+
auto linMapStructType =
166+
linMapStruct->getDeclaredInterfaceType()->getCanonicalType(
167+
derivativeGenSig);
168+
Lowering::AbstractionPattern pattern(derivativeGenSig, linMapStructType);
169+
return typeConverter.getLoweredType(pattern, linMapStructType,
170+
TypeExpansionContext::minimal());
171+
}
172+
173+
/// Returns the branching trace enum associated with the given original block.
174+
EnumDecl *getBranchingTraceDecl(SILBasicBlock *origBB) const {
175+
return branchingTraceDecls.lookup(origBB);
176+
}
177+
178+
/// Returns the lowered SIL type of the branching trace enum associated with
179+
/// the given original block.
180+
SILType getBranchingTraceEnumLoweredType(SILBasicBlock *origBB) const {
181+
auto *traceDecl = getBranchingTraceDecl(origBB);
182+
auto traceDeclType =
183+
traceDecl->getDeclaredInterfaceType()->getCanonicalType();
184+
Lowering::AbstractionPattern pattern(
185+
derivative->getLoweredFunctionType()->getSubstGenericSignature(),
186+
traceDeclType);
187+
return typeConverter.getLoweredType(pattern, traceDeclType,
188+
TypeExpansionContext::minimal());
189+
}
190+
191+
/// Returns the enum element in the given successor block's branching trace
192+
/// enum corresponding to the given predecessor block.
193+
EnumElementDecl *
194+
lookUpBranchingTraceEnumElement(SILBasicBlock *origPredBB,
195+
SILBasicBlock *origSuccBB) const {
196+
assert(origPredBB->getParent() == original);
197+
return branchingTraceEnumCases.lookup({origPredBB, origSuccBB});
198+
}
199+
200+
/// Returns the mapping from linear map structs to their branching trace enum
201+
/// fields.
202+
llvm::DenseMap<StructDecl *, VarDecl *> &getLinearMapStructEnumFields() {
203+
return linearMapStructEnumFields;
204+
}
205+
206+
/// Returns the branching trace enum field for the linear map struct of the
207+
/// given original block.
208+
VarDecl *lookUpLinearMapStructEnumField(SILBasicBlock *origBB) {
209+
auto *linearMapStruct = getLinearMapStruct(origBB);
210+
return linearMapStructEnumFields.lookup(linearMapStruct);
211+
}
212+
213+
/// Finds the linear map declaration in the pullback struct for the given
214+
/// `apply` instruction in the original function.
215+
VarDecl *lookUpLinearMapDecl(ApplyInst *ai) {
216+
assert(ai->getFunction() == original);
217+
auto lookup = linearMapFieldMap.find(ai);
218+
assert(lookup != linearMapFieldMap.end() &&
219+
"No linear map field corresponding to the given `apply`");
220+
return lookup->getSecond();
221+
}
222+
};
223+
224+
} // end namespace autodiff
225+
} // end namespace swift
226+
227+
#endif // SWIFT_SILOPTIMIZER_UTILS_DIFFERENTIATION_LINEARMAPINFO_H

lib/SILOptimizer/Utils/Differentiation/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,5 +2,6 @@ silopt_register_sources(
22
ADContext.cpp
33
Common.cpp
44
DifferentiationInvoker.cpp
5+
LinearMapInfo.cpp
56
Thunk.cpp
67
)

0 commit comments

Comments
 (0)