Skip to content

Commit 1635e56

Browse files
authored
Add eval and parallel dataset (#4651)
1 parent c8c45fd commit 1635e56

File tree

4 files changed

+411
-136
lines changed

4 files changed

+411
-136
lines changed

research/deep_speech/data/dataset.py

Lines changed: 101 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,15 @@
1717
from __future__ import division
1818
from __future__ import print_function
1919

20+
import functools
21+
import multiprocessing
22+
2023
import numpy as np
2124
import scipy.io.wavfile as wavfile
2225
from six.moves import xrange # pylint: disable=redefined-builtin
2326
import tensorflow as tf
2427

25-
# pylint: disable=g-bad-import-order
26-
from data.featurizer import AudioFeaturizer
27-
from data.featurizer import TextFeaturizer
28+
import data.featurizer as featurizer # pylint: disable=g-bad-import-order
2829

2930

3031
class AudioConfig(object):
@@ -44,7 +45,7 @@ def __init__(self,
4445
frame_length: an integer for the length of a spectrogram frame, in ms.
4546
frame_step: an integer for the frame stride, in ms.
4647
fft_length: an integer for the number of fft bins.
47-
normalize: a boolean for whether apply normalization on the audio tensor.
48+
normalize: a boolean for whether apply normalization on the audio feature.
4849
spect_type: a string for the type of spectrogram to be extracted.
4950
"""
5051

@@ -78,90 +79,122 @@ def __init__(self, audio_config, data_path, vocab_file_path):
7879
self.vocab_file_path = vocab_file_path
7980

8081

82+
def _normalize_audio_feature(audio_feature):
83+
"""Perform mean and variance normalization on the spectrogram feature.
84+
85+
Args:
86+
audio_feature: a numpy array for the spectrogram feature.
87+
88+
Returns:
89+
a numpy array of the normalized spectrogram.
90+
"""
91+
mean = np.mean(audio_feature, axis=0)
92+
var = np.var(audio_feature, axis=0)
93+
normalized = (audio_feature - mean) / (np.sqrt(var) + 1e-6)
94+
95+
return normalized
96+
97+
98+
def _preprocess_audio(
99+
audio_file_path, audio_sample_rate, audio_featurizer, normalize):
100+
"""Load the audio file in memory and compute spectrogram feature."""
101+
tf.logging.info(
102+
"Extracting spectrogram feature for {}".format(audio_file_path))
103+
sample_rate, data = wavfile.read(audio_file_path)
104+
assert sample_rate == audio_sample_rate
105+
if data.dtype not in [np.float32, np.float64]:
106+
data = data.astype(np.float32) / np.iinfo(data.dtype).max
107+
feature = featurizer.compute_spectrogram_feature(
108+
data, audio_featurizer.frame_length, audio_featurizer.frame_step,
109+
audio_featurizer.fft_length)
110+
if normalize:
111+
feature = _normalize_audio_feature(feature)
112+
return feature
113+
114+
115+
def _preprocess_transcript(transcript, token_to_index):
116+
"""Process transcript as label features."""
117+
return featurizer.compute_label_feature(transcript, token_to_index)
118+
119+
120+
def _preprocess_data(dataset_config, audio_featurizer, token_to_index):
121+
"""Generate a list of waveform, transcript pair.
122+
123+
Each dataset file contains three columns: "wav_filename", "wav_filesize",
124+
and "transcript". This function parses the csv file and stores each example
125+
by the increasing order of audio length (indicated by wav_filesize).
126+
AS the waveforms are ordered in increasing length, audio samples in a
127+
mini-batch have similar length.
128+
129+
Args:
130+
dataset_config: an instance of DatasetConfig.
131+
audio_featurizer: an instance of AudioFeaturizer.
132+
token_to_index: the mapping from character to its index
133+
134+
Returns:
135+
features and labels array processed from the audio/text input.
136+
"""
137+
138+
file_path = dataset_config.data_path
139+
sample_rate = dataset_config.audio_config.sample_rate
140+
normalize = dataset_config.audio_config.normalize
141+
142+
with tf.gfile.Open(file_path, "r") as f:
143+
lines = f.read().splitlines()
144+
lines = [line.split("\t") for line in lines]
145+
# Skip the csv header.
146+
lines = lines[1:]
147+
# Sort input data by the length of waveform.
148+
lines.sort(key=lambda item: int(item[1]))
149+
150+
# Use multiprocessing for feature/label extraction
151+
num_cores = multiprocessing.cpu_count()
152+
pool = multiprocessing.Pool(processes=num_cores)
153+
154+
features = pool.map(
155+
functools.partial(
156+
_preprocess_audio, audio_sample_rate=sample_rate,
157+
audio_featurizer=audio_featurizer, normalize=normalize),
158+
[line[0] for line in lines])
159+
labels = pool.map(
160+
functools.partial(
161+
_preprocess_transcript, token_to_index=token_to_index),
162+
[line[2] for line in lines])
163+
164+
pool.terminate()
165+
return features, labels
166+
167+
81168
class DeepSpeechDataset(object):
82169
"""Dataset class for training/evaluation of DeepSpeech model."""
83170

84171
def __init__(self, dataset_config):
85-
"""Initialize the class.
86-
87-
Each dataset file contains three columns: "wav_filename", "wav_filesize",
88-
and "transcript". This function parses the csv file and stores each example
89-
by the increasing order of audio length (indicated by wav_filesize).
172+
"""Initialize the DeepSpeechDataset class.
90173
91174
Args:
92175
dataset_config: DatasetConfig object.
93176
"""
94177
self.config = dataset_config
95178
# Instantiate audio feature extractor.
96-
self.audio_featurizer = AudioFeaturizer(
179+
self.audio_featurizer = featurizer.AudioFeaturizer(
97180
sample_rate=self.config.audio_config.sample_rate,
98181
frame_length=self.config.audio_config.frame_length,
99182
frame_step=self.config.audio_config.frame_step,
100-
fft_length=self.config.audio_config.fft_length,
101-
spect_type=self.config.audio_config.spect_type)
183+
fft_length=self.config.audio_config.fft_length)
102184
# Instantiate text feature extractor.
103-
self.text_featurizer = TextFeaturizer(
185+
self.text_featurizer = featurizer.TextFeaturizer(
104186
vocab_file=self.config.vocab_file_path)
105187

106188
self.speech_labels = self.text_featurizer.speech_labels
107-
self.features, self.labels = self._preprocess_data(self.config.data_path)
189+
self.features, self.labels = _preprocess_data(
190+
self.config,
191+
self.audio_featurizer,
192+
self.text_featurizer.token_to_idx
193+
)
194+
108195
self.num_feature_bins = (
109196
self.features[0].shape[1] if len(self.features) else None)
110197

111-
def _preprocess_data(self, file_path):
112-
"""Generate a list of waveform, transcript pair.
113-
114-
Note that the waveforms are ordered in increasing length, so that audio
115-
samples in a mini-batch have similar length.
116-
117-
Args:
118-
file_path: a string specifying the csv file path for a data set.
119-
120-
Returns:
121-
features and labels array processed from the audio/text input.
122-
"""
123-
124-
with tf.gfile.Open(file_path, "r") as f:
125-
lines = f.read().splitlines()
126-
lines = [line.split("\t") for line in lines]
127-
# Skip the csv header.
128-
lines = lines[1:]
129-
# Sort input data by the length of waveform.
130-
lines.sort(key=lambda item: int(item[1]))
131-
features = [self._preprocess_audio(line[0]) for line in lines]
132-
labels = [self._preprocess_transcript(line[2]) for line in lines]
133-
return features, labels
134-
135-
def _normalize_audio_tensor(self, audio_tensor):
136-
"""Perform mean and variance normalization on the spectrogram tensor.
137-
138-
Args:
139-
audio_tensor: a tensor for the spectrogram feature.
140-
141-
Returns:
142-
a tensor for the normalized spectrogram.
143-
"""
144-
mean, var = tf.nn.moments(audio_tensor, axes=[0])
145-
normalized = (audio_tensor - mean) / (tf.sqrt(var) + 1e-6)
146-
return normalized
147-
148-
def _preprocess_audio(self, audio_file_path):
149-
"""Load the audio file in memory."""
150-
tf.logging.info(
151-
"Extracting spectrogram feature for {}".format(audio_file_path))
152-
sample_rate, data = wavfile.read(audio_file_path)
153-
assert sample_rate == self.config.audio_config.sample_rate
154-
if data.dtype not in [np.float32, np.float64]:
155-
data = data.astype(np.float32) / np.iinfo(data.dtype).max
156-
feature = self.audio_featurizer.featurize(data)
157-
if self.config.audio_config.normalize:
158-
feature = self._normalize_audio_tensor(feature)
159-
return tf.Session().run(
160-
feature) # return a numpy array rather than a tensor
161-
162-
def _preprocess_transcript(self, transcript):
163-
return self.text_featurizer.featurize(transcript)
164-
165198

166199
def input_fn(batch_size, deep_speech_dataset, repeat=1):
167200
"""Input function for model training and evaluation.

research/deep_speech/data/featurizer.py

Lines changed: 20 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,21 @@
1818
from __future__ import print_function
1919

2020
import codecs
21-
import functools
2221
import numpy as np
23-
import tensorflow as tf
22+
from scipy import signal
23+
24+
25+
def compute_spectrogram_feature(waveform, frame_length, frame_step, fft_length):
26+
"""Compute the spectrograms for the input waveform."""
27+
_, _, stft = signal.stft(
28+
waveform,
29+
nperseg=frame_length,
30+
noverlap=frame_step,
31+
nfft=fft_length)
32+
33+
# Perform transpose to set its shape as [time_steps, feature_num_bins]
34+
spectrogram = np.transpose(np.absolute(stft), (1, 0))
35+
return spectrogram
2436

2537

2638
class AudioFeaturizer(object):
@@ -30,64 +42,26 @@ def __init__(self,
3042
sample_rate=16000,
3143
frame_length=25,
3244
frame_step=10,
33-
fft_length=None,
34-
window_fn=functools.partial(
35-
tf.contrib.signal.hann_window, periodic=True),
36-
spect_type="linear"):
45+
fft_length=None):
3746
"""Initialize the audio featurizer class according to the configs.
3847
3948
Args:
4049
sample_rate: an integer specifying the sample rate of the input waveform.
4150
frame_length: an integer for the length of a spectrogram frame, in ms.
4251
frame_step: an integer for the frame stride, in ms.
4352
fft_length: an integer for the number of fft bins.
44-
window_fn: windowing function.
45-
spect_type: a string for the type of spectrogram to be extracted.
46-
Currently only support 'linear', otherwise will raise a value error.
47-
48-
Raises:
49-
ValueError: In case of invalid arguments for `spect_type`.
5053
"""
51-
if spect_type != "linear":
52-
raise ValueError("Unsupported spectrogram type: %s" % spect_type)
53-
self.window_fn = window_fn
5454
self.frame_length = int(sample_rate * frame_length / 1e3)
5555
self.frame_step = int(sample_rate * frame_step / 1e3)
5656
self.fft_length = fft_length if fft_length else int(2**(np.ceil(
5757
np.log2(self.frame_length))))
5858

59-
def featurize(self, waveform):
60-
"""Extract spectrogram feature tensors from the waveform."""
61-
return self._compute_linear_spectrogram(waveform)
6259

63-
def _compute_linear_spectrogram(self, waveform):
64-
"""Compute the linear-scale, magnitude spectrograms for the input waveform.
65-
66-
Args:
67-
waveform: a float32 audio tensor.
68-
Returns:
69-
a float 32 tensor with shape [len, num_bins]
70-
"""
71-
72-
# `stfts` is a complex64 Tensor representing the Short-time Fourier
73-
# Transform of each signal in `signals`. Its shape is
74-
# [?, fft_unique_bins] where fft_unique_bins = fft_length // 2 + 1.
75-
stfts = tf.contrib.signal.stft(
76-
waveform,
77-
frame_length=self.frame_length,
78-
frame_step=self.frame_step,
79-
fft_length=self.fft_length,
80-
window_fn=self.window_fn,
81-
pad_end=True)
82-
83-
# An energy spectrogram is the magnitude of the complex-valued STFT.
84-
# A float32 Tensor of shape [?, 257].
85-
magnitude_spectrograms = tf.abs(stfts)
86-
return magnitude_spectrograms
87-
88-
def _compute_mel_filterbank_features(self, waveform):
89-
"""Compute the mel filterbank features."""
90-
raise NotImplementedError("MFCC feature extraction not supported yet.")
60+
def compute_label_feature(text, token_to_idx):
61+
"""Convert string to a list of integers."""
62+
tokens = list(text.strip().lower())
63+
feats = [token_to_idx[token] for token in tokens]
64+
return feats
9165

9266

9367
class TextFeaturizer(object):
@@ -114,9 +88,3 @@ def __init__(self, vocab_file):
11488
self.idx_to_token[idx] = line
11589
self.speech_labels += line
11690
idx += 1
117-
118-
def featurize(self, text):
119-
"""Convert string to a list of integers."""
120-
tokens = list(text.strip().lower())
121-
feats = [self.token_to_idx[token] for token in tokens]
122-
return feats

0 commit comments

Comments
 (0)