Skip to content

Commit 68764cc

Browse files
[Refactor]: Phase1 for rebasing_additional_info (#1394)
Signed-off-by: dsinghvi <divyanshsinghvi@gmail.com> Signed-off-by: Divyansh Singhvi <divyanshsinghvi@gmail.com> Signed-off-by: Zhou Taichang <tzhouam@connect.ust.hk> Co-authored-by: Zhou Taichang <tzhouam@connect.ust.hk>
1 parent fec0182 commit 68764cc

File tree

7 files changed

+122
-46
lines changed

7 files changed

+122
-46
lines changed

tests/worker/test_omni_gpu_model_runner.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ def _make_runner(req_ids=("r1", "r2"), hidden_size=4):
6464
# Minimal attributes used by OmniGPUModelRunner._talker_mtp_forward
6565
runner.input_batch = DummyInputBatch(list(req_ids))
6666
runner.requests = {rid: DummyReqState() for rid in req_ids}
67+
runner.model_intermediate_buffer = {}
6768

6869
# query_start_loc.cpu[req_index] is used to locate the token position
6970
# in the flattened `inputs_embeds`.
@@ -167,6 +168,61 @@ def test_talker_mtp_forward_cpu_empty_batch_noop(monkeypatch):
167168
assert torch.allclose(inputs_embeds, before)
168169

169170

171+
def test_update_intermediate_buffer_writes_to_buffer_and_setattr(monkeypatch):
172+
"""Validate that _update_intermediate_buffer writes to model_intermediate_buffer
173+
(forward path) and mirrors to additional_information_cpu setattr (backward compat)."""
174+
import vllm_omni.worker.gpu_model_runner as mod
175+
176+
monkeypatch.setattr(mod, "set_forward_context", _noop_forward_context)
177+
178+
runner = _make_runner(req_ids=("r1",), hidden_size=4)
179+
180+
update = {"my_tensor": torch.tensor([1.0, 2.0]), "my_list": [3, 4]}
181+
OmniGPUModelRunner._update_intermediate_buffer(runner, "r1", update)
182+
183+
# Forward: buffer is populated
184+
assert "r1" in runner.model_intermediate_buffer
185+
buf = runner.model_intermediate_buffer["r1"]
186+
assert torch.allclose(buf["my_tensor"], torch.tensor([1.0, 2.0]))
187+
assert buf["my_list"] == [3, 4]
188+
189+
# Backward compat: setattr is also populated
190+
info_cpu = runner.requests["r1"].additional_information_cpu
191+
assert torch.allclose(info_cpu["my_tensor"], torch.tensor([1.0, 2.0]))
192+
assert info_cpu["my_list"] == [3, 4]
193+
194+
195+
def test_update_intermediate_buffer_accumulates():
196+
"""Validate that successive merges accumulate keys in the buffer."""
197+
runner = _make_runner(req_ids=("r1",), hidden_size=4)
198+
199+
OmniGPUModelRunner._update_intermediate_buffer(runner, "r1", {"a": torch.tensor([1.0])})
200+
OmniGPUModelRunner._update_intermediate_buffer(runner, "r1", {"b": torch.tensor([2.0])})
201+
202+
buf = runner.model_intermediate_buffer["r1"]
203+
assert "a" in buf and "b" in buf
204+
assert torch.allclose(buf["a"], torch.tensor([1.0]))
205+
assert torch.allclose(buf["b"], torch.tensor([2.0]))
206+
207+
208+
def test_update_intermediate_buffer_skips_empty_update():
209+
"""Validate that an empty update dict is a no-op."""
210+
runner = _make_runner(req_ids=("r1",), hidden_size=4)
211+
212+
OmniGPUModelRunner._update_intermediate_buffer(runner, "r1", {})
213+
214+
assert "r1" not in runner.model_intermediate_buffer
215+
216+
217+
def test_update_intermediate_buffer_skips_unknown_req_id():
218+
"""Validate that merge is a no-op when req_id is not in self.requests."""
219+
runner = _make_runner(req_ids=("r1",), hidden_size=4)
220+
221+
OmniGPUModelRunner._update_intermediate_buffer(runner, "unknown_req", {"key": torch.tensor([1.0])})
222+
223+
assert "unknown_req" not in runner.model_intermediate_buffer
224+
225+
170226
def test_maybe_attach_mimo_audio_req_infos_enriches_dict():
171227
runner = _make_runner_for_mimo()
172228
req_id = "r_mimo"

vllm_omni/model_executor/models/qwen3_omni/qwen3_omni.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -410,11 +410,16 @@ def make_omni_output(self, model_outputs: torch.Tensor | OmniOutput, **kwargs) -
410410
talker_hidden = model_outputs
411411
# merge the code_predictor_codes from the info_dict list into a single tensor
412412
multimodal_outputs: dict = None
413-
# Here is the only place to use runtime_additional_information. After MTP in the
413+
# Here is the only place to use model_intermediate_buffer. After MTP in the
414414
# preprocess function, the code_predictor_codes are stored in the info_dict list.
415415
# We need to merge the tensors from different requests into a single tensor.
416416
# In the future, we may allow user to custom an aggregated function.
417-
info_dicts = kwargs.get("runtime_additional_information")
417+
info_dicts = kwargs.get("model_intermediate_buffer")
418+
if info_dicts is None:
419+
info_dicts = kwargs.get("runtime_additional_information")
420+
421+
if "runtime_additional_information" in kwargs and "model_intermediate_buffer" not in kwargs:
422+
logger.warning_once("runtime_additional_information is deprecated, use model_intermediate_buffer")
418423
code_predictor_codes = [info.get("code_predictor_codes") for info in info_dicts]
419424
multimodal_outputs = {"code_predictor_codes": torch.cat(code_predictor_codes, dim=0)}
420425
span_len = multimodal_outputs["code_predictor_codes"].shape[0]

vllm_omni/model_executor/models/qwen3_tts/qwen3_tts.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515
import base64
16+
import copy
1617
import io
1718
import urllib.request
1819
from collections.abc import Iterable
@@ -151,9 +152,14 @@ def forward(
151152

152153
# Extract additional parameters from kwargs that the generation methods expect
153154

154-
runtime_additional_information = kwargs.get("runtime_additional_information", [{}])
155+
runtime_additional_information = kwargs.get("model_intermediate_buffer")
156+
if runtime_additional_information is None:
157+
runtime_additional_information = kwargs.get("runtime_additional_information", [{}])
158+
if "runtime_additional_information" in kwargs and "model_intermediate_buffer" not in kwargs:
159+
logger.warning_once("runtime_additional_information is deprecated, use model_intermediate_buffer")
155160
if isinstance(runtime_additional_information, list) and len(runtime_additional_information) > 0:
156161
runtime_additional_information = runtime_additional_information[0]
162+
runtime_additional_information = copy.deepcopy(runtime_additional_information)
157163
text = runtime_additional_information.pop("text", [""])[0]
158164
# Extract task_type from kwargs, default to self.task_type
159165
task_type = _normalize_task_type(runtime_additional_information.pop("task_type", [self.task_type])[0])

vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_talker.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -332,7 +332,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
332332
self.have_multimodal_outputs = True
333333
self.has_preprocess = True
334334
self.has_postprocess = True
335-
336335
# Used by OmniGPUModelRunner for the GPU-side MTP fast-path.
337336
self.mtp_hidden_size = int(self.talker_config.hidden_size)
338337
# OmniGPUModelRunner will store talker_mtp output under this key in
@@ -441,7 +440,11 @@ def make_omni_output(self, model_outputs: torch.Tensor | OmniOutput, **kwargs: A
441440
return model_outputs
442441

443442
hidden = model_outputs
444-
info_dicts = kwargs.get("runtime_additional_information") or []
443+
info_dicts = kwargs.get("model_intermediate_buffer")
444+
if info_dicts is None:
445+
info_dicts = kwargs.get("runtime_additional_information") or []
446+
if "runtime_additional_information" in kwargs and "model_intermediate_buffer" not in kwargs:
447+
logger.warning_once("runtime_additional_information is deprecated, use model_intermediate_buffer")
445448
audio_codes_list: list[torch.Tensor] = []
446449
ref_code_len_list: list[torch.Tensor] = []
447450
codec_streaming_list: list[torch.Tensor] = []

vllm_omni/platforms/npu/worker/npu_ar_model_runner.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -586,7 +586,7 @@ def _resolve_global_request_id(self, req_id: str) -> str:
586586
if not req_state:
587587
return req_id
588588

589-
add_info = getattr(req_state, "additional_information_cpu", {}) or {}
589+
add_info = self.model_intermediate_buffer.get(req_id, {})
590590
global_id = add_info.get("global_request_id")
591591
if global_id:
592592
if isinstance(global_id, list) and global_id:

vllm_omni/worker/gpu_ar_model_runner.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -590,7 +590,7 @@ def _resolve_global_request_id(self, req_id: str) -> str:
590590
if not req_state:
591591
return req_id
592592

593-
add_info = getattr(req_state, "additional_information_cpu", {}) or {}
593+
add_info = self.model_intermediate_buffer.get(req_id, {})
594594
global_id = add_info.get("global_request_id")
595595
if global_id:
596596
if isinstance(global_id, list) and global_id:

vllm_omni/worker/gpu_model_runner.py

Lines changed: 45 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@
3737
class OmniGPUModelRunner(GPUModelRunner):
3838
def __init__(self, *args, **kwargs):
3939
super().__init__(*args, **kwargs)
40-
self._omni_per_req_additional_information: dict[str, dict] | None = None
40+
self.model_intermediate_buffer: dict[str, dict[str, Any]] = {}
4141
self._omni_num_scheduled_tokens_np: np.ndarray | None = None
4242
self._omni_last_model_output: object | None = None
4343

@@ -234,6 +234,7 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
234234
# Remove finished requests from the cached states.
235235
for req_id in scheduler_output.finished_req_ids:
236236
self.requests.pop(req_id, None)
237+
self.model_intermediate_buffer.pop(req_id, None)
237238
self.num_prompt_logprobs.pop(req_id, None)
238239
# Remove the finished requests from the persistent batch.
239240
# NOTE(woosuk): There could be an edge case where finished_req_ids and
@@ -328,6 +329,9 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
328329
# Decode additional_information payloads (dictionary)
329330
try:
330331
if getattr(new_req_data, "additional_information", None) is not None:
332+
logger.warning_once(
333+
"additional_information on request data is deprecated, use model_intermediate_buffer"
334+
)
331335
payload_info = new_req_data.additional_information
332336
info_dict = {}
333337
if isinstance(payload_info, dict):
@@ -345,6 +349,8 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
345349
else:
346350
info_dict[k] = entry.list_data
347351
if info_dict:
352+
self.model_intermediate_buffer[req_id] = info_dict
353+
# Backward compatible: mirror to old setattr location
348354
setattr(
349355
self.requests[req_id],
350356
"additional_information_cpu",
@@ -873,6 +879,9 @@ def _decode_and_store_request_payloads(self, scheduler_output: "SchedulerOutput"
873879
# additional_information
874880
payload_info = getattr(nr, "additional_information", None)
875881
if payload_info is not None:
882+
logger.warning_once(
883+
"additional_information on request data is deprecated, use model_intermediate_buffer"
884+
)
876885
info_dict = {}
877886
if isinstance(payload_info, dict):
878887
info_dict = payload_info
@@ -891,17 +900,18 @@ def _decode_and_store_request_payloads(self, scheduler_output: "SchedulerOutput"
891900
else:
892901
info_dict[k] = getattr(entry, "list_data", None)
893902
if info_dict and req_id in self.requests:
903+
self.model_intermediate_buffer[req_id] = info_dict
904+
# Backward compatible: mirror to old setattr location
894905
setattr(self.requests[req_id], "additional_information_cpu", info_dict)
895906
except Exception as e:
896907
logger.error(f"Error decoding prompt_embeds / additional_information: {e}")
897908

898909
def _gather_runtime_additional_information(self) -> list[dict]:
899-
"""Gather per-request additional_information stored in request state in batch order."""
910+
"""Gather per-request model_intermediate_buffer in batch order."""
900911
per_req_runtime_info = []
901912
for req_id in self.input_batch.req_ids:
902-
req_state = self.requests.get(req_id)
903-
info = getattr(req_state, "additional_information_cpu", None) if req_state is not None else None
904-
if info and isinstance(info, dict):
913+
info = self.model_intermediate_buffer.get(req_id, {})
914+
if info:
905915
per_req_runtime_info.append(info)
906916
if "thinker_reply_part_per_request" in info:
907917
q = info["thinker_reply_part_per_request"]
@@ -921,12 +931,13 @@ def _compute_request_token_spans(self, num_scheduled_tokens_np) -> list[tuple[in
921931
return req_token_spans
922932

923933
def _build_model_kwargs_extra(self) -> dict:
924-
"""Build extra keyword arguments passed to the model for this step, including:
925-
- runtime_additional_information: per-request additional information stored in request state
926-
"""
934+
"""Build extra keyword arguments passed to the model for this step."""
927935
model_kwargs_extra: dict[str, object] = {}
928936
try:
929-
model_kwargs_extra["runtime_additional_information"] = self._gather_runtime_additional_information()
937+
buffer_map = self._gather_runtime_additional_information()
938+
model_kwargs_extra["model_intermediate_buffer"] = buffer_map
939+
# Backward compatible: also emit old name
940+
model_kwargs_extra["runtime_additional_information"] = buffer_map
930941
except Exception as e:
931942
logger.error(f"[OMNI DEBUG] Error building model_kwargs_extra: {e}")
932943
import traceback
@@ -941,23 +952,20 @@ def _process_additional_information_updates(
941952
num_scheduled_tokens_np: np.ndarray,
942953
scheduler_output: "SchedulerOutput",
943954
) -> None:
944-
"""Process model-provided per-request additional_information updates and merge into request state."""
955+
"""Process model-provided per-request updates and merge into model_intermediate_buffer."""
945956
try:
946957
# execute the custom postprocess function
947958
# TODO(Peiqi): do we have a more elegant way to do this?
948959
if hasattr(self.model, "has_postprocess") and self.model.has_postprocess:
949960
for req_index, req_id in enumerate(self.input_batch.req_ids):
950-
req_state = self.requests.get(req_id)
951-
req_infos = (
952-
getattr(req_state, "additional_information_cpu", None) if req_state is not None else None
953-
)
961+
req_infos = self.model_intermediate_buffer.get(req_id, {})
954962
start_offset = int(self.query_start_loc.cpu[req_index])
955963
sched_tokens = int(num_scheduled_tokens_np[req_index])
956964
s, e = start_offset, start_offset + sched_tokens
957965
# only consider to store data into update dict.
958966
hidden_states_slice = hidden_states[s:e]
959967
update_dict = self.model.postprocess(hidden_states_slice, **req_infos)
960-
self._merge_additional_information_update(req_id, update_dict)
968+
self._update_intermediate_buffer(req_id, update_dict)
961969
except Exception as e:
962970
logger.error(
963971
f"Error merging for requests:{self.input_batch.req_ids} "
@@ -991,27 +999,23 @@ def _collect_additional_information_for_prefill(
991999
start_offset = int(self.query_start_loc.cpu[req_index])
9921000
self.inputs_embeds[start_offset : start_offset + overlay_len].copy_(src)
9931001

994-
def _update_request_information(self, request_id: str, payload_info: dict) -> None:
995-
"""Update per-request additional_information stored in request state."""
996-
req_state = self.requests.get(request_id)
997-
if req_state is None:
998-
return
999-
1000-
info_dict = getattr(req_state, "additional_information_cpu", None)
1001-
if isinstance(payload_info, dict) and info_dict is not None:
1002-
info_dict.update(payload_info)
1003-
10041002
def _update_additional_information(self, scheduler_output: "SchedulerOutput") -> None:
10051003
for new_req in scheduler_output.scheduled_new_reqs:
10061004
payload_info = getattr(new_req, "additional_information", None)
10071005
if isinstance(payload_info, dict):
1008-
self._update_request_information(new_req.req_id, payload_info)
1006+
logger.warning_once(
1007+
"additional_information on request data is deprecated, use model_intermediate_buffer"
1008+
)
1009+
self._update_intermediate_buffer(new_req.req_id, payload_info)
10091010

10101011
if hasattr(scheduler_output.scheduled_cached_reqs, "additional_information"):
1012+
logger.warning_once(
1013+
"additional_information on scheduled_cached_reqs is deprecated, use model_intermediate_buffer"
1014+
)
10111015
cached_infos = getattr(scheduler_output.scheduled_cached_reqs, "additional_information", {})
10121016
if isinstance(cached_infos, dict):
10131017
for req_id, req_infos in cached_infos.items():
1014-
self._update_request_information(req_id, req_infos)
1018+
self._update_intermediate_buffer(req_id, req_infos)
10151019

10161020
def _maybe_attach_mimo_audio_req_infos(
10171021
self,
@@ -1158,10 +1162,10 @@ def _preprocess(
11581162
if self.vllm_config.model_config.async_chunk:
11591163
self._update_additional_information(scheduler_output)
11601164
for req_index, req_id in enumerate(self.input_batch.req_ids):
1161-
req_state = self.requests.get(req_id)
1162-
req_infos = getattr(req_state, "additional_information_cpu", None) if req_state is not None else None
1165+
req_infos = self.model_intermediate_buffer.get(req_id, {})
11631166

11641167
# mimo-audio check
1168+
req_state = self.requests.get(req_id)
11651169
req_infos = self._maybe_attach_mimo_audio_req_infos(req_state, req_infos, req_id)
11661170

11671171
start_offset = int(self.query_start_loc.cpu[req_index])
@@ -1270,23 +1274,25 @@ def _model_forward(
12701274
self._omni_last_model_output = model_output
12711275
return model_output
12721276

1273-
def _merge_additional_information_update(self, req_id: str, upd: dict | None) -> None:
1274-
if not isinstance(upd, dict):
1277+
def _update_intermediate_buffer(self, req_id: str, upd: dict) -> None:
1278+
if not isinstance(upd, dict) or not upd:
12751279
return
12761280
req_state = self.requests.get(req_id)
12771281
if req_state is None:
12781282
return
1279-
existing = getattr(req_state, "additional_information_cpu", {})
1280-
if not isinstance(existing, dict):
1281-
existing = {}
1282-
merged = dict(existing)
1283+
existing = self.model_intermediate_buffer.setdefault(req_id, {})
12831284
for k, v in upd.items():
12841285
if isinstance(v, torch.Tensor):
1285-
merged[k] = v.detach().to("cpu").contiguous()
1286+
existing[k] = v.detach().to("cpu").contiguous()
12861287
elif isinstance(v, list):
1287-
merged[k] = [
1288+
existing[k] = [
12881289
(item.detach().to("cpu").contiguous() if isinstance(item, torch.Tensor) else item) for item in v
12891290
]
12901291
else:
1291-
merged[k] = v
1292-
setattr(req_state, "additional_information_cpu", merged)
1292+
existing[k] = v
1293+
# Backward compatible: mirror to old setattr location
1294+
setattr(req_state, "additional_information_cpu", existing)
1295+
1296+
def _merge_additional_information_update(self, req_id, upd):
1297+
logger.warning_once("_merge_additional_information_update is deprecated, use _update_intermediate_buffer")
1298+
return self._update_intermediate_buffer(req_id, upd)

0 commit comments

Comments
 (0)