Skip to content

Commit 2fb5786

Browse files
authored
feat: Add Optional Support for Rocket-Okapi in IntrospectedUser (#559)
Allow Rocket-Okapi to be used with a specialized feature.
1 parent cff3c30 commit 2fb5786

File tree

3 files changed

+165
-0
lines changed

3 files changed

+165
-0
lines changed

Cargo.toml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,9 @@ oidc = ["credentials", "dep:base64-compat"]
8787
## Refer to the rocket module for more information.
8888
rocket = ["credentials", "oidc", "dep:rocket"]
8989

90+
## Feature that enables support for the [rocket okapi](https://github.com/GREsau/okapi).
91+
rocket_okapi = ["rocket", "dep:rocket_okapi", "dep:schemars"]
92+
9093
# @@protoc_deletion_point(features)
9194
# This section is automatically generated by protoc-gen-prost-crate.
9295
# Changes in this area may be lost on regeneration.
@@ -167,6 +170,8 @@ tokio = { version = "1.37.0", optional = true, features = [
167170
tonic = { version = "0.12.1", features = [
168171
"tls",
169172
], optional = true }
173+
rocket_okapi = { version = "0.8.0", optional = true, default-features = false }
174+
schemars = {version = "0.8.21", optional = true}
170175
tonic-types = { version = "0.12.1", optional = true }
171176

172177
[dev-dependencies]

src/rocket/introspection/config.rs

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
use openidconnect::IntrospectionUrl;
2+
use serde::Deserialize;
23

34
#[cfg(feature = "introspection_cache")]
45
use crate::oidc::introspection::cache::IntrospectionCache;
@@ -18,3 +19,23 @@ pub struct IntrospectionConfig {
1819
#[cfg(feature = "introspection_cache")]
1920
pub(crate) cache: Option<Box<dyn IntrospectionCache>>,
2021
}
22+
23+
#[cfg(feature = "rocket_okapi")]
24+
/// Configuration for OAuth token introspection read from a Rocket.toml file.
25+
///
26+
/// # Fields
27+
/// - `authority`: A string representing the authority URL used for introspection. This is typically
28+
/// the base URL of the OAuth provider that will validate the tokens.
29+
///
30+
/// # Example
31+
/// ```toml
32+
/// [default]
33+
/// authority = "https://auth.example.com/"
34+
/// ```
35+
///
36+
/// # Features
37+
/// This struct is only available when the `rocket_okapi` feature is enabled.
38+
#[derive(Debug, Deserialize)]
39+
pub struct IntrospectionRocketConfig {
40+
pub(crate) authority: String,
41+
}

src/rocket/introspection/guard.rs

Lines changed: 139 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,31 @@
11
use custom_error::custom_error;
22
use openidconnect::TokenIntrospectionResponse;
3+
use rocket::figment::Figment;
34
use rocket::http::Status;
45
use rocket::request::{FromRequest, Outcome};
56
use rocket::{async_trait, Request};
7+
use std::collections::BTreeSet;
68
use std::collections::HashMap;
79

10+
#[cfg(feature = "rocket_okapi")]
11+
use rocket_okapi::{
12+
gen::OpenApiGenerator,
13+
okapi::openapi3::{
14+
Object, Responses, SecurityRequirement, SecurityScheme, SecuritySchemeData, MediaType,
15+
RefOr, Response,
16+
},
17+
okapi::Map,
18+
request::{OpenApiFromRequest, RequestHeaderInput},
19+
};
20+
#[cfg(feature = "rocket_okapi")]
21+
use schemars::schema::{InstanceType, ObjectValidation, Schema, SchemaObject};
22+
#[cfg(feature = "rocket_okapi")]
23+
824
use crate::oidc::introspection::{introspect, IntrospectionError, ZitadelIntrospectionResponse};
925
use crate::rocket::introspection::IntrospectionConfig;
1026

27+
use super::config::IntrospectionRocketConfig;
28+
1129
custom_error! {
1230
/// Error type for guard related errors.
1331
pub IntrospectionGuardError
@@ -147,6 +165,127 @@ impl<'request> FromRequest<'request> for &'request IntrospectedUser {
147165
}
148166
}
149167

168+
#[cfg(feature = "rocket_okapi")]
169+
impl<'a> OpenApiFromRequest<'a> for &'a IntrospectedUser {
170+
fn from_request_input(
171+
_gen: &mut OpenApiGenerator,
172+
_name: String,
173+
_required: bool,
174+
) -> rocket_okapi::Result<RequestHeaderInput> {
175+
let figment: Figment = rocket::Config::figment();
176+
let config: IntrospectionRocketConfig = figment
177+
.extract()
178+
.expect("authority must be set in Rocket.toml");
179+
180+
// Setup global requirement for Security scheme
181+
let security_scheme = SecurityScheme {
182+
description: Some(
183+
"Use OpenID Connect to authenticate. (does not work in RapiDoc at all)".to_owned(),
184+
),
185+
data: SecuritySchemeData::OpenIdConnect {
186+
open_id_connect_url: format!(
187+
"{}/.well-known/openid-configuration",
188+
config.authority
189+
),
190+
},
191+
extensions: Object::default(),
192+
};
193+
// Add the requirement for this route/endpoint
194+
// This can change between routes.
195+
let mut security_req = SecurityRequirement::new();
196+
// Each security requirement needs to be met before access is allowed.
197+
security_req.insert("OpenID".to_owned(), Vec::new());
198+
// These vvvv-------^^^^^^^ values need to match exactly!
199+
Ok(RequestHeaderInput::Security(
200+
"OpenID".to_owned(),
201+
security_scheme,
202+
security_req,
203+
))
204+
}
205+
206+
fn get_responses(_gen: &mut OpenApiGenerator) -> rocket_okapi::Result<Responses> {
207+
let mut res = Responses::default();
208+
209+
// Manually defining the error response schema
210+
let error_detail_schema = SchemaObject {
211+
instance_type: Some(InstanceType::Object.into()),
212+
object: Some(Box::new(ObjectValidation {
213+
properties: {
214+
let mut properties = Map::new();
215+
properties.insert(
216+
"code".to_owned(),
217+
Schema::Object(SchemaObject {
218+
instance_type: Some(InstanceType::Integer.into()),
219+
..Default::default()
220+
}),
221+
);
222+
properties.insert(
223+
"reason".to_owned(),
224+
Schema::Object(SchemaObject {
225+
instance_type: Some(InstanceType::String.into()),
226+
..Default::default()
227+
}),
228+
);
229+
properties.insert(
230+
"description".to_owned(),
231+
Schema::Object(SchemaObject {
232+
instance_type: Some(InstanceType::String.into()),
233+
..Default::default()
234+
}),
235+
);
236+
properties
237+
},
238+
required: vec!["code".to_owned(), "reason".to_owned(), "description".to_owned()]
239+
.into_iter()
240+
.collect::<BTreeSet<_>>(), // Convert Vec to BTreeSet
241+
..Default::default()
242+
})),
243+
..Default::default()
244+
};
245+
246+
let error_response_schema = SchemaObject {
247+
instance_type: Some(InstanceType::Object.into()),
248+
object: Some(Box::new(ObjectValidation {
249+
properties: {
250+
let mut properties = Map::new();
251+
properties.insert("error".to_owned(), Schema::Object(error_detail_schema));
252+
properties
253+
},
254+
required: vec!["error".to_owned()].into_iter().collect::<BTreeSet<_>>(), // Convert Vec to BTreeSet
255+
..Default::default()
256+
})),
257+
..Default::default()
258+
};
259+
260+
// Create the content for the error response
261+
let mut content = Map::new();
262+
content.insert(
263+
"application/json".to_owned(),
264+
MediaType {
265+
schema: Some(Schema::Object(error_response_schema).into()),
266+
..Default::default()
267+
},
268+
);
269+
270+
// Adding 400 BadRequest response
271+
let bad_request_response = Response {
272+
description: "Bad Request - Multiple authorization headers found.".to_owned(),
273+
content: content.clone(),
274+
..Default::default()
275+
};
276+
res.responses.insert("400".to_owned(), RefOr::Object(bad_request_response));
277+
278+
// Adding 401 Unauthorized response
279+
let unauthorized_response = Response {
280+
description: "Unauthorized - The request requires user authentication.".to_owned(),
281+
content: content.clone(),
282+
..Default::default()
283+
};
284+
res.responses.insert("401".to_owned(), RefOr::Object(unauthorized_response));
285+
286+
Ok(res)
287+
}}
288+
150289
#[cfg(test)]
151290
mod tests {
152291
#![allow(clippy::all)]

0 commit comments

Comments
 (0)