Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 13 additions & 8 deletions tavern/internal/cryptocodec/cryptocodec.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,10 @@ import (
"github.com/cloudflare/circl/dh/x25519"
lru "github.com/hashicorp/golang-lru/v2"
"golang.org/x/crypto/chacha20poly1305"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/encoding"
"google.golang.org/grpc/mem"
"google.golang.org/grpc/status"
)

var session_pub_keys = NewSyncMap()
Expand Down Expand Up @@ -93,7 +95,10 @@ func (s StreamDecryptCodec) Marshal(v any) (mem.BufferSlice, error) {
}

func (s StreamDecryptCodec) Unmarshal(buf mem.BufferSlice, v any) error {
dec_buf, _ := s.Csvc.Decrypt(buf.Materialize())
dec_buf, _, err := s.Csvc.Decrypt(buf.Materialize())
if err != nil {
return status.Error(codes.Unauthenticated, "auth failure")
}

proto := encoding.GetCodecV2("proto")
if proto == nil {
Expand Down Expand Up @@ -140,10 +145,10 @@ func (csvc *CryptoSvc) generate_shared_key(client_pub_key_bytes []byte) []byte {
return shared_key
}

func (csvc *CryptoSvc) Decrypt(in_arr []byte) ([]byte, []byte) {
func (csvc *CryptoSvc) Decrypt(in_arr []byte) ([]byte, []byte, error) {
if len(in_arr) < x25519.Size {
slog.Error(fmt.Sprintf("input bytes too short %d expected at least %d", len(in_arr), x25519.Size))
return FAILURE_BYTES, FAILURE_BYTES
return FAILURE_BYTES, FAILURE_BYTES, fmt.Errorf("input bytes too short")
}

// CRITICAL FIX: Make a distinct copy of the public key.
Expand All @@ -155,7 +160,7 @@ func (csvc *CryptoSvc) Decrypt(in_arr []byte) ([]byte, []byte) {
ids, err := goAllIds()
if err != nil {
slog.Error("failed to get goid")
return FAILURE_BYTES, FAILURE_BYTES
return FAILURE_BYTES, FAILURE_BYTES, err
}
session_pub_keys.Store(ids.Id, client_pub_key_bytes)

Expand All @@ -165,7 +170,7 @@ func (csvc *CryptoSvc) Decrypt(in_arr []byte) ([]byte, []byte) {
aead, err := chacha20poly1305.NewX(derived_key)
if err != nil {
slog.Error(fmt.Sprintf("failed to create xchacha key %v", err))
return FAILURE_BYTES, FAILURE_BYTES
return FAILURE_BYTES, FAILURE_BYTES, err
}

// Progress in_arr buf
Expand All @@ -174,18 +179,18 @@ func (csvc *CryptoSvc) Decrypt(in_arr []byte) ([]byte, []byte) {
// Read nonce & cipher text
if len(in_arr) < aead.NonceSize() {
slog.Error(fmt.Sprintf("input bytes to short %d expected at least %d", len(in_arr), aead.NonceSize()))
return FAILURE_BYTES, FAILURE_BYTES
return FAILURE_BYTES, FAILURE_BYTES, fmt.Errorf("input bytes too short")
}
nonce, ciphertext := in_arr[:aead.NonceSize()], in_arr[aead.NonceSize():]

// Decrypt
plaintext, err := aead.Open(nil, nonce, ciphertext, nil)
if err != nil {
slog.Error(fmt.Sprintf("failed to decrypt %v", err))
return FAILURE_BYTES, FAILURE_BYTES
return FAILURE_BYTES, FAILURE_BYTES, err
}

return plaintext, client_pub_key_bytes
return plaintext, client_pub_key_bytes, nil
}

// TODO: Don't use [] ref.
Expand Down
2 changes: 1 addition & 1 deletion tavern/internal/cryptocodec/cryptocodec_bench_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ func BenchmarkDecrypt(b *testing.B) {
b.SetBytes(int64(len(plaintext)))

for i := 0; i < b.N; i++ {
decrypted, _ := svc.Decrypt(payload)
decrypted, _, _ := svc.Decrypt(payload)
if len(decrypted) == 0 {
b.Fatal("Decrypt returned empty slice")
}
Expand Down
9 changes: 6 additions & 3 deletions tavern/internal/cryptocodec/cryptocodec_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -107,8 +107,9 @@ func TestDecrypt(t *testing.T) {

// Test Decrypt
svc := NewCryptoSvc(serverPrivKey)
decrypted, pubKey := svc.Decrypt(payload)
decrypted, pubKey, err := svc.Decrypt(payload)

assert.NoError(t, err)
assert.Equal(t, plaintext, decrypted)
assert.Equal(t, clientPubKey, pubKey)
}
Expand All @@ -118,7 +119,8 @@ func TestDecrypt_ShortInput(t *testing.T) {
require.NoError(t, err)
svc := NewCryptoSvc(serverPrivKey)

res, _ := svc.Decrypt([]byte{0x00})
res, _, err := svc.Decrypt([]byte{0x00})
assert.Error(t, err)
assert.Equal(t, FAILURE_BYTES, res)
}

Expand All @@ -129,7 +131,8 @@ func TestDecrypt_ShortInputAfterKey(t *testing.T) {

// Input long enough for key (32 bytes) but not nonce
input := make([]byte, 32+1)
res, _ := svc.Decrypt(input)
res, _, err := svc.Decrypt(input)
assert.Error(t, err)
assert.Equal(t, FAILURE_BYTES, res)
}

Expand Down
Loading