22# @Author : LG
33
44from PyQt5 import QtWidgets , QtGui , QtCore
5- from ISAT .widgets .polygon import Polygon , Vertex , PromptPoint , Line
5+ from ISAT .widgets .polygon import Polygon , Vertex , PromptPoint , Line , Rect
66from ISAT .configs import STATUSMode , CLICKMode , DRAWMode , CONTOURMode
77import numpy as np
88import cv2
@@ -18,6 +18,7 @@ def __init__(self, mainwindow):
1818 self .mask_item : QtWidgets .QGraphicsPixmapItem = None
1919 self .image_data = None
2020 self .current_graph : Polygon = None
21+ self .current_sam_rect : Rect = None
2122 self .current_line : Line = None
2223 self .mode = STATUSMode .VIEW
2324 self .click = CLICKMode .POSITIVE
@@ -82,6 +83,7 @@ def change_mode_to_create(self):
8283 self .mainwindow .actionNext .setEnabled (False )
8384
8485 self .mainwindow .actionSegment_anything .setEnabled (False )
86+ self .mainwindow .actionSegment_anything_box .setEnabled (False )
8587 self .mainwindow .actionPolygon .setEnabled (False )
8688 self .mainwindow .actionBackspace .setEnabled (True )
8789 self .mainwindow .actionFinish .setEnabled (True )
@@ -157,6 +159,7 @@ def change_mode_to_edit(self):
157159 self .mainwindow .actionNext .setEnabled (False )
158160
159161 self .mainwindow .actionSegment_anything .setEnabled (False )
162+ self .mainwindow .actionSegment_anything_box .setEnabled (False )
160163 self .mainwindow .actionPolygon .setEnabled (False )
161164 self .mainwindow .actionBackspace .setEnabled (False )
162165 self .mainwindow .actionFinish .setEnabled (False )
@@ -198,6 +201,7 @@ def change_mode_to_repaint(self):
198201 self .mainwindow .actionNext .setEnabled (False )
199202
200203 self .mainwindow .actionSegment_anything .setEnabled (False )
204+ self .mainwindow .actionSegment_anything_box .setEnabled (False )
201205 self .mainwindow .actionPolygon .setEnabled (False )
202206 self .mainwindow .actionBackspace .setEnabled (True )
203207 self .mainwindow .actionFinish .setEnabled (False )
@@ -242,6 +246,10 @@ def start_segment_anything(self):
242246 self .draw_mode = DRAWMode .SEGMENTANYTHING
243247 self .start_draw ()
244248
249+ def start_segment_anything_box (self ):
250+ self .draw_mode = DRAWMode .SEGMENTANYTHING_BOX
251+ self .start_draw ()
252+
245253 def start_draw_polygon (self ):
246254 self .draw_mode = DRAWMode .POLYGON
247255 self .start_draw ()
@@ -269,7 +277,7 @@ def finish_draw(self):
269277 is_crowd = False
270278 note = ''
271279
272- if self .draw_mode == DRAWMode .SEGMENTANYTHING :
280+ if self .draw_mode == DRAWMode .SEGMENTANYTHING or self . draw_mode == DRAWMode . SEGMENTANYTHING_BOX :
273281 # mask to polygon
274282 # --------------
275283 if self .masks is not None :
@@ -378,6 +386,12 @@ def finish_draw(self):
378386 self .mainwindow .annos_dock_widget .update_listwidget ()
379387
380388 self .current_graph = None
389+
390+ if self .current_sam_rect is not None :
391+ self .current_sam_rect .delete ()
392+ self .removeItem (self .current_sam_rect )
393+ self .current_sam_rect = None
394+
381395 self .change_mode_to_view ()
382396
383397 # mask清空
@@ -406,6 +420,11 @@ def cancel_draw(self):
406420 for item in self .selectedItems ():
407421 item .setSelected (False )
408422
423+ if self .current_sam_rect is not None :
424+ self .current_sam_rect .delete ()
425+ self .removeItem (self .current_sam_rect )
426+ self .current_sam_rect = None
427+
409428 self .change_mode_to_view ()
410429
411430 self .click_points .clear ()
@@ -763,6 +782,14 @@ def mousePressEvent(self, event: 'QtWidgets.QGraphicsSceneMouseEvent'):
763782 self .prompt_points .append (prompt_point )
764783 self .addItem (prompt_point )
765784
785+ elif self .draw_mode == DRAWMode .SEGMENTANYTHING_BOX : # sam 矩形框提示
786+ if self .current_sam_rect is None :
787+ self .current_sam_rect = Rect ()
788+ self .current_sam_rect .setZValue (2 )
789+ self .addItem (self .current_sam_rect )
790+ self .current_sam_rect .addPoint (QtCore .QPointF (sceneX , sceneY ))
791+ self .current_sam_rect .addPoint (QtCore .QPointF (sceneX , sceneY ))
792+
766793 elif self .draw_mode == DRAWMode .POLYGON :
767794 # 移除随鼠标移动的点
768795 self .current_graph .removePoint (len (self .current_graph .points ) - 1 )
@@ -877,6 +904,10 @@ def mouseMoveEvent(self, event: 'QtWidgets.QGraphicsSceneMouseEvent'):
877904 if self .draw_mode == DRAWMode .POLYGON :
878905 # 随鼠标位置实时更新多边形
879906 self .current_graph .movePoint (len (self .current_graph .points ) - 1 , pos )
907+ if self .draw_mode == DRAWMode .SEGMENTANYTHING_BOX :
908+ if self .current_sam_rect is not None :
909+ self .current_sam_rect .movePoint (len (self .current_sam_rect .points ) - 1 , pos )
910+ self .update_mask ()
880911
881912 if self .mode == STATUSMode .REPAINT :
882913 self .current_line .movePoint (len (self .current_line .points ) - 1 , pos )
@@ -946,6 +977,23 @@ def update_mask(self):
946977
947978 if len (self .click_points ) > 0 and len (self .click_points_mode ) > 0 :
948979 masks = self .mainwindow .segany .predict_with_point_prompt (self .click_points , self .click_points_mode )
980+ self .masks = masks
981+ color = np .array ([0 , 0 , 255 ])
982+ h , w = masks .shape [- 2 :]
983+ mask_image = masks .reshape (h , w , 1 ) * color .reshape (1 , 1 , - 1 )
984+ mask_image = mask_image .astype ("uint8" )
985+ mask_image = cv2 .cvtColor (mask_image , cv2 .COLOR_BGR2RGB )
986+ mask_image = cv2 .addWeighted (self .image_data , self .mask_alpha , mask_image , 1 , 0 )
987+ elif self .current_sam_rect is not None :
988+ point1 = self .current_sam_rect .points [0 ]
989+ point2 = self .current_sam_rect .points [1 ]
990+ box = np .array ([min (point1 .x (), point2 .x ()),
991+ min (point1 .y (), point2 .y ()),
992+ max (point1 .x (), point2 .x ()),
993+ max (point1 .y (), point2 .y ()),
994+ ])
995+ masks = self .mainwindow .segany .predict_with_box_prompt (box )
996+
949997 self .masks = masks
950998 color = np .array ([0 , 0 , 255 ])
951999 h , w = masks .shape [- 2 :]
@@ -954,19 +1002,14 @@ def update_mask(self):
9541002 mask_image = cv2 .cvtColor (mask_image , cv2 .COLOR_BGR2RGB )
9551003 # 这里通过调整原始图像的权重self.mask_alpha,来调整mask的明显程度。
9561004 mask_image = cv2 .addWeighted (self .image_data , self .mask_alpha , mask_image , 1 , 0 )
957- mask_image = QtGui .QImage (mask_image [:], mask_image .shape [1 ], mask_image .shape [0 ], mask_image .shape [1 ] * 3 ,
958- QtGui .QImage .Format_RGB888 )
959- mask_pixmap = QtGui .QPixmap (mask_image )
960- if self .mask_item is not None :
961- self .mask_item .setPixmap (mask_pixmap )
9621005 else :
9631006 mask_image = np .zeros (self .image_data .shape , dtype = np .uint8 )
9641007 mask_image = cv2 .addWeighted (self .image_data , 1 , mask_image , 0 , 0 )
965- mask_image = QtGui .QImage (mask_image [:], mask_image .shape [1 ], mask_image .shape [0 ], mask_image .shape [1 ] * 3 ,
966- QtGui .QImage .Format_RGB888 )
967- mask_pixmap = QtGui .QPixmap (mask_image )
968- if self .mask_item is not None :
969- self .mask_item .setPixmap (mask_pixmap )
1008+ mask_image = QtGui .QImage (mask_image [:], mask_image .shape [1 ], mask_image .shape [0 ], mask_image .shape [1 ] * 3 ,
1009+ QtGui .QImage .Format_RGB888 )
1010+ mask_pixmap = QtGui .QPixmap (mask_image )
1011+ if self .mask_item is not None :
1012+ self .mask_item .setPixmap (mask_pixmap )
9701013
9711014 def backspace (self ):
9721015 if self .mode == STATUSMode .CREATE :
0 commit comments