Skip to content

Commit 1f36f02

Browse files
author
Krzysztof Czajkowski
committed
Add speed related improvements
1 parent 950cb97 commit 1f36f02

File tree

4 files changed

+152
-91
lines changed

4 files changed

+152
-91
lines changed

editdistance/_edit_distance_osa.cpp

Lines changed: 51 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -4,58 +4,67 @@
44

55
std::vector<std::vector<double>> compute_dp_table(
66
const std::string& a,
7-
const std::string& b,
8-
const std::map<CppEditopName, double>& cost_map
7+
const std::string& b,
8+
double replace_weight,
9+
double insert_weight,
10+
double delete_weight,
11+
double swap_weight
912
) {
1013
int len_a = a.length();
1114
int len_b = b.length();
1215
std::vector<std::vector<double>> dp(len_a + 1, std::vector<double>(len_b + 1, 0.0));
1316

1417
for (int i = 0; i <= len_a; ++i) {
15-
dp[i][0] = i * cost_map.at(DELETE);
18+
dp[i][0] = i * delete_weight;
1619
}
1720
for (int j = 0; j <= len_b; ++j) {
18-
dp[0][j] = j * cost_map.at(INSERT);
21+
dp[0][j] = j * insert_weight;
1922
}
2023

2124
for (int i = 1; i <= len_a; ++i) {
2225
for (int j = 1; j <= len_b; ++j) {
23-
double deletion = dp[i-1][j] + cost_map.at(DELETE);
24-
double insertion = dp[i][j-1] + cost_map.at(INSERT);
25-
double substitution_cost = (a[i-1] == b[j-1]) ? 0.0 : cost_map.at(REPLACE);
26-
double substitution = dp[i-1][j-1] + substitution_cost;
27-
26+
if (a[i-1] == b[j-1]) {
27+
dp[i][j] = dp[i-1][j-1]; // match, no cost
28+
continue; // skip swap and other ops, match is optimal
29+
}
30+
double deletion = dp[i-1][j] + delete_weight;
31+
double insertion = dp[i][j-1] + insert_weight;
32+
double substitution = dp[i-1][j-1] + replace_weight;
2833
dp[i][j] = std::min({deletion, insertion, substitution});
29-
3034
if (i > 1 && j > 1 &&
3135
a[i-1] == b[j-2] && a[i-2] == b[j-1]) {
3236
dp[i][j] = std::min(dp[i][j],
33-
dp[i-2][j-2] + cost_map.at(SWAP));
37+
dp[i-2][j-2] + swap_weight);
3438
}
3539
}
3640
}
3741

3842
return dp;
3943
}
4044

41-
4245
double cpp_compute_distance(
4346
const std::string& a,
44-
const std::string& b,
45-
const std::map<CppEditopName, double>& cost_map
47+
const std::string& b,
48+
double replace_weight,
49+
double insert_weight,
50+
double delete_weight,
51+
double swap_weight
4652
) {
47-
auto dp = compute_dp_table(a, b, cost_map);
53+
auto dp = compute_dp_table(a, b, replace_weight, insert_weight, delete_weight, swap_weight);
4854
return dp[a.length()][b.length()];
4955
}
5056

5157
std::vector<std::vector<CppEditop>> backtrack_all_paths(
5258
const std::string& a,
5359
const std::string& b,
54-
const std::map<CppEditopName, double>& cost_map,
5560
const std::vector<std::vector<double>>& dp,
5661
int i,
5762
int j,
58-
std::vector<CppEditop>& current_path
63+
std::vector<CppEditop>& current_path,
64+
double replace_weight,
65+
double insert_weight,
66+
double delete_weight,
67+
double swap_weight
5968
) {
6069
if (i == 0 && j == 0) {
6170
std::vector<CppEditop> reversed_path = current_path;
@@ -67,70 +76,71 @@ std::vector<std::vector<CppEditop>> backtrack_all_paths(
6776
double current_cost = dp[i][j];
6877
const double tol = 1e-6;
6978

70-
71-
if (i > 0 && std::abs((dp[i-1][j] + cost_map.at(DELETE)) - current_cost) < tol) {
72-
CppEditop op(DELETE, i-1, i-1, cost_map.at(DELETE), std::string(1, a[i-1]));
79+
if (i > 0 && std::abs((dp[i-1][j] + delete_weight) - current_cost) < tol) {
80+
CppEditop op(DELETE, i-1, i-1, delete_weight, std::string(1, a[i-1]));
7381
current_path.push_back(op);
74-
auto paths = backtrack_all_paths(a, b, cost_map, dp, i-1, j, current_path);
82+
auto paths = backtrack_all_paths(a, b, dp, i-1, j, current_path, replace_weight, insert_weight, delete_weight, swap_weight);
7583
all_paths.insert(all_paths.end(), paths.begin(), paths.end());
7684
current_path.pop_back();
7785
}
7886

79-
if (j > 0 && std::abs((dp[i][j-1] + cost_map.at(INSERT)) - current_cost) < tol) {
80-
CppEditop op(INSERT, i, i, cost_map.at(INSERT), std::string(1, b[j-1]));
87+
if (j > 0 && std::abs((dp[i][j-1] + insert_weight) - current_cost) < tol) {
88+
CppEditop op(INSERT, i, i, insert_weight, std::string(1, b[j-1]));
8189
current_path.push_back(op);
82-
auto paths = backtrack_all_paths(a, b, cost_map, dp, i, j-1, current_path);
90+
auto paths = backtrack_all_paths(a, b, dp, i, j-1, current_path, replace_weight, insert_weight, delete_weight, swap_weight);
8391
all_paths.insert(all_paths.end(), paths.begin(), paths.end());
8492
current_path.pop_back();
8593
}
8694

87-
8895
if (i > 0 && j > 0) {
89-
double sub_cost = (a[i-1] == b[j-1]) ? 0.0 : cost_map.at(REPLACE);
96+
double sub_cost = (a[i-1] == b[j-1]) ? 0.0 : replace_weight;
9097
if (std::abs((dp[i-1][j-1] + sub_cost) - current_cost) < tol) {
9198
std::string out_char = (sub_cost == 0.0) ? std::string(1, a[i-1]) : std::string(1, b[j-1]);
9299
CppEditop op(REPLACE, i-1, j-1, sub_cost, out_char);
93100
current_path.push_back(op);
94-
auto paths = backtrack_all_paths(a, b, cost_map, dp, i-1, j-1, current_path);
101+
auto paths = backtrack_all_paths(a, b, dp, i-1, j-1, current_path, replace_weight, insert_weight, delete_weight, swap_weight);
95102
all_paths.insert(all_paths.end(), paths.begin(), paths.end());
96103
current_path.pop_back();
97104
}
98105
}
99106

100-
101107
if (i > 1 && j > 1 &&
102108
a[i-1] == b[j-2] && a[i-2] == b[j-1] &&
103-
std::abs((dp[i-2][j-2] + cost_map.at(SWAP)) - current_cost) < tol) {
109+
std::abs((dp[i-2][j-2] + swap_weight) - current_cost) < tol) {
104110
std::string swap_str = std::string(1, b[j-2]) + std::string(1, b[j-1]);
105-
CppEditop op(SWAP, i-2, j-2, cost_map.at(SWAP), swap_str);
111+
CppEditop op(SWAP, i-2, j-2, swap_weight, swap_str);
106112
current_path.push_back(op);
107-
auto paths = backtrack_all_paths(a, b, cost_map, dp, i-2, j-2, current_path);
113+
auto paths = backtrack_all_paths(a, b, dp, i-2, j-2, current_path, replace_weight, insert_weight, delete_weight, swap_weight);
108114
all_paths.insert(all_paths.end(), paths.begin(), paths.end());
109115
current_path.pop_back();
110116
}
111117

112118
return all_paths;
113119
}
114120

115-
116121
std::vector<std::vector<CppEditop>> cpp_compute_all_paths(
117122
const std::string& a,
118-
const std::string& b,
119-
const std::map<CppEditopName, double>& cost_map
123+
const std::string& b,
124+
double replace_weight,
125+
double insert_weight,
126+
double delete_weight,
127+
double swap_weight
120128
) {
121-
auto dp = compute_dp_table(a, b, cost_map);
129+
auto dp = compute_dp_table(a, b, replace_weight, insert_weight, delete_weight, swap_weight);
122130
std::vector<CppEditop> current_path;
123-
return backtrack_all_paths(a, b, cost_map, dp, a.length(), b.length(), current_path);
131+
return backtrack_all_paths(a, b, dp, a.length(), b.length(), current_path, replace_weight, insert_weight, delete_weight, swap_weight);
124132
}
125133

126-
127134
void cpp_print_all_paths(
128135
const std::string& a,
129-
const std::string& b,
130-
const std::map<CppEditopName, double>& cost_map
136+
const std::string& b,
137+
double replace_weight,
138+
double insert_weight,
139+
double delete_weight,
140+
double swap_weight
131141
) {
132-
auto paths = cpp_compute_all_paths(a, b, cost_map);
133-
double distance = cpp_compute_distance(a, b, cost_map);
142+
auto paths = cpp_compute_all_paths(a, b, replace_weight, insert_weight, delete_weight, swap_weight);
143+
double distance = cpp_compute_distance(a, b, replace_weight, insert_weight, delete_weight, swap_weight);
134144

135145
std::cout << "OSA Distance from '" << a << "' to '" << b << "': " << distance << std::endl;
136146
std::cout << "Number of optimal edit sequences: " << paths.size() << std::endl;

editdistance/_edit_distance_osa.hpp

Lines changed: 25 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -29,40 +29,55 @@ struct CppEditop {
2929

3030
std::vector<std::vector<double>> compute_dp_table(
3131
const std::string& a,
32-
const std::string& b,
33-
const std::map<CppEditopName, double>& cost_map
32+
const std::string& b,
33+
double replace_weight,
34+
double insert_weight,
35+
double delete_weight,
36+
double swap_weight
3437
);
3538

3639

3740
double cpp_compute_distance(
3841
const std::string& a,
39-
const std::string& b,
40-
const std::map<CppEditopName, double>& cost_map
42+
const std::string& b,
43+
double replace_weight,
44+
double insert_weight,
45+
double delete_weight,
46+
double swap_weight
4147
);
4248

4349

4450
std::vector<std::vector<CppEditop>> backtrack_all_paths(
4551
const std::string& a,
4652
const std::string& b,
47-
const std::map<CppEditopName, double>& cost_map,
4853
const std::vector<std::vector<double>>& dp,
4954
int i,
5055
int j,
51-
std::vector<CppEditop>& current_path
56+
std::vector<CppEditop>& current_path,
57+
double replace_weight,
58+
double insert_weight,
59+
double delete_weight,
60+
double swap_weight
5261
);
5362

5463

5564
std::vector<std::vector<CppEditop>> cpp_compute_all_paths(
5665
const std::string& a,
57-
const std::string& b,
58-
const std::map<CppEditopName, double>& cost_map
66+
const std::string& b,
67+
double replace_weight,
68+
double insert_weight,
69+
double delete_weight,
70+
double swap_weight
5971
);
6072

6173

6274
void cpp_print_all_paths(
6375
const std::string& a,
64-
const std::string& b,
65-
const std::map<CppEditopName, double>& cost_map
76+
const std::string& b,
77+
double replace_weight,
78+
double insert_weight,
79+
double delete_weight,
80+
double swap_weight
6681
);
6782

6883

editdistance/edit_distance_osa.pyx

Lines changed: 15 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,15 @@ cdef extern from "_edit_distance_osa.hpp":
2222
double cost
2323
string output_string
2424

25-
vector[vector[CppEditop]] cpp_compute_all_paths(const string& a, const string& b, const map[CppEditopName, double]& cost_map)
26-
void cpp_print_all_paths(const string& a, const string& b, const map[CppEditopName, double]& cost_map)
27-
double cpp_compute_distance(const string& a, const string& b, const map[CppEditopName, double]& cost_map)
25+
vector[vector[CppEditop]] cpp_compute_all_paths(
26+
const string& a, const string& b,
27+
double replace_weight, double insert_weight, double delete_weight, double swap_weight)
28+
void cpp_print_all_paths(
29+
const string& a, const string& b,
30+
double replace_weight, double insert_weight, double delete_weight, double swap_weight)
31+
double cpp_compute_distance(
32+
const string& a, const string& b,
33+
double replace_weight, double insert_weight, double delete_weight, double swap_weight)
2834

2935

3036
class EditopName(Enum):
@@ -52,19 +58,6 @@ cdef class Editop:
5258
return f"Editop(name={self.name}, src_idx={self.src_idx}, dst_idx={self.dst_idx}, cost={self.cost}, output_string='{self.output_string}')"
5359

5460

55-
cdef map[CppEditopName, double] _convert_cost_map(dict cost_map):
56-
cdef map[CppEditopName, double] cpp_cost_map
57-
if EditopName.INSERT in cost_map:
58-
cpp_cost_map[INSERT] = cost_map[EditopName.INSERT]
59-
if EditopName.DELETE in cost_map:
60-
cpp_cost_map[DELETE] = cost_map[EditopName.DELETE]
61-
if EditopName.REPLACE in cost_map:
62-
cpp_cost_map[REPLACE] = cost_map[EditopName.REPLACE]
63-
if EditopName.SWAP in cost_map:
64-
cpp_cost_map[SWAP] = cost_map[EditopName.SWAP]
65-
return cpp_cost_map
66-
67-
6861
def get_all_paths(
6962
str a,
7063
str b,
@@ -73,16 +66,10 @@ def get_all_paths(
7366
double delete_weight=1.0,
7467
double swap_weight=1.0
7568
):
76-
cdef dict cost_map = {
77-
EditopName.REPLACE: replace_weight,
78-
EditopName.INSERT: insert_weight,
79-
EditopName.DELETE: delete_weight,
80-
EditopName.SWAP: swap_weight
81-
}
8269
cdef string cpp_a = a.encode("utf-8")
8370
cdef string cpp_b = b.encode("utf-8")
84-
cdef map[CppEditopName, double] cpp_cost_map = _convert_cost_map(cost_map)
85-
cdef vector[vector[CppEditop]] cpp_paths = cpp_compute_all_paths(cpp_a, cpp_b, cpp_cost_map)
71+
cdef vector[vector[CppEditop]] cpp_paths = cpp_compute_all_paths(
72+
cpp_a, cpp_b, replace_weight, insert_weight, delete_weight, swap_weight)
8673
python_paths = []
8774
cdef vector[CppEditop] cpp_path
8875
cdef CppEditop cpp_op
@@ -120,16 +107,10 @@ def print_all_paths(
120107
double delete_weight=1.0,
121108
double swap_weight=1.0
122109
):
123-
cdef dict cost_map = {
124-
EditopName.REPLACE: replace_weight,
125-
EditopName.INSERT: insert_weight,
126-
EditopName.DELETE: delete_weight,
127-
EditopName.SWAP: swap_weight
128-
}
129110
cdef string cpp_a = a.encode("utf-8")
130111
cdef string cpp_b = b.encode("utf-8")
131-
cdef map[CppEditopName, double] cpp_cost_map = _convert_cost_map(cost_map)
132-
cpp_print_all_paths(cpp_a, cpp_b, cpp_cost_map)
112+
cpp_print_all_paths(
113+
cpp_a, cpp_b, replace_weight, insert_weight, delete_weight, swap_weight)
133114

134115

135116
def compute_distance(
@@ -140,13 +121,7 @@ def compute_distance(
140121
double delete_weight=1.0,
141122
double swap_weight=1.0
142123
):
143-
cdef dict cost_map = {
144-
EditopName.REPLACE: replace_weight,
145-
EditopName.INSERT: insert_weight,
146-
EditopName.DELETE: delete_weight,
147-
EditopName.SWAP: swap_weight
148-
}
149124
cdef string cpp_a = a.encode("utf-8")
150125
cdef string cpp_b = b.encode("utf-8")
151-
cdef map[CppEditopName, double] cpp_cost_map = _convert_cost_map(cost_map)
152-
return cpp_compute_distance(cpp_a, cpp_b, cpp_cost_map)
126+
return cpp_compute_distance(
127+
cpp_a, cpp_b, replace_weight, insert_weight, delete_weight, swap_weight)

0 commit comments

Comments
 (0)