Skip to content

Commit ce02277

Browse files
committed
fixed updates
1 parent 4f17cfd commit ce02277

File tree

2 files changed

+46
-40
lines changed

2 files changed

+46
-40
lines changed

llama-cpp-2/src/model.rs

Lines changed: 17 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ use crate::context::LlamaContext;
99
use crate::llama_backend::LlamaBackend;
1010
use crate::model::params::LlamaModelParams;
1111
use crate::token::LlamaToken;
12-
use crate::token_type::LlamaTokenType;
12+
use crate::token_type::LlamaTokenAttr;
1313
use crate::{
1414
ApplyChatTemplateError, ChatTemplateError, LlamaContextLoadError, LlamaModelLoadError,
1515
NewLlamaChatMessageError, StringToTokenError, TokenToStringError,
@@ -238,9 +238,9 @@ impl LlamaModel {
238238
///
239239
/// If the token type is not known to this library.
240240
#[must_use]
241-
pub fn token_type(&self, LlamaToken(id): LlamaToken) -> LlamaTokenType {
242-
let token_type = unsafe { llama_cpp_sys_2::llama_token_get_type(self.model.as_ptr(), id) };
243-
LlamaTokenType::try_from(token_type).expect("token type is valid")
241+
pub fn token_attr(&self, LlamaToken(id): LlamaToken) -> LlamaTokenAttr {
242+
let token_type = unsafe { llama_cpp_sys_2::llama_token_get_attr(self.model.as_ptr(), id) };
243+
LlamaTokenAttr::try_from(token_type).expect("token type is valid")
244244
}
245245

246246
/// Convert a token to a string with a specified buffer size.
@@ -292,18 +292,23 @@ impl LlamaModel {
292292
return Ok(String::from("\n").into_bytes());
293293
}
294294

295-
// unsure what to do with this in the face of the 'special' arg
296-
match self.token_type(token) {
297-
LlamaTokenType::Normal | LlamaTokenType::UserDefined => {}
298-
LlamaTokenType::Control => {
295+
// unsure what to do with this in the face of the 'special' arg + attr changes
296+
match self.token_attr(token) {
297+
LlamaTokenAttr::Normal
298+
| LlamaTokenAttr::UserDefined
299+
| LlamaTokenAttr::Normalized
300+
| LlamaTokenAttr::LStrip
301+
| LlamaTokenAttr::RStrip
302+
| LlamaTokenAttr::SingleWord => {}
303+
LlamaTokenAttr::Control => {
299304
if token == self.token_bos() || token == self.token_eos() {
300305
return Ok(Vec::new());
301306
}
302307
}
303-
LlamaTokenType::Unknown
304-
| LlamaTokenType::Undefined
305-
| LlamaTokenType::Byte
306-
| LlamaTokenType::Unused => {
308+
LlamaTokenAttr::Unknown
309+
| LlamaTokenAttr::Undefined
310+
| LlamaTokenAttr::Byte
311+
| LlamaTokenAttr::Unused => {
307312
return Ok(Vec::new());
308313
}
309314
}

llama-cpp-2/src/token_type.rs

Lines changed: 29 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -3,22 +3,19 @@
33
/// A rust flavored equivalent of `llama_token_type`.
44
#[repr(u32)]
55
#[derive(Eq, PartialEq, Debug, Clone, Copy)]
6-
#[allow(clippy::module_name_repetitions)]
7-
pub enum LlamaTokenType {
8-
/// An undefined token type.
9-
Undefined = llama_cpp_sys_2::LLAMA_TOKEN_TYPE_UNDEFINED as _,
10-
/// A normal token type.
11-
Normal = llama_cpp_sys_2::LLAMA_TOKEN_TYPE_NORMAL as _,
12-
/// An unknown token type.
13-
Unknown = llama_cpp_sys_2::LLAMA_TOKEN_TYPE_UNKNOWN as _,
14-
/// A control token type.
15-
Control = llama_cpp_sys_2::LLAMA_TOKEN_TYPE_CONTROL as _,
16-
/// A user defined token type.
17-
UserDefined = llama_cpp_sys_2::LLAMA_TOKEN_TYPE_USER_DEFINED as _,
18-
/// An unused token type.
19-
Unused = llama_cpp_sys_2::LLAMA_TOKEN_TYPE_UNUSED as _,
20-
/// A byte token type.
21-
Byte = llama_cpp_sys_2::LLAMA_TOKEN_TYPE_BYTE as _,
6+
#[allow(clippy::module_name_repetitions, missing_docs)]
7+
pub enum LlamaTokenAttr {
8+
Undefined = llama_cpp_sys_2::LLAMA_TOKEN_ATTR_UNDEFINED as _,
9+
Unknown = llama_cpp_sys_2::LLAMA_TOKEN_ATTR_UNKNOWN as _,
10+
Unused = llama_cpp_sys_2::LLAMA_TOKEN_ATTR_UNUSED as _,
11+
Normal = llama_cpp_sys_2::LLAMA_TOKEN_ATTR_NORMAL as _,
12+
Control = llama_cpp_sys_2::LLAMA_TOKEN_ATTR_CONTROL as _,
13+
UserDefined = llama_cpp_sys_2::LLAMA_TOKEN_ATTR_USER_DEFINED as _,
14+
Byte = llama_cpp_sys_2::LLAMA_TOKEN_ATTR_BYTE as _,
15+
Normalized = llama_cpp_sys_2::LLAMA_TOKEN_ATTR_NORMALIZED as _,
16+
LStrip = llama_cpp_sys_2::LLAMA_TOKEN_ATTR_LSTRIP as _,
17+
RStrip = llama_cpp_sys_2::LLAMA_TOKEN_ATTR_RSTRIP as _,
18+
SingleWord = llama_cpp_sys_2::LLAMA_TOKEN_ATTR_SINGLE_WORD as _,
2219
}
2320

2421
/// A safe wrapper for converting potentially deceptive `llama_token_type` values into
@@ -31,27 +28,31 @@ pub enum LlamaTokenType {
3128
/// # use std::ffi::c_int;
3229
/// # use std::num::TryFromIntError;
3330
/// # use std::result::Result;
34-
/// # use llama_cpp_2::token_type::{LlamaTokenTypeFromIntError, LlamaTokenType};
31+
/// # use llama_cpp_2::token_type::{LlamaTokenTypeFromIntError, LlamaTokenAttr};
3532
/// # fn main() -> Result<(), LlamaTokenTypeFromIntError> {
36-
/// let llama_token_type = LlamaTokenType::try_from(0 as llama_cpp_sys_2::llama_token_type)?;
37-
/// assert_eq!(llama_token_type, LlamaTokenType::Undefined);
33+
/// let llama_token_type = LlamaTokenAttr::try_from(0 as llama_cpp_sys_2::llama_token_type)?;
34+
/// assert_eq!(llama_token_type, LlamaTokenAttr::Undefined);
3835
///
39-
/// let bad_llama_token_type = LlamaTokenType::try_from(100 as llama_cpp_sys_2::llama_token_type);
36+
/// let bad_llama_token_type = LlamaTokenAttr::try_from(100 as llama_cpp_sys_2::llama_token_type);
4037
/// assert_eq!(Err(LlamaTokenTypeFromIntError::UnknownValue(100)), bad_llama_token_type);
4138
/// # Ok(())
4239
/// # }
43-
impl TryFrom<llama_cpp_sys_2::llama_token_type> for LlamaTokenType {
40+
impl TryFrom<llama_cpp_sys_2::llama_token_type> for LlamaTokenAttr {
4441
type Error = LlamaTokenTypeFromIntError;
4542

4643
fn try_from(value: llama_cpp_sys_2::llama_vocab_type) -> Result<Self, Self::Error> {
4744
match value {
48-
llama_cpp_sys_2::LLAMA_TOKEN_TYPE_UNDEFINED => Ok(LlamaTokenType::Undefined),
49-
llama_cpp_sys_2::LLAMA_TOKEN_TYPE_NORMAL => Ok(LlamaTokenType::Normal),
50-
llama_cpp_sys_2::LLAMA_TOKEN_TYPE_UNKNOWN => Ok(LlamaTokenType::Unknown),
51-
llama_cpp_sys_2::LLAMA_TOKEN_TYPE_CONTROL => Ok(LlamaTokenType::Control),
52-
llama_cpp_sys_2::LLAMA_TOKEN_TYPE_USER_DEFINED => Ok(LlamaTokenType::UserDefined),
53-
llama_cpp_sys_2::LLAMA_TOKEN_TYPE_UNUSED => Ok(LlamaTokenType::Unused),
54-
llama_cpp_sys_2::LLAMA_TOKEN_TYPE_BYTE => Ok(LlamaTokenType::Byte),
45+
llama_cpp_sys_2::LLAMA_TOKEN_ATTR_UNDEFINED => Ok(Self::Undefined),
46+
llama_cpp_sys_2::LLAMA_TOKEN_ATTR_UNKNOWN => Ok(Self::Unknown),
47+
llama_cpp_sys_2::LLAMA_TOKEN_ATTR_UNUSED => Ok(Self::Unused),
48+
llama_cpp_sys_2::LLAMA_TOKEN_ATTR_NORMAL => Ok(Self::Normal),
49+
llama_cpp_sys_2::LLAMA_TOKEN_ATTR_CONTROL => Ok(Self::Control),
50+
llama_cpp_sys_2::LLAMA_TOKEN_ATTR_USER_DEFINED => Ok(Self::UserDefined),
51+
llama_cpp_sys_2::LLAMA_TOKEN_ATTR_BYTE => Ok(Self::Byte),
52+
llama_cpp_sys_2::LLAMA_TOKEN_ATTR_NORMALIZED => Ok(Self::Normalized),
53+
llama_cpp_sys_2::LLAMA_TOKEN_ATTR_LSTRIP => Ok(Self::LStrip),
54+
llama_cpp_sys_2::LLAMA_TOKEN_ATTR_RSTRIP => Ok(Self::RStrip),
55+
llama_cpp_sys_2::LLAMA_TOKEN_ATTR_SINGLE_WORD => Ok(Self::SingleWord),
5556
_ => Err(LlamaTokenTypeFromIntError::UnknownValue(value as _)),
5657
}
5758
}

0 commit comments

Comments
 (0)