@@ -1959,41 +1959,61 @@ def _parse_combined_prompt(combined_prompt, dataset):
19591959
19601960
19611961def _create_formatter (possible_columns , final_optional_prompts , user_column_name ):
1962- # Start final prompt!
1963- function = ["def __combined_prompt_processor__(examples):" ]
1964- columns = list (set (possible_columns ))
1965- for column in columns :
1966- function .append (f"{ ' ' * 4 } { column } __ = examples['{ column } ']" )
1967- function .append (f"{ ' ' * 4 } texts = []" )
1968- function .append (f"{ ' ' * 4 } for ({ ', ' .join (columns )} ) in zip({ ', ' .join (f'{ x } __' for x in columns )} ):" )
1969-
1970- # Add optional tags as well!
1971- final_prompt = ""
1972- formatter = []
1962+ columns = list (dict .fromkeys (possible_columns ))
1963+ merged_prompt_parts = []
1964+ formatter_templates = []
19731965
19741966 for j , optional_prompt in enumerate (final_optional_prompts ):
19751967 if type (optional_prompt ) is str :
1976- columns = re .findall (r"\{(.+?)\}" , optional_prompt )
1977- formatter += columns
1978- # Must escape \n \r
1979- final_prompt += optional_prompt .encode ("unicode-escape" ).decode ("utf-8" ).replace ("'" , "\\ '" ).replace ('"' , '\\ "' )
1980- else :
1981- where , prompt = optional_prompt
1982- # Strip [[...]]
1983- # Must escape \n \r
1984- prompt = prompt [2 :- 2 ].encode ("unicode-escape" ).decode ("utf-8" ).replace ("'" , "\\ '" ).replace ('"' , '\\ "' )
1985- columns = re .findall (r"\{(.+?)\}" , prompt )
1986- x = f"__optional_{ j } __"
1987- prompt = f"{ ' ' * 8 } { x } = '{ prompt } '.format({ ', ' .join (f'{ x } = { x } ' for x in columns )} ) if { columns [0 ]} else ''"
1988- function .append (prompt )
1989- formatter .append (x )
1990- final_prompt += "{" + x + "}"
1991-
1992- function .insert (1 , f"{ ' ' * 4 } __combined_prompt__ = '{ final_prompt } '" )
1993- function .append (f"{ ' ' * 8 } texts.append(" \
1994- f"__combined_prompt__.format({ ', ' .join (f'{ x } = { x } ' for x in formatter )} ))" )
1995- function .append (f"{ ' ' * 4 } return " + "{ " + f"'{ user_column_name } ' : texts" + " }" )
1996- return "\n " .join (function )
1968+ needed_columns = re .findall (r"\{(.+?)\}" , optional_prompt )
1969+ formatter_templates .append (("required" , optional_prompt , needed_columns ))
1970+ merged_prompt_parts .append (optional_prompt )
1971+ continue
1972+
1973+ _ , prompt = optional_prompt
1974+ prompt = prompt [2 :- 2 ]
1975+ needed_columns = re .findall (r"\{(.+?)\}" , prompt )
1976+ if len (needed_columns ) == 0 :
1977+ raise IndexError ("Unsloth: Optional [[...]] blocks must contain at least 1 {column}." )
1978+ optional_name = f"__optional_{ j } __"
1979+ formatter_templates .append (("optional" , optional_name , prompt , needed_columns ))
1980+ merged_prompt_parts .append ("{" + optional_name + "}" )
1981+
1982+ merged_prompt = "" .join (merged_prompt_parts )
1983+
1984+ def __combined_prompt_processor__ (examples ):
1985+ if len (examples ) == 0 :
1986+ return {user_column_name : []}
1987+
1988+ first_key = next (iter (examples .keys ()), None )
1989+ if first_key is None :
1990+ return {user_column_name : []}
1991+ n_rows = len (examples [first_key ])
1992+
1993+ texts = []
1994+ for row_idx in range (n_rows ):
1995+ row_values = {column : examples [column ][row_idx ] for column in columns }
1996+ formatter_values = {}
1997+
1998+ for formatter_template in formatter_templates :
1999+ if formatter_template [0 ] == "required" :
2000+ _ , _ , needed_columns = formatter_template
2001+ for column in needed_columns :
2002+ formatter_values [column ] = row_values [column ]
2003+ continue
2004+
2005+ _ , optional_name , prompt , needed_columns = formatter_template
2006+ if row_values [needed_columns [0 ]] not in (None , "" ):
2007+ prompt_values = {column : row_values [column ] for column in needed_columns }
2008+ formatter_values [optional_name ] = prompt .format (** prompt_values )
2009+ else :
2010+ formatter_values [optional_name ] = ""
2011+
2012+ texts .append (merged_prompt .format (** formatter_values ))
2013+
2014+ return {user_column_name : texts }
2015+
2016+ return __combined_prompt_processor__
19972017
19982018
19992019def to_sharegpt (
@@ -2025,13 +2045,17 @@ def to_sharegpt(
20252045 raise TypeError ("Unsloth: Your dataset is probably already in ShareGPT format!" )
20262046
20272047 possible_columns , final_optional_prompts = _parse_combined_prompt (merged_prompt , dataset )
2028- function = _create_formatter (possible_columns , final_optional_prompts , merged_column_name )
2029- exec (function , globals ())
2030- dataset = dataset .map (__combined_prompt_processor__ , batched = True , desc = "Merging columns" )
2048+ formatter = _create_formatter (possible_columns , final_optional_prompts , merged_column_name )
2049+ dataset = dataset .map (formatter , batched = True , desc = "Merging columns" )
20312050
20322051 def __convert_to_sharegpt__ (examples ):
20332052 users = examples [merged_column_name ]
20342053 assistants = examples [output_column_name ]
2054+ if len (users ) != len (assistants ):
2055+ raise ValueError (
2056+ "Unsloth: Input and output columns must have matching batch lengths. "
2057+ f"Got { len (users )} { merged_column_name } rows and { len (assistants )} { output_column_name } rows."
2058+ )
20352059 texts = [
20362060 [
20372061 {"from" : "human" , "value" : str (user ) },
@@ -2062,19 +2086,18 @@ def __convert_to_sharegpt__(examples):
20622086 dataset = concatenate_datasets (all_shuffled , axis = 1 )
20632087
20642088 # Combine them into 1
2065- function = "def __combine_conversations__(examples):\n "
20662089 n_extensions += 1
2067- for j in range (n_extensions ):
2068- function += f" { ' ' * 4 } conversations { j } __ = examples['conversations { j } '] \n "
2069- function += f" { ' ' * 4 } convos = [] \n "
2070- function += f" { ' ' * 4 } for ( { ', ' . join ( f'conversations { j } ' for j in range ( n_extensions )) } ) " \
2071- f" in zip({ ', ' . join ( f'conversations { j } __' for j in range ( n_extensions )) } ): \n "
2072- function += f" { ' ' * 8 } convos.append(" \
2073- f" { '+' . join ( f'conversations { j } ' for j in range ( n_extensions )) } ) \n "
2074- function += f" { ' ' * 4 } return " + "{ " + "'conversations' : convos" + " }"
2075-
2076- # Map function
2077- exec ( function , globals ())
2090+ conversation_columns = [ f"conversations { j } " for j in range (n_extensions )]
2091+ def __combine_conversations__ ( examples ):
2092+ columns = [examples [ column ] for column in conversation_columns ]
2093+ convos = []
2094+ for conversations in zip (* columns ):
2095+ merged_conversation = []
2096+ for conversation in conversations :
2097+ merged_conversation . extend ( conversation )
2098+ convos . append ( merged_conversation )
2099+ return { "conversations" : convos }
2100+
20782101 dataset = dataset .map (
20792102 __combine_conversations__ ,
20802103 batched = True ,
@@ -2682,16 +2705,23 @@ def test_hf_gguf_equivalence(tokenizer, gguf_model = "./model-unsloth.F16.gguf")
26822705
26832706 if tokenizer .chat_template is not None :
26842707 prompt = tokenizer .apply_chat_template (messages , tokenize = False , add_generation_prompt = True )
2685- prompt = prompt .replace ("'" , "" ) # Subprocess does not like ''
26862708 prompt = remove_special_tokens (tokenizer , prompt )
26872709 prompts .append (prompt )
26882710
26892711 for prompt in prompts :
2690- command = f"./llama.cpp/llama-cli -m { gguf_model } -n 0 --temp 0.0 --verbose-prompt " \
2691- f"--check-tensors -p '{ prompt } '"
2712+ # Use a list of args with shell=False so prompt content is passed literally.
2713+ command = [
2714+ "./llama.cpp/llama-cli" ,
2715+ "-m" , gguf_model ,
2716+ "-n" , "0" ,
2717+ "--temp" , "0.0" ,
2718+ "--verbose-prompt" ,
2719+ "--check-tensors" ,
2720+ "-p" , prompt ,
2721+ ]
26922722
26932723 datas = []
2694- with subprocess .Popen (command , shell = True , stdout = subprocess .PIPE , stderr = subprocess .STDOUT , bufsize = 1 ) as sp :
2724+ with subprocess .Popen (command , shell = False , stdout = subprocess .PIPE , stderr = subprocess .STDOUT , bufsize = 1 ) as sp :
26952725 for line in sp .stdout :
26962726 datas .append (line .decode ("utf-8" , errors = "replace" ))
26972727 gguf_tokens = "" .join (datas )
0 commit comments