Skip to content

Commit fa328f1

Browse files
committed
minor update
1 parent 2ca8db8 commit fa328f1

File tree

4 files changed

+93
-42
lines changed

4 files changed

+93
-42
lines changed
Lines changed: 93 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,14 @@ module VX_tcu_fedp_bhf #(
4242
localparam FMT_DELAY = FMUL_LATENCY + FRND_LATENCY;
4343
localparam C_DELAY = (FMUL_LATENCY + FRND_LATENCY) + 1 + FRED_LATENCY;
4444

45+
localparam MUL_EXP = 8 + LEVELS;
46+
localparam MUL_SIG = 32 + LEVELS;
47+
localparam MUL_WIDTH = 1 + MUL_EXP + MUL_SIG;
48+
49+
localparam ACC_EXP = `MAX(MUL_EXP, 8);
50+
localparam ACC_SIG = `MAX(MUL_SIG, 24);
51+
localparam ACC_WIDTH = 1 + ACC_EXP + ACC_SIG;
52+
4553
`UNUSED_VAR ({fmt_s[3], fmt_d, c_val});
4654

4755
wire [2:0] frm = `round_near_even;
@@ -71,21 +79,21 @@ module VX_tcu_fedp_bhf #(
7179
.data_out(fmt_s_delayed)
7280
);
7381

74-
wire [32:0] mult_result [TCK];
82+
wire [MUL_WIDTH-1:0] mult_result [TCK];
7583

7684
for (genvar i = 0; i < TCK; i++) begin : g_multiply
7785

78-
wire [32:0] mult_result_fp16;
79-
wire [32:0] mult_result_bf16;
86+
wire [MUL_WIDTH-1:0] mult_result_fp16;
87+
wire [MUL_WIDTH-1:0] mult_result_bf16;
8088

81-
wire [32:0] mult_result_fp8;
82-
wire [32:0] mult_result_bf8;
89+
wire [MUL_WIDTH-1:0] mult_result_fp8;
90+
wire [MUL_WIDTH-1:0] mult_result_bf8;
8391

84-
VX_tcu_bhf_fp8mul #(
92+
VX_tcu_bhf_fmul8 #(
8593
.IN_EXPW (4),
8694
.IN_SIGW (3+1),
87-
.OUT_EXPW(8),
88-
.OUT_SIGW(24),
95+
.OUT_EXPW(MUL_EXP),
96+
.OUT_SIGW(MUL_SIG),
8997
.IN_REC (0), // input in IEEE format
9098
.OUT_REC (1), // output in recoded format
9199
.MUL_LATENCY (FMUL_LATENCY),
@@ -103,11 +111,11 @@ module VX_tcu_fedp_bhf #(
103111
`UNUSED_PIN(fflags)
104112
);
105113

106-
VX_tcu_bhf_fp8mul #(
114+
VX_tcu_bhf_fmul8 #(
107115
.IN_EXPW (5),
108116
.IN_SIGW (2+1),
109-
.OUT_EXPW(8),
110-
.OUT_SIGW(24),
117+
.OUT_EXPW(MUL_EXP),
118+
.OUT_SIGW(MUL_SIG),
111119
.IN_REC (0), // input in IEEE format
112120
.OUT_REC (1), // output in recoded format
113121
.MUL_LATENCY (FMUL_LATENCY),
@@ -128,8 +136,8 @@ module VX_tcu_fedp_bhf #(
128136
VX_tcu_bhf_fmul #(
129137
.IN_EXPW (5),
130138
.IN_SIGW (10+1),
131-
.OUT_EXPW(8),
132-
.OUT_SIGW(24),
139+
.OUT_EXPW(MUL_EXP),
140+
.OUT_SIGW(MUL_SIG),
133141
.IN_REC (0), // input in IEEE format
134142
.OUT_REC (1), // output in recoded format
135143
.MUL_LATENCY (FMUL_LATENCY),
@@ -148,8 +156,8 @@ module VX_tcu_fedp_bhf #(
148156
VX_tcu_bhf_fmul #(
149157
.IN_EXPW (8),
150158
.IN_SIGW (7+1),
151-
.OUT_EXPW(8),
152-
.OUT_SIGW(24),
159+
.OUT_EXPW(MUL_EXP),
160+
.OUT_SIGW(MUL_SIG),
153161
.IN_REC (0), // input in IEEE format
154162
.OUT_REC (1), // output in recoded format
155163
.MUL_LATENCY (FMUL_LATENCY),
@@ -165,7 +173,7 @@ module VX_tcu_fedp_bhf #(
165173
`UNUSED_PIN(fflags)
166174
);
167175

168-
logic [32:0] mult_result_mux;
176+
logic [MUL_WIDTH-1:0] mult_result_mux;
169177
always_comb begin
170178
case(fmt_s_delayed)
171179
3'd1: mult_result_mux = mult_result_fp16;
@@ -177,7 +185,7 @@ module VX_tcu_fedp_bhf #(
177185
end
178186

179187
VX_pipe_register #(
180-
.DATAW (33),
188+
.DATAW (MUL_WIDTH),
181189
.DEPTH (1)
182190
) pipe_mulsel (
183191
.clk (clk),
@@ -188,75 +196,118 @@ module VX_tcu_fedp_bhf #(
188196
);
189197
end
190198

191-
// Accumulate reduction tree
199+
// Product terms reduction
192200

193-
wire [32:0] red_in [LEVELS+1][TCK];
201+
for (genvar lvl = 0; lvl < LEVELS; lvl++) begin : g_levels
202+
localparam CURSZ = TCK >> lvl;
203+
localparam OUTSZ = CURSZ >> 1;
204+
localparam in_expw = (lvl == 0) ? MUL_EXP : ACC_EXP;
205+
localparam in_sigw = (lvl == 0) ? MUL_SIG : ACC_SIG;
206+
localparam in_w = 1 + in_expw + in_sigw;
207+
localparam out_expw = ACC_EXP;
208+
localparam out_sigw = ACC_SIG;
209+
localparam out_w = 1 + out_expw + out_sigw;
194210

195-
for (genvar i = 0; i < TCK; i++) begin : g_red_in
196-
assign red_in[0][i] = mult_result[i];
197-
end
211+
wire [OUTSZ-1:0][out_w-1:0] sum;
198212

199-
for (genvar lvl = 0; lvl < LEVELS; lvl++) begin : g_accumulate
200-
localparam CURSZ = TCK >> lvl;
201-
localparam OUTSZ = CURSZ >> 1;
202213
for (genvar i = 0; i < OUTSZ; i++) begin : g_add
214+
wire [in_w-1:0] a, b;
215+
if (lvl == 0) begin
216+
assign a = mult_result[2*i+0];
217+
assign b = mult_result[2*i+1];
218+
end else begin
219+
assign a = g_levels[lvl-1].sum[2*i+0];
220+
assign b = g_levels[lvl-1].sum[2*i+1];
221+
end
222+
203223
VX_tcu_bhf_fadd #(
204-
.IN_EXPW (8),
205-
.IN_SIGW (23+1),
206-
.IN_REC (1), // input in recoded format
207-
.OUT_REC (1), // output in recoded format
224+
.IN_EXPW (in_expw),
225+
.IN_SIGW (in_sigw),
226+
.OUT_EXPW (out_expw), // ACC
227+
.OUT_SIGW (out_sigw), // ACC
228+
.IN_REC (1),
229+
.OUT_REC (1),
208230
.ADD_LATENCY (FADD_LATENCY),
209231
.RND_LATENCY (FRND_LATENCY)
210232
) reduce_add (
211233
.clk (clk),
212234
.reset (reset),
213235
.enable (enable),
214-
.frm (frm),
215-
.a (red_in[lvl][2*i+0]),
216-
.b (red_in[lvl][2*i+1]),
217-
.y (red_in[lvl+1][i]),
236+
.frm (frm), // still RNE
237+
.a (a),
238+
.b (b),
239+
.y (sum[i]),
218240
`UNUSED_PIN(fflags)
219241
);
220242
end
221243
end
222244

223-
// Final accumulation with C
245+
// Final reduced result is ACC width now:
246+
wire [ACC_WIDTH-1:0] red_result = g_levels[LEVELS-1].sum[0];
224247

225-
wire [32:0] c_rec, c_delayed;
248+
// Final accumulation with C
249+
wire [32:0] c_rec, c_rec2;
250+
wire [ACC_WIDTH-1:0] c_up, c_delayed;
226251
wire [31:0] result;
227252

228253
fNToRecFN #(
229254
.expWidth (8),
230255
.sigWidth (24)
231-
) conv_c (
256+
) conv_c_rec (
232257
.in (c_val[31:0]),
233258
.out (c_rec)
234259
);
235260

236261
VX_pipe_register #(
237262
.DATAW (33),
238-
.DEPTH (C_DELAY)
239-
) pipe_c (
263+
.DEPTH (1)
264+
) pipe_c1 (
240265
.clk (clk),
241266
.reset (reset),
242267
.enable (enable),
243268
.data_in (c_rec),
269+
.data_out(c_rec2)
270+
);
271+
272+
recFNToRecFN #(
273+
.inExpWidth (8),
274+
.inSigWidth (24),
275+
.outExpWidth (ACC_EXP),
276+
.outSigWidth (ACC_SIG)
277+
) conv_c_up (
278+
.control (`flControl_tininessAfterRounding),
279+
.roundingMode (`round_near_even),
280+
.in (c_rec2),
281+
.out (c_up),
282+
`UNUSED_PIN (exceptionFlags)
283+
);
284+
285+
VX_pipe_register #(
286+
.DATAW (ACC_WIDTH),
287+
.DEPTH (C_DELAY-1)
288+
) pipe_c2 (
289+
.clk (clk),
290+
.reset (reset),
291+
.enable (enable),
292+
.data_in (c_up),
244293
.data_out(c_delayed)
245294
);
246295

247296
VX_tcu_bhf_fadd #(
248-
.IN_EXPW (8),
249-
.IN_SIGW (23+1),
297+
.IN_EXPW (ACC_EXP),
298+
.IN_SIGW (ACC_SIG),
299+
.OUT_EXPW(8),
300+
.OUT_SIGW(24),
250301
.IN_REC (1), // input in recoded format
251302
.OUT_REC (0), // output in IEEE format
252303
.ADD_LATENCY (FADD_LATENCY),
253304
.RND_LATENCY (FRND_LATENCY)
254-
) final_add (
305+
) acc (
255306
.clk (clk),
256307
.reset (reset),
257308
.enable (enable),
258309
.frm (frm),
259-
.a (red_in[LEVELS][0]),
310+
.a (red_result),
260311
.b (c_delayed),
261312
.y (result),
262313
`UNUSED_PIN(fflags)

0 commit comments

Comments
 (0)