Skip to content

Commit 81fc988

Browse files
author
Gunpal Jain
committed
add: google embedding model and updated Google_Generative SDK to latest version
1 parent e4f39bc commit 81fc988

File tree

6 files changed

+95
-59
lines changed

6 files changed

+95
-59
lines changed

src/Directory.Packages.props

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
<PackageVersion Include="FluentAssertions" Version="8.0.1" />
1414
<PackageVersion Include="GitHubActionsTestLogger" Version="2.4.1" />
1515
<PackageVersion Include="Google.Cloud.AIPlatform.V1" Version="3.9.0" />
16-
<PackageVersion Include="Google_GenerativeAI" Version="1.0.2" />
16+
<PackageVersion Include="Google_GenerativeAI" Version="2.0.4" />
1717
<PackageVersion Include="GroqSharp" Version="1.1.2" />
1818
<PackageVersion Include="H.Generators.Extensions" Version="1.22.0" />
1919
<PackageVersion Include="H.Generators.Tests.Extensions" Version="1.22.0" />

src/Google/src/Extensions/GoogleGeminiExtensions.cs

Lines changed: 41 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,47 +1,63 @@
11
using System.Text.Json;
22
using System.Text.Json.Serialization;
33
using CSharpToJsonSchema;
4-
using GenerativeAI.Tools;
4+
using GenerativeAI;
55
using GenerativeAI.Types;
6+
using Tool = CSharpToJsonSchema.Tool;
67

78
namespace LangChain.Providers.Google.Extensions;
89

910
internal static class GoogleGeminiExtensions
1011
{
11-
public static bool IsFunctionCall(this EnhancedGenerateContentResponse response)
12+
public static bool IsFunctionCall(this GenerateContentResponse response)
1213
{
1314
return response.GetFunction() != null;
1415
}
1516

16-
public static List<GenerativeAITool> ToGenerativeAiTools(this IEnumerable<Tool> functions)
17+
public static List<GenerativeAI.Types.Tool?> ToGenerativeAiTools(this IEnumerable<Tool> functions)
1718
{
18-
return new List<GenerativeAITool>([
19-
new GenerativeAITool
19+
var declarations = functions
20+
.Where(x => x != null)
21+
.Select(x => new FunctionDeclaration
2022
{
21-
FunctionDeclaration = functions.Select(x => new ChatCompletionFunction
23+
Name = x.Name ?? string.Empty,
24+
Description = x.Description ?? string.Empty,
25+
Parameters = x.Parameters is OpenApiSchema schema ? ToFunctionParameters(schema) : null,
26+
})
27+
.ToList();
28+
29+
if (declarations.Any())
30+
{
31+
return new List<GenerativeAI.Types.Tool?>
32+
{
33+
new GenerativeAI.Types.Tool
2234
{
23-
Name = x.Name ?? string.Empty,
24-
Description = x.Description ?? string.Empty,
25-
Parameters = ToFunctionParameters((OpenApiSchema)x.Parameters!),
26-
}).ToList(),
27-
}
28-
]);
35+
FunctionDeclarations = declarations
36+
}
37+
};
38+
}
39+
40+
return null;
41+
}
42+
43+
public static string GetStringForFunctionArgs(this object? arguments)
44+
{
45+
if (arguments == null)
46+
return string.Empty;
47+
if (arguments is JsonElement jsonElement)
48+
return jsonElement.ToString();
49+
else
50+
{
51+
return null;
52+
}
2953
}
30-
public static ChatCompletionFunctionParameters ToFunctionParameters(this OpenApiSchema schema)
54+
55+
public static Schema? ToFunctionParameters(this OpenApiSchema openApiSchema)
3156
{
32-
if (schema.Items == null) return new ChatCompletionFunctionParameters();
33-
var parameters = new ChatCompletionFunctionParameters();
34-
35-
parameters.AdditionalProperties.Add("type", schema.Items.Type);
36-
if (schema.Items.Description != null && !string.IsNullOrEmpty(schema.Items.Description))
37-
parameters.AdditionalProperties.Add("description", schema.Items.Description);
38-
if (schema.Items.Properties != null)
39-
parameters.AdditionalProperties.Add("properties", schema.Items.Properties);
40-
if (schema.Items.Required != null)
41-
parameters.AdditionalProperties.Add("required", schema.Items.Required);
42-
43-
return parameters;
57+
var text = JsonSerializer.Serialize(openApiSchema);
58+
return JsonSerializer.Deserialize<Schema?>(text);
4459
}
60+
4561
public static string GetString(this IDictionary<string, object>? arguments)
4662
{
4763
if (arguments == null)

src/Google/src/Extensions/StringExtensions.cs

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
using System.Text.Json;
22
using System.Text.Json.Nodes;
3-
using GenerativeAI.Extensions;
4-
using GenerativeAI.Tools;
3+
using GenerativeAI;
54
using GenerativeAI.Types;
65

76
namespace LangChain.Providers.Google.Extensions;
@@ -56,10 +55,9 @@ public static Content AsFunctionCallContent(this string args, string functionNam
5655
var content = new Content([
5756
new Part
5857
{
59-
FunctionCall = new ChatFunctionCall
58+
FunctionCall = new FunctionCall()
6059
{
61-
Arguments = JsonSerializer.Deserialize(args, SourceGenerationContext.Default.DictionaryStringString)?
62-
.ToDictionary(x => x.Key, x => (object)x.Value) ?? [],
60+
Args = JsonNode.Parse(args),
6361
Name = functionName
6462
}
6563
}
@@ -76,6 +74,20 @@ public static Content AsFunctionCallContent(this string args, string functionNam
7674
[CLSCompliant(false)]
7775
public static Content AsFunctionResultContent(this string args, string functionName)
7876
{
79-
return JsonNode.Parse(args).ToFunctionCallContent(functionName);
77+
var functionResponse = new FunctionResponse()
78+
{
79+
Response = new
80+
{
81+
Name = functionName,
82+
Content = JsonNode.Parse(args)
83+
},
84+
Name = functionName
85+
};
86+
var content = new Content(){Role = Roles.Function};
87+
content.AddPart(new Part()
88+
{
89+
FunctionResponse = functionResponse
90+
});
91+
return content;
8092
}
8193
}

src/Google/src/GoogleChatModel.Tokens.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ public async Task<int> CountTokens(string text)
1111

1212
public async Task<int> CountTokens(IEnumerable<Message> messages)
1313
{
14-
var response = await this.Api.CountTokens(new CountTokensRequest() { Contents = messages.Select(ToRequestMessage).ToArray() }).ConfigureAwait(false);
14+
var response = await this.Api.CountTokensAsync(new CountTokensRequest() { Contents = messages.Select(ToRequestMessage).ToList() }).ConfigureAwait(false);
1515

1616
return response.TotalTokens;
1717
}

src/Google/src/GoogleChatModel.cs

Lines changed: 29 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
11
using System.Diagnostics;
22
using System.Runtime.CompilerServices;
3-
using GenerativeAI.Models;
3+
using System.Text;
4+
using System.Text.Json;
5+
using System.Text.Json.Nodes;
6+
using GenerativeAI;
7+
using GenerativeAI.Core;
48
using GenerativeAI.Types;
59
using LangChain.Providers.Google.Extensions;
610

@@ -27,11 +31,14 @@ public partial class GoogleChatModel(
2731
private GenerativeModel Api { get; } = new(
2832
provider.ApiKey,
2933
id,
30-
provider.HttpClient)
34+
httpClient:provider.HttpClient)
3135
{
32-
AutoCallFunction = false,
33-
AutoReplyFunction = false,
34-
AutoHandleBadFunctionCalls = false
36+
FunctionCallingBehaviour = new FunctionCallingBehaviour()
37+
{
38+
AutoCallFunction = false,
39+
AutoReplyFunction = false,
40+
AutoHandleBadFunctionCalls = false
41+
}
3542
};
3643

3744
#endregion
@@ -52,13 +59,13 @@ private static Content ToRequestMessage(Message message)
5259
};
5360
}
5461

55-
private static Message ToMessage(EnhancedGenerateContentResponse message)
62+
private static Message ToMessage(GenerateContentResponse message)
5663
{
5764
if (message.GetFunction() != null)
5865
{
5966
var function = message.GetFunction();
6067

61-
return new Message(function?.Arguments.GetString() ?? string.Empty,
68+
return new Message( function?.Args.GetStringForFunctionArgs() ?? string.Empty,
6269
MessageRole.ToolCall, function?.Name);
6370
}
6471

@@ -67,13 +74,13 @@ private static Message ToMessage(EnhancedGenerateContentResponse message)
6774
MessageRole.Ai);
6875
}
6976

70-
private async Task<EnhancedGenerateContentResponse> CreateChatCompletionAsync(
77+
private async Task<GenerateContentResponse> CreateChatCompletionAsync(
7178
IReadOnlyCollection<Message> messages,
7279
CancellationToken cancellationToken = default)
7380
{
7481
var request = new GenerateContentRequest
7582
{
76-
Contents = messages.Select(ToRequestMessage).ToArray(),
83+
Contents = messages.Select(ToRequestMessage).ToList(),
7784
Tools = GlobalTools.ToGenerativeAiTools()
7885
};
7986

@@ -94,7 +101,7 @@ private async Task<Message> StreamCompletionAsync(IReadOnlyCollection<Message> m
94101
{
95102
var request = new GenerateContentRequest
96103
{
97-
Contents = messages.Select(ToRequestMessage).ToArray()
104+
Contents = messages.Select(ToRequestMessage).ToList()
98105
};
99106
if (provider.Configuration != null)
100107
request.GenerationConfig = new GenerationConfig
@@ -104,11 +111,18 @@ private async Task<Message> StreamCompletionAsync(IReadOnlyCollection<Message> m
104111
TopP = provider.Configuration.TopP,
105112
Temperature = provider.Configuration.Temperature
106113
};
107-
var res = await Api.StreamContentAsync(request, OnDeltaReceived, cancellationToken)
108-
.ConfigureAwait(false);
114+
StringBuilder sb = new StringBuilder();
115+
await foreach (var response in Api.StreamContentAsync(request, cancellationToken))
116+
{
117+
var text = response.Text() ?? string.Empty;
118+
119+
sb.Append(text);
120+
OnDeltaReceived(text);
121+
}
122+
109123

110124
return new Message(
111-
res,
125+
sb.ToString(),
112126
MessageRole.Ai);
113127
}
114128

@@ -173,7 +187,7 @@ public override async IAsyncEnumerable<ChatResponse> GenerateAsync(
173187

174188
if (Calls.TryGetValue(name, out var func))
175189
{
176-
var args = function?.Arguments.GetString() ?? string.Empty;
190+
var args = function?.Args.GetStringForFunctionArgs() ?? string.Empty;
177191

178192
var jsonResult = await func(args, cancellationToken).ConfigureAwait(false);
179193
messages.Add(jsonResult.AsToolResultMessage(name));
@@ -216,7 +230,7 @@ public override async IAsyncEnumerable<ChatResponse> GenerateAsync(
216230

217231
yield return chatResponse;
218232
}
219-
private Usage GetUsage(EnhancedGenerateContentResponse response)
233+
private Usage GetUsage(GenerateContentResponse response)
220234
{
221235
var outputTokens = response.UsageMetadata?.CandidatesTokenCount ?? 0;
222236
var inputTokens = response.UsageMetadata?.PromptTokenCount ?? 0;

src/Google/src/Predefined/GeminiModels.cs

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,20 @@
1-
using GenerativeAI.Models;
1+
using GenerativeAI;
22

33
namespace LangChain.Providers.Google.Predefined;
44

55
/// <inheritdoc cref="GoogleAIModels.GeminiPro" />
6-
public class GeminiProModel(GoogleProvider provider)
6+
public class Gemini15ProLatest(GoogleProvider provider)
77
: GoogleChatModel(
88
provider,
9-
GoogleAIModels.GeminiPro, 32 * 1024, 0.5 * 0.000001, 1.5 * 0.000001);
9+
GoogleAIModels.Gemini15ProLatest, 32 * 1024, 0.5 * 0.000001, 1.5 * 0.000001);
1010

11-
/// <inheritdoc cref="GoogleAIModels.GeminiProVision" />
12-
public class GeminiProVisionModel(GoogleProvider provider)
13-
: GoogleChatModel(
14-
provider,
15-
GoogleAIModels.GeminiProVision, 32 * 1024, 0.5 * 0.000001, 1.5 * 0.000001);
16-
17-
/// <inheritdoc cref="GoogleAIModels.GeminiProVision" />
11+
/// <inheritdoc cref="GoogleAIModels.Gemini15Flash" />
1812
public class Gemini15FlashModel(GoogleProvider provider)
1913
: GoogleChatModel(
2014
provider,
2115
GoogleAIModels.Gemini15Flash, 1024 * 1024, 0.35 * 0.000001, 1.05 * 0.000001, 0.70 * 0.000001, 2.1 * 0.000001);
2216

23-
/// <inheritdoc cref="GoogleAIModels.GeminiProVision" />
17+
/// <inheritdoc cref="GoogleAIModels.Gemini15Pro" />
2418
public class Gemini15ProModel(GoogleProvider provider)
2519
: GoogleChatModel(
2620
provider,

0 commit comments

Comments
 (0)