Skip to content

Commit 021b3f4

Browse files
committed
add fallback message support
1 parent b7175f7 commit 021b3f4

File tree

4 files changed

+143
-25
lines changed

4 files changed

+143
-25
lines changed

config/config.yaml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
11
orchestrator:
2-
host: localhost
3-
port: 8085
2+
host: guardrails-orchestrator-http-model-namespace.apps.rosa.trustyai-rob2.n1ai.p3.openshiftapps.com
43
detectors:
5-
- name: regex
4+
- name: regex_language
65
detector_params:
76
regex:
87
- email
@@ -11,6 +10,7 @@ detectors:
1110
routes:
1211
- name: pii
1312
detectors:
14-
- regex
13+
- regex_language
14+
fallback_message: "I'm sorry Dave, I'm afraid I can't do that."
1515
- name: passthrough
1616
detectors:

src/api.rs

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
use std::collections::HashMap;
2+
use serde::{Deserialize, Serialize};
3+
4+
#[derive(Debug, Serialize)]
5+
pub(crate) struct OrchestratorDetector {
6+
pub(crate) input: HashMap<String, serde_json::Value>,
7+
pub(crate) output: HashMap<String, serde_json::Value>,
8+
// implement when output is completed, also need to see about splitting detectors in config to input/output
9+
// output: HashMap<String, serde_json::Value>,
10+
}
11+
12+
#[derive(Serialize, Deserialize, Debug)]
13+
pub(crate) struct GenerationMessage {
14+
pub(crate) content: String,
15+
pub(crate) refusal: Option<String>,
16+
pub(crate) role: String,
17+
pub(crate) tool_calls: Option<serde_json::Value>,
18+
pub(crate) audio: Option<serde_json::Value>,
19+
}
20+
21+
#[derive(Serialize, Deserialize, Debug)]
22+
pub(crate) struct GenerationChoice {
23+
pub(crate) finish_reason: String,
24+
pub(crate) index: u32,
25+
pub(crate) message: GenerationMessage,
26+
pub(crate) logprobs: Option<serde_json::Value>,
27+
}
28+
29+
30+
#[derive(Serialize, Deserialize, Debug)]
31+
struct DetectionResult {
32+
start: serde_json::Value,
33+
end: u32,
34+
text: String,
35+
detection_type: String,
36+
detection: String,
37+
detector_id: String,
38+
score: f64,
39+
}
40+
41+
#[derive(Serialize, Deserialize, Debug)]
42+
struct InputDetection {
43+
message_index: u16,
44+
results: Option<Vec<DetectionResult>>
45+
}
46+
47+
#[derive(Serialize, Deserialize, Debug)]
48+
struct OutputDetection {
49+
choice_index: u32,
50+
results: Option<Vec<DetectionResult>>
51+
}
52+
53+
#[derive(Serialize, Deserialize, Debug)]
54+
pub(crate) struct Detections {
55+
input: Option<Vec<InputDetection>>,
56+
output: Option<Vec<OutputDetection>>,
57+
}
58+
59+
60+
#[derive(Serialize, Deserialize, Debug)]
61+
pub(crate) struct OrchestratorResponse {
62+
id: String,
63+
pub(crate) choices: Vec<GenerationChoice>,
64+
created: u64,
65+
model: String,
66+
service_tier: Option<String>,
67+
system_fingerprint: Option<String>,
68+
object: Option<String>,
69+
usage: serde_json::Value,
70+
pub(crate) detections: Option<Detections>,
71+
pub(crate) warnings: Option<Vec<HashMap<String, String>>>
72+
}

src/config.rs

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,19 +12,22 @@ pub struct GatewayConfig {
1212
#[derive(Debug, Deserialize, Clone)]
1313
pub struct OrchestratorConfig {
1414
pub host: String,
15-
pub port: u16,
15+
pub port: Option<u16>,
1616
}
1717

1818
#[derive(Debug, Deserialize, Clone)]
1919
pub struct DetectorConfig {
2020
pub name: String,
21+
pub input: bool,
22+
pub output: bool,
2123
pub detector_params: Option<serde_json::Value>,
2224
}
2325

2426
#[derive(Debug, Deserialize, Clone)]
2527
pub struct RouteConfig {
2628
pub name: String,
2729
pub detectors: Vec<String>,
30+
pub fallback_message: Option<String>,
2831
}
2932

3033
pub fn read_config(path: &str) -> GatewayConfig {
@@ -62,7 +65,7 @@ mod tests {
6265
let gc = GatewayConfig {
6366
orchestrator: OrchestratorConfig {
6467
host: "localhost".to_string(),
65-
port: 1234,
68+
port: Some(1234),
6669
},
6770
detectors: vec![DetectorConfig {
6871
name: "regex".to_string(),

src/main.rs

Lines changed: 62 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -3,38 +3,42 @@ use std::{
33
env,
44
net::{IpAddr, SocketAddr},
55
};
6-
6+
use std::ptr::null;
77
use axum::{http::StatusCode, response::IntoResponse, routing::post, Json, Router};
88
use config::{validate_registered_detectors, DetectorConfig, GatewayConfig};
9-
use serde::Serialize;
109
use serde_json::json;
1110
use serde_json::{Map, Value};
1211
use tower_http::trace::{self, TraceLayer};
1312
use tracing::Level;
13+
use crate::api::{GenerationChoice, GenerationMessage, OrchestratorResponse};
14+
use crate::config::RouteConfig;
1415

1516
mod config;
17+
mod api;
1618

17-
#[derive(Debug, Serialize)]
18-
struct OrchestratorDetector {
19-
input: HashMap<String, serde_json::Value>,
20-
// implement when output is completed, also need to see about splitting detectors in config to input/output
21-
// output: HashMap<String, serde_json::Value>,
22-
}
2319

2420
fn get_orchestrator_detectors(
2521
detectors: Vec<String>,
2622
detector_config: Vec<DetectorConfig>,
27-
) -> OrchestratorDetector {
23+
) -> api::OrchestratorDetector {
2824
let mut input_detectors = HashMap::new();
25+
let mut output_detectors = HashMap::new();
2926

3027
for detector in detector_config {
3128
if detectors.contains(&detector.name) && detector.detector_params.is_some() {
32-
input_detectors.insert(detector.name, detector.detector_params.unwrap());
29+
let detector_params = detector.detector_params.unwrap();
30+
if detector.input {
31+
input_detectors.insert(detector.name.clone(), detector_params.clone());
32+
}
33+
if detector.output {
34+
output_detectors.insert(detector.name, detector_params);
35+
}
3336
}
3437
}
3538

36-
OrchestratorDetector {
39+
api::OrchestratorDetector {
3740
input: input_detectors,
41+
output: output_detectors,
3842
}
3943
}
4044

@@ -59,10 +63,11 @@ async fn main() {
5963
let gateway_config = gateway_config.clone();
6064
let detectors = route.detectors.clone();
6165
let path = format!("/{}/v1/chat/completions", route.name);
66+
let fallback_message = route.fallback_message.clone();
6267
app = app.route(
6368
&path,
6469
post(|Json(payload): Json<serde_json::Value>| {
65-
handle_generation(Json(payload), detectors, gateway_config)
70+
handle_generation(Json(payload), detectors, gateway_config, fallback_message)
6671
}),
6772
);
6873
tracing::info!("exposed endpoints: {}", path);
@@ -86,27 +91,66 @@ async fn main() {
8691
axum::serve(listener, app).await.unwrap();
8792
}
8893

94+
95+
async fn handle_orchestrator_payload_parsing(orchestrator_response: &mut OrchestratorResponse, route_fallback_message: Option<String>) {
96+
if route_fallback_message.is_some() && orchestrator_response.detections.is_some() {
97+
let fallback_generation = GenerationMessage {
98+
content: route_fallback_message.clone().unwrap(),
99+
refusal: None,
100+
role: String::from("assistant"),
101+
tool_calls: None,
102+
audio: None,
103+
};
104+
105+
orchestrator_response.choices = vec![GenerationChoice {
106+
message: fallback_generation,
107+
finish_reason: String::from("stop"),
108+
index: 0,
109+
logprobs: None
110+
}];
111+
}
112+
}
113+
89114
async fn handle_generation(
90115
Json(mut payload): Json<serde_json::Value>,
91116
detectors: Vec<String>,
92117
gateway_config: GatewayConfig,
118+
route_fallback_message: Option<String>,
93119
) -> Result<impl IntoResponse, (StatusCode, String)> {
94120
let orchestrator_detectors = get_orchestrator_detectors(detectors, gateway_config.detectors);
95121

96122
let mut payload = payload.as_object_mut();
97123

98-
let url = format!(
99-
"http://{}:{}/api/v2/chat/completions-detection",
100-
gateway_config.orchestrator.host, gateway_config.orchestrator.port
101-
);
124+
let url;
125+
if gateway_config.orchestrator.port.is_some() {
126+
url = format!(
127+
"http://{}:{}/api/v2/chat/completions-detection",
128+
gateway_config.orchestrator.host, gateway_config.orchestrator.port.unwrap()
129+
);
130+
} else {
131+
url = format!(
132+
"https://{}/api/v2/chat/completions-detection",
133+
gateway_config.orchestrator.host
134+
);
135+
}
136+
102137

103138
payload.as_mut().unwrap().insert(
104139
"detectors".to_string(),
105140
serde_json::to_value(&orchestrator_detectors).unwrap(),
106141
);
107142
let response_payload = orchestrator_post_request(payload, &url).await;
108-
109-
Ok(Json(json!(response_payload.unwrap())).into_response())
143+
if response_payload.is_ok() {
144+
let response_unwrapped = response_payload.unwrap();
145+
println!("{}", response_unwrapped);
146+
147+
let mut response : OrchestratorResponse = serde_json::from_value(response_unwrapped).unwrap();
148+
handle_orchestrator_payload_parsing(&mut response, route_fallback_message).await;
149+
Ok(Json(json!(response)).into_response())
150+
} else {
151+
//println!("{:#?}", response_payload.err().unwrap().to_string());
152+
Err((StatusCode::INTERNAL_SERVER_ERROR, response_payload.err().unwrap().to_string()))
153+
}
110154
}
111155

112156
async fn orchestrator_post_request(
@@ -117,6 +161,5 @@ async fn orchestrator_post_request(
117161
let response = client.post(url).json(&payload).send();
118162

119163
let json = response.await?.json().await?;
120-
println!("{:#?}", json);
121164
Ok(json)
122165
}

0 commit comments

Comments
 (0)