Skip to content

Commit 2c92e5e

Browse files
chore: fix excluded fields
1 parent 5a82a4e commit 2c92e5e

File tree

1 file changed

+6
-5
lines changed

1 file changed

+6
-5
lines changed

gemma_template/models.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -655,7 +655,7 @@ def to_text(self, **kwargs) -> dict:
655655
**kwargs: see also `Template.apply_template`.
656656
"""
657657

658-
input_str, output_str, attr = self._build_template(**kwargs,)
658+
input_str, output_str, attr = self._build_template(**kwargs)
659659

660660
text = JinjaTemplate.from_string(self._position_value("template")).render(
661661
input=input_str,
@@ -677,7 +677,7 @@ def to_alpaca(self, **kwargs) -> dict:
677677
**kwargs: see also `Template.apply_template`.
678678
"""
679679

680-
input_str, output_str, attr = self._build_template(**kwargs,)
680+
input_str, output_str, attr = self._build_template(**kwargs)
681681

682682
return dict(
683683
instruction="\n\n".join(
@@ -714,7 +714,7 @@ def to_openai(self, **kwargs) -> dict:
714714
**kwargs: see also `Template.apply_template`.
715715
"""
716716

717-
input_str, output_str, attr = self._build_template(**kwargs,)
717+
input_str, output_str, attr = self._build_template(**kwargs)
718718

719719
return dict(
720720
messages=[
@@ -896,7 +896,7 @@ def get_template_attr(self, **kwargs) -> Attr:
896896
dict(
897897
system_prompt=system_prompt,
898898
prompt=prompt,
899-
prompt_structure=self._build_prompt_structure(structure_fields, prompt),
899+
prompt_structure=self._build_prompt_structure(structure_fields, prompt, **kwargs),
900900
instruction=self._build_instruction(document, analysis, **kwargs),
901901
structure_fields=structure_fields,
902902
input=document,
@@ -971,9 +971,10 @@ def _build_prompt_structure(
971971
**kwargs,
972972
) -> str:
973973
if self.prompt_template:
974+
excluded_fields = kwargs.get("excluded_fields", [])
974975
return (
975976
JinjaTemplate.from_string(self._position_value("prompt_template"))
976-
.render(prompt=prompt, structure_fields=structure_fields)
977+
.render(prompt=prompt, structure_fields=[field for field in structure_fields if field.key not in excluded_fields])
977978
.strip()
978979
)
979980
return ""

0 commit comments

Comments
 (0)