diff --git a/configs/demo.yaml b/configs/demo.yaml index d5d06d9..7443217 100644 --- a/configs/demo.yaml +++ b/configs/demo.yaml @@ -1,12 +1,11 @@ ### model scaling_factor: 4 -attn_implementation: "flash_attention_2" +attn_implementation: "sdpa" ### data sample_fps: 4 max_num_frames: 12 -# max_num_frames: 2048 longsize_resolution: 448 ### generate -do_sample: false \ No newline at end of file +do_sample: false diff --git a/demo.py b/demo.py index 980c32e..05190e4 100644 --- a/demo.py +++ b/demo.py @@ -48,13 +48,6 @@ def load_specific_frames(cap, frame_indices): def load_video(video_path: str, max_num_frames: int, fps: Union[int, float]=None, frame_extraction_fps: Union[int, float]=None, timestamps=None): - """Load video frames at fps. If total frames larger than `max_num_frames`, do downsample. - If 'fps' is `None`, load uniformly sample `max_num_frames` frames. - - video_path: Should either be a videofile or a directory of extracted frames. - - # NOTE: The extract frames must have name pattern of `%06d.(ext)`, or the loaded frame order will be wrong. - """ if video_path.startswith("file://"): video_path = video_path[7:] if os.path.isdir(video_path): # directory extracted frames @@ -102,7 +95,7 @@ def load_video(video_path: str, max_num_frames: int, fps: Union[int, float]=None cap.release() return frames, sampling_fps, timestamps - timestamps = [idx / frame_extraction_fps for idx in frame_indices] # ← 添加此行,计算每帧的时间戳 + timestamps = [idx / frame_extraction_fps for idx in frame_indices] # Convert into RGB format frames = [ @@ -173,38 +166,19 @@ def load_and_patch_model(model_name, hf_model_path, exp_configs, device): transformers.models.qwen2_5_vl.processing_qwen2_5_vl.Qwen2_5_VLProcessor.__call__ = date_processing_qwen2_5_vl__call__ transformers.models.qwen2_5_vl.modeling_qwen2_5_vl.Qwen2_5_VLForConditionalGeneration.get_rope_index = date_get_rope_index - processor = AutoProcessor.from_pretrained(hf_model_path) + # FIX: Force use_fast=False to avoid "unexpected keyword argument 'videos'" error + processor = AutoProcessor.from_pretrained(hf_model_path, use_fast=False) else: raise NotImplementedError return model, processor - - -# if __name__ == "__main__": -# #------------------- Modify the following configs ------------------# -# hf_model_path = 'Qwen/Qwen2.5-VL-7B-Instruct' -# model_name = 'qwen2_5_vl' -# #------------------- Modify the following settings ------------------# -# DEMO_VIDEO = 'asserts/demo.mp4' -# question = "When did his the straps of his pants slip off when he turned his back to dance?" -# max_frames = 24 -# Q2C = True # Query to Caption, if set to False, it will use the question directly. -# SAMPLE = True # Sample the video frames with TASS, if set to False, it will use the original method. - - -# config_path = 'configs/demo.yaml' -# # NOTE: for 7B models in Nvidia GPUs -# device = 'cuda:0' -# # NOTE: for 72B models in Nvidia GPUs -# # device = 'auto' import argparse if __name__ == "__main__": parser = argparse.ArgumentParser(description="Qwen2.5-VL-7B-Instruct Demo") - # 添加命令行参数 parser.add_argument("--hf_model_path", type=str, default='Qwen/Qwen2.5-VL-7B-Instruct', help="Huggingface model path") parser.add_argument("--model_name", type=str, default='qwen2_5_vl', help="Model name") parser.add_argument("--video", type=str, default='asserts/demo.mp4', help="Input demo video path") @@ -213,21 +187,18 @@ def load_and_patch_model(model_name, hf_model_path, exp_configs, device): parser.add_argument("--q2c", action='store_true', help="Enable Query to Caption") parser.add_argument("--tass_sample", action='store_true', help="Enable TASS sampling") parser.add_argument("--config_path", type=str, default='configs/demo.yaml', help="Path to config YAML") - parser.add_argument("--device", type=str, default='cuda:0', help="Device to use. Use 'auto' for automatic device selection (e.g., for 72B models in Nvidia GPUs)") + parser.add_argument("--device", type=str, default='cuda:0', help="Device to use.") args = parser.parse_args() - - #------------------------ No need to change ------------------------# video_info = {"type": "video", "video": args.video, "fps": 2.0} exp_configs = load_yaml(args.config_path) - model, processor = load_and_patch_model(args.model_name, args.hf_model_path, args.config_path, args.device) + model, processor = load_and_patch_model(args.model_name, args.hf_model_path, exp_configs, args.device) caption = args.question - # query->caption if args.q2c: from utils.query import chatcompletions caption = chatcompletions(question = args.question) @@ -243,8 +214,6 @@ def load_and_patch_model(model_name, hf_model_path, exp_configs, device): else: final_timestamps = None - - # Video conversation = [ { "role": "system", @@ -259,13 +228,11 @@ def load_and_patch_model(model_name, hf_model_path, exp_configs, device): } ] - # If final_timestamps is provided, it will sample the corresponding frames; if not provided or set to None, it will sample uniformly. video, sampling_fps, timestamps = fetch_video(video_info, exp_configs['max_num_frames'], exp_configs['sample_fps'], exp_configs['longsize_resolution'], timestamps = final_timestamps) text_prompt = processor.apply_chat_template(conversation, add_generation_prompt=True) videos_kwargs = dict(fps=sampling_fps) - # If `timestamps` is provided, it will embed the timestamps; if not provided or set to None, it will use the original method. inputs = processor(text=[text_prompt], videos=[video], padding=True, return_tensors="pt", timestamps=timestamps, **videos_kwargs) if args.device == 'auto': @@ -274,10 +241,8 @@ def load_and_patch_model(model_name, hf_model_path, exp_configs, device): inputs = inputs.to(args.device) inputs['pixel_values_videos'] = inputs['pixel_values_videos'].to(torch.bfloat16) - # Inference: Generation of the output output_ids = model.generate(**inputs, do_sample=False, max_new_tokens=128) generated_ids = [output_ids[len(input_ids):] for input_ids, output_ids in zip(inputs.input_ids, output_ids)] output_text = processor.batch_decode(generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True) - output_text = output_text - print('Output text:\n', output_text) + print('Output text:\n', output_text) \ No newline at end of file diff --git a/qwen2_5_vl_date.py b/qwen2_5_vl_date.py index e14ed2c..582fb15 100644 --- a/qwen2_5_vl_date.py +++ b/qwen2_5_vl_date.py @@ -1,22 +1,22 @@ - -from typing import List, Optional, Tuple, Union - +from typing import List, Optional, Tuple, Union, Any import torch +import math +from transformers.image_utils import ImageInput + +# FIX: VideoInput is missing in some transformer versions +VideoInput = Union[List[Any], Any] -from transformers.image_utils import ImageInput, VideoInput from transformers.processing_utils import Unpack from transformers.tokenization_utils_base import PreTokenizedInput, TextInput from transformers.feature_extraction_utils import BatchFeature from transformers.models.qwen2_5_vl.processing_qwen2_5_vl import Qwen2_5_VLProcessorKwargs - - def date_processing_qwen2_5_vl__call__( self, images: ImageInput = None, text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None, videos: VideoInput = None, - timestamps = None, # 新增传入参数 + timestamps = None, **kwargs: Unpack[Qwen2_5_VLProcessorKwargs], ) -> BatchFeature: output_kwargs = self._merge_kwargs( @@ -24,6 +24,7 @@ def date_processing_qwen2_5_vl__call__( tokenizer_init_kwargs=self.tokenizer.init_kwargs, **kwargs, ) + if images is not None: image_inputs = self.image_processor(images=images, videos=None, **output_kwargs["images_kwargs"]) image_grid_thw = image_inputs["image_grid_thw"] @@ -32,8 +33,49 @@ def date_processing_qwen2_5_vl__call__( image_grid_thw = None if videos is not None: - videos_inputs = self.image_processor(images=None, videos=videos, **output_kwargs["images_kwargs"]) - video_grid_thw = videos_inputs["video_grid_thw"] + # FIX: Flatten the batch of videos into a single list of frames + # videos is expected to be List[List[Frame]] (batch of videos) + all_frames = [] + video_frame_counts = [] + + for video in videos: + if isinstance(video, list): + all_frames.extend(video) + video_frame_counts.append(len(video)) + else: + # Fallback if a single tensor/image is passed (unlikely for video) + all_frames.append(video) + video_frame_counts.append(1) + + # Process all frames as "images" since the processor removed 'videos' arg + videos_inputs = self.image_processor(images=all_frames, **output_kwargs["images_kwargs"]) + + # The processor returns a grid for EACH frame (Treating them as independent images) + # We must aggregate these grids back into video volumes (T, H, W) + raw_grid = videos_inputs.pop("image_grid_thw") # Shape: (Total_Frames, 3) + + new_video_grids = [] + start_idx = 0 + for count in video_frame_counts: + # Extract grids for the current video + video_segment = raw_grid[start_idx : start_idx + count] + + # Assuming frames in one video have the same resolution (H, W) + h = video_segment[0, 1] + w = video_segment[0, 2] + t = count # Total frames + + new_video_grids.append([t, h, w]) + start_idx += count + + # Assign the aggregated grid + videos_inputs["video_grid_thw"] = torch.tensor(new_video_grids, device=raw_grid.device) + + # Rename pixel_values to what the model expects + if "pixel_values" in videos_inputs: + videos_inputs["pixel_values_videos"] = videos_inputs.pop("pixel_values") + + video_grid_thw = videos_inputs.get("video_grid_thw") fps = output_kwargs["videos_kwargs"].pop("fps", 2.0) if fps is None: @@ -78,42 +120,41 @@ def date_processing_qwen2_5_vl__call__( t = videos_inputs["video_grid_thw"][index][0] stamps = timestamps timestamps_length = len(stamps) - segment = timestamps_length//t + # Safety check to avoid division by zero if timestamps is empty + if t > 0: + segment = max(1, timestamps_length // t) + else: + segment = 1 + new_timestamps = [] for j in range(t): - - # - # tmp_time = stamps[j*segment] - # minutes = math.floor(tmp_time // 60) - # seconds = math.floor(tmp_time % 60) - # # # tmp_time = "{:02d}:{:02d}:{:02d}".format(hours, minutes, seconds) - # tmp_time = "{:02d}:{:02d}".format(minutes, seconds) - # new_timestamps.append("<" + tmp_time + ">") - - # - new_timestamps.append("<" + str(round(stamps[j*segment],1)) + "s>") + # Ensure we don't go out of bounds + idx = min(j * segment, len(stamps) - 1) + new_timestamps.append("<" + str(round(stamps[idx],1)) + "s>") - num_tokens = video_grid_thw[index].prod() // merge_length n_segments = len(new_timestamps) - interval = num_tokens // n_segments - + if n_segments > 0: + interval = num_tokens // n_segments + else: + interval = num_tokens - # video_token + timestamp + video_token + ... video_with_stamps = "" for j in range(n_segments): video_with_stamps += "<|placeholder|>" * interval video_with_stamps += new_timestamps[j] + + # Append remaining tokens if division wasn't exact + remaining = num_tokens - (interval * n_segments) + if remaining > 0: + video_with_stamps += "<|placeholder|>" * remaining - text[i] = text[i].replace(self.video_token, video_with_stamps, 1) - index += 1 text[i] = text[i].replace("<|placeholder|>", self.video_token) else: for i in range(len(text)): - while self.video_token in text[i]: text[i] = text[i].replace( self.video_token, @@ -126,7 +167,6 @@ def date_processing_qwen2_5_vl__call__( text_inputs = self.tokenizer(text, **output_kwargs["text_kwargs"]) return BatchFeature(data={**text_inputs, **image_inputs, **videos_inputs}) - def date_get_rope_index( self, input_ids: Optional[torch.LongTensor] = None, @@ -208,20 +248,13 @@ def date_get_rope_index( st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx) - # 修改:顺序编码1,2,3,4,5... range_tensor = torch.arange(llm_grid_t).view(-1, 1) expanded_range = range_tensor.expand(-1, llm_grid_h * llm_grid_w) t_index = expanded_range.long().flatten() - # 原始代码 - # time_tensor = expanded_range * second_per_grid_t * self.config.vision_config.tokens_per_second - # time_tensor_long = time_tensor.long() - # t_index = time_tensor_long.flatten() - h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(llm_grid_t, -1, llm_grid_w).flatten() w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(llm_grid_t, llm_grid_h, -1).flatten() - # 修改:实现vision token与时间戳token的交替位置编码 v_pos = torch.stack([t_index, h_index, w_index]) + text_len + st_idx tokens_per_timestamp = llm_grid_h * llm_grid_w my_st = st @@ -242,11 +275,6 @@ def date_get_rope_index( st = my_st - # 原始代码 - # llm_pos_ids_list.append(torch.stack([t_index, h_index, w_index]) + text_len + st_idx) - # st = ed + llm_grid_t * llm_grid_h * llm_grid_w - - if st < len(input_tokens): st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 text_len = len(input_tokens) - st diff --git a/utils/clip_sim.py b/utils/clip_sim.py index 23c08c3..76e190a 100644 --- a/utils/clip_sim.py +++ b/utils/clip_sim.py @@ -6,8 +6,6 @@ from decord import VideoReader import numpy as np - - device = "cuda" if torch.cuda.is_available() else "cpu" model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32") model.to(device) @@ -16,7 +14,6 @@ batch_size = 2048 - def extract_frames(video_path, fps): if os.path.isdir(video_path): all_frames = [ @@ -91,4 +88,3 @@ def process_video_question(video_path, question, topk=512, fps=4, max_frames=256 embedding = torch.cat(embedding,0) return score, timestamps, ids, embedding - diff --git a/utils/sampling.py b/utils/sampling.py index f3100c6..423d281 100644 --- a/utils/sampling.py +++ b/utils/sampling.py @@ -1,8 +1,7 @@ import torch - def get_final_indices(topk_indices, timestamps, attn_weights, interval=20.0, max_frames=256): - if len(topk_indices)<=max_frames: + if len(topk_indices) <= max_frames: selected_timestamps = [timestamps[x] for x in topk_indices] return selected_timestamps, topk_indices @@ -25,15 +24,31 @@ def get_final_indices(topk_indices, timestamps, attn_weights, interval=20.0, max def tass_sampling(timestamps, attn_weights, topk=None, max_frames=256): - if topk is not None: + # FIX: Calculate dynamic topk if it is None (or if it was passed as not None, per original logic) + # We use 'if topk is None' to handle the default case from demo.py + if topk is None: length = len(attn_weights) - trg = sum(attn_weights)/length - topk = len([x for x in attn_weights if x > trg]) # only keep the positive samples - + if length > 0: + trg = sum(attn_weights) / length + # Calculate count of frames with above-average attention + topk = len([x for x in attn_weights if x > trg]) + else: + topk = max_frames + + # Ensure topk is at least 1 to prevent errors + if topk < 1: + topk = 1 + up = max_frames * 4 topk = min(topk, up) + attn_weights = torch.tensor(attn_weights) + + # Safety check: k cannot be larger than the number of available weights + if topk > len(attn_weights): + topk = len(attn_weights) + topk_indices = torch.topk(attn_weights, k=topk).indices.tolist() - selected_timestamps, selected_indices = get_final_indices(topk_indices, timestamps, attn_weights, interval = 20, max_frames=max_frames) + selected_timestamps, selected_indices = get_final_indices(topk_indices, timestamps, attn_weights, interval=20, max_frames=max_frames) selected_timestamps = sorted(selected_timestamps) - return selected_timestamps, selected_indices + return selected_timestamps, selected_indices \ No newline at end of file