-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtest.py
More file actions
171 lines (149 loc) · 6.64 KB
/
test.py
File metadata and controls
171 lines (149 loc) · 6.64 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
import cv2
import numpy as np
import threading
from ultralytics import YOLO
import torch
from collections import defaultdict
import math
import sys
import queue
from tracking_config import *
import serial
import time
from move_control import SerialController, kinematics_solver
class Tracker():
def __init__(self,model_path):
self.model = YOLO(model_path,task="detect")
#串口控制器
self.serial_controller = SerialController()
# 速度控制
self.speed_gain = 0.1 # 像素到速度的比例
self.max_speed = 0.1
# 控制参数
self.min_error_threshold = TRACKING_CONFIG['min_error_threshold']
self.max_no_target_time = TRACKING_CONFIG['max_no_target_time']
self.last_target_time = time.time()
self.serial_controller.stop()
self.track_history = defaultdict(lambda:[])
print("初始化,停止移动")
def target_track(self,img):
p_time = time.time()
result = self.model.track(img, imgsz=640, verbose=False,conf=0.7)[0]
if result is None or not hasattr(result, 'boxes') or len(result.boxes) == 0:
# 没有检测到目标,执行停止
self.serial_controller.stop()
return img
# 获取类别名称映射
names = result.names if hasattr(result, 'names') else {}
# 获取图像尺寸
image_height, image_width = img.shape[:2]
# 画面中心区域
center_x = image_width // 2
center_y = (image_height // 2) - 120
img_center = (center_x,center_y)
detected_target = False
boxes = result.boxes.xywh.cpu()
trace_ids = result.boxes.id.int().cpu().tolist()
cls_ids = result.boxes.cls.int().cpu().tolist()
confs = result.boxes.conf.cpu().tolist()
#过滤其他目标
for box, track_id, cls_id, conf in zip(boxes, trace_ids, cls_ids, confs):
class_name = names.get(cls_id, str(cls_id))
if class_name != TARGET_NAME:
continue
#期望目标信息
detected_target = True
cv2.circle(img,img_center,radius=10,color=(0,0,255),thickness=-1)
if detected_target:
self.last_target_time = time.time()
x, y, w, h = box
x1, y1, x2, y2 = int(x - w/2), int(y - h/2), int(x + w/2), int(y + h/2)
# track = self.track_history[track_id]
# track.append((float(x),float(y)))
# if len(track) > 30:
# track.pop(0)
# points = np.array(track, dtype=np.int32).reshape(-1,1,2)
cx = int((x1 + x2) / 2)
cy = int((y1 + y2) / 2)
target_center = (cx,cy)
conf = float(conf)
# 获取track id
track_id = None
label = f'{class_name} {conf:.2f}'
error_x = cx - center_x
error_y = cy - center_y
# print(f"图像误差: error_x={error_x}, error_y={error_y}")
# print(f"目标位置: {'右' if error_x > 0 else '左' if error_x < 0 else '居中'}, {'下' if error_y > 0 else '上' if error_y < 0 else '居中'}")
cv2.rectangle(img, (x1, y1), (x2, y2), (0, 255, 255), 2)
cv2.putText(img, label, (x1, y1 - 10),
cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 2)
# cv2.polylines(img,[points],isClosed=False,color=(255,255,0),thickness=10)
if abs(error_x) > self.min_error_threshold or abs(error_y) > self.min_error_threshold:
# 方向向量
#vector
norm = math.hypot(error_x, error_y)
if norm == 0:
vx, vy = 0.0, 0.0
else:
dir_x = error_x / norm
dir_y = error_y / norm
target_speed = min(norm * self.speed_gain, self.max_speed)
vx = -dir_y * target_speed # 前后
vy = -dir_x * target_speed # 左右
w = 0.0
# direction_x = "前" if vx > 0 else "后" if vx < 0 else "停"
# direction_y = "右" if vy > 0 else "左" if vy < 0 else "中"
# print(f"目标位置: ({cx}, {cy}), 误差: X={error_x:.1f}, Y={error_y:.1f}")
# print(f"控制输出: VX={vx:.2f}, VY={vy:.2f}")
# print(f"运动方向: {direction_x}{direction_y}")
v0, v1, v2 = kinematics_solver(vx, vy, w)
# print(f"轮子速度: V0={v0:.1f}, V1={v1:.1f}, V2={v2:.1f}")
# print("-" * 50)
self.serial_controller.send_command(v0, v1, v2)
# self.serial_controller.read_wheel_speeds()
# print(f"轮子速度: {self.serial_controller.latest_wheel_speeds}")
else:
# 误差很小,停止移动
self.serial_controller.stop()
print("目标已居中,停止移动")
# for 循环结束后,如没有任何目标命中,执行停止
if not detected_target:
self.serial_controller.stop()
# print("未检测到目标,停止移动")
def main():
#初始化
model_path = "./models/best_s.engine"
tracker = Tracker(model_path)
cap = cv2.VideoCapture(CAMERA_CONFIG["camera_id"])
# 设置格式为 MJPEG
cap.set(cv2.CAP_PROP_FOURCC, cv2.VideoWriter_fourcc(*'MJPG'))
# 设置分辨率和帧率
cap.set(cv2.CAP_PROP_FRAME_WIDTH, 1280)
cap.set(cv2.CAP_PROP_FRAME_HEIGHT, 720)
cap.set(cv2.CAP_PROP_FPS, 120)
# 确认设置是否成功00000
print("Width:", cap.get(cv2.CAP_PROP_FRAME_WIDTH))
print("Height:", cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
print("FPS:", cap.get(cv2.CAP_PROP_FPS))
p_time = time.time()
while True:
ret, frame = cap.read()
frame = cv2.flip(frame, 1) # 水平镜像
if not ret:
break
tracker.target_track(frame)
c_time = time.time()
fps = 1 / (c_time - p_time)
p_time = c_time
# cv2.putText(frame,f"FPS:{fps}",(10,10),cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 2)
cv2.imshow("result", frame)
if cv2.waitKey(1) & 0xFF == ord('q'):
tracker.serial_controller.stop()
print("停止移动")
break
tracker.serial_controller.close()
cap.release()
cv2.destroyAllWindows()
pass
if __name__ == "__main__":
main()