Skip to content

Commit e658019

Browse files
marijnvdwerfMarenz
authored andcommitted
Move ChatGPT OAuth callback onto API server
1 parent 4b70fef commit e658019

File tree

2 files changed

+68
-54
lines changed

2 files changed

+68
-54
lines changed

src/api/providers.rs

Lines changed: 64 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,9 @@ use super::state::ApiState;
33
use anyhow::Context as _;
44
use axum::Json;
55
use axum::extract::{Query, State};
6-
use axum::http::StatusCode;
7-
use axum::routing::get;
6+
use axum::http::{HeaderMap, StatusCode};
87
use axum::response::Html;
8+
use reqwest::Url;
99
use rig::agent::AgentBuilder;
1010
use rig::completion::{CompletionModel as _, Prompt as _};
1111
use serde::{Deserialize, Serialize};
@@ -14,14 +14,10 @@ use std::sync::{Arc, LazyLock};
1414
use tokio::sync::RwLock;
1515

1616
const OPENAI_BROWSER_OAUTH_SESSION_TTL_SECS: i64 = 15 * 60;
17-
const OPENAI_BROWSER_OAUTH_CALLBACK_BIND: &str = "127.0.0.1:1455";
18-
const OPENAI_BROWSER_OAUTH_REDIRECT_URI: &str = "http://localhost:1455/auth/callback";
17+
const OPENAI_BROWSER_OAUTH_REDIRECT_PATH: &str = "/providers/openai/oauth/browser/callback";
1918

2019
static OPENAI_BROWSER_OAUTH_SESSIONS: LazyLock<RwLock<HashMap<String, BrowserOAuthSession>>> =
2120
LazyLock::new(|| RwLock::new(HashMap::new()));
22-
static OPENAI_BROWSER_OAUTH_CALLBACK_SERVER: LazyLock<
23-
RwLock<Option<BrowserOAuthCallbackServer>>,
24-
> = LazyLock::new(|| RwLock::new(None));
2521

2622
#[derive(Clone, Debug)]
2723
struct BrowserOAuthSession {
@@ -39,10 +35,6 @@ enum BrowserOAuthSessionStatus {
3935
Failed(String),
4036
}
4137

42-
struct BrowserOAuthCallbackServer {
43-
join_handle: tokio::task::JoinHandle<()>,
44-
}
45-
4638
#[derive(Serialize)]
4739
pub(super) struct ProviderStatus {
4840
anthropic: bool,
@@ -367,43 +359,64 @@ async fn prune_expired_browser_oauth_sessions() {
367359
sessions.retain(|_, session| session.created_at >= cutoff);
368360
}
369361

370-
async fn ensure_openai_browser_oauth_callback_server(state: Arc<ApiState>) -> anyhow::Result<()> {
371-
let mut callback_server = OPENAI_BROWSER_OAUTH_CALLBACK_SERVER.write().await;
372-
if let Some(existing) = callback_server.as_ref() {
373-
if existing.join_handle.is_finished() {
374-
*callback_server = None;
375-
} else {
376-
return Ok(());
362+
fn resolve_browser_oauth_redirect_uri(headers: &HeaderMap) -> Option<String> {
363+
if let Some(origin) = header_value(headers, axum::http::header::ORIGIN.as_str()) {
364+
if let Ok(origin_url) = Url::parse(origin) {
365+
let origin = origin_url.origin().ascii_serialization();
366+
if origin != "null" {
367+
return Some(format!("{origin}{OPENAI_BROWSER_OAUTH_REDIRECT_PATH}"));
368+
}
377369
}
378370
}
379371

380-
let listener = tokio::net::TcpListener::bind(OPENAI_BROWSER_OAUTH_CALLBACK_BIND)
381-
.await
382-
.with_context(|| {
383-
format!(
384-
"failed to bind local OAuth callback listener on {}",
385-
OPENAI_BROWSER_OAUTH_CALLBACK_BIND
386-
)
387-
})?;
388-
389-
let app = axum::Router::new()
390-
.route("/auth/callback", get(openai_browser_oauth_callback))
391-
.with_state(state);
392-
393-
let bind = OPENAI_BROWSER_OAUTH_CALLBACK_BIND.to_string();
394-
let join_handle = tokio::spawn(async move {
395-
tracing::info!(bind = %bind, "OpenAI browser OAuth callback listener started");
396-
if let Err(error) = axum::serve(listener, app).await {
397-
tracing::error!(
398-
%error,
399-
bind = %bind,
400-
"OpenAI browser OAuth callback listener stopped"
401-
);
402-
}
403-
});
372+
if let (Some(proto), Some(host)) = (
373+
header_value(headers, "x-forwarded-proto"),
374+
header_value(headers, "x-forwarded-host"),
375+
) {
376+
let proto = first_header_value(proto);
377+
let host = normalize_host(first_header_value(host));
378+
return Some(format!(
379+
"{proto}://{host}{OPENAI_BROWSER_OAUTH_REDIRECT_PATH}"
380+
));
381+
}
382+
383+
if let Some(host) = header_value(headers, "host") {
384+
let host = normalize_host(host);
385+
let scheme = if is_local_host(&host) { "http" } else { "https" };
386+
return Some(format!(
387+
"{scheme}://{host}{OPENAI_BROWSER_OAUTH_REDIRECT_PATH}"
388+
));
389+
}
390+
391+
None
392+
}
393+
394+
fn header_value(headers: &HeaderMap, name: impl AsRef<str>) -> Option<&str> {
395+
headers.get(name.as_ref()).and_then(|value| value.to_str().ok())
396+
}
397+
398+
fn first_header_value(value: &str) -> &str {
399+
value.split(',').next().map(str::trim).unwrap_or(value)
400+
}
401+
402+
fn normalize_host(host: &str) -> String {
403+
let host = host.trim();
404+
let colon_count = host.matches(':').count();
405+
if colon_count > 1 && !host.starts_with('[') {
406+
format!("[{host}]")
407+
} else {
408+
host.to_string()
409+
}
410+
}
404411

405-
*callback_server = Some(BrowserOAuthCallbackServer { join_handle });
406-
Ok(())
412+
fn is_local_host(host: &str) -> bool {
413+
let host = host
414+
.trim_start_matches('[')
415+
.trim_end_matches(']')
416+
.split(':')
417+
.next()
418+
.unwrap_or(host);
419+
matches!(host, "localhost" | "127.0.0.1" | "::1")
407420
}
408421

409422
fn browser_oauth_success_html() -> String {
@@ -598,7 +611,7 @@ pub(super) async fn get_providers(
598611
}
599612

600613
pub(super) async fn start_openai_browser_oauth(
601-
State(state): State<Arc<ApiState>>,
614+
headers: HeaderMap,
602615
Json(request): Json<OpenAiOAuthBrowserStartRequest>,
603616
) -> Result<Json<OpenAiOAuthBrowserStartResponse>, StatusCode> {
604617
if request.model.trim().is_empty() {
@@ -621,28 +634,25 @@ pub(super) async fn start_openai_browser_oauth(
621634
}));
622635
};
623636

624-
if let Err(error) = ensure_openai_browser_oauth_callback_server(state.clone()).await {
637+
let Some(redirect_uri) = resolve_browser_oauth_redirect_uri(&headers) else {
625638
return Ok(Json(OpenAiOAuthBrowserStartResponse {
626639
success: false,
627-
message: format!(
628-
"Failed to start local OAuth callback listener on {}: {}",
629-
OPENAI_BROWSER_OAUTH_CALLBACK_BIND, error
630-
),
640+
message: "Unable to determine OAuth callback URL. Check your Host/Origin headers."
641+
.to_string(),
631642
authorization_url: None,
632643
state: None,
633644
}));
634-
}
645+
};
635646

636647
prune_expired_browser_oauth_sessions().await;
637-
let browser_authorization =
638-
crate::openai_auth::start_browser_authorization(OPENAI_BROWSER_OAUTH_REDIRECT_URI);
648+
let browser_authorization = crate::openai_auth::start_browser_authorization(&redirect_uri);
639649
let state_key = browser_authorization.state.clone();
640650

641651
OPENAI_BROWSER_OAUTH_SESSIONS.write().await.insert(
642652
state_key.clone(),
643653
BrowserOAuthSession {
644654
pkce_verifier: browser_authorization.pkce_verifier,
645-
redirect_uri: OPENAI_BROWSER_OAUTH_REDIRECT_URI.to_string(),
655+
redirect_uri,
646656
model: chatgpt_model,
647657
created_at: chrono::Utc::now().timestamp(),
648658
status: BrowserOAuthSessionStatus::Pending,

src/api/server.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,10 @@ pub async fn start_http_server(
131131
"/providers/openai/oauth/browser/status",
132132
get(providers::openai_browser_oauth_status),
133133
)
134+
.route(
135+
"/providers/openai/oauth/browser/callback",
136+
get(providers::openai_browser_oauth_callback),
137+
)
134138
.route("/providers/test", post(providers::test_provider_model))
135139
.route("/providers/{provider}", delete(providers::delete_provider))
136140
.route("/models", get(models::get_models))

0 commit comments

Comments
 (0)