Skip to content

Commit 6f098ab

Browse files
committed
[TypeJoin] Implement Type::join for protocols and protocol compositions.
This is mostly complete, but we still need to add support for joins between these two types and other types.
1 parent dca7035 commit 6f098ab

File tree

2 files changed

+167
-14
lines changed

2 files changed

+167
-14
lines changed

lib/AST/TypeJoinMeet.cpp

Lines changed: 121 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,9 @@ struct TypeJoin : CanTypeVisitor<TypeJoin, CanType> {
5151
}
5252

5353
static CanType getSuperclassJoin(CanType first, CanType second);
54+
CanType computeProtocolCompositionJoin(ArrayRef<Type> firstMembers,
55+
ArrayRef<Type> secondMembers);
56+
5457

5558
CanType visitErrorType(CanType second);
5659
CanType visitTupleType(CanType second);
@@ -105,10 +108,10 @@ struct TypeJoin : CanTypeVisitor<TypeJoin, CanType> {
105108

106109
// Likewise, rather than making every visitor deal with Any,
107110
// always dispatch to the protocol composition side of the join.
108-
if (first->isAny())
111+
if (first->is<ProtocolCompositionType>())
109112
return TypeJoin(second).visit(first);
110113

111-
if (second->isAny())
114+
if (second->is<ProtocolCompositionType>())
112115
return TypeJoin(first).visit(second);
113116

114117
// Otherwise the first type might be an optional (or not), so
@@ -184,16 +187,6 @@ CanType TypeJoin::visitClassType(CanType second) {
184187
return getSuperclassJoin(First, second);
185188
}
186189

187-
CanType TypeJoin::visitProtocolType(CanType second) {
188-
assert(First != second);
189-
190-
// FIXME: We should compute a tighter bound and/or return nullptr if
191-
// we cannot. We do this now because existing tests rely on
192-
// producing Any for the join of protocols that have a common
193-
// supertype.
194-
return TheAnyType;
195-
}
196-
197190
CanType TypeJoin::visitBoundGenericClassType(CanType second) {
198191
return getSuperclassJoin(First, second);
199192
}
@@ -352,16 +345,130 @@ CanType TypeJoin::visitGenericFunctionType(CanType second) {
352345
return Unimplemented;
353346
}
354347

348+
// Use the distributive law to compute the join of the protocol
349+
// compositions.
350+
//
351+
// (A ^ B) v (C ^ D)
352+
// = (A v C) ^ (A v D) ^ (B v C) ^ (B v D)
353+
CanType TypeJoin::computeProtocolCompositionJoin(ArrayRef<Type> firstMembers,
354+
ArrayRef<Type> secondMembers) {
355+
SmallVector<Type, 8> result;
356+
for (auto first : firstMembers) {
357+
for (auto second : secondMembers) {
358+
auto joined = Type::join(first, second);
359+
if (!joined)
360+
return Unimplemented;
361+
362+
if ((*joined)->isAny())
363+
continue;
364+
365+
result.push_back(*joined);
366+
}
367+
}
368+
369+
if (result.empty())
370+
return TheAnyType;
371+
372+
auto &ctx = result[0]->getASTContext();
373+
return ProtocolCompositionType::get(ctx, result, false)->getCanonicalType();
374+
}
375+
355376
CanType TypeJoin::visitProtocolCompositionType(CanType second) {
377+
// The join of Any and a no-escape function doesn't exist; it isn't
378+
// Any. If it were Any, it would mean we would allow these functions
379+
// to escape through Any.
356380
if (second->isAny()) {
357381
auto *fnTy = First->getAs<AnyFunctionType>();
358382
if (fnTy && fnTy->getExtInfo().isNoEscape())
359383
return Nonexistent;
360384

361-
return second;
385+
return TheAnyType;
362386
}
363387

364-
return Unimplemented;
388+
assert(First != second);
389+
390+
// FIXME: Handle other types here.
391+
if (First->getKind() != TypeKind::Protocol &&
392+
First->getKind() != TypeKind::ProtocolComposition)
393+
return TheAnyType;
394+
395+
SmallVector<Type, 1> protocolType;
396+
ArrayRef<Type> firstMembers;
397+
if (First->getKind() == TypeKind::Protocol) {
398+
protocolType.push_back(First);
399+
firstMembers = protocolType;
400+
} else {
401+
firstMembers = cast<ProtocolCompositionType>(First)->getMembers();
402+
}
403+
auto secondMembers = cast<ProtocolCompositionType>(second)->getMembers();
404+
405+
return computeProtocolCompositionJoin(firstMembers, secondMembers);
406+
}
407+
408+
// Return true if the first ProtocolDecl is a supertype of the second.
409+
static bool isSupertypeOf(ProtocolDecl *super, ProtocolDecl *sub) {
410+
if (super == sub)
411+
return true;
412+
413+
SmallVector<ProtocolDecl *, 4> worklist;
414+
for (auto *decl : sub->getInheritedProtocols())
415+
worklist.push_back(decl);
416+
417+
llvm::SmallPtrSet<ProtocolDecl *, 4> visited;
418+
while (!worklist.empty()) {
419+
auto *entry = worklist.pop_back_val();
420+
if (visited.count(entry))
421+
continue;
422+
visited.insert(entry);
423+
424+
if (entry == super)
425+
return true;
426+
427+
for (auto *decl : entry->getInheritedProtocols())
428+
worklist.push_back(decl);
429+
}
430+
431+
return false;
432+
}
433+
434+
CanType TypeJoin::visitProtocolType(CanType second) {
435+
assert(First != second);
436+
437+
assert(First->getKind() != TypeKind::ProtocolComposition &&
438+
second->getKind() != TypeKind::ProtocolComposition);
439+
440+
// FIXME: Handle other types here.
441+
if (First->getKind() != second->getKind())
442+
return TheAnyType;
443+
444+
auto *firstDecl =
445+
cast<ProtocolDecl>(First->getNominalOrBoundGenericNominal());
446+
447+
auto *secondDecl =
448+
cast<ProtocolDecl>(second->getNominalOrBoundGenericNominal());
449+
450+
if (firstDecl->getInheritedProtocols().empty() &&
451+
secondDecl->getInheritedProtocols().empty())
452+
return TheAnyType;
453+
454+
if (isSupertypeOf(firstDecl, secondDecl))
455+
return First;
456+
457+
if (isSupertypeOf(secondDecl, firstDecl))
458+
return second;
459+
460+
// One isn't the supertype of the other, so instead, treat each as
461+
// if it's a protocol composition of its inherited members, and join
462+
// those.
463+
SmallVector<Type, 4> firstMembers;
464+
for (auto *decl : firstDecl->getInheritedProtocols())
465+
firstMembers.push_back(decl->getDeclaredInterfaceType());
466+
467+
SmallVector<Type, 4> secondMembers;
468+
for (auto *decl : secondDecl->getInheritedProtocols())
469+
secondMembers.push_back(decl->getDeclaredInterfaceType());
470+
471+
return computeProtocolCompositionJoin(firstMembers, secondMembers);
365472
}
366473

367474
CanType TypeJoin::visitLValueType(CanType second) { return Unimplemented; }

test/Sema/type_join.swift

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,31 @@ import Swift
55
class C {}
66
class D : C {}
77

8+
protocol L {}
9+
protocol M : L {}
10+
protocol N : L {}
11+
protocol P : M {}
12+
protocol Q : M {}
13+
protocol R : L {}
14+
protocol Y {}
15+
16+
protocol FakeEquatable {}
17+
protocol FakeHashable : FakeEquatable {}
18+
protocol FakeExpressibleByIntegerLiteral {}
19+
protocol FakeNumeric : FakeEquatable, FakeExpressibleByIntegerLiteral {}
20+
protocol FakeSignedNumeric : FakeNumeric {}
21+
protocol FakeComparable : FakeEquatable {}
22+
protocol FakeStrideable : FakeComparable {}
23+
protocol FakeCustomStringConvertible {}
24+
protocol FakeBinaryInteger : FakeHashable, FakeNumeric, FakeCustomStringConvertible, FakeStrideable {}
25+
protocol FakeLosslessStringConvertible {}
26+
protocol FakeFixedWidthInteger : FakeBinaryInteger, FakeLosslessStringConvertible {}
27+
protocol FakeUnsignedInteger : FakeBinaryInteger {}
28+
protocol FakeSignedInteger : FakeBinaryInteger, FakeSignedNumeric {}
29+
protocol FakeFloatingPoint : FakeSignedNumeric, FakeStrideable, FakeHashable {}
30+
protocol FakeExpressibleByFloatLiteral {}
31+
protocol FakeBinaryFloatingPoint : FakeFloatingPoint, FakeExpressibleByFloatLiteral {}
32+
833
func expectEqualType<T>(_: T.Type, _: T.Type) {}
934
func commonSupertype<T>(_: T, _: T) -> T {}
1035

@@ -38,6 +63,27 @@ expectEqualType(Builtin.type_join(Builtin.Int1.self, Builtin.Int1.self), Builtin
3863
expectEqualType(Builtin.type_join(Builtin.Int32.self, Builtin.Int1.self), Any.self)
3964
expectEqualType(Builtin.type_join(Builtin.Int1.self, Builtin.Int32.self), Any.self)
4065

66+
expectEqualType(Builtin.type_join(L.self, L.self), L.self)
67+
expectEqualType(Builtin.type_join(L.self, M.self), L.self)
68+
expectEqualType(Builtin.type_join(L.self, P.self), L.self)
69+
expectEqualType(Builtin.type_join(L.self, Y.self), Any.self)
70+
expectEqualType(Builtin.type_join(N.self, P.self), L.self)
71+
expectEqualType(Builtin.type_join(Q.self, P.self), M.self)
72+
expectEqualType(Builtin.type_join((N & P).self, (Q & R).self), M.self)
73+
expectEqualType(Builtin.type_join((Q & P).self, (Y & R).self), L.self)
74+
expectEqualType(Builtin.type_join(FakeEquatable.self, FakeEquatable.self), FakeEquatable.self)
75+
expectEqualType(Builtin.type_join(FakeHashable.self, FakeEquatable.self), FakeEquatable.self)
76+
expectEqualType(Builtin.type_join(FakeEquatable.self, FakeHashable.self), FakeEquatable.self)
77+
expectEqualType(Builtin.type_join(FakeNumeric.self, FakeHashable.self), FakeEquatable.self)
78+
expectEqualType(Builtin.type_join((FakeHashable & Strideable).self, (FakeHashable & FakeNumeric).self),
79+
FakeHashable.self)
80+
expectEqualType(Builtin.type_join((FakeNumeric & Strideable).self,
81+
(FakeHashable & FakeNumeric).self), FakeNumeric.self)
82+
expectEqualType(Builtin.type_join(FakeBinaryInteger.self, FakeFloatingPoint.self),
83+
(FakeHashable & FakeNumeric & FakeStrideable).self)
84+
expectEqualType(Builtin.type_join(FakeFloatingPoint.self, FakeBinaryInteger.self),
85+
(FakeHashable & FakeNumeric & FakeStrideable).self)
86+
4187
func joinFunctions(
4288
_ escaping: @escaping () -> (),
4389
_ nonescaping: () -> ()

0 commit comments

Comments
 (0)