|
1 | 1 | import abc |
| 2 | +import asyncio |
2 | 3 | import enum |
3 | 4 | import functools |
4 | 5 |
|
@@ -196,3 +197,116 @@ def wrap_execute_query_response( |
196 | 197 | return convert.ResultSet.from_message(response_pb.result_set, settings) |
197 | 198 |
|
198 | 199 | return None |
| 200 | + |
| 201 | + |
| 202 | +class TxListener: |
| 203 | + def _on_before_commit(self): |
| 204 | + pass |
| 205 | + |
| 206 | + def _on_after_commit(self, exc: typing.Optional[BaseException]): |
| 207 | + pass |
| 208 | + |
| 209 | + def _on_before_rollback(self): |
| 210 | + pass |
| 211 | + |
| 212 | + def _on_after_rollback(self, exc: typing.Optional[BaseException]): |
| 213 | + pass |
| 214 | + |
| 215 | + |
| 216 | +class TxListenerAsyncIO: |
| 217 | + async def _on_before_commit(self): |
| 218 | + pass |
| 219 | + |
| 220 | + async def _on_after_commit(self, exc: typing.Optional[BaseException]): |
| 221 | + pass |
| 222 | + |
| 223 | + async def _on_before_rollback(self): |
| 224 | + pass |
| 225 | + |
| 226 | + async def _on_after_rollback(self, exc: typing.Optional[BaseException]): |
| 227 | + pass |
| 228 | + |
| 229 | + |
| 230 | +def with_transaction_events(method): |
| 231 | + @functools.wraps(method) |
| 232 | + def wrapper(self, *args, **kwargs): |
| 233 | + method_name = method.__name__ |
| 234 | + before_event = f"_on_before_{method_name}" |
| 235 | + after_event = f"_on_after_{method_name}" |
| 236 | + |
| 237 | + self._notify_listeners_sync(before_event) |
| 238 | + |
| 239 | + try: |
| 240 | + result = method(self, *args, **kwargs) |
| 241 | + |
| 242 | + self._notify_listeners_sync(after_event, exc=None) |
| 243 | + |
| 244 | + return result |
| 245 | + except BaseException as e: |
| 246 | + self._notify_listeners_sync(after_event, exc=e) |
| 247 | + raise |
| 248 | + |
| 249 | + return wrapper |
| 250 | + |
| 251 | + |
| 252 | +def with_async_transaction_events(method): |
| 253 | + @functools.wraps(method) |
| 254 | + async def wrapper(self, *args, **kwargs): |
| 255 | + method_name = method.__name__ |
| 256 | + before_event = f"_on_before_{method_name}" |
| 257 | + after_event = f"_on_after_{method_name}" |
| 258 | + |
| 259 | + await self._notify_listeners_async(before_event) |
| 260 | + |
| 261 | + try: |
| 262 | + result = await method(self, *args, **kwargs) |
| 263 | + |
| 264 | + await self._notify_listeners_async(after_event, exc=None) |
| 265 | + |
| 266 | + return result |
| 267 | + except BaseException as e: |
| 268 | + await self._notify_listeners_async(after_event, exc=e) |
| 269 | + raise |
| 270 | + |
| 271 | + return wrapper |
| 272 | + |
| 273 | + |
| 274 | +class ListenerHandlerMixin: |
| 275 | + def _init_listener_handler(self): |
| 276 | + self.listeners = [] |
| 277 | + |
| 278 | + def _add_listener(self, listener): |
| 279 | + if listener not in self.listeners: |
| 280 | + self.listeners.append(listener) |
| 281 | + return self |
| 282 | + |
| 283 | + def _remove_listener(self, listener): |
| 284 | + if listener in self.listeners: |
| 285 | + self.listeners.remove(listener) |
| 286 | + return self |
| 287 | + |
| 288 | + def _clear_listeners(self): |
| 289 | + self.listeners.clear() |
| 290 | + return self |
| 291 | + |
| 292 | + def _notify_sync_listeners(self, event_name: str, **kwargs) -> None: |
| 293 | + for listener in self.listeners: |
| 294 | + if isinstance(listener, TxListener) and hasattr(listener, event_name): |
| 295 | + getattr(listener, event_name)(**kwargs) |
| 296 | + |
| 297 | + async def _notify_async_listeners(self, event_name: str, **kwargs) -> None: |
| 298 | + coros = [] |
| 299 | + for listener in self.listeners: |
| 300 | + if isinstance(listener, TxListenerAsyncIO) and hasattr(listener, event_name): |
| 301 | + coros.append(getattr(listener, event_name)(**kwargs)) |
| 302 | + |
| 303 | + if coros: |
| 304 | + await asyncio.gather(*coros) |
| 305 | + |
| 306 | + def _notify_listeners_sync(self, event_name: str, **kwargs) -> None: |
| 307 | + self._notify_sync_listeners(event_name, **kwargs) |
| 308 | + |
| 309 | + async def _notify_listeners_async(self, event_name: str, **kwargs) -> None: |
| 310 | + # self._notify_sync_listeners(event_name, **kwargs) |
| 311 | + |
| 312 | + await self._notify_async_listeners(event_name, **kwargs) |
0 commit comments