Skip to content

Commit 1893874

Browse files
committed
Merge branch 'master' into 1.4.x
2 parents 89dd536 + d5a6300 commit 1893874

File tree

21 files changed

+448
-11
lines changed

21 files changed

+448
-11
lines changed

.scalafmt.conf

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
maxColumn = 120
2+
align = most
3+
continuationIndent.defnSite = 2
4+
assumeStandardLibraryStripMargin = true
5+
docstrings = ScalaDoc
6+
lineEndings = preserve
7+
includeCurlyBraceInSelectChains = false
8+
danglingParentheses = true
9+
10+
align.tokens.add = [
11+
{
12+
code = ":"
13+
}
14+
]
15+
16+
newlines.alwaysBeforeCurlyBraceLambdaParams
 = false
17+
18+
optIn.annotationNewlines = true
19+
20+
rewrite.rules = [SortImports, PreferCurlyFors, AvoidInfix]

rocket/src/main/scala/amba/apb/Node.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ case class BundleBridgeToAPBNode(masterParams: APBMasterPortParameters)(implicit
4242
dFn = { mp =>
4343
masterParams
4444
},
45-
uFn = { slaveParams => BundleBridgeParams(None) }
45+
uFn = { slaveParams => BundleBridgeNull() }
4646
)
4747

4848
object BundleBridgeToAPBNode {
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
// See LICENSE for license details.
2+
3+
package dsptools.dspmath
4+
5+
object ExtendedEuclid {
6+
/** Extended Euclidean Algorithm
7+
* ax + by = gcd(a, b)
8+
* Inputs: a, b
9+
* Outputs: gcd, x, y
10+
*/
11+
def egcd(a: Int, b: Int): (Int, Int, Int) = {
12+
if (a == 0) {
13+
(b, 0, 1)
14+
} else {
15+
val (gcd, y, x) = egcd(b % a, a)
16+
(gcd, x - (b / a) * y, y)
17+
}
18+
}
19+
}
Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,135 @@
1+
// See LICENSE for license details.
2+
3+
package dsptools.dspmath
4+
5+
import org.scalatest.{FlatSpec, Matchers}
6+
7+
case class RadPow(rad: Int, pow: Int) {
8+
/** `r ^ p` */
9+
def get: Int = BigInt(rad).pow(pow).toInt
10+
/** Factorize i.e. rad = 4, pow = 3 -> Seq(4, 4, 4) */
11+
def factorize: Seq[Int] = Seq.fill(pow)(rad)
12+
}
13+
14+
case class Factorization(supportedRadsUnsorted: Seq[Seq[Int]]) {
15+
/** Supported radices, MSD First */
16+
private val supportedRads = supportedRadsUnsorted.map(_.sorted.reverse)
17+
18+
/** Factor n into powers of supported radices and store RadPow i.e. r^p, separated by coprimes
19+
* i.e. if supportedRads = Seq(Seq(4, 2), Seq(3)),
20+
* output = Seq(Seq(RadPow(4, 5), RadPow(2, 1)), Seq(RadPow(3, 7)))
21+
* implies n = 4^5 * 2^1 * 3^7
22+
*/
23+
private def getRadPows(n: Int): Seq[Seq[RadPow]] = {
24+
// Test if n can be factored by each of the supported radices (mod = 0)
25+
// Count # of times it can be factored
26+
var unfactorized = n
27+
val radPows = for (primeGroup <- supportedRads) yield { for (rad <- primeGroup) yield {
28+
var (mod, pow) = (0, 0)
29+
while (mod == 0) {
30+
mod = unfactorized % rad
31+
if (mod == 0) {
32+
pow = pow + 1
33+
unfactorized = unfactorized / rad
34+
}
35+
}
36+
RadPow(rad, pow)
37+
}}
38+
// If n hasn't completely been factorized, then an unsupported radix is required
39+
require(unfactorized == 1, s"$n is invalid for supportedRads.")
40+
radPows
41+
}
42+
43+
/** Factor n into powers of supported radices (flattened)
44+
* i.e. if supportedRads = Seq(Seq(4, 2), Seq(3)),
45+
* output = Seq(5, 1, 7)
46+
* implies `n = 4^5 * 2^1 * 3^7`
47+
* If supportedRads contains more radices than the ones used, a power of 0 will be
48+
* associated with the unused radices.
49+
*/
50+
def getPowsFlat(n: Int): Seq[Int] = {
51+
getRadPows(n).flatMap(_.map(_.pow))
52+
}
53+
54+
/** Break n into coprimes i.e.
55+
* n = 4^5 * 2^1 * 3^7
56+
* would result in Seq(4^5 * 2^1, 3^7)
57+
* If supportedRads contains more coprime groups than the ones used, 1 will be
58+
* associated with the unused groups.
59+
*/
60+
def getCoprimes(n: Int): Seq[Int] = {
61+
getRadPows(n).map(_.map(_.get).product)
62+
}
63+
64+
/** Factorizes the coprime into digit radices (mixed radix)
65+
* i.e. n = 8 -> Seq(4, 2)
66+
* Note: there's no padding!
67+
*/
68+
def factorizeCoprime(n: Int): Seq[Int] = {
69+
// i.e. if supportedRads = Seq(Seq(4, 2), Seq(3)) and n = 8,
70+
// correspondingPrimeGroup = Seq(4, 2)
71+
val correspondingPrimeGroup = supportedRads.filter(n % _.min == 0)
72+
require(correspondingPrimeGroup.length == 1, "n (coprime) must not be divisible by other primes.")
73+
// Factorize coprime -- only correspondingPrimeGroup should actually add to factorization length
74+
getRadPows(n).flatten.flatMap(_.factorize)
75+
}
76+
77+
/** Gets associated base prime for n (assuming n isn't divisible by other primes)
78+
* WARNING: Assumes supportedRads contains the base prime!
79+
*/
80+
def getBasePrime(n: Int): Int = {
81+
val primeTemp = supportedRads.map(_.min).filter(n % _ == 0)
82+
require(primeTemp.length == 1, "n should only be divisible by 1 prime")
83+
primeTemp.head
84+
}
85+
86+
}
87+
88+
class FactorizationSpec extends FlatSpec with Matchers {
89+
90+
val testSupportedRads = Seq(Seq(4, 2), Seq(3), Seq(5), Seq(7))
91+
92+
behavior of "Factorization"
93+
it should "properly factorize" in {
94+
case class FactorizationTest(n: Int, pows: Seq[Int], coprimes: Seq[Int])
95+
val tests = Seq(
96+
FactorizationTest(
97+
n = (math.pow(4, 5) * math.pow(2, 1) * math.pow(3, 7)).toInt,
98+
pows = Seq(5, 1, 7),
99+
coprimes = Seq((math.pow(4, 5) * math.pow(2, 1)).toInt, math.pow(3, 7).toInt)
100+
),
101+
FactorizationTest(n = 15, pows = Seq(0, 0, 1, 1), coprimes = Seq(1, 3, 5))
102+
)
103+
104+
tests foreach { case FactorizationTest(n, pows, coprimes) =>
105+
val powsFill = Seq.fill(testSupportedRads.flatten.length - pows.length)(0)
106+
val coprimesFill = Seq.fill(testSupportedRads.length - coprimes.length)(1)
107+
require(
108+
Factorization(testSupportedRads).getPowsFlat(n) == pows ++ powsFill,
109+
"Should factorize to get the right powers -- includes padding."
110+
)
111+
require(
112+
Factorization(testSupportedRads).getCoprimes(n) == coprimes ++ coprimesFill,
113+
"Should factorize into the right coprimes -- includes padding."
114+
)
115+
}
116+
}
117+
118+
it should "properly factorize coprime" in {
119+
case class CoprimeFactorizationTest(n: Int, factorization: Seq[Int], basePrime: Int)
120+
val tests = Seq(
121+
CoprimeFactorizationTest(n = 8, factorization = Seq(4, 2), basePrime = 2),
122+
CoprimeFactorizationTest(n = 16, factorization = Seq(4, 4), basePrime = 2)
123+
)
124+
tests foreach { case CoprimeFactorizationTest(n, factorization, basePrime) =>
125+
require(
126+
Factorization(testSupportedRads).factorizeCoprime(n) == factorization,
127+
"Should factorize coprime properly."
128+
)
129+
require(
130+
Factorization(testSupportedRads).getBasePrime(n) == basePrime,
131+
"Should get the correct base prime."
132+
)
133+
}
134+
}
135+
}
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
// See LICENSE for license details.
2+
3+
package dsptools.intervals
4+
5+
import chisel3.internal.firrtl.IntervalRange
6+
import firrtl.ir.Closed
7+
8+
object IAUtility {
9+
10+
/** Expand range by n (twice of `halfn`). If n is negative, shrink. */
11+
def expandBy(range: IntervalRange, halfn: Double): IntervalRange = {
12+
val newMinT = getMin(range) - halfn
13+
val newMax = getMax(range) + halfn
14+
val newMin = if (newMinT > newMax) newMax else newMinT
15+
if (newMinT > newMax) println("Attempting to shrink range too much!")
16+
/*
17+
println(
18+
(if (halfn < 0) s"[shrink $halfn]: " else s"[expand $halfn]: ") +
19+
s"old min: ${getMin(range)} old max: ${getMax(range)}; " +
20+
s"new min: $newMin new max: $newMax"
21+
)
22+
*/
23+
IntervalRange(Closed(newMin), Closed(newMax), range.binaryPoint)
24+
}
25+
26+
/** Shift range to the right by n. If n is negative, shift left. */
27+
def shiftRightBy(range: IntervalRange, n: Double): IntervalRange = {
28+
val newMin = getMin(range) + n
29+
val newMax = getMax(range) + n
30+
IntervalRange(Closed(newMin), Closed(newMax), range.binaryPoint)
31+
}
32+
33+
/** Check if range contains negative numbers. */
34+
def containsNegative(range: IntervalRange): Boolean = getMin(range) < 0
35+
36+
/** Get the # of bits required to represent the integer portion of the
37+
* rounded bounds (including sign bit if necessary).
38+
*/
39+
def getIntWidth(range: IntervalRange):Int = {
40+
require(range.binaryPoint.get == 0, "getIntWidth only works for bp = 0")
41+
val min = range.getLowestPossibleValue.get.toBigInt()
42+
val max = range.getHighestPossibleValue.get.toBigInt()
43+
44+
val minWidth = if (min < 0) min.bitLength + 1 else min.bitLength
45+
val maxWidth = if (max < 0) max.bitLength + 1 else max.bitLength
46+
math.max(minWidth, maxWidth)
47+
}
48+
49+
/** Gets min */
50+
def getMin(range: IntervalRange): BigDecimal = range.getLowestPossibleValue.get
51+
/** Gets max */
52+
def getMax(range: IntervalRange): BigDecimal = range.getPossibleValues.max
53+
/** Gets width of range */
54+
def getRange(range: IntervalRange): BigDecimal = getMax(range) - getMin(range)
55+
56+
}
Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
// See LICENSE for license details.
2+
3+
package dsptools.misc
4+
5+
object BitWidth {
6+
/**
7+
* Utility function that computes bits required for a number
8+
*
9+
* @param n number of interest
10+
* @return
11+
*/
12+
def computeBits(n: BigInt): Int = {
13+
n.bitLength + (if(n < 0) 1 else 0)
14+
}
15+
16+
/**
17+
* return the smallest number of bits required to hold the given number in
18+
* an SInt
19+
* Note: positive numbers will get one minimum width one higher than would be
20+
* required for a UInt
21+
*
22+
* @param num number to find width for
23+
* @return minimum required bits for an SInt
24+
*/
25+
def requiredBitsForSInt(num: BigInt): Int = {
26+
if(num == BigInt(0) || num == -BigInt(1)) {
27+
1
28+
}
29+
else {
30+
if (num < 0) {
31+
computeBits(num)
32+
}
33+
else {
34+
computeBits(num) + 1
35+
}
36+
}
37+
}
38+
39+
def requiredBitsForSInt(low: BigInt, high: BigInt): Int = {
40+
requiredBitsForSInt(low).max(requiredBitsForSInt(high))
41+
}
42+
43+
/**
44+
* return the smallest number of bits required to hold the given number in
45+
* an UInt
46+
* Note: positive numbers will get one minimum width one higher than would be
47+
* required for a UInt
48+
*
49+
* @param num number to find width for
50+
* @return minimum required bits for an SInt
51+
*/
52+
def requiredBitsForUInt(num: BigInt): Int = {
53+
if(num == BigInt(0)) {
54+
1
55+
}
56+
else {
57+
computeBits(num)
58+
}
59+
}
60+
}

src/main/scala/dsptools/numbers/algebra_types/helpers/Sign.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,8 +48,8 @@ object Sign {
4848
case object Negative extends Sign(Some(false), Some(true))
4949

5050
def apply(zero: Bool, neg: Bool): Sign = {
51-
val zeroLit = zero.litArg.map{_.num != BigInt(0)}
52-
val negLit = neg.litArg.map{_.num != BigInt(0)}
51+
val zeroLit = zero.litOption.map{_ != BigInt(0)}
52+
val negLit = neg.litOption.map{_ != BigInt(0)}
5353
val isLit = zeroLit.isDefined && negLit.isDefined
5454
val wireWrapIfNotLit: Sign => Sign = s => if (isLit) { s } else Wire(s)
5555
val bundle = wireWrapIfNotLit(

src/main/scala/dsptools/numbers/binary_types/BinaryRepresentation.scala

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
// See LICENSE for license details.
2+
13
package dsptools.numbers
24

35
import chisel3.{Data, UInt, Bool}
@@ -20,5 +22,8 @@ trait BinaryRepresentation[A <: Data] extends Any {
2022
def mul2(a: A, n: Int): A = shl(a, n)
2123
// Trim to n fractional bits (with DspContext) -- doesn't affect DspReal
2224
def trimBinary(a: A, n: Int): A = trimBinary(a, Some(n))
23-
def trimBinary(a: A, n: Option[Int]): A
25+
def trimBinary(a: A, n: Option[Int]): A
26+
27+
// Clip A to B (range)
28+
def clip(a: A, b: A): A
2429
}

src/main/scala/dsptools/numbers/chisel_concrete/DspComplex.scala

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@ package dsptools.numbers
55
import chisel3._
66
import chisel3.experimental.{FixedPoint, Interval}
77
import dsptools.DspException
8-
import implicits._
98
import breeze.math.Complex
109

1110
object DspComplex {

src/main/scala/dsptools/numbers/chisel_types/DspComplexTypeClass.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,14 +96,15 @@ class DspComplexEq[T <: Data:Eq] extends Eq[DspComplex[T]] with hasContext {
9696
}
9797
}
9898

99-
class DspComplexBinaryRepresentation[T <: Data:Ring:BinaryRepresentation] extends
99+
class DspComplexBinaryRepresentation[T <: Data:Ring:BinaryRepresentation] extends
100100
BinaryRepresentation[DspComplex[T]] with hasContext {
101101
override def shl(a: DspComplex[T], n: Int): DspComplex[T] = throw DspException("Can't shl on complex")
102102
override def shl(a: DspComplex[T], n: UInt): DspComplex[T] = throw DspException("Can't shl on complex")
103103
override def shr(a: DspComplex[T], n: Int): DspComplex[T] = throw DspException("Can't shr on complex")
104104
override def shr(a: DspComplex[T], n: UInt): DspComplex[T] = throw DspException("Can't shr on complex")
105105
override def div2(a: DspComplex[T], n: Int): DspComplex[T] = DspComplex.wire(a.real.div2(n), a.imag.div2(n))
106106
override def mul2(a: DspComplex[T], n: Int): DspComplex[T] = DspComplex.wire(a.real.mul2(n), a.imag.mul2(n))
107+
def clip(a: DspComplex[T], b: DspComplex[T]): DspComplex[T] = throw DspException("Can't clip on complex")
107108
def signBit(a: DspComplex[T]): Bool = throw DspException("Can't get sign bit on complex")
108109
def trimBinary(a: DspComplex[T], n: Option[Int]): DspComplex[T] =
109110
DspComplex.wire(BinaryRepresentation[T].trimBinary(a.real, n), BinaryRepresentation[T].trimBinary(a.imag, n))

0 commit comments

Comments
 (0)