Skip to content

Commit 1b09b11

Browse files
committed
Use require and assert in a few more Nebula test functions
1 parent 74d30d9 commit 1b09b11

File tree

1 file changed

+32
-58
lines changed

1 file changed

+32
-58
lines changed

authority/provisioner/nebula_test.go

Lines changed: 32 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -8,16 +8,15 @@ import (
88
"crypto/elliptic"
99
"crypto/rand"
1010
"crypto/x509"
11-
"fmt"
1211
"net"
1312
"net/url"
1413
"reflect"
1514
"strings"
1615
"testing"
1716
"time"
1817

19-
"github.com/google/go-cmp/cmp"
2018
"github.com/slackhq/nebula/cert"
19+
"github.com/stretchr/testify/assert"
2120
"github.com/stretchr/testify/require"
2221
"golang.org/x/crypto/ssh"
2322

@@ -30,9 +29,7 @@ import (
3029
func mustNebulaIPNet(t *testing.T, s string) *net.IPNet {
3130
t.Helper()
3231
ip, ipNet, err := net.ParseCIDR(s)
33-
if err != nil {
34-
t.Fatal(err)
35-
}
32+
require.NoError(t, err)
3633
if ip = ip.To4(); ip == nil {
3734
t.Fatalf("nebula only supports ipv4, have %s", s)
3835
}
@@ -43,9 +40,7 @@ func mustNebulaIPNet(t *testing.T, s string) *net.IPNet {
4340
func mustNebulaCA(t *testing.T) (*cert.NebulaCertificate, ed25519.PrivateKey) {
4441
t.Helper()
4542
pub, priv, err := ed25519.GenerateKey(rand.Reader)
46-
if err != nil {
47-
t.Fatal(err)
48-
}
43+
require.NoError(t, err)
4944
nc := &cert.NebulaCertificate{
5045
Details: cert.NebulaCertificateDetails{
5146
Name: "TestCA",
@@ -61,9 +56,9 @@ func mustNebulaCA(t *testing.T) (*cert.NebulaCertificate, ed25519.PrivateKey) {
6156
Curve: cert.Curve_CURVE25519,
6257
},
6358
}
64-
if err := nc.Sign(cert.Curve_CURVE25519, priv); err != nil {
65-
t.Fatal(err)
66-
}
59+
err = nc.Sign(cert.Curve_CURVE25519, priv)
60+
require.NoError(t, err)
61+
6762
return nc, priv
6863
}
6964

@@ -99,14 +94,10 @@ func mustNebulaCert(t *testing.T, name string, ipNet *net.IPNet, groups []string
9994
t.Helper()
10095

10196
pub, priv, err := x25519.GenerateKey(rand.Reader)
102-
if err != nil {
103-
t.Fatal(err)
104-
}
97+
require.NoError(t, err)
10598

10699
issuer, err := ca.Sha256Sum()
107-
if err != nil {
108-
t.Fatal(err)
109-
}
100+
require.NoError(t, err)
110101

111102
invertedGroups := make(map[string]struct{}, len(groups))
112103
for _, name := range groups {
@@ -130,9 +121,8 @@ func mustNebulaCert(t *testing.T, name string, ipNet *net.IPNet, groups []string
130121
},
131122
}
132123

133-
if err := nc.Sign(cert.Curve_CURVE25519, signer); err != nil {
134-
t.Fatal(err)
135-
}
124+
err = nc.Sign(cert.Curve_CURVE25519, signer)
125+
require.NoError(t, err)
136126

137127
return nc, priv
138128
}
@@ -184,9 +174,7 @@ func mustNebulaProvisioner(t *testing.T) (*Nebula, *cert.NebulaCertificate, ed25
184174

185175
nc, signer := mustNebulaCA(t)
186176
ncPem, err := nc.MarshalToPEM()
187-
if err != nil {
188-
t.Fatal(err)
189-
}
177+
require.NoError(t, err)
190178
bTrue := true
191179
p := &Nebula{
192180
Type: TypeNebula.String(),
@@ -196,12 +184,11 @@ func mustNebulaProvisioner(t *testing.T) (*Nebula, *cert.NebulaCertificate, ed25
196184
EnableSSHCA: &bTrue,
197185
},
198186
}
199-
if err := p.Init(Config{
187+
err = p.Init(Config{
200188
Claims: globalProvisionerClaims,
201189
Audiences: testAudiences,
202-
}); err != nil {
203-
t.Fatal(err)
204-
}
190+
})
191+
require.NoError(t, err)
205192

206193
return p, nc, signer
207194
}
@@ -310,9 +297,7 @@ func mustNebulaSSHToken(t *testing.T, sub, iss, aud string, iat time.Time, opts
310297
func TestNebula_Init(t *testing.T) {
311298
nc, _ := mustNebulaCA(t)
312299
ncPem, err := nc.MarshalToPEM()
313-
if err != nil {
314-
t.Fatal(err)
315-
}
300+
require.NoError(t, err)
316301

317302
cfg := Config{
318303
Claims: globalProvisionerClaims,
@@ -416,9 +401,7 @@ func TestNebula_GetTokenID(t *testing.T) {
416401
c1, priv := mustNebulaCert(t, "test.lan", mustNebulaIPNet(t, "10.1.0.1/16"), []string{"group"}, ca, signer)
417402
t1 := mustNebulaToken(t, "test.lan", p.Name, p.ctl.Audiences.Sign[0], now(), []string{"test.lan", "10.1.0.1"}, c1, priv, jose.XEdDSA)
418403
_, claims, err := parseToken(t1)
419-
if err != nil {
420-
t.Fatal(err)
421-
}
404+
require.NoError(t, err)
422405

423406
type args struct {
424407
token string
@@ -838,13 +821,9 @@ func TestNebula_authorizeToken(t *testing.T) {
838821

839822
// Not a nebula token
840823
jwk, err := generateJSONWebKey()
841-
if err != nil {
842-
t.Fatal(err)
843-
}
824+
require.NoError(t, err)
844825
simpleToken, err := generateSimpleToken("iss", "aud", jwk)
845-
if err != nil {
846-
t.Fatal(err)
847-
}
826+
require.NoError(t, err)
848827

849828
// Provisioner with a different CA
850829
p2, _, _ := mustNebulaProvisioner(t)
@@ -911,22 +890,20 @@ func TestNebula_authorizeToken(t *testing.T) {
911890
for _, tt := range tests {
912891
t.Run(tt.name, func(t *testing.T) {
913892
got, got1, err := tt.p.authorizeToken(tt.args.token, tt.args.audiences)
914-
if (err != nil) != tt.wantErr {
915-
t.Errorf("Nebula.authorizeToken() error = %v, wantErr %v", err, tt.wantErr)
893+
if tt.wantErr {
894+
assert.Error(t, err)
895+
assert.Nil(t, got)
896+
assert.Nil(t, got1)
916897
return
917898
}
918-
if !reflect.DeepEqual(got, tt.want) {
919-
t.Errorf("Nebula.authorizeToken() got = %#v, want %#v", got, tt.want)
920-
t.Error(cmp.Equal(got, tt.want))
921-
}
922899

923900
if got1 != nil && tt.want1 != nil {
924901
tt.want1.ID = got1.ID
925902
}
926903

927-
if !reflect.DeepEqual(got1, tt.want1) {
928-
t.Errorf("Nebula.authorizeToken() got1 = %v, want %v", got1, tt.want1)
929-
}
904+
assert.NoError(t, err)
905+
assert.Equal(t, tt.want, got)
906+
assert.Equal(t, tt.want1, got1)
930907
})
931908
}
932909
}
@@ -1021,23 +998,20 @@ func TestNebula_authorizeToken_P256(t *testing.T) {
1021998
for _, tt := range tests {
1022999
t.Run(tt.name, func(t *testing.T) {
10231000
got, got1, err := tt.p.authorizeToken(tt.args.token, tt.args.audiences)
1024-
if (err != nil) != tt.wantErr {
1025-
fmt.Println(err)
1026-
t.Errorf("Nebula.authorizeToken() error = %v, wantErr %v", err, tt.wantErr)
1001+
if tt.wantErr {
1002+
assert.Error(t, err)
1003+
assert.Nil(t, got)
1004+
assert.Nil(t, got1)
10271005
return
10281006
}
1029-
if !reflect.DeepEqual(got, tt.want) {
1030-
t.Errorf("Nebula.authorizeToken() got = %#v, want %#v", got, tt.want)
1031-
t.Error(cmp.Equal(got, tt.want))
1032-
}
10331007

10341008
if got1 != nil && tt.want1 != nil {
10351009
tt.want1.ID = got1.ID
10361010
}
10371011

1038-
if !reflect.DeepEqual(got1, tt.want1) {
1039-
t.Errorf("Nebula.authorizeToken() got1 = %v, want %v", got1, tt.want1)
1040-
}
1012+
assert.NoError(t, err)
1013+
assert.Equal(t, tt.want, got)
1014+
assert.Equal(t, tt.want1, got1)
10411015
})
10421016
}
10431017
}

0 commit comments

Comments
 (0)