Skip to content

Commit 1ab417c

Browse files
committed
[Constraint system] Compute and use a common type among overloads.
Given an overload set, attempt to compute a "common type" that abstracts over all entries in the overload set, providing more structure for the constraint solver.
1 parent d31ef61 commit 1ab417c

File tree

4 files changed

+382
-7
lines changed

4 files changed

+382
-7
lines changed

lib/Sema/CSGen.cpp

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2523,10 +2523,8 @@ namespace {
25232523
return resultOfTypeOperation(typeOperation, expr->getArg());
25242524
}
25252525

2526-
if (isa<DeclRefExpr>(fnExpr)) {
2527-
if (auto fnType = CS.getType(fnExpr)->getAs<AnyFunctionType>()) {
2528-
outputTy = fnType->getResult();
2529-
}
2526+
if (auto fnType = CS.getType(fnExpr)->getAs<AnyFunctionType>()) {
2527+
outputTy = fnType->getResult();
25302528
} else if (auto OSR = dyn_cast<OverloadedDeclRefExpr>(fnExpr)) {
25312529
// Determine if the overloads are all functions that share a common
25322530
// return type.

lib/Sema/ConstraintSystem.cpp

Lines changed: 366 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
#include "TypeCheckType.h"
2222
#include "swift/AST/GenericEnvironment.h"
2323
#include "swift/AST/ParameterList.h"
24+
#include "swift/AST/TypeVisitor.h"
2425
#include "swift/Basic/Statistic.h"
2526
#include "llvm/ADT/SetVector.h"
2627
#include "llvm/ADT/SmallString.h"
@@ -1464,6 +1465,364 @@ static ArrayRef<OverloadChoice> partitionSIMDOperators(
14641465
return scratch;
14651466
}
14661467

1468+
/// Retrieve the type that will be used when matching the given overload.
1469+
static Type getEffectiveOverloadType(const OverloadChoice &overload) {
1470+
switch (overload.getKind()) {
1471+
case OverloadChoiceKind::Decl:
1472+
// Declaration choices are handled below.
1473+
break;
1474+
1475+
case OverloadChoiceKind::BaseType:
1476+
case OverloadChoiceKind::DeclViaBridge:
1477+
case OverloadChoiceKind::DeclViaDynamic:
1478+
case OverloadChoiceKind::DeclViaUnwrappedOptional:
1479+
case OverloadChoiceKind::DynamicMemberLookup:
1480+
case OverloadChoiceKind::KeyPathApplication:
1481+
case OverloadChoiceKind::TupleIndex:
1482+
return Type();
1483+
}
1484+
1485+
auto decl = overload.getDecl();
1486+
1487+
// Retrieve the interface type.
1488+
auto type = decl->getInterfaceType();
1489+
if (!type) {
1490+
decl->getASTContext().getLazyResolver()->resolveDeclSignature(decl);
1491+
type = decl->getInterfaceType();
1492+
if (!type) {
1493+
return Type();
1494+
}
1495+
}
1496+
1497+
// If we have a generic function type, drop the generic signature; we don't
1498+
// need it for this comparison.
1499+
if (auto genericFn = type->getAs<GenericFunctionType>()) {
1500+
type = FunctionType::get(genericFn->getParams(),
1501+
genericFn->getResult(),
1502+
genericFn->getExtInfo());
1503+
}
1504+
1505+
// If this declaration is within a type context, bail out.
1506+
if (decl->getDeclContext()->isTypeContext()) {
1507+
return Type();
1508+
}
1509+
1510+
return type;
1511+
}
1512+
1513+
namespace {
1514+
/// Type visitor that extracts the common type between two types, when
1515+
/// possible.
1516+
class CommonTypeVisitor : public TypeVisitor<CommonTypeVisitor, Type, Type> {
1517+
/// Perform a "leaf" match for types, which does not consider the children.
1518+
Type handleLeafMatch(Type type1, Type type2) {
1519+
if (type1->isEqual(type2))
1520+
return type1;
1521+
1522+
return handleMismatch(type1, type2);
1523+
}
1524+
1525+
/// Handle a mismatch between two types.
1526+
Type handleMismatch(Type type1, Type type2) {
1527+
return Type();
1528+
}
1529+
1530+
public:
1531+
Type visitTupleType(TupleType *tuple1, Type type2) {
1532+
if (tuple1->isEqual(type2))
1533+
return Type(tuple1);
1534+
1535+
auto tuple2 = type2->getAs<TupleType>();
1536+
if (!tuple2) {
1537+
return handleMismatch(Type(tuple1), type2);
1538+
}
1539+
1540+
// Check for structural similarity between the two tuple types.
1541+
auto elements1 = tuple1->getElements();
1542+
auto elements2 = tuple2->getElements();
1543+
if (elements1.size() != elements2.size()) {
1544+
return handleMismatch(Type(tuple1), type2);
1545+
}
1546+
1547+
for (unsigned i : indices(elements1)) {
1548+
const auto &elt1 = elements1[i];
1549+
const auto &elt2 = elements2[i];
1550+
if (elt1.getName() != elt2.getName() ||
1551+
elt1.getParameterFlags() != elt2.getParameterFlags()) {
1552+
return handleMismatch(Type(tuple1), type2);
1553+
}
1554+
}
1555+
1556+
// Recurse on the element types.
1557+
SmallVector<TupleTypeElt, 4> newElements;
1558+
newElements.reserve(elements1.size());
1559+
for (unsigned i : indices(elements1)) {
1560+
const auto &elt1 = elements1[i];
1561+
const auto &elt2 = elements2[i];
1562+
Type elementType = visit(elt1.getRawType(), elt2.getRawType());
1563+
if (!elementType) {
1564+
return handleMismatch(Type(tuple1), type2);
1565+
}
1566+
1567+
newElements.push_back(elt1.getWithType(elementType));
1568+
}
1569+
return TupleType::get(newElements, tuple1->getASTContext());
1570+
}
1571+
1572+
Type visitReferenceStorageType(ReferenceStorageType *refStorage1,
1573+
Type type2) {
1574+
if (refStorage1->isEqual(type2))
1575+
return Type(refStorage1);
1576+
1577+
auto refStorage2 = type2->getAs<ReferenceStorageType>();
1578+
if (!refStorage2 ||
1579+
refStorage1->getOwnership() != refStorage2->getOwnership()) {
1580+
return handleMismatch(Type(refStorage1), type2);
1581+
}
1582+
1583+
Type newReferentType = visit(refStorage1->getReferentType(),
1584+
refStorage2->getReferentType());
1585+
if (!newReferentType) {
1586+
return handleMismatch(Type(refStorage1), type2);
1587+
}
1588+
1589+
return ReferenceStorageType::get(newReferentType,
1590+
refStorage1->getOwnership(),
1591+
refStorage1->getASTContext());
1592+
}
1593+
1594+
Type visitAnyMetatypeType(AnyMetatypeType *metatype1, Type type2) {
1595+
if (metatype1->isEqual(type2))
1596+
return Type(metatype1);
1597+
1598+
1599+
auto metatype2 = type2->getAs<AnyMetatypeType>();
1600+
if (!metatype2) {
1601+
return handleMismatch(Type(metatype1), type2);
1602+
}
1603+
1604+
if (metatype1->getKind() != metatype2->getKind() ||
1605+
metatype1->hasRepresentation() != metatype2->hasRepresentation() ||
1606+
(metatype1->hasRepresentation() &&
1607+
metatype2->getRepresentation() != metatype2->getRepresentation())) {
1608+
return handleMismatch(Type(metatype1), type2);
1609+
}
1610+
1611+
Type newInstanceType = visit(metatype1->getInstanceType(),
1612+
metatype2->getInstanceType());
1613+
if (!newInstanceType) {
1614+
return handleMismatch(Type(metatype1), type2);
1615+
}
1616+
1617+
Optional<MetatypeRepresentation> representation;
1618+
if (metatype1->hasRepresentation())
1619+
representation = metatype1->getRepresentation();
1620+
1621+
if (metatype1->getKind() == TypeKind::Metatype)
1622+
return MetatypeType::get(newInstanceType, representation);
1623+
1624+
assert(metatype1->getKind() == TypeKind::ExistentialMetatype);
1625+
return ExistentialMetatypeType::get(newInstanceType, representation);
1626+
}
1627+
1628+
Type visitFunctionType(FunctionType *function1, Type type2) {
1629+
if (function1->isEqual(type2))
1630+
return Type(function1);
1631+
1632+
auto function2 = type2->getAs<FunctionType>();
1633+
if (!function2 ||
1634+
function1->getExtInfo() != function2->getExtInfo() ||
1635+
function1->getNumParams() != function2->getNumParams()) {
1636+
return handleMismatch(Type(function1), type2);
1637+
}
1638+
1639+
// Check for a structural match between the parameters.
1640+
auto params1 = function1->getParams();
1641+
auto params2 = function2->getParams();
1642+
for (unsigned i : indices(params1)) {
1643+
const auto &param1 = params1[i];
1644+
const auto &param2 = params2[i];
1645+
if (param1.getLabel() != param2.getLabel() ||
1646+
param1.getParameterFlags() != param2.getParameterFlags()) {
1647+
return handleMismatch(Type(function1), type2);
1648+
}
1649+
}
1650+
1651+
Type newResultType = visit(function1->getResult(), function2->getResult());
1652+
if (!newResultType) {
1653+
return handleMismatch(Type(function1), type2);
1654+
}
1655+
1656+
SmallVector<AnyFunctionType::Param, 4> newParams;
1657+
newParams.reserve(params1.size());
1658+
for (unsigned i : indices(params1)) {
1659+
const auto &param1 = params1[i];
1660+
const auto &param2 = params2[i];
1661+
Type newParamType = visit(param1.getPlainType(), param2.getPlainType());
1662+
if (!newParamType) {
1663+
return handleMismatch(Type(function1), type2);
1664+
}
1665+
1666+
newParams.push_back(AnyFunctionType::Param(newParamType,
1667+
param1.getLabel(),
1668+
param1.getParameterFlags()));
1669+
}
1670+
1671+
return FunctionType::get(newParams, newResultType, function1->getExtInfo());
1672+
}
1673+
1674+
Type visitGenericFunctionType(GenericFunctionType *function1, Type type2) {
1675+
llvm_unreachable("Caller should have eliminated these");
1676+
}
1677+
1678+
Type visitLValueType(LValueType *lvalue1, Type type2) {
1679+
if (lvalue1->isEqual(type2))
1680+
return Type(lvalue1);
1681+
1682+
auto lvalue2 = type2->getAs<LValueType>();
1683+
if (!lvalue2) {
1684+
return handleMismatch(Type(lvalue1), type2);
1685+
}
1686+
1687+
Type newObjectType =
1688+
visit(lvalue1->getObjectType(), lvalue2->getObjectType());
1689+
if (!newObjectType) {
1690+
return handleMismatch(Type(lvalue1), type2);
1691+
}
1692+
1693+
return LValueType::get(newObjectType);
1694+
}
1695+
1696+
Type visitInOutType(InOutType *inout1, Type type2) {
1697+
if (inout1->isEqual(type2))
1698+
return Type(inout1);
1699+
1700+
auto inout2 = type2->getAs<InOutType>();
1701+
if (!inout2) {
1702+
return handleMismatch(Type(inout1), type2);
1703+
}
1704+
1705+
Type newObjectType =
1706+
visit(inout1->getObjectType(), inout2->getObjectType());
1707+
if (!newObjectType) {
1708+
return handleMismatch(Type(inout1), type2);
1709+
}
1710+
1711+
return LValueType::get(newObjectType);
1712+
}
1713+
1714+
Type visitSugarType(SugarType *sugar1, Type type2) {
1715+
if (sugar1->isEqual(type2))
1716+
return Type(sugar1);
1717+
1718+
// FIXME: Reconstitute sugar.
1719+
return visit(Type(sugar1->getSinglyDesugaredType()), type2);
1720+
}
1721+
1722+
#define FAILURE_CASE(Class) \
1723+
Type visit##Class##Type(Class##Type *type1, Type type2) { \
1724+
return Type(); \
1725+
}
1726+
1727+
#define LEAF_CASE(Class) \
1728+
Type visit##Class##Type(Class##Type *type1, Type type2) { \
1729+
return handleLeafMatch(Type(type1), type2); \
1730+
}
1731+
1732+
FAILURE_CASE(Error)
1733+
FAILURE_CASE(Unresolved)
1734+
LEAF_CASE(Builtin)
1735+
LEAF_CASE(Nominal) // FIXME: We can do a more specific match here.
1736+
LEAF_CASE(BoundGeneric) // FIXME: We can do a more specific match here.
1737+
FAILURE_CASE(UnboundGeneric)
1738+
LEAF_CASE(Module)
1739+
LEAF_CASE(DynamicSelf) // FIXME: Can we do better here?
1740+
LEAF_CASE(Substitutable)
1741+
LEAF_CASE(DependentMember)
1742+
LEAF_CASE(SILFunction)
1743+
LEAF_CASE(SILBlockStorage)
1744+
LEAF_CASE(SILBox)
1745+
LEAF_CASE(SILToken)
1746+
LEAF_CASE(ProtocolComposition)
1747+
LEAF_CASE(TypeVariable) // FIXME: Could do better here when we create vars
1748+
1749+
#undef LEAF_CASE
1750+
#undef FAILURE_CASE
1751+
};
1752+
1753+
}
1754+
1755+
Type ConstraintSystem::findCommonOverloadType(
1756+
ArrayRef<OverloadChoice> choices,
1757+
ArrayRef<OverloadChoice> outerAlternatives,
1758+
ConstraintLocator *locator) {
1759+
// Local function to consider this s new overload choice, updating the
1760+
// "common type". Returns true if this overload cannot be integrated into
1761+
// the common type, at which point there is no "common type".
1762+
Type commonType;
1763+
auto considerOverload = [&](const OverloadChoice &overload) -> bool {
1764+
// If we can't even get a type for the overload, there's nothing more to
1765+
// do.
1766+
Type overloadType = getEffectiveOverloadType(overload);
1767+
if (!overloadType) {
1768+
return true;
1769+
}
1770+
1771+
// If this is the first overload, record it's type as the common type.
1772+
if (!commonType) {
1773+
commonType = overloadType;
1774+
return false;
1775+
}
1776+
1777+
// Find the common type between the current common type and the new
1778+
// overload's type.
1779+
commonType = CommonTypeVisitor().visit(commonType, overloadType);
1780+
if (!commonType) {
1781+
return true;
1782+
}
1783+
1784+
return false;
1785+
};
1786+
1787+
// Consider all of the choices and outer alternatives.
1788+
for (const auto &choice : choices) {
1789+
if (considerOverload(choice))
1790+
return Type();
1791+
}
1792+
for (const auto &choice : outerAlternatives) {
1793+
if (considerOverload(choice))
1794+
return Type();
1795+
}
1796+
1797+
assert(commonType && "We can't get here without having a common type");
1798+
1799+
// If our common type contains any generic parameters, open them up into
1800+
// type variables.
1801+
if (commonType->hasTypeParameter()) {
1802+
llvm::SmallDenseMap<const GenericTypeParamType *, TypeVariableType *>
1803+
openedGenericParams;
1804+
commonType = commonType.transformRec([&](TypeBase *type) -> Optional<Type> {
1805+
if (auto genericParam = dyn_cast<GenericTypeParamType>(type)) {
1806+
auto canGenericParam = GenericTypeParamType::get(
1807+
genericParam->getDepth(),
1808+
genericParam->getIndex(),
1809+
type->getASTContext());
1810+
auto knownTypeVar = openedGenericParams.find(canGenericParam);
1811+
if (knownTypeVar != openedGenericParams.end())
1812+
return Type(knownTypeVar->second);
1813+
1814+
auto typeVar = createTypeVariable(locator);
1815+
openedGenericParams[canGenericParam] = typeVar;
1816+
return Type(typeVar);
1817+
}
1818+
1819+
return None;
1820+
});
1821+
}
1822+
1823+
return commonType;
1824+
}
1825+
14671826
void ConstraintSystem::addOverloadSet(Type boundType,
14681827
ArrayRef<OverloadChoice> choices,
14691828
DeclContext *useDC,
@@ -1478,6 +1837,13 @@ void ConstraintSystem::addOverloadSet(Type boundType,
14781837
return;
14791838
}
14801839

1840+
// If we can compute a common type for the overload set, bind that type.
1841+
if (Type commonType = findCommonOverloadType(choices, outerAlternatives,
1842+
locator)) {
1843+
addConstraint(ConstraintKind::Bind, boundType, commonType, locator);
1844+
boundType = commonType;
1845+
}
1846+
14811847
tryOptimizeGenericDisjunction(*this, choices, favoredChoice);
14821848

14831849
SmallVector<OverloadChoice, 4> scratchChoices;

0 commit comments

Comments
 (0)