Skip to content

Commit 7ce6f37

Browse files
committed
RequirementMachine: Correct concrete type unification with pack expansion on both sides
1 parent f07495d commit 7ce6f37

File tree

4 files changed

+87
-9
lines changed

4 files changed

+87
-9
lines changed

lib/AST/RequirementMachine/InterfaceType.cpp

Lines changed: 64 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -403,7 +403,7 @@ Type PropertyMap::getTypeForTerm(const MutableTerm &term,
403403
/// Concrete type terms are written in terms of generic parameter types that
404404
/// have a depth of 0, and an index into an array of substitution terms.
405405
///
406-
/// See RewriteSystemBuilder::getConcreteSubstitutionSchema().
406+
/// See RewriteSystemBuilder::getSubstitutionSchemaFromType().
407407
unsigned RewriteContext::getGenericParamIndex(Type type) {
408408
auto *paramTy = type->castTo<GenericTypeParamType>();
409409
assert(paramTy->getDepth() == 0);
@@ -429,6 +429,7 @@ RewriteContext::getRelativeTermForType(CanType typeWitness,
429429
// Get the substitution S corresponding to τ_0_n.
430430
unsigned index = getGenericParamIndex(typeWitness->getRootGenericParam());
431431
result = MutableTerm(substitutions[index]);
432+
assert(result.back().getKind() != Symbol::Kind::Shape);
432433

433434
// If the substitution is a term consisting of a single protocol symbol
434435
// [P], save P for later.
@@ -471,7 +472,7 @@ RewriteContext::getRelativeTermForType(CanType typeWitness,
471472
}
472473

473474
/// Reverses the transformation performed by
474-
/// RewriteSystemBuilder::getConcreteSubstitutionSchema().
475+
/// RewriteSystemBuilder::getSubstitutionSchemaFromType().
475476
Type PropertyMap::getTypeFromSubstitutionSchema(
476477
Type schema, ArrayRef<Term> substitutions,
477478
ArrayRef<GenericTypeParamType *> genericParams,
@@ -481,11 +482,38 @@ Type PropertyMap::getTypeFromSubstitutionSchema(
481482
if (!schema->hasTypeParameter())
482483
return schema;
483484

484-
return schema.transformRec([&](Type t) -> llvm::Optional<Type> {
485+
return schema.transformWithPosition(
486+
TypePosition::Invariant,
487+
[&](Type t, TypePosition pos) -> llvm::Optional<Type> {
485488
if (t->is<GenericTypeParamType>()) {
486489
auto index = RewriteContext::getGenericParamIndex(t);
487490
auto substitution = substitutions[index];
488491

492+
bool isShapePosition = (pos == TypePosition::Shape);
493+
bool isShapeTerm = (substitution.back() == Symbol::forShape(Context));
494+
if (isShapePosition != isShapeTerm) {
495+
llvm::errs() << "Shape vs. type mixup\n\n";
496+
schema.dump(llvm::errs());
497+
llvm::errs() << "Substitutions:\n";
498+
for (auto otherSubst : substitutions) {
499+
llvm::errs() << "- ";
500+
otherSubst.dump(llvm::errs());
501+
llvm::errs() << "\n";
502+
}
503+
llvm::errs() << "\n";
504+
dump(llvm::errs());
505+
506+
abort();
507+
}
508+
509+
// Undo the thing where the count type of a PackExpansionType
510+
// becomes a shape term.
511+
if (isShapeTerm) {
512+
MutableTerm mutTerm(substitution.begin(),
513+
substitution.end() - 1);
514+
substitution = Term::get(mutTerm, Context);
515+
}
516+
489517
// Prepend the prefix of the lookup key to the substitution.
490518
if (prefix.empty()) {
491519
// Skip creation of a new MutableTerm in the case where the
@@ -535,17 +563,31 @@ RewriteContext::getRelativeSubstitutionSchemaFromType(
535563
if (!concreteType->hasTypeParameter())
536564
return concreteType;
537565

538-
return CanType(concreteType.transformRec([&](Type t) -> llvm::Optional<Type> {
566+
return CanType(concreteType.transformWithPosition(
567+
TypePosition::Invariant,
568+
[&](Type t, TypePosition pos) -> llvm::Optional<Type> {
569+
539570
if (!t->isTypeParameter())
540571
return llvm::None;
541572

542573
auto term = getRelativeTermForType(CanType(t), substitutions);
543574

544-
unsigned newIndex = result.size();
575+
// PackExpansionType(pattern=T, count=U) becomes
576+
// PackExpansionType(pattern=τ_0_0, count=τ_0_1) with
577+
//
578+
// τ_0_0 := T
579+
// τ_0_1 := U.[shape]
580+
if (pos == TypePosition::Shape) {
581+
assert(false);
582+
term.add(Symbol::forShape(*this));
583+
}
584+
585+
unsigned index = result.size();
586+
545587
result.push_back(Term::get(term, *this));
546588

547589
return CanGenericTypeParamType::get(/*isParameterPack=*/ false,
548-
/*depth=*/ 0, newIndex,
590+
/*depth=*/ 0, index,
549591
Context);
550592
}));
551593
}
@@ -566,12 +608,26 @@ RewriteContext::getSubstitutionSchemaFromType(CanType concreteType,
566608
if (!concreteType->hasTypeParameter())
567609
return concreteType;
568610

569-
return CanType(concreteType.transformRec([&](Type t) -> llvm::Optional<Type> {
611+
return CanType(concreteType.transformWithPosition(
612+
TypePosition::Invariant,
613+
[&](Type t, TypePosition pos)
614+
-> llvm::Optional<Type> {
615+
570616
if (!t->isTypeParameter())
571617
return llvm::None;
572618

619+
// PackExpansionType(pattern=T, count=U) becomes
620+
// PackExpansionType(pattern=τ_0_0, count=τ_0_1) with
621+
//
622+
// τ_0_0 := T
623+
// τ_0_1 := U.[shape]
624+
MutableTerm term = getMutableTermForType(CanType(t), proto);
625+
if (pos == TypePosition::Shape)
626+
term.add(Symbol::forShape(*this));
627+
573628
unsigned index = result.size();
574-
result.push_back(getTermForType(CanType(t), proto));
629+
630+
result.push_back(Term::get(term, *this));
575631

576632
return CanGenericTypeParamType::get(/*isParameterPack=*/ false,
577633
/*depth=*/0, index,

lib/AST/RequirementMachine/Symbol.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ class Term;
5757
/// This transformation allows DependentMemberTypes to be manipulated as
5858
/// terms, with the actual concrete type structure remaining opaque to
5959
/// the requirement machine. This transformation is implemented in
60-
/// RewriteContext::getConcreteSubstitutionSchema().
60+
/// RewriteContext::getSubstitutionSchemaFromType().
6161
///
6262
/// For example, the superclass requirement
6363
/// "T : MyClass<U.X, (Int) -> V.A.B>" is denoted with a symbol

lib/AST/RequirementMachine/TypeDifference.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,9 @@ namespace {
159159
bool rhsAbstract = rhsType->isTypeParameter();
160160

161161
if (lhsAbstract && rhsAbstract) {
162+
// FIXME: same-element requirements
163+
assert(lhsType->isParameterPack() == rhsType->isParameterPack());
164+
162165
unsigned lhsIndex = RewriteContext::getGenericParamIndex(lhsType);
163166
unsigned rhsIndex = RewriteContext::getGenericParamIndex(rhsType);
164167

test/Generics/pack-shape-requirements.swift

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,3 +107,22 @@ func sameTypeDesugar1<each T, each U>(t: repeat each T, u: repeat each U)
107107
func sameTypeDesugar2<each T: P, each U: P>(t: repeat each T, u: repeat each U)
108108
where Shape<repeat (each T).A> == Shape<repeat (each U).A> {}
109109

110+
/// More complex example involving concrete type matching in
111+
/// property map construction
112+
113+
protocol PP {
114+
associatedtype A
115+
}
116+
117+
struct G<each T> {}
118+
119+
// CHECK-LABEL: sameTypeMatch1
120+
// CHECK-NEXT: <T, each U, each V where T : PP, repeat each U : PP, repeat each V : PP, T.[PP]A == G<repeat (each U).[PP]A>, repeat (each U).[PP]A == (each V).[PP]A>
121+
func sameTypeMatch1<T: PP, each U: PP, each V: PP>(t: T, u: repeat each U, v: repeat each V)
122+
where T.A == G<repeat (each U).A>, T.A == G<repeat (each V).A>,
123+
(repeat (each U, each V)) : Any {}
124+
125+
// CHECK-LABEL: sameTypeMatch2
126+
// CHECK-NEXT: <T, each U, each V where T : PP, repeat each U : PP, (repeat (each U, each V)) : Any, repeat each V : PP, T.[PP]A == (/* shape: each U */ repeat ())>
127+
func sameTypeMatch2<T: PP, each U: PP, each V: PP>(t: T, u: repeat each U, v: repeat each V)
128+
where T.A == Shape<repeat each U>, T.A == Shape<repeat each V> {}

0 commit comments

Comments
 (0)