1+ // See LICENSE for license details.
2+
3+ package dsptools .dspmath
4+
5+ import org .scalatest .{FlatSpec , Matchers }
6+
7+ case class RadPow (rad : Int , pow : Int ) {
8+ /** `r ^ p` */
9+ def get : Int = BigInt (rad).pow(pow).toInt
10+ /** Factorize i.e. rad = 4, pow = 3 -> Seq(4, 4, 4) */
11+ def factorize : Seq [Int ] = Seq .fill(pow)(rad)
12+ }
13+
14+ case class Factorization (supportedRadsUnsorted : Seq [Seq [Int ]]) {
15+ /** Supported radices, MSD First */
16+ private val supportedRads = supportedRadsUnsorted.map(_.sorted.reverse)
17+
18+ /** Factor n into powers of supported radices and store RadPow i.e. r^p, separated by coprimes
19+ * i.e. if supportedRads = Seq(Seq(4, 2), Seq(3)),
20+ * output = Seq(Seq(RadPow(4, 5), RadPow(2, 1)), Seq(RadPow(3, 7)))
21+ * implies n = 4^5 * 2^1 * 3^7
22+ */
23+ private def getRadPows (n : Int ): Seq [Seq [RadPow ]] = {
24+ // Test if n can be factored by each of the supported radices (mod = 0)
25+ // Count # of times it can be factored
26+ var unfactorized = n
27+ val radPows = for (primeGroup <- supportedRads) yield { for (rad <- primeGroup) yield {
28+ var (mod, pow) = (0 , 0 )
29+ while (mod == 0 ) {
30+ mod = unfactorized % rad
31+ if (mod == 0 ) {
32+ pow = pow + 1
33+ unfactorized = unfactorized / rad
34+ }
35+ }
36+ RadPow (rad, pow)
37+ }}
38+ // If n hasn't completely been factorized, then an unsupported radix is required
39+ require(unfactorized == 1 , s " $n is invalid for supportedRads. " )
40+ radPows
41+ }
42+
43+ /** Factor n into powers of supported radices (flattened)
44+ * i.e. if supportedRads = Seq(Seq(4, 2), Seq(3)),
45+ * output = Seq(5, 1, 7)
46+ * implies `n = 4^5 * 2^1 * 3^7`
47+ * If supportedRads contains more radices than the ones used, a power of 0 will be
48+ * associated with the unused radices.
49+ */
50+ def getPowsFlat (n : Int ): Seq [Int ] = {
51+ getRadPows(n).flatMap(_.map(_.pow))
52+ }
53+
54+ /** Break n into coprimes i.e.
55+ * n = 4^5 * 2^1 * 3^7
56+ * would result in Seq(4^5 * 2^1, 3^7)
57+ * If supportedRads contains more coprime groups than the ones used, 1 will be
58+ * associated with the unused groups.
59+ */
60+ def getCoprimes (n : Int ): Seq [Int ] = {
61+ getRadPows(n).map(_.map(_.get).product)
62+ }
63+
64+ /** Factorizes the coprime into digit radices (mixed radix)
65+ * i.e. n = 8 -> Seq(4, 2)
66+ * Note: there's no padding!
67+ */
68+ def factorizeCoprime (n : Int ): Seq [Int ] = {
69+ // i.e. if supportedRads = Seq(Seq(4, 2), Seq(3)) and n = 8,
70+ // correspondingPrimeGroup = Seq(4, 2)
71+ val correspondingPrimeGroup = supportedRads.filter(n % _.min == 0 )
72+ require(correspondingPrimeGroup.length == 1 , " n (coprime) must not be divisible by other primes." )
73+ // Factorize coprime -- only correspondingPrimeGroup should actually add to factorization length
74+ getRadPows(n).flatten.flatMap(_.factorize)
75+ }
76+
77+ /** Gets associated base prime for n (assuming n isn't divisible by other primes)
78+ * WARNING: Assumes supportedRads contains the base prime!
79+ */
80+ def getBasePrime (n : Int ): Int = {
81+ val primeTemp = supportedRads.map(_.min).filter(n % _ == 0 )
82+ require(primeTemp.length == 1 , " n should only be divisible by 1 prime" )
83+ primeTemp.head
84+ }
85+
86+ }
87+
88+ class FactorizationSpec extends FlatSpec with Matchers {
89+
90+ val testSupportedRads = Seq (Seq (4 , 2 ), Seq (3 ), Seq (5 ), Seq (7 ))
91+
92+ behavior of " Factorization"
93+ it should " properly factorize" in {
94+ case class FactorizationTest (n : Int , pows : Seq [Int ], coprimes : Seq [Int ])
95+ val tests = Seq (
96+ FactorizationTest (
97+ n = (math.pow(4 , 5 ) * math.pow(2 , 1 ) * math.pow(3 , 7 )).toInt,
98+ pows = Seq (5 , 1 , 7 ),
99+ coprimes = Seq ((math.pow(4 , 5 ) * math.pow(2 , 1 )).toInt, math.pow(3 , 7 ).toInt)
100+ ),
101+ FactorizationTest (n = 15 , pows = Seq (0 , 0 , 1 , 1 ), coprimes = Seq (1 , 3 , 5 ))
102+ )
103+
104+ tests foreach { case FactorizationTest (n, pows, coprimes) =>
105+ val powsFill = Seq .fill(testSupportedRads.flatten.length - pows.length)(0 )
106+ val coprimesFill = Seq .fill(testSupportedRads.length - coprimes.length)(1 )
107+ require(
108+ Factorization (testSupportedRads).getPowsFlat(n) == pows ++ powsFill,
109+ " Should factorize to get the right powers -- includes padding."
110+ )
111+ require(
112+ Factorization (testSupportedRads).getCoprimes(n) == coprimes ++ coprimesFill,
113+ " Should factorize into the right coprimes -- includes padding."
114+ )
115+ }
116+ }
117+
118+ it should " properly factorize coprime" in {
119+ case class CoprimeFactorizationTest (n : Int , factorization : Seq [Int ], basePrime : Int )
120+ val tests = Seq (
121+ CoprimeFactorizationTest (n = 8 , factorization = Seq (4 , 2 ), basePrime = 2 ),
122+ CoprimeFactorizationTest (n = 16 , factorization = Seq (4 , 4 ), basePrime = 2 )
123+ )
124+ tests foreach { case CoprimeFactorizationTest (n, factorization, basePrime) =>
125+ require(
126+ Factorization (testSupportedRads).factorizeCoprime(n) == factorization,
127+ " Should factorize coprime properly."
128+ )
129+ require(
130+ Factorization (testSupportedRads).getBasePrime(n) == basePrime,
131+ " Should get the correct base prime."
132+ )
133+ }
134+ }
135+ }
0 commit comments