Skip to content

Commit 7e7df3d

Browse files
committed
Default to stream
1 parent 6b799fc commit 7e7df3d

File tree

3 files changed

+27
-14
lines changed

3 files changed

+27
-14
lines changed

Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ clap = { version = "4.5.29", features = ["derive"] }
1010
tokio = { version = "1", features = ["macros", "rt-multi-thread"] }
1111
transformrs = { git = "https://github.com/rikhuijzer/transformrs.git", rev = "06b759f" }
1212
anyhow = "1"
13+
futures-util = "0.3.31"
1314

1415
[dev-dependencies]
1516
assert_cmd = "2"

README.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,8 +59,9 @@ Use plain text only; no markdown.
5959
Here is the text to check:
6060
6161
"
62+
MODEL="deepseek-ai/DeepSeek-R1-Distill-Llama-70B"
6263

63-
(echo "$PROMPT"; cat README.md) | ata chat --model="deepseek-ai/DeepSeek-R1"
64+
(echo "$PROMPT"; cat README.md) | ata chat --model="$MODEL"
6465
```
6566

6667
### Text to Speech in Bash

src/chat.rs

Lines changed: 24 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ use std::fs::File;
22
use std::io::Write;
33
use transformrs::Message;
44
use transformrs::Provider;
5+
use futures_util::stream::StreamExt;
56

67
#[derive(clap::Parser)]
78
pub(crate) struct ChatArgs {
@@ -42,19 +43,29 @@ pub(crate) async fn chat(args: &ChatArgs, key: &transformrs::Key, input: &str) {
4243
.clone()
4344
.unwrap_or_else(|| default_model(&provider));
4445
let messages = vec![Message::from_str("user", input)];
45-
let resp = transformrs::chat::chat_completion(&provider, key, &model, &messages)
46-
.await
47-
.expect("Chat completion failed");
48-
if args.raw_json {
49-
let json = resp.raw_value();
50-
println!("{}", json.unwrap());
51-
}
52-
let resp = resp.structured().expect("Could not parse response");
53-
let content = resp.choices[0].message.content.clone();
54-
if let Some(output) = args.output.clone() {
55-
let mut file = File::create(output).unwrap();
56-
file.write_all(content.to_string().as_bytes()).unwrap();
46+
if args.stream {
47+
let mut stream = transformrs::chat::stream_chat_completion(&provider, key, &model, &messages)
48+
.await
49+
.expect("Streaming chat completion failed");
50+
while let Some(resp) = stream.next().await {
51+
let msg = resp.choices[0].delta.content.clone().unwrap_or_default();
52+
print!("{}", msg);
53+
}
5754
} else {
58-
println!("{}", content);
55+
let resp = transformrs::chat::chat_completion(&provider, key, &model, &messages)
56+
.await
57+
.expect("Chat completion failed");
58+
if args.raw_json {
59+
let json = resp.raw_value();
60+
println!("{}", json.unwrap());
61+
}
62+
let resp = resp.structured().expect("Could not parse response");
63+
let content = resp.choices[0].message.content.clone();
64+
if let Some(output) = args.output.clone() {
65+
let mut file = File::create(output).unwrap();
66+
file.write_all(content.to_string().as_bytes()).unwrap();
67+
} else {
68+
println!("{}", content);
69+
}
5970
}
6071
}

0 commit comments

Comments
 (0)