Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit 72d2d20

Browse files
authored
Fix support for ET repo generated pte by adding batch dim (#1177)
1 parent 971ed93 commit 72d2d20

File tree

1 file changed

+5
-0
lines changed

1 file changed

+5
-0
lines changed

torchchat/model.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -951,6 +951,11 @@ def forward(self, x, input_pos):
951951
# the first element to get the tensor
952952
assert len(logits) == 1
953953
logits = logits[0]
954+
955+
# Add a batch dimension, if it's missing (e.g. some pte's
956+
# exported from the ExecuTorch repo)
957+
if logits.dim() == 2:
958+
logits = logits.unsqueeze(0)
954959
return logits
955960

956961
def setup_caches(self, max_batch_size, max_seq_length):

0 commit comments

Comments
 (0)