Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 6 additions & 5 deletions llm/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -675,11 +675,6 @@ def read_prompt():
raise click.ClickException(str(ex))
extract = template_obj.extract
extract_last = template_obj.extract_last
# Combine with template fragments/system_fragments
if template_obj.fragments:
fragments = [*template_obj.fragments, *fragments]
if template_obj.system_fragments:
system_fragments = [*template_obj.system_fragments, *system_fragments]
if template_obj.schema_object:
schema = template_obj.schema_object
if template_obj.tools:
Expand Down Expand Up @@ -711,6 +706,12 @@ def read_prompt():
raise click.ClickException(str(ex))
if model_id is None and template_obj.model:
model_id = template_obj.model
# Combine with template fragments/system_fragments AFTER evaluation
# so that any variables in the fragments have been interpolated
if template_obj.fragments:
fragments = [*template_obj.fragments, *fragments]
if template_obj.system_fragments:
system_fragments = [*template_obj.system_fragments, *system_fragments]
# Merge in any attachments
if template_obj.attachments:
attachments = [
Expand Down
15 changes: 15 additions & 0 deletions llm/templates.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,14 +53,29 @@ def evaluate(
else:
prompt = self.interpolate(self.prompt, params)
system = self.interpolate(self.system, params)

# Interpolate fragments
if self.fragments:
self.fragments = [interpolated for fragment in self.fragments if (interpolated := self.interpolate(fragment, params)) is not None]
if self.system_fragments:
self.system_fragments = [interpolated for fragment in self.system_fragments if (interpolated := self.interpolate(fragment, params)) is not None]

return prompt, system

def vars(self) -> set:
all_vars = set()
# Check prompt and system
for text in [self.prompt, self.system]:
if not text:
continue
all_vars.update(self.extract_vars(string.Template(text)))
# Check fragments and system_fragments
for fragment_list in [self.fragments, self.system_fragments]:
if not fragment_list:
continue
for fragment in fragment_list:
if fragment:
all_vars.update(self.extract_vars(string.Template(fragment)))
return all_vars

@classmethod
Expand Down
89 changes: 54 additions & 35 deletions tests/test_templates.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,7 @@
),
),
)
def test_template_evaluate(
prompt, system, defaults, params, expected_prompt, expected_system, expected_error
):
def test_template_evaluate(prompt, system, defaults, params, expected_prompt, expected_system, expected_error):
t = Template(name="t", prompt=prompt, system=system, defaults=defaults)
if expected_error:
with pytest.raises(Template.MissingVariables) as ex:
Expand All @@ -47,6 +45,49 @@ def test_template_evaluate(
assert system == expected_system


def test_template_evaluate_with_fragments():
"""Test that fragments and system_fragments support interpolation"""
t = Template(
name="t",
prompt="Main prompt: $input",
fragments=["Fragment 1: $input", "Fragment 2: $var2"],
system_fragments=["System fragment: $sys_var"],
)
prompt, system = t.evaluate("test input", {"var2": "value2", "sys_var": "sys_value"})

# Check that prompt and system are correctly interpolated
assert prompt == "Main prompt: test input"

# Check that fragments are interpolated
assert t.fragments == ["Fragment 1: test input", "Fragment 2: value2"]
assert t.system_fragments == ["System fragment: sys_value"]


def test_template_evaluate_with_fragments_missing_vars():
"""Test that missing variables in fragments raise an error"""
t = Template(
name="t",
prompt="Main prompt: $input",
fragments=["Fragment with $missing_var"],
)
with pytest.raises(Template.MissingVariables) as ex:
t.evaluate("test input", {})
assert "missing_var" in ex.value.args[0]


def test_template_vars_includes_fragments():
"""Test that the vars() method includes variables from fragments"""
t = Template(
name="t",
prompt="Prompt with $prompt_var",
system="System with $system_var",
fragments=["Fragment with $fragment_var"],
system_fragments=["System fragment with $sys_fragment_var"],
)
vars = t.vars()
assert vars == {"prompt_var", "system_var", "fragment_var", "sys_fragment_var"}


def test_templates_list_no_templates_found():
runner = CliRunner()
result = runner.invoke(cli, ["templates", "list"])
Expand All @@ -58,15 +99,9 @@ def test_templates_list_no_templates_found():
def test_templates_list(templates_path, args):
(templates_path / "one.yaml").write_text("template one", "utf-8")
(templates_path / "two.yaml").write_text("template two", "utf-8")
(templates_path / "three.yaml").write_text(
"template three is very long " * 4, "utf-8"
)
(templates_path / "four.yaml").write_text(
"'this one\n\nhas newlines in it'", "utf-8"
)
(templates_path / "both.yaml").write_text(
"system: summarize this\nprompt: $input", "utf-8"
)
(templates_path / "three.yaml").write_text("template three is very long " * 4, "utf-8")
(templates_path / "four.yaml").write_text("'this one\n\nhas newlines in it'", "utf-8")
(templates_path / "both.yaml").write_text("system: summarize this\nprompt: $input", "utf-8")
(templates_path / "sys.yaml").write_text("system: Summarize this", "utf-8")
(templates_path / "invalid.yaml").write_text("system2: This is invalid", "utf-8")
runner = CliRunner()
Expand Down Expand Up @@ -115,11 +150,7 @@ def test_templates_list(templates_path, args):
"--schema",
'{"properties": {"b": {"type": "string"}, "a": {"type": "string"}}}',
],
{
"schema_object": {
"properties": {"b": {"type": "string"}, "a": {"type": "string"}}
}
},
{"schema_object": {"properties": {"b": {"type": "string"}, "a": {"type": "string"}}}},
None,
),
# And fragments and system_fragments
Expand Down Expand Up @@ -164,9 +195,7 @@ def test_templates_prompt_save(templates_path, args, expected, expected_error):
yaml_data = yaml.safe_load((templates_path / "saved.yaml").read_text("utf-8"))
# Adjust attachment and attachment_types paths to be just the filename
if "attachments" in yaml_data:
yaml_data["attachments"] = [
os.path.basename(path) for path in yaml_data["attachments"]
]
yaml_data["attachments"] = [os.path.basename(path) for path in yaml_data["attachments"]]
for item in yaml_data.get("attachment_types", []):
item["value"] = os.path.basename(item["value"])
assert yaml_data == expected
Expand All @@ -177,18 +206,12 @@ def test_templates_prompt_save(templates_path, args, expected, expected_error):

def test_templates_error_on_missing_schema(templates_path):
runner = CliRunner()
runner.invoke(
cli, ["the-prompt", "--save", "prompt_no_schema"], catch_exceptions=False
)
runner.invoke(cli, ["the-prompt", "--save", "prompt_no_schema"], catch_exceptions=False)
# This should complain about no schema
result = runner.invoke(
cli, ["hi", "--schema", "t:prompt_no_schema"], catch_exceptions=False
)
result = runner.invoke(cli, ["hi", "--schema", "t:prompt_no_schema"], catch_exceptions=False)
assert result.output == "Error: Template 'prompt_no_schema' has no schema\n"
# And this is just an invalid template
result2 = runner.invoke(
cli, ["hi", "--schema", "t:bad_template"], catch_exceptions=False
)
result2 = runner.invoke(cli, ["hi", "--schema", "t:bad_template"], catch_exceptions=False)
assert result2.output == "Error: Invalid template: bad_template\n"


Expand Down Expand Up @@ -316,9 +339,7 @@ def test_execute_prompt_with_a_template(
runner = CliRunner()
result = runner.invoke(
cli,
["--no-stream", "-t", "template"]
+ ([input_text] if input_text else [])
+ extra_args,
["--no-stream", "-t", "template"] + ([input_text] if input_text else []) + extra_args,
catch_exceptions=False,
)
if isinstance(expected_input, str):
Expand Down Expand Up @@ -446,9 +467,7 @@ def register_tools(self, register):
("plugin", True, False),
),
)
def test_tools_in_templates(
source, expected_tool_success, expected_functions_success, httpx_mock, tmpdir
):
def test_tools_in_templates(source, expected_tool_success, expected_functions_success, httpx_mock, tmpdir):
template_yaml = textwrap.dedent(
"""
name: test
Expand Down
Loading