Skip to content

Commit fc67969

Browse files
Fix DotsOCR tensor type (#26281)
Signed-off-by: what_in_the_nim <[email protected]>
1 parent ab5e7d9 commit fc67969

File tree

1 file changed

+6
-6
lines changed

1 file changed

+6
-6
lines changed

vllm/model_executor/models/dots_ocr.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -617,7 +617,7 @@ def dtype(self) -> torch.dtype:
617617
def device(self) -> torch.device:
618618
return self.patch_embed.patchifier.proj.weight.device
619619

620-
def get_pos_ids_by_grid(self, grid_thw):
620+
def get_pos_ids_by_grid(self, grid_thw: list[list[int]]) -> list[torch.Tensor]:
621621
pos_ids = []
622622
for t, h, w in grid_thw:
623623
hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
@@ -643,10 +643,10 @@ def get_pos_ids_by_grid(self, grid_thw):
643643

644644
return pos_ids
645645

646-
def rot_pos_emb(self, grid_thw):
646+
def rot_pos_emb(self, grid_thw: list[list[int]]) -> torch.Tensor:
647647
pos_ids = self.get_pos_ids_by_grid(grid_thw)
648648
pos_ids = torch.cat(pos_ids, dim=0)
649-
max_grid_size = grid_thw[:, 1:].max()
649+
max_grid_size = max(max(h, w) for _, h, w in grid_thw)
650650
rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size)
651651
rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
652652
return rotary_pos_emb
@@ -667,13 +667,13 @@ def compute_attn_mask_seqlen(
667667
def forward(
668668
self, hidden_states: torch.Tensor, grid_thw: list[list[int]]
669669
) -> torch.Tensor:
670+
rotary_pos_emb = self.rot_pos_emb(grid_thw)
671+
670672
# Convert grid_thw to tensor (always expecting list format now)
671673
grid_thw = torch.tensor(grid_thw, device=hidden_states.device, dtype=torch.long)
672674
hidden_states = hidden_states.to(self.dtype)
673675
hidden_states = self.patch_embed(hidden_states, grid_thw)
674676

675-
rotary_pos_emb = self.rot_pos_emb(grid_thw)
676-
677677
cu_seqlens = torch.repeat_interleave(
678678
grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]
679679
).cumsum(
@@ -807,7 +807,7 @@ def _process_image_input(
807807
rope_type="rope_3d",
808808
)
809809
else:
810-
image_embeds = self.vision_tower(pixel_values, grid_thw)[
810+
image_embeds = self.vision_tower(pixel_values, grid_thw_list)[
811811
:, : self.config.hidden_size
812812
]
813813

0 commit comments

Comments
 (0)