Skip to content

Commit d37b05f

Browse files
committed
Updating to add conditional rules
1 parent 6f3890f commit d37b05f

File tree

7 files changed

+170
-67
lines changed

7 files changed

+170
-67
lines changed

backend/src/app/models/query_core.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,13 @@ class TransformationDict(BaseModel):
3131
class Rule(BaseModel):
3232
"""Rule model."""
3333

34-
type: Literal["must_return", "may_return", "max_length", "resolve_entity"]
34+
type: Literal[
35+
"must_return",
36+
"may_return",
37+
"max_length",
38+
"resolve_entity",
39+
"resolve_conditional",
40+
]
3541
options: Optional[List[str]] = None
3642
length: Optional[int] = None
3743

backend/src/app/services/query_service.py

Lines changed: 105 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -44,16 +44,13 @@ def extract_chunks(search_response: SearchResponse) -> List[Chunk]:
4444

4545

4646
def replace_keywords(
47-
text: Union[str, List[str]], keyword_replacements: Dict[str, str]
48-
) -> tuple[
49-
Union[str, List[str]], Dict[str, Union[str, List[str]]]
50-
]: # Changed return type
47+
text: Union[str, List[str]],
48+
keyword_replacements: Dict[str, str],
49+
conditional_replacements: List[tuple[List[str], str]] = [],
50+
) -> tuple[Union[str, List[str]], Dict[str, Union[str, List[str]]]]:
5151
"""Replace keywords in text and return both the modified text and transformation details."""
52-
if not text or not keyword_replacements:
53-
return text, {
54-
"original": text,
55-
"resolved": text,
56-
} # Return dict instead of TransformationDict
52+
if not text or (not keyword_replacements and not conditional_replacements):
53+
return text, {"original": text, "resolved": text}
5754

5855
# Handle list of strings
5956
if isinstance(text, list):
@@ -62,13 +59,12 @@ def replace_keywords(
6259
modified = False
6360

6461
# Create a single regex pattern for all keywords
65-
pattern = "|".join(map(re.escape, keyword_replacements.keys()))
66-
regex = re.compile(f"\\b({pattern})\\b")
62+
# pattern = "|".join(map(re.escape, keyword_replacements.keys()))
63+
# regex = re.compile(f"\\b({pattern})\\b")
6764

6865
for item in text:
69-
# Single pass replacement for all keywords
70-
new_item = regex.sub(
71-
lambda m: keyword_replacements[m.group()], item
66+
new_item, _ = replace_keywords_in_string(
67+
item, keyword_replacements, conditional_replacements
7268
)
7369
result.append(new_item)
7470
if new_item != item:
@@ -79,24 +75,46 @@ def replace_keywords(
7975
return result, {"original": original_text, "resolved": result}
8076

8177
# Handle single string
82-
return replace_keywords_in_string(text, keyword_replacements)
78+
return replace_keywords_in_string(
79+
text, keyword_replacements, conditional_replacements
80+
)
81+
82+
83+
def parse_conditional_replacement(option: str) -> tuple[List[str], str]:
84+
"""Parse a conditional replacement rule like 'word a + word b : word c'."""
85+
conditions, replacement = option.split(":")
86+
required_words = [word.strip() for word in conditions.split("+")]
87+
return required_words, replacement.strip()
8388

8489

8590
def replace_keywords_in_string(
86-
text: str, keyword_replacements: Dict[str, str]
87-
) -> tuple[str, Dict[str, Union[str, List[str]]]]: # Changed return type
91+
text: str,
92+
keyword_replacements: Dict[str, str],
93+
conditional_replacements: List[tuple[List[str], str]] = [],
94+
) -> tuple[str, Dict[str, Union[str, List[str]]]]:
8895
"""Keywords for single string."""
89-
if not text:
96+
if not text or (not keyword_replacements and not conditional_replacements):
9097
return text, {"original": text, "resolved": text}
9198

92-
# Create a single regex pattern for all keywords
93-
pattern = "|".join(map(re.escape, keyword_replacements.keys()))
94-
regex = re.compile(f"\\b({pattern})\\b")
99+
result = text
100+
101+
# First check conditional replacements
102+
for required_words, replacement in conditional_replacements:
103+
# Check if all required words are present
104+
if all(word.lower() in text.lower() for word in required_words):
105+
# Create a pattern that matches any of the required words
106+
pattern = "|".join(map(re.escape, required_words))
107+
# Replace all occurrences of the required words with the replacement
108+
result = re.sub(
109+
f"\\b({pattern})\\b", replacement, result, flags=re.IGNORECASE
110+
)
95111

96-
# Single pass replacement
97-
result = regex.sub(lambda m: keyword_replacements[m.group()], text)
112+
# Then do normal replacements
113+
if keyword_replacements:
114+
pattern = "|".join(map(re.escape, keyword_replacements.keys()))
115+
regex = re.compile(f"\\b({pattern})\\b")
116+
result = regex.sub(lambda m: keyword_replacements[m.group()], result)
98117

99-
# Only return transformation if something changed
100118
if result != text:
101119
return result, {"original": text, "resolved": result}
102120
return text, {"original": text, "resolved": text}
@@ -131,11 +149,13 @@ async def process_query(
131149
result_chunks = []
132150

133151
if format in ["str", "str_array"]:
134-
135-
# Extract and apply keyword replacements from all resolve_entity rules
152+
# Extract rules by type
136153
resolve_entity_rules = [
137154
rule for rule in rules if rule.type == "resolve_entity"
138155
]
156+
conditional_rules = [
157+
rule for rule in rules if rule.type == "resolve_conditional"
158+
]
139159

140160
result_chunks = (
141161
[]
@@ -144,28 +164,43 @@ async def process_query(
144164
else chunks
145165
)
146166

147-
# First populate the replacements dictionary
148-
replacements: Dict[str, str] = {}
149-
if resolve_entity_rules and answer_value:
150-
for rule in resolve_entity_rules:
151-
if rule.options:
152-
rule_replacements = dict(
153-
option.split(":") for option in rule.options
154-
)
155-
replacements.update(rule_replacements)
156-
157-
# Then apply the replacements if we have any
158-
if replacements:
167+
# Process both types of replacements if we have an answer
168+
if answer_value and (resolve_entity_rules or conditional_rules):
169+
# Build regular replacements dictionary
170+
replacements: Dict[str, str] = {}
171+
if resolve_entity_rules:
172+
for rule in resolve_entity_rules:
173+
if rule.options:
174+
rule_replacements = dict(
175+
option.split(":") for option in rule.options
176+
)
177+
replacements.update(rule_replacements)
178+
179+
# Build conditional replacements list
180+
conditional_replacements: List[tuple[List[str], str]] = []
181+
if conditional_rules:
182+
for rule in conditional_rules:
183+
if rule.options:
184+
for option in rule.options:
185+
required_words, replacement = (
186+
parse_conditional_replacement(option)
187+
)
188+
conditional_replacements.append(
189+
(required_words, replacement)
190+
)
191+
192+
# Apply replacements if we have any
193+
if replacements or conditional_replacements:
159194
print(f"Resolving entities in answer: {answer_value}")
160195
if isinstance(answer_value, list):
161196
transformed_list, transform_dict = replace_keywords(
162-
answer_value, replacements
197+
answer_value, replacements, conditional_replacements
163198
)
164199
transformations = transform_dict
165200
answer_value = transformed_list
166201
else:
167202
transformed_value, transform_dict = replace_keywords(
168-
answer_value, replacements
203+
answer_value, replacements, conditional_replacements
169204
)
170205
transformations = transform_dict
171206
answer_value = transformed_value
@@ -256,31 +291,47 @@ async def inference_query(
256291
llm_service: CompletionService,
257292
) -> QueryResult:
258293
"""Generate a response, no need for vector retrieval."""
259-
# Since we are just answering this query based on data provided in the query,
260-
# ther is no need to retrieve any chunks from the vector database.
261-
262294
answer = await generate_inferred_response(
263295
llm_service, query, rules, format
264296
)
265297
answer_value = answer["answer"]
266298

267-
# Extract and apply keyword replacements from all resolve_entity rules
299+
# Extract rules by type
268300
resolve_entity_rules = [
269301
rule for rule in rules if rule.type == "resolve_entity"
270302
]
303+
conditional_rules = [
304+
rule for rule in rules if rule.type == "resolve_conditional"
305+
]
271306

272-
if resolve_entity_rules and answer_value:
273-
# Combine all replacements from all resolve_entity rules
307+
if answer_value and (resolve_entity_rules or conditional_rules):
308+
# Build regular replacements
274309
replacements = {}
275-
for rule in resolve_entity_rules:
276-
if rule.options:
277-
rule_replacements = dict(
278-
option.split(":") for option in rule.options
279-
)
280-
replacements.update(rule_replacements)
310+
if resolve_entity_rules:
311+
for rule in resolve_entity_rules:
312+
if rule.options:
313+
rule_replacements = dict(
314+
option.split(":") for option in rule.options
315+
)
316+
replacements.update(rule_replacements)
281317

282-
if replacements:
318+
# Build conditional replacements
319+
conditional_replacements = []
320+
if conditional_rules:
321+
for rule in conditional_rules:
322+
if rule.options:
323+
for option in rule.options:
324+
required_words, replacement = (
325+
parse_conditional_replacement(option)
326+
)
327+
conditional_replacements.append(
328+
(required_words, replacement)
329+
)
330+
331+
if replacements or conditional_replacements:
283332
print(f"Resolving entities in answer: {answer_value}")
284-
answer_value = replace_keywords(answer_value, replacements)
333+
answer_value, _ = replace_keywords(
334+
answer_value, replacements, conditional_replacements
335+
)
285336

286337
return QueryResult(answer=answer_value, chunks=[])

frontend/src/components/kt/kt-controls/kt-global-rules.tsx

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,8 @@ export function KTGlobalRules(props: BoxProps) {
127127
max_length,3,
128128
<br />
129129
resolve_entity,"blue:ultramarine,red:crimson",Color
130+
<br />
131+
resolve_conditional,"word a + word b:word c",Words
130132
</Code>
131133
</Box>
132134
</Group>
@@ -247,7 +249,7 @@ export function KTGlobalRules(props: BoxProps) {
247249

248250
const csvJsonSchema = z.array(
249251
z.object({
250-
rule_type: z.enum(["must_return", "may_return", "max_length", "resolve_entity"]),
252+
rule_type: z.enum(["must_return", "may_return", "max_length", "resolve_entity", "resolve_conditional"]),
251253
value: z.string(),
252254
entity_type: z.string().optional()
253255
})

frontend/src/components/kt/kt-table/kt-cells/kt-column-settings/kt-column-settings.tsx

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,31 @@ const rulesMenu = (
157157
}}
158158
/>
159159
</Group>
160+
) : rule.type === "resolve_conditional" ? (
161+
<Group gap="xs" wrap="nowrap">
162+
<TextInput
163+
w={150}
164+
placeholder="word a + word b"
165+
value={rule.options?.[0]?.split(":")[0] ?? ""}
166+
onChange={e => {
167+
const after = rule.options?.[0]?.split(":")[1] ?? "";
168+
handleRuleChange(rule, {
169+
options: [`${e.target.value}:${after}`]
170+
});
171+
}}
172+
/>
173+
<TextInput
174+
w={100}
175+
placeholder="word c"
176+
value={rule.options?.[0]?.split(":")[1] ?? ""}
177+
onChange={e => {
178+
const before = rule.options?.[0]?.split(":")[0] ?? "";
179+
handleRuleChange(rule, {
180+
options: [`${before}:${e.target.value}`]
181+
});
182+
}}
183+
/>
184+
</Group>
160185
) : (
161186
<TagsInput
162187
w={210}

frontend/src/config/store/store.ts

Lines changed: 22 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -420,12 +420,27 @@ export const useStore = create<Store>()(
420420
? entity.original.join(' ')
421421
: entity.original;
422422

423-
return globalRules.some(rule =>
424-
rule.type === 'resolve_entity' &&
425-
rule.options?.some(pattern =>
426-
originalText.toLowerCase().includes(pattern.toLowerCase())
427-
)
428-
);
423+
return globalRules.some(rule => {
424+
// Handle regular resolve_entity rules
425+
if (rule.type === 'resolve_entity') {
426+
return rule.options?.some(pattern =>
427+
originalText.toLowerCase().includes(pattern.split(':')[0].toLowerCase())
428+
);
429+
}
430+
431+
// Handle conditional resolve rules
432+
if (rule.type === 'resolve_conditional') {
433+
return rule.options?.some(pattern => {
434+
const [conditions] = pattern.split(':');
435+
const requiredWords = conditions.split('+').map(word => word.trim());
436+
return requiredWords.every(word =>
437+
originalText.toLowerCase().includes(word.toLowerCase())
438+
);
439+
});
440+
}
441+
442+
return false;
443+
});
429444
};
430445

431446
editTable(activeTableId, {
@@ -451,7 +466,7 @@ export const useStore = create<Store>()(
451466
})),
452467
globalRules: currentTable.globalRules.map(rule => ({
453468
...rule,
454-
resolvedEntities: rule.type === 'resolve_entity'
469+
resolvedEntities: (rule.type === 'resolve_entity' || rule.type === 'resolve_conditional')
455470
? [
456471
...(rule.resolvedEntities || []),
457472
...(resolvedEntities || [])

frontend/src/config/store/store.types.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ export interface AnswerTableGlobalRule extends AnswerTableRule {
111111
}
112112

113113
export interface AnswerTableRule {
114-
type: "must_return" | "may_return" | "max_length" | "resolve_entity";
114+
type: "must_return" | "may_return" | "max_length" | "resolve_entity" | "resolve_conditional";
115115
options?: string[];
116116
length?: number;
117117
}

frontend/src/config/store/store.utils.ts

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,8 @@ export const defaultRules: Record<AnswerTableRule["type"], AnswerTableRule> = {
8080
must_return: { type: "must_return", options: [] },
8181
may_return: { type: "may_return", options: [] },
8282
max_length: { type: "max_length", length: 1 },
83-
resolve_entity: { type: "resolve_entity", options: [] }
83+
resolve_entity: { type: "resolve_entity", options: [] },
84+
resolve_conditional: { type: "resolve_conditional", options: [] }
8485
};
8586

8687
export const ruleOptions: {
@@ -90,15 +91,18 @@ export const ruleOptions: {
9091
{ value: "must_return", label: "Must return" },
9192
{ value: "may_return", label: "May return" },
9293
{ value: "max_length", label: "Allowed # of responses" },
93-
{ value: "resolve_entity", label: "Resolve entity" }
94+
{ value: "resolve_entity", label: "Resolve entity" },
95+
{ value: "resolve_conditional", label: "Resolve conditional" }
9496
];
9597

9698
export const ruleInfo: Record<AnswerTableRule["type"], string> = {
9799
must_return: "The column must return the specified values",
98100
may_return: "The column may return the specified values",
99101
max_length: "The column must return at most N values",
100102
resolve_entity:
101-
"Replace all specified values with the first one from the list (i.e. 'turquioise:blue')"
103+
"Replace all specified values with the first one from the list (i.e. 'turquioise:blue')",
104+
resolve_conditional:
105+
"Replace all specified values with the first one from the list (i.e. 'word a + word b:word c')"
102106
};
103107

104108
// Casting

0 commit comments

Comments
 (0)