|
1 | 1 | """Query service.""" |
2 | 2 |
|
3 | 3 | import logging |
| 4 | +import re |
4 | 5 | from typing import Any, Awaitable, Callable, Dict, List, Union |
5 | 6 |
|
6 | 7 | from app.models.query_core import Chunk, FormatType, QueryType, Rule |
@@ -45,61 +46,52 @@ def replace_keywords( |
45 | 46 | if not text or not keyword_replacements: |
46 | 47 | return text, {} |
47 | 48 |
|
48 | | - transformations: Dict[str, str] = {} |
49 | | - |
50 | 49 | # Handle list of strings |
51 | 50 | if isinstance(text, list): |
| 51 | + original_text = text.copy() |
52 | 52 | result = [] |
53 | | - # Track which strings were modified |
| 53 | + modified = False |
| 54 | + |
| 55 | + # Create a single regex pattern for all keywords |
| 56 | + pattern = '|'.join(map(re.escape, keyword_replacements.keys())) |
| 57 | + regex = re.compile(f'\\b({pattern})\\b') |
| 58 | + |
54 | 59 | for item in text: |
55 | | - if any(keyword in item.split() for keyword in keyword_replacements): |
56 | | - # Only process strings that contain keywords |
57 | | - transformed_item, item_transformations = replace_keywords_in_string(item, keyword_replacements) |
58 | | - result.append(transformed_item) |
59 | | - # Store the full before/after for the list item |
60 | | - transformations[item] = transformed_item |
61 | | - else: |
62 | | - result.append(item) |
63 | | - return result, transformations |
| 60 | + # Single pass replacement for all keywords |
| 61 | + new_item = regex.sub(lambda m: keyword_replacements[m.group()], item) |
| 62 | + result.append(new_item) |
| 63 | + if new_item != item: |
| 64 | + modified = True |
| 65 | + |
| 66 | + # Only return transformation if something actually changed |
| 67 | + if modified: |
| 68 | + return result, { |
| 69 | + "original": original_text, |
| 70 | + "resolved": result |
| 71 | + } |
| 72 | + return result, {} |
64 | 73 |
|
65 | 74 | # Handle single string |
66 | 75 | return replace_keywords_in_string(text, keyword_replacements) |
67 | 76 |
|
68 | | - |
69 | 77 | def replace_keywords_in_string( |
70 | 78 | text: str, keyword_replacements: dict[str, str] |
71 | 79 | ) -> tuple[str, dict[str, str]]: |
72 | 80 | """Keywords for single string.""" |
73 | 81 | if not text: |
74 | 82 | return text, {} |
75 | 83 |
|
76 | | - result = text |
77 | | - transformations: Dict[str, str] = {} |
78 | | - |
79 | | - for original, new_word in keyword_replacements.items(): |
80 | | - if original in text: |
81 | | - current_pos = 0 |
82 | | - while True: |
83 | | - start_idx = text.find(original, current_pos) |
84 | | - if start_idx == -1: # No more occurrences |
85 | | - break |
86 | | - |
87 | | - end_idx = start_idx + len(original) |
88 | | - current_pos = end_idx |
89 | | - |
90 | | - while end_idx < len(text) and ( |
91 | | - text[end_idx].isalnum() or text[end_idx] in "()" |
92 | | - ): |
93 | | - end_idx += 1 |
94 | | - |
95 | | - full_original = text[start_idx:end_idx] |
96 | | - suffix = text[start_idx + len(original) : end_idx] |
97 | | - full_new = new_word + suffix |
98 | | - |
99 | | - result = result.replace(full_original, full_new) |
100 | | - transformations[full_original] = full_new |
101 | | - |
102 | | - return result, transformations |
| 84 | + # Create a single regex pattern for all keywords |
| 85 | + pattern = '|'.join(map(re.escape, keyword_replacements.keys())) |
| 86 | + regex = re.compile(f'\\b({pattern})\\b') |
| 87 | + |
| 88 | + # Single pass replacement |
| 89 | + result = regex.sub(lambda m: keyword_replacements[m.group()], text) |
| 90 | + |
| 91 | + # Only return transformation if something changed |
| 92 | + if result != text: |
| 93 | + return result, {"original": text, "resolved": result} |
| 94 | + return text, {} |
103 | 95 |
|
104 | 96 |
|
105 | 97 | async def process_query( |
@@ -141,31 +133,48 @@ async def process_query( |
141 | 133 | else chunks |
142 | 134 | ) |
143 | 135 |
|
| 136 | + # First populate the replacements dictionary |
144 | 137 | replacements: Dict[str, str] = {} |
145 | | - |
146 | 138 | if resolve_entity_rules and answer_value: |
147 | | - # Combine all replacements from all resolve_entity rules |
148 | 139 | for rule in resolve_entity_rules: |
149 | 140 | if rule.options: |
150 | 141 | rule_replacements = dict( |
151 | 142 | option.split(":") for option in rule.options |
152 | 143 | ) |
153 | 144 | replacements.update(rule_replacements) |
154 | 145 |
|
| 146 | + # Then apply the replacements if we have any |
155 | 147 | if replacements: |
156 | 148 | print(f"Resolving entities in answer: {answer_value}") |
157 | | - # Handle both string and list cases |
158 | | - answer_value, transformations = replace_keywords( |
159 | | - answer_value, replacements |
160 | | - ) |
| 149 | + if isinstance(answer_value, list): |
| 150 | + # Transform the list but keep track of both original and transformed |
| 151 | + transformed_list, _ = replace_keywords(answer_value, replacements) |
| 152 | + transformations = { |
| 153 | + "original": answer_value, # Keep as list |
| 154 | + "resolved": transformed_list # Keep as list |
| 155 | + } |
| 156 | + answer_value = transformed_list |
| 157 | + else: |
| 158 | + # Handle single string case |
| 159 | + transformed_value, _ = replace_keywords(answer_value, replacements) |
| 160 | + transformations = { |
| 161 | + "original": answer_value, |
| 162 | + "resolved": transformed_value |
| 163 | + } |
| 164 | + answer_value = transformed_value |
| 165 | + |
161 | 166 |
|
162 | 167 | return QueryResult( |
163 | 168 | answer=answer_value, |
164 | 169 | chunks=result_chunks[:10], |
165 | | - resolved_entities=transformations if transformations else None, |
| 170 | + resolved_entities=[{ |
| 171 | + "original": transformations["original"], |
| 172 | + "resolved": transformations["resolved"], |
| 173 | + "source": {"type": "column", "id": "some-id"}, |
| 174 | + "entityType": "some-type" |
| 175 | + }] if transformations else None |
166 | 176 | ) |
167 | 177 |
|
168 | | - |
169 | 178 | # Convenience functions for specific query types |
170 | 179 | async def decomposition_query( |
171 | 180 | query: str, |
|
0 commit comments