Skip to content

Commit caf4e29

Browse files
JackTemakialbertz
andauthored
Add BlstmEncoder (#19)
Co-authored-by: Albert Zeyer <[email protected]>
1 parent ab4b708 commit caf4e29

File tree

4 files changed

+147
-0
lines changed

4 files changed

+147
-0
lines changed

.github/workflows/model_tests.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ jobs:
2121
- run: |
2222
pip install pytest
2323
pip install -r requirements.txt
24+
pip install -r requirements_dev.txt
2425
- name: Test Models
2526
run: |
2627
python -m pytest tests

i6_models/parts/blstm.py

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
from dataclasses import dataclass
2+
import torch
3+
from torch import nn
4+
5+
from i6_models.config import ModelConfiguration
6+
7+
8+
@dataclass
9+
class BlstmEncoderV1Config(ModelConfiguration):
10+
"""
11+
Attributes:
12+
num_layers: number of bi-directional LSTM layers, minimum 2
13+
input_dim: input dimension size
14+
hidden_dim: hidden dimension of one direction of LSTM, the total output size is twice of this
15+
dropout: nn.LSTM supports internal Dropout applied between each layer of BLSTM (but not on input/output)
16+
enforce_sorted:
17+
True: expects that sequences are sorted by sequence length in decreasing order.
18+
Will not do any sorting.
19+
This is required for ONNX-Export, and thus the recommended setting.
20+
False: no expectation.
21+
It will internally enforce that they are sorted
22+
and undo the reordering at the output.
23+
24+
Sorting can for example be performed independent of the ONNX export in e.g. train_step:
25+
26+
audio_features_len, indices = torch.sort(audio_features_len, descending=True)
27+
audio_features = audio_features[indices, :, :]
28+
labels = labels[indices, :]
29+
labels_len = labels_len[indices]
30+
"""
31+
32+
num_layers: int
33+
input_dim: int
34+
hidden_dim: int
35+
dropout: float
36+
enforce_sorted: bool
37+
38+
39+
class BlstmEncoderV1(torch.nn.Module):
40+
"""
41+
Simple multi-layer BLSTM model including dropout, batch-first variant,
42+
hardcoded to use B,T,F input
43+
44+
supports: TorchScript, ONNX-export
45+
"""
46+
47+
def __init__(self, config: BlstmEncoderV1Config):
48+
"""
49+
:param config: configuration object
50+
"""
51+
super().__init__()
52+
self.dropout = config.dropout
53+
self.enforce_sorted = config.enforce_sorted
54+
self.blstm_stack = nn.LSTM(
55+
input_size=config.input_dim,
56+
hidden_size=config.hidden_dim,
57+
bidirectional=True,
58+
num_layers=config.num_layers,
59+
batch_first=True,
60+
dropout=self.dropout,
61+
)
62+
63+
def forward(self, x: torch.Tensor, seq_len: torch.Tensor) -> torch.Tensor:
64+
"""
65+
:param x: [B, T, input_dim]
66+
:param seq_len: [B], should be on CPU for Script/Trace mode
67+
:return [B, T, 2 * hidden_dim]
68+
"""
69+
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
70+
# during graph mode we have to assume all Tensors are on the correct device,
71+
# otherwise move lengths to the CPU if they are on GPU
72+
if seq_len.get_device() >= 0:
73+
seq_len = seq_len.cpu()
74+
75+
blstm_packed_in = nn.utils.rnn.pack_padded_sequence(
76+
input=x,
77+
lengths=seq_len,
78+
enforce_sorted=self.enforce_sorted,
79+
batch_first=True,
80+
)
81+
blstm_out, _ = self.blstm_stack(blstm_packed_in)
82+
blstm_out, _ = nn.utils.rnn.pad_packed_sequence(blstm_out, padding_value=0.0, batch_first=True)
83+
84+
return blstm_out

requirements_dev.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
onnx
2+
onnxruntime

tests/test_blstm.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
import onnxruntime as ort
2+
import tempfile
3+
import torch
4+
from torch.onnx import export as export_onnx
5+
6+
from i6_models.parts.blstm import BlstmEncoderV1, BlstmEncoderV1Config
7+
8+
9+
def test_blstm_onnx_export():
10+
with torch.no_grad(), tempfile.NamedTemporaryFile() as f:
11+
config = BlstmEncoderV1Config(num_layers=4, input_dim=50, hidden_dim=128, dropout=0.1, enforce_sorted=True)
12+
model = BlstmEncoderV1(config=config)
13+
scripted_model = torch.jit.optimize_for_inference(torch.jit.script(model.eval()))
14+
15+
dummy_data = torch.randn(3, 30, 50)
16+
dummy_data_len = torch.IntTensor([30, 20, 15])
17+
dummy_data_len_2 = torch.IntTensor([30, 15, 10])
18+
19+
outputs_normal = model(dummy_data, dummy_data_len)
20+
outputs_scripted = scripted_model(dummy_data, dummy_data_len)
21+
assert torch.allclose(outputs_normal, outputs_scripted)
22+
export_onnx(
23+
scripted_model,
24+
(dummy_data, dummy_data_len),
25+
f=f,
26+
verbose=True,
27+
input_names=["data", "data_len"],
28+
output_names=["classes"],
29+
dynamic_axes={
30+
# dict value: manually named axes
31+
"data": {0: "batch", 1: "time"},
32+
"data_len": {0: "batch"},
33+
"classes": {0: "batch", 1: "time"},
34+
},
35+
)
36+
session = ort.InferenceSession(f.name)
37+
outputs_onnx = torch.FloatTensor(
38+
session.run(None, {"data": dummy_data.numpy(), "data_len": dummy_data_len.numpy()})[0]
39+
)
40+
outputs_onnx_other = torch.FloatTensor(
41+
session.run(None, {"data": dummy_data.numpy(), "data_len": dummy_data_len_2.numpy()})[0]
42+
)
43+
# The default 1e-8 was slightly too strong
44+
assert torch.allclose(outputs_normal, outputs_onnx, atol=1e-6)
45+
# check that for different lengths we really get a different result
46+
assert not torch.allclose(outputs_normal, outputs_onnx_other, atol=1e-6)
47+
48+
# check with different batching and max size
49+
outputs_onnx_diff_batch = torch.FloatTensor(
50+
session.run(
51+
None,
52+
{
53+
"data": dummy_data[(1, 2), :20, :].numpy(),
54+
"data_len": dummy_data_len[
55+
(1, 2),
56+
].numpy(),
57+
},
58+
)[0]
59+
)
60+
assert torch.allclose(outputs_normal[2, :20], outputs_onnx_diff_batch[1], atol=1e-6)

0 commit comments

Comments
 (0)