Skip to content

Commit 795d438

Browse files
michielp1807folkertdev
authored andcommitted
find_frame_size_info: use Result
1 parent b5d8064 commit 795d438

File tree

1 file changed

+40
-60
lines changed

1 file changed

+40
-60
lines changed

lib/decompress/zstd_decompress.rs

Lines changed: 40 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -288,7 +288,7 @@ unsafe fn ZSTD_decompressLegacy(
288288
}
289289

290290
// FIXME: this should be totally safe at this point.
291-
unsafe fn find_frame_size_info_legacy(src: &[u8]) -> ZSTD_frameSizeInfo {
291+
unsafe fn find_frame_size_info_legacy(src: &[u8]) -> Result<ZSTD_frameSizeInfo, Error> {
292292
let mut frameSizeInfo = ZSTD_frameSizeInfo::default();
293293

294294
match is_legacy(src) {
@@ -316,14 +316,16 @@ unsafe fn find_frame_size_info_legacy(src: &[u8]) -> ZSTD_frameSizeInfo {
316316
);
317317
}
318318
_ => {
319-
frameSizeInfo.compressedSize = Error::prefix_unknown.to_error_code();
320-
frameSizeInfo.decompressedBound = ZSTD_CONTENTSIZE_ERROR;
319+
return Err(Error::prefix_unknown);
321320
}
322321
}
323322

324-
if !ERR_isError(frameSizeInfo.compressedSize) && frameSizeInfo.compressedSize > src.len() {
325-
frameSizeInfo.compressedSize = Error::srcSize_wrong.to_error_code();
326-
frameSizeInfo.decompressedBound = ZSTD_CONTENTSIZE_ERROR;
323+
if let Some(err) = Error::from_error_code(frameSizeInfo.compressedSize) {
324+
return Err(err);
325+
}
326+
327+
if frameSizeInfo.compressedSize > src.len() {
328+
return Err(Error::srcSize_wrong);
327329
}
328330

329331
if frameSizeInfo.decompressedBound != ZSTD_CONTENTSIZE_ERROR {
@@ -332,7 +334,7 @@ unsafe fn find_frame_size_info_legacy(src: &[u8]) -> ZSTD_frameSizeInfo {
332334
as size_t;
333335
}
334336

335-
frameSizeInfo
337+
Ok(frameSizeInfo)
336338
}
337339

338340
#[inline]
@@ -1329,10 +1331,10 @@ fn find_decompressed_size(mut src: &[u8]) -> u64 {
13291331
};
13301332

13311333
// skip to next frame
1332-
let frameSrcSize = ZSTD_findFrameCompressedSize_advanced(src, Format::ZSTD_f_zstd1);
1333-
if ERR_isError(frameSrcSize) {
1334+
let Ok(frameSrcSize) = ZSTD_findFrameCompressedSize_advanced(src, Format::ZSTD_f_zstd1)
1335+
else {
13341336
return ZSTD_CONTENTSIZE_ERROR;
1335-
}
1337+
};
13361338
src = &src[frameSrcSize..];
13371339
}
13381340
}
@@ -1399,15 +1401,7 @@ fn ZSTD_decodeFrameHeader(dctx: &mut ZSTD_DCtx, src: &[u8]) -> Result<size_t, Er
13991401
Ok(0)
14001402
}
14011403

1402-
fn ZSTD_errorFrameSizeInfo(ret: size_t) -> ZSTD_frameSizeInfo {
1403-
ZSTD_frameSizeInfo {
1404-
nbBlocks: 0,
1405-
compressedSize: ret,
1406-
decompressedBound: ZSTD_CONTENTSIZE_ERROR,
1407-
}
1408-
}
1409-
1410-
fn find_frame_size_info(src: &[u8], format: Format) -> ZSTD_frameSizeInfo {
1404+
fn find_frame_size_info(src: &[u8], format: Format) -> Result<ZSTD_frameSizeInfo, Error> {
14111405
let mut frameSizeInfo = ZSTD_frameSizeInfo::default();
14121406

14131407
if format == Format::ZSTD_f_zstd1 && is_legacy(src) != 0 {
@@ -1418,25 +1412,19 @@ fn find_frame_size_info(src: &[u8], format: Format) -> ZSTD_frameSizeInfo {
14181412
&& src.len() >= ZSTD_SKIPPABLEHEADERSIZE as usize
14191413
&& is_skippable_frame(src)
14201414
{
1421-
frameSizeInfo.compressedSize =
1422-
read_skippable_frame_size(src).unwrap_or_else(Error::to_error_code);
1423-
debug_assert!(
1424-
ERR_isError(frameSizeInfo.compressedSize) || frameSizeInfo.compressedSize <= src.len()
1425-
);
1426-
frameSizeInfo
1415+
frameSizeInfo.compressedSize = read_skippable_frame_size(src)?;
1416+
debug_assert!(frameSizeInfo.compressedSize <= src.len());
1417+
Ok(frameSizeInfo)
14271418
} else {
14281419
let mut ip = 0;
14291420
let mut remainingSize = src.len();
14301421
let mut nbBlocks = 0usize;
14311422
let mut zfh = ZSTD_FrameHeader::default();
14321423

14331424
// extract Frame Header
1434-
let ret = match get_frame_header_advanced(&mut zfh, src, format) {
1435-
Ok(ret) => ret,
1436-
Err(err) => return ZSTD_errorFrameSizeInfo(err.to_error_code()),
1437-
};
1425+
let ret = get_frame_header_advanced(&mut zfh, src, format)?;
14381426
if ret > 0 {
1439-
return ZSTD_errorFrameSizeInfo(Error::srcSize_wrong.to_error_code());
1427+
return Err(Error::srcSize_wrong);
14401428
}
14411429

14421430
ip += zfh.headerSize as usize;
@@ -1445,12 +1433,9 @@ fn find_frame_size_info(src: &[u8], format: Format) -> ZSTD_frameSizeInfo {
14451433
// iterate over each block
14461434
loop {
14471435
let mut blockProperties = blockProperties_t::default();
1448-
let cBlockSize = match ZSTD_getcBlockSize(&src[ip..], &mut blockProperties) {
1449-
Ok(size) => size,
1450-
Err(err) => return ZSTD_errorFrameSizeInfo(err.to_error_code()),
1451-
};
1436+
let cBlockSize = ZSTD_getcBlockSize(&src[ip..], &mut blockProperties)?;
14521437
if ZSTD_blockHeaderSize.wrapping_add(cBlockSize) > remainingSize {
1453-
return ZSTD_errorFrameSizeInfo(Error::srcSize_wrong.to_error_code());
1438+
return Err(Error::srcSize_wrong);
14541439
}
14551440

14561441
ip += ZSTD_blockHeaderSize.wrapping_add(cBlockSize) as usize;
@@ -1466,7 +1451,7 @@ fn find_frame_size_info(src: &[u8], format: Format) -> ZSTD_frameSizeInfo {
14661451
// final frame content checksum
14671452
if zfh.checksumFlag != 0 {
14681453
if remainingSize < 4 {
1469-
return ZSTD_errorFrameSizeInfo(Error::srcSize_wrong.to_error_code());
1454+
return Err(Error::srcSize_wrong);
14701455
}
14711456
ip += 4;
14721457
}
@@ -1479,12 +1464,12 @@ fn find_frame_size_info(src: &[u8], format: Format) -> ZSTD_frameSizeInfo {
14791464
(nbBlocks as core::ffi::c_ulonglong)
14801465
.wrapping_mul(zfh.blockSizeMax as core::ffi::c_ulonglong)
14811466
};
1482-
frameSizeInfo
1467+
Ok(frameSizeInfo)
14831468
}
14841469
}
14851470

1486-
fn ZSTD_findFrameCompressedSize_advanced(src: &[u8], format: Format) -> size_t {
1487-
find_frame_size_info(src, format).compressedSize
1471+
fn ZSTD_findFrameCompressedSize_advanced(src: &[u8], format: Format) -> Result<size_t, Error> {
1472+
Ok(find_frame_size_info(src, format)?.compressedSize)
14881473
}
14891474

14901475
/// Find frame compressed size, compatible with legacy mode
@@ -1522,6 +1507,7 @@ pub unsafe extern "C" fn ZSTD_findFrameCompressedSize(
15221507
};
15231508

15241509
ZSTD_findFrameCompressedSize_advanced(src, Format::ZSTD_f_zstd1)
1510+
.unwrap_or_else(|err| err.to_error_code())
15251511
}
15261512

15271513
/// Get an upper-bound on the decompressed size
@@ -1558,14 +1544,11 @@ fn decompress_bound(mut src: &[u8]) -> core::ffi::c_ulonglong {
15581544

15591545
// iterate over each frame
15601546
while !src.is_empty() {
1561-
let frameSizeInfo = find_frame_size_info(src, Format::ZSTD_f_zstd1);
1562-
let compressedSize = frameSizeInfo.compressedSize;
1563-
let decompressedBound = frameSizeInfo.decompressedBound;
1564-
if ERR_isError(compressedSize) || decompressedBound == ZSTD_CONTENTSIZE_ERROR {
1547+
let Ok(frameSizeInfo) = find_frame_size_info(src, Format::ZSTD_f_zstd1) else {
15651548
return ZSTD_CONTENTSIZE_ERROR;
1566-
}
1567-
src = &src[compressedSize as usize..];
1568-
bound += decompressedBound;
1549+
};
1550+
src = &src[frameSizeInfo.compressedSize..];
1551+
bound += frameSizeInfo.decompressedBound;
15691552
}
15701553

15711554
bound
@@ -1607,26 +1590,22 @@ pub unsafe extern "C" fn ZSTD_decompressionMargin(
16071590
} else {
16081591
core::slice::from_raw_parts(src.cast(), srcSize)
16091592
})
1593+
.unwrap_or_else(|err| err.to_error_code())
16101594
}
16111595

1612-
fn decompression_margin(mut src: &[u8]) -> size_t {
1596+
fn decompression_margin(mut src: &[u8]) -> Result<size_t, Error> {
16131597
let mut margin = 0;
16141598
let mut maxBlockSize = 0;
16151599

16161600
// iterate over each frame
16171601
while !src.is_empty() {
16181602
let frameSizeInfo = find_frame_size_info(src, Format::ZSTD_f_zstd1);
1619-
let compressedSize = frameSizeInfo.compressedSize;
1620-
let decompressedBound = frameSizeInfo.decompressedBound;
16211603

16221604
let mut zfh = ZSTD_FrameHeader::default();
1623-
if let Err(err) = get_frame_header(&mut zfh, src) {
1624-
return err.to_error_code();
1625-
};
1605+
get_frame_header(&mut zfh, src)?;
16261606

1627-
if ERR_isError(compressedSize) || decompressedBound == ZSTD_CONTENTSIZE_ERROR {
1628-
return Error::corruption_detected.to_error_code();
1629-
}
1607+
let frameSizeInfo = frameSizeInfo.map_err(|_| Error::corruption_detected)?;
1608+
let compressedSize = frameSizeInfo.compressedSize;
16301609

16311610
if zfh.frameType as core::ffi::c_uint == ZSTD_frame as core::ffi::c_uint {
16321611
// add the frame header to our margin
@@ -1640,13 +1619,13 @@ fn decompression_margin(mut src: &[u8]) -> size_t {
16401619
margin += compressedSize;
16411620
}
16421621

1643-
src = &src[compressedSize as usize..];
1622+
src = &src[compressedSize..];
16441623
}
16451624

16461625
// add the max block size back to the margin
16471626
margin += maxBlockSize as size_t;
16481627

1649-
margin
1628+
Ok(margin)
16501629
}
16511630

16521631
/// Insert `src` block into `dctx` history. Useful to track uncompressed blocks.
@@ -1893,9 +1872,8 @@ unsafe fn ZSTD_decompressMultiFrame<'a>(
18931872

18941873
while src.len() >= ZSTD_startingInputLength((*dctx).format) {
18951874
if (*dctx).format == Format::ZSTD_f_zstd1 && is_legacy(src.as_slice()) != 0 {
1896-
let frameSizeInfo = find_frame_size_info_legacy(src.as_slice());
1875+
let frameSizeInfo = find_frame_size_info_legacy(src.as_slice())?;
18971876
let frameSize = frameSizeInfo.compressedSize;
1898-
Error::from_error_code(frameSize).map_or(Ok(()), Err)?;
18991877

19001878
if (*dctx).staticSize != 0 {
19011879
return Err(Error::memory_allocation);
@@ -3424,7 +3402,9 @@ pub unsafe extern "C" fn ZSTD_decompressStream(
34243402
iend.offset_from_unsigned(istart),
34253403
),
34263404
zds.format,
3427-
);
3405+
)
3406+
.unwrap_or_else(Error::to_error_code);
3407+
34283408
if cSize <= iend.offset_from_unsigned(istart) {
34293409
let decompressedSize = ZSTD_decompress_usingDDict(
34303410
zds,

0 commit comments

Comments
 (0)