Skip to content

Commit 1cfe3ce

Browse files
committed
fix: correct unzip directory & conditionally send training logs
1 parent b4aed8a commit 1cfe3ce

File tree

1 file changed

+13
-12
lines changed

1 file changed

+13
-12
lines changed

sleap_RTC/worker/worker_class.py

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -135,9 +135,9 @@ async def run_all_training_jobs(self, channel: RTCDataChannel, train_script_path
135135
"trainer_config.ckpt_dir=models",
136136
f"trainer_config.run_name={job_name}",
137137
"trainer_config.zmq.controller_port=9000",
138-
"trainer_config.zmq.publish_port=9001"
138+
"trainer_config.zmq.publish_port=9001",
139139
]
140-
logging.info(f"[RUNNING] {' '.join(cmd)} (cwd={self.zip_dir})")
140+
logging.info(f"[RUNNING] {' '.join(cmd)} (cwd={self.unzipped_dir})")
141141

142142
process = await asyncio.create_subprocess_exec(
143143
*cmd,
@@ -281,7 +281,7 @@ async def unzip_results(self, file_path: str):
281281
try:
282282
shutil.unpack_archive(file_path, self.save_dir)
283283
logging.info(f"Results unzipped from {file_path} to {self.save_dir}")
284-
self.unzipped_dir = f"{self.save_dir}/{file_path.split(".")[0]}"
284+
self.unzipped_dir = f"{self.save_dir}/{file_path[:-4]}" # remove .zip extension
285285
logging.info(f"Unzipped contents to {self.unzipped_dir}")
286286
except Exception as e:
287287
logging.error(f"Error unzipping results: {e}")
@@ -591,15 +591,16 @@ async def on_message(message):
591591
# Start ZMQ progress listener.
592592
# Don't need to send ZMQ progress reports if User just using CLI sleap-rtc.
593593
# (Will print sleap-nn train logs directly to terminal instead.)
594-
progress_listener_task = asyncio.create_task(self.start_progress_listener(channel))
595-
logging.info(f'{channel.label} progress listener started')
596-
597-
# Start ZMQ control socket.
598-
self.start_zmq_control()
599-
logging.info(f'{channel.label} ZMQ control socket started')
600-
601-
# Give SUB socket time to connect.
602-
await asyncio.sleep(1)
594+
if self.gui:
595+
progress_listener_task = asyncio.create_task(self.start_progress_listener(channel))
596+
logging.info(f'{channel.label} progress listener started')
597+
598+
# Start ZMQ control socket.
599+
self.start_zmq_control()
600+
logging.info(f'{channel.label} ZMQ control socket started')
601+
602+
# Give SUB socket time to connect.
603+
await asyncio.sleep(1)
603604

604605
logging.info(f"Running training script: {train_script_path}")
605606

0 commit comments

Comments
 (0)