Skip to content

Commit 5923803

Browse files
SimBe195curufinwe
andauthored
Add Python bindings for SearchAlgorithmV2 (#117)
Co-authored-by: Eugen Beck <[email protected]>
1 parent db96b2f commit 5923803

File tree

7 files changed

+331
-9
lines changed

7 files changed

+331
-9
lines changed

src/Python/Makefile

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ LIBPYTHON_O = \
2121
$(OBJDIR)/Configuration.o \
2222
$(OBJDIR)/Init.o \
2323
$(OBJDIR)/Numpy.o \
24+
$(OBJDIR)/Search.o \
2425
$(OBJDIR)/Utilities.o
2526

2627
CHECK_O = $(OBJDIR)/check.o \

src/Python/Search.cc

Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
/** Copyright 2025 RWTH Aachen University. All rights reserved.
2+
*
3+
* Licensed under the RWTH ASR License (the "License");
4+
* you may not use this file except in compliance with the License.
5+
* You may obtain a copy of the License at
6+
*
7+
* http://www.hltpr.rwth-aachen.de/rwth-asr/rwth-asr-license.html
8+
*
9+
* Unless required by applicable law or agreed to in writing, software
10+
* distributed under the License is distributed on an "AS IS" BASIS,
11+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
* See the License for the specific language governing permissions and
13+
* limitations under the License.
14+
*/
15+
16+
#include "Search.hh"
17+
18+
#include <Search/Module.hh>
19+
#include <Speech/ModelCombination.hh>
20+
21+
namespace py = pybind11;
22+
23+
SearchAlgorithm::SearchAlgorithm(const Core::Configuration& c)
24+
: Core::Component(c),
25+
searchAlgorithm_(Search::Module::instance().createSearchAlgorithmV2(select("search-algorithm"))) {
26+
searchAlgorithm_->setModelCombination({config, searchAlgorithm_->requiredModelCombination(), searchAlgorithm_->requiredAcousticModel()});
27+
}
28+
29+
void SearchAlgorithm::reset() {
30+
searchAlgorithm_->reset();
31+
}
32+
33+
void SearchAlgorithm::enterSegment() {
34+
searchAlgorithm_->enterSegment();
35+
}
36+
37+
void SearchAlgorithm::finishSegment() {
38+
searchAlgorithm_->finishSegment();
39+
}
40+
41+
void SearchAlgorithm::putFeature(py::array_t<f32> const& feature) {
42+
size_t F = 0ul;
43+
if (feature.ndim() == 2) {
44+
if (feature.shape(0) != 1) {
45+
error() << "Received feature tensor with non-trivial batch dimension " << feature.shape(0) << "; should be 1";
46+
}
47+
F = feature.shape(1);
48+
}
49+
else if (feature.ndim() == 1) {
50+
F = feature.shape(0);
51+
}
52+
else {
53+
error() << "Received feature vector of invalid dim " << feature.ndim() << "; should be 1 or 2";
54+
}
55+
56+
searchAlgorithm_->putFeature({feature, F});
57+
}
58+
59+
void SearchAlgorithm::putFeatures(py::array_t<f32> const& features) {
60+
size_t T = 0ul;
61+
size_t F = 0ul;
62+
if (features.ndim() == 3) {
63+
if (features.shape(0) != 1) {
64+
error() << "Received feature tensor with non-trivial batch dimension " << features.shape(0) << "; should be 1";
65+
}
66+
T = features.shape(1);
67+
F = features.shape(2);
68+
}
69+
else if (features.ndim() == 2) {
70+
T = features.shape(0);
71+
F = features.shape(1);
72+
}
73+
else {
74+
error() << "Received feature tensor of invalid dim " << features.ndim() << "; should be 2 or 3";
75+
}
76+
77+
searchAlgorithm_->putFeatures({features, T * F}, T);
78+
}
79+
80+
Traceback SearchAlgorithm::getCurrentBestTraceback() {
81+
searchAlgorithm_->decodeManySteps();
82+
83+
auto traceback = searchAlgorithm_->getCurrentBestTraceback();
84+
std::vector<TracebackItem> result;
85+
result.reserve(traceback->size());
86+
87+
u32 prevTime = 0;
88+
89+
for (auto it = traceback->begin(); it != traceback->end(); ++it) {
90+
if (not it->pronunciation or not it->pronunciation->lemma()) {
91+
continue;
92+
}
93+
result.push_back({
94+
it->pronunciation->lemma()->symbol(),
95+
it->score.acoustic,
96+
it->score.lm,
97+
prevTime,
98+
it->time,
99+
});
100+
prevTime = it->time;
101+
}
102+
return result;
103+
}
104+
105+
Traceback SearchAlgorithm::recognizeSegment(py::array_t<f32> const& features) {
106+
reset();
107+
enterSegment();
108+
putFeatures(features);
109+
finishSegment();
110+
return getCurrentBestTraceback();
111+
}

src/Python/Search.hh

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
/** Copyright 2025 RWTH Aachen University. All rights reserved.
2+
*
3+
* Licensed under the RWTH ASR License (the "License");
4+
* you may not use this file except in compliance with the License.
5+
* You may obtain a copy of the License at
6+
*
7+
* http://www.hltpr.rwth-aachen.de/rwth-asr/rwth-asr-license.html
8+
*
9+
* Unless required by applicable law or agreed to in writing, software
10+
* distributed under the License is distributed on an "AS IS" BASIS,
11+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
* See the License for the specific language governing permissions and
13+
* limitations under the License.
14+
*/
15+
16+
#ifndef _PYTHON_SEARCH_HH
17+
#define _PYTHON_SEARCH_HH
18+
19+
#include <Search/SearchV2.hh>
20+
21+
#pragma push_macro("ensure") // Macro duplication in numpy.h
22+
#undef ensure
23+
#include <pybind11/numpy.h>
24+
#include <pybind11/pybind11.h>
25+
#pragma pop_macro("ensure")
26+
27+
namespace py = pybind11;
28+
29+
struct TracebackItem {
30+
std::string lemma;
31+
f32 amScore;
32+
f32 lmScore;
33+
u32 startTime;
34+
u32 endTime;
35+
};
36+
37+
typedef std::vector<TracebackItem> Traceback;
38+
39+
class SearchAlgorithm : public Core::Component {
40+
public:
41+
SearchAlgorithm(const Core::Configuration& c);
42+
43+
// Call before starting a new recognition. Clean up existing data structures
44+
// from the previous run.
45+
void reset();
46+
47+
// Call at the beginning of a new segment.
48+
void enterSegment();
49+
50+
// Call after all features of the current segment have been passed
51+
void finishSegment();
52+
53+
// Pass a feature array of shape [F] or [1, F]
54+
void putFeature(py::array_t<f32> const& feature);
55+
56+
// Pass an array of features of shape [T, F] or [1, T, F]
57+
void putFeatures(py::array_t<f32> const& features);
58+
59+
// Return the current best result. May contain unstable results.
60+
Traceback getCurrentBestTraceback();
61+
62+
// Convenience function to recognize a full segment given all the features as a tensor of shape [T, F]
63+
// Returns the recognition result
64+
Traceback recognizeSegment(py::array_t<f32> const& features);
65+
66+
private:
67+
std::unique_ptr<Search::SearchAlgorithmV2> searchAlgorithm_;
68+
};
69+
70+
#endif // _PYTHON_SEARCH_HH

src/Tools/LibRASR/Makefile

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ CXXFLAGS += -fPIC
1313
LDFLAGS += -shared
1414

1515
RASR_LIB_O = $(OBJDIR)/LibRASR.o \
16+
$(OBJDIR)/Search.o \
1617
../../Flf/libSprintFlf.$(a) \
1718
../../Flf/FlfCore/libSprintFlfCore.$(a) \
1819
../../Speech/libSprintSpeech.$(a) \

src/Tools/LibRASR/PybindModule.cc

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
#include <string>
2-
32
#include <pybind11/pybind11.h>
43

54
#include <Python/AllophoneStateFsaBuilder.hh>
65
#include <Python/Configuration.hh>
76

87
#include "LibRASR.hh"
8+
#include "Search.hh"
99

1010
namespace py = pybind11;
1111

@@ -18,15 +18,13 @@ PYBIND11_MODULE(librasr, m) {
1818

1919
py::class_<PyConfiguration> pyRasrConfig(m, "Configuration", baseConfigClass);
2020
pyRasrConfig.def(py::init<>());
21-
pyRasrConfig.def("set_from_file",
22-
(bool(Core::Configuration::*)(const std::string&)) & Core::Configuration::setFromFile);
21+
pyRasrConfig.def("set_from_file", static_cast<bool (Core::Configuration::*)(const std::string&)>(&Core::Configuration::setFromFile));
2322

2423
py::class_<AllophoneStateFsaBuilder> pyFsaBuilder(m, "AllophoneStateFsaBuilder");
2524
pyFsaBuilder.def(py::init<const Core::Configuration&>());
26-
pyFsaBuilder.def("get_orthography_by_segment_name",
27-
&AllophoneStateFsaBuilder::getOrthographyBySegmentName);
28-
pyFsaBuilder.def("build_by_orthography",
29-
&AllophoneStateFsaBuilder::buildByOrthography);
30-
pyFsaBuilder.def("build_by_segment_name",
31-
&AllophoneStateFsaBuilder::buildBySegmentName);
25+
pyFsaBuilder.def("get_orthography_by_segment_name", &AllophoneStateFsaBuilder::getOrthographyBySegmentName);
26+
pyFsaBuilder.def("build_by_orthography", &AllophoneStateFsaBuilder::buildByOrthography);
27+
pyFsaBuilder.def("build_by_segment_name", &AllophoneStateFsaBuilder::buildBySegmentName);
28+
29+
bindSearchAlgorithm(m);
3230
}

src/Tools/LibRASR/Search.cc

Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
/** Copyright 2025 RWTH Aachen University. All rights reserved.
2+
*
3+
* Licensed under the RWTH ASR License (the "License");
4+
* you may not use this file except in compliance with the License.
5+
* You may obtain a copy of the License at
6+
*
7+
* http://www.hltpr.rwth-aachen.de/rwth-asr/rwth-asr-license.html
8+
*
9+
* Unless required by applicable law or agreed to in writing, software
10+
* distributed under the License is distributed on an "AS IS" BASIS,
11+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
* See the License for the specific language governing permissions and
13+
* limitations under the License.
14+
*/
15+
16+
#include "Search.hh"
17+
18+
#include <sstream>
19+
20+
#include <Python/Search.hh>
21+
22+
void bindSearchAlgorithm(py::module_& module) {
23+
/*
24+
* ========================
25+
* === Traceback ==========
26+
* ========================
27+
*/
28+
py::class_<TracebackItem> pyTracebackItem(
29+
module,
30+
"TracebackItem",
31+
"Represents attributes of a single traceback item.");
32+
pyTracebackItem.def_readwrite("lemma", &TracebackItem::lemma);
33+
pyTracebackItem.def_readwrite("am_score", &TracebackItem::amScore);
34+
pyTracebackItem.def_readwrite("lm_score", &TracebackItem::lmScore);
35+
pyTracebackItem.def_readwrite("start_time", &TracebackItem::startTime);
36+
pyTracebackItem.def_readwrite("end_time", &TracebackItem::endTime);
37+
38+
pyTracebackItem.def(
39+
"__repr__",
40+
[](TracebackItem const& t) {
41+
std::stringstream ss;
42+
ss << "<TracebackItem(";
43+
ss << "lemma='" << t.lemma << "'";
44+
ss << ", am_score=" << t.amScore;
45+
ss << ", lm_score=" << t.lmScore;
46+
ss << ", start_time=" << t.startTime;
47+
ss << ", end_time=" << t.endTime;
48+
ss << ")>";
49+
return ss.str();
50+
});
51+
52+
pyTracebackItem.def(
53+
"__str__",
54+
[](TracebackItem const& t) {
55+
return t.lemma;
56+
});
57+
58+
/*
59+
* ========================
60+
* === Search Algorithm ===
61+
* ========================
62+
*/
63+
py::class_<SearchAlgorithm> pySearchAlgorithm(
64+
module,
65+
"SearchAlgorithm",
66+
"Class that can perform recognition using RASR.\n\n"
67+
"The search algorithm is configured with a RASR config object.\n"
68+
"It works by calling `enter_segment()`, passing segment features\n"
69+
"via `put_feature` or `put_features` and finally calling `finish_segment()`.\n"
70+
"Intermediate and final results can be retrieved via `get_current_best_traceback()`.\n"
71+
"Before recognizing the next segment, `reset` should be called.\n"
72+
"There is also a convenience function `recognize_segment` that performs all\n"
73+
"these steps in one go given an array of segment features.");
74+
75+
pySearchAlgorithm.def(
76+
py::init<const Core::Configuration&>(),
77+
py::arg("config"),
78+
"Initialize search algorithm using a RASR config.");
79+
80+
pySearchAlgorithm.def(
81+
"reset",
82+
&SearchAlgorithm::reset,
83+
"Call before starting a new recognition. Cleans up existing data structures from the previous run.");
84+
85+
pySearchAlgorithm.def(
86+
"enter_segment",
87+
&SearchAlgorithm::enterSegment,
88+
"Call at the beginning of a new segment.");
89+
90+
pySearchAlgorithm.def(
91+
"finish_segment",
92+
&SearchAlgorithm::finishSegment,
93+
"Call after all features of the current segment have been passed");
94+
95+
pySearchAlgorithm.def(
96+
"put_feature",
97+
&SearchAlgorithm::putFeature,
98+
py::arg("feature_vector"),
99+
"Pass a single feature as a numpy array of shape [F] or [1, F].");
100+
101+
pySearchAlgorithm.def(
102+
"put_features",
103+
&SearchAlgorithm::putFeatures,
104+
py::arg("feature_array"),
105+
"Pass multiple features as a numpy array of shape [T, F] or [1, T, F].");
106+
107+
pySearchAlgorithm.def(
108+
"get_current_best_traceback",
109+
&SearchAlgorithm::getCurrentBestTraceback,
110+
"Get the best traceback given all features that have been passed thus far.");
111+
112+
pySearchAlgorithm.def(
113+
"recognize_segment",
114+
&SearchAlgorithm::recognizeSegment,
115+
py::arg("features"),
116+
"Convenience function to reset the search algorithm, start a segment, pass all the features as a numpy array of shape [T, F] or [1, T, F], finish the segment, and return the recognition result.");
117+
}

src/Tools/LibRASR/Search.hh

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
/** Copyright 2025 RWTH Aachen University. All rights reserved.
2+
*
3+
* Licensed under the RWTH ASR License (the "License");
4+
* you may not use this file except in compliance with the License.
5+
* You may obtain a copy of the License at
6+
*
7+
* http://www.hltpr.rwth-aachen.de/rwth-asr/rwth-asr-license.html
8+
*
9+
* Unless required by applicable law or agreed to in writing, software
10+
* distributed under the License is distributed on an "AS IS" BASIS,
11+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
* See the License for the specific language governing permissions and
13+
* limitations under the License.
14+
*/
15+
16+
#include <pybind11/pybind11.h>
17+
#include <pybind11/stl.h>
18+
19+
namespace py = pybind11;
20+
21+
/*
22+
* Create bindings for search functionalities and tracebacks
23+
*/
24+
void bindSearchAlgorithm(py::module_& module);

0 commit comments

Comments
 (0)