This repository was archived by the owner on Sep 10, 2025. It is now read-only.
File tree Expand file tree Collapse file tree 2 files changed +11
-1
lines changed Expand file tree Collapse file tree 2 files changed +11
-1
lines changed Original file line number Diff line number Diff line change @@ -812,7 +812,12 @@ def chat(
812812
813813 elif not generator_args .is_torchtune_model :
814814 max_seq_length = min (
815- encoded .size (0 ) + generator_args .max_new_tokens , max_seq_length
815+ encoded .size (0 ) + generator_args .max_new_tokens ,
816+ (
817+ text_transformer_args .block_size
818+ if text_transformer_args is not None
819+ else 2048
820+ ),
816821 )
817822
818823 max_seq_length = (
Original file line number Diff line number Diff line change @@ -170,6 +170,8 @@ class ModelArgs:
170170 transformer_args (Dict[str, Dict[str, Any]]): A dictionary containing the parameters for each transformer in the model.
171171 The outer dictionary has transformer names as keys and inner dictionaries as values. Each inner dictionary contains
172172 the parameter names and their corresponding values for the respective transformer.
173+ TODO: econcile Dict[str, Any] into tranformer-arg-family classes in future PRs.
174+
173175 use_tiktoken (bool): A flag indicating whether to use TikToken as the tokenizer for the model.
174176 Note:
175177 It is recommended to use factory functions to create instances of this class instead of directly using the constructor.
@@ -304,6 +306,9 @@ def __init__(self, config: ModelArgs) -> None:
304306 super ().__init__ ()
305307 self .config = config
306308 self .model = self .build_model ()
309+
310+ # text_transformer_args represents the args for the text transformer in the model.
311+ # It should be assigned in the actual model implementation, if any.
307312 self .text_transformer_args = None
308313
309314 def build_model (self ) -> nn .Module :
You can’t perform that action at this time.
0 commit comments