Skip to content

Commit ed4da72

Browse files
committed
Implement chat and tests
1 parent 5363c22 commit ed4da72

File tree

6 files changed

+177
-26
lines changed

6 files changed

+177
-26
lines changed

README.md

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,38 @@
44

55
## Examples
66

7+
### Chat in Bash
8+
9+
We can chat straight from the command line.
10+
For example, via the DeepInfra API:
11+
12+
```sh
13+
$ DEEPINFRA_KEY="$(cat /path/to/key)"; echo "hi there" | ata chat
14+
```
15+
16+
This defaults to the `meta-llama/Llama-3.3-70B-Instruct` model.
17+
We can also create a Bash script to provide some default settings to the chat.
18+
For example, create a file called `chat.sh` with the following content:
19+
20+
```bash
21+
#!/usr/bin/env bash
22+
23+
# Exit on (pipe) errors.
24+
set -euo pipefail
25+
26+
export OPENAI_KEY="$(cat /path/to/key)"
27+
28+
ata chat --model="gpt-4o"
29+
```
30+
31+
and add it to your PATH.
32+
Now, we can use it like this:
33+
34+
```sh
35+
$ echo "This is a test. Respond with 'hello'." | ata chat
36+
hello
37+
```
38+
739
### Text to Speech in Bash
840

941
We can read a file out loud from the command line.

src/chat.rs

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
use std::fs::File;
2+
use std::io::Write;
3+
use transformrs::Message;
4+
use transformrs::Provider;
5+
6+
#[derive(clap::Parser)]
7+
pub(crate) struct ChatArgs {
8+
/// Model to use (optional)
9+
#[arg(long)]
10+
model: Option<String>,
11+
12+
/// Output file (optional)
13+
#[arg(long, short = 'o')]
14+
output: Option<String>,
15+
16+
/// Stream output
17+
#[arg(long, default_value_t = true)]
18+
stream: bool,
19+
20+
/// Raw JSON output
21+
#[arg(long)]
22+
raw_json: bool,
23+
24+
/// Language code (optional)
25+
#[arg(long)]
26+
language_code: Option<String>,
27+
}
28+
29+
fn default_model(provider: &Provider) -> String {
30+
match provider {
31+
Provider::Google => "models/gemini-1.5-flash",
32+
Provider::OpenAI => "gpt-4o-mini",
33+
_ => "meta-llama/Llama-3.3-70B-Instruct",
34+
}
35+
.to_string()
36+
}
37+
38+
pub(crate) async fn chat(args: &ChatArgs, key: &transformrs::Key, input: &str) {
39+
let provider = key.provider.clone();
40+
let model = args
41+
.model
42+
.clone()
43+
.unwrap_or_else(|| default_model(&provider));
44+
let messages = vec![Message::from_str("user", input)];
45+
let resp = transformrs::chat::chat_completion(&provider, key, &model, &messages)
46+
.await
47+
.expect("Chat completion failed");
48+
if args.raw_json {
49+
let json = resp.raw_value();
50+
println!("{}", json.unwrap());
51+
}
52+
let resp = resp.structured().expect("Could not parse response");
53+
let content = resp.choices[0].message.content.clone();
54+
if let Some(output) = args.output.clone() {
55+
let mut file = File::create(output).unwrap();
56+
file.write_all(content.to_string().as_bytes()).unwrap();
57+
} else {
58+
println!("{}", content);
59+
}
60+
}

src/main.rs

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,19 @@
1+
mod chat;
12
mod tts;
23

4+
use chat::ChatArgs;
35
use clap::Parser;
46
use std::io::Read;
57
use transformrs::Key;
68
use tts::TextToSpeechArgs;
79

810
#[derive(clap::Subcommand)]
911
enum Commands {
12+
/// OpenAI-compatible chat.
13+
///
14+
/// Takes text input from stdin and chats with an AI model.
15+
#[command()]
16+
Chat(ChatArgs),
1017
/// Convert text to speech
1118
///
1219
/// Takes text input from stdin and converts it to speech using text-to-speech models.
@@ -50,6 +57,9 @@ async fn main() {
5057
let key = find_single_key(keys);
5158

5259
match args.command {
60+
Commands::Chat(args) => {
61+
chat::chat(&args, &key, &input).await;
62+
}
5363
Commands::Tts(args) => {
5464
tts::tts(&args, &key, &input).await;
5565
}

tests/chat.rs

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
mod common;
2+
3+
use common::ata;
4+
use common::load_key;
5+
use predicates::prelude::*;
6+
use transformrs::Provider;
7+
8+
fn canonicalize_response(text: &str) -> String {
9+
text.to_lowercase()
10+
.trim()
11+
.trim_end_matches('.')
12+
.trim_end_matches('!')
13+
.to_string()
14+
}
15+
16+
#[test]
17+
fn unexpected_argument() -> Result<(), Box<dyn std::error::Error>> {
18+
let mut cmd = ata();
19+
cmd.arg("foobar");
20+
cmd.assert()
21+
.failure()
22+
.stderr(predicate::str::contains("unrecognized subcommand"));
23+
24+
Ok(())
25+
}
26+
27+
#[test]
28+
fn tts_no_args() -> Result<(), Box<dyn std::error::Error>> {
29+
let dir = tempfile::tempdir().unwrap();
30+
let mut cmd = ata();
31+
let key = load_key(&Provider::DeepInfra);
32+
let cmd = cmd
33+
.arg("chat")
34+
.env("DEEPINFRA_KEY", key)
35+
.write_stdin("This is a test. Respond with 'hello'.")
36+
.current_dir(&dir);
37+
let output = cmd.assert().success().get_output().stdout.clone();
38+
39+
let text = String::from_utf8(output.clone()).unwrap();
40+
let content = canonicalize_response(&text);
41+
assert_eq!(content, "hello");
42+
Ok(())
43+
}

tests/common/mod.rs

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
use assert_cmd::Command;
2+
use std::io::BufRead;
3+
use transformrs::Provider;
4+
5+
pub fn ata() -> Command {
6+
Command::cargo_bin("ata").unwrap()
7+
}
8+
9+
#[allow(dead_code)]
10+
/// Load a key from the local .env file.
11+
///
12+
/// This is used for testing only. Expects the .env file to contain keys for providers in the following format:
13+
///
14+
/// ```
15+
/// DEEPINFRA_KEY="<KEY>"
16+
/// OPENAI_KEY="<KEY>"
17+
/// ```
18+
pub fn load_key(provider: &Provider) -> String {
19+
fn finder(line: &Result<String, std::io::Error>, provider: &Provider) -> bool {
20+
line.as_ref().unwrap().starts_with(&provider.key_name())
21+
}
22+
let path = std::path::Path::new("test.env");
23+
let file = std::fs::File::open(path).expect("Failed to open .env file");
24+
let reader = std::io::BufReader::new(file);
25+
let mut lines = reader.lines();
26+
let key = lines.find(|line| finder(line, provider)).unwrap().unwrap();
27+
key.split("=").nth(1).unwrap().to_string()
28+
}

tests/cli.rs renamed to tests/tts.rs

Lines changed: 4 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,10 @@
1-
use assert_cmd::Command;
1+
mod common;
2+
3+
use common::ata;
4+
use common::load_key;
25
use predicates::prelude::*;
3-
use std::io::BufRead;
46
use transformrs::Provider;
57

6-
fn ata() -> Command {
7-
Command::cargo_bin("ata").unwrap()
8-
}
9-
108
#[test]
119
fn unexpected_argument() -> Result<(), Box<dyn std::error::Error>> {
1210
let mut cmd = ata();
@@ -29,26 +27,6 @@ fn help() -> Result<(), Box<dyn std::error::Error>> {
2927
Ok(())
3028
}
3129

32-
/// Load a key from the local .env file.
33-
///
34-
/// This is used for testing only. Expects the .env file to contain keys for providers in the following format:
35-
///
36-
/// ```
37-
/// DEEPINFRA_KEY="<KEY>"
38-
/// OPENAI_KEY="<KEY>"
39-
/// ```
40-
fn load_key(provider: &Provider) -> String {
41-
fn finder(line: &Result<String, std::io::Error>, provider: &Provider) -> bool {
42-
line.as_ref().unwrap().starts_with(&provider.key_name())
43-
}
44-
let path = std::path::Path::new("test.env");
45-
let file = std::fs::File::open(path).expect("Failed to open .env file");
46-
let reader = std::io::BufReader::new(file);
47-
let mut lines = reader.lines();
48-
let key = lines.find(|line| finder(line, provider)).unwrap().unwrap();
49-
key.split("=").nth(1).unwrap().to_string()
50-
}
51-
5230
#[test]
5331
fn tts_no_args() -> Result<(), Box<dyn std::error::Error>> {
5432
let dir = tempfile::tempdir().unwrap();

0 commit comments

Comments
 (0)