11# -*- coding: utf-8 -*-
2+ from __future__ import annotations
3+
24import contextlib
35import logging
46import threading
57import random
68import itertools
7- from . import connection as conn_impl , issues , settings as settings_impl , _apis
9+ import typing
10+ from . import connection as conn_impl , driver , issues , settings as settings_impl , _apis
11+
12+
13+ # Workaround for good IDE and universal for runtime
14+ if typing .TYPE_CHECKING :
15+ from ._grpc .v4 .protos import ydb_discovery_pb2
16+ else :
17+ from ._grpc .common .protos import ydb_discovery_pb2
18+
819
920logger = logging .getLogger (__name__ )
1021
@@ -22,7 +33,7 @@ class EndpointInfo(object):
2233 "node_id" ,
2334 )
2435
25- def __init__ (self , endpoint_info ):
36+ def __init__ (self , endpoint_info : ydb_discovery_pb2 . EndpointInfo ):
2637 self .address = endpoint_info .address
2738 self .endpoint = "%s:%s" % (endpoint_info .address , endpoint_info .port )
2839 self .location = endpoint_info .location
@@ -33,7 +44,7 @@ def __init__(self, endpoint_info):
3344 self .ssl_target_name_override = endpoint_info .ssl_target_name_override
3445 self .node_id = endpoint_info .node_id
3546
36- def endpoints_with_options (self ):
47+ def endpoints_with_options (self ) -> typing . Generator [ typing . Tuple [ str , conn_impl . EndpointOptions ], None , None ] :
3748 ssl_target_name_override = None
3849 if self .ssl :
3950 if self .ssl_target_name_override :
@@ -73,14 +84,14 @@ def __eq__(self, other):
7384 return self .endpoint == other .endpoint
7485
7586
76- def _list_endpoints_request_factory (connection_params ) :
87+ def _list_endpoints_request_factory (connection_params : driver . DriverConfig ) -> _apis . ydb_discovery . ListEndpointsRequest :
7788 request = _apis .ydb_discovery .ListEndpointsRequest ()
7889 request .database = connection_params .database
7990 return request
8091
8192
8293class DiscoveryResult (object ):
83- def __init__ (self , self_location , endpoints ):
94+ def __init__ (self , self_location : str , endpoints : "list[EndpointInfo]" ):
8495 self .self_location = self_location
8596 self .endpoints = endpoints
8697
@@ -94,7 +105,12 @@ def __repr__(self):
94105 return self .__str__ ()
95106
96107 @classmethod
97- def from_response (cls , rpc_state , response , use_all_nodes = False ):
108+ def from_response (
109+ cls ,
110+ rpc_state : conn_impl ._RpcState ,
111+ response : ydb_discovery_pb2 .ListEndpointsResponse ,
112+ use_all_nodes : bool = False ,
113+ ) -> DiscoveryResult :
98114 issues ._process_response (response .operation )
99115 message = _apis .ydb_discovery .ListEndpointsResult ()
100116 response .operation .result .Unpack (message )
@@ -123,7 +139,7 @@ def from_response(cls, rpc_state, response, use_all_nodes=False):
123139
124140
125141class DiscoveryEndpointsResolver (object ):
126- def __init__ (self , driver_config ):
142+ def __init__ (self , driver_config : driver . DriverConfig ):
127143 self .logger = logger .getChild (self .__class__ .__name__ )
128144 self ._driver_config = driver_config
129145 self ._ready_timeout = getattr (self ._driver_config , "discovery_request_timeout" , 10 )
@@ -136,27 +152,27 @@ def __init__(self, driver_config):
136152 random .shuffle (self ._endpoints )
137153 self ._endpoints_iter = itertools .cycle (self ._endpoints )
138154
139- def _add_debug_details (self , message , * args ):
155+ def _add_debug_details (self , message : str , * args ):
140156 self .logger .debug (message , * args )
141157 message = message % args
142158 with self ._lock :
143159 self ._debug_details_items .append (message )
144160 if len (self ._debug_details_items ) > self ._debug_details_history_size :
145161 self ._debug_details_items .pop ()
146162
147- def debug_details (self ):
163+ def debug_details (self ) -> str :
148164 """
149165 Returns last resolver errors as a debug string.
150166 """
151167 with self ._lock :
152168 return "\n " .join (self ._debug_details_items )
153169
154- def resolve (self ):
170+ def resolve (self ) -> typing . ContextManager [ typing . Optional [ DiscoveryResult ]] :
155171 with self .context_resolve () as result :
156172 return result
157173
158174 @contextlib .contextmanager
159- def context_resolve (self ):
175+ def context_resolve (self ) -> typing . ContextManager [ typing . Optional [ DiscoveryResult ]] :
160176 self .logger .debug ("Preparing initial endpoint to resolve endpoints" )
161177 endpoint = next (self ._endpoints_iter )
162178 initial = conn_impl .Connection .ready_factory (endpoint , self ._driver_config , ready_timeout = self ._ready_timeout )
0 commit comments