diff --git a/README.md b/README.md index 884a299..1a2c893 100644 --- a/README.md +++ b/README.md @@ -6,12 +6,14 @@ [![Discord](https://img.shields.io/discord/1115206893015662663?label=Discord&logo=discord&logoColor=white&color=d82679)](https://discord.gg/Ca2xhfBf3v) ## Features 🔥 -- Source generator to define functions natively through C# interfaces +- Source generator to define functions natively through C# interfaces and individual methods - Doesn't use Reflection - All modern .NET features - nullability, trimming, NativeAOT, etc. - Tested for compatibility with OpenAI/Ollama/Anthropic/LangChain/Gemini ## Usage + +### Interface ```csharp using CSharpToJsonSchema; @@ -56,6 +58,32 @@ public class WeatherService : IWeatherFunctions var tools = service.AsTools(); ``` +### Methods + +```csharp + +[FunctionTool] +public Task GetCurrentWeatherAsync(string location, Unit unit = Unit.Celsius, CancellationToken cancellationToken = default) +{ + return Task.FromResult(new Weather + { + Location = location, + Temperature = 22.0, + Unit = unit, + Description = "Sunny", + }); +} + +var tools = new Tools([GetCurrentWeatherAsync]) + +//Access list of CSharpToJsonSchema.Tool +var myTools = tools.AvailableTools + +//Implicit Conversion to list of CSharpToJsonSchema.Tool +List myTools = tools +``` + + ## Support Priority place for bugs: https://github.com/tryAGI/CSharpToJsonSchema/issues diff --git a/src/libs/CSharpToJsonSchema.Generators/Conversion/SymbolGenerator.cs b/src/libs/CSharpToJsonSchema.Generators/Conversion/SymbolGenerator.cs index 912eb22..50672ad 100644 --- a/src/libs/CSharpToJsonSchema.Generators/Conversion/SymbolGenerator.cs +++ b/src/libs/CSharpToJsonSchema.Generators/Conversion/SymbolGenerator.cs @@ -96,6 +96,59 @@ public static class SymbolGenerator var classSymbol = semanticModel.GetDeclaredSymbol(classNode) as ITypeSymbol; return classSymbol; } + + public static INamedTypeSymbol? GenerateToolJsonSerializerContext( + string rootNamespace, + Compilation originalCompilation) + { + + + // Example: we create a class name + var className = $"ToolsJsonSerializerContext"; + + + // Build a class declaration + var classDecl = SyntaxFactory.ClassDeclaration(className) + .AddModifiers(SyntaxFactory.Token(SyntaxKind.PublicKeyword)) + .AddModifiers(SyntaxFactory.Token(SyntaxKind.PartialKeyword)); + + // We create a compilation unit holding our new class + + + + var namespaceName =rootNamespace; // choose your own + var ns = SyntaxFactory.NamespaceDeclaration(SyntaxFactory.IdentifierName(namespaceName)) + .AddMembers(classDecl); + + var compilationUnit = SyntaxFactory.CompilationUnit() + .AddMembers(ns) // if ns is a NamespaceDeclarationSyntax + .NormalizeWhitespace(); + + var parseOptions = CSharpParseOptions.Default.WithLanguageVersion(originalCompilation.GetLanguageVersion()?? LanguageVersion.Default); + var syntaxTree =CSharpSyntaxTree.Create(compilationUnit,parseOptions); + //CSharpSyntaxTree.Create(ns.NormalizeWhitespace()); + + // Now we need to add this syntax tree to a new or existing compilation + var assemblyName = "TemporaryAssembly"; + var compilation = originalCompilation + .AddSyntaxTrees(syntaxTree); + //.WithAssemblyName(assemblyName); + + + // Get the semantic model for our newly added syntax tree + var semanticModel = compilation.GetSemanticModel(syntaxTree); + + // Find the class syntax node in the syntax tree + var classNode = syntaxTree.GetRoot().DescendantNodes() + .OfType() + .FirstOrDefault(); + + if (classNode == null) return null; + + // Retrieve the ITypeSymbol from the semantic model + var classSymbol = semanticModel.GetDeclaredSymbol(classNode); + return classSymbol; + } public static AttributeSyntax GetConverter(string propertyType) { diff --git a/src/libs/CSharpToJsonSchema.Generators/Conversion/ToModels.cs b/src/libs/CSharpToJsonSchema.Generators/Conversion/ToModels.cs index 15f951f..f310202 100644 --- a/src/libs/CSharpToJsonSchema.Generators/Conversion/ToModels.cs +++ b/src/libs/CSharpToJsonSchema.Generators/Conversion/ToModels.cs @@ -16,6 +16,8 @@ public static InterfaceData PrepareData( var isStrict = attributeData.NamedArguments.FirstOrDefault(x => x.Key == "Strict").Value.Value is bool strict && strict; + var generateGoogleFunctionTool = attributeData.NamedArguments.FirstOrDefault(x => x.Key == "GoogleFunctionTool").Value.Value is bool googleFunctionTool && + googleFunctionTool; var methods = interfaceSymbol .GetMembers() .OfType() @@ -43,42 +45,93 @@ public static InterfaceData PrepareData( return new InterfaceData( Namespace: interfaceSymbol.ContainingNamespace.ToDisplayString(), Name: interfaceSymbol.ToDisplayString(SymbolDisplayFormat.MinimallyQualifiedFormat), + GoogleFunctionTool:generateGoogleFunctionTool, Methods: methods); } public static InterfaceData PrepareMethodData( - this IMethodSymbol interfaceSymbol, - AttributeData attributeData) + List<(IMethodSymbol, AttributeData)> list) { - interfaceSymbol = interfaceSymbol ?? throw new ArgumentNullException(nameof(interfaceSymbol)); - attributeData = attributeData ?? throw new ArgumentNullException(nameof(attributeData)); + //interfaceSymbol = interfaceSymbol ?? throw new ArgumentNullException(nameof(interfaceSymbol)); + list = list ?? throw new ArgumentNullException(nameof(list)); + + var namespaceName = "CSharpToJsonSchema"; + var className = "Tools"; + List methodList = new(); + List namespaces = new(); + bool generateGoogleFunctionTools = false; + foreach (var l in list) + { + var (interfaceSymbol, attributeData) = l; + var isStrict = attributeData.NamedArguments.FirstOrDefault(x => x.Key == "Strict").Value.Value is bool strict && + strict; + var ggft = attributeData.NamedArguments.FirstOrDefault(x => x.Key == "GoogleFunctionTool").Value.Value is bool googleFunctionTool && + googleFunctionTool; + if(ggft) + generateGoogleFunctionTools = true; + + var x = interfaceSymbol; + var parameters = x.Parameters + //.Where(static x => x.Type.MetadataName != "CancellationToken") + .ToArray(); + + var methodData = new MethodData( + Name: x.Name, + Description: GetDescription(x), + IsAsync: x.IsAsync || x.ReturnType.Name == "Task", + IsVoid: x.ReturnsVoid || x.ReturnType.MetadataName == "Task", + IsStrict: isStrict, + Parameters: parameters.Select(static y => y).ToArray(), + Descriptions: parameters.Select(static l => GetParameterDescriptions(l)).SelectMany(s => s) + .ToDictionary(s => s.Key, s => s.Value), + ReturnType:x.ReturnType + ); + methodList.Add(methodData); + namespaces.Add(interfaceSymbol.ContainingNamespace.ToDisplayString()); + } - var isStrict = attributeData.NamedArguments.FirstOrDefault(x => x.Key == "Strict").Value.Value is bool strict && - strict; - var x = interfaceSymbol; - var parameters = x.Parameters - //.Where(static x => x.Type.MetadataName != "CancellationToken") - .ToArray(); + return new InterfaceData( + Namespace: GetCommonRootNamespace(namespaces)??namespaceName, + Name: className, + GoogleFunctionTool: generateGoogleFunctionTools, + Methods: methodList.ToArray()); + } + public static string? GetCommonRootNamespace(IEnumerable namespaces) + { + // Convert the list of namespaces to a list of arrays split by "." + var splitNamespaces = namespaces + .Select(ns => ns.Split('.')) + .ToList(); + + if (!splitNamespaces.Any() || !splitNamespaces[0].Any()) + { + return null; + } - var methodData = new MethodData( - Name: x.Name, - Description: GetDescription(x), - IsAsync: x.IsAsync || x.ReturnType.Name == "Task", - IsVoid: x.ReturnsVoid || x.ReturnType.MetadataName == "Task", - IsStrict: isStrict, - Parameters: parameters.Select(static y => y).ToArray(), - Descriptions: parameters.Select(static l => GetParameterDescriptions(l)).SelectMany(s => s) - .ToDictionary(s => s.Key, s => s.Value), - ReturnType:x.ReturnType - ); + // Start with the first namespace parts + var firstNsParts = splitNamespaces[0]; + var commonParts = new List(); + // For each index in the first namespace + for (int i = 0; i < firstNsParts.Length; i++) + { + // Check if every other namespace has the same part at this index + string currentPart = firstNsParts[i]; + if (splitNamespaces.All(nsArr => nsArr.Length > i && nsArr[i] == currentPart)) + { + commonParts.Add(currentPart); + } + else + { + // Stop the moment there is a mismatch + break; + } + } - return new InterfaceData( - Namespace: interfaceSymbol.ContainingNamespace.ToDisplayString(), - Name: "I"+interfaceSymbol.Name, - Methods: [methodData]); + return string.Join(".", commonParts); } + // private static Dictionary GetIsRequired(IParameterSymbol[] parameters, Dictionary? dics = null) // { // dics ??= new Dictionary(); diff --git a/src/libs/CSharpToJsonSchema.Generators/JsonGen/JsonSourceGenerator.Parser.cs b/src/libs/CSharpToJsonSchema.Generators/JsonGen/JsonSourceGenerator.Parser.cs index afdbb84..ef1e907 100644 --- a/src/libs/CSharpToJsonSchema.Generators/JsonGen/JsonSourceGenerator.Parser.cs +++ b/src/libs/CSharpToJsonSchema.Generators/JsonGen/JsonSourceGenerator.Parser.cs @@ -99,7 +99,184 @@ public Parser(KnownTypeSymbols knownSymbols) _builtInSupportTypes = (knownSymbols.BuiltInSupportTypes ??= CreateBuiltInSupportTypeSet(knownSymbols)); } - public ContextGenerationSpec? ParseContextGenerationSpec(InterfaceDeclarationSyntax contextClassDeclaration, + public ContextGenerationSpec? ParseContextGenerationSpec(ImmutableArray<((MethodDeclarationSyntax ContextClass, SemanticModel SemanticModel) Left, KnownTypeSymbols Right)> values, CancellationToken cancellationToken) + { + if (!_compilationContainsCoreJsonTypes) + { + return null; + } + + Debug.Assert(_knownSymbols.JsonSerializerContextType != null); + + // Ensure context-scoped metadata caches are empty. + Debug.Assert(_typesToGenerate.Count == 0); + Debug.Assert(_generatedTypes.Count == 0); + Debug.Assert(_contextClassLocation is null); + + var nsList = new List(); + foreach (var l in values) + { + var (contextClassDeclaration, semanticModel) = l.Left; + nsList.Add(semanticModel.GetDeclaredSymbol(contextClassDeclaration, cancellationToken).ContainingNamespace.ToDisplayString()); + } + + var nameSpace = ToModels.GetCommonRootNamespace(nsList); + INamedTypeSymbol contextTypeSymbol_ = null; + var rootTypesToGenerate_ = new List(); + SourceGenerationOptionsSpec? options = null; + INamedTypeSymbol? jsonSerializerContext = null; + foreach (var l in values) + { + var (contextClassDeclaration, semanticModel) = l.Left; + IMethodSymbol? contextTypeSymbol = + semanticModel.GetDeclaredSymbol(contextClassDeclaration, cancellationToken); + Debug.Assert(contextTypeSymbol != null); + + contextTypeSymbol_ = contextTypeSymbol.ContainingType;//.ContainingType; + _contextClassLocation = contextClassDeclaration.GetLocation(); + Debug.Assert(_contextClassLocation is not null); + + // if (!_knownSymbols.JsonSerializerContextType.IsAssignableFrom(contextTypeSymbol)) + // { + // ReportDiagnostic(DiagnosticDescriptors.JsonSerializableAttributeOnNonContextType, _contextClassLocation, contextTypeSymbol.ToDisplayString()); + // return null; + // } + + ParseJsonSerializerContextAttributes(contextTypeSymbol, + semanticModel.Compilation, + nameSpace, + out List? rootSerializableTypes, + out options); + + (rootTypesToGenerate_??=new List()).AddRange(rootSerializableTypes ?? Enumerable.Empty()); + if (rootSerializableTypes is null) + { + // No types were annotated with JsonSerializableAttribute. + // Can only be reached if a [JsonSerializable(null)] declaration has been made. + // Do not emit a diagnostic since a NRT warning will also be emitted. + return null; + } + + Debug.Assert(rootSerializableTypes.Count > 0); + + LanguageVersion? langVersion = _knownSymbols.Compilation.GetLanguageVersion(); + if (langVersion is null or < MinimumSupportedLanguageVersion) + { + // Unsupported lang version should be the first (and only) diagnostic emitted by the generator. + ReportDiagnostic( + global::CSharpToJsonSchema.Generators.JsonGen.JsonSourceGenerator.DiagnosticDescriptors + .JsonUnsupportedLanguageVersion, _contextClassLocation, langVersion?.ToDisplayString(), + MinimumSupportedLanguageVersion.ToDisplayString()); + return null; + } + + if (jsonSerializerContext == null) + { + jsonSerializerContext = SymbolGenerator.GenerateToolJsonSerializerContext( + nameSpace, semanticModel.Compilation); + contextTypeSymbol_ = jsonSerializerContext; + } + + } + + if (!TryGetNestedTypeDeclarations(nameSpace, cancellationToken, + out List? classDeclarationList)) + { + // Class or one of its containing types is not partial so we can't add to it. + // ReportDiagnostic( + // global::CSharpToJsonSchema.Generators.JsonGen.JsonSourceGenerator.DiagnosticDescriptors + // .ContextClassesMustBePartial, _contextClassLocation, contextTypeSymbol.Name); + return null; + } + + // Enqueue attribute data for spec generation + foreach (TypeToGenerate rootSerializableType in rootTypesToGenerate_) + { + _typesToGenerate.Enqueue(rootSerializableType); + } + // Walk the transitive type graph generating specs for every encountered type. + while (_typesToGenerate.Count > 0) + { + cancellationToken.ThrowIfCancellationRequested(); + TypeToGenerate typeToGenerate = _typesToGenerate.Dequeue(); + if (!_generatedTypes.ContainsKey(typeToGenerate.Type)) + { + TypeGenerationSpec spec = ParseTypeGenerationSpec(typeToGenerate, contextTypeSymbol_, options); + _generatedTypes.Add(typeToGenerate.Type, spec); + } + } + + + + Debug.Assert(_generatedTypes.Count > 0); + + ContextGenerationSpec contextGenSpec = new() + { + ContextType = new(jsonSerializerContext), + GeneratedTypes = _generatedTypes.Values.OrderBy(t => t.TypeRef.FullyQualifiedName) + .ToImmutableEquatableArray(), + Namespace = nameSpace, + ContextClassDeclarations = classDeclarationList.ToImmutableEquatableArray(), + GeneratedOptionsSpec = options, + }; + + // Clear the caches of generated metadata between the processing of context classes. + _generatedTypes.Clear(); + _typesToGenerate.Clear(); + _contextClassLocation = null; + return contextGenSpec; + } + private static bool TryGetNestedTypeDeclarations(string nameSpace, CancellationToken cancellationToken, + [NotNullWhen(true)] out List? typeDeclarations) + { + typeDeclarations = new List(); + typeDeclarations.Add("public partial class ToolsJsonSerializerContext"); + // typeDeclarations = null; + // + // for (TypeDeclarationSyntax? currentType = contextClassSyntax; + // currentType != null; + // currentType = currentType.Parent as TypeDeclarationSyntax) + // { + // StringBuilder stringBuilder = new(); + // bool isPartialType = false; + // bool isInterface = currentType is InterfaceDeclarationSyntax; + // foreach (SyntaxToken modifier in currentType.Modifiers) + // { + // stringBuilder.Append(modifier.Text); + // stringBuilder.Append(' '); + // isPartialType |= modifier.IsKind(SyntaxKind.PartialKeyword); + // } + // + // if (!isPartialType && !isInterface) + // { + // typeDeclarations = null; + // return false; + // } + // + // var kind = currentType.GetTypeKindKeyword(); + // if (isInterface) + // kind = "partial class"; + // + // stringBuilder.Append(kind); + // stringBuilder.Append(' '); + // + // INamedTypeSymbol? typeSymbol = semanticModel.GetDeclaredSymbol(currentType, cancellationToken); + // Debug.Assert(typeSymbol != null); + // + // string typeName = (typeSymbol.TypeKind == TypeKind.Interface) + // ? typeSymbol.ToDisplayString(SymbolDisplayFormat.MinimallyQualifiedFormat) + + // "ExtensionsJsonSerializerContext" + // : typeSymbol.ToDisplayString(SymbolDisplayFormat.MinimallyQualifiedFormat); + // stringBuilder.Append(typeName); + // + // (typeDeclarations ??= new()).Add(stringBuilder.ToString()); + // } + // + // Debug.Assert(typeDeclarations?.Count > 0); + return true; + } + + public ContextGenerationSpec? ParseContextGenerationSpec(InterfaceDeclarationSyntax contextClassDeclaration, SemanticModel semanticModel, CancellationToken cancellationToken) { if (!_compilationContainsCoreJsonTypes) @@ -130,7 +307,7 @@ public Parser(KnownTypeSymbols knownSymbols) ParseJsonSerializerContextAttributes(contextTypeSymbol, semanticModel.Compilation, out List? rootSerializableTypes, - out SourceGenerationOptionsSpec? options ); + out SourceGenerationOptionsSpec? options); if (rootSerializableTypes is null) { @@ -201,6 +378,7 @@ public Parser(KnownTypeSymbols knownSymbols) _contextClassLocation = null; return contextGenSpec; } + public ContextGenerationSpec? ParseContextGenerationSpec(ClassDeclarationSyntax contextClassDeclaration, SemanticModel semanticModel, CancellationToken cancellationToken) { @@ -302,6 +480,7 @@ public Parser(KnownTypeSymbols knownSymbols) _contextClassLocation = null; return contextGenSpec; } + private static bool TryGetNestedTypeDeclarations(InterfaceDeclarationSyntax contextClassSyntax, SemanticModel semanticModel, CancellationToken cancellationToken, [NotNullWhen(true)] out List? typeDeclarations) @@ -329,7 +508,7 @@ private static bool TryGetNestedTypeDeclarations(InterfaceDeclarationSyntax cont } var kind = currentType.GetTypeKindKeyword(); - if(isInterface) + if (isInterface) kind = "partial class"; stringBuilder.Append(kind); @@ -338,7 +517,10 @@ private static bool TryGetNestedTypeDeclarations(InterfaceDeclarationSyntax cont INamedTypeSymbol? typeSymbol = semanticModel.GetDeclaredSymbol(currentType, cancellationToken); Debug.Assert(typeSymbol != null); - string typeName = (typeSymbol.TypeKind == TypeKind.Interface)? typeSymbol.ToDisplayString(SymbolDisplayFormat.MinimallyQualifiedFormat).Substring(1)+"ExtensionsJsonSerializerContext":typeSymbol.ToDisplayString(SymbolDisplayFormat.MinimallyQualifiedFormat); + string typeName = (typeSymbol.TypeKind == TypeKind.Interface) + ? typeSymbol.ToDisplayString(SymbolDisplayFormat.MinimallyQualifiedFormat).Substring(1) + + "ExtensionsJsonSerializerContext" + : typeSymbol.ToDisplayString(SymbolDisplayFormat.MinimallyQualifiedFormat); stringBuilder.Append(typeName); (typeDeclarations ??= new()).Add(stringBuilder.ToString()); @@ -421,7 +603,8 @@ private static bool TryGetNestedTypeDeclarations(ClassDeclarationSyntax contextC INamedTypeSymbol? typeSymbol = semanticModel.GetDeclaredSymbol(currentType, cancellationToken); Debug.Assert(typeSymbol != null); - string typeName = typeSymbol.ToDisplayString(SymbolDisplayFormat.MinimallyQualifiedFormat)+"JsonSerializerContext"; + string typeName = typeSymbol.ToDisplayString(SymbolDisplayFormat.MinimallyQualifiedFormat) + + "JsonSerializerContext"; stringBuilder.Append(typeName); (typeDeclarations ??= new()).Add(stringBuilder.ToString()); @@ -452,47 +635,47 @@ private TypeRef EnqueueType(ITypeSymbol type, JsonSourceGenerationMode? generati return new TypeRef(type); } + private void ParseJsonSerializerContextAttributes( - - INamedTypeSymbol contextClassSymbol, - out List? rootSerializableTypes, - out SourceGenerationOptionsSpec? options) - { - Debug.Assert(_knownSymbols.JsonSerializableAttributeType != null); - Debug.Assert(_knownSymbols.JsonSourceGenerationOptionsAttributeType != null); - - rootSerializableTypes = null; - options = null; - - foreach (AttributeData attributeData in contextClassSymbol.GetAttributes()) - { - INamedTypeSymbol? attributeClass = attributeData.AttributeClass; - - if (SymbolEqualityComparer.Default.Equals(attributeClass, _knownSymbols.JsonSerializableAttributeType)) - { - TypeToGenerate? typeToGenerate = ParseJsonSerializableAttribute(attributeData); - if (typeToGenerate is null) - { - continue; - } - - (rootSerializableTypes ??= new()).Add(typeToGenerate.Value); - } - else if (SymbolEqualityComparer.Default.Equals(attributeClass, _knownSymbols.JsonSourceGenerationOptionsAttributeType)) - { - options = ParseJsonSourceGenerationOptionsAttribute(contextClassSymbol, attributeData); - } - } - - if (contextClassSymbol.TypeKind == TypeKind.Interface) - { - - // - } - } - + INamedTypeSymbol contextClassSymbol, + out List? rootSerializableTypes, + out SourceGenerationOptionsSpec? options) + { + Debug.Assert(_knownSymbols.JsonSerializableAttributeType != null); + Debug.Assert(_knownSymbols.JsonSourceGenerationOptionsAttributeType != null); + + rootSerializableTypes = null; + options = null; + + foreach (AttributeData attributeData in contextClassSymbol.GetAttributes()) + { + INamedTypeSymbol? attributeClass = attributeData.AttributeClass; + + if (SymbolEqualityComparer.Default.Equals(attributeClass, + _knownSymbols.JsonSerializableAttributeType)) + { + TypeToGenerate? typeToGenerate = ParseJsonSerializableAttribute(attributeData); + if (typeToGenerate is null) + { + continue; + } + + (rootSerializableTypes ??= new()).Add(typeToGenerate.Value); + } + else if (SymbolEqualityComparer.Default.Equals(attributeClass, + _knownSymbols.JsonSourceGenerationOptionsAttributeType)) + { + options = ParseJsonSourceGenerationOptionsAttribute(contextClassSymbol, attributeData); + } + } + + if (contextClassSymbol.TypeKind == TypeKind.Interface) + { + // + } + } + private void ParseJsonSerializerContextAttributes( - INamedTypeSymbol contextClassSymbol, Compilation compilation, out List? rootSerializableTypes, @@ -508,7 +691,8 @@ private void ParseJsonSerializerContextAttributes( { INamedTypeSymbol? attributeClass = attributeData.AttributeClass; - if (SymbolEqualityComparer.Default.Equals(attributeClass, _knownSymbols.JsonSerializableAttributeType)) + if (SymbolEqualityComparer.Default.Equals(attributeClass, + _knownSymbols.JsonSerializableAttributeType)) { TypeToGenerate? typeToGenerate = ParseJsonSerializableAttribute(attributeData); if (typeToGenerate is null) @@ -518,7 +702,8 @@ private void ParseJsonSerializerContextAttributes( (rootSerializableTypes ??= new()).Add(typeToGenerate.Value); } - else if (SymbolEqualityComparer.Default.Equals(attributeClass, _knownSymbols.JsonSourceGenerationOptionsAttributeType)) + else if (SymbolEqualityComparer.Default.Equals(attributeClass, + _knownSymbols.JsonSourceGenerationOptionsAttributeType)) { options = ParseJsonSourceGenerationOptionsAttribute(contextClassSymbol, attributeData); } @@ -526,12 +711,11 @@ private void ParseJsonSerializerContextAttributes( if (contextClassSymbol.TypeKind == TypeKind.Interface) { - foreach (var member in contextClassSymbol.GetMembers().OfType()) { var type = SymbolGenerator.GenerateParameterBasedClassSymbol( contextClassSymbol.ContainingNamespace.ToDisplayString(), member, compilation); - + var typeToGenerate = new TypeToGenerate { Type = type, @@ -541,8 +725,9 @@ private void ParseJsonSerializerContextAttributes( AttributeLocation = null }; (rootSerializableTypes ??= new()).Add(typeToGenerate); - - if(member.ReturnsVoid || member.ReturnType.MetadataName == "Task")//, || member.ReturnsByRefReadonly) + + if (member.ReturnsVoid || + member.ReturnType.MetadataName == "Task") //, || member.ReturnsByRefReadonly) continue; if (member.ReturnType.BaseType?.Name == "Task") { @@ -558,7 +743,96 @@ private void ParseJsonSerializerContextAttributes( (rootSerializableTypes ??= new()).Add(typeToGenerate3); continue; } - + + var typeToGenerate2 = new TypeToGenerate + { + Type = member.ReturnType, + Mode = null, + TypeInfoPropertyName = $"{member.ReturnType.Name}", + Location = member.GetLocation(), + AttributeLocation = null + }; + (rootSerializableTypes ??= new()).Add(typeToGenerate2); + } + } + + options = new SourceGenerationOptionsSpec + { + GenerationMode = JsonSourceGenerationMode.Default, + Defaults = JsonSerializerDefaults.Web, + AllowOutOfOrderMetadataProperties = true, + AllowTrailingCommas = false, + DefaultBufferSize = 1024, + Converters = null, + DefaultIgnoreCondition = JsonIgnoreCondition.WhenWritingNull, + DictionaryKeyPolicy = null, + RespectNullableAnnotations = true, + RespectRequiredConstructorParameters = false, + IgnoreReadOnlyFields = false, + IgnoreReadOnlyProperties = false, + IncludeFields = false, + MaxDepth = 64, + NewLine = "\n", + NumberHandling = JsonNumberHandling.Strict, + PreferredObjectCreationHandling = JsonObjectCreationHandling.Replace, + PropertyNameCaseInsensitive = true, + PropertyNamingPolicy = JsonKnownNamingPolicy.CamelCase, + ReadCommentHandling = JsonCommentHandling.Disallow, + ReferenceHandler = null, + UnknownTypeHandling = JsonUnknownTypeHandling.JsonElement, + UnmappedMemberHandling = JsonUnmappedMemberHandling.Skip, + UseStringEnumConverter = true, + WriteIndented = false, + IndentCharacter = ' ', + IndentSize = 2, + }; + } + + private void ParseJsonSerializerContextAttributes( + IMethodSymbol contextClassSymbol, + Compilation compilation, + string nameSpace, + out List? rootSerializableTypes, + out SourceGenerationOptionsSpec? options) + { + Debug.Assert(_knownSymbols.JsonSerializableAttributeType != null); + Debug.Assert(_knownSymbols.JsonSourceGenerationOptionsAttributeType != null); + + rootSerializableTypes = null; + options = null; + var member = contextClassSymbol; + var type = SymbolGenerator.GenerateParameterBasedClassSymbol( + nameSpace, member, compilation); + + var typeToGenerate = new TypeToGenerate + { + Type = type, + Mode = null, + TypeInfoPropertyName = $"{member.Name}Args", + Location = type.GetLocation(), + AttributeLocation = null + }; + (rootSerializableTypes ??= new()).Add(typeToGenerate); + + + if (!member.ReturnsVoid && member.ReturnType.MetadataName != "Task") + { + if (member.ReturnType.BaseType?.Name == "Task") + { + var types = (member.ReturnType as INamedTypeSymbol).TypeArguments[0]; + var typeToGenerate3 = new TypeToGenerate + { + Type = types, + Mode = null, + TypeInfoPropertyName = $"{types.Name}", + Location = types.GetLocation(), + AttributeLocation = null + }; + (rootSerializableTypes ??= new()).Add(typeToGenerate3); + + } + else + { var typeToGenerate2 = new TypeToGenerate { Type = member.ReturnType, @@ -570,7 +844,6 @@ private void ParseJsonSerializerContextAttributes( (rootSerializableTypes ??= new()).Add(typeToGenerate2); } - } options = new SourceGenerationOptionsSpec @@ -1387,7 +1660,7 @@ private List ParsePropertyGenerationSpecs( // Walk the type hierarchy starting from the current type up to the base type(s) foreach (INamedTypeSymbol currentType in typeToGenerate.Type.GetSortedTypeHierarchy()) { - if(currentType.Name == "CancellationToken") + if (currentType.Name == "CancellationToken") continue; var declaringTypeRef = new TypeRef(currentType); ImmutableArray members = currentType.GetMembers(); @@ -2135,7 +2408,7 @@ private static string DetermineEffectiveJsonPropertyName(string propertyName, st return propertyName; } - + public static string ToCamelCase(string str) { if (!string.IsNullOrEmpty(str) && str.Length > 1) @@ -2318,7 +2591,8 @@ private bool TryGetDeserializationConstructor( } private bool IsSymbolAccessibleWithin(ISymbol symbol, INamedTypeSymbol within) - => true;// symbol.Name.EndsWith("Args") || _knownSymbols.Compilation.IsSymbolAccessibleWithin(symbol, within); + => + true; // symbol.Name.EndsWith("Args") || _knownSymbols.Compilation.IsSymbolAccessibleWithin(symbol, within); private bool IsUnsupportedType(ITypeSymbol type) { diff --git a/src/libs/CSharpToJsonSchema.Generators/JsonGen/JsonSourceGenerator.Roslyn4.0.cs b/src/libs/CSharpToJsonSchema.Generators/JsonGen/JsonSourceGenerator.Roslyn4.0.cs index c8309e1..e0c1d29 100644 --- a/src/libs/CSharpToJsonSchema.Generators/JsonGen/JsonSourceGenerator.Roslyn4.0.cs +++ b/src/libs/CSharpToJsonSchema.Generators/JsonGen/JsonSourceGenerator.Roslyn4.0.cs @@ -21,6 +21,7 @@ public sealed partial class JsonSourceGenerator { //private const string JsonSerializableAttributeFullName = "CSharpToJsonSchema.Generators.JsonGen.System.Text.Json.JsonSerializableAttribute"; private const string JsonSerializableAttributeFullName = "CSharpToJsonSchema.GenerateJsonSchemaAttribute"; + private const string FunctionToolAttributeFullName = "CSharpToJsonSchema.FunctionToolAttribute"; #if ROSLYN4_4_OR_GREATER public const string SourceGenerationSpecTrackingName = "SourceGenerationSpec"; #endif @@ -64,6 +65,52 @@ public void Initialize2(IncrementalGeneratorInitializationContext context) context.RegisterSourceOutput(contextGenerationSpecs, ReportDiagnosticsAndEmitSource); } + public void InitializeForFunctionTools(IncrementalGeneratorInitializationContext context) + { +#if LAUNCH_DEBUGGER + System.Diagnostics.Debugger.Launch(); +#endif + IncrementalValueProvider knownTypeSymbols = context.CompilationProvider + .Select((compilation, _) => new KnownTypeSymbols(compilation)); + + IncrementalValueProvider<(Model.ContextGenerationSpec?, ImmutableEquatableArray)> + contextGenerationSpecs = IncrementalValueProviderExtensions.Combine( + context.CompilationProvider + .ForAttributeWithMetadataName<(MethodDeclarationSyntax ContextClass, SemanticModel + SemanticModel)>( +#if !ROSLYN4_4_OR_GREATER + context, +#endif + FunctionToolAttributeFullName, + (node, _) => node is MethodDeclarationSyntax, + (context, _) => + (ContextClass: (MethodDeclarationSyntax)context.TargetNode, context.SemanticModel)), + knownTypeSymbols) + .Collect().Select(static (tuple,CancellationToken) => + { + + var known = tuple.FirstOrDefault(); + var knowntype = known.Right; + if (knowntype == null && !tuple.Any()) + { + return (null, ImmutableEquatableArray.Empty); + //knowntype =new KnownTypeSymbols(tuple.First().Left.SemanticModel.Compilation); + } + Parser parser = new(knowntype); + Model.ContextGenerationSpec? contextGenerationSpec = + parser.ParseContextGenerationSpec(tuple,CancellationToken); + ImmutableEquatableArray diagnostics = + parser.Diagnostics.ToImmutableEquatableArray(); + return (contextGenerationSpec, diagnostics); + })//.Select(sx=>(sx.contextGenerationSpec,sx.contextGenerationSpec)) +#if ROSLYN4_4_OR_GREATER + .WithTrackingName(SourceGenerationSpecTrackingName) +#endif + ; + + context.RegisterSourceOutput(contextGenerationSpecs, ReportDiagnosticsAndEmitSource); + } + public void Initialize(IncrementalGeneratorInitializationContext context) { #if LAUNCH_DEBUGGER diff --git a/src/libs/CSharpToJsonSchema.Generators/JsonSchemaGenerator.cs b/src/libs/CSharpToJsonSchema.Generators/JsonSchemaGenerator.cs index bac2e1d..2a96b6d 100755 --- a/src/libs/CSharpToJsonSchema.Generators/JsonSchemaGenerator.cs +++ b/src/libs/CSharpToJsonSchema.Generators/JsonSchemaGenerator.cs @@ -1,3 +1,4 @@ +using System.Collections.Immutable; using CSharpToJsonSchema.Generators.Conversion; using CSharpToJsonSchema.Generators.JsonGen; using H.Generators; @@ -27,7 +28,7 @@ public void Initialize(IncrementalGeneratorInitializationContext context) ProcessInterfaces(context); //Process Methods - // ProcessMethods(context); + ProcessMethods(context); } private void ProcessMethods(IncrementalGeneratorInitializationContext context) @@ -42,12 +43,16 @@ private void ProcessMethods(IncrementalGeneratorInitializationContext context) .SelectAndReportExceptions(AsFunctionTools, context, Id) .AddSource(context); - // attributes - // .SelectAndReportExceptions(AsCalls, context, Id) - // .AddSource(context); + attributes + .SelectAndReportExceptions(AsFunctionCalls, context, Id) + .AddSource(context); + + attributes + .SelectAndReportExceptions(AsGoogleFunctionToolsForMethods, context, Id) + .AddSource(context); - // var generator = new JsonSourceGenerator(); - // generator.Initialize2(context); + var generator = new JsonSourceGenerator(); + generator.InitializeForFunctionTools(context); } private void ProcessInterfaces(IncrementalGeneratorInitializationContext context) @@ -65,6 +70,9 @@ private void ProcessInterfaces(IncrementalGeneratorInitializationContext context attributes .SelectAndReportExceptions(AsCalls, context, Id) .AddSource(context); + attributes + .SelectAndReportExceptions(AsGoogleFunctionToolsForInterface, context, Id) + .AddSource(context); var generator = new JsonSourceGenerator(); generator.Initialize2(context); @@ -80,11 +88,10 @@ private static InterfaceData PrepareData( } private static InterfaceData PrepareMethodData( - (SemanticModel SemanticModel, AttributeData AttributeData, MethodDeclarationSyntax InterfaceSyntax, IMethodSymbol InterfaceSymbol) tuple) + ImmutableArray<(SemanticModel SemanticModel, AttributeData AttributeData, MethodDeclarationSyntax InterfaceSyntax, IMethodSymbol InterfaceSymbol)> list) { - var (_, attributeData, _, interfaceSymbol) = tuple; - - return interfaceSymbol.PrepareMethodData(attributeData); + var lst = list.Select(s => (s.InterfaceSymbol, s.AttributeData)).ToList(); + return ToModels.PrepareMethodData(lst); } private static FileWithName AsTools(InterfaceData @interface) @@ -106,7 +113,26 @@ private static FileWithName AsCalls(InterfaceData @interface) Name: $"{@interface.Name}.Calls.generated.cs", Text: Sources.GenerateCalls(@interface)); } + private static FileWithName AsFunctionCalls(InterfaceData @interface) + { + return new FileWithName( + Name: $"{@interface.Name}.FunctionCalls.generated.cs", + Text: Sources.GenerateFunctionCalls(@interface)); + } + + private static FileWithName AsGoogleFunctionToolsForMethods(InterfaceData @interface) + { + return new FileWithName( + Name: $"{@interface.Name}.GoogleFunctionTools.generated.cs", + Text: Sources.GenerateGoogleFunctionToolForMethods(@interface)); + } + private static FileWithName AsGoogleFunctionToolsForInterface(InterfaceData @interface) + { + return new FileWithName( + Name: $"{@interface.Name}.GoogleFunctionToolsExtensions.generated.cs", + Text: Sources.GenerateGoogleFunctionToolForInterface(@interface)); + } public static (string hintName, SourceText sourceText) AsCalls2(InterfaceData @interface) diff --git a/src/libs/CSharpToJsonSchema.Generators/Models/InterfaceData.cs b/src/libs/CSharpToJsonSchema.Generators/Models/InterfaceData.cs index 71ca4c2..2a773a8 100644 --- a/src/libs/CSharpToJsonSchema.Generators/Models/InterfaceData.cs +++ b/src/libs/CSharpToJsonSchema.Generators/Models/InterfaceData.cs @@ -3,4 +3,5 @@ namespace CSharpToJsonSchema.Generators.Models; public readonly record struct InterfaceData( string Namespace, string Name, + bool GoogleFunctionTool, IReadOnlyCollection Methods); \ No newline at end of file diff --git a/src/libs/CSharpToJsonSchema.Generators/Sources.Method.Calls.cs b/src/libs/CSharpToJsonSchema.Generators/Sources.Method.Calls.cs index dd22153..ddca22c 100644 --- a/src/libs/CSharpToJsonSchema.Generators/Sources.Method.Calls.cs +++ b/src/libs/CSharpToJsonSchema.Generators/Sources.Method.Calls.cs @@ -5,40 +5,42 @@ namespace CSharpToJsonSchema.Generators; -public class Sources_Method_Calls +internal static partial class Sources { - public static string GenerateCalls(InterfaceData @interface) + public static string GenerateFunctionCalls(InterfaceData @interface) { - var extensionsClassName = @interface.Name.Substring(startIndex: 1) + "Extensions"; + if(@interface.Methods.Count == 0) + return string.Empty; + var extensionsClassName = @interface.Name; var res = @$"#nullable enable + #pragma warning disable CS8602 namespace {@interface.Namespace} {{ {@interface.Methods.Select(static method => $@" public class {method.Name}Args - {{ - - {string.Join("\n ", method.Parameters.Select(static x => $@"[global::System.ComponentModel.Description({"\""}{ToModels.GetDescription(x)}{"\""})] + {{ + {string.Join("\n ", method.Parameters.Where(s=>s.Type.Name!="CancellationToken").Select(static x => $@"[global::System.ComponentModel.Description({"\""}{ToModels.GetDescription(x)}{"\""})] public {x.Type.ToDisplayString()}{(x.Type.IsNullableType() ? "?" : "")} {x.Name.ToPropertyName()} {{ get; set; }}{(!string.IsNullOrEmpty(ToModels.GetDefaultValue(x.Type)) ? $" = {ToModels.GetDefaultValue(x.Type)};" : "")}"))} }} ").Inject()} - public static partial class {extensionsClassName} + public partial class {extensionsClassName} {{ - public static global::System.Collections.Generic.IReadOnlyDictionary>> AsCalls(this {@interface.Name} service) + public global::System.Collections.Generic.IReadOnlyDictionary>> AsCalls() {{ return new global::System.Collections.Generic.Dictionary>> {{ {@interface.Methods.Where(static x => x is { IsAsync: false, IsVoid: false }).Select(method => $@" [""{method.Name}""] = (json, _) => {{ - return global::System.Threading.Tasks.Task.FromResult(service.Call{method.Name}(json)); + return global::System.Threading.Tasks.Task.FromResult(Call{method.Name}(json)); }}, ").Inject()} {@interface.Methods.Where(static x => x is { IsAsync: false, IsVoid: true }).Select(method => $@" [""{method.Name}""] = (json, _) => {{ - service.Call{method.Name}(json); + Call{method.Name}(json); return global::System.Threading.Tasks.Task.FromResult(string.Empty); }}, @@ -46,13 +48,13 @@ public static partial class {extensionsClassName} {@interface.Methods.Where(static x => x is { IsAsync: true, IsVoid: false }).Select(method => $@" [""{method.Name}""] = async (json, cancellationToken) => {{ - return await service.Call{method.Name}(json, cancellationToken); + return await Call{method.Name}(json, cancellationToken); }}, ").Inject()} {@interface.Methods.Where(static x => x is { IsAsync: true, IsVoid: true }).Select(method => $@" [""{method.Name}""] = async (json, cancellationToken) => {{ - await service.Call{method.Name}(json, cancellationToken); + await Call{method.Name}(json, cancellationToken); return string.Empty; }}, @@ -61,8 +63,7 @@ public static partial class {extensionsClassName} }} {@interface.Methods.Select(method => $@" - public static {method.Name}Args As{method.Name}Args( - this {@interface.Name} functions, + private {method.Name}Args As{method.Name}Args( string json) {{ #if NET6_0_OR_GREATER @@ -95,10 +96,11 @@ public static partial class {extensionsClassName} ").Inject()} {@interface.Methods.Where(static x => x is { IsAsync: false, IsVoid: false }).Select(method => $@" - public static string Call{method.Name}(this {@interface.Name} functions, string json) + private string Call{method.Name}(string json) {{ - var args = functions.As{method.Name}Args(json); - var jsonResult = functions.{method.Name}({string.Join(", ", method.Parameters.Select(static parameter => $@"args.{parameter.Name.ToPropertyName()}"))}); + var args = As{method.Name}Args(json); + var func = (dynamic) Delegates[""{method.Name}""]; + var jsonResult = func.Invoke({string.Join(", ", method.Parameters.Where(s=>s.Type.Name!="CancellationToken").Select(static parameter => $@"args.{parameter.Name.ToPropertyName()}"))}); #if NET6_0_OR_GREATER if(global::System.Text.Json.JsonSerializer.IsReflectionEnabledByDefault) @@ -125,21 +127,22 @@ public static partial class {extensionsClassName} }} ").Inject()} {@interface.Methods.Where(static x => x is { IsAsync: false, IsVoid: true }).Select(method => $@" - public static void Call{method.Name}(this {@interface.Name} functions, string json) + private void Call{method.Name}(string json) {{ - var args = functions.As{method.Name}Args(json); - functions.{method.Name}({string.Join(", ", method.Parameters.Select(static parameter => $@"args.{parameter.Name.ToPropertyName()}"))}); + var func = (dynamic) Delegates[""{method.Name}""]; + var args = As{method.Name}Args(json); + func.Invoke({string.Join(", ", method.Parameters.Where(s=>s.Type.Name!="CancellationToken").Select(static parameter => $@"args.{parameter.Name.ToPropertyName()}"))}); }} ").Inject()} {@interface.Methods.Where(static x => x is { IsAsync: true, IsVoid: false }).Select(method => $@" - public static async global::System.Threading.Tasks.Task Call{method.Name}( - this {@interface.Name} functions, + private async global::System.Threading.Tasks.Task Call{method.Name}( string json, global::System.Threading.CancellationToken cancellationToken = default) {{ - var args = functions.As{method.Name}Args(json); - var jsonResult = await functions.{method.Name}({string.Join(", ", method.Parameters + var args = As{method.Name}Args(json); + var func = (dynamic) Delegates[""{method.Name}""]; + var jsonResult = await func.Invoke({string.Join(", ", method.Parameters.Where(s=>s.Type.Name!="CancellationToken") .Select(static parameter => $@"args.{parameter.Name.ToPropertyName()}").Append("cancellationToken"))}); #if NET6_0_OR_GREATER @@ -170,25 +173,25 @@ public static partial class {extensionsClassName} ").Inject()} {@interface.Methods.Where(static x => x is { IsAsync: true, IsVoid: true }).Select(method => $@" - public static async global::System.Threading.Tasks.Task Call{method.Name}( - this {@interface.Name} functions, + private async global::System.Threading.Tasks.Task Call{method.Name}( string json, global::System.Threading.CancellationToken cancellationToken = default) {{ - var args = functions.As{method.Name}Args(json); - await functions.{method.Name}({string.Join(", ", method.Parameters.Select(static parameter => $@"args.{parameter.Name.ToPropertyName()}"))}, cancellationToken); + var args = As{method.Name}Args(json); + //var func = (global::System.Func<{GetInputsTypes(method)}>) Delegates[""{method.Name}""]; + var func = (dynamic) Delegates[""{method.Name}""]; + await func.Invoke({string.Join(", ", method.Parameters.Where(s=>s.Type.Name!="CancellationToken").Select(static parameter => $@"args.{parameter.Name.ToPropertyName()}"))}, cancellationToken); return string.Empty; }} ").Inject()} - public static async global::System.Threading.Tasks.Task CallAsync( - this {@interface.Name} service, + public async global::System.Threading.Tasks.Task CallAsync( string functionName, string argumentsAsJson, global::System.Threading.CancellationToken cancellationToken = default) {{ - var calls = service.AsCalls(); + var calls = AsCalls(); var func = calls[functionName]; return await func(argumentsAsJson, cancellationToken); @@ -198,7 +201,9 @@ public static partial class {extensionsClassName} public partial class {extensionsClassName}JsonSerializerContext: global::System.Text.Json.Serialization.JsonSerializerContext {{ }} -}}"; +}} +#pragma warning restore CS8602 +"; return res; } } \ No newline at end of file diff --git a/src/libs/CSharpToJsonSchema.Generators/Sources.Method.GoogleFunctionTools.cs b/src/libs/CSharpToJsonSchema.Generators/Sources.Method.GoogleFunctionTools.cs new file mode 100644 index 0000000..00cbed9 --- /dev/null +++ b/src/libs/CSharpToJsonSchema.Generators/Sources.Method.GoogleFunctionTools.cs @@ -0,0 +1,54 @@ +using CSharpToJsonSchema.Generators.Models; +using H.Generators.Extensions; + +namespace CSharpToJsonSchema.Generators; + +internal static partial class Sources +{ + public static string GenerateGoogleFunctionToolForMethods(InterfaceData @interface) + { + if(@interface.Methods.Count == 0 || !@interface.GoogleFunctionTool) + return string.Empty; + var extensionsClassName = @interface.Name; + + return @$" +#nullable enable + +namespace {@interface.Namespace} +{{ + public partial class {extensionsClassName} + {{ + public static implicit operator global::GenerativeAI.Tools.GenericFunctionTool ({@interface.Namespace}.{extensionsClassName} tools) + {{ + return tools.AsGoogleFunctionTool(); + }} + + public global::GenerativeAI.Tools.GenericFunctionTool AsGoogleFunctionTool() + {{ + return new global::GenerativeAI.Tools.GenericFunctionTool(this.AsTools(), this.AsCalls()); + }} + }} +}}"; + } + + public static string GenerateGoogleFunctionToolForInterface(InterfaceData @interface) + { + if(!@interface.GoogleFunctionTool) + return string.Empty; + var extensionsClassName = @interface.Name.Substring(startIndex: 1) + "Extensions"; + + return @$" +#nullable enable + +namespace {@interface.Namespace} +{{ + public partial class {extensionsClassName} + {{ + public global::GenerativeAI.Core.IFunctionTool AsGoogleFunctionTool(this {@interface.Name} service) + {{ + return new global::GenerativeAI.Tools.GenericFunctionTool(service.AsTools(), service.AsCalls()); + }} + }} +}}"; + } +} \ No newline at end of file diff --git a/src/libs/CSharpToJsonSchema.Generators/Sources.Method.Tools.cs b/src/libs/CSharpToJsonSchema.Generators/Sources.Method.Tools.cs new file mode 100644 index 0000000..724b2a9 --- /dev/null +++ b/src/libs/CSharpToJsonSchema.Generators/Sources.Method.Tools.cs @@ -0,0 +1,100 @@ +using CSharpToJsonSchema.Generators.Models; +using H.Generators.Extensions; +using Microsoft.CodeAnalysis; + +namespace CSharpToJsonSchema.Generators; + +internal static partial class Sources +{ + public static string GenerateFunctionToolClientImplementation(InterfaceData @interface) + { + if(@interface.Methods.Count == 0) + return string.Empty; + var extensionsClassName = @interface.Name; + var constructorsToAdd = new List(); + HashSet methods = new(); + foreach (var method in @interface.Methods) + { + if (methods.Add(GetInputsTypes(method))) + { + constructorsToAdd.Add(method); + } + } + //var methodName = @interface.Methods.First().Name; + //var funcsToAdd = @interface.Methods.Select(GetInputsTypes).Distinct().ToList(); + return @$" +#nullable enable + +namespace {@interface.Namespace} +{{ + public partial class {extensionsClassName} + {{ + static {extensionsClassName}() + {{ + AddAllTools(); + }} + private static global::System.Collections.Generic.IDictionary? AllTools {{get; set;}} + public global::System.Collections.Generic.IList? AvailableTools {{get; private set;}} + private global::System.Collections.Generic.IDictionary? Delegates {{get; set;}} + private static void AddAllTools() + {{ + var list = new global::System.Collections.Generic.Dictionary(); + global::CSharpToJsonSchema.Tool tool; + {@interface.Methods.Select(method => $@" + tool = new global::CSharpToJsonSchema.Tool + {{ + Name = ""{method.Name}"", + Description = ""{method.Description}"", + Strict = {(method.IsStrict ? "true" : "false")}, + Parameters = global::CSharpToJsonSchema.SchemaBuilder.ConvertToSchema(global::{@interface.Namespace}.{extensionsClassName}JsonSerializerContext.Default.{method.Name}Args,{"\""}{GetDictionaryString(method)}{"\""}), + }}; + if(!list.ContainsKey(tool.Name)) + list.Add(""{method.Name}"",tool); + ").Inject()} + AllTools = list; + }} + + public void AddTool(global::System.Delegate tool) + {{ + AvailableTools ??= new global::System.Collections.Generic.List(); + Delegates ??= new global::System.Collections.Generic.Dictionary(); + + var name = tool.Method.Name; + if(AllTools == null || !AllTools.ContainsKey(name)) + throw new global::System.Exception({"$\"Function {name} is not registered. Please make sure you have added proper attribute to the method.\""}); + var newTool = AllTools[name]; + AvailableTools.Add(newTool); + if(Delegates.ContainsKey(name)) + throw new global::System.Exception({"$\"Function {name} is already registered\""}); + Delegates.Add(name, tool); + }} + + public Tools(global::System.Delegate[] tools) + {{ + foreach (var tool in tools) + {{ + AddTool(tool); + }} + }} + + public static implicit operator global::System.Collections.Generic.List({@interface.Namespace}.{extensionsClassName} tools) + {{ + return tools.AsTools(); + }} + + public global::System.Collections.Generic.List AsTools() + {{ + return (global::System.Collections.Generic.List) (this.AvailableTools??= new global::System.Collections.Generic.List()); + }} + }} +}}"; + } + + private static string GetInputsTypes(MethodData first, bool addReturnType = true) + { + var f = first.Parameters.Select(s => s.Type.ToDisplayString()).ToList(); + if(addReturnType) + f.Add(first.ReturnType.ToDisplayString()); + return string.Join(", ", f); + } +} \ No newline at end of file diff --git a/src/libs/CSharpToJsonSchema.Generators/Sources.Tools.cs b/src/libs/CSharpToJsonSchema.Generators/Sources.Tools.cs index 8cc5597..cf6b829 100644 --- a/src/libs/CSharpToJsonSchema.Generators/Sources.Tools.cs +++ b/src/libs/CSharpToJsonSchema.Generators/Sources.Tools.cs @@ -90,40 +90,8 @@ public static partial class {extensionsClassName} }} }}"; } - - public static string GenerateFunctionToolClientImplementation(InterfaceData @interface) - { - var extensionsClassName = @interface.Name.Substring(startIndex: 1) + "Extensions"; - - return @$" -#nullable enable -namespace {@interface.Namespace} -{{ - public static partial class {extensionsClassName} - {{ - public static global::CSharpToJsonSchema.Tool AsTool(this Func<{GetInputsTypes(@interface.Methods.First())}> functions) - {{ - {@interface.Methods.Select(method => $@" - return new global::CSharpToJsonSchema.Tool - {{ - Name = ""{method.Name}"", - Description = ""{method.Description}"", - Strict = {(method.IsStrict ? "true" : "false")}, - Parameters = null - }}; - ").Inject()} - }} - }} -}}"; - } - - private static string GetInputsTypes(MethodData first) - { - var f = first.Parameters.Select(s => s.Type.ToDisplayString()).ToList(); - f.Add(first.ReturnType.ToDisplayString()); - return string.Join(", ", f); - } + private static string GetDictionaryString(MethodData data) { diff --git a/src/libs/CSharpToJsonSchema.Generators/Steps.cs b/src/libs/CSharpToJsonSchema.Generators/Steps.cs index 166cc87..ddfc40b 100644 --- a/src/libs/CSharpToJsonSchema.Generators/Steps.cs +++ b/src/libs/CSharpToJsonSchema.Generators/Steps.cs @@ -1,3 +1,4 @@ +using System.Collections.Immutable; using Microsoft.CodeAnalysis; using Microsoft.CodeAnalysis.CSharp.Syntax; @@ -43,7 +44,7 @@ public static IncrementalValuesProvider ForAttr ClassSymbol: (INamedTypeSymbol)context.TargetSymbol))); } - public static IncrementalValuesProvider<(SemanticModel SemanticModel, AttributeData AttributeData, MethodDeclarationSyntax InterfaceSyntax, IMethodSymbol InterfaceSymbol)> + public static IncrementalValueProvider> SelectManyAllAttributesOfCurrentMethodSyntax( this IncrementalValuesProvider source) { @@ -56,7 +57,7 @@ public static IncrementalValuesProvider ForAttr AttributeData: x, ClassSyntax: (MethodDeclarationSyntax)context.TargetNode, ClassSymbol: (IMethodSymbol)context.TargetSymbol) - )); + )).Collect(); return items; } } diff --git a/src/libs/CSharpToJsonSchema/FunctionToolAttribute.cs b/src/libs/CSharpToJsonSchema/FunctionToolAttribute.cs index d2c0acb..ae33c0d 100644 --- a/src/libs/CSharpToJsonSchema/FunctionToolAttribute.cs +++ b/src/libs/CSharpToJsonSchema/FunctionToolAttribute.cs @@ -11,4 +11,9 @@ public sealed class FunctionToolAttribute : Attribute /// /// public bool Strict { get; set; } + + /// + /// Generate Google Function Tools extensions for Google_GenerativeAI SDK + /// + public bool GoogleFunctionTool { get; set; } } \ No newline at end of file diff --git a/src/libs/CSharpToJsonSchema/GenerateJsonSchemaAttribute.cs b/src/libs/CSharpToJsonSchema/GenerateJsonSchemaAttribute.cs index 3d9502f..07e21c7 100644 --- a/src/libs/CSharpToJsonSchema/GenerateJsonSchemaAttribute.cs +++ b/src/libs/CSharpToJsonSchema/GenerateJsonSchemaAttribute.cs @@ -12,4 +12,9 @@ public sealed class GenerateJsonSchemaAttribute : Attribute /// /// public bool Strict { get; set; } + + /// + /// Generate Google Function Tools extensions for Google_GenerativeAI SDK + /// + public bool GoogleFunctionTool { get; set; } } \ No newline at end of file diff --git a/src/tests/CSharpToJsonSchema.AotTests/JsonSerializationTests.cs b/src/tests/CSharpToJsonSchema.AotTests/JsonSerializationTests.cs index aa970a9..c0bd0e9 100644 --- a/src/tests/CSharpToJsonSchema.AotTests/JsonSerializationTests.cs +++ b/src/tests/CSharpToJsonSchema.AotTests/JsonSerializationTests.cs @@ -30,31 +30,33 @@ public Task ShouldDeserializeWithJsonTypeInfo() Numbers = new List { 1, 2, 3, 4, 5 } } }; - - var serialized = JsonSerializer.Serialize(args, WeatherToolsExtensionsJsonSerializerContext.Default.ComplexClassSerializerTools); + + var serialized = JsonSerializer.Serialize(args, + WeatherToolsExtensionsJsonSerializerContext.Default.ComplexClassSerializerTools); serialized.Should().NotBeNullOrEmpty(); var deserialized = JsonSerializer.Deserialize(serialized, WeatherToolsExtensionsJsonSerializerContext.Default.ComplexClassSerializerTools); - + deserialized.Should().NotBeNull(); deserialized!.Name.Should().Be(args.Name); deserialized.Age.Should().Be(args.Age); deserialized.IsActive.Should().Be(args.IsActive); - deserialized.CreatedAt.Should().BeCloseTo(args.CreatedAt, TimeSpan.FromSeconds(1)); // Accounting for serialization precision + deserialized.CreatedAt.Should() + .BeCloseTo(args.CreatedAt, TimeSpan.FromSeconds(1)); // Accounting for serialization precision deserialized.Tags.Should().BeEquivalentTo(args.Tags); //foreach (var key in args.Metadata.Keys) //{ // deserialized.Metadata[key].Should().Be(args.Metadata[key]); //} - + deserialized.Details.Should().NotBeNull(); deserialized.Details!.Description.Should().Be(args.Details.Description); deserialized.Details.Value.Should().Be(args.Details.Value); deserialized.Details.Numbers.Should().BeEquivalentTo(args.Details.Numbers); - + return Task.CompletedTask; } @@ -94,16 +96,19 @@ public async Task ShouldWorkWithService() Temperature = 25, Unit = Unit.Celsius }; - var serialize = JsonSerializer.Serialize(args2, WeatherToolsExtensionsJsonSerializerContext.Default.GetComplexDataTypeArgs); + var serialize = JsonSerializer.Serialize(args2, + WeatherToolsExtensionsJsonSerializerContext.Default.GetComplexDataTypeArgs); var dx = await calls["GetComplexDataType"].Invoke(serialize, default); - var deserialized = JsonSerializer.Deserialize(dx, WeatherToolsExtensionsJsonSerializerContext.Default.ComplexClassSerializerTools); + var deserialized = JsonSerializer.Deserialize(dx, + WeatherToolsExtensionsJsonSerializerContext.Default.ComplexClassSerializerTools); deserialized.Should().NotBeNull(); deserialized!.Name.Should().Be(args.Name); deserialized.Age.Should().Be(args.Age); deserialized.IsActive.Should().Be(args.IsActive); - deserialized.CreatedAt.Should().BeCloseTo(args.CreatedAt, TimeSpan.FromSeconds(1)); // Accounting for serialization precision + deserialized.CreatedAt.Should() + .BeCloseTo(args.CreatedAt, TimeSpan.FromSeconds(1)); // Accounting for serialization precision deserialized.Tags.Should().BeEquivalentTo(args.Tags); //foreach (var key in args.Metadata.Keys) //{ @@ -115,7 +120,6 @@ public async Task ShouldWorkWithService() deserialized.Details!.Description.Should().Be(args.Details.Description); deserialized.Details.Value.Should().Be(args.Details.Value); deserialized.Details.Numbers.Should().BeEquivalentTo(args.Details.Numbers); - } [Fact] @@ -123,6 +127,6 @@ public void ShouldCreateToolWithComplexStudentClass() { var service = new StudenRecordService(); var tools = service.AsTools(); - } + } \ No newline at end of file diff --git a/src/tests/CSharpToJsonSchema.AotTests/MethodFunctionTools_Tests.cs b/src/tests/CSharpToJsonSchema.AotTests/MethodFunctionTools_Tests.cs new file mode 100644 index 0000000..b90e5e4 --- /dev/null +++ b/src/tests/CSharpToJsonSchema.AotTests/MethodFunctionTools_Tests.cs @@ -0,0 +1,78 @@ +namespace CSharpToJsonSchema.IntegrationTests; + +public class MethodFunctionTools_Tests +{ + [Fact] + public async Task Should_SampleFunctionTool_StringAsync() + { + var mf = new MethodFunctionTools(); + var tools = new Tools([mf.SampleFunctionTool_StringAsync]); + var value = await tools.CallAsync(nameof(mf.SampleFunctionTool_StringAsync), "{\"input\":\"test\"}"); + + Assert.Equal(value, "\"Hello world\""); + + } + + [Fact] + public async Task Should_SampleFunctionTool_NoReturnAsync() + { + var mf = new MethodFunctionTools(); + + var tools = new Tools([mf.SampleFunctionTool_NoReturnAsync]); + await tools.CallAsync(nameof(mf.SampleFunctionTool_NoReturnAsync), "{\"input\":\"test\"}"); + } + + [Fact] + public async Task Should_SampleFunctionTool_String() + { + var mf = new MethodFunctionTools(); + var tools = new Tools([mf.SampleFunctionTool_String]); + var value = await tools.CallAsync(nameof(mf.SampleFunctionTool_String), "{\"input\":\"test\"}"); + + Assert.Equal(value, "\"Hello world from string return\""); + } + + [Fact] + public async Task Should_SampleFunctionTool_Void() + { + var mf = new MethodFunctionTools(); + var tools = new Tools([mf.SampleFunctionTool_Void]); + await tools.CallAsync(nameof(mf.SampleFunctionTool_Void), "{\"input\":\"test\"}"); + } + + //Static Methods + + [Fact] + public async Task Should_SampleFunctionTool_Static_StringAsync() + { + + var tools = new Tools([MethodFunctionTools.SampleFunctionTool_Static_StringAsync]); + var value = await tools.CallAsync(nameof(MethodFunctionTools.SampleFunctionTool_Static_StringAsync), "{\"input\":\"test\"}"); + + Assert.Equal(value, "\"Hello world\""); + } + + [Fact] + public async Task Should_SampleFunctionTool_Static_NoReturnAsync() + { + var tools = new Tools([MethodFunctionTools.SampleFunctionTool_Static_NoReturnAsync]); + await tools.CallAsync(nameof(MethodFunctionTools.SampleFunctionTool_Static_NoReturnAsync), "{\"input\":\"test\"}"); + } + + [Fact] + public async Task Should_SampleFunctionTool_Static_String() + { + var tools = new Tools([MethodFunctionTools.SampleFunctionTool_Static_String]); + var value = await tools.CallAsync(nameof(MethodFunctionTools.SampleFunctionTool_Static_String), "{\"input\":\"test\"}"); + Assert.Equal(value, "\"Hello world from string return\""); + } + + [Fact] + public async Task Should_SampleFunctionTool_Static_Void() + { + var tools = new Tools([MethodFunctionTools.SampleFunctionTool_Static_Void]); + await tools.CallAsync(nameof(MethodFunctionTools.SampleFunctionTool_Static_Void), "{\"input\":\"test\"}"); + + } + +} \ No newline at end of file diff --git a/src/tests/CSharpToJsonSchema.AotTests/Services/MethodFunctionTools.cs b/src/tests/CSharpToJsonSchema.AotTests/Services/MethodFunctionTools.cs new file mode 100644 index 0000000..063e4df --- /dev/null +++ b/src/tests/CSharpToJsonSchema.AotTests/Services/MethodFunctionTools.cs @@ -0,0 +1,55 @@ + +using System.Threading; +using System.Threading.Tasks; + +namespace CSharpToJsonSchema.IntegrationTests; +public class MethodFunctionTools +{ + [FunctionTool] + public Task SampleFunctionTool_StringAsync(string input, CancellationToken cancellationToken = default) + { + return Task.FromResult("Hello world"); + } + + [FunctionTool] + public Task SampleFunctionTool_NoReturnAsync(string input, CancellationToken cancellationToken = default) + { + return Task.CompletedTask; + } + + [FunctionTool] + public void SampleFunctionTool_Void(string input) + { + + } + + [FunctionTool] + public string SampleFunctionTool_String(string input) + { + return "Hello world from string return"; + } + + [FunctionTool] + public static Task SampleFunctionTool_Static_StringAsync(string input, CancellationToken cancellationToken = default) + { + return Task.FromResult("Hello world"); + } + + [FunctionTool] + public static Task SampleFunctionTool_Static_NoReturnAsync(string input, CancellationToken cancellationToken = default) + { + return Task.CompletedTask; + } + + [FunctionTool] + public static void SampleFunctionTool_Static_Void(string input) + { + + } + + [FunctionTool] + public static string SampleFunctionTool_Static_String(string input) + { + return "Hello world from string return"; + } +} \ No newline at end of file diff --git a/src/tests/CSharpToJsonSchema.IntegrationTests/MethodFunctionTools.cs b/src/tests/CSharpToJsonSchema.IntegrationTests/MethodFunctionTools.cs index 05451bf..cb00bce 100644 --- a/src/tests/CSharpToJsonSchema.IntegrationTests/MethodFunctionTools.cs +++ b/src/tests/CSharpToJsonSchema.IntegrationTests/MethodFunctionTools.cs @@ -6,15 +6,54 @@ namespace CSharpToJsonSchema.IntegrationTests; public class MethodFunctionTools { [FunctionTool] - public Task SampleFunctionToolAsync(string input, CancellationToken cancellationToken = default) + public Task SampleFunctionTool_StringAsync(string input, CancellationToken cancellationToken = default) { return Task.FromResult("Hello world"); } - - public Task SampleFunctionToolAsync2(string input, CancellationToken cancellationToken = default) - { - //SampleFunctionToolAsync; - //var t = this.SampleFunctionToolAsync.AsTool(); + + [FunctionTool] + public Task SampleFunctionTool_NoReturnAsync(string input, CancellationToken cancellationToken = default) + { + return Task.CompletedTask; + } + + [FunctionTool] + public void SampleFunctionTool_Void(string input) + { + + } + + [FunctionTool] + public string SampleFunctionTool_String(string input) + { + return "Hello world from string return"; + } + + + [FunctionTool] + public Task SampleFunctionTool_Static_StringAsync(string input, CancellationToken cancellationToken = default) + { return Task.FromResult("Hello world"); } + + [FunctionTool] + public Task SampleFunctionTool_Static_NoReturnAsync(string input, CancellationToken cancellationToken = default) + { + return Task.CompletedTask; + } + + [FunctionTool] + public void SampleFunctionTool_Static_Void(string input) + { + + } + + [FunctionTool] + public string SampleFunctionTool_Static_String(string input) + { + return "Hello world from string return"; + } + + + } \ No newline at end of file diff --git a/src/tests/CSharpToJsonSchema.SnapshotTests/SnapshotTests.cs b/src/tests/CSharpToJsonSchema.SnapshotTests/SnapshotTests.cs index 484a83c..2d1c132 100755 --- a/src/tests/CSharpToJsonSchema.SnapshotTests/SnapshotTests.cs +++ b/src/tests/CSharpToJsonSchema.SnapshotTests/SnapshotTests.cs @@ -3,6 +3,12 @@ namespace CSharpToJsonSchema.SnapshotTests; [TestClass] public class ToolTests : VerifyBase { + // [TestMethod] + // public Task MethodFunction() + // { + // return this.CheckSourceAsync(H.Resources.MethodFunctionTools_cs.AsString()); + // } + [TestMethod] public Task Weather() {