@@ -173,80 +173,59 @@ enum class AutoDiffGeneratedDeclarationKind : uint8_t {
173
173
BranchingTraceEnum
174
174
};
175
175
176
- // / SIL-level automatic differentiation indices. Consists of:
177
- // / - The differentiability parameter indices.
178
- // / - The differentiability result indices.
179
- // TODO(TF-913): Remove `SILAutoDiffIndices` in favor of `AutoDiffConfig`.
180
- // `AutoDiffConfig` additionally stores a derivative generic signature.
181
- struct SILAutoDiffIndices {
182
- // / The indices of independent parameters to differentiate with respect to.
183
- IndexSubset *parameters;
184
- // / The indices of dependent results to differentiate from.
185
- IndexSubset *results;
186
-
187
- /* implicit*/ SILAutoDiffIndices(IndexSubset *parameters, IndexSubset *results)
188
- : parameters(parameters), results(results) {
189
- assert (parameters && " Parameter indices must be non-null" );
190
- assert (results && " Result indices must be non-null" );
191
- }
192
-
193
- bool operator ==(const SILAutoDiffIndices &other) const ;
194
-
195
- bool operator !=(const SILAutoDiffIndices &other) const {
196
- return !(*this == other);
197
- };
176
+ // / Identifies an autodiff derivative function configuration:
177
+ // / - Parameter indices.
178
+ // / - Result indices.
179
+ // / - Derivative generic signature (optional).
180
+ struct AutoDiffConfig {
181
+ IndexSubset *parameterIndices;
182
+ IndexSubset *resultIndices;
183
+ GenericSignature derivativeGenericSignature;
184
+
185
+ /* implicit*/ AutoDiffConfig(
186
+ IndexSubset *parameterIndices, IndexSubset *resultIndices,
187
+ GenericSignature derivativeGenericSignature = GenericSignature())
188
+ : parameterIndices(parameterIndices), resultIndices(resultIndices),
189
+ derivativeGenericSignature (derivativeGenericSignature) {}
198
190
199
191
// / Returns true if `parameterIndex` is a differentiability parameter index.
200
192
bool isWrtParameter (unsigned parameterIndex) const {
201
- return parameterIndex < parameters ->getCapacity () &&
202
- parameters ->contains (parameterIndex);
193
+ return parameterIndex < parameterIndices ->getCapacity () &&
194
+ parameterIndices ->contains (parameterIndex);
203
195
}
204
196
205
- void print (llvm::raw_ostream &s = llvm::outs()) const ;
206
- SWIFT_DEBUG_DUMP;
197
+ // / Returns true if `resultIndex` is a differentiability result index.
198
+ bool isWrtResult (unsigned resultIndex) const {
199
+ return resultIndex < resultIndices->getCapacity () &&
200
+ resultIndices->contains (resultIndex);
201
+ }
207
202
203
+ AutoDiffConfig withGenericSignature (GenericSignature signature) const {
204
+ return AutoDiffConfig (parameterIndices, resultIndices, signature);
205
+ }
206
+
207
+ // TODO(SR-13506): Use principled mangling for AD-generated symbols.
208
208
std::string mangle () const {
209
209
std::string result = " src_" ;
210
210
interleave (
211
- results ->getIndices (),
211
+ resultIndices ->getIndices (),
212
212
[&](unsigned idx) { result += llvm::utostr (idx); },
213
213
[&] { result += ' _' ; });
214
214
result += " _wrt_" ;
215
215
llvm::interleave (
216
- parameters ->getIndices (),
216
+ parameterIndices ->getIndices (),
217
217
[&](unsigned idx) { result += llvm::utostr (idx); },
218
218
[&] { result += ' _' ; });
219
219
return result;
220
220
}
221
- };
222
-
223
- // / Identifies an autodiff derivative function configuration:
224
- // / - Parameter indices.
225
- // / - Result indices.
226
- // / - Derivative generic signature (optional).
227
- struct AutoDiffConfig {
228
- IndexSubset *parameterIndices;
229
- IndexSubset *resultIndices;
230
- GenericSignature derivativeGenericSignature;
231
-
232
- /* implicit*/ AutoDiffConfig(IndexSubset *parameterIndices,
233
- IndexSubset *resultIndices,
234
- GenericSignature derivativeGenericSignature)
235
- : parameterIndices(parameterIndices), resultIndices(resultIndices),
236
- derivativeGenericSignature (derivativeGenericSignature) {}
237
-
238
- // / Returns the `SILAutoDiffIndices` corresponding to this config's indices.
239
- // TODO(TF-913): This is a temporary shim for incremental removal of
240
- // `SILAutoDiffIndices`. Eventually remove this.
241
- SILAutoDiffIndices getSILAutoDiffIndices () const ;
242
221
243
222
void print (llvm::raw_ostream &s = llvm::outs()) const ;
244
223
SWIFT_DEBUG_DUMP;
245
224
};
246
225
247
226
inline llvm::raw_ostream &operator <<(llvm::raw_ostream &s,
248
- const SILAutoDiffIndices &indices ) {
249
- indices .print (s);
227
+ const AutoDiffConfig &config ) {
228
+ config .print (s);
250
229
return s;
251
230
}
252
231
0 commit comments