@@ -92,6 +92,15 @@ impl LlamaChatMessage {
9292 }
9393}
9494
95+ /// The Rope type that's used within the model.
96+ #[ derive( Debug , Clone , Copy , PartialEq , Eq ) ]
97+ pub enum RopeType {
98+ Norm ,
99+ NeoX ,
100+ MRope ,
101+ Vision ,
102+ }
103+
95104/// How to determine if we should prepend a bos token to tokens
96105#[ derive( Debug , Clone , Copy , PartialEq , Eq ) ]
97106pub enum AddBos {
@@ -446,6 +455,50 @@ impl LlamaModel {
446455 unsafe { llama_cpp_sys_2:: llama_n_embd ( self . model . as_ptr ( ) ) }
447456 }
448457
458+ /// Returns the total size of all the tensors in the model in bytes.
459+ pub fn size ( & self ) -> u64 {
460+ unsafe { llama_cpp_sys_2:: llama_model_size ( self . model . as_ptr ( ) ) }
461+ }
462+
463+ /// Returns the number of parameters in the model.
464+ pub fn n_params ( & self ) -> u64 {
465+ unsafe { llama_cpp_sys_2:: llama_model_n_params ( self . model . as_ptr ( ) ) }
466+ }
467+
468+ /// Returns whether the model is a recurrent network (Mamba, RWKV, etc)
469+ pub fn is_recurrent ( & self ) -> bool {
470+ unsafe { llama_cpp_sys_2:: llama_model_is_recurrent ( self . model . as_ptr ( ) ) }
471+ }
472+
473+ /// Returns the number of layers within the model.
474+ pub fn n_layer ( & self ) -> u32 {
475+ // It's never possible for this to panic because while the API interface is defined as an int32_t,
476+ // the field it's accessing is a uint32_t.
477+ u32:: try_from ( unsafe { llama_cpp_sys_2:: llama_model_n_layer ( self . model . as_ptr ( ) ) } ) . unwrap ( )
478+ }
479+
480+ /// Returns the number of attention heads within the model.
481+ pub fn n_head ( & self ) -> u32 {
482+ // It's never possible for this to panic because while the API interface is defined as an int32_t,
483+ // the field it's accessing is a uint32_t.
484+ u32:: try_from ( unsafe { llama_cpp_sys_2:: llama_model_n_head ( self . model . as_ptr ( ) ) } ) . unwrap ( )
485+ }
486+
487+ /// Returns the rope type of the model.
488+ pub fn rope_type ( & self ) -> Option < RopeType > {
489+ match unsafe { llama_cpp_sys_2:: llama_model_rope_type ( self . model . as_ptr ( ) ) } {
490+ llama_cpp_sys_2:: LLAMA_ROPE_TYPE_NONE => None ,
491+ llama_cpp_sys_2:: LLAMA_ROPE_TYPE_NORM => Some ( RopeType :: Norm ) ,
492+ llama_cpp_sys_2:: LLAMA_ROPE_TYPE_NEOX => Some ( RopeType :: NeoX ) ,
493+ llama_cpp_sys_2:: LLAMA_ROPE_TYPE_MROPE => Some ( RopeType :: MRope ) ,
494+ llama_cpp_sys_2:: LLAMA_ROPE_TYPE_VISION => Some ( RopeType :: Vision ) ,
495+ rope_type => {
496+ tracing:: error!( rope_type = rope_type, "Unexpected rope type from llama.cpp" ) ;
497+ None
498+ }
499+ }
500+ }
501+
449502 fn get_chat_template_impl (
450503 & self ,
451504 capacity : usize ,
0 commit comments