1+ use axum:: { http:: StatusCode , response:: IntoResponse , routing:: post, Json , Router } ;
2+ use config:: { validate_registered_detectors, DetectorConfig , GatewayConfig } ;
3+ use serde_json:: json;
4+ use serde_json:: { Map , Value } ;
15use std:: {
26 collections:: HashMap ,
37 env,
48 net:: { IpAddr , SocketAddr } ,
59} ;
6- use std:: ptr:: null;
7- use axum:: { http:: StatusCode , response:: IntoResponse , routing:: post, Json , Router } ;
8- use config:: { validate_registered_detectors, DetectorConfig , GatewayConfig } ;
9- use serde_json:: json;
10- use serde_json:: { Map , Value } ;
1110use tower_http:: trace:: { self , TraceLayer } ;
1211use tracing:: Level ;
13- use crate :: api:: { GenerationChoice , GenerationMessage , OrchestratorResponse } ;
14- use crate :: config:: RouteConfig ;
1512
16- mod config;
1713mod api;
14+ mod config;
1815
16+ use api:: {
17+ Detections , GenerationChoice , GenerationMessage , OrchestratorDetector , OrchestratorResponse ,
18+ } ;
1919
2020fn get_orchestrator_detectors (
2121 detectors : Vec < String > ,
2222 detector_config : Vec < DetectorConfig > ,
23- ) -> api :: OrchestratorDetector {
23+ ) -> OrchestratorDetector {
2424 let mut input_detectors = HashMap :: new ( ) ;
2525 let mut output_detectors = HashMap :: new ( ) ;
2626
@@ -36,7 +36,7 @@ fn get_orchestrator_detectors(
3636 }
3737 }
3838
39- api :: OrchestratorDetector {
39+ OrchestratorDetector {
4040 input : input_detectors,
4141 output : output_detectors,
4242 }
@@ -85,30 +85,27 @@ async fn main() {
8585
8686 let ip: IpAddr = host. parse ( ) . expect ( "Failed to parse host IP address" ) ;
8787 let addr = SocketAddr :: from ( ( ip, http_port) ) ;
88- tracing:: info!( "listening on {}" , addr) ;
8988
9089 let listener = tokio:: net:: TcpListener :: bind ( addr) . await . unwrap ( ) ;
90+ tracing:: info!( "listening on {}" , addr) ;
91+
9192 axum:: serve ( listener, app) . await . unwrap ( ) ;
9293}
9394
94-
95- async fn handle_orchestrator_payload_parsing ( orchestrator_response : & mut OrchestratorResponse , route_fallback_message : Option < String > ) {
96- if route_fallback_message. is_some ( ) && orchestrator_response. detections . is_some ( ) {
97- let fallback_generation = GenerationMessage {
98- content : route_fallback_message. clone ( ) . unwrap ( ) ,
99- refusal : None ,
100- role : String :: from ( "assistant" ) ,
101- tool_calls : None ,
102- audio : None ,
103- } ;
104-
105- orchestrator_response. choices = vec ! [ GenerationChoice {
106- message: fallback_generation,
95+ fn check_payload_detections (
96+ detections : & Option < Detections > ,
97+ route_fallback_message : Option < String > ,
98+ ) -> Option < GenerationChoice > {
99+ if let ( Some ( fallback_message) , Some ( _) ) = ( route_fallback_message, detections) {
100+ return Some ( GenerationChoice {
101+ message : GenerationMessage :: new ( fallback_message) ,
107102 finish_reason : String :: from ( "stop" ) ,
108103 index : 0 ,
109- logprobs: None
110- } ] ;
104+ logprobs : None ,
105+ } ) ;
111106 }
107+
108+ None
112109}
113110
114111async fn handle_generation (
@@ -121,45 +118,47 @@ async fn handle_generation(
121118
122119 let mut payload = payload. as_object_mut ( ) ;
123120
124- let url;
125- if gateway_config. orchestrator . port . is_some ( ) {
126- url = format ! (
121+ let url: String = match gateway_config. orchestrator . port {
122+ Some ( port) => format ! (
127123 "http://{}:{}/api/v2/chat/completions-detection" ,
128- gateway_config. orchestrator. host, gateway_config. orchestrator. port. unwrap( )
129- ) ;
130- } else {
131- url = format ! (
124+ gateway_config. orchestrator. host, port
125+ ) ,
126+ None => format ! (
132127 "https://{}/api/v2/chat/completions-detection" ,
133128 gateway_config. orchestrator. host
134- ) ;
135- }
136-
129+ ) ,
130+ } ;
137131
138132 payload. as_mut ( ) . unwrap ( ) . insert (
139133 "detectors" . to_string ( ) ,
140134 serde_json:: to_value ( & orchestrator_detectors) . unwrap ( ) ,
141135 ) ;
142- let response_payload = orchestrator_post_request ( payload, & url) . await ;
143- if response_payload. is_ok ( ) {
144- let response_unwrapped = response_payload. unwrap ( ) ;
145- println ! ( "{}" , response_unwrapped) ;
146-
147- let mut response : OrchestratorResponse = serde_json:: from_value ( response_unwrapped) . unwrap ( ) ;
148- handle_orchestrator_payload_parsing ( & mut response, route_fallback_message) . await ;
149- Ok ( Json ( json ! ( response) ) . into_response ( ) )
150- } else {
151- //println!("{:#?}", response_payload.err().unwrap().to_string());
152- Err ( ( StatusCode :: INTERNAL_SERVER_ERROR , response_payload. err ( ) . unwrap ( ) . to_string ( ) ) )
136+
137+ let response_result = orchestrator_post_request ( payload, & url) . await ;
138+
139+ match response_result {
140+ Ok ( mut orchestrator_response) => {
141+ let detection =
142+ check_payload_detections ( & orchestrator_response. detections , route_fallback_message) ;
143+ if let Some ( message) = detection {
144+ orchestrator_response. choices = vec ! [ message] ;
145+ }
146+ Ok ( Json ( json ! ( orchestrator_response) ) . into_response ( ) )
147+ }
148+ Err ( _) => Err ( (
149+ StatusCode :: INTERNAL_SERVER_ERROR ,
150+ response_result. err ( ) . unwrap ( ) . to_string ( ) ,
151+ ) ) ,
153152 }
154153}
155154
156155async fn orchestrator_post_request (
157156 payload : Option < & mut Map < String , Value > > ,
158157 url : & str ,
159- ) -> Result < serde_json :: Value , anyhow:: Error > {
158+ ) -> Result < OrchestratorResponse , anyhow:: Error > {
160159 let client = reqwest:: Client :: new ( ) ;
161160 let response = client. post ( url) . json ( & payload) . send ( ) ;
162161
163162 let json = response. await ?. json ( ) . await ?;
164- Ok ( json)
163+ Ok ( serde_json :: from_value ( json) . expect ( "unexpected json response from request" ) )
165164}
0 commit comments