99
1010import torch
1111
12- from executorch .examples .models .llama .llama_transformer import ModelArgs
1312from executorch .extension .llm .tokenizer .utils import get_tokenizer
1413
1514
@@ -47,11 +46,35 @@ def next_token(logits: torch.Tensor, temperature: float, top_p: float) -> int:
4746
4847
4948class LlamaRunner (ABC ):
50- def __init__ (self , tokenizer_path : str , model_args : ModelArgs , device : str = "cpu" ):
51- self .params = model_args
49+ def __init__ (
50+ self ,
51+ tokenizer_path : str ,
52+ max_seq_len : int ,
53+ max_batch_size : int ,
54+ use_kv_cache : bool ,
55+ vocab_size : int ,
56+ has_full_logits : bool = False ,
57+ device : str = "cpu" ,
58+ ):
59+ """
60+ Constructor.
61+
62+ Args:
63+ tokenizer_path: path to tokenizer.model file.
64+ max_seq_len: max length of the output sequence, after which the output will be clipped.
65+ max_batch_size: max batch size.
66+ use_kv_cache: whether to use a KV cache.
67+ vocab_size: number of items in the vocab.
68+ has_full_logits: whether the model returns the full logits or only returns the last logit.
69+ device: device to run the runner on.
70+ """
71+ self .max_seq_len = max_seq_len
72+ self .max_batch_size = max_batch_size
73+ self .use_kv_cache = use_kv_cache
5274 self .tokenizer = get_tokenizer (tokenizer_path )
53- assert model_args . vocab_size == self . tokenizer . n_words
75+ self . has_full_logits = has_full_logits
5476 self .device = device
77+ assert vocab_size == self .tokenizer .n_words
5578
5679 @abstractmethod
5780 def forward (
@@ -75,17 +98,20 @@ def generate( # noqa: C901
7598 tokens = torch .tensor ([prompt_tokens ], dtype = torch .long , device = self .device ),
7699 input_pos = (
77100 torch .tensor ([pos_base ], dtype = torch .long , device = self .device )
78- if self .params . use_kv_cache
101+ if self .use_kv_cache
79102 else None
80103 ),
81104 )
82105
83- current_token = next_token (logits , temperature , top_p )
106+ if self .has_full_logits :
107+ current_token = next_token (logits [:, - 1 , :], temperature , top_p )
108+ else :
109+ current_token = next_token (logits , temperature , top_p )
84110 print (f"{ self .tokenizer .decode_token (current_token )} " , end = "" , flush = True )
85111 tokens = prompt_tokens + [current_token ]
86112
87113 while len (tokens ) < max_seq_len :
88- if self .params . use_kv_cache :
114+ if self .use_kv_cache :
89115 logits = self .forward (
90116 tokens = torch .tensor (
91117 [[current_token ]], dtype = torch .long , device = self .device
@@ -100,13 +126,20 @@ def generate( # noqa: C901
100126 logits = self .forward (
101127 tokens = torch .tensor ([tokens ], dtype = torch .long , device = self .device ),
102128 )
103- current_token = next_token (logits , temperature , top_p )
129+
130+ # If the logits aren't already clipped to only contain the last logit, clip them.
131+ if self .has_full_logits :
132+ current_token = next_token (logits [:, - 1 , :], temperature , top_p )
133+ else :
134+ current_token = next_token (logits , temperature , top_p )
104135 tokens .append (current_token )
136+
105137 if current_token == self .tokenizer .eos_id or (
106138 hasattr (self .tokenizer , "stop_tokens" )
107139 and current_token in self .tokenizer .stop_tokens
108140 ):
109141 break
142+
110143 print (f"{ self .tokenizer .decode_token (current_token )} " , end = "" , flush = True )
111144 print ("\n " )
112145
@@ -136,7 +169,7 @@ def text_completion(
136169 """
137170 return self .generate (
138171 prompt_tokens = self .tokenizer .encode (prompt , bos = True , eos = False ),
139- max_seq_len = self .params . max_seq_len ,
172+ max_seq_len = self .max_seq_len ,
140173 temperature = temperature ,
141174 top_p = top_p ,
142175 echo = echo ,
@@ -171,7 +204,7 @@ def chat_completion(
171204 prompt_tokens = self .tokenizer .encode (
172205 self ._format_prompt (prompt ), bos = True , eos = False
173206 ),
174- max_seq_len = self .params . max_seq_len ,
207+ max_seq_len = self .max_seq_len ,
175208 temperature = temperature ,
176209 top_p = top_p ,
177210 echo = True ,
0 commit comments