Skip to content

Commit 49e7c61

Browse files
authored
[Diagnostics] Diagnose comparisons with '.nan' and suggest using '.isNan' instead (swiftlang#33860)
* [AST] Add 'FloatingPoint' known protocol kind * [Sema] Emit a diagnostic for comparisons with '.nan' instead of using '.isNan' * [Sema] Update '.nan' comparison diagnostic wording * [Sema] Explicitly check for either arguments to be '.nan' and tweak a comment * [Test] Tweak some comments * [Diagnostic] Change 'isNan' to 'isNaN' * [Sema] Fix a bug where firstArg was checked twice for FloatingPoint conformance and update some comments * [Test] Fix comments in test file * [NFC] Add a new 'isStandardComparisonOperator' method to 'Identifier' and use it in ConstraintSystem * [NFC] Reuse argument decl extraction code and switch over to the new 'isStandardComparisonOperator' method * [NFC] Update conformsToKnownProtocol to accept DeclContext and use it to check for FloatingPoint conformance
1 parent f9c70b4 commit 49e7c61

File tree

10 files changed

+191
-13
lines changed

10 files changed

+191
-13
lines changed

include/swift/AST/DiagnosticsSema.def

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3351,6 +3351,16 @@ ERROR(unordered_adjacent_operators,none,
33513351
ERROR(missing_builtin_precedence_group,none,
33523352
"broken standard library: missing builtin precedence group %0",
33533353
(Identifier))
3354+
WARNING(nan_comparison, none,
3355+
"comparison with '.nan' using %0 is always %select{false|true}1, use "
3356+
"'%2.isNaN' to check if '%3' %select{is not a number|is a number}1",
3357+
(Identifier, bool, StringRef, StringRef))
3358+
WARNING(nan_comparison_without_isnan, none,
3359+
"comparison with '.nan' using %0 is always %select{false|true}1",
3360+
(Identifier, bool))
3361+
WARNING(nan_comparison_both_nan, none,
3362+
"'.nan' %0 '.nan' is always %select{false|true}1",
3363+
(StringRef, bool))
33543364

33553365
// If you change this, also change enum TryKindForDiagnostics.
33563366
#define TRY_KIND_SELECT(SUB) "%select{try|try!|try?|await}" #SUB

include/swift/AST/Identifier.h

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,14 @@ class Identifier {
109109
// Handle the high unicode case out of line.
110110
return isOperatorSlow();
111111
}
112-
112+
113+
// Returns whether this is a standard comparison operator,
114+
// such as '==', '>=' or '!=='.
115+
bool isStandardComparisonOperator() const {
116+
return is("==") || is("!=") || is("===") || is("!==") || is("<") ||
117+
is(">") || is("<=") || is(">=");
118+
}
119+
113120
/// isOperatorStartCodePoint - Return true if the specified code point is a
114121
/// valid start of an operator.
115122
static bool isOperatorStartCodePoint(uint32_t C) {

include/swift/AST/KnownProtocols.def

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,8 @@ PROTOCOL(StringInterpolationProtocol)
8787
PROTOCOL(AdditiveArithmetic)
8888
PROTOCOL(Differentiable)
8989

90+
PROTOCOL(FloatingPoint)
91+
9092
EXPRESSIBLE_BY_LITERAL_PROTOCOL(ExpressibleByArrayLiteral, "Array", false)
9193
EXPRESSIBLE_BY_LITERAL_PROTOCOL(ExpressibleByBooleanLiteral, "BooleanLiteralType", true)
9294
EXPRESSIBLE_BY_LITERAL_PROTOCOL(ExpressibleByDictionaryLiteral, "Dictionary", false)

lib/IRGen/GenMeta.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5044,6 +5044,7 @@ SpecialProtocol irgen::getSpecialProtocolID(ProtocolDecl *P) {
50445044
case KnownProtocolKind::StringInterpolationProtocol:
50455045
case KnownProtocolKind::AdditiveArithmetic:
50465046
case KnownProtocolKind::Differentiable:
5047+
case KnownProtocolKind::FloatingPoint:
50475048
return SpecialProtocol::None;
50485049
}
50495050

lib/Sema/CSDiagnostics.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,7 @@ Type FailureDiagnostic::restoreGenericParameters(
165165
bool FailureDiagnostic::conformsToKnownProtocol(
166166
Type type, KnownProtocolKind protocol) const {
167167
auto &cs = getConstraintSystem();
168-
return constraints::conformsToKnownProtocol(cs, type, protocol);
168+
return constraints::conformsToKnownProtocol(cs.DC, type, protocol);
169169
}
170170

171171
Type RequirementFailure::getOwnerType() const {

lib/Sema/CSGen.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1767,11 +1767,11 @@ namespace {
17671767

17681768
auto type = contextualType->lookThroughAllOptionalTypes();
17691769
if (conformsToKnownProtocol(
1770-
CS, type, KnownProtocolKind::ExpressibleByArrayLiteral))
1770+
CS.DC, type, KnownProtocolKind::ExpressibleByArrayLiteral))
17711771
return false;
17721772

17731773
return conformsToKnownProtocol(
1774-
CS, type, KnownProtocolKind::ExpressibleByDictionaryLiteral);
1774+
CS.DC, type, KnownProtocolKind::ExpressibleByDictionaryLiteral);
17751775
};
17761776

17771777
if (isDictionaryContextualType(contextualType)) {

lib/Sema/ConstraintSystem.cpp

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4173,11 +4173,11 @@ bool constraints::hasAppliedSelf(const OverloadChoice &choice,
41734173
doesMemberRefApplyCurriedSelf(baseType, decl);
41744174
}
41754175

4176-
bool constraints::conformsToKnownProtocol(ConstraintSystem &cs, Type type,
4176+
bool constraints::conformsToKnownProtocol(DeclContext *dc, Type type,
41774177
KnownProtocolKind protocol) {
41784178
if (auto *proto =
4179-
TypeChecker::getProtocol(cs.getASTContext(), SourceLoc(), protocol))
4180-
return (bool)TypeChecker::conformsToProtocol(type, proto, cs.DC);
4179+
TypeChecker::getProtocol(dc->getASTContext(), SourceLoc(), protocol))
4180+
return (bool)TypeChecker::conformsToProtocol(type, proto, dc);
41814181
return false;
41824182
}
41834183

@@ -4202,7 +4202,8 @@ Type constraints::isRawRepresentable(
42024202
ConstraintSystem &cs, Type type,
42034203
KnownProtocolKind rawRepresentableProtocol) {
42044204
Type rawTy = isRawRepresentable(cs, type);
4205-
if (!rawTy || !conformsToKnownProtocol(cs, rawTy, rawRepresentableProtocol))
4205+
if (!rawTy ||
4206+
!conformsToKnownProtocol(cs.DC, rawTy, rawRepresentableProtocol))
42064207
return Type();
42074208

42084209
return rawTy;
@@ -4496,9 +4497,7 @@ bool constraints::isStandardComparisonOperator(ASTNode node) {
44964497
if (!expr) return false;
44974498

44984499
if (auto opName = getOperatorName(expr)) {
4499-
return opName->is("==") || opName->is("!=") || opName->is("===") ||
4500-
opName->is("!==") || opName->is("<") || opName->is(">") ||
4501-
opName->is("<=") || opName->is(">=");
4500+
return opName->isStandardComparisonOperator();
45024501
}
45034502
return false;
45044503
}

lib/Sema/ConstraintSystem.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5521,7 +5521,7 @@ bool hasAppliedSelf(const OverloadChoice &choice,
55215521
llvm::function_ref<Type(Type)> getFixedType);
55225522

55235523
/// Check whether type conforms to a given known protocol.
5524-
bool conformsToKnownProtocol(ConstraintSystem &cs, Type type,
5524+
bool conformsToKnownProtocol(DeclContext *dc, Type type,
55255525
KnownProtocolKind protocol);
55265526

55275527
/// Check whether given type conforms to `RawPepresentable` protocol

lib/Sema/MiscDiagnostics.cpp

Lines changed: 129 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,9 @@
1515
//===----------------------------------------------------------------------===//
1616

1717
#include "MiscDiagnostics.h"
18-
#include "TypeChecker.h"
18+
#include "ConstraintSystem.h"
1919
#include "TypeCheckAvailability.h"
20+
#include "TypeChecker.h"
2021
#include "swift/AST/ASTWalker.h"
2122
#include "swift/AST/NameLookup.h"
2223
#include "swift/AST/NameLookupRequests.h"
@@ -34,6 +35,7 @@
3435

3536
#define DEBUG_TYPE "Sema"
3637
using namespace swift;
38+
using namespace constraints;
3739

3840
/// Return true if this expression is an implicit promotion from T to T?.
3941
static Expr *isImplicitPromotionToOptional(Expr *E) {
@@ -4438,6 +4440,131 @@ static void diagnoseExplicitUseOfLazyVariableStorage(const Expr *E,
44384440
const_cast<Expr *>(E)->walk(Walker);
44394441
}
44404442

4443+
static void diagnoseComparisonWithNaN(const Expr *E, const DeclContext *DC) {
4444+
class ComparisonWithNaNFinder : public ASTWalker {
4445+
const ASTContext &C;
4446+
const DeclContext *DC;
4447+
4448+
public:
4449+
ComparisonWithNaNFinder(const DeclContext *dc)
4450+
: C(dc->getASTContext()), DC(dc) {}
4451+
4452+
void tryDiagnoseComparisonWithNaN(BinaryExpr *BE) {
4453+
ValueDecl *comparisonDecl = nullptr;
4454+
4455+
// Comparison functions like == or <= take two arguments.
4456+
if (BE->getArg()->getNumElements() != 2) {
4457+
return;
4458+
}
4459+
4460+
// Dig out the function declaration.
4461+
if (auto Fn = BE->getFn()) {
4462+
if (auto DSCE = dyn_cast<DotSyntaxCallExpr>(Fn)) {
4463+
comparisonDecl = DSCE->getCalledValue();
4464+
} else {
4465+
comparisonDecl = BE->getCalledValue();
4466+
}
4467+
}
4468+
4469+
// Bail out if it isn't a function.
4470+
if (!comparisonDecl || !isa<FuncDecl>(comparisonDecl)) {
4471+
return;
4472+
}
4473+
4474+
// We're only interested in comparison functions like == or <=.
4475+
auto comparisonDeclName = comparisonDecl->getBaseIdentifier();
4476+
if (!comparisonDeclName.isStandardComparisonOperator()) {
4477+
return;
4478+
}
4479+
4480+
auto firstArg = BE->getArg()->getElement(0);
4481+
auto secondArg = BE->getArg()->getElement(1);
4482+
4483+
// Both arguments must conform to FloatingPoint protocol.
4484+
if (!conformsToKnownProtocol(const_cast<DeclContext *>(DC),
4485+
firstArg->getType(),
4486+
KnownProtocolKind::FloatingPoint) ||
4487+
!conformsToKnownProtocol(const_cast<DeclContext *>(DC),
4488+
secondArg->getType(),
4489+
KnownProtocolKind::FloatingPoint)) {
4490+
return;
4491+
}
4492+
4493+
// Convenience utility to extract argument decl.
4494+
auto extractArgumentDecl = [&](Expr *arg) -> ValueDecl * {
4495+
if (auto DRE = dyn_cast<DeclRefExpr>(arg)) {
4496+
return DRE->getDecl();
4497+
} else if (auto MRE = dyn_cast<MemberRefExpr>(arg)) {
4498+
return MRE->getMember().getDecl();
4499+
}
4500+
return nullptr;
4501+
};
4502+
4503+
// Dig out the declarations for the arguments.
4504+
auto *firstVal = extractArgumentDecl(firstArg);
4505+
auto *secondVal = extractArgumentDecl(secondArg);
4506+
4507+
// If we can't find declarations for both arguments, bail out,
4508+
// because one of them has to be '.nan'.
4509+
if (!firstArg && !secondArg) {
4510+
return;
4511+
}
4512+
4513+
// Convenience utility to check if this is a 'nan' variable.
4514+
auto isNanDecl = [&](ValueDecl *VD) {
4515+
return VD && isa<VarDecl>(VD) && VD->getBaseIdentifier().is("nan");
4516+
};
4517+
4518+
// Diagnose comparison with '.nan'.
4519+
//
4520+
// If the comparison is done using '<=', '<', '==', '>', '>=', then
4521+
// the result is always false. If the comparison is done using '!=',
4522+
// then the result is always true.
4523+
//
4524+
// Emit a different diagnostic which doesn't mention using '.isNaN' if
4525+
// the comparison isn't done using '==' or '!=' or if both sides are
4526+
// '.nan'.
4527+
if (isNanDecl(firstVal) && isNanDecl(secondVal)) {
4528+
C.Diags.diagnose(BE->getLoc(), diag::nan_comparison_both_nan,
4529+
comparisonDeclName.str(), comparisonDeclName.is("!="));
4530+
} else if (isNanDecl(firstVal) || isNanDecl(secondVal)) {
4531+
if (comparisonDeclName.is("==") || comparisonDeclName.is("!=")) {
4532+
auto exprStr =
4533+
C.SourceMgr
4534+
.extractText(Lexer::getCharSourceRangeFromSourceRange(
4535+
C.SourceMgr, firstArg->getSourceRange()))
4536+
.str();
4537+
auto prefix = exprStr;
4538+
if (comparisonDeclName.is("!=")) {
4539+
prefix = "!" + prefix;
4540+
}
4541+
C.Diags.diagnose(BE->getLoc(), diag::nan_comparison,
4542+
comparisonDeclName, comparisonDeclName.is("!="),
4543+
prefix, exprStr);
4544+
} else {
4545+
C.Diags.diagnose(BE->getLoc(), diag::nan_comparison_without_isnan,
4546+
comparisonDeclName, comparisonDeclName.is("!="));
4547+
}
4548+
}
4549+
}
4550+
4551+
std::pair<bool, Expr *> walkToExprPre(Expr *E) override {
4552+
if (!E || isa<ErrorExpr>(E) || !E->getType())
4553+
return {false, E};
4554+
4555+
if (auto *BE = dyn_cast<BinaryExpr>(E)) {
4556+
tryDiagnoseComparisonWithNaN(BE);
4557+
return {false, E};
4558+
}
4559+
4560+
return {true, E};
4561+
}
4562+
};
4563+
4564+
ComparisonWithNaNFinder Walker(DC);
4565+
const_cast<Expr *>(E)->walk(Walker);
4566+
}
4567+
44414568
//===----------------------------------------------------------------------===//
44424569
// High-level entry points.
44434570
//===----------------------------------------------------------------------===//
@@ -4454,6 +4581,7 @@ void swift::performSyntacticExprDiagnostics(const Expr *E,
44544581
diagnoseUnintendedOptionalBehavior(E, DC);
44554582
maybeDiagnoseCallToKeyValueObserveMethod(E, DC);
44564583
diagnoseExplicitUseOfLazyVariableStorage(E, DC);
4584+
diagnoseComparisonWithNaN(E, DC);
44574585
if (!ctx.isSwiftVersionAtLeast(5))
44584586
diagnoseDeprecatedWritableKeyPath(E, DC);
44594587
if (!ctx.LangOpts.DisableAvailabilityChecking)

test/decl/var/nan_comparisons.swift

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
// RUN: %target-typecheck-verify-swift
2+
3+
//////////////////////////////////////////////////////////////////////////////////////////////////
4+
/////// Comparison with '.nan' static property instead of using '.isNaN' instance property ///////
5+
//////////////////////////////////////////////////////////////////////////////////////////////////
6+
7+
// One side is '.nan' and the other isn't.
8+
// Using '==' or '!=' for comparison should suggest using '.isNaN'.
9+
10+
let double: Double = 0.0
11+
_ = double == .nan // expected-warning {{comparison with '.nan' using '==' is always false, use 'double.isNaN' to check if 'double' is not a number}}
12+
_ = double != .nan // expected-warning {{comparison with '.nan' using '!=' is always true, use '!double.isNaN' to check if 'double' is a number}}
13+
_ = 0.0 == .nan // // expected-warning {{comparison with '.nan' using '==' is always false, use '0.0.isNaN' to check if '0.0' is not a number}}
14+
15+
// One side is '.nan' and the other isn't. Using '>=', '>', '<', '<=' for comparison:
16+
// We can't suggest using '.isNaN' here.
17+
18+
_ = 0.0 >= .nan // expected-warning {{comparison with '.nan' using '>=' is always false}}
19+
_ = .nan > 1.1 // expected-warning {{comparison with '.nan' using '>' is always false}}
20+
_ = .nan < 2.2 // expected-warning {{comparison with '.nan' using '<' is always false}}
21+
_ = 3.3 <= .nan // expected-warning {{comparison with '.nan' using '<=' is always false}}
22+
23+
// Both sides are '.nan':
24+
// We can't suggest using '.isNaN' here.
25+
26+
_ = Double.nan == Double.nan // expected-warning {{'.nan' == '.nan' is always false}}
27+
_ = Double.nan != Double.nan // expected-warning {{'.nan' != '.nan' is always true}}
28+
_ = Double.nan < Double.nan // expected-warning {{'.nan' < '.nan' is always false}}
29+
_ = Double.nan <= Double.nan // expected-warning {{'.nan' <= '.nan' is always false}}
30+
_ = Double.nan > Double.nan // expected-warning {{'.nan' > '.nan' is always false}}
31+
_ = Double.nan >= Double.nan // expected-warning {{'.nan' >= '.nan' is always false}}

0 commit comments

Comments
 (0)