Skip to content

Commit 95bfa30

Browse files
feat: IPNet type (#236)
* feat: IPNet type
1 parent 12b797a commit 95bfa30

File tree

2 files changed

+134
-10
lines changed

2 files changed

+134
-10
lines changed

scw/custom_types.go

Lines changed: 47 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,8 @@ import (
55
"encoding/json"
66
"fmt"
77
"io"
8+
"net"
89
"time"
9-
10-
"github.com/scaleway/scaleway-sdk-go/internal/errors"
1110
)
1211

1312
// ServiceInfo contains API metadata
@@ -123,7 +122,7 @@ type TimeSeriesPoint struct {
123122
Value float32
124123
}
125124

126-
func (tsp *TimeSeriesPoint) MarshalJSON() ([]byte, error) {
125+
func (tsp TimeSeriesPoint) MarshalJSON() ([]byte, error) {
127126
timestamp := tsp.Timestamp.Format(time.RFC3339)
128127
value, err := json.Marshal(tsp.Value)
129128
if err != nil {
@@ -142,25 +141,66 @@ func (tsp *TimeSeriesPoint) UnmarshalJSON(b []byte) error {
142141
}
143142

144143
if len(point) != 2 {
145-
return errors.New("invalid point array")
144+
return fmt.Errorf("invalid point array")
146145
}
147146

148147
strTimestamp, isStrTimestamp := point[0].(string)
149148
if !isStrTimestamp {
150-
return errors.New("%s timestamp is not a string in RFC 3339 format", point[0])
149+
return fmt.Errorf("%s timestamp is not a string in RFC 3339 format", point[0])
151150
}
152151
timestamp, err := time.Parse(time.RFC3339, strTimestamp)
153152
if err != nil {
154-
return errors.New("%s timestamp is not in RFC 3339 format", point[0])
153+
return fmt.Errorf("%s timestamp is not in RFC 3339 format", point[0])
155154
}
156155
tsp.Timestamp = timestamp
157156

158157
// By default, JSON unmarshal a float in float64 but the TimeSeriesPoint is a float32 value.
159158
value, isValue := point[1].(float64)
160159
if !isValue {
161-
return errors.New("%s is not a valid float32 value", point[1])
160+
return fmt.Errorf("%s is not a valid float32 value", point[1])
162161
}
163162
tsp.Value = float32(value)
164163

165164
return nil
166165
}
166+
167+
// IPNet inherits net.IPNet and represents an IP network.
168+
type IPNet struct {
169+
net.IPNet
170+
}
171+
172+
func (n IPNet) MarshalJSON() ([]byte, error) {
173+
value := n.String()
174+
if value == "<nil>" {
175+
value = ""
176+
}
177+
return []byte(`"` + value + `"`), nil
178+
}
179+
180+
func (n *IPNet) UnmarshalJSON(b []byte) error {
181+
var str string
182+
183+
err := json.Unmarshal(b, &str)
184+
if err != nil {
185+
return err
186+
}
187+
if str == "" {
188+
*n = IPNet{}
189+
return nil
190+
}
191+
192+
switch ip := net.ParseIP(str); {
193+
case ip.To4() != nil:
194+
str += "/32"
195+
case ip.To16() != nil:
196+
str += "/128"
197+
}
198+
199+
_, value, err := net.ParseCIDR(str)
200+
if err != nil {
201+
return err
202+
}
203+
n.IPNet = *value
204+
205+
return nil
206+
}

scw/custom_types_test.go

Lines changed: 87 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,12 @@ package scw
22

33
import (
44
"encoding/json"
5+
"fmt"
56
"io/ioutil"
7+
"net"
68
"testing"
79
"time"
810

9-
"github.com/scaleway/scaleway-sdk-go/internal/errors"
1011
"github.com/scaleway/scaleway-sdk-go/internal/testhelpers"
1112
)
1213

@@ -67,7 +68,7 @@ func TestTimeSeries_MarshallJSON(t *testing.T) {
6768
}
6869
}
6970

70-
func TestTimeSeries_UnmarshallJSON(t *testing.T) {
71+
func TestTimeSeries_UnmarshalJSON(t *testing.T) {
7172
cases := []struct {
7273
name string
7374
json string
@@ -108,7 +109,7 @@ func TestTimeSeries_UnmarshallJSON(t *testing.T) {
108109
{
109110
name: "with timestamp error",
110111
json: `{"name":"cpu_usage","points":[["2019/08/08T15-00-00Z",0.2]]}`,
111-
err: errors.New("2019/08/08T15-00-00Z timestamp is not in RFC 3339 format"),
112+
err: fmt.Errorf("2019/08/08T15-00-00Z timestamp is not in RFC 3339 format"),
112113
},
113114
}
114115

@@ -166,3 +167,86 @@ func TestFile_UnmarshalJSON(t *testing.T) {
166167
content: []byte("\x00\x00\x00\n"),
167168
}))
168169
}
170+
171+
func TestIPNet_MarshallJSON(t *testing.T) {
172+
cases := []struct {
173+
name string
174+
ipRange IPNet
175+
want string
176+
err error
177+
}{
178+
{
179+
name: "ip",
180+
ipRange: IPNet{IPNet: net.IPNet{IP: net.IPv4(42, 42, 42, 42), Mask: net.CIDRMask(32, 32)}},
181+
want: `"42.42.42.42/32"`,
182+
},
183+
{
184+
name: "network",
185+
ipRange: IPNet{IPNet: net.IPNet{IP: net.IPv4(42, 42, 42, 42), Mask: net.CIDRMask(16, 32)}},
186+
want: `"42.42.42.42/16"`,
187+
},
188+
}
189+
190+
for _, c := range cases {
191+
t.Run(c.name, func(t *testing.T) {
192+
got, err := json.Marshal(c.ipRange)
193+
194+
testhelpers.Equals(t, c.err, err)
195+
if c.err == nil {
196+
testhelpers.Equals(t, c.want, string(got))
197+
}
198+
})
199+
}
200+
}
201+
202+
func TestIPNet_UnmarshalJSON(t *testing.T) {
203+
cases := []struct {
204+
name string
205+
json string
206+
want IPNet
207+
err string
208+
}{
209+
{
210+
name: "IPv4 with CIDR",
211+
json: `"42.42.42.42/32"`,
212+
want: IPNet{IPNet: net.IPNet{IP: net.IPv4(42, 42, 42, 42), Mask: net.CIDRMask(32, 32)}},
213+
},
214+
{
215+
name: "IPv4 with network",
216+
json: `"192.0.2.1/24"`,
217+
want: IPNet{IPNet: net.IPNet{IP: net.IPv4(192, 0, 2, 0), Mask: net.CIDRMask(24, 32)}},
218+
},
219+
{
220+
name: "IPv6 with network",
221+
json: `"2001:db8:abcd:8000::/50"`,
222+
want: IPNet{IPNet: net.IPNet{IP: net.ParseIP("2001:db8:abcd:8000::"), Mask: net.CIDRMask(50, 128)}},
223+
},
224+
{
225+
name: "IPv4 alone",
226+
json: `"42.42.42.42"`,
227+
want: IPNet{IPNet: net.IPNet{IP: net.IPv4(42, 42, 42, 42), Mask: net.CIDRMask(32, 32)}},
228+
},
229+
{
230+
name: "IPv6 alone",
231+
json: `"2001:db8:abcd:8000::"`,
232+
want: IPNet{IPNet: net.IPNet{IP: net.ParseIP("2001:db8:abcd:8000::"), Mask: net.CIDRMask(128, 128)}},
233+
},
234+
{
235+
name: "invalid CIDR error",
236+
json: `"invalidvalue"`,
237+
err: "invalid CIDR address: invalidvalue",
238+
},
239+
}
240+
241+
for _, c := range cases {
242+
t.Run(c.name, func(t *testing.T) {
243+
ipNet := &IPNet{}
244+
err := json.Unmarshal([]byte(c.json), ipNet)
245+
if err != nil {
246+
testhelpers.Equals(t, c.err, err.Error())
247+
}
248+
249+
testhelpers.Equals(t, c.want.String(), ipNet.String())
250+
})
251+
}
252+
}

0 commit comments

Comments
 (0)