Skip to content

Commit c631133

Browse files
committed
updated type for RopeScalingType + fmt
1 parent e94e54b commit c631133

File tree

6 files changed

+53
-38
lines changed

6 files changed

+53
-38
lines changed

llama-cpp-2/examples/simple.rs

Lines changed: 24 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,20 @@
11
//! This is an translation of simple.cpp in llama.cpp using llama-cpp-2.
22
#![allow(clippy::cast_possible_wrap, clippy::cast_possible_truncation)]
33

4-
use std::io::Write;
5-
use std::num::NonZeroU32;
6-
use std::path::PathBuf;
7-
use std::time::Duration;
4+
use anyhow::{bail, Context, Result};
85
use clap::Parser;
96
use llama_cpp_2::context::params::LlamaContextParams;
10-
use llama_cpp_2::llama_backend::LlamaBackend;
11-
use llama_cpp_2::model::LlamaModel;
12-
use llama_cpp_2::model::params::LlamaModelParams;
13-
use anyhow::{bail, Context, Result};
147
use llama_cpp_2::ggml_time_us;
8+
use llama_cpp_2::llama_backend::LlamaBackend;
159
use llama_cpp_2::llama_batch::LlamaBatch;
16-
use llama_cpp_2::token::data_array::LlamaTokenDataArray;
10+
use llama_cpp_2::model::params::LlamaModelParams;
1711
use llama_cpp_2::model::AddBos;
18-
12+
use llama_cpp_2::model::LlamaModel;
13+
use llama_cpp_2::token::data_array::LlamaTokenDataArray;
14+
use std::io::Write;
15+
use std::num::NonZeroU32;
16+
use std::path::PathBuf;
17+
use std::time::Duration;
1918

2019
#[derive(clap::Parser)]
2120
struct Args {
@@ -30,7 +29,6 @@ struct Args {
3029
disable_gpu: bool,
3130
}
3231

33-
3432
fn main() -> Result<()> {
3533
let params = Args::parse();
3634

@@ -60,12 +58,14 @@ fn main() -> Result<()> {
6058
.with_n_ctx(NonZeroU32::new(2048))
6159
.with_seed(1234);
6260

63-
let mut ctx = model.new_context(&backend, ctx_params)
61+
let mut ctx = model
62+
.new_context(&backend, ctx_params)
6463
.with_context(|| "unable to create the llama_context")?;
6564

6665
// tokenize the prompt
6766

68-
let tokens_list = model.str_to_token(&params.prompt, AddBos::Always)
67+
let tokens_list = model
68+
.str_to_token(&params.prompt, AddBos::Always)
6969
.with_context(|| format!("failed to tokenize {}", params.prompt))?;
7070

7171
let n_cxt = ctx.n_ctx() as i32;
@@ -75,8 +75,10 @@ fn main() -> Result<()> {
7575

7676
// make sure the KV cache is big enough to hold all the prompt and generated tokens
7777
if n_kv_req > n_cxt {
78-
bail!("n_kv_req > n_ctx, the required kv cache size is not big enough
79-
either reduce n_len or increase n_ctx")
78+
bail!(
79+
"n_kv_req > n_ctx, the required kv cache size is not big enough
80+
either reduce n_len or increase n_ctx"
81+
)
8082
}
8183

8284
// print the prompt token-by-token
@@ -137,7 +139,6 @@ either reduce n_len or increase n_ctx")
137139
ctx.decode(&mut batch).with_context(|| "failed to eval")?;
138140

139141
n_decode += 1;
140-
141142
}
142143

143144
eprintln!("\n");
@@ -146,10 +147,14 @@ either reduce n_len or increase n_ctx")
146147

147148
let duration = Duration::from_micros((t_main_end - t_main_start) as u64);
148149

149-
eprintln!("decoded {} tokens in {:.2} s, speed {:.2} t/s\n", n_decode, duration.as_secs_f32(), n_decode as f32 / duration.as_secs_f32());
150+
eprintln!(
151+
"decoded {} tokens in {:.2} s, speed {:.2} t/s\n",
152+
n_decode,
153+
duration.as_secs_f32(),
154+
n_decode as f32 / duration.as_secs_f32()
155+
);
150156

151157
println!("{}", ctx.timings());
152158

153159
Ok(())
154-
155-
}
160+
}

llama-cpp-2/src/context/params.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,8 @@ pub enum RopeScalingType {
1919

2020
/// Create a `RopeScalingType` from a `c_int` - returns `RopeScalingType::ScalingUnspecified` if
2121
/// the value is not recognized.
22-
impl From<i8> for RopeScalingType {
23-
fn from(value: i8) -> Self {
22+
impl From<i32> for RopeScalingType {
23+
fn from(value: i32) -> Self {
2424
match value {
2525
0 => Self::None,
2626
1 => Self::Linear,
@@ -31,7 +31,7 @@ impl From<i8> for RopeScalingType {
3131
}
3232

3333
/// Create a `c_int` from a `RopeScalingType`.
34-
impl From<RopeScalingType> for i8 {
34+
impl From<RopeScalingType> for i32 {
3535
fn from(value: RopeScalingType) -> Self {
3636
match value {
3737
RopeScalingType::None => 0,
@@ -172,7 +172,7 @@ impl LlamaContextParams {
172172
/// assert_eq!(params.rope_scaling_type(), RopeScalingType::Linear);
173173
/// ```
174174
pub fn with_rope_scaling_type(mut self, rope_scaling_type: RopeScalingType) -> Self {
175-
self.context_params.rope_scaling_type = i8::from(rope_scaling_type);
175+
self.context_params.rope_scaling_type = i32::from(rope_scaling_type);
176176
self
177177
}
178178

llama-cpp-2/src/llama_batch.rs

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,10 @@ impl LlamaBatch {
4444
seq_ids: &[i32],
4545
logits: bool,
4646
) -> Result<(), BatchAddError> {
47-
if self.allocated < usize::try_from(self.n_tokens() + 1).expect("cannot fit n_tokens into a usize") {
48-
return Err(BatchAddError::InsufficientSpace(self.allocated))
47+
if self.allocated
48+
< usize::try_from(self.n_tokens() + 1).expect("cannot fit n_tokens into a usize")
49+
{
50+
return Err(BatchAddError::InsufficientSpace(self.allocated));
4951
}
5052
let offset = self.llama_batch.n_tokens;
5153
let offset_usize = usize::try_from(offset).expect("cannot fit n_tokens into a usize");
@@ -55,8 +57,10 @@ impl LlamaBatch {
5557
// batch.pos [batch.n_tokens] = pos,
5658
self.llama_batch.pos.add(offset_usize).write(pos);
5759
// batch.n_seq_id[batch.n_tokens] = seq_ids.size();
58-
self.llama_batch.n_seq_id.add(offset_usize).write(llama_seq_id::try_from(seq_ids.len())
59-
.expect("cannot fit seq_ids.len() into a llama_seq_id"));
60+
self.llama_batch.n_seq_id.add(offset_usize).write(
61+
llama_seq_id::try_from(seq_ids.len())
62+
.expect("cannot fit seq_ids.len() into a llama_seq_id"),
63+
);
6064
// for (size_t i = 0; i < seq_ids.size(); ++i) {
6165
// batch.seq_id[batch.n_tokens][i] = seq_ids[i];
6266
// }
@@ -65,7 +69,10 @@ impl LlamaBatch {
6569
tmp.add(i).write(*seq_id);
6670
}
6771
// batch.logits [batch.n_tokens] = logits;
68-
self.llama_batch.logits.add(offset_usize).write(i8::from(logits));
72+
self.llama_batch
73+
.logits
74+
.add(offset_usize)
75+
.write(i8::from(logits));
6976
}
7077

7178
if logits {

llama-cpp-2/src/model.rs

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,7 @@ impl LlamaModel {
126126
) -> Result<Vec<LlamaToken>, StringToTokenError> {
127127
let add_bos = match add_bos {
128128
AddBos::Always => true,
129-
AddBos::Never => false
129+
AddBos::Never => false,
130130
};
131131

132132
let tokens_estimation = std::cmp::max(8, (str.len() / 2) + usize::from(add_bos));
@@ -136,8 +136,6 @@ impl LlamaModel {
136136
let buffer_capacity =
137137
c_int::try_from(buffer.capacity()).expect("buffer capacity should fit into a c_int");
138138

139-
140-
141139
let size = unsafe {
142140
llama_cpp_sys_2::llama_tokenize(
143141
self.model.as_ptr(),

llama-cpp-2/src/token.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ pub mod data_array;
1010
#[repr(transparent)]
1111
#[derive(Copy, Clone, Debug, Eq, PartialEq, Ord, PartialOrd, Hash)]
1212
#[allow(clippy::module_name_repetitions)]
13-
pub struct LlamaToken( pub llama_cpp_sys_2::llama_token);
13+
pub struct LlamaToken(pub llama_cpp_sys_2::llama_token);
1414

1515
impl Display for LlamaToken {
1616
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {

llama-cpp-sys-2/build.rs

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,32 @@
11
use std::env;
2-
use std::path::PathBuf;
32
use std::path::Path;
3+
use std::path::PathBuf;
44

55
fn main() {
66
println!("cargo:rerun-if-changed=llama.cpp");
77

88
let cublas_enabled = env::var("CARGO_FEATURE_CUBLAS").is_ok();
99

1010
if !Path::new("llama.cpp/ggml.c").exists() {
11-
panic!("llama.cpp seems to not be populated, try running `git submodule update --init --recursive` to init.")
11+
panic!("llama.cpp seems to not be populated, try running `git submodule update --init --recursive` to init.")
1212
}
1313

1414
let mut ggml = cc::Build::new();
15-
let mut ggml_cuda = if cublas_enabled { Some(cc::Build::new()) } else { None };
15+
let mut ggml_cuda = if cublas_enabled {
16+
Some(cc::Build::new())
17+
} else {
18+
None
19+
};
1620
let mut llama_cpp = cc::Build::new();
1721

1822
ggml.cpp(false);
1923
llama_cpp.cpp(true);
2024

2125
// https://github.com/ggerganov/llama.cpp/blob/a836c8f534ab789b02da149fbdaf7735500bff74/Makefile#L364-L368
2226
if let Some(ggml_cuda) = &mut ggml_cuda {
23-
for lib in ["cuda", "cublas", "culibos", "cudart", "cublasLt", "pthread", "dl", "rt"] {
27+
for lib in [
28+
"cuda", "cublas", "culibos", "cudart", "cublasLt", "pthread", "dl", "rt",
29+
] {
2430
println!("cargo:rustc-link-lib={}", lib);
2531
}
2632

@@ -66,8 +72,7 @@ fn main() {
6672
ggml.define("_GNU_SOURCE", None);
6773
}
6874

69-
ggml
70-
.std("c17")
75+
ggml.std("c17")
7176
.file("llama.cpp/ggml.c")
7277
.file("llama.cpp/ggml-alloc.c")
7378
.file("llama.cpp/ggml-backend.c")

0 commit comments

Comments
 (0)