diff --git a/.github/actions/test-python/action.yaml b/.github/actions/test-python/action.yaml index 99ebf94b2e..1d6676d821 100644 --- a/.github/actions/test-python/action.yaml +++ b/.github/actions/test-python/action.yaml @@ -25,6 +25,10 @@ inputs: description: Whether to run mypy type checking required: false default: "true" + type_check_dir: + description: The directory to run mypy type checking in + required: false + default: "." format_check: description: Whether to run formatting checks (isort, black) required: false @@ -62,7 +66,7 @@ runs: - name: Check types if: inputs.type_check == 'true' - run: uv run mypy . + run: uv run mypy ${{ inputs.type_check_dir }} shell: bash working-directory: ${{ inputs.directory }} diff --git a/.github/workflows/test.client.yaml b/.github/workflows/test.client.yaml index 510b859bb5..af7b022c2c 100644 --- a/.github/workflows/test.client.yaml +++ b/.github/workflows/test.client.yaml @@ -267,7 +267,7 @@ jobs: coverage: true coverage_module: synnax coverage_flag: client-py - type_check: false + type_check_dir: "./synnax" ts: name: Test - TypeScript diff --git a/client/py/examples/simulators/press.py b/client/py/examples/simulators/press.py index d3b0aaff4d..5e6258ca57 100644 --- a/client/py/examples/simulators/press.py +++ b/client/py/examples/simulators/press.py @@ -130,7 +130,7 @@ def _run_loop(self) -> None: with self.client.open_writer( start=sy.TimeStamp.now(), channels=[ - self.daq_time_ch.key, + self.daq_time_ch.name, "press_vlv_state", "vent_vlv_state", "press_pt", diff --git a/client/py/examples/simulators/thermal.py b/client/py/examples/simulators/thermal.py index 993006c030..9ce2898727 100644 --- a/client/py/examples/simulators/thermal.py +++ b/client/py/examples/simulators/thermal.py @@ -96,7 +96,7 @@ def _create_channels(self) -> None: client.write( now, { - self.daq_time.key: [now], + self.daq_time.name: [now], "temp_sensor": [self.AMBIENT_TEMP], "heater_state": [0], }, @@ -119,7 +119,7 @@ def _run_loop(self) -> None: ) as streamer: with self.client.open_writer( start=sy.TimeStamp.now(), - channels=[self.daq_time.key, "temp_sensor", "heater_state"], + channels=[self.daq_time.name, "temp_sensor", "heater_state"], name="Thermal Sim DAQ", ) as writer: force_overheat = False diff --git a/client/py/pyproject.toml b/client/py/pyproject.toml index 097ff957da..8bd60bce19 100644 --- a/client/py/pyproject.toml +++ b/client/py/pyproject.toml @@ -32,6 +32,7 @@ dev = [ "black>=26.1.0,<27", "isort>=7.0.0,<8", "mypy>=1.19.1,<2", + "pandas-stubs>=2.0.0", "pymodbus>=3.11.4,<4", "pytest>=9.0.2,<10", "pytest-asyncio>=1.3.0,<2", @@ -46,9 +47,14 @@ packages = ["synnax", "examples"] profile = "black" [tool.mypy] -plugins = ["numpy.typing.mypy_plugin", "pydantic.mypy"] +plugins = ["pydantic.mypy"] strict = true +[[tool.mypy.overrides]] +module = "nptdms.*" +ignore_missing_imports = true +follow_imports = "skip" + [tool.pydantic-mypy] init_forbid_extra = true init_types = true @@ -62,6 +68,7 @@ markers = [ "channel: mark test as a channel test", "cli: mark test as a cli test", "control: mark test as a control test", + "deprecation: mark test as a deprecation test", "device: mark test as a device test", "ethercat: mark test as an ethercat test", "framer: mark test as a framer test", diff --git a/client/py/synnax/__init__.py b/client/py/synnax/__init__.py index a5c16e026a..71b6f117ac 100644 --- a/client/py/synnax/__init__.py +++ b/client/py/synnax/__init__.py @@ -20,7 +20,7 @@ Handle, Position, ) -from synnax.arc import Task as ArcTask +from synnax.arc import Task as _ArcTask from synnax.arc import ( Text, ) @@ -57,8 +57,8 @@ from synnax.ranger import Range from synnax.status import Status from synnax.synnax import Synnax -from synnax.task import Status as TaskStatus -from synnax.task import StatusDetails as TaskStatusDetails +from synnax.task import Status as _TaskStatus +from synnax.task import StatusDetails as _TaskStatusDetails from synnax.task import Task from synnax.telem import ( Alignment, @@ -86,13 +86,20 @@ ) from synnax.timing import Loop, Timer, sleep from synnax.user.payload import User +from synnax.util.deprecation import deprecated_getattr -SynnaxOptions = Options +_DEPRECATED: dict[str, str | tuple[str, str]] = { + "ArcTask": ("synnax.arc.Task", "_ArcTask"), + "TaskStatus": ("synnax.task.Status", "_TaskStatus"), + "TaskStatusDetails": ("synnax.task.StatusDetails", "_TaskStatusDetails"), + "SynnaxOptions": "Options", +} + +__getattr__ = deprecated_getattr(__name__, _DEPRECATED, globals()) __all__ = [ "Alignment", "Arc", - "ArcTask", "AUTO_SPAN", "AuthError", "Authority", @@ -143,10 +150,7 @@ "Status", "Streamer", "Synnax", - "SynnaxOptions", "Task", - "TaskStatus", - "TaskStatusDetails", "Text", "Timer", "TimeRange", @@ -165,4 +169,5 @@ "ni", "opcua", "status", + "Status", ] diff --git a/client/py/synnax/access/policy/__init__.py b/client/py/synnax/access/policy/__init__.py index 653a28470c..abde9245b3 100644 --- a/client/py/synnax/access/policy/__init__.py +++ b/client/py/synnax/access/policy/__init__.py @@ -16,13 +16,17 @@ Policy, ontology_id, ) +from synnax.util.deprecation import deprecated_getattr -# Backwards compatibility -PolicyClient = Client -CREATE_ACTION = ACTION_CREATE -DELETE_ACTION = ACTION_DELETE -RETRIEVE_ACTION = ACTION_RETRIEVE -UPDATE_ACTION = ACTION_UPDATE +_DEPRECATED = { + "PolicyClient": "Client", + "CREATE_ACTION": "ACTION_CREATE", + "DELETE_ACTION": "ACTION_DELETE", + "RETRIEVE_ACTION": "ACTION_RETRIEVE", + "UPDATE_ACTION": "ACTION_UPDATE", +} + +__getattr__ = deprecated_getattr(__name__, _DEPRECATED, globals()) __all__ = [ "Client", @@ -32,9 +36,4 @@ "ACTION_RETRIEVE", "ACTION_UPDATE", "ontology_id", - "PolicyClient", - "CREATE_ACTION", - "DELETE_ACTION", - "RETRIEVE_ACTION", - "UPDATE_ACTION", ] diff --git a/client/py/synnax/access/policy/payload.py b/client/py/synnax/access/policy/payload.py index b1fb3513d4..2fd1f00879 100644 --- a/client/py/synnax/access/policy/payload.py +++ b/client/py/synnax/access/policy/payload.py @@ -34,8 +34,13 @@ def ontology_id(key: UUID | None = None) -> ontology.ID: return ontology.ID(type="policy", key=key if key is None else str(key)) -# Backwards compatibility -CREATE_ACTION = ACTION_CREATE -DELETE_ACTION = ACTION_DELETE -RETRIEVE_ACTION = ACTION_RETRIEVE -UPDATE_ACTION = ACTION_UPDATE +from synnax.util.deprecation import deprecated_getattr + +_DEPRECATED = { + "CREATE_ACTION": "ACTION_CREATE", + "DELETE_ACTION": "ACTION_DELETE", + "RETRIEVE_ACTION": "ACTION_RETRIEVE", + "UPDATE_ACTION": "ACTION_UPDATE", +} + +__getattr__ = deprecated_getattr(__name__, _DEPRECATED, globals()) diff --git a/client/py/synnax/access/role/__init__.py b/client/py/synnax/access/role/__init__.py index 915f84850b..a0cf144211 100644 --- a/client/py/synnax/access/role/__init__.py +++ b/client/py/synnax/access/role/__init__.py @@ -10,8 +10,12 @@ from synnax.access.role.client import Client from synnax.access.role.payload import ONTOLOGY_TYPE, Role, ontology_id +from synnax.util.deprecation import deprecated_getattr -# Backwards compatibility -RoleClient = Client +_DEPRECATED = { + "RoleClient": "Client", +} -__all__ = ["Role", "Client", "ONTOLOGY_TYPE", "ontology_id", "RoleClient"] +__getattr__ = deprecated_getattr(__name__, _DEPRECATED, globals()) + +__all__ = ["Role", "Client", "ONTOLOGY_TYPE", "ontology_id"] diff --git a/client/py/synnax/access/role/client.py b/client/py/synnax/access/role/client.py index da831b1a04..a587ad299e 100644 --- a/client/py/synnax/access/role/client.py +++ b/client/py/synnax/access/role/client.py @@ -65,7 +65,7 @@ def __init__( @overload def create( self, - role: Role, + roles: Role, ) -> Role: ... @overload @@ -89,6 +89,7 @@ def retrieve(self, key: UUID) -> Role: ... @overload def retrieve( self, + *, keys: list[UUID] | None = None, limit: int | None = None, offset: int | None = None, @@ -104,7 +105,7 @@ def retrieve( internal: bool | None = None, ) -> Role | list[Role]: is_single = key is not None - if is_single: + if is_single and key is not None: keys = [key] req = _RetrieveRequest(keys=keys, limit=limit, offset=offset, internal=internal) res = send_required( diff --git a/client/py/synnax/arc/__init__.py b/client/py/synnax/arc/__init__.py index 505cf01ef1..f5186b579f 100644 --- a/client/py/synnax/arc/__init__.py +++ b/client/py/synnax/arc/__init__.py @@ -20,14 +20,18 @@ Text, ) from synnax.arc.types import Task, TaskConfig +from synnax.util.deprecation import deprecated_getattr -# Backwards compatibility -ArcTask = Task -ArcTaskConfig = TaskConfig -ArcClient = Client -ArcKey = Key -ArcMode = Mode -ArcPayload = Payload +_DEPRECATED = { + "ArcTask": "Task", + "ArcTaskConfig": "TaskConfig", + "ArcClient": "Client", + "ArcKey": "Key", + "ArcMode": "Mode", + "ArcPayload": "Payload", +} + +__getattr__ = deprecated_getattr(__name__, _DEPRECATED, globals()) __all__ = [ "Arc", @@ -43,10 +47,4 @@ "Handle", "Position", "Text", - "ArcTask", - "ArcTaskConfig", - "ArcClient", - "ArcKey", - "ArcMode", - "ArcPayload", ] diff --git a/client/py/synnax/arc/payload.py b/client/py/synnax/arc/payload.py index 4b3a6fbd40..4fc25d5dc7 100644 --- a/client/py/synnax/arc/payload.py +++ b/client/py/synnax/arc/payload.py @@ -98,8 +98,13 @@ class Payload(BaseModel): """Visual graph representation of the program.""" -# Backwards compatibility -ARC_ONTOLOGY_TYPE = ONTOLOGY_TYPE -ArcKey = Key -ArcMode = Mode -ArcPayload = Payload +from synnax.util.deprecation import deprecated_getattr + +_DEPRECATED = { + "ARC_ONTOLOGY_TYPE": "ONTOLOGY_TYPE", + "ArcKey": "Key", + "ArcMode": "Mode", + "ArcPayload": "Payload", +} + +__getattr__ = deprecated_getattr(__name__, _DEPRECATED, globals()) diff --git a/client/py/synnax/arc/types.py b/client/py/synnax/arc/types.py index a66a0642e7..1a0dbb3f19 100644 --- a/client/py/synnax/arc/types.py +++ b/client/py/synnax/arc/types.py @@ -57,6 +57,11 @@ def __init__( self.config = TaskConfig(arc_key=str(arc_key), auto_start=auto_start) -# Backwards compatibility -ArcTask = Task -ArcTaskConfig = TaskConfig +from synnax.util.deprecation import deprecated_getattr + +_DEPRECATED = { + "ArcTask": "Task", + "ArcTaskConfig": "TaskConfig", +} + +__getattr__ = deprecated_getattr(__name__, _DEPRECATED, globals()) diff --git a/client/py/synnax/auth.py b/client/py/synnax/auth.py index 9cfe4ff8be..2b7baef697 100644 --- a/client/py/synnax/auth.py +++ b/client/py/synnax/auth.py @@ -69,8 +69,8 @@ def authenticate(self) -> None: self.user = res.user self.authenticated = True - def middleware(self) -> list[Middleware]: - def mw(ctx: Context, _next: Next): + def middleware(self) -> Middleware: + def mw(ctx: Context, _next: Next) -> tuple[Context, Exception | None]: if not self.authenticated: self.authenticate() @@ -86,8 +86,10 @@ def mw(ctx: Context, _next: Next): return mw - def async_middleware(self) -> list[AsyncMiddleware]: - async def mw(ctx: Context, _next: AsyncNext): + def async_middleware(self) -> AsyncMiddleware: + async def mw( + ctx: Context, _next: AsyncNext + ) -> tuple[Context, Exception | None]: if not self.authenticated: self.authenticate() @@ -112,5 +114,10 @@ def maybe_refresh_token( self.token = refresh -# Backwards compatibility -AuthenticationClient = Client +from synnax.util.deprecation import deprecated_getattr + +_DEPRECATED = { + "AuthenticationClient": "Client", +} + +__getattr__ = deprecated_getattr(__name__, _DEPRECATED, globals()) diff --git a/client/py/synnax/channel/__init__.py b/client/py/synnax/channel/__init__.py index 526bbcbdc3..dd63ca5d3b 100644 --- a/client/py/synnax/channel/__init__.py +++ b/client/py/synnax/channel/__init__.py @@ -10,9 +10,12 @@ from synnax.channel.client import Channel, Client from synnax.channel.payload import ( Key, + NormalizedKeyResult, + NormalizedNameResult, Operation, Params, Payload, + has_params, normalize_params, ) from synnax.channel.retrieve import ( @@ -22,14 +25,18 @@ retrieve_required, ) from synnax.channel.writer import Writer +from synnax.util.deprecation import deprecated_getattr -# Backwards compatibility -ChannelClient = Client -ChannelKey = Key -ChannelParams = Params -ChannelPayload = Payload -ChannelRetriever = Retriever -normalize_channel_params = normalize_params +_DEPRECATED = { + "ChannelClient": "Client", + "ChannelKey": "Key", + "ChannelParams": "Params", + "ChannelPayload": "Payload", + "ChannelRetriever": "Retriever", + "normalize_channel_params": "normalize_params", +} + +__getattr__ = deprecated_getattr(__name__, _DEPRECATED, globals()) __all__ = [ "Channel", @@ -44,10 +51,7 @@ "Retriever", "retrieve_required", "Writer", - "ChannelClient", - "ChannelKey", - "ChannelParams", - "ChannelPayload", - "ChannelRetriever", - "normalize_channel_params", + "has_params", + "NormalizedNameResult", + "NormalizedKeyResult", ] diff --git a/client/py/synnax/channel/client.py b/client/py/synnax/channel/client.py index 1e6d3e6025..f18a6438de 100644 --- a/client/py/synnax/channel/client.py +++ b/client/py/synnax/channel/client.py @@ -151,6 +151,8 @@ def rename(self, name: str) -> None: :param name: The new name for the channel. :returns: None. """ + if self.__client is None: + raise ValidationError("Cannot rename a channel that has not been created.") self.__client.rename(self.key, name) @property @@ -168,7 +170,9 @@ def __frame_client(self) -> framer.Client: def __hash__(self) -> int: return hash(self.key) - def __eq__(self, other) -> bool: + def __eq__(self, other: object) -> bool: + if not isinstance(other, Channel): + return NotImplemented return self.key == other.key def to_payload(self) -> Payload: @@ -217,7 +221,7 @@ def create( is_index: bool = False, leaseholder: int = 0, virtual: bool | None = None, - expression: str | None = None, + expression: str = "", operations: list[Operation] | None = None, retrieve_if_name_exists: bool = False, ) -> Channel: ... @@ -367,9 +371,12 @@ def rename( :param names: The new names for the channels. :returns: None. """ + ... def rename( - self, keys: list[Key] | tuple[Key], names: list[str] | tuple[str] + self, + keys: Key | list[Key] | tuple[Key], + names: str | list[str] | tuple[str], ) -> None: """Renames one or more channels in the cluster. diff --git a/client/py/synnax/channel/payload.py b/client/py/synnax/channel/payload.py index 59c6dbc930..598d017ac5 100644 --- a/client/py/synnax/channel/payload.py +++ b/client/py/synnax/channel/payload.py @@ -10,7 +10,7 @@ from __future__ import annotations from dataclasses import dataclass -from typing import Literal, cast +from typing import Literal, Sequence, TypeAlias from pydantic import BaseModel @@ -19,15 +19,13 @@ from synnax.util.normalize import normalize Key = int -Params = Key | list[Key] | tuple[Key] | str | list[str] | tuple[str] - ONTOLOGY_TYPE = ontology.ID(type="channel") def ontology_id(key: Key) -> ontology.ID: """Returns the ontology ID for the Channel entity.""" - return ontology.ID(type=ONTOLOGY_TYPE.type, key=key) + return ontology.ID(type=ONTOLOGY_TYPE.type, key=str(key)) OPERATION_TYPES = Literal["min", "max", "avg", "none"] @@ -38,7 +36,7 @@ class Operation(BaseModel): type: OPERATION_TYPES reset_channel: Key = 0 - duration: TimeSpan = 0 + duration: TimeSpan = TimeSpan(0) class Payload(BaseModel): @@ -57,7 +55,7 @@ class Payload(BaseModel): expression: str | None = "" operations: list[Operation] | None = None - def __str__(self): + def __str__(self) -> str: return f"Channel(name={self.name}, key={self.key})" def __hash__(self) -> int: @@ -65,55 +63,76 @@ def __hash__(self) -> int: @dataclass -class NormalizedChannelKeyResult: +class NormalizedKeyResult: single: bool - variant: Literal["keys"] channels: list[Key] | tuple[Key] @dataclass -class NormalizedChannelNameResult: +class NormalizedNameResult: single: bool - variant: Literal["names"] channels: list[str] +Params: TypeAlias = ( + Key | str | Payload | Sequence[Key] | Sequence[str] | Sequence[Payload] +) + + def normalize_params( channels: Params, -) -> NormalizedChannelKeyResult | NormalizedChannelNameResult: +) -> NormalizedKeyResult | NormalizedNameResult: """Determine if a list of keys or names is a single key or name.""" normalized = normalize(channels) if len(normalized) == 0: - return NormalizedChannelKeyResult(single=False, variant="keys", channels=[]) + return NormalizedKeyResult(single=False, channels=[]) single = isinstance(channels, (Key, str)) - if isinstance(normalized[0], str): + first = normalized[0] + if isinstance(first, str): + str_list = [s for s in normalized if isinstance(s, str)] + if len(str_list) != len(normalized): + raise TypeError( + "channel params must be all keys or all names, got a mix of both" + ) try: - numeric_strings = [Key(s) for s in normalized] - return NormalizedChannelKeyResult( + return NormalizedKeyResult( single=single, - variant="keys", - channels=numeric_strings, + channels=[Key(s) for s in str_list], ) - except ValueError: - return NormalizedChannelNameResult( + except (ValueError, TypeError): + return NormalizedNameResult( single=single, - variant="names", - channels=cast(list[str] | tuple[str], normalized), + channels=str_list, + ) + elif isinstance(first, Payload): + payload_list = [c.key for c in normalized if isinstance(c, Payload)] + if len(payload_list) != len(normalized): + raise TypeError( + "channel params must be all keys or all names, got a mix of both" ) - elif isinstance(normalized[0], Payload): - return NormalizedChannelNameResult( - single=single, - variant="keys", - channels=[c.key for c in normalized], + return NormalizedKeyResult(single=single, channels=payload_list) + key_list = [k for k in normalized if isinstance(k, int)] + if len(key_list) != len(normalized): + raise TypeError( + "channel params must be all keys or all names, got a mix of both" ) - return NormalizedChannelKeyResult( - single=single, - variant="keys", - channels=normalized, - ) + return NormalizedKeyResult(single=single, channels=key_list) + + +def has_params(channels: Params | None) -> bool: + if channels is None: + return False + if isinstance(channels, (Key, str, Payload)): + return True + return len(channels) > 0 + + +from synnax.util.deprecation import deprecated_getattr +_DEPRECATED = { + "ChannelKey": "Key", + "ChannelParams": "Params", + "ChannelPayload": "Payload", +} -# Backwards compatibility -ChannelKey = Key -ChannelParams = Params -ChannelPayload = Payload +__getattr__ = deprecated_getattr(__name__, _DEPRECATED, globals()) diff --git a/client/py/synnax/channel/retrieve.py b/client/py/synnax/channel/retrieve.py index dafb9594bf..e89c6a9d31 100644 --- a/client/py/synnax/channel/retrieve.py +++ b/client/py/synnax/channel/retrieve.py @@ -17,6 +17,7 @@ from synnax.channel.payload import ( Key, + NormalizedNameResult, Params, Payload, normalize_params, @@ -63,11 +64,14 @@ def retrieve(self, channels: Params) -> list[Payload]: normal = normalize_params(channels) if len(normal.channels) == 0: return list() - req = _Request(**{normal.variant: normal.channels}) + if isinstance(normal, NormalizedNameResult): + req = _Request(names=normal.channels) + else: + req = _Request(keys=normal.channels) return self.__exec_retrieve(req) @trace("debug") - def retrieve_one(self, param: Key | str) -> Payload | None: + def retrieve_one(self, param: Key | str) -> Payload: req = _Request() if isinstance(param, Key): req.keys = [param] @@ -100,7 +104,7 @@ def __init__( def delete(self, keys: Params) -> None: normal = normalize_params(keys) - if normal.variant == "names": + if isinstance(normal, NormalizedNameResult): matches = { ch for ch in self._channels.values() if ch.name in normal.channels } @@ -169,41 +173,47 @@ def _set_one(self, channel: Payload) -> None: @trace("debug") def retrieve(self, channels: Params) -> list[Payload]: normal = normalize_params(channels) - results = list() - to_retrieve: list[Key] | tuple[Key] | list[str] | tuple[str] = list() # type: ignore - for p in normal.channels: - ch = self._get(p) + results: list[Payload] = [] + missed: list[int] = [] + if isinstance(normal, NormalizedNameResult): + params: list[Key | str] = list(normal.channels) + else: + params = list(normal.channels) + for i, param in enumerate(params): + ch = self._get(param) if ch is None: - to_retrieve.append(p) # type: ignore + missed.append(i) else: results.extend(ch) - - if len(to_retrieve) == 0: + if not missed: return results - + if isinstance(normal, NormalizedNameResult): + to_retrieve: Params = [normal.channels[i] for i in missed] + else: + to_retrieve = [normal.channels[i] for i in missed] retrieved = self._retriever.retrieve(to_retrieve) self.set(retrieved) results.extend(retrieved) return results - def retrieve_one(self, param: Key | str) -> Payload | None: + def retrieve_one(self, param: Key | str) -> Payload: ch = self._get_one(param) if ch is not None: return ch retrieved = self._retriever.retrieve_one(param) - if retrieved is not None: - self._set_one(retrieved) + self._set_one(retrieved) return retrieved def retrieve_required(r: Retriever, channels: Params) -> list[Payload]: normal = normalize_params(channels) results = r.retrieve(channels) - not_found = list() - for p in normal.channels: - ch = next((c for c in results if c.key == p or c.name == p), None) - if ch is None: - not_found.append(p) + found: set[Key | str] + if isinstance(normal, NormalizedNameResult): + found = {c.name for c in results} + else: + found = {c.key for c in results} + not_found: list[Key | str] = [p for p in normal.channels if p not in found] if len(not_found) > 0: raise NotFoundError(f"Could not find channels: {not_found}") return results diff --git a/client/py/synnax/channel/writer.py b/client/py/synnax/channel/writer.py index 1af8b166e6..9e0a06ff28 100644 --- a/client/py/synnax/channel/writer.py +++ b/client/py/synnax/channel/writer.py @@ -13,6 +13,7 @@ from synnax.channel.payload import ( Key, + NormalizedNameResult, Params, Payload, normalize_params, @@ -66,7 +67,10 @@ def create( @trace("debug") def delete(self, channels: Params) -> None: normal = normalize_params(channels) - req = _DeleteRequest(**{normal.variant: normal.channels}) + if isinstance(normal, NormalizedNameResult): + req = _DeleteRequest(names=normal.channels) + else: + req = _DeleteRequest(keys=normal.channels) send_required(self._client, "/channel/delete", req, Empty) if self._cache is not None: self._cache.delete(normal.channels) @@ -78,4 +82,4 @@ def rename( req = _RenameRequest(keys=keys, names=names) send_required(self._client, "/channel/rename", req, Empty) if self._cache is not None: - self._cache.rename(keys, names) + self._cache.rename(list(keys), list(names)) diff --git a/client/py/synnax/cli/channel.py b/client/py/synnax/cli/channel.py index 23593bd3b4..0572bfdb3c 100644 --- a/client/py/synnax/cli/channel.py +++ b/client/py/synnax/cli/channel.py @@ -8,16 +8,16 @@ # included in the file licenses/APL.txt. import fnmatch +from typing import Any from synnax.channel import Channel -from synnax.cli.console.sugared import AskKwargs from synnax.cli.flow import Context def channel_name_table( ctx: Context, names: list[str], -): +) -> None: """Creates a table containing names of the channels. :param ctx: The current flow context. @@ -33,7 +33,7 @@ def maybe_select_channel( ctx: Context, channels: list[Channel], param: str, - **kwargs: AskKwargs[str], + **kwargs: Any, ) -> Channel | None: """Asks the user to select a channel if there are multiple channels available. @@ -52,7 +52,7 @@ def maybe_select_channel( def select_channel( ctx: Context, channels: list[Channel], - **kwargs: AskKwargs[str], + **kwargs: Any, ) -> Channel | None: """Prompts the user to select a channel from a list of channels. @@ -68,6 +68,8 @@ def select_channel( rows=[c.model_dump() for c in channels], **kwargs, ) + if i is None: + return None return channels[i] @@ -87,14 +89,17 @@ def prompt_group_channel_names( 3) A pattern to match (e.g. 'channel*, sensor*') """ ) - return group_channel_names(ctx, options, ctx.console.ask("channels").split(",")) + response = ctx.console.ask("channels") + if response is None: + return None + return group_channel_names(ctx, options, response.split(",")) def group_channel_names( ctx: Context, options: list[str], matchers: list[str], -): +) -> dict[str, list[str]] | None: """Groups channel names by matching them against a list of matchers. :param ctx: The current flow Context. diff --git a/client/py/synnax/cli/check_timing.py b/client/py/synnax/cli/check_timing.py index 761ccaa10b..bd2c448e74 100644 --- a/client/py/synnax/cli/check_timing.py +++ b/client/py/synnax/cli/check_timing.py @@ -7,6 +7,8 @@ # License, use of this software will be governed by the Apache License, Version 2.0, # included in the file licenses/APL.txt. +from typing import cast + import click import matplotlib.pyplot as plt import numpy as np @@ -32,11 +34,11 @@ def check_timing(ctx: click.Context) -> None: time_channels = [ch for ch in channels if ch.data_type == sy.DataType.TIMESTAMP] if not time_channels: - ctx.console.error("No time channels found in the database") + default.context().console.error("No time channels found in the database") return # Let user select the time channel - time_channel = select_channel(default.context(), time_channels, key="name") + time_channel = select_channel(default.context(), time_channels) if time_channel is None: return @@ -46,6 +48,7 @@ def check_timing(ctx: click.Context) -> None: type_=int, default=10, ) + assert duration is not None span = sy.TimeSpan.SECOND * duration # Collect samples @@ -61,7 +64,9 @@ def collect_samples( client: sy.Synnax, time_channel: sy.channel.Key, span: sy.TimeSpan, -): +) -> tuple[ + list[sy.TimeSpan], list[sy.TimeSpan], list[sy.TimeStamp], sy.TimeStamp, sy.TimeStamp +]: # Tracks the offset between the local clock and the time channel offsets: list[sy.TimeSpan] = list() # Tracks the spacing between the samples inside individual reads @@ -77,11 +82,13 @@ def collect_samples( while now < end: now = sy.TimeStamp.now() data = streamer.read()[time_channel] - offset = sy.TimeSpan(sy.TimeStamp.now() - sy.TimeStamp(data[-1])) + last = cast(int, data[-1]) + second_last = cast(int, data[-2]) + offset = sy.TimeSpan(sy.TimeStamp.now() - sy.TimeStamp(last)) offsets.append(offset) - diff = sy.TimeSpan(data[-1] - data[-2]) + diff = sy.TimeSpan(last - second_last) diffs.append(diff) - times.extend(data) + times.extend(data) # type: ignore[arg-type] local_end = sy.TimeStamp.now() return offsets, diffs, times, local_start, local_end @@ -134,15 +141,16 @@ def create_timing_report( np.abs(offsets_array - offset_mean) > 5 * offset_std ] - bins_offset = np.concatenate( - [np.linspace(min(offsets_array), max(offsets_array), 1000)] - ) - hist_offset, bins_offset, _ = ax2.hist( - offsets_array, bins=bins_offset, alpha=0.7, color="cyan" + bins_offset_list: list[float] = np.linspace( + min(offsets_array), max(offsets_array), 1000 + ).tolist() + hist_offset, _, _ = ax2.hist( + offsets_array, bins=bins_offset_list, alpha=0.7, color="cyan" ) x_offset = np.linspace(min(offsets_array), max(offsets_array), 100) - gaussian_offset = hist_offset.max() * np.exp( + hist_offset_arr = np.asarray(hist_offset) + gaussian_offset = hist_offset_arr.max() * np.exp( -((x_offset - offset_mean) ** 2) / (2 * offset_std**2) ) ax2.plot(x_offset, gaussian_offset, "magenta", lw=2, label="Gaussian fit") @@ -168,13 +176,16 @@ def create_timing_report( rates = [sy.Rate(d) for d in diffs] avg_rate = np.mean([float(r) for r in rates]) - bins_diff = np.concatenate([np.linspace(min(diffs_array), max(diffs_array), 500)]) - hist_diff, bins_diff, _ = ax3.hist( - diffs_array, bins=bins_diff, alpha=0.7, color="cyan" + bins_diff_list: list[float] = np.linspace( + min(diffs_array), max(diffs_array), 500 + ).tolist() + hist_diff, _, _ = ax3.hist( + diffs_array, bins=bins_diff_list, alpha=0.7, color="cyan" ) x_diff = np.linspace(min(diffs_array), max(diffs_array), 100) - gaussian_diff = hist_diff.max() * np.exp( + hist_diff_arr = np.asarray(hist_diff) + gaussian_diff = hist_diff_arr.max() * np.exp( -((x_diff - diff_mean) ** 2) / (2 * diff_std**2) ) ax3.plot(x_diff, gaussian_diff, "magenta", lw=2, label="Gaussian fit") diff --git a/client/py/synnax/cli/connect.py b/client/py/synnax/cli/connect.py index c85b984459..5f1c891019 100644 --- a/client/py/synnax/cli/connect.py +++ b/client/py/synnax/cli/connect.py @@ -58,13 +58,12 @@ def connect_from_options(ctx: Context, opts: Options) -> Synnax | None: try: client = Synnax(**opts.model_dump()) except Unreachable: - return ctx.console.error( - f"Cannot reach Synnax server at {opts.host}:{opts.port}" - ) + ctx.console.error(f"Cannot reach Synnax server at {opts.host}:{opts.port}") + return None except AuthError: - return ctx.console.error("Invalid credentials") + ctx.console.error("Invalid credentials") + return None except Exception as e: raise e - # return ctx.console.error(f"An error occurred: {e}") ctx.console.success("Connection successful!") return client diff --git a/client/py/synnax/cli/console/__init__.py b/client/py/synnax/cli/console/__init__.py index 0de9343bed..f9720d43b9 100644 --- a/client/py/synnax/cli/console/__init__.py +++ b/client/py/synnax/cli/console/__init__.py @@ -11,3 +11,5 @@ from synnax.cli.console.protocol import Console from synnax.cli.console.rich import RichConsole from synnax.cli.console.sugared import SugaredConsole + +__all__ = ["MockConsole", "Console", "RichConsole", "SugaredConsole"] diff --git a/client/py/synnax/cli/console/mock.py b/client/py/synnax/cli/console/mock.py index 076b904f67..e9dd6d677d 100644 --- a/client/py/synnax/cli/console/mock.py +++ b/client/py/synnax/cli/console/mock.py @@ -7,7 +7,7 @@ # License, use of this software will be governed by the Apache License, Version 2.0, # included in the file licenses/APL.txt. -from typing import Generic, TextIO +from typing import Any, Generic, TextIO from pydantic import BaseModel @@ -24,7 +24,7 @@ class Entry(BaseModel, Generic[R]): message: str | None = None columns: list[str] | None = None - rows: list[dict] | None = None + rows: list[dict[str, str]] | None = None choices: list[R] | None = None default: R | None = None response: R | None = None @@ -33,16 +33,16 @@ class Entry(BaseModel, Generic[R]): class Output(BaseModel): - entries: list[Entry] + entries: list[Entry[Any]] - def __init__(self, entries: list[Entry] | None = None): + def __init__(self, entries: list[Entry[Any]] | None = None): super().__init__(entries=entries or list()) - def append(self, entry: Entry): + def append(self, entry: Entry[Any]) -> None: assert self.entries is not None self.entries.append(entry) - def write(self, f: TextIO): + def write(self, f: TextIO) -> None: f.write(self.json()) @@ -62,27 +62,27 @@ def __init__(self, output: Output, verbose: bool = False): def _(self) -> Print: return self - def info(self, message: str): + def info(self, message: str) -> None: self.output.append(Entry(message=message)) if self.verbose is not None: self.verbose.info(message) - def error(self, message: str): + def error(self, message: str) -> None: self.output.append(Entry(message=message)) if self.verbose is not None: self.verbose.error(message) - def warn(self, message: str): + def warn(self, message: str) -> None: self.output.append(Entry(message=message)) if self.verbose is not None: self.verbose.warn(message) - def success(self, message: str): + def success(self, message: str) -> None: self.output.append(Entry(message=message)) if self.verbose is not None: self.verbose.success(message) - def table(self, columns: list, rows: list): + def table(self, columns: list[str], rows: list[dict[str, str]]) -> None: self.output.append(Entry(columns=columns, rows=rows)) if self.verbose is not None: self.verbose.table(columns, rows) @@ -92,9 +92,9 @@ class MockPrompt: """A mock implementation of the Prompt protocol for testing purposes.""" output: Output - responses: list + responses: list[Any] - def __init__(self, output: Output, responses: list): + def __init__(self, output: Output, responses: list[Any]): """ :param output: The output list to append entries to. :param responses: A list of responses to return in order. These responses @@ -114,22 +114,26 @@ def ask( default: R | None = None, password: bool = False, ) -> R | None: - e = Entry( + resolved_type = assign_default_ask_type(type_, choices, default) + response: R | None = ( + self.responses.pop(0) if len(self.responses) > 0 else default + ) + e: Entry[Any] = Entry( message=question, choices=choices, default=default, - type_=assign_default_ask_type(type_, choices, default), + type_=resolved_type, password=password, + response=response, ) - e.response = self.responses.pop(0) if len(self.responses) > 0 else default - if type(e.response) != e.type_: + if type(response) != resolved_type: raise TypeError(f""" Mock Prompt: Invalid response type Question: {question} Expected type: {type_} - Actual response: {e.response} + Actual response: {response} """) - return e.response + return response class MockConsole(MockPrint, MockPrompt): @@ -138,7 +142,7 @@ class MockConsole(MockPrint, MockPrompt): def __init__( self, output: Output = Output(), - responses: list | None = None, + responses: list[Any] | None = None, verbose: bool = False, ): """ @@ -148,3 +152,6 @@ def __init__( """ MockPrint.__init__(self, output, verbose) MockPrompt.__init__(self, output, responses or list()) + + def _(self) -> Console: + return self diff --git a/client/py/synnax/cli/console/protocol.py b/client/py/synnax/cli/console/protocol.py index edaf8f29d4..948fc6d457 100644 --- a/client/py/synnax/cli/console/protocol.py +++ b/client/py/synnax/cli/console/protocol.py @@ -9,7 +9,7 @@ from typing import Protocol, TypeVar -R = TypeVar("R", str, int, float, bool, None) +R = TypeVar("R", str, int, float, bool) class Prompt(Protocol): @@ -83,8 +83,8 @@ def success( def table( self, columns: list[str], - rows: list[dict], - ): + rows: list[dict[str, str]], + ) -> None: """Prints a table to the console. :param columns: A list of column names. @@ -116,7 +116,7 @@ def assign_default_ask_type( if choices is not None: type_ = type(choices[0]) elif default is not None: - type_ = type(default) # type: ignore + type_ = type(default) else: type_ = str # type: ignore return type_ # type: ignore diff --git a/client/py/synnax/cli/console/rich.py b/client/py/synnax/cli/console/rich.py index 7b3531955f..9c92bda269 100644 --- a/client/py/synnax/cli/console/rich.py +++ b/client/py/synnax/cli/console/rich.py @@ -57,7 +57,7 @@ def success(self, message: str) -> None: def table( self, columns: list[str], - rows: list[dict], + rows: list[dict[str, str]], ) -> None: from rich.table import Table @@ -81,7 +81,7 @@ def ask( ) -> R | None: if type_ is None: if default is not None: - type_ = type(default) # type: ignore + type_ = type(default) elif choices is not None and len(choices) > 0: type_ = type(choices[0]) else: @@ -98,7 +98,7 @@ def ask( question, default=default, choices=[str(choice) for choice in choices] if choices else None, # type: ignore - ) # type: ignore + ) if type_ == float: return FloatPrompt.ask( question, @@ -109,4 +109,4 @@ def ask( choices=choices, # type: ignore default=default, password=password, - ) # type: ignore + ) diff --git a/client/py/synnax/cli/console/sugared.py b/client/py/synnax/cli/console/sugared.py index 897fd16557..19cf1c300a 100644 --- a/client/py/synnax/cli/console/sugared.py +++ b/client/py/synnax/cli/console/sugared.py @@ -7,7 +7,7 @@ # License, use of this software will be governed by the Apache License, Version 2.0, # included in the file licenses/APL.txt. -from typing import Any, Generic, NotRequired, TypedDict, Unpack, overload +from typing import Any, Generic, NotRequired, TypedDict, cast from synnax.cli.console.protocol import Print, Prompt, R from synnax.exceptions import ValidationError @@ -52,54 +52,21 @@ def error(self, message: str) -> None: if self.enabled: self.print.error(message) - def table(self, columns: list[str], rows: list[dict]) -> None: + def table(self, columns: list[str], rows: list[dict[str, str]]) -> None: if self.enabled: self.print.table(columns, rows) - @overload def ask( self, question: str, type_: type[R] | None = None, choices: list[R] | None = None, password: bool = False, - **kwargs: Unpack[AskKwargs[R]], - ) -> R: ... - - @overload - def ask( - self, - question: str, - type_: type[R] | None = None, - choices: list[str] | None = None, - password: bool = False, - **kwargs: Unpack[DefaultAskKwargs[R]], - ) -> R: ... - - @overload - def ask( - self, - question: str, - type_: type[R] | None = None, - choices: list[R] | None = None, - password: bool = False, - **kwargs: Unpack[NoneDefaultAskKwargs[R]], - ) -> R | None: - ... - - ... - - def ask( - self, - question: str, - type_: type[R] | None = None, - choices: list[R] | None = None, - password: bool = False, - **kwargs: Unpack[NoneDefaultAskKwargs[R]], + **kwargs: Any, ) -> R | None: v, default, should_return, has_default = self._validate(kwargs) if should_return: - return v + return cast(R, v) v = self.prompt.ask( question=question, type_=type_, @@ -113,9 +80,7 @@ def ask( self.print.error("You must provide a value.") return self.ask(question, type_, choices, **kwargs) - def _validate( - self, kwargs: NoneDefaultAskKwargs[R] - ) -> tuple[R | None, R | None, bool, bool]: + def _validate(self, kwargs: dict[str, Any]) -> tuple[Any, Any, bool, bool]: has_default = "default" in kwargs default = kwargs.get("default", None) @@ -129,46 +94,13 @@ def _validate( return default, default, not self.enabled, has_default - @overload - def select( - self, rows: list[R], type_: type[R] = str, **kwargs: Unpack[DefaultAskKwargs[R]] - ) -> tuple[R, int]: ... - - @overload - def select( - self, - rows: list[dict[str, Any]], - type_: type[R] = str, - columns: list[str] | None = None, - key: str | None = None, - **kwargs: Unpack[DefaultAskKwargs[R]], - ) -> tuple[R, int]: ... - - @overload - def select( - self, - rows: list[R] | list[dict[str, Any]], - type_: type[R], - columns: list[str] | None = None, - **kwargs: Unpack[AskKwargs[R]], - ) -> tuple[R, int]: ... - - @overload - def select( - self, - rows: list[R] | list[dict[str, Any]], - type_: type[R], - columns: list[str] | None = None, - **kwargs: Unpack[NoneDefaultAskKwargs[R]], - ) -> tuple[R | None, int | None]: ... - def select( self, rows: list[R] | list[dict[str, Any]], - type_: type[R] = str, + type_: type[R] | None = None, columns: list[str] | None = None, key: str | None = None, - **kwargs: Unpack[NoneDefaultAskKwargs[R]], + **kwargs: Any, ) -> tuple[R | None, int | None]: """Prompts the user to select a row from a table. @@ -187,7 +119,7 @@ def select( raise ValidationError("Missing key argument.") _key: str = key or "value" - _rows = list() + _rows: list[dict[str, Any]] = list() default_idx = 0 no_cols = columns is None _columns = columns or list() @@ -202,10 +134,10 @@ def select( _columns.append(k) else: is_default = row == default - key = "value" + col_key = "value" if len(_columns) > 0: - key = _columns[0] - _rows.append({"choice": str(i), key: row}) + col_key = _columns[0] + _rows.append({"choice": str(i), col_key: row}) if is_default: default_idx = len(_rows) - 1 @@ -219,19 +151,19 @@ def select( return v, default_idx self.table(columns=_columns, rows=_rows) - i = self.ask( + selected: int | None = self.ask( "Select an option #", int, choices=[i for i in range(len(rows))], default=default_idx, ) - if i is not None: - r = rows[i] + if selected is not None: + r = rows[selected] assert r is not None - return (r[_key], i) if isinstance(r, dict) else (r, i) + return (r[_key], selected) if isinstance(r, dict) else (r, selected) if has_default: return default, default_idx if self.print is not None: self.print.error("You must make a selection.") - return self.select(type_, rows, columns, key, **kwargs) # type: ignore + return self.select(rows, type_, columns, key, **kwargs) diff --git a/client/py/synnax/cli/flow/__init__.py b/client/py/synnax/cli/flow/__init__.py index 206c2fbb0f..f874e4dadc 100644 --- a/client/py/synnax/cli/flow/__init__.py +++ b/client/py/synnax/cli/flow/__init__.py @@ -29,10 +29,10 @@ def __init__(self, ctx: Context): self.steps = {} self.context = ctx - def add(self, name: str, step: Callable[[Context, T], str | None]): + def add(self, name: str, step: Callable[[Context, T], str | None]) -> None: self.steps[name] = step - def run(self, req: T, root: str): + def run(self, req: T, root: str) -> None: root_step = self.steps[root] self._run(root_step, req) @@ -40,7 +40,7 @@ def _run( self, step: Callable[[Context, T], str | None], request: T, - ): + ) -> None: next_step = step(self.context, request) if next_step is not None: self._run(self.steps[next_step], request) diff --git a/client/py/synnax/cli/ingest.py b/client/py/synnax/cli/ingest.py index 8b0dd5a917..50e2cffe28 100644 --- a/client/py/synnax/cli/ingest.py +++ b/client/py/synnax/cli/ingest.py @@ -48,7 +48,7 @@ def pure_ingest( client: Synnax | None = None, ctx: Context = default.context(), ) -> None: - flow = Flow(ctx) + flow: Flow[Any] = Flow(ctx) flow.add("initialize_reader", initialize_reader) flow.add("connect_client", _connect_client) flow.add("ingest_all", ingest_all) @@ -98,6 +98,7 @@ def run_ingestion(ctx: Context, cli: IngestionCLI) -> None: raise NotImplementedError("Only row ingestion is supported at this time.") ctx.console.info("Starting ingestion process...") engine.run() + assert cli.name is not None cli.client.ranges.create(name=cli.name, time_range=TimeRange(cli.start, engine.end)) @@ -157,6 +158,7 @@ def skip_invalid_channels(ctx: Context, cli: IngestionCLI) -> str | None: channel_name_table(ctx, [ch for ch in data_types.keys()]) if not ctx.console.ask("Skip these channels?", default=True): return None + assert cli.filtered_channels is not None cli.filtered_channels = [ ch for ch in cli.filtered_channels if ch.name not in data_types.keys() ] @@ -179,7 +181,8 @@ def validate_data_types(ctx: Context, cli: IngestionCLI) -> str | None: samples_type = d_types[ch.name].np ch_type = ch.data_type.np if not np.can_cast(samples_type, ch_type): - return cannot_cast_error(ctx, samples_type, ch) + cannot_cast_error(ctx, samples_type, ch) + return None elif samples_type != ch_type: ctx.console.warn( f"""Channel {ch.name} has data type {ch_type} but the file data type is @@ -258,8 +261,10 @@ def validate_start_time(ctx: Context, cli: IngestionCLI) -> str | None: """Please enter the start timestamp of the file as a nanosecond UTC integer. If you'd like a converter, use https://www.epochconverter.com/""", - default=TimeStamp.now(), + type_=int, + default=int(TimeStamp.now()), ) + assert _start is not None cli.start = TimeStamp(_start) else: idx = _idx[0] @@ -307,8 +312,8 @@ def assign_data_type( ) -> dict[DataType, list[ChannelMeta]] | None: assert cli.not_found is not None - grouped = {GROUP_ALL: cli.db_channels} - assigned = {} + grouped: dict[str, list[ChannelMeta]] = {GROUP_ALL: cli.not_found} + assigned: dict[DataType, list[ChannelMeta]] = {} ctx.console.info("Please select an option for assigning data types:") opt, _ = ctx.console.select( rows=DATA_TYPE_OPTIONS, @@ -316,7 +321,6 @@ def assign_data_type( ) if opt == DATA_TYPE_OPTIONS[0]: data_types = read_data_types(ctx, cli) - assigned = {} for ch in cli.not_found: dt = data_types[ch.name] if dt not in assigned: @@ -334,22 +338,24 @@ def assign_data_type( for key, group in grouped.items(): if key != GROUP_ALL: ctx.console.info(f"Assigning data type to {key}") - dt = select_data_type(ctx) - assigned[dt] = group + selected_dt = select_data_type(ctx) + if selected_dt is None: + return None + assigned[selected_dt] = group return assigned def assign_index_or_rate( ctx: Context, cli: IngestionCLI, -) -> dict[Rate | str, list[ChannelMeta]] | None: +) -> dict[Rate | int, list[ChannelMeta]] | None: """Prompts the user to assign an index/rate to the channels in the given group""" assert cli.client is not None assert cli.not_found is not None client = cli.client - grouped = {GROUP_ALL: cli.not_found} + grouped: dict[str, list[ChannelMeta]] = {GROUP_ALL: cli.not_found} if not ctx.console.ask( "Do all non-indexed channels have the same data rate or index?", bool, @@ -359,14 +365,14 @@ def assign_index_or_rate( "Can you group channels by data rate or index?", default=True ): grouped = {v.name: [v] for v in cli.not_found} - grouped = prompt_group_channel_names(ctx, [ch.name for ch in cli.not_found]) - if grouped is None or len(grouped) == 0: + groups = prompt_group_channel_names(ctx, [ch.name for ch in cli.not_found]) + if groups is None or len(groups) == 0: return None grouped = { - k: [ch for ch in cli.not_found if ch.name in v] for k, v in grouped.items() + k: [ch for ch in cli.not_found if ch.name in v] for k, v in groups.items() } - def assign_to_group(key: str, group: list[ChannelMeta]): + def assign_to_group(key: str, group: list[ChannelMeta]) -> Rate | int | None: if key != GROUP_ALL: ctx.console.info(f"Assigning data rate or index to {key}") _choice = ctx.console.ask("Enter the name of an index or a data rate") @@ -385,7 +391,7 @@ def assign_to_group(key: str, group: list[ChannelMeta]): return None return idx.key - assigned: dict[Rate | str, list[ChannelMeta]] = dict() + assigned: dict[Rate | int, list[ChannelMeta]] = dict() for key, group in grouped.items(): idx = assign_to_group(key, group) if idx is None: @@ -409,13 +415,16 @@ def create_channels(ctx: Context, cli: IngestionCLI) -> str | None: to_create = list() for rate_or_index, channels in idx_grouped.items(): - is_rate = isinstance(rate_or_index, Rate) + if isinstance(rate_or_index, Rate): + index = 0 + else: + index = rate_or_index for ch in channels: to_create.append( Channel( name=ch.name, is_index=False, - index="" if is_rate else rate_or_index, + index=index, data_type=[dt for dt, chs in dt_grouped.items() if ch in chs][0], ) ) @@ -429,6 +438,7 @@ def prompt_name(ctx: Context, cli: IngestionCLI) -> str | None: assert cli.db_channels is not None assert cli.not_found is not None assert cli.client is not None + assert cli.path is not None path: Path = cli.path ctx.console.info("Please enter a name for the data set") cli.name = ctx.console.ask("Name", default=path.name) diff --git a/client/py/synnax/cli/login.py b/client/py/synnax/cli/login.py index c6c3701d43..bc1b0913ec 100644 --- a/client/py/synnax/cli/login.py +++ b/client/py/synnax/cli/login.py @@ -26,14 +26,14 @@ def login(ctx: click.Context) -> None: file. """ warning(ctx) - ctx = Context(console=RichConsole()) - options = prompt_client_options(ctx) - synnax = connect_from_options(ctx, options) + flow_ctx = Context(console=RichConsole()) + options = prompt_client_options(flow_ctx) + synnax = connect_from_options(flow_ctx, options) if synnax is None: return cfg = ClustersConfig(ConfigFile(Path(os.path.expanduser("~/.synnax")))) cfg.set(ClusterConfig(options=options)) - ctx.console.info(SUCCESSFUL_LOGIN) + flow_ctx.console.info(SUCCESSFUL_LOGIN) SUCCESSFUL_LOGIN = """Saved credentials. You can now use the Synnax Client diff --git a/client/py/synnax/cli/populate.py b/client/py/synnax/cli/populate.py index 11fca96dc6..90a18e60d8 100644 --- a/client/py/synnax/cli/populate.py +++ b/client/py/synnax/cli/populate.py @@ -23,7 +23,7 @@ default=10, help="Number of samples per range in each channel", ) -def populate(num_channels, num_ranges, num_samples): +def populate(num_channels: int, num_ranges: int, num_samples: int) -> None: client = instantiate_client() for channel_index in range(num_channels): @@ -35,23 +35,23 @@ def populate(num_channels, num_ranges, num_samples): client.populate_range(channel_name, range_data) -def instantiate_client(): +def instantiate_client() -> MockClient: return MockClient() -def generate_channel_name(channel_index): +def generate_channel_name(channel_index: int) -> str: return f"Channel_{channel_index}" -def generate_fake_data(num_samples): +def generate_fake_data(num_samples: int) -> list[int]: return [random.randint(0, 100) for _ in range(num_samples)] class MockClient: - def create_channel(self, channel_name): + def create_channel(self, channel_name: str) -> None: print(f"Creating channel: {channel_name}") - def populate_range(self, channel_name, range_data): + def populate_range(self, channel_name: str, range_data: list[int]) -> None: print(f"Populating range in channel {channel_name} with data: {range_data}") diff --git a/client/py/synnax/cli/telem.py b/client/py/synnax/cli/telem.py index 439b2ec613..2cb22a223d 100644 --- a/client/py/synnax/cli/telem.py +++ b/client/py/synnax/cli/telem.py @@ -8,16 +8,26 @@ # included in the file licenses/APL.txt. -from typing import Unpack +from typing import Any -from synnax.cli.console.sugared import AskKwargs from synnax.cli.flow import Context from synnax.telem import DataType, TimeSpan, TimeSpanUnits +_VALID_TIME_UNITS: dict[str, TimeSpanUnits] = { + "iso": "iso", + "ns": "ns", + "us": "us", + "ms": "ms", + "s": "s", + "m": "m", + "h": "h", + "d": "d", +} + def select_data_type( ctx: Context, - **kwargs: Unpack[AskKwargs[str]], + **kwargs: Any, ) -> DataType | None: """Prompts the user to select a data type from a list of all available data types. @@ -25,20 +35,21 @@ def select_data_type( :param ctx: The current flow Context. :param allow_none: Whether to allow the user to select None. """ - return DataType( - ctx.console.select( - rows=[str(name) for name in DataType.ALL], - type_=str, - columns=["data_type"], - **kwargs, - )[0] + selected, _ = ctx.console.select( + rows=[str(name) for name in DataType.ALL], + type_=str, + columns=["data_type"], + **kwargs, ) + if selected is None: + return None + return DataType(selected) def ask_time_units_select( ctx: Context, question: str | None = None, - **kwargs: Unpack[AskKwargs[str]], + **kwargs: Any, ) -> TimeSpanUnits: """Prompts the user to select a time unit from a list of all available time units. @@ -48,9 +59,16 @@ def ask_time_units_select( """ if question is not None: ctx.console.info(question) - return ctx.console.select( - rows=["iso", *list(TimeSpan.UNITS.keys())], + unit_rows: list[str] = ["iso", *list(TimeSpan.UNITS.keys())] + selected, _ = ctx.console.select( + rows=unit_rows, type_=str, columns=["unit"], **kwargs, - )[0] + ) + if selected is None: + raise ValueError("no time unit selected") + unit = _VALID_TIME_UNITS.get(selected) + if unit is None: + raise ValueError(f"invalid time unit: {selected}") + return unit diff --git a/client/py/synnax/cli/ts_convert.py b/client/py/synnax/cli/ts_convert.py index 9c1550745c..38011b6761 100644 --- a/client/py/synnax/cli/ts_convert.py +++ b/client/py/synnax/cli/ts_convert.py @@ -158,16 +158,17 @@ def pure_tsconvert( arg_name=OUTPUT_CHANNEL_ARG, ) - output_path = Path( - ctx.console.ask( - "Where would you like to save the converted data?", - default=str( - input_path.parent / f"{input_path.stem}_converted{input_path.suffix}" - ), - arg=str(output_path) if output_path is not None else None, - arg_name=OUTPUT_PATH_ARG, - ) + output_path_str = ctx.console.ask( + "Where would you like to save the converted data?", + default=str( + input_path.parent / f"{input_path.stem}_converted{input_path.suffix}" + ), + arg=str(output_path) if output_path is not None else None, + arg_name=OUTPUT_PATH_ARG, ) + if output_path_str is None: + raise ValueError("Output path is required") + output_path = Path(output_path_str) writer = IO_FACTORY.open_writer(output_path) @@ -184,7 +185,7 @@ def pure_tsconvert( for chunk in reader: t0 = datetime.now() converted = convert_time_units( - chunk[input_channel], input_precision, output_precision + chunk[input_channel].to_numpy(), input_precision, output_precision ) chunk[output_channel] = converted writer.write(chunk) @@ -198,11 +199,13 @@ def pure_tsconvert( def ask_channel_and_check_exists( ctx: Context, reader: BaseReader, - question="Enter a channel name", - arg_name="channel", + question: str = "Enter a channel name", + arg_name: str = "channel", arg: str | None = None, ) -> str: _ch = ctx.console.ask(question, arg_name=arg_name, arg=arg) + if _ch is None: + raise ValueError("Channel name is required") try: next(ch for ch in reader.channels() if ch.name == _ch) except StopIteration: diff --git a/client/py/synnax/color/__init__.py b/client/py/synnax/color/__init__.py index 10c9664bee..d9168f6dc3 100644 --- a/client/py/synnax/color/__init__.py +++ b/client/py/synnax/color/__init__.py @@ -8,3 +8,5 @@ # included in the file licenses/APL.txt. from synnax.color.color import Color, Crude + +__all__ = ["Color", "Crude"] diff --git a/client/py/synnax/color/color.py b/client/py/synnax/color/color.py index 17db534519..86a9171b3a 100644 --- a/client/py/synnax/color/color.py +++ b/client/py/synnax/color/color.py @@ -69,7 +69,7 @@ def is_zero(self) -> bool: return self.r == 0 and self.g == 0 and self.b == 0 and self.a == 0 -def _from_hex(s: str) -> dict: +def _from_hex(s: str) -> dict[str, int | float]: s = s.lstrip("#") if len(s) == 0: return {"r": 0, "g": 0, "b": 0, "a": 0} diff --git a/client/py/synnax/config/__init__.py b/client/py/synnax/config/__init__.py index d5686df0f7..9978705530 100644 --- a/client/py/synnax/config/__init__.py +++ b/client/py/synnax/config/__init__.py @@ -15,6 +15,8 @@ from synnax.exceptions import ValidationError from synnax.options import Options +__all__ = ["ClusterConfig", "ClustersConfig", "ConfigFile"] + CONFIG_FILE_PATH = Path(os.path.expanduser("~/.synnax")) diff --git a/client/py/synnax/config/clusters.py b/client/py/synnax/config/clusters.py index 4423910612..332604e58c 100644 --- a/client/py/synnax/config/clusters.py +++ b/client/py/synnax/config/clusters.py @@ -33,7 +33,7 @@ def get(self, key: str = "default") -> ClusterConfig | None: pwd = pwd or "" return ClusterConfig(options=Options(**opts, password=pwd)) - def set(self, c: ClusterConfig, key: str = "default"): + def set(self, c: ClusterConfig, key: str = "default") -> None: p = c.model_dump() keyring.set_password("synnax", key, p["options"].pop("password")) self.internal.set(f"clusters.{key}", p) diff --git a/client/py/synnax/config/file.py b/client/py/synnax/config/file.py index 13e7f62ae8..8707a0392b 100644 --- a/client/py/synnax/config/file.py +++ b/client/py/synnax/config/file.py @@ -9,6 +9,7 @@ import json import pathlib +from typing import Any CONFIG_DIR_NAME = "./synnax" @@ -17,7 +18,7 @@ class ConfigFile: """The global synnax py configuration file.""" file: pathlib.Path - config: dict + config: dict[str, Any] def __init__( self, @@ -28,7 +29,7 @@ def __init__( self.config = {} self.load() - def load(self): + def load(self) -> None: """Loads the config file from disk. If the file does not exist, it will be created. """ @@ -37,7 +38,7 @@ def load(self): with open(self.config_file, "r") as f: self.config = json.load(f) - def save(self): + def save(self) -> None: """Saves the config file to disk.""" self.config_file.parent.mkdir( parents=True, @@ -46,29 +47,29 @@ def save(self): with open(self.config_file, "w") as f: json.dump(self.config, f) - def get(self, key): + def get(self, key: str) -> Any: """Gets a value from the config file.""" return get_nested(self.config, key) - def set(self, key, value): + def set(self, key: str, value: Any) -> None: """Sets a value in the config file.""" set_nested(self.config, key, value) self.save() - def delete(self, key): + def delete(self, key: str) -> None: """Deletes a value from the config file.""" del self.config[key] self.save() -def set_nested(d, key, value): +def set_nested(d: dict[str, Any], key: str, value: Any) -> None: keys = key.split(".") for key in keys[:-1]: d = d.setdefault(key, {}) d[keys[-1]] = value -def get_nested(d, key): +def get_nested(d: dict[str, Any], key: str) -> Any: keys = key.split(".") for key in keys[:-1]: d = d.get(key, {}) diff --git a/client/py/synnax/control/__init__.py b/client/py/synnax/control/__init__.py index 2ba82a6c21..f177d655c6 100644 --- a/client/py/synnax/control/__init__.py +++ b/client/py/synnax/control/__init__.py @@ -9,3 +9,5 @@ from synnax.control.client import Client from synnax.control.controller import Controller, ScheduledCommand + +__all__ = ["Client", "Controller", "ScheduledCommand"] diff --git a/client/py/synnax/control/controller.py b/client/py/synnax/control/controller.py index 52368a8bee..ca1ba206a4 100644 --- a/client/py/synnax/control/controller.py +++ b/client/py/synnax/control/controller.py @@ -12,7 +12,7 @@ from asyncio import Future from collections.abc import Callable from threading import Event, Lock -from typing import Any, Protocol, overload +from typing import Any, Protocol, cast, overload import numpy as np @@ -102,7 +102,7 @@ def __init__( write_authorities: CrudeAuthority | list[CrudeAuthority], ) -> None: self._retriever = retriever - if write is not None and len(write) > 0: + if write is not None and channel_.has_params(write): write_channels = channel_.retrieve_required(self._retriever, write) write_keys = [ch.index for ch in write_channels if ch.index != 0] write_keys.extend([ch.key for ch in write_channels]) @@ -112,7 +112,7 @@ def __init__( channels=write_keys, authorities=write_authorities, ) - if read is not None and len(read) > 0: + if read is not None and channel_.has_params(read): self._receiver_opt = _Receiver(frame_client, read, retriever, self) self._receiver.start() self._receiver.startup_ack.wait() @@ -136,16 +136,24 @@ def _receiver(self) -> _Receiver: return self._receiver_opt @overload - def set(self, ch: channel_.Key | str, value: SampleValue): ... + def set(self, channel: channel_.Key | str, value: SampleValue) -> None: ... @overload - def set(self, ch: dict[channel_.Key | str, SampleValue]): ... + def set(self, channel: dict[channel_.Key, SampleValue]) -> None: ... + + @overload + def set(self, channel: dict[str, SampleValue]) -> None: ... def set( self, - channel: channel_.Key | str | dict[channel_.Key | str, SampleValue], + channel: ( + channel_.Key + | str + | dict[channel_.Key, SampleValue] + | dict[str, SampleValue] + ), value: SampleValue | None = None, - ): + ) -> None: """Sets the provided channel(s) to the provided value(s). :param channel: A single channel key or name, or a dictionary of channel keys and @@ -162,7 +170,9 @@ def set( """ if isinstance(channel, dict): values = list(channel.values()) - channels = channel_.retrieve_required(self._retriever, list(channel.keys())) + # Overloads guarantee keys are homogeneous (all Key or all str) + ch_keys = cast(channel_.Params, list(channel.keys())) + channels = channel_.retrieve_required(self._retriever, ch_keys) now = TimeStamp.now() updated = {channels[i].key: values[i] for i in range(len(channels))} updated_idx = { @@ -173,53 +183,79 @@ def set( self._writer.write({**updated, **updated_idx}) return ch = self._retriever.retrieve_one(channel) - to_write = {ch.key: value} + assert value is not None + to_write: dict[channel_.Key, SampleValue] = {ch.key: value} if not ch.virtual: to_write[ch.index] = TimeStamp.now() - self._writer.write(to_write) + self._writer.write(cast(framer.CrudeFrame, to_write)) @overload def set_authority( self, value: CrudeAuthority, - ) -> bool: ... + ) -> None: ... @overload def set_authority( self, - value: dict[channel_.Key | str, CrudeAuthority], - ) -> bool: ... + value: dict[channel_.Key, CrudeAuthority], + ) -> None: ... @overload def set_authority( self, - ch: channel_.Key | str, - value: CrudeAuthority, - ) -> bool: ... + value: dict[str, CrudeAuthority], + ) -> None: ... + + @overload + def set_authority( + self, + value: dict[channel_.Payload, CrudeAuthority], + ) -> None: ... + + @overload + def set_authority( + self, + value: channel_.Key | str, + authority: CrudeAuthority, + ) -> None: ... def set_authority( self, value: ( - dict[channel_.Key | str | channel_.Payload, CrudeAuthority] + dict[channel_.Key, CrudeAuthority] + | dict[str, CrudeAuthority] + | dict[channel_.Payload, CrudeAuthority] | channel_.Key | str | CrudeAuthority ), authority: CrudeAuthority | None = None, - ) -> bool: + ) -> None: if isinstance(value, dict): - channels = channel_.retrieve_required(self._retriever, list(value.keys())) + # Overloads guarantee homogeneous key types; widen for uniform access + auth = cast( + dict[channel_.Key | str | channel_.Payload, CrudeAuthority], value + ) + auth_keys = cast(channel_.Params, list(auth.keys())) + channels = channel_.retrieve_required(self._retriever, auth_keys) for ch in channels: - value[ch.index] = value.get(ch.key, value.get(ch.name)) + resolved = auth.get(ch.key) or auth.get(ch.name) + if resolved is not None: + auth[ch.index] = resolved + self._writer.set_authority(auth) elif authority is not None: ch = self._retriever.retrieve_one(value) - value = {ch.key: authority, ch.index: authority} - return self._writer.set_authority(value) + self._writer.set_authority({ch.key: authority, ch.index: authority}) + elif isinstance(value, str): + raise TypeError("authority must be provided when setting by channel name") + else: + self._writer.set_authority(value) def wait_until( self, cond: Callable[[Controller], bool], - timeout: float | int | TimeSpan = None, + timeout: CrudeTimeSpan | None = None, ) -> bool: """Blocks the controller, calling the provided callback on every new sample received by the controller. Once the callback returns True, the method will @@ -247,7 +283,7 @@ def wait_until( def wait_while( self, cond: Callable[[Controller], bool], - timeout: CrudeTimeSpan = None, + timeout: CrudeTimeSpan | None = None, ) -> bool: """Blocks the controller, calling the provided callback on every new sample received. The controller will continue to block until the @@ -259,9 +295,9 @@ def wait_while( def _internal_wait_until( self, cond: Callable[[Controller], bool], - timeout: CrudeTimeSpan = None, + timeout: CrudeTimeSpan | None = None, reverse: bool = False, - ): + ) -> bool: if not callable(cond): raise ValueError("First argument to wait_until must be a callable.") processor = WaitUntil(cond, reverse) @@ -277,7 +313,7 @@ def _internal_wait_until( raise processor.exc return ok - def sleep(self, dur: float | int | TimeSpan, precise: bool = False): + def sleep(self, dur: float | int | TimeSpan, precise: bool = False) -> None: """Sleeps the controller for the provided duration. :param dur: The duration to sleep for. This can be a flot or int representing @@ -291,8 +327,8 @@ def sleep(self, dur: float | int | TimeSpan, precise: bool = False): def wait_until_defined( self, - channels: channel_.Key | str | list[channel_.Key | str], - timeout: CrudeTimeSpan = None, + channels: channel_.Params, + timeout: CrudeTimeSpan | None = None, ) -> bool: """Blocks until the controller has received at least one value from all the provided channels. This is useful for ensuring that the controlled has reached @@ -351,7 +387,7 @@ def remains_true_for( raise processor.exc return ok - def release(self): + def release(self) -> None: """Release control and shuts down the controller. No further control operations can be performed after calling this method. """ @@ -360,13 +396,11 @@ def release(self): if self._receiver_opt is not None: self._receiver.stop() - def __setitem__( - self, ch: channel_.Key | str | channel_.Payload, value: int | float - ): + def __setitem__(self, ch: channel_.Key | str, value: int | float) -> None: self.set(ch, value) @property - def state(self) -> dict[channel_.Key, np.number]: + def state(self) -> dict[channel_.Key, np.number | int | float]: """ :returns: The current state of all channels passed to read_from in the acquire method. This is a dictionary of channel keys to their most recent values. It's @@ -375,21 +409,23 @@ def state(self) -> dict[channel_.Key, np.number]: """ return self._receiver.state - def __setattr__(self, key, value): + def __setattr__(self, key: str, value: Any) -> None: try: super().__setattr__(key, value) except AttributeError: self.set(key, value) @overload - def get(self, ch: channel_.Key | str) -> int | float | None: ... + def get(self, ch: channel_.Key | str) -> np.number | int | float | None: ... @overload - def get(self, ch: channel_.Key | str, default: int | float) -> int | float: ... + def get( + self, ch: channel_.Key | str, default: int | float + ) -> np.number | int | float: ... def get( - self, ch: channel_.Key | str, default: int | float = None - ) -> int | float | None: + self, ch: channel_.Key | str, default: int | float | None = None + ) -> np.number | int | float | None: """Gets the most recent value for the provided channel, and returns the default value if no value has been received yet. @@ -401,21 +437,21 @@ def get( >>> controller.get("my_channel") >>> controller.get("my_channel", 42) """ - ch = self._retriever.retrieve_one(ch) - return self._receiver.state.get(ch.key, default) + ch_pld = self._retriever.retrieve_one(ch) + return self._receiver.state.get(ch_pld.key, default) def schedule( self, *commands: ScheduledCommand, ) -> tuple[Callable[[], None], bool]: - def start(): + def start() -> None: for cmd in commands: self.sleep(cmd.delay, precise=True) self.set(cmd.channel, cmd.value) return start, True - def __getitem__(self, item): + def __getitem__(self, item: channel_.Key | str) -> np.number | int | float: ch = self._retriever.retrieve_one(item) try: return self._receiver.state[ch.key] @@ -431,7 +467,7 @@ def __getitem__(self, item): method. """) - def __getattr__(self, item): + def __getattr__(self, item: str) -> Any: try: return super().__getattribute__(item) except AttributeError: @@ -440,12 +476,12 @@ def __getattr__(self, item): def __enter__(self) -> Controller: return self - def __exit__(self, exc_type, exc_value, traceback) -> None: + def __exit__(self, exc_type: object, exc_value: object, traceback: object) -> None: self.release() class _Receiver(AsyncThread): - state: dict[channel_.Key, np.number] + state: dict[channel_.Key, np.number | int | float] channels: channel_.Params client: framer.Client streamer: framer.AsyncStreamer @@ -454,7 +490,7 @@ class _Receiver(AsyncThread): retriever: channel_.Retriever controller: Controller startup_ack: Event - shutdown_future: Future + shutdown_future: Future[None] def __init__( self, @@ -462,7 +498,7 @@ def __init__( channels: channel_.Params, retriever: channel_.Retriever, controller: Controller, - ): + ) -> None: super().__init__() self.channels = retriever.retrieve(channels) self.client = client @@ -472,34 +508,37 @@ def __init__( self.startup_ack = Event() self.processors = set() - def add_processor(self, processor: Processor): + def add_processor(self, processor: Processor) -> None: with self.processor_lock: self.processors.add(processor) - def remove_processor(self, processor: Processor): + def remove_processor(self, processor: Processor) -> None: with self.processor_lock: self.processors.remove(processor) - def _process(self): + def _process(self) -> None: with self.processor_lock: for p in self.processors: p.process(self.controller) - async def _listen_for_close(self): + async def _listen_for_close(self) -> None: await self.shutdown_future await self.streamer.close_loop() - async def run_async(self): + async def run_async(self) -> None: self.streamer = await self.client.open_async_streamer(self.channels) self.shutdown_future = self.loop.create_future() self.loop.create_task(self._listen_for_close()) self.startup_ack.set() async for frame in self.streamer: for i, key in enumerate(frame.channels): - self.state[key] = frame.series[i][-1] + if isinstance(key, int): + v = frame.series[i][-1] + if isinstance(v, (np.number, int, float)): + self.state[key] = v self._process() - def stop(self): + def stop(self) -> None: self.loop.call_soon_threadsafe(self.shutdown_future.set_result, None) self.join() diff --git a/client/py/synnax/device/client.py b/client/py/synnax/device/client.py index 650fc32906..888a8085f2 100644 --- a/client/py/synnax/device/client.py +++ b/client/py/synnax/device/client.py @@ -7,7 +7,7 @@ # License, use of this software will be governed by the Apache License, Version 2.0, # included in the file licenses/APL.txt. -from typing import overload +from typing import Literal, overload from alamos import NOOP, Instrumentation, trace from freighter import Empty, UnaryClient, send_required @@ -66,13 +66,13 @@ def create( model: str = "", configured: bool = False, properties: str = "", - ): ... + ) -> Device: ... @overload - def create(self, devices: Device): ... + def create(self, devices: Device) -> Device: ... @overload - def create(self, devices: list[Device]): ... + def create(self, devices: list[Device]) -> list[Device]: ... def create( self, @@ -86,7 +86,7 @@ def create( model: str = "", configured: bool = False, properties: str = "", - ): + ) -> Device | list[Device]: is_single = not isinstance(devices, list) if devices is None: devices = [ @@ -123,6 +123,31 @@ def retrieve( name: str | None = None, model: str | None = None, location: str | None = None, + ignore_not_found: Literal[True], + ) -> Device | None: ... + + @overload + def retrieve( + self, + *, + keys: list[str] | None = None, + makes: list[str] | None = None, + models: list[str] | None = None, + names: list[str] | None = None, + locations: list[str] | None = None, + ignore_not_found: Literal[True], + ) -> list[Device]: ... + + @overload + def retrieve( + self, + *, + key: str | None = None, + make: str | None = None, + name: str | None = None, + model: str | None = None, + location: str | None = None, + ignore_not_found: Literal[False] = ..., ) -> Device: ... @overload @@ -134,10 +159,9 @@ def retrieve( models: list[str] | None = None, names: list[str] | None = None, locations: list[str] | None = None, - ignore_not_found: bool = False, + ignore_not_found: Literal[False] = ..., ) -> list[Device]: ... - @trace("debug") def retrieve( self, *, @@ -152,7 +176,7 @@ def retrieve( names: list[str] | None = None, locations: list[str] | None = None, ignore_not_found: bool = False, - ) -> list[Device] | Device: + ) -> list[Device] | Device | None: is_single = check_for_none(keys, makes, models, locations, names) res = send_required( self._client, diff --git a/client/py/synnax/ethercat/types.py b/client/py/synnax/ethercat/types.py index cf00399114..f23f094d29 100644 --- a/client/py/synnax/ethercat/types.py +++ b/client/py/synnax/ethercat/types.py @@ -39,10 +39,10 @@ """ import json -from typing import Literal +from typing import Any, Literal from uuid import uuid4 -from pydantic import BaseModel, Field, conint, field_validator +from pydantic import BaseModel, Field, field_validator from synnax import channel, device, task from synnax.telem import CrudeRate @@ -85,7 +85,7 @@ class BaseChan(BaseModel): device: str = Field(min_length=1) "The key of the Synnax slave device this channel belongs to." - def __init__(self, **data): + def __init__(self, **data: Any) -> None: if "key" not in data or not data["key"]: data["key"] = str(uuid4()) super().__init__(**data) @@ -274,7 +274,7 @@ class ReadTaskConfig(task.BaseReadConfig): "A list of input channel configurations to acquire data from." @field_validator("channels") - def validate_channels_not_empty(cls, v): + def validate_channels_not_empty(cls, v: list[InputChan]) -> list[InputChan]: """Validate that at least one channel is provided.""" if len(v) == 0: raise ValueError("Task must have at least one channel") @@ -310,7 +310,7 @@ class WriteTaskConfig(task.BaseWriteConfig): "A list of output channel configurations to write to." @field_validator("channels") - def validate_channels_not_empty(cls, v): + def validate_channels_not_empty(cls, v: list[OutputChan]) -> list[OutputChan]: """Validate that at least one channel is provided.""" if len(v) == 0: raise ValueError("Task must have at least one channel") diff --git a/client/py/synnax/exceptions.py b/client/py/synnax/exceptions.py index 2a5b4a840d..c1aa9680d2 100644 --- a/client/py/synnax/exceptions.py +++ b/client/py/synnax/exceptions.py @@ -142,11 +142,14 @@ def _decode(encoded: freighter.ExceptionPayload) -> Exception | None: return UnexpectedError(encoded.data) try: data = json.loads(encoded.data) + decoded_err = freighter.decode_exception( + freighter.ExceptionPayload(**data["error"]) + ) + if decoded_err is None: + return UnexpectedError(encoded.data) return PathError( data["path"], - freighter.decode_exception( - freighter.ExceptionPayload(**data["error"]) - ), + decoded_err, ) except Exception as e: return UnexpectedError(f"Failed to decode PathError: {e}") diff --git a/client/py/synnax/framer/__init__.py b/client/py/synnax/framer/__init__.py index 5690503c8f..a68709abe9 100644 --- a/client/py/synnax/framer/__init__.py +++ b/client/py/synnax/framer/__init__.py @@ -9,7 +9,20 @@ from synnax.framer.client import Client from synnax.framer.deleter import Deleter -from synnax.framer.frame import Frame +from synnax.framer.frame import CrudeFrame, Frame from synnax.framer.iterator import AUTO_SPAN, Iterator from synnax.framer.streamer import AsyncStreamer, Streamer from synnax.framer.writer import Writer, WriterMode + +__all__ = [ + "Client", + "Deleter", + "Frame", + "CrudeFrame", + "AUTO_SPAN", + "Iterator", + "AsyncStreamer", + "Streamer", + "Writer", + "WriterMode", +] diff --git a/client/py/synnax/framer/adapter.py b/client/py/synnax/framer/adapter.py index 98695aa0f3..edba031b8c 100644 --- a/client/py/synnax/framer/adapter.py +++ b/client/py/synnax/framer/adapter.py @@ -8,6 +8,7 @@ # included in the file licenses/APL.txt. import warnings +from typing import Any, cast from pandas import DataFrame @@ -33,7 +34,7 @@ def __init__(self, retriever: ChannelRetriever): self.keys = list() self.codec = Codec() - def update(self, channels: channel.Params): + def update(self, channels: channel.Params) -> None: normal = channel.normalize_params(channels) fetched = self.retriever.retrieve(normal.channels) self.codec.update( @@ -41,9 +42,9 @@ def update(self, channels: channel.Params): [ch.data_type for ch in fetched], ) - if normal.variant == "keys": + if isinstance(normal, channel.NormalizedKeyResult): self.__adapter = None - self.keys = normal.channels + self.keys = list(normal.channels) return self.__adapter = dict[int, str]() @@ -54,7 +55,7 @@ def update(self, channels: channel.Params): self.__adapter[ch.key] = ch.name self.keys = list(self.__adapter.keys()) - def adapt(self, fr: Frame): + def adapt(self, fr: Frame) -> Frame: if self.__adapter is None: return fr @@ -65,14 +66,17 @@ def adapt(self, fr: Frame): for i, k in enumerate(fr.channels): try: if isinstance(k, channel.Key): - fr.channels[i] = self.__adapter[k] + fr.channels[i] = self.__adapter[k] # type: ignore[call-overload] except KeyError: if to_purge is None: to_purge = [i] else: to_purge.append(i) if to_purge is not None: - fr.channels = [k for i, k in enumerate(fr.channels) if i not in to_purge] + fr.channels = cast( + list[channel.Key] | list[str], + [k for i, k in enumerate(fr.channels) if i not in to_purge], + ) fr.series = [s for i, s in enumerate(fr.series) if i not in to_purge] return fr @@ -80,7 +84,7 @@ def adapt(self, fr: Frame): class WriteFrameAdapter: _adapter: dict[str, channel.Key] | None - _keys: list[channel.Key] | None + _keys: list[channel.Key] _err_on_extra_chans: bool _strict_data_types: bool _suppress_warnings: bool @@ -97,13 +101,13 @@ def __init__( ): self.retriever = retriever self._adapter = None - self._keys = None + self._keys = list() self._err_on_extra_chans = err_on_extra_chans self._strict_data_types = strict_data_types self._suppress_warnings = suppress_warnings self.codec = Codec() - def update(self, channels: channel.Params): + def update(self, channels: channel.Params) -> None: results = retrieve_required_channel(self.retriever, channels) self._adapter = {ch.name: ch.key for ch in results} self._keys = [ch.key for ch in results] @@ -113,15 +117,15 @@ def update(self, channels: channel.Params): ) def adapt_dict_keys( - self, data: dict[channel.Payload | channel.Key | str, any] - ) -> dict[channel.Key, any]: + self, data: dict[channel.Payload | channel.Key | str, Any] + ) -> dict[channel.Key, Any]: out = dict() for k in data.keys(): out[self.__adapt_to_key(k)] = data[k] return out @property - def keys(self): + def keys(self) -> list[channel.Key]: return self._keys def __adapt_to_key(self, ch: channel.Payload | channel.Key | str) -> channel.Key: @@ -145,14 +149,14 @@ def adapt( channel.Payload | list[channel.Payload] | channel.Params | CrudeFrame ), series: CrudeSeries | list[CrudeSeries] | None = None, - ): + ) -> Frame: frame = self._adapt(channels_or_data, series) extra = set(frame.channels) - set(self.keys) if extra: raise PathError("keys", ValidationError(f"frame has extra keys {extra}")) for i, (col, series) in enumerate(frame.items()): - ch = self.retriever.retrieve(col)[0] # type: ignore + ch = self.retriever.retrieve(col)[0] if series.data_type != ch.data_type: if self._strict_data_types: raise PathError( @@ -200,55 +204,67 @@ def _adapt( """) pld = self.__adapt_ch(channels_or_data) - return Frame([pld.key], [series]) + return Frame([pld.key], [Series(cast(CrudeSeries, series))]) if isinstance(channels_or_data, list): if series is None: raise ValidationError(f""" Received {len(channels_or_data)} channels but no series. """) - channels = list() - o_series = list() + series_list: list[CrudeSeries] = ( + [series] + if not isinstance(series, list) + else cast(list[CrudeSeries], series) + ) + channels: list[channel.Key] = list() + o_series: list[Series] = list() for i, ch in enumerate(channels_or_data): pld = self.__adapt_ch(ch) - if i >= len(series): + if i >= len(series_list): raise ValidationError(f""" - Received {len(channels_or_data)} channels but only {len(series)} series. + Received {len(channels_or_data)} channels but only {len(series_list)} series. """) channels.append(pld.key) - o_series.append(series[i]) + o_series.append(Series(series_list[i])) return Frame(channels, o_series) - is_frame = isinstance(channels_or_data, Frame) - is_df = isinstance(channels_or_data, DataFrame) - if is_frame or is_df: - cols = channels_or_data.channels if is_frame else channels_or_data.columns + if isinstance(channels_or_data, (Frame, DataFrame)): + if isinstance(channels_or_data, Frame): + cols: list[channel.Key] | list[str] = channels_or_data.channels + else: + cols = cast(list[str], list(channels_or_data.columns)) if self._adapter is None: - return channels_or_data - channels = list() - series = list() + return ( + Frame(channels_or_data) + if isinstance(channels_or_data, DataFrame) + else channels_or_data + ) + adapted_channels: list[channel.Key] = list() + adapted_series: list[Series] = list() for col in cols: try: - channels.append(self._adapter[col] if isinstance(col, str) else col) - series.append(Series(channels_or_data[col])) + adapted_channels.append( + self._adapter[col] if isinstance(col, str) else col + ) + adapted_series.append(Series(channels_or_data[col])) except KeyError as e: if self._err_on_extra_chans: raise ValidationError( f"Channel {e} was not provided in the list of " f"channels when the writer was opened." ) - return Frame(channels=channels, series=series) + return Frame(channels=adapted_channels, series=adapted_series) if isinstance(channels_or_data, dict): - channels = list() - series = list() + dict_channels: list[channel.Key] = list() + dict_series: list[Series] = list() for k, v in channels_or_data.items(): pld = self.__adapt_ch(k) - channels.append(pld.key) - series.append(Series(v)) + dict_channels.append(pld.key) + dict_series.append(Series(v)) - return Frame(channels, series) + return Frame(dict_channels, dict_series) raise TypeError( f"""Cannot construct frame from {channels_or_data} and {series}""" diff --git a/client/py/synnax/framer/client.py b/client/py/synnax/framer/client.py index 7399c5afb3..44df85f87b 100644 --- a/client/py/synnax/framer/client.py +++ b/client/py/synnax/framer/client.py @@ -7,7 +7,7 @@ # License, use of this software will be governed by the Apache License, Version 2.0, # included in the file licenses/APL.txt. -from typing import overload +from typing import cast, overload import pandas as pd from alamos import NOOP, Instrumentation @@ -56,7 +56,7 @@ def __init__( retriever: Retriever, deleter: Deleter, instrumentation: Instrumentation = NOOP, - ): + ) -> None: self.__stream_client = stream_client self.__async_client = async_client self.__unary_client = unary_client @@ -133,7 +133,7 @@ def open_iterator( self, tr: TimeRange, channels: channel.Params, - chunk_size: int = 1e5, + chunk_size: int = 100000, downsample_factor: int = 1, ) -> Iterator: """Opens a new iterator over the given channels within the provided time range. @@ -163,56 +163,37 @@ def open_iterator( def write( self, start: CrudeTimeStamp, - frame: CrudeFrame, + channels: CrudeFrame, + *, strict: bool = False, - ): ... + ) -> None: ... @overload def write( self, start: CrudeTimeStamp, - channel: channel.Key | str | channel.Payload, - data: CrudeSeries, - strict: bool = False, - ): - """Writes telemetry to the given channel starting at the given timestamp. - - :param channel: The key of the channel to write to. - :param start: The starting timestamp of the first sample in data. - :param data: The telemetry to write to the channel. - :returns: None. - """ - ... - - @overload - def write( - self, - start: CrudeTimeStamp, - channel: ( - list[channel.Key] - | tuple[channel.Key] - | list[str] - | tuple[str] - | list[channel.Payload] - ), - series: list[CrudeSeries], + channels: channel.Params, + series: CrudeSeries | list[CrudeSeries], + *, strict: bool = False, - ): ... + ) -> None: ... def write( self, start: CrudeTimeStamp, - channels: channel.Params | channel.Payload | list[channel.Payload] | CrudeFrame, + channels: channel.Params | CrudeFrame, series: CrudeSeries | list[CrudeSeries] | None = None, + *, strict: bool = False, - ): - parsed_channels = list() + ) -> None: + parsed_channels: channel.Params = list() if isinstance(channels, (list, channel.Key, channel.Payload, str)): parsed_channels = channels elif isinstance(channels, dict): parsed_channels = list(channels.keys()) elif isinstance(channels, Frame): - parsed_channels = channels.channels + # Frame channels are homogeneous (all Key or all str) at runtime + parsed_channels = cast(channel.Params, channels.channels) elif isinstance(channels, pd.DataFrame): parsed_channels = list(channels.columns) with self.open_writer( @@ -223,7 +204,7 @@ def write( err_on_unauthorized=True, auto_index_persist_interval=TimeSpan.MAX, ) as w: - w.write(channels, series) + w.write(channels, series) # type: ignore[arg-type] @overload def read( @@ -264,12 +245,14 @@ def read( ) return series + @overload def read_latest( self, channels: channel.Key | str, n: int = 1, ) -> MultiSeries: ... + @overload def read_latest( self, channels: list[channel.Key] | tuple[channel.Key] | list[str] | tuple[str], @@ -280,7 +263,7 @@ def read_latest( self, channels: channel.Params, n: int = 1, - ) -> Frame: + ) -> Frame | MultiSeries: """ Reads the latest n samples from time_channel and data_channel. @@ -303,7 +286,10 @@ def read_latest( aggregate.append(i.value) if len(normal.channels) > 1: return aggregate - return aggregate.get(normal.channels[0], MultiSeries([])) + result = aggregate.get(normal.channels[0], MultiSeries([])) + if result is None: + return MultiSeries([]) + return result def open_streamer( self, diff --git a/client/py/synnax/framer/codec.py b/client/py/synnax/framer/codec.py index 23a3aa05a9..b9024f102a 100644 --- a/client/py/synnax/framer/codec.py +++ b/client/py/synnax/framer/codec.py @@ -72,13 +72,11 @@ def decode(cls, b: int) -> CodecFlags: class CodecState: - keys: list[channel.Key] | tuple[channel.Key] + keys: list[channel.Key] data_types: dict[channel.Key, DataType] has_variable_data_types: bool - def __init__( - self, keys: list[channel.Key] | tuple[channel.Key], data_types: list[DataType] - ) -> None: + def __init__(self, keys: list[channel.Key], data_types: list[DataType]) -> None: self.keys = sorted(keys) self.data_types = {k: dt for k, dt in zip(keys, data_types)} self.has_variable_data_types = any(dt.is_variable for dt in data_types) @@ -88,26 +86,24 @@ class Codec: _has_variable_data_types: bool _seq_num: int _states: dict[int, CodecState] - _curr_state: CodecState = None + _curr_state: CodecState | None = None def __init__( self, - keys: list[channel.Key] | tuple[channel.Key] = None, - data_types: list[DataType] = None, + keys: list[channel.Key] | None = None, + data_types: list[DataType] | None = None, ) -> None: self._seq_num = 0 self._states = dict() - if keys is not None: + if keys is not None and data_types is not None: self.update(keys, data_types) - def update( - self, keys: list[channel.Key] | tuple[channel.Key], data_types: list[DataType] - ): + def update(self, keys: list[channel.Key], data_types: list[DataType]) -> None: self._seq_num += 1 self._curr_state = CodecState(keys, data_types) self._states[self._seq_num] = self._curr_state - def throw_if_not_updated(self, op_name: str): + def throw_if_not_updated(self, op_name: str) -> None: if self._curr_state is None: raise ValueError( "Codec has not been updated with keys and data types. " @@ -116,6 +112,7 @@ def throw_if_not_updated(self, op_name: str): def encode(self, frame: Frame | FramePayload, start_offset: int = 0) -> bytes: self.throw_if_not_updated("encode") + assert self._curr_state is not None pld = frame if isinstance(frame, FramePayload) else frame.to_payload() indices = sorted(range(len(pld.keys)), key=lambda i: pld.keys[i]) sorted_keys = [pld.keys[i] for i in indices] @@ -228,6 +225,7 @@ def encode(self, frame: Frame | FramePayload, start_offset: int = 0) -> bytes: def decode(self, data: bytes, offset: int = 0) -> FramePayload: self.throw_if_not_updated("decode") + assert self._curr_state is not None buffer = memoryview(data) idx = offset flags = CodecFlags.decode(buffer[idx]) diff --git a/client/py/synnax/framer/deleter.py b/client/py/synnax/framer/deleter.py index 9a214491ad..a0893e9436 100644 --- a/client/py/synnax/framer/deleter.py +++ b/client/py/synnax/framer/deleter.py @@ -43,10 +43,9 @@ def delete( tr: TimeRange, ) -> None: normal = channel.normalize_params(channels) - req = _Request( - **{ - normal.variant: normal.channels, - "bounds": tr, - } - ) + req = _Request(bounds=tr) + if isinstance(normal, channel.NormalizedKeyResult): + req.keys = normal.channels + else: + req.names = normal.channels send_required(self._client, "/frame/delete", req, _Response) diff --git a/client/py/synnax/framer/frame.py b/client/py/synnax/framer/frame.py index f6a3391476..34216845ef 100644 --- a/client/py/synnax/framer/frame.py +++ b/client/py/synnax/framer/frame.py @@ -9,8 +9,8 @@ from __future__ import annotations -from collections.abc import Iterator -from typing import overload +from collections.abc import Iterator, Sequence +from typing import TypeAlias, cast, overload from pandas import DataFrame from pydantic import BaseModel, Field @@ -21,7 +21,7 @@ class FramePayload(BaseModel): - keys: list[channel.Key] | tuple[channel.Key] + keys: list[channel.Key] series: list[Series] def __init__( @@ -29,9 +29,6 @@ def __init__( keys: list[int] | None = None, series: list[Series] | None = None, ): - # This is a workaround to allow for a None value to be - # passed to the arrays field, but still have required - # type hinting. if series is None: series = list() if keys is None: @@ -45,7 +42,7 @@ class Frame: can be keyed by channel name or channel key, but not both. """ - channels: list[channel.Key | str] + channels: list[channel.Key] | list[str] series: list[Series] = Field(default_factory=list) def __init__( @@ -62,25 +59,29 @@ def __init__( | dict[str, TypedCrudeSeries] | None ) = None, - series: list[TypedCrudeSeries] | None = None, + series: Sequence[TypedCrudeSeries] | None = None, ): if isinstance(channels, Frame): self.channels = channels.channels self.series = channels.series elif isinstance(channels, FramePayload): - self.channels = channels.keys + self.channels = list(channels.keys) self.series = channels.series elif isinstance(channels, DataFrame): self.channels = channels.columns.to_list() self.series = [Series(data=channels[k]) for k in self.channels] elif isinstance(channels, dict): - self.channels = list(channels.keys()) + self.channels = cast(list[channel.Key] | list[str], list(channels.keys())) self.series = [Series(d) for d in channels.values()] elif (series is None or isinstance(series, list)) and ( channels is None or isinstance(channels, list) ): self.series = list() if series is None else [Series(d) for d in series] - self.channels = channels or list[channel.Key]() + self.channels = ( + cast(list[channel.Key] | list[str], list(channels)) + if channels + else list() + ) else: raise ValueError(f""" [Frame] - invalid construction arguments. Received {channels} @@ -119,11 +120,13 @@ def append( or name. """ if isinstance(col_or_frame, Frame): - self.series.extend(col_or_frame.series) # type: ignore - self.channels.extend(col_or_frame.channels) # type: ignore + self.series.extend(col_or_frame.series) + self.channels.extend(col_or_frame.channels) # type: ignore[arg-type] else: + if series is None: + raise ValueError("series must be provided when appending a channel") self.series.append(series) - self.channels.append(col_or_frame) # type: ignore + self.channels.append(col_or_frame) # type: ignore[arg-type] def items( self, @@ -132,16 +135,16 @@ def items( Returns a generator of tuples containing the channel and series for each channel in the frame. """ - return zip(self.channels, self.series) # type: ignore + return zip(self.channels, self.series) - def __getitem__(self, key: channel.Key | str | any) -> MultiSeries: + def __getitem__(self, key: channel.Key | str) -> MultiSeries: if not isinstance(key, (channel.Key, str)): return self.to_df()[key] indexes = [i for i, k in enumerate(self.channels) if k == key] return MultiSeries([self.series[i] for i in indexes]) def get( - self, key: channel.Key | str, default: Series | None = None + self, key: channel.Key | str, default: MultiSeries | None = None ) -> MultiSeries | None: """Gets the series for the given channel key or name. If the channel does not exist in the frame, returns the default value or None if no default is provided. @@ -153,7 +156,7 @@ def get( except ValueError: return default - def to_payload(self): + def to_payload(self) -> FramePayload: """Converts the frame to its payload representation for transport over the network. This method should typically only be used internally. :raises: ValidationError if the frame is keyed by channel name instead of key. @@ -164,24 +167,29 @@ def to_payload(self): Cannot convert a frame that contains channel names to a payload. The following channels are invalid: {diff} """) - return FramePayload(keys=self.channels, series=self.series) + keys: list[int] = [k for k in self.channels if isinstance(k, int)] + return FramePayload(keys=keys, series=self.series) def to_df(self) -> DataFrame: """Converts the frame to a pandas DataFrame. Each column in the DataFrame corresponds to a channel in the frame. """ - base = dict() - for k in set(self.channels): - # Try to convert the value to a numpy array. If it fails (such as in the - # case of strings or JSON objects), convert it to a primitive list instead. + base: dict[channel.Key | str, object] = dict() + channels = cast(list[channel.Key | str], self.channels) + for k in set(channels): + v = self.get(k) + if v is None: + continue try: - base[k] = self.get(k).__array__() + base[k] = v.__array__() except TypeError: - base[k] = list(self.get(k)) + base[k] = list(v) return DataFrame(base) def __contains__(self, key: channel.Key | str) -> bool: return key in self.channels -CrudeFrame = Frame | FramePayload | dict[channel.Key, CrudeSeries] | DataFrame +CrudeFrame: TypeAlias = ( + Frame | FramePayload | dict[channel.Key, CrudeSeries] | DataFrame +) diff --git a/client/py/synnax/framer/iterator.py b/client/py/synnax/framer/iterator.py index 7540af1a46..37a993f728 100644 --- a/client/py/synnax/framer/iterator.py +++ b/client/py/synnax/framer/iterator.py @@ -7,6 +7,8 @@ # License, use of this software will be governed by the Apache License, Version 2.0, # included in the file licenses/APL.txt. +from __future__ import annotations + from enum import Enum from alamos import NOOP, Instrumentation @@ -45,7 +47,7 @@ class _Request(BaseModel): span: TimeSpan | None = None bounds: TimeRange | None = None stamp: TimeStamp | None = None - keys: list[channel.Key] | tuple[channel.Key] | None = None + keys: list[channel.Key] | None = None chunk_size: int | None = None downsample_factor: int | None = None @@ -82,7 +84,7 @@ def __init__( tr: TimeRange, client: StreamClient, adapter: ReadFrameAdapter, - chunk_size: int = 1e5, + chunk_size: int = 100000, downsample_factor: int = 1, instrumentation: Instrumentation = NOOP, ) -> None: @@ -94,7 +96,7 @@ def __init__( self._downsample_factor = downsample_factor self.__open() - def __open(self): + def __open(self) -> None: """Opens the iterator, configuring it to iterate over the telemetry in the channels with the given keys within the provided time range. @@ -183,7 +185,7 @@ def valid(self) -> bool: """ return self._exec(command=_Command.VALID) - def close(self): + def close(self) -> None: """Close closes the iterator. An iterator MUST be closed after use, and this method should probably be placed in a 'finally' block. If the iterator is not closed, it may leak resources and threads. @@ -202,22 +204,22 @@ def close(self): elif not isinstance(exc, EOF): raise exc - def __iter__(self): + def __iter__(self) -> Iterator: self.seek_first() return self - def __next__(self): + def __next__(self) -> Frame: if not self.next(AUTO_SPAN): raise StopIteration return self.value - def __enter__(self): + def __enter__(self) -> Iterator: return self - def __exit__(self, exc_type, exc_value, traceback): + def __exit__(self, exc_type: object, exc_value: object, traceback: object) -> None: self.close() - def _exec(self, **kwargs) -> bool: + def _exec(self, **kwargs: object) -> bool: exc = self.__stream.send(_Request(**kwargs)) if exc is not None: raise exc diff --git a/client/py/synnax/framer/streamer.py b/client/py/synnax/framer/streamer.py index 1d05850c6f..5e39338219 100644 --- a/client/py/synnax/framer/streamer.py +++ b/client/py/synnax/framer/streamer.py @@ -7,7 +7,9 @@ # License, use of this software will be governed by the Apache License, Version 2.0, # included in the file licenses/APL.txt. -from typing import overload +from __future__ import annotations + +from typing import cast, overload from freighter import ( EOF, @@ -16,6 +18,7 @@ Stream, WebsocketClient, ) +from freighter.transport import P from freighter.websocket import Message from pydantic import BaseModel @@ -28,7 +31,7 @@ class _Request(BaseModel): - keys: list[channel.Key] | tuple[channel.Key] + keys: list[channel.Key] downsample_factor: int throttle_rate_hz: float | None = None @@ -38,15 +41,14 @@ class _Response(BaseModel): class WSStreamerCodec(WSFramerCodec): - def encode(self, pld: Message) -> bytes: - return self.lower_perf_codec.encode(pld) + def encode(self, data: BaseModel) -> bytes: + return self.lower_perf_codec.encode(data) - def decode(self, data: bytes, pld_t: Message[_Response]) -> object: + def decode(self, data: bytes, pld_t: type[P]) -> P: if data[0] == LOW_PERF_SPECIAL_CHAR: - msg = self.lower_perf_codec.decode(data[1:], pld_t) - return msg + return self.lower_perf_codec.decode(data[1:], pld_t) frame = self.codec.decode(data, 1) - return Message(type="data", payload=_Response(frame=frame)) + return cast(P, Message(type="data", payload=_Response(frame=frame))) _ENDPOINT = "/frame/stream" @@ -137,14 +139,16 @@ def read(self, timeout: float | None = None) -> Frame | None: timeout. """ try: - res, exc = self._stream.receive(TimeSpan.to_seconds(timeout)) - if exc is not None: - raise exc - return self._adapter.adapt(Frame(res.frame)) + # mypy does not understand destructured union tuples, so we keep [pld, exc] as + # a single tuple. + res = self._stream.receive(TimeSpan.to_seconds(timeout)) + if res[1] is not None: + raise res[1] + return self._adapter.adapt(Frame(res[0].frame)) except TimeoutError: return None - def update_channels(self, channels: channel.Params): + def update_channels(self, channels: channel.Params) -> None: """Updates the list of channels to stream. This method will replace the current list of channels with the new list, not add to it. @@ -160,7 +164,7 @@ def update_channels(self, channels: channel.Params): ) ) - def close(self, timeout: float | int | TimeSpan | None = None): + def close(self, timeout: float | int | TimeSpan | None = None) -> None: """Closes the streamer and frees all network resources. :param timeout: The maximum amount of time to wait for the server to acknowledge @@ -186,22 +190,22 @@ def close(self, timeout: float | int | TimeSpan | None = None): raise exc break - def __iter__(self): + def __iter__(self) -> Streamer: """Returns an iterator object that can be used to iterate over the frames of telemetry as they are received. This is useful when you want to process each frame as it is received. """ return self - def __enter__(self): + def __enter__(self) -> Streamer: """Returns the streamer object when used as a context manager.""" return self - def __next__(self): + def __next__(self) -> Frame: """Reads the next frame of telemetry from the streamer.""" return self.read() - def __exit__(self, exc_type, exc_val, exc_tb): + def __exit__(self, exc_type: object, exc_val: object, exc_tb: object) -> None: self.close() @@ -237,7 +241,7 @@ def __init__( self._downsample_factor = downsample_factor self._throttle_rate = throttle_rate - async def _open(self): + async def _open(self) -> None: self._stream = await self._client.stream(_ENDPOINT, _Request, _Response) await self._stream.send( _Request( @@ -250,21 +254,18 @@ async def _open(self): if exc is not None: raise exc - @property - def received(self) -> bool: - """Returns True if a frame has been received, False otherwise.""" - return self._stream.received() - async def read(self) -> Frame: """Reads the next frame of telemetry from the streamer. If an error occurs while reading the frame, an exception will be raised. """ - res, exc = await self._stream.receive() - if exc is not None: - raise exc - return self._adapter.adapt(Frame(res.frame)) - - async def close_loop(self): + # mypy does not understand destructured union tuples, so we keep [pld, exc] as + # a single tuple. + res = await self._stream.receive() + if res[1] is not None: + raise res[1] + return self._adapter.adapt(Frame(res[0].frame)) + + async def close_loop(self) -> None: """Closes the sending end of the streamer, requiring the caller to process all remaining frames and close acknowledgements by calling read. This method is useful for managing the lifecycle of a streamer within a separate event loop or @@ -272,7 +273,7 @@ async def close_loop(self): """ await self._stream.close_send() - async def close(self): + async def close(self) -> None: """Close the streamer and free all network resources, waiting for the server to acknowledge the close request. """ @@ -288,24 +289,26 @@ async def close(self): elif not isinstance(exc, EOF): raise exc - async def __aenter__(self): + async def __aenter__(self) -> AsyncStreamer: """Returns the async streamer object when used as an async context manager.""" return self - def __aiter__(self): + def __aiter__(self) -> AsyncStreamer: """Returns an async iterator object that can be used to iterate over the frames of telemetry as they are received. This is useful when you want to process each frame as it is received. """ return self - async def __anext__(self): + async def __anext__(self) -> Frame: """Reads the next frame of telemetry from the streamer.""" try: return await self.read() except EOF: raise StopAsyncIteration - async def __aexit__(self, exc_type, exc_val, exc_tb): + async def __aexit__( + self, exc_type: object, exc_val: object, exc_tb: object + ) -> None: """Closes the streamer when used as an async context manager""" await self.close() diff --git a/client/py/synnax/framer/writer.py b/client/py/synnax/framer/writer.py index b1e512036a..f1ccc846b3 100644 --- a/client/py/synnax/framer/writer.py +++ b/client/py/synnax/framer/writer.py @@ -7,8 +7,10 @@ # License, use of this software will be governed by the Apache License, Version 2.0, # included in the file licenses/APL.txt. +from __future__ import annotations + from enum import Enum -from typing import Literal, TypeAlias, overload +from typing import Literal, TypeAlias, cast, overload from uuid import uuid4 from freighter import ( @@ -18,10 +20,12 @@ WebsocketClient, decode_exception, ) +from freighter.transport import P from freighter.websocket import Message from pydantic import BaseModel import synnax.channel.payload as channel +from synnax.exceptions import UnexpectedError from synnax.framer.adapter import WriteFrameAdapter from synnax.framer.codec import ( HIGH_PERF_SPECIAL_CHAR, @@ -57,10 +61,10 @@ class WriterMode(int, Enum): class WriterConfig(BaseModel): - authorities: list[int] = Authority.ABSOLUTE + authorities: list[int] = [Authority.ABSOLUTE] control_subject: Subject = Subject(name="", key=str(uuid4())) start: TimeStamp | None = None - keys: list[channel.Key] | tuple[channel.Key] + keys: list[channel.Key] mode: WriterMode = WriterMode.PERSIST_STREAM err_on_unauthorized: bool = False enable_auto_commit: bool = True @@ -80,22 +84,27 @@ class WriterResponse(BaseModel): class WSWriterCodec(WSFramerCodec): - def encode(self, pld: Message[WriterRequest]) -> bytes: - if pld.type == "close" or pld.payload.command != WriterCommand.WRITE: - data = self.lower_perf_codec.encode(pld) - return bytes([LOW_PERF_SPECIAL_CHAR]) + data - data = self.codec.encode(pld.payload.frame, 1) - data = bytearray(data) - data[0] = HIGH_PERF_SPECIAL_CHAR - return bytes(data) - - def decode(self, data: bytes, pld_t: Message[WriterResponse]) -> object: + def encode(self, data: BaseModel) -> bytes: + if not isinstance(data, Message): + raise TypeError(f"expected Message, got {type(data)}") + if ( + data.type == "close" + or data.payload is None + or data.payload.command != WriterCommand.WRITE + ): + return bytes([LOW_PERF_SPECIAL_CHAR]) + self.lower_perf_codec.encode(data) + encoded = self.codec.encode(data.payload.frame, 1) + buf = bytearray(encoded) + buf[0] = HIGH_PERF_SPECIAL_CHAR + return bytes(buf) + + def decode(self, data: bytes, pld_t: type[P]) -> P: if data[0] == LOW_PERF_SPECIAL_CHAR: return self.lower_perf_codec.decode(data[1:], pld_t) frame = self.codec.decode(data, 1) - msg = Message[WriterRequest](type="data") + msg: Message[WriterRequest] = Message(type="data") msg.payload = WriterRequest(command=WriterCommand.WRITE, frame=frame) - return msg + return cast(P, msg) def parse_writer_mode(mode: CrudeWriterMode) -> WriterMode: @@ -118,7 +127,7 @@ def parse_writer_mode(mode: CrudeWriterMode) -> WriterMode: ALWAYS_INDEX_PERSIST_ON_AUTO_COMMIT: TimeSpan = TimeSpan(-1) -class WriterClosed(BaseException): ... +class WriterClosed(Exception): ... class Writer: @@ -175,7 +184,7 @@ def __init__( client: WebsocketClient, adapter: WriteFrameAdapter, name: str = "", - authorities: list[Authority] | Authority = Authority.ABSOLUTE, + authorities: CrudeAuthority | list[CrudeAuthority] = Authority.ABSOLUTE, mode: CrudeWriterMode = WriterMode.PERSIST_STREAM, err_on_unauthorized: bool = False, enable_auto_commit: bool = True, @@ -207,7 +216,9 @@ def __init__( raise exc @overload - def write(self, channels_or_data: channel.Key | str, series: CrudeSeries): ... + def write( + self, channels_or_data: channel.Key | str, series: CrudeSeries + ) -> None: ... @overload def write( @@ -216,13 +227,13 @@ def write( list[channel.Key] | tuple[channel.Key] | list[str] | tuple[str] ), series: list[CrudeSeries], - ): ... + ) -> None: ... @overload def write( self, channels_or_data: CrudeFrame, - ): ... + ) -> None: ... def write( self, @@ -366,16 +377,21 @@ def set_authority( if isinstance(value, int) and authority is None: cfg = WriterConfig(keys=[], authorities=[value]) else: + auth_map: dict[channel.Key | str | channel.Payload, CrudeAuthority] if isinstance(value, (channel.Key, str)): if authority is None: raise ValueError( "authority must be provided when setting a single channel" ) - value = {value: authority} - value = self._adapter.adapt_dict_keys(value) + auth_map = {value: authority} + elif isinstance(value, dict): + auth_map = value + else: + raise ValueError(f"unexpected authority value type: {type(value)}") + resolved = self._adapter.adapt_dict_keys(auth_map) cfg = WriterConfig( - keys=list(value.keys()), - authorities=list(value.values()), + keys=list(resolved.keys()), + authorities=list(resolved.values()), ) self._exec(WriterRequest(command=WriterCommand.SET_AUTHORITY, config=cfg)) @@ -391,9 +407,13 @@ def commit(self) -> TimeStamp: if self._close_exc is not None: raise self._close_exc res = self._exec(WriterRequest(command=WriterCommand.COMMIT)) + if res is None: + raise UnexpectedError("commit did not receive a response") + if res.end is None: + raise UnexpectedError("commit response missing end timestamp") return res.end - def close(self): + def close(self) -> None: """Closes the writer, raising any accumulated error encountered during operation. A writer MUST be closed after use, and this method should probably be placed in a 'finally' block. @@ -412,30 +432,37 @@ def _close(self, exc: Exception | None) -> None: if isinstance(self._close_exc, WriterClosed): return raise self._close_exc - res, exc = self._stream.receive() - if exc is not None: - self._close_exc = WriterClosed() if isinstance(exc, EOF) else exc - else: + res, recv_exc = self._stream.receive() + if recv_exc is not None: + self._close_exc = ( + WriterClosed() if isinstance(recv_exc, EOF) else recv_exc + ) + elif res is not None: self._close_exc = decode_exception(res.err) def _exec( self, req: WriterRequest, timeout: int | None = None ) -> WriterResponse | None: - exc = self._stream.send(req) - if exc is not None: - return self._close(exc) + send_exc = self._stream.send(req) + if send_exc is not None: + self._close(send_exc) + return None while True: - res, exc = self._stream.receive(timeout) - if exc is not None: - return self._close(exc) - exc = decode_exception(res.err) - if exc is not None: - return self._close(exc) + res, recv_exc = self._stream.receive(timeout) + if recv_exc is not None: + self._close(recv_exc) + return None + if res is None: + continue + decoded_exc = decode_exception(res.err) + if decoded_exc is not None: + self._close(decoded_exc) + return None if res.command == req.command: return res - def __enter__(self): + def __enter__(self) -> Writer: return self - def __exit__(self, exc_type, exc_value, traceback): + def __exit__(self, exc_type: object, exc_value: object, traceback: object) -> None: self.close() diff --git a/client/py/synnax/group/__init__.py b/client/py/synnax/group/__init__.py index 168a7e073d..7c249ff333 100644 --- a/client/py/synnax/group/__init__.py +++ b/client/py/synnax/group/__init__.py @@ -9,3 +9,5 @@ from synnax.group.client import Client from synnax.group.payload import Group + +__all__ = ["Client", "Group"] diff --git a/client/py/synnax/ingest/row.py b/client/py/synnax/ingest/row.py index f7b1b8844e..94af99a42d 100644 --- a/client/py/synnax/ingest/row.py +++ b/client/py/synnax/ingest/row.py @@ -58,11 +58,11 @@ def __init__( ) self.end = start - def get_chunk_size(self): + def get_chunk_size(self) -> int: """Sum the density of all channels to determine the chunk size.""" return self.mem_limit // sum(ch.data_type.density for ch in self.channels) - def run(self): + def run(self) -> None: """Run the ingestion engine.""" self.reader.seek_first() try: @@ -88,7 +88,7 @@ def run(self): self.reader.close() self.writer.close() - def _write(self, df: DataFrame): + def _write(self, df: DataFrame) -> None: for channel in self.channels: if channel.name in df.columns: df.rename(columns={channel.name: channel.key}, inplace=True) diff --git a/client/py/synnax/io/__init__.py b/client/py/synnax/io/__init__.py index 07aa7418b1..5f9ff8df71 100644 --- a/client/py/synnax/io/__init__.py +++ b/client/py/synnax/io/__init__.py @@ -17,3 +17,15 @@ ReaderType, RowFileReader, ) + +__all__ = [ + "BaseReader", + "ColumnFileReader", + "DataFrameWriter", + "FileWriter", + "IO_FACTORY", + "IOFactory", + "ReaderType", + "RowFileReader", + "ChannelMeta", +] diff --git a/client/py/synnax/io/csv.py b/client/py/synnax/io/csv.py index 0d49578e99..70550040ac 100644 --- a/client/py/synnax/io/csv.py +++ b/client/py/synnax/io/csv.py @@ -7,6 +7,8 @@ # License, use of this software will be governed by the Apache License, Version 2.0, # included in the file licenses/APL.txt. +from __future__ import annotations + import csv from collections.abc import Iterator from pathlib import Path @@ -61,7 +63,7 @@ def __detect_delimiter(self) -> str: dialect = csv.Sniffer().sniff(sample) return dialect.delimiter - def seek_first(self): + def seek_first(self) -> None: self.close() self.__reader = pd.read_csv( self._path, @@ -113,7 +115,7 @@ def channels(self) -> list[ChannelMeta]: ] return self._channels - def set_chunk_size(self, chunk_size: int): + def set_chunk_size(self, chunk_size: int) -> None: self.chunk_size = chunk_size def read(self) -> pd.DataFrame: @@ -141,7 +143,7 @@ def _reader(self) -> TextFileReader: assert self.__reader is not None return self.__reader - def close(self): + def close(self) -> None: if self.__reader is not None: self.__reader.close() self.__reader = None @@ -179,14 +181,14 @@ def __init__( def _(self) -> FileWriter: return self - def write(self, df: pd.DataFrame): + def write(self, df: pd.DataFrame) -> None: df.to_csv(self._path, index=False, mode="a", header=self._header) self._header = False def path(self) -> Path: return self._path - def close(self): + def close(self) -> None: pass @@ -196,8 +198,8 @@ class CSVReaderIterator: def __init__(self, base: Iterator[pd.DataFrame]): self.base = base - def __iter__(self): + def __iter__(self) -> CSVReaderIterator: return self - def __next__(self): + def __next__(self) -> pd.DataFrame: return convert_df(next(self.base)) diff --git a/client/py/synnax/io/protocol.py b/client/py/synnax/io/protocol.py index 3d0e8a5889..ea2f2bc4a1 100644 --- a/client/py/synnax/io/protocol.py +++ b/client/py/synnax/io/protocol.py @@ -16,6 +16,16 @@ from synnax.io.meta import ChannelMeta +__all__ = [ + "ChannelMeta", + "ReaderType", + "BaseReader", + "RowFileReader", + "ColumnFileReader", + "DataFrameWriter", + "FileWriter", +] + class ReaderType(Enum): Row = "row" @@ -40,7 +50,7 @@ def match(cls, path: Path) -> bool: class Closer(Protocol): """Closer is a closable buffer""" - def close(self): + def close(self) -> None: """Closes the buffer.""" ... @@ -86,7 +96,7 @@ def nsamples(self) -> int: """:returns: the number of samples in the file.""" ... - def seek_first(self): + def seek_first(self) -> None: """Seeks the reader to the first sample in the file.""" ... @@ -98,7 +108,7 @@ class RowFileReader(BaseReader, Protocol): csv files). """ - def set_chunk_size(self, chunk_size: int): + def set_chunk_size(self, chunk_size: int) -> None: """Set the chunk size for the reader. It's generally unsafe to assume the reader position after calling set_chunk_size, so it's recommended to call reset afterwards. diff --git a/client/py/synnax/io/tdms.py b/client/py/synnax/io/tdms.py index 618e3e0bca..8cbe69ded7 100644 --- a/client/py/synnax/io/tdms.py +++ b/client/py/synnax/io/tdms.py @@ -74,10 +74,10 @@ def channels(self) -> list[ChannelMeta]: return self._channels - def set_chunk_size(self, chunk_size: int): + def set_chunk_size(self, chunk_size: int) -> None: self.chunk_size = chunk_size - def set_keys(self, keys: list[str]): + def set_keys(self, keys: list[str]) -> None: self.channel_keys = set(keys) def set_keys_from_file(self) -> set[str]: @@ -97,7 +97,7 @@ def nsamples(self) -> int: """:returns: the number of samples in the file.""" return self.chunk_size * self.n_chunks * len(self.channels()) - def seek_first(self): + def seek_first(self) -> None: """Seeks the reader to the first sample in the file.""" self._current_chunk = 0 @@ -109,8 +109,7 @@ def read(self, *keys: str) -> pd.DataFrame: if self._current_chunk >= self.n_chunks: return pd.DataFrame() - # if keys is empty, use default keys - keys: set[str] = self.channel_keys if (len(keys) == 0) else set(keys) + _keys: set[str] = self.channel_keys if (len(keys) == 0) else set(keys) # https://nptdms.readthedocs.io/en/stable/reading.html # https://github.com/adamreeve/npTDMS/issues/263 @@ -119,7 +118,7 @@ def read(self, *keys: str) -> pd.DataFrame: with TdmsFile.open(self._path) as tdms_file: for group in tdms_file.groups(): for channel in group.channels(): - if channel.name in keys: + if channel.name in _keys: data[channel.name] = channel[ self._current_chunk * self.chunk_size : (self._current_chunk + 1) diff --git a/client/py/synnax/labjack/types.py b/client/py/synnax/labjack/types.py index ae8ea87e35..343228da2b 100644 --- a/client/py/synnax/labjack/types.py +++ b/client/py/synnax/labjack/types.py @@ -8,10 +8,10 @@ # included in the file licenses/APL.txt. import json -from typing import Literal, get_args +from typing import Annotated, Any, Literal, TypeAlias, get_args from uuid import uuid4 -from pydantic import BaseModel, Field, confloat, conint, field_validator +from pydantic import BaseModel, Field, field_validator from synnax import channel as channel_ from synnax import device, task @@ -23,7 +23,7 @@ T4 = "LJM_dtT4" T7 = "LJM_dtT7" T8 = "LJM_dtT8" -SUPPORTED_MODELS = Literal[T4, T7, T8] +SUPPORTED_MODELS: TypeAlias = Literal["LJM_dtT4", "LJM_dtT7", "LJM_dtT8"] class BaseChan(BaseModel): @@ -36,7 +36,7 @@ class BaseChan(BaseModel): port: str = Field(min_length=1) "The port location of the channel (e.g., 'AIN0', 'DIO4')." - def __init__(self, **data): + def __init__(self, **data: Any) -> None: if "key" not in data or not data["key"]: data["key"] = str(uuid4()) super().__init__(**data) @@ -96,7 +96,7 @@ class AIChan(BaseChan): type: Literal["AI"] = "AI" channel: channel_.Key "The Synnax channel key that will be written to during acquisition." - range: confloat(gt=0) = 10.0 + range: Annotated[float, Field(gt=0)] = 10.0 "The voltage range for the channel (±range volts)." neg_chan: int = 199 "The negative channel for differential measurements. 199 = single-ended (GND)." @@ -336,13 +336,13 @@ class ReadTaskConfig(task.BaseReadConfig): device: str = Field(min_length=1) "The key of the Synnax LabJack device to read from." - sample_rate: conint(ge=0, le=100000) - stream_rate: conint(ge=0, le=100000) + sample_rate: Annotated[int, Field(ge=0, le=100000)] + stream_rate: Annotated[int, Field(ge=0, le=100000)] channels: list[InputChan] "A list of input channel configurations to acquire data from." @field_validator("channels") - def validate_channels_not_empty(cls, v): + def validate_channels_not_empty(cls, v: list[InputChan]) -> list[InputChan]: """Validate that at least one channel is provided.""" if len(v) == 0: raise ValueError("Task must have at least one channel") @@ -360,13 +360,13 @@ class WriteTaskConfig(task.BaseWriteConfig): data_saving: bool = True "Whether to persist state feedback data to disk (True) or only stream it (False)." - state_rate: conint(ge=0, le=10000) + state_rate: Annotated[int, Field(ge=0, le=10000)] "The rate at which to write task channel states to the Synnax cluster (Hz)." channels: list[OutputChan] "A list of output channel configurations to write to." @field_validator("channels") - def validate_channels_not_empty(cls, v): + def validate_channels_not_empty(cls, v: list[OutputChan]) -> list[OutputChan]: """Validate that at least one channel is provided.""" if len(v) == 0: raise ValueError("Task must have at least one channel") @@ -410,7 +410,7 @@ def __init__( stream_rate: CrudeRate = 0, data_saving: bool = False, auto_start: bool = False, - channels: list[InputChan] = None, + channels: list[InputChan] | None = None, ) -> None: if internal is not None: self._internal = internal @@ -478,7 +478,7 @@ def __init__( state_rate: CrudeRate = 0, data_saving: bool = False, auto_start: bool = False, - channels: list[OutputChan] = None, + channels: list[OutputChan] | None = None, ): if internal is not None: self._internal = internal diff --git a/client/py/synnax/modbus/__init__.py b/client/py/synnax/modbus/__init__.py index d49ecf3765..5b3c02904c 100644 --- a/client/py/synnax/modbus/__init__.py +++ b/client/py/synnax/modbus/__init__.py @@ -21,3 +21,17 @@ WriteTask, WriteTaskConfig, ) + +__all__ = [ + "CoilInputChan", + "CoilOutputChan", + "Device", + "DiscreteInputChan", + "HoldingRegisterInputChan", + "HoldingRegisterOutputChan", + "InputRegisterChan", + "ReadTask", + "ReadTaskConfig", + "WriteTask", + "WriteTaskConfig", +] diff --git a/client/py/synnax/modbus/types.py b/client/py/synnax/modbus/types.py index bad0f901f3..95c85a589b 100644 --- a/client/py/synnax/modbus/types.py +++ b/client/py/synnax/modbus/types.py @@ -8,10 +8,10 @@ # included in the file licenses/APL.txt. import json -from typing import Literal +from typing import Annotated, Any, Literal from uuid import uuid4 -from pydantic import BaseModel, Field, confloat, conint, field_validator +from pydantic import BaseModel, Field, field_validator from synnax import channel as channel_ from synnax import device, task @@ -32,7 +32,7 @@ class BaseChan(BaseModel): address: int = Field(ge=0, le=65535) "The Modbus register address (0-65535)." - def __init__(self, **data): + def __init__(self, **data: Any) -> None: if "key" not in data or not data["key"]: data["key"] = str(uuid4()) super().__init__(**data) @@ -392,13 +392,13 @@ class ReadTaskConfig(task.BaseReadConfig): device: str = Field(min_length=1) "The key of the Synnax Modbus device to read from." - sample_rate: conint(ge=0, le=10000) - stream_rate: conint(ge=0, le=10000) + sample_rate: Annotated[int, Field(ge=0, le=10000)] + stream_rate: Annotated[int, Field(ge=0, le=10000)] channels: list[InputChan] "A list of input channel configurations to acquire data from." @field_validator("channels") - def validate_channels_not_empty(cls, v): + def validate_channels_not_empty(cls, v: list[InputChan]) -> list[InputChan]: """Validate that at least one channel is provided.""" if len(v) == 0: raise ValueError("Task must have at least one channel") @@ -418,7 +418,7 @@ class WriteTaskConfig(task.BaseWriteConfig): "A list of output channel configurations to write to." @field_validator("channels") - def validate_channels_not_empty(cls, v): + def validate_channels_not_empty(cls, v: list[OutputChan]) -> list[OutputChan]: """Validate that at least one channel is provided.""" if len(v) == 0: raise ValueError("Task must have at least one channel") @@ -462,7 +462,7 @@ def __init__( stream_rate: CrudeRate = 0, data_saving: bool = False, auto_start: bool = False, - channels: list[InputChan] = None, + channels: list[InputChan] | None = None, ) -> None: if internal is not None: self._internal = internal @@ -547,7 +547,7 @@ def __init__( device: device.Key = "", name: str = "", auto_start: bool = False, - channels: list[OutputChan] = None, + channels: list[OutputChan] | None = None, ): if internal is not None: self._internal = internal diff --git a/client/py/synnax/ni/types.py b/client/py/synnax/ni/types.py index 5c799733fd..50103e5096 100644 --- a/client/py/synnax/ni/types.py +++ b/client/py/synnax/ni/types.py @@ -8,10 +8,10 @@ # included in the file licenses/APL.txt. import json -from typing import Literal +from typing import Annotated, Any, Literal from uuid import uuid4 -from pydantic import BaseModel, Field, confloat, conint, field_validator, validator +from pydantic import BaseModel, Field, field_validator, validator from synnax import device, task from synnax.exceptions import ValidationError @@ -160,7 +160,7 @@ class BaseChan(BaseModel): key: str enabled: bool = True - def __init__(self, **data): + def __init__(self, **data: Any) -> None: if "key" not in data: data["key"] = str(uuid4()) super().__init__(**data) @@ -349,7 +349,7 @@ class AIBridgeChan(BaseAIChan, MinMaxVal): bridge_config: Literal["FullBridge", "HalfBridge", "QuarterBridge"] voltage_excit_source: ExcitationSource voltage_excit_val: float - nominal_bridge_resistance: confloat(gt=0) + nominal_bridge_resistance: Annotated[float, Field(gt=0)] custom_scale: Scale = NoScale() @@ -425,7 +425,7 @@ class AICurrentChan(BaseAIChan, MinMaxVal): terminal_config: TerminalConfig = "Cfg_Default" units: Literal["Amps"] = "Amps" shunt_resistor_loc: Literal["Default", "Internal", "External"] - ext_shunt_resistor_val: confloat(gt=0) + ext_shunt_resistor_val: Annotated[float, Field(gt=0)] custom_scale: Scale = NoScale() @@ -466,7 +466,7 @@ class AICurrentRMSChan(BaseAIChan, MinMaxVal): terminal_config: TerminalConfig = "Cfg_Default" units: Literal["Amps"] = "Amps" shunt_resistor_loc: Literal["Default", "Internal", "External"] - ext_shunt_resistor_val: confloat(gt=0) + ext_shunt_resistor_val: Annotated[float, Field(gt=0)] custom_scale: Scale = NoScale() @@ -2412,11 +2412,11 @@ class AnalogReadTaskConfig(task.BaseReadConfig): device: str = "" "The key of the Synnax NI device to read from (optional, can be set per channel)." - sample_rate: conint(gt=0, le=1000000) + sample_rate: Annotated[int, Field(gt=0, le=1000000)] channels: list[AIChan] @field_validator("channels") - def validate_channel_ports(cls, v, values): + def validate_channel_ports(cls, v: list[AIChan], values: Any) -> list[AIChan]: ports = {c.port for c in v} if len(ports) < len(v): used_ports = [c.port for c in v] @@ -2435,7 +2435,7 @@ class AnalogWriteConfig(task.BaseWriteConfig): data_saving: bool = True "Whether to persist state feedback data to disk (True) or only stream it (False)." - state_rate: conint(gt=0, le=50000) + state_rate: Annotated[int, Field(gt=0, le=50000)] "The rate at which to write task channel states to the Synnax cluster (Hz)." channels: list[AOChan] @@ -2450,11 +2450,11 @@ class CounterReadConfig(task.BaseReadConfig): device: str = "" "The key of the Synnax NI device to read from (optional, can be set per channel)." - sample_rate: conint(gt=0, le=1000000) + sample_rate: Annotated[int, Field(gt=0, le=1000000)] channels: list[CIChan] @field_validator("channels") - def validate_channel_ports(cls, v): + def validate_channel_ports(cls, v: list[CIChan]) -> list[CIChan]: ports = {c.port for c in v} if len(ports) < len(v): used_ports = [c.port for c in v] @@ -2473,7 +2473,7 @@ class DigitalReadConfig(task.BaseReadConfig): device: str = Field(min_length=1) "The key of the Synnax NI device to read from." - sample_rate: conint(gt=0, le=1000000) + sample_rate: Annotated[int, Field(gt=0, le=1000000)] channels: list[DIChan] @@ -2487,7 +2487,7 @@ class DigitalWriteConfig(task.BaseWriteConfig): data_saving: bool = True "Whether to persist state feedback data to disk (True) or only stream it (False)." - state_rate: conint(gt=0, le=50000) + state_rate: Annotated[int, Field(gt=0, le=50000)] "The rate at which to write task channel states to the Synnax cluster (Hz)." channels: list[DOChan] @@ -2534,7 +2534,7 @@ def __init__( stream_rate: CrudeRate = 0, data_saving: bool = False, auto_start: bool = False, - channels: list[AIChan] = None, + channels: list[AIChan] | None = None, ) -> None: if internal is not None: self._internal = internal @@ -2592,7 +2592,7 @@ def __init__( state_rate: CrudeRate = 0, data_saving: bool = False, auto_start: bool = False, - channels: list[AOChan] = None, + channels: list[AOChan] | None = None, ): if internal is not None: self._internal = internal @@ -2641,7 +2641,7 @@ def __init__( stream_rate: CrudeRate = 0, data_saving: bool = False, auto_start: bool = False, - channels: list[CIChan] = None, + channels: list[CIChan] | None = None, ) -> None: if internal is not None: self._internal = internal @@ -2702,7 +2702,7 @@ def __init__( stream_rate: CrudeRate = 0, data_saving: bool = False, auto_start: bool = False, - channels: list[DIChan] = None, + channels: list[DIChan] | None = None, ) -> None: if internal is not None: self._internal = internal @@ -2748,7 +2748,7 @@ def __init__( state_rate: CrudeRate = 0, data_saving: bool = False, auto_start: bool = False, - channels: list[DOChan] = None, + channels: list[DOChan] | None = None, ): if internal is not None: self._internal = internal diff --git a/client/py/synnax/ontology/client.py b/client/py/synnax/ontology/client.py index 3c4fcf904b..b0a52f275a 100644 --- a/client/py/synnax/ontology/client.py +++ b/client/py/synnax/ontology/client.py @@ -68,7 +68,7 @@ def retrieve( @overload def retrieve( self, - ids: list[CrudeID], + id: list[CrudeID], *, children: bool = False, parents: bool = False, @@ -78,7 +78,7 @@ def retrieve( def retrieve( self, - id: CrudeID | list[CrudeID], + id: CrudeID | list[CrudeID] | None = None, *, children: bool = False, parents: bool = False, @@ -87,8 +87,11 @@ def retrieve( ) -> Resource | list[Resource]: is_single = False if not isinstance(id, list): - id = [id] - is_single = True + if id is None: + id = [] + else: + id = [id] + is_single = True req = RetrieveReq( ids=[ID(i) for i in id], children=children, @@ -105,8 +108,9 @@ def retrieve_children( self, id: CrudeID | list[CrudeID], ) -> list[Resource]: + normalized: list[CrudeID] = normalize(id) return self.__exec_retrieve( - RetrieveReq(ids=[ID(i) for i in normalize(id)], children=True) + RetrieveReq(ids=[ID(i) for i in normalized], children=True) ) def __exec_retrieve(self, req: RetrieveReq) -> list[Resource]: @@ -117,9 +121,10 @@ def __exec_retrieve(self, req: RetrieveReq) -> list[Resource]: def retrieve_parents( self, id: CrudeID | list[CrudeID], - ): + ) -> list[Resource]: + normalized: list[CrudeID] = normalize(id) return self.__exec_retrieve( - RetrieveReq(ids=[ID(i) for i in normalize(id)], parents=True) + RetrieveReq(ids=[ID(i) for i in normalized], parents=True) ) def move_children(self, from_: CrudeID, to: CrudeID, *children: CrudeID) -> None: @@ -145,6 +150,6 @@ def add_children(self, id: CrudeID, *children: CrudeID) -> None: send_required( self._client, "/ontology/add-children", - AddChildrenReq(parent=ID(id), children=[ID(i) for i in children]), + AddChildrenReq(id=ID(id), children=[ID(i) for i in children]), Empty, ) diff --git a/client/py/synnax/ontology/payload.py b/client/py/synnax/ontology/payload.py index 6e0dea0886..b7f8dc720b 100644 --- a/client/py/synnax/ontology/payload.py +++ b/client/py/synnax/ontology/payload.py @@ -9,6 +9,8 @@ from __future__ import annotations +from typing import Any + from pydantic import BaseModel @@ -22,13 +24,13 @@ def __init__(self, key: CrudeID | None = None, type: str | None = None): elif isinstance(key, tuple): key, type = key super().__init__(key=key, type=type) - elif type is None: + elif type is None and key is not None: type, key = key.split(":") super().__init__(key=key, type=type) else: super().__init__(key=key, type=type) - def __str__(self): + def __str__(self) -> str: return f"{self.key}:{self.type}" @@ -40,7 +42,7 @@ def __str__(self): class Resource(BaseModel): id: ID name: str - data: dict + data: dict[str, Any] class Relationship(BaseModel): diff --git a/client/py/synnax/opcua/types.py b/client/py/synnax/opcua/types.py index 2f87b857ed..abfaeda85d 100644 --- a/client/py/synnax/opcua/types.py +++ b/client/py/synnax/opcua/types.py @@ -315,6 +315,7 @@ class ReadTask(task.StarterStopperMixin, task.JSONConfigMixin, task.Protocol): """ TYPE = "opc_read" + config: NonArraySamplingReadTaskConfig | ArraySamplingReadTaskConfig _internal: task.Task def __init__( @@ -329,7 +330,7 @@ def __init__( auto_start: bool = False, array_mode: bool = False, array_size: int = 1, - channels: list[ReadChannel] = None, + channels: list[ReadChannel] | None = None, ): if internal is not None: self._internal = internal @@ -393,7 +394,7 @@ class WriteTaskConfig(task.BaseWriteConfig): "A list of WriteChannel objects that specify which OPC UA nodes to write to." @field_validator("channels") - def validate_channels_not_empty(cls, v): + def validate_channels_not_empty(cls, v: list[WriteChannel]) -> list[WriteChannel]: """Validate that at least one channel is provided.""" if len(v) == 0: raise ValueError("Task must have at least one channel") @@ -425,7 +426,7 @@ def __init__( device: device.Key = "", name: str = "", auto_start: bool = False, - channels: list[WriteChannel] = None, + channels: list[WriteChannel] | None = None, ): if internal is not None: self._internal = internal diff --git a/client/py/synnax/rack/client.py b/client/py/synnax/rack/client.py index 3ebe00e10d..82b0867e0d 100644 --- a/client/py/synnax/rack/client.py +++ b/client/py/synnax/rack/client.py @@ -61,7 +61,7 @@ def __init__( def create(self, *, key: int = 0, name: str = "") -> Rack: ... @overload - def create(self, rack: Rack) -> Rack: ... + def create(self, racks: Rack) -> Rack: ... @overload def create(self, racks: list[Rack]) -> list[Rack]: ... @@ -86,7 +86,7 @@ def create( return res.racks[0] return res.racks - def delete(self, keys: list[int]): + def delete(self, keys: list[int]) -> None: req = _DeleteRequest(keys=keys) send_required(self._client, "/rack/delete", req, Empty) @@ -95,10 +95,23 @@ def retrieve( self, key: int | None = None, name: str | None = None, + *, embedded: bool | None = None, host_is_node: bool | None = None, ) -> Rack: ... + @overload + def retrieve( + self, + key: int | None = None, + name: str | None = None, + keys: list[int] | None = None, + names: list[str] | None = None, + *, + embedded: bool | None = None, + host_is_node: bool | None = None, + ) -> list[Rack]: ... + def retrieve( self, key: int | None = None, @@ -108,7 +121,7 @@ def retrieve( *, host_is_node: bool | None = None, embedded: bool | None = None, - ) -> list[Rack]: + ) -> Rack | list[Rack]: is_single = check_for_none(keys, names) res = send_required( self._client, diff --git a/client/py/synnax/rack/payload.py b/client/py/synnax/rack/payload.py index cdab386da6..c96064ac00 100644 --- a/client/py/synnax/rack/payload.py +++ b/client/py/synnax/rack/payload.py @@ -17,7 +17,7 @@ def ontology_id(key: int) -> ontology.ID: """Returns the ontology ID for the Rack entity.""" - return ontology.ID(type=ONTOLOGY_TYPE.type, key=key) + return ontology.ID(type=ONTOLOGY_TYPE.type, key=str(key)) class StatusDetails(BaseModel): diff --git a/client/py/synnax/ranger/__init__.py b/client/py/synnax/ranger/__init__.py index 8aa95ddf40..eb19323311 100644 --- a/client/py/synnax/ranger/__init__.py +++ b/client/py/synnax/ranger/__init__.py @@ -10,7 +10,13 @@ from synnax.ranger.client import Client, Range from synnax.ranger.retrieve import Retriever from synnax.ranger.writer import Writer +from synnax.util.deprecation import deprecated_getattr -# Backwards compatibility -RangeRetriever = Retriever -RangeWriter = Writer +_DEPRECATED = { + "RangeRetriever": "Retriever", + "RangeWriter": "Writer", +} + +__getattr__ = deprecated_getattr(__name__, _DEPRECATED, globals()) + +__all__ = ["Client", "Range", "Retriever", "Writer"] diff --git a/client/py/synnax/ranger/alias/__init__.py b/client/py/synnax/ranger/alias/__init__.py index 0d3814fc9a..7f38346805 100644 --- a/client/py/synnax/ranger/alias/__init__.py +++ b/client/py/synnax/ranger/alias/__init__.py @@ -11,3 +11,5 @@ Aliaser = Client """Deprecated: Use Client instead.""" + +__all__ = ["Client", "Aliaser"] diff --git a/client/py/synnax/ranger/alias/client.py b/client/py/synnax/ranger/alias/client.py index 0960b370b1..d0cfabe757 100644 --- a/client/py/synnax/ranger/alias/client.py +++ b/client/py/synnax/ranger/alias/client.py @@ -8,6 +8,7 @@ # included in the file licenses/APL.txt. import uuid +from typing import overload from freighter import UnaryClient from pydantic import BaseModel @@ -42,17 +43,19 @@ def __init__(self, rng: uuid.UUID, client: UnaryClient) -> None: self.__rng = rng self.__cache = {} - def resolve(self, alias: str) -> channel.Key: ... + @overload + def resolve(self, aliases: str) -> channel.Key: ... + @overload def resolve(self, aliases: list[str]) -> dict[str, channel.Key]: ... def resolve(self, aliases: str | list[str]) -> dict[str, channel.Key] | channel.Key: - to_fetch = list() - aliases = normalize(aliases) is_single = isinstance(aliases, str) + normalized_aliases = normalize(aliases) + to_fetch: list[str] = list() - results = {} - for alias in aliases: + results: dict[str, channel.Key] = {} + for alias in normalized_aliases: key = self.__cache.get(alias, None) if key is not None: results[alias] = key @@ -60,18 +63,24 @@ def resolve(self, aliases: str | list[str]) -> dict[str, channel.Key] | channel. to_fetch.append(alias) if len(to_fetch) == 0: + if is_single: + return results[normalized_aliases[0]] return results req = _ResolveRequest(range=self.__rng, aliases=to_fetch) res, exc = self.__client.send("/range/alias/resolve", req, _ResolveResponse) if exc is not None: raise exc + if res is None: + if is_single: + raise KeyError(f"Alias not found: {aliases}") + return results for alias, key in res.aliases.items(): self.__cache[alias] = key if is_single: - return res.aliases[aliases] + return res.aliases[normalized_aliases[0]] return {**results, **res.aliases} def set(self, aliases: dict[channel.Key, str]) -> None: diff --git a/client/py/synnax/ranger/client.py b/client/py/synnax/ranger/client.py index 29c52b5c52..5c2bcadb36 100644 --- a/client/py/synnax/ranger/client.py +++ b/client/py/synnax/ranger/client.py @@ -12,8 +12,8 @@ import functools import warnings -from collections.abc import Callable -from typing import overload +from collections.abc import Callable, Iterator +from typing import Any, cast, overload from uuid import UUID import numpy as np @@ -21,15 +21,13 @@ from pydantic import PrivateAttr import synnax.channel.payload as channel +import synnax.ranger.alias as alias +import synnax.ranger.kv as kv +from synnax import framer from synnax.channel.retrieve import Retriever as ChannelRetriever from synnax.exceptions import QueryError -from synnax.framer.client import Client -from synnax.framer.frame import CrudeFrame -from synnax.ni import AnalogReadTask from synnax.ontology import Client as OntologyClient from synnax.ontology.payload import ID -from synnax.ranger.alias import Client as AliasClient -from synnax.ranger.kv import Client as KVClient from synnax.ranger.payload import ( Key, Payload, @@ -44,6 +42,7 @@ from synnax.telem import ( CrudeSeries, DataType, + MultiSeries, Rate, SampleValue, Series, @@ -59,27 +58,27 @@ class _InternalScopedChannel(channel.Payload): __range: Range | None = PrivateAttr(None) """The range that this channel belongs to.""" - __frame_client: Client | None = PrivateAttr(None) + __frame_client: framer.Client | None = PrivateAttr(None) """The frame client for executing read operations.""" - __aliaser: AliasClient | None = PrivateAttr(None) + __aliaser: alias.Client | None = PrivateAttr(None) """An aliaser for setting the channel's alias.""" - __cache: Series | None = PrivateAttr(None) + __cache: MultiSeries | None = PrivateAttr(None) """An internal cache to prevent repeated reads from the same channel.""" __tasks: TaskClient | None = PrivateAttr(None) __ontology: OntologyClient | None = PrivateAttr(None) - def __new__(cls, *args, **kwargs): + def __new__(cls, *args: Any, **kwargs: Any) -> _InternalScopedChannel: cls = overload_comparison_operators(cls, "__array__") - return super().__new__(cls) + return super().__new__(cls) # type: ignore[no-any-return] def __init__( self, rng: Range, - frame_client: Client, + frame_client: framer.Client, tasks: TaskClient, ontology: OntologyClient, payload: channel.Payload, - aliaser: AliasClient | None = None, + aliaser: alias.Client | None = None, ): super().__init__(**payload.model_dump()) self.__range = rng @@ -89,21 +88,16 @@ def __init__( self.__ontology = ontology @property - def time_range(self) -> TimeRange: - return self.__range.time_range + def _range(self) -> Range: + if self.__range is None: + raise _RANGE_NOT_CREATED + return self.__range @property - def calibrations(self): - snapshots = self.__range.snapshots() - ni_tasks = [AnalogReadTask(t) for t in snapshots] - for t in ni_tasks: - for chan in t.config.channels: - if chan.channel == self.key: - return chan - - return None + def time_range(self) -> TimeRange: + return self._range.time_range - def __array__(self, *args, **kwargs) -> np.ndarray: + def __array__(self, *args: Any, **kwargs: Any) -> np.ndarray: """Converts the channel to a numpy array. This method is necessary for numpy interop.""" return self.read().__array__(*args, **kwargs) @@ -117,18 +111,24 @@ def to_numpy(self) -> np.ndarray: """ return self.read().to_numpy() - def read(self) -> Series: + @property + def _frame_client(self) -> framer.Client: + if self.__frame_client is None: + raise _RANGE_NOT_CREATED + return self.__frame_client + + def read(self) -> MultiSeries: if self.__cache is None: - self.__cache = self.__frame_client.read(self.time_range, self.key) + self.__cache = self._frame_client.read(self.time_range, self.key) return self.__cache - def set_alias(self, alias: str): - self.__range.set_alias(self.key, alias) + def set_alias(self, alias: str) -> None: + self._range.set_alias(self.key, alias) def __str__(self) -> str: return f"{super().__str__()} between {self.time_range.start} and {self.time_range.end}" - def __len__(self): + def __len__(self) -> int: return len(self.read()) @@ -148,9 +148,9 @@ class ScopedChannel: __internal: list[_InternalScopedChannel] __query: str - def __new__(cls, *args, **kwargs): + def __new__(cls, *args: Any, **kwargs: Any) -> ScopedChannel: cls = overload_comparison_operators(cls, "__array__") - return super().__new__(cls) + return super().__new__(cls) # type: ignore[no-any-return] def __init__( self, @@ -160,13 +160,13 @@ def __init__( self.__internal = internal self.__query = query - def __guard(self): + def __guard(self) -> None: if len(self.__internal) > 1: raise QueryError(f"""Multiple channels found for query '{self.__query}': {[str(ch) for ch in self.__internal]} """) - def __array__(self, *args, **kwargs): + def __array__(self, *args: object, **kwargs: object) -> np.ndarray: """Converts the scoped channel to a numpy array. This method is necessary for numpy interop.""" self.__guard() @@ -211,24 +211,14 @@ def leaseholder(self) -> int: self.__guard() return self.__internal[0].leaseholder - @property - def rate(self) -> Rate: - self.__guard() - return self.__internal[0].rate - - @property - def calibrations(self): - self.__guard() - return self.__internal[0].calibrations - - def set_alias(self, alias: str): + def set_alias(self, alias: str) -> None: self.__guard() self.__internal[0].set_alias(alias) - def __iter__(self): + def __iter__(self) -> Iterator[_InternalScopedChannel]: return iter(self.__internal) - def __len__(self): + def __len__(self) -> int: return sum(len(ch) for ch in self.__internal) @@ -243,19 +233,19 @@ class Range(Payload): and how they work. """ - __frame_client: Client | None = PrivateAttr(None) + __frame_client: framer.Client | None = PrivateAttr(None) """The frame client for executing read and write operations.""" _channels: ChannelRetriever | None = PrivateAttr(None) """For retrieving channels from the cluster.""" - _kv: KVClient | None = PrivateAttr(None) + _kv: kv.Client | None = PrivateAttr(None) """Key-value store for storing metadata about the range.""" - __aliaser: AliasClient | None = PrivateAttr(None) + __aliaser: alias.Client | None = PrivateAttr(None) """For setting and resolving aliases.""" - _cache: dict[Key, _InternalScopedChannel] = PrivateAttr(dict()) + _cache: dict[channel.Key, _InternalScopedChannel] = PrivateAttr(dict()) """A writer for creating child ranges""" - _client: Client | None = PrivateAttr(None) - _tasks: TaskClient | None = PrivateAttr(None) - _ontology: OntologyClient | None = PrivateAttr(None) + __client: Client | None = PrivateAttr(None) + __tasks: TaskClient | None = PrivateAttr(None) + __ontology: OntologyClient | None = PrivateAttr(None) def __init__( self, @@ -264,10 +254,10 @@ def __init__( key: UUID = UUID(int=0), color: str = "", *, - _frame_client: Client | None = None, + _frame_client: framer.Client | None = None, _channel_retriever: ChannelRetriever | None = None, - _kv: KVClient | None = None, - _aliaser: AliasClient | None = None, + _kv: kv.Client | None = None, + _aliaser: alias.Client | None = None, _client: Client | None = None, _tasks: TaskClient | None = None, _ontology: OntologyClient | None = None, @@ -296,18 +286,21 @@ def __init__( self._channels = _channel_retriever self._kv = _kv self.__aliaser = _aliaser - self._client = _client - self._tasks = _tasks - self._ontology = _ontology + self.__client = _client + self.__tasks = _tasks + self.__ontology = _ontology - def _get_scoped_channel(self, channels: list[Payload], query: str) -> ScopedChannel: + def _get_scoped_channel( + self, channels: list[channel.Payload], query: str + ) -> ScopedChannel: if len(channels) == 0: raise QueryError(f"Channel matching {query} not found") return ScopedChannel(query, self.__splice_cached(channels)) def __getattr__(self, query: str) -> ScopedChannel: try: - return super().__getattr__(query) + # BaseModel.__getattr__ exists at runtime but not in mypy stubs + return super().__getattr__(query) # type: ignore[misc,no-any-return] except AttributeError: pass channels = self._channel_retriever.retrieve(query) @@ -321,7 +314,9 @@ def __getitem__(self, name: str | channel.Key) -> ScopedChannel: return self._get_scoped_channel(channels, name.__str__()) return self.__getattr__(name) - def __splice_cached(self, channels: list[Payload]) -> list[_InternalScopedChannel]: + def __splice_cached( + self, channels: list[channel.Payload] + ) -> list[_InternalScopedChannel]: results = list() for pld in channels: cached = self._cache.get(pld.key, None) @@ -342,23 +337,41 @@ def ontology_id(self) -> ID: return ontology_id(self.key) @property - def meta_data(self): + def meta_data(self) -> kv.Client: if self._kv is None: raise _RANGE_NOT_CREATED return self._kv @property - def _aliaser(self): + def _aliaser(self) -> alias.Client: if self.__aliaser is None: raise _RANGE_NOT_CREATED return self.__aliaser @property - def _frame_client(self) -> Client: + def _frame_client(self) -> framer.Client: if self.__frame_client is None: raise _RANGE_NOT_CREATED return self.__frame_client + @property + def _client(self) -> Client: + if self.__client is None: + raise _RANGE_NOT_CREATED + return self.__client + + @property + def _tasks(self) -> TaskClient: + if self.__tasks is None: + raise _RANGE_NOT_CREATED + return self.__tasks + + @property + def _ontology(self) -> OntologyClient: + if self.__ontology is None: + raise _RANGE_NOT_CREATED + return self.__ontology + @property def _channel_retriever(self) -> ChannelRetriever: if self._channels is None: @@ -366,16 +379,16 @@ def _channel_retriever(self) -> ChannelRetriever: return self._channels @overload - def set_alias(self, channel: channel.Key | str, alias: str): ... + def set_alias(self, channel: channel.Key | str, alias: str) -> None: ... @overload - def set_alias(self, channel: dict[channel.Key | str, str]): ... + def set_alias(self, channel: dict[channel.Key | str, str]) -> None: ... def set_alias( self, channel: channel.Key | str | dict[channel.Key | str, str], - alias: str = None, - ): + alias: str | None = None, + ) -> None: if not isinstance(channel, dict): if alias is None: raise ValueError("Alias must be provided if channel is not a dict") @@ -394,32 +407,28 @@ def set_alias( def to_payload(self) -> Payload: return Payload(name=self.name, time_range=self.time_range, key=self.key) - @overload - def write(self, to: channel.Key | str | Payload, data: CrudeSeries): ... - @overload def write( - self, - to: ( - list[channel.Key] - | tuple[channel.Key] - | list[str] - | tuple[str] - | list[Payload] - ), - series: list[CrudeSeries], - ): ... + self, channels: channel.Params, series: CrudeSeries | list[CrudeSeries] + ) -> None: ... @overload - def write(self, frame: CrudeFrame): ... + def write(self, channels: framer.CrudeFrame) -> None: ... def write( self, - to: channel.Params | Payload | list[Payload] | CrudeFrame, + channels: channel.Params | framer.CrudeFrame, series: CrudeSeries | list[CrudeSeries] | None = None, ) -> None: start = self.time_range.start - self.__frame_client.write(start, to, series) + if series is None: + self._frame_client.write(start, cast(framer.CrudeFrame, channels)) + return + if not isinstance(channels, (int, str, list, tuple, channel.Payload)): + raise TypeError( + "channels must be a channel key, name, or list when series is provided" + ) + self._frame_client.write(start, channels, series) def create_child_range( self, @@ -469,16 +478,21 @@ def children(self) -> list[Range]: range_children = [r for r in res if r.id.type == "range"] if len(range_children) == 0: return [] - return self._client.retrieve(keys=[r.id.key for r in range_children]) + child_keys: list[Key] = [ + r.id.key for r in range_children if r.id.key is not None + ] + return self._client.retrieve(keys=child_keys) def snapshots(self) -> list[Task]: res = self._ontology.retrieve_children(self.ontology_id) tasks = [t for t in res if t.id.type == "task"] - return self._tasks.retrieve(keys=[t.id.key for t in tasks]) + return self._tasks.retrieve( + keys=[int(t.id.key) for t in tasks if t.id.key is not None] + ) class Client: - _frame_client: Client + _frame_client: framer.Client _channels: ChannelRetriever _retriever: Retriever _writer: Writer @@ -490,7 +504,7 @@ class Client: def __init__( self, unary_client: UnaryClient, - frame_client: Client, + frame_client: framer.Client, writer: Writer, retriever: Retriever, channel_retriever: ChannelRetriever, @@ -536,6 +550,7 @@ def create( def create( self, ranges: Range, + *, retrieve_if_name_exists: bool = False, parent: ID | None = None, ) -> Range: @@ -555,6 +570,7 @@ def create( def create( self, ranges: list[Range], + *, retrieve_if_name_exists: bool = False, parent: ID | None = None, ) -> list[Range]: @@ -615,6 +631,24 @@ def create( res.extend(self.__sugar(self._writer.create(to_create, parent=parent))) return res if not is_single else res[0] + @overload + def retrieve( + self, + *, + key: Key | None = None, + name: str | None = None, + ) -> Range: ... + + @overload + def retrieve( + self, + *, + key: None = None, + name: None = None, + names: list[str] | tuple[str] | None = None, + keys: list[Key] | tuple[Key] | None = None, + ) -> list[Range]: ... + @require_named_params(example_params=("name", "My Range")) def retrieve( self, @@ -648,14 +682,14 @@ def search( _ranges = self._retriever.search(term) return self.__sugar(_ranges) - def __sugar(self, ranges: list[Payload]): + def __sugar(self, ranges: list[Payload]) -> list[Range]: return [ Range( **r.model_dump(), _frame_client=self._frame_client, _channel_retriever=self._channels, - _kv=KVClient(r.key, self._unary_client), - _aliaser=AliasClient(r.key, self._unary_client), + _kv=kv.Client(r.key, self._unary_client), + _aliaser=alias.Client(r.key, self._unary_client), _client=self, _ontology=self._ontology, _tasks=self._tasks, @@ -663,9 +697,9 @@ def __sugar(self, ranges: list[Payload]): for r in ranges ] - def on_create(self, f: Callable[[Range], None]): + def on_create(self, f: Callable[[Range], None]) -> Callable[[LatestState], None]: @functools.wraps(f) - def wrapper(state: LatestState): + def wrapper(state: LatestState) -> None: d = state[RANGE_SET_CHANNEL] f( Range( @@ -677,8 +711,8 @@ def wrapper(state: LatestState): ), _frame_client=self._frame_client, _channel_retriever=self._channels, - _kv=KVClient(d["key"], self._unary_client), - _aliaser=AliasClient(d["key"], self._unary_client), + _kv=kv.Client(d["key"], self._unary_client), + _aliaser=alias.Client(d["key"], self._unary_client), ) ) diff --git a/client/py/synnax/ranger/kv/__init__.py b/client/py/synnax/ranger/kv/__init__.py index 7d8b39890c..5eabf81cdd 100644 --- a/client/py/synnax/ranger/kv/__init__.py +++ b/client/py/synnax/ranger/kv/__init__.py @@ -15,3 +15,5 @@ KVPair = Pair """Deprecated: Use Pair instead.""" + +__all__ = ["Client", "Pair", "KV", "KVPair"] diff --git a/client/py/synnax/ranger/kv/client.py b/client/py/synnax/ranger/kv/client.py index 4c09de4a7e..36013b7e6e 100644 --- a/client/py/synnax/ranger/kv/client.py +++ b/client/py/synnax/ranger/kv/client.py @@ -8,7 +8,7 @@ # included in the file licenses/APL.txt. import uuid -from typing import overload +from typing import Any, overload from freighter import UnaryClient, send_required from pydantic import BaseModel @@ -50,6 +50,9 @@ def __init__(self, rng: uuid.UUID, client: UnaryClient) -> None: @overload def get(self, keys: str) -> str: ... + @overload + def get(self, keys: list[str]) -> dict[str, str]: ... + def get(self, keys: str | list[str]) -> dict[str, str] | str: req = _GetRequest(range=self._rng_key, keys=normalize(keys)) res = send_required(self._client, "/range/kv/get", req, _GetResponse) @@ -58,12 +61,12 @@ def get(self, keys: str | list[str]) -> dict[str, str] | str: return {pair.key: pair.value for pair in res.pairs} @overload - def set(self, key: str, value: any): ... + def set(self, key: str, value: Any) -> None: ... @overload - def set(self, key: dict[str, any]): ... + def set(self, key: dict[str, Any]) -> None: ... - def set(self, key: str | dict[str, any], value: any = None) -> None: + def set(self, key: str | dict[str, Any], value: Any = None) -> None: pairs = list() if isinstance(key, str): pairs.append(Pair(range=self._rng_key, key=key, value=value)) @@ -80,7 +83,7 @@ def delete(self, keys: str | list[str]) -> None: def __getitem__(self, key: str) -> str: return self.get(key) - def __setitem__(self, key: str, value: str) -> None: + def __setitem__(self, key: str, value: Any) -> None: self.set(key, value) def __delitem__(self, key: str) -> None: diff --git a/client/py/synnax/ranger/kv/payload.py b/client/py/synnax/ranger/kv/payload.py index 16c2176efe..8035792d5e 100644 --- a/client/py/synnax/ranger/kv/payload.py +++ b/client/py/synnax/ranger/kv/payload.py @@ -20,10 +20,11 @@ class Pair(BaseModel): key: str value: str - def __init__(self, **kwargs): + def __init__(self, **kwargs: object) -> None: value = kwargs.get("value") if not isinstance(value, str): - if not is_primitive(value) and type(value).__str__ == object.__str__: + str_method = getattr(type(value), "__str__", None) + if not is_primitive(value) and str_method is object.__str__: raise ValidationError(f""" Synnax has no way of casting {value} to a string when setting metadata on a range. Please convert the value to a string before setting it. diff --git a/client/py/synnax/ranger/payload.py b/client/py/synnax/ranger/payload.py index ea69a974a3..2f8b4db65c 100644 --- a/client/py/synnax/ranger/payload.py +++ b/client/py/synnax/ranger/payload.py @@ -39,5 +39,10 @@ class Payload(BaseModel): RangeParams = Key | list[Key] | tuple[Key] | str | list[str] | tuple[str] """Parameters that can be used to query a range""" -# Backwards compatibility -RangePayload = Payload +from synnax.util.deprecation import deprecated_getattr + +_DEPRECATED = { + "RangePayload": "Payload", +} + +__getattr__ = deprecated_getattr(__name__, _DEPRECATED, globals()) diff --git a/client/py/synnax/ranger/retrieve.py b/client/py/synnax/ranger/retrieve.py index cf1fe5a9c1..64845d4a91 100644 --- a/client/py/synnax/ranger/retrieve.py +++ b/client/py/synnax/ranger/retrieve.py @@ -61,6 +61,6 @@ def __execute(self, req: _Request) -> list[Payload]: res, exc = self.__client.send("/range/retrieve", req, _Response) if exc is not None: raise exc - if res.ranges is None: + if res is None or res.ranges is None: return list() return res.ranges diff --git a/client/py/synnax/ranger/writer.py b/client/py/synnax/ranger/writer.py index f81a1c7bd6..81f8606c89 100644 --- a/client/py/synnax/ranger/writer.py +++ b/client/py/synnax/ranger/writer.py @@ -47,6 +47,6 @@ def create( return send_required(self._client, "/range/create", req, _CreateResponse).ranges @trace("debug", "range.delete") - def delete(self, keys: list[Key]): + def delete(self, keys: list[Key]) -> None: req = _DeleteRequest(keys=keys) send_required(self._client, "/range/delete", req, Empty) diff --git a/client/py/synnax/signals/__init__.py b/client/py/synnax/signals/__init__.py index 8b5887da48..e0648dca05 100644 --- a/client/py/synnax/signals/__init__.py +++ b/client/py/synnax/signals/__init__.py @@ -8,3 +8,5 @@ # included in the file licenses/APL.txt. from synnax.signals.signals import Registry + +__all__ = ["Registry"] diff --git a/client/py/synnax/signals/signals.py b/client/py/synnax/signals/signals.py index bae8ad9ef1..e5e2117a8b 100644 --- a/client/py/synnax/signals/signals.py +++ b/client/py/synnax/signals/signals.py @@ -7,9 +7,10 @@ # License, use of this software will be governed by the Apache License, Version 2.0, # included in the file licenses/APL.txt. +from __future__ import annotations + from collections.abc import Callable from functools import wraps -from multiprocessing import Pool from synnax import channel, framer from synnax.state import LatestState, State @@ -18,76 +19,80 @@ class Registry: - __handlers: list[_InternalHandler] - __channels: set[channel.Key | str] - __frame_client: framer.Client - __channel_retriever: channel.Retriever + _handlers: list[_InternalHandler] + _channels: list[channel.Key] + _frame_client: framer.Client + _channel_retriever: channel.Retriever def __init__(self, frame_client: framer.Client, channels: channel.Retriever): - self.__handlers = list() - self.__channels = set() - self.__frame_client = frame_client - self.__channel_retriever = channels + self._handlers = list() + self._channels = list() + self._frame_client = frame_client + self._channel_retriever = channels def on( self, channels: channel.Params, filter_f: Callable[[LatestState], bool], - ) -> Callable[[Callable[[LatestState], None]], Callable[[], None] | None]: - self.__channels.update(channels) - - def decorator(f: Callable[[LatestState], None]) -> None: + ) -> Callable[[Callable[[LatestState], None]], Callable[[LatestState], None]]: + normal = channel.normalize_params(channels) + if isinstance(normal, channel.NormalizedKeyResult): + self._channels.extend(normal.channels) + else: + resolved = self._channel_retriever.retrieve(channels) + self._channels.extend(ch.key for ch in resolved) + + def decorator( + f: Callable[[LatestState], None], + ) -> Callable[[LatestState], None]: @wraps(f) - def wrapper(state: State) -> Callable[[], None] | None: + def wrapper(state: State) -> Callable[[LatestState], None] | None: if filter_f(LatestState(state)): return f return None - self.__handlers.append(wrapper) - return wrapper + self._handlers.append(wrapper) + return f return decorator async def process(self) -> None: await Scheduler( - channels=self.__channels, - handlers=self.__handlers, - frame_client=self.__frame_client, - channel_retriever=self.__channel_retriever, + channels=self._channels, + handlers=self._handlers, + frame_client=self._frame_client, + channel_retriever=self._channel_retriever, ).start() class Scheduler: - __pool: Pool - __streamer: framer.AsyncStreamer | None = None - __handlers: list[_InternalHandler] - __channels: channel.Params - __state: State - __frame_client: framer.Client - __channel_retriever: channel.Retriever + _streamer: framer.AsyncStreamer | None = None + _handlers: list[_InternalHandler] + _channels: list[channel.Key] + _state: State + _frame_client: framer.Client + _channel_retriever: channel.Retriever def __init__( self, - channels: channel.Params, + channels: list[channel.Key], handlers: list[_InternalHandler], frame_client: framer.Client, channel_retriever: channel.Retriever, ): - self.__frame_client = frame_client - self.__channels = channels - self.__handlers = handlers - self.__channel_retriever = channel_retriever - self.__state = State(channel_retriever) - - async def start(self): - self.__streamer = await self.__frame_client.open_async_streamer( - list(self.__channels) - ) - async for frame in self.__streamer: - self.__state.update(frame) - for handler in self.__handlers: - res = handler(self.__state) + self._frame_client = frame_client + self._channels = channels + self._handlers = handlers + self._channel_retriever = channel_retriever + self._state = State(channel_retriever) + + async def start(self) -> None: + self._streamer = await self._frame_client.open_async_streamer(self._channels) + async for frame in self._streamer: + self._state.update(frame) + for handler in self._handlers: + res = handler(self._state) if res is not None: - res(LatestState(self.__state)) + res(LatestState(self._state)) - def stop(self): ... + def stop(self) -> None: ... diff --git a/client/py/synnax/state/state.py b/client/py/synnax/state/state.py index 20cc724736..05993ffb17 100644 --- a/client/py/synnax/state/state.py +++ b/client/py/synnax/state/state.py @@ -7,39 +7,42 @@ # License, use of this software will be governed by the Apache License, Version 2.0, # included in the file licenses/APL.txt. +from typing import Any + from synnax import channel from synnax.framer import Frame -from synnax.telem import Series +from synnax.telem import MultiSeries class State: - value: dict[channel.Key, Series] - __retriever: channel.Retriever + value: dict[channel.Key, MultiSeries] + _retriever: channel.Retriever def __init__(self, retriever: channel.Retriever): - self.__retriever = retriever + self._retriever = retriever self.value = dict() - def update(self, frame: Frame): + def update(self, frame: Frame) -> None: for key in frame.channels: - self.value[key] = frame[key] + if isinstance(key, int): + self.value[key] = frame[key] - def __getitem__(self, ch: channel.Key): - ch = channel.retrieve_required(self.__retriever, ch)[0] - return self.value[ch.key] + def __getitem__(self, ch: channel.Key | str) -> MultiSeries: + payload = channel.retrieve_required(self._retriever, ch)[0] + return self.value[payload.key] - def __getattr__(self, ch: channel.Key): - return self.__getitem__(ch) + def __getattr__(self, name: str) -> Any: + return self.__getitem__(name) class LatestState: - __state: State + _state: State def __init__(self, state: State) -> None: - self.__state = state + self._state = state - def __getitem__(self, ch: channel.Key | str): - return self.__state.value[ch][-1] + def __getitem__(self, ch: channel.Key | str) -> Any: + return self._state[ch][-1] - def __getattr__(self, ch: channel.Key | str): - return self.__getitem__(ch) + def __getattr__(self, name: str) -> Any: + return self.__getitem__(name) diff --git a/client/py/synnax/status/__init__.py b/client/py/synnax/status/__init__.py index 188237ce9a..868384d1a2 100644 --- a/client/py/synnax/status/__init__.py +++ b/client/py/synnax/status/__init__.py @@ -19,14 +19,18 @@ Variant, ontology_id, ) +from synnax.util.deprecation import deprecated_getattr -# Backwards compatibility -SUCCESS_VARIANT = VARIANT_SUCCESS -INFO_VARIANT = VARIANT_INFO -WARNING_VARIANT = VARIANT_WARNING -ERROR_VARIANT = VARIANT_ERROR -DISABLED_VARIANT = VARIANT_DISABLED -LOADING_VARIANT = VARIANT_LOADING +_DEPRECATED = { + "SUCCESS_VARIANT": "VARIANT_SUCCESS", + "INFO_VARIANT": "VARIANT_INFO", + "WARNING_VARIANT": "VARIANT_WARNING", + "ERROR_VARIANT": "VARIANT_ERROR", + "DISABLED_VARIANT": "VARIANT_DISABLED", + "LOADING_VARIANT": "VARIANT_LOADING", +} + +__getattr__ = deprecated_getattr(__name__, _DEPRECATED, globals()) __all__ = [ "Client", @@ -38,11 +42,5 @@ "VARIANT_ERROR", "VARIANT_DISABLED", "VARIANT_LOADING", - "SUCCESS_VARIANT", - "INFO_VARIANT", - "WARNING_VARIANT", - "ERROR_VARIANT", - "DISABLED_VARIANT", - "LOADING_VARIANT", "ontology_id", ] diff --git a/client/py/synnax/status/client.py b/client/py/synnax/status/client.py index 458a0681da..ae4bf9a60b 100644 --- a/client/py/synnax/status/client.py +++ b/client/py/synnax/status/client.py @@ -7,7 +7,7 @@ # License, use of this software will be governed by the Apache License, Version 2.0, # included in the file licenses/APL.txt. -from typing import overload +from typing import Any, overload from uuid import UUID from freighter import Empty, UnaryClient, send_required @@ -21,11 +21,11 @@ class _SetRequest(BaseModel): parent: ID | None = None - statuses: list[Status] + statuses: list[Status[Any]] class _SetResponse(BaseModel): - statuses: list[Status] + statuses: list[Status[Any]] class _RetrieveRequest(BaseModel): @@ -38,7 +38,7 @@ class _RetrieveRequest(BaseModel): class _RetrieveResponse(BaseModel): - statuses: list[Status] | None = None + statuses: list[Status[Any]] | None = None class _DeleteRequest(BaseModel): @@ -71,25 +71,25 @@ def __init__(self, transport: UnaryClient) -> None: @overload def set( self, - status: Status, + status: Status[Any], *, parent: ID | None = None, - ) -> Status: ... + ) -> Status[Any]: ... @overload def set( self, - statuses: list[Status], + status: list[Status[Any]], *, parent: ID | None = None, - ) -> list[Status]: ... + ) -> list[Status[Any]]: ... def set( self, - status: Status | list[Status] | None = None, + status: Status[Any] | list[Status[Any]] | None = None, *, parent: ID | None = None, - ) -> Status | list[Status]: + ) -> Status[Any] | list[Status[Any]]: """Create or update a status. Args: @@ -137,12 +137,12 @@ def set( return res @overload - def retrieve(self, *, key: str, include_labels: bool = False) -> Status: ... + def retrieve(self, *, key: str, include_labels: bool = False) -> Status[Any]: ... @overload def retrieve( self, *, keys: list[str], include_labels: bool = False - ) -> list[Status]: ... + ) -> list[Status[Any]]: ... @overload def retrieve( @@ -153,7 +153,7 @@ def retrieve( limit: int | None = None, include_labels: bool = False, has_labels: list[UUID] | None = None, - ) -> list[Status]: ... + ) -> list[Status[Any]]: ... @require_named_params(example_params=("key", "'status-key-123'")) def retrieve( @@ -166,7 +166,7 @@ def retrieve( limit: int | None = None, include_labels: bool = False, has_labels: list[UUID] | None = None, - ) -> Status | list[Status]: + ) -> Status[Any] | list[Status[Any]]: """Retrieve statuses from the cluster. Args: @@ -205,7 +205,7 @@ def retrieve( ... ) """ single = key is not None - if single: + if single and key is not None: keys = [key] res = send_required( @@ -223,7 +223,9 @@ def retrieve( ).statuses if res is None: - return [] if not single else None + if single: + raise ValueError(f"Status with key '{key}' not found") + return [] if single: if len(res) == 0: diff --git a/client/py/synnax/status/payload.py b/client/py/synnax/status/payload.py index c633d0ad7e..5766477610 100644 --- a/client/py/synnax/status/payload.py +++ b/client/py/synnax/status/payload.py @@ -66,7 +66,7 @@ class Status(BaseModel, Generic[D]): """The time the status was created.""" labels: list[Any] | None = None """Optional labels attached to the status (only present in responses).""" - details: D = None + details: D | None = None """The details are customizable details for component specific statuses.""" @property @@ -79,10 +79,15 @@ def ontology_id(self) -> ID: return ontology_id(self.key) -# Backwards compatibility -SUCCESS_VARIANT = VARIANT_SUCCESS -INFO_VARIANT = VARIANT_INFO -WARNING_VARIANT = VARIANT_WARNING -ERROR_VARIANT = VARIANT_ERROR -DISABLED_VARIANT = VARIANT_DISABLED -LOADING_VARIANT = VARIANT_LOADING +from synnax.util.deprecation import deprecated_getattr + +_DEPRECATED = { + "SUCCESS_VARIANT": "VARIANT_SUCCESS", + "INFO_VARIANT": "VARIANT_INFO", + "WARNING_VARIANT": "VARIANT_WARNING", + "ERROR_VARIANT": "VARIANT_ERROR", + "DISABLED_VARIANT": "VARIANT_DISABLED", + "LOADING_VARIANT": "VARIANT_LOADING", +} + +__getattr__ = deprecated_getattr(__name__, _DEPRECATED, globals()) diff --git a/client/py/synnax/synnax.py b/client/py/synnax/synnax.py index 3e2ce0dfcb..2782524ea1 100644 --- a/client/py/synnax/synnax.py +++ b/client/py/synnax/synnax.py @@ -120,14 +120,21 @@ def __init__( self._transport.use(self.auth.middleware()) self._transport.use_async(self.auth.async_middleware()) - ch_retriever = channel.ClusterRetriever(self._transport.unary, instrumentation) + cluster_retriever = channel.ClusterRetriever( + self._transport.unary, instrumentation + ) + cache_retriever: channel.CacheRetriever | None = None + ch_retriever: channel.Retriever if cache_channels: - ch_retriever = channel.CacheRetriever(ch_retriever, instrumentation) + cache_retriever = channel.CacheRetriever(cluster_retriever, instrumentation) + ch_retriever = cache_retriever + else: + ch_retriever = cluster_retriever deleter = framer.Deleter(self._transport.unary, instrumentation) ch_creator = channel.Writer( self._transport.unary, instrumentation, - ch_retriever if cache_channels else None, + cache_retriever, ) super().__init__( stream_client=self._transport.stream, @@ -183,7 +190,7 @@ def hardware(self) -> "Synnax": ) return self - def close(self): + def close(self) -> None: """Shuts down the client and closes all connections. All open iterators or writers must be closed before calling this method. """ diff --git a/client/py/synnax/task/__init__.py b/client/py/synnax/task/__init__.py index 91fc1250e8..f72474e4b5 100644 --- a/client/py/synnax/task/__init__.py +++ b/client/py/synnax/task/__init__.py @@ -18,15 +18,19 @@ Task, ) from synnax.task.payload import Payload, Status, StatusDetails +from synnax.util.deprecation import deprecated_getattr -# Backwards compatibility -TaskPayload = Payload -TaskStatus = Status -TaskStatusDetails = StatusDetails -BaseTaskConfig = BaseConfig -BaseReadTaskConfig = BaseReadConfig -BaseWriteTaskConfig = BaseWriteConfig -TaskProtocol = Protocol +_DEPRECATED = { + "TaskPayload": "Payload", + "TaskStatus": "Status", + "TaskStatusDetails": "StatusDetails", + "BaseTaskConfig": "BaseConfig", + "BaseReadTaskConfig": "BaseReadConfig", + "BaseWriteTaskConfig": "BaseWriteConfig", + "TaskProtocol": "Protocol", +} + +__getattr__ = deprecated_getattr(__name__, _DEPRECATED, globals()) __all__ = [ "Client", @@ -40,11 +44,4 @@ "JSONConfigMixin", "StarterStopperMixin", "Protocol", - "TaskPayload", - "TaskStatus", - "TaskStatusDetails", - "BaseTaskConfig", - "BaseReadTaskConfig", - "BaseWriteTaskConfig", - "TaskProtocol", ] diff --git a/client/py/synnax/task/client.py b/client/py/synnax/task/client.py index 1f6eab54c5..b5af85ed1b 100644 --- a/client/py/synnax/task/client.py +++ b/client/py/synnax/task/client.py @@ -11,19 +11,22 @@ import json import warnings +from collections.abc import Generator from contextlib import contextmanager +from typing import TYPE_CHECKING, Annotated, Any from typing import Protocol as BaseProtocol from typing import overload from uuid import uuid4 from alamos import NOOP, Instrumentation from freighter import Empty, UnaryClient, send_required -from pydantic import BaseModel, Field, ValidationError, conint, field_validator +from pydantic import BaseModel, Field, ValidationError, field_validator from synnax.device import Client as DeviceClient from synnax.device import Device from synnax.exceptions import ConfigurationError, UnexpectedError from synnax.framer import Client as FrameClient +from synnax.ontology.payload import ID from synnax.rack import Client as RackClient from synnax.rack import Rack from synnax.status import VARIANT_ERROR, VARIANT_SUCCESS @@ -102,13 +105,13 @@ class BaseReadConfig(BaseConfig): data_saving: bool = True "Whether to persist acquired data to disk (True) or only stream it (False)." - sample_rate: conint(ge=0, le=50000) + sample_rate: Annotated[int, Field(ge=0, le=50000)] "The rate at which to sample data from the hardware device (Hz)." - stream_rate: conint(ge=0, le=50000) + stream_rate: Annotated[int, Field(ge=0, le=50000)] "The rate at which acquired data will be streamed to the Synnax cluster (Hz)." @field_validator("stream_rate") - def validate_stream_rate(cls, v, info): + def validate_stream_rate(cls, v: int, info: Any) -> int: """Validate that stream_rate is less than or equal to sample_rate.""" if "sample_rate" in info.data and v > info.data["sample_rate"]: raise ValueError( @@ -138,7 +141,7 @@ class Task: config: str = "" snapshot: bool = False status: Status | None = None - _frame_client: FrameClient | None = None + __frame_client: FrameClient | None = None def __init__( self, @@ -160,7 +163,15 @@ def __init__( self.config = config self.snapshot = snapshot self.status = status - self._frame_client = _frame_client + self.__frame_client = _frame_client + + @property + def _frame_client(self) -> FrameClient: + if self.__frame_client is None: + raise RuntimeError( + "Cannot execute commands on a task that has not been created or retrieved from the cluster." + ) + return self.__frame_client def to_payload(self) -> Payload: return Payload( @@ -170,21 +181,16 @@ def to_payload(self) -> Payload: config=self.config, ) - def set_internal(self, task: Task): + def set_internal(self, task: Task) -> None: self.key = task.key self.name = task.name self.type = task.type self.config = task.config self.snapshot = task.snapshot - self._frame_client = task._frame_client + self.__frame_client = task.__frame_client @property - def ontology_id(self) -> dict: - """Get the ontology ID for the task. - - Returns: - An ontology ID dictionary with type "task" and the task key. - """ + def ontology_id(self) -> ID: return ontology_id(self.key) def update_device_properties(self, device_client: DeviceClient) -> Device | None: @@ -199,7 +205,7 @@ def update_device_properties(self, device_client: DeviceClient) -> Device | None """ return None - def execute_command(self, type_: str, args: dict | None = None) -> str: + def execute_command(self, type_: str, args: dict[str, Any] | None = None) -> str: """Executes a command on the task and returns the unique key assigned to the command. @@ -219,7 +225,7 @@ def execute_command(self, type_: str, args: dict | None = None) -> str: def execute_command_sync( self, type_: str, - args: dict | None = None, + args: dict[str, Any] | None = None, timeout: float | TimeSpan = 5, ) -> Status: """Executes a command on the task and waits for the driver to acknowledge the @@ -243,7 +249,11 @@ def execute_command_sync( continue try: status = Status.model_validate(frame[_TASK_STATE_CHANNEL][0]) - if status.details.cmd is not None and status.details.cmd == key: + if ( + status.details is not None + and status.details.cmd is not None + and status.details.cmd == key + ): return status except ValidationError as e: raise UnexpectedError(f""" @@ -252,11 +262,12 @@ def execute_command_sync( class Protocol(BaseProtocol): - key: int + @property + def key(self) -> int: ... def to_payload(self) -> Payload: ... - def set_internal(self, task: Task): ... + def set_internal(self, task: Task) -> None: ... def update_device_properties(self, device_client: DeviceClient) -> Device | None: """Update device properties before task configuration. @@ -277,7 +288,7 @@ def update_device_properties(self, device_client: DeviceClient) -> Device | None class StarterStopperMixin: _internal: Task - def start(self, timeout: float | TimeSpan = 5): + def start(self, timeout: float | TimeSpan = 5) -> None: """Starts the task and blocks until the Synnax cluster has acknowledged the command or the specified timeout has elapsed. @@ -287,7 +298,7 @@ def start(self, timeout: float | TimeSpan = 5): """ self._internal.execute_command_sync("start", timeout=timeout) - def stop(self, timeout: float | TimeSpan = 5): + def stop(self, timeout: float | TimeSpan = 5) -> None: """Stops the task and blocks until the Synnax cluster has acknowledged the command or the specified timeout has elapsed. @@ -298,7 +309,7 @@ def stop(self, timeout: float | TimeSpan = 5): self._internal.execute_command_sync("stop", timeout=timeout) @contextmanager - def run(self, timeout: float | TimeSpan = 5): + def run(self, timeout: float | TimeSpan = 5) -> Generator[None, None, None]: """Context manager that starts the task before entering the block and stops the task after exiting the block. This is useful for ensuring that the task is properly stopped even if an exception occurs during execution. @@ -329,7 +340,7 @@ def to_payload(self) -> Payload: pld.config = json.dumps(self.config.model_dump()) return pld - def set_internal(self, task: Task): + def set_internal(self, task: Task) -> None: """Implements TaskProtocol protocol""" self._internal = task @@ -366,7 +377,7 @@ def create( type: str = "", config: str = "", rack: int = 0, - ): ... + ) -> Task: ... @overload def create(self, tasks: Task) -> Task: ... @@ -386,24 +397,24 @@ def create( ) -> Task | list[Task]: is_single = True if tasks is None: - tasks = [Payload(key=key, name=name, type=type, config=config)] + payloads = [Payload(key=key, name=name, type=type, config=config)] elif isinstance(tasks, Task): - tasks = [tasks.to_payload()] + payloads = [tasks.to_payload()] else: is_single = False - tasks = [t.to_payload() for t in tasks] - for pld in tasks: + payloads = [t.to_payload() for t in tasks] + for pld in payloads: self.maybe_assign_def_rack(pld, rack) - req = _CreateRequest(tasks=tasks) - tasks = self.__exec_create(req) - sugared = self.sugar(tasks) + req = _CreateRequest(tasks=payloads) + created = self.__exec_create(req) + sugared = self.sugar(created) return sugared[0] if is_single else sugared def __exec_create(self, req: _CreateRequest) -> list[Payload]: res = send_required(self._client, "/task/create", req, _CreateResponse) return res.tasks - def maybe_assign_def_rack(self, pld: Payload, rack: int = 0) -> Rack: + def maybe_assign_def_rack(self, pld: Payload, rack: int = 0) -> Payload: if self._default_rack is None: # Hardcoded as this value for now. Will be changed once we have multi-rack # systems @@ -438,7 +449,7 @@ def configure(self, task: Protocol, timeout: float = 5) -> Protocol: warnings.warn("task - unexpected missing state in frame") continue status = Status.model_validate(frame[_TASK_STATE_CHANNEL][0]) - if status.details.task != task.key: + if status.details is None or status.details.task != task.key: continue if status.variant == VARIANT_SUCCESS: break @@ -446,14 +457,13 @@ def configure(self, task: Protocol, timeout: float = 5) -> Protocol: raise ConfigurationError(status.message) return task - def delete(self, keys: int | list[int]): + def delete(self, keys: int | list[int]) -> None: req = _DeleteRequest(keys=normalize(keys)) send_required(self._client, "/task/delete", req, Empty) @overload def retrieve( self, - *, key: int | None = None, name: str | None = None, type: str | None = None, @@ -462,6 +472,9 @@ def retrieve( @overload def retrieve( self, + key: None = None, + name: None = None, + type: None = None, names: list[str] | None = None, keys: list[int] | None = None, types: list[str] | None = None, @@ -487,7 +500,7 @@ def retrieve( ), _RetrieveResponse, ) - sug = self.sugar(res.tasks) + sug = self.sugar(res.tasks or []) # Warn if multiple tasks found when retrieving by name if is_single and name is not None and len(sug) > 1: @@ -501,6 +514,9 @@ def retrieve( return sug[0] if is_single else sug + def sugar(self, tasks: list[Payload]) -> list[Task]: + return [Task(**t.model_dump(), _frame_client=self._frame_client) for t in tasks] + def list(self, rack: int | None = None) -> list[Task]: """Lists all tasks on a rack. If no rack is specified, lists all tasks on the default rack. Excludes internal system tasks (scanner tasks and rack state). @@ -508,9 +524,9 @@ def list(self, rack: int | None = None) -> list[Task]: :param rack: The rack key to list tasks from. If None, uses the default rack. :return: A list of all user-created tasks on the specified rack. """ - if rack is None and self._default_rack is None: - self._default_rack = self._racks.retrieve_embedded_rack() if rack is None: + if self._default_rack is None: + self._default_rack = self._racks.retrieve_embedded_rack() rack = self._default_rack.key res = send_required( @@ -519,7 +535,7 @@ def list(self, rack: int | None = None) -> list[Task]: _RetrieveRequest(rack=rack, internal=False), _RetrieveResponse, ) - return self.sugar(res.tasks) + return self.sugar(res.tasks or []) def copy( self, @@ -535,6 +551,3 @@ def copy( req = _CopyRequest(key=key, name=name, snapshot=False) res = send_required(self._client, _COPY_ENDPOINT, req, _CopyResponse) return self.sugar([res.task])[0] - - def sugar(self, tasks: list[Payload]): - return [Task(**t.model_dump(), _frame_client=self._frame_client) for t in tasks] diff --git a/client/py/synnax/task/payload.py b/client/py/synnax/task/payload.py index fa93b858c8..c91ea5f0cc 100644 --- a/client/py/synnax/task/payload.py +++ b/client/py/synnax/task/payload.py @@ -8,6 +8,8 @@ # included in the file licenses/APL.txt. +from typing import Any + from pydantic import BaseModel from synnax import ontology, status @@ -24,7 +26,7 @@ class StatusDetails(BaseModel): """The key of the task.""" running: bool = False """Whether the task is running.""" - data: dict | None = None + data: dict[str, Any] | None = None """Arbitrary data about the task.""" cmd: str | None = None @@ -56,8 +58,13 @@ def ontology_id(key: int) -> ontology.ID: return ontology.ID(type=ONTOLOGY_TYPE.type, key=str(key)) -# Backwards compatibility -TASK_ONTOLOGY_TYPE = ONTOLOGY_TYPE -TaskPayload = Payload -TaskStatus = Status -TaskStatusDetails = StatusDetails +from synnax.util.deprecation import deprecated_getattr + +_DEPRECATED = { + "TASK_ONTOLOGY_TYPE": "ONTOLOGY_TYPE", + "TaskPayload": "Payload", + "TaskStatus": "Status", + "TaskStatusDetails": "StatusDetails", +} + +__getattr__ = deprecated_getattr(__name__, _DEPRECATED, globals()) diff --git a/client/py/synnax/telem/series.py b/client/py/synnax/telem/series.py index e90ab96eb5..a9d6c09449 100644 --- a/client/py/synnax/telem/series.py +++ b/client/py/synnax/telem/series.py @@ -11,7 +11,9 @@ import json import uuid +from collections.abc import Iterator from datetime import datetime +from typing import Any, TypeAlias import numpy as np import pandas as pd @@ -70,7 +72,7 @@ def __len__(self) -> int: def __init__( self, - data: CrudeSeries, + data: CrudeSeries | MultiSeries, data_type: CrudeDataType | None = None, time_range: TimeRange | None = None, alignment: CrudeAlignment = 0, @@ -98,9 +100,9 @@ def __init__( raise ValueError( "[Series] - MultiSeries with more than one series cannot be converted to a Series" ) - data_type = data_type or data + data_type = data_type or data.data_type elif isinstance(data, pd.Series): - data_type = data_type or DataType(data.dtype) + data_type = data_type or DataType(str(data.dtype)) data_ = data.to_numpy(dtype=data_type.np).tobytes() elif isinstance(data, np.ndarray): data_type = data_type or DataType(data.dtype) @@ -112,12 +114,19 @@ def __init__( b"\n".join([json.dumps(d).encode("utf-8") for d in data]) + b"\n" ) elif data_type == DataType.STRING: - data_ = b"\n".join([d.encode("utf-8") for d in data]) + b"\n" + data_ = b"\n".join([str(d).encode("utf-8") for d in data]) + b"\n" elif data_type == DataType.UUID: - data_ = b"".join(d.bytes for d in data) + uuids = [d for d in data if isinstance(d, uuid.UUID)] + data_ = b"".join(d.bytes for d in uuids) else: data_ = np.array(data, dtype=data_type.np).tobytes() data_type = data_type or DataType(data) + elif isinstance(data, uuid.UUID): + data_type = DataType.UUID + data_ = data.bytes + elif isinstance(data, dict): + data_type = data_type or DataType.JSON + data_ = json.dumps(data).encode("utf-8") + b"\n" elif isinstance(data, str): data_ = bytes(f"{data}\n", "utf-8") data_type = DataType.STRING @@ -175,13 +184,13 @@ def __getitem__(self, index: int) -> SampleValue: if self.data_type == DataType.JSON: d = self.__newline_getitem__(index) - return json.loads(d) + return json.loads(d) # type: ignore[no-any-return] if self.data_type == DataType.STRING: d = self.__newline_getitem__(index) return d.decode("utf-8") - return self.__array__()[index] + return self.__array__()[index] # type: ignore[no-any-return] def __newline_getitem__(self, index: int) -> bytes: if index == 0: @@ -201,7 +210,7 @@ def __newline_getitem__(self, index: int) -> bytes: end = len(self.data) return self.data[start:end] - def __iter__(self): + def __iter__(self) -> Iterator[SampleValue]: # type: ignore[override] if self.data_type == DataType.UUID: yield from [self[i] for i in range(len(self))] elif self.data_type == DataType.JSON: @@ -213,7 +222,7 @@ def __iter__(self): else: yield from self.__array__() - def __iter__newline(self): + def __iter__newline(self) -> Iterator[bytes]: curr = 0 while curr < len(self.data): end = self.data.find(b"\n", curr) @@ -253,27 +262,34 @@ def astype(self, data_type: DataType) -> Series: def to_datetime(self) -> list[datetime]: return [pd.Timestamp(t).to_pydatetime() for t in self.__array__()] - def __eq__(self, other): + def __eq__(self, other: object) -> bool: if isinstance(other, Series): return self.data == other.data elif isinstance(other, np.ndarray): - return self.__array__() == other + return self.__array__() == other # type: ignore[no-any-return] else: return False -Series = overload_comparison_operators(Series, "__array__") +overload_comparison_operators(Series, "__array__") -SampleValue = np.number | uuid.UUID | dict | str | int | float | TimeStamp -TypedCrudeSeries = Series | pd.Series | np.ndarray -CrudeSeries = ( +SampleValue: TypeAlias = ( + np.number | uuid.UUID | dict[str, Any] | str | int | float | TimeStamp +) +TypedCrudeSeries: TypeAlias = Series | pd.Series | np.ndarray +CrudeSeries: TypeAlias = ( Series | bytes | pd.Series | np.ndarray | list[float] | list[str] - | list[dict] + | list[dict[str, Any]] + | list[uuid.UUID] + | np.number + | str + | uuid.UUID + | dict[str, Any] | float | int | TimeStamp @@ -286,7 +302,7 @@ def elapsed_seconds(d: np.ndarray) -> np.ndarray: :param d: A Series of timestamps. :returns: A Series of elapsed seconds. """ - return (d - d[0]) / TimeSpan.SECOND + return (d - d[0]) / TimeSpan.SECOND # type: ignore[no-any-return] class MultiSeries: @@ -379,11 +395,11 @@ def __getitem__(self, index: int) -> SampleValue: index -= len(s) raise IndexError(f"[MultiSeries] - Index {index} out of bounds for {len(self)}") - def __iter__(self): + def __iter__(self) -> Iterator[SampleValue]: for s in self.series: yield from s - def __str__(self): + def __str__(self) -> str: return str(list(self)) @property @@ -391,4 +407,4 @@ def size(self) -> Size: return Size(sum(s.size for s in self.series)) -MultiSeries = overload_comparison_operators(MultiSeries, "__array__") +overload_comparison_operators(MultiSeries, "__array__") diff --git a/client/py/synnax/telem/telem.py b/client/py/synnax/telem/telem.py index 747b842038..87d7be7b5f 100644 --- a/client/py/synnax/telem/telem.py +++ b/client/py/synnax/telem/telem.py @@ -12,7 +12,7 @@ import uuid from datetime import UTC, datetime, timedelta, tzinfo from math import trunc -from typing import Any, ClassVar, Literal, TypeAlias, cast, get_args +from typing import Any, ClassVar, Literal, TypeAlias import numpy as np import pandas as pd @@ -66,12 +66,12 @@ def __new__(cls, value: CrudeTimeStamp) -> TimeStamp: @classmethod def __get_pydantic_core_schema__( cls, source_type: Any, handler: GetCoreSchemaHandler - ): + ) -> core_schema.CoreSchema: """Implemented for pydantic validation. Should not be used externally.""" return core_schema.no_info_after_validator_function(cls, handler(int)) @classmethod - def validate(cls, value, *args, **kwargs): + def validate(cls, value: Any, *args: Any, **kwargs: Any) -> TimeStamp: """Implemented for pydantic validation. Should not be used externally.""" return cls(value) @@ -171,9 +171,11 @@ def __gt__(self, rhs: CrudeTimeStamp) -> bool: return self.after(rhs) def __eq__(self, rhs: object) -> bool: - if not isinstance(rhs, get_args(CrudeTimeStamp)): + if not isinstance( + rhs, (int, TimeSpan, datetime, timedelta, np.datetime64, np.int64) + ): return False - return super().__eq__(TimeStamp(cast(CrudeTimeStamp, rhs))) + return super().__eq__(TimeStamp(rhs)) def __str__(self) -> str: return self.datetime().isoformat() @@ -199,7 +201,7 @@ class TimeSpan(int): * np.timedelta64 - The duration of the timedelta64. """ - def __new__(cls, value: CrudeTimeSpan): + def __new__(cls, value: CrudeTimeSpan) -> TimeSpan: if isinstance(value, str): value = int(value) if isinstance(value, timedelta): @@ -213,12 +215,12 @@ def __new__(cls, value: CrudeTimeSpan): @classmethod def __get_pydantic_core_schema__( cls, source_type: Any, handler: GetCoreSchemaHandler - ): + ) -> core_schema.CoreSchema: """Implemented for pydantic validation. Should not be used externally.""" return core_schema.no_info_after_validator_function(cls, handler(int)) @classmethod - def validate(cls, value, *args, **kwargs): + def validate(cls, value: Any, *args: Any, **kwargs: Any) -> TimeSpan: """Implemented for pydantic validation. Should not be used externally.""" return cls(value) @@ -412,25 +414,27 @@ def __truediv__(self, rhs: CrudeTimeSpan) -> TimeSpan: def __mod__(self, rhs: CrudeTimeSpan) -> TimeSpan: return TimeSpan(super().__mod__(TimeSpan(rhs))) - def __rmul__(self, rhs: CrudeTimeSpan) -> TimeSpan: + # override widens int's parameter type to accept CrudeTimeSpan + def __rmul__(self, rhs: CrudeTimeSpan) -> TimeSpan: # type: ignore[misc] return self.__mul__(rhs) - def __gt__(self, rhs: CrudeTimeSpan) -> bool: + # overrides below widen int's parameter type to accept CrudeTimeSpan + def __gt__(self, rhs: CrudeTimeSpan) -> bool: # type: ignore[misc] return super().__gt__(TimeSpan(rhs)) - def __ge__(self, rhs: CrudeTimeSpan) -> bool: + def __ge__(self, rhs: CrudeTimeSpan) -> bool: # type: ignore[misc] return super().__ge__(TimeSpan(rhs)) - def __lt__(self, rhs: CrudeTimeSpan) -> bool: + def __lt__(self, rhs: CrudeTimeSpan) -> bool: # type: ignore[misc] return super().__lt__(TimeSpan(rhs)) - def __le__(self, rhs: CrudeTimeSpan) -> bool: + def __le__(self, rhs: CrudeTimeSpan) -> bool: # type: ignore[misc] return super().__le__(TimeSpan(rhs)) def __eq__(self, rhs: object) -> bool: - if not isinstance(rhs, get_args(CrudeTimeSpan)): + if not isinstance(rhs, (int, float, timedelta, np.timedelta64)): return NotImplemented - return super().__eq__(int(TimeSpan(cast(CrudeTimeSpan, rhs)))) + return super().__eq__(int(TimeSpan(rhs))) NANOSECOND: TimeSpan """A nanosecond.""" @@ -502,7 +506,9 @@ def __eq__(self, rhs: object) -> bool: TimeSpanUnits = Literal["ns", "us", "ms", "s", "m", "h", "d", "iso"] -def convert_time_units(data: np.ndarray, _from: TimeSpanUnits, to: TimeSpanUnits): +def convert_time_units( + data: np.ndarray, _from: TimeSpanUnits, to: TimeSpanUnits +) -> np.ndarray: """Converts the data from one time unit to another. :param data: the numpy array to convert. @@ -532,7 +538,7 @@ def convert_time_units(data: np.ndarray, _from: TimeSpanUnits, to: TimeSpanUnits class Rate(float): """Rate represents a data rate measured in Hz.""" - def __new__(cls, value: CrudeRate): + def __new__(cls, value: CrudeRate) -> Rate: if isinstance(value, float): return super().__new__(cls, value) if isinstance(value, TimeSpan): @@ -546,12 +552,12 @@ def __new__(cls, value: CrudeRate): @classmethod def __get_pydantic_core_schema__( cls, source_type: Any, handler: GetCoreSchemaHandler - ): + ) -> core_schema.CoreSchema: """Implemented for pydantic validation. Should not be used externally.""" return core_schema.no_info_after_validator_function(cls, handler(float)) @classmethod - def validate(cls, v, *args, **kwargs): + def validate(cls, v: Any, *args: Any, **kwargs: Any) -> Rate: """Implemented for pydantic validation. Should not be used externally.""" return cls(v) @@ -586,18 +592,18 @@ def size_span(self, size: Size, density: Density) -> TimeSpan: raise ContiguityError(f"Size {size} is not a multiple of density {density}") return self.span(int(size / density)) - def __str__(self): + def __str__(self) -> str: if self < 1: return f"{self.period} per cycle" return f"{round(self, 2)} Hz" - def __repr__(self): + def __repr__(self) -> str: return f"Rate({super().__repr__()} Hz)" def __mul__(self, rhs: CrudeRate) -> Rate: return Rate(super().__mul__(Rate(rhs))) - def __rmul__(self, other) -> Rate: + def __rmul__(self, other: CrudeRate) -> Rate: return self.__mul__(other) HZ: Rate @@ -639,24 +645,15 @@ def __init__( self, start: CrudeTimeStamp | TimeRange, end: CrudeTimeStamp | None = None, - *args, - **kwargs, + *args: Any, + **kwargs: Any, ): if isinstance(start, TimeRange): - start_ = cast(TimeRange, start) - start, end = start_.start, start_.end + end = start.end + start = start.start end = start if end is None else end super().__init__(start=TimeStamp(start), end=TimeStamp(end)) - @classmethod - def validate(cls, v, *args, **kwargs): - """Implemented for pydantic validation. Should not be used externally.""" - if isinstance(v, TimeRange): - return cls(v.start, v.end) - elif isinstance(v, dict): - return cls(**v) - return cls(start=v[0], end=v[1]) - @property def span(self) -> TimeSpan: """:returns: the TimeSpan between the start and end TimeStamps of the TimeRange.""" @@ -717,7 +714,7 @@ def swap(self) -> TimeRange: self.copy() return TimeRange(start=self.end, end=self.start) - def copy(self, *args, **kwargs) -> TimeRange: + def copy(self, *args: Any, **kwargs: Any) -> TimeRange: """:returns: A copy of the time range.""" return TimeRange(start=self.start, end=self.end) @@ -759,7 +756,7 @@ def __eq__(self, rhs: object) -> bool: class Density(int): """Density is the number of bytes contained in a single sample.""" - def __new__(cls, value: CrudeDensity): + def __new__(cls, value: CrudeDensity) -> Density: if isinstance(value, Density): return value if isinstance(value, int): @@ -769,12 +766,12 @@ def __new__(cls, value: CrudeDensity): @classmethod def __get_pydantic_core_schema__( cls, source_type: Any, handler: GetCoreSchemaHandler - ): + ) -> core_schema.CoreSchema: """Implemented for pydantic validation. Should not be used externally.""" return core_schema.no_info_after_validator_function(cls, handler(int)) @classmethod - def validate(cls, v, *args, **kwargs): + def validate(cls, v: Any, *args: Any, **kwargs: Any) -> Density: """Implemented for pydantic validation. Should not be used externally.""" return cls(v) @@ -786,7 +783,7 @@ def size_span(self, sample_count: int) -> Size: """:returns: The number of bytes occupied by the given number of samples.""" return Size(sample_count * self) - def __repr__(self): + def __repr__(self) -> str: return f"Density({super().__repr__()})" UNKNOWN: Density @@ -899,13 +896,13 @@ def __add__(self, rhs: CrudeSize) -> Size: return Size(int(self) + Size(rhs)) BYTE: Size - BYTE_UNITS: SizeUnits + BYTE_UNITS: str KB: Size - KB_UNITS: SizeUnits + KB_UNITS: str MB: Size - MB_UNITS: SizeUnits + MB_UNITS: str GB: Size - GB_UNITS: SizeUnits + GB_UNITS: str Size.BYTE = Size(1) @@ -923,7 +920,7 @@ def __add__(self, rhs: CrudeSize) -> Size: class DataType(str): """DataType represents a data type as a string.""" - def __new__(cls, value: CrudeDataType): + def __new__(cls, value: CrudeDataType) -> DataType: if isinstance(value, DataType): return value @@ -931,9 +928,9 @@ def __new__(cls, value: CrudeDataType): return super().__new__(cls, value) if isinstance(value, np.number): - value = DataType._FROM_NUMPY.get(np.dtype(value), None) - if value is not None: - return value + result = DataType._FROM_NUMPY.get(np.dtype(value), None) + if result is not None: + return result if isinstance(value, float): return DataType.FLOAT64 @@ -971,31 +968,33 @@ def __new__(cls, value: CrudeDataType): if isinstance(value, dict): return DataType.JSON - value = DataType._FROM_NUMPY.get(np.dtype(value), None) - if value is not None: - return value + result = DataType._FROM_NUMPY.get(np.dtype(value), None) + if result is not None: + return result raise TypeError(f"Cannot convert {type(value)} to DataType") @classmethod def __get_pydantic_core_schema__( cls, source_type: Any, handler: GetCoreSchemaHandler - ): + ) -> core_schema.CoreSchema: """Implemented for pydantic validation. Should not be used externally.""" return core_schema.no_info_after_validator_function(cls, handler(str)) @classmethod - def validate(cls, v, *args, **kwargs): + def validate(cls, v: Any, *args: Any, **kwargs: Any) -> DataType: """Implemented for pydantic validation. Should not be used externally.""" return cls(v) @classmethod - def __modify_schema__(cls, field_schema): + def __modify_schema__(cls, field_schema: dict[str, Any]) -> None: """Implemented for pydantic validation. Should not be used externally.""" field_schema.update(type="string") @classmethod - def __get_pydantic_json_schema__(cls, _schema_generator, _field_schema): + def __get_pydantic_json_schema__( + cls, _schema_generator: Any, _field_schema: Any + ) -> dict[str, str]: """Implemented for pydantic validation. Should not be used externally.""" return {"type": "string"} @@ -1005,9 +1004,9 @@ def np(self) -> np.dtype: :return: The numpy type """ npt = DataType._TO_NUMPY.get(self, None) - if npt is None: + if not isinstance(npt, np.dtype): raise TypeError(f"Cannot convert {self} to numpy type") - return cast(np.dtype, npt) + return npt @property def is_variable(self) -> bool: @@ -1025,7 +1024,7 @@ def density(self) -> Density: """ return DataType._DENSITIES.get(self, Density.UNKNOWN) - def __repr__(self): + def __repr__(self) -> str: return f"DataType({super().__repr__()})" UNKNOWN: DataType @@ -1088,7 +1087,9 @@ def __repr__(self): ) CrudeRate: TypeAlias = int | float | TimeSpan | Rate CrudeDensity: TypeAlias = Density | int -CrudeDataType: TypeAlias = DTypeLike | DataType | str | list | np.number +CrudeDataType: TypeAlias = ( + DTypeLike | DataType | str | list[str] | np.number | int | float +) CrudeSize: TypeAlias = int | float | Size DataType._TO_NUMPY = { @@ -1208,7 +1209,7 @@ def __new__( @classmethod def __get_pydantic_core_schema__( cls, source_type: Any, handler: GetCoreSchemaHandler - ): + ) -> core_schema.CoreSchema: """Implemented for pydantic validation. Should not be used externally.""" return core_schema.no_info_after_validator_function(cls._validate, handler(int)) diff --git a/client/py/synnax/timing/timing.py b/client/py/synnax/timing/timing.py index 4964ad3aff..47388afa40 100644 --- a/client/py/synnax/timing/timing.py +++ b/client/py/synnax/timing/timing.py @@ -7,18 +7,21 @@ # License, use of this software will be governed by the Apache License, Version 2.0, # included in the file licenses/APL.txt. +from __future__ import annotations + import math import time +from typing import Literal from synnax.telem import Rate, TimeSpan, TimeStamp RESOLUTION = (100 * TimeSpan.MICROSECOND).seconds -def _precise_sleep(dur: float | int): +def _precise_sleep(dur: float | int) -> None: estimate = RESOLUTION * 10 # Initial overestimate mean = RESOLUTION * 10 - m2 = 0 + m2: float = 0 count = 1 end_time = time.perf_counter() + dur nanoseconds = dur * 1e9 @@ -38,7 +41,7 @@ def _precise_sleep(dur: float | int): pass -def sleep(dur: Rate | TimeSpan | float | int, precise: bool = False): +def sleep(dur: Rate | TimeSpan | float | int, precise: bool = False) -> None: """Sleeps for the given duration, with the option to use a high-precision sleep that is more accurate than Python's default time.sleep implementation. @@ -67,18 +70,18 @@ class Timer: _start: TimeStamp - def __init__(self): + def __init__(self) -> None: self.reset() def elapsed(self) -> TimeSpan: """Returns the time elapsed since the timer was started.""" return TimeSpan(time.perf_counter_ns() - self._start) - def start(self): + def start(self) -> None: """Starts the timer.""" self.reset() - def reset(self): + def reset(self) -> None: """Resets the timer to zero.""" self._start = TimeStamp(time.perf_counter_ns()) @@ -119,17 +122,17 @@ def __init__(self, interval: Rate | TimeSpan | float | int, precise: bool = Fals self.precise = precise self._last = time.perf_counter_ns() - def __enter__(self): + def __enter__(self) -> Loop: self._last = time.perf_counter_ns() return self - def __exit__(self, exc_type, exc_val, exc_tb): + def __exit__(self, exc_type: object, exc_val: object, exc_tb: object) -> None: pass - def __iter__(self): + def __iter__(self) -> Loop: return self - def __next__(self): + def __next__(self) -> None: elapsed = self._timer.elapsed() if elapsed < self.interval: sleep_for = self.interval - elapsed - self._correction @@ -143,10 +146,10 @@ def __next__(self): self._correction = self.average - self.interval self._timer.reset() - def __call__(self): + def __call__(self) -> None: return self.__next__() - def wait(self) -> True: + def wait(self) -> Literal[True]: """Waits for the next iteration of the loop, automatically sleeping for the remainder of the interval if the calling block of code executes faster than the interval. diff --git a/client/py/synnax/transport.py b/client/py/synnax/transport.py index 170dd025da..e87297ade8 100644 --- a/client/py/synnax/transport.py +++ b/client/py/synnax/transport.py @@ -17,7 +17,6 @@ JSONCodec, Middleware, MsgPackCodec, - StreamClient, UnaryClient, WebsocketClient, async_instrumentation_middleware, @@ -30,7 +29,7 @@ class Transport: url: URL - stream: StreamClient + stream: WebsocketClient stream_async: AsyncStreamClient unary: UnaryClient secure: bool @@ -70,9 +69,9 @@ def __init__( self.use(instrumentation_middleware(instrumentation)) self.use_async(async_instrumentation_middleware(instrumentation)) - def use(self, *middleware: Middleware): + def use(self, *middleware: Middleware) -> None: self.unary.use(*middleware) self.stream.use(*middleware) - def use_async(self, *middleware: AsyncMiddleware): + def use_async(self, *middleware: AsyncMiddleware) -> None: self.stream_async.use(*middleware) diff --git a/client/py/synnax/user/__init__.py b/client/py/synnax/user/__init__.py index d12c7365d8..279f267ea6 100644 --- a/client/py/synnax/user/__init__.py +++ b/client/py/synnax/user/__init__.py @@ -9,3 +9,5 @@ from synnax.user.client import Client from synnax.user.payload import New, User + +__all__ = ["Client", "New", "User"] diff --git a/client/py/synnax/user/client.py b/client/py/synnax/user/client.py index 4d4d2a1b82..7338ed0d37 100644 --- a/client/py/synnax/user/client.py +++ b/client/py/synnax/user/client.py @@ -13,6 +13,7 @@ from freighter import Empty, UnaryClient, send_required from pydantic import BaseModel +from synnax.exceptions import NotFoundError from synnax.user.payload import New, User from synnax.util.normalize import normalize from synnax.util.params import require_named_params @@ -98,8 +99,10 @@ def create( key=key, ) single = user is not None - if single: + if user is not None: users = [user] + if users is None: + raise ValueError("Either username, user, or users must be provided") res = send_required( self.client, "/user/create", @@ -153,12 +156,19 @@ def retrieve( keys = normalize(key) if username is not None: usernames = normalize(username) - return send_required( + single = key is not None or username is not None + res = send_required( self.client, "/user/retrieve", _RetrieveRequest(keys=keys, usernames=usernames), _RetrieveResponse, - ).users + ) + users = res.users or [] + if not single: + return users + if len(users) == 0: + raise NotFoundError(f"User matching {key or username} not found") + return users[0] def delete(self, keys: UUID | list[UUID] | None = None) -> None: send_required( diff --git a/client/py/synnax/user/payload.py b/client/py/synnax/user/payload.py index 17e61e4db1..d27262039d 100644 --- a/client/py/synnax/user/payload.py +++ b/client/py/synnax/user/payload.py @@ -48,5 +48,10 @@ def ontology_id(key: UUID) -> ontology.ID: return ontology.ID(type=ONTOLOGY_TYPE.type, key=str(key)) -# Backwards compatibility -NewUser = New +from synnax.util.deprecation import deprecated_getattr + +_DEPRECATED = { + "NewUser": "New", +} + +__getattr__ = deprecated_getattr(__name__, _DEPRECATED, globals()) diff --git a/client/py/synnax/util/deprecation.py b/client/py/synnax/util/deprecation.py new file mode 100644 index 0000000000..4a72e995f0 --- /dev/null +++ b/client/py/synnax/util/deprecation.py @@ -0,0 +1,49 @@ +# Copyright 2026 Synnax Labs, Inc. +# +# Use of this software is governed by the Business Source License included in the file +# licenses/BSL.txt. +# +# As of the Change Date specified in that file, in accordance with the Business Source +# License, use of this software will be governed by the Apache License, Version 2.0, +# included in the file licenses/APL.txt. + +from __future__ import annotations + +import warnings +from collections.abc import Mapping +from typing import Any + + +def deprecated_getattr( + module_name: str, + deprecated: Mapping[str, str | tuple[str, str]], + module_globals: dict[str, Any], +) -> Any: + """Creates a module-level __getattr__ that warns on deprecated name access. + + :param module_name: The module's __name__ (used in AttributeError messages). + :param deprecated: Mapping of old_name to either new_name (str) or a tuple of + (display_name, globals_key) for cases where the display name in the warning + message differs from the key used to look up the value in module globals. + :param module_globals: The module's globals() dict to resolve new names from. + :returns: A __getattr__ function suitable for assignment at module level. + """ + + def __getattr__(name: str) -> Any: + if name in deprecated: + entry = deprecated[name] + if isinstance(entry, tuple): + display_name, globals_key = entry + else: + display_name = globals_key = entry + warnings.warn( + f"{name} is deprecated, use {display_name} instead", + DeprecationWarning, + stacklevel=2, + ) + val = module_globals[globals_key] + module_globals[name] = val + return val + raise AttributeError(f"module {module_name!r} has no attribute {name!r}") + + return __getattr__ diff --git a/client/py/synnax/util/params.py b/client/py/synnax/util/params.py index 8a4339f9a4..2655e4eca1 100644 --- a/client/py/synnax/util/params.py +++ b/client/py/synnax/util/params.py @@ -9,10 +9,10 @@ import functools from collections.abc import Callable -from typing import Any, TypeVar, cast +from typing import ParamSpec, TypeVar, overload -# Define a generic type variable for the function -F = TypeVar("F", bound=Callable[..., Any]) +P = ParamSpec("P") +R = TypeVar("R") class RequiresNamedParams(TypeError): @@ -24,33 +24,42 @@ class RequiresNamedParams(TypeError): pass +@overload +def require_named_params(func: Callable[P, R]) -> Callable[P, R]: ... + + +@overload def require_named_params( - func: F | None = None, *, example_params: tuple[str, str] | None = None -) -> Callable[[F], F]: + func: None = None, *, example_params: tuple[str, str] | None = None +) -> Callable[[Callable[P, R]], Callable[P, R]]: ... + + +def require_named_params( + func: Callable[P, R] | None = None, + *, + example_params: tuple[str, str] | None = None, +) -> Callable[P, R] | Callable[[Callable[P, R]], Callable[P, R]]: """ Decorator that catches TypeError exceptions related to positional arguments and re-raises them with a more helpful error message. Args: func: The function to decorate - example_params: Optional tuple of (param_name, param_value) to show in the error message - Example: example_params=("user_id", "12345") + example_params: Optional tuple of (param_name, param_value) to show in the + error message. Example: example_params=("user_id", "12345") Returns: The decorated function with improved error messages for positional arguments """ - def decorator(func: F) -> F: - @functools.wraps(func) - def wrapper(*args: Any, **kwargs: Any) -> Any: + def decorator(f: Callable[P, R]) -> Callable[P, R]: + @functools.wraps(f) + def wrapper(*args: P.args, **kwargs: P.kwargs) -> R: try: - return func(*args, **kwargs) + return f(*args, **kwargs) except TypeError as e: - # Check if this is the "takes X positional arguments but Y were given" error if "positional argument" in str(e) and "were given" in str(e): - func_name = func.__qualname__ - - # Use custom example if provided, otherwise use generic example + func_name = f.__qualname__ if example_params: param_name, param_value = example_params param_example = f"{func_name}({param_name}='{param_value}')" @@ -58,18 +67,17 @@ def wrapper(*args: Any, **kwargs: Any) -> Any: else: param_example = f"{func_name}(name='value')" value_example = "'value'" - message = ( - f"{str(e)}. '{func_name}' only accepts named parameters.\n" - f"Try using named parameters like: {param_example} instead of {func_name}({value_example})" + f"{str(e)}. '{func_name}' only accepts named" + f" parameters.\nTry using named parameters" + f" like: {param_example} instead of" + f" {func_name}({value_example})" ) raise RequiresNamedParams(message) from None - # Re-raise other TypeErrors unchanged raise - return cast(F, wrapper) + return wrapper - # Handle both @require_named_params and @require_named_params(example_params="...") if func is None: return decorator return decorator(func) diff --git a/client/py/tests/test_channel.py b/client/py/tests/test_channel.py index 9d9b5161ce..dbf4b8b850 100644 --- a/client/py/tests/test_channel.py +++ b/client/py/tests/test_channel.py @@ -107,7 +107,7 @@ def test_create_from_list(self, client: sy.Synnax): assert len(channels) == 2 for channel in channels: assert channel.name.startswith("test") - assert channel.key != "" + assert channel.key != 0 def test_create_from_single_instance(self, client: sy.Synnax): """Should create a single channel from a channel instance""" @@ -425,7 +425,7 @@ def test_create_list(self, hundred_channels: list[sy.Channel]): assert len(hundred_channels) == 100 for channel in hundred_channels: assert channel.name.startswith("sensor") - assert channel.key != "" + assert channel.key != 0 def test_retrieve_list(self, client: sy.Synnax, hundred_channels: list[sy.Channel]): """Should retrieve a list of 100 valid channels""" @@ -434,7 +434,7 @@ def test_retrieve_list(self, client: sy.Synnax, hundred_channels: list[sy.Channe assert len(res_channels) == 100 for channel in res_channels: assert channel.name.startswith("sensor") - assert channel.key != "" + assert channel.key != 0 assert isinstance(channel.data_type.density, sy.Density) def test_retrieve_zero_key_single(self, client: sy.Synnax): @@ -471,6 +471,7 @@ def test_create_channel_with_avg_operation_duration(self, client: sy.Synnax): assert created.name == channel.name assert created.virtual is True + assert created.operations is not None assert len(created.operations) == 1 assert created.operations[0].type == "avg" assert created.operations[0].duration == sy.TimeSpan.SECOND * 10 @@ -506,6 +507,7 @@ def test_create_channel_with_min_operation_reset_channel(self, client: sy.Synnax assert created.name == channel.name assert created.virtual is True + assert created.operations is not None assert len(created.operations) == 1 assert created.operations[0].type == "min" assert created.operations[0].reset_channel == reset_channel.key @@ -537,6 +539,7 @@ def test_create_channel_with_max_operation(self, client: sy.Synnax): created = client.channels.create(channel) assert created.name == channel.name + assert created.operations is not None assert len(created.operations) == 1 assert created.operations[0].type == "max" assert created.operations[0].duration == sy.TimeSpan.SECOND * 5 @@ -566,6 +569,7 @@ def test_retrieve_channel_with_operations(self, client: sy.Synnax): # Retrieve and verify operations are preserved retrieved = client.channels.retrieve(created.key) assert retrieved.name == channel.name + assert retrieved.operations is not None assert len(retrieved.operations) == 1 assert retrieved.operations[0].type == "avg" assert retrieved.operations[0].duration == sy.TimeSpan.SECOND * 15 @@ -605,6 +609,7 @@ def test_write_read(self, client: sy.Synnax): start = 1 * sy.TimeSpan.SECOND channel.write(start, d) data = channel.read(start, (start + len(d)) * sy.TimeSpan.SECOND) + assert data.time_range is not None assert data.time_range.start == start assert len(d) == len(data) assert data.time_range.end == start + (len(d) - 1) * sy.TimeSpan.SECOND + 1 diff --git a/client/py/tests/test_deprecation.py b/client/py/tests/test_deprecation.py new file mode 100644 index 0000000000..24f521d746 --- /dev/null +++ b/client/py/tests/test_deprecation.py @@ -0,0 +1,115 @@ +# Copyright 2026 Synnax Labs, Inc. +# +# Use of this software is governed by the Business Source License included in the file +# licenses/BSL.txt. +# +# As of the Change Date specified in that file, in accordance with the Business Source +# License, use of this software will be governed by the Apache License, Version 2.0, +# included in the file licenses/APL.txt. + +import types +import warnings + +import pytest + +from synnax.util.deprecation import deprecated_getattr + + +def _make_module( + deprecated: dict, + **globals_entries: object, +) -> types.ModuleType: + """Create a fake module with deprecated_getattr configured.""" + mod = types.ModuleType("test_module") + mod.__dict__.update(globals_entries) + mod.__dict__["__getattr__"] = deprecated_getattr( + "test_module", deprecated, mod.__dict__ + ) + return mod + + +@pytest.mark.deprecation +class TestDeprecatedGetattr: + def test_emits_deprecation_warning(self): + """Should emit a DeprecationWarning when accessing a deprecated name.""" + mod = _make_module({"OldName": "NewName"}, NewName="value") + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always", DeprecationWarning) + result = mod.OldName + assert result == "value" + assert len(w) == 1 + assert issubclass(w[0].category, DeprecationWarning) + assert "OldName is deprecated, use NewName instead" in str(w[0].message) + + def test_returns_correct_value(self): + """Should return the same object as the new name.""" + sentinel = object() + mod = _make_module({"Old": "New"}, New=sentinel) + with warnings.catch_warnings(record=True): + warnings.simplefilter("always", DeprecationWarning) + assert mod.Old is sentinel + + def test_caches_after_first_access(self): + """Should not call __getattr__ on subsequent accesses.""" + mod = _make_module({"Old": "New"}, New="value") + with warnings.catch_warnings(record=True): + warnings.simplefilter("always", DeprecationWarning) + _ = mod.Old + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always", DeprecationWarning) + _ = mod.Old + assert len(w) == 0 + + def test_raises_attribute_error_for_unknown(self): + """Should raise AttributeError for names that are not deprecated.""" + mod = _make_module({}, NewName="value") + with pytest.raises(AttributeError, match="test_module"): + _ = mod.NonExistent + + def test_tuple_form_custom_display_name(self): + """Should use the display name from a tuple entry in the warning.""" + mod = _make_module( + {"OldName": ("package.module.NewName", "_internal")}, + _internal="value", + ) + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always", DeprecationWarning) + result = mod.OldName + assert result == "value" + assert len(w) == 1 + assert "use package.module.NewName instead" in str(w[0].message) + + def test_tuple_form_caches(self): + """Should cache after first access with tuple form.""" + mod = _make_module( + {"OldName": ("pkg.New", "_internal")}, + _internal="value", + ) + with warnings.catch_warnings(record=True): + warnings.simplefilter("always", DeprecationWarning) + _ = mod.OldName + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always", DeprecationWarning) + _ = mod.OldName + assert len(w) == 0 + + def test_multiple_deprecated_names(self): + """Should handle multiple deprecated names independently.""" + mod = _make_module( + {"OldA": "NewA", "OldB": "NewB"}, + NewA="a", + NewB="b", + ) + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always", DeprecationWarning) + assert mod.OldA == "a" + assert mod.OldB == "b" + assert len(w) == 2 + assert "OldA" in str(w[0].message) + assert "OldB" in str(w[1].message) + + def test_non_deprecated_access_unaffected(self): + """Should not interfere with normal attribute access.""" + mod = _make_module({"Old": "New"}, New="value", Other="other") + assert mod.Other == "other" + assert mod.New == "value" diff --git a/uv.lock b/uv.lock index fbe385daee..6110fe844e 100644 --- a/uv.lock +++ b/uv.lock @@ -1003,6 +1003,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/0a/46/e1c6876d71c14332be70239acce9ad435975a80541086e5ffba2f249bcf6/pandas-3.0.0-cp313-cp313t-win_amd64.whl", hash = "sha256:940eebffe55528074341a5a36515f3e4c5e25e958ebbc764c9502cfc35ba3faa", size = 10473771, upload-time = "2026-01-21T15:51:25.285Z" }, ] +[[package]] +name = "pandas-stubs" +version = "3.0.0.260204" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "numpy" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/27/1d/297ff2c7ea50a768a2247621d6451abb2a07c0e9be7ca6d36ebe371658e5/pandas_stubs-3.0.0.260204.tar.gz", hash = "sha256:bf9294b76352effcffa9cb85edf0bed1339a7ec0c30b8e1ac3d66b4228f1fbc3", size = 109383, upload-time = "2026-02-04T15:17:17.247Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/7c/2f/f91e4eee21585ff548e83358332d5632ee49f6b2dcd96cb5dca4e0468951/pandas_stubs-3.0.0.260204-py3-none-any.whl", hash = "sha256:5ab9e4d55a6e2752e9720828564af40d48c4f709e6a2c69b743014a6fcb6c241", size = 168540, upload-time = "2026-02-04T15:17:15.615Z" }, +] + [[package]] name = "pathspec" version = "1.0.4" @@ -1393,8 +1405,8 @@ name = "secretstorage" version = "3.5.0" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "cryptography" }, - { name = "jeepney" }, + { name = "cryptography", marker = "sys_platform != 'emscripten' and sys_platform != 'win32'" }, + { name = "jeepney", marker = "sys_platform != 'emscripten' and sys_platform != 'win32'" }, ] sdist = { url = "https://files.pythonhosted.org/packages/1c/03/e834bcd866f2f8a49a85eaff47340affa3bfa391ee9912a952a1faa68c7b/secretstorage-3.5.0.tar.gz", hash = "sha256:f04b8e4689cbce351744d5537bf6b1329c6fc68f91fa666f60a380edddcd11be", size = 19884, upload-time = "2025-11-23T19:02:53.191Z" } wheels = [ @@ -1443,6 +1455,7 @@ dev = [ { name = "black" }, { name = "isort" }, { name = "mypy" }, + { name = "pandas-stubs" }, { name = "pymodbus" }, { name = "pytest" }, { name = "pytest-asyncio" }, @@ -1471,6 +1484,7 @@ dev = [ { name = "black", specifier = ">=26.1.0,<27" }, { name = "isort", specifier = ">=7.0.0,<8" }, { name = "mypy", specifier = ">=1.19.1,<2" }, + { name = "pandas-stubs", specifier = ">=2.0.0" }, { name = "pymodbus", specifier = ">=3.11.4,<4" }, { name = "pytest", specifier = ">=9.0.2,<10" }, { name = "pytest-asyncio", specifier = ">=1.3.0,<2" },