Skip to content

Commit 929c372

Browse files
authored
Chore: refactor dict module (#5501)
Again, a purely cosmetic change. (wanted to do this before implementing the batch execute) Signed-off-by: Connor Tsui <[email protected]>
1 parent b5d28a1 commit 929c372

File tree

13 files changed

+395
-364
lines changed

13 files changed

+395
-364
lines changed

vortex-array/src/arrays/dict/array.rs

Lines changed: 16 additions & 211 deletions
Original file line numberDiff line numberDiff line change
@@ -1,140 +1,44 @@
11
// SPDX-License-Identifier: Apache-2.0
22
// SPDX-FileCopyrightText: Copyright the Vortex contributors
33

4-
use std::fmt::Debug;
5-
use std::hash::Hash;
6-
7-
use vortex_buffer::{BitBuffer, ByteBuffer};
8-
use vortex_dtype::{DType, Nullability, PType, match_each_integer_ptype};
9-
use vortex_error::{
10-
VortexExpect, VortexResult, VortexUnwrap, vortex_bail, vortex_ensure, vortex_err,
11-
};
12-
use vortex_mask::{AllOr, Mask};
13-
14-
use crate::builders::dict::dict_encode;
15-
use crate::serde::ArrayChildren;
16-
use crate::stats::{ArrayStats, StatsSetRef};
17-
use crate::vtable::{
18-
ArrayId, ArrayVTable, ArrayVTableExt, BaseArrayVTable, EncodeVTable, NotSupported, VTable,
19-
ValidityVTable, VisitorVTable,
20-
};
21-
use crate::{
22-
Array, ArrayBufferVisitor, ArrayChildVisitor, ArrayEq, ArrayHash, ArrayRef, Canonical,
23-
DeserializeMetadata, Precision, ProstMetadata, SerializeMetadata, ToCanonical, vtable,
24-
};
25-
26-
vtable!(Dict);
4+
use vortex_buffer::BitBuffer;
5+
use vortex_dtype::{DType, PType, match_each_integer_ptype};
6+
use vortex_error::{VortexExpect, VortexResult, VortexUnwrap, vortex_bail, vortex_ensure};
7+
use vortex_mask::AllOr;
8+
9+
use crate::stats::ArrayStats;
10+
use crate::{Array, ArrayRef, ToCanonical};
2711

2812
#[derive(Clone, prost::Message)]
2913
pub struct DictMetadata {
3014
#[prost(uint32, tag = "1")]
3115
pub(super) values_len: u32,
3216
#[prost(enumeration = "PType", tag = "2")]
3317
pub(super) codes_ptype: i32,
34-
// nullable codes are optional since they were added after stabilisation
18+
// nullable codes are optional since they were added after stabilisation.
3519
#[prost(optional, bool, tag = "3")]
3620
pub(super) is_nullable_codes: Option<bool>,
37-
// all_values_referenced is optional for backward compatibility
38-
// true = all dictionary values are definitely referenced by at least one code
39-
// false/None = unknown whether all values are referenced (conservative default)
21+
// all_values_referenced is optional for backward compatibility.
22+
// true = all dictionary values are definitely referenced by at least one code.
23+
// false/None = unknown whether all values are referenced (conservative default).
4024
#[prost(optional, bool, tag = "4")]
4125
pub(super) all_values_referenced: Option<bool>,
4226
}
4327

44-
impl VTable for DictVTable {
45-
type Array = DictArray;
46-
47-
type Metadata = ProstMetadata<DictMetadata>;
48-
49-
type ArrayVTable = Self;
50-
type CanonicalVTable = Self;
51-
type OperationsVTable = Self;
52-
type ValidityVTable = Self;
53-
type VisitorVTable = Self;
54-
type ComputeVTable = NotSupported;
55-
type EncodeVTable = Self;
56-
type OperatorVTable = NotSupported;
57-
58-
fn id(&self) -> ArrayId {
59-
ArrayId::new_ref("vortex.dict")
60-
}
61-
62-
fn encoding(_array: &Self::Array) -> ArrayVTable {
63-
DictVTable.as_vtable()
64-
}
65-
66-
fn metadata(array: &DictArray) -> VortexResult<Self::Metadata> {
67-
Ok(ProstMetadata(DictMetadata {
68-
codes_ptype: PType::try_from(array.codes().dtype())? as i32,
69-
values_len: u32::try_from(array.values().len()).map_err(|_| {
70-
vortex_err!(
71-
"Dictionary values size {} overflowed u32",
72-
array.values().len()
73-
)
74-
})?,
75-
is_nullable_codes: Some(array.codes().dtype().is_nullable()),
76-
all_values_referenced: Some(array.all_values_referenced),
77-
}))
78-
}
79-
80-
fn serialize(metadata: Self::Metadata) -> VortexResult<Option<Vec<u8>>> {
81-
Ok(Some(metadata.serialize()))
82-
}
83-
84-
fn deserialize(buffer: &[u8]) -> VortexResult<Self::Metadata> {
85-
let metadata = <Self::Metadata as DeserializeMetadata>::deserialize(buffer)?;
86-
Ok(ProstMetadata(metadata))
87-
}
88-
89-
fn build(
90-
&self,
91-
dtype: &DType,
92-
len: usize,
93-
metadata: &Self::Metadata,
94-
_buffers: &[ByteBuffer],
95-
children: &dyn ArrayChildren,
96-
) -> VortexResult<DictArray> {
97-
if children.len() != 2 {
98-
vortex_bail!(
99-
"Expected 2 children for dict encoding, found {}",
100-
children.len()
101-
)
102-
}
103-
let codes_nullable = metadata
104-
.is_nullable_codes
105-
.map(Nullability::from)
106-
// If no `is_nullable_codes` metadata use the nullability of the values
107-
// (and whole array) as before.
108-
.unwrap_or_else(|| dtype.nullability());
109-
let codes_dtype = DType::Primitive(metadata.codes_ptype(), codes_nullable);
110-
let codes = children.get(0, &codes_dtype, len)?;
111-
let values = children.get(1, dtype, metadata.values_len as usize)?;
112-
let all_values_referenced = metadata.all_values_referenced.unwrap_or(false);
113-
114-
// SAFETY: We've validated the metadata and children
115-
Ok(unsafe {
116-
DictArray::new_unchecked(codes, values).set_all_values_referenced(all_values_referenced)
117-
})
118-
}
119-
}
120-
12128
#[derive(Debug, Clone)]
12229
pub struct DictArray {
123-
codes: ArrayRef,
124-
values: ArrayRef,
125-
stats_set: ArrayStats,
126-
dtype: DType,
30+
pub(super) codes: ArrayRef,
31+
pub(super) values: ArrayRef,
32+
pub(super) stats_set: ArrayStats,
33+
pub(super) dtype: DType,
12734
/// Indicates whether all dictionary values are definitely referenced by at least one code.
12835
/// `true` = all values are referenced (computed during encoding).
12936
/// `false` = unknown/might have unreferenced values.
13037
/// In case this is incorrect never use this to enable memory unsafe behaviour just semantically
13138
/// incorrect behaviour.
132-
all_values_referenced: bool,
39+
pub(super) all_values_referenced: bool,
13340
}
13441

135-
#[derive(Clone, Debug)]
136-
pub struct DictVTable;
137-
13842
impl DictArray {
13943
/// Build a new `DictArray` without validating the codes or values.
14044
///
@@ -286,105 +190,6 @@ impl DictArray {
286190
}
287191
}
288192

289-
impl BaseArrayVTable<DictVTable> for DictVTable {
290-
fn len(array: &DictArray) -> usize {
291-
array.codes.len()
292-
}
293-
294-
fn dtype(array: &DictArray) -> &DType {
295-
&array.dtype
296-
}
297-
298-
fn stats(array: &DictArray) -> StatsSetRef<'_> {
299-
array.stats_set.to_ref(array.as_ref())
300-
}
301-
302-
fn array_hash<H: std::hash::Hasher>(array: &DictArray, state: &mut H, precision: Precision) {
303-
array.dtype.hash(state);
304-
array.codes.array_hash(state, precision);
305-
array.values.array_hash(state, precision);
306-
}
307-
308-
fn array_eq(array: &DictArray, other: &DictArray, precision: Precision) -> bool {
309-
array.dtype == other.dtype
310-
&& array.codes.array_eq(&other.codes, precision)
311-
&& array.values.array_eq(&other.values, precision)
312-
}
313-
}
314-
315-
impl ValidityVTable<DictVTable> for DictVTable {
316-
fn is_valid(array: &DictArray, index: usize) -> bool {
317-
let scalar = array.codes().scalar_at(index);
318-
319-
if scalar.is_null() {
320-
return false;
321-
};
322-
let values_index: usize = scalar
323-
.as_ref()
324-
.try_into()
325-
.vortex_expect("Failed to convert dictionary code to usize");
326-
array.values().is_valid(values_index)
327-
}
328-
329-
fn all_valid(array: &DictArray) -> bool {
330-
array.codes().all_valid() && array.values().all_valid()
331-
}
332-
333-
fn all_invalid(array: &DictArray) -> bool {
334-
array.codes().all_invalid() || array.values().all_invalid()
335-
}
336-
337-
fn validity_mask(array: &DictArray) -> Mask {
338-
let codes_validity = array.codes().validity_mask();
339-
match codes_validity.bit_buffer() {
340-
AllOr::All => {
341-
let primitive_codes = array.codes().to_primitive();
342-
let values_mask = array.values().validity_mask();
343-
let is_valid_buffer = match_each_integer_ptype!(primitive_codes.ptype(), |P| {
344-
let codes_slice = primitive_codes.as_slice::<P>();
345-
BitBuffer::collect_bool(array.len(), |idx| {
346-
#[allow(clippy::cast_possible_truncation)]
347-
values_mask.value(codes_slice[idx] as usize)
348-
})
349-
});
350-
Mask::from_buffer(is_valid_buffer)
351-
}
352-
AllOr::None => Mask::AllFalse(array.len()),
353-
AllOr::Some(validity_buff) => {
354-
let primitive_codes = array.codes().to_primitive();
355-
let values_mask = array.values().validity_mask();
356-
let is_valid_buffer = match_each_integer_ptype!(primitive_codes.ptype(), |P| {
357-
let codes_slice = primitive_codes.as_slice::<P>();
358-
#[allow(clippy::cast_possible_truncation)]
359-
BitBuffer::collect_bool(array.len(), |idx| {
360-
validity_buff.value(idx) && values_mask.value(codes_slice[idx] as usize)
361-
})
362-
});
363-
Mask::from_buffer(is_valid_buffer)
364-
}
365-
}
366-
}
367-
}
368-
369-
impl EncodeVTable<DictVTable> for DictVTable {
370-
fn encode(
371-
_vtable: &DictVTable,
372-
canonical: &Canonical,
373-
_like: Option<&DictArray>,
374-
) -> VortexResult<Option<DictArray>> {
375-
Ok(Some(dict_encode(canonical.as_ref())?))
376-
}
377-
}
378-
379-
impl VisitorVTable<DictVTable> for DictVTable {
380-
fn visit_buffers(_array: &DictArray, _visitor: &mut dyn ArrayBufferVisitor) {}
381-
382-
fn visit_children(array: &DictArray, visitor: &mut dyn ArrayChildVisitor) {
383-
visitor.visit_child("codes", array.codes());
384-
visitor.visit_child("values", array.values());
385-
}
386-
}
387-
388193
#[cfg(test)]
389194
mod test {
390195
#[allow(unused_imports)]

vortex-array/src/arrays/dict/arrow.rs

Lines changed: 0 additions & 19 deletions
This file was deleted.

vortex-array/src/arrays/dict/display.rs

Lines changed: 0 additions & 73 deletions
This file was deleted.

vortex-array/src/arrays/dict/mod.rs

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,14 @@
55
//!
66
//! Expose a [DictArray] which is zero-copy equivalent to Arrow's
77
//! [DictionaryArray](https://docs.rs/arrow/latest/arrow/array/struct.DictionaryArray.html).
8-
pub use array::*;
98
109
mod array;
11-
mod arrow;
12-
mod canonical;
10+
pub use array::*;
11+
1312
mod compute;
14-
mod display;
15-
mod ops;
13+
14+
pub mod vtable;
15+
pub use vtable::*;
16+
17+
#[cfg(test)]
18+
mod tests;

0 commit comments

Comments
 (0)