|
| 1 | +/* |
| 2 | +git clone --recursive https://github.com/utilityai/llama-cpp-rs |
| 3 | +cd llama-cpp-rs/examples/usage |
| 4 | +wget https://huggingface.co/Qwen/Qwen2-1.5B-Instruct-GGUF/resolve/main/qwen2-1_5b-instruct-q4_0.gguf |
| 5 | +cargo run qwen2-1_5b-instruct-q4_0.gguf |
| 6 | +*/ |
| 7 | +use std::io::Write; |
| 8 | +use llama_cpp_2::context::params::LlamaContextParams; |
| 9 | +use llama_cpp_2::llama_backend::LlamaBackend; |
| 10 | +use llama_cpp_2::llama_batch::LlamaBatch; |
| 11 | +use llama_cpp_2::model::params::LlamaModelParams; |
| 12 | +use llama_cpp_2::model::LlamaModel; |
| 13 | +use llama_cpp_2::model::{AddBos, Special}; |
| 14 | +use llama_cpp_2::token::data_array::LlamaTokenDataArray; |
| 15 | + |
| 16 | +fn main() { |
| 17 | + let model_path = std::env::args().nth(1).expect("Please specify model path"); |
| 18 | + let backend = LlamaBackend::init().unwrap(); |
| 19 | + let params = LlamaModelParams::default(); |
| 20 | + |
| 21 | + let prompt = "<|im_start|>user\nHello! how are you?<|im_end|>\n<|im_start|>assistant\n".to_string(); |
| 22 | + LlamaContextParams::default(); |
| 23 | + let model = |
| 24 | + LlamaModel::load_from_file(&backend, model_path, ¶ms).expect("unable to load model"); |
| 25 | + let ctx_params = LlamaContextParams::default(); |
| 26 | + let mut ctx = model |
| 27 | + .new_context(&backend, ctx_params) |
| 28 | + .expect("unable to create the llama_context"); |
| 29 | + let tokens_list = model |
| 30 | + .str_to_token(&prompt, AddBos::Always) |
| 31 | + .expect(&format!("failed to tokenize {prompt}")); |
| 32 | + let n_len = 64; |
| 33 | + |
| 34 | + // create a llama_batch with size 512 |
| 35 | + // we use this object to submit token data for decoding |
| 36 | + let mut batch = LlamaBatch::new(512, 1); |
| 37 | + |
| 38 | + let last_index: i32 = (tokens_list.len() - 1) as i32; |
| 39 | + for (i, token) in (0_i32..).zip(tokens_list.into_iter()) { |
| 40 | + // llama_decode will output logits only for the last token of the prompt |
| 41 | + let is_last = i == last_index; |
| 42 | + batch.add(token, i, &[0], is_last).unwrap(); |
| 43 | + } |
| 44 | + ctx.decode(&mut batch).expect("llama_decode() failed"); |
| 45 | + |
| 46 | + |
| 47 | + let mut n_cur = batch.n_tokens(); |
| 48 | + |
| 49 | + |
| 50 | + // The `Decoder` |
| 51 | + let mut decoder = encoding_rs::UTF_8.new_decoder(); |
| 52 | + |
| 53 | + |
| 54 | + while n_cur <= n_len { |
| 55 | + // sample the next token |
| 56 | + { |
| 57 | + let candidates = ctx.candidates_ith(batch.n_tokens() - 1); |
| 58 | + |
| 59 | + let candidates_p = LlamaTokenDataArray::from_iter(candidates, false); |
| 60 | + |
| 61 | + // sample the most likely token |
| 62 | + let new_token_id = ctx.sample_token_greedy(candidates_p); |
| 63 | + |
| 64 | + // is it an end of stream? |
| 65 | + if new_token_id == model.token_eos() { |
| 66 | + eprintln!(); |
| 67 | + break; |
| 68 | + } |
| 69 | + |
| 70 | + let output_bytes = model.token_to_bytes(new_token_id, Special::Tokenize).unwrap(); |
| 71 | + // use `Decoder.decode_to_string()` to avoid the intermediate buffer |
| 72 | + let mut output_string = String::with_capacity(32); |
| 73 | + let _decode_result = decoder.decode_to_string(&output_bytes, &mut output_string, false); |
| 74 | + print!("{output_string}"); |
| 75 | + std::io::stdout().flush().unwrap(); |
| 76 | + |
| 77 | + batch.clear(); |
| 78 | + batch.add(new_token_id, n_cur, &[0], true).unwrap(); |
| 79 | + } |
| 80 | + |
| 81 | + n_cur += 1; |
| 82 | + |
| 83 | + ctx.decode(&mut batch).expect("failed to eval"); |
| 84 | + } |
| 85 | +} |
0 commit comments