@@ -8,16 +8,15 @@ import (
88 "crypto/elliptic"
99 "crypto/rand"
1010 "crypto/x509"
11- "fmt"
1211 "net"
1312 "net/url"
1413 "reflect"
1514 "strings"
1615 "testing"
1716 "time"
1817
19- "github.com/google/go-cmp/cmp"
2018 "github.com/slackhq/nebula/cert"
19+ "github.com/stretchr/testify/assert"
2120 "github.com/stretchr/testify/require"
2221 "golang.org/x/crypto/ssh"
2322
@@ -30,9 +29,7 @@ import (
3029func mustNebulaIPNet (t * testing.T , s string ) * net.IPNet {
3130 t .Helper ()
3231 ip , ipNet , err := net .ParseCIDR (s )
33- if err != nil {
34- t .Fatal (err )
35- }
32+ require .NoError (t , err )
3633 if ip = ip .To4 (); ip == nil {
3734 t .Fatalf ("nebula only supports ipv4, have %s" , s )
3835 }
@@ -43,9 +40,7 @@ func mustNebulaIPNet(t *testing.T, s string) *net.IPNet {
4340func mustNebulaCA (t * testing.T ) (* cert.NebulaCertificate , ed25519.PrivateKey ) {
4441 t .Helper ()
4542 pub , priv , err := ed25519 .GenerateKey (rand .Reader )
46- if err != nil {
47- t .Fatal (err )
48- }
43+ require .NoError (t , err )
4944 nc := & cert.NebulaCertificate {
5045 Details : cert.NebulaCertificateDetails {
5146 Name : "TestCA" ,
@@ -61,9 +56,9 @@ func mustNebulaCA(t *testing.T) (*cert.NebulaCertificate, ed25519.PrivateKey) {
6156 Curve : cert .Curve_CURVE25519 ,
6257 },
6358 }
64- if err : = nc .Sign (cert .Curve_CURVE25519 , priv ); err != nil {
65- t . Fatal ( err )
66- }
59+ err = nc .Sign (cert .Curve_CURVE25519 , priv )
60+ require . NoError ( t , err )
61+
6762 return nc , priv
6863}
6964
@@ -99,14 +94,10 @@ func mustNebulaCert(t *testing.T, name string, ipNet *net.IPNet, groups []string
9994 t .Helper ()
10095
10196 pub , priv , err := x25519 .GenerateKey (rand .Reader )
102- if err != nil {
103- t .Fatal (err )
104- }
97+ require .NoError (t , err )
10598
10699 issuer , err := ca .Sha256Sum ()
107- if err != nil {
108- t .Fatal (err )
109- }
100+ require .NoError (t , err )
110101
111102 invertedGroups := make (map [string ]struct {}, len (groups ))
112103 for _ , name := range groups {
@@ -130,9 +121,8 @@ func mustNebulaCert(t *testing.T, name string, ipNet *net.IPNet, groups []string
130121 },
131122 }
132123
133- if err := nc .Sign (cert .Curve_CURVE25519 , signer ); err != nil {
134- t .Fatal (err )
135- }
124+ err = nc .Sign (cert .Curve_CURVE25519 , signer )
125+ require .NoError (t , err )
136126
137127 return nc , priv
138128}
@@ -184,9 +174,7 @@ func mustNebulaProvisioner(t *testing.T) (*Nebula, *cert.NebulaCertificate, ed25
184174
185175 nc , signer := mustNebulaCA (t )
186176 ncPem , err := nc .MarshalToPEM ()
187- if err != nil {
188- t .Fatal (err )
189- }
177+ require .NoError (t , err )
190178 bTrue := true
191179 p := & Nebula {
192180 Type : TypeNebula .String (),
@@ -196,12 +184,11 @@ func mustNebulaProvisioner(t *testing.T) (*Nebula, *cert.NebulaCertificate, ed25
196184 EnableSSHCA : & bTrue ,
197185 },
198186 }
199- if err : = p .Init (Config {
187+ err = p .Init (Config {
200188 Claims : globalProvisionerClaims ,
201189 Audiences : testAudiences ,
202- }); err != nil {
203- t .Fatal (err )
204- }
190+ })
191+ require .NoError (t , err )
205192
206193 return p , nc , signer
207194}
@@ -310,9 +297,7 @@ func mustNebulaSSHToken(t *testing.T, sub, iss, aud string, iat time.Time, opts
310297func TestNebula_Init (t * testing.T ) {
311298 nc , _ := mustNebulaCA (t )
312299 ncPem , err := nc .MarshalToPEM ()
313- if err != nil {
314- t .Fatal (err )
315- }
300+ require .NoError (t , err )
316301
317302 cfg := Config {
318303 Claims : globalProvisionerClaims ,
@@ -416,9 +401,7 @@ func TestNebula_GetTokenID(t *testing.T) {
416401 c1 , priv := mustNebulaCert (t , "test.lan" , mustNebulaIPNet (t , "10.1.0.1/16" ), []string {"group" }, ca , signer )
417402 t1 := mustNebulaToken (t , "test.lan" , p .Name , p .ctl .Audiences .Sign [0 ], now (), []string {"test.lan" , "10.1.0.1" }, c1 , priv , jose .XEdDSA )
418403 _ , claims , err := parseToken (t1 )
419- if err != nil {
420- t .Fatal (err )
421- }
404+ require .NoError (t , err )
422405
423406 type args struct {
424407 token string
@@ -838,13 +821,9 @@ func TestNebula_authorizeToken(t *testing.T) {
838821
839822 // Not a nebula token
840823 jwk , err := generateJSONWebKey ()
841- if err != nil {
842- t .Fatal (err )
843- }
824+ require .NoError (t , err )
844825 simpleToken , err := generateSimpleToken ("iss" , "aud" , jwk )
845- if err != nil {
846- t .Fatal (err )
847- }
826+ require .NoError (t , err )
848827
849828 // Provisioner with a different CA
850829 p2 , _ , _ := mustNebulaProvisioner (t )
@@ -911,22 +890,20 @@ func TestNebula_authorizeToken(t *testing.T) {
911890 for _ , tt := range tests {
912891 t .Run (tt .name , func (t * testing.T ) {
913892 got , got1 , err := tt .p .authorizeToken (tt .args .token , tt .args .audiences )
914- if (err != nil ) != tt .wantErr {
915- t .Errorf ("Nebula.authorizeToken() error = %v, wantErr %v" , err , tt .wantErr )
893+ if tt .wantErr {
894+ assert .Error (t , err )
895+ assert .Nil (t , got )
896+ assert .Nil (t , got1 )
916897 return
917898 }
918- if ! reflect .DeepEqual (got , tt .want ) {
919- t .Errorf ("Nebula.authorizeToken() got = %#v, want %#v" , got , tt .want )
920- t .Error (cmp .Equal (got , tt .want ))
921- }
922899
923900 if got1 != nil && tt .want1 != nil {
924901 tt .want1 .ID = got1 .ID
925902 }
926903
927- if ! reflect . DeepEqual ( got1 , tt . want1 ) {
928- t . Errorf ( "Nebula.authorizeToken() got1 = %v, want %v" , got1 , tt . want1 )
929- }
904+ assert . NoError ( t , err )
905+ assert . Equal ( t , tt . want , got )
906+ assert . Equal ( t , tt . want1 , got1 )
930907 })
931908 }
932909}
@@ -1021,23 +998,20 @@ func TestNebula_authorizeToken_P256(t *testing.T) {
1021998 for _ , tt := range tests {
1022999 t .Run (tt .name , func (t * testing.T ) {
10231000 got , got1 , err := tt .p .authorizeToken (tt .args .token , tt .args .audiences )
1024- if (err != nil ) != tt .wantErr {
1025- fmt .Println (err )
1026- t .Errorf ("Nebula.authorizeToken() error = %v, wantErr %v" , err , tt .wantErr )
1001+ if tt .wantErr {
1002+ assert .Error (t , err )
1003+ assert .Nil (t , got )
1004+ assert .Nil (t , got1 )
10271005 return
10281006 }
1029- if ! reflect .DeepEqual (got , tt .want ) {
1030- t .Errorf ("Nebula.authorizeToken() got = %#v, want %#v" , got , tt .want )
1031- t .Error (cmp .Equal (got , tt .want ))
1032- }
10331007
10341008 if got1 != nil && tt .want1 != nil {
10351009 tt .want1 .ID = got1 .ID
10361010 }
10371011
1038- if ! reflect . DeepEqual ( got1 , tt . want1 ) {
1039- t . Errorf ( "Nebula.authorizeToken() got1 = %v, want %v" , got1 , tt . want1 )
1040- }
1012+ assert . NoError ( t , err )
1013+ assert . Equal ( t , tt . want , got )
1014+ assert . Equal ( t , tt . want1 , got1 )
10411015 })
10421016 }
10431017}
0 commit comments