Skip to content

Commit 34d3df5

Browse files
committed
[SourceKit] Cancel in-flight builds on editor.close
When closing a document, cancel any in-flight builds happening for it. rdar://127126348
1 parent 06e4733 commit 34d3df5

File tree

6 files changed

+243
-2
lines changed

6 files changed

+243
-2
lines changed

tools/SourceKit/lib/SwiftLang/SwiftASTManager.cpp

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -451,6 +451,9 @@ class ASTBuildOperation
451451
/// consumer, removes it from the \c Consumers severed by this build operation
452452
/// and, if no consumers are left, cancels the AST build of this operation.
453453
void requestConsumerCancellation(SwiftASTConsumerRef Consumer);
454+
455+
/// Cancels all consumers for the given operation.
456+
void cancelAllConsumers();
454457
};
455458

456459
using ASTBuildOperationRef = std::shared_ptr<ASTBuildOperation>;
@@ -517,6 +520,9 @@ class ASTProducer : public std::enable_shared_from_this<ASTProducer> {
517520
IntrusiveRefCntPtr<llvm::vfs::FileSystem> FileSystem,
518521
SwiftASTManagerRef Mgr);
519522

523+
/// Cancel all currently running build operations.
524+
void cancelAllBuilds();
525+
520526
size_t getMemoryCost() const {
521527
size_t Cost = sizeof(*this);
522528
for (auto &BuildOp : BuildOperations) {
@@ -848,6 +854,14 @@ void SwiftASTManager::removeCachedAST(SwiftInvocationRef Invok) {
848854
Impl.ASTCache.remove(Invok->Impl.Key);
849855
}
850856

857+
void SwiftASTManager::cancelBuildsForCachedAST(SwiftInvocationRef Invok) {
858+
auto Result = Impl.getASTProducer(Invok);
859+
if (!Result)
860+
return;
861+
862+
(*Result)->cancelAllBuilds();
863+
}
864+
851865
ASTProducerRef SwiftASTManager::Implementation::getOrCreateASTProducer(
852866
SwiftInvocationRef InvokRef) {
853867
llvm::sys::ScopedLock L(CacheMtx);
@@ -1006,6 +1020,24 @@ void ASTBuildOperation::requestConsumerCancellation(
10061020
});
10071021
}
10081022

1023+
void ASTBuildOperation::cancelAllConsumers() {
1024+
if (isFinished())
1025+
return;
1026+
1027+
llvm::sys::ScopedLock L(ConsumersAndResultMtx);
1028+
CancellationFlag->store(true, std::memory_order_relaxed);
1029+
1030+
// Take the consumers, and notify them of the cancellation.
1031+
decltype(this->Consumers) Consumers;
1032+
std::swap(Consumers, this->Consumers);
1033+
1034+
ASTManager->Impl.ConsumerNotificationQueue.dispatch(
1035+
[Consumers = std::move(Consumers)] {
1036+
for (auto &Consumer : Consumers)
1037+
Consumer->cancelled();
1038+
});
1039+
}
1040+
10091041
static void collectModuleDependencies(ModuleDecl *TopMod,
10101042
llvm::SmallPtrSetImpl<ModuleDecl *> &Visited,
10111043
SmallVectorImpl<std::string> &Filenames) {
@@ -1330,6 +1362,15 @@ ASTBuildOperationRef ASTProducer::getBuildOperationForConsumer(
13301362
return LatestUsableOp;
13311363
}
13321364

1365+
void ASTProducer::cancelAllBuilds() {
1366+
// Cancel all build operations, cleanup will happen when each operation
1367+
// terminates.
1368+
BuildOperationsQueue.dispatch([This = shared_from_this()] {
1369+
for (auto &BuildOp : This->BuildOperations)
1370+
BuildOp->cancelAllConsumers();
1371+
});
1372+
}
1373+
13331374
void ASTProducer::enqueueConsumer(
13341375
SwiftASTConsumerRef Consumer,
13351376
IntrusiveRefCntPtr<llvm::vfs::FileSystem> FileSystem,

tools/SourceKit/lib/SwiftLang/SwiftASTManager.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -313,6 +313,7 @@ class SwiftASTManager : public std::enable_shared_from_this<SwiftASTManager> {
313313
bool AllowInputs = true);
314314

315315
void removeCachedAST(SwiftInvocationRef Invok);
316+
void cancelBuildsForCachedAST(SwiftInvocationRef Invok);
316317

317318
struct Implementation;
318319
Implementation &Impl;

tools/SourceKit/lib/SwiftLang/SwiftEditor.cpp

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -707,6 +707,13 @@ class SwiftDocumentSemanticInfo :
707707
}
708708
}
709709

710+
void cancelBuildsForCachedAST() {
711+
if (!InvokRef)
712+
return;
713+
if (auto ASTMgr = this->ASTMgr.lock())
714+
ASTMgr->cancelBuildsForCachedAST(InvokRef);
715+
}
716+
710717
private:
711718
std::vector<SwiftSemanticToken> takeSemanticTokens(
712719
ImmutableTextSnapshotRef NewSnapshot);
@@ -2174,6 +2181,10 @@ void SwiftEditorDocument::removeCachedAST() {
21742181
Impl.SemanticInfo->removeCachedAST();
21752182
}
21762183

2184+
void SwiftEditorDocument::cancelBuildsForCachedAST() {
2185+
Impl.SemanticInfo->cancelBuildsForCachedAST();
2186+
}
2187+
21772188
void SwiftEditorDocument::applyFormatOptions(OptionsDictionary &FmtOptions) {
21782189
static UIdent KeyUseTabs("key.editor.format.usetabs");
21792190
static UIdent KeyIndentWidth("key.editor.format.indentwidth");
@@ -2444,8 +2455,14 @@ void SwiftLangSupport::editorClose(StringRef Name, bool RemoveCache) {
24442455
IFaceGenContexts.remove(Name);
24452456
}
24462457

2447-
if (Removed && RemoveCache)
2448-
Removed->removeCachedAST();
2458+
if (Removed) {
2459+
// Cancel any in-flight builds for the given AST.
2460+
Removed->cancelBuildsForCachedAST();
2461+
2462+
// Then remove the cached AST if we've been asked to do so.
2463+
if (RemoveCache)
2464+
Removed->removeCachedAST();
2465+
}
24492466
// FIXME: Report error if Name did not apply to anything ?
24502467
}
24512468

tools/SourceKit/lib/SwiftLang/SwiftLangSupport.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,7 @@ class SwiftEditorDocument :
107107
void updateSemaInfo(SourceKitCancellationToken CancellationToken);
108108

109109
void removeCachedAST();
110+
void cancelBuildsForCachedAST();
110111

111112
ImmutableTextSnapshotRef getLatestSnapshot() const;
112113

unittests/SourceKit/SwiftLang/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
if(NOT SWIFT_HOST_VARIANT MATCHES "${SWIFT_DARWIN_EMBEDDED_VARIANTS}")
22
add_swift_unittest(SourceKitSwiftLangTests
33
CursorInfoTest.cpp
4+
CloseTest.cpp
45
EditingTest.cpp
56
)
67
target_link_libraries(SourceKitSwiftLangTests PRIVATE SourceKitSwiftLang)
Lines changed: 180 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,180 @@
1+
//===----------------------------------------------------------------------===//
2+
//
3+
// This source file is part of the Swift.org open source project
4+
//
5+
// Copyright (c) 2024 Apple Inc. and the Swift project authors
6+
// Licensed under Apache License v2.0 with Runtime Library Exception
7+
//
8+
// See https://swift.org/LICENSE.txt for license information
9+
// See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors
10+
//
11+
//===----------------------------------------------------------------------===//
12+
13+
#include "NullEditorConsumer.h"
14+
#include "SourceKit/Core/Context.h"
15+
#include "SourceKit/Core/LangSupport.h"
16+
#include "SourceKit/Core/NotificationCenter.h"
17+
#include "SourceKit/Support/Concurrency.h"
18+
#include "SourceKit/SwiftLang/Factory.h"
19+
#include "swift/Basic/LLVMInitialize.h"
20+
#include "gtest/gtest.h"
21+
22+
#include <chrono>
23+
#include <condition_variable>
24+
#include <mutex>
25+
#include <thread>
26+
27+
using namespace SourceKit;
28+
using namespace llvm;
29+
30+
static StringRef getRuntimeLibPath() {
31+
return sys::path::parent_path(SWIFTLIB_DIR);
32+
}
33+
34+
static SmallString<128> getSwiftExecutablePath() {
35+
SmallString<128> path = sys::path::parent_path(getRuntimeLibPath());
36+
sys::path::append(path, "bin", "swift-frontend");
37+
return path;
38+
}
39+
40+
namespace {
41+
class CompileTrackingConsumer final : public trace::TraceConsumer {
42+
std::mutex Mtx;
43+
std::condition_variable CV;
44+
bool HasStarted = false;
45+
46+
public:
47+
CompileTrackingConsumer() {}
48+
CompileTrackingConsumer(const CompileTrackingConsumer &) = delete;
49+
50+
void operationStarted(uint64_t OpId, trace::OperationKind OpKind,
51+
const trace::SwiftInvocation &Inv,
52+
const trace::StringPairs &OpArgs) override {
53+
std::unique_lock<std::mutex> lk(Mtx);
54+
HasStarted = true;
55+
CV.notify_all();
56+
}
57+
58+
void waitForBuildToStart() {
59+
std::unique_lock<std::mutex> lk(Mtx);
60+
auto secondsToWait = std::chrono::seconds(20);
61+
auto when = std::chrono::system_clock::now() + secondsToWait;
62+
CV.wait_until(lk, when, [&]() { return HasStarted; });
63+
HasStarted = false;
64+
}
65+
66+
void operationFinished(uint64_t OpId, trace::OperationKind OpKind,
67+
ArrayRef<DiagnosticEntryInfo> Diagnostics) override {}
68+
69+
swift::OptionSet<trace::OperationKind> desiredOperations() override {
70+
return trace::OperationKind::PerformSema;
71+
}
72+
};
73+
74+
class CloseTest : public ::testing::Test {
75+
std::shared_ptr<SourceKit::Context> Ctx;
76+
std::shared_ptr<CompileTrackingConsumer> CompileTracker;
77+
78+
NullEditorConsumer Consumer;
79+
80+
public:
81+
CloseTest() {
82+
INITIALIZE_LLVM();
83+
Ctx = std::make_shared<SourceKit::Context>(
84+
getSwiftExecutablePath(), getRuntimeLibPath(),
85+
/*diagnosticDocumentationPath*/ "", SourceKit::createSwiftLangSupport,
86+
/*dispatchOnMain=*/false);
87+
}
88+
89+
CompileTrackingConsumer &getCompileTracker() const { return *CompileTracker; }
90+
LangSupport &getLang() { return Ctx->getSwiftLangSupport(); }
91+
92+
void SetUp() override {
93+
CompileTracker = std::make_shared<CompileTrackingConsumer>();
94+
trace::registerConsumer(CompileTracker.get());
95+
}
96+
97+
void TearDown() override {
98+
trace::unregisterConsumer(CompileTracker.get());
99+
CompileTracker = nullptr;
100+
}
101+
102+
void open(const char *DocName, StringRef Text, ArrayRef<const char *> CArgs) {
103+
auto Args = makeArgs(DocName, CArgs);
104+
auto Buf = MemoryBuffer::getMemBufferCopy(Text, DocName);
105+
getLang().editorOpen(DocName, Buf.get(), Consumer, Args, std::nullopt);
106+
}
107+
108+
void close(const char *DocName, bool removeCache) {
109+
getLang().editorClose(DocName, removeCache);
110+
}
111+
112+
void getDiagnosticsAsync(
113+
const char *DocName, ArrayRef<const char *> CArgs,
114+
llvm::function_ref<void(const RequestResult<DiagnosticsResult> &)>
115+
callback) {
116+
auto Args = makeArgs(DocName, CArgs);
117+
getLang().getDiagnostics(DocName, Args, /*VFSOpts*/ std::nullopt,
118+
/*CancelToken*/ {}, callback);
119+
}
120+
121+
private:
122+
std::vector<const char *> makeArgs(const char *DocName,
123+
ArrayRef<const char *> CArgs) {
124+
std::vector<const char *> Args = CArgs;
125+
Args.push_back(DocName);
126+
return Args;
127+
}
128+
};
129+
130+
} // end anonymous namespace
131+
132+
static const char *getComplexSourceText() {
133+
// best of luck, type-checker
134+
return
135+
"struct A: ExpressibleByIntegerLiteral { init(integerLiteral value: Int) {} }\n"
136+
"struct B: ExpressibleByIntegerLiteral { init(integerLiteral value: Int) {} }\n"
137+
"struct C: ExpressibleByIntegerLiteral { init(integerLiteral value: Int) {} }\n"
138+
139+
"func + (lhs: A, rhs: B) -> A { fatalError() }\n"
140+
"func + (lhs: B, rhs: C) -> A { fatalError() }\n"
141+
"func + (lhs: C, rhs: A) -> A { fatalError() }\n"
142+
143+
"func + (lhs: B, rhs: A) -> B { fatalError() }\n"
144+
"func + (lhs: C, rhs: B) -> B { fatalError() }\n"
145+
"func + (lhs: A, rhs: C) -> B { fatalError() }\n"
146+
147+
"func + (lhs: C, rhs: B) -> C { fatalError() }\n"
148+
"func + (lhs: B, rhs: C) -> C { fatalError() }\n"
149+
"func + (lhs: A, rhs: A) -> C { fatalError() }\n"
150+
151+
"let x: C = 1 + 2 + 3 + 4 + 5 + 6 + 7 + 8\n";
152+
}
153+
154+
TEST_F(CloseTest, Cancel) {
155+
const char *DocName = "test.swift";
156+
auto *Contents = getComplexSourceText();
157+
const char *Args[] = {"-parse-as-library"};
158+
159+
// Test twice with RemoveCache = false to test both the prior state of
160+
// the ASTProducer being cached and not cached.
161+
for (auto RemoveCache : {true, false, false}) {
162+
open(DocName, Contents, Args);
163+
164+
Semaphore BuildResultSema(0);
165+
166+
getDiagnosticsAsync(DocName, Args,
167+
[&](const RequestResult<DiagnosticsResult> &Result) {
168+
EXPECT_TRUE(Result.isCancelled());
169+
BuildResultSema.signal();
170+
});
171+
172+
getCompileTracker().waitForBuildToStart();
173+
174+
close(DocName, RemoveCache);
175+
176+
bool Expired = BuildResultSema.wait(30 * 1000);
177+
if (Expired)
178+
llvm::report_fatal_error("Did not receive a response for the request");
179+
}
180+
}

0 commit comments

Comments
 (0)