Skip to content

Commit 6f3890f

Browse files
committed
Updating for testing and linting
1 parent a093f7f commit 6f3890f

File tree

3 files changed

+86
-48
lines changed

3 files changed

+86
-48
lines changed

backend/src/app/models/query_core.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,29 @@
55
from pydantic import BaseModel
66

77

8+
class EntitySource(BaseModel):
9+
"""Entity source model."""
10+
11+
type: Literal["column", "global"]
12+
id: str
13+
14+
15+
class ResolvedEntity(BaseModel):
16+
"""Resolved entity model."""
17+
18+
original: Union[str, List[str]]
19+
resolved: Union[str, List[str]]
20+
source: EntitySource
21+
entityType: str
22+
23+
24+
class TransformationDict(BaseModel):
25+
"""Transformation dictionary model."""
26+
27+
original: Union[str, List[str]]
28+
resolved: Union[str, List[str]]
29+
30+
831
class Rule(BaseModel):
932
"""Rule model."""
1033

backend/src/app/schemas/query_api.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,10 @@
66

77
from app.models.query_core import Chunk, FormatType, Rule
88

9-
class ResolvedEntity(BaseModel):
9+
10+
class ResolvedEntitySchema(BaseModel):
1011
"""Schema for resolved entity transformations."""
12+
1113
original: Union[str, List[str]]
1214
resolved: Union[str, List[str]]
1315
source: dict[str, str]
@@ -46,7 +48,7 @@ class QueryResult(BaseModel):
4648

4749
answer: Any
4850
chunks: List[Chunk]
49-
resolved_entities: Optional[List[ResolvedEntity]] = None
51+
resolved_entities: Optional[List[ResolvedEntitySchema]] = None
5052

5153

5254
class QueryResponseSchema(BaseModel):
@@ -58,7 +60,7 @@ class QueryResponseSchema(BaseModel):
5860
answer: Optional[Any] = None
5961
chunks: List[Chunk]
6062
type: str
61-
resolved_entities: Optional[List[ResolvedEntity]] = None
63+
resolved_entities: Optional[List[ResolvedEntitySchema]] = None
6264

6365

6466
class QueryAnswer(BaseModel):
@@ -76,7 +78,7 @@ class QueryAnswerResponse(BaseModel):
7678

7779
answer: QueryAnswer
7880
chunks: List[Chunk]
79-
resolved_entities: Optional[List[ResolvedEntity]] = None
81+
resolved_entities: Optional[List[ResolvedEntitySchema]] = None
8082

8183

8284
# Type for search responses (used in service layer)

backend/src/app/services/query_service.py

Lines changed: 57 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,11 @@
55
from typing import Any, Awaitable, Callable, Dict, List, Union
66

77
from app.models.query_core import Chunk, FormatType, QueryType, Rule
8-
from app.schemas.query_api import QueryResult, SearchResponse
8+
from app.schemas.query_api import (
9+
QueryResult,
10+
ResolvedEntitySchema,
11+
SearchResponse,
12+
)
913
from app.services.llm_service import (
1014
CompletionService,
1115
generate_inferred_response,
@@ -40,58 +44,62 @@ def extract_chunks(search_response: SearchResponse) -> List[Chunk]:
4044

4145

4246
def replace_keywords(
43-
text: Union[str, List[str]], keyword_replacements: dict[str, str]
44-
) -> tuple[Union[str, List[str]], dict[str, str]]:
47+
text: Union[str, List[str]], keyword_replacements: Dict[str, str]
48+
) -> tuple[
49+
Union[str, List[str]], Dict[str, Union[str, List[str]]]
50+
]: # Changed return type
4551
"""Replace keywords in text and return both the modified text and transformation details."""
4652
if not text or not keyword_replacements:
47-
return text, {}
53+
return text, {
54+
"original": text,
55+
"resolved": text,
56+
} # Return dict instead of TransformationDict
4857

4958
# Handle list of strings
5059
if isinstance(text, list):
5160
original_text = text.copy()
5261
result = []
5362
modified = False
54-
63+
5564
# Create a single regex pattern for all keywords
56-
pattern = '|'.join(map(re.escape, keyword_replacements.keys()))
57-
regex = re.compile(f'\\b({pattern})\\b')
58-
65+
pattern = "|".join(map(re.escape, keyword_replacements.keys()))
66+
regex = re.compile(f"\\b({pattern})\\b")
67+
5968
for item in text:
6069
# Single pass replacement for all keywords
61-
new_item = regex.sub(lambda m: keyword_replacements[m.group()], item)
70+
new_item = regex.sub(
71+
lambda m: keyword_replacements[m.group()], item
72+
)
6273
result.append(new_item)
6374
if new_item != item:
6475
modified = True
65-
66-
# Only return transformation if something actually changed
76+
6777
if modified:
68-
return result, {
69-
"original": original_text,
70-
"resolved": result
71-
}
72-
return result, {}
78+
return result, {"original": original_text, "resolved": result}
79+
return result, {"original": original_text, "resolved": result}
7380

7481
# Handle single string
7582
return replace_keywords_in_string(text, keyword_replacements)
7683

84+
7785
def replace_keywords_in_string(
78-
text: str, keyword_replacements: dict[str, str]
79-
) -> tuple[str, dict[str, str]]:
86+
text: str, keyword_replacements: Dict[str, str]
87+
) -> tuple[str, Dict[str, Union[str, List[str]]]]: # Changed return type
8088
"""Keywords for single string."""
8189
if not text:
82-
return text, {}
90+
return text, {"original": text, "resolved": text}
8391

8492
# Create a single regex pattern for all keywords
85-
pattern = '|'.join(map(re.escape, keyword_replacements.keys()))
86-
regex = re.compile(f'\\b({pattern})\\b')
87-
93+
pattern = "|".join(map(re.escape, keyword_replacements.keys()))
94+
regex = re.compile(f"\\b({pattern})\\b")
95+
8896
# Single pass replacement
8997
result = regex.sub(lambda m: keyword_replacements[m.group()], text)
90-
98+
9199
# Only return transformation if something changed
92100
if result != text:
93101
return result, {"original": text, "resolved": result}
94-
return text, {}
102+
return text, {"original": text, "resolved": text}
95103

96104

97105
async def process_query(
@@ -115,7 +123,10 @@ async def process_query(
115123
)
116124
answer_value = answer["answer"]
117125

118-
transformations: Dict[str, str] = {}
126+
transformations: Dict[str, Union[str, List[str]]] = {
127+
"original": "",
128+
"resolved": "",
129+
}
119130

120131
result_chunks = []
121132

@@ -147,34 +158,36 @@ async def process_query(
147158
if replacements:
148159
print(f"Resolving entities in answer: {answer_value}")
149160
if isinstance(answer_value, list):
150-
# Transform the list but keep track of both original and transformed
151-
transformed_list, _ = replace_keywords(answer_value, replacements)
152-
transformations = {
153-
"original": answer_value, # Keep as list
154-
"resolved": transformed_list # Keep as list
155-
}
161+
transformed_list, transform_dict = replace_keywords(
162+
answer_value, replacements
163+
)
164+
transformations = transform_dict
156165
answer_value = transformed_list
157166
else:
158-
# Handle single string case
159-
transformed_value, _ = replace_keywords(answer_value, replacements)
160-
transformations = {
161-
"original": answer_value,
162-
"resolved": transformed_value
163-
}
167+
transformed_value, transform_dict = replace_keywords(
168+
answer_value, replacements
169+
)
170+
transformations = transform_dict
164171
answer_value = transformed_value
165172

166-
167173
return QueryResult(
168174
answer=answer_value,
169175
chunks=result_chunks[:10],
170-
resolved_entities=[{
171-
"original": transformations["original"],
172-
"resolved": transformations["resolved"],
173-
"source": {"type": "column", "id": "some-id"},
174-
"entityType": "some-type"
175-
}] if transformations else None
176+
resolved_entities=(
177+
[
178+
ResolvedEntitySchema(
179+
original=transformations["original"],
180+
resolved=transformations["resolved"],
181+
source={"type": "column", "id": "some-id"},
182+
entityType="some-type",
183+
)
184+
]
185+
if transformations["original"] or transformations["resolved"]
186+
else None
187+
),
176188
)
177189

190+
178191
# Convenience functions for specific query types
179192
async def decomposition_query(
180193
query: str,

0 commit comments

Comments
 (0)