diff --git a/src/c_api.rs b/src/c_api.rs index 95b75d6d..d4c9b728 100644 --- a/src/c_api.rs +++ b/src/c_api.rs @@ -342,3 +342,22 @@ pub unsafe extern "C" fn deflateInit2_( ) -> libc::c_int { crate::deflate::init2(strm, level, method, windowBits, memLevel, strategy) as _ } + +pub unsafe extern "C" fn deflateInit2( + strm: z_streamp, + level: c_int, + method: c_int, + windowBits: c_int, + memLevel: c_int, + strategy: c_int, +) -> libc::c_int { + crate::deflate::init2(strm, level, method, windowBits, memLevel, strategy) as _ +} + +pub unsafe extern "C" fn deflateReset(strm: *mut z_stream) -> i32 { + if let Some(stream) = DeflateStream::from_stream_mut(strm) { + crate::deflate::reset(stream) as _ + } else { + ReturnCode::StreamError as _ + } +} diff --git a/src/deflate.rs b/src/deflate.rs index fe967fc5..e0fba03e 100644 --- a/src/deflate.rs +++ b/src/deflate.rs @@ -304,7 +304,7 @@ pub unsafe fn end(strm: *mut z_stream) -> i32 { } } -fn reset(stream: &mut DeflateStream) -> ReturnCode { +pub fn reset(stream: &mut DeflateStream) -> ReturnCode { let ret = reset_keep(stream); if ret == ReturnCode::Ok { diff --git a/tests/deflate.rs b/tests/deflate.rs index e4a36055..d37f31fb 100644 --- a/tests/deflate.rs +++ b/tests/deflate.rs @@ -1,3 +1,8 @@ +use std::{ + ffi::CStr, + sync::atomic::{AtomicU8, Ordering}, +}; + use zlib::*; use libc::{c_char, c_int}; @@ -447,3 +452,141 @@ fn deflate_medium_fizzle_bug() { assert_eq!(&output[..dest_len], EXPECTED); } + +#[test] +fn test_deflate_concurrency() { + static mut BUF: [u8; 8 * 1024] = [0; 8 * 1024]; + + let zbuf: &mut [u8] = &mut [0; 4 * 1024]; + let tmp: &mut [u8] = &mut [0; 8 * 1024]; + + const PAUSED: u8 = 0; + const RUNNING: u8 = 1; + const STOPPED: u8 = 2; + + static STATE: AtomicU8 = AtomicU8::new(PAUSED); + static TARGET_STATE: AtomicU8 = AtomicU8::new(PAUSED); + + std::thread::spawn(|| loop { + STATE.store(TARGET_STATE.load(Ordering::Relaxed), Ordering::Relaxed); + + match STATE.load(Ordering::Relaxed) { + PAUSED => continue, + STOPPED => break, + _ => unsafe { + for i in 0..BUF.len() { + let ptr = BUF.as_mut_ptr().add(i); + *ptr = *ptr + 1; + } + }, + } + }); + + fn transition(target_state: u8) { + TARGET_STATE.store(target_state, Ordering::Relaxed); + + while STATE.load(Ordering::Relaxed) != target_state { + std::hint::spin_loop() + } + } + + unsafe { + let mut dstrm = zlib::z_stream::default(); + + let err = deflateInit2( + &mut dstrm, + Z_BEST_SPEED, + Z_DEFLATED, + -15, + 8, + Z_DEFAULT_STRATEGY, + ); + assert_eq!(Z_OK, err, "{:?}", CStr::from_ptr(dstrm.msg)); + + let mut istrm = zlib::z_stream::default(); + let err = inflateInit2(&mut istrm, -15); + assert_eq!(Z_OK, err, "{:?}", CStr::from_ptr(istrm.msg)); + + // Iterate for a certain amount of time. + let deadline = std::time::Instant::now() + std::time::Duration::from_secs(1); + while std::time::Instant::now() < deadline { + /* Start each iteration with a fresh stream state. */ + let err = deflateReset(&mut dstrm); + assert_eq!(Z_OK, err, "{:?}", CStr::from_ptr(dstrm.msg)); + + let err = inflateReset(&mut istrm); + assert_eq!(Z_OK, err, "{:?}", CStr::from_ptr(istrm.msg)); + + // Mutate and compress the first half of buf concurrently. + // Decompress and throw away the results, which are unpredictable. + transition(RUNNING); + dstrm.next_in = BUF.as_mut_ptr(); + dstrm.avail_in = (BUF.len() / 2) as _; + while dstrm.avail_in > 0 { + dstrm.next_out = zbuf.as_mut_ptr(); + dstrm.avail_out = zbuf.len() as _; + let err = deflate(&mut dstrm, Z_NO_FLUSH); + assert_eq!(Z_OK, err, "{:?}", CStr::from_ptr(dstrm.msg)); + istrm.next_in = zbuf.as_mut_ptr(); + istrm.avail_in = (zbuf.len() - dstrm.avail_out as usize) as _; + while istrm.avail_in > 0 { + istrm.next_out = tmp.as_mut_ptr(); + istrm.avail_out = tmp.len() as _; + let err = inflate(&mut istrm, Z_NO_FLUSH); + assert_eq!(Z_OK, err, "{:?}", CStr::from_ptr(istrm.msg)); + } + } + + // Stop mutation and compress the second half of buf. + // Decompress and check that the result matches. + transition(PAUSED); + dstrm.next_in = BUF.as_mut_ptr().add(BUF.len() / 2); + dstrm.avail_in = (BUF.len() - BUF.len() / 2) as _; + while dstrm.avail_in > 0 { + dstrm.next_out = zbuf.as_mut_ptr(); + dstrm.avail_out = zbuf.len() as _; + dbg!(dstrm.avail_in); + let err = deflate(&mut dstrm, Z_FINISH); + if err == Z_STREAM_END { + assert_eq!(0, dstrm.avail_in); + } else { + assert_eq!(Z_OK, err, "{:?}", CStr::from_ptr(dstrm.msg)); + } + istrm.next_in = zbuf.as_mut_ptr(); + istrm.avail_in = (zbuf.len() - dstrm.avail_out as usize) as _; + while istrm.avail_in > 0 { + let orig_total_out = istrm.total_out as usize; + istrm.next_out = tmp.as_mut_ptr(); + istrm.avail_out = tmp.len() as _; + let err = inflate(&mut istrm, Z_NO_FLUSH); + if err == Z_STREAM_END { + assert_eq!(0, istrm.avail_in); + } else { + assert_eq!(Z_OK, err, "{:?}", CStr::from_ptr(dstrm.msg)); + } + let concurrent_size = BUF.len() - BUF.len() / 2; + if istrm.total_out as usize > concurrent_size { + let (tmp_offset, buf_offset, size) = if orig_total_out >= concurrent_size { + ( + 0, + orig_total_out - concurrent_size, + istrm.total_out as usize - orig_total_out, + ) + } else { + ( + concurrent_size - orig_total_out, + 0, + istrm.total_out as usize - concurrent_size, + ) + }; + + let a = &tmp[tmp_offset as usize..][..size]; + let b = &BUF[BUF.len() / 2 + buf_offset..][..size]; + + assert_eq!(a, b); + } + } + } + } + } +}