Skip to content

Commit 177ab6e

Browse files
committed
[SourceKit] Support cancellation of code completion like requests
Essentially, just wire up cancellation tokens and cancellation flags for `CompletionInstance` and make sure to return `CancellableResult::cancelled()` when cancellation is detected. rdar://83391488
1 parent d482ee3 commit 177ab6e

File tree

11 files changed

+280
-146
lines changed

11 files changed

+280
-146
lines changed

include/swift/IDE/CompletionInstance.h

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,7 @@ class CompletionInstance {
106106
llvm::IntrusiveRefCntPtr<llvm::vfs::FileSystem> FileSystem,
107107
llvm::MemoryBuffer *completionBuffer, unsigned int Offset,
108108
DiagnosticConsumer *DiagC,
109+
std::shared_ptr<std::atomic<bool>> CancellationFlag,
109110
llvm::function_ref<void(CancellableResult<CompletionInstanceResult>)>
110111
Callback);
111112

@@ -119,6 +120,7 @@ class CompletionInstance {
119120
llvm::IntrusiveRefCntPtr<llvm::vfs::FileSystem> FileSystem,
120121
llvm::MemoryBuffer *completionBuffer, unsigned int Offset,
121122
DiagnosticConsumer *DiagC,
123+
std::shared_ptr<std::atomic<bool>> CancellationFlag,
122124
llvm::function_ref<void(CancellableResult<CompletionInstanceResult>)>
123125
Callback);
124126

@@ -129,6 +131,14 @@ class CompletionInstance {
129131
/// In case of failure or cancellation, the callback receives the
130132
/// corresponding failed or cancelled result.
131133
///
134+
/// If \p CancellationFlag is not \c nullptr, code completion can be cancelled
135+
/// by setting the flag to \c true.
136+
/// IMPORTANT: If \p CancellationFlag is not \c nullptr, then completion might
137+
/// be cancelled in the secon pass that's invoked inside \p Callback.
138+
/// Therefore, \p Callback MUST check whether completion was cancelled before
139+
/// interpreting the results, since invalid results may be returned in case
140+
/// of cancellation.
141+
///
132142
/// NOTE: \p Args is only used for checking the equaity of the invocation.
133143
/// Since this function assumes that it is already normalized, exact the same
134144
/// arguments including their order is considered as the same invocation.
@@ -137,6 +147,7 @@ class CompletionInstance {
137147
llvm::IntrusiveRefCntPtr<llvm::vfs::FileSystem> FileSystem,
138148
llvm::MemoryBuffer *completionBuffer, unsigned int Offset,
139149
DiagnosticConsumer *DiagC,
150+
std::shared_ptr<std::atomic<bool>> CancellationFlag,
140151
llvm::function_ref<void(CancellableResult<CompletionInstanceResult>)>
141152
Callback);
142153

@@ -155,13 +166,15 @@ class CompletionInstance {
155166
llvm::IntrusiveRefCntPtr<llvm::vfs::FileSystem> FileSystem,
156167
llvm::MemoryBuffer *completionBuffer, unsigned int Offset,
157168
DiagnosticConsumer *DiagC, ide::CodeCompletionContext &CompletionContext,
169+
std::shared_ptr<std::atomic<bool>> CancellationFlag,
158170
llvm::function_ref<void(CancellableResult<CodeCompleteResult>)> Callback);
159171

160172
void typeContextInfo(
161173
swift::CompilerInvocation &Invocation, llvm::ArrayRef<const char *> Args,
162174
llvm::IntrusiveRefCntPtr<llvm::vfs::FileSystem> FileSystem,
163175
llvm::MemoryBuffer *completionBuffer, unsigned int Offset,
164176
DiagnosticConsumer *DiagC,
177+
std::shared_ptr<std::atomic<bool>> CancellationFlag,
165178
llvm::function_ref<void(CancellableResult<TypeContextInfoResult>)>
166179
Callback);
167180

@@ -170,6 +183,7 @@ class CompletionInstance {
170183
llvm::IntrusiveRefCntPtr<llvm::vfs::FileSystem> FileSystem,
171184
llvm::MemoryBuffer *completionBuffer, unsigned int Offset,
172185
DiagnosticConsumer *DiagC, ArrayRef<const char *> ExpectedTypeNames,
186+
std::shared_ptr<std::atomic<bool>> CancellationFlag,
173187
llvm::function_ref<void(CancellableResult<ConformingMethodListResults>)>
174188
Callback);
175189
};

lib/IDE/CompletionInstance.cpp

Lines changed: 105 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -286,6 +286,7 @@ bool CompletionInstance::performCachedOperationIfPossible(
286286
llvm::IntrusiveRefCntPtr<llvm::vfs::FileSystem> FileSystem,
287287
llvm::MemoryBuffer *completionBuffer, unsigned int Offset,
288288
DiagnosticConsumer *DiagC,
289+
std::shared_ptr<std::atomic<bool>> CancellationFlag,
289290
llvm::function_ref<void(CancellableResult<CompletionInstanceResult>)>
290291
Callback) {
291292
llvm::PrettyStackTraceString trace(
@@ -338,6 +339,7 @@ bool CompletionInstance::performCachedOperationIfPossible(
338339
std::unique_ptr<ASTContext> tmpCtx(
339340
ASTContext::get(langOpts, typeckOpts, silOpts, searchPathOpts, clangOpts,
340341
symbolOpts, tmpSM, tmpDiags));
342+
tmpCtx->CancellationFlag = CancellationFlag;
341343
registerParseRequestFunctions(tmpCtx->evaluator);
342344
registerIDERequestFunctions(tmpCtx->evaluator);
343345
registerTypeCheckerRequestFunctions(tmpCtx->evaluator);
@@ -502,12 +504,18 @@ bool CompletionInstance::performCachedOperationIfPossible(
502504
// The diagnostic engine is keeping track of state which might modify
503505
// parsing and type checking behaviour. Clear the flags.
504506
CI.getDiags().resetHadAnyError();
507+
CI.getASTContext().CancellationFlag = CancellationFlag;
505508

506509
if (DiagC)
507510
CI.addDiagnosticConsumer(DiagC);
508511

509-
Callback(CancellableResult<CompletionInstanceResult>::success(
510-
{CI, /*reusingASTContext=*/true, /*DidFindCodeCompletionToken=*/true}));
512+
if (CancellationFlag && CancellationFlag->load(std::memory_order_relaxed)) {
513+
Callback(CancellableResult<CompletionInstanceResult>::cancelled());
514+
} else {
515+
Callback(CancellableResult<CompletionInstanceResult>::success(
516+
{CI, /*reusingASTContext=*/true,
517+
/*DidFindCodeCompletionToken=*/true}));
518+
}
511519

512520
if (DiagC)
513521
CI.removeDiagnosticConsumer(DiagC);
@@ -523,11 +531,13 @@ void CompletionInstance::performNewOperation(
523531
llvm::IntrusiveRefCntPtr<llvm::vfs::FileSystem> FileSystem,
524532
llvm::MemoryBuffer *completionBuffer, unsigned int Offset,
525533
DiagnosticConsumer *DiagC,
534+
std::shared_ptr<std::atomic<bool>> CancellationFlag,
526535
llvm::function_ref<void(CancellableResult<CompletionInstanceResult>)>
527536
Callback) {
528537
llvm::PrettyStackTraceString trace("While performing new completion");
529538

530-
auto isCachedCompletionRequested = ArgsHash.hasValue();
539+
// If ArgsHash is None we shouldn't cache the compiler instance.
540+
bool ShouldCacheCompilerInstance = ArgsHash.hasValue();
531541

532542
auto TheInstance = std::make_unique<CompilerInstance>();
533543

@@ -536,7 +546,6 @@ void CompletionInstance::performNewOperation(
536546
Invocation.getFrontendOptions().IntermoduleDependencyTracking =
537547
IntermoduleDepTrackingMode::ExcludeSystem;
538548

539-
bool DidFindCodeCompletionToken = false;
540549
{
541550
auto &CI = *TheInstance;
542551
if (DiagC)
@@ -557,6 +566,7 @@ void CompletionInstance::performNewOperation(
557566
"failed to setup compiler instance"));
558567
return;
559568
}
569+
CI.getASTContext().CancellationFlag = CancellationFlag;
560570
registerIDERequestFunctions(CI.getASTContext().evaluator);
561571

562572
// If we're expecting a standard library, but there either isn't one, or it
@@ -570,19 +580,32 @@ void CompletionInstance::performNewOperation(
570580

571581
CI.performParseAndResolveImportsOnly();
572582

573-
DidFindCodeCompletionToken = CI.getCodeCompletionFile()
574-
->getDelayedParserState()
575-
->hasCodeCompletionDelayedDeclState();
576-
577-
Callback(CancellableResult<CompletionInstanceResult>::success(
578-
{CI, /*ReuisingASTContext=*/false, DidFindCodeCompletionToken}));
583+
bool DidFindCodeCompletionToken = CI.getCodeCompletionFile()
584+
->getDelayedParserState()
585+
->hasCodeCompletionDelayedDeclState();
586+
ShouldCacheCompilerInstance &= DidFindCodeCompletionToken;
587+
588+
auto CancellationFlag = CI.getASTContext().CancellationFlag;
589+
if (CancellationFlag && CancellationFlag->load(std::memory_order_relaxed)) {
590+
Callback(CancellableResult<CompletionInstanceResult>::cancelled());
591+
// The completion instance may be in an invalid state when it's been
592+
// cancelled. Don't cache it.
593+
ShouldCacheCompilerInstance = false;
594+
} else {
595+
Callback(CancellableResult<CompletionInstanceResult>::success(
596+
{CI, /*ReuisingASTContext=*/false, DidFindCodeCompletionToken}));
597+
if (CancellationFlag &&
598+
CancellationFlag->load(std::memory_order_relaxed)) {
599+
ShouldCacheCompilerInstance = false;
600+
}
601+
}
579602
}
580603

581604
// Cache the compiler instance if fast completion is enabled.
582605
// If we didn't find a code compleiton token, we can't cache the instance
583606
// because performCachedOperationIfPossible wouldn't have an old code
584607
// completion state to compare the new one to.
585-
if (isCachedCompletionRequested && DidFindCodeCompletionToken)
608+
if (ShouldCacheCompilerInstance)
586609
cacheCompilerInstance(std::move(TheInstance), *ArgsHash);
587610
}
588611

@@ -623,20 +646,9 @@ void swift::ide::CompletionInstance::performOperation(
623646
llvm::IntrusiveRefCntPtr<llvm::vfs::FileSystem> FileSystem,
624647
llvm::MemoryBuffer *completionBuffer, unsigned int Offset,
625648
DiagnosticConsumer *DiagC,
649+
std::shared_ptr<std::atomic<bool>> CancellationFlag,
626650
llvm::function_ref<void(CancellableResult<CompletionInstanceResult>)>
627651
Callback) {
628-
629-
// Always disable source location resolutions from .swiftsourceinfo file
630-
// because they're somewhat heavy operations and aren't needed for completion.
631-
Invocation.getFrontendOptions().IgnoreSwiftSourceInfo = true;
632-
633-
// Disable to build syntax tree because code-completion skips some portion of
634-
// source text. That breaks an invariant of syntax tree building.
635-
Invocation.getLangOptions().BuildSyntaxTree = false;
636-
637-
// We don't need token list.
638-
Invocation.getLangOptions().CollectParsedToken = false;
639-
640652
// Compute the signature of the invocation.
641653
llvm::hash_code ArgsHash(0);
642654
for (auto arg : Args)
@@ -647,32 +659,48 @@ void swift::ide::CompletionInstance::performOperation(
647659
std::lock_guard<std::mutex> lock(mtx);
648660

649661
if (performCachedOperationIfPossible(ArgsHash, FileSystem, completionBuffer,
650-
Offset, DiagC, Callback)) {
662+
Offset, DiagC, CancellationFlag,
663+
Callback)) {
651664
// We were able to reuse a cached AST. Callback has already been invoked
652665
// and we don't need to build a new AST. We are done.
653666
return;
654667
}
655668

669+
// Always disable source location resolutions from .swiftsourceinfo file
670+
// because they're somewhat heavy operations and aren't needed for completion.
671+
Invocation.getFrontendOptions().IgnoreSwiftSourceInfo = true;
672+
673+
// Disable to build syntax tree because code-completion skips some portion of
674+
// source text. That breaks an invariant of syntax tree building.
675+
Invocation.getLangOptions().BuildSyntaxTree = false;
676+
677+
// We don't need token list.
678+
Invocation.getLangOptions().CollectParsedToken = false;
679+
656680
performNewOperation(ArgsHash, Invocation, FileSystem, completionBuffer,
657-
Offset, DiagC, Callback);
681+
Offset, DiagC, CancellationFlag, Callback);
658682
}
659683

660684
void swift::ide::CompletionInstance::codeComplete(
661685
swift::CompilerInvocation &Invocation, llvm::ArrayRef<const char *> Args,
662686
llvm::IntrusiveRefCntPtr<llvm::vfs::FileSystem> FileSystem,
663687
llvm::MemoryBuffer *completionBuffer, unsigned int Offset,
664688
DiagnosticConsumer *DiagC, ide::CodeCompletionContext &CompletionContext,
689+
std::shared_ptr<std::atomic<bool>> CancellationFlag,
665690
llvm::function_ref<void(CancellableResult<CodeCompleteResult>)> Callback) {
666691
using ResultType = CancellableResult<CodeCompleteResult>;
667692

668693
struct ConsumerToCallbackAdapter
669694
: public SimpleCachingCodeCompletionConsumer {
670695
SwiftCompletionInfo SwiftContext;
696+
std::shared_ptr<std::atomic<bool>> CancellationFlag;
671697
llvm::function_ref<void(ResultType)> Callback;
672698
bool HandleResultsCalled = false;
673699

674-
ConsumerToCallbackAdapter(llvm::function_ref<void(ResultType)> Callback)
675-
: Callback(Callback) {}
700+
ConsumerToCallbackAdapter(
701+
std::shared_ptr<std::atomic<bool>> CancellationFlag,
702+
llvm::function_ref<void(ResultType)> Callback)
703+
: CancellationFlag(CancellationFlag), Callback(Callback) {}
676704

677705
void setContext(swift::ASTContext *context,
678706
const swift::CompilerInvocation *invocation,
@@ -685,20 +713,28 @@ void swift::ide::CompletionInstance::codeComplete(
685713

686714
void handleResults(CodeCompletionContext &context) override {
687715
HandleResultsCalled = true;
688-
MutableArrayRef<CodeCompletionResult *> Results = context.takeResults();
689-
assert(SwiftContext.swiftASTContext);
690-
Callback(ResultType::success({Results, SwiftContext}));
716+
if (CancellationFlag &&
717+
CancellationFlag->load(std::memory_order_relaxed)) {
718+
Callback(ResultType::cancelled());
719+
} else {
720+
MutableArrayRef<CodeCompletionResult *> Results = context.takeResults();
721+
assert(SwiftContext.swiftASTContext);
722+
Callback(ResultType::success({Results, SwiftContext}));
723+
}
691724
}
692725
};
693726

694727
performOperation(
695728
Invocation, Args, FileSystem, completionBuffer, Offset, DiagC,
729+
CancellationFlag,
696730
[&](CancellableResult<CompletionInstanceResult> CIResult) {
697731
CIResult.mapAsync<CodeCompleteResult>(
698-
[&CompletionContext](auto &Result, auto DeliverTransformed) {
732+
[&CompletionContext, &CancellationFlag](auto &Result,
733+
auto DeliverTransformed) {
699734
CompletionContext.ReusingASTContext = Result.DidReuseAST;
700735
CompilerInstance &CI = Result.CI;
701-
ConsumerToCallbackAdapter Consumer(DeliverTransformed);
736+
ConsumerToCallbackAdapter Consumer(CancellationFlag,
737+
DeliverTransformed);
702738

703739
std::unique_ptr<CodeCompletionCallbacksFactory> callbacksFactory(
704740
ide::makeCodeCompletionCallbacksFactory(CompletionContext,
@@ -737,32 +773,43 @@ void swift::ide::CompletionInstance::typeContextInfo(
737773
llvm::IntrusiveRefCntPtr<llvm::vfs::FileSystem> FileSystem,
738774
llvm::MemoryBuffer *completionBuffer, unsigned int Offset,
739775
DiagnosticConsumer *DiagC,
776+
std::shared_ptr<std::atomic<bool>> CancellationFlag,
740777
llvm::function_ref<void(CancellableResult<TypeContextInfoResult>)>
741778
Callback) {
742779
using ResultType = CancellableResult<TypeContextInfoResult>;
743780

744781
struct ConsumerToCallbackAdapter : public ide::TypeContextInfoConsumer {
745782
bool ReusingASTContext;
783+
std::shared_ptr<std::atomic<bool>> CancellationFlag;
746784
llvm::function_ref<void(ResultType)> Callback;
747785
bool HandleResultsCalled = false;
748786

749-
ConsumerToCallbackAdapter(bool ReusingASTContext,
750-
llvm::function_ref<void(ResultType)> Callback)
751-
: ReusingASTContext(ReusingASTContext), Callback(Callback) {}
787+
ConsumerToCallbackAdapter(
788+
bool ReusingASTContext,
789+
std::shared_ptr<std::atomic<bool>> CancellationFlag,
790+
llvm::function_ref<void(ResultType)> Callback)
791+
: ReusingASTContext(ReusingASTContext),
792+
CancellationFlag(CancellationFlag), Callback(Callback) {}
752793

753794
void handleResults(ArrayRef<ide::TypeContextInfoItem> Results) override {
754795
HandleResultsCalled = true;
755-
Callback(ResultType::success({Results, ReusingASTContext}));
796+
if (CancellationFlag &&
797+
CancellationFlag->load(std::memory_order_relaxed)) {
798+
Callback(ResultType::cancelled());
799+
} else {
800+
Callback(ResultType::success({Results, ReusingASTContext}));
801+
}
756802
}
757803
};
758804

759805
performOperation(
760806
Invocation, Args, FileSystem, completionBuffer, Offset, DiagC,
807+
CancellationFlag,
761808
[&](CancellableResult<CompletionInstanceResult> CIResult) {
762809
CIResult.mapAsync<TypeContextInfoResult>(
763-
[](auto &Result, auto DeliverTransformed) {
764-
ConsumerToCallbackAdapter Consumer(Result.DidReuseAST,
765-
DeliverTransformed);
810+
[&CancellationFlag](auto &Result, auto DeliverTransformed) {
811+
ConsumerToCallbackAdapter Consumer(
812+
Result.DidReuseAST, CancellationFlag, DeliverTransformed);
766813
std::unique_ptr<CodeCompletionCallbacksFactory> callbacksFactory(
767814
ide::makeTypeContextInfoCallbacksFactory(Consumer));
768815

@@ -793,33 +840,45 @@ void swift::ide::CompletionInstance::conformingMethodList(
793840
llvm::IntrusiveRefCntPtr<llvm::vfs::FileSystem> FileSystem,
794841
llvm::MemoryBuffer *completionBuffer, unsigned int Offset,
795842
DiagnosticConsumer *DiagC, ArrayRef<const char *> ExpectedTypeNames,
843+
std::shared_ptr<std::atomic<bool>> CancellationFlag,
796844
llvm::function_ref<void(CancellableResult<ConformingMethodListResults>)>
797845
Callback) {
798846
using ResultType = CancellableResult<ConformingMethodListResults>;
799847

800848
struct ConsumerToCallbackAdapter
801849
: public swift::ide::ConformingMethodListConsumer {
802850
bool ReusingASTContext;
851+
std::shared_ptr<std::atomic<bool>> CancellationFlag;
803852
llvm::function_ref<void(ResultType)> Callback;
804853
bool HandleResultsCalled = false;
805854

806-
ConsumerToCallbackAdapter(bool ReusingASTContext,
807-
llvm::function_ref<void(ResultType)> Callback)
808-
: ReusingASTContext(ReusingASTContext), Callback(Callback) {}
855+
ConsumerToCallbackAdapter(
856+
bool ReusingASTContext,
857+
std::shared_ptr<std::atomic<bool>> CancellationFlag,
858+
llvm::function_ref<void(ResultType)> Callback)
859+
: ReusingASTContext(ReusingASTContext),
860+
CancellationFlag(CancellationFlag), Callback(Callback) {}
809861

810862
void handleResult(const ide::ConformingMethodListResult &result) override {
811863
HandleResultsCalled = true;
812-
Callback(ResultType::success({&result, ReusingASTContext}));
864+
if (CancellationFlag &&
865+
CancellationFlag->load(std::memory_order_relaxed)) {
866+
Callback(ResultType::cancelled());
867+
} else {
868+
Callback(ResultType::success({&result, ReusingASTContext}));
869+
}
813870
}
814871
};
815872

816873
performOperation(
817874
Invocation, Args, FileSystem, completionBuffer, Offset, DiagC,
875+
CancellationFlag,
818876
[&](CancellableResult<CompletionInstanceResult> CIResult) {
819877
CIResult.mapAsync<ConformingMethodListResults>(
820-
[&ExpectedTypeNames](auto &Result, auto DeliverTransformed) {
821-
ConsumerToCallbackAdapter Consumer(Result.DidReuseAST,
822-
DeliverTransformed);
878+
[&ExpectedTypeNames, &CancellationFlag](auto &Result,
879+
auto DeliverTransformed) {
880+
ConsumerToCallbackAdapter Consumer(
881+
Result.DidReuseAST, CancellationFlag, DeliverTransformed);
823882
std::unique_ptr<CodeCompletionCallbacksFactory> callbacksFactory(
824883
ide::makeConformingMethodListCallbacksFactory(
825884
ExpectedTypeNames, Consumer));

0 commit comments

Comments
 (0)