Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit d11f0e4

Browse files
committed
update flamingo model for tune
1 parent 6d2ef4a commit d11f0e4

File tree

1 file changed

+18
-8
lines changed

1 file changed

+18
-8
lines changed

torchchat/model.py

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -535,18 +535,28 @@ def reset_caches(self):
535535
class FlamingoModel(Model):
536536
def forward(
537537
self,
538-
tokens: Tensor,
539-
encoder_input: Optional[Dict[str, Tensor]] = None,
540-
encoder_mask: Optional[Tensor] = None,
538+
tokens: torch.Tensor,
539+
*,
540+
mask: Optional[torch.Tensor] = None,
541+
encoder_input: Optional[Dict] = None,
542+
encoder_mask: Optional[torch.Tensor] = None,
543+
input_pos: Optional[torch.Tensor] = None,
541544
) -> Tensor:
542-
if encoder_input is None:
543-
return self.model(tokens, encoder_mask=encoder_mask)
544545
return self.model(
545-
tokens, encoder_input=encoder_input, encoder_mask=encoder_mask
546+
tokens,
547+
mask=mask,
548+
encoder_input=encoder_input,
549+
encoder_mask=encoder_mask,
550+
input_pos=input_pos,
546551
)
547552

548-
def setup_caches(self, max_batch_size, dtype):
549-
self.model.setup_caches(max_batch_size, dtype=dtype)
553+
def setup_caches(self, batch_size, dtype, encoder_max_seq_len, decoder_max_seq_len):
554+
self.model.setup_caches(
555+
batch_size=batch_size,
556+
dtype=dtype,
557+
encoder_max_seq_len=encoder_max_seq_len,
558+
decoder_max_seq_len=decoder_max_seq_len,
559+
)
550560

551561
def reset_caches(self):
552562
self.model.reset_caches()

0 commit comments

Comments
 (0)