Skip to content

Commit 446180c

Browse files
committed
feat: implement category-specific benchmark scores and fix deployment
1 parent 1e5b86f commit 446180c

File tree

8 files changed

+180
-250
lines changed

8 files changed

+180
-250
lines changed

.gitignore

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,4 +31,5 @@ benchmarking.py
3131

3232
# Data
3333
cached_datasets/
34-
datasets/
34+
datasets/
35+
*_evaluation_results*.json

SoundCodec/dataset/__init__.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,12 @@ def load_dataset(dataset_name):
88
# Fallback to loading from Hugging Face Hub
99
ds = hf_load_dataset(dataset_name)
1010
if isinstance(ds, dict):
11-
if "test" in ds:
12-
return ds["test"]
13-
if "validation" in ds:
14-
return ds["validation"]
15-
if "train" in ds:
16-
return ds["train"]
17-
# return the first split if none of the above are found
18-
return ds[list(ds.keys())[0]]
11+
from datasets import concatenate_datasets
12+
all_ds = []
13+
for split_name, split_ds in ds.items():
14+
# Add category column if it doesn't exist
15+
if 'category' not in split_ds.column_names:
16+
split_ds = split_ds.add_column('category', [split_name] * len(split_ds))
17+
all_ds.append(split_ds)
18+
return concatenate_datasets(all_ds).shuffle(seed=42)
1919
return ds

update_leaderboard.py

Lines changed: 33 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,15 @@
11
import json
22

3-
# Load benchmark results
4-
with open('._datasets_voidful_codec-superb-tiny_synth_evaluation_results_20251218_204458.json', 'r') as f:
3+
import glob
4+
import os
5+
6+
# Load latest benchmark results
7+
json_files = glob.glob('*codec-superb-tiny_synth_evaluation_results*.json')
8+
if not json_files:
9+
raise FileNotFoundError("No benchmark results found.")
10+
latest_file = max(json_files, key=os.path.getmtime)
11+
print(f"Loading results from {latest_file}")
12+
with open(latest_file, 'r') as f:
513
benchmark_results = json.load(f)
614

715
# Hardcoded BPS mapping (bitrate in kbps or as used in data.js)
@@ -47,16 +55,32 @@
4755

4856
new_results = {}
4957

50-
for model_name, metrics in benchmark_results.items():
58+
new_results = {}
59+
60+
for model_name, metrics_data in benchmark_results.items():
5161
entry = {
5262
'bps': bps_mapping.get(model_name, 0)
5363
}
54-
for m in metrics_to_include:
55-
val = metrics.get(m, 0)
56-
# Handle NaN
57-
if val != val: # NaN check
58-
val = 0
59-
entry[m] = round(float(val), 3)
64+
65+
# Check if nested by category
66+
is_nested = any(isinstance(v, dict) for v in metrics_data.values())
67+
68+
if is_nested:
69+
for category, metrics in metrics_data.items():
70+
for m in metrics_to_include:
71+
val = metrics.get(m, 0)
72+
if val != val: # NaN check
73+
val = 0
74+
entry[f"{category.lower()}_{m}"] = round(float(val), 3)
75+
else:
76+
# Legacy format: metrics_data is {metric: value}
77+
# Map to 'overall' category by default
78+
for m in metrics_to_include:
79+
val = metrics_data.get(m, 0)
80+
if val != val: # NaN check
81+
val = 0
82+
entry[f"overall_{m}"] = round(float(val), 3)
83+
6084
new_results[model_name] = entry
6185

6286
# Format as JavaScript

web/src/App.css

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,17 @@
1515
z-index: 10;
1616
}
1717

18+
.leaderboard-note {
19+
margin-bottom: 1.5rem;
20+
color: var(--text-secondary);
21+
font-size: 0.9375rem;
22+
}
23+
24+
.leaderboard-note strong {
25+
color: var(--accent-color);
26+
font-weight: 600;
27+
}
28+
1829
.main-footer {
1930
padding: 3rem;
2031
text-align: center;

web/src/App.js

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ function App() {
3131
</Card>
3232

3333
<Card title="Leaderboard" delay={0.4}>
34+
<p className="leaderboard-note">Comparing performance across <strong>Speech</strong>, <strong>Audio</strong>, and <strong>Music</strong> categories.</p>
3435
<div className="results-section">
3536
<Leaderboard results={results} />
3637
</div>

web/src/Leaderboard.css

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,23 @@ th {
2020
font-weight: 600;
2121
color: var(--text-secondary);
2222
border-bottom: 1px solid var(--surface-border);
23+
border-right: 1px solid rgba(255, 255, 255, 0.05);
2324
cursor: pointer;
2425
transition: var(--transition-fast);
2526
white-space: nowrap;
2627
}
2728

29+
thead tr:first-child th {
30+
border-bottom: 1px solid rgba(255, 255, 255, 0.1);
31+
background: rgba(255, 255, 255, 0.05);
32+
color: var(--text-primary);
33+
text-align: center;
34+
font-size: 0.75rem;
35+
letter-spacing: 0.1em;
36+
text-transform: uppercase;
37+
padding: 0.5rem 1rem;
38+
}
39+
2840
th:hover {
2941
background: rgba(255, 255, 255, 0.06);
3042
color: var(--text-primary);

web/src/Leaderboard.js

Lines changed: 33 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -16,17 +16,41 @@ const Leaderboard = ({ results }) => {
1616

1717
const columns = React.useMemo(() => {
1818
const firstItem = results[Object.keys(results)[0]];
19-
return [
19+
const categories = ['Speech', 'Audio', 'Music', 'Overall'];
20+
const metrics_keys = ['mel', 'pesq', 'stoi', 'f0corr'];
21+
22+
const colGroups = [
2023
{
21-
Header: 'Model',
22-
accessor: 'model',
23-
},
24-
...Object.keys(firstItem).map(key => ({
25-
Header: key.toUpperCase(),
26-
accessor: key,
27-
sortType: 'basic',
28-
})),
24+
Header: 'Model Info',
25+
columns: [
26+
{ Header: 'Model', accessor: 'model' },
27+
{ Header: 'BPS', accessor: 'bps' }
28+
]
29+
}
2930
];
31+
32+
categories.forEach(cat => {
33+
const catColumns = metrics_keys.map(m => {
34+
const key = `${cat.toLowerCase()}_${m}`;
35+
if (key in firstItem || true) { // Force inclusion or check existence
36+
return {
37+
Header: m.toUpperCase(),
38+
accessor: key,
39+
sortType: 'basic',
40+
};
41+
}
42+
return null;
43+
}).filter(Boolean);
44+
45+
if (catColumns.length > 0) {
46+
colGroups.push({
47+
Header: cat.toUpperCase(),
48+
columns: catColumns
49+
});
50+
}
51+
});
52+
53+
return colGroups;
3054
}, [results]);
3155

3256
const {

0 commit comments

Comments
 (0)