@@ -24,9 +24,126 @@ namespace autodiff {
24
24
raw_ostream &getADDebugStream () { return llvm::dbgs () << " [AD] " ; }
25
25
26
26
// ===----------------------------------------------------------------------===//
27
- // Code emission utilities
27
+ // Helpers
28
28
// ===----------------------------------------------------------------------===//
29
29
30
+ bool isArrayLiteralIntrinsic (FullApplySite applySite) {
31
+ return doesApplyCalleeHaveSemantics (applySite.getCalleeOrigin (),
32
+ " array.uninitialized_intrinsic" );
33
+ }
34
+
35
+ ApplyInst *getAllocateUninitializedArrayIntrinsic (SILValue v) {
36
+ if (auto *ai = dyn_cast<ApplyInst>(v))
37
+ if (isArrayLiteralIntrinsic (ai))
38
+ return ai;
39
+ return nullptr ;
40
+ }
41
+
42
+ ApplyInst *getAllocateUninitializedArrayIntrinsicElementAddress (SILValue v) {
43
+ // Find the `pointer_to_address` result, peering through `index_addr`.
44
+ auto *ptai = dyn_cast<PointerToAddressInst>(v);
45
+ if (auto *iai = dyn_cast<IndexAddrInst>(v))
46
+ ptai = dyn_cast<PointerToAddressInst>(iai->getOperand (0 ));
47
+ if (!ptai)
48
+ return nullptr ;
49
+ // Return the `array.uninitialized_intrinsic` application, if it exists.
50
+ if (auto *dti = dyn_cast<DestructureTupleInst>(
51
+ ptai->getOperand ()->getDefiningInstruction ())) {
52
+ if (auto *ai = getAllocateUninitializedArrayIntrinsic (dti->getOperand ()))
53
+ return ai;
54
+ }
55
+ return nullptr ;
56
+ }
57
+
58
+ DestructureTupleInst *getSingleDestructureTupleUser (SILValue value) {
59
+ bool foundDestructureTupleUser = false ;
60
+ if (!value->getType ().is <TupleType>())
61
+ return nullptr ;
62
+ DestructureTupleInst *result = nullptr ;
63
+ for (auto *use : value->getUses ()) {
64
+ if (auto *dti = dyn_cast<DestructureTupleInst>(use->getUser ())) {
65
+ assert (!foundDestructureTupleUser &&
66
+ " There should only be one `destructure_tuple` user of a tuple" );
67
+ foundDestructureTupleUser = true ;
68
+ result = dti;
69
+ }
70
+ }
71
+ return result;
72
+ }
73
+
74
+ void forEachApplyDirectResult (
75
+ FullApplySite applySite,
76
+ llvm::function_ref<void (SILValue)> resultCallback) {
77
+ switch (applySite.getKind ()) {
78
+ case FullApplySiteKind::ApplyInst: {
79
+ auto *ai = cast<ApplyInst>(applySite.getInstruction ());
80
+ if (!ai->getType ().is <TupleType>()) {
81
+ resultCallback (ai);
82
+ return ;
83
+ }
84
+ if (auto *dti = getSingleDestructureTupleUser (ai))
85
+ for (auto directResult : dti->getResults ())
86
+ resultCallback (directResult);
87
+ break ;
88
+ }
89
+ case FullApplySiteKind::BeginApplyInst: {
90
+ auto *bai = cast<BeginApplyInst>(applySite.getInstruction ());
91
+ for (auto directResult : bai->getResults ())
92
+ resultCallback (directResult);
93
+ break ;
94
+ }
95
+ case FullApplySiteKind::TryApplyInst: {
96
+ auto *tai = cast<TryApplyInst>(applySite.getInstruction ());
97
+ for (auto *succBB : tai->getSuccessorBlocks ())
98
+ for (auto *arg : succBB->getArguments ())
99
+ resultCallback (arg);
100
+ break ;
101
+ }
102
+ }
103
+ }
104
+
105
+ void collectAllFormalResultsInTypeOrder (SILFunction &function,
106
+ SmallVectorImpl<SILValue> &results) {
107
+ SILFunctionConventions convs (function.getLoweredFunctionType (),
108
+ function.getModule ());
109
+ auto indResults = function.getIndirectResults ();
110
+ auto *retInst = cast<ReturnInst>(function.findReturnBB ()->getTerminator ());
111
+ auto retVal = retInst->getOperand ();
112
+ SmallVector<SILValue, 8 > dirResults;
113
+ if (auto *tupleInst =
114
+ dyn_cast_or_null<TupleInst>(retVal->getDefiningInstruction ()))
115
+ dirResults.append (tupleInst->getElements ().begin (),
116
+ tupleInst->getElements ().end ());
117
+ else
118
+ dirResults.push_back (retVal);
119
+ unsigned indResIdx = 0 , dirResIdx = 0 ;
120
+ for (auto &resInfo : convs.getResults ())
121
+ results.push_back (resInfo.isFormalDirect () ? dirResults[dirResIdx++]
122
+ : indResults[indResIdx++]);
123
+ // Treat `inout` parameters as semantic results.
124
+ // Append `inout` parameters after formal results.
125
+ for (auto i : range (convs.getNumParameters ())) {
126
+ auto paramInfo = convs.getParameters ()[i];
127
+ if (!paramInfo.isIndirectMutating ())
128
+ continue ;
129
+ auto *argument = function.getArgumentsWithoutIndirectResults ()[i];
130
+ results.push_back (argument);
131
+ }
132
+ }
133
+
134
+ void collectAllDirectResultsInTypeOrder (SILFunction &function,
135
+ SmallVectorImpl<SILValue> &results) {
136
+ SILFunctionConventions convs (function.getLoweredFunctionType (),
137
+ function.getModule ());
138
+ auto *retInst = cast<ReturnInst>(function.findReturnBB ()->getTerminator ());
139
+ auto retVal = retInst->getOperand ();
140
+ if (auto *tupleInst = dyn_cast<TupleInst>(retVal))
141
+ results.append (tupleInst->getElements ().begin (),
142
+ tupleInst->getElements ().end ());
143
+ else
144
+ results.push_back (retVal);
145
+ }
146
+
30
147
void collectAllActualResultsInTypeOrder (
31
148
ApplyInst *ai, ArrayRef<SILValue> extractedDirectResults,
32
149
SmallVectorImpl<SILValue> &results) {
@@ -39,6 +156,73 @@ void collectAllActualResultsInTypeOrder(
39
156
}
40
157
}
41
158
159
+ void collectMinimalIndicesForFunctionCall (
160
+ ApplyInst *ai, SILAutoDiffIndices parentIndices,
161
+ const DifferentiableActivityInfo &activityInfo,
162
+ SmallVectorImpl<SILValue> &results, SmallVectorImpl<unsigned > ¶mIndices,
163
+ SmallVectorImpl<unsigned > &resultIndices) {
164
+ auto calleeFnTy = ai->getSubstCalleeType ();
165
+ auto calleeConvs = ai->getSubstCalleeConv ();
166
+ // Parameter indices are indices (in the callee type signature) of parameter
167
+ // arguments that are varied or are arguments.
168
+ // Record all parameter indices in type order.
169
+ unsigned currentParamIdx = 0 ;
170
+ for (auto applyArg : ai->getArgumentsWithoutIndirectResults ()) {
171
+ if (activityInfo.isActive (applyArg, parentIndices))
172
+ paramIndices.push_back (currentParamIdx);
173
+ ++currentParamIdx;
174
+ }
175
+ // Result indices are indices (in the callee type signature) of results that
176
+ // are useful.
177
+ SmallVector<SILValue, 8 > directResults;
178
+ forEachApplyDirectResult (ai, [&](SILValue directResult) {
179
+ directResults.push_back (directResult);
180
+ });
181
+ auto indirectResults = ai->getIndirectSILResults ();
182
+ // Record all results and result indices in type order.
183
+ results.reserve (calleeFnTy->getNumResults ());
184
+ unsigned dirResIdx = 0 ;
185
+ unsigned indResIdx = calleeConvs.getSILArgIndexOfFirstIndirectResult ();
186
+ for (auto &resAndIdx : enumerate(calleeConvs.getResults ())) {
187
+ auto &res = resAndIdx.value ();
188
+ unsigned idx = resAndIdx.index ();
189
+ if (res.isFormalDirect ()) {
190
+ results.push_back (directResults[dirResIdx]);
191
+ if (auto dirRes = directResults[dirResIdx])
192
+ if (dirRes && activityInfo.isActive (dirRes, parentIndices))
193
+ resultIndices.push_back (idx);
194
+ ++dirResIdx;
195
+ } else {
196
+ results.push_back (indirectResults[indResIdx]);
197
+ if (activityInfo.isActive (indirectResults[indResIdx], parentIndices))
198
+ resultIndices.push_back (idx);
199
+ ++indResIdx;
200
+ }
201
+ }
202
+ // Record all `inout` parameters as results.
203
+ auto inoutParamResultIndex = calleeFnTy->getNumResults ();
204
+ for (auto ¶mAndIdx : enumerate(calleeConvs.getParameters ())) {
205
+ auto ¶m = paramAndIdx.value ();
206
+ if (!param.isIndirectMutating ())
207
+ continue ;
208
+ unsigned idx = paramAndIdx.index ();
209
+ auto inoutArg = ai->getArgument (idx);
210
+ results.push_back (inoutArg);
211
+ resultIndices.push_back (inoutParamResultIndex++);
212
+ }
213
+ // Make sure the function call has active results.
214
+ auto numResults = calleeFnTy->getNumResults () +
215
+ calleeFnTy->getNumIndirectMutatingParameters ();
216
+ assert (results.size () == numResults);
217
+ assert (llvm::any_of (results, [&](SILValue result) {
218
+ return activityInfo.isActive (result, parentIndices);
219
+ }));
220
+ }
221
+
222
+ // ===----------------------------------------------------------------------===//
223
+ // Code emission utilities
224
+ // ===----------------------------------------------------------------------===//
225
+
42
226
SILValue joinElements (ArrayRef<SILValue> elements, SILBuilder &builder,
43
227
SILLocation loc) {
44
228
if (elements.size () == 1 )
0 commit comments