diff --git a/decode.go b/decode.go index ea645aa..4aba00f 100644 --- a/decode.go +++ b/decode.go @@ -175,6 +175,15 @@ func (d *Decoder) UsePreallocateValues(on bool) { } } +// UseUIntStructKeys enables support for decoding uint struct tag keys to their corresponding struct field. +func (d *Decoder) UseUIntStructKeys(on bool) { + if on { + d.flags |= useUIntStructKeysFlag + } else { + d.flags &= ^useUIntStructKeysFlag + } +} + // DisableAllocLimit enables fully allocating slices/maps when the size is known func (d *Decoder) DisableAllocLimit(on bool) { if on { diff --git a/decode_map.go b/decode_map.go index c54dae3..644e7ae 100644 --- a/decode_map.go +++ b/decode_map.go @@ -4,6 +4,7 @@ import ( "errors" "fmt" "reflect" + "strconv" "github.com/vmihailenco/msgpack/v5/msgpcode" ) @@ -332,9 +333,30 @@ func (d *Decoder) decodeStruct(v reflect.Value, n int) error { fields := structs.Fields(v.Type(), d.structTag) for i := 0; i < n; i++ { - name, err := d.decodeStringTemp() - if err != nil { - return err + var name string + var err error + + // Try to decode the key as a uint if useUIntStructKeysFlag is set + if d.flags&useUIntStructKeysFlag != 0 { + code, err := d.PeekCode() + if err != nil { + return err + } + if msgpcode.IsUint(code) { + intKey, err := d.DecodeUint() + if err != nil { + return err + } + name = strconv.FormatUint(uint64(intKey), 10) + } + } + + // if useUIntStructKeysFlag is not set or if the current key is not a uint, fallback to a string key + if name == "" { + name, err = d.decodeStringTemp() + if err != nil { + return err + } } if f := fields.Map[name]; f != nil { diff --git a/encode.go b/encode.go index 135adc8..b8a69c6 100644 --- a/encode.go +++ b/encode.go @@ -17,6 +17,7 @@ const ( useCompactFloatsFlag useInternedStringsFlag omitEmptyFlag + useUIntStructKeysFlag ) type writer interface { @@ -196,6 +197,15 @@ func (e *Encoder) UseInternedStrings(on bool) { } } +// UseUIntStructKeys causes the Encoder to encode struct fields that have a valid uint tag key as uints. +func (e *Encoder) UseUIntStructKeys(on bool) { + if on { + e.flags |= useUIntStructKeysFlag + } else { + e.flags &= ^useUIntStructKeysFlag + } +} + func (e *Encoder) Encode(v interface{}) error { switch v := v.(type) { case nil: diff --git a/encode_map.go b/encode_map.go index a5aa31b..a7b9ac4 100644 --- a/encode_map.go +++ b/encode_map.go @@ -4,6 +4,7 @@ import ( "math" "reflect" "sort" + "strconv" "github.com/vmihailenco/msgpack/v5/msgpcode" ) @@ -201,8 +202,22 @@ func encodeStructValue(e *Encoder, strct reflect.Value) error { } for _, f := range fields { - if err := e.EncodeString(f.name); err != nil { - return err + uintKeyEncoded := false + if e.flags&useUIntStructKeysFlag != 0 { + // Try to encode the key as a uint if useUIntStructKeysFlag is set + uintKey, err := strconv.ParseUint(f.name, 10, 0) + if err == nil { + if err := e.EncodeUint(uintKey); err != nil { + return err + } + uintKeyEncoded = true + } + } + // If useUIntStructKeysFlag is not set or if the key was not encoded as a uint, encode it as a string + if !uintKeyEncoded { + if err := e.EncodeString(f.name); err != nil { + return err + } } if err := f.EncodeValue(e, strct); err != nil { return err diff --git a/example_test.go b/example_test.go index 1408a59..b68b74c 100644 --- a/example_test.go +++ b/example_test.go @@ -26,6 +26,41 @@ func ExampleMarshal() { // Output: bar } +func ExampleMarshal_intKeys() { + type Item struct { + Foo string `msgpack:"100"` + Bar string `msgpack:"200"` + Baz string `msgpack:"not_int_key"` + } + + buf := new(bytes.Buffer) + enc := msgpack.NewEncoder(buf) + enc.UseUIntStructKeys(true) + err := enc.Encode(&Item{ + Foo: "foo", + Bar: "bar", + Baz: "baz", + }) + if err != nil { + panic(err) + } + + var item Item + dec := msgpack.NewDecoder(buf) + dec.UseUIntStructKeys(true) + err = dec.Decode(&item) + if err != nil { + panic(err) + } + fmt.Println(item.Foo) + fmt.Println(item.Bar) + fmt.Println(item.Baz) + // Output: + // foo + // bar + // baz +} + func ExampleMarshal_mapStringInterface() { in := map[string]interface{}{"foo": 1, "hello": "world"} b, err := msgpack.Marshal(in) diff --git a/msgpcode/msgpcode.go b/msgpcode/msgpcode.go index e35389c..6a2046d 100644 --- a/msgpcode/msgpcode.go +++ b/msgpcode/msgpcode.go @@ -86,3 +86,7 @@ func IsFixedExt(c byte) bool { func IsExt(c byte) bool { return IsFixedExt(c) || c == Ext8 || c == Ext16 || c == Ext32 } + +func IsUint(c byte) bool { + return c <= PosFixedNumHigh || c == Uint8 || c == Uint16 || c == Uint32 || c == Uint64 +}