Skip to content

Commit 297fe19

Browse files
authored
Add initial support for Device Code Flow as per RFC 8628 (#113)
* Add Client methods to set up the device authorization url and details * Add Client methods to exchange for device codes, and to exchange codes for a token. * Add an example that authorizes using Device Code Flow against Google.
1 parent b50328e commit 297fe19

File tree

7 files changed

+1390
-32
lines changed

7 files changed

+1390
-32
lines changed

Cargo.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,10 +29,12 @@ serde = { version = "1.0", features = ["derive"] }
2929
serde_json = "1.0"
3030
sha2 = "0.9"
3131
url = { version = "2.1", features = ["serde"] }
32+
chrono = "0.4"
3233

3334
[dev-dependencies]
3435
hex = "0.4"
3536
hmac = "0.8"
3637
uuid = { version = "0.8", features = ["v4"] }
3738
anyhow="1.0"
3839
tokio = { version = "0.2", features = ["full"] }
40+
async-std = "1.6.3"

examples/google_devicecode.rs

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
//!
2+
//! This example showcases the Google OAuth2 process for requesting access to the Google Calendar features
3+
//! and the user's profile.
4+
//!
5+
//! Before running it, you'll need to generate your own Google OAuth2 credentials.
6+
//!
7+
//! In order to run the example call:
8+
//!
9+
//! ```sh
10+
//! GOOGLE_CLIENT_ID=xxx GOOGLE_CLIENT_SECRET=yyy cargo run --example google
11+
//! ```
12+
//!
13+
//! ...and follow the instructions.
14+
//!
15+
16+
use oauth2::basic::BasicClient;
17+
// Alternatively, this can be oauth2::curl::http_client or a custom.
18+
use oauth2::devicecode::{DeviceAuthorizationResponse, ExtraDeviceAuthorizationFields};
19+
use oauth2::reqwest::http_client;
20+
use oauth2::{AuthType, AuthUrl, ClientId, ClientSecret, DeviceAuthorizationUrl, Scope, TokenUrl};
21+
use serde::{Deserialize, Serialize};
22+
use std::collections::HashMap;
23+
use std::env;
24+
25+
#[derive(Debug, Serialize, Deserialize)]
26+
struct StoringFields(HashMap<String, serde_json::Value>);
27+
28+
impl ExtraDeviceAuthorizationFields for StoringFields {}
29+
type StoringDeviceAuthorizationResponse = DeviceAuthorizationResponse<StoringFields>;
30+
31+
fn main() {
32+
let google_client_id = ClientId::new(
33+
env::var("GOOGLE_CLIENT_ID").expect("Missing the GOOGLE_CLIENT_ID environment variable."),
34+
);
35+
let google_client_secret = ClientSecret::new(
36+
env::var("GOOGLE_CLIENT_SECRET")
37+
.expect("Missing the GOOGLE_CLIENT_SECRET environment variable."),
38+
);
39+
let auth_url = AuthUrl::new("https://accounts.google.com/o/oauth2/v2/auth".to_string())
40+
.expect("Invalid authorization endpoint URL");
41+
let token_url = TokenUrl::new("https://www.googleapis.com/oauth2/v3/token".to_string())
42+
.expect("Invalid token endpoint URL");
43+
let device_auth_url =
44+
DeviceAuthorizationUrl::new("https://oauth2.googleapis.com/device/code".to_string())
45+
.expect("Invalid device authorization endpoint URL");
46+
47+
// Set up the config for the Google OAuth2 process.
48+
//
49+
// Google's OAuth endpoint expects the client_id to be in the request body,
50+
// so ensure that option is set.
51+
let device_client = BasicClient::new(
52+
google_client_id,
53+
Some(google_client_secret),
54+
auth_url,
55+
Some(token_url),
56+
)
57+
.set_device_authorization_url(device_auth_url)
58+
.set_auth_type(AuthType::RequestBody);
59+
60+
// Request the set of codes from the Device Authorization endpoint.
61+
let details: StoringDeviceAuthorizationResponse = device_client
62+
.exchange_device_code()
63+
.add_scope(Scope::new("profile".to_string()))
64+
.request(http_client)
65+
.expect("Failed to request codes from device auth endpoint");
66+
67+
// Display the URL and user-code.
68+
println!(
69+
"Open this URL in your browser:\n{}\nand enter the code: {}",
70+
details.verification_uri().to_string(),
71+
details.user_code().secret().to_string()
72+
);
73+
74+
// Now poll for the token
75+
let token = device_client
76+
.exchange_device_access_token(&details)
77+
.request(http_client, std::thread::sleep, None)
78+
.expect("Failed to get token");
79+
80+
println!("Google returned the following token:\n{:?}\n", token);
81+
}

src/basic.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ pub enum BasicErrorResponseType {
117117
Extension(String),
118118
}
119119
impl BasicErrorResponseType {
120-
fn from_str(s: &str) -> Self {
120+
pub(crate) fn from_str(s: &str) -> Self {
121121
match s {
122122
"invalid_client" => BasicErrorResponseType::InvalidClient,
123123
"invalid_grant" => BasicErrorResponseType::InvalidGrant,

src/devicecode.rs

Lines changed: 239 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,239 @@
1+
use std::error::Error;
2+
use std::fmt::Error as FormatterError;
3+
use std::fmt::{Debug, Display, Formatter};
4+
use std::marker::PhantomData;
5+
use std::time::Duration;
6+
7+
use serde::de::DeserializeOwned;
8+
use serde::{Deserialize, Serialize};
9+
10+
use super::{
11+
DeviceCode, EndUserVerificationUrl, ErrorResponse, ErrorResponseType, RequestTokenError,
12+
StandardErrorResponse, TokenResponse, TokenType, UserCode,
13+
};
14+
use crate::basic::BasicErrorResponseType;
15+
use crate::types::VerificationUriComplete;
16+
17+
/// The minimum amount of time in seconds that the client SHOULD wait
18+
/// between polling requests to the token endpoint. If no value is
19+
/// provided, clients MUST use 5 as the default.
20+
fn default_devicecode_interval() -> u64 {
21+
5
22+
}
23+
24+
///
25+
/// Trait for adding extra fields to the `DeviceAuthorizationResponse`.
26+
///
27+
pub trait ExtraDeviceAuthorizationFields: DeserializeOwned + Debug + Serialize {}
28+
29+
#[derive(Clone, Debug, Deserialize, Serialize)]
30+
///
31+
/// Empty (default) extra token fields.
32+
///
33+
pub struct EmptyExtraDeviceAuthorizationFields {}
34+
impl ExtraDeviceAuthorizationFields for EmptyExtraDeviceAuthorizationFields {}
35+
36+
///
37+
/// Standard OAuth2 device authorization response.
38+
///
39+
#[derive(Clone, Debug, Deserialize, Serialize)]
40+
pub struct DeviceAuthorizationResponse<EF>
41+
where
42+
EF: ExtraDeviceAuthorizationFields,
43+
{
44+
/// The device verification code.
45+
device_code: DeviceCode,
46+
47+
/// The end-user verification code.
48+
user_code: UserCode,
49+
50+
/// The end-user verification URI on the authorization The URI should be
51+
/// short and easy to remember as end users will be asked to manually type
52+
/// it into their user agent.
53+
///
54+
/// The `verification_url` alias here is a deviation from the RFC, as
55+
/// implementations of device code flow predate RFC 8628.
56+
#[serde(alias = "verification_url")]
57+
verification_uri: EndUserVerificationUrl,
58+
59+
/// A verification URI that includes the "user_code" (or other information
60+
/// with the same function as the "user_code"), which is designed for
61+
/// non-textual transmission.
62+
#[serde(skip_serializing_if = "Option::is_none")]
63+
verification_uri_complete: Option<VerificationUriComplete>,
64+
65+
/// The lifetime in seconds of the "device_code" and "user_code".
66+
expires_in: u64,
67+
68+
/// The minimum amount of time in seconds that the client SHOULD wait
69+
/// between polling requests to the token endpoint. If no value is
70+
/// provided, clients MUST use 5 as the default.
71+
#[serde(default = "default_devicecode_interval")]
72+
interval: u64,
73+
74+
#[serde(bound = "EF: ExtraDeviceAuthorizationFields", flatten)]
75+
extra_fields: EF,
76+
}
77+
78+
impl<EF> DeviceAuthorizationResponse<EF>
79+
where
80+
EF: ExtraDeviceAuthorizationFields,
81+
{
82+
/// The device verification code.
83+
pub fn device_code(&self) -> &DeviceCode {
84+
&self.device_code
85+
}
86+
87+
/// The end-user verification code.
88+
pub fn user_code(&self) -> &UserCode {
89+
&self.user_code
90+
}
91+
92+
/// The end-user verification URI on the authorization The URI should be
93+
/// short and easy to remember as end users will be asked to manually type
94+
/// it into their user agent.
95+
pub fn verification_uri(&self) -> &EndUserVerificationUrl {
96+
&self.verification_uri
97+
}
98+
99+
/// A verification URI that includes the "user_code" (or other information
100+
/// with the same function as the "user_code"), which is designed for
101+
/// non-textual transmission.
102+
pub fn verification_uri_complete(&self) -> Option<&VerificationUriComplete> {
103+
self.verification_uri_complete.as_ref()
104+
}
105+
106+
/// The lifetime in seconds of the "device_code" and "user_code".
107+
pub fn expires_in(&self) -> Duration {
108+
Duration::from_secs(self.expires_in)
109+
}
110+
111+
/// The minimum amount of time in seconds that the client SHOULD wait
112+
/// between polling requests to the token endpoint. If no value is
113+
/// provided, clients MUST use 5 as the default.
114+
pub fn interval(&self) -> Duration {
115+
Duration::from_secs(self.interval)
116+
}
117+
118+
/// Any extra fields returned on the response.
119+
pub fn extra_fields(&self) -> &EF {
120+
&self.extra_fields
121+
}
122+
}
123+
124+
///
125+
/// Standard implementation of DeviceAuthorizationResponse which throws away
126+
/// extra received response fields.
127+
///
128+
pub type StandardDeviceAuthorizationResponse =
129+
DeviceAuthorizationResponse<EmptyExtraDeviceAuthorizationFields>;
130+
131+
///
132+
/// Basic access token error types.
133+
///
134+
/// These error types are defined in
135+
/// [Section 5.2 of RFC 6749](https://tools.ietf.org/html/rfc6749#section-5.2) and
136+
/// [Section 3.5 of RFC 6749](https://tools.ietf.org/html/rfc8628#section-3.5)
137+
///
138+
#[derive(Clone, PartialEq)]
139+
pub enum DeviceCodeErrorResponseType {
140+
///
141+
/// The authorization request is still pending as the end user hasn't
142+
/// yet completed the user-interaction steps. The client SHOULD repeat the
143+
/// access token request to the token endpoint. Before each new request,
144+
/// the client MUST wait at least the number of seconds specified by the
145+
/// "interval" parameter of the device authorization response, or 5 seconds
146+
/// if none was provided, and respect any increase in the polling interval
147+
/// required by the "slow_down" error.
148+
///
149+
AuthorizationPending,
150+
///
151+
/// A variant of "authorization_pending", the authorization request is
152+
/// still pending and polling should continue, but the interval MUST be
153+
/// increased by 5 seconds for this and all subsequent requests.
154+
SlowDown,
155+
///
156+
/// The authorization request was denied.
157+
///
158+
AccessDenied,
159+
///
160+
/// The "device_code" has expired, and the device authorization session has
161+
/// concluded. The client MAY commence a new device authorization request
162+
/// but SHOULD wait for user interaction before restarting to avoid
163+
/// unnecessary polling.
164+
ExpiredToken,
165+
///
166+
/// A Basic response type
167+
///
168+
Basic(BasicErrorResponseType),
169+
}
170+
impl DeviceCodeErrorResponseType {
171+
fn from_str(s: &str) -> Self {
172+
match BasicErrorResponseType::from_str(s) {
173+
BasicErrorResponseType::Extension(ext) => match ext.as_str() {
174+
"authorization_pending" => DeviceCodeErrorResponseType::AuthorizationPending,
175+
"slow_down" => DeviceCodeErrorResponseType::SlowDown,
176+
"access_denied" => DeviceCodeErrorResponseType::AccessDenied,
177+
"expired_token" => DeviceCodeErrorResponseType::ExpiredToken,
178+
_ => DeviceCodeErrorResponseType::Basic(BasicErrorResponseType::Extension(ext)),
179+
},
180+
basic => DeviceCodeErrorResponseType::Basic(basic),
181+
}
182+
}
183+
}
184+
impl AsRef<str> for DeviceCodeErrorResponseType {
185+
fn as_ref(&self) -> &str {
186+
match self {
187+
DeviceCodeErrorResponseType::AuthorizationPending => "authorization_pending",
188+
DeviceCodeErrorResponseType::SlowDown => "slow_down",
189+
DeviceCodeErrorResponseType::AccessDenied => "access_denied",
190+
DeviceCodeErrorResponseType::ExpiredToken => "expired_token",
191+
DeviceCodeErrorResponseType::Basic(basic) => basic.as_ref(),
192+
}
193+
}
194+
}
195+
impl<'de> serde::Deserialize<'de> for DeviceCodeErrorResponseType {
196+
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
197+
where
198+
D: serde::de::Deserializer<'de>,
199+
{
200+
let variant_str = String::deserialize(deserializer)?;
201+
Ok(Self::from_str(&variant_str))
202+
}
203+
}
204+
impl serde::ser::Serialize for DeviceCodeErrorResponseType {
205+
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
206+
where
207+
S: serde::ser::Serializer,
208+
{
209+
serializer.serialize_str(self.as_ref())
210+
}
211+
}
212+
impl ErrorResponseType for DeviceCodeErrorResponseType {}
213+
impl Debug for DeviceCodeErrorResponseType {
214+
fn fmt(&self, f: &mut Formatter) -> Result<(), FormatterError> {
215+
Display::fmt(self, f)
216+
}
217+
}
218+
219+
impl Display for DeviceCodeErrorResponseType {
220+
fn fmt(&self, f: &mut Formatter) -> Result<(), FormatterError> {
221+
write!(f, "{}", self.as_ref())
222+
}
223+
}
224+
225+
///
226+
/// Error response specialization for device code OAuth2 implementation.
227+
///
228+
pub type DeviceCodeErrorResponse = StandardErrorResponse<DeviceCodeErrorResponseType>;
229+
230+
pub(crate) enum DeviceAccessTokenPollResult<TR, RE, TE, TT>
231+
where
232+
TE: ErrorResponse + 'static,
233+
TR: TokenResponse<TT>,
234+
TT: TokenType,
235+
RE: Error + 'static,
236+
{
237+
ContinueWithNewPollInterval(Duration),
238+
Done(Result<TR, RequestTokenError<RE, TE>>, PhantomData<TT>),
239+
}

0 commit comments

Comments
 (0)