Skip to content

Commit 0448ada

Browse files
authored
Merge pull request #363 from RaileyHartheim/main Add type hints for resolver.py
2 parents b5ba69f + bdbbf32 commit 0448ada

File tree

1 file changed

+27
-11
lines changed

1 file changed

+27
-11
lines changed

ydb/resolver.py

Lines changed: 27 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,21 @@
11
# -*- coding: utf-8 -*-
2+
from __future__ import annotations
3+
24
import contextlib
35
import logging
46
import threading
57
import random
68
import itertools
7-
from . import connection as conn_impl, issues, settings as settings_impl, _apis
9+
import typing
10+
from . import connection as conn_impl, driver, issues, settings as settings_impl, _apis
11+
12+
13+
# Workaround for good IDE and universal for runtime
14+
if typing.TYPE_CHECKING:
15+
from ._grpc.v4.protos import ydb_discovery_pb2
16+
else:
17+
from ._grpc.common.protos import ydb_discovery_pb2
18+
819

920
logger = logging.getLogger(__name__)
1021

@@ -22,7 +33,7 @@ class EndpointInfo(object):
2233
"node_id",
2334
)
2435

25-
def __init__(self, endpoint_info):
36+
def __init__(self, endpoint_info: ydb_discovery_pb2.EndpointInfo):
2637
self.address = endpoint_info.address
2738
self.endpoint = "%s:%s" % (endpoint_info.address, endpoint_info.port)
2839
self.location = endpoint_info.location
@@ -33,7 +44,7 @@ def __init__(self, endpoint_info):
3344
self.ssl_target_name_override = endpoint_info.ssl_target_name_override
3445
self.node_id = endpoint_info.node_id
3546

36-
def endpoints_with_options(self):
47+
def endpoints_with_options(self) -> typing.Generator[typing.Tuple[str, conn_impl.EndpointOptions], None, None]:
3748
ssl_target_name_override = None
3849
if self.ssl:
3950
if self.ssl_target_name_override:
@@ -73,14 +84,14 @@ def __eq__(self, other):
7384
return self.endpoint == other.endpoint
7485

7586

76-
def _list_endpoints_request_factory(connection_params):
87+
def _list_endpoints_request_factory(connection_params: driver.DriverConfig) -> _apis.ydb_discovery.ListEndpointsRequest:
7788
request = _apis.ydb_discovery.ListEndpointsRequest()
7889
request.database = connection_params.database
7990
return request
8091

8192

8293
class DiscoveryResult(object):
83-
def __init__(self, self_location, endpoints):
94+
def __init__(self, self_location: str, endpoints: "list[EndpointInfo]"):
8495
self.self_location = self_location
8596
self.endpoints = endpoints
8697

@@ -94,7 +105,12 @@ def __repr__(self):
94105
return self.__str__()
95106

96107
@classmethod
97-
def from_response(cls, rpc_state, response, use_all_nodes=False):
108+
def from_response(
109+
cls,
110+
rpc_state: conn_impl._RpcState,
111+
response: ydb_discovery_pb2.ListEndpointsResponse,
112+
use_all_nodes: bool = False,
113+
) -> DiscoveryResult:
98114
issues._process_response(response.operation)
99115
message = _apis.ydb_discovery.ListEndpointsResult()
100116
response.operation.result.Unpack(message)
@@ -123,7 +139,7 @@ def from_response(cls, rpc_state, response, use_all_nodes=False):
123139

124140

125141
class DiscoveryEndpointsResolver(object):
126-
def __init__(self, driver_config):
142+
def __init__(self, driver_config: driver.DriverConfig):
127143
self.logger = logger.getChild(self.__class__.__name__)
128144
self._driver_config = driver_config
129145
self._ready_timeout = getattr(self._driver_config, "discovery_request_timeout", 10)
@@ -136,27 +152,27 @@ def __init__(self, driver_config):
136152
random.shuffle(self._endpoints)
137153
self._endpoints_iter = itertools.cycle(self._endpoints)
138154

139-
def _add_debug_details(self, message, *args):
155+
def _add_debug_details(self, message: str, *args):
140156
self.logger.debug(message, *args)
141157
message = message % args
142158
with self._lock:
143159
self._debug_details_items.append(message)
144160
if len(self._debug_details_items) > self._debug_details_history_size:
145161
self._debug_details_items.pop()
146162

147-
def debug_details(self):
163+
def debug_details(self) -> str:
148164
"""
149165
Returns last resolver errors as a debug string.
150166
"""
151167
with self._lock:
152168
return "\n".join(self._debug_details_items)
153169

154-
def resolve(self):
170+
def resolve(self) -> typing.ContextManager[typing.Optional[DiscoveryResult]]:
155171
with self.context_resolve() as result:
156172
return result
157173

158174
@contextlib.contextmanager
159-
def context_resolve(self):
175+
def context_resolve(self) -> typing.ContextManager[typing.Optional[DiscoveryResult]]:
160176
self.logger.debug("Preparing initial endpoint to resolve endpoints")
161177
endpoint = next(self._endpoints_iter)
162178
initial = conn_impl.Connection.ready_factory(endpoint, self._driver_config, ready_timeout=self._ready_timeout)

0 commit comments

Comments
 (0)