@@ -173,80 +173,59 @@ enum class AutoDiffGeneratedDeclarationKind : uint8_t {
173173 BranchingTraceEnum
174174};
175175
176- // / SIL-level automatic differentiation indices. Consists of:
177- // / - The differentiability parameter indices.
178- // / - The differentiability result indices.
179- // TODO(TF-913): Remove `SILAutoDiffIndices` in favor of `AutoDiffConfig`.
180- // `AutoDiffConfig` additionally stores a derivative generic signature.
181- struct SILAutoDiffIndices {
182- // / The indices of independent parameters to differentiate with respect to.
183- IndexSubset *parameters;
184- // / The indices of dependent results to differentiate from.
185- IndexSubset *results;
186-
187- /* implicit*/ SILAutoDiffIndices(IndexSubset *parameters, IndexSubset *results)
188- : parameters(parameters), results(results) {
189- assert (parameters && " Parameter indices must be non-null" );
190- assert (results && " Result indices must be non-null" );
191- }
192-
193- bool operator ==(const SILAutoDiffIndices &other) const ;
194-
195- bool operator !=(const SILAutoDiffIndices &other) const {
196- return !(*this == other);
197- };
176+ // / Identifies an autodiff derivative function configuration:
177+ // / - Parameter indices.
178+ // / - Result indices.
179+ // / - Derivative generic signature (optional).
180+ struct AutoDiffConfig {
181+ IndexSubset *parameterIndices;
182+ IndexSubset *resultIndices;
183+ GenericSignature derivativeGenericSignature;
184+
185+ /* implicit*/ AutoDiffConfig(
186+ IndexSubset *parameterIndices, IndexSubset *resultIndices,
187+ GenericSignature derivativeGenericSignature = GenericSignature())
188+ : parameterIndices(parameterIndices), resultIndices(resultIndices),
189+ derivativeGenericSignature (derivativeGenericSignature) {}
198190
199191 // / Returns true if `parameterIndex` is a differentiability parameter index.
200192 bool isWrtParameter (unsigned parameterIndex) const {
201- return parameterIndex < parameters ->getCapacity () &&
202- parameters ->contains (parameterIndex);
193+ return parameterIndex < parameterIndices ->getCapacity () &&
194+ parameterIndices ->contains (parameterIndex);
203195 }
204196
205- void print (llvm::raw_ostream &s = llvm::outs()) const ;
206- SWIFT_DEBUG_DUMP;
197+ // / Returns true if `resultIndex` is a differentiability result index.
198+ bool isWrtResult (unsigned resultIndex) const {
199+ return resultIndex < resultIndices->getCapacity () &&
200+ resultIndices->contains (resultIndex);
201+ }
207202
203+ AutoDiffConfig withGenericSignature (GenericSignature signature) const {
204+ return AutoDiffConfig (parameterIndices, resultIndices, signature);
205+ }
206+
207+ // TODO(SR-13506): Use principled mangling for AD-generated symbols.
208208 std::string mangle () const {
209209 std::string result = " src_" ;
210210 interleave (
211- results ->getIndices (),
211+ resultIndices ->getIndices (),
212212 [&](unsigned idx) { result += llvm::utostr (idx); },
213213 [&] { result += ' _' ; });
214214 result += " _wrt_" ;
215215 llvm::interleave (
216- parameters ->getIndices (),
216+ parameterIndices ->getIndices (),
217217 [&](unsigned idx) { result += llvm::utostr (idx); },
218218 [&] { result += ' _' ; });
219219 return result;
220220 }
221- };
222-
223- // / Identifies an autodiff derivative function configuration:
224- // / - Parameter indices.
225- // / - Result indices.
226- // / - Derivative generic signature (optional).
227- struct AutoDiffConfig {
228- IndexSubset *parameterIndices;
229- IndexSubset *resultIndices;
230- GenericSignature derivativeGenericSignature;
231-
232- /* implicit*/ AutoDiffConfig(IndexSubset *parameterIndices,
233- IndexSubset *resultIndices,
234- GenericSignature derivativeGenericSignature)
235- : parameterIndices(parameterIndices), resultIndices(resultIndices),
236- derivativeGenericSignature (derivativeGenericSignature) {}
237-
238- // / Returns the `SILAutoDiffIndices` corresponding to this config's indices.
239- // TODO(TF-913): This is a temporary shim for incremental removal of
240- // `SILAutoDiffIndices`. Eventually remove this.
241- SILAutoDiffIndices getSILAutoDiffIndices () const ;
242221
243222 void print (llvm::raw_ostream &s = llvm::outs()) const ;
244223 SWIFT_DEBUG_DUMP;
245224};
246225
247226inline llvm::raw_ostream &operator <<(llvm::raw_ostream &s,
248- const SILAutoDiffIndices &indices ) {
249- indices .print (s);
227+ const AutoDiffConfig &config ) {
228+ config .print (s);
250229 return s;
251230}
252231
0 commit comments