@@ -14,7 +14,7 @@ uint csel = 0;
1414void calc_superblock(const uint a_offset, const uint b_offset, const uint itid, const uint ix, const uint ql_offset, const uint qh_offset, const uint s_offset, const uint y_offset, const uint i, const uint num_blocks_per_row, const uint first_row, const uint num_rows, const bool all_threads) {
1515 const uint y_idx = i * QUANT_K + y_offset;
1616
17- [[unroll]] for (uint n = 0; n < num_rows; ++n) {
17+ for (uint n = 0; n < num_rows; ++n) {
1818 const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row;
1919 csel ^= 1;
2020
@@ -27,15 +27,39 @@ void calc_superblock(const uint a_offset, const uint b_offset, const uint itid,
2727 continue;
2828 }
2929
30- const uint32_t ql0_u32 = uint32_t(data_a_packed16[ib0 + i].ql[ql_offset / 2]) | (uint32_t(data_a_packed16[ib0 + i].ql[ql_offset / 2 + 1]) << 16);
31- const uint32_t ql32_u32 = uint32_t(data_a_packed16[ib0 + i].ql[ql_offset / 2 + 16]) | (uint32_t(data_a_packed16[ib0 + i].ql[ql_offset / 2 + 17]) << 16);
30+ #if 0
31+ const uint32_t ql0_u32 =
32+ uint32_t(data_a_packed16[ib0 + i].ql[ql_offset / 2]) |
33+ (uint32_t(data_a_packed16[ib0 + i].ql[ql_offset / 2 + 1]) << 16);
34+ const uint32_t ql32_u32 =
35+ uint32_t(data_a_packed16[ib0 + i].ql[ql_offset / 2 + 16]) |
36+ (uint32_t(data_a_packed16[ib0 + i].ql[ql_offset / 2 + 17]) << 16);
37+ const uint32_t qh_u32 =
38+ uint32_t(data_a_packed16[ib0 + i].qh[qh_offset / 2]) |
39+ (uint32_t(data_a_packed16[ib0 + i].qh[qh_offset / 2 + 1]) << 16);
40+ #else
41+ const uint32_t ql0_u32 =
42+ uint32_t(data_a[ib0 + i].ql[ql_offset]) |
43+ (uint32_t(data_a[ib0 + i].ql[ql_offset + 1]) << 8) |
44+ (uint32_t(data_a[ib0 + i].ql[ql_offset + 2]) << 16) |
45+ (uint32_t(data_a[ib0 + i].ql[ql_offset + 3]) << 24);
46+ const uint32_t ql32_u32 =
47+ uint32_t(data_a[ib0 + i].ql[ql_offset + 32]) |
48+ (uint32_t(data_a[ib0 + i].ql[ql_offset + 33]) << 8) |
49+ (uint32_t(data_a[ib0 + i].ql[ql_offset + 34]) << 16) |
50+ (uint32_t(data_a[ib0 + i].ql[ql_offset + 35]) << 24);
51+ const uint32_t qh_u32 =
52+ uint32_t(data_a[ib0 + i].qh[qh_offset + 0]) |
53+ (uint32_t(data_a[ib0 + i].qh[qh_offset + 1]) << 8) |
54+ (uint32_t(data_a[ib0 + i].qh[qh_offset + 2]) << 16) |
55+ (uint32_t(data_a[ib0 + i].qh[qh_offset + 3]) << 24);
56+ #endif
3257
3358 const uint32_t ql0_u32_lo4 = ql0_u32 & 0x0F0F0F0F;
3459 const uint32_t ql0_u32_hi4 = (ql0_u32 >> 4) & 0x0F0F0F0F;
3560 const uint32_t ql32_u32_lo4 = ql32_u32 & 0x0F0F0F0F;
3661 const uint32_t ql32_u32_hi4 = (ql32_u32 >> 4) & 0x0F0F0F0F;
3762
38- const uint32_t qh_u32 = uint32_t(data_a_packed16[ib0 + i].qh[qh_offset / 2]) | (uint32_t(data_a_packed16[ib0 + i].qh[qh_offset / 2 + 1]) << 16);
3963 const uint32_t qh0_u32 = (qh_u32 & 0x03030303) << 4;
4064 const uint32_t qh2_u32 = (qh_u32 & 0x0C0C0C0C) << 2;
4165 const uint32_t qh4_u32 = (qh_u32 & 0x30303030);
@@ -46,10 +70,17 @@ void calc_superblock(const uint a_offset, const uint b_offset, const uint itid,
4670 const uint32_t q2_u32 = ql0_u32_hi4 | qh4_u32;
4771 const uint32_t q3_u32 = ql32_u32_hi4 | qh6_u32;
4872
73+ #if 0
4974 const vec4 q0 = vec4(unpack8(q0_u32)) - 32;
5075 const vec4 q1 = vec4(unpack8(q1_u32)) - 32;
5176 const vec4 q2 = vec4(unpack8(q2_u32)) - 32;
5277 const vec4 q3 = vec4(unpack8(q3_u32)) - 32;
78+ #else
79+ const vec4 q0 = vec4(float(q0_u32 & 0xFF), float((q0_u32 >> 8) & 0xFF), float((q0_u32 >> 16) & 0xFF), float(q0_u32 >> 24)) - 32;
80+ const vec4 q1 = vec4(float(q1_u32 & 0xFF), float((q1_u32 >> 8) & 0xFF), float((q1_u32 >> 16) & 0xFF), float(q1_u32 >> 24)) - 32;
81+ const vec4 q2 = vec4(float(q2_u32 & 0xFF), float((q2_u32 >> 8) & 0xFF), float((q2_u32 >> 16) & 0xFF), float(q2_u32 >> 24)) - 32;
82+ const vec4 q3 = vec4(float(q3_u32 & 0xFF), float((q3_u32 >> 8) & 0xFF), float((q3_u32 >> 16) & 0xFF), float(q3_u32 >> 24)) - 32;
83+ #endif
5384
5485 if (all_threads) {
5586 sccache[csel][ix][itid] = FLOAT_TYPE(data_a[ib0 + i].scales[itid]);
@@ -58,14 +89,38 @@ void calc_superblock(const uint a_offset, const uint b_offset, const uint itid,
5889
5990 const FLOAT_TYPE d = FLOAT_TYPE(data_a[ib0 + i].d);
6091
61- [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
92+ for (uint j = 0; j < NUM_COLS; ++j) {
93+
94+ #if 0
6295 vec4 by0 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 ]);
6396 vec4 by32 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 8]);
6497 vec4 by64 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 16]);
6598 vec4 by96 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 24]);
99+ #else
100+ vec4 by0 =
101+ vec4(data_b[(j*p.batch_stride_b + b_offset + y_idx) + 0],
102+ data_b[(j*p.batch_stride_b + b_offset + y_idx) + 1],
103+ data_b[(j*p.batch_stride_b + b_offset + y_idx) + 2],
104+ data_b[(j*p.batch_stride_b + b_offset + y_idx) + 3]);
105+ vec4 by32 =
106+ vec4(data_b[(j*p.batch_stride_b + b_offset + y_idx) + 4 * 8],
107+ data_b[(j*p.batch_stride_b + b_offset + y_idx) + 4 * 8 + 1],
108+ data_b[(j*p.batch_stride_b + b_offset + y_idx) + 4 * 8 + 2],
109+ data_b[(j*p.batch_stride_b + b_offset + y_idx) + 4 * 8 + 3]);
110+ vec4 by64 =
111+ vec4(data_b[(j*p.batch_stride_b + b_offset + y_idx) + 4 * 16],
112+ data_b[(j*p.batch_stride_b + b_offset + y_idx) + 4 * 16 + 1],
113+ data_b[(j*p.batch_stride_b + b_offset + y_idx) + 4 * 16 + 2],
114+ data_b[(j*p.batch_stride_b + b_offset + y_idx) + 4 * 16 + 3]);
115+ vec4 by96 =
116+ vec4(data_b[(j*p.batch_stride_b + b_offset + y_idx) + 4 * 24],
117+ data_b[(j*p.batch_stride_b + b_offset + y_idx) + 4 * 24 + 1],
118+ data_b[(j*p.batch_stride_b + b_offset + y_idx) + 4 * 24 + 2],
119+ data_b[(j*p.batch_stride_b + b_offset + y_idx) + 4 * 24 + 3]);
120+ #endif
66121
67122 FLOAT_TYPE sum[4] = {0, 0, 0, 0};
68- [[unroll]] for (uint l = 0; l < 4; ++l) {
123+ for (uint l = 0; l < 4; ++l) {
69124 sum[0] = fma(FLOAT_TYPE(by0[l]), q0[l], sum[0]);
70125 sum[1] = fma(FLOAT_TYPE(by32[l]), q1[l], sum[1]);
71126 sum[2] = fma(FLOAT_TYPE(by64[l]), q2[l], sum[2]);
@@ -99,16 +154,16 @@ void compute_outputs(const uint first_row, const uint num_rows) {
99154 const uint s_offset = 8*v_im + is;
100155 const uint y_offset = 128*v_im + l0;
101156
102- [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
103- [[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) {
157+ for (uint j = 0; j < NUM_COLS; ++j) {
158+ for (uint i = 0; i < NUM_ROWS; ++i) {
104159 temp[j][i] = FLOAT_TYPE(0);
105160 }
106161 }
107162
108163 const uint nbr_par_th = num_blocks_per_row%it_size;
109164 const uint nbr_all_th = num_blocks_per_row - nbr_par_th;
110165 uint i0 = 0;
111- [[unroll]] for (; i0 < nbr_all_th; i0 += it_size)
166+ for (; i0 < nbr_all_th; i0 += it_size)
112167 calc_superblock(a_offset, b_offset, itid, ix, ql_offset, qh_offset, s_offset, y_offset, i0 + ix, num_blocks_per_row, first_row, num_rows, true);
113168 calc_superblock(a_offset, b_offset, itid, ix, ql_offset, qh_offset, s_offset, y_offset, i0 + ix, num_blocks_per_row, first_row, num_rows, false);
114169
0 commit comments