Skip to content

Commit f01ccbf

Browse files
authored
Merge pull request #58986 from xedin/multi-statement-closure-improvements-5.7
[5.7][ConstraintSystem] A couple of improvements to multi-statement closure handling
2 parents 56f91f4 + 7ca7811 commit f01ccbf

File tree

10 files changed

+170
-28
lines changed

10 files changed

+170
-28
lines changed

include/swift/Sema/CSFix.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -528,6 +528,8 @@ class RelabelArguments final
528528

529529
bool diagnose(const Solution &solution, bool asNote = false) const override;
530530

531+
bool diagnoseForAmbiguity(CommonFixesArray commonFixes) const override;
532+
531533
static RelabelArguments *create(ConstraintSystem &cs,
532534
llvm::ArrayRef<Identifier> correctLabels,
533535
ConstraintLocator *locator);

lib/Sema/CSClosure.cpp

Lines changed: 30 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,9 @@ Expr *getVoidExpr(ASTContext &ctx) {
3737

3838
/// Find any type variable references inside of an AST node.
3939
class TypeVariableRefFinder : public ASTWalker {
40+
/// A stack of all closures the walker encountered so far.
41+
SmallVector<DeclContext *> ClosureDCs;
42+
4043
ConstraintSystem &CS;
4144
ASTNode Parent;
4245

@@ -46,9 +49,16 @@ class TypeVariableRefFinder : public ASTWalker {
4649
TypeVariableRefFinder(
4750
ConstraintSystem &cs, ASTNode parent,
4851
llvm::SmallPtrSetImpl<TypeVariableType *> &referencedVars)
49-
: CS(cs), Parent(parent), ReferencedVars(referencedVars) {}
52+
: CS(cs), Parent(parent), ReferencedVars(referencedVars) {
53+
if (auto *closure = getAsExpr<ClosureExpr>(Parent))
54+
ClosureDCs.push_back(closure);
55+
}
5056

5157
std::pair<bool, Expr *> walkToExprPre(Expr *expr) override {
58+
if (auto *closure = dyn_cast<ClosureExpr>(expr)) {
59+
ClosureDCs.push_back(closure);
60+
}
61+
5262
if (auto *DRE = dyn_cast<DeclRefExpr>(expr)) {
5363
auto *decl = DRE->getDecl();
5464

@@ -81,20 +91,33 @@ class TypeVariableRefFinder : public ASTWalker {
8191
return {true, expr};
8292
}
8393

94+
Expr *walkToExprPost(Expr *expr) override {
95+
if (auto *closure = dyn_cast<ClosureExpr>(expr)) {
96+
ClosureDCs.pop_back();
97+
}
98+
return expr;
99+
}
100+
84101
std::pair<bool, Stmt *> walkToStmtPre(Stmt *stmt) override {
85102
// Return statements have to reference outside result type
86103
// since all of them are joined by it if it's not specified
87104
// explicitly.
88105
if (isa<ReturnStmt>(stmt)) {
89106
if (auto *closure = getAsExpr<ClosureExpr>(Parent)) {
90-
inferVariables(CS.getClosureType(closure)->getResult());
107+
// Return is only viable if it belongs to a parent closure.
108+
if (currentClosureDC() == closure)
109+
inferVariables(CS.getClosureType(closure)->getResult());
91110
}
92111
}
93112

94113
return {true, stmt};
95114
}
96115

97116
private:
117+
DeclContext *currentClosureDC() const {
118+
return ClosureDCs.empty() ? nullptr : ClosureDCs.back();
119+
}
120+
98121
void inferVariables(Type type) {
99122
type = type->getWithoutSpecifierType();
100123
// Record the type variable itself because it has to
@@ -440,7 +463,7 @@ class ClosureConstraintGenerator
440463

441464
cs.addConstraint(
442465
ConstraintKind::Conversion, elementType, initType,
443-
cs.getConstraintLocator(contextualLocator,
466+
cs.getConstraintLocator(sequenceLocator,
444467
ConstraintLocator::SequenceElementType));
445468

446469
// Reference the makeIterator witness.
@@ -449,9 +472,10 @@ class ClosureConstraintGenerator
449472

450473
Type makeIteratorType =
451474
cs.createTypeVariable(locator, TVO_CanBindToNoEscape);
452-
cs.addValueWitnessConstraint(LValueType::get(sequenceType), makeIterator,
453-
makeIteratorType, closure,
454-
FunctionRefKind::Compound, contextualLocator);
475+
cs.addValueWitnessConstraint(
476+
LValueType::get(sequenceType), makeIterator, makeIteratorType,
477+
closure, FunctionRefKind::Compound,
478+
cs.getConstraintLocator(sequenceLocator, ConstraintLocator::Witness));
455479

456480
// After successful constraint generation, let's record
457481
// solution application target with all relevant information.
@@ -1176,10 +1200,6 @@ class ClosureConstraintApplication
11761200
if (isa<IfConfigDecl>(decl))
11771201
return;
11781202

1179-
// Variable declaration would be handled by a pattern binding.
1180-
if (isa<VarDecl>(decl))
1181-
return;
1182-
11831203
// Generate constraints for pattern binding declarations.
11841204
if (auto patternBinding = dyn_cast<PatternBindingDecl>(decl)) {
11851205
SolutionApplicationTarget target(patternBinding);

lib/Sema/CSDiagnostics.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1188,6 +1188,12 @@ class MissingMemberFailure : public InvalidMemberRefFailure {
11881188
: InvalidMemberRefFailure(solution, baseType, memberName, locator) {}
11891189

11901190
SourceLoc getLoc() const override {
1191+
auto *locator = getLocator();
1192+
1193+
if (locator->findLast<LocatorPathElt::ClosureBodyElement>()) {
1194+
return constraints::getLoc(getAnchor());
1195+
}
1196+
11911197
// Diagnostic should point to the member instead of its base expression.
11921198
return constraints::getLoc(getRawAnchor());
11931199
}

lib/Sema/CSFix.cpp

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -242,6 +242,38 @@ bool RelabelArguments::diagnose(const Solution &solution, bool asNote) const {
242242
return failure.diagnose(asNote);
243243
}
244244

245+
bool RelabelArguments::diagnoseForAmbiguity(
246+
CommonFixesArray commonFixes) const {
247+
SmallPtrSet<ValueDecl *, 4> overloadChoices;
248+
249+
// First, let's find overload choice associated with each
250+
// re-labeling fix.
251+
for (const auto &fix : commonFixes) {
252+
auto &solution = *fix.first;
253+
254+
auto calleeLocator = solution.getCalleeLocator(getLocator());
255+
if (!calleeLocator)
256+
return false;
257+
258+
auto overloadChoice = solution.getOverloadChoiceIfAvailable(calleeLocator);
259+
if (!overloadChoice)
260+
return false;
261+
262+
auto *decl = overloadChoice->choice.getDeclOrNull();
263+
if (!decl)
264+
return false;
265+
266+
(void)overloadChoices.insert(decl);
267+
}
268+
269+
// If all of the fixes point to the same overload choice then it's
270+
// exactly the same issue since the call site is static.
271+
if (overloadChoices.size() == 1)
272+
return diagnose(*commonFixes.front().first);
273+
274+
return false;
275+
}
276+
245277
RelabelArguments *
246278
RelabelArguments::create(ConstraintSystem &cs,
247279
llvm::ArrayRef<Identifier> correctLabels,

lib/Sema/CSStep.cpp

Lines changed: 32 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -942,13 +942,10 @@ StepResult ConjunctionStep::resume(bool prevFailed) {
942942
// and scoring information.
943943
Snapshot.reset();
944944

945-
// Restore original scores of outer context before
946-
// trying to produce a combined solution.
947-
restoreOriginalScores();
948-
949945
// Apply all of the information deduced from the
950946
// conjunction (up to the point of ambiguity)
951947
// back to the outer context and form a joined solution.
948+
unsigned numSolutions = 0;
952949
for (auto &solution : Solutions) {
953950
ConstraintSystem::SolverScope scope(CS);
954951

@@ -958,24 +955,47 @@ StepResult ConjunctionStep::resume(bool prevFailed) {
958955
// of the constraint system, so they have to be
959956
// restored right afterwards because score of the
960957
// element does contribute to the overall score.
961-
restoreOriginalScores();
958+
restoreBestScore();
959+
restoreCurrentScore(solution.getFixedScore());
960+
961+
// Transform all of the unbound outer variables into
962+
// placeholders since we are not going to solve for
963+
// each ambguous solution.
964+
{
965+
unsigned numHoles = 0;
966+
for (auto *typeVar : CS.getTypeVariables()) {
967+
if (!typeVar->getImpl().hasRepresentativeOrFixed()) {
968+
CS.assignFixedType(
969+
typeVar, PlaceholderType::get(CS.getASTContext(), typeVar));
970+
++numHoles;
971+
}
972+
}
973+
CS.increaseScore(SK_Hole, numHoles);
974+
}
975+
976+
if (CS.worseThanBestSolution())
977+
continue;
962978

963979
// Note that `worseThanBestSolution` isn't checked
964980
// here because `Solutions` were pre-filtered, and
965981
// outer score is the same for all of them.
966982
OuterSolutions.push_back(CS.finalize());
983+
++numSolutions;
967984
}
968985

969-
return done(/*isSuccess=*/true);
986+
return done(/*isSuccess=*/numSolutions > 0);
970987
}
971988

989+
auto solution = Solutions.pop_back_val();
990+
auto score = solution.getFixedScore();
991+
972992
// Restore outer type variables and prepare to solve
973993
// constraints associated with outer context together
974994
// with information deduced from the conjunction.
975-
Snapshot->setupOuterContext(Solutions.pop_back_val());
995+
Snapshot->setupOuterContext(std::move(solution));
976996

977-
// Pretend that conjunction never happend.
978-
restoreOuterState();
997+
// Pretend that conjunction never happened.
998+
restoreOuterState(score);
979999

9801000
// Now that all of the information from the conjunction has
9811001
// been applied, let's attempt to solve the outer scope.
@@ -987,10 +1007,11 @@ StepResult ConjunctionStep::resume(bool prevFailed) {
9871007
return take(prevFailed);
9881008
}
9891009

990-
void ConjunctionStep::restoreOuterState() const {
1010+
void ConjunctionStep::restoreOuterState(const Score &solutionScore) const {
9911011
// Restore best/current score, since upcoming step is going to
9921012
// work with outer scope in relation to the conjunction.
993-
restoreOriginalScores();
1013+
restoreBestScore();
1014+
restoreCurrentScore(solutionScore);
9941015

9951016
// Active all of the previously out-of-scope constraints
9961017
// because conjunction can propagate type information up

lib/Sema/CSStep.h

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -962,8 +962,11 @@ class ConjunctionStep : public BindingStep<ConjunctionElementProducer> {
962962

963963
// Restore best score only if conjunction fails because
964964
// successful outcome should keep a score set by `restoreOuterState`.
965-
if (HadFailure)
966-
restoreOriginalScores();
965+
if (HadFailure) {
966+
auto solutionScore = Score();
967+
restoreBestScore();
968+
restoreCurrentScore(solutionScore);
969+
}
967970

968971
if (OuterTimeRemaining) {
969972
auto anchor = OuterTimeRemaining->first;
@@ -1015,16 +1018,19 @@ class ConjunctionStep : public BindingStep<ConjunctionElementProducer> {
10151018

10161019
private:
10171020
/// Restore best and current scores as they were before conjunction.
1018-
void restoreOriginalScores() const {
1019-
CS.solverState->BestScore = BestScore;
1021+
void restoreCurrentScore(const Score &solutionScore) const {
10201022
CS.CurrentScore = CurrentScore;
1023+
CS.increaseScore(SK_Fix, solutionScore.Data[SK_Fix]);
1024+
CS.increaseScore(SK_Hole, solutionScore.Data[SK_Hole]);
10211025
}
10221026

1027+
void restoreBestScore() const { CS.solverState->BestScore = BestScore; }
1028+
10231029
// Restore constraint system state before conjunction.
10241030
//
10251031
// Note that this doesn't include conjunction constraint
10261032
// itself because we don't want to re-solve it.
1027-
void restoreOuterState() const;
1033+
void restoreOuterState(const Score &solutionScore) const;
10281034
};
10291035

10301036
} // end namespace constraints

test/Constraints/closures.swift

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1170,7 +1170,6 @@ func test(arr: [[Int]]) {
11701170
}
11711171

11721172
arr.map { ($0 as? [Int]).map { A($0) } } // expected-error {{missing argument label 'arg:' in call}} {{36-36=arg: }}
1173-
// expected-warning@-1 {{conditional cast from '[Int]' to '[Int]' always succeeds}}
11741173
}
11751174

11761175
func closureWithCaseArchetype<T>(_: T.Type) {

test/Constraints/generics.swift

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -902,7 +902,8 @@ func rdar78781552() {
902902
// expected-error@-1 {{generic struct 'Test' requires that '(((Int) throws -> Bool) throws -> [Int])?' conform to 'RandomAccessCollection'}}
903903
// expected-error@-2 {{generic parameter 'Content' could not be inferred}} expected-note@-2 {{explicitly specify the generic arguments to fix this issue}}
904904
// expected-error@-3 {{cannot convert value of type '(((Int) throws -> Bool) throws -> [Int])?' to expected argument type '[(((Int) throws -> Bool) throws -> [Int])?]'}}
905-
// expected-error@-4 {{missing argument for parameter 'filter' in call}}
905+
// expected-error@-4 {{missing argument label 'data:' in call}}
906+
// expected-error@-5 {{missing argument for parameter 'filter' in call}}
906907
}
907908
}
908909

test/expr/closure/multi_statement.swift

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -234,6 +234,20 @@ func test_local_function_capturing_vars() {
234234
}
235235
}
236236

237+
func test_test_invalid_redeclaration() {
238+
func test(_: () -> Void) {
239+
}
240+
241+
test {
242+
let foo = 0 // expected-note {{'foo' previously declared here}}
243+
let foo = foo // expected-error {{invalid redeclaration of 'foo'}}
244+
}
245+
246+
test {
247+
let (foo, foo) = (5, 6) // expected-error {{invalid redeclaration of 'foo'}} expected-note {{'foo' previously declared here}}
248+
}
249+
}
250+
237251
func test_pattern_ambiguity_doesnot_crash_compiler() {
238252
enum E {
239253
case hello(result: Int) // expected-note 2 {{found this candidate}}
@@ -350,3 +364,43 @@ func test_no_crash_with_circular_ref_due_to_error() {
350364
return 0
351365
}
352366
}
367+
368+
func test_diagnosing_on_missing_member_in_case() {
369+
enum E {
370+
case one
371+
}
372+
373+
func test(_: (E) -> Void) {}
374+
375+
test {
376+
switch $0 {
377+
case .one: break
378+
case .unknown: break // expected-error {{type 'E' has no member 'unknown'}}
379+
}
380+
}
381+
}
382+
383+
// Type finder shouldn't bring external closure result type
384+
// into the scope of an inner closure e.g. while solving
385+
// init of pattern binding `x`.
386+
func test_type_finder_doesnt_walk_into_inner_closures() {
387+
func test<T>(fn: () -> T) -> T { fn() }
388+
389+
_ = test { // Ok
390+
let x = test {
391+
42
392+
}
393+
394+
let _ = test {
395+
test { "" }
396+
}
397+
398+
// multi-statement
399+
let _ = test {
400+
_ = 42
401+
return test { "" }
402+
}
403+
404+
return x
405+
}
406+
}

test/expr/unary/keypath/keypath.swift

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -507,6 +507,7 @@ func testLabeledSubscript() {
507507
// TODO: These ought to work without errors.
508508
let _ = \AA.[keyPath: k]
509509
// expected-error@-1 {{cannot convert value of type 'KeyPath<AA, Int>' to expected argument type 'Int'}}
510+
// expected-error@-2 {{extraneous argument label 'keyPath:' in call}}
510511

511512
let _ = \AA.[keyPath: \AA.[labeled: 0]] // expected-error {{extraneous argument label 'keyPath:' in call}}
512513
// expected-error@-1 {{cannot convert value of type 'KeyPath<AA, Int>' to expected argument type 'Int'}}

0 commit comments

Comments
 (0)