@@ -181,33 +181,53 @@ enum class ImplFunctionRepresentation {
181
181
Closure
182
182
};
183
183
184
+ enum class ImplFunctionDifferentiabilityKind {
185
+ NonDifferentiable,
186
+ Normal,
187
+ Linear
188
+ };
189
+
184
190
class ImplFunctionTypeFlags {
185
191
unsigned Rep : 3 ;
186
192
unsigned Pseudogeneric : 1 ;
187
193
unsigned Escaping : 1 ;
194
+ unsigned DifferentiabilityKind : 2 ;
188
195
189
196
public:
190
- ImplFunctionTypeFlags () : Rep(0 ), Pseudogeneric(0 ), Escaping(0 ) {}
197
+ ImplFunctionTypeFlags ()
198
+ : Rep(0 ), Pseudogeneric(0 ), Escaping(0 ), DifferentiabilityKind(0 ) {}
191
199
192
- ImplFunctionTypeFlags (ImplFunctionRepresentation rep,
193
- bool pseudogeneric, bool noescape)
194
- : Rep(unsigned (rep)), Pseudogeneric(pseudogeneric), Escaping(noescape) {}
200
+ ImplFunctionTypeFlags (ImplFunctionRepresentation rep, bool pseudogeneric,
201
+ bool noescape,
202
+ ImplFunctionDifferentiabilityKind diffKind)
203
+ : Rep(unsigned (rep)), Pseudogeneric(pseudogeneric), Escaping(noescape),
204
+ DifferentiabilityKind (unsigned (diffKind)) {}
195
205
196
206
ImplFunctionTypeFlags
197
207
withRepresentation (ImplFunctionRepresentation rep) const {
198
- return ImplFunctionTypeFlags (rep, Pseudogeneric, Escaping);
208
+ return ImplFunctionTypeFlags (
209
+ rep, Pseudogeneric, Escaping,
210
+ ImplFunctionDifferentiabilityKind (DifferentiabilityKind));
199
211
}
200
212
201
213
ImplFunctionTypeFlags
202
214
withEscaping () const {
203
- return ImplFunctionTypeFlags (ImplFunctionRepresentation (Rep),
204
- Pseudogeneric, true );
215
+ return ImplFunctionTypeFlags (
216
+ ImplFunctionRepresentation (Rep), Pseudogeneric, true ,
217
+ ImplFunctionDifferentiabilityKind (DifferentiabilityKind));
205
218
}
206
219
207
220
ImplFunctionTypeFlags
208
221
withPseudogeneric () const {
209
- return ImplFunctionTypeFlags (ImplFunctionRepresentation (Rep),
210
- true , Escaping);
222
+ return ImplFunctionTypeFlags (
223
+ ImplFunctionRepresentation (Rep), true , Escaping,
224
+ ImplFunctionDifferentiabilityKind (DifferentiabilityKind));
225
+ }
226
+
227
+ ImplFunctionTypeFlags
228
+ withDifferentiabilityKind (ImplFunctionDifferentiabilityKind diffKind) const {
229
+ return ImplFunctionTypeFlags (ImplFunctionRepresentation (Rep), Pseudogeneric,
230
+ Escaping, diffKind);
211
231
}
212
232
213
233
ImplFunctionRepresentation getRepresentation () const {
@@ -217,6 +237,10 @@ class ImplFunctionTypeFlags {
217
237
bool isEscaping () const { return Escaping; }
218
238
219
239
bool isPseudogeneric () const { return Pseudogeneric; }
240
+
241
+ ImplFunctionDifferentiabilityKind getDifferentiabilityKind () const {
242
+ return ImplFunctionDifferentiabilityKind (DifferentiabilityKind);
243
+ }
220
244
};
221
245
222
246
#if SWIFT_OBJC_INTEROP
@@ -582,6 +606,14 @@ class TypeDecoder {
582
606
flags =
583
607
flags.withRepresentation (ImplFunctionRepresentation::Block);
584
608
}
609
+ } else if (child->getKind () == NodeKind::ImplDifferentiable) {
610
+ flags = flags.withDifferentiabilityKind (
611
+ ImplFunctionDifferentiabilityKind::Normal);
612
+ } else if (child->getKind () == NodeKind::ImplLinear) {
613
+ flags = flags.withDifferentiabilityKind (
614
+ ImplFunctionDifferentiabilityKind::Linear);
615
+ } else if (child->getKind () == NodeKind::ImplEscaping) {
616
+ flags = flags.withEscaping ();
585
617
} else if (child->getKind () == NodeKind::ImplEscaping) {
586
618
flags = flags.withEscaping ();
587
619
} else if (child->getKind () == NodeKind::ImplParameter) {
0 commit comments