Skip to content

Commit 08ed8aa

Browse files
authored
feat: Don't allocate another DType in DType::eq_ignore_nullability (#1948)
1 parent 16daa52 commit 08ed8aa

File tree

2 files changed

+37
-2
lines changed

2 files changed

+37
-2
lines changed

vortex-dtype/src/dtype.rs

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,30 @@ impl DType {
8989

9090
/// Check if `self` and `other` are equal, ignoring nullability
9191
pub fn eq_ignore_nullability(&self, other: &Self) -> bool {
92-
self.as_nullable().eq(&other.as_nullable())
92+
match (self, other) {
93+
(Null, Null) => true,
94+
(Null, _) => false,
95+
(Bool(_), Bool(_)) => true,
96+
(Bool(_), _) => false,
97+
(Primitive(lhs_ptype, _), Primitive(rhs_ptype, _)) => lhs_ptype == rhs_ptype,
98+
(Primitive(..), _) => false,
99+
(Utf8(_), Utf8(_)) => true,
100+
(Utf8(_), _) => false,
101+
(Binary(_), Binary(_)) => true,
102+
(Binary(_), _) => false,
103+
(List(lhs_dtype, _), List(rhs_dtype, _)) => lhs_dtype.eq_ignore_nullability(rhs_dtype),
104+
(List(..), _) => false,
105+
(Struct(lhs_dtype, _), Struct(rhs_dtype, _)) => {
106+
(lhs_dtype.names() == rhs_dtype.names())
107+
&& (lhs_dtype
108+
.dtypes()
109+
.zip_eq(rhs_dtype.dtypes())
110+
.all(|(l, r)| l.eq_ignore_nullability(&r)))
111+
}
112+
(Struct(..), _) => false,
113+
(Extension(lhs_extdtype), Extension(rhs_extdtype)) => lhs_extdtype == rhs_extdtype,
114+
(Extension(_), _) => false,
115+
}
93116
}
94117

95118
/// Check if `self` is a `StructDType`

vortex-dtype/src/extension.rs

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,14 +58,26 @@ impl From<&[u8]> for ExtMetadata {
5858
}
5959

6060
/// A type descriptor for an extension type
61-
#[derive(Debug, Clone, PartialOrd, PartialEq, Eq, Hash)]
61+
#[derive(Debug, Clone, PartialOrd, Eq)]
6262
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
6363
pub struct ExtDType {
6464
id: ExtID,
6565
storage_dtype: Arc<DType>,
6666
metadata: Option<ExtMetadata>,
6767
}
6868

69+
impl PartialEq for ExtDType {
70+
fn eq(&self, other: &Self) -> bool {
71+
self.id == other.id
72+
}
73+
}
74+
75+
impl std::hash::Hash for ExtDType {
76+
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
77+
self.id.hash(state);
78+
}
79+
}
80+
6981
impl ExtDType {
7082
/// Creates a new `ExtDType`.
7183
///

0 commit comments

Comments
 (0)