1414 number_to_string , datetime_normalize , KEY_TO_VAL_STR , short_repr ,
1515 get_truncate_datetime , dict_ , add_root_to_paths )
1616from deepdiff .base import Base
17+
18+ try :
19+ import pandas
20+ except ImportError :
21+ pandas = False
22+
23+ try :
24+ import polars
25+ except ImportError :
26+ polars = False
27+
1728logger = logging .getLogger (__name__ )
1829
1930UNPROCESSED_KEY = object ()
@@ -139,6 +150,7 @@ def __init__(self,
139150 ignore_numeric_type_changes = False ,
140151 ignore_type_subclasses = False ,
141152 ignore_string_case = False ,
153+ use_enum_value = False ,
142154 exclude_obj_callback = None ,
143155 number_to_string_func = None ,
144156 ignore_private_variables = True ,
@@ -154,7 +166,7 @@ def __init__(self,
154166 "exclude_paths, include_paths, exclude_regex_paths, hasher, ignore_repetition, "
155167 "number_format_notation, apply_hash, ignore_type_in_groups, ignore_string_type_changes, "
156168 "ignore_numeric_type_changes, ignore_type_subclasses, ignore_string_case "
157- "number_to_string_func, ignore_private_variables, parent "
169+ "number_to_string_func, ignore_private_variables, parent, use_enum_value "
158170 "encodings, ignore_encoding_errors" ) % ', ' .join (kwargs .keys ()))
159171 if isinstance (hashes , MutableMapping ):
160172 self .hashes = hashes
@@ -170,6 +182,7 @@ def __init__(self,
170182 self .exclude_regex_paths = convert_item_or_items_into_compiled_regexes_else_none (exclude_regex_paths )
171183 self .hasher = default_hasher if hasher is None else hasher
172184 self .hashes [UNPROCESSED_KEY ] = []
185+ self .use_enum_value = use_enum_value
173186
174187 self .significant_digits = self .get_significant_digits (significant_digits , ignore_numeric_type_changes )
175188 self .truncate_datetime = get_truncate_datetime (truncate_datetime )
@@ -206,10 +219,10 @@ def __init__(self,
206219 sha1hex = sha1hex
207220
208221 def __getitem__ (self , obj , extract_index = 0 ):
209- return self ._getitem (self .hashes , obj , extract_index = extract_index )
222+ return self ._getitem (self .hashes , obj , extract_index = extract_index , use_enum_value = self . use_enum_value )
210223
211224 @staticmethod
212- def _getitem (hashes , obj , extract_index = 0 ):
225+ def _getitem (hashes , obj , extract_index = 0 , use_enum_value = False ):
213226 """
214227 extract_index is zero for hash and 1 for count and None to get them both.
215228 To keep it backward compatible, we only get the hash by default so it is set to zero by default.
@@ -220,6 +233,8 @@ def _getitem(hashes, obj, extract_index=0):
220233 key = BoolObj .TRUE
221234 elif obj is False :
222235 key = BoolObj .FALSE
236+ elif use_enum_value and isinstance (obj , Enum ):
237+ key = obj .value
223238
224239 result_n_count = (None , 0 )
225240
@@ -256,14 +271,14 @@ def get(self, key, default=None, extract_index=0):
256271 return self .get_key (self .hashes , key , default = default , extract_index = extract_index )
257272
258273 @staticmethod
259- def get_key (hashes , key , default = None , extract_index = 0 ):
274+ def get_key (hashes , key , default = None , extract_index = 0 , use_enum_value = False ):
260275 """
261276 get_key method for the hashes dictionary.
262277 It can extract the hash for a given key that is already calculated when extract_index=0
263278 or the count of items that went to building the object whenextract_index=1.
264279 """
265280 try :
266- result = DeepHash ._getitem (hashes , key , extract_index = extract_index )
281+ result = DeepHash ._getitem (hashes , key , extract_index = extract_index , use_enum_value = use_enum_value )
267282 except KeyError :
268283 result = default
269284 return result
@@ -444,7 +459,6 @@ def _prep_path(self, obj):
444459 type_ = obj .__class__ .__name__
445460 return KEY_TO_VAL_STR .format (type_ , obj )
446461
447-
448462 def _prep_number (self , obj ):
449463 type_ = "number" if self .ignore_numeric_type_changes else obj .__class__ .__name__
450464 if self .significant_digits is not None :
@@ -475,12 +489,14 @@ def _prep_tuple(self, obj, parent, parents_ids):
475489 return result , counts
476490
477491 def _hash (self , obj , parent , parents_ids = EMPTY_FROZENSET ):
478- """The main diff method"""
492+ """The main hash method"""
479493 counts = 1
480494
481495 if isinstance (obj , bool ):
482496 obj = self ._prep_bool (obj )
483497 result = None
498+ elif self .use_enum_value and isinstance (obj , Enum ):
499+ obj = obj .value
484500 else :
485501 result = not_hashed
486502 try :
@@ -523,6 +539,19 @@ def _hash(self, obj, parent, parents_ids=EMPTY_FROZENSET):
523539 elif isinstance (obj , tuple ):
524540 result , counts = self ._prep_tuple (obj = obj , parent = parent , parents_ids = parents_ids )
525541
542+ elif (pandas and isinstance (obj , pandas .DataFrame )):
543+ def gen ():
544+ yield ('dtype' , obj .dtypes )
545+ yield ('index' , obj .index )
546+ yield from obj .items () # which contains (column name, series tuples)
547+ result , counts = self ._prep_iterable (obj = gen (), parent = parent , parents_ids = parents_ids )
548+ elif (polars and isinstance (obj , polars .DataFrame )):
549+ def gen ():
550+ yield from obj .columns
551+ yield from list (obj .schema .items ())
552+ yield from obj .rows ()
553+ result , counts = self ._prep_iterable (obj = gen (), parent = parent , parents_ids = parents_ids )
554+
526555 elif isinstance (obj , Iterable ):
527556 result , counts = self ._prep_iterable (obj = obj , parent = parent , parents_ids = parents_ids )
528557
0 commit comments