@@ -47,7 +47,8 @@ class TransformerDecoderBlockV1Config(ModelConfiguration):
4747 Attributes:
4848 ff_cfg: Configuration for ConformerPositionwiseFeedForwardV1
4949 mhsa_cfg: Configuration for CausalSelfAttentionV1
50- cross_cfg: Configuration for CrossAttentionV1
50+ cross_cfg: Configuration for CrossAttentionV1.
51+ Can be set to None in case there is no cross attention block (e.g. for LM usage).
5152 modules: List of modules to use for ConformerBlockV2:
5253 - "ff" for feed forward module
5354 - "mhcsa" for multi-head causal self attention module
@@ -58,7 +59,7 @@ class TransformerDecoderBlockV1Config(ModelConfiguration):
5859
5960 ff_cfg : ConformerPositionwiseFeedForwardV2Config
6061 mhsa_cfg : CausalSelfAttentionV1Config
61- cross_cfg : CrossAttentionV1Config
62+ cross_cfg : Optional [ CrossAttentionV1Config ]
6263 modules : List [str ] = field (default_factory = lambda : ["mhcsa" , "cross" , "ff" ])
6364 scales : List [float ] = field (default_factory = lambda : [1.0 , 1.0 , 1.0 ])
6465
@@ -67,6 +68,9 @@ def __post__init__(self):
6768
6869 assert len (self .modules ) == len (self .scales ), "modules and scales must have same length"
6970 assert all (name in ["ff" , "mhcsa" , "cross" ] for name in self .modules ), "module type not supported"
71+ assert "cross" not in self .modules or self .cross_cfg is not None , (
72+ "must specify cross attention config when enabling the cross attention module"
73+ )
7074
7175
7276class TransformerDecoderBlockV1State (TypedDict ):
@@ -88,6 +92,9 @@ def __init__(self, cfg: TransformerDecoderBlockV1Config):
8892 elif module_name == "mhcsa" :
8993 modules .append (CausalSelfAttentionV1 (cfg .mhsa_cfg ))
9094 elif module_name == "cross" :
95+ assert cfg .cross_cfg is not None , (
96+ "must specify cross attention config when enabling the cross attention module"
97+ )
9198 modules .append (CrossAttentionV1 (cfg .cross_cfg ))
9299 else :
93100 raise NotImplementedError
0 commit comments