| 
1 | 1 | """RPLogger class for low-level logging in tests."""  | 
2 | 2 | 
 
  | 
3 | 3 | import logging  | 
 | 4 | +import threading  | 
4 | 5 | from contextlib import contextmanager  | 
5 | 6 | from functools import wraps  | 
 | 7 | + | 
 | 8 | +from reportportal_client.client import RPClient  | 
 | 9 | + | 
 | 10 | +from reportportal_client._local import current, set_current  | 
6 | 11 | from reportportal_client import RPLogger  | 
7 | 12 | 
 
  | 
8 | 13 | 
 
  | 
 | 14 | +@contextmanager  | 
 | 15 | +def patching_thread_class(config):  | 
 | 16 | +    """  | 
 | 17 | +    Add patch for Thread class.  | 
 | 18 | +
  | 
 | 19 | +    Set the parent thread client as the child thread's local client  | 
 | 20 | +    """  | 
 | 21 | +    if not config.rp_thread_logging:  | 
 | 22 | +        # Do nothing  | 
 | 23 | +        yield  | 
 | 24 | +    else:  | 
 | 25 | +        original_start = threading.Thread.start  | 
 | 26 | +        original_run = threading.Thread.run  | 
 | 27 | +        try:  | 
 | 28 | +            def wrap_start(original_func):  | 
 | 29 | +                @wraps(original_func)  | 
 | 30 | +                def _start(self, *args, **kwargs):  | 
 | 31 | +                    """Save the invoking thread's client if there is one."""  | 
 | 32 | +                    # Prevent an endless loop of workers being spawned  | 
 | 33 | +                    if "_monitor" not in self.name:  | 
 | 34 | +                        current_client = current()  | 
 | 35 | +                        self.parent_rp_client = current_client  | 
 | 36 | +                    return original_func(self, *args, **kwargs)  | 
 | 37 | + | 
 | 38 | +                return _start  | 
 | 39 | + | 
 | 40 | +            def wrap_run(original_func):  | 
 | 41 | +                @wraps(original_func)  | 
 | 42 | +                def _run(self, *args, **kwargs):  | 
 | 43 | +                    """Create a new client for the invoked thread."""  | 
 | 44 | +                    client = None  | 
 | 45 | +                    if (  | 
 | 46 | +                        hasattr(self, "parent_rp_client")  | 
 | 47 | +                        and self.parent_rp_client  | 
 | 48 | +                        and not current()  | 
 | 49 | +                    ):  | 
 | 50 | +                        parent = self.parent_rp_client  | 
 | 51 | +                        client = RPClient(  | 
 | 52 | +                            endpoint=parent.endpoint,  | 
 | 53 | +                            project=parent.project,  | 
 | 54 | +                            token=parent.token,  | 
 | 55 | +                            log_batch_size=parent.log_batch_size,  | 
 | 56 | +                            is_skipped_an_issue=parent.is_skipped_an_issue,  | 
 | 57 | +                            verify_ssl=parent.verify_ssl,  | 
 | 58 | +                            retries=config.rp_retries,  | 
 | 59 | +                            launch_id=parent.launch_id  | 
 | 60 | +                        )  | 
 | 61 | +                        if parent.current_item():  | 
 | 62 | +                            client._item_stack.append(  | 
 | 63 | +                                parent.current_item()  | 
 | 64 | +                            )  | 
 | 65 | +                        client.start()  | 
 | 66 | +                    try:  | 
 | 67 | +                        return original_func(self, *args, **kwargs)  | 
 | 68 | +                    finally:  | 
 | 69 | +                        if client:  | 
 | 70 | +                            # Stop the client and remove any references  | 
 | 71 | +                            client.terminate()  | 
 | 72 | +                            self.parent_rp_client = None  | 
 | 73 | +                            del self.parent_rp_client  | 
 | 74 | +                            set_current(None)  | 
 | 75 | + | 
 | 76 | +                return _run  | 
 | 77 | + | 
 | 78 | +            if not hasattr(threading.Thread, "patched"):  | 
 | 79 | +                # patch  | 
 | 80 | +                threading.Thread.patched = True  | 
 | 81 | +                threading.Thread.start = wrap_start(original_start)  | 
 | 82 | +                threading.Thread.run = wrap_run(original_run)  | 
 | 83 | +            yield  | 
 | 84 | + | 
 | 85 | +        finally:  | 
 | 86 | +            if hasattr(threading.Thread, "patched"):  | 
 | 87 | +                threading.Thread.start = original_start  | 
 | 88 | +                threading.Thread.run = original_run  | 
 | 89 | +                del threading.Thread.patched  | 
 | 90 | + | 
 | 91 | + | 
9 | 92 | @contextmanager  | 
10 | 93 | def patching_logger_class():  | 
11 | 94 |     """  | 
 | 
0 commit comments