@@ -535,7 +535,6 @@ def decode_n_tokens(
535535 attention_backend : SDPBackend = torch .nn .attention .SDPBackend .MATH ,
536536 ** sampling_kwargs ,
537537 ):
538- new_tokens , new_probs = [], []
539538 encountered_eos = False
540539 for _i in range (
541540 num_new_tokens - 1
@@ -553,12 +552,10 @@ def decode_n_tokens(
553552 ** sampling_kwargs ,
554553 )
555554 input_pos += 1
556- new_tokens .append (next_token .clone ())
557- callback (new_tokens [- 1 ], done_generating = _i == num_new_tokens - 2 )
558- if need_probs or next_prob is None :
555+ callback (next_token .clone (), done_generating = _i == num_new_tokens - 2 )
556+ if not need_probs or next_prob is None :
559557 yield out_token , None
560558 else :
561- new_probs .append (next_prob .clone ())
562559 yield out_token , next_prob .clone ()
563560 cur_token = next_token
564561
@@ -585,7 +582,6 @@ def decode_n_tokens(
585582 dtype = cur_token .dtype ,
586583 device = cur_token .device ,
587584 )
588- new_tokens .append (eos_token .clone ())
589585 eos_token , next_prob = self .decode_one_token (
590586 model ,
591587 eos_token .view (1 , - 1 ),
@@ -788,7 +784,6 @@ def generate(
788784 input_pos = input_pos + num_added
789785 next_token = next_tokens [- 1 ]
790786 else :
791- generated_tokens = []
792787 for generated_token , _ in self .decode_n_tokens (
793788 model ,
794789 next_token ,
@@ -806,7 +801,6 @@ def generate(
806801 attention_backend = attention_backend ,
807802 ** sampling_kwargs ,
808803 ):
809- generated_tokens .append (generated_token .view (- 1 ))
810804 yield generated_token , None
811805
812806 generate_stats = {
0 commit comments