@@ -3,38 +3,42 @@ use std::{
33 env,
44 net:: { IpAddr , SocketAddr } ,
55} ;
6-
6+ use std :: ptr :: null ;
77use axum:: { http:: StatusCode , response:: IntoResponse , routing:: post, Json , Router } ;
88use config:: { validate_registered_detectors, DetectorConfig , GatewayConfig } ;
9- use serde:: Serialize ;
109use serde_json:: json;
1110use serde_json:: { Map , Value } ;
1211use tower_http:: trace:: { self , TraceLayer } ;
1312use tracing:: Level ;
13+ use crate :: api:: { GenerationChoice , GenerationMessage , OrchestratorResponse } ;
14+ use crate :: config:: RouteConfig ;
1415
1516mod config;
17+ mod api;
1618
17- #[ derive( Debug , Serialize ) ]
18- struct OrchestratorDetector {
19- input : HashMap < String , serde_json:: Value > ,
20- // implement when output is completed, also need to see about splitting detectors in config to input/output
21- // output: HashMap<String, serde_json::Value>,
22- }
2319
2420fn get_orchestrator_detectors (
2521 detectors : Vec < String > ,
2622 detector_config : Vec < DetectorConfig > ,
27- ) -> OrchestratorDetector {
23+ ) -> api :: OrchestratorDetector {
2824 let mut input_detectors = HashMap :: new ( ) ;
25+ let mut output_detectors = HashMap :: new ( ) ;
2926
3027 for detector in detector_config {
3128 if detectors. contains ( & detector. name ) && detector. detector_params . is_some ( ) {
32- input_detectors. insert ( detector. name , detector. detector_params . unwrap ( ) ) ;
29+ let detector_params = detector. detector_params . unwrap ( ) ;
30+ if detector. input {
31+ input_detectors. insert ( detector. name . clone ( ) , detector_params. clone ( ) ) ;
32+ }
33+ if detector. output {
34+ output_detectors. insert ( detector. name , detector_params) ;
35+ }
3336 }
3437 }
3538
36- OrchestratorDetector {
39+ api :: OrchestratorDetector {
3740 input : input_detectors,
41+ output : output_detectors,
3842 }
3943}
4044
@@ -59,10 +63,11 @@ async fn main() {
5963 let gateway_config = gateway_config. clone ( ) ;
6064 let detectors = route. detectors . clone ( ) ;
6165 let path = format ! ( "/{}/v1/chat/completions" , route. name) ;
66+ let fallback_message = route. fallback_message . clone ( ) ;
6267 app = app. route (
6368 & path,
6469 post ( |Json ( payload) : Json < serde_json:: Value > | {
65- handle_generation ( Json ( payload) , detectors, gateway_config)
70+ handle_generation ( Json ( payload) , detectors, gateway_config, fallback_message )
6671 } ) ,
6772 ) ;
6873 tracing:: info!( "exposed endpoints: {}" , path) ;
@@ -86,27 +91,66 @@ async fn main() {
8691 axum:: serve ( listener, app) . await . unwrap ( ) ;
8792}
8893
94+
95+ async fn handle_orchestrator_payload_parsing ( orchestrator_response : & mut OrchestratorResponse , route_fallback_message : Option < String > ) {
96+ if route_fallback_message. is_some ( ) && orchestrator_response. detections . is_some ( ) {
97+ let fallback_generation = GenerationMessage {
98+ content : route_fallback_message. clone ( ) . unwrap ( ) ,
99+ refusal : None ,
100+ role : String :: from ( "assistant" ) ,
101+ tool_calls : None ,
102+ audio : None ,
103+ } ;
104+
105+ orchestrator_response. choices = vec ! [ GenerationChoice {
106+ message: fallback_generation,
107+ finish_reason: String :: from( "stop" ) ,
108+ index: 0 ,
109+ logprobs: None
110+ } ] ;
111+ }
112+ }
113+
89114async fn handle_generation (
90115 Json ( mut payload) : Json < serde_json:: Value > ,
91116 detectors : Vec < String > ,
92117 gateway_config : GatewayConfig ,
118+ route_fallback_message : Option < String > ,
93119) -> Result < impl IntoResponse , ( StatusCode , String ) > {
94120 let orchestrator_detectors = get_orchestrator_detectors ( detectors, gateway_config. detectors ) ;
95121
96122 let mut payload = payload. as_object_mut ( ) ;
97123
98- let url = format ! (
99- "http://{}:{}/api/v2/chat/completions-detection" ,
100- gateway_config. orchestrator. host, gateway_config. orchestrator. port
101- ) ;
124+ let url;
125+ if gateway_config. orchestrator . port . is_some ( ) {
126+ url = format ! (
127+ "http://{}:{}/api/v2/chat/completions-detection" ,
128+ gateway_config. orchestrator. host, gateway_config. orchestrator. port. unwrap( )
129+ ) ;
130+ } else {
131+ url = format ! (
132+ "https://{}/api/v2/chat/completions-detection" ,
133+ gateway_config. orchestrator. host
134+ ) ;
135+ }
136+
102137
103138 payload. as_mut ( ) . unwrap ( ) . insert (
104139 "detectors" . to_string ( ) ,
105140 serde_json:: to_value ( & orchestrator_detectors) . unwrap ( ) ,
106141 ) ;
107142 let response_payload = orchestrator_post_request ( payload, & url) . await ;
108-
109- Ok ( Json ( json ! ( response_payload. unwrap( ) ) ) . into_response ( ) )
143+ if response_payload. is_ok ( ) {
144+ let response_unwrapped = response_payload. unwrap ( ) ;
145+ println ! ( "{}" , response_unwrapped) ;
146+
147+ let mut response : OrchestratorResponse = serde_json:: from_value ( response_unwrapped) . unwrap ( ) ;
148+ handle_orchestrator_payload_parsing ( & mut response, route_fallback_message) . await ;
149+ Ok ( Json ( json ! ( response) ) . into_response ( ) )
150+ } else {
151+ //println!("{:#?}", response_payload.err().unwrap().to_string());
152+ Err ( ( StatusCode :: INTERNAL_SERVER_ERROR , response_payload. err ( ) . unwrap ( ) . to_string ( ) ) )
153+ }
110154}
111155
112156async fn orchestrator_post_request (
@@ -117,6 +161,5 @@ async fn orchestrator_post_request(
117161 let response = client. post ( url) . json ( & payload) . send ( ) ;
118162
119163 let json = response. await ?. json ( ) . await ?;
120- println ! ( "{:#?}" , json) ;
121164 Ok ( json)
122165}
0 commit comments