|
10 | 10 | import urllib.parse |
11 | 11 | import requests |
12 | 12 | from pathlib import Path |
13 | | -from typing import Optional, Dict, Any, Tuple |
| 13 | +from typing import Optional, Dict, Any, Tuple, TYPE_CHECKING |
14 | 14 | from datetime import datetime |
15 | 15 | from pathlib import Path |
16 | | -from src.constants import CLIENT_ID, OAUTH_HOST, AUTH_TIMEOUT_SECONDS, ROOT_DIR |
17 | | -from src.config.app_config import app_config |
| 16 | +from src.config.config import CLIENT_ID, OAUTH_HOST, AUTH_TIMEOUT_SECONDS, ROOT_DIR |
| 17 | +from src.config.app_config import app_config, AuthMethod |
| 18 | + |
| 19 | +if TYPE_CHECKING: |
| 20 | + from auth_provider import SingleStoreOAuthProvider |
18 | 21 |
|
19 | 22 | # Scopes that are always required |
20 | 23 | ALWAYS_PRESENT_SCOPES = ["openid", "offline", "offline_access"] |
@@ -218,8 +221,9 @@ def refresh_token(token_set: TokenSet, client_id: str = CLIENT_ID) -> Optional[T |
218 | 221 | # Create new token set |
219 | 222 | new_token_set = TokenSet(token_data) |
220 | 223 |
|
221 | | - # Save updated credentials |
222 | | - save_credentials(new_token_set) |
| 224 | + # Only save to file in stdio mode |
| 225 | + if app_config.server_mode == "stdio": |
| 226 | + save_credentials(new_token_set) |
223 | 227 |
|
224 | 228 | return new_token_set |
225 | 229 |
|
@@ -354,58 +358,104 @@ def __init__(self, *args, **kwargs): |
354 | 358 | # Create token set |
355 | 359 | token_set = TokenSet(token_response) |
356 | 360 |
|
357 | | - # Save credentials |
358 | | - save_credentials(token_set) |
| 361 | + # Only save credentials to file in stdio mode |
| 362 | + if app_config.server_mode == "stdio": |
| 363 | + save_credentials(token_set) |
359 | 364 |
|
360 | 365 | return True, token_set |
361 | 366 |
|
362 | 367 | except Exception as e: |
363 | 368 | print(f"Authentication failed: {e}") |
364 | 369 | return False, None |
365 | 370 |
|
366 | | -def get_authentication_token(client_id: Optional[str] = None) -> Optional[str]: |
| 371 | +def get_authentication_token(client_id: Optional[str] = None, http_auth_header: Optional[str] = None) -> Optional[str]: |
367 | 372 | """ |
368 | | - Get authentication token from environment or credentials file. |
369 | | - If no valid token is available, prompt for authentication. |
| 373 | + Get authentication token from various sources based on server mode. |
| 374 | + For HTTP mode, prioritizes auth header, then app_config, then browser auth. |
| 375 | + For stdio mode, uses app_config, credentials file, then browser auth. |
370 | 376 | |
371 | 377 | Args: |
372 | 378 | client_id: Optional client ID to use for authentication |
| 379 | + http_auth_header: Optional HTTP Authorization header (for HTTP mode) |
373 | 380 | |
374 | 381 | Returns: |
375 | 382 | JWT token or API key if available, None otherwise |
376 | 383 | """ |
377 | | - # First check for API key in environment |
| 384 | + server_mode = app_config.server_mode # "stdio" or "http" |
| 385 | + |
| 386 | + print(f"Server mode: {server_mode}") |
| 387 | + print(f"Client ID: {client_id}") |
| 388 | + print(f"HTTP Authorization header: {http_auth_header}") |
| 389 | + print(f"App config auth token: {app_config.get_auth_token()}") |
| 390 | + |
| 391 | + # For HTTP mode, first check the Authorization header |
| 392 | + if server_mode == "http" and http_auth_header: |
| 393 | + print("Using token from HTTP Authorization header") |
| 394 | + if http_auth_header.startswith("Bearer "): |
| 395 | + token = http_auth_header[7:] |
| 396 | + print("Using token from HTTP Authorization header") |
| 397 | + app_config.set_auth_token(token, AuthMethod.JWT_TOKEN) |
| 398 | + return token |
| 399 | + |
| 400 | + # Next, check for existing token in app_config |
378 | 401 | api_key = app_config.get_auth_token() |
379 | | - print(f"API key from environment: {api_key}") |
| 402 | + auth_method = app_config.get_auth_method() |
| 403 | + |
380 | 404 | if api_key: |
| 405 | + print(f"Using existing authentication token (type: {auth_method.name})") |
381 | 406 | return api_key |
382 | 407 |
|
383 | | - # Then check for saved credentials |
384 | | - credentials = load_credentials() |
385 | | - if credentials and "token_set" in credentials: |
386 | | - token_set = TokenSet(credentials["token_set"]) |
387 | | - |
388 | | - # If token is expired, try to refresh it |
389 | | - if token_set.is_expired() and token_set.refresh_token: |
390 | | - print("Access token expired, refreshing...") |
391 | | - refreshed_token_set = refresh_token(token_set, client_id or CLIENT_ID) |
392 | | - if refreshed_token_set: |
393 | | - token_set = refreshed_token_set |
394 | | - else: |
395 | | - print("Token refresh failed, proceeding to re-authentication") |
396 | | - |
397 | | - # If we have a valid token, use it |
398 | | - if not token_set.is_expired() and token_set.access_token: |
399 | | - print("Using saved OAuth token.") |
400 | | - return token_set.access_token |
| 408 | + # For stdio mode, check saved credentials file |
| 409 | + if server_mode == "stdio": |
| 410 | + credentials = load_credentials() |
| 411 | + if credentials and "token_set" in credentials: |
| 412 | + token_set = TokenSet(credentials["token_set"]) |
| 413 | + |
| 414 | + # If token is expired, try to refresh it |
| 415 | + if token_set.is_expired() and token_set.refresh_token: |
| 416 | + print("Access token expired, refreshing...") |
| 417 | + refreshed_token_set = refresh_token(token_set, client_id or CLIENT_ID) |
| 418 | + if refreshed_token_set: |
| 419 | + token_set = refreshed_token_set |
| 420 | + # Update app config with the refreshed token |
| 421 | + app_config.set_auth_token(token_set.access_token, AuthMethod.OAUTH) |
| 422 | + else: |
| 423 | + print("Token refresh failed, proceeding to re-authentication") |
| 424 | + |
| 425 | + # If we have a valid token, use it |
| 426 | + if not token_set.is_expired() and token_set.access_token: |
| 427 | + print("Using saved OAuth token.") |
| 428 | + app_config.set_auth_token(token_set.access_token, AuthMethod.OAUTH) |
| 429 | + return token_set.access_token |
401 | 430 |
|
402 | | - # If no valid credentials, authenticate |
| 431 | + # If no valid credentials found, launch browser authentication |
403 | 432 | print("No API key or valid authentication token found.") |
404 | 433 | success, token_set = authenticate(client_id) |
405 | 434 |
|
406 | 435 | if success and token_set and token_set.access_token: |
407 | 436 | print("Authentication successful!") |
| 437 | + app_config.set_auth_token(token_set.access_token, AuthMethod.OAUTH) |
| 438 | + |
| 439 | + # Only save to credentials file in stdio mode |
| 440 | + # In HTTP mode, we just keep it in memory (app_config) |
| 441 | + if server_mode == "stdio" and token_set: |
| 442 | + save_credentials(token_set) |
| 443 | + |
408 | 444 | return token_set.access_token |
409 | 445 | else: |
410 | 446 | print("Authentication failed. Please try again or provide an API key.") |
411 | 447 | return None |
| 448 | + |
| 449 | +def get_oauth_provider() -> Optional["SingleStoreOAuthProvider"]: |
| 450 | + """ |
| 451 | + Get the singleton instance of the OAuth provider. |
| 452 | + |
| 453 | + Returns: |
| 454 | + The OAuth provider instance |
| 455 | + """ |
| 456 | + try: |
| 457 | + from src.auth.oauth_routes import oauth_provider |
| 458 | + return oauth_provider |
| 459 | + except ImportError: |
| 460 | + # Handle case where oauth_routes hasn't been initialized yet |
| 461 | + return None |
0 commit comments