1111#include < string>
1212#include < cmath>
1313
14+ #define VERIFY (cond, msg ) do { if (!(cond)) { std::cerr << " FAIL: " << msg << std::endl; ++nrOfFailedTestCases; } } while (0 )
15+
1416namespace sw { namespace universal {
1517
1618// Test that carry analysis converges
1719int TestConvergence () {
1820 int nrOfFailedTestCases = 0 ;
19-
2021 ExprGraph g;
2122 int x = g.variable (" x" , 1.0 , 8.0 );
2223 int y = g.variable (" y" , 1.0 , 8.0 );
2324 int z = g.mul (x, y);
24-
2525 g.require_nsb (z, 10 );
26-
2726 CarryAnalyzer ca;
2827 int iters = ca.refine (g);
29-
3028 std::cout << " Simple mul converged in " << iters << " iterations\n " ;
3129 ca.report (std::cout, g);
32-
33- // Should converge (not exceed max iterations)
34- if (iters >= 10 ) {
35- std::cerr << " FAIL: carry analysis did not converge" << std::endl;
36- ++nrOfFailedTestCases;
37- }
38-
39- // z should still meet its requirement
40- if (g.get_nsb (z) < 10 ) {
41- std::cerr << " FAIL: z requirement not met after carry analysis" << std::endl;
42- ++nrOfFailedTestCases;
43- }
44-
30+ VERIFY (iters < 10 , " carry analysis did not converge" );
31+ VERIFY (g.get_nsb (z) >= 10 , " z requirement not met after carry analysis" );
4532 return nrOfFailedTestCases;
4633}
4734
4835// Test that carry refinement can reduce total bits
4936int TestBitReduction () {
5037 int nrOfFailedTestCases = 0 ;
51-
52- // Build graph: det = a*d - b*c
5338 ExprGraph g;
5439 int a = g.variable (" a" , 8.0 , 12.0 );
5540 int b = g.variable (" b" , 8.0 , 12.0 );
@@ -58,12 +43,10 @@ int TestBitReduction() {
5843 int ad = g.mul (a, d);
5944 int bc = g.mul (b, c);
6045 int det = g.sub (ad, bc);
61-
6246 g.require_nsb (det, 20 );
6347
64- // Solve with conservative carries (all 1)
6548 PopSolver conservative;
66- {
49+ { // Solve with conservative carries
6750 ExprGraph g2;
6851 int a2 = g2.variable (" a" , 8.0 , 12.0 );
6952 int b2 = g2.variable (" b" , 8.0 , 12.0 );
@@ -77,102 +60,134 @@ int TestBitReduction() {
7760 std::cout << " Conservative total: " << conservative.total_nsb () << " \n " ;
7861 }
7962
80- // Solve with carry refinement
8163 CarryAnalyzer ca;
8264 ca.refine (g);
83-
8465 double refined_total = 0 ;
85- for (int i = 0 ; i < g.size (); ++i) {
86- refined_total += g.get_nsb (i);
87- }
88-
66+ for (int i = 0 ; i < g.size (); ++i) refined_total += g.get_nsb (i);
8967 std::cout << " Refined total: " << refined_total << " \n " ;
9068 ca.report (std::cout, g);
91-
92- // Refined should be <= conservative
93- if (refined_total > conservative.total_nsb () + 1 ) {
94- std::cerr << " FAIL: refined total exceeds conservative" << std::endl;
95- ++nrOfFailedTestCases;
96- }
97-
98- // Output should still meet requirement
99- if (g.get_nsb (det) < 20 ) {
100- std::cerr << " FAIL: det requirement not met after carry refinement" << std::endl;
101- ++nrOfFailedTestCases;
102- }
103-
69+ VERIFY (refined_total <= conservative.total_nsb () + 1 , " refined total exceeds conservative" );
70+ VERIFY (g.get_nsb (det) >= 20 , " det requirement not met after carry refinement" );
10471 return nrOfFailedTestCases;
10572}
10673
107- // Test chain with addition (where carry analysis is most effective)
74+ // Test chain with addition (carry analysis most effective here )
10875int TestAdditionChain () {
10976 int nrOfFailedTestCases = 0 ;
110-
111- // z = (a + b) + c with values of very different magnitudes
11277 ExprGraph g;
113- int a = g.variable (" a" , 1000.0 , 2000.0 ); // ufp ~= 10
114- int b = g.variable (" b" , 0.001 , 0.002 ); // ufp ~= -10
115- int c = g.variable (" c" , 1000.0 , 2000.0 ); // ufp ~= 10
116-
78+ int a = g.variable (" a" , 1000.0 , 2000.0 );
79+ int b = g.variable (" b" , 0.001 , 0.002 );
80+ int c = g.variable (" c" , 1000.0 , 2000.0 );
11781 int ab = g.add (a, b);
11882 int z = g.add (ab, c);
119-
12083 g.require_nsb (z, 12 );
84+ CarryAnalyzer ca;
85+ int iters = ca.refine (g);
86+ std::cout << " Addition chain converged in " << iters << " iterations\n " ;
87+ ca.report (std::cout, g);
88+ VERIFY (g.get_nsb (z) >= 12 , " addition chain z requirement not met" );
89+ return nrOfFailedTestCases;
90+ }
12191
92+ // Test variables-only graph (no operations to refine)
93+ int TestVariablesOnly () {
94+ int nrOfFailedTestCases = 0 ;
95+ ExprGraph g;
96+ int a = g.variable (" a" , 1.0 , 10.0 );
97+ int b = g.variable (" b" , 1.0 , 10.0 );
98+ g.require_nsb (a, 8 );
99+ g.require_nsb (b, 12 );
122100 CarryAnalyzer ca;
123101 int iters = ca.refine (g);
102+ std::cout << " Variables-only converged in " << iters << " iterations\n " ;
103+ VERIFY (g.get_nsb (a) >= 8 , " a requirement not met" );
104+ VERIFY (g.get_nsb (b) >= 12 , " b requirement not met" );
105+ return nrOfFailedTestCases;
106+ }
124107
125- std::cout << " Addition chain converged in " << iters << " iterations\n " ;
108+ // Test repeated refine() calls (idempotent)
109+ int TestRepeatedRefine () {
110+ int nrOfFailedTestCases = 0 ;
111+ ExprGraph g;
112+ int x = g.variable (" x" , 1.0 , 8.0 );
113+ int y = g.variable (" y" , 1.0 , 8.0 );
114+ int z = g.mul (x, y);
115+ g.require_nsb (z, 10 );
116+ CarryAnalyzer ca;
117+ ca.refine (g);
118+ int nsb_x1 = g.get_nsb (x), nsb_y1 = g.get_nsb (y), nsb_z1 = g.get_nsb (z);
119+ CarryAnalyzer ca2;
120+ ca2.refine (g);
121+ VERIFY (g.get_nsb (x) == nsb_x1 && g.get_nsb (y) == nsb_y1 && g.get_nsb (z) == nsb_z1, " repeated refine changed nsb values" );
122+ return nrOfFailedTestCases;
123+ }
124+
125+ // Test carry analysis with division (always carry=1)
126+ int TestDivisionCarry () {
127+ int nrOfFailedTestCases = 0 ;
128+ ExprGraph g;
129+ int a = g.variable (" a" , 10.0 , 100.0 );
130+ int b = g.variable (" b" , 1.0 , 10.0 );
131+ int z = g.div (a, b);
132+ g.require_nsb (z, 14 );
133+ CarryAnalyzer ca;
134+ int iters = ca.refine (g);
135+ std::cout << " Division carry converged in " << iters << " iterations\n " ;
126136 ca.report (std::cout, g);
137+ VERIFY (g.get_node (z).carry == 1 , " division carry should remain 1, got " << g.get_node (z).carry );
138+ VERIFY (g.get_nsb (z) >= 14 , " div z requirement not met" );
139+ return nrOfFailedTestCases;
140+ }
127141
128- if (g.get_nsb (z) < 12 ) {
129- std::cerr << " FAIL: addition chain z requirement not met" << std::endl;
130- ++nrOfFailedTestCases;
131- }
142+ // Test carry analysis with sqrt
143+ int TestSqrtCarry () {
144+ int nrOfFailedTestCases = 0 ;
145+ ExprGraph g;
146+ int x = g.variable (" x" , 4.0 , 100.0 );
147+ int z = g.sqrt (x);
148+ g.require_nsb (z, 12 );
149+ CarryAnalyzer ca;
150+ int iters = ca.refine (g);
151+ std::cout << " Sqrt carry converged in " << iters << " iterations\n " ;
152+ VERIFY (g.get_nsb (z) >= 12 , " sqrt z requirement not met" );
153+ return nrOfFailedTestCases;
154+ }
132155
156+ // Test iterations() accessor
157+ int TestIterationsAccessor () {
158+ int nrOfFailedTestCases = 0 ;
159+ ExprGraph g;
160+ int x = g.variable (" x" , 1.0 , 8.0 );
161+ int y = g.variable (" y" , 1.0 , 8.0 );
162+ int z = g.mul (x, y);
163+ g.require_nsb (z, 10 );
164+ CarryAnalyzer ca;
165+ int iters = ca.refine (g);
166+ VERIFY (ca.iterations () == iters, " iterations() != refine() return value" );
133167 return nrOfFailedTestCases;
134168}
135169
136170}} // namespace sw::universal
137171
138- #define TEST_CASE (name, func ) \
139- do { \
140- int fails = func; \
141- if (fails) { \
142- std::cout << name << " : FAIL (" << fails << " errors)" << std::endl; \
143- nrOfFailedTestCases += fails; \
144- } else { \
145- std::cout << name << " : PASS" << std::endl; \
146- } \
147- } while (0 )
172+ #define TEST_CASE (name, func ) do { int f_ = func; if (f_) { std::cout << name << " : FAIL (" << f_ << " errors)\n " ; nrOfFailedTestCases += f_; } else { std::cout << name << " : PASS\n " ; } } while (0 )
148173
149174int main ()
150175try {
151176 using namespace sw ::universal;
152-
153177 int nrOfFailedTestCases = 0 ;
154-
155- std::cout << " POP Carry Analysis Tests\n " ;
156- std::cout << std::string (40 , ' =' ) << " \n\n " ;
178+ std::cout << " POP Carry Analysis Tests\n " << std::string (40 , ' =' ) << " \n\n " ;
157179
158180 TEST_CASE (" Convergence" , TestConvergence ());
159181 TEST_CASE (" Bit reduction" , TestBitReduction ());
160182 TEST_CASE (" Addition chain" , TestAdditionChain ());
183+ TEST_CASE (" Variables only" , TestVariablesOnly ());
184+ TEST_CASE (" Repeated refine" , TestRepeatedRefine ());
185+ TEST_CASE (" Division carry" , TestDivisionCarry ());
186+ TEST_CASE (" Sqrt carry" , TestSqrtCarry ());
187+ TEST_CASE (" Iterations accessor" , TestIterationsAccessor ());
161188
162- std::cout << " \n " ;
163- if (nrOfFailedTestCases == 0 ) {
164- std::cout << " All carry analysis tests PASSED\n " ;
165- } else {
166- std::cout << nrOfFailedTestCases << " test(s) FAILED\n " ;
167- }
168-
189+ std::cout << " \n " << (nrOfFailedTestCases == 0 ? " All carry analysis tests PASSED" : std::to_string (nrOfFailedTestCases) + " test(s) FAILED" ) << " \n " ;
169190 return (nrOfFailedTestCases > 0 ? EXIT_FAILURE : EXIT_SUCCESS);
170191}
171- catch (const char * msg) {
172- std::cerr << " Caught exception: " << msg << std::endl;
173- return EXIT_FAILURE;
174- }
175- catch (...) {
176- std::cerr << " Caught unknown exception" << std::endl;
177- return EXIT_FAILURE;
178- }
192+ catch (const char * msg) { std::cerr << " Caught exception: " << msg << std::endl; return EXIT_FAILURE; }
193+ catch (...) { std::cerr << " Caught unknown exception" << std::endl; return EXIT_FAILURE; }
0 commit comments