Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions configs/demo.yaml
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
### model
scaling_factor: 4
attn_implementation: "flash_attention_2"
attn_implementation: "sdpa"

### data
sample_fps: 4
max_num_frames: 12
# max_num_frames: 2048
longsize_resolution: 448

### generate
do_sample: false
do_sample: false
47 changes: 6 additions & 41 deletions demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,13 +48,6 @@ def load_specific_frames(cap, frame_indices):


def load_video(video_path: str, max_num_frames: int, fps: Union[int, float]=None, frame_extraction_fps: Union[int, float]=None, timestamps=None):
"""Load video frames at fps. If total frames larger than `max_num_frames`, do downsample.
If 'fps' is `None`, load uniformly sample `max_num_frames` frames.

video_path: Should either be a videofile or a directory of extracted frames.

# NOTE: The extract frames must have name pattern of `%06d.(ext)`, or the loaded frame order will be wrong.
"""
if video_path.startswith("file://"):
video_path = video_path[7:]
if os.path.isdir(video_path): # directory extracted frames
Expand Down Expand Up @@ -102,7 +95,7 @@ def load_video(video_path: str, max_num_frames: int, fps: Union[int, float]=None
cap.release()
return frames, sampling_fps, timestamps

timestamps = [idx / frame_extraction_fps for idx in frame_indices] # ← 添加此行,计算每帧的时间戳
timestamps = [idx / frame_extraction_fps for idx in frame_indices]

# Convert into RGB format
frames = [
Expand Down Expand Up @@ -173,38 +166,19 @@ def load_and_patch_model(model_name, hf_model_path, exp_configs, device):
transformers.models.qwen2_5_vl.processing_qwen2_5_vl.Qwen2_5_VLProcessor.__call__ = date_processing_qwen2_5_vl__call__
transformers.models.qwen2_5_vl.modeling_qwen2_5_vl.Qwen2_5_VLForConditionalGeneration.get_rope_index = date_get_rope_index

processor = AutoProcessor.from_pretrained(hf_model_path)
# FIX: Force use_fast=False to avoid "unexpected keyword argument 'videos'" error
processor = AutoProcessor.from_pretrained(hf_model_path, use_fast=False)

else:
raise NotImplementedError
return model, processor




# if __name__ == "__main__":
# #------------------- Modify the following configs ------------------#
# hf_model_path = 'Qwen/Qwen2.5-VL-7B-Instruct'
# model_name = 'qwen2_5_vl'
# #------------------- Modify the following settings ------------------#
# DEMO_VIDEO = 'asserts/demo.mp4'
# question = "When did his the straps of his pants slip off when he turned his back to dance?"
# max_frames = 24
# Q2C = True # Query to Caption, if set to False, it will use the question directly.
# SAMPLE = True # Sample the video frames with TASS, if set to False, it will use the original method.


# config_path = 'configs/demo.yaml'
# # NOTE: for 7B models in Nvidia GPUs
# device = 'cuda:0'
# # NOTE: for 72B models in Nvidia GPUs
# # device = 'auto'
import argparse

if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Qwen2.5-VL-7B-Instruct Demo")

# 添加命令行参数
parser.add_argument("--hf_model_path", type=str, default='Qwen/Qwen2.5-VL-7B-Instruct', help="Huggingface model path")
parser.add_argument("--model_name", type=str, default='qwen2_5_vl', help="Model name")
parser.add_argument("--video", type=str, default='asserts/demo.mp4', help="Input demo video path")
Expand All @@ -213,21 +187,18 @@ def load_and_patch_model(model_name, hf_model_path, exp_configs, device):
parser.add_argument("--q2c", action='store_true', help="Enable Query to Caption")
parser.add_argument("--tass_sample", action='store_true', help="Enable TASS sampling")
parser.add_argument("--config_path", type=str, default='configs/demo.yaml', help="Path to config YAML")
parser.add_argument("--device", type=str, default='cuda:0', help="Device to use. Use 'auto' for automatic device selection (e.g., for 72B models in Nvidia GPUs)")
parser.add_argument("--device", type=str, default='cuda:0', help="Device to use.")

args = parser.parse_args()


#------------------------ No need to change ------------------------#
video_info = {"type": "video",
"video": args.video,
"fps": 2.0}

exp_configs = load_yaml(args.config_path)
model, processor = load_and_patch_model(args.model_name, args.hf_model_path, args.config_path, args.device)
model, processor = load_and_patch_model(args.model_name, args.hf_model_path, exp_configs, args.device)
caption = args.question

# query->caption
if args.q2c:
from utils.query import chatcompletions
caption = chatcompletions(question = args.question)
Expand All @@ -243,8 +214,6 @@ def load_and_patch_model(model_name, hf_model_path, exp_configs, device):
else:
final_timestamps = None


# Video
conversation = [
{
"role": "system",
Expand All @@ -259,13 +228,11 @@ def load_and_patch_model(model_name, hf_model_path, exp_configs, device):
}
]

# If final_timestamps is provided, it will sample the corresponding frames; if not provided or set to None, it will sample uniformly.
video, sampling_fps, timestamps = fetch_video(video_info, exp_configs['max_num_frames'], exp_configs['sample_fps'], exp_configs['longsize_resolution'], timestamps = final_timestamps)

text_prompt = processor.apply_chat_template(conversation, add_generation_prompt=True)
videos_kwargs = dict(fps=sampling_fps)

# If `timestamps` is provided, it will embed the timestamps; if not provided or set to None, it will use the original method.
inputs = processor(text=[text_prompt], videos=[video], padding=True, return_tensors="pt", timestamps=timestamps, **videos_kwargs)

if args.device == 'auto':
Expand All @@ -274,10 +241,8 @@ def load_and_patch_model(model_name, hf_model_path, exp_configs, device):
inputs = inputs.to(args.device)
inputs['pixel_values_videos'] = inputs['pixel_values_videos'].to(torch.bfloat16)

# Inference: Generation of the output
output_ids = model.generate(**inputs, do_sample=False, max_new_tokens=128)

generated_ids = [output_ids[len(input_ids):] for input_ids, output_ids in zip(inputs.input_ids, output_ids)]
output_text = processor.batch_decode(generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True)
output_text = output_text
print('Output text:\n', output_text)
print('Output text:\n', output_text)
110 changes: 69 additions & 41 deletions qwen2_5_vl_date.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,30 @@

from typing import List, Optional, Tuple, Union

from typing import List, Optional, Tuple, Union, Any
import torch
import math
from transformers.image_utils import ImageInput

# FIX: VideoInput is missing in some transformer versions
VideoInput = Union[List[Any], Any]

from transformers.image_utils import ImageInput, VideoInput
from transformers.processing_utils import Unpack
from transformers.tokenization_utils_base import PreTokenizedInput, TextInput
from transformers.feature_extraction_utils import BatchFeature
from transformers.models.qwen2_5_vl.processing_qwen2_5_vl import Qwen2_5_VLProcessorKwargs



def date_processing_qwen2_5_vl__call__(
self,
images: ImageInput = None,
text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None,
videos: VideoInput = None,
timestamps = None, # 新增传入参数
timestamps = None,
**kwargs: Unpack[Qwen2_5_VLProcessorKwargs],
) -> BatchFeature:
output_kwargs = self._merge_kwargs(
Qwen2_5_VLProcessorKwargs,
tokenizer_init_kwargs=self.tokenizer.init_kwargs,
**kwargs,
)

if images is not None:
image_inputs = self.image_processor(images=images, videos=None, **output_kwargs["images_kwargs"])
image_grid_thw = image_inputs["image_grid_thw"]
Expand All @@ -32,8 +33,49 @@ def date_processing_qwen2_5_vl__call__(
image_grid_thw = None

if videos is not None:
videos_inputs = self.image_processor(images=None, videos=videos, **output_kwargs["images_kwargs"])
video_grid_thw = videos_inputs["video_grid_thw"]
# FIX: Flatten the batch of videos into a single list of frames
# videos is expected to be List[List[Frame]] (batch of videos)
all_frames = []
video_frame_counts = []

for video in videos:
if isinstance(video, list):
all_frames.extend(video)
video_frame_counts.append(len(video))
else:
# Fallback if a single tensor/image is passed (unlikely for video)
all_frames.append(video)
video_frame_counts.append(1)

# Process all frames as "images" since the processor removed 'videos' arg
videos_inputs = self.image_processor(images=all_frames, **output_kwargs["images_kwargs"])

# The processor returns a grid for EACH frame (Treating them as independent images)
# We must aggregate these grids back into video volumes (T, H, W)
raw_grid = videos_inputs.pop("image_grid_thw") # Shape: (Total_Frames, 3)

new_video_grids = []
start_idx = 0
for count in video_frame_counts:
# Extract grids for the current video
video_segment = raw_grid[start_idx : start_idx + count]

# Assuming frames in one video have the same resolution (H, W)
h = video_segment[0, 1]
w = video_segment[0, 2]
t = count # Total frames

new_video_grids.append([t, h, w])
start_idx += count

# Assign the aggregated grid
videos_inputs["video_grid_thw"] = torch.tensor(new_video_grids, device=raw_grid.device)

# Rename pixel_values to what the model expects
if "pixel_values" in videos_inputs:
videos_inputs["pixel_values_videos"] = videos_inputs.pop("pixel_values")

video_grid_thw = videos_inputs.get("video_grid_thw")

fps = output_kwargs["videos_kwargs"].pop("fps", 2.0)
if fps is None:
Expand Down Expand Up @@ -78,42 +120,41 @@ def date_processing_qwen2_5_vl__call__(
t = videos_inputs["video_grid_thw"][index][0]
stamps = timestamps
timestamps_length = len(stamps)
segment = timestamps_length//t
# Safety check to avoid division by zero if timestamps is empty
if t > 0:
segment = max(1, timestamps_length // t)
else:
segment = 1

new_timestamps = []

for j in range(t):

# <mm:ss>
# tmp_time = stamps[j*segment]
# minutes = math.floor(tmp_time // 60)
# seconds = math.floor(tmp_time % 60)
# # # tmp_time = "{:02d}:{:02d}:{:02d}".format(hours, minutes, seconds)
# tmp_time = "{:02d}:{:02d}".format(minutes, seconds)
# new_timestamps.append("<" + tmp_time + ">")

# <x.xs>
new_timestamps.append("<" + str(round(stamps[j*segment],1)) + "s>")
# Ensure we don't go out of bounds
idx = min(j * segment, len(stamps) - 1)
new_timestamps.append("<" + str(round(stamps[idx],1)) + "s>")


num_tokens = video_grid_thw[index].prod() // merge_length
n_segments = len(new_timestamps)
interval = num_tokens // n_segments

if n_segments > 0:
interval = num_tokens // n_segments
else:
interval = num_tokens

# video_token + timestamp + video_token + ...
video_with_stamps = ""
for j in range(n_segments):
video_with_stamps += "<|placeholder|>" * interval
video_with_stamps += new_timestamps[j]

# Append remaining tokens if division wasn't exact
remaining = num_tokens - (interval * n_segments)
if remaining > 0:
video_with_stamps += "<|placeholder|>" * remaining


text[i] = text[i].replace(self.video_token, video_with_stamps, 1)

index += 1
text[i] = text[i].replace("<|placeholder|>", self.video_token)
else:
for i in range(len(text)):

while self.video_token in text[i]:
text[i] = text[i].replace(
self.video_token,
Expand All @@ -126,7 +167,6 @@ def date_processing_qwen2_5_vl__call__(
text_inputs = self.tokenizer(text, **output_kwargs["text_kwargs"])
return BatchFeature(data={**text_inputs, **image_inputs, **videos_inputs})


def date_get_rope_index(
self,
input_ids: Optional[torch.LongTensor] = None,
Expand Down Expand Up @@ -208,20 +248,13 @@ def date_get_rope_index(
st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)

# 修改:顺序编码1,2,3,4,5...
range_tensor = torch.arange(llm_grid_t).view(-1, 1)
expanded_range = range_tensor.expand(-1, llm_grid_h * llm_grid_w)
t_index = expanded_range.long().flatten()

# 原始代码
# time_tensor = expanded_range * second_per_grid_t * self.config.vision_config.tokens_per_second
# time_tensor_long = time_tensor.long()
# t_index = time_tensor_long.flatten()

h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(llm_grid_t, -1, llm_grid_w).flatten()
w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(llm_grid_t, llm_grid_h, -1).flatten()

# 修改:实现vision token与时间戳token的交替位置编码
v_pos = torch.stack([t_index, h_index, w_index]) + text_len + st_idx
tokens_per_timestamp = llm_grid_h * llm_grid_w
my_st = st
Expand All @@ -242,11 +275,6 @@ def date_get_rope_index(

st = my_st

# 原始代码
# llm_pos_ids_list.append(torch.stack([t_index, h_index, w_index]) + text_len + st_idx)
# st = ed + llm_grid_t * llm_grid_h * llm_grid_w


if st < len(input_tokens):
st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
text_len = len(input_tokens) - st
Expand Down
4 changes: 0 additions & 4 deletions utils/clip_sim.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@
from decord import VideoReader
import numpy as np



device = "cuda" if torch.cuda.is_available() else "cpu"
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
model.to(device)
Expand All @@ -16,7 +14,6 @@

batch_size = 2048


def extract_frames(video_path, fps):
if os.path.isdir(video_path):
all_frames = [
Expand Down Expand Up @@ -91,4 +88,3 @@ def process_video_question(video_path, question, topk=512, fps=4, max_frames=256
embedding = torch.cat(embedding,0)

return score, timestamps, ids, embedding

Loading