Skip to content

Commit 036bb56

Browse files
authored
Nats integration (#137)
* Add TLS config creation with Ed25519 key validation for NATS * Enhance concurrency safety with Replace() and Keys() * Security improvement, more deterministic Constant-Time comparison comparison (only 0 or 1 is possible) * Add tests to cover concurrency improvements and coverage for Key management logic
1 parent 5df51ea commit 036bb56

File tree

2 files changed

+222
-3
lines changed

2 files changed

+222
-3
lines changed

rpc/mtls/mtls.go

Lines changed: 31 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,20 @@ func NewTransportSigner(signer crypto.Signer, pubKeys []ed25519.PublicKey) (cred
4545
return credentials.NewTLS(c), nil
4646
}
4747

48+
func NewTLSConfig(privKey ed25519.PrivateKey, pubKeys []ed25519.PublicKey) (*tls.Config, error) {
49+
priv, err := ValidPrivateKeyFromEd25519(privKey)
50+
if err != nil {
51+
return nil, err
52+
}
53+
54+
pubs, err := ValidPublicKeysFromEd25519(pubKeys...)
55+
if err != nil {
56+
return nil, err
57+
}
58+
59+
return newMutualTLSConfig(priv.key, pubs)
60+
}
61+
4862
// newMutualTLSConfig uses the private key and public keys to construct a mutual
4963
// TLS 1.3 config.
5064
//
@@ -117,6 +131,9 @@ type PublicKeys struct {
117131
}
118132

119133
func ValidPublicKeysFromEd25519(keys ...ed25519.PublicKey) (*PublicKeys, error) {
134+
if len(keys) == 0 {
135+
return nil, errors.New("no public keys provided")
136+
}
120137
for _, key := range keys {
121138
if len(key) != ed25519.PublicKeySize {
122139
return nil, fmt.Errorf("invalid key length: %d, expected: %d", len(key), ed25519.PublicKeySize)
@@ -129,7 +146,13 @@ func ValidPublicKeysFromEd25519(keys ...ed25519.PublicKey) (*PublicKeys, error)
129146
}
130147

131148
func (r *PublicKeys) Keys() []ed25519.PublicKey {
132-
return r.keys
149+
r.mu.RLock()
150+
defer r.mu.RUnlock()
151+
152+
// Return a copy to prevent race conditions
153+
keysCopy := make([]ed25519.PublicKey, len(r.keys))
154+
copy(keysCopy, r.keys)
155+
return keysCopy
133156
}
134157

135158
// Verifies that the certificate's public key matches with one of the keys in
@@ -160,17 +183,22 @@ func (r *PublicKeys) VerifyPeerCertificate() func(rawCerts [][]byte, verifiedCha
160183
// Replace replaces the existing keys with new keys. Use this to dynamically
161184
// update the allowable keys at runtime.
162185
func (r *PublicKeys) Replace(pubs *PublicKeys) {
186+
pubs.mu.RLock()
187+
newKeys := make([]ed25519.PublicKey, len(pubs.keys))
188+
copy(newKeys, pubs.keys)
189+
pubs.mu.RUnlock()
190+
163191
r.mu.Lock()
164192
defer r.mu.Unlock()
165-
r.keys = pubs.keys
193+
r.keys = newKeys
166194
}
167195

168196
// isValidPublicKey checks the public key against a list of valid keys.
169197
func (r *PublicKeys) isValidPublicKey(pub ed25519.PublicKey) bool {
170198
r.mu.RLock()
171199
defer r.mu.RUnlock()
172200
for _, vpub := range r.keys {
173-
if subtle.ConstantTimeCompare(pub, vpub) > 0 {
201+
if subtle.ConstantTimeCompare(pub, vpub) == 1 {
174202
return true
175203
}
176204
}

rpc/mtls/mtls_test.go

Lines changed: 191 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,9 @@ import (
66
"crypto/tls"
77
"crypto/x509"
88
"math/big"
9+
"sync"
910
"testing"
11+
"time"
1012

1113
"github.com/stretchr/testify/assert"
1214
"github.com/stretchr/testify/require"
@@ -154,3 +156,192 @@ func Test_NewPublicKeys(t *testing.T) {
154156
require.Error(t, err)
155157
})
156158
}
159+
160+
func Test_NewTLSConfig(t *testing.T) {
161+
t.Run("nil_arguments", func(t *testing.T) {
162+
cfg, err := NewTLSConfig(nil, nil)
163+
require.Error(t, err)
164+
assert.Nil(t, cfg)
165+
})
166+
167+
t.Run("nil_private_key", func(t *testing.T) {
168+
pub, _, err := ed25519.GenerateKey(nil)
169+
require.NoError(t, err)
170+
171+
cfg, err := NewTLSConfig(nil, []ed25519.PublicKey{pub})
172+
require.Error(t, err)
173+
assert.Nil(t, cfg)
174+
})
175+
176+
t.Run("nil_public_keys", func(t *testing.T) {
177+
_, priv, err := ed25519.GenerateKey(nil)
178+
require.NoError(t, err)
179+
180+
cfg, err := NewTLSConfig(priv, nil)
181+
require.Error(t, err)
182+
assert.Nil(t, cfg)
183+
})
184+
185+
t.Run("empty_public_keys", func(t *testing.T) {
186+
_, priv, err := ed25519.GenerateKey(nil)
187+
require.NoError(t, err)
188+
189+
cfg, err := NewTLSConfig(priv, []ed25519.PublicKey{})
190+
require.Error(t, err)
191+
assert.Nil(t, cfg)
192+
})
193+
194+
t.Run("valid_single_key", func(t *testing.T) {
195+
pub, priv, err := ed25519.GenerateKey(nil)
196+
require.NoError(t, err)
197+
198+
cfg, err := NewTLSConfig(priv, []ed25519.PublicKey{pub})
199+
require.NoError(t, err)
200+
assert.NotNil(t, cfg)
201+
assert.Equal(t, uint16(tls.VersionTLS13), cfg.MinVersion)
202+
assert.Equal(t, uint16(tls.VersionTLS13), cfg.MaxVersion)
203+
assert.True(t, cfg.InsecureSkipVerify)
204+
assert.Len(t, cfg.Certificates, 1)
205+
})
206+
207+
t.Run("valid_multiple_keys", func(t *testing.T) {
208+
pub1, _, err := ed25519.GenerateKey(nil)
209+
require.NoError(t, err)
210+
211+
pub2, _, err := ed25519.GenerateKey(nil)
212+
require.NoError(t, err)
213+
214+
_, priv, err := ed25519.GenerateKey(nil)
215+
require.NoError(t, err)
216+
217+
cfg, err := NewTLSConfig(priv, []ed25519.PublicKey{pub1, pub2})
218+
require.NoError(t, err)
219+
assert.NotNil(t, cfg)
220+
assert.Equal(t, uint16(tls.VersionTLS13), cfg.MinVersion)
221+
assert.Equal(t, uint16(tls.VersionTLS13), cfg.MaxVersion)
222+
assert.True(t, cfg.InsecureSkipVerify)
223+
assert.Len(t, cfg.Certificates, 1)
224+
})
225+
226+
t.Run("invalid_public_key_length", func(t *testing.T) {
227+
_, priv, err := ed25519.GenerateKey(nil)
228+
require.NoError(t, err)
229+
230+
// Create an invalid public key with wrong length
231+
invalidPub := make([]byte, ed25519.PublicKeySize+1)
232+
cfg, err := NewTLSConfig(priv, []ed25519.PublicKey{invalidPub})
233+
require.Error(t, err)
234+
assert.Nil(t, cfg)
235+
})
236+
}
237+
238+
func Test_PublicKeys_Keys(t *testing.T) {
239+
// Create keys
240+
pub1, _, err := ed25519.GenerateKey(nil)
241+
require.NoError(t, err)
242+
243+
pub2, _, err := ed25519.GenerateKey(nil)
244+
require.NoError(t, err)
245+
246+
// Create PublicKeys with both keys
247+
pks, err := ValidPublicKeysFromEd25519(pub1, pub2)
248+
require.NoError(t, err)
249+
250+
// Get a copy of the keys
251+
keysCopy := pks.Keys()
252+
require.Equal(t, 2, len(keysCopy))
253+
254+
// Verify the keys match
255+
assert.ElementsMatch(t, []ed25519.PublicKey{pub1, pub2}, keysCopy)
256+
257+
// Check original is unaffected
258+
keysAfter := pks.Keys()
259+
require.Equal(t, 2, len(keysAfter))
260+
}
261+
262+
func Test_PublicKeys_Replace(t *testing.T) {
263+
// Create original keys
264+
pub1, _, err := ed25519.GenerateKey(nil)
265+
require.NoError(t, err)
266+
267+
pks1, err := ValidPublicKeysFromEd25519(pub1)
268+
require.NoError(t, err)
269+
270+
// Create replacement keys
271+
pub2, _, err := ed25519.GenerateKey(nil)
272+
require.NoError(t, err)
273+
274+
pks2, err := ValidPublicKeysFromEd25519(pub2)
275+
require.NoError(t, err)
276+
277+
// Replace keys
278+
pks1.Replace(pks2)
279+
280+
// Verify the replacement worked
281+
assert.False(t, pks1.isValidPublicKey(pub1))
282+
assert.True(t, pks1.isValidPublicKey(pub2))
283+
284+
// Modify the source after replace (shouldn't affect replaced keys)
285+
pub3, _, err := ed25519.GenerateKey(nil)
286+
require.NoError(t, err)
287+
288+
pks2.Replace(&PublicKeys{keys: []ed25519.PublicKey{pub3}})
289+
290+
// Original replacement should be unaffected
291+
assert.True(t, pks1.isValidPublicKey(pub2))
292+
assert.False(t, pks1.isValidPublicKey(pub3))
293+
}
294+
295+
func Test_PublicKeys_Concurrency(t *testing.T) {
296+
pub1, _, err := ed25519.GenerateKey(nil)
297+
require.NoError(t, err)
298+
299+
pub2, _, err := ed25519.GenerateKey(nil)
300+
require.NoError(t, err)
301+
302+
pks, err := ValidPublicKeysFromEd25519(pub1)
303+
require.NoError(t, err)
304+
305+
// Simulate concurrent reads and writes
306+
var wg sync.WaitGroup
307+
start := make(chan struct{})
308+
309+
// Add multiple concurrent readers
310+
for i := 0; i < 5; i++ {
311+
wg.Add(1)
312+
go func() {
313+
defer wg.Done()
314+
<-start
315+
for j := 0; j < 10; j++ {
316+
_ = pks.isValidPublicKey(pub1)
317+
_ = pks.Keys()
318+
time.Sleep(time.Millisecond)
319+
}
320+
}()
321+
}
322+
323+
// Add concurrent writers
324+
for i := 0; i < 3; i++ {
325+
wg.Add(1)
326+
go func(idx int) {
327+
defer wg.Done()
328+
<-start
329+
330+
// Alternate between pub1 and pub2
331+
for j := 0; j < 5; j++ {
332+
if (idx+j)%2 == 0 {
333+
pks.Replace(&PublicKeys{keys: []ed25519.PublicKey{pub1}})
334+
} else {
335+
pks.Replace(&PublicKeys{keys: []ed25519.PublicKey{pub2}})
336+
}
337+
time.Sleep(time.Millisecond * 2)
338+
}
339+
}(i)
340+
}
341+
342+
// Start all goroutines
343+
close(start)
344+
wg.Wait()
345+
346+
// No assertion needed - if there are no race conditions, the test passes
347+
}

0 commit comments

Comments
 (0)