Skip to content

Commit d7245cb

Browse files
authored
fix: chat completions API now also passes tools along (openai#1167)
Prior to this PR, there were two big misses in `chat_completions.rs`: 1. The loop in `stream_chat_completions()` was only including items of type `ResponseItem::Message` when building up the `"messages"` JSON for the `POST` request to the `chat/completions` endpoint. This fixes things by ensuring other variants (`FunctionCall`, `LocalShellCall`, and `FunctionCallOutput`) are included, as well. 2. In `process_chat_sse()`, we were not recording tool calls and were only emitting items of type `ResponseEvent::OutputItemDone(ResponseItem::Message)` to the stream. Now we introduce `FunctionCallState`, which is used to accumulate the `delta`s of type `tool_calls`, so we can ultimately emit a `ResponseItem::FunctionCall`, when appropriate. While function calling now appears to work for chat completions with my local testing, I believe that there are still edge cases that are not covered and that this codepath would benefit from a battery of integration tests. (As part of that further cleanup, we should also work to support streaming responses in the UI.) The other important part of this PR is some cleanup in `core/src/codex.rs`. In particular, it was hard to reason about how `run_task()` was building up the list of messages to include in a request across the various cases: - Responses API - Chat Completions API - Responses API used in concert with ZDR I like to think things are a bit cleaner now where: - `zdr_transcript` (if present) contains all messages in the history of the conversation, which includes function call outputs that have not been sent back to the model yet - `pending_input` includes any messages the user has submitted while the turn is in flight that need to be injected as part of the next `POST` to the model - `input_for_next_turn` includes the tool call outputs that have not been sent back to the model yet
1 parent e40f86b commit d7245cb

File tree

2 files changed

+301
-73
lines changed

2 files changed

+301
-73
lines changed

codex-rs/core/src/chat_completions.rs

Lines changed: 171 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,7 @@ use crate::models::ResponseItem;
2828
use crate::openai_tools::create_tools_json_for_chat_completions_api;
2929
use crate::util::backoff;
3030

31-
/// Implementation for the classic Chat Completions API. This is intentionally
32-
/// minimal: we only stream back plain assistant text.
31+
/// Implementation for the classic Chat Completions API.
3332
pub(crate) async fn stream_chat_completions(
3433
prompt: &Prompt,
3534
model: &str,
@@ -43,17 +42,67 @@ pub(crate) async fn stream_chat_completions(
4342
messages.push(json!({"role": "system", "content": full_instructions}));
4443

4544
for item in &prompt.input {
46-
if let ResponseItem::Message { role, content } = item {
47-
let mut text = String::new();
48-
for c in content {
49-
match c {
50-
ContentItem::InputText { text: t } | ContentItem::OutputText { text: t } => {
51-
text.push_str(t);
45+
match item {
46+
ResponseItem::Message { role, content } => {
47+
let mut text = String::new();
48+
for c in content {
49+
match c {
50+
ContentItem::InputText { text: t }
51+
| ContentItem::OutputText { text: t } => {
52+
text.push_str(t);
53+
}
54+
_ => {}
5255
}
53-
_ => {}
5456
}
57+
messages.push(json!({"role": role, "content": text}));
58+
}
59+
ResponseItem::FunctionCall {
60+
name,
61+
arguments,
62+
call_id,
63+
} => {
64+
messages.push(json!({
65+
"role": "assistant",
66+
"content": null,
67+
"tool_calls": [{
68+
"id": call_id,
69+
"type": "function",
70+
"function": {
71+
"name": name,
72+
"arguments": arguments,
73+
}
74+
}]
75+
}));
76+
}
77+
ResponseItem::LocalShellCall {
78+
id,
79+
call_id: _,
80+
status,
81+
action,
82+
} => {
83+
// Confirm with API team.
84+
messages.push(json!({
85+
"role": "assistant",
86+
"content": null,
87+
"tool_calls": [{
88+
"id": id.clone().unwrap_or_else(|| "".to_string()),
89+
"type": "local_shell_call",
90+
"status": status,
91+
"action": action,
92+
}]
93+
}));
94+
}
95+
ResponseItem::FunctionCallOutput { call_id, output } => {
96+
messages.push(json!({
97+
"role": "tool",
98+
"tool_call_id": call_id,
99+
"content": output.content,
100+
}));
101+
}
102+
ResponseItem::Reasoning { .. } | ResponseItem::Other => {
103+
// Omit these items from the conversation history.
104+
continue;
55105
}
56-
messages.push(json!({"role": role, "content": text}));
57106
}
58107
}
59108

@@ -68,9 +117,8 @@ pub(crate) async fn stream_chat_completions(
68117
let base_url = provider.base_url.trim_end_matches('/');
69118
let url = format!("{}/chat/completions", base_url);
70119

71-
debug!(url, "POST (chat)");
72-
trace!(
73-
"request payload: {}",
120+
debug!(
121+
"POST to {url}: {}",
74122
serde_json::to_string_pretty(&payload).unwrap_or_default()
75123
);
76124

@@ -140,6 +188,21 @@ where
140188

141189
let idle_timeout = *OPENAI_STREAM_IDLE_TIMEOUT_MS;
142190

191+
// State to accumulate a function call across streaming chunks.
192+
// OpenAI may split the `arguments` string over multiple `delta` events
193+
// until the chunk whose `finish_reason` is `tool_calls` is emitted. We
194+
// keep collecting the pieces here and forward a single
195+
// `ResponseItem::FunctionCall` once the call is complete.
196+
#[derive(Default)]
197+
struct FunctionCallState {
198+
name: Option<String>,
199+
arguments: String,
200+
call_id: Option<String>,
201+
active: bool,
202+
}
203+
204+
let mut fn_call_state = FunctionCallState::default();
205+
143206
loop {
144207
let sse = match timeout(idle_timeout, stream.next()).await {
145208
Ok(Some(Ok(ev))) => ev,
@@ -179,23 +242,89 @@ where
179242
Ok(v) => v,
180243
Err(_) => continue,
181244
};
245+
trace!("chat_completions received SSE chunk: {chunk:?}");
246+
247+
let choice_opt = chunk.get("choices").and_then(|c| c.get(0));
248+
249+
if let Some(choice) = choice_opt {
250+
// Handle assistant content tokens.
251+
if let Some(content) = choice
252+
.get("delta")
253+
.and_then(|d| d.get("content"))
254+
.and_then(|c| c.as_str())
255+
{
256+
let item = ResponseItem::Message {
257+
role: "assistant".to_string(),
258+
content: vec![ContentItem::OutputText {
259+
text: content.to_string(),
260+
}],
261+
};
262+
263+
let _ = tx_event.send(Ok(ResponseEvent::OutputItemDone(item))).await;
264+
}
265+
266+
// Handle streaming function / tool calls.
267+
if let Some(tool_calls) = choice
268+
.get("delta")
269+
.and_then(|d| d.get("tool_calls"))
270+
.and_then(|tc| tc.as_array())
271+
{
272+
if let Some(tool_call) = tool_calls.first() {
273+
// Mark that we have an active function call in progress.
274+
fn_call_state.active = true;
275+
276+
// Extract call_id if present.
277+
if let Some(id) = tool_call.get("id").and_then(|v| v.as_str()) {
278+
fn_call_state.call_id.get_or_insert_with(|| id.to_string());
279+
}
280+
281+
// Extract function details if present.
282+
if let Some(function) = tool_call.get("function") {
283+
if let Some(name) = function.get("name").and_then(|n| n.as_str()) {
284+
fn_call_state.name.get_or_insert_with(|| name.to_string());
285+
}
286+
287+
if let Some(args_fragment) =
288+
function.get("arguments").and_then(|a| a.as_str())
289+
{
290+
fn_call_state.arguments.push_str(args_fragment);
291+
}
292+
}
293+
}
294+
}
295+
296+
// Emit end-of-turn when finish_reason signals completion.
297+
if let Some(finish_reason) = choice.get("finish_reason").and_then(|v| v.as_str()) {
298+
match finish_reason {
299+
"tool_calls" if fn_call_state.active => {
300+
// Build the FunctionCall response item.
301+
let item = ResponseItem::FunctionCall {
302+
name: fn_call_state.name.clone().unwrap_or_else(|| "".to_string()),
303+
arguments: fn_call_state.arguments.clone(),
304+
call_id: fn_call_state.call_id.clone().unwrap_or_else(String::new),
305+
};
306+
307+
// Emit it downstream.
308+
let _ = tx_event.send(Ok(ResponseEvent::OutputItemDone(item))).await;
309+
}
310+
"stop" => {
311+
// Regular turn without tool-call.
312+
}
313+
_ => {}
314+
}
182315

183-
let content_opt = chunk
184-
.get("choices")
185-
.and_then(|c| c.get(0))
186-
.and_then(|c| c.get("delta"))
187-
.and_then(|d| d.get("content"))
188-
.and_then(|c| c.as_str());
189-
190-
if let Some(content) = content_opt {
191-
let item = ResponseItem::Message {
192-
role: "assistant".to_string(),
193-
content: vec![ContentItem::OutputText {
194-
text: content.to_string(),
195-
}],
196-
};
197-
198-
let _ = tx_event.send(Ok(ResponseEvent::OutputItemDone(item))).await;
316+
// Emit Completed regardless of reason so the agent can advance.
317+
let _ = tx_event
318+
.send(Ok(ResponseEvent::Completed {
319+
response_id: String::new(),
320+
}))
321+
.await;
322+
323+
// Prepare for potential next turn (should not happen in same stream).
324+
// fn_call_state = FunctionCallState::default();
325+
326+
return; // End processing for this SSE stream.
327+
}
199328
}
200329
}
201330
}
@@ -242,20 +371,28 @@ where
242371
Poll::Ready(None) => return Poll::Ready(None),
243372
Poll::Ready(Some(Err(e))) => return Poll::Ready(Some(Err(e))),
244373
Poll::Ready(Some(Ok(ResponseEvent::OutputItemDone(item)))) => {
245-
// Accumulate *assistant* text but do not emit yet.
246-
if let crate::models::ResponseItem::Message { role, content } = &item {
247-
if role == "assistant" {
374+
// If this is an incremental assistant message chunk, accumulate but
375+
// do NOT emit yet. Forward any other item (e.g. FunctionCall) right
376+
// away so downstream consumers see it.
377+
378+
let is_assistant_delta = matches!(&item, crate::models::ResponseItem::Message { role, .. } if role == "assistant");
379+
380+
if is_assistant_delta {
381+
if let crate::models::ResponseItem::Message { content, .. } = &item {
248382
if let Some(text) = content.iter().find_map(|c| match c {
249383
crate::models::ContentItem::OutputText { text } => Some(text),
250384
_ => None,
251385
}) {
252386
this.cumulative.push_str(text);
253387
}
254388
}
389+
390+
// Swallow partial assistant chunk; keep polling.
391+
continue;
255392
}
256393

257-
// Swallow partial event; keep polling.
258-
continue;
394+
// Not an assistant message – forward immediately.
395+
return Poll::Ready(Some(Ok(ResponseEvent::OutputItemDone(item))));
259396
}
260397
Poll::Ready(Some(Ok(ResponseEvent::Completed { response_id }))) => {
261398
if !this.cumulative.is_empty() {

0 commit comments

Comments
 (0)