@@ -41,6 +41,49 @@ impl From<RopeScalingType> for i32 {
4141 }
4242}
4343
44+ /// A rusty wrapper around `LLAMA_POOLING_TYPE`.
45+ #[ repr( i8 ) ]
46+ #[ derive( Copy , Clone , Debug , PartialEq , Eq ) ]
47+ pub enum LlamaPoolingType {
48+ /// The pooling type is unspecified
49+ Unspecified = -1 ,
50+ /// No pooling
51+ None = 0 ,
52+ /// Mean pooling
53+ Mean = 1 ,
54+ /// CLS pooling
55+ Cls = 2 ,
56+ /// Last pooling
57+ Last = 3 ,
58+ }
59+
60+ /// Create a `LlamaPoolingType` from a `c_int` - returns `LlamaPoolingType::Unspecified` if
61+ /// the value is not recognized.
62+ impl From < i32 > for LlamaPoolingType {
63+ fn from ( value : i32 ) -> Self {
64+ match value {
65+ 0 => Self :: None ,
66+ 1 => Self :: Mean ,
67+ 2 => Self :: Cls ,
68+ 3 => Self :: Last ,
69+ _ => Self :: Unspecified ,
70+ }
71+ }
72+ }
73+
74+ /// Create a `c_int` from a `LlamaPoolingType`.
75+ impl From < LlamaPoolingType > for i32 {
76+ fn from ( value : LlamaPoolingType ) -> Self {
77+ match value {
78+ LlamaPoolingType :: None => 0 ,
79+ LlamaPoolingType :: Mean => 1 ,
80+ LlamaPoolingType :: Cls => 2 ,
81+ LlamaPoolingType :: Last => 3 ,
82+ LlamaPoolingType :: Unspecified => -1 ,
83+ }
84+ }
85+ }
86+
4487/// A safe wrapper around `llama_context_params`.
4588///
4689/// Generally this should be created with [`Default::default()`] and then modified with `with_*` methods.
@@ -471,6 +514,35 @@ impl LlamaContextParams {
471514 self . context_params . cb_eval_user_data = cb_eval_user_data;
472515 self
473516 }
517+
518+ /// Set the type of pooling.
519+ ///
520+ /// # Examples
521+ ///
522+ /// ```rust
523+ /// use llama_cpp_2::context::params::{LlamaContextParams, LlamaPoolingType};
524+ /// let params = LlamaContextParams::default()
525+ /// .with_pooling_type(LlamaPoolingType::Last);
526+ /// assert_eq!(params.pooling_type(), LlamaPoolingType::Last);
527+ /// ```
528+ #[ must_use]
529+ pub fn with_pooling_type ( mut self , pooling_type : LlamaPoolingType ) -> Self {
530+ self . context_params . pooling_type = i32:: from ( pooling_type) ;
531+ self
532+ }
533+
534+ /// Get the type of pooling.
535+ ///
536+ /// # Examples
537+ ///
538+ /// ```rust
539+ /// let params = llama_cpp_2::context::params::LlamaContextParams::default();
540+ /// assert_eq!(params.pooling_type(), llama_cpp_2::context::params::LlamaPoolingType::Unspecified);
541+ /// ```
542+ #[ must_use]
543+ pub fn pooling_type ( & self ) -> LlamaPoolingType {
544+ LlamaPoolingType :: from ( self . context_params . pooling_type )
545+ }
474546}
475547
476548/// Default parameters for `LlamaContext`. (as defined in llama.cpp by `llama_context_default_params`)
0 commit comments