Skip to content

Commit 77d4a20

Browse files
committed
RequirementMachine: Implement requirement inference
1 parent 291ddd7 commit 77d4a20

File tree

1 file changed

+117
-1
lines changed

1 file changed

+117
-1
lines changed

lib/AST/RequirementMachine/RequirementLowering.cpp

Lines changed: 117 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -210,6 +210,115 @@ static void realizeTypeRequirement(Type subjectType, Type constraintType,
210210
result.push_back({req, loc, /*wasInferred=*/false});
211211
}
212212

213+
namespace {
214+
215+
/// AST walker that infers requirements from type representations.
216+
struct InferRequirementsWalker : public TypeWalker {
217+
ModuleDecl *module;
218+
SmallVector<Requirement, 2> reqs;
219+
220+
explicit InferRequirementsWalker(ModuleDecl *module) : module(module) {}
221+
222+
Action walkToTypePre(Type ty) override {
223+
// Unbound generic types are the result of recovered-but-invalid code, and
224+
// don't have enough info to do any useful substitutions.
225+
if (ty->is<UnboundGenericType>())
226+
return Action::Stop;
227+
228+
return Action::Continue;
229+
}
230+
231+
Action walkToTypePost(Type ty) override {
232+
// Infer from generic typealiases.
233+
if (auto typeAlias = dyn_cast<TypeAliasType>(ty.getPointer())) {
234+
auto decl = typeAlias->getDecl();
235+
auto subMap = typeAlias->getSubstitutionMap();
236+
for (const auto &rawReq : decl->getGenericSignature().getRequirements()) {
237+
if (auto req = rawReq.subst(subMap))
238+
desugarRequirement(*req, reqs);
239+
}
240+
241+
return Action::Continue;
242+
}
243+
244+
// Infer requirements from `@differentiable` function types.
245+
// For all non-`@noDerivative` parameter and result types:
246+
// - `@differentiable`, `@differentiable(_forward)`, or
247+
// `@differentiable(reverse)`: add `T: Differentiable` requirement.
248+
// - `@differentiable(_linear)`: add
249+
// `T: Differentiable`, `T == T.TangentVector` requirements.
250+
if (auto *fnTy = ty->getAs<AnyFunctionType>()) {
251+
auto &ctx = module->getASTContext();
252+
auto *differentiableProtocol =
253+
ctx.getProtocol(KnownProtocolKind::Differentiable);
254+
if (differentiableProtocol && fnTy->isDifferentiable()) {
255+
auto addConformanceConstraint = [&](Type type, ProtocolDecl *protocol) {
256+
Requirement req(RequirementKind::Conformance, type,
257+
protocol->getDeclaredInterfaceType());
258+
desugarRequirement(req, reqs);
259+
};
260+
auto addSameTypeConstraint = [&](Type firstType,
261+
AssociatedTypeDecl *assocType) {
262+
auto *protocol = assocType->getProtocol();
263+
auto *module = protocol->getParentModule();
264+
auto conf = module->lookupConformance(firstType, protocol);
265+
auto secondType = conf.getAssociatedType(
266+
firstType, assocType->getDeclaredInterfaceType());
267+
Requirement req(RequirementKind::SameType, firstType, secondType);
268+
desugarRequirement(req, reqs);
269+
};
270+
auto *tangentVectorAssocType =
271+
differentiableProtocol->getAssociatedType(ctx.Id_TangentVector);
272+
auto addRequirements = [&](Type type, bool isLinear) {
273+
addConformanceConstraint(type, differentiableProtocol);
274+
if (isLinear)
275+
addSameTypeConstraint(type, tangentVectorAssocType);
276+
};
277+
auto constrainParametersAndResult = [&](bool isLinear) {
278+
for (auto &param : fnTy->getParams())
279+
if (!param.isNoDerivative())
280+
addRequirements(param.getPlainType(), isLinear);
281+
addRequirements(fnTy->getResult(), isLinear);
282+
};
283+
// Add requirements.
284+
constrainParametersAndResult(fnTy->getDifferentiabilityKind() ==
285+
DifferentiabilityKind::Linear);
286+
}
287+
}
288+
289+
if (!ty->isSpecialized())
290+
return Action::Continue;
291+
292+
// Infer from generic nominal types.
293+
auto decl = ty->getAnyNominal();
294+
if (!decl) return Action::Continue;
295+
296+
// FIXME: The GSB and the request evaluator both detect a cycle here if we
297+
// force a recursive generic signature. We should look into moving cycle
298+
// detection into the generic signature request(s) - see rdar://55263708
299+
if (!decl->hasComputedGenericSignature())
300+
return Action::Continue;
301+
302+
auto genericSig = decl->getGenericSignature();
303+
if (!genericSig)
304+
return Action::Continue;
305+
306+
/// Retrieve the substitution.
307+
auto subMap = ty->getContextSubstitutionMap(module, decl);
308+
309+
// Handle the requirements.
310+
// FIXME: Inaccurate TypeReprs.
311+
for (const auto &rawReq : genericSig.getRequirements()) {
312+
if (auto req = rawReq.subst(subMap))
313+
desugarRequirement(*req, reqs);
314+
}
315+
316+
return Action::Continue;
317+
}
318+
};
319+
320+
}
321+
213322
/// Infer requirements from applications of BoundGenericTypes to type
214323
/// parameters. For example, given a function declaration
215324
///
@@ -220,7 +329,14 @@ static void realizeTypeRequirement(Type subjectType, Type constraintType,
220329
void swift::rewriting::inferRequirements(
221330
Type type, SourceLoc loc, ModuleDecl *module,
222331
SmallVectorImpl<StructuralRequirement> &result) {
223-
// FIXME: Implement
332+
if (!type)
333+
return;
334+
335+
InferRequirementsWalker walker(module);
336+
type.walk(walker);
337+
338+
for (const auto &req : walker.reqs)
339+
result.push_back({req, loc, /*wasInferred=*/true});
224340
}
225341

226342
/// Desugar a requirement and perform requirement inference if requested

0 commit comments

Comments
 (0)