53
53
VIDEO_TOKEN : str = "<|_unuse_missing_100270|>"
54
54
55
55
56
+ # Based on combine_frames_into_images in
57
+ # https://huggingface.co/naver-hyperclovax/HyperCLOVAX-SEED-Vision-Instruct-3B/blob/main/processing_hyperclovax.py
58
+ def get_num_combined_frames (
59
+ num_frames : int ,
60
+ max_grid_shape : tuple [int , int ] = (3 , 3 ),
61
+ ) -> int :
62
+ max_num_grids = max_grid_shape [0 ] * max_grid_shape [1 ]
63
+
64
+ # Calculate the number of canvases needed.
65
+ num_canvases = num_frames // max_num_grids
66
+ leftover_frames = num_frames % max_num_grids
67
+
68
+ return num_canvases + (leftover_frames > 0 )
69
+
70
+
56
71
class HCXVisionMultimodalPixelInputs (TypedDict ):
57
72
type : Literal ["pixel_values" ]
58
73
pixel_values_images : list [torch .Tensor ]
@@ -172,23 +187,20 @@ def _call_hf_processor(
172
187
def replace_multimodal_token (
173
188
token_ids : torch .Tensor ,
174
189
target_token : int ,
175
- repeats : list ,
190
+ repeats : list [ int ] ,
176
191
):
177
- output = list ()
192
+ output = list [ int ] ()
178
193
_repeats_idx = 0
179
194
for token_id in token_ids :
180
195
if token_id == target_token :
181
- output += [
182
- token_id .item (),
183
- ] * repeats [_repeats_idx ]
196
+ output += [token_id .item ()] * repeats [_repeats_idx ]
184
197
_repeats_idx += 1
185
198
else :
186
- output += [
187
- token_id .item (),
188
- ]
199
+ output += [token_id .item ()]
200
+
189
201
return torch .tensor (output , device = token_ids .device )
190
202
191
- for video_idx , video_arr in enumerate (mm_data .get ("videos" , list () )):
203
+ for video_idx , video_arr in enumerate (mm_data .get ("videos" , [] )):
192
204
if video_arr .dtype == np .uint8 :
193
205
continue
194
206
mm_data ["videos" ][video_idx ] = video_arr .astype (np .uint8 )
@@ -205,88 +217,68 @@ def replace_multimodal_token(
205
217
if len (mm_data ) > 0 :
206
218
# batchify input as a single item
207
219
images = mm_data .get ("images" , None )
208
- num_images = 0
209
- if images is not None :
210
- num_images = len (images )
211
- images = [
212
- images ,
213
- ] # batchify
214
-
215
- videos = mm_data .get ("videos" ,
216
- None ) # list of video in single conversation
217
- num_videos = 0
218
- if videos is not None :
219
- num_videos = len (videos )
220
- videos = [
221
- videos ,
222
- ] # batchify
220
+ batched_images = None if images is None else [images ]
221
+
222
+ # list of video in single conversation
223
+ videos = mm_data .get ("videos" , None )
224
+ batched_videos = None if videos is None else [videos ]
223
225
224
226
_processed_outputs = self .info .ctx .call_hf_processor (
225
227
hf_processor = self .info .get_hf_processor (** mm_kwargs ),
226
228
data = dict (
227
229
text = None ,
228
- images = images ,
229
- videos = videos ,
230
+ images = batched_images ,
231
+ videos = batched_videos ,
230
232
),
231
233
) # mm-only
232
234
233
235
for k , v in _processed_outputs .items ():
234
- if len (v ) < 1 :
235
- continue
236
- elif k .endswith ("_images" ):
237
- # list of list of 4D tensor -> list of 4D tensor
236
+ if isinstance (v , list ) and len (v ) > 0 :
237
+ assert len (v ) == 1
238
238
_processed_outputs [k ] = v [0 ]
239
- elif k .endswith ("_videos" ):
240
- # list of list of 4D tensor -> list of 4D tensor
241
- v = v [0 ]
242
- if k == "pixel_values_videos" :
243
- v = torch .cat (v , dim = 0 )
244
- _c , _w , _h = v .shape [- 3 :]
245
- v = v .reshape (num_videos , - 1 , _c , _w , _h )
246
- v = list (torch .unbind (v , dim = 0 ))
247
- _processed_outputs [k ] = v
248
-
249
- if num_images > 0 :
239
+
240
+ if images :
250
241
tokenizer = self .info .get_tokenizer ()
242
+ image_token_id = tokenizer .convert_tokens_to_ids (IMAGE_TOKEN )
251
243
processed_outputs ["input_ids" ] = torch .stack ([
252
244
replace_multimodal_token (
253
245
token_ids = _input_ids ,
254
- target_token = tokenizer .convert_tokens_to_ids (
255
- IMAGE_TOKEN ),
246
+ target_token = image_token_id ,
256
247
repeats = _processed_outputs [
257
248
"vision_query_lengths_images" ],
258
249
) for _input_ids in processed_outputs ["input_ids" ]
259
250
],
260
251
dim = 0 )
261
252
262
- if num_videos > 0 :
253
+ if videos :
254
+ _num_per_videos = [
255
+ get_num_combined_frames (len (video )) for video in videos
256
+ ]
257
+ _processed_outputs ["pixel_values_videos" ] = [
258
+ _processed_outputs ["pixel_values_videos" ]
259
+ [sum (_num_per_videos [:_i ]):sum (_num_per_videos [:_i + 1 ])]
260
+ for _i in range (len (videos ))
261
+ ]
262
+ _processed_outputs ["vision_query_lengths_videos" ] = [
263
+ _processed_outputs ["vision_query_lengths_videos" ]
264
+ [sum (_num_per_videos [:_i ]):sum (_num_per_videos [:_i + 1 ])]
265
+ for _i in range (len (videos ))
266
+ ]
267
+
263
268
tokenizer = self .info .get_tokenizer ()
269
+ video_token_id = tokenizer .convert_tokens_to_ids (VIDEO_TOKEN )
264
270
processed_outputs ["input_ids" ] = torch .stack ([
265
271
replace_multimodal_token (
266
272
token_ids = _input_ids ,
267
- target_token = tokenizer .convert_tokens_to_ids (
268
- VIDEO_TOKEN ),
269
- repeats = _processed_outputs [
270
- "vision_query_lengths_videos" ],
273
+ target_token = video_token_id ,
274
+ repeats = [
275
+ sum (lens ) for lens in
276
+ _processed_outputs ["vision_query_lengths_videos" ]
277
+ ],
271
278
) for _input_ids in processed_outputs ["input_ids" ]
272
279
],
273
280
dim = 0 )
274
281
275
- _ratios = [
276
- len (_pixel_values ) for _pixel_values in
277
- _processed_outputs ["pixel_values_videos" ]
278
- ]
279
- _num_per_videos = [
280
- int (_e / sum (_ratios ) *
281
- len (_processed_outputs ["vision_query_lengths_videos" ]))
282
- for _e in _ratios
283
- ]
284
- _processed_outputs ["vision_query_lengths_videos" ] = [
285
- _processed_outputs ["vision_query_lengths_videos" ]
286
- [sum (_num_per_videos [:_i ]):sum (_num_per_videos [:_i + 1 ])]
287
- for _i in range (0 , num_videos )
288
- ]
289
-
290
282
processed_outputs .update (_processed_outputs )
291
283
292
284
return processed_outputs
0 commit comments