Skip to content

Commit a04ce34

Browse files
authored
Merge branch 'main' into feature/model-gemini
2 parents 9f49f4e + 7d13294 commit a04ce34

File tree

88 files changed

+6403
-760
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

88 files changed

+6403
-760
lines changed

agent/a2aagent/a2a_converter.go

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -274,9 +274,13 @@ func processDataPart(part protocol.Part) (content string, toolCall *model.ToolCa
274274
return
275275
}
276276

277+
// Try both standard "type" and ADK-compatible "adk_type" metadata keys
277278
typeVal, hasType := d.Metadata[ia2a.DataPartMetadataTypeKey]
278279
if !hasType {
279-
return
280+
typeVal, hasType = d.Metadata[ia2a.GetADKMetadataKey(ia2a.DataPartMetadataTypeKey)]
281+
if !hasType {
282+
return
283+
}
280284
}
281285

282286
// Convert typeVal to string for comparison

agent/a2aagent/a2a_converter_test.go

Lines changed: 171 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1250,6 +1250,177 @@ func TestConvertDataPartToToolResponse(t *testing.T) {
12501250
}
12511251
}
12521252

1253+
// TestProcessDataPart_ADKMetadataKey tests handling of ADK-compatible metadata keys
1254+
func TestProcessDataPart_ADKMetadataKey(t *testing.T) {
1255+
type testCase struct {
1256+
name string
1257+
dataPart protocol.Part
1258+
validateFunc func(t *testing.T, content string, toolCall *model.ToolCall, toolResp *toolResponseInfo)
1259+
}
1260+
1261+
tests := []testCase{
1262+
{
1263+
name: "function call with adk_type metadata",
1264+
dataPart: &protocol.DataPart{
1265+
Data: map[string]any{
1266+
"id": "call-adk-1",
1267+
"type": "function",
1268+
"name": "get_weather",
1269+
"args": `{"location": "Beijing"}`,
1270+
},
1271+
Metadata: map[string]any{
1272+
"adk_type": "function_call", // Using ADK-compatible key
1273+
},
1274+
},
1275+
validateFunc: func(t *testing.T, content string, toolCall *model.ToolCall, toolResp *toolResponseInfo) {
1276+
if toolCall == nil {
1277+
t.Fatal("expected tool call, got nil")
1278+
}
1279+
if toolCall.ID != "call-adk-1" {
1280+
t.Errorf("expected ID 'call-adk-1', got %s", toolCall.ID)
1281+
}
1282+
if toolCall.Function.Name != "get_weather" {
1283+
t.Errorf("expected name 'get_weather', got %s", toolCall.Function.Name)
1284+
}
1285+
if toolResp != nil {
1286+
t.Errorf("expected nil tool response, got %v", toolResp)
1287+
}
1288+
},
1289+
},
1290+
{
1291+
name: "function response with adk_type metadata",
1292+
dataPart: &protocol.DataPart{
1293+
Data: map[string]any{
1294+
"id": "call-adk-2",
1295+
"name": "get_weather",
1296+
"response": "Beijing: Sunny, 20°C",
1297+
},
1298+
Metadata: map[string]any{
1299+
"adk_type": "function_response", // Using ADK-compatible key
1300+
},
1301+
},
1302+
validateFunc: func(t *testing.T, content string, toolCall *model.ToolCall, toolResp *toolResponseInfo) {
1303+
if content != "Beijing: Sunny, 20°C" {
1304+
t.Errorf("expected content 'Beijing: Sunny, 20°C', got %s", content)
1305+
}
1306+
if toolResp == nil {
1307+
t.Fatal("expected tool response info, got nil")
1308+
}
1309+
if toolResp.id != "call-adk-2" {
1310+
t.Errorf("expected tool response ID 'call-adk-2', got %s", toolResp.id)
1311+
}
1312+
if toolResp.name != "get_weather" {
1313+
t.Errorf("expected tool response name 'get_weather', got %s", toolResp.name)
1314+
}
1315+
if toolCall != nil {
1316+
t.Errorf("expected nil tool call, got %v", toolCall)
1317+
}
1318+
},
1319+
},
1320+
{
1321+
name: "no type metadata",
1322+
dataPart: &protocol.DataPart{
1323+
Data: map[string]any{
1324+
"id": "call-no-type",
1325+
"name": "some_function",
1326+
},
1327+
Metadata: map[string]any{
1328+
"other_key": "other_value",
1329+
},
1330+
},
1331+
validateFunc: func(t *testing.T, content string, toolCall *model.ToolCall, toolResp *toolResponseInfo) {
1332+
if content != "" {
1333+
t.Errorf("expected empty content, got %s", content)
1334+
}
1335+
if toolCall != nil {
1336+
t.Errorf("expected nil tool call, got %v", toolCall)
1337+
}
1338+
if toolResp != nil {
1339+
t.Errorf("expected nil tool response, got %v", toolResp)
1340+
}
1341+
},
1342+
},
1343+
{
1344+
name: "nil metadata",
1345+
dataPart: &protocol.DataPart{
1346+
Data: map[string]any{
1347+
"id": "call-nil-meta",
1348+
"name": "some_function",
1349+
},
1350+
Metadata: nil,
1351+
},
1352+
validateFunc: func(t *testing.T, content string, toolCall *model.ToolCall, toolResp *toolResponseInfo) {
1353+
if content != "" {
1354+
t.Errorf("expected empty content, got %s", content)
1355+
}
1356+
if toolCall != nil {
1357+
t.Errorf("expected nil tool call, got %v", toolCall)
1358+
}
1359+
if toolResp != nil {
1360+
t.Errorf("expected nil tool response, got %v", toolResp)
1361+
}
1362+
},
1363+
},
1364+
{
1365+
name: "non-string type value",
1366+
dataPart: &protocol.DataPart{
1367+
Data: map[string]any{
1368+
"id": "call-bad-type",
1369+
"name": "some_function",
1370+
},
1371+
Metadata: map[string]any{
1372+
"type": 12345, // Non-string type value
1373+
},
1374+
},
1375+
validateFunc: func(t *testing.T, content string, toolCall *model.ToolCall, toolResp *toolResponseInfo) {
1376+
if content != "" {
1377+
t.Errorf("expected empty content, got %s", content)
1378+
}
1379+
if toolCall != nil {
1380+
t.Errorf("expected nil tool call, got %v", toolCall)
1381+
}
1382+
if toolResp != nil {
1383+
t.Errorf("expected nil tool response, got %v", toolResp)
1384+
}
1385+
},
1386+
},
1387+
{
1388+
name: "standard type key takes precedence over adk_type",
1389+
dataPart: &protocol.DataPart{
1390+
Data: map[string]any{
1391+
"id": "call-precedence",
1392+
"type": "function",
1393+
"name": "test_func",
1394+
"args": `{"x": 1}`,
1395+
},
1396+
Metadata: map[string]any{
1397+
"type": "function_call", // Standard key
1398+
"adk_type": "function_response", // ADK key (should be ignored)
1399+
},
1400+
},
1401+
validateFunc: func(t *testing.T, content string, toolCall *model.ToolCall, toolResp *toolResponseInfo) {
1402+
// Should process as function_call, not function_response
1403+
if toolCall == nil {
1404+
t.Fatal("expected tool call, got nil")
1405+
}
1406+
if toolCall.Function.Name != "test_func" {
1407+
t.Errorf("expected name 'test_func', got %s", toolCall.Function.Name)
1408+
}
1409+
if toolResp != nil {
1410+
t.Errorf("expected nil tool response (standard key should take precedence), got %v", toolResp)
1411+
}
1412+
},
1413+
},
1414+
}
1415+
1416+
for _, tc := range tests {
1417+
t.Run(tc.name, func(t *testing.T) {
1418+
content, toolCall, toolResp := processDataPart(tc.dataPart)
1419+
tc.validateFunc(t, content, toolCall, toolResp)
1420+
})
1421+
}
1422+
}
1423+
12531424
// Helper function to create string pointer
12541425
func stringPtr(s string) *string {
12551426
return &s

agent/invocation.go

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -558,6 +558,35 @@ func (inv *Invocation) GetState(key string) (any, bool) {
558558
return value, ok
559559
}
560560

561+
// GetStateValue retrieves a typed value from the invocation state.
562+
//
563+
// Returns the typed value and true if the key exists and the type matches,
564+
// or the zero value and false otherwise.
565+
//
566+
// Example:
567+
//
568+
// if startTime, ok := GetStateValue[time.Time](inv, "agent:start_time"); ok {
569+
// duration := time.Since(startTime)
570+
// }
571+
// if requestID, ok := GetStateValue[string](inv, "middleware:request_id"); ok {
572+
// log.Printf("Request ID: %s", requestID)
573+
// }
574+
func GetStateValue[T any](inv *Invocation, key string) (T, bool) {
575+
var zero T
576+
if inv == nil {
577+
return zero, false
578+
}
579+
val, ok := inv.GetState(key)
580+
if !ok {
581+
return zero, false
582+
}
583+
typedVal, ok := val.(T)
584+
if !ok {
585+
return zero, false
586+
}
587+
return typedVal, true
588+
}
589+
561590
// GetOrCreateTimingInfo gets or creates timing info for this invocation.
562591
// Only the first LLM call will create and populate timing info; subsequent calls reuse it.
563592
// This ensures timing metrics only reflect the first LLM call in scenarios with multiple calls (e.g., tool calls).

agent/invocation_state_test.go

Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -323,3 +323,124 @@ func TestInvocation_State_ComplexStruct(t *testing.T) {
323323
assert.Equal(t, data.ID, retrieved.ID)
324324
assert.Equal(t, data.Metadata, retrieved.Metadata)
325325
}
326+
327+
func TestGetStateValue(t *testing.T) {
328+
t.Run("key not found", func(t *testing.T) {
329+
inv := NewInvocation()
330+
val, ok := GetStateValue[string](inv, "nonexistent")
331+
assert.False(t, ok)
332+
assert.Equal(t, "", val)
333+
})
334+
335+
t.Run("matching type", func(t *testing.T) {
336+
inv := NewInvocation()
337+
inv.SetState("agent:string", "hello")
338+
inv.SetState("agent:int", 42)
339+
inv.SetState("agent:float", 3.14)
340+
inv.SetState("agent:bool", true)
341+
inv.SetState("agent:time", time.Now())
342+
343+
// Test string.
344+
strVal, ok := GetStateValue[string](inv, "agent:string")
345+
assert.True(t, ok)
346+
assert.Equal(t, "hello", strVal)
347+
348+
// Test int.
349+
intVal, ok := GetStateValue[int](inv, "agent:int")
350+
assert.True(t, ok)
351+
assert.Equal(t, 42, intVal)
352+
353+
// Test float64.
354+
floatVal, ok := GetStateValue[float64](inv, "agent:float")
355+
assert.True(t, ok)
356+
assert.Equal(t, 3.14, floatVal)
357+
358+
// Test bool.
359+
boolVal, ok := GetStateValue[bool](inv, "agent:bool")
360+
assert.True(t, ok)
361+
assert.Equal(t, true, boolVal)
362+
363+
// Test time.Time.
364+
timeVal, ok := GetStateValue[time.Time](inv, "agent:time")
365+
assert.True(t, ok)
366+
assert.IsType(t, time.Time{}, timeVal)
367+
})
368+
369+
t.Run("type mismatch", func(t *testing.T) {
370+
inv := NewInvocation()
371+
inv.SetState("agent:value", "hello")
372+
373+
// Try to get as int when it's actually string.
374+
intVal, ok := GetStateValue[int](inv, "agent:value")
375+
assert.False(t, ok)
376+
assert.Equal(t, 0, intVal)
377+
378+
// Try to get as string when it's actually int.
379+
inv.SetState("agent:number", 42)
380+
strVal, ok := GetStateValue[string](inv, "agent:number")
381+
assert.False(t, ok)
382+
assert.Equal(t, "", strVal)
383+
})
384+
385+
t.Run("nil invocation", func(t *testing.T) {
386+
var inv *Invocation
387+
val, ok := GetStateValue[string](inv, "key")
388+
assert.False(t, ok)
389+
assert.Equal(t, "", val)
390+
})
391+
392+
t.Run("complex struct type", func(t *testing.T) {
393+
type CustomData struct {
394+
ID string
395+
Timestamp time.Time
396+
Metadata map[string]string
397+
}
398+
399+
inv := NewInvocation()
400+
data := CustomData{
401+
ID: "test-123",
402+
Timestamp: time.Now(),
403+
Metadata: map[string]string{
404+
"key1": "value1",
405+
"key2": "value2",
406+
},
407+
}
408+
inv.SetState("agent:custom_data", data)
409+
410+
retrieved, ok := GetStateValue[CustomData](inv, "agent:custom_data")
411+
require.True(t, ok)
412+
assert.Equal(t, data.ID, retrieved.ID)
413+
assert.Equal(t, data.Metadata, retrieved.Metadata)
414+
})
415+
416+
t.Run("pointer type", func(t *testing.T) {
417+
inv := NewInvocation()
418+
str := "hello"
419+
inv.SetState("agent:ptr", &str)
420+
421+
ptrVal, ok := GetStateValue[*string](inv, "agent:ptr")
422+
assert.True(t, ok)
423+
require.NotNil(t, ptrVal)
424+
assert.Equal(t, "hello", *ptrVal)
425+
})
426+
427+
t.Run("slice type", func(t *testing.T) {
428+
inv := NewInvocation()
429+
slice := []int{1, 2, 3}
430+
inv.SetState("agent:slice", slice)
431+
432+
sliceVal, ok := GetStateValue[[]int](inv, "agent:slice")
433+
assert.True(t, ok)
434+
assert.Equal(t, []int{1, 2, 3}, sliceVal)
435+
})
436+
437+
t.Run("map type", func(t *testing.T) {
438+
inv := NewInvocation()
439+
m := map[string]int{"a": 1, "b": 2}
440+
inv.SetState("agent:map", m)
441+
442+
mapVal, ok := GetStateValue[map[string]int](inv, "agent:map")
443+
assert.True(t, ok)
444+
assert.Equal(t, map[string]int{"a": 1, "b": 2}, mapVal)
445+
})
446+
}

agent/invocationcontext.go

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,29 @@ func InvocationFromContext(ctx context.Context) (*Invocation, bool) {
3232
return invocation, ok
3333
}
3434

35+
// GetStateValueFromContext retrieves a typed value from the invocation state
36+
// stored in the context.
37+
//
38+
// Returns the typed value and true if the invocation exists, the key exists,
39+
// and the type matches, or the zero value and false otherwise.
40+
//
41+
// Example:
42+
//
43+
// if startTime, ok := GetStateValueFromContext[time.Time](ctx, "agent:start_time"); ok {
44+
// duration := time.Since(startTime)
45+
// }
46+
// if requestID, ok := GetStateValueFromContext[string](ctx, "middleware:request_id"); ok {
47+
// log.Printf("Request ID: %s", requestID)
48+
// }
49+
func GetStateValueFromContext[T any](ctx context.Context, key string) (T, bool) {
50+
var zero T
51+
inv, ok := InvocationFromContext(ctx)
52+
if !ok {
53+
return zero, false
54+
}
55+
return GetStateValue[T](inv, key)
56+
}
57+
3558
// CheckContextCancelled check context cancelled
3659
func CheckContextCancelled(ctx context.Context) error {
3760
select {

0 commit comments

Comments
 (0)