Skip to content

Commit 6640316

Browse files
author
Nathan Hawes
authored
Merge pull request swiftlang#29596 from Regno/feature/vlasov/SR-5740
[Source Tooling] Refactoring action to convert if statement to switch
2 parents d40d47f + 6013431 commit 6640316

20 files changed

+2760
-8
lines changed

include/swift/IDE/RefactoringKinds.def

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,8 @@ RANGE_REFACTORING(ConvertGuardExprToIfLetExpr, "Convert To IfLet Expression", co
7272

7373
RANGE_REFACTORING(ConvertToComputedProperty, "Convert To Computed Property", convert.to.computed.property)
7474

75+
RANGE_REFACTORING(ConvertToSwitchStmt, "Convert To Switch Statement", convert.switch.stmt)
76+
7577
// These internal refactorings are designed to be helpful for working on
7678
// the compiler/standard library, etc., but are likely to be just confusing and
7779
// noise for general development.

lib/IDE/Refactoring.cpp

Lines changed: 266 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2243,6 +2243,272 @@ bool RefactoringActionConvertGuardExprToIfLetExpr::performChange() {
22432243
return false;
22442244
}
22452245

2246+
bool RefactoringActionConvertToSwitchStmt::
2247+
isApplicable(ResolvedRangeInfo Info, DiagnosticEngine &Diag) {
2248+
2249+
class ConditionalChecker : public ASTWalker {
2250+
public:
2251+
bool ParamsUseSameVars = true;
2252+
bool ConditionUseOnlyAllowedFunctions = false;
2253+
StringRef ExpectName;
2254+
2255+
Expr *walkToExprPost(Expr *E) {
2256+
if (E->getKind() != ExprKind::DeclRef)
2257+
return E;
2258+
auto D = dyn_cast<DeclRefExpr>(E)->getDecl();
2259+
if (D->getKind() == DeclKind::Var || D->getKind() == DeclKind::Param)
2260+
ParamsUseSameVars = checkName(dyn_cast<VarDecl>(D));
2261+
if (D->getKind() == DeclKind::Func)
2262+
ConditionUseOnlyAllowedFunctions = checkName(dyn_cast<FuncDecl>(D));
2263+
if (allCheckPassed())
2264+
return E;
2265+
return nullptr;
2266+
}
2267+
2268+
bool allCheckPassed() {
2269+
return ParamsUseSameVars && ConditionUseOnlyAllowedFunctions;
2270+
}
2271+
2272+
private:
2273+
bool checkName(VarDecl *VD) {
2274+
auto Name = VD->getName().str();
2275+
if (ExpectName.empty())
2276+
ExpectName = Name;
2277+
return Name == ExpectName;
2278+
}
2279+
2280+
bool checkName(FuncDecl *FD) {
2281+
auto Name = FD->getName().str();
2282+
return Name == "~="
2283+
|| Name == "=="
2284+
|| Name == "__derived_enum_equals"
2285+
|| Name == "__derived_struct_equals"
2286+
|| Name == "||"
2287+
|| Name == "...";
2288+
}
2289+
};
2290+
2291+
class SwitchConvertable {
2292+
public:
2293+
SwitchConvertable(ResolvedRangeInfo Info) {
2294+
this->Info = Info;
2295+
}
2296+
2297+
bool isApplicable() {
2298+
if (Info.Kind != RangeKind::SingleStatement)
2299+
return false;
2300+
if (!findIfStmt())
2301+
return false;
2302+
return checkEachCondition();
2303+
}
2304+
2305+
private:
2306+
ResolvedRangeInfo Info;
2307+
IfStmt *If = nullptr;
2308+
ConditionalChecker checker;
2309+
2310+
bool findIfStmt() {
2311+
if (Info.ContainedNodes.size() != 1)
2312+
return false;
2313+
if (auto S = Info.ContainedNodes.front().dyn_cast<Stmt*>())
2314+
If = dyn_cast<IfStmt>(S);
2315+
return If != nullptr;
2316+
}
2317+
2318+
bool checkEachCondition() {
2319+
checker = ConditionalChecker();
2320+
do {
2321+
if (!checkEachElement())
2322+
return false;
2323+
} while ((If = dyn_cast_or_null<IfStmt>(If->getElseStmt())));
2324+
return true;
2325+
}
2326+
2327+
bool checkEachElement() {
2328+
bool result = true;
2329+
auto ConditionalList = If->getCond();
2330+
for (auto Element : ConditionalList) {
2331+
result &= check(Element);
2332+
}
2333+
return result;
2334+
}
2335+
2336+
bool check(StmtConditionElement ConditionElement) {
2337+
if (ConditionElement.getKind() == StmtConditionElement::CK_Availability)
2338+
return false;
2339+
if (ConditionElement.getKind() == StmtConditionElement::CK_PatternBinding)
2340+
checker.ConditionUseOnlyAllowedFunctions = true;
2341+
ConditionElement.walk(checker);
2342+
return checker.allCheckPassed();
2343+
}
2344+
};
2345+
return SwitchConvertable(Info).isApplicable();
2346+
}
2347+
2348+
bool RefactoringActionConvertToSwitchStmt::performChange() {
2349+
2350+
class VarNameFinder : public ASTWalker {
2351+
public:
2352+
std::string VarName;
2353+
2354+
Expr *walkToExprPost(Expr *E) {
2355+
if (E->getKind() != ExprKind::DeclRef)
2356+
return E;
2357+
auto D = dyn_cast<DeclRefExpr>(E)->getDecl();
2358+
if (D->getKind() != DeclKind::Var && D->getKind() != DeclKind::Param)
2359+
return E;
2360+
VarName = dyn_cast<VarDecl>(D)->getName().str().str();
2361+
return nullptr;
2362+
}
2363+
};
2364+
2365+
class ConditionalPatternFinder : public ASTWalker {
2366+
public:
2367+
ConditionalPatternFinder(SourceManager &SM) : SM(SM) {}
2368+
2369+
SmallString<64> ConditionalPattern = SmallString<64>();
2370+
2371+
Expr *walkToExprPost(Expr *E) {
2372+
if (E->getKind() != ExprKind::Binary)
2373+
return E;
2374+
auto BE = dyn_cast<BinaryExpr>(E);
2375+
if (isFunctionNameAllowed(BE))
2376+
appendPattern(dyn_cast<BinaryExpr>(E)->getArg());
2377+
return E;
2378+
}
2379+
2380+
std::pair<bool, Pattern*> walkToPatternPre(Pattern *P) {
2381+
ConditionalPattern.append(Lexer::getCharSourceRangeFromSourceRange(SM, P->getSourceRange()).str());
2382+
if (P->getKind() == PatternKind::OptionalSome)
2383+
ConditionalPattern.append("?");
2384+
return { true, nullptr };
2385+
}
2386+
2387+
private:
2388+
2389+
SourceManager &SM;
2390+
2391+
bool isFunctionNameAllowed(BinaryExpr *E) {
2392+
auto FunctionBody = dyn_cast<DotSyntaxCallExpr>(E->getFn())->getFn();
2393+
auto FunctionDeclaration = dyn_cast<DeclRefExpr>(FunctionBody)->getDecl();
2394+
auto FunctionName = dyn_cast<FuncDecl>(FunctionDeclaration)->getName().str();
2395+
return FunctionName == "~="
2396+
|| FunctionName == "=="
2397+
|| FunctionName == "__derived_enum_equals"
2398+
|| FunctionName == "__derived_struct_equals";
2399+
}
2400+
2401+
void appendPattern(TupleExpr *Tuple) {
2402+
auto PatternArgument = Tuple->getElements().back();
2403+
if (PatternArgument->getKind() == ExprKind::DeclRef)
2404+
PatternArgument = Tuple->getElements().front();
2405+
if (ConditionalPattern.size() > 0)
2406+
ConditionalPattern.append(", ");
2407+
ConditionalPattern.append(Lexer::getCharSourceRangeFromSourceRange(SM, PatternArgument->getSourceRange()).str());
2408+
}
2409+
};
2410+
2411+
class ConverterToSwitch {
2412+
public:
2413+
ConverterToSwitch(ResolvedRangeInfo Info, SourceManager &SM) : SM(SM) {
2414+
this->Info = Info;
2415+
}
2416+
2417+
void performConvert(SmallString<64> &Out) {
2418+
If = findIf();
2419+
OptionalLabel = If->getLabelInfo().Name.str().str();
2420+
ControlExpression = findControlExpression();
2421+
findPatternsAndBodies(PatternsAndBodies);
2422+
DefaultStatements = findDefaultStatements();
2423+
makeSwitchStatement(Out);
2424+
}
2425+
2426+
private:
2427+
ResolvedRangeInfo Info;
2428+
SourceManager &SM;
2429+
2430+
IfStmt *If;
2431+
IfStmt *PreviousIf;
2432+
2433+
std::string OptionalLabel;
2434+
std::string ControlExpression;
2435+
SmallVector<std::pair<std::string, std::string>, 16> PatternsAndBodies;
2436+
std::string DefaultStatements;
2437+
2438+
IfStmt *findIf() {
2439+
auto S = Info.ContainedNodes[0].dyn_cast<Stmt*>();
2440+
return dyn_cast<IfStmt>(S);
2441+
}
2442+
2443+
std::string findControlExpression() {
2444+
auto ConditionElement = If->getCond().front();
2445+
auto Finder = VarNameFinder();
2446+
ConditionElement.walk(Finder);
2447+
return Finder.VarName;
2448+
}
2449+
2450+
void findPatternsAndBodies(SmallVectorImpl<std::pair<std::string, std::string>> &Out) {
2451+
do {
2452+
auto pattern = findPattern();
2453+
auto body = findBodyStatements();
2454+
Out.push_back(std::make_pair(pattern, body));
2455+
PreviousIf = If;
2456+
} while ((If = dyn_cast_or_null<IfStmt>(If->getElseStmt())));
2457+
}
2458+
2459+
std::string findPattern() {
2460+
auto ConditionElement = If->getCond().front();
2461+
auto Finder = ConditionalPatternFinder(SM);
2462+
ConditionElement.walk(Finder);
2463+
return Finder.ConditionalPattern.str().str();
2464+
}
2465+
2466+
std::string findBodyStatements() {
2467+
return findBodyWithoutBraces(If->getThenStmt());
2468+
}
2469+
2470+
std::string findDefaultStatements() {
2471+
auto ElseBody = dyn_cast_or_null<BraceStmt>(PreviousIf->getElseStmt());
2472+
if (!ElseBody)
2473+
return getTokenText(tok::kw_break);
2474+
return findBodyWithoutBraces(ElseBody);
2475+
}
2476+
2477+
std::string findBodyWithoutBraces(Stmt *body) {
2478+
auto BS = dyn_cast<BraceStmt>(body);
2479+
if (!BS)
2480+
return Lexer::getCharSourceRangeFromSourceRange(SM, body->getSourceRange()).str().str();
2481+
if (BS->getElements().empty())
2482+
return getTokenText(tok::kw_break);
2483+
SourceRange BodyRange = BS->getElements().front().getSourceRange();
2484+
BodyRange.widen(BS->getElements().back().getSourceRange());
2485+
return Lexer::getCharSourceRangeFromSourceRange(SM, BodyRange).str().str();
2486+
}
2487+
2488+
void makeSwitchStatement(SmallString<64> &Out) {
2489+
StringRef Space = " ";
2490+
StringRef NewLine = "\n";
2491+
llvm::raw_svector_ostream OS(Out);
2492+
if (OptionalLabel.size() > 0)
2493+
OS << OptionalLabel << ":" << Space;
2494+
OS << tok::kw_switch << Space << ControlExpression << Space << tok::l_brace << NewLine;
2495+
for (auto &pair : PatternsAndBodies) {
2496+
OS << tok::kw_case << Space << pair.first << tok::colon << NewLine;
2497+
OS << pair.second << NewLine;
2498+
}
2499+
OS << tok::kw_default << tok::colon << NewLine;
2500+
OS << DefaultStatements << NewLine;
2501+
OS << tok::r_brace;
2502+
}
2503+
2504+
};
2505+
2506+
SmallString<64> result;
2507+
ConverterToSwitch(RangeInfo, SM).performConvert(result);
2508+
EditConsumer.accept(SM, RangeInfo.ContentRange, result.str());
2509+
return false;
2510+
}
2511+
22462512
/// Struct containing info about an IfStmt that can be converted into an IfExpr.
22472513
struct ConvertToTernaryExprInfo {
22482514
ConvertToTernaryExprInfo() {}

0 commit comments

Comments
 (0)