Skip to content

Commit 722e3b3

Browse files
committed
update encoding/decoding to match envelope encryption format, fix ekm endpoint calls
1 parent dff05ee commit 722e3b3

File tree

6 files changed

+131
-38
lines changed

6 files changed

+131
-38
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,7 @@ celerybeat.pid
105105
# Environments
106106
.env
107107
.venv
108+
.venv312
108109
env/
109110
venv/
110111
ENV/

tests/test_vault.py

Lines changed: 25 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -59,13 +59,19 @@ def mock_object_versions(self):
5959

6060
@pytest.fixture
6161
def mock_data_key(self):
62-
return MockDataKey(
63-
"key_01234567890abcdef", "MDEyMzQ1Njc4OWFiY2RlZjAxMjM0NTY3ODlhYmNkZWY="
64-
).dict()
62+
return {
63+
"id": "key_01234567890abcdef",
64+
"data_key": "MDEyMzQ1Njc4OWFiY2RlZjAxMjM0NTY3ODlhYmNkZWY=",
65+
}
6566

6667
@pytest.fixture
6768
def mock_data_key_pair(self):
68-
return MockDataKeyPair().dict()
69+
return {
70+
"context": {"key": "test-key"},
71+
"id": "key_01234567890abcdef",
72+
"data_key": "MDEyMzQ1Njc4OWFiY2RlZjAxMjM0NTY3ODlhYmNkZWY=",
73+
"encrypted_keys": "ZW5jcnlwdGVkX2tleXNfZGF0YQ==",
74+
}
6975

7076
def test_read_object_success(
7177
self, mock_vault_object, capture_and_mock_http_client_request
@@ -186,7 +192,7 @@ def test_create_object_success(
186192
assert request_kwargs["url"].endswith("/vault/v1/kv")
187193
assert request_kwargs["json"]["name"] == "test-secret"
188194
assert request_kwargs["json"]["value"] == "secret-value"
189-
assert request_kwargs["json"]["key_context"] == {"key": "test-key"}
195+
assert request_kwargs["json"]["context"] == KeyContext({"key": "test-key"})
190196
assert vault_object.id == "vault_01234567890abcdef"
191197
assert vault_object.name == "test-secret"
192198
assert vault_object.value == "secret-value"
@@ -310,7 +316,7 @@ def test_create_data_key_success(
310316

311317
assert request_kwargs["method"] == "post"
312318
assert request_kwargs["url"].endswith("/vault/v1/keys/data-key")
313-
assert request_kwargs["json"]["key_context"] == {"key": "test-key"}
319+
assert request_kwargs["json"]["context"] == KeyContext({"key": "test-key"})
314320
assert data_key_pair.data_key.id == "key_01234567890abcdef"
315321
assert data_key_pair.encrypted_keys == "ZW5jcnlwdGVkX2tleXNfZGF0YQ=="
316322

@@ -345,7 +351,7 @@ def test_encrypt_success(
345351
# Verify create_data_key was called
346352
assert request_kwargs["method"] == "post"
347353
assert request_kwargs["url"].endswith("/vault/v1/keys/data-key")
348-
assert request_kwargs["json"]["key_context"] == {"key": "test-key"}
354+
assert request_kwargs["json"]["context"] == KeyContext({"key": "test-key"})
349355

350356
# Verify we got encrypted data back
351357
assert isinstance(encrypted_data, str)
@@ -371,7 +377,12 @@ def test_encrypt_with_associated_data(
371377

372378
def test_decrypt_success(self, mock_data_key, capture_and_mock_http_client_request):
373379
# First encrypt some data to get a valid encrypted payload
374-
mock_data_key_pair = MockDataKeyPair().dict()
380+
mock_data_key_pair = {
381+
"context": {"key": "test-key"},
382+
"id": "key_01234567890abcdef",
383+
"data_key": "MDEyMzQ1Njc4OWFiY2RlZjAxMjM0NTY3ODlhYmNkZWY=",
384+
"encrypted_keys": "ZW5jcnlwdGVkX2tleXNfZGF0YQ==",
385+
}
375386

376387
# Mock create_data_key for encryption
377388
capture_and_mock_http_client_request(self.http_client, mock_data_key_pair, 200)
@@ -393,7 +404,12 @@ def test_decrypt_with_associated_data(
393404
self, mock_data_key, capture_and_mock_http_client_request
394405
):
395406
# First encrypt some data with associated data
396-
mock_data_key_pair = MockDataKeyPair().dict()
407+
mock_data_key_pair = {
408+
"context": {"key": "test-key"},
409+
"id": "key_01234567890abcdef",
410+
"data_key": "MDEyMzQ1Njc4OWFiY2RlZjAxMjM0NTY3ODlhYmNkZWY=",
411+
"encrypted_keys": "ZW5jcnlwdGVkX2tleXNfZGF0YQ==",
412+
}
397413

398414
# Mock create_data_key for encryption
399415
capture_and_mock_http_client_request(self.http_client, mock_data_key_pair, 200)

workos/async_client.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from workos.utils.http_client import AsyncHTTPClient
1515
from workos.webhooks import WebhooksModule
1616
from workos.widgets import WidgetsModule
17+
from workos.vault import VaultModule
1718

1819

1920
class AsyncClient(BaseClient):
@@ -112,3 +113,9 @@ def widgets(self) -> WidgetsModule:
112113
raise NotImplementedError(
113114
"Widgets APIs are not yet supported in the async client."
114115
)
116+
117+
@property
118+
def vault(self) -> VaultModule:
119+
raise NotImplementedError(
120+
"Vault APIs are not yet supported in the async client."
121+
)

workos/client.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from workos.user_management import UserManagement
1515
from workos.utils.http_client import SyncHTTPClient
1616
from workos.widgets import Widgets
17+
from workos.vault import Vault
1718

1819

1920
class SyncClient(BaseClient):
@@ -116,3 +117,9 @@ def widgets(self) -> Widgets:
116117
if not getattr(self, "_widgets", None):
117118
self._widgets = Widgets(http_client=self._http_client)
118119
return self._widgets
120+
121+
@property
122+
def vault(self) -> Vault:
123+
if not getattr(self, "_vault", None):
124+
self._vault = Vault(http_client=self._http_client)
125+
return self._vault

workos/types/vault/key.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from typing import Dict
2-
from pydantic import RootModel
2+
from pydantic import BaseModel, RootModel
33
from workos.types.workos_model import WorkOSModel
44

55

@@ -16,3 +16,10 @@ class DataKeyPair(WorkOSModel):
1616
context: KeyContext
1717
data_key: DataKey
1818
encrypted_keys: str
19+
20+
21+
class DecodedKeys(BaseModel):
22+
iv: bytes
23+
tag: bytes
24+
keys: str # Base64-encoded string
25+
ciphertext: bytes

workos/vault.py

Lines changed: 83 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
import base64
22
import struct
3-
from typing import Dict, Optional, Protocol, Sequence
3+
from typing import Dict, Optional, Protocol, Sequence, Tuple
44
from workos.types.vault import VaultObject, ObjectVersion
5-
from workos.types.vault.key import DataKey, DataKeyPair, KeyContext
5+
from workos.types.vault.key import DataKey, DataKeyPair, KeyContext, DecodedKeys
66
from workos.types.list_resource import (
77
ListArgs,
88
ListMetadata,
@@ -285,7 +285,7 @@ def create_object(
285285
request_data = {
286286
"name": name,
287287
"value": value,
288-
"key_context": key_context.root,
288+
"context": key_context,
289289
}
290290

291291
response = self._http_client.request(
@@ -341,7 +341,7 @@ def delete_object(
341341

342342
def create_data_key(self, *, key_context: KeyContext) -> DataKeyPair:
343343
request_data = {
344-
"key_context": key_context.root,
344+
"context": key_context,
345345
}
346346

347347
response = self._http_client.request(
@@ -350,7 +350,13 @@ def create_data_key(self, *, key_context: KeyContext) -> DataKeyPair:
350350
json=request_data,
351351
)
352352

353-
return DataKeyPair.model_validate(response)
353+
return DataKeyPair.model_validate(
354+
{
355+
"context": response["context"],
356+
"data_key": {"id": response["id"], "key": response["data_key"]},
357+
"encrypted_keys": response["encrypted_keys"],
358+
}
359+
)
354360

355361
def decrypt_data_key(
356362
self,
@@ -367,7 +373,9 @@ def decrypt_data_key(
367373
json=request_data,
368374
)
369375

370-
return DataKey.model_validate(response)
376+
return DataKey.model_validate(
377+
{"id": response["id"], "key": response["data_key"]}
378+
)
371379

372380
def encrypt(
373381
self, *, data: str, context: KeyContext, associated_data: Optional[str] = None
@@ -376,7 +384,7 @@ def encrypt(
376384

377385
key = self._base64_to_bytes(key_pair.data_key.key)
378386
key_blob = self._base64_to_bytes(key_pair.encrypted_keys)
379-
prefix_len_buffer = self._encode_uint32(len(key_blob))
387+
prefix_len_buffer = self._encode_u32(len(key_blob))
380388
aad_buffer = associated_data.encode("utf-8") if associated_data else None
381389
iv = self._crypto_provider.random_bytes(12)
382390

@@ -398,16 +406,16 @@ def decrypt(
398406
self, *, encrypted_data: str, associated_data: Optional[str] = None
399407
) -> str:
400408
decoded = self._decode(encrypted_data)
401-
data_key = self.decrypt_data_key(keys=self._bytes_to_base64(decoded["keys"]))
409+
data_key = self.decrypt_data_key(keys=decoded.keys)
402410

403411
key = self._base64_to_bytes(data_key.key)
404412
aad_buffer = associated_data.encode("utf-8") if associated_data else None
405413

406414
decrypted_bytes = self._crypto_provider.decrypt(
407-
ciphertext=decoded["ciphertext"],
415+
ciphertext=decoded.ciphertext,
408416
key=key,
409-
iv=decoded["iv"],
410-
tag=decoded["tag"],
417+
iv=decoded.iv,
418+
tag=decoded.tag,
411419
aad=aad_buffer,
412420
)
413421

@@ -419,30 +427,77 @@ def _base64_to_bytes(self, data: str) -> bytes:
419427
def _bytes_to_base64(self, data: bytes) -> str:
420428
return base64.b64encode(data).decode("utf-8")
421429

422-
def _encode_uint32(self, value: int) -> bytes:
423-
return struct.pack(">I", value) # Big-endian unsigned int (4 bytes)
430+
def _encode_u32(self, value: int) -> bytes:
431+
"""
432+
Encode a 32-bit unsigned integer as LEB128.
424433
425-
def _decode(self, encrypted_data_b64: str) -> Dict[str, bytes]:
434+
Returns:
435+
bytes: LEB128-encoded representation of the input value.
436+
"""
437+
if value < 0 or value > 0xFFFFFFFF:
438+
raise ValueError("Value must be a 32-bit unsigned integer")
439+
440+
encoded = bytearray()
441+
while True:
442+
byte = value & 0x7F
443+
value >>= 7
444+
if value != 0:
445+
byte |= 0x80 # Set continuation bit
446+
encoded.append(byte)
447+
if value == 0:
448+
break
449+
450+
return bytes(encoded)
451+
452+
def _decode(self, encrypted_data_b64: str) -> DecodedKeys:
426453
"""
427454
This function extracts IV, tag, keyBlobLength, keyBlob, and ciphertext
428-
from a base64-encoded payload. You must define this according to your encoding format.
429-
Assumes format: [IV][TAG][4B Length][keyBlob][ciphertext]
455+
from a base64-encoded payload.
456+
Encoding format: [IV][TAG][4B Length][keyBlob][ciphertext]
430457
"""
431-
raw = base64.b64decode(encrypted_data_b64)
432-
offset = 0
458+
try:
459+
payload = base64.b64decode(encrypted_data_b64)
460+
except Exception as e:
461+
raise ValueError("Base64 decoding failed") from e
462+
463+
iv = payload[0:12]
464+
tag = payload[12:28]
465+
466+
try:
467+
key_len, leb_len = self._decode_u32(payload[28:])
468+
except Exception as e:
469+
raise ValueError("Failed to decode key length") from e
470+
471+
keys_index = 28 + leb_len
472+
keys_end = keys_index + key_len
473+
keys_slice = payload[keys_index:keys_end]
474+
keys = base64.b64encode(keys_slice).decode("utf-8")
475+
ciphertext = payload[keys_end:]
433476

434-
iv = raw[offset : offset + 12]
435-
offset += 12
477+
return DecodedKeys(iv=iv, tag=tag, keys=keys, ciphertext=ciphertext)
478+
479+
def _decode_u32(self, buf: bytes) -> Tuple[int, int]:
480+
"""
481+
Decode an unsigned LEB128-encoded 32-bit integer from bytes.
482+
483+
Returns:
484+
(value, length_consumed)
485+
486+
Raises:
487+
ValueError if decoding fails or overflows.
488+
"""
489+
res = 0
490+
bit = 0
436491

437-
tag = raw[offset : offset + 16]
438-
offset += 16
492+
for i, b in enumerate(buf):
493+
if i > 4:
494+
raise ValueError("LEB128 integer overflow (was more than 4 bytes)")
439495

440-
key_len = int.from_bytes(raw[offset : offset + 4], byteorder="big")
441-
offset += 4
496+
res |= (b & 0x7F) << (7 * bit)
442497

443-
key_blob = raw[offset : offset + key_len]
444-
offset += key_len
498+
if (b & 0x80) == 0:
499+
return res, i + 1
445500

446-
ciphertext = raw[offset:]
501+
bit += 1
447502

448-
return {"iv": iv, "tag": tag, "keys": key_blob, "ciphertext": ciphertext}
503+
raise ValueError("LEB128 integer not found")

0 commit comments

Comments
 (0)