Skip to content

Commit 7c1be5c

Browse files
committed
Misc improvements
1 parent 5b7d79f commit 7c1be5c

File tree

2 files changed

+78
-88
lines changed

2 files changed

+78
-88
lines changed

keystore/encryptor.go

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,9 @@ import (
2121

2222
// Opaque error messages to prevent information leakage
2323
var (
24-
ErrEncryptionFailed = fmt.Errorf("encryption operation failed")
25-
ErrDecryptionFailed = fmt.Errorf("decryption operation failed")
24+
ErrSharedSecretFailed = fmt.Errorf("shared secret derivation failed")
25+
ErrEncryptionFailed = fmt.Errorf("encryption operation failed")
26+
ErrDecryptionFailed = fmt.Errorf("decryption operation failed")
2627
)
2728

2829
type EncryptRequest struct {
@@ -103,6 +104,7 @@ func (k *keystore) Encrypt(ctx context.Context, req EncryptRequest) (EncryptResp
103104
if len(req.Data) == 0 || len(req.Data) > MaxEncryptionPayloadSize {
104105
return EncryptResponse{}, ErrEncryptionFailed
105106
}
107+
106108
key, ok := k.keystore[req.KeyName]
107109
if !ok {
108110
return EncryptResponse{}, ErrEncryptionFailed
@@ -135,7 +137,7 @@ func (k *keystore) Encrypt(ctx context.Context, req EncryptRequest) (EncryptResp
135137
if err != nil {
136138
return EncryptResponse{}, ErrEncryptionFailed
137139
}
138-
// The magic here is the the receipient can compute the same
140+
// The magic here is that the receipient can compute the same
139141
// shared secret because ephPriv*G*recipientPriv = ephPub*G.
140142
// This lets them derive the same ephemeral key used for encryption
141143
// so they can decrypt the ciphertext.
@@ -289,17 +291,17 @@ func (k *keystore) DeriveSharedSecret(ctx context.Context, req DeriveSharedSecre
289291

290292
key, ok := k.keystore[req.LocalKeyName]
291293
if !ok {
292-
return DeriveSharedSecretResponse{}, ErrEncryptionFailed
294+
return DeriveSharedSecretResponse{}, ErrSharedSecretFailed
293295
}
294296

295297
switch key.keyType {
296298
case X25519:
297299
if len(req.RemotePubKey) != 32 {
298-
return DeriveSharedSecretResponse{}, ErrEncryptionFailed
300+
return DeriveSharedSecretResponse{}, ErrSharedSecretFailed
299301
}
300302
sharedSecret, err := curve25519.X25519(internal.Bytes(key.privateKey), req.RemotePubKey)
301303
if err != nil {
302-
return DeriveSharedSecretResponse{}, ErrEncryptionFailed
304+
return DeriveSharedSecretResponse{}, ErrSharedSecretFailed
303305
}
304306
return DeriveSharedSecretResponse{
305307
SharedSecret: sharedSecret,
@@ -308,23 +310,23 @@ func (k *keystore) DeriveSharedSecret(ctx context.Context, req DeriveSharedSecre
308310
curve := ecdh.P256()
309311
priv, err := curve.NewPrivateKey(internal.Bytes(key.privateKey))
310312
if err != nil {
311-
return DeriveSharedSecretResponse{}, ErrEncryptionFailed
313+
return DeriveSharedSecretResponse{}, ErrSharedSecretFailed
312314
}
313315
// P-256 uncompressed public keys are 65 bytes (0x04 || x || y)
314316
if len(req.RemotePubKey) != 65 {
315-
return DeriveSharedSecretResponse{}, ErrEncryptionFailed
317+
return DeriveSharedSecretResponse{}, ErrSharedSecretFailed
316318
}
317319
remotePub, err := curve.NewPublicKey(req.RemotePubKey)
318320
if err != nil {
319-
return DeriveSharedSecretResponse{}, ErrEncryptionFailed
321+
return DeriveSharedSecretResponse{}, ErrSharedSecretFailed
320322
}
321323
shared, err := priv.ECDH(remotePub)
322324
if err != nil {
323-
return DeriveSharedSecretResponse{}, ErrEncryptionFailed
325+
return DeriveSharedSecretResponse{}, ErrSharedSecretFailed
324326
}
325327
return DeriveSharedSecretResponse{SharedSecret: shared}, nil
326328
default:
327-
return DeriveSharedSecretResponse{}, ErrEncryptionFailed
329+
return DeriveSharedSecretResponse{}, ErrSharedSecretFailed
328330
}
329331
}
330332

keystore/encryptor_test.go

Lines changed: 65 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -45,11 +45,48 @@ func TestEncryptDecrypt(t *testing.T) {
4545
}{keyType: keys.Keys[1].KeyInfo.KeyType, publicKey: keys.Keys[1].KeyInfo.PublicKey}
4646
}
4747

48-
var tt []struct {
49-
name string
50-
fromKey string
51-
toKey string
52-
expectedError error
48+
type testCase struct {
49+
name string
50+
encryptKey string
51+
remotePubKey []byte
52+
decryptKey string
53+
payload []byte
54+
expectedEncryptError error
55+
expectedDecryptError error
56+
}
57+
58+
var tt = []testCase{
59+
{
60+
name: "Non-existent encrypt key",
61+
encryptKey: "blah",
62+
remotePubKey: testKeysByType[keyName(keystore.X25519, 0)].publicKey,
63+
decryptKey: keyName(keystore.X25519, 0),
64+
payload: []byte("hello world"),
65+
expectedEncryptError: keystore.ErrEncryptionFailed,
66+
},
67+
{
68+
name: "Non-existent decrypt key",
69+
encryptKey: keyName(keystore.X25519, 0),
70+
remotePubKey: testKeysByType[keyName(keystore.X25519, 0)].publicKey,
71+
decryptKey: "blah",
72+
payload: []byte("hello world"),
73+
expectedDecryptError: keystore.ErrDecryptionFailed,
74+
},
75+
{
76+
name: "Max payload",
77+
encryptKey: keyName(keystore.X25519, 0),
78+
remotePubKey: testKeysByType[keyName(keystore.X25519, 0)].publicKey,
79+
decryptKey: keyName(keystore.X25519, 0),
80+
payload: make([]byte, keystore.MaxEncryptionPayloadSize),
81+
},
82+
{
83+
name: "Payload too large",
84+
encryptKey: keyName(keystore.X25519, 0),
85+
remotePubKey: testKeysByType[keyName(keystore.X25519, 0)].publicKey,
86+
decryptKey: keyName(keystore.X25519, 0),
87+
payload: make([]byte, keystore.MaxEncryptionPayloadSize+1),
88+
expectedEncryptError: keystore.ErrEncryptionFailed,
89+
},
5390
}
5491

5592
for _, fromType := range keystore.AllKeyTypes {
@@ -60,48 +97,51 @@ func TestEncryptDecrypt(t *testing.T) {
6097
fromKey := keyName(fromType, 0) // Always use key 0 as source
6198
toKey := keyName(toType, keyIndex)
6299

63-
var expectedError error
100+
var expectedEncryptError error
64101
if fromType == toType && fromType.IsEncryptionKeyType() {
65102
// Same key types should succeed
66-
expectedError = nil
103+
expectedEncryptError = nil
67104
} else {
68105
// Different key types or non-encryption key types should fail
69-
expectedError = keystore.ErrEncryptionFailed
106+
expectedEncryptError = keystore.ErrEncryptionFailed
70107
}
71108

72-
tt = append(tt, struct {
73-
name string
74-
fromKey string
75-
toKey string
76-
expectedError error
77-
}{
78-
name: testName,
79-
fromKey: fromKey,
80-
toKey: toKey,
81-
expectedError: expectedError,
109+
tt = append(tt, testCase{
110+
name: testName,
111+
encryptKey: fromKey,
112+
remotePubKey: testKeysByType[toKey].publicKey,
113+
decryptKey: toKey,
114+
expectedEncryptError: expectedEncryptError,
115+
payload: []byte("hello world"),
82116
})
83117
}
84118
}
85119
}
120+
86121
for _, tt := range tt {
87122
t.Run(tt.name, func(t *testing.T) {
88123
encryptResp, err := ks.Encrypt(ctx, keystore.EncryptRequest{
89-
KeyName: tt.fromKey,
90-
RemotePubKey: testKeysByType[tt.toKey].publicKey,
91-
Data: []byte("hello world"),
124+
KeyName: tt.encryptKey,
125+
RemotePubKey: tt.remotePubKey,
126+
Data: tt.payload,
92127
})
93-
if tt.expectedError != nil {
128+
if tt.expectedEncryptError != nil {
94129
require.Error(t, err)
95-
require.True(t, errors.Is(err, tt.expectedError))
130+
require.True(t, errors.Is(err, tt.expectedEncryptError))
96131
return
97132
}
98133
require.NoError(t, err)
99134
decryptResp, err := ks.Decrypt(ctx, keystore.DecryptRequest{
100-
KeyName: tt.toKey,
135+
KeyName: tt.decryptKey,
101136
EncryptedData: encryptResp.EncryptedData,
102137
})
138+
if tt.expectedDecryptError != nil {
139+
require.Error(t, err)
140+
require.True(t, errors.Is(err, tt.expectedDecryptError))
141+
return
142+
}
103143
require.NoError(t, err)
104-
require.Equal(t, []byte("hello world"), decryptResp.Data)
144+
require.Equal(t, tt.payload, decryptResp.Data)
105145
})
106146
}
107147
}
@@ -132,58 +172,6 @@ func TestEncryptDecrypt_SharedSecret(t *testing.T) {
132172
}
133173
}
134174

135-
func TestEncryptDecrypt_PayloadSizeLimit(t *testing.T) {
136-
ctx := context.Background()
137-
ks, err := keystore.LoadKeystore(ctx, storage.NewMemoryStorage(), keystore.EncryptionParams{
138-
Password: "test-password",
139-
ScryptParams: keystore.FastScryptParams,
140-
})
141-
require.NoError(t, err)
142-
143-
for _, keyType := range keystore.AllEncryptionKeyTypes {
144-
t.Run(fmt.Sprintf("keyType_%s", keyType), func(t *testing.T) {
145-
keyName := fmt.Sprintf("test-key-%s", keyType)
146-
keys, err := ks.CreateKeys(ctx, keystore.CreateKeysRequest{
147-
Keys: []keystore.CreateKeyRequest{
148-
{KeyName: keyName, KeyType: keyType},
149-
},
150-
})
151-
require.NoError(t, err)
152-
// Test encrypting at the limit
153-
maxPayload := make([]byte, keystore.MaxEncryptionPayloadSize)
154-
maxEncryptResp, err := ks.Encrypt(ctx, keystore.EncryptRequest{
155-
KeyName: keyName,
156-
RemotePubKey: keys.Keys[0].KeyInfo.PublicKey,
157-
Data: maxPayload,
158-
})
159-
require.NoError(t, err)
160-
161-
// Test decrypting at max (confirm overhead sufficient)
162-
maxDecryptResp, err := ks.Decrypt(ctx, keystore.DecryptRequest{
163-
KeyName: keyName,
164-
EncryptedData: maxEncryptResp.EncryptedData,
165-
})
166-
require.NoError(t, err)
167-
require.Equal(t, len(maxDecryptResp.Data), len(maxPayload))
168-
169-
// Test encrypting above the limit
170-
_, err = ks.Encrypt(ctx, keystore.EncryptRequest{
171-
KeyName: keyName,
172-
RemotePubKey: keys.Keys[0].KeyInfo.PublicKey,
173-
Data: make([]byte, keystore.MaxEncryptionPayloadSize+1),
174-
})
175-
require.Error(t, err)
176-
177-
// Test decrypting above the limit
178-
_, err = ks.Decrypt(ctx, keystore.DecryptRequest{
179-
KeyName: keyName,
180-
EncryptedData: make([]byte, keystore.MaxEncryptionPayloadSize+1025),
181-
})
182-
require.Error(t, err)
183-
})
184-
}
185-
}
186-
187175
func FuzzEncryptDecryptRoundtrip(f *testing.F) {
188176
// Add seed corpus with various input sizes and patterns
189177
seedCorpus := [][]byte{

0 commit comments

Comments
 (0)