@@ -113,10 +113,10 @@ def __init__(self, od_config: OmniDiffusionConfig) -> None:
113113 self .vllm_config = get_current_vllm_config ()
114114 self .post_init ()
115115
116- def pre_load (self ) :
117- tp_rank = get_tensor_model_parallel_rank ()
118- state_dict = get_full_state_dict ( self . od_config . model )
119- non_layer_prefixes = [
116+ def load_weights (self , weights : Iterable [ tuple [ str , torch . Tensor ]]) -> set [ str ] :
117+ skip_prefixes = [ "lm_head." ] if self . hf_config . tie_word_embeddings else []
118+ # List of unexpected keywords in weight names
119+ non_model_layer_prefixes = [
120120 "vae" ,
121121 "vision_model" ,
122122 "vision_aligner" ,
@@ -129,32 +129,15 @@ def pre_load(self):
129129 "time_embed_2" ,
130130 "final_layer.model" ,
131131 ]
132- filtered_sd = {}
133- for k , v in state_dict .items ():
134- if any (k .startswith (prefix ) for prefix in non_layer_prefixes ):
135- filtered_sd [k ] = v
136-
137- missing , unexpected = self .load_state_dict (filtered_sd , strict = False )
138-
139- for prefix in non_layer_prefixes :
140- if hasattr (self , prefix .split ("." )[0 ]):
141- module = dict (self .named_modules ()).get (prefix )
142- if module :
143- module .to (f"{ self .model .device .type } :{ tp_rank } " )
132+ tp_rank = get_tensor_model_parallel_rank ()
133+ device_str = f"{ self .model .device .type } :{ tp_rank } "
134+ named_modules = dict (self .named_modules ())
135+ for prefix in non_model_layer_prefixes :
136+ mod = named_modules .get (prefix )
137+ if mod :
138+ mod .to (device_str )
144139
145- def load_weights (self , weights : Iterable [tuple [str , torch .Tensor ]]) -> set [str ]:
146- self .pre_load ()
147- skip_prefixes = ["lm_head." ] if self .hf_config .tie_word_embeddings else []
148- # List of unexpected keywords in weight names
149140 unexpected_keywords = [
150- "vae" ,
151- "vision_aligner" ,
152- "vision_model" ,
153- "final_layer" ,
154- "patch_embed" ,
155- "timestep_emb" ,
156- "time_embed" ,
157- "time_embed_2" ,
158141 "guidance_emb" ,
159142 "timestep_r_emb" ,
160143 ]
@@ -892,7 +875,7 @@ def forward_call(
892875
893876 custom_pos_emb = self .get_pos_emb (custom_pos_emb , position_ids )
894877
895- inputs_embeds = self .model .wte (input_ids )
878+ inputs_embeds = self .model .embed_tokens (input_ids )
896879
897880 bsz , seq_len , n_embd = inputs_embeds .shape
898881
0 commit comments