Skip to content

Commit 2b091f8

Browse files
authored
Merge pull request swiftlang#19962 from rudkx/refactor-disjunction-partitioning
[ConstraintSystem] Refactor disjunction partitioning
2 parents b1e66f8 + 5cfa61c commit 2b091f8

File tree

2 files changed

+97
-71
lines changed

2 files changed

+97
-71
lines changed

lib/Sema/CSSolver.cpp

Lines changed: 84 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -1683,6 +1683,68 @@ getOperatorDesignatedNominalTypes(Constraint *bindOverload) {
16831683
return operatorDecl->getDesignatedNominalTypes();
16841684
}
16851685

1686+
void ConstraintSystem::partitionForDesignatedTypes(
1687+
ArrayRef<Constraint *> Choices, ConstraintMatchLoop forEachChoice,
1688+
PartitionAppendCallback appendPartition) {
1689+
1690+
auto designatedNominalTypes = getOperatorDesignatedNominalTypes(Choices[0]);
1691+
if (designatedNominalTypes.empty())
1692+
return;
1693+
1694+
SmallVector<SmallVector<unsigned, 4>, 4> definedInDesignatedType;
1695+
SmallVector<SmallVector<unsigned, 4>, 4> definedInExtensionOfDesignatedType;
1696+
1697+
auto examineConstraint =
1698+
[&](unsigned constraintIndex, Constraint *constraint) -> bool {
1699+
auto *decl = constraint->getOverloadChoice().getDecl();
1700+
auto *funcDecl = cast<FuncDecl>(decl);
1701+
1702+
auto *parentDC = funcDecl->getParent();
1703+
auto *parentDecl = parentDC->getAsDecl();
1704+
1705+
if (parentDC->isExtensionContext())
1706+
parentDecl = cast<ExtensionDecl>(parentDecl)->getExtendedNominal();
1707+
1708+
for (auto designatedTypeIndex : indices(designatedNominalTypes)) {
1709+
auto *designatedNominal =
1710+
designatedNominalTypes[designatedTypeIndex];
1711+
1712+
if (parentDecl != designatedNominal)
1713+
continue;
1714+
1715+
auto &constraints =
1716+
parentDC->isExtensionContext()
1717+
? definedInExtensionOfDesignatedType[designatedTypeIndex]
1718+
: definedInDesignatedType[designatedTypeIndex];
1719+
1720+
constraints.push_back(constraintIndex);
1721+
return true;
1722+
}
1723+
1724+
return false;
1725+
};
1726+
1727+
definedInDesignatedType.resize(designatedNominalTypes.size());
1728+
definedInExtensionOfDesignatedType.resize(designatedNominalTypes.size());
1729+
1730+
forEachChoice(Choices, examineConstraint);
1731+
1732+
// Now collect the overload choices that are defined within the type
1733+
// that was designated in the operator declaration.
1734+
// Add partitions for each of the overloads we found in types that
1735+
// were designated as part of the operator declaration.
1736+
for (auto designatedTypeIndex : indices(designatedNominalTypes)) {
1737+
if (designatedTypeIndex < definedInDesignatedType.size()) {
1738+
auto &primary = definedInDesignatedType[designatedTypeIndex];
1739+
appendPartition(primary);
1740+
}
1741+
if (designatedTypeIndex < definedInExtensionOfDesignatedType.size()) {
1742+
auto &secondary = definedInExtensionOfDesignatedType[designatedTypeIndex];
1743+
appendPartition(secondary);
1744+
}
1745+
}
1746+
}
1747+
16861748
void ConstraintSystem::partitionDisjunction(
16871749
ArrayRef<Constraint *> Choices, SmallVectorImpl<unsigned> &Ordering,
16881750
SmallVectorImpl<unsigned> &PartitionBeginning) {
@@ -1702,20 +1764,14 @@ void ConstraintSystem::partitionDisjunction(
17021764
return;
17031765
}
17041766

1705-
SmallVector<unsigned, 4> disabled;
1706-
SmallVector<unsigned, 4> unavailable;
1707-
SmallVector<unsigned, 4> globalScope;
1708-
SmallVector<SmallVector<unsigned, 4>, 4> definedInDesignatedType;
1709-
SmallVector<SmallVector<unsigned, 4>, 4> definedInExtensionOfDesignatedType;
1710-
SmallVector<unsigned, 4> everythingElse;
17111767
SmallSet<Constraint *, 16> taken;
17121768

17131769
// Local function used to iterate over the untaken choices from the
17141770
// disjunction and use a higher-order function to determine if they
17151771
// should be part of a partition.
1716-
auto forEachChoice =
1772+
ConstraintMatchLoop forEachChoice =
17171773
[&](ArrayRef<Constraint *>,
1718-
llvm::function_ref<bool(unsigned index, Constraint *)> fn) {
1774+
std::function<bool(unsigned index, Constraint *)> fn) {
17191775
for (auto index : indices(Choices)) {
17201776
auto *constraint = Choices[index];
17211777
if (taken.count(constraint))
@@ -1729,6 +1785,13 @@ void ConstraintSystem::partitionDisjunction(
17291785
}
17301786
};
17311787

1788+
// First collect some things that we'll generally put near the end
1789+
// of the partitioning.
1790+
1791+
SmallVector<unsigned, 4> disabled;
1792+
SmallVector<unsigned, 4> unavailable;
1793+
SmallVector<unsigned, 4> globalScope;
1794+
17321795
// First collect disabled constraints.
17331796
forEachChoice(Choices, [&](unsigned index, Constraint *constraint) -> bool {
17341797
if (!constraint->isDisabled())
@@ -1765,77 +1828,27 @@ void ConstraintSystem::partitionDisjunction(
17651828
return true;
17661829
});
17671830

1768-
// Now collect the overload choices that are defined within the type
1769-
// that was designated in the operator declaration.
1770-
auto designatedNominalTypes = getOperatorDesignatedNominalTypes(Choices[0]);
1771-
if (!designatedNominalTypes.empty()) {
1772-
forEachChoice(
1773-
Choices, [&](unsigned constraintIndex, Constraint *constraint) -> bool {
1774-
auto *decl = constraint->getOverloadChoice().getDecl();
1775-
auto *funcDecl = cast<FuncDecl>(decl);
1776-
1777-
auto *parentDecl = funcDecl->getParent()->getAsDecl();
1778-
for (auto designatedTypeIndex : indices(designatedNominalTypes)) {
1779-
auto *designatedNominal =
1780-
designatedNominalTypes[designatedTypeIndex];
1781-
if (parentDecl == designatedNominal) {
1782-
if (designatedTypeIndex >= definedInDesignatedType.size())
1783-
definedInDesignatedType.resize(designatedTypeIndex + 1);
1784-
auto &constraints = definedInDesignatedType[designatedTypeIndex];
1785-
constraints.push_back(constraintIndex);
1786-
return true;
1787-
}
1788-
1789-
if (auto *extensionDecl = dyn_cast<ExtensionDecl>(parentDecl)) {
1790-
parentDecl = extensionDecl->getExtendedNominal();
1791-
if (parentDecl == designatedNominal) {
1792-
if (designatedTypeIndex >=
1793-
definedInExtensionOfDesignatedType.size())
1794-
definedInExtensionOfDesignatedType.resize(
1795-
designatedTypeIndex + 1);
1796-
1797-
auto &constraints =
1798-
definedInExtensionOfDesignatedType[designatedTypeIndex];
1799-
constraints.push_back(constraintIndex);
1800-
return true;
1801-
}
1802-
}
1803-
}
1831+
// Local function to create the next partition based on the options
1832+
// passed in.
1833+
PartitionAppendCallback appendPartition =
1834+
[&](SmallVectorImpl<unsigned> &options) {
1835+
if (options.size()) {
1836+
PartitionBeginning.push_back(Ordering.size());
1837+
Ordering.insert(Ordering.end(), options.begin(), options.end());
1838+
}
1839+
};
18041840

1805-
return false;
1806-
});
1807-
}
1841+
partitionForDesignatedTypes(Choices, forEachChoice, appendPartition);
18081842

1843+
SmallVector<unsigned, 4> everythingElse;
18091844
// Gather the remaining options.
18101845
forEachChoice(Choices, [&](unsigned index, Constraint *constraint) -> bool {
18111846
everythingElse.push_back(index);
18121847
return true;
18131848
});
1814-
1815-
// Local function to create the next partition based on the options
1816-
// passed in.
1817-
auto appendPartition = [&](SmallVectorImpl<unsigned> &options) {
1818-
if (options.size()) {
1819-
PartitionBeginning.push_back(Ordering.size());
1820-
Ordering.insert(Ordering.end(), options.begin(), options.end());
1821-
}
1822-
};
1823-
1824-
// Now create the partitioning based on what was collected.
1825-
1826-
// First we'll add partitions for each of the overloads we found in
1827-
// types that were designated as part of the operator declaration.
1828-
for (auto designatedTypeIndex : indices(designatedNominalTypes)) {
1829-
if (designatedTypeIndex < definedInDesignatedType.size()) {
1830-
auto &primary = definedInDesignatedType[designatedTypeIndex];
1831-
appendPartition(primary);
1832-
}
1833-
if (designatedTypeIndex < definedInExtensionOfDesignatedType.size()) {
1834-
auto &secondary = definedInExtensionOfDesignatedType[designatedTypeIndex];
1835-
appendPartition(secondary);
1836-
}
1837-
}
18381849
appendPartition(everythingElse);
1850+
1851+
// Now create the remaining partitions from what we previously collected.
18391852
appendPartition(globalScope);
18401853
appendPartition(unavailable);
18411854
appendPartition(disabled);

lib/Sema/ConstraintSystem.h

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3221,6 +3221,19 @@ class ConstraintSystem {
32213221

32223222
bool haveTypeInformationForAllArguments(FunctionType *fnType);
32233223

3224+
typedef std::function<bool(unsigned index, Constraint *)> ConstraintMatcher;
3225+
typedef std::function<void(ArrayRef<Constraint *>, ConstraintMatcher)>
3226+
ConstraintMatchLoop;
3227+
typedef std::function<void(SmallVectorImpl<unsigned> &options)>
3228+
PartitionAppendCallback;
3229+
3230+
// Partition the choices in a disjunction based on those that match
3231+
// the designated types for the operator that the disjunction was
3232+
// formed for.
3233+
void partitionForDesignatedTypes(ArrayRef<Constraint *> Choices,
3234+
ConstraintMatchLoop forEachChoice,
3235+
PartitionAppendCallback appendPartition);
3236+
32243237
// Partition the choices in the disjunction into groups that we will
32253238
// iterate over in an order appropriate to attempt to stop before we
32263239
// have to visit all of the options.

0 commit comments

Comments
 (0)