Skip to content

Commit df6d4a0

Browse files
committed
policyfile: support SSH check period 'always'
This also renames the 'Duration' type to 'SSHCheckPeriod' to clarify that it is used exclusively for the 'CheckPeriod' on 'ACLSSH'. Updates tailscale/tailscale-client-go#128 Signed-off-by: Percy Wegmann <[email protected]>
1 parent 299d9b1 commit df6d4a0

File tree

4 files changed

+96
-34
lines changed

4 files changed

+96
-34
lines changed

client.go

Lines changed: 0 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -369,31 +369,6 @@ func ErrorData(err error) []APIErrorData {
369369
return nil
370370
}
371371

372-
// Duration wraps a [time.Duration], allowing it to be JSON marshalled as a string like "20h" rather than
373-
// a numeric value.
374-
type Duration time.Duration
375-
376-
func (d Duration) String() string {
377-
return time.Duration(d).String()
378-
}
379-
380-
func (d Duration) MarshalText() ([]byte, error) {
381-
return []byte(d.String()), nil
382-
}
383-
384-
func (d *Duration) UnmarshalText(b []byte) error {
385-
text := string(b)
386-
if text == "" {
387-
text = "0s"
388-
}
389-
pd, err := time.ParseDuration(text)
390-
if err != nil {
391-
return err
392-
}
393-
*d = Duration(pd)
394-
return nil
395-
}
396-
397372
// PointerTo returns a pointer to the given value.
398373
// Pointers are used in PATCH requests to distinguish between specified and unspecified values.
399374
func PointerTo[T any](value T) *T {

policyfile.go

Lines changed: 47 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,48 @@ import (
77
"context"
88
"fmt"
99
"net/http"
10+
"time"
1011
)
1112

13+
// CheckPeriodAlways is a magic value corresponding to the [SSHCheckPeriod]
14+
// "always". It indicates that re-authorization is required on every login.
15+
const CheckPeriodAlways SSHCheckPeriod = -1
16+
17+
const checkPeriodAlways = "always"
18+
19+
// SSHCheckPeriod wraps a [time.Duration], allowing it to be JSON marshalled as
20+
// a string like "20h" rather than a numeric value. It also supports the
21+
// special value "always", which forces a check on every connection.
22+
type SSHCheckPeriod time.Duration
23+
24+
func (d SSHCheckPeriod) String() string {
25+
return time.Duration(d).String()
26+
}
27+
28+
func (d SSHCheckPeriod) MarshalText() ([]byte, error) {
29+
if d == CheckPeriodAlways {
30+
return []byte(checkPeriodAlways), nil
31+
}
32+
return []byte(d.String()), nil
33+
}
34+
35+
func (d *SSHCheckPeriod) UnmarshalText(b []byte) error {
36+
text := string(b)
37+
if text == checkPeriodAlways {
38+
*d = SSHCheckPeriod(CheckPeriodAlways)
39+
return nil
40+
}
41+
if text == "" {
42+
text = "0s"
43+
}
44+
pd, err := time.ParseDuration(text)
45+
if err != nil {
46+
return err
47+
}
48+
*d = SSHCheckPeriod(pd)
49+
return nil
50+
}
51+
1252
// PolicyFileResource provides access to https://tailscale.com/api#tag/policyfile.
1353
type PolicyFileResource struct {
1454
*Client
@@ -98,13 +138,13 @@ type ACLDERPNode struct {
98138
}
99139

100140
type ACLSSH struct {
101-
Action string `json:"action,omitempty" hujson:"Action,omitempty"`
102-
Users []string `json:"users,omitempty" hujson:"Users,omitempty"`
103-
Source []string `json:"src,omitempty" hujson:"Src,omitempty"`
104-
Destination []string `json:"dst,omitempty" hujson:"Dst,omitempty"`
105-
CheckPeriod Duration `json:"checkPeriod,omitempty" hujson:"CheckPeriod,omitempty"`
106-
Recorder []string `json:"recorder,omitempty" hujson:"Recorder,omitempty"`
107-
EnforceRecorder bool `json:"enforceRecorder,omitempty" hujson:"EnforceRecorder,omitempty"`
141+
Action string `json:"action,omitempty" hujson:"Action,omitempty"`
142+
Users []string `json:"users,omitempty" hujson:"Users,omitempty"`
143+
Source []string `json:"src,omitempty" hujson:"Src,omitempty"`
144+
Destination []string `json:"dst,omitempty" hujson:"Dst,omitempty"`
145+
CheckPeriod SSHCheckPeriod `json:"checkPeriod,omitempty" hujson:"CheckPeriod,omitempty"`
146+
Recorder []string `json:"recorder,omitempty" hujson:"Recorder,omitempty"`
147+
EnforceRecorder bool `json:"enforceRecorder,omitempty" hujson:"EnforceRecorder,omitempty"`
108148
}
109149

110150
type NodeAttrGrant struct {

policyfile_test.go

Lines changed: 42 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ func TestACL_Unmarshal(t *testing.T) {
114114
Source: []string{"tag:logging"},
115115
Destination: []string{"tag:prod"},
116116
Users: []string{"root", "autogroup:nonroot"},
117-
CheckPeriod: Duration(time.Hour * 20),
117+
CheckPeriod: SSHCheckPeriod(time.Hour * 20),
118118
},
119119
},
120120
},
@@ -194,7 +194,14 @@ func TestACL_Unmarshal(t *testing.T) {
194194
Source: []string{"tag:logging"},
195195
Destination: []string{"tag:prod"},
196196
Users: []string{"root", "autogroup:nonroot"},
197-
CheckPeriod: Duration(time.Hour * 20),
197+
CheckPeriod: SSHCheckPeriod(time.Hour * 20),
198+
},
199+
{
200+
Action: "accept",
201+
Source: []string{"tag:logging2"},
202+
Destination: []string{"tag:prod2"},
203+
Users: []string{"root", "autogroup:nonroot"},
204+
CheckPeriod: CheckPeriodAlways,
198205
},
199206
},
200207
Tests: []ACLTest{
@@ -275,6 +282,7 @@ func TestClient_SetACL(t *testing.T) {
275282
assert.NoError(t, json.Unmarshal(server.Body.Bytes(), &actualACL))
276283
assert.EqualValues(t, expectedACL, actualACL)
277284
}
285+
278286
func TestClient_SetACL_HuJSON(t *testing.T) {
279287
t.Parallel()
280288

@@ -384,3 +392,35 @@ func TestClient_RawACL(t *testing.T) {
384392
assert.EqualValues(t, "application/hujson", server.Header.Get("Accept"))
385393
assert.EqualValues(t, "/api/v2/tailnet/example.com/acl", server.Path)
386394
}
395+
396+
func TestSSHCheckPeriod(t *testing.T) {
397+
testCases := []struct {
398+
inStr string
399+
period SSHCheckPeriod
400+
outStr string
401+
}{
402+
{"1h", SSHCheckPeriod(1 * time.Hour), "1h0m0s"},
403+
{"", 0, "0s"},
404+
{checkPeriodAlways, CheckPeriodAlways, checkPeriodAlways},
405+
}
406+
407+
for _, tc := range testCases {
408+
t.Run(tc.inStr, func(t *testing.T) {
409+
var got SSHCheckPeriod
410+
if err := got.UnmarshalText([]byte(tc.inStr)); err != nil {
411+
t.Fatalf("failed to marshal: %s", err)
412+
}
413+
if got != tc.period {
414+
t.Fatalf("want period %s, got period %s", tc.period, got)
415+
}
416+
gotBytes, err := got.MarshalText()
417+
if err != nil {
418+
t.Fatalf("failed to marshal: %s", err)
419+
}
420+
gotStr := string(gotBytes)
421+
if gotStr != tc.outStr {
422+
t.Fatalf("want string %q, got string %q", tc.outStr, gotStr)
423+
}
424+
})
425+
}
426+
}

testdata/acl.hujson

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,5 +58,12 @@
5858
"users": ["root", "autogroup:nonroot"],
5959
"checkPeriod": "20h"
6060
},
61+
{
62+
"action": "accept",
63+
"src": ["tag:logging2"],
64+
"dst": ["tag:prod2"],
65+
"users": ["root", "autogroup:nonroot"],
66+
"checkPeriod": "always"
67+
},
6168
]
6269
}

0 commit comments

Comments
 (0)