Skip to content

Commit d4b8f1d

Browse files
committed
Implement TypeBase::isTypeSequenceParameter
Behaves much like isTypeParameter but specifically checks for the type sequence bits. Also, add the type sequence bits as a recursive type property.
1 parent 05fe333 commit d4b8f1d

File tree

6 files changed

+227
-18
lines changed

6 files changed

+227
-18
lines changed

include/swift/AST/Types.h

Lines changed: 35 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,11 @@ class RecursiveTypeProperties {
159159
/// This type contains a type placeholder.
160160
HasPlaceholder = 0x800,
161161

162-
Last_Property = HasPlaceholder
162+
/// This type contains a generic type parameter that is declared as a
163+
/// type sequence
164+
HasTypeSequence = 0x1000,
165+
166+
Last_Property = HasTypeSequence
163167
};
164168
enum { BitWidth = countBitsUsed(Property::Last_Property) };
165169

@@ -217,6 +221,8 @@ class RecursiveTypeProperties {
217221
/// Does a type with these properties structurally contain a placeholder?
218222
bool hasPlaceholder() const { return Bits & HasPlaceholder; }
219223

224+
bool hasTypeSequence() const { return Bits & HasTypeSequence; }
225+
220226
/// Returns the set of properties present in either set.
221227
friend RecursiveTypeProperties operator|(Property lhs, Property rhs) {
222228
return RecursiveTypeProperties(unsigned(lhs) | unsigned(rhs));
@@ -605,6 +611,10 @@ class alignas(1 << TypeAlignInBits) TypeBase
605611
return getRecursiveProperties().hasOpenedExistential();
606612
}
607613

614+
bool hasTypeSequence() const {
615+
return getRecursiveProperties().hasTypeSequence();
616+
}
617+
608618
/// Determine whether the type involves the given opened existential
609619
/// archetype.
610620
bool hasOpenedExistential(OpenedArchetypeType *opened);
@@ -655,6 +665,13 @@ class alignas(1 << TypeAlignInBits) TypeBase
655665
/// whether a type parameter exists at any position.
656666
bool isTypeParameter();
657667

668+
/// Determine whether this type is a type sequence parameter, which is
669+
/// either a GenericTypeParamType or a DependentMemberType.
670+
///
671+
/// Like \c isTypeParameter, this routine will return \c false for types that
672+
/// include type parameters in nested positions e.g. \c X<T...>.
673+
bool isTypeSequenceParameter();
674+
658675
/// Determine whether this type can dynamically be an optional type.
659676
///
660677
/// \param includeExistential Whether an existential type should be considered
@@ -3431,6 +3448,7 @@ struct ParameterListInfo {
34313448
SmallBitVector propertyWrappers;
34323449
SmallBitVector implicitSelfCapture;
34333450
SmallBitVector inheritActorContext;
3451+
SmallBitVector variadicGenerics;
34343452

34353453
public:
34363454
ParameterListInfo() { }
@@ -3460,6 +3478,8 @@ struct ParameterListInfo {
34603478
/// Whether there is any contextual information set on this parameter list.
34613479
bool anyContextualInfo() const;
34623480

3481+
bool isVariadicGenericParameter(unsigned paramIdx) const;
3482+
34633483
/// Retrieve the number of non-defaulted parameters.
34643484
unsigned numNonDefaultedParameters() const {
34653485
return defaultArguments.count();
@@ -5883,15 +5903,15 @@ class GenericTypeParamType : public SubstitutableType {
58835903
private:
58845904
friend class GenericTypeParamDecl;
58855905

5886-
explicit GenericTypeParamType(GenericTypeParamDecl *param)
5887-
: SubstitutableType(TypeKind::GenericTypeParam, nullptr,
5888-
RecursiveTypeProperties::HasTypeParameter),
5906+
explicit GenericTypeParamType(GenericTypeParamDecl *param,
5907+
RecursiveTypeProperties props)
5908+
: SubstitutableType(TypeKind::GenericTypeParam, nullptr, props),
58895909
ParamOrDepthIndex(param) { }
58905910

58915911
explicit GenericTypeParamType(bool isTypeSequence, unsigned depth,
5892-
unsigned index, const ASTContext &ctx)
5893-
: SubstitutableType(TypeKind::GenericTypeParam, &ctx,
5894-
RecursiveTypeProperties::HasTypeParameter),
5912+
unsigned index, RecursiveTypeProperties props,
5913+
const ASTContext &ctx)
5914+
: SubstitutableType(TypeKind::GenericTypeParam, &ctx, props),
58955915
ParamOrDepthIndex(depth << 16 | index |
58965916
((isTypeSequence ? 1 : 0) << 30)) {}
58975917
};
@@ -6300,7 +6320,14 @@ inline bool TypeBase::isTypeParameter() {
63006320
return t->is<GenericTypeParamType>();
63016321
}
63026322

6303-
return false;
6323+
inline bool TypeBase::isTypeSequenceParameter() {
6324+
Type t(this);
6325+
6326+
while (auto *memberTy = t->getAs<DependentMemberType>())
6327+
t = memberTy->getBase();
6328+
6329+
return t->is<GenericTypeParamType>() &&
6330+
t->castTo<GenericTypeParamType>()->isTypeSequence();
63046331
}
63056332

63066333
inline bool TypeBase::isMaterializable() {

include/swift/Sema/ConstraintSystem.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -363,6 +363,8 @@ class TypeVariableType::Implementation {
363363
/// Determine whether this type variable represents a subscript result type.
364364
bool isSubscriptResultType() const;
365365

366+
bool isTypeSequence() const;
367+
366368
/// Retrieve the representative of the equivalence class to which this
367369
/// type variable belongs.
368370
///

lib/AST/ASTContext.cpp

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3495,7 +3495,7 @@ isAnyFunctionTypeCanonical(ArrayRef<AnyFunctionType::Param> params,
34953495
static RecursiveTypeProperties
34963496
getGenericFunctionRecursiveProperties(ArrayRef<AnyFunctionType::Param> params,
34973497
Type result) {
3498-
static_assert(RecursiveTypeProperties::BitWidth == 12,
3498+
static_assert(RecursiveTypeProperties::BitWidth == 13,
34993499
"revisit this if you add new recursive type properties");
35003500
RecursiveTypeProperties properties;
35013501

@@ -3548,6 +3548,8 @@ Type AnyFunctionType::Param::getParameterType(bool forCanonical,
35483548
auto arrayDecl = ctx->getArrayDecl();
35493549
if (!arrayDecl)
35503550
type = ErrorType::get(*ctx);
3551+
else if (type->is<PackType>())
3552+
return type;
35513553
else if (forCanonical)
35523554
type = BoundGenericType::get(arrayDecl, Type(), {type});
35533555
else
@@ -3791,8 +3793,12 @@ GenericTypeParamType *GenericTypeParamType::get(bool isTypeSequence,
37913793
if (known != ctx.getImpl().GenericParamTypes.end())
37923794
return known->second;
37933795

3796+
RecursiveTypeProperties props = RecursiveTypeProperties::HasTypeParameter;
3797+
if (isTypeSequence)
3798+
props |= RecursiveTypeProperties::HasTypeSequence;
3799+
37943800
auto result = new (ctx, AllocationArena::Permanent)
3795-
GenericTypeParamType(isTypeSequence, depth, index, ctx);
3801+
GenericTypeParamType(isTypeSequence, depth, index, props, ctx);
37963802
ctx.getImpl().GenericParamTypes[{depthKey, index}] = result;
37973803
return result;
37983804
}
@@ -4078,7 +4084,7 @@ CanSILFunctionType SILFunctionType::get(
40784084
void *mem = ctx.Allocate(bytes, alignof(SILFunctionType));
40794085

40804086
RecursiveTypeProperties properties;
4081-
static_assert(RecursiveTypeProperties::BitWidth == 12,
4087+
static_assert(RecursiveTypeProperties::BitWidth == 13,
40824088
"revisit this if you add new recursive type properties");
40834089
for (auto &param : params)
40844090
properties |= param.getInterfaceType()->getRecursiveProperties();

lib/AST/Decl.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4107,7 +4107,10 @@ GenericTypeParamDecl::GenericTypeParamDecl(DeclContext *dc, Identifier name,
41074107
assert(Bits.GenericTypeParamDecl.Index == index && "Truncation");
41084108
Bits.GenericTypeParamDecl.TypeSequence = isTypeSequence;
41094109
auto &ctx = dc->getASTContext();
4110-
auto type = new (ctx, AllocationArena::Permanent) GenericTypeParamType(this);
4110+
RecursiveTypeProperties props = RecursiveTypeProperties::HasTypeParameter;
4111+
if (this->isTypeSequence())
4112+
props |= RecursiveTypeProperties::HasTypeSequence;
4113+
auto type = new (ctx, AllocationArena::Permanent) GenericTypeParamType(this, props);
41114114
setInterfaceType(MetatypeType::get(type, ctx));
41124115
}
41134116

lib/AST/Type.cpp

Lines changed: 171 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1037,6 +1037,7 @@ ParameterListInfo::ParameterListInfo(
10371037
propertyWrappers.resize(params.size());
10381038
implicitSelfCapture.resize(params.size());
10391039
inheritActorContext.resize(params.size());
1040+
variadicGenerics.resize(params.size());
10401041

10411042
// No parameter owner means no parameter list means no default arguments
10421043
// - hand back the zeroed bitvector.
@@ -1096,6 +1097,11 @@ ParameterListInfo::ParameterListInfo(
10961097
if (param->getAttrs().hasAttribute<InheritActorContextAttr>()) {
10971098
inheritActorContext.set(i);
10981099
}
1100+
1101+
if (param->isVariadic() &&
1102+
param->getVarargBaseTy()->hasTypeSequence()) {
1103+
variadicGenerics.set(i);
1104+
}
10991105
}
11001106
}
11011107

@@ -1130,6 +1136,13 @@ bool ParameterListInfo::anyContextualInfo() const {
11301136
return implicitSelfCapture.any() || inheritActorContext.any();
11311137
}
11321138

1139+
bool ParameterListInfo::isVariadicGenericParameter(unsigned paramIdx) const {
1140+
return paramIdx < variadicGenerics.size()
1141+
? variadicGenerics[paramIdx]
1142+
: false;
1143+
}
1144+
1145+
11331146
/// Turn a param list into a symbolic and printable representation that does not
11341147
/// include the types, something like (_:, b:, c:)
11351148
std::string swift::getParamListAsString(ArrayRef<AnyFunctionType::Param> params) {
@@ -1360,6 +1373,13 @@ CanType TypeBase::computeCanonicalType() {
13601373
break;
13611374
}
13621375

1376+
case TypeKind::PackExpansion: {
1377+
auto *expansion = cast<PackExpansionType>(this);
1378+
auto pattern = expansion->getPatternType()->getCanonicalType();
1379+
Result = PackExpansionType::get(pattern);
1380+
break;
1381+
}
1382+
13631383
case TypeKind::Tuple: {
13641384
TupleType *TT = cast<TupleType>(this);
13651385
assert(TT->getNumElements() != 0 && "Empty tuples are always canonical");
@@ -5061,18 +5081,152 @@ case TypeKind::Id:
50615081
return ParenType::get(Ptr->getASTContext(), underlying->getInOutObjectType(), otherFlags);
50625082
}
50635083

5084+
case TypeKind::Pack: {
5085+
auto pack = cast<PackType>(base);
5086+
bool anyChanged = false;
5087+
SmallVector<Type, 4> elements;
5088+
unsigned Index = 0;
5089+
for (Type eltTy : pack->getElementTypes()) {
5090+
Type transformedEltTy = eltTy.transformRec(fn);
5091+
if (!transformedEltTy)
5092+
return Type();
5093+
5094+
// If nothing has changed, just keep going.
5095+
if (!anyChanged &&
5096+
transformedEltTy.getPointer() == eltTy.getPointer()) {
5097+
++Index;
5098+
continue;
5099+
}
5100+
5101+
// If this is the first change we've seen, copy all of the previous
5102+
// elements.
5103+
if (!anyChanged) {
5104+
// Copy all of the previous elements.
5105+
elements.append(pack->getElementTypes().begin(),
5106+
pack->getElementTypes().begin() + Index);
5107+
anyChanged = true;
5108+
}
5109+
5110+
elements.push_back(transformedEltTy);
5111+
++Index;
5112+
}
5113+
5114+
if (!anyChanged)
5115+
return *this;
5116+
5117+
return PackType::get(Ptr->getASTContext(), elements);
5118+
}
5119+
5120+
case TypeKind::PackExpansion: {
5121+
auto expand = cast<PackExpansionType>(base);
5122+
struct ExpansionGatherer {
5123+
llvm::function_ref<Optional<Type>(TypeBase *)> baselineFn;
5124+
llvm::DenseMap<TypeBase *, PackType *> cache;
5125+
unsigned maxArity;
5126+
5127+
public:
5128+
ExpansionGatherer(
5129+
llvm::function_ref<Optional<Type>(TypeBase *)> baselineFn)
5130+
: baselineFn(baselineFn), maxArity(0) {}
5131+
5132+
Optional<Type> operator()(TypeBase *input) {
5133+
auto remap = baselineFn(input);
5134+
if (!remap) {
5135+
return remap;
5136+
}
5137+
5138+
if (input->is<TypeVariableType>()) {
5139+
if (auto *PT = (*remap)->getAs<PackType>()) {
5140+
maxArity = std::max(maxArity, PT->getNumElements());
5141+
cache.insert({input, PT});
5142+
}
5143+
} else if (input->isTypeSequenceParameter()) {
5144+
if (auto *PT = (*remap)->getAs<PackType>()) {
5145+
maxArity = std::max(maxArity, PT->getNumElements());
5146+
cache.insert({input, PT});
5147+
}
5148+
}
5149+
return remap;
5150+
}
5151+
5152+
std::pair<llvm::DenseMap<TypeBase *, PackType *>, unsigned>
5153+
intoExpansions() && {
5154+
return std::make_pair(cache, maxArity);
5155+
}
5156+
};
5157+
5158+
// First, substitute down the pattern type to gather the mapping from
5159+
// contained substitutable types to packs.
5160+
auto gather = ExpansionGatherer{fn};
5161+
Type transformedPat = expand->getPatternType().transformRec(gather);
5162+
if (!transformedPat)
5163+
return Type();
5164+
5165+
if (transformedPat.getPointer() == expand->getPatternType().getPointer())
5166+
return *this;
5167+
5168+
llvm::DenseMap<TypeBase *, PackType *> expansions;
5169+
unsigned arity;
5170+
std::tie(expansions, arity) = std::move(gather).intoExpansions();
5171+
if (expansions.empty()) {
5172+
// If we didn't find any expansions, either the caller wasn't interested
5173+
// in expanding this pack, or something has gone wrong. Leave off the
5174+
// expansion and return the transformed type.
5175+
return PackExpansionType::get(transformedPat);
5176+
}
5177+
5178+
SmallVector<Type, 8> elts;
5179+
elts.reserve(arity);
5180+
// Perform the expansion element-wise according to the maximum arity we
5181+
// picked up during the gather step above.
5182+
//
5183+
// For a pack expansion (F<... T..., U..., ...>) and mapping
5184+
//
5185+
// T... -> <X, Y, Z>
5186+
// U... -> <A, B, C>
5187+
//
5188+
// The expected expansion is
5189+
//
5190+
// <F<... X, A, ...>, F<... Y, B, ...>, F<... Z, C, ...> ...>
5191+
for (unsigned i = 0; i < arity; ++i) {
5192+
struct ElementExpander {
5193+
const llvm::DenseMap<TypeBase *, PackType *> &expansions;
5194+
llvm::function_ref<Optional<Type>(TypeBase *)> outerFn;
5195+
unsigned index;
5196+
5197+
public:
5198+
Optional<Type> operator()(TypeBase *input) {
5199+
// FIXME: Does this need to do bounds checking?
5200+
if (PackType *element = expansions.lookup(input))
5201+
return element->getElementType(index);
5202+
return outerFn(input);
5203+
}
5204+
};
5205+
5206+
auto expandedElt = expand->getPatternType().transformRec(
5207+
ElementExpander{expansions, fn, i});
5208+
if (!expandedElt)
5209+
return Type();
5210+
5211+
elts.push_back(expandedElt);
5212+
}
5213+
return PackType::get(base->getASTContext(), elts);
5214+
}
5215+
50645216
case TypeKind::Tuple: {
50655217
auto tuple = cast<TupleType>(base);
50665218
bool anyChanged = false;
50675219
SmallVector<TupleTypeElt, 4> elements;
50685220
unsigned Index = 0;
50695221
for (const auto &elt : tuple->getElements()) {
5070-
Type eltTy = elt.getType().transformRec(fn);
5071-
if (!eltTy)
5222+
Type eltTy = elt.getType();
5223+
Type transformedEltTy = eltTy.transformRec(fn);
5224+
if (!transformedEltTy)
50725225
return Type();
50735226

50745227
// If nothing has changed, just keep going.
5075-
if (!anyChanged && eltTy.getPointer() == elt.getType().getPointer()) {
5228+
if (!anyChanged &&
5229+
transformedEltTy.getPointer() == elt.getType().getPointer()) {
50765230
++Index;
50775231
continue;
50785232
}
@@ -5086,9 +5240,18 @@ case TypeKind::Id:
50865240
anyChanged = true;
50875241
}
50885242

5089-
// Add the new tuple element, with the new type, no initializer,
5090-
elements.push_back(elt.getWithType(eltTy));
5091-
++Index;
5243+
if (eltTy->isTypeSequenceParameter() &&
5244+
transformedEltTy->is<PackType>()) {
5245+
assert(anyChanged);
5246+
// Splat the tuple in by copying in all of the transformed elements.
5247+
auto tuple = dyn_cast<PackType>(transformedEltTy.getPointer());
5248+
elements.append(tuple->getElementTypes().begin(),
5249+
tuple->getElementTypes().end());
5250+
} else {
5251+
// Add the new tuple element, with the transformed type.
5252+
elements.push_back(elt.getWithType(transformedEltTy));
5253+
++Index;
5254+
}
50925255
}
50935256

50945257
if (!anyChanged)
@@ -5505,6 +5668,8 @@ ReferenceCounting TypeBase::getReferenceCounting() {
55055668
case TypeKind::SILToken:
55065669
case TypeKind::GenericTypeParam:
55075670
case TypeKind::DependentMember:
5671+
case TypeKind::Pack:
5672+
case TypeKind::PackExpansion:
55085673
#define REF_STORAGE(Name, ...) \
55095674
case TypeKind::Name##Storage:
55105675
#include "swift/AST/ReferenceStorage.def"

0 commit comments

Comments
 (0)