Skip to content

Commit ae5ebba

Browse files
committed
[RequirementMachine] Add same-length requirement inference for pack
expansion types.
1 parent 6fb7bcc commit ae5ebba

File tree

6 files changed

+66
-3
lines changed

6 files changed

+66
-3
lines changed

include/swift/AST/DiagnosticsSema.def

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2545,6 +2545,10 @@ ERROR(requires_not_suitable_archetype,none,
25452545
"generic parameter or associated type",
25462546
(Type))
25472547

2548+
ERROR(invalid_shape_requirement,none,
2549+
"invalid same-shape requirement between %0 and %1",
2550+
(Type, Type))
2551+
25482552
ERROR(requires_generic_params_made_equal,none,
25492553
"same-type requirement makes generic parameters %0 and %1 equivalent",
25502554
(Type, Type))

lib/AST/RequirementMachine/ConcreteContraction.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -554,7 +554,8 @@ bool ConcreteContraction::performConcreteContraction(
554554
auto kind = req.req.getKind();
555555
switch (kind) {
556556
case RequirementKind::SameCount:
557-
llvm_unreachable("Same-count requirement not supported here");
557+
assert(req.req.getSecondType()->isTypeParameter());
558+
continue;
558559

559560
case RequirementKind::SameType: {
560561
auto constraintType = req.req.getSecondType();

lib/AST/RequirementMachine/Diagnostics.cpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,20 @@ bool swift::rewriting::diagnoseRequirementErrors(
105105
break;
106106
}
107107

108+
case RequirementError::Kind::InvalidShapeRequirement: {
109+
if (error.requirement.hasError())
110+
break;
111+
112+
auto lhs = error.requirement.getFirstType();
113+
auto rhs = error.requirement.getSecondType();
114+
115+
// FIXME: Add tailored messages for specific issues.
116+
ctx.Diags.diagnose(loc, diag::invalid_shape_requirement,
117+
lhs, rhs);
118+
diagnosedError = true;
119+
break;
120+
}
121+
108122
case RequirementError::Kind::ConflictingRequirement: {
109123
auto requirement = error.requirement;
110124
auto conflict = error.conflictingRequirement;

lib/AST/RequirementMachine/Diagnostics.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,8 @@ struct RequirementError {
3434
/// A type requirement on a trivially invalid subject type,
3535
/// e.g. Bool: Collection.
3636
InvalidRequirementSubject,
37+
/// An invalid shape requirement, e.g. length(T...) == length(Int)
38+
InvalidShapeRequirement,
3739
/// A pair of conflicting requirements, T == Int, T == String
3840
ConflictingRequirement,
3941
/// A recursive requirement, e.g. T == G<T.A>.
@@ -73,6 +75,11 @@ struct RequirementError {
7375
return {Kind::InvalidRequirementSubject, req, loc};
7476
}
7577

78+
static RequirementError forInvalidShapeRequirement(Requirement req,
79+
SourceLoc loc) {
80+
return {Kind::InvalidShapeRequirement, req, loc};
81+
}
82+
7683
static RequirementError forConflictingRequirement(Requirement req,
7784
SourceLoc loc) {
7885
return {Kind::ConflictingRequirement, req, loc};

lib/AST/RequirementMachine/RequirementLowering.cpp

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -350,6 +350,23 @@ static void desugarConformanceRequirement(Type subjectType, Type constraintType,
350350
}
351351
}
352352

353+
/// Desugar same-shape requirements by equating the shapes of the
354+
/// root pack types, and diagnose shape requirements on non-pack
355+
/// types.
356+
static void desugarSameShapeRequirement(Type lhs, Type rhs, SourceLoc loc,
357+
SmallVectorImpl<Requirement> &result,
358+
SmallVectorImpl<RequirementError> &errors) {
359+
// For now, only allow shape requirements directly between pack types.
360+
if (!lhs->isTypeSequenceParameter() || !rhs->isTypeSequenceParameter()) {
361+
errors.push_back(RequirementError::forInvalidShapeRequirement(
362+
{RequirementKind::SameCount, lhs, rhs}, loc));
363+
}
364+
365+
result.emplace_back(RequirementKind::SameCount,
366+
lhs->getRootGenericParam(),
367+
rhs->getRootGenericParam());
368+
}
369+
353370
/// Convert a requirement where the subject type might not be a type parameter,
354371
/// or the constraint type in the conformance requirement might be a protocol
355372
/// composition, into zero or more "proper" requirements which can then be
@@ -362,7 +379,9 @@ swift::rewriting::desugarRequirement(Requirement req, SourceLoc loc,
362379

363380
switch (req.getKind()) {
364381
case RequirementKind::SameCount:
365-
llvm_unreachable("Same-count requirement not supported here");
382+
desugarSameShapeRequirement(firstType, req.getSecondType(),
383+
loc, result, errors);
384+
break;
366385

367386
case RequirementKind::Conformance:
368387
desugarConformanceRequirement(firstType, req.getSecondType(),
@@ -474,6 +493,23 @@ struct InferRequirementsWalker : public TypeWalker {
474493
return Action::Continue;
475494
}
476495

496+
// Infer same-length requirements between pack references that
497+
// are expanded in parallel.
498+
if (auto packExpansion = ty->getAs<PackExpansionType>()) {
499+
// Get all pack parameters referenced from the pattern.
500+
SmallVector<Type, 2> packReferences;
501+
packExpansion->getPatternType()->getTypeSequenceParameters(packReferences);
502+
503+
if (packReferences.size() > 1) {
504+
auto first = packReferences.begin();
505+
auto second = first + 1;
506+
while (second != packReferences.end()) {
507+
Requirement req(RequirementKind::SameCount, *first++, *second++);
508+
desugarRequirement(req, SourceLoc(), reqs, errors);
509+
}
510+
}
511+
}
512+
477513
// Infer requirements from `@differentiable` function types.
478514
// For all non-`@noDerivative` parameter and result types:
479515
// - `@differentiable`, `@differentiable(_forward)`, or

lib/AST/RequirementMachine/RuleBuilder.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -299,7 +299,8 @@ void RuleBuilder::addRequirement(const Requirement &req,
299299

300300
switch (req.getKind()) {
301301
case RequirementKind::SameCount:
302-
llvm_unreachable("Same-count requirement not supported here");
302+
// TODO
303+
return;
303304

304305
case RequirementKind::Conformance: {
305306
// A conformance requirement T : P becomes a rewrite rule

0 commit comments

Comments
 (0)