@@ -10,7 +10,9 @@ module Share.JWT
1010 ServantAuth. ToJWT (.. ),
1111 ServantAuth. FromJWT (.. ),
1212 signJWT ,
13+ signJWTWithJWK ,
1314 verifyJWT ,
15+ keyDescToJWK ,
1416
1517 -- * Additional Helpers
1618 textToSignedJWT ,
@@ -97,7 +99,8 @@ newtype KeyThumbprint = KeyThumbprint Text
9799
98100data KeyMap = KeyMap
99101 { byKeyId :: (Map KeyThumbprint JWT. JWK ),
100- -- | The key from before the introduction of key ids. This will be used to verify legacy tokens, but can eventually be removed.
102+ -- | The key from before the introduction of key ids.
103+ -- This will be used to verify legacy tokens, but is also used to sign HashJWTs on share.
101104 legacyKey :: Maybe JWT. JWK
102105 }
103106 deriving (Show ) via (Censored KeyMap )
@@ -134,11 +137,11 @@ defaultJWTSettings ::
134137 -- E.g. https://api.unison-lang.org
135138 URI ->
136139 Either CryptoError JWTSettings
137- defaultJWTSettings signingKey legacyKey oldValidKeys acceptedAudiences issuer = toEither do
138- sjwk@ (_, signingJWK) <- toJWK signingKey
139- verificationJWKs <- (sjwk : ) <$> traverse toJWK (Set. toList oldValidKeys)
140+ defaultJWTSettings signingKey legacyKey oldValidKeys acceptedAudiences issuer = do
141+ sjwk@ (_, signingJWK) <- keyDescToJWK signingKey
142+ verificationJWKs <- (sjwk : ) <$> traverse keyDescToJWK (Set. toList oldValidKeys)
140143 let byKeyId = Map. fromList verificationJWKs
141- legacyKey <- traverse toJWK legacyKey <&> fmap snd
144+ legacyKey <- traverse keyDescToJWK legacyKey <&> fmap snd
142145 pure $
143146 JWTSettings
144147 { signingJWK,
@@ -148,31 +151,33 @@ defaultJWTSettings signingKey legacyKey oldValidKeys acceptedAudiences issuer =
148151 acceptedAudiences,
149152 issuer
150153 }
154+
155+ -- | Converts a 'KeyDescription' to a 'JWK' and a 'KeyThumbprint'.
156+ keyDescToJWK :: KeyDescription -> Either CryptoError (KeyThumbprint , JWK. JWK )
157+ keyDescToJWK (KeyDescription {key, alg}) = cryptoFailableToEither $ do
158+ case alg of
159+ HS256 -> do
160+ let jwk =
161+ JWK. fromOctets key
162+ & JWK. jwkUse .~ Just JWK. Sig
163+ & JWK. jwkAlg .~ Just (JWK. JWSAlg JWS. HS256 )
164+ let thumbprint = jwkThumbprint jwk
165+ pure (KeyThumbprint thumbprint, jwk & JWK. jwkKid .~ Just thumbprint)
166+ Ed25519 -> do
167+ privKey <- Ed25519. secretKey key
168+ let pubKey = Ed25519. toPublic privKey
169+ let jwk =
170+ (JWT. Ed25519Key pubKey (Just privKey))
171+ & JWT. OKPKeyMaterial
172+ & JWK. fromKeyMaterial
173+ & JWK. jwkUse .~ Just JWK. Sig
174+ & JWK. jwkAlg .~ Just (JWK. JWSAlg JWS. EdDSA )
175+ let thumbprint = jwkThumbprint jwk
176+ pure (KeyThumbprint thumbprint, jwk & JWK. jwkKid .~ Just thumbprint)
151177 where
152- toEither :: CryptoFailable a -> Either CryptoError a
153- toEither (CryptoFailed err) = Left err
154- toEither (CryptoPassed a) = Right a
155- toJWK :: KeyDescription -> CryptoFailable (KeyThumbprint , JWK. JWK )
156- toJWK (KeyDescription {key, alg}) =
157- case alg of
158- HS256 -> do
159- let jwk =
160- JWK. fromOctets key
161- & JWK. jwkUse .~ Just JWK. Sig
162- & JWK. jwkAlg .~ Just (JWK. JWSAlg JWS. HS256 )
163- let thumbprint = jwkThumbprint jwk
164- pure (KeyThumbprint thumbprint, jwk & JWK. jwkKid .~ Just thumbprint)
165- Ed25519 -> do
166- privKey <- Ed25519. secretKey key
167- let pubKey = Ed25519. toPublic privKey
168- let jwk =
169- (JWT. Ed25519Key pubKey (Just privKey))
170- & JWT. OKPKeyMaterial
171- & JWK. fromKeyMaterial
172- & JWK. jwkUse .~ Just JWK. Sig
173- & JWK. jwkAlg .~ Just (JWK. JWSAlg JWS. EdDSA )
174- let thumbprint = jwkThumbprint jwk
175- pure (KeyThumbprint thumbprint, jwk & JWK. jwkKid .~ Just thumbprint)
178+ cryptoFailableToEither :: CryptoFailable a -> Either CryptoError a
179+ cryptoFailableToEither (CryptoFailed err) = Left err
180+ cryptoFailableToEither (CryptoPassed a) = Right a
176181
177182 jwkThumbprint :: JWK. JWK -> Text
178183 jwkThumbprint jwk =
@@ -216,10 +221,15 @@ textToSignedJWT jwtText = JWT.decodeCompact (TL.encodeUtf8 . TL.fromStrict $ jwt
216221
217222-- | Signs and encodes a JWT using the given 'JWTSettings'.
218223signJWT :: forall m v . (MonadIO m , ServantAuth. ToJWT v ) => JWTSettings -> v -> m (Either JWT. JWTError JWT. SignedJWT )
219- signJWT JWTSettings {signingJWK} v = runExceptT $ do
224+ signJWT JWTSettings {signingJWK} v = signJWTWithJWK signingJWK v
225+
226+ -- | Signs and encodes a JWT using the given JWK, you should typically use 'signJWT' instead
227+ -- unless you have a specific reason to use a different JWK.
228+ signJWTWithJWK :: forall m v . (MonadIO m , ServantAuth. ToJWT v ) => JWK. JWK -> v -> m (Either JWT. JWTError JWT. SignedJWT )
229+ signJWTWithJWK jwk v = runExceptT $ do
220230 let claimsSet = ServantAuth. encodeJWT v
221- jwtHeader <- mapExceptT liftIO (JWT. makeJWSHeader signingJWK )
222- mapExceptT liftIO (JWT. signClaims signingJWK jwtHeader claimsSet)
231+ jwtHeader <- mapExceptT liftIO (JWT. makeJWSHeader jwk )
232+ mapExceptT liftIO (JWT. signClaims jwk jwtHeader claimsSet)
223233
224234-- | Decodes a JWT and verifies the following:
225235-- * algorithm (except for legacy tokens)
0 commit comments