Skip to content

Commit 55b90a9

Browse files
authored
Change WER signature (#190)
1 parent e7ade1e commit 55b90a9

File tree

2 files changed

+20
-27
lines changed

2 files changed

+20
-27
lines changed

src/trustyai/metrics/language.py

Lines changed: 16 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -4,17 +4,16 @@
44
# pylint: disable = import-error
55
from typing import List, Optional, Union, Callable
66

7-
from org.kie.trustyai.metrics.language.wer import (
7+
from org.kie.trustyai.metrics.language.levenshtein import (
88
WordErrorRate as _WordErrorRate,
9-
WordErrorRateResult as _WordErrorRateResult,
9+
ErrorRateResult as _ErrorRateResult,
1010
)
11-
1211
from opennlp.tools.tokenize import Tokenizer
1312

1413

1514
@dataclass
16-
class TokenSequenceAlignmentCounters:
17-
"""Token Sequence Alignment Counters"""
15+
class LevenshteinCounters:
16+
"""LevenshteinCounters Counters"""
1817

1918
substitutions: int
2019
insertions: int
@@ -23,25 +22,19 @@ class TokenSequenceAlignmentCounters:
2322

2423

2524
@dataclass
26-
class WordErrorRateResult:
25+
class ErrorRateResult:
2726
"""Word Error Rate Result"""
2827

29-
wer: float
30-
aligned_reference: str
31-
aligned_input: str
32-
alignment_counters: TokenSequenceAlignmentCounters
28+
value: float
29+
alignment_counters: LevenshteinCounters
3330

3431
@staticmethod
35-
def convert(wer_result: _WordErrorRateResult):
36-
"""Converts a Java WordErrorRateResult to a Python WordErrorRateResult"""
37-
wer = wer_result.getWordErrorRate()
38-
aligned_reference = wer_result.getAlignedReferenceString()
39-
aligned_input = wer_result.getAlignedInputString()
40-
alignment_counters = wer_result.getAlignmentCounters()
41-
return WordErrorRateResult(
42-
wer=wer,
43-
aligned_reference=aligned_reference,
44-
aligned_input=aligned_input,
32+
def convert(result: _ErrorRateResult):
33+
"""Converts a Java ErrorRateResult to a Python ErrorRateResult"""
34+
value = result.getValue()
35+
alignment_counters = result.getAlignmentCounters()
36+
return ErrorRateResult(
37+
value=value,
4538
alignment_counters=alignment_counters,
4639
)
4740

@@ -50,7 +43,7 @@ def word_error_rate(
5043
reference: str,
5144
hypothesis: str,
5245
tokenizer: Optional[Union[Tokenizer, Callable[[str], List[str]]]] = None,
53-
) -> WordErrorRateResult:
46+
) -> ErrorRateResult:
5447
"""Calculate Word Error Rate between reference and hypothesis strings"""
5548
if not tokenizer:
5649
_wer = _WordErrorRate()
@@ -60,9 +53,9 @@ def word_error_rate(
6053
tokenized_reference = tokenizer(reference)
6154
tokenized_hypothesis = tokenizer(hypothesis)
6255
_wer = _WordErrorRate()
63-
return WordErrorRateResult.convert(
56+
return ErrorRateResult.convert(
6457
_wer.calculate(tokenized_reference, tokenized_hypothesis)
6558
)
6659
else:
6760
raise ValueError("Unsupported tokenizer")
68-
return WordErrorRateResult.convert(_wer.calculate(reference, hypothesis))
61+
return ErrorRateResult.convert(_wer.calculate(reference, hypothesis))

tests/general/test_metrics_language.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ def test_default_tokenizer():
2222
"""Test default tokenizer"""
2323
results = [4 / 7, 1 / 26, 1]
2424
for i, (reference, hypothesis) in enumerate(zip(REFERENCES, INPUTS)):
25-
wer = word_error_rate(reference, hypothesis).wer
25+
wer = word_error_rate(reference, hypothesis).value
2626
assert math.isclose(wer, results[i], rel_tol=tolerance), \
2727
f"WER for {reference}, {hypothesis} was {wer}, expected ~{results[i]}."
2828

@@ -36,7 +36,7 @@ def tokenizer(text: str) -> List[str]:
3636
return CommonsStringTokenizer(text).getTokenList()
3737

3838
for i, (reference, hypothesis) in enumerate(zip(REFERENCES, INPUTS)):
39-
wer = word_error_rate(reference, hypothesis, tokenizer=tokenizer).wer
39+
wer = word_error_rate(reference, hypothesis, tokenizer=tokenizer).value
4040
assert math.isclose(wer, results[i], rel_tol=tolerance), \
4141
f"WER for {reference}, {hypothesis} was {wer}, expected ~{results[i]}."
4242

@@ -47,7 +47,7 @@ def test_opennlp_tokenizer():
4747
results = [9 / 14., 3 / 78., 1.0]
4848
tokenizer = OpenNLPTokenizer()
4949
for i, (reference, hypothesis) in enumerate(zip(REFERENCES, INPUTS)):
50-
wer = word_error_rate(reference, hypothesis, tokenizer=tokenizer).wer
50+
wer = word_error_rate(reference, hypothesis, tokenizer=tokenizer).value
5151
assert math.isclose(wer, results[i], rel_tol=tolerance), \
5252
f"WER for {reference}, {hypothesis} was {wer}, expected ~{results[i]}."
5353

@@ -61,6 +61,6 @@ def tokenizer(text: str) -> List[str]:
6161
return text.split(" ")
6262

6363
for i, (reference, hypothesis) in enumerate(zip(REFERENCES, INPUTS)):
64-
wer = word_error_rate(reference, hypothesis, tokenizer=tokenizer).wer
64+
wer = word_error_rate(reference, hypothesis, tokenizer=tokenizer).value
6565
assert math.isclose(wer, results[i], rel_tol=tolerance), \
6666
f"WER for {reference}, {hypothesis} was {wer}, expected ~{results[i]}."

0 commit comments

Comments
 (0)