@@ -42,6 +42,14 @@ module VX_tcu_fedp_bhf #(
4242 localparam FMT_DELAY = FMUL_LATENCY + FRND_LATENCY ;
4343 localparam C_DELAY = (FMUL_LATENCY + FRND_LATENCY ) + 1 + FRED_LATENCY ;
4444
45+ localparam MUL_EXP = 8 + LEVELS ;
46+ localparam MUL_SIG = 32 + LEVELS ;
47+ localparam MUL_WIDTH = 1 + MUL_EXP + MUL_SIG ;
48+
49+ localparam ACC_EXP = `MAX (MUL_EXP , 8 );
50+ localparam ACC_SIG = `MAX (MUL_SIG , 24 );
51+ localparam ACC_WIDTH = 1 + ACC_EXP + ACC_SIG ;
52+
4553 `UNUSED_VAR ({ fmt_s[3 ], fmt_d, c_val} );
4654
4755 wire [2 : 0 ] frm = `round_near_even ;
@@ -71,21 +79,21 @@ module VX_tcu_fedp_bhf #(
7179 .data_out (fmt_s_delayed)
7280 );
7381
74- wire [32 : 0 ] mult_result [TCK ];
82+ wire [MUL_WIDTH - 1 : 0 ] mult_result [TCK ];
7583
7684 for (genvar i = 0 ; i < TCK ; i++ ) begin : g_multiply
7785
78- wire [32 : 0 ] mult_result_fp16;
79- wire [32 : 0 ] mult_result_bf16;
86+ wire [MUL_WIDTH - 1 : 0 ] mult_result_fp16;
87+ wire [MUL_WIDTH - 1 : 0 ] mult_result_bf16;
8088
81- wire [32 : 0 ] mult_result_fp8;
82- wire [32 : 0 ] mult_result_bf8;
89+ wire [MUL_WIDTH - 1 : 0 ] mult_result_fp8;
90+ wire [MUL_WIDTH - 1 : 0 ] mult_result_bf8;
8391
84- VX_tcu_bhf_fp8mul # (
92+ VX_tcu_bhf_fmul8 # (
8593 .IN_EXPW (4 ),
8694 .IN_SIGW (3 + 1 ),
87- .OUT_EXPW (8 ),
88- .OUT_SIGW (24 ),
95+ .OUT_EXPW (MUL_EXP ),
96+ .OUT_SIGW (MUL_SIG ),
8997 .IN_REC (0 ), // input in IEEE format
9098 .OUT_REC (1 ), // output in recoded format
9199 .MUL_LATENCY (FMUL_LATENCY ),
@@ -103,11 +111,11 @@ module VX_tcu_fedp_bhf #(
103111 `UNUSED_PIN (fflags)
104112 );
105113
106- VX_tcu_bhf_fp8mul # (
114+ VX_tcu_bhf_fmul8 # (
107115 .IN_EXPW (5 ),
108116 .IN_SIGW (2 + 1 ),
109- .OUT_EXPW (8 ),
110- .OUT_SIGW (24 ),
117+ .OUT_EXPW (MUL_EXP ),
118+ .OUT_SIGW (MUL_SIG ),
111119 .IN_REC (0 ), // input in IEEE format
112120 .OUT_REC (1 ), // output in recoded format
113121 .MUL_LATENCY (FMUL_LATENCY ),
@@ -128,8 +136,8 @@ module VX_tcu_fedp_bhf #(
128136 VX_tcu_bhf_fmul # (
129137 .IN_EXPW (5 ),
130138 .IN_SIGW (10 + 1 ),
131- .OUT_EXPW (8 ),
132- .OUT_SIGW (24 ),
139+ .OUT_EXPW (MUL_EXP ),
140+ .OUT_SIGW (MUL_SIG ),
133141 .IN_REC (0 ), // input in IEEE format
134142 .OUT_REC (1 ), // output in recoded format
135143 .MUL_LATENCY (FMUL_LATENCY ),
@@ -148,8 +156,8 @@ module VX_tcu_fedp_bhf #(
148156 VX_tcu_bhf_fmul # (
149157 .IN_EXPW (8 ),
150158 .IN_SIGW (7 + 1 ),
151- .OUT_EXPW (8 ),
152- .OUT_SIGW (24 ),
159+ .OUT_EXPW (MUL_EXP ),
160+ .OUT_SIGW (MUL_SIG ),
153161 .IN_REC (0 ), // input in IEEE format
154162 .OUT_REC (1 ), // output in recoded format
155163 .MUL_LATENCY (FMUL_LATENCY ),
@@ -165,7 +173,7 @@ module VX_tcu_fedp_bhf #(
165173 `UNUSED_PIN (fflags)
166174 );
167175
168- logic [32 : 0 ] mult_result_mux;
176+ logic [MUL_WIDTH - 1 : 0 ] mult_result_mux;
169177 always_comb begin
170178 case (fmt_s_delayed)
171179 3'd1 : mult_result_mux = mult_result_fp16;
@@ -177,7 +185,7 @@ module VX_tcu_fedp_bhf #(
177185 end
178186
179187 VX_pipe_register # (
180- .DATAW (33 ),
188+ .DATAW (MUL_WIDTH ),
181189 .DEPTH (1 )
182190 ) pipe_mulsel (
183191 .clk (clk),
@@ -188,75 +196,118 @@ module VX_tcu_fedp_bhf #(
188196 );
189197 end
190198
191- // Accumulate reduction tree
199+ // Product terms reduction
192200
193- wire [32 : 0 ] red_in [LEVELS + 1 ][TCK ];
201+ for (genvar lvl = 0 ; lvl < LEVELS ; lvl++ ) begin : g_levels
202+ localparam CURSZ = TCK >> lvl;
203+ localparam OUTSZ = CURSZ >> 1 ;
204+ localparam in_expw = (lvl == 0 ) ? MUL_EXP : ACC_EXP ;
205+ localparam in_sigw = (lvl == 0 ) ? MUL_SIG : ACC_SIG ;
206+ localparam in_w = 1 + in_expw + in_sigw;
207+ localparam out_expw = ACC_EXP ;
208+ localparam out_sigw = ACC_SIG ;
209+ localparam out_w = 1 + out_expw + out_sigw;
194210
195- for (genvar i = 0 ; i < TCK ; i++ ) begin : g_red_in
196- assign red_in[0 ][i] = mult_result[i];
197- end
211+ wire [OUTSZ - 1 : 0 ][out_w- 1 : 0 ] sum;
198212
199- for (genvar lvl = 0 ; lvl < LEVELS ; lvl++ ) begin : g_accumulate
200- localparam CURSZ = TCK >> lvl;
201- localparam OUTSZ = CURSZ >> 1 ;
202213 for (genvar i = 0 ; i < OUTSZ ; i++ ) begin : g_add
214+ wire [in_w- 1 : 0 ] a, b;
215+ if (lvl == 0 ) begin
216+ assign a = mult_result[2 * i+ 0 ];
217+ assign b = mult_result[2 * i+ 1 ];
218+ end else begin
219+ assign a = g_levels[lvl- 1 ].sum[2 * i+ 0 ];
220+ assign b = g_levels[lvl- 1 ].sum[2 * i+ 1 ];
221+ end
222+
203223 VX_tcu_bhf_fadd # (
204- .IN_EXPW (8 ),
205- .IN_SIGW (23 + 1 ),
206- .IN_REC (1 ), // input in recoded format
207- .OUT_REC (1 ), // output in recoded format
224+ .IN_EXPW (in_expw),
225+ .IN_SIGW (in_sigw),
226+ .OUT_EXPW (out_expw), // ACC
227+ .OUT_SIGW (out_sigw), // ACC
228+ .IN_REC (1 ),
229+ .OUT_REC (1 ),
208230 .ADD_LATENCY (FADD_LATENCY ),
209231 .RND_LATENCY (FRND_LATENCY )
210232 ) reduce_add (
211233 .clk (clk),
212234 .reset (reset),
213235 .enable (enable),
214- .frm (frm),
215- .a (red_in[lvl][ 2 * i + 0 ] ),
216- .b (red_in[lvl][ 2 * i + 1 ] ),
217- .y (red_in[lvl + 1 ] [i]),
236+ .frm (frm), // still RNE
237+ .a (a ),
238+ .b (b ),
239+ .y (sum [i]),
218240 `UNUSED_PIN (fflags)
219241 );
220242 end
221243 end
222244
223- // Final accumulation with C
245+ // Final reduced result is ACC width now:
246+ wire [ACC_WIDTH - 1 : 0 ] red_result = g_levels[LEVELS - 1 ].sum[0 ];
224247
225- wire [32 : 0 ] c_rec, c_delayed;
248+ // Final accumulation with C
249+ wire [32 : 0 ] c_rec, c_rec2;
250+ wire [ACC_WIDTH - 1 : 0 ] c_up, c_delayed;
226251 wire [31 : 0 ] result;
227252
228253 fNToRecFN # (
229254 .expWidth (8 ),
230255 .sigWidth (24 )
231- ) conv_c (
256+ ) conv_c_rec (
232257 .in (c_val[31 : 0 ]),
233258 .out (c_rec)
234259 );
235260
236261 VX_pipe_register # (
237262 .DATAW (33 ),
238- .DEPTH (C_DELAY )
239- ) pipe_c (
263+ .DEPTH (1 )
264+ ) pipe_c1 (
240265 .clk (clk),
241266 .reset (reset),
242267 .enable (enable),
243268 .data_in (c_rec),
269+ .data_out (c_rec2)
270+ );
271+
272+ recFNToRecFN # (
273+ .inExpWidth (8 ),
274+ .inSigWidth (24 ),
275+ .outExpWidth (ACC_EXP ),
276+ .outSigWidth (ACC_SIG )
277+ ) conv_c_up (
278+ .control (`flControl_tininessAfterRounding ),
279+ .roundingMode (`round_near_even ),
280+ .in (c_rec2),
281+ .out (c_up),
282+ `UNUSED_PIN (exceptionFlags)
283+ );
284+
285+ VX_pipe_register # (
286+ .DATAW (ACC_WIDTH ),
287+ .DEPTH (C_DELAY - 1 )
288+ ) pipe_c2 (
289+ .clk (clk),
290+ .reset (reset),
291+ .enable (enable),
292+ .data_in (c_up),
244293 .data_out (c_delayed)
245294 );
246295
247296 VX_tcu_bhf_fadd # (
248- .IN_EXPW (8 ),
249- .IN_SIGW (23 + 1 ),
297+ .IN_EXPW (ACC_EXP ),
298+ .IN_SIGW (ACC_SIG ),
299+ .OUT_EXPW (8 ),
300+ .OUT_SIGW (24 ),
250301 .IN_REC (1 ), // input in recoded format
251302 .OUT_REC (0 ), // output in IEEE format
252303 .ADD_LATENCY (FADD_LATENCY ),
253304 .RND_LATENCY (FRND_LATENCY )
254- ) final_add (
305+ ) acc (
255306 .clk (clk),
256307 .reset (reset),
257308 .enable (enable),
258309 .frm (frm),
259- .a (red_in[ LEVELS ][ 0 ] ),
310+ .a (red_result ),
260311 .b (c_delayed),
261312 .y (result),
262313 `UNUSED_PIN (fflags)
0 commit comments