diff --git a/src/Python/LabelScorer.cc b/src/Python/LabelScorer.cc new file mode 100644 index 000000000..8c679aa72 --- /dev/null +++ b/src/Python/LabelScorer.cc @@ -0,0 +1,176 @@ +/** Copyright 2025 RWTH Aachen University. All rights reserved. + * + * Licensed under the RWTH ASR License (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.hltpr.rwth-aachen.de/rwth-asr/rwth-asr-license.html + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "LabelScorer.hh" + +#include +#include +#include + +#include "ScoringContext.hh" + +namespace py = pybind11; + +namespace Python { + +PythonLabelScorer::PythonLabelScorer(Core::Configuration const& config) + : Core::Component(config), + Precursor(config) { +} + +void PythonLabelScorer::reset() { + PYBIND11_OVERRIDE_PURE(void, LabelScorer, reset); +} + +void PythonLabelScorer::signalNoMoreFeatures() { + PYBIND11_OVERRIDE_PURE_NAME( + void, + LabelScorer, + "signal_no_more_features", + signalNoMoreFeatures); +} + +Nn::ScoringContextRef PythonLabelScorer::getInitialScoringContext() { + py::gil_scoped_acquire gil; + // Store `py::object` from virtual python call in a `PythonScoringContext` + return Core::ref(new PythonScoringContext(getInitialPythonScoringContext())); +} + +py::object PythonLabelScorer::getInitialPythonScoringContext() { + PYBIND11_OVERRIDE_PURE_NAME( + py::object, + Nn::LabelScorer, + "get_initial_scoring_context", + getInitialPythonScoringContext); +} + +Nn::ScoringContextRef PythonLabelScorer::extendedScoringContext(Request const& request) { + auto* pythonScoringContext = dynamic_cast(request.context.get()); + py::gil_scoped_acquire gil; + // Store `py::object` from virtual python call in a `PythonScoringContext` + auto newScoringContext = extendedPythonScoringContext(pythonScoringContext->object, request.nextToken, request.transitionType); + return Core::ref(new PythonScoringContext(std::move(newScoringContext))); +} + +py::object PythonLabelScorer::extendedPythonScoringContext(py::object const& context, Nn::LabelIndex nextToken, TransitionType transitionType) { + PYBIND11_OVERRIDE_PURE_NAME( + py::object, + Nn::LabelScorer, + "extended_scoring_context", + extendedPythonScoringContext, + context, + nextToken, + transitionType); +} + +void PythonLabelScorer::addInput(Nn::DataView const& input) { + // Call batched version + addInputs(input, 1); +} + +void PythonLabelScorer::addInputs(Nn::DataView const& input, size_t nTimesteps) { + py::gil_scoped_acquire gil; + + // Convert `input` to a `py::array` for virtual python call + ssize_t featureDimSize = input.size() / nTimesteps; + + py::array_t inputArray( + {static_cast(nTimesteps), featureDimSize}, + {sizeof(f32) * featureDimSize, sizeof(f32)}, + input.data()); + + addPythonInputs(inputArray); +} + +void PythonLabelScorer::addPythonInputs(py::array const& inputs) { + PYBIND11_OVERRIDE_PURE_NAME( + void, + Nn::LabelScorer, + "add_inputs", + addPythonInputs, + inputs); +} + +std::optional PythonLabelScorer::computeScoreWithTime(Request const& request) { + // Extract the underlying `py::object` from ScoringContext in `request` to supply them to the virtual python call + auto* pythonScoringContext = dynamic_cast(request.context.get()); + + std::vector contexts = {pythonScoringContext->object}; + std::vector nextTokens = {request.nextToken}; + std::vector transitionTypes = {request.transitionType}; + + py::gil_scoped_acquire gil; + + // Call batched version + if (auto result = computePythonScoresWithTimes(contexts, nextTokens, transitionTypes)) { + verify(result->size() == 1); + ScoreWithTime scoreWithTime{result->front().first, result->front().second}; + return scoreWithTime; + } + + return {}; +} + +std::optional PythonLabelScorer::computeScoresWithTimes(std::vector const& requests) { + std::vector contexts; + std::vector nextTokens; + std::vector transitionTypes; + + contexts.reserve(requests.size()); + nextTokens.reserve(requests.size()); + transitionTypes.reserve(requests.size()); + + // Extract the underlying `py::object`s from ScoringContexts in `requests` to supply them to the virtual python call + for (auto const& request : requests) { + auto* pythonScoringContext = dynamic_cast(request.context.get()); + contexts.push_back(pythonScoringContext->object); + nextTokens.push_back(request.nextToken); + transitionTypes.push_back(request.transitionType); + } + + py::gil_scoped_acquire gil; + + if (auto result = computePythonScoresWithTimes(contexts, nextTokens, transitionTypes)) { + verify(result->size() == requests.size()); + ScoresWithTimes scoresWithTimes; + scoresWithTimes.scores.reserve(result->size()); + for (auto const& [score, timeframe] : *result) { + scoresWithTimes.scores.push_back(score); + scoresWithTimes.timeframes.push_back(timeframe); + } + return scoresWithTimes; + } + + return {}; +} + +std::optional>> PythonLabelScorer::computePythonScoresWithTimes(std::vector const& contexts, std::vector const& nextTokens, std::vector const& transitionTypes) { + using returnType = std::optional>>; // Macro can't handle types with commas inside properly + PYBIND11_OVERRIDE_PURE_NAME( + returnType, + Nn::LabelScorer, + "compute_scores_with_times", + computePythonScoresWithTimes, + contexts, + nextTokens, + transitionTypes); +} + +void PythonLabelScorer::setInstance(py::object const& instance) { + py::gil_scoped_acquire gil; + pyInstance_ = instance; +} + +} // namespace Python diff --git a/src/Python/LabelScorer.hh b/src/Python/LabelScorer.hh new file mode 100644 index 000000000..e2a44e34b --- /dev/null +++ b/src/Python/LabelScorer.hh @@ -0,0 +1,80 @@ +/** Copyright 2025 RWTH Aachen University. All rights reserved. + * + * Licensed under the RWTH ASR License (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.hltpr.rwth-aachen.de/rwth-asr/rwth-asr-license.html + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef PYTHON_LABEL_SCORER_HH +#define PYTHON_LABEL_SCORER_HH + +#include + +#include + +namespace py = pybind11; + +namespace Python { + +/* + * Trampoline class that is used in order to expose the LabelScorer class via pybind. + * It mainly specifies the signatures of abstract methods that need to be implemented in python + * and performs conversion between "C++ types" such as `DataView` and `ScoringContext` + * and "Python types" such as `py::array` and `py::object`. + * + * See https://pybind11.readthedocs.io/en/stable/advanced/classes.html for official documentation + * on the "trampoline" pattern. + */ +class PythonLabelScorer : public Nn::LabelScorer { +public: + using Precursor = Nn::LabelScorer; + + PythonLabelScorer(Core::Configuration const& config); + virtual ~PythonLabelScorer() = default; + + // Must be overridden in python by name "reset" + virtual void reset() override; + + // Can be overridden in python. No-op per default. + virtual void signalNoMoreFeatures() override; + + // Must be overridden in python by name "get_initial_scoring_context" + virtual Nn::ScoringContextRef getInitialScoringContext() override; + virtual py::object getInitialPythonScoringContext(); + + // Must be overridden in python by name "extended_scoring_context" + virtual Nn::ScoringContextRef extendedScoringContext(Request const& request) override; + virtual py::object extendedPythonScoringContext(py::object const& context, Nn::LabelIndex nextToken, TransitionType transitionType); + + // Calls batched version with `nTimesteps = 1` + virtual void addInput(Nn::DataView const& input) override; + + // Must be overridden in python by name "add_inputs" + virtual void addInputs(Nn::DataView const& input, size_t nTimesteps) override; + virtual void addPythonInputs(py::array const& inputs); + + // Calls batched version + virtual std::optional computeScoreWithTime(Request const& request) override; + + // Must be overridden in python by name "compute_scores_with_times" + virtual std::optional computeScoresWithTimes(std::vector const& requests) override; + virtual std::optional>> computePythonScoresWithTimes(std::vector const& contexts, std::vector const& nextTokens, std::vector const& transitionTypes); + + // Keep track of python object as a member to make sure it doesn't get garbage collected + void setInstance(py::object const& instance); + +protected: + py::object pyInstance_; // Hold the Python wrapper +}; + +} // namespace Python + +#endif // PYTHON_LABEL_SCORER_HH diff --git a/src/Python/Makefile b/src/Python/Makefile index 698d5ffb4..42d4a514a 100644 --- a/src/Python/Makefile +++ b/src/Python/Makefile @@ -20,7 +20,9 @@ LIBPYTHON_O = \ $(OBJDIR)/AllophoneStateFsaBuilder.o \ $(OBJDIR)/Configuration.o \ $(OBJDIR)/Init.o \ + $(OBJDIR)/LabelScorer.o \ $(OBJDIR)/Numpy.o \ + $(OBJDIR)/ScoringContext.o \ $(OBJDIR)/Search.o \ $(OBJDIR)/Utilities.o diff --git a/src/Python/ScoringContext.cc b/src/Python/ScoringContext.cc new file mode 100644 index 000000000..284221e5e --- /dev/null +++ b/src/Python/ScoringContext.cc @@ -0,0 +1,31 @@ +/** Copyright 2025 RWTH Aachen University. All rights reserved. + * + * Licensed under the RWTH ASR License (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.hltpr.rwth-aachen.de/rwth-asr/rwth-asr-license.html + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ScoringContext.hh" + +namespace Python { + +size_t PythonScoringContext::hash() const { + return py::hash(py::cast(object)); +} + +bool PythonScoringContext::isEqual(Nn::ScoringContextRef const& other) const { + auto* otherPtr = dynamic_cast(other.get()); + + py::gil_scoped_acquire gil; + return object.equal(py::cast(otherPtr->object)); +} + +} // namespace Python diff --git a/src/Python/ScoringContext.hh b/src/Python/ScoringContext.hh new file mode 100644 index 000000000..34a9c9201 --- /dev/null +++ b/src/Python/ScoringContext.hh @@ -0,0 +1,47 @@ +/** Copyright 2025 RWTH Aachen University. All rights reserved. + * + * Licensed under the RWTH ASR License (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.hltpr.rwth-aachen.de/rwth-asr/rwth-asr-license.html + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef PYTHON_SCORING_CONTEXT_HH +#define PYTHON_SCORING_CONTEXT_HH + +#include + +#include + +namespace py = pybind11; + +namespace Python { + +/* + * Scoring context containing some arbitrary (hashable) python object + */ +struct PythonScoringContext : public Nn::ScoringContext { + py::object object; + + PythonScoringContext() + : object(py::none()) {} + + PythonScoringContext(py::object&& object) + : object(object) {} + + bool isEqual(Nn::ScoringContextRef const& other) const; + size_t hash() const; +}; + +typedef Core::Ref PythonScoringContextRef; + +} // namespace Python + +#endif // PYTHON_SCORING_CONTEXT_HH diff --git a/src/Tools/LibRASR/LabelScorer.cc b/src/Tools/LibRASR/LabelScorer.cc new file mode 100644 index 000000000..5ea89b410 --- /dev/null +++ b/src/Tools/LibRASR/LabelScorer.cc @@ -0,0 +1,150 @@ +/** Copyright 2025 RWTH Aachen University. All rights reserved. + * + * Licensed under the RWTH ASR License (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.hltpr.rwth-aachen.de/rwth-asr/rwth-asr-license.html + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "LabelScorer.hh" + +#include +#include +#include + +#include +#include + +// Make it so that a `py::object` can use `Core::Ref` as a holder type instead of the usual `std::unique_ptr`. +// See https://pybind11.readthedocs.io/en/stable/advanced/smart_ptrs.html#custom-smart-pointers for official documentation. +PYBIND11_DECLARE_HOLDER_TYPE(T, Core::Ref, true); + +void registerPythonLabelScorer(std::string const& name, py::object const& pyLabelScorerClass) { + Nn::Module::instance().labelScorerFactory().registerLabelScorer( + name.c_str(), + [pyLabelScorerClass](Core::Configuration const& config) { + py::gil_scoped_acquire gil; + // Call constructor of `pyLabelScorerClass` + py::object inst = pyLabelScorerClass(config); + inst.cast()->setInstance(inst); + return inst.cast>(); + }); +} + +void bindLabelScorer(py::module_& module) { + module.def( + "register_label_scorer_type", + ®isterPythonLabelScorer, + py::arg("name"), + py::arg("label_scorer_cls"), + "Register a custom label scorer type in the internal label scorer factory of RASR.\n\n" + "Args:\n" + " name: The name under which the label scorer type is registered. The same name must be used in the RASR config\n" + " in order to make RASR instantiate a label scorer of this type later.\n" + " label_scorer_cls: A class that inherits from `librasr.LabelScorer` and implements the abstract methods."); + + py::enum_(module, "TransitionType") + .value("LABEL_TO_LABEL", Nn::LabelScorer::TransitionType::LABEL_TO_LABEL) + .value("LABEL_LOOP", Nn::LabelScorer::TransitionType::LABEL_LOOP) + .value("LABEL_TO_BLANK", Nn::LabelScorer::TransitionType::LABEL_TO_BLANK) + .value("BLANK_TO_LABEL", Nn::LabelScorer::TransitionType::BLANK_TO_LABEL) + .value("BLANK_LOOP", Nn::LabelScorer::TransitionType::BLANK_LOOP) + .value("INITIAL_LABEL", Nn::LabelScorer::TransitionType::INITIAL_LABEL) + .value("INITIAL_BLANK", Nn::LabelScorer::TransitionType::INITIAL_BLANK); + + // Specify `Python::LabelScorer` as trampoline class and `Core::Ref` as holder type + py::class_> pyLabelScorer( + module, + "LabelScorer", + "Abstract base class for label scorers. A label scorer is responsible for initializing and updating a 'scoring context'\n" + "and then computing scores for tokens given a scoring context. This scoring context can be an arbitrary (hashable)\n" + "python object depending on the needs of the model.\n" + "For example for a CTC model, the scoring context could just be the current timestep. For a transducer model with\n" + "LSTM prediction network it could be the timestep together with an LSTM hidden state tensor.\n" + "Label scorers implemented in python can be used in conjunction with native RASR label scorers such as\n" + "`CombineLabelScorer`, `EncoderDecoderLabelScorer` + `OnnxEncoder`, etc.\n" + "A label scorer instance can be used by RASR in order to perform procedures such as search or forced alignment.\n" + "Concrete subclasses need to implement the following methods:\n" + " - `reset`\n" + " - `signal_no_more_features`\n" + " - `get_initial_scoring_context`\n" + " - `extended_scoring_context`\n" + " - `add_inputs`\n" + " - `compute_scores_with_times`"); + + pyLabelScorer.def( + py::init(), + py::arg("config"), + "Construct a label scorer from a RASR config."); + + pyLabelScorer.def( + "reset", + &Nn::LabelScorer::reset, + "Reset any internal buffers and flags related to the current segment in order to prepare the label scorer for a new segment."); + + pyLabelScorer.def( + "signal_no_more_features", + &Nn::LabelScorer::signalNoMoreFeatures, + "Signal to the label scorer that all features for the current segment have been passed."); + + pyLabelScorer.def( + "get_initial_scoring_context", + [](Python::PythonLabelScorer& self) { return self.getInitialPythonScoringContext(); }, + "Create some arbitrary (hashable) python object which symbolizes the scoring context in the first search step"); + + pyLabelScorer.def( + "extended_scoring_context", + [](Python::PythonLabelScorer& self, + py::object const& context, + Nn::LabelIndex nextToken, + Nn::LabelScorer::TransitionType transitionType) { return self.extendedPythonScoringContext(context, nextToken, transitionType); }, + py::arg("context"), + py::arg("next_token"), + py::arg("transition_type"), + "Create a new extended scoring context given the previous context and next token.\n\n" + "Args:\n" + " context: The previous scoring context. The type of the object is the same as the one returned by `get_initial_scoring_context`.\n" + " next_token: The most recent token that has been hypothesized and can now be integrated into the scoring context.\n" + " transition_type: The kind of transition that has just been performed.\n\n" + "Returns:\n" + " An extended scoring context. The type of this object should be the same as the type of the input `context`."); + + pyLabelScorer.def( + "add_inputs", + []( + Python::PythonLabelScorer& self, + py::array const& inputs) { self.addPythonInputs(inputs); }, + py::arg("inputs"), + "Feed an array of input features to the label scorer.\n\n" + "Args:\n" + " inputs: A numpy array of shape [T, F] containing the input features for `T` time steps."); + + pyLabelScorer.def( + "compute_scores_with_times", + [](Python::PythonLabelScorer& self, + std::vector const& contexts, + std::vector nextTokens, + std::vector transitionTypes) { return self.computePythonScoresWithTimes(contexts, nextTokens, transitionTypes); }, + py::arg("contexts"), + py::arg("next_tokens"), + py::arg("transition_types"), + "Compute the scores and timestamps of tokens given the current scoring contexts. Timestamps need to be computed because\n" + "each label scorer may implement custom logic about how much time is advanced depending on the situation\n" + "(e.g. vertical vs. diagonal blank transitions in transducer).\n\n" + "Args:\n" + " contexts: A list of length `B` containing current scoring contexts for all requests. The type is the same as the one returned by" + " `get_initial_scoring_context` and `extended_scoring_context`.\n" + " next_tokens: A list of length `B` containing the tokens for which the score should be computed.\n" + " transition_types: A list of length `B` containing the types of the hypothesized transitions.\n\n" + "Returns:\n" + " Either `None` if the label scorer is not ready to process the requests (e.g. expects more features or segment end signal)\n" + " or a list of length `B` containing the scores and timestamps for each request. The returned timestamps will be used\n" + " to form word boundaries in the search traceback."); +} diff --git a/src/Tools/LibRASR/LabelScorer.hh b/src/Tools/LibRASR/LabelScorer.hh new file mode 100644 index 000000000..c2eabbb22 --- /dev/null +++ b/src/Tools/LibRASR/LabelScorer.hh @@ -0,0 +1,23 @@ +/** Copyright 2025 RWTH Aachen University. All rights reserved. + * + * Licensed under the RWTH ASR License (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.hltpr.rwth-aachen.de/rwth-asr/rwth-asr-license.html + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +namespace py = pybind11; + +/* + * Create bindings for label scorer classes + */ +void bindLabelScorer(py::module_& module); diff --git a/src/Tools/LibRASR/Makefile b/src/Tools/LibRASR/Makefile index 22b1b5ebf..699c568a8 100644 --- a/src/Tools/LibRASR/Makefile +++ b/src/Tools/LibRASR/Makefile @@ -12,8 +12,9 @@ TARGETS = librasr.so CXXFLAGS += -fPIC LDFLAGS += -shared -RASR_LIB_O = $(OBJDIR)/LibRASR.o \ - $(OBJDIR)/Search.o \ +RASR_LIB_O = $(OBJDIR)/LabelScorer.o \ + $(OBJDIR)/LibRASR.o \ + $(OBJDIR)/Search.o \ ../../Flf/libSprintFlf.$(a) \ ../../Flf/FlfCore/libSprintFlfCore.$(a) \ ../../Speech/libSprintSpeech.$(a) \ diff --git a/src/Tools/LibRASR/PybindModule.cc b/src/Tools/LibRASR/PybindModule.cc index 60a02c4ca..630d23ab9 100644 --- a/src/Tools/LibRASR/PybindModule.cc +++ b/src/Tools/LibRASR/PybindModule.cc @@ -4,6 +4,7 @@ #include #include +#include "LabelScorer.hh" #include "LibRASR.hh" #include "Search.hh" @@ -14,7 +15,22 @@ PYBIND11_MODULE(librasr, m) { m.doc() = "RASR python module"; + // TODO: Overhaul Configuration pybinds to make Configurations better to interact with from python-side. py::class_ baseConfigClass(m, "_BaseConfig"); + baseConfigClass.def( + "__getitem__", + [](Core::Configuration const& self, std::string const& key) { + std::string result; + if (self.get(key, result)) { + return result; + } + else { + std::cerr << "WARNING: Tried to get config value for key '" << key << "' but it was not configured. Return empty string.\n"; + return std::string(); + } + }, + py::arg("key"), + "Retrieve the configured value of a specific parameter key as an unprocessed string."); py::class_ pyRasrConfig(m, "Configuration", baseConfigClass); pyRasrConfig.def(py::init<>()); @@ -26,5 +42,6 @@ PYBIND11_MODULE(librasr, m) { pyFsaBuilder.def("build_by_orthography", &AllophoneStateFsaBuilder::buildByOrthography); pyFsaBuilder.def("build_by_segment_name", &AllophoneStateFsaBuilder::buildBySegmentName); + bindLabelScorer(m); bindSearchAlgorithm(m); }