@@ -163,10 +163,11 @@ def _get_request_timeout(settings):
163163
164164
165165class EndpointOptions (object ):
166- __slots__ = ("ssl_target_name_override" ,)
166+ __slots__ = ("ssl_target_name_override" , "node_id" )
167167
168- def __init__ (self , ssl_target_name_override = None ):
168+ def __init__ (self , ssl_target_name_override = None , node_id = None ):
169169 self .ssl_target_name_override = ssl_target_name_override
170+ self .node_id = node_id
170171
171172
172173def _construct_channel_options (driver_config , endpoint_options = None ):
@@ -223,16 +224,18 @@ class _RpcState(object):
223224 "endpoint" ,
224225 "rendezvous" ,
225226 "metadata_kv" ,
227+ "endpoint_key" ,
226228 )
227229
228- def __init__ (self , stub_instance , rpc_name , endpoint ):
230+ def __init__ (self , stub_instance , rpc_name , endpoint , endpoint_key ):
229231 """Stores all RPC related data"""
230232 self .rpc_name = rpc_name
231233 self .rpc = getattr (stub_instance , rpc_name )
232234 self .request_id = uuid .uuid4 ()
233235 self .endpoint = endpoint
234236 self .rendezvous = None
235237 self .metadata_kv = None
238+ self .endpoint_key = endpoint_key
236239
237240 def __str__ (self ):
238241 return "RpcState(%s, %s, %s)" % (self .rpc_name , self .request_id , self .endpoint )
@@ -318,6 +321,14 @@ def channel_factory(
318321 )
319322
320323
324+ class EndpointKey (object ):
325+ __slots__ = ("endpoint" , "node_id" )
326+
327+ def __init__ (self , endpoint , node_id ):
328+ self .endpoint = endpoint
329+ self .node_id = node_id
330+
331+
321332class Connection (object ):
322333 __slots__ = (
323334 "endpoint" ,
@@ -330,6 +341,8 @@ class Connection(object):
330341 "lock" ,
331342 "calls" ,
332343 "closing" ,
344+ "endpoint_key" ,
345+ "node_id" ,
333346 )
334347
335348 def __init__ (self , endpoint , driver_config = None , endpoint_options = None ):
@@ -341,6 +354,10 @@ def __init__(self, endpoint, driver_config=None, endpoint_options=None):
341354 """
342355 global _stubs_list
343356 self .endpoint = endpoint
357+ self .node_id = getattr (endpoint_options , "node_id" , None )
358+ self .endpoint_key = EndpointKey (
359+ endpoint , getattr (endpoint_options , "node_id" , None )
360+ )
344361 self ._channel = channel_factory (
345362 self .endpoint , driver_config , endpoint_options = endpoint_options
346363 )
@@ -368,7 +385,9 @@ def _prepare_call(self, stub, rpc_name, request, settings):
368385 )
369386 _set_server_timeouts (request , settings , timeout )
370387 self ._prepare_stub_instance (stub )
371- rpc_state = _RpcState (self ._stub_instances [stub ], rpc_name , self .endpoint )
388+ rpc_state = _RpcState (
389+ self ._stub_instances [stub ], rpc_name , self .endpoint , self .endpoint_key
390+ )
372391 logger .debug ("%s: creating call state" , rpc_state )
373392 with self .lock :
374393 if self .closing :
0 commit comments