diff --git a/README.md b/README.md index 4b4c5e6..7f56624 100644 --- a/README.md +++ b/README.md @@ -157,9 +157,11 @@ The most common usage of JWTs is to generate signed tokens. You can achieve this when building a JWT. ```kotlin +val signingKey = SigningAlgorithm.HS256.newKey() + val token: JwtInstance = Jwt.builder() .subject("1234567890") - .signWith(JwsAlgorithm.HS256, hmacKey) + .signWith(signingKey) ``` The result of the operation is a `JwtInstance` object. That object is a Kotlin representation of the JWT. You can use @@ -181,10 +183,10 @@ Another common usage of JWTs is to verify the authenticity of a token. If you ar ensure that the user hasn't modified the token on their side. That is achieved by verifying the signature of the token. ```kotlin -val compactToken: String = // +val compactToken: String = // val jwtParser = Jwt.parser() - .verifyWith(JwsAlgorithm.HS256, hmacKey) + .verifyWith(signingKey) .build() val parsedToken = jwtParser.parse(compactToken) @@ -206,28 +208,48 @@ implement any cryptographic operations. Instead, we rely on the [Cryptography Kotlin](https://github.com/whyoleg/cryptography-kotlin) library. It's an amazing and robust library that provides a wide range of cryptographic operations, and providers for most of the Kotlin Multiplatform targets. -To generate the `hmacKey` we used in the previous examples, you can use the following code: +KJWT ships extension functions on each algorithm object that generate or decode keys without +requiring you to touch the `cryptography-kotlin` API directly: ```kotlin -val myKeyString = "a-string-secret-at-least-256-bits-long" - .encodeToByteArray() // Convert the string into a byte array to perform the crypto operations +import co.touchlab.kjwt.model.algorithm.SigningAlgorithm + +// Generate a new random HMAC key +val key = SigningAlgorithm.HS256.newKey() + +// Decode an HMAC key from existing bytes +val key = SigningAlgorithm.HS256.parse(myKeyBytes) + +// Generate an RSA key pair (also available for RS384/RS512, PS*, ES*) +val key = SigningAlgorithm.RS256.newKey() + +// Decode individual RSA/ECDSA keys +val key = SigningAlgorithm.RS256.parsePublicKey(pemBytes) // verify only +val key = SigningAlgorithm.RS256.parsePrivateKey(pemBytes) // sign only +val key = SigningAlgorithm.RS256.parseKeyPair(pubPem, privPem) +``` + +The returned `SigningKey` can be passed directly to `signWith` or `verifyWith`: -val hmacKey = - CryptographyProvider.Default // Get the provider you use for your project. CryptographyProvider.Default is most common - .get(HMAC) // Get the HMAC algorithm - .keyDecoder(SHA256) // Use the correct digest for your operation. For HS256, use SHA256. For HS384 use SHA384, etc. - .decodeFromByteArray(HMAC.Key.Format.RAW, myKeyString) // Decode your key bytes into a HMAC key +```kotlin +val key = SigningAlgorithm.HS256.parse(myKeyBytes) -// Then you can use the HMAC key to sign or verify tokens val token: JwtInstance = Jwt.builder() .subject("1234567890") - .signWith(JwsAlgorithm.HS256, hmacKey) + .signWith(key) // Use the key to sign the token val jwtParser = Jwt.parser() - .verifyWith(JwsAlgorithm.HS256, hmacKey) + .verifyWith(key) // Use the key to verify the token .build() + +val compactToken: String = // ... +val parsedToken = jwtParser.parse(compactToken) // Token will get verified using the key used in the builder ``` +If you prefer to work with `cryptography-kotlin` directly, you can also construct keys manually +using its API and pass them to `signWith` / `verifyWith`. For a full reference of all key helper +overloads, see the [usage guide](docs/usage.md#keys). + ### More features For a more detailed list of features, check out the usage documentation available at the [docs](docs/USAGE.md). \ No newline at end of file diff --git a/build-logic/src/main/kotlin/kjwt/tests.kt b/build-logic/src/main/kotlin/kjwt/tests.kt index 77c7089..7e879a7 100644 --- a/build-logic/src/main/kotlin/kjwt/tests.kt +++ b/build-logic/src/main/kotlin/kjwt/tests.kt @@ -52,8 +52,6 @@ private fun KotlinMultiplatformExtension.configureJSTests() { enabled = false } } - - return@configureEach } whenBrowserConfigured { diff --git a/config/detekt/detekt.yml b/config/detekt/detekt.yml index e16e227..6e21875 100644 --- a/config/detekt/detekt.yml +++ b/config/detekt/detekt.yml @@ -168,11 +168,11 @@ complexity: TooManyFunctions: active: true excludes: ['**/test/**', '**/androidTest/**', '**/commonTest/**', '**/jvmTest/**', '**/androidUnitTest/**', '**/androidInstrumentedTest/**', '**/jsTest/**', '**/iosTest/**'] - thresholdInFiles: 25 - thresholdInClasses: 25 - thresholdInInterfaces: 25 - thresholdInObjects: 25 - thresholdInEnums: 25 + thresholdInFiles: 35 + thresholdInClasses: 35 + thresholdInInterfaces: 35 + thresholdInObjects: 35 + thresholdInEnums: 35 ignoreDeprecated: true ignorePrivate: true ignoreOverridden: false @@ -687,7 +687,7 @@ style: active: false ReturnCount: active: true - max: 2 + max: 4 excludedFunctions: - 'equals' excludeLabeled: false diff --git a/docs/usage.md b/docs/usage.md index f04a5fc..71be0a5 100644 --- a/docs/usage.md +++ b/docs/usage.md @@ -8,18 +8,20 @@ All signing, verifying, encrypting, and decrypting operations are `suspend` func ```kotlin import co.touchlab.kjwt.Jwt -import co.touchlab.kjwt.algorithm.JwsAlgorithm import co.touchlab.kjwt.model.JwtInstance +import co.touchlab.kjwt.model.algorithm.SigningAlgorithm import kotlin.time.Clock import kotlin.time.Duration.Companion.hours +val signingKey = SigningAlgorithm.HS256.newKey() + val jws: JwtInstance.Jws = Jwt.builder() .issuer("my-app") .subject("user-123") .audience("api") .expiresIn(1.hours) .issuedAt(Clock.System.now()) - .signWith(JwsAlgorithm.HS256, hmacKey) + .signWith(signingKey) val token: String = jws.compact() ``` @@ -28,7 +30,7 @@ val token: String = jws.compact() ```kotlin val parser = Jwt.parser() - .verifyWith(JwsAlgorithm.HS256, hmacKey) + .verifyWith(signingKey) .requireIssuer("my-app") .requireAudience("api") .clockSkew(30L) // seconds of tolerance @@ -41,10 +43,15 @@ val subject: String = jws.payload.subject ### Encrypt a JWT (JWE) ```kotlin +import co.touchlab.kjwt.model.algorithm.EncryptionAlgorithm +import co.touchlab.kjwt.model.algorithm.EncryptionContentAlgorithm + +val encKey = EncryptionAlgorithm.RsaOaep256.newKey() + val jwe: JwtInstance.Jwe = Jwt.builder() .subject("user-123") .expiresIn(1.hours) - .encryptWith(rsaPublicKey, JweKeyAlgorithm.RsaOaep256, JweContentAlgorithm.A256GCM) + .encryptWith(encKey, EncryptionContentAlgorithm.A256GCM) val token: String = jwe.compact() ``` @@ -53,7 +60,7 @@ val token: String = jwe.compact() ```kotlin val parser = Jwt.parser() - .decryptWith(JweKeyAlgorithm.RsaOaep256, rsaPrivateKey) + .decryptWith(encKey) .build() val jwe = parser.parseEncrypted(token) @@ -62,11 +69,173 @@ val subject: String = jwe.payload.subject --- +## Keys + +The `co.touchlab.kjwt.ext` package provides extension functions on each algorithm family for +generating and decoding keys. The goal of those extensions are hide the `cryptography-kotlin` API, +and simplify the integration for the developers. + +### HMAC keys (HS256 / HS384 / HS512) + +```kotlin +import co.touchlab.kjwt.model.algorithm.SigningAlgorithm + +// Generate a new random key +val signingKey = SigningAlgorithm.HS256.newKey() + +// Decode an existing key from raw bytes +val signingKey = SigningAlgorithm.HS256.parse(keyBytes) + +// Decode from a non-default format (e.g. JWK) +val signingKey = SigningAlgorithm.HS256.parse(keyBytes, format = HMAC.Key.Format.JWK) +``` + +The returned `SigningKey` is a `SigningKeyPair` — usable for both signing and verification because +HMAC uses a single symmetric key. + +### RSA PKCS#1 v1.5 keys (RS256 / RS384 / RS512) + +```kotlin +// Generate a new key pair (defaults: 4096-bit modulus, exponent 65537) +val signingKey = SigningAlgorithm.RS256.newKey() +val signingKey = SigningAlgorithm.RS256.newKey(keySize = 2048.bits) + +// Decode individual keys (PEM is the default format) +val verifyKey = SigningAlgorithm.RS256.parsePublicKey(pemBytes) // VerifyOnlyKey +val signKey = SigningAlgorithm.RS256.parsePrivateKey(pemBytes) // SigningOnlyKey + +// Decode both at once +val signingKey = SigningAlgorithm.RS256.parseKeyPair(publicPem, privatePem) +``` + +### RSA PSS keys (PS256 / PS384 / PS512) + +```kotlin +// Generate a new key pair +val signingKey = SigningAlgorithm.PS256.newKey() + +// Decode individual keys +val verifyKey = SigningAlgorithm.PS256.parsePublicKey(pemBytes) +val signKey = SigningAlgorithm.PS256.parsePrivateKey(pemBytes) + +// Decode both at once +val signingKey = SigningAlgorithm.PS256.parseKeyPair(publicPem, privatePem) +``` + +### ECDSA keys (ES256 / ES384 / ES512) + +The curve is inferred from the algorithm — P-256 for ES256, P-384 for ES384, P-521 for ES512. + +```kotlin +// Generate a new key pair +val signingKey = SigningAlgorithm.ES256.newKey() + +// Decode individual keys (RAW is the default format) +val verifyKey = SigningAlgorithm.ES256.parsePublicKey(rawBytes) +val signKey = SigningAlgorithm.ES256.parsePrivateKey(rawBytes) + +// Decode both at once +val signingKey = SigningAlgorithm.ES256.parseKeyPair(publicBytes, privateBytes) + +// PEM format +val verifyKey = SigningAlgorithm.ES256.parsePublicKey(pemBytes, format = EC.PublicKey.Format.PEM) +``` + +### Associating a `kid` with a key + +All helpers accept an optional `keyId` parameter. When set, it is embedded in the `SigningKey` +identifier so the parser can select the right key by matching the token's `kid` header field: + +```kotlin +val signingKey = SigningAlgorithm.RS256.parseKeyPair(publicPem, privatePem, keyId = "key-2024") +``` + +### Using a signing key with the parser + +The `SigningKey` returned by any of these helpers can be passed directly to +`JwtParserBuilder.verifyWith`: + +```kotlin +val key = SigningAlgorithm.HS256.parse(keyBytes) + +val parser = Jwt.parser() + .verifyWith(key) + .build() +``` + +--- + +## Encryption Keys + +The `co.touchlab.kjwt.ext` package also provides extension functions on each encryption algorithm +family for generating and decoding JWE keys. + +### Direct key (`dir`) + +The `dir` algorithm uses the raw key bytes directly as the Content Encryption Key (CEK). The byte +length must match the content algorithm's required size (16 bytes for A128GCM/A128CBC-HS256, +24 bytes for A192GCM/A192CBC-HS384, 32 bytes for A256GCM/A256CBC-HS512). + +```kotlin +import co.touchlab.kjwt.model.algorithm.EncryptionAlgorithm + +// Wrap existing bytes (length must match the content algorithm) +val encKey = EncryptionAlgorithm.Dir.key(cekBytes) + +// Generate random bytes of a given size (defaults to 256 bits) +val encKey = EncryptionAlgorithm.Dir.newKey() +val encKey = EncryptionAlgorithm.Dir.newKey(keySize = 128.bits) +``` + +The returned `EncryptionKey` is an `EncryptionKeyPair` — usable for both encryption and +decryption since `dir` uses the same symmetric key for both operations. + +### RSA-OAEP keys (RSA-OAEP / RSA-OAEP-256) + +```kotlin +// Generate a new key pair (defaults: 4096-bit modulus, exponent 65537) +val encKey = EncryptionAlgorithm.RsaOaep.newKey() +val encKey = EncryptionAlgorithm.RsaOaep256.newKey(keySize = 2048.bits) + +// Decode individual keys (PEM is the default format) +val encryptKey = EncryptionAlgorithm.RsaOaep.parsePublicKey(pemBytes) // EncryptionOnlyKey +val decryptKey = EncryptionAlgorithm.RsaOaep.parsePrivateKey(pemBytes) // DecryptionOnlyKey + +// Decode both at once +val encKey = EncryptionAlgorithm.RsaOaep.parseKeyPair(publicPem, privatePem) +``` + +### Associating a `kid` with an encryption key + +All helpers accept an optional `keyId` parameter, which is embedded in the `EncryptionKey` +identifier so the parser can select the right key by matching the token's `kid` header field: + +```kotlin +val encKey = EncryptionAlgorithm.RsaOaep.parseKeyPair(publicPem, privatePem, keyId = "enc-key-2024") +``` + +### Using an encryption key with the parser + +The `EncryptionKey` returned by any of these helpers can be passed directly to +`JwtParserBuilder.decryptWith`: + +```kotlin +val key = EncryptionAlgorithm.RsaOaep.parsePrivateKey(pemBytes) + +val parser = Jwt.parser() + .decryptWith(key) + .build() +``` + +--- + ## Standard Claims All seven RFC 7519 registered claims are supported via the builder: ```kotlin +val signingKey = SigningAlgorithm.HS256.newKey() + val jws: JwtInstance.Jws = Jwt.builder() .issuer("my-app") // iss .subject("user-123") // sub @@ -79,7 +248,7 @@ val jws: JwtInstance.Jws = Jwt.builder() .issuedNow() // iat (convenience: now) .id("unique-token-id") // jti .randomId() // jti (convenience: random UUID, @ExperimentalUuidApi) - .signWith(JwsAlgorithm.HS256, hmacKey) + .signWith(signingKey) val token: String = jws.compact() ``` @@ -98,7 +267,7 @@ val jws: JwtInstance.Jws = Jwt.builder() .claim("metadata", MyMetadata.serializer(), MyMetadata(version = 2)) // raw JsonElement .claim("raw", JsonPrimitive(42)) - .signWith(JwsAlgorithm.HS256, hmacKey) + .signWith(signingKey) val token: String = jws.compact() ``` @@ -106,18 +275,192 @@ val token: String = jws.compact() ## Header Parameters ```kotlin +val rsaSigningKey = SigningAlgorithm.RS256.parsePrivateKey(pemBytes, keyId = "key-2024-01") + val jws: JwtInstance.Jws = Jwt.builder() .subject("user-123") - .keyId("key-2024-01") // kid header parameter .header { type = "JWT" // typ (default: "JWT") contentType = "application/json" // cty } - .signWith(JwsAlgorithm.RS256, rsaPrivateKey) + .signWith(rsaSigningKey) val token: String = jws.compact() ``` +## Key ID (`kid`) + +The `kid` header parameter identifies which key was used to sign or encrypt a token — defined in RFC 7515 §4.1.4 for JWS and RFC 7516 §4.1.6 for JWE. It is useful when a server holds multiple keys or rotates keys over time — the recipient can use `kid` to look up the correct verification or decryption key without trying each one. + +### Setting `kid` when signing + +Pass the key ID via the `keyId` parameter when constructing the key — it is automatically embedded +in the key's identifier and written to the JWT header: + +```kotlin +val signingKey = SigningAlgorithm.RS256.parsePrivateKey(pemBytes, keyId = "key-2024-01") + +val jws: JwtInstance.Jws = Jwt.builder() + .subject("user-123") + .signWith(signingKey) + +// → header: {"typ":"JWT","alg":"RS256","kid":"key-2024-01"} +``` + +### Setting `kid` when encrypting + +Same pattern for `encryptWith`: + +```kotlin +val encKey = EncryptionAlgorithm.RsaOaep256.parsePublicKey(pemBytes, keyId = "enc-key-1") + +val jwe: JwtInstance.Jwe = Jwt.builder() + .subject("user-123") + .encryptWith(encKey, EncryptionContentAlgorithm.A256GCM) + +// → header: {"alg":"RSA-OAEP-256","enc":"A256GCM","kid":"enc-key-1"} +``` + +### JWK extensions + +When using the JWK builder extensions, `kid` defaults to the JWK's own `kid` field so you don't have to repeat it: + +```kotlin +val jwk = Jwk.Rsa(/* ... */, kid = "key-2024-01") + +Jwt.builder().subject("user-123") + .signWith(SigningAlgorithm.RS256, jwk) // kid = "key-2024-01" (from jwk.kid) + .signWith(SigningAlgorithm.RS256, jwk, null) // kid omitted + .signWith(SigningAlgorithm.RS256, jwk, "other-key") // kid = "other-key" (explicit override) +``` + +### Reading `kid` from a parsed token + +```kotlin +val jws = parser.parseSigned(token) +val kid: String? = jws.header.keyId // null if the header did not include kid +``` + +--- + +## Multiple Keys (Key Rotation) + +The parser can hold multiple keys for the same algorithm, each identified by an optional `kid`. This is useful for key rotation — where old and new keys need to coexist during a transition period — or for multi-tenant scenarios where different parties use different keys. + +### Registering multiple signing keys + +Embed a `kid` in each key via the `keyId` parameter when constructing it. Each `(algorithm, kid)` pair must be unique; registering the same combination twice throws `IllegalArgumentException` at builder time. + +```kotlin +val key2024 = SigningAlgorithm.RS256.parsePublicKey(pem2024, keyId = "key-2024") +val key2025 = SigningAlgorithm.RS256.parsePublicKey(pem2025, keyId = "key-2025") + +val parser = Jwt.parser() + .verifyWith(key2024) + .verifyWith(key2025) + .build() + +// Token signed with kid="key-2024" → verified with key2024 +// Token signed with kid="key-2025" → verified with key2025 +``` + +### Lookup priority + +When parsing a token the key is selected by this ordered strategy: + +1. **Exact match** — find a registered key whose algorithm and `kid` both match the token's header. +2. **Algo-only fallback** — if the token has a `kid` but no exact match exists, use the key registered for that algorithm *without* a `kid` (constructed without a `keyId`). This lets you register a single "catch-all" key alongside specific ones. +3. **`noVerify()` fallback** — if no key is found and `noVerify()` was configured on the builder, signature verification is skipped entirely. + +```kotlin +val specificKey = SigningAlgorithm.RS256.parsePublicKey(pem2024, keyId = "key-2024") +val fallbackKey = SigningAlgorithm.RS256.parsePublicKey(pemFallback) // no keyId → catch-all + +val parser = Jwt.parser() + .verifyWith(specificKey) // matched first by exact kid + .verifyWith(fallbackKey) // used when no exact kid match + .build() +``` + +If no key matches and `noVerify()` was not set, parsing throws `IllegalStateException`. + +### Multiple decryption keys (JWE) + +The same rules apply to `decryptWith`: + +```kotlin +val privateKey2024 = EncryptionAlgorithm.RsaOaep256.parsePrivateKey(pem2024, keyId = "enc-key-2024") +val privateKey2025 = EncryptionAlgorithm.RsaOaep256.parsePrivateKey(pem2025, keyId = "enc-key-2025") + +val parser = Jwt.parser() + .decryptWith(privateKey2024) + .decryptWith(privateKey2025) + .build() +``` + +### Using a shared `JwtKeyRegistry` + +`JwtKeyRegistry` is a centralised key store that can be shared across multiple builder and parser +instances. This is useful when you want to manage keys in one place — for example in a dependency +injection container — and reuse them without repeating configuration. + +#### Signing with a registry + +Pass a `JwtKeyRegistry` to `signWith` instead of a raw key: + +```kotlin +val registry = JwtKeyRegistry() +// Keys are added to the registry via JwtParserBuilder and shared by reference, +// or by registering them directly when both parties share the same module. + +val token = Jwt.builder() + .subject("user-123") + .signWith(SigningAlgorithm.HS256, registry) // looks up the private key from the registry + .compact() +``` + +If no matching key is found in the registry an `IllegalStateException` is thrown. + +#### Encrypting with a registry + +Same pattern for JWE encryption: + +```kotlin +val token = Jwt.builder() + .subject("user-123") + .encryptWith(registry, EncryptionAlgorithm.RsaOaep256, EncryptionContentAlgorithm.A256GCM) + .compact() +``` + +#### Sharing a registry with the parser — `useKeysFrom` + +`useKeysFrom` configures a parser to delegate key look-up to an existing registry. The registry is +searched **before** any keys registered directly on the parser builder, so a shared registry acts +as the primary key source. + +```kotlin +val parser = Jwt.parser() + .useKeysFrom(registry) // delegate to shared registry + .requireIssuer("my-app") + .build() + +val jws = parser.parseSigned(token) +``` + +You can combine `useKeysFrom` with direct `verifyWith` / `decryptWith` calls. The parser's own +keys take priority; the registry is only consulted when no local key matches: + +```kotlin +val localKey = SigningAlgorithm.HS256.newKey() + +val parser = Jwt.parser() + .verifyWith(localKey) // checked first + .useKeysFrom(sharedRegistry) // fallback if no local key matches + .build() +``` + +--- + ## Parsing Claims Access standard claims via extension properties. Mandatory variants throw `MissingClaimException` if the claim is absent; `OrNull` variants return `null`: @@ -151,12 +494,14 @@ val role: String? = payload.getClaimOrNull("role") Configure required claims on the parser; any failure throws an appropriate exception: ```kotlin +val ecSigningKey = SigningAlgorithm.ES256.parsePublicKey(rawBytes) + val parser = Jwt.parser() - .verifyWith(JwsAlgorithm.ES256, ecPublicKey) + .verifyWith(ecSigningKey) .requireIssuer("my-app") // throws IncorrectClaimException on mismatch .requireSubject("user-123") .requireAudience("api") - .require("role", "admin") // generic claim equality check + .requireClaim("role", "admin") // generic claim equality check .clockSkew(30L) // seconds of exp/nbf tolerance .build() ``` @@ -175,7 +520,7 @@ Permits tokens where `alg=none` was used at creation time. All other algorithms // Create an unsecured JWT val jws: JwtInstance.Jws = Jwt.builder() .subject("user-123") - .signWith(JwsAlgorithm.None) + .build() val token: String = jws.compact() @@ -250,20 +595,22 @@ println(payload.subject) For symmetric encryption where the key is used directly as the CEK (no key wrapping): ```kotlin -import co.touchlab.kjwt.algorithm.JweKeyAlgorithm -import co.touchlab.kjwt.algorithm.JweContentAlgorithm -import co.touchlab.kjwt.ext.encryptWith // extension for ByteArray / String keys +import co.touchlab.kjwt.model.algorithm.EncryptionAlgorithm +import co.touchlab.kjwt.model.algorithm.EncryptionContentAlgorithm + +// Wrap existing raw bytes as a symmetric encryption key +val encKey = EncryptionAlgorithm.Dir.key(cekBytes) +// Or generate a fresh random key: EncryptionAlgorithm.Dir.newKey() -// Encrypt - key is the raw CEK bytes (must match content algorithm key size) val jwe: JwtInstance.Jwe = Jwt.builder() .subject("user-123") - .encryptWith(cekBytes, JweKeyAlgorithm.Dir, JweContentAlgorithm.A256GCM) + .encryptWith(encKey, EncryptionContentAlgorithm.A256GCM) val token: String = jwe.compact() -// Decrypt +// Decrypt — use the same key (Dir is symmetric) val parser = Jwt.parser() - .decryptWith(JweKeyAlgorithm.Dir, SimpleKey(cekBytes)) + .decryptWith(encKey) .build() val jwe = parser.parseEncrypted(token) diff --git a/lib/src/commonMain/kotlin/co/touchlab/kjwt/builder/JwtBuilder.kt b/lib/src/commonMain/kotlin/co/touchlab/kjwt/builder/JwtBuilder.kt index 708a6cf..eaa6393 100644 --- a/lib/src/commonMain/kotlin/co/touchlab/kjwt/builder/JwtBuilder.kt +++ b/lib/src/commonMain/kotlin/co/touchlab/kjwt/builder/JwtBuilder.kt @@ -10,6 +10,10 @@ import co.touchlab.kjwt.model.JwtPayload import co.touchlab.kjwt.model.algorithm.EncryptionAlgorithm import co.touchlab.kjwt.model.algorithm.EncryptionContentAlgorithm import co.touchlab.kjwt.model.algorithm.SigningAlgorithm +import co.touchlab.kjwt.model.registry.EncryptionKey +import co.touchlab.kjwt.model.registry.JwtKeyRegistry +import co.touchlab.kjwt.model.registry.SigningKey +import co.touchlab.kjwt.model.registry.SigningKey.Identifier import dev.whyoleg.cryptography.materials.key.Key import kotlinx.serialization.SerializationStrategy import kotlinx.serialization.json.JsonElement @@ -22,18 +26,20 @@ import kotlin.uuid.ExperimentalUuidApi * * Example — signed: * ```kotlin + * val signingKey = SigningAlgorithm.HS256.newKey() * val token = Jwt.builder() * .subject("user123") * .issuer("myapp") * .expiration(Clock.System.now() + 1.hours) - * .signWith(JwsAlgorithm.HS256, hmacKey) + * .signWith(signingKey) * ``` * * Example — encrypted: * ```kotlin + * val encKey = EncryptionAlgorithm.RsaOaep256.newKey() * val token = Jwt.builder() * .subject("user123") - * .encryptWith(rsaPublicKey, JweKeyAlgorithm.RsaOaep256, JweContentAlgorithm.A256GCM) + * .encryptWith(encKey, EncryptionContentAlgorithm.A256GCM) * ``` */ public class JwtBuilder { @@ -177,28 +183,94 @@ public class JwtBuilder { apply { headerBuilder.block() } /** - * Sets the key ID (`kid`) header parameter. + * Builds and returns a JWS compact serialization: `header.payload.signature`. * - * @param kid the key identifier - * @return this builder for chaining + * For [SigningAlgorithm.None] the signature part is empty, producing `header.payload.` + * + * @param algorithm the signing algorithm to use + * @param key the private key (or symmetric key) used to produce the signature + * @param keyId optional key ID to embed in the JWT header's `kid` field. Defaults to `null`. + * @return the resulting [JwtInstance.Jws] compact serialization */ - public fun keyId(kid: String): JwtBuilder = - apply { headerBuilder.keyId = kid } + public suspend fun signWith( + algorithm: SigningAlgorithm, + key: PrivateKey, + keyId: String? = null, + ): JwtInstance.Jws = + signWithSigningKey(SigningKey.SigningOnlyKey(Identifier(algorithm, keyId), key)) /** - * Builds and returns a JWS compact serialization: `header.payload.signature`. + * Looks up the private key from [registry] and builds a JWS compact serialization. * - * For [SigningAlgorithm.None] the signature part is empty, producing `header.payload.` + * The registry is searched using [algorithm] and [keyId] as the look-up criteria (see + * [JwtKeyRegistry] for the full look-up order). If no matching key is found an + * [IllegalStateException] is thrown. + * + * Passing [co.touchlab.kjwt.model.algorithm.SigningAlgorithm.None] delegates directly to + * [build] (unsecured token) without consulting the registry. + * + * @param algorithm the signing algorithm to use + * @param registry the key registry to look up the private key from + * @param keyId optional key ID used for registry look-up and embedded in the JWT header's + * `kid` field. Defaults to `null`. + * @return the resulting [JwtInstance.Jws] compact serialization + * @throws IllegalStateException if no signing key for [algorithm] (and [keyId]) is found in + * [registry] + * @see JwtKeyRegistry */ public suspend fun signWith( algorithm: SigningAlgorithm, - key: PrivateKey + registry: JwtKeyRegistry, + keyId: String? = null, ): JwtInstance.Jws { - val header = headerBuilder.build(algorithm) + if (algorithm == SigningAlgorithm.None) { + return build() + } + + val key = requireNotNull(registry.findBestSigningKey(algorithm, keyId)) { + "No signing key configured for ${algorithm.id}." + } + + require(key.canSign) { "The signing key for $keyId does not support signing" } + + return signWithSigningKey(key, keyId) + } + + /** + * Builds and returns a JWS compact serialization using a pre-built [SigningKey.SigningOnlyKey]. + * + * @param key the signing key (or key pair) used to produce the signature + * @param keyId optional key ID to embed in the JWT header's `kid` field. Defaults to the + * key ID stored in [key]'s identifier. + * @return the resulting [JwtInstance.Jws] compact serialization + */ + public suspend fun signWith( + key: SigningKey.SigningOnlyKey, + keyId: String? = key.identifier.keyId, + ): JwtInstance.Jws = signWithSigningKey(key, keyId) + + /** + * Builds and returns a JWS compact serialization using a pre-built [SigningKey.SigningKeyPair]. + * + * @param key the signing key (or key pair) used to produce the signature + * @param keyId optional key ID to embed in the JWT header's `kid` field. Defaults to the + * key ID stored in [key]'s identifier. + * @return the resulting [JwtInstance.Jws] compact serialization + */ + public suspend fun signWith( + key: SigningKey.SigningKeyPair, + keyId: String? = key.identifier.keyId, + ): JwtInstance.Jws = signWithSigningKey(key, keyId) + + private suspend fun signWithSigningKey( + key: SigningKey, + keyId: String? = key.identifier.keyId, + ): JwtInstance.Jws { + val header = headerBuilder.build(key.identifier.algorithm, keyId) val payload = payloadBuilder.build() val signingInput = "$header.$payload".encodeToByteArray() - val signature = algorithm.sign(key, signingInput) + val signature = if (key.identifier.algorithm == SigningAlgorithm.None) ByteArray(0) else key.sign(signingInput) return JwtInstance.Jws(header, payload, signature.encodeBase64Url()) } @@ -206,30 +278,107 @@ public class JwtBuilder { /** * Builds and returns an unsecured JWS token with `alg=none` and an empty signature. * - * @param algorithm the [SigningAlgorithm.None] sentinel value * @return the resulting [JwtInstance.Jws] with an empty signature segment * @see co.touchlab.kjwt.parser.JwtParserBuilder.allowUnsecured */ - public suspend fun signWith(algorithm: SigningAlgorithm.None): JwtInstance.Jws = - signWith(algorithm, SimpleKey.Empty) + public suspend fun build(): JwtInstance.Jws = + signWith(SigningAlgorithm.None, SimpleKey.Empty) /** * Builds and returns a JWE compact serialization: * `header.encryptedKey.iv.ciphertext.tag` + * + * @param key the public key used to encrypt the content encryption key + * @param keyAlgorithm the key encryption algorithm used to wrap the content encryption key + * @param contentAlgorithm the content encryption algorithm used to encrypt the payload + * @param keyId optional key ID to embed in the JWE header's `kid` field. Defaults to `null`. + * @return the resulting [JwtInstance.Jwe] compact serialization */ public suspend fun encryptWith( key: PublicKey, keyAlgorithm: EncryptionAlgorithm, contentAlgorithm: EncryptionContentAlgorithm, + keyId: String? = null, + ): JwtInstance.Jwe = encryptWithEncryptionKey( + key = EncryptionKey.EncryptionOnlyKey(EncryptionKey.Identifier(keyAlgorithm, keyId), key), + contentAlgorithm = contentAlgorithm + ) + + /** + * Looks up the public key from [registry] and builds a JWE compact serialization. + * + * The registry is searched using [keyAlgorithm] and [keyId] as the look-up criteria (see + * [JwtKeyRegistry] for the full look-up order). If no matching key is found an + * [IllegalStateException] is thrown. + * + * @param registry the key registry to look up the public encryption key from + * @param keyAlgorithm the key encryption algorithm used to wrap the content encryption key + * @param contentAlgorithm the content encryption algorithm used to encrypt the payload + * @param keyId optional key ID used for registry look-up and embedded in the JWE header's + * `kid` field. Defaults to `null`. + * @return the resulting [JwtInstance.Jwe] compact serialization + * @throws IllegalStateException if no encryption key for [keyAlgorithm] (and [keyId]) is + * found in [registry] + * @see JwtKeyRegistry + */ + public suspend fun encryptWith( + registry: JwtKeyRegistry, + keyAlgorithm: EncryptionAlgorithm, + contentAlgorithm: EncryptionContentAlgorithm, + keyId: String? = null, + ): JwtInstance.Jwe { + val key = requireNotNull(registry.findBestEncryptionKey(keyAlgorithm, keyId)) { + "No signing key configured for ${keyAlgorithm.id}." + } + + require(key.canEncrypt) { "The signing key for $keyId does not support encryption." } + + return encryptWithEncryptionKey(key, contentAlgorithm, keyId) + } + + /** + * Builds and returns a JWE compact serialization using a pre-built [EncryptionKey.EncryptionOnlyKey]. + * + * @param key the encryption key used to wrap the content encryption key + * @param contentAlgorithm the content encryption algorithm used to encrypt the payload + * @param keyId optional key ID to embed in the JWE header's `kid` field. Defaults to the + * key ID stored in [key]'s identifier. + * @return the resulting [JwtInstance.Jwe] compact serialization + */ + public suspend fun encryptWith( + key: EncryptionKey.EncryptionOnlyKey, + contentAlgorithm: EncryptionContentAlgorithm, + keyId: String? = key.identifier.keyId, + ): JwtInstance.Jwe = encryptWithEncryptionKey(key, contentAlgorithm, keyId) + + /** + * Builds and returns a JWE compact serialization using a pre-built [EncryptionKey.EncryptionKeyPair]. + * + * @param key the encryption key used to wrap the content encryption key + * @param contentAlgorithm the content encryption algorithm used to encrypt the payload + * @param keyId optional key ID to embed in the JWE header's `kid` field. Defaults to the + * key ID stored in [key]'s identifier. + * @return the resulting [JwtInstance.Jwe] compact serialization + */ + public suspend fun encryptWith( + key: EncryptionKey.EncryptionKeyPair, + contentAlgorithm: EncryptionContentAlgorithm, + keyId: String? = key.identifier.keyId, + ): JwtInstance.Jwe = encryptWithEncryptionKey(key, contentAlgorithm, keyId) + + private suspend fun encryptWithEncryptionKey( + key: EncryptionKey, + contentAlgorithm: EncryptionContentAlgorithm, + keyId: String? = key.identifier.keyId, ): JwtInstance.Jwe { - val header = headerBuilder.build(keyAlgorithm, contentAlgorithm) + val header = headerBuilder.build(key.identifier.algorithm, contentAlgorithm, keyId) val payload = payloadBuilder.build() val headerB64 = JwtJson.encodeToBase64Url(header) val aad = headerB64.encodeToByteArray() val plaintext = JwtJson.encodeToString(payload).encodeToByteArray() - val result = keyAlgorithm.encrypt(key, contentAlgorithm, plaintext, aad) + val result = key.encrypt(contentAlgorithm, plaintext, aad) return JwtInstance.Jwe( header = header, diff --git a/lib/src/commonMain/kotlin/co/touchlab/kjwt/ext/EncryptionAlgorithmExt.kt b/lib/src/commonMain/kotlin/co/touchlab/kjwt/ext/EncryptionAlgorithmExt.kt new file mode 100644 index 0000000..311ff25 --- /dev/null +++ b/lib/src/commonMain/kotlin/co/touchlab/kjwt/ext/EncryptionAlgorithmExt.kt @@ -0,0 +1,178 @@ +@file:OptIn(dev.whyoleg.cryptography.DelicateCryptographyApi::class) + +package co.touchlab.kjwt.ext + +import co.touchlab.kjwt.cryptography.SimpleKey +import co.touchlab.kjwt.model.algorithm.EncryptionAlgorithm.Dir +import co.touchlab.kjwt.model.algorithm.EncryptionAlgorithm.OAEPBased +import co.touchlab.kjwt.model.registry.EncryptionKey +import dev.whyoleg.cryptography.BinarySize +import dev.whyoleg.cryptography.BinarySize.Companion.bits +import dev.whyoleg.cryptography.CryptographyProvider +import dev.whyoleg.cryptography.algorithms.RSA +import dev.whyoleg.cryptography.bigint.BigInt +import dev.whyoleg.cryptography.bigint.toBigInt +import kotlin.random.Random + +// ---- Dir --------------------------------------------------------------- + +/** + * Wraps an existing raw key [ByteArray] as a [Dir] encryption key. + * + * The byte length must match the Content Encryption Algorithm's required key size + * (e.g. 16 bytes for A128GCM, 32 bytes for A256GCM). The returned [EncryptionKey] is an + * [EncryptionKey.EncryptionKeyPair] usable for both encryption and decryption since `dir` + * uses the same symmetric key for both operations. + * + * @param key the raw symmetric key bytes to wrap. + * @param keyId optional key ID to associate with this key. Defaults to `null`. + * @return an [EncryptionKey] wrapping the provided key bytes. + */ +public fun Dir.key( + key: ByteArray, + keyId: String? = null, +): EncryptionKey.EncryptionKeyPair { + val simpleKey = SimpleKey(key) + return EncryptionKey.EncryptionKeyPair(EncryptionKey.Identifier(this, keyId), simpleKey, simpleKey) +} + +/** + * Generates a new random symmetric key for use with the `dir` algorithm. + * + * The returned [EncryptionKey] is an [EncryptionKey.EncryptionKeyPair] usable for both + * encryption and decryption since `dir` uses the same key for both operations. + * + * @param keySize the size of the key to generate in bits. Must match the Content Encryption + * Algorithm's required key size (e.g. 128, 192, or 256 bits). Defaults to 256 bits. + * @param keyId optional key ID to associate with the generated key. Defaults to `null`. + * @return an [EncryptionKey] wrapping the generated key bytes. + */ +public fun Dir.newKey( + keySize: BinarySize = 256.bits, + keyId: String? = null, +): EncryptionKey.EncryptionKeyPair = + key(Random.nextBytes(keySize.inBytes), keyId) + +// ---- OAEPBased --------------------------------------------------------- + +/** + * Generates a new RSA-OAEP key pair for use with this algorithm. + * + * The returned [EncryptionKey] is an [EncryptionKey.EncryptionKeyPair] containing both the + * public and private key, usable for encryption and decryption. + * + * @param keyId optional key ID to associate with the generated key pair. Defaults to `null`. + * @param keySize the RSA modulus size in bits. Defaults to 4096 bits. + * @param publicExponent the RSA public exponent. Defaults to 65537. + * @param cryptographyProvider the provider used to perform key generation. + * @return an [EncryptionKey] wrapping the generated [RSA.OAEP] key pair. + */ +public suspend fun OAEPBased.newKey( + keyId: String? = null, + keySize: BinarySize = 4096.bits, + publicExponent: BigInt = 65537.toBigInt(), + cryptographyProvider: CryptographyProvider = CryptographyProvider.Default, +): EncryptionKey.EncryptionKeyPair { + val rsaKeyPair = cryptographyProvider.get(RSA.OAEP) + .keyPairGenerator(keySize, digest, publicExponent) + .generateKey() + + return EncryptionKey.EncryptionKeyPair( + EncryptionKey.Identifier(this, keyId), + rsaKeyPair.publicKey, + rsaKeyPair.privateKey, + ) +} + +/** + * Decodes an RSA-OAEP public key from a [ByteArray] for use with this algorithm. + * + * The returned [EncryptionKey] is an [EncryptionKey.EncryptionOnlyKey] that can encrypt tokens + * but cannot decrypt them. + * + * @param key the public key material to decode. + * @param keyId optional key ID to associate with the decoded key. Defaults to `null`. + * @param format the format in which [key] is encoded. Defaults to [RSA.PublicKey.Format.PEM]. + * @param cryptographyProvider the provider used to perform key decoding. + * @return an [EncryptionKey] wrapping the decoded [RSA.OAEP.PublicKey]. + */ +public suspend fun OAEPBased.parsePublicKey( + key: ByteArray, + keyId: String? = null, + format: RSA.PublicKey.Format = RSA.PublicKey.Format.PEM, + cryptographyProvider: CryptographyProvider = CryptographyProvider.Default, +): EncryptionKey.EncryptionOnlyKey { + val parsedKey = cryptographyProvider.get(RSA.OAEP) + .publicKeyDecoder(digest) + .decodeFromByteArray(format, key) + + return EncryptionKey.EncryptionOnlyKey( + EncryptionKey.Identifier(this, keyId), + parsedKey, + ) +} + +/** + * Decodes an RSA-OAEP private key from a [ByteArray] for use with this algorithm. + * + * The returned [EncryptionKey] is an [EncryptionKey.DecryptionOnlyKey] that can decrypt tokens + * but cannot encrypt them. + * + * @param key the private key material to decode. + * @param keyId optional key ID to associate with the decoded key. Defaults to `null`. + * @param format the format in which [key] is encoded. Defaults to [RSA.PrivateKey.Format.PEM]. + * @param cryptographyProvider the provider used to perform key decoding. + * @return an [EncryptionKey] wrapping the decoded [RSA.OAEP.PrivateKey]. + */ +public suspend fun OAEPBased.parsePrivateKey( + key: ByteArray, + keyId: String? = null, + format: RSA.PrivateKey.Format = RSA.PrivateKey.Format.PEM, + cryptographyProvider: CryptographyProvider = CryptographyProvider.Default, +): EncryptionKey.DecryptionOnlyKey { + val parsedKey = cryptographyProvider.get(RSA.OAEP) + .privateKeyDecoder(digest) + .decodeFromByteArray(format, key) + + return EncryptionKey.DecryptionOnlyKey( + EncryptionKey.Identifier(this, keyId), + parsedKey, + ) +} + +/** + * Decodes an RSA-OAEP key pair from separate public and private key [ByteArray]s. + * + * The returned [EncryptionKey] is an [EncryptionKey.EncryptionKeyPair] containing both keys, + * usable for encryption and decryption. + * + * @param publicKey the public key material to decode. + * @param privateKey the private key material to decode. + * @param keyId optional key ID to associate with the decoded key pair. Defaults to `null`. + * @param publicKeyFormat the format in which [publicKey] is encoded. Defaults to [RSA.PublicKey.Format.PEM]. + * @param privateKeyFormat the format in which [privateKey] is encoded. Defaults to [RSA.PrivateKey.Format.PEM]. + * @param cryptographyProvider the provider used to perform key decoding. + * @return an [EncryptionKey] wrapping the decoded [RSA.OAEP] key pair. + */ +public suspend fun OAEPBased.parseKeyPair( + publicKey: ByteArray, + privateKey: ByteArray, + keyId: String? = null, + publicKeyFormat: RSA.PublicKey.Format = RSA.PublicKey.Format.PEM, + privateKeyFormat: RSA.PrivateKey.Format = RSA.PrivateKey.Format.PEM, + cryptographyProvider: CryptographyProvider = CryptographyProvider.Default, +): EncryptionKey.EncryptionKeyPair { + val parsedPublicKey = cryptographyProvider.get(RSA.OAEP) + .publicKeyDecoder(digest) + .decodeFromByteArray(publicKeyFormat, publicKey) + + val parsedPrivateKey = cryptographyProvider.get(RSA.OAEP) + .privateKeyDecoder(digest) + .decodeFromByteArray(privateKeyFormat, privateKey) + + return EncryptionKey.EncryptionKeyPair( + EncryptionKey.Identifier(this, keyId), + parsedPublicKey, + parsedPrivateKey, + ) +} diff --git a/lib/src/commonMain/kotlin/co/touchlab/kjwt/ext/JwkBuilderExt.kt b/lib/src/commonMain/kotlin/co/touchlab/kjwt/ext/JwkBuilderExt.kt index d9b82c9..608071a 100644 --- a/lib/src/commonMain/kotlin/co/touchlab/kjwt/ext/JwkBuilderExt.kt +++ b/lib/src/commonMain/kotlin/co/touchlab/kjwt/ext/JwkBuilderExt.kt @@ -16,10 +16,15 @@ import co.touchlab.kjwt.model.jwk.Jwk * * @param algorithm The HMAC-based signing algorithm (HS256, HS384, or HS512). * @param jwk The Oct JWK containing the raw symmetric key material. + * @param keyId Optional key ID override; when set, it is embedded in the token header's `kid` field. + * Defaults to the JWK's own `kid` field. * @return The signed [JwtInstance.Jws] token. */ -public suspend fun JwtBuilder.signWith(algorithm: SigningAlgorithm.HashBased, jwk: Jwk.Oct): JwtInstance.Jws = - signWith(algorithm, jwk.toHmacKey(algorithm.digest)) +public suspend fun JwtBuilder.signWith( + algorithm: SigningAlgorithm.HashBased, + jwk: Jwk.Oct, + keyId: String? = jwk.kid, +): JwtInstance.Jws = signWith(algorithm, jwk.toHmacKey(algorithm.digest), keyId) // --------------------------------------------------------------------------- // signWith — RSA PKCS1 (RS*) @@ -30,10 +35,15 @@ public suspend fun JwtBuilder.signWith(algorithm: SigningAlgorithm.HashBased, jw * * @param algorithm The RSA PKCS#1-based signing algorithm (RS256, RS384, or RS512). * @param jwk The RSA JWK containing the private key parameters. + * @param keyId Optional key ID override; when set, it is embedded in the token header's `kid` field. + * Defaults to the JWK's own `kid` field. * @return The signed [JwtInstance.Jws] token. */ -public suspend fun JwtBuilder.signWith(algorithm: SigningAlgorithm.PKCS1Based, jwk: Jwk.Rsa): JwtInstance.Jws = - signWith(algorithm, jwk.toRsaPkcs1PrivateKey(algorithm.digest)) +public suspend fun JwtBuilder.signWith( + algorithm: SigningAlgorithm.PKCS1Based, + jwk: Jwk.Rsa, + keyId: String? = jwk.kid, +): JwtInstance.Jws = signWith(algorithm, jwk.toRsaPkcs1PrivateKey(algorithm.digest), keyId) // --------------------------------------------------------------------------- // signWith — RSA PSS (PS*) @@ -44,10 +54,15 @@ public suspend fun JwtBuilder.signWith(algorithm: SigningAlgorithm.PKCS1Based, j * * @param algorithm The RSA PSS-based signing algorithm (PS256, PS384, or PS512). * @param jwk The RSA JWK containing the private key parameters. + * @param keyId Optional key ID override; when set, it is embedded in the token header's `kid` field. + * Defaults to the JWK's own `kid` field. * @return The signed [JwtInstance.Jws] token. */ -public suspend fun JwtBuilder.signWith(algorithm: SigningAlgorithm.PSSBased, jwk: Jwk.Rsa): JwtInstance.Jws = - signWith(algorithm, jwk.toRsaPssPrivateKey(algorithm.digest)) +public suspend fun JwtBuilder.signWith( + algorithm: SigningAlgorithm.PSSBased, + jwk: Jwk.Rsa, + keyId: String? = jwk.kid, +): JwtInstance.Jws = signWith(algorithm, jwk.toRsaPssPrivateKey(algorithm.digest), keyId) // --------------------------------------------------------------------------- // signWith — ECDSA (ES*) @@ -58,10 +73,15 @@ public suspend fun JwtBuilder.signWith(algorithm: SigningAlgorithm.PSSBased, jwk * * @param algorithm The ECDSA-based signing algorithm (ES256, ES384, or ES512). * @param jwk The EC JWK containing the private key parameter `d`. + * @param keyId Optional key ID override; when set, it is embedded in the token header's `kid` field. + * Defaults to the JWK's own `kid` field. * @return The signed [JwtInstance.Jws] token. */ -public suspend fun JwtBuilder.signWith(algorithm: SigningAlgorithm.ECDSABased, jwk: Jwk.Ec): JwtInstance.Jws = - signWith(algorithm, jwk.toEcdsaPrivateKey()) +public suspend fun JwtBuilder.signWith( + algorithm: SigningAlgorithm.ECDSABased, + jwk: Jwk.Ec, + keyId: String? = jwk.kid, +): JwtInstance.Jws = signWith(algorithm, jwk.toEcdsaPrivateKey(), keyId) // --------------------------------------------------------------------------- // encryptWith — RSA-OAEP / RSA-OAEP-256 @@ -73,6 +93,8 @@ public suspend fun JwtBuilder.signWith(algorithm: SigningAlgorithm.ECDSABased, j * @param jwk The RSA JWK containing the public key parameters `n` and `e`. * @param keyAlgorithm The OAEP-based key encryption algorithm (RSA-OAEP or RSA-OAEP-256). * @param contentAlgorithm The content encryption algorithm to use for the JWE payload. + * @param keyId Optional key ID override; when set, it is embedded in the token header's `kid` field. + * Defaults to the JWK's own `kid` field. * @return The encrypted [JwtInstance.Jwe] token. */ @OptIn(dev.whyoleg.cryptography.DelicateCryptographyApi::class) @@ -80,5 +102,6 @@ public suspend fun JwtBuilder.encryptWith( jwk: Jwk.Rsa, keyAlgorithm: EncryptionAlgorithm.OAEPBased, contentAlgorithm: EncryptionContentAlgorithm, + keyId: String? = jwk.kid, ): JwtInstance.Jwe = - encryptWith(jwk.toRsaOaepPublicKey(keyAlgorithm.digest), keyAlgorithm, contentAlgorithm) + encryptWith(jwk.toRsaOaepPublicKey(keyAlgorithm.digest), keyAlgorithm, contentAlgorithm, keyId) diff --git a/lib/src/commonMain/kotlin/co/touchlab/kjwt/ext/JwkParserExt.kt b/lib/src/commonMain/kotlin/co/touchlab/kjwt/ext/JwkParserExt.kt index 3b8c442..0bc4b2f 100644 --- a/lib/src/commonMain/kotlin/co/touchlab/kjwt/ext/JwkParserExt.kt +++ b/lib/src/commonMain/kotlin/co/touchlab/kjwt/ext/JwkParserExt.kt @@ -18,15 +18,21 @@ import dev.whyoleg.cryptography.algorithms.SHA512 * * @param algorithm The HMAC-based signing algorithm (HS256, HS384, or HS512). * @param jwk The Oct JWK containing the raw symmetric key material. + * @param keyId Optional key ID override; when set, the parser will only use this key if the token's + * `kid` header matches. Defaults to the JWK's own `kid` field. * @return This builder, configured with the HMAC verification key. */ -public suspend fun JwtParserBuilder.verifyWith(algorithm: SigningAlgorithm.HashBased, jwk: Jwk.Oct): JwtParserBuilder { +public suspend fun JwtParserBuilder.verifyWith( + algorithm: SigningAlgorithm.HashBased, + jwk: Jwk.Oct, + keyId: String? = jwk.kid, +): JwtParserBuilder { val digest = when (algorithm) { SigningAlgorithm.HS256 -> SHA256 SigningAlgorithm.HS384 -> SHA384 SigningAlgorithm.HS512 -> SHA512 } - return verifyWith(algorithm, jwk.toHmacKey(digest)) + return verifyWith(algorithm, jwk.toHmacKey(digest), keyId) } // --------------------------------------------------------------------------- @@ -38,15 +44,21 @@ public suspend fun JwtParserBuilder.verifyWith(algorithm: SigningAlgorithm.HashB * * @param algorithm The RSA PKCS#1-based signing algorithm (RS256, RS384, or RS512). * @param jwk The RSA JWK containing the public key parameters `n` and `e`. + * @param keyId Optional key ID override; when set, the parser will only use this key if the token's + * `kid` header matches. Defaults to the JWK's own `kid` field. * @return This builder, configured with the RSA PKCS#1 verification key. */ -public suspend fun JwtParserBuilder.verifyWith(algorithm: SigningAlgorithm.PKCS1Based, jwk: Jwk.Rsa): JwtParserBuilder { +public suspend fun JwtParserBuilder.verifyWith( + algorithm: SigningAlgorithm.PKCS1Based, + jwk: Jwk.Rsa, + keyId: String? = jwk.kid, +): JwtParserBuilder { val digest = when (algorithm) { SigningAlgorithm.RS256 -> SHA256 SigningAlgorithm.RS384 -> SHA384 SigningAlgorithm.RS512 -> SHA512 } - return verifyWith(algorithm, jwk.toRsaPkcs1PublicKey(digest)) + return verifyWith(algorithm, jwk.toRsaPkcs1PublicKey(digest), keyId) } // --------------------------------------------------------------------------- @@ -58,15 +70,21 @@ public suspend fun JwtParserBuilder.verifyWith(algorithm: SigningAlgorithm.PKCS1 * * @param algorithm The RSA PSS-based signing algorithm (PS256, PS384, or PS512). * @param jwk The RSA JWK containing the public key parameters `n` and `e`. + * @param keyId Optional key ID override; when set, the parser will only use this key if the token's + * `kid` header matches. Defaults to the JWK's own `kid` field. * @return This builder, configured with the RSA PSS verification key. */ -public suspend fun JwtParserBuilder.verifyWith(algorithm: SigningAlgorithm.PSSBased, jwk: Jwk.Rsa): JwtParserBuilder { +public suspend fun JwtParserBuilder.verifyWith( + algorithm: SigningAlgorithm.PSSBased, + jwk: Jwk.Rsa, + keyId: String? = jwk.kid, +): JwtParserBuilder { val digest = when (algorithm) { SigningAlgorithm.PS256 -> SHA256 SigningAlgorithm.PS384 -> SHA384 SigningAlgorithm.PS512 -> SHA512 } - return verifyWith(algorithm, jwk.toRsaPssPublicKey(digest)) + return verifyWith(algorithm, jwk.toRsaPssPublicKey(digest), keyId) } // --------------------------------------------------------------------------- @@ -78,10 +96,15 @@ public suspend fun JwtParserBuilder.verifyWith(algorithm: SigningAlgorithm.PSSBa * * @param algorithm The ECDSA-based signing algorithm (ES256, ES384, or ES512). * @param jwk The EC JWK containing the public key parameters `crv`, `x`, and `y`. + * @param keyId Optional key ID override; when set, the parser will only use this key if the token's + * `kid` header matches. Defaults to the JWK's own `kid` field. * @return This builder, configured with the ECDSA verification key. */ -public suspend fun JwtParserBuilder.verifyWith(algorithm: SigningAlgorithm.ECDSABased, jwk: Jwk.Ec): JwtParserBuilder = - verifyWith(algorithm, jwk.toEcdsaPublicKey()) +public suspend fun JwtParserBuilder.verifyWith( + algorithm: SigningAlgorithm.ECDSABased, + jwk: Jwk.Ec, + keyId: String? = jwk.kid, +): JwtParserBuilder = verifyWith(algorithm, jwk.toEcdsaPublicKey(), keyId) // --------------------------------------------------------------------------- // decryptWith — RSA-OAEP / RSA-OAEP-256 @@ -92,16 +115,19 @@ public suspend fun JwtParserBuilder.verifyWith(algorithm: SigningAlgorithm.ECDSA * * @param algorithm The OAEP-based key encryption algorithm (RSA-OAEP or RSA-OAEP-256). * @param jwk The RSA JWK containing the private key parameters, including `d` and the CRT parameters. + * @param keyId Optional key ID override; when set, the parser will only use this key if the token's + * `kid` header matches. Defaults to the JWK's own `kid` field. * @return This builder, configured with the RSA OAEP decryption key. */ @OptIn(dev.whyoleg.cryptography.DelicateCryptographyApi::class) public suspend fun JwtParserBuilder.decryptWith( algorithm: EncryptionAlgorithm.OAEPBased, jwk: Jwk.Rsa, + keyId: String? = jwk.kid, ): JwtParserBuilder { val digest = when (algorithm) { EncryptionAlgorithm.RsaOaep -> SHA1 EncryptionAlgorithm.RsaOaep256 -> SHA256 } - return decryptWith(algorithm, jwk.toRsaOaepPrivateKey(digest)) + return decryptWith(algorithm, jwk.toRsaOaepPrivateKey(digest), keyId) } diff --git a/lib/src/commonMain/kotlin/co/touchlab/kjwt/ext/JwtBuilderExt.kt b/lib/src/commonMain/kotlin/co/touchlab/kjwt/ext/JwtBuilderExt.kt index 621a037..04286cd 100644 --- a/lib/src/commonMain/kotlin/co/touchlab/kjwt/ext/JwtBuilderExt.kt +++ b/lib/src/commonMain/kotlin/co/touchlab/kjwt/ext/JwtBuilderExt.kt @@ -21,12 +21,14 @@ import dev.whyoleg.cryptography.algorithms.SHA512 * @param algorithm the HMAC-based signing algorithm (HS256, HS384, or HS512). * @param key the HMAC key material encoded as a String. * @param keyFormat the format in which [key] is encoded. + * @param keyId optional key ID to embed in the token header's `kid` field. Defaults to `null`. * @return the signed [JwtInstance.Jws] token. */ public suspend fun JwtBuilder.signWith( algorithm: SigningAlgorithm.HashBased, key: String, keyFormat: HMAC.Key.Format, + keyId: String? = null, ): JwtInstance.Jws { val parsedKey = CryptographyProvider.Default.get(HMAC) .keyDecoder( @@ -38,7 +40,7 @@ public suspend fun JwtBuilder.signWith( ) .decodeFromByteArray(keyFormat, key.encodeToByteArray()) - return signWith(algorithm, parsedKey) + return signWith(algorithm, parsedKey, keyId) } /** @@ -47,12 +49,14 @@ public suspend fun JwtBuilder.signWith( * @param algorithm the RSA PKCS#1-based signing algorithm (RS256, RS384, or RS512). * @param key the RSA private key material encoded as a String. * @param keyFormat the format in which [key] is encoded. + * @param keyId optional key ID to embed in the token header's `kid` field. Defaults to `null`. * @return the signed [JwtInstance.Jws] token. */ public suspend fun JwtBuilder.signWith( algorithm: SigningAlgorithm.PKCS1Based, key: String, keyFormat: RSA.PrivateKey.Format, + keyId: String? = null, ): JwtInstance.Jws { val parsedKey = CryptographyProvider.Default.get(RSA.PKCS1) .privateKeyDecoder( @@ -64,7 +68,7 @@ public suspend fun JwtBuilder.signWith( ) .decodeFromByteArray(keyFormat, key.encodeToByteArray()) - return signWith(algorithm, parsedKey) + return signWith(algorithm, parsedKey, keyId) } /** @@ -73,12 +77,14 @@ public suspend fun JwtBuilder.signWith( * @param algorithm the RSA PSS-based signing algorithm (PS256, PS384, or PS512). * @param key the RSA private key material encoded as a String. * @param keyFormat the format in which [key] is encoded. + * @param keyId optional key ID to embed in the token header's `kid` field. Defaults to `null`. * @return the signed [JwtInstance.Jws] token. */ public suspend fun JwtBuilder.signWith( algorithm: SigningAlgorithm.PSSBased, key: String, keyFormat: RSA.PrivateKey.Format, + keyId: String? = null, ): JwtInstance.Jws { val parsedKey = CryptographyProvider.Default.get(RSA.PSS) .privateKeyDecoder( @@ -90,7 +96,7 @@ public suspend fun JwtBuilder.signWith( ) .decodeFromByteArray(keyFormat, key.encodeToByteArray()) - return signWith(algorithm, parsedKey) + return signWith(algorithm, parsedKey, keyId) } /** @@ -99,24 +105,20 @@ public suspend fun JwtBuilder.signWith( * @param algorithm the ECDSA-based signing algorithm (ES256, ES384, or ES512). * @param key the EC private key material encoded as a String. * @param keyFormat the format in which [key] is encoded. + * @param keyId optional key ID to embed in the token header's `kid` field. Defaults to `null`. * @return the signed [JwtInstance.Jws] token. */ public suspend fun JwtBuilder.signWith( algorithm: SigningAlgorithm.ECDSABased, key: String, keyFormat: EC.PrivateKey.Format, + keyId: String? = null, ): JwtInstance.Jws { val parsedKey = CryptographyProvider.Default.get(ECDSA) - .privateKeyDecoder( - when (algorithm) { - SigningAlgorithm.ES256 -> EC.Curve.P256 - SigningAlgorithm.ES384 -> EC.Curve.P384 - SigningAlgorithm.ES512 -> EC.Curve.P521 - } - ) + .privateKeyDecoder(algorithm.curve) .decodeFromByteArray(keyFormat, key.encodeToByteArray()) - return signWith(algorithm, parsedKey) + return signWith(algorithm, parsedKey, keyId) } /** @@ -125,13 +127,15 @@ public suspend fun JwtBuilder.signWith( * @param key the raw symmetric key bytes used for direct encryption. * @param keyAlgorithm the direct key encryption algorithm ([EncryptionAlgorithm.Dir]). * @param contentAlgorithm the content encryption algorithm to apply to the JWT payload. + * @param keyId optional key ID to embed in the token header's `kid` field. Defaults to `null`. * @return the encrypted [JwtInstance.Jwe] token. */ public suspend fun JwtBuilder.encryptWith( key: ByteArray, keyAlgorithm: EncryptionAlgorithm.Dir, contentAlgorithm: EncryptionContentAlgorithm, -): JwtInstance.Jwe = encryptWith(SimpleKey(key), keyAlgorithm, contentAlgorithm) + keyId: String? = null, +): JwtInstance.Jwe = encryptWith(SimpleKey(key), keyAlgorithm, contentAlgorithm, keyId) /** * Encrypts the JWT using the direct key algorithm (`dir`) with a key supplied as a UTF-8 String. @@ -141,10 +145,12 @@ public suspend fun JwtBuilder.encryptWith( * @param key the symmetric key as a UTF-8 string. * @param keyAlgorithm the direct key encryption algorithm ([EncryptionAlgorithm.Dir]). * @param contentAlgorithm the content encryption algorithm to apply to the JWT payload. + * @param keyId optional key ID to embed in the token header's `kid` field. Defaults to `null`. * @return the encrypted [JwtInstance.Jwe] token. */ public suspend fun JwtBuilder.encryptWith( key: String, keyAlgorithm: EncryptionAlgorithm.Dir, contentAlgorithm: EncryptionContentAlgorithm, -): JwtInstance.Jwe = encryptWith(key.encodeToByteArray(), keyAlgorithm, contentAlgorithm) + keyId: String? = null, +): JwtInstance.Jwe = encryptWith(key.encodeToByteArray(), keyAlgorithm, contentAlgorithm, keyId) diff --git a/lib/src/commonMain/kotlin/co/touchlab/kjwt/ext/SigningAlgorithmsExt.kt b/lib/src/commonMain/kotlin/co/touchlab/kjwt/ext/SigningAlgorithmsExt.kt new file mode 100644 index 0000000..1fabfb8 --- /dev/null +++ b/lib/src/commonMain/kotlin/co/touchlab/kjwt/ext/SigningAlgorithmsExt.kt @@ -0,0 +1,424 @@ +package co.touchlab.kjwt.ext + +import co.touchlab.kjwt.model.algorithm.SigningAlgorithm.ECDSABased +import co.touchlab.kjwt.model.algorithm.SigningAlgorithm.HashBased +import co.touchlab.kjwt.model.algorithm.SigningAlgorithm.PKCS1Based +import co.touchlab.kjwt.model.algorithm.SigningAlgorithm.PSSBased +import co.touchlab.kjwt.model.registry.SigningKey +import dev.whyoleg.cryptography.BinarySize +import dev.whyoleg.cryptography.BinarySize.Companion.bits +import dev.whyoleg.cryptography.CryptographyProvider +import dev.whyoleg.cryptography.algorithms.EC +import dev.whyoleg.cryptography.algorithms.ECDSA +import dev.whyoleg.cryptography.algorithms.HMAC +import dev.whyoleg.cryptography.algorithms.RSA +import dev.whyoleg.cryptography.bigint.BigInt +import dev.whyoleg.cryptography.bigint.toBigInt + +/** + * Generates a new random HMAC key for use with this algorithm. + * + * The returned [SigningKey] is a [SigningKey.SigningKeyPair] usable for both signing and + * verification since HMAC uses a single symmetric key. + * + * @param keyId optional key ID to associate with the generated key. Defaults to `null`. + * @param cryptographyProvider the provider used to perform key generation. + * @return a [SigningKey] wrapping the generated [HMAC.Key]. + */ +public suspend fun HashBased.newKey( + keyId: String? = null, + cryptographyProvider: CryptographyProvider = CryptographyProvider.Default, +): SigningKey.SigningKeyPair { + val macKey = cryptographyProvider.get(HMAC) + .keyGenerator(digest) + .generateKey() + + return SigningKey.SigningKeyPair(identifier(keyId), macKey, macKey) +} + +/** + * Decodes an existing HMAC key from a [ByteArray] for use with this algorithm. + * + * The returned [SigningKey] is a [SigningKey.SigningKeyPair] usable for both signing and + * verification since HMAC uses a single symmetric key. + * + * @param key the raw key material to decode. + * @param keyId optional key ID to associate with the decoded key. Defaults to `null`. + * @param format the format in which [key] is encoded. Defaults to [HMAC.Key.Format.RAW]. + * @param cryptographyProvider the provider used to perform key decoding. + * @return a [SigningKey] wrapping the decoded [HMAC.Key]. + */ +public suspend fun HashBased.parse( + key: ByteArray, + keyId: String? = null, + format: HMAC.Key.Format = HMAC.Key.Format.RAW, + cryptographyProvider: CryptographyProvider = CryptographyProvider.Default, +): SigningKey.SigningKeyPair { + val macKey = cryptographyProvider.get(HMAC) + .keyDecoder(digest) + .decodeFromByteArray(format, key) + + return SigningKey.SigningKeyPair(identifier(keyId), macKey, macKey) +} + +/** + * Generates a new RSA PKCS#1 v1.5 key pair for use with this algorithm. + * + * The returned [SigningKey] is a [SigningKey.SigningKeyPair] containing both the public and + * private key, usable for signing and verification. + * + * @param keyId optional key ID to associate with the generated key pair. Defaults to `null`. + * @param keySize the RSA modulus size in bits. Defaults to 4096 bits. + * @param publicExponent the RSA public exponent. Defaults to 65537. + * @param cryptographyProvider the provider used to perform key generation. + * @return a [SigningKey] wrapping the generated [RSA.PKCS1] key pair. + */ +public suspend fun PKCS1Based.newKey( + keyId: String? = null, + keySize: BinarySize = 4096.bits, + publicExponent: BigInt = 65537.toBigInt(), + cryptographyProvider: CryptographyProvider = CryptographyProvider.Default, +): SigningKey.SigningKeyPair { + val rsaKeyPair = cryptographyProvider.get(RSA.PKCS1) + .keyPairGenerator(keySize, digest, publicExponent) + .generateKey() + + return SigningKey.SigningKeyPair( + identifier(keyId), + rsaKeyPair.publicKey, + rsaKeyPair.privateKey + ) +} + +/** + * Decodes an RSA PKCS#1 v1.5 public key from a [ByteArray] for use with this algorithm. + * + * The returned [SigningKey] is a [SigningKey.VerifyOnlyKey] that can verify signatures but + * cannot produce them. + * + * @param key the public key material to decode. + * @param keyId optional key ID to associate with the decoded key. Defaults to `null`. + * @param format the format in which [key] is encoded. Defaults to [RSA.PublicKey.Format.PEM]. + * @param cryptographyProvider the provider used to perform key decoding. + * @return a [SigningKey] wrapping the decoded [RSA.PKCS1.PublicKey]. + */ +public suspend fun PKCS1Based.parsePublicKey( + key: ByteArray, + keyId: String? = null, + format: RSA.PublicKey.Format = RSA.PublicKey.Format.PEM, + cryptographyProvider: CryptographyProvider = CryptographyProvider.Default, +): SigningKey.VerifyOnlyKey { + val parsedKey = cryptographyProvider.get(RSA.PKCS1) + .publicKeyDecoder(digest) + .decodeFromByteArray(format, key) + + return SigningKey.VerifyOnlyKey( + identifier(keyId), + parsedKey, + ) +} + +/** + * Decodes an RSA PKCS#1 v1.5 private key from a [ByteArray] for use with this algorithm. + * + * The returned [SigningKey] is a [SigningKey.SigningOnlyKey] that can produce signatures but + * cannot verify them. + * + * @param key the private key material to decode. + * @param keyId optional key ID to associate with the decoded key. Defaults to `null`. + * @param format the format in which [key] is encoded. Defaults to [RSA.PrivateKey.Format.PEM]. + * @param cryptographyProvider the provider used to perform key decoding. + * @return a [SigningKey] wrapping the decoded [RSA.PKCS1.PrivateKey]. + */ +public suspend fun PKCS1Based.parsePrivateKey( + key: ByteArray, + keyId: String? = null, + format: RSA.PrivateKey.Format = RSA.PrivateKey.Format.PEM, + cryptographyProvider: CryptographyProvider = CryptographyProvider.Default, +): SigningKey.SigningOnlyKey { + val parsedKey = cryptographyProvider.get(RSA.PKCS1) + .privateKeyDecoder(digest) + .decodeFromByteArray(format, key) + + return SigningKey.SigningOnlyKey( + identifier(keyId), + parsedKey, + ) +} + +/** + * Decodes an RSA PKCS#1 v1.5 key pair from separate public and private key [ByteArray]s. + * + * The returned [SigningKey] is a [SigningKey.SigningKeyPair] containing both keys, usable for + * signing and verification. + * + * @param publicKey the public key material to decode. + * @param privateKey the private key material to decode. + * @param keyId optional key ID to associate with the decoded key pair. Defaults to `null`. + * @param publicKeyFormat the format in which [publicKey] is encoded. Defaults to [RSA.PublicKey.Format.PEM]. + * @param privateKeyFormat the format in which [privateKey] is encoded. Defaults to [RSA.PrivateKey.Format.PEM]. + * @param cryptographyProvider the provider used to perform key decoding. + * @return a [SigningKey] wrapping the decoded [RSA.PKCS1] key pair. + */ +public suspend fun PKCS1Based.parseKeyPair( + publicKey: ByteArray, + privateKey: ByteArray, + keyId: String? = null, + publicKeyFormat: RSA.PublicKey.Format = RSA.PublicKey.Format.PEM, + privateKeyFormat: RSA.PrivateKey.Format = RSA.PrivateKey.Format.PEM, + cryptographyProvider: CryptographyProvider = CryptographyProvider.Default, +): SigningKey.SigningKeyPair { + val parsedPublicKey = cryptographyProvider.get(RSA.PKCS1) + .publicKeyDecoder(digest) + .decodeFromByteArray(publicKeyFormat, publicKey) + + val parsedPrivateKey = cryptographyProvider.get(RSA.PKCS1) + .privateKeyDecoder(digest) + .decodeFromByteArray(privateKeyFormat, privateKey) + + return SigningKey.SigningKeyPair( + identifier(keyId), + parsedPublicKey, + parsedPrivateKey, + ) +} + +/** + * Generates a new RSA PSS key pair for use with this algorithm. + * + * The returned [SigningKey] is a [SigningKey.SigningKeyPair] containing both the public and + * private key, usable for signing and verification. + * + * @param keyId optional key ID to associate with the generated key pair. Defaults to `null`. + * @param keySize the RSA modulus size in bits. Defaults to 4096 bits. + * @param publicExponent the RSA public exponent. Defaults to 65537. + * @param cryptographyProvider the provider used to perform key generation. + * @return a [SigningKey] wrapping the generated [RSA.PSS] key pair. + */ +public suspend fun PSSBased.newKey( + keyId: String? = null, + keySize: BinarySize = 4096.bits, + publicExponent: BigInt = 65537.toBigInt(), + cryptographyProvider: CryptographyProvider = CryptographyProvider.Default, +): SigningKey.SigningKeyPair { + val rsaKeyPair = cryptographyProvider.get(RSA.PSS) + .keyPairGenerator(keySize, digest, publicExponent) + .generateKey() + + return SigningKey.SigningKeyPair( + identifier(keyId), + rsaKeyPair.publicKey, + rsaKeyPair.privateKey + ) +} + +/** + * Decodes an RSA PSS public key from a [ByteArray] for use with this algorithm. + * + * The returned [SigningKey] is a [SigningKey.VerifyOnlyKey] that can verify signatures but + * cannot produce them. + * + * @param key the public key material to decode. + * @param keyId optional key ID to associate with the decoded key. Defaults to `null`. + * @param format the format in which [key] is encoded. Defaults to [RSA.PublicKey.Format.PEM]. + * @param cryptographyProvider the provider used to perform key decoding. + * @return a [SigningKey] wrapping the decoded [RSA.PSS.PublicKey]. + */ +public suspend fun PSSBased.parsePublicKey( + key: ByteArray, + keyId: String? = null, + format: RSA.PublicKey.Format = RSA.PublicKey.Format.PEM, + cryptographyProvider: CryptographyProvider = CryptographyProvider.Default, +): SigningKey.VerifyOnlyKey { + val parsedKey = cryptographyProvider.get(RSA.PSS) + .publicKeyDecoder(digest) + .decodeFromByteArray(format, key) + + return SigningKey.VerifyOnlyKey( + identifier(keyId), + parsedKey, + ) +} + +/** + * Decodes an RSA PSS private key from a [ByteArray] for use with this algorithm. + * + * The returned [SigningKey] is a [SigningKey.SigningOnlyKey] that can produce signatures but + * cannot verify them. + * + * @param key the private key material to decode. + * @param keyId optional key ID to associate with the decoded key. Defaults to `null`. + * @param format the format in which [key] is encoded. Defaults to [RSA.PrivateKey.Format.PEM]. + * @param cryptographyProvider the provider used to perform key decoding. + * @return a [SigningKey] wrapping the decoded [RSA.PSS.PrivateKey]. + */ +public suspend fun PSSBased.parsePrivateKey( + key: ByteArray, + keyId: String? = null, + format: RSA.PrivateKey.Format = RSA.PrivateKey.Format.PEM, + cryptographyProvider: CryptographyProvider = CryptographyProvider.Default, +): SigningKey.SigningOnlyKey { + val parsedKey = cryptographyProvider.get(RSA.PSS) + .privateKeyDecoder(digest) + .decodeFromByteArray(format, key) + + return SigningKey.SigningOnlyKey( + identifier(keyId), + parsedKey, + ) +} + +/** + * Decodes an RSA PSS key pair from separate public and private key [ByteArray]s. + * + * The returned [SigningKey] is a [SigningKey.SigningKeyPair] containing both keys, usable for + * signing and verification. + * + * @param publicKey the public key material to decode. + * @param privateKey the private key material to decode. + * @param keyId optional key ID to associate with the decoded key pair. Defaults to `null`. + * @param publicKeyFormat the format in which [publicKey] is encoded. Defaults to [RSA.PublicKey.Format.PEM]. + * @param privateKeyFormat the format in which [privateKey] is encoded. Defaults to [RSA.PrivateKey.Format.PEM]. + * @param cryptographyProvider the provider used to perform key decoding. + * @return a [SigningKey] wrapping the decoded [RSA.PSS] key pair. + */ +public suspend fun PSSBased.parseKeyPair( + publicKey: ByteArray, + privateKey: ByteArray, + keyId: String? = null, + publicKeyFormat: RSA.PublicKey.Format = RSA.PublicKey.Format.PEM, + privateKeyFormat: RSA.PrivateKey.Format = RSA.PrivateKey.Format.PEM, + cryptographyProvider: CryptographyProvider = CryptographyProvider.Default, +): SigningKey.SigningKeyPair { + val parsedPublicKey = cryptographyProvider.get(RSA.PSS) + .publicKeyDecoder(digest) + .decodeFromByteArray(publicKeyFormat, publicKey) + + val parsedPrivateKey = cryptographyProvider.get(RSA.PSS) + .privateKeyDecoder(digest) + .decodeFromByteArray(privateKeyFormat, privateKey) + + return SigningKey.SigningKeyPair( + identifier(keyId), + parsedPublicKey, + parsedPrivateKey, + ) +} + +/** + * Generates a new ECDSA key pair for use with this algorithm. + * + * The returned [SigningKey] is a [SigningKey.SigningKeyPair] containing both the public and + * private key, usable for signing and verification. + * + * @param keyId optional key ID to associate with the generated key pair. Defaults to `null`. + * @param cryptographyProvider the provider used to perform key generation. + * @return a [SigningKey] wrapping the generated [ECDSA] key pair. + */ +public suspend fun ECDSABased.newKey( + keyId: String? = null, + cryptographyProvider: CryptographyProvider = CryptographyProvider.Default, +): SigningKey.SigningKeyPair { + val rsaKeyPair = cryptographyProvider.get(ECDSA) + .keyPairGenerator(curve) + .generateKey() + + return SigningKey.SigningKeyPair( + identifier(keyId), + rsaKeyPair.publicKey, + rsaKeyPair.privateKey + ) +} + +/** + * Decodes an ECDSA public key from a [ByteArray] for use with this algorithm. + * + * The returned [SigningKey] is a [SigningKey.VerifyOnlyKey] that can verify signatures but + * cannot produce them. + * + * @param key the public key material to decode. + * @param keyId optional key ID to associate with the decoded key. Defaults to `null`. + * @param format the format in which [key] is encoded. Defaults to [EC.PublicKey.Format.RAW]. + * @param cryptographyProvider the provider used to perform key decoding. + * @return a [SigningKey] wrapping the decoded [ECDSA.PublicKey]. + */ +public suspend fun ECDSABased.parsePublicKey( + key: ByteArray, + keyId: String? = null, + format: EC.PublicKey.Format = EC.PublicKey.Format.RAW, + cryptographyProvider: CryptographyProvider = CryptographyProvider.Default, +): SigningKey.VerifyOnlyKey { + val parsedKey = cryptographyProvider.get(ECDSA) + .publicKeyDecoder(curve) + .decodeFromByteArray(format, key) + + return SigningKey.VerifyOnlyKey( + identifier(keyId), + parsedKey, + ) +} + +/** + * Decodes an ECDSA private key from a [ByteArray] for use with this algorithm. + * + * The returned [SigningKey] is a [SigningKey.SigningOnlyKey] that can produce signatures but + * cannot verify them. + * + * @param key the private key material to decode. + * @param keyId optional key ID to associate with the decoded key. Defaults to `null`. + * @param format the format in which [key] is encoded. Defaults to [EC.PrivateKey.Format.RAW]. + * @param cryptographyProvider the provider used to perform key decoding. + * @return a [SigningKey] wrapping the decoded [ECDSA.PrivateKey]. + */ +public suspend fun ECDSABased.parsePrivateKey( + key: ByteArray, + keyId: String? = null, + format: EC.PrivateKey.Format = EC.PrivateKey.Format.RAW, + cryptographyProvider: CryptographyProvider = CryptographyProvider.Default, +): SigningKey.SigningOnlyKey { + val parsedKey = cryptographyProvider.get(ECDSA) + .privateKeyDecoder(curve) + .decodeFromByteArray(format, key) + + return SigningKey.SigningOnlyKey( + identifier(keyId), + parsedKey, + ) +} + +/** + * Decodes an ECDSA key pair from separate public and private key [ByteArray]s. + * + * The returned [SigningKey] is a [SigningKey.SigningKeyPair] containing both keys, usable for + * signing and verification. + * + * @param publicKey the public key material to decode. + * @param privateKey the private key material to decode. + * @param keyId optional key ID to associate with the decoded key pair. Defaults to `null`. + * @param publicKeyFormat the format in which [publicKey] is encoded. Defaults to [EC.PublicKey.Format.RAW]. + * @param privateKeyFormat the format in which [privateKey] is encoded. Defaults to [EC.PrivateKey.Format.RAW]. + * @param cryptographyProvider the provider used to perform key decoding. + * @return a [SigningKey] wrapping the decoded [ECDSA] key pair. + */ +public suspend fun ECDSABased.parseKeyPair( + publicKey: ByteArray, + privateKey: ByteArray, + keyId: String? = null, + publicKeyFormat: EC.PublicKey.Format = EC.PublicKey.Format.RAW, + privateKeyFormat: EC.PrivateKey.Format = EC.PrivateKey.Format.RAW, + cryptographyProvider: CryptographyProvider = CryptographyProvider.Default, +): SigningKey.SigningKeyPair { + val parsedPublicKey = cryptographyProvider.get(ECDSA) + .publicKeyDecoder(curve) + .decodeFromByteArray(publicKeyFormat, publicKey) + + val parsedPrivateKey = cryptographyProvider.get(ECDSA) + .privateKeyDecoder(curve) + .decodeFromByteArray(privateKeyFormat, privateKey) + + return SigningKey.SigningKeyPair( + identifier(keyId), + parsedPublicKey, + parsedPrivateKey, + ) +} diff --git a/lib/src/commonMain/kotlin/co/touchlab/kjwt/model/JwtHeader.kt b/lib/src/commonMain/kotlin/co/touchlab/kjwt/model/JwtHeader.kt index ce30cec..a6fcae8 100644 --- a/lib/src/commonMain/kotlin/co/touchlab/kjwt/model/JwtHeader.kt +++ b/lib/src/commonMain/kotlin/co/touchlab/kjwt/model/JwtHeader.kt @@ -111,13 +111,6 @@ public class JwtHeader internal constructor( extra(CTY, value) } - /** The key ID (`kid`) header parameter identifying the key used to sign or encrypt the token. */ - public var keyId: String? = null - set(value) { - field = value - extra(KID, value) - } - /** * Sets an extra header parameter using a pre-built [JsonElement], or removes it if [value] is `null`. * @@ -153,19 +146,22 @@ public class JwtHeader internal constructor( extra(name, kotlinx.serialization.serializer(), value) } - internal fun build(algorithm: SigningAlgorithm<*, *>) = JwtHeader( + internal fun build(algorithm: SigningAlgorithm<*, *>, keyId: String?) = JwtHeader( buildToJson { put(ALG, algorithm.id) + if (keyId != null) put(KID, keyId) } ) internal fun build( keyAlgorithm: EncryptionAlgorithm<*, *>, contentAlgorithm: EncryptionContentAlgorithm, + keyId: String?, ) = JwtHeader( buildToJson { put(ALG, keyAlgorithm.id) put(ENC, contentAlgorithm.id) + if (keyId != null) put(KID, keyId) } ) diff --git a/lib/src/commonMain/kotlin/co/touchlab/kjwt/model/algorithm/EncryptionAlgorithm.kt b/lib/src/commonMain/kotlin/co/touchlab/kjwt/model/algorithm/EncryptionAlgorithm.kt index dc057dc..1b10bfa 100644 --- a/lib/src/commonMain/kotlin/co/touchlab/kjwt/model/algorithm/EncryptionAlgorithm.kt +++ b/lib/src/commonMain/kotlin/co/touchlab/kjwt/model/algorithm/EncryptionAlgorithm.kt @@ -5,19 +5,13 @@ package co.touchlab.kjwt.model.algorithm import co.touchlab.kjwt.cryptography.SimpleKey import co.touchlab.kjwt.serializers.EncryptionAlgorithmSerializer import dev.whyoleg.cryptography.CryptographyAlgorithmId -import dev.whyoleg.cryptography.CryptographyProvider import dev.whyoleg.cryptography.DelicateCryptographyApi -import dev.whyoleg.cryptography.algorithms.AES import dev.whyoleg.cryptography.algorithms.Digest -import dev.whyoleg.cryptography.algorithms.HMAC import dev.whyoleg.cryptography.algorithms.RSA import dev.whyoleg.cryptography.algorithms.SHA1 import dev.whyoleg.cryptography.algorithms.SHA256 -import dev.whyoleg.cryptography.algorithms.SHA384 -import dev.whyoleg.cryptography.algorithms.SHA512 import dev.whyoleg.cryptography.materials.key.Key import kotlinx.serialization.Serializable -import kotlin.random.Random @Serializable(EncryptionAlgorithmSerializer::class) public sealed class EncryptionAlgorithm( @@ -131,243 +125,3 @@ public sealed class EncryptionAlgorithm( } } } - -public sealed class EncryptionContentAlgorithm(public val id: String) { - /** AES-128 in GCM mode (`A128GCM`) content encryption algorithm. */ - public data object A128GCM : AesGCMBased("A128GCM") - - /** AES-192 in GCM mode (`A192GCM`) content encryption algorithm. */ - public data object A192GCM : AesGCMBased("A192GCM") - - /** AES-256 in GCM mode (`A256GCM`) content encryption algorithm. */ - public data object A256GCM : AesGCMBased("A256GCM") - - /** AES-128 CBC with HMAC-SHA-256 (`A128CBC-HS256`) content encryption algorithm. */ - public data object A128CbcHs256 : AesCBCBased("A128CBC-HS256") - - /** AES-192 CBC with HMAC-SHA-384 (`A192CBC-HS384`) content encryption algorithm. */ - public data object A192CbcHs384 : AesCBCBased("A192CBC-HS384") - - /** AES-256 CBC with HMAC-SHA-512 (`A256CBC-HS512`) content encryption algorithm. */ - public data object A256CbcHs512 : AesCBCBased("A256CBC-HS512") - - internal abstract suspend fun encrypt( - cek: ByteArray, - plaintext: ByteArray, - aad: ByteArray, - encryptedKey: ByteArray, - ): JweEncryptResult - - internal abstract suspend fun decrypt( - cek: ByteArray, - iv: ByteArray, - ciphertext: ByteArray, - tag: ByteArray, - aad: ByteArray, - ): ByteArray - - /** - * Base class for AES GCM content encryption algorithms (A128GCM, A192GCM, A256GCM). - * - * Uses AES in Galois/Counter Mode, which provides both confidentiality and integrity. - */ - public sealed class AesGCMBased(id: String) : EncryptionContentAlgorithm(id) { - public companion object { - private const val GCM_IV_SIZE = 12 - private const val GCM_TAG_SIZE = 16 - } - - override suspend fun encrypt( - cek: ByteArray, - plaintext: ByteArray, - aad: ByteArray, - encryptedKey: ByteArray - ): JweEncryptResult { - val aesKey = CryptographyProvider.Default.get(AES.GCM) - .keyDecoder() - .decodeFromByteArray(AES.Key.Format.RAW, cek) - - val cipher = aesKey.cipher() - val iv = Random.nextBytes(GCM_IV_SIZE) - - // encryptWithIv returns ciphertext || auth_tag - val combined = cipher.encryptWithIv(iv, plaintext, aad) - val ctLen = combined.size - GCM_TAG_SIZE - val ciphertext = combined.copyOfRange(0, ctLen) - val tag = combined.copyOfRange(ctLen, combined.size) - - return JweEncryptResult(encryptedKey, iv, ciphertext, tag) - } - - override suspend fun decrypt( - cek: ByteArray, - iv: ByteArray, - ciphertext: ByteArray, - tag: ByteArray, - aad: ByteArray - ): ByteArray { - val aesKey = CryptographyProvider.Default.get(AES.GCM) - .keyDecoder() - .decodeFromByteArray(AES.Key.Format.RAW, cek) - // Recombine ciphertext || tag before passing to the cipher - return aesKey.cipher().decryptWithIv(iv, ciphertext + tag, aad) - } - } - - /** - * Base class for AES CBC + HMAC content encryption algorithms (A128CBC-HS256, A192CBC-HS384, A256CBC-HS512). - * - * Uses AES in CBC mode combined with an HMAC tag for authenticated encryption per RFC 7516. - */ - public sealed class AesCBCBased(id: String) : EncryptionContentAlgorithm(id) { - public companion object { - private const val CBC_IV_SIZE = 16 - } - - override suspend fun encrypt( - cek: ByteArray, - plaintext: ByteArray, - aad: ByteArray, - encryptedKey: ByteArray - ): JweEncryptResult { - val half = cek.size / 2 - val macKey = cek.copyOfRange(0, half) - val encKey = cek.copyOfRange(half, cek.size) - - val iv = Random.nextBytes(CBC_IV_SIZE) - - val aesKey = CryptographyProvider.Default.get(AES.CBC) - .keyDecoder() - .decodeFromByteArray(AES.Key.Format.RAW, encKey) - val ciphertext = aesKey.cipher().encryptWithIv(iv, plaintext) - - val tag = computeCbcHmacTag(macKey, aad, iv, ciphertext) - - return JweEncryptResult(encryptedKey, iv, ciphertext, tag) - } - - override suspend fun decrypt( - cek: ByteArray, - iv: ByteArray, - ciphertext: ByteArray, - tag: ByteArray, - aad: ByteArray - ): ByteArray { - val half = cek.size / 2 - val macKey = cek.copyOfRange(0, half) - val encKey = cek.copyOfRange(half, cek.size) - - val expectedTag = computeCbcHmacTag(macKey, aad, iv, ciphertext) - require(expectedTag.contentEquals(tag)) { - "JWE authentication tag verification failed" - } - - val aesKey = CryptographyProvider.Default.get(AES.CBC) - .keyDecoder() - .decodeFromByteArray(AES.Key.Format.RAW, encKey) - return aesKey.cipher().decryptWithIv(iv, ciphertext) - } - - private suspend fun computeCbcHmacTag( - macKey: ByteArray, - aad: ByteArray, - iv: ByteArray, - ciphertext: ByteArray, - ): ByteArray { - // MAC input: AAD || IV || Ciphertext || AL (RFC 7516 §5.2.2.1) - val al = aad.size.toLong() * 8 - val alBytes = ByteArray(8) { i -> ((al shr (56 - i * 8)) and 0xFF).toByte() } - val macInput = aad + iv + ciphertext + alBytes - - val (hmacDigest, tagLen) = when (this) { - A128CbcHs256 -> Pair(SHA256, 16) - A192CbcHs384 -> Pair(SHA384, 24) - A256CbcHs512 -> Pair(SHA512, 32) - } - - val hmacKey = CryptographyProvider.Default.get(HMAC) - .keyDecoder(hmacDigest) - .decodeFromByteArray(HMAC.Key.Format.RAW, macKey) - val fullMac = hmacKey.signatureGenerator().generateSignature(macInput) - - // Per RFC 7516: truncate to the first T_LEN bytes - return fullMac.copyOfRange(0, tagLen) - } - } - - override fun toString(): String = id - - /** - * Generates a random Content Encryption Key (CEK) of the appropriate byte length for this algorithm. - * - * @return a freshly generated random CEK as a [ByteArray] - */ - internal fun generateCek(): ByteArray = - Random.nextBytes( - when (this) { - A128GCM -> 16 - A192GCM -> 24 - A256GCM -> 32 - A128CbcHs256 -> 32 // 16 mac + 16 enc - A192CbcHs384 -> 48 // 24 mac + 24 enc - A256CbcHs512 -> 64 // 32 mac + 32 enc - } - ) - - public companion object { - /** - * List of all supported [EncryptionContentAlgorithm] instances. - */ - internal val entries: List by lazy { - listOf( - A128GCM, - A192GCM, - A256GCM, - A128CbcHs256, - A192CbcHs384, - A256CbcHs512, - ) - } - - /** - * Returns the [EncryptionContentAlgorithm] whose [id] matches the given string. - * - * @param id the JWE content algorithm identifier to look up (e.g. `"A256GCM"`) - * @return the matching [EncryptionContentAlgorithm] instance - * @throws IllegalArgumentException if no algorithm with the given [id] is registered - */ - public fun fromId(id: String): EncryptionContentAlgorithm = - requireNotNull(entries.firstOrNull { it.id == id }) { - "Unknown JWE content algorithm: '$id'" - } - } -} - -internal data class JweEncryptResult( - val encryptedKey: ByteArray, - val iv: ByteArray, - val ciphertext: ByteArray, - val tag: ByteArray, -) { - override fun equals(other: Any?): Boolean { - if (this === other) return true - if (other == null || this::class != other::class) return false - - other as JweEncryptResult - - if (!encryptedKey.contentEquals(other.encryptedKey)) return false - if (!iv.contentEquals(other.iv)) return false - if (!ciphertext.contentEquals(other.ciphertext)) return false - if (!tag.contentEquals(other.tag)) return false - - return true - } - - override fun hashCode(): Int { - var result = encryptedKey.contentHashCode() - result = 31 * result + iv.contentHashCode() - result = 31 * result + ciphertext.contentHashCode() - result = 31 * result + tag.contentHashCode() - return result - } -} diff --git a/lib/src/commonMain/kotlin/co/touchlab/kjwt/model/algorithm/EncryptionContentAlgorithm.kt b/lib/src/commonMain/kotlin/co/touchlab/kjwt/model/algorithm/EncryptionContentAlgorithm.kt new file mode 100644 index 0000000..b370728 --- /dev/null +++ b/lib/src/commonMain/kotlin/co/touchlab/kjwt/model/algorithm/EncryptionContentAlgorithm.kt @@ -0,0 +1,223 @@ +@file:OptIn(DelicateCryptographyApi::class) + +package co.touchlab.kjwt.model.algorithm + +import dev.whyoleg.cryptography.CryptographyProvider +import dev.whyoleg.cryptography.DelicateCryptographyApi +import dev.whyoleg.cryptography.algorithms.AES +import dev.whyoleg.cryptography.algorithms.HMAC +import dev.whyoleg.cryptography.algorithms.SHA256 +import dev.whyoleg.cryptography.algorithms.SHA384 +import dev.whyoleg.cryptography.algorithms.SHA512 +import kotlin.random.Random + +public sealed class EncryptionContentAlgorithm(public val id: String) { + /** AES-128 in GCM mode (`A128GCM`) content encryption algorithm. */ + public data object A128GCM : AesGCMBased("A128GCM") + + /** AES-192 in GCM mode (`A192GCM`) content encryption algorithm. */ + public data object A192GCM : AesGCMBased("A192GCM") + + /** AES-256 in GCM mode (`A256GCM`) content encryption algorithm. */ + public data object A256GCM : AesGCMBased("A256GCM") + + /** AES-128 CBC with HMAC-SHA-256 (`A128CBC-HS256`) content encryption algorithm. */ + public data object A128CbcHs256 : AesCBCBased("A128CBC-HS256") + + /** AES-192 CBC with HMAC-SHA-384 (`A192CBC-HS384`) content encryption algorithm. */ + public data object A192CbcHs384 : AesCBCBased("A192CBC-HS384") + + /** AES-256 CBC with HMAC-SHA-512 (`A256CBC-HS512`) content encryption algorithm. */ + public data object A256CbcHs512 : AesCBCBased("A256CBC-HS512") + + internal abstract suspend fun encrypt( + cek: ByteArray, + plaintext: ByteArray, + aad: ByteArray, + encryptedKey: ByteArray, + ): JweEncryptResult + + internal abstract suspend fun decrypt( + cek: ByteArray, + iv: ByteArray, + ciphertext: ByteArray, + tag: ByteArray, + aad: ByteArray, + ): ByteArray + + /** + * Base class for AES GCM content encryption algorithms (A128GCM, A192GCM, A256GCM). + * + * Uses AES in Galois/Counter Mode, which provides both confidentiality and integrity. + */ + public sealed class AesGCMBased(id: String) : EncryptionContentAlgorithm(id) { + public companion object { + private const val GCM_IV_SIZE = 12 + private const val GCM_TAG_SIZE = 16 + } + + override suspend fun encrypt( + cek: ByteArray, + plaintext: ByteArray, + aad: ByteArray, + encryptedKey: ByteArray + ): JweEncryptResult { + val aesKey = CryptographyProvider.Companion.Default.get(AES.GCM) + .keyDecoder() + .decodeFromByteArray(AES.Key.Format.RAW, cek) + + val cipher = aesKey.cipher() + val iv = Random.Default.nextBytes(GCM_IV_SIZE) + + // encryptWithIv returns ciphertext || auth_tag + val combined = cipher.encryptWithIv(iv, plaintext, aad) + val ctLen = combined.size - GCM_TAG_SIZE + val ciphertext = combined.copyOfRange(0, ctLen) + val tag = combined.copyOfRange(ctLen, combined.size) + + return JweEncryptResult(encryptedKey, iv, ciphertext, tag) + } + + override suspend fun decrypt( + cek: ByteArray, + iv: ByteArray, + ciphertext: ByteArray, + tag: ByteArray, + aad: ByteArray + ): ByteArray { + val aesKey = CryptographyProvider.Companion.Default.get(AES.GCM) + .keyDecoder() + .decodeFromByteArray(AES.Key.Format.RAW, cek) + // Recombine ciphertext || tag before passing to the cipher + return aesKey.cipher().decryptWithIv(iv, ciphertext + tag, aad) + } + } + + /** + * Base class for AES CBC + HMAC content encryption algorithms (A128CBC-HS256, A192CBC-HS384, A256CBC-HS512). + * + * Uses AES in CBC mode combined with an HMAC tag for authenticated encryption per RFC 7516. + */ + public sealed class AesCBCBased(id: String) : EncryptionContentAlgorithm(id) { + public companion object { + private const val CBC_IV_SIZE = 16 + } + + override suspend fun encrypt( + cek: ByteArray, + plaintext: ByteArray, + aad: ByteArray, + encryptedKey: ByteArray + ): JweEncryptResult { + val half = cek.size / 2 + val macKey = cek.copyOfRange(0, half) + val encKey = cek.copyOfRange(half, cek.size) + + val iv = Random.Default.nextBytes(CBC_IV_SIZE) + + val aesKey = CryptographyProvider.Companion.Default.get(AES.CBC) + .keyDecoder() + .decodeFromByteArray(AES.Key.Format.RAW, encKey) + val ciphertext = aesKey.cipher().encryptWithIv(iv, plaintext) + + val tag = computeCbcHmacTag(macKey, aad, iv, ciphertext) + + return JweEncryptResult(encryptedKey, iv, ciphertext, tag) + } + + override suspend fun decrypt( + cek: ByteArray, + iv: ByteArray, + ciphertext: ByteArray, + tag: ByteArray, + aad: ByteArray + ): ByteArray { + val half = cek.size / 2 + val macKey = cek.copyOfRange(0, half) + val encKey = cek.copyOfRange(half, cek.size) + + val expectedTag = computeCbcHmacTag(macKey, aad, iv, ciphertext) + require(expectedTag.contentEquals(tag)) { + "JWE authentication tag verification failed" + } + + val aesKey = CryptographyProvider.Companion.Default.get(AES.CBC) + .keyDecoder() + .decodeFromByteArray(AES.Key.Format.RAW, encKey) + return aesKey.cipher().decryptWithIv(iv, ciphertext) + } + + private suspend fun computeCbcHmacTag( + macKey: ByteArray, + aad: ByteArray, + iv: ByteArray, + ciphertext: ByteArray, + ): ByteArray { + // MAC input: AAD || IV || Ciphertext || AL (RFC 7516 §5.2.2.1) + val al = aad.size.toLong() * 8 + val alBytes = ByteArray(8) { i -> ((al shr (56 - i * 8)) and 0xFF).toByte() } + val macInput = aad + iv + ciphertext + alBytes + + val (hmacDigest, tagLen) = when (this) { + A128CbcHs256 -> Pair(SHA256, 16) + A192CbcHs384 -> Pair(SHA384, 24) + A256CbcHs512 -> Pair(SHA512, 32) + } + + val hmacKey = CryptographyProvider.Companion.Default.get(HMAC.Companion) + .keyDecoder(hmacDigest) + .decodeFromByteArray(HMAC.Key.Format.RAW, macKey) + val fullMac = hmacKey.signatureGenerator().generateSignature(macInput) + + // Per RFC 7516: truncate to the first T_LEN bytes + return fullMac.copyOfRange(0, tagLen) + } + } + + override fun toString(): String = id + + /** + * Generates a random Content Encryption Key (CEK) of the appropriate byte length for this algorithm. + * + * @return a freshly generated random CEK as a [ByteArray] + */ + internal fun generateCek(): ByteArray = + Random.Default.nextBytes( + when (this) { + A128GCM -> 16 + A192GCM -> 24 + A256GCM -> 32 + A128CbcHs256 -> 32 // 16 mac + 16 enc + A192CbcHs384 -> 48 // 24 mac + 24 enc + A256CbcHs512 -> 64 // 32 mac + 32 enc + } + ) + + public companion object { + /** + * List of all supported [EncryptionContentAlgorithm] instances. + */ + internal val entries: List by lazy { + listOf( + A128GCM, + A192GCM, + A256GCM, + A128CbcHs256, + A192CbcHs384, + A256CbcHs512, + ) + } + + /** + * Returns the [EncryptionContentAlgorithm] whose [id] matches the given string. + * + * @param id the JWE content algorithm identifier to look up (e.g. `"A256GCM"`) + * @return the matching [EncryptionContentAlgorithm] instance + * @throws IllegalArgumentException if no algorithm with the given [id] is registered + */ + public fun fromId(id: String): EncryptionContentAlgorithm = + requireNotNull(entries.firstOrNull { it.id == id }) { + "Unknown JWE content algorithm: '$id'" + } + } +} diff --git a/lib/src/commonMain/kotlin/co/touchlab/kjwt/model/algorithm/JweEncryptResult.kt b/lib/src/commonMain/kotlin/co/touchlab/kjwt/model/algorithm/JweEncryptResult.kt new file mode 100644 index 0000000..83d37ef --- /dev/null +++ b/lib/src/commonMain/kotlin/co/touchlab/kjwt/model/algorithm/JweEncryptResult.kt @@ -0,0 +1,30 @@ +package co.touchlab.kjwt.model.algorithm + +internal data class JweEncryptResult( + val encryptedKey: ByteArray, + val iv: ByteArray, + val ciphertext: ByteArray, + val tag: ByteArray, +) { + override fun equals(other: Any?): Boolean { + if (this === other) return true + if (other == null || this::class != other::class) return false + + other as JweEncryptResult + + if (!encryptedKey.contentEquals(other.encryptedKey)) return false + if (!iv.contentEquals(other.iv)) return false + if (!ciphertext.contentEquals(other.ciphertext)) return false + if (!tag.contentEquals(other.tag)) return false + + return true + } + + override fun hashCode(): Int { + var result = encryptedKey.contentHashCode() + result = 31 * result + iv.contentHashCode() + result = 31 * result + ciphertext.contentHashCode() + result = 31 * result + tag.contentHashCode() + return result + } +} diff --git a/lib/src/commonMain/kotlin/co/touchlab/kjwt/model/algorithm/SigningAlgorithm.kt b/lib/src/commonMain/kotlin/co/touchlab/kjwt/model/algorithm/SigningAlgorithm.kt index 6b2061f..f9117dd 100644 --- a/lib/src/commonMain/kotlin/co/touchlab/kjwt/model/algorithm/SigningAlgorithm.kt +++ b/lib/src/commonMain/kotlin/co/touchlab/kjwt/model/algorithm/SigningAlgorithm.kt @@ -1,9 +1,11 @@ package co.touchlab.kjwt.model.algorithm import co.touchlab.kjwt.cryptography.SimpleKey +import co.touchlab.kjwt.model.registry.SigningKey import co.touchlab.kjwt.serializers.SigningAlgorithmSerializer import dev.whyoleg.cryptography.CryptographyAlgorithmId import dev.whyoleg.cryptography.algorithms.Digest +import dev.whyoleg.cryptography.algorithms.EC import dev.whyoleg.cryptography.algorithms.ECDSA import dev.whyoleg.cryptography.algorithms.HMAC import dev.whyoleg.cryptography.algorithms.RSA @@ -20,6 +22,8 @@ public sealed class SigningAlgorithm( internal abstract suspend fun sign(key: PrivateKey, signingInput: ByteArray): ByteArray internal abstract suspend fun verify(key: PublicKey, signingInput: ByteArray, signature: ByteArray): Boolean + internal fun identifier(keyId: String?) = SigningKey.Identifier(this, keyId) + /** HMAC with SHA-256 (`HS256`) signing algorithm using a symmetric [HMAC.Key]. */ public data object HS256 : HashBased("HS256") @@ -143,6 +147,13 @@ public sealed class SigningAlgorithm( ES512 -> SHA512 } + public val curve: EC.Curve + get() = when (this) { + ES256 -> EC.Curve.P256 + ES384 -> EC.Curve.P384 + ES512 -> EC.Curve.P521 + } + override suspend fun sign(key: ECDSA.PrivateKey, signingInput: ByteArray): ByteArray = key.signatureGenerator(digest, ECDSA.SignatureFormat.RAW).generateSignature(signingInput) diff --git a/lib/src/commonMain/kotlin/co/touchlab/kjwt/model/registry/EncryptionKey.kt b/lib/src/commonMain/kotlin/co/touchlab/kjwt/model/registry/EncryptionKey.kt new file mode 100644 index 0000000..52812d0 --- /dev/null +++ b/lib/src/commonMain/kotlin/co/touchlab/kjwt/model/registry/EncryptionKey.kt @@ -0,0 +1,210 @@ +package co.touchlab.kjwt.model.registry + +import co.touchlab.kjwt.model.algorithm.EncryptionAlgorithm +import co.touchlab.kjwt.model.algorithm.EncryptionContentAlgorithm +import co.touchlab.kjwt.model.algorithm.JweEncryptResult +import dev.whyoleg.cryptography.materials.key.Key + +/** + * Represents a cryptographic key (or key pair) used for JWE encryption and/or decryption. + * + * Instances are identified by a ([EncryptionAlgorithm], optional key ID) pair captured in + * [identifier]. Depending on which key material is available, an [EncryptionKey] may be: + * - [EncryptionOnlyKey] — holds only the public key; used by [co.touchlab.kjwt.builder.JwtBuilder] + * to encrypt tokens. + * - [DecryptionOnlyKey] — holds only the private key; used by [co.touchlab.kjwt.parser.JwtParser] + * to decrypt tokens. + * - [EncryptionKeyPair] — holds both keys; supports both encryption and decryption. + * + * Complementary keys that share the same [Identifier] can be merged into an [EncryptionKeyPair] + * via [mergeWith]. This happens automatically when both are registered with the same + * [JwtKeyRegistry]. + * + * @see JwtKeyRegistry + * @see co.touchlab.kjwt.parser.JwtParserBuilder.decryptWith + */ +public sealed class EncryptionKey { + public abstract val identifier: Identifier + public abstract val publicKey: PublicKey + public abstract val privateKey: PrivateKey + + public abstract val canEncrypt: Boolean + public abstract val canDecrypt: Boolean + + /** + * Identifies an [EncryptionKey] within a [JwtKeyRegistry] by algorithm and optional key ID. + * + * The combination of [algorithm] and [keyId] must be unique within a registry. When [keyId] + * is `null` the key acts as a catch-all for its algorithm (matched after any exact-`kid` key + * during look-up). + * + * @property algorithm the JWE key-encryption algorithm this key is associated with + * @property keyId the optional `kid` header value used to select this key; `null` matches any + * token for the given algorithm that has no more specific key registered + */ + public data class Identifier( + val algorithm: EncryptionAlgorithm, + val keyId: String?, + ) { + public companion object; + } + + /** + * An encryption-only key that holds only the public key material. + * + * Used when a token must be encrypted but decryption is not performed by the same key holder + * (e.g. asymmetric algorithms where only the public key is available). Accessing [privateKey] + * on this type throws. + */ + public class EncryptionOnlyKey internal constructor( + override val identifier: Identifier, + override val publicKey: PublicKey, + ) : EncryptionKey() { + @Deprecated("EncryptionOnlyKey does not have a private key", level = DeprecationLevel.ERROR) + override val privateKey: PrivateKey + get() = error("EncryptionOnlyKey does not have a private key") + + override val canEncrypt: Boolean = true + override val canDecrypt: Boolean = false + + override fun equals(other: Any?): Boolean { + if (this === other) return true + if (other == null || this::class != other::class) return false + + other as EncryptionOnlyKey<*, *> + + if (identifier != other.identifier) return false + if (publicKey != other.publicKey) return false + + return true + } + + override fun hashCode(): Int { + var result = identifier.hashCode() + result = 31 * result + publicKey.hashCode() + return result + } + + override fun toString(): String = + "EncryptionOnlyKey(identifier=$identifier, publicKey=$publicKey)" + } + + /** + * A decryption-only key that holds only the private key material. + * + * Used when tokens must be decrypted but encryption is not required (e.g. a service that only + * consumes encrypted tokens). Accessing [publicKey] on this type throws. + */ + public class DecryptionOnlyKey internal constructor( + override val identifier: Identifier, + override val privateKey: PrivateKey, + ) : EncryptionKey() { + @Deprecated("DecryptionOnlyKey does not have a public key", level = DeprecationLevel.ERROR) + override val publicKey: PublicKey + get() = error("DecryptionOnlyKey does not have a public key") + + override val canEncrypt: Boolean = false + override val canDecrypt: Boolean = true + + override fun equals(other: Any?): Boolean { + if (this === other) return true + if (other == null || this::class != other::class) return false + + other as DecryptionOnlyKey<*, *> + + if (identifier != other.identifier) return false + if (privateKey != other.privateKey) return false + + return true + } + + override fun hashCode(): Int { + var result = identifier.hashCode() + result = 31 * result + privateKey.hashCode() + return result + } + + override fun toString(): String = + "DecryptionOnlyKey(identifier=$identifier, privateKey=$privateKey)" + } + + /** + * A complete key pair that holds both public and private key material. + * + * Produced automatically by [mergeWith] when an [EncryptionOnlyKey] and a [DecryptionOnlyKey] + * with the same [Identifier] are both registered in a [JwtKeyRegistry]. Supports both + * encryption and decryption. + */ + public class EncryptionKeyPair internal constructor( + override val identifier: Identifier, + override val publicKey: PublicKey, + override val privateKey: PrivateKey, + ) : EncryptionKey() { + override val canEncrypt: Boolean = true + override val canDecrypt: Boolean = true + + override fun equals(other: Any?): Boolean { + if (this === other) return true + if (other == null || this::class != other::class) return false + + other as EncryptionKeyPair<*, *> + + if (identifier != other.identifier) return false + if (publicKey != other.publicKey) return false + if (privateKey != other.privateKey) return false + + return true + } + + override fun hashCode(): Int { + var result = identifier.hashCode() + result = 31 * result + publicKey.hashCode() + result = 31 * result + privateKey.hashCode() + return result + } + + override fun toString(): String = + "EncryptionKeyPair(identifier=$identifier, publicKey=$publicKey, privateKey=$privateKey)" + } + + internal suspend fun decrypt( + contentAlgorithm: EncryptionContentAlgorithm, + encryptedKey: ByteArray, + iv: ByteArray, + ciphertext: ByteArray, + tag: ByteArray, + aad: ByteArray, + ): ByteArray = identifier.algorithm.decrypt( + key = privateKey, + contentAlgorithm = contentAlgorithm, + encryptedKey = encryptedKey, + iv = iv, + ciphertext = ciphertext, + tag = tag, + aad = aad + ) + + internal suspend fun encrypt( + contentAlgorithm: EncryptionContentAlgorithm, + plaintext: ByteArray, + aad: ByteArray, + ): JweEncryptResult = identifier.algorithm.encrypt(publicKey, contentAlgorithm, plaintext, aad) + + internal fun mergeWith(other: EncryptionKey?): EncryptionKey { + if (other == null) return this + + require(identifier == other.identifier) { "Cannot merge keys with different identifiers" } + require(this::class != other::class) { "Cannot merge keys of the same type" } + require(this !is EncryptionKeyPair || other !is EncryptionKeyPair) { "Cannot merge when one key is complete" } + + return when (this) { + is EncryptionOnlyKey if other is DecryptionOnlyKey -> + EncryptionKeyPair(identifier, publicKey, other.privateKey) + + is DecryptionOnlyKey if other is EncryptionKeyPair -> + EncryptionKeyPair(identifier, other.publicKey, privateKey) + + else -> error("Cannot merge given keys") + } + } +} diff --git a/lib/src/commonMain/kotlin/co/touchlab/kjwt/model/registry/JwtKeyRegistry.kt b/lib/src/commonMain/kotlin/co/touchlab/kjwt/model/registry/JwtKeyRegistry.kt new file mode 100644 index 0000000..5a6da50 --- /dev/null +++ b/lib/src/commonMain/kotlin/co/touchlab/kjwt/model/registry/JwtKeyRegistry.kt @@ -0,0 +1,150 @@ +package co.touchlab.kjwt.model.registry + +import co.touchlab.kjwt.model.algorithm.EncryptionAlgorithm +import co.touchlab.kjwt.model.algorithm.SigningAlgorithm +import dev.whyoleg.cryptography.materials.key.Key + +/** + * A centralised store of signing and encryption keys shared across [co.touchlab.kjwt.builder.JwtBuilder] + * and [co.touchlab.kjwt.parser.JwtParser] instances. + * + * A [JwtKeyRegistry] decouples key management from individual builder and parser configurations. + * Populate it once, then reuse it across multiple call sites: + * - Pass it to [co.touchlab.kjwt.builder.JwtBuilder.signWith] or + * [co.touchlab.kjwt.builder.JwtBuilder.encryptWith] to sign or encrypt tokens using the + * registered keys. + * - Pass it to [co.touchlab.kjwt.parser.JwtParserBuilder.useKeysFrom] so one or more parsers + * delegate key look-up to it. + * + * ### Key lookup order + * + * When a key is requested the registry searches in this order: + * 1. **Exact match** — a key registered in this registry whose algorithm and key ID both match + * the request. + * 2. **Algorithm-only fallback** — if the request includes a key ID that has no exact match, a + * key registered *without* a key ID for the same algorithm is used as a catch-all. + * 3. **Delegate registry** — if no local key is found and a delegate was configured (via + * [co.touchlab.kjwt.parser.JwtParserBuilder.useKeysFrom]), the delegate is searched last. + * + * This means locally registered keys always take precedence over the delegate. For signing-key + * lookups an additional `alg=none` sentinel is tried first when insecure mode is active (used + * internally by [co.touchlab.kjwt.parser.JwtParserBuilder.noVerify]). + * + * ### Example + * + * ```kotlin + * val registry = JwtKeyRegistry() + * // populate via JwtParserBuilder and share the reference, or + * // register signing keys directly (see registerSigningKey) + * + * val token = Jwt.builder() + * .subject("user-123") + * .signWith(JwsAlgorithm.HS256, registry) + * + * val parser = Jwt.parser() + * .useKeysFrom(registry) + * .build() + * ``` + * + * @see co.touchlab.kjwt.parser.JwtParserBuilder.useKeysFrom + * @see co.touchlab.kjwt.builder.JwtBuilder.signWith + * @see co.touchlab.kjwt.builder.JwtBuilder.encryptWith + */ +public class JwtKeyRegistry { + private var delegateKeyRegistry: JwtKeyRegistry? = null + private val signingKeys = mutableMapOf, SigningKey<*, *>>() + private val encryptionKeys = mutableMapOf, EncryptionKey<*, *>>() + + /** + * Registers a [SigningKey] in this registry. + * + * Keys are stored by their [SigningKey.Identifier] (algorithm + optional key ID). If a key + * with the same identifier already exists and the new key is its complement — a + * [SigningKey.SigningOnlyKey] paired with a [SigningKey.VerifyOnlyKey] or vice-versa — the + * two are automatically merged into a [SigningKey.SigningKeyPair]. + * + * @param key the signing key to register + * @throws IllegalArgumentException if a key with the same identifier is already registered + * and the two keys cannot be merged (e.g. two verify-only keys for the same identifier) + */ + public fun registerSigningKey(key: SigningKey) { + signingKeys[key.identifier] = try { + key.mergeWith(signingKeys[key.identifier] as? SigningKey) + } catch (error: IllegalArgumentException) { + throw IllegalArgumentException( + "Signing key with for '${key.identifier.algorithm.id}' " + + "identified by '${key.identifier.keyId}' already registered", + error + ) + } + } + + /** + * Registers an [EncryptionKey] in this registry. + * + * Keys are stored by their [EncryptionKey.Identifier] (algorithm + optional key ID). If a key + * with the same identifier already exists and the new key is its complement — an + * [EncryptionKey.EncryptionOnlyKey] paired with a [EncryptionKey.DecryptionOnlyKey] or + * vice-versa — the two are automatically merged into an [EncryptionKey.EncryptionKeyPair]. + * + * @param key the encryption key to register + * @throws IllegalArgumentException if a key with the same identifier is already registered + * and the two keys cannot be merged (e.g. two decryption-only keys for the same identifier) + */ + internal fun registerEncryptionKey(key: EncryptionKey) { + encryptionKeys[key.identifier] = try { + key.mergeWith(encryptionKeys[key.identifier] as? EncryptionKey) + } catch (error: IllegalArgumentException) { + throw IllegalArgumentException( + "Decryption key with for '${key.identifier.algorithm.id}' " + + "identified by '${key.identifier.keyId}' already registered", + error + ) + } + } + + internal fun delegateTo(other: JwtKeyRegistry) { + var cursor: JwtKeyRegistry? = other + while (cursor != null) { + require(cursor !== this) { + "Cyclic delegation detected: this registry is already in the delegate chain of the target" + } + cursor = cursor.delegateKeyRegistry + } + delegateKeyRegistry = other + } + + internal fun findBestSigningKey( + algorithm: SigningAlgorithm, + keyId: String?, + ): SigningKey? { + signingKeys[SigningKey.Identifier(algorithm, keyId)]?.let { + return it as SigningKey + } + + if (keyId != null) { + signingKeys[SigningKey.Identifier(algorithm, null)]?.let { + return it as SigningKey + } + } + + return delegateKeyRegistry?.findBestSigningKey(algorithm, keyId) + } + + internal fun findBestEncryptionKey( + algorithm: EncryptionAlgorithm, + keyId: String?, + ): EncryptionKey? { + encryptionKeys[EncryptionKey.Identifier(algorithm, keyId)]?.let { + return it as EncryptionKey + } + + if (keyId != null) { + encryptionKeys[EncryptionKey.Identifier(algorithm, null)]?.let { + return it as EncryptionKey + } + } + + return delegateKeyRegistry?.findBestEncryptionKey(algorithm, keyId) + } +} diff --git a/lib/src/commonMain/kotlin/co/touchlab/kjwt/model/registry/SigningKey.kt b/lib/src/commonMain/kotlin/co/touchlab/kjwt/model/registry/SigningKey.kt new file mode 100644 index 0000000..3e94307 --- /dev/null +++ b/lib/src/commonMain/kotlin/co/touchlab/kjwt/model/registry/SigningKey.kt @@ -0,0 +1,198 @@ +package co.touchlab.kjwt.model.registry + +import co.touchlab.kjwt.cryptography.SimpleKey +import co.touchlab.kjwt.model.algorithm.SigningAlgorithm +import dev.whyoleg.cryptography.materials.key.Key + +/** + * Represents a cryptographic key (or key pair) used for JWS signing and/or verification. + * + * Instances are identified by a ([SigningAlgorithm], optional key ID) pair captured in + * [identifier]. Depending on which key material is available, a [SigningKey] may be: + * - [SigningOnlyKey] — holds only a private key; used by [co.touchlab.kjwt.builder.JwtBuilder] + * to produce signatures. + * - [VerifyOnlyKey] — holds only a public key; used by [co.touchlab.kjwt.parser.JwtParser] to + * verify signatures. + * - [SigningKeyPair] — holds both keys; supports both signing and verification. + * + * Complementary keys that share the same [Identifier] can be merged into a [SigningKeyPair] via + * [mergeWith]. This happens automatically when both are registered with the same + * [JwtKeyRegistry]. + * + * @see JwtKeyRegistry + * @see co.touchlab.kjwt.parser.JwtParserBuilder.verifyWith + */ +public sealed class SigningKey { + public abstract val identifier: Identifier + public abstract val publicKey: PublicKey + public abstract val privateKey: PrivateKey + + public abstract val canSign: Boolean + public abstract val canVerify: Boolean + + /** + * Identifies a [SigningKey] within a [JwtKeyRegistry] by algorithm and optional key ID. + * + * The combination of [algorithm] and [keyId] must be unique within a registry. When [keyId] + * is `null` the key acts as a catch-all for its algorithm (matched after any exact-`kid` key + * during look-up). + * + * @property algorithm the JWS algorithm this key is associated with + * @property keyId the optional `kid` header value used to select this key; `null` matches any + * token for the given algorithm that has no more specific key registered + */ + public data class Identifier( + val algorithm: SigningAlgorithm, + val keyId: String?, + ) { + public companion object { + /** Sentinel identifier used for unsigned (`alg=none`) tokens. */ + public val None: Identifier = Identifier(SigningAlgorithm.None, null) + } + } + + /** + * A signing-only key that holds only the private key material. + * + * Used when a token must be signed but signature verification is not performed by the same + * key holder (e.g. asymmetric algorithms where only the private key is available). Accessing + * [publicKey] on this type throws. + */ + public class SigningOnlyKey internal constructor( + override val identifier: Identifier, + override val privateKey: PrivateKey, + ) : SigningKey() { + @Deprecated("SigningOnlyKey does not have a public key", level = DeprecationLevel.ERROR) + override val publicKey: PublicKey + get() = error("SigningOnlyKey does not have a public key") + + override val canSign: Boolean = true + override val canVerify: Boolean = false + + override fun equals(other: Any?): Boolean { + if (this === other) return true + if (other == null || this::class != other::class) return false + + other as SigningOnlyKey<*, *> + + if (identifier != other.identifier) return false + if (privateKey != other.privateKey) return false + + return true + } + + override fun hashCode(): Int { + var result = identifier.hashCode() + result = 31 * result + privateKey.hashCode() + return result + } + + override fun toString(): String = + "SigningOnlyKey(identifier=$identifier, privateKey=$privateKey)" + } + + /** + * A verify-only key that holds only the public key material. + * + * Used when tokens must be verified but signing is not required (e.g. a service that only + * consumes tokens). Accessing [privateKey] on this type throws. + */ + public class VerifyOnlyKey internal constructor( + override val identifier: Identifier, + override val publicKey: PublicKey, + ) : SigningKey() { + @Deprecated("VerifyOnlyKey does not have a private key", level = DeprecationLevel.ERROR) + override val privateKey: PrivateKey + get() = error("VerifyOnlyKey does not have a private key") + + override val canSign: Boolean = false + override val canVerify: Boolean = true + + override fun equals(other: Any?): Boolean { + if (this === other) return true + if (other == null || this::class != other::class) return false + + other as VerifyOnlyKey<*, *> + + if (identifier != other.identifier) return false + if (publicKey != other.publicKey) return false + + return true + } + + override fun hashCode(): Int { + var result = identifier.hashCode() + result = 31 * result + publicKey.hashCode() + return result + } + + override fun toString(): String = + "VerifyOnlyKey(publicKey=$publicKey, identifier=$identifier)" + } + + /** + * A complete key pair that holds both private and public key material. + * + * Produced automatically by [mergeWith] when a [SigningOnlyKey] and a [VerifyOnlyKey] with + * the same [Identifier] are both registered in a [JwtKeyRegistry]. Supports both signing and + * verification. + */ + public class SigningKeyPair internal constructor( + override val identifier: Identifier, + override val publicKey: PublicKey, + override val privateKey: PrivateKey, + ) : SigningKey() { + override val canSign: Boolean = true + override val canVerify: Boolean = true + + override fun equals(other: Any?): Boolean { + if (this === other) return true + if (other == null || this::class != other::class) return false + + other as SigningKeyPair<*, *> + + if (identifier != other.identifier) return false + if (publicKey != other.publicKey) return false + if (privateKey != other.privateKey) return false + + return true + } + + override fun hashCode(): Int { + var result = identifier.hashCode() + result = 31 * result + publicKey.hashCode() + result = 31 * result + privateKey.hashCode() + return result + } + + override fun toString(): String = + "SigningKeyPair(identifier=$identifier, publicKey=$publicKey, privateKey=$privateKey)" + } + + internal suspend fun verify(signingInput: ByteArray, signature: ByteArray): Boolean = try { + identifier.algorithm.verify(publicKey, signingInput, signature) + } catch (_: Throwable) { + false + } + + internal suspend fun sign(signingInput: ByteArray): ByteArray = + identifier.algorithm.sign(privateKey, signingInput) + + internal fun mergeWith(other: SigningKey?): SigningKey { + if (other == null) return this + + require(identifier == other.identifier) { "Cannot merge keys with different identifiers" } + require(this::class != other::class) { "Cannot merge keys of the same type" } + require(this !is SigningKeyPair || other !is SigningKeyPair) { "Cannot merge when one key is complete" } + + return when (this) { + is SigningOnlyKey if other is VerifyOnlyKey -> + SigningKeyPair(identifier, other.publicKey, privateKey) + + is VerifyOnlyKey if other is SigningOnlyKey -> + SigningKeyPair(identifier, publicKey, other.privateKey) + + else -> error("Cannot merge given keys") + } + } +} diff --git a/lib/src/commonMain/kotlin/co/touchlab/kjwt/parser/JwtParser.kt b/lib/src/commonMain/kotlin/co/touchlab/kjwt/parser/JwtParser.kt index 2dda310..3ef1335 100644 --- a/lib/src/commonMain/kotlin/co/touchlab/kjwt/parser/JwtParser.kt +++ b/lib/src/commonMain/kotlin/co/touchlab/kjwt/parser/JwtParser.kt @@ -9,6 +9,7 @@ import co.touchlab.kjwt.exception.SignatureException import co.touchlab.kjwt.exception.UnsupportedJwtException import co.touchlab.kjwt.ext.encryption import co.touchlab.kjwt.ext.expirationOrNull +import co.touchlab.kjwt.ext.keyIdOrNull import co.touchlab.kjwt.ext.notBeforeOrNull import co.touchlab.kjwt.internal.decodeBase64Url import co.touchlab.kjwt.internal.encodeBase64Url @@ -56,12 +57,11 @@ public class JwtParser internal constructor(private val config: JwtParserBuilder val claims = JwtPayload(parts[1]) val signature = parts[2] - if (algorithm != SigningAlgorithm.None) { + if (algorithm != SigningAlgorithm.None && !config.skipVerification) { val verifier = checkNotNull( - config.jwsKeyVerifier?.takeIf { - it.algorithm == algorithm || it.algorithm == SigningAlgorithm.None && config.allowUnsecured - } + config.keyRegistry.findBestSigningKey(algorithm, header.keyIdOrNull) ) { "No verification key configured. Call verifyWith() or noVerify() on the parser builder." } + val signingInput = "${parts[0]}.${parts[1]}".encodeToByteArray() val signature = signature.decodeBase64Url() @@ -98,7 +98,7 @@ public class JwtParser internal constructor(private val config: JwtParserBuilder throw UnsupportedJwtException("Unsupported JWE content algorithm: '${header.encryption}'", e) } - val decryptor = requireNotNull(config.jweKeyDecryptor?.takeIf { it.algorithm == keyAlgorithm }) { + val decryptor = requireNotNull(config.keyRegistry.findBestEncryptionKey(keyAlgorithm, header.keyIdOrNull)) { "No decryption key configured. Call decryptWith() on the parser builder." } diff --git a/lib/src/commonMain/kotlin/co/touchlab/kjwt/parser/JwtParserBuilder.kt b/lib/src/commonMain/kotlin/co/touchlab/kjwt/parser/JwtParserBuilder.kt index c0e80f8..0229f13 100644 --- a/lib/src/commonMain/kotlin/co/touchlab/kjwt/parser/JwtParserBuilder.kt +++ b/lib/src/commonMain/kotlin/co/touchlab/kjwt/parser/JwtParserBuilder.kt @@ -1,6 +1,5 @@ package co.touchlab.kjwt.parser -import co.touchlab.kjwt.cryptography.SimpleKey import co.touchlab.kjwt.exception.IncorrectClaimException import co.touchlab.kjwt.exception.MissingClaimException import co.touchlab.kjwt.ext.audienceOrNull @@ -10,8 +9,10 @@ import co.touchlab.kjwt.ext.subjectOrNull import co.touchlab.kjwt.model.JwtHeader import co.touchlab.kjwt.model.JwtPayload import co.touchlab.kjwt.model.algorithm.EncryptionAlgorithm -import co.touchlab.kjwt.model.algorithm.EncryptionContentAlgorithm import co.touchlab.kjwt.model.algorithm.SigningAlgorithm +import co.touchlab.kjwt.model.registry.EncryptionKey +import co.touchlab.kjwt.model.registry.JwtKeyRegistry +import co.touchlab.kjwt.model.registry.SigningKey import dev.whyoleg.cryptography.materials.key.Key /** @@ -19,22 +20,23 @@ import dev.whyoleg.cryptography.materials.key.Key * * Example: * ```kotlin + * val signingKey = SigningAlgorithm.HS256.newKey() * val parser = Jwt.parser() - * .verifyWith(JwsAlgorithm.HS256, hmacKey) + * .verifyWith(signingKey) * .requireIssuer("myapp") - * .clockSkew(DateTimePeriod(seconds = 30)) + * .clockSkew(30L) * .build() * val jws = parser.parse(token) * ``` */ public class JwtParserBuilder { - internal var jwsKeyVerifier: JwsKeyVerifier<*, *>? = null - internal var jweKeyDecryptor: JweKeyDecryptor<*, *>? = null + internal val keyRegistry = JwtKeyRegistry() @PublishedApi internal val validators: MutableList<(JwtPayload, JwtHeader) -> Unit> = mutableListOf() internal var clockSkewSeconds: Long = 0L internal var allowUnsecured: Boolean = false + internal var skipVerification: Boolean = false /** * Disables signature verification entirely, accepting any token regardless of its signature. @@ -46,7 +48,33 @@ public class JwtParserBuilder { */ public fun noVerify(): JwtParserBuilder = apply { allowUnsecured = true - jwsKeyVerifier = JwsKeyVerifier(SigningAlgorithm.None, SimpleKey.Empty) + skipVerification = true + } + + /** + * Delegates key look-up to the given [registry] before consulting this parser's own keys. + * + * Keys registered directly on this builder (via [verifyWith] or [decryptWith]) take + * precedence; the [registry] is only consulted when no local key matches. This makes it easy + * to share a central key store across multiple parsers while still allowing each parser to + * override individual keys locally. + * + * ```kotlin + * val sharedRegistry = JwtKeyRegistry() + * // keys are added to sharedRegistry elsewhere + * + * val parser = Jwt.parser() + * .useKeysFrom(sharedRegistry) + * .requireIssuer("my-app") + * .build() + * ``` + * + * @param registry the [JwtKeyRegistry] to fall back to when no local key matches + * @return this builder for chaining + * @see JwtKeyRegistry + */ + public fun useKeysFrom(registry: JwtKeyRegistry): JwtParserBuilder = apply { + keyRegistry.delegateTo(registry) } /** @@ -54,29 +82,95 @@ public class JwtParserBuilder { * * @param algorithm the signing algorithm to use for verification * @param key the public key (or symmetric key) for signature verification + * @param keyId optional key ID to associate with this verifier; when set, the parser will + * only use this key if the token's `kid` header matches. Defaults to `null` (matches any token). * @return this builder for chaining */ public fun verifyWith( algorithm: SigningAlgorithm, - key: PublicKey + key: PublicKey, + keyId: String? = null, ): JwtParserBuilder = apply { - jwsKeyVerifier = JwsKeyVerifier(algorithm, key) + keyRegistry.registerSigningKey( + SigningKey.VerifyOnlyKey( + identifier = SigningKey.Identifier(algorithm, keyId), + publicKey = key + ) + ) } + /** + * Registers a pre-built [SigningKey.VerifyOnlyKey] for JWS signature verification. + * + * The algorithm and `kid` are taken from [key]'s [SigningKey.Identifier]. + * + * @param key the verify-only signing key to register + * @return this builder for chaining + */ + public fun verifyWith( + key: SigningKey.VerifyOnlyKey, + ): JwtParserBuilder = apply { keyRegistry.registerSigningKey(key) } + + /** + * Registers a pre-built [SigningKey.SigningKeyPair] for JWS signature verification. + * + * The algorithm and `kid` are taken from [key]'s [SigningKey.Identifier]. Both the public and + * private key material are stored, but only the public key is used for verification. + * + * @param key the signing key pair to register + * @return this builder for chaining + */ + public fun verifyWith( + key: SigningKey.SigningKeyPair, + ): JwtParserBuilder = apply { keyRegistry.registerSigningKey(key) } + /** * Sets the algorithm and private key used to decrypt JWE tokens. * * @param algorithm the key encryption algorithm used to unwrap the content encryption key * @param privateKey the private key for decrypting the JWE token + * @param keyId optional key ID to associate with this decryptor; when set, the parser will + * only use this key if the token's `kid` header matches. Defaults to `null` (matches any token). * @return this builder for chaining */ public fun decryptWith( algorithm: EncryptionAlgorithm, - privateKey: PrivateKey + privateKey: PrivateKey, + keyId: String? = null, ): JwtParserBuilder = apply { - jweKeyDecryptor = JweKeyDecryptor(algorithm, privateKey) + keyRegistry.registerEncryptionKey( + EncryptionKey.DecryptionOnlyKey( + identifier = EncryptionKey.Identifier(algorithm, keyId), + privateKey = privateKey, + ) + ) } + /** + * Registers a pre-built [EncryptionKey.DecryptionOnlyKey] for JWE token decryption. + * + * The algorithm and `kid` are taken from [key]'s [EncryptionKey.Identifier]. + * + * @param key the decryption-only encryption key to register + * @return this builder for chaining + */ + public fun decryptWith( + key: EncryptionKey.DecryptionOnlyKey, + ): JwtParserBuilder = apply { keyRegistry.registerEncryptionKey(key) } + + /** + * Registers a pre-built [EncryptionKey.EncryptionKeyPair] for JWE token decryption. + * + * The algorithm and `kid` are taken from [key]'s [EncryptionKey.Identifier]. Both the public + * and private key material are stored, but only the private key is used for decryption. + * + * @param key the encryption key pair to register + * @return this builder for chaining + */ + public fun decryptWith( + key: EncryptionKey.EncryptionKeyPair, + ): JwtParserBuilder = apply { keyRegistry.registerEncryptionKey(key) } + /** * Adds a validator that requires the `iss` claim to equal the given value. * @@ -163,6 +257,7 @@ public class JwtParserBuilder { */ public fun allowUnsecured(allow: Boolean): JwtParserBuilder = apply { allowUnsecured = allow + if (!allow) skipVerification = false } /** @@ -172,28 +267,3 @@ public class JwtParserBuilder { */ public fun build(): JwtParser = JwtParser(this) } - -internal data class JwsKeyVerifier( - val algorithm: SigningAlgorithm, - val publicKey: PublicKey, -) { - suspend fun verify(signingInput: ByteArray, signature: ByteArray): Boolean = try { - algorithm.verify(publicKey, signingInput, signature) - } catch (_: Throwable) { - false - } -} - -internal data class JweKeyDecryptor( - val algorithm: EncryptionAlgorithm, - val privateKey: PrivateKey, -) { - suspend fun decrypt( - contentAlgorithm: EncryptionContentAlgorithm, - encryptedKey: ByteArray, - iv: ByteArray, - ciphertext: ByteArray, - tag: ByteArray, - aad: ByteArray, - ): ByteArray = algorithm.decrypt(privateKey, contentAlgorithm, encryptedKey, iv, ciphertext, tag, aad) -} diff --git a/lib/src/commonTest/kotlin/co/touchlab/kjwt/ClaimsValidationTest.kt b/lib/src/commonTest/kotlin/co/touchlab/kjwt/ClaimsValidationTest.kt index 1a0a193..23eff9a 100644 --- a/lib/src/commonTest/kotlin/co/touchlab/kjwt/ClaimsValidationTest.kt +++ b/lib/src/commonTest/kotlin/co/touchlab/kjwt/ClaimsValidationTest.kt @@ -6,7 +6,6 @@ import co.touchlab.kjwt.exception.MissingClaimException import co.touchlab.kjwt.exception.PrematureJwtException import co.touchlab.kjwt.ext.expirationOrNull import co.touchlab.kjwt.ext.subjectOrNull -import co.touchlab.kjwt.model.algorithm.SigningAlgorithm import io.kotest.core.spec.style.FunSpec import kotlin.test.assertEquals import kotlin.test.assertFailsWith @@ -20,31 +19,31 @@ class ClaimsValidationTest : FunSpec({ context("expiration") { test("parse expired token throws ExpiredJwtException") { - val key = hs256Key() + val signingKey = hs256SigningKey() val token = Jwt.builder() .expiration(Clock.System.now() - 1.hours) // expired 1 hour ago - .signWith(SigningAlgorithm.HS256, key) + .signWith(signingKey) .compact() assertFailsWith { Jwt.parser() - .verifyWith(SigningAlgorithm.HS256, key) + .verifyWith(signingKey) .build() .parseSigned(token) } } test("parse expired token contains claims in exception") { - val key = hs256Key() + val signingKey = hs256SigningKey() val token = Jwt.builder() .subject("expired-user") .expiration(Clock.System.now() - 1.hours) - .signWith(SigningAlgorithm.HS256, key) + .signWith(signingKey) .compact() val ex = assertFailsWith { Jwt.parser() - .verifyWith(SigningAlgorithm.HS256, key) + .verifyWith(signingKey) .build() .parseSigned(token) } @@ -55,32 +54,32 @@ class ClaimsValidationTest : FunSpec({ } test("parse expired token within clock skew passes") { - val key = hs256Key() + val signingKey = hs256SigningKey() // Expired 30 seconds ago val token = Jwt.builder() .expiration(Clock.System.now() - 30.seconds) - .signWith(SigningAlgorithm.HS256, key) + .signWith(signingKey) .compact() // With 60-second skew, should pass Jwt.parser() - .verifyWith(SigningAlgorithm.HS256, key) + .verifyWith(signingKey) .clockSkew(60L) .build() .parseSigned(token) // should not throw } test("parse expired token outside clock skew throws") { - val key = hs256Key() + val signingKey = hs256SigningKey() // Expired 120 seconds ago val token = Jwt.builder() .expiration(Clock.System.now() - 120.seconds) - .signWith(SigningAlgorithm.HS256, key) + .signWith(signingKey) .compact() assertFailsWith { Jwt.parser() - .verifyWith(SigningAlgorithm.HS256, key) + .verifyWith(signingKey) .clockSkew(60L) // skew of 60s is not enough for 120s expired .build() .parseSigned(token) @@ -91,47 +90,47 @@ class ClaimsValidationTest : FunSpec({ context("not before") { test("parse premature token throws PrematureJwtException") { - val key = hs256Key() + val signingKey = hs256SigningKey() val token = Jwt.builder() .notBefore(Clock.System.now() + 1.hours) // not valid for another hour - .signWith(SigningAlgorithm.HS256, key) + .signWith(signingKey) .compact() assertFailsWith { Jwt.parser() - .verifyWith(SigningAlgorithm.HS256, key) + .verifyWith(signingKey) .build() .parseSigned(token) } } test("parse premature token within clock skew passes") { - val key = hs256Key() + val signingKey = hs256SigningKey() // Not valid for another 30 seconds val token = Jwt.builder() .notBefore(Clock.System.now() + 30.seconds) - .signWith(SigningAlgorithm.HS256, key) + .signWith(signingKey) .compact() // With 60-second skew, should pass Jwt.parser() - .verifyWith(SigningAlgorithm.HS256, key) + .verifyWith(signingKey) .clockSkew(60L) .build() .parseSigned(token) // should not throw } test("parse premature token outside clock skew throws") { - val key = hs256Key() + val signingKey = hs256SigningKey() // Not valid for another 120 seconds val token = Jwt.builder() .notBefore(Clock.System.now() + 120.seconds) - .signWith(SigningAlgorithm.HS256, key) + .signWith(signingKey) .compact() assertFailsWith { Jwt.parser() - .verifyWith(SigningAlgorithm.HS256, key) + .verifyWith(signingKey) .clockSkew(60L) .build() .parseSigned(token) @@ -142,15 +141,15 @@ class ClaimsValidationTest : FunSpec({ context("issuer") { test("requireIssuer mismatch throws IncorrectClaimException") { - val key = hs256Key() + val signingKey = hs256SigningKey() val token = Jwt.builder() .issuer("actual-issuer") - .signWith(SigningAlgorithm.HS256, key) + .signWith(signingKey) .compact() val ex = assertFailsWith { Jwt.parser() - .verifyWith(SigningAlgorithm.HS256, key) + .verifyWith(signingKey) .requireIssuer("expected-issuer") .build() .parseSigned(token) @@ -161,16 +160,16 @@ class ClaimsValidationTest : FunSpec({ } test("requireIssuer missing throws MissingClaimException") { - val key = hs256Key() + val signingKey = hs256SigningKey() // Token with no issuer claim val token = Jwt.builder() .subject("someone") - .signWith(SigningAlgorithm.HS256, key) + .signWith(signingKey) .compact() val ex = assertFailsWith { Jwt.parser() - .verifyWith(SigningAlgorithm.HS256, key) + .verifyWith(signingKey) .requireIssuer("expected-issuer") .build() .parseSigned(token) @@ -180,16 +179,16 @@ class ClaimsValidationTest : FunSpec({ } test("requireIssuer case mismatch throws IncorrectClaimException") { - val key = hs256Key() + val signingKey = hs256SigningKey() val token = Jwt.builder() .issuer("Auth.MyApp.io") - .signWith(SigningAlgorithm.HS256, key) + .signWith(signingKey) .compact() // Default comparison is case-sensitive val ex = assertFailsWith { Jwt.parser() - .verifyWith(SigningAlgorithm.HS256, key) + .verifyWith(signingKey) .requireIssuer("auth.myapp.io") .build() .parseSigned(token) @@ -200,15 +199,15 @@ class ClaimsValidationTest : FunSpec({ } test("requireIssuer ignore case passes") { - val key = hs256Key() + val signingKey = hs256SigningKey() val token = Jwt.builder() .issuer("AUTH.MYAPP.IO") - .signWith(SigningAlgorithm.HS256, key) + .signWith(signingKey) .compact() // Should not throw — comparison is case-insensitive Jwt.parser() - .verifyWith(SigningAlgorithm.HS256, key) + .verifyWith(signingKey) .requireIssuer("auth.myapp.io", ignoreCase = true) .build() .parseSigned(token) @@ -218,15 +217,15 @@ class ClaimsValidationTest : FunSpec({ context("subject") { test("requireSubject mismatch throws IncorrectClaimException") { - val key = hs256Key() + val signingKey = hs256SigningKey() val token = Jwt.builder() .subject("actual-subject") - .signWith(SigningAlgorithm.HS256, key) + .signWith(signingKey) .compact() val ex = assertFailsWith { Jwt.parser() - .verifyWith(SigningAlgorithm.HS256, key) + .verifyWith(signingKey) .requireSubject("expected-subject") .build() .parseSigned(token) @@ -236,15 +235,15 @@ class ClaimsValidationTest : FunSpec({ } test("requireSubject missing throws MissingClaimException") { - val key = hs256Key() + val signingKey = hs256SigningKey() val token = Jwt.builder() .issuer("issuer") - .signWith(SigningAlgorithm.HS256, key) + .signWith(signingKey) .compact() val ex = assertFailsWith { Jwt.parser() - .verifyWith(SigningAlgorithm.HS256, key) + .verifyWith(signingKey) .requireSubject("expected-subject") .build() .parseSigned(token) @@ -254,16 +253,16 @@ class ClaimsValidationTest : FunSpec({ } test("requireSubject case mismatch throws IncorrectClaimException") { - val key = hs256Key() + val signingKey = hs256SigningKey() val token = Jwt.builder() .subject("User-123") - .signWith(SigningAlgorithm.HS256, key) + .signWith(signingKey) .compact() // Subject comparison is case-sensitive val ex = assertFailsWith { Jwt.parser() - .verifyWith(SigningAlgorithm.HS256, key) + .verifyWith(signingKey) .requireSubject("user-123") .build() .parseSigned(token) @@ -276,15 +275,15 @@ class ClaimsValidationTest : FunSpec({ context("audience") { test("requireAudience mismatch throws IncorrectClaimException") { - val key = hs256Key() + val signingKey = hs256SigningKey() val token = Jwt.builder() .audience("actual-aud") - .signWith(SigningAlgorithm.HS256, key) + .signWith(signingKey) .compact() assertFailsWith { Jwt.parser() - .verifyWith(SigningAlgorithm.HS256, key) + .verifyWith(signingKey) .requireAudience("expected-aud") .build() .parseSigned(token) @@ -292,15 +291,15 @@ class ClaimsValidationTest : FunSpec({ } test("requireAudience not in array throws IncorrectClaimException") { - val key = hs256Key() + val signingKey = hs256SigningKey() val token = Jwt.builder() .audience("aud1", "aud2") - .signWith(SigningAlgorithm.HS256, key) + .signWith(signingKey) .compact() assertFailsWith { Jwt.parser() - .verifyWith(SigningAlgorithm.HS256, key) + .verifyWith(signingKey) .requireAudience("aud3") .build() .parseSigned(token) @@ -308,16 +307,16 @@ class ClaimsValidationTest : FunSpec({ } test("requireAudience case mismatch throws IncorrectClaimException") { - val key = hs256Key() + val signingKey = hs256SigningKey() val token = Jwt.builder() .audience("Mobile-App") - .signWith(SigningAlgorithm.HS256, key) + .signWith(signingKey) .compact() // Audience comparison is case-sensitive assertFailsWith { Jwt.parser() - .verifyWith(SigningAlgorithm.HS256, key) + .verifyWith(signingKey) .requireAudience("mobile-app") .build() .parseSigned(token) @@ -325,15 +324,15 @@ class ClaimsValidationTest : FunSpec({ } test("requireAudience missing throws MissingClaimException") { - val key = hs256Key() + val signingKey = hs256SigningKey() val token = Jwt.builder() .subject("user") - .signWith(SigningAlgorithm.HS256, key) + .signWith(signingKey) .compact() val ex = assertFailsWith { Jwt.parser() - .verifyWith(SigningAlgorithm.HS256, key) + .verifyWith(signingKey) .requireAudience("expected-aud") .build() .parseSigned(token) @@ -346,30 +345,30 @@ class ClaimsValidationTest : FunSpec({ context("custom required claims") { test("requireCustomClaim present match passes") { - val key = hs256Key() + val signingKey = hs256SigningKey() val token = Jwt.builder() .claim("role", "admin") - .signWith(SigningAlgorithm.HS256, key) + .signWith(signingKey) .compact() // Should not throw Jwt.parser() - .verifyWith(SigningAlgorithm.HS256, key) + .verifyWith(signingKey) .requireClaim("role", "admin") .build() .parseSigned(token) } test("requireCustomClaim present mismatch throws IncorrectClaimException") { - val key = hs256Key() + val signingKey = hs256SigningKey() val token = Jwt.builder() .claim("role", "user") - .signWith(SigningAlgorithm.HS256, key) + .signWith(signingKey) .compact() val ex = assertFailsWith { Jwt.parser() - .verifyWith(SigningAlgorithm.HS256, key) + .verifyWith(signingKey) .requireClaim("role", "admin") .build() .parseSigned(token) @@ -380,15 +379,15 @@ class ClaimsValidationTest : FunSpec({ } test("requireCustomClaim missing throws MissingClaimException") { - val key = hs256Key() + val signingKey = hs256SigningKey() val token = Jwt.builder() .subject("user") - .signWith(SigningAlgorithm.HS256, key) + .signWith(signingKey) .compact() val ex = assertFailsWith { Jwt.parser() - .verifyWith(SigningAlgorithm.HS256, key) + .verifyWith(signingKey) .requireClaim("role", "admin") .build() .parseSigned(token) diff --git a/lib/src/commonTest/kotlin/co/touchlab/kjwt/JweEncodeTest.kt b/lib/src/commonTest/kotlin/co/touchlab/kjwt/JweEncodeTest.kt index a928e52..36709a8 100644 --- a/lib/src/commonTest/kotlin/co/touchlab/kjwt/JweEncodeTest.kt +++ b/lib/src/commonTest/kotlin/co/touchlab/kjwt/JweEncodeTest.kt @@ -5,6 +5,7 @@ package co.touchlab.kjwt import co.touchlab.kjwt.ext.subjectOrNull import co.touchlab.kjwt.model.algorithm.EncryptionAlgorithm import co.touchlab.kjwt.model.algorithm.EncryptionContentAlgorithm +import co.touchlab.kjwt.model.registry.EncryptionKey import io.kotest.core.spec.style.FunSpec import kotlin.test.assertEquals import kotlin.test.assertNotEquals @@ -15,14 +16,14 @@ class JweEncodeTest : FunSpec({ context("Dir + GCM round-trips") { test("encrypt Dir A128GCM round trip") { - val cek = aesSimpleKey(128) + val encKey = dirEncKey(128) val token = Jwt.builder() .subject("a128gcm-user") - .encryptWith(cek, EncryptionAlgorithm.Dir, EncryptionContentAlgorithm.A128GCM) + .encryptWith(encKey, EncryptionContentAlgorithm.A128GCM) .compact() val jwe = Jwt.parser() - .decryptWith(EncryptionAlgorithm.Dir, cek) + .decryptWith(encKey) .build() .parseEncrypted(token) @@ -30,14 +31,14 @@ class JweEncodeTest : FunSpec({ } test("encrypt Dir A192GCM round trip").config(enabled = !isWebBrowserPlatform()) { - val cek = aesSimpleKey(192) + val encKey = dirEncKey(192) val token = Jwt.builder() .subject("a192gcm-user") - .encryptWith(cek, EncryptionAlgorithm.Dir, EncryptionContentAlgorithm.A192GCM) + .encryptWith(encKey, EncryptionContentAlgorithm.A192GCM) .compact() val jwe = Jwt.parser() - .decryptWith(EncryptionAlgorithm.Dir, cek) + .decryptWith(encKey) .build() .parseEncrypted(token) @@ -45,14 +46,14 @@ class JweEncodeTest : FunSpec({ } test("encrypt Dir A256GCM round trip") { - val cek = aesSimpleKey(256) + val encKey = dirEncKey(256) val token = Jwt.builder() .subject("a256gcm-user") - .encryptWith(cek, EncryptionAlgorithm.Dir, EncryptionContentAlgorithm.A256GCM) + .encryptWith(encKey, EncryptionContentAlgorithm.A256GCM) .compact() val jwe = Jwt.parser() - .decryptWith(EncryptionAlgorithm.Dir, cek) + .decryptWith(encKey) .build() .parseEncrypted(token) @@ -63,14 +64,14 @@ class JweEncodeTest : FunSpec({ context("Dir + CBC-HMAC round-trips") { test("encrypt Dir A128CbcHs256 round trip") { - val cek = aesSimpleKey(256) // 32 bytes: 16 MAC + 16 ENC + val encKey = dirEncKey(256) // 32 bytes: 16 MAC + 16 ENC val token = Jwt.builder() .subject("a128cbc-user") - .encryptWith(cek, EncryptionAlgorithm.Dir, EncryptionContentAlgorithm.A128CbcHs256) + .encryptWith(encKey, EncryptionContentAlgorithm.A128CbcHs256) .compact() val jwe = Jwt.parser() - .decryptWith(EncryptionAlgorithm.Dir, cek) + .decryptWith(encKey) .build() .parseEncrypted(token) @@ -78,14 +79,14 @@ class JweEncodeTest : FunSpec({ } test("encrypt Dir A192CbcHs384 round trip").config(enabled = !isWebBrowserPlatform()) { - val cek = aesSimpleKey(384) // 48 bytes: 24 MAC + 24 ENC + val encKey = dirEncKey(384) // 48 bytes: 24 MAC + 24 ENC val token = Jwt.builder() .subject("a192cbc-user") - .encryptWith(cek, EncryptionAlgorithm.Dir, EncryptionContentAlgorithm.A192CbcHs384) + .encryptWith(encKey, EncryptionContentAlgorithm.A192CbcHs384) .compact() val jwe = Jwt.parser() - .decryptWith(EncryptionAlgorithm.Dir, cek) + .decryptWith(encKey) .build() .parseEncrypted(token) @@ -93,14 +94,14 @@ class JweEncodeTest : FunSpec({ } test("encrypt Dir A256CbcHs512 round trip") { - val cek = aesSimpleKey(512) // 64 bytes: 32 MAC + 32 ENC + val encKey = dirEncKey(512) // 64 bytes: 32 MAC + 32 ENC val token = Jwt.builder() .subject("a256cbc-user") - .encryptWith(cek, EncryptionAlgorithm.Dir, EncryptionContentAlgorithm.A256CbcHs512) + .encryptWith(encKey, EncryptionContentAlgorithm.A256CbcHs512) .compact() val jwe = Jwt.parser() - .decryptWith(EncryptionAlgorithm.Dir, cek) + .decryptWith(encKey) .build() .parseEncrypted(token) @@ -111,14 +112,14 @@ class JweEncodeTest : FunSpec({ context("RSA-OAEP (SHA-1) round-trips") { test("encrypt RsaOaep A128GCM round trip") { - val keyPair = rsaOaepKeyPair() + val encKey = rsaOaepEncKey() val token = Jwt.builder() .subject("rsa-oaep-a128gcm") - .encryptWith(keyPair.publicKey, EncryptionAlgorithm.RsaOaep, EncryptionContentAlgorithm.A128GCM) + .encryptWith(encKey, EncryptionContentAlgorithm.A128GCM) .compact() val jwe = Jwt.parser() - .decryptWith(EncryptionAlgorithm.RsaOaep, keyPair.privateKey) + .decryptWith(encKey) .build() .parseEncrypted(token) @@ -126,14 +127,14 @@ class JweEncodeTest : FunSpec({ } test("encrypt RsaOaep A256GCM round trip") { - val keyPair = rsaOaepKeyPair() + val encKey = rsaOaepEncKey() val token = Jwt.builder() .subject("rsa-oaep-a256gcm") - .encryptWith(keyPair.publicKey, EncryptionAlgorithm.RsaOaep, EncryptionContentAlgorithm.A256GCM) + .encryptWith(encKey, EncryptionContentAlgorithm.A256GCM) .compact() val jwe = Jwt.parser() - .decryptWith(EncryptionAlgorithm.RsaOaep, keyPair.privateKey) + .decryptWith(encKey) .build() .parseEncrypted(token) @@ -141,14 +142,14 @@ class JweEncodeTest : FunSpec({ } test("encrypt RsaOaep A256CbcHs512 round trip") { - val keyPair = rsaOaepKeyPair() + val encKey = rsaOaepEncKey() val token = Jwt.builder() .subject("rsa-oaep-cbc512") - .encryptWith(keyPair.publicKey, EncryptionAlgorithm.RsaOaep, EncryptionContentAlgorithm.A256CbcHs512) + .encryptWith(encKey, EncryptionContentAlgorithm.A256CbcHs512) .compact() val jwe = Jwt.parser() - .decryptWith(EncryptionAlgorithm.RsaOaep, keyPair.privateKey) + .decryptWith(encKey) .build() .parseEncrypted(token) @@ -159,14 +160,14 @@ class JweEncodeTest : FunSpec({ context("RSA-OAEP-256 (SHA-256) round-trips") { test("encrypt RsaOaep256 A128GCM round trip") { - val keyPair = rsaOaep256KeyPair() + val encKey = rsaOaep256EncKey() val token = Jwt.builder() .subject("rsa-oaep256-a128gcm") - .encryptWith(keyPair.publicKey, EncryptionAlgorithm.RsaOaep256, EncryptionContentAlgorithm.A128GCM) + .encryptWith(encKey, EncryptionContentAlgorithm.A128GCM) .compact() val jwe = Jwt.parser() - .decryptWith(EncryptionAlgorithm.RsaOaep256, keyPair.privateKey) + .decryptWith(encKey) .build() .parseEncrypted(token) @@ -174,14 +175,14 @@ class JweEncodeTest : FunSpec({ } test("encrypt RsaOaep256 A256GCM round trip") { - val keyPair = rsaOaep256KeyPair() + val encKey = rsaOaep256EncKey() val token = Jwt.builder() .subject("rsa-oaep256-a256gcm") - .encryptWith(keyPair.publicKey, EncryptionAlgorithm.RsaOaep256, EncryptionContentAlgorithm.A256GCM) + .encryptWith(encKey, EncryptionContentAlgorithm.A256GCM) .compact() val jwe = Jwt.parser() - .decryptWith(EncryptionAlgorithm.RsaOaep256, keyPair.privateKey) + .decryptWith(encKey) .build() .parseEncrypted(token) @@ -189,14 +190,14 @@ class JweEncodeTest : FunSpec({ } test("encrypt RsaOaep256 A256CbcHs512 round trip") { - val keyPair = rsaOaep256KeyPair() + val encKey = rsaOaep256EncKey() val token = Jwt.builder() .subject("rsa-oaep256-cbc512") - .encryptWith(keyPair.publicKey, EncryptionAlgorithm.RsaOaep256, EncryptionContentAlgorithm.A256CbcHs512) + .encryptWith(encKey, EncryptionContentAlgorithm.A256CbcHs512) .compact() val jwe = Jwt.parser() - .decryptWith(EncryptionAlgorithm.RsaOaep256, keyPair.privateKey) + .decryptWith(encKey) .build() .parseEncrypted(token) @@ -276,14 +277,44 @@ class JweEncodeTest : FunSpec({ test("encrypt Dir A256GCM with kid") { val cek = aesSimpleKey(256) val token = Jwt.builder() - .keyId("enc-key-id") .subject("test") - .encryptWith(cek, EncryptionAlgorithm.Dir, EncryptionContentAlgorithm.A256GCM) + .encryptWith(cek, EncryptionAlgorithm.Dir, EncryptionContentAlgorithm.A256GCM, "enc-key-id") .compact() val headerJson = decodeTokenHeader(token) assertTrue(headerJson.contains("\"kid\":\"enc-key-id\""), "Header must contain kid, got: $headerJson") } + + test("encrypt RsaOaep A256GCM with kid") { + val keyPair = rsaOaepKeyPair() + val token = Jwt.builder() + .subject("test") + .encryptWith(keyPair.publicKey, EncryptionAlgorithm.RsaOaep, EncryptionContentAlgorithm.A256GCM, "rsa-enc-key-id") + .compact() + + val headerJson = decodeTokenHeader(token) + assertTrue(headerJson.contains("\"kid\":\"rsa-enc-key-id\""), "Header must contain kid, got: $headerJson") + } + } + + context("key capability checks") { + + test("encryptWith EncryptionOnlyKey succeeds") { + val encKey = rsaOaepEncKey() + val encryptionOnlyKey = EncryptionKey.EncryptionOnlyKey(encKey.identifier, encKey.publicKey) + + Jwt.builder() + .subject("test") + .encryptWith(encryptionOnlyKey, EncryptionContentAlgorithm.A256GCM) + } + + test("encryptWith EncryptionKeyPair succeeds") { + val encKey = rsaOaepEncKey() + + Jwt.builder() + .subject("test") + .encryptWith(encKey, EncryptionContentAlgorithm.A256GCM) + } } context("uniqueness") { @@ -302,4 +333,37 @@ class JweEncodeTest : FunSpec({ assertNotEquals(t1, t2, "Each JWE encryption call must produce a unique token (random IV)") } } + + context("raw key API (backward compat)") { + + test("encryptWith and decryptWith raw Dir key") { + val cek = aesSimpleKey(256) + val token = Jwt.builder() + .subject("dir-compat") + .encryptWith(cek, EncryptionAlgorithm.Dir, EncryptionContentAlgorithm.A256GCM) + .compact() + + val jwe = Jwt.parser() + .decryptWith(EncryptionAlgorithm.Dir, cek) + .build() + .parseEncrypted(token) + + assertEquals("dir-compat", jwe.payload.subjectOrNull) + } + + test("encryptWith and decryptWith raw RSA OAEP key pair") { + val keyPair = rsaOaepKeyPair() + val token = Jwt.builder() + .subject("rsa-oaep-compat") + .encryptWith(keyPair.publicKey, EncryptionAlgorithm.RsaOaep, EncryptionContentAlgorithm.A256GCM) + .compact() + + val jwe = Jwt.parser() + .decryptWith(EncryptionAlgorithm.RsaOaep, keyPair.privateKey) + .build() + .parseEncrypted(token) + + assertEquals("rsa-oaep-compat", jwe.payload.subjectOrNull) + } + } }) diff --git a/lib/src/commonTest/kotlin/co/touchlab/kjwt/JwsDecodeTest.kt b/lib/src/commonTest/kotlin/co/touchlab/kjwt/JwsDecodeTest.kt index 6ab478d..f747297 100644 --- a/lib/src/commonTest/kotlin/co/touchlab/kjwt/JwsDecodeTest.kt +++ b/lib/src/commonTest/kotlin/co/touchlab/kjwt/JwsDecodeTest.kt @@ -9,8 +9,6 @@ import co.touchlab.kjwt.ext.notBeforeOrNull import co.touchlab.kjwt.ext.subjectOrNull import co.touchlab.kjwt.ext.type import co.touchlab.kjwt.model.JwtInstance -import co.touchlab.kjwt.model.algorithm.SigningAlgorithm -import dev.whyoleg.cryptography.algorithms.EC import io.kotest.core.spec.style.FunSpec import kotlin.test.assertEquals import kotlin.test.assertIs @@ -25,9 +23,9 @@ class JwsDecodeTest : FunSpec({ context("parse known HS256 token") { test("parse Hs256 valid token") { - val key = hs256Key() + val signingKey = hs256SigningKey() val jws = Jwt.parser() - .verifyWith(SigningAlgorithm.HS256, key) + .verifyWith(signingKey) .build() .parseSigned( "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9" + @@ -44,15 +42,15 @@ class JwsDecodeTest : FunSpec({ } test("parse Hs384 valid token") { - val key = hs384Key() + val signingKey = hs384SigningKey() val token = Jwt.builder() .subject("hs384-user") .issuedAt(Instant.fromEpochSeconds(1_700_000_000)) - .signWith(SigningAlgorithm.HS384, key) + .signWith(signingKey) .compact() val jws = Jwt.parser() - .verifyWith(SigningAlgorithm.HS384, key) + .verifyWith(signingKey) .build() .parseSigned(token) @@ -61,15 +59,15 @@ class JwsDecodeTest : FunSpec({ } test("parse Hs512 valid token") { - val key = hs512Key() + val signingKey = hs512SigningKey() val token = Jwt.builder() .subject("hs512-user") .issuedAt(Instant.fromEpochSeconds(1_700_000_000)) - .signWith(SigningAlgorithm.HS512, key) + .signWith(signingKey) .compact() val jws = Jwt.parser() - .verifyWith(SigningAlgorithm.HS512, key) + .verifyWith(signingKey) .build() .parseSigned(token) @@ -78,14 +76,14 @@ class JwsDecodeTest : FunSpec({ } test("parse Rs256 valid token") { - val keyPair = rsaPkcs1KeyPair() + val keyPair = rs256SigningKey() val token = Jwt.builder() .subject("rs256-user") - .signWith(SigningAlgorithm.RS256, keyPair.privateKey) + .signWith(keyPair) .compact() val jws = Jwt.parser() - .verifyWith(SigningAlgorithm.RS256, keyPair.publicKey) + .verifyWith(keyPair) .build() .parseSigned(token) @@ -95,14 +93,14 @@ class JwsDecodeTest : FunSpec({ } test("parse Es256 valid token") { - val keyPair = ecKeyPair(EC.Curve.P256) + val keyPair = es256SigningKey() val token = Jwt.builder() .subject("es256-user") - .signWith(SigningAlgorithm.ES256, keyPair.privateKey) + .signWith(keyPair) .compact() val jws = Jwt.parser() - .verifyWith(SigningAlgorithm.ES256, keyPair.publicKey) + .verifyWith(keyPair) .build() .parseSigned(token) @@ -111,14 +109,14 @@ class JwsDecodeTest : FunSpec({ } test("parse Ps256 valid token") { - val keyPair = rsaPssKeyPair() + val keyPair = ps256SigningKey() val token = Jwt.builder() .subject("ps256-user") - .signWith(SigningAlgorithm.PS256, keyPair.privateKey) + .signWith(keyPair) .compact() val jws = Jwt.parser() - .verifyWith(SigningAlgorithm.PS256, keyPair.publicKey) + .verifyWith(keyPair) .build() .parseSigned(token) @@ -132,7 +130,7 @@ class JwsDecodeTest : FunSpec({ test("parse none with allow unsecured") { val token = Jwt.builder() .subject("none-user") - .signWith(SigningAlgorithm.None) + .build() .compact() val jws = Jwt.parser() @@ -147,7 +145,7 @@ class JwsDecodeTest : FunSpec({ test("parse none with no verify succeeds") { val token = Jwt.builder() .subject("none-user") - .signWith(SigningAlgorithm.None) + .build() .compact() val jws = Jwt.parser() @@ -160,10 +158,10 @@ class JwsDecodeTest : FunSpec({ } test("no verify with signed token skips verification") { - val key = hs256Key() + val signingKey = hs256SigningKey() val token = Jwt.builder() .subject("user") - .signWith(SigningAlgorithm.HS256, key) + .signWith(signingKey) .compact() // noVerify() matches any algorithm and None.verify() always returns true @@ -180,15 +178,15 @@ class JwsDecodeTest : FunSpec({ context("audience normalization") { test("parse Hs256 audience normalized single string") { - val key = hs256Key() + val signingKey = hs256SigningKey() // Build a token with single audience (serialized as plain string) val token = Jwt.builder() .audience("api.example.com") - .signWith(SigningAlgorithm.HS256, key) + .signWith(signingKey) .compact() val jws = Jwt.parser() - .verifyWith(SigningAlgorithm.HS256, key) + .verifyWith(signingKey) .build() .parseSigned(token) @@ -196,15 +194,15 @@ class JwsDecodeTest : FunSpec({ } test("parse Hs256 audience normalized array") { - val key = hs256Key() + val signingKey = hs256SigningKey() // Build a token with multiple audiences (serialized as JSON array) val token = Jwt.builder() .audience("aud1", "aud2") - .signWith(SigningAlgorithm.HS256, key) + .signWith(signingKey) .compact() val jws = Jwt.parser() - .verifyWith(SigningAlgorithm.HS256, key) + .verifyWith(signingKey) .build() .parseSigned(token) @@ -215,16 +213,16 @@ class JwsDecodeTest : FunSpec({ context("typed custom claim access") { test("parse Hs256 custom claims typed access") { - val key = hs256Key() + val signingKey = hs256SigningKey() val token = Jwt.builder() .claim("role", "admin") .claim("level", 5) .claim("active", true) - .signWith(SigningAlgorithm.HS256, key) + .signWith(signingKey) .compact() val jws = Jwt.parser() - .verifyWith(SigningAlgorithm.HS256, key) + .verifyWith(signingKey) .build() .parseSigned(token) @@ -237,14 +235,14 @@ class JwsDecodeTest : FunSpec({ context("auto-detect") { test("parse auto detect JWS token returns Jws") { - val key = hs256Key() + val signingKey = hs256SigningKey() val token = Jwt.builder() .subject("auto-detect-user") - .signWith(SigningAlgorithm.HS256, key) + .signWith(signingKey) .compact() val result = Jwt.parser() - .verifyWith(SigningAlgorithm.HS256, key) + .verifyWith(signingKey) .build() .parse(token) @@ -256,15 +254,15 @@ class JwsDecodeTest : FunSpec({ context("claim validation happy paths") { test("validate issuer match") { - val key = hs256Key() + val signingKey = hs256SigningKey() val token = Jwt.builder() .issuer("my-issuer") .expiration(Clock.System.now() + 1.hours) - .signWith(SigningAlgorithm.HS256, key) + .signWith(signingKey) .compact() val jws = Jwt.parser() - .verifyWith(SigningAlgorithm.HS256, key) + .verifyWith(signingKey) .requireIssuer("my-issuer") .build() .parseSigned(token) @@ -273,15 +271,15 @@ class JwsDecodeTest : FunSpec({ } test("validate subject match") { - val key = hs256Key() + val signingKey = hs256SigningKey() val token = Jwt.builder() .subject("my-subject") .expiration(Clock.System.now() + 1.hours) - .signWith(SigningAlgorithm.HS256, key) + .signWith(signingKey) .compact() val jws = Jwt.parser() - .verifyWith(SigningAlgorithm.HS256, key) + .verifyWith(signingKey) .requireSubject("my-subject") .build() .parseSigned(token) @@ -290,15 +288,15 @@ class JwsDecodeTest : FunSpec({ } test("validate audience match single") { - val key = hs256Key() + val signingKey = hs256SigningKey() val token = Jwt.builder() .audience("my-api") .expiration(Clock.System.now() + 1.hours) - .signWith(SigningAlgorithm.HS256, key) + .signWith(signingKey) .compact() val jws = Jwt.parser() - .verifyWith(SigningAlgorithm.HS256, key) + .verifyWith(signingKey) .requireAudience("my-api") .build() .parseSigned(token) @@ -307,15 +305,15 @@ class JwsDecodeTest : FunSpec({ } test("validate audience match one of many") { - val key = hs256Key() + val signingKey = hs256SigningKey() val token = Jwt.builder() .audience("api1", "api2", "api3") .expiration(Clock.System.now() + 1.hours) - .signWith(SigningAlgorithm.HS256, key) + .signWith(signingKey) .compact() val jws = Jwt.parser() - .verifyWith(SigningAlgorithm.HS256, key) + .verifyWith(signingKey) .requireAudience("api2") .build() .parseSigned(token) @@ -324,15 +322,15 @@ class JwsDecodeTest : FunSpec({ } test("validate exp not expired") { - val key = hs256Key() + val signingKey = hs256SigningKey() val token = Jwt.builder() .expiration(Clock.System.now() + 1.hours) - .signWith(SigningAlgorithm.HS256, key) + .signWith(signingKey) .compact() // Should not throw val jws = Jwt.parser() - .verifyWith(SigningAlgorithm.HS256, key) + .verifyWith(signingKey) .build() .parseSigned(token) @@ -340,15 +338,15 @@ class JwsDecodeTest : FunSpec({ } test("validate nbf past time allowed") { - val key = hs256Key() + val signingKey = hs256SigningKey() val token = Jwt.builder() .notBefore(Clock.System.now() - 1.hours) // already past, so valid - .signWith(SigningAlgorithm.HS256, key) + .signWith(signingKey) .compact() // Should not throw val jws = Jwt.parser() - .verifyWith(SigningAlgorithm.HS256, key) + .verifyWith(signingKey) .build() .parseSigned(token) @@ -356,16 +354,16 @@ class JwsDecodeTest : FunSpec({ } test("validate clock skew slightly expired within skew") { - val key = hs256Key() + val signingKey = hs256SigningKey() // Expired 3 seconds ago val token = Jwt.builder() .expiration(Clock.System.now() - 3.seconds) - .signWith(SigningAlgorithm.HS256, key) + .signWith(signingKey) .compact() // With 5-second skew, it should pass val jws = Jwt.parser() - .verifyWith(SigningAlgorithm.HS256, key) + .verifyWith(signingKey) .clockSkew(5L) .build() .parseSigned(token) diff --git a/lib/src/commonTest/kotlin/co/touchlab/kjwt/JwsEncodeTest.kt b/lib/src/commonTest/kotlin/co/touchlab/kjwt/JwsEncodeTest.kt index b2d4681..1e1dfeb 100644 --- a/lib/src/commonTest/kotlin/co/touchlab/kjwt/JwsEncodeTest.kt +++ b/lib/src/commonTest/kotlin/co/touchlab/kjwt/JwsEncodeTest.kt @@ -10,9 +10,8 @@ import co.touchlab.kjwt.ext.keyId import co.touchlab.kjwt.ext.notBeforeOrNull import co.touchlab.kjwt.ext.subjectOrNull import co.touchlab.kjwt.model.algorithm.SigningAlgorithm +import co.touchlab.kjwt.model.registry.SigningKey import dev.whyoleg.cryptography.algorithms.EC -import dev.whyoleg.cryptography.algorithms.SHA384 -import dev.whyoleg.cryptography.algorithms.SHA512 import io.kotest.core.spec.style.FunSpec import kotlin.test.assertEquals import kotlin.test.assertNotEquals @@ -81,7 +80,7 @@ class JwsEncodeTest : FunSpec({ context("all registered claims") { test("sign Hs256 with all registered claims") { - val key = hs256Key() + val signingKey = hs256SigningKey() val now = Clock.System.now() val token = Jwt.builder() .issuer("test-issuer") @@ -91,11 +90,11 @@ class JwsEncodeTest : FunSpec({ .notBefore(now - 1.hours) .issuedAt(now) .id("unique-jwt-id") - .signWith(SigningAlgorithm.HS256, key) + .signWith(signingKey) .compact() val jws = Jwt.parser() - .verifyWith(SigningAlgorithm.HS256, key) + .verifyWith(signingKey) .build() .parseSigned(token) @@ -112,16 +111,16 @@ class JwsEncodeTest : FunSpec({ context("custom claims") { test("sign Hs256 with custom claims") { - val key = hs256Key() + val signingKey = hs256SigningKey() val token = Jwt.builder() .claim("strClaim", "hello") .claim("numClaim", 42) .claim("boolClaim", true) - .signWith(SigningAlgorithm.HS256, key) + .signWith(signingKey) .compact() val jws = Jwt.parser() - .verifyWith(SigningAlgorithm.HS256, key) + .verifyWith(signingKey) .build() .parseSigned(token) @@ -134,10 +133,10 @@ class JwsEncodeTest : FunSpec({ context("audience serialization") { test("sign Hs256 audience single string") { - val key = hs256Key() + val signingKey = hs256SigningKey() val token = Jwt.builder() .audience("single-aud") - .signWith(SigningAlgorithm.HS256, key) + .signWith(signingKey) .compact() val payloadJson = decodeTokenPayload(token) @@ -146,10 +145,10 @@ class JwsEncodeTest : FunSpec({ } test("sign Hs256 audience multiple") { - val key = hs256Key() + val signingKey = hs256SigningKey() val token = Jwt.builder() .audience("aud1", "aud2", "aud3") - .signWith(SigningAlgorithm.HS256, key) + .signWith(signingKey) .compact() val payloadJson = decodeTokenPayload(token) @@ -161,27 +160,56 @@ class JwsEncodeTest : FunSpec({ context("header fields") { test("sign Hs256 header kid included") { - val key = hs256Key() + val signingKey = hs256SigningKey(keyId = "my-key-id") val token = Jwt.builder() - .keyId("my-key-id") .subject("test") - .signWith(SigningAlgorithm.HS256, key) + .signWith(signingKey) .compact() val jws = Jwt.parser() - .verifyWith(SigningAlgorithm.HS256, key) + .verifyWith(signingKey) .build() .parseSigned(token) assertEquals("my-key-id", jws.header.keyId) } + test("sign Rs256 header kid included") { + val signingKey = rs256SigningKey(keyId = "rsa-key-id") + val token = Jwt.builder() + .subject("test") + .signWith(signingKey) + .compact() + + val jws = Jwt.parser() + .verifyWith(signingKey) + .build() + .parseSigned(token) + + assertEquals("rsa-key-id", jws.header.keyId) + } + + test("sign Es256 header kid included") { + val signingKey = es256SigningKey(keyId = "ec-key-id") + val token = Jwt.builder() + .subject("test") + .signWith(signingKey) + .compact() + + val jws = Jwt.parser() + .verifyWith(signingKey) + .build() + .parseSigned(token) + + assertEquals("ec-key-id", jws.header.keyId) + } + test("sign Hs256 custom header fields") { - val key = hs256Key() + val signingKey = hs256SigningKey() val token = Jwt.builder() .header { extra("x-custom", JsonPrimitive("custom-value")) } .subject("test") - .signWith(SigningAlgorithm.HS256, key) + .signWith(signingKey) .compact() val headerJson = decodeTokenHeader(token) @@ -192,14 +220,14 @@ class JwsEncodeTest : FunSpec({ context("RSA PKCS1 round-trips") { test("sign Rs256 round trip") { - val keyPair = rsaPkcs1KeyPair() + val keyPair = rs256SigningKey() val token = Jwt.builder() .subject("rs256-subject") - .signWith(SigningAlgorithm.RS256, keyPair.privateKey) + .signWith(keyPair) .compact() val jws = Jwt.parser() - .verifyWith(SigningAlgorithm.RS256, keyPair.publicKey) + .verifyWith(keyPair) .build() .parseSigned(token) @@ -208,14 +236,14 @@ class JwsEncodeTest : FunSpec({ } test("sign Rs384 round trip") { - val keyPair = rsaPkcs1KeyPair(SHA384) + val keyPair = rs384SigningKey() val token = Jwt.builder() .subject("rs384-subject") - .signWith(SigningAlgorithm.RS384, keyPair.privateKey) + .signWith(keyPair) .compact() val jws = Jwt.parser() - .verifyWith(SigningAlgorithm.RS384, keyPair.publicKey) + .verifyWith(keyPair) .build() .parseSigned(token) @@ -223,14 +251,14 @@ class JwsEncodeTest : FunSpec({ } test("sign Rs512 round trip") { - val keyPair = rsaPkcs1KeyPair(SHA512) + val keyPair = rs512SigningKey() val token = Jwt.builder() .subject("rs512-subject") - .signWith(SigningAlgorithm.RS512, keyPair.privateKey) + .signWith(keyPair) .compact() val jws = Jwt.parser() - .verifyWith(SigningAlgorithm.RS512, keyPair.publicKey) + .verifyWith(keyPair) .build() .parseSigned(token) @@ -241,14 +269,14 @@ class JwsEncodeTest : FunSpec({ context("RSA PSS round-trips") { test("sign Ps256 round trip") { - val keyPair = rsaPssKeyPair() + val keyPair = ps256SigningKey() val token = Jwt.builder() .subject("ps256-subject") - .signWith(SigningAlgorithm.PS256, keyPair.privateKey) + .signWith(keyPair) .compact() val jws = Jwt.parser() - .verifyWith(SigningAlgorithm.PS256, keyPair.publicKey) + .verifyWith(keyPair) .build() .parseSigned(token) @@ -257,14 +285,14 @@ class JwsEncodeTest : FunSpec({ } test("sign Ps384 round trip") { - val keyPair = rsaPssKeyPair(SHA384) + val keyPair = ps384SigningKey() val token = Jwt.builder() .subject("ps384-subject") - .signWith(SigningAlgorithm.PS384, keyPair.privateKey) + .signWith(keyPair) .compact() val jws = Jwt.parser() - .verifyWith(SigningAlgorithm.PS384, keyPair.publicKey) + .verifyWith(keyPair) .build() .parseSigned(token) @@ -272,14 +300,14 @@ class JwsEncodeTest : FunSpec({ } test("sign Ps512 round trip") { - val keyPair = rsaPssKeyPair(SHA512) + val keyPair = ps512SigningKey() val token = Jwt.builder() .subject("ps512-subject") - .signWith(SigningAlgorithm.PS512, keyPair.privateKey) + .signWith(keyPair) .compact() val jws = Jwt.parser() - .verifyWith(SigningAlgorithm.PS512, keyPair.publicKey) + .verifyWith(keyPair) .build() .parseSigned(token) @@ -290,14 +318,14 @@ class JwsEncodeTest : FunSpec({ context("ECDSA round-trips") { test("sign Es256 round trip") { - val keyPair = ecKeyPair(EC.Curve.P256) + val keyPair = es256SigningKey() val token = Jwt.builder() .subject("es256-subject") - .signWith(SigningAlgorithm.ES256, keyPair.privateKey) + .signWith(keyPair) .compact() val jws = Jwt.parser() - .verifyWith(SigningAlgorithm.ES256, keyPair.publicKey) + .verifyWith(keyPair) .build() .parseSigned(token) @@ -306,14 +334,14 @@ class JwsEncodeTest : FunSpec({ } test("sign Es384 round trip") { - val keyPair = ecKeyPair(EC.Curve.P384) + val keyPair = es384SigningKey() val token = Jwt.builder() .subject("es384-subject") - .signWith(SigningAlgorithm.ES384, keyPair.privateKey) + .signWith(keyPair) .compact() val jws = Jwt.parser() - .verifyWith(SigningAlgorithm.ES384, keyPair.publicKey) + .verifyWith(keyPair) .build() .parseSigned(token) @@ -321,14 +349,14 @@ class JwsEncodeTest : FunSpec({ } test("sign Es512 round trip") { - val keyPair = ecKeyPair(EC.Curve.P521) + val keyPair = es512SigningKey() val token = Jwt.builder() .subject("es512-subject") - .signWith(SigningAlgorithm.ES512, keyPair.privateKey) + .signWith(keyPair) .compact() val jws = Jwt.parser() - .verifyWith(SigningAlgorithm.ES512, keyPair.publicKey) + .verifyWith(keyPair) .build() .parseSigned(token) @@ -337,10 +365,10 @@ class JwsEncodeTest : FunSpec({ test("sign Es256 signature is raw format") { // ES256 RAW signature = R‖S, each 32 bytes for P-256 → 64 bytes total - val keyPair = ecKeyPair(EC.Curve.P256) + val keyPair = es256SigningKey() val token = Jwt.builder() .subject("test") - .signWith(SigningAlgorithm.ES256, keyPair.privateKey) + .signWith(keyPair) .compact() val signatureB64 = token.split('.')[2] @@ -354,7 +382,7 @@ class JwsEncodeTest : FunSpec({ test("sign none produces empty signature part") { val token = Jwt.builder() .subject("test") - .signWith(SigningAlgorithm.None) + .build() .compact() val parts = token.split('.') @@ -364,6 +392,26 @@ class JwsEncodeTest : FunSpec({ } } + context("key capability checks") { + + test("signWith SigningOnlyKey succeeds") { + val keyPair = hs256SigningKey() + val signingOnlyKey = SigningKey.SigningOnlyKey(keyPair.identifier, keyPair.privateKey) + + Jwt.builder() + .subject("test") + .signWith(signingOnlyKey) + } + + test("signWith SigningKeyPair succeeds") { + val keyPair = hs256SigningKey() + + Jwt.builder() + .subject("test") + .signWith(keyPair) + } + } + context("determinism") { test("sign Hs256 two calls produce same token") { @@ -388,4 +436,37 @@ class JwsEncodeTest : FunSpec({ assertNotEquals(t1, t2, "ECDSA signatures should differ across calls due to random nonce") } } + + context("raw key API (backward compat)") { + + test("signWith and verifyWith raw HMAC key") { + val key = hs256Key() + val token = Jwt.builder() + .subject("hs256-compat") + .signWith(SigningAlgorithm.HS256, key) + .compact() + + val jws = Jwt.parser() + .verifyWith(SigningAlgorithm.HS256, key) + .build() + .parseSigned(token) + + assertEquals("hs256-compat", jws.payload.subjectOrNull) + } + + test("signWith and verifyWith raw RSA key pair") { + val keyPair = rsaPkcs1KeyPair() + val token = Jwt.builder() + .subject("rs256-compat") + .signWith(SigningAlgorithm.RS256, keyPair.privateKey) + .compact() + + val jws = Jwt.parser() + .verifyWith(SigningAlgorithm.RS256, keyPair.publicKey) + .build() + .parseSigned(token) + + assertEquals("rs256-compat", jws.payload.subjectOrNull) + } + } }) diff --git a/lib/src/commonTest/kotlin/co/touchlab/kjwt/JwtKeyRegistryTest.kt b/lib/src/commonTest/kotlin/co/touchlab/kjwt/JwtKeyRegistryTest.kt new file mode 100644 index 0000000..e203754 --- /dev/null +++ b/lib/src/commonTest/kotlin/co/touchlab/kjwt/JwtKeyRegistryTest.kt @@ -0,0 +1,362 @@ +package co.touchlab.kjwt + +import co.touchlab.kjwt.ext.subjectOrNull +import co.touchlab.kjwt.model.algorithm.EncryptionAlgorithm +import co.touchlab.kjwt.model.algorithm.EncryptionContentAlgorithm +import co.touchlab.kjwt.model.algorithm.SigningAlgorithm +import co.touchlab.kjwt.model.registry.EncryptionKey +import co.touchlab.kjwt.model.registry.JwtKeyRegistry +import co.touchlab.kjwt.model.registry.SigningKey +import io.kotest.core.spec.style.FunSpec +import kotlin.test.assertEquals +import kotlin.test.assertFailsWith + +class JwtKeyRegistryTest : FunSpec({ + + context("sign using registry") { + + test("sign HS256 using registry signing key") { + val key = hs256Key() + val registry = JwtKeyRegistry() + registry.registerSigningKey( + SigningKey.SigningOnlyKey( + identifier = SigningKey.Identifier(SigningAlgorithm.HS256, null), + privateKey = key, + ) + ) + + val token = Jwt.builder() + .subject("user") + .signWith(SigningAlgorithm.HS256, registry) + .compact() + + val jws = Jwt.parser() + .verifyWith(SigningAlgorithm.HS256, key) + .build() + .parseSigned(token) + + assertEquals("user", jws.payload.subjectOrNull) + } + + test("sign HS256 using registry with kid") { + val key = hs256Key() + val registry = JwtKeyRegistry() + registry.registerSigningKey( + SigningKey.SigningOnlyKey( + identifier = SigningKey.Identifier(SigningAlgorithm.HS256, "sign-key"), + privateKey = key, + ) + ) + + val token = Jwt.builder() + .subject("user") + .signWith(SigningAlgorithm.HS256, registry, "sign-key") + .compact() + + val jws = Jwt.parser() + .verifyWith(SigningAlgorithm.HS256, key, "sign-key") + .build() + .parseSigned(token) + + assertEquals("user", jws.payload.subjectOrNull) + } + + test("sign throws when no matching key in registry") { + val registry = JwtKeyRegistry() + + assertFailsWith { + Jwt.builder() + .subject("user") + .signWith(SigningAlgorithm.HS256, registry) + } + } + } + + context("verify using useKeysFrom") { + + test("parser delegates verification to shared registry") { + val key = hs256Key() + val sharedRegistry = JwtKeyRegistry() + sharedRegistry.registerSigningKey( + SigningKey.VerifyOnlyKey( + identifier = SigningKey.Identifier(SigningAlgorithm.HS256, null), + publicKey = key, + ) + ) + + val token = Jwt.builder() + .subject("user") + .signWith(SigningAlgorithm.HS256, key) + .compact() + + val jws = Jwt.parser() + .useKeysFrom(sharedRegistry) + .build() + .parseSigned(token) + + assertEquals("user", jws.payload.subjectOrNull) + } + + test("parser delegates verification to shared registry with kid") { + val key = hs256Key() + val sharedRegistry = JwtKeyRegistry() + sharedRegistry.registerSigningKey( + SigningKey.VerifyOnlyKey( + identifier = SigningKey.Identifier(SigningAlgorithm.HS256, "k1"), + publicKey = key, + ) + ) + + val token = Jwt.builder() + .subject("user") + .signWith(SigningAlgorithm.HS256, key, "k1") + .compact() + + val jws = Jwt.parser() + .useKeysFrom(sharedRegistry) + .build() + .parseSigned(token) + + assertEquals("user", jws.payload.subjectOrNull) + } + + test("local parser keys take priority over shared registry") { + val key = hs256Key() + val wrongKey = hmacKey( + dev.whyoleg.cryptography.algorithms.SHA256, + "wrong-secret-at-least-256-bits-long-padding".encodeToByteArray(), + ) + val sharedRegistry = JwtKeyRegistry() + sharedRegistry.registerSigningKey( + SigningKey.VerifyOnlyKey( + identifier = SigningKey.Identifier(SigningAlgorithm.HS256, null), + publicKey = wrongKey, + ) + ) + + val token = Jwt.builder() + .subject("user") + .signWith(SigningAlgorithm.HS256, key) + .compact() + + // Local key (correct) takes precedence over the shared registry (wrong key) + val jws = Jwt.parser() + .verifyWith(SigningAlgorithm.HS256, key) + .useKeysFrom(sharedRegistry) + .build() + .parseSigned(token) + + assertEquals("user", jws.payload.subjectOrNull) + } + } + + context("encrypt using registry") { + + test("encrypt Dir A256GCM using registry encryption key") { + val cek = aesSimpleKey(256) + val registry = JwtKeyRegistry() + registry.registerEncryptionKey( + EncryptionKey.EncryptionOnlyKey( + identifier = EncryptionKey.Identifier(EncryptionAlgorithm.Dir, null), + publicKey = cek, + ) + ) + + val token = Jwt.builder() + .subject("user") + .encryptWith(registry, EncryptionAlgorithm.Dir, EncryptionContentAlgorithm.A256GCM) + .compact() + + val jwe = Jwt.parser() + .decryptWith(EncryptionAlgorithm.Dir, cek) + .build() + .parseEncrypted(token) + + assertEquals("user", jwe.payload.subjectOrNull) + } + + test("encrypt Dir A256GCM using registry with kid") { + val cek = aesSimpleKey(256) + val registry = JwtKeyRegistry() + registry.registerEncryptionKey( + EncryptionKey.EncryptionOnlyKey( + identifier = EncryptionKey.Identifier(EncryptionAlgorithm.Dir, "enc-k1"), + publicKey = cek, + ) + ) + + val token = Jwt.builder() + .subject("user") + .encryptWith(registry, EncryptionAlgorithm.Dir, EncryptionContentAlgorithm.A256GCM, "enc-k1") + .compact() + + val jwe = Jwt.parser() + .decryptWith(EncryptionAlgorithm.Dir, cek, "enc-k1") + .build() + .parseEncrypted(token) + + assertEquals("user", jwe.payload.subjectOrNull) + } + + test("encrypt throws when no matching key in registry") { + val registry = JwtKeyRegistry() + + assertFailsWith { + Jwt.builder() + .subject("user") + .encryptWith(registry, EncryptionAlgorithm.Dir, EncryptionContentAlgorithm.A256GCM) + } + } + } + + context("decrypt using useKeysFrom") { + + test("parser delegates decryption to shared registry") { + val cek = aesSimpleKey(256) + val sharedRegistry = JwtKeyRegistry() + sharedRegistry.registerEncryptionKey( + EncryptionKey.DecryptionOnlyKey( + identifier = EncryptionKey.Identifier(EncryptionAlgorithm.Dir, null), + privateKey = cek, + ) + ) + + val token = Jwt.builder() + .subject("user") + .encryptWith(cek, EncryptionAlgorithm.Dir, EncryptionContentAlgorithm.A256GCM) + .compact() + + val jwe = Jwt.parser() + .useKeysFrom(sharedRegistry) + .build() + .parseEncrypted(token) + + assertEquals("user", jwe.payload.subjectOrNull) + } + + test("parser delegates decryption to shared registry with kid") { + val cek = aesSimpleKey(256) + val sharedRegistry = JwtKeyRegistry() + sharedRegistry.registerEncryptionKey( + EncryptionKey.DecryptionOnlyKey( + identifier = EncryptionKey.Identifier(EncryptionAlgorithm.Dir, "enc-k1"), + privateKey = cek, + ) + ) + + val token = Jwt.builder() + .subject("user") + .encryptWith(cek, EncryptionAlgorithm.Dir, EncryptionContentAlgorithm.A256GCM, "enc-k1") + .compact() + + val jwe = Jwt.parser() + .useKeysFrom(sharedRegistry) + .build() + .parseEncrypted(token) + + assertEquals("user", jwe.payload.subjectOrNull) + } + } + + context("delegation cycle detection") { + + test("self-delegation throws") { + val registry = JwtKeyRegistry() + assertFailsWith { + registry.delegateTo(registry) + } + } + + test("direct cycle throws") { + val a = JwtKeyRegistry() + val b = JwtKeyRegistry() + a.delegateTo(b) + assertFailsWith { + b.delegateTo(a) + } + } + + test("transitive cycle throws") { + val a = JwtKeyRegistry() + val b = JwtKeyRegistry() + val c = JwtKeyRegistry() + a.delegateTo(b) + b.delegateTo(c) + assertFailsWith { + c.delegateTo(a) + } + } + + test("linear chain without cycle is allowed") { + val a = JwtKeyRegistry() + val b = JwtKeyRegistry() + val c = JwtKeyRegistry() + a.delegateTo(b) + b.delegateTo(c) + // no exception expected + } + } + + context("full round-trip via merged registry") { + + test("sign and verify using a registry with merged SigningKeyPair") { + val key = hs256Key() + val sharedRegistry = JwtKeyRegistry() + // Registering complementary keys merges them into a SigningKeyPair + sharedRegistry.registerSigningKey( + SigningKey.SigningOnlyKey( + identifier = SigningKey.Identifier(SigningAlgorithm.HS256, "k1"), + privateKey = key, + ) + ) + sharedRegistry.registerSigningKey( + SigningKey.VerifyOnlyKey( + identifier = SigningKey.Identifier(SigningAlgorithm.HS256, "k1"), + publicKey = key, + ) + ) + + val token = Jwt.builder() + .subject("registry-user") + .signWith(SigningAlgorithm.HS256, sharedRegistry, "k1") + .compact() + + val jws = Jwt.parser() + .useKeysFrom(sharedRegistry) + .build() + .parseSigned(token) + + assertEquals("registry-user", jws.payload.subjectOrNull) + } + + test("encrypt and decrypt using a registry with merged EncryptionKeyPair") { + val cek = aesSimpleKey(256) + val sharedRegistry = JwtKeyRegistry() + // Register DecryptionOnlyKey first, then EncryptionOnlyKey — merges into EncryptionKeyPair + sharedRegistry.registerEncryptionKey( + EncryptionKey.DecryptionOnlyKey( + identifier = EncryptionKey.Identifier(EncryptionAlgorithm.Dir, "enc-k1"), + privateKey = cek, + ) + ) + sharedRegistry.registerEncryptionKey( + EncryptionKey.EncryptionOnlyKey( + identifier = EncryptionKey.Identifier(EncryptionAlgorithm.Dir, "enc-k1"), + publicKey = cek, + ) + ) + + val token = Jwt.builder() + .subject("encrypted-user") + .encryptWith(sharedRegistry, EncryptionAlgorithm.Dir, EncryptionContentAlgorithm.A256GCM, "enc-k1") + .compact() + + val jwe = Jwt.parser() + .useKeysFrom(sharedRegistry) + .build() + .parseEncrypted(token) + + assertEquals("encrypted-user", jwe.payload.subjectOrNull) + } + } +}) \ No newline at end of file diff --git a/lib/src/commonTest/kotlin/co/touchlab/kjwt/MalformedTokenTest.kt b/lib/src/commonTest/kotlin/co/touchlab/kjwt/MalformedTokenTest.kt index 9269856..e0b3b62 100644 --- a/lib/src/commonTest/kotlin/co/touchlab/kjwt/MalformedTokenTest.kt +++ b/lib/src/commonTest/kotlin/co/touchlab/kjwt/MalformedTokenTest.kt @@ -163,7 +163,7 @@ class MalformedTokenTest : FunSpec({ test("parse none without allow unsecured throws UnsupportedJwtException") { val noneToken = Jwt.builder() .subject("user") - .signWith(SigningAlgorithm.None) + .build() .compact() assertFailsWith { diff --git a/lib/src/commonTest/kotlin/co/touchlab/kjwt/MultiKeyParserTest.kt b/lib/src/commonTest/kotlin/co/touchlab/kjwt/MultiKeyParserTest.kt new file mode 100644 index 0000000..9124028 --- /dev/null +++ b/lib/src/commonTest/kotlin/co/touchlab/kjwt/MultiKeyParserTest.kt @@ -0,0 +1,255 @@ +package co.touchlab.kjwt + +import co.touchlab.kjwt.ext.key +import co.touchlab.kjwt.ext.parse +import co.touchlab.kjwt.ext.subjectOrNull +import co.touchlab.kjwt.model.algorithm.EncryptionAlgorithm +import co.touchlab.kjwt.model.algorithm.EncryptionContentAlgorithm +import co.touchlab.kjwt.model.algorithm.SigningAlgorithm +import io.kotest.core.spec.style.FunSpec +import kotlin.random.Random +import kotlin.test.assertEquals +import kotlin.test.assertFailsWith + +class MultiKeyParserTest : FunSpec({ + + context("multi-key JWS verification") { + + test("exact match — token kid matches registered kid") { + val signingKey = hs256SigningKey(keyId = "k1") + val token = Jwt.builder() + .subject("user") + .signWith(signingKey) + .compact() + + val jws = Jwt.parser() + .verifyWith(signingKey) + .build() + .parseSigned(token) + + assertEquals("user", jws.payload.subjectOrNull) + } + + test("fallback — token has kid but parser has algo-only key") { + val signingKey = hs256SigningKey(keyId = "k1") + val token = Jwt.builder() + .subject("user") + .signWith(signingKey) + .compact() + + // No keyId on verifyWith → algo-only key, matches any kid for HS256 + val fallbackKey = hs256SigningKey() + val jws = Jwt.parser() + .verifyWith(fallbackKey) + .build() + .parseSigned(token) + + assertEquals("user", jws.payload.subjectOrNull) + } + + test("no kid token — parser has algo-only key") { + val signingKey = hs256SigningKey() + val token = Jwt.builder() + .subject("user") + .signWith(signingKey) + .compact() + + val jws = Jwt.parser() + .verifyWith(signingKey) + .build() + .parseSigned(token) + + assertEquals("user", jws.payload.subjectOrNull) + } + + test("multiple keys same algo — correct kid selected") { + val key1 = SigningAlgorithm.HS256.parse( + "secret-for-k1-at-least-256-bits-long".encodeToByteArray(), keyId = "k1" + ) + val key2 = SigningAlgorithm.HS256.parse( + "secret-for-k2-at-least-256-bits-long".encodeToByteArray(), keyId = "k2" + ) + + val token = Jwt.builder() + .subject("user") + .signWith(key2) + .compact() + + val jws = Jwt.parser() + .verifyWith(key1) + .verifyWith(key2) + .build() + .parseSigned(token) + + assertEquals("user", jws.payload.subjectOrNull) + } + + test("multiple keys same algo — unmatched kid falls back to algo-only key") { + val keyFallback = hs256SigningKey() // algo-only, same secret as hs256Secret + val key1 = SigningAlgorithm.HS256.parse( + "secret-for-k1-at-least-256-bits-long".encodeToByteArray(), keyId = "k1" + ) + + val token = Jwt.builder() + .subject("user") + .signWith(hs256SigningKey(keyId = "k-unknown")) + .compact() + + // kid="k-unknown" doesn't match "k1", so falls back to algo-only key + val jws = Jwt.parser() + .verifyWith(key1) + .verifyWith(keyFallback) + .build() + .parseSigned(token) + + assertEquals("user", jws.payload.subjectOrNull) + } + + test("no key found — token with kid, parser has only different kid") { + val signingKey = hs256SigningKey(keyId = "k1") + val token = Jwt.builder() + .subject("user") + .signWith(signingKey) + .compact() + + assertFailsWith { + Jwt.parser() + .verifyWith(hs256SigningKey(keyId = "k2")) + .build() + .parseSigned(token) + } + } + + test("no key found — token has no kid, parser has only keyed entries") { + val signingKey = hs256SigningKey() + val token = Jwt.builder() + .subject("user") + .signWith(signingKey) + .compact() + + // Token has no kid → only (HS256, null) is looked up, but parser only has (HS256, "k1") + assertFailsWith { + Jwt.parser() + .verifyWith(hs256SigningKey(keyId = "k1")) + .build() + .parseSigned(token) + } + } + + test("none algorithm fallback — no matching key, noVerify registered") { + val key = hs256SigningKey(keyId = "k-unknown") + val key1 = SigningAlgorithm.HS256.parse( + "secret-for-k1-at-least-256-bits-long".encodeToByteArray(), keyId = "k1" + ) + + val token = Jwt.builder() + .subject("user") + .signWith(key) + .compact() + + // kid="k-unknown" matches neither "k1" nor algo-only (none registered); + // falls back to the None verifier registered by noVerify() + val jws = Jwt.parser() + .verifyWith(key1) + .noVerify() + .build() + .parseSigned(token) + + assertEquals("user", jws.payload.subjectOrNull) + } + + test("duplicate signing key registration throws") { + assertFailsWith { + Jwt.parser() + .verifyWith(hs256SigningKey(keyId = "k1")) + .verifyWith(hs256SigningKey(keyId = "k1")) + } + } + + test("duplicate algo-only signing key registration throws") { + assertFailsWith { + Jwt.parser() + .verifyWith(hs256SigningKey()) + .verifyWith(hs256SigningKey()) + } + } + } + + context("multi-key JWE decryption") { + + test("exact match — token kid matches registered kid") { + val encKey = dirEncKey(256, keyId = "k1") + val token = Jwt.builder() + .subject("user") + .encryptWith(encKey, EncryptionContentAlgorithm.A256GCM) + .compact() + + val jwe = Jwt.parser() + .decryptWith(encKey) + .build() + .parseEncrypted(token) + + assertEquals("user", jwe.payload.subjectOrNull) + } + + test("fallback — token has kid but parser has algo-only decryptor") { + val bytes = Random.Default.nextBytes(32) + val encKey = EncryptionAlgorithm.Dir.key(bytes, "k1") // token encrypted with kid="k1" + val fallbackKey = EncryptionAlgorithm.Dir.key(bytes) // same bytes, algo-only (no kid) + + val token = Jwt.builder() + .subject("user") + .encryptWith(encKey, EncryptionContentAlgorithm.A256GCM) + .compact() + + // No keyId on decryptWith → algo-only, matches any kid for Dir + val jwe = Jwt.parser() + .decryptWith(fallbackKey) + .build() + .parseEncrypted(token) + + assertEquals("user", jwe.payload.subjectOrNull) + } + + test("no kid token — parser has algo-only decryptor") { + val encKey = dirEncKey(256) + val token = Jwt.builder() + .subject("user") + .encryptWith(encKey, EncryptionContentAlgorithm.A256GCM) + .compact() + + val jwe = Jwt.parser() + .decryptWith(encKey) + .build() + .parseEncrypted(token) + + assertEquals("user", jwe.payload.subjectOrNull) + } + + test("multiple keys same algo — correct kid selected") { + val encKey1 = dirEncKey(256, keyId = "k1") + val encKey2 = dirEncKey(256, keyId = "k2") + + val token = Jwt.builder() + .subject("user") + .encryptWith(encKey2, EncryptionContentAlgorithm.A256GCM) + .compact() + + val jwe = Jwt.parser() + .decryptWith(encKey1) + .decryptWith(encKey2) + .build() + .parseEncrypted(token) + + assertEquals("user", jwe.payload.subjectOrNull) + } + + test("duplicate decryption key registration throws") { + assertFailsWith { + Jwt.parser() + .decryptWith(dirEncKey(256, keyId = "k1")) + .decryptWith(dirEncKey(256, keyId = "k1")) + } + } + } +}) diff --git a/lib/src/commonTest/kotlin/co/touchlab/kjwt/TestFixtures.kt b/lib/src/commonTest/kotlin/co/touchlab/kjwt/TestFixtures.kt index 56c218d..e4d56cc 100644 --- a/lib/src/commonTest/kotlin/co/touchlab/kjwt/TestFixtures.kt +++ b/lib/src/commonTest/kotlin/co/touchlab/kjwt/TestFixtures.kt @@ -1,17 +1,24 @@ package co.touchlab.kjwt import co.touchlab.kjwt.cryptography.SimpleKey +import co.touchlab.kjwt.ext.key +import co.touchlab.kjwt.ext.newKey +import co.touchlab.kjwt.ext.parse +import co.touchlab.kjwt.model.algorithm.EncryptionAlgorithm +import co.touchlab.kjwt.model.algorithm.SigningAlgorithm +import co.touchlab.kjwt.model.registry.EncryptionKey +import co.touchlab.kjwt.model.registry.SigningKey +import dev.whyoleg.cryptography.CryptographyAlgorithmId import dev.whyoleg.cryptography.CryptographyProvider -import dev.whyoleg.cryptography.algorithms.ECDSA +import dev.whyoleg.cryptography.algorithms.Digest import dev.whyoleg.cryptography.algorithms.EC +import dev.whyoleg.cryptography.algorithms.ECDSA import dev.whyoleg.cryptography.algorithms.HMAC import dev.whyoleg.cryptography.algorithms.RSA import dev.whyoleg.cryptography.algorithms.SHA1 import dev.whyoleg.cryptography.algorithms.SHA256 import dev.whyoleg.cryptography.algorithms.SHA384 import dev.whyoleg.cryptography.algorithms.SHA512 -import dev.whyoleg.cryptography.algorithms.Digest -import dev.whyoleg.cryptography.CryptographyAlgorithmId import co.touchlab.kjwt.exception.JwtException import kotlin.io.encoding.Base64 import kotlin.io.encoding.ExperimentalEncodingApi @@ -76,6 +83,43 @@ fun decodeTokenPayload(token: String): String { return Base64.UrlSafe.decode(padded).decodeToString() } +// ---- SigningKey helpers (library API) ---- + +// HMAC: parse from known test secrets for deterministic use in tests +suspend fun hs256SigningKey(keyId: String? = null): SigningKey.SigningKeyPair = + SigningAlgorithm.HS256.parse(hs256Secret, keyId = keyId) + +suspend fun hs384SigningKey(keyId: String? = null): SigningKey.SigningKeyPair = + SigningAlgorithm.HS384.parse(hs384Secret, keyId = keyId) + +suspend fun hs512SigningKey(keyId: String? = null): SigningKey.SigningKeyPair = + SigningAlgorithm.HS512.parse(hs512Secret, keyId = keyId) + +// RSA PKCS1 +suspend fun rs256SigningKey(keyId: String? = null) = SigningAlgorithm.RS256.newKey(keyId = keyId) +suspend fun rs384SigningKey(keyId: String? = null) = SigningAlgorithm.RS384.newKey(keyId = keyId) +suspend fun rs512SigningKey(keyId: String? = null) = SigningAlgorithm.RS512.newKey(keyId = keyId) + +// RSA PSS +suspend fun ps256SigningKey(keyId: String? = null) = SigningAlgorithm.PS256.newKey(keyId = keyId) +suspend fun ps384SigningKey(keyId: String? = null) = SigningAlgorithm.PS384.newKey(keyId = keyId) +suspend fun ps512SigningKey(keyId: String? = null) = SigningAlgorithm.PS512.newKey(keyId = keyId) + +// ECDSA +suspend fun es256SigningKey(keyId: String? = null) = SigningAlgorithm.ES256.newKey(keyId = keyId) +suspend fun es384SigningKey(keyId: String? = null) = SigningAlgorithm.ES384.newKey(keyId = keyId) +suspend fun es512SigningKey(keyId: String? = null) = SigningAlgorithm.ES512.newKey(keyId = keyId) + +// ---- EncryptionKey helpers (library API) ---- + +// Dir: wrap random bytes; not suspend (Dir.key is a plain function) +fun dirEncKey(bits: Int, keyId: String? = null): EncryptionKey.EncryptionKeyPair = + EncryptionAlgorithm.Dir.key(Random.Default.nextBytes(bits / 8), keyId) + +// RSA-OAEP +suspend fun rsaOaepEncKey(keyId: String? = null) = EncryptionAlgorithm.RsaOaep.newKey(keyId = keyId) +suspend fun rsaOaep256EncKey(keyId: String? = null) = EncryptionAlgorithm.RsaOaep256.newKey(keyId = keyId) + /** Decodes the header (first) part of a compact JWT and returns it as a JSON string. */ @OptIn(ExperimentalEncodingApi::class) fun decodeTokenHeader(token: String): String { diff --git a/lib/src/commonTest/kotlin/co/touchlab/kjwt/jwk/JwkBuilderExtTest.kt b/lib/src/commonTest/kotlin/co/touchlab/kjwt/jwk/JwkBuilderExtTest.kt index 6f2ffcd..59f7810 100644 --- a/lib/src/commonTest/kotlin/co/touchlab/kjwt/jwk/JwkBuilderExtTest.kt +++ b/lib/src/commonTest/kotlin/co/touchlab/kjwt/jwk/JwkBuilderExtTest.kt @@ -1,6 +1,7 @@ package co.touchlab.kjwt.jwk import co.touchlab.kjwt.Jwt +import co.touchlab.kjwt.decodeTokenHeader import co.touchlab.kjwt.ext.signWith import co.touchlab.kjwt.ext.subjectOrNull import co.touchlab.kjwt.ext.verifyWith @@ -12,6 +13,7 @@ import co.touchlab.kjwt.model.algorithm.SigningAlgorithm import co.touchlab.kjwt.model.jwk.Jwk import io.kotest.core.spec.style.FunSpec import kotlin.test.assertEquals +import kotlin.test.assertTrue /** * Tests the JWK builder/parser extensions end-to-end. @@ -98,4 +100,43 @@ class JwkBuilderExtTest : FunSpec({ assertEquals("jwk-cross-verify", jws.payload.subjectOrNull) } } + + context("keyId propagation") { + + test("sign with HS256 JWK defaults kid to jwk.kid") { + val jwk = Jwk.Oct(k = hs256Secret.encodeBase64Url(), alg = "HS256", kid = "my-hmac-key") + + val token = Jwt.builder() + .subject("test") + .signWith(SigningAlgorithm.HS256, jwk) + .compact() + + val headerJson = decodeTokenHeader(token) + assertTrue(headerJson.contains("\"kid\":\"my-hmac-key\""), "Header must contain jwk kid, got: $headerJson") + } + + test("sign with HS256 JWK explicit keyId overrides jwk.kid") { + val jwk = Jwk.Oct(k = hs256Secret.encodeBase64Url(), alg = "HS256", kid = "jwk-kid") + + val token = Jwt.builder() + .subject("test") + .signWith(SigningAlgorithm.HS256, jwk, keyId = "override-kid") + .compact() + + val headerJson = decodeTokenHeader(token) + assertTrue(headerJson.contains("\"kid\":\"override-kid\""), "Header must contain overridden kid, got: $headerJson") + } + + test("sign with HS256 JWK null keyId omits kid from header") { + val jwk = Jwk.Oct(k = hs256Secret.encodeBase64Url(), alg = "HS256", kid = "some-kid") + + val token = Jwt.builder() + .subject("test") + .signWith(SigningAlgorithm.HS256, jwk, keyId = null) + .compact() + + val headerJson = decodeTokenHeader(token) + assertTrue(!headerJson.contains("\"kid\""), "Header must not contain kid when null, got: $headerJson") + } + } })