Skip to content

Commit 22b7fc6

Browse files
bjorn3folkertdev
authored andcommitted
Use safe types for the input buffer in a bunch of places in zstd_v07
1 parent 8f9c19f commit 22b7fc6

File tree

1 file changed

+48
-65
lines changed

1 file changed

+48
-65
lines changed

lib/legacy/zstd_v07.rs

Lines changed: 48 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ use libc::{calloc, free, malloc};
66

77
use crate::lib::common::error_private::Error;
88
use crate::lib::common::mem::{MEM_32bits, MEM_64bits, MEM_readLE32, MEM_readLEST};
9+
use crate::lib::common::reader::Reader;
910
use crate::lib::common::xxhash::{
1011
XXH64_state_t, ZSTD_XXH64_digest, ZSTD_XXH64_reset, ZSTD_XXH64_update, ZSTD_XXH64_update_slice,
1112
};
@@ -1813,12 +1814,11 @@ fn ZSTDv07_copyRawBlock(mut dst: Writer<'_>, src: &[u8]) -> Result<usize, Error>
18131814
}
18141815
Ok(src.len())
18151816
}
1816-
unsafe fn ZSTDv07_decodeLiteralsBlock(
1817-
dctx: &mut ZSTDv07_DCtx,
1818-
srcPtr: *const u8,
1819-
srcSize: usize,
1820-
) -> Result<usize, Error> {
1821-
let src = unsafe { core::slice::from_raw_parts(srcPtr, srcSize) };
1817+
1818+
/// # Safety
1819+
///
1820+
/// `src` must outlive the last decompress call that covers the same compressed block.
1821+
unsafe fn ZSTDv07_decodeLiteralsBlock(dctx: &mut ZSTDv07_DCtx, src: &[u8]) -> Result<usize, Error> {
18221822
if src.len() < MIN_CBLOCK_SIZE as usize {
18231823
return Err(Error::corruption_detected);
18241824
}
@@ -1942,7 +1942,7 @@ unsafe fn ZSTDv07_decodeLiteralsBlock(
19421942
dctx.litBuffer[dctx.litSize..dctx.litSize + WILDCOPY_OVERLENGTH].fill(0);
19431943
return Ok(lhSize.wrapping_add(litSize));
19441944
}
1945-
dctx.litPtr = unsafe { srcPtr.add(lhSize) };
1945+
dctx.litPtr = src[lhSize..].as_ptr();
19461946
dctx.litSize = litSize;
19471947
Ok(lhSize + litSize)
19481948
}
@@ -1978,6 +1978,7 @@ unsafe fn ZSTDv07_decodeLiteralsBlock(
19781978
_ => Err(Error::corruption_detected),
19791979
}
19801980
}
1981+
19811982
fn ZSTDv07_buildSeqTable<const N: usize>(
19821983
DTable: &mut FSEv07_DTable<N>,
19831984
type_0: u32,
@@ -2345,21 +2346,24 @@ unsafe fn ZSTDv07_checkContinuity(dctx: &mut ZSTDv07_DCtx, dst: *const u8) {
23452346
dctx.previousDstEnd = dst;
23462347
}
23472348
}
2349+
2350+
/// # Safety
2351+
///
2352+
/// `src` must outlive the last decompress call that covers the same compressed block.
23482353
unsafe fn ZSTDv07_decompressBlock_internal(
23492354
dctx: &mut ZSTDv07_DCtx,
23502355
dst: Writer<'_>,
2351-
src: *const u8,
2352-
mut srcSize: usize,
2356+
src: &[u8],
23532357
) -> Result<usize, Error> {
23542358
let mut ip = src;
2355-
if srcSize >= ZSTDv07_BLOCKSIZE_ABSOLUTEMAX {
2359+
if src.len() >= ZSTDv07_BLOCKSIZE_ABSOLUTEMAX {
23562360
return Err(Error::srcSize_wrong);
23572361
}
2358-
let litCSize = ZSTDv07_decodeLiteralsBlock(dctx, src, srcSize)?;
2359-
ip = ip.add(litCSize);
2360-
srcSize = srcSize.wrapping_sub(litCSize);
2361-
ZSTDv07_decompressSequences(dctx, dst, core::slice::from_raw_parts(ip, srcSize))
2362+
let litCSize = ZSTDv07_decodeLiteralsBlock(dctx, src)?;
2363+
ip = &ip[litCSize..];
2364+
ZSTDv07_decompressSequences(dctx, dst, ip)
23622365
}
2366+
23632367
fn ZSTDv07_generateNxBytes(mut dst: Writer<'_>, byte: u8, length: usize) -> Result<usize, Error> {
23642368
if length > dst.capacity() {
23652369
return Err(Error::dstSize_tooSmall);
@@ -2420,8 +2424,7 @@ unsafe fn ZSTDv07_decompressFrame(
24202424
bt_compressed => ZSTDv07_decompressBlock_internal(
24212425
dctx,
24222426
Writer::from_range(op, oend),
2423-
ip,
2424-
cBlockSize,
2427+
core::slice::from_raw_parts(ip, cBlockSize),
24252428
)?,
24262429
bt_raw => ZSTDv07_copyRawBlock(
24272430
Writer::from_range(op, oend),
@@ -2544,23 +2547,24 @@ fn ZSTDv07_isSkipFrame(dctx: &ZSTDv07_DCtx) -> bool {
25442547
unsafe fn ZSTDv07_decompressContinue(
25452548
dctx: &mut ZSTDv07_DCtx,
25462549
mut dst: Writer<'_>,
2547-
src: *const core::ffi::c_void,
2548-
srcSize: usize,
2550+
src: Reader<'_>,
25492551
) -> Result<usize, Error> {
2550-
if srcSize != dctx.expected {
2552+
if src.len() != dctx.expected {
25512553
return Err(Error::srcSize_wrong);
25522554
}
25532555
if dst.capacity() != 0 {
25542556
ZSTDv07_checkContinuity(dctx, dst.as_ptr());
25552557
}
25562558
match dctx.stage as core::ffi::c_uint {
25572559
0 => {
2558-
if srcSize != ZSTDv07_frameHeaderSize_min {
2560+
if src.len() != ZSTDv07_frameHeaderSize_min {
25592561
return Err(Error::srcSize_wrong);
25602562
}
2561-
if MEM_readLE32(src) & 0xfffffff0 == ZSTDv07_MAGIC_SKIPPABLE_START {
2563+
if u32::from_le_bytes(src.subslice(..4).as_slice().try_into().unwrap()) & 0xfffffff0
2564+
== ZSTDv07_MAGIC_SKIPPABLE_START
2565+
{
25622566
ptr::copy_nonoverlapping(
2563-
src as *const u8,
2567+
src.as_ptr(),
25642568
dctx.headerBuffer.as_mut_ptr(),
25652569
ZSTDv07_frameHeaderSize_min,
25662570
);
@@ -2569,12 +2573,10 @@ unsafe fn ZSTDv07_decompressContinue(
25692573
dctx.stage = ZSTDds_decodeSkippableHeader;
25702574
return Ok(0);
25712575
}
2572-
dctx.headerSize = ZSTDv07_frameHeaderSize(core::slice::from_raw_parts(
2573-
src.cast::<u8>(),
2574-
ZSTDv07_frameHeaderSize_min,
2575-
))?;
2576+
dctx.headerSize =
2577+
ZSTDv07_frameHeaderSize(src.subslice(..ZSTDv07_frameHeaderSize_min).as_slice())?;
25762578
ptr::copy_nonoverlapping(
2577-
src as *const u8,
2579+
src.as_ptr(),
25782580
dctx.headerBuffer.as_mut_ptr(),
25792581
ZSTDv07_frameHeaderSize_min,
25802582
);
@@ -2591,19 +2593,17 @@ unsafe fn ZSTDv07_decompressContinue(
25912593
blockType: bt_compressed,
25922594
origSize: 0,
25932595
};
2594-
let cBlockSize = ZSTDv07_getcBlockSize(
2595-
core::slice::from_raw_parts(src.cast::<u8>(), ZSTDv07_blockHeaderSize),
2596-
&mut bp,
2597-
)?;
2596+
let cBlockSize =
2597+
ZSTDv07_getcBlockSize(src.subslice(..ZSTDv07_blockHeaderSize).as_slice(), &mut bp)?;
25982598
if bp.blockType == bt_end {
25992599
if dctx.fParams.checksumFlag != 0 {
26002600
let h64 = ZSTD_XXH64_digest(&mut dctx.xxhState);
2601-
let h32 = (h64 >> 11) as u32 & (((1) << 22) - 1) as u32;
2602-
let ip = src as *const u8;
2603-
let check32 = (*ip.add(2) as core::ffi::c_int
2604-
+ ((*ip.add(1) as core::ffi::c_int) << 8)
2605-
+ ((*ip as core::ffi::c_int & 0x3f as core::ffi::c_int) << 16))
2606-
as u32;
2601+
let h32 = (h64 >> 11) as u32 & ((1 << 22) - 1) as u32;
2602+
let ip = src.subslice(..3);
2603+
let ip = ip.as_slice();
2604+
let check32 = u32::from(ip[2])
2605+
+ (u32::from(ip[1]) << 8)
2606+
+ ((u32::from(ip[0]) & 0x3f) << 16);
26072607
if check32 != h32 {
26082608
return Err(Error::checksum_wrong);
26092609
}
@@ -2619,16 +2619,8 @@ unsafe fn ZSTDv07_decompressContinue(
26192619
}
26202620
3 => {
26212621
let rSize = match dctx.bType {
2622-
0 => ZSTDv07_decompressBlock_internal(
2623-
dctx,
2624-
dst.subslice(..),
2625-
src.cast::<u8>(),
2626-
srcSize,
2627-
),
2628-
1 => ZSTDv07_copyRawBlock(
2629-
dst.subslice(..),
2630-
core::slice::from_raw_parts(src.cast::<u8>(), srcSize),
2631-
),
2622+
0 => ZSTDv07_decompressBlock_internal(dctx, dst.subslice(..), src.as_slice()),
2623+
1 => ZSTDv07_copyRawBlock(dst.subslice(..), src.as_slice()),
26322624
2 => return Err(Error::GENERIC),
26332625
3 => Ok(0),
26342626
_ => return Err(Error::GENERIC),
@@ -2647,15 +2639,12 @@ unsafe fn ZSTDv07_decompressContinue(
26472639
}
26482640
4 => {
26492641
ptr::copy_nonoverlapping(
2650-
src as *const u8,
2651-
dctx.headerBuffer
2652-
.as_mut_ptr()
2653-
.add(ZSTDv07_frameHeaderSize_min),
2642+
src.as_ptr(),
2643+
dctx.headerBuffer[ZSTDv07_frameHeaderSize_min..].as_mut_ptr(),
26542644
dctx.expected,
26552645
);
26562646
dctx.expected =
2657-
MEM_readLE32(dctx.headerBuffer.as_mut_ptr().add(4) as *const core::ffi::c_void)
2658-
as usize;
2647+
u32::from_le_bytes(dctx.headerBuffer[4..8].try_into().unwrap()) as usize;
26592648
dctx.stage = ZSTDds_skipFrame;
26602649
return Ok(0);
26612650
}
@@ -2667,10 +2656,8 @@ unsafe fn ZSTDv07_decompressContinue(
26672656
_ => return Err(Error::GENERIC),
26682657
}
26692658
ptr::copy_nonoverlapping(
2670-
src as *const u8,
2671-
dctx.headerBuffer
2672-
.as_mut_ptr()
2673-
.add(ZSTDv07_frameHeaderSize_min),
2659+
src.as_ptr(),
2660+
dctx.headerBuffer[ZSTDv07_frameHeaderSize_min..].as_mut_ptr(),
26742661
dctx.expected,
26752662
);
26762663
ZSTDv07_decodeFrameHeader(dctx, &(&dctx.headerBuffer)[..dctx.headerSize])?;
@@ -2933,17 +2920,15 @@ pub(crate) unsafe fn ZBUFFv07_decompressContinue(
29332920
ZSTDv07_decompressContinue(
29342921
&mut *zbd.zd,
29352922
Writer::from_slice(&mut []),
2936-
zbd.headerBuffer.as_mut_ptr() as *const core::ffi::c_void,
2937-
h1Size,
2923+
Reader::from_raw_parts(zbd.headerBuffer.as_ptr(), h1Size),
29382924
)?;
29392925
if h1Size < zbd.lhSize {
29402926
// long header
29412927
let h2Size = ZSTDv07_nextSrcSizeToDecompress(&*zbd.zd);
29422928
ZSTDv07_decompressContinue(
29432929
&mut *zbd.zd,
29442930
Writer::from_slice(&mut []),
2945-
zbd.headerBuffer.as_mut_ptr().add(h1Size) as *const core::ffi::c_void,
2946-
h2Size,
2931+
Reader::from_raw_parts(zbd.headerBuffer.as_ptr().add(h1Size), h2Size),
29472932
)?;
29482933
}
29492934
zbd.fParams.windowSize = core::cmp::max(zbd.fParams.windowSize, 1 << 10);
@@ -3005,8 +2990,7 @@ pub(crate) unsafe fn ZBUFFv07_decompressContinue(
30052990
Writer::from_raw_parts(zbd.outBuff, zbd.outBuffSize)
30062991
.subslice(zbd.outStart..)
30072992
},
3008-
ip as *const core::ffi::c_void,
3009-
neededInSize,
2993+
Reader::from_raw_parts(ip, neededInSize),
30102994
)?;
30112995
ip = ip.add(neededInSize);
30122996
if decodedSize == 0 && !isSkipFrame {
@@ -3053,8 +3037,7 @@ pub(crate) unsafe fn ZBUFFv07_decompressContinue(
30533037
let decodedSize_0 = ZSTDv07_decompressContinue(
30543038
&mut *zbd.zd,
30553039
Writer::from_raw_parts(zbd.outBuff, zbd.outBuffSize).subslice(zbd.outStart..),
3056-
zbd.inBuff as *const core::ffi::c_void,
3057-
neededInSize_0,
3040+
Reader::from_raw_parts(zbd.inBuff, neededInSize_0),
30583041
)?;
30593042
zbd.inPos = 0; // input is consumed
30603043
if decodedSize_0 == 0 && !isSkipFrame_0 {

0 commit comments

Comments
 (0)