@@ -210,6 +210,115 @@ static void realizeTypeRequirement(Type subjectType, Type constraintType,
210
210
result.push_back ({req, loc, /* wasInferred=*/ false });
211
211
}
212
212
213
+ namespace {
214
+
215
+ // / AST walker that infers requirements from type representations.
216
+ struct InferRequirementsWalker : public TypeWalker {
217
+ ModuleDecl *module ;
218
+ SmallVector<Requirement, 2 > reqs;
219
+
220
+ explicit InferRequirementsWalker (ModuleDecl *module ) : module(module ) {}
221
+
222
+ Action walkToTypePre (Type ty) override {
223
+ // Unbound generic types are the result of recovered-but-invalid code, and
224
+ // don't have enough info to do any useful substitutions.
225
+ if (ty->is <UnboundGenericType>())
226
+ return Action::Stop;
227
+
228
+ return Action::Continue;
229
+ }
230
+
231
+ Action walkToTypePost (Type ty) override {
232
+ // Infer from generic typealiases.
233
+ if (auto typeAlias = dyn_cast<TypeAliasType>(ty.getPointer ())) {
234
+ auto decl = typeAlias->getDecl ();
235
+ auto subMap = typeAlias->getSubstitutionMap ();
236
+ for (const auto &rawReq : decl->getGenericSignature ().getRequirements ()) {
237
+ if (auto req = rawReq.subst (subMap))
238
+ desugarRequirement (*req, reqs);
239
+ }
240
+
241
+ return Action::Continue;
242
+ }
243
+
244
+ // Infer requirements from `@differentiable` function types.
245
+ // For all non-`@noDerivative` parameter and result types:
246
+ // - `@differentiable`, `@differentiable(_forward)`, or
247
+ // `@differentiable(reverse)`: add `T: Differentiable` requirement.
248
+ // - `@differentiable(_linear)`: add
249
+ // `T: Differentiable`, `T == T.TangentVector` requirements.
250
+ if (auto *fnTy = ty->getAs <AnyFunctionType>()) {
251
+ auto &ctx = module ->getASTContext ();
252
+ auto *differentiableProtocol =
253
+ ctx.getProtocol (KnownProtocolKind::Differentiable);
254
+ if (differentiableProtocol && fnTy->isDifferentiable ()) {
255
+ auto addConformanceConstraint = [&](Type type, ProtocolDecl *protocol) {
256
+ Requirement req (RequirementKind::Conformance, type,
257
+ protocol->getDeclaredInterfaceType ());
258
+ desugarRequirement (req, reqs);
259
+ };
260
+ auto addSameTypeConstraint = [&](Type firstType,
261
+ AssociatedTypeDecl *assocType) {
262
+ auto *protocol = assocType->getProtocol ();
263
+ auto *module = protocol->getParentModule ();
264
+ auto conf = module ->lookupConformance (firstType, protocol);
265
+ auto secondType = conf.getAssociatedType (
266
+ firstType, assocType->getDeclaredInterfaceType ());
267
+ Requirement req (RequirementKind::SameType, firstType, secondType);
268
+ desugarRequirement (req, reqs);
269
+ };
270
+ auto *tangentVectorAssocType =
271
+ differentiableProtocol->getAssociatedType (ctx.Id_TangentVector );
272
+ auto addRequirements = [&](Type type, bool isLinear) {
273
+ addConformanceConstraint (type, differentiableProtocol);
274
+ if (isLinear)
275
+ addSameTypeConstraint (type, tangentVectorAssocType);
276
+ };
277
+ auto constrainParametersAndResult = [&](bool isLinear) {
278
+ for (auto ¶m : fnTy->getParams ())
279
+ if (!param.isNoDerivative ())
280
+ addRequirements (param.getPlainType (), isLinear);
281
+ addRequirements (fnTy->getResult (), isLinear);
282
+ };
283
+ // Add requirements.
284
+ constrainParametersAndResult (fnTy->getDifferentiabilityKind () ==
285
+ DifferentiabilityKind::Linear);
286
+ }
287
+ }
288
+
289
+ if (!ty->isSpecialized ())
290
+ return Action::Continue;
291
+
292
+ // Infer from generic nominal types.
293
+ auto decl = ty->getAnyNominal ();
294
+ if (!decl) return Action::Continue;
295
+
296
+ // FIXME: The GSB and the request evaluator both detect a cycle here if we
297
+ // force a recursive generic signature. We should look into moving cycle
298
+ // detection into the generic signature request(s) - see rdar://55263708
299
+ if (!decl->hasComputedGenericSignature ())
300
+ return Action::Continue;
301
+
302
+ auto genericSig = decl->getGenericSignature ();
303
+ if (!genericSig)
304
+ return Action::Continue;
305
+
306
+ // / Retrieve the substitution.
307
+ auto subMap = ty->getContextSubstitutionMap (module , decl);
308
+
309
+ // Handle the requirements.
310
+ // FIXME: Inaccurate TypeReprs.
311
+ for (const auto &rawReq : genericSig.getRequirements ()) {
312
+ if (auto req = rawReq.subst (subMap))
313
+ desugarRequirement (*req, reqs);
314
+ }
315
+
316
+ return Action::Continue;
317
+ }
318
+ };
319
+
320
+ }
321
+
213
322
// / Infer requirements from applications of BoundGenericTypes to type
214
323
// / parameters. For example, given a function declaration
215
324
// /
@@ -220,7 +329,14 @@ static void realizeTypeRequirement(Type subjectType, Type constraintType,
220
329
void swift::rewriting::inferRequirements (
221
330
Type type, SourceLoc loc, ModuleDecl *module ,
222
331
SmallVectorImpl<StructuralRequirement> &result) {
223
- // FIXME: Implement
332
+ if (!type)
333
+ return ;
334
+
335
+ InferRequirementsWalker walker (module );
336
+ type.walk (walker);
337
+
338
+ for (const auto &req : walker.reqs )
339
+ result.push_back ({req, loc, /* wasInferred=*/ true });
224
340
}
225
341
226
342
// / Desugar a requirement and perform requirement inference if requested
0 commit comments