Skip to content

Commit 65aa2ea

Browse files
⅄ main → narrow
2 parents 4c79d57 + af15e26 commit 65aa2ea

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

54 files changed

+1767
-191
lines changed

app/Env.hs

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ import Data.ByteString.Char8 qualified as BS
1010
import Data.Char
1111
import Data.Char qualified as Char
1212
import Data.Either.Combinators
13+
import Data.Map qualified as Map
1314
import Data.Functor
1415
import Data.HashMap.Strict qualified as HM
1516
import Data.Set qualified as Set
@@ -73,6 +74,8 @@ withEnv action = do
7374
maxParallelismPerDownloadRequest <- fromEnv "SHARE_MAX_PARALLELISM_PER_DOWNLOAD_REQUEST" (pure . maybeToEither "Invalid SHARE_MAX_PARALLELISM_PER_DOWNLOAD_REQUEST" . readMaybe)
7475
maxParallelismPerUploadRequest <- fromEnv "SHARE_MAX_PARALLELISM_PER_UPLOAD_REQUEST" (pure . maybeToEither "Invalid SHARE_MAX_PARALLELISM_PER_UPLOAD_REQUEST" . readMaybe)
7576
cloudWebsiteOrigin <- fromEnv "SHARE_CLOUD_HOMEPAGE_ORIGIN" (pure . maybeToEither "Invalid SHARE_CLOUD_HOMEPAGE_ORIGIN" . parseURI)
77+
cloudAPIOrigin <- fromEnv "SHARE_CLOUD_API_ORIGIN" (pure . maybeToEither "Invalid SHARE_CLOUD_API_ORIGIN" . parseURI)
78+
cloudAPIJWKEndpoint <- fromEnv "SHARE_CLOUD_API_JWKS_ENDPOINT" (pure . maybeToEither "Invalid SHARE_CLOUD_API_JWKS_ENDPOINT" . parseURI)
7679

7780
sentryService <-
7881
lookupEnv "SHARE_SENTRY_DSN" >>= \case
@@ -90,18 +93,23 @@ withEnv action = do
9093
| Deployment.onLocal = Nothing
9194
| otherwise = Nothing
9295
in r {Redis.connectTLSParams = tlsParams}
93-
let acceptedAudiences = Set.singleton apiOrigin
94-
let acceptedIssuers = Set.singleton apiOrigin
96+
let shareAudience = JWT.Audience apiOrigin
97+
let shareIssuer = JWT.Issuer apiOrigin
98+
let cloudIssuer = JWT.Issuer cloudAPIOrigin
99+
let acceptedAudiences = Set.singleton $ shareAudience
100+
let acceptedIssuers = Set.fromList [shareIssuer, cloudIssuer]
95101
let legacyKey = JWT.KeyDescription {JWT.key = hs256Key, JWT.alg = JWT.HS256}
96102
let signingKey = JWT.KeyDescription {JWT.key = edDSAKey, JWT.alg = JWT.Ed25519}
103+
let externalJWKs = Map.fromList [ (cloudIssuer, Left cloudAPIJWKEndpoint)
104+
]
97105
hashJWTJWK <- case JWT.keyDescToJWK legacyKey of
98106
Left err -> throwIO err
99-
Right (_thumbprint, jwk) -> pure jwk
107+
Right jwk -> pure jwk
100108
-- I explicitly add the legacy key to the validation keys, so that the thumbprinted
101109
-- version of the key is used for validation, which is needed for HashJWTs which are signed
102110
-- with a 'kid'.
103111
let validationKeys = Set.fromList [legacyKey]
104-
jwtSettings <- case JWT.defaultJWTSettings signingKey (Just legacyKey) validationKeys acceptedAudiences acceptedIssuers of
112+
jwtSettings <- JWT.defaultJWTSettings shareIssuer signingKey (Just legacyKey) validationKeys acceptedAudiences acceptedIssuers externalJWKs >>= \case
105113
Left cryptoError -> throwIO cryptoError
106114
Right settings -> pure settings
107115
let cookieSettings = Cookies.defaultCookieSettings Deployment.onLocal (Just (realToFrac cookieSessionTTL))

docker/docker-compose.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,8 @@ services:
8686
- SHARE_CLOUD_UI_ORIGIN=http://localhost:5678
8787
- SHARE_HOMEPAGE_ORIGIN=http://localhost:1111
8888
- SHARE_CLOUD_HOMEPAGE_ORIGIN=http://localhost:2222
89+
- SHARE_CLOUD_API_ORIGIN=http://localhost:3333
90+
- SHARE_CLOUD_API_JWKS_ENDPOINT=http://localhost:3333/.well-known/jwks.json
8991
- SHARE_LOG_LEVEL=DEBUG
9092
- SHARE_COMMIT=dev
9193
- SHARE_MAX_PARALLELISM_PER_DOWNLOAD_REQUEST=1

local.env

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@ export SHARE_SHARE_UI_ORIGIN="http://localhost:1234"
1515
export SHARE_CLOUD_UI_ORIGIN="http://localhost:5678"
1616
export SHARE_HOMEPAGE_ORIGIN="http://localhost:1111"
1717
export SHARE_CLOUD_HOMEPAGE_ORIGIN="http://localhost:2222"
18+
export SHARE_CLOUD_API_ORIGIN="http://localhost:3333"
19+
export SHARE_CLOUD_API_JWKS_ENDPOINT="http://localhost:3333/.well-known/jwks.json"
1820
export SHARE_LOG_LEVEL="DEBUG"
1921
export SHARE_COMMIT="dev"
2022
export SHARE_MAX_PARALLELISM_PER_DOWNLOAD_REQUEST="1"

share-api.cabal

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,6 @@ library
118118
Share.Ticket
119119
Share.User
120120
Share.UserProfile
121-
Share.Utils.API
122121
Share.Utils.Caching
123122
Share.Utils.Caching.JSON
124123
Share.Utils.Data

share-auth/example/package.yaml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,11 @@ dependencies:
2222
- containers
2323
- share-auth
2424
- share-utils
25+
- jose
26+
- aeson
2527
- hedis
2628
- network-uri
29+
- raw-strings-qq
2730
- servant
2831
- servant-auth-server
2932
- servant-server

share-auth/example/src/Lib.hs

Lines changed: 43 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,9 @@
33

44
module Lib (main) where
55

6+
import Data.Aeson qualified as Aeson
7+
import Data.Map (Map)
8+
import Data.Map qualified as Map
69
import Data.Maybe (fromJust, fromMaybe)
710
import Data.Set qualified as Set
811
import Data.Text (Text)
@@ -22,6 +25,7 @@ import Share.OAuth.ServiceProvider qualified as Auth
2225
import Share.OAuth.Session (AuthCheckCtx, AuthenticatedUserId, MaybeAuthenticatedUserId, addAuthCheckCtx)
2326
import Share.OAuth.Types (OAuthClientId (..), OAuthClientSecret (OAuthClientSecret), RedirectReceiverErr, UserId)
2427
import Share.Utils.Servant.Cookies qualified as Cookies
28+
import Text.RawString.QQ (r)
2529
import UnliftIO
2630

2731
-- | An example application endpoint which is optionally authenticated.
@@ -78,10 +82,11 @@ main = do
7882
redisConn <- R.checkedConnect R.defaultConnectInfo
7983
putStrLn "booting up"
8084

81-
jwtSettings <- case JWT.defaultJWTSettings signingKey (Just legacyKey) rotatedKeys acceptedAudiences acceptedIssuers of
82-
Left cryptoError -> throwIO cryptoError
83-
Right jwtS -> do
84-
pure jwtS
85+
jwtSettings <-
86+
JWT.defaultJWTSettings issuer signingKey (Just legacyKey) rotatedKeys acceptedAudiences acceptedIssuers externalJWKs >>= \case
87+
Left cryptoError -> throwIO cryptoError
88+
Right jwtS -> do
89+
pure jwtS
8590

8691
Warp.run 3030 $ serveWithContext (Proxy @MyAPI) (ctx jwtSettings) (myServer redisConn jwtSettings)
8792
putStrLn "exiting"
@@ -135,7 +140,39 @@ main = do
135140
signingKey = JWT.KeyDescription {JWT.key = edDSAKey, JWT.alg = JWT.Ed25519}
136141
rotatedKeys = Set.empty
137142
api = unsafeURI "http://cloud:3030"
138-
serviceAudience = api
143+
serviceAudience = JWT.Audience api
139144
acceptedAudiences = Set.singleton serviceAudience
140-
issuer = unsafeURI "http://localhost:5424"
145+
issuer = JWT.Issuer $ unsafeURI "http://localhost:5424"
141146
acceptedIssuers = Set.singleton issuer
147+
externalJWKs :: Map JWT.Issuer (Either URI JWT.JWKSet)
148+
externalJWKs =
149+
Map.fromList
150+
[ -- This will fetch jwks from the identity provider directly, and keep them up to
151+
-- date.
152+
( JWT.Issuer $ unsafeURI "http://cloud:3030",
153+
Left $ unsafeURI "http://cloud:3030/.well-known/jwks.json"
154+
),
155+
-- This will use the provided static JWK set.
156+
( JWT.Issuer $ unsafeURI "https://api.unison.cloud",
157+
Right
158+
. fromJust
159+
. Aeson.decode @JWT.JWKSet
160+
$
161+
-- This is a sample JWK set, replace with your own.
162+
-- The key is an Ed25519 key, which is used for signing JWTs.
163+
[r|
164+
{
165+
"keys": [
166+
{
167+
"alg": "EdDSA",
168+
"crv": "Ed25519",
169+
"kid": "ZGRwKNuN0LlKkg2WCm4ZSQ1IRzBS2ej5NCTJW1KhFOY",
170+
"kty": "OKP",
171+
"use": "sig",
172+
"x": "rl4D9BawfhIP5M2UEKn30QG1BD3rjQMSLE9oFiUEJpo"
173+
}
174+
]
175+
}
176+
|]
177+
)
178+
]

share-auth/example/test-auth-app.cabal

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
cabal-version: 1.12
22

3-
-- This file has been generated from package.yaml by hpack version 0.35.2.
3+
-- This file has been generated from package.yaml by hpack version 0.37.0.
44
--
55
-- see: https://github.com/sol/hpack
66

@@ -59,10 +59,13 @@ library
5959
ImportQualifiedPost
6060
ghc-options: -Wall -Wcompat -Widentities -Wincomplete-record-updates -Wincomplete-uni-patterns -Wmissing-export-lists -Wmissing-home-modules -Wpartial-fields -Wredundant-constraints
6161
build-depends:
62-
base >=4.7 && <5
62+
aeson
63+
, base >=4.7 && <5
6364
, containers
6465
, hedis
66+
, jose
6567
, network-uri
68+
, raw-strings-qq
6669
, servant
6770
, servant-auth-server
6871
, servant-server
@@ -110,10 +113,13 @@ executable test-auth-app-exe
110113
ImportQualifiedPost
111114
ghc-options: -Wall -Wcompat -Widentities -Wincomplete-record-updates -Wincomplete-uni-patterns -Wmissing-export-lists -Wmissing-home-modules -Wpartial-fields -Wredundant-constraints -threaded -rtsopts -with-rtsopts=-N
112115
build-depends:
113-
base >=4.7 && <5
116+
aeson
117+
, base >=4.7 && <5
114118
, containers
115119
, hedis
120+
, jose
116121
, network-uri
122+
, raw-strings-qq
117123
, servant
118124
, servant-auth-server
119125
, servant-server
@@ -163,10 +169,13 @@ test-suite test-auth-app-test
163169
ImportQualifiedPost
164170
ghc-options: -Wall -Wcompat -Widentities -Wincomplete-record-updates -Wincomplete-uni-patterns -Wmissing-export-lists -Wmissing-home-modules -Wpartial-fields -Wredundant-constraints -threaded -rtsopts -with-rtsopts=-N
165171
build-depends:
166-
base >=4.7 && <5
172+
aeson
173+
, base >=4.7 && <5
167174
, containers
168175
, hedis
176+
, jose
169177
, network-uri
178+
, raw-strings-qq
170179
, servant
171180
, servant-auth-server
172181
, servant-server

share-auth/src/Share/JWT.hs

Lines changed: 54 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -29,17 +29,22 @@ module Share.JWT
2929

3030
-- * Utilities
3131
JWTParam (..),
32+
Issuer (..),
33+
Audience (..),
3234
textToSignedJWT,
3335
signedJWTToText,
3436
createSignedCookie,
3537

3638
-- * Re-exports
3739
CryptoError (..),
40+
JWK.JWK,
41+
JWK.JWKSet,
3842
)
3943
where
4044

4145
import Control.Lens
4246
import Control.Monad.Except
47+
import Control.Monad.Trans.Except (except)
4348
import Crypto.Error (CryptoError (..), CryptoFailable (..))
4449
import Crypto.JOSE.JWA.JWS qualified as JWS
4550
import Crypto.JOSE.JWK qualified as JWK
@@ -50,32 +55,41 @@ import Data.Aeson qualified as Aeson
5055
import Data.ByteArray qualified as ByteArray
5156
import Data.ByteString qualified as BS
5257
import Data.ByteString.Base64.URL qualified as Base64URL
58+
import Data.Map (Map)
5359
import Data.Map qualified as Map
5460
import Data.Set (Set)
5561
import Data.Set qualified as Set
5662
import Data.Text (Text)
5763
import Data.Text qualified as Text
5864
import Data.Text.Encoding qualified as Text
65+
import Data.Traversable (for)
5966
import Servant
6067
import Share.JWT.Types
6168
import Share.OAuth.Orphans ()
6269
import Share.Utils.Servant.Cookies qualified as Cookies
6370
import UnliftIO (MonadIO (..))
71+
import UnliftIO.STM
6472

65-
-- | Get the JWK Set value which is safe to expose to the public, e.g. in a JWKS endpoint.
73+
-- | Get the JWK Set for an issuer which is safe to expose to the public, e.g. in a JWKS endpoint.
6674
-- This will only include public keys.
6775
--
6876
-- Note that this will not include the legacy key or any HS256 keys, since those don't have a
6977
-- safe public component.
70-
publicJWKSet :: JWTSettings -> JWK.JWKSet
71-
publicJWKSet JWTSettings {validationKeys = KeyMap {byKeyId}} =
72-
JWK.JWKSet
73-
( byKeyId
74-
& foldMap (\jwk -> jwk ^.. JWK.asPublicKey . _Just)
75-
)
78+
publicJWKSet :: (MonadIO m) => JWTSettings -> Issuer -> m JWK.JWKSet
79+
publicJWKSet JWTSettings {validationKeys = KeyMap {keysVar}} issuer = do
80+
keyMap <- liftIO $ readTVarIO keysVar
81+
pure $
82+
JWK.JWKSet
83+
( keyMap
84+
& Map.lookup issuer
85+
& foldMap (\jwk -> jwk ^.. folded . JWK.asPublicKey . _Just)
86+
)
7687

7788
-- | Create a 'JWTSettings' using the required information.
7889
defaultJWTSettings ::
90+
(MonadIO m) =>
91+
-- | Which issuer is the current service
92+
Issuer ->
7993
-- | The key used to sign JWTs.
8094
KeyDescription ->
8195
-- | The legacy key used to verify old JWTs from before key IDs were used. This will be used to verify tokens that don't have a key id.
@@ -86,25 +100,42 @@ defaultJWTSettings ::
86100
-- Tokens must have an audience which is present in this set.
87101
--
88102
-- E.g. https://api.unison.cloud
89-
Set URI ->
103+
Set Audience ->
90104
-- | Valid issuers when validating tokens
91-
Set URI ->
92-
Either CryptoError JWTSettings
93-
defaultJWTSettings signingKey legacyKey oldValidKeys acceptedAudiences acceptedIssuers = do
94-
sjwk@(_, signingJWK) <- keyDescToJWK signingKey
95-
verificationJWKs <- (sjwk :) <$> traverse keyDescToJWK (Set.toList oldValidKeys)
96-
let byKeyId = Map.fromList verificationJWKs
97-
legacyKey <- traverse keyDescToJWK legacyKey <&> fmap snd
105+
Set Issuer ->
106+
-- | Mapping of issuers to either their:
107+
-- * JWK json endpoint.
108+
-- * JWK set.
109+
--
110+
-- If a JWK URI is provided for an issuer it will be fetched and kept up to date as needed.
111+
Map Issuer (Either URI JWT.JWKSet) ->
112+
m (Either CryptoError JWTSettings)
113+
defaultJWTSettings myIssuer signingKey legacyKey oldValidKeys acceptedAudiences acceptedIssuers externalJWKsMap = runExceptT $ do
114+
signingJWK <- except $ keyDescToJWK signingKey
115+
myVerificationJWKs <- (signingJWK :) <$> traverse (except . keyDescToJWK) (Set.toList oldValidKeys)
116+
let (externalJWKs, externalJWKLocations) =
117+
externalJWKsMap
118+
& foldMap \case
119+
Left uri -> (mempty, Map.singleton myIssuer uri)
120+
Right (JWT.JWKSet jwks) -> (Map.singleton myIssuer jwks, mempty)
121+
122+
let myJWKs = Map.singleton myIssuer myVerificationJWKs
123+
let keysMap = myJWKs <> externalJWKs
124+
keysVar <- liftIO $ newTVarIO keysMap
125+
lastCheckedVar <- liftIO $ newTVarIO Map.empty
126+
legacyKey <- for legacyKey (except . keyDescToJWK)
98127
pure $
99128
JWTSettings
100129
{ signingJWK,
101-
validationKeys = KeyMap {byKeyId, legacyKey},
130+
validationKeys = KeyMap {keysVar, legacyKey},
131+
externalJWKLocations,
132+
lastCheckedVar,
102133
acceptedAudiences,
103134
acceptedIssuers
104135
}
105136

106137
-- | Converts a 'KeyDescription' to a 'JWK' and a 'KeyThumbprint'.
107-
keyDescToJWK :: KeyDescription -> Either CryptoError (KeyThumbprint, JWK.JWK)
138+
keyDescToJWK :: KeyDescription -> Either CryptoError JWK.JWK
108139
keyDescToJWK (KeyDescription {key, alg}) = cryptoFailableToEither $ do
109140
case alg of
110141
HS256 -> do
@@ -113,7 +144,7 @@ keyDescToJWK (KeyDescription {key, alg}) = cryptoFailableToEither $ do
113144
& JWK.jwkUse .~ Just JWK.Sig
114145
& JWK.jwkAlg .~ Just (JWK.JWSAlg JWS.HS256)
115146
let thumbprint = jwkThumbprint jwk
116-
pure (KeyThumbprint thumbprint, jwk & JWK.jwkKid .~ Just thumbprint)
147+
pure (jwk & JWK.jwkKid .~ Just thumbprint)
117148
Ed25519 -> do
118149
privKey <- Ed25519.secretKey key
119150
let pubKey = Ed25519.toPublic privKey
@@ -124,7 +155,7 @@ keyDescToJWK (KeyDescription {key, alg}) = cryptoFailableToEither $ do
124155
& JWK.jwkUse .~ Just JWK.Sig
125156
& JWK.jwkAlg .~ Just (JWK.JWSAlg JWS.EdDSA)
126157
let thumbprint = jwkThumbprint jwk
127-
pure (KeyThumbprint thumbprint, jwk & JWK.jwkKid .~ Just thumbprint)
158+
pure (jwk & JWK.jwkKid .~ Just thumbprint)
128159
where
129160
cryptoFailableToEither :: CryptoFailable a -> Either CryptoError a
130161
cryptoFailableToEither (CryptoFailed err) = Left err
@@ -160,8 +191,8 @@ signJWTWithJWK jwk v = runExceptT $ do
160191
--
161192
-- Any other checks should be performed on the returned claims.
162193
verifyJWT :: forall claims m. (AsJWTClaims claims, MonadIO m) => JWTSettings -> JWT.SignedJWT -> m (Either JWT.JWTError claims)
163-
verifyJWT JWTSettings {validationKeys, acceptedAudiences, acceptedIssuers} signedJWT = runExceptT do
164-
jwtClaimsMap <- ExceptT . liftIO . runExceptT $ JWT.verifyJWT validators validationKeys signedJWT
194+
verifyJWT jwtSettings@JWTSettings {acceptedAudiences, acceptedIssuers} signedJWT = runExceptT do
195+
jwtClaimsMap <- ExceptT . liftIO . runExceptT $ JWT.verifyJWT validators jwtSettings signedJWT
165196
case fromClaims jwtClaimsMap of
166197
Left err -> throwError $ JWT.JWTClaimsSetDecodeError (Text.unpack err)
167198
Right claims -> pure claims
@@ -170,12 +201,12 @@ verifyJWT JWTSettings {validationKeys, acceptedAudiences, acceptedIssuers} signe
170201
auds =
171202
-- Annoyingly StringOrURI doesn't have an ord instance.
172203
Set.toList acceptedAudiences
173-
& map (review CryptoJWT.uri)
204+
& map (\(Audience aud) -> review CryptoJWT.uri aud)
174205
issuers :: [CryptoJWT.StringOrURI]
175206
issuers =
176207
-- Annoyingly StringOrURI doesn't have an ord instance.
177208
Set.toList acceptedIssuers
178-
& map (review CryptoJWT.uri)
209+
& map (\(Issuer iss) -> review CryptoJWT.uri iss)
179210
validators =
180211
CryptoJWT.defaultJWTValidationSettings (`elem` auds)
181212
& CryptoJWT.issuerPredicate .~ (`elem` issuers)

0 commit comments

Comments
 (0)