Skip to content

Commit 88ad3b7

Browse files
authored
Merge pull request swiftlang#39412 from Jumhyn/sr-15192
Replace inferrable types before binding literals to contextual type
2 parents 8e23a38 + f6b6f0c commit 88ad3b7

File tree

6 files changed

+163
-25
lines changed

6 files changed

+163
-25
lines changed

include/swift/Sema/ConstraintSystem.h

Lines changed: 40 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -908,6 +908,7 @@ using OpenedTypeMap =
908908
/// within a constraint system.
909909
struct ContextualTypeInfo {
910910
TypeLoc typeLoc;
911+
911912
ContextualTypePurpose purpose;
912913

913914
ContextualTypeInfo() : typeLoc(TypeLoc()), purpose(CTP_Unused) {}
@@ -2343,8 +2344,10 @@ class ConstraintSystem {
23432344
solutionApplicationTargets;
23442345

23452346
/// Contextual type information for expressions that are part of this
2346-
/// constraint system.
2347-
llvm::MapVector<ASTNode, ContextualTypeInfo> contextualTypes;
2347+
/// constraint system. The second type, if valid, contains the type as it
2348+
/// should appear in actual constraint. This will have unbound generic types
2349+
/// opened, placeholder types converted to type variables, etc.
2350+
llvm::MapVector<ASTNode, std::pair<ContextualTypeInfo, Type>> contextualTypes;
23482351

23492352
/// Information about each case label item tracked by the constraint system.
23502353
llvm::SmallMapVector<const CaseLabelItem *, CaseLabelItemInfo, 4>
@@ -3184,21 +3187,47 @@ class ConstraintSystem {
31843187
assert(bool(node) && "Expected non-null expression!");
31853188
assert(contextualTypes.count(node) == 0 &&
31863189
"Already set this contextual type");
3187-
contextualTypes[node] = {T, purpose};
3190+
contextualTypes[node] = {{T, purpose}, Type()};
31883191
}
31893192

31903193
Optional<ContextualTypeInfo> getContextualTypeInfo(ASTNode node) const {
31913194
auto known = contextualTypes.find(node);
31923195
if (known == contextualTypes.end())
31933196
return None;
3194-
return known->second;
3195-
}
3196-
3197-
Type getContextualType(ASTNode node) const {
3198-
auto result = getContextualTypeInfo(node);
3199-
if (result)
3200-
return result->typeLoc.getType();
3201-
return Type();
3197+
return known->second.first;
3198+
}
3199+
3200+
/// Gets the contextual type recorded for an AST node. When fetching a type
3201+
/// for use in constraint solving, \c forConstraint should be set to \c true,
3202+
/// which will ensure that unbound generics have been opened and placeholder
3203+
/// types have been converted to type variables, etc.
3204+
Type getContextualType(ASTNode node, bool forConstraint = false) {
3205+
if (forConstraint) {
3206+
auto known = contextualTypes.find(node);
3207+
if (known == contextualTypes.end())
3208+
return Type();
3209+
3210+
// If we've already computed a type for use in the constraint system,
3211+
// use that.
3212+
if (known->second.second)
3213+
return known->second.second;
3214+
3215+
// Otherwise, compute a type that can be used in a constraint and record
3216+
// it.
3217+
auto info = known->second.first;
3218+
3219+
auto *locator = getConstraintLocator(
3220+
node, LocatorPathElt::ContextualType(info.purpose));
3221+
known->second.second = replaceInferableTypesWithTypeVars(info.getType(),
3222+
locator);
3223+
3224+
return known->second.second;
3225+
} else {
3226+
auto result = getContextualTypeInfo(node);
3227+
if (result)
3228+
return result->getType();
3229+
return Type();
3230+
}
32023231
}
32033232

32043233
TypeLoc getContextualTypeLoc(ASTNode node) const {

lib/Sema/CSGen.cpp

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1731,9 +1731,12 @@ namespace {
17311731
};
17321732

17331733
// If a contextual type exists for this expression, apply it directly.
1734-
Optional<Type> arrayElementType;
1735-
if (contextualType &&
1736-
(arrayElementType = ConstraintSystem::isArrayType(contextualType))) {
1734+
if (contextualType && ConstraintSystem::isArrayType(contextualType)) {
1735+
// Now that we know we're actually going to use the type, get the
1736+
// version for use in a constraint.
1737+
contextualType = CS.getContextualType(expr, /*forConstraint=*/true);
1738+
Optional<Type> arrayElementType =
1739+
ConstraintSystem::isArrayType(contextualType);
17371740
CS.addConstraint(ConstraintKind::LiteralConformsTo, contextualType,
17381741
arrayProto->getDeclaredInterfaceType(),
17391742
locator);
@@ -1836,14 +1839,19 @@ namespace {
18361839
auto locator = CS.getConstraintLocator(expr);
18371840
auto contextualType = CS.getContextualType(expr);
18381841
auto contextualPurpose = CS.getContextualTypePurpose(expr);
1839-
auto openedType =
1840-
CS.openOpaqueType(contextualType, contextualPurpose, locator);
18411842

18421843
// If a contextual type exists for this expression and is a dictionary
18431844
// type, apply it directly.
1844-
Optional<std::pair<Type, Type>> dictionaryKeyValue;
1845-
if (openedType && (dictionaryKeyValue =
1846-
ConstraintSystem::isDictionaryType(openedType))) {
1845+
if (contextualType && ConstraintSystem::isDictionaryType(contextualType)) {
1846+
// Now that we know we're actually going to use the type, get the
1847+
// version for use in a constraint.
1848+
contextualType = CS.getContextualType(expr, /*forConstraint=*/true);
1849+
auto openedType =
1850+
CS.openOpaqueType(contextualType, contextualPurpose, locator);
1851+
openedType = CS.replaceInferableTypesWithTypeVars(
1852+
openedType, CS.getConstraintLocator(expr));
1853+
auto dictionaryKeyValue =
1854+
ConstraintSystem::isDictionaryType(openedType);
18471855
Type contextualDictionaryKeyType;
18481856
Type contextualDictionaryValueType;
18491857
std::tie(contextualDictionaryKeyType,

lib/Sema/CSSolver.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -180,8 +180,9 @@ Solution ConstraintSystem::finalize() {
180180
}
181181

182182
// Remember contextual types.
183-
solution.contextualTypes.assign(
184-
contextualTypes.begin(), contextualTypes.end());
183+
for (auto &entry : contextualTypes) {
184+
solution.contextualTypes.push_back({entry.first, entry.second.first});
185+
}
185186

186187
solution.solutionApplicationTargets = solutionApplicationTargets;
187188
solution.caseLabelItems = caseLabelItems;

lib/Sema/TypeCheckConstraints.cpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1116,9 +1116,10 @@ void ConstraintSystem::print(raw_ostream &out) const {
11161116

11171117
out << "Score: " << CurrentScore << "\n";
11181118

1119-
for (const auto &contextualType : contextualTypes) {
1120-
out << "Contextual Type: " << contextualType.second.getType().getString(PO);
1121-
if (TypeRepr *TR = contextualType.second.typeLoc.getTypeRepr()) {
1119+
for (const auto &contextualTypeEntry : contextualTypes) {
1120+
auto info = contextualTypeEntry.second.first;
1121+
out << "Contextual Type: " << info.getType().getString(PO);
1122+
if (TypeRepr *TR = info.typeLoc.getTypeRepr()) {
11221123
out << " at ";
11231124
TR->getSourceRange().print(out, getASTContext().SourceMgr, /*text*/false);
11241125
}

unittests/Sema/CMakeLists.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,8 @@ add_swift_unittest(swiftSemaTests
33
SemaFixture.cpp
44
BindingInferenceTests.cpp
55
ConstraintSimplificationTests.cpp
6-
UnresolvedMemberLookupTests.cpp)
6+
UnresolvedMemberLookupTests.cpp
7+
PlaceholderTypeInferenceTests.cpp)
78

89
target_link_libraries(swiftSemaTests
910
PRIVATE
Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
//===--- PlaceholderTypeInferenceTests.cpp --------------------------------===//
2+
//
3+
// This source file is part of the Swift.org open source project
4+
//
5+
// Copyright (c) 2014 - 2021 Apple Inc. and the Swift project authors
6+
// Licensed under Apache License v2.0 with Runtime Library Exception
7+
//
8+
// See https://swift.org/LICENSE.txt for license information
9+
// See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors
10+
//
11+
//===----------------------------------------------------------------------===//
12+
13+
#include "SemaFixture.h"
14+
#include "swift/Sema/ConstraintSystem.h"
15+
16+
using namespace swift;
17+
using namespace swift::unittest;
18+
using namespace swift::constraints;
19+
20+
TEST_F(SemaTest, TestPlaceholderInferenceForArrayLiteral) {
21+
auto *intTypeDecl = getStdlibNominalTypeDecl("Int");
22+
23+
auto *intLiteral = new (Context) IntegerLiteralExpr("0", SourceLoc(), true);
24+
auto *arrayExpr = ArrayExpr::create(Context, SourceLoc(), {intLiteral}, {}, SourceLoc());
25+
26+
auto *placeholderRepr = new (Context) PlaceholderTypeRepr(SourceLoc());
27+
auto *arrayRepr = new (Context) ArrayTypeRepr(placeholderRepr, SourceRange());
28+
auto placeholderTy = PlaceholderType::get(Context, placeholderRepr);
29+
auto *arrayTy = ArraySliceType::get(placeholderTy);
30+
31+
auto *varDecl = new (Context) VarDecl(false, VarDecl::Introducer::Let, SourceLoc(), Context.getIdentifier("x"), DC);
32+
auto *namedPattern = new (Context) NamedPattern(varDecl);
33+
auto *typedPattern = new (Context) TypedPattern(namedPattern, arrayRepr);
34+
35+
auto target = SolutionApplicationTarget::forInitialization(arrayExpr, DC, arrayTy, typedPattern, /*bindPatternVarsOneWay=*/false);
36+
37+
ConstraintSystem cs(DC, ConstraintSystemOptions());
38+
cs.setContextualType(arrayExpr, {arrayRepr, arrayTy}, CTP_Initialization);
39+
cs.generateConstraints(target, FreeTypeVariableBinding::Disallow);
40+
SmallVector<Solution, 2> solutions;
41+
cs.solve(solutions);
42+
43+
// We should have a solution.
44+
ASSERT_EQ(solutions.size(), 1u);
45+
46+
auto &solution = solutions[0];
47+
48+
auto eltTy = ConstraintSystem::isArrayType(solution.simplifyType(solution.getType(arrayExpr)));
49+
ASSERT_TRUE(eltTy.hasValue());
50+
ASSERT_TRUE((*eltTy)->is<StructType>());
51+
ASSERT_EQ((*eltTy)->getAs<StructType>()->getDecl(), intTypeDecl);
52+
}
53+
54+
TEST_F(SemaTest, TestPlaceholderInferenceForDictionaryLiteral) {
55+
auto *intTypeDecl = getStdlibNominalTypeDecl("Int");
56+
auto *stringTypeDecl = getStdlibNominalTypeDecl("String");
57+
58+
auto *intLiteral = new (Context) IntegerLiteralExpr("0", SourceLoc(), true);
59+
auto *stringLiteral = new (Context) StringLiteralExpr("test", SourceRange(), true);
60+
auto *kvTupleExpr = TupleExpr::create(Context, SourceLoc(), {stringLiteral, intLiteral}, {}, {}, SourceLoc(), true);
61+
auto *dictExpr = DictionaryExpr::create(Context, SourceLoc(), {kvTupleExpr}, {}, SourceLoc());
62+
63+
auto *keyPlaceholderRepr = new (Context) PlaceholderTypeRepr(SourceLoc());
64+
auto *valPlaceholderRepr = new (Context) PlaceholderTypeRepr(SourceLoc());
65+
auto *dictRepr = new (Context) DictionaryTypeRepr(keyPlaceholderRepr, valPlaceholderRepr, SourceLoc(), SourceRange());
66+
auto keyPlaceholderTy = PlaceholderType::get(Context, keyPlaceholderRepr);
67+
auto valPlaceholderTy = PlaceholderType::get(Context, valPlaceholderRepr);
68+
auto *dictTy = DictionaryType::get(keyPlaceholderTy, valPlaceholderTy);
69+
70+
auto *varDecl = new (Context) VarDecl(false, VarDecl::Introducer::Let, SourceLoc(), Context.getIdentifier("x"), DC);
71+
auto *namedPattern = new (Context) NamedPattern(varDecl);
72+
auto *typedPattern = new (Context) TypedPattern(namedPattern, dictRepr);
73+
74+
auto target = SolutionApplicationTarget::forInitialization(dictExpr, DC, dictTy, typedPattern, /*bindPatternVarsOneWay=*/false);
75+
76+
ConstraintSystem cs(DC, ConstraintSystemOptions());
77+
cs.setContextualType(dictExpr, {dictRepr, dictTy}, CTP_Initialization);
78+
cs.generateConstraints(target, FreeTypeVariableBinding::Disallow);
79+
SmallVector<Solution, 2> solutions;
80+
cs.solve(solutions);
81+
82+
// We should have a solution.
83+
ASSERT_EQ(solutions.size(), 1u);
84+
85+
auto &solution = solutions[0];
86+
87+
auto keyValTys = ConstraintSystem::isDictionaryType(solution.simplifyType(solution.getType(dictExpr)));
88+
ASSERT_TRUE(keyValTys.hasValue());
89+
90+
Type keyTy;
91+
Type valTy;
92+
std::tie(keyTy, valTy) = *keyValTys;
93+
ASSERT_TRUE(keyTy->is<StructType>());
94+
ASSERT_EQ(keyTy->getAs<StructType>()->getDecl(), stringTypeDecl);
95+
96+
ASSERT_TRUE(valTy->is<StructType>());
97+
ASSERT_EQ(valTy->getAs<StructType>()->getDecl(), intTypeDecl);
98+
}

0 commit comments

Comments
 (0)