|
59 | 59 | //! ).unwrap(); |
60 | 60 | //! ``` |
61 | 61 |
|
62 | | -use anyhow::{anyhow, Result}; |
| 62 | +use anyhow::{anyhow, Context, Result}; |
| 63 | +use base64::engine::general_purpose::URL_SAFE_NO_PAD; |
| 64 | +use base64::prelude::*; |
| 65 | +use jsonwebtoken::jwk::{Jwk, PublicKeyUse}; |
63 | 66 | use jsonwebtoken::{Algorithm, DecodingKey, EncodingKey}; |
| 67 | +use rsa::pkcs1::DecodeRsaPublicKey; |
| 68 | +use rsa::traits::PublicKeyParts; |
| 69 | +use rsa::RsaPublicKey; |
| 70 | +use rsa::sha2::Sha256; |
| 71 | +use rsa::sha2::Digest; |
| 72 | +use serde::{Deserialize, Serialize}; |
| 73 | +use serde_json::json; |
64 | 74 | use std::fs::File; |
65 | 75 | use std::io::Read; |
66 | 76 | use std::path::Path; |
@@ -506,3 +516,115 @@ impl JwtKeyConfig { |
506 | 516 | ) |
507 | 517 | } |
508 | 518 | } |
| 519 | + |
| 520 | +/// JSON Web Key Set |
| 521 | +/// |
| 522 | +/// This structure represents a set of JSON Web Keys (JWKs) as defined in RFC 7517. |
| 523 | +/// It can be used to generate and manipulate JWK representations of RSA keys for |
| 524 | +/// use with OpenID Connect discovery endpoints. |
| 525 | +#[derive(Debug, Serialize, Deserialize)] |
| 526 | +pub struct JwkKeySet { |
| 527 | + /// The set of JWKs |
| 528 | + pub keys: Vec<Jwk>, |
| 529 | +} |
| 530 | + |
| 531 | +impl JwkKeySet { |
| 532 | + /// Create a new JWK from a PEM encoded RSA public key |
| 533 | + /// |
| 534 | + /// This function converts a PEM encoded RSA public key to a JWK (JSON Web Key) |
| 535 | + /// representation suitable for use with OpenID Connect discovery endpoints. |
| 536 | + /// |
| 537 | + /// # Parameters |
| 538 | + /// |
| 539 | + /// * `pem_data` - The PEM encoded RSA public key as bytes |
| 540 | + /// |
| 541 | + /// # Returns |
| 542 | + /// |
| 543 | + /// A JWK representing the RSA public key, or an error if parsing fails |
| 544 | + pub fn create_jwk_from_pem(pem_data: &[u8]) -> Result<Jwk> { |
| 545 | + // Parse the PEM key |
| 546 | + let public_key = DecodeRsaPublicKey::from_pkcs1_pem(std::str::from_utf8(pem_data)?) |
| 547 | + .context("Failed to parse RSA public key from PEM")?; |
| 548 | + |
| 549 | + // Convert to JWK |
| 550 | + Self::create_jwk_from_public_key(&public_key) |
| 551 | + } |
| 552 | + |
| 553 | + /// Create a JWK from an RSA public key |
| 554 | + /// |
| 555 | + /// Converts an RSA public key to a JWK representation with the necessary |
| 556 | + /// parameters for use with OpenID Connect. |
| 557 | + /// |
| 558 | + /// # Parameters |
| 559 | + /// |
| 560 | + /// * `public_key` - The RSA public key |
| 561 | + /// |
| 562 | + /// # Returns |
| 563 | + /// |
| 564 | + /// A JWK representing the RSA public key |
| 565 | + pub fn create_jwk_from_public_key(public_key: &RsaPublicKey) -> Result<Jwk> { |
| 566 | + // Get the modulus (n) and exponent (e) from the public key |
| 567 | + let n = public_key.n(); |
| 568 | + let n = BASE64_STANDARD.encode(&public_key.n().to_bytes_be()); |
| 569 | + let e = BASE64_STANDARD.encode(&public_key.e().to_bytes_be()); |
| 570 | + |
| 571 | + // Calculate the key ID (kid) as a SHA-256 thumbprint |
| 572 | + let jwk_thumbprint = Self::calculate_jwk_thumbprint(&n, &e)?; |
| 573 | + |
| 574 | + // Build the JWK |
| 575 | + let jwk = Jwk { |
| 576 | + common: jsonwebtoken::jwk::CommonParameters { |
| 577 | + public_key_use: Some(PublicKeyUse::Signature), |
| 578 | + key_id: Some(jwk_thumbprint), |
| 579 | + key_algorithm: Some(jsonwebtoken::jwk::KeyAlgorithm::RS256), // Correct field name and type |
| 580 | + ..Default::default() |
| 581 | + }, |
| 582 | + algorithm: jsonwebtoken::jwk::AlgorithmParameters::RSA( |
| 583 | + jsonwebtoken::jwk::RSAKeyParameters { |
| 584 | + key_type: jsonwebtoken::jwk::RSAKeyType::RSA, |
| 585 | + n, |
| 586 | + e, |
| 587 | + ..Default::default() |
| 588 | + }, |
| 589 | + ), |
| 590 | + }; |
| 591 | + |
| 592 | + Ok(jwk) |
| 593 | + } |
| 594 | + |
| 595 | + /// Calculate a JWK thumbprint according to RFC 7638 |
| 596 | + /// |
| 597 | + /// This function calculates a thumbprint for a JWK which can be used as |
| 598 | + /// a key ID (kid) parameter. The thumbprint is a SHA-256 hash of the |
| 599 | + /// canonical JSON representation of the JWK. |
| 600 | + /// |
| 601 | + /// # Parameters |
| 602 | + /// |
| 603 | + /// * `n` - Base64URL encoded modulus |
| 604 | + /// * `e` - Base64URL encoded exponent |
| 605 | + /// |
| 606 | + /// # Returns |
| 607 | + /// |
| 608 | + /// Base64URL encoded SHA-256 thumbprint |
| 609 | + fn calculate_jwk_thumbprint(n: &str, e: &str) -> Result<String> { |
| 610 | + // Create canonical JWK representation |
| 611 | + let canonical = json!({ |
| 612 | + "e": e, |
| 613 | + "kty": "RSA", |
| 614 | + "n": n |
| 615 | + }); |
| 616 | + |
| 617 | + // Serialize to bytes in lexicographic order |
| 618 | + let canonical_bytes = serde_json::to_vec(&canonical)?; |
| 619 | + |
| 620 | + // Calculate SHA-256 hash |
| 621 | + let mut hasher = Sha256::new(); |
| 622 | + hasher.update(&canonical_bytes); |
| 623 | + let hash = hasher.finalize(); |
| 624 | + |
| 625 | + // Encode as Base64URL |
| 626 | + let thumbprint = URL_SAFE_NO_PAD.encode(hash); |
| 627 | + |
| 628 | + Ok(thumbprint) |
| 629 | + } |
| 630 | +} |
0 commit comments