Skip to content

Commit d1c2d9c

Browse files
authored
feat: Make FieldNames comparison more ergonomic (#4239)
Just been playing around with select exprs recently, and I think this would make it nicer to work with field names. Signed-off-by: Adam Gutglick <[email protected]>
1 parent 4e8178f commit d1c2d9c

File tree

7 files changed

+96
-20
lines changed

7 files changed

+96
-20
lines changed

vortex-array/src/arrays/struct_/mod.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -531,7 +531,7 @@ mod test {
531531
[0i64, 1, 2, 3, 4]
532532
);
533533

534-
assert_eq!(struct_a.names(), &[FieldName::from("ys")].into());
534+
assert_eq!(struct_a.names(), &["ys"]);
535535
assert_eq!(struct_a.fields.len(), 1);
536536
assert_eq!(struct_a.len(), 5);
537537
assert_eq!(
@@ -550,7 +550,7 @@ mod test {
550550
empty.is_none(),
551551
"Expected None when removing non-existent column"
552552
);
553-
assert_eq!(struct_a.names(), &[FieldName::from("ys")].into());
553+
assert_eq!(struct_a.names(), &["ys"]);
554554
}
555555

556556
#[test]

vortex-dtype/src/dtype.rs

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,48 @@ pub type FieldName = Arc<str>;
2323
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
2424
pub struct FieldNames(Arc<[FieldName]>);
2525

26+
impl PartialEq<&FieldNames> for FieldNames {
27+
fn eq(&self, other: &&FieldNames) -> bool {
28+
self == *other
29+
}
30+
}
31+
32+
impl PartialEq<&[&str]> for FieldNames {
33+
fn eq(&self, other: &&[&str]) -> bool {
34+
self.len() == other.len() && self.iter().zip_eq(other.iter()).all(|(l, r)| &**l == *r)
35+
}
36+
}
37+
38+
impl PartialEq<&[&str]> for &FieldNames {
39+
fn eq(&self, other: &&[&str]) -> bool {
40+
*self == other
41+
}
42+
}
43+
44+
impl<const N: usize> PartialEq<[&str; N]> for FieldNames {
45+
fn eq(&self, other: &[&str; N]) -> bool {
46+
self == other.as_slice()
47+
}
48+
}
49+
50+
impl<const N: usize> PartialEq<[&str; N]> for &FieldNames {
51+
fn eq(&self, other: &[&str; N]) -> bool {
52+
*self == other.as_slice()
53+
}
54+
}
55+
56+
impl PartialEq<&[FieldName]> for FieldNames {
57+
fn eq(&self, other: &&[FieldName]) -> bool {
58+
self.0.as_ref() == *other
59+
}
60+
}
61+
62+
impl PartialEq<&[FieldName]> for &FieldNames {
63+
fn eq(&self, other: &&[FieldName]) -> bool {
64+
self.0.as_ref() == *other
65+
}
66+
}
67+
2668
impl FieldNames {
2769
/// Returns the number of elements.
2870
pub fn len(&self) -> usize {
@@ -457,4 +499,40 @@ mod tests {
457499
assert_eq!(iter.next(), Some("b".into()));
458500
assert_eq!(iter.next(), None);
459501
}
502+
503+
#[test]
504+
fn test_field_names_equality() {
505+
let field_names = FieldNames::from(["field1", "field2", "field3"]);
506+
507+
// FieldNames == &FieldNames
508+
let field_names_ref = &field_names;
509+
assert_eq!(field_names, field_names_ref);
510+
511+
// FieldNames == &[&str]
512+
let str_slice = &["field1", "field2", "field3"][..];
513+
assert_eq!(field_names, str_slice);
514+
515+
// &FieldNames == &[&str]
516+
assert_eq!(&field_names, str_slice);
517+
518+
// FieldNames == [&str; N] (array)
519+
assert_eq!(field_names, ["field1", "field2", "field3"]);
520+
521+
// &FieldNames == [&str; N] (array)
522+
assert_eq!(&field_names, ["field1", "field2", "field3"]);
523+
524+
// FieldNames == &[FieldName]
525+
let field_name_vec: Vec<FieldName> =
526+
vec!["field1".into(), "field2".into(), "field3".into()];
527+
let field_name_slice = field_name_vec.as_slice();
528+
assert_eq!(field_names, field_name_slice);
529+
530+
// &FieldNames == &[FieldName]
531+
assert_eq!(&field_names, field_name_slice);
532+
533+
// Test inequality cases
534+
assert_ne!(field_names, &["field1", "field2"][..]);
535+
assert_ne!(field_names, ["different", "fields", "here"]);
536+
assert_ne!(field_names, &["field1", "field2", "field3", "extra"][..]);
537+
}
460538
}

vortex-dtype/src/struct_.rs

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -436,23 +436,21 @@ mod test {
436436
let sdt = dtype.as_struct().unwrap();
437437
assert_eq!(sdt.names().len(), 2);
438438
assert_eq!(sdt.fields().len(), 2);
439-
assert_eq!(sdt.names()[0], "A".into());
440-
assert_eq!(sdt.names()[1], "B".into());
439+
assert_eq!(sdt.names(), ["A", "B"]);
441440
assert_eq!(sdt.field_by_index(0).unwrap(), a_type);
442441
assert_eq!(sdt.field_by_index(1).unwrap(), b_type);
443442

444443
let proj = sdt.project(&["B".into(), "A".into()]).unwrap();
445-
assert_eq!(proj.names()[0], "B".into());
444+
assert_eq!(proj.names(), ["B", "A"]);
446445
assert_eq!(proj.field_by_index(0).unwrap(), b_type);
447-
assert_eq!(proj.names()[1], "A".into());
448446
assert_eq!(proj.field_by_index(1).unwrap(), a_type);
449447

450448
assert_eq!(sdt.find("A").unwrap(), 0);
451449
assert_eq!(sdt.find("B").unwrap(), 1);
452450
assert!(sdt.find("C").is_none());
453451

454452
let without_a = sdt.without_field(0).unwrap();
455-
assert_eq!(without_a.names()[0], "B".into());
453+
assert_eq!(without_a.names(), ["B"]);
456454
assert_eq!(without_a.field_by_index(0).unwrap(), b_type);
457455
assert_eq!(without_a.nfields(), 1);
458456
}
@@ -478,7 +476,7 @@ mod test {
478476
let sdt = StructFields::from_iter([("A", a_type), ("B", b_type.clone())]);
479477

480478
let without_a = sdt.without_field(0).unwrap();
481-
assert_eq!(without_a.names()[0], "B".into());
479+
assert_eq!(without_a.names(), ["B"]);
482480
assert_eq!(without_a.field_by_index(0).unwrap(), b_type);
483481
assert_eq!(without_a.nfields(), 1);
484482
}
@@ -494,7 +492,7 @@ mod test {
494492
let sf2 = StructFields::from_iter([("C", child_c.clone())]);
495493

496494
let merged = StructFields::disjoint_merge(&sf1, &sf2).unwrap();
497-
assert_eq!(merged.names(), &FieldNames::from_iter(["A", "B", "C"]));
495+
assert_eq!(merged.names(), ["A", "B", "C"]);
498496
assert_eq!(
499497
merged.fields().collect_vec(),
500498
vec![child_a, child_b, child_c]

vortex-expr/src/exprs/merge.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -262,7 +262,7 @@ mod tests {
262262

263263
assert_eq!(
264264
actual_array.as_struct_typed().names(),
265-
&["a", "b", "c", "d", "e"].into()
265+
["a", "b", "c", "d", "e"]
266266
);
267267

268268
assert_eq!(
@@ -402,7 +402,7 @@ mod tests {
402402
.to_struct()
403403
.unwrap();
404404

405-
assert_eq!(actual_array.names(), &["a", "c", "b", "d"].into());
405+
assert_eq!(actual_array.names(), ["a", "c", "b", "d"]);
406406
}
407407

408408
#[test]

vortex-expr/src/exprs/pack.rs

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -274,8 +274,8 @@ mod tests {
274274
.unwrap()
275275
.to_struct()
276276
.unwrap();
277-
let expected_names: FieldNames = ["one", "two", "three"].into();
278-
assert_eq!(actual_array.names(), &expected_names);
277+
278+
assert_eq!(actual_array.names(), ["one", "two", "three"]);
279279
assert_eq!(actual_array.validity(), &Validity::NonNullable);
280280

281281
assert_eq!(
@@ -322,8 +322,8 @@ mod tests {
322322
.unwrap()
323323
.to_struct()
324324
.unwrap();
325-
let expected_names = FieldNames::from(["one", "two", "three"]);
326-
assert_eq!(actual_array.names(), &expected_names);
325+
326+
assert_eq!(actual_array.names(), ["one", "two", "three"]);
327327

328328
assert_eq!(
329329
primitive_field(actual_array.as_ref(), &["one"])
@@ -365,8 +365,8 @@ mod tests {
365365
.unwrap()
366366
.to_struct()
367367
.unwrap();
368-
let expected_names: FieldNames = ["one", "two", "three"].into();
369-
assert_eq!(actual_array.names(), &expected_names);
368+
369+
assert_eq!(actual_array.names(), ["one", "two", "three"]);
370370
assert_eq!(actual_array.validity(), &Validity::AllValid);
371371
}
372372
}

vortex-python/src/scalar/factory.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ fn scalar_helper_inner(value: &Bound<'_, PyAny>, dtype: Option<&DType>) -> PyRes
107107
.into();
108108

109109
if let Some(DType::Struct(dtype, nullability)) = dtype {
110-
if &names != dtype.names() {
110+
if names != dtype.names() {
111111
return Err(PyValueError::new_err(format!(
112112
"Dictionary field names {:?} do not match target dtype names {:?}",
113113
&names,

vortex-scalar/src/struct_.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -511,11 +511,11 @@ mod tests {
511511
let projected_struct = projected.as_struct();
512512

513513
assert_eq!(projected_struct.names().len(), 1);
514-
assert_eq!(projected_struct.names()[0], "b".into());
514+
assert_eq!(projected_struct.names()[0].as_ref(), "b");
515515

516516
let fields = projected_struct.fields().unwrap();
517517
assert_eq!(fields.len(), 1);
518-
assert_eq!(fields[0].as_utf8().value().unwrap(), "hello".into());
518+
assert_eq!(fields[0].as_utf8().value().unwrap().as_str(), "hello");
519519
}
520520

521521
#[test]

0 commit comments

Comments
 (0)