Skip to content

Commit 53c297d

Browse files
Better s2v memory estimation. (Comfy-Org#9584)
1 parent 924073d commit 53c297d

File tree

2 files changed

+25
-2
lines changed

2 files changed

+25
-2
lines changed

comfy/ldm/wan/model.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1278,6 +1278,7 @@ def forward_orig(
12781278
x = torch.cat([x, ref], dim=1)
12791279
freqs = torch.cat([freqs, freqs_ref], dim=1)
12801280
t = torch.cat([t, torch.zeros((t.shape[0], reference_latent.shape[-3]), device=t.device, dtype=t.dtype)], dim=1)
1281+
del ref, freqs_ref
12811282

12821283
if reference_motion is not None:
12831284
motion_encoded, freqs_motion = self.frame_packer(reference_motion, self)
@@ -1287,6 +1288,7 @@ def forward_orig(
12871288

12881289
t = torch.repeat_interleave(t, 2, dim=1)
12891290
t = torch.cat([t, torch.zeros((t.shape[0], 3), device=t.device, dtype=t.dtype)], dim=1)
1291+
del motion_encoded, freqs_motion
12901292

12911293
# time embeddings
12921294
e = self.time_embedding(

comfy/model_base.py

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,7 @@ def __init__(self, model_config, model_type=ModelType.EPS, device=None, unet_mod
150150
logging.debug("adm {}".format(self.adm_channels))
151151
self.memory_usage_factor = model_config.memory_usage_factor
152152
self.memory_usage_factor_conds = ()
153+
self.memory_usage_shape_process = {}
153154

154155
def apply_model(self, x, t, c_concat=None, c_crossattn=None, control=None, transformer_options={}, **kwargs):
155156
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
@@ -350,8 +351,15 @@ def memory_required(self, input_shape, cond_shapes={}):
350351
input_shapes = [input_shape]
351352
for c in self.memory_usage_factor_conds:
352353
shape = cond_shapes.get(c, None)
353-
if shape is not None and len(shape) > 0:
354-
input_shapes += shape
354+
if shape is not None:
355+
if c in self.memory_usage_shape_process:
356+
out = []
357+
for s in shape:
358+
out.append(self.memory_usage_shape_process[c](s))
359+
shape = out
360+
361+
if len(shape) > 0:
362+
input_shapes += shape
355363

356364
if comfy.model_management.xformers_enabled() or comfy.model_management.pytorch_attention_flash_attention():
357365
dtype = self.get_dtype()
@@ -1204,6 +1212,8 @@ def extra_conds(self, **kwargs):
12041212
class WAN22_S2V(WAN21):
12051213
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
12061214
super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.WanModel_S2V)
1215+
self.memory_usage_factor_conds = ("reference_latent", "reference_motion")
1216+
self.memory_usage_shape_process = {"reference_motion": lambda shape: [shape[0], shape[1], 1.5, shape[-2], shape[-1]]}
12071217

12081218
def extra_conds(self, **kwargs):
12091219
out = super().extra_conds(**kwargs)
@@ -1224,6 +1234,17 @@ def extra_conds(self, **kwargs):
12241234
out['control_video'] = comfy.conds.CONDRegular(self.process_latent_in(control_video))
12251235
return out
12261236

1237+
def extra_conds_shapes(self, **kwargs):
1238+
out = {}
1239+
ref_latents = kwargs.get("reference_latents", None)
1240+
if ref_latents is not None:
1241+
out['reference_latent'] = list([1, 16, sum(map(lambda a: math.prod(a.size()), ref_latents)) // 16])
1242+
1243+
reference_motion = kwargs.get("reference_motion", None)
1244+
if reference_motion is not None:
1245+
out['reference_motion'] = reference_motion.shape
1246+
return out
1247+
12271248
class WAN22(BaseModel):
12281249
def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None):
12291250
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.WanModel)

0 commit comments

Comments
 (0)