Skip to content

Commit cb42069

Browse files
authored
Refactor code and add fixes for binary point edge cases (ucb-bar#6)
* Refactor code * Replace pattern matching with function calls where sensible * Remove superfluous use of asSInt * Minimize the use of _inferredBinaryPoint * Add source info to asFixedPoint * Handle BinaryPoint edge cases * Disallow negative binary point * Ensure that floor, ceil, and round don't change width or binary point * Don't force width in setBinaryPoint since binary point difference can be greater than data width * Add test cases for when binary point is greater than or equal to width
1 parent 35dda16 commit cb42069

File tree

4 files changed

+151
-104
lines changed

4 files changed

+151
-104
lines changed

src/main/scala/fixedpoint/FixedPoint.scala

Lines changed: 72 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -88,28 +88,24 @@ object FixedPoint extends NumObject {
8888
*/
8989
private[fixedpoint] def fromData(
9090
binaryPoint: BinaryPoint,
91-
data: SInt,
91+
data: Data,
9292
widthOption: Option[Width] = None
9393
)(
9494
implicit sourceInfo: SourceInfo,
9595
compileOptions: CompileOptions
9696
): FixedPoint = {
9797
val _new = Wire(
9898
FixedPoint(
99-
widthOption match {
100-
case Some(width) => width
101-
case None => recreateWidth(data)
102-
},
99+
widthOption.getOrElse(recreateWidth(data)),
103100
binaryPoint
104101
)
105102
)
106-
_new.data := data
103+
_new.data := data.asTypeOf(_new.data)
107104
_new
108105
}
109106

110-
private[fixedpoint] def recreateWidth[T <: Data](d: T): Width = d.widthOption match {
111-
case Some(w) => w.W
112-
case None => UnknownWidth()
107+
private[fixedpoint] def recreateWidth[T <: Data](d: T): Width = {
108+
d.widthOption.fold[Width](UnknownWidth())(_.W)
113109
}
114110

115111
/** Align all FixedPoints in a (possibly heterogeneous) sequence by width and binary point
@@ -120,34 +116,32 @@ object FixedPoint extends NumObject {
120116
implicit sourceInfo: SourceInfo,
121117
compileOptions: CompileOptions
122118
): Seq[T] = {
123-
124119
val bps = in.collect {
125120
case el: FixedPoint =>
126121
el.requireKnownBP()
127122
el.binaryPoint
128123
}
129124

130-
val out: Iterable[T] = if (bps.nonEmpty) {
131-
val maxBP = bps.fold(0.BP)(_.max(_))
132-
val maxWidth = in.map { el =>
133-
val width = recreateWidth(el)
134-
val extra = el match {
135-
case el: FixedPoint => maxBP.get - el.binaryPoint.get
136-
case _ => 0
125+
val out =
126+
if (bps.isEmpty) in
127+
else {
128+
val maxBP = bps.fold(0.BP)(_.max(_))
129+
val maxWidth = in.map {
130+
case el: FixedPoint => recreateWidth(el) + (maxBP.get - el.binaryPoint.get)
131+
case nonFp => recreateWidth(nonFp)
132+
}.fold(0.W)(_.max(_))
133+
134+
in.map {
135+
case el: FixedPoint =>
136+
val shift = maxBP.get - el.binaryPoint.get
137+
fromData(
138+
maxBP,
139+
if (shift > 0) el.data << shift else el.data,
140+
Some(maxWidth)
141+
).asInstanceOf[T]
142+
case nonFp => nonFp
137143
}
138-
width + extra.W
139-
}.fold(0.W)(_.max(_))
140-
in.map {
141-
case el: FixedPoint =>
142-
val shift = maxBP.get - el.binaryPoint.get
143-
fromData(
144-
maxBP,
145-
(if (shift > 0) el.data << shift else el.data).asSInt,
146-
Some(maxWidth)
147-
).asInstanceOf[T]
148-
case nonFp => nonFp
149144
}
150-
} else in
151145
out.toSeq
152146
}
153147

@@ -190,14 +184,14 @@ sealed class FixedPoint private[fixedpoint] (width: Width, private var _inferred
190184
with OpaqueType
191185
with Num[FixedPoint]
192186
with HasBinaryPoint {
187+
if (binaryPoint.known) require(binaryPoint.get >= 0, "Negative binary point is not supported")
193188
private val data: SInt = SInt(width)
194189
val elements: SeqMap[String, SInt] = SeqMap("" -> data)
195190

196191
def binaryPoint: BinaryPoint = _inferredBinaryPoint
197192

198-
private def requireKnownBP(message: => Any = "Unknown binary point is not supported in this operation"): Unit = {
199-
require(_inferredBinaryPoint.isInstanceOf[KnownBinaryPoint], message)
200-
}
193+
private def requireKnownBP(message: String = "Unknown binary point is not supported in this operation"): Unit =
194+
if (!binaryPoint.known) throw new ChiselException(message)
201195

202196
private def additiveOp(
203197
that: FixedPoint,
@@ -207,7 +201,7 @@ sealed class FixedPoint private[fixedpoint] (width: Width, private var _inferred
207201
compileOptions: CompileOptions
208202
): FixedPoint = {
209203
val Seq(_this, _that) = FixedPoint.dataAligned(this, that).map(WireDefault(_))
210-
FixedPoint.fromData(_inferredBinaryPoint.max(that._inferredBinaryPoint), f(_this.data, _that.data))
204+
FixedPoint.fromData(binaryPoint.max(that.binaryPoint), f(_this.data, _that.data))
211205
}
212206

213207
private def comparativeOp(that: FixedPoint, f: (SInt, SInt) => Bool): Bool = {
@@ -221,21 +215,22 @@ sealed class FixedPoint private[fixedpoint] (width: Width, private var _inferred
221215
)(
222216
implicit sourceInfo: SourceInfo,
223217
connectCompileOptions: CompileOptions
224-
): Unit =
218+
): Unit = {
225219
that match {
226220
case that: FixedPoint =>
227-
if (_inferredBinaryPoint.isInstanceOf[KnownBinaryPoint]) {
228-
c(data, that.setBinaryPoint(_inferredBinaryPoint.get).data)
221+
if (binaryPoint.known) {
222+
c(data, that.setBinaryPoint(binaryPoint.get).data)
229223
} else {
230-
if (that._inferredBinaryPoint.isInstanceOf[KnownBinaryPoint]) {
231-
this._inferredBinaryPoint = BinaryPoint(that._inferredBinaryPoint.get)
224+
if (that.binaryPoint.known) {
225+
this._inferredBinaryPoint = BinaryPoint(that.binaryPoint.get)
232226
}
233227
c(data, that.data)
234228
}
235229
case that @ DontCare =>
236230
c(data, that)
237231
case _ => throw new ChiselException(s"Cannot connect ${this} and ${that}")
238232
}
233+
}
239234

240235
override def do_+(that: FixedPoint)(implicit sourceInfo: SourceInfo, compileOptions: CompileOptions): FixedPoint =
241236
additiveOp(that, _ + _)
@@ -256,13 +251,13 @@ sealed class FixedPoint private[fixedpoint] (width: Width, private var _inferred
256251
additiveOp(that, _ -& _)
257252

258253
def do_unary_-(implicit sourceInfo: SourceInfo, compileOptions: CompileOptions): FixedPoint =
259-
FixedPoint.fromData(_inferredBinaryPoint, -data)
254+
FixedPoint.fromData(binaryPoint, -data)
260255

261256
def do_unary_-%(implicit sourceInfo: SourceInfo, compileOptions: CompileOptions): FixedPoint =
262-
FixedPoint.fromData(_inferredBinaryPoint, data.unary_-%)
257+
FixedPoint.fromData(binaryPoint, data.unary_-%)
263258

264259
override def do_*(that: FixedPoint)(implicit sourceInfo: SourceInfo, compileOptions: CompileOptions): FixedPoint =
265-
FixedPoint.fromData(_inferredBinaryPoint + that._inferredBinaryPoint, data * that.data)
260+
FixedPoint.fromData(binaryPoint + that.binaryPoint, data * that.data)
266261

267262
override def do_/(that: FixedPoint)(implicit sourceInfo: SourceInfo, compileOptions: CompileOptions): FixedPoint =
268263
throw new ChiselException(s"division is illegal on FixedPoint types")
@@ -283,27 +278,29 @@ sealed class FixedPoint private[fixedpoint] (width: Width, private var _inferred
283278
comparativeOp(that, _ >= _)
284279

285280
override def do_abs(implicit sourceInfo: SourceInfo, compileOptions: CompileOptions): FixedPoint =
286-
FixedPoint.fromData(_inferredBinaryPoint, data.abs)
281+
FixedPoint.fromData(binaryPoint, data.abs)
287282

288283
def do_floor(implicit sourceInfo: SourceInfo, compileOptions: CompileOptions): FixedPoint = {
289284
requireKnownBP()
290285
// Set the fractional part to zeroes
291-
val floored = Cat(data >> binaryPoint.get, 0.U(binaryPoint.get.W)).asSInt
292-
FixedPoint.fromData(binaryPoint, floored)
286+
val floored = Cat(data >> binaryPoint.get, 0.U(binaryPoint.get.W.min(width)))
287+
FixedPoint.fromData(binaryPoint, floored, Some(width))
293288
}
294289

295290
def do_ceil(implicit sourceInfo: SourceInfo, compileOptions: CompileOptions): FixedPoint = {
296291
requireKnownBP()
297292
// Get a number with the fractional part set to ones
298-
val almostOne = ((1 << binaryPoint.get) - 1).U(width)
293+
val almostOne = ((1 << binaryPoint.get) - 1).S
299294
// Add it to the number and floor it
300-
(this + FixedPoint.fromData(binaryPoint, almostOne.asSInt)).floor
295+
val ceiled = (this + FixedPoint.fromData(binaryPoint, almostOne)).floor
296+
FixedPoint.fromData(binaryPoint, ceiled, Some(width))
301297
}
302298

303299
def do_round(implicit sourceInfo: SourceInfo, compileOptions: CompileOptions): FixedPoint = {
304300
requireKnownBP()
305301
// Add 0.5 to the number and then floor it
306-
(this + 0.5.F(1.BP)).floor.setBinaryPoint(binaryPoint.get)
302+
val rounded = (this + 0.5.F(1.BP)).floor.setBinaryPoint(binaryPoint.get)
303+
FixedPoint.fromData(binaryPoint, rounded, Some(width))
307304
}
308305

309306
def do_===(that: FixedPoint)(implicit sourceInfo: SourceInfo, compileOptions: CompileOptions): Bool =
@@ -316,22 +313,22 @@ sealed class FixedPoint private[fixedpoint] (width: Width, private var _inferred
316313
comparativeOp(that, _ =/= _)
317314

318315
def do_>>(that: Int)(implicit sourceInfo: SourceInfo, compileOptions: CompileOptions): FixedPoint =
319-
FixedPoint.fromData(_inferredBinaryPoint, (data >> that).asSInt)
316+
FixedPoint.fromData(binaryPoint, data >> that)
320317

321318
def do_>>(that: BigInt)(implicit sourceInfo: SourceInfo, compileOptions: CompileOptions): FixedPoint =
322-
FixedPoint.fromData(_inferredBinaryPoint, (data >> that).asSInt)
319+
FixedPoint.fromData(binaryPoint, data >> that)
323320

324321
def do_>>(that: UInt)(implicit sourceInfo: SourceInfo, compileOptions: CompileOptions): FixedPoint =
325-
FixedPoint.fromData(_inferredBinaryPoint, (data >> that).asSInt)
322+
FixedPoint.fromData(binaryPoint, data >> that)
326323

327324
def do_<<(that: Int)(implicit sourceInfo: SourceInfo, compileOptions: CompileOptions): FixedPoint =
328-
FixedPoint.fromData(_inferredBinaryPoint, (data << that).asSInt)
325+
FixedPoint.fromData(binaryPoint, data << that)
329326

330327
def do_<<(that: BigInt)(implicit sourceInfo: SourceInfo, compileOptions: CompileOptions): FixedPoint =
331-
FixedPoint.fromData(_inferredBinaryPoint, (data << that).asSInt)
328+
FixedPoint.fromData(binaryPoint, data << that)
332329

333330
def do_<<(that: UInt)(implicit sourceInfo: SourceInfo, compileOptions: CompileOptions): FixedPoint =
334-
FixedPoint.fromData(_inferredBinaryPoint, (data << that).asSInt)
331+
FixedPoint.fromData(binaryPoint, data << that)
335332

336333
def +%(that: FixedPoint): FixedPoint = macro SourceInfoTransform.thatArg
337334

@@ -381,7 +378,7 @@ sealed class FixedPoint private[fixedpoint] (width: Width, private var _inferred
381378
implicit sourceInfo: SourceInfo,
382379
compileOptions: CompileOptions
383380
): Unit = {
384-
this.data := that.asSInt
381+
this.data := that.asTypeOf(this.data)
385382
}
386383

387384
def apply(x: BigInt): Bool = data.apply(x)
@@ -400,35 +397,29 @@ sealed class FixedPoint private[fixedpoint] (width: Width, private var _inferred
400397

401398
final def asSInt: SInt = data.asSInt
402399

403-
final def asFixedPoint(binaryPoint: BinaryPoint): FixedPoint = {
404-
binaryPoint match {
405-
case KnownBinaryPoint(_) =>
406-
FixedPoint.fromData(binaryPoint, data, Some(width))
407-
case UnknownBinaryPoint =>
408-
throw new ChiselException(
409-
s"cannot call $this.asFixedPoint(binaryPoint=$binaryPoint), you must specify a known binaryPoint"
410-
)
411-
}
400+
def do_asFixedPoint(
401+
binaryPoint: BinaryPoint
402+
)(
403+
implicit sourceInfo: SourceInfo,
404+
compileOptions: CompileOptions
405+
): FixedPoint = {
406+
requireKnownBP(s"cannot call $this.asFixedPoint(binaryPoint=$binaryPoint), you must specify a known binaryPoint")
407+
FixedPoint.fromData(binaryPoint, data, Some(width))
412408
}
413409

414410
def do_setBinaryPoint(that: Int)(implicit sourceInfo: SourceInfo, compileOptions: CompileOptions): FixedPoint = {
415-
_inferredBinaryPoint match {
416-
case KnownBinaryPoint(current) =>
417-
val diff = that - current
418-
FixedPoint.fromData(
419-
that.BP,
420-
(if (diff > 0) data << diff
421-
else if (diff < 0) data >> -diff
422-
else data).asSInt,
423-
Some(width + diff)
424-
)
425-
case UnknownBinaryPoint =>
426-
throw new ChiselException(
427-
s"cannot set new binary point if current binary point is unknown"
428-
)
429-
}
411+
requireKnownBP(s"cannot set new binary point if current binary point is unknown")
412+
val diff = that - binaryPoint.get
413+
FixedPoint.fromData(
414+
that.BP,
415+
if (diff > 0) data << diff
416+
else if (diff < 0) data >> -diff
417+
else data
418+
)
430419
}
431420

421+
final def asFixedPoint(that: BinaryPoint): FixedPoint = macro SourceInfoTransform.thatArg
422+
432423
def setBinaryPoint(that: Int): FixedPoint = macro SourceInfoTransform.thatArg
433424

434425
def widthKnown: Boolean = data.widthKnown
@@ -447,10 +438,9 @@ sealed class FixedPoint private[fixedpoint] (width: Width, private var _inferred
447438
case Some(value) => s"FixedPoint$width$binaryPoint($value)"
448439
case _ =>
449440
// Can't use stringAccessor so will have to extract from data field's toString...
450-
val suffix = ".*?([(].*[)])".r.findFirstMatchIn(data.toString) match {
451-
case Some(m) => m.group(1)
452-
case None => ""
453-
}
441+
val suffix = ".*?([(].*[)])".r
442+
.findFirstMatchIn(data.toString)
443+
.fold("")(_.group(1))
454444
s"FixedPoint$width$binaryPoint$suffix"
455445
}
456446
}

src/test/scala/ConnectSpec.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ class ConnectSpec extends ChiselPropSpec with Utils {
3737
}
3838
property("FixedPoint := FixedPoint should succeed") {
3939
assertTesterPasses { new CrossConnectTester(FixedPoint(16.W, 8.BP), FixedPoint(16.W, 8.BP)) }
40+
assertTesterPasses { new CrossConnectTester(FixedPoint(2.W, 14.BP), FixedPoint(8.W, 6.BP)) }
4041
}
4142
property("FixedPoint := SInt should fail") {
4243
intercept[ChiselException] {

0 commit comments

Comments
 (0)