@@ -432,6 +432,7 @@ def _convert_omni_to_inputs(
432
432
else :
433
433
images , image_sizes , tgt_sizes = [[]] * bs , [[]] * bs , [[]] * bs
434
434
435
+ final_texts_list = []
435
436
input_ids_list = []
436
437
image_bounds_list = []
437
438
audio_bounds_list = []
@@ -467,14 +468,26 @@ def _convert_omni_to_inputs(
467
468
final_text = "" .join (text_chunks )
468
469
input_ids , image_bounds , audio_bounds , spk_bounds = self ._convert (final_text , max_length , ** kwargs )
469
470
471
+ final_texts_list .append (final_text )
470
472
input_ids_list .append (input_ids )
471
473
image_bounds_list .append (image_bounds )
472
474
audio_bounds_list .append (audio_bounds )
473
475
spk_bounds_list .append (spk_bounds )
474
476
475
- padded_input_ids , padding_lengths = self .pad (input_ids_list , padding_side = "left" )
476
- attention_mask = torch .ones_like (padded_input_ids , dtype = torch .bool )
477
- for i , length in enumerate (padding_lengths ):
477
+ model_inputs = self .tokenizer (
478
+ final_texts_list ,
479
+ padding = "longest" ,
480
+ padding_side = "left" ,
481
+ return_tensors = return_tensors ,
482
+ truncation = truncation ,
483
+ max_length = max_length ,
484
+ ** kwargs ,
485
+ )
486
+
487
+ padded_input_ids = model_inputs ["input_ids" ]
488
+ attention_mask = model_inputs ["attention_mask" ]
489
+ for i in range (bs ):
490
+ length = (attention_mask [i ] == 0 ).sum ().item ()
478
491
image_bounds_list [i ] = image_bounds_list [i ] + length
479
492
audio_bounds_list [i ] = audio_bounds_list [i ] + length
480
493
spk_bounds_list [i ] = spk_bounds_list [i ] + length
@@ -501,52 +514,6 @@ def model_input_names(self):
501
514
feature_extractor_input_names = self .feature_extractor .model_input_names
502
515
return list (dict .fromkeys (tokenizer_input_names + image_processor_input_names + feature_extractor_input_names ))
503
516
504
- def pad (self , inputs , max_length = None , padding_value = 0 , padding_side = "left" ):
505
- items = []
506
- if isinstance (inputs [0 ], list ):
507
- assert isinstance (inputs [0 ][0 ], torch .Tensor )
508
- for it in inputs :
509
- for tr in it :
510
- items .append (tr )
511
- else :
512
- assert isinstance (inputs [0 ], torch .Tensor )
513
- items = inputs
514
-
515
- batch_size = len (items )
516
- shape = items [0 ].shape
517
- dim = len (shape )
518
- assert dim <= 2
519
- if max_length is None :
520
- max_length = 0
521
- max_length = max (max_length , max (item .shape [- 1 ] for item in items ))
522
- min_length = min (item .shape [- 1 ] for item in items )
523
- dtype = items [0 ].dtype
524
-
525
- if dim == 0 :
526
- return torch .stack ([item for item in items ], dim = 0 ), [0 ]
527
- elif dim == 1 :
528
- if max_length == min_length :
529
- return torch .stack ([item for item in items ], dim = 0 ), [0 ] * batch_size
530
- tensor = torch .zeros ((batch_size , max_length ), dtype = dtype ) + padding_value
531
- else :
532
- tensor = torch .zeros ((batch_size , max_length , shape [- 1 ]), dtype = dtype ) + padding_value
533
-
534
- padding_length = []
535
- for i , item in enumerate (items ):
536
- if dim == 1 :
537
- if padding_side == "left" :
538
- tensor [i , - len (item ) :] = item .clone ()
539
- else :
540
- tensor [i , : len (item )] = item .clone ()
541
- elif dim == 2 :
542
- if padding_side == "left" :
543
- tensor [i , - len (item ) :, :] = item .clone ()
544
- else :
545
- tensor [i , : len (item ), :] = item .clone ()
546
- padding_length .append (tensor .shape [- 1 ] - len (item ))
547
-
548
- return tensor , padding_length
549
-
550
517
551
518
class MelSpectrogramFeatures (torch .nn .Module ):
552
519
def __init__ (
0 commit comments