Skip to content

Commit f9c6848

Browse files
committed
Push BigDecimal exclusion into Numeric
1 parent 4eac431 commit f9c6848

File tree

2 files changed

+22
-27
lines changed

2 files changed

+22
-27
lines changed

library/src/scala/collection/IterableOnce.scala

Lines changed: 4 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -936,11 +936,7 @@ trait IterableOnceOps[+A, +CC[_], +C] extends Any { this: IterableOnce[A] =>
936936

937937
/** Sums the elements of this collection.
938938
*
939-
* The default implementation uses `reduce` for a known non-empty collection,
940-
* `foldLeft` otherwise.
941-
*
942-
* If `foldLeft` is used, this implementation works around pollution of the math context
943-
* by ignoring the identity element.
939+
* The default implementation uses `reduce` for a known non-empty collection, `foldLeft` otherwise.
944940
*
945941
* $willNotTerminateInf
946942
*
@@ -951,24 +947,14 @@ trait IterableOnceOps[+A, +CC[_], +C] extends Any { this: IterableOnce[A] =>
951947
*/
952948
def sum[B >: A](implicit num: Numeric[B]): B =
953949
knownSize match {
954-
case -1 =>
955-
val z = num.zero
956-
if ((num eq Numeric.BigDecimalIsFractional) || (num eq Numeric.BigDecimalAsIfIntegral)) {
957-
def nonPollutingPlus(x: B, y: B): B = if (x.asInstanceOf[AnyRef] eq z.asInstanceOf[AnyRef]) y else num.plus(x, y)
958-
foldLeft(z)(nonPollutingPlus)
959-
}
960-
else foldLeft(z)(num.plus)
950+
case -1 => foldLeft(num.zero)(num.plus)
961951
case 0 => num.zero
962952
case _ => reduce(num.plus)
963953
}
964954

965955
/** Multiplies together the elements of this collection.
966956
*
967-
* The default implementation uses `reduce` for a known non-empty collection,
968-
* `foldLeft` otherwise.
969-
*
970-
* If `foldLeft` is used, this implementation works around pollution of the math context
971-
* by ignoring the identity element.
957+
* The default implementation uses `reduce` for a known non-empty collection, `foldLeft` otherwise.
972958
*
973959
* $willNotTerminateInf
974960
*
@@ -979,13 +965,7 @@ trait IterableOnceOps[+A, +CC[_], +C] extends Any { this: IterableOnce[A] =>
979965
*/
980966
def product[B >: A](implicit num: Numeric[B]): B =
981967
knownSize match {
982-
case -1 =>
983-
val u = num.one
984-
if ((num eq Numeric.BigDecimalIsFractional) || (num eq Numeric.BigDecimalAsIfIntegral)) {
985-
def nonPollutingProd(x: B, y: B): B = if (x.asInstanceOf[AnyRef] eq u.asInstanceOf[AnyRef]) y else num.times(x, y)
986-
foldLeft(u)(nonPollutingProd)
987-
}
988-
else foldLeft(u)(num.times)
968+
case -1 => foldLeft(num.one)(num.times)
989969
case 0 => num.one
990970
case _ => reduce(num.times)
991971
}

library/src/scala/math/Numeric.scala

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -177,9 +177,20 @@ object Numeric {
177177
implicit object DoubleIsFractional extends DoubleIsFractional with Ordering.Double.IeeeOrdering
178178

179179
trait BigDecimalIsConflicted extends Numeric[BigDecimal] {
180-
def plus(x: BigDecimal, y: BigDecimal): BigDecimal = x + y
181-
def minus(x: BigDecimal, y: BigDecimal): BigDecimal = x - y
182-
def times(x: BigDecimal, y: BigDecimal): BigDecimal = x * y
180+
// works around pollution of math context by ignoring identity element
181+
def plus(x: BigDecimal, y: BigDecimal): BigDecimal = {
182+
import BigDecimalIsConflicted._0
183+
if (x eq _0) y else x + y
184+
}
185+
def minus(x: BigDecimal, y: BigDecimal): BigDecimal = {
186+
import BigDecimalIsConflicted._0
187+
if (x eq _0) -y else x - y
188+
}
189+
// works around pollution of math context by ignoring identity element
190+
def times(x: BigDecimal, y: BigDecimal): BigDecimal = {
191+
import BigDecimalIsConflicted._1
192+
if (x eq _1) y else x * y
193+
}
183194
def negate(x: BigDecimal): BigDecimal = -x
184195
def fromInt(x: Int): BigDecimal = BigDecimal(x)
185196
def parseString(str: String): Option[BigDecimal] = Try(BigDecimal(str)).toOption
@@ -188,6 +199,10 @@ object Numeric {
188199
def toFloat(x: BigDecimal): Float = x.floatValue
189200
def toDouble(x: BigDecimal): Double = x.doubleValue
190201
}
202+
private object BigDecimalIsConflicted {
203+
private val _0 = BigDecimal(0) // cached zero is ordinarily cached for default math context
204+
private val _1 = BigDecimal(1) // cached one is ordinarily cached for default math context
205+
}
191206

192207
trait BigDecimalIsFractional extends BigDecimalIsConflicted with Fractional[BigDecimal] {
193208
def div(x: BigDecimal, y: BigDecimal): BigDecimal = x / y

0 commit comments

Comments
 (0)