@@ -6,12 +6,12 @@ use vortex_array::stats::{ArrayStats, StatsSetRef};
66use vortex_array:: validity:: Validity ;
77use vortex_array:: vtable:: {
88 ArrayVTable , CanonicalVTable , NotSupported , OperationsVTable , VTable , ValidityHelper ,
9- ValidityVTableFromValidityHelper ,
9+ ValiditySliceHelper , ValidityVTableFromValiditySliceHelper ,
1010} ;
1111use vortex_array:: { ArrayRef , Canonical , EncodingId , EncodingRef , IntoArray , ToCanonical , vtable} ;
1212use vortex_buffer:: { Alignment , ByteBuffer , ByteBufferMut } ;
1313use vortex_dtype:: DType ;
14- use vortex_error:: { VortexError , VortexResult , vortex_err} ;
14+ use vortex_error:: { VortexError , VortexResult , vortex_bail , vortex_err} ;
1515use vortex_scalar:: Scalar ;
1616
1717use crate :: serde:: { ZstdFrameMetadata , ZstdMetadata } ;
@@ -20,26 +20,24 @@ use crate::serde::{ZstdFrameMetadata, ZstdMetadata};
2020const MIN_SAMPLES_FOR_DICTIONARY : usize = 8 ;
2121
2222// Overall approach here:
23- // Zstd can be used on the whole array (rows_per_frame = 0), resulting in a single Zstd
24- // frame, or it can be used with a dictionary (rows_per_frame < # rows ), resulting in
23+ // Zstd can be used on the whole array (values_per_frame = 0), resulting in a single Zstd
24+ // frame, or it can be used with a dictionary (values_per_frame < # values ), resulting in
2525// multiple Zstd frames sharing a common dictionary. This latter case is helpful if you
2626// want somewhat faster access to slices or individual rows, allowing us to only
2727// decompress the necessary frames.
2828
29- // Visually, during compression and decompression, we have an interval of frames we're
30- // compressing/ decompressing and a tighter interval of the slice we actually care about:
29+ // Visually, during decompression, we have an interval of frames we're
30+ // decompressing and a tighter interval of the slice we actually care about.
3131//
32- // |=====================validity========================|
33- // |=======================rows==========================|
34- // |----------------frames_rows-------------------|
35- // <--row_offset->|----slice-------------------|
36- // ^ ^
37- // |<------slice_n_rows-------->|
38- // slice_start slice_stop
32+ // |=============values (all valid elements)==============|
33+ // |<-skipped_uncompressed->|----decompressed-------------|
34+ // |------slice-------|
35+ // ^ ^
36+ // |<-slice_uncompressed_start->| |
37+ // |<------------slice_uncompressed_stop---------->|
3938//
40- // |=====values (all valid elements)====|
41- // |-------frames_values------|
42- // |----slice_values-----|
39+ // We then insert these values to the correct position using a primitive array
40+ // constructor.
4341
4442vtable ! ( Zstd ) ;
4543
@@ -50,7 +48,7 @@ impl VTable for ZstdVTable {
5048 type ArrayVTable = Self ;
5149 type CanonicalVTable = Self ;
5250 type OperationsVTable = Self ;
53- type ValidityVTable = ValidityVTableFromValidityHelper ;
51+ type ValidityVTable = ValidityVTableFromValiditySliceHelper ;
5452 type VisitorVTable = Self ;
5553 type ComputeVTable = NotSupported ;
5654 type EncodeVTable = Self ;
@@ -72,9 +70,10 @@ pub struct ZstdEncoding;
7270pub struct ZstdArray {
7371 pub ( crate ) dictionary : Option < ByteBuffer > ,
7472 pub ( crate ) frames : Vec < ByteBuffer > ,
75- pub ( crate ) validity : Validity ,
7673 pub ( crate ) metadata : ZstdMetadata ,
7774 dtype : DType ,
75+ pub ( crate ) unsliced_validity : Validity ,
76+ unsliced_n_rows : usize ,
7877 stats_set : ArrayStats ,
7978 slice_start : usize ,
8079 slice_stop : usize ,
@@ -105,48 +104,42 @@ impl ZstdArray {
105104 Self {
106105 dictionary,
107106 frames,
108- validity,
109107 metadata,
110108 dtype,
109+ unsliced_validity : validity,
110+ unsliced_n_rows : n_rows,
111111 stats_set : Default :: default ( ) ,
112112 slice_start : 0 ,
113113 slice_stop : n_rows,
114114 }
115115 }
116116
117- pub fn uncompressed_size ( & self ) -> usize {
118- ( self . slice_stop - self . slice_start ) * self . dtype . as_ptype ( ) . byte_width ( )
119- }
120-
121117 pub fn from_primitive (
122118 parray : & PrimitiveArray ,
123119 level : i32 ,
124- rows_per_frame : usize ,
120+ values_per_frame : usize ,
125121 ) -> VortexResult < Self > {
126122 let dtype = parray. dtype ( ) . clone ( ) ;
127123 let byte_width = parray. ptype ( ) . byte_width ( ) ;
128- let mask = parray. validity_mask ( ) ?;
129- let n_rows = parray. len ( ) ;
130- let rows_per_frame = if rows_per_frame > 0 {
131- rows_per_frame
132- } else {
133- n_rows
134- } ;
135- let frame_row_indices = ( 0 ..n_rows) . step_by ( rows_per_frame) . collect :: < Vec < _ > > ( ) ;
136- let n_frames = frame_row_indices. len ( ) ;
137124
138125 // We compress only the valid elements.
139126 let values = collect_valid ( parray) ?;
140- let mut valid_counts = mask. valid_counts_for_indices ( & frame_row_indices) ?;
141- valid_counts. push ( values. len ( ) ) ; // for convenience
142- let values = values. byte_buffer ( ) ;
143- let value_bytes = values. inner ( ) ;
127+ let n_values = values. len ( ) ;
128+ let values_per_frame = if values_per_frame > 0 {
129+ values_per_frame
130+ } else {
131+ n_values
132+ } ;
133+
134+ let mut frame_value_starts = ( 0 ..n_values) . step_by ( values_per_frame) . collect :: < Vec < _ > > ( ) ;
135+ let n_frames = frame_value_starts. len ( ) ;
136+ frame_value_starts. push ( values. len ( ) ) ; // for convenience, include the stop of the last frame
137+ let value_bytes = values. byte_buffer ( ) ;
144138
145139 // Would-be sample sizes if we end up applying zstd dictionary
146- let sample_sizes: Vec < usize > = valid_counts
140+ let sample_sizes: Vec < usize > = frame_value_starts
147141 . windows ( 2 )
148142 . map ( |pair| ( pair[ 1 ] - pair[ 0 ] ) * byte_width)
149- . filter ( |& size| size > 0 )
150143 . collect ( ) ;
151144 debug_assert_eq ! ( sample_sizes. iter( ) . sum:: <usize >( ) , value_bytes. len( ) ) ;
152145
@@ -155,7 +148,7 @@ impl ZstdArray {
155148 ( None , zstd:: bulk:: Compressor :: new ( level) ?)
156149 } else {
157150 // with dictionary
158- let max_dict_size = choose_max_dict_size ( values . len ( ) ) ;
151+ let max_dict_size = choose_max_dict_size ( value_bytes . len ( ) ) ;
159152 let dict = zstd:: dict:: from_continuous ( value_bytes, & sample_sizes, max_dict_size)
160153 . map_err ( |err| VortexError :: from ( err) . with_context ( "while training dictionary" ) ) ?;
161154
@@ -166,16 +159,12 @@ impl ZstdArray {
166159 let mut frame_metas = vec ! [ ] ;
167160 let mut frames = vec ! [ ] ;
168161 for i in 0 ..n_frames {
169- let uncompressed =
170- & value_bytes . slice ( valid_counts [ i] * byte_width..valid_counts [ i + 1 ] * byte_width) ;
162+ let uncompressed = & value_bytes
163+ . slice ( frame_value_starts [ i] * byte_width..frame_value_starts [ i + 1 ] * byte_width) ;
171164 let compressed = compressor
172165 . compress ( uncompressed)
173166 . map_err ( |err| VortexError :: from ( err) . with_context ( "while compressing" ) ) ?;
174- let frame_n_rows = ( frame_row_indices. get ( i + 1 ) . cloned ( ) . unwrap_or ( n_rows)
175- - frame_row_indices[ i] ) as u64 ;
176167 frame_metas. push ( ZstdFrameMetadata {
177- n_rows : frame_n_rows,
178- compressed_size : compressed. len ( ) as u64 ,
179168 uncompressed_size : uncompressed. len ( ) as u64 ,
180169 } ) ;
181170 frames. push ( ByteBuffer :: from ( compressed) ) ;
@@ -194,14 +183,14 @@ impl ZstdArray {
194183 frames,
195184 dtype,
196185 metadata,
197- n_rows ,
186+ parray . len ( ) ,
198187 parray. validity ( ) . clone ( ) ,
199188 ) )
200189 }
201190
202- pub fn from_array ( array : ArrayRef , level : i32 , rows_per_frame : usize ) -> VortexResult < Self > {
191+ pub fn from_array ( array : ArrayRef , level : i32 , values_per_frame : usize ) -> VortexResult < Self > {
203192 if let Some ( parray) = array. as_opt :: < PrimitiveVTable > ( ) {
204- Self :: from_primitive ( parray, level, rows_per_frame )
193+ Self :: from_primitive ( parray, level, values_per_frame )
205194 } else {
206195 Err ( vortex_err ! ( "Zstd can only encode primitive arrays" ) )
207196 }
@@ -210,78 +199,79 @@ impl ZstdArray {
210199 pub fn decompress ( & self ) -> VortexResult < ArrayRef > {
211200 // To start, we figure out which frames we need to decompress, and with
212201 // what row offset into the first such frame.
202+ let ptype = self . dtype . as_ptype ( ) ;
203+ let byte_width = ptype. byte_width ( ) ;
213204 let slice_n_rows = self . slice_stop - self . slice_start ;
214- let byte_width = self . dtype . as_ptype ( ) . byte_width ( ) ;
215- let mut frame_start_row = 0 ;
216- let mut frame_idx_lb = 0 ;
217- let mut frame_idx_ub = 0 ;
218- let mut row_offset = 0 ;
219- for ( i, frame_meta) in self . metadata . frames . iter ( ) . enumerate ( ) {
220- let buf_stop = frame_start_row + usize:: try_from ( frame_meta. n_rows ) ?;
221- if frame_start_row < self . slice_start {
222- frame_idx_lb = i;
223- row_offset = self . slice_start - frame_start_row
205+ let slice_value_indices = self
206+ . unsliced_validity
207+ . to_mask ( self . unsliced_n_rows ) ?
208+ . valid_counts_for_indices ( & [ self . slice_start , self . slice_stop ] ) ?;
209+ let slice_uncompressed_start = slice_value_indices[ 0 ] * byte_width;
210+ let slice_uncompressed_stop = slice_value_indices[ 1 ] * byte_width;
211+
212+ let mut frames_to_decompress = vec ! [ ] ;
213+ let mut uncompressed_start = 0 ;
214+ let mut uncompressed_size_to_decompress = 0 ;
215+ let mut skipped_uncompressed = 0 ;
216+ for ( frame, frame_meta) in self . frames . iter ( ) . zip ( & self . metadata . frames ) {
217+ if uncompressed_start >= slice_uncompressed_stop {
218+ break ;
224219 }
225- if frame_start_row < self . slice_stop {
226- frame_idx_ub = i + 1
220+ let frame_uncompressed = usize:: try_from ( frame_meta. uncompressed_size ) ?;
221+
222+ let uncompressed_stop = uncompressed_start + frame_uncompressed;
223+ if uncompressed_stop > slice_uncompressed_start {
224+ // we need this frame
225+ frames_to_decompress. push ( frame) ;
226+ uncompressed_size_to_decompress += frame_uncompressed;
227+ } else {
228+ skipped_uncompressed += frame_uncompressed;
227229 }
228- frame_start_row = buf_stop ;
230+ uncompressed_start = uncompressed_stop ;
229231 }
230232
231233 // then we actually decompress those frames
232- let frame_metas = & self . metadata . frames [ frame_idx_lb..frame_idx_ub] ;
233- let total_uncompressed_size: usize = frame_metas
234- . iter ( )
235- . map ( |meta| meta. uncompressed_size )
236- . sum :: < u64 > ( )
237- . try_into ( ) ?;
238-
239234 let mut decompressor = if let Some ( dictionary) = & self . dictionary {
240235 zstd:: bulk:: Decompressor :: with_dictionary ( dictionary)
241236 } else {
242237 zstd:: bulk:: Decompressor :: new ( )
243238 } ?;
244-
245- // we could make this empty initialized for better performance
246- let mut frames_values_bytes = ByteBufferMut :: with_capacity_aligned (
247- total_uncompressed_size,
239+ let mut decompressed = ByteBufferMut :: with_capacity_aligned (
240+ uncompressed_size_to_decompress,
248241 Alignment :: new ( byte_width) ,
249242 ) ;
250243 unsafe {
251244 // safety: we immediately fill all bytes in the following loop,
252245 // assuming our metadata's uncompressed size is correct
253- frames_values_bytes. set_len ( total_uncompressed_size) ;
246+ decompressed. set_len ( uncompressed_size_to_decompress) ;
247+ }
248+ let mut uncompressed_start = 0 ;
249+ for frame in frames_to_decompress {
250+ let uncompressed_written = decompressor
251+ . decompress_to_buffer ( frame. as_slice ( ) , & mut decompressed[ uncompressed_start..] ) ?;
252+ uncompressed_start += uncompressed_written;
254253 }
255- let mut start_byte = 0 ;
256- for ( frame, meta) in self . frames [ frame_idx_lb..frame_idx_ub]
257- . iter ( )
258- . zip ( frame_metas)
259- {
260- let stop_byte = start_byte + usize:: try_from ( meta. uncompressed_size ) ?;
261- decompressor. decompress_to_buffer (
262- frame. as_slice ( ) ,
263- & mut frames_values_bytes[ start_byte..stop_byte] ,
264- ) ?;
265- start_byte = stop_byte;
254+ if uncompressed_start != uncompressed_size_to_decompress {
255+ vortex_bail ! (
256+ "Zstd metadata or frames were corrupt; expected {} byte but decompressed {}" ,
257+ uncompressed_size_to_decompress,
258+ uncompressed_start
259+ ) ;
266260 }
267261
268- // Last, we apply our offset. We need to copy since the decompressed
269- // frame start/end might not align with our slice. And we need to
270- // align the data to our (dynamic) dtype.
271- let frames_validity = self
272- . validity
273- . slice ( self . slice_start - row_offset, self . slice_stop ) ?;
274- let frames_mask = frames_validity. to_mask ( row_offset + slice_n_rows) ?;
275- let frames_values_start_stop =
276- frames_mask. valid_counts_for_indices ( & [ row_offset, row_offset + slice_n_rows] ) ?;
277- let slice_values_buffer = frames_values_bytes. freeze ( ) . slice (
278- frames_values_start_stop[ 0 ] * byte_width..frames_values_start_stop[ 1 ] * byte_width,
262+ // Last, we slice the exact values requested out of the decompressed data.
263+ let slice_validity = self
264+ . unsliced_validity
265+ . slice ( self . slice_start , self . slice_stop ) ?;
266+ let slice_values_buffer = decompressed. freeze ( ) . slice (
267+ slice_uncompressed_start - skipped_uncompressed
268+ ..slice_uncompressed_stop - skipped_uncompressed,
279269 ) ;
280270
281271 let primitive = PrimitiveArray :: from_values_byte_buffer (
282272 slice_values_buffer,
283- self . dtype . as_ptype ( ) ,
284- frames_validity . slice ( row_offset , row_offset + slice_n_rows ) ? ,
273+ ptype ,
274+ slice_validity ,
285275 slice_n_rows,
286276 ) ?;
287277
@@ -297,9 +287,9 @@ impl ZstdArray {
297287 }
298288}
299289
300- impl ValidityHelper for ZstdArray {
301- fn validity ( & self ) -> & Validity {
302- & self . validity
290+ impl ValiditySliceHelper for ZstdArray {
291+ fn unsliced_validity_and_slice ( & self ) -> ( & Validity , usize , usize ) {
292+ ( & self . unsliced_validity , self . slice_start , self . slice_stop )
303293 }
304294}
305295
0 commit comments