Skip to content

Commit ad47749

Browse files
committed
Merge branch '1.4.x' into 1.4-release
2 parents 0b44c51 + 89dd536 commit ad47749

File tree

5 files changed

+543
-6
lines changed

5 files changed

+543
-6
lines changed

src/main/scala/dsptools/numbers/chisel_types/DspComplexTypeClass.scala

Lines changed: 37 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,18 +3,30 @@
33
package dsptools.numbers
44

55
import chisel3._
6+
import chisel3.experimental.FixedPoint
67
import dsptools.hasContext
78
import implicits._
89
import chisel3.util.ShiftRegister
910
import dsptools.DspException
1011

11-
class DspComplexRing[T <: Data:Ring] extends Ring[DspComplex[T]] with hasContext {
12+
abstract class DspComplexRing[T <: Data:Ring] extends Ring[DspComplex[T]] with hasContext {
1213
def plus(f: DspComplex[T], g: DspComplex[T]): DspComplex[T] = {
1314
DspComplex.wire(f.real + g.real, f.imag + g.imag)
1415
}
1516
def plusContext(f: DspComplex[T], g: DspComplex[T]): DspComplex[T] = {
1617
DspComplex.wire(f.real context_+ g.real, f.imag context_+ g.imag)
1718
}
19+
20+
/**
21+
* The builtin times calls +. Ideally we'd like to use growing addition, but we're relying on typeclasses and the
22+
* default + for UInt, SInt, etc. is wrapping. Thus, we're making an escape hatch just for the default (non-context)
23+
* complex multiply.
24+
* @param l
25+
* @param r
26+
* @return the sum of l and r, preferrably growing
27+
*/
28+
protected def plusForTimes(l: T, r: T): T
29+
1830
def times(f: DspComplex[T], g: DspComplex[T]): DspComplex[T] = {
1931
val c_p_d = g.real + g.imag
2032
val a_p_b = f.real + f.imag
@@ -59,6 +71,22 @@ class DspComplexRing[T <: Data:Ring] extends Ring[DspComplex[T]] with hasContext
5971
}
6072
}
6173

74+
class DspComplexRingUInt extends DspComplexRing[UInt] {
75+
override def plusForTimes(l: UInt, r: UInt): UInt = l +& r
76+
}
77+
78+
class DspComplexRingSInt extends DspComplexRing[SInt] {
79+
override def plusForTimes(l: SInt, r: SInt): SInt = l +& r
80+
}
81+
82+
class DspComplexRingFixed extends DspComplexRing[FixedPoint] {
83+
override def plusForTimes(l: FixedPoint, r: FixedPoint): FixedPoint = l +& r
84+
}
85+
86+
class DspComplexRingData[T <: Data : Ring] extends DspComplexRing[T] {
87+
override protected def plusForTimes(l: T, r: T): T = l + r
88+
}
89+
6290
class DspComplexEq[T <: Data:Eq] extends Eq[DspComplex[T]] with hasContext {
6391
override def eqv(x: DspComplex[T], y: DspComplex[T]): Bool = {
6492
Eq[T].eqv(x.real, y.real) && Eq[T].eqv(x.imag, y.imag)
@@ -81,9 +109,15 @@ class DspComplexBinaryRepresentation[T <: Data:Ring:BinaryRepresentation] extend
81109
DspComplex.wire(BinaryRepresentation[T].trimBinary(a.real, n), BinaryRepresentation[T].trimBinary(a.imag, n))
82110
}
83111

84-
trait DspComplexImpl {
85-
implicit def DspComplexRingImpl[T<: Data:Ring] = new DspComplexRing[T]()
112+
trait GenericDspComplexImpl {
113+
implicit def DspComplexRingDataImpl[T<: Data:Ring] = new DspComplexRingData[T]()
86114
implicit def DspComplexEq[T <: Data:Eq] = new DspComplexEq[T]()
87115
implicit def DspComplexBinaryRepresentation[T <: Data:Ring:BinaryRepresentation] =
88116
new DspComplexBinaryRepresentation[T]()
89117
}
118+
119+
trait DspComplexImpl extends GenericDspComplexImpl {
120+
implicit def DspComplexRingUIntImpl = new DspComplexRingUInt
121+
implicit def DspComplexRingSIntImpl = new DspComplexRingSInt
122+
implicit def DspComplexRingFixedImpl = new DspComplexRingFixed
123+
}
Lines changed: 289 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,289 @@
1+
package dsptools.numbers.rounding
2+
3+
import chisel3._
4+
import chisel3.experimental.{ChiselAnnotation, FixedPoint, RunFirrtlTransform, annotate, requireIsHardware}
5+
import firrtl.{CircuitForm, CircuitState, HighForm, MidForm, Transform}
6+
import firrtl.annotations.{
7+
SingleTargetAnnotation,
8+
ModuleName,
9+
Target
10+
}
11+
import firrtl.ir.{
12+
Block,
13+
DefModule,
14+
FixedType,
15+
IntWidth,
16+
Module => FModule,
17+
UIntType,
18+
SIntType
19+
}
20+
21+
import scala.collection.immutable.HashMap
22+
import scala.language.existentials
23+
24+
sealed trait SaturatingOp
25+
case object SaturatingAdd extends SaturatingOp
26+
case object SaturatingSub extends SaturatingOp
27+
28+
case class SaturateAnnotation(target: ModuleName, op: SaturatingOp, pipe: Int = 0) extends SingleTargetAnnotation[ModuleName] {
29+
def duplicate(t: ModuleName): SaturateAnnotation = this.copy(target = t)
30+
}
31+
32+
case class SaturateChiselAnnotation(target: SaturateDummyModule[_ <: Data], op: SaturatingOp, pipe: Int = 0) extends ChiselAnnotation with RunFirrtlTransform {
33+
def toFirrtl: SaturateAnnotation = SaturateAnnotation(target.toTarget, op = op, pipe = pipe)
34+
def transformClass: Class[SaturateTransform] = classOf[SaturateTransform]
35+
}
36+
37+
trait SaturateModule[T <: Data] extends MultiIOModule {
38+
val a: T
39+
val b: T
40+
val c: T
41+
}
42+
43+
class SaturateUIntAddModule(aWidth: Int, bWidth: Int, cWidth: Int, pipe: Int) extends SaturateModule[UInt] {
44+
require(pipe == 0, "pipe not implemented yet")
45+
46+
val a = IO(Input(UInt(aWidth.W)))
47+
val b = IO(Input(UInt(bWidth.W)))
48+
val c = IO(Output(UInt(cWidth.W)))
49+
50+
val max = ((1 << cWidth) - 1).U
51+
val sumWithGrow = a +& b
52+
val tooBig = sumWithGrow(cWidth)
53+
val sum = sumWithGrow(cWidth - 1, 0)
54+
55+
c := Mux(tooBig, max, sum)
56+
}
57+
58+
class SaturateUIntSubModule(aWidth: Int, bWidth: Int, cWidth: Int, pipe: Int) extends SaturateModule[UInt] {
59+
require(pipe == 0, "pipe not implemented yet")
60+
val a = IO(Input(UInt(aWidth.W)))
61+
val b = IO(Input(UInt(bWidth.W)))
62+
val c = IO(Output(UInt(cWidth.W)))
63+
64+
val tooSmall = a < b
65+
val diff = a -% b
66+
67+
c := Mux(tooSmall, 0.U, diff)
68+
}
69+
70+
class SaturateSIntAddModule(aWidth: Int, bWidth: Int, cWidth: Int, pipe: Int) extends SaturateModule[SInt] {
71+
require(pipe == 0, "pipe not implemented yet")
72+
val a = IO(Input(SInt(aWidth.W)))
73+
val b = IO(Input(SInt(bWidth.W)))
74+
val c = IO(Output(SInt(cWidth.W)))
75+
76+
val abWidth = aWidth max bWidth
77+
val max = ((1 << (cWidth - 1)) - 1).S
78+
val min = (-(1 << (cWidth - 1))).S
79+
val sumWithGrow = a +& b
80+
81+
val tooBig = !sumWithGrow(abWidth) && sumWithGrow(abWidth - 1)
82+
val tooSmall = sumWithGrow(abWidth) && !sumWithGrow(abWidth - 1)
83+
84+
val sum = sumWithGrow(abWidth - 1, 0).asSInt
85+
val fixTop = Mux(tooBig, max, sum)
86+
val fixTopAndBottom = Mux(tooSmall, min, fixTop)
87+
88+
c := fixTopAndBottom
89+
}
90+
91+
class SaturateSIntSubModule(aWidth: Int, bWidth: Int, cWidth: Int, pipe: Int) extends SaturateModule[SInt] {
92+
require(pipe == 0, "pipe not implemented yet")
93+
val a = IO(Input(SInt(aWidth.W)))
94+
val b = IO(Input(SInt(bWidth.W)))
95+
val c = IO(Output(SInt(cWidth.W)))
96+
97+
val abWidth = aWidth max bWidth
98+
val max = ((1 << (cWidth - 1)) - 1).S
99+
val min = (-(1 << (cWidth - 1))).S
100+
val sumWithGrow = a -& b
101+
102+
val tooBig = !sumWithGrow(abWidth) && sumWithGrow(abWidth - 1)
103+
val tooSmall = sumWithGrow(abWidth) && !sumWithGrow(abWidth - 1)
104+
105+
val sum = sumWithGrow(cWidth - 1, 0).asSInt
106+
val fixTop = Mux(tooBig, max, sum)
107+
val fixTopAndBottom = Mux(tooSmall, min, fixTop)
108+
109+
c := fixTopAndBottom
110+
}
111+
112+
class SaturateFixedPointAddModule(
113+
aWidth: Int, aBP: Int,
114+
bWidth: Int, bBP: Int,
115+
cWidth: Int, cBP: Int,
116+
pipe: Int) extends SaturateModule[FixedPoint] {
117+
require(pipe == 0, "pipe not implemented yet")
118+
119+
val a = IO(Input(FixedPoint(aWidth.W, aBP.BP)))
120+
val b = IO(Input(FixedPoint(bWidth.W, bBP.BP)))
121+
val c = IO(Output(FixedPoint(cWidth.W, cBP.BP)))
122+
123+
124+
val max = (math.pow(2, (cWidth - cBP - 1)) - math.pow(2, -cBP)).F(cWidth.W, cBP.BP)
125+
val min = (-math.pow(2, (cWidth - cBP - 1))).F(cWidth.W, cBP.BP)
126+
val sumWithGrow = a +& b
127+
128+
val tooBig = !sumWithGrow(cWidth) && sumWithGrow(cWidth - 1)
129+
val tooSmall = sumWithGrow(cWidth) && !sumWithGrow(cWidth - 1)
130+
131+
val sum = sumWithGrow(cWidth - 1, 0).asFixedPoint(cBP.BP)
132+
val fixTop = Mux(tooBig, max, sum)
133+
val fixTopAndBottom = Mux(tooSmall, min, fixTop)
134+
135+
c := fixTopAndBottom
136+
}
137+
138+
class SaturateFixedPointSubModule(
139+
aWidth: Int, aBP: Int,
140+
bWidth: Int, bBP: Int,
141+
cWidth: Int, cBP: Int,
142+
pipe: Int) extends SaturateModule[FixedPoint] {
143+
require(pipe == 0, "pipe not implemented yet")
144+
145+
val a = IO(Input(FixedPoint(aWidth.W, aBP.BP)))
146+
val b = IO(Input(FixedPoint(bWidth.W, bBP.BP)))
147+
val c = IO(Output(FixedPoint(cWidth.W, cBP.BP)))
148+
149+
val max = (math.pow(2, (cWidth - cBP - 1)) - math.pow(2, -cBP)).F(cWidth.W, cBP.BP)
150+
val min = (-math.pow(2, (cWidth - cBP - 1))).F(cWidth.W, cBP.BP)
151+
val diffWithGrow = a -& b
152+
153+
val tooBig = !diffWithGrow(cWidth) && diffWithGrow(cWidth - 1)
154+
val tooSmall = diffWithGrow(cWidth) && !diffWithGrow(cWidth - 1)
155+
156+
val diff = diffWithGrow(cWidth - 1, 0).asFixedPoint(cBP.BP)
157+
val fixTop = Mux(tooBig, max, diff)
158+
val fixTopAndBottom = Mux(tooSmall, min, fixTop)
159+
160+
c := fixTopAndBottom
161+
}
162+
163+
/**
164+
* A module that serves as a placeholder for a saturating op.
165+
* The frontend can't implement saturation easily when widths are unknown. This
166+
* module inserts a dummy op that has the desired behavior in FIRRTL's width
167+
* inference process. After width inference, this module will be replaced by an
168+
* implementation of a saturating op.
169+
*/
170+
class SaturateDummyModule[T <: Data](aOutside: T, bOutside: T, op: (T, T) => T) extends SaturateModule[T] {
171+
// this module should always be replaced in a transform
172+
// throw in this assertion in case it isn't
173+
assert(false.B)
174+
val a = IO(Input(chiselTypeOf(aOutside)))
175+
val b = IO(Input(chiselTypeOf(bOutside)))
176+
val res = op(a, b)
177+
val c = IO(Output(chiselTypeOf(res)))
178+
c := res
179+
}
180+
181+
object Saturate {
182+
private def op[T <: Data](a: T, b: T, widthOp: (T, T) => T, realOp: SaturatingOp, pipe: Int = 0): T = {
183+
requireIsHardware(a)
184+
requireIsHardware(b)
185+
val saturate = Module(new SaturateDummyModule(a, b, widthOp))
186+
val anno = SaturateChiselAnnotation(saturate, realOp, pipe)
187+
annotate(anno)
188+
saturate.a := a
189+
saturate.b := b
190+
saturate.c
191+
}
192+
def addUInt(a: UInt, b: UInt, pipe: Int = 0): UInt = {
193+
op(a, b, { (l: UInt, r: UInt) => l +% r }, SaturatingAdd, pipe)
194+
}
195+
def addSInt(a: SInt, b: SInt, pipe: Int = 0): SInt = {
196+
op(a, b, { (l: SInt, r: SInt) => l +% r }, SaturatingAdd, pipe)
197+
}
198+
def addFixedPoint(a: FixedPoint, b: FixedPoint, pipe: Int = 0): FixedPoint = {
199+
op(a, b, { (l: FixedPoint, r: FixedPoint) => (l +& r) >> 1 }, SaturatingAdd, pipe)
200+
}
201+
def subUInt(a: UInt, b: UInt, pipe: Int = 0): UInt = {
202+
op(a, b, { (l: UInt, r: UInt) => l -% r }, SaturatingSub, pipe)
203+
}
204+
def subSInt(a: SInt, b: SInt, pipe: Int = 0): SInt = {
205+
op(a, b, { (l: SInt, r: SInt) => l -% r }, SaturatingSub, pipe)
206+
}
207+
def subFixedPoint(a: FixedPoint, b: FixedPoint, pipe: Int = 0): FixedPoint = {
208+
op(a, b, { (l: FixedPoint, r: FixedPoint) => (l -& r) >> 1 }, SaturatingSub, pipe)
209+
}
210+
}
211+
212+
class SaturateTransform extends Transform {
213+
def inputForm: CircuitForm = MidForm
214+
def outputForm: CircuitForm = HighForm
215+
216+
private def replaceMod(m: FModule, anno: SaturateAnnotation): FModule = {
217+
val aTpe = m.ports.find(_.name == "a").map(_.tpe).getOrElse(throw new Exception("a not found"))
218+
val bTpe = m.ports.find(_.name == "b").map(_.tpe).getOrElse(throw new Exception("b not found"))
219+
val cTpe = m.ports.find(_.name == "c").map(_.tpe).getOrElse(throw new Exception("c not found"))
220+
221+
val newMod = (aTpe, bTpe, cTpe, anno) match {
222+
case (
223+
UIntType(IntWidth(aWidth)),
224+
UIntType(IntWidth(bWidth)),
225+
UIntType(IntWidth(cWidth)),
226+
SaturateAnnotation(_, SaturatingAdd, pipe)) =>
227+
() => new SaturateUIntAddModule(aWidth.toInt, bWidth.toInt, cWidth.toInt, pipe = pipe)
228+
case (
229+
UIntType(IntWidth(aWidth)),
230+
UIntType(IntWidth(bWidth)),
231+
UIntType(IntWidth(cWidth)),
232+
SaturateAnnotation(_, SaturatingSub, pipe)) =>
233+
() => new SaturateUIntSubModule(aWidth.toInt, bWidth.toInt, cWidth.toInt, pipe = pipe)
234+
case (
235+
SIntType(IntWidth(aWidth)),
236+
SIntType(IntWidth(bWidth)),
237+
SIntType(IntWidth(cWidth)),
238+
SaturateAnnotation(_, SaturatingAdd, pipe)) =>
239+
() => new SaturateSIntAddModule(aWidth.toInt, bWidth.toInt, cWidth.toInt, pipe = pipe)
240+
case (
241+
SIntType(IntWidth(aWidth)),
242+
SIntType(IntWidth(bWidth)),
243+
SIntType(IntWidth(cWidth)),
244+
SaturateAnnotation(_, SaturatingSub, pipe)) =>
245+
() => new SaturateSIntSubModule(aWidth.toInt, bWidth.toInt, cWidth.toInt, pipe = pipe)
246+
case (
247+
FixedType(IntWidth(aWidth), IntWidth(aBP)),
248+
FixedType(IntWidth(bWidth), IntWidth(bBP)),
249+
FixedType(IntWidth(cWidth), IntWidth(cBP)),
250+
SaturateAnnotation(_, SaturatingAdd, pipe)) =>
251+
() => new SaturateFixedPointAddModule(aWidth.toInt, aBP.toInt, bWidth.toInt, bBP.toInt, (cWidth - 1).toInt, cBP.toInt, pipe = pipe)
252+
case (
253+
FixedType(IntWidth(aWidth), IntWidth(aBP)),
254+
FixedType(IntWidth(bWidth), IntWidth(bBP)),
255+
FixedType(IntWidth(cWidth), IntWidth(cBP)),
256+
SaturateAnnotation(_, SaturatingSub, pipe)) =>
257+
() => new SaturateFixedPointSubModule(aWidth.toInt, aBP.toInt, bWidth.toInt, bBP.toInt, (cWidth - 1).toInt, cBP.toInt, pipe = pipe)
258+
}
259+
// get new body from newMod (must be single module!)
260+
val newBody = Driver.toFirrtl(Driver.elaborate(newMod)).modules.head match {
261+
case FModule(_, _, _, body) => body
262+
case _ => throw new Exception("Saw blackbox for some reason")
263+
}
264+
m.copy(body = newBody)
265+
}
266+
267+
private def onModule(annos: Seq[SaturateAnnotation]) = {
268+
val annoByName: HashMap[String, SaturateAnnotation] = HashMap(annos.map({ a => a.target.name -> a }): _*)
269+
object SaturateAnnotation {
270+
def unapply(name: String): Option[SaturateAnnotation] = {
271+
annoByName.get(name)
272+
}
273+
}
274+
def onModuleInner(m: DefModule): DefModule = m match {
275+
case m@FModule(_, SaturateAnnotation(a), _, _) =>
276+
replaceMod(m, a)
277+
case m => m
278+
}
279+
onModuleInner(_)
280+
}
281+
282+
def execute(state: CircuitState): CircuitState = {
283+
val annos = state.annotations.collect {
284+
case a: SaturateAnnotation => a
285+
}
286+
state.copy(circuit = state.circuit.copy(modules =
287+
state.circuit.modules.map(onModule(annos))))
288+
}
289+
}

src/test/scala/dsptools/numbers/DspComplexSpec.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
// See LICENSE for license details.
22

3-
package dsptools.numbers
3+
package testing.dsptools.numbers
44

55
import chisel3._
66
import chisel3.iotesters.ChiselPropSpec
77
import chisel3.testers.BasicTester
8-
import dsptools.numbers.implicits._
8+
import dsptools.numbers._
99

1010
//scalastyle:off magic.number
1111
class DspComplexExamples extends Module {

0 commit comments

Comments
 (0)