Skip to content

Commit 5ed7cae

Browse files
committed
[Refactoring] Rework how conditions are classified
- Expand ConditionType to include `.success` and `.failure` patterns. - Introduce ConditionPath to represent whether a classified condition represents a success or failure path. - Add methods to CallbackClassifier that deal with deciding whether a given condition is a success or failure path. - Use these methods to simplify `classifyConditional`.
1 parent f2d5f38 commit 5ed7cae

File tree

1 file changed

+194
-113
lines changed

1 file changed

+194
-113
lines changed

lib/IDE/Refactoring.cpp

Lines changed: 194 additions & 113 deletions
Original file line numberDiff line numberDiff line change
@@ -4410,20 +4410,37 @@ struct AsyncHandlerParamDesc : public AsyncHandlerDesc {
44104410
}
44114411
};
44124412

4413-
enum class ConditionType { NIL, NOT_NIL };
4413+
/// The type of a condition in a conditional statement.
4414+
enum class ConditionType {
4415+
NIL, // == nil
4416+
NOT_NIL, // != nil
4417+
SUCCESS_PATTERN, // case .success
4418+
FAILURE_PATTEN // case .failure
4419+
};
4420+
4421+
/// Indicates whether a condition describes a success or failure path. For
4422+
/// example, a check for whether an error parameter is present is a failure
4423+
/// path. A check for a nil error parameter is a success path. This is distinct
4424+
/// from ConditionType, as it relies on contextual information about what values
4425+
/// need to be checked for success or failure.
4426+
enum class ConditionPath { SUCCESS, FAILURE };
4427+
4428+
static ConditionPath flippedConditionPath(ConditionPath Path) {
4429+
switch (Path) {
4430+
case ConditionPath::SUCCESS:
4431+
return ConditionPath::FAILURE;
4432+
case ConditionPath::FAILURE:
4433+
return ConditionPath::SUCCESS;
4434+
}
4435+
llvm_unreachable("Unhandled case in switch!");
4436+
}
44144437

44154438
/// Finds the `Subject` being compared to in various conditions. Also finds any
44164439
/// pattern that may have a bound name.
44174440
struct CallbackCondition {
44184441
Optional<ConditionType> Type;
44194442
const Decl *Subject = nullptr;
44204443
const Pattern *BindPattern = nullptr;
4421-
// Bit of a hack. When the `Subject` is a `Result` type we use this to
4422-
// distinguish between the `.success` and `.failure` case (as opposed to just
4423-
// checking whether `Subject` == `TheErrDecl`)
4424-
bool ErrorCase = false;
4425-
4426-
CallbackCondition() = default;
44274444

44284445
/// Initializes a `CallbackCondition` with a `!=` or `==` comparison of
44294446
/// an `Optional` typed `Subject` to `nil`, ie.
@@ -4489,65 +4506,17 @@ struct CallbackCondition {
44894506

44904507
bool isValid() const { return Type.hasValue(); }
44914508

4492-
/// Given an `if` condition `Cond` and a set of `Decls`, find any
4493-
/// `CallbackCondition`s in `Cond` that use one of those `Decls` and add them
4494-
/// to the map `AddTo`. Return `true` if all elements in the condition are
4495-
/// "handled", ie. every condition can be mapped to a single `Decl` in
4496-
/// `Decls`.
4497-
static bool all(StmtCondition Cond, llvm::DenseSet<const Decl *> Decls,
4498-
llvm::DenseMap<const Decl *, CallbackCondition> &AddTo) {
4499-
bool Handled = true;
4500-
for (auto &CondElement : Cond) {
4501-
if (auto *BoolExpr = CondElement.getBooleanOrNull()) {
4502-
SmallVector<Expr *, 1> Exprs;
4503-
Exprs.push_back(BoolExpr);
4504-
4505-
while (!Exprs.empty()) {
4506-
auto *Next = Exprs.pop_back_val();
4507-
if (auto *ACE = dyn_cast<AutoClosureExpr>(Next))
4508-
Next = ACE->getSingleExpressionBody();
4509-
4510-
if (auto *BE = dyn_cast_or_null<BinaryExpr>(Next)) {
4511-
auto *Operator = isOperator(BE);
4512-
if (Operator) {
4513-
if (Operator->getBaseName() == "&&") {
4514-
Exprs.push_back(BE->getLHS());
4515-
Exprs.push_back(BE->getRHS());
4516-
} else {
4517-
addCond(CallbackCondition(BE, Operator), Decls, AddTo, Handled);
4518-
}
4519-
continue;
4520-
}
4521-
}
4522-
4523-
Handled = false;
4524-
}
4525-
} else if (auto *P = CondElement.getPatternOrNull()) {
4526-
addCond(CallbackCondition(P, CondElement.getInitializer()), Decls,
4527-
AddTo, Handled);
4528-
}
4529-
}
4530-
return Handled && !AddTo.empty();
4531-
}
4532-
45334509
private:
4534-
static void addCond(const CallbackCondition &CC,
4535-
llvm::DenseSet<const Decl *> Decls,
4536-
llvm::DenseMap<const Decl *, CallbackCondition> &AddTo,
4537-
bool &Handled) {
4538-
if (!CC.isValid() || !Decls.count(CC.Subject) ||
4539-
!AddTo.try_emplace(CC.Subject, CC).second)
4540-
Handled = false;
4541-
}
4542-
45434510
void initFromEnumPattern(const Decl *D, const EnumElementPattern *EEP) {
45444511
if (auto *EED = EEP->getElementDecl()) {
45454512
auto eedTy = EED->getParentEnum()->getDeclaredType();
45464513
if (!eedTy || !eedTy->isResult())
45474514
return;
4548-
if (EED->getNameStr() == StringRef("failure"))
4549-
ErrorCase = true;
4550-
Type = ConditionType::NOT_NIL;
4515+
if (EED->getNameStr() == StringRef("failure")) {
4516+
Type = ConditionType::FAILURE_PATTEN;
4517+
} else {
4518+
Type = ConditionType::SUCCESS_PATTERN;
4519+
}
45514520
Subject = D;
45524521
BindPattern = EEP->getSubPattern();
45534522
}
@@ -4581,6 +4550,27 @@ struct CallbackCondition {
45814550
}
45824551
};
45834552

4553+
/// A CallbackCondition with additional semantic information about whether it
4554+
/// is for a success path or failure path.
4555+
struct ClassifiedCondition : public CallbackCondition {
4556+
ConditionPath Path;
4557+
4558+
explicit ClassifiedCondition(CallbackCondition Cond, ConditionPath Path)
4559+
: CallbackCondition(Cond), Path(Path) {}
4560+
};
4561+
4562+
/// A wrapper for a map of parameter decls to their classified conditions, or
4563+
/// \c None if they are not present in any conditions.
4564+
struct ClassifiedCallbackConditions final
4565+
: llvm::MapVector<const Decl *, ClassifiedCondition> {
4566+
Optional<ClassifiedCondition> lookup(const Decl *D) const {
4567+
auto Res = find(D);
4568+
if (Res == end())
4569+
return None;
4570+
return Res->second;
4571+
}
4572+
};
4573+
45844574
/// A list of nodes to print, along with a list of locations that may have
45854575
/// preceding comments attached, which also need printing. For example:
45864576
///
@@ -4726,7 +4716,7 @@ class ClassifiedBlock {
47264716
Nodes.addNode(Node);
47274717
}
47284718

4729-
void addBinding(const CallbackCondition &FromCondition,
4719+
void addBinding(const ClassifiedCondition &FromCondition,
47304720
DiagnosticEngine &DiagEngine) {
47314721
if (!FromCondition.BindPattern)
47324722
return;
@@ -4750,9 +4740,8 @@ class ClassifiedBlock {
47504740
BoundNames.try_emplace(FromCondition.Subject, Name);
47514741
}
47524742

4753-
void addAllBindings(
4754-
const llvm::DenseMap<const Decl *, CallbackCondition> &FromConditions,
4755-
DiagnosticEngine &DiagEngine) {
4743+
void addAllBindings(const ClassifiedCallbackConditions &FromConditions,
4744+
DiagnosticEngine &DiagEngine) {
47564745
for (auto &Entry : FromConditions) {
47574746
addBinding(Entry.second, DiagEngine);
47584747
if (DiagEngine.hadAnyError())
@@ -4841,12 +4830,108 @@ struct CallbackClassifier {
48414830
CurrentBlock->addPossibleCommentLoc(endCommentLoc);
48424831
}
48434832

4833+
/// Given a callback condition, classify it as a success or failure path, or
4834+
/// \c None if it cannot be classified.
4835+
Optional<ClassifiedCondition>
4836+
classifyCallbackCondition(const CallbackCondition &Cond) {
4837+
if (!Cond.isValid())
4838+
return None;
4839+
4840+
// For certain types of condition, they need to appear in certain lists.
4841+
auto CondType = *Cond.Type;
4842+
switch (CondType) {
4843+
case ConditionType::NOT_NIL:
4844+
case ConditionType::NIL:
4845+
if (!UnwrapParams.count(Cond.Subject))
4846+
return None;
4847+
break;
4848+
case ConditionType::SUCCESS_PATTERN:
4849+
case ConditionType::FAILURE_PATTEN:
4850+
if (!IsResultParam || Cond.Subject != ErrParam)
4851+
return None;
4852+
break;
4853+
}
4854+
4855+
// Let's start with a success path, and flip any negative conditions.
4856+
auto Path = ConditionPath::SUCCESS;
4857+
4858+
// If it's an error param, that's a flip.
4859+
if (Cond.Subject == ErrParam && !IsResultParam)
4860+
Path = flippedConditionPath(Path);
4861+
4862+
// If we have a nil or failure condition, that's a flip.
4863+
switch (CondType) {
4864+
case ConditionType::NIL:
4865+
case ConditionType::FAILURE_PATTEN:
4866+
Path = flippedConditionPath(Path);
4867+
break;
4868+
case ConditionType::NOT_NIL:
4869+
case ConditionType::SUCCESS_PATTERN:
4870+
break;
4871+
}
4872+
return ClassifiedCondition(Cond, Path);
4873+
}
4874+
4875+
/// Classifies all the conditions present in a given StmtCondition. Returns
4876+
/// \c true if there were any conditions that couldn't be classified,
4877+
/// \c false otherwise.
4878+
bool classifyConditionsOf(StmtCondition Cond,
4879+
ClassifiedCallbackConditions &Conditions) {
4880+
bool UnhandledConditions = false;
4881+
auto TryAddCond = [&](CallbackCondition CC) {
4882+
auto Classified = classifyCallbackCondition(CC);
4883+
if (!Classified) {
4884+
UnhandledConditions = true;
4885+
return;
4886+
}
4887+
// If we've seen multiple conditions for the same subject, don't handle
4888+
// this.
4889+
if (!Conditions.insert({CC.Subject, *Classified}).second)
4890+
UnhandledConditions = true;
4891+
};
4892+
4893+
for (auto &CondElement : Cond) {
4894+
if (auto *BoolExpr = CondElement.getBooleanOrNull()) {
4895+
SmallVector<Expr *, 1> Exprs;
4896+
Exprs.push_back(BoolExpr);
4897+
4898+
while (!Exprs.empty()) {
4899+
auto *Next = Exprs.pop_back_val();
4900+
if (auto *ACE = dyn_cast<AutoClosureExpr>(Next))
4901+
Next = ACE->getSingleExpressionBody();
4902+
4903+
if (auto *BE = dyn_cast_or_null<BinaryExpr>(Next)) {
4904+
auto *Operator = isOperator(BE);
4905+
if (Operator) {
4906+
// If we have an && operator, decompose its arguments.
4907+
if (Operator->getBaseName() == "&&") {
4908+
Exprs.push_back(BE->getLHS());
4909+
Exprs.push_back(BE->getRHS());
4910+
} else {
4911+
// Otherwise check to see if we have an == nil or != nil
4912+
// condition.
4913+
TryAddCond(CallbackCondition(BE, Operator));
4914+
}
4915+
continue;
4916+
}
4917+
}
4918+
UnhandledConditions = true;
4919+
}
4920+
} else if (auto *P = CondElement.getPatternOrNull()) {
4921+
TryAddCond(CallbackCondition(P, CondElement.getInitializer()));
4922+
}
4923+
}
4924+
return UnhandledConditions || Conditions.empty();
4925+
}
4926+
4927+
/// Classifies the conditions of a conditional statement, and adds the
4928+
/// necessary nodes to either the success or failure block.
48444929
void classifyConditional(Stmt *Statement, StmtCondition Condition,
48454930
NodesToPrint ThenNodesToPrint, Stmt *ElseStmt) {
4846-
llvm::DenseMap<const Decl *, CallbackCondition> CallbackConditions;
4847-
bool UnhandledConditions =
4848-
!CallbackCondition::all(Condition, UnwrapParams, CallbackConditions);
4849-
CallbackCondition ErrCondition = CallbackConditions.lookup(ErrParam);
4931+
ClassifiedCallbackConditions CallbackConditions;
4932+
bool UnhandledConditions = classifyConditionsOf(
4933+
Condition, CallbackConditions);
4934+
auto ErrCondition = CallbackConditions.lookup(ErrParam);
48504935

48514936
if (UnhandledConditions) {
48524937
// Some unknown conditions. If there's an else, assume we can't handle
@@ -4862,12 +4947,11 @@ struct CallbackClassifier {
48624947
} else if (ElseStmt) {
48634948
DiagEngine.diagnose(Statement->getStartLoc(),
48644949
diag::unknown_callback_conditions);
4865-
} else if (ErrCondition.isValid() &&
4866-
ErrCondition.Type == ConditionType::NOT_NIL) {
4950+
} else if (ErrCondition && ErrCondition->Path == ConditionPath::FAILURE) {
48674951
Blocks.ErrorBlock.addNode(Statement);
48684952
} else {
48694953
for (auto &Entry : CallbackConditions) {
4870-
if (Entry.second.Type == ConditionType::NIL) {
4954+
if (Entry.second.Path == ConditionPath::FAILURE) {
48714955
Blocks.ErrorBlock.addNode(Statement);
48724956
return;
48734957
}
@@ -4877,42 +4961,36 @@ struct CallbackClassifier {
48774961
return;
48784962
}
48794963

4880-
ClassifiedBlock *ThenBlock = &Blocks.SuccessBlock;
4881-
ClassifiedBlock *ElseBlock = &Blocks.ErrorBlock;
4882-
4883-
if (ErrCondition.isValid() && (!IsResultParam || ErrCondition.ErrorCase) &&
4884-
ErrCondition.Type == ConditionType::NOT_NIL) {
4885-
ClassifiedBlock *TempBlock = ThenBlock;
4886-
ThenBlock = ElseBlock;
4887-
ElseBlock = TempBlock;
4888-
} else {
4889-
Optional<ConditionType> CondType;
4890-
for (auto &Entry : CallbackConditions) {
4891-
if (IsResultParam || Entry.second.Subject != ErrParam) {
4892-
if (!CondType) {
4893-
CondType = Entry.second.Type;
4894-
} else if (CondType != Entry.second.Type) {
4895-
// Similar to the unknown conditions case. Add the whole if unless
4896-
// there's an else, in which case use the fallback instead.
4897-
// TODO: Split the `if` statement
4898-
4899-
if (ElseStmt) {
4900-
DiagEngine.diagnose(Statement->getStartLoc(),
4901-
diag::mixed_callback_conditions);
4902-
} else {
4903-
CurrentBlock->addNode(Statement);
4904-
}
4905-
return;
4906-
}
4964+
// If all the conditions were classified, make sure they're all consistently
4965+
// on the success or failure path.
4966+
Optional<ConditionPath> Path;
4967+
for (auto &Entry : CallbackConditions) {
4968+
auto &Cond = Entry.second;
4969+
if (!Path) {
4970+
Path = Cond.Path;
4971+
} else if (*Path != Cond.Path) {
4972+
// Similar to the unknown conditions case. Add the whole if unless
4973+
// there's an else, in which case use the fallback instead.
4974+
// TODO: Split the `if` statement
4975+
4976+
if (ElseStmt) {
4977+
DiagEngine.diagnose(Statement->getStartLoc(),
4978+
diag::mixed_callback_conditions);
4979+
} else {
4980+
CurrentBlock->addNode(Statement);
49074981
}
4908-
}
4909-
4910-
if (CondType == ConditionType::NIL) {
4911-
ClassifiedBlock *TempBlock = ThenBlock;
4912-
ThenBlock = ElseBlock;
4913-
ElseBlock = TempBlock;
4982+
return;
49144983
}
49154984
}
4985+
assert(Path && "Didn't classify a path?");
4986+
4987+
auto *ThenBlock = &Blocks.SuccessBlock;
4988+
auto *ElseBlock = &Blocks.ErrorBlock;
4989+
4990+
// If the condition is for a failure path, the error block is ThenBlock, and
4991+
// the success block is ElseBlock.
4992+
if (*Path == ConditionPath::FAILURE)
4993+
std::swap(ThenBlock, ElseBlock);
49164994

49174995
// We'll be dropping the statement, but make sure to keep any attached
49184996
// comments.
@@ -4983,13 +5061,15 @@ struct CallbackClassifier {
49835061
return;
49845062
}
49855063

4986-
CallbackCondition CC(ErrParam, &Items[0]);
4987-
ClassifiedBlock *Block = &Blocks.SuccessBlock;
4988-
ClassifiedBlock *OtherBlock = &Blocks.ErrorBlock;
4989-
if (CC.ErrorCase) {
4990-
Block = &Blocks.ErrorBlock;
4991-
OtherBlock = &Blocks.SuccessBlock;
4992-
}
5064+
auto *Block = &Blocks.SuccessBlock;
5065+
auto *OtherBlock = &Blocks.ErrorBlock;
5066+
auto SuccessNodes = NodesToPrint::inBraceStmt(CS->getBody());
5067+
5068+
// Classify the case pattern.
5069+
auto CC = classifyCallbackCondition(
5070+
CallbackCondition(ErrParam, &Items[0]));
5071+
if (CC && CC->Path == ConditionPath::FAILURE)
5072+
std::swap(Block, OtherBlock);
49935073

49945074
// We'll be dropping the case, but make sure to keep any attached
49955075
// comments. Because these comments will effectively be part of the
@@ -5000,8 +5080,9 @@ struct CallbackClassifier {
50005080
if (CS == Cases.back())
50015081
Block->addPossibleCommentLoc(SS->getRBraceLoc());
50025082

5003-
setNodes(Block, OtherBlock, NodesToPrint::inBraceStmt(CS->getBody()));
5004-
Block->addBinding(CC, DiagEngine);
5083+
setNodes(Block, OtherBlock, std::move(SuccessNodes));
5084+
if (CC)
5085+
Block->addBinding(*CC, DiagEngine);
50055086
if (DiagEngine.hadAnyError())
50065087
return;
50075088
}

0 commit comments

Comments
 (0)