Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 38 additions & 22 deletions i6_models/parts/rasr_fsa.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,6 +306,43 @@ def build_batched_fsa(self, fsas: Iterable[FsaTuple]) -> WeightedFsa:
return out_fsa


def join_fsas_fbw_v2(fsas: Iterable[FsaTuple]) -> WeightedFsaV2:
"""
Joins a set of FSAs represented as tuples into a single :classref:`WeightedFsaV2` object,
for consumption by the FBW V2 op.

:param fsas: FSAs to be concatenated, represented as tuples with the following fields:
* number of states S
* number of edges E
* integer edge array of shape [E, 3] where each row is an edge
consisting of from-state, to-state and the emission idx
* float weight array of shape [E,]
:return: Single FSA object corresponding to the joined FSAs passed as parameter.
"""
fsas = list(fsas) # ensure we can iterate multiple times over this iterable
num_states = [f[0] for f in fsas]
num_edges = [f[1] for f in fsas]
start_states = np.cumsum(np.array([0] + num_states, dtype=np.uint32))[:-1]
end_states = np.cumsum(num_states) - 1
weights = np.concatenate(tuple(f[3] for f in fsas))

edges = []
for idx, f in enumerate(fsas):
f_edges = f[2].reshape(3, -1).copy()
f_edges[:2, :] += start_states[idx]
edges.append(f_edges)

out_fsa = WeightedFsaV2(
torch.IntTensor(num_states).to(torch.uint32),
torch.IntTensor(num_edges).to(torch.uint32),
torch.IntTensor(np.concatenate(edges, axis=1)).contiguous(),
torch.Tensor(weights),
torch.IntTensor(np.array([start_states, end_states])),
)

return out_fsa


class _RasrFsaBuilderFbw2(_AbstractRasrFsaBuilder):
"""
Abstract base class for building an FSA.
Expand Down Expand Up @@ -334,28 +371,7 @@ def build_batched_fsa(self, fsas: Iterable[FsaTuple]) -> WeightedFsaV2:
:return: Single FSA object corresponding to the joined FSAs passed as parameter.
"""

fsas = list(fsas) # ensure we can iterate multiple times over this iterable
num_states = [f[0] for f in fsas]
num_edges = [f[1] for f in fsas]
start_states = np.cumsum(np.array([0] + num_states, dtype=np.uint32))[:-1]
end_states = np.cumsum(num_states) - 1
weights = np.concatenate(tuple(f[3] for f in fsas))

edges = []
for idx, f in enumerate(fsas):
f_edges = f[2].reshape(3, -1).copy()
f_edges[:2, :] += start_states[idx]
edges.append(f_edges)

out_fsa = WeightedFsaV2(
torch.IntTensor(num_states).to(torch.uint32),
torch.IntTensor(num_edges).to(torch.uint32),
torch.IntTensor(np.concatenate(edges, axis=1)).contiguous(),
torch.Tensor(weights),
torch.IntTensor(np.array([start_states, end_states])),
)

return out_fsa
return join_fsas_fbw_v2(fsas)


class RasrFsaBuilderV2(_RasrFsaBuilderFbw2):
Expand Down