Skip to content

Commit bceee7d

Browse files
committed
instead of pad(), call tokenizer to create attn mask and padding in _convert_omni_to_inputs()
1 parent fd6f9ab commit bceee7d

File tree

1 file changed

+16
-49
lines changed

1 file changed

+16
-49
lines changed

src/transformers/models/minicpm_o_2_6/processing_minicpm_o_2_6.py

Lines changed: 16 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -432,6 +432,7 @@ def _convert_omni_to_inputs(
432432
else:
433433
images, image_sizes, tgt_sizes = [[]] * bs, [[]] * bs, [[]] * bs
434434

435+
final_texts_list = []
435436
input_ids_list = []
436437
image_bounds_list = []
437438
audio_bounds_list = []
@@ -467,14 +468,26 @@ def _convert_omni_to_inputs(
467468
final_text = "".join(text_chunks)
468469
input_ids, image_bounds, audio_bounds, spk_bounds = self._convert(final_text, max_length, **kwargs)
469470

471+
final_texts_list.append(final_text)
470472
input_ids_list.append(input_ids)
471473
image_bounds_list.append(image_bounds)
472474
audio_bounds_list.append(audio_bounds)
473475
spk_bounds_list.append(spk_bounds)
474476

475-
padded_input_ids, padding_lengths = self.pad(input_ids_list, padding_side="left")
476-
attention_mask = torch.ones_like(padded_input_ids, dtype=torch.bool)
477-
for i, length in enumerate(padding_lengths):
477+
model_inputs = self.tokenizer(
478+
final_texts_list,
479+
padding="longest",
480+
padding_side="left",
481+
return_tensors=return_tensors,
482+
truncation=truncation,
483+
max_length=max_length,
484+
**kwargs,
485+
)
486+
487+
padded_input_ids = model_inputs["input_ids"]
488+
attention_mask = model_inputs["attention_mask"]
489+
for i in range(bs):
490+
length = (attention_mask[i] == 0).sum().item()
478491
image_bounds_list[i] = image_bounds_list[i] + length
479492
audio_bounds_list[i] = audio_bounds_list[i] + length
480493
spk_bounds_list[i] = spk_bounds_list[i] + length
@@ -501,52 +514,6 @@ def model_input_names(self):
501514
feature_extractor_input_names = self.feature_extractor.model_input_names
502515
return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names + feature_extractor_input_names))
503516

504-
def pad(self, inputs, max_length=None, padding_value=0, padding_side="left"):
505-
items = []
506-
if isinstance(inputs[0], list):
507-
assert isinstance(inputs[0][0], torch.Tensor)
508-
for it in inputs:
509-
for tr in it:
510-
items.append(tr)
511-
else:
512-
assert isinstance(inputs[0], torch.Tensor)
513-
items = inputs
514-
515-
batch_size = len(items)
516-
shape = items[0].shape
517-
dim = len(shape)
518-
assert dim <= 2
519-
if max_length is None:
520-
max_length = 0
521-
max_length = max(max_length, max(item.shape[-1] for item in items))
522-
min_length = min(item.shape[-1] for item in items)
523-
dtype = items[0].dtype
524-
525-
if dim == 0:
526-
return torch.stack([item for item in items], dim=0), [0]
527-
elif dim == 1:
528-
if max_length == min_length:
529-
return torch.stack([item for item in items], dim=0), [0] * batch_size
530-
tensor = torch.zeros((batch_size, max_length), dtype=dtype) + padding_value
531-
else:
532-
tensor = torch.zeros((batch_size, max_length, shape[-1]), dtype=dtype) + padding_value
533-
534-
padding_length = []
535-
for i, item in enumerate(items):
536-
if dim == 1:
537-
if padding_side == "left":
538-
tensor[i, -len(item) :] = item.clone()
539-
else:
540-
tensor[i, : len(item)] = item.clone()
541-
elif dim == 2:
542-
if padding_side == "left":
543-
tensor[i, -len(item) :, :] = item.clone()
544-
else:
545-
tensor[i, : len(item), :] = item.clone()
546-
padding_length.append(tensor.shape[-1] - len(item))
547-
548-
return tensor, padding_length
549-
550517

551518
class MelSpectrogramFeatures(torch.nn.Module):
552519
def __init__(

0 commit comments

Comments
 (0)