Skip to content

Commit 424b0af

Browse files
updating the first iteration
1 parent 3b0ab27 commit 424b0af

File tree

2 files changed

+183
-40
lines changed

2 files changed

+183
-40
lines changed

docs/plot_dataset/colours.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,34 @@
2121
"Unknown": "#94a3b8",
2222
}
2323

24+
MODALITY_EMOJI = {
25+
"Visual": "👁️",
26+
"Auditory": "👂",
27+
"Sleep": "🌙",
28+
"Multisensory": "🧩",
29+
"Tactile": "✋",
30+
"Motor": "🏃",
31+
"Resting State": "🧘",
32+
"Rest": "🧘",
33+
"Other": "🧭",
34+
"Unknown": "❔",
35+
}
36+
37+
PATHOLOGY_PASTEL_OVERRIDES = {
38+
"Healthy": "#bbf7d0",
39+
"Unknown": "#d0d7df",
40+
"Dementia": "#fcd4d4",
41+
"Schizophrenia": "#f9d0e7",
42+
"Psychosis": "#f9d0e7",
43+
"Epilepsy": "#f9d7c4",
44+
"Parkinson's": "#f8c8c8",
45+
"TBI": "#f9cabd",
46+
"Surgery": "#f7d9b8",
47+
"Other": "#f8cbdc",
48+
"Clinical": "#f8d0d0",
49+
}
50+
51+
2452
TYPE_COLOR_MAP = {
2553
"Perception": "#3b82f6",
2654
"Decision-making": "#eab308",

docs/plot_dataset/treemap.py

Lines changed: 155 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,17 @@
1414
CANONICAL_MAP,
1515
MODALITY_COLOR_MAP,
1616
PATHOLOGY_COLOR_MAP,
17+
PATHOLOGY_PASTEL_OVERRIDES,
18+
MODALITY_EMOJI,
1719
hex_to_rgba,
1820
)
1921
except ImportError: # pragma: no cover - fallback for direct script execution
2022
from colours import ( # type: ignore
2123
CANONICAL_MAP,
2224
MODALITY_COLOR_MAP,
2325
PATHOLOGY_COLOR_MAP,
26+
PATHOLOGY_PASTEL_OVERRIDES,
27+
MODALITY_EMOJI,
2428
hex_to_rgba,
2529
)
2630

@@ -36,18 +40,6 @@
3640
_SEPARATORS = ("/", "|", ";", ",")
3741
_DEFAULT_COLOR = "#94a3b8"
3842

39-
MODALITY_EMOJI = {
40-
"Visual": "👁️",
41-
"Auditory": "👂",
42-
"Sleep": "🌙",
43-
"Multisensory": "🧩",
44-
"Tactile": "✋",
45-
"Motor": "🏃",
46-
"Resting State": "🧘",
47-
"Rest": "🧘",
48-
"Other": "🧭",
49-
}
50-
5143

5244
def _tokenise_cell(value: object, column_key: str) -> list[str]:
5345
"""Split multi-valued cells, normalise, and keep Unknown buckets."""
@@ -160,6 +152,12 @@ def _abbreviate(value: float | int) -> str:
160152
if num == 0:
161153
return "0"
162154

155+
if abs(num) < 1000:
156+
rounded = round(num / 10.0) * 10.0
157+
if rounded == 0 and num > 0:
158+
rounded = 10.0
159+
return f"{int(rounded):,}"
160+
163161
thresholds = [
164162
(1_000_000_000, "B"),
165163
(1_000_000, "M"),
@@ -168,11 +166,50 @@ def _abbreviate(value: float | int) -> str:
168166
for divisor, suffix in thresholds:
169167
if abs(num) >= divisor:
170168
scaled = num / divisor
169+
scaled = round(scaled, 1)
171170
text = f"{scaled:.1f}".rstrip("0").rstrip(".")
172171
return f"{text}{suffix}"
173172
return f"{num:.0f}"
174173

175174

175+
def _lighten_hex(hex_color: str, factor: float = 0.55) -> str:
176+
if not isinstance(hex_color, str) or not hex_color.startswith("#"):
177+
return _DEFAULT_COLOR
178+
hex_color = hex_color.lstrip("#")
179+
if len(hex_color) != 6:
180+
return _DEFAULT_COLOR
181+
try:
182+
r = int(hex_color[0:2], 16)
183+
g = int(hex_color[2:4], 16)
184+
b = int(hex_color[4:6], 16)
185+
except ValueError:
186+
return _DEFAULT_COLOR
187+
r = int(r + (255 - r) * factor)
188+
g = int(g + (255 - g) * factor)
189+
b = int(b + (255 - b) * factor)
190+
return f"#{r:02x}{g:02x}{b:02x}"
191+
192+
193+
def _pathology_colors(name: str) -> tuple[str, str, str]:
194+
"""Return (fill_rgba, legend_hex, group_key)."""
195+
base_hex = PATHOLOGY_PASTEL_OVERRIDES.get(name)
196+
if not base_hex:
197+
fallback = PATHOLOGY_COLOR_MAP.get(name)
198+
if fallback:
199+
base_hex = _lighten_hex(fallback, 0.6)
200+
else:
201+
base_hex = PATHOLOGY_PASTEL_OVERRIDES.get("Clinical", _DEFAULT_COLOR)
202+
203+
fill = hex_to_rgba(base_hex, alpha=0.65)
204+
if name == "Healthy":
205+
group = "healthy"
206+
elif name == "Unknown":
207+
group = "unknown"
208+
else:
209+
group = "clinical"
210+
return fill, base_hex, group
211+
212+
176213
def _filter_zero_nodes(df: pd.DataFrame, column: str) -> pd.DataFrame:
177214
mask = (df["subjects"] > 0) | (df[column] == "Unknown")
178215
return df.loc[mask].copy()
@@ -198,14 +235,16 @@ def _format_label(
198235
elif fallback_value > 0:
199236
secondary_text = f"{_abbreviate(records_value)} rec"
200237
else:
201-
secondary_text = "0 h"
238+
secondary_text = "records unavailable"
202239
return (
203240
f"{name}<br><span style='font-size:{font_px}px;'>{subjects_text} subj"
204241
f" | {secondary_text}</span>"
205242
)
206243

207244

208-
def _build_nodes(dataset_level: pd.DataFrame) -> list[dict[str, object]]:
245+
def _build_nodes(
246+
dataset_level: pd.DataFrame,
247+
) -> tuple[list[dict[str, object]], list[dict[str, str]]]:
209248
dataset_level = dataset_level.sort_values(
210249
["population_type", "experimental_modality", "dataset_name"]
211250
).reset_index(drop=True)
@@ -229,7 +268,12 @@ def _build_nodes(dataset_level: pd.DataFrame) -> list[dict[str, object]]:
229268
level1 = _filter_zero_nodes(level1, "population_type")
230269

231270
nodes: list[dict[str, object]] = []
232-
level1_meta: list[dict[str, str]] = []
271+
legend_entries: list[dict[str, str]] = []
272+
seen_groups: set[str] = set()
273+
modality_meta: dict[str, dict[str, str]] = {}
274+
modality_priority = {
275+
name: idx for idx, name in enumerate(MODALITY_COLOR_MAP.keys())
276+
}
233277

234278
total_subjects = level1["subjects"].sum()
235279
total_hours = level1["hours"].sum()
@@ -267,19 +311,16 @@ def _build_nodes(dataset_level: pd.DataFrame) -> list[dict[str, object]]:
267311
row["hours_from_records"],
268312
font_px=16,
269313
)
270-
base_color = PATHOLOGY_COLOR_MAP.get(name)
271-
if not base_color:
272-
base_color = PATHOLOGY_COLOR_MAP.get("Clinical", _DEFAULT_COLOR)
273-
color = hex_to_rgba(base_color, alpha=0.75)
274-
level1_meta.append({"name": name, "color": base_color})
314+
fill_color, _, group = _pathology_colors(name)
315+
seen_groups.add(group)
275316
nodes.append(
276317
{
277318
"id": node_id,
278319
"parent": "EEG Dash datasets",
279320
"name": name,
280321
"text": label,
281322
"value": float(row["subjects"]),
282-
"color": color,
323+
"color": fill_color,
283324
"hover": label,
284325
}
285326
)
@@ -288,10 +329,10 @@ def _build_nodes(dataset_level: pd.DataFrame) -> list[dict[str, object]]:
288329
modality = row["experimental_modality"] or "Unknown"
289330
parent = row["population_type"] or "Unknown"
290331
node_id = f"{parent} / {modality}"
332+
emoji_symbol = MODALITY_EMOJI.get(modality)
291333
modality_label = modality
292-
emoji = MODALITY_EMOJI.get(modality)
293-
if emoji:
294-
modality_label = f"{emoji} {modality}"
334+
if emoji_symbol and row["subjects"] >= 120:
335+
modality_label = f"{emoji_symbol} {modality}"
295336
label = _format_label(
296337
modality_label,
297338
row["subjects"],
@@ -301,6 +342,18 @@ def _build_nodes(dataset_level: pd.DataFrame) -> list[dict[str, object]]:
301342
font_px=16,
302343
)
303344
color = MODALITY_COLOR_MAP.get(modality, _DEFAULT_COLOR)
345+
legend_label = (
346+
f"{(emoji_symbol + ' ') if emoji_symbol else ''}{modality}".strip()
347+
)
348+
if modality not in modality_meta:
349+
order = modality_priority.get(modality, len(modality_priority))
350+
modality_meta[modality] = {
351+
"name": legend_label,
352+
"color": color,
353+
"group": "level2",
354+
"order": 100 + order,
355+
"legendgroup": "modalities",
356+
}
304357
nodes.append(
305358
{
306359
"id": node_id,
@@ -327,7 +380,6 @@ def _build_nodes(dataset_level: pd.DataFrame) -> list[dict[str, object]]:
327380
row["hours_from_records"],
328381
font_px=16,
329382
)
330-
_ = row["population_type"] or "Unknown"
331383
if dataset_name == "Unknown":
332384
color = _DEFAULT_COLOR
333385
else:
@@ -344,7 +396,41 @@ def _build_nodes(dataset_level: pd.DataFrame) -> list[dict[str, object]]:
344396
}
345397
)
346398

347-
return nodes, level1_meta
399+
group_config = {
400+
"healthy": {
401+
"name": "Healthy",
402+
"color": PATHOLOGY_PASTEL_OVERRIDES.get("Healthy", "#bbf7d0"),
403+
"order": 0,
404+
},
405+
"clinical": {
406+
"name": "Clinical",
407+
"color": PATHOLOGY_PASTEL_OVERRIDES.get("Clinical", "#f8d0d0"),
408+
"order": 1,
409+
},
410+
"unknown": {
411+
"name": "To be Categorised",
412+
"color": PATHOLOGY_PASTEL_OVERRIDES.get("Unknown", "#d0d7df"),
413+
"order": 2,
414+
},
415+
}
416+
417+
for key, cfg in group_config.items():
418+
if key in seen_groups:
419+
legend_entries.append(
420+
{
421+
"name": cfg["name"],
422+
"color": cfg["color"],
423+
"group": "level1",
424+
"order": cfg["order"],
425+
"legendgroup": "populations",
426+
}
427+
)
428+
429+
legend_entries.extend(
430+
sorted(modality_meta.values(), key=lambda item: item["order"])
431+
)
432+
433+
return nodes, legend_entries
348434

349435

350436
def _build_figure(
@@ -355,6 +441,17 @@ def _build_figure(
355441
if not node_list:
356442
raise ValueError("No data available to render the treemap.")
357443

444+
legend_list = list(legend_entries)
445+
seen: set[str] = set()
446+
deduped: list[dict[str, str]] = []
447+
for entry in legend_list:
448+
if entry["name"] in seen:
449+
continue
450+
seen.add(entry["name"])
451+
deduped.append(entry)
452+
453+
deduped.sort(key=lambda item: item.get("order", 999))
454+
358455
fig = go.Figure(
359456
go.Treemap(
360457
ids=[node["id"] for node in node_list],
@@ -367,41 +464,58 @@ def _build_figure(
367464
marker=dict(
368465
colors=[node["color"] for node in node_list],
369466
line=dict(color="white", width=1),
370-
pad=dict(t=6, r=6, b=6, l=6),
467+
pad=dict(t=10, r=10, b=10, l=10),
371468
),
372469
textinfo="text",
373470
hovertemplate="%{customdata[0]}<extra></extra>",
374-
pathbar=dict(visible=True, edgeshape="/", thickness=34),
471+
pathbar=dict(
472+
visible=True, edgeshape="/", thickness=34, textfont=dict(size=14)
473+
),
375474
textfont=dict(size=24),
376475
insidetextfont=dict(size=24),
377-
tiling=dict(pad=6, packing="squarify"),
476+
tiling=dict(pad=10, packing="squarify"),
378477
root=dict(color="rgba(255,255,255,0.95)"),
379478
)
380479
)
381480

382-
for entry in legend_entries:
481+
for entry in deduped:
383482
fig.add_trace(
384483
go.Scatter(
385-
x=[None],
386-
y=[None],
484+
x=[0],
485+
y=[0],
387486
mode="markers",
388-
marker=dict(size=14, symbol="square", color=entry["color"]),
487+
marker=dict(size=12, symbol="square", color=entry["color"]),
389488
name=entry["name"],
390489
showlegend=True,
391490
hoverinfo="skip",
491+
xaxis="x2",
492+
yaxis="y2",
493+
legendgroup=entry.get("legendgroup"),
392494
)
393495
)
394496

395497
fig.update_layout(
396498
legend=dict(
397499
orientation="h",
398-
yanchor="bottom",
500+
yanchor="top",
399501
y=1.08,
400-
xanchor="left",
401-
x=0.0,
502+
xanchor="center",
503+
x=0.5,
402504
font=dict(size=14),
403-
itemwidth=80,
404-
)
505+
itemsizing="constant",
506+
traceorder="normal",
507+
itemclick=False,
508+
itemdoubleclick=False,
509+
bgcolor="rgba(255,255,255,0)",
510+
bordercolor="rgba(0,0,0,0)",
511+
borderwidth=0,
512+
),
513+
legend_traceorder="normal",
514+
)
515+
516+
fig.update_layout(
517+
xaxis2=dict(visible=False),
518+
yaxis2=dict(visible=False),
405519
)
406520

407521
return fig
@@ -429,8 +543,9 @@ def generate_dataset_treemap(
429543
fig = _build_figure(nodes, legend_entries)
430544
fig.update_layout(
431545
uniformtext=dict(minsize=18, mode="hide"),
432-
margin=dict(t=140, l=24, r=24, b=16),
433-
hoverlabel=dict(font_size=16),
546+
margin=dict(t=60, l=32, r=220, b=40),
547+
hoverlabel=dict(font=dict(size=16), align="left"),
548+
height=860,
434549
)
435550

436551
out_path = Path(out_html)

0 commit comments

Comments
 (0)