Skip to content

Commit bddba68

Browse files
authored
Merge pull request #48 from infinitalo/vineet/vulkan_tq2_0
Add TQ2_0 model support to Vulkan backend
2 parents 9f133bb + 8ebd79e commit bddba68

File tree

5 files changed

+89
-17
lines changed

5 files changed

+89
-17
lines changed

convert_hf_to_gguf.py

Lines changed: 60 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2641,18 +2641,47 @@ def prepare_tensors(self):
26412641
super().prepare_tensors()
26422642

26432643

2644-
@ModelBase.register("BitnetForCausalLM")
2644+
@ModelBase.register("BitnetForCausalLM", "BitNetForCausalLM")
26452645
class BitnetModel(TextModel):
26462646
model_arch = gguf.MODEL_ARCH.BITNET
26472647

2648+
def __init__(self, *args, **kwargs):
2649+
super().__init__(*args, **kwargs)
2650+
self._bitnet_weight_scales: dict[str, torch.Tensor] = {}
2651+
26482652
def set_vocab(self):
2649-
self._set_vocab_sentencepiece()
2653+
if (self.dir_model / "tokenizer.model").is_file():
2654+
self._set_vocab_sentencepiece()
2655+
else:
2656+
self._set_vocab_gpt2()
26502657

26512658
def set_gguf_parameters(self):
26522659
super().set_gguf_parameters()
26532660
self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.LINEAR)
26542661
self.gguf_writer.add_rope_scaling_factor(1.0)
26552662

2663+
@staticmethod
2664+
def _unpack_bitnet_weights(packed: torch.Tensor) -> torch.Tensor:
2665+
if packed.dtype != torch.uint8:
2666+
raise ValueError(f"Expected packed BitNet weights to be torch.uint8, got {packed.dtype}")
2667+
2668+
values_per_item = 4
2669+
rows = packed.shape[0]
2670+
rest = packed.shape[1:]
2671+
2672+
unpacked_chunks: list[torch.Tensor] = []
2673+
mapping = torch.tensor([-1.0, 0.0, 1.0, 0.0], dtype=torch.float32, device=packed.device)
2674+
2675+
for i in range(values_per_item):
2676+
chunk = (packed >> (2 * i)) & 0x03
2677+
chunk = mapping[chunk.long()].reshape((rows, *rest))
2678+
unpacked_chunks.append(chunk)
2679+
2680+
if not unpacked_chunks:
2681+
raise ValueError("Failed to unpack BitNet weights: no chunks produced")
2682+
2683+
return torch.cat(unpacked_chunks, dim=0)
2684+
26562685
def weight_quant(self, weight: Tensor) -> Tensor:
26572686
dtype = weight.dtype
26582687
weight = weight.float()
@@ -2665,8 +2694,36 @@ def weight_quant(self, weight: Tensor) -> Tensor:
26652694
return result.type(dtype)
26662695

26672696
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
2697+
if name.endswith(".weight_scale"):
2698+
weight_name = name[:-13] + ".weight"
2699+
mapped_weight_name = self.map_tensor_name(weight_name)
2700+
if isinstance(data_torch, LazyTorchTensor):
2701+
data_torch = LazyTorchTensor.to_eager(data_torch)
2702+
2703+
scale_tensor = data_torch.to(torch.float32)
2704+
self._bitnet_weight_scales[mapped_weight_name] = scale_tensor
2705+
return []
2706+
26682707
new_name = self.map_tensor_name(name)
26692708

2709+
ternary_weight = False
2710+
2711+
if name.endswith(".weight"):
2712+
scale_tensor = self._bitnet_weight_scales.pop(new_name, None)
2713+
if scale_tensor is not None:
2714+
scale_tensor = scale_tensor.to(torch.float32)
2715+
if scale_tensor.numel() != 1:
2716+
raise ValueError(f"Expected scalar weight_scale for '{name}', got shape {tuple(scale_tensor.shape)}")
2717+
2718+
if isinstance(data_torch, LazyTorchTensor):
2719+
data_torch = LazyTorchTensor.to_eager(data_torch)
2720+
2721+
packed = data_torch.to(torch.uint8)
2722+
unpacked = self._unpack_bitnet_weights(packed)
2723+
scale_value = scale_tensor.reshape(-1)[0].item()
2724+
data_torch = unpacked * scale_value
2725+
ternary_weight = True
2726+
26702727
if any(self.match_model_tensor_name(new_name, key, bid) for key in [
26712728
gguf.MODEL_TENSOR.ATTN_Q,
26722729
gguf.MODEL_TENSOR.ATTN_K,
@@ -2675,7 +2732,7 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
26752732
gguf.MODEL_TENSOR.FFN_UP,
26762733
gguf.MODEL_TENSOR.FFN_DOWN,
26772734
gguf.MODEL_TENSOR.FFN_GATE,
2678-
]):
2735+
]) and not ternary_weight:
26792736
# transform weight into 1/0/-1 (in fp32)
26802737
data_torch = self.weight_quant(data_torch)
26812738

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -216,7 +216,7 @@ class vk_memory_logger;
216216
class vk_perf_logger;
217217
static void ggml_vk_destroy_buffer(vk_buffer& buf);
218218

219-
static constexpr uint32_t mul_mat_vec_max_cols = 16;
219+
static constexpr uint32_t mul_mat_vec_max_cols = 8;
220220
static constexpr uint32_t p021_max_gqa_ratio = 8;
221221

222222
enum vk_device_architecture {
@@ -2587,6 +2587,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
25872587
CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q5_0], matmul_q5_0_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
25882588
CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q5_1], matmul_q5_1_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
25892589
CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q8_0], matmul_q8_0_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
2590+
CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_TQ2_0], matmul_tq2_0_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
25902591
CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q2_K], matmul_q2_k_f16, mmq_wg_denoms_k, warptile_mmq_k, vk_mat_mat_push_constants, 3)
25912592
CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q3_K], matmul_q3_k_f16, mmq_wg_denoms_k, warptile_mmq_k, vk_mat_mat_push_constants, 3)
25922593
CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q4_K], matmul_q4_k_f16, mmq_wg_denoms_k, warptile_mmq_k, vk_mat_mat_push_constants, 3)
@@ -2616,6 +2617,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
26162617
CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0], matmul_id_subgroup_q5_0_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
26172618
CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1], matmul_id_subgroup_q5_1_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
26182619
CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0], matmul_id_subgroup_q8_0_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
2620+
CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_TQ2_0], matmul_id_subgroup_tq2_0_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
26192621
CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K], matmul_id_subgroup_q2_k_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
26202622
CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K], matmul_id_subgroup_q3_k_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
26212623
CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K], matmul_id_subgroup_q4_k_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
@@ -2677,6 +2679,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
26772679
CREATE_MM2(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0], matmul_q5_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
26782680
CREATE_MM2(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1], matmul_q5_1_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
26792681
CREATE_MM2(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0], matmul_q8_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
2682+
CREATE_MM2(GGML_TYPE_TQ2_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_TQ2_0], matmul_tq2_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
26802683

26812684
CREATE_MM2(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K], matmul_q2_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
26822685
CREATE_MM2(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K], matmul_q3_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
@@ -2699,6 +2702,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
26992702
CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0].f32acc, matmul_q5_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
27002703
CREATE_MM(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1].f32acc, matmul_q5_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
27012704
CREATE_MM(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0].f32acc, matmul_q8_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
2705+
CREATE_MM(GGML_TYPE_TQ2_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_TQ2_0].f32acc, matmul_tq2_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
27022706

27032707
CREATE_MM(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K].f32acc, matmul_q2_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
27042708
CREATE_MM(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K].f32acc, matmul_q3_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
@@ -2733,6 +2737,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
27332737
CREATE_MM2(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0], matmul_id_subgroup_q5_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
27342738
CREATE_MM2(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1], matmul_id_subgroup_q5_1_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
27352739
CREATE_MM2(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0], matmul_id_subgroup_q8_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
2740+
CREATE_MM2(GGML_TYPE_TQ2_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_TQ2_0], matmul_id_subgroup_tq2_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
27362741
CREATE_MM2(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K], matmul_id_subgroup_q2_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
27372742
CREATE_MM2(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K], matmul_id_subgroup_q3_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
27382743
CREATE_MM2(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K], matmul_id_subgroup_q4_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
@@ -2838,6 +2843,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
28382843
CREATE_MM2(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0], matmul_id_subgroup_q5_0_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
28392844
CREATE_MM2(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1], matmul_id_subgroup_q5_1_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
28402845
CREATE_MM2(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0], matmul_id_subgroup_q8_0_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
2846+
CREATE_MM2(GGML_TYPE_TQ2_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_TQ2_0], matmul_id_subgroup_tq2_0_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
28412847
CREATE_MM2(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K], matmul_id_subgroup_q2_k_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
28422848
CREATE_MM2(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K], matmul_id_subgroup_q3_k_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
28432849
CREATE_MM2(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K], matmul_id_subgroup_q4_k_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
@@ -2864,6 +2870,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
28642870
CREATE_MM2(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0], matmul_id_q5_0_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
28652871
CREATE_MM2(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1], matmul_id_q5_1_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
28662872
CREATE_MM2(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0], matmul_id_q8_0_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
2873+
CREATE_MM2(GGML_TYPE_TQ2_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_TQ2_0], matmul_id_tq2_0_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
28672874
CREATE_MM2(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K], matmul_id_q2_k_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
28682875
CREATE_MM2(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K], matmul_id_q3_k_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
28692876
CREATE_MM2(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K], matmul_id_q4_k_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
@@ -2919,6 +2926,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
29192926
CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0].f32acc, matmul_q5_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
29202927
CREATE_MM(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1].f32acc, matmul_q5_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
29212928
CREATE_MM(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0].f32acc, matmul_q8_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
2929+
CREATE_MM(GGML_TYPE_TQ2_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_TQ2_0].f32acc, matmul_tq2_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
29222930

29232931
CREATE_MM(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K].f32acc, matmul_q2_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
29242932
CREATE_MM(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K].f32acc, matmul_q3_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
@@ -2957,6 +2965,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
29572965
CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f32acc, matmul_id_subgroup_q5_0_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
29582966
CREATE_MM(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1].f32acc, matmul_id_subgroup_q5_1_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
29592967
CREATE_MM(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0].f32acc, matmul_id_subgroup_q8_0_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
2968+
CREATE_MM(GGML_TYPE_TQ2_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_TQ2_0].f32acc, matmul_id_subgroup_tq2_0_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
29602969
CREATE_MM(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K].f32acc, matmul_id_subgroup_q2_k_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
29612970
CREATE_MM(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K].f32acc, matmul_id_subgroup_q3_k_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
29622971
CREATE_MM(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f32acc, matmul_id_subgroup_q4_k_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
@@ -2983,6 +2992,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
29832992
CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f32acc, matmul_id_q5_0_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
29842993
CREATE_MM(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1].f32acc, matmul_id_q5_1_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
29852994
CREATE_MM(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0].f32acc, matmul_id_q8_0_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
2995+
CREATE_MM(GGML_TYPE_TQ2_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_TQ2_0].f32acc, matmul_id_tq2_0_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
29862996
CREATE_MM(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K].f32acc, matmul_id_q2_k_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
29872997
CREATE_MM(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K].f32acc, matmul_id_q3_k_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
29882998
CREATE_MM(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f32acc, matmul_id_q4_k_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
@@ -3090,6 +3100,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
30903100
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_Q5_0][i], "mul_mat_vec_q5_0_f16_f32", arr_dmmv_q5_0_f16_f32_len[reduc], arr_dmmv_q5_0_f16_f32_data[reduc], "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq, i+1}, 1, true, use_subgroups, force_subgroup_size);
30913101
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_Q5_1][i], "mul_mat_vec_q5_1_f16_f32", arr_dmmv_q5_1_f16_f32_len[reduc], arr_dmmv_q5_1_f16_f32_data[reduc], "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq, i+1}, 1, true, use_subgroups, force_subgroup_size);
30923102
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_Q8_0][i], "mul_mat_vec_q8_0_f16_f32", arr_dmmv_q8_0_f16_f32_len[reduc], arr_dmmv_q8_0_f16_f32_data[reduc], "main", 3, sizeof(vk_mat_vec_push_constants), {1*rm_stdq, 1, 1}, {wg_size_subgroup, 1*rm_stdq, i+1}, 1, true, use_subgroups, force_subgroup_size);
3103+
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_TQ2_0][i], "mul_mat_vec_tq2_0_f16_f32", arr_dmmv_tq2_0_f16_f32_len[reduc], arr_dmmv_tq2_0_f16_f32_data[reduc], "main", 3, sizeof(vk_mat_vec_push_constants), {1*rm_stdq, 1, 1}, {wg_size_subgroup, 1*rm_stdq, i+1}, 1, true, use_subgroups, force_subgroup_size);
30933104
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_Q2_K][i], "mul_mat_vec_q2_k_f16_f32", arr_dmmv_q2_k_f16_f32_len[reduc16], arr_dmmv_q2_k_f16_f32_data[reduc16], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {wg_size_subgroup16, rm_kq, i+1}, 1, true, use_subgroups16, force_subgroup_size16);
30943105
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_Q3_K][i], "mul_mat_vec_q3_k_f16_f32", arr_dmmv_q3_k_f16_f32_len[reduc16], arr_dmmv_q3_k_f16_f32_data[reduc16], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {wg_size_subgroup16, rm_kq, i+1}, 1, true, use_subgroups16, force_subgroup_size16);
30953106
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_Q4_K][i], "mul_mat_vec_q4_k_f16_f32", arr_dmmv_q4_k_f16_f32_len[reduc16], arr_dmmv_q4_k_f16_f32_data[reduc16], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {wg_size_subgroup16, rm_kq, i+1}, 1, true, use_subgroups16, force_subgroup_size16);

ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.comp

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -663,13 +663,12 @@ float16_t dequantFuncTQ2_0(const in decodeBufTQ2_0 bl, const in uint blockCoords
663663
{
664664
const float16_t d = bl.block.d;
665665
const uint idx = coordInBlock[1];
666-
const uint iqs = idx / 4;
667-
const uint iqs_offset = idx % 4;
668-
const uint vui = uint(bl.block.qs[iqs]);
669-
const uint c = (vui >> (2 * iqs_offset)) & 3;
670-
const float q = float(c) - 1.0f;
671-
float16_t ret = d * float16_t(q);
672-
return ret;
666+
667+
const uint byte_idx = ((idx >> 7) << 5) + (idx & 31u);
668+
const uint qsshift = (((idx & 127u) >> 5) << 1);
669+
670+
const uint c = (uint(bl.block.qs[byte_idx]) >> qsshift) & 3u;
671+
return d * float16_t(float(c) - 1.0f);
673672
}
674673
#endif
675674

0 commit comments

Comments
 (0)