Skip to content

Commit 9cdff5c

Browse files
committed
hidream update
1 parent 888ef67 commit 9cdff5c

File tree

3 files changed

+135
-53
lines changed

3 files changed

+135
-53
lines changed

EasyQuantizationGUI.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -229,7 +229,7 @@ def main():
229229
quantize_label = tk.Label(quantize_frame, text="Quantize Level:")
230230
quantize_label.pack(side=tk.LEFT)
231231

232-
quantize_levels = ["Q2_K", "Q3_K_S", "Q4_0", "Q4_1", "Q4_K_S", "Q5_0", "Q5_1", "Q5_K_S", "Q6_K", "Q8_0", "F16"]
232+
quantize_levels = ["Q2_K", "Q2_K_S", "Q3_K", "Q3_K_L", "Q3_K_M", "Q3_K_S", "Q4_0", "Q4_1", "Q4_K", "Q4_K_M", "Q4_K_S", "Q5_0", "Q5_1", "Q5_K", "Q5_K_M", "Q5_K_S", "Q6_K", "Q8_0", "F16", "BF16", "F32"]
233233
quantize_level_var = tk.StringVar(root)
234234
quantize_level_var.set("Q8_0") # Set default value to Q8_0
235235

convert.py

Lines changed: 134 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,26 @@
11
# (c) City96 || Apache-2.0 (apache.org/licenses/LICENSE-2.0)
22
import os
3-
import torch
43
import gguf
4+
import torch
5+
import logging
56
import argparse
67
from tqdm import tqdm
7-
8-
from safetensors.torch import load_file
8+
from safetensors.torch import load_file, save_file
99

1010
QUANTIZATION_THRESHOLD = 1024
1111
REARRANGE_THRESHOLD = 512
1212
MAX_TENSOR_NAME_LENGTH = 127
13+
MAX_TENSOR_DIMS = 4
1314

1415
class ModelTemplate:
1516
arch = "invalid" # string describing architecture
1617
shape_fix = False # whether to reshape tensors
1718
keys_detect = [] # list of lists to match in state dict
1819
keys_banned = [] # list of keys that should mark model as invalid for conversion
20+
keys_hiprec = [] # list of keys that need to be kept in fp32 for some reason
21+
22+
def handle_nd_tensor(self, key, data):
23+
raise NotImplementedError(f"Tensor detected that exceeds dims supported by C++ code! ({key} @ {data.shape})")
1924

2025
class ModelFlux(ModelTemplate):
2126
arch = "flux"
@@ -41,6 +46,51 @@ class ModelAura(ModelTemplate):
4146
]
4247
keys_banned = ["joint_transformer_blocks.3.ff_context.out_projection.weight",]
4348

49+
class ModelHiDream(ModelTemplate):
50+
arch = "hidream"
51+
keys_detect = [
52+
(
53+
"caption_projection.0.linear.weight",
54+
"double_stream_blocks.0.block.ff_i.shared_experts.w3.weight"
55+
)
56+
]
57+
keys_hiprec = [
58+
# nn.parameter, can't load from BF16 ver
59+
".ff_i.gate.weight",
60+
"img_emb.emb_pos"
61+
]
62+
63+
class ModelHyVid(ModelTemplate):
64+
arch = "hyvid"
65+
keys_detect = [
66+
(
67+
"double_blocks.0.img_attn_proj.weight",
68+
"txt_in.individual_token_refiner.blocks.1.self_attn_qkv.weight",
69+
)
70+
]
71+
72+
def handle_nd_tensor(self, key, data):
73+
# hacky but don't have any better ideas
74+
path = f"./fix_5d_tensors_{self.arch}.safetensors" # TODO: somehow get a path here??
75+
if os.path.isfile(path):
76+
raise RuntimeError(f"5D tensor fix file already exists! {path}")
77+
fsd = {key: torch.from_numpy(data)}
78+
tqdm.write(f"5D key found in state dict! Manual fix required! - {key} {data.shape}")
79+
save_file(fsd, path)
80+
81+
class ModelWan(ModelHyVid):
82+
arch = "wan"
83+
keys_detect = [
84+
(
85+
"blocks.0.self_attn.norm_q.weight",
86+
"text_embedding.2.weight",
87+
"head.modulation",
88+
)
89+
]
90+
keys_hiprec = [
91+
".modulation" # nn.parameter, can't load from BF16 ver
92+
]
93+
4494
class ModelLTXV(ModelTemplate):
4595
arch = "ltxv"
4696
keys_detect = [
@@ -50,6 +100,9 @@ class ModelLTXV(ModelTemplate):
50100
"caption_projection.linear_2.weight",
51101
)
52102
]
103+
keys_hiprec = [
104+
"scale_shift_table" # nn.parameter, can't load from BF16 base quant
105+
]
53106

54107
class ModelSDXL(ModelTemplate):
55108
arch = "sdxl"
@@ -75,7 +128,7 @@ class ModelSD1(ModelTemplate):
75128
]
76129

77130
# The architectures are checked in order and the first successful match terminates the search.
78-
arch_list = [ModelFlux, ModelSD3, ModelAura, ModelLTXV, ModelSDXL, ModelSD1]
131+
arch_list = [ModelFlux, ModelSD3, ModelAura, ModelHiDream, ModelLTXV, ModelHyVid, ModelWan, ModelSDXL, ModelSD1]
79132

80133
def is_model_arch(model, state_dict):
81134
# check if model is correct
@@ -93,7 +146,7 @@ def detect_arch(state_dict):
93146
model_arch = None
94147
for arch in arch_list:
95148
if is_model_arch(arch, state_dict):
96-
model_arch = arch
149+
model_arch = arch()
97150
break
98151
assert model_arch is not None, "Unknown model architecture!"
99152
return model_arch
@@ -109,13 +162,7 @@ def parse_args():
109162

110163
return args
111164

112-
def load_state_dict(path):
113-
if any(path.endswith(x) for x in [".ckpt", ".pt", ".bin", ".pth"]):
114-
state_dict = torch.load(path, map_location="cpu", weights_only=True)
115-
state_dict = state_dict.get("model", state_dict)
116-
else:
117-
state_dict = load_file(path)
118-
165+
def strip_prefix(state_dict):
119166
# only keep unet with no prefix!
120167
prefix = None
121168
for pfx in ["model.diffusion_model.", "model."]:
@@ -133,14 +180,21 @@ def load_state_dict(path):
133180

134181
return sd
135182

136-
def load_model(path):
137-
state_dict = load_state_dict(path)
138-
model_arch = detect_arch(state_dict)
139-
print(f"* Architecture detected from input: {model_arch.arch}")
140-
writer = gguf.GGUFWriter(path=None, arch=model_arch.arch)
141-
return (writer, state_dict, model_arch)
183+
def load_state_dict(path):
184+
if any(path.endswith(x) for x in [".ckpt", ".pt", ".bin", ".pth"]):
185+
state_dict = torch.load(path, map_location="cpu", weights_only=True)
186+
for subkey in ["model", "module"]:
187+
if subkey in state_dict:
188+
state_dict = state_dict[subkey]
189+
break
190+
if len(state_dict) < 20:
191+
raise RuntimeError(f"pt subkey load failed: {state_dict.keys()}")
192+
else:
193+
state_dict = load_file(path)
194+
195+
return strip_prefix(state_dict)
142196

143-
def handle_tensors(args, writer, state_dict, model_arch):
197+
def handle_tensors(writer, state_dict, model_arch):
144198
name_lengths = tuple(sorted(
145199
((key, len(key)) for key in state_dict.keys()),
146200
key=lambda item: item[1],
@@ -165,28 +219,23 @@ def handle_tensors(args, writer, state_dict, model_arch):
165219

166220
n_dims = len(data.shape)
167221
data_shape = data.shape
168-
data_qtype = getattr(
169-
gguf.GGMLQuantizationType,
170-
"BF16" if old_dtype == torch.bfloat16 else "F16"
171-
)
222+
if old_dtype == torch.bfloat16:
223+
data_qtype = gguf.GGMLQuantizationType.BF16
224+
# elif old_dtype == torch.float32:
225+
# data_qtype = gguf.GGMLQuantizationType.F32
226+
else:
227+
data_qtype = gguf.GGMLQuantizationType.F16
228+
229+
# The max no. of dimensions that can be handled by the quantization code is 4
230+
if len(data.shape) > MAX_TENSOR_DIMS:
231+
model_arch.handle_nd_tensor(key, data)
232+
continue # needs to be added back later
172233

173234
# get number of parameters (AKA elements) in this tensor
174235
n_params = 1
175236
for dim_size in data_shape:
176237
n_params *= dim_size
177238

178-
# keys to keep as max precision
179-
blacklist = {
180-
"time_embedding.",
181-
"add_embedding.",
182-
"time_in.",
183-
"txt_in.",
184-
"vector_in.",
185-
"img_in.",
186-
"guidance_in.",
187-
"final_layer.",
188-
}
189-
190239
if old_dtype in (torch.float32, torch.bfloat16):
191240
if n_dims == 1:
192241
# one-dimensional tensors should be kept in F32
@@ -197,7 +246,8 @@ def handle_tensors(args, writer, state_dict, model_arch):
197246
# very small tensors
198247
data_qtype = gguf.GGMLQuantizationType.F32
199248

200-
elif ".weight" in key and any(x in key for x in blacklist):
249+
elif any(x in key for x in model_arch.keys_hiprec):
250+
# tensors that require max precision
201251
data_qtype = gguf.GGMLQuantizationType.F32
202252

203253
if (model_arch.shape_fix # NEVER reshape for models such as flux
@@ -224,25 +274,57 @@ def handle_tensors(args, writer, state_dict, model_arch):
224274

225275
writer.add_tensor(new_name, data, raw_dtype=data_qtype)
226276

227-
if __name__ == "__main__":
228-
args = parse_args()
229-
path = args.src
230-
writer, state_dict, model_arch = load_model(path)
231-
232-
writer.add_quantization_version(gguf.GGML_QUANT_VERSION)
233-
if next(iter(state_dict.values())).dtype == torch.bfloat16:
234-
out_path = f"{os.path.splitext(path)[0]}-BF16.gguf"
235-
writer.add_file_type(gguf.LlamaFileType.MOSTLY_BF16)
277+
def convert_file(path, dst_path=None, interact=True, overwrite=False):
278+
# load & run model detection logic
279+
state_dict = load_state_dict(path)
280+
model_arch = detect_arch(state_dict)
281+
logging.info(f"* Architecture detected from input: {model_arch.arch}")
282+
283+
# detect & set dtype for output file
284+
dtypes = [x.dtype for x in state_dict.values()]
285+
dtypes = {x:dtypes.count(x) for x in set(dtypes)}
286+
main_dtype = max(dtypes, key=dtypes.get)
287+
288+
if main_dtype == torch.bfloat16:
289+
ftype_name = "BF16"
290+
ftype_gguf = gguf.LlamaFileType.MOSTLY_BF16
291+
# elif main_dtype == torch.float32:
292+
# ftype_name = "F32"
293+
# ftype_gguf = None
236294
else:
237-
out_path = f"{os.path.splitext(path)[0]}-F16.gguf"
238-
writer.add_file_type(gguf.LlamaFileType.MOSTLY_F16)
295+
ftype_name = "F16"
296+
ftype_gguf = gguf.LlamaFileType.MOSTLY_F16
297+
298+
if dst_path is None:
299+
dst_path = f"{os.path.splitext(path)[0]}-{ftype_name}.gguf"
300+
elif "{ftype}" in dst_path: # lcpp logic
301+
dst_path = dst_path.replace("{ftype}", ftype_name)
302+
303+
if os.path.isfile(dst_path) and not overwrite:
304+
if interact:
305+
input("Output exists enter to continue or ctrl+c to abort!")
306+
else:
307+
raise OSError("Output exists and overwriting is disabled!")
239308

240-
out_path = args.dst or out_path
241-
if os.path.isfile(out_path):
242-
input("Output exists enter to continue or ctrl+c to abort!")
309+
# handle actual file
310+
writer = gguf.GGUFWriter(path=None, arch=model_arch.arch)
311+
writer.add_quantization_version(gguf.GGML_QUANT_VERSION)
312+
if ftype_gguf is not None:
313+
writer.add_file_type(ftype_gguf)
243314

244-
handle_tensors(path, writer, state_dict, model_arch)
245-
writer.write_header_to_file(path=out_path)
315+
handle_tensors(writer, state_dict, model_arch)
316+
writer.write_header_to_file(path=dst_path)
246317
writer.write_kv_data_to_file()
247318
writer.write_tensors_to_file(progress=True)
248319
writer.close()
320+
321+
fix = f"./fix_5d_tensors_{model_arch.arch}.safetensors"
322+
if os.path.isfile(fix):
323+
logging.warning(f"\n### Warning! Fix file found at '{fix}'")
324+
logging.warning(" you most likely need to run 'fix_5d_tensors.py' after quantization.")
325+
326+
return dst_path, model_arch
327+
328+
if __name__ == "__main__":
329+
args = parse_args()
330+
convert_file(args.src, args.dst)

llama-quantize.exe

5.09 MB
Binary file not shown.

0 commit comments

Comments
 (0)