Skip to content

Commit f5ee0ad

Browse files
michielp1807folkertdev
authored andcommitted
fn ZDICT_trainFromBuffer_cover: make safe
1 parent 4f870b5 commit f5ee0ad

File tree

1 file changed

+53
-34
lines changed

1 file changed

+53
-34
lines changed

lib/dictBuilder/cover.rs

Lines changed: 53 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -555,9 +555,7 @@ fn COVER_ctx_init<'a>(
555555
totalSamplesSize
556556
};
557557
let testSamplesSize = if splitPoint < 1.0f64 {
558-
samplesSizes[nbTrainSamples..][..nbTestSamples]
559-
.iter()
560-
.sum()
558+
samplesSizes[nbTrainSamples..][..nbTestSamples].iter().sum()
561559
} else {
562560
totalSamplesSize
563561
};
@@ -803,16 +801,50 @@ pub(super) const unsafe fn assume_init_ref<T>(slice: &[MaybeUninit<T>]) -> &[T]
803801
unsafe { &*(slice as *const [MaybeUninit<T>] as *const [T]) }
804802
}
805803

804+
/// # Safety
805+
///
806+
/// Behavior is undefined if any of the following conditions are violated:
807+
///
808+
/// - `dictBufferCapacity` is 0 or `dictBuffer` and `dictBufferCapacity` satisfy the requirements
809+
/// of [`core::slice::from_raw_parts_mut`].
810+
/// - `nbSamples` is 0 or `samplesSizes` and `nbSamples` satisfy the requirements
811+
/// of [`core::slice::from_raw_parts`].
812+
/// - `sum(samplesSizes)` is 0 or `samplesBuffer` and `sum(samplesSizes)` satisfy the requirements
813+
/// of [`core::slice::from_raw_parts`].
806814
#[cfg_attr(feature = "export-symbols", export_name = crate::prefix!(ZDICT_trainFromBuffer_cover))]
807815
pub unsafe extern "C" fn ZDICT_trainFromBuffer_cover(
808816
dictBuffer: *mut core::ffi::c_void,
809817
dictBufferCapacity: size_t,
810818
samplesBuffer: *const core::ffi::c_void,
811819
samplesSizes: *const size_t,
812820
nbSamples: core::ffi::c_uint,
821+
parameters: ZDICT_cover_params_t,
822+
) -> size_t {
823+
let dict = unsafe { core::slice::from_raw_parts_mut(dictBuffer.cast(), dictBufferCapacity) };
824+
825+
let samplesSizes = if samplesSizes.is_null() || nbSamples == 0 {
826+
&[]
827+
} else {
828+
core::slice::from_raw_parts(samplesSizes, nbSamples as usize)
829+
};
830+
let totalSamplesSize = samplesSizes.iter().sum::<usize>();
831+
let samples = if samplesBuffer.is_null() || totalSamplesSize == 0 {
832+
&[]
833+
} else {
834+
core::slice::from_raw_parts(samplesBuffer.cast::<u8>(), totalSamplesSize)
835+
};
836+
837+
train_from_buffer_cover(dict, samples, samplesSizes, parameters)
838+
}
839+
840+
fn train_from_buffer_cover(
841+
dict: &mut [MaybeUninit<u8>],
842+
samples: &[u8],
843+
samplesSizes: &[usize],
813844
mut parameters: ZDICT_cover_params_t,
814845
) -> size_t {
815-
let dict = dictBuffer as *mut u8;
846+
let dictBufferCapacity = dict.len();
847+
816848
let mut ctx = COVER_ctx_t::default();
817849
let displayLevel = parameters.zParams.notificationLevel as core::ffi::c_int;
818850
parameters.splitPoint = 1.0f64;
@@ -822,7 +854,7 @@ pub unsafe extern "C" fn ZDICT_trainFromBuffer_cover(
822854
}
823855
return Error::parameter_outOfBound.to_error_code();
824856
}
825-
if nbSamples == 0 {
857+
if samplesSizes.is_empty() {
826858
if displayLevel >= 1 {
827859
eprintln!("Cover must have at least one input file");
828860
}
@@ -835,18 +867,6 @@ pub unsafe extern "C" fn ZDICT_trainFromBuffer_cover(
835867
return Error::dstSize_tooSmall.to_error_code();
836868
}
837869

838-
let samplesSizes = if samplesSizes.is_null() || nbSamples == 0 {
839-
&[]
840-
} else {
841-
core::slice::from_raw_parts(samplesSizes, nbSamples as usize)
842-
};
843-
let totalSamplesSize = samplesSizes.iter().sum::<usize>();
844-
let samples = if samplesBuffer.is_null() || totalSamplesSize == 0 {
845-
&[]
846-
} else {
847-
core::slice::from_raw_parts(samplesBuffer.cast::<u8>(), totalSamplesSize)
848-
};
849-
850870
let initVal = COVER_ctx_init(
851871
&mut ctx,
852872
samples,
@@ -866,25 +886,24 @@ pub unsafe extern "C" fn ZDICT_trainFromBuffer_cover(
866886
}
867887

868888
let mut freqs = core::mem::take(&mut ctx.freqs);
869-
let dict_tail = COVER_buildDictionary(
870-
&ctx,
871-
&mut freqs,
872-
&mut activeDmers,
873-
unsafe { core::slice::from_raw_parts_mut(dictBuffer.cast(), dictBufferCapacity) },
874-
parameters,
875-
);
889+
let dict_tail = COVER_buildDictionary(&ctx, &mut freqs, &mut activeDmers, dict, parameters);
876890
ctx.freqs = freqs;
877891

878-
let dictionarySize = ZDICT_finalizeDictionary(
879-
dict as *mut core::ffi::c_void,
880-
dictBufferCapacity,
881-
dict_tail.as_ptr() as *const core::ffi::c_void,
882-
dict_tail.len(),
883-
samplesBuffer,
884-
samplesSizes.as_ptr(),
885-
nbSamples,
886-
parameters.zParams,
887-
);
892+
let customDictContentSize = dict_tail.len();
893+
let dictBuffer = dict.as_mut_ptr() as *mut core::ffi::c_void;
894+
let customDictContent = dictBuffer.wrapping_add(dictBufferCapacity - customDictContentSize);
895+
let dictionarySize = unsafe {
896+
ZDICT_finalizeDictionary(
897+
dictBuffer,
898+
dictBufferCapacity,
899+
customDictContent,
900+
customDictContentSize,
901+
samples.as_ptr() as *const core::ffi::c_void,
902+
samplesSizes.as_ptr(),
903+
samplesSizes.len() as core::ffi::c_uint,
904+
parameters.zParams,
905+
)
906+
};
888907
if !ERR_isError(dictionarySize) && displayLevel >= 2 {
889908
eprintln!("Constructed dictionary of size {}", dictionarySize,);
890909
}

0 commit comments

Comments
 (0)