Skip to content

Commit a10dc8e

Browse files
committed
AST: New ParameterPack.cpp file with new algorithms for pack expansion substitution
1 parent 7e9995b commit a10dc8e

File tree

4 files changed

+263
-50
lines changed

4 files changed

+263
-50
lines changed

include/swift/AST/Types.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2325,6 +2325,8 @@ class TupleType final : public TypeBase, public llvm::FoldingSetNode,
23252325

23262326
bool containsPackExpansionType() const;
23272327

2328+
TupleType *flattenPackTypes();
2329+
23282330
private:
23292331
TupleType(ArrayRef<TupleTypeElt> elements, const ASTContext *CanCtx,
23302332
RecursiveTypeProperties properties)
@@ -3361,6 +3363,8 @@ class AnyFunctionType : public TypeBase {
33613363

33623364
static bool containsPackExpansionType(ArrayRef<Param> params);
33633365

3366+
AnyFunctionType *flattenPackTypes();
3367+
33643368
static void printParams(ArrayRef<Param> Params, raw_ostream &OS,
33653369
const PrintOptions &PO = PrintOptions());
33663370
static void printParams(ArrayRef<Param> Params, ASTPrinter &Printer,
@@ -6414,6 +6418,10 @@ class PackType final : public TypeBase, public llvm::FoldingSetNode,
64146418
return getTrailingObjects<Type>()[index];
64156419
}
64166420

6421+
bool containsPackExpansionType() const;
6422+
6423+
PackType *flattenPackTypes();
6424+
64176425
public:
64186426
void Profile(llvm::FoldingSetNodeID &ID) const {
64196427
Profile(ID, getElementTypes());
@@ -6484,6 +6492,8 @@ class PackExpansionType : public TypeBase, public llvm::FoldingSetNode {
64846492
/// Retrieves the count type of this pack expansion.
64856493
Type getCountType() const { return countType; }
64866494

6495+
PackExpansionType *expand();
6496+
64876497
public:
64886498
void Profile(llvm::FoldingSetNodeID &ID) {
64896499
Profile(ID, getPatternType(), getCountType());

lib/AST/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@ add_swift_host_library(swiftAST STATIC
7272
OperatorNameLookup.cpp
7373
PackConformance.cpp
7474
PackExpansionMatcher.cpp
75+
ParameterPack.cpp
7576
Parameter.cpp
7677
Pattern.cpp
7778
PlatformKind.cpp

lib/AST/ParameterPack.cpp

Lines changed: 252 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,252 @@
1+
//===--- ParameterPack.cpp - Utilities for variadic generics --------------===//
2+
//
3+
// This source file is part of the Swift.org open source project
4+
//
5+
// Copyright (c) 2022 Apple Inc. and the Swift project authors
6+
// Licensed under Apache License v2.0 with Runtime Library Exception
7+
//
8+
// See https://swift.org/LICENSE.txt for license information
9+
// See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors
10+
//
11+
//===----------------------------------------------------------------------===//
12+
//
13+
// This file implements utilities for substituting type parameter packs
14+
// appearing in pack expansion types.
15+
//
16+
//===----------------------------------------------------------------------===//
17+
18+
#include "swift/AST/Types.h"
19+
#include "swift/AST/Decl.h"
20+
#include "swift/AST/Type.h"
21+
#include "llvm/ADT/SmallVector.h"
22+
23+
using namespace swift;
24+
25+
void TypeBase::getTypeParameterPacks(
26+
SmallVectorImpl<Type> &rootParameterPacks) const {
27+
llvm::SmallDenseSet<CanType, 2> visited;
28+
29+
auto recordType = [&](Type t) {
30+
if (visited.insert(t->getCanonicalType()).second)
31+
rootParameterPacks.push_back(t);
32+
};
33+
34+
Type(const_cast<TypeBase *>(this)).visit([&](Type t) {
35+
if (auto *paramTy = t->getAs<GenericTypeParamType>()) {
36+
if (paramTy->isParameterPack()) {
37+
recordType(paramTy);
38+
}
39+
} else if (auto *archetypeTy = t->getAs<PackArchetypeType>()) {
40+
if (archetypeTy->isRoot()) {
41+
recordType(t);
42+
}
43+
}
44+
});
45+
}
46+
47+
bool GenericTypeParamType::isParameterPack() const {
48+
if (auto param = getDecl()) {
49+
return param->isParameterPack();
50+
}
51+
52+
auto fixedNum = ParamOrDepthIndex.get<DepthIndexTy>();
53+
return (fixedNum & GenericTypeParamType::TYPE_SEQUENCE_BIT) ==
54+
GenericTypeParamType::TYPE_SEQUENCE_BIT;
55+
}
56+
57+
/// G<{X1, ..., Xn}, {Y1, ..., Yn}>... => {G<X1, Y1>, ..., G<Xn, Yn>}...
58+
PackExpansionType *PackExpansionType::expand() {
59+
auto countType = getCountType();
60+
auto *countPack = countType->getAs<PackType>();
61+
if (countPack == nullptr)
62+
return this;
63+
64+
auto patternType = getPatternType();
65+
if (patternType->is<PackType>())
66+
return this;
67+
68+
unsigned j = 0;
69+
SmallVector<Type, 4> expandedTypes;
70+
for (auto type : countPack->getElementTypes()) {
71+
Type expandedCount;
72+
if (auto *expansion = type->getAs<PackExpansionType>())
73+
expandedCount = expansion->getCountType();
74+
75+
auto expandedPattern = patternType.transformRec(
76+
[&](Type t) -> Optional<Type> {
77+
if (t->is<PackExpansionType>())
78+
return t;
79+
80+
if (auto *nestedPack = t->getAs<PackType>()) {
81+
auto nestedPackElts = nestedPack->getElementTypes();
82+
if (j < nestedPackElts.size()) {
83+
if (expandedCount) {
84+
if (auto *expansion = nestedPackElts[j]->getAs<PackExpansionType>())
85+
return expansion->getPatternType();
86+
} else {
87+
return nestedPackElts[j];
88+
}
89+
}
90+
91+
return ErrorType::get(t->getASTContext());
92+
}
93+
94+
return None;
95+
});
96+
97+
if (expandedCount) {
98+
expandedTypes.push_back(PackExpansionType::get(expandedPattern,
99+
expandedCount));
100+
} else {
101+
expandedTypes.push_back(expandedPattern);
102+
}
103+
104+
++j;
105+
}
106+
107+
auto *packType = PackType::get(getASTContext(), expandedTypes);
108+
return PackExpansionType::get(packType, countType);
109+
}
110+
111+
bool TupleType::containsPackExpansionType() const {
112+
for (auto elt : getElements()) {
113+
if (elt.getType()->is<PackExpansionType>())
114+
return true;
115+
}
116+
117+
return false;
118+
}
119+
120+
/// (W, {X, Y}..., Z) => (W, X, Y, Z)
121+
TupleType *TupleType::flattenPackTypes() {
122+
bool anyChanged = false;
123+
SmallVector<TupleTypeElt, 4> elts;
124+
125+
for (unsigned i = 0, e = getNumElements(); i < e; ++i) {
126+
auto elt = getElement(i);
127+
128+
if (auto *expansionType = elt.getType()->getAs<PackExpansionType>()) {
129+
if (auto *packType = expansionType->getPatternType()->getAs<PackType>()) {
130+
if (!anyChanged) {
131+
elts.append(getElements().begin(), getElements().begin() + i);
132+
anyChanged = true;
133+
}
134+
135+
bool first = true;
136+
for (auto packElt : packType->getElementTypes()) {
137+
if (first) {
138+
elts.push_back(TupleTypeElt(packElt, elt.getName()));
139+
first = false;
140+
continue;
141+
}
142+
elts.push_back(TupleTypeElt(packElt));
143+
}
144+
145+
continue;
146+
}
147+
}
148+
149+
if (anyChanged)
150+
elts.push_back(elt);
151+
}
152+
153+
if (!anyChanged)
154+
return this;
155+
156+
return TupleType::get(elts, getASTContext());
157+
}
158+
159+
bool AnyFunctionType::containsPackExpansionType(ArrayRef<Param> params) {
160+
for (auto param : params) {
161+
if (param.getPlainType()->is<PackExpansionType>())
162+
return true;
163+
}
164+
165+
return false;
166+
}
167+
168+
/// (W, {X, Y}..., Z) -> T => (W, X, Y, Z) -> T
169+
AnyFunctionType *AnyFunctionType::flattenPackTypes() {
170+
bool anyChanged = false;
171+
SmallVector<AnyFunctionType::Param, 4> params;
172+
173+
for (unsigned i = 0, e = getParams().size(); i < e; ++i) {
174+
auto param = getParams()[i];
175+
176+
if (auto *expansionType = param.getPlainType()->getAs<PackExpansionType>()) {
177+
if (auto *packType = expansionType->getPatternType()->getAs<PackType>()) {
178+
if (!anyChanged) {
179+
params.append(getParams().begin(), getParams().begin() + i);
180+
anyChanged = true;
181+
}
182+
183+
bool first = true;
184+
for (auto packElt : packType->getElementTypes()) {
185+
if (first) {
186+
params.push_back(param.withType(packElt));
187+
first = false;
188+
continue;
189+
}
190+
params.push_back(param.withType(packElt).getWithoutLabels());
191+
}
192+
193+
continue;
194+
}
195+
}
196+
197+
if (anyChanged)
198+
params.push_back(param);
199+
}
200+
201+
if (!anyChanged)
202+
return this;
203+
204+
if (auto *genericFuncType = getAs<GenericFunctionType>()) {
205+
return GenericFunctionType::get(genericFuncType->getGenericSignature(),
206+
params, getResult(), getExtInfo());
207+
} else {
208+
return FunctionType::get(params, getResult(), getExtInfo());
209+
}
210+
}
211+
212+
bool PackType::containsPackExpansionType() const {
213+
for (auto type : getElementTypes()) {
214+
if (type->is<PackExpansionType>())
215+
return true;
216+
}
217+
218+
return false;
219+
}
220+
221+
/// {W, {X, Y}..., Z} => {W, X, Y, Z}
222+
PackType *PackType::flattenPackTypes() {
223+
bool anyChanged = false;
224+
SmallVector<Type, 4> elts;
225+
226+
for (unsigned i = 0, e = getNumElements(); i < e; ++i) {
227+
auto elt = getElementType(i);
228+
229+
if (auto *expansionType = elt->getAs<PackExpansionType>()) {
230+
if (auto *packType = expansionType->getPatternType()->getAs<PackType>()) {
231+
if (!anyChanged) {
232+
elts.append(getElementTypes().begin(), getElementTypes().begin() + i);
233+
anyChanged = true;
234+
}
235+
236+
for (auto packElt : packType->getElementTypes()) {
237+
elts.push_back(packElt);
238+
}
239+
240+
continue;
241+
}
242+
}
243+
244+
if (anyChanged)
245+
elts.push_back(elt);
246+
}
247+
248+
if (!anyChanged)
249+
return this;
250+
251+
return PackType::get(getASTContext(), elts);
252+
}

lib/AST/Type.cpp

Lines changed: 0 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -582,28 +582,6 @@ Type TypeBase::typeEraseOpenedArchetypesWithRoot(
582582
return transformFn(type);
583583
}
584584

585-
void TypeBase::getTypeParameterPacks(
586-
SmallVectorImpl<Type> &rootParameterPacks) const {
587-
llvm::SmallDenseSet<CanType, 2> visited;
588-
589-
auto recordType = [&](Type t) {
590-
if (visited.insert(t->getCanonicalType()).second)
591-
rootParameterPacks.push_back(t);
592-
};
593-
594-
Type(const_cast<TypeBase *>(this)).visit([&](Type t) {
595-
if (auto *paramTy = t->getAs<GenericTypeParamType>()) {
596-
if (paramTy->isParameterPack()) {
597-
recordType(paramTy);
598-
}
599-
} else if (auto *archetypeTy = t->getAs<PackArchetypeType>()) {
600-
if (archetypeTy->isRoot()) {
601-
recordType(t);
602-
}
603-
}
604-
});
605-
}
606-
607585
Type TypeBase::addCurriedSelfType(const DeclContext *dc) {
608586
if (!dc->isTypeContext())
609587
return this;
@@ -1966,16 +1944,6 @@ unsigned GenericTypeParamType::getIndex() const {
19661944
return fixedNum & 0xFFFF;
19671945
}
19681946

1969-
bool GenericTypeParamType::isParameterPack() const {
1970-
if (auto param = getDecl()) {
1971-
return param->isParameterPack();
1972-
}
1973-
1974-
auto fixedNum = ParamOrDepthIndex.get<DepthIndexTy>();
1975-
return (fixedNum & GenericTypeParamType::TYPE_SEQUENCE_BIT) ==
1976-
GenericTypeParamType::TYPE_SEQUENCE_BIT;
1977-
}
1978-
19791947
Identifier GenericTypeParamType::getName() const {
19801948
// Use the declaration name if we still have that sugar.
19811949
if (auto decl = getDecl())
@@ -3459,15 +3427,6 @@ int TupleType::getNamedElementId(Identifier I) const {
34593427
return -1;
34603428
}
34613429

3462-
bool TupleType::containsPackExpansionType() const {
3463-
for (auto elt : getElements()) {
3464-
if (elt.getType()->is<PackExpansionType>())
3465-
return true;
3466-
}
3467-
3468-
return false;
3469-
}
3470-
34713430
ArchetypeType::ArchetypeType(TypeKind Kind,
34723431
const ASTContext &Ctx,
34733432
RecursiveTypeProperties properties,
@@ -4282,15 +4241,6 @@ bool AnyFunctionType::hasSameExtInfoAs(const AnyFunctionType *otherFn) {
42824241
return getExtInfo().isEqualTo(otherFn->getExtInfo(), useClangTypes(this));
42834242
}
42844243

4285-
bool AnyFunctionType::containsPackExpansionType(ArrayRef<Param> params) {
4286-
for (auto param : params) {
4287-
if (param.getPlainType()->is<PackExpansionType>())
4288-
return true;
4289-
}
4290-
4291-
return false;
4292-
}
4293-
42944244
ClangTypeInfo SILFunctionType::getClangTypeInfo() const {
42954245
if (!Bits.SILFunctionType.HasClangTypeInfo)
42964246
return ClangTypeInfo();

0 commit comments

Comments
 (0)