Skip to content

Commit d2e022d

Browse files
authored
Remove linear map structs and use plain tuples instead. (swiftlang#63444)
The changes are intentionally were made close to the original implementation w/o possible simplifications to ease the review Fixes swiftlang#63207, supersedes swiftlang#63379 (and fixes swiftlang#63234)
1 parent acd72ef commit d2e022d

File tree

12 files changed

+416
-484
lines changed

12 files changed

+416
-484
lines changed

include/swift/SILOptimizer/Differentiation/LinearMapInfo.h

Lines changed: 28 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -69,26 +69,23 @@ class LinearMapInfo {
6969
/// Differentiation indices of the function.
7070
const AutoDiffConfig config;
7171

72-
/// Mapping from original basic blocks to linear map structs.
73-
llvm::DenseMap<SILBasicBlock *, StructDecl *> linearMapStructs;
72+
/// Mapping from original basic blocks to linear map tuple types.
73+
llvm::DenseMap<SILBasicBlock *, TupleType *> linearMapTuples;
7474

7575
/// Mapping from original basic blocks to branching trace enums.
7676
/// For pullbacks: these are predecessor enums.
7777
/// For differentials: these are successor enums.
7878
llvm::DenseMap<SILBasicBlock *, EnumDecl *> branchingTraceDecls;
7979

8080
/// Mapping from `apply` instructions in the original function to the
81-
/// corresponding linear map field declaration in the linear map struct.
82-
llvm::DenseMap<ApplyInst *, VarDecl *> linearMapFieldMap;
81+
/// corresponding linear map tuple type index.
82+
llvm::DenseMap<ApplyInst *, unsigned> linearMapIndexMap;
8383

8484
/// Mapping from predecessor-successor basic block pairs in the original
8585
/// function to the corresponding branching trace enum case.
8686
llvm::DenseMap<std::pair<SILBasicBlock *, SILBasicBlock *>, EnumElementDecl *>
8787
branchingTraceEnumCases;
8888

89-
/// Mapping from linear map structs to their branching trace enum fields.
90-
llvm::DenseMap<StructDecl *, VarDecl *> linearMapStructEnumFields;
91-
9289
/// Blocks in a loop.
9390
llvm::SmallSetVector<SILBasicBlock *, 4> blocksInLoop;
9491

@@ -102,37 +99,21 @@ class LinearMapInfo {
10299
/// Remaps the given type into the derivative function's context.
103100
SILType remapTypeInDerivative(SILType ty);
104101

105-
/// Adds a `VarDecl` member with the given name and type to the given nominal
106-
/// declaration.
107-
VarDecl *addVarDecl(NominalTypeDecl *nominal, StringRef name, Type type);
108-
109102
/// Retrieves the file unit that contains implicit declarations in the
110103
/// current Swift module.
111104
SynthesizedFileUnit &getSynthesizedFile() { return synthesizedFile; }
112105

113-
/// Computes and sets the access level for the given nominal type, given the
114-
/// original function linkage.
115-
void computeAccessLevel(NominalTypeDecl *nominal, SILLinkage originalLinkage);
116-
117106
/// Creates an enum declaration with the given JVP/VJP generic signature,
118107
/// whose cases represent the predecessors/successors of the given original
119108
/// block.
120109
EnumDecl *createBranchingTraceDecl(SILBasicBlock *originalBB,
121-
CanGenericSignature genericSig,
122-
SILLoopInfo *loopInfo);
110+
CanGenericSignature genericSig);
111+
void populateBranchingTraceDecl(SILBasicBlock *originalBB,
112+
SILLoopInfo *loopInfo);
123113

124-
/// Creates a struct declaration with the given JVP/VJP generic signature, for
125-
/// storing the linear map values and predecessor/successor basic block of the
126-
/// given original block.
127-
StructDecl *createLinearMapStruct(SILBasicBlock *originalBB,
128-
CanGenericSignature genericSig);
129-
130-
/// Adds a linear map field to the linear map struct.
131-
VarDecl *addLinearMapDecl(ApplyInst *ai, SILType linearMapType);
132-
133-
/// Given an `apply` instruction, conditionally adds a linear map struct field
134-
/// for its linear map function if it is active.
135-
void addLinearMapToStruct(ADContext &context, ApplyInst *ai);
114+
/// Given an `apply` instruction, conditionally gets a linear map tuple field
115+
/// AST type for its linear map function if it is active.
116+
Type getLinearMapType(ADContext &context, ApplyInst *ai);
136117

137118
/// Generates linear map struct and branching enum declarations for the given
138119
/// function. Linear map structs are populated with linear map fields and a
@@ -153,22 +134,20 @@ class LinearMapInfo {
153134
const DifferentiableActivityInfo &activityInfo,
154135
SILLoopInfo *loopInfo);
155136

156-
/// Returns the linear map struct associated with the given original block.
157-
StructDecl *getLinearMapStruct(SILBasicBlock *origBB) const {
158-
return linearMapStructs.lookup(origBB);
137+
/// Returns the linear map tuple associated with the given original block.
138+
TupleType *getLinearMapTupleType(SILBasicBlock *origBB) const {
139+
return linearMapTuples.lookup(origBB);
159140
}
160141

161-
/// Returns the lowered SIL type of the linear map struct associated with the
142+
/// Returns the lowered SIL type of the linear map tuple associated with the
162143
/// given original block.
163-
SILType getLinearMapStructLoweredType(SILBasicBlock *origBB) const {
144+
SILType getLinearMapTupleLoweredType(SILBasicBlock *origBB) const {
164145
auto derivativeGenSig =
165146
derivative->getLoweredFunctionType()->getSubstGenericSignature();
166-
auto *linMapStruct = getLinearMapStruct(origBB);
167-
auto linMapStructType =
168-
linMapStruct->getDeclaredInterfaceType()->getReducedType(
169-
derivativeGenSig);
170-
Lowering::AbstractionPattern pattern(derivativeGenSig, linMapStructType);
171-
return typeConverter.getLoweredType(pattern, linMapStructType,
147+
auto linMapTupleType =
148+
getLinearMapTupleType(origBB)->getReducedType(derivativeGenSig);
149+
Lowering::AbstractionPattern pattern(derivativeGenSig, linMapTupleType);
150+
return typeConverter.getLoweredType(pattern, linMapTupleType,
172151
TypeExpansionContext::minimal());
173152
}
174153

@@ -199,29 +178,21 @@ class LinearMapInfo {
199178
return branchingTraceEnumCases.lookup({origPredBB, origSuccBB});
200179
}
201180

202-
/// Returns the mapping from linear map structs to their branching trace enum
203-
/// fields.
204-
llvm::DenseMap<StructDecl *, VarDecl *> &getLinearMapStructEnumFields() {
205-
return linearMapStructEnumFields;
206-
}
207-
208-
/// Returns the branching trace enum field for the linear map struct of the
209-
/// given original block.
210-
VarDecl *lookUpLinearMapStructEnumField(SILBasicBlock *origBB) const {
211-
auto *linearMapStruct = getLinearMapStruct(origBB);
212-
return linearMapStructEnumFields.lookup(linearMapStruct);
213-
}
214-
215-
/// Finds the linear map declaration in the pullback struct for the given
181+
/// Finds the linear map index in the pullback tuple for the given
216182
/// `apply` instruction in the original function.
217-
VarDecl *lookUpLinearMapDecl(ApplyInst *ai) const {
183+
unsigned lookUpLinearMapIndex(ApplyInst *ai) const {
218184
assert(ai->getFunction() == original);
219-
auto lookup = linearMapFieldMap.find(ai);
220-
assert(lookup != linearMapFieldMap.end() &&
185+
auto lookup = linearMapIndexMap.find(ai);
186+
assert(lookup != linearMapIndexMap.end() &&
221187
"No linear map field corresponding to the given `apply`");
222188
return lookup->getSecond();
223189
}
224190

191+
Type lookUpLinearMapType(ApplyInst *ai) const {
192+
unsigned idx = lookUpLinearMapIndex(ai);
193+
return getLinearMapTupleType(ai->getParentBlock())->getElement(idx).getType();
194+
}
195+
225196
bool hasLoops() const {
226197
return !blocksInLoop.empty();
227198
}

lib/SIL/IR/SILPrinter.cpp

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3449,6 +3449,36 @@ static void printSILDifferentiabilityWitnesses(
34493449
dw->print(Ctx.OS(), Ctx.printVerbose());
34503450
}
34513451

3452+
static void printSILLinearMapTypes(SILPrintContext &Ctx,
3453+
const ModuleDecl *M) {
3454+
auto &OS = Ctx.OS();
3455+
3456+
PrintOptions Options = PrintOptions::printSIL();
3457+
Options.TypeDefinitions = true;
3458+
Options.VarInitializers = true;
3459+
Options.ExplodePatternBindingDecls = true;
3460+
Options.SkipImplicit = false;
3461+
Options.PrintGetSetOnRWProperties = true;
3462+
Options.PrintInSILBody = false;
3463+
3464+
SmallVector<Decl *, 32> topLevelDecls;
3465+
M->getTopLevelDecls(topLevelDecls);
3466+
for (const Decl *D : topLevelDecls) {
3467+
if (D->getDeclContext() == M)
3468+
continue;
3469+
3470+
if (!isa<StructDecl>(D) && !isa<EnumDecl>(D))
3471+
continue;
3472+
3473+
StringRef Name = cast<TypeDecl>(D)->getNameStr();
3474+
if (!Name.startswith("_AD__"))
3475+
continue;
3476+
3477+
D->print(OS, Options);
3478+
OS << "\n\n";
3479+
}
3480+
}
3481+
34523482
static void
34533483
printSILCoverageMaps(SILPrintContext &Ctx,
34543484
const SILModule::CoverageMapCollectionType &CoverageMaps) {
@@ -3624,6 +3654,7 @@ void SILModule::print(SILPrintContext &PrintCtx, ModuleDecl *M,
36243654
printSILGlobals(PrintCtx, getSILGlobalList());
36253655
printSILDifferentiabilityWitnesses(PrintCtx,
36263656
getDifferentiabilityWitnessList());
3657+
printSILLinearMapTypes(PrintCtx, getSwiftModule());
36273658
printSILFunctions(PrintCtx, getFunctionList());
36283659
printSILVTables(PrintCtx, getVTables());
36293660
printSILWitnessTables(PrintCtx, getWitnessTableList());

lib/SIL/Parser/ParseSIL.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2416,7 +2416,9 @@ static bool parseSILDifferentiabilityWitnessConfigAndFunction(
24162416
auto origFnType = resultOrigFn->getLoweredFunctionType();
24172417
auto *parameterIndices = IndexSubset::get(
24182418
P.Context, origFnType->getNumParameters(), rawParameterIndices);
2419-
auto *resultIndices = IndexSubset::get(P.Context, origFnType->getNumResults(),
2419+
auto *resultIndices = IndexSubset::get(P.Context,
2420+
origFnType->getNumResults() +
2421+
origFnType->getNumIndirectMutatingParameters(),
24202422
rawResultIndices);
24212423
resultConfig = AutoDiffConfig(parameterIndices, resultIndices, witnessGenSig);
24222424
return false;

lib/SILOptimizer/Differentiation/JVPCloner.cpp

Lines changed: 39 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ class JVPCloner::Implementation final
9797
/// elements destructured from the linear map basic block argument. In the
9898
/// beginning of each differential basic block, the block's differential
9999
/// struct is destructured into the individual elements stored here.
100-
llvm::DenseMap<VarDecl *, SILValue> differentialStructElements;
100+
llvm::DenseMap<SILBasicBlock *, SILInstructionResultArray> differentialTupleElements;
101101

102102
/// An auxiliary differential local allocation builder.
103103
TangentBuilder diffLocalAllocBuilder;
@@ -119,23 +119,17 @@ class JVPCloner::Implementation final
119119
TangentBuilder &getDifferentialBuilder() { return differentialBuilder; }
120120
SILFunction &getDifferential() { return differentialBuilder.getFunction(); }
121121
SILArgument *getDifferentialStructArgument(SILBasicBlock *origBB) {
122-
#ifndef NDEBUG
123-
auto *diffStruct = differentialStructArguments[origBB]
124-
->getType()
125-
.getStructOrBoundGenericStruct();
126-
assert(diffStruct == differentialInfo.getLinearMapStruct(origBB));
127-
#endif
128122
return differentialStructArguments[origBB];
129123
}
130124

131125
//--------------------------------------------------------------------------//
132-
// Differential struct mapping
126+
// Differential tuple mapping
133127
//--------------------------------------------------------------------------//
134128

135-
void initializeDifferentialStructElements(SILBasicBlock *origBB,
136-
SILInstructionResultArray values);
129+
void initializeDifferentialTupleElements(SILBasicBlock *origBB,
130+
SILInstructionResultArray values);
137131

138-
SILValue getDifferentialStructElement(SILBasicBlock *origBB, VarDecl *field);
132+
SILValue getDifferentialTupleElement(ApplyInst *ai);
139133

140134
//--------------------------------------------------------------------------//
141135
// General utilities
@@ -158,22 +152,21 @@ class JVPCloner::Implementation final
158152

159153
/// Build a differential struct value for the original block corresponding to
160154
/// the given terminator.
161-
StructInst *buildDifferentialValueStructValue(TermInst *termInst) {
155+
TupleInst *buildDifferentialValueStructValue(TermInst *termInst) {
162156
assert(termInst->getFunction() == original);
163157
auto loc = termInst->getFunction()->getLocation();
164158
auto *origBB = termInst->getParent();
165159
auto *jvpBB = BBMap[origBB];
166160
assert(jvpBB && "Basic block mapping should exist");
167-
auto *diffStruct = differentialInfo.getLinearMapStruct(origBB);
168-
assert(diffStruct && "The differential struct should have been declared");
169-
auto structLoweredTy = getNominalDeclLoweredType(diffStruct);
161+
auto tupleLoweredTy =
162+
remapType(differentialInfo.getLinearMapTupleLoweredType(origBB));
170163
auto bbDifferentialValues = differentialValues[origBB];
171164
if (!origBB->isEntry()) {
172165
auto *enumArg = jvpBB->getArguments().back();
173166
bbDifferentialValues.insert(bbDifferentialValues.begin(), enumArg);
174167
}
175-
return getBuilder().createStruct(loc, structLoweredTy,
176-
bbDifferentialValues);
168+
return getBuilder().createTuple(loc, tupleLoweredTy,
169+
bbDifferentialValues);
177170
}
178171

179172
//--------------------------------------------------------------------------//
@@ -438,8 +431,8 @@ class JVPCloner::Implementation final
438431
auto *mainDifferentialStruct = diffBB->getArguments().back();
439432
diffBuilder.setInsertionPoint(diffBB);
440433
auto *dsi =
441-
diffBuilder.createDestructureStruct(diffLoc, mainDifferentialStruct);
442-
initializeDifferentialStructElements(bb, dsi->getResults());
434+
diffBuilder.createDestructureTuple(diffLoc, mainDifferentialStruct);
435+
initializeDifferentialTupleElements(bb, dsi->getResults());
443436
TypeSubstCloner::visitInstructionsInBlock(bb);
444437
}
445438

@@ -667,12 +660,11 @@ class JVPCloner::Implementation final
667660
// Add the differential function for when we create the struct we partially
668661
// apply to the differential we are generating.
669662
auto differential = jvpDirectResults.back();
670-
auto *differentialDecl = differentialInfo.lookUpLinearMapDecl(ai);
663+
auto differentialType = differentialInfo.lookUpLinearMapType(ai);
671664
auto originalDifferentialType =
672665
getOpType(differential->getType()).getAs<SILFunctionType>();
673666
auto loweredDifferentialType =
674-
getOpType(getLoweredType(differentialDecl->getInterfaceType()))
675-
.castTo<SILFunctionType>();
667+
getOpType(getLoweredType(differentialType)).castTo<SILFunctionType>();
676668
// If actual differential type does not match lowered differential type,
677669
// reabstract the differential using a thunk.
678670
if (!loweredDifferentialType->isEqual(originalDifferentialType)) {
@@ -1218,9 +1210,7 @@ class JVPCloner::Implementation final
12181210
auto &diffBuilder = getDifferentialBuilder();
12191211

12201212
// Get the differential value.
1221-
auto *field = differentialInfo.lookUpLinearMapDecl(ai);
1222-
assert(field);
1223-
SILValue differential = getDifferentialStructElement(bb, field);
1213+
SILValue differential = getDifferentialTupleElement(ai);
12241214
auto differentialType = remapSILTypeInDifferential(differential->getType())
12251215
.castTo<SILFunctionType>();
12261216

@@ -1432,31 +1422,27 @@ JVPCloner::~JVPCloner() { delete &impl; }
14321422
// Differential struct mapping
14331423
//--------------------------------------------------------------------------//
14341424

1435-
void JVPCloner::Implementation::initializeDifferentialStructElements(
1436-
SILBasicBlock *origBB, SILInstructionResultArray values) {
1437-
auto *diffStructDecl = differentialInfo.getLinearMapStruct(origBB);
1438-
assert(diffStructDecl->getStoredProperties().size() == values.size() &&
1439-
"The number of differential struct fields must equal the number of "
1425+
void JVPCloner::Implementation::initializeDifferentialTupleElements(
1426+
SILBasicBlock *origBB, SILInstructionResultArray values) {
1427+
auto *diffTupleTyple = differentialInfo.getLinearMapTupleType(origBB);
1428+
assert(diffTupleTyple->getNumElements() == values.size() &&
1429+
"The number of differential tuple fields must equal the number of "
14401430
"differential struct element values");
1441-
for (auto pair : llvm::zip(diffStructDecl->getStoredProperties(), values)) {
1442-
assert(std::get<1>(pair)->getOwnershipKind() != OwnershipKind::Guaranteed &&
1443-
"Differential struct elements must be @owned");
1444-
auto insertion = differentialStructElements.insert(
1445-
{std::get<0>(pair), std::get<1>(pair)});
1446-
(void)insertion;
1447-
assert(insertion.second &&
1448-
"A differential struct element mapping already exists!");
1449-
}
1431+
auto res = differentialTupleElements.insert({origBB, values});
1432+
(void)res;
1433+
assert(res.second && "A pullback struct element already exists!");
14501434
}
14511435

1452-
SILValue
1453-
JVPCloner::Implementation::getDifferentialStructElement(SILBasicBlock *origBB,
1454-
VarDecl *field) {
1455-
assert(differentialInfo.getLinearMapStruct(origBB) ==
1456-
cast<StructDecl>(field->getDeclContext()));
1457-
assert(differentialStructElements.count(field) &&
1458-
"Differential struct element for this field does not exist!");
1459-
return differentialStructElements.lookup(field);
1436+
/// Returns the differential tuple element value corresponding to the given
1437+
/// original block and apply inst.
1438+
SILValue JVPCloner::Implementation::getDifferentialTupleElement(ApplyInst *ai) {
1439+
unsigned idx = differentialInfo.lookUpLinearMapIndex(ai);
1440+
assert((idx > 0 || (idx == 0 && ai->getParentBlock()->isEntry())) &&
1441+
"impossible linear map index");
1442+
auto values = differentialTupleElements.lookup(ai->getParentBlock());
1443+
assert(idx < values.size() &&
1444+
"differential tuple element for this apply does not exist!");
1445+
return values[idx];
14601446
}
14611447

14621448
//--------------------------------------------------------------------------//
@@ -1481,9 +1467,9 @@ void JVPCloner::Implementation::prepareForDifferentialGeneration() {
14811467
createEntryArguments(&differential);
14821468
auto *lastArg = diffBB->getArguments().back();
14831469
#ifndef NDEBUG
1484-
auto diffStructLoweredType = remapSILTypeInDifferential(
1485-
differentialInfo.getLinearMapStructLoweredType(&origBB));
1486-
assert(lastArg->getType() == diffStructLoweredType);
1470+
auto diffTupleLoweredType = remapSILTypeInDifferential(
1471+
differentialInfo.getLinearMapTupleLoweredType(&origBB));
1472+
assert(lastArg->getType() == diffTupleLoweredType);
14871473
#endif
14881474
differentialStructArguments[&origBB] = lastArg;
14891475
}
@@ -1671,10 +1657,9 @@ void JVPCloner::Implementation::prepareForDifferentialGeneration() {
16711657
// Accept a differential struct in the differential parameter list. This is
16721658
// the returned differential's closure context.
16731659
auto *origEntry = original->getEntryBlock();
1674-
auto *dfStruct = linearMapInfo->getLinearMapStruct(origEntry);
1675-
auto dfStructType =
1676-
dfStruct->getDeclaredInterfaceType()->getReducedType(witnessCanGenSig);
1677-
dfParams.push_back({dfStructType, ParameterConvention::Direct_Owned});
1660+
auto dfTupleType =
1661+
linearMapInfo->getLinearMapTupleLoweredType(origEntry).getASTType();
1662+
dfParams.push_back({dfTupleType, ParameterConvention::Direct_Owned});
16781663

16791664
Mangle::DifferentiationMangler mangler;
16801665
auto diffName = mangler.mangleLinearMap(

0 commit comments

Comments
 (0)