@@ -468,16 +468,50 @@ fn FASTCOVER_convertToFastCoverParams(
468468 fastCoverParams. shrinkDict = coverParams. shrinkDict ;
469469}
470470
471+ /// # Safety
472+ ///
473+ /// Behavior is undefined if any of the following conditions are violated:
474+ ///
475+ /// - `dictBufferCapacity` is 0 or `dictBuffer` and `dictBufferCapacity` satisfy the requirements
476+ /// of [`core::slice::from_raw_parts_mut`].
477+ /// - `nbSamples` is 0 or `samplesSizes` and `nbSamples` satisfy the requirements
478+ /// of [`core::slice::from_raw_parts`].
479+ /// - `sum(samplesSizes)` is 0 or `samplesBuffer` and `sum(samplesSizes)` satisfy the requirements
480+ /// of [`core::slice::from_raw_parts`].
471481#[ cfg_attr( feature = "export-symbols" , export_name = crate :: prefix!( ZDICT_trainFromBuffer_fastCover ) ) ]
472482pub unsafe extern "C" fn ZDICT_trainFromBuffer_fastCover (
473483 dictBuffer : * mut core:: ffi:: c_void ,
474484 dictBufferCapacity : size_t ,
475485 samplesBuffer : * const core:: ffi:: c_void ,
476486 samplesSizes : * const size_t ,
477487 nbSamples : core:: ffi:: c_uint ,
488+ parameters : ZDICT_fastCover_params_t ,
489+ ) -> size_t {
490+ let dict = unsafe { core:: slice:: from_raw_parts_mut ( dictBuffer. cast ( ) , dictBufferCapacity) } ;
491+
492+ let samplesSizes = if samplesSizes. is_null ( ) || nbSamples == 0 {
493+ & [ ]
494+ } else {
495+ core:: slice:: from_raw_parts ( samplesSizes, nbSamples as usize )
496+ } ;
497+ let totalSamplesSize = samplesSizes. iter ( ) . sum :: < usize > ( ) ;
498+ let samples = if samplesBuffer. is_null ( ) || totalSamplesSize == 0 {
499+ & [ ]
500+ } else {
501+ core:: slice:: from_raw_parts ( samplesBuffer. cast :: < u8 > ( ) , totalSamplesSize)
502+ } ;
503+
504+ train_from_buffer_fastcover ( dict, samples, samplesSizes, parameters)
505+ }
506+
507+ fn train_from_buffer_fastcover (
508+ dict : & mut [ MaybeUninit < u8 > ] ,
509+ samples : & [ u8 ] ,
510+ samplesSizes : & [ usize ] ,
478511 mut parameters : ZDICT_fastCover_params_t ,
479512) -> size_t {
480- let dict = dictBuffer as * mut u8 ;
513+ let dictBufferCapacity = dict. len ( ) ;
514+
481515 let mut ctx = FASTCOVER_ctx_t :: default ( ) ;
482516 let displayLevel = parameters. zParams . notificationLevel as core:: ffi:: c_int ;
483517 parameters. splitPoint = 1.0f64 ;
@@ -504,7 +538,7 @@ pub unsafe extern "C" fn ZDICT_trainFromBuffer_fastCover(
504538 }
505539 return Error :: parameter_outOfBound. to_error_code ( ) ;
506540 }
507- if nbSamples == 0 {
541+ if samplesSizes . is_empty ( ) {
508542 if displayLevel >= 1 {
509543 eprintln ! ( "FASTCOVER must have at least one input file" ) ;
510544 }
@@ -518,18 +552,6 @@ pub unsafe extern "C" fn ZDICT_trainFromBuffer_fastCover(
518552 }
519553 let accelParams = FASTCOVER_defaultAccelParameters [ parameters. accel as usize ] ;
520554
521- let samplesSizes = if samplesSizes. is_null ( ) || nbSamples == 0 {
522- & [ ]
523- } else {
524- core:: slice:: from_raw_parts ( samplesSizes, nbSamples as usize )
525- } ;
526- let totalSamplesSize = samplesSizes. iter ( ) . sum :: < usize > ( ) ;
527- let samples = if samplesBuffer. is_null ( ) || totalSamplesSize == 0 {
528- & [ ]
529- } else {
530- core:: slice:: from_raw_parts ( samplesBuffer. cast :: < u8 > ( ) , totalSamplesSize)
531- } ;
532-
533555 let initVal = FASTCOVER_ctx_init (
534556 & mut ctx,
535557 samples,
@@ -553,27 +575,27 @@ pub unsafe extern "C" fn ZDICT_trainFromBuffer_fastCover(
553575 let mut segmentFreqs: Box < [ u16 ] > = Box :: from ( vec ! [ 0u16 ; 1 << parameters. f] ) ;
554576
555577 let mut freqs = core:: mem:: take ( & mut ctx. freqs ) ;
556- let dict_tail = FASTCOVER_buildDictionary (
557- & ctx,
558- & mut freqs,
559- unsafe { core:: slice:: from_raw_parts_mut ( dictBuffer. cast ( ) , dictBufferCapacity) } ,
560- coverParams,
561- & mut segmentFreqs,
562- ) ;
578+ let dict_tail =
579+ FASTCOVER_buildDictionary ( & ctx, & mut freqs, dict, coverParams, & mut segmentFreqs) ;
563580 ctx. freqs = freqs;
564581
565582 let nbFinalizeSamples =
566583 ( ctx. nbTrainSamples * ctx. accelParams . finalize as size_t / 100 ) as core:: ffi:: c_uint ;
567- let dictionarySize = ZDICT_finalizeDictionary (
568- dict as * mut core:: ffi:: c_void ,
569- dictBufferCapacity,
570- dict_tail. as_ptr ( ) as * const core:: ffi:: c_void ,
571- dict_tail. len ( ) ,
572- samplesBuffer,
573- samplesSizes. as_ptr ( ) ,
574- nbFinalizeSamples,
575- coverParams. zParams ,
576- ) ;
584+ let customDictContentSize = dict_tail. len ( ) ;
585+ let dictBuffer = dict. as_mut_ptr ( ) as * mut core:: ffi:: c_void ;
586+ let customDictContent = dictBuffer. wrapping_add ( dictBufferCapacity - customDictContentSize) ;
587+ let dictionarySize = unsafe {
588+ ZDICT_finalizeDictionary (
589+ dictBuffer,
590+ dictBufferCapacity,
591+ customDictContent,
592+ customDictContentSize,
593+ samples. as_ptr ( ) as * const core:: ffi:: c_void ,
594+ samplesSizes. as_ptr ( ) ,
595+ nbFinalizeSamples,
596+ coverParams. zParams ,
597+ )
598+ } ;
577599 if !ERR_isError ( dictionarySize) && displayLevel >= 2 {
578600 eprintln ! ( "Constructed dictionary of size {}" , dictionarySize) ;
579601 }
0 commit comments