@@ -5,130 +5,10 @@ use crate::grammar::LlamaGrammar;
55use crate :: token:: data_array:: LlamaTokenDataArray ;
66use crate :: token:: LlamaToken ;
77
8- /// struct to hold params for sampling
9- #[ derive( Debug ) ]
10- #[ deprecated(
11- since = "0.1.32" ,
12- note = "this does not scale well with many params and does not allow for changing of orders."
13- ) ]
14- pub struct Sampler < ' grammar > {
15- token_data_array : LlamaTokenDataArray ,
16- grammar : Option < & ' grammar mut LlamaGrammar > ,
17- temperature : Option < f32 > ,
18- }
19-
20- impl < ' grammar > Sampler < ' grammar > {
21- #[ deprecated(
22- since = "0.1.32" ,
23- note = "this does not scale well with many params and does not allow for changing of orders."
24- ) ]
25- fn sample ( self , llama_context : & mut LlamaContext ) -> LlamaToken {
26- match self {
27- Sampler {
28- token_data_array,
29- grammar : None ,
30- temperature : None ,
31- } => llama_context. sample_token_greedy ( token_data_array) ,
32- Sampler {
33- mut token_data_array,
34- grammar : Some ( grammar) ,
35- temperature : None ,
36- } => {
37- llama_context. sample_grammar ( & mut token_data_array, grammar) ;
38- let token = llama_context. sample_token_greedy ( token_data_array) ;
39- llama_context. grammar_accept_token ( grammar, token) ;
40- token
41- }
42- Sampler {
43- mut token_data_array,
44- grammar : None ,
45- temperature : Some ( temp) ,
46- } => {
47- llama_context. sample_temp ( & mut token_data_array, temp) ;
48- llama_context. sample_token_softmax ( & mut token_data_array) ;
49- token_data_array. data [ 0 ] . id ( )
50- }
51- Sampler {
52- mut token_data_array,
53- grammar : Some ( grammar) ,
54- temperature : Some ( temperature) ,
55- } => {
56- llama_context. sample_grammar ( & mut token_data_array, grammar) ;
57- llama_context. sample_temp ( & mut token_data_array, temperature) ;
58- llama_context. sample_token_softmax ( & mut token_data_array) ;
59- let token = llama_context. sample_token_greedy ( token_data_array) ;
60- llama_context. grammar_accept_token ( grammar, token) ;
61- token
62- }
63- }
64- }
65-
66- /// Create a new sampler.
67- #[ must_use]
68- #[ deprecated(
69- since = "0.1.32" ,
70- note = "this does not scale well with many params and does not allow for changing of orders."
71- ) ]
72- pub fn new ( llama_token_data_array : LlamaTokenDataArray ) -> Self {
73- Self {
74- token_data_array : llama_token_data_array,
75- grammar : None ,
76- temperature : None ,
77- }
78- }
79-
80- /// Set the grammar for sampling.
81- #[ must_use]
82- #[ deprecated(
83- since = "0.1.32" ,
84- note = "this does not scale well with many params and does not allow for changing of orders."
85- ) ]
86- pub fn with_grammar ( mut self , grammar : & ' grammar mut LlamaGrammar ) -> Self {
87- self . grammar = Some ( grammar) ;
88- self
89- }
90-
91- /// Set the temperature for sampling.
92- ///
93- /// ```
94- /// # use llama_cpp_2::context::LlamaContext;
95- /// # use llama_cpp_2::context::sample::Sampler;
96- /// # use llama_cpp_2::grammar::LlamaGrammar;
97- /// # use llama_cpp_2::token::data::LlamaTokenData;
98- /// # use llama_cpp_2::token::data_array::LlamaTokenDataArray;
99- /// # use llama_cpp_2::token::LlamaToken;
100- ///
101- /// let _sampler = Sampler::new(LlamaTokenDataArray::new(vec![LlamaTokenData::new(LlamaToken(0), 0.0, 0.0)], false))
102- /// .with_temperature(0.5);
103- /// ```
104- #[ must_use]
105- #[ deprecated(
106- since = "0.1.32" ,
107- note = "this does not scale well with many params and does not allow for changing of orders."
108- ) ]
109- pub fn with_temperature ( mut self , temperature : f32 ) -> Self {
110- if temperature == 0.0 {
111- return self ;
112- }
113- self . temperature = Some ( temperature) ;
114- self
115- }
116- }
8+ #[ cfg( feature = "sampler" ) ]
9+ pub mod sampler;
11710
11811impl LlamaContext < ' _ > {
119- /// Sample a token.
120- ///
121- /// # Panics
122- ///
123- /// - sampler contains no tokens
124- #[ deprecated(
125- since = "0.1.32" ,
126- note = "this does not scale well with many params and does not allow for changing of orders."
127- ) ]
128- pub fn sample ( & mut self , sampler : Sampler ) -> LlamaToken {
129- sampler. sample ( self )
130- }
131-
13212 /// Accept a token into the grammar.
13313 pub fn grammar_accept_token ( & mut self , grammar : & mut LlamaGrammar , token : LlamaToken ) {
13414 unsafe {
@@ -157,38 +37,20 @@ impl LlamaContext<'_> {
15737 }
15838 }
15939
160- /// Modify [`token_data`] in place using temperature sampling.
161- ///
162- /// # Panics
163- ///
164- /// - [`temperature`] is not between 0.0 and 1.0
165- pub fn sample_temp ( & self , token_data : & mut LlamaTokenDataArray , temperature : f32 ) {
166- assert ! (
167- temperature >= 0.0 ,
168- "temperature must be positive (was {temperature})"
169- ) ;
170- assert ! (
171- temperature <= 1.0 ,
172- "temperature must be less than or equal to 1.0 (was {temperature})"
173- ) ;
174- if temperature == 0.0 {
175- return ;
176- }
177- let ctx: * mut llama_cpp_sys_2:: llama_context = self . context . as_ptr ( ) ;
178- unsafe {
179- token_data. modify_as_c_llama_token_data_array ( |c_llama_token_data_array| {
180- llama_cpp_sys_2:: llama_sample_temp ( ctx, c_llama_token_data_array, temperature) ;
181- } ) ;
182- }
40+ /// See [`LlamaTokenDataArray::sample_temp`]
41+ pub fn sample_temp ( & mut self , token_data : & mut LlamaTokenDataArray , temperature : f32 ) {
42+ token_data. sample_temp ( Some ( self ) , temperature) ;
18343 }
18444
185- /// Sample a token greedily.
45+ /// Sample a token greedily. Note that this *does not* take into account anything that has modified the probabilities - it only looks at logits.
46+ ///
47+ /// Most of the time [`LlamaTokenDataArray::sample_softmax`] or [`LlamaTokenDataArray::sample_token`] should be used instead.
18648 ///
18749 /// # Panics
18850 ///
189- /// - [ `token_data`] is empty
51+ /// - if `token_data` is empty
19052 #[ must_use]
191- pub fn sample_token_greedy ( & self , mut token_data : LlamaTokenDataArray ) -> LlamaToken {
53+ pub fn sample_token_greedy ( & mut self , mut token_data : LlamaTokenDataArray ) -> LlamaToken {
19254 assert ! ( !token_data. data. is_empty( ) , "no tokens" ) ;
19355 let mut data_arr = llama_cpp_sys_2:: llama_token_data_array {
19456 data : token_data
@@ -207,39 +69,34 @@ impl LlamaContext<'_> {
20769 LlamaToken ( token)
20870 }
20971
210- /// Tail Free Sampling described in [Tail-Free-Sampling](https://www.trentonbricken.com/Tail-Free-Sampling/).
211- pub fn sample_tail_free ( & self , token_data : & mut LlamaTokenDataArray , z : f32 , min_keep : usize ) {
212- let ctx = self . context . as_ptr ( ) ;
213- unsafe {
214- token_data . modify_as_c_llama_token_data_array ( |c_llama_token_data_array| {
215- llama_cpp_sys_2 :: llama_sample_tail_free ( ctx , c_llama_token_data_array , z , min_keep ) ;
216- } ) ;
217- }
72+ /// See [`LlamaTokenDataArray::sample_tail_free`]
73+ pub fn sample_tail_free (
74+ & mut self ,
75+ token_data : & mut LlamaTokenDataArray ,
76+ z : f32 ,
77+ min_keep : usize ,
78+ ) {
79+ token_data . sample_tail_free ( Some ( self ) , z , min_keep ) ;
21880 }
21981
220- /// Locally Typical Sampling implementation described in the [paper](https://arxiv.org/abs/2202.00666).
221- pub fn sample_typical ( & self , token_data : & mut LlamaTokenDataArray , p : f32 , min_keep : usize ) {
222- let ctx = self . context . as_ptr ( ) ;
223- unsafe {
224- token_data . modify_as_c_llama_token_data_array ( |c_llama_token_data_array| {
225- llama_cpp_sys_2 :: llama_sample_typical ( ctx , c_llama_token_data_array , p , min_keep ) ;
226- } ) ;
227- }
82+ /// See [`LlamaTokenDataArray::sample_typical`]
83+ pub fn sample_typical (
84+ & mut self ,
85+ token_data : & mut LlamaTokenDataArray ,
86+ p : f32 ,
87+ min_keep : usize ,
88+ ) {
89+ token_data . sample_typical ( Some ( self ) , p , min_keep ) ;
22890 }
22991
230- /// Nucleus sampling described in academic paper [The Curious Case of Neural Text Degeneration](https://arxiv.org/abs/1904.09751)"
231- pub fn sample_top_p ( & self , token_data : & mut LlamaTokenDataArray , p : f32 , min_keep : usize ) {
232- let ctx = self . context . as_ptr ( ) ;
233- unsafe {
234- token_data. modify_as_c_llama_token_data_array ( |c_llama_token_data_array| {
235- llama_cpp_sys_2:: llama_sample_top_p ( ctx, c_llama_token_data_array, p, min_keep) ;
236- } ) ;
237- }
92+ /// See [`LlamaTokenDataArray::sample_top_p`]
93+ pub fn sample_top_p ( & mut self , token_data : & mut LlamaTokenDataArray , p : f32 , min_keep : usize ) {
94+ token_data. sample_top_p ( Some ( self ) , p, min_keep) ;
23895 }
23996
24097 /// Minimum P sampling as described in [#3841](https://github.com/ggerganov/llama.cpp/pull/3841)
24198 pub fn sample_min_p (
242- & self ,
99+ & mut self ,
243100 llama_token_data : & mut LlamaTokenDataArray ,
244101 p : f32 ,
245102 min_keep : usize ,
@@ -252,24 +109,14 @@ impl LlamaContext<'_> {
252109 }
253110 }
254111
255- /// Top-K sampling described in academic paper [The Curious Case of Neural Text Degeneration](https://arxiv.org/abs/1904.09751)
256- pub fn sample_top_k ( & self , token_data : & mut LlamaTokenDataArray , k : i32 , min_keep : usize ) {
257- let ctx = self . context . as_ptr ( ) ;
258- unsafe {
259- token_data. modify_as_c_llama_token_data_array ( |c_llama_token_data_array| {
260- llama_cpp_sys_2:: llama_sample_top_k ( ctx, c_llama_token_data_array, k, min_keep) ;
261- } ) ;
262- }
112+ /// See [`LlamaTokenDataArray::sample_top_k`]
113+ pub fn sample_top_k ( & mut self , token_data : & mut LlamaTokenDataArray , k : i32 , min_keep : usize ) {
114+ token_data. sample_top_k ( Some ( self ) , k, min_keep) ;
263115 }
264116
265- /// Sorts candidate tokens by their logits in descending order and calculate probabilities based on logits.
266- pub fn sample_token_softmax ( & self , token_data : & mut LlamaTokenDataArray ) {
267- let ctx = self . context . as_ptr ( ) ;
268- unsafe {
269- token_data. modify_as_c_llama_token_data_array ( |c_llama_token_data_array| {
270- llama_cpp_sys_2:: llama_sample_softmax ( ctx, c_llama_token_data_array) ;
271- } ) ;
272- }
117+ /// See [`LlamaTokenDataArray::sample_softmax`]
118+ pub fn sample_token_softmax ( & mut self , token_data : & mut LlamaTokenDataArray ) {
119+ token_data. sample_softmax ( Some ( self ) ) ;
273120 }
274121
275122 /// See [`LlamaTokenDataArray::sample_repetition_penalty`]
0 commit comments