33from collections .abc import AsyncIterator
44from typing import Any , AsyncGenerator , Iterator , Literal
55
6+ import grpc .aio
67import nanoid # type: ignore
78import pytest
9+ from opentelemetry import trace
10+ from opentelemetry .sdk .trace import TracerProvider
11+ from opentelemetry .sdk .trace .export import SimpleSpanProcessor
12+ from opentelemetry .sdk .trace .export .in_memory_span_exporter import InMemorySpanExporter
813from websockets .server import serve
914
1015from replit_river .client import Client
1116from replit_river .client_transport import UriAndMetadata
12- from replit_river .error_schema import RiverError
17+ from replit_river .error_schema import RiverError , RiverException
1318from replit_river .rpc import (
14- GrpcContext ,
1519 TransportMessage ,
1620 rpc_method_handler ,
1721 stream_method_handler ,
@@ -68,12 +72,12 @@ def deserialize_error(response: dict) -> RiverError:
6872
6973
7074# RPC method handlers for testing
71- async def rpc_handler (request : str , context : GrpcContext ) -> str :
75+ async def rpc_handler (request : str , context : grpc . aio . ServicerContext ) -> str :
7276 return f"Hello, { request } !"
7377
7478
7579async def subscription_handler (
76- request : str , context : GrpcContext
80+ request : str , context : grpc . aio . ServicerContext
7781) -> AsyncGenerator [str , None ]:
7882 for i in range (5 ):
7983 yield f"Subscription message { i } for { request } "
@@ -93,7 +97,8 @@ async def upload_handler(
9397
9498
9599async def stream_handler (
96- request : Iterator [str ] | AsyncIterator [str ], context : GrpcContext
100+ request : Iterator [str ] | AsyncIterator [str ],
101+ context : grpc .aio .ServicerContext ,
97102) -> AsyncGenerator [str , None ]:
98103 if isinstance (request , AsyncIterator ):
99104 async for data in request :
@@ -103,6 +108,14 @@ async def stream_handler(
103108 yield f"Stream response for { data } "
104109
105110
111+ async def stream_error_handler (
112+ request : Iterator [str ] | AsyncIterator [str ],
113+ context : grpc .aio .ServicerContext ,
114+ ) -> AsyncGenerator [str , None ]:
115+ raise RiverException ("INJECTED_ERROR" , "test error" )
116+ yield "test" # appease the type checker
117+
118+
106119@pytest .fixture
107120def transport_options () -> TransportOptions :
108121 return TransportOptions ()
@@ -137,6 +150,12 @@ def server(transport_options: TransportOptions) -> Server:
137150 stream_handler , deserialize_request , serialize_response
138151 ),
139152 ),
153+ ("test_service" , "stream_method_error" ): (
154+ "stream" ,
155+ stream_method_handler (
156+ stream_error_handler , deserialize_request , serialize_response
157+ ),
158+ ),
140159 }
141160 )
142161 return server
@@ -173,3 +192,18 @@ async def websocket_uri_factory() -> UriAndMetadata[None]:
173192 await server .close ()
174193 # Server should close normally
175194 no_logging_error ()
195+
196+
197+ @pytest .fixture (scope = "session" )
198+ def span_exporter () -> InMemorySpanExporter :
199+ exporter = InMemorySpanExporter ()
200+ processor = SimpleSpanProcessor (exporter )
201+ provider = TracerProvider ()
202+ provider .add_span_processor (processor )
203+ trace .set_tracer_provider (provider )
204+ return exporter
205+
206+
207+ @pytest .fixture (autouse = True )
208+ def reset_span_exporter (span_exporter : InMemorySpanExporter ) -> None :
209+ span_exporter .clear ()
0 commit comments