-
Notifications
You must be signed in to change notification settings - Fork 3.1k
examples/traffic_analysis: improve CLI argument handling #2059
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: develop
Are you sure you want to change the base?
Changes from all commits
03800a9
6f164e9
7b62fa2
ade47d0
805cef7
3646830
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,6 +1,5 @@ | ||
| from __future__ import annotations | ||
|
|
||
| import argparse | ||
| import os | ||
| from collections.abc import Iterable | ||
|
|
||
|
|
@@ -180,62 +179,47 @@ def process_frame(self, frame: np.ndarray) -> np.ndarray: | |
| return self.annotate_frame(frame, detections) | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| parser = argparse.ArgumentParser( | ||
| description="Traffic Flow Analysis with Inference and ByteTrack" | ||
| ) | ||
|
|
||
| parser.add_argument( | ||
| "--model_id", | ||
| default="vehicle-count-in-drone-video/6", | ||
| help="Roboflow model ID", | ||
| type=str, | ||
| ) | ||
| parser.add_argument( | ||
| "--roboflow_api_key", | ||
| default=None, | ||
| help="Roboflow API KEY", | ||
| type=str, | ||
| ) | ||
| parser.add_argument( | ||
| "--source_video_path", | ||
| required=True, | ||
| help="Path to the source video file", | ||
| type=str, | ||
| ) | ||
| parser.add_argument( | ||
| "--target_video_path", | ||
| default=None, | ||
| help="Path to the target video file (output)", | ||
| type=str, | ||
| ) | ||
| parser.add_argument( | ||
| "--confidence_threshold", | ||
| default=0.3, | ||
| help="Confidence threshold for the model", | ||
| type=float, | ||
| ) | ||
| parser.add_argument( | ||
| "--iou_threshold", default=0.7, help="IOU threshold for the model", type=float | ||
| ) | ||
|
|
||
| args = parser.parse_args() | ||
|
|
||
| api_key = args.roboflow_api_key | ||
| def main( | ||
| source_video_path: str, | ||
| target_video_path: str, | ||
| roboflow_api_key: str, | ||
| model_id: str = "vehicle-count-in-drone-video/6", | ||
| confidence_threshold: float = 0.3, | ||
| iou_threshold: float = 0.7, | ||
| ) -> None: | ||
| """ | ||
| Traffic Flow Analysis with Inference and ByteTrack. | ||
|
|
||
| Args: | ||
| source_video_path: Path to the source video file | ||
| target_video_path: Path to the target video file (output) | ||
| roboflow_api_key: Roboflow API key | ||
|
||
| model_id: Roboflow model ID | ||
| confidence_threshold: Confidence threshold for the model | ||
| iou_threshold: IOU threshold for the model | ||
| """ | ||
| api_key = roboflow_api_key | ||
| api_key = os.environ.get("ROBOFLOW_API_KEY", api_key) | ||
| if api_key is None: | ||
| raise ValueError( | ||
| "Roboflow API KEY is missing. Please provide it as an argument or set the " | ||
| "ROBOFLOW_API_KEY environment variable." | ||
| ) | ||
| args.roboflow_api_key = api_key | ||
| roboflow_api_key = api_key | ||
|
|
||
| processor = VideoProcessor( | ||
| roboflow_api_key=args.roboflow_api_key, | ||
| model_id=args.model_id, | ||
| source_video_path=args.source_video_path, | ||
| target_video_path=args.target_video_path, | ||
| confidence_threshold=args.confidence_threshold, | ||
| iou_threshold=args.iou_threshold, | ||
| roboflow_api_key=roboflow_api_key, | ||
| model_id=model_id, | ||
| source_video_path=source_video_path, | ||
| target_video_path=target_video_path, | ||
| confidence_threshold=confidence_threshold, | ||
| iou_threshold=iou_threshold, | ||
| ) | ||
| processor.process_video() | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| from jsonargparse import auto_cli, set_parsing_settings | ||
|
|
||
| set_parsing_settings(parse_optionals_as_positionals=True) | ||
| auto_cli(main, as_positional=False) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -3,3 +3,4 @@ inference | |
| supervision | ||
| tqdm | ||
| ultralytics | ||
| jsonargparse[signatures] | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,6 +1,5 @@ | ||
| from __future__ import annotations | ||
|
|
||
| import argparse | ||
| from collections.abc import Iterable | ||
|
|
||
| import cv2 | ||
|
|
@@ -177,45 +176,35 @@ def process_frame(self, frame: np.ndarray) -> np.ndarray: | |
| return self.annotate_frame(frame, detections) | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| parser = argparse.ArgumentParser( | ||
| description="Traffic Flow Analysis with YOLO and ByteTrack" | ||
| ) | ||
|
|
||
| parser.add_argument( | ||
| "--source_weights_path", | ||
| required=True, | ||
| help="Path to the source weights file", | ||
| type=str, | ||
| ) | ||
| parser.add_argument( | ||
| "--source_video_path", | ||
| required=True, | ||
| help="Path to the source video file", | ||
| type=str, | ||
| ) | ||
| parser.add_argument( | ||
| "--target_video_path", | ||
| default=None, | ||
| help="Path to the target video file (output)", | ||
| type=str, | ||
| ) | ||
| parser.add_argument( | ||
| "--confidence_threshold", | ||
| default=0.3, | ||
| help="Confidence threshold for the model", | ||
| type=float, | ||
| ) | ||
| parser.add_argument( | ||
| "--iou_threshold", default=0.7, help="IOU threshold for the model", type=float | ||
| ) | ||
|
|
||
| args = parser.parse_args() | ||
| def main( | ||
| source_weights_path: str, | ||
| source_video_path: str, | ||
| target_video_path: str, | ||
|
||
| confidence_threshold: float = 0.3, | ||
| iou_threshold: float = 0.7, | ||
| ) -> None: | ||
| """ | ||
| Traffic Flow Analysis with YOLO and ByteTrack. | ||
|
|
||
| Args: | ||
| source_weights_path: Path to the source weights file | ||
| source_video_path: Path to the source video file | ||
| target_video_path: Path to the target video file (output) | ||
| confidence_threshold: Confidence threshold for the model | ||
| iou_threshold: IOU threshold for the model | ||
| """ | ||
| processor = VideoProcessor( | ||
| source_weights_path=args.source_weights_path, | ||
| source_video_path=args.source_video_path, | ||
| target_video_path=args.target_video_path, | ||
| confidence_threshold=args.confidence_threshold, | ||
| iou_threshold=args.iou_threshold, | ||
| source_weights_path=source_weights_path, | ||
| source_video_path=source_video_path, | ||
| target_video_path=target_video_path, | ||
| confidence_threshold=confidence_threshold, | ||
| iou_threshold=iou_threshold, | ||
| ) | ||
| processor.process_video() | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| from jsonargparse import auto_cli, set_parsing_settings | ||
|
|
||
| set_parsing_settings(parse_optionals_as_positionals=True) | ||
| auto_cli(main, as_positional=False) | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The
target_video_pathparameter should be optional with a default value ofNoneto maintain backward compatibility with the previous argparse implementation. The original argparse version haddefault=Nonefor this parameter. Without making it optional, this is a breaking API change.