diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..485dee6 --- /dev/null +++ b/.gitignore @@ -0,0 +1 @@ +.idea diff --git a/decode_map.go b/decode_map.go index 52e0526..6e48513 100644 --- a/decode_map.go +++ b/decode_map.go @@ -279,49 +279,33 @@ func decodeStructValue(d *Decoder, v reflect.Value) error { n, err := d.mapLen(c) if err == nil { - return d.decodeStruct(v, n) + return d.decodeStructValueAsMap(v, n) } var err2 error n, err2 = d.arrayLen(c) - if err2 != nil { - return err - } - - if n <= 0 { - v.Set(reflect.Zero(v.Type())) - return nil - } - - fields := structs.Fields(v.Type(), d.structTag) - if n != len(fields.List) { - return errArrayStruct + if err2 == nil { + return d.decodeStructValueAsArray(v, n) } - for _, f := range fields.List { - if err := f.DecodeValue(d, v); err != nil { - return err - } - } - - return nil + return err } -func (d *Decoder) decodeStruct(v reflect.Value, n int) error { +func (d *Decoder) decodeStructValueAsMap(v reflect.Value, n int) error { if n == -1 { v.Set(reflect.Zero(v.Type())) return nil } - fields := structs.Fields(v.Type(), d.structTag) + structFields := structs.Fields(v.Type(), d.structTag) for i := 0; i < n; i++ { name, err := d.decodeStringTemp() if err != nil { return err } - if f := fields.Map[name]; f != nil { - if err := f.DecodeValue(d, v); err != nil { + if f := structFields.Map[name]; f != nil { + if err := decodeStructFieldValue(d, v, f); err != nil { return err } continue @@ -337,3 +321,31 @@ func (d *Decoder) decodeStruct(v reflect.Value, n int) error { return nil } + +func (d *Decoder) decodeStructValueAsArray(v reflect.Value, n int) error { + if n <= 0 { + v.Set(reflect.Zero(v.Type())) + return nil + } + + fields := structs.Fields(v.Type(), d.structTag) + + if n != len(fields.List) && !fields.hasKeyFields { + return errArrayStruct + } + + for _, f := range fields.List { + if err := decodeStructFieldValue(d, v, f); err != nil { + return err + } + } + + return nil +} + +func decodeStructFieldValue(d *Decoder, v reflect.Value, f *field) error { + if f == nil { + return d.Skip() + } + return f.DecodeValue(d, v) +} diff --git a/encode_map.go b/encode_map.go index ba4c61b..2dd5687 100644 --- a/encode_map.go +++ b/encode_map.go @@ -145,9 +145,14 @@ func (e *Encoder) EncodeMapLen(l int) error { func encodeStructValue(e *Encoder, strct reflect.Value) error { structFields := structs.Fields(strct.Type(), e.structTag) - if e.flags&arrayEncodedStructsFlag != 0 || structFields.AsArray { - return encodeStructValueAsArray(e, strct, structFields.List) + + if e.flags&arrayEncodedStructsFlag != 0 || structFields.AsArray || structFields.hasKeyFields { + return encodeStructValueAsArray(e, strct, structFields) } + return encodeStructValueAsMap(e, strct, structFields) +} + +func encodeStructValueAsMap(e *Encoder, strct reflect.Value, structFields *fields) error { fields := structFields.OmitEmpty(strct, e.flags&omitEmptyFlag != 0) if err := e.EncodeMapLen(len(fields)); err != nil { @@ -158,7 +163,7 @@ func encodeStructValue(e *Encoder, strct reflect.Value) error { if err := e.EncodeString(f.name); err != nil { return err } - if err := f.EncodeValue(e, strct); err != nil { + if err := encodeStructFieldValue(e, strct, f); err != nil { return err } } @@ -166,14 +171,21 @@ func encodeStructValue(e *Encoder, strct reflect.Value) error { return nil } -func encodeStructValueAsArray(e *Encoder, strct reflect.Value, fields []*field) error { - if err := e.EncodeArrayLen(len(fields)); err != nil { +func encodeStructValueAsArray(e *Encoder, strct reflect.Value, structFields *fields) error { + if err := e.EncodeArrayLen(len(structFields.List)); err != nil { return err } - for _, f := range fields { - if err := f.EncodeValue(e, strct); err != nil { + for _, f := range structFields.List { + if err := encodeStructFieldValue(e, strct, f); err != nil { return err } } return nil } + +func encodeStructFieldValue(e *Encoder, strct reflect.Value, f *field) error { + if f == nil { + return e.EncodeNil() + } + return f.EncodeValue(e, strct) +} diff --git a/example_test.go b/example_test.go index 1408a59..11b6689 100644 --- a/example_test.go +++ b/example_test.go @@ -164,6 +164,29 @@ func ExampleMarshal_asArray() { // Output: [foo bar] } +func ExampleMarshal_useKeys() { + type Item struct { + Foo string `msgpack:"key:0"` + // Bar string `msgpack:"1"` + Baz string `msgpack:"key:2"` + } + + var buf bytes.Buffer + enc := msgpack.NewEncoder(&buf) + err := enc.Encode(&Item{Foo: "foo", Baz: "baz"}) + if err != nil { + panic(err) + } + + dec := msgpack.NewDecoder(&buf) + v, err := dec.DecodeInterface() + if err != nil { + panic(err) + } + fmt.Println(v) + // Output: [foo baz] +} + func ExampleMarshal_omitEmpty() { type Item struct { Foo string diff --git a/types.go b/types.go index 69aca61..95c494c 100644 --- a/types.go +++ b/types.go @@ -5,6 +5,7 @@ import ( "fmt" "log" "reflect" + "strconv" "sync" "github.com/vmihailenco/tagparser/v2" @@ -92,6 +93,7 @@ func (m *structCache) Fields(typ reflect.Type, tag string) *fields { type field struct { name string index []int + key int omitEmpty bool encoder encoderFunc decoder decoderFunc @@ -126,6 +128,7 @@ type fields struct { List []*field AsArray bool + hasKeyFields bool hasOmitEmpty bool } @@ -137,11 +140,26 @@ func newFields(typ reflect.Type) *fields { } } -func (fs *fields) Add(field *field) { - fs.warnIfFieldExists(field.name) - fs.Map[field.name] = field - fs.List = append(fs.List, field) - if field.omitEmpty { +func (fs *fields) Add(f *field) { + fs.warnIfFieldExists(f.name) + fs.Map[f.name] = f + + if f.key == -1 { + fs.List = append(fs.List, f) + } else { + if len(fs.List) <= f.key { + if cap(fs.List) > f.key { + fs.List = fs.List[0 : f.key+1] + } else { + fsListOld := fs.List + fs.List = make([]*field, f.key+1) + copy(fs.List, fsListOld) + } + } + fs.List[f.key] = f + } + + if f.omitEmpty { fs.hasOmitEmpty = true } } @@ -175,12 +193,7 @@ func getFields(typ reflect.Type, fallbackTag string) *fields { for i := 0; i < typ.NumField(); i++ { f := typ.Field(i) - tagStr := f.Tag.Get(defaultStructTag) - if tagStr == "" && fallbackTag != "" { - tagStr = f.Tag.Get(fallbackTag) - } - - tag := tagparser.Parse(tagStr) + tag := getFieldTag(&f, fallbackTag) if tag.Name == "-" { continue } @@ -199,6 +212,7 @@ func getFields(typ reflect.Type, fallbackTag string) *fields { field := &field{ name: tag.Name, index: f.Index, + key: getKeyFromTag(tag), omitEmpty: omitEmpty || tag.HasOption("omitempty"), } @@ -223,6 +237,11 @@ func getFields(typ reflect.Type, fallbackTag string) *fields { field.name = f.Name } + if field.key != -1 { + fs.AsArray = true + fs.hasKeyFields = true + } + if f.Anonymous && !tag.HasOption("noinline") { inline := tag.HasOption("inline") if inline { @@ -232,9 +251,7 @@ func getFields(typ reflect.Type, fallbackTag string) *fields { } if inline { - if _, ok := fs.Map[field.name]; ok { - log.Printf("msgpack: %s already has field=%s", fs.Type, field.name) - } + fs.warnIfFieldExists(field.name) fs.Map[field.name] = field continue } @@ -250,6 +267,26 @@ func getFields(typ reflect.Type, fallbackTag string) *fields { return fs } +func getFieldTag(f *reflect.StructField, fallbackTag string) *tagparser.Tag { + tagStr := f.Tag.Get(defaultStructTag) + if tagStr == "" && fallbackTag != "" { + tagStr = f.Tag.Get(fallbackTag) + } + return tagparser.Parse(tagStr) +} + +func getKeyFromTag(tag *tagparser.Tag) int { + if key, ok := tag.Options["key"]; ok { + keyInt, err := strconv.Atoi(key) + if err != nil { + err := fmt.Errorf("msgpack: key value should be int: %s", key) + panic(err) + } + return keyInt + } + return -1 +} + var ( encodeStructValuePtr uintptr decodeStructValuePtr uintptr diff --git a/types_test.go b/types_test.go index fbc1ad3..9fb81c1 100644 --- a/types_test.go +++ b/types_test.go @@ -231,6 +231,18 @@ type AsArrayTest struct { OmitEmptyTest } +type AsArrayByKeysTest struct { + Foo string `msgpack:"key:0"` + Bar string `msgpack:"key:1"` + Baz string `msgpack:"key:2"` +} + +type AsArrayByKeysCompatibilityTest struct { + Foo string `msgpack:"key:0"` + // Bar string `msgpack:"key:1"` + Baz string `msgpack:"key:2"` +} + type ExtTestField struct { ExtTest ExtTest } @@ -280,6 +292,8 @@ var encoderTests = []encoderTest{ {&InlinePtrTest{OmitEmptyTest: &OmitEmptyTest{Bar: "world"}}, "81a3426172a5776f726c64"}, {&AsArrayTest{}, "92a0a0"}, + {&AsArrayByKeysTest{}, "93a0a0a0"}, + {&AsArrayByKeysCompatibilityTest{}, "93a0c0a0"}, {&JSONFallbackTest{Foo: "hello"}, "82a3666f6fa568656c6c6fa3626172a0"}, {&JSONFallbackTest{Bar: "world"}, "81a3626172a5776f726c64"}, @@ -597,6 +611,20 @@ var ( decErr: "msgpack: number of fields in array-encoded struct has changed", }, + {in: nil, out: new(*AsArrayByKeysTest), wantnil: true}, + {in: nil, out: new(AsArrayByKeysTest), wantzero: true}, + {in: AsArrayByKeysTest{Foo: "foo", Bar: "bar", Baz: "baz"}, out: new(AsArrayByKeysTest)}, + { + in: AsArrayByKeysTest{Foo: "foo", Bar: "bar", Baz: "baz"}, + out: new(AsArrayByKeysCompatibilityTest), + wanted: AsArrayByKeysCompatibilityTest{Foo: "foo", Baz: "baz"}, + }, + { + in: AsArrayByKeysCompatibilityTest{Foo: "foo", Baz: "baz"}, + out: new(AsArrayByKeysTest), + wanted: AsArrayByKeysTest{Foo: "foo", Baz: "baz"}, + }, + {in: (*EventTime)(nil), out: new(*EventTime)}, {in: &EventTime{time.Unix(0, 0)}, out: new(*EventTime)},