21
21
#include " TypeCheckType.h"
22
22
#include " swift/AST/GenericEnvironment.h"
23
23
#include " swift/AST/ParameterList.h"
24
+ #include " swift/AST/TypeVisitor.h"
24
25
#include " swift/Basic/Statistic.h"
25
26
#include " llvm/ADT/SetVector.h"
26
27
#include " llvm/ADT/SmallString.h"
@@ -1464,6 +1465,364 @@ static ArrayRef<OverloadChoice> partitionSIMDOperators(
1464
1465
return scratch;
1465
1466
}
1466
1467
1468
+ // / Retrieve the type that will be used when matching the given overload.
1469
+ static Type getEffectiveOverloadType (const OverloadChoice &overload) {
1470
+ switch (overload.getKind ()) {
1471
+ case OverloadChoiceKind::Decl:
1472
+ // Declaration choices are handled below.
1473
+ break ;
1474
+
1475
+ case OverloadChoiceKind::BaseType:
1476
+ case OverloadChoiceKind::DeclViaBridge:
1477
+ case OverloadChoiceKind::DeclViaDynamic:
1478
+ case OverloadChoiceKind::DeclViaUnwrappedOptional:
1479
+ case OverloadChoiceKind::DynamicMemberLookup:
1480
+ case OverloadChoiceKind::KeyPathApplication:
1481
+ case OverloadChoiceKind::TupleIndex:
1482
+ return Type ();
1483
+ }
1484
+
1485
+ auto decl = overload.getDecl ();
1486
+
1487
+ // Retrieve the interface type.
1488
+ auto type = decl->getInterfaceType ();
1489
+ if (!type) {
1490
+ decl->getASTContext ().getLazyResolver ()->resolveDeclSignature (decl);
1491
+ type = decl->getInterfaceType ();
1492
+ if (!type) {
1493
+ return Type ();
1494
+ }
1495
+ }
1496
+
1497
+ // If we have a generic function type, drop the generic signature; we don't
1498
+ // need it for this comparison.
1499
+ if (auto genericFn = type->getAs <GenericFunctionType>()) {
1500
+ type = FunctionType::get (genericFn->getParams (),
1501
+ genericFn->getResult (),
1502
+ genericFn->getExtInfo ());
1503
+ }
1504
+
1505
+ // If this declaration is within a type context, bail out.
1506
+ if (decl->getDeclContext ()->isTypeContext ()) {
1507
+ return Type ();
1508
+ }
1509
+
1510
+ return type;
1511
+ }
1512
+
1513
+ namespace {
1514
+ // / Type visitor that extracts the common type between two types, when
1515
+ // / possible.
1516
+ class CommonTypeVisitor : public TypeVisitor <CommonTypeVisitor, Type, Type> {
1517
+ // / Perform a "leaf" match for types, which does not consider the children.
1518
+ Type handleLeafMatch (Type type1, Type type2) {
1519
+ if (type1->isEqual (type2))
1520
+ return type1;
1521
+
1522
+ return handleMismatch (type1, type2);
1523
+ }
1524
+
1525
+ // / Handle a mismatch between two types.
1526
+ Type handleMismatch (Type type1, Type type2) {
1527
+ return Type ();
1528
+ }
1529
+
1530
+ public:
1531
+ Type visitTupleType (TupleType *tuple1, Type type2) {
1532
+ if (tuple1->isEqual (type2))
1533
+ return Type (tuple1);
1534
+
1535
+ auto tuple2 = type2->getAs <TupleType>();
1536
+ if (!tuple2) {
1537
+ return handleMismatch (Type (tuple1), type2);
1538
+ }
1539
+
1540
+ // Check for structural similarity between the two tuple types.
1541
+ auto elements1 = tuple1->getElements ();
1542
+ auto elements2 = tuple2->getElements ();
1543
+ if (elements1.size () != elements2.size ()) {
1544
+ return handleMismatch (Type (tuple1), type2);
1545
+ }
1546
+
1547
+ for (unsigned i : indices (elements1)) {
1548
+ const auto &elt1 = elements1[i];
1549
+ const auto &elt2 = elements2[i];
1550
+ if (elt1.getName () != elt2.getName () ||
1551
+ elt1.getParameterFlags () != elt2.getParameterFlags ()) {
1552
+ return handleMismatch (Type (tuple1), type2);
1553
+ }
1554
+ }
1555
+
1556
+ // Recurse on the element types.
1557
+ SmallVector<TupleTypeElt, 4 > newElements;
1558
+ newElements.reserve (elements1.size ());
1559
+ for (unsigned i : indices (elements1)) {
1560
+ const auto &elt1 = elements1[i];
1561
+ const auto &elt2 = elements2[i];
1562
+ Type elementType = visit (elt1.getRawType (), elt2.getRawType ());
1563
+ if (!elementType) {
1564
+ return handleMismatch (Type (tuple1), type2);
1565
+ }
1566
+
1567
+ newElements.push_back (elt1.getWithType (elementType));
1568
+ }
1569
+ return TupleType::get (newElements, tuple1->getASTContext ());
1570
+ }
1571
+
1572
+ Type visitReferenceStorageType (ReferenceStorageType *refStorage1,
1573
+ Type type2) {
1574
+ if (refStorage1->isEqual (type2))
1575
+ return Type (refStorage1);
1576
+
1577
+ auto refStorage2 = type2->getAs <ReferenceStorageType>();
1578
+ if (!refStorage2 ||
1579
+ refStorage1->getOwnership () != refStorage2->getOwnership ()) {
1580
+ return handleMismatch (Type (refStorage1), type2);
1581
+ }
1582
+
1583
+ Type newReferentType = visit (refStorage1->getReferentType (),
1584
+ refStorage2->getReferentType ());
1585
+ if (!newReferentType) {
1586
+ return handleMismatch (Type (refStorage1), type2);
1587
+ }
1588
+
1589
+ return ReferenceStorageType::get (newReferentType,
1590
+ refStorage1->getOwnership (),
1591
+ refStorage1->getASTContext ());
1592
+ }
1593
+
1594
+ Type visitAnyMetatypeType (AnyMetatypeType *metatype1, Type type2) {
1595
+ if (metatype1->isEqual (type2))
1596
+ return Type (metatype1);
1597
+
1598
+
1599
+ auto metatype2 = type2->getAs <AnyMetatypeType>();
1600
+ if (!metatype2) {
1601
+ return handleMismatch (Type (metatype1), type2);
1602
+ }
1603
+
1604
+ if (metatype1->getKind () != metatype2->getKind () ||
1605
+ metatype1->hasRepresentation () != metatype2->hasRepresentation () ||
1606
+ (metatype1->hasRepresentation () &&
1607
+ metatype2->getRepresentation () != metatype2->getRepresentation ())) {
1608
+ return handleMismatch (Type (metatype1), type2);
1609
+ }
1610
+
1611
+ Type newInstanceType = visit (metatype1->getInstanceType (),
1612
+ metatype2->getInstanceType ());
1613
+ if (!newInstanceType) {
1614
+ return handleMismatch (Type (metatype1), type2);
1615
+ }
1616
+
1617
+ Optional<MetatypeRepresentation> representation;
1618
+ if (metatype1->hasRepresentation ())
1619
+ representation = metatype1->getRepresentation ();
1620
+
1621
+ if (metatype1->getKind () == TypeKind::Metatype)
1622
+ return MetatypeType::get (newInstanceType, representation);
1623
+
1624
+ assert (metatype1->getKind () == TypeKind::ExistentialMetatype);
1625
+ return ExistentialMetatypeType::get (newInstanceType, representation);
1626
+ }
1627
+
1628
+ Type visitFunctionType (FunctionType *function1, Type type2) {
1629
+ if (function1->isEqual (type2))
1630
+ return Type (function1);
1631
+
1632
+ auto function2 = type2->getAs <FunctionType>();
1633
+ if (!function2 ||
1634
+ function1->getExtInfo () != function2->getExtInfo () ||
1635
+ function1->getNumParams () != function2->getNumParams ()) {
1636
+ return handleMismatch (Type (function1), type2);
1637
+ }
1638
+
1639
+ // Check for a structural match between the parameters.
1640
+ auto params1 = function1->getParams ();
1641
+ auto params2 = function2->getParams ();
1642
+ for (unsigned i : indices (params1)) {
1643
+ const auto ¶m1 = params1[i];
1644
+ const auto ¶m2 = params2[i];
1645
+ if (param1.getLabel () != param2.getLabel () ||
1646
+ param1.getParameterFlags () != param2.getParameterFlags ()) {
1647
+ return handleMismatch (Type (function1), type2);
1648
+ }
1649
+ }
1650
+
1651
+ Type newResultType = visit (function1->getResult (), function2->getResult ());
1652
+ if (!newResultType) {
1653
+ return handleMismatch (Type (function1), type2);
1654
+ }
1655
+
1656
+ SmallVector<AnyFunctionType::Param, 4 > newParams;
1657
+ newParams.reserve (params1.size ());
1658
+ for (unsigned i : indices (params1)) {
1659
+ const auto ¶m1 = params1[i];
1660
+ const auto ¶m2 = params2[i];
1661
+ Type newParamType = visit (param1.getPlainType (), param2.getPlainType ());
1662
+ if (!newParamType) {
1663
+ return handleMismatch (Type (function1), type2);
1664
+ }
1665
+
1666
+ newParams.push_back (AnyFunctionType::Param (newParamType,
1667
+ param1.getLabel (),
1668
+ param1.getParameterFlags ()));
1669
+ }
1670
+
1671
+ return FunctionType::get (newParams, newResultType, function1->getExtInfo ());
1672
+ }
1673
+
1674
+ Type visitGenericFunctionType (GenericFunctionType *function1, Type type2) {
1675
+ llvm_unreachable (" Caller should have eliminated these" );
1676
+ }
1677
+
1678
+ Type visitLValueType (LValueType *lvalue1, Type type2) {
1679
+ if (lvalue1->isEqual (type2))
1680
+ return Type (lvalue1);
1681
+
1682
+ auto lvalue2 = type2->getAs <LValueType>();
1683
+ if (!lvalue2) {
1684
+ return handleMismatch (Type (lvalue1), type2);
1685
+ }
1686
+
1687
+ Type newObjectType =
1688
+ visit (lvalue1->getObjectType (), lvalue2->getObjectType ());
1689
+ if (!newObjectType) {
1690
+ return handleMismatch (Type (lvalue1), type2);
1691
+ }
1692
+
1693
+ return LValueType::get (newObjectType);
1694
+ }
1695
+
1696
+ Type visitInOutType (InOutType *inout1, Type type2) {
1697
+ if (inout1->isEqual (type2))
1698
+ return Type (inout1);
1699
+
1700
+ auto inout2 = type2->getAs <InOutType>();
1701
+ if (!inout2) {
1702
+ return handleMismatch (Type (inout1), type2);
1703
+ }
1704
+
1705
+ Type newObjectType =
1706
+ visit (inout1->getObjectType (), inout2->getObjectType ());
1707
+ if (!newObjectType) {
1708
+ return handleMismatch (Type (inout1), type2);
1709
+ }
1710
+
1711
+ return LValueType::get (newObjectType);
1712
+ }
1713
+
1714
+ Type visitSugarType (SugarType *sugar1, Type type2) {
1715
+ if (sugar1->isEqual (type2))
1716
+ return Type (sugar1);
1717
+
1718
+ // FIXME: Reconstitute sugar.
1719
+ return visit (Type (sugar1->getSinglyDesugaredType ()), type2);
1720
+ }
1721
+
1722
+ #define FAILURE_CASE (Class ) \
1723
+ Type visit##Class##Type(Class##Type *type1, Type type2) { \
1724
+ return Type (); \
1725
+ }
1726
+
1727
+ #define LEAF_CASE (Class ) \
1728
+ Type visit##Class##Type(Class##Type *type1, Type type2) { \
1729
+ return handleLeafMatch (Type (type1), type2); \
1730
+ }
1731
+
1732
+ FAILURE_CASE (Error)
1733
+ FAILURE_CASE (Unresolved)
1734
+ LEAF_CASE (Builtin)
1735
+ LEAF_CASE (Nominal) // FIXME: We can do a more specific match here.
1736
+ LEAF_CASE (BoundGeneric) // FIXME: We can do a more specific match here.
1737
+ FAILURE_CASE (UnboundGeneric)
1738
+ LEAF_CASE (Module)
1739
+ LEAF_CASE (DynamicSelf) // FIXME: Can we do better here?
1740
+ LEAF_CASE (Substitutable)
1741
+ LEAF_CASE (DependentMember)
1742
+ LEAF_CASE (SILFunction)
1743
+ LEAF_CASE (SILBlockStorage)
1744
+ LEAF_CASE (SILBox)
1745
+ LEAF_CASE (SILToken)
1746
+ LEAF_CASE (ProtocolComposition)
1747
+ LEAF_CASE (TypeVariable) // FIXME: Could do better here when we create vars
1748
+
1749
+ #undef LEAF_CASE
1750
+ #undef FAILURE_CASE
1751
+ };
1752
+
1753
+ }
1754
+
1755
+ Type ConstraintSystem::findCommonOverloadType (
1756
+ ArrayRef<OverloadChoice> choices,
1757
+ ArrayRef<OverloadChoice> outerAlternatives,
1758
+ ConstraintLocator *locator) {
1759
+ // Local function to consider this s new overload choice, updating the
1760
+ // "common type". Returns true if this overload cannot be integrated into
1761
+ // the common type, at which point there is no "common type".
1762
+ Type commonType;
1763
+ auto considerOverload = [&](const OverloadChoice &overload) -> bool {
1764
+ // If we can't even get a type for the overload, there's nothing more to
1765
+ // do.
1766
+ Type overloadType = getEffectiveOverloadType (overload);
1767
+ if (!overloadType) {
1768
+ return true ;
1769
+ }
1770
+
1771
+ // If this is the first overload, record it's type as the common type.
1772
+ if (!commonType) {
1773
+ commonType = overloadType;
1774
+ return false ;
1775
+ }
1776
+
1777
+ // Find the common type between the current common type and the new
1778
+ // overload's type.
1779
+ commonType = CommonTypeVisitor ().visit (commonType, overloadType);
1780
+ if (!commonType) {
1781
+ return true ;
1782
+ }
1783
+
1784
+ return false ;
1785
+ };
1786
+
1787
+ // Consider all of the choices and outer alternatives.
1788
+ for (const auto &choice : choices) {
1789
+ if (considerOverload (choice))
1790
+ return Type ();
1791
+ }
1792
+ for (const auto &choice : outerAlternatives) {
1793
+ if (considerOverload (choice))
1794
+ return Type ();
1795
+ }
1796
+
1797
+ assert (commonType && " We can't get here without having a common type" );
1798
+
1799
+ // If our common type contains any generic parameters, open them up into
1800
+ // type variables.
1801
+ if (commonType->hasTypeParameter ()) {
1802
+ llvm::SmallDenseMap<const GenericTypeParamType *, TypeVariableType *>
1803
+ openedGenericParams;
1804
+ commonType = commonType.transformRec ([&](TypeBase *type) -> Optional<Type> {
1805
+ if (auto genericParam = dyn_cast<GenericTypeParamType>(type)) {
1806
+ auto canGenericParam = GenericTypeParamType::get (
1807
+ genericParam->getDepth (),
1808
+ genericParam->getIndex (),
1809
+ type->getASTContext ());
1810
+ auto knownTypeVar = openedGenericParams.find (canGenericParam);
1811
+ if (knownTypeVar != openedGenericParams.end ())
1812
+ return Type (knownTypeVar->second );
1813
+
1814
+ auto typeVar = createTypeVariable (locator);
1815
+ openedGenericParams[canGenericParam] = typeVar;
1816
+ return Type (typeVar);
1817
+ }
1818
+
1819
+ return None;
1820
+ });
1821
+ }
1822
+
1823
+ return commonType;
1824
+ }
1825
+
1467
1826
void ConstraintSystem::addOverloadSet (Type boundType,
1468
1827
ArrayRef<OverloadChoice> choices,
1469
1828
DeclContext *useDC,
@@ -1478,6 +1837,13 @@ void ConstraintSystem::addOverloadSet(Type boundType,
1478
1837
return ;
1479
1838
}
1480
1839
1840
+ // If we can compute a common type for the overload set, bind that type.
1841
+ if (Type commonType = findCommonOverloadType (choices, outerAlternatives,
1842
+ locator)) {
1843
+ addConstraint (ConstraintKind::Bind, boundType, commonType, locator);
1844
+ boundType = commonType;
1845
+ }
1846
+
1481
1847
tryOptimizeGenericDisjunction (*this , choices, favoredChoice);
1482
1848
1483
1849
SmallVector<OverloadChoice, 4 > scratchChoices;
0 commit comments