Skip to content

Commit 30a1211

Browse files
committed
AST: Move getReducedShape() from CSSimplify.cpp to a method on TypeBase
1 parent 6d75fac commit 30a1211

File tree

5 files changed

+76
-45
lines changed

5 files changed

+76
-45
lines changed

include/swift/AST/Types.h

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1290,6 +1290,18 @@ class alignas(1 << TypeAlignInBits) TypeBase
12901290
/// Error.
12911291
bool isExistentialWithError();
12921292

1293+
/// Returns the reduced shape of the type, which represents an equivalence
1294+
/// class for the same-shape generic requirement:
1295+
///
1296+
/// - The shape of a scalar type is always the empty tuple type ().
1297+
/// - The shape of a pack archetype is computed from the generic signature
1298+
/// using same-shape requirements.
1299+
/// - The shape of a pack type is computed recursively from its elements.
1300+
///
1301+
/// Two types satisfy a same-shape requirement if their reduced shapes are
1302+
/// equal as canonical types.
1303+
CanType getReducedShape();
1304+
12931305
SWIFT_DEBUG_DUMP;
12941306
void dump(raw_ostream &os, unsigned indent = 0) const;
12951307

@@ -6028,7 +6040,7 @@ class PackArchetypeType final
60286040
LayoutConstraint Layout);
60296041

60306042
// Returns the reduced shape type for this pack archetype.
6031-
Type getShape() const;
6043+
CanType getReducedShape() const;
60326044

60336045
static bool classof(const TypeBase *T) {
60346046
return T->getKind() == TypeKind::PackArchetype;
@@ -6481,6 +6493,8 @@ class PackType final : public TypeBase, public llvm::FoldingSetNode,
64816493

64826494
PackType *flattenPackTypes();
64836495

6496+
CanTypeWrapper<PackType> getReducedShape();
6497+
64846498
public:
64856499
void Profile(llvm::FoldingSetNodeID &ID) const {
64866500
Profile(ID, getElementTypes());
@@ -6553,6 +6567,8 @@ class PackExpansionType : public TypeBase, public llvm::FoldingSetNode {
65536567

65546568
PackExpansionType *expand();
65556569

6570+
CanType getReducedShape();
6571+
65566572
public:
65576573
void Profile(llvm::FoldingSetNodeID &ID) {
65586574
Profile(ID, getPatternType(), getCountType());

lib/AST/ParameterPack.cpp

Lines changed: 53 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,11 @@
1515
//
1616
//===----------------------------------------------------------------------===//
1717

18-
#include "swift/AST/Types.h"
18+
#include "swift/AST/ASTContext.h"
1919
#include "swift/AST/Decl.h"
2020
#include "swift/AST/ParameterList.h"
2121
#include "swift/AST/Type.h"
22+
#include "swift/AST/Types.h"
2223
#include "llvm/ADT/SmallVector.h"
2324

2425
using namespace swift;
@@ -123,6 +124,19 @@ PackExpansionType *PackExpansionType::expand() {
123124
return PackExpansionType::get(packType, countType);
124125
}
125126

127+
CanType PackExpansionType::getReducedShape() {
128+
if (auto *archetypeType = countType->getAs<PackArchetypeType>()) {
129+
auto shape = archetypeType->getReducedShape();
130+
return CanType(PackExpansionType::get(shape, shape));
131+
} else if (auto *packType = countType->getAs<PackType>()) {
132+
auto shape = packType->getReducedShape();
133+
return CanType(PackExpansionType::get(shape, shape));
134+
}
135+
136+
assert(countType->is<PlaceholderType>());
137+
return getASTContext().TheEmptyTupleType;
138+
}
139+
126140
bool TupleType::containsPackExpansionType() const {
127141
for (auto elt : getElements()) {
128142
if (elt.getType()->is<PackExpansionType>())
@@ -266,6 +280,44 @@ PackType *PackType::flattenPackTypes() {
266280
return PackType::get(getASTContext(), elts);
267281
}
268282

283+
CanPackType PackType::getReducedShape() {
284+
SmallVector<Type, 4> elts;
285+
286+
auto &ctx = getASTContext();
287+
288+
for (auto elt : getElementTypes()) {
289+
// T... => shape(T)...
290+
if (auto *packExpansionType = elt->getAs<PackExpansionType>()) {
291+
elts.push_back(packExpansionType->getReducedShape());
292+
continue;
293+
}
294+
295+
// Use () as a placeholder for scalar shape.
296+
assert(!elt->is<PackArchetypeType>() &&
297+
"Pack archetype outside of a pack expansion");
298+
elts.push_back(ctx.TheEmptyTupleType);
299+
}
300+
301+
return CanPackType(PackType::get(ctx, elts));
302+
}
303+
304+
CanType TypeBase::getReducedShape() {
305+
if (auto *packArchetype = getAs<PackArchetypeType>())
306+
return packArchetype->getReducedShape();
307+
308+
if (auto *packType = getAs<PackType>())
309+
return packType->getReducedShape();
310+
311+
if (auto *expansionType = getAs<PackExpansionType>())
312+
return expansionType->getReducedShape();
313+
314+
assert(!isTypeVariableOrMember());
315+
assert(!hasTypeParameter());
316+
317+
// Use () as a placeholder for scalar shape.
318+
return getASTContext().TheEmptyTupleType;
319+
}
320+
269321
unsigned ParameterList::getOrigParamIndex(SubstitutionMap subMap,
270322
unsigned substIndex) const {
271323
unsigned remappedIndex = substIndex;

lib/AST/Type.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4013,9 +4013,9 @@ PackArchetypeType::get(const ASTContext &Ctx,
40134013
{ShapeType}));
40144014
}
40154015

4016-
Type PackArchetypeType::getShape() const {
4016+
CanType PackArchetypeType::getReducedShape() const {
40174017
auto shapeType = getTrailingObjects<PackShape>()->shapeType;
4018-
return getGenericEnvironment()->mapTypeIntoContext(shapeType);
4018+
return getGenericEnvironment()->mapTypeIntoContext(shapeType)->getCanonicalType();
40194019
}
40204020

40214021
ElementArchetypeType::ElementArchetypeType(

lib/Sema/CSGen.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2927,7 +2927,7 @@ namespace {
29272927
return;
29282928

29292929
if (auto archetype = type->getAs<PackArchetypeType>()) {
2930-
shapeType = archetype->getShape();
2930+
shapeType = archetype->getReducedShape();
29312931
}
29322932
});
29332933

lib/Sema/CSSimplify.cpp

Lines changed: 3 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -12603,40 +12603,6 @@ ConstraintSystem::simplifyDynamicCallableApplicableFnConstraint(
1260312603
return SolutionKind::Solved;
1260412604
}
1260512605

12606-
/// FIXME: Move this elsewhere if it is broadly useful
12607-
static Type getReducedShape(Type type, ASTContext &ctx) {
12608-
// Pack archetypes know their reduced shape
12609-
if (auto *packArchetype = type->getAs<PackArchetypeType>())
12610-
return packArchetype->getShape();
12611-
12612-
// Reduced shape of pack is computed recursively
12613-
if (auto *packType = type->getAs<PackType>()) {
12614-
SmallVector<Type, 2> elts;
12615-
12616-
for (auto elt : packType->getElementTypes()) {
12617-
// T... => shape(T)...
12618-
if (auto *packExpansionType = elt->getAs<PackExpansionType>()) {
12619-
if (packExpansionType->getCountType()->is<PlaceholderType>()) {
12620-
elts.push_back(ctx.TheEmptyTupleType);
12621-
continue;
12622-
}
12623-
auto shapeType = getReducedShape(packExpansionType->getCountType(), ctx);
12624-
elts.push_back(PackExpansionType::get(shapeType, shapeType));
12625-
}
12626-
12627-
// Use () as a placeholder for scalar shape.
12628-
elts.push_back(ctx.TheEmptyTupleType);
12629-
}
12630-
12631-
return PackType::get(ctx, elts);
12632-
}
12633-
12634-
assert(!type->isTypeVariableOrMember());
12635-
12636-
// Use () as a placeholder for scalar shape.
12637-
return ctx.TheEmptyTupleType;
12638-
}
12639-
1264012606
ConstraintSystem::SolutionKind ConstraintSystem::simplifyShapeOfConstraint(
1264112607
Type type1, Type type2, TypeMatchOptions flags,
1264212608
ConstraintLocatorBuilder locator) {
@@ -12667,12 +12633,9 @@ ConstraintSystem::SolutionKind ConstraintSystem::simplifyShapeOfConstraint(
1266712633
return formUnsolved();
1266812634
}
1266912635

12670-
if (Type shape = getReducedShape(type1, getASTContext())) {
12671-
addConstraint(ConstraintKind::Bind, shape, type2, locator);
12672-
return SolutionKind::Solved;
12673-
}
12674-
12675-
return SolutionKind::Error;
12636+
auto shape = type1->getReducedShape();
12637+
addConstraint(ConstraintKind::Bind, shape, type2, locator);
12638+
return SolutionKind::Solved;
1267612639
}
1267712640

1267812641
static llvm::PointerIntPair<Type, 3, unsigned>

0 commit comments

Comments
 (0)