diff --git a/.devcontainer/devcontainer.json b/.devcontainer/devcontainer.json index d048962..fded83c 100644 --- a/.devcontainer/devcontainer.json +++ b/.devcontainer/devcontainer.json @@ -1,6 +1,6 @@ { - "name": "vortex-image-segmentation-devcontainer", - "image": "vortex-image-segmentation:latest", + "name": "vortex-deep-learning-pipelines-devcontainer", + "image": "vortex-deep-learning-pipelines:latest", "customizations": { "vscode": { "settings": { diff --git a/.dockerignore b/.dockerignore index ab2a35e..0aae7b1 100644 --- a/.dockerignore +++ b/.dockerignore @@ -27,4 +27,4 @@ venv/ *.swp *.swo node_modules/ -bags/ +rosbags/ diff --git a/.gitignore b/.gitignore index 452c82c..efeac54 100644 --- a/.gitignore +++ b/.gitignore @@ -26,6 +26,18 @@ share/python-wheels/ *.egg MANIFEST +#VSCode +.vscode/ + +# OS Generated files +.DS_Store +.DS_Store? +._* +.Spotlight-V100 +.Trashes +ehthumbs.db +Thumbs.db + # PyInstaller # Usually these files are written by a python script from a template # before PyInstaller builds the exe, so as to inject date/other infos into it. @@ -173,7 +185,7 @@ build/ log/ # data -data/ +rosbags/ # Training outputs and Roboflow datasets results/ diff --git a/requirements.txt b/requirements.txt index 298b8ab..747fa7e 100644 Binary files a/requirements.txt and b/requirements.txt differ diff --git a/vortex_image_segmentation/LICENSE b/vortex_image_segmentation/LICENSE new file mode 100644 index 0000000..30e8e2e --- /dev/null +++ b/vortex_image_segmentation/LICENSE @@ -0,0 +1,17 @@ +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. diff --git a/vortex_image_segmentation/launch/yolo_segmentation.launch.py b/vortex_image_segmentation/launch/yolo_segmentation.launch.py new file mode 100644 index 0000000..3935f7a --- /dev/null +++ b/vortex_image_segmentation/launch/yolo_segmentation.launch.py @@ -0,0 +1,19 @@ +from launch import LaunchDescription +from launch_ros.actions import Node +from ament_index_python.packages import get_package_share_directory +import os + + +def generate_launch_description(): + pkg_share = get_package_share_directory('vortex_image_segmentation') + yolo_params = os.path.join(pkg_share, 'params', 'yolo_params.yaml') + + return LaunchDescription([ + Node( + package='vortex_image_segmentation', + executable='yolo_seg_node', + name='yolo_segmentation_node', + output='screen', + parameters=[yolo_params] + ) + ]) diff --git a/vortex_image_segmentation/package.xml b/vortex_image_segmentation/package.xml new file mode 100644 index 0000000..3b75ff8 --- /dev/null +++ b/vortex_image_segmentation/package.xml @@ -0,0 +1,21 @@ + + + + vortex_image_segmentation + 0.0.0 + Package for image segmentation nodes + Mads Engesvoll + MIT + + rclpy + std_msgs + + ament_copyright + ament_flake8 + ament_pep257 + python3-pytest + + + ament_python + + diff --git a/vortex_image_segmentation/params/yolo_params.yaml b/vortex_image_segmentation/params/yolo_params.yaml new file mode 100644 index 0000000..40bf69e --- /dev/null +++ b/vortex_image_segmentation/params/yolo_params.yaml @@ -0,0 +1,19 @@ +yolo_segmentation_node: + ros__parameters: + # Node parameters + input_topic: "/gripper_camera/image_raw" + output_bbox_topic: "/segmentation/bboxes" + output_mask_topic: "/segmentation/mask" + debug_topic: "/image_debug" + pub_bbox: True + pub_mask: True + pub_debug: True + + # Implementation parameters + model_path: "/ros2_ws/yolo_segmentation_training/results/custom_yolov89/weights/last.pt" + device: "cpu" + imgsz: 640 + confidence_threshold: 0.3 + max_detections: 1 + compile: True + retina_masks: False # We dont need high res masks diff --git a/vortex_image_segmentation/resource/vortex_image_segmentation b/vortex_image_segmentation/resource/vortex_image_segmentation new file mode 100644 index 0000000..e69de29 diff --git a/vortex_image_segmentation/setup.cfg b/vortex_image_segmentation/setup.cfg new file mode 100644 index 0000000..9c05dd7 --- /dev/null +++ b/vortex_image_segmentation/setup.cfg @@ -0,0 +1,4 @@ +[develop] +script_dir=$base/lib/vortex_image_segmentation +[install] +install_scripts=$base/lib/vortex_image_segmentation diff --git a/vortex_image_segmentation/setup.py b/vortex_image_segmentation/setup.py new file mode 100644 index 0000000..b635c18 --- /dev/null +++ b/vortex_image_segmentation/setup.py @@ -0,0 +1,30 @@ + +from setuptools import find_packages, setup +import glob + +package_name = 'vortex_image_segmentation' + +setup( + name=package_name, + version='0.0.0', + packages=find_packages(exclude=['test']), + data_files=[ + ('share/ament_index/resource_index/packages', + ['resource/' + package_name]), + ('share/' + package_name, ['package.xml']), + ('share/' + package_name + '/launch', glob.glob('launch/*.launch.py')), + ('share/' + package_name + '/params', glob.glob('params/*.yaml')), + ], + install_requires=['setuptools'], + zip_safe=True, + maintainer='mads', + maintainer_email='mjengesv@ntnu.no', + description='ROS 2 package that provides a YOLO-based instance segmentation node (yolo_seg_node) for real-time segmentation.', + license='MIT', + tests_require=['pytest'], + entry_points={ + 'console_scripts': [ + 'yolo_seg_node = vortex_image_segmentation.yolo_seg_node:main', + ], + }, +) diff --git a/vortex_image_segmentation/vortex_image_segmentation/__init__.py b/vortex_image_segmentation/vortex_image_segmentation/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/vortex_image_segmentation/vortex_image_segmentation/yolo_seg.py b/vortex_image_segmentation/vortex_image_segmentation/yolo_seg.py new file mode 100644 index 0000000..dffc4aa --- /dev/null +++ b/vortex_image_segmentation/vortex_image_segmentation/yolo_seg.py @@ -0,0 +1,78 @@ +""" +YOLO segmentation model wrapper for inference and visualization using Ultralytics. +Defines parameter dataclass and main segmentation class. +""" + +from dataclasses import dataclass +from typing import List + +import numpy as np +import torch +from ultralytics import YOLO +from ultralytics.engine.results import Results + + +@dataclass +class YoloSegmentationParams: + """ + Dataclass for storing YOLO segmentation parameters. + """ + model_path: str + device: str + confidence_threshold: float + max_detections: int + imgsz: int + compile: bool + retina_masks: bool + + +class YoloSegmentation: + """ + Wrapper class for YOLO segmentation model inference and visualization. + """ + + def __init__(self, params: YoloSegmentationParams) -> None: + """ + Initialize the YOLO segmentation model with given parameters. + Args: + params (YoloSegmentationParams): Parameters for YOLO segmentation. + """ + self.params: YoloSegmentationParams = params + self.model: YOLO = self._load_model() + + def _load_model(self) -> YOLO: + """ + Load the YOLO segmentation model. + Returns: + YOLO: Ultralytics YOLO model instance. + """ + return YOLO(self.params.model_path, task="segment") + + def predict(self, cv_image: np.ndarray) -> List[Results]: + """ + Run prediction on an input image and return the Results object(s). + Args: + cv_image (np.ndarray): Input image in OpenCV format. + Returns: + List[Results]: Ultralytics Results object(s) for the prediction. + """ + results = self.model.predict( + source=cv_image, + imgsz=self.params.imgsz, + conf=self.params.confidence_threshold, + device=torch.device(self.params.device), + max_det=self.params.max_detections, + retina_masks=self.params.retina_masks, + # compile=self.params.compile, + ) + return results + + def visualize(self, result: Results) -> np.ndarray: + """ + Generate a visualization image from the Results object. + Args: + result (Results): Ultralytics Results object from prediction. + Returns: + np.ndarray: Visualization image with segmentation overlays. + """ + return result.plot() diff --git a/vortex_image_segmentation/vortex_image_segmentation/yolo_seg_node.py b/vortex_image_segmentation/vortex_image_segmentation/yolo_seg_node.py new file mode 100644 index 0000000..a94a37a --- /dev/null +++ b/vortex_image_segmentation/vortex_image_segmentation/yolo_seg_node.py @@ -0,0 +1,264 @@ +""" +ROS2 node for YOLO segmentation: subscribes to images, runs segmentation, and publishes results. +""" + +import rclpy +from dataclasses import dataclass +from typing import List, Optional +import numpy as np +from cv_bridge import CvBridge +from rclpy.node import Node +from rclpy.parameter import Parameter +from rclpy.qos import QoSHistoryPolicy, QoSProfile, QoSReliabilityPolicy +from rclpy.publisher import Publisher +from rclpy.subscription import Subscription +from sensor_msgs.msg import Image +from std_msgs.msg import Header +from ultralytics.engine.results import Results +from vision_msgs.msg import ( + BoundingBox2D, + Detection2D, + Detection2DArray, + ObjectHypothesisWithPose, +) + +from .yolo_seg import YoloSegmentation, YoloSegmentationParams + + +@dataclass +class YoloNodeParams: + input_topic: str + output_bbox_topic: str + output_mask_topic: str + debug_topic: str + pub_bbox: bool + pub_mask: bool + pub_debug: bool + + +class YoloSegmentationNode(Node): + """ + ROS2 node for running YOLO segmentation and publishing results. + Subscribes to an input image topic, runs segmentation, and publishes output images, masks, and confidences. + """ + + def __init__(self) -> None: + """ + Initialize the YoloSegmentationNode, set up publishers, subscribers, and segmentation model. + """ + super().__init__("yolo_segmentation_node") + node_params = self.load_node_params() + self.input_topic: str = node_params.input_topic + self.output_bbox_topic: str = node_params.output_bbox_topic + self.output_mask_topic: str = node_params.output_mask_topic + self.debug_topic: str = node_params.debug_topic + self.pub_bbox: bool = node_params.pub_bbox + self.pub_mask: bool = node_params.pub_mask + self.pub_debug: bool = node_params.pub_debug + + self.bridge: CvBridge = CvBridge() + qos_profile = QoSProfile( + reliability=QoSReliabilityPolicy.BEST_EFFORT, + history=QoSHistoryPolicy.KEEP_LAST, + depth=1, + ) + self.subscription: Subscription = self.create_subscription( + Image, self.input_topic, self.image_callback, qos_profile + ) + self.debug_publisher: Optional[Publisher] = None + self.bbox_pub: Optional[Publisher] = None + self.mask_pub: Optional[Publisher] = None + + if self.pub_debug: + self.debug_publisher = self.create_publisher( + Image, self.debug_topic, qos_profile + ) + if self.pub_bbox: + self.bbox_pub = self.create_publisher( + Detection2DArray, self.output_bbox_topic, qos_profile + ) + if self.pub_mask: + self.mask_pub = self.create_publisher( + Image, self.output_mask_topic, qos_profile + ) + + self.params: YoloSegmentationParams = self.load_params() + self.get_logger().info( + f"Loading YOLO model from: {self.params.model_path}" + ) + self.segmentation: YoloSegmentation = YoloSegmentation(self.params) + self.get_logger().info( + f"Node initialized. Subscribing to '{self.input_topic}'" + ) + + def image_callback(self, msg: Image) -> None: + """ + Callback for incoming images. Runs segmentation, publishes results and debug images. + Args: + msg (sensor_msgs.msg.Image): Input image message. + """ + cv_image = self.bridge.imgmsg_to_cv2(msg, "bgr8") + results: List[Results] = self.segmentation.predict(cv_image) + # When passing in one image to predict, we always want the first result from the list + result: Results = results[0] + + self.publish_bboxes_and_confidences(result, msg.header) + self.publish_masks(result, msg.header) + self.publish_debug_image(result, msg.header) + + def publish_bboxes_and_confidences(self, result: Results, header: Header) -> None: + """ + Publish bounding boxes and confidences as Detection2DArray. + """ + if getattr(self, "bbox_pub", None) is None: + return + det_array = Detection2DArray() + det_array.header = header + boxes = result.boxes.xyxy.cpu().numpy() + confs = result.boxes.conf.cpu().numpy() + clss = result.boxes.cls.cpu().numpy() + for (x1, y1, x2, y2), conf, cls_id in zip(boxes, confs, clss): + det = Detection2D() + det.header = header + + hyp = ObjectHypothesisWithPose() + hyp.hypothesis.class_id = str(int(cls_id)) + hyp.hypothesis.score = float(conf) + det.results.append(hyp) + + bbox = BoundingBox2D() + bbox.center.position.x = float((x1 + x2) / 2.0) + bbox.center.position.y = float((y1 + y2) / 2.0) + bbox.center.theta = 0.0 + bbox.size_x = float(x2 - x1) + bbox.size_y = float(y2 - y1) + det.bbox = bbox + det_array.detections.append(det) + + self.bbox_pub.publish(det_array) + + def publish_masks(self, result: Results, header: Header) -> None: + """ + Publish segmentation masks as mono8 Image messages. + """ + if getattr(self, "mask_pub", None) is None: + return + if result.masks is None: + return + + masks: np.ndarray = result.masks.data.cpu().numpy() + + if masks.shape[0] == 0: + return + + # Create a single binary mask where any non-zero instance pixel becomes detection + binary_masks = masks > 0.0 + combined = np.any(binary_masks, axis=0).astype("uint8") * 255 + + mask_msg: Image = self.bridge.cv2_to_imgmsg(combined, encoding="mono8") + mask_msg.header = header + self.mask_pub.publish(mask_msg) + + def publish_debug_image(self, result: Results, header: Header) -> None: + """ + Publish debug visualization image. + """ + if getattr(self, "debug_publisher", None) is None: + return + debug_img: np.ndarray = self.segmentation.visualize(result) + debug_msg: Image = self.bridge.cv2_to_imgmsg(debug_img, "bgr8") + debug_msg.header = header + self.debug_publisher.publish(debug_msg) + + def load_node_params(self) -> YoloNodeParams: + """ + Load node-specific parameters (topics, debug). + Returns: + YoloNodeParams: Node parameter dataclass. + """ + self.declare_parameter("input_topic", Parameter.Type.STRING) + self.declare_parameter("output_bbox_topic", Parameter.Type.STRING) + self.declare_parameter("output_mask_topic", Parameter.Type.STRING) + self.declare_parameter("debug_topic", Parameter.Type.STRING) + self.declare_parameter("pub_bbox", Parameter.Type.BOOL) + self.declare_parameter("pub_mask", Parameter.Type.BOOL) + self.declare_parameter("pub_debug", Parameter.Type.BOOL) + return YoloNodeParams( + input_topic=self.get_parameter("input_topic") + .get_parameter_value() + .string_value, + output_bbox_topic=self.get_parameter("output_bbox_topic") + .get_parameter_value() + .string_value, + output_mask_topic=self.get_parameter("output_mask_topic") + .get_parameter_value() + .string_value, + debug_topic=self.get_parameter("debug_topic") + .get_parameter_value() + .string_value, + pub_bbox=self.get_parameter("pub_bbox") + .get_parameter_value() + .bool_value, + pub_mask=self.get_parameter("pub_mask") + .get_parameter_value() + .bool_value, + pub_debug=self.get_parameter("pub_debug") + .get_parameter_value() + .bool_value, + ) + + def load_params(self) -> YoloSegmentationParams: + """ + Load segmentation parameters. + Returns: + YoloSegmentationParams: Segmentation parameters dataclass. + """ + self.declare_parameter("device", Parameter.Type.STRING) + self.declare_parameter("model_path", Parameter.Type.STRING) + self.declare_parameter("confidence_threshold", Parameter.Type.DOUBLE) + self.declare_parameter("max_detections", Parameter.Type.INTEGER) + self.declare_parameter("imgsz", Parameter.Type.INTEGER) + self.declare_parameter("compile", Parameter.Type.BOOL) + self.declare_parameter("retina_masks", Parameter.Type.BOOL) + return YoloSegmentationParams( + device=self.get_parameter("device") + .get_parameter_value() + .string_value, + model_path=self.get_parameter("model_path") + .get_parameter_value() + .string_value, + confidence_threshold=self.get_parameter("confidence_threshold") + .get_parameter_value() + .double_value, + max_detections=self.get_parameter("max_detections") + .get_parameter_value() + .integer_value, + imgsz=self.get_parameter("imgsz") + .get_parameter_value() + .integer_value, + compile=self.get_parameter("compile") + .get_parameter_value() + .bool_value, + retina_masks=self.get_parameter("retina_masks") + .get_parameter_value() + .bool_value, + ) + + +def main(args: Optional[List[str]] = None) -> None: + """ + Entry point for the ROS2 node. Initializes and spins the YoloSegmentationNode. + """ + rclpy.init(args=args) + node = YoloSegmentationNode() + try: + rclpy.spin(node) + except KeyboardInterrupt: + pass + finally: + node.destroy_node() + rclpy.shutdown() + + +if __name__ == "__main__": + main() diff --git a/yolo_roboflow_training/Job.slurm b/yolo_roboflow_training/Job.slurm new file mode 100644 index 0000000..7ff5bb5 --- /dev/null +++ b/yolo_roboflow_training/Job.slurm @@ -0,0 +1,50 @@ +#!/bin/bash +#SBATCH --partition=GPUQ +#SBATCH --account=studiegrupper-vortex +#SBATCH --time=2:00:00 +#SBATCH --nodes=1 +#SBATCH --ntasks-per-node=4 +#SBATCH --gres=gpu:a100:4 +#SBATCH --constraint="gpu40g|gpu80g|gpu32g" +#SBATCH --job-name="vortex-img-process" +#SBATCH --output=vortex_img_process_log.out +#SBATCH --mem=32G + +export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/cluster/apps/eb/software/Python/3.10.4-GCCcore-11.3.0/lib/ + +set -e + +module purge +module --ignore_cache load foss/2022a +module --ignore_cache load Python/3.10.4-GCCcore-11.3.0 + +pip cache purge + +# makes sure that the pip is up to date +python3 -m pip install --upgrade pip + +# Create a temporary virtual environment +VENV_DIR=$(mktemp -d -t env-repaint-XXXXXXXXXX) +python3 -m venv $VENV_DIR +source $VENV_DIR/bin/activate + +pip install --upgrade pip + +# install the required packages +pip install -r requirements.txt +#pip install pyyaml # used to read the configuration file +#pip install blobfile # install blobfile to download the dataset +#pip install kagglehub # install kagglehub to download the dataset +pip install --force-reinstall torch -U +pip install torchvision torchaudio +#pip install diffusers transformers accelerate --user + +# Mixing expandable_segments:True with max_split_size doesn't make sense because the expandable segment is the size of RAM and so it could never be split with max_split_size. +# export PYTORCH_CUDA_ALLOC_CONF="expandable_segments:True,max_split_size_mb:128" +export PYTORCH_CUDA_ALLOC_CONF="expandable_segments:True" + +python3 train.py + +# Deactivate and remove the virtual environment +deactivate +rm -rf $VENV_DIR diff --git a/yolo_roboflow_training/config.yaml b/yolo_roboflow_training/config.yaml new file mode 100644 index 0000000..1d1565a --- /dev/null +++ b/yolo_roboflow_training/config.yaml @@ -0,0 +1,12 @@ +# General training parameters for YOLO Roboflow training +project_id: "pipe-line-instance-seg-rxqdn" +version: "1" # Dataset version number (exclude 'v') +model_type: "yolov8n-seg.pt" +epochs: 5 +patience: 2 +imgsz: 640 +batch: 16 +results_dir: "results" +export_formats: + - onnx +dataset_format: "YOLOv8" diff --git a/yolo_roboflow_training/requirements.txt b/yolo_roboflow_training/requirements.txt new file mode 100644 index 0000000..c2c343c --- /dev/null +++ b/yolo_roboflow_training/requirements.txt @@ -0,0 +1,49 @@ +-i https://pypi.org/simple +certifi==2025.10.5; python_version >= '3.7' +charset-normalizer==3.4.4; python_version >= '3.7' +contourpy==1.3.3; python_version >= '3.11' +cycler==0.12.1; python_version >= '3.8' +filelock==3.20.0; python_version >= '3.10' +filetype==1.2.0 +fonttools==4.60.1; python_version >= '3.9' +fsspec==2025.9.0; python_version >= '3.9' +hf-xet==1.1.10; python_version >= '3.8' +huggingface-hub==0.35.3; python_full_version >= '3.8.0' +idna==3.7; python_version >= '3.5' +jinja2==3.1.6; python_version >= '3.7' +kiwisolver==1.4.9; python_version >= '3.10' +markupsafe==3.0.3; python_version >= '3.9' +matplotlib==3.10.7; python_version >= '3.10' +mpmath==1.3.0 +networkx==3.5; python_version >= '3.11' +numpy==2.2.6; python_version >= '3.10' +opencv-contrib-python==4.12.0.88; python_version >= '3.6' +opencv-python==4.12.0.88; python_version >= '3.6' +opencv-python-headless==4.10.0.84; python_version >= '3.6' +packaging==25.0; python_version >= '3.8' +pafy==0.5.5 +pi-heif==1.1.1; python_version >= '3.9' +pillow==12.0.0; python_version >= '3.10' +pillow-avif-plugin==1.5.2 +polars==1.34.0; python_version >= '3.9' +polars-runtime-32==1.34.0; python_version >= '3.9' +protobuf==6.33.0; python_version >= '3.9' +psutil==7.1.1; python_version >= '3.6' +pyparsing==3.2.5; python_version >= '3.9' +python-dateutil==2.9.0.post0; python_version >= '2.7' and python_version not in '3.0, 3.1, 3.2' +python-dotenv==1.1.1; python_version >= '3.9' +pyyaml==6.0.3; python_version >= '3.8' +requests==2.32.5; python_version >= '3.9' +requests-toolbelt==1.0.0; python_version >= '2.7' and python_version not in '3.0, 3.1, 3.2, 3.3' +roboflow==1.2.11; python_version >= '3.8' +scipy==1.16.2; python_version >= '3.11' +setuptools==80.9.0; python_version >= '3.9' +six==1.17.0; python_version >= '2.7' and python_version not in '3.0, 3.1, 3.2' +sympy==1.14.0; python_version >= '3.9' +torch==2.9.0; python_version >= '3.10' +torchvision==0.24.0; python_version >= '3.10' +tqdm==4.67.1; python_version >= '3.7' +typing-extensions==4.15.0; python_version >= '3.9' +ultralytics==8.3.220; python_version >= '3.8' +ultralytics-thop==2.0.17; python_version >= '3.8' +urllib3==2.5.0; python_version >= '3.9' diff --git a/yolo_roboflow_training/train.py b/yolo_roboflow_training/train.py new file mode 100644 index 0000000..6181b15 --- /dev/null +++ b/yolo_roboflow_training/train.py @@ -0,0 +1,90 @@ +import os +from pathlib import Path + +import torch +import yaml +from dotenv import load_dotenv +from roboflow import Roboflow +from ultralytics import YOLO + +# Load environment variables from .env file +load_dotenv() + +# Print CUDA info +print(f"CUDA available: {torch.cuda.is_available()}") +print(f"CUDA version: {torch.version.cuda}") +print(f"PyTorch cuDNN version: {torch.backends.cudnn.version()}") + +# Load config +with open("config.yaml", "r") as f: + config = yaml.safe_load(f) + +# Parameters +ROBOFLOW_API_KEY = os.getenv("ROBOFLOW_API_KEY") +PROJECT_ID = config["project_id"] +VERSION = config["version"] +MODEL_TYPE = config.get("model_type") +EPOCHS = config.get("epochs") +PATIENCE = config.get("patience") +IMGSZ = config.get("imgsz") +BATCH = config.get("batch") +RESULTS_DIR = config.get("results_dir") +EXPORT_FORMATS = config.get("export_formats") +DATASET_FORMAT = config.get("dataset_format") + +if not ROBOFLOW_API_KEY: + raise ValueError("ROBOFLOW_API_KEY not set in environment or .env file.") + + +# Download dataset from Roboflow into 'roboflow_data/' +roboflow_data_dir = Path(__file__).parent / "roboflow_data" +roboflow_data_dir.mkdir(exist_ok=True) +rf = Roboflow(api_key=ROBOFLOW_API_KEY) +project = rf.workspace().project(PROJECT_ID) +versions = project.versions() +dataset = project.version(VERSION).download( + DATASET_FORMAT, location=str(roboflow_data_dir) +) + +# Set up training configuration +model = YOLO(MODEL_TYPE) +subfolder_name = f"{dataset.name.replace(' ', '-')}-{VERSION}" +data_yaml_path = Path(dataset.location) / subfolder_name / "data.yaml" + + +# Select device: CUDA, MPS (Apple Silicon), or CPU +if torch.cuda.is_available(): + device = 0 + print(f"Using CUDA GPU: {torch.cuda.get_device_name(0)}") +elif torch.backends.mps.is_built(): + device = 'mps' + print("Using Apple Silicon MPS device.") +else: + device = 'cpu' + print("Using CPU.") + +WORKERS = os.cpu_count() // 2 if os.cpu_count() else 2 +print(f"Using {WORKERS} workers") + +model.train( + data=str(data_yaml_path), + epochs=EPOCHS, + patience=PATIENCE, + imgsz=IMGSZ, + batch=BATCH, + device=device, + project=str(RESULTS_DIR), + name="custom_yolov8", + workers=WORKERS, + # compile=True +) + + +# Evaluate the model +metrics = model.val(data=str(data_yaml_path)) +print(f"Validation metrics: {metrics}") + + +# Export the trained model +for fmt in EXPORT_FORMATS: + model.export(format=fmt, device=device)