Skip to content

Commit 8a64896

Browse files
chore: fix OpenAI GPT template
1 parent 3ba6107 commit 8a64896

File tree

2 files changed

+37
-19
lines changed

2 files changed

+37
-19
lines changed

gemma_template/__version__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,5 +3,5 @@
33
__url__ = "https://github.com/thewebscraping/gemma-template"
44
__author__ = "Tu Pham"
55
__author_email__ = "[email protected]"
6-
__version__ = "0.1.2"
6+
__version__ = "0.1.3"
77
__license__ = "Apache-2.0"

gemma_template/models.py

Lines changed: 36 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import asyncio
44
import json
5+
from math import ceil
56
from pathlib import Path
67
from string import punctuation
78
from typing import (Callable, ClassVar, Literal, Optional, Sequence, Union,
@@ -124,8 +125,6 @@ class StructureField(BaseTemplate):
124125
"description": [
125126
"Description",
126127
"Introduction",
127-
"Summary",
128-
"Intro",
129128
"Meta Description",
130129
],
131130
"document": ["Article", "Edit Article"],
@@ -401,6 +400,7 @@ def load_dataset(
401400
min_chars_length: int = 2,
402401
max_chars_length: int = 0,
403402
max_concurrency: int = 4,
403+
is_remove_data: bool = True,
404404
is_close_async_loop: bool = True,
405405
**kwargs,
406406
) -> Union[Dataset, DatasetDict]:
@@ -436,6 +436,8 @@ def load_dataset(
436436
Maximum character of a word, used to create unigrams, bigrams and trigrams. Default is 0.
437437
max_concurrency (int):
438438
Maximum number of concurrent threads for processing data. Default is 4.
439+
is_remove_data (bool):
440+
True will remove the original data from the dataset, otherwise it will keep the field as `data` in the dataset.
439441
is_close_async_loop (bool):
440442
By default it will close the asyncio event loop every time I finish processing the dataset data.
441443
Although it has handled the `RuntimeError` exception. However, you should set it to False if running on Kaggle Notebooks and Colab.
@@ -470,14 +472,15 @@ def load_dataset(
470472
```
471473
""" # noqa: E501
472474

473-
async def create_task(config, hidden_count: int = 0):
475+
async def create_task(config, max_hidden_count: int = 0, hidden_count: int = 0):
474476
async with semaphore:
475477
config.update(kwargs)
476478
config.update(
477479
dict(
478480
min_chars_length=min_chars_length,
479481
max_chars_length=max_chars_length,
480482
excluded_fields=excluded_fields,
483+
is_remove_data=is_remove_data,
481484
)
482485
)
483486
if max_hidden_ratio > 0 and hidden_count < max_hidden_count:
@@ -518,9 +521,10 @@ async def create_task(config, hidden_count: int = 0):
518521
hidden_count += 1
519522

520523
async def run_task(ds):
524+
max_hidden_count = ceil(len(ds) * max_hidden_ratio)
521525
await asyncio.wait(
522526
[
523-
loop.create_task(create_task(config, idx))
527+
loop.create_task(create_task(config, max_hidden_count, idx))
524528
for idx, config in enumerate(ds)
525529
]
526530
)
@@ -558,8 +562,6 @@ def _close():
558562
)
559563
)
560564

561-
items = []
562-
max_hidden_count = int(len(dataset) * max_hidden_ratio)
563565
try:
564566
loop = asyncio.get_running_loop()
565567
except RuntimeError:
@@ -580,7 +582,6 @@ def _close():
580582
with tqdm(total=len(dataset)) as pbar:
581583
for field in dataset.column_names:
582584
items = []
583-
max_hidden_count = int(len(dataset[field]) * max_hidden_ratio)
584585
_ = loop.run_until_complete(run_task(dataset[field]))
585586
mapping[field] = Dataset.from_list(items)
586587

@@ -836,8 +837,7 @@ def generate_user_prompt(
836837

837838
def generate_model_prompt(
838839
self,
839-
structure_template: Optional[TemplateTypes] = None,
840-
excluded_fields: Optional[Sequence[str]] = (),
840+
structure_template: Optional[TemplateTypes] = "",
841841
bullet_style: Optional[Union[str, Literal["dash", "number"]]] = "dash",
842842
**kwargs,
843843
) -> str:
@@ -849,7 +849,6 @@ def generate_model_prompt(
849849
850850
Args:
851851
structure_template (Optional[Union[str, Callable]]): A structure template defining the generating structure prompt.
852-
excluded_fields (Sequence[str]): Fields excluded to response. Default is empty sequence.
853852
bullet_style (Optional[str]): Bullet list style start dash or number. Default is dash.
854853
**kwargs: See also `Template.template`.
855854
@@ -866,11 +865,6 @@ def generate_model_prompt(
866865
""" # noqa: E501
867866

868867
output_document = kwargs.get("output", "")
869-
if excluded_fields:
870-
for excluded_field in excluded_fields:
871-
if excluded_field in kwargs:
872-
kwargs.pop(excluded_field)
873-
874868
if isinstance(structure_template, (str, Callable)):
875869
kwargs["document"] = output_document
876870
if isinstance(structure_template, Callable):
@@ -916,6 +910,7 @@ def to_text(
916910
language_code=user_kwargs.get("language_code", "auto"),
917911
language=user_kwargs.get("language"),
918912
is_masked=bool(user_kwargs.get("is_masked")),
913+
data=self._get_origin_data(**kwargs),
919914
)
920915

921916
def to_alpaca(
@@ -941,6 +936,7 @@ def to_alpaca(
941936
language_code=user_kwargs.get("language_code", "auto"),
942937
language=user_kwargs.get("language"),
943938
is_masked=bool(user_kwargs.get("is_masked")),
939+
data=self._get_origin_data(**kwargs),
944940
)
945941

946942
def to_openai(
@@ -955,8 +951,16 @@ def to_openai(
955951
user_template, instruction_template, structure_template, **kwargs
956952
)
957953
return dict(
958-
human=user_template,
959-
gpt=model_template,
954+
conversations=[
955+
{
956+
"from": "human",
957+
"value": user_template,
958+
},
959+
{
960+
"from": "gpt",
961+
"value": model_template,
962+
},
963+
],
960964
is_instructed=bool(instruction_template is not None),
961965
is_structured=bool(structure_template is not None),
962966
unigrams=user_kwargs.get("unigrams", []) or [],
@@ -965,6 +969,7 @@ def to_openai(
965969
language_code=user_kwargs.get("language_code", "auto"),
966970
language=user_kwargs.get("language"),
967971
is_masked=bool(user_kwargs.get("is_masked")),
972+
data=self._get_origin_data(**kwargs),
968973
)
969974

970975
def _get_template(
@@ -1069,10 +1074,14 @@ def _ftm_template(word):
10691074
def _formatting_structure_user_fn(
10701075
self,
10711076
structure_template: str = STRUCTURE_TEMPLATE,
1077+
excluded_fields: Sequence[str] = (),
10721078
**kwargs,
10731079
) -> str:
10741080
prompts = []
1075-
for _, data in self._get_structure_attrs(**kwargs).items():
1081+
for field, data in self._get_structure_attrs(**kwargs).items():
1082+
if excluded_fields and field in excluded_fields:
1083+
continue
1084+
10761085
prompts.append(
10771086
"{field} {prompt}".format(
10781087
field=data["bold_value"], prompt=data["prompt"]
@@ -1085,6 +1094,7 @@ def _formatting_structure_model_fn(
10851094
self,
10861095
structure_data: dict,
10871096
bullet_style: str = None,
1097+
excluded_fields: Sequence[str] = (),
10881098
*args,
10891099
**kwargs,
10901100
) -> str:
@@ -1097,6 +1107,9 @@ def _formatting_structure_model_fn(
10971107
if field not in kwargs:
10981108
continue
10991109

1110+
if excluded_fields and field in excluded_fields:
1111+
continue
1112+
11001113
value = kwargs[field]
11011114
if not value:
11021115
continue
@@ -1129,6 +1142,11 @@ def _get_structure_attrs(self, **kwargs):
11291142
}
11301143
return mapping
11311144

1145+
def _get_origin_data(self, **kwargs) -> dict:
1146+
if not kwargs.get("is_remove_data", True):
1147+
return {k: v for k, v in kwargs.items() if hasattr(self, k)}
1148+
return {}
1149+
11321150

11331151
gemma_template = Template()
11341152
vietnamese_gemma_template = Template(

0 commit comments

Comments
 (0)