Skip to content

Commit 5557283

Browse files
Fix regressions from security PRs #4042, #4044, and #4045 (#4062)
* Fix security-regression fallout in chat templates and PDL patching * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Drop security regression test files from PR scope * Apply suggestion from @danielhanchen --------- Co-authored-by: Daniel Hanchen <danielhanchen@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent f1a11f0 commit 5557283

File tree

2 files changed

+100
-55
lines changed

2 files changed

+100
-55
lines changed

unsloth/chat_templates.py

Lines changed: 81 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -1959,41 +1959,61 @@ def _parse_combined_prompt(combined_prompt, dataset):
19591959

19601960

19611961
def _create_formatter(possible_columns, final_optional_prompts, user_column_name):
1962-
# Start final prompt!
1963-
function = ["def __combined_prompt_processor__(examples):"]
1964-
columns = list(set(possible_columns))
1965-
for column in columns:
1966-
function.append(f"{' '*4}{column}__ = examples['{column}']")
1967-
function.append(f"{' '*4}texts = []")
1968-
function.append(f"{' '*4}for ({', '.join(columns)}) in zip({', '.join(f'{x}__' for x in columns)}):")
1969-
1970-
# Add optional tags as well!
1971-
final_prompt = ""
1972-
formatter = []
1962+
columns = list(dict.fromkeys(possible_columns))
1963+
merged_prompt_parts = []
1964+
formatter_templates = []
19731965

19741966
for j, optional_prompt in enumerate(final_optional_prompts):
19751967
if type(optional_prompt) is str:
1976-
columns = re.findall(r"\{(.+?)\}", optional_prompt)
1977-
formatter += columns
1978-
# Must escape \n \r
1979-
final_prompt += optional_prompt.encode("unicode-escape").decode("utf-8").replace("'", "\\'").replace('"', '\\"')
1980-
else:
1981-
where, prompt = optional_prompt
1982-
# Strip [[...]]
1983-
# Must escape \n \r
1984-
prompt = prompt[2:-2].encode("unicode-escape").decode("utf-8").replace("'", "\\'").replace('"', '\\"')
1985-
columns = re.findall(r"\{(.+?)\}", prompt)
1986-
x = f"__optional_{j}__"
1987-
prompt = f"{' '*8}{x} = '{prompt}'.format({', '.join(f'{x} = {x}' for x in columns)}) if {columns[0]} else ''"
1988-
function.append(prompt)
1989-
formatter.append(x)
1990-
final_prompt += "{" + x + "}"
1991-
1992-
function.insert(1, f"{' '*4}__combined_prompt__ = '{final_prompt}'")
1993-
function.append(f"{' '*8}texts.append("\
1994-
f"__combined_prompt__.format({', '.join(f'{x} = {x}' for x in formatter)}))")
1995-
function.append(f"{' '*4}return " + "{ " + f"'{user_column_name}' : texts" + " }")
1996-
return "\n".join(function)
1968+
needed_columns = re.findall(r"\{(.+?)\}", optional_prompt)
1969+
formatter_templates.append(("required", optional_prompt, needed_columns))
1970+
merged_prompt_parts.append(optional_prompt)
1971+
continue
1972+
1973+
_, prompt = optional_prompt
1974+
prompt = prompt[2:-2]
1975+
needed_columns = re.findall(r"\{(.+?)\}", prompt)
1976+
if len(needed_columns) == 0:
1977+
raise IndexError("Unsloth: Optional [[...]] blocks must contain at least 1 {column}.")
1978+
optional_name = f"__optional_{j}__"
1979+
formatter_templates.append(("optional", optional_name, prompt, needed_columns))
1980+
merged_prompt_parts.append("{" + optional_name + "}")
1981+
1982+
merged_prompt = "".join(merged_prompt_parts)
1983+
1984+
def __combined_prompt_processor__(examples):
1985+
if len(examples) == 0:
1986+
return {user_column_name: []}
1987+
1988+
first_key = next(iter(examples.keys()), None)
1989+
if first_key is None:
1990+
return {user_column_name: []}
1991+
n_rows = len(examples[first_key])
1992+
1993+
texts = []
1994+
for row_idx in range(n_rows):
1995+
row_values = {column: examples[column][row_idx] for column in columns}
1996+
formatter_values = {}
1997+
1998+
for formatter_template in formatter_templates:
1999+
if formatter_template[0] == "required":
2000+
_, _, needed_columns = formatter_template
2001+
for column in needed_columns:
2002+
formatter_values[column] = row_values[column]
2003+
continue
2004+
2005+
_, optional_name, prompt, needed_columns = formatter_template
2006+
if row_values[needed_columns[0]] not in (None, ""):
2007+
prompt_values = {column: row_values[column] for column in needed_columns}
2008+
formatter_values[optional_name] = prompt.format(**prompt_values)
2009+
else:
2010+
formatter_values[optional_name] = ""
2011+
2012+
texts.append(merged_prompt.format(**formatter_values))
2013+
2014+
return {user_column_name: texts}
2015+
2016+
return __combined_prompt_processor__
19972017

19982018

19992019
def to_sharegpt(
@@ -2025,13 +2045,17 @@ def to_sharegpt(
20252045
raise TypeError("Unsloth: Your dataset is probably already in ShareGPT format!")
20262046

20272047
possible_columns, final_optional_prompts = _parse_combined_prompt(merged_prompt, dataset)
2028-
function = _create_formatter(possible_columns, final_optional_prompts, merged_column_name)
2029-
exec(function, globals())
2030-
dataset = dataset.map(__combined_prompt_processor__, batched = True, desc = "Merging columns")
2048+
formatter = _create_formatter(possible_columns, final_optional_prompts, merged_column_name)
2049+
dataset = dataset.map(formatter, batched = True, desc = "Merging columns")
20312050

20322051
def __convert_to_sharegpt__(examples):
20332052
users = examples[merged_column_name]
20342053
assistants = examples[output_column_name]
2054+
if len(users) != len(assistants):
2055+
raise ValueError(
2056+
"Unsloth: Input and output columns must have matching batch lengths. "
2057+
f"Got {len(users)} {merged_column_name} rows and {len(assistants)} {output_column_name} rows."
2058+
)
20352059
texts = [
20362060
[
20372061
{"from" : "human", "value" : str(user) },
@@ -2062,19 +2086,18 @@ def __convert_to_sharegpt__(examples):
20622086
dataset = concatenate_datasets(all_shuffled, axis = 1)
20632087

20642088
# Combine them into 1
2065-
function = "def __combine_conversations__(examples):\n"
20662089
n_extensions += 1
2067-
for j in range(n_extensions):
2068-
function += f"{' '*4}conversations{j}__ = examples['conversations{j}']\n"
2069-
function += f"{' '*4}convos = []\n"
2070-
function += f"{' '*4}for ({', '.join(f'conversations{j}' for j in range(n_extensions))}) "\
2071-
f"in zip({', '.join(f'conversations{j}__' for j in range(n_extensions))}):\n"
2072-
function += f"{' '*8}convos.append("\
2073-
f"{'+'.join(f'conversations{j}' for j in range(n_extensions))})\n"
2074-
function += f"{' '*4}return " + "{ " + "'conversations' : convos" + " }"
2075-
2076-
# Map function
2077-
exec(function, globals())
2090+
conversation_columns = [f"conversations{j}" for j in range(n_extensions)]
2091+
def __combine_conversations__(examples):
2092+
columns = [examples[column] for column in conversation_columns]
2093+
convos = []
2094+
for conversations in zip(*columns):
2095+
merged_conversation = []
2096+
for conversation in conversations:
2097+
merged_conversation.extend(conversation)
2098+
convos.append(merged_conversation)
2099+
return {"conversations" : convos}
2100+
20782101
dataset = dataset.map(
20792102
__combine_conversations__,
20802103
batched = True,
@@ -2682,16 +2705,23 @@ def test_hf_gguf_equivalence(tokenizer, gguf_model = "./model-unsloth.F16.gguf")
26822705

26832706
if tokenizer.chat_template is not None:
26842707
prompt = tokenizer.apply_chat_template(messages, tokenize = False, add_generation_prompt = True)
2685-
prompt = prompt.replace("'", "") # Subprocess does not like ''
26862708
prompt = remove_special_tokens(tokenizer, prompt)
26872709
prompts.append(prompt)
26882710

26892711
for prompt in prompts:
2690-
command = f"./llama.cpp/llama-cli -m {gguf_model} -n 0 --temp 0.0 --verbose-prompt "\
2691-
f"--check-tensors -p '{prompt}'"
2712+
# Use a list of args with shell=False so prompt content is passed literally.
2713+
command = [
2714+
"./llama.cpp/llama-cli",
2715+
"-m", gguf_model,
2716+
"-n", "0",
2717+
"--temp", "0.0",
2718+
"--verbose-prompt",
2719+
"--check-tensors",
2720+
"-p", prompt,
2721+
]
26922722

26932723
datas = []
2694-
with subprocess.Popen(command, shell = True, stdout = subprocess.PIPE, stderr = subprocess.STDOUT, bufsize = 1) as sp:
2724+
with subprocess.Popen(command, shell = False, stdout = subprocess.PIPE, stderr = subprocess.STDOUT, bufsize = 1) as sp:
26952725
for line in sp.stdout:
26962726
datas.append(line.decode("utf-8", errors = "replace"))
26972727
gguf_tokens = "".join(datas)

unsloth/import_fixes.py

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1045,10 +1045,10 @@ def _spec_exists(name):
10451045
return
10461046

10471047
# Check if vLLM version includes the fix
1048-
VLLM_PDL_FIX_VERSION = "0.13.2"
1048+
VLLM_PDL_FIX_VERSION = "0.15.0"
10491049
try:
10501050
vllm_version = Version(importlib_version("vllm"))
1051-
if vllm_version > Version(VLLM_PDL_FIX_VERSION):
1051+
if vllm_version >= Version(VLLM_PDL_FIX_VERSION):
10521052
logger.info(
10531053
f"Unsloth: SM100 ({sm100_gpu_name}) detected but vLLM {vllm_version} "
10541054
f"should include PDL fix - skipping workaround"
@@ -1066,6 +1066,12 @@ def fake_supports_pdl(*args, **kwargs):
10661066
return False
10671067

10681068
patched = []
1069+
patched_names = set()
1070+
1071+
def _record_patch(name):
1072+
if name not in patched_names:
1073+
patched.append(name)
1074+
patched_names.add(name)
10691075

10701076
# First, patch the source module (utils.py) where supports_pdl is defined.
10711077
# This is critical because supports_pdl uses @lru_cache - we must clear the
@@ -1077,7 +1083,7 @@ def fake_supports_pdl(*args, **kwargs):
10771083
if hasattr(original_fn, "cache_clear"):
10781084
original_fn.cache_clear()
10791085
utils_module.supports_pdl = fake_supports_pdl
1080-
patched.append("utils")
1086+
_record_patch("utils")
10811087
except (ImportError, ModuleNotFoundError, AttributeError):
10821088
pass
10831089

@@ -1094,10 +1100,19 @@ def fake_supports_pdl(*args, **kwargs):
10941100
module = importlib.import_module(path)
10951101
if hasattr(module, "supports_pdl"):
10961102
module.supports_pdl = fake_supports_pdl
1097-
patched.append(name)
1103+
_record_patch(name)
10981104
except (ImportError, ModuleNotFoundError, AttributeError):
10991105
pass
11001106

1107+
# Patch any additional already-loaded triton ops consumers that expose supports_pdl.
1108+
for module_name, module in tuple(sys.modules.items()):
1109+
if not module_name.startswith("vllm.lora.ops.triton_ops."):
1110+
continue
1111+
if module is None or not hasattr(module, "supports_pdl"):
1112+
continue
1113+
module.supports_pdl = fake_supports_pdl
1114+
_record_patch(module_name.rsplit(".", 1)[-1])
1115+
11011116
if patched:
11021117
logger.info(
11031118
f"Unsloth: Applied PDL fix for SM100 ({sm100_gpu_name}) - "

0 commit comments

Comments
 (0)