diff --git a/compiler/base/orchestrator/src/coordinator.rs b/compiler/base/orchestrator/src/coordinator.rs index 7c19f732..038e0b29 100644 --- a/compiler/base/orchestrator/src/coordinator.rs +++ b/compiler/base/orchestrator/src/coordinator.rs @@ -1,8 +1,4 @@ -use futures::{ - future::{BoxFuture, OptionFuture}, - stream::BoxStream, - Future, FutureExt, Stream, StreamExt, -}; +use futures::{future::BoxFuture, stream::BoxStream, Future, FutureExt, Stream, StreamExt}; use serde::Deserialize; use snafu::prelude::*; use std::{ @@ -16,12 +12,12 @@ use std::{ time::Duration, }; use tokio::{ - join, process::{Child, ChildStdin, ChildStdout, Command}, select, sync::{mpsc, oneshot, OnceCell}, task::{JoinHandle, JoinSet}, time::{self, MissedTickBehavior}, + try_join, }; use tokio_stream::wrappers::ReceiverStream; use tokio_util::{io::SyncIoBridge, sync::CancellationToken}; @@ -366,12 +362,12 @@ pub struct ExecuteRequest { pub code: String, } -impl ExecuteRequest { - pub(crate) fn delete_previous_main_request(&self) -> DeleteFileRequest { +impl LowerRequest for ExecuteRequest { + fn delete_previous_main_request(&self) -> DeleteFileRequest { delete_previous_primary_file_request(self.crate_type) } - pub(crate) fn write_main_request(&self) -> WriteFileRequest { + fn write_main_request(&self) -> WriteFileRequest { write_primary_file_request(self.crate_type, &self.code) } @@ -457,15 +453,39 @@ pub struct CompileRequest { } impl CompileRequest { - pub(crate) fn delete_previous_main_request(&self) -> DeleteFileRequest { + const OUTPUT_PATH: &str = "compilation"; + + fn read_output_request(&self) -> ReadFileRequest { + ReadFileRequest { + path: Self::OUTPUT_PATH.to_owned(), + } + } + + pub(crate) fn postprocess_result(&self, mut code: String) -> String { + if let CompileTarget::Assembly(_, demangle, process) = self.target { + if demangle == DemangleAssembly::Demangle { + code = asm_cleanup::demangle_asm(&code); + } + + if process == ProcessAssembly::Filter { + code = asm_cleanup::filter_asm(&code); + } + } + + code + } +} + +impl LowerRequest for CompileRequest { + fn delete_previous_main_request(&self) -> DeleteFileRequest { delete_previous_primary_file_request(self.crate_type) } - pub(crate) fn write_main_request(&self) -> WriteFileRequest { + fn write_main_request(&self) -> WriteFileRequest { write_primary_file_request(self.crate_type, &self.code) } - pub(crate) fn execute_cargo_request(&self, output_path: &str) -> ExecuteCommandRequest { + fn execute_cargo_request(&self) -> ExecuteCommandRequest { use CompileTarget::*; let mut args = if let Wasm = self.target { @@ -495,8 +515,8 @@ impl CompileRequest { } LlvmIr => args.extend(&["--", "--emit", "llvm-ir=compilation"]), Mir => args.extend(&["--", "--emit", "mir=compilation"]), - Hir => args.extend(&["--", "-Zunpretty=hir", "-o", output_path]), - Wasm => args.extend(&["-o", output_path]), + Hir => args.extend(&["--", "-Zunpretty=hir", "-o", Self::OUTPUT_PATH]), + Wasm => args.extend(&["-o", Self::OUTPUT_PATH]), } let mut envs = HashMap::new(); if self.backtrace { @@ -510,20 +530,6 @@ impl CompileRequest { cwd: None, } } - - pub(crate) fn postprocess_result(&self, mut code: String) -> String { - if let CompileTarget::Assembly(_, demangle, process) = self.target { - if demangle == DemangleAssembly::Demangle { - code = asm_cleanup::demangle_asm(&code); - } - - if process == ProcessAssembly::Filter { - code = asm_cleanup::filter_asm(&code); - } - } - - code - } } impl CargoTomlModifier for CompileRequest { @@ -563,15 +569,23 @@ pub struct FormatRequest { } impl FormatRequest { - pub(crate) fn delete_previous_main_request(&self) -> DeleteFileRequest { + fn read_output_request(&self) -> ReadFileRequest { + ReadFileRequest { + path: self.crate_type.primary_path().to_owned(), + } + } +} + +impl LowerRequest for FormatRequest { + fn delete_previous_main_request(&self) -> DeleteFileRequest { delete_previous_primary_file_request(self.crate_type) } - pub(crate) fn write_main_request(&self) -> WriteFileRequest { + fn write_main_request(&self) -> WriteFileRequest { write_primary_file_request(self.crate_type, &self.code) } - pub(crate) fn execute_cargo_request(&self) -> ExecuteCommandRequest { + fn execute_cargo_request(&self) -> ExecuteCommandRequest { ExecuteCommandRequest { cmd: "cargo".to_owned(), args: vec!["fmt".to_owned()], @@ -611,16 +625,16 @@ pub struct ClippyRequest { pub code: String, } -impl ClippyRequest { - pub(crate) fn delete_previous_main_request(&self) -> DeleteFileRequest { +impl LowerRequest for ClippyRequest { + fn delete_previous_main_request(&self) -> DeleteFileRequest { delete_previous_primary_file_request(self.crate_type) } - pub(crate) fn write_main_request(&self) -> WriteFileRequest { + fn write_main_request(&self) -> WriteFileRequest { write_primary_file_request(self.crate_type, &self.code) } - pub(crate) fn execute_cargo_request(&self) -> ExecuteCommandRequest { + fn execute_cargo_request(&self) -> ExecuteCommandRequest { ExecuteCommandRequest { cmd: "cargo".to_owned(), args: vec!["clippy".to_owned()], @@ -659,16 +673,16 @@ pub struct MiriRequest { pub code: String, } -impl MiriRequest { - pub(crate) fn delete_previous_main_request(&self) -> DeleteFileRequest { +impl LowerRequest for MiriRequest { + fn delete_previous_main_request(&self) -> DeleteFileRequest { delete_previous_primary_file_request(self.crate_type) } - pub(crate) fn write_main_request(&self) -> WriteFileRequest { + fn write_main_request(&self) -> WriteFileRequest { write_primary_file_request(self.crate_type, &self.code) } - pub(crate) fn execute_cargo_request(&self) -> ExecuteCommandRequest { + fn execute_cargo_request(&self) -> ExecuteCommandRequest { ExecuteCommandRequest { cmd: "cargo".to_owned(), args: vec!["miri-playground".to_owned()], @@ -707,16 +721,16 @@ pub struct MacroExpansionRequest { pub code: String, } -impl MacroExpansionRequest { - pub(crate) fn delete_previous_main_request(&self) -> DeleteFileRequest { +impl LowerRequest for MacroExpansionRequest { + fn delete_previous_main_request(&self) -> DeleteFileRequest { delete_previous_primary_file_request(self.crate_type) } - pub(crate) fn write_main_request(&self) -> WriteFileRequest { + fn write_main_request(&self) -> WriteFileRequest { write_primary_file_request(self.crate_type, &self.code) } - pub(crate) fn execute_cargo_request(&self) -> ExecuteCommandRequest { + fn execute_cargo_request(&self) -> ExecuteCommandRequest { ExecuteCommandRequest { cmd: "cargo".to_owned(), args: ["rustc", "--", "-Zunpretty=expanded"] @@ -781,11 +795,10 @@ impl WithOutput { where F: Future>, { - let stdout = stdout_rx.collect(); - let stderr = stderr_rx.collect(); + let stdout = stdout_rx.collect().map(Ok); + let stderr = stderr_rx.collect().map(Ok); - let (result, stdout, stderr) = join!(task, stdout, stderr); - let response = result?; + let (response, stdout, stderr) = try_join!(task, stdout, stderr)?; Ok(WithOutput { response, @@ -880,7 +893,7 @@ pub struct Coordinator { stable: OnceCell, beta: OnceCell, nightly: OnceCell, - token: CancellationToken, + token: CancelOnDrop, } /// Runs things. @@ -897,7 +910,7 @@ where B: Backend, { pub fn new(limits: Arc, backend: B) -> Self { - let token = CancellationToken::new(); + let token = CancelOnDrop(CancellationToken::new()); Self { limits, @@ -918,11 +931,11 @@ where c.versions().await.map_err(VersionsChannelError::from) }); - let (stable, beta, nightly) = join!(stable, beta, nightly); + let stable = async { stable.await.context(StableSnafu) }; + let beta = async { beta.await.context(BetaSnafu) }; + let nightly = async { nightly.await.context(NightlySnafu) }; - let stable = stable.context(StableSnafu)?; - let beta = beta.context(BetaSnafu)?; - let nightly = nightly.context(NightlySnafu)?; + let (stable, beta, nightly) = try_join!(stable, beta, nightly)?; Ok(Versions { stable, @@ -1110,16 +1123,17 @@ where let token = mem::take(token); token.cancel(); - let channels = - [stable, beta, nightly].map(|c| OptionFuture::from(c.take().map(|c| c.shutdown()))); + let channels = [stable, beta, nightly].map(|c| async { + match c.take() { + Some(c) => c.shutdown().await, + _ => Ok(()), + } + }); let [stable, beta, nightly] = channels; - let (stable, beta, nightly) = join!(stable, beta, nightly); - - stable.transpose()?; - beta.transpose()?; - nightly.transpose()?; + let (stable, beta, nightly) = try_join!(stable, beta, nightly)?; + let _: [(); 3] = [stable, beta, nightly]; Ok(()) } @@ -1139,13 +1153,28 @@ where container .get_or_try_init(|| { let limits = self.limits.clone(); - let token = self.token.clone(); + let token = self.token.0.clone(); Container::new(channel, limits, token, &self.backend) }) .await } } +#[derive(Debug, Default)] +struct CancelOnDrop(CancellationToken); + +impl CancelOnDrop { + fn cancel(&self) { + self.0.cancel(); + } +} + +impl Drop for CancelOnDrop { + fn drop(&mut self) { + self.0.cancel(); + } +} + #[derive(Debug)] struct Container { permit: Box, @@ -1178,14 +1207,28 @@ impl Container { let task = tokio::spawn( async move { - let (c, d, t) = join!(child.wait(), demultiplex_task, tasks.join_next()); + let child = async { + let _: std::process::ExitStatus = + child.wait().await.context(JoinWorkerSnafu)?; + Ok(()) + }; - c.context(JoinWorkerSnafu)?; - d.context(DemultiplexerTaskPanickedSnafu)? - .context(DemultiplexerTaskFailedSnafu)?; - if let Some(t) = t { - t.context(IoQueuePanickedSnafu)??; - } + let demultiplex_task = async { + demultiplex_task + .await + .context(DemultiplexerTaskPanickedSnafu)? + .context(DemultiplexerTaskFailedSnafu) + }; + + let task = async { + if let Some(t) = tasks.join_next().await { + t.context(IoQueuePanickedSnafu)??; + } + Ok(()) + }; + + let (c, d, t) = try_join!(child, demultiplex_task, task)?; + let _: [(); 3] = [c, d, t]; Ok(()) } @@ -1216,19 +1259,41 @@ impl Container { let token = CancellationToken::new(); - let rustc = self.rustc_version(token.clone()); - let rustfmt = self.tool_version(token.clone(), "fmt"); - let clippy = self.tool_version(token.clone(), "clippy"); - let miri = self.tool_version(token, "miri"); + let rustc = { + let token = token.clone(); + async { + self.rustc_version(token) + .await + .context(RustcSnafu)? + .context(RustcMissingSnafu) + } + }; + let rustfmt = { + let token = token.clone(); + async { + self.tool_version(token, "fmt") + .await + .context(RustfmtSnafu)? + .context(RustfmtMissingSnafu) + } + }; + let clippy = { + let token = token.clone(); + async { + self.tool_version(token, "clippy") + .await + .context(ClippySnafu)? + .context(ClippyMissingSnafu) + } + }; + let miri = { + let token = token.clone(); + async { self.tool_version(token, "miri").await.context(MiriSnafu) } + }; - let (rustc, rustfmt, clippy, miri) = join!(rustc, rustfmt, clippy, miri); + let _token = token.drop_guard(); - let rustc = rustc.context(RustcSnafu)?.context(RustcMissingSnafu)?; - let rustfmt = rustfmt - .context(RustfmtSnafu)? - .context(RustfmtMissingSnafu)?; - let clippy = clippy.context(ClippySnafu)?.context(ClippyMissingSnafu)?; - let miri = miri.context(MiriSnafu)?; + let (rustc, rustfmt, clippy, miri) = try_join!(rustc, rustfmt, clippy, miri)?; Ok(ChannelVersions { rustc, @@ -1264,7 +1329,6 @@ impl Container { token: CancellationToken, cmd: ExecuteCommandRequest, ) -> Result, VersionError> { - let v = self.spawn_cargo_task(token.clone(), cmd).await?; let SpawnCargo { permit: _permit, task, @@ -1272,7 +1336,7 @@ impl Container { stdout_rx, stderr_rx, status_rx, - } = v; + } = self.spawn_cargo_task(token, cmd).await?; drop(stdin_tx); drop(status_rx); @@ -1320,21 +1384,6 @@ impl Container { ) -> Result { use execute_error::*; - let delete_previous_main = request.delete_previous_main_request(); - let write_main = request.write_main_request(); - let execute_cargo = request.execute_cargo_request(); - - let delete_previous_main = self.commander.one(delete_previous_main); - let write_main = self.commander.one(write_main); - let modify_cargo_toml = self.modify_cargo_toml.modify_for(&request); - - let (delete_previous_main, write_main, modify_cargo_toml) = - join!(delete_previous_main, write_main, modify_cargo_toml); - - delete_previous_main.context(CouldNotDeletePreviousCodeSnafu)?; - write_main.context(CouldNotWriteCodeSnafu)?; - modify_cargo_toml.context(CouldNotModifyCargoTomlSnafu)?; - let SpawnCargo { permit, task, @@ -1342,10 +1391,7 @@ impl Container { stdout_rx, stderr_rx, status_rx, - } = self - .spawn_cargo_task(token, execute_cargo) - .await - .context(CouldNotStartCargoSnafu)?; + } = self.do_request(request, token).await?; let task = async move { let ExecuteCommandResponse { @@ -1409,26 +1455,6 @@ impl Container { ) -> Result { use compile_error::*; - let output_path: &str = "compilation"; - - let delete_previous_main = request.delete_previous_main_request(); - let write_main = request.write_main_request(); - let execute_cargo = request.execute_cargo_request(output_path); - let read_output = ReadFileRequest { - path: output_path.to_owned(), - }; - - let delete_previous_main = self.commander.one(delete_previous_main); - let write_main = self.commander.one(write_main); - let modify_cargo_toml = self.modify_cargo_toml.modify_for(&request); - - let (delete_previous_main, write_main, modify_cargo_toml) = - join!(delete_previous_main, write_main, modify_cargo_toml); - - delete_previous_main.context(CouldNotDeletePreviousCodeSnafu)?; - write_main.context(CouldNotWriteCodeSnafu)?; - modify_cargo_toml.context(CouldNotModifyCargoTomlSnafu)?; - let SpawnCargo { permit, task, @@ -1436,10 +1462,7 @@ impl Container { stdout_rx, stderr_rx, status_rx, - } = self - .spawn_cargo_task(token, execute_cargo) - .await - .context(CouldNotStartCargoSnafu)?; + } = self.do_request(&request, token).await?; drop(stdin_tx); drop(status_rx); @@ -1455,6 +1478,8 @@ impl Container { .context(CargoFailedSnafu)?; let code = if success { + let read_output = request.read_output_request(); + let file: ReadFileResponse = commander .one(read_output) .await @@ -1506,24 +1531,6 @@ impl Container { ) -> Result { use format_error::*; - let delete_previous_main = request.delete_previous_main_request(); - let write_main = request.write_main_request(); - let execute_cargo = request.execute_cargo_request(); - let read_output = ReadFileRequest { - path: request.crate_type.primary_path().to_owned(), - }; - - let delete_previous_main = self.commander.one(delete_previous_main); - let write_main = self.commander.one(write_main); - let modify_cargo_toml = self.modify_cargo_toml.modify_for(&request); - - let (delete_previous_main, write_main, modify_cargo_toml) = - join!(delete_previous_main, write_main, modify_cargo_toml); - - delete_previous_main.context(CouldNotDeletePreviousCodeSnafu)?; - write_main.context(CouldNotWriteCodeSnafu)?; - modify_cargo_toml.context(CouldNotModifyCargoTomlSnafu)?; - let SpawnCargo { permit, task, @@ -1531,10 +1538,7 @@ impl Container { stdout_rx, stderr_rx, status_rx, - } = self - .spawn_cargo_task(token, execute_cargo) - .await - .context(CouldNotStartCargoSnafu)?; + } = self.do_request(&request, token).await?; drop(stdin_tx); drop(status_rx); @@ -1549,6 +1553,7 @@ impl Container { .context(CargoTaskPanickedSnafu)? .context(CargoFailedSnafu)?; + let read_output = request.read_output_request(); let file = commander .one(read_output) .await @@ -1594,21 +1599,6 @@ impl Container { ) -> Result { use clippy_error::*; - let delete_previous_main = request.delete_previous_main_request(); - let write_main = request.write_main_request(); - let execute_cargo = request.execute_cargo_request(); - - let delete_previous_main = self.commander.one(delete_previous_main); - let write_main = self.commander.one(write_main); - let modify_cargo_toml = self.modify_cargo_toml.modify_for(&request); - - let (delete_previous_main, write_main, modify_cargo_toml) = - join!(delete_previous_main, write_main, modify_cargo_toml); - - delete_previous_main.context(CouldNotDeletePreviousCodeSnafu)?; - write_main.context(CouldNotWriteCodeSnafu)?; - modify_cargo_toml.context(CouldNotModifyCargoTomlSnafu)?; - let SpawnCargo { permit, task, @@ -1616,10 +1606,7 @@ impl Container { stdout_rx, stderr_rx, status_rx, - } = self - .spawn_cargo_task(token, execute_cargo) - .await - .context(CouldNotStartCargoSnafu)?; + } = self.do_request(request, token).await?; drop(stdin_tx); drop(status_rx); @@ -1668,21 +1655,6 @@ impl Container { ) -> Result { use miri_error::*; - let delete_previous_main = request.delete_previous_main_request(); - let write_main = request.write_main_request(); - let execute_cargo = request.execute_cargo_request(); - - let delete_previous_main = self.commander.one(delete_previous_main); - let write_main = self.commander.one(write_main); - let modify_cargo_toml = self.modify_cargo_toml.modify_for(&request); - - let (delete_previous_main, write_main, modify_cargo_toml) = - join!(delete_previous_main, write_main, modify_cargo_toml); - - delete_previous_main.context(CouldNotDeletePreviousCodeSnafu)?; - write_main.context(CouldNotWriteCodeSnafu)?; - modify_cargo_toml.context(CouldNotModifyCargoTomlSnafu)?; - let SpawnCargo { permit, task, @@ -1690,10 +1662,7 @@ impl Container { stdout_rx, stderr_rx, status_rx, - } = self - .spawn_cargo_task(token, execute_cargo) - .await - .context(CouldNotStartCargoSnafu)?; + } = self.do_request(request, token).await?; drop(stdin_tx); drop(status_rx); @@ -1745,21 +1714,6 @@ impl Container { ) -> Result { use macro_expansion_error::*; - let delete_previous_main = request.delete_previous_main_request(); - let write_main = request.write_main_request(); - let execute_cargo = request.execute_cargo_request(); - - let delete_previous_main = self.commander.one(delete_previous_main); - let write_main = self.commander.one(write_main); - let modify_cargo_toml = self.modify_cargo_toml.modify_for(&request); - - let (delete_previous_main, write_main, modify_cargo_toml) = - join!(delete_previous_main, write_main, modify_cargo_toml); - - delete_previous_main.context(CouldNotDeletePreviousCodeSnafu)?; - write_main.context(CouldNotWriteCodeSnafu)?; - modify_cargo_toml.context(CouldNotModifyCargoTomlSnafu)?; - let SpawnCargo { permit, task, @@ -1767,10 +1721,7 @@ impl Container { stdout_rx, stderr_rx, status_rx, - } = self - .spawn_cargo_task(token, execute_cargo) - .await - .context(CouldNotStartCargoSnafu)?; + } = self.do_request(request, token).await?; drop(stdin_tx); drop(status_rx); @@ -1799,6 +1750,45 @@ impl Container { }) } + async fn do_request( + &self, + request: impl LowerRequest + CargoTomlModifier, + token: CancellationToken, + ) -> Result { + use do_request_error::*; + + let delete_previous_main = async { + self.commander + .one(request.delete_previous_main_request()) + .await + .context(CouldNotDeletePreviousCodeSnafu) + .map(drop::) + }; + + let write_main = async { + self.commander + .one(request.write_main_request()) + .await + .context(CouldNotWriteCodeSnafu) + .map(drop::) + }; + + let modify_cargo_toml = async { + self.modify_cargo_toml + .modify_for(&request) + .await + .context(CouldNotModifyCargoTomlSnafu) + }; + + let (d, w, m) = try_join!(delete_previous_main, write_main, modify_cargo_toml)?; + let _: [(); 3] = [d, w, m]; + + let execute_cargo = request.execute_cargo_request(); + self.spawn_cargo_task(token, execute_cargo) + .await + .context(CouldNotStartCargoSnafu) + } + async fn spawn_cargo_task( &self, token: CancellationToken, @@ -1950,17 +1940,8 @@ pub enum ExecuteError { #[snafu(display("Could not start the container"))] CouldNotStartContainer { source: Error }, - #[snafu(display("Could not modify Cargo.toml"))] - CouldNotModifyCargoToml { source: ModifyCargoTomlError }, - - #[snafu(display("Could not delete previous source code"))] - CouldNotDeletePreviousCode { source: CommanderError }, - - #[snafu(display("Could not write source code"))] - CouldNotWriteCode { source: CommanderError }, - - #[snafu(display("Could not start Cargo task"))] - CouldNotStartCargo { source: SpawnCargoError }, + #[snafu(transparent)] + DoRequest { source: DoRequestError }, #[snafu(display("The Cargo task panicked"))] CargoTaskPanicked { source: tokio::task::JoinError }, @@ -1992,17 +1973,8 @@ pub enum CompileError { #[snafu(display("Could not start the container"))] CouldNotStartContainer { source: Error }, - #[snafu(display("Could not modify Cargo.toml"))] - CouldNotModifyCargoToml { source: ModifyCargoTomlError }, - - #[snafu(display("Could not delete previous source code"))] - CouldNotDeletePreviousCode { source: CommanderError }, - - #[snafu(display("Could not write source code"))] - CouldNotWriteCode { source: CommanderError }, - - #[snafu(display("Could not start Cargo task"))] - CouldNotStartCargo { source: SpawnCargoError }, + #[snafu(transparent)] + DoRequest { source: DoRequestError }, #[snafu(display("The Cargo task panicked"))] CargoTaskPanicked { source: tokio::task::JoinError }, @@ -2040,17 +2012,8 @@ pub enum FormatError { #[snafu(display("Could not start the container"))] CouldNotStartContainer { source: Error }, - #[snafu(display("Could not modify Cargo.toml"))] - CouldNotModifyCargoToml { source: ModifyCargoTomlError }, - - #[snafu(display("Could not delete previous source code"))] - CouldNotDeletePreviousCode { source: CommanderError }, - - #[snafu(display("Could not write source code"))] - CouldNotWriteCode { source: CommanderError }, - - #[snafu(display("Could not start Cargo task"))] - CouldNotStartCargo { source: SpawnCargoError }, + #[snafu(transparent)] + DoRequest { source: DoRequestError }, #[snafu(display("The Cargo task panicked"))] CargoTaskPanicked { source: tokio::task::JoinError }, @@ -2088,17 +2051,8 @@ pub enum ClippyError { #[snafu(display("Could not start the container"))] CouldNotStartContainer { source: Error }, - #[snafu(display("Could not modify Cargo.toml"))] - CouldNotModifyCargoToml { source: ModifyCargoTomlError }, - - #[snafu(display("Could not delete previous source code"))] - CouldNotDeletePreviousCode { source: CommanderError }, - - #[snafu(display("Could not write source code"))] - CouldNotWriteCode { source: CommanderError }, - - #[snafu(display("Could not start Cargo task"))] - CouldNotStartCargo { source: SpawnCargoError }, + #[snafu(transparent)] + DoRequest { source: DoRequestError }, #[snafu(display("The Cargo task panicked"))] CargoTaskPanicked { source: tokio::task::JoinError }, @@ -2130,17 +2084,8 @@ pub enum MiriError { #[snafu(display("Could not start the container"))] CouldNotStartContainer { source: Error }, - #[snafu(display("Could not modify Cargo.toml"))] - CouldNotModifyCargoToml { source: ModifyCargoTomlError }, - - #[snafu(display("Could not delete previous source code"))] - CouldNotDeletePreviousCode { source: CommanderError }, - - #[snafu(display("Could not write source code"))] - CouldNotWriteCode { source: CommanderError }, - - #[snafu(display("Could not start Cargo task"))] - CouldNotStartCargo { source: SpawnCargoError }, + #[snafu(transparent)] + DoRequest { source: DoRequestError }, #[snafu(display("The Cargo task panicked"))] CargoTaskPanicked { source: tokio::task::JoinError }, @@ -2172,6 +2117,19 @@ pub enum MacroExpansionError { #[snafu(display("Could not start the container"))] CouldNotStartContainer { source: Error }, + #[snafu(transparent)] + DoRequest { source: DoRequestError }, + + #[snafu(display("The Cargo task panicked"))] + CargoTaskPanicked { source: tokio::task::JoinError }, + + #[snafu(display("Cargo task failed"))] + CargoFailed { source: SpawnCargoError }, +} + +#[derive(Debug, Snafu)] +#[snafu(module)] +pub enum DoRequestError { #[snafu(display("Could not modify Cargo.toml"))] CouldNotModifyCargoToml { source: ModifyCargoTomlError }, @@ -2183,12 +2141,6 @@ pub enum MacroExpansionError { #[snafu(display("Could not start Cargo task"))] CouldNotStartCargo { source: SpawnCargoError }, - - #[snafu(display("The Cargo task panicked"))] - CargoTaskPanicked { source: tokio::task::JoinError }, - - #[snafu(display("Cargo task failed"))] - CargoFailed { source: SpawnCargoError }, } struct SpawnCargo { @@ -2232,10 +2184,44 @@ struct Commander { id: Arc, } +trait LowerRequest { + fn delete_previous_main_request(&self) -> DeleteFileRequest; + + fn write_main_request(&self) -> WriteFileRequest; + + fn execute_cargo_request(&self) -> ExecuteCommandRequest; +} + +impl LowerRequest for &S +where + S: LowerRequest, +{ + fn delete_previous_main_request(&self) -> DeleteFileRequest { + S::delete_previous_main_request(self) + } + + fn write_main_request(&self) -> WriteFileRequest { + S::write_main_request(self) + } + + fn execute_cargo_request(&self) -> ExecuteCommandRequest { + S::execute_cargo_request(self) + } +} + trait CargoTomlModifier { fn modify_cargo_toml(&self, cargo_toml: toml::Value) -> toml::Value; } +impl CargoTomlModifier for &C +where + C: CargoTomlModifier, +{ + fn modify_cargo_toml(&self, cargo_toml: toml::Value) -> toml::Value { + C::modify_cargo_toml(self, cargo_toml) + } +} + #[derive(Debug)] struct ModifyCargoToml { commander: Commander, @@ -2572,9 +2558,9 @@ impl TerminateContainer { .insert(name.into()); if was_inserted { - info!(%name, "Started tracking container"); + info!("Started tracking container"); } else { - error!(%name, "Started tracking container, but it was already tracked"); + error!("Started tracking container, but it was already tracked"); } } @@ -2586,9 +2572,9 @@ impl TerminateContainer { .remove(name); if was_tracked { - info!(%name, "Stopped tracking container"); + info!("Stopped tracking container"); } else { - error!(%name, "Stopped tracking container, but it was not in the tracking set"); + error!("Stopped tracking container, but it was not in the tracking set"); } } diff --git a/ui/src/main.rs b/ui/src/main.rs index 8decaaad..42eb0b05 100644 --- a/ui/src/main.rs +++ b/ui/src/main.rs @@ -8,12 +8,17 @@ use std::{ net::SocketAddr, path::{Path, PathBuf}, sync::Arc, + time::Duration, }; use tracing::{error, info, warn}; use tracing_subscriber::EnvFilter; const DEFAULT_ADDRESS: &str = "127.0.0.1"; const DEFAULT_PORT: u16 = 5000; + +const DEFAULT_WEBSOCKET_IDLE_TIMEOUT: Duration = Duration::from_secs(60); +const DEFAULT_WEBSOCKET_SESSION_TIMEOUT: Duration = Duration::from_secs(45 * 60); + const DEFAULT_COORDINATORS_LIMIT: usize = 25; const DEFAULT_PROCESSES_LIMIT: usize = 10; @@ -50,6 +55,7 @@ struct Config { metrics_token: Option, feature_flags: FeatureFlags, request_db_path: Option, + websocket_config: WebSocketConfig, limits: Arc, port: u16, root: PathBuf, @@ -108,6 +114,23 @@ impl Config { let request_db_path = env::var_os("PLAYGROUND_REQUEST_DATABASE").map(Into::into); + let websocket_config = { + let idle_timeout = env::var("PLAYGROUND_WEBSOCKET_IDLE_TIMEOUT_S") + .ok() + .and_then(|l| l.parse().map(Duration::from_secs).ok()) + .unwrap_or(DEFAULT_WEBSOCKET_IDLE_TIMEOUT); + + let session_timeout = env::var("PLAYGROUND_WEBSOCKET_SESSION_TIMEOUT_S") + .ok() + .and_then(|l| l.parse().map(Duration::from_secs).ok()) + .unwrap_or(DEFAULT_WEBSOCKET_SESSION_TIMEOUT); + + WebSocketConfig { + idle_timeout, + session_timeout, + } + }; + let coordinators_limit = env::var("PLAYGROUND_COORDINATORS_LIMIT") .ok() .and_then(|l| l.parse().ok()) @@ -131,6 +154,7 @@ impl Config { metrics_token, feature_flags, request_db_path, + websocket_config, limits, port, root, @@ -232,3 +256,9 @@ impl limits::Lifecycle for LifecycleMetrics { metrics::PROCESS_ACTIVE.dec(); } } + +#[derive(Debug, Copy, Clone)] +struct WebSocketConfig { + idle_timeout: Duration, + session_timeout: Duration, +} diff --git a/ui/src/server_axum.rs b/ui/src/server_axum.rs index bbe2cad8..6d01422b 100644 --- a/ui/src/server_axum.rs +++ b/ui/src/server_axum.rs @@ -5,7 +5,7 @@ use crate::{ UNAVAILABLE_WS, }, request_database::Handle, - Config, GhToken, MetricsToken, + Config, GhToken, MetricsToken, WebSocketConfig, }; use async_trait::async_trait; use axum::{ @@ -111,7 +111,8 @@ pub(crate) async fn serve(config: Config) { .layer(Extension(db_handle)) .layer(Extension(Arc::new(SandboxCache::default()))) .layer(Extension(config.github_token())) - .layer(Extension(config.feature_flags)); + .layer(Extension(config.feature_flags)) + .layer(Extension(config.websocket_config)); if let Some(token) = config.metrics_token() { app = app.layer(Extension(token)) @@ -652,11 +653,12 @@ async fn metrics(_: MetricsAuthorization) -> Result, StatusCode> { async fn websocket( ws: WebSocketUpgrade, + Extension(config): Extension, Extension(factory): Extension, Extension(feature_flags): Extension, Extension(db): Extension, ) -> impl IntoResponse { - ws.on_upgrade(move |s| websocket::handle(s, factory.0, feature_flags.into(), db)) + ws.on_upgrade(move |s| websocket::handle(s, config, factory.0, feature_flags.into(), db)) } #[derive(Debug, serde::Deserialize)] diff --git a/ui/src/server_axum/websocket.rs b/ui/src/server_axum/websocket.rs index 1b9892b5..1895c8ac 100644 --- a/ui/src/server_axum/websocket.rs +++ b/ui/src/server_axum/websocket.rs @@ -2,6 +2,7 @@ use crate::{ metrics::{self, record_metric, Endpoint, HasLabelsCore, Outcome}, request_database::Handle, server_axum::api_orchestrator_integration_impls::*, + WebSocketConfig, }; use axum::extract::ws::{Message, WebSocket}; @@ -28,7 +29,7 @@ use tokio::{ time, }; use tokio_util::sync::CancellationToken; -use tracing::{error, instrument, warn, Instrument}; +use tracing::{error, info, instrument, warn, Instrument}; #[derive(Debug, serde::Deserialize, serde::Serialize)] #[serde(rename_all = "camelCase")] @@ -199,6 +200,7 @@ struct ExecuteResponse { #[instrument(skip_all, fields(ws_id))] pub(crate) async fn handle( socket: WebSocket, + config: WebSocketConfig, factory: Arc, feature_flags: FeatureFlags, db: Handle, @@ -210,9 +212,11 @@ pub(crate) async fn handle( let id = WEBSOCKET_ID.fetch_add(1, Ordering::SeqCst); tracing::Span::current().record("ws_id", &id); + info!("WebSocket started"); - handle_core(socket, factory, feature_flags, db).await; + handle_core(socket, config, factory, feature_flags, db).await; + info!("WebSocket ending"); metrics::LIVE_WS.dec(); let elapsed = start.elapsed(); metrics::DURATION_WS.observe(elapsed.as_secs_f64()); @@ -240,9 +244,6 @@ struct CoordinatorManager { } impl CoordinatorManager { - const IDLE_TIMEOUT: Duration = Duration::from_secs(60); - const SESSION_TIMEOUT: Duration = Duration::from_secs(45 * 60); - const N_PARALLEL: usize = 2; const N_KINDS: usize = 1; @@ -341,6 +342,7 @@ type CoordinatorManagerResult = std::result::Res async fn handle_core( mut socket: WebSocket, + config: WebSocketConfig, factory: Arc, feature_flags: FeatureFlags, db: Handle, @@ -361,7 +363,7 @@ async fn handle_core( } let mut manager = CoordinatorManager::new(&factory); - let mut session_timeout = pin!(time::sleep(CoordinatorManager::SESSION_TIMEOUT)); + let mut session_timeout = pin!(time::sleep(config.session_timeout)); let mut idle_timeout = pin!(Fuse::terminated()); let mut active_executions = BTreeMap::new(); @@ -407,7 +409,7 @@ async fn handle_core( // The last task has completed which means we are a // candidate for idling in a little while. if manager.is_empty() { - idle_timeout.set(time::sleep(CoordinatorManager::IDLE_TIMEOUT).fuse()); + idle_timeout.set(time::sleep(config.idle_timeout).fuse()); } let (error, meta) = match task {