@@ -29,17 +29,22 @@ module Share.JWT
2929
3030 -- * Utilities
3131 JWTParam (.. ),
32+ Issuer (.. ),
33+ Audience (.. ),
3234 textToSignedJWT ,
3335 signedJWTToText ,
3436 createSignedCookie ,
3537
3638 -- * Re-exports
3739 CryptoError (.. ),
40+ JWK. JWK ,
41+ JWK. JWKSet ,
3842 )
3943where
4044
4145import Control.Lens
4246import Control.Monad.Except
47+ import Control.Monad.Trans.Except (except )
4348import Crypto.Error (CryptoError (.. ), CryptoFailable (.. ))
4449import Crypto.JOSE.JWA.JWS qualified as JWS
4550import Crypto.JOSE.JWK qualified as JWK
@@ -50,32 +55,41 @@ import Data.Aeson qualified as Aeson
5055import Data.ByteArray qualified as ByteArray
5156import Data.ByteString qualified as BS
5257import Data.ByteString.Base64.URL qualified as Base64URL
58+ import Data.Map (Map )
5359import Data.Map qualified as Map
5460import Data.Set (Set )
5561import Data.Set qualified as Set
5662import Data.Text (Text )
5763import Data.Text qualified as Text
5864import Data.Text.Encoding qualified as Text
65+ import Data.Traversable (for )
5966import Servant
6067import Share.JWT.Types
6168import Share.OAuth.Orphans ()
6269import Share.Utils.Servant.Cookies qualified as Cookies
6370import UnliftIO (MonadIO (.. ))
71+ import UnliftIO.STM
6472
65- -- | Get the JWK Set value which is safe to expose to the public, e.g. in a JWKS endpoint.
73+ -- | Get the JWK Set for an issuer which is safe to expose to the public, e.g. in a JWKS endpoint.
6674-- This will only include public keys.
6775--
6876-- Note that this will not include the legacy key or any HS256 keys, since those don't have a
6977-- safe public component.
70- publicJWKSet :: JWTSettings -> JWK. JWKSet
71- publicJWKSet JWTSettings {validationKeys = KeyMap {byKeyId}} =
72- JWK. JWKSet
73- ( byKeyId
74- & foldMap (\ jwk -> jwk ^.. JWK. asPublicKey . _Just)
75- )
78+ publicJWKSet :: (MonadIO m ) => JWTSettings -> Issuer -> m JWK. JWKSet
79+ publicJWKSet JWTSettings {validationKeys = KeyMap {keysVar}} issuer = do
80+ keyMap <- liftIO $ readTVarIO keysVar
81+ pure $
82+ JWK. JWKSet
83+ ( keyMap
84+ & Map. lookup issuer
85+ & foldMap (\ jwk -> jwk ^.. folded . JWK. asPublicKey . _Just)
86+ )
7687
7788-- | Create a 'JWTSettings' using the required information.
7889defaultJWTSettings ::
90+ (MonadIO m ) =>
91+ -- | Which issuer is the current service
92+ Issuer ->
7993 -- | The key used to sign JWTs.
8094 KeyDescription ->
8195 -- | The legacy key used to verify old JWTs from before key IDs were used. This will be used to verify tokens that don't have a key id.
@@ -86,25 +100,42 @@ defaultJWTSettings ::
86100 -- Tokens must have an audience which is present in this set.
87101 --
88102 -- E.g. https://api.unison.cloud
89- Set URI ->
103+ Set Audience ->
90104 -- | Valid issuers when validating tokens
91- Set URI ->
92- Either CryptoError JWTSettings
93- defaultJWTSettings signingKey legacyKey oldValidKeys acceptedAudiences acceptedIssuers = do
94- sjwk@ (_, signingJWK) <- keyDescToJWK signingKey
95- verificationJWKs <- (sjwk : ) <$> traverse keyDescToJWK (Set. toList oldValidKeys)
96- let byKeyId = Map. fromList verificationJWKs
97- legacyKey <- traverse keyDescToJWK legacyKey <&> fmap snd
105+ Set Issuer ->
106+ -- | Mapping of issuers to either their:
107+ -- * JWK json endpoint.
108+ -- * JWK set.
109+ --
110+ -- If a JWK URI is provided for an issuer it will be fetched and kept up to date as needed.
111+ Map Issuer (Either URI JWT. JWKSet ) ->
112+ m (Either CryptoError JWTSettings )
113+ defaultJWTSettings myIssuer signingKey legacyKey oldValidKeys acceptedAudiences acceptedIssuers externalJWKsMap = runExceptT $ do
114+ signingJWK <- except $ keyDescToJWK signingKey
115+ myVerificationJWKs <- (signingJWK : ) <$> traverse (except . keyDescToJWK) (Set. toList oldValidKeys)
116+ let (externalJWKs, externalJWKLocations) =
117+ externalJWKsMap
118+ & foldMap \ case
119+ Left uri -> (mempty , Map. singleton myIssuer uri)
120+ Right (JWT. JWKSet jwks) -> (Map. singleton myIssuer jwks, mempty )
121+
122+ let myJWKs = Map. singleton myIssuer myVerificationJWKs
123+ let keysMap = myJWKs <> externalJWKs
124+ keysVar <- liftIO $ newTVarIO keysMap
125+ lastCheckedVar <- liftIO $ newTVarIO Map. empty
126+ legacyKey <- for legacyKey (except . keyDescToJWK)
98127 pure $
99128 JWTSettings
100129 { signingJWK,
101- validationKeys = KeyMap {byKeyId, legacyKey},
130+ validationKeys = KeyMap {keysVar, legacyKey},
131+ externalJWKLocations,
132+ lastCheckedVar,
102133 acceptedAudiences,
103134 acceptedIssuers
104135 }
105136
106137-- | Converts a 'KeyDescription' to a 'JWK' and a 'KeyThumbprint'.
107- keyDescToJWK :: KeyDescription -> Either CryptoError ( KeyThumbprint , JWK. JWK)
138+ keyDescToJWK :: KeyDescription -> Either CryptoError JWK. JWK
108139keyDescToJWK (KeyDescription {key, alg}) = cryptoFailableToEither $ do
109140 case alg of
110141 HS256 -> do
@@ -113,7 +144,7 @@ keyDescToJWK (KeyDescription {key, alg}) = cryptoFailableToEither $ do
113144 & JWK. jwkUse .~ Just JWK. Sig
114145 & JWK. jwkAlg .~ Just (JWK. JWSAlg JWS. HS256 )
115146 let thumbprint = jwkThumbprint jwk
116- pure (KeyThumbprint thumbprint, jwk & JWK. jwkKid .~ Just thumbprint)
147+ pure (jwk & JWK. jwkKid .~ Just thumbprint)
117148 Ed25519 -> do
118149 privKey <- Ed25519. secretKey key
119150 let pubKey = Ed25519. toPublic privKey
@@ -124,7 +155,7 @@ keyDescToJWK (KeyDescription {key, alg}) = cryptoFailableToEither $ do
124155 & JWK. jwkUse .~ Just JWK. Sig
125156 & JWK. jwkAlg .~ Just (JWK. JWSAlg JWS. EdDSA )
126157 let thumbprint = jwkThumbprint jwk
127- pure (KeyThumbprint thumbprint, jwk & JWK. jwkKid .~ Just thumbprint)
158+ pure (jwk & JWK. jwkKid .~ Just thumbprint)
128159 where
129160 cryptoFailableToEither :: CryptoFailable a -> Either CryptoError a
130161 cryptoFailableToEither (CryptoFailed err) = Left err
@@ -160,8 +191,8 @@ signJWTWithJWK jwk v = runExceptT $ do
160191--
161192-- Any other checks should be performed on the returned claims.
162193verifyJWT :: forall claims m . (AsJWTClaims claims , MonadIO m ) => JWTSettings -> JWT. SignedJWT -> m (Either JWT. JWTError claims )
163- verifyJWT JWTSettings {validationKeys, acceptedAudiences, acceptedIssuers} signedJWT = runExceptT do
164- jwtClaimsMap <- ExceptT . liftIO . runExceptT $ JWT. verifyJWT validators validationKeys signedJWT
194+ verifyJWT jwtSettings @ JWTSettings {acceptedAudiences, acceptedIssuers} signedJWT = runExceptT do
195+ jwtClaimsMap <- ExceptT . liftIO . runExceptT $ JWT. verifyJWT validators jwtSettings signedJWT
165196 case fromClaims jwtClaimsMap of
166197 Left err -> throwError $ JWT. JWTClaimsSetDecodeError (Text. unpack err)
167198 Right claims -> pure claims
@@ -170,12 +201,12 @@ verifyJWT JWTSettings {validationKeys, acceptedAudiences, acceptedIssuers} signe
170201 auds =
171202 -- Annoyingly StringOrURI doesn't have an ord instance.
172203 Set. toList acceptedAudiences
173- & map (review CryptoJWT. uri)
204+ & map (\ ( Audience aud) -> review CryptoJWT. uri aud )
174205 issuers :: [CryptoJWT. StringOrURI ]
175206 issuers =
176207 -- Annoyingly StringOrURI doesn't have an ord instance.
177208 Set. toList acceptedIssuers
178- & map (review CryptoJWT. uri)
209+ & map (\ ( Issuer iss) -> review CryptoJWT. uri iss )
179210 validators =
180211 CryptoJWT. defaultJWTValidationSettings (`elem` auds)
181212 & CryptoJWT. issuerPredicate .~ (`elem` issuers)
0 commit comments