11//! A safe wrapper around `llama_context_params`.
2- use llama_cpp_sys_2:: { ggml_type , llama_context_params } ;
2+ use llama_cpp_sys_2;
33use std:: fmt:: Debug ;
44use std:: num:: NonZeroU32 ;
55
@@ -43,152 +43,115 @@ impl From<RopeScalingType> for i8 {
4343}
4444
4545/// A safe wrapper around `llama_context_params`.
46- #[ derive( Debug , PartialEq ) ]
46+ ///
47+ /// Generally this should be created with [`Default::default()`] and then modified with `with_*` methods.
48+ ///
49+ /// # Examples
50+ ///
51+ /// ```rust
52+ /// # use std::num::NonZeroU32;
53+ /// use llama_cpp_2::context::params::LlamaContextParams;
54+ ///
55+ ///let ctx_params = LlamaContextParams::default()
56+ /// .with_n_ctx(NonZeroU32::new(2048))
57+ /// .with_seed(1234);
58+ ///
59+ /// assert_eq!(ctx_params.seed(), 1234);
60+ /// assert_eq!(ctx_params.n_ctx(), NonZeroU32::new(2048));
61+ /// ```
62+ #[ derive( Debug , Clone ) ]
4763#[ allow(
4864 missing_docs,
4965 clippy:: struct_excessive_bools,
5066 clippy:: module_name_repetitions
5167) ]
5268pub struct LlamaContextParams {
53- /// The random seed
54- pub seed : u32 ,
55- /// the number of tokens in the context - [`None`] if defined by the model.
56- pub n_ctx : Option < NonZeroU32 > ,
57- pub n_batch : u32 ,
58- pub n_threads : u32 ,
59- pub n_threads_batch : u32 ,
60- pub rope_scaling_type : RopeScalingType ,
61- pub rope_freq_base : f32 ,
62- pub rope_freq_scale : f32 ,
63- pub yarn_ext_factor : f32 ,
64- pub yarn_attn_factor : f32 ,
65- pub yarn_beta_fast : f32 ,
66- pub yarn_beta_slow : f32 ,
67- pub yarn_orig_ctx : u32 ,
68- pub type_k : ggml_type ,
69- pub type_v : ggml_type ,
70- pub mul_mat_q : bool ,
71- pub logits_all : bool ,
72- pub embedding : bool ,
73- pub offload_kqv : bool ,
74- pub cb_eval : llama_cpp_sys_2:: ggml_backend_sched_eval_callback ,
75- pub cb_eval_user_data : * mut std:: ffi:: c_void ,
69+ pub ( crate ) context_params : llama_cpp_sys_2:: llama_context_params ,
70+ }
71+
72+ impl LlamaContextParams {
73+ /// Set the seed of the context
74+ ///
75+ /// # Examples
76+ ///
77+ /// ```rust
78+ /// use llama_cpp_2::context::params::LlamaContextParams;
79+ /// let params = LlamaContextParams::default();
80+ /// let params = params.with_seed(1234);
81+ /// assert_eq!(params.seed(), 1234);
82+ /// ```
83+ pub fn with_seed ( mut self , seed : u32 ) -> Self {
84+ self . context_params . seed = seed;
85+ self
86+ }
87+
88+ /// Get the seed of the context
89+ ///
90+ /// # Examples
91+ ///
92+ /// ```rust
93+ /// use llama_cpp_2::context::params::LlamaContextParams;
94+ /// let params = LlamaContextParams::default()
95+ /// .with_seed(1234);
96+ /// assert_eq!(params.seed(), 1234);
97+ /// ```
98+ pub fn seed ( & self ) -> u32 {
99+ self . context_params . seed
100+ }
101+
102+ /// Set the side of the context
103+ ///
104+ /// # Examples
105+ ///
106+ /// ```rust
107+ /// # use std::num::NonZeroU32;
108+ /// use llama_cpp_2::context::params::LlamaContextParams;
109+ /// let params = LlamaContextParams::default();
110+ /// let params = params.with_n_ctx(NonZeroU32::new(2048));
111+ /// assert_eq!(params.n_ctx(), NonZeroU32::new(2048));
112+ /// ```
113+ pub fn with_n_ctx ( mut self , n_ctx : Option < NonZeroU32 > ) -> Self {
114+ self . context_params . n_ctx = n_ctx. map_or ( 0 , |n_ctx| n_ctx. get ( ) ) ;
115+ self
116+ }
117+
118+ /// Get the size of the context.
119+ ///
120+ /// [`None`] if the context size is specified by the model and not the context.
121+ ///
122+ /// # Examples
123+ ///
124+ /// ```rust
125+ /// let params = llama_cpp_2::context::params::LlamaContextParams::default();
126+ /// assert_eq!(params.n_ctx(), std::num::NonZeroU32::new(512));
127+ pub fn n_ctx ( & self ) -> Option < NonZeroU32 > {
128+ NonZeroU32 :: new ( self . context_params . n_ctx )
129+ }
130+
131+ /// Get the type of rope scaling.
132+ ///
133+ /// # Examples
134+ ///
135+ /// ```rust
136+ /// let params = llama_cpp_2::context::params::LlamaContextParams::default();
137+ /// assert_eq!(params.rope_scaling_type(), llama_cpp_2::context::params::RopeScalingType::Unspecified);
138+ /// ```
139+ pub fn rope_scaling_type ( & self ) -> RopeScalingType {
140+ RopeScalingType :: from ( self . context_params . rope_scaling_type )
141+ }
76142}
77143
78144/// Default parameters for `LlamaContext`. (as defined in llama.cpp by `llama_context_default_params`)
79145/// ```
80146/// # use std::num::NonZeroU32;
81147/// use llama_cpp_2::context::params::{LlamaContextParams, RopeScalingType};
82148/// let params = LlamaContextParams::default();
83- /// assert_eq!(params.n_ctx, NonZeroU32::new(512), "n_ctx should be 512");
84- /// assert_eq!(params.rope_scaling_type, RopeScalingType::Unspecified);
149+ /// assert_eq!(params.n_ctx() , NonZeroU32::new(512), "n_ctx should be 512");
150+ /// assert_eq!(params.rope_scaling_type() , RopeScalingType::Unspecified);
85151/// ```
86152impl Default for LlamaContextParams {
87153 fn default ( ) -> Self {
88- Self :: from ( unsafe { llama_cpp_sys_2:: llama_context_default_params ( ) } )
89- }
90- }
91-
92- impl From < llama_context_params > for LlamaContextParams {
93- fn from (
94- llama_context_params {
95- seed,
96- n_ctx,
97- n_batch,
98- n_threads,
99- n_threads_batch,
100- rope_freq_base,
101- rope_freq_scale,
102- cb_eval,
103- cb_eval_user_data,
104- type_k,
105- type_v,
106- mul_mat_q,
107- logits_all,
108- embedding,
109- rope_scaling_type,
110- yarn_ext_factor,
111- yarn_attn_factor,
112- yarn_beta_fast,
113- yarn_beta_slow,
114- yarn_orig_ctx,
115- offload_kqv,
116- } : llama_context_params ,
117- ) -> Self {
118- Self {
119- seed,
120- n_ctx : NonZeroU32 :: new ( n_ctx) ,
121- n_batch,
122- n_threads,
123- n_threads_batch,
124- rope_freq_base,
125- rope_freq_scale,
126- type_k,
127- type_v,
128- mul_mat_q,
129- logits_all,
130- embedding,
131- rope_scaling_type : RopeScalingType :: from ( rope_scaling_type) ,
132- yarn_ext_factor,
133- yarn_attn_factor,
134- yarn_beta_fast,
135- yarn_beta_slow,
136- yarn_orig_ctx,
137- offload_kqv,
138- cb_eval,
139- cb_eval_user_data,
140- }
154+ let context_params = unsafe { llama_cpp_sys_2:: llama_context_default_params ( ) } ;
155+ Self { context_params, }
141156 }
142157}
143-
144- impl From < LlamaContextParams > for llama_context_params {
145- fn from (
146- LlamaContextParams {
147- seed,
148- n_ctx,
149- n_batch,
150- n_threads,
151- n_threads_batch,
152- rope_freq_base,
153- rope_freq_scale,
154- type_k,
155- type_v,
156- mul_mat_q,
157- logits_all,
158- embedding,
159- rope_scaling_type,
160- yarn_ext_factor,
161- yarn_attn_factor,
162- yarn_beta_fast,
163- yarn_beta_slow,
164- yarn_orig_ctx,
165- offload_kqv,
166- cb_eval,
167- cb_eval_user_data,
168- } : LlamaContextParams ,
169- ) -> Self {
170- llama_context_params {
171- seed,
172- n_ctx : n_ctx. map_or ( 0 , NonZeroU32 :: get) ,
173- n_batch,
174- n_threads,
175- n_threads_batch,
176- rope_freq_base,
177- rope_freq_scale,
178- type_k,
179- type_v,
180- mul_mat_q,
181- logits_all,
182- embedding,
183- rope_scaling_type : i8:: from ( rope_scaling_type) ,
184- yarn_ext_factor,
185- yarn_attn_factor,
186- yarn_beta_fast,
187- yarn_beta_slow,
188- yarn_orig_ctx,
189- offload_kqv,
190- cb_eval,
191- cb_eval_user_data,
192- }
193- }
194- }
0 commit comments