|
| 1 | +use std::error::Error; |
| 2 | +use std::fmt::Error as FormatterError; |
| 3 | +use std::fmt::{Debug, Display, Formatter}; |
| 4 | +use std::marker::PhantomData; |
| 5 | +use std::time::Duration; |
| 6 | + |
| 7 | +use serde::de::DeserializeOwned; |
| 8 | +use serde::{Deserialize, Serialize}; |
| 9 | + |
| 10 | +use super::{ |
| 11 | + DeviceCode, EndUserVerificationUrl, ErrorResponse, ErrorResponseType, RequestTokenError, |
| 12 | + StandardErrorResponse, TokenResponse, TokenType, UserCode, |
| 13 | +}; |
| 14 | +use crate::basic::BasicErrorResponseType; |
| 15 | +use crate::types::VerificationUriComplete; |
| 16 | + |
| 17 | +/// The minimum amount of time in seconds that the client SHOULD wait |
| 18 | +/// between polling requests to the token endpoint. If no value is |
| 19 | +/// provided, clients MUST use 5 as the default. |
| 20 | +fn default_devicecode_interval() -> u64 { |
| 21 | + 5 |
| 22 | +} |
| 23 | + |
| 24 | +/// |
| 25 | +/// Trait for adding extra fields to the `DeviceAuthorizationResponse`. |
| 26 | +/// |
| 27 | +pub trait ExtraDeviceAuthorizationFields: DeserializeOwned + Debug + Serialize {} |
| 28 | + |
| 29 | +#[derive(Clone, Debug, Deserialize, Serialize)] |
| 30 | +/// |
| 31 | +/// Empty (default) extra token fields. |
| 32 | +/// |
| 33 | +pub struct EmptyExtraDeviceAuthorizationFields {} |
| 34 | +impl ExtraDeviceAuthorizationFields for EmptyExtraDeviceAuthorizationFields {} |
| 35 | + |
| 36 | +/// |
| 37 | +/// Standard OAuth2 device authorization response. |
| 38 | +/// |
| 39 | +#[derive(Clone, Debug, Deserialize, Serialize)] |
| 40 | +pub struct DeviceAuthorizationResponse<EF> |
| 41 | +where |
| 42 | + EF: ExtraDeviceAuthorizationFields, |
| 43 | +{ |
| 44 | + /// The device verification code. |
| 45 | + device_code: DeviceCode, |
| 46 | + |
| 47 | + /// The end-user verification code. |
| 48 | + user_code: UserCode, |
| 49 | + |
| 50 | + /// The end-user verification URI on the authorization The URI should be |
| 51 | + /// short and easy to remember as end users will be asked to manually type |
| 52 | + /// it into their user agent. |
| 53 | + /// |
| 54 | + /// The `verification_url` alias here is a deviation from the RFC, as |
| 55 | + /// implementations of device code flow predate RFC 8628. |
| 56 | + #[serde(alias = "verification_url")] |
| 57 | + verification_uri: EndUserVerificationUrl, |
| 58 | + |
| 59 | + /// A verification URI that includes the "user_code" (or other information |
| 60 | + /// with the same function as the "user_code"), which is designed for |
| 61 | + /// non-textual transmission. |
| 62 | + #[serde(skip_serializing_if = "Option::is_none")] |
| 63 | + verification_uri_complete: Option<VerificationUriComplete>, |
| 64 | + |
| 65 | + /// The lifetime in seconds of the "device_code" and "user_code". |
| 66 | + expires_in: u64, |
| 67 | + |
| 68 | + /// The minimum amount of time in seconds that the client SHOULD wait |
| 69 | + /// between polling requests to the token endpoint. If no value is |
| 70 | + /// provided, clients MUST use 5 as the default. |
| 71 | + #[serde(default = "default_devicecode_interval")] |
| 72 | + interval: u64, |
| 73 | + |
| 74 | + #[serde(bound = "EF: ExtraDeviceAuthorizationFields", flatten)] |
| 75 | + extra_fields: EF, |
| 76 | +} |
| 77 | + |
| 78 | +impl<EF> DeviceAuthorizationResponse<EF> |
| 79 | +where |
| 80 | + EF: ExtraDeviceAuthorizationFields, |
| 81 | +{ |
| 82 | + /// The device verification code. |
| 83 | + pub fn device_code(&self) -> &DeviceCode { |
| 84 | + &self.device_code |
| 85 | + } |
| 86 | + |
| 87 | + /// The end-user verification code. |
| 88 | + pub fn user_code(&self) -> &UserCode { |
| 89 | + &self.user_code |
| 90 | + } |
| 91 | + |
| 92 | + /// The end-user verification URI on the authorization The URI should be |
| 93 | + /// short and easy to remember as end users will be asked to manually type |
| 94 | + /// it into their user agent. |
| 95 | + pub fn verification_uri(&self) -> &EndUserVerificationUrl { |
| 96 | + &self.verification_uri |
| 97 | + } |
| 98 | + |
| 99 | + /// A verification URI that includes the "user_code" (or other information |
| 100 | + /// with the same function as the "user_code"), which is designed for |
| 101 | + /// non-textual transmission. |
| 102 | + pub fn verification_uri_complete(&self) -> Option<&VerificationUriComplete> { |
| 103 | + self.verification_uri_complete.as_ref() |
| 104 | + } |
| 105 | + |
| 106 | + /// The lifetime in seconds of the "device_code" and "user_code". |
| 107 | + pub fn expires_in(&self) -> Duration { |
| 108 | + Duration::from_secs(self.expires_in) |
| 109 | + } |
| 110 | + |
| 111 | + /// The minimum amount of time in seconds that the client SHOULD wait |
| 112 | + /// between polling requests to the token endpoint. If no value is |
| 113 | + /// provided, clients MUST use 5 as the default. |
| 114 | + pub fn interval(&self) -> Duration { |
| 115 | + Duration::from_secs(self.interval) |
| 116 | + } |
| 117 | + |
| 118 | + /// Any extra fields returned on the response. |
| 119 | + pub fn extra_fields(&self) -> &EF { |
| 120 | + &self.extra_fields |
| 121 | + } |
| 122 | +} |
| 123 | + |
| 124 | +/// |
| 125 | +/// Standard implementation of DeviceAuthorizationResponse which throws away |
| 126 | +/// extra received response fields. |
| 127 | +/// |
| 128 | +pub type StandardDeviceAuthorizationResponse = |
| 129 | + DeviceAuthorizationResponse<EmptyExtraDeviceAuthorizationFields>; |
| 130 | + |
| 131 | +/// |
| 132 | +/// Basic access token error types. |
| 133 | +/// |
| 134 | +/// These error types are defined in |
| 135 | +/// [Section 5.2 of RFC 6749](https://tools.ietf.org/html/rfc6749#section-5.2) and |
| 136 | +/// [Section 3.5 of RFC 6749](https://tools.ietf.org/html/rfc8628#section-3.5) |
| 137 | +/// |
| 138 | +#[derive(Clone, PartialEq)] |
| 139 | +pub enum DeviceCodeErrorResponseType { |
| 140 | + /// |
| 141 | + /// The authorization request is still pending as the end user hasn't |
| 142 | + /// yet completed the user-interaction steps. The client SHOULD repeat the |
| 143 | + /// access token request to the token endpoint. Before each new request, |
| 144 | + /// the client MUST wait at least the number of seconds specified by the |
| 145 | + /// "interval" parameter of the device authorization response, or 5 seconds |
| 146 | + /// if none was provided, and respect any increase in the polling interval |
| 147 | + /// required by the "slow_down" error. |
| 148 | + /// |
| 149 | + AuthorizationPending, |
| 150 | + /// |
| 151 | + /// A variant of "authorization_pending", the authorization request is |
| 152 | + /// still pending and polling should continue, but the interval MUST be |
| 153 | + /// increased by 5 seconds for this and all subsequent requests. |
| 154 | + SlowDown, |
| 155 | + /// |
| 156 | + /// The authorization request was denied. |
| 157 | + /// |
| 158 | + AccessDenied, |
| 159 | + /// |
| 160 | + /// The "device_code" has expired, and the device authorization session has |
| 161 | + /// concluded. The client MAY commence a new device authorization request |
| 162 | + /// but SHOULD wait for user interaction before restarting to avoid |
| 163 | + /// unnecessary polling. |
| 164 | + ExpiredToken, |
| 165 | + /// |
| 166 | + /// A Basic response type |
| 167 | + /// |
| 168 | + Basic(BasicErrorResponseType), |
| 169 | +} |
| 170 | +impl DeviceCodeErrorResponseType { |
| 171 | + fn from_str(s: &str) -> Self { |
| 172 | + match BasicErrorResponseType::from_str(s) { |
| 173 | + BasicErrorResponseType::Extension(ext) => match ext.as_str() { |
| 174 | + "authorization_pending" => DeviceCodeErrorResponseType::AuthorizationPending, |
| 175 | + "slow_down" => DeviceCodeErrorResponseType::SlowDown, |
| 176 | + "access_denied" => DeviceCodeErrorResponseType::AccessDenied, |
| 177 | + "expired_token" => DeviceCodeErrorResponseType::ExpiredToken, |
| 178 | + _ => DeviceCodeErrorResponseType::Basic(BasicErrorResponseType::Extension(ext)), |
| 179 | + }, |
| 180 | + basic => DeviceCodeErrorResponseType::Basic(basic), |
| 181 | + } |
| 182 | + } |
| 183 | +} |
| 184 | +impl AsRef<str> for DeviceCodeErrorResponseType { |
| 185 | + fn as_ref(&self) -> &str { |
| 186 | + match self { |
| 187 | + DeviceCodeErrorResponseType::AuthorizationPending => "authorization_pending", |
| 188 | + DeviceCodeErrorResponseType::SlowDown => "slow_down", |
| 189 | + DeviceCodeErrorResponseType::AccessDenied => "access_denied", |
| 190 | + DeviceCodeErrorResponseType::ExpiredToken => "expired_token", |
| 191 | + DeviceCodeErrorResponseType::Basic(basic) => basic.as_ref(), |
| 192 | + } |
| 193 | + } |
| 194 | +} |
| 195 | +impl<'de> serde::Deserialize<'de> for DeviceCodeErrorResponseType { |
| 196 | + fn deserialize<D>(deserializer: D) -> Result<Self, D::Error> |
| 197 | + where |
| 198 | + D: serde::de::Deserializer<'de>, |
| 199 | + { |
| 200 | + let variant_str = String::deserialize(deserializer)?; |
| 201 | + Ok(Self::from_str(&variant_str)) |
| 202 | + } |
| 203 | +} |
| 204 | +impl serde::ser::Serialize for DeviceCodeErrorResponseType { |
| 205 | + fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error> |
| 206 | + where |
| 207 | + S: serde::ser::Serializer, |
| 208 | + { |
| 209 | + serializer.serialize_str(self.as_ref()) |
| 210 | + } |
| 211 | +} |
| 212 | +impl ErrorResponseType for DeviceCodeErrorResponseType {} |
| 213 | +impl Debug for DeviceCodeErrorResponseType { |
| 214 | + fn fmt(&self, f: &mut Formatter) -> Result<(), FormatterError> { |
| 215 | + Display::fmt(self, f) |
| 216 | + } |
| 217 | +} |
| 218 | + |
| 219 | +impl Display for DeviceCodeErrorResponseType { |
| 220 | + fn fmt(&self, f: &mut Formatter) -> Result<(), FormatterError> { |
| 221 | + write!(f, "{}", self.as_ref()) |
| 222 | + } |
| 223 | +} |
| 224 | + |
| 225 | +/// |
| 226 | +/// Error response specialization for device code OAuth2 implementation. |
| 227 | +/// |
| 228 | +pub type DeviceCodeErrorResponse = StandardErrorResponse<DeviceCodeErrorResponseType>; |
| 229 | + |
| 230 | +pub(crate) enum DeviceAccessTokenPollResult<TR, RE, TE, TT> |
| 231 | +where |
| 232 | + TE: ErrorResponse + 'static, |
| 233 | + TR: TokenResponse<TT>, |
| 234 | + TT: TokenType, |
| 235 | + RE: Error + 'static, |
| 236 | +{ |
| 237 | + ContinueWithNewPollInterval(Duration), |
| 238 | + Done(Result<TR, RequestTokenError<RE, TE>>, PhantomData<TT>), |
| 239 | +} |
0 commit comments