Skip to content

Commit 714040b

Browse files
authored
Add Levenshtein distance (#191)
* Add Levenshtein distance * Fix linting and formatting * Fix matrix plot ranges
1 parent 55b90a9 commit 714040b

File tree

2 files changed

+121
-11
lines changed

2 files changed

+121
-11
lines changed

src/trustyai/metrics/distance.py

Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
""""Distance metrics"""
2+
# pylint: disable = import-error
3+
from dataclasses import dataclass
4+
from typing import List, Optional, Union, Callable
5+
6+
from org.kie.trustyai.metrics.language.distance import (
7+
Levenshtein as _Levenshtein,
8+
LevenshteinResult as _LevenshteinResult,
9+
LevenshteinCounters as _LevenshteinCounters,
10+
)
11+
from opennlp.tools.tokenize import Tokenizer
12+
import numpy as np
13+
import matplotlib.pyplot as plt
14+
from trustyai import _default_initializer # pylint: disable=unused-import
15+
16+
17+
@dataclass
18+
class LevenshteinCounters:
19+
"""LevenshteinCounters Counters"""
20+
21+
substitutions: int
22+
insertions: int
23+
deletions: int
24+
correct: int
25+
26+
@staticmethod
27+
def convert(result: _LevenshteinCounters):
28+
"""Converts a Java LevenshteinCounters to a Python LevenshteinCounters"""
29+
return LevenshteinCounters(
30+
substitutions=result.getSubstitutions(),
31+
insertions=result.getInsertions(),
32+
deletions=result.getDeletions(),
33+
correct=result.getCorrect(),
34+
)
35+
36+
37+
@dataclass
38+
class LevenshteinResult:
39+
"""Levenshtein Result"""
40+
41+
distance: float
42+
counters: LevenshteinCounters
43+
matrix: np.ndarray
44+
reference: List[str]
45+
hypothesis: List[str]
46+
47+
@staticmethod
48+
def convert(result: _LevenshteinResult):
49+
"""Converts a Java LevenshteinResult to a Python LevenshteinResult"""
50+
distance = result.getDistance()
51+
counters = LevenshteinCounters.convert(result.getCounters())
52+
data = result.getDistanceMatrix().getData()
53+
numpy_array = np.array(data)[1:, 1:]
54+
reference = result.getReferenceTokens()
55+
hypothesis = result.getHypothesisTokens()
56+
57+
return LevenshteinResult(
58+
distance=distance,
59+
counters=counters,
60+
matrix=numpy_array,
61+
reference=reference,
62+
hypothesis=hypothesis,
63+
)
64+
65+
def plot(self):
66+
"""Plot the Levenshtein distance matrix"""
67+
cmap = plt.cm.viridis
68+
69+
_, axes = plt.subplots()
70+
cax = axes.imshow(self.matrix, cmap=cmap, interpolation="nearest")
71+
72+
plt.colorbar(cax)
73+
74+
axes.set_xticks(np.arange(len(self.reference)))
75+
axes.set_yticks(np.arange(len(self.hypothesis)))
76+
axes.set_xticklabels(self.reference)
77+
axes.set_yticklabels(self.hypothesis)
78+
79+
plt.setp(
80+
axes.get_xticklabels(), rotation=45, ha="right", rotation_mode="anchor"
81+
)
82+
83+
nrows, ncols = self.matrix.shape
84+
for i in range(nrows):
85+
for j in range(ncols):
86+
color = (
87+
"white" if self.matrix[i, j] < self.matrix.max() / 2 else "black"
88+
)
89+
axes.text(
90+
j, i, int(self.matrix[i, j]), ha="center", va="center", color=color
91+
)
92+
93+
plt.show()
94+
95+
96+
def levenshtein(
97+
reference: str,
98+
hypothesis: str,
99+
tokenizer: Optional[Union[Tokenizer, Callable[[str], List[str]]]] = None,
100+
) -> LevenshteinResult:
101+
"""Calculate Levenshtein distance between two strings"""
102+
if not tokenizer:
103+
return LevenshteinResult.convert(
104+
_Levenshtein.calculateToken(reference, hypothesis)
105+
)
106+
if isinstance(tokenizer, Tokenizer):
107+
return LevenshteinResult.convert(
108+
_Levenshtein.calculateToken(reference, hypothesis, tokenizer)
109+
)
110+
if callable(tokenizer):
111+
tokenized_reference = tokenizer(reference)
112+
tokenized_hypothesis = tokenizer(hypothesis)
113+
return LevenshteinResult.convert(
114+
_Levenshtein.calculateToken(tokenized_reference, tokenized_hypothesis)
115+
)
116+
117+
raise ValueError("Unsupported tokenizer")

src/trustyai/metrics/language.py

Lines changed: 4 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,17 @@
1-
""""Group fairness metrics"""
1+
""""Language metrics"""
2+
# pylint: disable = import-error
23
from dataclasses import dataclass
34

4-
# pylint: disable = import-error
55
from typing import List, Optional, Union, Callable
66

77
from org.kie.trustyai.metrics.language.levenshtein import (
88
WordErrorRate as _WordErrorRate,
99
ErrorRateResult as _ErrorRateResult,
1010
)
1111
from opennlp.tools.tokenize import Tokenizer
12+
from trustyai import _default_initializer # pylint: disable=unused-import
1213

13-
14-
@dataclass
15-
class LevenshteinCounters:
16-
"""LevenshteinCounters Counters"""
17-
18-
substitutions: int
19-
insertions: int
20-
deletions: int
21-
correct: int
14+
from .distance import LevenshteinCounters
2215

2316

2417
@dataclass

0 commit comments

Comments
 (0)