Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 65 additions & 13 deletions dvc_webdav/__init__.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
import logging
import threading
from getpass import getpass
from typing import ClassVar
from typing import ClassVar, Optional

from funcy import memoize, wrap_prop, wrap_with

from dvc.repo import Repo
from dvc.utils.objects import cached_property
from dvc_objects.fs.base import FileSystem
from dvc_webdav.bearer_auth_client import BearerAuthClient

logger = logging.getLogger(__name__)
logger = logging.getLogger("dvc")


@wrap_with(threading.Lock())
Expand All @@ -17,6 +19,16 @@ def ask_password(host, user):
return getpass(f"Enter a password for host '{host}' user '{user}':\n")


@wrap_with(threading.Lock())
@memoize
def get_bearer_auth_client(bearer_token_command: str):
logger.debug(
"Bearer token command provided, using BearerAuthClient, command: %s",
bearer_token_command,
)
return BearerAuthClient(bearer_token_command)


class WebDAVFileSystem(FileSystem): # pylint:disable=abstract-method
protocol = "webdav"
root_marker = ""
Expand All @@ -37,32 +49,72 @@ def __init__(self, **config):
"timeout": config.get("timeout", 30),
}
)
if bearer_token_command := config.get("bearer_token_command"):
client = get_bearer_auth_client(bearer_token_command)
client.save_token_cb = self._save_token
if token := config.get("token"):
client.update_token(token)
self.fs_args["http_client"] = client

def unstrip_protocol(self, path: str) -> str:
return self.fs_args["base_url"] + "/" + path

@staticmethod
def _get_kwargs_from_urls(urlpath):
def _normalize_url(url):
from urllib.parse import urlparse, urlunparse

parsed = urlparse(url)
scheme = parsed.scheme.replace("webdav", "http")
path = parsed.path.rstrip("/")
return urlunparse((scheme, parsed.netloc, path, None, None, None))

@classmethod
def _get_kwargs_from_urls(cls, urlpath):
from urllib.parse import urlparse, urlunparse

parsed = urlparse(urlpath)
scheme = parsed.scheme.replace("webdav", "http")
return {
"prefix": parsed.path.rstrip("/"),
"host": urlunparse((scheme, parsed.netloc, "", None, None, None)),
"url": urlunparse(
(
scheme,
parsed.netloc,
parsed.path.rstrip("/"),
None,
None,
None,
)
),
"url": cls._normalize_url(urlpath),
"user": parsed.username,
}

def _find_remote_name(self) -> Optional[str]:
"""Find the remote name for the current filesystem."""
repo = Repo()
base_url = self.fs_args["base_url"]
for remote_name, remote_config in repo.config["remote"].items():
remote_url = remote_config.get("url")
if not remote_url:
continue

normalized_remote_url = self._normalize_url(remote_url)
if normalized_remote_url == base_url:
return remote_name
return None

@wrap_with(threading.Lock())
def _save_token(self, token: Optional[str]) -> None:
"""Save or unset the token in the local DVC config."""
remote_name = self._find_remote_name()
if not remote_name:
logger.warning(
"Skipping token persistence - Could not find remote name to save token."
)
return

with Repo().config.edit("local") as conf:
remote_conf = conf.setdefault("remote", {}).setdefault(remote_name, {})
if token:
if remote_conf.get("token") != token:
remote_conf["token"] = token
logger.debug("Saved token for remote '%s'", remote_name)
elif "token" in remote_conf:
del remote_conf["token"]
logger.debug("Unset token for remote '%s'", remote_name)

def _prepare_credentials(self, **config):
user = config.get("user")
password = config.get("password")
Expand Down
202 changes: 202 additions & 0 deletions dvc_webdav/bearer_auth_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,202 @@
import logging
import shlex
import subprocess
import sys
import threading
from typing import Optional, Protocol, Union

import httpx

logger = logging.getLogger("dvc")


def _log_with_thread(level: int, msg: str, *args) -> None:
"""
Universal helper to inject thread identity into logs.
Output format: [Thread-Name] Message...
"""
if logger.isEnabledFor(level):
thread_name = threading.current_thread().name
log_fmt = f"[{thread_name}] " + msg
logger.log(level, log_fmt, *args)


def execute_command(command: Union[list[str], str], timeout: int = 10) -> str:
"""Executes a command to retrieve the token."""
if isinstance(command, str):
command = shlex.split(command)

try:
result = subprocess.run( # noqa: S603
command,
shell=False,
capture_output=True,
text=True,
check=True,
timeout=timeout,
encoding="utf-8",
)
except (
FileNotFoundError,
subprocess.TimeoutExpired,
subprocess.CalledProcessError,
ValueError,
OSError,
) as e:
error_header = "\n" + "=" * 60
error_msg = (
f"{error_header}\n[CRITICAL] Bearer Token Retrieval Failed.\n"
"DVC may misinterpret this as 'File Not Found' and skip files.\n"
f"Command: {command}\n"
f"Error: {e}"
)

if isinstance(e, subprocess.CalledProcessError):
error_msg += f"\nStderr: {e.stderr.strip()}"

error_msg += f"\n{error_header}\n"

logger.critical(error_msg)
sys.stderr.write(error_msg)
sys.stderr.flush()

# Re-raise the exception so the caller knows it failed.
# DVC might catch this and swallow it, but we've done our duty to notify.
raise

token = result.stdout.strip()
if not token:
raise ValueError("Command executed successfully but returned an empty token.")
return token


class TokenSaver(Protocol):
"""Protocol defining the token persistence interface"""

def __call__(self, token: Optional[str]) -> None: ...


def safe_callback(
cb: Optional[TokenSaver], value: Optional[str], operation: str
) -> None:
"""Safely execute callback function with error handling"""
if not cb:
return

try:
cb(value)
except Exception as e: # noqa: BLE001
_log_with_thread(
logging.WARNING,
"[BearerAuthClient] Failed to %s token: %s",
operation,
e,
)


class BearerAuthClient(httpx.Client):
"""HTTPX client that adds Bearer token authentication using a command.

Args:
bearer_token_command: The command to run to get the Bearer token.
save_token_cb: Optional callback to persist the token.
token: Optional initial token to use.
**kwargs: Additional arguments to pass to the httpx.Client constructor.
"""

def __init__(
self,
bearer_token_command: str,
save_token_cb: Optional[TokenSaver] = None,
**kwargs,
):
super().__init__(**kwargs)
if (
not isinstance(bearer_token_command, str)
or not bearer_token_command.strip()
):
raise ValueError(
"[BearerAuthClient] bearer_token_command must be a non-empty string"
)
self.bearer_token_command = bearer_token_command
self.save_token_cb = save_token_cb
self._token: Optional[str] = None
self._lock = threading.Lock()

def _refresh_token_locked(self) -> None:
"""Execute token command and update state."""
_log_with_thread(
logging.DEBUG, "[BearerAuthClient] Refreshing token via command..."
)

try:
new_token = execute_command(self.bearer_token_command)
# execute_command guarantees non-empty string or raises ValueError

self._token = new_token
self.headers["Authorization"] = f"Bearer {new_token}"
safe_callback(self.save_token_cb, new_token, "save")

_log_with_thread(
logging.DEBUG, "[BearerAuthClient] Token refreshed successfully."
)
except Exception:
# Clean up state on failure
self._token = None
# Clear persisted token but don't fail the refresh operation
safe_callback(self.save_token_cb, None, "clear")
raise

def _ensure_token(self) -> None:
"""Ensure a token exists before making requests"""
if self._token:
return

with self._lock:
if not self._token:
self._refresh_token_locked()

def update_token(self, token: Optional[str]) -> None:
"""Update the token with a new one"""
if not token:
return

with self._lock:
if self._token != token:
self._token = token
self.headers["Authorization"] = f"Bearer {token}"

def request(self, *args, **kwargs) -> httpx.Response:
"""Wraps httpx.request with auto-refresh logic for 401 Unauthorized."""
self._ensure_token()
response = super().request(*args, **kwargs)

if response.status_code != 401:
return response

_log_with_thread(
logging.DEBUG, "[BearerAuthClient] Received 401. Attempting recovery."
)
sent_auth_header = response.request.headers.get("Authorization")

try:
with self._lock:
current_auth_header = self.headers.get("Authorization")
if sent_auth_header == current_auth_header:
self._refresh_token_locked()
else:
_log_with_thread(
logging.DEBUG,
"[BearerAuthClient] Token already refreshed by another thread. "
"Retrying.",
)
except Exception:
logger.exception(
"[BearerAuthClient] Recovery failed: Token refresh threw exception"
)
return response

# Retry the request with the new valid token
# We must close the old 401 response to free connections
response.close()
return super().request(*args, **kwargs)