|
| 1 | +//===--- LinearMapInfo.h --------------------------------------*- C++ -*---===// |
| 2 | +// |
| 3 | +// This source file is part of the Swift.org open source project |
| 4 | +// |
| 5 | +// Copyright (c) 2019 Apple Inc. and the Swift project authors |
| 6 | +// Licensed under Apache License v2.0 with Runtime Library Exception |
| 7 | +// |
| 8 | +// See https://swift.org/LICENSE.txt for license information |
| 9 | +// See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors |
| 10 | +// |
| 11 | +//===----------------------------------------------------------------------===// |
| 12 | +// |
| 13 | +// Linear map struct and branching trace enum information for differentation. |
| 14 | +// |
| 15 | +//===----------------------------------------------------------------------===// |
| 16 | + |
| 17 | +#ifndef SWIFT_SILOPTIMIZER_UTILS_DIFFERENTIATION_LINEARMAPINFO_H |
| 18 | +#define SWIFT_SILOPTIMIZER_UTILS_DIFFERENTIATION_LINEARMAPINFO_H |
| 19 | + |
| 20 | +#include "swift/AST/AutoDiff.h" |
| 21 | +#include "swift/SIL/ApplySite.h" |
| 22 | +#include "swift/SILOptimizer/Analysis/DifferentiableActivityAnalysis.h" |
| 23 | +#include "llvm/ADT/DenseMap.h" |
| 24 | + |
| 25 | +namespace swift { |
| 26 | + |
| 27 | +class SILFunction; |
| 28 | +class SILLoopInfo; |
| 29 | + |
| 30 | +namespace autodiff { |
| 31 | + |
| 32 | +class ADContext; |
| 33 | + |
| 34 | +/// Linear map struct and branching trace enum information for an original |
| 35 | +/// function and and derivative function (JVP or VJP). |
| 36 | +/// |
| 37 | +/// Linear map structs contain all callee linear maps produced in a JVP/VJP |
| 38 | +/// basic block. A linear map struct is created for each basic block in the |
| 39 | +/// original function, and a linear map struct field is created for every active |
| 40 | +/// `apply` in the original basic block. |
| 41 | +/// |
| 42 | +/// Branching trace enums model the control flow graph of the original function. |
| 43 | +/// A branching trace enum is created for each basic block in the original |
| 44 | +/// function, and a branching trace enum case is created for every basic block |
| 45 | +/// predecessor/successor. This supports control flow differentiation: JVP/VJP |
| 46 | +/// functions build branching trace enums to record an execution trace. Indirect |
| 47 | +/// branching trace enums are created for basic blocks that are in loops. |
| 48 | +/// |
| 49 | +/// Linear map struct values and branching trace enum values are constructed in |
| 50 | +/// JVP/VJP functions and consumed in pullback/differential functions. |
| 51 | +class LinearMapInfo { |
| 52 | +private: |
| 53 | + /// The linear map kind. |
| 54 | + AutoDiffLinearMapKind kind; |
| 55 | + |
| 56 | + /// The original function. |
| 57 | + SILFunction *const original; |
| 58 | + |
| 59 | + /// The derivative function. |
| 60 | + SILFunction *const derivative; |
| 61 | + |
| 62 | + /// Activity info of the original function. |
| 63 | + const DifferentiableActivityInfo &activityInfo; |
| 64 | + |
| 65 | + /// Differentiation indices of the function. |
| 66 | + const SILAutoDiffIndices indices; |
| 67 | + |
| 68 | + /// Mapping from original basic blocks to linear map structs. |
| 69 | + llvm::DenseMap<SILBasicBlock *, StructDecl *> linearMapStructs; |
| 70 | + |
| 71 | + /// Mapping from original basic blocks to branching trace enums. |
| 72 | + /// For pullbacks: these are predecessor enums. |
| 73 | + /// For differentials: these are successor enums. |
| 74 | + llvm::DenseMap<SILBasicBlock *, EnumDecl *> branchingTraceDecls; |
| 75 | + |
| 76 | + /// Mapping from `apply` instructions in the original function to the |
| 77 | + /// corresponding linear map field declaration in the linear map struct. |
| 78 | + llvm::DenseMap<ApplyInst *, VarDecl *> linearMapFieldMap; |
| 79 | + |
| 80 | + /// Mapping from predecessor-succcessor basic block pairs in the original |
| 81 | + /// function to the corresponding branching trace enum case. |
| 82 | + llvm::DenseMap<std::pair<SILBasicBlock *, SILBasicBlock *>, EnumElementDecl *> |
| 83 | + branchingTraceEnumCases; |
| 84 | + |
| 85 | + /// Mapping from linear map structs to their branching trace enum fields. |
| 86 | + llvm::DenseMap<StructDecl *, VarDecl *> linearMapStructEnumFields; |
| 87 | + |
| 88 | + /// A type converter, used to compute struct/enum SIL types. |
| 89 | + Lowering::TypeConverter &typeConverter; |
| 90 | + |
| 91 | +private: |
| 92 | + /// Remaps the given type into the derivative function's context. |
| 93 | + SILType remapTypeInDerivative(SILType ty); |
| 94 | + |
| 95 | + /// Adds a `VarDecl` member with the given name and type to the given nominal |
| 96 | + /// declaration. |
| 97 | + VarDecl *addVarDecl(NominalTypeDecl *nominal, StringRef name, Type type); |
| 98 | + |
| 99 | + /// Retrieves the file unit that contains implicit declarations in the |
| 100 | + /// current Swift module. If it does not exist, create one. |
| 101 | + /// |
| 102 | + // FIXME: Currently it defaults to the file containing `original`, if it can |
| 103 | + // be determined. Otherwise, it defaults to any file unit in the module. To |
| 104 | + // handle this more properly, we could revive the DerivedFileUnit class to |
| 105 | + // contain all synthesized implicit type declarations. |
| 106 | + SourceFile &getDeclarationFileUnit(); |
| 107 | + |
| 108 | + /// Computes and sets the access level for the given nominal type, given the |
| 109 | + /// original function linkage. |
| 110 | + void computeAccessLevel(NominalTypeDecl *nominal, SILLinkage originalLinkage); |
| 111 | + |
| 112 | + /// Creates an enum declaration with the given JVP/VJP generic signature, |
| 113 | + /// whose cases represent the predecessors/successors of the given original |
| 114 | + /// block. |
| 115 | + EnumDecl *createBranchingTraceDecl(SILBasicBlock *originalBB, |
| 116 | + SILAutoDiffIndices indices, |
| 117 | + CanGenericSignature genericSig, |
| 118 | + SILLoopInfo *loopInfo); |
| 119 | + |
| 120 | + /// Creates a struct declaration with the given JVP/VJP generic signature, for |
| 121 | + /// storing the linear map values and predecessor/successor basic block of the |
| 122 | + /// given original block. |
| 123 | + StructDecl *createLinearMapStruct(SILBasicBlock *originalBB, |
| 124 | + SILAutoDiffIndices indices, |
| 125 | + CanGenericSignature genericSig); |
| 126 | + |
| 127 | + /// Adds a linear map field to the linear map struct. |
| 128 | + VarDecl *addLinearMapDecl(ApplyInst *ai, SILType linearMapType); |
| 129 | + |
| 130 | + /// Given an `apply` instruction, conditionally adds a linear map struct field |
| 131 | + /// for its linear map function if it is active. |
| 132 | + void addLinearMapToStruct(ADContext &context, ApplyInst *ai, |
| 133 | + SILAutoDiffIndices indices); |
| 134 | + |
| 135 | + /// Generates linear map struct and branching enum declarations for the given |
| 136 | + /// function. Linear map structs are populated with linear map fields and a |
| 137 | + /// branching enum field. |
| 138 | + void generateDifferentiationDataStructures(ADContext &context, |
| 139 | + SILAutoDiffIndices indices, |
| 140 | + SILFunction *derivative); |
| 141 | + |
| 142 | +public: |
| 143 | + bool shouldDifferentiateApplySite(FullApplySite applySite); |
| 144 | + bool shouldDifferentiateInstruction(SILInstruction *inst); |
| 145 | + |
| 146 | + LinearMapInfo(const LinearMapInfo &) = delete; |
| 147 | + LinearMapInfo &operator=(const LinearMapInfo &) = delete; |
| 148 | + |
| 149 | + explicit LinearMapInfo(ADContext &context, AutoDiffLinearMapKind kind, |
| 150 | + SILFunction *original, SILFunction *derivative, |
| 151 | + SILAutoDiffIndices indices, |
| 152 | + const DifferentiableActivityInfo &activityInfo); |
| 153 | + |
| 154 | + /// Returns the linear map struct associated with the given original block. |
| 155 | + StructDecl *getLinearMapStruct(SILBasicBlock *origBB) const { |
| 156 | + return linearMapStructs.lookup(origBB); |
| 157 | + } |
| 158 | + |
| 159 | + /// Returns the lowered SIL type of the linear map struct associated with the |
| 160 | + /// given original block. |
| 161 | + SILType getLinearMapStructLoweredType(SILBasicBlock *origBB) const { |
| 162 | + auto derivativeGenSig = |
| 163 | + derivative->getLoweredFunctionType()->getSubstGenericSignature(); |
| 164 | + auto *linMapStruct = getLinearMapStruct(origBB); |
| 165 | + auto linMapStructType = |
| 166 | + linMapStruct->getDeclaredInterfaceType()->getCanonicalType( |
| 167 | + derivativeGenSig); |
| 168 | + Lowering::AbstractionPattern pattern(derivativeGenSig, linMapStructType); |
| 169 | + return typeConverter.getLoweredType(pattern, linMapStructType, |
| 170 | + TypeExpansionContext::minimal()); |
| 171 | + } |
| 172 | + |
| 173 | + /// Returns the branching trace enum associated with the given original block. |
| 174 | + EnumDecl *getBranchingTraceDecl(SILBasicBlock *origBB) const { |
| 175 | + return branchingTraceDecls.lookup(origBB); |
| 176 | + } |
| 177 | + |
| 178 | + /// Returns the lowered SIL type of the branching trace enum associated with |
| 179 | + /// the given original block. |
| 180 | + SILType getBranchingTraceEnumLoweredType(SILBasicBlock *origBB) const { |
| 181 | + auto *traceDecl = getBranchingTraceDecl(origBB); |
| 182 | + auto traceDeclType = |
| 183 | + traceDecl->getDeclaredInterfaceType()->getCanonicalType(); |
| 184 | + Lowering::AbstractionPattern pattern( |
| 185 | + derivative->getLoweredFunctionType()->getSubstGenericSignature(), |
| 186 | + traceDeclType); |
| 187 | + return typeConverter.getLoweredType(pattern, traceDeclType, |
| 188 | + TypeExpansionContext::minimal()); |
| 189 | + } |
| 190 | + |
| 191 | + /// Returns the enum element in the given successor block's branching trace |
| 192 | + /// enum corresponding to the given predecessor block. |
| 193 | + EnumElementDecl * |
| 194 | + lookUpBranchingTraceEnumElement(SILBasicBlock *origPredBB, |
| 195 | + SILBasicBlock *origSuccBB) const { |
| 196 | + assert(origPredBB->getParent() == original); |
| 197 | + return branchingTraceEnumCases.lookup({origPredBB, origSuccBB}); |
| 198 | + } |
| 199 | + |
| 200 | + /// Returns the mapping from linear map structs to their branching trace enum |
| 201 | + /// fields. |
| 202 | + llvm::DenseMap<StructDecl *, VarDecl *> &getLinearMapStructEnumFields() { |
| 203 | + return linearMapStructEnumFields; |
| 204 | + } |
| 205 | + |
| 206 | + /// Returns the branching trace enum field for the linear map struct of the |
| 207 | + /// given original block. |
| 208 | + VarDecl *lookUpLinearMapStructEnumField(SILBasicBlock *origBB) { |
| 209 | + auto *linearMapStruct = getLinearMapStruct(origBB); |
| 210 | + return linearMapStructEnumFields.lookup(linearMapStruct); |
| 211 | + } |
| 212 | + |
| 213 | + /// Finds the linear map declaration in the pullback struct for the given |
| 214 | + /// `apply` instruction in the original function. |
| 215 | + VarDecl *lookUpLinearMapDecl(ApplyInst *ai) { |
| 216 | + assert(ai->getFunction() == original); |
| 217 | + auto lookup = linearMapFieldMap.find(ai); |
| 218 | + assert(lookup != linearMapFieldMap.end() && |
| 219 | + "No linear map field corresponding to the given `apply`"); |
| 220 | + return lookup->getSecond(); |
| 221 | + } |
| 222 | +}; |
| 223 | + |
| 224 | +} // end namespace autodiff |
| 225 | +} // end namespace swift |
| 226 | + |
| 227 | +#endif // SWIFT_SILOPTIMIZER_UTILS_DIFFERENTIATION_LINEARMAPINFO_H |
0 commit comments