Skip to content

Commit 352f848

Browse files
committed
Add NullRawMessage type
1 parent 67e7b41 commit 352f848

File tree

2 files changed

+101
-0
lines changed

2 files changed

+101
-0
lines changed

json.go

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
package pqtype
2+
3+
import (
4+
"database/sql/driver"
5+
"encoding/json"
6+
"fmt"
7+
)
8+
9+
// NullRawMessage represents a json.RawMessage that may be null.
10+
// NullRawMessage implements the Scanner interface so
11+
// it can be used as a scan destination, similar to NullString.
12+
type NullRawMessage struct {
13+
RawMessage json.RawMessage
14+
Valid bool // Valid is true if RawMessage is not NULL
15+
}
16+
17+
// Scan implements the Scanner interface.
18+
func (n *NullRawMessage) Scan(src interface{}) error {
19+
if src == nil {
20+
n.Valid = false
21+
return nil
22+
}
23+
switch src := src.(type) {
24+
case []byte:
25+
srcCopy := make([]byte, len(src))
26+
copy(srcCopy, src)
27+
n.RawMessage, n.Valid = srcCopy, true
28+
default:
29+
return fmt.Errorf("unsupported Scan, storing driver.Value type %T into type %T", src, []byte{})
30+
}
31+
return nil
32+
}
33+
34+
// Value implements the driver Valuer interface.
35+
func (n NullRawMessage) Value() (driver.Value, error) {
36+
if !n.Valid {
37+
return nil, nil
38+
}
39+
return n.RawMessage, nil
40+
}

tests/json_test.go

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
package tests
2+
3+
import (
4+
"encoding/json"
5+
"fmt"
6+
"testing"
7+
8+
"github.com/google/go-cmp/cmp"
9+
"github.com/tabbed/pqtype"
10+
)
11+
12+
func TestJSONRawMessage(t *testing.T) {
13+
for _, payload := range []string{
14+
`{}`,
15+
`[]`,
16+
`1`,
17+
`1.2`,
18+
`"a"`,
19+
`true`,
20+
`false`,
21+
`{"foo": "bar"}`,
22+
} {
23+
payload = payload
24+
t.Run(payload, func(t *testing.T) {
25+
var n pqtype.NullRawMessage
26+
if err := db.QueryRow(fmt.Sprintf(`SELECT '%s'::json`, payload)).Scan(&n); err != nil {
27+
t.Fatal(err)
28+
}
29+
if diff := cmp.Diff(true, n.Valid); diff != "" {
30+
t.Errorf("json valid mismatch (-want +got):\n%s", diff)
31+
}
32+
if diff := cmp.Diff(string(json.RawMessage(payload)), string(n.RawMessage)); diff != "" {
33+
t.Errorf("json mismatch (-want +got):\n%s", diff)
34+
}
35+
if err := db.QueryRow(fmt.Sprintf(`SELECT '%s'::jsonb`, payload)).Scan(&n); err != nil {
36+
t.Fatal(err)
37+
}
38+
if diff := cmp.Diff(true, n.Valid); diff != "" {
39+
t.Errorf("jsonb valid mismatch (-want +got):\n%s", diff)
40+
}
41+
if diff := cmp.Diff(string(json.RawMessage(payload)), string(n.RawMessage)); diff != "" {
42+
t.Errorf("jsonb mismatch (-want +got):\n%s", diff)
43+
}
44+
})
45+
}
46+
t.Run("NULL", func(t *testing.T) {
47+
var n pqtype.NullRawMessage
48+
if err := db.QueryRow(`SELECT NULL::json`).Scan(&n); err != nil {
49+
t.Fatal(err)
50+
}
51+
if diff := cmp.Diff(false, n.Valid); diff != "" {
52+
t.Errorf("valid mismatch (-want +got):\n%s", diff)
53+
}
54+
if err := db.QueryRow(`SELECT NULL::jsonb`).Scan(&n); err != nil {
55+
t.Fatal(err)
56+
}
57+
if diff := cmp.Diff(false, n.Valid); diff != "" {
58+
t.Errorf("valid mismatch (-want +got):\n%s", diff)
59+
}
60+
})
61+
}

0 commit comments

Comments
 (0)