Skip to content

Commit f7a4a54

Browse files
authored
agent: add runtime state retrieval from context (#768)
1 parent b910954 commit f7a4a54

File tree

2 files changed

+169
-0
lines changed

2 files changed

+169
-0
lines changed

agent/invocationcontext.go

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,29 @@ func GetStateValueFromContext[T any](ctx context.Context, key string) (T, bool)
5555
return GetStateValue[T](inv, key)
5656
}
5757

58+
// GetRuntimeStateValueFromContext retrieves a typed value from the runtime state
59+
// stored in the invocation's RunOptions within the context.
60+
//
61+
// Returns the typed value and true if the invocation exists, the key exists in
62+
// RuntimeState, and the type matches, or the zero value and false otherwise.
63+
//
64+
// Example:
65+
//
66+
// if userID, ok := GetRuntimeStateValueFromContext[string](ctx, "user_id"); ok {
67+
// log.Printf("User ID: %s", userID)
68+
// }
69+
// if roomID, ok := GetRuntimeStateValueFromContext[int](ctx, "room_id"); ok {
70+
// log.Printf("Room ID: %d", roomID)
71+
// }
72+
func GetRuntimeStateValueFromContext[T any](ctx context.Context, key string) (T, bool) {
73+
var zero T
74+
inv, ok := InvocationFromContext(ctx)
75+
if !ok || inv == nil {
76+
return zero, false
77+
}
78+
return GetRuntimeStateValue[T](&inv.RunOptions, key)
79+
}
80+
5881
// CheckContextCancelled check context cancelled
5982
func CheckContextCancelled(ctx context.Context) error {
6083
select {

agent/invocationcontext_test.go

Lines changed: 146 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -196,3 +196,149 @@ func TestGetStateValueFromContext(t *testing.T) {
196196
assert.Equal(t, data.Metadata, retrieved.Metadata)
197197
})
198198
}
199+
200+
func TestGetRuntimeStateValueFromContext(t *testing.T) {
201+
t.Run("context without invocation", func(t *testing.T) {
202+
ctx := context.Background()
203+
val, ok := GetRuntimeStateValueFromContext[string](ctx, "key")
204+
assert.False(t, ok)
205+
assert.Equal(t, "", val)
206+
})
207+
208+
t.Run("context with invocation but key not found", func(t *testing.T) {
209+
inv := NewInvocation()
210+
ctx := NewInvocationContext(context.Background(), inv)
211+
val, ok := GetRuntimeStateValueFromContext[string](ctx, "nonexistent")
212+
assert.False(t, ok)
213+
assert.Equal(t, "", val)
214+
})
215+
216+
t.Run("context with invocation but nil RuntimeState", func(t *testing.T) {
217+
inv := NewInvocation()
218+
ctx := NewInvocationContext(context.Background(), inv)
219+
val, ok := GetRuntimeStateValueFromContext[string](ctx, "key")
220+
assert.False(t, ok)
221+
assert.Equal(t, "", val)
222+
})
223+
224+
t.Run("context with invocation and matching type", func(t *testing.T) {
225+
inv := NewInvocation(
226+
WithInvocationRunOptions(RunOptions{
227+
RuntimeState: map[string]any{
228+
"user_id": "12345",
229+
"room_id": 678,
230+
"config": true,
231+
"score": 3.14,
232+
},
233+
}),
234+
)
235+
ctx := NewInvocationContext(context.Background(), inv)
236+
237+
// Test string value.
238+
userID, ok := GetRuntimeStateValueFromContext[string](ctx, "user_id")
239+
assert.True(t, ok)
240+
assert.Equal(t, "12345", userID)
241+
242+
// Test int value.
243+
roomID, ok := GetRuntimeStateValueFromContext[int](ctx, "room_id")
244+
assert.True(t, ok)
245+
assert.Equal(t, 678, roomID)
246+
247+
// Test bool value.
248+
config, ok := GetRuntimeStateValueFromContext[bool](ctx, "config")
249+
assert.True(t, ok)
250+
assert.Equal(t, true, config)
251+
252+
// Test float64 value.
253+
score, ok := GetRuntimeStateValueFromContext[float64](ctx, "score")
254+
assert.True(t, ok)
255+
assert.Equal(t, 3.14, score)
256+
})
257+
258+
t.Run("context with invocation but type mismatch", func(t *testing.T) {
259+
inv := NewInvocation(
260+
WithInvocationRunOptions(RunOptions{
261+
RuntimeState: map[string]any{
262+
"value": "hello",
263+
},
264+
}),
265+
)
266+
ctx := NewInvocationContext(context.Background(), inv)
267+
268+
// Try to get as int when it's actually string.
269+
intVal, ok := GetRuntimeStateValueFromContext[int](ctx, "value")
270+
assert.False(t, ok)
271+
assert.Equal(t, 0, intVal)
272+
})
273+
274+
t.Run("context with nil invocation", func(t *testing.T) {
275+
ctx := NewInvocationContext(context.Background(), nil)
276+
val, ok := GetRuntimeStateValueFromContext[string](ctx, "key")
277+
assert.False(t, ok)
278+
assert.Equal(t, "", val)
279+
})
280+
281+
t.Run("slice type", func(t *testing.T) {
282+
inv := NewInvocation(
283+
WithInvocationRunOptions(RunOptions{
284+
RuntimeState: map[string]any{
285+
"tags": []string{"tag1", "tag2", "tag3"},
286+
},
287+
}),
288+
)
289+
ctx := NewInvocationContext(context.Background(), inv)
290+
291+
tags, ok := GetRuntimeStateValueFromContext[[]string](ctx, "tags")
292+
assert.True(t, ok)
293+
assert.Equal(t, []string{"tag1", "tag2", "tag3"}, tags)
294+
})
295+
296+
t.Run("map type", func(t *testing.T) {
297+
inv := NewInvocation(
298+
WithInvocationRunOptions(RunOptions{
299+
RuntimeState: map[string]any{
300+
"metadata": map[string]string{
301+
"key1": "value1",
302+
"key2": "value2",
303+
},
304+
},
305+
}),
306+
)
307+
ctx := NewInvocationContext(context.Background(), inv)
308+
309+
metadata, ok := GetRuntimeStateValueFromContext[map[string]string](ctx, "metadata")
310+
assert.True(t, ok)
311+
assert.Equal(t, "value1", metadata["key1"])
312+
assert.Equal(t, "value2", metadata["key2"])
313+
})
314+
315+
t.Run("complex struct type", func(t *testing.T) {
316+
type UserContext struct {
317+
UserID string
318+
RoomID int
319+
Metadata map[string]string
320+
}
321+
322+
userCtx := UserContext{
323+
UserID: "user-123",
324+
RoomID: 456,
325+
Metadata: map[string]string{
326+
"key1": "value1",
327+
},
328+
}
329+
inv := NewInvocation(
330+
WithInvocationRunOptions(RunOptions{
331+
RuntimeState: map[string]any{
332+
"user_context": userCtx,
333+
},
334+
}),
335+
)
336+
ctx := NewInvocationContext(context.Background(), inv)
337+
338+
retrieved, ok := GetRuntimeStateValueFromContext[UserContext](ctx, "user_context")
339+
require.True(t, ok)
340+
assert.Equal(t, userCtx.UserID, retrieved.UserID)
341+
assert.Equal(t, userCtx.RoomID, retrieved.RoomID)
342+
assert.Equal(t, userCtx.Metadata, retrieved.Metadata)
343+
})
344+
}

0 commit comments

Comments
 (0)