Skip to content

Commit b1420f3

Browse files
authored
Merge pull request #504 from brittlewis12/ubatch
Expose `n_ubatch` context param
2 parents ea798fa + 56625a6 commit b1420f3

File tree

2 files changed

+38
-1
lines changed

2 files changed

+38
-1
lines changed

llama-cpp-2/src/context.rs

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,12 +52,18 @@ impl<'model> LlamaContext<'model> {
5252
}
5353
}
5454

55-
/// Gets the max number of tokens in a batch.
55+
/// Gets the max number of logical tokens that can be submitted to decode. Must be greater than or equal to n_ubatch.
5656
#[must_use]
5757
pub fn n_batch(&self) -> u32 {
5858
unsafe { llama_cpp_sys_2::llama_n_batch(self.context.as_ptr()) }
5959
}
6060

61+
/// Gets the max number of physical tokens (hardware level) to decode in batch. Must be less than or equal to n_batch.
62+
#[must_use]
63+
pub fn n_ubatch(&self) -> u32 {
64+
unsafe { llama_cpp_sys_2::llama_n_ubatch(self.context.as_ptr()) }
65+
}
66+
6167
/// Gets the size of the context.
6268
#[must_use]
6369
pub fn n_ctx(&self) -> u32 {

llama-cpp-2/src/context/params.rs

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,37 @@ impl LlamaContextParams {
166166
self.context_params.n_batch
167167
}
168168

169+
/// Set the `n_ubatch`
170+
///
171+
/// # Examples
172+
///
173+
/// ```rust
174+
/// # use std::num::NonZeroU32;
175+
/// use llama_cpp_2::context::params::LlamaContextParams;
176+
/// let params = LlamaContextParams::default()
177+
/// .with_n_ubatch(512);
178+
/// assert_eq!(params.n_ubatch(), 512);
179+
/// ```
180+
#[must_use]
181+
pub fn with_n_ubatch(mut self, n_ubatch: u32) -> Self {
182+
self.context_params.n_ubatch = n_ubatch;
183+
self
184+
}
185+
186+
/// Get the `n_ubatch`
187+
///
188+
/// # Examples
189+
///
190+
/// ```rust
191+
/// use llama_cpp_2::context::params::LlamaContextParams;
192+
/// let params = LlamaContextParams::default();
193+
/// assert_eq!(params.n_ubatch(), 512);
194+
/// ```
195+
#[must_use]
196+
pub fn n_ubatch(&self) -> u32 {
197+
self.context_params.n_ubatch
198+
}
199+
169200
/// Set the type of rope scaling.
170201
///
171202
/// # Examples

0 commit comments

Comments
 (0)