diff --git a/Cargo.toml b/Cargo.toml index de163c4c1..243a08279 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -23,23 +23,45 @@ time = { version = "0.3.36", default-features = false } [dependencies] aes = { version = "0.8.4", optional = true } +async-stream = { version = "0.3.5", optional = true } byteorder = "1.5.0" bzip2 = { version = "0.4.4", optional = true } chrono = { version = "0.4.38", optional = true } +cfg-if = "1" constant_time_eq = { version = "0.3.0", optional = true } crc32fast = "1.4.0" +displaydoc = "0.2.4" +thiserror = "1.0.48" flate2 = { version = "1.0.28", default-features = false, optional = true } +futures-core = { version = "0.3", optional = true } +futures-util = { version = "0.3", optional = true } hmac = { version = "0.12.1", optional = true, features = ["reset"] } +indexmap = { version = "2", features = ["rayon"], optional = true } +libc = { version = "0.2.148", optional = true } +num_enum = "0.6.1" +once_cell = { version = "1.18.0", optional = true } +parking_lot = { version = "0.12.1", features = ["arc_lock"], optional = true } pbkdf2 = { version = "0.12.2", optional = true } +rayon = { version = "1.8.0", optional = true } sha1 = { version = "0.10.6", optional = true } +static_assertions = { version = "1.1.0", optional = true } +tempfile = { version = "3.8.0", optional = true } time = { workspace = true, optional = true, features = [ "std", ] } zstd = { version = "0.13.1", optional = true, default-features = false } +tokio = { version = "1", features = ["rt", "io-util", "sync", "fs", "macros"], optional = true } +tokio-pipe = { git = "https://github.com/cosmicexplorer/tokio-pipe", rev = "c44321ae17b4324a8ccaa4f687a8f68259fdca30", optional = true } +tokio-stream = { version = "0.1.14", optional = true } zopfli = { version = "0.8.0", optional = true } deflate64 = { version = "0.1.8", optional = true } lzma-rs = { version = "0.3.0", default-features = false, optional = true } +[dependencies.memchr2] +version = "2.6.4" +optional = true +package = "memchr" + [target.'cfg(any(all(target_arch = "arm", target_pointer_width = "32"), target_arch = "mips", target_arch = "powerpc"))'.dependencies] crossbeam-utils = "0.8.19" @@ -48,10 +70,15 @@ arbitrary = { version = "1.3.2", features = ["derive"] } [dev-dependencies] bencher = "0.1.5" -getrandom = { version = "0.2.14", features = ["js"] } -walkdir = "2.5.0" +criterion = { version = "0.5", features = ["async_tokio"] } +getrandom = "0.2.14" +tempfile = "3.8.0" time = { workspace = true, features = ["formatting", "macros"] } -anyhow = "1" +tokio = { version = "1", features = ["rt", "rt-multi-thread"] } +tokio-test = "0.4.3" +uuid = { version = "1.4.1", features = ["v4"] } +walkdir = "2.5.0" + [features] aes-crypto = ["aes", "constant_time_eq", "hmac", "pbkdf2", "sha1"] chrono = ["chrono/default"] @@ -65,6 +92,11 @@ deflate-zlib = ["flate2/zlib", "_deflate-any"] deflate-zlib-ng = ["flate2/zlib-ng", "_deflate-any"] deflate-zopfli = ["zopfli", "_deflate-any"] lzma = ["lzma-rs/stream"] +tokio-async = [ + "dep:tokio", "dep:memchr2", "dep:tokio-stream", "dep:tokio-pipe", "dep:parking_lot", "dep:libc", "dep:futures-core", + "dep:futures-util", "dep:async-stream", "dep:indexmap", "dep:once_cell", "dep:static_assertions", "dep:rayon", + "dep:tempfile", +] unreserved = [] default = [ "aes-crypto", @@ -76,6 +108,7 @@ default = [ "lzma", "time", "zstd", + "tokio-async" ] [[bench]] @@ -85,3 +118,20 @@ harness = false [[bench]] name = "read_metadata" harness = false + +[[bench]] +name = "extract" +harness = false + +# [[bench]] +# name = "merge_archive" +# harness = false + +[profile.release] +strip = false +debug = true +# lto = true + +[package.metadata.docs.rs] +all-features = true +rustdoc-args = ["--cfg", "docsrs"] diff --git a/benches/extract.rs b/benches/extract.rs new file mode 100644 index 000000000..d4a869bb0 --- /dev/null +++ b/benches/extract.rs @@ -0,0 +1,221 @@ +use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion, Throughput}; + +use std::io::Cursor; +use std::path::{Path, PathBuf}; +use std::pin::Pin; +use std::sync::Arc; +use std::time::Duration; + +use getrandom::getrandom; +use once_cell::sync::Lazy; +use tempfile::tempdir; +use tokio::{fs, io, runtime::Runtime}; +use uuid::Uuid; + +use zip::{result::ZipResult, write::FileOptions, ZipWriter}; + +fn generate_random_archive( + num_entries: usize, + entry_size: usize, + options: FileOptions, +) -> ZipResult>> { + use std::io::Write; + + let buf = Cursor::new(Vec::new()); + let mut zip = ZipWriter::new(buf); + + let mut bytes = vec![0u8; entry_size]; + for i in 0..num_entries { + let name = format!("random{}.dat", i); + zip.start_file(name, options)?; + getrandom(&mut bytes).unwrap(); + zip.write_all(&bytes)?; + } + + let buf = zip.finish()?.into_inner(); + + Ok(Cursor::new(buf.into_boxed_slice())) +} + +static BIG_ARCHIVE_PATH: Lazy = + Lazy::new(|| Path::new(env!("CARGO_MANIFEST_DIR")).join("benches/target.zip")); + +static SMALL_ARCHIVE_PATH: Lazy = + Lazy::new(|| Path::new(env!("CARGO_MANIFEST_DIR")).join("benches/small-target.zip")); + +const NUM_ENTRIES: usize = 1_000; +const ENTRY_SIZE: usize = 10_000; + +pub fn bench_io(c: &mut Criterion) { + let rt = Runtime::new().unwrap(); + + let mut group = c.benchmark_group("io"); + + let options = FileOptions::default().compression_method(zip::CompressionMethod::Deflated); + + let gen_td = tempdir().unwrap(); + let random_path = gen_td.path().join("random.zip"); + std::io::copy( + &mut generate_random_archive(NUM_ENTRIES, ENTRY_SIZE, options).unwrap(), + &mut std::fs::File::create(&random_path).unwrap(), + ) + .unwrap(); + + for (path, desc, n) in [ + (&*BIG_ARCHIVE_PATH, "big archive", Some(30)), + (&*SMALL_ARCHIVE_PATH, "small archive", None), + (&random_path, "random archive", None), + ] { + let len = std::fs::metadata(&path).unwrap().len() as usize; + let id = format!("{}({} bytes)", desc, len); + + group.throughput(Throughput::Bytes(len as u64)); + + if let Some(n) = n { + group.sample_size(n); + } + + group.bench_function(BenchmarkId::new(&id, ""), |b| { + use zip::tokio::os::copy_file_range::*; + + let td = tempdir().unwrap(); + b.to_async(&rt).iter(|| async { + let sync_handle = std::fs::File::open(path).unwrap(); + let mut src = MutateInnerOffset::new(sync_handle, Role::Readable).unwrap(); + let mut src = Pin::new(&mut src); + + let cur_name = format!("{}.zip", Uuid::new_v4()); + let tf = td.path().join(cur_name); + let out = std::fs::File::create(tf).unwrap(); + let mut dst = MutateInnerOffset::new(out, Role::Writable).unwrap(); + let mut dst = Pin::new(&mut dst); + + let written = copy_file_range(src, dst, len).await.unwrap(); + assert_eq!(written, len); + }); + }); + + group.bench_function(BenchmarkId::new(&id, ""), |b| { + let td = tempdir().unwrap(); + b.to_async(&rt).iter(|| async { + let handle = fs::File::open(path).await.unwrap(); + let mut buf_handle = io::BufReader::with_capacity(len, handle); + + let cur_name = format!("{}.zip", Uuid::new_v4()); + let tf = td.path().join(cur_name); + + let mut out = fs::File::create(tf).await.unwrap(); + assert_eq!( + len as u64, + io::copy_buf(&mut buf_handle, &mut out).await.unwrap() + ); + }); + }); + + group.bench_function(BenchmarkId::new(&id, ""), |b| { + let td = tempdir().unwrap(); + b.to_async(&rt).iter(|| async { + let mut handle = fs::File::open(path).await.unwrap(); + + let cur_name = format!("{}.zip", Uuid::new_v4()); + let tf = td.path().join(cur_name); + + let mut out = fs::File::create(tf).await.unwrap(); + assert_eq!(len as u64, io::copy(&mut handle, &mut out).await.unwrap()); + }); + }); + + group.bench_function(BenchmarkId::new(&id, ""), |b| { + let td = tempdir().unwrap(); + b.iter(|| { + let mut sync_handle = std::fs::File::open(path).unwrap(); + + let cur_name = format!("{}.zip", Uuid::new_v4()); + let tf = td.path().join(cur_name); + + let mut out = std::fs::File::create(tf).unwrap(); + assert_eq!( + len as u64, + /* NB: this doesn't use copy_buf like the async case, because std::io has no + * corresponding function, and using an std::io::BufReader wrapper actually + * hurts perf by a lot!! */ + std::io::copy(&mut sync_handle, &mut out).unwrap() + ); + }); + }); + } +} + +pub fn bench_extract(c: &mut Criterion) { + let rt = Runtime::new().unwrap(); + let mut group = c.benchmark_group("extract"); + + let options = FileOptions::default().compression_method(zip::CompressionMethod::Deflated); + + let gen_td = tempdir().unwrap(); + let random_path = gen_td.path().join("random.zip"); + std::io::copy( + &mut generate_random_archive(NUM_ENTRIES, ENTRY_SIZE, options).unwrap(), + &mut std::fs::File::create(&random_path).unwrap(), + ) + .unwrap(); + + for (path, desc, n, t) in [ + ( + &*BIG_ARCHIVE_PATH, + "big archive", + Some(10), + Some(Duration::from_secs(10)), + ), + (&*SMALL_ARCHIVE_PATH, "small archive", None, None), + (&random_path, "random archive", None, None), + ] { + let len = std::fs::metadata(&path).unwrap().len() as usize; + let id = format!("{}({} bytes)", desc, len); + + group.throughput(Throughput::Bytes(len as u64)); + + if let Some(n) = n { + group.sample_size(n); + } + if let Some(t) = t { + group.measurement_time(t); + } + + group.bench_function(BenchmarkId::new(&id, ""), |b| { + let td = tempdir().unwrap(); + b.to_async(&rt).iter(|| async { + let out_dir = Arc::new(td.path().to_path_buf()); + let handle = fs::File::open(path).await.unwrap(); + let mut zip = zip::tokio::read::ZipArchive::new(Box::pin(handle)) + .await + .unwrap(); + Pin::new(&mut zip).extract(out_dir).await.unwrap(); + }) + }); + + group.bench_function(BenchmarkId::new(&id, ""), |b| { + let td = tempdir().unwrap(); + b.to_async(&rt).iter(|| async { + let out_dir = Arc::new(td.path().to_path_buf()); + let handle = fs::File::open(path).await.unwrap(); + let mut zip = zip::tokio::read::ZipArchive::new(Box::pin(handle)) + .await + .unwrap(); + Pin::new(&mut zip).extract_simple(out_dir).await.unwrap(); + }) + }); + + group.bench_function(BenchmarkId::new(&id, ""), |b| { + let td = tempdir().unwrap(); + b.iter(|| { + let sync_handle = std::fs::File::open(path).unwrap(); + let mut zip = zip::read::ZipArchive::new(sync_handle).unwrap(); + zip.extract(td.path()).unwrap(); + }) + }); + } +} + +criterion_group!(benches, bench_io, bench_extract); +criterion_main!(benches); diff --git a/benches/merge_archive.rs b/benches/merge_archive.rs new file mode 100644 index 000000000..833b85c20 --- /dev/null +++ b/benches/merge_archive.rs @@ -0,0 +1,138 @@ +/* use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion, Throughput}; */ + +/* use std::io::Cursor; */ +/* use std::path::{Path, PathBuf}; */ +/* use std::pin::Pin; */ +/* use std::sync::Arc; */ +/* use std::time::Duration; */ + +/* use getrandom::getrandom; */ +/* use once_cell::sync::Lazy; */ +/* use tempfile::tempdir; */ +/* use tokio::{fs, io, runtime::Runtime}; */ +/* use uuid::Uuid; */ + +/* use zip::{result::ZipResult, write::FileOptions, ZipWriter}; */ + +/* fn generate_random_archive( */ +/* num_entries: usize, */ +/* entry_size: usize, */ +/* options: FileOptions, */ +/* ) -> ZipResult>> { */ +/* use std::io::Write; */ + +/* let buf = Cursor::new(Vec::new()); */ +/* let mut zip = ZipWriter::new(buf); */ + +/* let mut bytes = vec![0u8; entry_size]; */ +/* for i in 0..num_entries { */ +/* let name = format!("random{}.dat", i); */ +/* zip.start_file(name, options)?; */ +/* getrandom(&mut bytes).unwrap(); */ +/* zip.write_all(&bytes)?; */ +/* } */ + +/* let buf = zip.finish()?.into_inner(); */ + +/* Ok(Cursor::new(buf.into_boxed_slice())) */ +/* } */ + +/* static BIG_ARCHIVE_PATH: Lazy = */ +/* Lazy::new(|| Path::new(env!("CARGO_MANIFEST_DIR")).join("benches/target.zip")); */ + +/* static SMALL_ARCHIVE_PATH: Lazy = */ +/* Lazy::new(|| Path::new(env!("CARGO_MANIFEST_DIR")).join("benches/small-target.zip")); */ + +/* const NUM_ENTRIES: usize = 1_000; */ +/* const ENTRY_SIZE: usize = 10_000; */ + +/* fn perform_merge( */ +/* mut src: ZipArchive, */ +/* mut target: ZipWriter, */ +/* ) -> ZipResult> { */ +/* (target).merge_archive(Pin::new(&mut src))?; */ +/* Ok(target) */ +/* } */ + +/* fn perform_raw_copy_file( */ +/* mut src: ZipArchive, */ +/* mut target: ZipWriter, */ +/* ) -> ZipResult> { */ +/* for i in 0..src.len() { */ +/* let entry = src.by_index(i)?; */ +/* target.raw_copy_file(entry)?; */ +/* } */ +/* Ok(target) */ +/* } */ + +/* const NUM_ENTRIES: usize = 100; */ +/* const ENTRY_SIZE: usize = 1024; */ + +/* fn merge_archive_stored(bench: &mut Bencher) { */ +/* let options = FileOptions::default().compression_method(zip::CompressionMethod::Stored); */ +/* let (len, src) = generate_random_archive(NUM_ENTRIES, ENTRY_SIZE, options).unwrap(); */ + +/* bench.bytes = len as u64; */ + +/* bench.iter(|| { */ +/* let buf = Cursor::new(Vec::new()); */ +/* let zip = ZipWriter::new(buf); */ +/* let mut zip = perform_merge(src.clone(), zip).unwrap(); */ +/* let buf = zip.finish().unwrap().into_inner(); */ +/* assert_eq!(buf.len(), len); */ +/* }); */ +/* } */ + +/* fn merge_archive_compressed(bench: &mut Bencher) { */ +/* let options = FileOptions::default().compression_method(zip::CompressionMethod::Deflated); */ +/* let (len, src) = generate_random_archive(NUM_ENTRIES, ENTRY_SIZE, options).unwrap(); */ + +/* bench.bytes = len as u64; */ + +/* bench.iter(|| { */ +/* let buf = Cursor::new(Vec::new()); */ +/* let zip = ZipWriter::new(buf); */ +/* let mut zip = perform_merge(src.clone(), zip).unwrap(); */ +/* let buf = zip.finish().unwrap().into_inner(); */ +/* assert_eq!(buf.len(), len); */ +/* }); */ +/* } */ + +/* fn merge_archive_raw_copy_file_stored(bench: &mut Bencher) { */ +/* let options = FileOptions::default().compression_method(zip::CompressionMethod::Stored); */ +/* let (len, src) = generate_random_archive(NUM_ENTRIES, ENTRY_SIZE, options).unwrap(); */ + +/* bench.bytes = len as u64; */ + +/* bench.iter(|| { */ +/* let buf = Cursor::new(Vec::new()); */ +/* let zip = ZipWriter::new(buf); */ +/* let mut zip = perform_raw_copy_file(src.clone(), zip).unwrap(); */ +/* let buf = zip.finish().unwrap().into_inner(); */ +/* assert_eq!(buf.len(), len); */ +/* }); */ +/* } */ + +/* fn merge_archive_raw_copy_file_compressed(bench: &mut Bencher) { */ +/* let options = FileOptions::default().compression_method(zip::CompressionMethod::Deflated); */ +/* let (len, src) = generate_random_archive(NUM_ENTRIES, ENTRY_SIZE, options).unwrap(); */ + +/* bench.bytes = len as u64; */ + +/* bench.iter(|| { */ +/* let buf = Cursor::new(Vec::new()); */ +/* let zip = ZipWriter::new(buf); */ +/* let mut zip = perform_raw_copy_file(src.clone(), zip).unwrap(); */ +/* let buf = zip.finish().unwrap().into_inner(); */ +/* assert_eq!(buf.len(), len); */ +/* }); */ +/* } */ + +/* benchmark_group!( */ +/* benches, */ +/* merge_archive_stored, */ +/* merge_archive_compressed, */ +/* merge_archive_raw_copy_file_stored, */ +/* merge_archive_raw_copy_file_compressed, */ +/* ); */ +/* benchmark_main!(benches); */ diff --git a/benches/small-target.zip b/benches/small-target.zip new file mode 100644 index 000000000..7451faea8 Binary files /dev/null and b/benches/small-target.zip differ diff --git a/benches/target.zip b/benches/target.zip new file mode 100644 index 000000000..f452e1881 Binary files /dev/null and b/benches/target.zip differ diff --git a/examples/write_dir.rs b/examples/write_dir.rs index 81305e214..8d9d72f6f 100644 --- a/examples/write_dir.rs +++ b/examples/write_dir.rs @@ -2,6 +2,7 @@ use anyhow::Context; use std::io::prelude::*; use zip::{result::ZipError, write::SimpleFileOptions}; +use cfg_if::cfg_if; use std::fs::File; use std::path::Path; use walkdir::{DirEntry, WalkDir}; @@ -17,15 +18,21 @@ const METHOD_DEFLATED: Option = Some(zip::CompressionMet #[cfg(not(feature = "_deflate-any"))] const METHOD_DEFLATED: Option = None; -#[cfg(feature = "bzip2")] -const METHOD_BZIP2: Option = Some(zip::CompressionMethod::Bzip2); -#[cfg(not(feature = "bzip2"))] -const METHOD_BZIP2: Option = None; +cfg_if! { + if #[cfg(feature = "bzip2")] { + const METHOD_BZIP2: Option = Some(zip::CompressionMethod::Bzip2); + } else { + const METHOD_BZIP2: Option = None; + } +} -#[cfg(feature = "zstd")] -const METHOD_ZSTD: Option = Some(zip::CompressionMethod::Zstd); -#[cfg(not(feature = "zstd"))] -const METHOD_ZSTD: Option = None; +cfg_if! { + if #[cfg(feature = "zstd")] { + const METHOD_ZSTD: Option = Some(zip::CompressionMethod::Zstd); + } else { + const METHOD_ZSTD: Option = None; + } +} fn real_main() -> i32 { let args: Vec<_> = std::env::args().collect(); diff --git a/src/compression.rs b/src/compression.rs index 1862c976c..23b53563e 100644 --- a/src/compression.rs +++ b/src/compression.rs @@ -1,5 +1,6 @@ //! Possible ZIP compression methods. +use cfg_if::cfg_if; use std::fmt; #[allow(deprecated)] @@ -18,6 +19,14 @@ pub enum CompressionMethod { Stored, /// Compress the file using Deflate #[cfg(feature = "_deflate-any")] + #[cfg_attr( + docsrs, + doc(cfg(any( + feature = "deflate", + feature = "deflate-miniz", + feature = "deflate-zlib" + ))) + )] Deflated, /// Compress the file using Deflate64. /// Decoding deflate64 is supported but encoding deflate64 is not supported. @@ -25,15 +34,18 @@ pub enum CompressionMethod { Deflate64, /// Compress the file using BZIP2 #[cfg(feature = "bzip2")] + #[cfg_attr(docsrs, doc(cfg(feature = "bzip2")))] Bzip2, /// Encrypted using AES. /// /// The actual compression method has to be taken from the AES extra data field /// or from `ZipFileData`. #[cfg(feature = "aes-crypto")] + #[cfg_attr(docsrs, doc(cfg(feature = "aes-crypto")))] Aes, /// Compress the file using ZStandard #[cfg(feature = "zstd")] + #[cfg_attr(docsrs, doc(cfg(feature = "zstd")))] Zstd, /// Compress the file using LZMA #[cfg(feature = "lzma")] @@ -68,26 +80,33 @@ impl CompressionMethod { pub const BZIP2: Self = CompressionMethod::Bzip2; #[cfg(not(feature = "bzip2"))] pub const BZIP2: Self = CompressionMethod::Unsupported(12); - #[cfg(not(feature = "lzma"))] pub const LZMA: Self = CompressionMethod::Unsupported(14); #[cfg(feature = "lzma")] pub const LZMA: Self = CompressionMethod::Lzma; pub const IBM_ZOS_CMPSC: Self = CompressionMethod::Unsupported(16); pub const IBM_TERSE: Self = CompressionMethod::Unsupported(18); pub const ZSTD_DEPRECATED: Self = CompressionMethod::Unsupported(20); - #[cfg(feature = "zstd")] - pub const ZSTD: Self = CompressionMethod::Zstd; - #[cfg(not(feature = "zstd"))] - pub const ZSTD: Self = CompressionMethod::Unsupported(93); + cfg_if! { + if #[cfg(feature = "zstd")] { + #[cfg_attr(docsrs, doc(cfg(feature = "zstd")))] + pub const ZSTD: Self = CompressionMethod::Zstd; + } else { + pub const ZSTD: Self = CompressionMethod::Unsupported(93); + } + } pub const MP3: Self = CompressionMethod::Unsupported(94); pub const XZ: Self = CompressionMethod::Unsupported(95); pub const JPEG: Self = CompressionMethod::Unsupported(96); pub const WAVPACK: Self = CompressionMethod::Unsupported(97); pub const PPMD: Self = CompressionMethod::Unsupported(98); - #[cfg(feature = "aes-crypto")] - pub const AES: Self = CompressionMethod::Aes; - #[cfg(not(feature = "aes-crypto"))] - pub const AES: Self = CompressionMethod::Unsupported(99); + cfg_if! { + if #[cfg(feature = "aes-crypto")] { + #[cfg_attr(docsrs, doc(cfg(feature = "aes-crypto")))] + pub const AES: Self = CompressionMethod::Aes; + } else { + pub const AES: Self = CompressionMethod::Unsupported(99); + } + } } impl CompressionMethod { /// Converts an u16 to its corresponding CompressionMethod @@ -174,12 +193,20 @@ impl fmt::Display for CompressionMethod { pub const SUPPORTED_COMPRESSION_METHODS: &[CompressionMethod] = &[ CompressionMethod::Stored, #[cfg(feature = "_deflate-any")] + /* NB: these don't appear to show up in the docs. */ + #[cfg_attr(docsrs, doc(cfg(any( + feature = "deflate", + feature = "deflate-miniz", + feature = "deflate-zlib" + ))))] CompressionMethod::Deflated, #[cfg(feature = "deflate64")] CompressionMethod::Deflate64, #[cfg(feature = "bzip2")] + #[cfg_attr(docsrs, doc(cfg(feature = "bzip2")))] CompressionMethod::Bzip2, #[cfg(feature = "zstd")] + #[cfg_attr(docsrs, doc(cfg(feature = "zstd")))] CompressionMethod::Zstd, ]; diff --git a/src/crc32.rs b/src/crc32.rs index 7152c0853..85c19c528 100644 --- a/src/crc32.rs +++ b/src/crc32.rs @@ -1,7 +1,6 @@ //! Helper module to compute a CRC32 checksum use std::io; -use std::io::prelude::*; use crc32fast::Hasher; @@ -27,6 +26,7 @@ impl Crc32Reader { } } + #[inline] fn check_matches(&self) -> bool { self.check == self.hasher.clone().finalize() } @@ -36,7 +36,7 @@ impl Crc32Reader { } } -impl Read for Crc32Reader { +impl std::io::Read for Crc32Reader { fn read(&mut self, buf: &mut [u8]) -> io::Result { let invalid_check = !buf.is_empty() && !self.check_matches() && !self.ae2_encrypted; @@ -55,6 +55,7 @@ impl Read for Crc32Reader { #[cfg(test)] mod test { use super::*; + use std::io::Read; #[test] fn test_empty_reader() { diff --git a/src/lib.rs b/src/lib.rs index baf29a1c4..32f4d974b 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -25,7 +25,12 @@ //! | ZipCrypto deprecated encryption | ✅ | ✅ | //! //! +//! + +#![feature(can_vector)] + #![warn(missing_docs)] +#![cfg_attr(docsrs, feature(doc_cfg))] pub use crate::compression::{CompressionMethod, SUPPORTED_COMPRESSION_METHODS}; pub use crate::read::ZipArchive; @@ -59,3 +64,6 @@ zip = \"="] #[doc = "\"\n\ ```"] pub mod unstable; + +#[cfg(feature = "tokio-async")] +pub mod tokio; diff --git a/src/read.rs b/src/read.rs index 0a39faef0..81e198f2e 100644 --- a/src/read.rs +++ b/src/read.rs @@ -11,6 +11,7 @@ use crate::spec; use crate::types::{AesMode, AesVendorVersion, DateTime, System, ZipFileData}; use crate::zipcrypto::{ZipCryptoReader, ZipCryptoReaderValid, ZipCryptoValidator}; use byteorder::{LittleEndian, ReadBytesExt}; +use cfg_if::cfg_if; use std::borrow::{Borrow, Cow}; use std::collections::HashMap; use std::io::{self, prelude::*}; @@ -123,16 +124,19 @@ impl<'a> CryptoReader<'a> { /// Returns `true` if the data is encrypted using AE2. pub const fn is_ae2_encrypted(&self) -> bool { - #[cfg(feature = "aes-crypto")] - return matches!( - self, - CryptoReader::Aes { - vendor_version: AesVendorVersion::Ae2, - .. + cfg_if! { + if #[cfg(feature = "aes-crypto")] { + return matches!( + self, + CryptoReader::Aes { + vendor_version: AesVendorVersion::Ae2, + .. + } + ); + } else { + false } - ); - #[cfg(not(feature = "aes-crypto"))] - false + } } } @@ -250,20 +254,21 @@ pub(crate) fn make_crypto_reader<'a>( } let reader = match (password, aes_info) { - #[cfg(not(feature = "aes-crypto"))] - (Some(_), Some(_)) => { - return Err(ZipError::UnsupportedArchive( - "AES encrypted files cannot be decrypted without the aes-crypto feature.", - )) - } - #[cfg(feature = "aes-crypto")] (Some(password), Some((aes_mode, vendor_version))) => { - match AesReader::new(reader, aes_mode, compressed_size).validate(password)? { - None => return Err(InvalidPassword), - Some(r) => CryptoReader::Aes { - reader: r, - vendor_version, - }, + cfg_if! { + if #[cfg(feature = "aes-crypto")] { + match AesReader::new(reader, aes_mode, compressed_size).validate(password)? { + None => return Err(InvalidPassword), + Some(r) => CryptoReader::Aes { + reader: r, + vendor_version, + }, + } + } else { + return Err(ZipError::UnsupportedArchive( + "AES encrypted files cannot be decrypted without the aes-crypto feature.", + )) + } } } (Some(password), None) => { @@ -812,7 +817,7 @@ fn central_header_to_zip_file_inner( // Construct the result let mut result = ZipFileData { - system: System::from_u8((version_made_by >> 8) as u8), + system: System::from((version_made_by >> 8) as u8), version_made_by: version_made_by as u8, encrypted, using_data_descriptor, @@ -1178,7 +1183,7 @@ pub fn read_zipfile_from_stream<'a, R: Read>(reader: &'a mut R) -> ZipResult> 8) as u8), + system: System::from((version_made_by >> 8) as u8), version_made_by: version_made_by as u8, encrypted, using_data_descriptor, diff --git a/src/result.rs b/src/result.rs index f2bb46099..d3e7cb5e5 100644 --- a/src/result.rs +++ b/src/result.rs @@ -1,28 +1,32 @@ //! Error types that can be emitted from this library +use displaydoc::Display; +use thiserror::Error; + use std::error::Error; use std::fmt; use std::io; use std::io::IntoInnerError; use std::num::TryFromIntError; +use std::ops::{Range, RangeInclusive}; /// Generic result type with ZipError as its error variant pub type ZipResult = Result; /// Error type for Zip -#[derive(Debug)] +#[derive(Debug, Display, Error)] #[non_exhaustive] pub enum ZipError { - /// An Error caused by I/O - Io(io::Error), + /// i/o error: {0} + Io(#[from] io::Error), - /// This file is probably not a zip archive + /// invalid Zip archive: {0} InvalidArchive(&'static str), - /// This archive is not supported + /// unsupported Zip archive: {0} UnsupportedArchive(&'static str), - /// The requested file could not be found in the archive + /// specified file not found in archive FileNotFound, /// The password provided is incorrect @@ -110,5 +114,3 @@ impl fmt::Display for DateTimeRangeError { ) } } - -impl Error for DateTimeRangeError {} diff --git a/src/spec.rs b/src/spec.rs index b620c01e7..f09e38034 100644 --- a/src/spec.rs +++ b/src/spec.rs @@ -1,8 +1,12 @@ use crate::result::{ZipError, ZipResult}; + use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt}; use std::borrow::Cow; -use std::io; +use tokio::io::{self, AsyncReadExt, AsyncSeekExt, AsyncWriteExt}; use std::io::prelude::*; +use std::io::Seek; +use std::pin::Pin; +use std::{cmp, io::IoSlice, mem}; use std::path::{Component, Path}; pub const LOCAL_FILE_HEADER_SIGNATURE: u32 = 0x04034b50; @@ -24,6 +28,18 @@ pub struct CentralDirectoryEnd { pub zip_file_comment: Vec, } +#[repr(packed)] +struct CentralDirectoryEndBuffer { + pub magic: u32, + pub disk_number: u16, + pub disk_with_central_directory: u16, + pub number_of_files_on_this_disk: u16, + pub number_of_files: u16, + pub central_directory_size: u32, + pub central_directory_offset: u32, + pub zip_file_comment_length: u16, +} + impl CentralDirectoryEnd { pub fn parse(reader: &mut T) -> ZipResult { let magic = reader.read_u32::()?; @@ -51,23 +67,26 @@ impl CentralDirectoryEnd { }) } - pub fn find_and_parse(reader: &mut T) -> ZipResult<(CentralDirectoryEnd, u64)> { + pub async fn parse_async(mut reader: Pin<&mut T>) -> ZipResult<(CentralDirectoryEnd, u64)> { + static_assertions::assert_eq_size!([u8; 22], CentralDirectoryEndBuffer); + { const HEADER_SIZE: u64 = 22; const BYTES_BETWEEN_MAGIC_AND_COMMENT_SIZE: u64 = HEADER_SIZE - 6; let file_length = reader.seek(io::SeekFrom::End(0))?; - let search_upper_bound = file_length.saturating_sub(HEADER_SIZE + u16::MAX as u64); + let search_upper_bound = + file_length.saturating_sub(Self::HEADER_SIZE + u16::MAX as u64); - if file_length < HEADER_SIZE { + if file_length < Self::HEADER_SIZE { return Err(ZipError::InvalidArchive("Invalid zip header")); } - let mut pos = file_length - HEADER_SIZE; + let mut pos = file_length - Self::HEADER_SIZE; while pos >= search_upper_bound { reader.seek(io::SeekFrom::Start(pos))?; if reader.read_u32::()? == CENTRAL_DIRECTORY_END_SIGNATURE { reader.seek(io::SeekFrom::Current( - BYTES_BETWEEN_MAGIC_AND_COMMENT_SIZE as i64, + Self::BYTES_BETWEEN_MAGIC_AND_COMMENT_SIZE as i64, ))?; let cde_start_pos = reader.seek(io::SeekFrom::Start(pos))?; if let Ok(end_header) = CentralDirectoryEnd::parse(reader) { @@ -96,6 +115,34 @@ impl CentralDirectoryEnd { writer.write_all(&self.zip_file_comment)?; Ok(()) } + + pub async fn write_async(&self, mut writer: Pin<&mut T>) -> ZipResult<()> { + let block: [u8; 22] = unsafe { + mem::transmute(CentralDirectoryEndBuffer { + magic: CENTRAL_DIRECTORY_END_SIGNATURE, + disk_number: self.disk_number, + disk_with_central_directory: self.disk_with_central_directory, + number_of_files_on_this_disk: self.number_of_files_on_this_disk, + number_of_files: self.number_of_files, + central_directory_size: self.central_directory_size, + central_directory_offset: self.central_directory_offset, + zip_file_comment_length: self.zip_file_comment.len() as u16, + }) + }; + + if writer.is_write_vectored() { + /* TODO: zero-copy!! */ + let block = IoSlice::new(&block); + let comment = IoSlice::new(&self.zip_file_comment); + writer.write_vectored(&[block, comment]).await?; + } else { + /* If no special vector write support, just perform two separate writes. */ + writer.write_all(&block).await?; + writer.write_all(&self.zip_file_comment).await?; + } + + Ok(()) + } } pub struct Zip64CentralDirectoryEndLocator { @@ -104,6 +151,14 @@ pub struct Zip64CentralDirectoryEndLocator { pub number_of_disks: u32, } +#[repr(packed)] +struct Zip64CentralDirectoryEndLocatorBuffer { + pub magic: u32, + pub disk_with_central_directory: u32, + pub end_of_central_directory_offset: u64, + pub number_of_disks: u32, +} + impl Zip64CentralDirectoryEndLocator { pub fn parse(reader: &mut T) -> ZipResult { let magic = reader.read_u32::()?; @@ -123,6 +178,31 @@ impl Zip64CentralDirectoryEndLocator { }) } + pub async fn parse_async(mut reader: Pin<&mut T>) -> ZipResult { + static_assertions::assert_eq_size!([u8; 20], Zip64CentralDirectoryEndLocatorBuffer); + let mut info = [0u8; 20]; + reader.read_exact(&mut info[..]).await?; + + let Zip64CentralDirectoryEndLocatorBuffer { + magic, + disk_with_central_directory, + end_of_central_directory_offset, + number_of_disks, + } = unsafe { mem::transmute(info) }; + + if magic != ZIP64_CENTRAL_DIRECTORY_END_LOCATOR_SIGNATURE { + return Err(ZipError::InvalidArchive( + "Invalid zip64 locator digital signature header", + )); + } + + Ok(Zip64CentralDirectoryEndLocator { + disk_with_central_directory, + end_of_central_directory_offset, + number_of_disks, + }) + } + pub fn write(&self, writer: &mut T) -> ZipResult<()> { writer.write_u32::(ZIP64_CENTRAL_DIRECTORY_END_LOCATOR_SIGNATURE)?; writer.write_u32::(self.disk_with_central_directory)?; @@ -130,6 +210,28 @@ impl Zip64CentralDirectoryEndLocator { writer.write_u32::(self.number_of_disks)?; Ok(()) } + + pub async fn write_async(&self, mut writer: Pin<&mut T>) -> ZipResult<()> { + let block: [u8; 20] = unsafe { + mem::transmute(Zip64CentralDirectoryEndLocatorBuffer { + magic: ZIP64_CENTRAL_DIRECTORY_END_LOCATOR_SIGNATURE, + disk_with_central_directory: self.disk_with_central_directory, + end_of_central_directory_offset: self.end_of_central_directory_offset, + number_of_disks: self.number_of_disks, + }) + }; + + if writer.is_write_vectored() { + /* TODO: zero-copy?? */ + let block = IoSlice::new(&block); + writer.write_vectored(&[block]).await?; + } else { + /* If no special vector write support, just perform a normal write. */ + writer.write_all(&block).await?; + } + + Ok(()) + } } pub struct Zip64CentralDirectoryEnd { @@ -144,6 +246,20 @@ pub struct Zip64CentralDirectoryEnd { //pub extensible_data_sector: Vec, <-- We don't do anything with this at the moment. } +#[repr(packed)] +struct Zip64CentralDirectoryEndBuffer { + pub magic: u32, + pub record_size: u64, /* this should always be 44! */ + pub version_made_by: u16, + pub version_needed_to_extract: u16, + pub disk_number: u32, + pub disk_with_central_directory: u32, + pub number_of_files_on_this_disk: u64, + pub number_of_files: u64, + pub central_directory_size: u64, + pub central_directory_offset: u64, +} + impl Zip64CentralDirectoryEnd { pub fn find_and_parse( reader: &mut T, @@ -200,6 +316,83 @@ impl Zip64CentralDirectoryEnd { } } + pub async fn parse_async(mut reader: Pin<&mut T>) -> ZipResult { + static_assertions::assert_eq_size!([u8; 56], Zip64CentralDirectoryEndBuffer); + let mut info = [0u8; 56]; + reader.read_exact(&mut info[..]).await?; + + let Zip64CentralDirectoryEndBuffer { + magic, + record_size, + version_made_by, + version_needed_to_extract, + disk_number, + disk_with_central_directory, + number_of_files_on_this_disk, + number_of_files, + central_directory_size, + central_directory_offset, + } = unsafe { mem::transmute(info) }; + + assert_eq!(record_size, 44); + + if magic != ZIP64_CENTRAL_DIRECTORY_END_SIGNATURE { + return Err(ZipError::InvalidArchive("Invalid digital signature header")); + } + + Ok(Zip64CentralDirectoryEnd { + version_made_by, + version_needed_to_extract, + disk_number, + disk_with_central_directory, + number_of_files_on_this_disk, + number_of_files, + central_directory_size, + central_directory_offset, + }) + } + + pub const ZIP64_SEARCH_BUFFER_SIZE: usize = 2 * CentralDirectoryEnd::SEARCH_BUFFER_SIZE; + + pub async fn find_and_parse_async( + mut reader: Pin<&mut T>, + nominal_offset: u64, + search_upper_bound: u64, + ) -> ZipResult<(Self, u64)> { + let mut rightmost_frontier = reader.seek(io::SeekFrom::Start(nominal_offset)).await?; + + let mut buf = [0u8; Self::ZIP64_SEARCH_BUFFER_SIZE]; + while rightmost_frontier <= search_upper_bound { + let remaining = search_upper_bound - rightmost_frontier; + let cur_len = cmp::min(remaining as usize, buf.len()); + let cur_buf: &mut [u8] = &mut buf[..cur_len]; + + reader.read_exact(cur_buf).await?; + + if let Some(index_within_buffer) = memchr2::memmem::find( + &cur_buf, + &ZIP64_CENTRAL_DIRECTORY_END_SIGNATURE.to_le_bytes()[..], + ) { + let zip64_central_directory_end = rightmost_frontier + index_within_buffer as u64; + + reader + .seek(io::SeekFrom::Start(zip64_central_directory_end)) + .await?; + + let archive_offset = zip64_central_directory_end - nominal_offset; + return Zip64CentralDirectoryEnd::parse_async(reader) + .await + .map(|cde| (cde, archive_offset)); + } else { + rightmost_frontier += cur_len as u64; + } + } + + Err(ZipError::InvalidArchive( + "Could not find ZIP64 central directory end", + )) + } + pub fn write(&self, writer: &mut T) -> ZipResult<()> { writer.write_u32::(ZIP64_CENTRAL_DIRECTORY_END_SIGNATURE)?; writer.write_u64::(44)?; // record size @@ -213,6 +406,34 @@ impl Zip64CentralDirectoryEnd { writer.write_u64::(self.central_directory_offset)?; Ok(()) } + + pub async fn write_async(&self, mut writer: Pin<&mut T>) -> ZipResult<()> { + let block: [u8; 56] = unsafe { + mem::transmute(Zip64CentralDirectoryEndBuffer { + magic: ZIP64_CENTRAL_DIRECTORY_END_SIGNATURE, + record_size: 44, + version_made_by: self.version_made_by, + version_needed_to_extract: self.version_needed_to_extract, + disk_number: self.disk_number, + disk_with_central_directory: self.disk_with_central_directory, + number_of_files_on_this_disk: self.number_of_files_on_this_disk, + number_of_files: self.number_of_files, + central_directory_size: self.central_directory_size, + central_directory_offset: self.central_directory_offset, + }) + }; + + if writer.is_write_vectored() { + /* TODO: zero-copy?? */ + let block = IoSlice::new(&block); + writer.write_vectored(&[block]).await?; + } else { + /* If no special vector write support, just perform a normal write. */ + writer.write_all(&block).await?; + } + + Ok(()) + } } /// Converts a path to the ZIP format (forward-slash-delimited and normalized). @@ -237,3 +458,48 @@ pub(crate) fn path_to_string>(path: T) -> String { } normalized_components.join("/") } + +#[repr(packed)] +pub struct LocalHeaderBuffer { + // local file header signature + pub magic: u32, + // version needed to extract + pub version_needed_to_extract: u16, + // general purpose bit flag + pub flag: u16, + // Compression method + pub compression_method: u16, + // last mod file time and last mod file date + pub last_modified_time_timepart: u16, + pub last_modified_time_datepart: u16, + // crc-32 + pub crc32: u32, + // compressed size and uncompressed size + pub compressed_size: u32, + pub uncompressed_size: u32, + // file name length + pub file_name_length: u16, + // extra field length + pub extra_field_length: u16, +} + +#[repr(packed)] +pub struct CentralDirectoryHeaderBuffer { + pub magic: u32, + pub version_made_by: u16, + pub version_needed: u16, + pub flag: u16, + pub compression_method: u16, + pub last_modified_time_timepart: u16, + pub last_modified_time_datepart: u16, + pub crc32: u32, + pub compressed_size: u32, + pub uncompressed_size: u32, + pub file_name_length: u16, + pub extra_field_length: u16, + pub file_comment_length: u16, + pub disk_number_start: u16, + pub internal_attributes: u16, + pub external_attributes: u32, + pub header_start: u32, +} diff --git a/src/tokio/buf_reader.rs b/src/tokio/buf_reader.rs new file mode 100644 index 000000000..cd8cc11d8 --- /dev/null +++ b/src/tokio/buf_reader.rs @@ -0,0 +1,226 @@ +/* Taken from https://docs.rs/tokio/latest/src/tokio/io/util/buf_reader.rs.html to fix a few + * issues. */ + +use crate::tokio::WrappedPin; + +#[cfg(doc)] +use tokio::io::AsyncRead; +use tokio::io::{self, AsyncBufRead}; + +use std::{ + cmp, fmt, + num::NonZeroUsize, + pin::Pin, + task::{ready, Context, Poll}, +}; + +// used by `BufReader` and `BufWriter` +// https://github.com/rust-lang/rust/blob/master/library/std/src/sys_common/io.rs#L1 +const DEFAULT_BUF_SIZE: NonZeroUsize = unsafe { NonZeroUsize::new_unchecked(8 * 1024) }; + +/// The `BufReader` struct adds buffering to any reader. +/// +/// It can be excessively inefficient to work directly with a [`AsyncRead`] +/// instance. A `BufReader` performs large, infrequent reads on the underlying +/// [`AsyncRead`] and maintains an in-memory buffer of the results. +/// +/// `BufReader` can improve the speed of programs that make *small* and +/// *repeated* read calls to the same file or network socket. It does not +/// help when reading very large amounts at once, or reading just one or a few +/// times. It also provides no advantage when reading from a source that is +/// already in memory, like a `Vec`. +/// +/// When the `BufReader` is dropped, the contents of its buffer will be +/// discarded. Creating multiple instances of a `BufReader` on the same +/// stream can cause data loss. +/// +///``` +/// # fn main() -> std::io::Result<()> { tokio_test::block_on(async { +/// use zip::tokio::buf_reader::BufReader; +/// use tokio::io::AsyncReadExt; +/// use std::{io::Cursor, pin::Pin}; +/// +/// let msg = "hello"; +/// let buf = Cursor::new(msg.as_bytes()); +/// let mut buf_reader = BufReader::new(Box::pin(buf)); +/// +/// let mut s = String::new(); +/// buf_reader.read_to_string(&mut s).await?; +/// assert_eq!(&s, &msg); +/// # Ok(()) +/// # })} +///``` +pub struct BufReader { + inner: Pin>, + buf: Box<[u8]>, + pos: usize, + cap: usize, +} + +struct BufProj<'a, R> { + pub inner: Pin<&'a mut R>, + pub buf: &'a mut Box<[u8]>, +} + +impl BufReader { + /// Creates a new `BufReader` with a default buffer capacity. The default is currently 8 KB, + /// but may change in the future. + pub fn new(inner: Pin>) -> Self { + Self::with_capacity(DEFAULT_BUF_SIZE, inner) + } + + /// Creates a new `BufReader` with the specified buffer capacity. + pub fn with_capacity(capacity: NonZeroUsize, inner: Pin>) -> Self { + let buffer = vec![0; capacity.into()]; + Self { + inner, + buf: buffer.into_boxed_slice(), + pos: 0, + cap: 0, + } + } + + #[inline] + pub fn capacity(&self) -> NonZeroUsize { + unsafe { NonZeroUsize::new_unchecked(self.buf.len()) } + } + + #[inline] + fn get_pin_mut(self: Pin<&mut Self>) -> Pin<&mut R> { + self.project().inner + } + + #[inline] + fn project(self: Pin<&mut Self>) -> BufProj<'_, R> { + unsafe { + let Self { inner, buf, .. } = self.get_unchecked_mut(); + BufProj { + inner: Pin::new_unchecked(inner.as_mut().get_unchecked_mut()), + buf, + } + } + } + + #[inline] + fn is_empty(&self) -> bool { + self.cap == self.pos + } + + /// Returns a reference to the internally buffered data. + /// + /// Unlike `fill_buf`, this will not attempt to fill the buffer if it is empty. + #[inline] + pub fn buffer(&self) -> &[u8] { + &self.buf[self.pos..self.cap] + } + + /// Invalidates all data in the internal buffer. + #[inline] + fn discard_buffer(&mut self) { + self.pos = 0; + self.cap = 0; + } + + #[inline] + fn reset_buffer(&mut self, len: usize) { + self.pos = 0; + self.cap = len; + } + + #[inline] + fn request_is_larger_than_buffer(&self, buf: &io::ReadBuf<'_>) -> bool { + buf.remaining() >= self.capacity().get() + } + + #[inline] + fn should_bypass_buffer(&self, buf: &io::ReadBuf<'_>) -> bool { + self.is_empty() && self.request_is_larger_than_buffer(buf) + } +} + +impl WrappedPin for BufReader { + /// Consumes this `BufReader`, returning the underlying reader. + /// + /// Note that any leftover data in the internal buffer is lost. + fn unwrap_inner_pin(self) -> Pin> { + self.inner + } +} + +impl BufReader { + fn bypass_buffer( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut io::ReadBuf<'_>, + ) -> Poll> { + let res = ready!(self.as_mut().get_pin_mut().poll_read(cx, buf)); + self.discard_buffer(); + Poll::Ready(res) + } + + fn reset_to_single_read( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + let len = { + let me = self.as_mut().project(); + let mut buf = io::ReadBuf::new(me.buf); + ready!(me.inner.poll_read(cx, &mut buf))?; + buf.filled().len() + }; + + self.reset_buffer(len); + Poll::Ready(Ok(())) + } +} + +impl io::AsyncRead for BufReader { + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut io::ReadBuf<'_>, + ) -> Poll> { + // If we don't have any buffered data and we're doing a massive read + // (larger than our internal buffer), bypass our internal buffer + // entirely. + if self.should_bypass_buffer(buf) { + return self.bypass_buffer(cx, buf); + } + let rem = ready!(self.as_mut().poll_fill_buf(cx))?; + let amt = cmp::min(rem.len(), buf.remaining()); + buf.put_slice(&rem[..amt]); + self.consume(amt); + Poll::Ready(Ok(())) + } +} + +impl io::AsyncBufRead for BufReader { + fn poll_fill_buf(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + if self.is_empty() { + ready!(self.as_mut().reset_to_single_read(cx))?; + } + let buf: &[u8] = self.into_ref().get_ref().buffer(); + Poll::Ready(Ok(buf)) + } + + #[inline] + fn consume(self: Pin<&mut Self>, amt: usize) { + if amt == 0 { + return; + } + let me = self.get_mut(); + me.pos = cmp::min(me.pos + amt, me.cap); + } +} + +impl fmt::Debug for BufReader { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("BufReader") + .field("reader", &self.inner) + .field( + "buffer", + &format_args!("{}/{}", self.cap - self.pos, self.buf.len()), + ) + .finish() + } +} diff --git a/src/tokio/buf_writer.rs b/src/tokio/buf_writer.rs new file mode 100644 index 000000000..1449fdfa4 --- /dev/null +++ b/src/tokio/buf_writer.rs @@ -0,0 +1,301 @@ +use crate::tokio::WrappedPin; + +use tokio::io; + +use std::{ + cell::UnsafeCell, + num::NonZeroUsize, + pin::Pin, + task::{ready, Context, Poll}, +}; + +pub trait AsyncBufWrite: io::AsyncWrite { + fn consume_read(self: Pin<&mut Self>, amt: NonZeroUsize); + fn readable_data(&self) -> &[u8]; + + fn consume_write(self: Pin<&mut Self>, amt: NonZeroUsize); + fn try_writable(self: Pin<&mut Self>) -> Option>; + fn poll_writable( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll>>; + + fn reset(self: Pin<&mut Self>); +} + +// used by `BufReader` and `BufWriter` +// https://github.com/rust-lang/rust/blob/master/library/std/src/sys_common/io.rs#L1 +const DEFAULT_BUF_SIZE: NonZeroUsize = unsafe { NonZeroUsize::new_unchecked(8 * 1024) }; + +///``` +/// # fn main() -> std::io::Result<()> { tokio_test::block_on(async { +/// use zip::tokio::{WrappedPin, buf_writer::{AsyncBufWrite, BufWriter}}; +/// use tokio::io::AsyncWriteExt; +/// use std::{io::Cursor, pin::Pin}; +/// +/// let msg = "hello\n"; +/// let mut buf_writer = BufWriter::new(Box::pin(Cursor::new(Vec::new()))); +/// +/// buf_writer.write_all(msg.as_bytes()).await?; +/// buf_writer.flush().await?; +/// buf_writer.shutdown().await?; +/// let buf: Vec = Pin::into_inner(buf_writer.unwrap_inner_pin()).into_inner(); +/// let s = std::str::from_utf8(&buf).unwrap(); +/// assert_eq!(&s, &msg); +/// # Ok(()) +/// # })} +///``` +pub struct BufWriter { + inner: Pin>, + buf: Box<[u8]>, + read_end: usize, + write_end: usize, +} + +struct BufProj<'a, W> { + pub inner: Pin<&'a mut W>, + pub buf: &'a mut Box<[u8]>, + pub read_end: &'a mut usize, + pub write_end: &'a mut usize, +} + +impl BufWriter { + pub fn new(inner: Pin>) -> Self { + Self::with_capacity(DEFAULT_BUF_SIZE, inner) + } + + pub fn with_capacity(capacity: NonZeroUsize, inner: Pin>) -> Self { + let buffer = vec![0; capacity.get()]; + Self { + inner, + buf: buffer.into_boxed_slice(), + read_end: 0, + write_end: 0, + } + } + + #[inline] + pub fn capacity(&self) -> NonZeroUsize { + unsafe { NonZeroUsize::new_unchecked(self.buf.len()) } + } + + #[inline] + fn get_pin_mut(self: Pin<&mut Self>) -> Pin<&mut W> { + self.project().inner + } + + #[inline] + fn project(self: Pin<&mut Self>) -> BufProj<'_, W> { + unsafe { + let Self { + inner, + buf, + read_end, + write_end, + } = self.get_unchecked_mut(); + BufProj { + inner: Pin::new_unchecked(inner.as_mut().get_unchecked_mut()), + buf, + read_end, + write_end, + } + } + } +} + +impl WrappedPin for BufWriter { + fn unwrap_inner_pin(self) -> Pin> { + self.inner + } +} + +impl BufWriter { + fn flush_one_readable(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + assert!(!self.readable_data().is_empty()); + + let me = self.as_mut().project(); + let read_buf: &[u8] = &me.buf[*me.read_end..*me.write_end]; + let written_to_inner: usize = ready!(me.inner.poll_write(cx, read_buf))?; + match NonZeroUsize::new(written_to_inner) { + None => { + return Poll::Ready(Err(io::ErrorKind::WriteZero.into())); + } + Some(read) => { + self.consume_read(read); + } + } + + Poll::Ready(Ok(())) + } + + fn flush_readable(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + while !self.readable_data().is_empty() { + ready!(self.as_mut().flush_one_readable(cx))?; + } + Poll::Ready(Ok(())) + } +} + +impl AsyncBufWrite for BufWriter { + #[inline] + fn consume_read(self: Pin<&mut Self>, amt: NonZeroUsize) { + debug_assert!(self.readable_data().len() >= amt.get()); + let me = self.project(); + *me.read_end += amt.get(); + } + + #[inline] + fn readable_data(&self) -> &[u8] { + debug_assert!(self.read_end <= self.write_end); + debug_assert!(self.write_end <= self.buf.len()); + &self.buf[self.read_end..self.write_end] + } + + #[inline] + fn consume_write(self: Pin<&mut Self>, amt: NonZeroUsize) { + debug_assert!(self.capacity().get() - self.write_end >= amt.get()); + let me = self.project(); + *me.write_end += amt.get(); + } + + #[inline] + fn try_writable(self: Pin<&mut Self>) -> Option> { + if self.write_end == self.buf.len() { + return None; + } + let me = self.project(); + NonEmptyWriteSlice::new(&mut me.buf[*me.write_end..]) + } + + fn poll_writable( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll>> { + let s = UnsafeCell::new(self); + if let Some(write_buf) = unsafe { &mut *s.get() }.as_mut().try_writable() { + return Poll::Ready(Ok(write_buf)); + } + + ready!(unsafe { &mut *s.get() }.as_mut().flush_readable(cx))?; + + unsafe { &mut *s.get() }.as_mut().reset(); + + Poll::Ready(Ok(s.into_inner().try_writable().unwrap())) + } + + #[inline] + fn reset(self: Pin<&mut Self>) { + let me = self.project(); + *me.read_end = 0; + *me.write_end = 0; + } +} + +impl io::AsyncWrite for BufWriter { + fn poll_write( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + let buf = NonEmptyReadSlice::new(buf).unwrap(); + let mut rem: NonEmptyWriteSlice<'_, u8> = ready!(self.as_mut().poll_writable(cx))?; + + let amt = rem.copy_from_slice(buf); + dbg!(amt); + self.as_mut().consume_write(amt); + + Poll::Ready(Ok(amt.get())) + } + + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + ready!(self.as_mut().flush_readable(cx))?; + self.get_pin_mut().poll_flush(cx) + } + + fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + ready!(self.as_mut().poll_flush(cx))?; + self.get_pin_mut().poll_shutdown(cx) + } +} + +pub mod slices { + use std::{cmp, mem, num::NonZeroUsize, ops}; + + #[derive(Debug, Copy, Clone)] + pub struct NonEmptyReadSlice<'a, T> { + data: &'a [T], + } + + impl<'a, T> NonEmptyReadSlice<'a, T> { + pub fn new(data: &'a [T]) -> Option { + NonZeroUsize::new(data.len()).map(|_| Self { data }) + } + + #[inline] + pub fn len(&self) -> NonZeroUsize { + unsafe { NonZeroUsize::new_unchecked(self.data.len()) } + } + + #[inline] + pub fn maybe_uninit(&self) -> &'a [mem::MaybeUninit] { + unsafe { mem::transmute(&*self.data) } + } + } + + impl<'a, T> ops::Deref for NonEmptyReadSlice<'a, T> { + type Target = [T]; + + #[inline] + fn deref(&self) -> &[T] { + &self.data + } + } + + #[derive(Debug)] + pub struct NonEmptyWriteSlice<'a, T> { + data: &'a mut [T], + } + + impl<'a, T> ops::Deref for NonEmptyWriteSlice<'a, T> { + type Target = [T]; + + #[inline] + fn deref(&self) -> &[T] { + &self.data + } + } + + impl<'a, T> ops::DerefMut for NonEmptyWriteSlice<'a, T> { + #[inline] + fn deref_mut(&mut self) -> &mut [T] { + &mut self.data + } + } + + impl<'a, T> NonEmptyWriteSlice<'a, T> { + pub fn new(data: &'a mut [T]) -> Option { + NonZeroUsize::new(data.len()).map(|_| Self { data }) + } + + #[inline] + pub fn len(&self) -> NonZeroUsize { + unsafe { NonZeroUsize::new_unchecked(self.data.len()) } + } + + #[inline] + pub fn maybe_uninit(&mut self) -> &mut [mem::MaybeUninit] { + unsafe { mem::transmute(&mut *self.data) } + } + } + + impl<'a, T: Copy> NonEmptyWriteSlice<'a, T> { + pub fn copy_from_slice(&mut self, src: NonEmptyReadSlice<'a, T>) -> NonZeroUsize { + let amt = cmp::min(self.len(), src.len()); + let dst: &mut [mem::MaybeUninit] = self.maybe_uninit(); + let src: &[mem::MaybeUninit] = src.maybe_uninit(); + dst[..amt.get()].copy_from_slice(&src[..amt.get()]); + amt + } + } +} +pub use slices::{NonEmptyReadSlice, NonEmptyWriteSlice}; diff --git a/src/tokio/channels.rs b/src/tokio/channels.rs new file mode 100755 index 000000000..5469c5c02 --- /dev/null +++ b/src/tokio/channels.rs @@ -0,0 +1,852 @@ +#![allow(missing_docs)] + +use std::sync::atomic::{AtomicU8, Ordering}; + +#[derive(Debug)] +pub enum LeaseBehavior { + AllowSpuriousFailures, + NoSpuriousFailures, +} + +impl LeaseBehavior { + #[inline] + pub fn try_acquire_lease(self, state: &AtomicU8) -> bool { + match self { + Self::AllowSpuriousFailures => state.compare_exchange_weak( + PermitState::Unleashed.into(), + PermitState::TakenOut.into(), + Ordering::AcqRel, + Ordering::Relaxed, + ), + Self::NoSpuriousFailures => state.compare_exchange( + PermitState::Unleashed.into(), + PermitState::TakenOut.into(), + Ordering::AcqRel, + Ordering::Relaxed, + ), + } + .is_ok() + } +} + +#[derive(Debug, Copy, Clone)] +pub enum Lease { + NoSpace, + PossiblyTaken, + Taken(Permit), +} + +impl Lease { + #[inline] + pub fn option(self) -> Option { + match self { + Self::NoSpace => None, + Self::PossiblyTaken => None, + Self::Taken(permit) => Some(permit), + } + } +} + +#[repr(u8)] +#[derive( + Debug, Copy, Clone, PartialEq, Eq, Default, num_enum::TryFromPrimitive, num_enum::IntoPrimitive, +)] +pub enum PermitState { + #[default] + Unleashed = 0, + TakenOut = 1, +} + +pub mod ring { + use super::{Lease, LeaseBehavior, PermitState}; + + use std::{ + cmp, ops, slice, + sync::atomic::{AtomicU8, AtomicUsize, Ordering}, + }; + + ///``` + /// use zip::tokio::channels::*; + /// use std::sync::Arc; + /// + /// let ring = Arc::new(Ring::with_capacity(10)); + /// + /// assert!(matches![ring.request_read_lease_strong(1), Lease::NoSpace]); + /// { + /// let mut write_lease = ring.request_write_lease_strong(5).option().unwrap(); + /// write_lease.copy_from_slice(b"world"); + /// } + /// { + /// let read_lease = ring.request_read_lease_strong(5).option().unwrap(); + /// assert_eq!(std::str::from_utf8(&*read_lease).unwrap(), "world"); + /// } + /// { + /// let mut write_lease = ring.request_write_lease_strong(6).option().unwrap(); + /// assert_eq!(5, write_lease.len()); + /// write_lease.copy_from_slice(b"hello"); + /// } + /// { + /// let read_lease = ring.request_read_lease_strong(4).option().unwrap(); + /// assert_eq!(std::str::from_utf8(&*read_lease).unwrap(), "hell"); + /// } + /// { + /// let mut write_lease = ring.request_write_lease_strong(2).option().unwrap(); + /// write_lease.copy_from_slice(b"k!"); + /// } + /// let mut buf = Vec::new(); + /// { + /// let read_lease = ring.request_read_lease_strong(3).option().unwrap(); + /// assert_eq!(1, read_lease.len()); + /// buf.extend_from_slice(&read_lease); + /// } + /// { + /// let read_lease = ring.request_read_lease_strong(3).option().unwrap(); + /// assert_eq!(2, read_lease.len()); + /// buf.extend_from_slice(&read_lease); + /// } + /// assert_eq!(std::str::from_utf8(&buf).unwrap(), "ok!"); + /// assert!(matches![ring.request_read_lease_strong(1), Lease::NoSpace]); + ///``` + #[derive(Debug)] + pub struct Ring { + buf: Box<[u8]>, + write_head: AtomicUsize, + remaining_inline_write: AtomicUsize, + write_state: AtomicU8, + read_head: AtomicUsize, + remaining_inline_read: AtomicUsize, + read_state: AtomicU8, + } + + impl Ring { + pub fn clear(&mut self) { + *self.write_head.get_mut() = 0; + *self.remaining_inline_write.get_mut() = self.capacity(); + *self.write_state.get_mut() = PermitState::Unleashed.into(); + *self.read_head.get_mut() = 0; + *self.remaining_inline_read.get_mut() = 0; + *self.read_state.get_mut() = PermitState::Unleashed.into(); + } + + pub fn with_capacity(capacity: usize) -> Self { + assert!(capacity > 0); + Self { + buf: vec![0u8; capacity].into_boxed_slice(), + write_head: AtomicUsize::new(0), + remaining_inline_write: AtomicUsize::new(capacity), + write_state: AtomicU8::new(PermitState::Unleashed.into()), + read_head: AtomicUsize::new(0), + remaining_inline_read: AtomicUsize::new(0), + read_state: AtomicU8::new(PermitState::Unleashed.into()), + } + } + + #[inline] + pub fn capacity(&self) -> usize { + self.buf.len() + } + + pub(crate) fn return_write_lease(&self, permit: &WritePermit<'_>) { + debug_assert!( + self.write_state.load(Ordering::Relaxed) + == >::into(PermitState::TakenOut) + ); + + let truncated_length = permit.truncated_length(); + if truncated_length > 0 { + self.remaining_inline_write + .fetch_add(truncated_length, Ordering::Release); + } + + let len = permit.len(); + + if len > 0 { + let new_write_head = len + self.write_head.load(Ordering::Acquire); + let read_head = self.read_head.load(Ordering::Acquire); + + if new_write_head == self.capacity() { + debug_assert_eq!(0, self.remaining_inline_write.load(Ordering::Acquire)); + self.remaining_inline_write + .store(read_head, Ordering::Release); + self.write_head.store(0, Ordering::Release); + } else { + self.write_head.store(new_write_head, Ordering::Release); + } + + if new_write_head > read_head { + self.remaining_inline_read.fetch_add(len, Ordering::Release); + } + } + + self.write_state + .store(PermitState::Unleashed.into(), Ordering::Release); + } + + #[inline] + pub fn request_write_lease_weak(&self, requested_length: usize) -> Lease> { + self.request_write_lease(requested_length, LeaseBehavior::AllowSpuriousFailures) + } + + #[inline] + pub fn request_write_lease_strong( + &self, + requested_length: usize, + ) -> Lease> { + self.request_write_lease(requested_length, LeaseBehavior::NoSpuriousFailures) + } + + pub fn request_write_lease( + &self, + requested_length: usize, + behavior: LeaseBehavior, + ) -> Lease> { + assert!(requested_length > 0); + if self.remaining_inline_write.load(Ordering::Relaxed) == 0 { + return Lease::NoSpace; + } + + if !behavior.try_acquire_lease(&self.write_state) { + return Lease::PossiblyTaken; + } + + let remaining_inline_write = self.remaining_inline_write.load(Ordering::Acquire); + if remaining_inline_write == 0 { + self.write_state + .store(PermitState::Unleashed.into(), Ordering::Release); + return Lease::NoSpace; + } + + let limited_length = cmp::min(remaining_inline_write, requested_length); + debug_assert!(limited_length > 0); + self.remaining_inline_write + .fetch_sub(limited_length, Ordering::Release); + + let prev_write_head = self.write_head.load(Ordering::Acquire); + + let buf: &mut [u8] = unsafe { + let buf: *const u8 = self.buf.as_ptr(); + let start = buf.add(prev_write_head) as *mut u8; + slice::from_raw_parts_mut(start, limited_length) + }; + Lease::Taken(WritePermit::view(buf, self)) + } + + pub(crate) fn return_read_lease(&self, permit: &ReadPermit<'_>) { + debug_assert!( + self.read_state.load(Ordering::Relaxed) + == >::into(PermitState::TakenOut) + ); + + let truncated_length = permit.truncated_length(); + if truncated_length > 0 { + self.remaining_inline_read + .fetch_add(truncated_length, Ordering::Release); + } + + let len = permit.len(); + + if len > 0 { + let new_read_head = len + self.read_head.load(Ordering::Acquire); + let write_head = self.write_head.load(Ordering::Acquire); + + if new_read_head == self.capacity() { + debug_assert_eq!(0, self.remaining_inline_read.load(Ordering::Acquire)); + self.remaining_inline_read + .store(write_head, Ordering::Release); + self.read_head.store(0, Ordering::Release); + } else { + self.read_head.store(new_read_head, Ordering::Release); + } + + if new_read_head > write_head { + self.remaining_inline_write + .fetch_add(len, Ordering::Release); + } + } + + self.read_state + .store(PermitState::Unleashed.into(), Ordering::Release); + } + + #[inline] + pub fn request_read_lease_weak(&self, requested_length: usize) -> Lease> { + self.request_read_lease(requested_length, LeaseBehavior::AllowSpuriousFailures) + } + + #[inline] + pub fn request_read_lease_strong(&self, requested_length: usize) -> Lease> { + self.request_read_lease(requested_length, LeaseBehavior::NoSpuriousFailures) + } + + pub fn request_read_lease( + &self, + requested_length: usize, + behavior: LeaseBehavior, + ) -> Lease> { + assert!(requested_length > 0); + if self.remaining_inline_read.load(Ordering::Relaxed) == 0 { + return Lease::NoSpace; + } + + if !behavior.try_acquire_lease(&self.read_state) { + return Lease::PossiblyTaken; + } + + let remaining_inline_read = self.remaining_inline_read.load(Ordering::Acquire); + if remaining_inline_read == 0 { + self.read_state + .store(PermitState::Unleashed.into(), Ordering::Release); + return Lease::NoSpace; + } + + let limited_length = cmp::min(remaining_inline_read, requested_length); + debug_assert!(limited_length > 0); + self.remaining_inline_read + .fetch_sub(limited_length, Ordering::Release); + + let prev_read_head = self.read_head.load(Ordering::Acquire); + + let buf: &[u8] = unsafe { + let buf: *const u8 = self.buf.as_ptr(); + let start = buf.add(prev_read_head); + slice::from_raw_parts(start, limited_length) + }; + Lease::Taken(ReadPermit::view(buf, self)) + } + } + + ///``` + /// use zip::tokio::channels::*; + /// use std::sync::Arc; + /// + /// let msg = "hello world"; + /// let ring = Arc::new(Ring::with_capacity(30)); + /// + /// let mut buf = Vec::new(); + /// { + /// let mut write_lease = ring.request_write_lease_strong(5).option().unwrap(); + /// write_lease.copy_from_slice(&msg.as_bytes()[..5]); + /// write_lease.truncate(4); + /// assert_eq!(4, write_lease.len()); + /// } + /// { + /// let mut read_lease = ring.request_read_lease_strong(5).option().unwrap(); + /// assert_eq!(4, read_lease.len()); + /// buf.extend_from_slice(read_lease.truncate(1)); + /// assert_eq!(1, buf.len()); + /// assert_eq!(1, read_lease.len()); + /// } + /// { + /// let mut write_lease = ring.request_write_lease_strong(msg.len() - 4).option().unwrap(); + /// write_lease.copy_from_slice(&msg.as_bytes()[4..]); + /// } + /// { + /// let read_lease = ring.request_read_lease_strong(msg.len() - 1).option().unwrap(); + /// assert_eq!(read_lease.len(), msg.len() - 1); + /// buf.extend_from_slice(&read_lease); + /// } + /// assert_eq!(msg, std::str::from_utf8(&buf).unwrap()); + ///``` + pub trait TruncateLength { + fn truncated_length(&self) -> usize; + fn truncate(&mut self, len: usize) -> &mut Self; + } + + #[derive(Debug)] + pub struct ReadPermit<'a> { + view: &'a [u8], + parent: &'a Ring, + original_length: usize, + } + + impl<'a> ReadPermit<'a> { + pub(crate) fn view(view: &'a [u8], parent: &'a Ring) -> Self { + let original_length = view.len(); + Self { + view, + parent, + original_length, + } + } + } + + impl<'a> TruncateLength for ReadPermit<'a> { + #[inline] + fn truncated_length(&self) -> usize { + self.original_length - self.len() + } + + #[inline] + fn truncate(&mut self, len: usize) -> &mut Self { + assert!(len <= self.len()); + self.view = &self.view[..len]; + self + } + } + + impl<'a> ops::Drop for ReadPermit<'a> { + fn drop(&mut self) { + self.parent.return_read_lease(self); + } + } + + impl<'a> AsRef<[u8]> for ReadPermit<'a> { + #[inline] + fn as_ref(&self) -> &[u8] { + &self.view + } + } + + impl<'a> ops::Deref for ReadPermit<'a> { + type Target = [u8]; + + #[inline] + fn deref(&self) -> &[u8] { + &self.view + } + } + + #[derive(Debug)] + pub struct WritePermit<'a> { + view: &'a mut [u8], + parent: &'a Ring, + original_length: usize, + } + + impl<'a> WritePermit<'a> { + pub(crate) fn view(view: &'a mut [u8], parent: &'a Ring) -> Self { + let original_length = view.len(); + Self { + view, + parent, + original_length, + } + } + } + + impl<'a> TruncateLength for WritePermit<'a> { + #[inline] + fn truncated_length(&self) -> usize { + self.original_length - self.len() + } + + #[inline] + fn truncate(&mut self, len: usize) -> &mut Self { + assert!(len <= self.len()); + self.view = unsafe { slice::from_raw_parts_mut(self.view.as_ptr() as *mut u8, len) }; + self + } + } + + impl<'a> ops::Drop for WritePermit<'a> { + fn drop(&mut self) { + self.parent.return_write_lease(self); + } + } + + impl<'a> AsRef<[u8]> for WritePermit<'a> { + #[inline] + fn as_ref(&self) -> &[u8] { + &self.view + } + } + + impl<'a> AsMut<[u8]> for WritePermit<'a> { + #[inline] + fn as_mut(&mut self) -> &mut [u8] { + &mut self.view + } + } + + impl<'a> ops::Deref for WritePermit<'a> { + type Target = [u8]; + + #[inline] + fn deref(&self) -> &[u8] { + &self.view + } + } + + impl<'a> ops::DerefMut for WritePermit<'a> { + #[inline] + fn deref_mut(&mut self) -> &mut [u8] { + &mut self.view + } + } +} +pub use ring::{ReadPermit, Ring, TruncateLength, WritePermit}; + +pub mod push { + use super::PermitState; + + use std::{ + cell, mem, + sync::atomic::{AtomicU8, Ordering}, + }; + + pub struct Pusher { + elements: cell::UnsafeCell>, + state: AtomicU8, + } + + impl Pusher { + pub fn new() -> Self { + Self { + elements: cell::UnsafeCell::new(Vec::new()), + state: AtomicU8::new(PermitState::Unleashed.into()), + } + } + + fn within_lock) -> O>(&self, f: F) -> O { + while let Err(_) = self.state.compare_exchange_weak( + PermitState::Unleashed.into(), + PermitState::TakenOut.into(), + Ordering::AcqRel, + Ordering::Relaxed, + ) {} + + let v: &mut Vec = unsafe { &mut *self.elements.get() }; + let ret = f(v); + + self.state + .store(PermitState::Unleashed.into(), Ordering::Release); + + ret + } + + pub fn push(&self, x: T) { + self.within_lock(|v| v.push(x)) + } + + pub fn extract(&self) -> Vec { + self.within_lock(|v| mem::take(v)) + } + + pub fn take_owned(&mut self) -> Vec { + mem::replace(&mut self.elements, cell::UnsafeCell::new(Vec::new())).into_inner() + } + } + + unsafe impl Sync for Pusher {} +} + +pub mod futurized { + use super::{ + push::Pusher, + ring::{ReadPermit, Ring, WritePermit}, + Lease, LeaseBehavior, + }; + + use once_cell::sync::Lazy; + use parking_lot::Mutex; + + use std::{ + collections::VecDeque, + mem, ops, + task::{Context, Poll, Waker}, + }; + + static RING_BUF_FREE_LIST: Lazy>> = + Lazy::new(|| Mutex::new(VecDeque::new())); + + fn get_or_create_ring Ring>(f: F) -> Ring { + RING_BUF_FREE_LIST.lock().pop_front().unwrap_or_else(f) + } + + fn return_ring(mut ring: Ring) { + ring.clear(); + RING_BUF_FREE_LIST.lock().push_back(ring); + } + + ///``` + /// # fn main() { tokio_test::block_on(async { + /// use zip::tokio::channels::{*, futurized::*}; + /// use futures_util::future::poll_fn; + /// use tokio::task; + /// use std::{cell::UnsafeCell, pin::Pin}; + /// + /// let ring = UnsafeCell::new(RingFuturized::new()); + /// let read_lease = poll_fn(|cx| unsafe { &mut *ring.get() }.poll_read(cx, 5)); + /// { + /// let mut write_lease = poll_fn(|cx| { + /// unsafe { &mut *ring.get() }.poll_write(cx, 20) + /// }).await; + /// write_lease.truncate(5).copy_from_slice(b"hello"); + /// } + /// { + /// let read_lease = read_lease.await; + /// assert_eq!("hello", std::str::from_utf8(&read_lease).unwrap()); + /// } + /// # })} + ///``` + pub struct RingFuturized { + buf: mem::ManuallyDrop, + read_wakers: Pusher, + write_wakers: Pusher, + } + + impl ops::Drop for RingFuturized { + fn drop(&mut self) { + let Self { + buf, + read_wakers, + write_wakers, + } = self; + for waker in read_wakers + .take_owned() + .into_iter() + .chain(write_wakers.take_owned().into_iter()) + { + waker.wake(); + } + return_ring(unsafe { mem::ManuallyDrop::take(buf) }); + } + } + + impl RingFuturized { + #[inline] + pub fn capacity(&self) -> usize { + self.buf.capacity() + } + + pub fn new() -> Self { + let ring = get_or_create_ring(|| Ring::with_capacity(8 * 1024)); + Self { + buf: mem::ManuallyDrop::new(ring), + read_wakers: Pusher::::new(), + write_wakers: Pusher::::new(), + } + } + + /* pub fn wrap_ring(buf: Ring) -> Self { */ + /* Self { */ + /* buf: Arc::new(buf), */ + /* read_wakers: Arc::new(Pusher::::new()), */ + /* write_wakers: Arc::new(Pusher::::new()), */ + /* } */ + /* } */ + + /* pub fn poll_read_until_no_space( */ + /* &mut self, */ + /* cx: &mut Context<'_>, */ + /* ) -> Poll> { */ + /* match self.buf.request_read_lease(self.capacity()) { */ + /* Lease::NoSpace => Poll::Ready(None), */ + /* Lease::PossiblyTaken => { */ + /* self.read_wakers.push(cx.waker().clone()); */ + /* Poll::Pending */ + /* } */ + /* Lease::Taken(permit) => Poll::Ready(Some(ReadPermitFuturized::for_buf( */ + /* permit, */ + /* self.read_wakers.clone(), */ + /* self.write_wakers.clone(), */ + /* ))), */ + /* } */ + /* } */ + + pub fn poll_read( + &mut self, + cx: &mut Context<'_>, + requested_length: usize, + ) -> Poll> { + match self + .buf + .request_read_lease(requested_length, LeaseBehavior::AllowSpuriousFailures) + { + Lease::NoSpace | Lease::PossiblyTaken => { + self.read_wakers.push(cx.waker().clone()); + Poll::Pending + } + Lease::Taken(permit) => Poll::Ready(ReadPermitFuturized::for_buf( + permit, + &self.read_wakers, + &self.write_wakers, + )), + } + } + + pub fn poll_write( + &mut self, + cx: &mut Context<'_>, + requested_length: usize, + ) -> Poll> { + match self + .buf + .request_write_lease(requested_length, LeaseBehavior::AllowSpuriousFailures) + { + Lease::NoSpace | Lease::PossiblyTaken => { + self.write_wakers.push(cx.waker().clone()); + Poll::Pending + } + Lease::Taken(permit) => Poll::Ready(WritePermitFuturized::for_buf( + permit, + &self.read_wakers, + &self.write_wakers, + )), + } + } + } + + pub struct ReadPermitFuturized<'a> { + buf: mem::ManuallyDrop>, + read_wakers: &'a Pusher, + write_wakers: &'a Pusher, + } + + impl<'a> ReadPermitFuturized<'a> { + pub(crate) fn for_buf( + buf: ReadPermit<'a>, + read_wakers: &'a Pusher, + write_wakers: &'a Pusher, + ) -> Self { + Self { + buf: mem::ManuallyDrop::new(buf), + read_wakers, + write_wakers, + } + } + } + + impl<'a> ops::Drop for ReadPermitFuturized<'a> { + fn drop(&mut self) { + let Self { + buf, + read_wakers, + write_wakers, + } = self; + let was_empty = buf.is_empty(); + /* Drop the ReadPermit first to close out the owned region in the parent Ring before + * waking up any tasks. */ + unsafe { + mem::ManuallyDrop::drop(buf); + } + /* Notify any blocked readers. */ + for waker in read_wakers.extract().into_iter() { + waker.wake(); + } + if !was_empty { + /* Notify any blocked writers. */ + for waker in write_wakers.extract().into_iter() { + waker.wake(); + } + } + } + } + + impl<'a> AsRef> for ReadPermitFuturized<'a> { + #[inline] + fn as_ref(&self) -> &ReadPermit<'a> { + &self.buf + } + } + + impl<'a> AsMut> for ReadPermitFuturized<'a> { + #[inline] + fn as_mut(&mut self) -> &mut ReadPermit<'a> { + &mut self.buf + } + } + + impl<'a> ops::Deref for ReadPermitFuturized<'a> { + type Target = ReadPermit<'a>; + + #[inline] + fn deref(&self) -> &ReadPermit<'a> { + &self.buf + } + } + + impl<'a> ops::DerefMut for ReadPermitFuturized<'a> { + #[inline] + fn deref_mut(&mut self) -> &mut ReadPermit<'a> { + &mut self.buf + } + } + + pub struct WritePermitFuturized<'a> { + buf: mem::ManuallyDrop>, + read_wakers: &'a Pusher, + write_wakers: &'a Pusher, + } + + impl<'a> WritePermitFuturized<'a> { + pub(crate) fn for_buf( + buf: WritePermit<'a>, + read_wakers: &'a Pusher, + write_wakers: &'a Pusher, + ) -> Self { + Self { + buf: mem::ManuallyDrop::new(buf), + read_wakers, + write_wakers, + } + } + } + + impl<'a> ops::Drop for WritePermitFuturized<'a> { + fn drop(&mut self) { + let Self { + buf, + read_wakers, + write_wakers, + } = self; + let was_empty = buf.is_empty(); + /* Drop the WritePermit first to close out the owned region in the parent Ring before + * waking up any tasks. */ + unsafe { + mem::ManuallyDrop::drop(buf); + } + /* Notify any blocked writers. */ + for waker in write_wakers.extract().into_iter() { + waker.wake(); + } + if !was_empty { + /* Notify any blocked readers. */ + for waker in read_wakers.extract().into_iter() { + waker.wake(); + } + } + } + } + + impl<'a> AsRef> for WritePermitFuturized<'a> { + #[inline] + fn as_ref(&self) -> &WritePermit<'a> { + &self.buf + } + } + + impl<'a> AsMut> for WritePermitFuturized<'a> { + #[inline] + fn as_mut(&mut self) -> &mut WritePermit<'a> { + &mut self.buf + } + } + + impl<'a> ops::Deref for WritePermitFuturized<'a> { + type Target = WritePermit<'a>; + + #[inline] + fn deref(&self) -> &WritePermit<'a> { + &self.buf + } + } + + impl<'a> ops::DerefMut for WritePermitFuturized<'a> { + #[inline] + fn deref_mut(&mut self) -> &mut WritePermit<'a> { + &mut self.buf + } + } +} + +/* impl std::io::Read for RingBuffer { */ +/* fn read(&mut self, buf: &mut [u8]) -> io::Result { */ +/* debug_assert!(!buf.is_empty()); */ +/* /\* TODO: is this sufficient to make underflow unambiguous? *\/ */ +/* static_assertions::const_assert!(N < (usize::MAX >> 1)); */ + +/* let requested_length: usize = cmp::min(N, buf.len()); */ +/* self.remaining */ +/* } */ +/* } */ diff --git a/src/tokio/combinators.rs b/src/tokio/combinators.rs new file mode 100755 index 000000000..c3a737263 --- /dev/null +++ b/src/tokio/combinators.rs @@ -0,0 +1,467 @@ +use crate::tokio::WrappedPin; + +use tokio::io; + +use std::{ + pin::Pin, + task::{Context, Poll}, +}; + +pub mod stream_adaptors { + use super::*; + use crate::tokio::channels::{futurized::*, *}; + + use std::{cmp, task::ready}; + + pub trait KnownExpanse { + /* TODO: make this have a parameterized Self::Index type, used e.g. with RangeInclusive or + * something. */ + fn full_length(&self) -> usize; + } + + ///``` + /// # fn main() -> zip::result::ZipResult<()> { tokio_test::block_on(async { + /// use std::{io::{SeekFrom, Cursor}, pin::Pin}; + /// use tokio::io::{AsyncReadExt, AsyncSeekExt, AsyncWriteExt}; + /// use zip::tokio::combinators::Limiter; + /// + /// let mut buf = Cursor::new(Vec::new()); + /// buf.write_all(b"hello\n").await?; + /// buf.seek(SeekFrom::Start(1)).await?; + /// + /// let mut limited = Limiter::take(1, Box::pin(buf), 3); + /// let mut s = String::new(); + /// limited.read_to_string(&mut s).await?; + /// assert_eq!(s, "ell"); + /// + /// limited.seek(SeekFrom::End(-1)).await?; + /// s.clear(); + /// limited.read_to_string(&mut s).await?; + /// assert_eq!(s, "l"); + /// # Ok(()) + /// # })} + ///``` + #[derive(Debug)] + pub struct Limiter { + pub max_len: usize, + pub internal_pos: usize, + pub start_pos: u64, + pub source_stream: Pin>, + } + + impl Limiter { + pub fn take(start_pos: u64, source_stream: Pin>, limit: usize) -> Self { + Self { + max_len: limit, + internal_pos: 0, + start_pos, + source_stream, + } + } + + #[inline] + fn pin_stream(self: Pin<&mut Self>) -> Pin<&mut S> { + self.get_mut().source_stream.as_mut() + } + + #[inline] + fn remaining_len(&self) -> usize { + debug_assert!(self.internal_pos <= self.max_len); + self.max_len - self.internal_pos + } + + #[inline] + fn limit_length(&self, requested_length: usize) -> usize { + cmp::min(self.remaining_len(), requested_length) + } + + #[inline] + fn push_cursor(&mut self, len: usize) { + debug_assert!(len <= self.remaining_len()); + self.internal_pos += len; + } + + #[inline] + fn convert_seek_request_to_relative(&self, op: io::SeekFrom) -> i64 { + let cur = self.internal_pos as u64; + let new_point = cmp::min( + self.max_len as u64, + match op { + io::SeekFrom::Start(new_point) => new_point, + io::SeekFrom::End(from_end) => { + cmp::max(0, self.max_len as i64 + from_end) as u64 + } + io::SeekFrom::Current(from_cur) => cmp::max(0, cur as i64 + from_cur) as u64, + }, + ); + let diff = new_point as i64 - cur as i64; + diff + } + + #[inline] + fn interpret_new_pos(&mut self, new_pos: u64) { + assert!(new_pos >= self.start_pos); + assert!(new_pos <= self.start_pos + self.max_len as u64); + self.internal_pos = (new_pos - self.start_pos) as usize; + } + } + + impl WrappedPin for Limiter { + fn unwrap_inner_pin(self) -> Pin> { + self.source_stream + } + } + + impl KnownExpanse for Limiter { + #[inline] + fn full_length(&self) -> usize { + self.max_len + } + } + + impl io::AsyncRead for Limiter { + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut io::ReadBuf<'_>, + ) -> Poll> { + debug_assert!(buf.remaining() > 0); + + let num_bytes_to_read: usize = self.as_mut().limit_length(buf.remaining()); + if num_bytes_to_read == 0 { + return Poll::Ready(Ok(())); + } + + buf.initialize_unfilled_to(num_bytes_to_read); + let mut unfilled_buf = buf.take(num_bytes_to_read); + match self.as_mut().pin_stream().poll_read(cx, &mut unfilled_buf) { + Poll::Pending => Poll::Pending, + Poll::Ready(x) => { + let bytes_read = unfilled_buf.filled().len(); + Poll::Ready(x.map(|()| { + assert!(bytes_read <= num_bytes_to_read); + if bytes_read > 0 { + buf.advance(bytes_read); + self.push_cursor(bytes_read); + } + })) + } + } + } + } + + impl io::AsyncSeek for Limiter { + fn start_seek(self: Pin<&mut Self>, op: io::SeekFrom) -> io::Result<()> { + let diff = self.convert_seek_request_to_relative(op); + let s = self.get_mut(); + Pin::new(&mut s.source_stream).start_seek(io::SeekFrom::Current(diff)) + } + fn poll_complete(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let s = self.get_mut(); + let result = ready!(Pin::new(&mut s.source_stream).poll_complete(cx)); + if let Ok(ref cur_pos) = result { + s.interpret_new_pos(*cur_pos); + } + Poll::Ready(result) + } + } + + impl io::AsyncWrite for Limiter { + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + debug_assert!(!buf.is_empty()); + + let num_bytes_to_write: usize = self.limit_length(buf.len()); + if num_bytes_to_write == 0 { + return Poll::Ready(Ok(0)); + } + + let s = self.get_mut(); + match Pin::new(&mut s.source_stream).poll_write(cx, &buf[..num_bytes_to_write]) { + Poll::Pending => Poll::Pending, + Poll::Ready(x) => Poll::Ready(x.map(|bytes_written| { + assert!(bytes_written <= num_bytes_to_write); + if bytes_written > 0 { + s.push_cursor(bytes_written); + } + bytes_written + })), + } + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let s = self.get_mut(); + Pin::new(&mut s.source_stream).poll_flush(cx) + } + + fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let s = self.get_mut(); + Pin::new(&mut s.source_stream).poll_shutdown(cx) + } + } + + /* ///``` */ + /* /// # fn main() -> zip::result::ZipResult<()> { tokio_test::block_on(async { */ + /* /// use std::{io::{Cursor, prelude::*}, pin::Pin, sync::Arc}; */ + /* /// use tokio::{io::{self, AsyncReadExt}, fs}; */ + /* /// */ + /* /// let mut buf = Cursor::new(Vec::new()); */ + /* /// buf.write_all(b"hello\n")?; */ + /* /// buf.rewind()?; */ + /* /// let mut f = zip::combinators::AsyncIoAdapter::new(buf); */ + /* /// let mut buf: Vec = Vec::new(); */ + /* /// f.read_to_end(&mut buf).await?; */ + /* /// assert_eq!(&buf, b"hello\n"); */ + /* /// # Ok(()) */ + /* /// # })} */ + /* ///``` */ + /* pub struct AsyncIoAdapter { */ + /* inner: Arc>>, */ + /* tx: Arc>>>>, */ + /* rx: oneshot::Receiver>, */ + /* ring: RingFuturized, */ + /* } */ + /* impl AsyncIoAdapter { */ + /* pub fn new(inner: S) -> Self { */ + /* let (tx, rx) = oneshot::channel::>(); */ + /* Self { */ + /* inner: Arc::new(sync::Mutex::new(Some(inner))), */ + /* tx: Arc::new(parking_lot::Mutex::new(Some(tx))), */ + /* rx, */ + /* ring: RingFuturized::new(), */ + /* } */ + /* } */ + + /* pub fn into_inner(self) -> S { */ + /* self.inner */ + /* .try_lock_owned() */ + /* .expect("there should be no further handles to this mutex") */ + /* .take() */ + /* .unwrap() */ + /* } */ + /* } */ + + /* impl AsyncIoAdapter { */ + /* fn do_write( */ + /* mut write_lease: WritePermitFuturized, */ + /* mut inner: sync::OwnedMappedMutexGuard, S>, */ + /* tx: Arc>>>>, */ + /* ) { */ + /* assert!(!write_lease.is_empty()); */ + /* match inner.read(&mut write_lease) { */ + /* Err(e) => { */ + /* /\* If any error occurs, assume no bytes were written, as per std::io::Read */ + /* * docs. *\/ */ + /* write_lease.truncate(0); */ + /* if e.kind() != io::ErrorKind::Interrupted { */ + /* /\* Interrupted is the only retryable error according to the docs. *\/ */ + /* if let Some(tx) = tx.lock().take() { */ + /* tx.send(Err(e)) */ + /* .expect("receiver should not have been dropped yet!"); */ + /* } */ + /* } */ + /* } */ + /* Ok(n) => { */ + /* write_lease.truncate(n); */ + /* /\* If we received 0 after providing a non-empty output buf, assume we are */ + /* * OVER! *\/ */ + /* if n == 0 { */ + /* if let Some(tx) = tx.lock().take() { */ + /* tx.send(Ok(())) */ + /* .expect("receiver should not have been dropped yet!"); */ + /* } */ + /* } */ + /* } */ + /* } */ + /* } */ + /* } */ + + /* impl io::AsyncRead for AsyncIoAdapter { */ + /* fn poll_read( */ + /* self: Pin<&mut Self>, */ + /* cx: &mut Context<'_>, */ + /* buf: &mut io::ReadBuf<'_>, */ + /* ) -> Poll> { */ + /* debug_assert!(buf.remaining() > 0); */ + + /* let s = self.get_mut(); */ + + /* if let Poll::Ready(read_data) = s.ring.poll_read(cx, buf.remaining()) { */ + /* debug_assert!(!read_data.is_empty()); */ + /* buf.put_slice(&**read_data); */ + /* return Poll::Ready(Ok(())); */ + /* } */ + + /* if let Poll::Ready(result) = Pin::new(&mut s.rx).poll(cx) { */ + /* return Poll::Ready( */ + /* result.expect("sender should not have been dropped without sending!"), */ + /* ); */ + /* } */ + + /* let write_data = ready!(s.ring.poll_write(cx, s.ring.capacity())); */ + /* /\* FIXME: i'm pretty sure this can cause a segfault with the current */ + /* * mem::ManuallyDrop::take() strategy and global mutex for Ring instances. *\/ */ + /* let write_data: WritePermitFuturized<'static> = unsafe { mem::transmute(write_data) }; */ + /* let tx = s.tx.clone(); */ + /* if let Ok(inner) = s.inner.clone().try_lock_owned() { */ + /* let inner = sync::OwnedMutexGuard::map(inner, |inner| inner.as_mut().unwrap()); */ + /* task::spawn_blocking(move || { */ + /* Self::do_write(write_data, inner, tx); */ + /* }); */ + /* Poll::Pending */ + /* } else { */ + /* let inner = s.inner.clone(); */ + /* task::spawn(async move { */ + /* let inner = sync::OwnedMutexGuard::map(inner.lock_owned().await, |inner| { */ + /* inner.as_mut().unwrap() */ + /* }); */ + /* task::spawn_blocking(move || { */ + /* Self::do_write(write_data, inner, tx); */ + /* }); */ + /* }); */ + /* Poll::Pending */ + /* } */ + /* } */ + /* } */ + + /* impl AsyncIoAdapter { */ + /* fn do_read( */ + /* mut read_lease: ReadPermitFuturized, */ + /* mut inner: sync::OwnedMappedMutexGuard, S>, */ + /* tx: Arc>>>>, */ + /* ) { */ + /* match inner.write(&read_lease) { */ + /* Err(e) => { */ + /* if e.kind() == io::ErrorKind::Interrupted { */ + /* read_lease.truncate(0); */ + /* } else { */ + /* if let Some(tx) = tx.lock().take() { */ + /* tx.send(Err(e)) */ + /* .expect("receiver should not have been dropped yet!"); */ + /* } */ + /* } */ + /* } */ + /* Ok(n) => { */ + /* read_lease.truncate(n); */ + /* if n == 0 { */ + /* if let Some(tx) = tx.lock().take() { */ + /* tx.send(Ok(())) */ + /* .expect("receiver should not have been dropped yet!"); */ + /* } */ + /* } */ + /* } */ + /* } */ + /* } */ + /* } */ + + /* impl io::AsyncWrite for AsyncIoAdapter { */ + /* fn poll_write( */ + /* self: Pin<&mut Self>, */ + /* cx: &mut Context<'_>, */ + /* buf: &[u8], */ + /* ) -> Poll> { */ + /* debug_assert!(!buf.is_empty()); */ + + /* let s = self.get_mut(); */ + + /* if let Poll::Ready(mut write_data) = s.ring.poll_write(cx, buf.len()) { */ + /* debug_assert!(!write_data.is_empty()); */ + /* let len = write_data.len(); */ + /* write_data.copy_from_slice(&buf[..len]); */ + /* return Poll::Ready(Ok(len)); */ + /* } */ + + /* if let Poll::Ready(result) = Pin::new(&mut s.rx).poll(cx) { */ + /* return Poll::Ready( */ + /* result */ + /* .expect("sender should not have been dropped without sending!") */ + /* .map(|()| 0), */ + /* ); */ + /* } */ + + /* let tx = s.tx.clone(); */ + /* let read_data = ready!(s.ring.poll_read(cx, buf.len())); */ + /* if let Ok(inner) = s.inner.clone().try_lock_owned() { */ + /* let inner = sync::OwnedMutexGuard::map(inner, |inner| inner.as_mut().unwrap()); */ + /* task::spawn_blocking(move || { */ + /* Self::do_read(read_data, inner, tx); */ + /* }); */ + /* Poll::Pending */ + /* } else { */ + /* let inner = s.inner.clone(); */ + /* task::spawn(async move { */ + /* let inner = sync::OwnedMutexGuard::map(inner.lock_owned().await, |inner| { */ + /* inner.as_mut().unwrap() */ + /* }); */ + /* task::spawn_blocking(move || { */ + /* Self::do_read(read_data, inner, tx); */ + /* }); */ + /* }); */ + /* Poll::Pending */ + /* } */ + /* } */ + + /* fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { */ + /* let s = self.get_mut(); */ + + /* let tx = s.tx.clone(); */ + /* match ready!(s.ring.poll_read_until_no_space(cx)) { */ + /* Some(read_data) => { */ + /* if let Ok(inner) = s.inner.clone().try_lock_owned() { */ + /* let inner = */ + /* sync::OwnedMutexGuard::map(inner, |inner| inner.as_mut().unwrap()); */ + /* task::spawn_blocking(move || { */ + /* Self::do_read(read_data, inner, tx); */ + /* }); */ + /* Poll::Pending */ + /* } else { */ + /* let inner = s.inner.clone(); */ + /* task::spawn(async move { */ + /* let inner = */ + /* sync::OwnedMutexGuard::map(inner.lock_owned().await, |inner| { */ + /* inner.as_mut().unwrap() */ + /* }); */ + /* task::spawn_blocking(move || { */ + /* Self::do_read(read_data, inner, tx); */ + /* }); */ + /* }); */ + /* Poll::Pending */ + /* } */ + /* } */ + /* None => { */ + /* if let Ok(inner) = s.inner.clone().try_lock_owned() { */ + /* let mut inner = */ + /* sync::OwnedMutexGuard::map(inner, |inner| inner.as_mut().unwrap()); */ + /* task::spawn_blocking(move || { */ + /* match inner.flush() { */ + /* Ok(()) */ + /* } */ + /* Self::do_read(read_data, inner, tx); */ + /* }); */ + /* Poll::Pending */ + /* } else { */ + /* let inner = s.inner.clone(); */ + /* task::spawn(async move { */ + /* let inner = */ + /* sync::OwnedMutexGuard::map(inner.lock_owned().await, |inner| { */ + /* inner.as_mut().unwrap() */ + /* }); */ + /* task::spawn_blocking(move || { */ + /* Self::do_read(read_data, inner, tx); */ + /* }); */ + /* }); */ + /* Poll::Pending */ + /* } */ + /* } */ + /* } */ + /* } */ + + /* fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { */ + /* self.poll_flush(cx) */ + /* } */ + /* } */ +} +pub use stream_adaptors::{/* AsyncIoAdapter, */ KnownExpanse, Limiter}; diff --git a/src/tokio/crc32.rs b/src/tokio/crc32.rs new file mode 100644 index 000000000..32632a32f --- /dev/null +++ b/src/tokio/crc32.rs @@ -0,0 +1,151 @@ +//! Helper module to compute a CRC32 checksum + +use crate::tokio::WrappedPin; + +use crc32fast::Hasher; +use tokio::io; + +use std::pin::Pin; +use std::task::{ready, Context, Poll}; + +/// Reader that validates the CRC32 when it reaches the EOF. +pub struct Crc32Reader { + inner: Pin>, + hasher: Hasher, + check: u32, + /// Signals if `inner` stores aes encrypted data. + /// AE-2 encrypted data doesn't use crc and sets the value to 0. + ae2_encrypted: bool, +} + +struct Crc32Proj<'a, R> { + pub inner: Pin<&'a mut R>, + pub hasher: &'a mut Hasher, + pub check: &'a mut u32, + pub ae2_encrypted: &'a mut bool, +} + +impl Crc32Reader { + /// Get a new Crc32Reader which checks the inner reader against checksum. + /// The check is disabled if `ae2_encrypted == true`. + pub(crate) fn new(inner: Pin>, checksum: u32, ae2_encrypted: bool) -> Self { + Crc32Reader { + inner, + hasher: Hasher::new(), + check: checksum, + ae2_encrypted, + } + } + + #[inline] + fn project(self: Pin<&mut Self>) -> Crc32Proj<'_, R> { + unsafe { + let Self { + inner, + hasher, + check, + ae2_encrypted, + } = self.get_unchecked_mut(); + Crc32Proj { + inner: Pin::new_unchecked(inner.as_mut().get_unchecked_mut()), + hasher, + check, + ae2_encrypted, + } + } + } +} + +impl WrappedPin for Crc32Reader { + fn unwrap_inner_pin(self) -> Pin> { + self.inner + } +} + +impl io::AsyncRead for Crc32Reader { + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut io::ReadBuf<'_>, + ) -> Poll> { + if buf.remaining() == 0 { + return Poll::Ready(Ok(())); + } + let start = buf.filled().len(); + + let me = self.project(); + + if let Err(e) = ready!(me.inner.poll_read(cx, buf)) { + return Poll::Ready(Err(e)); + } + + let written: usize = buf.filled().len() - start; + if written == 0 { + return Poll::Ready( + if !*me.ae2_encrypted && (*me.check != me.hasher.clone().finalize()) { + Err(io::Error::new(io::ErrorKind::Other, "Invalid checksum")) + } else { + Ok(()) + }, + ); + } + + me.hasher.update(&buf.filled()[start..]); + Poll::Ready(Ok(())) + } +} + +#[cfg(test)] +mod test { + use super::*; + use crate::result::ZipResult; + + use tokio::io::AsyncReadExt; + + #[tokio::test] + async fn test_empty_reader() -> ZipResult<()> { + let data: &[u8] = b""; + let mut buf = [0; 1]; + + let mut reader = Crc32Reader::new(Box::pin(data.clone()), 0, false); + assert_eq!(reader.read(&mut buf).await?, 0); + + let mut reader = Crc32Reader::new(Box::pin(data), 1, false); + assert!(reader + .read(&mut buf) + .await + .unwrap_err() + .to_string() + .contains("Invalid checksum")); + Ok(()) + } + + #[tokio::test] + async fn test_byte_by_byte() -> ZipResult<()> { + let data: &[u8] = b"1234"; + let mut buf = [0; 1]; + + let mut reader = Crc32Reader::new(Box::pin(data), 0x9be3e0a3, false); + assert_eq!(reader.read(&mut buf).await?, 1); + assert_eq!(reader.read(&mut buf).await?, 1); + assert_eq!(reader.read(&mut buf).await?, 1); + assert_eq!(reader.read(&mut buf).await?, 1); + assert_eq!(reader.read(&mut buf).await?, 0); + // Can keep reading 0 bytes after the end + assert_eq!(reader.read(&mut buf).await?, 0); + + Ok(()) + } + + #[tokio::test] + async fn test_zero_read() -> ZipResult<()> { + let data: &[u8] = b"1234"; + let mut buf = [0; 5]; + + let mut reader = Crc32Reader::new(Box::pin(data), 0x9be3e0a3, false); + assert_eq!(reader.read(&mut buf[..0]).await?, 0); + assert_eq!(reader.read(&mut buf).await?, 4); + + Ok(()) + } +} diff --git a/src/tokio/extraction.rs b/src/tokio/extraction.rs new file mode 100755 index 000000000..fca4c0bfb --- /dev/null +++ b/src/tokio/extraction.rs @@ -0,0 +1,80 @@ +#![allow(missing_docs)] + +use indexmap::IndexSet; + +use std::{os::unix::ffi::OsStrExt, path::Path, str}; + +#[derive(Debug, Clone)] +pub struct CompletedPaths<'a> { + seen: IndexSet<&'a Path>, +} + +impl<'a> CompletedPaths<'a> { + pub fn new() -> Self { + Self { + seen: IndexSet::new(), + } + } + + #[inline] + pub fn contains(&self, path: &'a Path) -> bool { + self.seen.contains(Self::normalize_trailing_slashes(path)) + } + + #[inline] + pub fn is_dir(path: &'a Path) -> bool { + Self::path_str(path).ends_with('/') + } + + #[inline] + pub(crate) fn path_str(path: &'a Path) -> &'a str { + debug_assert!(path.to_str().is_some()); + unsafe { str::from_utf8_unchecked(path.as_os_str().as_bytes()) } + } + + #[inline] + pub fn normalize_trailing_slashes(path: &'a Path) -> &'a Path { + Path::new(Self::path_str(path).trim_end_matches('/')) + } + + pub fn containing_dirs(path: &'a Path) -> impl Iterator { + let is_dir = Self::is_dir(path.as_ref()); + path.ancestors() + .inspect(|p| { + if p == &Path::new("/") { + unreachable!("did not expect absolute paths") + } + }) + .filter_map(move |p| { + if &p == &path { + if is_dir { + Some(p) + } else { + None + } + } else if p == Path::new("") { + None + } else { + Some(p) + } + }) + .map(Self::normalize_trailing_slashes) + } + + pub fn new_containing_dirs_needed(&self, path: &'a Path) -> Vec<&'a Path> { + let mut ret: Vec<_> = Self::containing_dirs(path) + /* Assuming we are given ancestors in order from child to parent. */ + .take_while(|p| !self.contains(p)) + .collect(); + /* Get dirs in order from parent to child. */ + ret.reverse(); + ret + } + + pub fn confirm_dir(&mut self, dir: &'a Path) { + let dir = Self::normalize_trailing_slashes(dir); + if !self.seen.contains(dir) { + self.seen.insert(dir); + } + } +} diff --git a/src/tokio/mod.rs b/src/tokio/mod.rs new file mode 100755 index 000000000..92db06fbb --- /dev/null +++ b/src/tokio/mod.rs @@ -0,0 +1,78 @@ +#![allow(missing_docs)] + +pub mod buf_reader; +pub mod buf_writer; +pub mod channels; +pub mod combinators; +pub(crate) mod crc32; +pub(crate) mod extraction; +pub mod os; +pub mod read; +pub mod stream_impls; +pub mod write; + +use std::pin::Pin; + +pub trait WrappedPin { + fn unwrap_inner_pin(self) -> Pin>; +} + +pub(crate) mod utils { + use std::{ + mem::{self, ManuallyDrop, MaybeUninit}, + ptr, + }; + + pub(crate) fn map_take_manual_drop U>( + slot: &mut ManuallyDrop, + f: F, + ) -> ManuallyDrop { + let taken = unsafe { ManuallyDrop::take(slot) }; + ManuallyDrop::new(f(taken)) + } + + pub(crate) fn map_swap_uninit T>(slot: &mut T, f: F) { + let mut other = MaybeUninit::::uninit(); + /* `other` IS UNINIT!!!!! */ + unsafe { + ptr::swap(slot, other.as_mut_ptr()); + } + /* `slot` IS NOW UNINIT!!!! */ + let mut wrapped = f(unsafe { + /* `other` will be dropped at the end of f(). */ + other.assume_init() + }); + /* `wrapped` has a valid value, returned by f(). */ + unsafe { + ptr::swap(slot, &mut wrapped); + } + /* `wrapped` IS NOW UNINIT!!!! */ + mem::forget(wrapped); + } + + #[cfg(test)] + mod test { + use super::*; + + #[test] + fn test_map_swap_uninit_primitive() { + let mut x = 5; + map_swap_uninit(&mut x, |x| x + 3); + assert_eq!(x, 8); + } + + #[test] + fn test_map_swap_uninit_heap_alloc() { + let mut x: Vec = (0..12).map(|x| x * 3).collect(); + map_swap_uninit(&mut x, |mut x| { + x.push(192); + x.insert(0, 123); + x + }); + assert_eq!( + x, + vec![123, 0, 3, 6, 9, 12, 15, 18, 21, 24, 27, 30, 33, 192] + ); + } + } +} diff --git a/src/tokio/os/copy_file_range.rs b/src/tokio/os/copy_file_range.rs new file mode 100644 index 000000000..a9f898d8d --- /dev/null +++ b/src/tokio/os/copy_file_range.rs @@ -0,0 +1,986 @@ +use super::{SyscallAvailability, INVALID_FD}; +use crate::{cvt, try_libc}; + +use cfg_if::cfg_if; +use displaydoc::Display; +use libc; +use once_cell::sync::Lazy; +use tokio::{io, task}; +use tokio_pipe::{PipeRead, PipeWrite}; + +use std::{ + future::Future, + io::{IoSlice, IoSliceMut}, + mem::{self, MaybeUninit}, + os::unix::io::{AsRawFd, FromRawFd, IntoRawFd, OwnedFd, RawFd}, + pin::Pin, + ptr, + task::{ready, Context, Poll, Waker}, +}; + +fn invalid_copy_file_range() -> io::Error { + let ret = unsafe { + libc::copy_file_range( + INVALID_FD, + ptr::null_mut(), + INVALID_FD, + ptr::null_mut(), + 1, + 0, + ) + }; + assert_eq!(-1, ret); + io::Error::last_os_error() +} + +pub static HAS_COPY_FILE_RANGE: Lazy = Lazy::new(|| { + cfg_if! { + if #[cfg(target_os = "linux")] { + match invalid_copy_file_range().raw_os_error().unwrap() { + libc::EBADF => SyscallAvailability::Available, + errno => SyscallAvailability::FailedProbe(io::Error::from_raw_os_error(errno)), + } + } else { + SyscallAvailability::NotOnThisPlatform + } + } +}); + +pub struct RawArgs<'a> { + fd: libc::c_int, + off: Option<&'a mut libc::off64_t>, +} + +pub trait CopyFileRangeHandle { + fn role(&self) -> Role; + fn as_args(self: Pin<&mut Self>) -> RawArgs<'_>; +} + +pub struct MutateInnerOffset { + role: Role, + owned_fd: OwnedFd, + offset: u64, +} + +impl MutateInnerOffset { + pub async fn new(f: impl IntoRawFd, role: Role) -> io::Result { + let raw_fd = validate_raw_fd(f.into_raw_fd(), role)?; + let offset: libc::off64_t = + task::spawn_blocking(move || unsafe { cvt!(libc::lseek(raw_fd, 0, libc::SEEK_CUR)) }) + .await + .unwrap()?; + let owned_fd = unsafe { OwnedFd::from_raw_fd(raw_fd) }; + Ok(Self { + role, + owned_fd, + offset: offset as u64, + }) + } + + pub fn into_owned(self) -> OwnedFd { + self.owned_fd + } +} + +impl IntoRawFd for MutateInnerOffset { + fn into_raw_fd(self) -> RawFd { + self.into_owned().into_raw_fd() + } +} + +impl From for std::fs::File { + fn from(x: MutateInnerOffset) -> Self { + x.into_owned().into() + } +} + +impl CopyFileRangeHandle for MutateInnerOffset { + fn role(&self) -> Role { + self.role + } + fn as_args(self: Pin<&mut Self>) -> RawArgs<'_> { + RawArgs { + fd: self.owned_fd.as_raw_fd(), + off: None, + } + } +} + +impl std::io::Read for MutateInnerOffset { + fn read(&mut self, buf: &mut [u8]) -> io::Result { + let fd = self.owned_fd.as_raw_fd(); + /* FIXME: make this truly async instead of sync, perf permitting! */ + let num_read = + unsafe { cvt!(libc::read(fd, mem::transmute(buf.as_mut_ptr()), buf.len())) }?; + assert!(num_read >= 0); + Ok(num_read as usize) + } + + fn read_vectored(&mut self, bufs: &mut [IoSliceMut<'_>]) -> io::Result { + let fd = self.owned_fd.as_raw_fd(); + /* FIXME: make this truly async instead of sync, perf permitting! */ + let num_read = unsafe { + cvt!(libc::readv( + fd, + mem::transmute(bufs.as_ptr()), + bufs.len() as libc::c_int + )) + }?; + assert!(num_read >= 0); + Ok(num_read as usize) + } + + #[inline] + fn is_read_vectored(&self) -> bool { + true + } +} + +impl std::io::Seek for MutateInnerOffset { + fn seek(&mut self, arg: io::SeekFrom) -> io::Result { + let (offset, whence): (libc::off64_t, libc::c_int) = match arg { + io::SeekFrom::Start(pos) => (pos as libc::off64_t, libc::SEEK_SET), + io::SeekFrom::Current(diff) => (diff, libc::SEEK_CUR), + io::SeekFrom::End(diff) => (diff, libc::SEEK_END), + }; + let fd = self.owned_fd.as_raw_fd(); + let new_offset = cvt!(unsafe { libc::lseek(fd, offset, whence) })?; + self.offset = new_offset as u64; + Ok(self.offset) + } +} + +impl std::io::Write for MutateInnerOffset { + fn write(&mut self, buf: &[u8]) -> io::Result { + let fd = self.owned_fd.as_raw_fd(); + let num_written = + cvt!(unsafe { libc::write(fd, mem::transmute(buf.as_ptr()), buf.len()) })?; + assert!(num_written > 0); + Ok(num_written as usize) + } + + fn write_vectored(&mut self, bufs: &[IoSlice<'_>]) -> io::Result { + let fd = self.owned_fd.as_raw_fd(); + let num_written = cvt!(unsafe { + libc::writev(fd, mem::transmute(bufs.as_ptr()), bufs.len() as libc::c_int) + })?; + assert!(num_written > 0); + Ok(num_written as usize) + } + + #[inline] + fn is_write_vectored(&self) -> bool { + true + } + + fn flush(&mut self) -> io::Result<()> { + let _ = cvt!(unsafe { libc::fdatasync(self.owned_fd.as_raw_fd()) })?; + Ok(()) + } +} + +impl io::AsyncRead for MutateInnerOffset { + fn poll_read( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + buf: &mut io::ReadBuf<'_>, + ) -> Poll> { + let fd = self.owned_fd.as_raw_fd(); + /* FIXME: make this truly async instead of sync, perf permitting! */ + let num_read = unsafe { + cvt!(libc::read( + fd, + mem::transmute(buf.initialize_unfilled().as_mut_ptr()), + buf.remaining(), + )) + }?; + assert!(num_read >= 0); + buf.set_filled(buf.filled().len() + num_read as usize); + Poll::Ready(Ok(())) + } +} + +impl io::AsyncSeek for MutateInnerOffset { + fn start_seek(mut self: Pin<&mut Self>, arg: io::SeekFrom) -> io::Result<()> { + let (offset, whence): (libc::off64_t, libc::c_int) = match arg { + io::SeekFrom::Start(pos) => (pos as libc::off64_t, libc::SEEK_SET), + io::SeekFrom::Current(diff) => (diff, libc::SEEK_CUR), + io::SeekFrom::End(diff) => (diff, libc::SEEK_END), + }; + let fd = self.owned_fd.as_raw_fd(); + /* FIXME: make this async/pollable! */ + let new_offset = cvt!(unsafe { libc::lseek(fd, offset, whence) })?; + self.offset = new_offset as u64; + Ok(()) + } + + fn poll_complete(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(self.offset)) + } +} + +impl io::AsyncWrite for MutateInnerOffset { + #[inline] + fn is_write_vectored(&self) -> bool { + true + } + + fn poll_write_vectored( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + bufs: &[IoSlice<'_>], + ) -> Poll> { + let fd = self.owned_fd.as_raw_fd(); + let num_written = cvt!(unsafe { + /* FIXME: make this async/pollable! */ + libc::writev(fd, mem::transmute(bufs.as_ptr()), bufs.len() as libc::c_int) + })?; + assert!(num_written > 0); + Poll::Ready(Ok(num_written as usize)) + } + + fn poll_write( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + let fd = self.owned_fd.as_raw_fd(); + /* FIXME: make this async/pollable! */ + let num_written = + cvt!(unsafe { libc::write(fd, mem::transmute(buf.as_ptr()), buf.len()) })?; + assert!(num_written > 0); + Poll::Ready(Ok(num_written as usize)) + } + + fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + let _ = cvt!(unsafe { libc::fdatasync(self.owned_fd.as_raw_fd()) })?; + Poll::Ready(Ok(())) + } + + fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.poll_flush(cx) + } +} + +#[derive(Clone)] +pub struct FromGivenOffset { + fd: RawFd, + pub offset: i64, + role: Role, +} + +impl FromGivenOffset { + pub fn new(f: &impl AsRawFd, role: Role, init: u32) -> io::Result { + let raw_fd = f.as_raw_fd(); + let fd = validate_raw_fd(raw_fd, role)?; + Ok(Self { + fd, + role, + offset: init as i64, + }) + } + + fn stat_sync(&self) -> io::Result { + let mut stat = MaybeUninit::::uninit(); + + try_libc!(unsafe { libc::fstat(self.fd, stat.as_mut_ptr()) }); + + Ok(unsafe { stat.assume_init() }) + } + + fn len_sync(&self) -> io::Result { + Ok(self.stat_sync()?.st_size) + } + + pub async fn stat(&self) -> io::Result { + let mut stat = MaybeUninit::::uninit(); + + let fd = self.fd; + task::spawn_blocking(move || cvt!(unsafe { libc::fstat(fd, stat.as_mut_ptr()) })) + .await + .unwrap()?; + + Ok(unsafe { stat.assume_init() }) + } + + pub async fn len(&self) -> io::Result { + Ok(self.stat().await?.st_size) + } +} + +impl std::io::Read for FromGivenOffset { + fn read(&mut self, buf: &mut [u8]) -> io::Result { + let num_read = cvt!(unsafe { + libc::pread( + self.fd, + mem::transmute(buf.as_mut_ptr()), + buf.len(), + self.offset, + ) + })?; + assert!(num_read >= 0); + self.offset += num_read as i64; + Ok(num_read as usize) + } + + fn read_vectored(&mut self, bufs: &mut [IoSliceMut<'_>]) -> io::Result { + let num_read = cvt!(unsafe { + libc::preadv( + self.fd, + mem::transmute(bufs.as_mut_ptr()), + bufs.len() as libc::c_int, + self.offset, + ) + })?; + assert!(num_read >= 0); + self.offset += num_read as i64; + Ok(num_read as usize) + } + + #[inline] + fn is_read_vectored(&self) -> bool { + true + } +} + +impl std::io::Seek for FromGivenOffset { + fn seek(&mut self, arg: io::SeekFrom) -> io::Result { + self.offset = match arg { + io::SeekFrom::Start(from_start) => from_start as i64, + io::SeekFrom::Current(diff) => self.offset + diff, + io::SeekFrom::End(from_end) => { + assert!(from_end <= 0); + let full_len = self.len_sync()?; + full_len + from_end + } + }; + Ok(self.offset as u64) + } +} + +impl std::io::Write for FromGivenOffset { + fn write(&mut self, buf: &[u8]) -> io::Result { + let num_written = cvt!(unsafe { + libc::pwrite( + self.fd, + mem::transmute(buf.as_ptr()), + buf.len(), + self.offset, + ) + })?; + assert!(num_written > 0); + self.offset += num_written as i64; + Ok(num_written as usize) + } + + fn write_vectored(&mut self, bufs: &[IoSlice<'_>]) -> io::Result { + let num_written = cvt!(unsafe { + libc::pwritev( + self.fd, + mem::transmute(bufs.as_ptr()), + bufs.len() as libc::c_int, + self.offset, + ) + })?; + assert!(num_written > 0); + self.offset += num_written as i64; + Ok(num_written as usize) + } + + fn is_write_vectored(&self) -> bool { + true + } + + fn flush(&mut self) -> io::Result<()> { + let _ = cvt!(unsafe { libc::fdatasync(self.fd) })?; + Ok(()) + } +} + +impl io::AsyncRead for FromGivenOffset { + fn poll_read( + mut self: Pin<&mut Self>, + _cx: &mut Context<'_>, + buf: &mut io::ReadBuf<'_>, + ) -> Poll> { + debug_assert!(buf.remaining() > 0); + + let prev_filled = buf.filled().len(); + /* FIXME: don't block here! */ + let num_read = cvt!(unsafe { + libc::pread( + self.fd, + mem::transmute(buf.initialize_unfilled().as_mut_ptr()), + buf.remaining(), + self.offset, + ) + })?; + self.offset += num_read as i64; + buf.set_filled(prev_filled + num_read as usize); + + Poll::Ready(Ok(())) + } +} + +impl io::AsyncSeek for FromGivenOffset { + fn start_seek(mut self: Pin<&mut Self>, op: io::SeekFrom) -> io::Result<()> { + self.offset = match op { + io::SeekFrom::Start(from_start) => from_start as i64, + io::SeekFrom::Current(diff) => self.offset + diff, + io::SeekFrom::End(from_end) => { + assert!(from_end <= 0); + let full_len = self.len_sync()?; + full_len + from_end + } + }; + Ok(()) + } + + fn poll_complete(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + assert!(self.offset >= 0); + Poll::Ready(Ok(self.offset as u64)) + } +} + +impl io::AsyncWrite for FromGivenOffset { + #[inline] + fn is_write_vectored(&self) -> bool { + true + } + + fn poll_write_vectored( + mut self: Pin<&mut Self>, + _cx: &mut Context<'_>, + bufs: &[IoSlice<'_>], + ) -> Poll> { + /* FIXME: don't block here! */ + let num_written = cvt!(unsafe { + libc::pwritev( + self.fd, + mem::transmute(bufs.as_ptr()), + bufs.len() as libc::c_int, + self.offset, + ) + })?; + assert!(num_written > 0); + self.offset += num_written as i64; + + Poll::Ready(Ok(num_written as usize)) + } + + fn poll_write( + mut self: Pin<&mut Self>, + _cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + debug_assert!(buf.len() > 0); + + /* FIXME: don't block here! */ + let num_written = cvt!(unsafe { + libc::pwrite( + self.fd, + buf.as_ptr() as *const libc::c_void, + buf.len(), + self.offset, + ) + })?; + assert!(num_written > 0); + self.offset += num_written as i64; + + Poll::Ready(Ok(num_written as usize)) + } + + fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + let _ = cvt!(unsafe { libc::fdatasync(self.fd) })?; + Poll::Ready(Ok(())) + } + + fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + ready!(self.poll_flush(cx))?; + Poll::Ready(Ok(())) + } +} + +impl AsRawFd for FromGivenOffset { + fn as_raw_fd(&self) -> RawFd { + self.fd.as_raw_fd() + } +} + +impl CopyFileRangeHandle for FromGivenOffset { + fn role(&self) -> Role { + self.role + } + fn as_args(self: Pin<&mut Self>) -> RawArgs { + let Self { + fd, ref mut offset, .. + } = self.get_mut(); + RawArgs { + fd: fd.as_raw_fd(), + off: Some(offset), + } + } +} + +#[inline] +fn convert_option_ptr(mut p: Option<&mut T>) -> *mut T { + if let Some(ref mut val) = p { + &mut **val + } else { + ptr::null_mut() + } +} + +pub async fn iter_copy_file_range( + src: Pin<&mut impl CopyFileRangeHandle>, + dst: Pin<&mut impl CopyFileRangeHandle>, + len: usize, +) -> io::Result { + assert_eq!(src.role(), Role::Readable); + let RawArgs { + fd: fd_in, + off: off_in, + } = src.as_args(); + let off_in = convert_option_ptr(off_in); + let off_in = off_in as usize; + + assert_eq!(dst.role(), Role::Writable); + let RawArgs { + fd: fd_out, + off: off_out, + } = dst.as_args(); + let off_out = convert_option_ptr(off_out); + let off_out = off_out as usize; + + /* These must always be set to 0 for now. */ + const FUTURE_FLAGS: libc::c_uint = 0; + let written: libc::ssize_t = task::spawn_blocking(move || { + let off_in = off_in as *mut libc::off64_t; + let off_out = off_out as *mut libc::off64_t; + cvt!(unsafe { libc::copy_file_range(fd_in, off_in, fd_out, off_out, len, FUTURE_FLAGS) }) + }) + .await + .unwrap()?; + assert!(written >= 0); + Ok(written as usize) +} + +pub async fn iter_splice_from_pipe( + mut src: Pin<&mut PipeRead>, + dst: Pin<&mut impl CopyFileRangeHandle>, + len: usize, +) -> io::Result { + assert_eq!(dst.role(), Role::Writable); + let RawArgs { + fd: fd_out, + off: off_out, + } = dst.as_args(); + + src.splice_to_blocking_fd(fd_out, off_out, len, false).await +} + +pub async fn splice_from_pipe( + mut src: Pin<&mut PipeRead>, + mut dst: Pin<&mut impl CopyFileRangeHandle>, + full_len: usize, +) -> io::Result { + let mut remaining = full_len; + + while remaining > 0 { + let cur_written = iter_splice_from_pipe(src.as_mut(), dst.as_mut(), remaining).await?; + assert!(cur_written <= remaining); + if cur_written == 0 { + return Ok(full_len - remaining); + } + remaining -= cur_written; + } + Ok(full_len) +} + +pub async fn iter_splice_to_pipe( + src: Pin<&mut impl CopyFileRangeHandle>, + mut dst: Pin<&mut PipeWrite>, + len: usize, +) -> io::Result { + assert_eq!(src.role(), Role::Readable); + let RawArgs { + fd: fd_in, + off: off_in, + } = src.as_args(); + + dst.splice_from_blocking_fd(fd_in, off_in, len).await +} + +pub async fn splice_to_pipe( + mut src: Pin<&mut impl CopyFileRangeHandle>, + mut dst: Pin<&mut PipeWrite>, + full_len: usize, +) -> io::Result { + let mut remaining = full_len; + + while remaining > 0 { + let cur_written = iter_splice_to_pipe(src.as_mut(), dst.as_mut(), remaining).await?; + assert!(cur_written <= remaining); + if cur_written == 0 { + return Ok(full_len - remaining); + } + remaining -= cur_written; + } + Ok(full_len) +} + +pub async fn copy_file_range( + mut src: Pin<&mut impl CopyFileRangeHandle>, + mut dst: Pin<&mut impl CopyFileRangeHandle>, + full_len: usize, +) -> io::Result { + let mut remaining = full_len; + + while remaining > 0 { + let cur_written = iter_copy_file_range(src.as_mut(), dst.as_mut(), remaining).await?; + assert!(cur_written <= remaining); + if cur_written == 0 { + return Ok(full_len - remaining); + } + remaining -= cur_written; + } + Ok(full_len) +} + +fn check_regular_file(fd: RawFd) -> io::Result<()> { + let mut stat = MaybeUninit::::uninit(); + + try_libc!(unsafe { libc::fstat(fd, stat.as_mut_ptr()) }); + + let stat = unsafe { stat.assume_init() }; + if (stat.st_mode & libc::S_IFMT) == libc::S_IFREG { + Ok(()) + } else { + Err(io::Error::new( + io::ErrorKind::Other, + "Fd is not a regular file", + )) + } +} + +fn get_status_flags(fd: RawFd) -> io::Result { + Ok(try_libc!(unsafe { libc::fcntl(fd, libc::F_GETFL) })) +} + +#[derive(Copy, Clone, Debug, Display, Eq, PartialEq, Hash, Ord, PartialOrd)] +pub enum Role { + /// fd has the read capability + Readable, + /// fd has the write capability + Writable, +} + +impl Role { + fn allowed_modes(&self) -> &'static [libc::c_int] { + static READABLE: &'static [libc::c_int] = &[libc::O_RDONLY, libc::O_RDWR]; + static WRITABLE: &'static [libc::c_int] = &[libc::O_WRONLY, libc::O_RDWR]; + match self { + Self::Readable => READABLE, + Self::Writable => WRITABLE, + } + } + + fn check_append(&self, flags: libc::c_int) -> io::Result<()> { + if let Self::Writable = self { + if (flags & libc::O_APPEND) != 0 { + return Err(io::Error::new( + io::ErrorKind::Other, + "Writable Fd was set for append!", + )); + } + } + Ok(()) + } + + fn errmsg(&self) -> &'static str { + static READABLE: &'static str = "Fd is not readable!"; + static WRITABLE: &'static str = "Fd is not writable!"; + match self { + Self::Readable => READABLE, + Self::Writable => WRITABLE, + } + } + + pub(crate) fn validate_flags(&self, flags: libc::c_int) -> io::Result<()> { + let access_mode = flags & libc::O_ACCMODE; + + if !self.allowed_modes().contains(&access_mode) { + return Err(io::Error::new(io::ErrorKind::Other, self.errmsg())); + } + self.check_append(flags)?; + + Ok(()) + } +} + +fn validate_raw_fd(fd: RawFd, role: Role) -> io::Result { + check_regular_file(fd)?; + + let status_flags = get_status_flags(fd)?; + role.validate_flags(status_flags)?; + + Ok(fd) +} + +#[cfg(test)] +mod test { + use super::*; + + use std::fs; + + #[test] + fn check_copy_file_range() { + assert!(matches!( + *HAS_COPY_FILE_RANGE, + SyscallAvailability::Available + )); + } + + #[test] + fn check_readable_writable_file() { + let f = tempfile::tempfile().unwrap(); + let fd: RawFd = f.as_raw_fd(); + + validate_raw_fd(fd, Role::Readable).unwrap(); + validate_raw_fd(fd, Role::Writable).unwrap(); + } + + #[test] + fn check_only_writable() { + let td = tempfile::tempdir().unwrap(); + let f = fs::OpenOptions::new() + .create_new(true) + .read(false) + .write(true) + .open(td.path().join("asdf.txt")) + .unwrap(); + let fd: RawFd = f.as_raw_fd(); + + validate_raw_fd(fd, Role::Writable).unwrap(); + assert!(validate_raw_fd(fd, Role::Readable).is_err()); + } + + #[test] + fn check_only_readable() { + let td = tempfile::tempdir().unwrap(); + let p = td.path().join("asdf.txt"); + fs::write(&p, b"wow!").unwrap(); + + let f = fs::OpenOptions::new() + .read(true) + .write(false) + .open(&p) + .unwrap(); + let fd: RawFd = f.as_raw_fd(); + + validate_raw_fd(fd, Role::Readable).unwrap(); + assert!(validate_raw_fd(fd, Role::Writable).is_err()); + } + + #[test] + fn check_no_append() { + let td = tempfile::tempdir().unwrap(); + + let f = fs::OpenOptions::new() + .create(true) + .append(true) + .write(true) + .open(td.path().join("asdf.txt")) + .unwrap(); + let fd: RawFd = f.as_raw_fd(); + + assert!(validate_raw_fd(fd, Role::Writable).is_err()); + assert!(validate_raw_fd(fd, Role::Readable).is_err()); + } + + #[tokio::test] + async fn read_ref_into_write_owned() { + use std::io::{Read, Seek}; + + let td = tempfile::tempdir().unwrap(); + let p = td.path().join("asdf.txt"); + fs::write(&p, b"wow!").unwrap(); + + let in_file = fs::File::open(&p).unwrap(); + let mut src = FromGivenOffset::new(&in_file, Role::Readable, 0).unwrap(); + let sp = Pin::new(&mut src); + + let p2 = td.path().join("asdf2.txt"); + let out_file = fs::OpenOptions::new() + .create_new(true) + .write(true) + /* Need this to read the output file contents at the end! */ + .read(true) + .open(&p2) + .unwrap(); + let mut dst = MutateInnerOffset::new(out_file, Role::Writable) + .await + .unwrap(); + let dp = Pin::new(&mut dst); + + /* Explicit offset begins at 0. */ + assert_eq!(0, sp.offset); + + /* 4 bytes were written. */ + assert_eq!( + 4, + /* NB: 5 bytes were requested! */ + copy_file_range(sp, dp, 5).await.unwrap() + ); + assert_eq!(4, src.offset); + + let mut dst: fs::File = dst.into_owned().into(); + assert_eq!(4, dst.stream_position().unwrap()); + dst.rewind().unwrap(); + let mut s = String::new(); + dst.read_to_string(&mut s).unwrap(); + assert_eq!(&s, "wow!"); + } + + #[tokio::test] + async fn test_splice_blocking() { + use tokio::io::{AsyncReadExt, AsyncSeekExt, AsyncWriteExt}; + + let mut in_file = tokio::fs::File::from_std(tempfile::tempfile().unwrap()); + in_file.write_all(b"hello").await.unwrap(); + in_file.rewind().await.unwrap(); + let mut in_file = MutateInnerOffset::new(in_file.into_std().await, Role::Readable) + .await + .unwrap(); + + let mut out_file = tokio::fs::File::from_std(tempfile::tempfile().unwrap()); + let mut out_file_handle = FromGivenOffset::new(&out_file, Role::Writable, 0).unwrap(); + + let (mut r, mut w) = tokio_pipe::pipe().unwrap(); + + let w_task = tokio::spawn(async move { + assert_eq!( + 5, + splice_to_pipe(Pin::new(&mut in_file), Pin::new(&mut w), 6) + .await + .unwrap() + ); + + let in_file: fs::File = in_file.into(); + let mut in_file = tokio::fs::File::from_std(in_file); + assert_eq!(5, in_file.stream_position().await.unwrap()); + }); + + let r_task = tokio::spawn(async move { + assert_eq!( + 5, + splice_from_pipe(Pin::new(&mut r), Pin::new(&mut out_file_handle), 6) + .await + .unwrap() + ); + assert_eq!(out_file_handle.offset, 5); + }); + + tokio::try_join!(w_task, r_task).unwrap(); + + assert_eq!(0, out_file.stream_position().await.unwrap()); + let mut s = String::new(); + out_file.read_to_string(&mut s).await.unwrap(); + assert_eq!(&s, "hello"); + } + + #[test] + fn test_has_vectored_write() { + use tokio::io::AsyncWrite; + + let f = tokio::fs::File::from_std(tempfile::tempfile().unwrap()); + assert!(!f.is_write_vectored()); + + let f = std::io::Cursor::new(Vec::new()); + assert!(f.is_write_vectored()); + } + + #[tokio::test] + async fn test_wrappers_have_vectored_write() { + use tokio::io::AsyncWrite; + + let f = tempfile::tempfile().unwrap(); + let f = MutateInnerOffset::new(f, Role::Writable).await.unwrap(); + assert!(f.is_write_vectored()); + + let f = tokio::fs::File::from_std(tempfile::tempfile().unwrap()); + assert!(!f.is_write_vectored()); + let f = FromGivenOffset::new(&f, Role::Writable, 0).unwrap(); + assert!(f.is_write_vectored()); + } + + #[tokio::test] + async fn test_io_copy_for_owned_wrapper() { + use std::io::prelude::*; + + let mut f = MutateInnerOffset::new(tempfile::tempfile().unwrap(), Role::Writable) + .await + .unwrap(); + f.write_all(b"hello").unwrap(); + f.seek(io::SeekFrom::Start(0)).unwrap(); + let mut f_in = MutateInnerOffset::new(f, Role::Readable).await.unwrap(); + + let mut f_out = MutateInnerOffset::new(tempfile::tempfile().unwrap(), Role::Writable) + .await + .unwrap(); + + std::io::copy(&mut f_in, &mut f_out).unwrap(); + + let mut f: std::fs::File = f_out.into(); + f.seek(io::SeekFrom::Start(0)).unwrap(); + let mut s = String::new(); + f.read_to_string(&mut s).unwrap(); + assert_eq!(&s, "hello"); + } + + #[tokio::test] + async fn test_io_copy_for_non_owned_wrapper() { + use tokio::io::{self, AsyncReadExt, AsyncSeekExt, AsyncWriteExt}; + + let in_backing_file = tokio::fs::File::from_std(tempfile::tempfile().unwrap()); + let mut f_in = FromGivenOffset::new(&in_backing_file, Role::Writable, 0).unwrap(); + f_in.write_all(b"hello").await.unwrap(); + let mut f_in = FromGivenOffset::new(&in_backing_file, Role::Readable, 0).unwrap(); + + let out_backing_file = tokio::fs::File::from_std(tempfile::tempfile().unwrap()); + let mut f_out = FromGivenOffset::new(&out_backing_file, Role::Writable, 0).unwrap(); + + io::copy(&mut Pin::new(&mut f_in), &mut Pin::new(&mut f_out)) + .await + .unwrap(); + + f_out.seek(io::SeekFrom::Start(0)).await.unwrap(); + + let mut s = String::new(); + f_out.read_to_string(&mut s).await.unwrap(); + assert_eq!(&s, "hello"); + } + + #[tokio::test] + async fn test_cloneable_wrapper() { + use tokio::io::{self, AsyncReadExt, AsyncSeekExt, AsyncWriteExt}; + + let backing_file = tokio::fs::File::from_std(tempfile::tempfile().unwrap()); + let mut f1 = FromGivenOffset::new(&backing_file, Role::Writable, 0).unwrap(); + let mut f1b = f1.clone(); + f1.write_all(b"hell").await.unwrap(); + f1b.seek(io::SeekFrom::Current(4)).await.unwrap(); + f1b.write_all(b"o").await.unwrap(); + + let mut f2 = FromGivenOffset::new(&backing_file, Role::Readable, 0).unwrap(); + let mut f3 = f2.clone(); + + let mut s = String::new(); + f2.read_to_string(&mut s).await.unwrap(); + assert_eq!(&s, "hello"); + s.clear(); + assert_eq!(&s, ""); + f3.read_to_string(&mut s).await.unwrap(); + assert_eq!(&s, "hello"); + } +} diff --git a/src/tokio/os/mod.rs b/src/tokio/os/mod.rs new file mode 100644 index 000000000..4beb5987f --- /dev/null +++ b/src/tokio/os/mod.rs @@ -0,0 +1,403 @@ +pub mod copy_file_range; + +#[macro_export] +macro_rules! cvt { + ($e:expr) => {{ + let ret = $e; + if ret == -1 { + Err(io::Error::last_os_error()) + } else { + Ok(ret) + } + }}; +} + +#[macro_export] +macro_rules! try_libc { + ($e: expr) => {{ + let ret = $e; + if ret == -1 { + return Err(io::Error::last_os_error()); + } + ret + }}; +} + +pub enum SyscallAvailability { + Available, + FailedProbe(std::io::Error), + NotOnThisPlatform, +} + +/// Invalid file descriptor. +/// +/// Valid file descriptors are guaranteed to be positive numbers (see `open()` manpage) +/// while negative values are used to indicate errors. +/// Thus -1 will never be overlap with a valid open file. +const INVALID_FD: std::os::fd::RawFd = -1; + +pub mod subset { + use crate::{ + tokio::read::{Shared, SharedData}, + types::ZipFileData, + }; + + use std::{cmp, ops, sync::Arc}; + + #[derive(Debug)] + pub struct SharedSubset { + parent: Arc, + content_range: ops::Range, + entry_range: ops::RangeInclusive, + } + + impl SharedData for SharedSubset { + #[inline] + fn content_range(&self) -> ops::Range { + debug_assert!(self.content_range.start <= self.content_range.end); + debug_assert!(self.content_range.start >= self.parent.offset()); + debug_assert!(self.content_range.end <= self.parent.directory_start()); + self.content_range.clone() + } + #[inline] + fn comment(&self) -> &[u8] { + self.parent.comment() + } + #[inline] + fn contiguous_entries(&self) -> &indexmap::map::Slice { + debug_assert!(self.entry_range.start() <= self.entry_range.end()); + &self.parent.contiguous_entries()[self.entry_range.clone()] + } + } + + impl SharedSubset { + pub fn parent(&self) -> Arc { + self.parent.clone() + } + + pub fn split_contiguous_chunks( + parent: Arc, + num_chunks: usize, + ) -> Box<[SharedSubset]> { + let chunk_size = cmp::max(1, parent.len() / num_chunks); + let all_entry_indices: Vec = (0..parent.len()).collect(); + let chunked_entry_ranges: Vec> = all_entry_indices + .chunks(chunk_size) + .map(|chunk_indices| { + let min = *chunk_indices.first().unwrap(); + let max = *chunk_indices.last().unwrap(); + min..=max + }) + .collect(); + assert!(chunked_entry_ranges.len() <= num_chunks); + + let parent_slice = parent.contiguous_entries(); + let chunked_slices: Vec<&indexmap::map::Slice> = + chunked_entry_ranges + .iter() + .map(|range| &parent_slice[range.clone()]) + .collect(); + + let chunk_regions: Vec> = chunked_slices + .iter() + .enumerate() + .map(|(i, chunk)| { + fn begin_pos(chunk: &indexmap::map::Slice) -> u64 { + assert!(!chunk.is_empty()); + chunk.get_index(0).unwrap().1.header_start + } + if i == 0 { + assert_eq!(parent.offset(), begin_pos(chunk)); + } + let beg = begin_pos(chunk); + let end = chunked_slices + .get(i + 1) + .map(|chunk| *chunk) + .map(begin_pos) + .unwrap_or(parent.directory_start()); + beg..end + }) + .collect(); + + let subsets: Vec = chunk_regions + .into_iter() + .zip(chunked_entry_ranges.into_iter()) + .map(|(content_range, entry_range)| SharedSubset { + parent: parent.clone(), + content_range, + entry_range, + }) + .collect(); + + subsets.into_boxed_slice() + } + } + + #[cfg(test)] + mod test { + use super::*; + + use crate::{result::ZipResult, tokio::write::ZipWriter, write::FileOptions}; + use std::{io::Cursor, pin::Pin}; + use tokio::io::AsyncWriteExt; + + #[tokio::test] + async fn test_split_contiguous_chunks() -> ZipResult<()> { + let options = FileOptions::default(); + + let buf = Cursor::new(Vec::new()); + let mut zip = ZipWriter::new(Box::pin(buf)); + let mut zp = Pin::new(&mut zip); + zp.as_mut().start_file("a.txt", options).await?; + zp.write_all(b"hello\n").await?; + let mut src = zip.finish_into_readable().await?; + let src = Pin::new(&mut src); + + let buf = Cursor::new(Vec::new()); + let mut zip = ZipWriter::new(Box::pin(buf)); + let mut zp = Pin::new(&mut zip); + zp.as_mut().start_file("b.txt", options).await?; + zp.write_all(b"hey\n").await?; + let mut src2 = zip.finish_into_readable().await?; + let src2 = Pin::new(&mut src2); + + let buf = Cursor::new(Vec::new()); + let mut zip = ZipWriter::new(Box::pin(buf)); + let mut zp = Pin::new(&mut zip); + zp.as_mut().start_file("c/d.txt", options).await?; + zp.write_all(b"asdf!\n").await?; + let mut src3 = zip.finish_into_readable().await?; + let src3 = Pin::new(&mut src3); + + let prefix = [0u8; 200]; + let mut buf = Cursor::new(Vec::new()); + buf.write_all(&prefix).await?; + let mut zip = ZipWriter::new(Box::pin(buf)); + let mut zp = Pin::new(&mut zip); + zp.as_mut().merge_archive(src).await?; + zp.as_mut().merge_archive(src2).await?; + zp.merge_archive(src3).await?; + let result = zip.finish_into_readable().await?; + + let parent = result.shared(); + assert_eq!(parent.len(), 3); + assert_eq!(parent.offset(), prefix.len() as u64); + assert_eq!(200..329, parent.content_range()); + + let split_result = SharedSubset::split_contiguous_chunks(parent.clone(), 3); + assert_eq!(split_result.len(), 3); + + assert_eq!(200..243, split_result[0].content_range); + assert_eq!(1, split_result[0].len()); + assert_eq!(0..=0, split_result[0].entry_range); + + assert_eq!(243..284, split_result[1].content_range); + assert_eq!(1, split_result[1].len()); + assert_eq!(1..=1, split_result[1].entry_range); + + assert_eq!(284..329, split_result[2].content_range); + assert_eq!(1, split_result[2].len()); + assert_eq!(2..=2, split_result[2].entry_range); + assert_eq!(329, parent.directory_start()); + + Ok(()) + } + } +} +pub use subset::SharedSubset; + +pub mod mapped_archive { + use super::{ + copy_file_range::{self, CopyFileRangeHandle, FromGivenOffset, MutateInnerOffset, Role}, + subset::SharedSubset, + }; + use crate::{ + result::ZipResult, + tokio::{ + combinators::Limiter, + read::{Shared, SharedData, ZipArchive}, + WrappedPin, + }, + }; + + use tempfile; + use tokio::{ + fs, + io::{self, AsyncSeekExt}, + task, + }; + + use std::{marker::Unpin, os::unix::io::AsRawFd, pin::Pin, sync::Arc}; + + async fn split_mapped_archive_impl( + mut in_handle: impl CopyFileRangeHandle + Unpin, + split_chunks: Vec, + ) -> ZipResult, SharedSubset>]>> { + let mut ret: Vec, SharedSubset>> = + Vec::with_capacity(split_chunks.len()); + + for chunk in split_chunks.into_iter() { + let cur_len = chunk.content_len() as usize; + let cur_start = chunk.content_range().start; + + let backing_file = task::spawn_blocking(|| tempfile::tempfile()) + .await + .unwrap()?; + let mut out_handle = FromGivenOffset::new(&backing_file, Role::Writable, 0)?; + + assert_eq!( + cur_len, + copy_file_range::copy_file_range( + Pin::new(&mut in_handle), + Pin::new(&mut out_handle), + cur_len, + ) + .await? + ); + + let inner = Limiter::take( + cur_start, + Box::pin(fs::File::from_std(backing_file)), + cur_len, + ); + ret.push(ZipArchive::mapped(Arc::new(chunk), Box::pin(inner))); + } + + Ok(ret.into_boxed_slice()) + } + + ///``` + /// # fn main() -> zip::result::ZipResult<()> { tokio_test::block_on(async { + /// use zip::{result::ZipError, tokio::{read::ZipArchive, write::ZipWriter, os}}; + /// use futures_util::{pin_mut, TryStreamExt}; + /// use tokio::{io::{self, AsyncReadExt, AsyncWriteExt}, fs}; + /// use std::{io::Cursor, pin::Pin, sync::Arc, path::PathBuf}; + /// + /// let f: ZipArchive<_, _> = { + /// let buf = fs::File::from_std(tempfile::tempfile()?); + /// let mut f = zip::tokio::write::ZipWriter::new(Box::pin(buf)); + /// let mut fp = Pin::new(&mut f); + /// let options = zip::write::FileOptions::default() + /// .compression_method(zip::CompressionMethod::Deflated); + /// fp.as_mut().start_file("a.txt", options).await?; + /// fp.as_mut().write_all(b"hello\n").await?; + /// fp.as_mut().start_file("b.txt", options).await?; + /// fp.as_mut().write_all(b"hello2\n").await?; + /// fp.as_mut().start_file("c.txt", options).await?; + /// fp.write_all(b"hello3\n").await?; + /// f.finish_into_readable().await? + /// }; + /// + /// let split_f: Vec> = os::split_into_mapped_archive(f, 3).await?.into(); + /// + /// let mut ret: Vec<(PathBuf, String)> = Vec::with_capacity(3); + /// + /// for mut mapped_archive in split_f.into_iter() { + /// let entries = Pin::new(&mut mapped_archive).entries_stream(); + /// pin_mut!(entries); + /// + /// while let Some(mut zf) = entries.try_next().await? { + /// let name = zf.name()?.to_path_buf(); + /// let mut contents = String::new(); + /// zf.read_to_string(&mut contents).await?; + /// ret.push((name, contents)); + /// } + /// } + /// + /// assert_eq!( + /// ret, + /// vec![ + /// (PathBuf::from("a.txt"), "hello\n".to_string()), + /// (PathBuf::from("b.txt"), "hello2\n".to_string()), + /// (PathBuf::from("c.txt"), "hello3\n".to_string()), + /// ], + /// ); + /// + /// # Ok(()) + /// # })} + ///``` + pub async fn split_into_mapped_archive( + archive: ZipArchive, + num_chunks: usize, + ) -> ZipResult, SharedSubset>]>> { + let shared = archive.shared(); + + let inner: Pin> = archive.unwrap_inner_pin(); + let inner: Box = Pin::into_inner(inner); + let mut inner: fs::File = *inner; + + inner.seek(io::SeekFrom::Start(shared.offset())).await?; + + let in_handle = MutateInnerOffset::new(inner.into_std().await, Role::Readable).await?; + + let split_chunks: Vec = + SharedSubset::split_contiguous_chunks(shared, num_chunks).into(); + + split_mapped_archive_impl(in_handle, split_chunks).await + } + + ///``` + /// # fn main() -> zip::result::ZipResult<()> { tokio_test::block_on(async { + /// use zip::{result::ZipError, tokio::{read::ZipArchive, write::ZipWriter, os}}; + /// use futures_util::{pin_mut, TryStreamExt}; + /// use tokio::{io::{self, AsyncReadExt, AsyncWriteExt}, fs}; + /// use std::{io::Cursor, pin::Pin, sync::Arc, path::PathBuf}; + /// + /// let f: ZipArchive<_, _> = { + /// let buf = fs::File::from_std(tempfile::tempfile()?); + /// let mut f = zip::tokio::write::ZipWriter::new(Box::pin(buf)); + /// let mut fp = Pin::new(&mut f); + /// let options = zip::write::FileOptions::default() + /// .compression_method(zip::CompressionMethod::Deflated); + /// fp.as_mut().start_file("a.txt", options).await?; + /// fp.as_mut().write_all(b"hello\n").await?; + /// fp.as_mut().start_file("b.txt", options).await?; + /// fp.as_mut().write_all(b"hello2\n").await?; + /// fp.as_mut().start_file("c.txt", options).await?; + /// fp.write_all(b"hello3\n").await?; + /// f.finish_into_readable().await? + /// }; + /// + /// let split_f: Vec> = os::split_mapped_archive_ref(&f, 3).await?.into(); + /// + /// let mut ret: Vec<(PathBuf, String)> = Vec::with_capacity(3); + /// + /// for mut mapped_archive in split_f.into_iter() { + /// let entries = Pin::new(&mut mapped_archive).entries_stream(); + /// pin_mut!(entries); + /// + /// while let Some(mut zf) = entries.try_next().await? { + /// let name = zf.name()?.to_path_buf(); + /// let mut contents = String::new(); + /// zf.read_to_string(&mut contents).await?; + /// ret.push((name, contents)); + /// } + /// } + /// + /// assert_eq!( + /// ret, + /// vec![ + /// (PathBuf::from("a.txt"), "hello\n".to_string()), + /// (PathBuf::from("b.txt"), "hello2\n".to_string()), + /// (PathBuf::from("c.txt"), "hello3\n".to_string()), + /// ], + /// ); + /// + /// # Ok(()) + /// # })} + ///``` + pub async fn split_mapped_archive_ref( + archive: &ZipArchive, + num_chunks: usize, + ) -> ZipResult, SharedSubset>]>> { + let shared = archive.shared(); + + let in_handle = FromGivenOffset::new(archive, Role::Readable, shared.offset() as u32)?; + + let split_chunks: Vec = + SharedSubset::split_contiguous_chunks(shared, num_chunks).into(); + + split_mapped_archive_impl(in_handle, split_chunks).await + } +} +pub use mapped_archive::{split_into_mapped_archive, split_mapped_archive_ref}; diff --git a/src/tokio/read.rs b/src/tokio/read.rs new file mode 100755 index 000000000..f07f69a00 --- /dev/null +++ b/src/tokio/read.rs @@ -0,0 +1,1361 @@ +use crate::compression::CompressionMethod; +use crate::result::{ZipError, ZipResult}; +use crate::spec::{self, LocalHeaderBuffer}; +use crate::tokio::{ + buf_reader::BufReader, combinators::Limiter, crc32::Crc32Reader, extraction::CompletedPaths, + stream_impls::deflate, utils::map_take_manual_drop, WrappedPin, +}; +use crate::types::ZipFileData; + +#[cfg(any( + feature = "deflate", + feature = "deflate-miniz", + feature = "deflate-zlib" +))] +use flate2::Decompress; + +use async_stream::try_stream; +use cfg_if::cfg_if; +use futures_core::stream::Stream; +use futures_util::{pin_mut, stream::TryStreamExt}; +use indexmap::IndexMap; +use tokio::{ + fs, + io::{self, AsyncReadExt, AsyncSeekExt}, + sync::{self, mpsc}, + task, +}; + +use std::{ + cell::UnsafeCell, + mem::{self, ManuallyDrop}, + num, ops, + os::unix::io::{AsRawFd, RawFd}, + path::{Path, PathBuf}, + pin::Pin, + str, + sync::Arc, + task::{Context, Poll}, +}; + +pub(crate) trait ReaderWrapper { + fn construct(data: &ZipFileData, s: Pin>) -> Self + where + Self: Sized; +} + +pub struct StoredReader(Crc32Reader); + +impl StoredReader { + #[inline] + fn pin_stream(self: Pin<&mut Self>) -> Pin<&mut Crc32Reader> { + unsafe { self.map_unchecked_mut(|Self(inner)| inner) } + } +} + +impl io::AsyncRead for StoredReader { + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut io::ReadBuf<'_>, + ) -> Poll> { + self.pin_stream().poll_read(cx, buf) + } +} + +impl WrappedPin for StoredReader { + fn unwrap_inner_pin(self) -> Pin> { + self.0.unwrap_inner_pin() + } +} + +impl ReaderWrapper for StoredReader { + fn construct(data: &ZipFileData, s: Pin>) -> Self { + Self(Crc32Reader::new(s, data.crc32, false)) + } +} + +pub struct DeflateReader(Crc32Reader>>); + +impl DeflateReader { + #[inline] + fn pin_stream( + self: Pin<&mut Self>, + ) -> Pin<&mut Crc32Reader>>> { + unsafe { self.map_unchecked_mut(|Self(inner)| inner) } + } +} + +impl io::AsyncRead for DeflateReader { + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut io::ReadBuf<'_>, + ) -> Poll> { + self.pin_stream().poll_read(cx, buf) + } +} + +impl WrappedPin for DeflateReader { + fn unwrap_inner_pin(self) -> Pin> { + Pin::into_inner(Pin::into_inner(self.0.unwrap_inner_pin()).unwrap_inner_pin()) + .unwrap_inner_pin() + } +} + +impl ReaderWrapper for DeflateReader { + fn construct(data: &ZipFileData, s: Pin>) -> Self { + let buf_reader = BufReader::with_capacity(num::NonZeroUsize::new(32 * 1024).unwrap(), s); + let deflater = deflate::Reader::with_state(Decompress::new(false), Box::pin(buf_reader)); + Self(Crc32Reader::new(Box::pin(deflater), data.crc32, false)) + } +} + +pub enum ZipFileWrappedReader { + Stored(StoredReader), + Deflated(DeflateReader), +} + +enum WrappedProj<'a, S> { + Stored(Pin<&'a mut StoredReader>), + Deflated(Pin<&'a mut DeflateReader>), +} + +impl ZipFileWrappedReader { + #[inline] + fn project(self: Pin<&mut Self>) -> WrappedProj<'_, S> { + unsafe { + let s = self.get_unchecked_mut(); + match s { + Self::Stored(s) => WrappedProj::Stored(Pin::new_unchecked(s)), + Self::Deflated(s) => WrappedProj::Deflated(Pin::new_unchecked(s)), + } + } + } +} + +impl io::AsyncRead for ZipFileWrappedReader { + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut io::ReadBuf<'_>, + ) -> Poll> { + match self.project() { + WrappedProj::Stored(r) => r.poll_read(cx, buf), + WrappedProj::Deflated(r) => r.poll_read(cx, buf), + } + } +} + +impl WrappedPin for ZipFileWrappedReader { + fn unwrap_inner_pin(self) -> Pin> { + match self { + Self::Stored(r) => r.unwrap_inner_pin(), + Self::Deflated(r) => r.unwrap_inner_pin(), + } + } +} + +impl WrappedPin for ZipFileWrappedReader> { + fn unwrap_inner_pin(self) -> Pin> { + match self { + Self::Stored(r) => Pin::into_inner(r.unwrap_inner_pin()).unwrap_inner_pin(), + Self::Deflated(r) => Pin::into_inner(r.unwrap_inner_pin()).unwrap_inner_pin(), + } + } +} + +impl ReaderWrapper for ZipFileWrappedReader { + fn construct(data: &ZipFileData, s: Pin>) -> Self { + match data.compression_method { + CompressionMethod::Stored => Self::Stored(StoredReader::::construct(data, s)), + #[cfg(any( + feature = "deflate", + feature = "deflate-miniz", + feature = "deflate-zlib" + ))] + CompressionMethod::Deflated => Self::Deflated(DeflateReader::::construct(data, s)), + _ => todo!("other compression methods not supported yet!"), + } + } +} + +pub(crate) async fn find_content( + data: &ZipFileData, + mut reader: Pin>, +) -> ZipResult> { + let cur_pos = { + // Parse local header + reader.seek(io::SeekFrom::Start(data.header_start)).await?; + + static_assertions::assert_eq_size!([u8; 30], LocalHeaderBuffer); + let mut info = [0u8; 30]; + reader.read_exact(&mut info[..]).await?; + + let LocalHeaderBuffer { + magic, + file_name_length, + /* NB: zip files have separate local and central extra data records. The length of the + * local extra field is being parsed here. The value of this field cannot be inferred + * from the central record data alone. */ + extra_field_length, + .. + } = unsafe { mem::transmute(info) }; + + if magic != spec::LOCAL_FILE_HEADER_SIGNATURE { + return Err(ZipError::InvalidArchive("Invalid local file header")); + } + + let data_start = data.header_start + + info.len() as u64 + + file_name_length as u64 + + extra_field_length as u64; + data.data_start.store(data_start); + + reader.seek(io::SeekFrom::Start(data_start)).await? + }; + Ok(Limiter::take( + cur_pos, + reader, + data.compressed_size as usize, + )) +} + +pub trait SharedData { + #[inline] + fn len(&self) -> usize { + self.contiguous_entries().len() + } + #[inline] + fn is_empty(&self) -> bool { + self.len() == 0 + } + fn content_range(&self) -> ops::Range; + #[inline] + fn content_len(&self) -> u64 { + let r = self.content_range(); + debug_assert!(r.start <= r.end); + r.end - r.start + } + fn comment(&self) -> &[u8]; + fn contiguous_entries(&self) -> &indexmap::map::Slice; +} + +#[derive(Debug)] +pub struct Shared { + files: IndexMap, + offset: u64, + directory_start: u64, + comment: Vec, +} + +impl SharedData for Shared { + #[inline] + fn content_range(&self) -> ops::Range { + ops::Range { + start: self.offset(), + end: self.directory_start(), + } + } + #[inline] + fn comment(&self) -> &[u8] { + &self.comment + } + #[inline] + fn contiguous_entries(&self) -> &indexmap::map::Slice { + self.files.as_slice() + } +} + +impl Shared { + #[inline] + pub fn offset(&self) -> u64 { + self.offset + } + #[inline] + pub fn directory_start(&self) -> u64 { + self.directory_start + } + + #[inline] + pub fn file_names(&self) -> impl Iterator { + self.files.keys().map(|s| s.as_str()) + } + + pub(crate) async fn get_directory_counts( + mut reader: Pin<&mut S>, + footer: &spec::CentralDirectoryEnd, + cde_end_pos: u64, + ) -> ZipResult<(u64, u64, usize)> { + // See if there's a ZIP64 footer. The ZIP64 locator if present will + // have its signature 20 bytes in front of the standard footer. The + // standard footer, in turn, is 22+N bytes large, where N is the + // comment length. Therefore: + let zip64locator = if reader + .as_mut() + .seek(io::SeekFrom::End( + -(20 + 22 + footer.zip_file_comment.len() as i64), + )) + .await + .is_ok() + { + match spec::Zip64CentralDirectoryEndLocator::parse_async(reader.as_mut()).await { + Ok(loc) => Some(loc), + Err(ZipError::InvalidArchive(_)) => { + // No ZIP64 header; that's actually fine. We're done here. + None + } + Err(e) => { + // Yikes, a real problem + return Err(e); + } + } + } else { + // Empty Zip files will have nothing else so this error might be fine. If + // not, we'll find out soon. + None + }; + + match zip64locator { + None => { + // Some zip files have data prepended to them, resulting in the + // offsets all being too small. Get the amount of error by comparing + // the actual file position we found the CDE at with the offset + // recorded in the CDE. + let archive_offset = cde_end_pos + .checked_sub(footer.central_directory_size as u64) + .and_then(|x| x.checked_sub(footer.central_directory_offset as u64)) + .ok_or(ZipError::InvalidArchive( + "Invalid central directory size or offset", + ))?; + + let directory_start = footer.central_directory_offset as u64 + archive_offset; + let number_of_files = footer.number_of_files_on_this_disk as usize; + Ok((archive_offset, directory_start, number_of_files)) + } + Some(locator64) => { + // If we got here, this is indeed a ZIP64 file. + + if !footer.record_too_small() + && footer.disk_number as u32 != locator64.disk_with_central_directory + { + return Err(ZipError::UnsupportedArchive( + "Support for multi-disk files is not implemented", + )); + } + + // We need to reassess `archive_offset`. We know where the ZIP64 + // central-directory-end structure *should* be, but unfortunately we + // don't know how to precisely relate that location to our current + // actual offset in the file, since there may be junk at its + // beginning. Therefore we need to perform another search, as in + // read::CentralDirectoryEnd::find_and_parse, except now we search + // forward. + + let search_upper_bound = cde_end_pos + .checked_sub(60) // minimum size of Zip64CentralDirectoryEnd + Zip64CentralDirectoryEndLocator + .ok_or(ZipError::InvalidArchive( + "File cannot contain ZIP64 central directory end", + ))?; + let (footer, archive_offset) = + spec::Zip64CentralDirectoryEnd::find_and_parse_async( + reader.as_mut(), + locator64.end_of_central_directory_offset, + search_upper_bound, + ) + .await?; + + if footer.disk_number != footer.disk_with_central_directory { + return Err(ZipError::UnsupportedArchive( + "Support for multi-disk files is not implemented", + )); + } + + let directory_start = footer + .central_directory_offset + .checked_add(archive_offset) + .ok_or({ + ZipError::InvalidArchive("Invalid central directory size or offset") + })?; + + Ok(( + archive_offset, + directory_start, + footer.number_of_files as usize, + )) + } + } + } + + pub async fn parse( + mut reader: Pin>, + ) -> ZipResult<(Self, Pin>)> { + let (footer, cde_end_pos) = + spec::CentralDirectoryEnd::find_and_parse_async(reader.as_mut()).await?; + + if !footer.record_too_small() && footer.disk_number != footer.disk_with_central_directory { + return Err(ZipError::UnsupportedArchive( + "Support for multi-disk files is not implemented", + )); + } + + let (archive_offset, directory_start, number_of_files) = + Self::get_directory_counts(reader.as_mut(), &footer, cde_end_pos).await?; + + // If the parsed number of files is greater than the offset then + // something fishy is going on and we shouldn't trust number_of_files. + let file_capacity = if number_of_files > directory_start as usize { + 0 + } else { + number_of_files + }; + + let mut files = IndexMap::with_capacity(file_capacity); + + reader + .seek(io::SeekFrom::Start(directory_start)) + .await + .map_err(|_| { + ZipError::InvalidArchive("Could not seek to start of central directory") + })?; + + for i in 0..number_of_files { + let file = + read_spec::central_header_to_zip_file(reader.as_mut(), archive_offset).await?; + if i == 0 { + assert_eq!(archive_offset, file.header_start); + } + assert!(files.insert(file.file_name.clone(), file).is_none()); + } + + Ok(( + Self { + files, + offset: archive_offset, + directory_start, + comment: footer.zip_file_comment, + }, + reader, + )) + } +} + +async fn create_dir_idempotent>(dir: P) -> io::Result> { + match fs::create_dir(dir).await { + Ok(()) => Ok(Some(())), + Err(e) if e.kind() == io::ErrorKind::AlreadyExists => Ok(None), + Err(e) => Err(e), + } +} + +#[derive(Debug)] +pub struct ZipFile<'a, S, R: WrappedPin, Sh> { + data: &'a ZipFileData, + wrapped_reader: ManuallyDrop, + parent: &'a mut ZipArchive, +} + +impl<'a, S, R: WrappedPin, Sh> ZipFile<'a, S, R, Sh> { + #[inline] + pub fn name(&self) -> ZipResult<&Path> { + self.data + .enclosed_name() + .ok_or(ZipError::InvalidArchive("Invalid file path")) + } + + #[inline] + pub fn is_dir(&self) -> bool { + /* TODO: '\\' too? */ + self.data.file_name.ends_with('/') + } + + #[inline] + pub fn unix_mode(&self) -> Option { + self.data.unix_mode() + } + + #[inline] + pub fn data(&self) -> &ZipFileData { + &self.data + } + + #[inline] + fn pin_stream(self: Pin<&mut Self>) -> Pin<&mut R> { + unsafe { self.map_unchecked_mut(|s| &mut *s.wrapped_reader) } + } +} + +impl<'a, S, R: WrappedPin, Sh> ops::Drop for ZipFile<'a, S, R, Sh> { + fn drop(&mut self) { + inner_drop(unsafe { Pin::new_unchecked(self) }); + fn inner_drop<'a, S, R: WrappedPin, Sh>(this: Pin<&mut ZipFile<'a, S, R, Sh>>) { + let ZipFile { + ref mut wrapped_reader, + ref mut parent, + .. + } = unsafe { this.get_unchecked_mut() }; + let _ = parent + .reader + .insert(unsafe { ManuallyDrop::take(wrapped_reader) }.unwrap_inner_pin()); + } + } +} + +impl<'a, S, R: WrappedPin + 'a, Sh> ZipFile<'a, S, R, Sh> { + pub(crate) fn decode_stream + WrappedPin>( + self, + ) -> ZipFile<'a, S, T, Sh> { + let s = UnsafeCell::new(ManuallyDrop::new(self)); + + let data: &'a ZipFileData = unsafe { &*s.get() }.data; + let wrapped_reader: &mut ManuallyDrop = &mut unsafe { &mut *s.get() }.wrapped_reader; + let parent: &'a mut ZipArchive = unsafe { &mut *s.get() }.parent; + let wrapped_reader = map_take_manual_drop(wrapped_reader, move |wrapped_reader: R| { + T::construct(data, Box::pin(wrapped_reader)) + }); + + let data: &'a ZipFileData = unsafe { &*s.get() }.data; + ZipFile { + data, + wrapped_reader, + parent, + } + } +} + +impl<'a, S, R: WrappedPin + io::AsyncRead, Sh> io::AsyncRead for ZipFile<'a, S, R, Sh> { + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut io::ReadBuf<'_>, + ) -> Poll> { + self.pin_stream().poll_read(cx, buf) + } +} + +#[derive(Debug)] +pub struct ZipArchive { + reader: Option>>, + shared: Arc, +} + +impl ZipArchive { + #[inline] + fn pin_reader_mut_option(self: Pin<&mut Self>) -> &mut Option>> { + &mut self.get_mut().reader + } + + #[inline] + fn pin_reader_assert(self: Pin<&mut Self>) -> Pin<&mut S> { + unsafe { + let s = self.get_unchecked_mut(); + s.reader.as_mut().unwrap().as_mut() + } + } + + #[inline] + pub fn shared(&self) -> Arc { + self.shared.clone() + } +} + +impl AsRawFd for ZipArchive { + fn as_raw_fd(&self) -> RawFd { + self.reader.as_ref().unwrap().as_raw_fd() + } +} + +impl ZipArchive { + pub async fn new(reader: Pin>) -> ZipResult { + let (shared, reader) = Shared::parse(reader).await?; + Ok(Self { + reader: Some(reader), + shared: Arc::new(shared), + }) + } +} + +impl ZipArchive { + pub fn mapped(shared: Arc, reader: Pin>) -> Self { + Self { + reader: Some(reader), + shared, + } + } +} + +impl WrappedPin for ZipArchive { + fn unwrap_inner_pin(self) -> Pin> { + self.reader.unwrap() + } +} + +impl ZipArchive { + pub async fn by_index( + self: Pin<&mut Self>, + index: usize, + ) -> ZipResult>, Sh>>>> { + let raw_entry: Pin, Sh>>> = self.by_index_raw(index).await?; + let decoded_entry: Pin>, Sh>>> = + Box::pin(Pin::into_inner(raw_entry).decode_stream()); + Ok(decoded_entry) + } + + pub async fn by_index_raw( + self: Pin<&mut Self>, + index: usize, + ) -> ZipResult, Sh>>>> { + let s = UnsafeCell::new(self); + let data = match unsafe { &*s.get() } + .shared + .contiguous_entries() + .get_index(index) + { + None => { + return Err(ZipError::FileNotFound); + } + Some((_, data)) => data, + }; + + let limited_reader = find_content( + data, + unsafe { &mut *s.get() } + .as_mut() + .pin_reader_mut_option() + .take() + .unwrap(), + ) + .await?; + + Ok(Box::pin(ZipFile { + data, + wrapped_reader: ManuallyDrop::new(limited_reader), + parent: unsafe { &mut *s.get() }, + })) + } + + pub fn raw_entries_stream( + self: Pin<&mut Self>, + ) -> impl Stream, Sh>>>>> + '_ { + let len = self.shared.len(); + let s = std::cell::UnsafeCell::new(self); + /* FIXME: make this a stream with a known length! */ + try_stream! { + for i in 0..len { + let f = Pin::new(unsafe { &mut **s.get() }).by_index_raw(i).await?; + yield f; + } + } + } + + pub fn entries_stream( + self: Pin<&mut Self>, + ) -> impl Stream>, Sh>>>>> + '_ + { + use futures_util::StreamExt; + + self.raw_entries_stream() + .map(|result| result.map(|entry| Box::pin(Pin::into_inner(entry).decode_stream()))) + } +} + +impl ZipArchive { + pub async fn by_name( + self: Pin<&mut Self>, + name: &str, + ) -> ZipResult>, Shared>>>> { + let index = match self.shared.files.get_index_of(name) { + None => { + return Err(ZipError::FileNotFound); + } + Some(n) => n, + }; + self.by_index(index).await + } + + pub(crate) async fn merge_contents( + mut self: Pin<&mut Self>, + mut w: Pin<&mut W>, + ) -> ZipResult> { + use rayon::prelude::*; + + let mut new_files: Box<[ZipFileData]> = self + .shared + .files + .par_values() + .cloned() + .collect::>() + .into_boxed_slice(); + if new_files.is_empty() { + return Ok(new_files); + } + /* The first file header will probably start at the beginning of the file, but zip doesn't + * enforce that, and executable zips like PEX files will have a shebang line so will + * definitely be greater than 0. + * + * assert_eq!(0, new_files[0].header_start); // Avoid this. + */ + + let new_initial_header_start = w.stream_position().await?; + + /* Push back file header starts for all entries in the covered files. */ + new_files + .par_iter_mut() + .map(|f| { + /* This is probably the only really important thing to change. */ + f.header_start = f.header_start.checked_add(new_initial_header_start).ok_or( + ZipError::InvalidArchive( + "new header start from merge would have been too large", + ), + )?; + /* This is only ever used internally to cache metadata lookups (i + t's not part of the + * zip spec), and 0 is the sentinel value. */ + f.central_header_start = 0; + /* This is an atomic variable so it can be updated from another thread in the + * implementation (which is good!). */ + let new_data_start = f + .data_start + /* NB: it's annoying there's no .checked_fetch_add(), but we don't need it here + * because nothing else has any reference to this data. */ + .load() + .checked_add(new_initial_header_start) + .ok_or(ZipError::InvalidArchive( + "new data start from merge would have been too large", + ))?; + f.data_start.store(new_data_start); + Ok(()) + }) + .collect::>()?; + + let shared = Arc::clone(&self.shared); + + /* Rewind to the beginning of the file. + * + * NB: we *could* decide to start copying from shared.offset instead, which + * would avoid copying over e.g. any pex shebangs or other file contents that start before + * the first zip file entry. However, zip files actually shouldn't care about garbage data + * in *between* real entries, since the central directory header records the correct start + * location of each, and keeping track of that math is more complicated logic that will only + * rarely be used, since most zips that get merged together are likely to be produced + * specifically for that purpose (and therefore are unlikely to have a shebang or other + * preface). Finally, this preserves any data that might actually be desired. + */ + self.as_mut() + .pin_reader_assert() + .seek(io::SeekFrom::Start(0)) + .await?; + /* Find the end of the file data. */ + let length_to_read = shared.directory_start as usize; + + let inner = self.as_mut().pin_reader_mut_option().take().unwrap(); + /* Produce an AsyncRead that reads bytes up until the start of the central directory + * header. */ + let mut limited_raw = Limiter::take(0, inner, length_to_read); + io::copy(&mut limited_raw, &mut w).await?; + + let _ = self + .pin_reader_mut_option() + .insert(limited_raw.unwrap_inner_pin()); + + /* Return the files we've just written to the data stream. */ + Ok(new_files) + } + + ///``` + /// # fn main() -> zip::result::ZipResult<()> { tokio_test::block_on(async { + /// use std::{io::Cursor, pin::Pin, sync::Arc}; + /// use tokio::{io::{self, AsyncReadExt, AsyncWriteExt}, fs}; + /// + /// let mut f = { + /// let buf = Cursor::new(Vec::new()); + /// let mut f = zip::tokio::write::ZipWriter::new(Box::pin(buf)); + /// let mut fp = Pin::new(&mut f); + /// let options = zip::write::FileOptions::default() + /// .compression_method(zip::CompressionMethod::Deflated); + /// fp.as_mut().start_file("a/b.txt", options).await?; + /// fp.write_all(b"hello\n").await?; + /// f.finish_into_readable().await? + /// }; + /// + /// let t = tempfile::tempdir()?; + /// + /// let root = t.path(); + /// Pin::new(&mut f).extract_simple(Arc::new(root.to_path_buf())).await?; + /// let msg = fs::read_to_string(root.join("a/b.txt")).await?; + /// assert_eq!(&msg, "hello\n"); + /// # Ok(()) + /// # })} + ///``` + pub async fn extract_simple(self: Pin<&mut Self>, root: Arc) -> ZipResult<()> { + fs::create_dir_all(&*root).await?; + + let entries = self.entries_stream(); + pin_mut!(entries); + + while let Some(mut file) = entries.try_next().await? { + let name = file.name()?; + let outpath = root.join(name); + + if CompletedPaths::is_dir(name) { + fs::create_dir_all(&outpath).await?; + } else { + if let Some(p) = outpath.parent() { + if !p.exists() { + fs::create_dir_all(p).await?; + } + } + let mut outfile = fs::File::create(&outpath).await?; + io::copy(&mut file, &mut outfile).await?; + } + #[cfg(unix)] + { + use std::os::unix::fs::PermissionsExt; + if let Some(mode) = file.data().unix_mode() { + fs::set_permissions(&outpath, std::fs::Permissions::from_mode(mode)).await?; + } + } + } + Ok(()) + } + + ///``` + /// # fn main() -> zip::result::ZipResult<()> { tokio_test::block_on(async { + /// use std::{io::Cursor, pin::Pin, sync::Arc}; + /// use tokio::{io::{self, AsyncReadExt, AsyncWriteExt}, fs}; + /// + /// let mut f = { + /// let buf = Cursor::new(Vec::new()); + /// let mut f = zip::tokio::write::ZipWriter::new(Box::pin(buf)); + /// let mut fp = Pin::new(&mut f); + /// let options = zip::write::FileOptions::default() + /// .compression_method(zip::CompressionMethod::Deflated); + /// fp.as_mut().start_file("a/b.txt", options).await?; + /// fp.write_all(b"hello\n").await?; + /// f.finish_into_readable().await? + /// }; + /// + /// let t = tempfile::tempdir()?; + /// + /// let root = t.path(); + /// Pin::new(&mut f).extract(Arc::new(root.to_path_buf())).await?; + /// let msg = fs::read_to_string(root.join("a/b.txt")).await?; + /// assert_eq!(&msg, "hello\n"); + /// # Ok(()) + /// # })} + ///``` + pub async fn extract(self: Pin<&mut Self>, root: Arc) -> ZipResult<()> { + fs::create_dir_all(&*root).await?; + + let names: Vec<&Path> = self + .shared + .file_names() + .map(|name| Path::new::(unsafe { mem::transmute(name) })) + .collect(); + + let paths = Arc::new(sync::RwLock::new(CompletedPaths::new())); + let (path_tx, path_rx) = mpsc::unbounded_channel::<&Path>(); + let (compressed_tx, compressed_rx) = mpsc::unbounded_channel::<(&Path, Box<[u8]>)>(); + let (paired_tx, paired_rx) = mpsc::unbounded_channel::<(&Path, Box<[u8]>)>(); + + /* (1) Before we even start reading from the file handle, we know what our output paths are + * going to be from the ZipFileData, so create any necessary subdirectory structures. */ + let root2 = root.clone(); + let paths2 = paths.clone(); + let dirs_task = task::spawn(async move { + use futures_util::{stream, StreamExt}; + + let path_tx = &path_tx; + stream::iter(names.into_iter()) + .map(Ok) + .try_for_each_concurrent(None, move |name| { + let root2 = root2.clone(); + let paths2 = paths2.clone(); + async move { + /* dbg!(&name); */ + let new_dirs = paths2.read().await.new_containing_dirs_needed(&name); + for dir in new_dirs.into_iter() { + if paths2.read().await.contains(&dir) { + continue; + } + let full_dir = root2.join(&dir); + if create_dir_idempotent(full_dir).await?.is_some() { + paths2.write().await.confirm_dir(dir); + } + } + + path_tx.send(name).unwrap(); + + Ok::<_, ZipError>(()) + } + }) + .await?; + + Ok::<_, ZipError>(()) + }); + + /* (2) Match up the uncompressed buffers with open file handles to extract to! */ + let shared = self.shared.clone(); + let matching_task = task::spawn(async move { + use futures_util::{select, FutureExt}; + use tokio_stream::{wrappers::UnboundedReceiverStream, StreamExt}; + + let mut path_rx = UnboundedReceiverStream::new(path_rx); + + let mut compressed_rx = UnboundedReceiverStream::new(compressed_rx); + + let mut remaining_unmatched_paths: IndexMap<&Path, (bool, Option>)> = shared + .files + .values() + .map(|data| { + data.enclosed_name() + .ok_or(ZipError::InvalidArchive("Invalid file path")) + .map(|name| (name, (false, None))) + }) + .collect::>>()?; + + let mut stopped_path = false; + let mut stopped_compressed = false; + loop { + let (name, val) = select! { + x = path_rx.next().fuse() => match x { + Some(name) => { + let val = remaining_unmatched_paths.get_mut(&name).unwrap(); + assert_eq!(val.0, false); + val.0 = true; + (name, val) + }, + None => { + stopped_path = true; + continue; + }, + }, + x = compressed_rx.next().fuse() => match x { + Some((name, buf)) => { + let val = remaining_unmatched_paths.get_mut(&name).unwrap(); + assert!(val.1.is_none()); + let _ = val.1.insert(buf); + (name, val) + }, + None => { + stopped_compressed = true; + continue; + }, + }, + complete => break, + }; + /* dbg!(&name); */ + if val.0 && val.1.is_some() { + let buf = mem::take(&mut val.1).unwrap(); + remaining_unmatched_paths.remove(&name).unwrap(); + paired_tx.send((name, buf)).unwrap(); + } + if stopped_path && stopped_compressed { + break; + } + if remaining_unmatched_paths.is_empty() { + break; + } + } + + Ok::<_, ZipError>(()) + }); + + /* (3) Attempt to offload decompression to as many threads as possible. */ + let shared = self.shared.clone(); + let root2 = root.clone(); + let decompress_task = task::spawn(async move { + use futures_util::StreamExt; + use tokio_stream::wrappers::UnboundedReceiverStream; + + let paired_rx = UnboundedReceiverStream::new(paired_rx); + paired_rx + .map(Ok) + .try_for_each_concurrent(None, move |(name, buf)| { + let shared = shared.clone(); + let root2 = root2.clone(); + async move { + /* dbg!(&name); */ + let data = shared.files.get(CompletedPaths::path_str(&name)).unwrap(); + + /* Get the file to write to. */ + let full_path = root2.join(&name); + let mut handle = fs::OpenOptions::new() + .create(true) + .write(true) + .truncate(true) + .open(full_path) + .await?; + /* Set the length, in case this improves performance writing to the handle + * just below. */ + handle.set_len(data.uncompressed_size).await?; + + let uncompressed_size = data.uncompressed_size as usize; + /* We already know *exactly* how many bytes we will need to read out + * (because this info is recorded in the zip file entryu), so we can + * allocate exactly that much to minimize allocation as well as + * blocking on memory availability for the decompressor. */ + let mut wrapped = BufReader::with_capacity( + num::NonZeroUsize::new(uncompressed_size).unwrap(), + Box::pin(ZipFileWrappedReader::construct( + data, + Box::pin(buf.as_ref()), + )), + ); + + assert_eq!( + uncompressed_size as u64, + /* NB: This appears to be faster than calling .read_to_end() and + * .write_all() with an intermediate buffer for some reason! */ + io::copy_buf(&mut wrapped, &mut handle).await? + ); + + cfg_if! { + if #[cfg(unix)] { + use std::os::unix::fs::PermissionsExt; + + if let Some(mode) = data.unix_mode() { + handle + .set_permissions(std::fs::Permissions::from_mode(mode)) + .await?; + } + handle.sync_all().await?; + } else { + handle.sync_data().await?; + } + } + + Ok::<_, ZipError>(()) + } + }) + .await?; + + Ok::<_, ZipError>(()) + }); + + /* (4) In order, scan off the raw memory for every file entry into a Box<[u8]> to avoid + * interleaving decompression with read i/o. */ + let entries = self.raw_entries_stream(); + pin_mut!(entries); + + while let Some(mut file) = entries.try_next().await? { + let name: &'static Path = unsafe { mem::transmute(file.name()?) }; + /* dbg!(&name); */ + if CompletedPaths::is_dir(&name) { + continue; + } + let compressed_size = file.data.compressed_size as usize; + + let mut compressed_contents: Vec = Vec::with_capacity(compressed_size); + assert_eq!( + compressed_size, + file.read_to_end(&mut compressed_contents).await? + ); + compressed_tx + .send((name, compressed_contents.into_boxed_slice())) + .unwrap(); + } + mem::drop(compressed_tx); + + dirs_task.await.expect("panic in subtask")?; + matching_task.await.expect("panic in subtask")?; + decompress_task.await.expect("panic in subtask")?; + + Ok(()) + } +} + +impl ZipArchive { + pub(crate) fn from_finalized_writer( + files: Vec, + comment: Vec, + stream: Pin>, + directory_start: u64, + ) -> ZipResult { + use rayon::prelude::*; + + /* This is where the whole file starts. */ + if let Some(initial_offset) = files.first().map(|d| d.header_start) { + let files: IndexMap = files + .into_par_iter() + .map(|d| (d.file_name.clone(), d)) + .collect(); + let shared = Shared { + files, + offset: initial_offset, + directory_start, + comment, + }; + Ok(Self { + reader: Some(stream), + shared: Arc::new(shared), + }) + } else { + /* We currently require at least 1 file in order to determine the `initial_offset`. */ + Err(ZipError::InvalidArchive( + "attempt to finalize empty zip writer into readable", + )) + } + } +} + +pub(crate) mod read_spec { + use crate::{ + compression::CompressionMethod, + result::{ZipError, ZipResult}, + spec::{self, CentralDirectoryHeaderBuffer}, + types::ZipFileData, + }; + + use byteorder::{ByteOrder, LittleEndian}; + use tokio::io::{self, AsyncReadExt, AsyncSeekExt}; + + use std::{mem, pin::Pin, slice}; + + /// Parse a central directory entry to collect the information for the file. + pub async fn central_header_to_zip_file( + mut reader: Pin<&mut R>, + archive_offset: u64, + ) -> ZipResult { + use crate::cp437::FromCp437; + use crate::types::{AtomicU64, DateTime}; + + let central_header_start = reader.stream_position().await?; + + static_assertions::assert_eq_size!([u8; 46], CentralDirectoryHeaderBuffer); + let mut info = [0u8; 46]; + reader.read_exact(&mut info[..]).await?; + + let CentralDirectoryHeaderBuffer { + magic, + version_made_by, + /* version_needed: _, */ + flag, + compression_method, + last_modified_time_timepart, + last_modified_time_datepart, + crc32, + compressed_size, + uncompressed_size, + file_name_length, + extra_field_length, + file_comment_length, + /* disk_number_start: _, */ + /* _internal_file_attributes, */ + external_attributes, + header_start, + .. + } = unsafe { mem::transmute(info) }; + + if magic != spec::CENTRAL_DIRECTORY_HEADER_SIGNATURE { + return Err(ZipError::InvalidArchive("Invalid Central Directory header")); + } + + let encrypted = flag & 1 == 1; + let is_utf8 = flag & (1 << 11) != 0; + let using_data_descriptor = flag & (1 << 3) != 0; + + let mut file_name_raw = vec![0; file_name_length as usize]; + reader.read_exact(&mut file_name_raw).await?; + let mut extra_field = vec![0; extra_field_length as usize]; + reader.read_exact(&mut extra_field).await?; + let mut file_comment_raw = vec![0; file_comment_length as usize]; + reader.read_exact(&mut file_comment_raw).await?; + + let file_name = match is_utf8 { + true => String::from_utf8_lossy(&file_name_raw).into_owned(), + false => file_name_raw.clone().from_cp437(), + }; + let file_comment = match is_utf8 { + true => String::from_utf8_lossy(&file_comment_raw).into_owned(), + false => file_comment_raw.from_cp437(), + }; + + // Construct the result + let mut result = ZipFileData { + system: ((version_made_by >> 8) as u8).into(), + version_made_by: version_made_by as u8, + encrypted, + using_data_descriptor, + compression_method: { + #[allow(deprecated)] + CompressionMethod::from_u16(compression_method) + }, + compression_level: None, + last_modified_time: DateTime::from_msdos( + last_modified_time_datepart, + last_modified_time_timepart, + ), + crc32, + compressed_size: compressed_size as u64, + uncompressed_size: uncompressed_size as u64, + file_name, + file_name_raw, + extra_field, + file_comment, + header_start: header_start as u64, + central_header_start, + data_start: AtomicU64::new(0), + external_attributes, + large_file: false, + aes_mode: None, + }; + + match parse_extra_field(&mut result).await { + Ok(..) | Err(ZipError::Io(..)) => {} + Err(e) => return Err(e), + } + + let aes_enabled = result.compression_method == CompressionMethod::AES; + if aes_enabled && result.aes_mode.is_none() { + return Err(ZipError::InvalidArchive( + "AES encryption without AES extra data field", + )); + } + + // Account for shifted zip offsets. + result.header_start = result + .header_start + .checked_add(archive_offset) + .ok_or(ZipError::InvalidArchive("Archive header is too large"))?; + + Ok(result) + } + + async fn parse_extra_field(file: &mut ZipFileData) -> ZipResult<()> { + let mut reader = std::io::Cursor::new(&file.extra_field); + + while (reader.position() as usize) < file.extra_field.len() { + let mut buf = [0u8; 32]; + + reader.read_exact(&mut buf[..]).await?; + + static_assertions::assert_eq_size!([u8; 32], (u16, u16, u64, u64, u64)); + + LittleEndian::from_slice_u16(unsafe { + slice::from_raw_parts_mut(buf.as_mut_ptr() as *mut u16, 14) + }); + + let args: (u16, u16, u64, u64, u64) = unsafe { mem::transmute(buf) }; + let (kind, len, uncompressed_size, compressed_size, header_start) = args; + + let mut len_left = len as i64; + match kind { + // Zip64 extended information extra field + 0x0001 => { + if file.uncompressed_size >= spec::ZIP64_BYTES_THR { + file.large_file = true; + file.uncompressed_size = uncompressed_size; + len_left -= 8; + } + if file.compressed_size >= spec::ZIP64_BYTES_THR { + file.large_file = true; + file.compressed_size = compressed_size; + len_left -= 8; + } + if file.header_start == spec::ZIP64_BYTES_THR { + file.header_start = header_start; + len_left -= 8; + } + } + _ => { + // Other fields are ignored + } + } + + assert!(len_left >= 0); + if len_left > 0 { + reader.seek(io::SeekFrom::Current(len_left)).await?; + } + } + Ok(()) + } +} + +#[cfg(test)] +mod test { + use super::*; + use crate::{ + compression::CompressionMethod, + tokio::{combinators::KnownExpanse, write::ZipWriter}, + write::FileOptions, + }; + + use tokio::io::AsyncWriteExt; + + use std::io::Cursor; + + #[tokio::test] + async fn test_find_content() -> ZipResult<()> { + let f = { + let buf = Cursor::new(Vec::new()); + let mut f = ZipWriter::new(Box::pin(buf)); + let mut fp = Pin::new(&mut f); + let options = FileOptions::default().compression_method(CompressionMethod::Stored); + fp.as_mut().start_file("a/b.txt", options).await?; + fp.write_all(b"hello\n").await?; + f.finish_into_readable().await? + }; + + assert_eq!(1, f.shared.len()); + let data = f + .shared + .contiguous_entries() + .get_index(0) + .unwrap() + .1 + .clone(); + assert_eq!("a/b.txt", &data.file_name); + + let mut limited = find_content(&data, f.unwrap_inner_pin()).await?; + + let mut buf = String::new(); + limited.read_to_string(&mut buf).await?; + assert_eq!(&buf, "hello\n"); + + Ok(()) + } + + #[tokio::test] + async fn test_get_reader() -> ZipResult<()> { + let f = { + let buf = Cursor::new(Vec::new()); + let mut f = ZipWriter::new(Box::pin(buf)); + let mut fp = Pin::new(&mut f); + let options = FileOptions::default().compression_method(CompressionMethod::Deflated); + fp.as_mut().start_file("a/b.txt", options).await?; + fp.write_all(b"hello\n").await?; + f.finish_into_readable().await? + }; + + assert_eq!(1, f.shared.len()); + let data = f + .shared + .contiguous_entries() + .get_index(0) + .unwrap() + .1 + .clone(); + assert_eq!(data.crc32, 909783072); + assert_eq!("a/b.txt", &data.file_name); + + let mut limited = find_content(&data, f.unwrap_inner_pin()).await?; + + let mut buf: Vec = Vec::new(); + io::AsyncReadExt::read_to_end(&mut limited, &mut buf).await?; + /* This is compressed, so it should NOT match! */ + assert_ne!(&buf, b"hello\n"); + assert_eq!(buf.len(), limited.full_length()); + assert_eq!(buf.len(), data.compressed_size as usize); + assert_eq!(b"hello\n".len(), data.uncompressed_size as usize); + + io::AsyncSeekExt::rewind(&mut limited).await?; + /* This stream should decode the compressed content! */ + let mut decoded = ZipFileWrappedReader::construct(&data, Box::pin(limited)); + let mut buf = String::new(); + io::AsyncReadExt::read_to_string(&mut decoded, &mut buf).await?; + assert_eq!(&buf, "hello\n"); + + Ok(()) + } +} diff --git a/src/tokio/stream_impls.rs b/src/tokio/stream_impls.rs new file mode 100755 index 000000000..319cc18ec --- /dev/null +++ b/src/tokio/stream_impls.rs @@ -0,0 +1,556 @@ +/// As 2 separate read calls: +///``` +/// # fn main() -> std::io::Result<()> { tokio_test::block_on(async { +/// use zip::tokio::{WrappedPin, buf_reader::BufReader, buf_writer::BufWriter, stream_impls::deflate::*}; +/// use flate2::{Decompress, Compress, Compression}; +/// use tokio::io::{self, AsyncReadExt, AsyncWriteExt}; +/// use std::{io::Cursor, pin::Pin, num::NonZeroUsize}; +/// +/// let msg = "hello"; +/// let c = Compression::default(); +/// let buf_reader = BufReader::new(Box::pin(Cursor::new(msg.as_bytes()))); +/// let mut def = Reader::with_state(Compress::new(c, false), Box::pin(buf_reader)); +/// +/// let mut buf = Vec::new(); +/// { +/// use tokio::io::{AsyncReadExt, AsyncSeekExt}; +/// def.read_to_end(&mut buf).await?; +/// assert_eq!(&buf, &[203, 72, 205, 201, 201, 7, 0]); +/// } +/// +/// let buf_writer = BufWriter::new(Box::pin(Cursor::new(Vec::new()))); +/// let mut out_inf = Writer::with_state( +/// Decompress::new(false), +/// Box::pin(buf_writer), +/// ); +/// { +/// use tokio::io::{AsyncReadExt, AsyncSeekExt}; +/// out_inf.write_all(&buf).await?; +/// out_inf.flush().await?; +/// out_inf.shutdown().await?; +/// let buf: Vec = Pin::into_inner(Pin::into_inner(out_inf.unwrap_inner_pin()).unwrap_inner_pin()).into_inner(); +/// assert_eq!(&buf[..], b"hello"); +/// } +/// # Ok(()) +/// # })} +///``` +/// +/// Or, within a single `tokio::io::copy{,_buf}()`: +///``` +/// # fn main() -> std::io::Result<()> { tokio_test::block_on(async { +/// use zip::tokio::{WrappedPin, buf_reader::BufReader, buf_writer::BufWriter, stream_impls::deflate::*}; +/// use flate2::{Decompress, Compress, Compression}; +/// use tokio::io::{self, AsyncReadExt, AsyncWriteExt}; +/// use std::{io::Cursor, pin::Pin, num::NonZeroUsize}; +/// +/// let msg = "hello"; +/// let c = Compression::default(); +/// let buf_reader = BufReader::new(Box::pin(Cursor::new(msg.as_bytes()))); +/// let mut def = Reader::with_state(Compress::new(c, false), Box::pin(buf_reader)); +/// +/// let buf_writer = BufWriter::new(Box::pin(Cursor::new(Vec::new()))); +/// let mut out_inf = Writer::with_state( +/// Decompress::new(false), +/// Box::pin(buf_writer), +/// ); +/// +/// io::copy(&mut def, &mut out_inf).await?; +/// out_inf.flush().await?; +/// out_inf.shutdown().await?; +/// +/// let final_buf: Vec = Pin::into_inner(Pin::into_inner(out_inf.unwrap_inner_pin()).unwrap_inner_pin()).into_inner(); +/// let s = std::str::from_utf8(&final_buf[..]).unwrap(); +/// assert_eq!(&s, &msg); +/// # Ok(()) +/// # })} +///``` +#[cfg(any( + feature = "deflate", + feature = "deflate-miniz", + feature = "deflate-zlib" +))] +pub mod deflate { + use crate::tokio::{ + buf_writer::{AsyncBufWrite, NonEmptyWriteSlice}, + WrappedPin, + }; + + use flate2::{ + Compress, CompressError, Decompress, DecompressError, FlushCompress, FlushDecompress, + Status, + }; + use tokio::io; + + use std::{ + fmt, + num::NonZeroUsize, + pin::Pin, + task::{ready, Context, Poll}, + }; + + pub trait Ops { + type Flush: Flush; + type E: fmt::Display; + fn total_in(&self) -> u64; + fn total_out(&self) -> u64; + fn encode_frame( + &mut self, + input: &[u8], + output: &mut [u8], + flush: Self::Flush, + ) -> Result; + } + + /// Compress: + ///``` + /// # fn main() -> std::io::Result<()> { tokio_test::block_on(async { + /// use zip::tokio::{stream_impls::deflate::Reader, buf_reader::BufReader}; + /// use flate2::{Compress, Compression}; + /// use tokio::io::{self, AsyncReadExt, AsyncBufRead}; + /// use std::{io::Cursor, pin::Pin}; + /// + /// let msg = "hello"; + /// let c = Compression::default(); + /// let buf = BufReader::new(Box::pin(Cursor::new(msg.as_bytes()))); + /// let mut def = Reader::with_state(Compress::new(c, false), Box::pin(buf)); + /// + /// let mut b = Vec::new(); + /// def.read_to_end(&mut b).await?; + /// assert_eq!(&b, &[203, 72, 205, 201, 201, 7, 0]); + /// # Ok(()) + /// # })} + ///``` + /// + /// Decompress: + ///``` + /// # fn main() -> std::io::Result<()> { tokio_test::block_on(async { + /// use zip::tokio::{stream_impls::deflate::Reader, buf_reader::BufReader}; + /// use flate2::Decompress; + /// use tokio::io::{self, AsyncReadExt}; + /// use std::{io::Cursor, pin::Pin}; + /// + /// let msg: &[u8] = &[203, 72, 205, 201, 201, 7, 0]; + /// let buf = BufReader::new(Box::pin(Cursor::new(msg))); + /// let mut inf = Reader::with_state(Decompress::new(false), Box::pin(buf)); + /// + /// let mut s = String::new(); + /// inf.read_to_string(&mut s).await?; + /// assert_eq!(&s, "hello"); + /// # Ok(()) + /// # })} + ///``` + pub struct Reader { + state: O, + inner: Pin>, + } + + struct ReadProj<'a, O, S> { + pub state: &'a mut O, + pub inner: Pin<&'a mut S>, + } + + impl Reader { + #[inline] + fn pin_stream(self: Pin<&mut Self>) -> Pin<&mut S> { + self.project().inner + } + + #[inline] + fn project(self: Pin<&mut Self>) -> ReadProj<'_, O, S> { + unsafe { + let Self { inner, state } = self.get_unchecked_mut(); + ReadProj { + inner: Pin::new_unchecked(inner.as_mut().get_unchecked_mut()), + state, + } + } + } + } + + impl WrappedPin for Reader { + fn unwrap_inner_pin(self) -> Pin> { + self.inner + } + } + + impl Reader { + pub fn with_state(state: O, inner: Pin>) -> Self { + Self { state, inner } + } + } + + impl io::AsyncBufRead for Reader { + fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.pin_stream().poll_fill_buf(cx) + } + fn consume(self: Pin<&mut Self>, amt: usize) { + self.pin_stream().consume(amt); + } + } + + impl io::AsyncRead for Reader { + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut io::ReadBuf<'_>, + ) -> Poll> { + debug_assert!(buf.remaining() > 0); + + let mut me = self.project(); + + loop { + let input: &[u8] = ready!(me.inner.as_mut().poll_fill_buf(cx))?; + let eof: bool = input.is_empty(); + let (before_out, before_in): (u64, u64) = + { (me.state.total_out(), me.state.total_in()) }; + let flush = if eof { + O::Flush::finish() + } else { + O::Flush::none() + }; + + let ret = me + .state + .encode_frame(input, buf.initialize_unfilled(), flush); + + let (num_read, num_consumed): (usize, usize) = ( + (me.state.total_out() - before_out) as usize, + (me.state.total_in() - before_in) as usize, + ); + + buf.set_filled(buf.filled().len() + num_read as usize); + me.inner.as_mut().consume(num_consumed); + + match ret { + Ok(Status::Ok | Status::BufError) if num_read == 0 && !eof => (), + Ok(Status::Ok | Status::BufError | Status::StreamEnd) => { + return Poll::Ready(Ok(())); + } + Err(e) => { + return Poll::Ready(Err(io::Error::new( + io::ErrorKind::InvalidInput, + format!("corrupt read stream({})", e), + ))) + } + } + } + } + } + + /// Compress: + ///``` + /// # fn main() -> std::io::Result<()> { tokio_test::block_on(async { + /// use zip::tokio::{WrappedPin, stream_impls::deflate::Writer, buf_writer::BufWriter}; + /// use flate2::{Compress, Compression}; + /// use tokio::io::{self, AsyncWriteExt}; + /// use std::{io::Cursor, pin::Pin}; + /// + /// let msg = "hello"; + /// let c = Compression::default(); + /// let buf = BufWriter::new(Box::pin(Cursor::new(Vec::new()))); + /// let mut def = Writer::with_state( + /// Compress::new(c, false), + /// Box::pin(buf), + /// ); + /// + /// def.write_all(msg.as_bytes()).await?; + /// def.flush().await?; + /// def.shutdown().await?; + /// let buf: Vec = Pin::into_inner(Pin::into_inner(def.unwrap_inner_pin()).unwrap_inner_pin()).into_inner(); + /// let expected: &[u8] = &[202, 72, 205, 201, 201, 7, 0, 0, 0, 255, 255, 3, 0]; + /// assert_eq!(&buf[..], expected); + /// # Ok(()) + /// # })} + ///``` + /// + /// Decompress: + ///``` + /// # fn main() -> std::io::Result<()> { tokio_test::block_on(async { + /// use zip::tokio::{WrappedPin, buf_writer::BufWriter, stream_impls::deflate::Writer}; + /// use flate2::Decompress; + /// use tokio::io::{self, AsyncWriteExt}; + /// use std::{cmp, io::Cursor, pin::Pin}; + /// + /// let msg: &[u8] = &[202, 72, 205, 201, 201, 231, 2, 0, 0, 0, 255, 255, 3, 0]; + /// let buf = BufWriter::new(Box::pin(Cursor::new(Vec::new()))); + /// let mut inf = Writer::with_state(Decompress::new(false), Box::pin(buf)); + /// + /// inf.write_all(msg).await?; + /// inf.flush().await?; + /// inf.shutdown().await?; + /// let buf: Vec = Pin::into_inner(Pin::into_inner(inf.unwrap_inner_pin()).unwrap_inner_pin()).into_inner(); + /// let expected = b"hello\n"; + /// assert_eq!(&buf, &expected); + /// # Ok(()) + /// # })} + ///``` + pub struct Writer { + state: O, + inner: Pin>, + } + + struct WriteProj<'a, O, S> { + pub state: &'a mut O, + pub inner: Pin<&'a mut S>, + } + + impl Writer { + #[inline] + fn pin_stream(self: Pin<&mut Self>) -> Pin<&mut S> { + self.project().inner + } + + #[inline] + fn project(self: Pin<&mut Self>) -> WriteProj<'_, O, S> { + unsafe { + let Self { inner, state } = self.get_unchecked_mut(); + WriteProj { + inner: Pin::new_unchecked(inner.as_mut().get_unchecked_mut()), + state, + } + } + } + } + + impl WrappedPin for Writer { + fn unwrap_inner_pin(self) -> Pin> { + self.inner + } + } + + impl Writer { + pub fn with_state(state: O, inner: Pin>) -> Self { + Self { state, inner } + } + } + + impl io::AsyncWrite for Writer { + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + debug_assert!(buf.len() > 0); + + let mut me = self.project(); + + loop { + let mut write_buf: NonEmptyWriteSlice<'_, u8> = + ready!(me.inner.as_mut().poll_writable(cx))?; + + let before_in = me.state.total_in(); + let before_out = me.state.total_out(); + + let status = me + .state + .encode_frame(buf, &mut *write_buf, O::Flush::none()) + .map_err(|e| { + io::Error::new( + io::ErrorKind::InvalidInput, + format!("corrupt write stream({})/1", e), + ) + })?; + + let num_read = (me.state.total_in() - before_in) as usize; + let num_consumed = (me.state.total_out() - before_out) as usize; + + if let Some(num_consumed) = NonZeroUsize::new(num_consumed) { + me.inner.as_mut().consume_write(num_consumed); + } + + match (num_read, status) { + (0, Status::Ok | Status::BufError) => { + continue; + } + (n, _) => { + return Poll::Ready(Ok(n)); + } + } + } + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let mut me = self.project(); + + { + let mut write_buf: NonEmptyWriteSlice<'_, u8> = + ready!(me.inner.as_mut().poll_writable(cx))?; + + let before_out = me.state.total_out(); + + me.state + .encode_frame(&[], &mut *write_buf, O::Flush::sync()) + .map_err(|e| { + io::Error::new( + io::ErrorKind::InvalidInput, + format!("corrupt write stream({})/2", e), + ) + })?; + + let num_consumed = (me.state.total_out() - before_out) as usize; + if let Some(num_consumed) = NonZeroUsize::new(num_consumed) { + me.inner.as_mut().consume_write(num_consumed); + } + } + + loop { + let mut write_buf = ready!(me.inner.as_mut().poll_writable(cx))?; + + let before_out = me.state.total_out(); + + me.state + .encode_frame(&[], &mut *write_buf, O::Flush::none()) + .map_err(|e| { + io::Error::new( + io::ErrorKind::InvalidInput, + format!("corrupt write stream({})/3", e), + ) + })?; + + let num_consumed = (me.state.total_out() - before_out) as usize; + + if let Some(num_consumed) = NonZeroUsize::new(num_consumed) { + me.inner.as_mut().consume_write(num_consumed); + } else { + break; + } + } + + me.inner.poll_flush(cx) + } + + fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let mut me = self.project(); + + loop { + let mut write_buf = ready!(me.inner.as_mut().poll_writable(cx))?; + + let before_out = me.state.total_out(); + + me.state + .encode_frame(&[], &mut *write_buf, O::Flush::finish()) + .map_err(|e| { + io::Error::new( + io::ErrorKind::InvalidInput, + format!("corrupt write stream({})/4", e), + ) + })?; + + let num_consumed = (me.state.total_out() - before_out) as usize; + if let Some(num_consumed) = NonZeroUsize::new(num_consumed) { + me.inner.as_mut().consume_write(num_consumed); + } else { + break; + } + } + + me.inner.poll_shutdown(cx) + } + } + + impl AsyncBufWrite for Writer { + #[inline] + fn consume_read(self: Pin<&mut Self>, amt: NonZeroUsize) { + self.pin_stream().consume_read(amt); + } + #[inline] + fn readable_data(&self) -> &[u8] { + self.inner.readable_data() + } + + #[inline] + fn consume_write(self: Pin<&mut Self>, amt: NonZeroUsize) { + self.pin_stream().consume_write(amt); + } + #[inline] + fn try_writable(self: Pin<&mut Self>) -> Option> { + self.pin_stream().try_writable() + } + #[inline] + fn poll_writable( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll>> { + self.pin_stream().poll_writable(cx) + } + + #[inline] + fn reset(self: Pin<&mut Self>) { + self.pin_stream().reset(); + } + } + + impl Ops for Compress { + type Flush = FlushCompress; + type E = CompressError; + #[inline] + fn total_in(&self) -> u64 { + self.total_in() + } + #[inline] + fn total_out(&self) -> u64 { + self.total_out() + } + #[inline] + fn encode_frame( + &mut self, + input: &[u8], + output: &mut [u8], + flush: Self::Flush, + ) -> Result { + self.compress(input, output, flush) + } + } + + impl Ops for Decompress { + type Flush = FlushDecompress; + type E = DecompressError; + #[inline] + fn total_in(&self) -> u64 { + self.total_in() + } + #[inline] + fn total_out(&self) -> u64 { + self.total_out() + } + #[inline] + fn encode_frame( + &mut self, + input: &[u8], + output: &mut [u8], + flush: Self::Flush, + ) -> Result { + self.decompress(input, output, flush) + } + } + + pub trait Flush { + fn none() -> Self; + fn sync() -> Self; + fn finish() -> Self; + } + + impl Flush for FlushCompress { + fn none() -> Self { + Self::None + } + fn sync() -> Self { + Self::Sync + } + fn finish() -> Self { + Self::Finish + } + } + + impl Flush for FlushDecompress { + fn none() -> Self { + Self::None + } + fn sync() -> Self { + Self::Sync + } + fn finish() -> Self { + Self::Finish + } + } +} diff --git a/src/tokio/write.rs b/src/tokio/write.rs new file mode 100644 index 000000000..69ed333ef --- /dev/null +++ b/src/tokio/write.rs @@ -0,0 +1,1315 @@ +use crate::{ + compression::CompressionMethod, + result::{ZipError, ZipResult}, + spec, + tokio::{ + buf_writer::BufWriter, + read::{read_spec, Shared, ZipArchive}, + stream_impls::deflate, + utils::map_swap_uninit, + WrappedPin, + }, + types::{ZipFileData, DEFAULT_VERSION}, + write::{FileOptions, ZipRawValues, ZipWriterStats}, +}; + +use tokio::io::{self, AsyncSeekExt, AsyncWriteExt}; + +use std::{ + cmp, mem, ops, + pin::Pin, + ptr, + task::{ready, Context, Poll}, +}; + +#[cfg(any( + feature = "deflate", + feature = "deflate-miniz", + feature = "deflate-zlib" +))] +use flate2::{Compress, Compression}; + +pub struct StoredWriter(Pin>); + +impl StoredWriter { + pub fn new(inner: Pin>) -> Self { + Self(inner) + } + + #[inline] + fn pin_stream(self: Pin<&mut Self>) -> Pin<&mut S> { + unsafe { self.get_unchecked_mut() }.0.as_mut() + } +} + +impl WrappedPin for StoredWriter { + fn unwrap_inner_pin(self) -> Pin> { + self.0 + } +} + +impl io::AsyncWrite for StoredWriter { + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + self.pin_stream().poll_write(cx, buf) + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.pin_stream().poll_flush(cx) + } + + fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.pin_stream().poll_shutdown(cx) + } +} + +pub struct DeflatedWriter(deflate::Writer>); + +impl DeflatedWriter { + #[inline] + fn pin_stream(self: Pin<&mut Self>) -> Pin<&mut deflate::Writer>> { + unsafe { self.map_unchecked_mut(|Self(inner)| inner) } + } +} + +impl DeflatedWriter { + pub fn new(compression: Compression, inner: Pin>) -> Self { + let compress = Compress::new(compression, false); + let buf_writer = BufWriter::new(inner); + Self(deflate::Writer::with_state(compress, Box::pin(buf_writer))) + } +} + +impl WrappedPin for DeflatedWriter { + fn unwrap_inner_pin(self) -> Pin> { + Pin::into_inner(self.0.unwrap_inner_pin()).unwrap_inner_pin() + } +} + +impl io::AsyncWrite for DeflatedWriter { + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + self.pin_stream().poll_write(cx, buf) + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.pin_stream().poll_flush(cx) + } + + fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.pin_stream().poll_shutdown(cx) + } +} + +pub enum ZipFileWrappedWriter { + Stored(StoredWriter), + Deflated(DeflatedWriter), +} + +enum WrappedProj<'a, S> { + Stored(Pin<&'a mut StoredWriter>), + Deflated(Pin<&'a mut DeflatedWriter>), +} + +impl ZipFileWrappedWriter { + #[inline] + fn project(self: Pin<&mut Self>) -> WrappedProj<'_, S> { + unsafe { + let s = self.get_unchecked_mut(); + match s { + Self::Stored(s) => WrappedProj::Stored(Pin::new_unchecked(s)), + Self::Deflated(s) => WrappedProj::Deflated(Pin::new_unchecked(s)), + } + } + } +} + +impl io::AsyncWrite for ZipFileWrappedWriter { + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + match self.project() { + WrappedProj::Stored(w) => w.poll_write(cx, buf), + WrappedProj::Deflated(w) => w.poll_write(cx, buf), + } + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + match self.project() { + WrappedProj::Stored(w) => w.poll_flush(cx), + WrappedProj::Deflated(w) => w.poll_flush(cx), + } + } + + fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + match self.project() { + WrappedProj::Stored(w) => w.poll_shutdown(cx), + WrappedProj::Deflated(w) => w.poll_shutdown(cx), + } + } +} + +impl WrappedPin for ZipFileWrappedWriter { + fn unwrap_inner_pin(self) -> Pin> { + match self { + Self::Stored(s) => s.unwrap_inner_pin(), + Self::Deflated(s) => s.unwrap_inner_pin(), + } + } +} + +enum InnerWriter { + NoActiveFile(Pin>), + FileWriter(ZipFileWrappedWriter), +} + +enum InnerProj<'a, S> { + NoActiveFile(Pin<&'a mut S>), + FileWriter(Pin<&'a mut ZipFileWrappedWriter>), +} + +impl InnerWriter { + #[inline] + pub(crate) fn project(self: Pin<&mut Self>) -> InnerProj<'_, S> { + unsafe { + let s = self.get_unchecked_mut(); + match s { + Self::NoActiveFile(s) => InnerProj::NoActiveFile(s.as_mut()), + Self::FileWriter(s) => InnerProj::FileWriter(Pin::new_unchecked(s)), + } + } + } +} + +enum WriterWrapResult { + Stored, + Deflated(Compression), +} + +impl WriterWrapResult { + pub(crate) fn wrap(self, s: Pin>) -> ZipFileWrappedWriter { + match self { + Self::Stored => ZipFileWrappedWriter::Stored(StoredWriter::new(s)), + Self::Deflated(c) => ZipFileWrappedWriter::Deflated(DeflatedWriter::new(c, s)), + } + } +} + +impl InnerWriter { + /// Returns `directory_start`. + pub(crate) async fn finalize( + self: Pin<&mut Self>, + files: &[ZipFileData], + comment: &[u8], + ) -> ZipResult { + match self.project() { + InnerProj::FileWriter(_) => unreachable!("stream should be unwrapped!"), + InnerProj::NoActiveFile(mut inner) => { + let central_start = inner.stream_position().await?; + for file in files.iter() { + write_spec::write_central_directory_header(inner.as_mut(), file).await?; + } + let central_size = inner.stream_position().await? - central_start; + + /* If we have to create a zip64 file, generate the appropriate additional footer. */ + if files.len() > spec::ZIP64_ENTRY_THR + || central_size.max(central_start) > spec::ZIP64_BYTES_THR + { + let zip64_footer = spec::Zip64CentralDirectoryEnd { + version_made_by: DEFAULT_VERSION as u16, + version_needed_to_extract: DEFAULT_VERSION as u16, + disk_number: 0, + disk_with_central_directory: 0, + number_of_files_on_this_disk: files.len() as u64, + number_of_files: files.len() as u64, + central_directory_size: central_size, + central_directory_offset: central_start, + }; + zip64_footer.write_async(inner.as_mut()).await?; + + let zip64_footer = spec::Zip64CentralDirectoryEndLocator { + disk_with_central_directory: 0, + end_of_central_directory_offset: central_start + central_size, + number_of_disks: 1, + }; + zip64_footer.write_async(inner.as_mut()).await?; + } + + let number_of_files = files.len().min(spec::ZIP64_ENTRY_THR) as u16; + let footer = spec::CentralDirectoryEnd { + disk_number: 0, + disk_with_central_directory: 0, + zip_file_comment: comment.to_vec(), + number_of_files_on_this_disk: number_of_files, + number_of_files, + central_directory_size: central_size.min(spec::ZIP64_BYTES_THR) as u32, + central_directory_offset: central_start.min(spec::ZIP64_BYTES_THR) as u32, + }; + footer.write_async(inner.as_mut()).await?; + + Ok(central_start) + } + } + } + + pub(crate) async fn initialize_entry( + self: Pin<&mut Self>, + raw: ZipRawValues, + options: FileOptions, + file_name: String, + stats: &mut ZipWriterStats, + ) -> ZipResult { + match self.project() { + InnerProj::FileWriter(_) => unreachable!("stream should be unwrapped!"), + InnerProj::NoActiveFile(mut inner) => { + let header_start = inner.stream_position().await?; + + let mut file = ZipFileData::initialize(raw, options, header_start, file_name); + write_spec::write_local_file_header(inner.as_mut(), &file).await?; + + let header_end = inner.stream_position().await?; + stats.start = header_end; + *file.data_start.get_mut() = header_end; + + stats.bytes_written = 0; + stats.hasher.reset(); + + Ok(file) + } + } + } + + pub(crate) async fn update_header_and_unwrap_stream( + mut self: Pin<&mut Self>, + stats: &mut ZipWriterStats, + file: &mut ZipFileData, + ) -> ZipResult<()> { + match self.as_mut().project() { + InnerProj::NoActiveFile(mut inner) => { + inner.shutdown().await?; + } + InnerProj::FileWriter(mut inner) => { + /* NB: we need to ensure the compression stream writes out everything it has left + * before reclaiming the stream handle! */ + inner.shutdown().await?; + } + } + + let s = self.get_mut(); + + map_swap_uninit(s, |s| Self::NoActiveFile(s.unwrap_inner_pin())); + + match s { + Self::NoActiveFile(ref mut inner) => { + file.crc32 = mem::take(&mut stats.hasher).finalize(); + file.uncompressed_size = stats.bytes_written; + + let file_end = inner.stream_position().await?; + file.compressed_size = file_end - stats.start; + + write_spec::update_local_file_header(inner.as_mut(), file).await?; + inner.seek(io::SeekFrom::Start(file_end)).await?; + } + _ => unreachable!(), + } + Ok(()) + } + + pub(crate) fn wrap_compressor_stream( + self: Pin<&mut Self>, + file: &ZipFileData, + ) -> ZipResult<()> { + let s = self.get_mut(); + let wrap_result = Self::try_wrap_writer(file.compression_method, file.compression_level)?; + map_swap_uninit(s, move |s| match s { + Self::FileWriter(_) => unreachable!("writer should be unwrapped!"), + Self::NoActiveFile(s) => Self::FileWriter(wrap_result.wrap(s)), + }); + Ok(()) + } + + pub(crate) async fn write_extra_data_and_wrap_compressor_stream( + self: Pin<&mut Self>, + stats: &mut ZipWriterStats, + file: &mut ZipFileData, + ) -> ZipResult<()> { + let s = self.get_mut(); + match s { + Self::FileWriter(_) => unreachable!("writer should be unwrapped!"), + Self::NoActiveFile(inner) => { + let data_start = file.data_start.get_mut(); + + // Append extra data to local file header and keep it for central file header. + inner.write_all(&file.extra_field).await?; + + // Update final `data_start`. + let header_end = *data_start + file.extra_field.len() as u64; + stats.start = header_end; + *data_start = header_end; + + // Update extra field length in local file header. + let extra_field_length = + if file.large_file { 20 } else { 0 } + file.extra_field.len() as u16; + inner + .seek(io::SeekFrom::Start(file.header_start + 28)) + .await?; + inner.write_u16_le(extra_field_length).await?; + inner.seek(io::SeekFrom::Start(header_end)).await?; + } + } + + Pin::new(s).wrap_compressor_stream(file)?; + + Ok(()) + } +} + +impl InnerWriter { + pub fn try_wrap_writer( + compression: CompressionMethod, + compression_level: Option, + ) -> ZipResult { + match compression { + CompressionMethod::Stored => { + if compression_level.is_some() { + return Err(ZipError::UnsupportedArchive( + "Unsupported compression level", + )); + } + Ok(WriterWrapResult::Stored) + } + CompressionMethod::Deflated => { + let compression = Compression::new( + clamp_opt( + compression_level.unwrap_or(Compression::default().level() as i32), + deflate_compression_level_range(), + ) + .ok_or(ZipError::UnsupportedArchive( + "Unsupported compression level", + ))? as u32, + ); + Ok(WriterWrapResult::Deflated(compression)) + } + _ => todo!("other compression methods not yet supported!"), + } + } +} + +fn clamp_opt(value: T, range: ops::RangeInclusive) -> Option { + if range.contains(&value) { + Some(value) + } else { + None + } +} + +fn deflate_compression_level_range() -> ops::RangeInclusive { + let min = Compression::none().level() as i32; + let max = Compression::best().level() as i32; + min..=max +} + +impl WrappedPin for InnerWriter { + fn unwrap_inner_pin(self) -> Pin> { + match self { + Self::NoActiveFile(s) => s, + Self::FileWriter(s) => s.unwrap_inner_pin(), + } + } +} + +/// To a [`Cursor`](std::io::Cursor): +///``` +/// # fn main() -> zip::result::ZipResult<()> { tokio_test::block_on(async { +/// use zip::{write::FileOptions, tokio::{read::ZipArchive, write::ZipWriter}}; +/// use std::{io::Cursor, pin::Pin}; +/// use tokio::io::{self, AsyncReadExt, AsyncWriteExt}; +/// +/// let buf = Cursor::new(Vec::new()); +/// let mut f = ZipWriter::new(Box::pin(buf)); +/// let mut fp = Pin::new(&mut f); +/// +/// let opts = FileOptions::default(); +/// fp.as_mut().start_file("asdf.txt", opts).await?; +/// fp.write_all(b"hello!").await?; +/// +/// let mut f = f.finish_into_readable().await?; +/// let mut f = Pin::new(&mut f); +/// let mut s = String::new(); +/// { +/// let mut zf = f.by_name("asdf.txt").await?; +/// zf.read_to_string(&mut s).await?; +/// } +/// assert_eq!(&s, "hello!"); +/// # Ok(()) +/// # })} +///``` +/// +/// To a [`File`](tokio::fs::File): +///``` +/// # fn main() -> zip::result::ZipResult<()> { tokio_test::block_on(async { +/// use zip::{write::FileOptions, tokio::{read::ZipArchive, write::ZipWriter}}; +/// use std::pin::Pin; +/// use tokio::{fs, io::{self, AsyncReadExt, AsyncWriteExt}}; +/// +/// let file = fs::File::from_std(tempfile::tempfile()?); +/// let mut f = ZipWriter::new(Box::pin(file)); +/// let mut fp = Pin::new(&mut f); +/// +/// let opts = FileOptions::default(); +/// fp.as_mut().start_file("asdf.txt", opts).await?; +/// fp.write_all(b"hello!").await?; +/// +/// let mut f = f.finish_into_readable().await?; +/// let mut f = Pin::new(&mut f); +/// let mut s = String::new(); +/// { +/// let mut zf = f.by_name("asdf.txt").await?; +/// zf.read_to_string(&mut s).await?; +/// } +/// assert_eq!(&s, "hello!"); +/// # Ok(()) +/// # })} +///``` +pub struct ZipWriter { + inner: InnerWriter, + files: Vec, + stats: ZipWriterStats, + writing_to_file: bool, + writing_to_extra_field: bool, + writing_to_central_extra_field_only: bool, + writing_raw: bool, + comment: Vec, +} + +struct WriteProj<'a, S> { + pub inner: Pin<&'a mut InnerWriter>, + pub files: &'a mut Vec, + pub stats: &'a mut ZipWriterStats, + pub writing_to_file: &'a mut bool, + pub writing_to_extra_field: &'a mut bool, + pub writing_to_central_extra_field_only: &'a mut bool, + pub writing_raw: &'a mut bool, + pub comment: &'a mut Vec, +} + +impl ZipWriter { + pub fn new(inner: Pin>) -> Self { + Self { + inner: InnerWriter::NoActiveFile(inner), + files: Vec::new(), + stats: Default::default(), + writing_to_file: false, + writing_to_extra_field: false, + writing_to_central_extra_field_only: false, + writing_raw: false, + comment: Vec::new(), + } + } + + fn for_append(inner: Pin>, files: Vec, comment: Vec) -> Self { + Self { + inner: InnerWriter::NoActiveFile(inner), + files, + stats: Default::default(), + writing_to_file: false, + writing_to_extra_field: false, + writing_to_central_extra_field_only: false, + comment, + writing_raw: true, /* avoid recomputing the last file's header */ + } + } + + #[inline] + fn project(self: Pin<&mut Self>) -> WriteProj<'_, S> { + unsafe { + let Self { + inner, + files, + stats, + writing_to_file, + writing_to_extra_field, + writing_to_central_extra_field_only, + writing_raw, + comment, + } = self.get_unchecked_mut(); + WriteProj { + inner: Pin::new_unchecked(inner), + files, + stats, + writing_to_file, + writing_to_extra_field, + writing_to_central_extra_field_only, + writing_raw, + comment, + } + } + } + + pub fn set_raw_comment(self: Pin<&mut Self>, comment: Vec) { + *self.project().comment = comment; + } + + pub fn set_comment(self: Pin<&mut Self>, comment: impl Into) { + self.set_raw_comment(comment.into().into()); + } +} + +impl ZipWriter { + pub async fn start_file( + mut self: Pin<&mut Self>, + name: impl Into, + mut options: FileOptions, + ) -> ZipResult<()> { + if options.permissions.is_none() { + options.permissions = Some(0o644); + } + *options.permissions.as_mut().unwrap() |= 0o100000; + + self.as_mut().start_entry(name, options, None).await?; + + let me = self.project(); + + me.inner.wrap_compressor_stream(me.files.last().unwrap())?; + + *me.writing_to_file = true; + + Ok(()) + } + + pub async fn add_directory( + mut self: Pin<&mut Self>, + name: impl Into, + mut options: FileOptions, + ) -> ZipResult<()> { + if options.permissions.is_none() { + options.permissions = Some(0o755); + } + *options.permissions.as_mut().unwrap() |= 0o40000; + options.compression_method = CompressionMethod::Stored; + + let mut name_as_string = name.into(); + // Append a slash to the filename if it does not end with it. + if !name_as_string.ends_with('/') { + /* TODO: ends_with('\\') as well? */ + name_as_string.push('/'); + } + + self.as_mut() + .start_entry(name_as_string, options, None) + .await?; + self.writing_to_file = false; + + Ok(()) + } + + pub async fn add_symlink( + mut self: Pin<&mut Self>, + name: impl Into, + target: impl Into, + mut options: FileOptions, + ) -> ZipResult<()> { + if options.permissions.is_none() { + options.permissions = Some(0o777); + } + *options.permissions.as_mut().unwrap() |= 0o120000; + // The symlink target is stored as file content. And compressing the target path + // likely wastes space. So always store. + options.compression_method = CompressionMethod::Stored; + + self.as_mut().start_entry(name, options, None).await?; + self.writing_to_file = true; + self.write_all(target.into().as_bytes()).await?; + self.writing_to_file = false; + + Ok(()) + } + + pub async fn start_file_with_extra_data( + mut self: Pin<&mut Self>, + name: impl Into, + mut options: FileOptions, + ) -> ZipResult { + if options.permissions.is_none() { + options.permissions = Some(0o644); + } + *options.permissions.as_mut().unwrap() |= 0o100000; + + self.as_mut().start_entry(name, options, None).await?; + + self.writing_to_file = true; + self.writing_to_extra_field = true; + + Ok(self.files.last().unwrap().data_start.load()) + } + + pub async fn start_file_aligned( + mut self: Pin<&mut Self>, + name: impl Into, + options: FileOptions, + align: u16, + ) -> ZipResult { + let data_start = self + .as_mut() + .start_file_with_extra_data(name, options) + .await?; + let align = align as u64; + + if align > 1 && data_start % align != 0 { + let pad_length = (align - (data_start + 4) % align) % align; + let pad = vec![0; pad_length as usize]; + self.write_all(b"za").await?; // 0x617a + self.write_u16_le(pad.len() as u16).await?; + self.write_all(&pad).await?; + assert_eq!( + self.as_mut().end_local_start_central_extra_data().await? % align, + 0 + ); + } + let extra_data_end = self.end_extra_data().await?; + Ok(extra_data_end - data_start) + } + + async fn start_entry( + mut self: Pin<&mut Self>, + name: impl Into, + options: FileOptions, + raw_values: Option, + ) -> ZipResult<()> { + self.as_mut().finish_file().await?; + + let raw_values = raw_values.unwrap_or_default(); + + let mut me = self.project(); + + let file = me + .inner + .as_mut() + .initialize_entry(raw_values, options, name.into(), me.stats) + .await?; + me.files.push(file); + Ok(()) + } + + async fn finish_file(mut self: Pin<&mut Self>) -> ZipResult<()> { + if self.writing_to_extra_field { + // Implicitly calling [`ZipWriter::end_extra_data`] for empty files. + self.as_mut().end_extra_data().await?; + } + + let mut me = self.project(); + + if !*me.writing_raw { + let file = match me.files.last_mut() { + None => return Ok(()), + Some(f) => f, + }; + + me.inner + .as_mut() + .update_header_and_unwrap_stream(me.stats, file) + .await?; + } + + *me.writing_to_file = false; + *me.writing_raw = false; + Ok(()) + } + + pub async fn end_local_start_central_extra_data(mut self: Pin<&mut Self>) -> ZipResult { + let data_start = self.as_mut().end_extra_data().await?; + self.files.last_mut().unwrap().extra_field.clear(); + self.writing_to_extra_field = true; + self.writing_to_central_extra_field_only = true; + Ok(data_start) + } + + pub async fn end_extra_data(self: Pin<&mut Self>) -> ZipResult { + if !self.writing_to_extra_field { + return Err(io::Error::new(io::ErrorKind::Other, "Not writing to extra field").into()); + } + + let mut me = self.project(); + + let file = me.files.last_mut().unwrap(); + + write_spec::validate_extra_data(file).await?; + + if !*me.writing_to_central_extra_field_only { + me.inner + .as_mut() + .write_extra_data_and_wrap_compressor_stream(me.stats, file) + .await?; + } + + let data_start = file.data_start.get_mut(); + *me.writing_to_extra_field = false; + *me.writing_to_central_extra_field_only = false; + Ok(*data_start) + } + + pub async fn finish(mut self) -> ZipResult>> { + let _ = Pin::new(&mut self).finalize().await?; + Pin::new(&mut self).shutdown().await?; + Ok(self.unwrap_inner_pin()) + } + + async fn finalize(mut self: Pin<&mut Self>) -> ZipResult { + self.as_mut().finish_file().await?; + + let me = self.project(); + me.inner.finalize(me.files, me.comment).await + } + + /// Copy over the entire contents of another archive verbatim. + /// + /// This method extracts file metadata from the `source` archive, then simply performs a single + /// big [`io::copy()`](io::copy) to transfer all the actual file contents without any + /// decompression or decryption. This is more performant than the equivalent operation of + /// calling [`Self::raw_copy_file()`] for each entry from the `source` archive in sequence. + /// + ///``` + /// # fn main() -> zip::result::ZipResult<()> { tokio_test::block_on(async { + /// use zip::{tokio::{read::ZipArchive, write::ZipWriter}, write::FileOptions}; + /// use tokio::io::{self, AsyncReadExt, AsyncWriteExt, AsyncSeekExt}; + /// use std::{io::Cursor, pin::Pin}; + /// + /// let buf = Cursor::new(Vec::new()); + /// let mut zip = ZipWriter::new(Box::pin(buf)); + /// let mut zp = Pin::new(&mut zip); + /// zp.as_mut().start_file("a.txt", FileOptions::default()).await?; + /// zp.write_all(b"hello\n").await?; + /// let mut src = zip.finish_into_readable().await?; + /// let src = Pin::new(&mut src); + /// + /// let buf = Cursor::new(Vec::new()); + /// let mut zip = ZipWriter::new(Box::pin(buf)); + /// let mut zp = Pin::new(&mut zip); + /// zp.as_mut().start_file("b.txt", FileOptions::default()).await?; + /// zp.write_all(b"hey\n").await?; + /// let mut src2 = zip.finish_into_readable().await?; + /// let src2 = Pin::new(&mut src2); + /// + /// let buf = Cursor::new(Vec::new()); + /// let mut zip = ZipWriter::new(Box::pin(buf)); + /// let mut zp = Pin::new(&mut zip); + /// zp.as_mut().merge_archive(src).await?; + /// zp.merge_archive(src2).await?; + /// let mut result = zip.finish_into_readable().await?; + /// let mut zp = Pin::new(&mut result); + /// + /// let mut s: String = String::new(); + /// zp.as_mut().by_name("a.txt").await?.read_to_string(&mut s).await?; + /// assert_eq!(s, "hello\n"); + /// s.clear(); + /// zp.by_name("b.txt").await?.read_to_string(&mut s).await?; + /// assert_eq!(s, "hey\n"); + /// # Ok(()) + /// # })} + ///``` + pub async fn merge_archive( + mut self: Pin<&mut Self>, + source: Pin<&mut ZipArchive>, + ) -> ZipResult<()> + where + R: io::AsyncRead + io::AsyncSeek, + { + self.as_mut().finish_file().await?; + + /* Ensure we accept the file contents on faith (and avoid overwriting the data). + * See raw_copy_file_rename(). */ + self.writing_to_file = true; + self.writing_raw = true; + + let mut me = self.project(); + + /* Get the file entries from the source archive. */ + let new_files = match me.inner.as_mut().project() { + InnerProj::FileWriter(_) => unreachable!("should never merge with unfinished file!"), + InnerProj::NoActiveFile(writer) => source.merge_contents(writer).await?, + }; + /* These file entries are now ours! */ + let new_files: Vec = new_files.into(); + me.files.extend(new_files.into_iter()); + + Ok(()) + } +} + +impl ZipWriter { + /// Initializes the archive from an existing ZIP archive, making it ready for append. + /// + /// # fn main() -> zip::result::ZipResult<()> { tokio_test::block_on(async { + /// use zip::{tokio::{read::ZipArchive, write::ZipWriter}, write::FileOptions}; + /// use tokio::io::{self, AsyncReadExt, AsyncWriteExt, AsyncSeekExt}; + /// use std::{io::Cursor, pin::Pin}; + /// + /// let buf = Cursor::new(Vec::new()); + /// let mut zip = ZipWriter::new(Box::pin(buf)); + /// let mut zp = Pin::new(&mut zip); + /// zp.as_mut().start_file("a.txt", FileOptions::default()).await?; + /// zp.write_all(b"hello\n").await?; + /// let src = zip.finish().await?; + /// + /// let buf = Cursor::new(Vec::new()); + /// let mut zip = ZipWriter::new(Box::pin(buf)); + /// let mut zp = Pin::new(&mut zip); + /// zp.as_mut().start_file("b.txt", FileOptions::default()).await?; + /// zp.write_all(b"hey\n").await?; + /// let mut src2 = zip.finish_into_readable().await?; + /// let src2 = Pin::new(&mut src2); + /// + /// let mut zip = ZipWriter::new_append(src)?; + /// let mut zp = Pin::new(&mut zip); + /// zp.merge_archive(src2).await?; + /// let mut result = zip.finish_into_readable().await?; + /// let mut zp = Pin::new(&mut result); + /// + /// let mut s: String = String::new(); + /// { + /// use zip::tokio::read::SharedData; + /// assert_eq!(zp.shared().len(), 2); + /// } + /// zp.as_mut().by_name("a.txt").await?.read_to_string(&mut s).await?; + /// assert_eq!(s, "hello\n"); + /// s.clear(); + /// zp.by_name("b.txt").await?.read_to_string(&mut s).await?; + /// assert_eq!(s, "hey\n"); + /// + /// # Ok(()) + /// # })} + ///``` + pub async fn new_append(mut readwriter: Pin>) -> ZipResult { + let (footer, cde_end_pos) = + spec::CentralDirectoryEnd::find_and_parse_async(readwriter.as_mut()).await?; + + if footer.disk_number != footer.disk_with_central_directory { + return Err(ZipError::UnsupportedArchive( + "Support for multi-disk files is not implemented", + )); + } + + let (archive_offset, directory_start, number_of_files) = + Shared::get_directory_counts(readwriter.as_mut(), &footer, cde_end_pos).await?; + + readwriter + .seek(io::SeekFrom::Start(directory_start)) + .await + .map_err(|_| { + ZipError::InvalidArchive("Could not seek to start of central directory") + })?; + + let mut files: Vec = Vec::with_capacity(number_of_files); + for _ in 0..number_of_files { + let file = + read_spec::central_header_to_zip_file(readwriter.as_mut(), archive_offset).await?; + files.push(file); + } + + /* seek directory_start to overwrite it */ + readwriter + .seek(io::SeekFrom::Start(directory_start)) + .await?; + + Ok(Self::for_append(readwriter, files, footer.zip_file_comment)) + } + + /// Write the zip file into the backing stream, then produce a readable archive of that data. + /// + /// This method avoids parsing the central directory records at the end of the stream for + /// a slight performance improvement over running [`ZipArchive::new()`] on the output of + /// [`Self::finish()`]. + /// + ///``` + /// # fn main() -> zip::result::ZipResult<()> { tokio_test::block_on(async { + /// use zip::{tokio::{read::ZipArchive, write::ZipWriter}, write::FileOptions}; + /// use tokio::io::{self, AsyncReadExt, AsyncWriteExt, AsyncSeekExt}; + /// use std::{io::Cursor, pin::Pin}; + /// + /// let buf = Cursor::new(Vec::new()); + /// let mut zip = ZipWriter::new(Box::pin(buf)); + /// let mut zp = Pin::new(&mut zip); + /// let options = FileOptions::default(); + /// zp.as_mut().start_file("a.txt", options).await?; + /// zp.write_all(b"hello\n").await?; + /// + /// let mut zip = zip.finish_into_readable().await?; + /// let mut zp = Pin::new(&mut zip); + /// let mut s: String = String::new(); + /// zp.by_name("a.txt").await?.read_to_string(&mut s).await?; + /// assert_eq!(s, "hello\n"); + /// # Ok(()) + /// # })} + ///``` + pub async fn finish_into_readable(mut self) -> ZipResult> { + let directory_start = Pin::new(&mut self).finalize().await?; + Pin::new(&mut self).shutdown().await?; + let files = mem::take(&mut self.files); + let comment = mem::take(&mut self.comment); + let inner = self.unwrap_inner_pin(); + ZipArchive::from_finalized_writer(files, comment, inner, directory_start) + } +} + +impl io::AsyncWrite for ZipWriter { + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + if !self.writing_to_file { + return Poll::Ready(Err(io::Error::new( + io::ErrorKind::Other, + "No file has been started", + ))); + } + let mut me = self.project(); + match me.inner.as_mut().project() { + InnerProj::NoActiveFile(_) => { + assert!(*me.writing_to_extra_field); + let field = Pin::new(&mut me.files.last_mut().unwrap().extra_field); + field.poll_write(cx, buf) + } + InnerProj::FileWriter(wrapped) => { + let num_written = ready!(wrapped.poll_write(cx, buf))?; + me.stats.update(&buf[..num_written]); + if me.stats.bytes_written > spec::ZIP64_BYTES_THR + && !me.files.last_mut().unwrap().large_file + { + return Poll::Ready(Err(io::Error::new( + io::ErrorKind::Other, + "Large file option has not been set", + ))); + } + Poll::Ready(Ok(num_written)) + } + } + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + match self.project().inner.project() { + InnerProj::NoActiveFile(inner) => inner.poll_flush(cx), + InnerProj::FileWriter(wrapped) => wrapped.poll_flush(cx), + } + } + + fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + ready!(self.as_mut().poll_flush(cx))?; + match self.project().inner.project() { + InnerProj::NoActiveFile(inner) => inner.poll_shutdown(cx), + InnerProj::FileWriter(wrapped) => wrapped.poll_shutdown(cx), + } + } +} + +/* impl ops::Drop for ZipWriter { */ +/* fn drop(&mut self) { */ +/* unreachable!("must call .finish()!"); */ +/* } */ +/* } */ + +impl WrappedPin for ZipWriter { + fn unwrap_inner_pin(self) -> Pin> { + let inner: InnerWriter = unsafe { ptr::read(&self.inner) }; + mem::forget(self); + inner.unwrap_inner_pin() + } +} + +pub(crate) mod write_spec { + use crate::{ + result::ZipResult, + spec::{self, CentralDirectoryHeaderBuffer, LocalHeaderBuffer}, + types::ZipFileData, + }; + + use tokio::io::{self, AsyncReadExt, AsyncSeekExt, AsyncWriteExt}; + + use std::{io::IoSlice, mem, pin::Pin}; + + pub async fn validate_extra_data(file: &ZipFileData) -> ZipResult<()> { + if file.extra_field.len() > spec::ZIP64_ENTRY_THR { + return Err(io::Error::new( + io::ErrorKind::InvalidData, + "Extra data exceeds extra field", + ) + .into()); + } + + let mut data = file.extra_field.as_slice(); + + while !data.is_empty() { + let left = data.len(); + if left < 4 { + return Err( + io::Error::new(io::ErrorKind::Other, "Incomplete extra data header").into(), + ); + } + + let mut buf = [0u8; 4]; + data.read_exact(&mut buf[..]).await?; + let (kind, size): (u16, u16) = unsafe { mem::transmute(buf) }; + + let left = left - 4; + + if kind == 0x0001 { + return Err(io::Error::new( + io::ErrorKind::Other, + "No custom ZIP64 extra data allowed", + ) + .into()); + } + + #[cfg(not(feature = "unreserved"))] + { + if kind <= 31 || crate::write::EXTRA_FIELD_MAPPING.contains(&kind) { + return Err(io::Error::new( + io::ErrorKind::Other, + format!( + "Extra data header ID {kind:#06} requires crate feature \"unreserved\"", + ), + ) + .into()); + } + } + + if size > left as u16 { + return Err(io::Error::new( + io::ErrorKind::Other, + "Extra data size exceeds extra field", + ) + .into()); + } + + data = &data[size as usize..]; + } + + Ok(()) + } + + pub async fn update_local_file_header( + mut writer: Pin<&mut S>, + file: &ZipFileData, + ) -> ZipResult<()> { + const CRC32_OFFSET: u64 = 14; + writer + .seek(io::SeekFrom::Start(file.header_start + CRC32_OFFSET)) + .await?; + + if file.large_file { + writer.write_u32_le(file.crc32).await?; + update_local_zip64_extra_field(writer.as_mut(), file).await?; + } else { + // check compressed size as well as it can also be slightly larger than uncompressed + if file.compressed_size > spec::ZIP64_BYTES_THR { + return Err(io::Error::new( + io::ErrorKind::Other, + "Large file option has not been set", + ) + .into()); + } + let buf: [u32; 3] = [ + file.crc32, + file.compressed_size as u32, + file.uncompressed_size as u32, + ]; + let buf: [u8; 12] = unsafe { mem::transmute(buf) }; + writer.write_all(&buf[..]).await?; + }; + Ok(()) + } + + async fn update_local_zip64_extra_field( + mut writer: Pin<&mut S>, + file: &ZipFileData, + ) -> ZipResult<()> { + let zip64_extra_field = file.header_start + 30 + file.file_name.as_bytes().len() as u64; + writer + .seek(io::SeekFrom::Start(zip64_extra_field + 4)) + .await?; + + let buf: [u64; 2] = [file.uncompressed_size, file.compressed_size]; + let buf: [u8; 16] = unsafe { mem::transmute(buf) }; + writer.write_all(&buf[..]).await?; + // Excluded fields: + // u32: disk start number + Ok(()) + } + + pub async fn write_local_file_header( + mut writer: Pin<&mut S>, + file: &ZipFileData, + ) -> ZipResult<()> { + let (compressed_size, uncompressed_size): (u32, u32) = if file.large_file { + (spec::ZIP64_BYTES_THR as u32, spec::ZIP64_BYTES_THR as u32) + } else { + (file.compressed_size as u32, file.uncompressed_size as u32) + }; + #[allow(deprecated)] + let block: [u8; 30] = unsafe { + mem::transmute(LocalHeaderBuffer { + magic: spec::LOCAL_FILE_HEADER_SIGNATURE, + version_needed_to_extract: file.version_needed(), + flag: if !file.file_name.is_ascii() { + 1u16 << 11 + } else { + 0 + } | if file.encrypted { 1u16 << 0 } else { 0 }, + compression_method: file.compression_method.to_u16(), + last_modified_time_timepart: file.last_modified_time.timepart(), + last_modified_time_datepart: file.last_modified_time.datepart(), + crc32: file.crc32, + compressed_size, + uncompressed_size, + file_name_length: file.file_name.as_bytes().len() as u16, + extra_field_length: if file.large_file { 20 } else { 0 } + + file.extra_field.len() as u16, + }) + }; + + let maybe_extra_field = if file.large_file { + // This entry in the Local header MUST include BOTH original + // and compressed file size fields. + assert!(file.uncompressed_size > spec::ZIP64_BYTES_THR); + assert!(file.compressed_size > spec::ZIP64_BYTES_THR); + Some(get_central_zip64_extra_field(file)) + } else { + None + }; + + let fname = file.file_name.as_bytes(); + + if writer.is_write_vectored() { + /* TODO: zero-copy!! */ + let block = IoSlice::new(&block); + let fname = IoSlice::new(&fname); + if let Some(extra_block) = maybe_extra_field { + let extra_field = IoSlice::new(&extra_block); + writer.write_vectored(&[block, fname, extra_field]).await?; + } else { + writer.write_vectored(&[block, fname]).await?; + } + } else { + /* If no special vector write support, just perform a series of normal writes. */ + writer.write_all(&block).await?; + writer.write_all(&fname).await?; + if let Some(extra_block) = maybe_extra_field { + writer.write_all(&extra_block).await?; + } + } + + Ok(()) + } + + pub async fn write_central_directory_header( + mut writer: Pin<&mut S>, + file: &ZipFileData, + ) -> ZipResult<()> { + let zip64_extra_field = get_central_zip64_extra_field(file); + + #[allow(deprecated)] + let block: [u8; 46] = unsafe { + mem::transmute(CentralDirectoryHeaderBuffer { + magic: spec::CENTRAL_DIRECTORY_HEADER_SIGNATURE, + version_made_by: (file.system as u16) << 8 | (file.version_made_by as u16), + version_needed: file.version_needed(), + flag: if !file.file_name.is_ascii() { + 1u16 << 11 + } else { + 0 + } | if file.encrypted { 1u16 << 0 } else { 0 }, + compression_method: file.compression_method.to_u16(), + last_modified_time_timepart: file.last_modified_time.timepart(), + last_modified_time_datepart: file.last_modified_time.datepart(), + crc32: file.crc32, + compressed_size: file.compressed_size.min(spec::ZIP64_BYTES_THR) as u32, + uncompressed_size: file.uncompressed_size.min(spec::ZIP64_BYTES_THR) as u32, + file_name_length: file.file_name.as_bytes().len() as u16, + extra_field_length: zip64_extra_field.len() as u16, + file_comment_length: 0, + disk_number_start: 0, + internal_attributes: 0, + external_attributes: file.external_attributes, + header_start: file.header_start.min(spec::ZIP64_BYTES_THR) as u32, + }) + }; + + let fname = file.file_name.as_bytes(); + + if writer.is_write_vectored() { + /* TODO: zero-copy!! */ + let block = IoSlice::new(&block); + let fname = IoSlice::new(&fname); + let z64_extra = IoSlice::new(&zip64_extra_field); + let extra = IoSlice::new(&file.extra_field); + writer + .write_vectored(&[block, fname, z64_extra, extra]) + .await?; + } else { + writer.write_all(&block).await?; + // file name + writer.write_all(&fname).await?; + // zip64 extra field + writer.write_all(&zip64_extra_field).await?; + // extra field + writer.write_all(&file.extra_field).await?; + // file comment + // + } + + Ok(()) + } + + fn get_central_zip64_extra_field(file: &ZipFileData) -> Vec { + // The order of the fields in the zip64 extended + // information record is fixed, but the fields MUST + // only appear if the corresponding Local or Central + // directory record field is set to 0xFFFF or 0xFFFFFFFF. + let mut ret: Vec = Vec::new(); + let mut size: u16 = 0; + let uncompressed_size = file.uncompressed_size > spec::ZIP64_BYTES_THR; + let compressed_size = file.compressed_size > spec::ZIP64_BYTES_THR; + let header_start = file.header_start > spec::ZIP64_BYTES_THR; + + let zip64_kind: u16 = 0x0001; + + if uncompressed_size { + size += 8; + } + if compressed_size { + size += 8; + } + if header_start { + size += 8; + } + if size > 0 { + ret.extend_from_slice(&zip64_kind.to_le_bytes()); + ret.extend_from_slice(&size.to_le_bytes()); + size += 4; + + if uncompressed_size { + ret.extend_from_slice(&file.uncompressed_size.to_le_bytes()); + } + if compressed_size { + ret.extend_from_slice(&file.compressed_size.to_le_bytes()); + } + if header_start { + ret.extend_from_slice(&file.header_start.to_le_bytes()); + } + // Excluded fields: + // u32: disk start number + } + + assert_eq!(size as usize, ret.len()); + + ret + } +} diff --git a/src/types.rs b/src/types.rs index b4079f6c7..a5dc24eda 100644 --- a/src/types.rs +++ b/src/types.rs @@ -6,53 +6,78 @@ use std::sync::{Arc, OnceLock}; #[cfg(feature = "chrono")] use chrono::{Datelike, NaiveDate, NaiveDateTime, NaiveTime, Timelike}; #[cfg(doc)] -use {crate::read::ZipFile, crate::write::FileOptions}; +use crate::read::ZipFile; + +use crate::write::{FileOptions, ZipRawValues}; + +use cfg_if::cfg_if; +use num_enum::{FromPrimitive, IntoPrimitive}; + +use std::convert::TryInto; +use std::ops::{Range, RangeInclusive}; +use std::path; pub(crate) mod ffi { pub const S_IFDIR: u32 = 0o0040000; pub const S_IFREG: u32 = 0o0100000; } -#[cfg(any( - all(target_arch = "arm", target_pointer_width = "32"), - target_arch = "mips", - target_arch = "powerpc" -))] -mod atomic { - use crossbeam_utils::sync::ShardedLock; - pub use std::sync::atomic::Ordering; - - #[derive(Debug, Default)] - pub struct AtomicU64 { - value: ShardedLock, - } +cfg_if! { + if #[cfg(any( + all(target_arch = "arm", target_pointer_width = "32"), + target_arch = "mips", + target_arch = "powerpc" + ))] { + mod atomic { + use crossbeam_utils::sync::ShardedLock; + pub use std::sync::atomic::Ordering; + + #[derive(Debug, Default)] + pub struct AtomicU64 { + value: ShardedLock, + } - impl AtomicU64 { - pub fn new(v: u64) -> Self { - Self { - value: ShardedLock::new(v), + impl AtomicU64 { + pub fn new(v: u64) -> Self { + Self { + value: ShardedLock::new(v), + } + } + pub fn get_mut(&mut self) -> &mut u64 { + self.value.get_mut().unwrap() + } + pub fn load(&self, _: Ordering) -> u64 { + *self.value.read().unwrap() + } + pub fn store(&self, value: u64, _: Ordering) { + *self.value.write().unwrap() = value; + } } } - pub fn get_mut(&mut self) -> &mut u64 { - self.value.get_mut().unwrap() - } - pub fn load(&self, _: Ordering) -> u64 { - *self.value.read().unwrap() - } - pub fn store(&self, value: u64, _: Ordering) { - *self.value.write().unwrap() = value; - } + } else { + use std::sync::atomic; } + } -use crate::result::DateTimeRangeError; -#[cfg(feature = "time")] -use time::{error::ComponentRange, Date, Month, OffsetDateTime, PrimitiveDateTime, Time}; +cfg_if! { + if #[cfg(feature = "time")] { + use crate::result::DateTimeRangeError; + use time::{ + Date, Month, OffsetDateTime, UtcOffset, PrimitiveDateTime, Time, + error::ComponentRange, + }; + } else { + use std::time::SystemTime; + } +} -#[derive(Clone, Copy, Debug, PartialEq, Eq)] +#[derive(Clone, Copy, Debug, PartialEq, Eq, FromPrimitive, IntoPrimitive)] +#[repr(u8)] pub enum System { Dos = 0, Unix = 3, + #[num_enum(default)] Unknown, } @@ -140,9 +165,15 @@ impl TryInto for DateTime { } impl Default for DateTime { - /// Constructs an 'default' datetime of 1980-01-01 00:00:00 fn default() -> DateTime { - DateTime { + Self::zero() + } +} + +impl DateTime { + /// Constructs a 'default' datetime of 1980-01-01 00:00:00 + pub const fn zero() -> Self { + Self { year: 1980, month: 1, day: 1, @@ -151,9 +182,74 @@ impl Default for DateTime { second: 0, } } -} -impl DateTime { + /// The allowed range for years in a zip file's timestamp. + pub const YEAR_RANGE: RangeInclusive = 1980..=2107; + /// The allowed range for months in a zip file's timestamp. + pub const MONTH_RANGE: RangeInclusive = 1..=12; + /// The allowed range for days in a zip file's timestamp. + pub const DAY_RANGE: RangeInclusive = 1..=31; + /// The allowed range for hours in a zip file's timestamp. + pub const HOUR_RANGE: Range = 0..24; + /// The allowed range for minutes in a zip file's timestamp. + pub const MINUTE_RANGE: Range = 0..60; + /// The allowed range for seconds in a zip file's timestamp. + pub const SECOND_RANGE: RangeInclusive = 0..=60; + + fn check_year(year: u16) -> Result<(), DateTimeRangeError> { + if Self::YEAR_RANGE.contains(&year) { + Ok(()) + } else { + Err(DateTimeRangeError::InvalidYear(year, Self::YEAR_RANGE)) + } + } + + fn check_month(month: u8) -> Result<(), DateTimeRangeError> { + if Self::MONTH_RANGE.contains(&month) { + Ok(()) + } else { + Err(DateTimeRangeError::InvalidMonth(month, Self::MONTH_RANGE)) + } + } + + fn check_day(day: u8) -> Result<(), DateTimeRangeError> { + if Self::DAY_RANGE.contains(&day) { + Ok(()) + } else { + Err(DateTimeRangeError::InvalidDay(day, Self::DAY_RANGE)) + } + } + + fn check_hour(hour: u8) -> Result<(), DateTimeRangeError> { + if Self::HOUR_RANGE.contains(&hour) { + Ok(()) + } else { + Err(DateTimeRangeError::InvalidHour(hour, Self::HOUR_RANGE)) + } + } + + fn check_minute(minute: u8) -> Result<(), DateTimeRangeError> { + if Self::MINUTE_RANGE.contains(&minute) { + Ok(()) + } else { + Err(DateTimeRangeError::InvalidMinute( + minute, + Self::MINUTE_RANGE, + )) + } + } + + fn check_second(second: u8) -> Result<(), DateTimeRangeError> { + if Self::SECOND_RANGE.contains(&second) { + Ok(()) + } else { + Err(DateTimeRangeError::InvalidSecond( + second, + Self::SECOND_RANGE, + )) + } + } + /// Converts an msdos (u16, u16) pair to a DateTime object pub const fn from_msdos(datepart: u16, timepart: u16) -> DateTime { let seconds = (timepart & 0b0000000000011111) << 1; @@ -182,6 +278,8 @@ impl DateTime { /// * hour: [0, 23] /// * minute: [0, 59] /// * second: [0, 60] + #[allow(clippy::result_unit_err)] + #[deprecated(note = "use DateTime::parse_from_date_and_time() instead")] pub fn from_date_and_time( year: u16, month: u8, @@ -189,25 +287,51 @@ impl DateTime { hour: u8, minute: u8, second: u8, + ) -> Result { + Self::parse_from_date_and_time(year, month, day, hour, minute, second).map_err(|_| ()) + } + + /// Constructs a DateTime from a specific date and time + /// + /// The bounds are: + /// * year ([`Self::YEAR_RANGE`]): [1980, 2107] + /// * month ([`Self::MONTH_RANGE`]): [1, 12] + /// * day ([`Self::DAY_RANGE`]): [1, 31] + /// * hour ([`Self::HOUR_RANGE`]): [0, 23] + /// * minute ([`Self::MINUTE_RANGE`]): [0, 59] + /// * second ([`Self::SECOND_RANGE`]): [0, 60] + /// + ///``` + /// use zip::{DateTime, result::DateTimeRangeError}; + /// + /// assert!(DateTime::parse_from_date_and_time(1980, 1, 1, 0, 0, 0).is_ok()); + /// assert!(matches![ + /// DateTime::parse_from_date_and_time(1979, 1, 1, 0, 0, 0), + /// Err(DateTimeRangeError::InvalidYear(1979, _)), + /// ]); + ///``` + pub fn parse_from_date_and_time( + year: u16, + month: u8, + day: u8, + hour: u8, + minute: u8, + second: u8, ) -> Result { - if (1980..=2107).contains(&year) - && (1..=12).contains(&month) - && (1..=31).contains(&day) - && hour <= 23 - && minute <= 59 - && second <= 60 - { - Ok(DateTime { - year, - month, - day, - hour, - minute, - second, - }) - } else { - Err(DateTimeRangeError) - } + Self::check_year(year)?; + Self::check_month(month)?; + Self::check_day(day)?; + Self::check_hour(hour)?; + Self::check_minute(minute)?; + Self::check_second(second)?; + Ok(Self { + year, + month, + day, + hour, + minute, + second, + }) } /// Indicates whether this date and time can be written to a zip archive. @@ -223,10 +347,14 @@ impl DateTime { .is_ok() } - #[cfg(feature = "time")] /// Converts a OffsetDateTime object to a DateTime /// /// Returns `Err` when this object is out of bounds + #[cfg(feature = "time")] + #[cfg_attr(docsrs, doc(cfg(feature = "time")))] + #[allow(clippy::result_unit_err)] + #[cfg(feature = "time")] + #[cfg_attr(docsrs, doc(cfg(feature = "time")))] #[deprecated(note = "use `DateTime::try_from()`")] pub fn from_time(dt: OffsetDateTime) -> Result { dt.try_into().map_err(|_err| DateTimeRangeError) @@ -242,13 +370,22 @@ impl DateTime { (self.day as u16) | ((self.month as u16) << 5) | ((self.year - 1980) << 9) } - #[cfg(feature = "time")] /// Converts the DateTime to a OffsetDateTime structure + #[cfg(feature = "time")] + #[cfg_attr(docsrs, doc(cfg(feature = "time")))] + #[deprecated(note = "use `DateTime::to_time_with_offset()`")] pub fn to_time(&self) -> Result { + self.to_time_with_offset(UtcOffset::UTC) + } + + /// Converts the DateTime to a OffsetDateTime structure, given a UTC offset. + #[cfg(feature = "time")] + #[cfg_attr(docsrs, doc(cfg(feature = "time")))] + pub fn to_time_with_offset(&self, offset: UtcOffset) -> Result { let date = Date::from_calendar_date(self.year as i32, Month::try_from(self.month)?, self.day)?; let time = Time::from_hms(self.hour, self.minute, self.second)?; - Ok(PrimitiveDateTime::new(date, time).assume_utc()) + Ok(PrimitiveDateTime::new(date, time).assume_offset(offset)) } /// Get the year. There is no epoch, i.e. 2018 will be returned as 2018. @@ -303,6 +440,7 @@ impl DateTime { } #[cfg(feature = "time")] +#[cfg_attr(docsrs, doc(cfg(feature = "time")))] impl TryFrom for DateTime { type Error = DateTimeRangeError; @@ -374,6 +512,37 @@ pub struct ZipFileData { } impl ZipFileData { + pub(crate) fn initialize( + raw: ZipRawValues, + options: FileOptions, + header_start: u64, + file_name: String, + ) -> Self { + let permissions = options.permissions.unwrap_or(0o100644); + Self { + system: System::Unix, + version_made_by: DEFAULT_VERSION, + encrypted: options.encrypt_with.is_some(), + using_data_descriptor: false, + compression_method: options.compression_method, + compression_level: options.compression_level, + last_modified_time: options.last_modified_time, + crc32: raw.crc32, + compressed_size: raw.compressed_size, + uncompressed_size: raw.uncompressed_size, + file_name, + file_name_raw: Vec::new(), // Never used for saving + extra_field: Vec::new(), + file_comment: String::new(), + header_start, + data_start: AtomicU64::new(0), + central_header_start: 0, + external_attributes: permissions << 16, + large_file: options.large_file, + aes_mode: None, + } + } + pub fn file_name_sanitized(&self) -> PathBuf { let no_null_filename = match self.file_name.find('\0') { Some(index) => &self.file_name[0..index], @@ -419,6 +588,7 @@ impl ZipFileData { } /// Get unix mode for the file + #[inline] pub(crate) const fn unix_mode(&self) -> Option { if self.external_attributes == 0 { return None; @@ -493,6 +663,7 @@ pub enum AesMode { } #[cfg(feature = "aes-crypto")] +#[cfg_attr(docsrs, doc(cfg(feature = "aes-crypto")))] impl AesMode { pub const fn salt_length(&self) -> usize { self.key_length() / 2 @@ -508,14 +679,19 @@ impl AesMode { } #[cfg(test)] +#[allow(deprecated)] mod test { #[test] fn system() { use super::System; - assert_eq!(System::Dos as u16, 0u16); - assert_eq!(System::Unix as u16, 3u16); - assert_eq!(System::from_u8(0), System::Dos); - assert_eq!(System::from_u8(3), System::Unix); + assert_eq!(u8::from(System::Dos), 0u8); + assert_eq!(System::Dos as u8, 0u8); + assert_eq!(System::Unix as u8, 3u8); + assert_eq!(u8::from(System::Unix), 3u8); + assert_eq!(System::from(0), System::Dos); + assert_eq!(System::from(3), System::Unix); + assert_eq!(u8::from(System::Unknown), 4u8); + assert_eq!(System::Unknown as u8, 4u8); } #[test] @@ -561,7 +737,7 @@ mod test { #[allow(clippy::unusual_byte_groupings)] fn datetime_max() { use super::DateTime; - let dt = DateTime::from_date_and_time(2107, 12, 31, 23, 59, 60).unwrap(); + let dt = DateTime::parse_from_date_and_time(2107, 12, 31, 23, 59, 60).unwrap(); assert_eq!(dt.timepart(), 0b10111_111011_11110); assert_eq!(dt.datepart(), 0b1111111_1100_11111); } @@ -570,23 +746,23 @@ mod test { fn datetime_bounds() { use super::DateTime; - assert!(DateTime::from_date_and_time(2000, 1, 1, 23, 59, 60).is_ok()); - assert!(DateTime::from_date_and_time(2000, 1, 1, 24, 0, 0).is_err()); - assert!(DateTime::from_date_and_time(2000, 1, 1, 0, 60, 0).is_err()); - assert!(DateTime::from_date_and_time(2000, 1, 1, 0, 0, 61).is_err()); + assert!(DateTime::parse_from_date_and_time(2000, 1, 1, 23, 59, 60).is_ok()); + assert!(DateTime::parse_from_date_and_time(2000, 1, 1, 24, 0, 0).is_err()); + assert!(DateTime::parse_from_date_and_time(2000, 1, 1, 0, 60, 0).is_err()); + assert!(DateTime::parse_from_date_and_time(2000, 1, 1, 0, 0, 61).is_err()); - assert!(DateTime::from_date_and_time(2107, 12, 31, 0, 0, 0).is_ok()); - assert!(DateTime::from_date_and_time(1980, 1, 1, 0, 0, 0).is_ok()); - assert!(DateTime::from_date_and_time(1979, 1, 1, 0, 0, 0).is_err()); - assert!(DateTime::from_date_and_time(1980, 0, 1, 0, 0, 0).is_err()); - assert!(DateTime::from_date_and_time(1980, 1, 0, 0, 0, 0).is_err()); - assert!(DateTime::from_date_and_time(2108, 12, 31, 0, 0, 0).is_err()); - assert!(DateTime::from_date_and_time(2107, 13, 31, 0, 0, 0).is_err()); - assert!(DateTime::from_date_and_time(2107, 12, 32, 0, 0, 0).is_err()); + assert!(DateTime::parse_from_date_and_time(2107, 12, 31, 0, 0, 0).is_ok()); + assert!(DateTime::parse_from_date_and_time(1980, 1, 1, 0, 0, 0).is_ok()); + assert!(DateTime::parse_from_date_and_time(1979, 1, 1, 0, 0, 0).is_err()); + assert!(DateTime::parse_from_date_and_time(1980, 0, 1, 0, 0, 0).is_err()); + assert!(DateTime::parse_from_date_and_time(1980, 1, 0, 0, 0, 0).is_err()); + assert!(DateTime::parse_from_date_and_time(2108, 12, 31, 0, 0, 0).is_err()); + assert!(DateTime::parse_from_date_and_time(2107, 13, 31, 0, 0, 0).is_err()); + assert!(DateTime::parse_from_date_and_time(2107, 12, 32, 0, 0, 0).is_err()); } #[cfg(feature = "time")] - use time::{format_description::well_known::Rfc3339, OffsetDateTime}; + use time::{format_description::well_known::Rfc3339, OffsetDateTime, UtcOffset}; #[cfg(feature = "time")] #[test] @@ -621,10 +797,14 @@ mod test { assert_eq!(dt.second(), 30); #[cfg(feature = "time")] - assert_eq!( - dt.to_time().unwrap().format(&Rfc3339).unwrap(), - "2018-11-17T10:38:30Z" - ); + { + let offset_time = dt.to_time_with_offset(UtcOffset::UTC).unwrap(); + assert_eq!(dt.to_time().unwrap(), offset_time); + assert_eq!( + offset_time.format(&Rfc3339).unwrap(), + "2018-11-17T10:38:30Z" + ); + } } #[test] @@ -639,7 +819,10 @@ mod test { assert_eq!(dt.second(), 62); #[cfg(feature = "time")] - assert!(dt.to_time().is_err()); + { + assert!(dt.to_time().is_err()); + assert!(dt.to_time_with_offset(UtcOffset::UTC).is_err()); + } let dt = DateTime::from_msdos(0x0000, 0x0000); assert_eq!(dt.year(), 1980); @@ -650,7 +833,10 @@ mod test { assert_eq!(dt.second(), 0); #[cfg(feature = "time")] - assert!(dt.to_time().is_err()); + { + assert!(dt.to_time().is_err()); + assert!(dt.to_time_with_offset(UtcOffset::UTC).is_err()); + } } #[cfg(feature = "time")] diff --git a/src/write.rs b/src/write.rs index 0051f253d..86b7ba269 100644 --- a/src/write.rs +++ b/src/write.rs @@ -132,16 +132,26 @@ use crate::CompressionMethod::Stored; pub use zip_writer::ZipWriter; #[derive(Default)] -struct ZipWriterStats { - hasher: Hasher, - start: u64, - bytes_written: u64, +pub(crate) struct ZipWriterStats { + pub(crate) hasher: Hasher, + pub(crate) start: u64, + pub(crate) bytes_written: u64, } -struct ZipRawValues { - crc32: u32, - compressed_size: u64, - uncompressed_size: u64, +pub(crate) struct ZipRawValues { + pub(crate) crc32: u32, + pub(crate) compressed_size: u64, + pub(crate) uncompressed_size: u64, +} + +impl Default for ZipRawValues { + fn default() -> Self { + Self { + crc32: 0, + compressed_size: 0, + uncompressed_size: 0, + } + } } mod sealed { use std::sync::Arc; @@ -207,12 +217,12 @@ pub struct ExtendedFileOptions { impl arbitrary::Arbitrary<'_> for FileOptions { fn arbitrary(u: &mut arbitrary::Unstructured) -> arbitrary::Result { let mut options = FullFileOptions { - compression_method: CompressionMethod::arbitrary(u)?, - compression_level: None, - last_modified_time: DateTime::arbitrary(u)?, - permissions: Option::::arbitrary(u)?, - large_file: bool::arbitrary(u)?, - encrypt_with: Option::::arbitrary(u)?, + pub(crate) compression_method: CompressionMethod::arbitrary(u)?, + pub(crate) compression_level: None, + pub(crate) last_modified_time: DateTime::arbitrary(u)?, + pub(crate) permissions: Option::::arbitrary(u)?, + pub(crate) large_file: bool::arbitrary(u)?, + pub(crate) encrypt_with: Option::::arbitrary(u)?, alignment: u16::arbitrary(u)?, #[cfg(feature = "deflate-zopfli")] zopfli_buffer_size: None, @@ -403,10 +413,7 @@ impl Default for FileOptions { Self { compression_method: Default::default(), compression_level: None, - #[cfg(feature = "time")] - last_modified_time: OffsetDateTime::now_utc().try_into().unwrap_or_default(), - #[cfg(not(feature = "time"))] - last_modified_time: DateTime::default(), + last_modified_time, permissions: None, large_file: false, encrypt_with: None, @@ -465,7 +472,7 @@ impl Write for ZipWriter { } impl ZipWriterStats { - fn update(&mut self, buf: &[u8]) { + pub(crate) fn update(&mut self, buf: &[u8]) { self.hasher.update(buf); self.bytes_written += buf.len() as u64; } @@ -1747,7 +1754,7 @@ mod test { .add_directory( "test", SimpleFileOptions::default().last_modified_time( - DateTime::from_date_and_time(2018, 8, 15, 20, 45, 6).unwrap(), + DateTime::parse_from_date_and_time(2018, 8, 15, 20, 45, 6).unwrap(), ), ) .unwrap(); @@ -1776,7 +1783,7 @@ mod test { "name", "target", SimpleFileOptions::default().last_modified_time( - DateTime::from_date_and_time(2018, 8, 15, 20, 45, 6).unwrap(), + DateTime::parse_from_date_and_time(2018, 8, 15, 20, 45, 6).unwrap(), ), ) .unwrap(); @@ -1821,7 +1828,7 @@ mod test { "directory\\link", "/absolute/symlink\\with\\mixed/slashes", SimpleFileOptions::default().last_modified_time( - DateTime::from_date_and_time(2018, 8, 15, 20, 45, 6).unwrap(), + DateTime::parse_from_date_and_time(2018, 8, 15, 20, 45, 6).unwrap(), ), ) .unwrap(); @@ -2171,7 +2178,7 @@ mod test { #[test] fn crash_with_no_features() -> ZipResult<()> { - const ORIGINAL_FILE_NAME: &str = "PK\u{6}\u{6}\0\0\0\0\0\0\0\0\0\u{2}g\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\u{1}\0\0\0\0\0\0\0\0\0\0PK\u{6}\u{7}\0\0\0\0\0\0\0\0\0\0\0\0\u{7}\0\t'"; + pub(crate) const ORIGINAL_FILE_NAME: &str = "PK\u{6}\u{6}\0\0\0\0\0\0\0\0\0\u{2}g\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\u{1}\0\0\0\0\0\0\0\0\0\0PK\u{6}\u{7}\0\0\0\0\0\0\0\0\0\0\0\0\u{7}\0\t'"; let mut writer = ZipWriter::new(io::Cursor::new(Vec::new())); let mut options = SimpleFileOptions::default(); options = options diff --git a/tmp/.gitignore b/tmp/.gitignore new file mode 100644 index 000000000..ddc9e8a3c --- /dev/null +++ b/tmp/.gitignore @@ -0,0 +1,3 @@ +perf* +/tmp-out/ +/tmp-copy-out diff --git a/tmp/Cargo.toml b/tmp/Cargo.toml new file mode 100644 index 000000000..3adea6285 --- /dev/null +++ b/tmp/Cargo.toml @@ -0,0 +1,27 @@ +[package] +name = "tmp" +version = "0.0.0" +edition = "2021" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] +zip = { path = ".." } +libc = "0.2" +tokio = { version = "1", features = ["rt-multi-thread", "macros"] } +tempfile = "3" +getrandom = "0.2" +once_cell = "1" + +[[bin]] +name = "main" +path = "src/main.rs" + +[[bin]] +name = "copy-file-range" +path = "src/copy-file-range.rs" + +[profile.release] +strip = false +debug = true +# lto = true diff --git a/tmp/src/copy-file-range.rs b/tmp/src/copy-file-range.rs new file mode 100755 index 000000000..d882f6213 --- /dev/null +++ b/tmp/src/copy-file-range.rs @@ -0,0 +1,144 @@ +use std::{ + env, + io::Cursor, + num::NonZeroUsize, + path::{Path, PathBuf}, + pin::Pin, + str::FromStr, + sync::Arc, +}; + +use getrandom::getrandom; +use once_cell::sync::Lazy; +use tokio::{fs, io, task}; + +use zip::{ + result::{ZipError, ZipResult}, + write::FileOptions, + CompressionMethod, ZipWriter, +}; + +fn generate_random_archive( + num_entries: usize, + entry_size: usize, + out_path: &Path, +) -> ZipResult<()> { + use std::io::Write; + + eprintln!("num_entries = {}", num_entries); + eprintln!("entry_size = {}", entry_size); + + let out_handle = std::fs::File::create(out_path)?; + let mut zip = ZipWriter::new(out_handle); + /* No point compressing random entries. */ + let options = FileOptions::default().compression_method(CompressionMethod::Stored); + + let mut bytes = vec![0u8; entry_size]; + for i in 0..num_entries { + let name = format!("random{}.dat", i); + zip.start_file(name, options)?; + getrandom(&mut bytes).unwrap(); + zip.write_all(&bytes)?; + } + + let out_handle = zip.finish()?; + out_handle.sync_all()?; + + Ok(()) +} + +async fn get_len(p: &Path) -> io::Result { + Ok(fs::metadata(p).await?.len()) +} + +static BIG_ARCHIVE_PATH: Lazy = + Lazy::new(|| Path::new("../benches/target.zip").to_path_buf()); + +static SMALL_ARCHIVE_PATH: Lazy = + Lazy::new(|| Path::new("../benches/small-target.zip").to_path_buf()); + +fn flag_var(var_name: &str) -> bool { + env::var(var_name) + .ok() + .filter(|v| v.starts_with('y')) + .is_some() +} + +fn num_var(var_name: &str) -> Option { + let var = env::var(var_name).ok()?; + let n = usize::from_str(&var).ok()?; + Some(n) +} + +fn path_var(var_name: &str) -> Option { + let var = env::var(var_name).ok()?; + Some(var.into()) +} + +#[tokio::main] +async fn main() -> ZipResult<()> { + let source = if flag_var("SMALL") || flag_var("small") { + println!("small!"); + &*SMALL_ARCHIVE_PATH + } else { + println!("big!"); + &*BIG_ARCHIVE_PATH + } + .to_path_buf(); + + use zip::tokio::{buf_reader::BufReader, os::copy_file_range::*}; + + let len: u64 = get_len(&source).await?; + println!("len = {}", len); + + let out_path: PathBuf = path_var("OUT") + .or(path_var("out")) + .unwrap_or_else(|| PathBuf::from("./tmp-copy-out")); + println!("out = {}", out_path.display()); + + let num_iters: usize = num_var("N").or(num_var("n")).unwrap_or(15); + println!("num_iters = {}", num_iters); + + for _ in 0..num_iters { + if flag_var("ASYNC") || flag_var("async") { + println!("async!"); + use tokio::io::AsyncWriteExt; + + let non_zero_len = NonZeroUsize::new(len as usize).unwrap(); + + let handle = fs::File::open(&source).await?; + let mut out = fs::File::create(&out_path).await?; + + let mut buf_reader = BufReader::with_capacity(non_zero_len, Box::pin(handle)); + + let written = io::copy_buf(&mut buf_reader, &mut out).await?; + assert_eq!(written, len); + + out.shutdown().await?; + } else { + println!("copy_file_range!"); + let source = source.clone(); + let out_path = out_path.clone(); + task::spawn_blocking(move || { + use std::{fs, io}; + + let handle = fs::File::open(source)?; + let mut src = MutateInnerOffset::new(handle, Role::Readable)?; + let src = Pin::new(&mut src); + + let out = fs::File::create(out_path)?; + let mut dst = MutateInnerOffset::new(out, Role::Writable)?; + let dst = Pin::new(&mut dst); + + let len = len as usize; + assert_eq!(len, copy_file_range(src, dst, len)?); + + Ok::<_, io::Error>(()) + }) + .await + .unwrap()?; + }; + } + + Ok(()) +} diff --git a/tmp/src/main.rs b/tmp/src/main.rs new file mode 100755 index 000000000..5a35b3f09 --- /dev/null +++ b/tmp/src/main.rs @@ -0,0 +1,154 @@ +#![allow(unused_imports)] +#![allow(dead_code)] + +use std::env; +use std::io::{Cursor, Read, Seek, Write}; +use std::path::{Path, PathBuf}; +use std::pin::Pin; +use std::str::FromStr; +use std::sync::Arc; + +use getrandom::getrandom; +use once_cell::sync::Lazy; +use tokio::{fs, io, task}; + +use zip::{ + result::{ZipError, ZipResult}, + write::FileOptions, + CompressionMethod, ZipWriter, +}; + +fn generate_random_archive( + num_entries: usize, + entry_size: usize, + out_path: &Path, +) -> ZipResult<()> { + eprintln!("num_entries = {}", num_entries); + eprintln!("entry_size = {}", entry_size); + + let out_handle = std::fs::File::create(out_path)?; + let mut zip = ZipWriter::new(out_handle); + /* No point compressing random entries. */ + let options = FileOptions::default().compression_method(CompressionMethod::Stored); + + let mut bytes = vec![0u8; entry_size]; + for i in 0..num_entries { + let name = format!("random{}.dat", i); + zip.start_file(name, options)?; + getrandom(&mut bytes).unwrap(); + zip.write_all(&bytes)?; + } + + let out_handle = zip.finish()?; + out_handle.sync_all()?; + + Ok(()) +} + +async fn get_len(p: &Path) -> io::Result { + Ok(fs::metadata(p).await?.len()) +} + +static BIG_ARCHIVE_PATH: Lazy = + Lazy::new(|| Path::new("../benches/target.zip").to_path_buf()); + +static SMALL_ARCHIVE_PATH: Lazy = + Lazy::new(|| Path::new("../benches/small-target.zip").to_path_buf()); + +fn flag_var(var_name: &str) -> bool { + env::var(var_name) + .ok() + .filter(|v| v.starts_with('y')) + .is_some() +} + +fn num_var(var_name: &str) -> Option { + let var = env::var(var_name).ok()?; + let n = usize::from_str(&var).ok()?; + Some(n) +} + +fn path_var(var_name: &str) -> Option { + let var = env::var(var_name).ok()?; + Some(var.into()) +} + +#[tokio::main] +async fn main() -> ZipResult<()> { + let n = num_var("N").or(num_var("n")).unwrap_or(5); + eprintln!("n = {}", n); + + let td = task::spawn_blocking(move || tempfile::tempdir()) + .await + .unwrap()?; + + let test_archive_path = if flag_var("RANDOM") || flag_var("random") { + let zip_out_path = td.path().join("random.zip"); + let num_entries: usize = num_var("RANDOM_N").or(num_var("random_n")).unwrap_or(1_000); + let entry_size: usize = num_var("RANDOM_SIZE") + .or(num_var("random_size")) + .unwrap_or(10_000); + { + let z2 = zip_out_path.clone(); + task::spawn_blocking(move || generate_random_archive(num_entries, entry_size, &z2)) + .await + .unwrap()?; + } + eprintln!( + "random({}) = {}", + get_len(&zip_out_path).await?, + zip_out_path.display() + ); + zip_out_path + } else if flag_var("SMALL") || flag_var("small") { + eprintln!( + "small({}) = {}", + get_len(&*SMALL_ARCHIVE_PATH).await?, + SMALL_ARCHIVE_PATH.display() + ); + SMALL_ARCHIVE_PATH.to_path_buf() + } else { + eprintln!( + "big({}) = {}", + get_len(&*BIG_ARCHIVE_PATH).await?, + BIG_ARCHIVE_PATH.display() + ); + BIG_ARCHIVE_PATH.to_path_buf() + }; + + let out_path = path_var("OUT") + .or(path_var("out")) + .unwrap_or_else(|| PathBuf::from("./tmp-out")); + eprintln!("out = {}", out_path.display()); + + if flag_var("SYNC") || flag_var("sync") { + eprintln!("synchronous!"); + task::spawn_blocking(move || { + for _ in 0..n { + let handle = std::fs::OpenOptions::new() + .read(true) + .open(&test_archive_path)?; + let mut src = zip::read::ZipArchive::new(handle)?; + src.extract(&out_path)?; + } + + Ok::<_, ZipError>(()) + }) + .await + .unwrap()?; + } else { + eprintln!("async!"); + let out = Arc::new(out_path); + for _ in 0..n { + let handle = fs::OpenOptions::new() + .read(true) + /* .custom_flags(libc::O_NONBLOCK) */ + .open(&test_archive_path) + .await?; + let mut src = zip::tokio::read::ZipArchive::new(Box::pin(handle)).await?; + Pin::new(&mut src).extract(out.clone()).await?; + } + } + + Ok(()) +}