Skip to content

Commit a45cd03

Browse files
committed
[Refactoring] Move ExtractRepeatedExpr and ConvertToSwitchStmt to its own file
1 parent da69dc5 commit a45cd03

File tree

9 files changed

+785
-625
lines changed

9 files changed

+785
-625
lines changed

lib/Refactoring/CMakeLists.txt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,13 @@
11
add_swift_host_library(swiftRefactoring STATIC
22
CollapseNestedIfStmt.cpp
33
ConvertStringConcatenationToInterpolation.cpp
4+
ConvertToSwitchStmt.cpp
5+
ExtractRepeatedExpr.cpp
46
MoveMembersToExtension.cpp
57
Refactoring.cpp
8+
ExtractExprBase.cpp
69
ReplaceBodiesWithFatalError.cpp
10+
Utils.cpp
711
)
812

913
target_link_libraries(swiftRefactoring PRIVATE

lib/Refactoring/ContextFinder.h

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
//===----------------------------------------------------------------------===//
2+
//
3+
// This source file is part of the Swift.org open source project
4+
//
5+
// Copyright (c) 2014 - 2023 Apple Inc. and the Swift project authors
6+
// Licensed under Apache License v2.0 with Runtime Library Exception
7+
//
8+
// See https://swift.org/LICENSE.txt for license information
9+
// See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors
10+
//
11+
//===----------------------------------------------------------------------===//
12+
13+
#include "swift/AST/ASTContext.h"
14+
#include "swift/AST/SourceFile.h"
15+
#include "swift/Basic/SourceManager.h"
16+
17+
namespace swift {
18+
namespace refactoring {
19+
20+
class ContextFinder : public SourceEntityWalker {
21+
SourceFile &SF;
22+
ASTContext &Ctx;
23+
SourceManager &SM;
24+
SourceRange Target;
25+
std::function<bool(ASTNode)> IsContext;
26+
SmallVector<ASTNode, 4> AllContexts;
27+
bool contains(ASTNode Enclosing) {
28+
auto Result = SM.rangeContainsRespectingReplacedRanges(
29+
Enclosing.getSourceRange(), Target);
30+
if (Result && IsContext(Enclosing)) {
31+
AllContexts.push_back(Enclosing);
32+
}
33+
return Result;
34+
}
35+
36+
public:
37+
ContextFinder(
38+
SourceFile &SF, ASTNode TargetNode,
39+
std::function<bool(ASTNode)> IsContext = [](ASTNode N) { return true; })
40+
: SF(SF), Ctx(SF.getASTContext()), SM(Ctx.SourceMgr),
41+
Target(TargetNode.getSourceRange()), IsContext(IsContext) {}
42+
43+
ContextFinder(
44+
SourceFile &SF, SourceLoc TargetLoc,
45+
std::function<bool(ASTNode)> IsContext = [](ASTNode N) { return true; })
46+
: SF(SF), Ctx(SF.getASTContext()), SM(Ctx.SourceMgr), Target(TargetLoc),
47+
IsContext(IsContext) {
48+
assert(TargetLoc.isValid() && "Invalid loc to find");
49+
}
50+
51+
// Only need expansions for the expands refactoring, but we
52+
// skip nodes that don't contain the passed location anyway.
53+
virtual MacroWalking getMacroWalkingBehavior() const override {
54+
return MacroWalking::ArgumentsAndExpansion;
55+
}
56+
57+
bool walkToDeclPre(Decl *D, CharSourceRange Range) override {
58+
return contains(D);
59+
}
60+
bool walkToStmtPre(Stmt *S) override { return contains(S); }
61+
bool walkToExprPre(Expr *E) override { return contains(E); }
62+
void resolve() { walk(SF); }
63+
ArrayRef<ASTNode> getContexts() const {
64+
return llvm::makeArrayRef(AllContexts);
65+
}
66+
};
67+
68+
} // namespace refactoring
69+
} // namespace swift
Lines changed: 309 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,309 @@
1+
//===----------------------------------------------------------------------===//
2+
//
3+
// This source file is part of the Swift.org open source project
4+
//
5+
// Copyright (c) 2014 - 2023 Apple Inc. and the Swift project authors
6+
// Licensed under Apache License v2.0 with Runtime Library Exception
7+
//
8+
// See https://swift.org/LICENSE.txt for license information
9+
// See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors
10+
//
11+
//===----------------------------------------------------------------------===//
12+
13+
#include "RefactoringActions.h"
14+
#include "swift/AST/Pattern.h"
15+
#include "swift/AST/Stmt.h"
16+
17+
using namespace swift::refactoring;
18+
19+
bool RefactoringActionConvertToSwitchStmt::isApplicable(
20+
const ResolvedRangeInfo &Info, DiagnosticEngine &Diag) {
21+
22+
class ConditionalChecker : public ASTWalker {
23+
public:
24+
bool ParamsUseSameVars = true;
25+
bool ConditionUseOnlyAllowedFunctions = false;
26+
StringRef ExpectName;
27+
28+
MacroWalking getMacroWalkingBehavior() const override {
29+
return MacroWalking::Arguments;
30+
}
31+
32+
PostWalkResult<Expr *> walkToExprPost(Expr *E) override {
33+
if (E->getKind() != ExprKind::DeclRef)
34+
return Action::Continue(E);
35+
auto D = dyn_cast<DeclRefExpr>(E)->getDecl();
36+
if (D->getKind() == DeclKind::Var || D->getKind() == DeclKind::Param)
37+
ParamsUseSameVars = checkName(dyn_cast<VarDecl>(D));
38+
if (D->getKind() == DeclKind::Func)
39+
ConditionUseOnlyAllowedFunctions = checkName(dyn_cast<FuncDecl>(D));
40+
if (allCheckPassed())
41+
return Action::Continue(E);
42+
43+
return Action::Stop();
44+
}
45+
46+
bool allCheckPassed() {
47+
return ParamsUseSameVars && ConditionUseOnlyAllowedFunctions;
48+
}
49+
50+
private:
51+
bool checkName(VarDecl *VD) {
52+
auto Name = VD->getName().str();
53+
if (ExpectName.empty())
54+
ExpectName = Name;
55+
return Name == ExpectName;
56+
}
57+
58+
bool checkName(FuncDecl *FD) {
59+
const auto Name = FD->getBaseIdentifier().str();
60+
return Name == "~=" || Name == "==" || Name == "__derived_enum_equals" ||
61+
Name == "__derived_struct_equals" || Name == "||" || Name == "...";
62+
}
63+
};
64+
65+
class SwitchConvertable {
66+
public:
67+
SwitchConvertable(const ResolvedRangeInfo &Info) : Info(Info) {}
68+
69+
bool isApplicable() {
70+
if (Info.Kind != RangeKind::SingleStatement)
71+
return false;
72+
if (!findIfStmt())
73+
return false;
74+
return checkEachCondition();
75+
}
76+
77+
private:
78+
const ResolvedRangeInfo &Info;
79+
IfStmt *If = nullptr;
80+
ConditionalChecker checker;
81+
82+
bool findIfStmt() {
83+
if (Info.ContainedNodes.size() != 1)
84+
return false;
85+
if (auto S = Info.ContainedNodes.front().dyn_cast<Stmt *>())
86+
If = dyn_cast<IfStmt>(S);
87+
return If != nullptr;
88+
}
89+
90+
bool checkEachCondition() {
91+
checker = ConditionalChecker();
92+
do {
93+
if (!checkEachElement())
94+
return false;
95+
} while ((If = dyn_cast_or_null<IfStmt>(If->getElseStmt())));
96+
return true;
97+
}
98+
99+
bool checkEachElement() {
100+
bool result = true;
101+
auto ConditionalList = If->getCond();
102+
for (auto Element : ConditionalList) {
103+
result &= check(Element);
104+
}
105+
return result;
106+
}
107+
108+
bool check(StmtConditionElement ConditionElement) {
109+
if (ConditionElement.getKind() == StmtConditionElement::CK_Availability)
110+
return false;
111+
if (ConditionElement.getKind() == StmtConditionElement::CK_PatternBinding)
112+
checker.ConditionUseOnlyAllowedFunctions = true;
113+
ConditionElement.walk(checker);
114+
return checker.allCheckPassed();
115+
}
116+
};
117+
return SwitchConvertable(Info).isApplicable();
118+
}
119+
120+
bool RefactoringActionConvertToSwitchStmt::performChange() {
121+
122+
class VarNameFinder : public ASTWalker {
123+
public:
124+
std::string VarName;
125+
126+
MacroWalking getMacroWalkingBehavior() const override {
127+
return MacroWalking::Arguments;
128+
}
129+
130+
PostWalkResult<Expr *> walkToExprPost(Expr *E) override {
131+
if (E->getKind() != ExprKind::DeclRef)
132+
return Action::Continue(E);
133+
auto D = dyn_cast<DeclRefExpr>(E)->getDecl();
134+
if (D->getKind() != DeclKind::Var && D->getKind() != DeclKind::Param)
135+
return Action::Continue(E);
136+
VarName = dyn_cast<VarDecl>(D)->getName().str().str();
137+
return Action::Stop();
138+
}
139+
};
140+
141+
class ConditionalPatternFinder : public ASTWalker {
142+
public:
143+
ConditionalPatternFinder(SourceManager &SM) : SM(SM) {}
144+
145+
SmallString<64> ConditionalPattern = SmallString<64>();
146+
147+
MacroWalking getMacroWalkingBehavior() const override {
148+
return MacroWalking::Arguments;
149+
}
150+
151+
PostWalkResult<Expr *> walkToExprPost(Expr *E) override {
152+
auto *BE = dyn_cast<BinaryExpr>(E);
153+
if (!BE)
154+
return Action::Continue(E);
155+
if (isFunctionNameAllowed(BE))
156+
appendPattern(BE->getLHS(), BE->getRHS());
157+
return Action::Continue(E);
158+
}
159+
160+
PreWalkResult<Pattern *> walkToPatternPre(Pattern *P) override {
161+
ConditionalPattern.append(
162+
Lexer::getCharSourceRangeFromSourceRange(SM, P->getSourceRange())
163+
.str());
164+
if (P->getKind() == PatternKind::OptionalSome)
165+
ConditionalPattern.append("?");
166+
return Action::Stop();
167+
}
168+
169+
private:
170+
SourceManager &SM;
171+
172+
bool isFunctionNameAllowed(BinaryExpr *E) {
173+
Expr *Fn = E->getFn();
174+
if (auto DotSyntaxCall = dyn_cast_or_null<DotSyntaxCallExpr>(Fn)) {
175+
Fn = DotSyntaxCall->getFn();
176+
}
177+
DeclRefExpr *DeclRef = dyn_cast_or_null<DeclRefExpr>(Fn);
178+
if (!DeclRef) {
179+
return false;
180+
}
181+
auto FunctionDeclaration = dyn_cast_or_null<FuncDecl>(DeclRef->getDecl());
182+
if (!FunctionDeclaration) {
183+
return false;
184+
}
185+
auto &ASTCtx = FunctionDeclaration->getASTContext();
186+
const auto FunctionName = FunctionDeclaration->getBaseIdentifier();
187+
return FunctionName == ASTCtx.Id_MatchOperator ||
188+
FunctionName == ASTCtx.Id_EqualsOperator ||
189+
FunctionName == ASTCtx.Id_derived_enum_equals ||
190+
FunctionName == ASTCtx.Id_derived_struct_equals;
191+
}
192+
193+
void appendPattern(Expr *LHS, Expr *RHS) {
194+
auto *PatternArgument = RHS;
195+
if (PatternArgument->getKind() == ExprKind::DeclRef)
196+
PatternArgument = LHS;
197+
if (ConditionalPattern.size() > 0)
198+
ConditionalPattern.append(", ");
199+
ConditionalPattern.append(Lexer::getCharSourceRangeFromSourceRange(
200+
SM, PatternArgument->getSourceRange())
201+
.str());
202+
}
203+
};
204+
205+
class ConverterToSwitch {
206+
public:
207+
ConverterToSwitch(const ResolvedRangeInfo &Info, SourceManager &SM)
208+
: Info(Info), SM(SM) {}
209+
210+
void performConvert(SmallString<64> &Out) {
211+
If = findIf();
212+
OptionalLabel = If->getLabelInfo().Name.str().str();
213+
ControlExpression = findControlExpression();
214+
findPatternsAndBodies(PatternsAndBodies);
215+
DefaultStatements = findDefaultStatements();
216+
makeSwitchStatement(Out);
217+
}
218+
219+
private:
220+
const ResolvedRangeInfo &Info;
221+
SourceManager &SM;
222+
223+
IfStmt *If;
224+
IfStmt *PreviousIf;
225+
226+
std::string OptionalLabel;
227+
std::string ControlExpression;
228+
SmallVector<std::pair<std::string, std::string>, 16> PatternsAndBodies;
229+
std::string DefaultStatements;
230+
231+
IfStmt *findIf() {
232+
auto S = Info.ContainedNodes[0].dyn_cast<Stmt *>();
233+
return dyn_cast<IfStmt>(S);
234+
}
235+
236+
std::string findControlExpression() {
237+
auto ConditionElement = If->getCond().front();
238+
auto Finder = VarNameFinder();
239+
ConditionElement.walk(Finder);
240+
return Finder.VarName;
241+
}
242+
243+
void findPatternsAndBodies(
244+
SmallVectorImpl<std::pair<std::string, std::string>> &Out) {
245+
do {
246+
auto pattern = findPattern();
247+
auto body = findBodyStatements();
248+
Out.push_back(std::make_pair(pattern, body));
249+
PreviousIf = If;
250+
} while ((If = dyn_cast_or_null<IfStmt>(If->getElseStmt())));
251+
}
252+
253+
std::string findPattern() {
254+
auto ConditionElement = If->getCond().front();
255+
auto Finder = ConditionalPatternFinder(SM);
256+
ConditionElement.walk(Finder);
257+
return Finder.ConditionalPattern.str().str();
258+
}
259+
260+
std::string findBodyStatements() {
261+
return findBodyWithoutBraces(If->getThenStmt());
262+
}
263+
264+
std::string findDefaultStatements() {
265+
auto ElseBody = dyn_cast_or_null<BraceStmt>(PreviousIf->getElseStmt());
266+
if (!ElseBody)
267+
return getTokenText(tok::kw_break).str();
268+
return findBodyWithoutBraces(ElseBody);
269+
}
270+
271+
std::string findBodyWithoutBraces(Stmt *body) {
272+
auto BS = dyn_cast<BraceStmt>(body);
273+
if (!BS)
274+
return Lexer::getCharSourceRangeFromSourceRange(SM,
275+
body->getSourceRange())
276+
.str()
277+
.str();
278+
if (BS->getElements().empty())
279+
return getTokenText(tok::kw_break).str();
280+
SourceRange BodyRange = BS->getElements().front().getSourceRange();
281+
BodyRange.widen(BS->getElements().back().getSourceRange());
282+
return Lexer::getCharSourceRangeFromSourceRange(SM, BodyRange)
283+
.str()
284+
.str();
285+
}
286+
287+
void makeSwitchStatement(SmallString<64> &Out) {
288+
StringRef Space = " ";
289+
StringRef NewLine = "\n";
290+
llvm::raw_svector_ostream OS(Out);
291+
if (OptionalLabel.size() > 0)
292+
OS << OptionalLabel << ":" << Space;
293+
OS << tok::kw_switch << Space << ControlExpression << Space
294+
<< tok::l_brace << NewLine;
295+
for (auto &pair : PatternsAndBodies) {
296+
OS << tok::kw_case << Space << pair.first << tok::colon << NewLine;
297+
OS << pair.second << NewLine;
298+
}
299+
OS << tok::kw_default << tok::colon << NewLine;
300+
OS << DefaultStatements << NewLine;
301+
OS << tok::r_brace;
302+
}
303+
};
304+
305+
SmallString<64> result;
306+
ConverterToSwitch(RangeInfo, SM).performConvert(result);
307+
EditConsumer.accept(SM, RangeInfo.ContentRange, result.str());
308+
return false;
309+
}

0 commit comments

Comments
 (0)