Skip to content

Commit df1c3b7

Browse files
committed
Simplify select_enum forwarding instruction
Remove OwnershipForwardingSelectEnumInstBase, inherit SelectEnumInst from OwnershipForwardingSingleValueInstruction instead.
1 parent c8001d8 commit df1c3b7

File tree

4 files changed

+114
-204
lines changed

4 files changed

+114
-204
lines changed

include/swift/SIL/SILInstruction.h

Lines changed: 94 additions & 120 deletions
Original file line numberDiff line numberDiff line change
@@ -6479,52 +6479,11 @@ class UncheckedTakeEnumDataAddrInst
64796479
}
64806480
};
64816481

6482-
// Abstract base class of all select instructions like select_enum.
6483-
// The template parameter represents a type of case values
6484-
// to be compared with the operand of a select instruction.
6485-
//
6486-
// Subclasses must provide tail allocated storage.
6487-
// The first operand is the operand of select_xxx instruction. The rest of
6488-
// the operands are the case values and results of a select instruction.
6489-
template <class Derived, class T, class Base = SingleValueInstruction>
6490-
class SelectInstBase : public Base {
6491-
public:
6492-
template <typename... Args>
6493-
SelectInstBase(SILInstructionKind kind, SILDebugLocation Loc, SILType type,
6494-
Args &&... otherArgs)
6495-
: Base(kind, Loc, type, std::forward<Args>(otherArgs)...) {}
6496-
6497-
SILValue getOperand() const { return getAllOperands()[0].get(); }
6498-
6499-
ArrayRef<Operand> getAllOperands() const {
6500-
return static_cast<const Derived *>(this)->getAllOperands();
6501-
}
6502-
MutableArrayRef<Operand> getAllOperands() {
6503-
return static_cast<Derived *>(this)->getAllOperands();
6504-
}
6505-
6506-
std::pair<T, SILValue> getCase(unsigned i) const {
6507-
return static_cast<const Derived *>(this)->getCase(i);
6508-
}
6509-
6510-
unsigned getNumCases() const {
6511-
return static_cast<const Derived *>(this)->getNumCases();
6512-
}
6513-
6514-
bool hasDefault() const {
6515-
return static_cast<const Derived *>(this)->hasDefault();
6516-
}
6517-
6518-
SILValue getDefaultResult() const {
6519-
return static_cast<const Derived *>(this)->getDefaultResult();
6520-
}
6521-
};
6522-
65236482
/// Common base class for the select_enum and select_enum_addr instructions,
65246483
/// which select one of a set of possible results based on the case of an enum.
6525-
class SelectEnumInstBase
6526-
: public SelectInstBase<SelectEnumInstBase, EnumElementDecl *> {
6527-
USE_SHARED_UINT8;
6484+
template <typename DerivedTy, typename BaseTy>
6485+
class SelectEnumInstBase : public BaseTy {
6486+
TEMPLATE_USE_SHARED_UINT8(BaseTy);
65286487

65296488
// Tail-allocated after the operands is an array of `NumCases`
65306489
// EnumElementDecl* pointers, referencing the case discriminators for each
@@ -6536,20 +6495,22 @@ class SelectEnumInstBase
65366495
}
65376496

65386497
protected:
6498+
template <typename... Rest>
65396499
SelectEnumInstBase(SILInstructionKind kind, SILDebugLocation debugLoc,
65406500
SILType type, bool defaultValue,
65416501
llvm::Optional<ArrayRef<ProfileCounter>> CaseCounts,
6542-
ProfileCounter DefaultCount)
6543-
: SelectInstBase(kind, debugLoc, type) {
6502+
ProfileCounter DefaultCount, Rest &&...rest)
6503+
: BaseTy(kind, debugLoc, type, std::forward<Rest>(rest)...) {
65446504
sharedUInt8().SelectEnumInstBase.hasDefault = defaultValue;
65456505
}
6546-
template <typename SELECT_ENUM_INST>
6547-
static SELECT_ENUM_INST *createSelectEnum(
6548-
SILDebugLocation DebugLoc, SILValue Enum, SILType Type,
6549-
SILValue DefaultValue,
6550-
ArrayRef<std::pair<EnumElementDecl *, SILValue>> CaseValues, SILModule &M,
6551-
llvm::Optional<ArrayRef<ProfileCounter>> CaseCounts,
6552-
ProfileCounter DefaultCount, ValueOwnershipKind forwardingOwnershipKind);
6506+
template <typename... RestTys>
6507+
static DerivedTy *
6508+
createSelectEnum(SILDebugLocation DebugLoc, SILValue Enum, SILType Type,
6509+
SILValue DefaultValue,
6510+
ArrayRef<std::pair<EnumElementDecl *, SILValue>> CaseValues,
6511+
SILModule &M,
6512+
llvm::Optional<ArrayRef<ProfileCounter>> CaseCounts,
6513+
ProfileCounter DefaultCount, RestTys &&...restArgs);
65536514

65546515
public:
65556516
ArrayRef<Operand> getAllOperands() const;
@@ -6576,9 +6537,6 @@ class SelectEnumInstBase
65766537
// didn't find anything.
65776538
return getDefaultResult();
65786539
}
6579-
6580-
/// If the default refers to exactly one case decl, return it.
6581-
NullablePtr<EnumElementDecl> getUniqueCaseForDefault();
65826540

65836541
bool hasDefault() const {
65846542
return sharedUInt8().SelectEnumInstBase.hasDefault;
@@ -6593,58 +6551,82 @@ class SelectEnumInstBase
65936551
return getAllOperands().size() - 1 - hasDefault();
65946552
}
65956553

6554+
/// If the default refers to exactly one case decl, return it.
6555+
NullablePtr<EnumElementDecl> getUniqueCaseForDefault() {
6556+
assert(this->hasDefault() && "doesn't have a default");
6557+
auto enumValue = getEnumOperand();
6558+
SILType enumType = enumValue->getType();
6559+
6560+
EnumDecl *decl = enumType.getEnumOrBoundGenericEnum();
6561+
assert(decl && "switch_enum operand is not an enum");
6562+
6563+
if (!enumType.isEffectivelyExhaustiveEnumType(this->getFunction())) {
6564+
return nullptr;
6565+
}
6566+
6567+
llvm::SmallPtrSet<EnumElementDecl *, 4> unswitchedElts;
6568+
for (auto elt : decl->getAllElements())
6569+
unswitchedElts.insert(elt);
6570+
6571+
for (unsigned i = 0, e = this->getNumCases(); i != e; ++i) {
6572+
auto Entry = this->getCase(i);
6573+
unswitchedElts.erase(Entry.first);
6574+
}
6575+
6576+
if (unswitchedElts.size() == 1)
6577+
return *unswitchedElts.begin();
6578+
6579+
return nullptr;
6580+
}
6581+
65966582
/// If there is a single case that returns a literal "true" value (an
65976583
/// "integer_literal $Builtin.Int1, 1" value), return it.
65986584
///
65996585
/// FIXME: This is used to interoperate with passes that reasoned about the
66006586
/// old enum_is_tag insn. Ideally those passes would become general enough
66016587
/// not to need this.
6602-
NullablePtr<EnumElementDecl> getSingleTrueElement() const;
6603-
};
6604-
6605-
/// A select enum inst that produces a static OwnershipKind.
6606-
class OwnershipForwardingSelectEnumInstBase : public SelectEnumInstBase,
6607-
public ForwardingInstruction {
6608-
protected:
6609-
OwnershipForwardingSelectEnumInstBase(
6610-
SILInstructionKind kind, SILDebugLocation debugLoc, SILType type,
6611-
bool defaultValue, llvm::Optional<ArrayRef<ProfileCounter>> caseCounts,
6612-
ProfileCounter defaultCount, ValueOwnershipKind ownershipKind)
6613-
: SelectEnumInstBase(kind, debugLoc, type, defaultValue, caseCounts,
6614-
defaultCount),
6615-
ForwardingInstruction(kind, ownershipKind) {
6616-
assert(classof(kind) && "classof missing subclass");
6617-
}
6618-
6619-
public:
6620-
static bool classof(SILNodePointer node) {
6621-
if (auto *i = dyn_cast<SILInstruction>(node.get()))
6622-
return classof(i);
6623-
return false;
6624-
}
6625-
6626-
static bool classof(const SILInstruction *i) { return classof(i->getKind()); }
6588+
NullablePtr<EnumElementDecl> getSingleTrueElement() const {
6589+
auto SEIType = static_cast<const DerivedTy *>(this)
6590+
->getType()
6591+
.template getAs<BuiltinIntegerType>();
6592+
if (!SEIType)
6593+
return nullptr;
6594+
if (SEIType->getWidth() != BuiltinIntegerWidth::fixed(1))
6595+
return nullptr;
66276596

6628-
static bool classof(SILInstructionKind kind) {
6629-
switch (kind) {
6630-
case SILInstructionKind::SelectEnumInst:
6631-
return true;
6632-
default:
6633-
return false;
6597+
// Try to find a single literal "true" case.
6598+
llvm::Optional<EnumElementDecl *> TrueElement;
6599+
for (unsigned i = 0, e = getNumCases(); i < e; ++i) {
6600+
auto casePair = getCase(i);
6601+
if (auto intLit = dyn_cast<IntegerLiteralInst>(casePair.second)) {
6602+
if (intLit->getValue() == APInt(1, 1)) {
6603+
if (!TrueElement)
6604+
TrueElement = casePair.first;
6605+
else
6606+
// Use Optional(nullptr) to represent more than one.
6607+
TrueElement = llvm::Optional<EnumElementDecl *>(nullptr);
6608+
}
6609+
}
66346610
}
6611+
6612+
if (!TrueElement || !*TrueElement)
6613+
return nullptr;
6614+
return *TrueElement;
66356615
}
66366616
};
66376617

66386618
/// Select one of a set of values based on the case of an enum.
66396619
class SelectEnumInst final
66406620
: public InstructionBaseWithTrailingOperands<
66416621
SILInstructionKind::SelectEnumInst, SelectEnumInst,
6642-
OwnershipForwardingSelectEnumInstBase, EnumElementDecl *> {
6622+
SelectEnumInstBase<SelectEnumInst,
6623+
OwnershipForwardingSingleValueInstruction>,
6624+
EnumElementDecl *> {
66436625
friend SILBuilder;
6626+
friend SelectEnumInstBase<SelectEnumInst,
6627+
OwnershipForwardingSingleValueInstruction>;
66446628

6645-
private:
6646-
friend SelectEnumInstBase;
6647-
6629+
public:
66486630
SelectEnumInst(SILDebugLocation DebugLoc, SILValue Operand, SILType Type,
66496631
bool DefaultValue, ArrayRef<SILValue> CaseValues,
66506632
ArrayRef<EnumElementDecl *> CaseDecls,
@@ -6671,20 +6653,20 @@ class SelectEnumInst final
66716653
class SelectEnumAddrInst final
66726654
: public InstructionBaseWithTrailingOperands<
66736655
SILInstructionKind::SelectEnumAddrInst, SelectEnumAddrInst,
6674-
SelectEnumInstBase, EnumElementDecl *> {
6656+
SelectEnumInstBase<SelectEnumAddrInst, SingleValueInstruction>,
6657+
EnumElementDecl *> {
66756658
friend SILBuilder;
6676-
friend SelectEnumInstBase;
6659+
friend SelectEnumInstBase<SelectEnumAddrInst, SingleValueInstruction>;
66776660

6661+
public:
66786662
SelectEnumAddrInst(SILDebugLocation DebugLoc, SILValue Operand, SILType Type,
66796663
bool DefaultValue, ArrayRef<SILValue> CaseValues,
66806664
ArrayRef<EnumElementDecl *> CaseDecls,
66816665
llvm::Optional<ArrayRef<ProfileCounter>> CaseCounts,
6682-
ProfileCounter DefaultCount,
6683-
ValueOwnershipKind forwardingOwnershipKind)
6666+
ProfileCounter DefaultCount)
66846667
: InstructionBaseWithTrailingOperands(Operand, CaseValues, DebugLoc, Type,
66856668
bool(DefaultValue), CaseCounts,
66866669
DefaultCount) {
6687-
(void)forwardingOwnershipKind;
66886670
assert(CaseValues.size() - DefaultValue == CaseDecls.size());
66896671
std::uninitialized_copy(CaseDecls.begin(), CaseDecls.end(),
66906672
getTrailingObjects<EnumElementDecl *>());
@@ -10528,6 +10510,7 @@ OwnershipForwardingSingleValueInstruction::classof(SILInstructionKind kind) {
1052810510
case SILInstructionKind::BridgeObjectToRefInst:
1052910511
case SILInstructionKind::ThinToThickFunctionInst:
1053010512
case SILInstructionKind::UnconditionalCheckedCastInst:
10513+
case SILInstructionKind::SelectEnumInst:
1053110514
return true;
1053210515
default:
1053310516
return false;
@@ -10755,31 +10738,25 @@ inline MutableArrayRef<Operand> AllocRefInstBase::getAllOperands() {
1075510738
llvm_unreachable("Unhandled AllocRefInstBase subclass");
1075610739
}
1075710740

10758-
inline ArrayRef<Operand> SelectEnumInstBase::getAllOperands() const {
10759-
// If the size of the subclasses are equal, then all of this compiles away.
10760-
if (auto I = dyn_cast<SelectEnumInst>(this))
10761-
return I->getAllOperands();
10762-
if (auto I = dyn_cast<SelectEnumAddrInst>(this))
10763-
return I->getAllOperands();
10764-
llvm_unreachable("Unhandled SelectEnumInstBase subclass");
10741+
template <typename DerivedTy, typename BaseTy>
10742+
inline ArrayRef<Operand>
10743+
SelectEnumInstBase<DerivedTy, BaseTy>::getAllOperands() const {
10744+
const auto &I = static_cast<const DerivedTy &>(*this);
10745+
return I.getAllOperands();
1076510746
}
1076610747

10767-
inline MutableArrayRef<Operand> SelectEnumInstBase::getAllOperands() {
10768-
// If the size of the subclasses are equal, then all of this compiles away.
10769-
if (auto I = dyn_cast<SelectEnumInst>(this))
10770-
return I->getAllOperands();
10771-
if (auto I = dyn_cast<SelectEnumAddrInst>(this))
10772-
return I->getAllOperands();
10773-
llvm_unreachable("Unhandled SelectEnumInstBase subclass");
10748+
template <typename DerivedTy, typename BaseTy>
10749+
inline MutableArrayRef<Operand>
10750+
SelectEnumInstBase<DerivedTy, BaseTy>::getAllOperands() {
10751+
auto &I = static_cast<DerivedTy &>(*this);
10752+
return I.getAllOperands();
1077410753
}
1077510754

10776-
inline EnumElementDecl **SelectEnumInstBase::getEnumElementDeclStorage() {
10777-
// If the size of the subclasses are equal, then all of this compiles away.
10778-
if (auto I = dyn_cast<SelectEnumInst>(this))
10779-
return I->getTrailingObjects<EnumElementDecl*>();
10780-
if (auto I = dyn_cast<SelectEnumAddrInst>(this))
10781-
return I->getTrailingObjects<EnumElementDecl*>();
10782-
llvm_unreachable("Unhandled SelectEnumInstBase subclass");
10755+
template <typename DerivedTy, typename BaseTy>
10756+
inline EnumElementDecl **
10757+
SelectEnumInstBase<DerivedTy, BaseTy>::getEnumElementDeclStorage() {
10758+
auto &I = static_cast<DerivedTy &>(*this);
10759+
return I.template getTrailingObjects<EnumElementDecl *>();
1078310760
}
1078410761

1078510762
inline void SILSuccessor::pred_iterator::cacheBasicBlock() {
@@ -10799,7 +10776,6 @@ inline bool Operand::isTypeDependent() const {
1079910776
inline bool ForwardingInstruction::isa(SILInstructionKind kind) {
1080010777
return OwnershipForwardingSingleValueInstruction::classof(kind) ||
1080110778
OwnershipForwardingTermInst::classof(kind) ||
10802-
OwnershipForwardingSelectEnumInstBase::classof(kind) ||
1080310779
OwnershipForwardingMultipleValueInstruction::classof(kind);
1080410780
}
1080510781

@@ -10812,8 +10788,6 @@ inline ForwardingInstruction *ForwardingInstruction::get(SILInstruction *inst) {
1081210788
return result;
1081310789
if (auto *result = dyn_cast<OwnershipForwardingTermInst>(inst))
1081410790
return result;
10815-
if (auto *result = dyn_cast<OwnershipForwardingSelectEnumInstBase>(inst))
10816-
return result;
1081710791
if (auto *result =
1081810792
dyn_cast<OwnershipForwardingMultipleValueInstruction>(inst))
1081910793
return result;

include/swift/SIL/SILNode.h

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,13 @@ class alignas(8) SILNode :
170170
#define SHARED_TEMPLATE_FIELD(T, I, ...) \
171171
class { template <T> friend class I; __VA_ARGS__; } I
172172

173+
#define SHARED_TEMPLATE2_FIELD(T1, T2, I, ...) \
174+
class { \
175+
template <T1, T2> \
176+
friend class I; \
177+
__VA_ARGS__; \
178+
} I
179+
173180
/// Special case for `InstructionBaseWithTrailingOperands`.
174181
#define SHARED_TEMPLATE4_FIELD(T1, T2, T3, T4, I, ...) \
175182
class { template <T1, T2, T3, T4> friend class I; __VA_ARGS__; } I
@@ -179,6 +186,7 @@ class alignas(8) SILNode :
179186
uint8_t opaque;
180187

181188
SHARED_TEMPLATE_FIELD(typename, SwitchEnumInstBase, bool hasDefault);
189+
SHARED_TEMPLATE2_FIELD(typename, typename, SelectEnumInstBase, bool hasDefault);
182190
SHARED_TEMPLATE_FIELD(SILInstructionKind, LoadReferenceInstBase, bool isTake);
183191
SHARED_TEMPLATE_FIELD(SILInstructionKind, StoreReferenceInstBase, bool isInitializationOfDest);
184192
SHARED_FIELD(MultipleValueInstructionResult, uint8_t valueOwnershipKind);
@@ -189,7 +197,6 @@ class alignas(8) SILNode :
189197
SHARED_FIELD(AssignByWrapperInst, uint8_t mode);
190198
SHARED_FIELD(AssignOrInitInst, uint8_t mode);
191199
SHARED_FIELD(StringLiteralInst, uint8_t encoding);
192-
SHARED_FIELD(SelectEnumInstBase, bool hasDefault);
193200
SHARED_FIELD(SwitchValueInst, bool hasDefault);
194201
SHARED_FIELD(RefCountingInst, bool atomicity);
195202
SHARED_FIELD(EndAccessInst, bool aborting);

0 commit comments

Comments
 (0)