@@ -91,8 +91,9 @@ def response_hook(span, instance, response):
9191---
9292"""
9393
94- import typing
95- from typing import Any , Collection
94+ from __future__ import annotations
95+
96+ from typing import TYPE_CHECKING , Any , Callable , Collection
9697
9798import redis
9899from wrapt import wrap_function_wrapper
@@ -109,18 +110,43 @@ def response_hook(span, instance, response):
109110from opentelemetry .instrumentation .redis .version import __version__
110111from opentelemetry .instrumentation .utils import unwrap
111112from opentelemetry .semconv .trace import SpanAttributes
112- from opentelemetry .trace import Span , StatusCode
113+ from opentelemetry .trace import Span , StatusCode , Tracer
113114
114- _DEFAULT_SERVICE = "redis"
115+ if TYPE_CHECKING :
116+ from typing import Awaitable , TypeVar
115117
116- _RequestHookT = typing .Optional [
117- typing .Callable [
118- [Span , redis .connection .Connection , typing .List , typing .Dict ], None
118+ import redis .asyncio .client
119+ import redis .asyncio .cluster
120+ import redis .client
121+ import redis .cluster
122+ import redis .connection
123+
124+ _RequestHookT = Callable [
125+ [Span , redis .connection .Connection , list [Any ], dict [str , Any ]], None
119126 ]
120- ]
121- _ResponseHookT = typing .Optional [
122- typing .Callable [[Span , redis .connection .Connection , Any ], None ]
123- ]
127+ _ResponseHookT = Callable [[Span , redis .connection .Connection , Any ], None ]
128+
129+ AsyncPipelineInstance = TypeVar (
130+ "AsyncPipelineInstance" ,
131+ redis .asyncio .client .Pipeline ,
132+ redis .asyncio .cluster .ClusterPipeline ,
133+ )
134+ AsyncRedisInstance = TypeVar (
135+ "AsyncRedisInstance" , redis .asyncio .Redis , redis .asyncio .RedisCluster
136+ )
137+ PipelineInstance = TypeVar (
138+ "PipelineInstance" ,
139+ redis .client .Pipeline ,
140+ redis .cluster .ClusterPipeline ,
141+ )
142+ RedisInstance = TypeVar (
143+ "RedisInstance" , redis .client .Redis , redis .cluster .RedisCluster
144+ )
145+ R = TypeVar ("R" )
146+
147+
148+ _DEFAULT_SERVICE = "redis"
149+
124150
125151_REDIS_ASYNCIO_VERSION = (4 , 2 , 0 )
126152if redis .VERSION >= _REDIS_ASYNCIO_VERSION :
@@ -132,7 +158,9 @@ def response_hook(span, instance, response):
132158_FIELD_TYPES = ["NUMERIC" , "TEXT" , "GEO" , "TAG" , "VECTOR" ]
133159
134160
135- def _set_connection_attributes (span , conn ):
161+ def _set_connection_attributes (
162+ span : Span , conn : RedisInstance | AsyncRedisInstance
163+ ) -> None :
136164 if not span .is_recording () or not hasattr (conn , "connection_pool" ):
137165 return
138166 for key , value in _extract_conn_attributes (
@@ -141,7 +169,9 @@ def _set_connection_attributes(span, conn):
141169 span .set_attribute (key , value )
142170
143171
144- def _build_span_name (instance , cmd_args ):
172+ def _build_span_name (
173+ instance : RedisInstance | AsyncRedisInstance , cmd_args : tuple [Any , ...]
174+ ) -> str :
145175 if len (cmd_args ) > 0 and cmd_args [0 ]:
146176 if cmd_args [0 ] == "FT.SEARCH" :
147177 name = "redis.search"
@@ -154,7 +184,9 @@ def _build_span_name(instance, cmd_args):
154184 return name
155185
156186
157- def _build_span_meta_data_for_pipeline (instance ):
187+ def _build_span_meta_data_for_pipeline (
188+ instance : PipelineInstance | AsyncPipelineInstance ,
189+ ) -> tuple [list [Any ], str , str ]:
158190 try :
159191 command_stack = (
160192 instance .command_stack
@@ -184,11 +216,16 @@ def _build_span_meta_data_for_pipeline(instance):
184216
185217# pylint: disable=R0915
186218def _instrument (
187- tracer ,
188- request_hook : _RequestHookT = None ,
189- response_hook : _ResponseHookT = None ,
219+ tracer : Tracer ,
220+ request_hook : _RequestHookT | None = None ,
221+ response_hook : _ResponseHookT | None = None ,
190222):
191- def _traced_execute_command (func , instance , args , kwargs ):
223+ def _traced_execute_command (
224+ func : Callable [..., R ],
225+ instance : RedisInstance ,
226+ args : tuple [Any , ...],
227+ kwargs : dict [str , Any ],
228+ ) -> R :
192229 query = _format_command_args (args )
193230 name = _build_span_name (instance , args )
194231 with tracer .start_as_current_span (
@@ -210,7 +247,12 @@ def _traced_execute_command(func, instance, args, kwargs):
210247 response_hook (span , instance , response )
211248 return response
212249
213- def _traced_execute_pipeline (func , instance , args , kwargs ):
250+ def _traced_execute_pipeline (
251+ func : Callable [..., R ],
252+ instance : PipelineInstance ,
253+ args : tuple [Any , ...],
254+ kwargs : dict [str , Any ],
255+ ) -> R :
214256 (
215257 command_stack ,
216258 resource ,
@@ -242,7 +284,7 @@ def _traced_execute_pipeline(func, instance, args, kwargs):
242284
243285 return response
244286
245- def _add_create_attributes (span , args ):
287+ def _add_create_attributes (span : Span , args : tuple [ Any , ...] ):
246288 _set_span_attribute_if_value (
247289 span , "redis.create_index.index" , _value_or_none (args , 1 )
248290 )
@@ -266,7 +308,7 @@ def _add_create_attributes(span, args):
266308 field_attribute ,
267309 )
268310
269- def _add_search_attributes (span , response , args ):
311+ def _add_search_attributes (span : Span , response , args ):
270312 _set_span_attribute_if_value (
271313 span , "redis.search.index" , _value_or_none (args , 1 )
272314 )
@@ -326,7 +368,12 @@ def _add_search_attributes(span, response, args):
326368 _traced_execute_pipeline ,
327369 )
328370
329- async def _async_traced_execute_command (func , instance , args , kwargs ):
371+ async def _async_traced_execute_command (
372+ func : Callable [..., Awaitable [R ]],
373+ instance : AsyncRedisInstance ,
374+ args : tuple [Any , ...],
375+ kwargs : dict [str , Any ],
376+ ) -> Awaitable [R ]:
330377 query = _format_command_args (args )
331378 name = _build_span_name (instance , args )
332379
@@ -344,7 +391,12 @@ async def _async_traced_execute_command(func, instance, args, kwargs):
344391 response_hook (span , instance , response )
345392 return response
346393
347- async def _async_traced_execute_pipeline (func , instance , args , kwargs ):
394+ async def _async_traced_execute_pipeline (
395+ func : Callable [..., Awaitable [R ]],
396+ instance : AsyncPipelineInstance ,
397+ args : tuple [Any , ...],
398+ kwargs : dict [str , Any ],
399+ ) -> Awaitable [R ]:
348400 (
349401 command_stack ,
350402 resource ,
@@ -408,14 +460,15 @@ async def _async_traced_execute_pipeline(func, instance, args, kwargs):
408460
409461
410462class RedisInstrumentor (BaseInstrumentor ):
411- """An instrumentor for Redis
463+ """An instrumentor for Redis.
464+
412465 See `BaseInstrumentor`
413466 """
414467
415468 def instrumentation_dependencies (self ) -> Collection [str ]:
416469 return _instruments
417470
418- def _instrument (self , ** kwargs ):
471+ def _instrument (self , ** kwargs : Any ):
419472 """Instruments the redis module
420473
421474 Args:
@@ -436,7 +489,7 @@ def _instrument(self, **kwargs):
436489 response_hook = kwargs .get ("response_hook" ),
437490 )
438491
439- def _uninstrument (self , ** kwargs ):
492+ def _uninstrument (self , ** kwargs : Any ):
440493 if redis .VERSION < (3 , 0 , 0 ):
441494 unwrap (redis .StrictRedis , "execute_command" )
442495 unwrap (redis .StrictRedis , "pipeline" )
0 commit comments