Skip to content

Commit 2ef68ae

Browse files
Extends Duration to support minutes
1 parent 319861e commit 2ef68ae

File tree

2 files changed

+120
-9
lines changed

2 files changed

+120
-9
lines changed

pkg/config/duration.go

Lines changed: 78 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,9 @@ import (
44
"database/sql/driver"
55
"encoding/json"
66
"fmt"
7+
"regexp"
8+
"strconv"
9+
"strings"
710
"time"
811
)
912

@@ -59,25 +62,24 @@ func (d Duration) String() string {
5962

6063
// MarshalJSON implements the json.Marshaler interface.
6164
func (d Duration) MarshalJSON() ([]byte, error) {
62-
return json.Marshal(d.String())
65+
// json.Marshal to get a proper JSON string with quotes/escaping
66+
return json.Marshal(formatDuration(d.d))
6367
}
6468

6569
// UnmarshalJSON implements the json.Unmarshaler interface.
6670
func (d *Duration) UnmarshalJSON(input []byte) error {
6771
var txt string
68-
err := json.Unmarshal(input, &txt)
69-
if err != nil {
72+
if err := json.Unmarshal(input, &txt); err != nil {
7073
return err
7174
}
72-
v, err := time.ParseDuration(txt)
75+
76+
v, err := parseDuration(txt)
7377
if err != nil {
7478
return err
7579
}
80+
7681
*d, err = NewDuration(v)
77-
if err != nil {
78-
return err
79-
}
80-
return nil
82+
return err
8183
}
8284

8385
func (d *Duration) Scan(v any) (err error) {
@@ -102,14 +104,81 @@ func (d Duration) MarshalText() ([]byte, error) {
102104

103105
// UnmarshalText implements the text.Unmarshaler interface.
104106
func (d *Duration) UnmarshalText(input []byte) error {
105-
v, err := time.ParseDuration(string(input))
107+
v, err := parseDuration(string(input))
106108
if err != nil {
107109
return err
108110
}
111+
109112
pd, err := NewDuration(v)
110113
if err != nil {
111114
return err
112115
}
116+
113117
*d = pd
118+
114119
return nil
115120
}
121+
122+
func formatDuration(dur time.Duration) string {
123+
if dur == 0 {
124+
return "0s"
125+
}
126+
127+
var parts []string
128+
129+
days := dur / (24 * time.Hour)
130+
if days > 0 {
131+
parts = append(parts, fmt.Sprintf("%dd", days))
132+
dur %= 24 * time.Hour
133+
}
134+
135+
if days > 0 && dur == 0 {
136+
return strings.Join(parts, "")
137+
}
138+
139+
hours := dur / time.Hour
140+
minutes := (dur % time.Hour) / time.Minute
141+
seconds := (dur % time.Minute) / time.Second
142+
nanos := dur % time.Second
143+
144+
if days > 0 {
145+
return fmt.Sprintf("%dd%dh%dm%ds", days, hours, minutes, seconds)
146+
}
147+
148+
if hours > 0 {
149+
parts = append(parts, fmt.Sprintf("%dh", hours))
150+
}
151+
if minutes > 0 || (hours > 0) {
152+
parts = append(parts, fmt.Sprintf("%dm", minutes))
153+
}
154+
155+
if dur < time.Second {
156+
return dur.String()
157+
}
158+
159+
if nanos == 0 {
160+
if seconds > 0 || (minutes > 0) || (hours > 0) {
161+
parts = append(parts, fmt.Sprintf("%ds", seconds))
162+
}
163+
} else {
164+
parts = append(parts, fmt.Sprintf("%g", float64(seconds)+float64(nanos)/1e9)+"s")
165+
}
166+
167+
if len(parts) == 0 {
168+
return dur.String()
169+
}
170+
171+
return strings.Join(parts, "")
172+
}
173+
174+
func parseDuration(s string) (time.Duration, error) {
175+
if strings.ContainsAny(s, "dhms") {
176+
// Replace "d" with "h" and multiply the number by 24
177+
re := regexp.MustCompile(`(\d+)d`)
178+
s = re.ReplaceAllStringFunc(s, func(match string) string {
179+
val, _ := strconv.Atoi(strings.TrimSuffix(match, "d"))
180+
return fmt.Sprintf("%dh", val*24)
181+
})
182+
}
183+
return time.ParseDuration(s)
184+
}

pkg/config/duration_test.go

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,13 @@ func TestDuration_MarshalJSON(t *testing.T) {
2020
{"one minute", *MustNewDuration(time.Minute), `"1m0s"`},
2121
{"one hour", *MustNewDuration(time.Hour), `"1h0m0s"`},
2222
{"one hour thirty minutes", *MustNewDuration(time.Hour + 30*time.Minute), `"1h30m0s"`},
23+
{"1 day", *MustNewDuration(24 * time.Hour), `"1d"`},
24+
{"2 days", *MustNewDuration(48 * time.Hour), `"2d"`},
25+
{"1 day 12 hours", *MustNewDuration(36 * time.Hour), `"1d12h0m0s"`},
26+
{"3 days 6 hours 30 minutes", *MustNewDuration(78*time.Hour + 30*time.Minute), `"3d6h30m0s"`},
27+
{"4 days 1 hour 0 minutes 10 seconds", *MustNewDuration(97*time.Hour + 10*time.Second), `"4d1h0m10s"`},
28+
{"4 days 0 hours 0 minutes 10 seconds", *MustNewDuration(4*24*time.Hour + 10*time.Second), `"4d0h0m10s"`},
29+
{"4 days 0 hours 22 minutes 10 seconds", *MustNewDuration(4*24*time.Hour + 22*time.Minute + 10*time.Second), `"4d0h22m10s"`},
2330
}
2431
for _, test := range tests {
2532
t.Run(test.name, func(t *testing.T) {
@@ -30,6 +37,41 @@ func TestDuration_MarshalJSON(t *testing.T) {
3037
}
3138
}
3239

40+
func TestDuration_UnmarshalJSON(t *testing.T) {
41+
tests := []struct {
42+
name string
43+
input string
44+
want Duration
45+
wantErr bool
46+
}{
47+
{"zero", `"0s"`, *MustNewDuration(0), false},
48+
{"one second", `"1s"`, *MustNewDuration(time.Second), false},
49+
{"one minute", `"1m0s"`, *MustNewDuration(time.Minute), false},
50+
{"one hour", `"1h0m0s"`, *MustNewDuration(time.Hour), false},
51+
{"one hour thirty minutes", `"1h30m0s"`, *MustNewDuration(time.Hour + 30*time.Minute), false},
52+
{"1 day", `"1d"`, *MustNewDuration(24 * time.Hour), false},
53+
{"2 days", `"2d"`, *MustNewDuration(48 * time.Hour), false},
54+
{"1 day 12 hours", `"1d12h0m0s"`, *MustNewDuration(36 * time.Hour), false},
55+
{"3 days 6 hours 30 minutes", `"3d6h30m0s"`, *MustNewDuration(78*time.Hour + 30*time.Minute), false},
56+
{"4 days 1 hour 0 minutes 10 seconds", `"4d1h0m10s"`, *MustNewDuration(97*time.Hour + 10*time.Second), false},
57+
{"4 days 0 hours 0 minutes 10 seconds", `"4d0h0m10s"`, *MustNewDuration(4*24*time.Hour + 10*time.Second), false},
58+
{"4 days 0 hours 22 minutes 10 seconds", `"4d0h22m10s"`, *MustNewDuration(4*24*time.Hour + 22*time.Minute + 10*time.Second), false},
59+
{"invalid", `"invalid"`, Duration{}, true},
60+
}
61+
for _, tt := range tests {
62+
t.Run(tt.name, func(t *testing.T) {
63+
var d Duration
64+
err := json.Unmarshal([]byte(tt.input), &d)
65+
if tt.wantErr {
66+
assert.Error(t, err)
67+
} else {
68+
assert.NoError(t, err)
69+
assert.Equal(t, tt.want, d)
70+
}
71+
})
72+
}
73+
}
74+
3375
func TestDuration_Scan_Value(t *testing.T) {
3476
t.Parallel()
3577

0 commit comments

Comments
 (0)