Skip to content

Commit 0618ffb

Browse files
committed
Added the enable_backward option
Fixed a bug where the min_distance option was sometimes ignored.
1 parent fefc826 commit 0618ffb

File tree

4 files changed

+350
-39
lines changed

4 files changed

+350
-39
lines changed

README.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,9 @@ Amplifies coordinate movement with directional control for X/Y axes, preserving
3535
- GridPointGeneratorNode
3636
Generates a grid of coordinate points.
3737

38+
### 2025-6-8
39+
Added the enable_backward option. This is an experimental feature intended for tracking objects that don't appear in the first frame.
40+
Fixed a bug where the min_distance option was sometimes ignored.
3841

3942

4043
### Related resources

cotracker_node.py

Lines changed: 47 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import gc
88

99
import comfy.model_management as mm
10-
10+
from .trajectory_integration import trajectory_integration
1111

1212

1313
class CoTrackerNode:
@@ -57,6 +57,7 @@ def INPUT_TYPES(cls):
5757
"tooltip": "Minimum distance between tracking points"
5858
}),
5959
"force_offload": ("BOOLEAN", {"default": True}),
60+
"enable_backward": ("BOOLEAN", {"default": False}),
6061
}
6162
}
6263

@@ -121,7 +122,7 @@ def prepare_query_points(self, points, video_shape):
121122

122123
return query_points_tensor
123124

124-
def track_points(self, images, tracking_points, grid_size, max_num_of_points, tracking_mask=None, confidence_threshold=0.5, min_distance=60, force_offload=True):
125+
def track_points(self, images, tracking_points, grid_size, max_num_of_points, tracking_mask=None, confidence_threshold=0.5, min_distance=60, force_offload=True, enable_backward=False):
125126

126127
self.load_model("cotracker3_online")
127128

@@ -147,26 +148,31 @@ def track_points(self, images, tracking_points, grid_size, max_num_of_points, tr
147148

148149
results = []
149150

150-
if len(points) > 0:
151-
print(f"forward - queries")
151+
def _tracking(video, grid_size, queries, add_support_grid):
152152
with torch.no_grad():
153153
self.model(
154154
video_chunk=video,
155155
is_first_step=True,
156-
grid_size=0,
156+
grid_size=grid_size,
157157
queries=queries,
158-
add_support_grid=True
158+
add_support_grid=add_support_grid
159159
)
160160
for ind in range(0, video.shape[1] - self.model.step, self.model.step):
161161
pred_tracks, pred_visibility = self.model(
162162
video_chunk=video[:, ind : ind + self.model.step * 2],
163163
is_first_step=False,
164-
grid_size=0,
164+
grid_size=grid_size,
165165
queries=queries,
166-
add_support_grid=True
166+
add_support_grid=add_support_grid
167167
) # B T N 2, B T N 1
168+
return pred_tracks, pred_visibility
169+
170+
171+
if len(points) > 0:
172+
print(f"forward - queries")
168173

169-
results, images_np = self.format_results(pred_tracks, pred_visibility, None, confidence_threshold, points, max_num_of_points, min_distance, images_np)
174+
pred_tracks, pred_visibility = _tracking(video, 0, queries, True)
175+
results, images_np = self.format_results(pred_tracks, pred_visibility, None, confidence_threshold, points, max_num_of_points, 1, images_np)
170176

171177
print(f"{len(results)=}")
172178

@@ -179,24 +185,16 @@ def track_points(self, images, tracking_points, grid_size, max_num_of_points, tr
179185

180186
if grid_size > 0:
181187
print(f"forward - grid")
182-
with torch.no_grad():
183-
self.model(
184-
video_chunk=video,
185-
is_first_step=True,
186-
grid_size=grid_size,
187-
queries=None,
188-
add_support_grid=False
189-
)
190-
for ind in range(0, video.shape[1] - self.model.step, self.model.step):
191-
pred_tracks, pred_visibility = self.model(
192-
video_chunk=video[:, ind : ind + self.model.step * 2],
193-
is_first_step=False,
194-
grid_size=grid_size,
195-
queries=None,
196-
add_support_grid=False
197-
) # B T N 2, B T N 1
198188

199-
results2, images_np = self.format_results(pred_tracks, pred_visibility, tracking_mask, confidence_threshold, points, max_num_of_points, min_distance, images_np)
189+
pred_tracks, pred_visibility = _tracking(video, grid_size, None, False)
190+
191+
if enable_backward:
192+
pred_tracks_b, pred_visibility_b = _tracking(video.flip(1), grid_size, None, False)
193+
_,_,_,H,W = video.shape
194+
pred_tracks, pred_visibility = trajectory_integration(pred_tracks, pred_visibility, pred_tracks_b, pred_visibility_b, (H,W) , grid_size)
195+
196+
results2, images_np = self.format_results(pred_tracks, pred_visibility, tracking_mask, confidence_threshold, points, max_num_of_points, min_distance, images_np, enable_backward)
197+
200198
print(f"{len(results2)=}")
201199

202200
results = results + results2
@@ -213,7 +211,6 @@ def track_points(self, images, tracking_points, grid_size, max_num_of_points, tr
213211
return (results,images_with_markers)
214212

215213

216-
217214
def select_diverse_points(self, motion_sorted_indices, tracks, visibility, max_points, min_distance):
218215
"""
219216
Selects spatially diverse points from among those with large motion.
@@ -312,7 +309,8 @@ def select_points(self, tracks, visibility, vis_threshold=0.5, max_points=9, min
312309
# 3. Point selection
313310
selected_indices = []
314311

315-
if len(valid_indices) <= max_points:
312+
# if len(valid_indices) <= max_points:
313+
if False:
316314
selected_indices = valid_indices.tolist()
317315
else:
318316
# Sort points in descending order of motion magnitude
@@ -336,11 +334,14 @@ def select_points(self, tracks, visibility, vis_threshold=0.5, max_points=9, min
336334
return selected_indices
337335

338336

339-
def format_results(self, tracks, visibility, mask, confidence_threshold, original_points, max_points, min_distance, images_np):
337+
def format_results(self, tracks, visibility, mask, confidence_threshold, original_points, max_points, min_distance, images_np, enable_backward=False):
340338
# tracks : (B, T, N, 2) where B=batch, T=frames, N=points
341339
tracks = tracks.squeeze(0).cpu().numpy() # (T, N, 2)
342340
visibility = visibility.squeeze(0).cpu().numpy() # (T, N)
343341

342+
if enable_backward:
343+
confidence_threshold = 0
344+
344345
num_frames, num_points, _ = tracks.shape
345346

346347
def filter_by_mask(trs, vis, mask):
@@ -398,17 +399,25 @@ def filter_by_mask(trs, vis, mask):
398399
"y": int(y),
399400
})
400401
else:
401-
# Use the previous coordinates
402-
if len(point_track) > 0:
403-
last_point = point_track[-1].copy()
404-
point_track.append(last_point)
405-
x = last_point["x"]
406-
y = last_point["y"]
407-
else:
402+
if enable_backward:
408403
point_track.append({
409-
"x": int(x),
410-
"y": int(y),
404+
"x": -100,
405+
"y": -100,
411406
})
407+
x = -100
408+
y = -100
409+
else:
410+
# Use the previous coordinates
411+
if len(point_track) > 0:
412+
last_point = point_track[-1].copy()
413+
point_track.append(last_point)
414+
x = last_point["x"]
415+
y = last_point["y"]
416+
else:
417+
point_track.append({
418+
"x": int(x),
419+
"y": int(y),
420+
})
412421

413422
if frame_idx < images_np.shape[0]:
414423
cv2.circle(images_np[frame_idx], (int(x), int(y)), marker_radius, marker_color, marker_thickness)

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
[project]
22
name = "cotracker_node"
33
description = "This is a node that outputs tracking results of a grid or specified points using CoTracker. It can be directly connected to the WanVideo ATI Tracks Node."
4-
version = "1.0.1"
4+
version = "1.0.2"
55
license = {file = "LICENSE"}
66

77
[project.urls]

0 commit comments

Comments
 (0)