Skip to content

Commit dadc19b

Browse files
committed
fn ZDICT_optimizeTrainFromBuffer_cover: split out unsafety
1 parent ec1d004 commit dadc19b

File tree

1 file changed

+45
-35
lines changed

1 file changed

+45
-35
lines changed

lib/dictBuilder/cover.rs

Lines changed: 45 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -533,14 +533,14 @@ fn COVER_ctx_init<'a>(
533533
ctx: &'_ mut COVER_ctx_t<'a>,
534534
samples: &'a [u8],
535535
samplesSizes: &'a [size_t],
536-
nbSamples: core::ffi::c_uint,
537536
d: core::ffi::c_uint,
538537
splitPoint: core::ffi::c_double,
539538
displayLevel: core::ffi::c_int,
540539
) -> size_t {
540+
let nbSamples = samplesSizes.len();
541541
let totalSamplesSize = samples.len();
542542
let nbTrainSamples = if splitPoint < 1.0f64 {
543-
(nbSamples as core::ffi::c_double * splitPoint) as core::ffi::c_uint
543+
(nbSamples as core::ffi::c_double * splitPoint) as usize
544544
} else {
545545
nbSamples
546546
};
@@ -851,7 +851,6 @@ pub unsafe extern "C" fn ZDICT_trainFromBuffer_cover(
851851
&mut ctx,
852852
samples,
853853
samplesSizes,
854-
nbSamples,
855854
parameters.d,
856855
parameters.splitPoint,
857856
displayLevel,
@@ -1196,8 +1195,35 @@ pub unsafe extern "C" fn ZDICT_optimizeTrainFromBuffer_cover(
11961195
nbSamples: core::ffi::c_uint,
11971196
parameters: *mut ZDICT_cover_params_t,
11981197
) -> size_t {
1198+
let dict = if dictBuffer.is_null() || nbSamples == 0 {
1199+
&mut []
1200+
} else {
1201+
core::slice::from_raw_parts_mut(dictBuffer.cast(), dictBufferCapacity)
1202+
};
1203+
1204+
let samplesSizes = if samplesSizes.is_null() || nbSamples == 0 {
1205+
&[]
1206+
} else {
1207+
core::slice::from_raw_parts(samplesSizes, nbSamples as usize)
1208+
};
1209+
let totalSamplesSize = samplesSizes.iter().sum::<usize>();
1210+
let samples = if samplesBuffer.is_null() || totalSamplesSize == 0 {
1211+
&[]
1212+
} else {
1213+
core::slice::from_raw_parts(samplesBuffer.cast::<u8>(), totalSamplesSize)
1214+
};
1215+
11991216
let parameters = unsafe { parameters.as_mut().unwrap() };
12001217

1218+
optimize_train_from_buffer_cover(dict, samples, samplesSizes, parameters)
1219+
}
1220+
1221+
unsafe fn optimize_train_from_buffer_cover(
1222+
dict: &mut [MaybeUninit<u8>],
1223+
samples: &[u8],
1224+
samplesSizes: &[size_t],
1225+
parameters: &mut ZDICT_cover_params_t,
1226+
) -> usize {
12011227
let nbThreads = parameters.nbThreads;
12021228
let splitPoint = if parameters.splitPoint <= 0.0f64 {
12031229
COVER_DEFAULT_SPLITPOINT
@@ -1222,17 +1248,14 @@ pub unsafe extern "C" fn ZDICT_optimizeTrainFromBuffer_cover(
12221248
} else {
12231249
1
12241250
};
1225-
let kIterations = (1 as core::ffi::c_uint)
1251+
let kIterations = 1u32
12261252
.wrapping_add(kMaxD.wrapping_sub(kMinD).wrapping_div(2))
1227-
.wrapping_mul(
1228-
(1 as core::ffi::c_uint)
1229-
.wrapping_add(kMaxK.wrapping_sub(kMinK).wrapping_div(kStepSize)),
1230-
);
1231-
let shrinkDict = 0 as core::ffi::c_uint;
1232-
let displayLevel = parameters.zParams.notificationLevel as core::ffi::c_int;
1233-
let mut iteration = 1 as core::ffi::c_uint;
1253+
.wrapping_mul(1u32.wrapping_add(kMaxK.wrapping_sub(kMinK).wrapping_div(kStepSize)));
1254+
let shrinkDict = 0;
1255+
let displayLevel = parameters.zParams.notificationLevel as i32;
1256+
let mut iteration = 1u32;
12341257
let mut pool = core::ptr::null_mut();
1235-
let mut warned = 0;
1258+
let mut warned = false;
12361259
let mut last_update_time = Instant::now();
12371260
if splitPoint <= 0.0 || splitPoint > 1.0 {
12381261
if displayLevel >= 1 {
@@ -1246,13 +1269,13 @@ pub unsafe extern "C" fn ZDICT_optimizeTrainFromBuffer_cover(
12461269
}
12471270
return Error::parameter_outOfBound.to_error_code();
12481271
}
1249-
if nbSamples == 0 {
1272+
if samplesSizes.is_empty() {
12501273
if displayLevel >= 1 {
12511274
eprintln!("Cover must have at least one input file");
12521275
}
12531276
return Error::srcSize_wrong.to_error_code();
12541277
}
1255-
if dictBufferCapacity < ZDICT_DICTSIZE_MIN as size_t {
1278+
if dict.len() < ZDICT_DICTSIZE_MIN as size_t {
12561279
if displayLevel >= 1 {
12571280
eprintln!("dictBufferCapacity must be at least {}", 256);
12581281
}
@@ -1268,18 +1291,6 @@ pub unsafe extern "C" fn ZDICT_optimizeTrainFromBuffer_cover(
12681291
eprintln!("Trying {} different sets of parameters", kIterations);
12691292
}
12701293

1271-
let samplesSizes = if samplesSizes.is_null() || nbSamples == 0 {
1272-
&[]
1273-
} else {
1274-
core::slice::from_raw_parts(samplesSizes, nbSamples as usize)
1275-
};
1276-
let totalSamplesSize = samplesSizes.iter().sum::<usize>();
1277-
let samples = if samplesBuffer.is_null() || totalSamplesSize == 0 {
1278-
&[]
1279-
} else {
1280-
core::slice::from_raw_parts(samplesBuffer.cast::<u8>(), totalSamplesSize)
1281-
};
1282-
12831294
let best = COVER_best_t::new();
12841295

12851296
for d in (kMinD..=kMaxD).step_by(2) {
@@ -1297,7 +1308,6 @@ pub unsafe extern "C" fn ZDICT_optimizeTrainFromBuffer_cover(
12971308
&mut ctx,
12981309
samples,
12991310
samplesSizes,
1300-
nbSamples,
13011311
d,
13021312
splitPoint,
13031313
childDisplayLevel,
@@ -1310,9 +1320,9 @@ pub unsafe extern "C" fn ZDICT_optimizeTrainFromBuffer_cover(
13101320
POOL_free(pool);
13111321
return initVal;
13121322
}
1313-
if warned == 0 {
1314-
COVER_warnOnSmallCorpus(dictBufferCapacity, ctx.suffix.len(), displayLevel);
1315-
warned = 1;
1323+
if !warned {
1324+
COVER_warnOnSmallCorpus(dict.len(), ctx.suffix.len(), displayLevel);
1325+
warned = true;
13161326
}
13171327

13181328
for k in (kMinK..=kMaxK).step_by(kStepSize as usize) {
@@ -1323,7 +1333,7 @@ pub unsafe extern "C" fn ZDICT_optimizeTrainFromBuffer_cover(
13231333
steps: kSteps,
13241334
shrinkDict,
13251335
zParams: ZDICT_params_t {
1326-
notificationLevel: ctx.displayLevel as core::ffi::c_uint,
1336+
notificationLevel: ctx.displayLevel as u32,
13271337
compressionLevel: 0,
13281338
dictID: 0,
13291339
},
@@ -1334,14 +1344,14 @@ pub unsafe extern "C" fn ZDICT_optimizeTrainFromBuffer_cover(
13341344
let data = Box::new(COVER_tryParameters_data_t {
13351345
ctx: &ctx,
13361346
best: &best,
1337-
dictBufferCapacity,
1347+
dictBufferCapacity: dict.len(),
13381348
parameters,
13391349
});
13401350

13411351
if displayLevel >= 3 {
13421352
eprintln!("k={}", k);
13431353
}
1344-
if !COVER_checkParameters(data.parameters, dictBufferCapacity) {
1354+
if !COVER_checkParameters(data.parameters, dict.len()) {
13451355
if displayLevel >= 1 {
13461356
eprintln!("Cover parameters incorrect");
13471357
}
@@ -1389,8 +1399,8 @@ pub unsafe extern "C" fn ZDICT_optimizeTrainFromBuffer_cover(
13891399
}
13901400
*parameters = best.parameters;
13911401
core::ptr::copy_nonoverlapping(
1392-
best.dict[..dictSize].as_ptr(),
1393-
dictBuffer.cast::<u8>(),
1402+
best.dict[..dictSize].as_ptr().cast::<MaybeUninit<u8>>(),
1403+
dict[..dictSize].as_mut_ptr(),
13941404
dictSize,
13951405
);
13961406
POOL_free(pool);

0 commit comments

Comments
 (0)