Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
157 changes: 94 additions & 63 deletions transformerlab/shared/shared.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ async def async_run_python_script_and_update_status(python_script: list[str], jo
raise asyncio.CancelledError()


async def read_process_output(process, job_id):
async def read_process_output(process, job_id, log_handle=None):
await process.wait()
returncode = process.returncode
if returncode == 0:
Expand All @@ -181,9 +181,22 @@ async def read_process_output(process, job_id):
print("Worker Process stopped by user")
else:
print(f"ERROR: Worker Process ended with exit code {returncode}.")
with open(get_global_log_path(), "a") as log:
log.write(f"Inference Server Terminated with {returncode}.\n")
log.flush()

# Close the log handle if one was passed (from async_run_python_daemon_and_update_status)
if log_handle:
try:
log_handle.close()
except Exception:
pass

# Wrap log write in try-except to handle errors gracefully during shutdown
try:
with open(get_global_log_path(), "a") as log:
log.write(f"Inference Server Terminated with {returncode}.\n")
log.flush()
except Exception:
# Silently ignore logging errors during shutdown to prevent error bursts
pass
# so we should delete the pid file:
from lab.dirs import get_temp_dir

Expand Down Expand Up @@ -217,79 +230,97 @@ async def async_run_python_daemon_and_update_status(
break

# Open a file to write the output to:
log = open(get_global_log_path(), "a")
# Use context manager to ensure proper cleanup, but we need to keep it open
# so we'll use a different approach - store the handle and close it later
log = None
try:
log = open(get_global_log_path(), "a")

# Check if plugin has a venv directory
if plugin_location:
plugin_location = os.path.normpath(plugin_location)
from lab.dirs import get_plugin_dir
# Check if plugin has a venv directory
if plugin_location:
plugin_location = os.path.normpath(plugin_location)
from lab.dirs import get_plugin_dir

plugin_dir_root = get_plugin_dir()
if not plugin_location.startswith(plugin_dir_root):
print(f"Plugin location {plugin_location} is not in {plugin_dir_root}")
raise Exception(f"Plugin location {plugin_location} is not in {plugin_dir_root}")
if os.path.exists(os.path.join(plugin_location, "venv")) and os.path.isdir(
os.path.join(plugin_location, "venv")
):
venv_path = os.path.join(plugin_location, "venv")
print(f">Plugin has virtual environment, activating venv from {venv_path}")
venv_python = os.path.join(venv_path, "bin", "python")
command = [venv_python, *python_script]
else:
print(">Using system Python interpreter")
command = [sys.executable, *python_script]

plugin_dir_root = get_plugin_dir()
if not plugin_location.startswith(plugin_dir_root):
print(f"Plugin location {plugin_location} is not in {plugin_dir_root}")
raise Exception(f"Plugin location {plugin_location} is not in {plugin_dir_root}")
if os.path.exists(os.path.join(plugin_location, "venv")) and os.path.isdir(
os.path.join(plugin_location, "venv")
):
venv_path = os.path.join(plugin_location, "venv")
print(f">Plugin has virtual environment, activating venv from {venv_path}")
venv_python = os.path.join(venv_path, "bin", "python")
command = [venv_python, *python_script]
else:
print(">Using system Python interpreter")
command = [sys.executable, *python_script]

else:
print(">Using system Python interpreter")
command = [sys.executable, *python_script] # Skip the original Python interpreter

process = await asyncio.create_subprocess_exec(
*command, stdin=None, stderr=subprocess.STDOUT, stdout=subprocess.PIPE
)

pid = process.pid
from lab.dirs import get_temp_dir
command = [sys.executable, *python_script] # Skip the original Python interpreter

pid_file = os.path.join(get_temp_dir(), f"worker_job_{job_id}.pid")
with open(pid_file, "w") as f:
f.write(str(pid))
process = await asyncio.create_subprocess_exec(
*command, stdin=None, stderr=subprocess.STDOUT, stdout=subprocess.PIPE
)

# keep a tail of recent lines so we can show them on failure
recent_lines = deque(maxlen=10)
pid = process.pid
from lab.dirs import get_temp_dir

line = await process.stdout.readline()
error_msg = None
while line:
decoded = line.decode()
pid_file = os.path.join(get_temp_dir(), f"worker_job_{job_id}.pid")
with open(pid_file, "w") as f:
f.write(str(pid))

recent_lines.append(decoded.strip())
# keep a tail of recent lines so we can show them on failure:
recent_lines = deque(maxlen=10)

# If we hit the begin_string then the daemon is started and we can return!
if begin_string in decoded:
if set_process_id_function is not None:
if set_process_id_function:
set_process_id_function(process)
print(f"Worker job {job_id} started successfully")
job = job_service.job_get(job_id)
experiment_id = job["experiment_id"]
await job_update_status(job_id=job_id, status="COMPLETE", experiment_id=experiment_id)
line = await process.stdout.readline()
error_msg = None
while line:
decoded = line.decode()
recent_lines.append(decoded.strip())

# If we hit the begin_string then the daemon is started and we can return!
if begin_string in decoded:
if set_process_id_function is not None:
if set_process_id_function:
set_process_id_function(process)
print(f"Worker job {job_id} started successfully")
job = job_service.job_get(job_id)
experiment_id = job["experiment_id"]
await job_update_status(job_id=job_id, status="COMPLETE", experiment_id=experiment_id)

# Schedule the read_process_output coroutine in the current event
# so we can keep watching this process, but return back to the caller
# so that the REST call can complete
asyncio.create_task(read_process_output(process, job_id))
# Schedule the read_process_output coroutine in the current event
# so we can keep watching this process, but return back to the caller
# so that the REST call can complete
# Pass the log handle to read_process_output so it can close it
# Set log to None so the finally block doesn't close it
log_handle_to_pass = log
log = None
asyncio.create_task(read_process_output(process, job_id, log_handle_to_pass))

return process
return process

# Watch the output for any errors and store the latest error
elif ("stderr" in decoded) and ("ERROR" in decoded):
error_msg = decoded.split("| ")[-1]
# Watch the output for any errors and store the latest error
elif ("stderr" in decoded) and ("ERROR" in decoded):
error_msg = decoded.split("| ")[-1]

# Wrap log write in try-except to handle errors gracefully during shutdown
if log:
try:
log.write(decoded)
log.flush()
except Exception:
# Silently ignore logging errors during shutdown
pass
line = await process.stdout.readline()
finally:
# Ensure log file is closed even if there's an error
if log:
log.write(decoded)
log.flush()
log.flush()
line = await process.stdout.readline()
try:
log.close()
except Exception:
pass

# If we're here then stdout didn't return and we didn't start the daemon
# Wait on the process and return the error
Expand Down
Loading