diff --git a/torchchat/cli/cli.py b/torchchat/cli/cli.py index 007fe0c8f..d35c8efbd 100644 --- a/torchchat/cli/cli.py +++ b/torchchat/cli/cli.py @@ -359,6 +359,12 @@ def _add_generation_args(parser, verb: str) -> None: default=1, help="Number of samples", ) + generator_parser.add_argument( + "--accumulate-tokens", + type=int, + default=8, + help="Number of generated tokens to accumulate before calling the callback on each one of them.", + ) generator_parser.add_argument( "--image-prompts", diff --git a/torchchat/generate.py b/torchchat/generate.py index 53d9d8f8c..3d639bbe7 100644 --- a/torchchat/generate.py +++ b/torchchat/generate.py @@ -230,6 +230,7 @@ class GeneratorArgs: max_autotune: bool = False # (Misnomer) See Issue: https://github.com/pytorch/torchchat/issues/1273 is_torchtune_model: bool = False + accumulate_tokens: int = 8 def __post_init__(self): if self.compile_prefill and self.sequential_prefill: @@ -294,6 +295,7 @@ def from_args(cls, args): sequential_prefill=sequential_prefill, max_autotune=args.max_autotune, is_torchtune_model=args.model and args.model.endswith("tune"), + accumulate_tokens=getattr(args, "accumulate_tokens", 8), ) @@ -530,11 +532,13 @@ def decode_n_tokens( need_probs: bool, batch=Optional[Dict[str, Any]], # Inputs for multimodal models callback=lambda _: _, + accumulate_tokens: int = 8, eos_token_id: int = 2, eot_id: Optional[int] = None, attention_backend: SDPBackend = torch.nn.attention.SDPBackend.MATH, **sampling_kwargs, ): + new_tokens = [] encountered_eos = False for _i in range( num_new_tokens - 1 @@ -552,29 +556,52 @@ def decode_n_tokens( **sampling_kwargs, ) input_pos += 1 - callback(next_token.clone(), done_generating=_i == num_new_tokens - 2) + new_tokens.append(next_token.clone()) + + done_generating = _i == num_new_tokens - 2 + if need_probs: + callback(new_tokens[-1], done_generating=done_generating) if not need_probs or next_prob is None: yield out_token, None else: yield out_token, next_prob.clone() cur_token = next_token - # encountered eos - if next_token.item() == eos_token_id or ( - eot_id is not None and next_token.item() == eot_id - ): - encountered_eos = True - final_token, next_prob = self.decode_one_token( - model, - cur_token, - input_pos, - need_probs, - batch=batch, - **sampling_kwargs, - ) - input_pos += 1 - yield cur_token.clone(), next_prob.clone() - break + if need_probs: + # encountered eos + if next_token.item() == eos_token_id or ( + eot_id is not None and next_token.item() == eot_id + ): + encountered_eos = True + final_token, next_prob = self.decode_one_token( + model, + cur_token, + input_pos, + need_probs, + batch=batch, + **sampling_kwargs, + ) + input_pos += 1 + yield cur_token.clone(), next_prob.clone() + break + else: + callback_pos = _i % accumulate_tokens + 1 + if done_generating or callback_pos == accumulate_tokens: + callback_num = min(accumulate_tokens, callback_pos) + for i in range(callback_num, 0, -1): + callback(new_tokens[-i], done_generating=done_generating) + + token_item = new_tokens[-i].item() + # encountered eos + if token_item == eos_token_id or ( + eot_id is not None and token_item == eot_id + ): + encountered_eos = True + input_pos += 1 + yield new_tokens[-i].clone(), None + break + if encountered_eos: + break if not encountered_eos: eos_token = torch.tensor( @@ -681,6 +708,7 @@ def generate( speculate_k: Optional[int] = 8, sequential_prefill=True, callback=lambda x: x, + accumulate_tokens: int, max_seq_length: int, attention_backend: SDPBackend = torch.nn.attention.SDPBackend.MATH, seed: Optional[int] = None, @@ -791,6 +819,7 @@ def generate( max_new_tokens - 1, batch=batch, callback=callback, + accumulate_tokens=accumulate_tokens, need_probs=False, eos_token_id=self.tokenizer.eos_id() if self.tokenizer else 2, eot_id=( @@ -1179,6 +1208,7 @@ def callback(x, *, done_generating=False): chat_mode=generator_args.chat_mode, batch=batch, callback=callback, + accumulate_tokens=generator_args.accumulate_tokens, temperature=generator_args.temperature, top_k=generator_args.top_k, sequential_prefill=generator_args.sequential_prefill,