|
| 1 | +/** Copyright 2025 RWTH Aachen University. All rights reserved. |
| 2 | + * |
| 3 | + * Licensed under the RWTH ASR License (the "License"); |
| 4 | + * you may not use this file except in compliance with the License. |
| 5 | + * You may obtain a copy of the License at |
| 6 | + * |
| 7 | + * http://www.hltpr.rwth-aachen.de/rwth-asr/rwth-asr-license.html |
| 8 | + * |
| 9 | + * Unless required by applicable law or agreed to in writing, software |
| 10 | + * distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | + * See the License for the specific language governing permissions and |
| 13 | + * limitations under the License. |
| 14 | + */ |
| 15 | + |
| 16 | +#include "Search.hh" |
| 17 | + |
| 18 | +#include <sstream> |
| 19 | + |
| 20 | +#include <Python/Search.hh> |
| 21 | + |
| 22 | +void bindSearchAlgorithm(py::module_& module) { |
| 23 | + /* |
| 24 | + * ======================== |
| 25 | + * === Traceback ========== |
| 26 | + * ======================== |
| 27 | + */ |
| 28 | + py::class_<TracebackItem> pyTracebackItem( |
| 29 | + module, |
| 30 | + "TracebackItem", |
| 31 | + "Represents attributes of a single traceback item."); |
| 32 | + pyTracebackItem.def_readwrite("lemma", &TracebackItem::lemma); |
| 33 | + pyTracebackItem.def_readwrite("am_score", &TracebackItem::amScore); |
| 34 | + pyTracebackItem.def_readwrite("lm_score", &TracebackItem::lmScore); |
| 35 | + pyTracebackItem.def_readwrite("start_time", &TracebackItem::startTime); |
| 36 | + pyTracebackItem.def_readwrite("end_time", &TracebackItem::endTime); |
| 37 | + |
| 38 | + pyTracebackItem.def( |
| 39 | + "__repr__", |
| 40 | + [](TracebackItem const& t) { |
| 41 | + std::stringstream ss; |
| 42 | + ss << "<TracebackItem("; |
| 43 | + ss << "lemma='" << t.lemma << "'"; |
| 44 | + ss << ", am_score=" << t.amScore; |
| 45 | + ss << ", lm_score=" << t.lmScore; |
| 46 | + ss << ", start_time=" << t.startTime; |
| 47 | + ss << ", end_time=" << t.endTime; |
| 48 | + ss << ")>"; |
| 49 | + return ss.str(); |
| 50 | + }); |
| 51 | + |
| 52 | + pyTracebackItem.def( |
| 53 | + "__str__", |
| 54 | + [](TracebackItem const& t) { |
| 55 | + return t.lemma; |
| 56 | + }); |
| 57 | + |
| 58 | + /* |
| 59 | + * ======================== |
| 60 | + * === Search Algorithm === |
| 61 | + * ======================== |
| 62 | + */ |
| 63 | + py::class_<SearchAlgorithm> pySearchAlgorithm( |
| 64 | + module, |
| 65 | + "SearchAlgorithm", |
| 66 | + "Class that can perform recognition using RASR.\n\n" |
| 67 | + "The search algorithm is configured with a RASR config object.\n" |
| 68 | + "It works by calling `enter_segment()`, passing segment features\n" |
| 69 | + "via `put_feature` or `put_features` and finally calling `finish_segment()`.\n" |
| 70 | + "Intermediate and final results can be retrieved via `get_current_best_traceback()`.\n" |
| 71 | + "Before recognizing the next segment, `reset` should be called.\n" |
| 72 | + "There is also a convenience function `recognize_segment` that performs all\n" |
| 73 | + "these steps in one go given an array of segment features."); |
| 74 | + |
| 75 | + pySearchAlgorithm.def( |
| 76 | + py::init<const Core::Configuration&>(), |
| 77 | + py::arg("config"), |
| 78 | + "Initialize search algorithm using a RASR config."); |
| 79 | + |
| 80 | + pySearchAlgorithm.def( |
| 81 | + "reset", |
| 82 | + &SearchAlgorithm::reset, |
| 83 | + "Call before starting a new recognition. Cleans up existing data structures from the previous run."); |
| 84 | + |
| 85 | + pySearchAlgorithm.def( |
| 86 | + "enter_segment", |
| 87 | + &SearchAlgorithm::enterSegment, |
| 88 | + "Call at the beginning of a new segment."); |
| 89 | + |
| 90 | + pySearchAlgorithm.def( |
| 91 | + "finish_segment", |
| 92 | + &SearchAlgorithm::finishSegment, |
| 93 | + "Call after all features of the current segment have been passed"); |
| 94 | + |
| 95 | + pySearchAlgorithm.def( |
| 96 | + "put_feature", |
| 97 | + &SearchAlgorithm::putFeature, |
| 98 | + py::arg("feature_vector"), |
| 99 | + "Pass a single feature as a numpy array of shape [F] or [1, F]."); |
| 100 | + |
| 101 | + pySearchAlgorithm.def( |
| 102 | + "put_features", |
| 103 | + &SearchAlgorithm::putFeatures, |
| 104 | + py::arg("feature_array"), |
| 105 | + "Pass multiple features as a numpy array of shape [T, F] or [1, T, F]."); |
| 106 | + |
| 107 | + pySearchAlgorithm.def( |
| 108 | + "get_current_best_traceback", |
| 109 | + &SearchAlgorithm::getCurrentBestTraceback, |
| 110 | + "Get the best traceback given all features that have been passed thus far."); |
| 111 | + |
| 112 | + pySearchAlgorithm.def( |
| 113 | + "recognize_segment", |
| 114 | + &SearchAlgorithm::recognizeSegment, |
| 115 | + py::arg("features"), |
| 116 | + "Convenience function to reset the search algorithm, start a segment, pass all the features as a numpy array of shape [T, F] or [1, T, F], finish the segment, and return the recognition result."); |
| 117 | +} |
0 commit comments