Skip to content

Commit c4fc250

Browse files
authored
Merge branch 'main' into PRIV-192
2 parents 8cb381f + 98903c7 commit c4fc250

File tree

2 files changed

+118
-0
lines changed

2 files changed

+118
-0
lines changed

pkg/sqlutil/interval.go

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
package sqlutil
2+
3+
import (
4+
"database/sql/driver"
5+
"fmt"
6+
"time"
7+
)
8+
9+
// Interval represents a time.Duration stored as a Postgres interval type
10+
type Interval time.Duration
11+
12+
// NewInterval creates Interval for specified duration
13+
func NewInterval(d time.Duration) *Interval {
14+
i := new(Interval)
15+
*i = Interval(d)
16+
return i
17+
}
18+
19+
func (i Interval) Duration() time.Duration {
20+
return time.Duration(i)
21+
}
22+
23+
// MarshalText implements the text.Marshaler interface.
24+
func (i Interval) MarshalText() ([]byte, error) {
25+
return []byte(time.Duration(i).String()), nil
26+
}
27+
28+
// UnmarshalText implements the text.Unmarshaler interface.
29+
func (i *Interval) UnmarshalText(input []byte) error {
30+
v, err := time.ParseDuration(string(input))
31+
if err != nil {
32+
return err
33+
}
34+
*i = Interval(v)
35+
return nil
36+
}
37+
38+
func (i *Interval) Scan(v interface{}) error {
39+
if v == nil {
40+
*i = Interval(time.Duration(0))
41+
return nil
42+
}
43+
asInt64, is := v.(int64)
44+
if !is {
45+
return fmt.Errorf("models.Interval#Scan() wanted int64, got %T", v)
46+
}
47+
*i = Interval(time.Duration(asInt64) * time.Nanosecond)
48+
return nil
49+
}
50+
51+
func (i Interval) Value() (driver.Value, error) {
52+
return time.Duration(i).Nanoseconds(), nil
53+
}
54+
55+
func (i Interval) IsZero() bool {
56+
return time.Duration(i) == time.Duration(0)
57+
}

pkg/sqlutil/interval_test.go

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
package sqlutil
2+
3+
import (
4+
"testing"
5+
"time"
6+
7+
"github.com/stretchr/testify/require"
8+
)
9+
10+
func TestNewInterval(t *testing.T) {
11+
t.Parallel()
12+
13+
duration := 33 * time.Second
14+
interval := NewInterval(duration)
15+
16+
require.Equal(t, duration, interval.Duration())
17+
}
18+
19+
func TestInterval_IsZero(t *testing.T) {
20+
t.Parallel()
21+
22+
i := NewInterval(0)
23+
require.NotNil(t, i)
24+
require.True(t, i.IsZero())
25+
26+
i = NewInterval(1)
27+
require.NotNil(t, i)
28+
require.False(t, i.IsZero())
29+
}
30+
31+
func TestInterval_Scan_Value(t *testing.T) {
32+
t.Parallel()
33+
34+
i := NewInterval(100)
35+
require.NotNil(t, i)
36+
37+
val, err := i.Value()
38+
require.NoError(t, err)
39+
40+
iNew := NewInterval(0)
41+
err = iNew.Scan(val)
42+
require.NoError(t, err)
43+
44+
require.Equal(t, i, iNew)
45+
}
46+
47+
func TestInterval_MarshalText_UnmarshalText(t *testing.T) {
48+
t.Parallel()
49+
50+
i := NewInterval(100)
51+
require.NotNil(t, i)
52+
53+
txt, err := i.MarshalText()
54+
require.NoError(t, err)
55+
56+
iNew := NewInterval(0)
57+
err = iNew.UnmarshalText(txt)
58+
require.NoError(t, err)
59+
60+
require.Equal(t, i, iNew)
61+
}

0 commit comments

Comments
 (0)