Skip to content

Commit 8b11c5c

Browse files
authored
Merge pull request #666 from vlovich/informational-methods
Expose model & backend informational methods
2 parents 6b4e52c + 899c217 commit 8b11c5c

File tree

2 files changed

+68
-0
lines changed

2 files changed

+68
-0
lines changed

llama-cpp-2/src/llama_backend.rs

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,21 @@ impl LlamaBackend {
7070
Ok(LlamaBackend {})
7171
}
7272

73+
/// Was the code built for a GPU backend & is a supported one available.
74+
pub fn supports_gpu_offload(&self) -> bool {
75+
unsafe { llama_cpp_sys_2::llama_supports_gpu_offload() }
76+
}
77+
78+
/// Does this platform support loading the model via mmap.
79+
pub fn supports_mmap(&self) -> bool {
80+
unsafe { llama_cpp_sys_2::llama_supports_mmap() }
81+
}
82+
83+
/// Does this platform support locking the model in RAM.
84+
pub fn supports_mlock(&self) -> bool {
85+
unsafe { llama_cpp_sys_2::llama_supports_mlock() }
86+
}
87+
7388
/// Change the output of llama.cpp's logging to be voided instead of pushed to `stderr`.
7489
pub fn void_logs(&mut self) {
7590
unsafe extern "C" fn void_log(

llama-cpp-2/src/model.rs

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,15 @@ impl LlamaChatMessage {
9292
}
9393
}
9494

95+
/// The Rope type that's used within the model.
96+
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
97+
pub enum RopeType {
98+
Norm,
99+
NeoX,
100+
MRope,
101+
Vision,
102+
}
103+
95104
/// How to determine if we should prepend a bos token to tokens
96105
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
97106
pub enum AddBos {
@@ -446,6 +455,50 @@ impl LlamaModel {
446455
unsafe { llama_cpp_sys_2::llama_n_embd(self.model.as_ptr()) }
447456
}
448457

458+
/// Returns the total size of all the tensors in the model in bytes.
459+
pub fn size(&self) -> u64 {
460+
unsafe { llama_cpp_sys_2::llama_model_size(self.model.as_ptr()) }
461+
}
462+
463+
/// Returns the number of parameters in the model.
464+
pub fn n_params(&self) -> u64 {
465+
unsafe { llama_cpp_sys_2::llama_model_n_params(self.model.as_ptr()) }
466+
}
467+
468+
/// Returns whether the model is a recurrent network (Mamba, RWKV, etc)
469+
pub fn is_recurrent(&self) -> bool {
470+
unsafe { llama_cpp_sys_2::llama_model_is_recurrent(self.model.as_ptr()) }
471+
}
472+
473+
/// Returns the number of layers within the model.
474+
pub fn n_layer(&self) -> u32 {
475+
// It's never possible for this to panic because while the API interface is defined as an int32_t,
476+
// the field it's accessing is a uint32_t.
477+
u32::try_from(unsafe { llama_cpp_sys_2::llama_model_n_layer(self.model.as_ptr()) }).unwrap()
478+
}
479+
480+
/// Returns the number of attention heads within the model.
481+
pub fn n_head(&self) -> u32 {
482+
// It's never possible for this to panic because while the API interface is defined as an int32_t,
483+
// the field it's accessing is a uint32_t.
484+
u32::try_from(unsafe { llama_cpp_sys_2::llama_model_n_head(self.model.as_ptr()) }).unwrap()
485+
}
486+
487+
/// Returns the rope type of the model.
488+
pub fn rope_type(&self) -> Option<RopeType> {
489+
match unsafe { llama_cpp_sys_2::llama_model_rope_type(self.model.as_ptr()) } {
490+
llama_cpp_sys_2::LLAMA_ROPE_TYPE_NONE => None,
491+
llama_cpp_sys_2::LLAMA_ROPE_TYPE_NORM => Some(RopeType::Norm),
492+
llama_cpp_sys_2::LLAMA_ROPE_TYPE_NEOX => Some(RopeType::NeoX),
493+
llama_cpp_sys_2::LLAMA_ROPE_TYPE_MROPE => Some(RopeType::MRope),
494+
llama_cpp_sys_2::LLAMA_ROPE_TYPE_VISION => Some(RopeType::Vision),
495+
rope_type => {
496+
tracing::error!(rope_type = rope_type, "Unexpected rope type from llama.cpp");
497+
None
498+
}
499+
}
500+
}
501+
449502
fn get_chat_template_impl(
450503
&self,
451504
capacity: usize,

0 commit comments

Comments
 (0)