@@ -90,7 +90,7 @@ class PullbackCloner::Implementation final
90
90
// / elements destructured from the linear map basic block argument. In the
91
91
// / beginning of each pullback basic block, the block's pullback struct is
92
92
// / destructured into individual elements stored here.
93
- llvm::DenseMap<SILBasicBlock*, SILInstructionResultArray > pullbackTupleElements;
93
+ llvm::DenseMap<SILBasicBlock*, SmallVector<SILValue, 4 > > pullbackTupleElements;
94
94
95
95
// / Mapping from original basic blocks and successor basic blocks to
96
96
// / corresponding pullback trampoline basic blocks. Trampoline basic blocks
@@ -163,8 +163,19 @@ class PullbackCloner::Implementation final
163
163
auto *pbTupleTyple = getPullbackInfo ().getLinearMapTupleType (origBB);
164
164
assert (pbTupleTyple->getNumElements () == values.size () &&
165
165
" The number of pullback tuple fields must equal the number of "
166
- " pullback struct element values" );
167
- auto res = pullbackTupleElements.insert ({origBB, values});
166
+ " pullback tuple element values" );
167
+ auto res = pullbackTupleElements.insert ({origBB, { values.begin (), values.end () }});
168
+ (void )res;
169
+ assert (res.second && " A pullback tuple element already exists!" );
170
+ }
171
+
172
+ void initializePullbackTupleElements (SILBasicBlock *origBB,
173
+ const llvm::ArrayRef<SILArgument *> &values) {
174
+ auto *pbTupleTyple = getPullbackInfo ().getLinearMapTupleType (origBB);
175
+ assert (pbTupleTyple->getNumElements () == values.size () &&
176
+ " The number of pullback tuple fields must equal the number of "
177
+ " pullback tuple element values" );
178
+ auto res = pullbackTupleElements.insert ({origBB, { values.begin (), values.end () }});
168
179
(void )res;
169
180
assert (res.second && " A pullback struct element already exists!" );
170
181
}
@@ -1917,34 +1928,36 @@ bool PullbackCloner::Implementation::run() {
1917
1928
builder.setInsertionPoint (pullbackBB);
1918
1929
// Obtain the context object, if any, and the top-level subcontext, i.e.
1919
1930
// the main pullback struct.
1920
- SILValue mainPullbackTuple;
1921
1931
if (getPullbackInfo ().hasLoops ()) {
1922
1932
// The last argument is the context object (`Builtin.NativeObject`).
1923
1933
contextValue = pullbackBB->getArguments ().back ();
1924
1934
assert (contextValue->getType () ==
1925
1935
SILType::getNativeObjectType (getASTContext ()));
1926
- // Load the pullback struct .
1936
+ // Load the pullback context .
1927
1937
auto subcontextAddr = emitProjectTopLevelSubcontext (
1928
1938
builder, pbLoc, contextValue, pbTupleLoweredType);
1929
- mainPullbackTuple = builder.createLoad (
1939
+ SILValue mainPullbackTuple = builder.createLoad (
1930
1940
pbLoc, subcontextAddr,
1931
1941
pbTupleLoweredType.isTrivial (getPullback ()) ?
1932
1942
LoadOwnershipQualifier::Trivial : LoadOwnershipQualifier::Take);
1943
+ auto *dsi = builder.createDestructureTuple (pbLoc, mainPullbackTuple);
1944
+ initializePullbackTupleElements (origBB, dsi->getAllResults ());
1933
1945
} else {
1934
1946
// Obtain and destructure pullback struct elements.
1935
- mainPullbackTuple = pullbackBB->getArguments ().back ();
1936
- assert (mainPullbackTuple->getType () == pbTupleLoweredType);
1947
+ unsigned numVals = pbTupleLoweredType.getAs <TupleType>()->getNumElements ();
1948
+ initializePullbackTupleElements (origBB,
1949
+ pullbackBB->getArguments ().take_back (numVals));
1937
1950
}
1938
1951
1939
- auto *dsi = builder.createDestructureTuple (pbLoc, mainPullbackTuple);
1940
- initializePullbackTupleElements (origBB, dsi->getResults ());
1941
1952
continue ;
1942
1953
}
1954
+
1943
1955
// Get all active values in the original block.
1944
1956
// If the original block has no active values, continue.
1945
1957
auto &bbActiveValues = activeValues[origBB];
1946
1958
if (bbActiveValues.empty ())
1947
1959
continue ;
1960
+
1948
1961
// Otherwise, if the original block has active values:
1949
1962
// - For each active buffer in the original block, allocate a new local
1950
1963
// buffer in the pullback entry. (All adjoint buffers are allocated in
@@ -2008,11 +2021,17 @@ bool PullbackCloner::Implementation::run() {
2008
2021
}
2009
2022
2010
2023
auto *pullbackEntry = pullback.getEntryBlock ();
2024
+ auto pbTupleLoweredType =
2025
+ remapType (getPullbackInfo ().getLinearMapTupleLoweredType (originalExitBlock));
2026
+ unsigned numVals = (getPullbackInfo ().hasLoops () ?
2027
+ 1 : pbTupleLoweredType.getAs <TupleType>()->getNumElements ());
2028
+ (void )numVals;
2029
+
2011
2030
// The pullback function has type:
2012
- // `(seed0, seed1, ..., exit_pb_struct |context_obj) -> (d_arg0, ..., d_argn)`.
2031
+ // `(seed0, seed1, ..., (exit_pb_tuple_el0, ..., ) |context_obj) -> (d_arg0, ..., d_argn)`.
2013
2032
auto pbParamArgs = pullback.getArgumentsWithoutIndirectResults ();
2014
- assert (getConfig ().resultIndices ->getNumIndices () == pbParamArgs.size () - 1 &&
2015
- pbParamArgs.size () >= 2 );
2033
+ assert (getConfig ().resultIndices ->getNumIndices () == pbParamArgs.size () - numVals &&
2034
+ pbParamArgs.size () >= 1 );
2016
2035
// Assign adjoints for original result.
2017
2036
builder.setCurrentDebugScope (
2018
2037
remapScope (originalExitBlock->getTerminator ()->getDebugScope ()));
@@ -2637,9 +2656,9 @@ bool PullbackCloner::Implementation::runForSemanticMemberGetter() {
2637
2656
2638
2657
// Get getter argument and result values.
2639
2658
// Getter type: $(Self) -> Result
2640
- // Pullback type: $(Result', PB_Struct|Context ) -> Self'
2659
+ // Pullback type: $(Result') -> Self'
2641
2660
assert (original.getLoweredFunctionType ()->getNumParameters () == 1 );
2642
- assert (pullback.getLoweredFunctionType ()->getNumParameters () == 2 );
2661
+ assert (pullback.getLoweredFunctionType ()->getNumParameters () == 1 );
2643
2662
assert (pullback.getLoweredFunctionType ()->getNumResults () == 1 );
2644
2663
SILValue origSelf = original.getArgumentsWithoutIndirectResults ().front ();
2645
2664
@@ -2752,10 +2771,10 @@ bool PullbackCloner::Implementation::runForSemanticMemberSetter() {
2752
2771
2753
2772
// Get setter argument values.
2754
2773
// Setter type: $(inout Self, Argument) -> ()
2755
- // Pullback type (wrt self): $(inout Self', PB_Struct ) -> ()
2756
- // Pullback type (wrt both): $(inout Self', PB_Struct ) -> Argument'
2774
+ // Pullback type (wrt self): $(inout Self') -> ()
2775
+ // Pullback type (wrt both): $(inout Self') -> Argument'
2757
2776
assert (original.getLoweredFunctionType ()->getNumParameters () == 2 );
2758
- assert (pullback.getLoweredFunctionType ()->getNumParameters () == 2 );
2777
+ assert (pullback.getLoweredFunctionType ()->getNumParameters () == 1 );
2759
2778
assert (pullback.getLoweredFunctionType ()->getNumResults () == 0 ||
2760
2779
pullback.getLoweredFunctionType ()->getNumResults () == 1 );
2761
2780
0 commit comments