Skip to content

Commit 2b2806a

Browse files
committed
Use unorderedHash, LinkedHashMap/Set are equals to other Maps
Plus review cleanups. Reuse a single iterator implementation.
1 parent 200f520 commit 2b2806a

File tree

3 files changed

+111
-138
lines changed

3 files changed

+111
-138
lines changed

library/src/scala/collection/mutable/HashMap.scala

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -109,12 +109,12 @@ class HashMap[K, V](initialCapacity: Int, loadFactor: Double)
109109
}
110110
this
111111
case lhm: mutable.LinkedHashMap[K, V] =>
112-
val iter = lhm.entryIterator
113-
while (iter.hasNext) {
114-
val entry = iter.next()
115-
put0(entry.key, entry.value,entry.hash,getOld = false)
116-
}
117-
this
112+
val iter = lhm.entryIterator
113+
while (iter.hasNext) {
114+
val entry = iter.next()
115+
put0(entry.key, entry.value, entry.hash, getOld = false)
116+
}
117+
this
118118
case thatMap: Map[K, V] =>
119119
thatMap.foreachEntry { (key: K, value: V) =>
120120
put0(key, value, improveHash(key.##), getOld = false)

library/src/scala/collection/mutable/LinkedHashMap.scala

Lines changed: 89 additions & 103 deletions
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ class LinkedHashMap[K, V]
9393
}
9494
override def sizeHint(size: Int): Unit = {
9595
val target = tableSizeFor(((size + 1).toDouble / LinkedHashMap.defaultLoadFactor).toInt)
96-
if(target > table.length) growTable(target)
96+
if (target > table.length) growTable(target)
9797
}
9898

9999
override def contains(key: K): Boolean = {
@@ -131,23 +131,23 @@ class LinkedHashMap[K, V]
131131
// subclasses of LinkedHashMap might customise `get` ...
132132
super.getOrElseUpdate(key, defaultValue)
133133
} else {
134-
val hash = computeHash(key)
135-
val idx = index(hash)
136-
val nd = table(idx) match {
137-
case null => null
138-
case nd => nd.findEntry(key, hash)
139-
}
140-
if (nd != null) nd.value
141-
else {
142-
val table0 = table
143-
val default = defaultValue
144-
if (contentSize + 1 >= threshold) growTable(table.length * 2)
145-
// Avoid recomputing index if the `defaultValue()` or new element hasn't triggered a table resize.
146-
val newIdx = if (table0 eq table) idx else index(hash)
147-
put0(key, default, false, hash, newIdx)
148-
default
149-
}
150-
}
134+
val hash = computeHash(key)
135+
val idx = index(hash)
136+
val nd = table(idx) match {
137+
case null => null
138+
case nd => nd.findEntry(key, hash)
139+
}
140+
if (nd != null) nd.value
141+
else {
142+
val table0 = table
143+
val default = defaultValue
144+
if (contentSize + 1 >= threshold) growTable(table.length * 2)
145+
// Avoid recomputing index if the `defaultValue()` or new element hasn't triggered a table resize.
146+
val newIdx = if (table0 eq table) idx else index(hash)
147+
put0(key, default, false, hash, newIdx)
148+
default
149+
}
150+
}
151151
}
152152

153153
private[this] def removeEntry0(elem: K): Entry = removeEntry0(elem, computeHash(elem))
@@ -215,104 +215,103 @@ class LinkedHashMap[K, V]
215215
this
216216
}
217217

218-
def iterator: Iterator[(K, V)] = new AbstractIterator[(K, V)] {
218+
private[this] abstract class LinkedHashMapIterator[T] extends AbstractIterator[T] {
219219
private[this] var cur = firstEntry
220-
def hasNext = cur ne null
221-
def next() =
222-
if (hasNext) { val res = (cur.key, cur.value); cur = cur.later; res }
220+
def extract(nd: Entry): T
221+
def hasNext: Boolean = cur ne null
222+
def next(): T =
223+
if (hasNext) { val r = extract(cur); cur = cur.later; r }
223224
else Iterator.empty.next()
224225
}
225226

227+
def iterator: Iterator[(K, V)] =
228+
if (size == 0) Iterator.empty
229+
else new LinkedHashMapIterator[(K, V)] {
230+
def extract(nd: Entry): (K, V) = (nd.key, nd.value)
231+
}
232+
226233
protected class LinkedKeySet extends KeySet {
227234
override def iterableFactory: IterableFactory[collection.Set] = LinkedHashSet
228235
}
229236

230237
override def keySet: collection.Set[K] = new LinkedKeySet
231238

232-
override def keysIterator: Iterator[K] = new AbstractIterator[K] {
233-
private[this] var cur = firstEntry
234-
def hasNext = cur ne null
235-
def next() =
236-
if (hasNext) { val res = cur.key; cur = cur.later; res }
237-
else Iterator.empty.next()
238-
}
239+
override def keysIterator: Iterator[K] =
240+
if (size == 0) Iterator.empty
241+
else new LinkedHashMapIterator[K] {
242+
def extract(nd: Entry): K = nd.key
243+
}
239244

240-
private[collection] def entryIterator: Iterator[Entry] = new AbstractIterator[Entry] {
241-
private[this] var cur = firstEntry
245+
private[collection] def entryIterator: Iterator[Entry] =
246+
if (size == 0) Iterator.empty
247+
else new LinkedHashMapIterator[Entry] {
248+
def extract(nd: Entry): Entry = nd
249+
}
242250

243-
def hasNext = cur ne null
244251

245-
def next() =
246-
if (hasNext) {
247-
val res = cur; cur = cur.later; res
248-
}
249-
else Iterator.empty.next()
250-
}
251252
// Override updateWith for performance, so we can do the update while hashing
252253
// the input key only once and performing one lookup into the hash table
253254
override def updateWith(key: K)(remappingFunction: Option[V] => Option[V]): Option[V] = {
254255
if (getClass != classOf[LinkedHashMap[_, _]]) {
255256
// subclasses of LinkedHashMap might customise `get` ...
256257
super.updateWith(key)(remappingFunction)
257258
} else {
258-
val hash = computeHash(key)
259-
val indexedHash = index(hash)
260-
261-
var foundEntry: Entry = null
262-
var previousEntry: Entry = null
263-
table(indexedHash) match {
264-
case null =>
265-
case nd =>
266-
@tailrec
267-
def findEntry(prev: Entry, nd: Entry, k: K, h: Int): Unit = {
268-
if (h == nd.hash && k == nd.key) {
269-
previousEntry = prev
270-
foundEntry = nd
271-
}
272-
else if ((nd.next eq null) || (nd.hash > h)) ()
273-
else findEntry(nd, nd.next, k, h)
274-
}
275-
276-
findEntry(null, nd, key, hash)
259+
val hash = computeHash(key)
260+
val indexedHash = index(hash)
261+
262+
var foundEntry: Entry = null
263+
var previousEntry: Entry = null
264+
table(indexedHash) match {
265+
case null =>
266+
case nd =>
267+
@tailrec
268+
def findEntry(prev: Entry, nd: Entry, k: K, h: Int): Unit = {
269+
if (h == nd.hash && k == nd.key) {
270+
previousEntry = prev
271+
foundEntry = nd
272+
}
273+
else if ((nd.next eq null) || (nd.hash > h)) ()
274+
else findEntry(nd, nd.next, k, h)
277275
}
278276

279-
val previousValue = foundEntry match {
280-
case null => None
281-
case nd => Some(nd.value)
282-
}
277+
findEntry(null, nd, key, hash)
278+
}
279+
280+
val previousValue = foundEntry match {
281+
case null => None
282+
case nd => Some(nd.value)
283+
}
283284

284-
val nextValue = remappingFunction(previousValue)
285+
val nextValue = remappingFunction(previousValue)
285286

286-
(previousValue, nextValue) match {
287-
case (None, None) => // do nothing
287+
(previousValue, nextValue) match {
288+
case (None, None) => // do nothing
288289

289-
case (Some(_), None) =>
290-
if (previousEntry != null) previousEntry.next = foundEntry.next
291-
else table(indexedHash) = foundEntry.next
292-
deleteEntry(foundEntry)
293-
contentSize -= 1
290+
case (Some(_), None) =>
291+
if (previousEntry != null) previousEntry.next = foundEntry.next
292+
else table(indexedHash) = foundEntry.next
293+
deleteEntry(foundEntry)
294+
contentSize -= 1
294295

295-
case (None, Some(value)) =>
296-
val newIndexedHash =
297-
if (contentSize + 1 >= threshold) {
298-
growTable(table.length * 2)
299-
index(hash)
300-
} else indexedHash
301-
put0(key, value, false, hash, newIndexedHash)
296+
case (None, Some(value)) =>
297+
val newIndexedHash =
298+
if (contentSize + 1 >= threshold) {
299+
growTable(table.length * 2)
300+
index(hash)
301+
} else indexedHash
302+
put0(key, value, false, hash, newIndexedHash)
302303

303-
case (Some(_), Some(newValue)) => foundEntry.value = newValue
304-
}
305-
nextValue
304+
case (Some(_), Some(newValue)) => foundEntry.value = newValue
305+
}
306+
nextValue
306307
}
307308
}
308309

309-
override def valuesIterator: Iterator[V] = new AbstractIterator[V] {
310-
private[this] var cur = firstEntry
311-
def hasNext = cur ne null
312-
def next() =
313-
if (hasNext) { val res = cur.value; cur = cur.later; res }
314-
else Iterator.empty.next()
315-
}
310+
override def valuesIterator: Iterator[V] =
311+
if (size == 0) Iterator.empty
312+
else new LinkedHashMapIterator[V] {
313+
def extract(nd: Entry): V = nd.value
314+
}
316315

317316

318317
override def foreach[U](f: ((K, V)) => U): Unit = {
@@ -452,31 +451,18 @@ class LinkedHashMap[K, V]
452451
}
453452
}
454453

455-
override def hashCode(): Int = {
456-
abstract class LinkedHashMapIterator[A](val firstentry: Entry) extends AbstractIterator[A] {
457-
var cur = firstentry
458-
def extract(nd: Entry): A
459-
def hasNext: Boolean = cur ne null
460-
def next(): A =
461-
if(hasNext) {
462-
val r = extract(cur)
463-
cur = cur.later
464-
r
465-
} else Iterator.empty.next()
466-
}
467-
454+
override def hashCode: Int = {
468455
if (isEmpty) MurmurHash3.emptyMapHash
469456
else {
470-
val tupleHashIterator = new LinkedHashMapIterator[Any](firstEntry) {
457+
val tupleHashIterator = new LinkedHashMapIterator[Any] {
471458
var hash: Int = 0
472459
override def hashCode: Int = hash
473-
override def extract(nd: Entry): Any = {
460+
override def extract(nd: Entry): Any = {
474461
hash = MurmurHash3.tuple2Hash(unimproveHash(nd.hash), nd.value.##)
475462
this
476463
}
477464
}
478-
479-
MurmurHash3.orderedHash(tupleHashIterator, MurmurHash3.mapSeed)
465+
MurmurHash3.unorderedHash(tupleHashIterator, MurmurHash3.mapSeed)
480466
}
481467
}
482468
@nowarn("""cat=deprecation&origin=scala\.collection\.Iterable\.stringPrefix""")

library/src/scala/collection/mutable/LinkedHashSet.scala

Lines changed: 16 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ class LinkedHashSet[A]
4343
// stepper is not overridden to use XTableStepper because that stepper would not return the
4444
// elements in insertion order
4545

46-
type Entry = LinkedHashSet.Entry[A]
46+
/*private*/ type Entry = LinkedHashSet.Entry[A]
4747

4848
protected var firstEntry: Entry = null
4949

@@ -100,23 +100,21 @@ class LinkedHashSet[A]
100100

101101
override def remove(elem: A): Boolean = remove0(elem, computeHash(elem))
102102

103-
def iterator: Iterator[A] = new AbstractIterator[A] {
103+
private[this] abstract class LinkedHashSetIterator[T] extends AbstractIterator[T] {
104104
private[this] var cur = firstEntry
105-
def hasNext = cur ne null
106-
def next() =
107-
if (hasNext) { val res = cur.key; cur = cur.later; res }
105+
def extract(nd: Entry): T
106+
def hasNext: Boolean = cur ne null
107+
def next(): T =
108+
if (hasNext) { val r = extract(cur); cur = cur.later; r }
108109
else Iterator.empty.next()
109110
}
110-
private[collection] def entryIterator: Iterator[Entry] = new AbstractIterator[Entry] {
111-
private[this] var cur = firstEntry
112111

113-
def hasNext = cur ne null
112+
def iterator: Iterator[A] = new LinkedHashSetIterator[A] {
113+
override def extract(nd: Entry): A = nd.key
114+
}
114115

115-
def next() =
116-
if (hasNext) {
117-
val res = cur; cur = cur.later; res
118-
}
119-
else Iterator.empty.next()
116+
private[collection] def entryIterator: Iterator[Entry] = new LinkedHashSetIterator[Entry] {
117+
override def extract(nd: Entry): Entry = nd
120118
}
121119

122120
override def foreach[U](f: A => U): Unit = {
@@ -284,22 +282,11 @@ class LinkedHashSet[A]
284282
}
285283
}
286284

287-
override def hashCode(): Int = {
288-
abstract class LinkedHashSetIterator[B](val firstentry: Entry) extends AbstractIterator[B] {
289-
var cur = firstentry
290-
def extract(nd: Entry): B
291-
def hasNext: Boolean = cur ne null
292-
def next(): B =
293-
if (hasNext) {
294-
val r = extract(cur)
295-
cur = cur.later
296-
r
297-
} else Iterator.empty.next()
298-
}
299-
300-
val setHashIterator = if (isEmpty) this.iterator
285+
override def hashCode: Int = {
286+
val setHashIterator =
287+
if (isEmpty) this.iterator
301288
else {
302-
new LinkedHashSetIterator[Any](firstEntry) {
289+
new LinkedHashSetIterator[Any] {
303290
var hash: Int = 0
304291
override def hashCode: Int = hash
305292
override def extract(nd: Entry): Any = {
@@ -308,7 +295,7 @@ class LinkedHashSet[A]
308295
}
309296
}
310297
}
311-
MurmurHash3.orderedHash(setHashIterator, MurmurHash3.setSeed)
298+
MurmurHash3.unorderedHash(setHashIterator, MurmurHash3.setSeed)
312299
}
313300

314301
@nowarn("""cat=deprecation&origin=scala\.collection\.Iterable\.stringPrefix""")

0 commit comments

Comments
 (0)