|
| 1 | +// Package trust — JWKSKeyResolver implements SD-JWT VC spec §5.3 key resolution |
| 2 | +// via JWT VC Issuer Metadata (.well-known/jwt-vc-issuer). |
| 3 | +package trust |
| 4 | + |
| 5 | +import ( |
| 6 | + "context" |
| 7 | + "crypto" |
| 8 | + "encoding/json" |
| 9 | + "fmt" |
| 10 | + "net/http" |
| 11 | + "strings" |
| 12 | + "time" |
| 13 | + |
| 14 | + "github.com/jellydator/ttlcache/v3" |
| 15 | +) |
| 16 | + |
| 17 | +const ( |
| 18 | + // DefaultJWKSCacheTTL is the default TTL for cached JWKS entries. |
| 19 | + DefaultJWKSCacheTTL = 5 * time.Minute |
| 20 | + |
| 21 | + // DefaultJWKSMaxCapacity is the default max capacity for the JWKS cache. |
| 22 | + DefaultJWKSMaxCapacity = 100 |
| 23 | +) |
| 24 | + |
| 25 | +// JWKSResolverConfig contains configuration for the JWKSKeyResolver. |
| 26 | +type JWKSResolverConfig struct { |
| 27 | + // HTTPClient is the HTTP client used for fetching metadata and JWKS. |
| 28 | + // If nil, a default client with 30s timeout is used. |
| 29 | + HTTPClient *http.Client |
| 30 | + |
| 31 | + // CacheTTL is the time-to-live for cached JWKS entries per issuer. |
| 32 | + // Default: 5 minutes. |
| 33 | + CacheTTL time.Duration |
| 34 | + |
| 35 | + // MaxCapacity is the maximum number of issuers to cache. |
| 36 | + // Default: 100. |
| 37 | + MaxCapacity uint64 |
| 38 | + |
| 39 | + // ParseJWKToPublicKey converts a JWK map to a crypto.PublicKey. |
| 40 | + // If nil, a default implementation using lestrrat-go/jwx is expected |
| 41 | + // to be injected by the caller (avoids coupling pkg/trust to pkg/jose). |
| 42 | + ParseJWKToPublicKey func(jwkData any) (crypto.PublicKey, error) |
| 43 | +} |
| 44 | + |
| 45 | +// JWKSKeyResolver resolves issuer public keys via JWT VC Issuer Metadata (SD-JWT VC §5.3). |
| 46 | +// |
| 47 | +// Resolution flow: |
| 48 | +// 1. Fetch {issuer}/.well-known/jwt-vc-issuer → JWT VC Issuer Metadata |
| 49 | +// 2. Validate metadata.issuer matches the expected issuer |
| 50 | +// 3. Obtain JWKS from inline jwks field or follow jwks_uri |
| 51 | +// 4. Cache the resolved JWKS per issuer URL |
| 52 | +// 5. Match by kid to return the correct key |
| 53 | +type JWKSKeyResolver struct { |
| 54 | + httpClient *http.Client |
| 55 | + cache *ttlcache.Cache[string, *cachedJWKS] |
| 56 | + parseJWK func(jwkData any) (crypto.PublicKey, error) |
| 57 | +} |
| 58 | + |
| 59 | +// cachedJWKS holds the parsed JWKS keys for an issuer. |
| 60 | +type cachedJWKS struct { |
| 61 | + keys []jwkEntry |
| 62 | +} |
| 63 | + |
| 64 | +// jwkEntry holds a single JWK as both a map (for trust evaluation) and parsed public key. |
| 65 | +type jwkEntry struct { |
| 66 | + kid string |
| 67 | + jwkMap map[string]any |
| 68 | + publicKey crypto.PublicKey |
| 69 | +} |
| 70 | + |
| 71 | +// jwtVCIssuerMetadata represents the JWT VC Issuer Metadata response per SD-JWT VC §5.3. |
| 72 | +type jwtVCIssuerMetadata struct { |
| 73 | + Issuer string `json:"issuer"` |
| 74 | + JWKSURI string `json:"jwks_uri,omitempty"` |
| 75 | + JWKS *struct { |
| 76 | + Keys []json.RawMessage `json:"keys"` |
| 77 | + } `json:"jwks,omitempty"` |
| 78 | +} |
| 79 | + |
| 80 | +// NewJWKSKeyResolver creates a new resolver for SD-JWT VC issuer key resolution. |
| 81 | +// The parseJWK function must be provided to convert JWK maps to crypto.PublicKey |
| 82 | +// (this avoids coupling pkg/trust to pkg/jose). |
| 83 | +func NewJWKSKeyResolver(config JWKSResolverConfig) *JWKSKeyResolver { |
| 84 | + httpClient := config.HTTPClient |
| 85 | + if httpClient == nil { |
| 86 | + httpClient = &http.Client{Timeout: 30 * time.Second} |
| 87 | + } |
| 88 | + |
| 89 | + cacheTTL := config.CacheTTL |
| 90 | + if cacheTTL <= 0 { |
| 91 | + cacheTTL = DefaultJWKSCacheTTL |
| 92 | + } |
| 93 | + |
| 94 | + maxCapacity := config.MaxCapacity |
| 95 | + if maxCapacity == 0 { |
| 96 | + maxCapacity = DefaultJWKSMaxCapacity |
| 97 | + } |
| 98 | + |
| 99 | + cache := ttlcache.New( |
| 100 | + ttlcache.WithTTL[string, *cachedJWKS](cacheTTL), |
| 101 | + ttlcache.WithCapacity[string, *cachedJWKS](maxCapacity), |
| 102 | + ) |
| 103 | + go cache.Start() |
| 104 | + |
| 105 | + return &JWKSKeyResolver{ |
| 106 | + httpClient: httpClient, |
| 107 | + cache: cache, |
| 108 | + parseJWK: config.ParseJWKToPublicKey, |
| 109 | + } |
| 110 | +} |
| 111 | + |
| 112 | +// ResolveKeyByKID resolves the public key for the given issuer and kid. |
| 113 | +// Returns the public key and the JWK map (for trust evaluation). |
| 114 | +// |
| 115 | +// Per SD-JWT VC §5.3, the metadata is fetched from {issuer}/.well-known/jwt-vc-issuer. |
| 116 | +// Resolved JWKS are cached per issuer URL. |
| 117 | +func (r *JWKSKeyResolver) ResolveKeyByKID(ctx context.Context, issuerURL, kid string) (crypto.PublicKey, map[string]any, error) { |
| 118 | + if issuerURL == "" { |
| 119 | + return nil, nil, fmt.Errorf("issuer URL is empty") |
| 120 | + } |
| 121 | + if kid == "" { |
| 122 | + return nil, nil, fmt.Errorf("kid is empty") |
| 123 | + } |
| 124 | + |
| 125 | + // Check cache first |
| 126 | + jwks, err := r.getOrFetchJWKS(ctx, issuerURL) |
| 127 | + if err != nil { |
| 128 | + return nil, nil, err |
| 129 | + } |
| 130 | + |
| 131 | + // Find the key matching the kid |
| 132 | + for _, entry := range jwks.keys { |
| 133 | + if entry.kid == kid { |
| 134 | + return entry.publicKey, entry.jwkMap, nil |
| 135 | + } |
| 136 | + } |
| 137 | + |
| 138 | + return nil, nil, fmt.Errorf("no key found in issuer JWKS matching kid %q", kid) |
| 139 | +} |
| 140 | + |
| 141 | +// getOrFetchJWKS returns the cached JWKS for the issuer, or fetches and caches it. |
| 142 | +func (r *JWKSKeyResolver) getOrFetchJWKS(ctx context.Context, issuerURL string) (*cachedJWKS, error) { |
| 143 | + // Check cache |
| 144 | + item := r.cache.Get(issuerURL) |
| 145 | + if item != nil { |
| 146 | + return item.Value(), nil |
| 147 | + } |
| 148 | + |
| 149 | + // Cache miss — fetch from issuer |
| 150 | + jwks, err := r.fetchIssuerJWKS(ctx, issuerURL) |
| 151 | + if err != nil { |
| 152 | + return nil, err |
| 153 | + } |
| 154 | + |
| 155 | + r.cache.Set(issuerURL, jwks, ttlcache.DefaultTTL) |
| 156 | + return jwks, nil |
| 157 | +} |
| 158 | + |
| 159 | +// fetchIssuerJWKS fetches the JWT VC Issuer Metadata and resolves the JWKS. |
| 160 | +func (r *JWKSKeyResolver) fetchIssuerJWKS(ctx context.Context, issuerURL string) (*cachedJWKS, error) { |
| 161 | + // Fetch JWT VC Issuer Metadata per SD-JWT VC §5.3 |
| 162 | + metadataURL := strings.TrimRight(issuerURL, "/") + "/.well-known/jwt-vc-issuer" |
| 163 | + var metadata jwtVCIssuerMetadata |
| 164 | + if _, err := r.fetchJSON(ctx, metadataURL, &metadata); err != nil { |
| 165 | + return nil, fmt.Errorf("failed to fetch JWT VC Issuer Metadata from %s: %w", metadataURL, err) |
| 166 | + } |
| 167 | + |
| 168 | + // Validate issuer match (security requirement per §5.3) |
| 169 | + if metadata.Issuer != issuerURL { |
| 170 | + return nil, fmt.Errorf("metadata issuer %q does not match expected issuer %q", metadata.Issuer, issuerURL) |
| 171 | + } |
| 172 | + |
| 173 | + // Get raw JWKS keys: inline or via jwks_uri |
| 174 | + var rawKeys []json.RawMessage |
| 175 | + if metadata.JWKS != nil && len(metadata.JWKS.Keys) > 0 { |
| 176 | + rawKeys = metadata.JWKS.Keys |
| 177 | + } else if metadata.JWKSURI != "" { |
| 178 | + var fetchErr error |
| 179 | + rawKeys, fetchErr = r.fetchJWKSKeys(ctx, metadata.JWKSURI) |
| 180 | + if fetchErr != nil { |
| 181 | + return nil, fmt.Errorf("failed to fetch JWKS from %s: %w", metadata.JWKSURI, fetchErr) |
| 182 | + } |
| 183 | + } else { |
| 184 | + return nil, fmt.Errorf("issuer metadata contains neither jwks nor jwks_uri") |
| 185 | + } |
| 186 | + |
| 187 | + // Parse all keys |
| 188 | + entries := make([]jwkEntry, 0, len(rawKeys)) |
| 189 | + for _, raw := range rawKeys { |
| 190 | + var jwkMap map[string]any |
| 191 | + if err := json.Unmarshal(raw, &jwkMap); err != nil { |
| 192 | + continue // skip unparseable keys |
| 193 | + } |
| 194 | + |
| 195 | + kid, _ := jwkMap["kid"].(string) |
| 196 | + |
| 197 | + publicKey, err := r.parseJWK(jwkMap) |
| 198 | + if err != nil { |
| 199 | + continue // skip keys that can't be parsed |
| 200 | + } |
| 201 | + |
| 202 | + entries = append(entries, jwkEntry{ |
| 203 | + kid: kid, |
| 204 | + jwkMap: jwkMap, |
| 205 | + publicKey: publicKey, |
| 206 | + }) |
| 207 | + } |
| 208 | + |
| 209 | + if len(entries) == 0 { |
| 210 | + return nil, fmt.Errorf("issuer JWKS contains no usable keys") |
| 211 | + } |
| 212 | + |
| 213 | + return &cachedJWKS{keys: entries}, nil |
| 214 | +} |
| 215 | + |
| 216 | +// fetchJWKSKeys fetches a JWKS from a URI and returns the raw key entries. |
| 217 | +func (r *JWKSKeyResolver) fetchJWKSKeys(ctx context.Context, jwksURI string) ([]json.RawMessage, error) { |
| 218 | + var jwks struct { |
| 219 | + Keys []json.RawMessage `json:"keys"` |
| 220 | + } |
| 221 | + if _, err := r.fetchJSON(ctx, jwksURI, &jwks); err != nil { |
| 222 | + return nil, err |
| 223 | + } |
| 224 | + return jwks.Keys, nil |
| 225 | +} |
| 226 | + |
| 227 | +// fetchJSON fetches a URL and decodes the JSON response into the given target. |
| 228 | +// Returns the decoded target and any error. |
| 229 | +func (r *JWKSKeyResolver) fetchJSON(ctx context.Context, url string, target any) (any, error) { |
| 230 | + req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) |
| 231 | + if err != nil { |
| 232 | + return nil, fmt.Errorf("failed to create request for %s: %w", url, err) |
| 233 | + } |
| 234 | + req.Header.Set("Accept", "application/json") |
| 235 | + |
| 236 | + resp, err := r.httpClient.Do(req) |
| 237 | + if err != nil { |
| 238 | + return nil, fmt.Errorf("failed to fetch %s: %w", url, err) |
| 239 | + } |
| 240 | + defer resp.Body.Close() //nolint:errcheck |
| 241 | + |
| 242 | + if resp.StatusCode != http.StatusOK { |
| 243 | + return nil, fmt.Errorf("HTTP %d from %s", resp.StatusCode, url) |
| 244 | + } |
| 245 | + |
| 246 | + if err := json.NewDecoder(resp.Body).Decode(target); err != nil { |
| 247 | + return nil, fmt.Errorf("failed to decode response from %s: %w", url, err) |
| 248 | + } |
| 249 | + |
| 250 | + return target, nil |
| 251 | +} |
| 252 | + |
| 253 | +// Stop stops the cache's automatic expiration goroutine. |
| 254 | +func (r *JWKSKeyResolver) Stop() { |
| 255 | + r.cache.Stop() |
| 256 | +} |
| 257 | + |
| 258 | +// InvalidateIssuer removes a cached JWKS for a specific issuer. |
| 259 | +// Useful when key rotation is detected (e.g., kid not found in cached JWKS). |
| 260 | +func (r *JWKSKeyResolver) InvalidateIssuer(issuerURL string) { |
| 261 | + r.cache.Delete(issuerURL) |
| 262 | +} |
| 263 | + |
| 264 | +// Len returns the number of issuers currently cached. |
| 265 | +func (r *JWKSKeyResolver) Len() int { |
| 266 | + return r.cache.Len() |
| 267 | +} |
0 commit comments