diff --git a/fastdeploy/entrypoints/chat_utils.py b/fastdeploy/entrypoints/chat_utils.py
index 08abc8ed391..c40aaab6f4b 100644
--- a/fastdeploy/entrypoints/chat_utils.py
+++ b/fastdeploy/entrypoints/chat_utils.py
@@ -199,11 +199,10 @@ def parse_chat_messages(messages: List[ChatCompletionMessageParam]):
role = message["role"]
content = message["content"]
- parsed_content = []
if content is None:
- parsed_content = []
+ parsed_content = content
elif isinstance(content, str):
- parsed_content = [{"type": "text", "text": content}]
+ parsed_content = content
else:
parsed_content = [parse_content_part(mm_parser, part) for part in content]
diff --git a/fastdeploy/entrypoints/llm.py b/fastdeploy/entrypoints/llm.py
index 1d18421a55d..e646f667428 100644
--- a/fastdeploy/entrypoints/llm.py
+++ b/fastdeploy/entrypoints/llm.py
@@ -276,7 +276,7 @@ def chat(
raise RuntimeError(f"Failed to validate 'tools' parameter in chat method: {e}") from e
req_ids = self._add_request(
- prompts=[{"messages": msg} for msg in messages],
+ prompts=messages,
sampling_params=sampling_params,
chat_template_kwargs=chat_template_kwargs,
chat_template=chat_template,
@@ -326,11 +326,16 @@ def _add_request(
"prompt": prompts[i],
"request_id": request_id,
}
- elif isinstance(prompts[i], list) and isinstance(prompts[i][0], int):
+ elif isinstance(prompts[i], list) and len(prompts[i]) > 0 and isinstance(prompts[i][0], int):
tasks = {
"prompt_token_ids": prompts[i],
"request_id": request_id,
}
+ elif isinstance(prompts[i], list) and len(prompts[i]) > 0 and isinstance(prompts[i][0], dict):
+ tasks = {
+ "messages": prompts[i],
+ "request_id": request_id,
+ }
elif isinstance(prompts[i], dict):
tasks = prompts[i]
tasks["request_id"] = request_id
diff --git a/fastdeploy/input/multimodal/__init__.py b/fastdeploy/input/multimodal/__init__.py
new file mode 100644
index 00000000000..05aeda0e4bf
--- /dev/null
+++ b/fastdeploy/input/multimodal/__init__.py
@@ -0,0 +1,29 @@
+# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Multimodal processors for FastDeploy."""
+
+from fastdeploy.input.multimodal.ernie4_5_vl import Ernie4_5VLProcessor
+from fastdeploy.input.multimodal.mm_processor import MMProcessor
+from fastdeploy.input.multimodal.paddleocr_vl import PaddleOCRVLProcessor
+from fastdeploy.input.multimodal.qwen3_vl import Qwen3VLProcessor
+from fastdeploy.input.multimodal.qwen_vl import QwenVLProcessor
+
+__all__ = [
+ "MMProcessor",
+ "QwenVLProcessor",
+ "Qwen3VLProcessor",
+ "Ernie4_5VLProcessor",
+ "PaddleOCRVLProcessor",
+]
diff --git a/fastdeploy/input/multimodal/common.py b/fastdeploy/input/multimodal/common.py
new file mode 100644
index 00000000000..891363b7428
--- /dev/null
+++ b/fastdeploy/input/multimodal/common.py
@@ -0,0 +1,147 @@
+# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Shared image utility functions for all VL image processors."""
+
+import math
+
+import numpy as np
+
+from fastdeploy.utils import data_processor_logger
+
+__all__ = [
+ "round_by_factor",
+ "ceil_by_factor",
+ "floor_by_factor",
+ "is_scaled_image",
+ "smart_resize",
+ "smart_resize_qwen",
+ "smart_resize_paddleocr",
+]
+
+
+def round_by_factor(number: int, factor: int) -> int:
+ """Returns the closest integer to 'number' that is divisible by 'factor'."""
+ return round(number / factor) * factor
+
+
+def ceil_by_factor(number: int, factor: int) -> int:
+ """Returns the smallest integer >= 'number' that is divisible by 'factor'."""
+ return math.ceil(number / factor) * factor
+
+
+def floor_by_factor(number: int, factor: int) -> int:
+ """Returns the largest integer <= 'number' that is divisible by 'factor'."""
+ return math.floor(number / factor) * factor
+
+
+def is_scaled_image(image: np.ndarray) -> bool:
+ """Check if image pixel values are already normalized to [0, 1] range."""
+ if image.dtype == np.uint8:
+ return False
+ return np.min(image) >= 0 and np.max(image) <= 1
+
+
+def smart_resize_qwen(
+ height: int,
+ width: int,
+ factor: int,
+ min_pixels: int,
+ max_pixels: int,
+ max_ratio: int = 200,
+) -> tuple:
+ """Smart image resizing for ERNIE / Qwen2.5 / Qwen3 models."""
+ if max(height, width) / min(height, width) > max_ratio:
+ if height > width:
+ new_width = max(factor, round_by_factor(width, factor))
+ new_height = floor_by_factor(new_width * max_ratio, factor)
+ else:
+ new_height = max(factor, round_by_factor(height, factor))
+ new_width = floor_by_factor(new_height * max_ratio, factor)
+
+ data_processor_logger.info(
+ f"absolute aspect ratio must be smaller than {max_ratio}, "
+ f"got {max(height, width) / min(height, width)}, "
+ f"resize to {max(new_height, new_width) / min(new_height, new_width)}"
+ )
+ height = new_height
+ width = new_width
+
+ h_bar = max(factor, round_by_factor(height, factor))
+ w_bar = max(factor, round_by_factor(width, factor))
+ if h_bar * w_bar > max_pixels:
+ beta = math.sqrt((height * width) / max_pixels)
+ h_bar = floor_by_factor(height / beta, factor)
+ w_bar = floor_by_factor(width / beta, factor)
+ elif h_bar * w_bar < min_pixels:
+ beta = math.sqrt(min_pixels / (height * width))
+ h_bar = ceil_by_factor(height * beta, factor)
+ w_bar = ceil_by_factor(width * beta, factor)
+
+ if min_pixels > h_bar * w_bar or h_bar * w_bar > max_pixels:
+ raise ValueError(f"encounter invalid h_bar: {h_bar}, w_bar: {w_bar}")
+
+ return h_bar, w_bar
+
+
+def smart_resize_paddleocr(
+ height: int,
+ width: int,
+ factor: int = 28,
+ min_pixels: int = 28 * 28 * 130,
+ max_pixels: int = 28 * 28 * 1280,
+) -> tuple:
+ """Smart image resizing for PaddleOCR-VL model."""
+ if height < factor:
+ data_processor_logger.debug(f"smart_resize_paddleocr: height={height} < factor={factor}, reset height=factor")
+ width = round((width * factor) / height)
+ height = factor
+
+ if width < factor:
+ data_processor_logger.debug(f"smart_resize_paddleocr: width={width} < factor={factor}, reset width=factor")
+ height = round((height * factor) / width)
+ width = factor
+
+ if max(height, width) / min(height, width) > 200:
+ raise ValueError(
+ f"absolute aspect ratio must be smaller than 200, " f"got {max(height, width) / min(height, width)}"
+ )
+
+ h_bar = round(height / factor) * factor
+ w_bar = round(width / factor) * factor
+ if h_bar * w_bar > max_pixels:
+ beta = math.sqrt((height * width) / max_pixels)
+ h_bar = math.floor(height / beta / factor) * factor
+ w_bar = math.floor(width / beta / factor) * factor
+ elif h_bar * w_bar < min_pixels:
+ beta = math.sqrt(min_pixels / (height * width))
+ h_bar = math.ceil(height * beta / factor) * factor
+ w_bar = math.ceil(width * beta / factor) * factor
+
+ return h_bar, w_bar
+
+
+def smart_resize(
+ height: int,
+ width: int,
+ factor: int,
+ min_pixels: int,
+ max_pixels: int,
+ max_ratio: int = 200,
+ variant: str = "qwen",
+) -> tuple:
+ """Unified smart_resize dispatcher."""
+ if variant == "paddleocr":
+ return smart_resize_paddleocr(height, width, factor, min_pixels, max_pixels)
+ return smart_resize_qwen(height, width, factor, min_pixels, max_pixels, max_ratio)
diff --git a/fastdeploy/input/multimodal/ernie4_5_vl.py b/fastdeploy/input/multimodal/ernie4_5_vl.py
new file mode 100644
index 00000000000..33e9860df5c
--- /dev/null
+++ b/fastdeploy/input/multimodal/ernie4_5_vl.py
@@ -0,0 +1,466 @@
+# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Ernie4_5VLProcessor — multimodal processor for ERNIE 4.5 VL."""
+
+import copy
+from collections import defaultdict
+
+import numpy as np
+import paddle
+from paddleformers.transformers.image_utils import ChannelDimension
+
+from fastdeploy.engine.request import ImagePosition
+from fastdeploy.input.multimodal.image_processors import AdaptiveImageProcessor
+from fastdeploy.input.multimodal.mm_processor import MMProcessor
+from fastdeploy.input.utils import IDS_TYPE_FLAG, MAX_IMAGE_DIMENSION
+
+
+class Ernie4_5VLProcessor(MMProcessor):
+ """Multimodal processor for ERNIE 4.5 VL.
+
+ Key differences from QwenVLProcessor:
+ - Uses AdaptiveImageProcessor
+ - Has boundary tokens (IMG_START/END, VID_START/END)
+ - Position IDs are list-of-lists (Nx3) format
+ - Overrides _write_back to preserve original prompt_token_ids on Path A
+ - Has get_mm_max_tokens_per_item for scheduler
+ """
+
+ # ---- Class-level declarations ----
+ image_placeholder = "<|image@placeholder|>"
+ video_placeholder = "<|video@placeholder|>"
+ image_token_str = "<|IMAGE_PLACEHOLDER|>"
+ video_token_str = "<|IMAGE_PLACEHOLDER|>"
+ tokenizer_type = "ernie4_5"
+
+ # Boundary token constants
+ IMG_START = "<|IMAGE_START|>"
+ IMG_END = "<|IMAGE_END|>"
+ VID_START = "<|VIDEO_START|>"
+ VID_END = "<|VIDEO_END|>"
+
+ def _init_extra(self, processor_kwargs):
+ """Ernie-specific extra initialisation."""
+ processor_kwargs = processor_kwargs or {}
+
+ # Image processor
+ self.image_processor = AdaptiveImageProcessor.from_pretrained(self.model_name_or_path)
+
+ # Conv params from image_processor
+ self.spatial_conv_size = self.image_processor.merge_size
+ self.temporal_conv_size = self.image_processor.temporal_conv_size
+
+ # Special token IDs
+ self.image_token_id = self.tokenizer.convert_tokens_to_ids(self.image_token_str)
+ self.video_token_id = self.tokenizer.convert_tokens_to_ids(self.video_token_str)
+
+ # Pixel bounds
+ self.image_min_pixels = processor_kwargs.get("image_min_pixels", 4 * 28 * 28)
+ self.image_max_pixels = processor_kwargs.get("image_max_pixels", 6177 * 28 * 28)
+ self.video_min_pixels = processor_kwargs.get("video_min_pixels", 299 * 28 * 28)
+ self.video_max_pixels = processor_kwargs.get("video_max_pixels", 1196 * 28 * 28)
+ self.frames_sample = processor_kwargs.get("video_frames_sample", self.default_frames_sample)
+
+ # Build token-type mapping for ernie boundary tokens
+ self.token_type_mapping = self._build_token_type_mapping()
+
+ def _init_role_prefixes(self) -> dict:
+ """Ernie supports a 'tool' role in addition to standard roles."""
+ return {
+ "system": "",
+ "user": "User: ",
+ "bot": "Assistant: ",
+ "assistant": "Assistant: ",
+ "tool": "Tool: ",
+ }
+
+ def _build_token_type_mapping(self):
+ mapping = defaultdict(lambda: IDS_TYPE_FLAG["text"])
+ for token in (self.IMG_START, self.IMG_END, self.VID_START, self.VID_END):
+ mapping[token] = IDS_TYPE_FLAG["image"]
+ mapping[self.image_token_id] = IDS_TYPE_FLAG["image"]
+ return mapping
+
+ # ------------------------------------------------------------------
+ # Image processing
+ # ------------------------------------------------------------------
+
+ def preprocess_image(self, img, outputs, uuid, token_len=None):
+ patches_h, patches_w = self.image_processor.get_smarted_resize(
+ img.height,
+ img.width,
+ min_pixels=self.image_min_pixels,
+ max_pixels=self.image_max_pixels,
+ )[1]
+ num_tokens = (patches_h * patches_w) // (self.spatial_conv_size**2)
+ if token_len and token_len != num_tokens:
+ raise ValueError("image tokens num not match the size")
+
+ outputs["mm_positions"].append(ImagePosition(len(outputs["input_ids"]), num_tokens))
+ outputs["input_ids"].extend([self.image_token_id] * num_tokens)
+ outputs["token_type_ids"].extend([IDS_TYPE_FLAG["image"]] * num_tokens)
+ outputs["num_input_image_tokens"] += num_tokens
+
+ pos_ids = self._compute_3d_positions(1, patches_h, patches_w, outputs["cur_position"])
+ outputs["position_ids"].extend(pos_ids)
+ outputs["cur_position"] = np.max(pos_ids) + 1
+
+ ret = self.image_processor.preprocess(
+ images=[img.convert("RGB")],
+ do_normalize=False,
+ do_rescale=False,
+ predetermined_grid_thw=np.array([[patches_h, patches_w]]),
+ do_convert_rgb=True,
+ input_data_format=ChannelDimension.LAST,
+ )
+ outputs["images"].append(ret["pixel_values"])
+ outputs["grid_thw"].append(ret["image_grid_thw"])
+ outputs["image_type_ids"].append(0)
+
+ def preprocess_cached_image(self, img_cache, outputs, uuid, token_len=None):
+ img, meta = img_cache
+ num_tokens = img.shape[0] // (self.spatial_conv_size**2)
+ if token_len and num_tokens != token_len:
+ raise ValueError("image tokens num not match the size")
+
+ outputs["mm_positions"].append(ImagePosition(len(outputs["input_ids"]), num_tokens))
+ outputs["input_ids"].extend([self.image_token_id] * num_tokens)
+ outputs["token_type_ids"].extend([IDS_TYPE_FLAG["image"]] * num_tokens)
+ outputs["num_input_image_tokens"] += num_tokens
+
+ _, h, w = meta["thw"]
+ pos_ids = self._compute_3d_positions(1, h, w, outputs["cur_position"])
+ outputs["position_ids"].extend(pos_ids)
+ outputs["cur_position"] = np.max(pos_ids) + 1
+
+ outputs["images"].append(img)
+ outputs["grid_thw"].append(np.array([[1, h, w]]))
+ outputs["image_type_ids"].append(0)
+
+ # ------------------------------------------------------------------
+ # Video processing
+ # ------------------------------------------------------------------
+
+ def preprocess_video(self, frames, outputs, uuid, token_len=None, meta=None):
+ patches_h, patches_w = self.image_processor.get_smarted_resize(
+ frames[0].height,
+ frames[0].width,
+ min_pixels=self.video_min_pixels,
+ max_pixels=self.video_max_pixels,
+ )[1]
+ num_frames = len(frames)
+ num_tokens = (num_frames * patches_h * patches_w) // (self.spatial_conv_size**2 * self.temporal_conv_size)
+ if token_len and num_tokens != token_len:
+ raise ValueError("video tokens num not match the size")
+
+ pixel_stack = np.stack([np.array(f.convert("RGB")) for f in frames], axis=0)
+ ret = self.image_processor.preprocess(
+ images=None,
+ videos=pixel_stack,
+ do_normalize=False,
+ do_rescale=False,
+ predetermined_grid_thw=np.array([[patches_h, patches_w]] * num_frames),
+ do_convert_rgb=True,
+ input_data_format=ChannelDimension.LAST,
+ )
+ outputs["images"].append(ret["pixel_values_videos"])
+ outputs["grid_thw"].append(ret["video_grid_thw"])
+ outputs["image_type_ids"].extend([1] * num_frames)
+
+ outputs["mm_positions"].append(ImagePosition(len(outputs["input_ids"]), num_tokens))
+ outputs["input_ids"].extend([self.image_token_id] * num_tokens)
+ outputs["token_type_ids"].extend([IDS_TYPE_FLAG["video"]] * num_tokens)
+ outputs["num_input_video_tokens"] += num_tokens
+
+ pos_ids = self._compute_3d_positions(num_frames, patches_h, patches_w, outputs["cur_position"])
+ outputs["position_ids"].extend(pos_ids)
+ outputs["cur_position"] = np.max(pos_ids) + 1
+
+ def preprocess_cached_video(self, frames_cache, outputs, uuid, token_len=None):
+ frames, meta = frames_cache
+ num_tokens = frames.shape[0] // (self.spatial_conv_size**2 * self.temporal_conv_size)
+ if token_len and num_tokens != token_len:
+ raise ValueError("video tokens num not match the size")
+
+ t, h, w = meta["thw"]
+ outputs["images"].append(frames)
+ outputs["grid_thw"].append(np.array([[t, h, w]]))
+
+ outputs["mm_positions"].append(ImagePosition(len(outputs["input_ids"]), num_tokens))
+ outputs["input_ids"].extend([self.image_token_id] * num_tokens)
+ outputs["token_type_ids"].extend([IDS_TYPE_FLAG["video"]] * num_tokens)
+ outputs["num_input_video_tokens"] += num_tokens
+ outputs["image_type_ids"].extend([1] * t)
+
+ pos_ids = self._compute_3d_positions(t, h, w, outputs["cur_position"])
+ outputs["position_ids"].extend(pos_ids)
+ outputs["cur_position"] = np.max(pos_ids) + 1
+
+ def load_video(self, url, item):
+ from fastdeploy.input.utils.render_timestamp import render_frame_timestamp
+ from fastdeploy.input.utils.video import read_frames_decord, read_video_decord
+
+ reader, meta, path = read_video_decord(url, save_to_disk=False)
+
+ video_frame_args = {
+ "fps": item.get("fps", self.fps),
+ "min_frames": item.get("min_frames", self.min_frames),
+ "max_frames": item.get("max_frames", self.max_frames),
+ "target_frames": item.get("target_frames", self.target_frames),
+ "frames_sample": item.get("frames_sample", self.frames_sample),
+ }
+ video_frame_args = self._set_video_frame_args(video_frame_args, meta)
+
+ frames_data, _, timestamps = read_frames_decord(
+ path,
+ reader,
+ meta,
+ target_frames=video_frame_args["target_frames"],
+ target_fps=video_frame_args["fps"],
+ frames_sample=video_frame_args["frames_sample"],
+ save_to_disk=False,
+ )
+
+ frames = []
+ for img_array, ts in zip(frames_data, timestamps):
+ frames.append(render_frame_timestamp(img_array, ts))
+ # Ensure even number of frames for temporal conv
+ if len(frames) % 2 != 0:
+ frames.append(copy.deepcopy(frames[-1]))
+ return frames, {}
+
+ def _set_video_frame_args(self, video_frame_args, video_meta):
+ """Set final frame sampling args based on priorities."""
+ if video_frame_args["target_frames"] > 0:
+ if video_frame_args["fps"] >= 0:
+ raise ValueError("fps must be negative if target_frames is given")
+ if (
+ video_frame_args["min_frames"] > 0
+ and video_frame_args["target_frames"] < video_frame_args["min_frames"]
+ ):
+ raise ValueError("target_frames must be larger than min_frames")
+ if (
+ video_frame_args["max_frames"] > 0
+ and video_frame_args["target_frames"] > video_frame_args["max_frames"]
+ ):
+ raise ValueError("target_frames must be smaller than max_frames")
+ else:
+ if video_frame_args["fps"] < 0:
+ raise ValueError("Must provide either positive target_fps or positive target_frames.")
+ frames_to_extract = int(video_meta["duration"] * video_frame_args["fps"])
+ if (
+ video_frame_args["min_frames"] > 0
+ and video_frame_args["max_frames"] > 0
+ and video_frame_args["min_frames"] > video_frame_args["max_frames"]
+ ):
+ raise ValueError("min_frames must be smaller than max_frames")
+ if video_frame_args["min_frames"] > 0 and frames_to_extract < video_frame_args["min_frames"]:
+ video_frame_args["target_frames"] = video_frame_args["min_frames"]
+ video_frame_args["fps"] = -1
+ if video_frame_args["max_frames"] > 0 and frames_to_extract > video_frame_args["max_frames"]:
+ video_frame_args["target_frames"] = video_frame_args["max_frames"]
+ video_frame_args["fps"] = -1
+ return video_frame_args
+
+ # ------------------------------------------------------------------
+ # Position IDs
+ # ------------------------------------------------------------------
+
+ def add_text_positions(self, outputs, num_tokens):
+ """Write text position IDs in ernie [pos, pos, pos] format."""
+ start = outputs["cur_position"]
+ for i in range(num_tokens):
+ outputs["position_ids"].append([start + i] * 3)
+ outputs["cur_position"] += num_tokens
+
+ def append_completion_tokens(self, multimodal_inputs, completion_token_ids):
+ num_tokens = len(completion_token_ids)
+ multimodal_inputs["input_ids"].extend(completion_token_ids)
+ multimodal_inputs["token_type_ids"].extend([IDS_TYPE_FLAG["text"]] * num_tokens)
+
+ start = multimodal_inputs["cur_position"]
+ for i in range(num_tokens):
+ multimodal_inputs["position_ids"].append([start + i] * 3)
+ multimodal_inputs["cur_position"] += num_tokens
+
+ def pack_position_ids(self, outputs):
+ """Ernie: position_ids is np.array (list-of-lists -> ndarray)."""
+ outputs["position_ids"] = np.array(outputs["position_ids"], dtype=np.int64)
+ outputs["image_patch_id"] = self.image_token_id
+
+ def _compute_3d_positions(self, t, h, w, start_idx):
+ """Compute 3D position IDs as list-of-lists for ernie format."""
+ t_eff = t // self.temporal_conv_size if t != 1 else 1
+ gh, gw = h // self.spatial_conv_size, w // self.spatial_conv_size
+ time_idx = np.repeat(np.arange(t_eff), gh * gw)
+ h_idx = np.tile(np.repeat(np.arange(gh), gw), t_eff)
+ w_idx = np.tile(np.arange(gw), t_eff * gh)
+
+ coords = list(zip(time_idx, h_idx, w_idx))
+ return [[start_idx + ti, start_idx + hi, start_idx + wi] for ti, hi, wi in coords]
+
+ # ------------------------------------------------------------------
+ # Token counting
+ # ------------------------------------------------------------------
+
+ @staticmethod
+ def mm_num_tokens(grid_thw):
+ """Ernie mm_num_tokens: video (t>1) divides by an extra 2."""
+ if isinstance(grid_thw, paddle.Tensor):
+ grid_thw = grid_thw.numpy()
+ if len(grid_thw) == 0:
+ return 0
+
+ def calc_one(thw):
+ t, h, w = map(int, thw)
+ if t == 1:
+ return t * h * w // 4
+ else:
+ return t * h * w // 4 // 2
+
+ if isinstance(grid_thw[0], (list, tuple, np.ndarray)):
+ return [calc_one(x) for x in grid_thw]
+ return calc_one(grid_thw)
+
+ # ------------------------------------------------------------------
+ # Prompt token IDs path
+ # ------------------------------------------------------------------
+
+ def prompt_token_ids2outputs(self, mm_context):
+ outputs = self._make_outputs()
+ prompt_token_ids = mm_context.prompt_token_ids
+ prompt_token_ids_len = len(prompt_token_ids)
+
+ if not mm_context.images and not mm_context.videos:
+ outputs["input_ids"].extend(prompt_token_ids)
+ outputs["token_type_ids"].extend([IDS_TYPE_FLAG["text"]] * prompt_token_ids_len)
+ for i in range(prompt_token_ids_len):
+ outputs["position_ids"].append([i] * 3)
+ outputs["cur_position"] += prompt_token_ids_len
+ return outputs
+
+ images = mm_context.images
+ videos = mm_context.videos
+
+ image_start_id = self.tokenizer.convert_tokens_to_ids(self.IMG_START)
+ image_end_id = self.tokenizer.convert_tokens_to_ids(self.IMG_END)
+ video_start_id = self.tokenizer.convert_tokens_to_ids(self.VID_START)
+ video_end_id = self.tokenizer.convert_tokens_to_ids(self.VID_END)
+
+ st, image_idx, video_idx = 0, 0, 0
+ while st < prompt_token_ids_len:
+ cur_token_id = prompt_token_ids[st]
+ if cur_token_id == image_start_id:
+ if image_idx >= len(images):
+ raise ValueError("prompt token ids has more image placeholder than in messages")
+ # append image_start_id
+ outputs["input_ids"].extend([cur_token_id])
+ outputs["token_type_ids"].extend([IDS_TYPE_FLAG["image"]])
+ outputs["position_ids"].append([outputs["cur_position"]] * 3)
+ outputs["cur_position"] += 1
+ st += 1
+ # process placeholder token ids
+ cur_idx = st
+ while cur_idx < prompt_token_ids_len and prompt_token_ids[cur_idx] != image_end_id:
+ cur_idx += 1
+ if cur_idx >= prompt_token_ids_len:
+ raise ValueError("image token ids not complete")
+ item = images[image_idx]
+ token_len = cur_idx - st
+ if not isinstance(item.data, tuple):
+ self.preprocess_image(item.data, outputs, item.uuid, token_len)
+ else:
+ self.preprocess_cached_image(item.data, outputs, item.uuid, token_len)
+ image_idx += 1
+ st = cur_idx
+ elif cur_token_id == video_start_id:
+ if video_idx >= len(videos):
+ raise ValueError("prompt token ids has more video placeholder than in messages")
+ # append video_start_id
+ outputs["input_ids"].extend([cur_token_id])
+ outputs["token_type_ids"].extend([IDS_TYPE_FLAG["image"]])
+ outputs["position_ids"].append([outputs["cur_position"]] * 3)
+ outputs["cur_position"] += 1
+ st += 1
+ # process placeholder token ids
+ cur_idx = st
+ while cur_idx < prompt_token_ids_len and prompt_token_ids[cur_idx] != video_end_id:
+ cur_idx += 1
+ if cur_idx >= prompt_token_ids_len:
+ raise ValueError("video token ids not complete")
+ item = videos[video_idx]
+ token_len = cur_idx - st
+ if not isinstance(item.data, tuple):
+ if isinstance(item.data, dict):
+ frames, _ = self.load_video(item.data["video"], item.data)
+ else:
+ frames, _ = self.load_video(item.data, {})
+ self.preprocess_video(frames, outputs, item.uuid, token_len=token_len)
+ else:
+ self.preprocess_cached_video(item.data, outputs, item.uuid, token_len)
+ video_idx += 1
+ st = cur_idx
+ else:
+ outputs["input_ids"].extend([cur_token_id])
+ type_flag = (
+ IDS_TYPE_FLAG["image"] if cur_token_id in (image_end_id, video_end_id) else IDS_TYPE_FLAG["text"]
+ )
+ outputs["token_type_ids"].extend([type_flag])
+ outputs["position_ids"].append([outputs["cur_position"]] * 3)
+ outputs["cur_position"] += 1
+ st += 1
+
+ if image_idx != len(images):
+ raise ValueError("number of images does not match")
+ if video_idx != len(videos):
+ raise ValueError("number of videos does not match")
+
+ return outputs
+
+ # ------------------------------------------------------------------
+ # Scheduler helper
+ # ------------------------------------------------------------------
+
+ def get_mm_max_tokens_per_item(self, seq_len):
+ """Per-modality max token counts for ernie."""
+ target_height, target_width = self._get_image_size_with_most_features()
+ # image
+ patches_h, patches_w = self.image_processor.get_smarted_resize(
+ height=target_height,
+ width=target_width,
+ min_pixels=self.image_min_pixels,
+ max_pixels=self.image_max_pixels,
+ )[1]
+ max_image_tokens = (patches_h * patches_w) // (self.spatial_conv_size**2)
+ max_image_tokens = min(max_image_tokens, seq_len)
+ # video
+ patches_h, patches_w = self.image_processor.get_smarted_resize(
+ height=target_height,
+ width=target_width,
+ min_pixels=self.video_min_pixels,
+ max_pixels=self.video_max_pixels,
+ )[1]
+ max_video_tokens = (patches_h * patches_w) // (self.spatial_conv_size**2 * self.temporal_conv_size)
+ max_video_tokens = min(max_video_tokens, seq_len)
+ return {"image": max_image_tokens, "video": max_video_tokens}
+
+ def _get_image_size_with_most_features(self):
+ resized_height, resized_width = self.image_processor.get_smarted_resize(
+ height=MAX_IMAGE_DIMENSION,
+ width=MAX_IMAGE_DIMENSION,
+ min_pixels=self.image_min_pixels,
+ max_pixels=self.image_max_pixels,
+ )[0]
+ return (resized_height, resized_width)
diff --git a/fastdeploy/input/multimodal/image_processors/__init__.py b/fastdeploy/input/multimodal/image_processors/__init__.py
new file mode 100644
index 00000000000..59573ea4856
--- /dev/null
+++ b/fastdeploy/input/multimodal/image_processors/__init__.py
@@ -0,0 +1,25 @@
+# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from .ernie import AdaptiveImageProcessor
+from .paddleocr import PaddleOCRImageProcessor
+from .qwen import QwenImageProcessor
+from .qwen3 import Qwen3ImageProcessor
+
+__all__ = [
+ "QwenImageProcessor",
+ "Qwen3ImageProcessor",
+ "AdaptiveImageProcessor",
+ "PaddleOCRImageProcessor",
+]
diff --git a/fastdeploy/input/multimodal/image_processors/ernie.py b/fastdeploy/input/multimodal/image_processors/ernie.py
new file mode 100644
index 00000000000..d29f07547ea
--- /dev/null
+++ b/fastdeploy/input/multimodal/image_processors/ernie.py
@@ -0,0 +1,379 @@
+"""
+# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+
+"""AdaptiveImageProcessor for ERNIE 4.5 VL model."""
+
+from typing import List, Optional, Union
+
+import numpy as np
+import paddle
+import PIL
+from paddleformers.transformers.feature_extraction_utils import BatchFeature
+from paddleformers.transformers.image_processing_utils import BaseImageProcessor
+from paddleformers.transformers.image_transforms import (
+ convert_to_rgb,
+ normalize,
+ rescale,
+ resize,
+ to_channel_dimension_format,
+)
+from paddleformers.transformers.image_utils import (
+ ChannelDimension,
+ ImageInput,
+ PILImageResampling,
+ get_image_size,
+ infer_channel_dimension_format,
+ is_valid_image,
+ make_list_of_images,
+ to_numpy_array,
+ valid_images,
+)
+from paddleformers.transformers.legacy.tokenizer_utils_base import TensorType
+from PIL import Image
+
+from fastdeploy.input.multimodal.common import is_scaled_image
+from fastdeploy.input.multimodal.common import smart_resize_qwen as smart_resize
+from fastdeploy.utils import data_processor_logger
+
+OPENAI_CLIP_MEAN = [0.48145466, 0.4578275, 0.40821073]
+OPENAI_CLIP_STD = [0.26862954, 0.26130258, 0.27577711]
+
+IMAGE_FACTOR = 28
+MIN_PIXELS = 4 * 28 * 28
+MAX_PIXELS = 16384 * 28 * 28
+MAX_RATIO = 200
+
+
+VideoInput = Union[
+ List["PIL.Image.Image"],
+ "np.ndarray",
+ "paddle.Tensor",
+ List["np.ndarray"],
+ List["paddle.Tensor"],
+ List[List["PIL.Image.Image"]],
+ List[List["np.ndarray"]],
+ List[List["paddle.Tensor"]],
+]
+
+
+__all__ = [
+ "AdaptiveImageProcessor",
+ "get_image_preprocessor",
+ "make_batched_images",
+ "make_batched_videos",
+]
+
+
+def make_batched_images(images) -> List[List[ImageInput]]:
+ if isinstance(images, (list, tuple)) and isinstance(images[0], (list, tuple)) and is_valid_image(images[0][0]):
+ return [img for img_list in images for img in img_list]
+ elif isinstance(images, (list, tuple)) and is_valid_image(images[0]):
+ return images
+ elif is_valid_image(images):
+ return [images]
+ raise ValueError(f"Could not make batched images from {images}")
+
+
+def make_batched_videos(videos) -> List[VideoInput]:
+ if isinstance(videos, (list, tuple)) and isinstance(videos[0], (list, tuple)) and is_valid_image(videos[0][0]):
+ return videos
+ elif isinstance(videos, (list, tuple)) and is_valid_image(videos[0]):
+ if isinstance(videos[0], Image.Image):
+ return [videos]
+ elif len(videos[0].shape) == 4:
+ return [list(video) for video in videos]
+ elif is_valid_image(videos) and len(videos.shape) == 4:
+ return [list(videos)]
+ raise ValueError(f"Could not make batched video from {videos}")
+
+
+class AdaptiveImageProcessor(BaseImageProcessor):
+ """Adaptive image processor for ERNIE 4.5 VL."""
+
+ model_input_names = [
+ "pixel_values",
+ "image_grid_thw",
+ "pixel_values_videos",
+ "video_grid_thw",
+ ]
+
+ def __init__(
+ self,
+ do_resize: bool = True,
+ resample: PILImageResampling = PILImageResampling.BICUBIC,
+ do_rescale: bool = True,
+ rescale_factor: float = 1 / 255,
+ do_normalize: bool = True,
+ image_mean: Optional[Union[float, List[float]]] = None,
+ image_std: Optional[Union[float, List[float]]] = None,
+ do_convert_rgb: bool = True,
+ min_pixels: int = 56 * 56,
+ max_pixels: int = 28 * 28 * 1280,
+ patch_size: int = 14,
+ temporal_conv_size: int = 2,
+ merge_size: int = 2,
+ **kwargs,
+ ) -> None:
+ super().__init__(**kwargs)
+ self.do_resize = do_resize
+ self.resample = resample
+ self.do_rescale = do_rescale
+ self.rescale_factor = rescale_factor
+ self.do_normalize = do_normalize
+ self.image_mean = image_mean if image_mean is not None else OPENAI_CLIP_MEAN
+ self.image_std = image_std if image_std is not None else OPENAI_CLIP_STD
+ self.min_pixels = min_pixels
+ self.max_pixels = max_pixels
+ self.patch_size = patch_size
+ self.temporal_conv_size = temporal_conv_size
+ self.merge_size = merge_size
+ self.size = {"min_pixels": min_pixels, "max_pixels": max_pixels}
+ self.do_convert_rgb = do_convert_rgb
+
+ def set_pixels(self, min_pixels=None, max_pixels=None, msg=""):
+ if min_pixels is not None:
+ assert isinstance(min_pixels, int) and min_pixels >= 0, "min_pixels must be positive int"
+ data_processor_logger.info(f"{msg} AdaptiveImageProcessor set min_pixels = {min_pixels}")
+ self.min_pixels = min_pixels
+ self.size["min_pixels"] = int(min_pixels)
+ if max_pixels is not None:
+ assert isinstance(max_pixels, int) and max_pixels > 0, "max_pixels must be positive int"
+ data_processor_logger.info(f"{msg} AdaptiveImageProcessor set max_pixels = {max_pixels}")
+ self.max_pixels = max_pixels
+ self.size["max_pixels"] = int(max_pixels)
+
+ def get_smarted_resize(self, height, width, min_pixels=None, max_pixels=None):
+ actual_min_pixels = min_pixels if min_pixels is not None else self.min_pixels
+ actual_max_pixels = max_pixels if max_pixels is not None else self.max_pixels
+ resized_height, resized_width = smart_resize(
+ height,
+ width,
+ factor=self.patch_size * self.merge_size,
+ min_pixels=actual_min_pixels,
+ max_pixels=actual_max_pixels,
+ )
+ return (resized_height, resized_width), (
+ resized_height // self.patch_size,
+ resized_width // self.patch_size,
+ )
+
+ def _preprocess(
+ self,
+ images: Union[ImageInput, VideoInput],
+ do_resize: bool = True,
+ resample: PILImageResampling = None,
+ do_rescale: bool = True,
+ rescale_factor: float = 1 / 255,
+ do_normalize: bool = True,
+ image_mean: Optional[Union[float, List[float]]] = None,
+ image_std: Optional[Union[float, List[float]]] = None,
+ do_convert_rgb: bool = False,
+ data_format: Optional[ChannelDimension] = ChannelDimension.FIRST,
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
+ predetermined_grid_thw=None,
+ ):
+ images = make_list_of_images(images)
+
+ if do_convert_rgb:
+ images = [convert_to_rgb(image) for image in images]
+
+ images = [to_numpy_array(image) for image in images]
+
+ if is_scaled_image(images[0]) and do_rescale:
+ data_processor_logger.warning(
+ "It looks like you are trying to rescale already rescaled images. If the input"
+ " images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again."
+ )
+ if input_data_format is None:
+ input_data_format = infer_channel_dimension_format(images[0])
+
+ height, width = get_image_size(images[0], channel_dim=input_data_format)
+ resized_height, resized_width = height, width
+ processed_images = []
+
+ if predetermined_grid_thw is not None:
+ assert len(predetermined_grid_thw) == len(
+ images
+ ), f"len(predetermined_grid_thw) {len(predetermined_grid_thw)} == len(images) {len(images)}"
+
+ for img_idx, image in enumerate(images):
+ if do_resize:
+ if predetermined_grid_thw is not None:
+ (resized_height, resized_width) = predetermined_grid_thw[img_idx]
+ resized_height *= self.patch_size
+ resized_width *= self.patch_size
+ else:
+ resized_height, resized_width = smart_resize(
+ height,
+ width,
+ factor=self.patch_size * self.merge_size,
+ min_pixels=self.min_pixels,
+ max_pixels=self.max_pixels,
+ )
+ image = image.astype("uint8")
+ image = Image.fromarray(image)
+ image = resize(
+ image,
+ size=(resized_height, resized_width),
+ resample=resample,
+ data_format=input_data_format,
+ )
+ if do_rescale:
+ image = rescale(image, scale=rescale_factor, data_format=input_data_format)
+
+ if do_normalize:
+ image = normalize(
+ image=image,
+ mean=image_mean,
+ std=image_std,
+ data_format=input_data_format,
+ )
+
+ image = to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format)
+ processed_images.append(image)
+
+ patches = np.array(processed_images)
+ if data_format == ChannelDimension.LAST:
+ patches = patches.transpose([0, 3, 1, 2])
+
+ channel = patches.shape[1]
+ grid_t = patches.shape[0]
+ grid_h, grid_w = (
+ resized_height // self.patch_size,
+ resized_width // self.patch_size,
+ )
+ patches = patches.reshape(
+ [
+ grid_t,
+ channel,
+ grid_h // self.merge_size,
+ self.merge_size,
+ self.patch_size,
+ grid_w // self.merge_size,
+ self.merge_size,
+ self.patch_size,
+ ]
+ )
+ patches = patches.transpose([0, 2, 5, 3, 6, 1, 4, 7])
+
+ flatten_patches = patches.reshape(
+ [
+ grid_t * grid_h * grid_w,
+ channel * self.patch_size * self.patch_size,
+ ]
+ )
+
+ return flatten_patches, (grid_t, grid_h, grid_w)
+
+ def preprocess(
+ self,
+ images: ImageInput,
+ videos: VideoInput = None,
+ do_resize: bool = True,
+ size: Optional[Union[int, List[int]]] = None,
+ resample: PILImageResampling = None,
+ do_rescale: bool = True,
+ rescale_factor: float = 1 / 255,
+ do_normalize: bool = True,
+ image_mean: Optional[Union[float, List[float]]] = None,
+ image_std: Optional[Union[float, List[float]]] = None,
+ do_convert_rgb: bool = False,
+ return_tensors: Optional[Union[str, TensorType]] = None,
+ data_format: Optional[ChannelDimension] = ChannelDimension.FIRST,
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
+ predetermined_grid_thw=None,
+ ):
+ do_resize = do_resize if do_resize is not None else self.do_resize
+ size = size if size is not None else self.size
+ resample = resample if resample is not None else self.resample
+ do_rescale = do_rescale if do_rescale is not None else self.do_rescale
+ rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
+ do_normalize = do_normalize if do_normalize is not None else self.do_normalize
+ image_mean = image_mean if image_mean is not None else self.image_mean
+ image_std = image_std if image_std is not None else self.image_std
+ do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb
+
+ if images is not None:
+ images = make_batched_images(images)
+ if videos is not None:
+ videos = make_batched_videos(videos)
+
+ if images is not None and not valid_images(images):
+ raise ValueError("Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " "paddle.Tensor.")
+
+ data = {}
+
+ if images is not None:
+ pixel_values, vision_grid_thws = [], []
+ for img_idx, image in enumerate(images):
+ if predetermined_grid_thw is not None:
+ predetermined_grid_thw_one = [predetermined_grid_thw[img_idx]]
+ else:
+ predetermined_grid_thw_one = None
+ patches, image_grid_thw = self._preprocess(
+ image,
+ do_resize=do_resize,
+ resample=resample,
+ do_rescale=do_rescale,
+ rescale_factor=rescale_factor,
+ do_normalize=do_normalize,
+ image_mean=image_mean,
+ image_std=image_std,
+ data_format=data_format,
+ do_convert_rgb=do_convert_rgb,
+ input_data_format=input_data_format,
+ predetermined_grid_thw=predetermined_grid_thw_one,
+ )
+ pixel_values.extend(patches)
+ vision_grid_thws.append(image_grid_thw)
+ pixel_values = np.array(pixel_values)
+ vision_grid_thws = np.array(vision_grid_thws)
+ data["pixel_values"] = pixel_values
+ data["image_grid_thw"] = vision_grid_thws
+
+ if videos is not None:
+ pixel_values, vision_grid_thws = [], []
+ for images in videos:
+ patches, video_grid_thw = self._preprocess(
+ images,
+ do_resize=do_resize,
+ resample=resample,
+ do_rescale=do_rescale,
+ rescale_factor=rescale_factor,
+ do_normalize=do_normalize,
+ image_mean=image_mean,
+ image_std=image_std,
+ data_format=data_format,
+ do_convert_rgb=do_convert_rgb,
+ input_data_format=input_data_format,
+ predetermined_grid_thw=predetermined_grid_thw,
+ )
+ pixel_values.extend(patches)
+ vision_grid_thws.append(video_grid_thw)
+ pixel_values = np.array(pixel_values)
+ vision_grid_thws = np.array(vision_grid_thws)
+ data["pixel_values_videos"] = pixel_values
+ data["video_grid_thw"] = vision_grid_thws
+
+ return BatchFeature(data=data, tensor_type=return_tensors)
+
+
+def get_image_preprocessor(args):
+ if args.vision_model_name_or_path is None:
+ return None
+ data_processor_logger.info("use AdaptiveImageProcessor")
+ image_preprocess = AdaptiveImageProcessor.from_pretrained(args.vision_model_name_or_path)
+ return image_preprocess
diff --git a/fastdeploy/input/multimodal/image_processors/paddleocr.py b/fastdeploy/input/multimodal/image_processors/paddleocr.py
new file mode 100644
index 00000000000..39f222fe87d
--- /dev/null
+++ b/fastdeploy/input/multimodal/image_processors/paddleocr.py
@@ -0,0 +1,214 @@
+"""
+# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+
+"""PaddleOCRImageProcessor for PaddleOCR-VL model."""
+
+import json
+from pathlib import Path
+from typing import Dict, List, Optional, Union
+
+import numpy as np
+from paddleformers.transformers.feature_extraction_utils import BatchFeature
+from paddleformers.transformers.image_processing_utils import BaseImageProcessor
+from paddleformers.transformers.image_utils import (
+ ImageInput,
+ is_valid_image,
+ make_list_of_images,
+ to_numpy_array,
+)
+
+from fastdeploy.input.multimodal.common import smart_resize_paddleocr as smart_resize
+
+_OPENAI_CLIP_MEAN = [0.48145466, 0.4578275, 0.40821073]
+_OPENAI_CLIP_STD = [0.26862954, 0.26130258, 0.27577711]
+
+
+def make_batched_images(images) -> List[ImageInput]:
+ if isinstance(images, (list, tuple)) and isinstance(images[0], (list, tuple)) and is_valid_image(images[0][0]):
+ return [img for img_list in images for img in img_list]
+ elif isinstance(images, (list, tuple)) and is_valid_image(images[0]):
+ return images
+ elif is_valid_image(images):
+ return [images]
+ raise ValueError(f"Could not make batched images from {images}")
+
+
+def adjust_size(size, patch_size):
+ num_patches = size // patch_size
+ if num_patches % 2 != 0:
+ num_patches -= 1
+ return num_patches * patch_size
+
+
+class PaddleOCRImageProcessor(BaseImageProcessor):
+ """Image processor for PaddleOCR-VL. temporal_patch_size=1, 4D output."""
+
+ model_input_names = [
+ "pixel_values",
+ "image_grid_thw",
+ "pixel_values_videos",
+ "video_grid_thw",
+ ]
+
+ def __init__(
+ self,
+ do_resize: bool = True,
+ resample: int = 3,
+ do_rescale: bool = True,
+ rescale_factor: Union[int, float] = 1 / 255,
+ do_normalize: bool = True,
+ image_mean: Optional[Union[float, List[float]]] = None,
+ image_std: Optional[Union[float, List[float]]] = None,
+ do_convert_rgb: bool = True,
+ min_pixels: int = 28 * 28 * 130,
+ max_pixels: int = 28 * 28 * 1280,
+ patch_size: int = 14,
+ temporal_patch_size: int = 1,
+ merge_size: int = 2,
+ **kwargs,
+ ) -> None:
+ super().__init__()
+ self.do_resize = do_resize
+ self.resample = resample
+ self.do_rescale = do_rescale
+ self.rescale_factor = rescale_factor
+ self.do_normalize = do_normalize
+ self.image_mean = image_mean if image_mean is not None else _OPENAI_CLIP_MEAN
+ self.image_std = image_std if image_std is not None else _OPENAI_CLIP_STD
+ self.min_pixels = min_pixels
+ self.max_pixels = max_pixels
+ self.patch_size = patch_size
+ self.temporal_patch_size = temporal_patch_size
+ self.merge_size = merge_size
+ self.size = {"min_pixels": min_pixels, "max_pixels": max_pixels}
+ self.do_convert_rgb = do_convert_rgb
+
+ @classmethod
+ def from_pretrained(cls, pretrained_model_dir):
+ pretrained_model_dir = Path(pretrained_model_dir)
+ image_processor_config_path = pretrained_model_dir / "preprocessor_config.json"
+ with open(image_processor_config_path, "r", encoding="utf-8") as f:
+ image_processor_config = json.load(f)
+ return cls(**image_processor_config)
+
+ def _preprocess(
+ self,
+ images,
+ do_resize: Optional[bool] = None,
+ do_rescale: Optional[bool] = None,
+ rescale_factor: Optional[float] = None,
+ do_normalize: Optional[bool] = None,
+ image_mean: Optional[Union[float, List[float]]] = None,
+ image_std: Optional[Union[float, List[float]]] = None,
+ do_convert_rgb: Optional[bool] = None,
+ ):
+ images = make_list_of_images(images)
+
+ if do_convert_rgb:
+ images = [image.convert("RGB") for image in images]
+
+ width, height = images[0].size
+ resized_height, resized_width = height, width
+ processed_images = []
+
+ for image in images:
+ if do_resize:
+ resized_height, resized_width = smart_resize(
+ height,
+ width,
+ factor=self.patch_size * self.merge_size,
+ min_pixels=self.min_pixels,
+ max_pixels=self.max_pixels,
+ )
+
+ image = image.resize((resized_width, resized_height), resample=self.resample)
+
+ image = to_numpy_array(image)
+
+ if do_rescale:
+ image = (image * rescale_factor).astype(np.float32)
+
+ if do_normalize:
+ image = image.astype(np.float32)
+ image -= np.array(image_mean, dtype=np.float32)
+ image /= np.array(image_std, dtype=np.float32)
+
+ processed_images.append(image)
+
+ patches = np.array(processed_images)
+ patches = patches.transpose(0, 3, 1, 2)
+ if patches.shape[0] == 1:
+ patches = np.tile(patches, (self.temporal_patch_size, 1, 1, 1))
+ channel = patches.shape[1]
+ grid_t = patches.shape[0] // self.temporal_patch_size
+ grid_h, grid_w = (
+ resized_height // self.patch_size,
+ resized_width // self.patch_size,
+ )
+
+ patches = patches.reshape(
+ grid_t,
+ self.temporal_patch_size,
+ channel,
+ grid_h,
+ self.patch_size,
+ grid_w,
+ self.patch_size,
+ )
+ patches = patches.transpose(0, 3, 5, 2, 1, 4, 6)
+ assert self.temporal_patch_size == 1
+ flatten_patches = patches.reshape(grid_t * grid_h * grid_w, channel, self.patch_size, self.patch_size)
+ return flatten_patches, np.array([grid_t, grid_h, grid_w])
+
+ def preprocess(
+ self,
+ images,
+ videos=None,
+ do_resize: Optional[bool] = None,
+ size: Optional[Dict[str, int]] = None,
+ do_rescale: Optional[bool] = None,
+ rescale_factor: Optional[float] = None,
+ do_normalize: Optional[bool] = None,
+ image_mean: Optional[Union[float, List[float]]] = None,
+ image_std: Optional[Union[float, List[float]]] = None,
+ do_convert_rgb: Optional[bool] = None,
+ return_tensors=None,
+ ):
+ do_resize = do_resize if do_resize is not None else self.do_resize
+ size = size if size is not None else self.size
+ do_rescale = do_rescale if do_rescale is not None else self.do_rescale
+ rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
+ do_normalize = do_normalize if do_normalize is not None else self.do_normalize
+ image_mean = image_mean if image_mean is not None else self.image_mean
+ image_std = image_std if image_std is not None else self.image_std
+ do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb
+
+ if videos is not None:
+ raise NotImplementedError("Videos are not yet supported")
+
+ patches, image_grid_thw = self._preprocess(
+ images,
+ do_resize=do_resize,
+ do_rescale=do_rescale,
+ rescale_factor=rescale_factor,
+ do_normalize=do_normalize,
+ image_mean=image_mean,
+ image_std=image_std,
+ do_convert_rgb=do_convert_rgb,
+ )
+ pixel_values = np.array(patches)
+ data = {"pixel_values": pixel_values, "grid_thw": image_grid_thw}
+ return BatchFeature(data=data, tensor_type=return_tensors)
diff --git a/fastdeploy/input/multimodal/image_processors/qwen.py b/fastdeploy/input/multimodal/image_processors/qwen.py
new file mode 100644
index 00000000000..ab7da5615bc
--- /dev/null
+++ b/fastdeploy/input/multimodal/image_processors/qwen.py
@@ -0,0 +1,253 @@
+"""
+# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+
+"""QwenImageProcessor for Qwen2.5-VL model."""
+
+from typing import List, Optional, Union
+
+import numpy as np
+import paddle
+import PIL
+from paddleformers.transformers.feature_extraction_utils import BatchFeature
+from paddleformers.transformers.image_processing_utils import BaseImageProcessor
+from paddleformers.transformers.image_transforms import (
+ normalize,
+ rescale,
+ resize,
+ to_channel_dimension_format,
+)
+from paddleformers.transformers.image_utils import (
+ ChannelDimension,
+ ImageInput,
+ PILImageResampling,
+ get_image_size,
+ infer_channel_dimension_format,
+ make_list_of_images,
+ to_numpy_array,
+ valid_images,
+)
+from paddleformers.transformers.legacy.tokenizer_utils_base import TensorType
+from PIL import Image
+
+from fastdeploy.input.multimodal.common import is_scaled_image, smart_resize
+from fastdeploy.utils import data_processor_logger
+
+OPENAI_CLIP_MEAN = [0.48145466, 0.4578275, 0.40821073]
+OPENAI_CLIP_STD = [0.26862954, 0.26130258, 0.27577711]
+
+MIN_PIXELS = 4 * 28 * 28
+MAX_PIXELS = 16384 * 28 * 28
+
+
+VideoInput = Union[
+ List["PIL.Image.Image"],
+ "np.ndarray",
+ "paddle.Tensor",
+ List["np.ndarray"],
+ List["paddle.Tensor"],
+ List[List["PIL.Image.Image"]],
+ List[List["np.ndarray"]],
+ List[List["paddle.Tensor"]],
+]
+
+
+class QwenImageProcessor(BaseImageProcessor):
+ """Image processor for Qwen2.5-VL. patch_size=14, CLIP mean/std."""
+
+ def __init__(
+ self,
+ patch_size: int = 14,
+ merge_size: int = 2,
+ temporal_patch_size: int = 2,
+ min_pixels: int = MIN_PIXELS,
+ max_pixels: int = MAX_PIXELS,
+ image_mean: Union[float, List[float]] = OPENAI_CLIP_MEAN,
+ image_std: Union[float, List[float]] = OPENAI_CLIP_STD,
+ rescale_factor: float = 1 / 255,
+ do_rescale: bool = True,
+ do_normalize: bool = True,
+ resample: PILImageResampling = PILImageResampling.BICUBIC,
+ **kwargs,
+ ) -> None:
+ super().__init__(**kwargs)
+ self.patch_size = patch_size
+ self.merge_size = merge_size
+ self.temporal_patch_size = temporal_patch_size
+
+ self.min_pixels = min_pixels
+ self.max_pixels = max_pixels
+
+ self.image_mean = image_mean
+ self.image_std = image_std
+ self.rescale_factor = rescale_factor
+ self.do_rescale = do_rescale
+ self.do_normalize = do_normalize
+
+ self.resample = resample
+
+ def _preprocess(
+ self,
+ images: Union[ImageInput, VideoInput],
+ min_pixels: int,
+ max_pixels: int,
+ image_mean: Optional[Union[float, List[float]]],
+ image_std: Optional[Union[float, List[float]]],
+ rescale_factor: float,
+ do_rescale: bool,
+ do_normalize: bool,
+ resample: PILImageResampling,
+ data_format: Optional[ChannelDimension],
+ input_data_format: Optional[Union[str, ChannelDimension]],
+ ):
+ images = make_list_of_images(images)
+ images = [to_numpy_array(image) for image in images]
+
+ if is_scaled_image(images[0]) and do_rescale:
+ data_processor_logger.warning(
+ "It looks like you are trying to rescale already rescaled images. If the input"
+ " images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again."
+ )
+ if input_data_format is None:
+ input_data_format = infer_channel_dimension_format(images[0])
+
+ height, width = get_image_size(images[0], channel_dim=input_data_format)
+ resized_height, resized_width = smart_resize(
+ height,
+ width,
+ factor=self.patch_size * self.merge_size,
+ min_pixels=min_pixels,
+ max_pixels=max_pixels,
+ )
+
+ processed_images = []
+ for image in images:
+ if height != resized_height or width != resized_width:
+ image = image.astype("uint8")
+ image = Image.fromarray(image)
+ image = resize(
+ image,
+ size=(resized_height, resized_width),
+ resample=resample,
+ data_format=input_data_format,
+ )
+
+ if do_rescale and do_normalize:
+ image_mean = np.array(image_mean, dtype=np.float32) * (1.0 / rescale_factor)
+ image_std = np.array(image_std, dtype=np.float32) * (1.0 / rescale_factor)
+ do_rescale = False
+
+ if do_rescale:
+ image = image.astype(np.float32)
+ image = rescale(image, scale=rescale_factor, data_format=input_data_format)
+
+ if do_normalize:
+ image = image.astype(np.float32)
+ image = normalize(
+ image=image,
+ mean=image_mean,
+ std=image_std,
+ data_format=input_data_format,
+ )
+
+ image = to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format)
+ processed_images.append(image)
+
+ patches = np.array(processed_images)
+
+ if patches.shape[0] % self.temporal_patch_size != 0:
+ repeats = np.repeat(
+ patches[-1][np.newaxis],
+ self.temporal_patch_size - (patches.shape[0] % self.temporal_patch_size),
+ axis=0,
+ )
+ patches = np.concatenate([patches, repeats], axis=0)
+
+ if data_format == ChannelDimension.LAST:
+ patches = patches.transpose([0, 3, 1, 2])
+
+ grid_t, channel = patches.shape[:2]
+ grid_t = grid_t // self.temporal_patch_size
+
+ grid_h, grid_w = (
+ resized_height // self.patch_size,
+ resized_width // self.patch_size,
+ )
+ patches = patches.reshape(
+ [
+ grid_t,
+ self.temporal_patch_size,
+ channel,
+ grid_h // self.merge_size,
+ self.merge_size,
+ self.patch_size,
+ grid_w // self.merge_size,
+ self.merge_size,
+ self.patch_size,
+ ]
+ )
+ patches = patches.transpose([0, 3, 6, 4, 7, 2, 1, 5, 8])
+
+ flatten_patches = patches.reshape(
+ [
+ grid_t * grid_h * grid_w,
+ channel * self.temporal_patch_size * self.patch_size * self.patch_size,
+ ]
+ )
+
+ return flatten_patches, np.array([grid_t, grid_h, grid_w])
+
+ def preprocess(
+ self,
+ images: Union[ImageInput, VideoInput],
+ min_pixels: Optional[int] = None,
+ max_pixels: Optional[int] = None,
+ image_mean: Optional[Union[float, List[float]]] = None,
+ image_std: Optional[Union[float, List[float]]] = None,
+ rescale_factor: Optional[float] = None,
+ do_rescale: Optional[bool] = None,
+ do_normalize: Optional[bool] = None,
+ resample: Optional[PILImageResampling] = None,
+ return_tensors: Optional[Union[str, TensorType]] = None,
+ data_format: Optional[ChannelDimension] = ChannelDimension.FIRST,
+ input_data_format: Optional[Union[str, ChannelDimension]] = ChannelDimension.LAST,
+ ):
+ min_pixels = min_pixels if min_pixels is not None else self.min_pixels
+ max_pixels = max_pixels if max_pixels is not None else self.max_pixels
+ image_mean = image_mean if image_mean is not None else self.image_mean
+ image_std = image_std if image_std is not None else self.image_std
+ rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
+ do_rescale = do_rescale if do_rescale is not None else self.do_rescale
+ do_normalize = do_normalize if do_normalize is not None else self.do_normalize
+ resample = resample if resample is not None else self.resample
+
+ if images is not None and not valid_images(images):
+ raise ValueError("Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " "paddle.Tensor.")
+
+ pixel_values, grid_thw = self._preprocess(
+ images,
+ min_pixels=min_pixels,
+ max_pixels=max_pixels,
+ image_mean=image_mean,
+ image_std=image_std,
+ rescale_factor=rescale_factor,
+ do_rescale=do_rescale,
+ do_normalize=do_normalize,
+ resample=resample,
+ data_format=data_format,
+ input_data_format=input_data_format,
+ )
+ data = {"pixel_values": pixel_values, "grid_thw": grid_thw}
+ return BatchFeature(data=data, tensor_type=return_tensors)
diff --git a/fastdeploy/input/multimodal/image_processors/qwen3.py b/fastdeploy/input/multimodal/image_processors/qwen3.py
new file mode 100644
index 00000000000..00af9d6ecc3
--- /dev/null
+++ b/fastdeploy/input/multimodal/image_processors/qwen3.py
@@ -0,0 +1,63 @@
+"""
+# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+
+"""Qwen3ImageProcessor — inherits QwenImageProcessor with different defaults."""
+
+from typing import List, Union
+
+from paddleformers.transformers.image_utils import PILImageResampling
+
+from .qwen import QwenImageProcessor
+
+IMAGE_MEAN = [0.5, 0.5, 0.5]
+IMAGE_STD = [0.5, 0.5, 0.5]
+
+MIN_PIXELS = 65536
+MAX_PIXELS = 16777216
+
+
+class Qwen3ImageProcessor(QwenImageProcessor):
+ """Image processor for Qwen3-VL. patch_size=16, mean/std=[0.5,0.5,0.5]."""
+
+ def __init__(
+ self,
+ patch_size: int = 16,
+ merge_size: int = 2,
+ temporal_patch_size: int = 2,
+ min_pixels: int = MIN_PIXELS,
+ max_pixels: int = MAX_PIXELS,
+ image_mean: Union[float, List[float]] = IMAGE_MEAN,
+ image_std: Union[float, List[float]] = IMAGE_STD,
+ rescale_factor: float = 1 / 255,
+ do_rescale: bool = True,
+ do_normalize: bool = True,
+ resample: PILImageResampling = PILImageResampling.BICUBIC,
+ **kwargs,
+ ) -> None:
+ super().__init__(
+ patch_size=patch_size,
+ merge_size=merge_size,
+ temporal_patch_size=temporal_patch_size,
+ min_pixels=min_pixels,
+ max_pixels=max_pixels,
+ image_mean=image_mean,
+ image_std=image_std,
+ rescale_factor=rescale_factor,
+ do_rescale=do_rescale,
+ do_normalize=do_normalize,
+ resample=resample,
+ **kwargs,
+ )
diff --git a/fastdeploy/input/multimodal/mm_processor.py b/fastdeploy/input/multimodal/mm_processor.py
new file mode 100644
index 00000000000..e65c79482b7
--- /dev/null
+++ b/fastdeploy/input/multimodal/mm_processor.py
@@ -0,0 +1,565 @@
+# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""MMProcessor abstract base class for multimodal processing.
+
+Only one public method: process(request).
+Responsible for converting prompt + multimodal_data into token IDs and
+pixel features, writing them back into the request dict.
+"""
+
+import pickle
+from abc import ABC, abstractmethod
+from collections.abc import Mapping
+from dataclasses import dataclass, field
+from enum import Enum
+from typing import Any, Dict, List, Optional, Tuple
+
+import numpy as np
+import zmq
+
+from fastdeploy.input.utils import IDS_TYPE_FLAG
+from fastdeploy.multimodal.hasher import MultimodalHasher
+from fastdeploy.utils import data_processor_logger
+
+_DEFAULT_MM_LIMITS = {"image": 1, "video": 1, "audio": 1}
+
+
+# ------------------------------------------------------------------
+# Data classes for structured multimodal context
+# ------------------------------------------------------------------
+
+
+class TokenizationPath(Enum):
+ """Processing path for multimodal requests."""
+
+ PRETOKENIZED = "pretokenized" # request already has prompt_token_ids
+ FROM_TEXT = "from_text" # prompt text + multimodal_data -> tokenize
+
+
+@dataclass
+class MMItem:
+ """A normalized multimodal element (image or video)."""
+
+ type: str # "image" | "video"
+ data: Any = None # raw data (PIL Image, frames, etc.) or None if pending cache fetch
+ uuid: Optional[str] = None
+
+
+@dataclass
+class MMContext:
+ """Normalized multimodal context passed between process() steps."""
+
+ images: List[MMItem] = field(default_factory=list)
+ videos: List[MMItem] = field(default_factory=list)
+ mm_order: List[str] = field(default_factory=list) # interleaved type order: ["image", "video", ...]
+ path: TokenizationPath = TokenizationPath.FROM_TEXT
+ prompt_token_ids: Optional[List[int]] = None # used by PRETOKENIZED path
+
+
+# ------------------------------------------------------------------
+# Cache client (centralized ZMQ connection management)
+# ------------------------------------------------------------------
+
+
+class _CacheClient:
+ """Lazy-initialized ZMQ DEALER client for processor cache."""
+
+ _IPC_ADDR = "ipc:///dev/shm/processor_cache.ipc"
+
+ def __init__(self):
+ self._socket = None
+
+ @property
+ def socket(self):
+ if self._socket is None:
+ ctx = zmq.Context()
+ self._socket = ctx.socket(zmq.DEALER)
+ self._socket.connect(self._IPC_ADDR)
+ return self._socket
+
+ def get(self, hashes: list) -> list:
+ """Retrieve cached multimodal data by hash list."""
+ req = pickle.dumps(hashes)
+ self.socket.send_multipart([b"", req])
+ _, resp = self.socket.recv_multipart()
+ items = pickle.loads(resp)
+ data_processor_logger.info(f"Get cache of mm_hashes: {hashes}")
+ return items
+
+ def put(self, hashes: list, items: list) -> None:
+ """Write processed multimodal items to cache."""
+ req = pickle.dumps((hashes, items))
+ self.socket.send_multipart([b"", req])
+ data_processor_logger.info(f"Update cache of mm_hashes: {hashes}")
+
+
+# ------------------------------------------------------------------
+# MMProcessor abstract base class
+# ------------------------------------------------------------------
+
+
+class MMProcessor(ABC):
+ """Abstract base class for multimodal processors.
+
+ Only public method: process(request) -> None
+ Uses a template method pattern: base class provides the orchestration
+ flow, subclasses implement hooks for model-specific logic.
+ """
+
+ # ---- Subclass must declare ----
+ image_placeholder: str = ""
+ video_placeholder: str = ""
+ image_token_str: str = ""
+ video_token_str: str = ""
+ tokenizer_type: str = "auto"
+
+ # ---- Video defaults (subclass can override) ----
+ default_min_frames: int = 4
+ default_max_frames: int = 768
+ default_target_frames: int = -1
+ default_fps: float = 2.0
+ default_frames_sample: str = "leading"
+
+ # ---- processor_kwargs type validation whitelist ----
+ expected_kwargs: Dict[str, type] = {}
+
+ def __init__(
+ self,
+ tokenizer,
+ model_name_or_path: str,
+ config=None,
+ processor_kwargs: Optional[dict] = None,
+ limit_mm_per_prompt: Optional[dict] = None,
+ enable_processor_cache: bool = False,
+ ):
+ self.tokenizer = tokenizer
+ self.model_name_or_path = model_name_or_path
+ self.config = config
+ self.enable_processor_cache = enable_processor_cache
+ self._cache = _CacheClient() if enable_processor_cache else None
+
+ kw = processor_kwargs or {}
+ self.fps = kw.get("video_fps", self.default_fps)
+ self.min_frames = kw.get("video_min_frames", self.default_min_frames)
+ self.max_frames = kw.get("video_max_frames", self.default_max_frames)
+ self.target_frames = kw.get("video_target_frames", self.default_target_frames)
+
+ self.role_prefixes = self._init_role_prefixes()
+ self.limit_mm_per_prompt = self._parse_limits(limit_mm_per_prompt)
+
+ # Subclass extra init hook
+ self._init_extra(processor_kwargs)
+
+ # ------------------------------------------------------------------
+ # Public interface (only method exposed to Processor)
+ # ------------------------------------------------------------------
+
+ def process(self, request: dict) -> None:
+ """Multimodal data processing (template method).
+
+ Reads from request:
+ request["prompt"] or request["prompt_token_ids"]
+ request["multimodal_data"]
+ request["messages"] (for prompt_token_ids path with media items)
+
+ Writes into request:
+ request["prompt_token_ids"]
+ request["multimodal_inputs"]
+ """
+ # Step 1: Resolve and normalize multimodal data
+ mm_context = self._resolve_mm_data(request)
+ # Step 2: Fetch missing data from cache (if enabled)
+ self._fetch_from_cache(mm_context)
+ # Step 3: Core tokenization + preprocessing
+ outputs = self._tokenize_and_preprocess(request, mm_context)
+ # Step 4: Append completion tokens (speculative decoding)
+ self._process_post_tokens(request, outputs)
+ # Step 5: Compute mm_hashes and update processor cache (before packing)
+ self._update_cache(mm_context, outputs)
+ # Step 6: Pack to numpy
+ outputs = self._pack_outputs(outputs)
+ # Step 7: Write back (subclass can override)
+ self._write_back(request, outputs)
+
+ # ------------------------------------------------------------------
+ # Step 1: Resolve multimodal data
+ # ------------------------------------------------------------------
+
+ def _resolve_mm_data(self, request: dict) -> MMContext:
+ """Parse request and build a normalized MMContext.
+
+ Multimodal data is read from request["multimodal_data"] (populated by
+ Processor.process_messages when messages are present).
+ Path is determined by whether prompt_token_ids exists.
+ """
+ if not request.get("prompt_token_ids") and not request.get("prompt"):
+ raise ValueError("Request must contain 'prompt_token_ids', 'prompt', or 'messages'")
+
+ mm_data = request.get("multimodal_data") or {}
+ raw_images = mm_data.get("image", [])
+ raw_videos = mm_data.get("video", [])
+ self._check_mm_limits({"image": raw_images, "video": raw_videos})
+
+ images = []
+ for img in raw_images:
+ if isinstance(img, dict):
+ images.append(MMItem(type="image", data=img.get("data"), uuid=img.get("uuid")))
+ else:
+ images.append(MMItem(type="image", data=img, uuid=None))
+
+ videos = []
+ for vid in raw_videos:
+ if isinstance(vid, dict):
+ videos.append(MMItem(type="video", data=vid.get("data"), uuid=vid.get("uuid")))
+ else:
+ videos.append(MMItem(type="video", data=vid, uuid=None))
+
+ # Interleaved type order: directly from mm_data, or default images-then-videos.
+ mm_order = mm_data.get("mm_order")
+ if not mm_order:
+ mm_order = ["image"] * len(images) + ["video"] * len(videos)
+
+ if request.get("prompt_token_ids"):
+ return MMContext(
+ images=images,
+ videos=videos,
+ mm_order=mm_order,
+ path=TokenizationPath.PRETOKENIZED,
+ prompt_token_ids=request["prompt_token_ids"],
+ )
+
+ if not request.get("prompt"):
+ raise ValueError("Request must contain 'prompt_token_ids', 'prompt', or 'messages'")
+
+ return MMContext(images=images, videos=videos, mm_order=mm_order, path=TokenizationPath.FROM_TEXT)
+
+ # ------------------------------------------------------------------
+ # Step 2: Fetch from cache
+ # ------------------------------------------------------------------
+
+ def _fetch_from_cache(self, mm_context: MMContext) -> None:
+ """Retrieve missing multimodal data from processor cache."""
+ missing_hashes = []
+ missing_items_ref = []
+
+ for item in mm_context.images + mm_context.videos:
+ if item.data is None and item.uuid is not None:
+ missing_hashes.append(item.uuid)
+ missing_items_ref.append(item)
+
+ if not missing_hashes:
+ return
+
+ if not self._cache:
+ raise ValueError("Missing items cannot be retrieved without processor cache.")
+
+ cached_data = self._cache.get(missing_hashes)
+ for i, data in enumerate(cached_data):
+ if not data:
+ raise ValueError(f"Missing item {i} not found in processor cache")
+ missing_items_ref[i].data = data
+
+ # ------------------------------------------------------------------
+ # Step 3: Tokenize and preprocess
+ # ------------------------------------------------------------------
+
+ def _tokenize_and_preprocess(self, request: dict, mm_context: MMContext) -> dict:
+ """Core tokenization and preprocessing, dispatching by path."""
+ if mm_context.path == TokenizationPath.PRETOKENIZED:
+ return self.prompt_token_ids2outputs(mm_context)
+ else:
+ return self._build_outputs_from_text(request["prompt"], mm_context)
+
+ def _build_outputs_from_text(self, text: str, mm_context: MMContext) -> dict:
+ """Build outputs by scanning text for placeholders and tokenizing segments.
+
+ All multimodal data in mm_context is already resolved (no cache logic here).
+ """
+ outputs = self._make_outputs()
+
+ IMAGE_PLACEHOLDER = self.image_placeholder
+ VIDEO_PLACEHOLDER = self.video_placeholder
+ IMAGE_PLACEHOLDER_LEN = len(IMAGE_PLACEHOLDER)
+ VIDEO_PLACEHOLDER_LEN = len(VIDEO_PLACEHOLDER)
+
+ st, image_idx, video_idx = 0, 0, 0
+ while st < len(text):
+ image_pos = text.find(IMAGE_PLACEHOLDER, st)
+ image_pos = len(text) if image_pos == -1 else image_pos
+ video_pos = text.find(VIDEO_PLACEHOLDER, st)
+ video_pos = len(text) if video_pos == -1 else video_pos
+ ed = min(image_pos, video_pos)
+
+ self._add_text(text[st:ed], outputs)
+ if ed == len(text):
+ break
+
+ if ed == image_pos:
+ if image_idx >= len(mm_context.images):
+ raise ValueError("prompt has more image placeholders than provided images")
+ mm_item = mm_context.images[image_idx]
+ if not isinstance(mm_item.data, tuple):
+ self.preprocess_image(mm_item.data, outputs, mm_item.uuid)
+ else:
+ self.preprocess_cached_image(mm_item.data, outputs, mm_item.uuid)
+ image_idx += 1
+ st = ed + IMAGE_PLACEHOLDER_LEN
+ else:
+ if video_idx >= len(mm_context.videos):
+ raise ValueError("prompt has more video placeholders than provided videos")
+ mm_item = mm_context.videos[video_idx]
+ if not isinstance(mm_item.data, tuple):
+ if isinstance(mm_item.data, dict):
+ frames, meta = self.load_video(mm_item.data.get("video", mm_item.data), mm_item.data)
+ else:
+ frames, meta = self.load_video(mm_item.data, {})
+ self.preprocess_video(frames, outputs, mm_item.uuid, meta=meta)
+ else:
+ self.preprocess_cached_video(mm_item.data, outputs, mm_item.uuid)
+ video_idx += 1
+ st = ed + VIDEO_PLACEHOLDER_LEN
+
+ return outputs
+
+ # ------------------------------------------------------------------
+ # Step 4: Post-tokens
+ # ------------------------------------------------------------------
+
+ def _process_post_tokens(self, request, outputs):
+ """Handle completion_token_ids for speculative decoding."""
+ completion_token_ids = request.get("completion_token_ids") or request.get("generated_token_ids")
+ if completion_token_ids:
+ self.append_completion_tokens(outputs, completion_token_ids)
+
+ # ------------------------------------------------------------------
+ # Step 6: Pack outputs
+ # ------------------------------------------------------------------
+
+ def _pack_outputs(self, outputs) -> dict:
+ """Convert lists to numpy arrays."""
+ if not outputs["images"]:
+ outputs["images"] = None
+ outputs["grid_thw"] = None
+ outputs["image_type_ids"] = None
+ else:
+ outputs["images"] = np.vstack(outputs["images"])
+ outputs["grid_thw"] = np.vstack(outputs["grid_thw"])
+ outputs["image_type_ids"] = np.array(outputs["image_type_ids"])
+
+ outputs["input_ids"] = np.array(outputs["input_ids"], dtype=np.int64)
+ outputs["token_type_ids"] = np.array(outputs["token_type_ids"], dtype=np.int64)
+ outputs["mm_num_token_func"] = self.mm_num_tokens
+
+ # Position IDs: delegate to subclass
+ self.pack_position_ids(outputs)
+
+ return outputs
+
+ # ------------------------------------------------------------------
+ # Step 5: Compute hashes and update cache
+ # ------------------------------------------------------------------
+
+ def _update_cache(self, mm_context: MMContext, outputs: dict) -> None:
+ """Compute mm_hashes for all items and optionally update processor cache.
+
+ Hash computation is centralized here: use item.uuid if available,
+ otherwise compute hash from the processed pixel_values.
+ outputs["mm_hashes"] is always populated (needed by downstream engine).
+ Processor cache is only updated when self._cache is enabled.
+
+ NOTE: Must run BEFORE _pack_outputs(), because outputs["images"] is
+ still a per-item list at this point (not yet vstack'd).
+ """
+ # Reconstruct interleaved item list using mm_order
+ all_items = []
+ img_idx, vid_idx = 0, 0
+ for t in mm_context.mm_order:
+ if t == "image":
+ all_items.append(mm_context.images[img_idx])
+ img_idx += 1
+ else:
+ all_items.append(mm_context.videos[vid_idx])
+ vid_idx += 1
+
+ hashes_to_cache, items_to_cache = [], []
+ for idx, item in enumerate(all_items):
+ if outputs["images"] is None or idx >= len(outputs["images"]):
+ continue
+ pixel_values = outputs["images"][idx]
+ # Compute hash: prefer uuid, fallback to content hash
+ cache_key = item.uuid if item.uuid else MultimodalHasher.hash_features(pixel_values)
+ outputs["mm_hashes"].append(cache_key)
+
+ # Only cache newly-processed items (not those fetched from cache)
+ if self._cache and not isinstance(item.data, tuple):
+ meta = {}
+ grid_thw_list = outputs.get("grid_thw")
+ if grid_thw_list is not None and idx < len(grid_thw_list):
+ grid_thw = np.asarray(outputs["grid_thw"][idx]) if outputs["grid_thw"] is not None else None
+ if grid_thw is not None:
+ if grid_thw.ndim > 1:
+ t_val, h, w = grid_thw[0]
+ else:
+ t_val, h, w = grid_thw
+ meta["thw"] = (int(t_val), int(h), int(w))
+ if "fps" in outputs and idx < len(outputs.get("fps", [])):
+ meta["fps"] = outputs["fps"][idx]
+ hashes_to_cache.append(cache_key)
+ items_to_cache.append((pixel_values, meta))
+
+ if hashes_to_cache:
+ self._cache.put(hashes_to_cache, items_to_cache)
+
+ # ------------------------------------------------------------------
+ # Step 7: Write-back hook (subclass can override)
+ # ------------------------------------------------------------------
+
+ def _write_back(self, request: dict, outputs: dict) -> None:
+ """Write processing results back to request.
+
+ Default: unconditionally overwrite prompt_token_ids.
+ Subclasses can override to customize write-back behavior.
+ """
+ request["prompt_token_ids"] = outputs["input_ids"].tolist()
+ request["multimodal_inputs"] = outputs
+
+ # ------------------------------------------------------------------
+ # Text tokenization helper
+ # ------------------------------------------------------------------
+
+ def _add_text(self, tokens, outputs):
+ """Tokenize text and add to outputs."""
+ if not tokens:
+ return
+ if isinstance(tokens, str):
+ tokens_str = self.tokenizer.tokenize(tokens)
+ tokens = self.tokenizer.convert_tokens_to_ids(tokens_str)
+ num_tokens = len(tokens)
+ outputs["input_ids"].extend(tokens)
+ outputs["token_type_ids"].extend([IDS_TYPE_FLAG["text"]] * num_tokens)
+ self.add_text_positions(outputs, num_tokens)
+
+ # ------------------------------------------------------------------
+ # Outputs accumulator
+ # ------------------------------------------------------------------
+
+ def _make_outputs(self) -> dict:
+ """Create the mutable accumulator dict. Subclass can override to add fields."""
+ return {
+ "input_ids": [],
+ "token_type_ids": [],
+ "position_ids": [],
+ "images": [],
+ "grid_thw": [],
+ "image_type_ids": [],
+ "labels": [],
+ "cur_position": 0,
+ "video_cnt": 0,
+ "num_input_image_tokens": 0,
+ "num_input_video_tokens": 0,
+ "mm_positions": [],
+ "mm_hashes": [],
+ }
+
+ # ------------------------------------------------------------------
+ # Init helpers
+ # ------------------------------------------------------------------
+
+ def _init_role_prefixes(self) -> dict:
+ """Set up role prefixes for message parsing. Subclass can override."""
+ return {
+ "system": "",
+ "user": "User: ",
+ "bot": "Assistant: ",
+ "assistant": "Assistant: ",
+ }
+
+ def _parse_limits(self, limits: Optional[dict]) -> dict:
+ if not limits:
+ return dict(_DEFAULT_MM_LIMITS)
+ try:
+ if not isinstance(limits, dict):
+ raise ValueError("limit-mm-per-prompt must be a dictionary")
+ data_processor_logger.info(f"_parse_limits:{limits}")
+ return {**_DEFAULT_MM_LIMITS, **limits}
+ except Exception as e:
+ data_processor_logger.warning(f"Invalid limit-mm-per-prompt format: {e}, using default limits")
+ return dict(_DEFAULT_MM_LIMITS)
+
+ def _check_mm_limits(self, mm_data):
+ """Validate that request does not exceed per-modality limits."""
+ if isinstance(mm_data, dict):
+ for modality, data in mm_data.items():
+ if modality in self.limit_mm_per_prompt and data:
+ limit = self.limit_mm_per_prompt[modality]
+ if len(data) > limit:
+ raise ValueError(f"Too many {modality} items in prompt, got {len(data)} but limit is {limit}")
+
+ def _init_extra(self, processor_kwargs):
+ """Model-specific extra initialization. Override in subclass."""
+ pass
+
+ # ------------------------------------------------------------------
+ # Public helpers (called by Processor)
+ # ------------------------------------------------------------------
+
+ def get_mm_max_tokens_per_item(self, seq_len: int) -> Optional[Mapping[str, int]]:
+ """Per-modality max token counts for the scheduler. None = not applicable."""
+ return None
+
+ def append_completion_tokens(self, multimodal_inputs: dict, completion_token_ids):
+ """Append completion tokens. Must be implemented by subclass."""
+ raise NotImplementedError
+
+ # ------------------------------------------------------------------
+ # Abstract methods (subclass must implement)
+ # ------------------------------------------------------------------
+
+ @abstractmethod
+ def preprocess_image(self, img, outputs: dict, uuid, token_len=None):
+ """Process a raw image and append results to outputs."""
+
+ @abstractmethod
+ def preprocess_cached_image(self, img_cache, outputs: dict, uuid, token_len=None):
+ """Append a pre-processed (cached) image to outputs."""
+
+ @abstractmethod
+ def preprocess_video(self, frames, outputs: dict, uuid, token_len=None, meta=None):
+ """Process video frames and append results to outputs."""
+
+ @abstractmethod
+ def preprocess_cached_video(self, frames_cache, outputs: dict, uuid, token_len=None):
+ """Append a pre-processed (cached) video to outputs."""
+
+ @abstractmethod
+ def load_video(self, url, item: dict) -> Tuple[Any, dict]:
+ """Decode a video and return (frames, meta)."""
+
+ @abstractmethod
+ def add_text_positions(self, outputs: dict, num_tokens: int):
+ """Append text position IDs to outputs."""
+
+ @abstractmethod
+ def pack_position_ids(self, outputs: dict):
+ """Convert intermediate position ID lists into final packed format."""
+
+ @staticmethod
+ @abstractmethod
+ def mm_num_tokens(grid_thw) -> int:
+ """Calculate number of multimodal tokens for given grid_thw."""
+
+ def prompt_token_ids2outputs(self, mm_context: "MMContext") -> dict:
+ """Build outputs from pre-tokenized prompt_token_ids. Override if supported."""
+ raise NotImplementedError(f"{type(self).__name__} does not support prompt_token_ids path")
diff --git a/fastdeploy/input/multimodal/paddleocr_vl.py b/fastdeploy/input/multimodal/paddleocr_vl.py
new file mode 100644
index 00000000000..4bb57ceba40
--- /dev/null
+++ b/fastdeploy/input/multimodal/paddleocr_vl.py
@@ -0,0 +1,224 @@
+# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""PaddleOCRVLProcessor — multimodal processor for PaddleOCR-VL."""
+
+import numpy as np
+from PIL import Image
+
+from fastdeploy.engine.request import ImagePosition
+from fastdeploy.input.multimodal.image_processors import PaddleOCRImageProcessor
+from fastdeploy.input.multimodal.qwen_vl import QwenVLProcessor
+from fastdeploy.input.utils import IDS_TYPE_FLAG
+from fastdeploy.input.utils.video import read_video_decord
+from fastdeploy.input.utils.video import sample_frames_paddleocr as _sample_paddleocr
+
+
+class PaddleOCRVLProcessor(QwenVLProcessor):
+ """Multimodal processor for PaddleOCR-VL.
+
+ Inherits from QwenVLProcessor and overrides:
+ - _make_outputs: add vit_seqlen / vit_position_ids
+ - preprocess_image/video: append vit fields
+ - preprocess_video / preprocess_cached_video: use video_token_id
+ - load_video: use sample_frames_paddleocr
+ """
+
+ # ---- Class-level declarations ----
+ image_placeholder = "<|IMAGE_PLACEHOLDER|>"
+ video_placeholder = "<|video_pad|>"
+ image_token_str = "<|IMAGE_PLACEHOLDER|>"
+ video_token_str = "<|video_pad|>"
+
+ # PaddleOCR default: video not typically used
+ default_fps: float = -1.0
+
+ def _init_extra(self, processor_kwargs):
+ """Initialize PaddleOCR-specific attributes."""
+ processor_kwargs = processor_kwargs or {}
+
+ # Use PaddleOCRImageProcessor
+ self.image_processor = PaddleOCRImageProcessor.from_pretrained(self.model_name_or_path)
+
+ # Conv params from image_processor
+ self.spatial_conv_size = self.image_processor.merge_size
+ self.temporal_conv_size = self.image_processor.temporal_patch_size
+
+ # Special token IDs
+ self.image_token_id = self.tokenizer.convert_tokens_to_ids(self.image_token_str)
+ self.video_token_id = self.tokenizer.convert_tokens_to_ids(self.video_token_str)
+
+ # tokens_per_second from vision_config
+ vision_config = getattr(self.config, "vision_config", None)
+ self.tokens_per_second = getattr(vision_config, "tokens_per_second", 2)
+
+ # ------------------------------------------------------------------
+ # Outputs accumulator (adds vit fields)
+ # ------------------------------------------------------------------
+
+ def _make_outputs(self) -> dict:
+ outputs = super()._make_outputs()
+ outputs["vit_seqlen"] = []
+ outputs["vit_position_ids"] = []
+ return outputs
+
+ # ------------------------------------------------------------------
+ # Image processing (overrides to add vit fields)
+ # ------------------------------------------------------------------
+
+ def preprocess_image(self, img, outputs, uuid, token_len=None):
+ ret = self.image_processor.preprocess(images=[img.convert("RGB")])
+ num_tokens = ret["grid_thw"].prod() // self.image_processor.merge_size**2
+ grid_thw = ret["grid_thw"].tolist()
+ if token_len is not None and token_len != num_tokens:
+ raise ValueError("image tokens num not match the size")
+
+ outputs["mm_positions"].append(ImagePosition(len(outputs["input_ids"]), num_tokens))
+ outputs["input_ids"].extend([self.image_token_id] * num_tokens)
+ outputs["token_type_ids"].extend([IDS_TYPE_FLAG["image"]] * num_tokens)
+ outputs["num_input_image_tokens"] += int(num_tokens)
+
+ outputs["images"].append(ret["pixel_values"])
+ outputs["grid_thw"].append(grid_thw)
+ outputs["image_type_ids"].append(0)
+
+ t, h, w = grid_thw
+ pos_ids = self._compute_vision_positions(outputs["cur_position"], t, h, w, 0)
+ outputs["position_ids"].append(pos_ids)
+ outputs["cur_position"] = pos_ids.max() + 1
+
+ outputs["fps"].append(0)
+
+ # paddleocr vit fields
+ numel = h * w
+ outputs["vit_seqlen"].append(numel)
+ outputs["vit_position_ids"].append(np.arange(numel) % numel)
+
+ def preprocess_cached_image(self, img_cache, outputs, uuid, token_len=None):
+ super().preprocess_cached_image(img_cache, outputs, uuid, token_len)
+ _, h, w = img_cache[1]["thw"]
+ numel = h * w
+ outputs["vit_seqlen"].append(numel)
+ outputs["vit_position_ids"].append(np.arange(numel) % numel)
+
+ # ------------------------------------------------------------------
+ # Video processing (uses video_token_id + vit fields)
+ # ------------------------------------------------------------------
+
+ def preprocess_video(self, frames, outputs, uuid, token_len=None, meta=None):
+ preprocess_kwargs = {}
+ if self.video_min_pixels is not None:
+ preprocess_kwargs["min_pixels"] = self.video_min_pixels
+ preprocess_kwargs["max_pixels"] = self.video_max_pixels
+
+ ret = self.image_processor.preprocess(images=frames, **preprocess_kwargs)
+
+ num_tokens = ret["grid_thw"].prod() // self.image_processor.merge_size**2
+ grid_thw = ret["grid_thw"].tolist()
+ if token_len is not None and token_len != num_tokens:
+ raise ValueError("video tokens num not match the size")
+
+ outputs["mm_positions"].append(ImagePosition(len(outputs["input_ids"]), num_tokens))
+ # PaddleOCR uses video_token_id for video (not image_token_id)
+ outputs["input_ids"].extend([self.video_token_id] * num_tokens)
+ outputs["token_type_ids"].extend([IDS_TYPE_FLAG["video"]] * num_tokens)
+ outputs["num_input_video_tokens"] += int(num_tokens)
+
+ outputs["images"].append(ret["pixel_values"])
+ outputs["grid_thw"].append(grid_thw)
+ outputs["image_type_ids"].extend([1] * grid_thw[0])
+
+ fps = meta["fps"] if meta else 0
+ second_per_grid_t = self.temporal_conv_size / fps if fps else 0
+ t, h, w = grid_thw
+ pos_ids = self._compute_vision_positions(outputs["cur_position"], t, h, w, second_per_grid_t)
+ outputs["position_ids"].append(pos_ids)
+ outputs["cur_position"] = pos_ids.max() + 1
+
+ outputs["fps"].append(fps)
+
+ # paddleocr vit fields
+ numel = h * w
+ outputs["vit_seqlen"].append(numel)
+ outputs["vit_position_ids"].append(np.arange(numel) % numel)
+
+ def preprocess_cached_video(self, frames_cache, outputs, uuid, token_len=None):
+ frames, meta = frames_cache
+ num_tokens = frames.shape[0] // self.image_processor.merge_size**2
+ if token_len is not None and token_len != num_tokens:
+ raise ValueError("video tokens num not match the size")
+
+ t, h, w = meta["thw"]
+ outputs["images"].append(frames)
+ outputs["grid_thw"].append(np.array([[t, h, w]]))
+
+ outputs["mm_positions"].append(ImagePosition(len(outputs["input_ids"]), num_tokens))
+ # PaddleOCR uses video_token_id for video
+ outputs["input_ids"].extend([self.video_token_id] * num_tokens)
+ outputs["token_type_ids"].extend([IDS_TYPE_FLAG["video"]] * num_tokens)
+ outputs["num_input_video_tokens"] += num_tokens
+ outputs["image_type_ids"].extend([1] * t)
+
+ fps = meta["fps"]
+ second_per_grid_t = self.temporal_conv_size / fps
+ pos_ids = self._compute_vision_positions(outputs["cur_position"], t, h, w, second_per_grid_t)
+ outputs["position_ids"].append(pos_ids)
+ outputs["cur_position"] = pos_ids.max() + 1
+
+ outputs["fps"].append(fps)
+
+ # paddleocr vit fields
+ numel = h * w
+ outputs["vit_seqlen"].append(numel)
+ outputs["vit_position_ids"].append(np.arange(numel) % numel)
+
+ # ------------------------------------------------------------------
+ # Video loading (uses sample_frames_paddleocr)
+ # ------------------------------------------------------------------
+
+ def load_video(self, url, item):
+ reader, meta, _ = read_video_decord(url, save_to_disk=False)
+
+ fps = item.get("fps", self.fps)
+ num_frames = item.get("target_frames", self.target_frames)
+
+ frame_indices = list(range(meta["num_of_frame"]))
+ if fps > 0 or num_frames > 0:
+ min_frames = item.get("min_frames", self.min_frames)
+ max_frames = item.get("max_frames", self.max_frames)
+
+ frame_indices = _sample_paddleocr(
+ frame_factor=self.temporal_conv_size,
+ min_frames=min_frames,
+ max_frames=max_frames,
+ metadata=meta,
+ fps=fps,
+ num_frames=num_frames,
+ )
+
+ meta["num_of_frame"] = len(frame_indices)
+ if fps is not None:
+ meta["fps"] = fps
+ meta["duration"] = len(frame_indices) / fps
+ else:
+ meta["fps"] = len(frame_indices) / meta["duration"]
+
+ frames = []
+ for idx in frame_indices:
+ frame = reader[idx].asnumpy()
+ image = Image.fromarray(frame, "RGB")
+ frames.append(image)
+ frames = np.stack([np.array(f.convert("RGB")) for f in frames], axis=0)
+
+ return frames, meta
diff --git a/fastdeploy/input/multimodal/qwen3_vl.py b/fastdeploy/input/multimodal/qwen3_vl.py
new file mode 100644
index 00000000000..bdde1323cdc
--- /dev/null
+++ b/fastdeploy/input/multimodal/qwen3_vl.py
@@ -0,0 +1,50 @@
+# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Qwen3VLProcessor — multimodal processor for Qwen3-VL."""
+
+from fastdeploy.input.multimodal.image_processors import Qwen3ImageProcessor
+from fastdeploy.input.multimodal.qwen_vl import QwenVLProcessor
+
+
+class Qwen3VLProcessor(QwenVLProcessor):
+ """Multimodal processor for Qwen3-VL.
+
+ Inherits QwenVLProcessor with:
+ - Qwen3ImageProcessor (patch_size=16, mean/std=[0.5,0.5,0.5])
+ - Video pixel bounds for Qwen3-VL
+ """
+
+ # Qwen3-VL video pixel bounds
+ video_min_pixels = 128 * 28 * 28
+ video_max_pixels = 768 * 28 * 28
+
+ def _init_extra(self, processor_kwargs):
+ """Initialize Qwen3VL-specific attributes."""
+ processor_kwargs = processor_kwargs or {}
+
+ # Use Qwen3ImageProcessor instead of QwenImageProcessor
+ self.image_processor = Qwen3ImageProcessor.from_pretrained(self.model_name_or_path)
+
+ # Conv params from image_processor
+ self.spatial_conv_size = self.image_processor.merge_size
+ self.temporal_conv_size = self.image_processor.temporal_patch_size
+
+ # Special token IDs
+ self.image_token_id = self.tokenizer.convert_tokens_to_ids(self.image_token_str)
+ self.video_token_id = self.tokenizer.convert_tokens_to_ids(self.video_token_str)
+
+ # tokens_per_second from vision_config
+ vision_config = getattr(self.config, "vision_config", None)
+ self.tokens_per_second = getattr(vision_config, "tokens_per_second", 2)
diff --git a/fastdeploy/input/multimodal/qwen_vl.py b/fastdeploy/input/multimodal/qwen_vl.py
new file mode 100644
index 00000000000..c2aa3df059e
--- /dev/null
+++ b/fastdeploy/input/multimodal/qwen_vl.py
@@ -0,0 +1,375 @@
+# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""QwenVLProcessor — multimodal processor for Qwen2.5-VL."""
+
+from typing import Optional
+
+import numpy as np
+import paddle
+from PIL import Image
+
+from fastdeploy.engine.request import ImagePosition
+from fastdeploy.input.multimodal.image_processors import QwenImageProcessor
+from fastdeploy.input.multimodal.mm_processor import MMProcessor
+from fastdeploy.input.utils import IDS_TYPE_FLAG
+from fastdeploy.input.utils.video import read_video_decord
+from fastdeploy.input.utils.video import sample_frames_qwen as _sample_qwen
+
+
+class QwenVLProcessor(MMProcessor):
+ """Multimodal processor for Qwen2.5-VL (qwen_vl).
+
+ Implements qwen-family position ID computation (3D: temporal, height, width)
+ and image/video preprocessing using QwenImageProcessor.
+ """
+
+ # ---- Class-level declarations ----
+ image_placeholder = "<|image_pad|>"
+ video_placeholder = "<|video_pad|>"
+ image_token_str = "<|image_pad|>"
+ video_token_str = "<|video_pad|>"
+ tokenizer_type = "auto"
+
+ FRAME_FACTOR = 2
+
+ # Video pixel bounds (None means use image_processor defaults)
+ video_min_pixels: Optional[int] = None
+ video_max_pixels: Optional[int] = None
+
+ def _init_extra(self, processor_kwargs):
+ """Initialize QwenVL-specific attributes."""
+ processor_kwargs = processor_kwargs or {}
+
+ # Image processor
+ self.image_processor = QwenImageProcessor.from_pretrained(self.model_name_or_path)
+
+ # Conv params from image_processor
+ self.spatial_conv_size = self.image_processor.merge_size
+ self.temporal_conv_size = self.image_processor.temporal_patch_size
+
+ # Special token IDs
+ self.image_token_id = self.tokenizer.convert_tokens_to_ids(self.image_token_str)
+ self.video_token_id = self.tokenizer.convert_tokens_to_ids(self.video_token_str)
+
+ # tokens_per_second from vision_config
+ vision_config = getattr(self.config, "vision_config", None)
+ self.tokens_per_second = getattr(vision_config, "tokens_per_second", 2)
+
+ # ------------------------------------------------------------------
+ # Outputs accumulator (adds fps field)
+ # ------------------------------------------------------------------
+
+ def _make_outputs(self) -> dict:
+ outputs = super()._make_outputs()
+ outputs["fps"] = []
+ return outputs
+
+ # ------------------------------------------------------------------
+ # Image processing
+ # ------------------------------------------------------------------
+
+ def preprocess_image(self, img, outputs, uuid, token_len=None):
+ ret = self.image_processor.preprocess(images=[img.convert("RGB")])
+ num_tokens = ret["grid_thw"].prod() // self.image_processor.merge_size**2
+ grid_thw = ret["grid_thw"].tolist()
+ if token_len is not None and token_len != num_tokens:
+ raise ValueError("image tokens num not match the size")
+
+ outputs["mm_positions"].append(ImagePosition(len(outputs["input_ids"]), num_tokens))
+ outputs["input_ids"].extend([self.image_token_id] * num_tokens)
+ outputs["token_type_ids"].extend([IDS_TYPE_FLAG["image"]] * num_tokens)
+ outputs["num_input_image_tokens"] += int(num_tokens)
+
+ outputs["images"].append(ret["pixel_values"])
+ outputs["grid_thw"].append(grid_thw)
+ outputs["image_type_ids"].append(0)
+
+ t, h, w = grid_thw
+ pos_ids = self._compute_vision_positions(outputs["cur_position"], t, h, w, 0)
+ outputs["position_ids"].append(pos_ids)
+ outputs["cur_position"] = pos_ids.max() + 1
+
+ outputs["fps"].append(0)
+
+ def preprocess_cached_image(self, img_cache, outputs, uuid, token_len=None):
+ img, meta = img_cache
+ num_tokens = img.shape[0] // self.image_processor.merge_size**2
+ if token_len is not None and token_len != num_tokens:
+ raise ValueError("image tokens num not match the size")
+
+ outputs["mm_positions"].append(ImagePosition(len(outputs["input_ids"]), num_tokens))
+ outputs["input_ids"].extend([self.image_token_id] * num_tokens)
+ outputs["token_type_ids"].extend([IDS_TYPE_FLAG["image"]] * num_tokens)
+ outputs["num_input_image_tokens"] += num_tokens
+
+ _, h, w = meta["thw"]
+ pos_ids = self._compute_vision_positions(outputs["cur_position"], 1, h, w, 0)
+ outputs["position_ids"].append(pos_ids)
+ outputs["cur_position"] = pos_ids.max() + 1
+
+ outputs["images"].append(img)
+ outputs["grid_thw"].append(np.array([[1, h, w]]))
+ outputs["image_type_ids"].append(0)
+
+ outputs["fps"].append(0)
+
+ # ------------------------------------------------------------------
+ # Video processing
+ # ------------------------------------------------------------------
+
+ def preprocess_video(self, frames, outputs, uuid, token_len=None, meta=None):
+ preprocess_kwargs = {}
+ if self.video_min_pixels is not None:
+ preprocess_kwargs["min_pixels"] = self.video_min_pixels
+ preprocess_kwargs["max_pixels"] = self.video_max_pixels
+
+ ret = self.image_processor.preprocess(images=frames, **preprocess_kwargs)
+
+ num_tokens = ret["grid_thw"].prod() // self.image_processor.merge_size**2
+ grid_thw = ret["grid_thw"].tolist()
+ if token_len is not None and token_len != num_tokens:
+ raise ValueError("video tokens num not match the size")
+
+ outputs["mm_positions"].append(ImagePosition(len(outputs["input_ids"]), num_tokens))
+ outputs["input_ids"].extend([self.image_token_id] * num_tokens)
+ outputs["token_type_ids"].extend([IDS_TYPE_FLAG["video"]] * num_tokens)
+ outputs["num_input_video_tokens"] += int(num_tokens)
+
+ outputs["images"].append(ret["pixel_values"])
+ outputs["grid_thw"].append(grid_thw)
+ outputs["image_type_ids"].extend([1] * grid_thw[0])
+
+ fps = meta["fps"] if meta else 0
+ second_per_grid_t = self.temporal_conv_size / fps if fps else 0
+ t, h, w = grid_thw
+ pos_ids = self._compute_vision_positions(outputs["cur_position"], t, h, w, second_per_grid_t)
+ outputs["position_ids"].append(pos_ids)
+ outputs["cur_position"] = pos_ids.max() + 1
+
+ outputs["fps"].append(fps)
+
+ def preprocess_cached_video(self, frames_cache, outputs, uuid, token_len=None):
+ frames, meta = frames_cache
+ num_tokens = frames.shape[0] // self.image_processor.merge_size**2
+ if token_len is not None and token_len != num_tokens:
+ raise ValueError("video tokens num not match the size")
+
+ t, h, w = meta["thw"]
+ outputs["images"].append(frames)
+ outputs["grid_thw"].append(np.array([[t, h, w]]))
+
+ outputs["mm_positions"].append(ImagePosition(len(outputs["input_ids"]), num_tokens))
+ outputs["input_ids"].extend([self.image_token_id] * num_tokens)
+ outputs["token_type_ids"].extend([IDS_TYPE_FLAG["video"]] * num_tokens)
+ outputs["num_input_video_tokens"] += num_tokens
+ outputs["image_type_ids"].extend([1] * t)
+
+ fps = meta["fps"]
+ second_per_grid_t = self.temporal_conv_size / fps
+ pos_ids = self._compute_vision_positions(outputs["cur_position"], t, h, w, second_per_grid_t)
+ outputs["position_ids"].append(pos_ids)
+ outputs["cur_position"] = pos_ids.max() + 1
+
+ outputs["fps"].append(fps)
+
+ def load_video(self, url, item):
+ reader, meta, _ = read_video_decord(url, save_to_disk=False)
+
+ fps = item.get("fps", self.fps)
+ num_frames = item.get("target_frames", self.target_frames)
+
+ frame_indices = list(range(meta["num_of_frame"]))
+ if fps > 0 or num_frames > 0:
+ min_frames = item.get("min_frames", self.min_frames)
+ max_frames = item.get("max_frames", self.max_frames)
+
+ frame_indices = _sample_qwen(
+ frame_factor=self.FRAME_FACTOR,
+ min_frames=min_frames,
+ max_frames=max_frames,
+ metadata=meta,
+ fps=-1 if num_frames > 0 else fps,
+ num_frames=num_frames,
+ )
+
+ meta["num_of_frame"] = len(frame_indices)
+ if fps is not None:
+ meta["fps"] = fps
+ meta["duration"] = len(frame_indices) / fps
+ else:
+ meta["fps"] = len(frame_indices) / meta["duration"]
+
+ frames = []
+ for idx in frame_indices:
+ frame = reader[idx].asnumpy()
+ image = Image.fromarray(frame, "RGB")
+ frames.append(image)
+ frames = np.stack([np.array(f.convert("RGB")) for f in frames], axis=0)
+
+ return frames, meta
+
+ # ------------------------------------------------------------------
+ # Position IDs
+ # ------------------------------------------------------------------
+
+ def add_text_positions(self, outputs, num_tokens):
+ """Write text position IDs in qwen 3xN ndarray format."""
+ pos_ids = self._compute_text_positions(outputs["cur_position"], num_tokens)
+ outputs["position_ids"].append(pos_ids)
+ outputs["cur_position"] = pos_ids.max() + 1
+
+ def append_completion_tokens(self, multimodal_inputs, completion_token_ids):
+ num_tokens = len(completion_token_ids)
+ multimodal_inputs["input_ids"].extend(completion_token_ids)
+ multimodal_inputs["token_type_ids"].extend([IDS_TYPE_FLAG["text"]] * num_tokens)
+
+ pos_ids = self._compute_text_positions(multimodal_inputs["cur_position"], num_tokens)
+ multimodal_inputs["position_ids"].append(pos_ids)
+ multimodal_inputs["cur_position"] += num_tokens
+
+ def pack_position_ids(self, outputs):
+ """Qwen: concatenate 3xN arrays, then transpose to Nx3."""
+ outputs["position_ids"] = np.concatenate(outputs["position_ids"], axis=1, dtype=np.int64)
+ outputs["image_patch_id"] = self.image_token_id
+ outputs["video_patch_id"] = self.video_token_id
+ outputs["position_ids"] = outputs["position_ids"].transpose(1, 0)
+
+ # ------------------------------------------------------------------
+ # Token counting
+ # ------------------------------------------------------------------
+
+ @staticmethod
+ def mm_num_tokens(grid_thw):
+ """Qwen mm_num_tokens: t * h * w // 4."""
+ if isinstance(grid_thw, paddle.Tensor):
+ grid_thw = grid_thw.numpy()
+ if len(grid_thw) == 0:
+ return 0
+
+ def calc_one(thw):
+ t, h, w = map(int, thw)
+ return t * h * w // 4
+
+ if isinstance(grid_thw[0], (list, tuple, np.ndarray)):
+ return [calc_one(x) for x in grid_thw]
+ return calc_one(grid_thw)
+
+ # ------------------------------------------------------------------
+ # Prompt token IDs path
+ # ------------------------------------------------------------------
+
+ def prompt_token_ids2outputs(self, mm_context):
+ """Build outputs from prompt_token_ids."""
+ outputs = self._make_outputs()
+ prompt_token_ids = mm_context.prompt_token_ids
+ prompt_token_ids_len = len(prompt_token_ids)
+
+ if not mm_context.images and not mm_context.videos:
+ self._add_text_tokens(prompt_token_ids, outputs)
+ return outputs
+
+ # Reconstruct interleaved list using mm_order
+ mm_items = []
+ img_idx, vid_idx = 0, 0
+ for t in mm_context.mm_order:
+ if t == "image":
+ item = mm_context.images[img_idx]
+ mm_items.append(item)
+ img_idx += 1
+ else:
+ item = mm_context.videos[vid_idx]
+ mm_items.append(item)
+ vid_idx += 1
+
+ st, mm_idx = 0, 0
+ while st < prompt_token_ids_len:
+ if prompt_token_ids[st] != self.image_token_id:
+ cur_idx = st
+ while cur_idx < prompt_token_ids_len and prompt_token_ids[cur_idx] != self.image_token_id:
+ cur_idx += 1
+ self._add_text_tokens(prompt_token_ids[st:cur_idx], outputs)
+ st = cur_idx
+ continue
+
+ if mm_idx >= len(mm_items):
+ raise ValueError("prompt token ids has more multimodal placeholder than in messages")
+
+ cur_idx = st
+ while cur_idx < prompt_token_ids_len and prompt_token_ids[cur_idx] == self.image_token_id:
+ cur_idx += 1
+
+ item = mm_items[mm_idx]
+ uuid = item.uuid
+ token_len = cur_idx - st
+ if item.type == "image":
+ if not isinstance(item.data, tuple):
+ self.preprocess_image(item.data, outputs, uuid, token_len)
+ else:
+ self.preprocess_cached_image(item.data, outputs, uuid, token_len)
+ elif item.type == "video":
+ if not isinstance(item.data, tuple):
+ if isinstance(item.data, dict):
+ frames, meta = self.load_video(item.data["video"], item.data)
+ else:
+ frames, meta = self.load_video(item.data, {})
+ self.preprocess_video(frames, outputs, uuid, token_len=token_len, meta=meta)
+ else:
+ self.preprocess_cached_video(item.data, outputs, uuid, token_len)
+ else:
+ raise ValueError(f"Unsupported multimodal type: {item.type}")
+ mm_idx += 1
+ st = cur_idx
+
+ if mm_idx != len(mm_items):
+ raise ValueError("number of multimodal items does not match prompt token ids")
+
+ return outputs
+
+ # ------------------------------------------------------------------
+ # Internal helpers
+ # ------------------------------------------------------------------
+
+ def _add_text_tokens(self, tokens, outputs):
+ """Helper: add text tokens with position IDs."""
+ if not tokens:
+ return
+ num_tokens = len(tokens)
+ outputs["input_ids"].extend(tokens)
+ outputs["token_type_ids"].extend([IDS_TYPE_FLAG["text"]] * num_tokens)
+ self.add_text_positions(outputs, num_tokens)
+
+ def _compute_text_positions(self, start_pos, num_tokens):
+ """3xN ndarray for qwen-family text positions."""
+ text_array = np.arange(num_tokens).reshape(1, -1)
+ text_index = np.broadcast_to(text_array, (3, num_tokens))
+ return text_index + start_pos
+
+ def _compute_vision_positions(self, start_pos, t, h, w, second_per_grid_t):
+ """3D position IDs as 3xN ndarray for qwen-family."""
+ h //= self.spatial_conv_size
+ w //= self.spatial_conv_size
+
+ tn = np.arange(t).reshape(-1, 1)
+ tn = np.broadcast_to(tn, (t, h * w))
+ tn = tn * int(second_per_grid_t) * self.tokens_per_second
+ t_index = tn.flatten()
+
+ hn = np.arange(h).reshape(1, -1, 1)
+ h_index = np.broadcast_to(hn, (t, h, w)).flatten()
+
+ wn = np.arange(w).reshape(1, 1, -1)
+ w_index = np.broadcast_to(wn, (t, h, w)).flatten()
+
+ return np.stack([t_index, h_index, w_index]) + start_pos
diff --git a/fastdeploy/input/preprocess.py b/fastdeploy/input/preprocess.py
index 1462306c967..f53f93bfe6b 100644
--- a/fastdeploy/input/preprocess.py
+++ b/fastdeploy/input/preprocess.py
@@ -73,8 +73,8 @@ def create_processor(self):
try:
from fastdeploy.plugins.input_processor import load_input_processor_plugins
- Processor = load_input_processor_plugins()
- self.processor = Processor(
+ PluginProcessor = load_input_processor_plugins()
+ self.processor = PluginProcessor(
model_name_or_path=self.model_name_or_path,
reasoning_parser_obj=reasoning_parser_obj,
tool_parser_obj=tool_parser_obj,
@@ -83,45 +83,65 @@ def create_processor(self):
)
except Exception as e:
logger.info(f"Plugin input processor not available ({e}), using built-in processor")
- if not self.enable_mm_runtime:
- from fastdeploy.input.text_processor import TextProcessor
+ from fastdeploy.input.processor import Processor
+ if not self.enable_mm_runtime:
tokenizer_type = "ernie4_5" if ErnieArchitectures.contains_ernie_arch(architecture) else "auto"
- self.processor = TextProcessor(
+ self.processor = Processor(
model_name_or_path=self.model_name_or_path,
tokenizer_type=tokenizer_type,
reasoning_parser_obj=reasoning_parser_obj,
tool_parser_obj=tool_parser_obj,
)
else:
- from fastdeploy.input.mm_model_config import (
- ERNIE4_5_VL,
- PADDLEOCR_VL,
- QWEN3_VL,
- QWEN_VL,
+ from fastdeploy.input.multimodal import (
+ Ernie4_5VLProcessor,
+ PaddleOCRVLProcessor,
+ Qwen3VLProcessor,
+ QwenVLProcessor,
)
- from fastdeploy.input.multimodal_processor import MultiModalProcessor
+ # Determine mm_processor class and Processor-level flags
if ErnieArchitectures.contains_ernie_arch(architecture):
- model_type = ERNIE4_5_VL
+ mm_proc_cls = Ernie4_5VLProcessor
+ force_disable_thinking = False
+ set_default_reasoning_max_tokens = True
elif "PaddleOCRVL" in architecture:
- model_type = PADDLEOCR_VL
+ mm_proc_cls = PaddleOCRVLProcessor
+ force_disable_thinking = False
+ set_default_reasoning_max_tokens = False
elif "Qwen2_5_VL" in architecture:
- model_type = QWEN_VL
+ mm_proc_cls = QwenVLProcessor
+ force_disable_thinking = True
+ set_default_reasoning_max_tokens = False
elif "Qwen3VL" in architecture:
- model_type = QWEN3_VL
+ mm_proc_cls = Qwen3VLProcessor
+ force_disable_thinking = True
+ set_default_reasoning_max_tokens = False
else:
raise ValueError(f"Unsupported model processor architecture: {architecture}. ")
- self.processor = MultiModalProcessor(
+ tokenizer_type = mm_proc_cls.tokenizer_type
+
+ # Create the unified Processor first (loads tokenizer)
+ self.processor = Processor(
model_name_or_path=self.model_name_or_path,
- model_type=model_type,
- config=self.model_config,
- limit_mm_per_prompt=self.limit_mm_per_prompt,
- mm_processor_kwargs=self.mm_processor_kwargs,
+ tokenizer_type=tokenizer_type,
reasoning_parser_obj=reasoning_parser_obj,
tool_parser_obj=tool_parser_obj,
+ force_disable_thinking=force_disable_thinking,
+ set_default_reasoning_max_tokens=set_default_reasoning_max_tokens,
+ )
+
+ # Create and attach the multimodal processor
+ mm_processor = mm_proc_cls(
+ tokenizer=self.processor.tokenizer,
+ model_name_or_path=self.model_name_or_path,
+ config=self.model_config,
+ processor_kwargs=self.mm_processor_kwargs,
+ limit_mm_per_prompt=self.limit_mm_per_prompt,
enable_processor_cache=self.enable_processor_cache,
)
+ self.processor.mm_processor = mm_processor
return self.processor
diff --git a/fastdeploy/input/processor.py b/fastdeploy/input/processor.py
new file mode 100644
index 00000000000..43678e5eb2f
--- /dev/null
+++ b/fastdeploy/input/processor.py
@@ -0,0 +1,808 @@
+# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Unified Processor for both text-only and multimodal models.
+
+Merges the former BaseTextProcessor + TextProcessor into a single concrete
+class. Multimodal support is opt-in via the ``mm_processor`` parameter.
+"""
+
+from collections import OrderedDict
+from typing import Dict
+
+import numpy as np
+from paddleformers.generation import GenerationConfig
+from paddleformers.transformers import Llama3Tokenizer, LlamaTokenizer
+
+from fastdeploy import envs
+from fastdeploy.input.utils import process_stop_token_ids
+from fastdeploy.logger.request_logger import RequestLogLevel, log_request
+from fastdeploy.utils import data_processor_logger
+
+_SAMPLING_EPS = 1e-5
+
+
+class Processor:
+ """Unified processor for text-only and multimodal models.
+
+ Handles the full request pre-processing pipeline, response decoding,
+ and optionally delegates multimodal data processing to an MMProcessor.
+ """
+
+ def __init__(
+ self,
+ model_name_or_path: str,
+ tokenizer_type: str = "auto",
+ reasoning_parser_obj=None,
+ tool_parser_obj=None,
+ mm_processor=None,
+ force_disable_thinking: bool = False,
+ set_default_reasoning_max_tokens: bool = False,
+ ):
+ self.model_name_or_path = model_name_or_path
+ self.tokenizer_type = tokenizer_type
+ self.mm_processor = mm_processor
+ self.force_disable_thinking = force_disable_thinking
+ self.set_default_reasoning_max_tokens = set_default_reasoning_max_tokens
+
+ # Response-handling state.
+ self.decode_status: Dict[str, list] = {}
+ self.model_status_dict: Dict[str, dict] = {}
+ self.tool_parser_dict: Dict = {}
+ # Token-encode cache.
+ self._tokenize_cache: OrderedDict = OrderedDict()
+ self._tokenize_cache_capacity: int = 128
+
+ # Generation config
+ try:
+ self.generation_config = GenerationConfig.from_pretrained(self.model_name_or_path)
+ except Exception as e:
+ data_processor_logger.warning(
+ f"Can't find generation config: {e}, so it will not use generation_config field in the model config"
+ )
+ self.generation_config = None
+
+ # Tokenizer
+ self.tokenizer = self._load_tokenizer()
+ data_processor_logger.info(
+ f"tokenizer information: bos_token is {self.tokenizer.bos_token}, "
+ f"{self.tokenizer.bos_token_id}, "
+ f"eos_token is {self.tokenizer.eos_token}, {self.tokenizer.eos_token_id}"
+ )
+
+ # EOS tokens
+ try:
+ from paddleformers.trl.llm_utils import get_eos_token_id
+ except Exception:
+ from paddleformers.cli.utils.llm_utils import get_eos_token_id
+
+ self.eos_token_ids = get_eos_token_id(self.tokenizer, self.generation_config)
+ data_processor_logger.info(
+ f"The eos_token_ids obtained by merging tokenizer and generation_config is {self.eos_token_ids}"
+ )
+ self.eos_token_id_len = len(self.eos_token_ids)
+ self.pad_token_id = self.get_pad_id()
+ self.tokenizer.pad_token_id = self.pad_token_id
+ self._init_parsers(reasoning_parser_obj, tool_parser_obj)
+
+ # ------------------------------------------------------------------
+ # Tokenizer loading
+ # ------------------------------------------------------------------
+
+ def _load_tokenizer(self):
+ if self.tokenizer_type == "ernie4_5":
+ return self._load_ernie4_5_tokenizer()
+ return self._load_auto_tokenizer()
+
+ def _load_auto_tokenizer(self):
+ if envs.FD_USE_HF_TOKENIZER:
+ from transformers import AutoTokenizer
+
+ return AutoTokenizer.from_pretrained(self.model_name_or_path, use_fast=False)
+ else:
+ from paddleformers.transformers import AutoTokenizer
+
+ return AutoTokenizer.from_pretrained(self.model_name_or_path, padding_side="left", use_fast=True)
+
+ def _load_ernie4_5_tokenizer(self):
+ import os
+
+ from fastdeploy.input.ernie4_5_tokenizer import Ernie4_5Tokenizer
+
+ vocab_file_names = ["tokenizer.model", "spm.model", "ernie_token_100k.model"]
+ for name in vocab_file_names:
+ if os.path.exists(os.path.join(self.model_name_or_path, name)):
+ Ernie4_5Tokenizer.resource_files_names["vocab_file"] = name
+ break
+ return Ernie4_5Tokenizer.from_pretrained(self.model_name_or_path)
+
+ # ------------------------------------------------------------------
+ # Parser initialisation helper
+ # ------------------------------------------------------------------
+
+ def _init_parsers(self, reasoning_parser_obj, tool_parser_obj):
+ """Initialise reasoning / tool parser attributes."""
+ self.reasoning_parser = None
+ self.tool_parser_obj = tool_parser_obj
+ if reasoning_parser_obj:
+ self.reasoning_parser = reasoning_parser_obj(self.tokenizer)
+
+ # ------------------------------------------------------------------
+ # Text tokenization
+ # ------------------------------------------------------------------
+
+ def text2ids(self, text, max_model_len=None, **kwargs):
+ """Convert text to token IDs."""
+ if self.tokenizer_type == "ernie4_5":
+ return self.tokenizer.convert_tokens_to_ids(self.tokenizer.tokenize(text))
+ add_special_tokens = kwargs.get("add_special_tokens", False)
+ if envs.FD_USE_HF_TOKENIZER:
+ tokens = self.tokenizer(text, return_tensors="np", padding=True, truncation=True)
+ else:
+ text_input = [text] if isinstance(text, str) else text
+ tokens = self.tokenizer(
+ text_input,
+ return_tensors="np",
+ padding=True,
+ truncation=True,
+ max_length=max_model_len,
+ add_special_tokens=add_special_tokens,
+ )
+ return tokens["input_ids"][0]
+
+ def messages2ids(self, request, **kwargs):
+ """Convert a chat-template request into a token-ID list."""
+ if self.tokenizer.chat_template is None:
+ raise ValueError("This model does not support chat_template.")
+ if self.tokenizer_type != "ernie4_5":
+ if "add_generation_prompt" not in kwargs:
+ kwargs["add_generation_prompt"] = request.get("add_generation_prompt", True)
+ spliced_message = self.tokenizer.apply_chat_template(
+ request,
+ tokenize=False,
+ split_special_tokens=False,
+ add_special_tokens=False,
+ **kwargs,
+ )
+ request["prompt_tokens"] = spliced_message
+ req_id = request.get("request_id", None) if isinstance(request, dict) else None
+ if self.tokenizer_type == "ernie4_5":
+ token_ids = self.tokenizer.convert_tokens_to_ids(self.tokenizer.tokenize(spliced_message))
+ else:
+ token_ids = self.tokenizer.encode(spliced_message, add_special_tokens=False)
+ if hasattr(token_ids, "input_ids") or (isinstance(token_ids, dict) and "input_ids" in token_ids):
+ token_ids = token_ids["input_ids"]
+ if hasattr(token_ids, "ndim") and token_ids.ndim > 1:
+ token_ids = token_ids[0]
+ if hasattr(token_ids, "tolist"):
+ token_ids = token_ids.tolist()
+ if not isinstance(token_ids, list):
+ token_ids = list(token_ids)
+ log_request(
+ level=1,
+ message="req_id:{req_id}, token_ids: {token_ids}",
+ req_id=req_id,
+ token_ids=token_ids,
+ )
+ return token_ids
+
+ # ------------------------------------------------------------------
+ # ids2tokens
+ # ------------------------------------------------------------------
+
+ def ids2tokens(self, token_id, task_id):
+ """Incrementally decode *token_id* and return a three-tuple.
+
+ Returns:
+ (delta_text, previous_token_ids, previous_texts)
+ """
+ if envs.FD_USE_HF_TOKENIZER:
+ if task_id not in self.decode_status:
+ self.decode_status[task_id] = [[], [], ""]
+ status = self.decode_status[task_id]
+ status[0].extend(token_id)
+ decode_str = self.tokenizer.batch_decode(
+ [status[0]],
+ skip_special_tokens=True,
+ clean_up_tokenization_spaces=False,
+ )
+ if isinstance(decode_str, list) and len(decode_str):
+ new_str = decode_str[0].replace(status[2], "", 1)
+ status[1].append(new_str)
+ status[2] = decode_str[0]
+ else:
+ new_str = ""
+ return new_str, [], status[2]
+ else:
+ if task_id not in self.decode_status:
+ self.decode_status[task_id] = [0, 0, [], ""]
+ status = self.decode_status[task_id]
+ previous_texts = status[3]
+ status[2].extend(token_id)
+ decode_str, prefix_offset, read_offset = self.tokenizer.decode_token(status[2], status[0], status[1])
+ status[0] = prefix_offset
+ status[1] = read_offset
+ status[3] += decode_str
+ return decode_str, status[2], previous_texts
+
+ # ------------------------------------------------------------------
+ # Response processing
+ # ------------------------------------------------------------------
+
+ def process_response_dict(self, response_dict, **kwargs):
+ """Dispatch to streaming or non-streaming handler."""
+ if isinstance(response_dict, dict):
+ outputs = response_dict.get("outputs")
+ error_code = response_dict.get("error_code", 200)
+ else:
+ outputs = getattr(response_dict, "outputs", None)
+ error_code = getattr(response_dict, "error_code", 200)
+ if outputs is None or error_code != 200:
+ return response_dict
+
+ stream = kwargs.get("stream", True)
+ if stream:
+ return self.process_response_dict_streaming(response_dict, **kwargs)
+ else:
+ return self.process_response_dict_normal(response_dict, **kwargs)
+
+ def process_response_dict_normal(self, response_dict, **kwargs):
+ """Accumulate tokens and build the full completion text (non-streaming)."""
+ token_ids = response_dict["outputs"]["token_ids"]
+ is_end = response_dict["finished"]
+ req_id = response_dict["request_id"]
+ request = kwargs.get("request", None)
+ direct_decode = kwargs.get("direct_decode", False)
+
+ if is_end and len(token_ids) > 0 and not kwargs.get("include_stop_str_in_output"):
+ if token_ids[-1] in self.eos_token_ids:
+ token_ids = token_ids[:-1]
+
+ if direct_decode:
+ delta_text = self.tokenizer.decode(token_ids)
+ previous_texts = ""
+ else:
+ delta_text, _, previous_texts = self.ids2tokens(token_ids, req_id)
+
+ if is_end:
+ full_text = previous_texts + delta_text
+ response_dict["outputs"]["completion_tokens"] = full_text
+ response_dict["outputs"]["text"] = full_text
+
+ if self.reasoning_parser:
+ reasoning_content, text = self.reasoning_parser.extract_reasoning_content(
+ full_text, request, self.model_status_dict[req_id]
+ )
+ response_dict["outputs"]["text"] = text
+ response_dict["outputs"]["reasoning_content"] = reasoning_content
+ reasoning_tokens = self.tokenizer.tokenize(reasoning_content)
+ response_dict["outputs"]["reasoning_token_num"] = len(reasoning_tokens)
+
+ if self.tool_parser_obj:
+ tool_parser = self.tool_parser_obj(self.tokenizer)
+ tool_call_info = tool_parser.extract_tool_calls(full_text, request)
+ if tool_call_info.tools_called:
+ response_dict["outputs"]["tool_calls"] = tool_call_info.tool_calls
+
+ if req_id in self.decode_status:
+ del self.decode_status[req_id]
+ if req_id in self.model_status_dict:
+ del self.model_status_dict[req_id]
+
+ return response_dict
+
+ def process_response_dict_streaming(self, response_dict, **kwargs):
+ """Incrementally decode and populate streaming output fields."""
+ is_end = response_dict["finished"]
+ req_id = response_dict["request_id"]
+ token_ids = response_dict["outputs"]["token_ids"]
+ request = kwargs.get("request", None)
+
+ if is_end and len(token_ids) > 0 and not kwargs.get("include_stop_str_in_output"):
+ if token_ids[-1] in self.eos_token_ids:
+ token_ids = token_ids[:-1]
+
+ delta_text, previous_token_ids, previous_texts = self.ids2tokens(token_ids, req_id)
+
+ response_dict["outputs"]["text"] = delta_text
+ response_dict["outputs"]["completion_tokens"] = delta_text
+ response_dict["outputs"]["skipped"] = False
+ response_dict["outputs"]["tool_calls"] = None
+ response_dict["outputs"]["reasoning_content"] = ""
+
+ if self.reasoning_parser:
+ reasoning_delta_message = self.reasoning_parser.extract_reasoning_content_streaming(
+ previous_texts,
+ previous_texts + delta_text,
+ delta_text,
+ previous_token_ids,
+ previous_token_ids + token_ids,
+ token_ids,
+ self.model_status_dict[req_id],
+ )
+ if reasoning_delta_message:
+ reasoning_content = reasoning_delta_message.reasoning_content
+ reasoning_tokens = self.tokenizer.tokenize(reasoning_content) if reasoning_content else []
+ response_dict["outputs"]["reasoning_token_num"] = len(reasoning_tokens)
+ response_dict["outputs"]["reasoning_content"] = reasoning_content or ""
+ response_dict["outputs"]["text"] = reasoning_delta_message.content or ""
+ else:
+ if not is_end:
+ response_dict["outputs"]["skipped"] = True
+
+ if self.tool_parser_obj:
+ if req_id not in self.tool_parser_dict:
+ self.tool_parser_dict[req_id] = self.tool_parser_obj(self.tokenizer)
+ tool_parser = self.tool_parser_dict[req_id]
+ tool_call_delta_message = tool_parser.extract_tool_calls_streaming(
+ previous_texts,
+ previous_texts + delta_text,
+ delta_text,
+ previous_token_ids,
+ previous_token_ids + token_ids,
+ token_ids,
+ request,
+ )
+ if tool_call_delta_message:
+ if tool_call_delta_message.tool_calls:
+ response_dict["outputs"]["text"] = tool_call_delta_message.content
+ response_dict["outputs"]["tool_calls"] = tool_call_delta_message.tool_calls
+ response_dict["outputs"]["skipped"] = False
+ else:
+ if not is_end:
+ response_dict["outputs"]["skipped"] = True
+
+ if is_end:
+ del self.decode_status[req_id]
+ if req_id in self.tool_parser_dict:
+ del self.tool_parser_dict[req_id]
+ if req_id in self.model_status_dict:
+ del self.model_status_dict[req_id]
+
+ return response_dict
+
+ # ------------------------------------------------------------------
+ # Request processing (unified flow for text + multimodal)
+ # ------------------------------------------------------------------
+
+ def process_request_dict(self, request, max_model_len=None, **kwargs):
+ """Unified request pre-processing for all models (text and multimodal)."""
+ request = self._apply_default_parameters(request)
+
+ if not request.get("eos_token_ids"):
+ request["eos_token_ids"] = self.eos_token_ids
+
+ # Stop tokens (universal)
+ process_stop_token_ids(request, self.update_stop_seq)
+
+ # Bad words (universal — no-op if none present in request)
+ bad_words = request.get("bad_words")
+ bad_words_token_ids = request.get("bad_words_token_ids")
+ if bad_words:
+ bad_words_token_ids = self.update_bad_words(bad_words, bad_words_token_ids)
+ request["bad_words_token_ids"] = bad_words_token_ids
+
+ # Logits processor: think stop sentence (universal — no-op if no data)
+ logits_processors_args = self._prepare_think_stop_sentence(
+ request.get("logits_processors_args") or {}, max_model_len
+ )
+ request["logits_processors_args"] = logits_processors_args
+
+ # Step 6: messages → prompt + multimodal_data (通用预处理)
+ if request.get("messages") and not request.get("prompt"):
+ self.process_messages(request)
+
+ # Step 7: tokenization (multimodal or text-only)
+ if self.mm_processor is not None:
+ self.mm_processor.process(request)
+ else:
+ self._tokenize_text_request(request, max_model_len)
+
+ # Force disable thinking (Processor-level config)
+ if self.force_disable_thinking:
+ request["enable_thinking"] = False
+
+ # Truncation
+ if max_model_len is not None and len(request["prompt_token_ids"]) > max_model_len:
+ request["prompt_token_ids"] = request["prompt_token_ids"][: max_model_len - 1]
+
+ request["prompt_token_ids_len"] = len(request["prompt_token_ids"])
+
+ # Update thinking prompt state (universal — no-op if no budget data)
+ logits_processors_args = self._update_thinking_prompt_state(
+ request["prompt_token_ids"], request.get("logits_processors_args") or {}
+ )
+ request["logits_processors_args"] = logits_processors_args
+
+ # max_tokens
+ max_tokens = max_model_len - len(request["prompt_token_ids"])
+ if request.get("max_tokens") is None:
+ request["max_tokens"] = max(1, max_tokens)
+ else:
+ request["max_tokens"] = min(max_tokens, request["max_tokens"])
+
+ # Default reasoning_max_tokens (only for models that need it, e.g. Ernie)
+ if self.set_default_reasoning_max_tokens and request.get("reasoning_max_tokens") is None:
+ request["reasoning_max_tokens"] = max(int(request["max_tokens"] * 0.8), 1)
+
+ # Temperature handling
+ if request.get("temperature") < _SAMPLING_EPS:
+ request["temperature"] = 1
+ request["top_k"] = 1
+
+ # Clamp top_p
+ if request.get("top_p") < _SAMPLING_EPS:
+ request["top_p"] = _SAMPLING_EPS
+ request["top_k"] = 1
+
+ # Reasoning parser
+ if self.reasoning_parser:
+ self._apply_reasoning_parser(request)
+
+ # Cap response_max_tokens (universal)
+ if request.get("response_max_tokens") is not None and request.get("enable_thinking") is False:
+ request["max_tokens"] = min(request["response_max_tokens"], request["max_tokens"])
+
+ log_request(RequestLogLevel.CONTENT, message="Processed request dict: {request}", request=request)
+ return request
+
+ def process_messages(self, request):
+ """将 messages 格式转换为 prompt + multimodal_data(通用,与模型无关)。
+
+ 职责:
+ 1. 从 messages 中提取多模态内容(图片/视频)
+ → 写入 request["multimodal_data"] = {"image": [...], "video": [...], "mm_order": [...]}
+ 2. 调用 tokenizer.apply_chat_template(messages) 拼接 prompt
+ → 写入 request["prompt"]
+
+ 调用时机:request 含 "messages" 且尚未有 "prompt"/"prompt_token_ids" 时。
+ """
+ messages = request.get("messages")
+
+ # Apply chat_template_kwargs
+ chat_template_kwargs = request.get("chat_template_kwargs", {})
+ if chat_template_kwargs:
+ if isinstance(chat_template_kwargs, dict):
+ for k, v in chat_template_kwargs.items():
+ if k not in request or request[k] is None:
+ request[k] = v
+ else:
+ raise ValueError("Invalid input: chat_template_kwargs must be a dict")
+
+ request.setdefault("enable_thinking", True)
+
+ # Step 1: 解析 messages(下载图片/视频,转为标准格式)
+ from fastdeploy.entrypoints.chat_utils import parse_chat_messages
+
+ parsed_messages = parse_chat_messages(messages)
+
+ # Step 2: 从解析后的 messages 中提取多模态内容
+ images, videos, mm_order = [], [], []
+ for msg in parsed_messages:
+ content = msg.get("content") if isinstance(msg, dict) else None
+ if not isinstance(content, list):
+ continue
+ for item in content:
+ if isinstance(item, dict):
+ if item.get("type") == "image":
+ images.append(item)
+ mm_order.append("image")
+ elif item.get("type") == "video":
+ videos.append(item)
+ mm_order.append("video")
+
+ if images or videos:
+ request["multimodal_data"] = {"image": images, "video": videos, "mm_order": mm_order}
+
+ # Step 3: apply_chat_template → prompt
+ if self.tokenizer.chat_template is None:
+ raise ValueError("This model does not support chat_template.")
+
+ if self.tokenizer_type == "ernie4_5":
+ prompt = self.tokenizer.apply_chat_template(
+ request,
+ tokenize=False,
+ add_generation_prompt=request.get("add_generation_prompt", True),
+ **chat_template_kwargs,
+ )
+ else:
+ prompt = self.tokenizer.apply_chat_template(
+ parsed_messages,
+ tokenize=False,
+ add_generation_prompt=request.get("add_generation_prompt", True),
+ **chat_template_kwargs,
+ )
+
+ request["prompt"] = prompt
+ request["prompt_tokens"] = prompt
+
+ def _has_multimodal_content(self, request):
+ """Check if request contains multimodal data."""
+ # Check multimodal_data field (from prompt path)
+ multimodal_data = request.get("multimodal_data") or {}
+ if multimodal_data.get("image") or multimodal_data.get("video"):
+ return True
+ # Check messages for multimodal content
+ messages = request.get("messages")
+ if messages:
+ for msg in messages:
+ content = msg.get("content") if isinstance(msg, dict) else None
+ if isinstance(content, list):
+ for part in content:
+ if isinstance(part, dict) and part.get("type") in (
+ "image",
+ "image_url",
+ "video",
+ "video_url",
+ ):
+ return True
+ return False
+
+ def _tokenize_text_request(self, request, max_model_len=None):
+ """Text-only tokenization path."""
+ if not request.get("prompt_token_ids"):
+ if request.get("prompt"):
+ prompt = request.get("prompt")
+ assert isinstance(prompt, str) or (
+ isinstance(prompt, list) and all(isinstance(t, int) for t in prompt)
+ ), f"prompt must be a string or a list of integers, but got {type(prompt)}"
+ if isinstance(prompt, list):
+ request["prompt_token_ids"] = prompt
+ else:
+ request["prompt_tokens"] = prompt
+ add_special_tokens = request.get("add_special_tokens", False)
+ token_ids = self.text2ids(prompt, max_model_len, add_special_tokens=add_special_tokens)
+ if hasattr(token_ids, "tolist"):
+ token_ids = token_ids.tolist()
+ request["prompt_token_ids"] = token_ids
+ else:
+ raise ValueError(f"Request must contain 'prompt_token_ids', 'prompt', or 'messages': {request}")
+
+ if len(request["prompt_token_ids"]) == 0:
+ raise ValueError("Invalid input: prompt_token_ids must be a non-empty sequence of token IDs")
+
+ if request.get("completion_token_ids"):
+ request["prompt_token_ids"].extend(request["completion_token_ids"])
+
+ # ------------------------------------------------------------------
+ # Reasoning parser
+ # ------------------------------------------------------------------
+
+ def _apply_reasoning_parser(self, request):
+ """Apply reasoning parser to determine model thinking status."""
+ model_status = self.reasoning_parser.get_model_status(request["prompt_token_ids"])
+ parts = request["request_id"].split("_")
+ if len(parts) > 1:
+ real_req_id = parts[0]
+ index = int(parts[1])
+ n = request.get("n", 1)
+ for idx in range(index * n, (index + 1) * n):
+ self.model_status_dict[f"{real_req_id}_{idx}"] = model_status
+ else:
+ self.model_status_dict[request["request_id"]] = model_status
+ request["enable_thinking"] = model_status == "think_start"
+
+ def clear_request_status(self, task_id):
+ """Clear all per-request decode state and return the accumulated text."""
+ results_all = ""
+ if task_id in self.decode_status:
+ if envs.FD_USE_HF_TOKENIZER:
+ results_all = self.decode_status[task_id][2]
+ else:
+ results_all = "".join(self.decode_status[task_id][3])
+ del self.decode_status[task_id]
+ return results_all
+
+ # ------------------------------------------------------------------
+ # Common utility methods
+ # ------------------------------------------------------------------
+
+ def update_stop_seq(self, stop_sequences):
+ """Convert stop strings to padded token-id sequences."""
+ if isinstance(stop_sequences, str):
+ stop_sequences = [stop_sequences]
+ stop_seqs = []
+ for seq in stop_sequences:
+ if seq != self.tokenizer.eos_token_id:
+ stop_seqs.append(self.tokenizer.convert_tokens_to_ids(self.tokenizer.tokenize(seq)))
+ stop_seqs, stop_seqs_len = self.pad_batch_data(stop_seqs, pad_id=-1, return_seq_len=True, return_array=False)
+ log_request(
+ level=3,
+ message="processed stop_seqs: {stop_seqs}, {stop_seqs_len}",
+ stop_seqs=stop_seqs,
+ stop_seqs_len=stop_seqs_len,
+ )
+ return stop_seqs, stop_seqs_len
+
+ def _apply_default_parameters(self, request):
+ """Apply default values for sampling parameters in request."""
+
+ def set_value(req, key, value):
+ value = getattr(self.generation_config, key, value)
+ if isinstance(req, dict):
+ if key not in req or req[key] is None:
+ req[key] = value
+ else:
+ if req.get(key) is None:
+ req.set(key, value)
+
+ set_value(request, "top_p", 0.7)
+ set_value(request, "temperature", 1.0)
+ set_value(request, "repetition_penalty", 1.0)
+ set_value(request, "frequency_penalty", 0.0)
+ set_value(request, "presence_penalty", 0.0)
+ return request
+
+ def _encode_literal_text_with_cache(self, text):
+ if not hasattr(self, "_tokenize_cache"):
+ self._tokenize_cache = OrderedDict()
+ self._tokenize_cache_capacity = 128
+ key = ("literal_text", text)
+ cached = self._tokenize_cache.get(key)
+ if cached is not None:
+ self._tokenize_cache.move_to_end(key)
+ return cached
+ token_ids = self.tokenizer.convert_tokens_to_ids(self.tokenizer.tokenize(text))
+ if hasattr(token_ids, "tolist"):
+ token_ids = token_ids.tolist()
+ elif not isinstance(token_ids, list):
+ token_ids = list(token_ids)
+ self._tokenize_cache[key] = token_ids
+ if len(self._tokenize_cache) > self._tokenize_cache_capacity:
+ self._tokenize_cache.popitem(last=False)
+ return token_ids
+
+ def _get_think_token_ids(self):
+ think_token_ids = getattr(self, "_think_token_ids", None)
+ if think_token_ids is not None:
+ return think_token_ids
+ tokenizer = getattr(self, "tokenizer", None)
+ vocab = tokenizer.get_vocab() if tokenizer is not None else {}
+ think_start_id = vocab.get("", -1)
+ think_end_id = vocab.get("", -1)
+ self._think_token_ids = (think_start_id, think_end_id)
+ return self._think_token_ids
+
+ def _prepare_think_stop_sentence(self, logits_processors_args, max_model_len=None):
+ if not isinstance(logits_processors_args, dict):
+ return logits_processors_args
+ think_stop_sentence = logits_processors_args.get("think_stop_sentence")
+ if isinstance(think_stop_sentence, str) and think_stop_sentence:
+ sentence_token_ids = self._encode_literal_text_with_cache(think_stop_sentence)
+ logits_processors_args["think_stop_sentence_token_ids"] = sentence_token_ids
+ logits_processors_args.pop("think_stop_sentence", None)
+ return logits_processors_args
+
+ def _update_thinking_prompt_state(self, prompt_token_ids, logits_processors_args):
+ if not isinstance(logits_processors_args, dict):
+ return logits_processors_args
+ thinking_budget = logits_processors_args.get("thinking_budget")
+ if thinking_budget is None or not isinstance(thinking_budget, int) or thinking_budget < 0:
+ return logits_processors_args
+ if logits_processors_args.get("think_prompt_checked"):
+ return logits_processors_args
+ if prompt_token_ids is None:
+ return logits_processors_args
+ token_len = getattr(prompt_token_ids, "size", None) or len(prompt_token_ids)
+ if token_len == 0:
+ return logits_processors_args
+ think_start_id, think_end_id = self._get_think_token_ids()
+ if think_start_id < 0 or think_end_id < 0:
+ return logits_processors_args
+
+ if hasattr(prompt_token_ids, "tolist"):
+ token_list = prompt_token_ids.tolist()
+ else:
+ token_list = list(prompt_token_ids)
+
+ started = False
+ ended = False
+ tokens_after_start = 0
+ last_token_id = None
+ in_thinking = False
+ for token_id in token_list:
+ if token_id == think_start_id:
+ started = True
+ ended = False
+ in_thinking = True
+ elif token_id == think_end_id and in_thinking:
+ ended = True
+ in_thinking = False
+ if started and token_list:
+ last_token_id = int(token_list[-1])
+
+ logits_processors_args["think_prompt_checked"] = True
+ logits_processors_args["think_prompt_started"] = started
+ logits_processors_args["think_prompt_ended"] = ended
+ logits_processors_args["think_prompt_tokens_after_start"] = tokens_after_start
+ if last_token_id is not None:
+ logits_processors_args["think_prompt_last_token_id"] = last_token_id
+ else:
+ logits_processors_args.pop("think_prompt_last_token_id", None)
+ return logits_processors_args
+
+ def update_bad_words(self, bad_words, bad_words_token_ids):
+ """Tokenize bad-word strings and merge with existing bad-word token ids."""
+ token_ids = bad_words_token_ids
+ if token_ids is None:
+ token_ids = []
+ for bad_word in bad_words:
+ for add_prefix_space in [False, True]:
+ prefix = " " if add_prefix_space else ""
+ prompt = prefix + bad_word.lstrip()
+ prompt_token_ids = self.tokenizer.convert_tokens_to_ids(self.tokenizer.tokenize(prompt))
+ if len(prompt_token_ids) != 1:
+ if not add_prefix_space:
+ log_request(
+ level=1,
+ message="bad_words: '{prompt}' tokenises to {num_tokens} tokens, skipping",
+ prompt=prompt,
+ num_tokens=len(prompt_token_ids),
+ )
+ continue
+ if prompt_token_ids[0] > self.tokenizer.vocab_size:
+ if not add_prefix_space:
+ log_request(
+ level=1,
+ message="bad_words: '{prompt}' token id {token_id} > vocab_size, skipping",
+ prompt=prompt,
+ token_id=prompt_token_ids[0],
+ )
+ continue
+ if prompt_token_ids not in token_ids:
+ token_ids.extend(prompt_token_ids)
+ return token_ids
+
+ def get_pad_id(self):
+ """Return the padding token id, with LlamaTokenizer fallback."""
+ if isinstance(self.tokenizer, (LlamaTokenizer, Llama3Tokenizer)) and not self.tokenizer.pad_token_id:
+ return self.tokenizer.eos_token
+ return self.tokenizer.pad_token_id
+
+ def pad_batch_data(self, insts, pad_id=0, return_seq_len=False, return_array=True, pad_style="right"):
+ """Pad a list of variable-length lists to a rectangular array."""
+ if len(insts) == 0:
+ padded_insts = np.array([[]], dtype=np.int64) if return_array else [[]]
+ if return_seq_len:
+ seq_len = np.array([], dtype=np.int64) if return_array else []
+ return padded_insts, seq_len
+ return padded_insts
+ max_len = max(map(len, insts))
+ if pad_style == "left":
+ padded_insts = [[pad_id] * (max_len - len(inst)) + list(inst) for inst in insts]
+ else:
+ padded_insts = [list(inst) + [pad_id] * (max_len - len(inst)) for inst in insts]
+ if return_array:
+ padded_insts = np.array(padded_insts, dtype=np.int64).reshape([-1, max_len])
+ if return_seq_len:
+ seq_len = [len(inst) for inst in insts]
+ if return_array:
+ seq_len = np.array(seq_len, dtype=np.int64).reshape(-1, 1)
+ return padded_insts, seq_len
+ return padded_insts
+
+ def process_logprob_response(self, token_ids, **kwargs):
+ """Decode a list of token ids to a string for logprob responses."""
+ return self.tokenizer.decode(token_ids, **kwargs)
+
+ def get_mm_max_tokens_per_item(self, seq_len: int):
+ """Return the maximum number of tokens per item for each modality."""
+ if self.mm_processor is not None:
+ return self.mm_processor.get_mm_max_tokens_per_item(seq_len)
+ return None
+
+ def append_completion_tokens(self, multimodal_inputs, completion_token_ids):
+ """Append completion tokens — delegates to mm_processor if available."""
+ if self.mm_processor is not None:
+ self.mm_processor.append_completion_tokens(multimodal_inputs, completion_token_ids)
diff --git a/tests/entrypoints/test_chat.py b/tests/entrypoints/test_chat.py
index 75ff3a3e050..1205f326ff1 100644
--- a/tests/entrypoints/test_chat.py
+++ b/tests/entrypoints/test_chat.py
@@ -93,15 +93,15 @@ def test_chat_with_tools(self):
data_processor = self.llm.llm_engine.data_processor
captured_spliced_message = None
- def capture_spliced_message(request_or_messages, **kwargs):
- """Wrap original messages2ids to capture spliced_message"""
- token_ids = data_processor.original_messages2ids(request_or_messages, **kwargs)
+ original_process_messages = data_processor.process_messages
+
+ def capture_process_messages(request):
+ """Wrap original process_messages to capture prompt_tokens"""
+ original_process_messages(request)
nonlocal captured_spliced_message
- captured_spliced_message = request_or_messages.get("prompt_tokens")
- return token_ids
+ captured_spliced_message = request.get("prompt_tokens")
- data_processor.original_messages2ids = data_processor.messages2ids
- data_processor.messages2ids = capture_spliced_message
+ data_processor.process_messages = capture_process_messages
try:
outputs = self.llm.chat(
@@ -112,7 +112,7 @@ def capture_spliced_message(request_or_messages, **kwargs):
stream=False,
)
- self.assertIsNotNone(captured_spliced_message, "Failed to capture spliced_message from messages2ids")
+ self.assertIsNotNone(captured_spliced_message, "Failed to capture spliced_message from process_messages")
self.assertIn(
"",
captured_spliced_message,
@@ -124,7 +124,7 @@ def capture_spliced_message(request_or_messages, **kwargs):
self.assertTrue(hasattr(output, "outputs"))
self.assertTrue(hasattr(output.outputs, "text"))
finally:
- data_processor.messages2ids = data_processor.original_messages2ids
+ data_processor.process_messages = original_process_messages
def test_validate_tools(self):
"""Test both valid and invalid scenarios for _validate_tools method"""
diff --git a/tests/entrypoints/test_generation.py b/tests/entrypoints/test_generation.py
index 1e238cd350f..1e528eaa223 100644
--- a/tests/entrypoints/test_generation.py
+++ b/tests/entrypoints/test_generation.py
@@ -123,15 +123,13 @@ def test_multiple_sampling_params(self):
self.assertEqual(len(self.PROMPTS), len(outputs))
def test_consistency_single_prompt_tokens_chat(self):
- """Test consistency between different prompt input formats"""
+ """Test deterministic output for prompt_token_ids via chat interface"""
sampling_params = SamplingParams(temperature=1.0, top_p=0.0)
for prompt_token_ids in self.TOKEN_IDS:
with self.subTest(prompt_token_ids=prompt_token_ids):
output1 = self.llm.chat(messages=[prompt_token_ids], sampling_params=sampling_params)
- output2 = self.llm.chat(
- [{"prompt": "", "prompt_token_ids": prompt_token_ids}], sampling_params=sampling_params
- )
+ output2 = self.llm.chat(messages=[prompt_token_ids], sampling_params=sampling_params)
self.assert_outputs_equal(output1, output2)
def test_multiple_sampling_params_chat(self):
diff --git a/tests/input/multimodal/test_common.py b/tests/input/multimodal/test_common.py
new file mode 100644
index 00000000000..4d5b1898bcb
--- /dev/null
+++ b/tests/input/multimodal/test_common.py
@@ -0,0 +1,170 @@
+# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Unit tests for fastdeploy.input.multimodal.common."""
+
+import unittest
+
+import numpy as np
+
+from fastdeploy.input.multimodal.common import (
+ ceil_by_factor,
+ floor_by_factor,
+ is_scaled_image,
+ round_by_factor,
+ smart_resize,
+ smart_resize_paddleocr,
+ smart_resize_qwen,
+)
+
+
+class TestRoundByFactor(unittest.TestCase):
+ def test_exact_multiple(self):
+ self.assertEqual(round_by_factor(28, 28), 28)
+
+ def test_round_up(self):
+ self.assertEqual(round_by_factor(15, 14), 14)
+
+ def test_round_down(self):
+ self.assertEqual(round_by_factor(20, 14), 14)
+
+ def test_zero(self):
+ self.assertEqual(round_by_factor(0, 28), 0)
+
+
+class TestCeilByFactor(unittest.TestCase):
+ def test_exact_multiple(self):
+ self.assertEqual(ceil_by_factor(28, 28), 28)
+
+ def test_round_up(self):
+ self.assertEqual(ceil_by_factor(29, 28), 56)
+
+ def test_zero(self):
+ self.assertEqual(ceil_by_factor(0, 14), 0)
+
+
+class TestFloorByFactor(unittest.TestCase):
+ def test_exact_multiple(self):
+ self.assertEqual(floor_by_factor(56, 28), 56)
+
+ def test_floor_down(self):
+ self.assertEqual(floor_by_factor(55, 28), 28)
+
+ def test_zero(self):
+ self.assertEqual(floor_by_factor(0, 14), 0)
+
+
+class TestIsScaledImage(unittest.TestCase):
+ def test_uint8_not_scaled(self):
+ img = np.zeros((10, 10, 3), dtype=np.uint8)
+ self.assertFalse(is_scaled_image(img))
+
+ def test_float_in_range_scaled(self):
+ img = np.random.rand(10, 10, 3).astype(np.float32)
+ self.assertTrue(is_scaled_image(img))
+
+ def test_float_out_of_range_not_scaled(self):
+ img = np.array([[0.0, 2.0]], dtype=np.float32)
+ self.assertFalse(is_scaled_image(img))
+
+ def test_float_negative_not_scaled(self):
+ img = np.array([[-0.1, 0.5]], dtype=np.float32)
+ self.assertFalse(is_scaled_image(img))
+
+
+class TestSmartResizeQwen(unittest.TestCase):
+ def test_normal_image(self):
+ h, w = smart_resize_qwen(224, 224, factor=28, min_pixels=56 * 56, max_pixels=28 * 28 * 1280)
+ self.assertEqual(h % 28, 0)
+ self.assertEqual(w % 28, 0)
+
+ def test_high_aspect_ratio_height(self):
+ """Height >> width triggers aspect ratio clamping."""
+ h, w = smart_resize_qwen(10000, 10, factor=28, min_pixels=56 * 56, max_pixels=28 * 28 * 1280)
+ self.assertLessEqual(max(h, w) / min(h, w), 200)
+ self.assertEqual(h % 28, 0)
+ self.assertEqual(w % 28, 0)
+
+ def test_high_aspect_ratio_width(self):
+ """Width >> height triggers aspect ratio clamping."""
+ h, w = smart_resize_qwen(10, 10000, factor=28, min_pixels=56 * 56, max_pixels=28 * 28 * 1280)
+ self.assertLessEqual(max(h, w) / min(h, w), 200)
+ self.assertEqual(h % 28, 0)
+ self.assertEqual(w % 28, 0)
+
+ def test_too_large_scales_down(self):
+ h, w = smart_resize_qwen(10000, 10000, factor=28, min_pixels=56 * 56, max_pixels=28 * 28 * 1280)
+ self.assertLessEqual(h * w, 28 * 28 * 1280)
+
+ def test_too_small_scales_up(self):
+ h, w = smart_resize_qwen(10, 10, factor=28, min_pixels=56 * 56, max_pixels=28 * 28 * 1280)
+ self.assertGreaterEqual(h * w, 56 * 56)
+
+ def test_invalid_raises(self):
+ with self.assertRaises(ValueError):
+ smart_resize_qwen(1, 1, factor=100000, min_pixels=100, max_pixels=1000)
+
+
+class TestSmartResizePaddleocr(unittest.TestCase):
+ def test_normal_image(self):
+ h, w = smart_resize_paddleocr(224, 224)
+ self.assertEqual(h % 28, 0)
+ self.assertEqual(w % 28, 0)
+
+ def test_height_below_factor(self):
+ """Height < factor triggers rescale."""
+ h, w = smart_resize_paddleocr(10, 100, factor=28)
+ self.assertGreaterEqual(h, 28)
+ self.assertEqual(h % 28, 0)
+ self.assertEqual(w % 28, 0)
+
+ def test_width_below_factor(self):
+ """Width < factor triggers rescale."""
+ h, w = smart_resize_paddleocr(100, 10, factor=28)
+ self.assertGreaterEqual(w, 28)
+ self.assertEqual(h % 28, 0)
+ self.assertEqual(w % 28, 0)
+
+ def test_extreme_aspect_ratio_raises(self):
+ with self.assertRaisesRegex(ValueError, "aspect ratio"):
+ smart_resize_paddleocr(6000, 28, factor=28)
+
+ def test_above_max_pixels_scales_down(self):
+ h, w = smart_resize_paddleocr(2000, 2000, factor=28, max_pixels=28 * 28 * 100)
+ self.assertLessEqual(h * w, 28 * 28 * 100)
+
+ def test_below_min_pixels_scales_up(self):
+ h, w = smart_resize_paddleocr(56, 56, factor=28, min_pixels=28 * 28 * 130)
+ self.assertGreaterEqual(h * w, 28 * 28 * 130)
+
+
+class TestSmartResizeDispatcher(unittest.TestCase):
+ def test_qwen_variant(self):
+ h, w = smart_resize(224, 224, factor=28, min_pixels=56 * 56, max_pixels=28 * 28 * 1280, variant="qwen")
+ self.assertEqual(h % 28, 0)
+ self.assertEqual(w % 28, 0)
+
+ def test_paddleocr_variant(self):
+ h, w = smart_resize(224, 224, factor=28, min_pixels=56 * 56, max_pixels=28 * 28 * 1280, variant="paddleocr")
+ self.assertEqual(h % 28, 0)
+ self.assertEqual(w % 28, 0)
+
+ def test_default_is_qwen(self):
+ h1, w1 = smart_resize(224, 224, factor=28, min_pixels=56 * 56, max_pixels=28 * 28 * 1280)
+ h2, w2 = smart_resize_qwen(224, 224, factor=28, min_pixels=56 * 56, max_pixels=28 * 28 * 1280)
+ self.assertEqual((h1, w1), (h2, w2))
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/tests/input/multimodal/test_ernie4_5_vl.py b/tests/input/multimodal/test_ernie4_5_vl.py
new file mode 100644
index 00000000000..06088a3be97
--- /dev/null
+++ b/tests/input/multimodal/test_ernie4_5_vl.py
@@ -0,0 +1,495 @@
+# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Unit tests for Ernie4_5VLProcessor."""
+
+import unittest
+from collections import defaultdict
+from unittest.mock import MagicMock, patch
+
+import numpy as np
+
+from fastdeploy.input.multimodal.ernie4_5_vl import Ernie4_5VLProcessor
+from fastdeploy.input.multimodal.mm_processor import MMContext, MMItem, TokenizationPath
+from fastdeploy.input.utils import IDS_TYPE_FLAG
+
+
+def _make_ernie_processor(**overrides):
+ """Create an Ernie4_5VLProcessor with mocked dependencies."""
+ with patch.object(Ernie4_5VLProcessor, "__init__", return_value=None):
+ proc = Ernie4_5VLProcessor.__new__(Ernie4_5VLProcessor)
+
+ proc.tokenizer = MagicMock()
+ token_map = {
+ "<|IMAGE_PLACEHOLDER|>": 204,
+ "<|IMAGE_START|>": 200,
+ "<|IMAGE_END|>": 201,
+ "<|VIDEO_START|>": 202,
+ "<|VIDEO_END|>": 203,
+ }
+ proc.tokenizer.convert_tokens_to_ids.side_effect = lambda x: token_map.get(x, 999)
+ proc.tokenizer.tokenize.return_value = ["tok"]
+
+ proc.model_name_or_path = "test-model"
+ proc.config = None
+ proc._cache = None
+ proc.enable_processor_cache = False
+
+ proc.image_placeholder = "<|image@placeholder|>"
+ proc.video_placeholder = "<|video@placeholder|>"
+ proc.image_token_str = "<|IMAGE_PLACEHOLDER|>"
+ proc.video_token_str = "<|IMAGE_PLACEHOLDER|>"
+ proc.tokenizer_type = "ernie4_5"
+
+ proc.IMG_START = "<|IMAGE_START|>"
+ proc.IMG_END = "<|IMAGE_END|>"
+ proc.VID_START = "<|VIDEO_START|>"
+ proc.VID_END = "<|VIDEO_END|>"
+
+ proc.image_token_id = 204
+ proc.video_token_id = 204
+
+ proc.spatial_conv_size = 2
+ proc.temporal_conv_size = 2
+
+ proc.image_min_pixels = 4 * 28 * 28
+ proc.image_max_pixels = 6177 * 28 * 28
+ proc.video_min_pixels = 299 * 28 * 28
+ proc.video_max_pixels = 1196 * 28 * 28
+ proc.frames_sample = "leading"
+
+ proc.fps = 2.0
+ proc.min_frames = 4
+ proc.max_frames = 768
+ proc.target_frames = -1
+
+ proc.limit_mm_per_prompt = {"image": 10, "video": 10, "audio": 1}
+
+ # Build token_type_mapping
+ mapping = defaultdict(lambda: IDS_TYPE_FLAG["text"])
+ for token in ("<|IMAGE_START|>", "<|IMAGE_END|>", "<|VIDEO_START|>", "<|VIDEO_END|>"):
+ mapping[token] = IDS_TYPE_FLAG["image"]
+ mapping[204] = IDS_TYPE_FLAG["image"]
+ proc.token_type_mapping = mapping
+
+ # Mock image processor
+ mock_ip = MagicMock()
+ mock_ip.merge_size = 2
+ mock_ip.temporal_conv_size = 2
+ proc.image_processor = mock_ip
+
+ for k, v in overrides.items():
+ setattr(proc, k, v)
+ return proc
+
+
+# ==================================================================
+# Test classes
+# ==================================================================
+
+
+class TestErnieInitExtra(unittest.TestCase):
+ def test_init_extra_defaults(self):
+ proc = _make_ernie_processor()
+ self.assertEqual(proc.image_min_pixels, 4 * 28 * 28)
+ self.assertEqual(proc.image_max_pixels, 6177 * 28 * 28)
+ self.assertEqual(proc.video_min_pixels, 299 * 28 * 28)
+ self.assertEqual(proc.video_max_pixels, 1196 * 28 * 28)
+
+ def test_init_extra_custom(self):
+ proc = _make_ernie_processor(image_min_pixels=100, image_max_pixels=200)
+ self.assertEqual(proc.image_min_pixels, 100)
+ self.assertEqual(proc.image_max_pixels, 200)
+
+
+class TestErnieMakeOutputs(unittest.TestCase):
+ def test_no_fps_no_vit(self):
+ proc = _make_ernie_processor()
+ outputs = proc._make_outputs()
+ self.assertNotIn("fps", outputs)
+ self.assertNotIn("vit_seqlen", outputs)
+ self.assertIn("input_ids", outputs)
+ self.assertIn("mm_hashes", outputs)
+
+
+class TestErnieComputePositions(unittest.TestCase):
+ def setUp(self):
+ self.proc = _make_ernie_processor()
+
+ def test_compute_3d_positions_image(self):
+ """t=1 image: 1*1*1=1 token (h/2=1, w/2=1, t/2=1 but t==1 so t_eff=1)."""
+ pos = self.proc._compute_3d_positions(t=1, h=2, w=2, start_idx=0)
+ # t_eff=1 (since t==1), gh=1, gw=1 -> 1 token
+ self.assertEqual(len(pos), 1)
+ self.assertEqual(pos[0], [0, 0, 0])
+
+ def test_compute_3d_positions_video(self):
+ """t=4 video: t_eff=4//2=2, gh=2, gw=2 -> 2*2*2=8 tokens."""
+ pos = self.proc._compute_3d_positions(t=4, h=4, w=4, start_idx=5)
+ self.assertEqual(len(pos), 8)
+ # First frame tokens
+ self.assertEqual(pos[0], [5, 5, 5]) # time_idx=0
+ # Second frame tokens (time_idx=1)
+ self.assertEqual(pos[4], [6, 5, 5])
+
+ def test_add_text_positions(self):
+ proc = _make_ernie_processor()
+ outputs = proc._make_outputs()
+ proc.add_text_positions(outputs, 3)
+ self.assertEqual(len(outputs["position_ids"]), 3)
+ self.assertEqual(outputs["position_ids"][0], [0, 0, 0])
+ self.assertEqual(outputs["position_ids"][1], [1, 1, 1])
+ self.assertEqual(outputs["position_ids"][2], [2, 2, 2])
+ self.assertEqual(outputs["cur_position"], 3)
+
+
+class TestErniePreprocessImage(unittest.TestCase):
+ def setUp(self):
+ self.proc = _make_ernie_processor()
+ self.proc.image_processor.get_smarted_resize.return_value = ((56, 56), (2, 2))
+ self.proc.image_processor.preprocess.return_value = {
+ "pixel_values": np.ones((1, 3), dtype=np.float32),
+ "image_grid_thw": np.array([[1, 2, 2]]),
+ }
+
+ def test_raw_image(self):
+ outputs = self.proc._make_outputs()
+ mock_img = MagicMock()
+ mock_img.height = 224
+ mock_img.width = 224
+ mock_img.convert.return_value = mock_img
+
+ self.proc.preprocess_image(mock_img, outputs, uuid="img_uuid")
+
+ # patches_h=2, patches_w=2, num_tokens = 4 // 4 = 1
+ self.assertEqual(len(outputs["images"]), 1)
+ self.assertEqual(outputs["input_ids"], [204])
+ self.assertEqual(outputs["token_type_ids"], [IDS_TYPE_FLAG["image"]])
+ self.assertEqual(outputs["num_input_image_tokens"], 1)
+
+ def test_cached_image(self):
+ outputs = self.proc._make_outputs()
+ cached_pixels = np.ones((4, 3), dtype=np.float32) # 4 // 4 = 1 token
+ meta = {"thw": (1, 2, 2)}
+
+ self.proc.preprocess_cached_image((cached_pixels, meta), outputs, uuid="u1")
+
+ self.assertEqual(len(outputs["images"]), 1)
+ self.assertEqual(outputs["input_ids"], [204])
+
+ def test_cached_image_token_mismatch(self):
+ outputs = self.proc._make_outputs()
+ cached_pixels = np.ones((4, 3), dtype=np.float32)
+ meta = {"thw": (1, 2, 2)}
+
+ with self.assertRaises(ValueError):
+ self.proc.preprocess_cached_image((cached_pixels, meta), outputs, uuid="u", token_len=999)
+
+ def test_token_len_mismatch(self):
+ outputs = self.proc._make_outputs()
+ mock_img = MagicMock()
+ mock_img.height = 224
+ mock_img.width = 224
+
+ with self.assertRaises(ValueError):
+ self.proc.preprocess_image(mock_img, outputs, uuid="u", token_len=999)
+
+
+class TestErniePreprocessVideo(unittest.TestCase):
+ def setUp(self):
+ self.proc = _make_ernie_processor()
+ self.proc.image_processor.get_smarted_resize.return_value = ((56, 56), (2, 2))
+ self.proc.image_processor.preprocess.return_value = {
+ "pixel_values_videos": np.ones((4, 3), dtype=np.float32),
+ "video_grid_thw": np.array([[4, 2, 2]]),
+ }
+
+ def test_raw_video(self):
+ outputs = self.proc._make_outputs()
+ # 4 frames, each is PIL-like
+ frames = [MagicMock() for _ in range(4)]
+ for f in frames:
+ f.height = 224
+ f.width = 224
+ f.convert.return_value = f
+
+ self.proc.preprocess_video(frames, outputs, uuid="vid_uuid")
+
+ # patches_h=2, patches_w=2, num_frames=4
+ # num_tokens = (4*2*2) / (2*2*2) = 16/8 = 2
+ self.assertEqual(len(outputs["images"]), 1)
+ self.assertEqual(len(outputs["input_ids"]), 2)
+ self.assertEqual(outputs["token_type_ids"], [IDS_TYPE_FLAG["video"]] * 2)
+
+ def test_cached_video(self):
+ outputs = self.proc._make_outputs()
+ # 8 pixels, spatial^2 * temporal = 4*2 = 8, num_tokens = 8//8 = 1
+ cached_pixels = np.ones((8, 3), dtype=np.float32)
+ meta = {"thw": (2, 2, 2)}
+
+ self.proc.preprocess_cached_video((cached_pixels, meta), outputs, uuid="v1")
+
+ self.assertEqual(len(outputs["images"]), 1)
+ self.assertEqual(len(outputs["input_ids"]), 1)
+
+ def test_cached_video_token_mismatch(self):
+ outputs = self.proc._make_outputs()
+ cached_pixels = np.ones((8, 3), dtype=np.float32)
+ meta = {"thw": (2, 2, 2)}
+
+ with self.assertRaises(ValueError):
+ self.proc.preprocess_cached_video((cached_pixels, meta), outputs, uuid="v", token_len=999)
+
+
+class TestErnieMmNumTokens(unittest.TestCase):
+ def test_image(self):
+ # t=1: t*h*w//4
+ result = Ernie4_5VLProcessor.mm_num_tokens([1, 4, 4])
+ self.assertEqual(result, 1 * 4 * 4 // 4)
+
+ def test_video(self):
+ # t>1: t*h*w//4//2
+ result = Ernie4_5VLProcessor.mm_num_tokens([4, 4, 4])
+ self.assertEqual(result, 4 * 4 * 4 // 4 // 2)
+
+ def test_list(self):
+ result = Ernie4_5VLProcessor.mm_num_tokens([[1, 4, 4], [4, 4, 4]])
+ self.assertEqual(result, [4, 8])
+
+ def test_empty(self):
+ result = Ernie4_5VLProcessor.mm_num_tokens([])
+ self.assertEqual(result, 0)
+
+
+class TestErniePromptTokenIds2Outputs(unittest.TestCase):
+ def setUp(self):
+ self.proc = _make_ernie_processor()
+
+ def test_text_only(self):
+ ctx = MMContext(
+ images=[],
+ videos=[],
+ mm_order=[],
+ path=TokenizationPath.PRETOKENIZED,
+ prompt_token_ids=[1, 2, 3],
+ )
+ outputs = self.proc.prompt_token_ids2outputs(ctx)
+ self.assertEqual(outputs["input_ids"], [1, 2, 3])
+ self.assertEqual(outputs["token_type_ids"], [IDS_TYPE_FLAG["text"]] * 3)
+ self.assertEqual(len(outputs["position_ids"]), 3)
+
+ def test_with_processed_image(self):
+ """Scans for IMG_START (200) ... IMG_END (201) boundary tokens."""
+ cached_pixels = np.ones((4, 3), dtype=np.float32) # 4//4=1 token
+ meta = {"thw": (1, 2, 2)}
+ img_item = MMItem(type="image", data=(cached_pixels, meta), uuid="u1")
+
+ ctx = MMContext(
+ images=[img_item],
+ videos=[],
+ mm_order=["image"],
+ path=TokenizationPath.PRETOKENIZED,
+ # text(1) IMG_START(200) placeholder(204) IMG_END(201) text(2)
+ prompt_token_ids=[1, 200, 204, 201, 2],
+ )
+ outputs = self.proc.prompt_token_ids2outputs(ctx)
+
+ self.assertIn(200, outputs["input_ids"]) # IMG_START
+ self.assertIn(201, outputs["input_ids"]) # IMG_END
+ self.assertEqual(len(outputs["images"]), 1)
+
+ def test_image_placeholder_overflow(self):
+ """More IMG_START tokens than images raises ValueError."""
+ img_item = MMItem(type="image", data=(np.ones((4, 3)), {"thw": (1, 2, 2)}), uuid="u1")
+ ctx = MMContext(
+ images=[img_item],
+ videos=[],
+ mm_order=["image"],
+ path=TokenizationPath.PRETOKENIZED,
+ # Two IMG_START tokens but only 1 image
+ prompt_token_ids=[200, 204, 201, 200, 204, 201],
+ )
+ with self.assertRaises(ValueError):
+ self.proc.prompt_token_ids2outputs(ctx)
+
+ def test_image_tokens_incomplete(self):
+ """Missing IMG_END token raises ValueError."""
+ img_item = MMItem(type="image", data=(np.ones((4, 3)), {"thw": (1, 2, 2)}), uuid="u1")
+ ctx = MMContext(
+ images=[img_item],
+ videos=[],
+ mm_order=["image"],
+ path=TokenizationPath.PRETOKENIZED,
+ prompt_token_ids=[200, 204, 204], # no IMG_END (201)
+ )
+ with self.assertRaises(ValueError):
+ self.proc.prompt_token_ids2outputs(ctx)
+
+ def test_video_placeholder_overflow(self):
+ vid_item = MMItem(type="video", data=(np.ones((8, 3)), {"thw": (2, 2, 2)}), uuid="v1")
+ ctx = MMContext(
+ images=[],
+ videos=[vid_item],
+ mm_order=["video"],
+ path=TokenizationPath.PRETOKENIZED,
+ prompt_token_ids=[202, 204, 203, 202, 204, 203], # 2 VID_START but only 1 video
+ )
+ with self.assertRaises(ValueError):
+ self.proc.prompt_token_ids2outputs(ctx)
+
+ def test_video_tokens_incomplete(self):
+ vid_item = MMItem(type="video", data=(np.ones((8, 3)), {"thw": (2, 2, 2)}), uuid="v1")
+ ctx = MMContext(
+ images=[],
+ videos=[vid_item],
+ mm_order=["video"],
+ path=TokenizationPath.PRETOKENIZED,
+ prompt_token_ids=[202, 204, 204], # no VID_END (203)
+ )
+ with self.assertRaises(ValueError):
+ self.proc.prompt_token_ids2outputs(ctx)
+
+ def test_image_count_mismatch(self):
+ """Fewer placeholders than images raises ValueError."""
+ img_item = MMItem(type="image", data=(np.ones((4, 3)), {"thw": (1, 2, 2)}), uuid="u1")
+ img_item2 = MMItem(type="image", data=(np.ones((4, 3)), {"thw": (1, 2, 2)}), uuid="u2")
+ ctx = MMContext(
+ images=[img_item, img_item2],
+ videos=[],
+ mm_order=["image", "image"],
+ path=TokenizationPath.PRETOKENIZED,
+ prompt_token_ids=[200, 204, 201], # only 1 placeholder
+ )
+ with self.assertRaises(ValueError):
+ self.proc.prompt_token_ids2outputs(ctx)
+
+ def test_video_count_mismatch(self):
+ vid_item = MMItem(type="video", data=(np.ones((8, 3)), {"thw": (2, 2, 2)}), uuid="v1")
+ vid_item2 = MMItem(type="video", data=(np.ones((8, 3)), {"thw": (2, 2, 2)}), uuid="v2")
+ ctx = MMContext(
+ images=[],
+ videos=[vid_item, vid_item2],
+ mm_order=["video", "video"],
+ path=TokenizationPath.PRETOKENIZED,
+ prompt_token_ids=[202, 204, 203], # only 1 placeholder
+ )
+ with self.assertRaises(ValueError):
+ self.proc.prompt_token_ids2outputs(ctx)
+
+
+class TestErnieSetVideoFrameArgs(unittest.TestCase):
+ def setUp(self):
+ self.proc = _make_ernie_processor()
+ self.meta = {"duration": 10.0, "num_of_frame": 300}
+
+ def test_target_frames(self):
+ args = {"target_frames": 16, "fps": -1, "min_frames": 4, "max_frames": 768, "frames_sample": "leading"}
+ result = self.proc._set_video_frame_args(args, self.meta)
+ self.assertEqual(result["target_frames"], 16)
+
+ def test_fps_positive_with_target_raises(self):
+ args = {"target_frames": 16, "fps": 2.0, "min_frames": 4, "max_frames": 768, "frames_sample": "leading"}
+ with self.assertRaises(ValueError):
+ self.proc._set_video_frame_args(args, self.meta)
+
+ def test_below_min_raises(self):
+ args = {"target_frames": 2, "fps": -1, "min_frames": 4, "max_frames": 768, "frames_sample": "leading"}
+ with self.assertRaises(ValueError):
+ self.proc._set_video_frame_args(args, self.meta)
+
+ def test_above_max_raises(self):
+ args = {"target_frames": 1000, "fps": -1, "min_frames": 4, "max_frames": 768, "frames_sample": "leading"}
+ with self.assertRaises(ValueError):
+ self.proc._set_video_frame_args(args, self.meta)
+
+ def test_fps_negative_no_target_raises(self):
+ args = {"target_frames": -1, "fps": -1, "min_frames": 4, "max_frames": 768, "frames_sample": "leading"}
+ with self.assertRaises(ValueError):
+ self.proc._set_video_frame_args(args, self.meta)
+
+ def test_min_greater_than_max_raises(self):
+ args = {"target_frames": -1, "fps": 2.0, "min_frames": 100, "max_frames": 10, "frames_sample": "leading"}
+ with self.assertRaises(ValueError):
+ self.proc._set_video_frame_args(args, self.meta)
+
+ def test_fps_clamp_to_min(self):
+ """fps*duration < min_frames -> target_frames set to min_frames."""
+ args = {"target_frames": -1, "fps": 0.1, "min_frames": 4, "max_frames": 768, "frames_sample": "leading"}
+ result = self.proc._set_video_frame_args(args, self.meta)
+ self.assertEqual(result["target_frames"], 4)
+ self.assertEqual(result["fps"], -1)
+
+ def test_fps_clamp_to_max(self):
+ """fps*duration > max_frames -> target_frames set to max_frames."""
+ args = {"target_frames": -1, "fps": 100.0, "min_frames": 4, "max_frames": 768, "frames_sample": "leading"}
+ result = self.proc._set_video_frame_args(args, self.meta)
+ self.assertEqual(result["target_frames"], 768)
+ self.assertEqual(result["fps"], -1)
+
+
+class TestErnieGetMmMaxTokens(unittest.TestCase):
+ def test_returns_image_and_video(self):
+ proc = _make_ernie_processor()
+ proc.image_processor.get_smarted_resize.return_value = ((56, 56), (14, 14))
+ result = proc.get_mm_max_tokens_per_item(seq_len=99999)
+ self.assertIn("image", result)
+ self.assertIn("video", result)
+
+ def test_capped_by_seq_len(self):
+ proc = _make_ernie_processor()
+ proc.image_processor.get_smarted_resize.return_value = ((56, 56), (14, 14))
+ result = proc.get_mm_max_tokens_per_item(seq_len=10)
+ self.assertLessEqual(result["image"], 10)
+ self.assertLessEqual(result["video"], 10)
+
+
+class TestErnieAppendCompletionTokens(unittest.TestCase):
+ def test_appends_tokens_and_positions(self):
+ proc = _make_ernie_processor()
+ outputs = proc._make_outputs()
+
+ proc.append_completion_tokens(outputs, [50, 60, 70])
+
+ self.assertEqual(outputs["input_ids"], [50, 60, 70])
+ self.assertEqual(outputs["token_type_ids"], [IDS_TYPE_FLAG["text"]] * 3)
+ self.assertEqual(len(outputs["position_ids"]), 3)
+ self.assertEqual(outputs["position_ids"][0], [0, 0, 0])
+ self.assertEqual(outputs["position_ids"][2], [2, 2, 2])
+
+
+class TestErnieTokenTypeMapping(unittest.TestCase):
+ def test_boundary_tokens_mapped(self):
+ proc = _make_ernie_processor()
+ self.assertEqual(proc.token_type_mapping["<|IMAGE_START|>"], IDS_TYPE_FLAG["image"])
+ self.assertEqual(proc.token_type_mapping["<|IMAGE_END|>"], IDS_TYPE_FLAG["image"])
+ self.assertEqual(proc.token_type_mapping["<|VIDEO_START|>"], IDS_TYPE_FLAG["image"])
+ self.assertEqual(proc.token_type_mapping["<|VIDEO_END|>"], IDS_TYPE_FLAG["image"])
+ self.assertEqual(proc.token_type_mapping[204], IDS_TYPE_FLAG["image"])
+
+
+class TestErniePackPositionIds(unittest.TestCase):
+ def test_pack_position_ids(self):
+ proc = _make_ernie_processor()
+ outputs = proc._make_outputs()
+ outputs["position_ids"] = [[0, 0, 0], [1, 1, 1], [2, 2, 2]]
+
+ proc.pack_position_ids(outputs)
+
+ self.assertEqual(outputs["position_ids"].dtype, np.int64)
+ self.assertEqual(outputs["position_ids"].shape, (3, 3))
+ self.assertEqual(outputs["image_patch_id"], 204)
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/tests/input/multimodal/test_image_processors.py b/tests/input/multimodal/test_image_processors.py
new file mode 100644
index 00000000000..5cf0dade2e0
--- /dev/null
+++ b/tests/input/multimodal/test_image_processors.py
@@ -0,0 +1,490 @@
+# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Unit tests for multimodal image processors (paddleocr, qwen, qwen3, ernie)."""
+
+import unittest
+from unittest.mock import patch
+
+import numpy as np
+from PIL import Image
+
+from fastdeploy.input.multimodal.image_processors.paddleocr import (
+ PaddleOCRImageProcessor,
+ adjust_size,
+)
+from fastdeploy.input.multimodal.image_processors.paddleocr import (
+ make_batched_images as paddleocr_make_batched,
+)
+
+# ==================================================================
+# PaddleOCR ImageProcessor
+# ==================================================================
+
+
+class TestPaddleOCRAdjustSize(unittest.TestCase):
+ def test_even_patches(self):
+ # 224 // 14 = 16 (even) -> stays 224
+ self.assertEqual(adjust_size(224, 14), 224)
+
+ def test_odd_patches(self):
+ # 210 // 14 = 15 (odd) -> (15-1)*14 = 196
+ self.assertEqual(adjust_size(210, 14), 196)
+
+
+class TestPaddleOCRMakeBatchedImages(unittest.TestCase):
+ def test_single_image(self):
+ img = Image.new("RGB", (100, 100))
+ result = paddleocr_make_batched(img)
+ self.assertEqual(len(result), 1)
+
+ def test_list_of_images(self):
+ imgs = [Image.new("RGB", (100, 100)) for _ in range(3)]
+ result = paddleocr_make_batched(imgs)
+ self.assertEqual(len(result), 3)
+
+ def test_nested_list(self):
+ imgs = [[Image.new("RGB", (100, 100)) for _ in range(2)] for _ in range(2)]
+ result = paddleocr_make_batched(imgs)
+ self.assertEqual(len(result), 4)
+
+ def test_invalid_raises(self):
+ with self.assertRaises(ValueError):
+ paddleocr_make_batched("not_an_image")
+
+
+class TestPaddleOCRImageProcessorInit(unittest.TestCase):
+ def test_default_init(self):
+ proc = PaddleOCRImageProcessor()
+ self.assertEqual(proc.patch_size, 14)
+ self.assertEqual(proc.temporal_patch_size, 1)
+ self.assertEqual(proc.merge_size, 2)
+ self.assertTrue(proc.do_resize)
+ self.assertTrue(proc.do_rescale)
+ self.assertTrue(proc.do_normalize)
+
+ def test_custom_init(self):
+ proc = PaddleOCRImageProcessor(patch_size=16, merge_size=4, temporal_patch_size=2)
+ self.assertEqual(proc.patch_size, 16)
+ self.assertEqual(proc.merge_size, 4)
+ self.assertEqual(proc.temporal_patch_size, 2)
+
+
+class TestPaddleOCRImageProcessorFromPretrained(unittest.TestCase):
+ @patch("builtins.open", unittest.mock.mock_open(read_data='{"do_resize": false, "patch_size": 16}'))
+ def test_from_pretrained(self):
+ proc = PaddleOCRImageProcessor.from_pretrained("/fake/path")
+ self.assertFalse(proc.do_resize)
+ self.assertEqual(proc.patch_size, 16)
+
+
+class TestPaddleOCRImageProcessorPreprocess(unittest.TestCase):
+ def setUp(self):
+ self.proc = PaddleOCRImageProcessor(
+ min_pixels=28 * 28 * 4,
+ max_pixels=28 * 28 * 100,
+ )
+
+ def test_single_image(self):
+ img = Image.new("RGB", (224, 224))
+ result = self.proc.preprocess(images=[img])
+ self.assertIn("pixel_values", result)
+ self.assertIn("grid_thw", result)
+ # pixel_values: [N, C, patch_size, patch_size]
+ self.assertEqual(result["pixel_values"].ndim, 4)
+ self.assertEqual(result["pixel_values"].shape[1], 3)
+ self.assertEqual(result["pixel_values"].shape[2], 14)
+ self.assertEqual(result["pixel_values"].shape[3], 14)
+
+ def test_batch_images(self):
+ imgs = [Image.new("RGB", (224, 224)), Image.new("RGB", (224, 224))]
+ result = self.proc.preprocess(images=imgs)
+ single = self.proc.preprocess(images=[imgs[0]])
+ # batch should have double the patches
+ self.assertEqual(result["pixel_values"].shape[0], single["pixel_values"].shape[0] * 2)
+
+ def test_grid_thw_shape(self):
+ img = Image.new("RGB", (224, 224))
+ result = self.proc.preprocess(images=[img])
+ grid_thw = result["grid_thw"]
+ self.assertEqual(len(grid_thw), 3)
+ # t should be 1 for single image with temporal_patch_size=1
+ self.assertEqual(grid_thw[0], 1)
+
+ def test_no_resize(self):
+ img = Image.new("RGB", (56, 56))
+ result = self.proc.preprocess(images=[img], do_resize=False)
+ self.assertIn("pixel_values", result)
+
+ def test_no_rescale(self):
+ img = Image.new("RGB", (224, 224))
+ result = self.proc.preprocess(images=[img], do_rescale=False)
+ self.assertIn("pixel_values", result)
+
+ def test_no_normalize(self):
+ img = Image.new("RGB", (224, 224))
+ result = self.proc.preprocess(images=[img], do_normalize=False)
+ self.assertIn("pixel_values", result)
+
+ def test_custom_mean_std(self):
+ img = Image.new("RGB", (224, 224))
+ result = self.proc.preprocess(images=[img], image_mean=[0.5, 0.5, 0.5], image_std=[0.5, 0.5, 0.5])
+ self.assertIn("pixel_values", result)
+
+ def test_do_convert_rgb_false(self):
+ img = Image.new("RGB", (224, 224))
+ result = self.proc.preprocess(images=[img], do_convert_rgb=False)
+ self.assertIn("pixel_values", result)
+
+ def test_videos_not_implemented(self):
+ img = Image.new("RGB", (224, 224))
+ with self.assertRaises(NotImplementedError):
+ self.proc.preprocess(images=[img], videos=["video"])
+
+
+# ==================================================================
+# Qwen ImageProcessor
+# ==================================================================
+
+
+class TestQwenImageProcessorInit(unittest.TestCase):
+ def test_default_init(self):
+ from fastdeploy.input.multimodal.image_processors.qwen import QwenImageProcessor
+
+ proc = QwenImageProcessor()
+ self.assertEqual(proc.patch_size, 14)
+ self.assertEqual(proc.merge_size, 2)
+ self.assertEqual(proc.temporal_patch_size, 2)
+ self.assertTrue(proc.do_rescale)
+ self.assertTrue(proc.do_normalize)
+
+
+class TestQwenImageProcessorPreprocess(unittest.TestCase):
+ def setUp(self):
+ from fastdeploy.input.multimodal.image_processors.qwen import QwenImageProcessor
+
+ self.proc = QwenImageProcessor(min_pixels=4 * 28 * 28, max_pixels=100 * 28 * 28)
+
+ def test_single_image(self):
+ img = Image.new("RGB", (224, 224))
+ result = self.proc.preprocess(images=[img])
+ self.assertIn("pixel_values", result)
+ self.assertIn("grid_thw", result)
+ # pixel_values: [grid_t * grid_h * grid_w, C * temporal_patch_size * patch_size * patch_size]
+ self.assertEqual(result["pixel_values"].ndim, 2)
+
+ def test_grid_thw_values(self):
+ img = Image.new("RGB", (224, 224))
+ result = self.proc.preprocess(images=[img])
+ grid_thw = result["grid_thw"]
+ self.assertEqual(len(grid_thw), 3)
+ t, h, w = grid_thw
+ # For single image with temporal_patch_size=2, t should be 1
+ self.assertEqual(t, 1)
+ self.assertGreater(h, 0)
+ self.assertGreater(w, 0)
+
+ def test_video_frames(self):
+ """Multiple frames → temporal dimension > 1."""
+ frames = [Image.new("RGB", (224, 224)) for _ in range(4)]
+ result = self.proc.preprocess(images=frames)
+ grid_thw = result["grid_thw"]
+ t = grid_thw[0]
+ # 4 frames // temporal_patch_size=2 = 2
+ self.assertEqual(t, 2)
+
+ def test_odd_frames_padded(self):
+ """Odd number of frames gets padded to next multiple of temporal_patch_size."""
+ frames = [Image.new("RGB", (224, 224)) for _ in range(3)]
+ result = self.proc.preprocess(images=frames)
+ grid_thw = result["grid_thw"]
+ t = grid_thw[0]
+ # 3 frames -> padded to 4 -> t=4//2=2
+ self.assertEqual(t, 2)
+
+ def test_no_rescale_no_normalize(self):
+ img = Image.new("RGB", (224, 224))
+ result = self.proc.preprocess(images=[img], do_rescale=False, do_normalize=False)
+ self.assertIn("pixel_values", result)
+
+ def test_rescale_only(self):
+ img = Image.new("RGB", (224, 224))
+ result = self.proc.preprocess(images=[img], do_rescale=True, do_normalize=False)
+ self.assertIn("pixel_values", result)
+
+ def test_invalid_images_raises(self):
+ with self.assertRaises(ValueError):
+ self.proc.preprocess(images="invalid")
+
+ def test_custom_pixels(self):
+ img = Image.new("RGB", (224, 224))
+ result = self.proc.preprocess(images=[img], min_pixels=28 * 28, max_pixels=50 * 28 * 28)
+ self.assertIn("pixel_values", result)
+
+
+# ==================================================================
+# Qwen3 ImageProcessor
+# ==================================================================
+
+
+class TestQwen3ImageProcessorInit(unittest.TestCase):
+ def test_default_init(self):
+ from fastdeploy.input.multimodal.image_processors.qwen3 import (
+ Qwen3ImageProcessor,
+ )
+
+ proc = Qwen3ImageProcessor()
+ self.assertEqual(proc.patch_size, 16)
+ self.assertEqual(proc.merge_size, 2)
+ self.assertEqual(proc.temporal_patch_size, 2)
+ self.assertEqual(proc.image_mean, [0.5, 0.5, 0.5])
+ self.assertEqual(proc.image_std, [0.5, 0.5, 0.5])
+
+ def test_preprocess(self):
+ from fastdeploy.input.multimodal.image_processors.qwen3 import (
+ Qwen3ImageProcessor,
+ )
+
+ proc = Qwen3ImageProcessor(min_pixels=32 * 32, max_pixels=100 * 32 * 32)
+ img = Image.new("RGB", (224, 224))
+ result = proc.preprocess(images=[img])
+ self.assertIn("pixel_values", result)
+ self.assertIn("grid_thw", result)
+
+
+# ==================================================================
+# Ernie AdaptiveImageProcessor
+# ==================================================================
+
+
+class TestErnieAdaptiveImageProcessorInit(unittest.TestCase):
+ def test_default_init(self):
+ from fastdeploy.input.multimodal.image_processors.ernie import (
+ AdaptiveImageProcessor,
+ )
+
+ proc = AdaptiveImageProcessor()
+ self.assertEqual(proc.patch_size, 14)
+ self.assertEqual(proc.merge_size, 2)
+ self.assertEqual(proc.temporal_conv_size, 2)
+ self.assertTrue(proc.do_resize)
+
+ def test_custom_init(self):
+ from fastdeploy.input.multimodal.image_processors.ernie import (
+ AdaptiveImageProcessor,
+ )
+
+ proc = AdaptiveImageProcessor(min_pixels=100, max_pixels=50000, patch_size=16)
+ self.assertEqual(proc.min_pixels, 100)
+ self.assertEqual(proc.max_pixels, 50000)
+ self.assertEqual(proc.patch_size, 16)
+
+
+class TestErnieAdaptiveImageProcessorSetPixels(unittest.TestCase):
+ def setUp(self):
+ from fastdeploy.input.multimodal.image_processors.ernie import (
+ AdaptiveImageProcessor,
+ )
+
+ self.proc = AdaptiveImageProcessor()
+
+ def test_set_min_pixels(self):
+ self.proc.set_pixels(min_pixels=1000, msg="test")
+ self.assertEqual(self.proc.min_pixels, 1000)
+ self.assertEqual(self.proc.size["min_pixels"], 1000)
+
+ def test_set_max_pixels(self):
+ self.proc.set_pixels(max_pixels=100000, msg="test")
+ self.assertEqual(self.proc.max_pixels, 100000)
+ self.assertEqual(self.proc.size["max_pixels"], 100000)
+
+ def test_invalid_min_pixels_raises(self):
+ with self.assertRaises(AssertionError):
+ self.proc.set_pixels(min_pixels=-1, msg="test")
+
+ def test_invalid_max_pixels_raises(self):
+ with self.assertRaises(AssertionError):
+ self.proc.set_pixels(max_pixels=0, msg="test")
+
+
+class TestErnieAdaptiveGetSmartedResize(unittest.TestCase):
+ def setUp(self):
+ from fastdeploy.input.multimodal.image_processors.ernie import (
+ AdaptiveImageProcessor,
+ )
+
+ self.proc = AdaptiveImageProcessor(min_pixels=56 * 56, max_pixels=28 * 28 * 1280)
+
+ def test_default_pixels(self):
+ (rh, rw), (ph, pw) = self.proc.get_smarted_resize(224, 224)
+ self.assertEqual(rh % 28, 0)
+ self.assertEqual(rw % 28, 0)
+ self.assertEqual(ph, rh // 14)
+ self.assertEqual(pw, rw // 14)
+
+ def test_custom_pixels(self):
+ (rh, rw), (ph, pw) = self.proc.get_smarted_resize(224, 224, min_pixels=100, max_pixels=10000)
+ self.assertEqual(rh % 28, 0)
+ self.assertEqual(rw % 28, 0)
+
+
+class TestErnieMakeBatchedImages(unittest.TestCase):
+ def test_single_image(self):
+ from fastdeploy.input.multimodal.image_processors.ernie import (
+ make_batched_images,
+ )
+
+ img = Image.new("RGB", (100, 100))
+ result = make_batched_images(img)
+ self.assertEqual(len(result), 1)
+
+ def test_list_of_images(self):
+ from fastdeploy.input.multimodal.image_processors.ernie import (
+ make_batched_images,
+ )
+
+ imgs = [Image.new("RGB", (100, 100)) for _ in range(3)]
+ result = make_batched_images(imgs)
+ self.assertEqual(len(result), 3)
+
+ def test_nested_list(self):
+ from fastdeploy.input.multimodal.image_processors.ernie import (
+ make_batched_images,
+ )
+
+ imgs = [[Image.new("RGB", (100, 100)) for _ in range(2)] for _ in range(2)]
+ result = make_batched_images(imgs)
+ self.assertEqual(len(result), 4)
+
+ def test_invalid_raises(self):
+ from fastdeploy.input.multimodal.image_processors.ernie import (
+ make_batched_images,
+ )
+
+ with self.assertRaises(ValueError):
+ make_batched_images("invalid")
+
+
+class TestErnieMakeBatchedVideos(unittest.TestCase):
+ def test_list_of_pil_images(self):
+ from fastdeploy.input.multimodal.image_processors.ernie import (
+ make_batched_videos,
+ )
+
+ imgs = [Image.new("RGB", (100, 100)) for _ in range(4)]
+ result = make_batched_videos(imgs)
+ self.assertEqual(len(result), 1)
+ self.assertEqual(len(result[0]), 4)
+
+ def test_nested_list(self):
+ from fastdeploy.input.multimodal.image_processors.ernie import (
+ make_batched_videos,
+ )
+
+ videos = [[Image.new("RGB", (100, 100)) for _ in range(2)] for _ in range(3)]
+ result = make_batched_videos(videos)
+ self.assertEqual(len(result), 3)
+
+ def test_4d_ndarray(self):
+ from fastdeploy.input.multimodal.image_processors.ernie import (
+ make_batched_videos,
+ )
+
+ video = np.random.randint(0, 255, (4, 100, 100, 3), dtype=np.uint8)
+ result = make_batched_videos(video)
+ self.assertEqual(len(result), 1)
+
+ def test_list_of_4d_ndarrays(self):
+ from fastdeploy.input.multimodal.image_processors.ernie import (
+ make_batched_videos,
+ )
+
+ videos = [np.random.randint(0, 255, (4, 100, 100, 3), dtype=np.uint8)]
+ result = make_batched_videos(videos)
+ self.assertEqual(len(result), 1)
+
+ def test_invalid_raises(self):
+ from fastdeploy.input.multimodal.image_processors.ernie import (
+ make_batched_videos,
+ )
+
+ with self.assertRaises(ValueError):
+ make_batched_videos("invalid")
+
+
+class TestErnieAdaptivePreprocess(unittest.TestCase):
+ def setUp(self):
+ from fastdeploy.input.multimodal.image_processors.ernie import (
+ AdaptiveImageProcessor,
+ )
+
+ self.proc = AdaptiveImageProcessor(min_pixels=56 * 56, max_pixels=28 * 28 * 100)
+
+ def test_single_image(self):
+ img = Image.new("RGB", (224, 224))
+ result = self.proc.preprocess(images=img)
+ self.assertIn("pixel_values", result)
+ self.assertIn("image_grid_thw", result)
+
+ def test_multiple_images(self):
+ imgs = [Image.new("RGB", (224, 224)) for _ in range(3)]
+ result = self.proc.preprocess(images=imgs)
+ self.assertIn("pixel_values", result)
+ self.assertIn("image_grid_thw", result)
+
+ def test_video_input(self):
+ frames = [Image.new("RGB", (224, 224)) for _ in range(4)]
+ result = self.proc.preprocess(images=None, videos=frames)
+ self.assertIn("pixel_values_videos", result)
+ self.assertIn("video_grid_thw", result)
+
+ def test_invalid_images_raises(self):
+ with self.assertRaises(ValueError):
+ self.proc.preprocess(images="invalid_string")
+
+ def test_do_convert_rgb(self):
+ img = Image.new("L", (224, 224))
+ result = self.proc.preprocess(images=img, do_convert_rgb=True)
+ self.assertIn("pixel_values", result)
+
+ def test_predetermined_grid_thw(self):
+ img = Image.new("RGB", (224, 224))
+ result = self.proc.preprocess(images=img, predetermined_grid_thw=[(16, 16)])
+ self.assertIn("pixel_values", result)
+
+ def test_no_resize(self):
+ img = Image.new("RGB", (56, 56))
+ result = self.proc.preprocess(images=img, do_resize=False)
+ self.assertIn("pixel_values", result)
+
+ def test_no_rescale(self):
+ img = Image.new("RGB", (224, 224))
+ result = self.proc.preprocess(images=img, do_rescale=False)
+ self.assertIn("pixel_values", result)
+
+ def test_no_normalize(self):
+ img = Image.new("RGB", (224, 224))
+ result = self.proc.preprocess(images=img, do_normalize=False)
+ self.assertIn("pixel_values", result)
+
+ def test_both_images_and_videos(self):
+ imgs = [Image.new("RGB", (224, 224))]
+ videos = [[Image.new("RGB", (224, 224)) for _ in range(4)]]
+ result = self.proc.preprocess(images=imgs, videos=videos)
+ self.assertIn("pixel_values", result)
+ self.assertIn("pixel_values_videos", result)
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/tests/input/multimodal/test_mm_processor.py b/tests/input/multimodal/test_mm_processor.py
new file mode 100644
index 00000000000..623320bfc7c
--- /dev/null
+++ b/tests/input/multimodal/test_mm_processor.py
@@ -0,0 +1,802 @@
+# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Unit tests for MMProcessor base class, data classes, and _CacheClient."""
+
+import pickle
+import unittest
+from typing import Any, Tuple
+from unittest.mock import MagicMock, patch
+
+import numpy as np
+
+from fastdeploy.input.multimodal.mm_processor import (
+ _DEFAULT_MM_LIMITS,
+ MMContext,
+ MMItem,
+ MMProcessor,
+ TokenizationPath,
+ _CacheClient,
+)
+from fastdeploy.input.utils import IDS_TYPE_FLAG
+
+# ------------------------------------------------------------------
+# Concrete subclass for testing abstract base class methods
+# ------------------------------------------------------------------
+
+
+class _ConcreteProcessor(MMProcessor):
+ """Minimal concrete implementation for testing base class logic."""
+
+ image_placeholder = ""
+ video_placeholder = "