@@ -4195,6 +4195,13 @@ struct AsyncHandlerDesc {
4195
4195
return params ();
4196
4196
}
4197
4197
4198
+ // / If the completion handler has an Error parameter, return it.
4199
+ Optional<AnyFunctionType::Param> getErrorParam () const {
4200
+ if (HasError && Type == HandlerType::PARAMS)
4201
+ return params ().back ();
4202
+ return None;
4203
+ }
4204
+
4198
4205
// / Get the type of the error that will be thrown by the \c async method or \c
4199
4206
// / None if the completion handler doesn't accept an error parameter.
4200
4207
// / This may be more specialized than the generic 'Error' type if the
@@ -5397,6 +5404,41 @@ class AsyncConverter : private SourceEntityWalker {
5397
5404
return true ;
5398
5405
}
5399
5406
5407
+ // / Creates an async alternative function that forwards onto the completion
5408
+ // / handler function through
5409
+ // / withCheckedContinuation/withCheckedThrowingContinuation.
5410
+ bool createAsyncWrapper () {
5411
+ assert (Buffer.empty () && " AsyncConverter can only be used once" );
5412
+ auto *FD = cast<FuncDecl>(StartNode.get <Decl *>());
5413
+
5414
+ // First add the new async function declaration.
5415
+ addFuncDecl (FD);
5416
+ OS << tok::l_brace << " \n " ;
5417
+
5418
+ // Then add the body.
5419
+ OS << tok::kw_return << " " ;
5420
+ if (TopHandler.HasError )
5421
+ OS << tok::kw_try << " " ;
5422
+
5423
+ OS << " await " ;
5424
+
5425
+ // withChecked[Throwing]Continuation { cont in
5426
+ if (TopHandler.HasError ) {
5427
+ OS << " withCheckedThrowingContinuation" ;
5428
+ } else {
5429
+ OS << " withCheckedContinuation" ;
5430
+ }
5431
+ OS << " " << tok::l_brace << " cont " << tok::kw_in << " \n " ;
5432
+
5433
+ // fnWithHandler(args...) { ... }
5434
+ auto ClosureStr = getAsyncWrapperCompletionClosure (" cont" , TopHandler);
5435
+ addForwardingCallTo (FD, TopHandler, /* HandlerReplacement*/ ClosureStr);
5436
+
5437
+ OS << tok::r_brace << " \n " ; // end continuation closure
5438
+ OS << tok::r_brace << " \n " ; // end function body
5439
+ return true ;
5440
+ }
5441
+
5400
5442
void replace (ASTNode Node, SourceEditConsumer &EditConsumer,
5401
5443
SourceLoc StartOverride = SourceLoc()) {
5402
5444
SourceRange Range = Node.getSourceRange ();
@@ -5446,6 +5488,116 @@ class AsyncConverter : private SourceEntityWalker {
5446
5488
OS << tok::r_paren;
5447
5489
}
5448
5490
5491
+ // / Retrieve the completion handler closure argument for an async wrapper
5492
+ // / function.
5493
+ std::string
5494
+ getAsyncWrapperCompletionClosure (StringRef ContName,
5495
+ const AsyncHandlerParamDesc &HandlerDesc) {
5496
+ std::string OutputStr;
5497
+ llvm::raw_string_ostream OS (OutputStr);
5498
+
5499
+ OS << " " << tok::l_brace; // start closure
5500
+
5501
+ // Prepare parameter names for the closure.
5502
+ auto SuccessParams = HandlerDesc.getSuccessParams ();
5503
+ SmallVector<SmallString<4 >, 2 > SuccessParamNames;
5504
+ for (auto idx : indices (SuccessParams)) {
5505
+ SuccessParamNames.emplace_back (" res" );
5506
+
5507
+ // If we have multiple success params, number them e.g res1, res2...
5508
+ if (SuccessParams.size () > 1 )
5509
+ SuccessParamNames.back ().append (std::to_string (idx + 1 ));
5510
+ }
5511
+ Optional<SmallString<4 >> ErrName;
5512
+ if (HandlerDesc.getErrorParam ())
5513
+ ErrName.emplace (" err" );
5514
+
5515
+ auto HasAnyParams = !SuccessParamNames.empty () || ErrName;
5516
+ if (HasAnyParams)
5517
+ OS << " " ;
5518
+
5519
+ // res1, res2
5520
+ llvm::interleave (
5521
+ SuccessParamNames, [&](auto Name) { OS << Name; },
5522
+ [&]() { OS << tok::comma << " " ; });
5523
+
5524
+ // , err
5525
+ if (ErrName) {
5526
+ if (!SuccessParamNames.empty ())
5527
+ OS << tok::comma << " " ;
5528
+
5529
+ OS << *ErrName;
5530
+ }
5531
+ if (HasAnyParams)
5532
+ OS << " " << tok::kw_in;
5533
+
5534
+ OS << " \n " ;
5535
+
5536
+ // The closure body.
5537
+ switch (HandlerDesc.Type ) {
5538
+ case HandlerType::PARAMS: {
5539
+ // For a (Success?, Error?) -> Void handler, we do an if let on the error.
5540
+ if (ErrName) {
5541
+ // if let err = err {
5542
+ OS << tok::kw_if << " " << tok::kw_let << " " ;
5543
+ OS << *ErrName << " " << tok::equal << " " << *ErrName << " " ;
5544
+ OS << tok::l_brace << " \n " ;
5545
+
5546
+ // cont.resume(throwing: err)
5547
+ OS << ContName << tok::period << " resume" << tok::l_paren;
5548
+ OS << " throwing" << tok::colon << " " << *ErrName;
5549
+ OS << tok::r_paren << " \n " ;
5550
+
5551
+ // return }
5552
+ OS << tok::kw_return << " \n " ;
5553
+ OS << tok::r_brace << " \n " ;
5554
+ }
5555
+
5556
+ // If we have any success params that we need to unwrap, insert a guard.
5557
+ for (auto Idx : indices (SuccessParamNames)) {
5558
+ auto &Name = SuccessParamNames[Idx];
5559
+ auto ParamTy = SuccessParams[Idx].getParameterType ();
5560
+ if (!HandlerDesc.shouldUnwrap (ParamTy))
5561
+ continue ;
5562
+
5563
+ // guard let res = res else {
5564
+ OS << tok::kw_guard << " " << tok::kw_let << " " ;
5565
+ OS << Name << " " << tok::equal << " " << Name << " " << tok::kw_else;
5566
+ OS << " " << tok::l_brace << " \n " ;
5567
+
5568
+ // fatalError(...)
5569
+ OS << " fatalError" << tok::l_paren;
5570
+ OS << " \" Expected non-nil success param '" << Name;
5571
+ OS << " ' for nil error\" " ;
5572
+ OS << tok::r_paren << " \n " ;
5573
+
5574
+ // End guard.
5575
+ OS << tok::r_brace << " \n " ;
5576
+ }
5577
+
5578
+ // cont.resume(returning: (res1, res2, ...))
5579
+ OS << ContName << tok::period << " resume" << tok::l_paren;
5580
+ OS << " returning" << tok::colon << " " ;
5581
+ addTupleOf (llvm::makeArrayRef (SuccessParamNames), OS,
5582
+ [&](auto Ref) { OS << Ref; });
5583
+ OS << tok::r_paren << " \n " ;
5584
+ break ;
5585
+ }
5586
+ case HandlerType::RESULT: {
5587
+ // cont.resume(with: res)
5588
+ assert (SuccessParamNames.size () == 1 );
5589
+ OS << ContName << tok::period << " resume" << tok::l_paren;
5590
+ OS << " with" << tok::colon << " " << SuccessParamNames[0 ];
5591
+ OS << tok::r_paren << " \n " ;
5592
+ break ;
5593
+ }
5594
+ case HandlerType::INVALID:
5595
+ llvm_unreachable (" Should not have an invalid handler here" );
5596
+ }
5597
+
5598
+ OS << tok::r_brace << " \n " ; // end closure
5599
+ return OutputStr;
5600
+ }
5449
5601
5450
5602
// / Retrieves the location for the start of a comment attached to the token
5451
5603
// / at the provided location, or the location itself if there is no comment.
@@ -6472,6 +6624,24 @@ class AsyncConverter : private SourceEntityWalker {
6472
6624
}
6473
6625
};
6474
6626
6627
+ // / Adds an attribute to describe a completion handler function's async
6628
+ // / alternative if necessary.
6629
+ void addCompletionHandlerAsyncAttrIfNeccessary (
6630
+ ASTContext &Ctx, const FuncDecl *FD,
6631
+ const AsyncHandlerParamDesc &HandlerDesc,
6632
+ SourceEditConsumer &EditConsumer) {
6633
+ if (!Ctx.LangOpts .EnableExperimentalConcurrency )
6634
+ return ;
6635
+
6636
+ llvm::SmallString<0 > HandlerAttribute;
6637
+ llvm::raw_svector_ostream OS (HandlerAttribute);
6638
+ OS << " @completionHandlerAsync(\" " ;
6639
+ HandlerDesc.printAsyncFunctionName (OS);
6640
+ OS << " \" , completionHandlerIndex: " << HandlerDesc.Index << " )\n " ;
6641
+ EditConsumer.accept (Ctx.SourceMgr , FD->getAttributeInsertionLoc (false ),
6642
+ HandlerAttribute);
6643
+ }
6644
+
6475
6645
} // namespace asyncrefactorings
6476
6646
6477
6647
bool RefactoringActionConvertCallToAsyncAlternative::isApplicable (
@@ -6593,16 +6763,7 @@ bool RefactoringActionAddAsyncAlternative::performChange() {
6593
6763
" @available(*, deprecated, message: \" Prefer async "
6594
6764
" alternative instead\" )\n " );
6595
6765
6596
- if (Ctx.LangOpts .EnableExperimentalConcurrency ) {
6597
- // Add an attribute to describe its async alternative
6598
- llvm::SmallString<0 > HandlerAttribute;
6599
- llvm::raw_svector_ostream OS (HandlerAttribute);
6600
- OS << " @completionHandlerAsync(\" " ;
6601
- HandlerDesc.printAsyncFunctionName (OS);
6602
- OS << " \" , completionHandlerIndex: " << HandlerDesc.Index << " )\n " ;
6603
- EditConsumer.accept (SM, FD->getAttributeInsertionLoc (false ),
6604
- HandlerAttribute);
6605
- }
6766
+ addCompletionHandlerAsyncAttrIfNeccessary (Ctx, FD, HandlerDesc, EditConsumer);
6606
6767
6607
6768
AsyncConverter LegacyBodyCreator (TheFile, SM, DiagEngine, FD, HandlerDesc);
6608
6769
if (LegacyBodyCreator.createLegacyBody ()) {
@@ -6614,6 +6775,43 @@ bool RefactoringActionAddAsyncAlternative::performChange() {
6614
6775
6615
6776
return false ;
6616
6777
}
6778
+
6779
+ bool RefactoringActionAddAsyncWrapper::isApplicable (
6780
+ const ResolvedCursorInfo &CursorInfo, DiagnosticEngine &Diag) {
6781
+ using namespace asyncrefactorings ;
6782
+
6783
+ auto *FD = findFunction (CursorInfo);
6784
+ if (!FD)
6785
+ return false ;
6786
+
6787
+ auto HandlerDesc =
6788
+ AsyncHandlerParamDesc::find (FD, /* RequireAttributeOrName=*/ false );
6789
+ return HandlerDesc.isValid ();
6790
+ }
6791
+
6792
+ bool RefactoringActionAddAsyncWrapper::performChange () {
6793
+ using namespace asyncrefactorings ;
6794
+
6795
+ auto *FD = findFunction (CursorInfo);
6796
+ assert (FD &&
6797
+ " Should not run performChange when refactoring is not applicable" );
6798
+
6799
+ auto HandlerDesc =
6800
+ AsyncHandlerParamDesc::find (FD, /* RequireAttributeOrName=*/ false );
6801
+ assert (HandlerDesc.isValid () &&
6802
+ " Should not run performChange when refactoring is not applicable" );
6803
+
6804
+ AsyncConverter Converter (TheFile, SM, DiagEngine, FD, HandlerDesc);
6805
+ if (!Converter.createAsyncWrapper ())
6806
+ return true ;
6807
+
6808
+ addCompletionHandlerAsyncAttrIfNeccessary (Ctx, FD, HandlerDesc, EditConsumer);
6809
+
6810
+ // Add the async wrapper.
6811
+ Converter.insertAfter (FD, EditConsumer);
6812
+ return false ;
6813
+ }
6814
+
6617
6815
} // end of anonymous namespace
6618
6816
6619
6817
StringRef swift::ide::
0 commit comments