Skip to content

Commit b910954

Browse files
authored
server/agui: still return existing messages even on snapshot error (#763)
1 parent d5a4524 commit b910954

File tree

4 files changed

+273
-12
lines changed

4 files changed

+273
-12
lines changed

server/agui/internal/reduce/reduce.go

Lines changed: 60 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -64,21 +64,27 @@ type toolCallState struct {
6464
}
6565

6666
// Reduce reduces the AG-UI track events into message snapshots.
67+
// In order to fetch the history messages as much as possible, still return the messages even if there is an error.
6768
func Reduce(appName, userID string, events []session.TrackEvent) ([]aguievents.Message, error) {
6869
r := new(appName, userID)
70+
var err error
6971
for _, trackEvent := range events {
70-
if err := r.reduce(trackEvent); err != nil {
71-
return nil, fmt.Errorf("reduce: %w", err)
72+
if err = r.reduce(trackEvent); err != nil {
73+
err = fmt.Errorf("reduce: %w", err)
74+
break
7275
}
7376
}
74-
if err := r.finalize(); err != nil {
75-
return nil, fmt.Errorf("finalize: %w", err)
77+
if err == nil {
78+
if finalizeErr := r.finalize(); finalizeErr != nil {
79+
err = fmt.Errorf("finalize: %w", finalizeErr)
80+
}
7681
}
7782
messages := make([]aguievents.Message, 0, len(r.messages))
7883
for _, message := range r.messages {
7984
messages = append(messages, *message)
8085
}
81-
return messages, nil
86+
// In order to fetch the history messages as much as possible, still return the messages even if there is an error.
87+
return messages, err
8288
}
8389

8490
// new creates a new reducer.
@@ -101,13 +107,19 @@ func (r *reducer) reduce(trackEvent session.TrackEvent) error {
101107
if err != nil {
102108
return fmt.Errorf("unmarshal track event payload: %w", err)
103109
}
110+
return r.reduceEvent(evt)
111+
}
112+
113+
func (r *reducer) reduceEvent(evt aguievents.Event) error {
104114
switch e := evt.(type) {
105115
case *aguievents.TextMessageStartEvent:
106116
return r.handleTextStart(e)
107117
case *aguievents.TextMessageContentEvent:
108118
return r.handleTextContent(e)
109119
case *aguievents.TextMessageEndEvent:
110120
return r.handleTextEnd(e)
121+
case *aguievents.TextMessageChunkEvent:
122+
return r.handleTextChunk(e)
111123
case *aguievents.ToolCallStartEvent:
112124
return r.handleToolStart(e)
113125
case *aguievents.ToolCallArgsEvent:
@@ -184,6 +196,49 @@ func (r *reducer) handleTextEnd(e *aguievents.TextMessageEndEvent) error {
184196
return nil
185197
}
186198

199+
// handleTextChunk handles the text message chunk event.
200+
func (r *reducer) handleTextChunk(e *aguievents.TextMessageChunkEvent) error {
201+
if e.MessageID == nil || *e.MessageID == "" {
202+
return fmt.Errorf("text message chunk missing id")
203+
}
204+
if _, exists := r.texts[*e.MessageID]; exists {
205+
return fmt.Errorf("duplicate text message chunk: %s", *e.MessageID)
206+
}
207+
role := string(model.RoleAssistant)
208+
if e.Role != nil && *e.Role != "" {
209+
role = string(*e.Role)
210+
}
211+
name := ""
212+
switch role {
213+
case string(model.RoleUser):
214+
name = r.userID
215+
case string(model.RoleAssistant):
216+
name = r.appName
217+
default:
218+
return fmt.Errorf("unsupported role: %s", role)
219+
}
220+
content := ""
221+
if e.Delta != nil {
222+
content = strings.Clone(*e.Delta)
223+
}
224+
r.messages = append(r.messages, &aguievents.Message{
225+
ID: *e.MessageID,
226+
Role: role,
227+
Name: &name,
228+
Content: &content,
229+
})
230+
builder := strings.Builder{}
231+
builder.WriteString(content)
232+
r.texts[*e.MessageID] = &textState{
233+
role: role,
234+
name: name,
235+
content: builder,
236+
phase: textEnded,
237+
index: len(r.messages) - 1,
238+
}
239+
return nil
240+
}
241+
187242
// handleToolStart handles the tool call start event.
188243
func (r *reducer) handleToolStart(e *aguievents.ToolCallStartEvent) error {
189244
if e.ToolCallID == "" {

server/agui/internal/reduce/reduce_test.go

Lines changed: 154 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,160 @@ func TestBuildMessagesHappyPath(t *testing.T) {
109109
}
110110
}
111111

112+
func TestReduceReturnsMessagesOnReduceError(t *testing.T) {
113+
events := trackEventsFrom(
114+
aguievents.NewTextMessageStartEvent("user-1", aguievents.WithRole("user")),
115+
aguievents.NewTextMessageContentEvent("user-1", "hello"),
116+
aguievents.NewTextMessageEndEvent("user-1"),
117+
aguievents.NewTextMessageContentEvent("user-1", "!"),
118+
)
119+
msgs, err := Reduce(testAppName, testUserID, events)
120+
if err == nil || !strings.Contains(err.Error(), "reduce: text message content after end: user-1") {
121+
t.Fatalf("unexpected error %v", err)
122+
}
123+
if len(msgs) != 1 {
124+
t.Fatalf("expected 1 message, got %d", len(msgs))
125+
}
126+
if msgs[0].Content == nil || *msgs[0].Content != "hello" {
127+
t.Fatalf("unexpected content %v", msgs[0].Content)
128+
}
129+
}
130+
131+
func TestReduceReturnsMessagesOnFinalizeError(t *testing.T) {
132+
events := trackEventsFrom(
133+
aguievents.NewTextMessageStartEvent("user-1", aguievents.WithRole("user")),
134+
aguievents.NewTextMessageContentEvent("user-1", "hello"),
135+
)
136+
msgs, err := Reduce(testAppName, testUserID, events)
137+
if err == nil || !strings.Contains(err.Error(), "finalize: text message user-1 not closed") {
138+
t.Fatalf("unexpected error %v", err)
139+
}
140+
if len(msgs) != 1 {
141+
t.Fatalf("expected 1 message, got %d", len(msgs))
142+
}
143+
if msgs[0].Content != nil {
144+
t.Fatalf("expected nil content, got %v", msgs[0].Content)
145+
}
146+
}
147+
148+
func TestHandleTextChunkSuccess(t *testing.T) {
149+
tests := []struct {
150+
name string
151+
chunk *aguievents.TextMessageChunkEvent
152+
wantRole string
153+
wantName string
154+
wantContent string
155+
}{
156+
{
157+
name: "assistant default role empty delta",
158+
chunk: aguievents.NewTextMessageChunkEvent().WithChunkMessageID("msg-1"),
159+
wantRole: "assistant",
160+
wantName: testAppName,
161+
wantContent: "",
162+
},
163+
{
164+
name: "user role with delta",
165+
chunk: aguievents.NewTextMessageChunkEvent().
166+
WithChunkMessageID("msg-2").
167+
WithChunkRole("user").
168+
WithChunkDelta("hi"),
169+
wantRole: "user",
170+
wantName: testUserID,
171+
wantContent: "hi",
172+
},
173+
}
174+
for _, tt := range tests {
175+
t.Run(tt.name, func(t *testing.T) {
176+
r := new(testAppName, testUserID)
177+
if err := r.handleTextChunk(tt.chunk); err != nil {
178+
t.Fatalf("handleTextChunk err: %v", err)
179+
}
180+
if err := r.finalize(); err != nil {
181+
t.Fatalf("finalize err: %v", err)
182+
}
183+
if len(r.messages) != 1 {
184+
t.Fatalf("expected 1 message, got %d", len(r.messages))
185+
}
186+
msg := r.messages[0]
187+
if msg.Role != tt.wantRole {
188+
t.Fatalf("unexpected role %q", msg.Role)
189+
}
190+
if msg.Name == nil || *msg.Name != tt.wantName {
191+
t.Fatalf("unexpected name %v", msg.Name)
192+
}
193+
if msg.Content == nil || *msg.Content != tt.wantContent {
194+
t.Fatalf("unexpected content %v", msg.Content)
195+
}
196+
state, ok := r.texts[*tt.chunk.MessageID]
197+
if !ok {
198+
t.Fatalf("expected text state for %s", *tt.chunk.MessageID)
199+
}
200+
if state.phase != textEnded {
201+
t.Fatalf("unexpected phase %v", state.phase)
202+
}
203+
if got := state.content.String(); got != tt.wantContent {
204+
t.Fatalf("unexpected builder content %q", got)
205+
}
206+
if state.index != 0 {
207+
t.Fatalf("unexpected state index %d", state.index)
208+
}
209+
})
210+
}
211+
}
212+
213+
func TestHandleTextChunkErrors(t *testing.T) {
214+
t.Run("missing id", func(t *testing.T) {
215+
chunk := aguievents.NewTextMessageChunkEvent()
216+
r := new(testAppName, testUserID)
217+
if err := r.handleTextChunk(chunk); err == nil || !strings.Contains(err.Error(), "text message chunk missing id") {
218+
t.Fatalf("unexpected error %v", err)
219+
}
220+
})
221+
t.Run("duplicate id", func(t *testing.T) {
222+
chunk := aguievents.NewTextMessageChunkEvent().WithChunkMessageID("msg-1")
223+
r := new(testAppName, testUserID)
224+
if err := r.handleTextChunk(chunk); err != nil {
225+
t.Fatalf("handleTextChunk err: %v", err)
226+
}
227+
if err := r.handleTextChunk(chunk); err == nil || !strings.Contains(err.Error(), "duplicate text message chunk: msg-1") {
228+
t.Fatalf("unexpected error %v", err)
229+
}
230+
})
231+
t.Run("unsupported role", func(t *testing.T) {
232+
chunk := aguievents.NewTextMessageChunkEvent().WithChunkMessageID("msg-3").WithChunkRole("tool")
233+
r := new(testAppName, testUserID)
234+
if err := r.handleTextChunk(chunk); err == nil || !strings.Contains(err.Error(), "unsupported role: tool") {
235+
t.Fatalf("unexpected error %v", err)
236+
}
237+
})
238+
t.Run("empty string id pointer", func(t *testing.T) {
239+
chunk := aguievents.NewTextMessageChunkEvent()
240+
empty := ""
241+
chunk.MessageID = &empty
242+
r := new(testAppName, testUserID)
243+
if err := r.handleTextChunk(chunk); err == nil || !strings.Contains(err.Error(), "text message chunk missing id") {
244+
t.Fatalf("unexpected error %v", err)
245+
}
246+
})
247+
}
248+
249+
func TestReduceEventDispatchesChunk(t *testing.T) {
250+
r := new(testAppName, testUserID)
251+
chunk := aguievents.NewTextMessageChunkEvent().WithChunkMessageID("msg-1").WithChunkDelta("hi")
252+
if err := r.reduceEvent(chunk); err != nil {
253+
t.Fatalf("reduceEvent err: %v", err)
254+
}
255+
if err := r.finalize(); err != nil {
256+
t.Fatalf("finalize err: %v", err)
257+
}
258+
if len(r.messages) != 1 {
259+
t.Fatalf("expected 1 message, got %d", len(r.messages))
260+
}
261+
if r.messages[0].Content == nil || *r.messages[0].Content != "hi" {
262+
t.Fatalf("unexpected content %v", r.messages[0].Content)
263+
}
264+
}
265+
112266
func TestAssistantOnlyToolCall(t *testing.T) {
113267
events := []session.TrackEvent{
114268
newTrackEvent(aguievents.NewTextMessageStartEvent("assistant-1", aguievents.WithRole("assistant"))),

server/agui/runner/messagessnapshot.go

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -79,21 +79,30 @@ func (r *runner) messagesSnapshot(ctx context.Context, input *runInput, events c
7979
messagesSnapshotEvent, err := r.getMessagesSnapshotEvent(ctx, input.key)
8080
if err != nil {
8181
log.Errorf("agui messages snapshot: threadID: %s, runID: %s, load history: %v", threadID, runID, err)
82-
r.emitEvent(ctx, events, aguievents.NewRunErrorEvent(fmt.Sprintf("load history: %v", err),
83-
aguievents.WithRunID(runID)), input)
84-
return
82+
if messagesSnapshotEvent == nil {
83+
r.emitEvent(ctx, events, aguievents.NewRunErrorEvent(fmt.Sprintf("load history: %v", err),
84+
aguievents.WithRunID(runID)), input)
85+
return
86+
}
8587
}
88+
// In order to fetch the history messages as much as possible, still emit the messages even if there is an error.
8689
// Emit a MESSAGES_SNAPSHOT event to send the snapshot payload.
8790
if !r.emitEvent(ctx, events, messagesSnapshotEvent, input) {
8891
return
8992
}
93+
if err != nil {
94+
r.emitEvent(ctx, events, aguievents.NewRunErrorEvent(fmt.Sprintf("load history: %v", err),
95+
aguievents.WithRunID(runID)), input)
96+
return
97+
}
9098
// Emit a RUN_FINISHED event to signal downstream consumers there is no more data.
9199
if !r.emitEvent(ctx, events, aguievents.NewRunFinishedEvent(threadID, runID), input) {
92100
return
93101
}
94102
}
95103

96104
// getMessagesSnapshotEvent loads AG-UI track events and converts them to an AG-UI MessagesSnapshotEvent.
105+
// In order to fetch the history messages as much as possible, still return the messages even if there is an error.
97106
func (r *runner) getMessagesSnapshotEvent(ctx context.Context,
98107
sessionKey session.Key) (*aguievents.MessagesSnapshotEvent, error) {
99108
trackEvents, err := r.tracker.GetEvents(ctx, sessionKey)
@@ -102,7 +111,7 @@ func (r *runner) getMessagesSnapshotEvent(ctx context.Context,
102111
}
103112
messages, err := reduce.Reduce(r.appName, sessionKey.UserID, trackEvents.Events)
104113
if err != nil {
105-
return nil, fmt.Errorf("reduce track events: %w", err)
114+
err = fmt.Errorf("reduce track events: %w", err)
106115
}
107-
return aguievents.NewMessagesSnapshotEvent(messages), nil
116+
return aguievents.NewMessagesSnapshotEvent(messages), err
108117
}

server/agui/runner/messagessnapshot_test.go

Lines changed: 45 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -246,8 +246,51 @@ func TestMessagesSnapshotReduceError(t *testing.T) {
246246
stream, err := r.MessagesSnapshot(context.Background(), &adapter.RunAgentInput{ThreadID: "thread", RunID: "run"})
247247
require.NoError(t, err)
248248
collected := collectAGUIEvents(t, stream)
249-
require.Len(t, collected, 2)
250-
errEvt, ok := collected[1].(*aguievents.RunErrorEvent)
249+
require.Len(t, collected, 3)
250+
if _, ok := collected[0].(*aguievents.RunStartedEvent); !ok {
251+
t.Fatalf("expected RUN_STARTED")
252+
}
253+
snapshot, ok := collected[1].(*aguievents.MessagesSnapshotEvent)
254+
require.True(t, ok)
255+
require.Len(t, snapshot.Messages, 1)
256+
errEvt, ok := collected[2].(*aguievents.RunErrorEvent)
257+
require.True(t, ok)
258+
assert.Contains(t, errEvt.Message, "reduce track events")
259+
}
260+
261+
func TestMessagesSnapshotReduceErrorEmitsSnapshotThenError(t *testing.T) {
262+
svc := &testSessionService{
263+
trackEvents: []session.TrackEvent{
264+
newTrackEvent(t, aguievents.NewTextMessageStartEvent("user-1", aguievents.WithRole("user"))),
265+
newTrackEvent(t, aguievents.NewTextMessageContentEvent("user-1", "hello")),
266+
newTrackEvent(t, aguievents.NewTextMessageEndEvent("user-1")),
267+
newTrackEvent(t, aguievents.NewTextMessageContentEvent("user-1", "!")),
268+
},
269+
}
270+
tracker, err := track.New(svc)
271+
require.NoError(t, err)
272+
r := &runner{
273+
runner: noopBaseRunner{},
274+
userIDResolver: NewOptions().UserIDResolver,
275+
runAgentInputHook: NewOptions().RunAgentInputHook,
276+
appName: "demo",
277+
tracker: tracker,
278+
}
279+
280+
stream, err := r.MessagesSnapshot(context.Background(), &adapter.RunAgentInput{ThreadID: "thread", RunID: "run"})
281+
require.NoError(t, err)
282+
collected := collectAGUIEvents(t, stream)
283+
require.Len(t, collected, 3)
284+
if _, ok := collected[0].(*aguievents.RunStartedEvent); !ok {
285+
t.Fatalf("expected RUN_STARTED")
286+
}
287+
snapshot, ok := collected[1].(*aguievents.MessagesSnapshotEvent)
288+
require.True(t, ok)
289+
require.Len(t, snapshot.Messages, 1)
290+
if snapshot.Messages[0].Content == nil || *snapshot.Messages[0].Content != "hello" {
291+
t.Fatalf("unexpected snapshot content %v", snapshot.Messages[0].Content)
292+
}
293+
errEvt, ok := collected[2].(*aguievents.RunErrorEvent)
251294
require.True(t, ok)
252295
assert.Contains(t, errEvt.Message, "reduce track events")
253296
}

0 commit comments

Comments
 (0)