Skip to content

Commit 904bd0b

Browse files
committed
SR-5740 Refactoring action to convert if statement to switch
1 parent cc3b6a0 commit 904bd0b

20 files changed

+2744
-8
lines changed

include/swift/IDE/RefactoringKinds.def

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,8 @@ RANGE_REFACTORING(ConvertGuardExprToIfLetExpr, "Convert To IfLet Expression", co
7272

7373
RANGE_REFACTORING(ConvertToComputedProperty, "Convert To Computed Property", convert.to.computed.property)
7474

75+
RANGE_REFACTORING(ConvertToSwitchStmt, "Convert To Switch Statement", convert.switch.stmt)
76+
7577
// These internal refactorings are designed to be helpful for working on
7678
// the compiler/standard library, etc., but are likely to be just confusing and
7779
// noise for general development.

lib/IDE/Refactoring.cpp

Lines changed: 264 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2243,6 +2243,270 @@ bool RefactoringActionConvertGuardExprToIfLetExpr::performChange() {
22432243
return false;
22442244
}
22452245

2246+
bool RefactoringActionConvertToSwitchStmt::
2247+
isApplicable(ResolvedRangeInfo Info, DiagnosticEngine &Diag) {
2248+
2249+
class ConditionalChecker : public ASTWalker {
2250+
public:
2251+
bool ParamsUseSameVars = true;
2252+
bool ConditionUseOnlyAllowedFunctions = true;
2253+
StringRef ExpectName;
2254+
2255+
Expr *walkToExprPost(Expr *E) {
2256+
if (E->getKind() != ExprKind::DeclRef)
2257+
return E;
2258+
auto D = dyn_cast<DeclRefExpr>(E)->getDecl();
2259+
if (D->getKind() == DeclKind::Var || D->getKind() == DeclKind::Param)
2260+
ParamsUseSameVars = checkName(dyn_cast<VarDecl>(D));
2261+
if (D->getKind() == DeclKind::Func)
2262+
ConditionUseOnlyAllowedFunctions = checkName(dyn_cast<FuncDecl>(D));
2263+
if (allCheckPassed())
2264+
return E;
2265+
return nullptr;
2266+
}
2267+
2268+
bool allCheckPassed() {
2269+
return ParamsUseSameVars && ConditionUseOnlyAllowedFunctions;
2270+
}
2271+
2272+
private:
2273+
bool checkName(VarDecl *VD) {
2274+
auto Name = VD->getName().str();
2275+
if (ExpectName.empty())
2276+
ExpectName = Name;
2277+
return Name == ExpectName;
2278+
}
2279+
2280+
bool checkName(FuncDecl *FD) {
2281+
auto Name = FD->getName().str();
2282+
return Name == "~="
2283+
|| Name == "=="
2284+
|| Name == "__derived_enum_equals"
2285+
|| Name == "__derived_struct_equals"
2286+
|| Name == "||"
2287+
|| Name == "...";
2288+
}
2289+
};
2290+
2291+
class SwitchConvertable {
2292+
public:
2293+
SwitchConvertable(ResolvedRangeInfo Info) {
2294+
this->Info = Info;
2295+
}
2296+
2297+
bool isApplicable() {
2298+
if (Info.Kind != RangeKind::SingleStatement)
2299+
return false;
2300+
if (!findIfStmt())
2301+
return false;
2302+
return checkEachCondition();
2303+
}
2304+
2305+
private:
2306+
ResolvedRangeInfo Info;
2307+
IfStmt *If = nullptr;
2308+
ConditionalChecker checker;
2309+
2310+
bool findIfStmt() {
2311+
if (Info.ContainedNodes.size() != 1)
2312+
return false;
2313+
if (auto S = Info.ContainedNodes.front().dyn_cast<Stmt*>())
2314+
If = dyn_cast<IfStmt>(S);
2315+
return If != nullptr;
2316+
}
2317+
2318+
bool checkEachCondition() {
2319+
checker = ConditionalChecker();
2320+
do {
2321+
if (!checkEachElement())
2322+
return false;
2323+
} while ((If = dyn_cast_or_null<IfStmt>(If->getElseStmt())));
2324+
return true;
2325+
}
2326+
2327+
bool checkEachElement() {
2328+
bool result = true;
2329+
auto ConditionalList = If->getCond();
2330+
for (auto Element : ConditionalList) {
2331+
result &= check(Element);
2332+
}
2333+
return result;
2334+
}
2335+
2336+
bool check(StmtConditionElement ConditionElement) {
2337+
if (ConditionElement.getKind() == StmtConditionElement::CK_Availability)
2338+
return false;
2339+
ConditionElement.walk(checker);
2340+
return checker.allCheckPassed();
2341+
}
2342+
};
2343+
return SwitchConvertable(Info).isApplicable();
2344+
}
2345+
2346+
bool RefactoringActionConvertToSwitchStmt::performChange() {
2347+
2348+
class VarNameFinder : public ASTWalker {
2349+
public:
2350+
std::string VarName;
2351+
2352+
Expr *walkToExprPost(Expr *E) {
2353+
if (E->getKind() != ExprKind::DeclRef)
2354+
return E;
2355+
auto D = dyn_cast<DeclRefExpr>(E)->getDecl();
2356+
if (D->getKind() != DeclKind::Var && D->getKind() != DeclKind::Param)
2357+
return E;
2358+
VarName = dyn_cast<VarDecl>(D)->getName().str().str();
2359+
return nullptr;
2360+
}
2361+
};
2362+
2363+
class ConditionalPatternFinder : public ASTWalker {
2364+
public:
2365+
ConditionalPatternFinder(SourceManager &SM) : SM(SM) {}
2366+
2367+
SmallString<64> ConditionalPattern = SmallString<64>();
2368+
2369+
Expr *walkToExprPost(Expr *E) {
2370+
if (E->getKind() != ExprKind::Binary)
2371+
return E;
2372+
auto BE = dyn_cast<BinaryExpr>(E);
2373+
if (isFunctionNameAllowed(BE))
2374+
appendPattern(dyn_cast<BinaryExpr>(E)->getArg());
2375+
return E;
2376+
}
2377+
2378+
std::pair<bool, Pattern*> walkToPatternPre(Pattern *P) {
2379+
ConditionalPattern.append(Lexer::getCharSourceRangeFromSourceRange(SM, P->getSourceRange()).str());
2380+
if (P->getKind() == PatternKind::OptionalSome)
2381+
ConditionalPattern.append("?");
2382+
return { true, nullptr };
2383+
}
2384+
2385+
private:
2386+
2387+
SourceManager &SM;
2388+
2389+
bool isFunctionNameAllowed(BinaryExpr *E) {
2390+
auto FunctionBody = dyn_cast<DotSyntaxCallExpr>(E->getFn())->getFn();
2391+
auto FunctionDeclaration = dyn_cast<DeclRefExpr>(FunctionBody)->getDecl();
2392+
auto FunctionName = dyn_cast<FuncDecl>(FunctionDeclaration)->getName().str();
2393+
return FunctionName == "~="
2394+
|| FunctionName == "=="
2395+
|| FunctionName == "__derived_enum_equals"
2396+
|| FunctionName == "__derived_struct_equals";
2397+
}
2398+
2399+
void appendPattern(TupleExpr *Tuple) {
2400+
auto PatternArgument = Tuple->getElements().back();
2401+
if (PatternArgument->getKind() == ExprKind::DeclRef)
2402+
PatternArgument = Tuple->getElements().front();
2403+
if (ConditionalPattern.size() > 0)
2404+
ConditionalPattern.append(", ");
2405+
ConditionalPattern.append(Lexer::getCharSourceRangeFromSourceRange(SM, PatternArgument->getSourceRange()).str());
2406+
}
2407+
};
2408+
2409+
class ConverterToSwitch {
2410+
public:
2411+
ConverterToSwitch(ResolvedRangeInfo Info, SourceManager &SM) : SM(SM) {
2412+
this->Info = Info;
2413+
}
2414+
2415+
void performConvert(SmallString<64> &Out) {
2416+
If = findIf();
2417+
OptionalLabel = If->getLabelInfo().Name.str().str();
2418+
ControlExpression = findControlExpression();
2419+
findPatternsAndBodies(PatternsAndBodies);
2420+
DefaultStatements = findDefaultStatements();
2421+
makeSwitchStatement(Out);
2422+
}
2423+
2424+
private:
2425+
ResolvedRangeInfo Info;
2426+
SourceManager &SM;
2427+
2428+
IfStmt *If;
2429+
IfStmt *PreviousIf;
2430+
2431+
std::string OptionalLabel;
2432+
std::string ControlExpression;
2433+
SmallVector<std::pair<std::string, std::string>, 16> PatternsAndBodies;
2434+
std::string DefaultStatements;
2435+
2436+
IfStmt *findIf() {
2437+
auto S = Info.ContainedNodes[0].dyn_cast<Stmt*>();
2438+
return dyn_cast<IfStmt>(S);
2439+
}
2440+
2441+
std::string findControlExpression() {
2442+
auto ConditionElement = If->getCond().front();
2443+
auto Finder = VarNameFinder();
2444+
ConditionElement.walk(Finder);
2445+
return Finder.VarName;
2446+
}
2447+
2448+
void findPatternsAndBodies(SmallVectorImpl<std::pair<std::string, std::string>> &Out) {
2449+
do {
2450+
auto pattern = findPattern();
2451+
auto body = findBodyStatements();
2452+
Out.push_back(std::make_pair(pattern, body));
2453+
PreviousIf = If;
2454+
} while ((If = dyn_cast_or_null<IfStmt>(If->getElseStmt())));
2455+
}
2456+
2457+
std::string findPattern() {
2458+
auto ConditionElement = If->getCond().front();
2459+
auto Finder = ConditionalPatternFinder(SM);
2460+
ConditionElement.walk(Finder);
2461+
return Finder.ConditionalPattern.str().str();
2462+
}
2463+
2464+
std::string findBodyStatements() {
2465+
return findBodyWithoutBraces(If->getThenStmt());
2466+
}
2467+
2468+
std::string findDefaultStatements() {
2469+
auto ElseBody = dyn_cast_or_null<BraceStmt>(PreviousIf->getElseStmt());
2470+
if (!ElseBody)
2471+
return getTokenText(tok::kw_break);
2472+
return findBodyWithoutBraces(ElseBody);
2473+
}
2474+
2475+
std::string findBodyWithoutBraces(Stmt *body) {
2476+
auto BS = dyn_cast<BraceStmt>(body);
2477+
if (!BS)
2478+
return Lexer::getCharSourceRangeFromSourceRange(SM, body->getSourceRange()).str().str();
2479+
if (BS->getElements().empty())
2480+
return getTokenText(tok::kw_break);
2481+
SourceRange BodyRange = BS->getElements().front().getSourceRange();
2482+
BodyRange.widen(BS->getElements().back().getSourceRange());
2483+
return Lexer::getCharSourceRangeFromSourceRange(SM, BodyRange).str().str();
2484+
}
2485+
2486+
void makeSwitchStatement(SmallString<64> &Out) {
2487+
StringRef Space = " ";
2488+
StringRef NewLine = "\n";
2489+
llvm::raw_svector_ostream OS(Out);
2490+
if (OptionalLabel.size() > 0)
2491+
OS << OptionalLabel << ":" << Space;
2492+
OS << tok::kw_switch << Space << ControlExpression << Space << tok::l_brace << NewLine;
2493+
for (auto &pair : PatternsAndBodies) {
2494+
OS << tok::kw_case << Space << pair.first << tok::colon << NewLine;
2495+
OS << pair.second << NewLine;
2496+
}
2497+
OS << tok::kw_default << tok::colon << NewLine;
2498+
OS << DefaultStatements << NewLine;
2499+
OS << tok::r_brace;
2500+
}
2501+
2502+
};
2503+
2504+
SmallString<64> result;
2505+
ConverterToSwitch(RangeInfo, SM).performConvert(result);
2506+
EditConsumer.accept(SM, RangeInfo.ContentRange, result.str());
2507+
return false;
2508+
}
2509+
22462510
/// Struct containing info about an IfStmt that can be converted into an IfExpr.
22472511
struct ConvertToTernaryExprInfo {
22482512
ConvertToTernaryExprInfo() {}

0 commit comments

Comments
 (0)