Skip to content

Commit 81c6b2d

Browse files
authored
{flow,agent_tool,processor/functioncall}: improve the notification wating time during multiple rounds of flow sessions to prevent message non persistence in multi-agent mode (#424)
1. Migration of completion waiting notification from processor to flow. 2. In multi-agent mode, it is necessary to ensure that the events of the previous agent have been appended to the session when executing the next agent
1 parent 77e2974 commit 81c6b2d

File tree

8 files changed

+116
-146
lines changed

8 files changed

+116
-146
lines changed

agent/invocation.go

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -270,6 +270,11 @@ func EmitEvent(ctx context.Context, inv *Invocation, ch chan<- *event.Event,
270270
return event.EmitEvent(ctx, ch, e)
271271
}
272272

273+
// GetAppendEventNoticeKey get append event notice key.
274+
func GetAppendEventNoticeKey(eventID string) string {
275+
return AppendEventNoticeKeyPrefix + eventID
276+
}
277+
273278
// AddNoticeChannelAndWait add notice channel and wait it complete
274279
func (inv *Invocation) AddNoticeChannelAndWait(ctx context.Context, key string, timeout time.Duration) error {
275280
ch := inv.AddNoticeChannel(ctx, key)

internal/flow/llmflow/llmflow.go

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ package llmflow
1313
import (
1414
"context"
1515
"errors"
16+
"time"
1617

1718
oteltrace "go.opentelemetry.io/otel/trace"
1819

@@ -29,6 +30,9 @@ import (
2930

3031
const (
3132
defaultChannelBufferSize = 256
33+
34+
// Timeout for event completion signaling.
35+
eventCompletionTimeout = 5 * time.Second
3236
)
3337

3438
// Options contains configuration options for creating a Flow.
@@ -74,8 +78,8 @@ func (f *Flow) Run(ctx context.Context, invocation *agent.Invocation) (<-chan *e
7478
defer close(eventChan)
7579

7680
for {
77-
// Check if context is cancelled.
78-
if err := agent.CheckContextCancelled(ctx); err != nil {
81+
// emit start event and wait for completion notice.
82+
if err := f.emitStartEventAndWait(ctx, invocation, eventChan); err != nil {
7983
return
8084
}
8185

@@ -124,6 +128,27 @@ func (f *Flow) Run(ctx context.Context, invocation *agent.Invocation) (<-chan *e
124128
return eventChan, nil
125129
}
126130

131+
func (f *Flow) emitStartEventAndWait(ctx context.Context, invocation *agent.Invocation,
132+
eventChan chan<- *event.Event) error {
133+
invocationID, agentName := "", ""
134+
if invocation != nil {
135+
invocationID = invocation.InvocationID
136+
agentName = invocation.AgentName
137+
}
138+
startEvent := event.New(invocationID, agentName)
139+
startEvent.RequiresCompletion = true
140+
agent.EmitEvent(ctx, invocation, eventChan, startEvent)
141+
142+
// Wait for completion notice.
143+
// Ensure that the events of the previous agent or the previous step have been synchronized to the session.
144+
completionID := agent.GetAppendEventNoticeKey(startEvent.ID)
145+
err := invocation.AddNoticeChannelAndWait(ctx, completionID, eventCompletionTimeout)
146+
if errors.Is(err, context.Canceled) {
147+
return err
148+
}
149+
return nil
150+
}
151+
127152
// runOneStep executes one step of the flow (one LLM call cycle).
128153
// Returns the last event generated, or nil if no events.
129154
func (f *Flow) runOneStep(

internal/flow/llmflow/llmflow_endinvocation_test.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ func TestPreprocess_StopsAfterEndInvocation(t *testing.T) {
4242
}
4343

4444
f := New(reqProcs, nil, Options{})
45-
inv := &agent.Invocation{InvocationID: "inv-pre", AgentName: "agent-pre"}
45+
inv := agent.NewInvocation()
4646

4747
// Act
4848
ch, err := f.Run(context.Background(), inv)
@@ -55,8 +55,8 @@ func TestPreprocess_StopsAfterEndInvocation(t *testing.T) {
5555

5656
// Assert
5757
require.True(t, called, "subsequent processors should run after EndInvocation")
58-
require.Len(t, events, 2)
59-
require.Equal(t, "preprocess.end", events[0].Object)
58+
require.Len(t, events, 3)
59+
require.Equal(t, "preprocess.end", events[1].Object)
6060
}
6161

6262
// twoChunkModel returns two streaming chunks to ensure we break after EndInvocation.

internal/flow/llmflow/llmflow_test.go

Lines changed: 80 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -178,31 +178,31 @@ func TestModelCallbacks_BeforeSkip(t *testing.T) {
178178
})
179179

180180
llmFlow := New(nil, nil, Options{ModelCallbacks: modelCallbacks})
181-
invocation := &agent.Invocation{
182-
InvocationID: "test-invocation",
183-
AgentName: "test-agent",
184-
Model: &mockModel{
181+
invocation := agent.NewInvocation(
182+
agent.WithInvocationModel(&mockModel{
185183
responses: []*model.Response{{ID: "should-not-be-called"}},
186-
},
187-
Session: &session.Session{
188-
ID: "test-session",
189-
},
190-
}
184+
}),
185+
agent.WithInvocationSession(&session.Session{ID: "test-session"}),
186+
)
191187
eventChan, err := llmFlow.Run(ctx, invocation)
192188
require.NoError(t, err)
193189
var events []*event.Event
194190
for evt := range eventChan {
191+
if evt.RequiresCompletion {
192+
key := agent.AppendEventNoticeKeyPrefix + evt.ID
193+
invocation.NotifyCompletion(ctx, key)
194+
}
195195
events = append(events, evt)
196-
// Receive the first event and cancel ctx to prevent deadlock.
197-
cancel()
198-
break
196+
if len(events) >= 2 {
197+
break
198+
}
199199
}
200-
require.Equal(t, 1, len(events))
201-
require.Equal(t, "skip-response", events[0].Response.ID)
200+
require.Equal(t, 2, len(events))
201+
require.Equal(t, "skip-response", events[1].Response.ID)
202202
}
203203

204204
func TestModelCBs_BeforeCustom(t *testing.T) {
205-
ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond)
205+
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
206206
defer cancel()
207207

208208
modelCallbacks := model.NewCallbacks()
@@ -211,27 +211,27 @@ func TestModelCBs_BeforeCustom(t *testing.T) {
211211
})
212212

213213
llmFlow := New(nil, nil, Options{ModelCallbacks: modelCallbacks})
214-
invocation := &agent.Invocation{
215-
InvocationID: "test-invocation",
216-
AgentName: "test-agent",
217-
Model: &mockModel{
214+
invocation := agent.NewInvocation(
215+
agent.WithInvocationModel(&mockModel{
218216
responses: []*model.Response{{ID: "should-not-be-called"}},
219-
},
220-
Session: &session.Session{
221-
ID: "test-session",
222-
},
223-
}
217+
}),
218+
agent.WithInvocationSession(&session.Session{ID: "test-session"}),
219+
)
224220
eventChan, err := llmFlow.Run(ctx, invocation)
225221
require.NoError(t, err)
226222
var events []*event.Event
227223
for evt := range eventChan {
224+
if evt.RequiresCompletion {
225+
key := agent.AppendEventNoticeKeyPrefix + evt.ID
226+
invocation.NotifyCompletion(ctx, key)
227+
}
228228
events = append(events, evt)
229-
// Receive the first event and cancel ctx to prevent deadlock.
230-
cancel()
231-
break
229+
if len(events) >= 2 {
230+
break
231+
}
232232
}
233-
require.Equal(t, 1, len(events))
234-
require.Equal(t, "custom-before", events[0].Response.ID)
233+
require.Equal(t, 2, len(events))
234+
require.Equal(t, "custom-before", events[1].Response.ID)
235235
}
236236

237237
func TestModelCallbacks_BeforeError(t *testing.T) {
@@ -244,26 +244,32 @@ func TestModelCallbacks_BeforeError(t *testing.T) {
244244
})
245245

246246
llmFlow := New(nil, nil, Options{ModelCallbacks: modelCallbacks})
247-
invocation := &agent.Invocation{
248-
InvocationID: "test-invocation",
249-
AgentName: "test-agent",
250-
Model: &mockModel{
247+
invocation := agent.NewInvocation(
248+
agent.WithInvocationModel(&mockModel{
251249
responses: []*model.Response{{ID: "should-not-be-called"}},
252-
},
253-
}
250+
}),
251+
agent.WithInvocationSession(&session.Session{ID: "test-session"}),
252+
)
254253
eventChan, err := llmFlow.Run(ctx, invocation)
255254
require.NoError(t, err)
256255
var events []*event.Event
257256
for evt := range eventChan {
257+
if evt.RequiresCompletion {
258+
key := agent.AppendEventNoticeKeyPrefix + evt.ID
259+
invocation.NotifyCompletion(ctx, key)
260+
}
258261
events = append(events, evt)
262+
if len(events) >= 2 {
263+
break
264+
}
259265
// Receive the first error event and cancel ctx to prevent deadlock.
260266
if evt.Error != nil && evt.Error.Message == "before error" {
261267
cancel()
262268
break
263269
}
264270
}
265-
require.Equal(t, 1, len(events))
266-
require.Equal(t, "before error", events[0].Error.Message)
271+
require.Equal(t, 2, len(events))
272+
require.Equal(t, "before error", events[1].Error.Message)
267273
}
268274

269275
func TestModelCBs_AfterOverride(t *testing.T) {
@@ -278,28 +284,28 @@ func TestModelCBs_AfterOverride(t *testing.T) {
278284
)
279285

280286
llmFlow := New(nil, nil, Options{ModelCallbacks: modelCallbacks})
281-
invocation := &agent.Invocation{
282-
InvocationID: "test-invocation",
283-
AgentName: "test-agent",
284-
Model: &mockModel{
287+
invocation := agent.NewInvocation(
288+
agent.WithInvocationModel(&mockModel{
285289
responses: []*model.Response{{ID: "original"}},
286-
},
287-
Session: &session.Session{
288-
ID: "test-session",
289-
},
290-
}
290+
}),
291+
agent.WithInvocationSession(&session.Session{ID: "test-session"}),
292+
)
291293
eventChan, err := llmFlow.Run(ctx, invocation)
292294
require.NoError(t, err)
293295
var events []*event.Event
294296
for evt := range eventChan {
297+
if evt.RequiresCompletion {
298+
key := agent.AppendEventNoticeKeyPrefix + evt.ID
299+
invocation.NotifyCompletion(ctx, key)
300+
}
295301
events = append(events, evt)
296-
// Receive the first event and cancel ctx to prevent deadlock.
297-
cancel()
298-
break
302+
if len(events) >= 2 {
303+
break
304+
}
299305
}
300-
require.Equal(t, 1, len(events))
306+
require.Equal(t, 2, len(events))
301307
t.Log(events[0])
302-
require.Equal(t, "after-override", events[0].Response.Object)
308+
require.Equal(t, "after-override", events[1].Response.Object)
303309
}
304310

305311
func TestModelCallbacks_AfterError(t *testing.T) {
@@ -314,29 +320,32 @@ func TestModelCallbacks_AfterError(t *testing.T) {
314320
)
315321

316322
llmFlow := New(nil, nil, Options{ModelCallbacks: modelCallbacks})
317-
invocation := &agent.Invocation{
318-
InvocationID: "test-invocation",
319-
AgentName: "test-agent",
320-
Model: &mockModel{
323+
invocation := agent.NewInvocation(
324+
agent.WithInvocationModel(&mockModel{
321325
responses: []*model.Response{{ID: "original"}},
322-
},
323-
Session: &session.Session{
324-
ID: "test-session",
325-
},
326-
}
326+
}),
327+
agent.WithInvocationSession(&session.Session{ID: "test-session"}),
328+
)
327329
eventChan, err := llmFlow.Run(ctx, invocation)
328330
require.NoError(t, err)
329331
var events []*event.Event
330332
for evt := range eventChan {
333+
if evt.RequiresCompletion {
334+
key := agent.AppendEventNoticeKeyPrefix + evt.ID
335+
invocation.NotifyCompletion(ctx, key)
336+
}
331337
events = append(events, evt)
338+
if len(events) >= 2 {
339+
break
340+
}
332341
// Receive the first error event and cancel ctx to prevent deadlock.
333342
if evt.Error != nil && evt.Error.Message == "after error" {
334343
cancel()
335344
break
336345
}
337346
}
338-
require.Equal(t, 1, len(events))
339-
require.Equal(t, "after error", events[0].Error.Message)
347+
require.Equal(t, 2, len(events))
348+
require.Equal(t, "after error", events[1].Error.Message)
340349
}
341350

342351
// noResponseModel returns a closed channel without emitting any responses.
@@ -356,15 +365,21 @@ func TestRun_NoPanicWhenModelReturnsNoResponses(t *testing.T) {
356365
defer cancel()
357366

358367
f := New(nil, nil, Options{})
359-
inv := &agent.Invocation{InvocationID: "inv-nil", AgentName: "agent-nil", Model: &noResponseModel{}}
368+
inv := agent.NewInvocation(
369+
agent.WithInvocationModel(&noResponseModel{}),
370+
)
360371

361372
ch, err := f.Run(ctx, inv)
362373
require.NoError(t, err)
363374

364375
// Collect all events until channel closes. Expect none and, importantly, no panic.
365376
var count int
366-
for range ch {
377+
for evt := range ch {
378+
if evt.RequiresCompletion {
379+
key := agent.AppendEventNoticeKeyPrefix + evt.ID
380+
inv.NotifyCompletion(ctx, key)
381+
}
367382
count++
368383
}
369-
require.Equal(t, 0, count)
384+
require.Equal(t, 1, count)
370385
}

internal/flow/processor/functioncall.go

Lines changed: 0 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@ package processor
1212
import (
1313
"context"
1414
"encoding/json"
15-
"errors"
1615
"fmt"
1716
"io"
1817
"sync"
@@ -39,9 +38,6 @@ const (
3938
ErrorStreamableToolExecution = "Error: streamable tool execution failed"
4039
// ErrorMarshalResult is the error message for failed to marshal result.
4140
ErrorMarshalResult = "Error: failed to marshal result"
42-
43-
// Timeout for event completion signaling.
44-
eventCompletionTimeout = 5 * time.Second
4541
)
4642

4743
// summarizationSkipper is implemented by tools that can indicate whether
@@ -114,17 +110,6 @@ func (p *FunctionCallResponseProcessor) ProcessResponse(
114110
invocation.EndInvocation = true
115111
return
116112
}
117-
118-
// Wait for completion if required.
119-
if err := p.waitForCompletion(ctx, invocation, functioncallResponseEvent); err != nil {
120-
agent.EmitEvent(ctx, invocation, ch, event.NewErrorEvent(
121-
invocation.InvocationID,
122-
invocation.AgentName,
123-
model.ErrorTypeFlowError,
124-
err.Error(),
125-
))
126-
return
127-
}
128113
}
129114

130115
func (p *FunctionCallResponseProcessor) handleFunctionCallsAndSendEvent(
@@ -406,7 +391,6 @@ func (p *FunctionCallResponseProcessor) buildMergedParallelEvent(
406391
} else {
407392
mergedEvent = mergeParallelToolCallResponseEvents(toolCallEvents)
408393
}
409-
mergedEvent.RequiresCompletion = true
410394
if len(toolCallEvents) > 1 {
411395
_, span := trace.Tracer.Start(
412396
ctx, fmt.Sprintf("%s (merged)", itelemetry.SpanNamePrefixExecuteTool),
@@ -494,20 +478,6 @@ func (p *FunctionCallResponseProcessor) executeToolCall(
494478
}, modifiedArgs, nil
495479
}
496480

497-
// waitForCompletion waits for event completion if required.
498-
func (p *FunctionCallResponseProcessor) waitForCompletion(ctx context.Context, invocation *agent.Invocation, lastEvent *event.Event) error {
499-
if !lastEvent.RequiresCompletion {
500-
return nil
501-
}
502-
503-
completionID := agent.AppendEventNoticeKeyPrefix + lastEvent.ID
504-
err := invocation.AddNoticeChannelAndWait(ctx, completionID, eventCompletionTimeout)
505-
if errors.Is(err, context.Canceled) {
506-
return err
507-
}
508-
return nil
509-
}
510-
511481
// createErrorChoice creates an error choice for tool execution failures.
512482
func (p *FunctionCallResponseProcessor) createErrorChoice(index int, toolID string,
513483
errorMsg string) *model.Choice {

0 commit comments

Comments
 (0)