Skip to content

Commit f99d318

Browse files
committed
update eval metrics
1 parent bbcf4b4 commit f99d318

File tree

3 files changed

+266
-31
lines changed

3 files changed

+266
-31
lines changed

simpler_env/utils/metrics.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,15 +17,23 @@ def pearson_correlation(x, y):
1717
return pearson
1818

1919

20-
def normalized_rank_loss(x, y):
20+
def mean_maximum_rank_violation(x, y):
2121
# assuming x is sim result and y is real result
2222
x, y = np.array(x), np.array(y)
2323
assert x.shape == y.shape
24-
rank_violation = 0.0
25-
for i in range(len(x) - 1):
26-
for j in range(i + 1, len(x)):
24+
rank_violations = []
25+
for i in range(len(x)):
26+
rank_violation = 0.0
27+
for j in range(len(x)):
2728
if (x[i] > x[j]) != (y[i] > y[j]):
2829
rank_violation = max(rank_violation, np.abs(y[i] - y[j]))
30+
rank_violations.append(rank_violation)
31+
rank_violation = np.mean(rank_violations)
32+
# rank_violation = 0.0
33+
# for i in range(len(x) - 1):
34+
# for j in range(i + 1, len(x)):
35+
# if (x[i] > x[j]) != (y[i] > y[j]):
36+
# rank_violation = max(rank_violation, np.abs(y[i] - y[j]))
2937
return rank_violation
3038

3139

0 commit comments

Comments
 (0)