11# (c) City96 || Apache-2.0 (apache.org/licenses/LICENSE-2.0)
22import os
3- import torch
43import gguf
4+ import torch
5+ import logging
56import argparse
67from tqdm import tqdm
7-
8- from safetensors .torch import load_file
8+ from safetensors .torch import load_file , save_file
99
1010QUANTIZATION_THRESHOLD = 1024
1111REARRANGE_THRESHOLD = 512
1212MAX_TENSOR_NAME_LENGTH = 127
13+ MAX_TENSOR_DIMS = 4
1314
1415class ModelTemplate :
1516 arch = "invalid" # string describing architecture
1617 shape_fix = False # whether to reshape tensors
1718 keys_detect = [] # list of lists to match in state dict
1819 keys_banned = [] # list of keys that should mark model as invalid for conversion
20+ keys_hiprec = [] # list of keys that need to be kept in fp32 for some reason
21+
22+ def handle_nd_tensor (self , key , data ):
23+ raise NotImplementedError (f"Tensor detected that exceeds dims supported by C++ code! ({ key } @ { data .shape } )" )
1924
2025class ModelFlux (ModelTemplate ):
2126 arch = "flux"
@@ -41,6 +46,51 @@ class ModelAura(ModelTemplate):
4146 ]
4247 keys_banned = ["joint_transformer_blocks.3.ff_context.out_projection.weight" ,]
4348
49+ class ModelHiDream (ModelTemplate ):
50+ arch = "hidream"
51+ keys_detect = [
52+ (
53+ "caption_projection.0.linear.weight" ,
54+ "double_stream_blocks.0.block.ff_i.shared_experts.w3.weight"
55+ )
56+ ]
57+ keys_hiprec = [
58+ # nn.parameter, can't load from BF16 ver
59+ ".ff_i.gate.weight" ,
60+ "img_emb.emb_pos"
61+ ]
62+
63+ class ModelHyVid (ModelTemplate ):
64+ arch = "hyvid"
65+ keys_detect = [
66+ (
67+ "double_blocks.0.img_attn_proj.weight" ,
68+ "txt_in.individual_token_refiner.blocks.1.self_attn_qkv.weight" ,
69+ )
70+ ]
71+
72+ def handle_nd_tensor (self , key , data ):
73+ # hacky but don't have any better ideas
74+ path = f"./fix_5d_tensors_{ self .arch } .safetensors" # TODO: somehow get a path here??
75+ if os .path .isfile (path ):
76+ raise RuntimeError (f"5D tensor fix file already exists! { path } " )
77+ fsd = {key : torch .from_numpy (data )}
78+ tqdm .write (f"5D key found in state dict! Manual fix required! - { key } { data .shape } " )
79+ save_file (fsd , path )
80+
81+ class ModelWan (ModelHyVid ):
82+ arch = "wan"
83+ keys_detect = [
84+ (
85+ "blocks.0.self_attn.norm_q.weight" ,
86+ "text_embedding.2.weight" ,
87+ "head.modulation" ,
88+ )
89+ ]
90+ keys_hiprec = [
91+ ".modulation" # nn.parameter, can't load from BF16 ver
92+ ]
93+
4494class ModelLTXV (ModelTemplate ):
4595 arch = "ltxv"
4696 keys_detect = [
@@ -50,6 +100,9 @@ class ModelLTXV(ModelTemplate):
50100 "caption_projection.linear_2.weight" ,
51101 )
52102 ]
103+ keys_hiprec = [
104+ "scale_shift_table" # nn.parameter, can't load from BF16 base quant
105+ ]
53106
54107class ModelSDXL (ModelTemplate ):
55108 arch = "sdxl"
@@ -75,7 +128,7 @@ class ModelSD1(ModelTemplate):
75128 ]
76129
77130# The architectures are checked in order and the first successful match terminates the search.
78- arch_list = [ModelFlux , ModelSD3 , ModelAura , ModelLTXV , ModelSDXL , ModelSD1 ]
131+ arch_list = [ModelFlux , ModelSD3 , ModelAura , ModelHiDream , ModelLTXV , ModelHyVid , ModelWan , ModelSDXL , ModelSD1 ]
79132
80133def is_model_arch (model , state_dict ):
81134 # check if model is correct
@@ -93,7 +146,7 @@ def detect_arch(state_dict):
93146 model_arch = None
94147 for arch in arch_list :
95148 if is_model_arch (arch , state_dict ):
96- model_arch = arch
149+ model_arch = arch ()
97150 break
98151 assert model_arch is not None , "Unknown model architecture!"
99152 return model_arch
@@ -109,13 +162,7 @@ def parse_args():
109162
110163 return args
111164
112- def load_state_dict (path ):
113- if any (path .endswith (x ) for x in [".ckpt" , ".pt" , ".bin" , ".pth" ]):
114- state_dict = torch .load (path , map_location = "cpu" , weights_only = True )
115- state_dict = state_dict .get ("model" , state_dict )
116- else :
117- state_dict = load_file (path )
118-
165+ def strip_prefix (state_dict ):
119166 # only keep unet with no prefix!
120167 prefix = None
121168 for pfx in ["model.diffusion_model." , "model." ]:
@@ -133,14 +180,21 @@ def load_state_dict(path):
133180
134181 return sd
135182
136- def load_model (path ):
137- state_dict = load_state_dict (path )
138- model_arch = detect_arch (state_dict )
139- print (f"* Architecture detected from input: { model_arch .arch } " )
140- writer = gguf .GGUFWriter (path = None , arch = model_arch .arch )
141- return (writer , state_dict , model_arch )
183+ def load_state_dict (path ):
184+ if any (path .endswith (x ) for x in [".ckpt" , ".pt" , ".bin" , ".pth" ]):
185+ state_dict = torch .load (path , map_location = "cpu" , weights_only = True )
186+ for subkey in ["model" , "module" ]:
187+ if subkey in state_dict :
188+ state_dict = state_dict [subkey ]
189+ break
190+ if len (state_dict ) < 20 :
191+ raise RuntimeError (f"pt subkey load failed: { state_dict .keys ()} " )
192+ else :
193+ state_dict = load_file (path )
194+
195+ return strip_prefix (state_dict )
142196
143- def handle_tensors (args , writer , state_dict , model_arch ):
197+ def handle_tensors (writer , state_dict , model_arch ):
144198 name_lengths = tuple (sorted (
145199 ((key , len (key )) for key in state_dict .keys ()),
146200 key = lambda item : item [1 ],
@@ -165,28 +219,23 @@ def handle_tensors(args, writer, state_dict, model_arch):
165219
166220 n_dims = len (data .shape )
167221 data_shape = data .shape
168- data_qtype = getattr (
169- gguf .GGMLQuantizationType ,
170- "BF16" if old_dtype == torch .bfloat16 else "F16"
171- )
222+ if old_dtype == torch .bfloat16 :
223+ data_qtype = gguf .GGMLQuantizationType .BF16
224+ # elif old_dtype == torch.float32:
225+ # data_qtype = gguf.GGMLQuantizationType.F32
226+ else :
227+ data_qtype = gguf .GGMLQuantizationType .F16
228+
229+ # The max no. of dimensions that can be handled by the quantization code is 4
230+ if len (data .shape ) > MAX_TENSOR_DIMS :
231+ model_arch .handle_nd_tensor (key , data )
232+ continue # needs to be added back later
172233
173234 # get number of parameters (AKA elements) in this tensor
174235 n_params = 1
175236 for dim_size in data_shape :
176237 n_params *= dim_size
177238
178- # keys to keep as max precision
179- blacklist = {
180- "time_embedding." ,
181- "add_embedding." ,
182- "time_in." ,
183- "txt_in." ,
184- "vector_in." ,
185- "img_in." ,
186- "guidance_in." ,
187- "final_layer." ,
188- }
189-
190239 if old_dtype in (torch .float32 , torch .bfloat16 ):
191240 if n_dims == 1 :
192241 # one-dimensional tensors should be kept in F32
@@ -197,7 +246,8 @@ def handle_tensors(args, writer, state_dict, model_arch):
197246 # very small tensors
198247 data_qtype = gguf .GGMLQuantizationType .F32
199248
200- elif ".weight" in key and any (x in key for x in blacklist ):
249+ elif any (x in key for x in model_arch .keys_hiprec ):
250+ # tensors that require max precision
201251 data_qtype = gguf .GGMLQuantizationType .F32
202252
203253 if (model_arch .shape_fix # NEVER reshape for models such as flux
@@ -224,25 +274,57 @@ def handle_tensors(args, writer, state_dict, model_arch):
224274
225275 writer .add_tensor (new_name , data , raw_dtype = data_qtype )
226276
227- if __name__ == "__main__" :
228- args = parse_args ()
229- path = args .src
230- writer , state_dict , model_arch = load_model (path )
231-
232- writer .add_quantization_version (gguf .GGML_QUANT_VERSION )
233- if next (iter (state_dict .values ())).dtype == torch .bfloat16 :
234- out_path = f"{ os .path .splitext (path )[0 ]} -BF16.gguf"
235- writer .add_file_type (gguf .LlamaFileType .MOSTLY_BF16 )
277+ def convert_file (path , dst_path = None , interact = True , overwrite = False ):
278+ # load & run model detection logic
279+ state_dict = load_state_dict (path )
280+ model_arch = detect_arch (state_dict )
281+ logging .info (f"* Architecture detected from input: { model_arch .arch } " )
282+
283+ # detect & set dtype for output file
284+ dtypes = [x .dtype for x in state_dict .values ()]
285+ dtypes = {x :dtypes .count (x ) for x in set (dtypes )}
286+ main_dtype = max (dtypes , key = dtypes .get )
287+
288+ if main_dtype == torch .bfloat16 :
289+ ftype_name = "BF16"
290+ ftype_gguf = gguf .LlamaFileType .MOSTLY_BF16
291+ # elif main_dtype == torch.float32:
292+ # ftype_name = "F32"
293+ # ftype_gguf = None
236294 else :
237- out_path = f"{ os .path .splitext (path )[0 ]} -F16.gguf"
238- writer .add_file_type (gguf .LlamaFileType .MOSTLY_F16 )
295+ ftype_name = "F16"
296+ ftype_gguf = gguf .LlamaFileType .MOSTLY_F16
297+
298+ if dst_path is None :
299+ dst_path = f"{ os .path .splitext (path )[0 ]} -{ ftype_name } .gguf"
300+ elif "{ftype}" in dst_path : # lcpp logic
301+ dst_path = dst_path .replace ("{ftype}" , ftype_name )
302+
303+ if os .path .isfile (dst_path ) and not overwrite :
304+ if interact :
305+ input ("Output exists enter to continue or ctrl+c to abort!" )
306+ else :
307+ raise OSError ("Output exists and overwriting is disabled!" )
239308
240- out_path = args .dst or out_path
241- if os .path .isfile (out_path ):
242- input ("Output exists enter to continue or ctrl+c to abort!" )
309+ # handle actual file
310+ writer = gguf .GGUFWriter (path = None , arch = model_arch .arch )
311+ writer .add_quantization_version (gguf .GGML_QUANT_VERSION )
312+ if ftype_gguf is not None :
313+ writer .add_file_type (ftype_gguf )
243314
244- handle_tensors (path , writer , state_dict , model_arch )
245- writer .write_header_to_file (path = out_path )
315+ handle_tensors (writer , state_dict , model_arch )
316+ writer .write_header_to_file (path = dst_path )
246317 writer .write_kv_data_to_file ()
247318 writer .write_tensors_to_file (progress = True )
248319 writer .close ()
320+
321+ fix = f"./fix_5d_tensors_{ model_arch .arch } .safetensors"
322+ if os .path .isfile (fix ):
323+ logging .warning (f"\n ### Warning! Fix file found at '{ fix } '" )
324+ logging .warning (" you most likely need to run 'fix_5d_tensors.py' after quantization." )
325+
326+ return dst_path , model_arch
327+
328+ if __name__ == "__main__" :
329+ args = parse_args ()
330+ convert_file (args .src , args .dst )
0 commit comments