Skip to content

Commit 241f71d

Browse files
committed
fix hardcoded dtype
Signed-off-by: Connor Tsui <[email protected]>
1 parent 4c65aa5 commit 241f71d

File tree

1 file changed

+111
-5
lines changed

1 file changed

+111
-5
lines changed

vortex-layout/src/layouts/dict/writer.rs

Lines changed: 111 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@ use vortex_array::builders::dict::DictEncoder;
2727
use vortex_array::builders::dict::dict_encoder;
2828
use vortex_btrblocks::BtrBlocksCompressor;
2929
use vortex_dtype::DType;
30-
use vortex_dtype::Nullability::NonNullable;
3130
use vortex_dtype::PType;
3231
use vortex_error::VortexError;
3332
use vortex_error::VortexResult;
@@ -49,10 +48,12 @@ use crate::sequence::SequentialStream;
4948
use crate::sequence::SequentialStreamAdapter;
5049
use crate::sequence::SequentialStreamExt;
5150

51+
/// Constraints for dictionary layout encoding.
5252
#[derive(Clone)]
5353
pub struct DictLayoutConstraints {
54+
/// Maximum size of the dictionary in bytes.
5455
pub max_bytes: usize,
55-
// Dict layout codes currently only support u16 codes
56+
/// Maximum dictionary length.
5657
pub max_len: u16,
5758
}
5859

@@ -396,13 +397,15 @@ impl Stream for DictionaryTransformer {
396397
self.active_codes_tx = Some(codes_tx.clone());
397398
self.active_values_tx = Some(values_tx);
398399

399-
// Send first codes
400+
let codes_dtype = codes.1.dtype().clone();
401+
402+
// Send first codes.
400403
self.pending_send =
401404
Some(Box::pin(async move { codes_tx.send(Ok(codes)).await }));
402405

403-
// Create output streams
406+
// Create output streams.
404407
let codes_stream = SequentialStreamAdapter::new(
405-
DType::Primitive(PType::U16, NonNullable),
408+
codes_dtype,
406409
codes_rx.into_stream().boxed(),
407410
)
408411
.sendable();
@@ -514,3 +517,106 @@ fn encode_chunk(mut encoder: Box<dyn DictEncoder>, chunk: &dyn Array) -> Encodin
514517
fn remainder(array: &dyn Array, encoded_len: usize) -> Option<ArrayRef> {
515518
(encoded_len < array.len()).then(|| array.slice(encoded_len..array.len()))
516519
}
520+
521+
#[cfg(test)]
522+
mod tests {
523+
use futures::StreamExt;
524+
use vortex_array::IntoArray;
525+
use vortex_array::arrays::VarBinArray;
526+
use vortex_array::builders::dict::DictConstraints;
527+
use vortex_dtype::DType;
528+
use vortex_dtype::Nullability::NonNullable;
529+
use vortex_dtype::PType;
530+
531+
use super::DictionaryTransformer;
532+
use super::dict_encode_stream;
533+
use crate::sequence::SequenceId;
534+
use crate::sequence::SequentialStream;
535+
use crate::sequence::SequentialStreamAdapter;
536+
use crate::sequence::SequentialStreamExt;
537+
538+
/// Regression test for a bug where the codes stream dtype was hardcoded to U16 instead of
539+
/// using the actual codes dtype from the array. When `max_len <= 255`, the dict encoder
540+
/// produces U8 codes, but the stream was incorrectly typed as U16, causing a dtype mismatch
541+
/// assertion failure in [`SequentialStreamAdapter`].
542+
#[tokio::test]
543+
async fn test_dict_transformer_uses_u8_for_small_dictionaries() {
544+
// Use max_len = 100 to force U8 codes (since 100 <= 255).
545+
let constraints = DictConstraints {
546+
max_bytes: 1024 * 1024,
547+
max_len: 100,
548+
};
549+
550+
// Create a simple string array with a few unique values.
551+
let arr = VarBinArray::from(vec!["hello", "world", "hello", "world"]).into_array();
552+
553+
// Wrap into a sequential stream.
554+
let mut pointer = SequenceId::root();
555+
let input_stream = SequentialStreamAdapter::new(
556+
arr.dtype().clone(),
557+
futures::stream::once(async move { Ok((pointer.advance(), arr)) }),
558+
)
559+
.sendable();
560+
561+
// Encode into dict chunks.
562+
let dict_stream = dict_encode_stream(input_stream, constraints);
563+
564+
// Transform into codes/values streams.
565+
let mut transformer = DictionaryTransformer::new(dict_stream);
566+
567+
// Get the first (and only) run.
568+
let (codes_stream, _values_fut) = transformer
569+
.next()
570+
.await
571+
.expect("expected at least one dictionary run");
572+
573+
// The key assertion: codes stream dtype should be U8, not U16.
574+
assert_eq!(
575+
codes_stream.dtype(),
576+
&DType::Primitive(PType::U8, NonNullable),
577+
"codes stream should use U8 dtype for small dictionaries, not U16"
578+
);
579+
}
580+
581+
/// Test that the codes stream uses U16 dtype when the dictionary has more than 255 entries.
582+
#[tokio::test]
583+
async fn test_dict_transformer_uses_u16_for_large_dictionaries() {
584+
// Use max_len = 1000 to allow U16 codes (since 1000 > 255).
585+
let constraints = DictConstraints {
586+
max_bytes: 1024 * 1024,
587+
max_len: 1000,
588+
};
589+
590+
// Create an array with more than 255 distinct values to force U16 codes.
591+
let values: Vec<String> = (0..300).map(|i| format!("value_{i}")).collect();
592+
let arr =
593+
VarBinArray::from(values.iter().map(|s| s.as_str()).collect::<Vec<_>>()).into_array();
594+
595+
// Wrap into a sequential stream.
596+
let mut pointer = SequenceId::root();
597+
let input_stream = SequentialStreamAdapter::new(
598+
arr.dtype().clone(),
599+
futures::stream::once(async move { Ok((pointer.advance(), arr)) }),
600+
)
601+
.sendable();
602+
603+
// Encode into dict chunks.
604+
let dict_stream = dict_encode_stream(input_stream, constraints);
605+
606+
// Transform into codes/values streams.
607+
let mut transformer = DictionaryTransformer::new(dict_stream);
608+
609+
// Get the first (and only) run.
610+
let (codes_stream, _values_fut) = transformer
611+
.next()
612+
.await
613+
.expect("expected at least one dictionary run");
614+
615+
// Codes stream dtype should be U16 since we have more than 255 distinct values.
616+
assert_eq!(
617+
codes_stream.dtype(),
618+
&DType::Primitive(PType::U16, NonNullable),
619+
"codes stream should use U16 dtype for dictionaries with >255 entries"
620+
);
621+
}
622+
}

0 commit comments

Comments
 (0)