Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 56 additions & 0 deletions tests/worker/test_omni_gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ def _make_runner(req_ids=("r1", "r2"), hidden_size=4):
# Minimal attributes used by OmniGPUModelRunner._talker_mtp_forward
runner.input_batch = DummyInputBatch(list(req_ids))
runner.requests = {rid: DummyReqState() for rid in req_ids}
runner.model_intermediate_buffer = {}

# query_start_loc.cpu[req_index] is used to locate the token position
# in the flattened `inputs_embeds`.
Expand Down Expand Up @@ -167,6 +168,61 @@ def test_talker_mtp_forward_cpu_empty_batch_noop(monkeypatch):
assert torch.allclose(inputs_embeds, before)


def test_update_intermediate_buffer_writes_to_buffer_and_setattr(monkeypatch):
"""Validate that _update_intermediate_buffer writes to model_intermediate_buffer
(forward path) and mirrors to additional_information_cpu setattr (backward compat)."""
import vllm_omni.worker.gpu_model_runner as mod

monkeypatch.setattr(mod, "set_forward_context", _noop_forward_context)

runner = _make_runner(req_ids=("r1",), hidden_size=4)

update = {"my_tensor": torch.tensor([1.0, 2.0]), "my_list": [3, 4]}
OmniGPUModelRunner._update_intermediate_buffer(runner, "r1", update)

# Forward: buffer is populated
assert "r1" in runner.model_intermediate_buffer
buf = runner.model_intermediate_buffer["r1"]
assert torch.allclose(buf["my_tensor"], torch.tensor([1.0, 2.0]))
assert buf["my_list"] == [3, 4]

# Backward compat: setattr is also populated
info_cpu = runner.requests["r1"].additional_information_cpu
assert torch.allclose(info_cpu["my_tensor"], torch.tensor([1.0, 2.0]))
assert info_cpu["my_list"] == [3, 4]


def test_update_intermediate_buffer_accumulates():
"""Validate that successive merges accumulate keys in the buffer."""
runner = _make_runner(req_ids=("r1",), hidden_size=4)

OmniGPUModelRunner._update_intermediate_buffer(runner, "r1", {"a": torch.tensor([1.0])})
OmniGPUModelRunner._update_intermediate_buffer(runner, "r1", {"b": torch.tensor([2.0])})

buf = runner.model_intermediate_buffer["r1"]
assert "a" in buf and "b" in buf
assert torch.allclose(buf["a"], torch.tensor([1.0]))
assert torch.allclose(buf["b"], torch.tensor([2.0]))


def test_update_intermediate_buffer_skips_empty_update():
"""Validate that an empty update dict is a no-op."""
runner = _make_runner(req_ids=("r1",), hidden_size=4)

OmniGPUModelRunner._update_intermediate_buffer(runner, "r1", {})

assert "r1" not in runner.model_intermediate_buffer


def test_update_intermediate_buffer_skips_unknown_req_id():
"""Validate that merge is a no-op when req_id is not in self.requests."""
runner = _make_runner(req_ids=("r1",), hidden_size=4)

OmniGPUModelRunner._update_intermediate_buffer(runner, "unknown_req", {"key": torch.tensor([1.0])})

assert "unknown_req" not in runner.model_intermediate_buffer


def test_maybe_attach_mimo_audio_req_infos_enriches_dict():
runner = _make_runner_for_mimo()
req_id = "r_mimo"
Expand Down
9 changes: 7 additions & 2 deletions vllm_omni/model_executor/models/qwen3_omni/qwen3_omni.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,11 +410,16 @@ def make_omni_output(self, model_outputs: torch.Tensor | OmniOutput, **kwargs) -
talker_hidden = model_outputs
# merge the code_predictor_codes from the info_dict list into a single tensor
multimodal_outputs: dict = None
# Here is the only place to use runtime_additional_information. After MTP in the
# Here is the only place to use model_intermediate_buffer. After MTP in the
# preprocess function, the code_predictor_codes are stored in the info_dict list.
# We need to merge the tensors from different requests into a single tensor.
# In the future, we may allow user to custom an aggregated function.
info_dicts = kwargs.get("runtime_additional_information")
info_dicts = kwargs.get("model_intermediate_buffer")
if info_dicts is None:
info_dicts = kwargs.get("runtime_additional_information")

if "runtime_additional_information" in kwargs and "model_intermediate_buffer" not in kwargs:
logger.warning_once("runtime_additional_information is deprecated, use model_intermediate_buffer")
code_predictor_codes = [info.get("code_predictor_codes") for info in info_dicts]
multimodal_outputs = {"code_predictor_codes": torch.cat(code_predictor_codes, dim=0)}
span_len = multimodal_outputs["code_predictor_codes"].shape[0]
Expand Down
8 changes: 7 additions & 1 deletion vllm_omni/model_executor/models/qwen3_tts/qwen3_tts.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import base64
import copy
import io
import urllib.request
from collections.abc import Iterable
Expand Down Expand Up @@ -151,9 +152,14 @@ def forward(

# Extract additional parameters from kwargs that the generation methods expect

runtime_additional_information = kwargs.get("runtime_additional_information", [{}])
runtime_additional_information = kwargs.get("model_intermediate_buffer")
if runtime_additional_information is None:
runtime_additional_information = kwargs.get("runtime_additional_information", [{}])
if "runtime_additional_information" in kwargs and "model_intermediate_buffer" not in kwargs:
logger.warning_once("runtime_additional_information is deprecated, use model_intermediate_buffer")
if isinstance(runtime_additional_information, list) and len(runtime_additional_information) > 0:
runtime_additional_information = runtime_additional_information[0]
runtime_additional_information = copy.deepcopy(runtime_additional_information)
text = runtime_additional_information.pop("text", [""])[0]
# Extract task_type from kwargs, default to self.task_type
task_type = _normalize_task_type(runtime_additional_information.pop("task_type", [self.task_type])[0])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -332,7 +332,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
self.have_multimodal_outputs = True
self.has_preprocess = True
self.has_postprocess = True

# Used by OmniGPUModelRunner for the GPU-side MTP fast-path.
self.mtp_hidden_size = int(self.talker_config.hidden_size)
# OmniGPUModelRunner will store talker_mtp output under this key in
Expand Down Expand Up @@ -441,7 +440,11 @@ def make_omni_output(self, model_outputs: torch.Tensor | OmniOutput, **kwargs: A
return model_outputs

hidden = model_outputs
info_dicts = kwargs.get("runtime_additional_information") or []
info_dicts = kwargs.get("model_intermediate_buffer")
if info_dicts is None:
info_dicts = kwargs.get("runtime_additional_information") or []
if "runtime_additional_information" in kwargs and "model_intermediate_buffer" not in kwargs:
logger.warning_once("runtime_additional_information is deprecated, use model_intermediate_buffer")
audio_codes_list: list[torch.Tensor] = []
ref_code_len_list: list[torch.Tensor] = []
codec_streaming_list: list[torch.Tensor] = []
Expand Down
2 changes: 1 addition & 1 deletion vllm_omni/platforms/npu/worker/npu_ar_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -586,7 +586,7 @@ def _resolve_global_request_id(self, req_id: str) -> str:
if not req_state:
return req_id

add_info = getattr(req_state, "additional_information_cpu", {}) or {}
add_info = self.model_intermediate_buffer.get(req_id, {})
global_id = add_info.get("global_request_id")
if global_id:
if isinstance(global_id, list) and global_id:
Expand Down
2 changes: 1 addition & 1 deletion vllm_omni/worker/gpu_ar_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -590,7 +590,7 @@ def _resolve_global_request_id(self, req_id: str) -> str:
if not req_state:
return req_id

add_info = getattr(req_state, "additional_information_cpu", {}) or {}
add_info = self.model_intermediate_buffer.get(req_id, {})
global_id = add_info.get("global_request_id")
if global_id:
if isinstance(global_id, list) and global_id:
Expand Down
84 changes: 45 additions & 39 deletions vllm_omni/worker/gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
class OmniGPUModelRunner(GPUModelRunner):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._omni_per_req_additional_information: dict[str, dict] | None = None
self.model_intermediate_buffer: dict[str, dict[str, Any]] = {}
self._omni_num_scheduled_tokens_np: np.ndarray | None = None
self._omni_last_model_output: object | None = None

Expand Down Expand Up @@ -234,6 +234,7 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
# Remove finished requests from the cached states.
for req_id in scheduler_output.finished_req_ids:
self.requests.pop(req_id, None)
self.model_intermediate_buffer.pop(req_id, None)
self.num_prompt_logprobs.pop(req_id, None)
# Remove the finished requests from the persistent batch.
# NOTE(woosuk): There could be an edge case where finished_req_ids and
Expand Down Expand Up @@ -328,6 +329,9 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
# Decode additional_information payloads (dictionary)
try:
if getattr(new_req_data, "additional_information", None) is not None:
logger.warning_once(
"additional_information on request data is deprecated, use model_intermediate_buffer"
)
payload_info = new_req_data.additional_information
info_dict = {}
if isinstance(payload_info, dict):
Expand All @@ -345,6 +349,8 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
else:
info_dict[k] = entry.list_data
if info_dict:
self.model_intermediate_buffer[req_id] = info_dict
# Backward compatible: mirror to old setattr location
setattr(
self.requests[req_id],
"additional_information_cpu",
Expand Down Expand Up @@ -873,6 +879,9 @@ def _decode_and_store_request_payloads(self, scheduler_output: "SchedulerOutput"
# additional_information
payload_info = getattr(nr, "additional_information", None)
if payload_info is not None:
logger.warning_once(
"additional_information on request data is deprecated, use model_intermediate_buffer"
)
info_dict = {}
if isinstance(payload_info, dict):
info_dict = payload_info
Expand All @@ -891,17 +900,18 @@ def _decode_and_store_request_payloads(self, scheduler_output: "SchedulerOutput"
else:
info_dict[k] = getattr(entry, "list_data", None)
if info_dict and req_id in self.requests:
self.model_intermediate_buffer[req_id] = info_dict
# Backward compatible: mirror to old setattr location
setattr(self.requests[req_id], "additional_information_cpu", info_dict)
except Exception as e:
logger.error(f"Error decoding prompt_embeds / additional_information: {e}")

def _gather_runtime_additional_information(self) -> list[dict]:
"""Gather per-request additional_information stored in request state in batch order."""
"""Gather per-request model_intermediate_buffer in batch order."""
per_req_runtime_info = []
for req_id in self.input_batch.req_ids:
req_state = self.requests.get(req_id)
info = getattr(req_state, "additional_information_cpu", None) if req_state is not None else None
if info and isinstance(info, dict):
info = self.model_intermediate_buffer.get(req_id, {})
if info:
per_req_runtime_info.append(info)
if "thinker_reply_part_per_request" in info:
q = info["thinker_reply_part_per_request"]
Expand All @@ -921,12 +931,13 @@ def _compute_request_token_spans(self, num_scheduled_tokens_np) -> list[tuple[in
return req_token_spans

def _build_model_kwargs_extra(self) -> dict:
"""Build extra keyword arguments passed to the model for this step, including:
- runtime_additional_information: per-request additional information stored in request state
"""
"""Build extra keyword arguments passed to the model for this step."""
model_kwargs_extra: dict[str, object] = {}
try:
model_kwargs_extra["runtime_additional_information"] = self._gather_runtime_additional_information()
buffer_map = self._gather_runtime_additional_information()
model_kwargs_extra["model_intermediate_buffer"] = buffer_map
# Backward compatible: also emit old name
model_kwargs_extra["runtime_additional_information"] = buffer_map
except Exception as e:
logger.error(f"[OMNI DEBUG] Error building model_kwargs_extra: {e}")
import traceback
Expand All @@ -941,23 +952,20 @@ def _process_additional_information_updates(
num_scheduled_tokens_np: np.ndarray,
scheduler_output: "SchedulerOutput",
) -> None:
"""Process model-provided per-request additional_information updates and merge into request state."""
"""Process model-provided per-request updates and merge into model_intermediate_buffer."""
try:
# execute the custom postprocess function
# TODO(Peiqi): do we have a more elegant way to do this?
if hasattr(self.model, "has_postprocess") and self.model.has_postprocess:
for req_index, req_id in enumerate(self.input_batch.req_ids):
req_state = self.requests.get(req_id)
req_infos = (
getattr(req_state, "additional_information_cpu", None) if req_state is not None else None
)
req_infos = self.model_intermediate_buffer.get(req_id, {})
start_offset = int(self.query_start_loc.cpu[req_index])
sched_tokens = int(num_scheduled_tokens_np[req_index])
s, e = start_offset, start_offset + sched_tokens
# only consider to store data into update dict.
hidden_states_slice = hidden_states[s:e]
update_dict = self.model.postprocess(hidden_states_slice, **req_infos)
self._merge_additional_information_update(req_id, update_dict)
self._update_intermediate_buffer(req_id, update_dict)
except Exception as e:
logger.error(
f"Error merging for requests:{self.input_batch.req_ids} "
Expand Down Expand Up @@ -991,27 +999,23 @@ def _collect_additional_information_for_prefill(
start_offset = int(self.query_start_loc.cpu[req_index])
self.inputs_embeds[start_offset : start_offset + overlay_len].copy_(src)

def _update_request_information(self, request_id: str, payload_info: dict) -> None:
"""Update per-request additional_information stored in request state."""
req_state = self.requests.get(request_id)
if req_state is None:
return

info_dict = getattr(req_state, "additional_information_cpu", None)
if isinstance(payload_info, dict) and info_dict is not None:
info_dict.update(payload_info)

def _update_additional_information(self, scheduler_output: "SchedulerOutput") -> None:
for new_req in scheduler_output.scheduled_new_reqs:
payload_info = getattr(new_req, "additional_information", None)
if isinstance(payload_info, dict):
self._update_request_information(new_req.req_id, payload_info)
logger.warning_once(
"additional_information on request data is deprecated, use model_intermediate_buffer"
)
self._update_intermediate_buffer(new_req.req_id, payload_info)

if hasattr(scheduler_output.scheduled_cached_reqs, "additional_information"):
logger.warning_once(
"additional_information on scheduled_cached_reqs is deprecated, use model_intermediate_buffer"
)
cached_infos = getattr(scheduler_output.scheduled_cached_reqs, "additional_information", {})
if isinstance(cached_infos, dict):
for req_id, req_infos in cached_infos.items():
self._update_request_information(req_id, req_infos)
self._update_intermediate_buffer(req_id, req_infos)

def _maybe_attach_mimo_audio_req_infos(
self,
Expand Down Expand Up @@ -1158,10 +1162,10 @@ def _preprocess(
if self.vllm_config.model_config.async_chunk:
self._update_additional_information(scheduler_output)
for req_index, req_id in enumerate(self.input_batch.req_ids):
req_state = self.requests.get(req_id)
req_infos = getattr(req_state, "additional_information_cpu", None) if req_state is not None else None
req_infos = self.model_intermediate_buffer.get(req_id, {})

# mimo-audio check
req_state = self.requests.get(req_id)
req_infos = self._maybe_attach_mimo_audio_req_infos(req_state, req_infos, req_id)

start_offset = int(self.query_start_loc.cpu[req_index])
Expand Down Expand Up @@ -1270,23 +1274,25 @@ def _model_forward(
self._omni_last_model_output = model_output
return model_output

def _merge_additional_information_update(self, req_id: str, upd: dict | None) -> None:
if not isinstance(upd, dict):
def _update_intermediate_buffer(self, req_id: str, upd: dict) -> None:
if not isinstance(upd, dict) or not upd:
return
req_state = self.requests.get(req_id)
if req_state is None:
return
existing = getattr(req_state, "additional_information_cpu", {})
if not isinstance(existing, dict):
existing = {}
merged = dict(existing)
existing = self.model_intermediate_buffer.setdefault(req_id, {})
for k, v in upd.items():
if isinstance(v, torch.Tensor):
merged[k] = v.detach().to("cpu").contiguous()
existing[k] = v.detach().to("cpu").contiguous()
elif isinstance(v, list):
merged[k] = [
existing[k] = [
(item.detach().to("cpu").contiguous() if isinstance(item, torch.Tensor) else item) for item in v
]
else:
merged[k] = v
setattr(req_state, "additional_information_cpu", merged)
existing[k] = v
# Backward compatible: mirror to old setattr location
setattr(req_state, "additional_information_cpu", existing)

def _merge_additional_information_update(self, req_id, upd):
logger.warning_once("_merge_additional_information_update is deprecated, use _update_intermediate_buffer")
return self._update_intermediate_buffer(req_id, upd)