Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit 83f8501

Browse files
committed
Merge branch 'unify-constuct-model' into llava-support
2 parents 6fbb460 + f224da7 commit 83f8501

File tree

2 files changed

+11
-1
lines changed

2 files changed

+11
-1
lines changed

torchchat/generate.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -812,7 +812,12 @@ def chat(
812812

813813
elif not generator_args.is_torchtune_model:
814814
max_seq_length = min(
815-
encoded.size(0) + generator_args.max_new_tokens, max_seq_length
815+
encoded.size(0) + generator_args.max_new_tokens,
816+
(
817+
text_transformer_args.block_size
818+
if text_transformer_args is not None
819+
else 2048
820+
),
816821
)
817822

818823
max_seq_length = (

torchchat/model.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -302,6 +302,8 @@ class ModelArgs:
302302
transformer_args (Dict[str, Dict[str, Any]]): A dictionary containing the parameters for each transformer in the model.
303303
The outer dictionary has transformer names as keys and inner dictionaries as values. Each inner dictionary contains
304304
the parameter names and their corresponding values for the respective transformer.
305+
TODO: econcile Dict[str, Any] into tranformer-arg-family classes in future PRs.
306+
305307
use_tiktoken (bool): A flag indicating whether to use TikToken as the tokenizer for the model.
306308
Note:
307309
It is recommended to use factory functions to create instances of this class instead of directly using the constructor.
@@ -436,6 +438,9 @@ def __init__(self, config: ModelArgs) -> None:
436438
super().__init__()
437439
self.config = config
438440
self.model = self.build_model()
441+
442+
# text_transformer_args represents the args for the text transformer in the model.
443+
# It should be assigned in the actual model implementation, if any.
439444
self.text_transformer_args = None
440445

441446
def build_model(self) -> nn.Module:

0 commit comments

Comments
 (0)