Skip to content

Commit bbe305c

Browse files
committed
[ConstraintSystem] Add same-shape constraint
The constraint takes two pack types and makes sure that their reduced shapes are equal. This helps with diagnostics because constraint has access to the original pack expansion pattern types.
1 parent cbfec20 commit bbe305c

File tree

8 files changed

+100
-17
lines changed

8 files changed

+100
-17
lines changed

include/swift/Sema/Constraint.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -233,6 +233,8 @@ enum class ConstraintKind : char {
233233
/// an overload. The second type is a PackType containing the explicit
234234
/// generic arguments.
235235
ExplicitGenericArguments,
236+
/// Both (first and second) pack types should have the same reduced shape.
237+
SameShape,
236238
};
237239

238240
/// Classification of the different kinds of constraints.
@@ -701,6 +703,7 @@ class Constraint final : public llvm::ilist_node<Constraint>,
701703
case ConstraintKind::DefaultClosureType:
702704
case ConstraintKind::UnresolvedMemberChainBase:
703705
case ConstraintKind::PackElementOf:
706+
case ConstraintKind::SameShape:
704707
return ConstraintClassification::Relational;
705708

706709
case ConstraintKind::ValueMember:

include/swift/Sema/ConstraintSystem.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4871,6 +4871,12 @@ class ConstraintSystem {
48714871
Type type1, Type type2, TypeMatchOptions flags,
48724872
ConstraintLocatorBuilder locator);
48734873

4874+
/// Simplify a same-shape constraint by comparing the reduced shape of the
4875+
/// left hand side to the right hand side.
4876+
SolutionKind simplifySameShapeConstraint(Type type1, Type type2,
4877+
TypeMatchOptions flags,
4878+
ConstraintLocatorBuilder locator);
4879+
48744880
public: // FIXME: Public for use by static functions.
48754881
/// Simplify a conversion constraint with a fix applied to it.
48764882
SolutionKind simplifyFixConstraint(ConstraintFix *fix, Type type1, Type type2,

lib/Sema/CSBindings.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1480,6 +1480,7 @@ void PotentialBindings::infer(Constraint *constraint) {
14801480
case ConstraintKind::ShapeOf:
14811481
case ConstraintKind::ExplicitGenericArguments:
14821482
case ConstraintKind::PackElementOf:
1483+
case ConstraintKind::SameShape:
14831484
// Constraints from which we can't do anything.
14841485
break;
14851486

lib/Sema/CSSimplify.cpp

Lines changed: 73 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -2276,6 +2276,7 @@ ConstraintSystem::matchTupleTypes(TupleType *tuple1, TupleType *tuple2,
22762276
case ConstraintKind::PackElementOf:
22772277
case ConstraintKind::ShapeOf:
22782278
case ConstraintKind::ExplicitGenericArguments:
2279+
case ConstraintKind::SameShape:
22792280
llvm_unreachable("Bad constraint kind in matchTupleTypes()");
22802281
}
22812282

@@ -2489,13 +2490,9 @@ ConstraintSystem::matchPackExpansionTypes(PackExpansionType *expansion1,
24892490
ConstraintKind kind, TypeMatchOptions flags,
24902491
ConstraintLocatorBuilder locator) {
24912492
// The count types of two pack expansion types must have the same shape.
2492-
auto *shapeLoc = getConstraintLocator(
2493-
locator.withPathElement(ConstraintLocator::PackShape));
2494-
auto *shapeTypeVar = createTypeVariable(shapeLoc, TVO_CanBindToPack);
2495-
addConstraint(ConstraintKind::ShapeOf,
2496-
expansion1->getCountType(), shapeTypeVar, shapeLoc);
2497-
addConstraint(ConstraintKind::ShapeOf,
2498-
expansion2->getCountType(), shapeTypeVar, shapeLoc);
2493+
addConstraint(ConstraintKind::SameShape, expansion1->getCountType(),
2494+
expansion2->getCountType(),
2495+
locator.withPathElement(ConstraintLocator::PackShape));
24992496

25002497
auto pattern1 = expansion1->getPatternType();
25012498
auto pattern2 = expansion2->getPatternType();
@@ -2655,6 +2652,7 @@ static bool matchFunctionRepresentations(FunctionType::ExtInfo einfo1,
26552652
case ConstraintKind::PackElementOf:
26562653
case ConstraintKind::ShapeOf:
26572654
case ConstraintKind::ExplicitGenericArguments:
2655+
case ConstraintKind::SameShape:
26582656
return true;
26592657
}
26602658

@@ -3162,6 +3160,7 @@ ConstraintSystem::matchFunctionTypes(FunctionType *func1, FunctionType *func2,
31623160
case ConstraintKind::PackElementOf:
31633161
case ConstraintKind::ShapeOf:
31643162
case ConstraintKind::ExplicitGenericArguments:
3163+
case ConstraintKind::SameShape:
31653164
llvm_unreachable("Not a relational constraint");
31663165
}
31673166

@@ -6815,6 +6814,7 @@ ConstraintSystem::matchTypes(Type type1, Type type2, ConstraintKind kind,
68156814
case ConstraintKind::PackElementOf:
68166815
case ConstraintKind::ShapeOf:
68176816
case ConstraintKind::ExplicitGenericArguments:
6817+
case ConstraintKind::SameShape:
68186818
llvm_unreachable("Not a relational constraint");
68196819
}
68206820
}
@@ -13190,6 +13190,18 @@ ConstraintSystem::simplifyDynamicCallableApplicableFnConstraint(
1319013190
return SolutionKind::Solved;
1319113191
}
1319213192

13193+
static bool hasUnresolvedPackVars(Type type) {
13194+
// We can't compute a reduced shape if the input type still
13195+
// contains type variables that might bind to pack archetypes
13196+
// or pack expansions.
13197+
SmallPtrSet<TypeVariableType *, 2> typeVars;
13198+
type->getTypeVariables(typeVars);
13199+
return llvm::any_of(typeVars, [](const TypeVariableType *typeVar) {
13200+
return typeVar->getImpl().canBindToPack() ||
13201+
typeVar->getImpl().isPackExpansion();
13202+
});
13203+
}
13204+
1319313205
ConstraintSystem::SolutionKind ConstraintSystem::simplifyShapeOfConstraint(
1319413206
Type type1, Type type2, TypeMatchOptions flags,
1319513207
ConstraintLocatorBuilder locator) {
@@ -13231,6 +13243,51 @@ ConstraintSystem::SolutionKind ConstraintSystem::simplifyShapeOfConstraint(
1323113243
return SolutionKind::Solved;
1323213244
}
1323313245

13246+
ConstraintSystem::SolutionKind ConstraintSystem::simplifySameShapeConstraint(
13247+
Type type1, Type type2, TypeMatchOptions flags,
13248+
ConstraintLocatorBuilder locator) {
13249+
type1 = simplifyType(type1);
13250+
type2 = simplifyType(type2);
13251+
13252+
auto formUnsolved = [&]() {
13253+
// If we're supposed to generate constraints, do so.
13254+
if (flags.contains(TMF_GenerateConstraints)) {
13255+
auto *sameShape =
13256+
Constraint::create(*this, ConstraintKind::SameShape, type1, type2,
13257+
getConstraintLocator(locator));
13258+
13259+
addUnsolvedConstraint(sameShape);
13260+
return SolutionKind::Solved;
13261+
}
13262+
13263+
return SolutionKind::Unsolved;
13264+
};
13265+
13266+
if (hasUnresolvedPackVars(type1) || hasUnresolvedPackVars(type2))
13267+
return formUnsolved();
13268+
13269+
auto shape1 = type1->getReducedShape();
13270+
auto shape2 = type2->getReducedShape();
13271+
13272+
if (shape1->isEqual(shape2))
13273+
return SolutionKind::Solved;
13274+
13275+
if (shouldAttemptFixes()) {
13276+
if (type1->hasPlaceholder() || type2->hasPlaceholder())
13277+
return SolutionKind::Solved;
13278+
13279+
unsigned impact = 1;
13280+
if (locator.endsWith<LocatorPathElt::AnyRequirement>())
13281+
impact = assessRequirementFailureImpact(*this, shape1, locator);
13282+
13283+
auto *fix = SkipSameShapeRequirement::create(*this, type1, type2,
13284+
getConstraintLocator(locator));
13285+
return recordFix(fix, impact) ? SolutionKind::Error : SolutionKind::Solved;
13286+
}
13287+
13288+
return SolutionKind::Error;
13289+
}
13290+
1323413291
ConstraintSystem::SolutionKind
1323513292
ConstraintSystem::simplifyExplicitGenericArgumentsConstraint(
1323613293
Type type1, Type type2, TypeMatchOptions flags,
@@ -14718,6 +14775,9 @@ ConstraintSystem::addConstraintImpl(ConstraintKind kind, Type first,
1471814775
case ConstraintKind::ShapeOf:
1471914776
return simplifyShapeOfConstraint(first, second, subflags, locator);
1472014777

14778+
case ConstraintKind::SameShape:
14779+
return simplifySameShapeConstraint(first, second, subflags, locator);
14780+
1472114781
case ConstraintKind::ExplicitGenericArguments:
1472214782
return simplifyExplicitGenericArgumentsConstraint(
1472314783
first, second, subflags, locator);
@@ -14889,13 +14949,7 @@ void ConstraintSystem::addConstraint(Requirement req,
1488914949
auto type1 = req.getFirstType();
1489014950
auto type2 = req.getSecondType();
1489114951

14892-
auto *shapeLoc = getConstraintLocator(
14893-
locator.withPathElement(ConstraintLocator::PackShape));
14894-
auto typeVar = createTypeVariable(shapeLoc,
14895-
TVO_CanBindToPack);
14896-
14897-
addConstraint(ConstraintKind::ShapeOf, type1, typeVar, locator);
14898-
addConstraint(ConstraintKind::ShapeOf, type2, typeVar, locator);
14952+
addConstraint(ConstraintKind::SameShape, type1, type2, locator);
1489914953
return;
1490014954
}
1490114955

@@ -15319,6 +15373,11 @@ ConstraintSystem::simplifyConstraint(const Constraint &constraint) {
1531915373
constraint.getFirstType(), constraint.getSecondType(), /*flags*/ None,
1532015374
constraint.getLocator());
1532115375

15376+
case ConstraintKind::SameShape:
15377+
return simplifySameShapeConstraint(constraint.getFirstType(),
15378+
constraint.getSecondType(),
15379+
/*flags*/ None, constraint.getLocator());
15380+
1532215381
case ConstraintKind::ExplicitGenericArguments:
1532315382
return simplifyExplicitGenericArgumentsConstraint(
1532415383
constraint.getFirstType(), constraint.getSecondType(),

lib/Sema/Constraint.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@ Constraint::Constraint(ConstraintKind Kind, Type First, Type Second,
8282
case ConstraintKind::PackElementOf:
8383
case ConstraintKind::ShapeOf:
8484
case ConstraintKind::ExplicitGenericArguments:
85+
case ConstraintKind::SameShape:
8586
assert(!First.isNull());
8687
assert(!Second.isNull());
8788
break;
@@ -171,6 +172,7 @@ Constraint::Constraint(ConstraintKind Kind, Type First, Type Second, Type Third,
171172
case ConstraintKind::PackElementOf:
172173
case ConstraintKind::ShapeOf:
173174
case ConstraintKind::ExplicitGenericArguments:
175+
case ConstraintKind::SameShape:
174176
llvm_unreachable("Wrong constructor");
175177

176178
case ConstraintKind::KeyPath:
@@ -319,6 +321,7 @@ Constraint *Constraint::clone(ConstraintSystem &cs) const {
319321
case ConstraintKind::PackElementOf:
320322
case ConstraintKind::ShapeOf:
321323
case ConstraintKind::ExplicitGenericArguments:
324+
case ConstraintKind::SameShape:
322325
return create(cs, getKind(), getFirstType(), getSecondType(), getLocator());
323326

324327
case ConstraintKind::ApplicableFunction:
@@ -568,6 +571,10 @@ void Constraint::print(llvm::raw_ostream &Out, SourceManager *sm,
568571
Out << " shape of ";
569572
break;
570573

574+
case ConstraintKind::SameShape:
575+
Out << " same-shape ";
576+
break;
577+
571578
case ConstraintKind::ExplicitGenericArguments:
572579
Out << " explicit generic argument binding ";
573580
break;
@@ -740,6 +747,7 @@ gatherReferencedTypeVars(Constraint *constraint,
740747
case ConstraintKind::PackElementOf:
741748
case ConstraintKind::ShapeOf:
742749
case ConstraintKind::ExplicitGenericArguments:
750+
case ConstraintKind::SameShape:
743751
constraint->getFirstType()->getTypeVariables(typeVars);
744752
constraint->getSecondType()->getTypeVariables(typeVars);
745753
break;

test/Constraints/pack-expansion-expressions.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,7 @@ func tupleExpansion<each T, each U>(
133133
_ = zip(repeat each tuple1.element, with: repeat each tuple1.element)
134134

135135
_ = zip(repeat each tuple1.element, with: repeat each tuple2.element)
136-
// expected-error@-1 {{global function 'zip(_:with:)' requires the type packs 'each U' and 'each T' have the same shape}}
136+
// expected-error@-1 {{global function 'zip(_:with:)' requires the type packs 'each T' and 'each U' have the same shape}}
137137
}
138138

139139
protocol Generatable {

test/Constraints/variadic_generic_constraints.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,5 +73,5 @@ func goodCallToZip<each T, each U>(t: repeat each T, u: repeat each U) where (re
7373

7474
func badCallToZip<each T, each U>(t: repeat each T, u: repeat each U) {
7575
_ = zip(t: repeat each t, u: repeat each u)
76-
// expected-error@-1 {{global function 'zip(t:u:)' requires the type packs 'each U' and 'each T' have the same shape}}
76+
// expected-error@-1 {{global function 'zip(t:u:)' requires the type packs 'each T' and 'each U' have the same shape}}
7777
}

test/Constraints/variadic_generic_functions.swift

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,9 +32,15 @@ func call() {
3232

3333
let x: String = multipleParameters(xs: "", ys: "")
3434
let (one, two) = multipleParameters(xs: "", 5.0, ys: "", 5.0)
35-
multipleParameters(xs: "", 5.0, ys: 5.0, "") // expected-error {{type of expression is ambiguous without more context}}
35+
multipleParameters(xs: "", 5.0, ys: 5.0, "") // expected-error {{conflicting arguments to generic parameter 'each T' ('Pack{Double, String}' vs. 'Pack{String, String}' vs. 'Pack{String, Double}' vs. 'Pack{Double, Double}')}}
3636

3737
func multipleSequences<each T, each U>(xs: repeat each T, ys: repeat each U) -> (repeat each T) {
38+
return (repeat each ys)
39+
// expected-error@-1 {{pack expansion requires that 'each U' and 'each T' have the same shape}}
40+
// expected-error@-2 {{cannot convert return expression of type '(repeat each U)' to return type '(repeat each T)'}}
41+
}
42+
43+
func multipleSequencesWithSameShape<each T, each U>(xs: repeat each T, ys: repeat each U) -> (repeat each T) where (repeat (each T, each U)): Any {
3844
return (repeat each ys)
3945
// expected-error@-1 {{cannot convert return expression of type '(repeat each U)' to return type '(repeat each T)'}}
4046
}

0 commit comments

Comments
 (0)