diff --git a/configs/qwen2.5-vl-32b-eagle3.json b/configs/qwen2.5-vl-32b-eagle3.json new file mode 100644 index 000000000..76aa04cdf --- /dev/null +++ b/configs/qwen2.5-vl-32b-eagle3.json @@ -0,0 +1,40 @@ +{ + "architectures": [ + "LlamaForCausalLMEagle3" + ], + "attention_bias": false, + "attention_dropout": 0.0, + "bos_token_id": 151643, + "eos_token_id": 151645, + "head_dim": 128, + "hidden_act": "silu", + "hidden_size": 5120, + "initializer_range": 0.02, + "intermediate_size": 18944, + "max_position_embeddings": 8192, + "max_window_layers": 28, + "model_type": "llama", + "target_model_type": "qwen2_5_vl", + "num_attention_heads": 28, + "num_hidden_layers": 1, + "num_key_value_heads": 4, + "rms_norm_eps": 1e-06, + "pretraining_tp": 1, + "rope_scaling": { + "type": "mrope", + "mrope_section": [ + 16, + 24, + 24 + ] + }, + "rope_theta": 1000000, + "sliding_window": 32768, + "tie_word_embeddings": false, + "torch_dtype": "bfloat16", + "transformers_version": "4.51.0", + "use_cache": true, + "use_sliding_window": false, + "vocab_size": 152064, + "draft_vocab_size": 32000 + } diff --git a/examples/run_qwen2.5_32b_vl_eagle3_online.sh b/examples/run_qwen2.5_32b_vl_eagle3_online.sh new file mode 100755 index 000000000..a7c86b0e5 --- /dev/null +++ b/examples/run_qwen2.5_32b_vl_eagle3_online.sh @@ -0,0 +1,32 @@ +#!/bin/bash + +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) +ROOT_DIR=$(dirname $SCRIPT_DIR) + +# support tp1 train eagle3 for qwen2.5-vl-7b-instruct +NUM_GPUS=${1:-1} +BUILD_DATASET_NUM_PROC=${BUILD_DATASET_NUM_PROC:-64} + +torchrun \ + --standalone \ + --nproc_per_node $NUM_GPUS \ + $ROOT_DIR/scripts/train_eagle3.py \ + --target-model-path Qwen/Qwen2.5-VL-32B-Instruct \ + --draft-model-config $ROOT_DIR/configs/qwen2.5-vl-32b-eagle3.json \ + --train-data-path $ROOT_DIR/cache/allava4v_train.jsonl \ + --build-dataset-num-proc $BUILD_DATASET_NUM_PROC \ + --output-dir $ROOT_DIR/outputs/qwen2.5-vl-32b-eagle3 \ + --num-epochs 10 \ + --batch-size 1 \ + --learning-rate 1e-4 \ + --max-length 4096 \ + --dist-timeout 360 \ + --chat-template qwen2-vl \ + --target-model-backend sglang \ + --cache-dir $ROOT_DIR/cache \ + --embedding-key model.embed_tokens.weight \ + --tp-size 4 \ + --sglang-mem-fraction-static 0.5 \ + --is-vlm \ + --min-pixels 200704 \ + --max-pixels 1003520 diff --git a/scripts/train_eagle3.py b/scripts/train_eagle3.py index 3d81e42e3..b5ab6a1d3 100644 --- a/scripts/train_eagle3.py +++ b/scripts/train_eagle3.py @@ -268,7 +268,7 @@ def build_target_model( if ( args.is_vlm and draft_model_config.target_model_type == "qwen2_5_vl" - and args.tp_size == 1 + and args.target_model_backend == "custom" ): from transformers import Qwen2_5_VLForConditionalGeneration @@ -456,7 +456,6 @@ def build_dataloaders( ), is_vlm=args.is_vlm, ) - if args.eval_data_path is not None or args.eval_hidden_states_path is not None: if args.eval_data_path is not None: eval_dataset = load_dataset("json", data_files=args.eval_data_path)["train"] @@ -547,7 +546,7 @@ def run_forward( target_model: Optional[Eagle3TargetModel] = None, is_online: bool = True, ) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: - if args.is_vlm: + if args.is_vlm and args.target_model_backend == "custom": plosses, _, acces = eagle3_model( input_ids=data["input_ids"].cuda(), attention_mask=data["attention_mask"].cuda(), @@ -556,13 +555,32 @@ def run_forward( image_grid_thw=data["image_grid_thw"].cuda(), ) else: + image_grid_thw = None if is_online: # we generate the eagle3 using the target model in an online fashion - eagle3_data = target_model.generate_eagle3_data( - input_ids=data["input_ids"].cuda(), - attention_mask=data["attention_mask"].cuda(), - loss_mask=data["loss_mask"].cuda(), - ) + # Handle VLM data: pixel_values and image_grid_thw are lists + # pixel_values = [pv.cuda() for pv in data["pixel_values"]] if args.is_vlm else None + if args.is_vlm: + image_grid_thw = ( + [thw.cuda().squeeze() for thw in data["image_grid_thw"]] + if args.is_vlm + else None + ) + pixel_values = data["pixel_values"].cuda() + eagle3_data = target_model.generate_eagle3_data( + input_ids=data["input_ids"].cuda(), + attention_mask=data["attention_mask"].cuda(), + loss_mask=data["loss_mask"].cuda(), + is_vlm=args.is_vlm, + pixel_values=pixel_values, + image_grid_thw=image_grid_thw, + ) + else: + eagle3_data = target_model.generate_eagle3_data( + input_ids=data["input_ids"].cuda(), + attention_mask=data["attention_mask"].cuda(), + loss_mask=data["loss_mask"].cuda(), + ) input_ids = get_dp_data_shard_from_tp(eagle3_data.input_ids) attention_mask = get_dp_data_shard_from_tp(eagle3_data.attention_mask) @@ -579,13 +597,14 @@ def run_forward( input_ids, target, loss_mask = target_model.preprocess( input_ids, target, loss_mask ) - plosses, _, acces = eagle3_model( input_ids=input_ids, attention_mask=attention_mask, loss_mask=loss_mask, target=target, hidden_states=hidden_states, + image_grid_thw=image_grid_thw, + is_vlm=args.is_vlm, ) return plosses, acces @@ -747,6 +766,8 @@ def main(): if ( args.is_vlm and getattr(draft_model_config, "target_model_type", None) == "qwen2_5_vl" + and args.tp_size == 1 + and args.target_model_backend != "sglang" ): eagle3_model = QwenVLOnlineEagle3Model( target_model=target_model, @@ -756,12 +777,20 @@ def main(): attention_backend=args.attention_backend, ) else: - eagle3_model = OnlineEagle3Model( - draft_model=draft_model, - length=args.ttt_length, - attention_backend=args.attention_backend, - ) - + if is_online: + eagle3_model = OnlineEagle3Model( + target_model=target_model, + draft_model=draft_model, + length=args.ttt_length, + attention_backend=args.attention_backend, + ) + else: + # offline: the target_model is TargetHead not a model + eagle3_model = OnlineEagle3Model( + draft_model=draft_model, + length=args.ttt_length, + attention_backend=args.attention_backend, + ) eagle3_model = FSDP( eagle3_model, use_orig_params=True, @@ -910,7 +939,6 @@ def main(): tracker, mode="eval", ) - # ================================================ # 7.3 Save Checkpoints # ================================================ @@ -923,7 +951,6 @@ def main(): if args.max_num_steps is not None and global_step >= args.max_num_steps: break - # Save final checkpoint if training ended without saving if global_step % args.save_interval != 0: print_on_rank0( diff --git a/specforge/core/eagle3.py b/specforge/core/eagle3.py index 34c6d6b5e..23f1bb9eb 100644 --- a/specforge/core/eagle3.py +++ b/specforge/core/eagle3.py @@ -59,6 +59,7 @@ def __init__( draft_model: Eagle3DraftModel, length: int = 7, attention_backend="sdpa", + target_model: Optional[Eagle3Model] = None, ): """ Args: @@ -70,6 +71,7 @@ def __init__( self.draft_model = draft_model self.length = length self.attention_backend = attention_backend + self.target_model = target_model if self.attention_backend == "usp": self.extract_func = EXTRACT_FUNC_DICT["basic"] @@ -98,6 +100,8 @@ def forward( hidden_states: torch.Tensor, past_key_values: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, position_ids: Optional[torch.Tensor] = None, + image_grid_thw: Optional[torch.Tensor] = None, + is_vlm: bool = False, **kwargs, ) -> Tuple[List[torch.Tensor], List[torch.Tensor], List[torch.Tensor]]: """ @@ -132,14 +136,22 @@ def forward( past_key_values_length = past_key_values[0][0].shape[2] seq_length_with_past = seq_length_with_past + past_key_values_length if position_ids is None: - device = hidden_states.device - position_ids = torch.arange( - past_key_values_length, - seq_length + past_key_values_length, - dtype=torch.long, - device=device, - ) - position_ids = position_ids.unsqueeze(0).view(-1, seq_length) + if is_vlm: + mrope_positions_ids, mrope_position_delta = ( + self.target_model.get_rope_index( + input_ids=input_ids, image_grid_thw=image_grid_thw + ) + ) + position_ids = mrope_positions_ids + else: + device = hidden_states.device + position_ids = torch.arange( + past_key_values_length, + seq_length + past_key_values_length, + dtype=torch.long, + device=device, + ) + position_ids = position_ids.unsqueeze(0).view(-1, seq_length) else: position_ids = position_ids.view(-1, seq_length).long() diff --git a/specforge/modeling/target/eagle3_target_model.py b/specforge/modeling/target/eagle3_target_model.py index b0cb0e854..235544148 100644 --- a/specforge/modeling/target/eagle3_target_model.py +++ b/specforge/modeling/target/eagle3_target_model.py @@ -1,16 +1,29 @@ from abc import ABC, abstractmethod from dataclasses import dataclass -from typing import List, Optional +from typing import List, Optional, Tuple +import sglang.srt.managers.mm_utils as mm_utils import torch import torch.distributed as dist import torch.nn as nn from sglang.srt.configs.model_config import ModelConfig -from sglang.srt.managers.schedule_batch import Req, ScheduleBatch +from sglang.srt.layers.rotary_embedding import MRotaryEmbedding +from sglang.srt.managers.mm_utils import ( + MultiModalityDataPaddingPatternMultimodalTokens, + init_mm_embedding_cache, +) +from sglang.srt.managers.schedule_batch import ( + Modality, + MultimodalDataItem, + MultimodalInputs, + Req, + ScheduleBatch, +) from sglang.srt.managers.scheduler import Scheduler from sglang.srt.mem_cache.cache_init_params import CacheInitParams from sglang.srt.mem_cache.radix_cache import RadixCache from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, ForwardBatch +from sglang.srt.multimodal.processors.base_processor import BaseMultimodalProcessor from sglang.srt.sampling.sampling_params import SamplingParams from sglang.srt.server_args import ServerArgs from sglang.srt.speculative.spec_info import SpeculativeAlgorithm @@ -241,9 +254,42 @@ def hook(module, input, output): class SGLangEagle3TargetModel(Eagle3TargetModel): - def __init__(self, model_runner: SGLangRunner): + def __init__(self, model_runner: SGLangRunner, hf_config=None): super().__init__() self.model_runner = model_runner + self.hf_config = hf_config + + # VLM-specific attributes (initialized from hf_config if available) + self._init_vlm_attributes() + + def _init_vlm_attributes(self): + """Initialize VLM-specific attributes from hf_config for models like Qwen2.5-VL""" + if self.hf_config is None: + self.is_vlm = False + return + + # Check if this is a VLM model by looking for vision_config + self.is_vlm = hasattr(self.hf_config, "vision_config") + + if not self.is_vlm: + return + + init_mm_embedding_cache(1024 * 1024 * 512) + # Model type (e.g., "qwen2_5_vl", "qwen2_vl") + self.model_type = getattr(self.hf_config, "model_type", None) + + # Vision config attributes + vision_config = self.hf_config.vision_config + self.spatial_merge_size = getattr(vision_config, "spatial_merge_size", 2) + self.tokens_per_second = getattr(vision_config, "tokens_per_second", None) + + # Special token IDs from hf_config + self.image_token_id = getattr(self.hf_config, "image_token_id", None) + self.video_token_id = getattr(self.hf_config, "video_token_id", None) + self.vision_start_token_id = getattr( + self.hf_config, "vision_start_token_id", None + ) + self.vision_end_token_id = getattr(self.hf_config, "vision_end_token_id", None) @classmethod def from_pretrained( @@ -286,7 +332,11 @@ def from_pretrained( wrap_eagle3_logits_processors_in_module( model_runner.model, return_full_logits=False ) - return cls(model_runner) + + # Get hf_config from model_config for VLM attributes + hf_config = getattr(model_config, "hf_config", None) + + return cls(model_runner, hf_config=hf_config) def set_aux_hidden_states_layers( self, aux_hidden_states_layers: Optional[List[int]] = None @@ -420,12 +470,208 @@ def extend( return data_cache, logits_list, aux_hidden_states_list, last_hidden_states_list + def get_rope_index( + self, + input_ids: torch.Tensor, + image_grid_thw: Optional[torch.Tensor] = None, + video_grid_thw: Optional[torch.Tensor] = None, + second_per_grid_ts: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + """ + Get M-RoPE position indices for VLM models like Qwen2.5-VL. + + This is a wrapper around MRotaryEmbedding.get_rope_index that uses + the VLM-specific attributes initialized from hf_config. + + Args: + input_ids: (batch_size, seq_len) input token IDs + image_grid_thw: (num_images, 3) image grid dimensions (t, h, w) + video_grid_thw: (num_videos, 3) video grid dimensions (t, h, w) + second_per_grid_ts: Optional temporal information for videos + attention_mask: (batch_size, seq_len) attention mask + + Returns: + position_ids: (3, batch_size, seq_len) M-RoPE position IDs + rope_deltas: Optional position deltas for incremental decoding + """ + if not self.is_vlm: + raise ValueError("get_rope_index is only available for VLM models") + + from sglang.srt.layers.rotary_embedding import MRotaryEmbedding + + position_ids, rope_deltas = MRotaryEmbedding.get_rope_index( + spatial_merge_size=self.spatial_merge_size, + image_token_id=self.image_token_id, + video_token_id=self.video_token_id, + vision_start_token_id=self.vision_start_token_id, + model_type=self.model_type, + input_ids=input_ids, + image_grid_thw=image_grid_thw, + video_grid_thw=video_grid_thw, + second_per_grid_ts=second_per_grid_ts, + attention_mask=attention_mask, + tokens_per_second=self.tokens_per_second, + ) + + return position_ids, rope_deltas + + def extend_vlm( + self, + input_ids: torch.Tensor, + attention_mask: torch.Tensor, + loss_mask: torch.Tensor, + return_last_hidden_states: bool = False, + return_logits: bool = True, + pixel_values: Optional[List[torch.Tensor]] = None, + image_grid_thw: Optional[List[torch.Tensor]] = None, + ): + """ + Args: + input_ids: (batch_size, seq_len) or List of (1, seq_len) tensors + attention_mask: (batch_size, seq_len) or List of (1, seq_len) tensors + loss_mask: (batch_size, seq_len) or List of (1, seq_len) tensors + pixel_values: List of pixel_values tensors, one per sample in batch + image_grid_thw: List of image_grid_thw tensors, one per sample in batch + """ + mm_utils.embedding_cache.clear() + sampling_params = SamplingParams(temperature=0, max_new_tokens=1, top_k=1) + reqs, data_cache = [], [] + + # Split tensors if needed + if isinstance(input_ids, torch.Tensor): + batch_size = input_ids.shape[0] + input_ids = torch.split(input_ids, 1, dim=0) + attention_mask = torch.split(attention_mask, 1, dim=0) + loss_mask = torch.split(loss_mask, 1, dim=0) + else: + batch_size = len(input_ids) + # Process image_grid_thw - convert to list if needed + if image_grid_thw is None: + image_grid_thw = [None] * batch_size + elif not isinstance(image_grid_thw, (list, tuple)): + image_grid_thw = [image_grid_thw] + + # pixel_values is a single 2D tensor (total_patches, patch_dim) for Qwen2.5-VL + # We need to track offset and slice it based on image_grid_thw for each sample + pixel_values_offset = 0 # Track current offset in pixel_values + + for idx, (input_id_, attention_mask_, loss_mask_, image_grid_thw_) in enumerate( + zip( + input_ids, + attention_mask, + loss_mask, + image_grid_thw, + ) + ): + # Compute num_patches for this sample from image_grid_thw_ + # image_grid_thw_: (num_images, 3) where each row is (t, h, w) + if image_grid_thw_ is not None: + # Ensure image_grid_thw_ is 2D: (num_images, 3) + if image_grid_thw_.dim() == 1: + image_grid_thw_ = image_grid_thw_.unsqueeze(0) # (3,) -> (1, 3) + elif image_grid_thw_.dim() == 0: + raise ValueError( + f"image_grid_thw_ is 0-dim tensor, expected at least 1D. Value: {image_grid_thw_}" + ) + + # Calculate num_patches for this sample: sum(t * h * w) for all images + num_patches = ( + ( + image_grid_thw_[:, 0] + * image_grid_thw_[:, 1] + * image_grid_thw_[:, 2] + ) + .sum() + .item() + ) + num_patches = int(num_patches) + + # Slice pixel_values for this sample + pixel_value_ = pixel_values[ + pixel_values_offset : pixel_values_offset + num_patches + ] + pixel_values_offset += num_patches + else: + pixel_value_ = None + num_patches = 0 + + # Compute mrope positions for VLM models (e.g., Qwen2.5-VL) + input_id_flat = input_id_.view(-1) + + # Count image tokens + num_img_tokens = (input_id_flat == self.image_token_id).sum().item() + # print(f"[extend_vlm] num_img_tokens in input_ids: {num_img_tokens}") + + mrope_positions, mrope_position_delta = MRotaryEmbedding.get_rope_index( + spatial_merge_size=self.spatial_merge_size, + image_token_id=self.image_token_id, + video_token_id=self.video_token_id, + vision_start_token_id=self.vision_start_token_id, + model_type=self.model_type, + input_ids=input_id_flat.unsqueeze(0), + image_grid_thw=( + image_grid_thw_.cpu() if image_grid_thw_ is not None else None + ), + tokens_per_second=self.tokens_per_second, + ) + + offset = BaseMultimodalProcessor.get_mm_items_offset( + input_id_flat, self.image_token_id + ) + mm_item = MultimodalDataItem( + modality=Modality.IMAGE, + feature=pixel_value_, # torch.Tensor: (num_patches, patch_dim) + pad_value=self.image_token_id, # Required for placeholder tensor creation + offsets=offset, # List of (start, end) tuples + ) + mm_item.set("image_grid_thw", image_grid_thw_.cpu()) + mm_item.set_pad_value() + mm_inputs = MultimodalInputs( + mm_items=[mm_item], + im_token_id=self.image_token_id, + im_start_id=self.vision_start_token_id, + im_end_id=self.vision_end_token_id, + mrope_positions=( + mrope_positions.squeeze(1) if mrope_positions is not None else None + ), + mrope_position_delta=mrope_position_delta, + ) + pattern = MultiModalityDataPaddingPatternMultimodalTokens() + input_id_list = pattern.pad_input_tokens( + input_id_.view(-1).tolist(), mm_inputs + ) + req = Req( + rid=str(idx), + origin_input_text="", + origin_input_ids=input_id_list, + sampling_params=sampling_params, + ) + req.fill_ids = req.origin_input_ids + req.extend_input_len = len(req.fill_ids) - len(req.prefix_indices) + req.logprob_start_len = len(req.origin_input_ids) - 1 + req.multimodal_inputs = mm_inputs + data_cache.append([input_id_, attention_mask_, loss_mask_]) + reqs.append(req) + + logits_list, aux_hidden_states_list, last_hidden_states_list = self._extend( + reqs, + capture_aux_hidden_states=True, + return_last_hidden_states=return_last_hidden_states, + return_logits=return_logits, + ) + + return data_cache, logits_list, aux_hidden_states_list, last_hidden_states_list + @torch.no_grad() def generate_eagle3_data( self, input_ids: torch.Tensor, attention_mask: torch.Tensor, loss_mask: torch.Tensor, + pixel_values: Optional[torch.Tensor] = None, + image_grid_thw: Optional[torch.Tensor] = None, + is_vlm: bool = False, ) -> Eagle3TargetOutput: """ return: @@ -435,16 +681,31 @@ def generate_eagle3_data( - loss_mask: (1, seq_len) - target: (1, seq_len, vocab_size) or (1, seq_len, hidden_size) - hidden_states: (1, seq_len, hidden_size) + - pixel_values: (patch_len, patch_width) + - image_grid_thw (batch_size, 3) """ - data_cache, logits_list, aux_hidden_states_list, last_hidden_states_list = ( - self.extend( - input_ids, - attention_mask, - loss_mask, - return_last_hidden_states=False, - return_logits=True, + if is_vlm: + data_cache, logits_list, aux_hidden_states_list, last_hidden_states_list = ( + self.extend_vlm( + input_ids, + attention_mask, + loss_mask, + return_last_hidden_states=False, + return_logits=True, + pixel_values=pixel_values, + image_grid_thw=image_grid_thw, + ) + ) + else: + data_cache, logits_list, aux_hidden_states_list, last_hidden_states_list = ( + self.extend( + input_ids, + attention_mask, + loss_mask, + return_last_hidden_states=False, + return_logits=True, + ) ) - ) aux_hidden_states_out = [] target_out = [] loss_mask_out = [] diff --git a/tests/test_modeling/test_target/test_sglang_backend/images/demo.jpeg b/tests/test_modeling/test_target/test_sglang_backend/images/demo.jpeg new file mode 100644 index 000000000..9fdc04005 Binary files /dev/null and b/tests/test_modeling/test_target/test_sglang_backend/images/demo.jpeg differ diff --git a/tests/test_modeling/test_target/test_sglang_backend/test_sglang_backend.py b/tests/test_modeling/test_target/test_sglang_backend/test_sglang_backend.py index 146c1bc08..ce7396fc8 100644 --- a/tests/test_modeling/test_target/test_sglang_backend/test_sglang_backend.py +++ b/tests/test_modeling/test_target/test_sglang_backend/test_sglang_backend.py @@ -84,6 +84,322 @@ def test_moe(rank, world_size, port, tp_size): ) +def test_vlm(rank, world_size, port, tp_size): + os.environ["RANK"] = str(rank) + os.environ["LOCAL_RANK"] = str(rank) + os.environ["WORLD_SIZE"] = str(world_size) + os.environ["MASTER_ADDR"] = "localhost" + os.environ["MASTER_PORT"] = str(port) + + init_distributed(tp_size=tp_size) + set_seed(42) + + # model_path = "Qwen/Qwen2.5-VL-32B-Instruct" + model_path = "Qwen/Qwen2.5-VL-32B-Instruct" + image_path = os.path.join(os.path.dirname(__file__), "images", "demo.jpeg") + + # Use Qwen2.5-VL processor to prepare inputs + from qwen_vl_utils import process_vision_info + from transformers import Qwen2_5_VLProcessor + + processor = Qwen2_5_VLProcessor.from_pretrained(model_path) + + # Create test messages with images (batch_size=2) + # Sample 1: single image + messages_1 = [ + { + "role": "user", + "content": [ + {"type": "image", "image": image_path}, + {"type": "text", "text": "Describe this image."}, + ], + } + ] + + # Sample 2: single image (can use same or different image) + messages_2 = [ + { + "role": "user", + "content": [ + {"type": "image", "image": image_path}, + {"type": "text", "text": "What do you see in this picture?"}, + ], + } + ] + + # Process each sample separately to get correct format + batch_input_ids = [] + batch_attention_mask = [] + batch_pixel_values = [] + batch_image_grid_thw = [] + + for messages in [messages_1, messages_2]: + # Apply chat template + text = processor.apply_chat_template( + messages, tokenize=False, add_generation_prompt=True + ) + # Process vision info to get actual image data + image_inputs, video_inputs = process_vision_info(messages) + + # Process with processor + inputs = processor( + text=[text], + images=image_inputs, + videos=video_inputs, + padding=True, + return_tensors="pt", + ) + + batch_input_ids.append(inputs["input_ids"]) + batch_attention_mask.append(inputs["attention_mask"]) + batch_pixel_values.append(inputs["pixel_values"]) + batch_image_grid_thw.append(inputs["image_grid_thw"]) + + # Debug: print shapes + if rank == 0: + print(f"[Debug] batch_input_ids shapes: {[x.shape for x in batch_input_ids]}") + print( + f"[Debug] batch_pixel_values shapes: {[x.shape for x in batch_pixel_values]}" + ) + print( + f"[Debug] batch_image_grid_thw shapes: {[x.shape for x in batch_image_grid_thw]}" + ) + print(f"[Debug] batch_image_grid_thw values: {batch_image_grid_thw}") + # Count image tokens in input_ids + image_token_id = processor.tokenizer.convert_tokens_to_ids("<|image_pad|>") + for i, ids in enumerate(batch_input_ids): + num_img_tokens = (ids == image_token_id).sum().item() + print(f"[Debug] Sample {i}: {num_img_tokens} image tokens in input_ids") + + # Pad input_ids and attention_mask to same length + max_len = max(ids.shape[1] for ids in batch_input_ids) + padded_input_ids = [] + padded_attention_mask = [] + padded_loss_mask = [] + + for input_ids, attention_mask in zip(batch_input_ids, batch_attention_mask): + pad_len = max_len - input_ids.shape[1] + if pad_len > 0: + input_ids = torch.nn.functional.pad( + input_ids, (0, pad_len), value=processor.tokenizer.pad_token_id + ) + attention_mask = torch.nn.functional.pad( + attention_mask, (0, pad_len), value=0 + ) + padded_input_ids.append(input_ids) + padded_attention_mask.append(attention_mask) + padded_loss_mask.append( + attention_mask.clone() + ) # loss_mask same as attention_mask + + # Stack into batches + input_ids = torch.cat(padded_input_ids, dim=0).cuda() + attention_mask = torch.cat(padded_attention_mask, dim=0).cuda() + loss_mask = torch.cat(padded_loss_mask, dim=0).cuda() + + # pixel_values and image_grid_thw remain as lists (one per sample) + pixel_values = torch.cat(batch_pixel_values, dim=0).cuda() + image_grid_thw = [thw.cuda() for thw in batch_image_grid_thw] + + sgl_target_model = SGLangEagle3TargetModel.from_pretrained( + model_path, + torch_dtype=torch.float16, + device="cuda", + attention_backend="fa3", + mem_fraction_static=0.75, + enable_torch_compile=True, + enable_nccl_nvls=False, + enable_symm_mem=False, # Disable to avoid nccl_allocator compilation issues + enable_dp_attention=True, + enable_dp_lm_head=True, + enable_piecewise_cuda_graph=True, + context_length=4096, + ) + sgl_target_model.set_aux_hidden_states_layers() + sgl_out = sgl_target_model.generate_eagle3_data( + input_ids=input_ids, + attention_mask=attention_mask, + loss_mask=loss_mask, + pixel_values=pixel_values, + image_grid_thw=image_grid_thw, + is_vlm=True, + ) + + if rank == 0: + # Verify output shapes + print(f"[Rank {rank}] hidden_states shape: {sgl_out.hidden_states.shape}") + print(f"[Rank {rank}] target shape: {sgl_out.target.shape}") + print(f"[Rank {rank}] input_ids shape: {sgl_out.input_ids.shape}") + + +def test_vlm_multi_batch(rank, world_size, port, tp_size): + """Test VLM with larger batch size (4 samples) and varying image counts.""" + os.environ["RANK"] = str(rank) + os.environ["LOCAL_RANK"] = str(rank) + os.environ["WORLD_SIZE"] = str(world_size) + os.environ["MASTER_ADDR"] = "localhost" + os.environ["MASTER_PORT"] = str(port) + + init_distributed(tp_size=tp_size) + set_seed(42) + + model_path = "Qwen/Qwen2.5-VL-32B-Instruct" + + from qwen_vl_utils import process_vision_info + from transformers import Qwen2_5_VLProcessor + + processor = Qwen2_5_VLProcessor.from_pretrained(model_path) + + image_path = os.path.join(os.path.dirname(__file__), "images", "demo.jpeg") + + # Create test messages with different configurations (batch_size=4) + # Sample 1: single image + messages_1 = [ + { + "role": "user", + "content": [ + {"type": "image", "image": image_path}, + {"type": "text", "text": "Describe this image in detail."}, + ], + } + ] + + # Sample 2: single image with different prompt + messages_2 = [ + { + "role": "user", + "content": [ + {"type": "image", "image": image_path}, + {"type": "text", "text": "What objects can you see in this picture?"}, + ], + } + ] + + # Sample 3: single image with longer prompt + messages_3 = [ + { + "role": "user", + "content": [ + {"type": "image", "image": image_path}, + { + "type": "text", + "text": "Please analyze this image and describe the main subject, background, colors, and any notable details you observe.", + }, + ], + } + ] + + # Sample 4: single image with short prompt + messages_4 = [ + { + "role": "user", + "content": [ + {"type": "image", "image": image_path}, + {"type": "text", "text": "What is this?"}, + ], + } + ] + + all_messages = [messages_1, messages_2, messages_3, messages_4] + batch_size = len(all_messages) + + # Process each sample separately to get correct format + batch_input_ids = [] + batch_attention_mask = [] + batch_pixel_values = [] + batch_image_grid_thw = [] + + for messages in all_messages: + # Apply chat template + text = processor.apply_chat_template( + messages, tokenize=False, add_generation_prompt=True + ) + # Process vision info to get actual image data + image_inputs, video_inputs = process_vision_info(messages) + + # Process with processor + inputs = processor( + text=[text], + images=image_inputs, + videos=video_inputs, + padding=True, + return_tensors="pt", + ) + + batch_input_ids.append(inputs["input_ids"]) + batch_attention_mask.append(inputs["attention_mask"]) + batch_pixel_values.append(inputs["pixel_values"]) + batch_image_grid_thw.append(inputs["image_grid_thw"]) + + # Pad input_ids and attention_mask to same length + max_len = max(ids.shape[1] for ids in batch_input_ids) + padded_input_ids = [] + padded_attention_mask = [] + padded_loss_mask = [] + + for input_ids, attention_mask in zip(batch_input_ids, batch_attention_mask): + pad_len = max_len - input_ids.shape[1] + if pad_len > 0: + input_ids = torch.nn.functional.pad( + input_ids, (0, pad_len), value=processor.tokenizer.pad_token_id + ) + attention_mask = torch.nn.functional.pad( + attention_mask, (0, pad_len), value=0 + ) + padded_input_ids.append(input_ids) + padded_attention_mask.append(attention_mask) + padded_loss_mask.append( + attention_mask.clone() + ) # loss_mask same as attention_mask + + # Stack into batches + input_ids = torch.cat(padded_input_ids, dim=0).cuda() + attention_mask = torch.cat(padded_attention_mask, dim=0).cuda() + loss_mask = torch.cat(padded_loss_mask, dim=0).cuda() + + # pixel_values and image_grid_thw remain as lists (one per sample) + pixel_values = torch.cat(batch_pixel_values, dim=0).cuda() + image_grid_thw = [thw.cuda() for thw in batch_image_grid_thw] + sgl_target_model = SGLangEagle3TargetModel.from_pretrained( + model_path, + torch_dtype=torch.float16, + device="cuda", + attention_backend="fa3", + mem_fraction_static=0.4, + enable_torch_compile=True, + enable_nccl_nvls=False, + enable_symm_mem=False, + enable_dp_attention=True, + enable_dp_lm_head=True, + enable_piecewise_cuda_graph=True, + context_length=4096, + ) + sgl_target_model.set_aux_hidden_states_layers() + sgl_out = sgl_target_model.generate_eagle3_data( + input_ids=input_ids, + attention_mask=attention_mask, + loss_mask=loss_mask, + pixel_values=pixel_values, + image_grid_thw=image_grid_thw, + is_vlm=True, + ) + + if rank == 0: + # Verify output shapes + print(f"\n{'='*60}") + print(f"[test_vlm_multi_batch] Results:") + print(f"[Rank {rank}] hidden_states shape: {sgl_out.hidden_states.shape}") + print(f"[Rank {rank}] target shape: {sgl_out.target.shape}") + print(f"[Rank {rank}] input_ids shape: {sgl_out.input_ids.shape}") + + # Verify batch dimension matches + assert ( + sgl_out.input_ids.shape[0] == batch_size + ), f"Expected batch_size={batch_size}, got {sgl_out.input_ids.shape[0]}" + print(f"[Rank {rank}] Batch size verification: PASSED") + print(f"{'='*60}\n") + + class TestTargetModelBackend(unittest.TestCase): def test_sglang_backend_with_dense(self): @@ -96,6 +412,16 @@ def test_sglang_backend_with_moe(self): port = get_available_port() mp.spawn(test_moe, nprocs=world_size, args=(world_size, port, 2)) + def test_sglang_backend_with_vlm(self): + world_size = 2 + port = get_available_port() + mp.spawn(test_vlm, nprocs=world_size, args=(world_size, port, 2)) + + def test_sglang_backend_with_vlm_multi_batch(self): + world_size = 2 + port = get_available_port() + mp.spawn(test_vlm_multi_batch, nprocs=world_size, args=(world_size, port, 2)) + if __name__ == "__main__": suite = unittest.TestSuite()