Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 29 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,14 @@
[![Discord](https://img.shields.io/discord/1115206893015662663?label=Discord&logo=discord&logoColor=white&color=d82679)](https://discord.gg/Ca2xhfBf3v)

## Features 🔥
- Source generator to define functions natively through C# interfaces
- Source generator to define functions natively through C# interfaces and individual methods
- Doesn't use Reflection
- All modern .NET features - nullability, trimming, NativeAOT, etc.
- Tested for compatibility with OpenAI/Ollama/Anthropic/LangChain/Gemini

## Usage

### Interface
```csharp
using CSharpToJsonSchema;

Expand Down Expand Up @@ -56,6 +58,32 @@ public class WeatherService : IWeatherFunctions
var tools = service.AsTools();
```

### Methods

```csharp

[FunctionTool]
public Task<Weather> GetCurrentWeatherAsync(string location, Unit unit = Unit.Celsius, CancellationToken cancellationToken = default)
{
return Task.FromResult(new Weather
{
Location = location,
Temperature = 22.0,
Unit = unit,
Description = "Sunny",
});
}

var tools = new Tools([GetCurrentWeatherAsync])

//Access list of CSharpToJsonSchema.Tool
var myTools = tools.AvailableTools

//Implicit Conversion to list of CSharpToJsonSchema.Tool
List<Tool> myTools = tools
```


## Support

Priority place for bugs: https://github.com/tryAGI/CSharpToJsonSchema/issues
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,59 @@ public static class SymbolGenerator
var classSymbol = semanticModel.GetDeclaredSymbol(classNode) as ITypeSymbol;
return classSymbol;
}

public static INamedTypeSymbol? GenerateToolJsonSerializerContext(
string rootNamespace,
Compilation originalCompilation)
{


// Example: we create a class name
var className = $"ToolsJsonSerializerContext";


// Build a class declaration
var classDecl = SyntaxFactory.ClassDeclaration(className)
.AddModifiers(SyntaxFactory.Token(SyntaxKind.PublicKeyword))
.AddModifiers(SyntaxFactory.Token(SyntaxKind.PartialKeyword));

// We create a compilation unit holding our new class



var namespaceName =rootNamespace; // choose your own
var ns = SyntaxFactory.NamespaceDeclaration(SyntaxFactory.IdentifierName(namespaceName))
.AddMembers(classDecl);

var compilationUnit = SyntaxFactory.CompilationUnit()
.AddMembers(ns) // if ns is a NamespaceDeclarationSyntax
.NormalizeWhitespace();

var parseOptions = CSharpParseOptions.Default.WithLanguageVersion(originalCompilation.GetLanguageVersion()?? LanguageVersion.Default);
var syntaxTree =CSharpSyntaxTree.Create(compilationUnit,parseOptions);
//CSharpSyntaxTree.Create(ns.NormalizeWhitespace());

// Now we need to add this syntax tree to a new or existing compilation
var assemblyName = "TemporaryAssembly";
var compilation = originalCompilation
.AddSyntaxTrees(syntaxTree);
//.WithAssemblyName(assemblyName);


// Get the semantic model for our newly added syntax tree
var semanticModel = compilation.GetSemanticModel(syntaxTree);

// Find the class syntax node in the syntax tree
var classNode = syntaxTree.GetRoot().DescendantNodes()
.OfType<ClassDeclarationSyntax>()
.FirstOrDefault();

if (classNode == null) return null;

// Retrieve the ITypeSymbol from the semantic model
var classSymbol = semanticModel.GetDeclaredSymbol(classNode);
return classSymbol;
}

public static AttributeSyntax GetConverter(string propertyType)
{
Expand Down
103 changes: 78 additions & 25 deletions src/libs/CSharpToJsonSchema.Generators/Conversion/ToModels.cs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ public static InterfaceData PrepareData(

var isStrict = attributeData.NamedArguments.FirstOrDefault(x => x.Key == "Strict").Value.Value is bool strict &&
strict;
var generateGoogleFunctionTool = attributeData.NamedArguments.FirstOrDefault(x => x.Key == "GoogleFunctionTool").Value.Value is bool googleFunctionTool &&
googleFunctionTool;
var methods = interfaceSymbol
.GetMembers()
.OfType<IMethodSymbol>()
Expand Down Expand Up @@ -43,42 +45,93 @@ public static InterfaceData PrepareData(
return new InterfaceData(
Namespace: interfaceSymbol.ContainingNamespace.ToDisplayString(),
Name: interfaceSymbol.ToDisplayString(SymbolDisplayFormat.MinimallyQualifiedFormat),
GoogleFunctionTool:generateGoogleFunctionTool,
Methods: methods);
}

public static InterfaceData PrepareMethodData(
this IMethodSymbol interfaceSymbol,
AttributeData attributeData)
List<(IMethodSymbol, AttributeData)> list)
{
interfaceSymbol = interfaceSymbol ?? throw new ArgumentNullException(nameof(interfaceSymbol));
attributeData = attributeData ?? throw new ArgumentNullException(nameof(attributeData));
//interfaceSymbol = interfaceSymbol ?? throw new ArgumentNullException(nameof(interfaceSymbol));
list = list ?? throw new ArgumentNullException(nameof(list));

var namespaceName = "CSharpToJsonSchema";
var className = "Tools";
List<MethodData> methodList = new();
List<string> namespaces = new();
bool generateGoogleFunctionTools = false;
foreach (var l in list)
{
var (interfaceSymbol, attributeData) = l;
var isStrict = attributeData.NamedArguments.FirstOrDefault(x => x.Key == "Strict").Value.Value is bool strict &&
strict;
var ggft = attributeData.NamedArguments.FirstOrDefault(x => x.Key == "GoogleFunctionTool").Value.Value is bool googleFunctionTool &&
googleFunctionTool;
if(ggft)
generateGoogleFunctionTools = true;

var x = interfaceSymbol;
var parameters = x.Parameters
//.Where(static x => x.Type.MetadataName != "CancellationToken")
.ToArray();

var methodData = new MethodData(
Name: x.Name,
Description: GetDescription(x),
IsAsync: x.IsAsync || x.ReturnType.Name == "Task",
IsVoid: x.ReturnsVoid || x.ReturnType.MetadataName == "Task",
IsStrict: isStrict,
Parameters: parameters.Select(static y => y).ToArray(),
Descriptions: parameters.Select(static l => GetParameterDescriptions(l)).SelectMany(s => s)
.ToDictionary(s => s.Key, s => s.Value),
ReturnType:x.ReturnType
);
methodList.Add(methodData);
namespaces.Add(interfaceSymbol.ContainingNamespace.ToDisplayString());
}

var isStrict = attributeData.NamedArguments.FirstOrDefault(x => x.Key == "Strict").Value.Value is bool strict &&
strict;
var x = interfaceSymbol;
var parameters = x.Parameters
//.Where(static x => x.Type.MetadataName != "CancellationToken")
.ToArray();
return new InterfaceData(
Namespace: GetCommonRootNamespace(namespaces)??namespaceName,
Name: className,
GoogleFunctionTool: generateGoogleFunctionTools,
Methods: methodList.ToArray());
}
public static string? GetCommonRootNamespace(IEnumerable<string> namespaces)
{
// Convert the list of namespaces to a list of arrays split by "."
var splitNamespaces = namespaces
.Select(ns => ns.Split('.'))
.ToList();

if (!splitNamespaces.Any() || !splitNamespaces[0].Any())
{
return null;
}

var methodData = new MethodData(
Name: x.Name,
Description: GetDescription(x),
IsAsync: x.IsAsync || x.ReturnType.Name == "Task",
IsVoid: x.ReturnsVoid || x.ReturnType.MetadataName == "Task",
IsStrict: isStrict,
Parameters: parameters.Select(static y => y).ToArray(),
Descriptions: parameters.Select(static l => GetParameterDescriptions(l)).SelectMany(s => s)
.ToDictionary(s => s.Key, s => s.Value),
ReturnType:x.ReturnType
);
// Start with the first namespace parts
var firstNsParts = splitNamespaces[0];
var commonParts = new List<string>();

// For each index in the first namespace
for (int i = 0; i < firstNsParts.Length; i++)
{
// Check if every other namespace has the same part at this index
string currentPart = firstNsParts[i];
if (splitNamespaces.All(nsArr => nsArr.Length > i && nsArr[i] == currentPart))
{
commonParts.Add(currentPart);
}
else
{
// Stop the moment there is a mismatch
break;
}
}

return new InterfaceData(
Namespace: interfaceSymbol.ContainingNamespace.ToDisplayString(),
Name: "I"+interfaceSymbol.Name,
Methods: [methodData]);
return string.Join(".", commonParts);
}


// private static Dictionary<string, bool> GetIsRequired(IParameterSymbol[] parameters, Dictionary<string, bool>? dics = null)
// {
// dics ??= new Dictionary<string, bool>();
Expand Down
Loading
Loading