Skip to content

Commit 85c82b2

Browse files
authored
Merge pull request #42021 from nkcsgexi/enum-in-array-const
2 parents e2a394b + 6206d7c commit 85c82b2

File tree

6 files changed

+51
-23
lines changed

6 files changed

+51
-23
lines changed

include/swift/AST/Expr.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -496,8 +496,8 @@ class alignas(8) Expr : public ASTAllocated<Expr> {
496496
return getSemanticsProvidingExpr()->getKind() == ExprKind::InOut;
497497
}
498498

499-
bool printConstExprValue(llvm::raw_ostream *OS) const;
500-
bool isSemanticallyConstExpr() const;
499+
bool printConstExprValue(llvm::raw_ostream *OS, llvm::function_ref<bool(Expr*)> additionalCheck) const;
500+
bool isSemanticallyConstExpr(llvm::function_ref<bool(Expr*)> additionalCheck = nullptr) const;
501501

502502
/// Returns false if this expression needs to be wrapped in parens when
503503
/// used inside of a any postfix expression, true otherwise.

lib/APIDigester/ModuleAnalyzerNodes.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2265,7 +2265,7 @@ class ConstExtractor: public ASTWalker {
22652265
void record(Expr *E, Expr *ValueProvider, StringRef ReferecedD = "") {
22662266
std::string content;
22672267
llvm::raw_string_ostream os(content);
2268-
ValueProvider->printConstExprValue(&os);
2268+
ValueProvider->printConstExprValue(&os, nullptr);
22692269
assert(!content.empty());
22702270
auto buffered = SCtx.buffer(content);
22712271
switch(ValueProvider->getKind()) {

lib/AST/Expr.cpp

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -198,7 +198,8 @@ Expr *Expr::getSemanticsProvidingExpr() {
198198
return this;
199199
}
200200

201-
bool Expr::printConstExprValue(llvm::raw_ostream *OS) const {
201+
bool Expr::printConstExprValue(llvm::raw_ostream *OS,
202+
llvm::function_ref<bool(Expr*)> additionalCheck) const {
202203
auto print = [&](StringRef text) {
203204
if (OS) {
204205
*OS << text;
@@ -242,7 +243,7 @@ bool Expr::printConstExprValue(llvm::raw_ostream *OS) const {
242243
for (unsigned N = CE->getNumElements(), I = 0; I != N; I ++) {
243244
auto Ele = CE->getElement(I);
244245
auto needComma = I + 1 != N;
245-
if (!Ele->printConstExprValue(OS)) {
246+
if (!Ele->printConstExprValue(OS, additionalCheck)) {
246247
return false;
247248
}
248249
if (needComma)
@@ -257,7 +258,7 @@ bool Expr::printConstExprValue(llvm::raw_ostream *OS) const {
257258
for (unsigned N = TE->getNumElements(), I = 0; I != N; I ++) {
258259
auto Ele = TE->getElement(I);
259260
auto needComma = I + 1 != N;
260-
if (!Ele->printConstExprValue(OS)) {
261+
if (!Ele->printConstExprValue(OS, additionalCheck)) {
261262
return false;
262263
}
263264
if (needComma)
@@ -267,12 +268,13 @@ bool Expr::printConstExprValue(llvm::raw_ostream *OS) const {
267268
return true;
268269
}
269270
default:
270-
return false;
271+
return additionalCheck && additionalCheck(const_cast<Expr*>(this));
271272
}
272273
}
273274

274-
bool Expr::isSemanticallyConstExpr() const {
275-
return printConstExprValue(nullptr);
275+
bool Expr::isSemanticallyConstExpr(
276+
llvm::function_ref<bool(Expr*)> additionalCheck) const {
277+
return printConstExprValue(nullptr, additionalCheck);
276278
}
277279

278280
Expr *Expr::getValueProvidingExpr() {

lib/Sema/CSFix.cpp

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1143,14 +1143,26 @@ NotCompileTimeConst::create(ConstraintSystem &cs, Type paramTy,
11431143

11441144
bool NotCompileTimeConst::diagnose(const Solution &solution, bool asNote) const {
11451145
auto *locator = getLocator();
1146-
// Referencing an enum element directly is considered a compile-time literal.
1147-
if (auto *d = solution.resolveLocatorToDecl(locator).getDecl()) {
1148-
if (isa<EnumElementDecl>(d)) {
1149-
if (!d->hasParameterList()) {
1150-
return true;
1146+
if (auto *E = getAsExpr(locator->getAnchor())) {
1147+
auto isAccepted = E->isSemanticallyConstExpr([&](Expr *E) {
1148+
if (auto *UMC = dyn_cast<UnresolvedMemberChainResultExpr>(E)) {
1149+
E = UMC->getSubExpr();
11511150
}
1152-
}
1151+
auto locator = solution.getConstraintSystem().getConstraintLocator(E);
1152+
// Referencing an enum element directly is considered a compile-time literal.
1153+
if (auto *d = solution.resolveLocatorToDecl(locator).getDecl()) {
1154+
if (isa<EnumElementDecl>(d)) {
1155+
if (!d->hasParameterList()) {
1156+
return true;
1157+
}
1158+
}
1159+
}
1160+
return false;
1161+
});
1162+
if (isAccepted)
1163+
return true;
11531164
}
1165+
11541166
NotCompileTimeConstFailure failure(solution, locator);
11551167
return failure.diagnose(asNote);
11561168
}

lib/Sema/CSSimplify.cpp

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1961,14 +1961,8 @@ static ConstraintSystem::TypeMatchResult matchCallArguments(
19611961
auto *locator = cs.getConstraintLocator(loc);
19621962
SourceRange range;
19631963
// simplify locator so the anchor is the exact argument.
1964-
locator = simplifyLocator(cs, locator, range);
1965-
if (locator->getPath().empty() &&
1966-
locator->getAnchor().isExpr(ExprKind::UnresolvedMemberChainResult)) {
1967-
locator =
1968-
cs.getConstraintLocator(cast<UnresolvedMemberChainResultExpr>(
1969-
locator->getAnchor().get<Expr*>())->getSubExpr());
1970-
}
1971-
cs.recordFix(NotCompileTimeConst::create(cs, paramTy, locator));
1964+
cs.recordFix(NotCompileTimeConst::create(cs, paramTy,
1965+
simplifyLocator(cs, locator, range)));
19721966
}
19731967

19741968
cs.addConstraint(
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
// RUN: %target-typecheck-verify-swift
2+
3+
enum E {
4+
case a
5+
case b
6+
7+
var z: E { .b }
8+
}
9+
10+
func getAE() -> E { return .a }
11+
12+
func test_without_const(_ : [E]) {}
13+
func testArr(_ : _const [E]) {}
14+
15+
testArr([])
16+
testArr([.a])
17+
testArr([.a, .b])
18+
19+
testArr([getAE()]) // expected-error {{expect a compile-time constant literal}}
20+
testArr([.a, .b, .a.z]) // expected-error {{expect a compile-time constant literal}}

0 commit comments

Comments
 (0)