Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions session/vertexai/vertexai_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -443,12 +443,14 @@ func aiplatformToGenaiContent(rpcResp *aiplatformpb.SessionEvent) *genai.Content
case *aiplatformpb.Part_FunctionCall:
argsMap := v.FunctionCall.Args.AsMap() // Converts *structpb.Struct -> map[string]any
part.FunctionCall = &genai.FunctionCall{
ID: v.FunctionCall.Id,
Name: v.FunctionCall.Name,
Args: argsMap,
}
case *aiplatformpb.Part_FunctionResponse:
responseMap := v.FunctionResponse.Response.AsMap() // Converts *structpb.Struct -> map[string]any
part.FunctionResponse = &genai.FunctionResponse{
ID: v.FunctionResponse.Id,
Name: v.FunctionResponse.Name,
Response: responseMap,
}
Expand Down
63 changes: 63 additions & 0 deletions session/vertexai/vertexai_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,71 @@ package vertexai

import (
"testing"

aiplatformpb "cloud.google.com/go/aiplatform/apiv1beta1/aiplatformpb"
"google.golang.org/protobuf/types/known/structpb"
)

func TestAiplatformToGenaiContent_PreservesFunctionCallAndResponseIDs(t *testing.T) {
callID := "test-call-id-123"
argsStruct, err := structpb.NewStruct(map[string]any{"param": "value"})
if err != nil {
t.Fatalf("failed to create args struct: %v", err)
}
respStruct, err := structpb.NewStruct(map[string]any{"result": "ok"})
if err != nil {
t.Fatalf("failed to create response struct: %v", err)
}

sessionEvent := &aiplatformpb.SessionEvent{
Content: &aiplatformpb.Content{
Role: "model",
Parts: []*aiplatformpb.Part{
{
Data: &aiplatformpb.Part_FunctionCall{
FunctionCall: &aiplatformpb.FunctionCall{
Id: callID,
Name: "my_tool",
Args: argsStruct,
},
},
},
},
},
}
gotCall := aiplatformToGenaiContent(sessionEvent)
if gotCall == nil || len(gotCall.Parts) == 0 || gotCall.Parts[0].FunctionCall == nil {
t.Fatal("expected FunctionCall part, got nil")
}
if got := gotCall.Parts[0].FunctionCall.ID; got != callID {
t.Errorf("FunctionCall.ID = %q, want %q", got, callID)
}

sessionEvent2 := &aiplatformpb.SessionEvent{
Content: &aiplatformpb.Content{
Role: "user",
Parts: []*aiplatformpb.Part{
{
Data: &aiplatformpb.Part_FunctionResponse{
FunctionResponse: &aiplatformpb.FunctionResponse{
Id: callID,
Name: "my_tool",
Response: respStruct,
},
},
},
},
},
}
gotResp := aiplatformToGenaiContent(sessionEvent2)
if gotResp == nil || len(gotResp.Parts) == 0 || gotResp.Parts[0].FunctionResponse == nil {
t.Fatal("expected FunctionResponse part, got nil")
}
if got := gotResp.Parts[0].FunctionResponse.ID; got != callID {
t.Errorf("FunctionResponse.ID = %q, want %q", got, callID)
}
}
Comment thread
nuthalapativarun marked this conversation as resolved.
Outdated

func TestGetReasoningEngineID(t *testing.T) {
tests := []struct {
name string
Expand Down