diff --git a/torchchat/model.py b/torchchat/model.py index aaa72cb2a..79bd1f188 100644 --- a/torchchat/model.py +++ b/torchchat/model.py @@ -951,6 +951,11 @@ def forward(self, x, input_pos): # the first element to get the tensor assert len(logits) == 1 logits = logits[0] + + # Add a batch dimension, if it's missing (e.g. some pte's + # exported from the ExecuTorch repo) + if logits.dim() == 2: + logits = logits.unsqueeze(0) return logits def setup_caches(self, max_batch_size, max_seq_length):