Skip to content

Commit f5ece59

Browse files
talmoclaude
andauthored
Add ROI/mask support to Labels class (#68)
Add ROI and SegmentationMask integration to the Labels class. - New `rois` and `masks` array fields with empty defaults - `staticRois` / `temporalRois` computed properties for filtering by frame scope - `getRois()` and `getMasks()` methods with AND-logic filtering by video, frameIdx, annotationType, category, track, and instance Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent e28ff87 commit f5ece59

File tree

2 files changed

+234
-0
lines changed

2 files changed

+234
-0
lines changed

src/model/labels.ts

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,9 @@ import { RecordingSession } from "./camera.js";
77
import { toDict } from "../codecs/dictionary.js";
88
import { labelsFromNumpy } from "../codecs/numpy.js";
99
import type { LazyDataStore, LazyFrameList } from "./lazy.js";
10+
import type { ROI } from "./roi.js";
11+
import type { SegmentationMask } from "./mask.js";
12+
import type { AnnotationType } from "./roi.js";
1013

1114
export class Labels {
1215
labeledFrames: LabeledFrame[];
@@ -16,6 +19,8 @@ export class Labels {
1619
suggestions: SuggestionFrame[];
1720
sessions: RecordingSession[];
1821
provenance: Record<string, unknown>;
22+
rois: ROI[];
23+
masks: SegmentationMask[];
1924

2025
/** @internal Lazy frame list for on-demand materialization. */
2126
_lazyFrameList: LazyFrameList | null = null;
@@ -30,6 +35,8 @@ export class Labels {
3035
suggestions?: SuggestionFrame[];
3136
sessions?: RecordingSession[];
3237
provenance?: Record<string, unknown>;
38+
rois?: ROI[];
39+
masks?: SegmentationMask[];
3340
}) {
3441
this.labeledFrames = options?.labeledFrames ?? [];
3542
this.videos = options?.videos ?? [];
@@ -38,6 +45,8 @@ export class Labels {
3845
this.suggestions = options?.suggestions ?? [];
3946
this.sessions = options?.sessions ?? [];
4047
this.provenance = options?.provenance ?? {};
48+
this.rois = options?.rois ?? [];
49+
this.masks = options?.masks ?? [];
4150

4251
if (!this.videos.length && this.labeledFrames.length) {
4352
const uniqueVideos = new Map<string | Video, Video>();
@@ -137,6 +146,76 @@ export class Labels {
137146
return toDict(this, options);
138147
}
139148

149+
get staticRois(): ROI[] {
150+
return this.rois.filter((roi) => roi.isStatic);
151+
}
152+
153+
get temporalRois(): ROI[] {
154+
return this.rois.filter((roi) => !roi.isStatic);
155+
}
156+
157+
getRois(filters?: {
158+
video?: Video;
159+
frameIdx?: number;
160+
annotationType?: AnnotationType;
161+
category?: string;
162+
track?: Track;
163+
instance?: Instance | PredictedInstance;
164+
}): ROI[] {
165+
if (!filters) return [...this.rois];
166+
let results = this.rois;
167+
if (filters.video !== undefined) {
168+
results = results.filter((r) => r.video === filters.video);
169+
}
170+
if (filters.frameIdx !== undefined) {
171+
results = results.filter((r) => r.frameIdx === filters.frameIdx);
172+
}
173+
if (filters.annotationType !== undefined) {
174+
results = results.filter((r) => r.annotationType === filters.annotationType);
175+
}
176+
if (filters.category !== undefined) {
177+
results = results.filter((r) => r.category === filters.category);
178+
}
179+
if (filters.track !== undefined) {
180+
results = results.filter((r) => r.track === filters.track);
181+
}
182+
if (filters.instance !== undefined) {
183+
results = results.filter((r) => r.instance === filters.instance);
184+
}
185+
return results;
186+
}
187+
188+
getMasks(filters?: {
189+
video?: Video;
190+
frameIdx?: number;
191+
annotationType?: AnnotationType;
192+
category?: string;
193+
track?: Track;
194+
instance?: Instance | PredictedInstance;
195+
}): SegmentationMask[] {
196+
if (!filters) return [...this.masks];
197+
let results = this.masks;
198+
if (filters.video !== undefined) {
199+
results = results.filter((m) => m.video === filters.video);
200+
}
201+
if (filters.frameIdx !== undefined) {
202+
results = results.filter((m) => m.frameIdx === filters.frameIdx);
203+
}
204+
if (filters.annotationType !== undefined) {
205+
results = results.filter((m) => m.annotationType === filters.annotationType);
206+
}
207+
if (filters.category !== undefined) {
208+
results = results.filter((m) => m.category === filters.category);
209+
}
210+
if (filters.track !== undefined) {
211+
results = results.filter((m) => m.track === filters.track);
212+
}
213+
if (filters.instance !== undefined) {
214+
results = results.filter((m) => m.instance === filters.instance);
215+
}
216+
return results;
217+
}
218+
140219
static fromNumpy(
141220
data: number[][][][],
142221
options: { videos?: Video[]; video?: Video; skeletons?: Skeleton[] | Skeleton; skeleton?: Skeleton; trackNames?: string[]; firstFrame?: number; returnConfidence?: boolean }

tests/labels-rois-masks.test.ts

Lines changed: 155 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,155 @@
1+
/* @vitest-environment node */
2+
import { describe, it, expect } from "vitest";
3+
import {
4+
Labels,
5+
Video,
6+
Skeleton,
7+
Instance,
8+
Track,
9+
ROI,
10+
SegmentationMask,
11+
AnnotationType,
12+
} from "../src/index.js";
13+
14+
describe("Labels ROI and Mask integration", () => {
15+
it("stores rois and masks", () => {
16+
const video = new Video({ filename: "test.mp4" });
17+
const roi = ROI.fromBbox(10, 20, 100, 200, { video });
18+
const mask = SegmentationMask.fromArray(new Uint8Array(16), 4, 4, { video });
19+
const labels = new Labels({ videos: [video], rois: [roi], masks: [mask] });
20+
expect(labels.rois).toHaveLength(1);
21+
expect(labels.masks).toHaveLength(1);
22+
expect(labels.rois[0]).toBe(roi);
23+
expect(labels.masks[0]).toBe(mask);
24+
});
25+
26+
it("defaults rois and masks to empty arrays", () => {
27+
const labels = new Labels();
28+
expect(labels.rois).toEqual([]);
29+
expect(labels.masks).toEqual([]);
30+
});
31+
32+
it("filters staticRois and temporalRois", () => {
33+
const video = new Video({ filename: "test.mp4" });
34+
const staticRoi = ROI.fromBbox(0, 0, 10, 10, { video });
35+
const temporalRoi = ROI.fromBbox(0, 0, 20, 20, { video, frameIdx: 5 });
36+
const labels = new Labels({ rois: [staticRoi, temporalRoi] });
37+
38+
expect(labels.staticRois).toHaveLength(1);
39+
expect(labels.staticRois[0]).toBe(staticRoi);
40+
expect(labels.temporalRois).toHaveLength(1);
41+
expect(labels.temporalRois[0]).toBe(temporalRoi);
42+
});
43+
44+
it("getRois filters by video", () => {
45+
const v1 = new Video({ filename: "a.mp4" });
46+
const v2 = new Video({ filename: "b.mp4" });
47+
const roi1 = ROI.fromBbox(0, 0, 10, 10, { video: v1 });
48+
const roi2 = ROI.fromBbox(0, 0, 10, 10, { video: v2 });
49+
const labels = new Labels({ videos: [v1, v2], rois: [roi1, roi2] });
50+
51+
expect(labels.getRois({ video: v1 })).toEqual([roi1]);
52+
expect(labels.getRois({ video: v2 })).toEqual([roi2]);
53+
});
54+
55+
it("getRois filters by frameIdx", () => {
56+
const roi1 = ROI.fromBbox(0, 0, 10, 10, { frameIdx: 0 });
57+
const roi2 = ROI.fromBbox(0, 0, 10, 10, { frameIdx: 5 });
58+
const roi3 = ROI.fromBbox(0, 0, 10, 10);
59+
const labels = new Labels({ rois: [roi1, roi2, roi3] });
60+
61+
expect(labels.getRois({ frameIdx: 0 })).toEqual([roi1]);
62+
expect(labels.getRois({ frameIdx: 5 })).toEqual([roi2]);
63+
});
64+
65+
it("getRois filters by category", () => {
66+
const roi1 = ROI.fromBbox(0, 0, 10, 10, { category: "animal" });
67+
const roi2 = ROI.fromBbox(0, 0, 10, 10, { category: "arena" });
68+
const labels = new Labels({ rois: [roi1, roi2] });
69+
70+
expect(labels.getRois({ category: "animal" })).toEqual([roi1]);
71+
expect(labels.getRois({ category: "arena" })).toEqual([roi2]);
72+
});
73+
74+
it("getRois filters by annotationType", () => {
75+
const roi1 = ROI.fromBbox(0, 0, 10, 10);
76+
const roi2 = ROI.fromPolygon([[0, 0], [10, 0], [10, 10], [0, 10]]);
77+
const labels = new Labels({ rois: [roi1, roi2] });
78+
79+
expect(labels.getRois({ annotationType: AnnotationType.BOUNDING_BOX })).toEqual([roi1]);
80+
expect(labels.getRois({ annotationType: AnnotationType.SEGMENTATION })).toEqual([roi2]);
81+
});
82+
83+
it("getRois filters by track and instance", () => {
84+
const skeleton = new Skeleton({ nodes: ["A"] });
85+
const track = new Track({ name: "track1" });
86+
const inst = new Instance({ points: { A: [1, 2] }, skeleton, track });
87+
const roi1 = ROI.fromBbox(0, 0, 10, 10, { track, instance: inst });
88+
const roi2 = ROI.fromBbox(0, 0, 10, 10);
89+
const labels = new Labels({ rois: [roi1, roi2] });
90+
91+
expect(labels.getRois({ track })).toEqual([roi1]);
92+
expect(labels.getRois({ instance: inst })).toEqual([roi1]);
93+
});
94+
95+
it("getRois with combined filters uses AND logic", () => {
96+
const v1 = new Video({ filename: "a.mp4" });
97+
const roi1 = ROI.fromBbox(0, 0, 10, 10, { video: v1, category: "animal", frameIdx: 0 });
98+
const roi2 = ROI.fromBbox(0, 0, 10, 10, { video: v1, category: "arena", frameIdx: 0 });
99+
const roi3 = ROI.fromBbox(0, 0, 10, 10, { video: v1, category: "animal", frameIdx: 5 });
100+
const labels = new Labels({ rois: [roi1, roi2, roi3] });
101+
102+
const result = labels.getRois({ video: v1, category: "animal", frameIdx: 0 });
103+
expect(result).toEqual([roi1]);
104+
});
105+
106+
it("getMasks filters by frameIdx", () => {
107+
const m1 = SegmentationMask.fromArray(new Uint8Array(4), 2, 2, { frameIdx: 0 });
108+
const m2 = SegmentationMask.fromArray(new Uint8Array(4), 2, 2, { frameIdx: 3 });
109+
const labels = new Labels({ masks: [m1, m2] });
110+
111+
expect(labels.getMasks({ frameIdx: 0 })).toEqual([m1]);
112+
expect(labels.getMasks({ frameIdx: 3 })).toEqual([m2]);
113+
});
114+
115+
it("getMasks filters by category", () => {
116+
const m1 = SegmentationMask.fromArray(new Uint8Array(4), 2, 2, { category: "bg" });
117+
const m2 = SegmentationMask.fromArray(new Uint8Array(4), 2, 2, { category: "fg" });
118+
const labels = new Labels({ masks: [m1, m2] });
119+
120+
expect(labels.getMasks({ category: "bg" })).toEqual([m1]);
121+
});
122+
123+
it("getMasks filters by video", () => {
124+
const v1 = new Video({ filename: "a.mp4" });
125+
const v2 = new Video({ filename: "b.mp4" });
126+
const m1 = SegmentationMask.fromArray(new Uint8Array(4), 2, 2, { video: v1 });
127+
const m2 = SegmentationMask.fromArray(new Uint8Array(4), 2, 2, { video: v2 });
128+
const labels = new Labels({ masks: [m1, m2] });
129+
130+
expect(labels.getMasks({ video: v1 })).toEqual([m1]);
131+
expect(labels.getMasks({ video: v2 })).toEqual([m2]);
132+
});
133+
134+
it("getMasks filters by track", () => {
135+
const track = new Track({ name: "t1" });
136+
const m1 = SegmentationMask.fromArray(new Uint8Array(4), 2, 2, { track });
137+
const m2 = SegmentationMask.fromArray(new Uint8Array(4), 2, 2);
138+
const labels = new Labels({ masks: [m1, m2] });
139+
140+
expect(labels.getMasks({ track })).toEqual([m1]);
141+
});
142+
143+
it("getMasks filters by annotationType", () => {
144+
const m1 = SegmentationMask.fromArray(new Uint8Array(4), 2, 2, {
145+
annotationType: AnnotationType.SEGMENTATION,
146+
});
147+
const m2 = SegmentationMask.fromArray(new Uint8Array(4), 2, 2, {
148+
annotationType: AnnotationType.ARENA,
149+
});
150+
const labels = new Labels({ masks: [m1, m2] });
151+
152+
expect(labels.getMasks({ annotationType: AnnotationType.SEGMENTATION })).toEqual([m1]);
153+
expect(labels.getMasks({ annotationType: AnnotationType.ARENA })).toEqual([m2]);
154+
});
155+
});

0 commit comments

Comments
 (0)