Skip to content

Commit 152efb0

Browse files
authored
Merge pull request scala/scala#10448 from dragonfly-ai/2.13.x
Prevent ArrayBuilder capacity overflow/infinite looping in ArrayBuilder.ensureSize(size:Int).
2 parents 3646036 + 1f61034 commit 152efb0

File tree

5 files changed

+47
-46
lines changed

5 files changed

+47
-46
lines changed

library/src/scala/collection/IterableOnce.scala

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -283,13 +283,6 @@ object IterableOnce {
283283
case src: Iterable[A] => src.copyToArray[B](xs, start, len)
284284
case src => src.iterator.copyToArray[B](xs, start, len)
285285
}
286-
287-
@inline private[collection] def checkArraySizeWithinVMLimit(size: Int): Unit = {
288-
import scala.runtime.PStatics.VM_MaxArraySize
289-
if (size > VM_MaxArraySize) {
290-
throw new Exception(s"Size of array-backed collection exceeds VM array size limit of ${VM_MaxArraySize}")
291-
}
292-
}
293286
}
294287

295288
/** This implementation trait can be mixed into an `IterableOnce` to get the basic methods that are shared between

library/src/scala/collection/mutable/ArrayBuffer.scala

Lines changed: 29 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,10 @@ package collection
1515
package mutable
1616

1717
import java.util.Arrays
18-
19-
import scala.annotation.nowarn
20-
import scala.annotation.tailrec
18+
import scala.annotation.{nowarn, tailrec}
2119
import scala.collection.Stepper.EfficientSplit
2220
import scala.collection.generic.DefaultSerializable
21+
import scala.runtime.PStatics.VM_MaxArraySize
2322

2423
/** An implementation of the `Buffer` class using an array to
2524
* represent the assembled sequence internally. Append, update and random
@@ -70,13 +69,6 @@ class ArrayBuffer[A] private (initialElements: Array[AnyRef], initialSize: Int)
7069
array = ArrayBuffer.ensureSize(array, size0, n)
7170
}
7271

73-
// TODO 3.T: should be `protected`, perhaps `protected[this]`
74-
/** Ensure that the internal array has at least `n` additional cells more than `size0`. */
75-
private[mutable] def ensureAdditionalSize(n: Int): Unit = {
76-
// `.toLong` to ensure `Long` arithmetic is used and prevent `Int` overflow
77-
array = ArrayBuffer.ensureSize(array, size0, size0.toLong + n)
78-
}
79-
8072
/** Uses the given size to resize internal storage, if necessary.
8173
*
8274
* @param size Expected maximum number of elements.
@@ -147,10 +139,10 @@ class ArrayBuffer[A] private (initialElements: Array[AnyRef], initialSize: Int)
147139

148140
def addOne(elem: A): this.type = {
149141
mutationCount += 1
150-
ensureAdditionalSize(1)
151-
val oldSize = size0
152-
size0 = oldSize + 1
153-
this(oldSize) = elem
142+
val newSize = size0 + 1
143+
ensureSize(newSize)
144+
size0 = newSize
145+
this(size0 - 1) = elem
154146
this
155147
}
156148

@@ -161,7 +153,7 @@ class ArrayBuffer[A] private (initialElements: Array[AnyRef], initialSize: Int)
161153
val elemsLength = elems.size0
162154
if (elemsLength > 0) {
163155
mutationCount += 1
164-
ensureAdditionalSize(elemsLength)
156+
ensureSize(size0 + elemsLength)
165157
Array.copy(elems.array, 0, array, length, elemsLength)
166158
size0 = length + elemsLength
167159
}
@@ -173,7 +165,7 @@ class ArrayBuffer[A] private (initialElements: Array[AnyRef], initialSize: Int)
173165
def insert(@deprecatedName("n", "2.13.0") index: Int, elem: A): Unit = {
174166
checkWithinBounds(index, index)
175167
mutationCount += 1
176-
ensureAdditionalSize(1)
168+
ensureSize(size0 + 1)
177169
Array.copy(array, index, array, index + 1, size0 - index)
178170
size0 += 1
179171
this(index) = elem
@@ -191,7 +183,7 @@ class ArrayBuffer[A] private (initialElements: Array[AnyRef], initialSize: Int)
191183
val elemsLength = elems.size
192184
if (elemsLength > 0) {
193185
mutationCount += 1
194-
ensureAdditionalSize(elemsLength)
186+
ensureSize(size0 + elemsLength)
195187
val len = size0
196188
Array.copy(array, index, array, index + elemsLength, len - index)
197189
// if `elems eq this`, this copy is safe because
@@ -314,24 +306,35 @@ object ArrayBuffer extends StrictOptimizedSeqFactory[ArrayBuffer] {
314306

315307
def empty[A]: ArrayBuffer[A] = new ArrayBuffer[A]()
316308

309+
@inline private def checkArrayLengthLimit(length: Int, currentLength: Int): Unit =
310+
if (length > VM_MaxArraySize)
311+
throw new Exception(s"Array of array-backed collection exceeds VM length limit of $VM_MaxArraySize. Requested length: $length; current length: $currentLength")
312+
else if (length < 0)
313+
throw new Exception(s"Overflow while resizing array of array-backed collection. Requested length: $length; current length: $currentLength; increase: ${length - currentLength}")
314+
317315
/**
316+
* The increased size for an array-backed collection.
317+
*
318318
* @param arrayLen the length of the backing array
319319
* @param targetLen the minimum length to resize up to
320-
* @return -1 if no resizing is needed, or the size for the new array otherwise
320+
* @return
321+
* - `-1` if no resizing is needed, else
322+
* - `VM_MaxArraySize` if `arrayLen` is too large to be doubled, else
323+
* - `max(targetLen, arrayLen * 2, , DefaultInitialSize)`.
324+
* - Throws an exception if `targetLen` exceeds `VM_MaxArraySize` or is negative (overflow).
321325
*/
322-
private def resizeUp(arrayLen: Long, targetLen: Long): Int = {
323-
if (targetLen <= arrayLen) -1
326+
private[mutable] def resizeUp(arrayLen: Int, targetLen: Int): Int = {
327+
if (targetLen > 0 && targetLen <= arrayLen) -1
324328
else {
325-
if (targetLen > Int.MaxValue) throw new Exception(s"Collections cannot have more than ${Int.MaxValue} elements")
326-
IterableOnce.checkArraySizeWithinVMLimit(targetLen.toInt) // safe because `targetSize <= Int.MaxValue`
327-
328-
val newLen = math.max(targetLen, math.max(arrayLen * 2, DefaultInitialSize))
329-
math.min(newLen, scala.runtime.PStatics.VM_MaxArraySize).toInt
329+
checkArrayLengthLimit(targetLen, arrayLen)
330+
if (arrayLen > VM_MaxArraySize / 2) VM_MaxArraySize
331+
else math.max(targetLen, math.max(arrayLen * 2, DefaultInitialSize))
330332
}
331333
}
334+
332335
// if necessary, copy (curSize elements of) the array to a new array of capacity n.
333336
// Should use Array.copyOf(array, resizeEnsuring(array.length))?
334-
private def ensureSize(array: Array[AnyRef], curSize: Int, targetSize: Long): Array[AnyRef] = {
337+
private def ensureSize(array: Array[AnyRef], curSize: Int, targetSize: Int): Array[AnyRef] = {
335338
val newLen = resizeUp(array.length, targetSize)
336339
if (newLen < 0) array
337340
else {

library/src/scala/collection/mutable/ArrayBuilder.scala

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
package scala.collection
1414
package mutable
1515

16+
import scala.collection.mutable.ArrayBuffer.resizeUp
1617
import scala.reflect.ClassTag
1718

1819
/** A builder class for arrays.
@@ -34,15 +35,11 @@ sealed abstract class ArrayBuilder[T]
3435
override def knownSize: Int = size
3536

3637
protected[this] final def ensureSize(size: Int): Unit = {
37-
if (capacity < size || capacity == 0) {
38-
var newsize = if (capacity == 0) 16 else capacity * 2
39-
while (newsize < size) newsize *= 2
40-
resize(newsize)
41-
}
38+
val newLen = resizeUp(capacity, size)
39+
if (newLen > 0) resize(newLen)
4240
}
4341

44-
override final def sizeHint(size: Int): Unit =
45-
if (capacity < size) resize(size)
42+
override final def sizeHint(size: Int): Unit = if (capacity < size) resize(size)
4643

4744
def clear(): Unit = size = 0
4845

@@ -491,17 +488,23 @@ object ArrayBuilder {
491488
protected def elems: Array[Unit] = throw new UnsupportedOperationException()
492489

493490
def addOne(elem: Unit): this.type = {
494-
size += 1
491+
val newSize = size + 1
492+
ensureSize(newSize)
493+
size = newSize
495494
this
496495
}
497496

498497
override def addAll(xs: IterableOnce[Unit]): this.type = {
499-
size += xs.iterator.size
498+
val newSize = size + xs.iterator.size
499+
ensureSize(newSize)
500+
size = newSize
500501
this
501502
}
502503

503504
override def addAll(xs: Array[_ <: Unit], offset: Int, length: Int): this.type = {
504-
size += length
505+
val newSize = size + length
506+
ensureSize(newSize)
507+
size = newSize
505508
this
506509
}
507510

@@ -517,7 +520,7 @@ object ArrayBuilder {
517520
case _ => false
518521
}
519522

520-
protected[this] def resize(size: Int): Unit = ()
523+
protected[this] def resize(size: Int): Unit = capacity = size
521524

522525
override def toString = "ArrayBuilder.ofUnit"
523526
}

library/src/scala/collection/mutable/PriorityQueue.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ sealed class PriorityQueue[A](implicit val ord: Ordering[A])
8989
def p_size0_=(s: Int) = size0 = s
9090
def p_array = array
9191
def p_ensureSize(n: Int) = super.ensureSize(n)
92-
def p_ensureAdditionalSize(n: Int) = super.ensureAdditionalSize(n)
92+
def p_ensureAdditionalSize(n: Int) = super.ensureSize(size0 + n)
9393
def p_swap(a: Int, b: Int): Unit = {
9494
val h = array(a)
9595
array(a) = array(b)

library/src/scala/runtime/PStatics.scala

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,5 +15,7 @@ package scala.runtime
1515
// things that should be in `Statics`, but can't be yet for bincompat reasons
1616
// TODO 3.T: move to `Statics`
1717
private[scala] object PStatics {
18-
final val VM_MaxArraySize = 2147483645 // == `Int.MaxValue - 2`, hotspot limit
18+
// `Int.MaxValue - 8` traditional soft limit to maximize compatibility with diverse JVMs
19+
// See https://stackoverflow.com/a/8381338 for example
20+
final val VM_MaxArraySize = 2147483639
1921
}

0 commit comments

Comments
 (0)