Skip to content

[Bugfix] Fix transformers 5.x compat issues in online TTS serving#1536

Open
linyueqian wants to merge 4 commits intovllm-project:mainfrom
linyueqian:bugfix/tts-transformers5-compat
Open

[Bugfix] Fix transformers 5.x compat issues in online TTS serving#1536
linyueqian wants to merge 4 commits intovllm-project:mainfrom
linyueqian:bugfix/tts-transformers5-compat

Conversation

@linyueqian
Copy link
Contributor

Summary

  • Remove fix_mistral_regex=True from AutoTokenizer.from_pretrained (parameter removed in transformers 5.x)
  • Add fallback for 'default' rope_type missing from ROPE_INIT_FUNCTIONS in transformers 5.x (inline standard sinusoidal RoPE)
  • Clamp num_cached_tokens to max(0, ...) in OmniGenerationScheduler to prevent negative value crash

These fixes are required for online TTS serving to work with the current environment (transformers 5.2.0, pinned via uv.lock).

Copy link

@chatgpt-codex-connector chatgpt-codex-connector bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: 339b3ddb2b

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

Copy link
Collaborator

@hsliuustc0106 hsliuustc0106 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Summary

This PR fixes three compatibility issues with transformers 5.x that were breaking online TTS serving. The changes are minimal, focused, and address real breaking changes in the transformers library.

Pros:

  • Addresses actual breaking changes in transformers 5.x
  • Small, focused fixes (21 additions, 4 deletions)
  • Good inline documentation explaining the 'default' rope_type fallback
  • Defensive programming with the max(0, ...) clamp
  • Clear error message for unsupported rope types

Cons:

  • No test coverage for the new fallback logic
  • The num_cached_tokens negative value issue suggests a deeper problem upstream

Recommendation: Approve with suggestions for follow-up investigation.

def _default_rope_init(config, device=None, seq_len=None, layer_type=None):
head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
inv_freq = 1.0 / (
config.rope_theta ** (torch.arange(0, head_dim, 2, dtype=torch.float32, device=device) / head_dim)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good: Well-documented fallback

The inline implementation of 'default' RoPE is well-documented and correct. The comment clearly explains why this is needed (transformers 5.x removed 'default' from ROPE_INIT_FUNCTIONS).

Suggestion: Consider adding a reference to the transformers version where this changed:

# transformers>=5.0 removed 'default' from ROPE_INIT_FUNCTIONS (see transformers PR #xxxxx)

f"Unsupported rope_type '{self.rope_type}'. Expected one of {list(ROPE_INIT_FUNCTIONS)} or 'default'."
)

inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good: Clear error message

The error message provides helpful context about what rope types are supported. This will make debugging easier if an unsupported type is encountered.

events=request.take_events(),
kv_transfer_params=kv_transfer_params,
trace_headers=request.trace_headers,
num_cached_tokens=request.num_cached_tokens,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Issue: Symptom fix, not root cause

Clamping num_cached_tokens to max(0, ...) prevents the crash, but it's treating the symptom rather than the root cause. A negative num_cached_tokens suggests:

  1. There's a bug upstream where request.num_cached_tokens is being set to a negative value
  2. Or there's a logic error in how cached tokens are being counted

Recommendation:

  • Add a warning log when clamping occurs to help track down the root cause:
num_cached = request.num_cached_tokens
if num_cached < 0:
    logger.warning(f"Negative num_cached_tokens ({num_cached}) detected for request {request.request_id}, clamping to 0")
    num_cached = 0
num_cached_tokens=num_cached,
  • File a follow-up issue to investigate why num_cached_tokens can be negative

This defensive fix is fine for now, but understanding the root cause would prevent potential issues elsewhere.

num_cached_tokens=max(0, request.num_cached_tokens),
num_external_computed_tokens=request.num_external_computed_tokens,
routed_experts=routed_experts,
num_nans_in_logits=request.num_nans_in_logits,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same issue here

Same recommendation as above - consider adding logging to track when this clamping occurs.

config.rope_theta ** (torch.arange(0, head_dim, 2, dtype=torch.float32, device=device) / head_dim)
)
return inv_freq, 1.0

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: _default_rope_init doesn't close over anything — pull it out to module level so you're not creating a new function object per instance.

events=request.take_events(),
kv_transfer_params=kv_transfer_params,
trace_headers=request.trace_headers,
num_cached_tokens=request.num_cached_tokens,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 to adding a logger.warning when clamping fires. Silent clamps on negative values will mask whatever upstream bug is producing them.

Copy link
Contributor

@lishunyang12 lishunyang12 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Left a couple minor comments. The fixes look correct overall.

Signed-off-by: linyueqian <linyueqian@outlook.com>
…known types

Signed-off-by: linyueqian <linyueqian@outlook.com>
…hed_tokens

Signed-off-by: linyueqian <linyueqian@outlook.com>
@linyueqian linyueqian force-pushed the bugfix/tts-transformers5-compat branch from 50daf92 to d911ac2 Compare February 28, 2026 04:06
@linyueqian
Copy link
Contributor Author

@hsliuustc0106 check this again?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants