Skip to content

Commit 20d56ae

Browse files
committed
sparse gated fedp drl
1 parent 6655c81 commit 20d56ae

File tree

9 files changed

+176
-284
lines changed

9 files changed

+176
-284
lines changed

hw/rtl/libs/VX_csa_half_en.sv

Lines changed: 0 additions & 186 deletions
This file was deleted.

hw/rtl/tcu/drl/VX_tcu_drl_acc.sv

Lines changed: 34 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
//
88
// Unless required by applicable law or agreed to in writing, software
99
// distributed under the License is distributed on an "AS IS" BASIS,
10-
// WAITHOUT WAARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10+
// WITHOUT WAARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1111
// See the License for the specific language governing permissions and
1212
// limitations under the License.
1313

@@ -20,27 +20,48 @@ module VX_tcu_drl_acc #(
2020
) (
2121
input wire [N-1:0][W-1:0] sigsIn,
2222
input wire fmt_sel,
23+
input wire [N-2:0] sparse_mask,
2324
output logic [WA-1:0] sigOut,
2425
output logic [N-2:0] signOuts
2526
);
26-
// Sign-extend fp significands to WA bits (header)
27+
28+
//input power gating
29+
wire [N-1:0][W-1:0] gated_sigsIn;
30+
for (genvar i = 0; i < N-1; i++) begin : g_power_gating
31+
assign gated_sigsIn[i] = ({W{sparse_mask[i]}} & sigsIn[i]);
32+
end
33+
assign gated_sigsIn[N-1] = sigsIn[N-1]; //c_val
34+
35+
//Sign-extend fp significands to WA bits (header)
2736
wire [N-1:0][WA-1:0] sigsIn_ext;
2837
for (genvar i = 0; i < N; i++) begin : g_ext_sign
29-
assign sigsIn_ext[i] = fmt_sel ? {{(WA-W){1'b0}}, sigsIn[i]} : {{(WA-W){sigsIn[i][W-1]}}, sigsIn[i]};
38+
assign sigsIn_ext[i] = fmt_sel ? {{(WA-W){1'b0}}, gated_sigsIn[i]} : {{(WA-W){gated_sigsIn[i][W-1]}}, gated_sigsIn[i]};
3039
end
3140

3241
//Carry-Save-Adder based significand accumulation
33-
VX_csa_half_en #(
34-
.N (N),
35-
.W (WA),
36-
.S (WA-1)
37-
) sig_csa (
38-
.operands (sigsIn_ext),
39-
.half_en (1'b1), // TODO: feed sparsity control signal when resolved
40-
.sum (sigOut[WA-2:0]),
41-
.cout (sigOut[WA-1])
42-
);
42+
if (N >= 7) begin : g_large_acc
43+
VX_csa_mod4 #(
44+
.N (N),
45+
.W (WA),
46+
.S (WA-1)
47+
) sig_csa (
48+
.operands (sigsIn_ext),
49+
.sum (sigOut[WA-2:0]),
50+
.cout (sigOut[WA-1])
51+
);
52+
end else begin : g_small_acc
53+
VX_csa_tree #(
54+
.N (N),
55+
.W (WA),
56+
.S (WA-1)
57+
) sig_csa (
58+
.operands (sigsIn_ext),
59+
.sum (sigOut[WA-2:0]),
60+
.cout (sigOut[WA-1])
61+
);
62+
end
4363

64+
//Extract prod sigs signs for INT
4465
for (genvar i = 0; i < N-1; i++) begin : g_signs
4566
assign signOuts[i] = sigsIn[i][W-1];
4667
end

hw/rtl/tcu/drl/VX_tcu_drl_align.sv

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -20,22 +20,28 @@ module VX_tcu_drl_align #(
2020
input wire [N-1:0][7:0] shift_amounts,
2121
input wire [N-1:0][24:0] sigs_in,
2222
input wire fmt_sel,
23+
input wire [N-2:0] sparse_mask,
2324
output logic [N-1:0][W-1:0] sigs_out
2425
);
2526

27+
//input power gating
28+
wire [N-1:0][7:0] gated_shift_amounts;
29+
wire [N-1:0][24:0] gated_sigs_in;
30+
for (genvar i = 0; i < N-1; i++) begin : g_power_gating
31+
assign gated_sigs_in[i] = ({25{sparse_mask[i]}} & sigs_in[i]);
32+
assign gated_shift_amounts[i] = ({8{sparse_mask[i]}} & shift_amounts[i]);
33+
end
34+
assign gated_sigs_in[N-1] = sigs_in[N-1]; //c_val
35+
assign gated_shift_amounts[N-1] = shift_amounts[N-1];
36+
2637
//extend + align + sign significands
2738
for (genvar i = 0; i < N; i++) begin : g_align
28-
wire [W-1:0] ext_sigs_in = {sigs_in[i], {W-25{1'b0}}};
39+
wire [W-1:0] ext_sigs_in = {gated_sigs_in[i], {W-25{1'b0}}};
2940
wire fp_sign = ext_sigs_in[W-1];
3041
wire [W-2:0] fp_sig = ext_sigs_in[W-2:0];
31-
wire [W-2:0] adj_sig = fp_sig >> shift_amounts[i];
42+
wire [W-2:0] adj_sig = fp_sig >> gated_shift_amounts[i];
3243
wire [W-1:0] fp_val = fp_sign ? -adj_sig : {1'b0, adj_sig};
3344
assign sigs_out[i] = fmt_sel ? ext_sigs_in : fp_val;
3445
end
3546

3647
endmodule
37-
38-
/*
39-
wire [23:0] adj_sig = shift_amount[3] ? 24'd0 : full_sig[i] >> shift_amount; //reducing switching activity (power) by clamping to 0 if
40-
//input won't make a significant impact on accumulated value
41-
*/

hw/rtl/tcu/drl/VX_tcu_drl_exp_bias.sv

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,12 @@ module VX_tcu_drl_exp_bias (
2222
output logic exp_low_larger,
2323
output logic [6:0] raw_exp_diff
2424
);
25-
`UNUSED_VAR({a, b, enable});
25+
26+
//Power gating inputs to prevent switching activity if not enabled
27+
wire [15:0] gated_a = {16{enable}} & a;
28+
wire [15:0] gated_b = {16{enable}} & b;
29+
wire [2:0] gated_fmt_s = {3{enable}} & fmt_s;
30+
`UNUSED_VAR({gated_a, gated_b});
2631

2732
//FP16 exponent addition and bias
2833
wire [7:0] raw_exp_fp16;
@@ -32,7 +37,7 @@ module VX_tcu_drl_exp_bias (
3237
.W(8),
3338
.S(8)
3439
) biasexp_fp16(
35-
.operands({{3'd0, a[14:10]}, {3'd0, b[14:10]}, fp16_32_conv_bias}),
40+
.operands({{3'd0, gated_a[14:10]}, {3'd0, gated_b[14:10]}, fp16_32_conv_bias}),
3641
.sum (raw_exp_fp16),
3742
`UNUSED_PIN (cout)
3843
);
@@ -44,10 +49,10 @@ module VX_tcu_drl_exp_bias (
4449
`UNUSED_VAR(raw_exp_bf16_signed);
4550
VX_csa_tree #(
4651
.N(3),
47-
.W(10), //8 + log2(3) extend for sign handling
52+
.W(10), //8 + log2(3)-extend for sign handling
4853
.S(10)
4954
) biasexp_bf16(
50-
.operands({{2'd0, a[14:7]}, {2'd0, b[14:7]}, neg_bias}),
55+
.operands({{2'd0, gated_a[14:7]}, {2'd0, gated_b[14:7]}, neg_bias}),
5156
.sum (raw_exp_bf16_signed),
5257
`UNUSED_PIN (cout)
5358
);
@@ -60,8 +65,8 @@ module VX_tcu_drl_exp_bias (
6065
VX_ks_adder #(
6166
.N(4)
6267
) raw_exp_fp8_sub_add (
63-
.dataa (a[(i*8)+6 -: 4]),
64-
.datab (b[(i*8)+6 -: 4]),
68+
.dataa (gated_a[(i*8)+6 -: 4]),
69+
.datab (gated_b[(i*8)+6 -: 4]),
6570
.sum (raw_exp_fp8_sub[i][3:0]),
6671
.cout (raw_exp_fp8_sub[i][4])
6772
);
@@ -86,8 +91,8 @@ module VX_tcu_drl_exp_bias (
8691
VX_ks_adder #(
8792
.N(5)
8893
) raw_exp_bf8_sub_add (
89-
.dataa (a[(j*8)+6 -: 5]),
90-
.datab (b[(j*8)+6 -: 5]),
94+
.dataa (gated_a[(j*8)+6 -: 5]),
95+
.datab (gated_b[(j*8)+6 -: 5]),
9196
.sum (raw_exp_bf8_sub[j][4:0]),
9297
.cout (raw_exp_bf8_sub[j][5])
9398
);
@@ -107,7 +112,7 @@ module VX_tcu_drl_exp_bias (
107112

108113
//Select exp out based on datatype
109114
always_comb begin
110-
case(fmt_s[2:0])
115+
case(gated_fmt_s[2:0])
111116
3'd1: begin
112117
raw_exp_y = raw_exp_fp16;
113118
exp_low_larger = 1'bx;

0 commit comments

Comments
 (0)