@@ -58,13 +58,9 @@ async def ServerStreamingMethod(self, request, context):
5858
5959
6060async def run_with_test_server (
61- runnable , servicer = Servicer (), add_interceptor = True
61+ runnable , servicer = Servicer (), interceptors = None
6262):
63- if add_interceptor :
64- interceptors = [aio_server_interceptor ()]
65- server = grpc .aio .server (interceptors = interceptors )
66- else :
67- server = grpc .aio .server ()
63+ server = grpc .aio .server (interceptors = interceptors )
6864
6965 add_GRPCTestServerServicer_to_server (servicer , server )
7066
@@ -95,7 +91,7 @@ async def request(channel):
9591 msg = request .SerializeToString ()
9692 return await channel .unary_unary (rpc_call )(msg )
9793
98- await run_with_test_server (request , add_interceptor = False )
94+ await run_with_test_server (request )
9995
10096 spans_list = self .memory_exporter .get_finished_spans ()
10197 self .assertEqual (len (spans_list ), 1 )
@@ -140,7 +136,7 @@ async def request(channel):
140136 msg = request .SerializeToString ()
141137 return await channel .unary_unary (rpc_call )(msg )
142138
143- await run_with_test_server (request , add_interceptor = False )
139+ await run_with_test_server (request )
144140
145141 spans_list = self .memory_exporter .get_finished_spans ()
146142 self .assertEqual (len (spans_list ), 0 )
@@ -154,7 +150,9 @@ async def request(channel):
154150 msg = request .SerializeToString ()
155151 return await channel .unary_unary (rpc_call )(msg )
156152
157- await run_with_test_server (request )
153+ await run_with_test_server (
154+ request , interceptors = [aio_server_interceptor ()]
155+ )
158156
159157 spans_list = self .memory_exporter .get_finished_spans ()
160158 self .assertEqual (len (spans_list ), 1 )
@@ -206,7 +204,11 @@ async def request(channel):
206204 msg = request .SerializeToString ()
207205 return await channel .unary_unary (rpc_call )(msg )
208206
209- await run_with_test_server (request , servicer = TwoSpanServicer ())
207+ await run_with_test_server (
208+ request ,
209+ servicer = TwoSpanServicer (),
210+ interceptors = [aio_server_interceptor ()],
211+ )
210212
211213 spans_list = self .memory_exporter .get_finished_spans ()
212214 self .assertEqual (len (spans_list ), 2 )
@@ -253,7 +255,9 @@ async def request(channel):
253255 async for response in channel .unary_stream (rpc_call )(msg ):
254256 print (response )
255257
256- await run_with_test_server (request )
258+ await run_with_test_server (
259+ request , interceptors = [aio_server_interceptor ()]
260+ )
257261
258262 spans_list = self .memory_exporter .get_finished_spans ()
259263 self .assertEqual (len (spans_list ), 1 )
@@ -307,7 +311,11 @@ async def request(channel):
307311 async for response in channel .unary_stream (rpc_call )(msg ):
308312 print (response )
309313
310- await run_with_test_server (request , servicer = TwoSpanServicer ())
314+ await run_with_test_server (
315+ request ,
316+ servicer = TwoSpanServicer (),
317+ interceptors = [aio_server_interceptor ()],
318+ )
311319
312320 spans_list = self .memory_exporter .get_finished_spans ()
313321 self .assertEqual (len (spans_list ), 2 )
@@ -367,7 +375,11 @@ async def request(channel):
367375 lifetime_servicer = SpanLifetimeServicer ()
368376 active_span_before_call = trace .get_current_span ()
369377
370- await run_with_test_server (request , servicer = lifetime_servicer )
378+ await run_with_test_server (
379+ request ,
380+ servicer = lifetime_servicer ,
381+ interceptors = [aio_server_interceptor ()],
382+ )
371383
372384 active_span_in_handler = lifetime_servicer .span
373385 active_span_after_call = trace .get_current_span ()
@@ -390,7 +402,9 @@ async def sequential_requests(channel):
390402 await request (channel )
391403 await request (channel )
392404
393- await run_with_test_server (sequential_requests )
405+ await run_with_test_server (
406+ sequential_requests , interceptors = [aio_server_interceptor ()]
407+ )
394408
395409 spans_list = self .memory_exporter .get_finished_spans ()
396410 self .assertEqual (len (spans_list ), 2 )
@@ -450,7 +464,9 @@ async def concurrent_requests(channel):
450464 await asyncio .gather (request (channel ), request (channel ))
451465
452466 await run_with_test_server (
453- concurrent_requests , servicer = LatchedServicer ()
467+ concurrent_requests ,
468+ servicer = LatchedServicer (),
469+ interceptors = [aio_server_interceptor ()],
454470 )
455471
456472 spans_list = self .memory_exporter .get_finished_spans ()
@@ -504,7 +520,11 @@ async def request(channel):
504520 self .assertEqual (cm .exception .code (), grpc .StatusCode .INTERNAL )
505521 self .assertEqual (cm .exception .details (), failure_message )
506522
507- await run_with_test_server (request , servicer = AbortServicer ())
523+ await run_with_test_server (
524+ request ,
525+ servicer = AbortServicer (),
526+ interceptors = [aio_server_interceptor ()],
527+ )
508528
509529 spans_list = self .memory_exporter .get_finished_spans ()
510530 self .assertEqual (len (spans_list ), 1 )
@@ -569,7 +589,11 @@ async def request(channel):
569589 )
570590 self .assertEqual (cm .exception .details (), failure_message )
571591
572- await run_with_test_server (request , servicer = AbortServicer ())
592+ await run_with_test_server (
593+ request ,
594+ servicer = AbortServicer (),
595+ interceptors = [aio_server_interceptor ()],
596+ )
573597
574598 spans_list = self .memory_exporter .get_finished_spans ()
575599 self .assertEqual (len (spans_list ), 1 )
@@ -602,6 +626,60 @@ async def request(channel):
602626 },
603627 )
604628
629+ async def test_non_list_interceptors (self ):
630+ """Check that we handle non-list interceptors correctly."""
631+
632+ grpc_server_instrumentor = GrpcAioInstrumentorServer ()
633+ grpc_server_instrumentor .instrument ()
634+
635+ try :
636+ rpc_call = "/GRPCTestServer/SimpleMethod"
637+
638+ async def request (channel ):
639+ request = Request (client_id = 1 , request_data = "test" )
640+ msg = request .SerializeToString ()
641+ return await channel .unary_unary (rpc_call )(msg )
642+
643+ class MockInterceptor (grpc .aio .ServerInterceptor ):
644+ async def intercept_service (
645+ self , continuation , handler_call_details
646+ ):
647+ return await continuation (handler_call_details )
648+
649+ await run_with_test_server (
650+ request , interceptors = (MockInterceptor (),)
651+ )
652+
653+ finally :
654+ grpc_server_instrumentor .uninstrument ()
655+
656+ spans_list = self .memory_exporter .get_finished_spans ()
657+ self .assertEqual (len (spans_list ), 1 )
658+ span = spans_list [0 ]
659+
660+ self .assertEqual (span .name , rpc_call )
661+ self .assertIs (span .kind , trace .SpanKind .SERVER )
662+
663+ # Check version and name in span's instrumentation info
664+ self .assertEqualSpanInstrumentationScope (
665+ span , opentelemetry .instrumentation .grpc
666+ )
667+
668+ # Check attributes
669+ self .assertSpanHasAttributes (
670+ span ,
671+ {
672+ SpanAttributes .NET_PEER_IP : "[::1]" ,
673+ SpanAttributes .NET_PEER_NAME : "localhost" ,
674+ SpanAttributes .RPC_METHOD : "SimpleMethod" ,
675+ SpanAttributes .RPC_SERVICE : "GRPCTestServer" ,
676+ SpanAttributes .RPC_SYSTEM : "grpc" ,
677+ SpanAttributes .RPC_GRPC_STATUS_CODE : grpc .StatusCode .OK .value [
678+ 0
679+ ],
680+ },
681+ )
682+
605683
606684def get_latch (num ):
607685 """Get a countdown latch function for use in n threads."""
0 commit comments