We're removing the TensorFlow and Jax parts of the library. This will help us focus fully on torch
going forward and will greatly reduce the maintenance cost of models. We are working with tools from
the Jax ecosystem still (such as MaxText) in order to see how we can remain compatible with their
tool while keeping torch as the only backend for now.
Linked PR: huggingface#40760
We introduce a new weight loading API in transformers, which significantly improves on the previous API. This
weight loading API is designed to apply operations to the checkpoints loaded by transformers.
Instead of loading the checkpoint exactly as it is serialized within the model, these operations can reshape, merge, and split the layers according to how they're defined in this new API. These operations are often a necessity when working with quantization or parallelism algorithms.
This new API is centered around the new WeightConverter class:
class WeightConverter(WeightTransform):
operations: list[ConversionOps]
source_keys: Union[str, list[str]]
target_keys: Union[str, list[str]]The weight converter is designed to apply a list of operations on the source keys, resulting in target keys. A common operation done on the attention layers is to fuse the query, key, values layers. Doing so with this API would amount to defining the following conversion:
conversion = WeightConverter(
["self_attn.q_proj", "self_attn.k_proj", "self_attn.v_proj"], # The input layers
"self_attn.qkv_proj", # The single layer as output
operations=[Concatenate(dim=0)],
)In this situation, we apply the Concatenate operation, which accepts a list of layers as input and returns a single
layer.
This allows us to define a mapping from architecture to a list of weight conversions. Applying those weight conversions
can apply arbitrary transformations to the layers themselves. This significantly simplified the from_pretrained method
and helped us remove a lot of technical debt that we accumulated over the past few years.
This results in several improvements:
- Much cleaner definition of transformations applied to the checkpoint
- Reversible transformations, so loading and saving a checkpoint should result in the same checkpoint
- Faster model loading thanks to scheduling of tensor materialization
- Enables complex mix of transformations that wouldn't otherwise be possible (such as quantization + MoEs, or TP + MoEs)
While this is being implemented, expect varying levels of support across different release candidates.
Linked PR: huggingface#41580
Just as we moved towards a single backend library for model definition, we want Tokenizer to be a lot more intuitive.
With v5, you can now initialize an empty LlamaTokenizer and train it directly on your new task!
Defining a new tokenizer object should be as simple as this:
from transformers import TokenizersBackend, generate_merges
from tokenizers import pre_tokenizers, Tokenizer
from tokenizers.model import BPE
class Llama5Tokenizer(TokenizersBackend):
def __init__(self, unk_token="<unk>",bos_token="<s>", eos_token="</s>", vocab=None, merges=None ):
if vocab is None:
self._vocab = {
str(unk_token): 0,
str(bos_token): 1,
str(eos_token): 2,
}
else:
self._vocab = vocab
if merges is not None:
self._merges = merges
else:
self._merges = generate_merges(filtered_vocab)
self._tokenizer = Tokenizer(
BPE(vocab=self._vocab, merges=self._merges, fuse_unk=True)
)
self._tokenizer.pre_tokenizer = pre_tokenizers.Metaspace(
replacement="▁", prepend_scheme=_get_prepend_scheme(self.add_prefix_space, self), split=False
)
super().__init__(
tokenizer_object=self._tokenizer,
unk_token=unk_token,
bos_token=bos_token,
eos_token=eos_token,
)And now if you call Llama5Tokenizer() you just get an empty, trainable tokenizer that follows the definition of the authors of Llama5 (it does not exist yet 😉).
The above is the main motivation towards refactoring tokenization: we want people to just instantiate a tokenizer like they would a model, empty or not and with exactly what they defined.
If you tokenizers is not common, or you just don't want to rely on sentencepiece nor tokenizers you can just import the PythonBackend (previousl PreTrainedTokenzier) which has all the API and logic for added tokens, encoding and decoding wieht them etc.
If you want to have en less features, you can use the common PreTrainedTokenizerBase mixin, which mostly defines transformers tokenizer API: encode, decode, vocab_size, get_vocab, convert_tokens_to_ids, convert_ids_to_tokens, from_pretrained, save_pretrained, etc.
Moving away from "slow" vs "fast" tokenizers:
Previously, transformers maintained two parallel implementations for many tokenizers:
- "Slow" tokenizers (
tokenization_<model>.py) - Python-based implementations, often using SentencePiece as the backend. - "Fast" tokenizers (
tokenization_<model>_fast.py) - Rust-based implementations using the 🤗 tokenizers library.
In v5, we consolidate to a single tokenizer file per model: tokenization_<model>.py. This file will use the most appropriate backend available:
- TokenizersBackend (preferred): Rust-based tokenizers from the 🤗 tokenizers library. In general its performances are better, but it also offers a lot more features that are comonly adopted across the ecosystem, like handling additional tokens, easily update the state of the tokenizer, automatic parallelisation etc.
- SentencePieceBackend: For models requiring SentencePiece
- PythonBackend: Pure Python implementations
- MistralCommonBackend: Relies on
MistralCommon's toknenization library. (PreviouslyMistralCommonTokenizer)
The AutoTokenizer automatically selects the appropriate backend based on available files and dependencies. This is transparent, you continue to use AutoTokenizer.from_pretrained() as before. This allows transformers to be future-proof and modular to easily support future backends.
1. Direct tokenizer initialization with vocab and merges:
In v5, you can now initialize tokenizers directly with vocabulary and merges, enabling training custom tokenizers from scratch:
# v5: Initialize a blank tokenizer for training
from transformers import LlamaTokenizer
# Create a tokenizer with custom vocabulary and merges
vocab = {"<unk>": 0, "<s>": 1, "</s>": 2, "hello": 3, "world": 4}
merges = [("h", "e"), ("l", "l"), ("o", " ")]
tokenizer = LlamaTokenizer(vocab=vocab, merges=merges)
# Or initialize a blank tokenizer to train on your own dataset
tokenizer = LlamaTokenizer() # Creates a blank Llama-like tokenizerBut you can no longer pass a vocab file. As this accounts for from_pretrained use-case.
2. Simplified decoding API:
The batch_decode method has been unified with decode. Both single and batch decoding now use the same method:
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("t5-small")
inputs = ["hey how are you?", "fine"]
tokenizer.decode(tokenizer.encode(inputs))Gives:
- 'hey how are you?</s> fine</s>'
+ ['hey how are you?</s>', 'fine</s>']This is mostly because people get list[list[int]] out of generate, but then they would use decode because they use encode and would get:
...: tokenizer.decode([[1,2], [1,4]])
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
Cell In[2], line 4
2 tokenizer = AutoTokenizer.from_pretrained("t5-small")
3 inputs = ["hey how are you?", "fine"]
----> 4 tokenizer.decode([[1,2], [1,4]])
File /raid/arthur/transformers/src/transformers/tokenization_utils_base.py:3948, in PreTrainedTokenizerBase.decode(self, token_ids, skip_special_tokens, clean_up_tokenization_spaces, **kwargs)
3945 # Convert inputs to python lists
3946 token_ids = to_py_obj(token_ids)
-> 3948 return self._decode(
3949 token_ids=token_ids,
3950 skip_special_tokens=skip_special_tokens,
3951 clean_up_tokenization_spaces=clean_up_tokenization_spaces,
3952 **kwargs,
3953 )
File /raid/arthur/transformers/src/transformers/tokenization_utils_fast.py:682, in PreTrainedTokenizerFast._decode(self, token_ids, skip_special_tokens, clean_up_tokenization_spaces, **kwargs)
680 if isinstance(token_ids, int):
681 token_ids = [token_ids]
--> 682 text = self._tokenizer.decode(token_ids, skip_special_tokens=skip_special_tokens)
684 clean_up_tokenization_spaces = (
685 clean_up_tokenization_spaces
686 if clean_up_tokenization_spaces is not None
687 else self.clean_up_tokenization_spaces
688 )
689 if clean_up_tokenization_spaces:
TypeError: argument 'ids': 'list' object cannot be interpreted as an integer3. Unified encoding API:
The encode_plus is deprecated → call directly with __call__
3. apply_chat_template returns BatchEncoding:
Previously, apply_chat_template returned input_ids for backward compatibility. In v5, it now consistently returns a BatchEncoding dict like other tokenizer methods:
# v5
messages = [
{"role": "user", "content": "Hello!"},
{"role": "assistant", "content": "Hi there!"}
]
# Now returns BatchEncoding with input_ids, attention_mask, etc.
outputs = tokenizer.apply_chat_template(messages, return_tensors="pt")
print(outputs.keys()) # dict_keys(['input_ids', 'attention_mask'])special_tokens_map.json- special tokens are now stored intokenizer_config.json.added_tokens.json- added tokens are now stored intokenizer.json.added_tokens_decoderis only stored when there is notokenizer.json.
When loading older tokenizers, these files are still read for backward compatibility, but new saves use the consolidated format.
Several models that had identical tokenizers now import from their base implementation:
- LayoutLM → uses BertTokenizer
- LED → uses BartTokenizer
- Longformer → uses RobertaTokenizer
- LXMert → uses BertTokenizer
- MT5 → uses T5Tokenizer
- MVP → uses BartTokenizer
We're just gonna remove these files at term.
Removed T5-specific workarounds:
The internal _eventually_correct_t5_max_length method has been removed. T5 tokenizers now handle max length consistently with other models.
Model-specific tokenization test files now focus on integration tests.
Common tokenization API tests (e.g., add_tokens, encode, decode) are now centralized and automatically applied across all tokenizers. This reduces test duplication and ensures consistent behavior
For legacy implementations, the original BERT Python tokenizer code (including WhitespaceTokenizer, BasicTokenizer, etc.) is preserved in bert_legacy.py for reference purposes.
Linked PRs:
The use_auth_token argument/parameter is deprecated in favor of token everywhere.
You should be able to search and replace use_auth_token with token and get the same logic.
Linked PR: huggingface#41666
We decided to remove some features for the upcoming v5 as they are currently only supported in a few old models and no longer integrated in current model additions. It's recommended to stick to v4.x in case you need them. Following features are affected:
- No more head masking, see #41076. This feature allowed to turn off certain heads during the attention calculation and only worked for eager.
- No more relative positional biases in Bert-like models, see #41170. This feature was introduced to allow relative position scores within attention calculations (similar to T5). However, this feature is barely used in official models and a lot of complexity instead. It also only worked with eager.
- No more head pruning, see #41417 by @gante. As the name suggests, it allowed to prune heads within your attention layers.
We dropped support for two torch APIs:
torchscriptin huggingface#41688torch.fxin huggingface#41683
Those APIs were deprecated by the PyTorch team, and we're instead focusing on the supported APIs dynamo and export.
We clean up the quantization API in transformers, and significantly refactor the weight loading as highlighted above.
We drop support for two quantization arguments that have been deprecated for some time:
load_in_4bitload_in_8bit
We remove them in favor of the quantization_config argument which is much more complete. As an example, here is how
you would load a 4-bit bitsandbytes model using this argument:
from transformers import AutoModelForCausalLM, BitsAndBytesConfig
quantization_config = BitsAndBytesConfig(load_in_4bit=True)
model_4bit = AutoModelForCausalLM.from_pretrained(
"meta-llama/Llama-3.2-3B",
device_map="auto",
quantization_config=quantization_config
)- Methods to init a nested config such as
from_xxx_configare deleted. Configs can be init from the__init__method in the same way. See #41314. - It is no longer possible to load a config class from a URL file. Configs must be loaded from either a local path or a repo on the Hub. See #42383.
- All parameters for configuring model's rotary embedding are now stored under
mode.rope_parameters, including therope_thetaandrope_type. Model'sconfig.rope_parametersis a simple dictionaty in most cases, and can also be a nested dict in special cases (i.e. Gemma3 and ModernBert) with different rope parameterization for each layer type. See #39847 - Qwen-VL family configuration is in a nested format and trying to access keys directly will throw an error (e.g.
config.vocab_size). Users are expected to access keys from their respective sub-configs (config.text_config.vocab_size).
- Slow tokenizer files (aka:
tokenization_<model>.py) will be removed in favor of using fast tokenizer filestokenization_<model>_fast.py--> will be renamed totokenization_<model>.py. As fast tokenizers are 🤗tokenizers- backend, they include a wider range of features that are maintainable and reliable. - Other backends (sentence piece, tokenizers, etc.) will be supported with a light layer if loading a fast tokenizer fails
- Remove legacy files like special_tokens_map.json and added_tokens.json
- Remove _eventually_correct_t5_max_length
encode_plus-->__call__batch_decode-->decode
apply_chat_template by default returns naked input_ids rather than a BatchEncoding dict.
This was inconvenient - it should return a BatchEncoding dict like tokenizer.__call__(), but we were stuck with
it for backward compatibility. The method now returns a BatchEncoding.
Linked PRs:
- In processing classes each attribute will be serialized under
processor_config.jsonas a nested dict, instead of serializing attributes in their own config files. Loading will be supported for all old format processors (huggingface#41474) XXXFeatureExtractorsclasses are completely removed in favor ofXXXImageProcessorclass for all vision models (huggingface#41174)- Minor change:
XXXFastImageProcessorKwargsis removed in favor ofXXXImageProcessorKwargswhich will be shared between fast and slow processors (huggingface#40931)
- Some
RotaryEmbeddingslayers will start returning a dict of tuples, in case the model uses several RoPE configurations (Gemma2, ModernBert). Each value will be a tuple of "cos, sin" per RoPE type. - Config attribute for
RotaryEmbeddingslayer will be unified and accessed viaconfig.rope_parameters. Config attr forrope_thetamight not be accessible anymore for some models, and instead will be inconfig.rope_parameters['rope_theta']. BC will be supported for a while as much as possible, and in the near future we'll gradually move to the new RoPE format (huggingface#39847) - Vision Language models will not have a shortcut access to its language and vision component from the generative model via
model.language_model. It is recommended to either access the module withmodel.model.language_modelormodel.get_decoder(). See #42156
- Old, deprecated output type aliases were removed (e.g.
GreedySearchEncoderDecoderOutput). We now only have 4 output classes built from the following matrix: decoder-only vs encoder-decoder, uses beams vs doesn't use beams (huggingface#40998) - Removed deprecated classes regarding decoding methods that were moved to the Hub due to low usage (constraints and beam scores) (huggingface#41223)
- If
generatedoesn't receive any KV Cache argument, the default cache class used is now defined by the model (as opposed to always beingDynamicCache) (huggingface#41505) - Generation parameters are no longer accessible via model's config. If generation paramaters are serialized in
config.jsonfor any old model, it will be loaded back into model's generation config. Users are expected to access or modify generation parameters only withmodel.generation_config.do_sample = True.
mp_parameters-> legacy param that was later on added to sagemaker trainer_n_gpu-> not intended for users to set, we will initialize it correctly instead of putting it in theTrainingArgumentsoverwrite_output_dir- > replaced byresume_from_checkpointand it was only used in examples script, no impact on Trainer.logging_dir-> only used for tensorboard, setTENSORBOARD_LOGGING_DIRenv var insteadjit_mode_eval-> useuse_torch_compileinstead as torchscript is not recommended anymoretpu_num_cores-> It is actually better to remove it as it is not recommended to set the number of cores. By default, all tpu cores are used . SetTPU_NUM_CORESenv var insteadpast_index-> it was only used for a very small number of models that have special architecture like transformersxl + it was not documented at all how to train those modelray_scope-> only for a minor arg for ray integration. SetRAY_SCOPEvar env insteadwarmup_ratio-> usewarmup_stepinstead. We combined both args together by allowing passing float values inwarmup_step.
fsdp_min_num_paramsandfsdp_transformer_layer_cls_to_wrap-> usefsdp_configtpu_metrics_debug->debugpush_to_hub_token->hub_tokenpush_to_hub_model_idandpush_to_hub_organization->hub_model_idinclude_inputs_for_metrics->include_for_metricsper_gpu_train_batch_size->per_device_train_batch_sizeper_gpu_eval_batch_size->per_device_eval_batch_sizeuse_mps_device-> mps will be used by default if detectedfp16_backendandhalf_precision_backend-> we will only rely on torch.amp as everything has been upstream to torchno_cuda->use_cpuinclude_tokens_per_second->include_num_input_tokens_seenuse_legacy_prediction_loop-> we only useevaluation_loopfunction from now on
tokenizerin initialization ->processing_classmodel_pathin train() ->resume_from_checkpoint
- sigpot integration for hp search was removed as the library was archived + the api stopped working
- drop support for sagemaker API <1.10
- bump accelerate minimum version to 1.1.0
use_cachein the model config will be set toFalse. You can still change the cache value throughTrainingArgumentsusel_cacheargument if needed.
- Image text to text pipelines will no longer accept images as a separate argument along with conversation chats. Image data has to be embedded in the chat's "content" field. See #42359
- removed deprecated
organizationandrepo_urlfromPushToHubMixin. You must pass arepo_idinstead. - removed
ignore_metadata_errorsfromPushToMixin. In practice if we ignore errors while loading the model card, we won't be able to push the card back to the Hub so it's better to fail early and not provide the option to fail later. push_to_hubdo not accept**kwargsanymore. All accepted parameters are explicitly documented.- arguments of
push_to_hubare now keyword-only to avoid confusion. Onlyrepo_idcan be positional since it's the main arg. - removed
use_temp_dirargument frompush_to_hub. We now use a tmp dir in all cases.
Linked PR: huggingface#42391.
The deprecated transformers-cli ... command was deprecated, transformers ... is now the only CLI entry point.
transformers CLI has been migrated to Typer, making it easier to maintain + adding some nice features out of
the box (improved --help section, autocompletion).
Biggest breaking change is in transformers chat. This command starts a terminal UI to interact with a chat model.
It used to also be able to start a Chat Completion server powered by transformers and chat with it. In this revamped
version, this feature has been removed in favor of transformers serve. The goal of splitting transformers chat
and transformers serve is to define clear boundaries between client and server code. It helps with maintenance
but also makes the commands less bloated. The new signature of transformers chat is:
Usage: transformers chat [OPTIONS] BASE_URL MODEL_ID [GENERATE_FLAGS]...
Chat with a model from the command line.
Example:
transformers chat https://router.huggingface.co/v1 HuggingFaceTB/SmolLM3-3BLinked PRs:
The transformers run (previously transformers-cli run) is an artefact of the past, was not documented nor tested,
and isn't part of any public documentation. We're removing it for now and ask you to please let us know in case
this is a method you are using; in which case we should bring it back with better support.
Linked PR: huggingface#42447
- Legacy environment variables like
TRANSFORMERS_CACHE,PYTORCH_TRANSFORMERS_CACHE, andPYTORCH_PRETRAINED_BERT_CACHEhave been removed. Please useHF_HOMEinstead. - Constants
HUGGINGFACE_CO_EXAMPLES_TELEMETRY,HUGGINGFACE_CO_EXAMPLES_TELEMETRY,HUGGINGFACE_CO_PREFIX, andHUGGINGFACE_CO_RESOLVE_ENDPOINThave been removed. Please usehuggingface_hub.constants.ENDPOINTinstead.
Linked PR: huggingface#42391.
transformers v5 pins the huggingface_hub version to >=1.0.0. See this migration guide to learn more about this major release. Here are to main aspects to know about:
- switched the HTTP backend from
requeststohttpx. This change was made to improve performance and to support both synchronous and asynchronous requests the same way. If you are currently catchingrequests.HTTPErrorerrors in your codebase, you'll need to switch tohttpx.HTTPError. - related to 1., it is not possible to set proxies from your script. To handle proxies, you must set the
HTTP_PROXY/HTTPS_PROXYenvironment variables hf_transferand thereforeHF_HUB_ENABLE_HF_TRANSFERhave been completed dropped in favor ofhf_xet. This should be transparent for most users. Please let us know if you notice any downside!
typer-slim has been added as required dependency, used to implement both hf and transformers CLIs.