@@ -109,6 +109,15 @@ func (e *ECDH1PUKeyAgreement) GenerateEphemeralKey() error {
109109// Returns the derived key and the ephemeral public key (for inclusion in JWE header).
110110// Note: This method does NOT include the cc_tag. Use DeriveKeyWithTag for ECDH-1PU key commitment.
111111func (e * ECDH1PUKeyAgreement ) DeriveKey () ([]byte , jwk.Key , error ) {
112+ return e .DeriveKeyWithTag (nil )
113+ }
114+
115+ // DeriveKeyWithTag derives the key using ECDH-1PU with optional key commitment (cc_tag).
116+ // If ccTag is nil or empty, uses standard Concat KDF.
117+ // If ccTag is provided, includes the content ciphertext authentication tag in the KDF
118+ // per draft-madden-jose-ecdh-1pu for key commitment.
119+ // Returns the derived key and the ephemeral public key.
120+ func (e * ECDH1PUKeyAgreement ) DeriveKeyWithTag (ccTag []byte ) ([]byte , jwk.Key , error ) {
112121 if e .SenderPrivateKey == nil {
113122 return nil , nil , fmt .Errorf ("%w: sender private key is required for ECDH-1PU encryption" , ErrInvalidKey )
114123 }
@@ -119,48 +128,22 @@ func (e *ECDH1PUKeyAgreement) DeriveKey() ([]byte, jwk.Key, error) {
119128 }
120129 }
121130
122- // Extract raw keys for ECDH operations
123- ephemeralPriv , err := extractECDHPrivateKey (e .EphemeralPrivateKey )
124- if err != nil {
125- return nil , nil , fmt .Errorf ("failed to extract ephemeral private key: %w" , err )
126- }
127-
128- senderPriv , err := extractECDHPrivateKey (e .SenderPrivateKey )
129- if err != nil {
130- return nil , nil , fmt .Errorf ("failed to extract sender private key: %w" , err )
131- }
132-
133- recipientPub , err := extractECDHPublicKey (e .RecipientPublicKey )
131+ // Compute the shared secret Z = Z_es || Z_ss
132+ z , err := e .computeSharedSecret ()
134133 if err != nil {
135- return nil , nil , fmt .Errorf ("failed to extract recipient public key: %w" , err )
136- }
137-
138- // Compute Z_es = ECDH(ephemeral_private, recipient_public)
139- zES , err := ephemeralPriv .ECDH (recipientPub )
140- if err != nil {
141- return nil , nil , fmt .Errorf ("failed to compute Z_es: %w" , err )
142- }
143-
144- // Compute Z_ss = ECDH(sender_private, recipient_public)
145- zSS , err := senderPriv .ECDH (recipientPub )
146- if err != nil {
147- return nil , nil , fmt .Errorf ("failed to compute Z_ss: %w" , err )
134+ return nil , nil , err
148135 }
149136
150- // Concatenate: Z = Z_es || Z_ss
151- z := append (zES , zSS ... )
152-
153137 // Get the required key size for the content encryption algorithm
154138 keySize , err := getKeyWrapKeySize (e .Algorithm )
155139 if err != nil {
156140 return nil , nil , err
157141 }
158142
159- // Derive the key wrapping key using Concat KDF
160- // For ECDH-1PU+A256KW, we use the same KDF as ECDH-ES
161- derivedKey , err := concatKDF (z , e .Algorithm , e .APU , e .APV , keySize )
143+ // Derive the key wrapping key using the appropriate KDF
144+ derivedKey , err := e .deriveKeyFromZ (z , ccTag , keySize )
162145 if err != nil {
163- return nil , nil , fmt . Errorf ( "failed to derive key: %w" , err )
146+ return nil , nil , err
164147 }
165148
166149 // Get ephemeral public key for inclusion in JWE header
@@ -172,70 +155,56 @@ func (e *ECDH1PUKeyAgreement) DeriveKey() ([]byte, jwk.Key, error) {
172155 return derivedKey , ephemeralPubKey , nil
173156}
174157
175- // DeriveKeyWithTag derives the key using ECDH-1PU with key commitment (cc_tag).
176- // This includes the content ciphertext authentication tag in the KDF per draft-madden-jose-ecdh-1pu.
177- // Returns the derived key and the ephemeral public key.
178- func (e * ECDH1PUKeyAgreement ) DeriveKeyWithTag (ccTag []byte ) ([]byte , jwk.Key , error ) {
179- if e .SenderPrivateKey == nil {
180- return nil , nil , fmt .Errorf ("%w: sender private key is required for ECDH-1PU encryption" , ErrInvalidKey )
181- }
182-
183- if e .EphemeralPrivateKey == nil {
184- if err := e .GenerateEphemeralKey (); err != nil {
185- return nil , nil , err
186- }
187- }
188-
158+ // computeSharedSecret computes Z = Z_es || Z_ss for ECDH-1PU encryption.
159+ func (e * ECDH1PUKeyAgreement ) computeSharedSecret () ([]byte , error ) {
189160 // Extract raw keys for ECDH operations
190161 ephemeralPriv , err := extractECDHPrivateKey (e .EphemeralPrivateKey )
191162 if err != nil {
192- return nil , nil , fmt .Errorf ("failed to extract ephemeral private key: %w" , err )
163+ return nil , fmt .Errorf ("failed to extract ephemeral private key: %w" , err )
193164 }
194165
195166 senderPriv , err := extractECDHPrivateKey (e .SenderPrivateKey )
196167 if err != nil {
197- return nil , nil , fmt .Errorf ("failed to extract sender private key: %w" , err )
168+ return nil , fmt .Errorf ("failed to extract sender private key: %w" , err )
198169 }
199170
200171 recipientPub , err := extractECDHPublicKey (e .RecipientPublicKey )
201172 if err != nil {
202- return nil , nil , fmt .Errorf ("failed to extract recipient public key: %w" , err )
173+ return nil , fmt .Errorf ("failed to extract recipient public key: %w" , err )
203174 }
204175
205176 // Compute Z_es = ECDH(ephemeral_private, recipient_public)
206177 zES , err := ephemeralPriv .ECDH (recipientPub )
207178 if err != nil {
208- return nil , nil , fmt .Errorf ("failed to compute Z_es: %w" , err )
179+ return nil , fmt .Errorf ("failed to compute Z_es: %w" , err )
209180 }
210181
211182 // Compute Z_ss = ECDH(sender_private, recipient_public)
212183 zSS , err := senderPriv .ECDH (recipientPub )
213184 if err != nil {
214- return nil , nil , fmt .Errorf ("failed to compute Z_ss: %w" , err )
185+ return nil , fmt .Errorf ("failed to compute Z_ss: %w" , err )
215186 }
216187
217188 // Concatenate: Z = Z_es || Z_ss
218- z := append (zES , zSS ... )
219-
220- // Get the required key size for the content encryption algorithm
221- keySize , err := getKeyWrapKeySize (e .Algorithm )
222- if err != nil {
223- return nil , nil , err
224- }
189+ return append (zES , zSS ... ), nil
190+ }
225191
226- // Derive the key wrapping key using Concat KDF with cc_tag (key commitment)
227- derivedKey , err := concatKDF1PU (z , e .Algorithm , e .APU , e .APV , ccTag , keySize )
228- if err != nil {
229- return nil , nil , fmt .Errorf ("failed to derive key with tag: %w" , err )
192+ // deriveKeyFromZ derives the key wrapping key from shared secret Z.
193+ // Uses concatKDF1PU if ccTag is provided, otherwise uses standard concatKDF.
194+ func (e * ECDH1PUKeyAgreement ) deriveKeyFromZ (z , ccTag []byte , keySize int ) ([]byte , error ) {
195+ if len (ccTag ) > 0 {
196+ derivedKey , err := concatKDF1PU (z , e .Algorithm , e .APU , e .APV , ccTag , keySize )
197+ if err != nil {
198+ return nil , fmt .Errorf ("failed to derive key with tag: %w" , err )
199+ }
200+ return derivedKey , nil
230201 }
231202
232- // Get ephemeral public key for inclusion in JWE header
233- ephemeralPubKey , err := e .EphemeralPrivateKey .PublicKey ()
203+ derivedKey , err := concatKDF (z , e .Algorithm , e .APU , e .APV , keySize )
234204 if err != nil {
235- return nil , nil , fmt .Errorf ("failed to get ephemeral public key: %w" , err )
205+ return nil , fmt .Errorf ("failed to derive key: %w" , err )
236206 }
237-
238- return derivedKey , ephemeralPubKey , nil
207+ return derivedKey , nil
239208}
240209
241210// DeriveKeyForDecryption derives the key for decryption given the ephemeral public key.
@@ -281,19 +250,8 @@ func (e *ECDH1PUKeyAgreement) DeriveKeyForDecryption(ephemeralPubKey jwk.Key, se
281250 return nil , err
282251 }
283252
284- // Derive the key wrapping key
285- // Use concatKDF1PU if CCTag is set (for ECDH-1PU key commitment)
286- var derivedKey []byte
287- if len (e .CCTag ) > 0 {
288- derivedKey , err = concatKDF1PU (z , e .Algorithm , e .APU , e .APV , e .CCTag , keySize )
289- } else {
290- derivedKey , err = concatKDF (z , e .Algorithm , e .APU , e .APV , keySize )
291- }
292- if err != nil {
293- return nil , fmt .Errorf ("failed to derive key: %w" , err )
294- }
295-
296- return derivedKey , nil
253+ // Derive the key wrapping key (uses CCTag if set for key commitment)
254+ return e .deriveKeyFromZ (z , e .CCTag , keySize )
297255}
298256
299257// generateECDHKey generates an ECDH key pair for the specified curve.
0 commit comments