Skip to content

Commit 691b308

Browse files
Tulio Coppolatuliocoppola
authored andcommitted
feat: req validation generator & template v2
Signed-off-by: Tulio Coppola <[email protected]>
1 parent b8fc8e1 commit 691b308

File tree

25 files changed

+605
-27
lines changed

25 files changed

+605
-27
lines changed

cli/decompose/decompose.py

Lines changed: 28 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import json
22
import keyword
3+
import shutil
34
from enum import Enum
45
from pathlib import Path
56
from typing import Annotated
@@ -14,7 +15,7 @@
1415
class DecompVersion(str, Enum):
1516
latest = "latest"
1617
v1 = "v1"
17-
# v2 = "v2"
18+
v2 = "v2"
1819

1920

2021
this_file_dir = Path(__file__).resolve().parent
@@ -170,23 +171,40 @@ def run(
170171
backend_api_key=backend_api_key,
171172
)
172173

173-
with open(out_dir / f"{out_name}.json", "w") as f:
174+
decomp_dir = out_dir / out_name
175+
val_fn_dir = decomp_dir / "validations"
176+
val_fn_dir.mkdir(parents=True)
177+
178+
(val_fn_dir / "__init__.py").touch()
179+
180+
for constraint in decomp_data["identified_constraints"]:
181+
if constraint["val_fn"] is not None:
182+
with open(val_fn_dir / f"{constraint['val_fn_name']}.py", "w") as f:
183+
f.write(constraint["val_fn"] + "\n")
184+
185+
with open(decomp_dir / f"{out_name}.json", "w") as f:
174186
json.dump(decomp_data, f, indent=2)
175187

176-
with open(out_dir / f"{out_name}.py", "w") as f:
188+
with open(decomp_dir / f"{out_name}.py", "w") as f:
177189
f.write(
178190
m_template.render(
179-
subtasks=decomp_data["subtasks"], user_inputs=input_var
191+
subtasks=decomp_data["subtasks"],
192+
user_inputs=input_var,
193+
identified_constraints=decomp_data["identified_constraints"],
180194
)
181195
+ "\n"
182196
)
183197
except Exception:
184-
created_json = Path(out_dir / f"{out_name}.json")
185-
created_py = Path(out_dir / f"{out_name}.py")
198+
# created_json = Path(out_dir / f"{out_name}.json")
199+
# created_py = Path(out_dir / f"{out_name}.py")
200+
201+
# if created_json.exists() and created_json.is_file():
202+
# created_json.unlink()
203+
# if created_py.exists() and created_py.is_file():
204+
# created_py.unlink()
186205

187-
if created_json.exists() and created_json.is_file():
188-
created_json.unlink()
189-
if created_py.exists() and created_py.is_file():
190-
created_py.unlink()
206+
decomp_dir = out_dir / out_name
207+
if decomp_dir.exists() and decomp_dir.is_dir():
208+
shutil.rmtree(decomp_dir)
191209

192210
raise Exception

cli/decompose/m_decomp_result_v1.py.jinja2

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,14 @@ import os
44
import textwrap
55

66
import mellea
7+
{%- if "code" in identified_constraints | map(attribute="val_strategy") %}
8+
from mellea.stdlib.requirement import req
9+
{% for c in identified_constraints %}
10+
{%- if c.val_fn %}
11+
from validations.{{ c.val_fn_name }} import validate_input as {{ c.val_fn_name }}
12+
{%- endif %}
13+
{%- endfor %}
14+
{%- endif %}
715

816
m = mellea.start_session()
917
{%- if user_inputs %}
@@ -30,7 +38,14 @@ except KeyError as e:
3038
{%- if item.constraints %}
3139
requirements=[
3240
{%- for c in item.constraints %}
41+
{%- if c.val_fn %}
42+
req(
43+
{{ c.constraint | tojson}},
44+
validation_fn={{ c.val_fn_name }},
45+
),
46+
{%- else %}
3347
{{ c.constraint | tojson}},
48+
{%- endif %}
3449
{%- endfor %}
3550
],
3651
{%- else %}
Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
{% if user_inputs -%}
2+
import os
3+
{% endif -%}
4+
import textwrap
5+
6+
import mellea
7+
{%- if "code" in identified_constraints | map(attribute="val_strategy") %}
8+
from mellea.stdlib.requirement import req
9+
{% for c in identified_constraints %}
10+
{%- if c.val_fn %}
11+
from validations.{{ c.val_fn_name }} import validate_input as {{ c.val_fn_name }}
12+
{%- endif %}
13+
{%- endfor %}
14+
{%- endif %}
15+
16+
m = mellea.start_session()
17+
{%- if user_inputs %}
18+
19+
20+
# User Input Variables
21+
try:
22+
{%- for var in user_inputs %}
23+
{{ var | lower }} = os.environ["{{ var | upper }}"]
24+
{%- endfor %}
25+
except KeyError as e:
26+
print(f"ERROR: One or more required environment variables are not set; {e}")
27+
exit(1)
28+
{%- endif %}
29+
{%- for item in subtasks %}
30+
31+
32+
{{ item.tag | lower }}_gnrl = textwrap.dedent(
33+
R"""
34+
{{ item.general_instructions | trim | indent(width=4, first=False) }}
35+
""".strip()
36+
)
37+
{{ item.tag | lower }} = m.instruct(
38+
{%- if not item.input_vars_required %}
39+
{{ item.subtask[3:] | trim | tojson }},
40+
{%- else %}
41+
textwrap.dedent(
42+
R"""
43+
{{ item.subtask[3:] | trim }}
44+
45+
Here are the input variables and their content:
46+
{%- for var in item.input_vars_required %}
47+
48+
- {{ var | upper }} = {{ "{{" }}{{ var | upper }}{{ "}}" }}
49+
{%- endfor %}
50+
""".strip()
51+
),
52+
{%- endif %}
53+
{%- if item.constraints %}
54+
requirements=[
55+
{%- for c in item.constraints %}
56+
{%- if c.val_fn %}
57+
req(
58+
{{ c.constraint | tojson}},
59+
validation_fn={{ c.val_fn_name }},
60+
),
61+
{%- else %}
62+
{{ c.constraint | tojson}},
63+
{%- endif %}
64+
{%- endfor %}
65+
],
66+
{%- else %}
67+
requirements=None,
68+
{%- endif %}
69+
{%- if item.input_vars_required %}
70+
user_variables={
71+
{%- for var in item.input_vars_required %}
72+
{{ var | upper | tojson }}: {{ var | lower }},
73+
{%- endfor %}
74+
},
75+
{%- endif %}
76+
grounding_context={
77+
"GENERAL_INSTRUCTIONS": {{ item.tag | lower }}_gnrl,
78+
{%- for var in item.depends_on %}
79+
{{ var | upper | tojson }}: {{ var | lower }}.value,
80+
{%- endfor %}
81+
},
82+
)
83+
assert {{ item.tag | lower }}.value is not None, 'ERROR: task "{{ item.tag | lower }}" execution failed'
84+
{%- if loop.last %}
85+
86+
87+
final_answer = {{ item.tag | lower }}.value
88+
89+
print(final_answer)
90+
{%- endif -%}
91+
{%- endfor -%}

cli/decompose/pipeline.py

Lines changed: 48 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -11,28 +11,36 @@
1111

1212
from .prompt_modules import (
1313
constraint_extractor,
14-
# general_instructions,
14+
general_instructions,
1515
subtask_constraint_assign,
1616
subtask_list,
1717
subtask_prompt_generator,
18+
validation_code_generator,
1819
validation_decision,
1920
)
2021
from .prompt_modules.subtask_constraint_assign import SubtaskPromptConstraintsItem
2122
from .prompt_modules.subtask_list import SubtaskItem
2223
from .prompt_modules.subtask_prompt_generator import SubtaskPromptItem
2324

2425

26+
class ConstraintValData(TypedDict):
27+
val_strategy: Literal["code", "llm"]
28+
val_fn: str | None
29+
30+
2531
class ConstraintResult(TypedDict):
2632
constraint: str
27-
validation_strategy: str
33+
val_strategy: Literal["code", "llm"]
34+
val_fn: str | None
35+
val_fn_name: str
2836

2937

3038
class DecompSubtasksResult(TypedDict):
3139
subtask: str
3240
tag: str
3341
constraints: list[ConstraintResult]
3442
prompt_template: str
35-
# general_instructions: str
43+
general_instructions: str
3644
input_vars_required: list[str]
3745
depends_on: list[str]
3846
generated_response: NotRequired[str]
@@ -72,7 +80,9 @@ def decompose(
7280
case DecompBackend.ollama:
7381
m_session = MelleaSession(
7482
OllamaModelBackend(
75-
model_id=model_id, model_options={ModelOption.CONTEXT_WINDOW: 16384}
83+
model_id=model_id,
84+
base_url=backend_endpoint,
85+
model_options={ModelOption.CONTEXT_WINDOW: 16384},
7686
)
7787
)
7888
case DecompBackend.openai:
@@ -115,11 +125,27 @@ def decompose(
115125
m_session, task_prompt, enforce_same_words=False
116126
).parse()
117127

118-
constraint_validation_strategies: dict[str, Literal["code", "llm"]] = {
119-
cons_key: validation_decision.generate(m_session, cons_key).parse()
128+
constraint_val_strategy: dict[
129+
str, dict[Literal["val_strategy"], Literal["code", "llm"]]
130+
] = {
131+
cons_key: {
132+
"val_strategy": validation_decision.generate(m_session, cons_key).parse()
133+
}
120134
for cons_key in task_prompt_constraints
121135
}
122136

137+
constraint_val_data: dict[str, ConstraintValData] = {}
138+
139+
for cons_key in constraint_val_strategy:
140+
constraint_val_data[cons_key] = {
141+
"val_strategy": constraint_val_strategy[cons_key]["val_strategy"],
142+
"val_fn": None,
143+
}
144+
if constraint_val_data[cons_key]["val_strategy"] == "code":
145+
constraint_val_data[cons_key]["val_fn"] = (
146+
validation_code_generator.generate(m_session, cons_key).parse()
147+
)
148+
123149
subtask_prompts: list[SubtaskPromptItem] = subtask_prompt_generator.generate(
124150
m_session,
125151
task_prompt,
@@ -142,14 +168,21 @@ def decompose(
142168
constraints=[
143169
{
144170
"constraint": cons_str,
145-
"validation_strategy": constraint_validation_strategies[cons_str],
171+
"val_strategy": constraint_val_data[cons_str]["val_strategy"],
172+
"val_fn_name": f"val_fn_{task_prompt_constraints.index(cons_str) + 1}",
173+
# >> Always include generated "val_fn" code (experimental)
174+
"val_fn": constraint_val_data[cons_str]["val_fn"],
175+
# >> Include generated "val_fn" code only for the last subtask (experimental)
176+
# "val_fn": constraint_val_data[cons_str]["val_fn"]
177+
# if subtask_i + 1 == len(subtask_prompts_with_constraints)
178+
# else None,
146179
}
147180
for cons_str in subtask_data.constraints
148181
],
149182
prompt_template=subtask_data.prompt_template,
150-
# general_instructions=general_instructions.generate(
151-
# m_session, input_str=subtask_data.prompt_template
152-
# ).parse(),
183+
general_instructions=general_instructions.generate(
184+
m_session, input_str=subtask_data.prompt_template
185+
).parse(),
153186
input_vars_required=list(
154187
dict.fromkeys( # Remove duplicates while preserving the original order.
155188
[
@@ -173,7 +206,7 @@ def decompose(
173206
)
174207
),
175208
)
176-
for subtask_data in subtask_prompts_with_constraints
209+
for subtask_i, subtask_data in enumerate(subtask_prompts_with_constraints)
177210
]
178211

179212
return DecompPipelineResult(
@@ -182,9 +215,11 @@ def decompose(
182215
identified_constraints=[
183216
{
184217
"constraint": cons_str,
185-
"validation_strategy": constraint_validation_strategies[cons_str],
218+
"val_strategy": constraint_val_data[cons_str]["val_strategy"],
219+
"val_fn": constraint_val_data[cons_str]["val_fn"],
220+
"val_fn_name": f"val_fn_{cons_i + 1}",
186221
}
187-
for cons_str in task_prompt_constraints
222+
for cons_i, cons_str in enumerate(task_prompt_constraints)
188223
],
189224
subtasks=decomp_subtask_result,
190225
)

cli/decompose/prompt_modules/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,4 +7,7 @@
77
from .subtask_prompt_generator import (
88
subtask_prompt_generator as subtask_prompt_generator,
99
)
10+
from .validation_code_generator import (
11+
validation_code_generator as validation_code_generator,
12+
)
1013
from .validation_decision import validation_decision as validation_decision

cli/decompose/prompt_modules/subtask_constraint_assign/_prompt/system_template.jinja2

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ You will be provided with the following 4 parameters inside their respective tag
88
4. <all_constraints> : A list of candidate (possible) constraints that can be assigned to the target task.
99
</parameters>
1010

11-
The <all_constraints> list contain the constraints of all tasks on the <execution_plan>, your job is to filter and select only the constraints belonging to your target task.
11+
The <all_constraints> is a list of constraints identified for the entire <execution_plan>, your job is to filter and select only the constraints belonging to your target task.
1212
It is possible that none of the constraints in the <all_constraints> are relevant or related to your target task.
1313

1414
Below, enclosed in <general_instructions> tags, are instructions to guide you on how to complete your assignment:
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
from ._exceptions import (
2+
BackendGenerationError as BackendGenerationError,
3+
TagExtractionError as TagExtractionError,
4+
)
5+
from ._validation_code_generator import (
6+
validation_code_generator as validation_code_generator,
7+
)
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
from typing import Any
2+
3+
4+
class ValidationCodeGeneratorError(Exception):
5+
def __init__(self, error_message: str, **kwargs: dict[str, Any]):
6+
self.error_message = error_message
7+
self.__dict__.update(kwargs)
8+
super().__init__(
9+
f'Module Error "validation_code_generator"; {self.error_message}'
10+
)
11+
12+
13+
class BackendGenerationError(ValidationCodeGeneratorError):
14+
"""Raised when LLM generation fails in the "validation_code_generator" prompt module."""
15+
16+
def __init__(self, error_message: str, **kwargs: dict[str, Any]):
17+
super().__init__(error_message, **kwargs)
18+
19+
20+
class TagExtractionError(ValidationCodeGeneratorError):
21+
"""Raised when tag extraction fails in the "validation_code_generator" prompt module."""
22+
23+
def __init__(self, error_message: str, **kwargs: dict[str, Any]):
24+
super().__init__(error_message, **kwargs)
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
from ._icl_examples import icl_examples as default_icl_examples
2+
from ._prompt import (
3+
get_system_prompt as get_system_prompt,
4+
get_user_prompt as get_user_prompt,
5+
)
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
from ._icl_examples import icl_examples as icl_examples
2+
from ._types import ICLExample as ICLExample

0 commit comments

Comments
 (0)