Skip to content

Commit fd724a0

Browse files
committed
fn optimize_train_from_buffer_fastcover: make safe
1 parent 4019e41 commit fd724a0

File tree

1 file changed

+50
-31
lines changed

1 file changed

+50
-31
lines changed

lib/dictBuilder/fastcover.rs

Lines changed: 50 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -222,13 +222,13 @@ fn FASTCOVER_ctx_init<'a>(
222222
ctx: &mut FASTCOVER_ctx_t<'a>,
223223
samples: &'a [u8],
224224
samplesSizes: &'a [size_t],
225-
nbSamples: core::ffi::c_uint,
226225
d: core::ffi::c_uint,
227226
splitPoint: core::ffi::c_double,
228227
f: core::ffi::c_uint,
229228
accelParams: FASTCOVER_accel_t,
230229
displayLevel: core::ffi::c_int,
231230
) -> size_t {
231+
let nbSamples = samplesSizes.len() as core::ffi::c_uint;
232232
let totalSamplesSize = samplesSizes.iter().sum::<usize>();
233233
let nbTrainSamples = if splitPoint < 1.0f64 {
234234
(nbSamples as core::ffi::c_double * splitPoint) as core::ffi::c_uint
@@ -534,7 +534,6 @@ pub unsafe extern "C" fn ZDICT_trainFromBuffer_fastCover(
534534
&mut ctx,
535535
samples,
536536
samplesSizes,
537-
nbSamples,
538537
coverParams.d,
539538
parameters.splitPoint,
540539
parameters.f,
@@ -587,7 +586,13 @@ pub unsafe extern "C" fn ZDICT_trainFromBuffer_fastCover(
587586
///
588587
/// Behavior is undefined if any of the following conditions are violated:
589588
///
590-
/// - `parameters` satisfies the conditions of [`pointer::as_mut`]
589+
/// - `dictBufferCapacity` is 0 or `dictBuffer` and `dictBufferCapacity` satisfy the requirements
590+
/// of [`core::slice::from_raw_parts_mut`].
591+
/// - `nbSamples` is 0 or `samplesSizes` and `nbSamples` satisfy the requirements
592+
/// of [`core::slice::from_raw_parts`].
593+
/// - `sum(samplesSizes)` is 0 or `samplesBuffer` and `sum(samplesSizes)` satisfy the requirements
594+
/// of [`core::slice::from_raw_parts`].
595+
/// - `parameters` satisfies the requirements of [`pointer::as_mut`]
591596
#[cfg_attr(feature = "export-symbols", export_name = crate::prefix!(ZDICT_optimizeTrainFromBuffer_fastCover))]
592597
pub unsafe extern "C" fn ZDICT_optimizeTrainFromBuffer_fastCover(
593598
dictBuffer: *mut core::ffi::c_void,
@@ -597,8 +602,33 @@ pub unsafe extern "C" fn ZDICT_optimizeTrainFromBuffer_fastCover(
597602
nbSamples: core::ffi::c_uint,
598603
parameters: *mut ZDICT_fastCover_params_t,
599604
) -> size_t {
605+
let dict = unsafe { core::slice::from_raw_parts_mut(dictBuffer.cast(), dictBufferCapacity) };
606+
607+
let samplesSizes = if samplesSizes.is_null() || nbSamples == 0 {
608+
&[]
609+
} else {
610+
core::slice::from_raw_parts(samplesSizes, nbSamples as usize)
611+
};
612+
let totalSamplesSize = samplesSizes.iter().sum::<usize>();
613+
let samples = if samplesBuffer.is_null() || totalSamplesSize == 0 {
614+
&[]
615+
} else {
616+
core::slice::from_raw_parts(samplesBuffer.cast::<u8>(), totalSamplesSize)
617+
};
618+
600619
let parameters = unsafe { parameters.as_mut().unwrap() };
601620

621+
optimize_train_from_buffer_fastcover(dict, samples, samplesSizes, parameters)
622+
}
623+
624+
fn optimize_train_from_buffer_fastcover(
625+
dict: &mut [MaybeUninit<u8>],
626+
samples: &[u8],
627+
samplesSizes: &[usize],
628+
parameters: &mut ZDICT_fastCover_params_t,
629+
) -> size_t {
630+
let dictBufferCapacity = dict.len();
631+
602632
let nbThreads = parameters.nbThreads;
603633
let splitPoint = if parameters.splitPoint <= 0.0f64 {
604634
FASTCOVER_DEFAULT_SPLITPOINT
@@ -642,7 +672,6 @@ pub unsafe extern "C" fn ZDICT_optimizeTrainFromBuffer_fastCover(
642672
let shrinkDict = 0;
643673
let displayLevel = parameters.zParams.notificationLevel as core::ffi::c_int;
644674
let mut iteration = 1 as core::ffi::c_uint;
645-
let mut pool = core::ptr::null_mut();
646675
let mut warned = 0;
647676
let mut last_update_time = Instant::now();
648677
if splitPoint <= 0.0 || splitPoint > 1.0 {
@@ -663,24 +692,27 @@ pub unsafe extern "C" fn ZDICT_optimizeTrainFromBuffer_fastCover(
663692
}
664693
return Error::parameter_outOfBound.to_error_code();
665694
}
666-
if nbSamples == 0 {
695+
if samplesSizes.is_empty() {
667696
if displayLevel >= 1 {
668697
eprintln!("FASTCOVER must have at least one input file");
669698
}
670699
return Error::srcSize_wrong.to_error_code();
671700
}
672-
if dictBufferCapacity < ZDICT_DICTSIZE_MIN as size_t {
701+
if dict.len() < ZDICT_DICTSIZE_MIN as size_t {
673702
if displayLevel >= 1 {
674703
eprintln!("dictBufferCapacity must be at least {}", 256);
675704
}
676705
return Error::dstSize_tooSmall.to_error_code();
677706
}
707+
708+
let mut pool = core::ptr::null_mut();
678709
if nbThreads > 1 {
679-
pool = POOL_create(nbThreads as size_t, 1);
710+
pool = unsafe { POOL_create(nbThreads as size_t, 1) };
680711
if pool.is_null() {
681712
return Error::memory_allocation.to_error_code();
682713
}
683714
}
715+
684716
let best = COVER_best_t::new();
685717
let mut coverParams = ZDICT_cover_params_t::default();
686718
FASTCOVER_convertToCoverParams(*parameters, &mut coverParams);
@@ -689,18 +721,6 @@ pub unsafe extern "C" fn ZDICT_optimizeTrainFromBuffer_fastCover(
689721
eprintln!("Trying {} different sets of parameters", kIterations);
690722
}
691723

692-
let samplesSizes = if samplesSizes.is_null() || nbSamples == 0 {
693-
&[]
694-
} else {
695-
core::slice::from_raw_parts(samplesSizes, nbSamples as usize)
696-
};
697-
let totalSamplesSize = samplesSizes.iter().sum::<usize>();
698-
let samples = if samplesBuffer.is_null() || totalSamplesSize == 0 {
699-
&[]
700-
} else {
701-
core::slice::from_raw_parts(samplesBuffer.cast::<u8>(), totalSamplesSize)
702-
};
703-
704724
for d in (kMinD..=kMaxD).step_by(2) {
705725
let mut ctx = FASTCOVER_ctx_t::default();
706726
if displayLevel >= 3 {
@@ -716,7 +736,6 @@ pub unsafe extern "C" fn ZDICT_optimizeTrainFromBuffer_fastCover(
716736
&mut ctx,
717737
samples,
718738
samplesSizes,
719-
nbSamples,
720739
d,
721740
splitPoint,
722741
f,
@@ -728,7 +747,7 @@ pub unsafe extern "C" fn ZDICT_optimizeTrainFromBuffer_fastCover(
728747
eprintln!("Failed to initialize context");
729748
}
730749
drop(COVER_best_wait(&best));
731-
POOL_free(pool);
750+
unsafe { POOL_free(pool) };
732751
return initVal;
733752
}
734753
if warned == 0 {
@@ -771,11 +790,13 @@ pub unsafe extern "C" fn ZDICT_optimizeTrainFromBuffer_fastCover(
771790
} else {
772791
COVER_best_start(&best);
773792
if !pool.is_null() {
774-
POOL_add(
775-
pool,
776-
FASTCOVER_tryParameters_wrapper,
777-
Box::leak(data) as *mut _ as *mut core::ffi::c_void,
778-
);
793+
unsafe {
794+
POOL_add(
795+
pool,
796+
FASTCOVER_tryParameters_wrapper,
797+
Box::leak(data) as *mut _ as *mut core::ffi::c_void,
798+
)
799+
}
779800
} else {
780801
FASTCOVER_tryParameters(data);
781802
}
@@ -806,13 +827,11 @@ pub unsafe extern "C" fn ZDICT_optimizeTrainFromBuffer_fastCover(
806827
let dictSize = best.dictSize;
807828
if ERR_isError(best.compressedSize) {
808829
let compressedSize = best.compressedSize;
809-
POOL_free(pool);
830+
unsafe { POOL_free(pool) };
810831
return compressedSize;
811832
}
812833
FASTCOVER_convertToFastCoverParams(best.parameters, parameters, f, accel);
813-
unsafe {
814-
core::ptr::copy_nonoverlapping(best.dict.as_ptr(), dictBuffer.cast::<u8>(), dictSize)
815-
};
816-
POOL_free(pool);
834+
dict[..dictSize].copy_from_slice(super::cover::as_uninit(&best.dict[..dictSize]));
835+
unsafe { POOL_free(pool) };
817836
dictSize
818837
}

0 commit comments

Comments
 (0)