Skip to content

Commit 40d7af2

Browse files
authored
Cache FSST compressor for running compare (#3343)
Signed-off-by: Nicholas Gates <[email protected]>
1 parent 10e84ca commit 40d7af2

File tree

2 files changed

+38
-14
lines changed

2 files changed

+38
-14
lines changed

encodings/fsst/src/array.rs

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
1-
use fsst::{Decompressor, Symbol};
1+
use std::fmt::{Debug, Formatter};
2+
use std::sync::{Arc, LazyLock};
3+
4+
use fsst::{Compressor, Decompressor, Symbol};
25
use vortex_array::arrays::VarBinArray;
36
use vortex_array::stats::{ArrayStats, StatsSetRef};
47
use vortex_array::vtable::{
@@ -33,7 +36,7 @@ impl VTable for FSSTVTable {
3336
}
3437
}
3538

36-
#[derive(Clone, Debug)]
39+
#[derive(Clone)]
3740
pub struct FSSTArray {
3841
dtype: DType,
3942
symbols: Buffer<Symbol>,
@@ -42,6 +45,21 @@ pub struct FSSTArray {
4245
/// Lengths of the original values before compression, can be compressed.
4346
uncompressed_lengths: ArrayRef,
4447
stats_set: ArrayStats,
48+
49+
/// Memoized compressor used for push-down of compute by compressing the RHS.
50+
compressor: Arc<LazyLock<Compressor, Box<dyn Fn() -> Compressor + Send>>>,
51+
}
52+
53+
impl Debug for FSSTArray {
54+
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
55+
f.debug_struct("FSSTArray")
56+
.field("dtype", &self.dtype)
57+
.field("symbols", &self.symbols)
58+
.field("symbol_lengths", &self.symbol_lengths)
59+
.field("codes", &self.codes)
60+
.field("uncompressed_lengths", &self.uncompressed_lengths)
61+
.finish()
62+
}
4563
}
4664

4765
#[derive(Clone, Debug)]
@@ -84,13 +102,21 @@ impl FSSTArray {
84102
vortex_bail!(InvalidArgument: "codes array must be DType::Binary type");
85103
}
86104

105+
let symbols2 = symbols.clone();
106+
let symbol_lengths2 = symbol_lengths.clone();
107+
let compressor = Arc::new(LazyLock::new(Box::new(move || {
108+
Compressor::rebuild_from(symbols2.as_slice(), symbol_lengths2.as_slice())
109+
})
110+
as Box<dyn Fn() -> Compressor + Send>));
111+
87112
Ok(Self {
88113
dtype,
89114
symbols,
90115
symbol_lengths,
91116
codes,
92117
uncompressed_lengths,
93118
stats_set: Default::default(),
119+
compressor,
94120
})
95121
}
96122

@@ -133,6 +159,10 @@ impl FSSTArray {
133159
pub(crate) fn decompressor(&self) -> Decompressor {
134160
Decompressor::new(self.symbols().as_slice(), self.symbol_lengths().as_slice())
135161
}
162+
163+
pub(crate) fn compressor(&self) -> &Compressor {
164+
self.compressor.as_ref()
165+
}
136166
}
137167

138168
impl ArrayVTable<FSSTVTable> for FSSTVTable {

encodings/fsst/src/compute/compare.rs

Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,7 @@ impl CompareKernel for FSSTVTable {
1919
operator: Operator,
2020
) -> VortexResult<Option<ArrayRef>> {
2121
match rhs.as_constant() {
22-
Some(constant) => {
23-
compare_fsst_constant(lhs, &ConstantArray::new(constant, lhs.len()), operator)
24-
}
22+
Some(constant) => compare_fsst_constant(lhs, &constant, operator),
2523
// Otherwise, fall back to the default comparison behavior.
2624
_ => Ok(None),
2725
}
@@ -33,16 +31,15 @@ register_kernel!(CompareKernelAdapter(FSSTVTable).lift());
3331
/// Specialized compare function implementation used when performing against a constant
3432
fn compare_fsst_constant(
3533
left: &FSSTArray,
36-
right: &ConstantArray,
34+
right: &Scalar,
3735
operator: Operator,
3836
) -> VortexResult<Option<ArrayRef>> {
39-
let rhs_scalar = right.scalar();
40-
let is_rhs_empty = match rhs_scalar.dtype() {
41-
DType::Binary(_) => rhs_scalar
37+
let is_rhs_empty = match right.dtype() {
38+
DType::Binary(_) => right
4239
.as_binary()
4340
.is_empty()
4441
.vortex_expect("RHS should not be null"),
45-
DType::Utf8(_) => rhs_scalar
42+
DType::Utf8(_) => right
4643
.as_utf8()
4744
.is_empty()
4845
.vortex_expect("RHS should not be null"),
@@ -77,20 +74,17 @@ fn compare_fsst_constant(
7774
return Ok(None);
7875
}
7976

80-
let compressor = fsst::Compressor::rebuild_from(left.symbols(), left.symbol_lengths());
81-
77+
let compressor = left.compressor();
8278
let encoded_buffer = match left.dtype() {
8379
DType::Utf8(_) => {
8480
let value = right
85-
.scalar()
8681
.as_utf8()
8782
.value()
8883
.vortex_expect("Expected non-null scalar");
8984
ByteBuffer::from(compressor.compress(value.as_bytes()))
9085
}
9186
DType::Binary(_) => {
9287
let value = right
93-
.scalar()
9488
.as_binary()
9589
.value()
9690
.vortex_expect("Expected non-null scalar");

0 commit comments

Comments
 (0)