@@ -34,7 +34,9 @@ def __init__(self,
3434 num_heads ,
3535 window_size = (- 1 , - 1 ),
3636 qk_norm = True ,
37- eps = 1e-6 , operation_settings = {}):
37+ eps = 1e-6 ,
38+ kv_dim = None ,
39+ operation_settings = {}):
3840 assert dim % num_heads == 0
3941 super ().__init__ ()
4042 self .dim = dim
@@ -43,11 +45,13 @@ def __init__(self,
4345 self .window_size = window_size
4446 self .qk_norm = qk_norm
4547 self .eps = eps
48+ if kv_dim is None :
49+ kv_dim = dim
4650
4751 # layers
4852 self .q = operation_settings .get ("operations" ).Linear (dim , dim , device = operation_settings .get ("device" ), dtype = operation_settings .get ("dtype" ))
49- self .k = operation_settings .get ("operations" ).Linear (dim , dim , device = operation_settings .get ("device" ), dtype = operation_settings .get ("dtype" ))
50- self .v = operation_settings .get ("operations" ).Linear (dim , dim , device = operation_settings .get ("device" ), dtype = operation_settings .get ("dtype" ))
53+ self .k = operation_settings .get ("operations" ).Linear (kv_dim , dim , device = operation_settings .get ("device" ), dtype = operation_settings .get ("dtype" ))
54+ self .v = operation_settings .get ("operations" ).Linear (kv_dim , dim , device = operation_settings .get ("device" ), dtype = operation_settings .get ("dtype" ))
5155 self .o = operation_settings .get ("operations" ).Linear (dim , dim , device = operation_settings .get ("device" ), dtype = operation_settings .get ("dtype" ))
5256 self .norm_q = operation_settings .get ("operations" ).RMSNorm (dim , eps = eps , elementwise_affine = True , device = operation_settings .get ("device" ), dtype = operation_settings .get ("dtype" )) if qk_norm else nn .Identity ()
5357 self .norm_k = operation_settings .get ("operations" ).RMSNorm (dim , eps = eps , elementwise_affine = True , device = operation_settings .get ("device" ), dtype = operation_settings .get ("dtype" )) if qk_norm else nn .Identity ()
@@ -402,6 +406,7 @@ def __init__(self,
402406 eps = 1e-6 ,
403407 flf_pos_embed_token_number = None ,
404408 in_dim_ref_conv = None ,
409+ wan_attn_block_class = WanAttentionBlock ,
405410 image_model = None ,
406411 device = None ,
407412 dtype = None ,
@@ -479,8 +484,8 @@ def __init__(self,
479484 # blocks
480485 cross_attn_type = 't2v_cross_attn' if model_type == 't2v' else 'i2v_cross_attn'
481486 self .blocks = nn .ModuleList ([
482- WanAttentionBlock (cross_attn_type , dim , ffn_dim , num_heads ,
483- window_size , qk_norm , cross_attn_norm , eps , operation_settings = operation_settings )
487+ wan_attn_block_class (cross_attn_type , dim , ffn_dim , num_heads ,
488+ window_size , qk_norm , cross_attn_norm , eps , operation_settings = operation_settings )
484489 for _ in range (num_layers )
485490 ])
486491
@@ -1325,3 +1330,247 @@ def block_wrap(args):
13251330 # unpatchify
13261331 x = self .unpatchify (x , grid_sizes )
13271332 return x
1333+
1334+
1335+ class WanT2VCrossAttentionGather (WanSelfAttention ):
1336+
1337+ def forward (self , x , context , transformer_options = {}, ** kwargs ):
1338+ r"""
1339+ Args:
1340+ x(Tensor): Shape [B, L1, C] - video tokens
1341+ context(Tensor): Shape [B, L2, C] - audio tokens with shape [B, frames*16, 1536]
1342+ """
1343+ b , n , d = x .size (0 ), self .num_heads , self .head_dim
1344+
1345+ q = self .norm_q (self .q (x ))
1346+ k = self .norm_k (self .k (context ))
1347+ v = self .v (context )
1348+
1349+ # Handle audio temporal structure (16 tokens per frame)
1350+ k = k .reshape (- 1 , 16 , n , d ).transpose (1 , 2 )
1351+ v = v .reshape (- 1 , 16 , n , d ).transpose (1 , 2 )
1352+
1353+ # Handle video spatial structure
1354+ q = q .reshape (k .shape [0 ], - 1 , n , d ).transpose (1 , 2 )
1355+
1356+ x = optimized_attention (q , k , v , heads = self .num_heads , skip_reshape = True , skip_output_reshape = True , transformer_options = transformer_options )
1357+
1358+ x = x .transpose (1 , 2 ).view (b , - 1 , n , d ).flatten (2 )
1359+ x = self .o (x )
1360+ return x
1361+
1362+
1363+ class AudioCrossAttentionWrapper (nn .Module ):
1364+ def __init__ (self , dim , kv_dim , num_heads , qk_norm = True , eps = 1e-6 , operation_settings = {}):
1365+ super ().__init__ ()
1366+
1367+ self .audio_cross_attn = WanT2VCrossAttentionGather (dim , num_heads , qk_norm , kv_dim , eps , operation_settings = operation_settings )
1368+ self .norm1_audio = operation_settings .get ("operations" ).LayerNorm (dim , eps , elementwise_affine = True , device = operation_settings .get ("device" ), dtype = operation_settings .get ("dtype" ))
1369+
1370+ def forward (self , x , audio , transformer_options = {}):
1371+ x = x + self .audio_cross_attn (self .norm1_audio (x ), audio , transformer_options = transformer_options )
1372+ return x
1373+
1374+
1375+ class WanAttentionBlockAudio (WanAttentionBlock ):
1376+
1377+ def __init__ (self ,
1378+ cross_attn_type ,
1379+ dim ,
1380+ ffn_dim ,
1381+ num_heads ,
1382+ window_size = (- 1 , - 1 ),
1383+ qk_norm = True ,
1384+ cross_attn_norm = False ,
1385+ eps = 1e-6 , operation_settings = {}):
1386+ super ().__init__ (cross_attn_type , dim , ffn_dim , num_heads , window_size , qk_norm , cross_attn_norm , eps , operation_settings )
1387+ self .audio_cross_attn_wrapper = AudioCrossAttentionWrapper (dim , 1536 , num_heads , qk_norm , eps , operation_settings = operation_settings )
1388+
1389+ def forward (
1390+ self ,
1391+ x ,
1392+ e ,
1393+ freqs ,
1394+ context ,
1395+ context_img_len = 257 ,
1396+ audio = None ,
1397+ transformer_options = {},
1398+ ):
1399+ r"""
1400+ Args:
1401+ x(Tensor): Shape [B, L, C]
1402+ e(Tensor): Shape [B, 6, C]
1403+ freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2]
1404+ """
1405+ # assert e.dtype == torch.float32
1406+
1407+ if e .ndim < 4 :
1408+ e = (comfy .model_management .cast_to (self .modulation , dtype = x .dtype , device = x .device ) + e ).chunk (6 , dim = 1 )
1409+ else :
1410+ e = (comfy .model_management .cast_to (self .modulation , dtype = x .dtype , device = x .device ).unsqueeze (0 ) + e ).unbind (2 )
1411+ # assert e[0].dtype == torch.float32
1412+
1413+ # self-attention
1414+ y = self .self_attn (
1415+ torch .addcmul (repeat_e (e [0 ], x ), self .norm1 (x ), 1 + repeat_e (e [1 ], x )),
1416+ freqs , transformer_options = transformer_options )
1417+
1418+ x = torch .addcmul (x , y , repeat_e (e [2 ], x ))
1419+
1420+ # cross-attention & ffn
1421+ x = x + self .cross_attn (self .norm3 (x ), context , context_img_len = context_img_len , transformer_options = transformer_options )
1422+ if audio is not None :
1423+ x = self .audio_cross_attn_wrapper (x , audio , transformer_options = transformer_options )
1424+ y = self .ffn (torch .addcmul (repeat_e (e [3 ], x ), self .norm2 (x ), 1 + repeat_e (e [4 ], x )))
1425+ x = torch .addcmul (x , y , repeat_e (e [5 ], x ))
1426+ return x
1427+
1428+ class DummyAdapterLayer (nn .Module ):
1429+ def __init__ (self , layer ):
1430+ super ().__init__ ()
1431+ self .layer = layer
1432+
1433+ def forward (self , * args , ** kwargs ):
1434+ return self .layer (* args , ** kwargs )
1435+
1436+
1437+ class AudioProjModel (nn .Module ):
1438+ def __init__ (
1439+ self ,
1440+ seq_len = 5 ,
1441+ blocks = 13 , # add a new parameter blocks
1442+ channels = 768 , # add a new parameter channels
1443+ intermediate_dim = 512 ,
1444+ output_dim = 1536 ,
1445+ context_tokens = 16 ,
1446+ device = None ,
1447+ dtype = None ,
1448+ operations = None ,
1449+ ):
1450+ super ().__init__ ()
1451+
1452+ self .seq_len = seq_len
1453+ self .blocks = blocks
1454+ self .channels = channels
1455+ self .input_dim = seq_len * blocks * channels # update input_dim to be the product of blocks and channels.
1456+ self .intermediate_dim = intermediate_dim
1457+ self .context_tokens = context_tokens
1458+ self .output_dim = output_dim
1459+
1460+ # define multiple linear layers
1461+ self .audio_proj_glob_1 = DummyAdapterLayer (operations .Linear (self .input_dim , intermediate_dim , dtype = dtype , device = device ))
1462+ self .audio_proj_glob_2 = DummyAdapterLayer (operations .Linear (intermediate_dim , intermediate_dim , dtype = dtype , device = device ))
1463+ self .audio_proj_glob_3 = DummyAdapterLayer (operations .Linear (intermediate_dim , context_tokens * output_dim , dtype = dtype , device = device ))
1464+
1465+ self .audio_proj_glob_norm = DummyAdapterLayer (operations .LayerNorm (output_dim , dtype = dtype , device = device ))
1466+
1467+ def forward (self , audio_embeds ):
1468+ video_length = audio_embeds .shape [1 ]
1469+ audio_embeds = rearrange (audio_embeds , "bz f w b c -> (bz f) w b c" )
1470+ batch_size , window_size , blocks , channels = audio_embeds .shape
1471+ audio_embeds = audio_embeds .view (batch_size , window_size * blocks * channels )
1472+
1473+ audio_embeds = torch .relu (self .audio_proj_glob_1 (audio_embeds ))
1474+ audio_embeds = torch .relu (self .audio_proj_glob_2 (audio_embeds ))
1475+
1476+ context_tokens = self .audio_proj_glob_3 (audio_embeds ).reshape (batch_size , self .context_tokens , self .output_dim )
1477+
1478+ context_tokens = self .audio_proj_glob_norm (context_tokens )
1479+ context_tokens = rearrange (context_tokens , "(bz f) m c -> bz f m c" , f = video_length )
1480+
1481+ return context_tokens
1482+
1483+
1484+ class HumoWanModel (WanModel ):
1485+ r"""
1486+ Wan diffusion backbone supporting both text-to-video and image-to-video.
1487+ """
1488+
1489+ def __init__ (self ,
1490+ model_type = 'humo' ,
1491+ patch_size = (1 , 2 , 2 ),
1492+ text_len = 512 ,
1493+ in_dim = 16 ,
1494+ dim = 2048 ,
1495+ ffn_dim = 8192 ,
1496+ freq_dim = 256 ,
1497+ text_dim = 4096 ,
1498+ out_dim = 16 ,
1499+ num_heads = 16 ,
1500+ num_layers = 32 ,
1501+ window_size = (- 1 , - 1 ),
1502+ qk_norm = True ,
1503+ cross_attn_norm = True ,
1504+ eps = 1e-6 ,
1505+ flf_pos_embed_token_number = None ,
1506+ image_model = None ,
1507+ audio_token_num = 16 ,
1508+ device = None ,
1509+ dtype = None ,
1510+ operations = None ,
1511+ ):
1512+
1513+ super ().__init__ (model_type = 't2v' , patch_size = patch_size , text_len = text_len , in_dim = in_dim , dim = dim , ffn_dim = ffn_dim , freq_dim = freq_dim , text_dim = text_dim , out_dim = out_dim , num_heads = num_heads , num_layers = num_layers , window_size = window_size , qk_norm = qk_norm , cross_attn_norm = cross_attn_norm , eps = eps , flf_pos_embed_token_number = flf_pos_embed_token_number , wan_attn_block_class = WanAttentionBlockAudio , image_model = image_model , device = device , dtype = dtype , operations = operations )
1514+
1515+ self .audio_proj = AudioProjModel (seq_len = 8 , blocks = 5 , channels = 1280 , intermediate_dim = 512 , output_dim = 1536 , context_tokens = audio_token_num , dtype = dtype , device = device , operations = operations )
1516+
1517+ def forward_orig (
1518+ self ,
1519+ x ,
1520+ t ,
1521+ context ,
1522+ freqs = None ,
1523+ audio_embed = None ,
1524+ reference_latent = None ,
1525+ transformer_options = {},
1526+ ** kwargs ,
1527+ ):
1528+ bs , _ , time , height , width = x .shape
1529+
1530+ # embeddings
1531+ x = self .patch_embedding (x .float ()).to (x .dtype )
1532+ grid_sizes = x .shape [2 :]
1533+ x = x .flatten (2 ).transpose (1 , 2 )
1534+
1535+ # time embeddings
1536+ e = self .time_embedding (
1537+ sinusoidal_embedding_1d (self .freq_dim , t .flatten ()).to (dtype = x [0 ].dtype ))
1538+ e = e .reshape (t .shape [0 ], - 1 , e .shape [- 1 ])
1539+ e0 = self .time_projection (e ).unflatten (2 , (6 , self .dim ))
1540+
1541+ if reference_latent is not None :
1542+ ref = self .patch_embedding (reference_latent .float ()).to (x .dtype )
1543+ ref = ref .flatten (2 ).transpose (1 , 2 )
1544+ freqs_ref = self .rope_encode (reference_latent .shape [- 3 ], reference_latent .shape [- 2 ], reference_latent .shape [- 1 ], t_start = time , device = x .device , dtype = x .dtype )
1545+ x = torch .cat ([x , ref ], dim = 1 )
1546+ freqs = torch .cat ([freqs , freqs_ref ], dim = 1 )
1547+ del ref , freqs_ref
1548+
1549+ # context
1550+ context = self .text_embedding (context )
1551+ context_img_len = None
1552+
1553+ if audio_embed is not None :
1554+ audio = self .audio_proj (audio_embed ).permute (0 , 3 , 1 , 2 ).flatten (2 ).transpose (1 , 2 )
1555+ else :
1556+ audio = None
1557+
1558+ patches_replace = transformer_options .get ("patches_replace" , {})
1559+ blocks_replace = patches_replace .get ("dit" , {})
1560+ for i , block in enumerate (self .blocks ):
1561+ if ("double_block" , i ) in blocks_replace :
1562+ def block_wrap (args ):
1563+ out = {}
1564+ out ["img" ] = block (args ["img" ], context = args ["txt" ], e = args ["vec" ], freqs = args ["pe" ], context_img_len = context_img_len , audio = audio , transformer_options = args ["transformer_options" ])
1565+ return out
1566+ out = blocks_replace [("double_block" , i )]({"img" : x , "txt" : context , "vec" : e0 , "pe" : freqs , "transformer_options" : transformer_options }, {"original_block" : block_wrap })
1567+ x = out ["img" ]
1568+ else :
1569+ x = block (x , e = e0 , freqs = freqs , context = context , context_img_len = context_img_len , audio = audio , transformer_options = transformer_options )
1570+
1571+ # head
1572+ x = self .head (x , e )
1573+
1574+ # unpatchify
1575+ x = self .unpatchify (x , grid_sizes )
1576+ return x
0 commit comments