Skip to content

Commit cff33fc

Browse files
Merge branch 'main' into bug/dg-260-error-reoccurs-when-stopping-the-process
2 parents 9250baa + cb0be0f commit cff33fc

File tree

3 files changed

+308
-0
lines changed

3 files changed

+308
-0
lines changed

inference/core/workflows/core_steps/loader.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -477,6 +477,9 @@
477477
from inference.core.workflows.core_steps.visualizations.halo.v2 import (
478478
HaloVisualizationBlockV2,
479479
)
480+
from inference.core.workflows.core_steps.visualizations.heatmap.v1 import (
481+
HeatmapVisualizationBlockV1,
482+
)
480483
from inference.core.workflows.core_steps.visualizations.icon.v1 import (
481484
IconVisualizationBlockV1,
482485
)
@@ -740,6 +743,7 @@ def load_blocks() -> List[Type[WorkflowBlock]]:
740743
ImageContoursDetectionBlockV1,
741744
ImagePreprocessingBlockV1,
742745
ImageSlicerBlockV1,
746+
HeatmapVisualizationBlockV1,
743747
ImageThresholdBlockV1,
744748
MotionDetectionBlockV1,
745749
BackgroundSubtractionBlockV1,

inference/core/workflows/core_steps/visualizations/heatmap/__init__.py

Whitespace-only changes.
Lines changed: 304 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,304 @@
1+
import time
2+
from typing import Dict, Literal, Optional, Type, Union
3+
4+
import numpy as np
5+
import supervision as sv
6+
from pydantic import ConfigDict, Field
7+
8+
from inference.core.workflows.core_steps.visualizations.common.base import (
9+
OUTPUT_IMAGE_KEY,
10+
PredictionsVisualizationBlock,
11+
PredictionsVisualizationManifest,
12+
)
13+
from inference.core.workflows.execution_engine.entities.base import (
14+
VideoMetadata,
15+
WorkflowImageData,
16+
)
17+
from inference.core.workflows.execution_engine.entities.types import (
18+
BOOLEAN_KIND,
19+
FLOAT_KIND,
20+
INTEGER_KIND,
21+
STRING_KIND,
22+
VIDEO_METADATA_KIND,
23+
Selector,
24+
)
25+
from inference.core.workflows.prototypes.block import BlockResult, WorkflowBlockManifest
26+
27+
TYPE: str = "roboflow_core/heatmap_visualization@v1"
28+
SHORT_DESCRIPTION = "Draw a heatmap based on detections in an image."
29+
LONG_DESCRIPTION = """
30+
Draw heatmaps on an image based on provided detections. Heat accumulates over time and is drawn as a semi-transparent overlay of blurred circles.
31+
32+
## How This Block Works
33+
34+
This block takes an image and detection predictions and draws a heatmap. The block:
35+
36+
1. Takes an image and predictions as input.
37+
2. Accumulates heat based on the position of detections.
38+
3. Draws a semi-transparent overlay of blurred circles representing the heat.
39+
40+
## Common Use Cases
41+
42+
- **Density Analysis**: Visualize the density of objects in a scene.
43+
- **Traffic Monitoring**: Identify high-traffic areas.
44+
- **Retail Analytics**: Analyze foot traffic in stores.
45+
"""
46+
47+
48+
class HeatmapManifest(PredictionsVisualizationManifest):
49+
type: Literal[f"{TYPE}", "HeatmapVisualization"]
50+
model_config = ConfigDict(
51+
json_schema_extra={
52+
"name": "Heatmap Visualization",
53+
"version": "v1",
54+
"short_description": SHORT_DESCRIPTION,
55+
"long_description": LONG_DESCRIPTION,
56+
"license": "Apache-2.0",
57+
"block_type": "visualization",
58+
"search_keywords": ["annotator", "heatmap"],
59+
"ui_manifest": {
60+
"section": "visualization",
61+
"icon": "fas fa-fire",
62+
"blockPriority": 4,
63+
"supervision": True,
64+
"warnings": [
65+
{
66+
"property": "copy_image",
67+
"value": False,
68+
"message": "This setting will mutate its input image. If the input is used by other blocks, it may cause unexpected behavior.",
69+
}
70+
],
71+
},
72+
}
73+
)
74+
75+
metadata: Selector(kind=[VIDEO_METADATA_KIND]) = Field(
76+
description="Video metadata containing video_identifier to maintain separate state for different videos.",
77+
default=None,
78+
examples=["$inputs.video_metadata"],
79+
)
80+
81+
position: Union[
82+
Literal[
83+
"CENTER",
84+
"CENTER_LEFT",
85+
"CENTER_RIGHT",
86+
"TOP_CENTER",
87+
"TOP_LEFT",
88+
"TOP_RIGHT",
89+
"BOTTOM_CENTER",
90+
"BOTTOM_LEFT",
91+
"BOTTOM_RIGHT",
92+
],
93+
Selector(kind=[STRING_KIND]),
94+
] = Field( # type: ignore
95+
default="BOTTOM_CENTER",
96+
description="The position of the heatmap relative to the detection.",
97+
examples=["BOTTOM_CENTER", "$inputs.position"],
98+
)
99+
100+
opacity: Union[float, Selector(kind=[FLOAT_KIND])] = Field( # type: ignore
101+
description="Opacity of the overlay mask, between 0 and 1.",
102+
default=0.2,
103+
examples=[0.2, "$inputs.opacity"],
104+
)
105+
106+
radius: Union[int, Selector(kind=[INTEGER_KIND])] = Field( # type: ignore
107+
description="Radius of the heat circle.",
108+
default=40,
109+
examples=[40, "$inputs.radius"],
110+
)
111+
112+
kernel_size: Union[int, Selector(kind=[INTEGER_KIND])] = Field( # type: ignore
113+
description="Kernel size for blurring the heatmap.",
114+
default=25,
115+
examples=[25, "$inputs.kernel_size"],
116+
)
117+
118+
top_hue: Union[int, Selector(kind=[INTEGER_KIND])] = Field( # type: ignore
119+
description="Hue at the top of the heatmap. Defaults to 0 (red).",
120+
default=0,
121+
examples=[0, "$inputs.top_hue"],
122+
)
123+
124+
low_hue: Union[int, Selector(kind=[INTEGER_KIND])] = Field( # type: ignore
125+
description="Hue at the bottom of the heatmap. Defaults to 125 (blue).",
126+
default=125,
127+
examples=[125, "$inputs.low_hue"],
128+
)
129+
130+
ignore_stationary: Union[bool, Selector(kind=[BOOLEAN_KIND])] = Field( # type: ignore
131+
description="If True, only moving objects (based on tracker ID) will contribute to the heatmap.",
132+
default=True,
133+
examples=[True, "$inputs.ignore_stationary"],
134+
)
135+
136+
motion_threshold: Union[int, Selector(kind=[INTEGER_KIND])] = Field( # type: ignore
137+
description="Minimum movement in pixels required to consider an object as moving.",
138+
default=25,
139+
examples=[25, "$inputs.motion_threshold"],
140+
)
141+
142+
@classmethod
143+
def get_execution_engine_compatibility(cls) -> Optional[str]:
144+
return ">=1.3.0,<2.0.0"
145+
146+
147+
class HeatmapVisualizationBlockV1(PredictionsVisualizationBlock):
148+
def __init__(self, *args, **kwargs):
149+
super().__init__(*args, **kwargs)
150+
self.annotatorCache = {}
151+
# Dictionary to store track history: {video_id: {tracker_id: (x, y, timestamp)}}
152+
self._track_history: Dict[str, Dict[int, tuple]] = {}
153+
self._last_cleanup_time = time.time()
154+
self._cleanup_interval = 10.0 # seconds
155+
156+
@classmethod
157+
def get_manifest(cls) -> Type[WorkflowBlockManifest]:
158+
return HeatmapManifest
159+
160+
def _cleanup_history(self):
161+
current_time = time.time()
162+
if current_time - self._last_cleanup_time < self._cleanup_interval:
163+
return
164+
165+
# Clean up stale trackers (e.g., older than 60s)
166+
# Using 60s as a conservative estimate for ~1800 frames at 30fps
167+
stale_threshold = 60.0
168+
empty_videos = []
169+
170+
for video_id, history in self._track_history.items():
171+
expired_trackers = [
172+
tid
173+
for tid, data in history.items()
174+
if current_time - data[2] > stale_threshold
175+
]
176+
for tid in expired_trackers:
177+
del history[tid]
178+
179+
if not history:
180+
empty_videos.append(video_id)
181+
182+
# Clean up empty video histories
183+
for video_id in empty_videos:
184+
del self._track_history[video_id]
185+
186+
self._last_cleanup_time = current_time
187+
188+
def getAnnotator(
189+
self,
190+
video_id: str,
191+
position: str,
192+
opacity: float,
193+
radius: int,
194+
kernel_size: int,
195+
top_hue: int,
196+
low_hue: int,
197+
) -> sv.annotators.base.BaseAnnotator:
198+
key = "_".join(
199+
map(
200+
str,
201+
[
202+
video_id,
203+
position,
204+
opacity,
205+
radius,
206+
kernel_size,
207+
top_hue,
208+
low_hue,
209+
],
210+
)
211+
)
212+
213+
if key not in self.annotatorCache:
214+
position_enum = getattr(sv.Position, position)
215+
self.annotatorCache[key] = sv.HeatMapAnnotator(
216+
position=position_enum,
217+
opacity=opacity,
218+
radius=radius,
219+
kernel_size=kernel_size,
220+
top_hue=top_hue,
221+
low_hue=low_hue,
222+
)
223+
224+
return self.annotatorCache[key]
225+
226+
def run(
227+
self,
228+
image: WorkflowImageData,
229+
predictions: sv.Detections,
230+
copy_image: bool,
231+
position: Optional[str],
232+
opacity: Optional[float],
233+
radius: Optional[int],
234+
kernel_size: Optional[int],
235+
top_hue: Optional[int],
236+
low_hue: Optional[int],
237+
metadata: Optional[VideoMetadata] = None,
238+
ignore_stationary: bool = True,
239+
motion_threshold: int = 25,
240+
) -> BlockResult:
241+
self._cleanup_history()
242+
detections_to_plot = predictions
243+
video_id = metadata.video_identifier if metadata else "default_video"
244+
245+
if ignore_stationary and predictions.tracker_id is not None:
246+
if video_id not in self._track_history:
247+
self._track_history[video_id] = {}
248+
249+
current_history = self._track_history[video_id]
250+
moving_indices = []
251+
current_time = time.time()
252+
253+
# Calculate centers for current detections
254+
# Use the specified position anchor for tracking consistency
255+
anchor_position = (
256+
getattr(sv.Position, position)
257+
if position
258+
else sv.Position.BOTTOM_CENTER
259+
)
260+
anchors = predictions.get_anchors_coordinates(anchor=anchor_position)
261+
262+
for i, (tracker_id, point) in enumerate(
263+
zip(predictions.tracker_id, anchors)
264+
):
265+
tracker_id = int(tracker_id)
266+
x, y = point
267+
268+
if tracker_id in current_history:
269+
# Check for movement
270+
prev_x, prev_y, _ = current_history[tracker_id]
271+
dist = np.sqrt((x - prev_x) ** 2 + (y - prev_y) ** 2)
272+
273+
if dist >= motion_threshold:
274+
moving_indices.append(i)
275+
# Update history with new position and timestamp
276+
current_history[tracker_id] = (x, y, current_time)
277+
else:
278+
# New track, initialize history
279+
current_history[tracker_id] = (x, y, current_time)
280+
281+
# Filter detections
282+
if len(moving_indices) > 0:
283+
detections_to_plot = predictions[np.array(moving_indices)]
284+
else:
285+
detections_to_plot = sv.Detections.empty()
286+
287+
annotator = self.getAnnotator(
288+
video_id,
289+
position,
290+
opacity,
291+
radius,
292+
kernel_size,
293+
top_hue,
294+
low_hue,
295+
)
296+
annotated_image = annotator.annotate(
297+
scene=image.numpy_image.copy() if copy_image else image.numpy_image,
298+
detections=detections_to_plot,
299+
)
300+
return {
301+
OUTPUT_IMAGE_KEY: WorkflowImageData.copy_and_replace(
302+
origin_image_data=image, numpy_image=annotated_image
303+
)
304+
}

0 commit comments

Comments
 (0)