Skip to content

Commit ce6eb1b

Browse files
committed
fixed up LlamaContextParams with new CB
1 parent 40d4b04 commit ce6eb1b

File tree

4 files changed

+105
-145
lines changed

4 files changed

+105
-145
lines changed

Cargo.lock

Lines changed: 2 additions & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

llama-cpp-2/examples/simple.rs

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -56,11 +56,9 @@ fn main() -> Result<()> {
5656
.with_context(|| "unable to load model")?;
5757

5858
// initialize the context
59-
let ctx_params = LlamaContextParams {
60-
seed: 1234,
61-
n_ctx: NonZeroU32::new(2048),
62-
..LlamaContextParams::default()
63-
};
59+
let ctx_params = LlamaContextParams::default()
60+
.with_n_ctx(NonZeroU32::new(2048))
61+
.with_seed(1234);
6462

6563
let mut ctx = model.new_context(&backend, ctx_params)
6664
.with_context(|| "unable to create the llama_context")?;

llama-cpp-2/src/context/params.rs

Lines changed: 95 additions & 132 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
//! A safe wrapper around `llama_context_params`.
2-
use llama_cpp_sys_2::{ggml_type, llama_context_params};
2+
use llama_cpp_sys_2;
33
use std::fmt::Debug;
44
use std::num::NonZeroU32;
55

@@ -43,152 +43,115 @@ impl From<RopeScalingType> for i8 {
4343
}
4444

4545
/// A safe wrapper around `llama_context_params`.
46-
#[derive(Debug, PartialEq)]
46+
///
47+
/// Generally this should be created with [`Default::default()`] and then modified with `with_*` methods.
48+
///
49+
/// # Examples
50+
///
51+
/// ```rust
52+
/// # use std::num::NonZeroU32;
53+
/// use llama_cpp_2::context::params::LlamaContextParams;
54+
///
55+
///let ctx_params = LlamaContextParams::default()
56+
/// .with_n_ctx(NonZeroU32::new(2048))
57+
/// .with_seed(1234);
58+
///
59+
/// assert_eq!(ctx_params.seed(), 1234);
60+
/// assert_eq!(ctx_params.n_ctx(), NonZeroU32::new(2048));
61+
/// ```
62+
#[derive(Debug, Clone)]
4763
#[allow(
4864
missing_docs,
4965
clippy::struct_excessive_bools,
5066
clippy::module_name_repetitions
5167
)]
5268
pub struct LlamaContextParams {
53-
/// The random seed
54-
pub seed: u32,
55-
/// the number of tokens in the context - [`None`] if defined by the model.
56-
pub n_ctx: Option<NonZeroU32>,
57-
pub n_batch: u32,
58-
pub n_threads: u32,
59-
pub n_threads_batch: u32,
60-
pub rope_scaling_type: RopeScalingType,
61-
pub rope_freq_base: f32,
62-
pub rope_freq_scale: f32,
63-
pub yarn_ext_factor: f32,
64-
pub yarn_attn_factor: f32,
65-
pub yarn_beta_fast: f32,
66-
pub yarn_beta_slow: f32,
67-
pub yarn_orig_ctx: u32,
68-
pub type_k: ggml_type,
69-
pub type_v: ggml_type,
70-
pub mul_mat_q: bool,
71-
pub logits_all: bool,
72-
pub embedding: bool,
73-
pub offload_kqv: bool,
74-
pub cb_eval: llama_cpp_sys_2::ggml_backend_sched_eval_callback,
75-
pub cb_eval_user_data: *mut std::ffi::c_void,
69+
pub(crate) context_params: llama_cpp_sys_2::llama_context_params,
70+
}
71+
72+
impl LlamaContextParams {
73+
/// Set the seed of the context
74+
///
75+
/// # Examples
76+
///
77+
/// ```rust
78+
/// use llama_cpp_2::context::params::LlamaContextParams;
79+
/// let params = LlamaContextParams::default();
80+
/// let params = params.with_seed(1234);
81+
/// assert_eq!(params.seed(), 1234);
82+
/// ```
83+
pub fn with_seed(mut self, seed: u32) -> Self {
84+
self.context_params.seed = seed;
85+
self
86+
}
87+
88+
/// Get the seed of the context
89+
///
90+
/// # Examples
91+
///
92+
/// ```rust
93+
/// use llama_cpp_2::context::params::LlamaContextParams;
94+
/// let params = LlamaContextParams::default()
95+
/// .with_seed(1234);
96+
/// assert_eq!(params.seed(), 1234);
97+
/// ```
98+
pub fn seed(&self) -> u32 {
99+
self.context_params.seed
100+
}
101+
102+
/// Set the side of the context
103+
///
104+
/// # Examples
105+
///
106+
/// ```rust
107+
/// # use std::num::NonZeroU32;
108+
/// use llama_cpp_2::context::params::LlamaContextParams;
109+
/// let params = LlamaContextParams::default();
110+
/// let params = params.with_n_ctx(NonZeroU32::new(2048));
111+
/// assert_eq!(params.n_ctx(), NonZeroU32::new(2048));
112+
/// ```
113+
pub fn with_n_ctx(mut self, n_ctx: Option<NonZeroU32>) -> Self {
114+
self.context_params.n_ctx = n_ctx.map_or(0, |n_ctx| n_ctx.get());
115+
self
116+
}
117+
118+
/// Get the size of the context.
119+
///
120+
/// [`None`] if the context size is specified by the model and not the context.
121+
///
122+
/// # Examples
123+
///
124+
/// ```rust
125+
/// let params = llama_cpp_2::context::params::LlamaContextParams::default();
126+
/// assert_eq!(params.n_ctx(), std::num::NonZeroU32::new(512));
127+
pub fn n_ctx(&self) -> Option<NonZeroU32> {
128+
NonZeroU32::new(self.context_params.n_ctx)
129+
}
130+
131+
/// Get the type of rope scaling.
132+
///
133+
/// # Examples
134+
///
135+
/// ```rust
136+
/// let params = llama_cpp_2::context::params::LlamaContextParams::default();
137+
/// assert_eq!(params.rope_scaling_type(), llama_cpp_2::context::params::RopeScalingType::Unspecified);
138+
/// ```
139+
pub fn rope_scaling_type(&self) -> RopeScalingType {
140+
RopeScalingType::from(self.context_params.rope_scaling_type)
141+
}
76142
}
77143

78144
/// Default parameters for `LlamaContext`. (as defined in llama.cpp by `llama_context_default_params`)
79145
/// ```
80146
/// # use std::num::NonZeroU32;
81147
/// use llama_cpp_2::context::params::{LlamaContextParams, RopeScalingType};
82148
/// let params = LlamaContextParams::default();
83-
/// assert_eq!(params.n_ctx, NonZeroU32::new(512), "n_ctx should be 512");
84-
/// assert_eq!(params.rope_scaling_type, RopeScalingType::Unspecified);
149+
/// assert_eq!(params.n_ctx(), NonZeroU32::new(512), "n_ctx should be 512");
150+
/// assert_eq!(params.rope_scaling_type(), RopeScalingType::Unspecified);
85151
/// ```
86152
impl Default for LlamaContextParams {
87153
fn default() -> Self {
88-
Self::from(unsafe { llama_cpp_sys_2::llama_context_default_params() })
89-
}
90-
}
91-
92-
impl From<llama_context_params> for LlamaContextParams {
93-
fn from(
94-
llama_context_params {
95-
seed,
96-
n_ctx,
97-
n_batch,
98-
n_threads,
99-
n_threads_batch,
100-
rope_freq_base,
101-
rope_freq_scale,
102-
cb_eval,
103-
cb_eval_user_data,
104-
type_k,
105-
type_v,
106-
mul_mat_q,
107-
logits_all,
108-
embedding,
109-
rope_scaling_type,
110-
yarn_ext_factor,
111-
yarn_attn_factor,
112-
yarn_beta_fast,
113-
yarn_beta_slow,
114-
yarn_orig_ctx,
115-
offload_kqv,
116-
}: llama_context_params,
117-
) -> Self {
118-
Self {
119-
seed,
120-
n_ctx: NonZeroU32::new(n_ctx),
121-
n_batch,
122-
n_threads,
123-
n_threads_batch,
124-
rope_freq_base,
125-
rope_freq_scale,
126-
type_k,
127-
type_v,
128-
mul_mat_q,
129-
logits_all,
130-
embedding,
131-
rope_scaling_type: RopeScalingType::from(rope_scaling_type),
132-
yarn_ext_factor,
133-
yarn_attn_factor,
134-
yarn_beta_fast,
135-
yarn_beta_slow,
136-
yarn_orig_ctx,
137-
offload_kqv,
138-
cb_eval,
139-
cb_eval_user_data,
140-
}
154+
let context_params = unsafe { llama_cpp_sys_2::llama_context_default_params() };
155+
Self { context_params, }
141156
}
142157
}
143-
144-
impl From<LlamaContextParams> for llama_context_params {
145-
fn from(
146-
LlamaContextParams {
147-
seed,
148-
n_ctx,
149-
n_batch,
150-
n_threads,
151-
n_threads_batch,
152-
rope_freq_base,
153-
rope_freq_scale,
154-
type_k,
155-
type_v,
156-
mul_mat_q,
157-
logits_all,
158-
embedding,
159-
rope_scaling_type,
160-
yarn_ext_factor,
161-
yarn_attn_factor,
162-
yarn_beta_fast,
163-
yarn_beta_slow,
164-
yarn_orig_ctx,
165-
offload_kqv,
166-
cb_eval,
167-
cb_eval_user_data,
168-
}: LlamaContextParams,
169-
) -> Self {
170-
llama_context_params {
171-
seed,
172-
n_ctx: n_ctx.map_or(0, NonZeroU32::get),
173-
n_batch,
174-
n_threads,
175-
n_threads_batch,
176-
rope_freq_base,
177-
rope_freq_scale,
178-
type_k,
179-
type_v,
180-
mul_mat_q,
181-
logits_all,
182-
embedding,
183-
rope_scaling_type: i8::from(rope_scaling_type),
184-
yarn_ext_factor,
185-
yarn_attn_factor,
186-
yarn_beta_fast,
187-
yarn_beta_slow,
188-
yarn_orig_ctx,
189-
offload_kqv,
190-
cb_eval,
191-
cb_eval_user_data,
192-
}
193-
}
194-
}

llama-cpp-2/src/model.rs

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@ use crate::model::params::LlamaModelParams;
66
use crate::token::LlamaToken;
77
use crate::token_type::LlamaTokenType;
88
use crate::{LlamaContextLoadError, LlamaModelLoadError, StringToTokenError, TokenToStringError};
9-
use llama_cpp_sys_2::{llama_context_params, llama_token_get_type, llama_vocab_type};
109
use std::ffi::CString;
1110
use std::os::raw::c_int;
1211
use std::path::Path;
@@ -184,7 +183,7 @@ impl LlamaModel {
184183
/// If the token type is not known to this library.
185184
#[must_use]
186185
pub fn token_type(&self, LlamaToken(id): LlamaToken) -> LlamaTokenType {
187-
let token_type = unsafe { llama_token_get_type(self.model.as_ptr(), id) };
186+
let token_type = unsafe { llama_cpp_sys_2::llama_token_get_type(self.model.as_ptr(), id) };
188187
LlamaTokenType::try_from(token_type).expect("token type is valid")
189188
}
190189

@@ -314,7 +313,7 @@ impl LlamaModel {
314313
_: &LlamaBackend,
315314
params: LlamaContextParams,
316315
) -> Result<LlamaContext, LlamaContextLoadError> {
317-
let context_params = llama_context_params::from(params);
316+
let context_params = params.context_params;
318317
let context = unsafe {
319318
llama_cpp_sys_2::llama_new_context_with_model(self.model.as_ptr(), context_params)
320319
};
@@ -345,13 +344,13 @@ pub enum VocabType {
345344
pub enum LlamaTokenTypeFromIntError {
346345
/// The value is not a valid `llama_token_type`. Contains the int value that was invalid.
347346
#[error("Unknown Value {0}")]
348-
UnknownValue(llama_vocab_type),
347+
UnknownValue(llama_cpp_sys_2::llama_vocab_type),
349348
}
350349

351-
impl TryFrom<llama_vocab_type> for VocabType {
350+
impl TryFrom<llama_cpp_sys_2::llama_vocab_type> for VocabType {
352351
type Error = LlamaTokenTypeFromIntError;
353352

354-
fn try_from(value: llama_vocab_type) -> Result<Self, Self::Error> {
353+
fn try_from(value: llama_cpp_sys_2::llama_vocab_type) -> Result<Self, Self::Error> {
355354
match value {
356355
llama_cpp_sys_2::LLAMA_VOCAB_TYPE_BPE => Ok(VocabType::BPE),
357356
llama_cpp_sys_2::LLAMA_VOCAB_TYPE_SPM => Ok(VocabType::SPM),

0 commit comments

Comments
 (0)