Skip to content
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 56 additions & 0 deletions tests/worker/test_omni_gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ def _make_runner(req_ids=("r1", "r2"), hidden_size=4):
# Minimal attributes used by OmniGPUModelRunner._talker_mtp_forward
runner.input_batch = DummyInputBatch(list(req_ids))
runner.requests = {rid: DummyReqState() for rid in req_ids}
runner.model_intermediate_buffer = {}

# query_start_loc.cpu[req_index] is used to locate the token position
# in the flattened `inputs_embeds`.
Expand Down Expand Up @@ -132,3 +133,58 @@ def test_talker_mtp_forward_cpu_empty_batch_noop(monkeypatch):

# Ensure no changes were made
assert torch.allclose(inputs_embeds, before)


def test_merge_intermediate_buffer_writes_to_buffer_and_setattr(monkeypatch):
"""Validate that _merge_intermediate_buffer writes to model_intermediate_buffer
(forward path) and mirrors to additional_information_cpu setattr (backward compat)."""
import vllm_omni.worker.gpu_model_runner as mod

monkeypatch.setattr(mod, "set_forward_context", _noop_forward_context)

runner = _make_runner(req_ids=("r1",), hidden_size=4)

update = {"my_tensor": torch.tensor([1.0, 2.0]), "my_list": [3, 4]}
OmniGPUModelRunner._merge_intermediate_buffer(runner, "r1", update)

# Forward: buffer is populated
assert "r1" in runner.model_intermediate_buffer
buf = runner.model_intermediate_buffer["r1"]
assert torch.allclose(buf["my_tensor"], torch.tensor([1.0, 2.0]))
assert buf["my_list"] == [3, 4]

# Backward compat: setattr is also populated
info_cpu = runner.requests["r1"].additional_information_cpu
assert torch.allclose(info_cpu["my_tensor"], torch.tensor([1.0, 2.0]))
assert info_cpu["my_list"] == [3, 4]


def test_merge_intermediate_buffer_accumulates():
"""Validate that successive merges accumulate keys in the buffer."""
runner = _make_runner(req_ids=("r1",), hidden_size=4)

OmniGPUModelRunner._merge_intermediate_buffer(runner, "r1", {"a": torch.tensor([1.0])})
OmniGPUModelRunner._merge_intermediate_buffer(runner, "r1", {"b": torch.tensor([2.0])})

buf = runner.model_intermediate_buffer["r1"]
assert "a" in buf and "b" in buf
assert torch.allclose(buf["a"], torch.tensor([1.0]))
assert torch.allclose(buf["b"], torch.tensor([2.0]))


def test_merge_intermediate_buffer_skips_empty_update():
"""Validate that an empty update dict is a no-op."""
runner = _make_runner(req_ids=("r1",), hidden_size=4)

OmniGPUModelRunner._merge_intermediate_buffer(runner, "r1", {})

assert "r1" not in runner.model_intermediate_buffer


def test_merge_intermediate_buffer_skips_unknown_req_id():
"""Validate that merge is a no-op when req_id is not in self.requests."""
runner = _make_runner(req_ids=("r1",), hidden_size=4)

OmniGPUModelRunner._merge_intermediate_buffer(runner, "unknown_req", {"key": torch.tensor([1.0])})

assert "unknown_req" not in runner.model_intermediate_buffer
18 changes: 16 additions & 2 deletions vllm_omni/model_executor/models/qwen3_omni/qwen3_omni.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ class Qwen3OmniMoeForConditionalGeneration(
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
self.have_multimodal_outputs = True
self._additional_info_compat_warned: bool = False
self.has_preprocess = False
self.has_postprocess = False
config: Qwen3OmniMoeConfig = vllm_config.model_config.hf_config
Expand Down Expand Up @@ -437,11 +438,24 @@ def make_omni_output(self, model_outputs: torch.Tensor | OmniOutput, **kwargs) -
talker_hidden = model_outputs
# merge the code_predictor_codes from the info_dict list into a single tensor
multimodal_outputs: dict = None
# Here is the only place to use runtime_additional_information. After MTP in the
# Here is the only place to use model_intermediate_buffer. After MTP in the
# preprocess function, the code_predictor_codes are stored in the info_dict list.
# We need to merge the tensors from different requests into a single tensor.
# In the future, we may allow user to custom an aggregated function.
info_dicts = kwargs.get("runtime_additional_information")
info_dicts = kwargs.get("model_intermediate_buffer")
if info_dicts is None:
info_dicts = kwargs.get("runtime_additional_information")

if "runtime_additional_information" in kwargs and "model_intermediate_buffer" not in kwargs:
if not self._additional_info_compat_warned:
self._additional_info_compat_warned = True
import warnings

warnings.warn(
"runtime_additional_information is deprecated, use model_intermediate_buffer",
DeprecationWarning,
stacklevel=2,
)
code_predictor_codes = [info.get("code_predictor_codes") for info in info_dicts]
multimodal_outputs = {"code_predictor_codes": torch.cat(code_predictor_codes, dim=0)}
span_len = multimodal_outputs["code_predictor_codes"].shape[0]
Expand Down
17 changes: 16 additions & 1 deletion vllm_omni/model_executor/models/qwen3_tts/qwen3_tts.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ class VoiceClonePromptItem:
class Qwen3TTSModelForGeneration(nn.Module):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
self._additional_info_compat_warned: bool = False
model_path = vllm_config.model_config.model

# Check if flash-attn is installed
Expand Down Expand Up @@ -112,9 +113,23 @@ def forward(

# Extract additional parameters from kwargs that the generation methods expect

runtime_additional_information = kwargs.get("runtime_additional_information", [{}])
runtime_additional_information = kwargs.get("model_intermediate_buffer")
if runtime_additional_information is None:
runtime_additional_information = kwargs.get("runtime_additional_information", [{}])
if "runtime_additional_information" in kwargs and "model_intermediate_buffer" not in kwargs:
if not self._additional_info_compat_warned:
self._additional_info_compat_warned = True
import warnings

warnings.warn(
"runtime_additional_information is deprecated, use model_intermediate_buffer",
DeprecationWarning,
stacklevel=2,
)
if isinstance(runtime_additional_information, list) and len(runtime_additional_information) > 0:
runtime_additional_information = runtime_additional_information[0]
# Copy to avoid mutating the shared buffer dict
runtime_additional_information = dict(runtime_additional_information)
text = runtime_additional_information.pop("text", [""])[0]
# Extract task_type from kwargs, default to "instruct"
task_type = runtime_additional_information.pop("task_type", [self.task_type])[0]
Expand Down
15 changes: 14 additions & 1 deletion vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_talker.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,6 +332,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
self.have_multimodal_outputs = True
self.has_preprocess = True
self.has_postprocess = True
self._additional_info_compat_warned: bool = False

# Used by OmniGPUModelRunner for the GPU-side MTP fast-path.
self.mtp_hidden_size = int(self.talker_config.hidden_size)
Expand Down Expand Up @@ -441,7 +442,19 @@ def make_omni_output(self, model_outputs: torch.Tensor | OmniOutput, **kwargs: A
return model_outputs

hidden = model_outputs
info_dicts = kwargs.get("runtime_additional_information") or []
info_dicts = kwargs.get("model_intermediate_buffer")
if info_dicts is None:
info_dicts = kwargs.get("runtime_additional_information") or []
if "runtime_additional_information" in kwargs and "model_intermediate_buffer" not in kwargs:
if not self._additional_info_compat_warned:
self._additional_info_compat_warned = True
import warnings

warnings.warn(
"runtime_additional_information is deprecated, use model_intermediate_buffer",
DeprecationWarning,
stacklevel=2,
)
audio_codes_list: list[torch.Tensor] = []
ref_code_len_list: list[torch.Tensor] = []
codec_streaming_list: list[torch.Tensor] = []
Expand Down
2 changes: 1 addition & 1 deletion vllm_omni/platforms/npu/worker/npu_ar_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -512,7 +512,7 @@ def _resolve_global_request_id(self, req_id: str) -> str:
if not req_state:
return req_id

add_info = getattr(req_state, "additional_information_cpu", {}) or {}
add_info = self.model_intermediate_buffer.get(req_id, {})
global_id = add_info.get("global_request_id")
if global_id:
if isinstance(global_id, list) and global_id:
Expand Down
Loading