|
8 | 8 | import threading |
9 | 9 | import time |
10 | 10 | import unittest |
| 11 | +from ctypes import c_long, py_object, pythonapi |
11 | 12 | from datetime import datetime |
12 | 13 | from functools import partial |
13 | 14 | from unittest.mock import patch |
|
30 | 31 | "This plugin does not support raise()", |
31 | 32 | ] |
32 | 33 |
|
| 34 | +_TIMEOUT_CONTEXT = { |
| 35 | + "testId": None, |
| 36 | + "timeout": None, |
| 37 | +} |
| 38 | + |
| 39 | + |
| 40 | +class TestTimeoutError(BaseException): |
| 41 | + pass |
| 42 | + |
| 43 | + |
| 44 | +def _raiseAsyncException(threadId: int, exceptionType: type): |
| 45 | + result = pythonapi.PyThreadState_SetAsyncExc(c_long(threadId), py_object(exceptionType)) |
| 46 | + if result > 1: |
| 47 | + pythonapi.PyThreadState_SetAsyncExc(c_long(threadId), py_object(None)) |
| 48 | + |
| 49 | + |
| 50 | +def _triggerTestTimeout(mainThreadId: int, testId: str, timeoutSeconds: float): |
| 51 | + _TIMEOUT_CONTEXT["testId"] = testId |
| 52 | + _TIMEOUT_CONTEXT["timeout"] = timeoutSeconds |
| 53 | + _raiseAsyncException(mainThreadId, TestTimeoutError) |
| 54 | + |
| 55 | + |
| 56 | +class TimeoutTextTestResult(unittest.TextTestResult): |
| 57 | + def addError(self, test, err): |
| 58 | + errorType, errorValue, errorTraceback = err |
| 59 | + if issubclass(errorType, TestTimeoutError): |
| 60 | + testId = _TIMEOUT_CONTEXT.get("testId") or test.id() |
| 61 | + timeoutSeconds = _TIMEOUT_CONTEXT.get("timeout") |
| 62 | + timeoutMessage = f"Test timed out after {timeoutSeconds}s: {testId}" |
| 63 | + super().addFailure( |
| 64 | + test, |
| 65 | + (AssertionError, AssertionError(timeoutMessage), errorTraceback) |
| 66 | + ) |
| 67 | + self.stop() |
| 68 | + return |
| 69 | + super().addError(test, err) |
| 70 | + |
| 71 | + |
| 72 | +class TimeoutTextTestRunner(unittest.TextTestRunner): |
| 73 | + resultclass = TimeoutTextTestResult |
| 74 | + |
| 75 | + |
| 76 | +unittest.TextTestRunner = TimeoutTextTestRunner |
| 77 | + |
33 | 78 |
|
34 | 79 | def _qt_message_handler(type: QtMsgType, context: QMessageLogContext, msg: str): |
35 | 80 | if type == QtMsgType.QtWarningMsg and msg in knownQtWarnings: |
@@ -145,6 +190,25 @@ def _init_with_trace(instance, *args, **kwargs): |
145 | 190 |
|
146 | 191 |
|
147 | 192 | class TestBase(unittest.TestCase): |
| 193 | + TEST_TIMEOUT_SECONDS = 30 |
| 194 | + |
| 195 | + def run(self, result=None): |
| 196 | + timeoutSeconds = self.TEST_TIMEOUT_SECONDS |
| 197 | + if timeoutSeconds <= 0: |
| 198 | + return super().run(result) |
| 199 | + |
| 200 | + watchdog = threading.Timer( |
| 201 | + timeoutSeconds, |
| 202 | + _triggerTestTimeout, |
| 203 | + args=(threading.main_thread().ident, self.id(), timeoutSeconds), |
| 204 | + ) |
| 205 | + watchdog.daemon = True |
| 206 | + watchdog.start() |
| 207 | + try: |
| 208 | + return super().run(result) |
| 209 | + finally: |
| 210 | + watchdog.cancel() |
| 211 | + |
148 | 212 | def setUp(self): |
149 | 213 | self.oldDir = Git.REPO_DIR or os.getcwd() |
150 | 214 | self.gitDir = None |
|
0 commit comments