137137---
138138"""
139139
140+ from __future__ import annotations
141+
140142import logging
141- import typing
142- from typing import Collection
143+ from typing import Any , Callable , Collection , TypeVar
143144
144145import psycopg # pylint: disable=import-self
145- from psycopg import (
146- AsyncCursor as pg_async_cursor , # pylint: disable=import-self,no-name-in-module
147- )
148- from psycopg import (
149- Cursor as pg_cursor , # pylint: disable=no-name-in-module,import-self
150- )
151146from psycopg .sql import Composed # pylint: disable=no-name-in-module
152147
153148from opentelemetry .instrumentation import dbapi
154149from opentelemetry .instrumentation .instrumentor import BaseInstrumentor
155150from opentelemetry .instrumentation .psycopg .package import _instruments
156151from opentelemetry .instrumentation .psycopg .version import __version__
152+ from opentelemetry .trace import TracerProvider
157153
158154_logger = logging .getLogger (__name__ )
159155_OTEL_CURSOR_FACTORY_KEY = "_otel_orig_cursor_factory"
160156
157+ ConnectionT = TypeVar (
158+ "ConnectionT" , psycopg .Connection , psycopg .AsyncConnection
159+ )
160+ CursorT = TypeVar ("CursorT" , psycopg .Cursor , psycopg .AsyncCursor )
161+
161162
162163class PsycopgInstrumentor (BaseInstrumentor ):
163164 _CONNECTION_ATTRIBUTES = {
@@ -172,7 +173,7 @@ class PsycopgInstrumentor(BaseInstrumentor):
172173 def instrumentation_dependencies (self ) -> Collection [str ]:
173174 return _instruments
174175
175- def _instrument (self , ** kwargs ):
176+ def _instrument (self , ** kwargs : Any ):
176177 """Integrate with PostgreSQL Psycopg library.
177178 Psycopg: http://initd.org/psycopg/
178179 """
@@ -223,7 +224,7 @@ def _instrument(self, **kwargs):
223224 enable_attribute_commenter = enable_attribute_commenter ,
224225 )
225226
226- def _uninstrument (self , ** kwargs ):
227+ def _uninstrument (self , ** kwargs : Any ):
227228 """ "Disable Psycopg instrumentation"""
228229 dbapi .unwrap_connect (psycopg , "connect" ) # pylint: disable=no-member
229230 dbapi .unwrap_connect (
@@ -237,7 +238,9 @@ def _uninstrument(self, **kwargs):
237238
238239 # TODO(owais): check if core dbapi can do this for all dbapi implementations e.g, pymysql and mysql
239240 @staticmethod
240- def instrument_connection (connection , tracer_provider = None ):
241+ def instrument_connection (
242+ connection : ConnectionT , tracer_provider : TracerProvider | None = None
243+ ) -> ConnectionT :
241244 """Enable instrumentation in a psycopg connection.
242245
243246 Args:
@@ -269,7 +272,7 @@ def instrument_connection(connection, tracer_provider=None):
269272
270273 # TODO(owais): check if core dbapi can do this for all dbapi implementations e.g, pymysql and mysql
271274 @staticmethod
272- def uninstrument_connection (connection ) :
275+ def uninstrument_connection (connection : ConnectionT ) -> ConnectionT :
273276 connection .cursor_factory = getattr (
274277 connection , _OTEL_CURSOR_FACTORY_KEY , None
275278 )
@@ -281,9 +284,9 @@ def uninstrument_connection(connection):
281284class DatabaseApiIntegration (dbapi .DatabaseApiIntegration ):
282285 def wrapped_connection (
283286 self ,
284- connect_method : typing . Callable [..., typing . Any ],
285- args : typing . Tuple [ typing . Any , typing . Any ],
286- kwargs : typing . Dict [ typing . Any , typing . Any ],
287+ connect_method : Callable [..., Any ],
288+ args : tuple [ Any , Any ],
289+ kwargs : dict [ Any , Any ],
287290 ):
288291 """Add object proxy to connection object."""
289292 base_cursor_factory = kwargs .pop ("cursor_factory" , None )
@@ -299,9 +302,9 @@ def wrapped_connection(
299302class DatabaseApiAsyncIntegration (dbapi .DatabaseApiIntegration ):
300303 async def wrapped_connection (
301304 self ,
302- connect_method : typing . Callable [..., typing . Any ],
303- args : typing . Tuple [ typing . Any , typing . Any ],
304- kwargs : typing . Dict [ typing . Any , typing . Any ],
305+ connect_method : Callable [..., Any ],
306+ args : tuple [ Any , Any ],
307+ kwargs : dict [ Any , Any ],
305308 ):
306309 """Add object proxy to connection object."""
307310 base_cursor_factory = kwargs .pop ("cursor_factory" , None )
@@ -317,7 +320,7 @@ async def wrapped_connection(
317320
318321
319322class CursorTracer (dbapi .CursorTracer ):
320- def get_operation_name (self , cursor , args ) :
323+ def get_operation_name (self , cursor : CursorT , args : list [ Any ]) -> str :
321324 if not args :
322325 return ""
323326
@@ -332,7 +335,7 @@ def get_operation_name(self, cursor, args):
332335
333336 return ""
334337
335- def get_statement (self , cursor , args ) :
338+ def get_statement (self , cursor : CursorT , args : list [ Any ]) -> str :
336339 if not args :
337340 return ""
338341
@@ -342,7 +345,11 @@ def get_statement(self, cursor, args):
342345 return statement
343346
344347
345- def _new_cursor_factory (db_api = None , base_factory = None , tracer_provider = None ):
348+ def _new_cursor_factory (
349+ db_api : DatabaseApiIntegration | None = None ,
350+ base_factory : type [psycopg .Cursor ] | None = None ,
351+ tracer_provider : TracerProvider | None = None ,
352+ ):
346353 if not db_api :
347354 db_api = DatabaseApiIntegration (
348355 __name__ ,
@@ -352,21 +359,21 @@ def _new_cursor_factory(db_api=None, base_factory=None, tracer_provider=None):
352359 tracer_provider = tracer_provider ,
353360 )
354361
355- base_factory = base_factory or pg_cursor
362+ base_factory = base_factory or psycopg . Cursor
356363 _cursor_tracer = CursorTracer (db_api )
357364
358365 class TracedCursorFactory (base_factory ):
359- def execute (self , * args , ** kwargs ):
366+ def execute (self , * args : Any , ** kwargs : Any ):
360367 return _cursor_tracer .traced_execution (
361368 self , super ().execute , * args , ** kwargs
362369 )
363370
364- def executemany (self , * args , ** kwargs ):
371+ def executemany (self , * args : Any , ** kwargs : Any ):
365372 return _cursor_tracer .traced_execution (
366373 self , super ().executemany , * args , ** kwargs
367374 )
368375
369- def callproc (self , * args , ** kwargs ):
376+ def callproc (self , * args : Any , ** kwargs : Any ):
370377 return _cursor_tracer .traced_execution (
371378 self , super ().callproc , * args , ** kwargs
372379 )
@@ -375,7 +382,9 @@ def callproc(self, *args, **kwargs):
375382
376383
377384def _new_cursor_async_factory (
378- db_api = None , base_factory = None , tracer_provider = None
385+ db_api : DatabaseApiAsyncIntegration | None = None ,
386+ base_factory : type [psycopg .AsyncCursor ] | None = None ,
387+ tracer_provider : TracerProvider | None = None ,
379388):
380389 if not db_api :
381390 db_api = DatabaseApiAsyncIntegration (
@@ -385,21 +394,21 @@ def _new_cursor_async_factory(
385394 version = __version__ ,
386395 tracer_provider = tracer_provider ,
387396 )
388- base_factory = base_factory or pg_async_cursor
397+ base_factory = base_factory or psycopg . AsyncCursor
389398 _cursor_tracer = CursorTracer (db_api )
390399
391400 class TracedCursorAsyncFactory (base_factory ):
392- async def execute (self , * args , ** kwargs ):
401+ async def execute (self , * args : Any , ** kwargs : Any ):
393402 return await _cursor_tracer .traced_execution (
394403 self , super ().execute , * args , ** kwargs
395404 )
396405
397- async def executemany (self , * args , ** kwargs ):
406+ async def executemany (self , * args : Any , ** kwargs : Any ):
398407 return await _cursor_tracer .traced_execution (
399408 self , super ().executemany , * args , ** kwargs
400409 )
401410
402- async def callproc (self , * args , ** kwargs ):
411+ async def callproc (self , * args : Any , ** kwargs : Any ):
403412 return await _cursor_tracer .traced_execution (
404413 self , super ().callproc , * args , ** kwargs
405414 )
0 commit comments