|
1 | 1 | use custom_error::custom_error;
|
2 | 2 | use openidconnect::TokenIntrospectionResponse;
|
| 3 | +use rocket::figment::Figment; |
3 | 4 | use rocket::http::Status;
|
4 | 5 | use rocket::request::{FromRequest, Outcome};
|
5 | 6 | use rocket::{async_trait, Request};
|
| 7 | +use std::collections::BTreeSet; |
6 | 8 | use std::collections::HashMap;
|
7 | 9 |
|
| 10 | +#[cfg(feature = "rocket_okapi")] |
| 11 | +use rocket_okapi::{ |
| 12 | + gen::OpenApiGenerator, |
| 13 | + okapi::openapi3::{ |
| 14 | + Object, Responses, SecurityRequirement, SecurityScheme, SecuritySchemeData, MediaType, |
| 15 | + RefOr, Response, |
| 16 | + }, |
| 17 | + okapi::Map, |
| 18 | + request::{OpenApiFromRequest, RequestHeaderInput}, |
| 19 | +}; |
| 20 | +#[cfg(feature = "rocket_okapi")] |
| 21 | +use schemars::schema::{InstanceType, ObjectValidation, Schema, SchemaObject}; |
| 22 | +#[cfg(feature = "rocket_okapi")] |
| 23 | + |
8 | 24 | use crate::oidc::introspection::{introspect, IntrospectionError, ZitadelIntrospectionResponse};
|
9 | 25 | use crate::rocket::introspection::IntrospectionConfig;
|
10 | 26 |
|
| 27 | +use super::config::IntrospectionRocketConfig; |
| 28 | + |
11 | 29 | custom_error! {
|
12 | 30 | /// Error type for guard related errors.
|
13 | 31 | pub IntrospectionGuardError
|
@@ -147,6 +165,127 @@ impl<'request> FromRequest<'request> for &'request IntrospectedUser {
|
147 | 165 | }
|
148 | 166 | }
|
149 | 167 |
|
| 168 | +#[cfg(feature = "rocket_okapi")] |
| 169 | +impl<'a> OpenApiFromRequest<'a> for &'a IntrospectedUser { |
| 170 | + fn from_request_input( |
| 171 | + _gen: &mut OpenApiGenerator, |
| 172 | + _name: String, |
| 173 | + _required: bool, |
| 174 | + ) -> rocket_okapi::Result<RequestHeaderInput> { |
| 175 | + let figment: Figment = rocket::Config::figment(); |
| 176 | + let config: IntrospectionRocketConfig = figment |
| 177 | + .extract() |
| 178 | + .expect("authority must be set in Rocket.toml"); |
| 179 | + |
| 180 | + // Setup global requirement for Security scheme |
| 181 | + let security_scheme = SecurityScheme { |
| 182 | + description: Some( |
| 183 | + "Use OpenID Connect to authenticate. (does not work in RapiDoc at all)".to_owned(), |
| 184 | + ), |
| 185 | + data: SecuritySchemeData::OpenIdConnect { |
| 186 | + open_id_connect_url: format!( |
| 187 | + "{}/.well-known/openid-configuration", |
| 188 | + config.authority |
| 189 | + ), |
| 190 | + }, |
| 191 | + extensions: Object::default(), |
| 192 | + }; |
| 193 | + // Add the requirement for this route/endpoint |
| 194 | + // This can change between routes. |
| 195 | + let mut security_req = SecurityRequirement::new(); |
| 196 | + // Each security requirement needs to be met before access is allowed. |
| 197 | + security_req.insert("OpenID".to_owned(), Vec::new()); |
| 198 | + // These vvvv-------^^^^^^^ values need to match exactly! |
| 199 | + Ok(RequestHeaderInput::Security( |
| 200 | + "OpenID".to_owned(), |
| 201 | + security_scheme, |
| 202 | + security_req, |
| 203 | + )) |
| 204 | + } |
| 205 | + |
| 206 | + fn get_responses(_gen: &mut OpenApiGenerator) -> rocket_okapi::Result<Responses> { |
| 207 | + let mut res = Responses::default(); |
| 208 | + |
| 209 | + // Manually defining the error response schema |
| 210 | + let error_detail_schema = SchemaObject { |
| 211 | + instance_type: Some(InstanceType::Object.into()), |
| 212 | + object: Some(Box::new(ObjectValidation { |
| 213 | + properties: { |
| 214 | + let mut properties = Map::new(); |
| 215 | + properties.insert( |
| 216 | + "code".to_owned(), |
| 217 | + Schema::Object(SchemaObject { |
| 218 | + instance_type: Some(InstanceType::Integer.into()), |
| 219 | + ..Default::default() |
| 220 | + }), |
| 221 | + ); |
| 222 | + properties.insert( |
| 223 | + "reason".to_owned(), |
| 224 | + Schema::Object(SchemaObject { |
| 225 | + instance_type: Some(InstanceType::String.into()), |
| 226 | + ..Default::default() |
| 227 | + }), |
| 228 | + ); |
| 229 | + properties.insert( |
| 230 | + "description".to_owned(), |
| 231 | + Schema::Object(SchemaObject { |
| 232 | + instance_type: Some(InstanceType::String.into()), |
| 233 | + ..Default::default() |
| 234 | + }), |
| 235 | + ); |
| 236 | + properties |
| 237 | + }, |
| 238 | + required: vec!["code".to_owned(), "reason".to_owned(), "description".to_owned()] |
| 239 | + .into_iter() |
| 240 | + .collect::<BTreeSet<_>>(), // Convert Vec to BTreeSet |
| 241 | + ..Default::default() |
| 242 | + })), |
| 243 | + ..Default::default() |
| 244 | + }; |
| 245 | + |
| 246 | + let error_response_schema = SchemaObject { |
| 247 | + instance_type: Some(InstanceType::Object.into()), |
| 248 | + object: Some(Box::new(ObjectValidation { |
| 249 | + properties: { |
| 250 | + let mut properties = Map::new(); |
| 251 | + properties.insert("error".to_owned(), Schema::Object(error_detail_schema)); |
| 252 | + properties |
| 253 | + }, |
| 254 | + required: vec!["error".to_owned()].into_iter().collect::<BTreeSet<_>>(), // Convert Vec to BTreeSet |
| 255 | + ..Default::default() |
| 256 | + })), |
| 257 | + ..Default::default() |
| 258 | + }; |
| 259 | + |
| 260 | + // Create the content for the error response |
| 261 | + let mut content = Map::new(); |
| 262 | + content.insert( |
| 263 | + "application/json".to_owned(), |
| 264 | + MediaType { |
| 265 | + schema: Some(Schema::Object(error_response_schema).into()), |
| 266 | + ..Default::default() |
| 267 | + }, |
| 268 | + ); |
| 269 | + |
| 270 | + // Adding 400 BadRequest response |
| 271 | + let bad_request_response = Response { |
| 272 | + description: "Bad Request - Multiple authorization headers found.".to_owned(), |
| 273 | + content: content.clone(), |
| 274 | + ..Default::default() |
| 275 | + }; |
| 276 | + res.responses.insert("400".to_owned(), RefOr::Object(bad_request_response)); |
| 277 | + |
| 278 | + // Adding 401 Unauthorized response |
| 279 | + let unauthorized_response = Response { |
| 280 | + description: "Unauthorized - The request requires user authentication.".to_owned(), |
| 281 | + content: content.clone(), |
| 282 | + ..Default::default() |
| 283 | + }; |
| 284 | + res.responses.insert("401".to_owned(), RefOr::Object(unauthorized_response)); |
| 285 | + |
| 286 | + Ok(res) |
| 287 | + }} |
| 288 | + |
150 | 289 | #[cfg(test)]
|
151 | 290 | mod tests {
|
152 | 291 | #![allow(clippy::all)]
|
|
0 commit comments