@@ -132,6 +132,50 @@ impl LlamaContextParams {
132
132
NonZeroU32 :: new ( self . context_params . n_ctx )
133
133
}
134
134
135
+ /// Set the n_batch
136
+ ///
137
+ /// # Examples
138
+ ///
139
+ /// ```rust
140
+ /// # use std::num::NonZeroU32;
141
+ /// use llama_cpp_2::context::params::LlamaContextParams;
142
+ /// let params = LlamaContextParams::default()
143
+ /// .with_n_batch(2048);
144
+ /// assert_eq!(params.n_batch(), 2048);
145
+ /// ```
146
+ pub fn with_n_batch ( mut self , n_batch : u32 ) -> Self {
147
+ self . context_params . n_batch = n_batch;
148
+ self
149
+ }
150
+
151
+ /// Get the n_batch
152
+ ///
153
+ /// # Examples
154
+ ///
155
+ /// ```rust
156
+ /// use llama_cpp_2::context::params::LlamaContextParams;
157
+ /// let params = LlamaContextParams::default();
158
+ /// assert_eq!(params.n_batch(), 512);
159
+ /// ```
160
+ pub fn n_batch ( & self ) -> u32 {
161
+ self . context_params . n_batch
162
+ }
163
+
164
+ /// Set the type of rope scaling.
165
+ ///
166
+ /// # Examples
167
+ ///
168
+ /// ```rust
169
+ /// use llama_cpp_2::context::params::{LlamaContextParams, RopeScalingType};
170
+ /// let params = LlamaContextParams::default()
171
+ /// .with_rope_scaling_type(RopeScalingType::Linear);
172
+ /// assert_eq!(params.rope_scaling_type(), RopeScalingType::Linear);
173
+ /// ```
174
+ pub fn with_rope_scaling_type ( mut self , rope_scaling_type : RopeScalingType ) -> Self {
175
+ self . context_params . rope_scaling_type = i8:: from ( rope_scaling_type) ;
176
+ self
177
+ }
178
+
135
179
/// Get the type of rope scaling.
136
180
///
137
181
/// # Examples
@@ -143,6 +187,60 @@ impl LlamaContextParams {
143
187
pub fn rope_scaling_type ( & self ) -> RopeScalingType {
144
188
RopeScalingType :: from ( self . context_params . rope_scaling_type )
145
189
}
190
+
191
+ /// Set the rope frequency base.
192
+ ///
193
+ /// # Examples
194
+ ///
195
+ /// ```rust
196
+ /// use llama_cpp_2::context::params::LlamaContextParams;
197
+ /// let params = LlamaContextParams::default()
198
+ /// .with_rope_freq_base(0.5);
199
+ /// assert_eq!(params.rope_freq_base(), 0.5);
200
+ /// ```
201
+ pub fn with_rope_freq_base ( mut self , rope_freq_base : f32 ) -> Self {
202
+ self . context_params . rope_freq_base = rope_freq_base;
203
+ self
204
+ }
205
+
206
+ /// Get the rope frequency base.
207
+ ///
208
+ /// # Examples
209
+ ///
210
+ /// ```rust
211
+ /// let params = llama_cpp_2::context::params::LlamaContextParams::default();
212
+ /// assert_eq!(params.rope_freq_base(), 0.0);
213
+ /// ```
214
+ pub fn rope_freq_base ( & self ) -> f32 {
215
+ self . context_params . rope_freq_base
216
+ }
217
+
218
+ /// Set the rope frequency scale.
219
+ ///
220
+ /// # Examples
221
+ ///
222
+ /// ```rust
223
+ /// use llama_cpp_2::context::params::LlamaContextParams;
224
+ /// let params = LlamaContextParams::default()
225
+ /// .with_rope_freq_scale(0.5);
226
+ /// assert_eq!(params.rope_freq_scale(), 0.5);
227
+ /// ```
228
+ pub fn with_rope_freq_scale ( mut self , rope_freq_scale : f32 ) -> Self {
229
+ self . context_params . rope_freq_scale = rope_freq_scale;
230
+ self
231
+ }
232
+
233
+ /// Get the rope frequency scale.
234
+ ///
235
+ /// # Examples
236
+ ///
237
+ /// ```rust
238
+ /// let params = llama_cpp_2::context::params::LlamaContextParams::default();
239
+ /// assert_eq!(params.rope_freq_scale(), 0.0);
240
+ /// ```
241
+ pub fn rope_freq_scale ( & self ) -> f32 {
242
+ self . context_params . rope_freq_scale
243
+ }
146
244
}
147
245
148
246
/// Default parameters for `LlamaContext`. (as defined in llama.cpp by `llama_context_default_params`)
@@ -156,6 +254,6 @@ impl LlamaContextParams {
156
254
impl Default for LlamaContextParams {
157
255
fn default ( ) -> Self {
158
256
let context_params = unsafe { llama_cpp_sys_2:: llama_context_default_params ( ) } ;
159
- Self { context_params, }
257
+ Self { context_params }
160
258
}
161
259
}
0 commit comments