@@ -150,6 +150,7 @@ def __init__(self, model_config, model_type=ModelType.EPS, device=None, unet_mod
150150 logging .debug ("adm {}" .format (self .adm_channels ))
151151 self .memory_usage_factor = model_config .memory_usage_factor
152152 self .memory_usage_factor_conds = ()
153+ self .memory_usage_shape_process = {}
153154
154155 def apply_model (self , x , t , c_concat = None , c_crossattn = None , control = None , transformer_options = {}, ** kwargs ):
155156 return comfy .patcher_extension .WrapperExecutor .new_class_executor (
@@ -350,8 +351,15 @@ def memory_required(self, input_shape, cond_shapes={}):
350351 input_shapes = [input_shape ]
351352 for c in self .memory_usage_factor_conds :
352353 shape = cond_shapes .get (c , None )
353- if shape is not None and len (shape ) > 0 :
354- input_shapes += shape
354+ if shape is not None :
355+ if c in self .memory_usage_shape_process :
356+ out = []
357+ for s in shape :
358+ out .append (self .memory_usage_shape_process [c ](s ))
359+ shape = out
360+
361+ if len (shape ) > 0 :
362+ input_shapes += shape
355363
356364 if comfy .model_management .xformers_enabled () or comfy .model_management .pytorch_attention_flash_attention ():
357365 dtype = self .get_dtype ()
@@ -1204,6 +1212,8 @@ def extra_conds(self, **kwargs):
12041212class WAN22_S2V (WAN21 ):
12051213 def __init__ (self , model_config , model_type = ModelType .FLOW , device = None ):
12061214 super (WAN21 , self ).__init__ (model_config , model_type , device = device , unet_model = comfy .ldm .wan .model .WanModel_S2V )
1215+ self .memory_usage_factor_conds = ("reference_latent" , "reference_motion" )
1216+ self .memory_usage_shape_process = {"reference_motion" : lambda shape : [shape [0 ], shape [1 ], 1.5 , shape [- 2 ], shape [- 1 ]]}
12071217
12081218 def extra_conds (self , ** kwargs ):
12091219 out = super ().extra_conds (** kwargs )
@@ -1224,6 +1234,17 @@ def extra_conds(self, **kwargs):
12241234 out ['control_video' ] = comfy .conds .CONDRegular (self .process_latent_in (control_video ))
12251235 return out
12261236
1237+ def extra_conds_shapes (self , ** kwargs ):
1238+ out = {}
1239+ ref_latents = kwargs .get ("reference_latents" , None )
1240+ if ref_latents is not None :
1241+ out ['reference_latent' ] = list ([1 , 16 , sum (map (lambda a : math .prod (a .size ()), ref_latents )) // 16 ])
1242+
1243+ reference_motion = kwargs .get ("reference_motion" , None )
1244+ if reference_motion is not None :
1245+ out ['reference_motion' ] = reference_motion .shape
1246+ return out
1247+
12271248class WAN22 (BaseModel ):
12281249 def __init__ (self , model_config , model_type = ModelType .FLOW , image_to_video = False , device = None ):
12291250 super ().__init__ (model_config , model_type , device = device , unet_model = comfy .ldm .wan .model .WanModel )
0 commit comments