Skip to content

Commit ec1d004

Browse files
committed
fn train_from_buffer_fastcover: make safe
1 parent fd724a0 commit ec1d004

File tree

1 file changed

+53
-31
lines changed

1 file changed

+53
-31
lines changed

lib/dictBuilder/fastcover.rs

Lines changed: 53 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -468,16 +468,50 @@ fn FASTCOVER_convertToFastCoverParams(
468468
fastCoverParams.shrinkDict = coverParams.shrinkDict;
469469
}
470470

471+
/// # Safety
472+
///
473+
/// Behavior is undefined if any of the following conditions are violated:
474+
///
475+
/// - `dictBufferCapacity` is 0 or `dictBuffer` and `dictBufferCapacity` satisfy the requirements
476+
/// of [`core::slice::from_raw_parts_mut`].
477+
/// - `nbSamples` is 0 or `samplesSizes` and `nbSamples` satisfy the requirements
478+
/// of [`core::slice::from_raw_parts`].
479+
/// - `sum(samplesSizes)` is 0 or `samplesBuffer` and `sum(samplesSizes)` satisfy the requirements
480+
/// of [`core::slice::from_raw_parts`].
471481
#[cfg_attr(feature = "export-symbols", export_name = crate::prefix!(ZDICT_trainFromBuffer_fastCover))]
472482
pub unsafe extern "C" fn ZDICT_trainFromBuffer_fastCover(
473483
dictBuffer: *mut core::ffi::c_void,
474484
dictBufferCapacity: size_t,
475485
samplesBuffer: *const core::ffi::c_void,
476486
samplesSizes: *const size_t,
477487
nbSamples: core::ffi::c_uint,
488+
parameters: ZDICT_fastCover_params_t,
489+
) -> size_t {
490+
let dict = unsafe { core::slice::from_raw_parts_mut(dictBuffer.cast(), dictBufferCapacity) };
491+
492+
let samplesSizes = if samplesSizes.is_null() || nbSamples == 0 {
493+
&[]
494+
} else {
495+
core::slice::from_raw_parts(samplesSizes, nbSamples as usize)
496+
};
497+
let totalSamplesSize = samplesSizes.iter().sum::<usize>();
498+
let samples = if samplesBuffer.is_null() || totalSamplesSize == 0 {
499+
&[]
500+
} else {
501+
core::slice::from_raw_parts(samplesBuffer.cast::<u8>(), totalSamplesSize)
502+
};
503+
504+
train_from_buffer_fastcover(dict, samples, samplesSizes, parameters)
505+
}
506+
507+
fn train_from_buffer_fastcover(
508+
dict: &mut [MaybeUninit<u8>],
509+
samples: &[u8],
510+
samplesSizes: &[usize],
478511
mut parameters: ZDICT_fastCover_params_t,
479512
) -> size_t {
480-
let dict = dictBuffer as *mut u8;
513+
let dictBufferCapacity = dict.len();
514+
481515
let mut ctx = FASTCOVER_ctx_t::default();
482516
let displayLevel = parameters.zParams.notificationLevel as core::ffi::c_int;
483517
parameters.splitPoint = 1.0f64;
@@ -504,7 +538,7 @@ pub unsafe extern "C" fn ZDICT_trainFromBuffer_fastCover(
504538
}
505539
return Error::parameter_outOfBound.to_error_code();
506540
}
507-
if nbSamples == 0 {
541+
if samplesSizes.is_empty() {
508542
if displayLevel >= 1 {
509543
eprintln!("FASTCOVER must have at least one input file");
510544
}
@@ -518,18 +552,6 @@ pub unsafe extern "C" fn ZDICT_trainFromBuffer_fastCover(
518552
}
519553
let accelParams = FASTCOVER_defaultAccelParameters[parameters.accel as usize];
520554

521-
let samplesSizes = if samplesSizes.is_null() || nbSamples == 0 {
522-
&[]
523-
} else {
524-
core::slice::from_raw_parts(samplesSizes, nbSamples as usize)
525-
};
526-
let totalSamplesSize = samplesSizes.iter().sum::<usize>();
527-
let samples = if samplesBuffer.is_null() || totalSamplesSize == 0 {
528-
&[]
529-
} else {
530-
core::slice::from_raw_parts(samplesBuffer.cast::<u8>(), totalSamplesSize)
531-
};
532-
533555
let initVal = FASTCOVER_ctx_init(
534556
&mut ctx,
535557
samples,
@@ -553,27 +575,27 @@ pub unsafe extern "C" fn ZDICT_trainFromBuffer_fastCover(
553575
let mut segmentFreqs: Box<[u16]> = Box::from(vec![0u16; 1 << parameters.f]);
554576

555577
let mut freqs = core::mem::take(&mut ctx.freqs);
556-
let dict_tail = FASTCOVER_buildDictionary(
557-
&ctx,
558-
&mut freqs,
559-
unsafe { core::slice::from_raw_parts_mut(dictBuffer.cast(), dictBufferCapacity) },
560-
coverParams,
561-
&mut segmentFreqs,
562-
);
578+
let dict_tail =
579+
FASTCOVER_buildDictionary(&ctx, &mut freqs, dict, coverParams, &mut segmentFreqs);
563580
ctx.freqs = freqs;
564581

565582
let nbFinalizeSamples =
566583
(ctx.nbTrainSamples * ctx.accelParams.finalize as size_t / 100) as core::ffi::c_uint;
567-
let dictionarySize = ZDICT_finalizeDictionary(
568-
dict as *mut core::ffi::c_void,
569-
dictBufferCapacity,
570-
dict_tail.as_ptr() as *const core::ffi::c_void,
571-
dict_tail.len(),
572-
samplesBuffer,
573-
samplesSizes.as_ptr(),
574-
nbFinalizeSamples,
575-
coverParams.zParams,
576-
);
584+
let customDictContentSize = dict_tail.len();
585+
let dictBuffer = dict.as_mut_ptr() as *mut core::ffi::c_void;
586+
let customDictContent = dictBuffer.wrapping_add(dictBufferCapacity - customDictContentSize);
587+
let dictionarySize = unsafe {
588+
ZDICT_finalizeDictionary(
589+
dictBuffer,
590+
dictBufferCapacity,
591+
customDictContent,
592+
customDictContentSize,
593+
samples.as_ptr() as *const core::ffi::c_void,
594+
samplesSizes.as_ptr(),
595+
nbFinalizeSamples,
596+
coverParams.zParams,
597+
)
598+
};
577599
if !ERR_isError(dictionarySize) && displayLevel >= 2 {
578600
eprintln!("Constructed dictionary of size {}", dictionarySize);
579601
}

0 commit comments

Comments
 (0)