@@ -89,7 +89,30 @@ impl DType {
8989
9090 /// Check if `self` and `other` are equal, ignoring nullability
9191 pub fn eq_ignore_nullability ( & self , other : & Self ) -> bool {
92- self . as_nullable ( ) . eq ( & other. as_nullable ( ) )
92+ match ( self , other) {
93+ ( Null , Null ) => true ,
94+ ( Null , _) => false ,
95+ ( Bool ( _) , Bool ( _) ) => true ,
96+ ( Bool ( _) , _) => false ,
97+ ( Primitive ( lhs_ptype, _) , Primitive ( rhs_ptype, _) ) => lhs_ptype == rhs_ptype,
98+ ( Primitive ( ..) , _) => false ,
99+ ( Utf8 ( _) , Utf8 ( _) ) => true ,
100+ ( Utf8 ( _) , _) => false ,
101+ ( Binary ( _) , Binary ( _) ) => true ,
102+ ( Binary ( _) , _) => false ,
103+ ( List ( lhs_dtype, _) , List ( rhs_dtype, _) ) => lhs_dtype. eq_ignore_nullability ( rhs_dtype) ,
104+ ( List ( ..) , _) => false ,
105+ ( Struct ( lhs_dtype, _) , Struct ( rhs_dtype, _) ) => {
106+ ( lhs_dtype. names ( ) == rhs_dtype. names ( ) )
107+ && ( lhs_dtype
108+ . dtypes ( )
109+ . zip_eq ( rhs_dtype. dtypes ( ) )
110+ . all ( |( l, r) | l. eq_ignore_nullability ( & r) ) )
111+ }
112+ ( Struct ( ..) , _) => false ,
113+ ( Extension ( lhs_extdtype) , Extension ( rhs_extdtype) ) => lhs_extdtype == rhs_extdtype,
114+ ( Extension ( _) , _) => false ,
115+ }
93116 }
94117
95118 /// Check if `self` is a `StructDType`
0 commit comments