Skip to content

Commit 56625a6

Browse files
committed
Expose n_ubatch context param
* n_batch is responsible for max number of tokens llama_decode can accept in a single call (a single "batch") * n_ubatch is lower level, corresponding to hardware batch size during decoding. must be less than or equal to n_batch. - ggml-org/llama.cpp#6328 (comment) - https://github.com/ggerganov/llama.cpp/blob/557410b8f06380560155ac7fcb8316d71ddc9837/common/common.h#L58
1 parent ea798fa commit 56625a6

File tree

2 files changed

+38
-1
lines changed

2 files changed

+38
-1
lines changed

llama-cpp-2/src/context.rs

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,12 +52,18 @@ impl<'model> LlamaContext<'model> {
5252
}
5353
}
5454

55-
/// Gets the max number of tokens in a batch.
55+
/// Gets the max number of logical tokens that can be submitted to decode. Must be greater than or equal to n_ubatch.
5656
#[must_use]
5757
pub fn n_batch(&self) -> u32 {
5858
unsafe { llama_cpp_sys_2::llama_n_batch(self.context.as_ptr()) }
5959
}
6060

61+
/// Gets the max number of physical tokens (hardware level) to decode in batch. Must be less than or equal to n_batch.
62+
#[must_use]
63+
pub fn n_ubatch(&self) -> u32 {
64+
unsafe { llama_cpp_sys_2::llama_n_ubatch(self.context.as_ptr()) }
65+
}
66+
6167
/// Gets the size of the context.
6268
#[must_use]
6369
pub fn n_ctx(&self) -> u32 {

llama-cpp-2/src/context/params.rs

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,37 @@ impl LlamaContextParams {
166166
self.context_params.n_batch
167167
}
168168

169+
/// Set the `n_ubatch`
170+
///
171+
/// # Examples
172+
///
173+
/// ```rust
174+
/// # use std::num::NonZeroU32;
175+
/// use llama_cpp_2::context::params::LlamaContextParams;
176+
/// let params = LlamaContextParams::default()
177+
/// .with_n_ubatch(512);
178+
/// assert_eq!(params.n_ubatch(), 512);
179+
/// ```
180+
#[must_use]
181+
pub fn with_n_ubatch(mut self, n_ubatch: u32) -> Self {
182+
self.context_params.n_ubatch = n_ubatch;
183+
self
184+
}
185+
186+
/// Get the `n_ubatch`
187+
///
188+
/// # Examples
189+
///
190+
/// ```rust
191+
/// use llama_cpp_2::context::params::LlamaContextParams;
192+
/// let params = LlamaContextParams::default();
193+
/// assert_eq!(params.n_ubatch(), 512);
194+
/// ```
195+
#[must_use]
196+
pub fn n_ubatch(&self) -> u32 {
197+
self.context_params.n_ubatch
198+
}
199+
169200
/// Set the type of rope scaling.
170201
///
171202
/// # Examples

0 commit comments

Comments
 (0)