Skip to content

Commit 9163649

Browse files
authored
Merge pull request #120 from utilityai/kv-overrides
override model values
2 parents 1a4493d + 91081e7 commit 9163649

File tree

3 files changed

+271
-4
lines changed

3 files changed

+271
-4
lines changed

llama-cpp-2/src/model/params.rs

Lines changed: 112 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,113 @@
11
//! A safe wrapper around `llama_model_params`.
22
3-
use std::fmt::Debug;
3+
use crate::model::params::kv_overrides::KvOverrides;
4+
use std::ffi::{c_char, CStr};
5+
use std::fmt::{Debug, Formatter};
6+
use std::pin::Pin;
7+
use std::ptr::null;
8+
9+
pub mod kv_overrides;
410

511
/// A safe wrapper around `llama_model_params`.
12+
///
13+
/// [`T`] is the type of the backing storage for the key-value overrides. Generally it can be left to [`()`] which will
14+
/// make your life with the borrow checker much easier.
615
#[allow(clippy::module_name_repetitions)]
7-
#[derive(Debug)]
816
pub struct LlamaModelParams {
917
pub(crate) params: llama_cpp_sys_2::llama_model_params,
18+
kv_overrides: Vec<llama_cpp_sys_2::llama_model_kv_override>,
19+
}
20+
21+
impl Debug for LlamaModelParams {
22+
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
23+
f.debug_struct("LlamaModelParams")
24+
.field("n_gpu_layers", &self.params.n_gpu_layers)
25+
.field("main_gpu", &self.params.main_gpu)
26+
.field("vocab_only", &self.params.vocab_only)
27+
.field("use_mmap", &self.params.use_mmap)
28+
.field("use_mlock", &self.params.use_mlock)
29+
.field("kv_overrides", &"vec of kv_overrides")
30+
.finish()
31+
}
32+
}
33+
34+
impl LlamaModelParams {
35+
/// See [`KvOverrides`]
36+
///
37+
/// # Examples
38+
///
39+
/// ```rust
40+
/// # use llama_cpp_2::model::params::LlamaModelParams;
41+
/// let params = Box::pin(LlamaModelParams::default());
42+
/// let kv_overrides = params.kv_overrides();
43+
/// let count = kv_overrides.into_iter().count();
44+
/// assert_eq!(count, 0);
45+
/// ```
46+
#[must_use]
47+
pub fn kv_overrides(&self) -> KvOverrides {
48+
KvOverrides::new(self)
49+
}
50+
51+
/// Appends a key-value override to the model parameters. It must be pinned as this creates a self-referential struct.
52+
///
53+
/// # Examples
54+
///
55+
/// ```rust
56+
/// # use std::ffi::{CStr, CString};
57+
/// use std::pin::pin;
58+
/// # use llama_cpp_2::model::params::LlamaModelParams;
59+
/// # use llama_cpp_2::model::params::kv_overrides::ParamOverrideValue;
60+
/// let mut params = pin!(LlamaModelParams::default());
61+
/// let key = CString::new("key").expect("CString::new failed");
62+
/// params.append_kv_override(&key, ParamOverrideValue::Int(50));
63+
///
64+
/// let kv_overrides = params.kv_overrides().into_iter().collect::<Vec<_>>();
65+
/// assert_eq!(kv_overrides.len(), 1);
66+
///
67+
/// let (k, v) = &kv_overrides[0];
68+
/// assert_eq!(v, &ParamOverrideValue::Int(50));
69+
///
70+
/// assert_eq!(k.to_bytes(), b"key", "expected key to be 'key', was {:?}", k);
71+
/// ```
72+
#[allow(clippy::missing_panics_doc)] // panics are just to enforce internal invariants, not user errors
73+
pub fn append_kv_override(
74+
self: &mut Pin<&mut Self>,
75+
key: &CStr,
76+
value: kv_overrides::ParamOverrideValue,
77+
) {
78+
let kv_override = self
79+
.kv_overrides
80+
.get_mut(0)
81+
.expect("kv_overrides did not have a next allocated");
82+
83+
assert_eq!(kv_override.key[0], 0, "last kv_override was not empty");
84+
85+
// There should be some way to do this without iterating over everything.
86+
for (i, &c) in key.to_bytes_with_nul().iter().enumerate() {
87+
kv_override.key[i] = c_char::try_from(c).expect("invalid character in key");
88+
}
89+
90+
kv_override.tag = value.tag();
91+
kv_override.__bindgen_anon_1 = value.value();
92+
93+
// set to null pointer for panic safety (as push may move the vector, invalidating the pointer)
94+
self.params.kv_overrides = null();
95+
96+
// push the next one to ensure we maintain the iterator invariant of ending with a 0
97+
self.kv_overrides
98+
.push(llama_cpp_sys_2::llama_model_kv_override {
99+
key: [0; 128],
100+
tag: 0,
101+
__bindgen_anon_1: llama_cpp_sys_2::llama_model_kv_override__bindgen_ty_1 {
102+
int_value: 0,
103+
},
104+
});
105+
106+
// set the pointer to the (potentially) new vector
107+
self.params.kv_overrides = self.kv_overrides.as_ptr();
108+
109+
eprintln!("saved ptr: {:?}", self.params.kv_overrides);
110+
}
10111
}
11112

12113
impl LlamaModelParams {
@@ -90,8 +191,16 @@ impl LlamaModelParams {
90191
/// ```
91192
impl Default for LlamaModelParams {
92193
fn default() -> Self {
194+
let default_params = unsafe { llama_cpp_sys_2::llama_model_default_params() };
93195
LlamaModelParams {
94-
params: unsafe { llama_cpp_sys_2::llama_model_default_params() },
196+
params: default_params,
197+
kv_overrides: vec![llama_cpp_sys_2::llama_model_kv_override {
198+
key: [0; 128],
199+
tag: 0,
200+
__bindgen_anon_1: llama_cpp_sys_2::llama_model_kv_override__bindgen_ty_1 {
201+
int_value: 0,
202+
},
203+
}],
95204
}
96205
}
97206
}
Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
//! Key-value overrides for a model.
2+
3+
use crate::model::params::LlamaModelParams;
4+
use std::ffi::{CStr, CString};
5+
use std::fmt::Debug;
6+
7+
/// An override value for a model parameter.
8+
#[derive(Debug, Clone, Copy, PartialEq)]
9+
pub enum ParamOverrideValue {
10+
/// A string value
11+
Bool(bool),
12+
/// A float value
13+
Float(f64),
14+
/// A integer value
15+
Int(i64),
16+
}
17+
18+
impl ParamOverrideValue {
19+
pub(crate) fn tag(&self) -> llama_cpp_sys_2::llama_model_kv_override_type {
20+
match self {
21+
ParamOverrideValue::Bool(_) => llama_cpp_sys_2::LLAMA_KV_OVERRIDE_TYPE_BOOL,
22+
ParamOverrideValue::Float(_) => llama_cpp_sys_2::LLAMA_KV_OVERRIDE_TYPE_FLOAT,
23+
ParamOverrideValue::Int(_) => llama_cpp_sys_2::LLAMA_KV_OVERRIDE_TYPE_INT,
24+
}
25+
}
26+
27+
pub(crate) fn value(&self) -> llama_cpp_sys_2::llama_model_kv_override__bindgen_ty_1 {
28+
match self {
29+
ParamOverrideValue::Bool(value) => {
30+
llama_cpp_sys_2::llama_model_kv_override__bindgen_ty_1 { bool_value: *value }
31+
}
32+
ParamOverrideValue::Float(value) => {
33+
llama_cpp_sys_2::llama_model_kv_override__bindgen_ty_1 {
34+
float_value: *value,
35+
}
36+
}
37+
ParamOverrideValue::Int(value) => {
38+
llama_cpp_sys_2::llama_model_kv_override__bindgen_ty_1 { int_value: *value }
39+
}
40+
}
41+
}
42+
}
43+
44+
impl From<&llama_cpp_sys_2::llama_model_kv_override> for ParamOverrideValue {
45+
fn from(
46+
llama_cpp_sys_2::llama_model_kv_override {
47+
key: _,
48+
tag,
49+
__bindgen_anon_1,
50+
}: &llama_cpp_sys_2::llama_model_kv_override,
51+
) -> Self {
52+
match *tag {
53+
llama_cpp_sys_2::LLAMA_KV_OVERRIDE_TYPE_INT => {
54+
ParamOverrideValue::Int(unsafe { __bindgen_anon_1.int_value })
55+
}
56+
llama_cpp_sys_2::LLAMA_KV_OVERRIDE_TYPE_FLOAT => {
57+
ParamOverrideValue::Float(unsafe { __bindgen_anon_1.float_value })
58+
}
59+
llama_cpp_sys_2::LLAMA_KV_OVERRIDE_TYPE_BOOL => {
60+
ParamOverrideValue::Bool(unsafe { __bindgen_anon_1.bool_value })
61+
}
62+
_ => unreachable!("Unknown tag of {tag}"),
63+
}
64+
}
65+
}
66+
67+
/// A struct implementing [`IntoIterator`] over the key-value overrides for a model.
68+
#[derive(Debug)]
69+
pub struct KvOverrides<'a> {
70+
model_params: &'a LlamaModelParams,
71+
}
72+
73+
impl KvOverrides<'_> {
74+
pub(super) fn new(
75+
model_params: &LlamaModelParams,
76+
) -> KvOverrides {
77+
KvOverrides { model_params }
78+
}
79+
}
80+
81+
impl<'a> IntoIterator for KvOverrides<'a> {
82+
// I'm fairly certain this could be written returning by reference, but I'm not sure how to do it safely. I do not
83+
// expect this to be a performance bottleneck so the copy should be fine. (let me know if it's not fine!)
84+
type Item = (CString, ParamOverrideValue);
85+
type IntoIter = KvOverrideValueIterator<'a>;
86+
87+
fn into_iter(self) -> Self::IntoIter {
88+
KvOverrideValueIterator {
89+
model_params: self.model_params,
90+
current: 0,
91+
}
92+
}
93+
}
94+
95+
/// An iterator over the key-value overrides for a model.
96+
#[derive(Debug)]
97+
pub struct KvOverrideValueIterator<'a> {
98+
model_params: &'a LlamaModelParams,
99+
current: usize,
100+
}
101+
102+
impl<'a> Iterator for KvOverrideValueIterator<'a> {
103+
type Item = (CString, ParamOverrideValue);
104+
105+
fn next(&mut self) -> Option<Self::Item> {
106+
let overrides = self.model_params.params.kv_overrides;
107+
if overrides.is_null() {
108+
return None;
109+
}
110+
111+
// SAFETY: llama.cpp seems to guarantee that the last element contains an empty key or is valid. We've checked
112+
// the prev one in the last iteration, the next one should be valid or 0 (and thus safe to deref)
113+
let current = unsafe { *overrides.add(self.current) };
114+
115+
if current.key[0] == 0 {
116+
return None;
117+
}
118+
119+
let value = ParamOverrideValue::from(&current);
120+
121+
let key = unsafe { CStr::from_ptr(current.key.as_ptr()).to_owned() };
122+
123+
self.current += 1;
124+
Some((key, value))
125+
}
126+
}

simple/src/main.rs

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,9 @@
66
clippy::cast_sign_loss
77
)]
88

9-
use anyhow::{bail, Context, Result};
9+
use std::collections::BTreeMap;
10+
use std::ffi::{CStr, CString};
11+
use anyhow::{anyhow, bail, Context, Result};
1012
use clap::Parser;
1113
use hf_hub::api::sync::ApiBuilder;
1214
use llama_cpp_2::context::params::LlamaContextParams;
@@ -20,7 +22,10 @@ use llama_cpp_2::token::data_array::LlamaTokenDataArray;
2022
use std::io::Write;
2123
use std::num::NonZeroU32;
2224
use std::path::PathBuf;
25+
use std::pin::pin;
26+
use std::str::FromStr;
2327
use std::time::Duration;
28+
use llama_cpp_2::model::params::kv_overrides::ParamOverrideValue;
2429

2530
#[derive(clap::Parser, Debug, Clone)]
2631
struct Args {
@@ -33,12 +38,31 @@ struct Args {
3338
/// set the length of the prompt + output in tokens
3439
#[arg(long, default_value_t = 32)]
3540
n_len: i32,
41+
/// override some parameters of the model
42+
#[arg(short = 'o', value_parser = parse_key_val)]
43+
key_value_overrides: Vec<(String, ParamOverrideValue)>,
3644
/// Disable offloading layers to the gpu
3745
#[cfg(feature = "cublas")]
3846
#[clap(long)]
3947
disable_gpu: bool,
4048
}
4149

50+
/// Parse a single key-value pair
51+
fn parse_key_val(s: &str) -> Result<(String, ParamOverrideValue)> {
52+
let pos = s
53+
.find('=')
54+
.ok_or_else(|| anyhow!("invalid KEY=value: no `=` found in `{}`", s))?;
55+
let key = s[..pos].parse()?;
56+
let value: String = s[pos + 1..].parse()?;
57+
let value = i64::from_str(&value).map(ParamOverrideValue::Int)
58+
.or_else(|_| f64::from_str(&value).map(ParamOverrideValue::Float))
59+
.or_else(|_| bool::from_str(&value).map(ParamOverrideValue::Bool))
60+
.map_err(|_| anyhow!("must be one of i64, f64, or bool"))?;
61+
62+
Ok((key, value))
63+
}
64+
65+
4266
#[derive(clap::Subcommand, Debug, Clone)]
4367
enum Model {
4468
/// Use an already downloaded model
@@ -79,6 +103,7 @@ fn main() -> Result<()> {
79103
prompt,
80104
#[cfg(feature = "cublas")]
81105
disable_gpu,
106+
key_value_overrides,
82107
} = Args::parse();
83108

84109
// init LLM
@@ -95,6 +120,13 @@ fn main() -> Result<()> {
95120
#[cfg(not(feature = "cublas"))]
96121
LlamaModelParams::default()
97122
};
123+
124+
let mut model_params = pin!(model_params);
125+
126+
for (k, v) in key_value_overrides.iter() {
127+
let k = CString::new(k.as_bytes()).with_context(|| format!("invalid key: {}", k))?;
128+
model_params.append_kv_override(k.as_c_str(), *v);
129+
}
98130

99131
let model_path = model
100132
.get_or_load()

0 commit comments

Comments
 (0)