diff --git a/sleap/io/convert.py b/sleap/io/convert.py index 1f900666d..d5c35109d 100644 --- a/sleap/io/convert.py +++ b/sleap/io/convert.py @@ -158,29 +158,43 @@ def main(args: list = None): outnames.append(dflt_name) if "csv" in args.format: - from sleap.info.write_tracking_h5 import main as write_analysis + import sleap_io as sio for video, output_path in zip(vids, outnames): - write_analysis( + # Check for labeled frames before exporting + labeled_frames = labels.find(video) + if not labeled_frames: + print(f"No labeled frames in {video.filename}. Skipping.") + continue + + sio.save_csv( labels, - output_path=output_path, - labels_path=args.input_path, - all_frames=True, + output_path, + format="sleap", video=video, - csv=True, + include_score=True, + include_empty=True, + save_metadata=True, ) else: - from sleap.info.write_tracking_h5 import main as write_analysis + import sleap_io as sio for video, output_path in zip(vids, outnames): - write_analysis( - labels, - output_path=output_path, - labels_path=args.input_path, - all_frames=True, - video=video, - ) + try: + sio.save_analysis_h5( + labels, + output_path, + video=video, + labels_path=args.input_path, + all_frames=True, + preset="matlab", + ) + except ValueError as e: + if "No labeled frames" in str(e): + print(f"No labeled frames in {video.filename}. Skipping.") + else: + raise elif len(args.outputs) > 0: print(f"Output SLEAP dataset: {args.outputs[0]}") diff --git a/sleap/io/format/csv.py b/sleap/io/format/csv.py index eb6b4de9a..38973a9ae 100644 --- a/sleap/io/format/csv.py +++ b/sleap/io/format/csv.py @@ -1,8 +1,13 @@ -"""Adaptor for writing SLEAP analysis as csv.""" +"""Adaptor for writing SLEAP analysis as CSV. + +This adaptor uses sleap-io for CSV export, providing a consistent format +with the analysis HDF5 export. +""" from sleap.io import format from sleap_io import Labels, Video +import sleap_io as sio class CSVAdaptor(format.adaptor.Adaptor): @@ -46,24 +51,34 @@ def write( source_path: str = None, video: Video = None, ): - """Writes csv file for :py:class:`Labels` `source_object`. + """Writes CSV file for :py:class:`Labels` `source_object`. Args: filename: The filename for the output file. source_object: The :py:class:`Labels` from which to get data from. - source_path: Path for the labels object - video: The :py:class:`Video` from which toget data from. If no `video` is + source_path: Path for the labels object (stored as metadata). + video: The :py:class:`Video` from which to get data from. If no `video` is specified, then the first video in `source_object` videos list will be - used. If there are no :py:class:`Labeled Frame`s in the `video`, + used. If there are no :py:class:`LabeledFrame`s in the `video`, then no analysis file will be written. """ - from sleap.info.write_tracking_h5 import main as write_analysis - - write_analysis( - labels=source_object, - output_path=filename, - labels_path=source_path, - all_frames=True, + # Resolve video + if video is None: + video = source_object.videos[0] if source_object.videos else None + + # Check for labeled frames before exporting (sleap-io may not raise error) + if video is not None: + labeled_frames = source_object.find(video) + if not labeled_frames: + print("No labeled frames in video. Skipping CSV export.") + return + + sio.save_csv( + source_object, + filename, + format="sleap", video=video, - csv=True, + include_score=True, + include_empty=True, # Include all frames from 0 to last labeled + save_metadata=True, ) diff --git a/sleap/io/format/sleap_analysis.py b/sleap/io/format/sleap_analysis.py index 510845ab1..be8378b95 100644 --- a/sleap/io/format/sleap_analysis.py +++ b/sleap/io/format/sleap_analysis.py @@ -1,23 +1,20 @@ -""" -Adaptor to read and write analysis HDF5 files. +"""Adaptor to read and write analysis HDF5 files. These contain location and track data, but lack other metadata included in a full SLEAP dataset file. -Note that this adaptor will use default track names and skeleton node names -if these cannot be read from the HDF5 (some files have these, some don't). +This adaptor uses sleap-io for both reading and writing analysis HDF5 files, +providing a consistent format with additional metadata like dimension labels +and skeleton symmetries. To determine whether this adaptor can read a file, we check it's an HDF5 file with a `track_occupancy` dataset. """ -import numpy as np - from typing import Union -from sleap_io import Labels, Video, LabeledFrame -from sleap_io import Skeleton -from sleap_io.model.instance import Track +from sleap_io import Labels, Video +import sleap_io as sio from .adaptor import Adaptor, SleapObjectType from .filehandle import FileHandle @@ -66,72 +63,18 @@ def read( *args, **kwargs, ) -> Labels: - connect_adj_nodes = False - - if video is None: - raise ValueError("Cannot read analysis hdf5 if no video specified.") - - if not isinstance(video, Video): - video = Video.from_filename(video) - - f = file.file - tracks_matrix = f["tracks"][:].T - - # shape: frames * nodes * 2 * tracks - frame_count, node_count, _, track_count = tracks_matrix.shape - - if "track_names" in f and len(f["track_names"]): - track_names_list = f["track_names"][:].T - tracks = [ - Track(name=track_name.decode()) for track_name in track_names_list - ] - else: - tracks = [Track(name=f"track_{i}") for i in range(track_count)] - - if "node_names" in f: - node_names_dset = f["node_names"][:].T - node_names = [name.decode() for name in node_names_dset] - else: - node_names = [f"node {i}" for i in range(node_count)] - - skeleton = Skeleton() - last_node_name = None - for node_name in node_names: - skeleton.add_node(node_name) - if connect_adj_nodes and last_node_name: - skeleton.add_edge(last_node_name, node_name) - last_node_name = node_name - - frames = [] - for frame_idx in range(frame_count): - instances = [] - for track_idx in range(track_count): - points = tracks_matrix[frame_idx, ..., track_idx] - if not np.all(np.isnan(points)): - point_scores = np.ones(len(points)) - # make everything a PredictedInstance since the usual use - # case is to export predictions for analysis - from sleap.sleap_io_adaptors.instance_utils import ( - predicted_instance_from_numpy_compat, - ) - - instances.append( - predicted_instance_from_numpy_compat( - points=points, - point_confidences=point_scores, - skeleton=skeleton, - track=tracks[track_idx], - instance_score=1, - ) - ) - if instances: - frames.append( - LabeledFrame(video=video, frame_idx=frame_idx, instances=instances) - ) - - labels = Labels(labeled_frames=frames) - labels.update() - return labels + """Reads analysis HDF5 file using sleap-io. + + Args: + file: The file handle for the HDF5 file. + video: The video to associate with the data. Can be a Video object + or a path string. + + Returns: + Labels object with loaded pose data. + """ + # Use sleap-io's load_analysis_h5 which handles all format variants + return sio.load_analysis_h5(file.filename, video=video) @classmethod def write( @@ -146,17 +89,25 @@ def write( Args: filename: The filename for the output file. source_object: The :py:class:`Labels` from which to get data from. - video: The :py:class:`Video` from which toget data from. If no `video` is + source_path: Path to the source labels file (stored as metadata). + video: The :py:class:`Video` from which to get data from. If no `video` is specified, then the first video in `source_object` videos list will be - used. If there are no :py:class:`Labeled Frame`s in the `video`, + used. If there are no :py:class:`LabeledFrame`s in the `video`, then no analysis file will be written. """ - from sleap.info.write_tracking_h5 import main as write_analysis - - write_analysis( - labels=source_object, - output_path=filename, - labels_path=source_path, - all_frames=True, - video=video, - ) + try: + sio.save_analysis_h5( + source_object, + filename, + video=video, + labels_path=source_path, + all_frames=True, + preset="matlab", # SLEAP-compatible format + ) + except ValueError as e: + # Handle case where video has no labeled frames + # sleap-io raises ValueError, but we silently skip like old behavior + if "No labeled frames" in str(e): + print("No labeled frames in video. Skipping analysis export.") + else: + raise