Skip to content

Commit a6af5e2

Browse files
committed
feat: allow set tools in extra
1 parent ad0f876 commit a6af5e2

File tree

1 file changed

+34
-11
lines changed

1 file changed

+34
-11
lines changed

src/ai/mod.rs

Lines changed: 34 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -128,13 +128,13 @@ async fn test_groq_asr() {
128128
#[derive(Debug, Clone, serde::Serialize)]
129129
pub struct StableLlmRequest {
130130
stream: bool,
131-
#[serde(flatten, skip_serializing_if = "Option::is_none")]
132-
extra: Option<serde_json::Value>,
131+
#[serde(flatten)]
132+
extra: serde_json::Value,
133133
messages: Vec<llm::Content>,
134134
#[serde(skip_serializing_if = "String::is_empty")]
135135
model: String,
136-
#[serde(default, skip_serializing_if = "Vec::is_empty")]
137-
tools: Vec<llm::Tool>,
136+
// #[serde(default, skip_serializing_if = "Vec::is_empty")]
137+
// tools: Vec<llm::Tool>,
138138
#[serde(skip_serializing_if = "str::is_empty")]
139139
tool_choice: &'static str,
140140
}
@@ -143,12 +143,11 @@ pub struct StableLlmRequest {
143143
fn test_stable_llm_request_json() {
144144
let request = StableLlmRequest {
145145
stream: true,
146-
extra: Some(serde_json::json!({
146+
extra: serde_json::json!({
147147
"chat_id": "test-chat-id",
148-
})),
148+
}),
149149
messages: vec![],
150150
model: "test-model".to_string(),
151-
tools: vec![],
152151
tool_choice: "",
153152
};
154153

@@ -447,8 +446,6 @@ pub async fn llm_stable<'p, I: IntoIterator<Item = C>, C: AsRef<llm::Content>>(
447446
response_builder = response_builder.bearer_auth(token);
448447
};
449448

450-
let tool_choice = if tools.is_empty() { "" } else { "auto" };
451-
452449
let tool_name = tools
453450
.iter()
454451
.map(|t| t.function.name.as_str())
@@ -463,16 +460,42 @@ pub async fn llm_stable<'p, I: IntoIterator<Item = C>, C: AsRef<llm::Content>>(
463460
"model": model.to_string(),
464461
"tools": tool_name,
465462
"extra": extra,
466-
"tool_choice": tool_choice,
467463
}
468464
))?
469465
);
470466

467+
let mut tool_choice = "";
468+
469+
let tools = tools
470+
.iter()
471+
.map(|t| serde_json::to_value(&t).unwrap())
472+
.collect::<Vec<_>>();
473+
474+
let mut extra = extra.unwrap_or(serde_json::json!({}));
475+
if let Some(extra) = extra.as_object_mut() {
476+
match extra.entry("tools") {
477+
serde_json::map::Entry::Vacant(e) => {
478+
if !tools.is_empty() {
479+
e.insert(serde_json::Value::Array(tools));
480+
tool_choice = "auto";
481+
}
482+
}
483+
serde_json::map::Entry::Occupied(mut e) => {
484+
if let serde_json::Value::Array(arr) = e.get_mut() {
485+
tool_choice = "auto";
486+
487+
if !tools.is_empty() {
488+
arr.extend(tools);
489+
}
490+
}
491+
}
492+
}
493+
}
494+
471495
let request = StableLlmRequest {
472496
stream: true,
473497
messages,
474498
model: model.to_string(),
475-
tools,
476499
tool_choice,
477500
extra,
478501
};

0 commit comments

Comments
 (0)