diff --git a/Cargo.lock b/Cargo.lock index 25fb1fe..8b5f4f7 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -471,6 +471,12 @@ dependencies = [ "syn", ] +[[package]] +name = "data-encoding" +version = "2.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2a2330da5de22e8a3cb63252ce2abb30116bf5265e89c0e01bc17015ce30a476" + [[package]] name = "deflate64" version = "0.1.9" @@ -774,7 +780,7 @@ dependencies = [ "futures-core", "futures-sink", "futures-util", - "http", + "http 0.2.9", "indexmap 1.9.3", "slab", "tokio", @@ -841,6 +847,16 @@ dependencies = [ "itoa", ] +[[package]] +name = "http" +version = "1.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e3ba2a386d7f85a81f119ad7498ebe444d2e22c2af0b86b069416ace48b3311a" +dependencies = [ + "bytes", + "itoa", +] + [[package]] name = "http-body" version = "0.4.5" @@ -848,7 +864,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d5f38f16d184e36f2408a55281cd658ecbd3ca05cce6d6510a176eca393e26d1" dependencies = [ "bytes", - "http", + "http 0.2.9", "pin-project-lite", ] @@ -875,7 +891,7 @@ dependencies = [ "futures-core", "futures-util", "h2", - "http", + "http 0.2.9", "http-body", "httparse", "httpdate", @@ -1465,6 +1481,15 @@ version = "0.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "439ee305def115ba05938db6eb1644ff94165c5ab5e9420d1c1bcedbba909391" +[[package]] +name = "ppv-lite86" +version = "0.2.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "85eae3c4ed2f50dcfe72643da4befc30deadb458a9b590d720cde2f2b1e97da9" +dependencies = [ + "zerocopy", +] + [[package]] name = "proc-macro2" version = "1.0.94" @@ -1502,6 +1527,35 @@ version = "5.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "74765f6d916ee2faa39bc8e68e4f3ed8949b48cccdac59983d287a7cb71ce9c5" +[[package]] +name = "rand" +version = "0.9.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6db2770f06117d490610c7488547d543617b21bfa07796d7a12f6f1bd53850d1" +dependencies = [ + "rand_chacha", + "rand_core", +] + +[[package]] +name = "rand_chacha" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d3022b5f1df60f26e1ffddd6c66e8aa15de382ae63b3a0c1bfc0e4d3e3f325cb" +dependencies = [ + "ppv-lite86", + "rand_core", +] + +[[package]] +name = "rand_core" +version = "0.9.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "99d9a13982dcf210057a8a78572b2217b667c3beacbf3a0d8b454f6f82837d38" +dependencies = [ + "getrandom 0.3.2", +] + [[package]] name = "rayon" version = "1.9.0" @@ -1554,7 +1608,7 @@ dependencies = [ "futures-core", "futures-util", "h2", - "http", + "http 0.2.9", "http-body", "hyper", "hyper-tls", @@ -1944,6 +1998,7 @@ dependencies = [ "tempfile", "thiserror 1.0.69", "tokio", + "tokio-tungstenite", "tokio-util", "toml 0.7.8", "walkdir", @@ -2110,6 +2165,18 @@ dependencies = [ "tokio", ] +[[package]] +name = "tokio-tungstenite" +version = "0.28.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d25a406cddcc431a75d3d9afc6a7c0f7428d4891dd973e4d54c56b46127bf857" +dependencies = [ + "futures-util", + "log", + "tokio", + "tungstenite", +] + [[package]] name = "tokio-util" version = "0.7.8" @@ -2215,6 +2282,23 @@ version = "0.2.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3528ecfd12c466c6f163363caf2d02a71161dd5e1cc6ae7b34207ea2d42d81ed" +[[package]] +name = "tungstenite" +version = "0.28.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8628dcc84e5a09eb3d8423d6cb682965dea9133204e8fb3efee74c2a0c259442" +dependencies = [ + "bytes", + "data-encoding", + "http 1.4.0", + "httparse", + "log", + "rand", + "sha1", + "thiserror 2.0.12", + "utf-8", +] + [[package]] name = "typenum" version = "1.18.0" @@ -2292,6 +2376,12 @@ dependencies = [ "percent-encoding", ] +[[package]] +name = "utf-8" +version = "0.7.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "09cc8ee72d2a9becf2f2febe0205bbed8fc6615b7cb429ad062dc7b7ddd036a9" + [[package]] name = "utf8parse" version = "0.2.2" @@ -2762,6 +2852,26 @@ version = "1.0.0-rc.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1367295b8f788d371ce2dbc842c7b709c73ee1364d30351dd300ec2203b12377" +[[package]] +name = "zerocopy" +version = "0.8.33" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "668f5168d10b9ee831de31933dc111a459c97ec93225beb307aed970d1372dfd" +dependencies = [ + "zerocopy-derive", +] + +[[package]] +name = "zerocopy-derive" +version = "0.8.33" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2c7962b26b0a8685668b671ee4b54d007a67d4eaf05fda79ac0ecf41e32270f1" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "zeroize" version = "1.8.1" diff --git a/Cargo.toml b/Cargo.toml index 1047bf2..3a89060 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -23,8 +23,10 @@ tokio = { version = "1.44", features = [ "macros", "process", "rt-multi-thread", + "net", ] } tokio-util = { version = "0.7", features = ["io"] } +tokio-tungstenite = "0.28" async-compression = { version = "0.4", features = ["futures-io", "gzip"] } # parsers, serializations, and other data processing diff --git a/src/main.rs b/src/main.rs index b063f35..d172b22 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,7 +1,8 @@ #![allow(dead_code)] -use std::{env, io}; +use std::env; use std::path::PathBuf; +use std::sync::Arc; use clap::Parser; use cli::{ExternSubcommand, InitSubcommand}; @@ -11,7 +12,6 @@ use once_cell::sync::Lazy; use project::error::ProjectError; use project::ProjectKind; use ts::error::ApiError; -use ts::v1::models::ecosystem::GamePlatform; use wildmatch::WildMatch; use crate::cli::{Args, Commands, ListSubcommand}; @@ -23,8 +23,7 @@ use crate::package::Package; use crate::project::lock::LockFile; use crate::project::overrides::ProjectOverrides; use crate::project::Project; -use crate::ts::experimental; -use crate::ui::reporter::IndicatifReporter; +use crate::ui::progress::{self, TerminalSink}; mod cli; mod config; @@ -155,12 +154,11 @@ async fn main() -> Result<(), Error> { sync, } => { ts::init_repository("https://thunderstore.io", None); - - let reporter = Box::new(IndicatifReporter); + progress::set_sink(Arc::new(TerminalSink::new())); let project = Project::open(&project_path)?; project.add_packages(&packages[..])?; - project.commit(reporter, sync).await?; + project.commit(sync).await?; Ok(()) } @@ -170,11 +168,11 @@ async fn main() -> Result<(), Error> { sync, } => { ts::init_repository("https://thunderstore.io", None); - let reporter = Box::new(IndicatifReporter); + progress::set_sink(Arc::new(TerminalSink::new())); let project = Project::open(&project_path)?; project.remove_packages(&packages[..])?; - project.commit(reporter, sync).await?; + project.commit(sync).await?; Ok(()) } @@ -356,9 +354,7 @@ async fn main() -> Result<(), Error> { } }, Commands::Server { project_path } => { - let read = io::stdin(); - let write = io::stdout(); - server::spawn(read, write, &project_path).await?; + server::spawn_stdio(&project_path).await?; Ok(()) } diff --git a/src/package/install/api.rs b/src/package/install/api.rs index 322e084..f073d4c 100644 --- a/src/package/install/api.rs +++ b/src/package/install/api.rs @@ -20,7 +20,7 @@ pub enum FileAction { } #[derive(Serialize, Deserialize, Clone, Debug)] -pub struct TrackedFile { +pub struct LinkedFile { pub action: FileAction, pub path: PathBuf, pub context: Option, @@ -48,7 +48,7 @@ pub enum Request { package_dir: PathBuf, state_dir: PathBuf, staging_dir: PathBuf, - tracked_files: Vec, + tracked_files: Vec, }, StartGame { mods_enabled: bool, @@ -68,7 +68,7 @@ pub enum Response { protocol: Version, }, PackageInstall { - tracked_files: Vec, + tracked_files: Vec, post_hook_context: Option, }, PackageUninstall { diff --git a/src/package/install/bepinex.rs b/src/package/install/bepinex.rs new file mode 100644 index 0000000..5a11c4f --- /dev/null +++ b/src/package/install/bepinex.rs @@ -0,0 +1,161 @@ +use std::fs; +use std::path::Path; + +use walkdir::WalkDir; + +use crate::error::Error; +use crate::package::install::tracked::TrackedFs; +use crate::package::install::PackageInstaller; +use crate::package::Package; + +pub struct BpxInstaller { + fs: T, +} + +impl BpxInstaller { + pub fn new(fs: T) -> Self { + BpxInstaller { fs } + } +} + +impl PackageInstaller for BpxInstaller { + async fn install_package( + &mut self, + package: &Package, + package_dir: &Path, + state_dir: &Path, + staging_dir: &Path, + is_modloader: bool, + ) -> Result<(), Error> { + if is_modloader { + // Figure out the root bepinex directory. This should, in theory, always be the folder + // that contains the winhttp.dll binary. + let bepinex_root = WalkDir::new(package_dir) + .into_iter() + .filter_map(|x| x.ok()) + .filter(|x| x.path().is_file()) + .find(|x| x.path().file_name().unwrap() == "winhttp.dll") + .expect("Failed to find winhttp.dll within BepInEx directory."); + let bepinex_root = bepinex_root.path().parent().unwrap(); + + let bep_dir = bepinex_root.join("BepInEx"); + let bep_dst = state_dir.join("BepInEx"); + + self.fs.dir_copy(&bep_dir, &bep_dst).await.unwrap(); + + // Install top-level doorstop files. + let files = fs::read_dir(bepinex_root) + .unwrap() + .filter_map(|x| x.ok()) + .filter(|x| x.path().is_file()); + + for file in files { + let dest = staging_dir.join(file.path().file_name().unwrap()); + self.fs.file_copy(&file.path(), &dest, None).await?; + } + + return Ok(()); + } + + let state_dir = state_dir.canonicalize()?; + let full_name= format!("{}-{}", package.identifier.namespace, package.identifier.name); + + let targets = vec![ + ("plugins", true), + ("patchers", true), + ("monomod", true), + ("config", false), + ].into_iter() + .map(|(x, y)| (Path::new(x), y)); + + let default = state_dir.join("BepInEx/plugins"); + for (target, relocate) in targets { + // Packages may either have the target at their tld or BepInEx/target. + let src = match package_dir.join("BepInEx").exists() { + true => package_dir.join("BepInEx").join(target), + false => package_dir.join(target), + }; + + // let src = package_dir.join(target); + let dest = state_dir.join("BepInEx").join(target); + + if !src.exists() { + continue; + } + + if !dest.exists() { + fs::create_dir_all(&dest)?; + } + + // Copy the directory contents of the target into the destination. + let entries = fs::read_dir(&src)? + .filter_map(|x| x.ok()); + + for entry in entries { + let entry = entry.path(); + + let entry_dest = match relocate { + true => dest.join(&full_name).join(entry.file_name().unwrap()), + false => dest.join(entry.file_name().unwrap()), + }; + + let entry_parent = entry_dest.parent().unwrap(); + + if !entry_parent.is_dir() { + fs::create_dir_all(entry_parent)?; + } + + if entry.is_dir(){ + self.fs.dir_copy(&entry, &entry_dest).await?; + } + + if entry.is_file() { + self.fs.file_copy(&entry, &entry_dest, None).await?; + } + } + } + + // Copy top-level files into the plugin directory. + let tl_files = fs::read_dir(package_dir)? + .filter_map(|x| x.ok()) + .filter(|x| x.path().is_file()); + + for file in tl_files { + let parent = default.join(&full_name); + let dest = parent.join(file.file_name()); + + if !parent.exists() { + fs::create_dir_all(&parent)?; + } + + self.fs.file_copy(&file.path(), &dest, None).await?; + } + + Ok(()) + } + + async fn uninstall_package( + &mut self, + _package: &Package, + _package_dir: &Path, + _state_dir: &Path, + _staging_dir: &Path, + _is_modloader: bool, + ) -> Result<(), Error> { + todo!() + } + + async fn start_game( + _mods_enabled: bool, + _state_dir: &Path, + _game_dir: &Path, + _game_exe: &Path, + _args: Vec, + ) -> Result { + todo!() + } + + fn extract_state(self) -> crate::project::state::StateEntry { + self.fs.extract_state() + } +} diff --git a/src/package/install/mod.rs b/src/package/install/mod.rs index ba45745..5a201f9 100644 --- a/src/package/install/mod.rs +++ b/src/package/install/mod.rs @@ -1,265 +1,95 @@ -use std::env; -use std::fs; -use std::path::{Path, PathBuf}; -use std::process::Stdio; +use std::collections::{HashMap, HashSet}; +use std::path::Path; -use colored::Colorize; -use tokio::io::AsyncReadExt; -use tokio::process::Command; - -use self::api::Response; -use self::api::PROTOCOL_VERSION; -use self::api::{Request, TrackedFile}; -use self::manifest::InstallerManifest; -use super::error::PackageError; -use super::Package; use crate::error::Error; -use crate::error::IoError; -use crate::ui::reporter::{Progress, ProgressBarTrait, VoidProgress}; +use crate::game::ecosystem; +use crate::package::install::bepinex::BpxInstaller; +use crate::package::install::tracked::TrackedFs; +use crate::package::Package; +use crate::project::state::StateEntry; +use crate::ts::package_reference::PackageReference; +use crate::ts::v1::models::ecosystem::R2MLLoader; pub mod api; mod legacy_compat; pub mod manifest; - -pub struct Installer { - pub exec_path: PathBuf, -} - -impl Installer { - /// Loads the given package as an Installer and prepares it for execution. - /// Note that cached installers can skip the prepare step. - pub async fn load_and_prepare(package: &Package) -> Result { - // Temp, we'll figure out a good solution from the progress reporter later. - let test = VoidProgress {}; - let cache_dir = match package.get_path().await { - Some(x) => x, - None => package.download(test.add_bar().as_ref()).await?, - }; - - let manifest = { - let path = cache_dir.join("installer.json"); - if !path.is_file() { - Err(PackageError::InstallerNoManifest)? - } else { - let contents = fs::read_to_string(path)?; - serde_json::from_str::(&contents)? - } - }; - - // Determine the absolute path of the installer's executable based on the current architecture. - let current_arch = env::consts::ARCH; - let current_os = env::consts::OS; - - let matrix = manifest - .matrix - .iter() - .find(|x| { - x.architecture.to_string() == current_arch && x.target_os.to_string() == current_os - }) - .ok_or(PackageError::InstallerNotExecutable)?; - - let exec_path = { - let abs = cache_dir.join(&matrix.executable); - - if abs.is_file() { - Ok(abs) - } else { - Err(IoError::FileNotFound(abs)) - } - }?; - - let installer = Installer { exec_path }; - - // Validate that the installer is (a) executable and (b) is using a valid protocol version. - let response = installer.run(&Request::Version).await?; - let Response::Version { - author: _, - identifier: _, - protocol, - } = response - else { - Err(PackageError::InstallerBadResponse { - package_id: package.identifier.to_string(), - message: "The installer did not respond with a valid or otherwise serializable Version response variant.".to_string(), - })? - }; - - if protocol.major != PROTOCOL_VERSION.major { - Err(PackageError::InstallerBadVersion { - package_id: package.identifier.to_string(), - given_version: protocol, - our_version: PROTOCOL_VERSION, - })? - } - - Ok(installer) - } - - pub fn override_new() -> Self { - let override_installer = PathBuf::from(std::env::var("TCLI_INSTALLER_OVERRIDE").unwrap()); - - if !override_installer.is_file() { - panic!( - "TCLI_INSTALLER_OVERRIDE is set to {}, which does not point to a file that actually exists.", override_installer.to_str().unwrap() - ) - } - - Installer { - exec_path: override_installer, - } - } - - pub async fn install_package( - &self, +pub mod bepinex; +pub mod tracked; + +pub trait PackageInstaller { + /// Install a package into this profile. + /// + /// `state_dir` is the directory that is "linked" to at runtime by the modloader. + /// `staging_dir` is the directory that contains files that are directly installed into the game directory. + async fn install_package( + &mut self, package: &Package, package_dir: &Path, state_dir: &Path, staging_dir: &Path, - reporter: &dyn ProgressBarTrait, - ) -> Result, Error> { - // Determine if the package is a modloader or not. - let is_modloader = package.identifier.name.to_lowercase().contains("bepinex"); - - let request = Request::PackageInstall { - is_modloader, - package: package.identifier.clone(), - package_deps: package.dependencies.clone(), - package_dir: package_dir.to_path_buf(), - state_dir: state_dir.to_path_buf(), - staging_dir: staging_dir.to_path_buf(), - }; + is_modloader: bool, + ) -> Result<(), Error>; - let progress_message = format!( - "{}-{} {}", - package.identifier.namespace.bold(), - package.identifier.name.bold(), - package.identifier.version.to_string().truecolor(90, 90, 90) - ); - reporter.set_message(format!("Installing {progress_message}")); - - let response = self.run(&request).await?; - match response { - Response::PackageInstall { - tracked_files, - post_hook_context: _, - } => Ok(tracked_files), - - Response::Error { message } => Err(PackageError::InstallerError { message })?, - - x => { - let message = - format!("Didn't recieve one of the expected variants: Response::PackageInstall or Response::Error. Got: {x:#?}"); - - Err(PackageError::InstallerBadResponse { - package_id: package.identifier.to_string(), - message, - })? - } - } - } - - pub async fn uninstall_package( - &self, + /// Uninstall a package from this profile. + async fn uninstall_package( + &mut self, package: &Package, package_dir: &Path, state_dir: &Path, staging_dir: &Path, - tracked_files: Vec, - reporter: &dyn ProgressBarTrait, - ) -> Result<(), Error> { - let is_modloader = package.identifier.name.to_lowercase().contains("bepinex"); - let request = Request::PackageUninstall { - is_modloader, - package: package.identifier.clone(), - package_deps: package.dependencies.clone(), - package_dir: package_dir.to_path_buf(), - state_dir: state_dir.to_path_buf(), - staging_dir: staging_dir.to_path_buf(), - tracked_files, - }; - - let progress_message = format!( - "{}-{} {}", - package.identifier.namespace.bold(), - package.identifier.name.bold(), - package.identifier.version.to_string().truecolor(90, 90, 90) - ); - reporter.set_message(format!("Uninstalling {progress_message}")); - - let response = self.run(&request).await?; - match response { - Response::PackageUninstall { - post_hook_context: _, - } => Ok(()), - Response::Error { message } => Err(PackageError::InstallerError { message })?, - x => { - let message = - format!("Didn't recieve one of the expected variants: Response::PackageInstall or Response::Error. Got: {x:#?}"); - - Err(PackageError::InstallerBadResponse { - package_id: package.identifier.to_string(), - message, - })? - } - } - } + is_modloader: bool, + ) -> Result<(), Error>; - /// Start the game and drop a PID file in the state directory of the current project. - pub async fn start_game( - &self, + /// Start the game. + async fn start_game( mods_enabled: bool, state_dir: &Path, game_dir: &Path, game_exe: &Path, args: Vec, - ) -> Result { - let request = Request::StartGame { - mods_enabled, - project_state: state_dir.to_path_buf(), - game_dir: game_dir.to_path_buf(), - game_exe: game_exe.to_path_buf(), - args, - }; + ) -> Result; - let response = self.run(&request).await?; - - let Response::StartGame { pid } = response else { - panic!("Invalid response."); - }; + /// Extract the tracked state from this installer, consuming it. + fn extract_state(self) -> StateEntry; +} - Ok(pid) +/// Get the proper installer for the provided modloader variant. +pub fn get_installer(ml_variant: &R2MLLoader, fs: T) -> impl PackageInstaller { + match ml_variant { + R2MLLoader::BepInEx => BpxInstaller::new(fs), + _ => panic!("Support for modloader {ml_variant:?} has not been implemented."), } +} - pub async fn run(&self, arg: &Request) -> Result { - let args_json = serde_json::to_string(arg)?; - - let child = Command::new(&self.exec_path) - .stdout(Stdio::piped()) - .stderr(Stdio::piped()) - .arg(&args_json) - .spawn()?; - - // Execute the installer, capturing and deserializing any output. - // TODO: Safety check here to warn / stop an installer from blowing up the heap. - let mut output_str = String::new(); - child - .stdout - .unwrap() - .read_to_string(&mut output_str) - .await?; - - let mut err_str = String::new(); - child.stderr.unwrap().read_to_string(&mut err_str).await?; - - if !err_str.is_empty() { - println!("installer stderr:"); - println!("{err_str}"); - } - - // println!("installer stdout:"); - // println!("{output_str}"); +/// Determine the modloader to use for the given packages. +pub async fn guess_modloader(packages: &[PackageReference]) -> Option { + let schema = ecosystem::get_schema().await.ok()?; + let ml: HashMap = schema + .modloader_packages + .into_iter() + .map(|x| (x.package_id, x.loader)) + .collect(); + + packages + .iter() + .find_map(|x| ml.get(&x.to_loose_ident_string()).cloned()) +} - let response = serde_json::from_str(&output_str)?; - Ok(response) - } +/// Determine which packages are modloaders. +pub async fn get_modloader_packages(packages: &[PackageReference]) -> HashSet { + let Ok(schema) = ecosystem::get_schema().await else { + return HashSet::new(); + }; + + let ml_ids: HashSet = schema + .modloader_packages + .into_iter() + .map(|x| x.package_id) + .collect(); + + packages + .iter() + .filter(|x| ml_ids.contains(&x.to_loose_ident_string())) + .map(|x| x.to_loose_ident_string()) + .collect() } diff --git a/src/package/install/tracked.rs b/src/package/install/tracked.rs new file mode 100644 index 0000000..bb02008 --- /dev/null +++ b/src/package/install/tracked.rs @@ -0,0 +1,106 @@ +use std::path::Path; +use tokio::fs; +use walkdir::WalkDir; + +use crate::package::install::api::{FileAction, LinkedFile}; +use crate::project::state::{StagedFile, StateEntry}; + +use crate::error::Error; + +pub trait TrackedFs { + /// Create a new instance dedicated to tracking filesystem edits during the + /// installation of the provided package. + /// + /// This essentially creates or opens the cooresponding entry within the + /// tracked_files.json file and writes any tracked fs modifications to it. + fn new(state: StateEntry) -> Self; + + /// Extract the new StateEntry from this instance. + fn extract_state(self) -> StateEntry; + + /// Copy a file from a source to a destination, overwriting it if the file + /// already exists. + /// + /// This will append (or overwrite) a FileAction::Create entry. + async fn file_copy(&mut self, src: &Path, dst: &Path, stage_dst: Option<&Path>) -> Result<(), Error>; + + /// Delete some target file. + /// + /// If `tracked` is set this this will append a FileAction::Delete entry, + /// overwriting one if it already exists for this file. + async fn file_delete(&mut self, target: &Path, tracked: bool); + + /// Recursively copy a source directory to a destination, overwriting it if + /// it already exists. + /// + /// This will append (or overwrite) a FileAction::Create entry for each file + /// copied while recursing. + async fn dir_copy(&mut self, src: &Path, dst: &Path) -> Result<(), Error>; + + /// Recursively delete some target directory. + /// + /// If `tracked` if set then this will append a FileAction::Delete entry + /// for each file deleted while recursing, otherwise matching entries are + /// deleted. + async fn dir_delete(&mut self, target: &Path, tracked: bool); +} + +#[derive(Debug)] +pub struct ConcreteFs { + state: StateEntry, +} + +impl TrackedFs for ConcreteFs { + fn new(state: StateEntry) -> Self { + ConcreteFs { + state + } + } + + fn extract_state(self) -> StateEntry { + self.state + } + + async fn file_copy(&mut self, src: &Path, dst: &Path, stage_dst: Option<&Path>) -> Result<(), Error> { + fs::copy(src, dst).await?; + let tracked = LinkedFile { action: FileAction::Create, path: dst.to_path_buf(), context: None }; + + if let Some(stage_dst) = stage_dst { + let mut staged = StagedFile::new(tracked)?; + staged.dest.push(stage_dst.to_path_buf()); + self.state.add_staged(staged, false); + } else { + self.state.add_linked(tracked, false); + } + + Ok(()) + } + + async fn file_delete(&mut self, _target: &Path, _tracked: bool) { + todo!() + } + + async fn dir_copy(&mut self, src: &Path, dst: &Path) -> Result<(), Error> { + let files = WalkDir::new(&src) + .into_iter() + .filter_map(|e| e.ok()) + .filter(|x| x.path().is_file()); + + for file in files { + let dest = dst.join(file.path().strip_prefix(&src).unwrap()); + let dest_parent = dest.parent().unwrap(); + + if !dest_parent.is_dir() { + fs::create_dir_all(dest_parent).await?; + } + + self.file_copy(file.path(), &dest, None).await?; + } + + Ok(()) + } + + async fn dir_delete(&mut self, _target: &Path, _tracked: bool) { + todo!() + } +} diff --git a/src/package/mod.rs b/src/package/mod.rs index 72545d1..0ddebd3 100644 --- a/src/package/mod.rs +++ b/src/package/mod.rs @@ -21,7 +21,7 @@ use crate::error::{Error, IoError, IoResultToTcli}; use crate::ts::package_manifest::PackageManifestV1; use crate::ts::package_reference::PackageReference; use crate::ts::{self, CLIENT}; -use crate::ui::reporter::ProgressBarTrait; +use crate::ui::progress; use crate::TCLI_HOME; #[derive(Serialize, Deserialize, Debug)] @@ -165,7 +165,7 @@ impl Package { })) } - pub async fn download(&self, reporter: &dyn ProgressBarTrait) -> Result { + pub async fn download(&self) -> Result { let PackageSource::Remote(package_source) = &self.source else { panic!("Invalid use, this is a local package.") }; @@ -173,41 +173,36 @@ impl Package { let output_path = cache::get_cache_location(&self.identifier); if output_path.is_dir() { - reporter.finish(); return Ok(output_path); } + let pkg_id = self.identifier.to_string(); let download_result = CLIENT.get(package_source).send().await.unwrap(); - let download_size = download_result.content_length().unwrap(); + let download_size = download_result.content_length().unwrap_or(0); - let progress_message = format!( - "{}-{} ({})", - self.identifier.namespace.bold(), - self.identifier.name.bold(), - self.identifier.version.to_string().truecolor(90, 90, 90) - ); - - reporter.set_length(download_size); - reporter.set_message(format!("Downloading {progress_message}...")); + progress::scope_progress(&pkg_id, 0, Some("downloading")); let mut download_stream = download_result.bytes_stream(); let mut temp_file = cache::get_temp_zip_file(&self.identifier).await?; let zip_file = temp_file.file_mut(); + let mut downloaded: u64 = 0; while let Some(chunk) = download_stream.next().await { let chunk = chunk.unwrap(); zip_file.write_all(&chunk).await.unwrap(); - - reporter.inc(chunk.len() as u64); + downloaded += chunk.len() as u64; + + if download_size > 0 { + let pct = (downloaded * 100) / download_size; + progress::scope_progress(&pkg_id, pct, None); + } } - reporter.set_message(format!("Extracting {progress_message}...")); + progress::scope_progress(&pkg_id, 100, Some("extracting")); let cache_path = add_to_cache(&self.identifier, temp_file.into_std().await.file())?; - // reporter.finish(); - Ok(cache_path) } } diff --git a/src/project/mod.rs b/src/project/mod.rs index 834b716..2bd68d8 100644 --- a/src/project/mod.rs +++ b/src/project/mod.rs @@ -1,4 +1,4 @@ -use std::borrow::Borrow; + use std::collections::HashMap; use std::fs; use std::fs::File; @@ -6,7 +6,7 @@ use std::io::{ErrorKind, Write}; use std::path::{Path, PathBuf}; use std::sync::Arc; -use colored::Colorize; + use error::ProjectError; use futures::future::try_join_all; pub use publish::publish; @@ -18,24 +18,24 @@ use crate::error::{Error, IoError, IoResultToTcli}; use crate::game::registry::GameData; use crate::game::{proc, registry}; use crate::package::index::PackageIndex; -use crate::package::install::api::TrackedFile; -use crate::package::install::Installer; +use crate::package::install::tracked::{ConcreteFs, TrackedFs}; +use crate::package::install::{self, PackageInstaller}; use crate::package::resolver::DependencyGraph; use crate::package::{resolver, Package}; use crate::project::manifest::ProjectManifest; use crate::project::overrides::ProjectOverrides; -use crate::project::state::{StagedFile, StateFile}; +use crate::project::state::{StateEntry, StateFile}; use crate::ts::package_manifest::PackageManifestV1; use crate::ts::package_reference::PackageReference; -use crate::ui::reporter::{Progress, Reporter}; +use crate::ui::progress; use crate::{util, TCLI_HOME}; pub mod error; pub mod lock; pub mod manifest; pub mod overrides; -mod publish; -mod state; +pub mod publish; +pub mod state; pub enum ProjectKind { Dev(ProjectOverrides), @@ -289,85 +289,78 @@ impl Project { async fn install_packages( &self, - installer: &Installer, statefile: &mut StateFile, packages: Vec<&PackageReference>, - multi: &dyn Progress, + all_resolved: &[PackageReference], ) -> Result<(), Error> { - let packages = try_join_all( - packages - .into_iter() - .map(|x| async move { Package::from_any(x).await }), - ) - .await?; - - let sem = Arc::new(Semaphore::new(5)); + // Determine the modloader using the full resolved package list. + let modloader = install::guess_modloader(all_resolved) + .await + .expect("Could not determine modloader. Ensure a modloader package is in your dependencies."); + let modloader_packages = install::get_modloader_packages(all_resolved).await; - let jobs = packages.into_iter().map(|package| async { - let _permit = sem.acquire().await.unwrap(); - - let bar = multi.add_bar(); - let bar = bar.as_ref(); - - // Resolve the package, either downloading it or returning its cached path. - let package_dir = match package.get_path().await { - Some(x) => x, - None => package.download(bar).await?, - }; - let tracked_files = installer - .install_package( - &package, - &package_dir, - &self.state_dir, - &self.staging_dir, - bar, - ) - .await; - - let finished_msg = match tracked_files { - Ok(_) => format!( - "{} Installed {}-{} {}", - "[✓]".green(), - package.identifier.namespace.bold(), - package.identifier.name.bold(), - package.identifier.version.to_string().truecolor(90, 90, 90) - ), - Err(ref e) => format!( - "{} Error {}-{} {}\n\t{}", - "[x]".red(), - package.identifier.namespace.bold(), - package.identifier.name.bold(), - package.identifier.version.to_string().truecolor(90, 90, 90), - e, - ), - }; + let packages = packages + .into_iter() + .map(|x| async move { Package::from_any(x).await }); - bar.println(&finished_msg); + let sem = Arc::new(Semaphore::new(5)); - tracked_files.map(|x| (package.identifier, x)) + let jobs = packages.into_iter().map(|package| { + let modloader = modloader.clone(); + let modloader_packages = modloader_packages.clone(); + let sem = sem.clone(); + + async move { + let _permit = sem.acquire().await.unwrap(); + let package = package.await?; + let pkg_id = package.identifier.to_string(); + let is_modloader = modloader_packages.contains(&package.identifier.to_loose_ident_string()); + + progress::scope_start_child(&pkg_id, "install", &package.identifier.name); + + // Resolve the package, either downloading it or returning its cached path. + progress::scope_progress(&pkg_id, 0, Some("resolving")); + let package_dir = match package.get_path().await { + Some(x) => x, + None => package.download().await?, + }; + + let mut installer = install::get_installer(&modloader, ConcreteFs::new(StateEntry::default())); + + progress::scope_progress(&pkg_id, 0, Some("installing")); + let install_result = installer + .install_package( + &package, + &package_dir, + &self.state_dir, + &self.staging_dir, + is_modloader, + ) + .await; + + // On success, extract the tracked state and return it with the package id. + match install_result { + Ok(_) => { + progress::scope_complete(&pkg_id); + let state = installer.extract_state(); + Ok((package.identifier, state)) + } + Err(e) => { + progress::scope_fail(&pkg_id, e.to_string()); + Err(e) + } + } + } }); - let tracked_files = try_join_all(jobs) - .await? - .into_iter() - .collect::)>>(); - - // Iterate through each installed mod, separate tracked files into "link" and "stage" variants. - // TODO: Make this less hacky, we shouldn't be relying on path ops to determine this. - for (package, tracked_files) in tracked_files { - let staged_files = tracked_files - .iter() - .filter(|x| x.path.starts_with(&self.staging_dir)) - .map(|x| StagedFile::new(x.clone())) - .collect::, _>>()?; - - let linked_files = tracked_files - .into_iter() - .filter(|x| x.path.starts_with(&self.state_dir)); + let results = try_join_all(jobs).await?; - let group = statefile.state.entry(package).or_default(); - group.staged.extend(staged_files); - group.linked.extend(linked_files); + // Merge tracked files into the statefile. + for (package_id, state_entry) in results { + // Merge the new state with any existing state for this package. + let existing = statefile.entry(package_id); + existing.staged.extend(state_entry.staged); + existing.linked.extend(state_entry.linked); } Ok(()) @@ -375,10 +368,8 @@ impl Project { async fn uninstall_packages( &self, - installer: &Installer, statefile: &mut StateFile, packages: Vec<&PackageReference>, - multi: &dyn Progress, ) -> Result<(), Error> { let packages = try_join_all( packages @@ -387,65 +378,39 @@ impl Project { ) .await?; - // Uninstall each package in parallel. - try_join_all(packages.iter().map(|package| async { - let bar = multi.add_bar(); - let bar = bar.as_ref(); + // For each package to uninstall: + // 1. Remove all staged files that were copied to game directories + // 2. Remove all linked files from the state directory + // 3. Remove the package's entry from the statefile + for package in packages { + let pkg_id = package.identifier.to_string(); + progress::scope_start_child(&pkg_id, "uninstall", &package.identifier.name); - let package_dir = match package.get_path().await { - Some(x) => x, - None => package.download(bar).await?, + let Some(entry) = statefile.get(&package.identifier) else { + progress::scope_complete(&pkg_id); + continue; }; - let state_entry = statefile.state.get(&package.identifier); - let tracked_files = state_entry - .map_or(vec![], |x| x.staged.clone()) - .into_iter() - .map(|x| x.action) - .chain(state_entry.map_or(vec![], |x| x.linked.clone())) - .collect::>(); - - installer - .uninstall_package( - package, - &package_dir, - &self.state_dir, - &self.staging_dir, - tracked_files, - bar, - ) - .await - })) - .await?; - - // Run post-uninstall cleanup / validation ops. - // 1. Invalidate staged files, removing them from the statefile if they no longer exist. - // 2. Cleanup empty directories within staging and state. - // 3. Remove uninstalled / invalidated entries from the statefile. - for package in packages { - let entry = statefile.state.get(&package.identifier).unwrap(); - let staged = &entry.staged; - - // Determine the list of entries that will be invalidated. - let invalid_staged_entries = staged.iter().filter(|x| !x.action.path.is_file()); - - for staged_entry in invalid_staged_entries { - // Each dest is checked if it (a) exists and (b) is the same as orig. - let dests_to_remove = - staged_entry.dest.iter().filter_map(|path| { - match staged_entry.is_same_as(path) { - Ok(x) if x => Some(Ok(path)), - Ok(_) => None, - Err(e) => Some(Err(e)), - } - }); - - for dest in dests_to_remove { - fs::remove_file(dest?)?; + // Remove staged file destinations (files copied to game dir at launch) + for staged in &entry.staged { + for dest in &staged.dest { + // Only remove if the file still matches what we installed + if let Ok(true) = staged.is_same_as(dest) { + let _ = fs::remove_file(dest); + } } + // Remove the source file in staging dir + let _ = fs::remove_file(&staged.file.path); + } + + // Remove linked files from state dir + for linked in &entry.linked { + let _ = fs::remove_file(&linked.path); } - statefile.state.remove(&package.identifier); + // Remove package from statefile + statefile.remove(&package.identifier); + progress::scope_complete(&pkg_id); } // Cleanup empty directories in the state and staging dirs. @@ -456,43 +421,46 @@ impl Project { } /// Commit changes made to the project manifest to the project. - pub async fn commit(&self, reporter: Box, sync: bool) -> Result<(), Error> { + pub async fn commit(&self, sync: bool) -> Result<(), Error> { if sync { + progress::scope_start("sync", "Syncing package index"); PackageIndex::sync(&TCLI_HOME).await?; + progress::scope_complete("sync"); } + let lockfile = LockFile::open_or_new(&self.lockfile_path)?; let lockfile_graph = DependencyGraph::from_graph(lockfile.package_graph); let manifest = ProjectManifest::read_from_file(&self.manifest_path)?; let package_graph = resolver::resolve_packages(manifest.dependencies.dependencies).await?; - // Compare the lockfile and new graphs to determine the + // Get the full list of resolved packages for modloader detection. + let all_resolved: Vec<_> = package_graph.digest().into_iter().cloned().collect(); + let delta = lockfile_graph.graph_delta(&package_graph); - println!( - "{} packages will be installed, {} will be removed.", + progress::info(format!( + "{} packages to install, {} to remove", delta.add.len(), delta.del.len() - ); + )); - let installer = Installer::override_new(); let mut statefile = StateFile::open_or_new(&self.statefile_path)?; - let multi = reporter.create_progress(); - let packages_to_remove = delta.del.iter().rev().collect::>(); let packages_to_add = delta.add.iter().rev().collect::>(); - self.uninstall_packages( - &installer, - &mut statefile, - packages_to_remove, - multi.borrow(), - ) - .await?; + if !packages_to_remove.is_empty() { + progress::scope_start("uninstall", "Removing packages"); + self.uninstall_packages(&mut statefile, packages_to_remove).await?; + progress::scope_complete("uninstall"); + } - self.install_packages(&installer, &mut statefile, packages_to_add, multi.borrow()) - .await?; + if !packages_to_add.is_empty() { + progress::scope_start("install", "Installing packages"); + self.install_packages(&mut statefile, packages_to_add, &all_resolved).await?; + progress::scope_complete("install"); + } // Write the statefile with changes made during unins statefile.write(&self.statefile_path)?; @@ -507,8 +475,8 @@ impl Project { pub async fn start_game( &self, game_id: &str, - mods_enabled: bool, - args: Vec, + _mods_enabled: bool, + _args: Vec, ) -> Result<(), Error> { let game_data = registry::get_game_data(&self.game_registry_path, game_id) .ok_or_else(|| ProjectError::InvalidGameId(game_id.to_string()))?; @@ -520,7 +488,7 @@ impl Project { let staged_files = statefile.state.values_mut().flat_map(|x| &mut x.staged); for file in staged_files { - let rel = file.action.path.strip_prefix(&self.staging_dir).unwrap(); + let rel = file.file.path.strip_prefix(&self.staging_dir).unwrap(); let dest = game_dir.join(rel); if file.is_same_as(&dest)? { @@ -532,37 +500,37 @@ impl Project { fs::create_dir_all(dest_parent)?; } - fs::copy(&file.action.path, &dest)?; + fs::copy(&file.file.path, &dest)?; file.dest.push(dest); } statefile.write(&self.statefile_path)?; - let installer = Installer::override_new(); - let pid = installer - .start_game( - mods_enabled, - &self.state_dir, - &game_dist.game_dir, - &game_dist.exe_path, - args, - ) - .await?; - - // The PID file is contained within the state dir and is of name `game.exe.pid`. - let pid_path = self - .base_dir - .join(".tcli") - .join(format!("{}.pid", game_data.identifier)); - - let mut pid_file = File::create(pid_path)?; - pid_file.write_all(format!("{}", pid).as_bytes())?; - - println!( - "{} has been started with PID {}.", - game_data.display_name.green(), - pid - ); + // let installer = Installer::override_new(); + // let pid = installer + // .start_game( + // mods_enabled, + // &self.state_dir, + // &game_dist.game_dir, + // &game_dist.exe_path, + // args, + // ) + // .await?; + + // // The PID file is contained within the state dir and is of name `game.exe.pid`. + // let pid_path = self + // .base_dir + // .join(".tcli") + // .join(format!("{}.pid", game_data.identifier)); + + // let mut pid_file = File::create(pid_path)?; + // pid_file.write_all(format!("{}", pid).as_bytes())?; + + // println!( + // "{} has been started with PID {}.", + // game_data.display_name.green(), + // pid + // ); Ok(()) } diff --git a/src/project/state.rs b/src/project/state.rs index deeb47a..41f1860 100644 --- a/src/project/state.rs +++ b/src/project/state.rs @@ -6,22 +6,22 @@ use std::path::{Path, PathBuf}; use serde::{Deserialize, Serialize}; use crate::error::Error; -use crate::package::install::api::TrackedFile; +use crate::package::install::api::LinkedFile; use crate::ts::package_reference::PackageReference; use crate::util; #[derive(Serialize, Deserialize, Debug, Clone)] pub struct StagedFile { - pub action: TrackedFile, + pub file: LinkedFile, pub dest: Vec, pub md5: String, } impl StagedFile { - pub fn new(action: TrackedFile) -> Result { - let md5 = util::file::md5(&action.path)?; + pub fn new(file: LinkedFile) -> Result { + let md5 = util::file::md5(&file.path)?; Ok(StagedFile { - action, + file, dest: vec![], md5, }) @@ -37,10 +37,85 @@ impl StagedFile { } } -#[derive(Serialize, Deserialize, Default)] +#[derive(Serialize, Deserialize, Default, Debug, Clone)] pub struct StateEntry { pub staged: Vec, - pub linked: Vec, + pub linked: Vec, +} + +impl StateEntry { + /// Add a new staged file. If overwrite is set then already existing + /// entries with the same source path will be replaced. + pub fn add_staged(&mut self, file: StagedFile, overwrite: bool) { + let existing_idx = self.staged.iter().position(|x| x.file.path == file.file.path); + + match existing_idx { + Some(idx) if overwrite => { + self.staged[idx] = file; + } + Some(_) => { + // Entry exists and overwrite is false — do nothing + } + None => { + self.staged.push(file); + } + } + } + + /// Add a new linked file. If overwrite is set then already existing + /// entries with the same path will be replaced. + pub fn add_linked(&mut self, file: LinkedFile, overwrite: bool) { + let existing_idx = self.linked.iter().position(|x| x.path == file.path); + + match existing_idx { + Some(idx) if overwrite => { + self.linked[idx] = file; + } + Some(_) => { + // Entry exists and overwrite is false — do nothing + } + None => { + self.linked.push(file); + } + } + } + + /// Remove a staged file entry by its source path. + pub fn remove_staged(&mut self, path: &Path) -> Option { + let idx = self.staged.iter().position(|x| x.file.path == path)?; + Some(self.staged.remove(idx)) + } + + /// Remove a linked file entry by its path. + pub fn remove_linked(&mut self, path: &Path) -> Option { + let idx = self.linked.iter().position(|x| x.path == path)?; + Some(self.linked.remove(idx)) + } + + /// Get a staged file by its source path. + pub fn get_staged(&self, path: &Path) -> Option<&StagedFile> { + self.staged.iter().find(|x| x.file.path == path) + } + + /// Get a mutable staged file by its source path. + pub fn get_staged_mut(&mut self, path: &Path) -> Option<&mut StagedFile> { + self.staged.iter_mut().find(|x| x.file.path == path) + } + + /// Get a linked file by its path. + pub fn get_linked(&self, path: &Path) -> Option<&LinkedFile> { + self.linked.iter().find(|x| x.path == path) + } + + /// Check if this entry has any tracked files. + pub fn is_empty(&self) -> bool { + self.staged.is_empty() && self.linked.is_empty() + } + + /// Get the total count of tracked files. + pub fn file_count(&self) -> usize { + self.staged.len() + self.linked.len() + } } #[derive(Serialize, Deserialize, Default)] @@ -53,6 +128,7 @@ impl StateFile { if !path.is_file() { let empty = StateFile::default(); empty.write(path)?; + return Ok(StateFile::default()); } let contents = fs::read_to_string(path)?; @@ -61,7 +137,7 @@ impl StateFile { Ok(statefile) } - pub fn write(self, path: &Path) -> Result<(), Error> { + pub fn write(&self, path: &Path) -> Result<(), Error> { let ser = serde_json::to_string_pretty(&self)?; let mut file = OpenOptions::new() .write(true) @@ -72,4 +148,74 @@ impl StateFile { Ok(()) } + + /// Get or create a state entry for the given package. + pub fn entry(&mut self, package: PackageReference) -> &mut StateEntry { + self.state.entry(package).or_default() + } + + /// Get the state entry for a package, if it exists. + pub fn get(&self, package: &PackageReference) -> Option<&StateEntry> { + self.state.get(package) + } + + /// Get a mutable state entry for a package, if it exists. + pub fn get_mut(&mut self, package: &PackageReference) -> Option<&mut StateEntry> { + self.state.get_mut(package) + } + + /// Remove a package's state entry entirely, returning it if it existed. + pub fn remove(&mut self, package: &PackageReference) -> Option { + self.state.remove(package) + } + + /// Check if a package has any tracked state. + pub fn contains(&self, package: &PackageReference) -> bool { + self.state.contains_key(package) + } + + /// Get all packages that have tracked state. + pub fn packages(&self) -> impl Iterator { + self.state.keys() + } + + /// Get all staged files across all packages. + pub fn all_staged(&self) -> impl Iterator { + self.state + .iter() + .flat_map(|(pkg, entry)| entry.staged.iter().map(move |f| (pkg, f))) + } + + /// Get all linked files across all packages. + pub fn all_linked(&self) -> impl Iterator { + self.state + .iter() + .flat_map(|(pkg, entry)| entry.linked.iter().map(move |f| (pkg, f))) + } + + /// Find which package owns a staged file by its source path. + pub fn find_staged_owner(&self, path: &Path) -> Option<&PackageReference> { + self.state + .iter() + .find(|(_, entry)| entry.staged.iter().any(|f| f.file.path == path)) + .map(|(pkg, _)| pkg) + } + + /// Find which package owns a linked file by its path. + pub fn find_linked_owner(&self, path: &Path) -> Option<&PackageReference> { + self.state + .iter() + .find(|(_, entry)| entry.linked.iter().any(|f| f.path == path)) + .map(|(pkg, _)| pkg) + } + + /// Remove empty entries (packages with no tracked files). + pub fn prune_empty(&mut self) { + self.state.retain(|_, entry| !entry.is_empty()); + } + + /// Get total count of tracked files across all packages. + pub fn total_file_count(&self) -> usize { + self.state.values().map(|e| e.file_count()).sum() + } } diff --git a/src/server/method/mod.rs b/src/server/method/mod.rs index dc44c04..72585eb 100644 --- a/src/server/method/mod.rs +++ b/src/server/method/mod.rs @@ -1,21 +1,12 @@ pub mod package; pub mod project; -use std::sync::RwLock; - -use futures::channel::mpsc::Sender; use serde::de::DeserializeOwned; use serde::{Deserialize, Serialize}; use self::package::PackageMethod; use self::project::ProjectMethod; -use super::proto::Response; use super::{Error, ServerError}; -use crate::project::Project; - -pub trait Routeable { - async fn route(&self, ctx: RwLock, send: Sender>); -} #[derive(Serialize, Deserialize, Debug, PartialEq, Eq)] pub enum Method { diff --git a/src/server/method/package.rs b/src/server/method/package.rs index b8fafa0..676a327 100644 --- a/src/server/method/package.rs +++ b/src/server/method/package.rs @@ -4,7 +4,7 @@ use super::Error; use crate::package::cache; use crate::package::index::PackageIndex; use crate::server::proto::{Id, Response}; -use crate::server::{Runtime, ServerError}; +use crate::server::{Runtime, ServerError, Transport}; use crate::ts::package_reference::PackageReference; use crate::TCLI_HOME; @@ -14,7 +14,7 @@ pub enum PackageMethod { GetMetadata(GetMetadata), /// Determine if the package exists within the cache. IsCached(IsCached), - /// Syncronize the package index. + /// Synchronize the package index. SyncIndex, } @@ -28,20 +28,25 @@ impl PackageMethod { }) } - pub async fn route(&self, rt: &mut Runtime) -> Result<(), Error> { + pub async fn route( + &self, + id: Id, + rt: &Runtime, + transport: &mut T, + ) -> Result<(), Error> { match self { Self::GetMetadata(data) => { let index = PackageIndex::open(&TCLI_HOME).await?; let package = index.lock().unwrap().get_package(&data.package).unwrap(); - rt.send(Response::data_ok(Id::String("diowadaw".into()), package)); + rt.send_response(transport, Response::ok(id, package)).await; } Self::IsCached(data) => { let is_cached = cache::is_cached(&data.package); - rt.send(Response::data_ok(Id::String("dwdawdwa".into()), is_cached)); + rt.send_response(transport, Response::ok(id, is_cached)).await; } Self::SyncIndex => { PackageIndex::sync(&TCLI_HOME).await?; - rt.send(Response::ok(Id::String("dwada".into()))); + rt.send_response(transport, Response::ok(id, serde_json::json!({ "synced": true }))).await; } } diff --git a/src/server/method/project.rs b/src/server/method/project.rs index 63314fc..a6c7f7b 100644 --- a/src/server/method/project.rs +++ b/src/server/method/project.rs @@ -1,14 +1,12 @@ use std::path::PathBuf; -use std::sync::Arc; use serde::{Deserialize, Serialize}; use super::Error; use crate::project::ProjectKind; -use crate::server::proto::{Id, Response, ResponseData}; -use crate::server::{Runtime, ServerError}; +use crate::server::proto::{Id, Response}; +use crate::server::{Runtime, ServerError, Transport}; use crate::ts::package_reference::PackageReference; -use crate::{project::Project, ui::reporter::VoidReporter}; #[derive(Serialize, Deserialize, Debug, PartialEq, Eq)] pub enum ProjectMethod { @@ -24,12 +22,6 @@ pub enum ProjectMethod { InstalledPackages, } -impl From> for ServerError { - fn from(_val: Option) -> Self { - ServerError::InvalidContext - } -} - impl ProjectMethod { pub fn from_value(method: &str, value: serde_json::Value) -> Result { Ok(match method { @@ -43,39 +35,62 @@ impl ProjectMethod { } /// Route and execute various project methods. - /// Each of these call and interact directly with global project state. - pub async fn route(&self, rt: &mut Runtime) -> Result<(), Error> { + pub async fn route( + &self, + id: Id, + rt: &Runtime, + transport: &mut T, + ) -> Result<(), Error> { match self { ProjectMethod::Open(OpenProject { path }) => { - // Unlock the previous ctx (if it exists) and relock this one. - rt.proj = Arc::new(Project::open(path).unwrap_or(Project::create_new( - path, - true, - ProjectKind::Profile, - )?)) + // Replace the project in the runtime + let new_project = crate::project::Project::open(path).unwrap_or( + crate::project::Project::create_new(path, true, ProjectKind::Profile)?, + ); + + let mut proj = rt.proj.write().map_err(|_| ServerError::InvalidContext)?; + *proj = new_project; + drop(proj); + + rt.send_response(transport, Response::ok(id, serde_json::json!({ "path": path }))).await; } ProjectMethod::GetMetadata => { - rt.send(Response { - id: Id::String("OK".into()), - data: ResponseData::Result(format!("{:?}", rt.proj.statefile_path)), - }); + let proj = rt.proj.read().map_err(|_| ServerError::InvalidContext)?; + rt.send_response(transport, Response::ok(id, serde_json::json!({ + "statefile_path": proj.statefile_path, + "manifest_path": proj.manifest_path, + "lockfile_path": proj.lockfile_path, + }))).await; } ProjectMethod::AddPackages(packages) => { - rt.proj.add_packages(&packages.packages[..])?; - rt.proj.commit(Box::new(VoidReporter), false).await?; + { + let proj = rt.proj.read().map_err(|_| ServerError::InvalidContext)?; + proj.add_packages(&packages.packages[..])?; + proj.commit(false).await?; + } + rt.send_response( + transport, + Response::ok(id, serde_json::json!({ "added": packages.packages.len() })), + ).await; } ProjectMethod::RemovePackages(packages) => { - rt.proj.remove_packages(&packages.packages[..])?; - rt.proj.commit(Box::new(VoidReporter), false).await?; + { + let proj = rt.proj.read().map_err(|_| ServerError::InvalidContext)?; + proj.remove_packages(&packages.packages[..])?; + proj.commit(false).await?; + } + rt.send_response( + transport, + Response::ok(id, serde_json::json!({ "removed": packages.packages.len() })), + ).await; } ProjectMethod::InstalledPackages => { - let lock = rt.proj.get_lockfile()?; - let installed = lock.installed_packages().await?; - - rt.send(Response { - id: Id::Int(installed.len() as _), - data: ResponseData::Result(serde_json::to_string(&installed)?), - }); + let installed = { + let proj = rt.proj.read().map_err(|_| ServerError::InvalidContext)?; + let lock = proj.get_lockfile()?; + lock.installed_packages().await? + }; + rt.send_response(transport, Response::ok(id, installed)).await; } } @@ -85,7 +100,7 @@ impl ProjectMethod { #[derive(Serialize, Deserialize, Debug, PartialEq, Eq)] pub struct OpenProject { - path: PathBuf, + pub path: PathBuf, } #[derive(Serialize, Deserialize, Debug, PartialEq, Eq)] diff --git a/src/server/mod.rs b/src/server/mod.rs index cf60aa4..40ade29 100644 --- a/src/server/mod.rs +++ b/src/server/mod.rs @@ -1,209 +1,336 @@ -use std::io::{Read, Write}; +use std::io::{self, BufRead, Write}; +use std::net::SocketAddr; use std::path::{Path, PathBuf}; -use std::sync::mpsc::{self, Receiver, Sender}; use std::sync::{Arc, RwLock}; -use std::{io, thread}; +use futures_util::{SinkExt, StreamExt}; use lock::ProjectLock; use once_cell::sync::Lazy; -use proto::ResponseData; +use tokio::net::{TcpListener, TcpStream}; +use tokio_tungstenite::tungstenite::Message as WsMessage; +use tokio_tungstenite::WebSocketStream; -use self::proto::{Message, Request, Response}; +use self::proto::{Id, Message, Request, Response, RpcError}; use crate::error::Error; use crate::project::Project; use crate::ts; mod lock; -mod method; -mod proto; +pub mod method; +pub mod proto; + +/// Transport trait for receiving and sending JSON-RPC messages. +/// Implementations can be stdin/stdout, WebSocket, TCP, etc. +pub trait Transport { + /// Receive the next message. Returns None on EOF/disconnect. + fn recv(&mut self) -> impl std::future::Future> + Send; + + /// Send a message. + fn send(&mut self, msg: &str) -> impl std::future::Future> + Send; +} + +/// Stdin/stdout transport for CLI usage. +pub struct StdioTransport { + stdin: io::Stdin, + stdout: io::Stdout, + buf: String, +} + +impl StdioTransport { + pub fn new() -> Self { + Self { + stdin: io::stdin(), + stdout: io::stdout(), + buf: String::new(), + } + } +} + +impl Transport for StdioTransport { + async fn recv(&mut self) -> Option { + self.buf.clear(); + match self.stdin.lock().read_line(&mut self.buf) { + Ok(0) => None, // EOF + Ok(_) => Some(self.buf.trim().to_string()), + Err(_) => None, + } + } + + async fn send(&mut self, msg: &str) -> Result<(), Error> { + let mut out = self.stdout.lock(); + writeln!(out, "{}", msg).map_err(|e| Error::Server(ServerError::InvalidRequest(e.to_string())))?; + out.flush().map_err(|e| Error::Server(ServerError::InvalidRequest(e.to_string())))?; + Ok(()) + } +} + +/// WebSocket transport for GUI/remote usage. +pub struct WebSocketTransport { + ws: WebSocketStream, +} + +impl WebSocketTransport { + pub fn new(ws: WebSocketStream) -> Self { + Self { ws } + } +} -trait ToJson { - fn to_json(&self) -> Result; +impl Transport for WebSocketTransport { + async fn recv(&mut self) -> Option { + loop { + match self.ws.next().await { + Some(Ok(WsMessage::Text(text))) => return Some(text.to_string()), + Some(Ok(WsMessage::Close(_))) => return None, + Some(Ok(WsMessage::Ping(data))) => { + // Respond to ping with pong + let _ = self.ws.send(WsMessage::Pong(data)).await; + continue; + } + Some(Ok(_)) => continue, // Ignore binary, pong, etc. + Some(Err(_)) => return None, + None => return None, + } + } + } + + async fn send(&mut self, msg: &str) -> Result<(), Error> { + self.ws + .send(WsMessage::Text(msg.into())) + .await + .map_err(|e| Error::Server(ServerError::InvalidRequest(e.to_string()))) + } } /// This is our project dir singleton. It will likely be refactored, but also likely not. -/// It's buried within a couple layers of abstraction. The Lazy is because PathBuf does not have -/// a static new(), RwLock is so we can have thread-safe interior mutability. static PROJECT_DIR: Lazy> = Lazy::new(Default::default); -/// This error type exists to wrap library errors into a single easy-to-use package. +/// Server-specific errors. These map to JSON-RPC error codes. +/// TODO: Replace with macro-based system for inline error metadata. #[derive(thiserror::Error, Debug)] -#[repr(isize)] pub enum ServerError { - /// A partial implementation of the error variants described by the JRPC spec. - #[error("Failed to serialize JSON: {0:?}")] - InvalidJson(#[from] serde_json::Error) = -32700, + #[error("Failed to parse JSON: {0}")] + InvalidJson(#[from] serde_json::Error), - #[error("The method {0} is not valid.")] - InvalidMethod(String) = -32601, + #[error("Invalid request: {0}")] + InvalidRequest(String), - #[error("Recieved invalid params for method {0}: {1}")] - InvalidParams(String, String) = -32602, + #[error("Method not found: {0}")] + InvalidMethod(String), - #[error("")] - InvalidContext = 0, -} + #[error("Invalid params for {0}: {1}")] + InvalidParams(String, String), -impl Error { - pub fn discriminant(&self) -> isize { - // SAFETY: `Self` is `repr(isize)` with layout `repr(C)`, with each variant having an isize - // as its first field, so we can access this value without a pointer offset. - unsafe { *<*const _>::from(self).cast::() } - } -} + #[error("No project context available")] + InvalidContext, -impl ToJson for Result { - fn to_json(&self) -> Result { - todo!() - } + #[error("Project is locked by another process")] + ProjectLocked, + + #[error("WebSocket error: {0}")] + WebSocket(String), } -/// Runtime context for the server. This is mutable state, protected through a RwLock. -/// Mutations require a lock first, attainable through ::lock(). -struct Runtime { - tx: Sender, - proj: Arc, +/// Runtime context for the server. +pub struct Runtime { + pub proj: Arc>, + _lock: ProjectLock, } impl Runtime { - pub fn send(&self, response: Response) { - self.tx - .send(Message::Response(response)) - .expect("Failed to write to mpsc tx channel."); + /// Create a new runtime, acquiring an exclusive lock on the project. + pub fn new(project_dir: &Path) -> Result { + let lock = ProjectLock::lock(project_dir) + .ok_or(ServerError::ProjectLocked)?; + + let project = Project::open(project_dir)?; + + Ok(Self { + proj: Arc::new(RwLock::new(project)), + _lock: lock, + }) + } + + /// Send a response through the provided transport. + pub async fn send_response(&self, transport: &mut T, response: Response) { + let json = serde_json::to_string(&response).unwrap_or_else(|e| { + serde_json::to_string(&Response::err( + Id::Null, + RpcError::internal_error(format!("Serialization failed: {e}")), + )) + .unwrap() + }); + let _ = transport.send(&json).await; } -} -/// Runtime context, mutable or otherwise. This contains the project, by which most -/// project-specific ops go through. -struct RtContext { - pub project: Project, - pub lock: ProjectLock, + /// Execute a read-only operation on the project. + pub fn with_project(&self, f: F) -> Result + where + F: FnOnce(&Project) -> Result, + { + let proj = self.proj.read().map_err(|_| ServerError::InvalidContext)?; + f(&proj) + } + + /// Execute a mutable operation on the project. + pub fn with_project_mut(&self, f: F) -> Result + where + F: FnOnce(&mut Project) -> Result, + { + let mut proj = self.proj.write().map_err(|_| ServerError::InvalidContext)?; + f(&mut proj) + } } -/// Create the server runtime from the provided read and write channels. -/// This lives for the lifespan of the process. -pub async fn spawn(_read: impl Read, _write: impl Write, project_dir: &Path) -> Result<(), Error> { - let (tx, rx) = mpsc::channel::(); - let cancel = RwLock::new(false); +/// Create and run the server with the given transport. +pub async fn run(mut transport: T, project_dir: &Path) -> Result<(), Error> { + let rt = Runtime::new(project_dir)?; + + ts::init_repository("https://thunderstore.io", None); - // This thread recieves internal mpsc messages, serializes, and writes them to stdout. - thread::spawn(move || respond_msg(rx, cancel)); + while let Some(line) = transport.recv().await { + if line.is_empty() { + continue; + } - // Begin looping over stdin messages. - let stdin = io::stdin(); - let mut line = String::new(); + // Parse the message + let msg = match Message::from_json(&line) { + Ok(msg) => msg, + Err(e) => { + let rpc_err = match &e { + Error::Server(se) => RpcError::from(se), + _ => RpcError::parse_error(e.to_string()), + }; + rt.send_response(&mut transport, Response::err(Id::Null, rpc_err)).await; + continue; + } + }; - let mut rt = Runtime { - tx, - proj: Arc::new(Project::open(project_dir)?), - }; + // Route the message + if let Err(e) = route(msg, &rt, &mut transport).await { + let rpc_err = match &e { + Error::Server(se) => RpcError::from(se), + _ => RpcError::internal_error(e.to_string()), + }; + rt.send_response(&mut transport, Response::err(Id::Null, rpc_err)).await; + } + } + + Ok(()) +} +/// Convenience function to spawn with stdio transport. +pub async fn spawn_stdio(project_dir: &Path) -> Result<(), Error> { + run(StdioTransport::new(), project_dir).await +} + +/// Start a WebSocket server on the given address. +/// Each connection is handled sequentially to avoid Send requirements +/// from the non-Send Reporter/Progress traits used in the project code. +pub async fn spawn_websocket(addr: SocketAddr, project_dir: &Path) -> Result<(), Error> { + let listener = TcpListener::bind(addr) + .await + .map_err(|e| ServerError::WebSocket(e.to_string()))?; + + println!("WebSocket server listening on ws://{}", addr); + + let rt = Runtime::new(project_dir)?; ts::init_repository("https://thunderstore.io", None); loop { - if let Err(_) = stdin.read_line(&mut line) { - panic!(""); - }; + let (stream, peer) = listener + .accept() + .await + .map_err(|e| ServerError::WebSocket(e.to_string()))?; + + match tokio_tungstenite::accept_async(stream).await { + Ok(ws) => { + println!("New WebSocket connection from {}", peer); + let mut transport = WebSocketTransport::new(ws); + if let Err(e) = run_with_runtime(&rt, &mut transport).await { + eprintln!("Connection error from {}: {}", peer, e); + } + println!("Connection closed: {}", peer); + } + Err(e) => { + eprintln!("WebSocket handshake failed from {}: {}", peer, e); + } + } + } +} - println!("LINE: {line}"); +/// Run the server loop with an existing runtime (for shared WebSocket connections). +async fn run_with_runtime(rt: &Runtime, transport: &mut T) -> Result<(), Error> { + while let Some(line) = transport.recv().await { + if line.is_empty() { + continue; + } - match Message::from_json(&line) { - Ok(msg) => route(msg, &mut rt).await?, + let msg = match Message::from_json(&line) { + Ok(msg) => msg, Err(e) => { - rt.tx - .send(Message::Response(Response { - id: proto::Id::String("FUCK".into()), - data: ResponseData::Error(e.to_string()), - })) - .unwrap(); + let rpc_err = match &e { + Error::Server(se) => RpcError::from(se), + _ => RpcError::parse_error(e.to_string()), + }; + rt.send_response(transport, Response::err(Id::Null, rpc_err)).await; + continue; } }; - // if let Ok(msg) = Message::from_json(&line) { - // } else { - // } - - // let msg = Message::from_json(&line); - // route(msg, &rt).await?; + if let Err(e) = route(msg, rt, transport).await { + let rpc_err = match &e { + Error::Server(se) => RpcError::from(se), + _ => RpcError::internal_error(e.to_string()), + }; + rt.send_response(transport, Response::err(Id::Null, rpc_err)).await; + } } + + Ok(()) } -/// Route -async fn route(msg: Message, rt: &mut Runtime) -> Result<(), Error> { +/// Route a message to its handler. +async fn route(msg: Message, rt: &Runtime, transport: &mut T) -> Result<(), Error> { match msg { - Message::Request(rq) => route_rq(Request::try_from(rq)?, rt).await?, - Message::Response(_) => panic!(), + Message::Request(rq) => { + let id = rq.id.clone(); + match Request::try_from(rq) { + Ok(request) => route_rq(request, rt, transport).await, + Err(e) => { + let rpc_err = match &e { + Error::Server(se) => RpcError::from(se), + Error::Parse(pe) => RpcError::invalid_params(pe.to_string()), + _ => RpcError::internal_error(e.to_string()), + }; + rt.send_response(transport, Response::err(id, rpc_err)).await; + Ok(()) + } + } + } + Message::Response(_) => Ok(()), } - - Ok(()) } -// Request routing -async fn route_rq(rq: Request, rt: &mut Runtime) -> Result<(), Error> { - match rq.method { - method::Method::Exit => todo!(), - method::Method::Project(proj) => proj.route(rt).await?, - method::Method::Package(pack) => pack.route(rt).await?, +/// Route a validated request to its method handler. +async fn route_rq(rq: Request, rt: &Runtime, transport: &mut T) -> Result<(), Error> { + let id = rq.id.clone(); + + let result = match rq.method { + method::Method::Exit => { + rt.send_response(transport, Response::ok(id, "exiting")).await; + return Ok(()); + } + method::Method::Project(proj) => proj.route(id.clone(), rt, transport).await, + method::Method::Package(pack) => pack.route(id.clone(), rt, transport).await, + }; + + if let Err(e) = result { + let rpc_err = match &e { + Error::Server(se) => RpcError::from(se), + _ => RpcError::internal_error(e.to_string()), + }; + rt.send_response(transport, Response::err(id, rpc_err)).await; } Ok(()) } - -// /// The daemon's entrypoint. This is a psuedo event loop which does the following in step: -// /// 1. Read JSON-RPC input(s) from stdin. -// /// 2. Route each input. -// /// 3. Serialize the output and write to stdout. -// async fn start() { -// let stdin = io::stdin(); -// let mut line = String::new(); -// let (send, recv) = mpsc::channel::>(); - -// let cancel = RwLock::new(false); - -// // Responses are published through the tx send channel. -// // thread::spawn(move || respond_msg(recv, cancel)); - -// loop { -// // Block the main thread until we have an input line available to be read. -// // This is ok because, in theory, tasks will be processed on background threads. -// if let Err(e) = stdin.read_line(&mut line) { -// panic!("") -// } -// let res = route(&line, self.ctx, send.clone()).await; -// res.to_json().unwrap(); -// } -// } - -fn respond_msg(recv: Receiver, _cancel: RwLock) { - let mut stdout = io::stdout(); - while let Ok(res) = recv.recv() { - let msg = serde_json::to_string(&res); - stdout.write_all(msg.unwrap().as_bytes()).unwrap(); - stdout.write_all("\n".as_bytes()).unwrap(); - } -} - -// Route and execute the request, returning the result. -// Messages, including the result of subsequent computation, are sent over the sender channel. -// async fn route(line: &str, ctx: RwLock, send: Sender>) -> Result { -// let req = Message::from_json(line)?; -// match req { -// Message::Request(rq) => route_rq(Request::try_from(rq)?, ctx, send).await, -// Message::Response(_) => panic!(), -// } -// } - -// /// Do the actual Request routing here. -// /// One more level of abstraction. This routes calls to their actual implementation within -// /// the method module. -// async fn route_rq( -// request: Request, -// ctx: RwLock, -// send: Sender>, -// ) -> Result { -// match request.method { -// method::Method::Exit => todo!(), -// method::Method::Project(x) => x.route(ctx, send).await, -// method::Method::Package(_) => todo!(), -// } -// diff --git a/src/server/proto.rs b/src/server/proto.rs index 82f9024..8645456 100644 --- a/src/server/proto.rs +++ b/src/server/proto.rs @@ -1,12 +1,12 @@ use serde::{Deserialize, Serialize}; use serde_json::Value; -use super::ServerError; use crate::server::method::Method; -use crate::server::Error; +use crate::server::{Error, ServerError}; const JRPC_VER: &str = "2.0"; +/// A JSON-RPC 2.0 message - either a request or response. #[derive(Serialize, Deserialize, Debug, PartialEq, Eq)] #[serde(untagged)] pub enum Message { @@ -23,60 +23,52 @@ impl Message { .map_err(ServerError::InvalidJson)?; match msg { - Message::Request(x) if x.jsonrpc != JRPC_VER => { - Err(ServerError::InvalidMethod(x.jsonrpc))? + Message::Request(ref x) if x.jsonrpc != JRPC_VER => { + Err(ServerError::InvalidRequest("jsonrpc must be \"2.0\"".into()))? } _ => Ok(msg), } } } -#[derive(Serialize, Deserialize, Debug, PartialEq, Eq)] +/// JSON-RPC request/response identifier. Can be integer, string, or null. +#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)] #[serde(untagged)] pub enum Id { - Int(isize), + Int(i64), String(String), + Null, } -/// This is the raw representation of a JSON-RPC request *before* we convert it -/// into a structured type. +/// Raw representation of a JSON-RPC request before method routing. #[derive(Serialize, Deserialize, Debug, PartialEq, Eq)] pub struct RequestInner { - /// The JSON-RPC protocol version. Must be 2.0 as per the spec. + /// Must be "2.0". pub jsonrpc: String, - /// An identifier which, as per the JSON-RPC spec, can either be an integer - /// or string. We use an untagged enum to allow serde to transparenly parse these types. + /// Request identifier, echoed back in the response. pub id: Id, - /// This field is deserialized into a Method enum variant via Method::from_str. - /// Unfortunately this means that errors returned from Method::from_str are lost. - // #[serde_as(as = "DisplayFromStr")] + /// Method name in "namespace/method" format. pub method: String, - /// This field is null for notifications. + /// Method parameters (optional). #[serde(default = "Value::default")] #[serde(skip_serializing_if = "Value::is_null")] pub params: Value, } -/// A structured JSON-RPC request. -#[derive(Serialize, Deserialize, Debug, PartialEq, Eq)] +/// A validated, routable JSON-RPC request. +#[derive(Debug, PartialEq, Eq)] pub struct Request { - /// The JSON-RPC identifier. pub id: Id, - /// The method with data, if any. pub method: Method, } impl TryFrom for Request { - type Error = super::Error; + type Error = Error; fn try_from(value: RequestInner) -> Result { - // We deserialize the params value depending on the provided method. - // This is done by passing it to the from_value function of the Method type, - // which iterates down through nested enums until we have a concrete type for the Value - // and a valid method variant. let method = Method::from_value(&value.method, value.params)?; Ok(Self { id: value.id, @@ -85,49 +77,178 @@ impl TryFrom for Request { } } +/// A JSON-RPC 2.0 response. #[derive(Serialize, Deserialize, Debug, PartialEq, Eq)] pub struct Response { + /// Always "2.0". + pub jsonrpc: String, + + /// The request ID this response corresponds to. pub id: Id, - #[serde(flatten)] - pub data: ResponseData, + /// Success result (mutually exclusive with error). + #[serde(skip_serializing_if = "Option::is_none")] + pub result: Option, + + /// Error object (mutually exclusive with result). + #[serde(skip_serializing_if = "Option::is_none")] + pub error: Option, } impl Response { - pub fn data_ok(id: Id, data: impl Serialize) -> Response { + /// Create a success response with a serializable result. + pub fn ok(id: Id, data: T) -> Response { Response { + jsonrpc: JRPC_VER.into(), id, - data: ResponseData::Result(serde_json::to_string(&data).unwrap()), + result: Some(serde_json::to_value(data).unwrap_or(Value::Null)), + error: None, } } - pub fn ok(id: Id) -> Response { + /// Create a success response with no result data. + pub fn ok_empty(id: Id) -> Response { Response { + jsonrpc: JRPC_VER.into(), id, - data: ResponseData::Result("OK".into()), + result: Some(Value::Null), + error: None, } } -} -#[derive(Serialize, Deserialize, Debug, PartialEq, Eq)] -pub enum ResponseData { - #[serde(rename = "result")] - Result(String), - #[serde(rename = "error")] - Error(String), + /// Create an error response. + pub fn err(id: Id, error: RpcError) -> Response { + Response { + jsonrpc: JRPC_VER.into(), + id, + result: None, + error: Some(error), + } + } } -#[derive(Serialize, Deserialize, Debug)] +/// JSON-RPC 2.0 error object. +#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)] pub struct RpcError { - pub code: isize, + /// Numeric error code. + pub code: i32, + + /// Short description of the error. pub message: String, + + /// Additional error data. Contains `kind` for i18n error identification. + #[serde(skip_serializing_if = "Option::is_none")] + pub data: Option, } -impl From for RpcError { - fn from(value: Error) -> Self { +/// Additional error data for i18n and debugging. +#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)] +pub struct RpcErrorData { + /// Stable string identifier for the error type (e.g., "project.not_found"). + /// Used by clients for i18n lookup. + pub kind: String, + + /// Additional context, if any. + #[serde(skip_serializing_if = "Option::is_none")] + pub context: Option, +} + +impl RpcError { + /// Standard JSON-RPC error: Parse error (-32700) + pub fn parse_error(msg: impl Into) -> Self { + Self { + code: -32700, + message: msg.into(), + data: Some(RpcErrorData { + kind: "rpc.parse_error".into(), + context: None, + }), + } + } + + /// Standard JSON-RPC error: Invalid request (-32600) + pub fn invalid_request(msg: impl Into) -> Self { + Self { + code: -32600, + message: msg.into(), + data: Some(RpcErrorData { + kind: "rpc.invalid_request".into(), + context: None, + }), + } + } + + /// Standard JSON-RPC error: Method not found (-32601) + pub fn method_not_found(method: impl Into) -> Self { + let method = method.into(); Self { - code: value.discriminant(), - message: value.to_string(), + code: -32601, + message: format!("Method not found: {method}"), + data: Some(RpcErrorData { + kind: "rpc.method_not_found".into(), + context: Some(serde_json::json!({ "method": method })), + }), + } + } + + /// Standard JSON-RPC error: Invalid params (-32602) + pub fn invalid_params(msg: impl Into) -> Self { + Self { + code: -32602, + message: msg.into(), + data: Some(RpcErrorData { + kind: "rpc.invalid_params".into(), + context: None, + }), + } + } + + /// Standard JSON-RPC error: Internal error (-32603) + pub fn internal_error(msg: impl Into) -> Self { + Self { + code: -32603, + message: msg.into(), + data: Some(RpcErrorData { + kind: "rpc.internal_error".into(), + context: None, + }), + } + } + + /// Application-level error (code >= -32000) + /// TODO: This will be replaced by macro-generated error conversion + pub fn app_error(kind: impl Into, msg: impl Into) -> Self { + Self { + code: -32000, + message: msg.into(), + data: Some(RpcErrorData { + kind: kind.into(), + context: None, + }), + } + } +} + +/// Convert server errors to RPC errors. +/// TODO: Replace with macro-based system for inline error metadata. +impl From<&ServerError> for RpcError { + fn from(err: &ServerError) -> Self { + match err { + ServerError::InvalidJson(e) => RpcError::parse_error(e.to_string()), + ServerError::InvalidRequest(msg) => RpcError::invalid_request(msg), + ServerError::InvalidMethod(method) => RpcError::method_not_found(method), + ServerError::InvalidParams(method, msg) => { + RpcError::invalid_params(format!("{method}: {msg}")) + } + ServerError::InvalidContext => { + RpcError::app_error("server.invalid_context", "No project context available") + } + ServerError::ProjectLocked => { + RpcError::app_error("server.project_locked", "Project is locked by another process") + } + ServerError::WebSocket(e) => { + RpcError::app_error("server.websocket_error", e) + } } } } @@ -135,21 +256,23 @@ impl From for RpcError { #[cfg(test)] mod test { use super::*; - use crate::server::method::package::PackageMethod; use crate::server::method::project::{OpenProject, ProjectMethod}; - use crate::server::ServerError; #[test] fn test_jrpc_ver_validate() { - let data = r#"{ "jsonrpc": "2.0", "id": 1, "method": "oksamies""#; + // Invalid version should fail + let data = r#"{ "jsonrpc": "1.0", "id": 1, "method": "project/open", "params": {} }"#; + let result = Message::from_json(data); + assert!(result.is_err()); } #[test] fn test_request_deserialize() { - let data = r#"{ "jsonrpc": "2.0", "id": 1, "method": "project/set_context", "params": { "path": "/some/path" } }"#; - let rq: RequestInner = serde_json::from_str(&data).unwrap(); + // Valid request with params: project/open + let data = r#"{ "jsonrpc": "2.0", "id": 1, "method": "project/open", "params": { "path": "/some/path" } }"#; + let rq: RequestInner = serde_json::from_str(data).unwrap(); assert_eq!(rq.id, Id::Int(1)); - assert_eq!(rq.method, "project/set_context"); + assert_eq!(rq.method, "project/open"); assert!(matches!(rq.params, Value::Object(..))); let rq = Request::try_from(rq).unwrap(); @@ -159,22 +282,23 @@ mod test { Method::Project(ProjectMethod::Open(OpenProject { .. })) )); - let data = r#"{ "jsonrpc": "2.0", "id": "oksamies", "method": "package/get_metadata" }"#; - let rq: RequestInner = serde_json::from_str(&data).unwrap(); + // Valid request without params: project/get_metadata + let data = r#"{ "jsonrpc": "2.0", "id": "oksamies", "method": "project/get_metadata" }"#; + let rq: RequestInner = serde_json::from_str(data).unwrap(); assert_eq!(rq.id, Id::String("oksamies".into())); - assert_eq!(rq.method, "package/get_metadata"); + assert_eq!(rq.method, "project/get_metadata"); assert_eq!(rq.params, Value::Null); - // let rq = Request::try_from(rq).unwrap(); - // assert_eq!(rq.id, Id::String("oksamies".into())); - // assert!(matches!( - // rq.method, - // Method::Package(PackageMethod::GetMetadata) - // )); + let rq = Request::try_from(rq).unwrap(); + assert_eq!(rq.id, Id::String("oksamies".into())); + assert!(matches!( + rq.method, + Method::Project(ProjectMethod::GetMetadata) + )); - // Invalid methods should still be deserialized aok as they're checked by typed Request struct. + // Invalid methods should still be deserialized ok as they're checked by typed Request struct. let data = r#"{ "jsonrpc": "2.0", "id": "oksamies", "method": "null/null" }"#; - let rq: RequestInner = serde_json::from_str(&data).unwrap(); + let rq: RequestInner = serde_json::from_str(data).unwrap(); assert_eq!(rq.id, Id::String("oksamies".into())); assert_eq!(rq.method, "null/null"); assert_eq!(rq.params, Value::Null); @@ -184,20 +308,36 @@ mod test { assert!(matches!( rq, Err(Error::Server(ServerError::InvalidMethod(..))) - )); // Invalid methods should still be deserialized aok as they're checked by typed Request struct. + )); // Likewise, valid methods with garbage data should also fail when converted to typed. - let data = r#"{ "jsonrpc": "2.0", "id": "oksamies", "method": "project/set_context", "params": { "garbage": 1 } }"#; - let rq: RequestInner = serde_json::from_str(&data).unwrap(); + let data = r#"{ "jsonrpc": "2.0", "id": "oksamies", "method": "project/open", "params": { "garbage": 1 } }"#; + let rq: RequestInner = serde_json::from_str(data).unwrap(); assert_eq!(rq.id, Id::String("oksamies".into())); - assert_eq!(rq.method, "project/set_context"); + assert_eq!(rq.method, "project/open"); assert!(matches!(rq.params, Value::Object(..))); let rq = Request::try_from(rq); - panic!("{rq:?}"); - assert!(matches!( - rq, - Err(Error::Server(ServerError::InvalidJson(..))) - )); + assert!(matches!(rq, Err(Error::Parse(..)))); + } + + #[test] + fn test_response_serialize() { + // Success response + let resp = Response::ok(Id::Int(1), "hello"); + let json = serde_json::to_string(&resp).unwrap(); + assert!(json.contains(r#""jsonrpc":"2.0""#)); + assert!(json.contains(r#""id":1"#)); + assert!(json.contains(r#""result":"hello""#)); + assert!(!json.contains("error")); + + // Error response + let resp = Response::err(Id::String("req-1".into()), RpcError::method_not_found("foo/bar")); + let json = serde_json::to_string(&resp).unwrap(); + assert!(json.contains(r#""jsonrpc":"2.0""#)); + assert!(json.contains(r#""id":"req-1""#)); + assert!(json.contains(r#""code":-32601"#)); + assert!(json.contains(r#""kind":"rpc.method_not_found""#)); + assert!(!json.contains("result")); } } diff --git a/src/ts/package_reference/mod.rs b/src/ts/package_reference/mod.rs index 41d3a4d..fd4d7c3 100644 --- a/src/ts/package_reference/mod.rs +++ b/src/ts/package_reference/mod.rs @@ -39,6 +39,7 @@ impl PackageReference { .ok_or(PackageReferenceParseError::NumSections { expected: 2, got: 1, + provided: fullname.as_ref().to_string(), })?; Ok(PackageReference { namespace: namespace.to_string(), @@ -63,6 +64,7 @@ impl FromStr for PackageReference { .map_err(|v: Vec<&str>| PackageReferenceParseError::NumSections { expected: 3, got: v.len() - 1, + provided: s.to_string(), })?; Ok(PackageReference { @@ -81,8 +83,8 @@ impl Display for PackageReference { #[derive(thiserror::Error, Debug)] pub enum PackageReferenceParseError { - #[error("Expected {expected} sections, got {got}.")] - NumSections { expected: usize, got: usize }, + #[error("Expected {expected} sections, got {got} for string '{provided}'")] + NumSections { expected: usize, got: usize, provided: String }, #[error("Failed to parse version: {0}.")] VersionParseFail(#[from] VersionParseError), } diff --git a/src/ts/v1/models/ecosystem.rs b/src/ts/v1/models/ecosystem.rs index cb12f10..56c4b79 100644 --- a/src/ts/v1/models/ecosystem.rs +++ b/src/ts/v1/models/ecosystem.rs @@ -2,10 +2,7 @@ use std::collections::HashMap; use serde::{Deserialize, Serialize}; -use crate::error::Error; use crate::ts::version::Version; -use crate::game::ecosystem; -use crate::game::error::GameError; #[derive(Serialize, Deserialize, Debug)] @@ -123,7 +120,23 @@ pub struct GameDefR2MM { pub struct R2MMModLoaderPackage { pub package_id: String, pub root_folder: String, - pub loader: String, + pub loader: R2MLLoader, +} + +#[derive(Serialize, Deserialize, Debug, Clone)] +#[serde(rename_all = "lowercase")] +pub enum R2MLLoader { + BepInEx, + GDWeave, + GodotML, + Lovely, + MelonLoader, + Northstar, + #[serde(rename = "recursive-melonloader")] + RecursiveMelonLoader, + #[serde(rename = "return-of-modding")] + ReturnOfModding, + Shimloader, } #[derive(Serialize, Deserialize, Debug, Clone)] diff --git a/src/ui/mod.rs b/src/ui/mod.rs index e20cfc8..b0e14fb 100644 --- a/src/ui/mod.rs +++ b/src/ui/mod.rs @@ -1,3 +1,4 @@ +pub mod progress; pub mod reporter; use indicatif::ProgressStyle; diff --git a/src/ui/progress.rs b/src/ui/progress.rs new file mode 100644 index 0000000..631c880 --- /dev/null +++ b/src/ui/progress.rs @@ -0,0 +1,232 @@ +use std::sync::{Arc, OnceLock, RwLock}; + +use serde::Serialize; + +static SINK: OnceLock>>> = OnceLock::new(); + +fn get_sink_lock() -> &'static RwLock>> { + SINK.get_or_init(|| RwLock::new(None)) +} + +/// Install a progress sink globally. +pub fn set_sink(sink: Arc) { + *get_sink_lock().write().unwrap() = Some(sink); +} + +/// Clear the global progress sink. +pub fn clear_sink() { + *get_sink_lock().write().unwrap() = None; +} + +/// Emit a progress event to the current sink, if any. +pub fn emit(event: ProgressEvent) { + if let Some(sink) = get_sink_lock().read().unwrap().as_ref() { + sink.emit(event.clone()); + } +} + +/// Start a new scope. +pub fn scope_start(id: impl Into, label: impl Into) { + emit(ProgressEvent::ScopeStarted { + id: id.into(), + parent: None, + label: label.into(), + total: None, + }); +} + +/// Start a new child scope under an existing parent. +pub fn scope_start_child(id: impl Into, parent: impl Into, label: impl Into) { + emit(ProgressEvent::ScopeStarted { + id: id.into(), + parent: Some(parent.into()), + label: label.into(), + total: None, + }); +} + +/// Update progress within a scope. +pub fn scope_progress(id: impl Into, current: u64, status: Option<&str>) { + emit(ProgressEvent::ScopeProgress { + id: id.into(), + current, + status: status.map(String::from), + }); +} + +/// Mark a scope as completed. +pub fn scope_complete(id: impl Into) { + emit(ProgressEvent::ScopeCompleted { id: id.into() }); +} + +/// Mark a scope as failed. +pub fn scope_fail(id: impl Into, error: impl Into) { + emit(ProgressEvent::ScopeFailed { + id: id.into(), + error: error.into(), + }); +} + +/// Emit a log message. +pub fn log(level: LogLevel, message: impl Into) { + emit(ProgressEvent::Log { + level, + message: message.into(), + scope: None, + }); +} + +/// Emit an info log message. +pub fn info(message: impl Into) { + log(LogLevel::Info, message); +} + +/// Emit a warning log message. +pub fn warn(message: impl Into) { + log(LogLevel::Warn, message); +} + +/// Emit an error log message. +pub fn error(message: impl Into) { + log(LogLevel::Error, message); +} + +/// Progress events emitted during operations. +#[derive(Debug, Clone, Serialize)] +#[serde(tag = "type", rename_all = "snake_case")] +pub enum ProgressEvent { + /// A scoped operation has started. + ScopeStarted { + id: String, + parent: Option, + label: String, + total: Option, + }, + + /// Progress within a scope. + ScopeProgress { + id: String, + current: u64, + status: Option, + }, + + /// Scope completed successfully. + ScopeCompleted { id: String }, + + /// Scope failed. + ScopeFailed { id: String, error: String }, + + /// Log message. + Log { + level: LogLevel, + message: String, + scope: Option, + }, +} + +#[derive(Debug, Clone, Copy, Serialize, PartialEq, Eq)] +#[serde(rename_all = "snake_case")] +pub enum LogLevel { + Debug, + Info, + Warn, + Error, +} + +/// Receiver for progress events. +pub trait ProgressSink: Send + Sync { + fn emit(&self, event: ProgressEvent); +} + +/// Discards all events. +pub struct VoidSink; + +impl ProgressSink for VoidSink { + fn emit(&self, _event: ProgressEvent) {} +} + +/// Collects events for testing. +pub struct CollectorSink { + events: std::sync::Mutex>, +} + +impl CollectorSink { + pub fn new() -> Self { + Self { + events: std::sync::Mutex::new(Vec::new()), + } + } + + pub fn events(&self) -> Vec { + self.events.lock().unwrap().clone() + } +} + +impl ProgressSink for CollectorSink { + fn emit(&self, event: ProgressEvent) { + self.events.lock().unwrap().push(event); + } +} + +/// Terminal sink using indicatif for progress bars. +pub struct TerminalSink { + multi: indicatif::MultiProgress, + bars: std::sync::Mutex>, +} + +impl TerminalSink { + pub fn new() -> Self { + Self { + multi: indicatif::MultiProgress::new(), + bars: std::sync::Mutex::new(std::collections::HashMap::new()), + } + } + + fn get_style() -> indicatif::ProgressStyle { + indicatif::ProgressStyle::with_template("{spinner:.cyan} {msg}") + .unwrap() + .tick_strings(&["⠋", "⠙", "⠹", "⠸", "⠼", "⠴", "⠦", "⠧", "⠇", "⠏"]) + } +} + +impl ProgressSink for TerminalSink { + fn emit(&self, event: ProgressEvent) { + let mut bars = self.bars.lock().unwrap(); + + match event { + ProgressEvent::ScopeStarted { id, parent: _, label, total: _ } => { + let bar = self.multi.add(indicatif::ProgressBar::new_spinner()); + bar.set_style(Self::get_style()); + bar.set_message(label.clone()); + bar.enable_steady_tick(std::time::Duration::from_millis(80)); + bars.insert(id, (bar, label)); + } + ProgressEvent::ScopeProgress { id, current: _, status } => { + if let Some((bar, label)) = bars.get(&id) { + if let Some(s) = status { + bar.set_message(format!("{} ({})", label, s)); + } + } + } + ProgressEvent::ScopeCompleted { id } => { + if let Some((bar, _)) = bars.remove(&id) { + bar.finish_and_clear(); + } + } + ProgressEvent::ScopeFailed { id, error } => { + if let Some((bar, label)) = bars.remove(&id) { + bar.abandon_with_message(format!("✗ {}: {}", label, error)); + } + } + ProgressEvent::Log { level, message, .. } => { + let prefix = match level { + LogLevel::Debug => "DEBUG", + LogLevel::Info => "INFO ", + LogLevel::Warn => "WARN ", + LogLevel::Error => "ERROR", + }; + self.multi.println(format!("[{}] {}", prefix, message)).ok(); + } + } + } +}