diff --git a/json/codec.go b/json/codec.go index 77fe264..e4b6ab1 100644 --- a/json/codec.go +++ b/json/codec.go @@ -10,6 +10,7 @@ import ( "sort" "strconv" "strings" + "sync" "sync/atomic" "time" "unicode" @@ -32,13 +33,29 @@ type codec struct { type encoder struct { flags AppendFlags - // ptrDepth tracks the depth of pointer cycles, when it reaches the value + // refDepth tracks the depth of pointer cycles, when it reaches the value // of startDetectingCyclesAfter, the ptrSeen map is allocated and the // encoder starts tracking pointers it has seen as an attempt to detect // whether it has entered a pointer cycle and needs to error before the // goroutine runs out of stack space. - ptrDepth uint32 - ptrSeen map[unsafe.Pointer]struct{} + // + // This relies on encoder being passed as a value, + // and encoder methods calling each other in a traditional stack + // (not using trampoline techniques), + // since refDepth is never decremented. + refDepth uint32 + refSeen cycleMap +} + +type cycleKey struct { + ptr unsafe.Pointer + len int // 0 for pointers or maps; length for slices or array pointers. +} + +type cycleMap map[cycleKey]struct{} + +var cycleMapPool = sync.Pool{ + New: func() any { return make(cycleMap) }, } type decoder struct { @@ -63,6 +80,17 @@ type ( // lookup time for simple types like bool, int, etc.. var cache atomic.Pointer[map[unsafe.Pointer]codec] +func cachedCodec(t reflect.Type) codec { + cache := cacheLoad() + + c, found := cache[typeid(t)] + if !found { + c = constructCachedCodec(t, cache) + } + + return c +} + func cacheLoad() map[unsafe.Pointer]codec { p := cache.Load() if p == nil { diff --git a/json/encode.go b/json/encode.go index 2a6da07..a65dc41 100644 --- a/json/encode.go +++ b/json/encode.go @@ -5,6 +5,7 @@ import ( "fmt" "math" "reflect" + "runtime" "sort" "strconv" "sync" @@ -17,6 +18,23 @@ import ( const hex = "0123456789abcdef" +func (e encoder) appendAny(b []byte, x any) ([]byte, error) { + if x == nil { + // Special case for nil values because it makes the rest of the code + // simpler to assume that it won't be seeing nil pointers. + return e.encodeNull(b, nil) + } + + t := reflect.TypeOf(x) + p := (*iface)(unsafe.Pointer(&x)).ptr + c := cachedCodec(t) + + b, err := c.encode(e, b, p) + runtime.KeepAlive(x) + + return b, err +} + func (e encoder) encodeNull(b []byte, p unsafe.Pointer) ([]byte, error) { return append(b, "null"...), nil } @@ -241,7 +259,7 @@ func (e encoder) encodeToString(b []byte, p unsafe.Pointer, encode encodeFunc) ( func (e encoder) encodeBytes(b []byte, p unsafe.Pointer) ([]byte, error) { v := *(*[]byte)(p) if v == nil { - return append(b, "null"...), nil + return e.encodeNull(b, nil) } n := base64.StdEncoding.EncodedLen(len(v)) + 2 @@ -278,6 +296,15 @@ func (e encoder) encodeTime(b []byte, p unsafe.Pointer) ([]byte, error) { } func (e encoder) encodeArray(b []byte, p unsafe.Pointer, n int, size uintptr, t reflect.Type, encode encodeFunc) ([]byte, error) { + if shouldCheckForRefCycle(&e) { + key := cycleKey{ptr: p} + if hasRefCycle(&e, key) { + return b, refCycleError(t, p) + } + + defer freeRefCycleInfo(&e, key) + } + start := len(b) var err error b = append(b, '[') @@ -299,7 +326,7 @@ func (e encoder) encodeSlice(b []byte, p unsafe.Pointer, size uintptr, t reflect s := (*slice)(p) if s.data == nil && s.len == 0 && s.cap == 0 { - return append(b, "null"...), nil + return e.encodeNull(b, nil) } return e.encodeArray(b, s.data, s.len, size, t, encode) @@ -308,7 +335,20 @@ func (e encoder) encodeSlice(b []byte, p unsafe.Pointer, size uintptr, t reflect func (e encoder) encodeMap(b []byte, p unsafe.Pointer, t reflect.Type, encodeKey, encodeValue encodeFunc, sortKeys sortFunc) ([]byte, error) { m := reflect.NewAt(t, p).Elem() if m.IsNil() { - return append(b, "null"...), nil + return e.encodeNull(b, nil) + } + + // checkRefCycle/freeRefCycle expect the map header pointer itself, + // rather than a pointer to the header. + p = *(*unsafe.Pointer)(p) + + if shouldCheckForRefCycle(&e) { + key := cycleKey{ptr: p} + if hasRefCycle(&e, key) { + return b, refCycleError(t, p) + } + + defer freeRefCycleInfo(&e, key) } keys := m.MapKeys() @@ -317,9 +357,10 @@ func (e encoder) encodeMap(b []byte, p unsafe.Pointer, t reflect.Type, encodeKey } start := len(b) - var err error b = append(b, '{') + var err error + for i, k := range keys { v := m.MapIndex(k) @@ -363,7 +404,20 @@ var mapslicePool = sync.Pool{ func (e encoder) encodeMapStringInterface(b []byte, p unsafe.Pointer) ([]byte, error) { m := *(*map[string]any)(p) if m == nil { - return append(b, "null"...), nil + return e.encodeNull(b, nil) + } + + // checkRefCycle/freeRefCycle expect the map header pointer itself, + // rather than a pointer to the header. + p = *(*unsafe.Pointer)(p) + + if shouldCheckForRefCycle(&e) { + key := cycleKey{ptr: p} + if hasRefCycle(&e, key) { + return b, refCycleError(mapStringInterfaceType, p) + } + + defer freeRefCycleInfo(&e, key) } if (e.flags & SortMapKeys) == 0 { @@ -383,7 +437,7 @@ func (e encoder) encodeMapStringInterface(b []byte, p unsafe.Pointer) ([]byte, e b, _ = e.encodeString(b, unsafe.Pointer(&k)) b = append(b, ':') - b, err = Append(b, v, e.flags) + b, err = e.appendAny(b, v) if err != nil { return b, err } @@ -406,9 +460,10 @@ func (e encoder) encodeMapStringInterface(b []byte, p unsafe.Pointer) ([]byte, e sort.Sort(s) start := len(b) - var err error b = append(b, '{') + var err error + for i, elem := range s.elements { if i != 0 { b = append(b, ',') @@ -417,7 +472,7 @@ func (e encoder) encodeMapStringInterface(b []byte, p unsafe.Pointer) ([]byte, e b, _ = e.encodeString(b, unsafe.Pointer(&elem.key)) b = append(b, ':') - b, err = Append(b, elem.val, e.flags) + b, err = e.appendAny(b, elem.val) if err != nil { break } @@ -441,7 +496,7 @@ func (e encoder) encodeMapStringInterface(b []byte, p unsafe.Pointer) ([]byte, e func (e encoder) encodeMapStringRawMessage(b []byte, p unsafe.Pointer) ([]byte, error) { m := *(*map[string]RawMessage)(p) if m == nil { - return append(b, "null"...), nil + return e.encodeNull(b, nil) } if (e.flags & SortMapKeys) == 0 { @@ -520,7 +575,7 @@ func (e encoder) encodeMapStringRawMessage(b []byte, p unsafe.Pointer) ([]byte, func (e encoder) encodeMapStringString(b []byte, p unsafe.Pointer) ([]byte, error) { m := *(*map[string]string)(p) if m == nil { - return append(b, "null"...), nil + return e.encodeNull(b, nil) } if (e.flags & SortMapKeys) == 0 { @@ -586,7 +641,7 @@ func (e encoder) encodeMapStringString(b []byte, p unsafe.Pointer) ([]byte, erro func (e encoder) encodeMapStringStringSlice(b []byte, p unsafe.Pointer) ([]byte, error) { m := *(*map[string][]string)(p) if m == nil { - return append(b, "null"...), nil + return e.encodeNull(b, nil) } stringSize := unsafe.Sizeof("") @@ -667,7 +722,7 @@ func (e encoder) encodeMapStringStringSlice(b []byte, p unsafe.Pointer) ([]byte, func (e encoder) encodeMapStringBool(b []byte, p unsafe.Pointer) ([]byte, error) { m := *(*map[string]bool)(p) if m == nil { - return append(b, "null"...), nil + return e.encodeNull(b, nil) } if (e.flags & SortMapKeys) == 0 { @@ -794,30 +849,31 @@ func (e encoder) encodeEmbeddedStructPointer(b []byte, p unsafe.Pointer, t refle } func (e encoder) encodePointer(b []byte, p unsafe.Pointer, t reflect.Type, encode encodeFunc) ([]byte, error) { - if p = *(*unsafe.Pointer)(p); p != nil { - if e.ptrDepth++; e.ptrDepth >= startDetectingCyclesAfter { - if _, seen := e.ptrSeen[p]; seen { - // TODO: reconstruct the reflect.Value from p + t so we can set - // the erorr's Value field? - return b, &UnsupportedValueError{Str: fmt.Sprintf("encountered a cycle via %s", t)} - } - if e.ptrSeen == nil { - e.ptrSeen = make(map[unsafe.Pointer]struct{}) - } - e.ptrSeen[p] = struct{}{} - defer delete(e.ptrSeen, p) + // p was a pointer to the actual user data pointer: + // dereference it to operate on the user data pointer. + p = *(*unsafe.Pointer)(p) + if p == nil { + return e.encodeNull(b, nil) + } + + if shouldCheckForRefCycle(&e) { + key := cycleKey{ptr: p} + if hasRefCycle(&e, key) { + return b, refCycleError(t, p) } - return encode(e, b, p) + + defer freeRefCycleInfo(&e, key) } - return e.encodeNull(b, nil) + + return encode(e, b, p) } func (e encoder) encodeInterface(b []byte, p unsafe.Pointer) ([]byte, error) { - return Append(b, *(*any)(p), e.flags) + return e.appendAny(b, *(*any)(p)) } func (e encoder) encodeMaybeEmptyInterface(b []byte, p unsafe.Pointer, t reflect.Type) ([]byte, error) { - return Append(b, reflect.NewAt(t, p).Elem().Interface(), e.flags) + return e.appendAny(b, reflect.NewAt(t, p).Elem().Interface()) } func (e encoder) encodeUnsupportedTypeError(b []byte, p unsafe.Pointer, t reflect.Type) ([]byte, error) { @@ -828,7 +884,7 @@ func (e encoder) encodeRawMessage(b []byte, p unsafe.Pointer) ([]byte, error) { v := *(*RawMessage)(p) if v == nil { - return append(b, "null"...), nil + return e.encodeNull(b, nil) } var s []byte @@ -862,7 +918,7 @@ func (e encoder) encodeJSONMarshaler(b []byte, p unsafe.Pointer, t reflect.Type, switch v.Kind() { case reflect.Ptr, reflect.Interface: if v.IsNil() { - return append(b, "null"...), nil + return e.encodeNull(b, nil) } } @@ -968,3 +1024,64 @@ func appendCompactEscapeHTML(dst []byte, src []byte) []byte { return dst } + +// shouldCheckForRefCycle determines whether checking for reference cycles +// is reasonable to do at this time. +// +// When true, checkRefCycle should be called and any error handled, +// and then a deferred call to freeRefCycleInfo should be made. +// +// This should only be called from encoder methods that are possible points +// that could directly contribute to a reference cycle. +func shouldCheckForRefCycle(e *encoder) bool { + // Note: do not combine this with checkRefCycle, + // because checkRefCycle is too large to be inlined, + // and a non-inlined depth check leads to ~5%+ benchmark degradation. + e.refDepth++ + return e.refDepth >= startDetectingCyclesAfter +} + +// refCycleError constructs an [UnsupportedValueError]. +func refCycleError(t reflect.Type, p unsafe.Pointer) error { + v := reflect.NewAt(t, p) + return &UnsupportedValueError{ + Value: v, + Str: fmt.Sprintf("encountered a cycle via %s", t), + } +} + +// hasRefCycle returns an error if a reference cycle was detected. +// The data pointer passed in should be equivalent to one of: +// +// - A normal Go pointer, e.g. `unsafe.Pointer(&T)` +// - The pointer to a map header, e.g. `*(*unsafe.Pointer)(&map[K]V)` +// +// Many [encoder] methods accept a pointer-to-a-pointer, +// and so those may need to be derenced in order to safely pass them here. +func hasRefCycle(e *encoder, key cycleKey) bool { + _, seen := e.refSeen[key] + if seen { + return true + } + + if e.refSeen == nil { + e.refSeen = cycleMapPool.Get().(cycleMap) + } + + e.refSeen[key] = struct{}{} + + return false +} + +// freeRefCycle performs the cleanup operation for [checkRefCycle]. +// p must be the same value passed into a prior call to checkRefCycle. +func freeRefCycleInfo(e *encoder, key cycleKey) { + delete(e.refSeen, key) + if len(e.refSeen) == 0 { + // There are no remaining elements, + // so we can release this map for later reuse. + m := e.refSeen + e.refSeen = nil + cycleMapPool.Put(m) + } +} diff --git a/json/json.go b/json/json.go index 028fd1f..a3138f8 100644 --- a/json/json.go +++ b/json/json.go @@ -6,7 +6,6 @@ import ( "io" "math/bits" "reflect" - "runtime" "sync" "unsafe" ) @@ -194,25 +193,9 @@ func (k Kind) Class() Kind { return Kind(1 << uint(bits.Len(uint(k))-1)) } // Append acts like Marshal but appends the json representation to b instead of // always reallocating a new slice. func Append(b []byte, x any, flags AppendFlags) ([]byte, error) { - if x == nil { - // Special case for nil values because it makes the rest of the code - // simpler to assume that it won't be seeing nil pointers. - return append(b, "null"...), nil - } - - t := reflect.TypeOf(x) - p := (*iface)(unsafe.Pointer(&x)).ptr - - cache := cacheLoad() - c, found := cache[typeid(t)] - - if !found { - c = constructCachedCodec(t, cache) - } + e := encoder{flags: flags} - b, err := c.encode(encoder{flags: flags}, b, p) - runtime.KeepAlive(x) - return b, err + return e.appendAny(b, x) } // Escape is a convenience helper to construct an escaped JSON string from s. @@ -330,14 +313,9 @@ func Parse(b []byte, x any, flags ParseFlags) ([]byte, error) { } return r, &InvalidUnmarshalError{Type: t} } - t = t.Elem() - cache := cacheLoad() - c, found := cache[typeid(t)] - - if !found { - c = constructCachedCodec(t, cache) - } + t = t.Elem() + c := cachedCodec(t) r, err := c.decode(d, b, p) return skipSpaces(r), err