Skip to content

Commit 151fc79

Browse files
authored
Merge pull request swiftlang#29770 from DougGregor/builder-transform-node-types
[Constraint solver] Fix handling of node types for function builder application
2 parents 0dc9c02 + 4c98b31 commit 151fc79

File tree

5 files changed

+54
-83
lines changed

5 files changed

+54
-83
lines changed

lib/Sema/BuilderTransform.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1101,6 +1101,9 @@ Optional<BraceStmt *> TypeChecker::applyFunctionBuilderBodyTransform(
11011101
// The system was salvaged; continue on as if nothing happened.
11021102
}
11031103

1104+
// FIXME: Shouldn't need to do this.
1105+
cs.applySolution(solutions.front());
1106+
11041107
// Apply the solution to the function body.
11051108
if (auto result = cs.applySolution(
11061109
solutions.front(),

lib/Sema/CSApply.cpp

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7538,14 +7538,8 @@ ProtocolConformanceRef Solution::resolveConformance(
75387538
}
75397539

75407540
Type Solution::getType(const Expr *expr) const {
7541-
auto result = llvm::find_if(
7542-
addedNodeTypes, [&](const std::pair<TypedNode, Type> &node) -> bool {
7543-
if (auto *e = node.first.dyn_cast<const Expr *>())
7544-
return expr == e;
7545-
return false;
7546-
});
7547-
7548-
if (result != addedNodeTypes.end())
7541+
auto result = nodeTypes.find(expr);
7542+
if (result != nodeTypes.end())
75497543
return result->second;
75507544

75517545
auto &cs = getConstraintSystem();

lib/Sema/CSSolver.cpp

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -164,8 +164,8 @@ Solution ConstraintSystem::finalize() {
164164
solution.DefaultedConstraints.insert(DefaultedConstraints.begin(),
165165
DefaultedConstraints.end());
166166

167-
for (auto &nodeType : addedNodeTypes) {
168-
solution.addedNodeTypes.insert(nodeType);
167+
for (auto &nodeType : NodeTypes) {
168+
solution.nodeTypes.insert(nodeType);
169169
}
170170

171171
// Remember contextual types.
@@ -231,9 +231,8 @@ void ConstraintSystem::applySolution(const Solution &solution) {
231231
solution.DefaultedConstraints.end());
232232

233233
// Add the node types back.
234-
for (auto &nodeType : solution.addedNodeTypes) {
235-
if (!hasType(nodeType.first))
236-
setType(nodeType.first, nodeType.second);
234+
for (auto &nodeType : solution.nodeTypes) {
235+
setType(nodeType.first, nodeType.second);
237236
}
238237

239238
// Add the contextual types.
@@ -511,8 +510,13 @@ ConstraintSystem::SolverScope::~SolverScope() {
511510
truncate(cs.DefaultedConstraints, numDefaultedConstraints);
512511

513512
// Remove any node types we registered.
514-
for (unsigned i : range(numAddedNodeTypes, cs.addedNodeTypes.size())) {
515-
cs.eraseType(cs.addedNodeTypes[i].first);
513+
for (unsigned i :
514+
reverse(range(numAddedNodeTypes, cs.addedNodeTypes.size()))) {
515+
TypedNode node = cs.addedNodeTypes[i].first;
516+
if (Type oldType = cs.addedNodeTypes[i].second)
517+
cs.NodeTypes[node] = oldType;
518+
else
519+
cs.NodeTypes.erase(node);
516520
}
517521
truncate(cs.addedNodeTypes, numAddedNodeTypes);
518522

lib/Sema/ConstraintSystem.h

Lines changed: 24 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -856,7 +856,7 @@ class Solution {
856856
llvm::SmallPtrSet<ConstraintLocator *, 2> DefaultedConstraints;
857857

858858
/// The node -> type mappings introduced by this solution.
859-
llvm::MapVector<TypedNode, Type> addedNodeTypes;
859+
llvm::MapVector<TypedNode, Type> nodeTypes;
860860

861861
/// Contextual types introduced by this solution.
862862
std::vector<std::pair<const Expr *, ContextualTypeInfo>> contextualTypes;
@@ -953,8 +953,8 @@ class Solution {
953953

954954
/// Retrieve the type of the given node, as recorded in this solution.
955955
Type getType(TypedNode node) const {
956-
auto known = addedNodeTypes.find(node);
957-
assert(known != addedNodeTypes.end());
956+
auto known = nodeTypes.find(node);
957+
assert(known != nodeTypes.end());
958958
return known->second;
959959
}
960960

@@ -1479,15 +1479,12 @@ class ConstraintSystem {
14791479
/// from declared parameters/result and body.
14801480
llvm::MapVector<const ClosureExpr *, FunctionType *> ClosureTypes;
14811481

1482-
/// Maps expression types used within all portions of the constraint
1483-
/// system, instead of directly using the types on the expression
1484-
/// nodes themselves. This allows us to typecheck an expression and
1482+
/// Maps node types used within all portions of the constraint
1483+
/// system, instead of directly using the types on the
1484+
/// nodes themselves. This allows us to typecheck and
14851485
/// run through various diagnostics passes without actually mutating
1486-
/// the types on the expression nodes.
1487-
llvm::DenseMap<const Expr *, TypeBase *> ExprTypes;
1488-
llvm::DenseMap<const TypeLoc *, TypeBase *> TypeLocTypes;
1489-
llvm::DenseMap<const VarDecl *, TypeBase *> VarTypes;
1490-
llvm::DenseMap<const Pattern *, TypeBase *> PatternTypes;
1486+
/// the types on the nodes.
1487+
llvm::MapVector<TypedNode, Type> NodeTypes;
14911488
llvm::DenseMap<std::pair<const KeyPathExpr *, unsigned>, TypeBase *>
14921489
KeyPathComponentTypes;
14931490

@@ -1555,7 +1552,8 @@ class ConstraintSystem {
15551552
SmallVector<std::pair<ConstraintLocator *, OpenedArchetypeType *>, 4>
15561553
OpenedExistentialTypes;
15571554

1558-
/// The node -> type mappings introduced by generating constraints.
1555+
/// The nodes for which we have produced types, along with the prior type
1556+
/// each node had before introducing this type.
15591557
llvm::SmallVector<std::pair<TypedNode, Type>, 8> addedNodeTypes;
15601558

15611559
std::vector<std::pair<ConstraintLocator *, ProtocolConformanceRef>>
@@ -2238,19 +2236,12 @@ class ConstraintSystem {
22382236
assert(type && "Expected non-null type");
22392237

22402238
// Record the type.
2241-
if (auto expr = node.dyn_cast<const Expr *>()) {
2242-
ExprTypes[expr] = type.getPointer();
2243-
} else if (auto typeLoc = node.dyn_cast<const TypeLoc *>()) {
2244-
TypeLocTypes[typeLoc] = type.getPointer();
2245-
} else if (auto var = node.dyn_cast<const VarDecl *>()) {
2246-
VarTypes[var] = type.getPointer();
2247-
} else {
2248-
auto pattern = node.get<const Pattern *>();
2249-
PatternTypes[pattern] = type.getPointer();
2250-
}
2239+
Type &entry = NodeTypes[node];
2240+
Type oldType = entry;
2241+
entry = type;
22512242

22522243
// Record the fact that we ascribed a type to this node.
2253-
addedNodeTypes.push_back({node, type});
2244+
addedNodeTypes.push_back({node, oldType});
22542245
}
22552246

22562247
/// Set the type in our type map for a given expression. The side
@@ -2261,49 +2252,20 @@ class ConstraintSystem {
22612252
setType(TypedNode(&L), T);
22622253
}
22632254

2264-
/// Erase the type for the given node.
2265-
void eraseType(TypedNode node) {
2266-
if (auto expr = node.dyn_cast<const Expr *>()) {
2267-
ExprTypes.erase(expr);
2268-
} else if (auto typeLoc = node.dyn_cast<const TypeLoc *>()) {
2269-
TypeLocTypes.erase(typeLoc);
2270-
} else if (auto var = node.dyn_cast<const VarDecl *>()) {
2271-
VarTypes.erase(var);
2272-
} else {
2273-
auto pattern = node.get<const Pattern *>();
2274-
PatternTypes.erase(pattern);
2275-
}
2276-
}
2277-
22782255
void setType(KeyPathExpr *KP, unsigned I, Type T) {
22792256
assert(KP && "Expected non-null key path parameter!");
22802257
assert(T && "Expected non-null type!");
22812258
KeyPathComponentTypes[std::make_pair(KP, I)] = T.getPointer();
22822259
}
22832260

2284-
/// Check to see if we have a type for an expression.
2285-
bool hasType(const Expr *E) const {
2286-
assert(E != nullptr && "Expected non-null expression!");
2287-
return ExprTypes.find(E) != ExprTypes.end();
2288-
}
2289-
22902261
bool hasType(const TypeLoc &L) const {
22912262
return hasType(TypedNode(&L));
22922263
}
22932264

22942265
/// Check to see if we have a type for a node.
22952266
bool hasType(TypedNode node) const {
22962267
assert(!node.isNull() && "Expected non-null node");
2297-
if (auto expr = node.dyn_cast<const Expr *>()) {
2298-
return ExprTypes.find(expr) != ExprTypes.end();
2299-
} else if (auto typeLoc = node.dyn_cast<const TypeLoc *>()) {
2300-
return TypeLocTypes.find(typeLoc) != TypeLocTypes.end();
2301-
} else if (auto var = node.dyn_cast<const VarDecl *>()) {
2302-
return VarTypes.find(var) != VarTypes.end();
2303-
} else {
2304-
auto pattern = node.get<const Pattern *>();
2305-
return PatternTypes.find(pattern) != PatternTypes.end();
2306-
}
2268+
return NodeTypes.count(node) > 0;
23072269
}
23082270

23092271
bool hasType(const KeyPathExpr *KP, unsigned I) const {
@@ -2312,35 +2274,29 @@ class ConstraintSystem {
23122274
!= KeyPathComponentTypes.end();
23132275
}
23142276

2315-
/// Get the type for an expression.
2316-
Type getType(const Expr *E) const {
2317-
assert(hasType(E) && "Expected type to have been set!");
2277+
/// Get the type for an node.
2278+
Type getType(TypedNode node) const {
2279+
assert(hasType(node) && "Expected type to have been set!");
23182280
// FIXME: lvalue differences
23192281
// assert((!E->getType() ||
23202282
// E->getType()->isEqual(ExprTypes.find(E)->second)) &&
23212283
// "Mismatched types!");
2322-
return ExprTypes.find(E)->second;
2284+
return NodeTypes.find(node)->second;
23232285
}
23242286

23252287
Type getType(const TypeLoc &L) const {
2326-
assert(hasType(L) && "Expected type to have been set!");
2327-
return TypeLocTypes.find(&L)->second;
2328-
}
2329-
2330-
Type getType(const VarDecl *VD) const {
2331-
assert(hasType(VD) && "Expected type to have been set!");
2332-
return VarTypes.find(VD)->second;
2288+
return getType(TypedNode(&L));
23332289
}
23342290

23352291
Type getType(const KeyPathExpr *KP, unsigned I) const {
23362292
assert(hasType(KP, I) && "Expected type to have been set!");
23372293
return KeyPathComponentTypes.find(std::make_pair(KP, I))->second;
23382294
}
23392295

2340-
/// Retrieve the type of the variable, if known.
2341-
Type getTypeIfAvailable(const VarDecl *VD) const {
2342-
auto known = VarTypes.find(VD);
2343-
if (known == VarTypes.end())
2296+
/// Retrieve the type of the node, if known.
2297+
Type getTypeIfAvailable(TypedNode node) const {
2298+
auto known = NodeTypes.find(node);
2299+
if (known == NodeTypes.end())
23442300
return Type();
23452301

23462302
return known->second;

test/Constraints/function_builder_diags.swift

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -327,3 +327,17 @@ func checkImplicitSelfInClosure() {
327327
}
328328
}
329329
}
330+
331+
// rdar://problem/59239224 - crash because some nodes don't have type
332+
// information during solution application.
333+
struct X<T> {
334+
init(_: T) { }
335+
}
336+
337+
@TupleBuilder func foo(cond: Bool) -> some Any {
338+
if cond {
339+
tuplify(cond) { x in
340+
X(x)
341+
}
342+
}
343+
}

0 commit comments

Comments
 (0)