55from typing import Any , Awaitable , Callable , Dict , List , Union
66
77from app .models .query_core import Chunk , FormatType , QueryType , Rule
8- from app .schemas .query_api import QueryResult , SearchResponse
8+ from app .schemas .query_api import (
9+ QueryResult ,
10+ ResolvedEntitySchema ,
11+ SearchResponse ,
12+ )
913from app .services .llm_service import (
1014 CompletionService ,
1115 generate_inferred_response ,
@@ -40,58 +44,62 @@ def extract_chunks(search_response: SearchResponse) -> List[Chunk]:
4044
4145
4246def replace_keywords (
43- text : Union [str , List [str ]], keyword_replacements : dict [str , str ]
44- ) -> tuple [Union [str , List [str ]], dict [str , str ]]:
47+ text : Union [str , List [str ]], keyword_replacements : Dict [str , str ]
48+ ) -> tuple [
49+ Union [str , List [str ]], Dict [str , Union [str , List [str ]]]
50+ ]: # Changed return type
4551 """Replace keywords in text and return both the modified text and transformation details."""
4652 if not text or not keyword_replacements :
47- return text , {}
53+ return text , {
54+ "original" : text ,
55+ "resolved" : text ,
56+ } # Return dict instead of TransformationDict
4857
4958 # Handle list of strings
5059 if isinstance (text , list ):
5160 original_text = text .copy ()
5261 result = []
5362 modified = False
54-
63+
5564 # Create a single regex pattern for all keywords
56- pattern = '|' .join (map (re .escape , keyword_replacements .keys ()))
57- regex = re .compile (f' \\ b({ pattern } )\\ b' )
58-
65+ pattern = "|" .join (map (re .escape , keyword_replacements .keys ()))
66+ regex = re .compile (f" \\ b({ pattern } )\\ b" )
67+
5968 for item in text :
6069 # Single pass replacement for all keywords
61- new_item = regex .sub (lambda m : keyword_replacements [m .group ()], item )
70+ new_item = regex .sub (
71+ lambda m : keyword_replacements [m .group ()], item
72+ )
6273 result .append (new_item )
6374 if new_item != item :
6475 modified = True
65-
66- # Only return transformation if something actually changed
76+
6777 if modified :
68- return result , {
69- "original" : original_text ,
70- "resolved" : result
71- }
72- return result , {}
78+ return result , {"original" : original_text , "resolved" : result }
79+ return result , {"original" : original_text , "resolved" : result }
7380
7481 # Handle single string
7582 return replace_keywords_in_string (text , keyword_replacements )
7683
84+
7785def replace_keywords_in_string (
78- text : str , keyword_replacements : dict [str , str ]
79- ) -> tuple [str , dict [str , str ]]:
86+ text : str , keyword_replacements : Dict [str , str ]
87+ ) -> tuple [str , Dict [str , Union [ str , List [ str ]]]]: # Changed return type
8088 """Keywords for single string."""
8189 if not text :
82- return text , {}
90+ return text , {"original" : text , "resolved" : text }
8391
8492 # Create a single regex pattern for all keywords
85- pattern = '|' .join (map (re .escape , keyword_replacements .keys ()))
86- regex = re .compile (f' \\ b({ pattern } )\\ b' )
87-
93+ pattern = "|" .join (map (re .escape , keyword_replacements .keys ()))
94+ regex = re .compile (f" \\ b({ pattern } )\\ b" )
95+
8896 # Single pass replacement
8997 result = regex .sub (lambda m : keyword_replacements [m .group ()], text )
90-
98+
9199 # Only return transformation if something changed
92100 if result != text :
93101 return result , {"original" : text , "resolved" : result }
94- return text , {}
102+ return text , {"original" : text , "resolved" : text }
95103
96104
97105async def process_query (
@@ -115,7 +123,10 @@ async def process_query(
115123 )
116124 answer_value = answer ["answer" ]
117125
118- transformations : Dict [str , str ] = {}
126+ transformations : Dict [str , Union [str , List [str ]]] = {
127+ "original" : "" ,
128+ "resolved" : "" ,
129+ }
119130
120131 result_chunks = []
121132
@@ -147,34 +158,36 @@ async def process_query(
147158 if replacements :
148159 print (f"Resolving entities in answer: { answer_value } " )
149160 if isinstance (answer_value , list ):
150- # Transform the list but keep track of both original and transformed
151- transformed_list , _ = replace_keywords (answer_value , replacements )
152- transformations = {
153- "original" : answer_value , # Keep as list
154- "resolved" : transformed_list # Keep as list
155- }
161+ transformed_list , transform_dict = replace_keywords (
162+ answer_value , replacements
163+ )
164+ transformations = transform_dict
156165 answer_value = transformed_list
157166 else :
158- # Handle single string case
159- transformed_value , _ = replace_keywords (answer_value , replacements )
160- transformations = {
161- "original" : answer_value ,
162- "resolved" : transformed_value
163- }
167+ transformed_value , transform_dict = replace_keywords (
168+ answer_value , replacements
169+ )
170+ transformations = transform_dict
164171 answer_value = transformed_value
165172
166-
167173 return QueryResult (
168174 answer = answer_value ,
169175 chunks = result_chunks [:10 ],
170- resolved_entities = [{
171- "original" : transformations ["original" ],
172- "resolved" : transformations ["resolved" ],
173- "source" : {"type" : "column" , "id" : "some-id" },
174- "entityType" : "some-type"
175- }] if transformations else None
176+ resolved_entities = (
177+ [
178+ ResolvedEntitySchema (
179+ original = transformations ["original" ],
180+ resolved = transformations ["resolved" ],
181+ source = {"type" : "column" , "id" : "some-id" },
182+ entityType = "some-type" ,
183+ )
184+ ]
185+ if transformations ["original" ] or transformations ["resolved" ]
186+ else None
187+ ),
176188 )
177189
190+
178191# Convenience functions for specific query types
179192async def decomposition_query (
180193 query : str ,
0 commit comments