@@ -187,6 +187,45 @@ impl<'model> LlamaContext<'model> {
187187 }
188188 }
189189
190+ /// Get the logits for the last token in the context.
191+ ///
192+ /// # Returns
193+ /// An iterator over unsorted `LlamaTokenData` containing the
194+ /// logits for the last token in the context.
195+ ///
196+ /// # Panics
197+ ///
198+ /// - underlying logits data is null
199+ pub fn candidates ( & self ) -> impl Iterator < Item = LlamaTokenData > + ' _ {
200+ ( 0_i32 ..) . zip ( self . get_logits ( ) ) . map ( |( i, logit) | {
201+ let token = LlamaToken :: new ( i) ;
202+ LlamaTokenData :: new ( token, * logit, 0_f32 )
203+ } )
204+ }
205+
206+ /// Token logits obtained from the last call to `decode()`.
207+ /// The logits for which `batch.logits[i] != 0` are stored contiguously
208+ /// in the order they have appeared in the batch.
209+ /// Rows: number of tokens for which `batch.logits[i] != 0`
210+ /// Cols: `n_vocab`
211+ ///
212+ /// # Returns
213+ ///
214+ /// A slice containing the logits for the last decoded token.
215+ /// The size corresponds to the `n_vocab` parameter of the context's model.
216+ ///
217+ /// # Panics
218+ ///
219+ /// - `n_vocab` does not fit into a usize
220+ /// - token data returned is null
221+ pub fn get_logits ( & self ) -> & [ f32 ] {
222+ let data = unsafe { llama_cpp_sys_2:: llama_get_logits ( self . context . as_ptr ( ) ) } ;
223+ assert ! ( !data. is_null( ) , "logits data for last token is null" ) ;
224+ let len = usize:: try_from ( self . model . n_vocab ( ) ) . expect ( "n_vocab does not fit into a usize" ) ;
225+
226+ unsafe { slice:: from_raw_parts ( data, len) }
227+ }
228+
190229 /// Get the logits for the ith token in the context.
191230 ///
192231 /// # Panics
0 commit comments