Skip to content

Commit 10038b6

Browse files
committed
[Refactoring] Add async wrapper refactoring action
This allows an async alternative function to be created that forwards onto the user's completion handler function through the use of `withCheckedContinuation`/`withCheckedThrowingContinuation`. rdar://77802486
1 parent e4a7a93 commit 10038b6

File tree

5 files changed

+397
-11
lines changed

5 files changed

+397
-11
lines changed

include/swift/IDE/RefactoringKinds.def

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,8 @@ CURSOR_REFACTORING(ConvertToAsync, "Convert Function to Async", convert.func-to-
6060

6161
CURSOR_REFACTORING(AddAsyncAlternative, "Add Async Alternative", add.async-alternative)
6262

63+
CURSOR_REFACTORING(AddAsyncWrapper, "Add Async Wrapper", add.async-wrapper)
64+
6365
RANGE_REFACTORING(ExtractExpr, "Extract Expression", extract.expr)
6466

6567
RANGE_REFACTORING(ExtractFunction, "Extract Method", extract.function)

lib/IDE/Refactoring.cpp

Lines changed: 208 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4195,6 +4195,13 @@ struct AsyncHandlerDesc {
41954195
return params();
41964196
}
41974197

4198+
/// If the completion handler has an Error parameter, return it.
4199+
Optional<AnyFunctionType::Param> getErrorParam() const {
4200+
if (HasError && Type == HandlerType::PARAMS)
4201+
return params().back();
4202+
return None;
4203+
}
4204+
41984205
/// Get the type of the error that will be thrown by the \c async method or \c
41994206
/// None if the completion handler doesn't accept an error parameter.
42004207
/// This may be more specialized than the generic 'Error' type if the
@@ -5397,6 +5404,41 @@ class AsyncConverter : private SourceEntityWalker {
53975404
return true;
53985405
}
53995406

5407+
/// Creates an async alternative function that forwards onto the completion
5408+
/// handler function through
5409+
/// withCheckedContinuation/withCheckedThrowingContinuation.
5410+
bool createAsyncWrapper() {
5411+
assert(Buffer.empty() && "AsyncConverter can only be used once");
5412+
auto *FD = cast<FuncDecl>(StartNode.get<Decl *>());
5413+
5414+
// First add the new async function declaration.
5415+
addFuncDecl(FD);
5416+
OS << tok::l_brace << "\n";
5417+
5418+
// Then add the body.
5419+
OS << tok::kw_return << " ";
5420+
if (TopHandler.HasError)
5421+
OS << tok::kw_try << " ";
5422+
5423+
OS << "await ";
5424+
5425+
// withChecked[Throwing]Continuation { cont in
5426+
if (TopHandler.HasError) {
5427+
OS << "withCheckedThrowingContinuation";
5428+
} else {
5429+
OS << "withCheckedContinuation";
5430+
}
5431+
OS << " " << tok::l_brace << " cont " << tok::kw_in << "\n";
5432+
5433+
// fnWithHandler(args...) { ... }
5434+
auto ClosureStr = getAsyncWrapperCompletionClosure("cont", TopHandler);
5435+
addForwardingCallTo(FD, TopHandler, /*HandlerReplacement*/ ClosureStr);
5436+
5437+
OS << tok::r_brace << "\n"; // end continuation closure
5438+
OS << tok::r_brace << "\n"; // end function body
5439+
return true;
5440+
}
5441+
54005442
void replace(ASTNode Node, SourceEditConsumer &EditConsumer,
54015443
SourceLoc StartOverride = SourceLoc()) {
54025444
SourceRange Range = Node.getSourceRange();
@@ -5446,6 +5488,116 @@ class AsyncConverter : private SourceEntityWalker {
54465488
OS << tok::r_paren;
54475489
}
54485490

5491+
/// Retrieve the completion handler closure argument for an async wrapper
5492+
/// function.
5493+
std::string
5494+
getAsyncWrapperCompletionClosure(StringRef ContName,
5495+
const AsyncHandlerParamDesc &HandlerDesc) {
5496+
std::string OutputStr;
5497+
llvm::raw_string_ostream OS(OutputStr);
5498+
5499+
OS << " " << tok::l_brace; // start closure
5500+
5501+
// Prepare parameter names for the closure.
5502+
auto SuccessParams = HandlerDesc.getSuccessParams();
5503+
SmallVector<SmallString<4>, 2> SuccessParamNames;
5504+
for (auto idx : indices(SuccessParams)) {
5505+
SuccessParamNames.emplace_back("res");
5506+
5507+
// If we have multiple success params, number them e.g res1, res2...
5508+
if (SuccessParams.size() > 1)
5509+
SuccessParamNames.back().append(std::to_string(idx + 1));
5510+
}
5511+
Optional<SmallString<4>> ErrName;
5512+
if (HandlerDesc.getErrorParam())
5513+
ErrName.emplace("err");
5514+
5515+
auto HasAnyParams = !SuccessParamNames.empty() || ErrName;
5516+
if (HasAnyParams)
5517+
OS << " ";
5518+
5519+
// res1, res2
5520+
llvm::interleave(
5521+
SuccessParamNames, [&](auto Name) { OS << Name; },
5522+
[&]() { OS << tok::comma << " "; });
5523+
5524+
// , err
5525+
if (ErrName) {
5526+
if (!SuccessParamNames.empty())
5527+
OS << tok::comma << " ";
5528+
5529+
OS << *ErrName;
5530+
}
5531+
if (HasAnyParams)
5532+
OS << " " << tok::kw_in;
5533+
5534+
OS << "\n";
5535+
5536+
// The closure body.
5537+
switch (HandlerDesc.Type) {
5538+
case HandlerType::PARAMS: {
5539+
// For a (Success?, Error?) -> Void handler, we do an if let on the error.
5540+
if (ErrName) {
5541+
// if let err = err {
5542+
OS << tok::kw_if << " " << tok::kw_let << " ";
5543+
OS << *ErrName << " " << tok::equal << " " << *ErrName << " ";
5544+
OS << tok::l_brace << "\n";
5545+
5546+
// cont.resume(throwing: err)
5547+
OS << ContName << tok::period << "resume" << tok::l_paren;
5548+
OS << "throwing" << tok::colon << " " << *ErrName;
5549+
OS << tok::r_paren << "\n";
5550+
5551+
// return }
5552+
OS << tok::kw_return << "\n";
5553+
OS << tok::r_brace << "\n";
5554+
}
5555+
5556+
// If we have any success params that we need to unwrap, insert a guard.
5557+
for (auto Idx : indices(SuccessParamNames)) {
5558+
auto &Name = SuccessParamNames[Idx];
5559+
auto ParamTy = SuccessParams[Idx].getParameterType();
5560+
if (!HandlerDesc.shouldUnwrap(ParamTy))
5561+
continue;
5562+
5563+
// guard let res = res else {
5564+
OS << tok::kw_guard << " " << tok::kw_let << " ";
5565+
OS << Name << " " << tok::equal << " " << Name << " " << tok::kw_else;
5566+
OS << " " << tok::l_brace << "\n";
5567+
5568+
// fatalError(...)
5569+
OS << "fatalError" << tok::l_paren;
5570+
OS << "\"Expected non-nil success param '" << Name;
5571+
OS << "' for nil error\"";
5572+
OS << tok::r_paren << "\n";
5573+
5574+
// End guard.
5575+
OS << tok::r_brace << "\n";
5576+
}
5577+
5578+
// cont.resume(returning: (res1, res2, ...))
5579+
OS << ContName << tok::period << "resume" << tok::l_paren;
5580+
OS << "returning" << tok::colon << " ";
5581+
addTupleOf(llvm::makeArrayRef(SuccessParamNames), OS,
5582+
[&](auto Ref) { OS << Ref; });
5583+
OS << tok::r_paren << "\n";
5584+
break;
5585+
}
5586+
case HandlerType::RESULT: {
5587+
// cont.resume(with: res)
5588+
assert(SuccessParamNames.size() == 1);
5589+
OS << ContName << tok::period << "resume" << tok::l_paren;
5590+
OS << "with" << tok::colon << " " << SuccessParamNames[0];
5591+
OS << tok::r_paren << "\n";
5592+
break;
5593+
}
5594+
case HandlerType::INVALID:
5595+
llvm_unreachable("Should not have an invalid handler here");
5596+
}
5597+
5598+
OS << tok::r_brace << "\n"; // end closure
5599+
return OutputStr;
5600+
}
54495601

54505602
/// Retrieves the location for the start of a comment attached to the token
54515603
/// at the provided location, or the location itself if there is no comment.
@@ -6472,6 +6624,24 @@ class AsyncConverter : private SourceEntityWalker {
64726624
}
64736625
};
64746626

6627+
/// Adds an attribute to describe a completion handler function's async
6628+
/// alternative if necessary.
6629+
void addCompletionHandlerAsyncAttrIfNeccessary(
6630+
ASTContext &Ctx, const FuncDecl *FD,
6631+
const AsyncHandlerParamDesc &HandlerDesc,
6632+
SourceEditConsumer &EditConsumer) {
6633+
if (!Ctx.LangOpts.EnableExperimentalConcurrency)
6634+
return;
6635+
6636+
llvm::SmallString<0> HandlerAttribute;
6637+
llvm::raw_svector_ostream OS(HandlerAttribute);
6638+
OS << "@completionHandlerAsync(\"";
6639+
HandlerDesc.printAsyncFunctionName(OS);
6640+
OS << "\", completionHandlerIndex: " << HandlerDesc.Index << ")\n";
6641+
EditConsumer.accept(Ctx.SourceMgr, FD->getAttributeInsertionLoc(false),
6642+
HandlerAttribute);
6643+
}
6644+
64756645
} // namespace asyncrefactorings
64766646

64776647
bool RefactoringActionConvertCallToAsyncAlternative::isApplicable(
@@ -6593,16 +6763,7 @@ bool RefactoringActionAddAsyncAlternative::performChange() {
65936763
"@available(*, deprecated, message: \"Prefer async "
65946764
"alternative instead\")\n");
65956765

6596-
if (Ctx.LangOpts.EnableExperimentalConcurrency) {
6597-
// Add an attribute to describe its async alternative
6598-
llvm::SmallString<0> HandlerAttribute;
6599-
llvm::raw_svector_ostream OS(HandlerAttribute);
6600-
OS << "@completionHandlerAsync(\"";
6601-
HandlerDesc.printAsyncFunctionName(OS);
6602-
OS << "\", completionHandlerIndex: " << HandlerDesc.Index << ")\n";
6603-
EditConsumer.accept(SM, FD->getAttributeInsertionLoc(false),
6604-
HandlerAttribute);
6605-
}
6766+
addCompletionHandlerAsyncAttrIfNeccessary(Ctx, FD, HandlerDesc, EditConsumer);
66066767

66076768
AsyncConverter LegacyBodyCreator(TheFile, SM, DiagEngine, FD, HandlerDesc);
66086769
if (LegacyBodyCreator.createLegacyBody()) {
@@ -6614,6 +6775,43 @@ bool RefactoringActionAddAsyncAlternative::performChange() {
66146775

66156776
return false;
66166777
}
6778+
6779+
bool RefactoringActionAddAsyncWrapper::isApplicable(
6780+
const ResolvedCursorInfo &CursorInfo, DiagnosticEngine &Diag) {
6781+
using namespace asyncrefactorings;
6782+
6783+
auto *FD = findFunction(CursorInfo);
6784+
if (!FD)
6785+
return false;
6786+
6787+
auto HandlerDesc =
6788+
AsyncHandlerParamDesc::find(FD, /*RequireAttributeOrName=*/false);
6789+
return HandlerDesc.isValid();
6790+
}
6791+
6792+
bool RefactoringActionAddAsyncWrapper::performChange() {
6793+
using namespace asyncrefactorings;
6794+
6795+
auto *FD = findFunction(CursorInfo);
6796+
assert(FD &&
6797+
"Should not run performChange when refactoring is not applicable");
6798+
6799+
auto HandlerDesc =
6800+
AsyncHandlerParamDesc::find(FD, /*RequireAttributeOrName=*/false);
6801+
assert(HandlerDesc.isValid() &&
6802+
"Should not run performChange when refactoring is not applicable");
6803+
6804+
AsyncConverter Converter(TheFile, SM, DiagEngine, FD, HandlerDesc);
6805+
if (!Converter.createAsyncWrapper())
6806+
return true;
6807+
6808+
addCompletionHandlerAsyncAttrIfNeccessary(Ctx, FD, HandlerDesc, EditConsumer);
6809+
6810+
// Add the async wrapper.
6811+
Converter.insertAfter(FD, EditConsumer);
6812+
return false;
6813+
}
6814+
66176815
} // end of anonymous namespace
66186816

66196817
StringRef swift::ide::

test/SourceKit/Refactoring/basic.swift

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -244,16 +244,20 @@ func hasCallToAsyncAlternative() {
244244
// CHECK-ASYNC-NEXT: Convert Function to Async
245245
// CHECK-ASYNC-NEXT: source.refactoring.kind.add.async-alternative
246246
// CHECK-ASYNC-NEXT: Add Async Alternative
247+
// CHECK-ASYNC-NEXT: source.refactoring.kind.add.async-wrapper
248+
// CHECK-ASYNC-NEXT: Add Async Wrapper
247249
// CHECK-ASYNC-NOT: source.refactoring.kind.convert.call-to-async
248250
// CHECK-ASYNC: ACTIONS END
249251

250252
// CHECK-CALLASYNC: ACTIONS BEGIN
251253
// CHECK-CALLASYNC-NOT: source.refactoring.kind.add.async-alternative
252254
// CHECK-CALLASYNC-NOT: source.refactoring.kind.convert.func-to-async
255+
// CHECK-CALLASYNC-NOT: source.refactoring.kind.add.async-wrapper
253256
// CHECK-CALLASYNC: source.refactoring.kind.convert.call-to-async
254257
// CHECK-CALLASYNC-NEXT: Convert Call to Async Alternative
255258
// CHECK-CALLASYNC-NOT: source.refactoring.kind.add.async-alternative
256259
// CHECK-CALLASYNC-NOT: source.refactoring.kind.convert.func-to-async
260+
// CHECK-CALLASYNC-NOT: source.refactoring.kind.add.async-wrapper
257261
// CHECK-CALLASYNC: ACTIONS END
258262

259263
// REQUIRES: OS=macosx || OS=linux-gnu

0 commit comments

Comments
 (0)