Skip to content

Commit cba9613

Browse files
xedinrudkx
authored andcommitted
[Type checker] Introduce directional path consistency algorithm
DPC algorithm tries to solve individual sub-expressions and combine resolved types as a way to reduce pre-existing OSR domains. Solving is done bottom-up so each consecutive sub-expression tightens possible solution domain even further. (cherry picked from commit 8e04a84)
1 parent 8db7b08 commit cba9613

File tree

6 files changed

+565
-43
lines changed

6 files changed

+565
-43
lines changed

include/swift/AST/Expr.h

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1421,7 +1421,11 @@ class OverloadSetRefExpr : public Expr {
14211421

14221422
public:
14231423
ArrayRef<ValueDecl*> getDecls() const { return Decls; }
1424-
1424+
1425+
void setDecls(ArrayRef<ValueDecl *> domain) {
1426+
Decls = domain;
1427+
}
1428+
14251429
/// getBaseType - Determine the type of the base object provided for the
14261430
/// given overload set, which is only non-null when dealing with an overloaded
14271431
/// member reference.

lib/Sema/CSSolver.cpp

Lines changed: 318 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@
2020
#include "llvm/Support/SaveAndRestore.h"
2121
#include <memory>
2222
#include <tuple>
23+
#include <stack>
24+
#include <queue>
2325
using namespace swift;
2426
using namespace constraints;
2527

@@ -1349,6 +1351,321 @@ ConstraintSystem::solveSingle(FreeTypeVariableBinding allowFreeTypeVariables) {
13491351
return std::move(solutions[0]);
13501352
}
13511353

1354+
bool ConstraintSystem::Candidate::solve() {
1355+
// Cleanup after constraint system generation/solving,
1356+
// because it would assign types to expressions, which
1357+
// might interfere with solving higher-level expressions.
1358+
ExprCleaner cleaner(E);
1359+
1360+
// Allocate new constraint system for sub-expression.
1361+
ConstraintSystem cs(TC, DC, None);
1362+
1363+
// Set contextual type if present. This is done before constraint generation
1364+
// to give a "hint" to that operation about possible optimizations.
1365+
if (!CT.isNull())
1366+
cs.setContextualType(E, CT, CTP);
1367+
1368+
// Generate constraints for the new system.
1369+
if (auto generatedExpr = cs.generateConstraints(E)) {
1370+
E = generatedExpr;
1371+
} else {
1372+
// Failure to generate constraint system for sub-expression means we can't
1373+
// continue solving sub-expressions.
1374+
return true;
1375+
}
1376+
1377+
// If there is contextual type present, add an explicit "conversion"
1378+
// constraint to the system.
1379+
if (!CT.isNull()) {
1380+
auto constraintKind = ConstraintKind::Conversion;
1381+
if (CTP == CTP_CallArgument)
1382+
constraintKind = ConstraintKind::ArgumentConversion;
1383+
1384+
cs.addConstraint(constraintKind, E->getType(), CT.getType(),
1385+
cs.getConstraintLocator(E), /*isFavored=*/true);
1386+
}
1387+
1388+
// Try to solve the system and record all available solutions.
1389+
llvm::SmallVector<Solution, 2> solutions;
1390+
{
1391+
SolverState state(cs);
1392+
cs.solverState = &state;
1393+
1394+
// Use solveRec() instead of solve() in here, because solve()
1395+
// would try to deduce the best solution, which we don't
1396+
// really want. Instead, we want the reduced set of domain choices.
1397+
cs.solveRec(solutions, FreeTypeVariableBinding::Allow);
1398+
1399+
cs.solverState = nullptr;
1400+
}
1401+
1402+
// No solutions for the sub-expression means that either main expression
1403+
// needs salvaging or it's inconsistent (read: doesn't have solutions).
1404+
if (solutions.empty())
1405+
return true;
1406+
1407+
// Record found solutions as suggestions.
1408+
this->applySolutions(solutions);
1409+
return false;
1410+
}
1411+
1412+
void ConstraintSystem::Candidate::applySolutions(
1413+
llvm::SmallVectorImpl<Solution> &solutions) const {
1414+
// A collection of OSRs with their newly reduced domains,
1415+
// it's domains are sets because multiple solutions can have the same
1416+
// choice for one of the type variables, and we want no duplication.
1417+
llvm::SmallDenseMap<OverloadSetRefExpr *, llvm::SmallSet<ValueDecl *, 2>>
1418+
domains;
1419+
for (auto &solution : solutions) {
1420+
for (auto choice : solution.overloadChoices) {
1421+
// Some of the choices might not have locators.
1422+
if (!choice.getFirst())
1423+
continue;
1424+
1425+
auto anchor = choice.getFirst()->getAnchor();
1426+
// Anchor is not available or expression is not an overload set.
1427+
if (!anchor || !isa<OverloadSetRefExpr>(anchor))
1428+
continue;
1429+
1430+
auto OSR = cast<OverloadSetRefExpr>(anchor);
1431+
auto overload = choice.getSecond().choice;
1432+
auto type = overload.getDecl()->getInterfaceType();
1433+
1434+
// One of the solutions has polymorphic type assigned with one of it's
1435+
// type variables. Such functions can only be properly resolved
1436+
// via complete expression, so we'll have to forget solutions
1437+
// we have already recorded. They might not include all viable overload
1438+
// choices.
1439+
if (type->is<GenericFunctionType>()) {
1440+
return;
1441+
}
1442+
1443+
domains[OSR].insert(overload.getDecl());
1444+
}
1445+
}
1446+
1447+
// Reduce the domains.
1448+
for (auto &domain : domains) {
1449+
auto OSR = domain.getFirst();
1450+
auto &choices = domain.getSecond();
1451+
1452+
// If the domain wasn't reduced, skip it.
1453+
if (OSR->getDecls().size() == choices.size()) continue;
1454+
1455+
// Update the expression with the reduced domain.
1456+
MutableArrayRef<ValueDecl *> decls
1457+
= TC.Context.AllocateUninitialized<ValueDecl *>(choices.size());
1458+
std::uninitialized_copy(choices.begin(), choices.end(), decls.begin());
1459+
OSR->setDecls(decls);
1460+
}
1461+
}
1462+
1463+
void ConstraintSystem::shrink(Expr *expr) {
1464+
typedef llvm::SmallDenseMap<Expr *, ArrayRef<ValueDecl *>> DomainMap;
1465+
1466+
// A collection of original domains of all of the expressions,
1467+
// so they can be restored in case of failure.
1468+
DomainMap domains;
1469+
1470+
struct ExprCollector : public ASTWalker {
1471+
// The primary constraint system.
1472+
ConstraintSystem &CS;
1473+
1474+
// All of the sub-expressions of certain type (binary/unary/calls) in
1475+
// depth-first order.
1476+
std::queue<Candidate> &SubExprs;
1477+
1478+
// Counts the number of overload sets present in the tree so far.
1479+
// Note that the traversal is depth-first.
1480+
std::stack<std::pair<ApplyExpr *, unsigned>,
1481+
llvm::SmallVector<std::pair<ApplyExpr *, unsigned>, 4>>
1482+
ApplyExprs;
1483+
1484+
// A collection of original domains of all of the expressions,
1485+
// so they can be restored in case of failure.
1486+
DomainMap &Domains;
1487+
1488+
ExprCollector(ConstraintSystem &cs,
1489+
std::queue<Candidate> &container,
1490+
DomainMap &domains)
1491+
: CS(cs), SubExprs(container), Domains(domains) { }
1492+
1493+
std::pair<bool, Expr *> walkToExprPre(Expr *expr) override {
1494+
// A dictionary expression is just a set of tuples; try to solve ones
1495+
// that have overload sets.
1496+
if (auto dictionaryExpr = dyn_cast<DictionaryExpr>(expr)) {
1497+
for (auto element : dictionaryExpr->getElements()) {
1498+
unsigned numOverlaods = 0;
1499+
element->walk(OverloadSetCounter(numOverlaods));
1500+
1501+
// There are no overload sets in the element; skip it.
1502+
if (numOverlaods == 0)
1503+
continue;
1504+
1505+
// FIXME: Could we avoid creating a separate dictionary expression
1506+
// here by introducing a contextual type on the element?
1507+
auto dict = DictionaryExpr::create(CS.getASTContext(),
1508+
dictionaryExpr->getLBracketLoc(),
1509+
{ element },
1510+
dictionaryExpr->getRBracketLoc(),
1511+
dictionaryExpr->getType());
1512+
1513+
// Make each of the dictionary elements an independent dictionary,
1514+
// such makes it easy to type-check everything separately.
1515+
SubExprs.push(Candidate(CS, dict));
1516+
}
1517+
1518+
// Don't try to walk into the dictionary.
1519+
return { false, expr };
1520+
}
1521+
1522+
// Let's not attempt to type-check closures or default values,
1523+
// which has already been type checked anyway.
1524+
if (isa<ClosureExpr>(expr) || isa<DefaultValueExpr>(expr)) {
1525+
return { false, expr };
1526+
}
1527+
1528+
// Coerce to type expressions are only viable if they have
1529+
// a single child expression.
1530+
if (auto coerceExpr = dyn_cast<CoerceExpr>(expr)) {
1531+
if (!coerceExpr->getSubExpr()) {
1532+
return { false, expr };
1533+
}
1534+
}
1535+
1536+
if (auto OSR = dyn_cast<OverloadSetRefExpr>(expr)) {
1537+
Domains[OSR] = OSR->getDecls();
1538+
}
1539+
1540+
if (auto applyExpr = dyn_cast<ApplyExpr>(expr)) {
1541+
auto func = applyExpr->getFn();
1542+
// Let's record this function application for post-processing
1543+
// as well as if it contains overload set, see walkToExprPost.
1544+
ApplyExprs.push({ applyExpr, isa<OverloadSetRefExpr>(func) });
1545+
}
1546+
1547+
return { true, expr };
1548+
}
1549+
1550+
Expr *walkToExprPost(Expr *expr) override {
1551+
if (!isa<ApplyExpr>(expr))
1552+
return expr;
1553+
1554+
unsigned numOverloadSets = 0;
1555+
// Let's count how many overload sets do we have.
1556+
while (!ApplyExprs.empty()) {
1557+
auto application = ApplyExprs.top();
1558+
auto applyExpr = application.first;
1559+
1560+
// Add overload sets tracked by current expression.
1561+
numOverloadSets += application.second;
1562+
ApplyExprs.pop();
1563+
1564+
// We've found the current expression, so record the number of
1565+
// overloads.
1566+
if (expr == applyExpr) {
1567+
ApplyExprs.push({ applyExpr, numOverloadSets });
1568+
break;
1569+
}
1570+
}
1571+
1572+
// If there are fewer than two overloads in the chain
1573+
// there is no point of solving this expression,
1574+
// because we won't be able to reduce it's domain.
1575+
if (numOverloadSets > 1)
1576+
SubExprs.push(Candidate(CS, expr));
1577+
1578+
return expr;
1579+
}
1580+
};
1581+
1582+
std::queue<Candidate> expressions;
1583+
ExprCollector collector(*this, expressions, domains);
1584+
1585+
// Collect all of the binary/unary and call sub-expressions
1586+
// so we can start solving them separately.
1587+
expr->walk(collector);
1588+
1589+
while (!expressions.empty()) {
1590+
auto &candidate = expressions.front();
1591+
1592+
// If there are no results, let's forget everything we know about the
1593+
// system so far. This actually is ok, because some of the expressions
1594+
// might require manual salvaging.
1595+
if (candidate.solve()) {
1596+
// Let's restore all of the original OSR domains.
1597+
for (auto &domain : domains) {
1598+
if (auto OSR = dyn_cast<OverloadSetRefExpr>(domain.getFirst())) {
1599+
OSR->setDecls(domain.getSecond());
1600+
}
1601+
}
1602+
break;
1603+
}
1604+
1605+
expressions.pop();
1606+
}
1607+
}
1608+
1609+
ConstraintSystem::SolutionKind
1610+
ConstraintSystem::solve(Expr *&expr,
1611+
Type convertType,
1612+
ExprTypeCheckListener *listener,
1613+
SmallVectorImpl<Solution> &solutions,
1614+
FreeTypeVariableBinding allowFreeTypeVariables) {
1615+
assert(!solverState && "use solveRec for recursive calls");
1616+
1617+
// Try to shrink the system by reducing disjunction domains. This
1618+
// goes through every sub-expression and generate it's own sub-system, to
1619+
// try to reduce the domains of those subexpressions.
1620+
shrink(expr);
1621+
1622+
// Generate constraints for the main system.
1623+
if (auto generatedExpr = generateConstraints(expr))
1624+
expr = generatedExpr;
1625+
else {
1626+
return SolutionKind::Error;
1627+
}
1628+
1629+
// If there is a type that we're expected to convert to, add the conversion
1630+
// constraint.
1631+
if (convertType) {
1632+
auto constraintKind = ConstraintKind::Conversion;
1633+
if (getContextualTypePurpose() == CTP_CallArgument)
1634+
constraintKind = ConstraintKind::ArgumentConversion;
1635+
1636+
if (allowFreeTypeVariables == FreeTypeVariableBinding::UnresolvedType) {
1637+
convertType = convertType.transform([&](Type type) -> Type {
1638+
if (type->is<UnresolvedType>())
1639+
return createTypeVariable(getConstraintLocator(expr), 0);
1640+
return type;
1641+
});
1642+
}
1643+
1644+
addConstraint(constraintKind, expr->getType(), convertType,
1645+
getConstraintLocator(expr), /*isFavored*/ true);
1646+
}
1647+
1648+
// Notify the listener that we've built the constraint system.
1649+
if (listener && listener->builtConstraints(*this, expr)) {
1650+
return SolutionKind::Error;
1651+
}
1652+
1653+
if (TC.getLangOpts().DebugConstraintSolver) {
1654+
auto &log = getASTContext().TypeCheckerDebug->getStream();
1655+
log << "---Initial constraints for the given expression---\n";
1656+
expr->print(log);
1657+
log << "\n";
1658+
print(log);
1659+
}
1660+
1661+
// Try to solve the constraint system using computed suggestions.
1662+
solve(solutions, allowFreeTypeVariables);
1663+
1664+
// If there are no solutions let's mark system as unsolved,
1665+
// and solved otherwise even if there are multiple solutions still present.
1666+
return solutions.empty() ? SolutionKind::Unsolved : SolutionKind::Solved;
1667+
}
1668+
13521669
bool ConstraintSystem::solve(SmallVectorImpl<Solution> &solutions,
13531670
FreeTypeVariableBinding allowFreeTypeVariables) {
13541671
assert(!solverState && "use solveRec for recursive calls");
@@ -1372,7 +1689,7 @@ bool ConstraintSystem::solve(SmallVectorImpl<Solution> &solutions,
13721689

13731690
// Remove the solver state.
13741691
this->solverState = nullptr;
1375-
1692+
13761693
// We fail if there is no solution.
13771694
return solutions.empty();
13781695
}

0 commit comments

Comments
 (0)