Skip to content

Commit 7d13294

Browse files
authored
{agent, graph}: implement generic getter for typed state retrieval (#752)
- Added GetStateValue function to retrieve typed values from the invocation state. - Introduced GetStateValueFromContext to fetch typed values from the invocation stored in the context. - Introduced state_test.go to validate the behavior of GetStateValue for various scenarios, including key not found, nil state, type matching, type mismatch, and complex types.
1 parent 7521bd4 commit 7d13294

File tree

6 files changed

+476
-0
lines changed

6 files changed

+476
-0
lines changed

agent/invocation.go

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -558,6 +558,35 @@ func (inv *Invocation) GetState(key string) (any, bool) {
558558
return value, ok
559559
}
560560

561+
// GetStateValue retrieves a typed value from the invocation state.
562+
//
563+
// Returns the typed value and true if the key exists and the type matches,
564+
// or the zero value and false otherwise.
565+
//
566+
// Example:
567+
//
568+
// if startTime, ok := GetStateValue[time.Time](inv, "agent:start_time"); ok {
569+
// duration := time.Since(startTime)
570+
// }
571+
// if requestID, ok := GetStateValue[string](inv, "middleware:request_id"); ok {
572+
// log.Printf("Request ID: %s", requestID)
573+
// }
574+
func GetStateValue[T any](inv *Invocation, key string) (T, bool) {
575+
var zero T
576+
if inv == nil {
577+
return zero, false
578+
}
579+
val, ok := inv.GetState(key)
580+
if !ok {
581+
return zero, false
582+
}
583+
typedVal, ok := val.(T)
584+
if !ok {
585+
return zero, false
586+
}
587+
return typedVal, true
588+
}
589+
561590
// GetOrCreateTimingInfo gets or creates timing info for this invocation.
562591
// Only the first LLM call will create and populate timing info; subsequent calls reuse it.
563592
// This ensures timing metrics only reflect the first LLM call in scenarios with multiple calls (e.g., tool calls).

agent/invocation_state_test.go

Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -323,3 +323,124 @@ func TestInvocation_State_ComplexStruct(t *testing.T) {
323323
assert.Equal(t, data.ID, retrieved.ID)
324324
assert.Equal(t, data.Metadata, retrieved.Metadata)
325325
}
326+
327+
func TestGetStateValue(t *testing.T) {
328+
t.Run("key not found", func(t *testing.T) {
329+
inv := NewInvocation()
330+
val, ok := GetStateValue[string](inv, "nonexistent")
331+
assert.False(t, ok)
332+
assert.Equal(t, "", val)
333+
})
334+
335+
t.Run("matching type", func(t *testing.T) {
336+
inv := NewInvocation()
337+
inv.SetState("agent:string", "hello")
338+
inv.SetState("agent:int", 42)
339+
inv.SetState("agent:float", 3.14)
340+
inv.SetState("agent:bool", true)
341+
inv.SetState("agent:time", time.Now())
342+
343+
// Test string.
344+
strVal, ok := GetStateValue[string](inv, "agent:string")
345+
assert.True(t, ok)
346+
assert.Equal(t, "hello", strVal)
347+
348+
// Test int.
349+
intVal, ok := GetStateValue[int](inv, "agent:int")
350+
assert.True(t, ok)
351+
assert.Equal(t, 42, intVal)
352+
353+
// Test float64.
354+
floatVal, ok := GetStateValue[float64](inv, "agent:float")
355+
assert.True(t, ok)
356+
assert.Equal(t, 3.14, floatVal)
357+
358+
// Test bool.
359+
boolVal, ok := GetStateValue[bool](inv, "agent:bool")
360+
assert.True(t, ok)
361+
assert.Equal(t, true, boolVal)
362+
363+
// Test time.Time.
364+
timeVal, ok := GetStateValue[time.Time](inv, "agent:time")
365+
assert.True(t, ok)
366+
assert.IsType(t, time.Time{}, timeVal)
367+
})
368+
369+
t.Run("type mismatch", func(t *testing.T) {
370+
inv := NewInvocation()
371+
inv.SetState("agent:value", "hello")
372+
373+
// Try to get as int when it's actually string.
374+
intVal, ok := GetStateValue[int](inv, "agent:value")
375+
assert.False(t, ok)
376+
assert.Equal(t, 0, intVal)
377+
378+
// Try to get as string when it's actually int.
379+
inv.SetState("agent:number", 42)
380+
strVal, ok := GetStateValue[string](inv, "agent:number")
381+
assert.False(t, ok)
382+
assert.Equal(t, "", strVal)
383+
})
384+
385+
t.Run("nil invocation", func(t *testing.T) {
386+
var inv *Invocation
387+
val, ok := GetStateValue[string](inv, "key")
388+
assert.False(t, ok)
389+
assert.Equal(t, "", val)
390+
})
391+
392+
t.Run("complex struct type", func(t *testing.T) {
393+
type CustomData struct {
394+
ID string
395+
Timestamp time.Time
396+
Metadata map[string]string
397+
}
398+
399+
inv := NewInvocation()
400+
data := CustomData{
401+
ID: "test-123",
402+
Timestamp: time.Now(),
403+
Metadata: map[string]string{
404+
"key1": "value1",
405+
"key2": "value2",
406+
},
407+
}
408+
inv.SetState("agent:custom_data", data)
409+
410+
retrieved, ok := GetStateValue[CustomData](inv, "agent:custom_data")
411+
require.True(t, ok)
412+
assert.Equal(t, data.ID, retrieved.ID)
413+
assert.Equal(t, data.Metadata, retrieved.Metadata)
414+
})
415+
416+
t.Run("pointer type", func(t *testing.T) {
417+
inv := NewInvocation()
418+
str := "hello"
419+
inv.SetState("agent:ptr", &str)
420+
421+
ptrVal, ok := GetStateValue[*string](inv, "agent:ptr")
422+
assert.True(t, ok)
423+
require.NotNil(t, ptrVal)
424+
assert.Equal(t, "hello", *ptrVal)
425+
})
426+
427+
t.Run("slice type", func(t *testing.T) {
428+
inv := NewInvocation()
429+
slice := []int{1, 2, 3}
430+
inv.SetState("agent:slice", slice)
431+
432+
sliceVal, ok := GetStateValue[[]int](inv, "agent:slice")
433+
assert.True(t, ok)
434+
assert.Equal(t, []int{1, 2, 3}, sliceVal)
435+
})
436+
437+
t.Run("map type", func(t *testing.T) {
438+
inv := NewInvocation()
439+
m := map[string]int{"a": 1, "b": 2}
440+
inv.SetState("agent:map", m)
441+
442+
mapVal, ok := GetStateValue[map[string]int](inv, "agent:map")
443+
assert.True(t, ok)
444+
assert.Equal(t, map[string]int{"a": 1, "b": 2}, mapVal)
445+
})
446+
}

agent/invocationcontext.go

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,29 @@ func InvocationFromContext(ctx context.Context) (*Invocation, bool) {
3232
return invocation, ok
3333
}
3434

35+
// GetStateValueFromContext retrieves a typed value from the invocation state
36+
// stored in the context.
37+
//
38+
// Returns the typed value and true if the invocation exists, the key exists,
39+
// and the type matches, or the zero value and false otherwise.
40+
//
41+
// Example:
42+
//
43+
// if startTime, ok := GetStateValueFromContext[time.Time](ctx, "agent:start_time"); ok {
44+
// duration := time.Since(startTime)
45+
// }
46+
// if requestID, ok := GetStateValueFromContext[string](ctx, "middleware:request_id"); ok {
47+
// log.Printf("Request ID: %s", requestID)
48+
// }
49+
func GetStateValueFromContext[T any](ctx context.Context, key string) (T, bool) {
50+
var zero T
51+
inv, ok := InvocationFromContext(ctx)
52+
if !ok {
53+
return zero, false
54+
}
55+
return GetStateValue[T](inv, key)
56+
}
57+
3558
// CheckContextCancelled check context cancelled
3659
func CheckContextCancelled(ctx context.Context) error {
3760
select {

agent/invocationcontext_test.go

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,3 +114,85 @@ func TestInvocationFromContext(t *testing.T) {
114114
})
115115
}
116116
}
117+
118+
func TestGetStateValueFromContext(t *testing.T) {
119+
t.Run("context without invocation", func(t *testing.T) {
120+
ctx := context.Background()
121+
val, ok := GetStateValueFromContext[string](ctx, "key")
122+
assert.False(t, ok)
123+
assert.Equal(t, "", val)
124+
})
125+
126+
t.Run("context with invocation but key not found", func(t *testing.T) {
127+
inv := NewInvocation()
128+
ctx := NewInvocationContext(context.Background(), inv)
129+
val, ok := GetStateValueFromContext[string](ctx, "nonexistent")
130+
assert.False(t, ok)
131+
assert.Equal(t, "", val)
132+
})
133+
134+
t.Run("context with invocation and matching type", func(t *testing.T) {
135+
inv := NewInvocation()
136+
inv.SetState("agent:string", "hello")
137+
inv.SetState("agent:int", 42)
138+
inv.SetState("agent:time", time.Now())
139+
ctx := NewInvocationContext(context.Background(), inv)
140+
141+
// Test string value.
142+
strVal, ok := GetStateValueFromContext[string](ctx, "agent:string")
143+
assert.True(t, ok)
144+
assert.Equal(t, "hello", strVal)
145+
146+
// Test int value.
147+
intVal, ok := GetStateValueFromContext[int](ctx, "agent:int")
148+
assert.True(t, ok)
149+
assert.Equal(t, 42, intVal)
150+
151+
// Test time.Time value.
152+
timeVal, ok := GetStateValueFromContext[time.Time](ctx, "agent:time")
153+
assert.True(t, ok)
154+
assert.IsType(t, time.Time{}, timeVal)
155+
})
156+
157+
t.Run("context with invocation but type mismatch", func(t *testing.T) {
158+
inv := NewInvocation()
159+
inv.SetState("agent:value", "hello")
160+
ctx := NewInvocationContext(context.Background(), inv)
161+
162+
// Try to get as int when it's actually string.
163+
intVal, ok := GetStateValueFromContext[int](ctx, "agent:value")
164+
assert.False(t, ok)
165+
assert.Equal(t, 0, intVal)
166+
})
167+
168+
t.Run("context with nil invocation", func(t *testing.T) {
169+
ctx := NewInvocationContext(context.Background(), nil)
170+
val, ok := GetStateValueFromContext[string](ctx, "key")
171+
assert.False(t, ok)
172+
assert.Equal(t, "", val)
173+
})
174+
175+
t.Run("complex struct type", func(t *testing.T) {
176+
type CustomData struct {
177+
ID string
178+
Timestamp time.Time
179+
Metadata map[string]string
180+
}
181+
182+
inv := NewInvocation()
183+
data := CustomData{
184+
ID: "test-123",
185+
Timestamp: time.Now(),
186+
Metadata: map[string]string{
187+
"key1": "value1",
188+
},
189+
}
190+
inv.SetState("agent:custom_data", data)
191+
ctx := NewInvocationContext(context.Background(), inv)
192+
193+
retrieved, ok := GetStateValueFromContext[CustomData](ctx, "agent:custom_data")
194+
require.True(t, ok)
195+
assert.Equal(t, data.ID, retrieved.ID)
196+
assert.Equal(t, data.Metadata, retrieved.Metadata)
197+
})
198+
}

graph/state.go

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,35 @@ const (
5353
// This is the shared data structure that flows between nodes.
5454
type State map[string]any
5555

56+
// GetStateValue retrieves a typed value from the state.
57+
//
58+
// Returns the typed value and true if the key exists and the type matches,
59+
// or the zero value and false otherwise.
60+
//
61+
// Example:
62+
//
63+
// if messages, ok := GetStateValue[[]model.Message](state, StateKeyMessages); ok {
64+
// // use messages
65+
// }
66+
// if userInput, ok := GetStateValue[string](state, StateKeyUserInput); ok {
67+
// // use userInput
68+
// }
69+
func GetStateValue[T any](s State, key string) (T, bool) {
70+
var zero T
71+
if s == nil {
72+
return zero, false
73+
}
74+
val, ok := s[key]
75+
if !ok {
76+
return zero, false
77+
}
78+
typedVal, ok := val.(T)
79+
if !ok {
80+
return zero, false
81+
}
82+
return typedVal, true
83+
}
84+
5685
// Clone creates a deep copy of the state.
5786
func (s State) Clone() State {
5887
clone := make(State, len(s))

0 commit comments

Comments
 (0)