diff --git a/torchchat/generate.py b/torchchat/generate.py index a8501328e..4e25912f4 100644 --- a/torchchat/generate.py +++ b/torchchat/generate.py @@ -140,6 +140,7 @@ class GeneratorArgs: speculate_k: int = 5 sequential_prefill: bool = False max_autotune: bool = False + # (Misnomer) See Issue: https://github.com/pytorch/torchchat/issues/1273 is_torchtune_model: bool = False def __post_init__(self): @@ -958,16 +959,19 @@ def chat( if get_system_prompt == "y" or get_system_prompt == "Y": self.system_prompt = input("What is your system prompt? \n") - # elif not generator_args.is_torchtune_model: - # max_seq_length = min( - # encoded.size(0) + generator_args.max_new_tokens, - # ( - # text_transformer_args.block_size - # if text_transformer_args is not None - # else 2048 - # ), - # max_seq_length, - # ) + # `is_torchtune_model` is a misnomer since it doesn't capture all + # torchtune models (i.e. Flamingo) + # See Issue: https://github.com/pytorch/torchchat/issues/1273 + elif not generator_args.is_torchtune_model and self.model.config.model_type != ModelType.Flamingo: + max_seq_length = min( + encoded.size(0) + generator_args.max_new_tokens, + ( + text_transformer_args.block_size + if text_transformer_args is not None + else 2048 + ), + max_seq_length, + ) max_seq_length = ( max_seq_length + self.speculative_builder_args.speculate_k + 1