11from collections import namedtuple
22from collections .abc import Callable , Coroutine
3- from http import HTTPStatus
3+ from types import TracebackType
4+ from typing import Any , Optional , Type
45
56import aiohttp
67from aiohttp import ClientSession
78
8- from .response import Response , Headers
9- from ..client_configuration import ClientConfiguration
9+ from .response import Response
1010from ..errors .client_error import ClientError
1111from ..errors .custom_error import CustomError
1212from ..errors .internal_error import InternalError
1313from ..errors .server_error import ServerError
1414from ..util import url
15- from ..util .retry .retry_result import Retry , Return , RetryResult
16- from ..util .retry .retry_with_backoff import retry_with_backoff
17- from ..util .retry .retry_error import RetryError
1815
1916Range = namedtuple ('Range' , 'lower, upper' )
2017
@@ -26,71 +23,62 @@ class UninitializedError(CustomError):
2623class HttpClient :
2724 def __init__ (
2825 self ,
29- client_configuration : ClientConfiguration
26+ base_url : str ,
27+ api_token : str ,
3028 ):
31- self .__client_configuration : ClientConfiguration = client_configuration
29+ self .__base_url = base_url
30+ self .__api_token = api_token
3231 self .__session : ClientSession | None = None
32+
33+ async def __aenter__ (self ):
34+ await self .initialize ()
35+ return self
36+
37+ async def __aexit__ (
38+ self ,
39+ exc_type : Optional [Type [BaseException ]],
40+ exc_val : Optional [BaseException ],
41+ exc_tb : Optional [TracebackType ]
42+ ) -> None :
43+ await self .close ()
3344
45+ # Keep for backward compatibility
3446 async def initialize (self ) -> None :
35- self .__session = aiohttp .ClientSession (
36- timeout = aiohttp .ClientTimeout (
37- connect = self .__client_configuration .timeout_seconds ,
38- sock_read = self .__client_configuration .timeout_seconds
39- )
40- )
47+ self .__session = aiohttp .ClientSession ()
4148
49+ # Keep for backward compatibility
4250 async def close (self ):
4351 if self .__session is not None :
4452 await self .__session .close ()
53+ self .__session = None
4554
4655 async def __execute_request (
4756 self ,
48- execute_request : Callable [[], Coroutine [None , None , RetryResult [ Response ] ]]
57+ execute_request : Callable [[], Coroutine [Any , Any , Response ]]
4958 ) -> Response :
5059 try :
51- return await retry_with_backoff (
52- self .__client_configuration .max_tries ,
53- execute_request
54- )
55- except RetryError as retry_error :
56- raise ServerError (str (retry_error )) from retry_error
60+ result = await execute_request ()
61+ return result
5762 except CustomError as custom_error :
5863 raise custom_error
5964 except aiohttp .ClientError as request_error :
6065 raise ServerError (str (request_error )) from request_error
6166 except Exception as other_error :
6267 raise InternalError (str (other_error )) from other_error
6368
64- def __validate_protocol_version (self , http_status_code : int , headers : Headers ) -> None :
65- if http_status_code != HTTPStatus .UNPROCESSABLE_ENTITY :
66- return
67-
68- server_protocol_version = headers .get (
69- 'x-eventsourcingdb-protocol-version'
70- )
71-
72- if server_protocol_version is None :
73- server_protocol_version = 'unknown version'
74-
75- raise ClientError (
76- f'Protocol version mismatch, server \' { server_protocol_version } \' ,'
77- f' client \' { self .__client_configuration .protocol_version } \' .'
78- )
7969
8070 @staticmethod
8171 async def __get_error_message (response : Response ):
8272 error_message = f'Request failed with status code \' { response .status_code } \' '
8373
8474 # We want to purposefully ignore all errors here, as we're already error handling,
8575 # and this function just tries to get more information on a best-effort basis.
86- # pylint: disable=too-many-try-statements
8776 try :
8877 encoded_error_reason = await response .body .read ()
8978 error_reason = encoded_error_reason .decode ('utf-8' )
9079 error_message += f" { error_reason } "
9180 finally :
9281 pass
93- # pylint: enable=too-many-try-statements
9482
9583 error_message += '.'
9684
@@ -99,23 +87,20 @@ async def __get_error_message(response: Response):
9987 async def __validate_response (
10088 self ,
10189 response : Response
102- ) -> RetryResult [ Response ] :
90+ ) -> Response :
10391 server_failure_range = Range (500 , 600 )
10492 if server_failure_range .lower <= response .status_code < server_failure_range .upper :
105- return Retry ( ServerError (await self .__get_error_message (response ) ))
93+ raise ServerError (await self .__get_error_message (response ))
10694
10795 client_failure_range = Range (400 , 500 )
10896 if client_failure_range .lower <= response .status_code < client_failure_range .upper :
10997 raise ClientError (await self .__get_error_message (response ))
11098
111- self .__validate_protocol_version (response .status_code , response .headers )
112-
113- return Return (response )
99+ return response
114100
115101 def __get_post_request_headers (self ) -> dict [str , str ]:
116102 headers = {
117- 'X-EventSourcingDB-Protocol-Version' : self .__client_configuration .protocol_version ,
118- 'Authorization' : f'Bearer { self .__client_configuration .api_token } ' ,
103+ 'Authorization' : f'Bearer { self .__api_token } ' ,
119104 'Content-Type' : 'application/json'
120105 }
121106
@@ -125,37 +110,31 @@ async def post(self, path: str, request_body: str) -> Response:
125110 if self .__session is None :
126111 raise UninitializedError ()
127112
128- async def execute_request () -> RetryResult [ Response ] :
129- response = await self .__session .post (
113+ async def execute_request () -> Response :
114+ async_response = await self .__session .post ( # type: ignore
130115 url .join_segments (
131- self .__client_configuration . base_url ,
116+ self .__base_url ,
132117 path
133118 ),
134119 data = request_body ,
135120 headers = self .__get_post_request_headers (),
136121 )
137122
138- response = Response (response )
123+ response = Response (async_response )
139124 try :
140125 result = await self .__validate_response (response )
126+ return result
141127 except Exception as error :
142128 response .close ()
143129 raise error
144130
145- if isinstance (result , Retry ):
146- response .close ()
147-
148- return result
149-
150131 return await self .__execute_request (execute_request )
151132
152133 def __get_get_request_headers (self , with_authorization : bool ) -> dict [str , str ]:
153- headers = {
154- 'X-EventSourcingDB-Protocol-Version' : self .__client_configuration .protocol_version ,
155- }
134+ headers = {}
156135
157136 if with_authorization :
158- headers ['Authorization' ] = f'Bearer { self .__client_configuration . api_token } '
137+ headers ['Authorization' ] = f'Bearer { self .__api_token } '
159138
160139 return headers
161140
@@ -167,23 +146,21 @@ async def get(
167146 if self .__session is None :
168147 raise UninitializedError ()
169148
170- async def execute_request () -> RetryResult [ Response ] :
171- response = await self .__session .get (
149+ async def execute_request () -> Response :
150+ async_response = await self .__session .get ( # type: ignore
172151 url .join_segments (
173- self .__client_configuration .base_url , path ),
152+ self .__base_url ,
153+ path
154+ ),
174155 headers = self .__get_get_request_headers (with_authorization ),
175156 )
176157
177- response = Response (response )
158+ response = Response (async_response )
178159 try :
179160 result = await self .__validate_response (response )
161+ return result
180162 except Exception as error :
181163 response .close ()
182164 raise error
183165
184- if isinstance (result , Retry ):
185- response .close ()
186-
187- return result
188-
189166 return await self .__execute_request (execute_request )
0 commit comments