Skip to content

Commit 818ec2b

Browse files
committed
improved api
1 parent 667d612 commit 818ec2b

File tree

4 files changed

+25
-52
lines changed

4 files changed

+25
-52
lines changed

llama-cpp-2/src/model.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -277,10 +277,10 @@ impl LlamaModel {
277277
///
278278
/// See [`LlamaModelLoadError`] for more information.
279279
#[tracing::instrument(skip_all)]
280-
pub fn load_from_file<T>(
280+
pub fn load_from_file(
281281
_: &LlamaBackend,
282282
path: impl AsRef<Path>,
283-
params: &LlamaModelParams<T>,
283+
params: &LlamaModelParams,
284284
) -> Result<Self, LlamaModelLoadError> {
285285
let path = path.as_ref();
286286
debug_assert!(Path::new(path).exists(), "{path:?} does not exist");

llama-cpp-2/src/model/params.rs

Lines changed: 18 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -13,25 +13,12 @@ pub mod kv_overrides;
1313
/// [`T`] is the type of the backing storage for the key-value overrides. Generally it can be left to [`()`] which will
1414
/// make your life with the borrow checker much easier.
1515
#[allow(clippy::module_name_repetitions)]
16-
pub struct LlamaModelParams<T> {
16+
pub struct LlamaModelParams {
1717
pub(crate) params: llama_cpp_sys_2::llama_model_params,
18-
kv_overrides: T,
18+
kv_overrides: Vec<llama_cpp_sys_2::llama_model_kv_override>,
1919
}
2020

21-
impl Debug for LlamaModelParams<()> {
22-
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
23-
f.debug_struct("LlamaModelParams")
24-
.field("n_gpu_layers", &self.params.n_gpu_layers)
25-
.field("main_gpu", &self.params.main_gpu)
26-
.field("vocab_only", &self.params.vocab_only)
27-
.field("use_mmap", &self.params.use_mmap)
28-
.field("use_mlock", &self.params.use_mlock)
29-
.field("kv_overrides", &self.kv_overrides)
30-
.finish()
31-
}
32-
}
33-
34-
impl Debug for LlamaModelParams<Vec<llama_cpp_sys_2::llama_model_kv_override>> {
21+
impl Debug for LlamaModelParams {
3522
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
3623
f.debug_struct("LlamaModelParams")
3724
.field("n_gpu_layers", &self.params.n_gpu_layers)
@@ -44,36 +31,14 @@ impl Debug for LlamaModelParams<Vec<llama_cpp_sys_2::llama_model_kv_override>> {
4431
}
4532
}
4633

47-
impl LlamaModelParams<Vec<llama_cpp_sys_2::llama_model_kv_override>> {
48-
/// Creates a new `LlamaModelParams` with the default parameters and a backing storage. As this struct will be
49-
/// self-referential, it cannot be moved and thus is pinned.
50-
///
51-
/// # Examples
52-
/// ```rust
53-
/// # use llama_cpp_2::model::params::LlamaModelParams;
54-
/// let params = LlamaModelParams::new_with_kv_overrides(LlamaModelParams::default());
55-
/// ```
56-
#[must_use]
57-
pub fn new_with_kv_overrides(params: LlamaModelParams<()>) -> Pin<Box<Self>> {
58-
Box::pin(Self {
59-
params: params.params,
60-
kv_overrides: vec![llama_cpp_sys_2::llama_model_kv_override {
61-
key: [0; 128],
62-
tag: 0,
63-
__bindgen_anon_1: llama_cpp_sys_2::llama_model_kv_override__bindgen_ty_1 {
64-
int_value: 0,
65-
},
66-
}],
67-
})
68-
}
69-
34+
impl LlamaModelParams {
7035
/// See [`KvOverrides`]
7136
///
7237
/// # Examples
7338
///
7439
/// ```rust
7540
/// # use llama_cpp_2::model::params::LlamaModelParams;
76-
/// let params = LlamaModelParams::new_with_kv_overrides(LlamaModelParams::default());
41+
/// let params = Box::pin(LlamaModelParams::default());
7742
/// let kv_overrides = params.kv_overrides();
7843
/// let count = kv_overrides.into_iter().count();
7944
/// assert_eq!(count, 0);
@@ -83,15 +48,15 @@ impl LlamaModelParams<Vec<llama_cpp_sys_2::llama_model_kv_override>> {
8348
KvOverrides::new(self)
8449
}
8550

86-
/// Appends a key-value override to the model parameters.
51+
/// Appends a key-value override to the model parameters. It must be pinned as this creates a self-referential struct.
8752
///
8853
/// # Examples
8954
///
9055
/// ```rust
9156
/// # use std::ffi::{CStr, CString};
9257
/// # use llama_cpp_2::model::params::LlamaModelParams;
9358
/// # use llama_cpp_2::model::params::kv_overrides::ParamOverrideValue;
94-
/// let mut params = LlamaModelParams::new_with_kv_overrides(LlamaModelParams::default());
59+
/// let mut params = Box::pin(LlamaModelParams::default());
9560
/// let key = CString::new("key").expect("CString::new failed");
9661
/// params.append_kv_override(&key, ParamOverrideValue::Int(50));
9762
///
@@ -128,7 +93,8 @@ impl LlamaModelParams<Vec<llama_cpp_sys_2::llama_model_kv_override>> {
12893
self.params.kv_overrides = null();
12994

13095
// push the next one to ensure we maintain the iterator invariant of ending with a 0
131-
self.kv_overrides.push(llama_cpp_sys_2::llama_model_kv_override {
96+
self.kv_overrides
97+
.push(llama_cpp_sys_2::llama_model_kv_override {
13298
key: [0; 128],
13399
tag: 0,
134100
__bindgen_anon_1: llama_cpp_sys_2::llama_model_kv_override__bindgen_ty_1 {
@@ -143,7 +109,7 @@ impl LlamaModelParams<Vec<llama_cpp_sys_2::llama_model_kv_override>> {
143109
}
144110
}
145111

146-
impl<T> LlamaModelParams<T> {
112+
impl LlamaModelParams {
147113
/// Get the number of layers to offload to the GPU.
148114
#[must_use]
149115
pub fn n_gpu_layers(&self) -> i32 {
@@ -222,12 +188,18 @@ impl<T> LlamaModelParams<T> {
222188
/// assert_eq!(params.use_mmap(), true, "use_mmap should be true");
223189
/// assert_eq!(params.use_mlock(), false, "use_mlock should be false");
224190
/// ```
225-
impl Default for LlamaModelParams<()> {
191+
impl Default for LlamaModelParams {
226192
fn default() -> Self {
227193
let default_params = unsafe { llama_cpp_sys_2::llama_model_default_params() };
228194
LlamaModelParams {
229195
params: default_params,
230-
kv_overrides: (),
196+
kv_overrides: vec![llama_cpp_sys_2::llama_model_kv_override {
197+
key: [0; 128],
198+
tag: 0,
199+
__bindgen_anon_1: llama_cpp_sys_2::llama_model_kv_override__bindgen_ty_1 {
200+
int_value: 0,
201+
},
202+
}],
231203
}
232204
}
233205
}

llama-cpp-2/src/model/params/kv_overrides.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -67,12 +67,12 @@ impl From<&llama_cpp_sys_2::llama_model_kv_override> for ParamOverrideValue {
6767
/// A struct implementing [`IntoIterator`] over the key-value overrides for a model.
6868
#[derive(Debug)]
6969
pub struct KvOverrides<'a> {
70-
model_params: &'a LlamaModelParams<Vec<llama_cpp_sys_2::llama_model_kv_override>>,
70+
model_params: &'a LlamaModelParams,
7171
}
7272

7373
impl KvOverrides<'_> {
7474
pub(super) fn new(
75-
model_params: &LlamaModelParams<Vec<llama_cpp_sys_2::llama_model_kv_override>>,
75+
model_params: &LlamaModelParams,
7676
) -> KvOverrides {
7777
KvOverrides { model_params }
7878
}
@@ -95,7 +95,7 @@ impl<'a> IntoIterator for KvOverrides<'a> {
9595
/// An iterator over the key-value overrides for a model.
9696
#[derive(Debug)]
9797
pub struct KvOverrideValueIterator<'a> {
98-
model_params: &'a LlamaModelParams<Vec<llama_cpp_sys_2::llama_model_kv_override>>,
98+
model_params: &'a LlamaModelParams,
9999
current: usize,
100100
}
101101

simple/src/main.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,8 @@ fn main() -> Result<()> {
120120
LlamaModelParams::default()
121121
};
122122

123-
let mut model_params = LlamaModelParams::new_with_kv_overrides(model_params);
123+
let mut model_params = Box::pin(model_params);
124+
124125
for (k, v) in key_value_overrides.iter() {
125126
let k = CString::new(k.as_bytes()).with_context(|| format!("invalid key: {}", k))?;
126127
model_params.append_kv_override(k.as_c_str(), *v);

0 commit comments

Comments
 (0)