Skip to content

Commit 9655fb7

Browse files
committed
Add text to image
1 parent e044eb5 commit 9655fb7

File tree

5 files changed

+143
-2
lines changed

5 files changed

+143
-2
lines changed

.gitignore

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
**/target/
22
**/*.rs.bk
3-
ata.toml
43
Cargo.lock
54
test.env
6-
**/tmp*
5+
**/tmp*
6+
**/*.png
7+
**/*.jpg

README.md

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,19 @@ $ export OPENAI_KEY="$(grep 'OPENAI_KEY' .env | cut -d= -f2)"
1212
$ cat myfile.txt | trf tts | vlc - --intf dummy
1313
```
1414

15+
Or generate an image from text:
16+
17+
```sh
18+
$ export DEEPINFRA_KEY="$(grep 'DEEPINFRA_KEY' .env | cut -d= -f2)"
19+
20+
$ echo "A photo of a beach in Hawaii" | cargo run -- --verbose tti \
21+
--model=black-forest-labs/FLUX-1-dev \
22+
--output=myfile --width=1024 --height=1024 --steps=25
23+
24+
$ ls myfile*
25+
myfile.png
26+
```
27+
1528
You can also chat with an LLM:
1629

1730
```sh

src/main.rs

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
mod chat;
2+
mod tti;
23
mod tts;
34

45
use chat::ChatArgs;
56
use clap::Parser;
67
use std::io::Read;
78
use tracing::subscriber::SetGlobalDefaultError;
89
use transformrs::Key;
10+
use tti::TextToImageArgs;
911
use tts::TextToSpeechArgs;
1012

1113
#[derive(clap::Subcommand)]
@@ -15,6 +17,11 @@ enum Commands {
1517
/// Takes text input from stdin and chats with an AI model.
1618
#[command()]
1719
Chat(ChatArgs),
20+
/// Convert text to image
21+
///
22+
/// Takes text input from stdin and converts it to an image using text-to-image models.
23+
#[command()]
24+
Tti(TextToImageArgs),
1825
/// Convert text to speech
1926
///
2027
/// Takes text input from stdin and converts it to speech using text-to-speech models.
@@ -83,6 +90,9 @@ async fn main() {
8390
Commands::Chat(args) => {
8491
chat::chat(&args, &key, &input).await;
8592
}
93+
Commands::Tti(args) => {
94+
tti::tti(&args, &key, &input).await;
95+
}
8696
Commands::Tts(args) => {
8797
tts::tts(&args, &key, &input).await;
8898
}

src/tti.rs

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
use std::fs::File;
2+
use std::io::Write;
3+
use transformrs::text_to_image::TTIConfig;
4+
use transformrs::Provider;
5+
6+
#[derive(clap::Parser)]
7+
pub(crate) struct TextToImageArgs {
8+
/// Model to use (optional)
9+
#[arg(long)]
10+
model: Option<String>,
11+
12+
/// Number of steps (optional)
13+
#[arg(long, default_value_t = 10)]
14+
steps: u32,
15+
16+
/// CFG scale (optional)
17+
#[arg(long, default_value_t = 3)]
18+
cfg_scale: u32,
19+
20+
/// Height (optional)
21+
#[arg(long, default_value_t = 512)]
22+
height: u32,
23+
24+
/// Width (optional)
25+
#[arg(long, default_value_t = 512)]
26+
width: u32,
27+
28+
/// Output filename without extension
29+
#[arg(long, short = 'o')]
30+
output: Option<String>,
31+
}
32+
33+
fn default_model(provider: &Provider) -> String {
34+
match provider {
35+
Provider::DeepInfra => "black-forest-labs/FLUX-1-schnell".to_string(),
36+
_ => "black-forest-labs/FLUX-1-dev".to_string(),
37+
}
38+
}
39+
40+
pub(crate) async fn tti(args: &TextToImageArgs, key: &transformrs::Key, input: &str) {
41+
let provider = key.provider.clone();
42+
let config = TTIConfig {
43+
model: args
44+
.model
45+
.clone()
46+
.unwrap_or_else(|| default_model(&provider)),
47+
steps: Some(args.steps),
48+
cfg_scale: Some(args.cfg_scale),
49+
height: Some(args.height),
50+
width: Some(args.width),
51+
};
52+
let resp = transformrs::text_to_image::text_to_image(key, config, input)
53+
.await
54+
.unwrap()
55+
.structured()
56+
.unwrap();
57+
let encoded = &resp.images[0];
58+
let image = encoded.base64_decode().unwrap();
59+
if let Some(output) = &args.output {
60+
let filename = format!("{}.{}", output, image.filetype);
61+
let mut file = File::create(filename).unwrap();
62+
file.write_all(&image.image).unwrap();
63+
} else {
64+
std::io::stdout().write_all(&image.image).unwrap();
65+
}
66+
}

tests/tti.rs

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
mod common;
2+
3+
use common::load_key;
4+
use common::trf;
5+
use transformrs::Provider;
6+
7+
#[test]
8+
fn no_args() -> Result<(), Box<dyn std::error::Error>> {
9+
let dir = tempfile::tempdir().unwrap();
10+
let mut cmd = trf();
11+
let key = load_key(&Provider::DeepInfra);
12+
let cmd = cmd
13+
.arg("tti")
14+
.env("DEEPINFRA_KEY", key)
15+
.write_stdin("Hello world")
16+
.current_dir(&dir);
17+
let output = cmd.assert().success().get_output().stdout.clone();
18+
19+
assert!(output.len() > 0);
20+
21+
Ok(())
22+
}
23+
24+
fn default_settings_helper(provider: &Provider) -> Result<(), Box<dyn std::error::Error>> {
25+
let dir = tempfile::tempdir().unwrap();
26+
let mut cmd = trf();
27+
let key = load_key(provider);
28+
let name = provider.key_name();
29+
cmd.arg("--verbose")
30+
.arg("tti")
31+
.arg("--output=myfile")
32+
.arg("--width=128")
33+
.arg("--height=128")
34+
.arg("--steps=10")
35+
.arg("--cfg-scale=3")
36+
.env(name, key)
37+
.write_stdin("image of a beach")
38+
.current_dir(&dir)
39+
.assert()
40+
.success();
41+
42+
let path = dir.path().join("myfile.png");
43+
assert!(path.exists());
44+
45+
Ok(())
46+
}
47+
48+
#[test]
49+
fn default_settings_deepinfra() -> Result<(), Box<dyn std::error::Error>> {
50+
default_settings_helper(&Provider::DeepInfra)
51+
}

0 commit comments

Comments
 (0)