Skip to content

Commit 3955cca

Browse files
bjorn3folkertdev
authored andcommitted
Make HUFv07_decompress4X4_usingDTable_internal safe
1 parent 2ed504d commit 3955cca

File tree

2 files changed

+37
-46
lines changed

2 files changed

+37
-46
lines changed

lib/decompress/huf_decompress.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2224,7 +2224,7 @@ impl<'a> Writer<'a> {
22242224
self.ptr = unsafe { NonNull::new(ptr.as_ptr().add(1)) }
22252225
}
22262226

2227-
fn write_symbol_x2(&mut self, value: u16, length: u8) {
2227+
pub(crate) fn write_symbol_x2(&mut self, value: u16, length: u8) {
22282228
debug_assert!(length <= 2);
22292229

22302230
let Some(ptr) = self.ptr else {

lib/legacy/zstd_v07.rs

Lines changed: 36 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -1189,25 +1189,24 @@ fn HUFv07_readDTableX4(DTable: &mut HUFv07_DTable, src: &[u8]) -> Result<usize,
11891189
DTable.description = dtd;
11901190
Ok(iSize)
11911191
}
1192-
unsafe fn HUFv07_decodeSymbolX4(
1192+
fn HUFv07_decodeSymbolX4(
11931193
dst: &mut Writer<'_>,
11941194
DStream: &mut BITv07_DStream_t,
11951195
dt: &[HUFv07_DEltX4; 4096],
11961196
dtLog: u32,
11971197
) {
11981198
let val = DStream.look_bits_fast(dtLog);
1199-
ptr::write(dst.as_mut_ptr() as *mut [u8; 2], dt[val].sequence.0);
1199+
dst.write_symbol_x2(u16::from_le_bytes(dt[val].sequence.0), dt[val].length);
12001200
DStream.skip_bits(dt[val].nbBits as u32);
1201-
*dst = dst.subslice(usize::from(dt[val].length)..);
12021201
}
1203-
unsafe fn HUFv07_decodeLastSymbolX4(
1202+
fn HUFv07_decodeLastSymbolX4(
12041203
dst: &mut Writer<'_>,
12051204
DStream: &mut BITv07_DStream_t,
12061205
dt: &[HUFv07_DEltX4],
12071206
dtLog: u32,
12081207
) {
12091208
let val = DStream.look_bits_fast(dtLog);
1210-
ptr::write(dst.as_mut_ptr(), dt[val].sequence.0[0]);
1209+
dst.write_u8(dt[val].sequence.0[0]);
12111210
if (dt[val]).length == 1 {
12121211
DStream.skip_bits(dt[val].nbBits as u32);
12131212
} else if DStream.bitsConsumed < usize::BITS {
@@ -1228,24 +1227,24 @@ fn HUFv07_decodeStreamX4(
12281227
let dst_capacity = dst.capacity();
12291228
while bitDPtr.reload() == StreamStatus::Unfinished && dst.capacity() >= 8 {
12301229
if MEM_64bits() {
1231-
unsafe { HUFv07_decodeSymbolX4(&mut dst, bitDPtr, dt, dtLog) };
1230+
HUFv07_decodeSymbolX4(&mut dst, bitDPtr, dt, dtLog);
12321231
}
12331232
if MEM_64bits() || HUFv07_TABLELOG_MAX <= 12 {
1234-
unsafe { HUFv07_decodeSymbolX4(&mut dst, bitDPtr, dt, dtLog) };
1233+
HUFv07_decodeSymbolX4(&mut dst, bitDPtr, dt, dtLog);
12351234
}
12361235
if MEM_64bits() {
1237-
unsafe { HUFv07_decodeSymbolX4(&mut dst, bitDPtr, dt, dtLog) };
1236+
HUFv07_decodeSymbolX4(&mut dst, bitDPtr, dt, dtLog);
12381237
}
1239-
unsafe { HUFv07_decodeSymbolX4(&mut dst, bitDPtr, dt, dtLog) };
1238+
HUFv07_decodeSymbolX4(&mut dst, bitDPtr, dt, dtLog);
12401239
}
12411240
while bitDPtr.reload() == StreamStatus::Unfinished && dst.capacity() >= 2 {
1242-
unsafe { HUFv07_decodeSymbolX4(&mut dst, bitDPtr, dt, dtLog) };
1241+
HUFv07_decodeSymbolX4(&mut dst, bitDPtr, dt, dtLog);
12431242
}
12441243
while dst.capacity() >= 2 {
1245-
unsafe { HUFv07_decodeSymbolX4(&mut dst, bitDPtr, dt, dtLog) };
1244+
HUFv07_decodeSymbolX4(&mut dst, bitDPtr, dt, dtLog);
12461245
}
12471246
if dst.capacity() > 0 {
1248-
unsafe { HUFv07_decodeLastSymbolX4(&mut dst, bitDPtr, dt, dtLog) };
1247+
HUFv07_decodeLastSymbolX4(&mut dst, bitDPtr, dt, dtLog);
12491248
}
12501249
dst_capacity - dst.capacity()
12511250
}
@@ -1274,48 +1273,40 @@ fn HUFv07_decompress1X4_usingDTable(
12741273
}
12751274
HUFv07_decompress1X4_usingDTable_internal(dst, cSrc, DTable)
12761275
}
1277-
unsafe fn HUFv07_decompress4X4_usingDTable_internal(
1276+
fn HUFv07_decompress4X4_usingDTable_internal(
12781277
mut dst: Writer<'_>,
12791278
cSrc: &[u8],
12801279
DTable: &HUFv07_DTable,
12811280
) -> Result<usize, Error> {
12821281
if cSrc.len() < 10 {
12831282
return Err(Error::corruption_detected);
12841283
}
1285-
let istart = cSrc.as_ptr();
1284+
let mut ip = cSrc;
12861285
let dstSize = dst.capacity();
1287-
let ostart = dst.as_mut_ptr();
1288-
let oend = ostart.add(dstSize);
12891286
let dt = DTable.as_x4();
1290-
let length1 = MEM_readLE16(istart as *const core::ffi::c_void) as usize;
1291-
let length2 = MEM_readLE16(istart.add(2) as *const core::ffi::c_void) as usize;
1292-
let length3 = MEM_readLE16(istart.add(4) as *const core::ffi::c_void) as usize;
1293-
let length4 = cSrc.len().wrapping_sub(
1294-
length1
1295-
.wrapping_add(length2)
1296-
.wrapping_add(length3)
1297-
.wrapping_add(6),
1298-
);
1299-
let istart1 = istart.add(6);
1300-
let istart2 = istart1.add(length1);
1301-
let istart3 = istart2.add(length2);
1302-
let istart4 = istart3.add(length3);
1303-
let segmentSize = dstSize.wrapping_add(3) / 4;
1304-
let opStart2 = ostart.add(segmentSize);
1305-
let opStart3 = opStart2.add(segmentSize);
1306-
let opStart4 = opStart3.add(segmentSize);
1307-
let mut op1 = Writer::from_range(ostart, opStart2);
1308-
let mut op2 = Writer::from_range(opStart2, opStart3);
1309-
let mut op3 = Writer::from_range(opStart3, opStart3);
1310-
let mut op4 = Writer::from_range(opStart4, oend);
1311-
let dtLog = DTable.description.tableLog as u32;
1312-
if length4 > cSrc.len() {
1287+
let length1 = usize::from(u16::from_le_bytes(ip[0..2].try_into().unwrap()));
1288+
let length2 = usize::from(u16::from_le_bytes(ip[2..4].try_into().unwrap()));
1289+
let length3 = usize::from(u16::from_le_bytes(ip[4..6].try_into().unwrap()));
1290+
ip = &ip[6..];
1291+
let (istart1, istart2, istart3, istart4);
1292+
(istart1, ip) = ip
1293+
.split_at_checked(length1)
1294+
.ok_or(Error::corruption_detected)?;
1295+
(istart2, ip) = ip
1296+
.split_at_checked(length2)
1297+
.ok_or(Error::corruption_detected)?;
1298+
(istart3, ip) = ip
1299+
.split_at_checked(length3)
1300+
.ok_or(Error::corruption_detected)?;
1301+
istart4 = ip;
1302+
let Some((mut op1, mut op2, mut op3, mut op4)) = dst.quarter() else {
13131303
return Err(Error::corruption_detected);
1314-
}
1315-
let mut bitD1 = BITv07_DStream_t::new(core::slice::from_raw_parts(istart1, length1))?;
1316-
let mut bitD2 = BITv07_DStream_t::new(core::slice::from_raw_parts(istart2, length2))?;
1317-
let mut bitD3 = BITv07_DStream_t::new(core::slice::from_raw_parts(istart3, length3))?;
1318-
let mut bitD4 = BITv07_DStream_t::new(core::slice::from_raw_parts(istart4, length4))?;
1304+
};
1305+
let dtLog = DTable.description.tableLog as u32;
1306+
let mut bitD1 = BITv07_DStream_t::new(istart1)?;
1307+
let mut bitD2 = BITv07_DStream_t::new(istart2)?;
1308+
let mut bitD3 = BITv07_DStream_t::new(istart3)?;
1309+
let mut bitD4 = BITv07_DStream_t::new(istart4)?;
13191310
let mut endSignal = true;
13201311
endSignal &= bitD1.reload() == StreamStatus::Unfinished;
13211312
endSignal &= bitD2.reload() == StreamStatus::Unfinished;
@@ -1373,7 +1364,7 @@ fn HUFv07_decompress4X4_DCtx(
13731364
if hSize >= cSrc.len() {
13741365
return Err(Error::srcSize_wrong);
13751366
}
1376-
unsafe { HUFv07_decompress4X4_usingDTable_internal(dst, &cSrc[hSize..], dctx) }
1367+
HUFv07_decompress4X4_usingDTable_internal(dst, &cSrc[hSize..], dctx)
13771368
}
13781369

13791370
#[repr(C)]

0 commit comments

Comments
 (0)