Skip to content

Commit 39ec990

Browse files
committed
RequirementMachine: Untangle requirement desugaring from requirement inference
Refactor the code to match what's written up in generics.tex. It's easier to understand what's going on if requirement inference first introduces a bunch of requirements that might be trivial, and then all user-written and inferred requirements are desugared at the end in a separate pass.
1 parent 75c1651 commit 39ec990

File tree

3 files changed

+70
-71
lines changed

3 files changed

+70
-71
lines changed

lib/AST/RequirementMachine/RequirementLowering.cpp

Lines changed: 54 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -458,6 +458,25 @@ swift::rewriting::desugarRequirement(Requirement req, SourceLoc loc,
458458
}
459459
}
460460

461+
void swift::rewriting::desugarRequirements(SmallVector<StructuralRequirement, 2> &reqs,
462+
SmallVectorImpl<RequirementError> &errors) {
463+
SmallVector<StructuralRequirement, 2> result;
464+
for (auto req : reqs) {
465+
SmallVector<Requirement, 2> desugaredReqs;
466+
SmallVector<RequirementError, 2> ignoredErrors;
467+
468+
if (req.inferred)
469+
desugarRequirement(req.req, SourceLoc(), desugaredReqs, ignoredErrors);
470+
else
471+
desugarRequirement(req.req, req.loc, desugaredReqs, errors);
472+
473+
for (auto desugaredReq : desugaredReqs)
474+
result.push_back({desugaredReq, req.loc, req.inferred});
475+
}
476+
477+
std::swap(reqs, result);
478+
}
479+
461480
//
462481
// Requirement realization and inference.
463482
//
@@ -467,8 +486,6 @@ static void realizeTypeRequirement(DeclContext *dc,
467486
SourceLoc loc,
468487
SmallVectorImpl<StructuralRequirement> &result,
469488
SmallVectorImpl<RequirementError> &errors) {
470-
SmallVector<Requirement, 2> reqs;
471-
472489
// The GenericSignatureBuilder allowed the right hand side of a
473490
// conformance or superclass requirement to reference a protocol
474491
// typealias whose underlying type was a protocol or class.
@@ -497,22 +514,19 @@ static void realizeTypeRequirement(DeclContext *dc,
497514
}
498515

499516
if (constraintType->isConstraintType()) {
500-
Requirement req(RequirementKind::Conformance, subjectType, constraintType);
501-
desugarRequirement(req, loc, reqs, errors);
517+
result.push_back({Requirement(RequirementKind::Conformance,
518+
subjectType, constraintType),
519+
loc, /*wasInferred=*/false});
502520
} else if (constraintType->getClassOrBoundGenericClass()) {
503-
Requirement req(RequirementKind::Superclass, subjectType, constraintType);
504-
desugarRequirement(req, loc, reqs, errors);
521+
result.push_back({Requirement(RequirementKind::Superclass,
522+
subjectType, constraintType),
523+
loc, /*wasInferred=*/false});
505524
} else {
506525
errors.push_back(
507526
RequirementError::forInvalidTypeRequirement(subjectType,
508527
constraintType,
509528
loc));
510-
return;
511529
}
512-
513-
// Add source location information.
514-
for (auto req : reqs)
515-
result.push_back({req, loc, /*wasInferred=*/false});
516530
}
517531

518532
namespace {
@@ -521,11 +535,11 @@ namespace {
521535
struct InferRequirementsWalker : public TypeWalker {
522536
ModuleDecl *module;
523537
DeclContext *dc;
524-
SmallVector<Requirement, 2> reqs;
525-
SmallVector<RequirementError, 2> errors;
538+
SmallVectorImpl<StructuralRequirement> &reqs;
526539

527-
explicit InferRequirementsWalker(ModuleDecl *module, DeclContext *dc)
528-
: module(module), dc(dc) {}
540+
explicit InferRequirementsWalker(ModuleDecl *module, DeclContext *dc,
541+
SmallVectorImpl<StructuralRequirement> &reqs)
542+
: module(module), dc(dc), reqs(reqs) {}
529543

530544
Action walkToTypePre(Type ty) override {
531545
// Unbound generic types are the result of recovered-but-invalid code, and
@@ -555,8 +569,7 @@ struct InferRequirementsWalker : public TypeWalker {
555569
return false;
556570

557571
return (req.getKind() == RequirementKind::Conformance &&
558-
req.getSecondType()->castTo<ProtocolType>()->getDecl()
559-
->isSpecificProtocol(KnownProtocolKind::Sendable));
572+
req.getProtocolDecl()->isSpecificProtocol(KnownProtocolKind::Sendable));
560573
};
561574

562575
// Infer from generic typealiases.
@@ -567,7 +580,7 @@ struct InferRequirementsWalker : public TypeWalker {
567580
if (skipRequirement(rawReq, decl))
568581
continue;
569582

570-
desugarRequirement(rawReq.subst(subMap), SourceLoc(), reqs, errors);
583+
reqs.push_back({rawReq.subst(subMap), SourceLoc(), /*inferred=*/true});
571584
}
572585

573586
return Action::Continue;
@@ -581,10 +594,9 @@ struct InferRequirementsWalker : public TypeWalker {
581594
packExpansion->getPatternType()->getTypeParameterPacks(packReferences);
582595

583596
auto countType = packExpansion->getCountType();
584-
for (auto pack : packReferences) {
585-
Requirement req(RequirementKind::SameShape, countType, pack);
586-
desugarRequirement(req, SourceLoc(), reqs, errors);
587-
}
597+
for (auto pack : packReferences)
598+
reqs.push_back({Requirement(RequirementKind::SameShape, countType, pack),
599+
SourceLoc(), /*inferred=*/true});
588600
}
589601

590602
// Infer requirements from `@differentiable` function types.
@@ -596,9 +608,9 @@ struct InferRequirementsWalker : public TypeWalker {
596608
if (auto *fnTy = ty->getAs<AnyFunctionType>()) {
597609
// Add a new conformance constraint for a fixed protocol.
598610
auto addConformanceConstraint = [&](Type type, ProtocolDecl *protocol) {
599-
Requirement req(RequirementKind::Conformance, type,
600-
protocol->getDeclaredInterfaceType());
601-
desugarRequirement(req, SourceLoc(), reqs, errors);
611+
reqs.push_back({Requirement(RequirementKind::Conformance, type,
612+
protocol->getDeclaredInterfaceType()),
613+
SourceLoc(), /*inferred=*/true});
602614
};
603615

604616
auto &ctx = module->getASTContext();
@@ -610,8 +622,9 @@ struct InferRequirementsWalker : public TypeWalker {
610622
auto secondType = assocType->getDeclaredInterfaceType()
611623
->castTo<DependentMemberType>()
612624
->substBaseType(module, firstType);
613-
Requirement req(RequirementKind::SameType, firstType, secondType);
614-
desugarRequirement(req, SourceLoc(), reqs, errors);
625+
reqs.push_back({Requirement(RequirementKind::SameType,
626+
firstType, secondType),
627+
SourceLoc(), /*inferred=*/true});
615628
};
616629
auto *tangentVectorAssocType =
617630
differentiableProtocol->getAssociatedType(ctx.Id_TangentVector);
@@ -659,8 +672,7 @@ struct InferRequirementsWalker : public TypeWalker {
659672
if (skipRequirement(rawReq, decl))
660673
continue;
661674

662-
auto req = rawReq.subst(subMap);
663-
desugarRequirement(req, SourceLoc(), reqs, errors);
675+
reqs.push_back({rawReq.subst(subMap), SourceLoc(), /*inferred=*/true});
664676
}
665677

666678
return Action::Continue;
@@ -683,15 +695,12 @@ void swift::rewriting::inferRequirements(
683695
if (!type)
684696
return;
685697

686-
InferRequirementsWalker walker(module, dc);
698+
InferRequirementsWalker walker(module, dc, result);
687699
type.walk(walker);
688-
689-
for (const auto &req : walker.reqs)
690-
result.push_back({req, loc, /*wasInferred=*/true});
691700
}
692701

693-
/// Desugar a requirement and perform requirement inference if requested
694-
/// to obtain zero or more structural requirements.
702+
/// Perform requirement inference from the type representations in the
703+
/// requirement itself (eg, `T == Set<U>` infers `U: Hashable`).
695704
void swift::rewriting::realizeRequirement(
696705
DeclContext *dc,
697706
Requirement req, RequirementRepr *reqRepr,
@@ -732,12 +741,7 @@ void swift::rewriting::realizeRequirement(
732741
inferRequirements(firstType, firstLoc, moduleForInference, dc, result);
733742
}
734743

735-
SmallVector<Requirement, 2> reqs;
736-
desugarRequirement(req, loc, reqs, errors);
737-
738-
for (auto req : reqs)
739-
result.push_back({req, loc, /*wasInferred=*/false});
740-
744+
result.push_back({req, loc, /*wasInferred=*/false});
741745
break;
742746
}
743747

@@ -754,11 +758,7 @@ void swift::rewriting::realizeRequirement(
754758
inferRequirements(secondType, secondLoc, moduleForInference, dc, result);
755759
}
756760

757-
SmallVector<Requirement, 2> reqs;
758-
desugarRequirement(req, loc, reqs, errors);
759-
760-
for (auto req : reqs)
761-
result.push_back({req, loc, /*wasInferred=*/false});
761+
result.push_back({req, loc, /*wasInferred=*/false});
762762
break;
763763
}
764764
}
@@ -903,13 +903,13 @@ StructuralRequirementsRequest::evaluate(Evaluator &evaluator,
903903
ProtocolDecl *proto) const {
904904
assert(!proto->hasLazyRequirementSignature());
905905

906-
SmallVector<StructuralRequirement, 4> result;
907-
SmallVector<RequirementError, 4> errors;
906+
SmallVector<StructuralRequirement, 2> result;
907+
SmallVector<RequirementError, 2> errors;
908908

909909
auto &ctx = proto->getASTContext();
910910
auto selfTy = proto->getSelfInterfaceType();
911911

912-
SmallVector<Type, 4> needsDefaultReqirements({selfTy});
912+
SmallVector<Type, 4> needsDefaultRequirements({selfTy});
913913

914914
unsigned errorCount = errors.size();
915915
realizeInheritedRequirements(proto, selfTy,
@@ -950,7 +950,8 @@ StructuralRequirementsRequest::evaluate(Evaluator &evaluator,
950950
result.push_back({Requirement(RequirementKind::Layout, selfTy, layout),
951951
proto->getLoc(), /*inferred=*/true});
952952

953-
expandDefaultRequirements(ctx, needsDefaultReqirements, result, errors);
953+
desugarRequirements(result, errors);
954+
expandDefaultRequirements(ctx, needsDefaultRequirements, result, errors);
954955
return ctx.AllocateCopy(result);
955956
}
956957

@@ -976,7 +977,7 @@ StructuralRequirementsRequest::evaluate(Evaluator &evaluator,
976977
return false;
977978
});
978979

979-
needsDefaultReqirements.push_back(assocType);
980+
needsDefaultRequirements.push_back(assocType);
980981
}
981982

982983
// Add requirements for each typealias.
@@ -1014,7 +1015,8 @@ StructuralRequirementsRequest::evaluate(Evaluator &evaluator,
10141015
}
10151016
}
10161017

1017-
expandDefaultRequirements(ctx, needsDefaultReqirements, result, errors);
1018+
desugarRequirements(result, errors);
1019+
expandDefaultRequirements(ctx, needsDefaultRequirements, result, errors);
10181020

10191021
diagnoseRequirementErrors(ctx, errors,
10201022
AllowConcreteTypePolicy::NestedAssocTypes);

lib/AST/RequirementMachine/RequirementLowering.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,9 @@ namespace rewriting {
3838
// documentation
3939
// comments.
4040

41+
void desugarRequirements(SmallVector<StructuralRequirement, 2> &result,
42+
SmallVectorImpl<RequirementError> &errors);
43+
4144
void desugarRequirement(Requirement req, SourceLoc loc,
4245
SmallVectorImpl<Requirement> &result,
4346
SmallVectorImpl<RequirementError> &errors);

lib/AST/RequirementMachine/RequirementMachineRequests.cpp

Lines changed: 13 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -632,14 +632,13 @@ AbstractGenericSignatureRequest::evaluate(
632632

633633
// Convert the input Requirements into StructuralRequirements by adding
634634
// empty source locations.
635-
SmallVector<StructuralRequirement, 4> requirements;
635+
SmallVector<StructuralRequirement, 2> requirements;
636636
for (auto req : baseSignature.getRequirements())
637637
requirements.push_back({req, SourceLoc(), /*wasInferred=*/false});
638638

639-
// We need to create this errors vector to pass to
640-
// desugarRequirement, but this request should never
641-
// diagnose errors.
642-
SmallVector<RequirementError, 4> errors;
639+
// Add the new requirements.
640+
for (auto req : addedRequirements)
641+
requirements.push_back({req, SourceLoc(), /*wasInferred=*/false});
643642

644643
// The requirements passed to this request may have been substituted,
645644
// meaning the subject type might be a concrete type and not a type
@@ -651,12 +650,8 @@ AbstractGenericSignatureRequest::evaluate(
651650
// Desugaring converts these kinds of requirements into "proper"
652651
// requirements where the subject type is always a type parameter,
653652
// which is what the RuleBuilder expects.
654-
for (auto req : addedRequirements) {
655-
SmallVector<Requirement, 2> reqs;
656-
desugarRequirement(req, SourceLoc(), reqs, errors);
657-
for (auto req : reqs)
658-
requirements.push_back({req, SourceLoc(), /*wasInferred=*/false});
659-
}
653+
SmallVector<RequirementError, 2> errors;
654+
desugarRequirements(requirements, errors);
660655

661656
auto &rewriteCtx = ctx.getRewriteContext();
662657

@@ -747,8 +742,8 @@ InferredGenericSignatureRequest::evaluate(
747742
parentSig.getGenericParams().begin(),
748743
parentSig.getGenericParams().end());
749744

750-
SmallVector<StructuralRequirement, 4> requirements;
751-
SmallVector<RequirementError, 4> errors;
745+
SmallVector<StructuralRequirement, 2> requirements;
746+
SmallVector<RequirementError, 2> errors;
752747

753748
SourceLoc loc = [&]() {
754749
if (genericParamList) {
@@ -844,9 +839,6 @@ InferredGenericSignatureRequest::evaluate(
844839
for (auto *gtpd : genericParamList->getParams())
845840
localGPs.push_back(gtpd->getDeclaredInterfaceType());
846841

847-
// Expand defaults and eliminate all inverse-conformance requirements.
848-
expandDefaultRequirements(ctx, localGPs, requirements, errors);
849-
850842
// Perform requirement inference from function parameter and result
851843
// types and such.
852844
for (auto sourcePair : inferenceSources) {
@@ -860,12 +852,14 @@ InferredGenericSignatureRequest::evaluate(
860852
// Finish by adding any remaining requirements. This is used to introduce
861853
// inferred same-type requirements when building the generic signature of
862854
// an extension whose extended type is a generic typealias.
863-
SmallVector<Requirement, 4> rawAddedRequirements;
864855
for (const auto &req : addedRequirements)
865-
desugarRequirement(req, SourceLoc(), rawAddedRequirements, errors);
866-
for (const auto &req : rawAddedRequirements)
867856
requirements.push_back({req, SourceLoc(), /*inferred=*/true});
868857

858+
desugarRequirements(requirements, errors);
859+
860+
// Expand defaults and eliminate all inverse-conformance requirements.
861+
expandDefaultRequirements(ctx, localGPs, requirements, errors);
862+
869863
// Re-order requirements so that inferred requirements appear last. This
870864
// ensures that if an inferred requirement is redundant with some other
871865
// requirement, it is the inferred requirement that becomes redundant,

0 commit comments

Comments
 (0)