Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 21 additions & 1 deletion compiler/base/orchestrator/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion compiler/base/orchestrator/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ snafu = { version = "0.8.0", default-features = false, features = ["futures", "s
strum_macros = { version = "0.26.1", default-features = false }
tokio = { version = "1.28", default-features = false, features = ["fs", "io-std", "io-util", "macros", "process", "rt", "time", "sync"] }
tokio-stream = { version = "0.1.14", default-features = false }
tokio-util = { version = "0.7.8", default-features = false, features = ["io", "io-util"] }
tokio-util = { version = "0.7.8", default-features = false, features = ["io", "io-util", "rt"] }
toml = { version = "0.8.2", default-features = false, features = ["parse", "display"] }
tracing = { version = "0.1.37", default-features = false, features = ["attributes"] }

Expand Down
28 changes: 13 additions & 15 deletions compiler/base/orchestrator/src/coordinator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,12 @@ use tokio::{
process::{Child, ChildStdin, ChildStdout, Command},
select,
sync::{mpsc, oneshot, OnceCell},
task::{JoinHandle, JoinSet},
task::JoinSet,
time::{self, MissedTickBehavior},
try_join,
};
use tokio_stream::wrappers::ReceiverStream;
use tokio_util::{io::SyncIoBridge, sync::CancellationToken};
use tokio_util::{io::SyncIoBridge, sync::CancellationToken, task::AbortOnDropHandle};
use tracing::{error, info, info_span, instrument, trace, trace_span, warn, Instrument};

use crate::{
Expand All @@ -30,7 +30,7 @@ use crate::{
ExecuteCommandResponse, JobId, Multiplexed, OneToOneResponse, ReadFileRequest,
ReadFileResponse, SerializedError2, WorkerMessage, WriteFileRequest,
},
DropErrorDetailsExt,
DropErrorDetailsExt, TaskAbortExt as _,
};

pub mod limits;
Expand Down Expand Up @@ -1161,7 +1161,7 @@ impl Drop for CancelOnDrop {
#[derive(Debug)]
struct Container {
permit: Box<dyn ContainerPermit>,
task: JoinHandle<Result<()>>,
task: AbortOnDropHandle<Result<()>>,
kill_child: TerminateContainer,
modify_cargo_toml: ModifyCargoToml,
commander: Commander,
Expand All @@ -1186,7 +1186,8 @@ impl Container {

let (command_tx, command_rx) = mpsc::channel(8);
let demultiplex_task =
tokio::spawn(Commander::demultiplex(command_rx, from_worker_rx).in_current_span());
tokio::spawn(Commander::demultiplex(command_rx, from_worker_rx).in_current_span())
.abort_on_drop();

let task = tokio::spawn(
async move {
Expand Down Expand Up @@ -1216,7 +1217,8 @@ impl Container {
Ok(())
}
.in_current_span(),
);
)
.abort_on_drop();

let commander = Commander {
to_worker_tx,
Expand Down Expand Up @@ -1865,7 +1867,8 @@ impl Container {
}
}
.instrument(trace_span!("cargo task").or_current())
});
})
.abort_on_drop();

Ok(SpawnCargo {
permit,
Expand Down Expand Up @@ -2128,7 +2131,7 @@ pub enum DoRequestError {

struct SpawnCargo {
permit: Box<dyn ProcessPermit>,
task: JoinHandle<Result<ExecuteCommandResponse, SpawnCargoError>>,
task: AbortOnDropHandle<Result<ExecuteCommandResponse, SpawnCargoError>>,
stdin_tx: mpsc::Sender<String>,
stdout_rx: mpsc::Receiver<String>,
stderr_rx: mpsc::Receiver<String>,
Expand Down Expand Up @@ -2842,14 +2845,9 @@ fn spawn_io_queue(stdin: ChildStdin, stdout: ChildStdout, token: CancellationTok
let handle = tokio::runtime::Handle::current();

loop {
let coordinator_msg = handle.block_on(async {
select! {
() = token.cancelled() => None,
msg = rx.recv() => msg,
}
});
let coordinator_msg = handle.block_on(token.run_until_cancelled(rx.recv()));

let Some(coordinator_msg) = coordinator_msg else {
let Some(Some(coordinator_msg)) = coordinator_msg else {
break;
};

Expand Down
10 changes: 10 additions & 0 deletions compiler/base/orchestrator/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,16 @@ pub mod coordinator;
mod message;
pub mod worker;

pub trait TaskAbortExt<T>: Sized {
fn abort_on_drop(self) -> tokio_util::task::AbortOnDropHandle<T>;
}

impl<T> TaskAbortExt<T> for tokio::task::JoinHandle<T> {
fn abort_on_drop(self) -> tokio_util::task::AbortOnDropHandle<T> {
tokio_util::task::AbortOnDropHandle::new(self)
}
}

pub trait DropErrorDetailsExt<T> {
fn drop_error_details(self) -> Result<T, tokio::sync::mpsc::error::SendError<()>>;
}
Expand Down
18 changes: 10 additions & 8 deletions compiler/base/orchestrator/src/worker.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ use tokio::{
sync::mpsc,
task::JoinSet,
};
use tokio_util::sync::CancellationToken;
use tokio_util::sync::{CancellationToken, DropGuard};

use crate::{
bincode_input_closed,
Expand All @@ -55,7 +55,7 @@ use crate::{
ExecuteCommandResponse, JobId, Multiplexed, ReadFileRequest, ReadFileResponse,
SerializedError2, WorkerMessage, WriteFileRequest, WriteFileResponse,
},
DropErrorDetailsExt,
DropErrorDetailsExt as _, TaskAbortExt as _,
};

pub async fn listen(project_dir: impl Into<PathBuf>) -> Result<(), Error> {
Expand All @@ -66,14 +66,16 @@ pub async fn listen(project_dir: impl Into<PathBuf>) -> Result<(), Error> {
let mut io_tasks = spawn_io_queue(coordinator_msg_tx, worker_msg_rx);

let (process_tx, process_rx) = mpsc::channel(8);
let process_task = tokio::spawn(manage_processes(process_rx, project_dir.clone()));
let process_task =
tokio::spawn(manage_processes(process_rx, project_dir.clone())).abort_on_drop();

let handler_task = tokio::spawn(handle_coordinator_message(
coordinator_msg_rx,
worker_msg_tx,
project_dir,
process_tx,
));
))
.abort_on_drop();

select! {
Some(io_task) = io_tasks.join_next() => {
Expand Down Expand Up @@ -403,7 +405,7 @@ struct ProcessState {
processes: JoinSet<Result<(), ProcessError>>,
stdin_senders: HashMap<JobId, mpsc::Sender<String>>,
stdin_shutdown_tx: mpsc::Sender<JobId>,
kill_tokens: HashMap<JobId, CancellationToken>,
kill_tokens: HashMap<JobId, DropGuard>,
}

impl ProcessState {
Expand Down Expand Up @@ -456,7 +458,7 @@ impl ProcessState {

let task_set = stream_stdio(worker_msg_tx.clone(), stdin_rx, stdin, stdout, stderr);

self.kill_tokens.insert(job_id, token.clone());
self.kill_tokens.insert(job_id, token.clone().drop_guard());

self.processes.spawn({
let stdin_shutdown_tx = self.stdin_shutdown_tx.clone();
Expand Down Expand Up @@ -508,8 +510,8 @@ impl ProcessState {
}

fn kill(&mut self, job_id: JobId) {
if let Some(token) = self.kill_tokens.get(&job_id) {
token.cancel();
if let Some(token) = self.kill_tokens.remove(&job_id) {
drop(token);
}
}
}
Expand Down
12 changes: 10 additions & 2 deletions ui/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

11 changes: 6 additions & 5 deletions ui/src/server_axum/cache.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use futures::{
future::{Fuse, FusedFuture as _},
FutureExt as _,
};
use orchestrator::DropErrorDetailsExt as _;
use orchestrator::{DropErrorDetailsExt as _, TaskAbortExt as _};
use snafu::prelude::*;
use std::{
future::Future,
Expand All @@ -13,9 +13,9 @@ use std::{
use tokio::{
select,
sync::{mpsc, oneshot},
task::JoinHandle,
time,
};
use tokio_util::task::AbortOnDropHandle;
use tracing::warn;

const ONE_HUNDRED_MILLISECONDS: Duration = Duration::from_millis(100);
Expand Down Expand Up @@ -48,12 +48,12 @@ where
{
pub fn spawn<Fut>(
f: impl FnOnce(mpsc::Receiver<CacheTaskItem<T, E>>) -> Fut,
) -> (JoinHandle<()>, Self)
) -> (AbortOnDropHandle<()>, Self)
where
Fut: Future<Output = ()> + Send + 'static,
{
let (tx, rx) = mpsc::channel(8);
let task = tokio::spawn(f(rx));
let task = tokio::spawn(f(rx)).abort_on_drop();
let cache_tx = CacheTx(tx);
(task, cache_tx)
}
Expand Down Expand Up @@ -148,7 +148,8 @@ where
let new_value = generator().await.map_err(CacheError::from);
CacheInfo::build(new_value)
}
});
})
.abort_on_drop();

new_value.set(new_value_task.fuse());
}
Expand Down
13 changes: 8 additions & 5 deletions ui/src/server_axum/websocket.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ use tokio::{
task::{AbortHandle, JoinSet},
time,
};
use tokio_util::sync::CancellationToken;
use tokio_util::sync::{CancellationToken, DropGuard};
use tracing::{error, info, instrument, warn, Instrument};

#[derive(Debug, serde::Deserialize, serde::Serialize)]
Expand Down Expand Up @@ -525,7 +525,7 @@ async fn handle_idle(manager: &mut CoordinatorManager, tx: &ResponseTx) -> Contr
ControlFlow::Continue(())
}

type ActiveExecutionInfo = (CancellationToken, Option<mpsc::Sender<String>>);
type ActiveExecutionInfo = (DropGuard, Option<mpsc::Sender<String>>);

async fn handle_msg(
txt: &str,
Expand All @@ -545,7 +545,10 @@ async fn handle_msg(

let guard = db.clone().start_with_guard("ws.Execute", txt).await;

active_executions.insert(meta.sequence_number, (token.clone(), Some(execution_tx)));
active_executions.insert(
meta.sequence_number,
(token.clone().drop_guard(), Some(execution_tx)),
);

// TODO: Should a single execute / build / etc. session have a timeout of some kind?
let spawned = manager
Expand Down Expand Up @@ -602,11 +605,11 @@ async fn handle_msg(
}

Ok(ExecuteKill { meta }) => {
let Some((token, _)) = active_executions.get(&meta.sequence_number) else {
let Some((token, _)) = active_executions.remove(&meta.sequence_number) else {
warn!("Received kill for an execution that is no longer active");
return;
};
token.cancel();
drop(token);
}

Err(e) => {
Expand Down