From 133eed05d6ab3c19b6b644270dbd02941c2e0b8f Mon Sep 17 00:00:00 2001 From: mucunwuxian Date: Wed, 12 Jul 2023 13:56:41 +0900 Subject: [PATCH 1/3] add vid_stride argument --- tools/demo.py | 33 +++++++++++++++++++++------------ 1 file changed, 21 insertions(+), 12 deletions(-) diff --git a/tools/demo.py b/tools/demo.py index 5799e8a..c18636a 100644 --- a/tools/demo.py +++ b/tools/demo.py @@ -301,6 +301,10 @@ def make_parser(): default=True, type=bool, help='whether save visualization results') + parser.add_argument('--vid_stride', + default=1, + type=int, + help='video frame-rate stride') return parser @@ -334,21 +338,26 @@ def main(): vid_writer = cv2.VideoWriter( save_path, cv2.VideoWriter_fourcc(*"mp4v"), fps, (int(width), int(height))) + n = 0 while True: - ret_val, frame = cap.read() - if ret_val: - bboxes, scores, cls_inds = infer_engine.forward(frame) - result_frame = infer_engine.visualize(frame, bboxes, scores, cls_inds, conf=args.conf, save_result=False) - if args.save_result: - vid_writer.write(result_frame) + # ret_val, frame = cap.read() + cap.grab() # .read() = .grab() followed by .retrieve() + if n % args.vid_stride == 0: + ret_val, frame = cap.retrieve() + if ret_val: + bboxes, scores, cls_inds = infer_engine.forward(frame) + result_frame = infer_engine.visualize(frame, bboxes, scores, cls_inds, conf=args.conf, save_result=False) + if args.save_result: + vid_writer.write(result_frame) + else: + cv2.namedWindow("DAMO-YOLO", cv2.WINDOW_NORMAL) + cv2.imshow("DAMO-YOLO", result_frame) + ch = cv2.waitKey(1) + if ch == 27 or ch == ord("q") or ch == ord("Q"): + break else: - cv2.namedWindow("DAMO-YOLO", cv2.WINDOW_NORMAL) - cv2.imshow("DAMO-YOLO", result_frame) - ch = cv2.waitKey(1) - if ch == 27 or ch == ord("q") or ch == ord("Q"): break - else: - break + n += 1 From e2db698595219fcac6890f05c4730e0ff9dea46b Mon Sep 17 00:00:00 2001 From: mucunwuxian Date: Wed, 12 Jul 2023 15:58:21 +0900 Subject: [PATCH 2/3] adjust demo.py --- tools/demo.py | 22 +++------------------- 1 file changed, 3 insertions(+), 19 deletions(-) diff --git a/tools/demo.py b/tools/demo.py index c18636a..cac1751 100644 --- a/tools/demo.py +++ b/tools/demo.py @@ -6,6 +6,7 @@ import cv2 import numpy as np import torch +from tqdm import tqdm from loguru import logger from PIL import Image @@ -23,7 +24,6 @@ class Infer(): def __init__(self, config, infer_size=[640,640], device='cuda', output_dir='./', ckpt=None, end2end=False): - self.ckpt_path = ckpt suffix = ckpt.split('.')[-1] if suffix == 'onnx': @@ -66,9 +66,7 @@ def _pad_image(self, img, target_size): return ImageList(pad_imgs, img_sizes, pad_sizes) - def _build_engine(self, config, engine_type): - print(f'Inference with {engine_type} engine!') if engine_type == 'torch': model = build_local_model(config, self.device) @@ -88,7 +86,6 @@ def _build_engine(self, config, engine_type): return model def build_tensorRT_engine(self, trt_path): - import tensorrt as trt from cuda import cuda loggert = trt.Logger(trt.Logger.INFO) @@ -145,11 +142,7 @@ def predict(batch): # result gets copied into output return predict - - - def build_onnx_engine(self, onnx_path): - import onnxruntime session = onnxruntime.InferenceSession(onnx_path) @@ -166,7 +159,6 @@ def build_onnx_engine(self, onnx_path): def preprocess(self, origin_img): - img = transform_img(origin_img, 0, **self.config.test.augment.transform, infer_size=self.infer_size) @@ -178,10 +170,8 @@ def preprocess(self, origin_img): return img, (ow, oh) def postprocess(self, preds, image, origin_shape=None): - if self.engine_type == 'torch': output = preds - elif self.engine_type == 'onnx': scores = torch.Tensor(preds[0]) bboxes = torch.Tensor(preds[1]) @@ -225,9 +215,7 @@ def postprocess(self, preds, image, origin_shape=None): return bboxes, scores, cls_inds - def forward(self, origin_image): - image, origin_shape = self.preprocess(origin_image) if self.engine_type == 'torch': @@ -306,7 +294,6 @@ def make_parser(): type=int, help='video frame-rate stride') - return parser @@ -338,11 +325,10 @@ def main(): vid_writer = cv2.VideoWriter( save_path, cv2.VideoWriter_fourcc(*"mp4v"), fps, (int(width), int(height))) - n = 0 - while True: + for i in tqdm(range(int(capture.get(cv2.CAP_PROP_FRAME_COUNT)))): # ret_val, frame = cap.read() cap.grab() # .read() = .grab() followed by .retrieve() - if n % args.vid_stride == 0: + if i % args.vid_stride == 0: ret_val, frame = cap.retrieve() if ret_val: bboxes, scores, cls_inds = infer_engine.forward(frame) @@ -357,8 +343,6 @@ def main(): break else: break - n += 1 - if __name__ == '__main__': From 40d478bafcec09f08570270ba3330812515fe5a6 Mon Sep 17 00:00:00 2001 From: mucunwuxian Date: Wed, 12 Jul 2023 16:01:30 +0900 Subject: [PATCH 3/3] re-adjust demo.py --- tools/demo.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tools/demo.py b/tools/demo.py index cac1751..4c8b8c9 100644 --- a/tools/demo.py +++ b/tools/demo.py @@ -325,7 +325,7 @@ def main(): vid_writer = cv2.VideoWriter( save_path, cv2.VideoWriter_fourcc(*"mp4v"), fps, (int(width), int(height))) - for i in tqdm(range(int(capture.get(cv2.CAP_PROP_FRAME_COUNT)))): + for i in tqdm(range(int(cap.get(cv2.CAP_PROP_FRAME_COUNT)))): # ret_val, frame = cap.read() cap.grab() # .read() = .grab() followed by .retrieve() if i % args.vid_stride == 0: