Skip to content

Commit f4bfce6

Browse files
committed
[RequirementMachine] Map same-shape requriements to rewrite rules.
A same-shape requirement 'length(T...) == length(U...)' becomes a rewrite rule 'T.[shape] => U.[shape]'. Reduced shape rules will drop the [shape] term from each side of the rule, and create a same-shape requirement between the two type parameter packs.
1 parent 0f13eda commit f4bfce6

File tree

5 files changed

+85
-3
lines changed

5 files changed

+85
-3
lines changed

lib/AST/RequirementMachine/RequirementBuilder.cpp

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -250,6 +250,23 @@ void RequirementBuilder::addRequirementRules(ArrayRef<unsigned> rules) {
250250
llvm_unreachable("Invalid symbol kind");
251251
}
252252

253+
if (rule.getLHS().back().getKind() == Symbol::Kind::Shape) {
254+
assert(rule.getRHS().back().getKind() == Symbol::Kind::Shape);
255+
256+
// Strip off the shape symbols from either side of the rule.
257+
MutableTerm lhsTerm(rule.getLHS().begin(),
258+
rule.getLHS().end() - 1);
259+
MutableTerm rhsTerm(rule.getRHS().begin(),
260+
rule.getRHS().end() - 1);
261+
262+
// Add a SameCount requirement between the two parameter packs.
263+
auto constraintType = Map.getTypeForTerm(lhsTerm, GenericParams);
264+
auto subjectType = Map.getTypeForTerm(rhsTerm, GenericParams);
265+
Reqs.emplace_back(RequirementKind::SameCount,
266+
subjectType, constraintType);
267+
return;
268+
}
269+
253270
assert(rule.getLHS().back().getKind() != Symbol::Kind::Protocol);
254271
auto constraintType = Map.getTypeForTerm(rule.getLHS(), GenericParams);
255272
auto subjectType = Map.getTypeForTerm(rule.getRHS(), GenericParams);

lib/AST/RequirementMachine/RuleBuilder.cpp

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -298,9 +298,25 @@ void RuleBuilder::addRequirement(const Requirement &req,
298298
MutableTerm constraintTerm;
299299

300300
switch (req.getKind()) {
301-
case RequirementKind::SameCount:
302-
// TODO
303-
return;
301+
case RequirementKind::SameCount: {
302+
// A same-shape requirement length(T...) == length(U...)
303+
// becomes a rewrite rule:
304+
//
305+
// T.[shape] => U.[shape]
306+
auto otherType = CanType(req.getSecondType());
307+
assert(otherType->isTypeSequenceParameter());
308+
309+
constraintTerm = (substitutions
310+
? Context.getRelativeTermForType(
311+
otherType, *substitutions)
312+
: Context.getMutableTermForType(
313+
otherType, proto));
314+
315+
// Add the [shape] symbol to both sides.
316+
subjectTerm.add(Symbol::forShape(Context));
317+
constraintTerm.add(Symbol::forShape(Context));
318+
break;
319+
}
304320

305321
case RequirementKind::Conformance: {
306322
// A conformance requirement T : P becomes a rewrite rule

lib/AST/RequirementMachine/Symbol.cpp

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,14 @@ struct Symbol::Storage final
7272
GenericParam = param;
7373
}
7474

75+
/// A dummy type for overload resolution of the
76+
/// 'shape' constructor for Storage.
77+
struct ForShape {};
78+
79+
explicit Storage(ForShape shape) {
80+
Kind = Kind::Shape;
81+
}
82+
7583
Storage(const ProtocolDecl *proto, Identifier name) {
7684
Kind = Symbol::Kind::AssociatedType;
7785
Proto = proto;
@@ -298,6 +306,30 @@ Symbol Symbol::forGenericParam(GenericTypeParamType *param,
298306
return symbol;
299307
}
300308

309+
Symbol Symbol::forShape(RewriteContext &ctx) {
310+
llvm::FoldingSetNodeID id;
311+
id.AddInteger(unsigned(Kind::Shape));
312+
313+
void *insertPos = nullptr;
314+
if (auto *symbol = ctx.Symbols.FindNodeOrInsertPos(id, insertPos))
315+
return symbol;
316+
317+
unsigned size = Storage::totalSizeToAlloc<unsigned, Term>(0, 0);
318+
void *mem = ctx.Allocator.Allocate(size, alignof(Storage));
319+
auto *symbol = new (mem) Storage(Storage::ForShape());
320+
321+
#ifndef NDEBUG
322+
llvm::FoldingSetNodeID newID;
323+
symbol->Profile(newID);
324+
assert(id == newID);
325+
#endif
326+
327+
ctx.Symbols.InsertNode(symbol, insertPos);
328+
ctx.SymbolHistogram.add(unsigned(Kind::Shape));
329+
330+
return symbol;
331+
}
332+
301333
/// Creates a layout symbol, representing a layout constraint.
302334
Symbol Symbol::forLayout(LayoutConstraint layout,
303335
RewriteContext &ctx) {

lib/AST/RequirementMachine/Symbol.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -204,6 +204,8 @@ class Symbol final {
204204
static Symbol forGenericParam(GenericTypeParamType *param,
205205
RewriteContext &ctx);
206206

207+
static Symbol forShape(RewriteContext &ctx);
208+
207209
static Symbol forLayout(LayoutConstraint layout,
208210
RewriteContext &ctx);
209211

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
// RUN: %target-swift-frontend -typecheck -enable-experimental-variadic-generics %s -debug-generic-signatures 2>&1 | %FileCheck %s
2+
3+
protocol P {
4+
associatedtype A
5+
}
6+
7+
// CHECK-LABEL: inferSameShape(ts:us:)
8+
// CHECK-NEXT: Generic signature: <@_typeSequence T, @_typeSequence U where T.count == U.count>
9+
func inferSameShape<@_typeSequence T, @_typeSequence U>(ts t: T..., us u: U...) where ((T, U)...): Any {
10+
}
11+
12+
// CHECK-LABEL: desugarSameShape(ts:us:)
13+
// CHECK-NEXT: Generic signature: <@_typeSequence T, @_typeSequence U where T : P, T.count == U.count, U : P>
14+
func desugarSameShape<@_typeSequence T, @_typeSequence U>(ts t: T..., us u: U...) where T: P, U: P, ((T.A, U.A)...): Any {
15+
}

0 commit comments

Comments
 (0)