forked from ggml-org/llama.cpp
-
Notifications
You must be signed in to change notification settings - Fork 18
Integrate TQ2_0 into Vulkan #33
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 10 commits
Commits
Show all changes
14 commits
Select commit
Hold shift + click to select a range
af6603d
ggml-vulkan: Add TQ2_0 dequantize and mul_mat vec
makaveli10 87d471b
ggml-vulkan: Enable coopmat support for Android
makaveli10 9a7ba54
ggml-vulkan: Add mul_mm path for TQ2_0
makaveli10 aafd00f
Use the correct subgroup size for TQ2_0.
zoq 911d0d9
Add Vulkan TQ2_0 shader.
zoq 5f19b2a
SET_ROWS and GET_ROWS has no TQ2_0 support yet.
zoq 6651f60
Use the vector/matrix shader for larger matrix/vector computations.
zoq 01fe180
Link against "lc++" on Android, for exception handling symbols.
zoq 7b0b9af
Linking with c++_shared for Android/Termux compatibility.
zoq 9c941a8
Test TQ2_0 dequant + pipelines
787fcba
Make sure the output model can start with a number.
zoq f09743f
Linking against c++_shared is done automatically.
zoq cb8128c
Add support for microsoft/bitnet-b1.58-2B-4T (HF to GGUF).
zoq 6d0777e
Merge branch 'temp-latest' into vulkan_tq2_0_type
zoq File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,36 @@ | ||
| #version 450 | ||
|
|
||
| #extension GL_EXT_shader_16bit_storage : require | ||
|
|
||
| #include "types.comp" | ||
|
|
||
| layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; | ||
| layout (binding = 1) writeonly buffer D {D_TYPE data_b[];}; | ||
|
|
||
| layout (push_constant) uniform parameter { | ||
| uint ne; | ||
| } p; | ||
|
|
||
| layout (local_size_x = 256, local_size_y = 1, local_size_z = 1) in; | ||
|
|
||
| void main() { | ||
| const uint i = gl_GlobalInvocationID.x * 4; | ||
|
|
||
| if (i >= p.ne) { | ||
| return; | ||
| } | ||
|
|
||
| const uint ib = i / QUANT_K; // block index | ||
| const uint iqs = (i % QUANT_K) / 4; // quant index within block (byte index) | ||
| const uint bit_pos_base = (i % 4) * 2; // bit position within byte | ||
|
|
||
| const float d = float(data_a[ib].d); | ||
|
|
||
| for (uint j = 0; j < 4 && (i + j) < p.ne; ++j) { | ||
| const uint local_iqs = ((i + j) % QUANT_K) / 4; // byte index for this element | ||
| const uint bit_pos = ((i + j) % 4) * 2; // bit position for this element | ||
| const uint vui = uint(data_a[ib].qs[local_iqs]); | ||
| const uint q = (vui >> bit_pos) & 3; | ||
| data_b[i + j] = D_TYPE(d * (float(q) - 1.0f)); | ||
| } | ||
| } |
66 changes: 66 additions & 0 deletions
66
ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_tq2_0.comp
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,66 @@ | ||
| #version 450 | ||
| #extension GL_EXT_shader_explicit_arithmetic_types : require | ||
|
|
||
| #include "mul_mat_vec_base.comp" | ||
|
|
||
| layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; | ||
|
|
||
| FLOAT_TYPE temp[NUM_COLS][NUM_ROWS]; | ||
|
|
||
| void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { | ||
| uint a_offset, b_offset, d_offset; | ||
| get_offsets(a_offset, b_offset, d_offset); | ||
|
|
||
| const uint num_blocks_per_row = p.ncols / QUANT_K; | ||
|
|
||
| const uint tid = gl_LocalInvocationID.x; | ||
|
|
||
| [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { | ||
| [[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) { | ||
| temp[j][i] = FLOAT_TYPE(0); | ||
| } | ||
| } | ||
|
|
||
| [[unroll]] for (uint i = tid; i < num_blocks_per_row; i += gl_WorkGroupSize.x) { | ||
|
|
||
| [[unroll]] for (uint n = 0; n < num_rows; ++n) { | ||
| const uint ib0 = a_offset / QUANT_K + (first_row + n) * num_blocks_per_row; | ||
| const float d = float(data_a[ib0 + i].d); | ||
|
|
||
| [[unroll]] for (uint j = 0; j < 64; j += 32) { | ||
| [[unroll]] for (uint l = 0; l < 4; ++l) { | ||
| [[unroll]] for (uint k = 0; k < 32; ++k) { | ||
| // Extract quantized value: ((x[i].qs[j + k] >> (l*2)) & 3) - 1 | ||
| const uint q_byte = uint(data_a[ib0 + i].qs[j + k]); | ||
| const uint shift = l * 2; | ||
| const uint q = (q_byte >> shift) & 3; | ||
| const FLOAT_TYPE dequant_val = FLOAT_TYPE(d * (float(q) - 1.0f)); // CPU kernel: (q-1)*d | ||
|
|
||
| // y-data access pattern: y[i].qs[j*4 + l*32 + k] | ||
| const uint b_idx = i * QUANT_K + j * 4 + l * 32 + k; | ||
| if (b_idx < p.ncols) { | ||
| [[unroll]] for (uint jcol = 0; jcol < NUM_COLS; ++jcol) { | ||
| temp[jcol][n] += dequant_val * FLOAT_TYPE(data_b[jcol * p.batch_stride_b + b_offset + b_idx]); | ||
| } | ||
| } | ||
| } | ||
| } | ||
| } | ||
| } | ||
| } | ||
|
|
||
| reduce_result(temp, d_offset, first_row, num_rows, tid); | ||
| } | ||
|
|
||
| void main() { | ||
| const uint first_row = NUM_ROWS * (gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z); | ||
|
|
||
| if (first_row + NUM_ROWS <= p.stride_d) { | ||
| compute_outputs(first_row, NUM_ROWS); | ||
| } else { | ||
| if (first_row >= p.stride_d) { | ||
| return; | ||
| } | ||
| compute_outputs(first_row, p.stride_d - first_row); | ||
| } | ||
| } |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,58 @@ | ||
| #version 450 | ||
|
|
||
| #include "types.comp" | ||
| #include "generic_binary_head.comp" | ||
| #include "dequant_funcs.comp" | ||
|
|
||
| const uint num_threads = 256; | ||
| layout(local_size_x = num_threads, local_size_y = 1, local_size_z = 1) in; | ||
|
|
||
| void get_dst_indices(uint idx, out uint i20, out uint i21, out uint i22, out uint i23) { | ||
| i23 = fastdiv(idx, (p.ne22*p.ne21*p.ne20)); | ||
| const uint i23_offset = i23 * p.ne22*p.ne21*p.ne20; | ||
| i22 = fastdiv((idx - i23_offset), (p.ne21*p.ne20)); | ||
| const uint i22_offset = i22*p.ne21*p.ne20; | ||
| i21 = (idx - i23_offset - i22_offset) / p.ne20; | ||
| i20 = idx - i23_offset - i22_offset - i21*p.ne20; | ||
| } | ||
|
|
||
| void main() { | ||
| // num_threads * num_iter must equal 512 to match the wg_denoms and get_idx | ||
| const uint num_iter = 2; | ||
|
|
||
| const uint broadcast2 = uint(p.param2); | ||
| const uint broadcast3 = p.param3; | ||
|
|
||
| uint idx = get_idx(); | ||
|
|
||
| [[unroll]] for (uint it = 0; it < num_iter; ++it) { | ||
| if (idx < p.ne) { | ||
| uint i0, i1, i2, i3; | ||
| get_dst_indices(idx, i0, i1, i2, i3); | ||
|
|
||
| float acc = 0.0f; | ||
|
|
||
| for (uint k = 0; k < p.ne01; k += 1) { | ||
| const uint a_block_base = get_aoffset() + (i3 / broadcast3) * p.nb03 + (i2 / broadcast2) * p.nb02 + k * p.nb01; | ||
| const uint ib = a_block_base + (i0 / QUANT_K); | ||
| const uint r = (i0 % QUANT_K); | ||
| const uint iqs = (r % 32u) + 32u * (r / 128u); | ||
| const uint sub = (r % 128u) / 32u; | ||
|
|
||
| const vec4 v = dequantize4(ib, iqs, 0); | ||
| const vec2 dm = get_dm(ib, 0); | ||
|
|
||
| float qv = (sub == 0u) ? v.x : (sub == 1u) ? v.y : (sub == 2u) ? v.z : v.w; | ||
| const float a_val = qv * dm.x + dm.y; | ||
|
|
||
| const uint b_idx = src1_idx(i1, k, i2, i3); | ||
| const float b = data_b[get_boffset() + b_idx]; | ||
| acc += a_val * b; | ||
| } | ||
|
|
||
| uint d_idx = dst_idx(i0, i1, i2, i3); | ||
| data_d[get_doffset() + d_idx] = acc; | ||
| } | ||
| idx += num_threads; | ||
| } | ||
| } |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.