diff --git a/transformerlab/shared/shared.py b/transformerlab/shared/shared.py index 5dafcbcc5..e3dcf0058 100644 --- a/transformerlab/shared/shared.py +++ b/transformerlab/shared/shared.py @@ -172,7 +172,7 @@ async def async_run_python_script_and_update_status(python_script: list[str], jo raise asyncio.CancelledError() -async def read_process_output(process, job_id): +async def read_process_output(process, job_id, log_handle=None): await process.wait() returncode = process.returncode if returncode == 0: @@ -181,9 +181,22 @@ async def read_process_output(process, job_id): print("Worker Process stopped by user") else: print(f"ERROR: Worker Process ended with exit code {returncode}.") - with open(get_global_log_path(), "a") as log: - log.write(f"Inference Server Terminated with {returncode}.\n") - log.flush() + + # Close the log handle if one was passed (from async_run_python_daemon_and_update_status) + if log_handle: + try: + log_handle.close() + except Exception: + pass + + # Wrap log write in try-except to handle errors gracefully during shutdown + try: + with open(get_global_log_path(), "a") as log: + log.write(f"Inference Server Terminated with {returncode}.\n") + log.flush() + except Exception: + # Silently ignore logging errors during shutdown to prevent error bursts + pass # so we should delete the pid file: from lab.dirs import get_temp_dir @@ -217,79 +230,97 @@ async def async_run_python_daemon_and_update_status( break # Open a file to write the output to: - log = open(get_global_log_path(), "a") + # Use context manager to ensure proper cleanup, but we need to keep it open + # so we'll use a different approach - store the handle and close it later + log = None + try: + log = open(get_global_log_path(), "a") - # Check if plugin has a venv directory - if plugin_location: - plugin_location = os.path.normpath(plugin_location) - from lab.dirs import get_plugin_dir + # Check if plugin has a venv directory + if plugin_location: + plugin_location = os.path.normpath(plugin_location) + from lab.dirs import get_plugin_dir + + plugin_dir_root = get_plugin_dir() + if not plugin_location.startswith(plugin_dir_root): + print(f"Plugin location {plugin_location} is not in {plugin_dir_root}") + raise Exception(f"Plugin location {plugin_location} is not in {plugin_dir_root}") + if os.path.exists(os.path.join(plugin_location, "venv")) and os.path.isdir( + os.path.join(plugin_location, "venv") + ): + venv_path = os.path.join(plugin_location, "venv") + print(f">Plugin has virtual environment, activating venv from {venv_path}") + venv_python = os.path.join(venv_path, "bin", "python") + command = [venv_python, *python_script] + else: + print(">Using system Python interpreter") + command = [sys.executable, *python_script] - plugin_dir_root = get_plugin_dir() - if not plugin_location.startswith(plugin_dir_root): - print(f"Plugin location {plugin_location} is not in {plugin_dir_root}") - raise Exception(f"Plugin location {plugin_location} is not in {plugin_dir_root}") - if os.path.exists(os.path.join(plugin_location, "venv")) and os.path.isdir( - os.path.join(plugin_location, "venv") - ): - venv_path = os.path.join(plugin_location, "venv") - print(f">Plugin has virtual environment, activating venv from {venv_path}") - venv_python = os.path.join(venv_path, "bin", "python") - command = [venv_python, *python_script] else: print(">Using system Python interpreter") - command = [sys.executable, *python_script] - - else: - print(">Using system Python interpreter") - command = [sys.executable, *python_script] # Skip the original Python interpreter - - process = await asyncio.create_subprocess_exec( - *command, stdin=None, stderr=subprocess.STDOUT, stdout=subprocess.PIPE - ) - - pid = process.pid - from lab.dirs import get_temp_dir + command = [sys.executable, *python_script] # Skip the original Python interpreter - pid_file = os.path.join(get_temp_dir(), f"worker_job_{job_id}.pid") - with open(pid_file, "w") as f: - f.write(str(pid)) + process = await asyncio.create_subprocess_exec( + *command, stdin=None, stderr=subprocess.STDOUT, stdout=subprocess.PIPE + ) - # keep a tail of recent lines so we can show them on failure - recent_lines = deque(maxlen=10) + pid = process.pid + from lab.dirs import get_temp_dir - line = await process.stdout.readline() - error_msg = None - while line: - decoded = line.decode() + pid_file = os.path.join(get_temp_dir(), f"worker_job_{job_id}.pid") + with open(pid_file, "w") as f: + f.write(str(pid)) - recent_lines.append(decoded.strip()) + # keep a tail of recent lines so we can show them on failure: + recent_lines = deque(maxlen=10) - # If we hit the begin_string then the daemon is started and we can return! - if begin_string in decoded: - if set_process_id_function is not None: - if set_process_id_function: - set_process_id_function(process) - print(f"Worker job {job_id} started successfully") - job = job_service.job_get(job_id) - experiment_id = job["experiment_id"] - await job_update_status(job_id=job_id, status="COMPLETE", experiment_id=experiment_id) + line = await process.stdout.readline() + error_msg = None + while line: + decoded = line.decode() + recent_lines.append(decoded.strip()) + + # If we hit the begin_string then the daemon is started and we can return! + if begin_string in decoded: + if set_process_id_function is not None: + if set_process_id_function: + set_process_id_function(process) + print(f"Worker job {job_id} started successfully") + job = job_service.job_get(job_id) + experiment_id = job["experiment_id"] + await job_update_status(job_id=job_id, status="COMPLETE", experiment_id=experiment_id) - # Schedule the read_process_output coroutine in the current event - # so we can keep watching this process, but return back to the caller - # so that the REST call can complete - asyncio.create_task(read_process_output(process, job_id)) + # Schedule the read_process_output coroutine in the current event + # so we can keep watching this process, but return back to the caller + # so that the REST call can complete + # Pass the log handle to read_process_output so it can close it + # Set log to None so the finally block doesn't close it + log_handle_to_pass = log + log = None + asyncio.create_task(read_process_output(process, job_id, log_handle_to_pass)) - return process + return process - # Watch the output for any errors and store the latest error - elif ("stderr" in decoded) and ("ERROR" in decoded): - error_msg = decoded.split("| ")[-1] + # Watch the output for any errors and store the latest error + elif ("stderr" in decoded) and ("ERROR" in decoded): + error_msg = decoded.split("| ")[-1] + # Wrap log write in try-except to handle errors gracefully during shutdown + if log: + try: + log.write(decoded) + log.flush() + except Exception: + # Silently ignore logging errors during shutdown + pass + line = await process.stdout.readline() + finally: + # Ensure log file is closed even if there's an error if log: - log.write(decoded) - log.flush() - log.flush() - line = await process.stdout.readline() + try: + log.close() + except Exception: + pass # If we're here then stdout didn't return and we didn't start the daemon # Wait on the process and return the error