@@ -2,6 +2,7 @@ use std::fs::File;
22use std:: io:: Write ;
33use transformrs:: Message ;
44use transformrs:: Provider ;
5+ use futures_util:: stream:: StreamExt ;
56
67#[ derive( clap:: Parser ) ]
78pub ( crate ) struct ChatArgs {
@@ -42,19 +43,29 @@ pub(crate) async fn chat(args: &ChatArgs, key: &transformrs::Key, input: &str) {
4243 . clone ( )
4344 . unwrap_or_else ( || default_model ( & provider) ) ;
4445 let messages = vec ! [ Message :: from_str( "user" , input) ] ;
45- let resp = transformrs:: chat:: chat_completion ( & provider, key, & model, & messages)
46- . await
47- . expect ( "Chat completion failed" ) ;
48- if args. raw_json {
49- let json = resp. raw_value ( ) ;
50- println ! ( "{}" , json. unwrap( ) ) ;
51- }
52- let resp = resp. structured ( ) . expect ( "Could not parse response" ) ;
53- let content = resp. choices [ 0 ] . message . content . clone ( ) ;
54- if let Some ( output) = args. output . clone ( ) {
55- let mut file = File :: create ( output) . unwrap ( ) ;
56- file. write_all ( content. to_string ( ) . as_bytes ( ) ) . unwrap ( ) ;
46+ if args. stream {
47+ let mut stream = transformrs:: chat:: stream_chat_completion ( & provider, key, & model, & messages)
48+ . await
49+ . expect ( "Streaming chat completion failed" ) ;
50+ while let Some ( resp) = stream. next ( ) . await {
51+ let msg = resp. choices [ 0 ] . delta . content . clone ( ) . unwrap_or_default ( ) ;
52+ print ! ( "{}" , msg) ;
53+ }
5754 } else {
58- println ! ( "{}" , content) ;
55+ let resp = transformrs:: chat:: chat_completion ( & provider, key, & model, & messages)
56+ . await
57+ . expect ( "Chat completion failed" ) ;
58+ if args. raw_json {
59+ let json = resp. raw_value ( ) ;
60+ println ! ( "{}" , json. unwrap( ) ) ;
61+ }
62+ let resp = resp. structured ( ) . expect ( "Could not parse response" ) ;
63+ let content = resp. choices [ 0 ] . message . content . clone ( ) ;
64+ if let Some ( output) = args. output . clone ( ) {
65+ let mut file = File :: create ( output) . unwrap ( ) ;
66+ file. write_all ( content. to_string ( ) . as_bytes ( ) ) . unwrap ( ) ;
67+ } else {
68+ println ! ( "{}" , content) ;
69+ }
5970 }
6071}
0 commit comments