Skip to content

Commit ffbd54c

Browse files
committed
Expose pooling type
1 parent 4333caa commit ffbd54c

File tree

1 file changed

+72
-0
lines changed

1 file changed

+72
-0
lines changed

llama-cpp-2/src/context/params.rs

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,49 @@ impl From<RopeScalingType> for i32 {
4141
}
4242
}
4343

44+
/// A rusty wrapper around `LLAMA_POOLING_TYPE`.
45+
#[repr(i8)]
46+
#[derive(Copy, Clone, Debug, PartialEq, Eq)]
47+
pub enum LlamaPoolingType {
48+
/// The pooling type is unspecified
49+
Unspecified = -1,
50+
/// No pooling
51+
None = 0,
52+
/// Mean pooling
53+
Mean = 1,
54+
/// CLS pooling
55+
Cls = 2,
56+
/// Last pooling
57+
Last = 3,
58+
}
59+
60+
/// Create a `LlamaPoolingType` from a `c_int` - returns `LlamaPoolingType::Unspecified` if
61+
/// the value is not recognized.
62+
impl From<i32> for LlamaPoolingType {
63+
fn from(value: i32) -> Self {
64+
match value {
65+
0 => Self::None,
66+
1 => Self::Mean,
67+
2 => Self::Cls,
68+
3 => Self::Last,
69+
_ => Self::Unspecified,
70+
}
71+
}
72+
}
73+
74+
/// Create a `c_int` from a `LlamaPoolingType`.
75+
impl From<LlamaPoolingType> for i32 {
76+
fn from(value: LlamaPoolingType) -> Self {
77+
match value {
78+
LlamaPoolingType::None => 0,
79+
LlamaPoolingType::Mean => 1,
80+
LlamaPoolingType::Cls => 2,
81+
LlamaPoolingType::Last => 3,
82+
LlamaPoolingType::Unspecified => -1,
83+
}
84+
}
85+
}
86+
4487
/// A safe wrapper around `llama_context_params`.
4588
///
4689
/// Generally this should be created with [`Default::default()`] and then modified with `with_*` methods.
@@ -471,6 +514,35 @@ impl LlamaContextParams {
471514
self.context_params.cb_eval_user_data = cb_eval_user_data;
472515
self
473516
}
517+
518+
/// Set the type of pooling.
519+
///
520+
/// # Examples
521+
///
522+
/// ```rust
523+
/// use llama_cpp_2::context::params::{LlamaContextParams, LlamaPoolingType};
524+
/// let params = LlamaContextParams::default()
525+
/// .with_pooling_type(LlamaPoolingType::Last);
526+
/// assert_eq!(params.pooling_type(), LlamaPoolingType::Last);
527+
/// ```
528+
#[must_use]
529+
pub fn with_pooling_type(mut self, pooling_type: LlamaPoolingType) -> Self {
530+
self.context_params.pooling_type = i32::from(pooling_type);
531+
self
532+
}
533+
534+
/// Get the type of pooling.
535+
///
536+
/// # Examples
537+
///
538+
/// ```rust
539+
/// let params = llama_cpp_2::context::params::LlamaContextParams::default();
540+
/// assert_eq!(params.pooling_type(), llama_cpp_2::context::params::LlamaPoolingType::Unspecified);
541+
/// ```
542+
#[must_use]
543+
pub fn pooling_type(&self) -> LlamaPoolingType {
544+
LlamaPoolingType::from(self.context_params.pooling_type)
545+
}
474546
}
475547

476548
/// Default parameters for `LlamaContext`. (as defined in llama.cpp by `llama_context_default_params`)

0 commit comments

Comments
 (0)