Skip to content

Commit 6aa5e45

Browse files
committed
fix hardcoded dtype
Signed-off-by: Connor Tsui <[email protected]>
1 parent c93b6a7 commit 6aa5e45

File tree

1 file changed

+132
-14
lines changed

1 file changed

+132
-14
lines changed

vortex-layout/src/layouts/dict/writer.rs

Lines changed: 132 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@ use vortex_array::builders::dict::DictEncoder;
2727
use vortex_array::builders::dict::dict_encoder;
2828
use vortex_btrblocks::BtrBlocksCompressor;
2929
use vortex_dtype::DType;
30-
use vortex_dtype::Nullability::NonNullable;
3130
use vortex_dtype::PType;
3231
use vortex_error::VortexError;
3332
use vortex_error::VortexResult;
@@ -49,10 +48,22 @@ use crate::sequence::SequentialStream;
4948
use crate::sequence::SequentialStreamAdapter;
5049
use crate::sequence::SequentialStreamExt;
5150

51+
/// Constraints for dictionary layout encoding.
52+
///
53+
/// Note that [`max_len`](Self::max_len) is limited to `u16` (65,535 entries) by design. Since
54+
/// layout chunks are typically ~8k elements, having more than 64k unique values in a dictionary
55+
/// means dictionary encoding provides little compression benefit. If a column has very high
56+
/// cardinality, the fallback encoding strategy should be used instead.
5257
#[derive(Clone)]
5358
pub struct DictLayoutConstraints {
59+
/// Maximum size of the dictionary in bytes.
5460
pub max_bytes: usize,
55-
// Dict layout codes currently only support u16 codes
61+
/// Maximum dictionary length. Limited to `u16` because dictionaries with more than 64k unique
62+
/// values provide diminishing compression returns given typical chunk sizes (~8k elements).
63+
///
64+
/// The codes dtype is chosen dynamically based on the actual dictionary size:
65+
/// - [`PType::U8`] when the dictionary has at most 255 entries
66+
/// - [`PType::U16`] when the dictionary has more than 255 entries
5667
pub max_len: u16,
5768
}
5869

@@ -387,7 +398,7 @@ impl Stream for DictionaryTransformer {
387398
}
388399

389400
match self.input.poll_next_unpin(cx) {
390-
Poll::Ready(Some(Ok(DictionaryChunk::Codes(codes)))) => {
401+
Poll::Ready(Some(Ok(DictionaryChunk::Codes((seq_id, codes))))) => {
391402
if self.active_codes_tx.is_none() {
392403
// Start a new group
393404
let (codes_tx, codes_rx) = kanal::bounded_async::<SequencedChunk>(1);
@@ -396,13 +407,17 @@ impl Stream for DictionaryTransformer {
396407
self.active_codes_tx = Some(codes_tx.clone());
397408
self.active_values_tx = Some(values_tx);
398409

399-
// Send first codes
410+
let codes_dtype = codes.dtype().clone();
411+
412+
// Send first codes.
400413
self.pending_send =
401-
Some(Box::pin(async move { codes_tx.send(Ok(codes)).await }));
414+
Some(Box::pin(
415+
async move { codes_tx.send(Ok((seq_id, codes))).await },
416+
));
402417

403-
// Create output streams
418+
// Create output streams.
404419
let codes_stream = SequentialStreamAdapter::new(
405-
DType::Primitive(PType::U16, NonNullable),
420+
codes_dtype,
406421
codes_rx.into_stream().boxed(),
407422
)
408423
.sendable();
@@ -416,13 +431,13 @@ impl Stream for DictionaryTransformer {
416431
.boxed();
417432

418433
return Poll::Ready(Some((codes_stream, values_future)));
419-
} else {
420-
// Continue streaming codes to existing group
421-
if let Some(tx) = &self.active_codes_tx {
422-
let tx = tx.clone();
423-
self.pending_send =
424-
Some(Box::pin(async move { tx.send(Ok(codes)).await }));
425-
}
434+
}
435+
436+
// Continue streaming codes to existing group
437+
if let Some(tx) = &self.active_codes_tx {
438+
let tx = tx.clone();
439+
self.pending_send =
440+
Some(Box::pin(async move { tx.send(Ok((seq_id, codes))).await }));
426441
}
427442
}
428443
Poll::Ready(Some(Ok(DictionaryChunk::Values(values)))) => {
@@ -514,3 +529,106 @@ fn encode_chunk(mut encoder: Box<dyn DictEncoder>, chunk: &dyn Array) -> Encodin
514529
fn remainder(array: &dyn Array, encoded_len: usize) -> Option<ArrayRef> {
515530
(encoded_len < array.len()).then(|| array.slice(encoded_len..array.len()))
516531
}
532+
533+
#[cfg(test)]
534+
mod tests {
535+
use futures::StreamExt;
536+
use vortex_array::IntoArray;
537+
use vortex_array::arrays::VarBinArray;
538+
use vortex_array::builders::dict::DictConstraints;
539+
use vortex_dtype::DType;
540+
use vortex_dtype::Nullability::NonNullable;
541+
use vortex_dtype::PType;
542+
543+
use super::DictionaryTransformer;
544+
use super::dict_encode_stream;
545+
use crate::sequence::SequenceId;
546+
use crate::sequence::SequentialStream;
547+
use crate::sequence::SequentialStreamAdapter;
548+
use crate::sequence::SequentialStreamExt;
549+
550+
/// Regression test for a bug where the codes stream dtype was hardcoded to U16 instead of
551+
/// using the actual codes dtype from the array. When `max_len <= 255`, the dict encoder
552+
/// produces U8 codes, but the stream was incorrectly typed as U16, causing a dtype mismatch
553+
/// assertion failure in [`SequentialStreamAdapter`].
554+
#[tokio::test]
555+
async fn test_dict_transformer_uses_u8_for_small_dictionaries() {
556+
// Use max_len = 100 to force U8 codes (since 100 <= 255).
557+
let constraints = DictConstraints {
558+
max_bytes: 1024 * 1024,
559+
max_len: 100,
560+
};
561+
562+
// Create a simple string array with a few unique values.
563+
let arr = VarBinArray::from(vec!["hello", "world", "hello", "world"]).into_array();
564+
565+
// Wrap into a sequential stream.
566+
let mut pointer = SequenceId::root();
567+
let input_stream = SequentialStreamAdapter::new(
568+
arr.dtype().clone(),
569+
futures::stream::once(async move { Ok((pointer.advance(), arr)) }),
570+
)
571+
.sendable();
572+
573+
// Encode into dict chunks.
574+
let dict_stream = dict_encode_stream(input_stream, constraints);
575+
576+
// Transform into codes/values streams.
577+
let mut transformer = DictionaryTransformer::new(dict_stream);
578+
579+
// Get the first (and only) run.
580+
let (codes_stream, _values_fut) = transformer
581+
.next()
582+
.await
583+
.expect("expected at least one dictionary run");
584+
585+
// The key assertion: codes stream dtype should be U8, not U16.
586+
assert_eq!(
587+
codes_stream.dtype(),
588+
&DType::Primitive(PType::U8, NonNullable),
589+
"codes stream should use U8 dtype for small dictionaries, not U16"
590+
);
591+
}
592+
593+
/// Test that the codes stream uses U16 dtype when the dictionary has more than 255 entries.
594+
#[tokio::test]
595+
async fn test_dict_transformer_uses_u16_for_large_dictionaries() {
596+
// Use max_len = 1000 to allow U16 codes (since 1000 > 255).
597+
let constraints = DictConstraints {
598+
max_bytes: 1024 * 1024,
599+
max_len: 1000,
600+
};
601+
602+
// Create an array with more than 255 distinct values to force U16 codes.
603+
let values: Vec<String> = (0..300).map(|i| format!("value_{i}")).collect();
604+
let arr =
605+
VarBinArray::from(values.iter().map(|s| s.as_str()).collect::<Vec<_>>()).into_array();
606+
607+
// Wrap into a sequential stream.
608+
let mut pointer = SequenceId::root();
609+
let input_stream = SequentialStreamAdapter::new(
610+
arr.dtype().clone(),
611+
futures::stream::once(async move { Ok((pointer.advance(), arr)) }),
612+
)
613+
.sendable();
614+
615+
// Encode into dict chunks.
616+
let dict_stream = dict_encode_stream(input_stream, constraints);
617+
618+
// Transform into codes/values streams.
619+
let mut transformer = DictionaryTransformer::new(dict_stream);
620+
621+
// Get the first (and only) run.
622+
let (codes_stream, _values_fut) = transformer
623+
.next()
624+
.await
625+
.expect("expected at least one dictionary run");
626+
627+
// Codes stream dtype should be U16 since we have more than 255 distinct values.
628+
assert_eq!(
629+
codes_stream.dtype(),
630+
&DType::Primitive(PType::U16, NonNullable),
631+
"codes stream should use U16 dtype for dictionaries with >255 entries"
632+
);
633+
}
634+
}

0 commit comments

Comments
 (0)