@@ -44,16 +44,13 @@ def extract_chunks(search_response: SearchResponse) -> List[Chunk]:
4444
4545
4646def replace_keywords (
47- text : Union [str , List [str ]], keyword_replacements : Dict [ str , str ]
48- ) -> tuple [
49- Union [ str , List [str ]], Dict [ str , Union [ str , List [ str ]]]
50- ]: # Changed return type
47+ text : Union [str , List [str ]],
48+ keyword_replacements : Dict [ str , str ],
49+ conditional_replacements : List [tuple [ List [ str ], str ]] = [],
50+ ) -> tuple [ Union [ str , List [ str ]], Dict [ str , Union [ str , List [ str ]]]]:
5151 """Replace keywords in text and return both the modified text and transformation details."""
52- if not text or not keyword_replacements :
53- return text , {
54- "original" : text ,
55- "resolved" : text ,
56- } # Return dict instead of TransformationDict
52+ if not text or (not keyword_replacements and not conditional_replacements ):
53+ return text , {"original" : text , "resolved" : text }
5754
5855 # Handle list of strings
5956 if isinstance (text , list ):
@@ -62,13 +59,12 @@ def replace_keywords(
6259 modified = False
6360
6461 # Create a single regex pattern for all keywords
65- pattern = "|" .join (map (re .escape , keyword_replacements .keys ()))
66- regex = re .compile (f"\\ b({ pattern } )\\ b" )
62+ # pattern = "|".join(map(re.escape, keyword_replacements.keys()))
63+ # regex = re.compile(f"\\b({pattern})\\b")
6764
6865 for item in text :
69- # Single pass replacement for all keywords
70- new_item = regex .sub (
71- lambda m : keyword_replacements [m .group ()], item
66+ new_item , _ = replace_keywords_in_string (
67+ item , keyword_replacements , conditional_replacements
7268 )
7369 result .append (new_item )
7470 if new_item != item :
@@ -79,24 +75,46 @@ def replace_keywords(
7975 return result , {"original" : original_text , "resolved" : result }
8076
8177 # Handle single string
82- return replace_keywords_in_string (text , keyword_replacements )
78+ return replace_keywords_in_string (
79+ text , keyword_replacements , conditional_replacements
80+ )
81+
82+
83+ def parse_conditional_replacement (option : str ) -> tuple [List [str ], str ]:
84+ """Parse a conditional replacement rule like 'word a + word b : word c'."""
85+ conditions , replacement = option .split (":" )
86+ required_words = [word .strip () for word in conditions .split ("+" )]
87+ return required_words , replacement .strip ()
8388
8489
8590def replace_keywords_in_string (
86- text : str , keyword_replacements : Dict [str , str ]
87- ) -> tuple [str , Dict [str , Union [str , List [str ]]]]: # Changed return type
91+ text : str ,
92+ keyword_replacements : Dict [str , str ],
93+ conditional_replacements : List [tuple [List [str ], str ]] = [],
94+ ) -> tuple [str , Dict [str , Union [str , List [str ]]]]:
8895 """Keywords for single string."""
89- if not text :
96+ if not text or ( not keyword_replacements and not conditional_replacements ) :
9097 return text , {"original" : text , "resolved" : text }
9198
92- # Create a single regex pattern for all keywords
93- pattern = "|" .join (map (re .escape , keyword_replacements .keys ()))
94- regex = re .compile (f"\\ b({ pattern } )\\ b" )
99+ result = text
100+
101+ # First check conditional replacements
102+ for required_words , replacement in conditional_replacements :
103+ # Check if all required words are present
104+ if all (word .lower () in text .lower () for word in required_words ):
105+ # Create a pattern that matches any of the required words
106+ pattern = "|" .join (map (re .escape , required_words ))
107+ # Replace all occurrences of the required words with the replacement
108+ result = re .sub (
109+ f"\\ b({ pattern } )\\ b" , replacement , result , flags = re .IGNORECASE
110+ )
95111
96- # Single pass replacement
97- result = regex .sub (lambda m : keyword_replacements [m .group ()], text )
112+ # Then do normal replacements
113+ if keyword_replacements :
114+ pattern = "|" .join (map (re .escape , keyword_replacements .keys ()))
115+ regex = re .compile (f"\\ b({ pattern } )\\ b" )
116+ result = regex .sub (lambda m : keyword_replacements [m .group ()], result )
98117
99- # Only return transformation if something changed
100118 if result != text :
101119 return result , {"original" : text , "resolved" : result }
102120 return text , {"original" : text , "resolved" : text }
@@ -131,11 +149,13 @@ async def process_query(
131149 result_chunks = []
132150
133151 if format in ["str" , "str_array" ]:
134-
135- # Extract and apply keyword replacements from all resolve_entity rules
152+ # Extract rules by type
136153 resolve_entity_rules = [
137154 rule for rule in rules if rule .type == "resolve_entity"
138155 ]
156+ conditional_rules = [
157+ rule for rule in rules if rule .type == "resolve_conditional"
158+ ]
139159
140160 result_chunks = (
141161 []
@@ -144,28 +164,43 @@ async def process_query(
144164 else chunks
145165 )
146166
147- # First populate the replacements dictionary
148- replacements : Dict [str , str ] = {}
149- if resolve_entity_rules and answer_value :
150- for rule in resolve_entity_rules :
151- if rule .options :
152- rule_replacements = dict (
153- option .split (":" ) for option in rule .options
154- )
155- replacements .update (rule_replacements )
156-
157- # Then apply the replacements if we have any
158- if replacements :
167+ # Process both types of replacements if we have an answer
168+ if answer_value and (resolve_entity_rules or conditional_rules ):
169+ # Build regular replacements dictionary
170+ replacements : Dict [str , str ] = {}
171+ if resolve_entity_rules :
172+ for rule in resolve_entity_rules :
173+ if rule .options :
174+ rule_replacements = dict (
175+ option .split (":" ) for option in rule .options
176+ )
177+ replacements .update (rule_replacements )
178+
179+ # Build conditional replacements list
180+ conditional_replacements : List [tuple [List [str ], str ]] = []
181+ if conditional_rules :
182+ for rule in conditional_rules :
183+ if rule .options :
184+ for option in rule .options :
185+ required_words , replacement = (
186+ parse_conditional_replacement (option )
187+ )
188+ conditional_replacements .append (
189+ (required_words , replacement )
190+ )
191+
192+ # Apply replacements if we have any
193+ if replacements or conditional_replacements :
159194 print (f"Resolving entities in answer: { answer_value } " )
160195 if isinstance (answer_value , list ):
161196 transformed_list , transform_dict = replace_keywords (
162- answer_value , replacements
197+ answer_value , replacements , conditional_replacements
163198 )
164199 transformations = transform_dict
165200 answer_value = transformed_list
166201 else :
167202 transformed_value , transform_dict = replace_keywords (
168- answer_value , replacements
203+ answer_value , replacements , conditional_replacements
169204 )
170205 transformations = transform_dict
171206 answer_value = transformed_value
@@ -256,31 +291,47 @@ async def inference_query(
256291 llm_service : CompletionService ,
257292) -> QueryResult :
258293 """Generate a response, no need for vector retrieval."""
259- # Since we are just answering this query based on data provided in the query,
260- # ther is no need to retrieve any chunks from the vector database.
261-
262294 answer = await generate_inferred_response (
263295 llm_service , query , rules , format
264296 )
265297 answer_value = answer ["answer" ]
266298
267- # Extract and apply keyword replacements from all resolve_entity rules
299+ # Extract rules by type
268300 resolve_entity_rules = [
269301 rule for rule in rules if rule .type == "resolve_entity"
270302 ]
303+ conditional_rules = [
304+ rule for rule in rules if rule .type == "resolve_conditional"
305+ ]
271306
272- if resolve_entity_rules and answer_value :
273- # Combine all replacements from all resolve_entity rules
307+ if answer_value and ( resolve_entity_rules or conditional_rules ) :
308+ # Build regular replacements
274309 replacements = {}
275- for rule in resolve_entity_rules :
276- if rule .options :
277- rule_replacements = dict (
278- option .split (":" ) for option in rule .options
279- )
280- replacements .update (rule_replacements )
310+ if resolve_entity_rules :
311+ for rule in resolve_entity_rules :
312+ if rule .options :
313+ rule_replacements = dict (
314+ option .split (":" ) for option in rule .options
315+ )
316+ replacements .update (rule_replacements )
281317
282- if replacements :
318+ # Build conditional replacements
319+ conditional_replacements = []
320+ if conditional_rules :
321+ for rule in conditional_rules :
322+ if rule .options :
323+ for option in rule .options :
324+ required_words , replacement = (
325+ parse_conditional_replacement (option )
326+ )
327+ conditional_replacements .append (
328+ (required_words , replacement )
329+ )
330+
331+ if replacements or conditional_replacements :
283332 print (f"Resolving entities in answer: { answer_value } " )
284- answer_value = replace_keywords (answer_value , replacements )
333+ answer_value , _ = replace_keywords (
334+ answer_value , replacements , conditional_replacements
335+ )
285336
286337 return QueryResult (answer = answer_value , chunks = [])
0 commit comments