@@ -97,7 +97,7 @@ class JVPCloner::Implementation final
97
97
// / elements destructured from the linear map basic block argument. In the
98
98
// / beginning of each differential basic block, the block's differential
99
99
// / struct is destructured into the individual elements stored here.
100
- llvm::DenseMap<VarDecl *, SILValue> differentialStructElements ;
100
+ llvm::DenseMap<SILBasicBlock *, SILInstructionResultArray> differentialTupleElements ;
101
101
102
102
// / An auxiliary differential local allocation builder.
103
103
TangentBuilder diffLocalAllocBuilder;
@@ -119,23 +119,17 @@ class JVPCloner::Implementation final
119
119
TangentBuilder &getDifferentialBuilder () { return differentialBuilder; }
120
120
SILFunction &getDifferential () { return differentialBuilder.getFunction (); }
121
121
SILArgument *getDifferentialStructArgument (SILBasicBlock *origBB) {
122
- #ifndef NDEBUG
123
- auto *diffStruct = differentialStructArguments[origBB]
124
- ->getType ()
125
- .getStructOrBoundGenericStruct ();
126
- assert (diffStruct == differentialInfo.getLinearMapStruct (origBB));
127
- #endif
128
122
return differentialStructArguments[origBB];
129
123
}
130
124
131
125
// --------------------------------------------------------------------------//
132
- // Differential struct mapping
126
+ // Differential tuple mapping
133
127
// --------------------------------------------------------------------------//
134
128
135
- void initializeDifferentialStructElements (SILBasicBlock *origBB,
136
- SILInstructionResultArray values);
129
+ void initializeDifferentialTupleElements (SILBasicBlock *origBB,
130
+ SILInstructionResultArray values);
137
131
138
- SILValue getDifferentialStructElement (SILBasicBlock *origBB, VarDecl *field );
132
+ SILValue getDifferentialTupleElement (ApplyInst *ai );
139
133
140
134
// --------------------------------------------------------------------------//
141
135
// General utilities
@@ -158,22 +152,21 @@ class JVPCloner::Implementation final
158
152
159
153
// / Build a differential struct value for the original block corresponding to
160
154
// / the given terminator.
161
- StructInst *buildDifferentialValueStructValue (TermInst *termInst) {
155
+ TupleInst *buildDifferentialValueStructValue (TermInst *termInst) {
162
156
assert (termInst->getFunction () == original);
163
157
auto loc = termInst->getFunction ()->getLocation ();
164
158
auto *origBB = termInst->getParent ();
165
159
auto *jvpBB = BBMap[origBB];
166
160
assert (jvpBB && " Basic block mapping should exist" );
167
- auto *diffStruct = differentialInfo.getLinearMapStruct (origBB);
168
- assert (diffStruct && " The differential struct should have been declared" );
169
- auto structLoweredTy = getNominalDeclLoweredType (diffStruct);
161
+ auto tupleLoweredTy =
162
+ remapType (differentialInfo.getLinearMapTupleLoweredType (origBB));
170
163
auto bbDifferentialValues = differentialValues[origBB];
171
164
if (!origBB->isEntry ()) {
172
165
auto *enumArg = jvpBB->getArguments ().back ();
173
166
bbDifferentialValues.insert (bbDifferentialValues.begin (), enumArg);
174
167
}
175
- return getBuilder ().createStruct (loc, structLoweredTy ,
176
- bbDifferentialValues);
168
+ return getBuilder ().createTuple (loc, tupleLoweredTy ,
169
+ bbDifferentialValues);
177
170
}
178
171
179
172
// --------------------------------------------------------------------------//
@@ -438,8 +431,8 @@ class JVPCloner::Implementation final
438
431
auto *mainDifferentialStruct = diffBB->getArguments ().back ();
439
432
diffBuilder.setInsertionPoint (diffBB);
440
433
auto *dsi =
441
- diffBuilder.createDestructureStruct (diffLoc, mainDifferentialStruct);
442
- initializeDifferentialStructElements (bb, dsi->getResults ());
434
+ diffBuilder.createDestructureTuple (diffLoc, mainDifferentialStruct);
435
+ initializeDifferentialTupleElements (bb, dsi->getResults ());
443
436
TypeSubstCloner::visitInstructionsInBlock (bb);
444
437
}
445
438
@@ -667,12 +660,11 @@ class JVPCloner::Implementation final
667
660
// Add the differential function for when we create the struct we partially
668
661
// apply to the differential we are generating.
669
662
auto differential = jvpDirectResults.back ();
670
- auto *differentialDecl = differentialInfo.lookUpLinearMapDecl (ai);
663
+ auto differentialType = differentialInfo.lookUpLinearMapType (ai);
671
664
auto originalDifferentialType =
672
665
getOpType (differential->getType ()).getAs <SILFunctionType>();
673
666
auto loweredDifferentialType =
674
- getOpType (getLoweredType (differentialDecl->getInterfaceType ()))
675
- .castTo <SILFunctionType>();
667
+ getOpType (getLoweredType (differentialType)).castTo <SILFunctionType>();
676
668
// If actual differential type does not match lowered differential type,
677
669
// reabstract the differential using a thunk.
678
670
if (!loweredDifferentialType->isEqual (originalDifferentialType)) {
@@ -1218,9 +1210,7 @@ class JVPCloner::Implementation final
1218
1210
auto &diffBuilder = getDifferentialBuilder ();
1219
1211
1220
1212
// Get the differential value.
1221
- auto *field = differentialInfo.lookUpLinearMapDecl (ai);
1222
- assert (field);
1223
- SILValue differential = getDifferentialStructElement (bb, field);
1213
+ SILValue differential = getDifferentialTupleElement (ai);
1224
1214
auto differentialType = remapSILTypeInDifferential (differential->getType ())
1225
1215
.castTo <SILFunctionType>();
1226
1216
@@ -1432,31 +1422,27 @@ JVPCloner::~JVPCloner() { delete &impl; }
1432
1422
// Differential struct mapping
1433
1423
// --------------------------------------------------------------------------//
1434
1424
1435
- void JVPCloner::Implementation::initializeDifferentialStructElements (
1436
- SILBasicBlock *origBB, SILInstructionResultArray values) {
1437
- auto *diffStructDecl = differentialInfo.getLinearMapStruct (origBB);
1438
- assert (diffStructDecl-> getStoredProperties (). size () == values.size () &&
1439
- " The number of differential struct fields must equal the number of "
1425
+ void JVPCloner::Implementation::initializeDifferentialTupleElements (
1426
+ SILBasicBlock *origBB, SILInstructionResultArray values) {
1427
+ auto *diffTupleTyple = differentialInfo.getLinearMapTupleType (origBB);
1428
+ assert (diffTupleTyple-> getNumElements () == values.size () &&
1429
+ " The number of differential tuple fields must equal the number of "
1440
1430
" differential struct element values" );
1441
- for (auto pair : llvm::zip (diffStructDecl->getStoredProperties (), values)) {
1442
- assert (std::get<1 >(pair)->getOwnershipKind () != OwnershipKind::Guaranteed &&
1443
- " Differential struct elements must be @owned" );
1444
- auto insertion = differentialStructElements.insert (
1445
- {std::get<0 >(pair), std::get<1 >(pair)});
1446
- (void )insertion;
1447
- assert (insertion.second &&
1448
- " A differential struct element mapping already exists!" );
1449
- }
1431
+ auto res = differentialTupleElements.insert ({origBB, values});
1432
+ (void )res;
1433
+ assert (res.second && " A pullback struct element already exists!" );
1450
1434
}
1451
1435
1452
- SILValue
1453
- JVPCloner::Implementation::getDifferentialStructElement (SILBasicBlock *origBB,
1454
- VarDecl *field) {
1455
- assert (differentialInfo.getLinearMapStruct (origBB) ==
1456
- cast<StructDecl>(field->getDeclContext ()));
1457
- assert (differentialStructElements.count (field) &&
1458
- " Differential struct element for this field does not exist!" );
1459
- return differentialStructElements.lookup (field);
1436
+ // / Returns the differential tuple element value corresponding to the given
1437
+ // / original block and apply inst.
1438
+ SILValue JVPCloner::Implementation::getDifferentialTupleElement (ApplyInst *ai) {
1439
+ unsigned idx = differentialInfo.lookUpLinearMapIndex (ai);
1440
+ assert ((idx > 0 || (idx == 0 && ai->getParentBlock ()->isEntry ())) &&
1441
+ " impossible linear map index" );
1442
+ auto values = differentialTupleElements.lookup (ai->getParentBlock ());
1443
+ assert (idx < values.size () &&
1444
+ " differential tuple element for this apply does not exist!" );
1445
+ return values[idx];
1460
1446
}
1461
1447
1462
1448
// --------------------------------------------------------------------------//
@@ -1481,9 +1467,9 @@ void JVPCloner::Implementation::prepareForDifferentialGeneration() {
1481
1467
createEntryArguments (&differential);
1482
1468
auto *lastArg = diffBB->getArguments ().back ();
1483
1469
#ifndef NDEBUG
1484
- auto diffStructLoweredType = remapSILTypeInDifferential (
1485
- differentialInfo.getLinearMapStructLoweredType (&origBB));
1486
- assert (lastArg->getType () == diffStructLoweredType );
1470
+ auto diffTupleLoweredType = remapSILTypeInDifferential (
1471
+ differentialInfo.getLinearMapTupleLoweredType (&origBB));
1472
+ assert (lastArg->getType () == diffTupleLoweredType );
1487
1473
#endif
1488
1474
differentialStructArguments[&origBB] = lastArg;
1489
1475
}
@@ -1671,10 +1657,9 @@ void JVPCloner::Implementation::prepareForDifferentialGeneration() {
1671
1657
// Accept a differential struct in the differential parameter list. This is
1672
1658
// the returned differential's closure context.
1673
1659
auto *origEntry = original->getEntryBlock ();
1674
- auto *dfStruct = linearMapInfo->getLinearMapStruct (origEntry);
1675
- auto dfStructType =
1676
- dfStruct->getDeclaredInterfaceType ()->getReducedType (witnessCanGenSig);
1677
- dfParams.push_back ({dfStructType, ParameterConvention::Direct_Owned});
1660
+ auto dfTupleType =
1661
+ linearMapInfo->getLinearMapTupleLoweredType (origEntry).getASTType ();
1662
+ dfParams.push_back ({dfTupleType, ParameterConvention::Direct_Owned});
1678
1663
1679
1664
Mangle::DifferentiationMangler mangler;
1680
1665
auto diffName = mangler.mangleLinearMap (
0 commit comments