Skip to content

Commit 06a7950

Browse files
committed
[ConstraintSystem] Treat arithmetic SIMD operators like other generic
operators when partitioning an overload set.
1 parent ab594dd commit 06a7950

File tree

3 files changed

+18
-2
lines changed

3 files changed

+18
-2
lines changed

include/swift/AST/KnownProtocols.def

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ PROTOCOL(Error)
6969
PROTOCOL_(ErrorCodeProtocol)
7070
PROTOCOL(OptionSet)
7171
PROTOCOL(CaseIterable)
72+
PROTOCOL(SIMD)
7273
PROTOCOL(SIMDScalar)
7374
PROTOCOL(BinaryInteger)
7475
PROTOCOL(RangeReplaceableCollection)

lib/IRGen/GenMeta.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5128,6 +5128,7 @@ SpecialProtocol irgen::getSpecialProtocolID(ProtocolDecl *P) {
51285128
case KnownProtocolKind::Hashable:
51295129
case KnownProtocolKind::CaseIterable:
51305130
case KnownProtocolKind::Comparable:
5131+
case KnownProtocolKind::SIMD:
51315132
case KnownProtocolKind::SIMDScalar:
51325133
case KnownProtocolKind::BinaryInteger:
51335134
case KnownProtocolKind::ObjectiveCBridgeable:

lib/Sema/CSSolver.cpp

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1877,6 +1877,7 @@ void DisjunctionChoiceProducer::partitionGenericOperators(
18771877
SmallVector<unsigned, 4> concreteOverloads;
18781878
SmallVector<unsigned, 4> numericOverloads;
18791879
SmallVector<unsigned, 4> sequenceOverloads;
1880+
SmallVector<unsigned, 4> simdOverloads;
18801881
SmallVector<unsigned, 4> otherGenericOverloads;
18811882

18821883
auto refinesOrConformsTo = [&](NominalTypeDecl *nominal, KnownProtocolKind kind) -> bool {
@@ -1898,7 +1899,10 @@ void DisjunctionChoiceProducer::partitionGenericOperators(
18981899
unsigned index = *iter;
18991900
auto *decl = Choices[index]->getOverloadChoice().getDecl();
19001901
auto *nominal = decl->getDeclContext()->getSelfNominalTypeDecl();
1901-
if (!decl->getInterfaceType()->is<GenericFunctionType>()) {
1902+
1903+
if (isSIMDOperator(decl)) {
1904+
simdOverloads.push_back(index);
1905+
} else if (!decl->getInterfaceType()->is<GenericFunctionType>()) {
19021906
concreteOverloads.push_back(index);
19031907
} else if (refinesOrConformsTo(nominal, KnownProtocolKind::AdditiveArithmetic)) {
19041908
numericOverloads.push_back(index);
@@ -1955,11 +1959,20 @@ void DisjunctionChoiceProducer::partitionGenericOperators(
19551959
sequenceOverloads.clear();
19561960
break;
19571961
}
1962+
1963+
if (TypeChecker::conformsToKnownProtocol(
1964+
argType, KnownProtocolKind::SIMD,
1965+
CS.DC->getParentModule())) {
1966+
first = std::copy(simdOverloads.begin(), simdOverloads.end(), first);
1967+
simdOverloads.clear();
1968+
break;
1969+
}
19581970
}
19591971

19601972
first = std::copy(otherGenericOverloads.begin(), otherGenericOverloads.end(), first);
19611973
first = std::copy(numericOverloads.begin(), numericOverloads.end(), first);
19621974
first = std::copy(sequenceOverloads.begin(), sequenceOverloads.end(), first);
1975+
first = std::copy(simdOverloads.begin(), simdOverloads.end(), first);
19631976
}
19641977

19651978
void DisjunctionChoiceProducer::partitionDisjunction(
@@ -2044,7 +2057,8 @@ void DisjunctionChoiceProducer::partitionDisjunction(
20442057
}
20452058

20462059
// Partition SIMD operators.
2047-
if (isOperatorDisjunction(Disjunction)) {
2060+
if (isOperatorDisjunction(Disjunction) &&
2061+
!Choices[0]->getOverloadChoice().getName().getBaseIdentifier().isArithmeticOperator()) {
20482062
forEachChoice(Choices, [&](unsigned index, Constraint *constraint) -> bool {
20492063
if (isSIMDOperator(constraint->getOverloadChoice().getDecl())) {
20502064
simdOperators.push_back(index);

0 commit comments

Comments
 (0)