77import gc
88
99import comfy .model_management as mm
10-
10+ from . trajectory_integration import trajectory_integration
1111
1212
1313class CoTrackerNode :
@@ -57,6 +57,7 @@ def INPUT_TYPES(cls):
5757 "tooltip" : "Minimum distance between tracking points"
5858 }),
5959 "force_offload" : ("BOOLEAN" , {"default" : True }),
60+ "enable_backward" : ("BOOLEAN" , {"default" : False }),
6061 }
6162 }
6263
@@ -121,7 +122,7 @@ def prepare_query_points(self, points, video_shape):
121122
122123 return query_points_tensor
123124
124- def track_points (self , images , tracking_points , grid_size , max_num_of_points , tracking_mask = None , confidence_threshold = 0.5 , min_distance = 60 , force_offload = True ):
125+ def track_points (self , images , tracking_points , grid_size , max_num_of_points , tracking_mask = None , confidence_threshold = 0.5 , min_distance = 60 , force_offload = True , enable_backward = False ):
125126
126127 self .load_model ("cotracker3_online" )
127128
@@ -147,26 +148,31 @@ def track_points(self, images, tracking_points, grid_size, max_num_of_points, tr
147148
148149 results = []
149150
150- if len (points ) > 0 :
151- print (f"forward - queries" )
151+ def _tracking (video , grid_size , queries , add_support_grid ):
152152 with torch .no_grad ():
153153 self .model (
154154 video_chunk = video ,
155155 is_first_step = True ,
156- grid_size = 0 ,
156+ grid_size = grid_size ,
157157 queries = queries ,
158- add_support_grid = True
158+ add_support_grid = add_support_grid
159159 )
160160 for ind in range (0 , video .shape [1 ] - self .model .step , self .model .step ):
161161 pred_tracks , pred_visibility = self .model (
162162 video_chunk = video [:, ind : ind + self .model .step * 2 ],
163163 is_first_step = False ,
164- grid_size = 0 ,
164+ grid_size = grid_size ,
165165 queries = queries ,
166- add_support_grid = True
166+ add_support_grid = add_support_grid
167167 ) # B T N 2, B T N 1
168+ return pred_tracks , pred_visibility
169+
170+
171+ if len (points ) > 0 :
172+ print (f"forward - queries" )
168173
169- results , images_np = self .format_results (pred_tracks , pred_visibility , None , confidence_threshold , points , max_num_of_points , min_distance , images_np )
174+ pred_tracks , pred_visibility = _tracking (video , 0 , queries , True )
175+ results , images_np = self .format_results (pred_tracks , pred_visibility , None , confidence_threshold , points , max_num_of_points , 1 , images_np )
170176
171177 print (f"{ len (results )= } " )
172178
@@ -179,24 +185,16 @@ def track_points(self, images, tracking_points, grid_size, max_num_of_points, tr
179185
180186 if grid_size > 0 :
181187 print (f"forward - grid" )
182- with torch .no_grad ():
183- self .model (
184- video_chunk = video ,
185- is_first_step = True ,
186- grid_size = grid_size ,
187- queries = None ,
188- add_support_grid = False
189- )
190- for ind in range (0 , video .shape [1 ] - self .model .step , self .model .step ):
191- pred_tracks , pred_visibility = self .model (
192- video_chunk = video [:, ind : ind + self .model .step * 2 ],
193- is_first_step = False ,
194- grid_size = grid_size ,
195- queries = None ,
196- add_support_grid = False
197- ) # B T N 2, B T N 1
198188
199- results2 , images_np = self .format_results (pred_tracks , pred_visibility , tracking_mask , confidence_threshold , points , max_num_of_points , min_distance , images_np )
189+ pred_tracks , pred_visibility = _tracking (video , grid_size , None , False )
190+
191+ if enable_backward :
192+ pred_tracks_b , pred_visibility_b = _tracking (video .flip (1 ), grid_size , None , False )
193+ _ ,_ ,_ ,H ,W = video .shape
194+ pred_tracks , pred_visibility = trajectory_integration (pred_tracks , pred_visibility , pred_tracks_b , pred_visibility_b , (H ,W ) , grid_size )
195+
196+ results2 , images_np = self .format_results (pred_tracks , pred_visibility , tracking_mask , confidence_threshold , points , max_num_of_points , min_distance , images_np , enable_backward )
197+
200198 print (f"{ len (results2 )= } " )
201199
202200 results = results + results2
@@ -213,7 +211,6 @@ def track_points(self, images, tracking_points, grid_size, max_num_of_points, tr
213211 return (results ,images_with_markers )
214212
215213
216-
217214 def select_diverse_points (self , motion_sorted_indices , tracks , visibility , max_points , min_distance ):
218215 """
219216 Selects spatially diverse points from among those with large motion.
@@ -312,7 +309,8 @@ def select_points(self, tracks, visibility, vis_threshold=0.5, max_points=9, min
312309 # 3. Point selection
313310 selected_indices = []
314311
315- if len (valid_indices ) <= max_points :
312+ # if len(valid_indices) <= max_points:
313+ if False :
316314 selected_indices = valid_indices .tolist ()
317315 else :
318316 # Sort points in descending order of motion magnitude
@@ -336,11 +334,14 @@ def select_points(self, tracks, visibility, vis_threshold=0.5, max_points=9, min
336334 return selected_indices
337335
338336
339- def format_results (self , tracks , visibility , mask , confidence_threshold , original_points , max_points , min_distance , images_np ):
337+ def format_results (self , tracks , visibility , mask , confidence_threshold , original_points , max_points , min_distance , images_np , enable_backward = False ):
340338 # tracks : (B, T, N, 2) where B=batch, T=frames, N=points
341339 tracks = tracks .squeeze (0 ).cpu ().numpy () # (T, N, 2)
342340 visibility = visibility .squeeze (0 ).cpu ().numpy () # (T, N)
343341
342+ if enable_backward :
343+ confidence_threshold = 0
344+
344345 num_frames , num_points , _ = tracks .shape
345346
346347 def filter_by_mask (trs , vis , mask ):
@@ -398,17 +399,25 @@ def filter_by_mask(trs, vis, mask):
398399 "y" : int (y ),
399400 })
400401 else :
401- # Use the previous coordinates
402- if len (point_track ) > 0 :
403- last_point = point_track [- 1 ].copy ()
404- point_track .append (last_point )
405- x = last_point ["x" ]
406- y = last_point ["y" ]
407- else :
402+ if enable_backward :
408403 point_track .append ({
409- "x" : int ( x ) ,
410- "y" : int ( y ) ,
404+ "x" : - 100 ,
405+ "y" : - 100 ,
411406 })
407+ x = - 100
408+ y = - 100
409+ else :
410+ # Use the previous coordinates
411+ if len (point_track ) > 0 :
412+ last_point = point_track [- 1 ].copy ()
413+ point_track .append (last_point )
414+ x = last_point ["x" ]
415+ y = last_point ["y" ]
416+ else :
417+ point_track .append ({
418+ "x" : int (x ),
419+ "y" : int (y ),
420+ })
412421
413422 if frame_idx < images_np .shape [0 ]:
414423 cv2 .circle (images_np [frame_idx ], (int (x ), int (y )), marker_radius , marker_color , marker_thickness )
0 commit comments