Skip to content

Commit 9ab4dc4

Browse files
committed
[NFC] Add better APIs for parallel destructuring of orig+subst types
As I've been iterating on this work, I've been gradually mulling these over, and I think this is the way to go for now. These should make it a lot less cumbersome to write these kinds of traversals correctly. The intent is to the sunset the existing expanded-components stuff after I do a similar pass for function parameters.
1 parent 7a8d8b4 commit 9ab4dc4

File tree

8 files changed

+336
-260
lines changed

8 files changed

+336
-260
lines changed

include/swift/SIL/AbstractionPattern.h

Lines changed: 52 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -913,6 +913,9 @@ class AbstractionPattern {
913913
bool hasCachingKey() const {
914914
// Only the simplest Kind::Type pattern has a caching key; we
915915
// don't want to try to unique by Clang node.
916+
//
917+
// Even if we support Clang nodes someday, we *cannot* cache
918+
// by the open-coded patterns like Tuple and PackExpansion.
916919
return getKind() == Kind::Type || getKind() == Kind::Opaque
917920
|| getKind() == Kind::Discard;
918921
}
@@ -1216,9 +1219,7 @@ class AbstractionPattern {
12161219
case Kind::Invalid:
12171220
llvm_unreachable("querying invalid abstraction pattern!");
12181221
case Kind::Opaque:
1219-
return typename CanTypeWrapperTraits<TYPE>::type();
12201222
case Kind::Tuple:
1221-
return typename CanTypeWrapperTraits<TYPE>::type();
12221223
case Kind::OpaqueFunction:
12231224
case Kind::OpaqueDerivativeFunction:
12241225
return typename CanTypeWrapperTraits<TYPE>::type();
@@ -1275,7 +1276,7 @@ class AbstractionPattern {
12751276

12761277
/// Is the given tuple type a valid substitution of this abstraction
12771278
/// pattern?
1278-
bool matchesTuple(CanTupleType substType);
1279+
bool matchesTuple(CanTupleType substType) const;
12791280

12801281
bool isTuple() const {
12811282
switch (getKind()) {
@@ -1346,6 +1347,40 @@ class AbstractionPattern {
13461347
return { { this, 0 }, { this, getNumTupleElements() } };
13471348
}
13481349

1350+
/// Perform a parallel visitation of the elements of a tuple type,
1351+
/// preserving structure about where pack expansions appear in the
1352+
/// original type and how many elements of the substituted type they
1353+
/// expand to.
1354+
///
1355+
/// This pattern must be a tuple pattern.
1356+
///
1357+
/// Calls handleScalar or handleExpansion as appropriate for each
1358+
/// element of the original tuple, in order.
1359+
void forEachTupleElement(CanTupleType substType,
1360+
llvm::function_ref<void(unsigned origEltIndex,
1361+
unsigned substEltIndex,
1362+
AbstractionPattern origEltType,
1363+
CanType substEltType)>
1364+
handleScalar,
1365+
llvm::function_ref<void(unsigned origEltIndex,
1366+
unsigned substEltIndex,
1367+
AbstractionPattern origExpansionType,
1368+
CanTupleEltTypeArrayRef substEltTypes)>
1369+
handleExpansion) const;
1370+
1371+
/// Perform a parallel visitation of the elements of a tuple type,
1372+
/// expanding the elements of the type. This preserves the structure
1373+
/// of the *substituted* tuple type: it will be called once per element
1374+
/// of the substituted type, in order. The original element trappings
1375+
/// are also provided for convenience.
1376+
///
1377+
/// This pattern must match the substituted type, but it may be an
1378+
/// opaque pattern.
1379+
void forEachExpandedTupleElement(CanTupleType substType,
1380+
llvm::function_ref<void(AbstractionPattern origEltType,
1381+
CanType substEltType,
1382+
const TupleTypeElt &elt)> handleElement) const;
1383+
13491384
/// Is the given pack type a valid substitution of this abstraction
13501385
/// pattern?
13511386
bool matchesPack(CanPackType substType);
@@ -1420,13 +1455,20 @@ class AbstractionPattern {
14201455
/// the abstraction pattern for an element type.
14211456
AbstractionPattern getPackElementType(unsigned index) const;
14221457

1423-
/// Give that the value being abstracted is a pack expansion type, return the
1424-
/// underlying pattern type.
1458+
/// Given that the value being abstracted is a pack expansion type,
1459+
/// return the underlying pattern type.
1460+
///
1461+
/// If you're looking for getPackExpansionCountType(), it deliberately
1462+
/// does not exist. Count types are not lowered types, and the original
1463+
/// count types are not relevant to lowering. Only the substituted
1464+
/// components and expansion counts are significant.
14251465
AbstractionPattern getPackExpansionPatternType() const;
14261466

1427-
/// Give that the value being abstracted is a pack expansion type, return the
1428-
/// underlying count type.
1429-
AbstractionPattern getPackExpansionCountType() const;
1467+
/// Given that the value being abstracted is a pack expansion type,
1468+
/// return the appropriate pattern type for the given expansion
1469+
/// component.
1470+
AbstractionPattern getPackExpansionComponentType(CanType substType) const;
1471+
AbstractionPattern getPackExpansionComponentType(bool isExpansion) const;
14301472

14311473
/// Given that the value being abstracted is a function, return the
14321474
/// abstraction pattern for its result type.
@@ -1486,6 +1528,8 @@ class AbstractionPattern {
14861528
void forEachPackExpandedComponent(
14871529
llvm::function_ref<void(AbstractionPattern pattern)> fn) const;
14881530

1531+
size_t getNumPackExpandedComponents() const;
1532+
14891533
SmallVector<AbstractionPattern, 4> getPackExpandedComponents() const;
14901534

14911535
/// If this pattern refers to a foreign ObjC method that was imported as

lib/SIL/IR/AbstractionPattern.cpp

Lines changed: 130 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -282,7 +282,7 @@ LayoutConstraint AbstractionPattern::getLayoutConstraint() const {
282282
}
283283
}
284284

285-
bool AbstractionPattern::matchesTuple(CanTupleType substType) {
285+
bool AbstractionPattern::matchesTuple(CanTupleType substType) const {
286286
switch (getKind()) {
287287
case Kind::Invalid:
288288
llvm_unreachable("querying invalid abstraction pattern!");
@@ -311,26 +311,25 @@ bool AbstractionPattern::matchesTuple(CanTupleType substType) {
311311
LLVM_FALLTHROUGH;
312312
case Kind::Tuple: {
313313
size_t nextSubstIndex = 0;
314-
auto nextComponentIsAcceptable =
315-
[&](AbstractionPattern origComponentType) -> bool {
314+
auto nextComponentIsAcceptable = [&](bool isPackExpansion) -> bool {
316315
if (nextSubstIndex == substType->getNumElements())
317316
return false;
318317
auto substComponentType = substType.getElementType(nextSubstIndex++);
319-
return (origComponentType.isPackExpansion() ==
320-
isa<PackExpansionType>(substComponentType));
318+
return (isPackExpansion == isa<PackExpansionType>(substComponentType));
321319
};
322-
for (size_t i = 0, n = getNumTupleElements(); i != n; ++i) {
323-
auto elt = getTupleElementType(i);
324-
if (elt.isPackExpansion()) {
325-
bool fail = false;
326-
elt.forEachPackExpandedComponent([&](AbstractionPattern component) {
327-
if (!nextComponentIsAcceptable(component))
328-
fail = true;
329-
});
330-
if (fail) return false;
331-
} else {
332-
if (!nextComponentIsAcceptable(elt))
333-
return false;
320+
for (auto elt : getTupleElementTypes()) {
321+
bool isPackExpansion = elt.isPackExpansion();
322+
if (isPackExpansion && elt.GenericSubs) {
323+
auto origExpansion = cast<PackExpansionType>(elt.getType());
324+
auto substShape = cast<PackType>(
325+
origExpansion.getCountType().subst(elt.GenericSubs)
326+
->getCanonicalType());
327+
for (auto shapeElt : substShape.getElementTypes()) {
328+
if (!nextComponentIsAcceptable(isa<PackExpansionType>(shapeElt)))
329+
return false;
330+
}
331+
} else if (!nextComponentIsAcceptable(isPackExpansion)) {
332+
return false;
334333
}
335334
}
336335
return nextSubstIndex == substType->getNumElements();
@@ -469,6 +468,87 @@ bool AbstractionPattern::doesTupleContainPackExpansionType() const {
469468
llvm_unreachable("bad kind");
470469
}
471470

471+
void AbstractionPattern::forEachTupleElement(CanTupleType substType,
472+
llvm::function_ref<void(unsigned origEltIndex,
473+
unsigned substEltIndex,
474+
AbstractionPattern origEltType,
475+
CanType substEltType)>
476+
handleScalar,
477+
llvm::function_ref<void(unsigned origEltIndex,
478+
unsigned substEltIndex,
479+
AbstractionPattern origExpansionType,
480+
CanTupleEltTypeArrayRef substEltTypes)>
481+
handleExpansion) const {
482+
assert(isTuple() && "can only call on a tuple expansion");
483+
assert(matchesTuple(substType));
484+
485+
size_t substEltIndex = 0;
486+
auto substEltTypes = substType.getElementTypes();
487+
for (size_t origEltIndex : range(getNumTupleElements())) {
488+
auto origEltType = getTupleElementType(origEltIndex);
489+
if (!origEltType.isPackExpansion()) {
490+
handleScalar(origEltIndex, substEltIndex,
491+
origEltType, substEltTypes[substEltIndex]);
492+
substEltIndex++;
493+
} else {
494+
auto numComponents = origEltType.getNumPackExpandedComponents();
495+
handleExpansion(origEltIndex, substEltIndex, origEltType,
496+
substEltTypes.slice(substEltIndex, numComponents));
497+
substEltIndex += numComponents;
498+
}
499+
}
500+
assert(substEltIndex == substEltTypes.size());
501+
}
502+
503+
void AbstractionPattern::forEachExpandedTupleElement(CanTupleType substType,
504+
llvm::function_ref<void(AbstractionPattern origEltType,
505+
CanType substEltType,
506+
const TupleTypeElt &elt)>
507+
handleElement) const {
508+
assert(matchesTuple(substType));
509+
510+
auto substEltTypes = substType.getElementTypes();
511+
512+
// Handle opaque patterns by just iterating the substituted components.
513+
if (!isTuple()) {
514+
for (auto i : indices(substEltTypes)) {
515+
handleElement(getTupleElementType(i), substEltTypes[i],
516+
substType->getElement(i));
517+
}
518+
return;
519+
}
520+
521+
// For non-opaque patterns, we have to iterate the original components
522+
// in order to match things up properly, but we'll still end up calling
523+
// once per substituted element.
524+
size_t substEltIndex = 0;
525+
for (size_t origEltIndex : range(getNumTupleElements())) {
526+
auto origEltType = getTupleElementType(origEltIndex);
527+
if (!origEltType.isPackExpansion()) {
528+
handleElement(origEltType, substEltTypes[substEltIndex],
529+
substType->getElement(substEltIndex));
530+
substEltIndex++;
531+
} else {
532+
auto origPatternType = origEltType.getPackExpansionPatternType();
533+
for (auto i : range(origEltType.getNumPackExpandedComponents())) {
534+
(void) i;
535+
auto substEltType = substEltTypes[substEltIndex];
536+
// When the substituted type is a pack expansion, pass down
537+
// the original element type so that it's *also* a pack expansion.
538+
// Clients expect to look through this structure in parallel on
539+
// both types. The count is misleading, but normal usage won't
540+
// access it, and there's nothing we could provide that *wouldn't*
541+
// be misleading in one way or another.
542+
handleElement(isa<PackExpansionType>(substEltType)
543+
? origEltType : origPatternType,
544+
substEltType, substType->getElement(substEltIndex));
545+
substEltIndex++;
546+
}
547+
}
548+
}
549+
assert(substEltIndex == substEltTypes.size());
550+
}
551+
472552
static CanType getCanPackElementType(CanType type, unsigned index) {
473553
return cast<PackType>(type).getElementType(index);
474554
}
@@ -541,6 +621,17 @@ bool AbstractionPattern::matchesPack(CanPackType substType) {
541621
llvm_unreachable("bad kind");
542622
}
543623

624+
AbstractionPattern
625+
AbstractionPattern::getPackExpansionComponentType(CanType substType) const {
626+
return getPackExpansionComponentType(isa<PackExpansionType>(substType));
627+
}
628+
629+
AbstractionPattern
630+
AbstractionPattern::getPackExpansionComponentType(bool isExpansion) const {
631+
assert(isPackExpansion());
632+
return isExpansion ? *this : getPackExpansionPatternType();
633+
}
634+
544635
static CanType getPackExpansionPatternType(CanType type) {
545636
return cast<PackExpansionType>(type).getPatternType();
546637
}
@@ -584,49 +675,6 @@ AbstractionPattern AbstractionPattern::getPackExpansionPatternType() const {
584675
llvm_unreachable("bad kind");
585676
}
586677

587-
static CanType getPackExpansionCountType(CanType type) {
588-
return cast<PackExpansionType>(type).getCountType();
589-
}
590-
591-
AbstractionPattern AbstractionPattern::getPackExpansionCountType() const {
592-
switch (getKind()) {
593-
case Kind::Invalid:
594-
llvm_unreachable("querying invalid abstraction pattern!");
595-
case Kind::ObjCMethodType:
596-
case Kind::CurriedObjCMethodType:
597-
case Kind::PartialCurriedObjCMethodType:
598-
case Kind::CFunctionAsMethodType:
599-
case Kind::CurriedCFunctionAsMethodType:
600-
case Kind::PartialCurriedCFunctionAsMethodType:
601-
case Kind::CXXMethodType:
602-
case Kind::CurriedCXXMethodType:
603-
case Kind::PartialCurriedCXXMethodType:
604-
case Kind::Tuple:
605-
case Kind::OpaqueFunction:
606-
case Kind::OpaqueDerivativeFunction:
607-
case Kind::ObjCCompletionHandlerArgumentsType:
608-
case Kind::ClangType:
609-
llvm_unreachable("pattern for function or tuple cannot be for "
610-
"pack expansion type");
611-
612-
case Kind::Opaque:
613-
return *this;
614-
615-
case Kind::Type:
616-
if (isTypeParameterOrOpaqueArchetype())
617-
return AbstractionPattern::getOpaque();
618-
return AbstractionPattern(getGenericSubstitutions(),
619-
getGenericSignature(),
620-
::getPackExpansionCountType(getType()));
621-
622-
case Kind::Discard:
623-
return AbstractionPattern::getDiscard(
624-
getGenericSubstitutions(), getGenericSignature(),
625-
::getPackExpansionCountType(getType()));
626-
}
627-
llvm_unreachable("bad kind");
628-
}
629-
630678
SmallVector<AbstractionPattern, 4>
631679
AbstractionPattern::getPackExpandedComponents() const {
632680
SmallVector<AbstractionPattern, 4> result;
@@ -636,6 +684,21 @@ AbstractionPattern::getPackExpandedComponents() const {
636684
return result;
637685
}
638686

687+
size_t AbstractionPattern::getNumPackExpandedComponents() const {
688+
assert(isPackExpansion());
689+
assert(getKind() == Kind::Type || getKind() == Kind::Discard);
690+
691+
// If we don't have substitutions, we should be walking parallel
692+
// structure; take a single element.
693+
if (!GenericSubs) return 1;
694+
695+
// Otherwise, substitute the expansion shape.
696+
auto origExpansion = cast<PackExpansionType>(getType());
697+
auto substShape = cast<PackType>(
698+
origExpansion.getCountType().subst(GenericSubs)->getCanonicalType());
699+
return substShape->getNumElements();
700+
}
701+
639702
void AbstractionPattern::forEachPackExpandedComponent(
640703
llvm::function_ref<void (AbstractionPattern)> fn) const {
641704
assert(isPackExpansion());
@@ -665,7 +728,7 @@ void AbstractionPattern::forEachPackExpandedComponent(
665728
}
666729

667730
default:
668-
return fn(*this);
731+
llvm_unreachable("not a pack expansion");
669732
}
670733
llvm_unreachable("bad kind");
671734
}
@@ -692,8 +755,9 @@ AbstractionPattern AbstractionPattern::removingMoveOnlyWrapper() const {
692755
llvm_unreachable("not handled yet");
693756
case Kind::Discard:
694757
llvm_unreachable("operation not needed on discarded abstractions yet");
695-
case Kind::Opaque:
696758
case Kind::Tuple:
759+
llvm_unreachable("cannot apply move-only wrappers to open-coded patterns");
760+
case Kind::Opaque:
697761
case Kind::Type:
698762
if (auto mvi = dyn_cast<SILMoveOnlyWrappedType>(getType())) {
699763
return AbstractionPattern(getGenericSubstitutions(),
@@ -727,8 +791,9 @@ AbstractionPattern AbstractionPattern::addingMoveOnlyWrapper() const {
727791
llvm_unreachable("not handled yet");
728792
case Kind::Discard:
729793
llvm_unreachable("operation not needed on discarded abstractions yet");
730-
case Kind::Opaque:
731794
case Kind::Tuple:
795+
llvm_unreachable("cannot add move only wrapper to open-coded pattern");
796+
case Kind::Opaque:
732797
case Kind::Type:
733798
if (isa<SILMoveOnlyWrappedType>(getType()))
734799
return *this;
@@ -1686,7 +1751,7 @@ AbstractionPattern::operator==(const AbstractionPattern &other) const {
16861751
}
16871752
}
16881753
return true;
1689-
1754+
16901755
case Kind::Type:
16911756
case Kind::Discard:
16921757
return OrigType == other.OrigType
@@ -1996,7 +2061,7 @@ class SubstFunctionTypePatternVisitor
19962061
auto substPatternType = visit(pack->getPatternType(),
19972062
pattern.getPackExpansionPatternType());
19982063
auto substCountType = visit(pack->getCountType(),
1999-
pattern.getPackExpansionCountType());
2064+
AbstractionPattern::getOpaque());
20002065

20012066
SmallVector<Type> rootParameterPacks;
20022067
substPatternType->getTypeParameterPacks(rootParameterPacks);

0 commit comments

Comments
 (0)