Skip to content

Commit 2701012

Browse files
committed
fix some reviewed issues
1 parent fb4bd68 commit 2701012

File tree

6 files changed

+186
-172
lines changed

6 files changed

+186
-172
lines changed

include/swift/AST/ASTContext.h

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -437,9 +437,6 @@ class ASTContext final {
437437
void setStatsReporter(UnifiedStatsReporter *stats);
438438

439439
private:
440-
// get `<` or `==`
441-
FuncDecl *getBinaryComparisonOperatorIntDecl(StringRef op, FuncDecl **cached) const;
442-
443440
friend class TypeChecker;
444441

445442
void installGlobalTypeChecker(TypeChecker *TC);

lib/AST/ASTContext.cpp

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1086,35 +1086,36 @@ ASTContext::getBuiltinInitDecl(NominalTypeDecl *decl,
10861086
return witness;
10871087
}
10881088

1089-
FuncDecl *ASTContext::getBinaryComparisonOperatorIntDecl(StringRef op, FuncDecl **cached) const {
1090-
if (*cached)
1091-
return *cached;
1089+
static
1090+
FuncDecl *getBinaryComparisonOperatorIntDecl(const ASTContext &C, StringRef op, FuncDecl *&cached) {
1091+
if (cached)
1092+
return cached;
10921093

1093-
if (!getIntDecl() || !getBoolDecl())
1094+
if (!C.getIntDecl() || !C.getBoolDecl())
10941095
return nullptr;
10951096

1096-
auto intType = getIntDecl()->getDeclaredType();
1097+
auto intType = C.getIntDecl()->getDeclaredType();
10971098
auto isIntParam = [&](AnyFunctionType::Param param) {
10981099
return (!param.isVariadic() && !param.isInOut() &&
10991100
param.getPlainType()->isEqual(intType));
11001101
};
1101-
auto boolType = getBoolDecl()->getDeclaredType();
1102-
auto decl = lookupOperatorFunc(*this, op,
1103-
intType, [=](FunctionType *type) {
1102+
auto boolType = C.getBoolDecl()->getDeclaredType();
1103+
auto decl = lookupOperatorFunc(C, op, intType,
1104+
[=](FunctionType *type) {
11041105
// Check for the signature: (Int, Int) -> Bool
11051106
if (type->getParams().size() != 2) return false;
11061107
if (!isIntParam(type->getParams()[0]) ||
11071108
!isIntParam(type->getParams()[1])) return false;
11081109
return type->getResult()->isEqual(boolType);
11091110
});
1110-
*cached = decl;
1111+
cached = decl;
11111112
return decl;
11121113
}
11131114
FuncDecl *ASTContext::getLessThanIntDecl() const {
1114-
return getBinaryComparisonOperatorIntDecl("<", &getImpl().LessThanIntDecl);
1115+
return getBinaryComparisonOperatorIntDecl(*this, "<", getImpl().LessThanIntDecl);
11151116
}
11161117
FuncDecl *ASTContext::getEqualIntDecl() const {
1117-
return getBinaryComparisonOperatorIntDecl("==", &getImpl().EqualIntDecl);
1118+
return getBinaryComparisonOperatorIntDecl(*this, "==", getImpl().EqualIntDecl);
11181119
}
11191120

11201121
FuncDecl *ASTContext::getHashValueForDecl() const {

lib/Sema/DerivedConformanceComparable.cpp

Lines changed: 26 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -31,52 +31,6 @@
3131

3232
using namespace swift;
3333

34-
/// Returns a generated guard statement that checks whether the given lhs and
35-
/// rhs expressions are equal. If not equal, the else block for the guard
36-
/// returns lhs < rhs.
37-
/// \p C The AST context.
38-
/// \p lhsExpr The first expression to compare for equality.
39-
/// \p rhsExpr The second expression to compare for equality.
40-
static GuardStmt *returnComparisonIfNotEqualGuard(ASTContext &C,
41-
Expr *lhsExpr,
42-
Expr *rhsExpr) {
43-
SmallVector<StmtConditionElement, 1> conditions;
44-
SmallVector<ASTNode, 1> statements;
45-
46-
// First, generate the statement for the body of the guard.
47-
// return lhs < rhs
48-
auto ltFuncExpr = new (C) UnresolvedDeclRefExpr(
49-
DeclNameRef(C.Id_LessThanOperator), DeclRefKind::BinaryOperator,
50-
DeclNameLoc());
51-
auto ltArgsTuple = TupleExpr::create(C, SourceLoc(),
52-
{ lhsExpr, rhsExpr },
53-
{ }, { }, SourceLoc(),
54-
/*HasTrailingClosure*/false,
55-
/*Implicit*/true);
56-
auto ltExpr = new (C) BinaryExpr(ltFuncExpr, ltArgsTuple, /*Implicit*/true);
57-
auto returnStmt = new (C) ReturnStmt(SourceLoc(), ltExpr);
58-
statements.emplace_back(ASTNode(returnStmt));
59-
60-
// Next, generate the condition being checked.
61-
// lhs == rhs
62-
auto cmpFuncExpr = new (C) UnresolvedDeclRefExpr(
63-
DeclNameRef(C.Id_EqualsOperator), DeclRefKind::BinaryOperator,
64-
DeclNameLoc());
65-
auto cmpArgsTuple = TupleExpr::create(C, SourceLoc(),
66-
{ lhsExpr, rhsExpr },
67-
{ }, { }, SourceLoc(),
68-
/*HasTrailingClosure*/false,
69-
/*Implicit*/true);
70-
auto cmpExpr = new (C) BinaryExpr(cmpFuncExpr, cmpArgsTuple,
71-
/*Implicit*/true);
72-
conditions.emplace_back(cmpExpr);
73-
74-
// Build and return the complete guard statement.
75-
// guard lhs == rhs else { return lhs < rhs }
76-
auto body = BraceStmt::create(C, SourceLoc(), statements, SourceLoc());
77-
return new (C) GuardStmt(SourceLoc(), C.AllocateCopy(conditions), body);
78-
}
79-
8034
// how does this code ever even get invoked? you can’t compare uninhabited enums...
8135
static std::pair<BraceStmt *, bool>
8236
deriveBodyComparable_enum_uninhabited_lt(AbstractFunctionDecl *ltDecl, void *) {
@@ -129,9 +83,9 @@ deriveBodyComparable_enum_noAssociatedValues_lt(AbstractFunctionDecl *ltDecl,
12983

13084
// Generate the conversion from the enums to integer indices.
13185
SmallVector<ASTNode, 6> statements;
132-
DeclRefExpr *aIndex = convertEnumToIndex(statements, parentDC, enumDecl,
86+
DeclRefExpr *aIndex = DerivedConformance::convertEnumToIndex(statements, parentDC, enumDecl,
13387
aParam, ltDecl, "index_a");
134-
DeclRefExpr *bIndex = convertEnumToIndex(statements, parentDC, enumDecl,
88+
DeclRefExpr *bIndex = DerivedConformance::convertEnumToIndex(statements, parentDC, enumDecl,
13589
bParam, ltDecl, "index_b");
13690

13791
// Generate the compare of the indices.
@@ -199,7 +153,7 @@ deriveBodyComparable_enum_hasAssociatedValues_lt(AbstractFunctionDecl *ltDecl, v
199153

200154
// .<elt>(let l0, let l1, ...)
201155
SmallVector<VarDecl*, 3> lhsPayloadVars;
202-
auto lhsSubpattern = enumElementPayloadSubpattern(elt, 'l', ltDecl,
156+
auto lhsSubpattern = DerivedConformance::enumElementPayloadSubpattern(elt, 'l', ltDecl,
203157
lhsPayloadVars);
204158
auto lhsElemPat = new (C) EnumElementPattern(TypeLoc::withoutLoc(enumType),
205159
SourceLoc(), DeclNameLoc(),
@@ -209,7 +163,7 @@ deriveBodyComparable_enum_hasAssociatedValues_lt(AbstractFunctionDecl *ltDecl, v
209163

210164
// .<elt>(let r0, let r1, ...)
211165
SmallVector<VarDecl*, 3> rhsPayloadVars;
212-
auto rhsSubpattern = enumElementPayloadSubpattern(elt, 'r', ltDecl,
166+
auto rhsSubpattern = DerivedConformance::enumElementPayloadSubpattern(elt, 'r', ltDecl,
213167
rhsPayloadVars);
214168
auto rhsElemPat = new (C) EnumElementPattern(TypeLoc::withoutLoc(enumType),
215169
SourceLoc(), DeclNameLoc(),
@@ -254,7 +208,8 @@ deriveBodyComparable_enum_hasAssociatedValues_lt(AbstractFunctionDecl *ltDecl, v
254208
auto rhsVar = rhsPayloadVars[varIdx];
255209
auto rhsExpr = new (C) DeclRefExpr(rhsVar, DeclNameLoc(),
256210
/*Implicit*/true);
257-
auto guardStmt = returnComparisonIfNotEqualGuard(C, lhsExpr, rhsExpr);
211+
auto guardStmt = DerivedConformance::returnComparisonIfNotEqualGuard(C,
212+
lhsExpr, rhsExpr);
258213
statementsInCase.emplace_back(guardStmt);
259214
}
260215

@@ -380,14 +335,17 @@ deriveComparable_lt(
380335

381336
bool
382337
DerivedConformance::canDeriveComparable(DeclContext *context, NominalTypeDecl *declaration) {
338+
auto enumeration = dyn_cast<EnumDecl>(declaration);
383339
// The type must be an enum.
384-
if (EnumDecl *const enumeration = dyn_cast<EnumDecl>(declaration)) {
385-
// The cases must not have non-comparable associated values or raw backing
386-
auto comparable = context->getASTContext().getProtocol(KnownProtocolKind::Comparable);
387-
return allAssociatedValuesConformToProtocol(context, enumeration, comparable) && !enumeration->hasRawType();
388-
} else {
389-
return false;
340+
if (!enumeration) {
341+
return false;
390342
}
343+
auto comparable = context->getASTContext().getProtocol(KnownProtocolKind::Comparable);
344+
if (!comparable) {
345+
return false; // not sure what should be done here instead
346+
}
347+
// The cases must not have non-comparable associated values or raw backing
348+
return allAssociatedValuesConformToProtocol(context, enumeration, comparable) && !enumeration->hasRawType();
391349
}
392350

393351
ValueDecl *DerivedConformance::deriveComparable(ValueDecl *requirement) {
@@ -396,12 +354,17 @@ ValueDecl *DerivedConformance::deriveComparable(ValueDecl *requirement) {
396354
// Build the necessary decl.
397355
if (requirement->getBaseName() == "<") {
398356
if (EnumDecl const *const enumeration = dyn_cast<EnumDecl>(this->Nominal)) {
399-
auto bodySynthesizer = !enumeration->hasCases()
400-
? &deriveBodyComparable_enum_uninhabited_lt
401-
: enumeration->hasOnlyCasesWithoutAssociatedValues()
402-
? &deriveBodyComparable_enum_noAssociatedValues_lt
403-
: &deriveBodyComparable_enum_hasAssociatedValues_lt;
404-
return deriveComparable_lt(*this, bodySynthesizer);
357+
std::pair<BraceStmt *, bool> (*synthesizer)(AbstractFunctionDecl *, void *);
358+
if (enumeration->hasCases()) {
359+
if (enumeration->hasOnlyCasesWithoutAssociatedValues()) {
360+
synthesizer = &deriveBodyComparable_enum_noAssociatedValues_lt;
361+
} else {
362+
synthesizer = &deriveBodyComparable_enum_hasAssociatedValues_lt;
363+
}
364+
} else {
365+
synthesizer = &deriveBodyComparable_enum_uninhabited_lt;
366+
}
367+
return deriveComparable_lt(*this, synthesizer);
405368
} else {
406369
llvm_unreachable("todo");
407370
}

lib/Sema/DerivedConformanceEquatableHashable.cpp

Lines changed: 12 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ static bool canDeriveConformance(DeclContext *DC,
8282
if (auto enumDecl = dyn_cast<EnumDecl>(target)) {
8383
// The cases must not have associated values, or all associated values must
8484
// conform to the protocol.
85-
return allAssociatedValuesConformToProtocol(DC, enumDecl, protocol);
85+
return DerivedConformance::allAssociatedValuesConformToProtocol(DC, enumDecl, protocol);
8686
}
8787

8888
if (auto structDecl = dyn_cast<StructDecl>(target)) {
@@ -101,7 +101,7 @@ void diagnoseFailedDerivation(DeclContext *DC, NominalTypeDecl *nominal,
101101

102102
if (auto *enumDecl = dyn_cast<EnumDecl>(nominal)) {
103103
auto nonconformingAssociatedTypes =
104-
associatedValuesNotConformingToProtocol(DC, enumDecl, protocol);
104+
DerivedConformance::associatedValuesNotConformingToProtocol(DC, enumDecl, protocol);
105105
for (auto *typeToDiagnose : nonconformingAssociatedTypes) {
106106
SourceLoc reprLoc;
107107
if (auto *repr = typeToDiagnose->getTypeRepr())
@@ -135,45 +135,6 @@ void diagnoseFailedDerivation(DeclContext *DC, NominalTypeDecl *nominal,
135135
}
136136
}
137137

138-
/// Returns a generated guard statement that checks whether the given lhs and
139-
/// rhs expressions are equal. If not equal, the else block for the guard
140-
/// returns false.
141-
/// \p C The AST context.
142-
/// \p lhsExpr The first expression to compare for equality.
143-
/// \p rhsExpr The second expression to compare for equality.
144-
static GuardStmt *returnIfNotEqualGuard(ASTContext &C,
145-
Expr *lhsExpr,
146-
Expr *rhsExpr) {
147-
SmallVector<StmtConditionElement, 1> conditions;
148-
SmallVector<ASTNode, 1> statements;
149-
150-
// First, generate the statement for the body of the guard.
151-
// return false
152-
auto falseExpr = new (C) BooleanLiteralExpr(false, SourceLoc(),
153-
/*Implicit*/true);
154-
auto returnStmt = new (C) ReturnStmt(SourceLoc(), falseExpr);
155-
statements.emplace_back(ASTNode(returnStmt));
156-
157-
// Next, generate the condition being checked.
158-
// lhs == rhs
159-
auto cmpFuncExpr = new (C) UnresolvedDeclRefExpr(
160-
DeclNameRef(C.Id_EqualsOperator), DeclRefKind::BinaryOperator,
161-
DeclNameLoc());
162-
auto cmpArgsTuple = TupleExpr::create(C, SourceLoc(),
163-
{ lhsExpr, rhsExpr },
164-
{ }, { }, SourceLoc(),
165-
/*HasTrailingClosure*/false,
166-
/*Implicit*/true);
167-
auto cmpExpr = new (C) BinaryExpr(cmpFuncExpr, cmpArgsTuple,
168-
/*Implicit*/true);
169-
conditions.emplace_back(cmpExpr);
170-
171-
// Build and return the complete guard statement.
172-
// guard lhs == rhs else { return false }
173-
auto body = BraceStmt::create(C, SourceLoc(), statements, SourceLoc());
174-
return new (C) GuardStmt(SourceLoc(), C.AllocateCopy(conditions), body);
175-
}
176-
177138
static std::pair<BraceStmt *, bool>
178139
deriveBodyEquatable_enum_uninhabited_eq(AbstractFunctionDecl *eqDecl, void *) {
179140
auto parentDC = eqDecl->getDeclContext();
@@ -225,9 +186,9 @@ deriveBodyEquatable_enum_noAssociatedValues_eq(AbstractFunctionDecl *eqDecl,
225186

226187
// Generate the conversion from the enums to integer indices.
227188
SmallVector<ASTNode, 6> statements;
228-
DeclRefExpr *aIndex = convertEnumToIndex(statements, parentDC, enumDecl,
189+
DeclRefExpr *aIndex = DerivedConformance::convertEnumToIndex(statements, parentDC, enumDecl,
229190
aParam, eqDecl, "index_a");
230-
DeclRefExpr *bIndex = convertEnumToIndex(statements, parentDC, enumDecl,
191+
DeclRefExpr *bIndex = DerivedConformance::convertEnumToIndex(statements, parentDC, enumDecl,
231192
bParam, eqDecl, "index_b");
232193

233194
// Generate the compare of the indices.
@@ -296,7 +257,7 @@ deriveBodyEquatable_enum_hasAssociatedValues_eq(AbstractFunctionDecl *eqDecl,
296257

297258
// .<elt>(let l0, let l1, ...)
298259
SmallVector<VarDecl*, 3> lhsPayloadVars;
299-
auto lhsSubpattern = enumElementPayloadSubpattern(elt, 'l', eqDecl,
260+
auto lhsSubpattern = DerivedConformance::enumElementPayloadSubpattern(elt, 'l', eqDecl,
300261
lhsPayloadVars);
301262
auto lhsElemPat = new (C) EnumElementPattern(TypeLoc::withoutLoc(enumType),
302263
SourceLoc(), DeclNameLoc(),
@@ -306,7 +267,7 @@ deriveBodyEquatable_enum_hasAssociatedValues_eq(AbstractFunctionDecl *eqDecl,
306267

307268
// .<elt>(let r0, let r1, ...)
308269
SmallVector<VarDecl*, 3> rhsPayloadVars;
309-
auto rhsSubpattern = enumElementPayloadSubpattern(elt, 'r', eqDecl,
270+
auto rhsSubpattern = DerivedConformance::enumElementPayloadSubpattern(elt, 'r', eqDecl,
310271
rhsPayloadVars);
311272
auto rhsElemPat = new (C) EnumElementPattern(TypeLoc::withoutLoc(enumType),
312273
SourceLoc(), DeclNameLoc(),
@@ -352,7 +313,8 @@ deriveBodyEquatable_enum_hasAssociatedValues_eq(AbstractFunctionDecl *eqDecl,
352313
auto rhsVar = rhsPayloadVars[varIdx];
353314
auto rhsExpr = new (C) DeclRefExpr(rhsVar, DeclNameLoc(),
354315
/*Implicit*/true);
355-
auto guardStmt = returnIfNotEqualGuard(C, lhsExpr, rhsExpr);
316+
auto guardStmt = DerivedConformance::returnFalseIfNotEqualGuard(C,
317+
lhsExpr, rhsExpr);
356318
statementsInCase.emplace_back(guardStmt);
357319
}
358320

@@ -438,7 +400,8 @@ deriveBodyEquatable_struct_eq(AbstractFunctionDecl *eqDecl, void *) {
438400
auto bPropertyExpr = new (C) DotSyntaxCallExpr(bPropertyRef, SourceLoc(),
439401
bParamRef);
440402

441-
auto guardStmt = returnIfNotEqualGuard(C, aPropertyExpr, bPropertyExpr);
403+
auto guardStmt = DerivedConformance::returnFalseIfNotEqualGuard(C,
404+
aPropertyExpr, bPropertyExpr);
442405
statements.emplace_back(guardStmt);
443406
}
444407

@@ -763,7 +726,7 @@ deriveBodyHashable_enum_noAssociatedValues_hashInto(
763726

764727
// generate: switch self {...}
765728
SmallVector<ASTNode, 3> stmts;
766-
auto discriminatorExpr = convertEnumToIndex(stmts, parentDC, enumDecl,
729+
auto discriminatorExpr = DerivedConformance::convertEnumToIndex(stmts, parentDC, enumDecl,
767730
selfDecl, hashIntoDecl,
768731
"discriminator");
769732
// generate: hasher.combine(discriminator)
@@ -818,7 +781,7 @@ deriveBodyHashable_enum_hasAssociatedValues_hashInto(
818781
SmallVector<VarDecl*, 3> payloadVars;
819782
SmallVector<ASTNode, 3> statements;
820783

821-
auto payloadPattern = enumElementPayloadSubpattern(elt, 'a', hashIntoDecl,
784+
auto payloadPattern = DerivedConformance::enumElementPayloadSubpattern(elt, 'a', hashIntoDecl,
822785
payloadVars);
823786
auto pat = new (C) EnumElementPattern(TypeLoc::withoutLoc(enumType),
824787
SourceLoc(), DeclNameLoc(),

0 commit comments

Comments
 (0)