@@ -555,9 +555,7 @@ fn COVER_ctx_init<'a>(
555555 totalSamplesSize
556556 } ;
557557 let testSamplesSize = if splitPoint < 1.0f64 {
558- samplesSizes[ nbTrainSamples..] [ ..nbTestSamples]
559- . iter ( )
560- . sum ( )
558+ samplesSizes[ nbTrainSamples..] [ ..nbTestSamples] . iter ( ) . sum ( )
561559 } else {
562560 totalSamplesSize
563561 } ;
@@ -803,16 +801,50 @@ pub(super) const unsafe fn assume_init_ref<T>(slice: &[MaybeUninit<T>]) -> &[T]
803801 unsafe { & * ( slice as * const [ MaybeUninit < T > ] as * const [ T ] ) }
804802}
805803
804+ /// # Safety
805+ ///
806+ /// Behavior is undefined if any of the following conditions are violated:
807+ ///
808+ /// - `dictBufferCapacity` is 0 or `dictBuffer` and `dictBufferCapacity` satisfy the requirements
809+ /// of [`core::slice::from_raw_parts_mut`].
810+ /// - `nbSamples` is 0 or `samplesSizes` and `nbSamples` satisfy the requirements
811+ /// of [`core::slice::from_raw_parts`].
812+ /// - `sum(samplesSizes)` is 0 or `samplesBuffer` and `sum(samplesSizes)` satisfy the requirements
813+ /// of [`core::slice::from_raw_parts`].
806814#[ cfg_attr( feature = "export-symbols" , export_name = crate :: prefix!( ZDICT_trainFromBuffer_cover ) ) ]
807815pub unsafe extern "C" fn ZDICT_trainFromBuffer_cover (
808816 dictBuffer : * mut core:: ffi:: c_void ,
809817 dictBufferCapacity : size_t ,
810818 samplesBuffer : * const core:: ffi:: c_void ,
811819 samplesSizes : * const size_t ,
812820 nbSamples : core:: ffi:: c_uint ,
821+ parameters : ZDICT_cover_params_t ,
822+ ) -> size_t {
823+ let dict = unsafe { core:: slice:: from_raw_parts_mut ( dictBuffer. cast ( ) , dictBufferCapacity) } ;
824+
825+ let samplesSizes = if samplesSizes. is_null ( ) || nbSamples == 0 {
826+ & [ ]
827+ } else {
828+ core:: slice:: from_raw_parts ( samplesSizes, nbSamples as usize )
829+ } ;
830+ let totalSamplesSize = samplesSizes. iter ( ) . sum :: < usize > ( ) ;
831+ let samples = if samplesBuffer. is_null ( ) || totalSamplesSize == 0 {
832+ & [ ]
833+ } else {
834+ core:: slice:: from_raw_parts ( samplesBuffer. cast :: < u8 > ( ) , totalSamplesSize)
835+ } ;
836+
837+ train_from_buffer_cover ( dict, samples, samplesSizes, parameters)
838+ }
839+
840+ fn train_from_buffer_cover (
841+ dict : & mut [ MaybeUninit < u8 > ] ,
842+ samples : & [ u8 ] ,
843+ samplesSizes : & [ usize ] ,
813844 mut parameters : ZDICT_cover_params_t ,
814845) -> size_t {
815- let dict = dictBuffer as * mut u8 ;
846+ let dictBufferCapacity = dict. len ( ) ;
847+
816848 let mut ctx = COVER_ctx_t :: default ( ) ;
817849 let displayLevel = parameters. zParams . notificationLevel as core:: ffi:: c_int ;
818850 parameters. splitPoint = 1.0f64 ;
@@ -822,7 +854,7 @@ pub unsafe extern "C" fn ZDICT_trainFromBuffer_cover(
822854 }
823855 return Error :: parameter_outOfBound. to_error_code ( ) ;
824856 }
825- if nbSamples == 0 {
857+ if samplesSizes . is_empty ( ) {
826858 if displayLevel >= 1 {
827859 eprintln ! ( "Cover must have at least one input file" ) ;
828860 }
@@ -835,18 +867,6 @@ pub unsafe extern "C" fn ZDICT_trainFromBuffer_cover(
835867 return Error :: dstSize_tooSmall. to_error_code ( ) ;
836868 }
837869
838- let samplesSizes = if samplesSizes. is_null ( ) || nbSamples == 0 {
839- & [ ]
840- } else {
841- core:: slice:: from_raw_parts ( samplesSizes, nbSamples as usize )
842- } ;
843- let totalSamplesSize = samplesSizes. iter ( ) . sum :: < usize > ( ) ;
844- let samples = if samplesBuffer. is_null ( ) || totalSamplesSize == 0 {
845- & [ ]
846- } else {
847- core:: slice:: from_raw_parts ( samplesBuffer. cast :: < u8 > ( ) , totalSamplesSize)
848- } ;
849-
850870 let initVal = COVER_ctx_init (
851871 & mut ctx,
852872 samples,
@@ -866,25 +886,24 @@ pub unsafe extern "C" fn ZDICT_trainFromBuffer_cover(
866886 }
867887
868888 let mut freqs = core:: mem:: take ( & mut ctx. freqs ) ;
869- let dict_tail = COVER_buildDictionary (
870- & ctx,
871- & mut freqs,
872- & mut activeDmers,
873- unsafe { core:: slice:: from_raw_parts_mut ( dictBuffer. cast ( ) , dictBufferCapacity) } ,
874- parameters,
875- ) ;
889+ let dict_tail = COVER_buildDictionary ( & ctx, & mut freqs, & mut activeDmers, dict, parameters) ;
876890 ctx. freqs = freqs;
877891
878- let dictionarySize = ZDICT_finalizeDictionary (
879- dict as * mut core:: ffi:: c_void ,
880- dictBufferCapacity,
881- dict_tail. as_ptr ( ) as * const core:: ffi:: c_void ,
882- dict_tail. len ( ) ,
883- samplesBuffer,
884- samplesSizes. as_ptr ( ) ,
885- nbSamples,
886- parameters. zParams ,
887- ) ;
892+ let customDictContentSize = dict_tail. len ( ) ;
893+ let dictBuffer = dict. as_mut_ptr ( ) as * mut core:: ffi:: c_void ;
894+ let customDictContent = dictBuffer. wrapping_add ( dictBufferCapacity - customDictContentSize) ;
895+ let dictionarySize = unsafe {
896+ ZDICT_finalizeDictionary (
897+ dictBuffer,
898+ dictBufferCapacity,
899+ customDictContent,
900+ customDictContentSize,
901+ samples. as_ptr ( ) as * const core:: ffi:: c_void ,
902+ samplesSizes. as_ptr ( ) ,
903+ samplesSizes. len ( ) as core:: ffi:: c_uint ,
904+ parameters. zParams ,
905+ )
906+ } ;
888907 if !ERR_isError ( dictionarySize) && displayLevel >= 2 {
889908 eprintln ! ( "Constructed dictionary of size {}" , dictionarySize, ) ;
890909 }
0 commit comments