diff --git a/transformers/causal_llm/src/bin/client.rs b/transformers/causal_llm/src/bin/client.rs index 64f4d9f605..6d5a0e5033 100644 --- a/transformers/causal_llm/src/bin/client.rs +++ b/transformers/causal_llm/src/bin/client.rs @@ -64,6 +64,7 @@ impl Api { model: &str, prompt: impl Into, max_tokens: usize, + temperature: Option, ) -> Result { match &self { Api::OpenAICompletions(client, endpoint) => { @@ -71,6 +72,7 @@ impl Api { prompt: prompt.into(), model: model.to_string(), max_tokens, + temperature, stop: vec![], }; let response = client @@ -315,11 +317,13 @@ struct CompleteArgs { prompt: String, #[arg(short('n'), default_value = "50")] max_tokens: usize, + #[arg(short('t'), long("temperature"))] + temperature: Option, } impl CompleteArgs { async fn handle(&self, clients: &Clients) -> Result<()> { - let reply = clients.complete(&self.prompt, self.max_tokens).await?; + let reply = clients.complete(&self.prompt, self.max_tokens, self.temperature).await?; println!("{}", reply.text); eprintln!("prompt:{:?} generated:{}", reply.prompt_tokens, reply.generated_tokens); Ok(()) @@ -381,15 +385,16 @@ impl Clients { } async fn run_one_generate(&self, pp: usize, tg: usize) -> Result { - self.api.generate(&self.model, &self.get_one_prompt(pp), tg).await + self.api.generate(&self.model, &self.get_one_prompt(pp), tg, None).await } async fn complete( &self, prompt: impl Into, max_tokens: usize, + temperature: Option, ) -> Result { - self.api.generate(&self.model, prompt.into(), max_tokens).await + self.api.generate(&self.model, prompt.into(), max_tokens, temperature).await } async fn bench_one_generate(&self, pp: usize, tg: usize) -> Result { diff --git a/transformers/causal_llm/src/bin/common/mod.rs b/transformers/causal_llm/src/bin/common/mod.rs index af67340bdd..233e090f0b 100644 --- a/transformers/causal_llm/src/bin/common/mod.rs +++ b/transformers/causal_llm/src/bin/common/mod.rs @@ -5,6 +5,7 @@ pub struct OpenAICompletionQuery { pub prompt: String, pub model: String, pub max_tokens: usize, + pub temperature: Option, pub stop: Vec, }