1111# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212# See the License for the specific language governing permissions and
1313# limitations under the License.
14- from unittest import TestCase
1514
16- from aiokafka import AIOKafkaConsumer , AIOKafkaProducer
15+ import uuid
16+ from typing import Any , List , Sequence , Tuple
17+ from unittest import IsolatedAsyncioTestCase , TestCase , mock
18+
19+ from aiokafka import (
20+ AIOKafkaConsumer ,
21+ AIOKafkaProducer ,
22+ ConsumerRecord ,
23+ TopicPartition ,
24+ )
1725from wrapt import BoundFunctionWrapper
1826
27+ from opentelemetry import baggage , context
1928from opentelemetry .instrumentation .aiokafka import AIOKafkaInstrumentor
29+ from opentelemetry .sdk .trace import ReadableSpan
30+ from opentelemetry .semconv ._incubating .attributes import messaging_attributes
31+ from opentelemetry .semconv .attributes import server_attributes
32+ from opentelemetry .test .test_base import TestBase
33+ from opentelemetry .trace import SpanKind , format_trace_id , set_span_in_context
2034
2135
22- class TestAIOKafka (TestCase ):
36+ class TestAIOKafkaInstrumentor (TestCase ):
2337 def test_instrument_api (self ) -> None :
2438 instrumentation = AIOKafkaInstrumentor ()
2539
@@ -28,13 +42,279 @@ def test_instrument_api(self) -> None:
2842 isinstance (AIOKafkaProducer .send , BoundFunctionWrapper )
2943 )
3044 self .assertTrue (
31- isinstance (AIOKafkaConsumer .__anext__ , BoundFunctionWrapper )
45+ isinstance (AIOKafkaConsumer .getone , BoundFunctionWrapper )
3246 )
3347
3448 instrumentation .uninstrument ()
3549 self .assertFalse (
3650 isinstance (AIOKafkaProducer .send , BoundFunctionWrapper )
3751 )
3852 self .assertFalse (
39- isinstance (AIOKafkaConsumer .__anext__ , BoundFunctionWrapper )
53+ isinstance (AIOKafkaConsumer .getone , BoundFunctionWrapper )
54+ )
55+
56+
57+ class TestAIOKafkaInstrumentation (TestBase , IsolatedAsyncioTestCase ):
58+ @staticmethod
59+ def consumer_record_factory (
60+ number : int , headers : Tuple [Tuple [str , bytes ], ...]
61+ ) -> ConsumerRecord :
62+ return ConsumerRecord (
63+ f"topic_{ number } " ,
64+ number ,
65+ number ,
66+ number ,
67+ number ,
68+ f"key_{ number } " .encode (),
69+ f"value_{ number } " .encode (),
70+ None ,
71+ number ,
72+ number ,
73+ headers = headers ,
74+ )
75+
76+ @staticmethod
77+ async def consumer_factory (** consumer_kwargs : Any ) -> AIOKafkaConsumer :
78+ consumer = AIOKafkaConsumer (** consumer_kwargs )
79+
80+ consumer ._client .bootstrap = mock .AsyncMock ()
81+ consumer ._client ._wait_on_metadata = mock .AsyncMock ()
82+
83+ await consumer .start ()
84+
85+ consumer ._fetcher .next_record = mock .AsyncMock ()
86+
87+ return consumer
88+
89+ @staticmethod
90+ async def producer_factory () -> AIOKafkaProducer :
91+ producer = AIOKafkaProducer (api_version = "1.0" )
92+
93+ producer .client ._wait_on_metadata = mock .AsyncMock ()
94+ producer .client .bootstrap = mock .AsyncMock ()
95+ producer ._message_accumulator .add_message = mock .AsyncMock ()
96+ producer ._sender .start = mock .AsyncMock ()
97+ producer ._partition = mock .Mock (return_value = 1 )
98+
99+ await producer .start ()
100+
101+ return producer
102+
103+ async def test_getone (self ) -> None :
104+ AIOKafkaInstrumentor ().uninstrument ()
105+ AIOKafkaInstrumentor ().instrument (tracer_provider = self .tracer_provider )
106+
107+ client_id = str (uuid .uuid4 ())
108+ group_id = str (uuid .uuid4 ())
109+ consumer = await self .consumer_factory (
110+ client_id = client_id , group_id = group_id
111+ )
112+ next_record_mock : mock .AsyncMock = consumer ._fetcher .next_record
113+
114+ expected_spans = [
115+ {
116+ "name" : "topic_1 receive" ,
117+ "kind" : SpanKind .CONSUMER ,
118+ "attributes" : {
119+ messaging_attributes .MESSAGING_SYSTEM : messaging_attributes .MessagingSystemValues .KAFKA .value ,
120+ server_attributes .SERVER_ADDRESS : '"localhost"' ,
121+ messaging_attributes .MESSAGING_CLIENT_ID : client_id ,
122+ messaging_attributes .MESSAGING_DESTINATION_NAME : "topic_1" ,
123+ messaging_attributes .MESSAGING_DESTINATION_PARTITION_ID : "1" ,
124+ messaging_attributes .MESSAGING_KAFKA_MESSAGE_KEY : "key_1" ,
125+ messaging_attributes .MESSAGING_CONSUMER_GROUP_NAME : group_id ,
126+ messaging_attributes .MESSAGING_OPERATION_NAME : "receive" ,
127+ messaging_attributes .MESSAGING_OPERATION_TYPE : messaging_attributes .MessagingOperationTypeValues .RECEIVE .value ,
128+ messaging_attributes .MESSAGING_KAFKA_MESSAGE_OFFSET : 1 ,
129+ messaging_attributes .MESSAGING_MESSAGE_ID : "topic_1.1.1" ,
130+ },
131+ },
132+ {
133+ "name" : "topic_2 receive" ,
134+ "kind" : SpanKind .CONSUMER ,
135+ "attributes" : {
136+ messaging_attributes .MESSAGING_SYSTEM : messaging_attributes .MessagingSystemValues .KAFKA .value ,
137+ server_attributes .SERVER_ADDRESS : '"localhost"' ,
138+ messaging_attributes .MESSAGING_CLIENT_ID : client_id ,
139+ messaging_attributes .MESSAGING_DESTINATION_NAME : "topic_2" ,
140+ messaging_attributes .MESSAGING_DESTINATION_PARTITION_ID : "2" ,
141+ messaging_attributes .MESSAGING_KAFKA_MESSAGE_KEY : "key_2" ,
142+ messaging_attributes .MESSAGING_CONSUMER_GROUP_NAME : group_id ,
143+ messaging_attributes .MESSAGING_OPERATION_NAME : "receive" ,
144+ messaging_attributes .MESSAGING_OPERATION_TYPE : messaging_attributes .MessagingOperationTypeValues .RECEIVE .value ,
145+ messaging_attributes .MESSAGING_KAFKA_MESSAGE_OFFSET : 2 ,
146+ messaging_attributes .MESSAGING_MESSAGE_ID : "topic_2.2.2" ,
147+ },
148+ },
149+ ]
150+ self .memory_exporter .clear ()
151+
152+ next_record_mock .side_effect = [
153+ self .consumer_record_factory (
154+ 1 ,
155+ headers = (
156+ (
157+ "traceparent" ,
158+ b"00-03afa25236b8cd948fa853d67038ac79-405ff022e8247c46-01" ,
159+ ),
160+ ),
161+ ),
162+ self .consumer_record_factory (2 , headers = ()),
163+ ]
164+
165+ await consumer .getone ()
166+ next_record_mock .assert_awaited_with (())
167+
168+ first_span = self .memory_exporter .get_finished_spans ()[0 ]
169+ self .assertEqual (
170+ format_trace_id (first_span .get_span_context ().trace_id ),
171+ "03afa25236b8cd948fa853d67038ac79" ,
172+ )
173+
174+ await consumer .getone ()
175+ next_record_mock .assert_awaited_with (())
176+
177+ span_list = self .memory_exporter .get_finished_spans ()
178+ self ._compare_spans (span_list , expected_spans )
179+
180+ async def test_getone_baggage (self ) -> None :
181+ received_baggage = None
182+
183+ async def async_consume_hook (span , * _ ) -> None :
184+ nonlocal received_baggage
185+ received_baggage = baggage .get_all (set_span_in_context (span ))
186+
187+ AIOKafkaInstrumentor ().uninstrument ()
188+ AIOKafkaInstrumentor ().instrument (
189+ tracer_provider = self .tracer_provider ,
190+ async_consume_hook = async_consume_hook ,
191+ )
192+
193+ consumer = await self .consumer_factory ()
194+ next_record_mock : mock .AsyncMock = consumer ._fetcher .next_record
195+
196+ self .memory_exporter .clear ()
197+
198+ next_record_mock .side_effect = [
199+ self .consumer_record_factory (
200+ 1 ,
201+ headers = (
202+ (
203+ "traceparent" ,
204+ b"00-03afa25236b8cd948fa853d67038ac79-405ff022e8247c46-01" ,
205+ ),
206+ ("baggage" , b"foo=bar" ),
207+ ),
208+ ),
209+ ]
210+
211+ await consumer .getone ()
212+ next_record_mock .assert_awaited_with (())
213+
214+ self .assertEqual (received_baggage , {"foo" : "bar" })
215+
216+ async def test_getone_consume_hook (self ) -> None :
217+ async_consume_hook_mock = mock .AsyncMock ()
218+
219+ AIOKafkaInstrumentor ().uninstrument ()
220+ AIOKafkaInstrumentor ().instrument (
221+ tracer_provider = self .tracer_provider ,
222+ async_consume_hook = async_consume_hook_mock ,
223+ )
224+
225+ consumer = await self .consumer_factory ()
226+ next_record_mock : mock .AsyncMock = consumer ._fetcher .next_record
227+
228+ next_record_mock .side_effect = [
229+ self .consumer_record_factory (1 , headers = ())
230+ ]
231+
232+ await consumer .getone ()
233+
234+ async_consume_hook_mock .assert_awaited_once ()
235+
236+ async def test_send (self ) -> None :
237+ AIOKafkaInstrumentor ().uninstrument ()
238+ AIOKafkaInstrumentor ().instrument (tracer_provider = self .tracer_provider )
239+
240+ producer = await self .producer_factory ()
241+ add_message_mock : mock .AsyncMock = (
242+ producer ._message_accumulator .add_message
243+ )
244+
245+ tracer = self .tracer_provider .get_tracer (__name__ )
246+ with tracer .start_as_current_span ("test_span" ) as span :
247+ await producer .send ("topic_1" , b"value_1" )
248+
249+ add_message_mock .assert_awaited_with (
250+ TopicPartition (topic = "topic_1" , partition = 1 ),
251+ None ,
252+ b"value_1" ,
253+ 40.0 ,
254+ timestamp_ms = None ,
255+ headers = [("traceparent" , mock .ANY )],
40256 )
257+ add_message_mock .call_args_list [0 ].kwargs ["headers" ][0 ][1 ].startswith (
258+ f"00-{ format_trace_id (span .get_span_context ().trace_id )} -" .encode ()
259+ )
260+
261+ await producer .send ("topic_2" , b"value_2" )
262+ add_message_mock .assert_awaited_with (
263+ TopicPartition (topic = "topic_2" , partition = 1 ),
264+ None ,
265+ b"value_2" ,
266+ 40.0 ,
267+ timestamp_ms = None ,
268+ headers = [("traceparent" , mock .ANY )],
269+ )
270+
271+ async def test_send_baggage (self ) -> None :
272+ AIOKafkaInstrumentor ().uninstrument ()
273+ AIOKafkaInstrumentor ().instrument (tracer_provider = self .tracer_provider )
274+
275+ producer = await self .producer_factory ()
276+ add_message_mock : mock .AsyncMock = (
277+ producer ._message_accumulator .add_message
278+ )
279+
280+ tracer = self .tracer_provider .get_tracer (__name__ )
281+ ctx = baggage .set_baggage ("foo" , "bar" )
282+ context .attach (ctx )
283+
284+ with tracer .start_as_current_span ("test_span" , context = ctx ):
285+ await producer .send ("topic_1" , b"value_1" )
286+
287+ add_message_mock .assert_awaited_with (
288+ TopicPartition (topic = "topic_1" , partition = 1 ),
289+ None ,
290+ b"value_1" ,
291+ 40.0 ,
292+ timestamp_ms = None ,
293+ headers = [("traceparent" , mock .ANY ), ("baggage" , b"foo=bar" )],
294+ )
295+
296+ async def test_send_produce_hook (self ) -> None :
297+ async_produce_hook_mock = mock .AsyncMock ()
298+
299+ AIOKafkaInstrumentor ().uninstrument ()
300+ AIOKafkaInstrumentor ().instrument (
301+ tracer_provider = self .tracer_provider ,
302+ async_produce_hook = async_produce_hook_mock ,
303+ )
304+
305+ producer = await self .producer_factory ()
306+
307+ await producer .send ("topic_1" , b"value_1" )
308+
309+ async_produce_hook_mock .assert_awaited_once ()
310+
311+ def _compare_spans (
312+ self , spans : Sequence [ReadableSpan ], expected_spans : List [dict ]
313+ ) -> None :
314+ self .assertEqual (len (spans ), len (expected_spans ))
315+ for span , expected_span in zip (spans , expected_spans ):
316+ self .assertEqual (expected_span ["name" ], span .name )
317+ self .assertEqual (expected_span ["kind" ], span .kind )
318+ self .assertEqual (
319+ expected_span ["attributes" ], dict (span .attributes )
320+ )
0 commit comments