Skip to content

Commit c401d3e

Browse files
authored
Transformer Decoder: make cross attention config optional (#90)
Then you won't have to specify a dummy config in case you use the module as pure LM without cross attention.
1 parent 12faf98 commit c401d3e

File tree

1 file changed

+9
-2
lines changed

1 file changed

+9
-2
lines changed

i6_models/assemblies/transformer/transformer_decoder_v1.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,8 @@ class TransformerDecoderBlockV1Config(ModelConfiguration):
4747
Attributes:
4848
ff_cfg: Configuration for ConformerPositionwiseFeedForwardV1
4949
mhsa_cfg: Configuration for CausalSelfAttentionV1
50-
cross_cfg: Configuration for CrossAttentionV1
50+
cross_cfg: Configuration for CrossAttentionV1.
51+
Can be set to None in case there is no cross attention block (e.g. for LM usage).
5152
modules: List of modules to use for ConformerBlockV2:
5253
- "ff" for feed forward module
5354
- "mhcsa" for multi-head causal self attention module
@@ -58,7 +59,7 @@ class TransformerDecoderBlockV1Config(ModelConfiguration):
5859

5960
ff_cfg: ConformerPositionwiseFeedForwardV2Config
6061
mhsa_cfg: CausalSelfAttentionV1Config
61-
cross_cfg: CrossAttentionV1Config
62+
cross_cfg: Optional[CrossAttentionV1Config]
6263
modules: List[str] = field(default_factory=lambda: ["mhcsa", "cross", "ff"])
6364
scales: List[float] = field(default_factory=lambda: [1.0, 1.0, 1.0])
6465

@@ -67,6 +68,9 @@ def __post__init__(self):
6768

6869
assert len(self.modules) == len(self.scales), "modules and scales must have same length"
6970
assert all(name in ["ff", "mhcsa", "cross"] for name in self.modules), "module type not supported"
71+
assert "cross" not in self.modules or self.cross_cfg is not None, (
72+
"must specify cross attention config when enabling the cross attention module"
73+
)
7074

7175

7276
class TransformerDecoderBlockV1State(TypedDict):
@@ -88,6 +92,9 @@ def __init__(self, cfg: TransformerDecoderBlockV1Config):
8892
elif module_name == "mhcsa":
8993
modules.append(CausalSelfAttentionV1(cfg.mhsa_cfg))
9094
elif module_name == "cross":
95+
assert cfg.cross_cfg is not None, (
96+
"must specify cross attention config when enabling the cross attention module"
97+
)
9198
modules.append(CrossAttentionV1(cfg.cross_cfg))
9299
else:
93100
raise NotImplementedError

0 commit comments

Comments
 (0)