1919from threading import Thread
2020from typing import Any , Dict , Iterator , List , Optional , Tuple
2121
22+ import torch
23+
2224from .....types import (
2325 ChatCompletion ,
2426 ChatCompletionAudio ,
3537
3638@register_transformer
3739@register_non_default_model ("qwen2.5-omni" )
38- class Qwen2_5OmniChatModel (PytorchMultiModalModel ):
40+ @register_non_default_model ("Qwen3-Omni-Thinking" )
41+ @register_non_default_model ("Qwen3-Omni-Instruct" )
42+ class QwenOmniChatModel (PytorchMultiModalModel ):
3943 DEFAULT_SYSTEM_PROMPT = (
4044 "You are Qwen, a virtual human developed by the Qwen Team, Alibaba Group, "
4145 "capable of perceiving auditory and visual inputs, as well as generating text and speech."
4246 )
4347
48+ def __init__ (self , * args , ** kwargs ):
49+ super ().__init__ (* args , ** kwargs )
50+ # 2.5 or 3
51+ model_family = self .model_family .model_family or self .model_family .model_name
52+ self ._omni_version = "2.5" if "2.5" in model_family else "3"
53+
4454 @classmethod
4555 def match_json (
4656 cls , model_family : "LLMFamilyV2" , model_spec : "LLMSpecV1" , quantization : str
4757 ) -> bool :
4858 if model_spec .model_format not in ["pytorch" , "gptq" , "awq" , "bnb" ]:
4959 return False
5060 llm_family = model_family .model_family or model_family .model_name
51- if "qwen2.5-omni" .lower () in llm_family .lower ():
61+ if (
62+ "qwen2.5-omni" .lower () in llm_family .lower ()
63+ or "qwen3-omni" .lower () in llm_family .lower ()
64+ ):
5265 return True
5366 return False
5467
@@ -58,15 +71,25 @@ def decide_device(self):
5871 self ._device = device
5972
6073 def load_processor (self ):
61- from transformers import Qwen2_5OmniProcessor
74+ if self ._omni_version == "2.5" :
75+ from transformers import Qwen2_5OmniProcessor as QwenOminiProcessor
76+ else :
77+ from transformers import Qwen3OmniMoeProcessor as QwenOminiProcessor
6278
63- self ._processor = Qwen2_5OmniProcessor .from_pretrained (
79+ self ._processor = QwenOminiProcessor .from_pretrained (
6480 self .model_path , trust_remote_code = True
6581 )
6682 self ._tokenizer = self ._processor .tokenizer
6783
6884 def load_multimodal_model (self ):
69- from transformers import Qwen2_5OmniForConditionalGeneration
85+ if self ._omni_version == "2.5" :
86+ from transformers import (
87+ Qwen2_5OmniForConditionalGeneration as QwenOmniForConditionalGeneration ,
88+ )
89+ else :
90+ from transformers import (
91+ Qwen3OmniMoeForConditionalGeneration as QwenOmniForConditionalGeneration ,
92+ )
7093
7194 # for multiple GPU, set back to auto to make multiple devices work
7295 device = "auto" if self ._device == "cuda" else self ._device
@@ -79,7 +102,7 @@ def load_multimodal_model(self):
79102 kwargs = self .apply_bnb_quantization (kwargs )
80103 logger .debug ("Loading model with extra kwargs: %s" , kwargs )
81104
82- self ._model = Qwen2_5OmniForConditionalGeneration .from_pretrained (
105+ self ._model = QwenOmniForConditionalGeneration .from_pretrained (
83106 self .model_path ,
84107 torch_dtype = "auto" ,
85108 device_map = device ,
@@ -181,11 +204,37 @@ def generate_non_streaming(
181204 inputs = self .build_inputs_from_messages (messages , generate_config ) # type: ignore
182205 use_audio_in_video = generate_config .get ("use_audio_in_video" , True )
183206 gen_kwargs = dict (** inputs , ** config , use_audio_in_video = use_audio_in_video )
184- generated_ids , audio = self ._model .generate (** gen_kwargs )
185- generated_ids_trimmed = [
186- out_ids [len (in_ids ) :]
187- for in_ids , out_ids in zip (inputs .input_ids , generated_ids )
188- ]
207+ # === Run model.generate() (handle both (ids, audio) and ids-only cases) ===
208+ result = self ._model .generate (** gen_kwargs )
209+ if isinstance (result , tuple ) and len (result ) == 2 :
210+ # Qwen2.5-Omni returns (generated_ids, audio)
211+ generated_ids , audio = result
212+ else :
213+ # Qwen3-Omni returns only generated_ids
214+ generated_ids , audio = result , None
215+ if hasattr (generated_ids , "sequences" ):
216+ generated_ids = generated_ids .sequences
217+
218+ # === Handle text decoding ===
219+ input_len = inputs .input_ids .shape [1 ]
220+ # Ensure we have a consistent 2D structure
221+ # Normalize to list[list[int]]
222+ if isinstance (generated_ids , torch .Tensor ):
223+ generated_ids = generated_ids .tolist ()
224+ elif isinstance (generated_ids , list ) and all (
225+ isinstance (x , int ) for x in generated_ids
226+ ):
227+ # Single sequence as flat list of ints
228+ generated_ids = [generated_ids ]
229+ elif isinstance (generated_ids , list ) and all (
230+ isinstance (x , list ) for x in generated_ids
231+ ):
232+ pass # already correct
233+ else :
234+ raise TypeError (f"Unexpected generated_ids type: { type (generated_ids )} " )
235+
236+ # Remove prompt tokens
237+ generated_ids_trimmed = [out_ids [input_len :] for out_ids in generated_ids ]
189238 output_text = self ._processor .batch_decode (
190239 generated_ids_trimmed ,
191240 skip_special_tokens = True ,
0 commit comments