@@ -17,7 +17,9 @@ use llama_cpp_2::model::params::kv_overrides::ParamOverrideValue;
1717use llama_cpp_2:: model:: params:: LlamaModelParams ;
1818use llama_cpp_2:: model:: LlamaModel ;
1919use llama_cpp_2:: model:: { AddBos , Special } ;
20- use llama_cpp_2:: token:: data_array:: LlamaTokenDataArray ;
20+ use llama_cpp_2:: sampling:: params:: LlamaSamplerChainParams ;
21+ use llama_cpp_2:: sampling:: LlamaSampler ;
22+
2123use std:: ffi:: CString ;
2224use std:: io:: Write ;
2325use std:: num:: NonZeroU32 ;
@@ -174,9 +176,9 @@ fn main() -> Result<()> {
174176 . with_context ( || "unable to load model" ) ?;
175177
176178 // initialize the context
177- let mut ctx_params = LlamaContextParams :: default ( )
178- . with_n_ctx ( ctx_size. or ( Some ( NonZeroU32 :: new ( 2048 ) . unwrap ( ) ) ) )
179- . with_seed ( seed . unwrap_or ( 1234 ) ) ;
179+ let mut ctx_params =
180+ LlamaContextParams :: default ( ) . with_n_ctx ( ctx_size. or ( Some ( NonZeroU32 :: new ( 2048 ) . unwrap ( ) ) ) ) ;
181+
180182 if let Some ( threads) = threads {
181183 ctx_params = ctx_params. with_n_threads ( threads) ;
182184 }
@@ -244,31 +246,31 @@ either reduce n_len or increase n_ctx"
244246 // The `Decoder`
245247 let mut decoder = encoding_rs:: UTF_8 . new_decoder ( ) ;
246248
249+ let sampler_params = LlamaSamplerChainParams :: default ( ) ;
250+ let mut sampler = LlamaSampler :: new ( sampler_params) ?. add_dist ( seed. unwrap_or ( 1234 ) ) ;
251+
247252 while n_cur <= n_len {
248253 // sample the next token
249254 {
250- let candidates = ctx. candidates ( ) ;
251-
252- let candidates_p = LlamaTokenDataArray :: from_iter ( candidates, false ) ;
255+ let token = sampler. sample ( & ctx, batch. n_tokens ( ) - 1 ) ;
253256
254- // sample the most likely token
255- let new_token_id = ctx. sample_token_greedy ( candidates_p) ;
257+ sampler. accept ( token) ;
256258
257259 // is it an end of stream?
258- if model. is_eog_token ( new_token_id ) {
260+ if model. is_eog_token ( token ) {
259261 eprintln ! ( ) ;
260262 break ;
261263 }
262264
263- let output_bytes = model. token_to_bytes ( new_token_id , Special :: Tokenize ) ?;
265+ let output_bytes = model. token_to_bytes ( token , Special :: Tokenize ) ?;
264266 // use `Decoder.decode_to_string()` to avoid the intermediate buffer
265267 let mut output_string = String :: with_capacity ( 32 ) ;
266268 let _decode_result = decoder. decode_to_string ( & output_bytes, & mut output_string, false ) ;
267269 print ! ( "{output_string}" ) ;
268270 std:: io:: stdout ( ) . flush ( ) ?;
269271
270272 batch. clear ( ) ;
271- batch. add ( new_token_id , n_cur, & [ 0 ] , true ) ?;
273+ batch. add ( token , n_cur, & [ 0 ] , true ) ?;
272274 }
273275
274276 n_cur += 1 ;
0 commit comments