diff --git a/src/lib.rs b/src/lib.rs index 2a9c6206..f9bd6df7 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -899,7 +899,7 @@ impl<'a, K: 'a + Eq + Hash, V: 'a, S: BuildHasher + Clone> DashMap { /// **Locking behaviour:** May deadlock if called when holding any sort of reference into the map. pub fn entry_ref<'q, Q>(&'a self, key: &'q Q) -> EntryRef<'a, 'q, K, Q, V> where - Q: Hash + Equivalent, + Q: Hash + Equivalent + ?Sized, { self._entry_ref(key) } @@ -1199,7 +1199,7 @@ impl<'a, K: 'a + Eq + Hash, V: 'a, S: 'a + BuildHasher + Clone> DashMap fn _entry_ref<'q, Q>(&'a self, key: &'q Q) -> EntryRef<'a, 'q, K, Q, V> where - Q: Hash + Equivalent, + Q: Hash + Equivalent + ?Sized, { let hash = self.hash_u64(&key); diff --git a/src/mapref/entry_ref.rs b/src/mapref/entry_ref.rs index da8e7c29..2f53ed27 100644 --- a/src/mapref/entry_ref.rs +++ b/src/mapref/entry_ref.rs @@ -6,12 +6,12 @@ use core::hash::Hash; use std::mem; /// Entry with a borrowed key. -pub enum EntryRef<'a, 'q, K, Q, V> { +pub enum EntryRef<'a, 'q, K, Q: ?Sized, V> { Occupied(OccupiedEntryRef<'a, 'q, K, Q, V>), Vacant(VacantEntryRef<'a, 'q, K, Q, V>), } -impl<'a, 'q, K: Eq + Hash, Q, V> EntryRef<'a, 'q, K, Q, V> { +impl<'a, 'q, K: Eq + Hash, Q: ?Sized, V> EntryRef<'a, 'q, K, Q, V> { /// Apply a function to the stored value if it exists. pub fn and_modify(self, f: impl FnOnce(&mut V)) -> Self { match self { @@ -26,7 +26,7 @@ impl<'a, 'q, K: Eq + Hash, Q, V> EntryRef<'a, 'q, K, Q, V> { } } -impl<'a, 'q, K: Eq + Hash + From<&'q Q>, Q, V> EntryRef<'a, 'q, K, Q, V> { +impl<'a, 'q, K: Eq + Hash + From<&'q Q>, Q: ?Sized, V> EntryRef<'a, 'q, K, Q, V> { /// Get the key of the entry. pub fn key(&self) -> &Q { match *self { @@ -114,13 +114,13 @@ impl<'a, 'q, K: Eq + Hash + From<&'q Q>, Q, V> EntryRef<'a, 'q, K, Q, V> { } } -pub struct VacantEntryRef<'a, 'q, K, Q, V> { +pub struct VacantEntryRef<'a, 'q, K, Q: ?Sized, V> { shard: RwLockWriteGuardDetached<'a>, entry: hash_table::VacantEntry<'a, (K, V)>, key: &'q Q, } -impl<'a, 'q, K: Eq + Hash, Q, V> VacantEntryRef<'a, 'q, K, Q, V> { +impl<'a, 'q, K: Eq + Hash, Q: ?Sized, V> VacantEntryRef<'a, 'q, K, Q, V> { pub(crate) fn new( shard: RwLockWriteGuardDetached<'a>, key: &'q Q, @@ -162,13 +162,13 @@ impl<'a, 'q, K: Eq + Hash, Q, V> VacantEntryRef<'a, 'q, K, Q, V> { } } -pub struct OccupiedEntryRef<'a, 'q, K, Q, V> { +pub struct OccupiedEntryRef<'a, 'q, K, Q: ?Sized, V> { shard: RwLockWriteGuardDetached<'a>, entry: hash_table::OccupiedEntry<'a, (K, V)>, key: &'q Q, } -impl<'a, 'q, K: Eq + Hash, Q, V> OccupiedEntryRef<'a, 'q, K, Q, V> { +impl<'a, 'q, K: Eq + Hash, Q: ?Sized, V> OccupiedEntryRef<'a, 'q, K, Q, V> { pub(crate) fn new( shard: RwLockWriteGuardDetached<'a>, key: &'q Q, @@ -316,4 +316,76 @@ mod tests { assert_eq!(*map.get(&1).unwrap(), 2); } + + #[test] + fn test_str_insert_into_vacant() { + let map: DashMap = DashMap::new(); + + let entry = map.entry_ref("1"); + + assert!(matches!(entry, EntryRef::Vacant(_))); + + let val = entry.insert(2); + + assert_eq!(*val, 2); + + drop(val); + + assert_eq!(*map.get("1").unwrap(), 2); + } + + #[test] + fn test_str_insert_into_occupied() { + let map: DashMap = DashMap::new(); + + map.insert("1".to_owned(), 1000); + + let entry = map.entry_ref("1"); + + assert!(matches!(&entry, EntryRef::Occupied(entry) if *entry.get() == 1000)); + + let val = entry.insert(2); + + assert_eq!(*val, 2); + + drop(val); + + assert_eq!(*map.get("1").unwrap(), 2); + } + + #[test] + fn test_str_insert_entry_into_vacant() { + let map: DashMap = DashMap::new(); + + let entry = map.entry_ref("1"); + + assert!(matches!(entry, EntryRef::Vacant(_))); + + let entry = entry.insert_entry(2); + + assert_eq!(*entry.get(), 2); + + drop(entry); + + assert_eq!(*map.get("1").unwrap(), 2); + } + + #[test] + fn test_str_insert_entry_into_occupied() { + let map: DashMap = DashMap::new(); + + map.insert("1".to_owned(), 1000); + + let entry = map.entry_ref("1"); + + assert!(matches!(&entry, EntryRef::Occupied(entry) if *entry.get() == 1000)); + + let entry = entry.insert_entry(2); + + assert_eq!(*entry.get(), 2); + + drop(entry); + + assert_eq!(*map.get("1").unwrap(), 2); + } }