Skip to content

Commit 2cbc925

Browse files
committed
add a simple sim-eval metric assessment script
1 parent de0d5b8 commit 2cbc925

File tree

5 files changed

+1233
-1180
lines changed

5 files changed

+1233
-1180
lines changed

README.md

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ We hope that our work guides and inspires future real-to-sim evaluation efforts.
2020
- [Examples](#examples)
2121
- [Current Environments](#current-environments)
2222
- [Customizing Evaluation Configs](#customizing-evaluation-configs)
23+
- [Metrics for Assessing the Effectiveness of Simulated Evaluation Pipelines](#metrics-for-assessing-the-effectiveness-of-simulated-evaluation-pipelines)
2324
- [Code Structure](#code-structure)
2425
- [Adding New Policies](#adding-new-policies)
2526
- [Adding New Real-to-Sim Evaluation Environments and Robots](#adding-new-real-to-sim-evaluation-environments-and-robots)
@@ -131,7 +132,9 @@ By default, Google Robot environments use a control frequency of 3hz, and Bridge
131132

132133
Please see `scripts/` for examples of how to customize evaluation configs. The inference script `simpler_env/main_inference.py` supports advanced environment building and logging. For example, you can perform a sweep over object and robot poses for evaluation. (Note, however, varying robot poses is not meaningful under the visual matching evaluation setup.)
133134

135+
## Metrics for Assessing the Effectiveness of Simulated Evaluation Pipelines
134136

137+
In our paper, we use the Mean Maximum Rank Violation (MMRV) metric and the Pearson Correlation Coefficient metric to assess the correlation between real and simulated evaluation results. You can reproduce the metrics in `tools/calc_metrics.py` and assess your own real-to-sim evaluation pipeline.
135138

136139
## Code Structure
137140

@@ -165,7 +168,7 @@ simpler_env/
165168
tools/
166169
robot_object_visualization/: tools for visualizing robots and objects when creating new environments
167170
sysid/: tools for system identification when adding new robots
168-
calc_metrics.py: tools for summarizing eval results and calculating metrics, such as Normalized Rank Loss, Pearson Correlation, and Kruskal-Wallis test, to reproduce our paper results
171+
calc_metrics.py: tools for summarizing eval results and calculating metrics, such as Mean Maximum Rank Violation (MMRV) and Pearson Correlation
169172
coacd_process_mesh.py: tools for generating convex collision meshes through CoACD when adding new assets
170173
merge_videos.py: tools for merging videos into one
171174
...

simpler_env/utils/metrics.py

Lines changed: 33 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,49 +1,52 @@
11
import glob
22
from pathlib import Path
3+
from typing import Sequence, Optional
34

45
import numpy as np
5-
from scipy.stats import kruskal
66

77

8-
def pearson_correlation(x, y):
9-
x, y = np.array(x), np.array(y)
10-
assert x.shape == y.shape
11-
x = x - np.mean(x)
12-
y = y - np.mean(y)
13-
if np.all(x == y):
8+
def pearson_correlation(perf_sim: Sequence[float], perf_real: Sequence[float]) -> float:
9+
perf_sim, perf_real = np.array(perf_sim), np.array(perf_real)
10+
assert perf_sim.shape == perf_real.shape
11+
perf_sim = perf_sim - np.mean(perf_sim)
12+
perf_real = perf_real - np.mean(perf_real)
13+
if np.all(perf_sim == perf_real):
1414
pearson = 1
1515
else:
16-
pearson = np.sum(x * y) / (np.sqrt(np.sum(x**2) * np.sum(y**2)) + 1e-8)
16+
pearson = np.sum(perf_sim * perf_real) / (
17+
np.sqrt(np.sum(perf_sim**2) * np.sum(perf_real**2)) + 1e-8
18+
)
1719
return pearson
1820

1921

20-
def mean_maximum_rank_violation(x, y):
21-
# assuming x is sim result and y is real result
22-
x, y = np.array(x), np.array(y)
23-
assert x.shape == y.shape
22+
def mean_maximum_rank_violation(
23+
perf_sim: Sequence[float], perf_real: Sequence[float]
24+
) -> float:
25+
perf_sim, perf_real = np.array(perf_sim), np.array(perf_real)
26+
assert perf_sim.shape == perf_real.shape
2427
rank_violations = []
25-
for i in range(len(x)):
28+
for i in range(len(perf_sim)):
2629
rank_violation = 0.0
27-
for j in range(len(x)):
28-
if (x[i] > x[j]) != (y[i] > y[j]):
29-
rank_violation = max(rank_violation, np.abs(y[i] - y[j]))
30+
for j in range(len(perf_sim)):
31+
if (perf_sim[i] > perf_sim[j]) != (perf_real[i] > perf_real[j]):
32+
rank_violation = max(
33+
rank_violation, np.abs(perf_real[i] - perf_real[j])
34+
)
3035
rank_violations.append(rank_violation)
3136
rank_violation = np.mean(rank_violations)
32-
# rank_violation = 0.0
33-
# for i in range(len(x) - 1):
34-
# for j in range(i + 1, len(x)):
35-
# if (x[i] > x[j]) != (y[i] > y[j]):
36-
# rank_violation = max(rank_violation, np.abs(y[i] - y[j]))
3737
return rank_violation
3838

3939

40-
def print_all_kruskal_results(sim, real, title):
40+
def print_all_kruskal_results(
41+
sim: Sequence[Sequence[float]], real: Sequence[Sequence[float]], title: str
42+
) -> None:
4143
"""
4244
sim, real: shape [n_ckpt, n_trials]
4345
The trial-by-trial success indicator of each checkpoint
4446
(within each checkpoint, the ordering doesn't matter)
4547
Prints out the Kruskal-Wallis test for each checkpoint
4648
"""
49+
from scipy.stats import kruskal
4750
sim, real = np.array(sim), np.array(real)
4851
assert sim.shape == real.shape
4952
print(title)
@@ -57,7 +60,9 @@ def print_all_kruskal_results(sim, real, title):
5760
print(" " * 12, kruskal(sim[i], real[i]))
5861

5962

60-
def construct_unordered_trial_results(n_trials_per_ckpt, success):
63+
def construct_unordered_trial_results(
64+
n_trials_per_ckpt: int, success: Sequence[float]
65+
) -> np.ndarray:
6166
success = np.array(success)
6267
success = np.where(np.isnan(success), 0, success)
6368
n_success_trials = np.round(n_trials_per_ckpt * success).astype(np.int32)
@@ -68,7 +73,11 @@ def construct_unordered_trial_results(n_trials_per_ckpt, success):
6873

6974

7075
# util to get success / failure results from a directory
71-
def get_dir_stats(dir_name, extra_pattern_require=[], succ_fail_pattern=["success", "failure"]):
76+
def get_dir_stats(
77+
dir_name: str,
78+
extra_pattern_require: Optional[Sequence[str]] = [],
79+
succ_fail_pattern: Sequence[str] = ["success", "failure"],
80+
) -> Sequence[int]:
7281
if dir_name[-1] == "/":
7382
dir_name = dir_name[:-1]
7483

simpler_env/utils/visualization.py

Lines changed: 0 additions & 122 deletions
Original file line numberDiff line numberDiff line change
@@ -87,125 +87,3 @@ def plot_pred_and_gt_action_trajectory(predicted_actions, gt_actions, stacked_im
8787

8888
plt.legend()
8989
plt.show()
90-
91-
92-
def colorize_mask(pred_mask: np.ndarray) -> np.ndarray:
93-
"""Colorize a predicted mask
94-
:param pred_mask: [H, W] bool/np.uint8 np.ndarray
95-
:return mask: colorized mask, [H, W, 3] np.uint8 np.ndarray
96-
"""
97-
save_mask = Image.fromarray(pred_mask.astype(np.uint8))
98-
save_mask = save_mask.convert(mode="P")
99-
save_mask.putpalette(_palette)
100-
save_mask = save_mask.convert(mode="RGB")
101-
return np.asarray(save_mask)
102-
103-
104-
def draw_mask(rgb_img, mask, alpha=0.5, id_countour=False) -> np.ndarray:
105-
"""Overlay predicted mask on rgb image
106-
:param rgb_img: RGB image, [H, W, 3] np.uint8 np.ndarray
107-
:param mask: [H, W] bool/np.uint8 np.ndarray
108-
:param alpha: overlay transparency
109-
:return img_mask: mask-overlayed image, [H, W, 3] np.uint8 np.ndarray
110-
"""
111-
img_mask = rgb_img.copy()
112-
if id_countour:
113-
# very slow ~ 1s per image
114-
obj_ids = np.unique(mask)
115-
obj_ids = obj_ids[obj_ids != 0]
116-
117-
for id in obj_ids:
118-
# Overlay color on binary mask
119-
if id <= 255:
120-
color = _palette[id * 3 : id * 3 + 3]
121-
else:
122-
color = [0, 0, 0]
123-
foreground = rgb_img * (1 - alpha) + np.ones_like(rgb_img) * alpha * np.asarray(color)
124-
binary_mask = mask == id
125-
126-
# Compose image
127-
img_mask[binary_mask] = foreground[binary_mask]
128-
129-
countours = binary_dilation(binary_mask, iterations=1) ^ binary_mask
130-
img_mask[countours, :] = 0
131-
else:
132-
binary_mask = mask != 0
133-
countours = binary_dilation(binary_mask, iterations=1) ^ binary_mask
134-
foreground = rgb_img * (1 - alpha) + colorize_mask(mask) * alpha
135-
img_mask[binary_mask] = foreground[binary_mask]
136-
img_mask[countours, :] = 0
137-
return img_mask
138-
139-
140-
def draw_bbox(
141-
rgb_image: np.ndarray,
142-
labels: List[str],
143-
bboxes: np.ndarray,
144-
pred_indices: np.ndarray,
145-
pred_scores: np.ndarray,
146-
bbox_width=2,
147-
text_size=25,
148-
sort_by_score=True,
149-
) -> np.ndarray:
150-
"""Draw bbox predictions on rgb image
151-
152-
:param rgb_image: RGB image, [H, W, 3] np.uint8 np.ndarray
153-
:param labels: list of label strings
154-
:param bboxes: bbox as XYXY pixel coordinates, [n_bbox, 4] np.float32 np.ndarray
155-
:param pred_indices: predicted label indices, [n_bbox,] integer np.ndarray
156-
:param pred_scores: predicted scores, [n_bbox,] np.float32 np.ndarray
157-
:param bbox_width: line width to draw bbox
158-
:param text_size: text size to write predicted label
159-
:param sort_by_score: plot bboxes with lower scores first
160-
so bboxes with higher score are visible
161-
:return out_image: rgb_image with drawn bboxes, [H, W, 3] np.uint8 np.ndarray
162-
"""
163-
font = ImageFont.truetype(FONT_PATH, text_size)
164-
165-
H, W = rgb_image.shape[:2]
166-
rgb_im = Image.fromarray(rgb_image).convert("RGBA")
167-
# make a blank image for text, initialized to transparent text color
168-
txt_im = Image.new("RGBA", rgb_im.size, (255, 255, 255, 0))
169-
d = ImageDraw.Draw(txt_im)
170-
171-
if sort_by_score:
172-
sorted_idx = pred_scores.argsort()
173-
bboxes = bboxes[sorted_idx]
174-
pred_indices = pred_indices[sorted_idx]
175-
pred_scores = pred_scores[sorted_idx]
176-
177-
def _pad_bbox(bbox: Tuple[float], pad: float) -> Tuple[float]:
178-
left, top, right, bottom = bbox
179-
return (left - pad, top - pad, right + pad, bottom + pad)
180-
181-
for (x1, y1, x2, y2), pred_index, pred_score in zip(bboxes, pred_indices, pred_scores):
182-
# draw bbox (left, top, right, bottom)
183-
d.rectangle([x1, y1, x2, y2], fill=None, outline=(255, 0, 0), width=bbox_width)
184-
185-
# draw text
186-
text = f"{labels[pred_index]}: {pred_score:1.2f}"
187-
anchor_xy = [x1 + text_size * 0.1, y2 + text_size * 0.1 + 1]
188-
anchor = "lt"
189-
text_bbox = d.textbbox(anchor_xy, text, font=font, anchor=anchor)
190-
text_bbox = _pad_bbox(text_bbox, text_size * 0.1)
191-
if text_bbox[3] > H and text_bbox[2] > W: # bottom-right
192-
anchor_xy = [x2 - text_size * 0.1, y1 - text_size * 0.1 - 1]
193-
anchor = "rb"
194-
text_bbox = d.textbbox(anchor_xy, text, font=font, anchor=anchor)
195-
text_bbox = _pad_bbox(text_bbox, text_size * 0.1)
196-
elif text_bbox[3] > H: # bottom
197-
anchor_xy = [x1 + text_size * 0.1, y1 - text_size * 0.1 - 1]
198-
anchor = "lb"
199-
text_bbox = d.textbbox(anchor_xy, text, font=font, anchor=anchor)
200-
text_bbox = _pad_bbox(text_bbox, text_size * 0.1)
201-
elif text_bbox[2] > W: # right
202-
anchor_xy = [x2 - text_size * 0.1, y2 + text_size * 0.1 + 1]
203-
anchor = "rt"
204-
text_bbox = d.textbbox(anchor_xy, text, font=font, anchor=anchor)
205-
text_bbox = _pad_bbox(text_bbox, text_size * 0.1)
206-
# draw text bbox (bg only)
207-
d.rectangle(text_bbox, fill=(255, 255, 255), outline=None, width=1)
208-
d.text(anchor_xy, text, fill=(0, 0, 0), font=font, anchor=anchor)
209-
210-
out_im = Image.alpha_composite(rgb_im, txt_im).convert("RGB")
211-
return np.asarray(out_im).copy() # copy makes it writable

0 commit comments

Comments
 (0)