Skip to content

Commit d903fa1

Browse files
committed
fix: uncapped fuzzy sort return and sort by "score"
1 parent 47a9856 commit d903fa1

File tree

1 file changed

+57
-11
lines changed

1 file changed

+57
-11
lines changed

fred/fred_commands/_command_utils.py

Lines changed: 57 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,30 +1,54 @@
11
from typing import Type
22

3-
from regex import ENHANCEMATCH, match, escape
3+
from regex import ENHANCEMATCH, match, escape, search as re_search
44

55
from ..config import Commands, Crashes, Misc
66
from ..libraries.common import new_logger
77

88
logger = new_logger("[Command/Crash Search]")
99

1010

11-
def search(table: Type[Commands | Crashes], pattern: str, column: str, force_fuzzy: bool) -> (str | list[str], bool):
12-
"""Returns the top three results based on the result"""
11+
def search(table: Type[Commands | Crashes], pattern: str, column: str, force_fuzzy: bool) -> tuple[str | list[str], bool]:
12+
"""Returns the top results based on the result.
13+
14+
This function performs an exact lookup unless `force_fuzzy` is True or no
15+
exact match is found. Fuzzy results are returned as a list of names (best
16+
matches first). The score is computed from the regex Match.fuzzy_counts()
17+
(inserts + deletes + substitutions) and used only for sorting/filtering.
18+
"""
1319

1420
if column not in dir(table):
1521
raise KeyError(f"`{column}` is not a column in the {table.__name__} table!")
1622

1723
if not force_fuzzy and (exact_match := table.fetch_by(column, pattern)):
1824
return exact_match[column], True
1925

20-
fuzzy_pattern = rf".*(?:{escape(pattern)}){{e<={min(len(pattern) // 3, 6)}}}.*"
21-
fuzzies: list[str] = [
22-
item["name"]
23-
for item in table.fetch_all()
24-
if (item.get(column, None) is not None) and match(fuzzy_pattern, item[column], flags=ENHANCEMATCH)
25-
]
26-
logger.info(fuzzies)
27-
return fuzzies[:5], False
26+
# Set fuzzy range - (1/3 pattern length, max 6)
27+
max_edits = min(len(pattern) // 3, 6)
28+
substring_pattern = rf".*(?:{escape(pattern)}){{e<={max_edits}}}.*"
29+
30+
scored_results: list[tuple[int, str]] = []
31+
for item in table.fetch_all():
32+
value = item.get(column)
33+
34+
# Filter non matching strings
35+
if not isinstance(value, str):
36+
continue
37+
if not re_search(substring_pattern, value, flags=ENHANCEMATCH):
38+
continue
39+
40+
# add levenshtein score
41+
score = levenshtein(pattern, value)
42+
scored_results.append((score, item["name"]))
43+
44+
# Sort by score, then alphabetically
45+
scored_results.sort(key=lambda x: (x[0], x[1]))
46+
results = [name for _, name in scored_results]
47+
48+
# Return all results fitting fuzzy range
49+
logger.info(results)
50+
return results, False
51+
2852

2953

3054
def get_search(table: Type[Commands | Crashes], pattern: str, column: str, force_fuzzy: bool) -> str:
@@ -54,3 +78,25 @@ def get_search(table: Type[Commands | Crashes], pattern: str, column: str, force
5478
response = e.args[0]
5579

5680
return response
81+
82+
83+
# Levenshtein distance algorithm
84+
def levenshtein(a: str, b: str) -> int:
85+
if a == b:
86+
return 0
87+
la, lb = len(a), len(b)
88+
if la == 0:
89+
return lb
90+
if lb == 0:
91+
return la
92+
93+
prev = list(range(lb + 1))
94+
for i, ca in enumerate(a, start=1):
95+
cur = [i] + [0] * lb
96+
for j, cb in enumerate(b, start=1):
97+
ins = cur[j - 1] + 1
98+
delete = prev[j] + 1
99+
sub = prev[j - 1] + (0 if ca == cb else 1)
100+
cur[j] = min(ins, delete, sub)
101+
prev = cur
102+
return prev[lb]

0 commit comments

Comments
 (0)