11import base64
22import struct
3- from typing import Dict , Optional , Protocol , Sequence
3+ from typing import Dict , Optional , Protocol , Sequence , Tuple
44from workos .types .vault import VaultObject , ObjectVersion
5- from workos .types .vault .key import DataKey , DataKeyPair , KeyContext
5+ from workos .types .vault .key import DataKey , DataKeyPair , KeyContext , DecodedKeys
66from workos .types .list_resource import (
77 ListArgs ,
88 ListMetadata ,
@@ -285,7 +285,7 @@ def create_object(
285285 request_data = {
286286 "name" : name ,
287287 "value" : value ,
288- "key_context " : key_context . root ,
288+ "context " : key_context ,
289289 }
290290
291291 response = self ._http_client .request (
@@ -341,7 +341,7 @@ def delete_object(
341341
342342 def create_data_key (self , * , key_context : KeyContext ) -> DataKeyPair :
343343 request_data = {
344- "key_context " : key_context . root ,
344+ "context " : key_context ,
345345 }
346346
347347 response = self ._http_client .request (
@@ -350,7 +350,13 @@ def create_data_key(self, *, key_context: KeyContext) -> DataKeyPair:
350350 json = request_data ,
351351 )
352352
353- return DataKeyPair .model_validate (response )
353+ return DataKeyPair .model_validate (
354+ {
355+ "context" : response ["context" ],
356+ "data_key" : {"id" : response ["id" ], "key" : response ["data_key" ]},
357+ "encrypted_keys" : response ["encrypted_keys" ],
358+ }
359+ )
354360
355361 def decrypt_data_key (
356362 self ,
@@ -367,7 +373,9 @@ def decrypt_data_key(
367373 json = request_data ,
368374 )
369375
370- return DataKey .model_validate (response )
376+ return DataKey .model_validate (
377+ {"id" : response ["id" ], "key" : response ["data_key" ]}
378+ )
371379
372380 def encrypt (
373381 self , * , data : str , context : KeyContext , associated_data : Optional [str ] = None
@@ -376,7 +384,7 @@ def encrypt(
376384
377385 key = self ._base64_to_bytes (key_pair .data_key .key )
378386 key_blob = self ._base64_to_bytes (key_pair .encrypted_keys )
379- prefix_len_buffer = self ._encode_uint32 (len (key_blob ))
387+ prefix_len_buffer = self ._encode_u32 (len (key_blob ))
380388 aad_buffer = associated_data .encode ("utf-8" ) if associated_data else None
381389 iv = self ._crypto_provider .random_bytes (12 )
382390
@@ -398,16 +406,16 @@ def decrypt(
398406 self , * , encrypted_data : str , associated_data : Optional [str ] = None
399407 ) -> str :
400408 decoded = self ._decode (encrypted_data )
401- data_key = self .decrypt_data_key (keys = self . _bytes_to_base64 ( decoded [ " keys" ]) )
409+ data_key = self .decrypt_data_key (keys = decoded . keys )
402410
403411 key = self ._base64_to_bytes (data_key .key )
404412 aad_buffer = associated_data .encode ("utf-8" ) if associated_data else None
405413
406414 decrypted_bytes = self ._crypto_provider .decrypt (
407- ciphertext = decoded [ " ciphertext" ] ,
415+ ciphertext = decoded . ciphertext ,
408416 key = key ,
409- iv = decoded [ "iv" ] ,
410- tag = decoded [ " tag" ] ,
417+ iv = decoded . iv ,
418+ tag = decoded . tag ,
411419 aad = aad_buffer ,
412420 )
413421
@@ -419,30 +427,77 @@ def _base64_to_bytes(self, data: str) -> bytes:
419427 def _bytes_to_base64 (self , data : bytes ) -> str :
420428 return base64 .b64encode (data ).decode ("utf-8" )
421429
422- def _encode_uint32 (self , value : int ) -> bytes :
423- return struct .pack (">I" , value ) # Big-endian unsigned int (4 bytes)
430+ def _encode_u32 (self , value : int ) -> bytes :
431+ """
432+ Encode a 32-bit unsigned integer as LEB128.
424433
425- def _decode (self , encrypted_data_b64 : str ) -> Dict [str , bytes ]:
434+ Returns:
435+ bytes: LEB128-encoded representation of the input value.
436+ """
437+ if value < 0 or value > 0xFFFFFFFF :
438+ raise ValueError ("Value must be a 32-bit unsigned integer" )
439+
440+ encoded = bytearray ()
441+ while True :
442+ byte = value & 0x7F
443+ value >>= 7
444+ if value != 0 :
445+ byte |= 0x80 # Set continuation bit
446+ encoded .append (byte )
447+ if value == 0 :
448+ break
449+
450+ return bytes (encoded )
451+
452+ def _decode (self , encrypted_data_b64 : str ) -> DecodedKeys :
426453 """
427454 This function extracts IV, tag, keyBlobLength, keyBlob, and ciphertext
428- from a base64-encoded payload. You must define this according to your encoding format.
429- Assumes format: [IV][TAG][4B Length][keyBlob][ciphertext]
455+ from a base64-encoded payload.
456+ Encoding format: [IV][TAG][4B Length][keyBlob][ciphertext]
430457 """
431- raw = base64 .b64decode (encrypted_data_b64 )
432- offset = 0
458+ try :
459+ payload = base64 .b64decode (encrypted_data_b64 )
460+ except Exception as e :
461+ raise ValueError ("Base64 decoding failed" ) from e
462+
463+ iv = payload [0 :12 ]
464+ tag = payload [12 :28 ]
465+
466+ try :
467+ key_len , leb_len = self ._decode_u32 (payload [28 :])
468+ except Exception as e :
469+ raise ValueError ("Failed to decode key length" ) from e
470+
471+ keys_index = 28 + leb_len
472+ keys_end = keys_index + key_len
473+ keys_slice = payload [keys_index :keys_end ]
474+ keys = base64 .b64encode (keys_slice ).decode ("utf-8" )
475+ ciphertext = payload [keys_end :]
433476
434- iv = raw [offset : offset + 12 ]
435- offset += 12
477+ return DecodedKeys (iv = iv , tag = tag , keys = keys , ciphertext = ciphertext )
478+
479+ def _decode_u32 (self , buf : bytes ) -> Tuple [int , int ]:
480+ """
481+ Decode an unsigned LEB128-encoded 32-bit integer from bytes.
482+
483+ Returns:
484+ (value, length_consumed)
485+
486+ Raises:
487+ ValueError if decoding fails or overflows.
488+ """
489+ res = 0
490+ bit = 0
436491
437- tag = raw [offset : offset + 16 ]
438- offset += 16
492+ for i , b in enumerate (buf ):
493+ if i > 4 :
494+ raise ValueError ("LEB128 integer overflow (was more than 4 bytes)" )
439495
440- key_len = int .from_bytes (raw [offset : offset + 4 ], byteorder = "big" )
441- offset += 4
496+ res |= (b & 0x7F ) << (7 * bit )
442497
443- key_blob = raw [ offset : offset + key_len ]
444- offset += key_len
498+ if ( b & 0x80 ) == 0 :
499+ return res , i + 1
445500
446- ciphertext = raw [ offset :]
501+ bit += 1
447502
448- return { "iv" : iv , "tag" : tag , "keys" : key_blob , "ciphertext" : ciphertext }
503+ raise ValueError ( "LEB128 integer not found" )
0 commit comments