Skip to content

Commit bf3c3cc

Browse files
committed
parse objects
1 parent 0706e46 commit bf3c3cc

File tree

7 files changed

+1015
-41
lines changed

7 files changed

+1015
-41
lines changed

bool_or_schema.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,8 @@ func NewBoolOrSchema(v any) *BoolOrSchema {
7878
return &BoolOrSchema{Allowed: v}
7979
case *RefOrSpec[Schema]:
8080
return &BoolOrSchema{Schema: v}
81+
case *SchemaBulder:
82+
return &BoolOrSchema{Schema: v.Build()}
8183
default:
8284
return nil
8385
}

components.go

Lines changed: 54 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
11
package openapi
22

3+
import (
4+
"regexp"
5+
)
6+
37
// Components holds a set of reusable objects for different aspects of the OAS.
48
// All objects defined within the components object will have no effect on the API unless they are explicitly referenced
59
// from properties outside the components object.
@@ -160,57 +164,77 @@ func (o *Components) Add(name string, v any) *Components {
160164
return o
161165
}
162166

167+
var namePattern = regexp.MustCompile(`^[a-zA-Z0-9\.\-_]+$`)
168+
163169
func (o *Components) validateSpec(location string, validator *Validator) []*validationError {
164170
var errs []*validationError
165-
if o.Schemas != nil {
166-
for k, v := range o.Schemas {
167-
errs = append(errs, v.validateSpec(joinLoc(location, "schemas", k), validator)...)
171+
for k, v := range o.Schemas {
172+
if !namePattern.MatchString(k) {
173+
errs = append(errs, newValidationError(joinLoc(location, "schemas", k), "invalid name %q, must match %q", k, namePattern.String()))
168174
}
175+
errs = append(errs, v.validateSpec(joinLoc(location, "schemas", k), validator)...)
169176
}
170-
if o.Responses != nil {
171-
for k, v := range o.Responses {
172-
errs = append(errs, v.validateSpec(joinLoc(location, "responses", k), validator)...)
177+
178+
for k, v := range o.Responses {
179+
if !namePattern.MatchString(k) {
180+
errs = append(errs, newValidationError(joinLoc(location, "schemas", k), "invalid name %q, must match %q", k, namePattern.String()))
173181
}
182+
errs = append(errs, v.validateSpec(joinLoc(location, "responses", k), validator)...)
174183
}
175-
if o.Parameters != nil {
176-
for k, v := range o.Parameters {
177-
errs = append(errs, v.validateSpec(joinLoc(location, "parameters", k), validator)...)
184+
for k, v := range o.Parameters {
185+
if !namePattern.MatchString(k) {
186+
errs = append(errs, newValidationError(joinLoc(location, "schemas", k), "invalid name %q, must match %q", k, namePattern.String()))
178187
}
188+
errs = append(errs, v.validateSpec(joinLoc(location, "parameters", k), validator)...)
179189
}
180-
if o.Examples != nil {
181-
for k, v := range o.Examples {
182-
errs = append(errs, v.validateSpec(joinLoc(location, "examples", k), validator)...)
190+
191+
for k, v := range o.Examples {
192+
if !namePattern.MatchString(k) {
193+
errs = append(errs, newValidationError(joinLoc(location, "schemas", k), "invalid name %q, must match %q", k, namePattern.String()))
183194
}
195+
errs = append(errs, v.validateSpec(joinLoc(location, "examples", k), validator)...)
184196
}
185-
if o.RequestBodies != nil {
186-
for k, v := range o.RequestBodies {
187-
errs = append(errs, v.validateSpec(joinLoc(location, "requestBodies", k), validator)...)
197+
198+
for k, v := range o.RequestBodies {
199+
if !namePattern.MatchString(k) {
200+
errs = append(errs, newValidationError(joinLoc(location, "schemas", k), "invalid name %q, must match %q", k, namePattern.String()))
188201
}
202+
errs = append(errs, v.validateSpec(joinLoc(location, "requestBodies", k), validator)...)
189203
}
190-
if o.Headers != nil {
191-
for k, v := range o.Headers {
192-
errs = append(errs, v.validateSpec(joinLoc(location, "headers", k), validator)...)
204+
205+
for k, v := range o.Headers {
206+
if !namePattern.MatchString(k) {
207+
errs = append(errs, newValidationError(joinLoc(location, "schemas", k), "invalid name %q, must match %q", k, namePattern.String()))
193208
}
209+
errs = append(errs, v.validateSpec(joinLoc(location, "headers", k), validator)...)
194210
}
195-
if o.SecuritySchemes != nil {
196-
for k, v := range o.SecuritySchemes {
197-
errs = append(errs, v.validateSpec(joinLoc(location, "securitySchemes", k), validator)...)
211+
212+
for k, v := range o.SecuritySchemes {
213+
if !namePattern.MatchString(k) {
214+
errs = append(errs, newValidationError(joinLoc(location, "schemas", k), "invalid name %q, must match %q", k, namePattern.String()))
198215
}
216+
errs = append(errs, v.validateSpec(joinLoc(location, "securitySchemes", k), validator)...)
199217
}
200-
if o.Links != nil {
201-
for k, v := range o.Links {
202-
errs = append(errs, v.validateSpec(joinLoc(location, "links", k), validator)...)
218+
219+
for k, v := range o.Links {
220+
if !namePattern.MatchString(k) {
221+
errs = append(errs, newValidationError(joinLoc(location, "schemas", k), "invalid name %q, must match %q", k, namePattern.String()))
203222
}
223+
errs = append(errs, v.validateSpec(joinLoc(location, "links", k), validator)...)
204224
}
205-
if o.Callbacks != nil {
206-
for k, v := range o.Callbacks {
207-
errs = append(errs, v.validateSpec(joinLoc(location, "callbacks", k), validator)...)
225+
226+
for k, v := range o.Callbacks {
227+
if !namePattern.MatchString(k) {
228+
errs = append(errs, newValidationError(joinLoc(location, "schemas", k), "invalid name %q, must match %q", k, namePattern.String()))
208229
}
230+
errs = append(errs, v.validateSpec(joinLoc(location, "callbacks", k), validator)...)
209231
}
210-
if o.Paths != nil {
211-
for k, v := range o.Paths {
212-
errs = append(errs, v.validateSpec(joinLoc(location, "paths", k), validator)...)
232+
233+
for k, v := range o.Paths {
234+
if !namePattern.MatchString(k) {
235+
errs = append(errs, newValidationError(joinLoc(location, "schemas", k), "invalid name %q, must match %q", k, namePattern.String()))
213236
}
237+
errs = append(errs, v.validateSpec(joinLoc(location, "paths", k), validator)...)
214238
}
215239

216240
return errs

parser.go

Lines changed: 250 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,250 @@
1+
package openapi
2+
3+
import (
4+
"encoding/json"
5+
"fmt"
6+
"reflect"
7+
"strings"
8+
)
9+
10+
const is64Bit = uint64(^uintptr(0)) == ^uint64(0)
11+
12+
// ParseObject parses the object and returns the schema or the reference to the schema.
13+
//
14+
// The object can be a struct, pointer to struct, map, slice, pointer to map or slice, or any other type.
15+
// The object can contain fields with `json`, `yaml` or `openapi` tags.
16+
//
17+
// `opanapi:"<name>[,schema:<ref> || named fields]"` tag
18+
// - <name> is the name of the field in the schema, can be "-" to skip the field or empty to use the name from json, yaml tags or original field name.
19+
// - ref:<ref> is a reference to the schema, can not be used with jsonschema fields
20+
// jsonschema fields:
21+
// - required
22+
// - deprecated
23+
// - title:<title>
24+
// - description:<description>
25+
// - type:<type> (boolean, integer, number, string, array, object), may be used multiple times.
26+
// The first type overrides the default type, all other types are added.
27+
// - addtype:<type>, adds additional type, may be used multiple times.
28+
// - format:<format>
29+
//
30+
// The components is needed to store the schemas of the structs, and to avoid the circular references.
31+
// In case of the given object is struct, the function will return a reference to the schema.
32+
// Otherwise, the function will return the schema itself.
33+
func ParseObject(obj any, components *Extendable[Components]) (*SchemaBulder, error) {
34+
t := reflect.TypeOf(obj)
35+
if t == nil {
36+
return NewSchemaBuilder().Type(NullType).GoType("nil"), nil
37+
}
38+
value := reflect.ValueOf(obj)
39+
return parseObject(joinLoc("", t.String()), value, components)
40+
}
41+
42+
func parseObject(location string, obj reflect.Value, components *Extendable[Components]) (*SchemaBulder, error) {
43+
t := obj.Type()
44+
if t == nil {
45+
return NewSchemaBuilder().Type(NullType).GoType("nil"), nil
46+
}
47+
kind := t.Kind()
48+
if kind == reflect.Ptr {
49+
return parseObject(location, obj.Elem(), components)
50+
}
51+
if kind == reflect.Interface {
52+
return NewSchemaBuilder().GoType("any"), nil
53+
}
54+
builder := NewSchemaBuilder().GoType(fmt.Sprintf("%T", obj.Interface()))
55+
switch obj.Interface().(type) {
56+
case bool:
57+
builder.Type(BooleanType)
58+
case int, uint:
59+
if is64Bit {
60+
builder.Type(IntegerType).Format(Int64Format)
61+
} else {
62+
builder.Type(IntegerType).Format(Int32Format)
63+
}
64+
case int8, int16, int32, uint8, uint16, uint32:
65+
builder.Type(IntegerType).Format(Int32Format)
66+
case int64, uint64:
67+
builder.Type(IntegerType).Format(Int64Format)
68+
case float32:
69+
builder.Type(NumberType).Format(FloatFormat)
70+
case float64:
71+
builder.Type(NumberType).Format(DoubleFormat)
72+
case string:
73+
builder.Type(StringType)
74+
case []byte:
75+
builder.Type(StringType).ContentEncoding(Base64Encoding).GoType("[]byte") // TODO: create an option for default ContentEncoding
76+
case json.Number:
77+
builder.Type(NumberType).GoPackage(t.PkgPath())
78+
case json.RawMessage:
79+
builder.Type(StringType).ContentMediaType("application/json").GoPackage(t.PkgPath())
80+
default:
81+
switch kind {
82+
case reflect.Array, reflect.Slice:
83+
var elemSchema any
84+
if t.Elem().Kind() == reflect.Interface {
85+
elemSchema = true
86+
} else {
87+
var err error
88+
elemSchema, err = parseObject(location, reflect.New(t.Elem()), components)
89+
if err != nil {
90+
return nil, err
91+
}
92+
}
93+
builder.Type(ArrayType).Items(NewBoolOrSchema(elemSchema)).GoType("")
94+
case reflect.Map:
95+
if k := t.Key().Kind(); k != reflect.String {
96+
return nil, fmt.Errorf("%s: unsupported map key type %s, expected string", location, k)
97+
}
98+
var elemSchema any
99+
if t.Elem().Kind() == reflect.Interface {
100+
elemSchema = true
101+
} else {
102+
var err error
103+
elemSchema, err = parseObject(location, reflect.New(t.Elem()), components)
104+
if err != nil {
105+
return nil, err
106+
}
107+
}
108+
builder.Type(ObjectType).AdditionalProperties(NewBoolOrSchema(elemSchema)).GoType("")
109+
case reflect.Struct:
110+
objName := strings.ReplaceAll(t.PkgPath()+"."+t.Name(), "/", ".")
111+
if components.Spec.Schemas[objName] != nil {
112+
return NewSchemaBuilder().Ref("#/components/schemas/" + objName), nil
113+
}
114+
// add a temporary schema to avoid circular references
115+
if components.Spec.Schemas == nil {
116+
components.Spec.Schemas = make(map[string]*RefOrSpec[Schema], 1)
117+
}
118+
// reserve the name of the schema
119+
components.Spec.Schemas[objName] = NewSchemaBuilder().Ref("to be deleted").Build()
120+
var allOf []*RefOrSpec[Schema]
121+
for i := 0; i < t.NumField(); i++ {
122+
field := t.Field(i)
123+
// skip unexported fields
124+
if !field.IsExported() {
125+
continue
126+
}
127+
fieldSchema, err := parseObject(joinLoc(location, field.Name), obj.Field(i), components)
128+
if err != nil {
129+
// remove the temporary schema
130+
delete(components.Spec.Schemas, objName)
131+
return nil, err
132+
}
133+
if field.Type.Kind() == reflect.Ptr {
134+
if fieldSchema.IsRef() {
135+
fieldSchema = NewSchemaBuilder().OneOf(
136+
fieldSchema.Build(),
137+
NewSchemaBuilder().Type(NullType).Build(),
138+
)
139+
} else {
140+
fieldSchema.AddType(NullType)
141+
}
142+
}
143+
if field.Anonymous {
144+
allOf = append(allOf, fieldSchema.Build())
145+
continue
146+
}
147+
name := applyTag(field, fieldSchema, builder)
148+
// skip the field if it's marked as "-"
149+
if name == "-" {
150+
continue
151+
}
152+
builder.AddProperty(name, fieldSchema.Build())
153+
}
154+
if len(allOf) > 0 {
155+
allOf = append(allOf, builder.Type(ObjectType).GoType("").Build())
156+
builder = NewSchemaBuilder().AllOf(allOf...).GoType(t.String())
157+
} else {
158+
builder.Type(ObjectType)
159+
}
160+
builder.GoPackage(t.PkgPath())
161+
components.Spec.Schemas[objName] = builder.Build()
162+
builder = NewSchemaBuilder().Ref("#/components/schemas/" + objName)
163+
}
164+
}
165+
166+
return builder, nil
167+
}
168+
169+
func applyTag(field reflect.StructField, schema *SchemaBulder, parent *SchemaBulder) (name string) {
170+
name = field.Name
171+
172+
for _, tagName := range []string{"json", "yaml"} {
173+
if tag, ok := field.Tag.Lookup(tagName); ok {
174+
parts := strings.SplitN(tag, ",", 2)
175+
if len(parts) > 0 {
176+
part := strings.TrimSpace(parts[0])
177+
if part != "" {
178+
name = part
179+
break
180+
}
181+
}
182+
}
183+
}
184+
185+
tag, ok := field.Tag.Lookup("openapi")
186+
if !ok {
187+
return
188+
}
189+
parts := strings.Split(tag, ",")
190+
if len(parts) == 0 {
191+
return
192+
}
193+
194+
if parts[0] != "" {
195+
name = parts[0]
196+
}
197+
if name == "-" {
198+
return parts[0]
199+
}
200+
parts = parts[1:]
201+
if len(parts) == 0 {
202+
return
203+
}
204+
205+
if strings.HasPrefix("ref:", parts[0]) {
206+
schema.Ref(parts[0][4:])
207+
return
208+
}
209+
210+
var isTypeOverriden bool
211+
212+
for _, part := range parts {
213+
prefixIndex := strings.Index(part, ":")
214+
var prefix string
215+
if prefixIndex == -1 {
216+
prefix = part
217+
} else {
218+
prefix = part[:prefixIndex]
219+
if prefixIndex == len(part)-1 {
220+
part = ""
221+
}
222+
part = part[prefixIndex+1:]
223+
}
224+
switch prefix {
225+
case "required":
226+
parent.AddRequired(name)
227+
case "deprecated":
228+
schema.Deprecated(true)
229+
case "title":
230+
schema.Title(part)
231+
case "description":
232+
schema.Description(part)
233+
case "type":
234+
// first type overrides the default type
235+
// all other types are added
236+
if !isTypeOverriden {
237+
schema.Type(part)
238+
isTypeOverriden = true
239+
} else {
240+
schema.AddType(part)
241+
}
242+
case "addtype":
243+
schema.AddType(part)
244+
case "format":
245+
schema.Format(part)
246+
}
247+
}
248+
249+
return
250+
}

0 commit comments

Comments
 (0)