diff --git a/parquet-variant/src/variant.rs b/parquet-variant/src/variant.rs index b3e9e6d94c13..1b0bb9a13f56 100644 --- a/parquet-variant/src/variant.rs +++ b/parquet-variant/src/variant.rs @@ -474,12 +474,8 @@ impl<'m, 'v> VariantObject<'m, 'v> { let field_name = self.metadata.get_field_by(field_id)?; let start_offset = field_offsets[i]; - let end_offset = field_offsets[i + 1]; - let value_bytes = slice_from_slice( - self.value, - self.header.values_start_byte + start_offset - ..self.header.values_start_byte + end_offset, - )?; + let value_bytes = + slice_from_slice(self.value, self.header.values_start_byte + start_offset..)?; let variant = Variant::try_new_with_metadata(self.metadata, value_bytes)?; fields.push((field_name, variant)); diff --git a/parquet-variant/tests/variant_interop.rs b/parquet-variant/tests/variant_interop.rs index 34470995e5fe..7c165967c9e8 100644 --- a/parquet-variant/tests/variant_interop.rs +++ b/parquet-variant/tests/variant_interop.rs @@ -23,9 +23,8 @@ use std::fs; use std::path::{Path, PathBuf}; -use arrow_schema::ArrowError; use chrono::NaiveDate; -use parquet_variant::{Variant, VariantMetadata}; +use parquet_variant::Variant; fn cases_dir() -> PathBuf { Path::new(env!("CARGO_MANIFEST_DIR")) @@ -34,11 +33,24 @@ fn cases_dir() -> PathBuf { .join("variant") } -fn load_case(name: &str) -> Result<(Vec, Vec), ArrowError> { - let root = cases_dir(); - let meta = fs::read(root.join(format!("{name}.metadata")))?; - let val = fs::read(root.join(format!("{name}.value")))?; - Ok((meta, val)) +struct Case { + metadata: Vec, + value: Vec, +} + +impl Case { + /// Load the case with the given name from the parquet testing repository. + fn load(name: &str) -> Self { + let root = cases_dir(); + let metadata = fs::read(root.join(format!("{name}.metadata"))).unwrap(); + let value = fs::read(root.join(format!("{name}.value"))).unwrap(); + Self { metadata, value } + } + + /// Return the Variant for this case. + fn variant(&self) -> Variant<'_, '_> { + Variant::try_new(&self.metadata, &self.value).expect("Failed to parse variant") + } } /// Return a list of the values from the parquet testing repository: @@ -67,47 +79,98 @@ fn get_primitive_cases() -> Vec<(&'static str, Variant<'static, 'static>)> { ("short_string", Variant::ShortString("Less than 64 bytes (❤\u{fe0f} with utf8)")), ] } - -fn get_non_primitive_cases() -> Vec<&'static str> { - vec!["object_primitive", "array_primitive"] -} - #[test] -fn variant_primitive() -> Result<(), ArrowError> { +fn variant_primitive() { let cases = get_primitive_cases(); for (case, want) in cases { - let (metadata, value) = load_case(case)?; - let got = Variant::try_new(&metadata, &value)?; + let case = Case::load(case); + let got = case.variant(); assert_eq!(got, want); } - Ok(()) } - #[test] -fn variant_non_primitive() -> Result<(), ArrowError> { - let cases = get_non_primitive_cases(); - for case in cases { - let (metadata, value) = load_case(case)?; - let variant_metadata = VariantMetadata::try_new(&metadata)?; - let variant = Variant::try_new(&metadata, &value)?; - match case { - "object_primitive" => { - assert!(matches!(variant, Variant::Object(_))); - assert_eq!(variant_metadata.dictionary_size(), 7); - let dict_val = variant_metadata.get_field_by(0)?; - assert_eq!(dict_val, "int_field"); - } - "array_primitive" => match variant { - Variant::List(arr) => { - let v = arr.get(0)?; - assert!(matches!(v, Variant::Int8(2))); - let v = arr.get(1)?; - assert!(matches!(v, Variant::Int8(1))); - } - _ => panic!("expected an array"), +fn variant_object_empty() { + let case = Case::load("object_empty"); + let Variant::Object(variant_object) = case.variant() else { + panic!("expected an object"); + }; + assert_eq!(variant_object.len(), 0); + assert!(variant_object.is_empty()); +} +#[test] +fn variant_object_primitive() { + // the data is defined in + // https://github.com/apache/parquet-testing/blob/84d525a8731cec345852fb4ea2e7c581fbf2ef29/variant/data_dictionary.json#L46-L53 + // + // ```json + // " "object_primitive": { + // "boolean_false_field": false, + // "boolean_true_field": true, + // "double_field": 1.23456789, + // "int_field": 1, + // "null_field": null, + // "string_field": "Apache Parquet", + // "timestamp_field": "2025-04-16T12:34:56.78" + // }, + // ``` + let case = Case::load("object_primitive"); + let Variant::Object(variant_object) = case.variant() else { + panic!("expected an object"); + }; + let expected_fields = vec![ + ("boolean_false_field", Variant::BooleanFalse), + ("boolean_true_field", Variant::BooleanTrue), + // spark wrote this as a decimal4 (not a double) + ( + "double_field", + Variant::Decimal4 { + integer: 123456789, + scale: 8, }, - _ => unreachable!(), - } + ), + ("int_field", Variant::Int8(1)), + ("null_field", Variant::Null), + ("string_field", Variant::ShortString("Apache Parquet")), + ( + // apparently spark wrote this as a string (not a timestamp) + "timestamp_field", + Variant::ShortString("2025-04-16T12:34:56.78"), + ), + ]; + let actual_fields: Vec<_> = variant_object.fields().unwrap().collect(); + assert_eq!(actual_fields, expected_fields); +} +#[test] +fn variant_array_primitive() { + // The data is defined in + // https://github.com/apache/parquet-testing/blob/84d525a8731cec345852fb4ea2e7c581fbf2ef29/variant/data_dictionary.json#L24-L29 + // + // ```json + // "array_primitive": [ + // 2, + // 1, + // 5, + // 9 + // ], + // ``` + let case = Case::load("array_primitive"); + let Variant::List(list) = case.variant() else { + panic!("expected an array"); + }; + let expected = vec![ + Variant::Int8(2), + Variant::Int8(1), + Variant::Int8(5), + Variant::Int8(9), + ]; + let actual: Vec<_> = list.values().unwrap().collect(); + assert_eq!(actual, expected); + + // Call `get` for each individual element + for (i, expected_value) in expected.iter().enumerate() { + let got = list.get(i).unwrap(); + assert_eq!(&got, expected_value); } - Ok(()) } + +// TODO: Add tests for object_nested and array_nested