-
-
Notifications
You must be signed in to change notification settings - Fork 323
Description
OS
Windows
GPU Library
CUDA 12.x
Python version
3.11
Pytorch version
2.5.0 and 2.7.1
Model
Entire Llama3 Family (3.1, 3.2, 3.3) for now. Could also affect Gemma3 if the eos_token_id
list-ordering is shuffled around in a future-update.
Describe the bug
Issue
The entire Llama3 family of models (3.1, 3.2, 3.3) define a list comprising multiple EOS token IDs in their config.json: "eos_token_id": [128001,128008,128009]
When looking at the accompanying tokenizer.json
file, 128001
corresponds to <|end_of_text|>
, while the last 128009
corresponds to <|eot_id|>
, which is actually the correct EOS token as per Llama3's prompt-template
However, in ExLlamaV2's config.py at line 253, we see that the prepare()
method of the ExLlamaV2Config
class sets the eos_token_id
to only the first value at init:
if isinstance(self.eos_token_id, list):
self.eos_token_id = self.eos_token_id[0] # TODO: Figure out a way to maybe use all the EOS tokens somehow
Therefore when defining the job object with the eos_token_id as the stop condition, the wrong EOS token is set and generation never stops:
job = ExLlamaV2DynamicJob(
input_ids= EXL2_TOKENIZER.encode(tokenized_messages, encode_special_tokens=True),
max_new_tokens = int(max_new_tokens),
stop_conditions = [EXL2_TOKENIZER.eos_token_id], # sets to 128001, while 128009 is desired!
gen_settings = gen_settings
)
Potential Workaround 1 - Manually Add the correct EOS token integer ID's to the stop condition:
job = ExLlamaV2DynamicJob(
input_ids= EXL2_TOKENIZER.encode(tokenized_messages, encode_special_tokens=True),
max_new_tokens = int(max_new_tokens),
stop_conditions = [EXL2_TOKENIZER.eos_token_id, 128009], # manually adding the correct eos_id from config & tokenizer.json
gen_settings = gen_settings
)
This is not great for obvious reasons: may cause unpredictable instability for other models that define a similar ID for something other than the EOS ID.
Potential Workaround 2 - Import Transformers alongside & use AutoTokenizer:
Hear me out!!
from transformers import AutoTokenizer
# Set Transformers AutoTokenizer:
model_id = "meta-llama/Llama-3.1-8B-Instruct"
AUTO_TOKENIZER = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
# Load ExLlamaV2 Pipeline:
config = ExLlamaV2Config(quantized_model_path)
EXL2_MODEL = ExLlamaV2(config)
EXL2_CACHE = ExLlamaV2Cache(EXL2_MODEL, max_seq_len = exl2_max_seq_len, lazy = True) # or use ExLlamaV2Cache_Q4 / _Q8
EXL2_MODEL.load_autosplit(EXL2_CACHE, progress=True)
EXL2_TOKENIZER = ExLlamaV2Tokenizer(config)
# Define Generator and Job, now adding AUTO_TOKENIZER.eos_token_id to the stop_conditions:
exl2_dynamic_generator = ExLlamaV2DynamicGenerator(model = EXL2_MODEL, cache = EXL2_CACHE, tokenizer = EXL2_TOKENIZER)
job = ExLlamaV2DynamicJob(
input_ids= EXL2_TOKENIZER.encode(tokenized_messages, encode_special_tokens=True),
max_new_tokens = int(max_new_tokens),
stop_conditions = [EXL2_TOKENIZER.eos_token_id, AUTO_TOKENIZER.eos_token_id], # both Exl2 & Transformers EOS IDs!
gen_settings = gen_settings
)
exl2_dynamic_generator.enqueue(job)
Bonus: this also allows you to pass a standard list of messages and use AutoTokenizer
's apply_chat_template()
method to auto-apply the correct prompt-template format for any LLM!
messages = [
{"role": "system", "content": "You are a friendly chatbot."},
{"role": "user", "content": "Hello, how are you?"},
{"role": "assistant", "content": "I'm doing great. How can I help you today?"},
{"role": "user", "content": "How many helicopters can a human eat in one sitting?"},
]
tokenized_messages = AUTO_TOKENIZER.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
But I digress!
Pros:
- Reliability: Much more reliable as any such issues are far more likely to be detected and fixed in the transformers library, almost always guaranteeing
AUTO_TOKENIZER.eos_token_id
is correct. - Lightweight: Only the tokenizer config files are downloaded by Transformers for initializing the AutoTokenizer, which together amount to just a few KBs
- Auto-templating: Enables a universal
messages
list which can be easily auto-formatted with the correct prompt-template format for any given model!
Cons: Need to install and import two libraries (Transformers & ExLlamaV2) instead of just one (honestly not a big deal)
Related Issues & Potential Pitfalls:
Any model with a similar config wherein the eos_token_id
is a list of multiple values rather a singular int is at risk of similar issues. Both, Gemma3 & Gemma3n are great examples of this, as their config.json defines:
"eos_token_id": [
1,
106
],
And coincidently 1
is the correct EOS ID as per tokenizer_config.json
: <eos>
However if this were the other way around, or a future updates adds/shuffles this list around, the same issue will be observed unless ExLlamaV2 correctly sets eos_token_id
to the entire list!
Background:
When attempting to deploy Llama-3.3-70B quantized to 4.55BPW on my 2xRTX-3090 dual-GPU rig last week, I found that even simply saying 'Hi!' led the infinite generation: the model would continue generating text, quickly devolving into complete nonsense until all the tokens allowed by the 'max_new_tokens' param were generated.
After looking through the open issues on ExLlamaV2's GitHub, I could not find any other reports of this specific issue. Other models I tested (dominantly Phi4 and the Qwen3 family) worked fine, so I decided to try Llama-3.2-3B and even Llama-3.1-8B both @ 8BPW, only to find the exact same issue!
Thinking this to be an issue with Ampere/Torch-2.4.1/CUDA-12.5.0, I thought to try testing on an Nvidia Blackwell/Torch-2.7.1/CUDA-12.8 system only to find the exact same issue!
As the issue I was seeing was infinite-generation and not infinite-repetition, it got me thinking that there might be an issue with the stop-condition wherein the End-of-Stream (EOS) token may not be correctly recognized. Googling around, I found that the Llama-3 family indeed had a history of issues with the EOS token, but they had apparently been patched. Besides, the specific issues being described were different than what I was facing.
Nothing was left at this point but to start digging into the model's config and perform extensive testing on encoding/decoding messages and checking EOS IDs etc, which led to the above.
Reproduction steps
Simply try to run any of the Llama3 family of models from the official Meta repos on HuggingFace with the latest version of ExLlamaV2.
Expected behavior
LLM-response generation should stop after the response to the user's query has been completely generated!
Logs
Happy to share as required.
Additional context
Would be happy to contribute a fix with a little guidance on any potential pitfalls to modifying the eos_token_id
variable in the ExLlamaV2Config
class's prepare()
method to accepting a list instead of just a singular int? Thanks for everything!
Acknowledgements
- I have looked for similar issues before submitting this one.
- I understand that the developers have lives and my issue will be answered when possible.
- I understand the developers of this program are human, and I will ask my questions politely.