Skip to content

Commit 9a433cd

Browse files
committed
fix
Signed-off-by: guochenxu <[email protected]>
1 parent dbcdb7b commit 9a433cd

File tree

1 file changed

+61
-61
lines changed

1 file changed

+61
-61
lines changed

src/transformers/models/minicpm_o_2_6/modeling_minicpm_o_2_6.py

Lines changed: 61 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -723,42 +723,64 @@ def _get_feat_extract_output_lengths(self, input_lengths: torch.LongTensor):
723723

724724
return input_lengths_after_cnn, input_lengths_after_pooling
725725

726-
def get_image_features(self, tgt_sizes, all_pixel_values, dtype, device):
727-
tgt_sizes = [tgt_size for tgt_size in tgt_sizes if isinstance(tgt_size, torch.Tensor)]
728-
tgt_sizes = torch.vstack(tgt_sizes).type(torch.int32)
726+
def get_image_features(self, pixel_values_list, tgt_sizes, dtype, device):
727+
vision_hidden_states = []
728+
all_pixel_values = []
729+
img_cnt = []
730+
for pixel_values in pixel_values_list:
731+
img_cnt.append(len(pixel_values))
732+
all_pixel_values.extend([i.flatten(end_dim=1).permute(1, 0) for i in pixel_values])
733+
734+
# exist image
735+
if all_pixel_values:
736+
tgt_sizes = [tgt_size for tgt_size in tgt_sizes if isinstance(tgt_size, torch.Tensor)]
737+
tgt_sizes = torch.vstack(tgt_sizes).type(torch.int32)
738+
739+
max_patches = torch.max(tgt_sizes[:, 0] * tgt_sizes[:, 1])
740+
741+
all_pixel_values = torch.nn.utils.rnn.pad_sequence(
742+
all_pixel_values, batch_first=True, padding_value=0.0
743+
)
744+
B, L, _ = all_pixel_values.shape
745+
all_pixel_values = all_pixel_values.permute(0, 2, 1).reshape(B, 3, -1, L)
746+
747+
patch_attn_mask = torch.zeros((B, 1, max_patches), dtype=torch.bool, device=device)
748+
for i in range(B):
749+
patch_attn_mask[i, 0, : tgt_sizes[i][0] * tgt_sizes[i][1]] = True
750+
751+
vision_batch_size = self.config.vision_batch_size
752+
all_pixel_values = all_pixel_values.type(dtype)
753+
if B > vision_batch_size:
754+
hs = []
755+
for i in range(0, B, vision_batch_size):
756+
start_idx = i
757+
end_idx = i + vision_batch_size
758+
tmp_hs = self.vpm(
759+
all_pixel_values[start_idx:end_idx],
760+
patch_attention_mask=patch_attn_mask[start_idx:end_idx],
761+
tgt_sizes=tgt_sizes[start_idx:end_idx],
762+
).last_hidden_state
763+
hs.append(tmp_hs)
764+
vision_embedding = torch.cat(hs, dim=0)
765+
else:
766+
vision_embedding = self.vpm(
767+
all_pixel_values, patch_attention_mask=patch_attn_mask, tgt_sizes=tgt_sizes
768+
).last_hidden_state
729769

730-
max_patches = torch.max(tgt_sizes[:, 0] * tgt_sizes[:, 1])
770+
vision_embedding = self.resampler(vision_embedding, tgt_sizes)
731771

732-
all_pixel_values = torch.nn.utils.rnn.pad_sequence(
733-
all_pixel_values, batch_first=True, padding_value=0.0
734-
)
735-
B, L, _ = all_pixel_values.shape
736-
all_pixel_values = all_pixel_values.permute(0, 2, 1).reshape(B, 3, -1, L)
737-
738-
patch_attn_mask = torch.zeros((B, 1, max_patches), dtype=torch.bool, device=device)
739-
for i in range(B):
740-
patch_attn_mask[i, 0, : tgt_sizes[i][0] * tgt_sizes[i][1]] = True
741-
742-
vision_batch_size = self.config.vision_batch_size
743-
all_pixel_values = all_pixel_values.type(dtype)
744-
if B > vision_batch_size:
745-
hs = []
746-
for i in range(0, B, vision_batch_size):
747-
start_idx = i
748-
end_idx = i + vision_batch_size
749-
tmp_hs = self.vpm(
750-
all_pixel_values[start_idx:end_idx],
751-
patch_attention_mask=patch_attn_mask[start_idx:end_idx],
752-
tgt_sizes=tgt_sizes[start_idx:end_idx],
753-
).last_hidden_state
754-
hs.append(tmp_hs)
755-
vision_embedding = torch.cat(hs, dim=0)
756-
else:
757-
vision_embedding = self.vpm(
758-
all_pixel_values, patch_attention_mask=patch_attn_mask, tgt_sizes=tgt_sizes
759-
).last_hidden_state
760-
vision_embedding = self.resampler(vision_embedding, tgt_sizes)
761-
return vision_embedding
772+
start = 0
773+
for pixel_values in pixel_values_list:
774+
img_cnt = len(pixel_values)
775+
if img_cnt > 0:
776+
vision_hidden_states.append(vision_embedding[start : start + img_cnt])
777+
start += img_cnt
778+
else:
779+
vision_hidden_states.append([])
780+
else: # no image
781+
vision_hidden_states.extend([[]] * len(pixel_values_list))
782+
783+
return vision_hidden_states
762784

763785
def get_vllm_embedding(self, data):
764786
"""
@@ -773,34 +795,12 @@ def get_vllm_embedding(self, data):
773795
Returns:
774796
embedding with vision, vision_hidden_states
775797
"""
798+
dtype = self.language_model.embed_tokens.weight.dtype
799+
device = self.language_model.embed_tokens.weight.device
776800
if "vision_hidden_states" not in data:
777-
dtype = self.language_model.embed_tokens.weight.dtype
778-
device = self.language_model.embed_tokens.weight.device
779-
tgt_sizes = data["tgt_sizes"]
780-
pixel_values_list = data["pixel_values"]
781-
vision_hidden_states = []
782-
all_pixel_values = []
783-
img_cnt = []
784-
for pixel_values in pixel_values_list:
785-
img_cnt.append(len(pixel_values))
786-
all_pixel_values.extend([i.flatten(end_dim=1).permute(1, 0) for i in pixel_values])
787-
788-
# exist image
789-
if all_pixel_values:
790-
791-
vision_embedding = self.get_image_features(tgt_sizes=tgt_sizes, all_pixel_values=all_pixel_values, dtype=dtype, device=device)
792-
793-
start = 0
794-
for pixel_values in pixel_values_list:
795-
img_cnt = len(pixel_values)
796-
if img_cnt > 0:
797-
vision_hidden_states.append(vision_embedding[start : start + img_cnt])
798-
start += img_cnt
799-
else:
800-
vision_hidden_states.append([])
801-
else: # no image
802-
vision_hidden_states.extend([[]] * len(pixel_values_list))
803-
801+
vision_hidden_states = self.get_image_features(
802+
data["pixel_values"], data["tgt_sizes"], dtype, device
803+
)
804804
else:
805805
vision_hidden_states = data["vision_hidden_states"]
806806

0 commit comments

Comments
 (0)