diff --git a/pyproject.toml b/pyproject.toml index e05e51f..394f246 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -8,7 +8,8 @@ dependencies = [ "snakemake-interface-common>=1.12.0", "wrapt>=1.15.0", "reretry>=0.11.8", - "throttler>=1.2.2" + "throttler>=1.2.2", + "humanfriendly" ] [[project.authors]] diff --git a/snakemake_interface_storage_plugins/io.py b/snakemake_interface_storage_plugins/io.py index ed35301..05bce7b 100644 --- a/snakemake_interface_storage_plugins/io.py +++ b/snakemake_interface_storage_plugins/io.py @@ -6,6 +6,7 @@ from abc import abstractmethod import re from typing import Dict +from typing import Optional WILDCARD_REGEX = re.compile( @@ -26,7 +27,7 @@ ) -def get_constant_prefix(pattern: str, strip_incomplete_parts: bool = False): +def get_constant_prefix(pattern: str, strip_incomplete_parts: bool = False) -> str: """Return constant prefix of a pattern, removing everything from the first wildcard on. @@ -53,23 +54,31 @@ def get_constant_prefix(pattern: str, strip_incomplete_parts: bool = False): class Mtime: __slots__ = ["_local", "_local_target", "_storage"] + _local: Optional[float] + _local_target: Optional[float] + _storage: Optional[float] - def __init__(self, local=None, local_target=None, storage=None): + def __init__( + self, + local: Optional[float] = None, + local_target: Optional[float] = None, + storage: Optional[float] = None, + ): self._local = local self._local_target = local_target self._storage = storage - def local_or_storage(self, follow_symlinks=False): + def local_or_storage(self, follow_symlinks: bool = False) -> Optional[float]: if self._storage is not None: return self._storage return self.local(follow_symlinks=follow_symlinks) def storage( self, - ): + ) -> Optional[float]: return self._storage - def local(self, follow_symlinks=False): + def local(self, follow_symlinks: bool = False) -> Optional[float]: if follow_symlinks and self._local_target is not None: return self._local_target return self._local diff --git a/snakemake_interface_storage_plugins/py.typed b/snakemake_interface_storage_plugins/py.typed new file mode 100644 index 0000000..e69de29 diff --git a/snakemake_interface_storage_plugins/registry/__init__.py b/snakemake_interface_storage_plugins/registry/__init__.py index c5e4fb5..018799c 100644 --- a/snakemake_interface_storage_plugins/registry/__init__.py +++ b/snakemake_interface_storage_plugins/registry/__init__.py @@ -20,14 +20,14 @@ from snakemake_interface_storage_plugins.storage_provider import StorageProviderBase -class StoragePluginRegistry(PluginRegistryBase): +class StoragePluginRegistry(PluginRegistryBase[Plugin]): """This class is a singleton that holds all registered executor plugins.""" def get_registered_read_write_plugins(self) -> List[str]: return [ plugin.name for plugin in self.plugins.values() - if plugin.storage_provider.is_read_write + if plugin.storage_provider.is_read_write is True ] @property diff --git a/snakemake_interface_storage_plugins/registry/plugin.py b/snakemake_interface_storage_plugins/registry/plugin.py index 5c9fb39..d69819f 100644 --- a/snakemake_interface_storage_plugins/registry/plugin.py +++ b/snakemake_interface_storage_plugins/registry/plugin.py @@ -4,7 +4,7 @@ __license__ = "MIT" from dataclasses import dataclass -from typing import Optional, Type +from typing import Optional, Type, TYPE_CHECKING from snakemake_interface_storage_plugins.settings import ( StorageProviderSettingsBase, ) @@ -18,10 +18,15 @@ ) +if TYPE_CHECKING: + from snakemake_interface_storage_plugins.storage_provider import StorageProviderBase + from snakemake_interface_storage_plugins.storage_object import StorageObjectBase + + @dataclass -class Plugin(PluginBase): - storage_provider: object - storage_object: object +class Plugin(PluginBase[StorageProviderSettingsBase]): + storage_provider: Type["StorageProviderBase"] + storage_object: Type["StorageObjectBase"] _storage_settings_cls: Optional[Type[StorageProviderSettingsBase]] _name: str @@ -30,18 +35,19 @@ def support_tagged_values(self) -> bool: return True @property - def name(self): + def name(self) -> str: return self._name @property - def cli_prefix(self): + def cli_prefix(self) -> str: return "storage-" + self.name.replace(common.storage_plugin_module_prefix, "") @property - def settings_cls(self): + def settings_cls(self) -> Optional[Type[StorageProviderSettingsBase]]: return self._storage_settings_cls - def is_read_write(self): + @property + def is_read_write(self) -> bool: return issubclass(self.storage_object, StorageObjectWrite) and issubclass( self.storage_object, StorageObjectRead ) diff --git a/snakemake_interface_storage_plugins/storage_object.py b/snakemake_interface_storage_plugins/storage_object.py index 48a22c2..4df4e09 100644 --- a/snakemake_interface_storage_plugins/storage_object.py +++ b/snakemake_interface_storage_plugins/storage_object.py @@ -8,7 +8,7 @@ from abc import ABC, abstractmethod from pathlib import Path import shutil -from typing import Iterable, Optional +from typing import Iterable, Optional, AsyncContextManager, Dict, TypeVar, Generic from wrapt import ObjectProxy from reretry import retry @@ -34,48 +34,56 @@ class StaticStorageObjectProxy(ObjectProxy): """ - def exists(self): + def exists(self) -> bool: return True def mtime(self) -> float: return float("-inf") - def is_newer(self, time): + def is_newer(self, time: float) -> bool: return False - def __copy__(self): + def __copy__(self) -> "StaticStorageObjectProxy": copied_wrapped = copy.copy(self.__wrapped__) return type(self)(copied_wrapped) - def __deepcopy__(self, memo): + def __deepcopy__(self, memo: Dict) -> "StaticStorageObjectProxy": copied_wrapped = copy.deepcopy(self.__wrapped__, memo) return type(self)(copied_wrapped) -class StorageObjectBase(ABC): +TStorageProviderBase = TypeVar("TStorageProviderBase", bound=StorageProviderBase) + + +class StorageObjectBase(ABC, Generic[TStorageProviderBase]): """This is an abstract class to be used to derive storage object classes for different cloud storage providers. For example, there could be classes for interacting with Amazon AWS S3 and Google Cloud Storage, both derived from this common base class. """ + query: str + keep_local: bool + retrieve: bool + provider: TStorageProviderBase + print_query: str + _overwrite_local_path: Optional[Path] = None + _is_ondemand_eligible: bool = False def __init__( self, query: str, keep_local: bool, retrieve: bool, - provider: StorageProviderBase, - ): - self.query: str = query - self.keep_local: bool = keep_local - self.retrieve: bool = retrieve - self.provider: StorageProviderBase = provider - self.print_query: str = self.provider.safe_print(self.query) - self._overwrite_local_path: Optional[Path] = None - self._is_ondemand_eligible: bool = False + provider: TStorageProviderBase, + ) -> None: + self.query = query + self.keep_local = keep_local + self.retrieve = retrieve + self.provider = provider + self.print_query = self.provider.safe_print(self.query) self.__post_init__() - def __post_init__(self): # noqa B027 + def __post_init__(self) -> None: # noqa B027 pass @property @@ -83,7 +91,7 @@ def is_ondemand_eligible(self) -> bool: return self._is_ondemand_eligible and not self.keep_local @is_ondemand_eligible.setter - def is_ondemand_eligible(self, value: bool): + def is_ondemand_eligible(self, value: bool) -> None: self._is_ondemand_eligible = value def set_local_path(self, path: Path) -> None: @@ -92,7 +100,7 @@ def set_local_path(self, path: Path) -> None: def is_valid_query(self) -> bool: """Return True is the query is valid for this storage provider.""" - return self.provider.is_valid_query(self.query) + return bool(self.provider.is_valid_query(self.query)) def local_path(self) -> Path: """Return the local path that would represent the query.""" @@ -116,13 +124,13 @@ def local_suffix(self) -> str: # part and any optional parameters if that does not hamper the uniqueness. ... - def _rate_limiter(self, operation: Operation): + def _rate_limiter(self, operation: Operation) -> AsyncContextManager: return self.provider.rate_limiter(self.query, operation) class StorageObjectRead(StorageObjectBase): @abstractmethod - async def inventory(self, cache: IOCacheStorageInterface): + async def inventory(self, cache: IOCacheStorageInterface) -> None: """From this file, try to find as much existence and modification date information as possible. """ @@ -134,7 +142,7 @@ async def inventory(self, cache: IOCacheStorageInterface): def get_inventory_parent(self) -> Optional[str]: ... @abstractmethod - def cleanup(self): + def cleanup(self) -> None: """Perform local cleanup of any remainders of the storage object.""" ... @@ -157,7 +165,7 @@ def local_footprint(self) -> int: return self.size() @abstractmethod - def retrieve_object(self): + def retrieve_object(self) -> None: """Ensure that the object is accessible locally under self.local_path() Optionally, this can make use of the attribute self.is_ondemand_eligible, @@ -194,7 +202,7 @@ async def managed_exists(self) -> bool: except Exception as e: raise WorkflowError(f"Failed to check existence of {self.print_query}", e) - async def managed_retrieve(self): + async def managed_retrieve(self) -> None: await self.wait_for_free_space() try: self.local_path().parent.mkdir(parents=True, exist_ok=True) @@ -223,7 +231,7 @@ async def managed_local_footprint(self) -> int: e, ) - async def wait_for_free_space(self): + async def wait_for_free_space(self) -> None: """Wait for free space on the disk.""" size = await self.managed_local_footprint() disk_free = get_disk_free(self.local_path()) @@ -253,19 +261,19 @@ async def wait_for_free_space(self): raise WorkflowError( f"Cannot store {self.local_path()} " f"({format_size(size)} > {format_size(disk_free)}), " - f"waited {format_timespan(self.provider.wait_for_free_local_storage)} " + f"waited {format_timespan(self.provider.wait_for_free_local_storage or 0)} " "for more space." ) class StorageObjectWrite(StorageObjectBase): @abstractmethod - def store_object(self): ... + def store_object(self) -> None: ... @abstractmethod - def remove(self): ... + def remove(self) -> None: ... - async def managed_remove(self): + async def managed_remove(self) -> None: try: async with self._rate_limiter(Operation.REMOVE): self.remove() @@ -274,7 +282,7 @@ async def managed_remove(self): f"Failed to remove storage object {self.print_query}", e ) - async def managed_store(self): + async def managed_store(self) -> None: try: async with self._rate_limiter(Operation.STORE): self.store_object() @@ -295,11 +303,11 @@ def list_candidate_matches(self) -> Iterable[str]: class StorageObjectTouch(StorageObjectBase): @abstractmethod - def touch(self): + def touch(self) -> None: """Touch the object.""" ... - async def managed_touch(self): + async def managed_touch(self) -> None: try: async with self._rate_limiter(Operation.TOUCH): self.touch() diff --git a/snakemake_interface_storage_plugins/storage_provider.py b/snakemake_interface_storage_plugins/storage_provider.py index 5430006..83d6563 100644 --- a/snakemake_interface_storage_plugins/storage_provider.py +++ b/snakemake_interface_storage_plugins/storage_provider.py @@ -12,7 +12,7 @@ from pathlib import Path import sys from abc import ABC, abstractmethod -from typing import Any, List, Optional +from typing import Any, Dict, Generic, List, Optional, AsyncGenerator, TypeVar from throttler import Throttler from snakemake_interface_common.exceptions import WorkflowError @@ -22,21 +22,52 @@ @dataclass class StorageQueryValidationResult: + """Result of validating a storage query string. + + Represents whether a query string is valid for a storage provider + and provides a reason if the validation fails. + + Parameters + ---------- + query : str + The query string being validated. + valid : bool + Whether the query is valid for the storage provider. + reason : str, optional + If invalid, explanation of why validation failed. + """ + query: str valid: bool reason: Optional[str] = None - def __str__(self): + def __str__(self) -> str: if self.valid: return f"query {self.query} is valid" - else: + elif self.reason: return f"query {self.query} is invalid: {self.reason}" + else: + return f"query {self.query} is invalid. Reason Unknown" - def __bool__(self): + def __bool__(self) -> bool: return self.valid class QueryType(Enum): + """Enumeration of query types for storage providers. + + Defines the context in which a query will be used within a workflow. + + Attributes + ---------- + INPUT : int + Query used for reading/retrieving data from storage. + OUTPUT : int + Query used for writing/storing data to storage. + ANY : int + Query usable for both input and output operations. + """ + INPUT = 0 OUTPUT = 1 ANY = 2 @@ -44,29 +75,76 @@ class QueryType(Enum): @dataclass class ExampleQuery: + """Example query for a storage provider with description and intended usage. + + Provides documentation and examples for users to understand how to construct + valid queries for a specific storage provider. + + Parameters + ---------- + query : str + Example query string demonstrating correct format. + description : str + Human-readable explanation of what the query does. + type : QueryType + Whether this example is for input, output, or both. + """ + query: str description: str type: QueryType -class StorageProviderBase(ABC): - """This is an abstract class to be used to derive remote provider classes. - These might be used to hold common credentials, - and are then passed to StorageObjects. +TStorageProviderSettings = TypeVar( + "TStorageProviderSettings", + bound="StorageProviderSettingsBase", +) + + +class StorageProviderBase(ABC, Generic[TStorageProviderSettings]): + """Abstract base class for Snakemake storage providers. + + Defines the interface for interacting with external storage systems + like S3, GCS, HTTP, etc. Storage providers handle authentication, + rate limiting, caching, and mapping between remote resources and + local files within Snakemake workflows. + + Parameters + ---------- + local_prefix : Path + Directory where remote files are cached locally. + settings : StorageProviderSettingsBase, optional + Provider-specific configuration options. + keep_local : bool, default=False + Whether to retain local copies after workflow completion. + retrieve : bool, default=True + Whether to automatically fetch remote files when referenced. + is_default : bool, default=False + Whether this provider is the default for its protocol. """ + # Class attributes with type hints + local_prefix: Path + logger: Logger + wait_for_free_local_storage: Optional[int] + settings: Optional[TStorageProviderSettings] + keep_local: bool + retrieve: bool + is_default: bool + _rate_limiters: Dict[Any, Throttler] + def __init__( self, local_prefix: Path, logger: Logger, wait_for_free_local_storage: Optional[int] = None, - settings: Optional[StorageProviderSettingsBase] = None, - keep_local=False, - retrieve=True, - is_default=False, + settings: Optional[TStorageProviderSettings] = None, + keep_local: bool = False, + retrieve: bool = True, + is_default: bool = False, ): self.logger: Logger = logger - self.wait_for_free_local_storage: int = wait_for_free_local_storage + self.wait_for_free_local_storage = wait_for_free_local_storage try: local_prefix.mkdir(parents=True, exist_ok=True) except OSError as e: @@ -79,19 +157,30 @@ def __init__( self.retrieve = retrieve self.is_default = is_default self._rate_limiters = dict() + try: + self.local_prefix.mkdir(parents=True, exist_ok=True) + except OSError as e: + raise WorkflowError( + f"Failed to create local storage prefix {self.local_prefix}", e + ) self.__post_init__() - def __post_init__(self): # noqa B027 + def __post_init__(self) -> None: + """Hook for subclasses to perform additional initialization. + + Subclasses may override this method to perform additional setup + after the base class has been initialized. + """ pass - def rate_limiter(self, query: str, operation: Operation): + def rate_limiter(self, query: str, operation: Operation) -> Throttler: if not self.use_rate_limiter(): return self._noop_context() else: key = self.rate_limiter_key(query, operation) if key not in self._rate_limiters: max_status_checks_frac = Fraction( - self.settings.max_requests_per_second + (self.settings.max_requests_per_second if self.settings else None) or self.default_max_requests_per_second() ).limit_denominator() self._rate_limiters[key] = Throttler( @@ -101,7 +190,7 @@ def rate_limiter(self, query: str, operation: Operation): return self._rate_limiters[key] @asynccontextmanager - async def _noop_context(self): + async def _noop_context(self) -> AsyncGenerator[Any, Any]: yield @classmethod @@ -159,13 +248,17 @@ def safe_print(self, query: str) -> str: @property def is_read_write(self) -> bool: from snakemake_interface_storage_plugins.storage_object import ( - StorageObjectReadWrite, + StorageObjectRead, + StorageObjectWrite, ) - return isinstance(self.storage_object_cls, StorageObjectReadWrite) + cls = self.get_storage_object_cls() + return issubclass(cls, StorageObjectRead) and issubclass( + cls, StorageObjectWrite + ) @classmethod - def get_storage_object_cls(cls): + def get_storage_object_cls(cls) -> type: provider = sys.modules[cls.__module__] # get module of derived class return provider.StorageObject @@ -175,7 +268,7 @@ def object( keep_local: Optional[bool] = None, retrieve: Optional[bool] = None, static: bool = False, - ): + ) -> Any: from snakemake_interface_storage_plugins.storage_object import ( StaticStorageObjectProxy, )