Skip to content

Commit ff2004d

Browse files
committed
feat: Added Cohere AllOf inheritance/polymorphism support.
1 parent b8da9f5 commit ff2004d

File tree

2,203 files changed

+427
-23265
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

2,203 files changed

+427
-23265
lines changed

src/libs/AutoSDK/Models/ModelData.cs

Lines changed: 25 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,12 @@ public readonly record struct ModelData(
1313
ImmutableArray<PropertyData> Properties,
1414
ImmutableArray<PropertyData> EnumValues,
1515
string Summary,
16-
bool IsDeprecated
16+
bool IsDeprecated,
17+
string BaseClass,
18+
bool IsBaseClass,
19+
bool IsDerivedClass,
20+
string DiscriminatorPropertyName,
21+
EquatableArray<(string ClassName, string Discriminator)> DerivedTypes
1722
)
1823
{
1924
public static ModelData FromSchemaContext(
@@ -41,16 +46,31 @@ public static ModelData FromSchemaContext(
4146
Namespace: context.Settings.Namespace,
4247
Style: context.Schema.IsEnum() ? ModelStyle.Enumeration : context.Settings.ModelStyle,
4348
Settings: context.Settings,
44-
Properties: !context.Schema.IsEnum()
45-
? context.Children
49+
Properties: context.IsDerivedClass
50+
? context.DerivedClassContext.Children
4651
.Where(x => x is { IsProperty: true, PropertyData: not null })
4752
.SelectMany(x => x.ComputedProperties)
48-
.ToImmutableArray() : [],
53+
.ToImmutableArray()
54+
: !context.Schema.IsEnum()
55+
? context.Children
56+
.Where(x => x is { IsProperty: true, PropertyData: not null })
57+
.SelectMany(x => x.ComputedProperties)
58+
.ToImmutableArray()
59+
: [],
4960
EnumValues: context.Schema.IsEnum()
5061
? context.ComputeEnum().Values.ToImmutableArray()
5162
: [],
5263
Summary: context.Schema.GetSummary(),
53-
IsDeprecated: context.Schema.Deprecated
64+
IsDeprecated: context.Schema.Deprecated,
65+
BaseClass: context.IsDerivedClass
66+
? context.BaseClassContext.Id
67+
: string.Empty,
68+
IsBaseClass: context.IsBaseClass,
69+
IsDerivedClass: context.IsDerivedClass,
70+
DiscriminatorPropertyName: context.Schema.Discriminator?.PropertyName ?? string.Empty,
71+
DerivedTypes: context.Schema.Discriminator?.Mapping?
72+
.Select(x => (ClassName: x.Value.Replace("#/components/schemas/", string.Empty), Discriminator: x.Key))
73+
.ToImmutableArray() ?? []
5474
);
5575
}
5676

src/libs/AutoSDK/Models/SchemaContext.cs

Lines changed: 31 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,9 @@ public class SchemaContext
1212
public IList<SchemaContext> Children { get; set; } = [];
1313

1414
public required Settings Settings { get; init; }
15-
public required OpenApiSchema Schema { get; init; }
15+
public required OpenApiSchema Schema { get; set; }
1616
public required string Id { get; set; }
17-
public required string Type { get; init; }
17+
public required string Type { get; set; }
1818

1919
public string? ReferenceId { get; init; }
2020
public bool IsReference => ReferenceId != null;
@@ -52,7 +52,9 @@ public class SchemaContext
5252

5353
public TypeData TypeData { get; set; } = TypeData.Default;
5454

55-
public bool IsClass => Type == "class";// || ResolvedReference?.IsClass == true;
55+
public bool IsClass =>
56+
Type == "class" ||
57+
IsDerivedClass;// || ResolvedReference?.IsClass == true;
5658
//public ModelData? ClassData { get; set; }
5759
public ModelData? ClassData => IsClass
5860
? //IsReference
@@ -78,7 +80,32 @@ public class SchemaContext
7880

7981
public bool IsAnyOf => Schema.IsAnyOf();
8082
public bool IsOneOf => Schema.IsOneOf();
81-
public bool IsAllOf => Schema.IsAllOf();
83+
public bool IsAllOf => Schema.IsAllOf() && !IsDerivedClass;
84+
public bool IsBaseClass => this is { IsComponent: true, Schema.Discriminator.Mapping: not null };
85+
public bool IsDerivedClass => Schema.IsAllOf() &&
86+
Schema.AllOf is { Count: 2 } allOf &&
87+
(allOf[0].Reference != null &&
88+
allOf[0].ResolveIfRequired().Discriminator?.Mapping != null ||
89+
allOf[1].Reference != null &&
90+
allOf[1].ResolveIfRequired().Discriminator?.Mapping != null);
91+
public SchemaContext DerivedClassContext =>
92+
Schema.IsAllOf() &&
93+
Schema.AllOf is { Count: 2 } allOf
94+
? allOf[0].Reference != null &&
95+
allOf[0].ResolveIfRequired().Discriminator?.Mapping != null
96+
? Children.First(x => x.ReferenceId == allOf[1].Reference?.Id)
97+
: Children.First(x => x.ReferenceId == allOf[0].Reference?.Id)
98+
: throw new InvalidOperationException("Schema is not derived class.");
99+
100+
public SchemaContext BaseClassContext =>
101+
Schema.IsAllOf() &&
102+
Schema.AllOf is { Count: 2 } allOf
103+
? allOf[0].Reference != null &&
104+
allOf[0].ResolveIfRequired().Discriminator?.Mapping != null
105+
? Children.First(x => x.ReferenceId == allOf[0].Reference?.Id)
106+
: Children.First(x => x.ReferenceId == allOf[1].Reference?.Id)
107+
: throw new InvalidOperationException("Schema is not derived class.");
108+
82109
public bool IsAnyOfLikeStructure => IsAnyOf || IsOneOf || IsAllOf;
83110
public bool IsNamedAnyOfLike => IsAnyOfLikeStructure &&
84111
(IsComponent || Schema.Discriminator != null);

src/libs/AutoSDK/Models/TypeData.cs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -229,6 +229,7 @@ public static string GetCSharpType(SchemaContext context)
229229
$"{context.Children.FirstOrDefault(x => x.Hint == Hint.ArrayItem)?.TypeData.CSharpTypeWithoutNullability}".AsArray(),
230230

231231
(_, _) when context.IsNamedAnyOfLike => $"global::{context.Settings.Namespace}.{context.Id}",
232+
(_, _) when context.IsDerivedClass => $"global::{context.Settings.Namespace}.{context.Id}",
232233

233234
(_, _) when context.Schema.IsAnyOf() => $"global::{context.Settings.Namespace}.AnyOf<{string.Join(", ", context.Children.Where(x => x.Hint == Hint.AnyOf).Select(x => x.TypeData.CSharpTypeWithNullabilityForValueTypes))}>",
234235
(_, _) when context.Schema.IsOneOf() => $"global::{context.Settings.Namespace}.OneOf<{string.Join(", ", context.Children.Where(x => x.Hint == Hint.OneOf).Select(x => x.TypeData.CSharpTypeWithNullabilityForValueTypes))}>",

src/libs/AutoSDK/Sources/Sources.Models.Json.cs

Lines changed: 31 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,18 @@ public static string GenerateClassFromToJsonMethods(
1010
ModelData modelData,
1111
CancellationToken cancellationToken = default)
1212
{
13+
if (modelData.IsDerivedClass)
14+
{
15+
return string.Empty;
16+
}
17+
1318
return GenerateModelFromToJsonMethods(
1419
@namespace: modelData.Namespace,
1520
className: modelData.ClassName,
1621
settings: modelData.Settings,
1722
isValueType: false,
23+
baseClassName: modelData.BaseClass,
24+
isBaseClass: modelData.IsBaseClass,
1825
cancellationToken);
1926
}
2027

@@ -32,6 +39,8 @@ public static string GenerateAnyOfFromToJsonMethods(
3239
className: className,
3340
settings: anyOfData.Settings,
3441
isValueType: true,
42+
baseClassName: string.Empty,
43+
isBaseClass: false,
3544
cancellationToken);
3645
}
3746

@@ -40,12 +49,15 @@ public static string GenerateModelFromToJsonMethods(
4049
string className,
4150
Settings settings,
4251
bool isValueType,
52+
string baseClassName,
53+
bool isBaseClass,
4354
CancellationToken cancellationToken = default)
4455
{
4556
var typeName = $"global::{@namespace}.{className}";
4657
var modifiers = isValueType
4758
? "readonly partial struct"
48-
: "sealed partial class";
59+
: $"{(isBaseClass ? "" : "sealed ")}partial class";
60+
var isDerivedClass = !string.IsNullOrWhiteSpace(baseClassName);
4961

5062
return settings.JsonSerializerType == JsonSerializerType.SystemTextJson
5163
? @$"#nullable enable
@@ -60,7 +72,7 @@ public string ToJson(
6072
{{
6173
return global::System.Text.Json.JsonSerializer.Serialize(
6274
this,
63-
this.GetType(),
75+
{(isDerivedClass ? $"typeof({baseClassName})" : "this.GetType()")},
6476
jsonSerializerContext);
6577
}}
6678
@@ -74,45 +86,49 @@ public string ToJson(
7486
{{
7587
return global::System.Text.Json.JsonSerializer.Serialize(
7688
this,
89+
{(isDerivedClass ? $"typeof({baseClassName})," : string.Empty)}
7790
jsonSerializerOptions);
7891
}}
7992
8093
{"Deserializes a JSON string using the provided JsonSerializerContext.".ToXmlDocumentationSummary(level: 8)}
81-
public static {typeName}? FromJson(
94+
public static {typeName}? FromJson{(isDerivedClass ? "<T>" : string.Empty)}(
8295
string json,
8396
global::System.Text.Json.Serialization.JsonSerializerContext jsonSerializerContext)
97+
{(isDerivedClass ? $"where T : {baseClassName}" : string.Empty)}
8498
{{
8599
return global::System.Text.Json.JsonSerializer.Deserialize(
86100
json,
87-
typeof({typeName}),
88-
jsonSerializerContext) as {typeName}{(isValueType ? "?" : "")};
101+
typeof({(isDerivedClass ? baseClassName : typeName)}),
102+
jsonSerializerContext) as {(isDerivedClass ? "T" : typeName)}{(isValueType ? "?" : "")};
89103
}}
90104
91105
{"Deserializes a JSON string using the provided JsonSerializerOptions.".ToXmlDocumentationSummary(level: 8)}
92106
#if NET8_0_OR_GREATER
93107
[global::System.Diagnostics.CodeAnalysis.RequiresUnreferencedCode(""JSON serialization and deserialization might require types that cannot be statically analyzed. Use the overload that takes a JsonTypeInfo or JsonSerializerContext, or make sure all of the required types are preserved."")]
94108
[global::System.Diagnostics.CodeAnalysis.RequiresDynamicCode(""JSON serialization and deserialization might require types that cannot be statically analyzed and might need runtime code generation. Use System.Text.Json source generation for native AOT applications."")]
95109
#endif
96-
public static {typeName}? FromJson(
110+
public static {typeName}? FromJson{(isDerivedClass ? "<T>" : string.Empty)}(
97111
string json,
98112
global::System.Text.Json.JsonSerializerOptions? jsonSerializerOptions = null)
113+
{(isDerivedClass ? $"where T : {baseClassName}" : string.Empty)}
99114
{{
100-
return global::System.Text.Json.JsonSerializer.Deserialize<{typeName}>(
115+
return global::System.Text.Json.JsonSerializer.Deserialize<{(isDerivedClass ? baseClassName : typeName)}>(
101116
json,
102-
jsonSerializerOptions);
117+
jsonSerializerOptions){(isDerivedClass ? " as T" : string.Empty)};
103118
}}
104119
105120
/// <summary>
106121
/// Deserializes a JSON stream using the provided JsonSerializerContext.
107122
/// </summary>
108-
public static async global::System.Threading.Tasks.ValueTask<{typeName}?> FromJsonStreamAsync(
123+
public static async global::System.Threading.Tasks.ValueTask<{typeName}?> FromJsonStreamAsync{(isDerivedClass ? "<T>" : string.Empty)}(
109124
global::System.IO.Stream jsonStream,
110125
global::System.Text.Json.Serialization.JsonSerializerContext jsonSerializerContext)
126+
{(isDerivedClass ? $"where T : {baseClassName}" : string.Empty)}
111127
{{
112128
return (await global::System.Text.Json.JsonSerializer.DeserializeAsync(
113129
jsonStream,
114-
typeof({typeName}),
115-
jsonSerializerContext).ConfigureAwait(false)) as {typeName}{(isValueType ? "?" : "")};
130+
typeof({(isDerivedClass ? baseClassName : typeName)}),
131+
jsonSerializerContext).ConfigureAwait(false)) as {(isDerivedClass ? "T" : typeName)}{(isValueType ? "?" : "")};
116132
}}
117133
118134
/// <summary>
@@ -122,17 +138,18 @@ public string ToJson(
122138
[global::System.Diagnostics.CodeAnalysis.RequiresUnreferencedCode(""JSON serialization and deserialization might require types that cannot be statically analyzed. Use the overload that takes a JsonTypeInfo or JsonSerializerContext, or make sure all of the required types are preserved."")]
123139
[global::System.Diagnostics.CodeAnalysis.RequiresDynamicCode(""JSON serialization and deserialization might require types that cannot be statically analyzed and might need runtime code generation. Use System.Text.Json source generation for native AOT applications."")]
124140
#endif
125-
public static global::System.Threading.Tasks.ValueTask<{typeName}?> FromJsonStreamAsync(
141+
public static global::System.Threading.Tasks.ValueTask<{typeName}?> FromJsonStreamAsync{(isDerivedClass ? "<T>" : string.Empty)}(
126142
global::System.IO.Stream jsonStream,
127143
global::System.Text.Json.JsonSerializerOptions? jsonSerializerOptions = null)
144+
{(isDerivedClass ? $"where T : {baseClassName}" : string.Empty)}
128145
{{
129146
return global::System.Text.Json.JsonSerializer.DeserializeAsync<{typeName}?>(
130147
jsonStream,
131-
jsonSerializerOptions);
148+
jsonSerializerOptions){(isDerivedClass ? " as T" : string.Empty)};
132149
}}
133150
}}
134151
}}
135-
"
152+
".RemoveBlankLinesWhereOnlyWhitespaces()
136153
: @$"#nullable enable
137154
138155
namespace {@namespace}

src/libs/AutoSDK/Sources/Sources.Models.Validation.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ public static string GenerateClassValidationMethods(
99
ModelData modelData,
1010
CancellationToken cancellationToken = default)
1111
{
12-
return GenerateModelFromToJsonMethods(
12+
return GenerateModelValidationMethods(
1313
@namespace: modelData.Namespace,
1414
className: modelData.ClassName,
1515
settings: modelData.Settings,
@@ -26,7 +26,7 @@ public static string GenerateAnyOfValidationMethods(
2626
? $"{anyOfData.SubType}{types}"
2727
: anyOfData.Name;
2828

29-
return GenerateModelFromToJsonMethods(
29+
return GenerateModelValidationMethods(
3030
@namespace: anyOfData.Namespace,
3131
className: className,
3232
settings: anyOfData.Settings,

src/libs/AutoSDK/Sources/Sources.Models.cs

Lines changed: 24 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -81,13 +81,23 @@ public static string GenerateClassModel(
8181
var additionalPropertiesPostfix = modelData.ClassName == "AdditionalProperties"
8282
? "2"
8383
: string.Empty;
84+
var properties = modelData.Properties.Where(x =>
85+
!modelData.IsBaseClass ||
86+
x.Id != modelData.DiscriminatorPropertyName).ToArray();
8487

8588
return $@"
8689
{modelData.Summary.ToXmlDocumentationSummary(level: 4)}
8790
{(modelData.IsDeprecated ? "[global::System.Obsolete(\"This model marked as deprecated.\")]" : " ")}
88-
public sealed partial class {modelData.ClassName}
91+
{(modelData.Settings.JsonSerializerType == JsonSerializerType.SystemTextJson && modelData.IsBaseClass ? @$"
92+
[global::System.Text.Json.Serialization.JsonPolymorphic(
93+
TypeDiscriminatorPropertyName = ""{modelData.DiscriminatorPropertyName}"",
94+
IgnoreUnrecognizedTypeDiscriminators = true,
95+
UnknownDerivedTypeHandling = global::System.Text.Json.Serialization.JsonUnknownDerivedTypeHandling.FallBackToBaseType)]
96+
{modelData.DerivedTypes.Select(x => $@"
97+
[global::System.Text.Json.Serialization.JsonDerivedType(typeof({modelData.Namespace}.{x.ClassName}), typeDiscriminator: ""{x.Discriminator}"")]").Inject()}" : " ")}
98+
public{(modelData.IsBaseClass ? "" : " sealed")} partial class {modelData.ClassName}{(!string.IsNullOrWhiteSpace(modelData.BaseClass) ? $" : {modelData.BaseClass}" : "")}
8999
{{
90-
{modelData.Properties.Select(property => @$"
100+
{properties.Select(property => @$"
91101
{property.Summary.ToXmlDocumentationSummary(level: 8)}
92102
{property.DefaultValue?.ClearForXml().ToXmlDocumentationDefault(level: 8)}
93103
{property.Example?.ToXmlDocumentationExample(level: 8)}
@@ -98,31 +108,35 @@ public sealed partial class {modelData.ClassName}
98108
public{(property.IsRequired ? requiredKeyword : "")} {property.Type.CSharpType} {property.Name} {{ get; set; }}{GetDefaultValue(property, isRequiredKeywordSupported)}
99109
").Inject()}
100110
111+
{(!modelData.IsDerivedClass ? $@"
101112
{"Additional properties that are not explicitly defined in the schema".ToXmlDocumentationSummary(level: 8)}
102113
{jsonSerializer.GenerateExtensionDataAttribute()}
103114
public global::System.Collections.Generic.IDictionary<string, object> AdditionalProperties{additionalPropertiesPostfix} {{ get; set; }} = new global::System.Collections.Generic.Dictionary<string, object>();
104-
115+
" : " ")}
116+
117+
{( properties.Any(static x => x.IsRequired || !x.IsDeprecated) ? $@"
105118
/// <summary>
106119
/// Initializes a new instance of the <see cref=""{modelData.ClassName}"" /> class.
107120
/// </summary>
108-
{modelData.Properties.Where(static x => x.IsRequired || !x.IsDeprecated).Select(x => $@"
121+
{properties.Where(static x => x.IsRequired || !x.IsDeprecated).Select(x => $@"
109122
{x.Summary.ToXmlDocumentationForParam(x.ParameterName, level: 8)}").Inject()}
110123
{(modelData.Settings.TargetFramework.StartsWith("net8", StringComparison.OrdinalIgnoreCase) ? "[global::System.Diagnostics.CodeAnalysis.SetsRequiredMembers]" : " ")}
111124
public {modelData.ClassName}(
112125
{string.Join(",",
113-
modelData.Properties.Where(static x => x.IsRequired).Select(x => $@"
126+
properties.Where(static x => x.IsRequired).Select(x => $@"
114127
{x.Type.CSharpType} {x.ParameterName}").Concat(
115-
modelData.Properties.Where(static x => x is { IsRequired: false, IsDeprecated: false } && (x.Type.CSharpTypeNullability || string.IsNullOrWhiteSpace(x.DefaultValue))).Select(x => $@"
128+
properties.Where(static x => x is { IsRequired: false, IsDeprecated: false } && (x.Type.CSharpTypeNullability || string.IsNullOrWhiteSpace(x.DefaultValue))).Select(x => $@"
116129
{x.Type.CSharpType} {x.ParameterName}")).Concat(
117-
modelData.Properties.Where(static x => x is { IsRequired: false, IsDeprecated: false } && !(x.Type.CSharpTypeNullability || string.IsNullOrWhiteSpace(x.DefaultValue))).Select(x => $@"
130+
properties.Where(static x => x is { IsRequired: false, IsDeprecated: false } && !(x.Type.CSharpTypeNullability || string.IsNullOrWhiteSpace(x.DefaultValue))).Select(x => $@"
118131
{x.Type.CSharpType} {x.ParameterName}{GetDefaultValue(x, isRequiredKeywordSupported).TrimEnd(';')}")))})
119132
{{
120-
{modelData.Properties.Where(static x => x.IsRequired).Select(x => $@"
133+
{properties.Where(static x => x.IsRequired).Select(x => $@"
121134
this.{x.Name} = {x.ParameterName}{(x.Type.IsValueType ? "" : $" ?? throw new global::System.ArgumentNullException(nameof({x.ParameterName}))")};").Inject()}
122-
{modelData.Properties.Where(static x => x is { IsRequired: false, IsDeprecated: false }).Select(x => $@"
135+
{properties.Where(static x => x is { IsRequired: false, IsDeprecated: false }).Select(x => $@"
123136
this.{x.Name} = {x.ParameterName};").Inject()}
124137
}}
125-
{(modelData.Properties.Any(static x => !x.IsDeprecated) ? $@"
138+
" : " ")}
139+
{(properties.Any(static x => !x.IsDeprecated) ? $@"
126140
/// <summary>
127141
/// Initializes a new instance of the <see cref=""{modelData.ClassName}"" /> class.
128142
/// </summary>

src/tests/AutoSDK.SnapshotTests/Snapshots/ai21/NewtonsoftJson/_#G.Models.ChatCompletionMeta.g.verified.cs

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,5 @@ public sealed partial class ChatCompletionMeta
1515
/// </summary>
1616
[global::Newtonsoft.Json.JsonExtensionData]
1717
public global::System.Collections.Generic.IDictionary<string, object> AdditionalProperties { get; set; } = new global::System.Collections.Generic.Dictionary<string, object>();
18-
19-
/// <summary>
20-
/// Initializes a new instance of the <see cref="ChatCompletionMeta" /> class.
21-
/// </summary>
22-
public ChatCompletionMeta(
23-
)
24-
{
25-
}
2618
}
2719
}

src/tests/AutoSDK.SnapshotTests/Snapshots/ai21/NewtonsoftJson/_#G.Models.ChatCompletionVllmStreamingMessageMeta.g.verified.cs

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,5 @@ public sealed partial class ChatCompletionVllmStreamingMessageMeta
1515
/// </summary>
1616
[global::Newtonsoft.Json.JsonExtensionData]
1717
public global::System.Collections.Generic.IDictionary<string, object> AdditionalProperties { get; set; } = new global::System.Collections.Generic.Dictionary<string, object>();
18-
19-
/// <summary>
20-
/// Initializes a new instance of the <see cref="ChatCompletionVllmStreamingMessageMeta" /> class.
21-
/// </summary>
22-
public ChatCompletionVllmStreamingMessageMeta(
23-
)
24-
{
25-
}
2618
}
2719
}

0 commit comments

Comments
 (0)