Skip to content

Commit e7ade1e

Browse files
authored
Add WER metric (#189)
* Add WER language performance metric * Add test for pure Python tokenizer * Add documentation string to method * Fix lint errors * Fix additional lint errors
1 parent 7eae4d1 commit e7ade1e

File tree

3 files changed

+142
-0
lines changed

3 files changed

+142
-0
lines changed

src/trustyai/metrics/language.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
""""Group fairness metrics"""
2+
from dataclasses import dataclass
3+
4+
# pylint: disable = import-error
5+
from typing import List, Optional, Union, Callable
6+
7+
from org.kie.trustyai.metrics.language.wer import (
8+
WordErrorRate as _WordErrorRate,
9+
WordErrorRateResult as _WordErrorRateResult,
10+
)
11+
12+
from opennlp.tools.tokenize import Tokenizer
13+
14+
15+
@dataclass
16+
class TokenSequenceAlignmentCounters:
17+
"""Token Sequence Alignment Counters"""
18+
19+
substitutions: int
20+
insertions: int
21+
deletions: int
22+
correct: int
23+
24+
25+
@dataclass
26+
class WordErrorRateResult:
27+
"""Word Error Rate Result"""
28+
29+
wer: float
30+
aligned_reference: str
31+
aligned_input: str
32+
alignment_counters: TokenSequenceAlignmentCounters
33+
34+
@staticmethod
35+
def convert(wer_result: _WordErrorRateResult):
36+
"""Converts a Java WordErrorRateResult to a Python WordErrorRateResult"""
37+
wer = wer_result.getWordErrorRate()
38+
aligned_reference = wer_result.getAlignedReferenceString()
39+
aligned_input = wer_result.getAlignedInputString()
40+
alignment_counters = wer_result.getAlignmentCounters()
41+
return WordErrorRateResult(
42+
wer=wer,
43+
aligned_reference=aligned_reference,
44+
aligned_input=aligned_input,
45+
alignment_counters=alignment_counters,
46+
)
47+
48+
49+
def word_error_rate(
50+
reference: str,
51+
hypothesis: str,
52+
tokenizer: Optional[Union[Tokenizer, Callable[[str], List[str]]]] = None,
53+
) -> WordErrorRateResult:
54+
"""Calculate Word Error Rate between reference and hypothesis strings"""
55+
if not tokenizer:
56+
_wer = _WordErrorRate()
57+
elif isinstance(tokenizer, Tokenizer):
58+
_wer = _WordErrorRate(tokenizer)
59+
elif callable(tokenizer):
60+
tokenized_reference = tokenizer(reference)
61+
tokenized_hypothesis = tokenizer(hypothesis)
62+
_wer = _WordErrorRate()
63+
return WordErrorRateResult.convert(
64+
_wer.calculate(tokenized_reference, tokenized_hypothesis)
65+
)
66+
else:
67+
raise ValueError("Unsupported tokenizer")
68+
return WordErrorRateResult.convert(_wer.calculate(reference, hypothesis))

src/trustyai/utils/tokenizers.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
""""Default tokenizers for TrustyAI."""
2+
# pylint: disable = import-error
3+
4+
from org.apache.commons.text import StringTokenizer as _StringTokenizer
5+
from opennlp.tools.tokenize import SimpleTokenizer as _SimpleTokenizer
6+
7+
CommonsStringTokenizer = _StringTokenizer
8+
OpenNLPTokenizer = _SimpleTokenizer
Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
# pylint: disable=import-error, wrong-import-position, wrong-import-order, duplicate-code, unused-import
2+
"""Language metrics test suite"""
3+
4+
from common import *
5+
from trustyai.metrics.language import word_error_rate
6+
import math
7+
8+
tolerance = 1e-4
9+
10+
REFERENCES = [
11+
"This is the test reference, to which I will compare alignment against.",
12+
"Lorem ipsum dolor sit amet, consectetur adipiscing elit. Curabitur condimentum velit id velit posuere dictum. Fusce euismod tortor massa, nec euismod sapien laoreet non. Donec vulputate mi velit, eu ultricies nibh iaculis vel. Aenean posuere urna nec sapien consectetur, vitae porttitor sapien finibus. Duis nec libero convallis lectus pharetra blandit ut ac odio. Vivamus nec dui quis sem convallis pulvinar. Maecenas sodales sollicitudin leo a faucibus.",
13+
"The quick red fox jumped over the lazy brown dog"]
14+
15+
INPUTS = [
16+
"I'm a hypothesis reference, from which the aligner will compare against.",
17+
"Lorem ipsum sit amet, consectetur adipiscing elit. Curabitur condimentum velit id velit posuere dictum. Fusce blandit euismod tortor massa, nec euismod sapien blandit laoreet non. Donec vulputate mi velit, eu ultricies nibh iaculis vel. Aenean posuere urna nec sapien consectetur, vitae porttitor sapien finibus. Duis nec libero convallis lectus pharetra blandit ut ac odio. Vivamus nec dui quis sem convallis pulvinar. Maecenas sodales sollicitudin leo a faucibus.",
18+
"dog brown lazy the over jumped fox red quick The"]
19+
20+
21+
def test_default_tokenizer():
22+
"""Test default tokenizer"""
23+
results = [4 / 7, 1 / 26, 1]
24+
for i, (reference, hypothesis) in enumerate(zip(REFERENCES, INPUTS)):
25+
wer = word_error_rate(reference, hypothesis).wer
26+
assert math.isclose(wer, results[i], rel_tol=tolerance), \
27+
f"WER for {reference}, {hypothesis} was {wer}, expected ~{results[i]}."
28+
29+
30+
def test_commons_stringtokenizer():
31+
"""Test Apache Commons StringTokenizer"""
32+
from trustyai.utils.tokenizers import CommonsStringTokenizer
33+
results = [8 / 12., 3 / 66., 1.0]
34+
35+
def tokenizer(text: str) -> List[str]:
36+
return CommonsStringTokenizer(text).getTokenList()
37+
38+
for i, (reference, hypothesis) in enumerate(zip(REFERENCES, INPUTS)):
39+
wer = word_error_rate(reference, hypothesis, tokenizer=tokenizer).wer
40+
assert math.isclose(wer, results[i], rel_tol=tolerance), \
41+
f"WER for {reference}, {hypothesis} was {wer}, expected ~{results[i]}."
42+
43+
44+
def test_opennlp_tokenizer():
45+
"""Test Apache Commons StringTokenizer"""
46+
from trustyai.utils.tokenizers import OpenNLPTokenizer
47+
results = [9 / 14., 3 / 78., 1.0]
48+
tokenizer = OpenNLPTokenizer()
49+
for i, (reference, hypothesis) in enumerate(zip(REFERENCES, INPUTS)):
50+
wer = word_error_rate(reference, hypothesis, tokenizer=tokenizer).wer
51+
assert math.isclose(wer, results[i], rel_tol=tolerance), \
52+
f"WER for {reference}, {hypothesis} was {wer}, expected ~{results[i]}."
53+
54+
55+
def test_python_tokenizer():
56+
"""Test pure Python whitespace tokenizer"""
57+
58+
results = [3 / 4., 3 / 66., 1.0]
59+
60+
def tokenizer(text: str) -> List[str]:
61+
return text.split(" ")
62+
63+
for i, (reference, hypothesis) in enumerate(zip(REFERENCES, INPUTS)):
64+
wer = word_error_rate(reference, hypothesis, tokenizer=tokenizer).wer
65+
assert math.isclose(wer, results[i], rel_tol=tolerance), \
66+
f"WER for {reference}, {hypothesis} was {wer}, expected ~{results[i]}."

0 commit comments

Comments
 (0)