Skip to content

Commit cb8128c

Browse files
committed
Add support for microsoft/bitnet-b1.58-2B-4T (HF to GGUF).
Signed-off-by: Marcus Edel <[email protected]>
1 parent f09743f commit cb8128c

File tree

2 files changed

+62
-3
lines changed

2 files changed

+62
-3
lines changed

convert_hf_to_gguf.py

Lines changed: 60 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2641,18 +2641,47 @@ def prepare_tensors(self):
26412641
super().prepare_tensors()
26422642

26432643

2644-
@ModelBase.register("BitnetForCausalLM")
2644+
@ModelBase.register("BitnetForCausalLM", "BitNetForCausalLM")
26452645
class BitnetModel(TextModel):
26462646
model_arch = gguf.MODEL_ARCH.BITNET
26472647

2648+
def __init__(self, *args, **kwargs):
2649+
super().__init__(*args, **kwargs)
2650+
self._bitnet_weight_scales: dict[str, torch.Tensor] = {}
2651+
26482652
def set_vocab(self):
2649-
self._set_vocab_sentencepiece()
2653+
if (self.dir_model / "tokenizer.model").is_file():
2654+
self._set_vocab_sentencepiece()
2655+
else:
2656+
self._set_vocab_gpt2()
26502657

26512658
def set_gguf_parameters(self):
26522659
super().set_gguf_parameters()
26532660
self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.LINEAR)
26542661
self.gguf_writer.add_rope_scaling_factor(1.0)
26552662

2663+
@staticmethod
2664+
def _unpack_bitnet_weights(packed: torch.Tensor) -> torch.Tensor:
2665+
if packed.dtype != torch.uint8:
2666+
raise ValueError(f"Expected packed BitNet weights to be torch.uint8, got {packed.dtype}")
2667+
2668+
values_per_item = 4
2669+
rows = packed.shape[0]
2670+
rest = packed.shape[1:]
2671+
2672+
unpacked_chunks: list[torch.Tensor] = []
2673+
mapping = torch.tensor([-1.0, 0.0, 1.0, 0.0], dtype=torch.float32, device=packed.device)
2674+
2675+
for i in range(values_per_item):
2676+
chunk = (packed >> (2 * i)) & 0x03
2677+
chunk = mapping[chunk.long()].reshape((rows, *rest))
2678+
unpacked_chunks.append(chunk)
2679+
2680+
if not unpacked_chunks:
2681+
raise ValueError("Failed to unpack BitNet weights: no chunks produced")
2682+
2683+
return torch.cat(unpacked_chunks, dim=0)
2684+
26562685
def weight_quant(self, weight: Tensor) -> Tensor:
26572686
dtype = weight.dtype
26582687
weight = weight.float()
@@ -2665,8 +2694,36 @@ def weight_quant(self, weight: Tensor) -> Tensor:
26652694
return result.type(dtype)
26662695

26672696
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
2697+
if name.endswith(".weight_scale"):
2698+
weight_name = name[:-13] + ".weight"
2699+
mapped_weight_name = self.map_tensor_name(weight_name)
2700+
if isinstance(data_torch, LazyTorchTensor):
2701+
data_torch = LazyTorchTensor.to_eager(data_torch)
2702+
2703+
scale_tensor = data_torch.to(torch.float32)
2704+
self._bitnet_weight_scales[mapped_weight_name] = scale_tensor
2705+
return []
2706+
26682707
new_name = self.map_tensor_name(name)
26692708

2709+
ternary_weight = False
2710+
2711+
if name.endswith(".weight"):
2712+
scale_tensor = self._bitnet_weight_scales.pop(new_name, None)
2713+
if scale_tensor is not None:
2714+
scale_tensor = scale_tensor.to(torch.float32)
2715+
if scale_tensor.numel() != 1:
2716+
raise ValueError(f"Expected scalar weight_scale for '{name}', got shape {tuple(scale_tensor.shape)}")
2717+
2718+
if isinstance(data_torch, LazyTorchTensor):
2719+
data_torch = LazyTorchTensor.to_eager(data_torch)
2720+
2721+
packed = data_torch.to(torch.uint8)
2722+
unpacked = self._unpack_bitnet_weights(packed)
2723+
scale_value = scale_tensor.reshape(-1)[0].item()
2724+
data_torch = unpacked * scale_value
2725+
ternary_weight = True
2726+
26702727
if any(self.match_model_tensor_name(new_name, key, bid) for key in [
26712728
gguf.MODEL_TENSOR.ATTN_Q,
26722729
gguf.MODEL_TENSOR.ATTN_K,
@@ -2675,7 +2732,7 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
26752732
gguf.MODEL_TENSOR.FFN_UP,
26762733
gguf.MODEL_TENSOR.FFN_DOWN,
26772734
gguf.MODEL_TENSOR.FFN_GATE,
2678-
]):
2735+
]) and not ternary_weight:
26792736
# transform weight into 1/0/-1 (in fp32)
26802737
data_torch = self.weight_quant(data_torch)
26812738

gguf-py/gguf/tensor_mapping.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -898,10 +898,12 @@ class TensorNameMap:
898898

899899
MODEL_TENSOR.ATTN_SUB_NORM: (
900900
"model.layers.{bid}.self_attn.inner_attn_ln", # bitnet
901+
"model.layers.{bid}.self_attn.attn_sub_norm", # microsoft-bitnet
901902
),
902903

903904
MODEL_TENSOR.FFN_SUB_NORM: (
904905
"model.layers.{bid}.mlp.ffn_layernorm", # bitnet
906+
"model.layers.{bid}.mlp.ffn_sub_norm", # microsoft-bitnet
905907
),
906908

907909
MODEL_TENSOR.DEC_ATTN_NORM: (

0 commit comments

Comments
 (0)