|
8 | 8 | from http import HTTPStatus |
9 | 9 | from http.server import HTTPServer, SimpleHTTPRequestHandler |
10 | 10 | from multiprocessing.context import Process |
11 | | -from typing import Type |
| 11 | +from queue import Queue |
| 12 | +from typing import Type, Union |
12 | 13 | from unittest import TestCase |
13 | | -from urllib.parse import urlparse, parse_qs |
| 14 | +from urllib.parse import parse_qs, urlparse |
14 | 15 | from urllib.request import Request, urlopen |
15 | 16 |
|
16 | | -from tests.helpers import get_mock_server_mode |
| 17 | +from tests.helpers import ReceivedRequests, get_mock_server_mode |
17 | 18 |
|
18 | 19 |
|
19 | 20 | class MockHandler(SimpleHTTPRequestHandler): |
@@ -78,6 +79,8 @@ def set_common_headers(self): |
78 | 79 |
|
79 | 80 | def _handle(self): |
80 | 81 | try: |
| 82 | + # put_nowait is common between Queue & asyncio.Queue, it does not need to be awaited |
| 83 | + self.server.queue.put_nowait(self.path) |
81 | 84 | if self.path == "/received_requests.json": |
82 | 85 | self.send_response(200) |
83 | 86 | self.set_common_headers() |
@@ -313,32 +316,44 @@ def stop(self): |
313 | 316 |
|
314 | 317 |
|
315 | 318 | class MockServerThread(threading.Thread): |
316 | | - def __init__(self, test: TestCase, handler: Type[SimpleHTTPRequestHandler] = MockHandler): |
| 319 | + def __init__( |
| 320 | + self, queue: Union[Queue, asyncio.Queue], test: TestCase, handler: Type[SimpleHTTPRequestHandler] = MockHandler |
| 321 | + ): |
317 | 322 | threading.Thread.__init__(self) |
318 | 323 | self.handler = handler |
319 | 324 | self.test = test |
| 325 | + self.queue = queue |
320 | 326 |
|
321 | 327 | def run(self): |
322 | 328 | self.server = HTTPServer(("localhost", 8888), self.handler) |
| 329 | + self.server.queue = self.queue |
323 | 330 | self.test.server_url = "http://localhost:8888" |
324 | 331 | self.test.host, self.test.port = self.server.socket.getsockname() |
325 | 332 | self.test.server_started.set() # threading.Event() |
326 | 333 |
|
327 | 334 | self.test = None |
328 | 335 | try: |
329 | | - self.server.serve_forever() |
| 336 | + self.server.serve_forever(0.05) |
330 | 337 | finally: |
331 | 338 | self.server.server_close() |
332 | 339 |
|
333 | 340 | def stop(self): |
| 341 | + with self.server.queue.mutex: |
| 342 | + del self.server.queue |
| 343 | + self.server.shutdown() |
| 344 | + self.join() |
| 345 | + |
| 346 | + def stop_unsafe(self): |
| 347 | + del self.server.queue |
334 | 348 | self.server.shutdown() |
335 | 349 | self.join() |
336 | 350 |
|
337 | 351 |
|
338 | 352 | def setup_mock_web_api_server(test: TestCase): |
339 | 353 | if get_mock_server_mode() == "threading": |
340 | 354 | test.server_started = threading.Event() |
341 | | - test.thread = MockServerThread(test) |
| 355 | + test.received_requests = ReceivedRequests(Queue()) |
| 356 | + test.thread = MockServerThread(test.received_requests.queue, test) |
342 | 357 | test.thread.start() |
343 | 358 | test.server_started.wait() |
344 | 359 | else: |
@@ -389,37 +404,65 @@ def cleanup_mock_web_api_server(test: TestCase): |
389 | 404 | test.process = None |
390 | 405 |
|
391 | 406 |
|
392 | | -def assert_auth_test_count(test: TestCase, expected_count: int): |
393 | | - time.sleep(0.1) |
394 | | - retry_count = 0 |
| 407 | +def assert_received_request_count(test: TestCase, path: str, min_count: int, timeout: float = 1): |
| 408 | + start_time = time.time() |
395 | 409 | error = None |
396 | | - while retry_count < 3: |
| 410 | + while time.time() - start_time < timeout: |
397 | 411 | try: |
398 | | - test.mock_received_requests["/auth.test"] == expected_count |
399 | | - break |
| 412 | + received_count = test.received_requests.get(path, 0) |
| 413 | + assert ( |
| 414 | + received_count == min_count |
| 415 | + ), f"Expected {min_count} '{path}' {'requests' if min_count > 1 else 'request'}, but got {received_count}!" |
| 416 | + return |
400 | 417 | except Exception as e: |
401 | 418 | error = e |
402 | | - retry_count += 1 |
403 | | - # waiting for mock_received_requests updates |
404 | | - time.sleep(0.1) |
| 419 | + # waiting for some requests to be received |
| 420 | + time.sleep(0.05) |
405 | 421 |
|
406 | 422 | if error is not None: |
407 | 423 | raise error |
408 | 424 |
|
409 | 425 |
|
410 | | -async def assert_auth_test_count_async(test: TestCase, expected_count: int): |
411 | | - await asyncio.sleep(0.1) |
412 | | - retry_count = 0 |
| 426 | +def assert_auth_test_count(test: TestCase, expected_count: int): |
| 427 | + assert_received_request_count(test, "/auth.test", expected_count, 0.5) |
| 428 | + |
| 429 | + |
| 430 | +######### |
| 431 | +# async # |
| 432 | +######### |
| 433 | + |
| 434 | + |
| 435 | +def setup_mock_web_api_server_async(test: TestCase): |
| 436 | + test.server_started = threading.Event() |
| 437 | + test.received_requests = ReceivedRequests(asyncio.Queue()) |
| 438 | + test.thread = MockServerThread(test.received_requests.queue, test) |
| 439 | + test.thread.start() |
| 440 | + test.server_started.wait() |
| 441 | + |
| 442 | + |
| 443 | +def cleanup_mock_web_api_server_async(test: TestCase): |
| 444 | + test.thread.stop_unsafe() |
| 445 | + test.thread = None |
| 446 | + |
| 447 | + |
| 448 | +async def assert_received_request_count_async(test: TestCase, path: str, min_count: int, timeout: float = 1): |
| 449 | + start_time = time.time() |
413 | 450 | error = None |
414 | | - while retry_count < 3: |
| 451 | + while time.time() - start_time < timeout: |
415 | 452 | try: |
416 | | - test.mock_received_requests["/auth.test"] == expected_count |
417 | | - break |
| 453 | + received_count = await test.received_requests.get_async(path, 0) |
| 454 | + assert ( |
| 455 | + received_count == min_count |
| 456 | + ), f"Expected {min_count} '{path}' {'requests' if min_count > 1 else 'request'}, but got {received_count}!" |
| 457 | + return |
418 | 458 | except Exception as e: |
419 | 459 | error = e |
420 | | - retry_count += 1 |
421 | 460 | # waiting for mock_received_requests updates |
422 | | - await asyncio.sleep(0.1) |
| 461 | + await asyncio.sleep(0.05) |
423 | 462 |
|
424 | 463 | if error is not None: |
425 | 464 | raise error |
| 465 | + |
| 466 | + |
| 467 | +async def assert_auth_test_count_async(test: TestCase, expected_count: int): |
| 468 | + await assert_received_request_count_async(test, "/auth.test", expected_count, 0.5) |
0 commit comments