Skip to content

Commit 3ac14a8

Browse files
Merge pull request #3 from resoluteCoder/refactor-fallback-message
refactor fallback message pr
2 parents 13024d5 + bf33f17 commit 3ac14a8

File tree

5 files changed

+94
-83
lines changed

5 files changed

+94
-83
lines changed

config/config.yaml

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,18 @@
11
orchestrator:
2-
host: guardrails-orchestrator-http-model-namespace.apps.rosa.trustyai-rob2.n1ai.p3.openshiftapps.com
2+
host: localhost
3+
port: 8085
34
detectors:
4-
- name: regex_language
5+
- name: regex-language
6+
input: false
7+
output: true
58
detector_params:
69
regex:
710
- email
811
- ssn
9-
- name: other_detector
1012
routes:
1113
- name: pii
1214
detectors:
13-
- regex_language
14-
fallback_message: "I'm sorry Dave, I'm afraid I can't do that."
15+
- regex-language
16+
fallback_message: "I'm sorry, I'm afraid I can't do that."
1517
- name: passthrough
1618
detectors:

curl.sh

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,9 @@
11
#!/bin/bash
22

33
# curl "localhost:8032/v1/chat/completions" \
4-
curl "localhost:8090/passthrough" \
4+
curl "localhost:8090/pii/v1/chat/completions" \
55
-H "Content-Type: application/json" \
66
-d '{
7-
"max_completion_tokens": 1,
87
"model": "Qwen/Qwen2.5-1.5B-Instruct",
98
"messages": [
109
{

src/api.rs

Lines changed: 35 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,31 +1,40 @@
1-
use std::collections::HashMap;
21
use serde::{Deserialize, Serialize};
2+
use std::collections::HashMap;
33

44
#[derive(Debug, Serialize)]
5-
pub(crate) struct OrchestratorDetector {
6-
pub(crate) input: HashMap<String, serde_json::Value>,
7-
pub(crate) output: HashMap<String, serde_json::Value>,
8-
// implement when output is completed, also need to see about splitting detectors in config to input/output
9-
// output: HashMap<String, serde_json::Value>,
5+
pub struct OrchestratorDetector {
6+
pub input: HashMap<String, serde_json::Value>,
7+
pub output: HashMap<String, serde_json::Value>,
108
}
119

1210
#[derive(Serialize, Deserialize, Debug)]
13-
pub(crate) struct GenerationMessage {
14-
pub(crate) content: String,
15-
pub(crate) refusal: Option<String>,
16-
pub(crate) role: String,
17-
pub(crate) tool_calls: Option<serde_json::Value>,
18-
pub(crate) audio: Option<serde_json::Value>,
11+
pub struct GenerationMessage {
12+
pub content: String,
13+
pub refusal: Option<String>,
14+
pub role: String,
15+
pub tool_calls: Option<serde_json::Value>,
16+
pub audio: Option<serde_json::Value>,
1917
}
2018

21-
#[derive(Serialize, Deserialize, Debug)]
22-
pub(crate) struct GenerationChoice {
23-
pub(crate) finish_reason: String,
24-
pub(crate) index: u32,
25-
pub(crate) message: GenerationMessage,
26-
pub(crate) logprobs: Option<serde_json::Value>,
19+
impl GenerationMessage {
20+
pub fn new(message: String) -> Self {
21+
GenerationMessage {
22+
content: message,
23+
refusal: None,
24+
role: String::from("assistant"),
25+
tool_calls: None,
26+
audio: None,
27+
}
28+
}
2729
}
2830

31+
#[derive(Serialize, Deserialize, Debug)]
32+
pub struct GenerationChoice {
33+
pub finish_reason: String,
34+
pub index: u32,
35+
pub message: GenerationMessage,
36+
pub logprobs: Option<serde_json::Value>,
37+
}
2938

3039
#[derive(Serialize, Deserialize, Debug)]
3140
struct DetectionResult {
@@ -41,32 +50,31 @@ struct DetectionResult {
4150
#[derive(Serialize, Deserialize, Debug)]
4251
struct InputDetection {
4352
message_index: u16,
44-
results: Option<Vec<DetectionResult>>
53+
results: Option<Vec<DetectionResult>>,
4554
}
4655

4756
#[derive(Serialize, Deserialize, Debug)]
4857
struct OutputDetection {
4958
choice_index: u32,
50-
results: Option<Vec<DetectionResult>>
59+
results: Option<Vec<DetectionResult>>,
5160
}
5261

5362
#[derive(Serialize, Deserialize, Debug)]
54-
pub(crate) struct Detections {
63+
pub struct Detections {
5564
input: Option<Vec<InputDetection>>,
5665
output: Option<Vec<OutputDetection>>,
5766
}
5867

59-
6068
#[derive(Serialize, Deserialize, Debug)]
61-
pub(crate) struct OrchestratorResponse {
69+
pub struct OrchestratorResponse {
6270
id: String,
63-
pub(crate) choices: Vec<GenerationChoice>,
71+
pub choices: Vec<GenerationChoice>,
6472
created: u64,
6573
model: String,
6674
service_tier: Option<String>,
6775
system_fingerprint: Option<String>,
6876
object: Option<String>,
6977
usage: serde_json::Value,
70-
pub(crate) detections: Option<Detections>,
71-
pub(crate) warnings: Option<Vec<HashMap<String, String>>>
72-
}
78+
pub detections: Option<Detections>,
79+
pub warnings: Option<Vec<HashMap<String, String>>>,
80+
}

src/config.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,11 +69,14 @@ mod tests {
6969
},
7070
detectors: vec![DetectorConfig {
7171
name: "regex".to_string(),
72+
input: false,
73+
output: false,
7274
detector_params: None,
7375
}],
7476
routes: vec![RouteConfig {
7577
name: "route1".to_string(),
7678
detectors: vec!["regex".to_string(), "not_existent_detector".to_string()],
79+
fallback_message: None,
7780
}],
7881
};
7982

src/main.rs

Lines changed: 48 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,26 @@
1+
use axum::{http::StatusCode, response::IntoResponse, routing::post, Json, Router};
2+
use config::{validate_registered_detectors, DetectorConfig, GatewayConfig};
3+
use serde_json::json;
4+
use serde_json::{Map, Value};
15
use std::{
26
collections::HashMap,
37
env,
48
net::{IpAddr, SocketAddr},
59
};
6-
use std::ptr::null;
7-
use axum::{http::StatusCode, response::IntoResponse, routing::post, Json, Router};
8-
use config::{validate_registered_detectors, DetectorConfig, GatewayConfig};
9-
use serde_json::json;
10-
use serde_json::{Map, Value};
1110
use tower_http::trace::{self, TraceLayer};
1211
use tracing::Level;
13-
use crate::api::{GenerationChoice, GenerationMessage, OrchestratorResponse};
14-
use crate::config::RouteConfig;
1512

16-
mod config;
1713
mod api;
14+
mod config;
1815

16+
use api::{
17+
Detections, GenerationChoice, GenerationMessage, OrchestratorDetector, OrchestratorResponse,
18+
};
1919

2020
fn get_orchestrator_detectors(
2121
detectors: Vec<String>,
2222
detector_config: Vec<DetectorConfig>,
23-
) -> api::OrchestratorDetector {
23+
) -> OrchestratorDetector {
2424
let mut input_detectors = HashMap::new();
2525
let mut output_detectors = HashMap::new();
2626

@@ -36,7 +36,7 @@ fn get_orchestrator_detectors(
3636
}
3737
}
3838

39-
api::OrchestratorDetector {
39+
OrchestratorDetector {
4040
input: input_detectors,
4141
output: output_detectors,
4242
}
@@ -85,30 +85,27 @@ async fn main() {
8585

8686
let ip: IpAddr = host.parse().expect("Failed to parse host IP address");
8787
let addr = SocketAddr::from((ip, http_port));
88-
tracing::info!("listening on {}", addr);
8988

9089
let listener = tokio::net::TcpListener::bind(addr).await.unwrap();
90+
tracing::info!("listening on {}", addr);
91+
9192
axum::serve(listener, app).await.unwrap();
9293
}
9394

94-
95-
async fn handle_orchestrator_payload_parsing(orchestrator_response: &mut OrchestratorResponse, route_fallback_message: Option<String>) {
96-
if route_fallback_message.is_some() && orchestrator_response.detections.is_some() {
97-
let fallback_generation = GenerationMessage {
98-
content: route_fallback_message.clone().unwrap(),
99-
refusal: None,
100-
role: String::from("assistant"),
101-
tool_calls: None,
102-
audio: None,
103-
};
104-
105-
orchestrator_response.choices = vec![GenerationChoice {
106-
message: fallback_generation,
95+
fn check_payload_detections(
96+
detections: &Option<Detections>,
97+
route_fallback_message: Option<String>,
98+
) -> Option<GenerationChoice> {
99+
if let (Some(fallback_message), Some(_)) = (route_fallback_message, detections) {
100+
return Some(GenerationChoice {
101+
message: GenerationMessage::new(fallback_message),
107102
finish_reason: String::from("stop"),
108103
index: 0,
109-
logprobs: None
110-
}];
104+
logprobs: None,
105+
});
111106
}
107+
108+
None
112109
}
113110

114111
async fn handle_generation(
@@ -121,45 +118,47 @@ async fn handle_generation(
121118

122119
let mut payload = payload.as_object_mut();
123120

124-
let url;
125-
if gateway_config.orchestrator.port.is_some() {
126-
url = format!(
121+
let url: String = match gateway_config.orchestrator.port {
122+
Some(port) => format!(
127123
"http://{}:{}/api/v2/chat/completions-detection",
128-
gateway_config.orchestrator.host, gateway_config.orchestrator.port.unwrap()
129-
);
130-
} else {
131-
url = format!(
124+
gateway_config.orchestrator.host, port
125+
),
126+
None => format!(
132127
"https://{}/api/v2/chat/completions-detection",
133128
gateway_config.orchestrator.host
134-
);
135-
}
136-
129+
),
130+
};
137131

138132
payload.as_mut().unwrap().insert(
139133
"detectors".to_string(),
140134
serde_json::to_value(&orchestrator_detectors).unwrap(),
141135
);
142-
let response_payload = orchestrator_post_request(payload, &url).await;
143-
if response_payload.is_ok() {
144-
let response_unwrapped = response_payload.unwrap();
145-
println!("{}", response_unwrapped);
146-
147-
let mut response : OrchestratorResponse = serde_json::from_value(response_unwrapped).unwrap();
148-
handle_orchestrator_payload_parsing(&mut response, route_fallback_message).await;
149-
Ok(Json(json!(response)).into_response())
150-
} else {
151-
//println!("{:#?}", response_payload.err().unwrap().to_string());
152-
Err((StatusCode::INTERNAL_SERVER_ERROR, response_payload.err().unwrap().to_string()))
136+
137+
let response_result = orchestrator_post_request(payload, &url).await;
138+
139+
match response_result {
140+
Ok(mut orchestrator_response) => {
141+
let detection =
142+
check_payload_detections(&orchestrator_response.detections, route_fallback_message);
143+
if let Some(message) = detection {
144+
orchestrator_response.choices = vec![message];
145+
}
146+
Ok(Json(json!(orchestrator_response)).into_response())
147+
}
148+
Err(_) => Err((
149+
StatusCode::INTERNAL_SERVER_ERROR,
150+
response_result.err().unwrap().to_string(),
151+
)),
153152
}
154153
}
155154

156155
async fn orchestrator_post_request(
157156
payload: Option<&mut Map<String, Value>>,
158157
url: &str,
159-
) -> Result<serde_json::Value, anyhow::Error> {
158+
) -> Result<OrchestratorResponse, anyhow::Error> {
160159
let client = reqwest::Client::new();
161160
let response = client.post(url).json(&payload).send();
162161

163162
let json = response.await?.json().await?;
164-
Ok(json)
163+
Ok(serde_json::from_value(json).expect("unexpected json response from request"))
165164
}

0 commit comments

Comments
 (0)