22
33import asyncio
44import json
5+ from math import ceil
56from pathlib import Path
67from string import punctuation
78from typing import (Callable , ClassVar , Literal , Optional , Sequence , Union ,
@@ -124,8 +125,6 @@ class StructureField(BaseTemplate):
124125 "description" : [
125126 "Description" ,
126127 "Introduction" ,
127- "Summary" ,
128- "Intro" ,
129128 "Meta Description" ,
130129 ],
131130 "document" : ["Article" , "Edit Article" ],
@@ -401,6 +400,7 @@ def load_dataset(
401400 min_chars_length : int = 2 ,
402401 max_chars_length : int = 0 ,
403402 max_concurrency : int = 4 ,
403+ is_remove_data : bool = True ,
404404 is_close_async_loop : bool = True ,
405405 ** kwargs ,
406406 ) -> Union [Dataset , DatasetDict ]:
@@ -436,6 +436,8 @@ def load_dataset(
436436 Maximum character of a word, used to create unigrams, bigrams and trigrams. Default is 0.
437437 max_concurrency (int):
438438 Maximum number of concurrent threads for processing data. Default is 4.
439+ is_remove_data (bool):
440+ True will remove the original data from the dataset, otherwise it will keep the field as `data` in the dataset.
439441 is_close_async_loop (bool):
440442 By default it will close the asyncio event loop every time I finish processing the dataset data.
441443 Although it has handled the `RuntimeError` exception. However, you should set it to False if running on Kaggle Notebooks and Colab.
@@ -470,14 +472,15 @@ def load_dataset(
470472 ```
471473 """ # noqa: E501
472474
473- async def create_task (config , hidden_count : int = 0 ):
475+ async def create_task (config , max_hidden_count : int = 0 , hidden_count : int = 0 ):
474476 async with semaphore :
475477 config .update (kwargs )
476478 config .update (
477479 dict (
478480 min_chars_length = min_chars_length ,
479481 max_chars_length = max_chars_length ,
480482 excluded_fields = excluded_fields ,
483+ is_remove_data = is_remove_data ,
481484 )
482485 )
483486 if max_hidden_ratio > 0 and hidden_count < max_hidden_count :
@@ -518,9 +521,10 @@ async def create_task(config, hidden_count: int = 0):
518521 hidden_count += 1
519522
520523 async def run_task (ds ):
524+ max_hidden_count = ceil (len (ds ) * max_hidden_ratio )
521525 await asyncio .wait (
522526 [
523- loop .create_task (create_task (config , idx ))
527+ loop .create_task (create_task (config , max_hidden_count , idx ))
524528 for idx , config in enumerate (ds )
525529 ]
526530 )
@@ -558,8 +562,6 @@ def _close():
558562 )
559563 )
560564
561- items = []
562- max_hidden_count = int (len (dataset ) * max_hidden_ratio )
563565 try :
564566 loop = asyncio .get_running_loop ()
565567 except RuntimeError :
@@ -580,7 +582,6 @@ def _close():
580582 with tqdm (total = len (dataset )) as pbar :
581583 for field in dataset .column_names :
582584 items = []
583- max_hidden_count = int (len (dataset [field ]) * max_hidden_ratio )
584585 _ = loop .run_until_complete (run_task (dataset [field ]))
585586 mapping [field ] = Dataset .from_list (items )
586587
@@ -836,8 +837,7 @@ def generate_user_prompt(
836837
837838 def generate_model_prompt (
838839 self ,
839- structure_template : Optional [TemplateTypes ] = None ,
840- excluded_fields : Optional [Sequence [str ]] = (),
840+ structure_template : Optional [TemplateTypes ] = "" ,
841841 bullet_style : Optional [Union [str , Literal ["dash" , "number" ]]] = "dash" ,
842842 ** kwargs ,
843843 ) -> str :
@@ -849,7 +849,6 @@ def generate_model_prompt(
849849
850850 Args:
851851 structure_template (Optional[Union[str, Callable]]): A structure template defining the generating structure prompt.
852- excluded_fields (Sequence[str]): Fields excluded to response. Default is empty sequence.
853852 bullet_style (Optional[str]): Bullet list style start dash or number. Default is dash.
854853 **kwargs: See also `Template.template`.
855854
@@ -866,11 +865,6 @@ def generate_model_prompt(
866865 """ # noqa: E501
867866
868867 output_document = kwargs .get ("output" , "" )
869- if excluded_fields :
870- for excluded_field in excluded_fields :
871- if excluded_field in kwargs :
872- kwargs .pop (excluded_field )
873-
874868 if isinstance (structure_template , (str , Callable )):
875869 kwargs ["document" ] = output_document
876870 if isinstance (structure_template , Callable ):
@@ -916,6 +910,7 @@ def to_text(
916910 language_code = user_kwargs .get ("language_code" , "auto" ),
917911 language = user_kwargs .get ("language" ),
918912 is_masked = bool (user_kwargs .get ("is_masked" )),
913+ data = self ._get_origin_data (** kwargs ),
919914 )
920915
921916 def to_alpaca (
@@ -941,6 +936,7 @@ def to_alpaca(
941936 language_code = user_kwargs .get ("language_code" , "auto" ),
942937 language = user_kwargs .get ("language" ),
943938 is_masked = bool (user_kwargs .get ("is_masked" )),
939+ data = self ._get_origin_data (** kwargs ),
944940 )
945941
946942 def to_openai (
@@ -955,8 +951,16 @@ def to_openai(
955951 user_template , instruction_template , structure_template , ** kwargs
956952 )
957953 return dict (
958- human = user_template ,
959- gpt = model_template ,
954+ conversations = [
955+ {
956+ "from" : "human" ,
957+ "value" : user_template ,
958+ },
959+ {
960+ "from" : "gpt" ,
961+ "value" : model_template ,
962+ },
963+ ],
960964 is_instructed = bool (instruction_template is not None ),
961965 is_structured = bool (structure_template is not None ),
962966 unigrams = user_kwargs .get ("unigrams" , []) or [],
@@ -965,6 +969,7 @@ def to_openai(
965969 language_code = user_kwargs .get ("language_code" , "auto" ),
966970 language = user_kwargs .get ("language" ),
967971 is_masked = bool (user_kwargs .get ("is_masked" )),
972+ data = self ._get_origin_data (** kwargs ),
968973 )
969974
970975 def _get_template (
@@ -1069,10 +1074,14 @@ def _ftm_template(word):
10691074 def _formatting_structure_user_fn (
10701075 self ,
10711076 structure_template : str = STRUCTURE_TEMPLATE ,
1077+ excluded_fields : Sequence [str ] = (),
10721078 ** kwargs ,
10731079 ) -> str :
10741080 prompts = []
1075- for _ , data in self ._get_structure_attrs (** kwargs ).items ():
1081+ for field , data in self ._get_structure_attrs (** kwargs ).items ():
1082+ if excluded_fields and field in excluded_fields :
1083+ continue
1084+
10761085 prompts .append (
10771086 "{field} {prompt}" .format (
10781087 field = data ["bold_value" ], prompt = data ["prompt" ]
@@ -1085,6 +1094,7 @@ def _formatting_structure_model_fn(
10851094 self ,
10861095 structure_data : dict ,
10871096 bullet_style : str = None ,
1097+ excluded_fields : Sequence [str ] = (),
10881098 * args ,
10891099 ** kwargs ,
10901100 ) -> str :
@@ -1097,6 +1107,9 @@ def _formatting_structure_model_fn(
10971107 if field not in kwargs :
10981108 continue
10991109
1110+ if excluded_fields and field in excluded_fields :
1111+ continue
1112+
11001113 value = kwargs [field ]
11011114 if not value :
11021115 continue
@@ -1129,6 +1142,11 @@ def _get_structure_attrs(self, **kwargs):
11291142 }
11301143 return mapping
11311144
1145+ def _get_origin_data (self , ** kwargs ) -> dict :
1146+ if not kwargs .get ("is_remove_data" , True ):
1147+ return {k : v for k , v in kwargs .items () if hasattr (self , k )}
1148+ return {}
1149+
11321150
11331151gemma_template = Template ()
11341152vietnamese_gemma_template = Template (
0 commit comments