Skip to content

Commit aa583ea

Browse files
committed
[Refactoring] Rework how conditions are classified
- Expand ConditionType to include `.success` and `.failure` patterns. - Introduce ConditionPath to represent whether a classified condition represents a success or failure path. - Add methods to CallbackClassifier that deal with deciding whether a given condition is a success or failure path. - Use these methods to simplify `classifyConditional`.
1 parent a9084d4 commit aa583ea

File tree

1 file changed

+194
-113
lines changed

1 file changed

+194
-113
lines changed

lib/IDE/Refactoring.cpp

Lines changed: 194 additions & 113 deletions
Original file line numberDiff line numberDiff line change
@@ -4427,20 +4427,37 @@ struct AsyncHandlerParamDesc : public AsyncHandlerDesc {
44274427
}
44284428
};
44294429

4430-
enum class ConditionType { NIL, NOT_NIL };
4430+
/// The type of a condition in a conditional statement.
4431+
enum class ConditionType {
4432+
NIL, // == nil
4433+
NOT_NIL, // != nil
4434+
SUCCESS_PATTERN, // case .success
4435+
FAILURE_PATTEN // case .failure
4436+
};
4437+
4438+
/// Indicates whether a condition describes a success or failure path. For
4439+
/// example, a check for whether an error parameter is present is a failure
4440+
/// path. A check for a nil error parameter is a success path. This is distinct
4441+
/// from ConditionType, as it relies on contextual information about what values
4442+
/// need to be checked for success or failure.
4443+
enum class ConditionPath { SUCCESS, FAILURE };
4444+
4445+
static ConditionPath flippedConditionPath(ConditionPath Path) {
4446+
switch (Path) {
4447+
case ConditionPath::SUCCESS:
4448+
return ConditionPath::FAILURE;
4449+
case ConditionPath::FAILURE:
4450+
return ConditionPath::SUCCESS;
4451+
}
4452+
llvm_unreachable("Unhandled case in switch!");
4453+
}
44314454

44324455
/// Finds the `Subject` being compared to in various conditions. Also finds any
44334456
/// pattern that may have a bound name.
44344457
struct CallbackCondition {
44354458
Optional<ConditionType> Type;
44364459
const Decl *Subject = nullptr;
44374460
const Pattern *BindPattern = nullptr;
4438-
// Bit of a hack. When the `Subject` is a `Result` type we use this to
4439-
// distinguish between the `.success` and `.failure` case (as opposed to just
4440-
// checking whether `Subject` == `TheErrDecl`)
4441-
bool ErrorCase = false;
4442-
4443-
CallbackCondition() = default;
44444461

44454462
/// Initializes a `CallbackCondition` with a `!=` or `==` comparison of
44464463
/// an `Optional` typed `Subject` to `nil`, ie.
@@ -4506,64 +4523,16 @@ struct CallbackCondition {
45064523

45074524
bool isValid() const { return Type.hasValue(); }
45084525

4509-
/// Given an `if` condition `Cond` and a set of `Decls`, find any
4510-
/// `CallbackCondition`s in `Cond` that use one of those `Decls` and add them
4511-
/// to the map `AddTo`. Return `true` if all elements in the condition are
4512-
/// "handled", ie. every condition can be mapped to a single `Decl` in
4513-
/// `Decls`.
4514-
static bool all(StmtCondition Cond, llvm::DenseSet<const Decl *> Decls,
4515-
llvm::DenseMap<const Decl *, CallbackCondition> &AddTo) {
4516-
bool Handled = true;
4517-
for (auto &CondElement : Cond) {
4518-
if (auto *BoolExpr = CondElement.getBooleanOrNull()) {
4519-
SmallVector<Expr *, 1> Exprs;
4520-
Exprs.push_back(BoolExpr);
4521-
4522-
while (!Exprs.empty()) {
4523-
auto *Next = Exprs.pop_back_val();
4524-
if (auto *ACE = dyn_cast<AutoClosureExpr>(Next))
4525-
Next = ACE->getSingleExpressionBody();
4526-
4527-
if (auto *BE = dyn_cast_or_null<BinaryExpr>(Next)) {
4528-
auto *Operator = isOperator(BE);
4529-
if (Operator) {
4530-
if (Operator->getBaseName() == "&&") {
4531-
auto Args = BE->getArg()->getElements();
4532-
Exprs.insert(Exprs.end(), Args.begin(), Args.end());
4533-
} else {
4534-
addCond(CallbackCondition(BE, Operator), Decls, AddTo, Handled);
4535-
}
4536-
continue;
4537-
}
4538-
}
4539-
4540-
Handled = false;
4541-
}
4542-
} else if (auto *P = CondElement.getPatternOrNull()) {
4543-
addCond(CallbackCondition(P, CondElement.getInitializer()), Decls,
4544-
AddTo, Handled);
4545-
}
4546-
}
4547-
return Handled && !AddTo.empty();
4548-
}
4549-
45504526
private:
4551-
static void addCond(const CallbackCondition &CC,
4552-
llvm::DenseSet<const Decl *> Decls,
4553-
llvm::DenseMap<const Decl *, CallbackCondition> &AddTo,
4554-
bool &Handled) {
4555-
if (!CC.isValid() || !Decls.count(CC.Subject) ||
4556-
!AddTo.try_emplace(CC.Subject, CC).second)
4557-
Handled = false;
4558-
}
4559-
45604527
void initFromEnumPattern(const Decl *D, const EnumElementPattern *EEP) {
45614528
if (auto *EED = EEP->getElementDecl()) {
45624529
if (!isResultType(EED->getParentEnum()->getDeclaredType()))
45634530
return;
4564-
if (EED->getNameStr() == StringRef("failure"))
4565-
ErrorCase = true;
4566-
Type = ConditionType::NOT_NIL;
4531+
if (EED->getNameStr() == StringRef("failure")) {
4532+
Type = ConditionType::FAILURE_PATTEN;
4533+
} else {
4534+
Type = ConditionType::SUCCESS_PATTERN;
4535+
}
45674536
Subject = D;
45684537
BindPattern = EEP->getSubPattern();
45694538
}
@@ -4597,6 +4566,27 @@ struct CallbackCondition {
45974566
}
45984567
};
45994568

4569+
/// A CallbackCondition with additional semantic information about whether it
4570+
/// is for a success path or failure path.
4571+
struct ClassifiedCondition : public CallbackCondition {
4572+
ConditionPath Path;
4573+
4574+
explicit ClassifiedCondition(CallbackCondition Cond, ConditionPath Path)
4575+
: CallbackCondition(Cond), Path(Path) {}
4576+
};
4577+
4578+
/// A wrapper for a map of parameter decls to their classified conditions, or
4579+
/// \c None if they are not present in any conditions.
4580+
struct ClassifiedCallbackConditions final
4581+
: llvm::MapVector<const Decl *, ClassifiedCondition> {
4582+
Optional<ClassifiedCondition> lookup(const Decl *D) const {
4583+
auto Res = find(D);
4584+
if (Res == end())
4585+
return None;
4586+
return Res->second;
4587+
}
4588+
};
4589+
46004590
/// A list of nodes to print, along with a list of locations that may have
46014591
/// preceding comments attached, which also need printing. For example:
46024592
///
@@ -4742,7 +4732,7 @@ class ClassifiedBlock {
47424732
Nodes.addNode(Node);
47434733
}
47444734

4745-
void addBinding(const CallbackCondition &FromCondition,
4735+
void addBinding(const ClassifiedCondition &FromCondition,
47464736
DiagnosticEngine &DiagEngine) {
47474737
if (!FromCondition.BindPattern)
47484738
return;
@@ -4766,9 +4756,8 @@ class ClassifiedBlock {
47664756
BoundNames.try_emplace(FromCondition.Subject, Name);
47674757
}
47684758

4769-
void addAllBindings(
4770-
const llvm::DenseMap<const Decl *, CallbackCondition> &FromConditions,
4771-
DiagnosticEngine &DiagEngine) {
4759+
void addAllBindings(const ClassifiedCallbackConditions &FromConditions,
4760+
DiagnosticEngine &DiagEngine) {
47724761
for (auto &Entry : FromConditions) {
47734762
addBinding(Entry.second, DiagEngine);
47744763
if (DiagEngine.hadAnyError())
@@ -4857,12 +4846,108 @@ struct CallbackClassifier {
48574846
CurrentBlock->addPossibleCommentLoc(endCommentLoc);
48584847
}
48594848

4849+
/// Given a callback condition, classify it as a success or failure path, or
4850+
/// \c None if it cannot be classified.
4851+
Optional<ClassifiedCondition>
4852+
classifyCallbackCondition(const CallbackCondition &Cond) {
4853+
if (!Cond.isValid())
4854+
return None;
4855+
4856+
// For certain types of condition, they need to appear in certain lists.
4857+
auto CondType = *Cond.Type;
4858+
switch (CondType) {
4859+
case ConditionType::NOT_NIL:
4860+
case ConditionType::NIL:
4861+
if (!UnwrapParams.count(Cond.Subject))
4862+
return None;
4863+
break;
4864+
case ConditionType::SUCCESS_PATTERN:
4865+
case ConditionType::FAILURE_PATTEN:
4866+
if (!IsResultParam || Cond.Subject != ErrParam)
4867+
return None;
4868+
break;
4869+
}
4870+
4871+
// Let's start with a success path, and flip any negative conditions.
4872+
auto Path = ConditionPath::SUCCESS;
4873+
4874+
// If it's an error param, that's a flip.
4875+
if (Cond.Subject == ErrParam && !IsResultParam)
4876+
Path = flippedConditionPath(Path);
4877+
4878+
// If we have a nil or failure condition, that's a flip.
4879+
switch (CondType) {
4880+
case ConditionType::NIL:
4881+
case ConditionType::FAILURE_PATTEN:
4882+
Path = flippedConditionPath(Path);
4883+
break;
4884+
case ConditionType::NOT_NIL:
4885+
case ConditionType::SUCCESS_PATTERN:
4886+
break;
4887+
}
4888+
return ClassifiedCondition(Cond, Path);
4889+
}
4890+
4891+
/// Classifies all the conditions present in a given StmtCondition. Returns
4892+
/// \c true if there were any conditions that couldn't be classified,
4893+
/// \c false otherwise.
4894+
bool classifyConditionsOf(StmtCondition Cond,
4895+
ClassifiedCallbackConditions &Conditions) {
4896+
bool UnhandledConditions = false;
4897+
auto TryAddCond = [&](CallbackCondition CC) {
4898+
auto Classified = classifyCallbackCondition(CC);
4899+
if (!Classified) {
4900+
UnhandledConditions = true;
4901+
return;
4902+
}
4903+
// If we've seen multiple conditions for the same subject, don't handle
4904+
// this.
4905+
if (!Conditions.insert({CC.Subject, *Classified}).second)
4906+
UnhandledConditions = true;
4907+
};
4908+
4909+
for (auto &CondElement : Cond) {
4910+
if (auto *BoolExpr = CondElement.getBooleanOrNull()) {
4911+
SmallVector<Expr *, 1> Exprs;
4912+
Exprs.push_back(BoolExpr);
4913+
4914+
while (!Exprs.empty()) {
4915+
auto *Next = Exprs.pop_back_val();
4916+
if (auto *ACE = dyn_cast<AutoClosureExpr>(Next))
4917+
Next = ACE->getSingleExpressionBody();
4918+
4919+
if (auto *BE = dyn_cast_or_null<BinaryExpr>(Next)) {
4920+
auto *Operator = isOperator(BE);
4921+
if (Operator) {
4922+
// If we have an && operator, decompose its arguments.
4923+
if (Operator->getBaseName() == "&&") {
4924+
auto Args = BE->getArg()->getElements();
4925+
Exprs.insert(Exprs.end(), Args.begin(), Args.end());
4926+
} else {
4927+
// Otherwise check to see if we have an == nil or != nil
4928+
// condition.
4929+
TryAddCond(CallbackCondition(BE, Operator));
4930+
}
4931+
continue;
4932+
}
4933+
}
4934+
UnhandledConditions = true;
4935+
}
4936+
} else if (auto *P = CondElement.getPatternOrNull()) {
4937+
TryAddCond(CallbackCondition(P, CondElement.getInitializer()));
4938+
}
4939+
}
4940+
return UnhandledConditions || Conditions.empty();
4941+
}
4942+
4943+
/// Classifies the conditions of a conditional statement, and adds the
4944+
/// necessary nodes to either the success or failure block.
48604945
void classifyConditional(Stmt *Statement, StmtCondition Condition,
48614946
NodesToPrint ThenNodesToPrint, Stmt *ElseStmt) {
4862-
llvm::DenseMap<const Decl *, CallbackCondition> CallbackConditions;
4863-
bool UnhandledConditions =
4864-
!CallbackCondition::all(Condition, UnwrapParams, CallbackConditions);
4865-
CallbackCondition ErrCondition = CallbackConditions.lookup(ErrParam);
4947+
ClassifiedCallbackConditions CallbackConditions;
4948+
bool UnhandledConditions = classifyConditionsOf(
4949+
Condition, CallbackConditions);
4950+
auto ErrCondition = CallbackConditions.lookup(ErrParam);
48664951

48674952
if (UnhandledConditions) {
48684953
// Some unknown conditions. If there's an else, assume we can't handle
@@ -4878,12 +4963,11 @@ struct CallbackClassifier {
48784963
} else if (ElseStmt) {
48794964
DiagEngine.diagnose(Statement->getStartLoc(),
48804965
diag::unknown_callback_conditions);
4881-
} else if (ErrCondition.isValid() &&
4882-
ErrCondition.Type == ConditionType::NOT_NIL) {
4966+
} else if (ErrCondition && ErrCondition->Path == ConditionPath::FAILURE) {
48834967
Blocks.ErrorBlock.addNode(Statement);
48844968
} else {
48854969
for (auto &Entry : CallbackConditions) {
4886-
if (Entry.second.Type == ConditionType::NIL) {
4970+
if (Entry.second.Path == ConditionPath::FAILURE) {
48874971
Blocks.ErrorBlock.addNode(Statement);
48884972
return;
48894973
}
@@ -4893,42 +4977,36 @@ struct CallbackClassifier {
48934977
return;
48944978
}
48954979

4896-
ClassifiedBlock *ThenBlock = &Blocks.SuccessBlock;
4897-
ClassifiedBlock *ElseBlock = &Blocks.ErrorBlock;
4898-
4899-
if (ErrCondition.isValid() && (!IsResultParam || ErrCondition.ErrorCase) &&
4900-
ErrCondition.Type == ConditionType::NOT_NIL) {
4901-
ClassifiedBlock *TempBlock = ThenBlock;
4902-
ThenBlock = ElseBlock;
4903-
ElseBlock = TempBlock;
4904-
} else {
4905-
Optional<ConditionType> CondType;
4906-
for (auto &Entry : CallbackConditions) {
4907-
if (IsResultParam || Entry.second.Subject != ErrParam) {
4908-
if (!CondType) {
4909-
CondType = Entry.second.Type;
4910-
} else if (CondType != Entry.second.Type) {
4911-
// Similar to the unknown conditions case. Add the whole if unless
4912-
// there's an else, in which case use the fallback instead.
4913-
// TODO: Split the `if` statement
4914-
4915-
if (ElseStmt) {
4916-
DiagEngine.diagnose(Statement->getStartLoc(),
4917-
diag::mixed_callback_conditions);
4918-
} else {
4919-
CurrentBlock->addNode(Statement);
4920-
}
4921-
return;
4922-
}
4980+
// If all the conditions were classified, make sure they're all consistently
4981+
// on the success or failure path.
4982+
Optional<ConditionPath> Path;
4983+
for (auto &Entry : CallbackConditions) {
4984+
auto &Cond = Entry.second;
4985+
if (!Path) {
4986+
Path = Cond.Path;
4987+
} else if (*Path != Cond.Path) {
4988+
// Similar to the unknown conditions case. Add the whole if unless
4989+
// there's an else, in which case use the fallback instead.
4990+
// TODO: Split the `if` statement
4991+
4992+
if (ElseStmt) {
4993+
DiagEngine.diagnose(Statement->getStartLoc(),
4994+
diag::mixed_callback_conditions);
4995+
} else {
4996+
CurrentBlock->addNode(Statement);
49234997
}
4924-
}
4925-
4926-
if (CondType == ConditionType::NIL) {
4927-
ClassifiedBlock *TempBlock = ThenBlock;
4928-
ThenBlock = ElseBlock;
4929-
ElseBlock = TempBlock;
4998+
return;
49304999
}
49315000
}
5001+
assert(Path && "Didn't classify a path?");
5002+
5003+
auto *ThenBlock = &Blocks.SuccessBlock;
5004+
auto *ElseBlock = &Blocks.ErrorBlock;
5005+
5006+
// If the condition is for a failure path, the error block is ThenBlock, and
5007+
// the success block is ElseBlock.
5008+
if (*Path == ConditionPath::FAILURE)
5009+
std::swap(ThenBlock, ElseBlock);
49325010

49335011
// We'll be dropping the statement, but make sure to keep any attached
49345012
// comments.
@@ -4999,13 +5077,15 @@ struct CallbackClassifier {
49995077
return;
50005078
}
50015079

5002-
CallbackCondition CC(ErrParam, &Items[0]);
5003-
ClassifiedBlock *Block = &Blocks.SuccessBlock;
5004-
ClassifiedBlock *OtherBlock = &Blocks.ErrorBlock;
5005-
if (CC.ErrorCase) {
5006-
Block = &Blocks.ErrorBlock;
5007-
OtherBlock = &Blocks.SuccessBlock;
5008-
}
5080+
auto *Block = &Blocks.SuccessBlock;
5081+
auto *OtherBlock = &Blocks.ErrorBlock;
5082+
auto SuccessNodes = NodesToPrint::inBraceStmt(CS->getBody());
5083+
5084+
// Classify the case pattern.
5085+
auto CC = classifyCallbackCondition(
5086+
CallbackCondition(ErrParam, &Items[0]));
5087+
if (CC && CC->Path == ConditionPath::FAILURE)
5088+
std::swap(Block, OtherBlock);
50095089

50105090
// We'll be dropping the case, but make sure to keep any attached
50115091
// comments. Because these comments will effectively be part of the
@@ -5016,8 +5096,9 @@ struct CallbackClassifier {
50165096
if (CS == Cases.back())
50175097
Block->addPossibleCommentLoc(SS->getRBraceLoc());
50185098

5019-
setNodes(Block, OtherBlock, NodesToPrint::inBraceStmt(CS->getBody()));
5020-
Block->addBinding(CC, DiagEngine);
5099+
setNodes(Block, OtherBlock, std::move(SuccessNodes));
5100+
if (CC)
5101+
Block->addBinding(*CC, DiagEngine);
50215102
if (DiagEngine.hadAnyError())
50225103
return;
50235104
}

0 commit comments

Comments
 (0)