Skip to content

Commit cf3a695

Browse files
Support the HuMo model. (Comfy-Org#9903)
1 parent 1b5cff7 commit cf3a695

File tree

6 files changed

+383
-6
lines changed

6 files changed

+383
-6
lines changed

comfy/audio_encoders/audio_encoders.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ def encode_audio(self, audio, sample_rate):
4141
outputs = {}
4242
outputs["encoded_audio"] = out
4343
outputs["encoded_audio_all_layers"] = all_layers
44+
outputs["audio_samples"] = audio.shape[2]
4445
return outputs
4546

4647

comfy/ldm/wan/model.py

Lines changed: 254 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,9 @@ def __init__(self,
3434
num_heads,
3535
window_size=(-1, -1),
3636
qk_norm=True,
37-
eps=1e-6, operation_settings={}):
37+
eps=1e-6,
38+
kv_dim=None,
39+
operation_settings={}):
3840
assert dim % num_heads == 0
3941
super().__init__()
4042
self.dim = dim
@@ -43,11 +45,13 @@ def __init__(self,
4345
self.window_size = window_size
4446
self.qk_norm = qk_norm
4547
self.eps = eps
48+
if kv_dim is None:
49+
kv_dim = dim
4650

4751
# layers
4852
self.q = operation_settings.get("operations").Linear(dim, dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
49-
self.k = operation_settings.get("operations").Linear(dim, dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
50-
self.v = operation_settings.get("operations").Linear(dim, dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
53+
self.k = operation_settings.get("operations").Linear(kv_dim, dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
54+
self.v = operation_settings.get("operations").Linear(kv_dim, dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
5155
self.o = operation_settings.get("operations").Linear(dim, dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
5256
self.norm_q = operation_settings.get("operations").RMSNorm(dim, eps=eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) if qk_norm else nn.Identity()
5357
self.norm_k = operation_settings.get("operations").RMSNorm(dim, eps=eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) if qk_norm else nn.Identity()
@@ -402,6 +406,7 @@ def __init__(self,
402406
eps=1e-6,
403407
flf_pos_embed_token_number=None,
404408
in_dim_ref_conv=None,
409+
wan_attn_block_class=WanAttentionBlock,
405410
image_model=None,
406411
device=None,
407412
dtype=None,
@@ -479,8 +484,8 @@ def __init__(self,
479484
# blocks
480485
cross_attn_type = 't2v_cross_attn' if model_type == 't2v' else 'i2v_cross_attn'
481486
self.blocks = nn.ModuleList([
482-
WanAttentionBlock(cross_attn_type, dim, ffn_dim, num_heads,
483-
window_size, qk_norm, cross_attn_norm, eps, operation_settings=operation_settings)
487+
wan_attn_block_class(cross_attn_type, dim, ffn_dim, num_heads,
488+
window_size, qk_norm, cross_attn_norm, eps, operation_settings=operation_settings)
484489
for _ in range(num_layers)
485490
])
486491

@@ -1325,3 +1330,247 @@ def block_wrap(args):
13251330
# unpatchify
13261331
x = self.unpatchify(x, grid_sizes)
13271332
return x
1333+
1334+
1335+
class WanT2VCrossAttentionGather(WanSelfAttention):
1336+
1337+
def forward(self, x, context, transformer_options={}, **kwargs):
1338+
r"""
1339+
Args:
1340+
x(Tensor): Shape [B, L1, C] - video tokens
1341+
context(Tensor): Shape [B, L2, C] - audio tokens with shape [B, frames*16, 1536]
1342+
"""
1343+
b, n, d = x.size(0), self.num_heads, self.head_dim
1344+
1345+
q = self.norm_q(self.q(x))
1346+
k = self.norm_k(self.k(context))
1347+
v = self.v(context)
1348+
1349+
# Handle audio temporal structure (16 tokens per frame)
1350+
k = k.reshape(-1, 16, n, d).transpose(1, 2)
1351+
v = v.reshape(-1, 16, n, d).transpose(1, 2)
1352+
1353+
# Handle video spatial structure
1354+
q = q.reshape(k.shape[0], -1, n, d).transpose(1, 2)
1355+
1356+
x = optimized_attention(q, k, v, heads=self.num_heads, skip_reshape=True, skip_output_reshape=True, transformer_options=transformer_options)
1357+
1358+
x = x.transpose(1, 2).view(b, -1, n, d).flatten(2)
1359+
x = self.o(x)
1360+
return x
1361+
1362+
1363+
class AudioCrossAttentionWrapper(nn.Module):
1364+
def __init__(self, dim, kv_dim, num_heads, qk_norm=True, eps=1e-6, operation_settings={}):
1365+
super().__init__()
1366+
1367+
self.audio_cross_attn = WanT2VCrossAttentionGather(dim, num_heads, qk_norm, kv_dim, eps, operation_settings=operation_settings)
1368+
self.norm1_audio = operation_settings.get("operations").LayerNorm(dim, eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
1369+
1370+
def forward(self, x, audio, transformer_options={}):
1371+
x = x + self.audio_cross_attn(self.norm1_audio(x), audio, transformer_options=transformer_options)
1372+
return x
1373+
1374+
1375+
class WanAttentionBlockAudio(WanAttentionBlock):
1376+
1377+
def __init__(self,
1378+
cross_attn_type,
1379+
dim,
1380+
ffn_dim,
1381+
num_heads,
1382+
window_size=(-1, -1),
1383+
qk_norm=True,
1384+
cross_attn_norm=False,
1385+
eps=1e-6, operation_settings={}):
1386+
super().__init__(cross_attn_type, dim, ffn_dim, num_heads, window_size, qk_norm, cross_attn_norm, eps, operation_settings)
1387+
self.audio_cross_attn_wrapper = AudioCrossAttentionWrapper(dim, 1536, num_heads, qk_norm, eps, operation_settings=operation_settings)
1388+
1389+
def forward(
1390+
self,
1391+
x,
1392+
e,
1393+
freqs,
1394+
context,
1395+
context_img_len=257,
1396+
audio=None,
1397+
transformer_options={},
1398+
):
1399+
r"""
1400+
Args:
1401+
x(Tensor): Shape [B, L, C]
1402+
e(Tensor): Shape [B, 6, C]
1403+
freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2]
1404+
"""
1405+
# assert e.dtype == torch.float32
1406+
1407+
if e.ndim < 4:
1408+
e = (comfy.model_management.cast_to(self.modulation, dtype=x.dtype, device=x.device) + e).chunk(6, dim=1)
1409+
else:
1410+
e = (comfy.model_management.cast_to(self.modulation, dtype=x.dtype, device=x.device).unsqueeze(0) + e).unbind(2)
1411+
# assert e[0].dtype == torch.float32
1412+
1413+
# self-attention
1414+
y = self.self_attn(
1415+
torch.addcmul(repeat_e(e[0], x), self.norm1(x), 1 + repeat_e(e[1], x)),
1416+
freqs, transformer_options=transformer_options)
1417+
1418+
x = torch.addcmul(x, y, repeat_e(e[2], x))
1419+
1420+
# cross-attention & ffn
1421+
x = x + self.cross_attn(self.norm3(x), context, context_img_len=context_img_len, transformer_options=transformer_options)
1422+
if audio is not None:
1423+
x = self.audio_cross_attn_wrapper(x, audio, transformer_options=transformer_options)
1424+
y = self.ffn(torch.addcmul(repeat_e(e[3], x), self.norm2(x), 1 + repeat_e(e[4], x)))
1425+
x = torch.addcmul(x, y, repeat_e(e[5], x))
1426+
return x
1427+
1428+
class DummyAdapterLayer(nn.Module):
1429+
def __init__(self, layer):
1430+
super().__init__()
1431+
self.layer = layer
1432+
1433+
def forward(self, *args, **kwargs):
1434+
return self.layer(*args, **kwargs)
1435+
1436+
1437+
class AudioProjModel(nn.Module):
1438+
def __init__(
1439+
self,
1440+
seq_len=5,
1441+
blocks=13, # add a new parameter blocks
1442+
channels=768, # add a new parameter channels
1443+
intermediate_dim=512,
1444+
output_dim=1536,
1445+
context_tokens=16,
1446+
device=None,
1447+
dtype=None,
1448+
operations=None,
1449+
):
1450+
super().__init__()
1451+
1452+
self.seq_len = seq_len
1453+
self.blocks = blocks
1454+
self.channels = channels
1455+
self.input_dim = seq_len * blocks * channels # update input_dim to be the product of blocks and channels.
1456+
self.intermediate_dim = intermediate_dim
1457+
self.context_tokens = context_tokens
1458+
self.output_dim = output_dim
1459+
1460+
# define multiple linear layers
1461+
self.audio_proj_glob_1 = DummyAdapterLayer(operations.Linear(self.input_dim, intermediate_dim, dtype=dtype, device=device))
1462+
self.audio_proj_glob_2 = DummyAdapterLayer(operations.Linear(intermediate_dim, intermediate_dim, dtype=dtype, device=device))
1463+
self.audio_proj_glob_3 = DummyAdapterLayer(operations.Linear(intermediate_dim, context_tokens * output_dim, dtype=dtype, device=device))
1464+
1465+
self.audio_proj_glob_norm = DummyAdapterLayer(operations.LayerNorm(output_dim, dtype=dtype, device=device))
1466+
1467+
def forward(self, audio_embeds):
1468+
video_length = audio_embeds.shape[1]
1469+
audio_embeds = rearrange(audio_embeds, "bz f w b c -> (bz f) w b c")
1470+
batch_size, window_size, blocks, channels = audio_embeds.shape
1471+
audio_embeds = audio_embeds.view(batch_size, window_size * blocks * channels)
1472+
1473+
audio_embeds = torch.relu(self.audio_proj_glob_1(audio_embeds))
1474+
audio_embeds = torch.relu(self.audio_proj_glob_2(audio_embeds))
1475+
1476+
context_tokens = self.audio_proj_glob_3(audio_embeds).reshape(batch_size, self.context_tokens, self.output_dim)
1477+
1478+
context_tokens = self.audio_proj_glob_norm(context_tokens)
1479+
context_tokens = rearrange(context_tokens, "(bz f) m c -> bz f m c", f=video_length)
1480+
1481+
return context_tokens
1482+
1483+
1484+
class HumoWanModel(WanModel):
1485+
r"""
1486+
Wan diffusion backbone supporting both text-to-video and image-to-video.
1487+
"""
1488+
1489+
def __init__(self,
1490+
model_type='humo',
1491+
patch_size=(1, 2, 2),
1492+
text_len=512,
1493+
in_dim=16,
1494+
dim=2048,
1495+
ffn_dim=8192,
1496+
freq_dim=256,
1497+
text_dim=4096,
1498+
out_dim=16,
1499+
num_heads=16,
1500+
num_layers=32,
1501+
window_size=(-1, -1),
1502+
qk_norm=True,
1503+
cross_attn_norm=True,
1504+
eps=1e-6,
1505+
flf_pos_embed_token_number=None,
1506+
image_model=None,
1507+
audio_token_num=16,
1508+
device=None,
1509+
dtype=None,
1510+
operations=None,
1511+
):
1512+
1513+
super().__init__(model_type='t2v', patch_size=patch_size, text_len=text_len, in_dim=in_dim, dim=dim, ffn_dim=ffn_dim, freq_dim=freq_dim, text_dim=text_dim, out_dim=out_dim, num_heads=num_heads, num_layers=num_layers, window_size=window_size, qk_norm=qk_norm, cross_attn_norm=cross_attn_norm, eps=eps, flf_pos_embed_token_number=flf_pos_embed_token_number, wan_attn_block_class=WanAttentionBlockAudio, image_model=image_model, device=device, dtype=dtype, operations=operations)
1514+
1515+
self.audio_proj = AudioProjModel(seq_len=8, blocks=5, channels=1280, intermediate_dim=512, output_dim=1536, context_tokens=audio_token_num, dtype=dtype, device=device, operations=operations)
1516+
1517+
def forward_orig(
1518+
self,
1519+
x,
1520+
t,
1521+
context,
1522+
freqs=None,
1523+
audio_embed=None,
1524+
reference_latent=None,
1525+
transformer_options={},
1526+
**kwargs,
1527+
):
1528+
bs, _, time, height, width = x.shape
1529+
1530+
# embeddings
1531+
x = self.patch_embedding(x.float()).to(x.dtype)
1532+
grid_sizes = x.shape[2:]
1533+
x = x.flatten(2).transpose(1, 2)
1534+
1535+
# time embeddings
1536+
e = self.time_embedding(
1537+
sinusoidal_embedding_1d(self.freq_dim, t.flatten()).to(dtype=x[0].dtype))
1538+
e = e.reshape(t.shape[0], -1, e.shape[-1])
1539+
e0 = self.time_projection(e).unflatten(2, (6, self.dim))
1540+
1541+
if reference_latent is not None:
1542+
ref = self.patch_embedding(reference_latent.float()).to(x.dtype)
1543+
ref = ref.flatten(2).transpose(1, 2)
1544+
freqs_ref = self.rope_encode(reference_latent.shape[-3], reference_latent.shape[-2], reference_latent.shape[-1], t_start=time, device=x.device, dtype=x.dtype)
1545+
x = torch.cat([x, ref], dim=1)
1546+
freqs = torch.cat([freqs, freqs_ref], dim=1)
1547+
del ref, freqs_ref
1548+
1549+
# context
1550+
context = self.text_embedding(context)
1551+
context_img_len = None
1552+
1553+
if audio_embed is not None:
1554+
audio = self.audio_proj(audio_embed).permute(0, 3, 1, 2).flatten(2).transpose(1, 2)
1555+
else:
1556+
audio = None
1557+
1558+
patches_replace = transformer_options.get("patches_replace", {})
1559+
blocks_replace = patches_replace.get("dit", {})
1560+
for i, block in enumerate(self.blocks):
1561+
if ("double_block", i) in blocks_replace:
1562+
def block_wrap(args):
1563+
out = {}
1564+
out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"], context_img_len=context_img_len, audio=audio, transformer_options=args["transformer_options"])
1565+
return out
1566+
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs, "transformer_options": transformer_options}, {"original_block": block_wrap})
1567+
x = out["img"]
1568+
else:
1569+
x = block(x, e=e0, freqs=freqs, context=context, context_img_len=context_img_len, audio=audio, transformer_options=transformer_options)
1570+
1571+
# head
1572+
x = self.head(x, e)
1573+
1574+
# unpatchify
1575+
x = self.unpatchify(x, grid_sizes)
1576+
return x

comfy/model_base.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1213,6 +1213,23 @@ def extra_conds(self, **kwargs):
12131213
out['camera_conditions'] = comfy.conds.CONDRegular(camera_conditions)
12141214
return out
12151215

1216+
class WAN21_HuMo(WAN21):
1217+
def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None):
1218+
super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.HumoWanModel)
1219+
self.image_to_video = image_to_video
1220+
1221+
def extra_conds(self, **kwargs):
1222+
out = super().extra_conds(**kwargs)
1223+
1224+
audio_embed = kwargs.get("audio_embed", None)
1225+
if audio_embed is not None:
1226+
out['audio_embed'] = comfy.conds.CONDRegular(audio_embed)
1227+
1228+
reference_latents = kwargs.get("reference_latents", None)
1229+
if reference_latents is not None:
1230+
out['reference_latent'] = comfy.conds.CONDRegular(self.process_latent_in(reference_latents[-1]))
1231+
return out
1232+
12161233
class WAN22_S2V(WAN21):
12171234
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
12181235
super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.WanModel_S2V)

comfy/model_detection.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -402,6 +402,8 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
402402
dit_config["model_type"] = "camera_2.2"
403403
elif '{}casual_audio_encoder.encoder.final_linear.weight'.format(key_prefix) in state_dict_keys:
404404
dit_config["model_type"] = "s2v"
405+
elif '{}audio_proj.audio_proj_glob_1.layer.bias'.format(key_prefix) in state_dict_keys:
406+
dit_config["model_type"] = "humo"
405407
else:
406408
if '{}img_emb.proj.0.bias'.format(key_prefix) in state_dict_keys:
407409
dit_config["model_type"] = "i2v"

comfy/supported_models.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1073,6 +1073,16 @@ def get_model(self, state_dict, prefix="", device=None):
10731073
out = model_base.WAN21_Vace(self, image_to_video=False, device=device)
10741074
return out
10751075

1076+
class WAN21_HuMo(WAN21_T2V):
1077+
unet_config = {
1078+
"image_model": "wan2.1",
1079+
"model_type": "humo",
1080+
}
1081+
1082+
def get_model(self, state_dict, prefix="", device=None):
1083+
out = model_base.WAN21_HuMo(self, image_to_video=False, device=device)
1084+
return out
1085+
10761086
class WAN22_S2V(WAN21_T2V):
10771087
unet_config = {
10781088
"image_model": "wan2.1",
@@ -1351,6 +1361,6 @@ def get_model(self, state_dict, prefix="", device=None):
13511361
out = model_base.HunyuanImage21Refiner(self, device=device)
13521362
return out
13531363

1354-
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, Omnigen2, QwenImage]
1364+
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, Omnigen2, QwenImage]
13551365

13561366
models += [SVD_img2vid]

0 commit comments

Comments
 (0)