@@ -2243,6 +2243,272 @@ bool RefactoringActionConvertGuardExprToIfLetExpr::performChange() {
2243
2243
return false ;
2244
2244
}
2245
2245
2246
+ bool RefactoringActionConvertToSwitchStmt::
2247
+ isApplicable (ResolvedRangeInfo Info, DiagnosticEngine &Diag) {
2248
+
2249
+ class ConditionalChecker : public ASTWalker {
2250
+ public:
2251
+ bool ParamsUseSameVars = true ;
2252
+ bool ConditionUseOnlyAllowedFunctions = false ;
2253
+ StringRef ExpectName;
2254
+
2255
+ Expr *walkToExprPost (Expr *E) {
2256
+ if (E->getKind () != ExprKind::DeclRef)
2257
+ return E;
2258
+ auto D = dyn_cast<DeclRefExpr>(E)->getDecl ();
2259
+ if (D->getKind () == DeclKind::Var || D->getKind () == DeclKind::Param)
2260
+ ParamsUseSameVars = checkName (dyn_cast<VarDecl>(D));
2261
+ if (D->getKind () == DeclKind::Func)
2262
+ ConditionUseOnlyAllowedFunctions = checkName (dyn_cast<FuncDecl>(D));
2263
+ if (allCheckPassed ())
2264
+ return E;
2265
+ return nullptr ;
2266
+ }
2267
+
2268
+ bool allCheckPassed () {
2269
+ return ParamsUseSameVars && ConditionUseOnlyAllowedFunctions;
2270
+ }
2271
+
2272
+ private:
2273
+ bool checkName (VarDecl *VD) {
2274
+ auto Name = VD->getName ().str ();
2275
+ if (ExpectName.empty ())
2276
+ ExpectName = Name;
2277
+ return Name == ExpectName;
2278
+ }
2279
+
2280
+ bool checkName (FuncDecl *FD) {
2281
+ auto Name = FD->getName ().str ();
2282
+ return Name == " ~="
2283
+ || Name == " =="
2284
+ || Name == " __derived_enum_equals"
2285
+ || Name == " __derived_struct_equals"
2286
+ || Name == " ||"
2287
+ || Name == " ..." ;
2288
+ }
2289
+ };
2290
+
2291
+ class SwitchConvertable {
2292
+ public:
2293
+ SwitchConvertable (ResolvedRangeInfo Info) {
2294
+ this ->Info = Info;
2295
+ }
2296
+
2297
+ bool isApplicable () {
2298
+ if (Info.Kind != RangeKind::SingleStatement)
2299
+ return false ;
2300
+ if (!findIfStmt ())
2301
+ return false ;
2302
+ return checkEachCondition ();
2303
+ }
2304
+
2305
+ private:
2306
+ ResolvedRangeInfo Info;
2307
+ IfStmt *If = nullptr ;
2308
+ ConditionalChecker checker;
2309
+
2310
+ bool findIfStmt () {
2311
+ if (Info.ContainedNodes .size () != 1 )
2312
+ return false ;
2313
+ if (auto S = Info.ContainedNodes .front ().dyn_cast <Stmt*>())
2314
+ If = dyn_cast<IfStmt>(S);
2315
+ return If != nullptr ;
2316
+ }
2317
+
2318
+ bool checkEachCondition () {
2319
+ checker = ConditionalChecker ();
2320
+ do {
2321
+ if (!checkEachElement ())
2322
+ return false ;
2323
+ } while ((If = dyn_cast_or_null<IfStmt>(If->getElseStmt ())));
2324
+ return true ;
2325
+ }
2326
+
2327
+ bool checkEachElement () {
2328
+ bool result = true ;
2329
+ auto ConditionalList = If->getCond ();
2330
+ for (auto Element : ConditionalList) {
2331
+ result &= check (Element);
2332
+ }
2333
+ return result;
2334
+ }
2335
+
2336
+ bool check (StmtConditionElement ConditionElement) {
2337
+ if (ConditionElement.getKind () == StmtConditionElement::CK_Availability)
2338
+ return false ;
2339
+ if (ConditionElement.getKind () == StmtConditionElement::CK_PatternBinding)
2340
+ checker.ConditionUseOnlyAllowedFunctions = true ;
2341
+ ConditionElement.walk (checker);
2342
+ return checker.allCheckPassed ();
2343
+ }
2344
+ };
2345
+ return SwitchConvertable (Info).isApplicable ();
2346
+ }
2347
+
2348
+ bool RefactoringActionConvertToSwitchStmt::performChange () {
2349
+
2350
+ class VarNameFinder : public ASTWalker {
2351
+ public:
2352
+ std::string VarName;
2353
+
2354
+ Expr *walkToExprPost (Expr *E) {
2355
+ if (E->getKind () != ExprKind::DeclRef)
2356
+ return E;
2357
+ auto D = dyn_cast<DeclRefExpr>(E)->getDecl ();
2358
+ if (D->getKind () != DeclKind::Var && D->getKind () != DeclKind::Param)
2359
+ return E;
2360
+ VarName = dyn_cast<VarDecl>(D)->getName ().str ().str ();
2361
+ return nullptr ;
2362
+ }
2363
+ };
2364
+
2365
+ class ConditionalPatternFinder : public ASTWalker {
2366
+ public:
2367
+ ConditionalPatternFinder (SourceManager &SM) : SM(SM) {}
2368
+
2369
+ SmallString<64 > ConditionalPattern = SmallString<64 >();
2370
+
2371
+ Expr *walkToExprPost (Expr *E) {
2372
+ if (E->getKind () != ExprKind::Binary)
2373
+ return E;
2374
+ auto BE = dyn_cast<BinaryExpr>(E);
2375
+ if (isFunctionNameAllowed (BE))
2376
+ appendPattern (dyn_cast<BinaryExpr>(E)->getArg ());
2377
+ return E;
2378
+ }
2379
+
2380
+ std::pair<bool , Pattern*> walkToPatternPre (Pattern *P) {
2381
+ ConditionalPattern.append (Lexer::getCharSourceRangeFromSourceRange (SM, P->getSourceRange ()).str ());
2382
+ if (P->getKind () == PatternKind::OptionalSome)
2383
+ ConditionalPattern.append (" ?" );
2384
+ return { true , nullptr };
2385
+ }
2386
+
2387
+ private:
2388
+
2389
+ SourceManager &SM;
2390
+
2391
+ bool isFunctionNameAllowed (BinaryExpr *E) {
2392
+ auto FunctionBody = dyn_cast<DotSyntaxCallExpr>(E->getFn ())->getFn ();
2393
+ auto FunctionDeclaration = dyn_cast<DeclRefExpr>(FunctionBody)->getDecl ();
2394
+ auto FunctionName = dyn_cast<FuncDecl>(FunctionDeclaration)->getName ().str ();
2395
+ return FunctionName == " ~="
2396
+ || FunctionName == " =="
2397
+ || FunctionName == " __derived_enum_equals"
2398
+ || FunctionName == " __derived_struct_equals" ;
2399
+ }
2400
+
2401
+ void appendPattern (TupleExpr *Tuple) {
2402
+ auto PatternArgument = Tuple->getElements ().back ();
2403
+ if (PatternArgument->getKind () == ExprKind::DeclRef)
2404
+ PatternArgument = Tuple->getElements ().front ();
2405
+ if (ConditionalPattern.size () > 0 )
2406
+ ConditionalPattern.append (" , " );
2407
+ ConditionalPattern.append (Lexer::getCharSourceRangeFromSourceRange (SM, PatternArgument->getSourceRange ()).str ());
2408
+ }
2409
+ };
2410
+
2411
+ class ConverterToSwitch {
2412
+ public:
2413
+ ConverterToSwitch (ResolvedRangeInfo Info, SourceManager &SM) : SM(SM) {
2414
+ this ->Info = Info;
2415
+ }
2416
+
2417
+ void performConvert (SmallString<64 > &Out) {
2418
+ If = findIf ();
2419
+ OptionalLabel = If->getLabelInfo ().Name .str ().str ();
2420
+ ControlExpression = findControlExpression ();
2421
+ findPatternsAndBodies (PatternsAndBodies);
2422
+ DefaultStatements = findDefaultStatements ();
2423
+ makeSwitchStatement (Out);
2424
+ }
2425
+
2426
+ private:
2427
+ ResolvedRangeInfo Info;
2428
+ SourceManager &SM;
2429
+
2430
+ IfStmt *If;
2431
+ IfStmt *PreviousIf;
2432
+
2433
+ std::string OptionalLabel;
2434
+ std::string ControlExpression;
2435
+ SmallVector<std::pair<std::string, std::string>, 16 > PatternsAndBodies;
2436
+ std::string DefaultStatements;
2437
+
2438
+ IfStmt *findIf () {
2439
+ auto S = Info.ContainedNodes [0 ].dyn_cast <Stmt*>();
2440
+ return dyn_cast<IfStmt>(S);
2441
+ }
2442
+
2443
+ std::string findControlExpression () {
2444
+ auto ConditionElement = If->getCond ().front ();
2445
+ auto Finder = VarNameFinder ();
2446
+ ConditionElement.walk (Finder);
2447
+ return Finder.VarName ;
2448
+ }
2449
+
2450
+ void findPatternsAndBodies (SmallVectorImpl<std::pair<std::string, std::string>> &Out) {
2451
+ do {
2452
+ auto pattern = findPattern ();
2453
+ auto body = findBodyStatements ();
2454
+ Out.push_back (std::make_pair (pattern, body));
2455
+ PreviousIf = If;
2456
+ } while ((If = dyn_cast_or_null<IfStmt>(If->getElseStmt ())));
2457
+ }
2458
+
2459
+ std::string findPattern () {
2460
+ auto ConditionElement = If->getCond ().front ();
2461
+ auto Finder = ConditionalPatternFinder (SM);
2462
+ ConditionElement.walk (Finder);
2463
+ return Finder.ConditionalPattern .str ().str ();
2464
+ }
2465
+
2466
+ std::string findBodyStatements () {
2467
+ return findBodyWithoutBraces (If->getThenStmt ());
2468
+ }
2469
+
2470
+ std::string findDefaultStatements () {
2471
+ auto ElseBody = dyn_cast_or_null<BraceStmt>(PreviousIf->getElseStmt ());
2472
+ if (!ElseBody)
2473
+ return getTokenText (tok::kw_break);
2474
+ return findBodyWithoutBraces (ElseBody);
2475
+ }
2476
+
2477
+ std::string findBodyWithoutBraces (Stmt *body) {
2478
+ auto BS = dyn_cast<BraceStmt>(body);
2479
+ if (!BS)
2480
+ return Lexer::getCharSourceRangeFromSourceRange (SM, body->getSourceRange ()).str ().str ();
2481
+ if (BS->getElements ().empty ())
2482
+ return getTokenText (tok::kw_break);
2483
+ SourceRange BodyRange = BS->getElements ().front ().getSourceRange ();
2484
+ BodyRange.widen (BS->getElements ().back ().getSourceRange ());
2485
+ return Lexer::getCharSourceRangeFromSourceRange (SM, BodyRange).str ().str ();
2486
+ }
2487
+
2488
+ void makeSwitchStatement (SmallString<64 > &Out) {
2489
+ StringRef Space = " " ;
2490
+ StringRef NewLine = " \n " ;
2491
+ llvm::raw_svector_ostream OS (Out);
2492
+ if (OptionalLabel.size () > 0 )
2493
+ OS << OptionalLabel << " :" << Space;
2494
+ OS << tok::kw_switch << Space << ControlExpression << Space << tok::l_brace << NewLine;
2495
+ for (auto &pair : PatternsAndBodies) {
2496
+ OS << tok::kw_case << Space << pair.first << tok::colon << NewLine;
2497
+ OS << pair.second << NewLine;
2498
+ }
2499
+ OS << tok::kw_default << tok::colon << NewLine;
2500
+ OS << DefaultStatements << NewLine;
2501
+ OS << tok::r_brace;
2502
+ }
2503
+
2504
+ };
2505
+
2506
+ SmallString<64 > result;
2507
+ ConverterToSwitch (RangeInfo, SM).performConvert (result);
2508
+ EditConsumer.accept (SM, RangeInfo.ContentRange , result.str ());
2509
+ return false ;
2510
+ }
2511
+
2246
2512
// / Struct containing info about an IfStmt that can be converted into an IfExpr.
2247
2513
struct ConvertToTernaryExprInfo {
2248
2514
ConvertToTernaryExprInfo () {}
0 commit comments