Skip to content

Commit 3e93615

Browse files
authored
Zstd encoding improvements (#3537)
* Chunk frames by values instead of by rows. This should be more efficient in the case where almost all rows are null, and it reduces metadata and makes the implementation a bit easier. To call it out: I'm getting rid of metadata here, which should be ok since no one is actually using this yet. * Nit to use to_primitive directly in tests. * Removed unnecessary uncompressed_size function. * Fixed and tested vtable implementation to use correct slice. --------- Signed-off-by: mwlon <[email protected]>
1 parent 60533ae commit 3e93615

File tree

5 files changed

+129
-141
lines changed

5 files changed

+129
-141
lines changed

Cargo.lock

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

encodings/zstd/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ vortex-array = { workspace = true }
2222
vortex-buffer = { workspace = true }
2323
vortex-dtype = { workspace = true }
2424
vortex-error = { workspace = true }
25+
vortex-mask = { workspace = true }
2526
vortex-scalar = { workspace = true }
2627
zstd = { workspace = true }
2728

encodings/zstd/src/array.rs

Lines changed: 91 additions & 101 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,12 @@ use vortex_array::stats::{ArrayStats, StatsSetRef};
66
use vortex_array::validity::Validity;
77
use vortex_array::vtable::{
88
ArrayVTable, CanonicalVTable, NotSupported, OperationsVTable, VTable, ValidityHelper,
9-
ValidityVTableFromValidityHelper,
9+
ValiditySliceHelper, ValidityVTableFromValiditySliceHelper,
1010
};
1111
use vortex_array::{ArrayRef, Canonical, EncodingId, EncodingRef, IntoArray, ToCanonical, vtable};
1212
use vortex_buffer::{Alignment, ByteBuffer, ByteBufferMut};
1313
use vortex_dtype::DType;
14-
use vortex_error::{VortexError, VortexResult, vortex_err};
14+
use vortex_error::{VortexError, VortexResult, vortex_bail, vortex_err};
1515
use vortex_scalar::Scalar;
1616

1717
use crate::serde::{ZstdFrameMetadata, ZstdMetadata};
@@ -20,26 +20,24 @@ use crate::serde::{ZstdFrameMetadata, ZstdMetadata};
2020
const MIN_SAMPLES_FOR_DICTIONARY: usize = 8;
2121

2222
// Overall approach here:
23-
// Zstd can be used on the whole array (rows_per_frame = 0), resulting in a single Zstd
24-
// frame, or it can be used with a dictionary (rows_per_frame < # rows), resulting in
23+
// Zstd can be used on the whole array (values_per_frame = 0), resulting in a single Zstd
24+
// frame, or it can be used with a dictionary (values_per_frame < # values), resulting in
2525
// multiple Zstd frames sharing a common dictionary. This latter case is helpful if you
2626
// want somewhat faster access to slices or individual rows, allowing us to only
2727
// decompress the necessary frames.
2828

29-
// Visually, during compression and decompression, we have an interval of frames we're
30-
// compressing/decompressing and a tighter interval of the slice we actually care about:
29+
// Visually, during decompression, we have an interval of frames we're
30+
// decompressing and a tighter interval of the slice we actually care about.
3131
//
32-
// |=====================validity========================|
33-
// |=======================rows==========================|
34-
// |----------------frames_rows-------------------|
35-
// <--row_offset->|----slice-------------------|
36-
// ^ ^
37-
// |<------slice_n_rows-------->|
38-
// slice_start slice_stop
32+
// |=============values (all valid elements)==============|
33+
// |<-skipped_uncompressed->|----decompressed-------------|
34+
// |------slice-------|
35+
// ^ ^
36+
// |<-slice_uncompressed_start->| |
37+
// |<------------slice_uncompressed_stop---------->|
3938
//
40-
// |=====values (all valid elements)====|
41-
// |-------frames_values------|
42-
// |----slice_values-----|
39+
// We then insert these values to the correct position using a primitive array
40+
// constructor.
4341

4442
vtable!(Zstd);
4543

@@ -50,7 +48,7 @@ impl VTable for ZstdVTable {
5048
type ArrayVTable = Self;
5149
type CanonicalVTable = Self;
5250
type OperationsVTable = Self;
53-
type ValidityVTable = ValidityVTableFromValidityHelper;
51+
type ValidityVTable = ValidityVTableFromValiditySliceHelper;
5452
type VisitorVTable = Self;
5553
type ComputeVTable = NotSupported;
5654
type EncodeVTable = Self;
@@ -72,9 +70,10 @@ pub struct ZstdEncoding;
7270
pub struct ZstdArray {
7371
pub(crate) dictionary: Option<ByteBuffer>,
7472
pub(crate) frames: Vec<ByteBuffer>,
75-
pub(crate) validity: Validity,
7673
pub(crate) metadata: ZstdMetadata,
7774
dtype: DType,
75+
pub(crate) unsliced_validity: Validity,
76+
unsliced_n_rows: usize,
7877
stats_set: ArrayStats,
7978
slice_start: usize,
8079
slice_stop: usize,
@@ -105,48 +104,42 @@ impl ZstdArray {
105104
Self {
106105
dictionary,
107106
frames,
108-
validity,
109107
metadata,
110108
dtype,
109+
unsliced_validity: validity,
110+
unsliced_n_rows: n_rows,
111111
stats_set: Default::default(),
112112
slice_start: 0,
113113
slice_stop: n_rows,
114114
}
115115
}
116116

117-
pub fn uncompressed_size(&self) -> usize {
118-
(self.slice_stop - self.slice_start) * self.dtype.as_ptype().byte_width()
119-
}
120-
121117
pub fn from_primitive(
122118
parray: &PrimitiveArray,
123119
level: i32,
124-
rows_per_frame: usize,
120+
values_per_frame: usize,
125121
) -> VortexResult<Self> {
126122
let dtype = parray.dtype().clone();
127123
let byte_width = parray.ptype().byte_width();
128-
let mask = parray.validity_mask()?;
129-
let n_rows = parray.len();
130-
let rows_per_frame = if rows_per_frame > 0 {
131-
rows_per_frame
132-
} else {
133-
n_rows
134-
};
135-
let frame_row_indices = (0..n_rows).step_by(rows_per_frame).collect::<Vec<_>>();
136-
let n_frames = frame_row_indices.len();
137124

138125
// We compress only the valid elements.
139126
let values = collect_valid(parray)?;
140-
let mut valid_counts = mask.valid_counts_for_indices(&frame_row_indices)?;
141-
valid_counts.push(values.len()); // for convenience
142-
let values = values.byte_buffer();
143-
let value_bytes = values.inner();
127+
let n_values = values.len();
128+
let values_per_frame = if values_per_frame > 0 {
129+
values_per_frame
130+
} else {
131+
n_values
132+
};
133+
134+
let mut frame_value_starts = (0..n_values).step_by(values_per_frame).collect::<Vec<_>>();
135+
let n_frames = frame_value_starts.len();
136+
frame_value_starts.push(values.len()); // for convenience, include the stop of the last frame
137+
let value_bytes = values.byte_buffer();
144138

145139
// Would-be sample sizes if we end up applying zstd dictionary
146-
let sample_sizes: Vec<usize> = valid_counts
140+
let sample_sizes: Vec<usize> = frame_value_starts
147141
.windows(2)
148142
.map(|pair| (pair[1] - pair[0]) * byte_width)
149-
.filter(|&size| size > 0)
150143
.collect();
151144
debug_assert_eq!(sample_sizes.iter().sum::<usize>(), value_bytes.len());
152145

@@ -155,7 +148,7 @@ impl ZstdArray {
155148
(None, zstd::bulk::Compressor::new(level)?)
156149
} else {
157150
// with dictionary
158-
let max_dict_size = choose_max_dict_size(values.len());
151+
let max_dict_size = choose_max_dict_size(value_bytes.len());
159152
let dict = zstd::dict::from_continuous(value_bytes, &sample_sizes, max_dict_size)
160153
.map_err(|err| VortexError::from(err).with_context("while training dictionary"))?;
161154

@@ -166,16 +159,12 @@ impl ZstdArray {
166159
let mut frame_metas = vec![];
167160
let mut frames = vec![];
168161
for i in 0..n_frames {
169-
let uncompressed =
170-
&value_bytes.slice(valid_counts[i] * byte_width..valid_counts[i + 1] * byte_width);
162+
let uncompressed = &value_bytes
163+
.slice(frame_value_starts[i] * byte_width..frame_value_starts[i + 1] * byte_width);
171164
let compressed = compressor
172165
.compress(uncompressed)
173166
.map_err(|err| VortexError::from(err).with_context("while compressing"))?;
174-
let frame_n_rows = (frame_row_indices.get(i + 1).cloned().unwrap_or(n_rows)
175-
- frame_row_indices[i]) as u64;
176167
frame_metas.push(ZstdFrameMetadata {
177-
n_rows: frame_n_rows,
178-
compressed_size: compressed.len() as u64,
179168
uncompressed_size: uncompressed.len() as u64,
180169
});
181170
frames.push(ByteBuffer::from(compressed));
@@ -194,14 +183,14 @@ impl ZstdArray {
194183
frames,
195184
dtype,
196185
metadata,
197-
n_rows,
186+
parray.len(),
198187
parray.validity().clone(),
199188
))
200189
}
201190

202-
pub fn from_array(array: ArrayRef, level: i32, rows_per_frame: usize) -> VortexResult<Self> {
191+
pub fn from_array(array: ArrayRef, level: i32, values_per_frame: usize) -> VortexResult<Self> {
203192
if let Some(parray) = array.as_opt::<PrimitiveVTable>() {
204-
Self::from_primitive(parray, level, rows_per_frame)
193+
Self::from_primitive(parray, level, values_per_frame)
205194
} else {
206195
Err(vortex_err!("Zstd can only encode primitive arrays"))
207196
}
@@ -210,78 +199,79 @@ impl ZstdArray {
210199
pub fn decompress(&self) -> VortexResult<ArrayRef> {
211200
// To start, we figure out which frames we need to decompress, and with
212201
// what row offset into the first such frame.
202+
let ptype = self.dtype.as_ptype();
203+
let byte_width = ptype.byte_width();
213204
let slice_n_rows = self.slice_stop - self.slice_start;
214-
let byte_width = self.dtype.as_ptype().byte_width();
215-
let mut frame_start_row = 0;
216-
let mut frame_idx_lb = 0;
217-
let mut frame_idx_ub = 0;
218-
let mut row_offset = 0;
219-
for (i, frame_meta) in self.metadata.frames.iter().enumerate() {
220-
let buf_stop = frame_start_row + usize::try_from(frame_meta.n_rows)?;
221-
if frame_start_row < self.slice_start {
222-
frame_idx_lb = i;
223-
row_offset = self.slice_start - frame_start_row
205+
let slice_value_indices = self
206+
.unsliced_validity
207+
.to_mask(self.unsliced_n_rows)?
208+
.valid_counts_for_indices(&[self.slice_start, self.slice_stop])?;
209+
let slice_uncompressed_start = slice_value_indices[0] * byte_width;
210+
let slice_uncompressed_stop = slice_value_indices[1] * byte_width;
211+
212+
let mut frames_to_decompress = vec![];
213+
let mut uncompressed_start = 0;
214+
let mut uncompressed_size_to_decompress = 0;
215+
let mut skipped_uncompressed = 0;
216+
for (frame, frame_meta) in self.frames.iter().zip(&self.metadata.frames) {
217+
if uncompressed_start >= slice_uncompressed_stop {
218+
break;
224219
}
225-
if frame_start_row < self.slice_stop {
226-
frame_idx_ub = i + 1
220+
let frame_uncompressed = usize::try_from(frame_meta.uncompressed_size)?;
221+
222+
let uncompressed_stop = uncompressed_start + frame_uncompressed;
223+
if uncompressed_stop > slice_uncompressed_start {
224+
// we need this frame
225+
frames_to_decompress.push(frame);
226+
uncompressed_size_to_decompress += frame_uncompressed;
227+
} else {
228+
skipped_uncompressed += frame_uncompressed;
227229
}
228-
frame_start_row = buf_stop;
230+
uncompressed_start = uncompressed_stop;
229231
}
230232

231233
// then we actually decompress those frames
232-
let frame_metas = &self.metadata.frames[frame_idx_lb..frame_idx_ub];
233-
let total_uncompressed_size: usize = frame_metas
234-
.iter()
235-
.map(|meta| meta.uncompressed_size)
236-
.sum::<u64>()
237-
.try_into()?;
238-
239234
let mut decompressor = if let Some(dictionary) = &self.dictionary {
240235
zstd::bulk::Decompressor::with_dictionary(dictionary)
241236
} else {
242237
zstd::bulk::Decompressor::new()
243238
}?;
244-
245-
// we could make this empty initialized for better performance
246-
let mut frames_values_bytes = ByteBufferMut::with_capacity_aligned(
247-
total_uncompressed_size,
239+
let mut decompressed = ByteBufferMut::with_capacity_aligned(
240+
uncompressed_size_to_decompress,
248241
Alignment::new(byte_width),
249242
);
250243
unsafe {
251244
// safety: we immediately fill all bytes in the following loop,
252245
// assuming our metadata's uncompressed size is correct
253-
frames_values_bytes.set_len(total_uncompressed_size);
246+
decompressed.set_len(uncompressed_size_to_decompress);
247+
}
248+
let mut uncompressed_start = 0;
249+
for frame in frames_to_decompress {
250+
let uncompressed_written = decompressor
251+
.decompress_to_buffer(frame.as_slice(), &mut decompressed[uncompressed_start..])?;
252+
uncompressed_start += uncompressed_written;
254253
}
255-
let mut start_byte = 0;
256-
for (frame, meta) in self.frames[frame_idx_lb..frame_idx_ub]
257-
.iter()
258-
.zip(frame_metas)
259-
{
260-
let stop_byte = start_byte + usize::try_from(meta.uncompressed_size)?;
261-
decompressor.decompress_to_buffer(
262-
frame.as_slice(),
263-
&mut frames_values_bytes[start_byte..stop_byte],
264-
)?;
265-
start_byte = stop_byte;
254+
if uncompressed_start != uncompressed_size_to_decompress {
255+
vortex_bail!(
256+
"Zstd metadata or frames were corrupt; expected {} byte but decompressed {}",
257+
uncompressed_size_to_decompress,
258+
uncompressed_start
259+
);
266260
}
267261

268-
// Last, we apply our offset. We need to copy since the decompressed
269-
// frame start/end might not align with our slice. And we need to
270-
// align the data to our (dynamic) dtype.
271-
let frames_validity = self
272-
.validity
273-
.slice(self.slice_start - row_offset, self.slice_stop)?;
274-
let frames_mask = frames_validity.to_mask(row_offset + slice_n_rows)?;
275-
let frames_values_start_stop =
276-
frames_mask.valid_counts_for_indices(&[row_offset, row_offset + slice_n_rows])?;
277-
let slice_values_buffer = frames_values_bytes.freeze().slice(
278-
frames_values_start_stop[0] * byte_width..frames_values_start_stop[1] * byte_width,
262+
// Last, we slice the exact values requested out of the decompressed data.
263+
let slice_validity = self
264+
.unsliced_validity
265+
.slice(self.slice_start, self.slice_stop)?;
266+
let slice_values_buffer = decompressed.freeze().slice(
267+
slice_uncompressed_start - skipped_uncompressed
268+
..slice_uncompressed_stop - skipped_uncompressed,
279269
);
280270

281271
let primitive = PrimitiveArray::from_values_byte_buffer(
282272
slice_values_buffer,
283-
self.dtype.as_ptype(),
284-
frames_validity.slice(row_offset, row_offset + slice_n_rows)?,
273+
ptype,
274+
slice_validity,
285275
slice_n_rows,
286276
)?;
287277

@@ -297,9 +287,9 @@ impl ZstdArray {
297287
}
298288
}
299289

300-
impl ValidityHelper for ZstdArray {
301-
fn validity(&self) -> &Validity {
302-
&self.validity
290+
impl ValiditySliceHelper for ZstdArray {
291+
fn unsliced_validity_and_slice(&self) -> (&Validity, usize, usize) {
292+
(&self.unsliced_validity, self.slice_start, self.slice_stop)
303293
}
304294
}
305295

encodings/zstd/src/serde.rs

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,6 @@ use crate::{ZstdArray, ZstdEncoding, ZstdVTable};
1111
#[derive(Clone, prost::Message)]
1212
pub struct ZstdFrameMetadata {
1313
#[prost(uint64, tag = "1")]
14-
pub n_rows: u64,
15-
#[prost(uint64, tag = "2")]
16-
pub compressed_size: u64,
17-
#[prost(uint64, tag = "3")]
1814
pub uncompressed_size: u64,
1915
}
2016

@@ -93,6 +89,6 @@ impl VisitorVTable<ZstdVTable> for ZstdVTable {
9389
}
9490

9591
fn visit_children(array: &ZstdArray, visitor: &mut dyn ArrayChildVisitor) {
96-
visitor.visit_validity(&array.validity, array.len());
92+
visitor.visit_validity(&array.unsliced_validity, array.len());
9793
}
9894
}

0 commit comments

Comments
 (0)