1414 CANONICAL_MAP ,
1515 MODALITY_COLOR_MAP ,
1616 PATHOLOGY_COLOR_MAP ,
17+ PATHOLOGY_PASTEL_OVERRIDES ,
18+ MODALITY_EMOJI ,
1719 hex_to_rgba ,
1820 )
1921except ImportError : # pragma: no cover - fallback for direct script execution
2022 from colours import ( # type: ignore
2123 CANONICAL_MAP ,
2224 MODALITY_COLOR_MAP ,
2325 PATHOLOGY_COLOR_MAP ,
26+ PATHOLOGY_PASTEL_OVERRIDES ,
27+ MODALITY_EMOJI ,
2428 hex_to_rgba ,
2529 )
2630
3640_SEPARATORS = ("/" , "|" , ";" , "," )
3741_DEFAULT_COLOR = "#94a3b8"
3842
39- MODALITY_EMOJI = {
40- "Visual" : "👁️" ,
41- "Auditory" : "👂" ,
42- "Sleep" : "🌙" ,
43- "Multisensory" : "🧩" ,
44- "Tactile" : "✋" ,
45- "Motor" : "🏃" ,
46- "Resting State" : "🧘" ,
47- "Rest" : "🧘" ,
48- "Other" : "🧭" ,
49- }
50-
5143
5244def _tokenise_cell (value : object , column_key : str ) -> list [str ]:
5345 """Split multi-valued cells, normalise, and keep Unknown buckets."""
@@ -160,6 +152,12 @@ def _abbreviate(value: float | int) -> str:
160152 if num == 0 :
161153 return "0"
162154
155+ if abs (num ) < 1000 :
156+ rounded = round (num / 10.0 ) * 10.0
157+ if rounded == 0 and num > 0 :
158+ rounded = 10.0
159+ return f"{ int (rounded ):,} "
160+
163161 thresholds = [
164162 (1_000_000_000 , "B" ),
165163 (1_000_000 , "M" ),
@@ -168,11 +166,50 @@ def _abbreviate(value: float | int) -> str:
168166 for divisor , suffix in thresholds :
169167 if abs (num ) >= divisor :
170168 scaled = num / divisor
169+ scaled = round (scaled , 1 )
171170 text = f"{ scaled :.1f} " .rstrip ("0" ).rstrip ("." )
172171 return f"{ text } { suffix } "
173172 return f"{ num :.0f} "
174173
175174
175+ def _lighten_hex (hex_color : str , factor : float = 0.55 ) -> str :
176+ if not isinstance (hex_color , str ) or not hex_color .startswith ("#" ):
177+ return _DEFAULT_COLOR
178+ hex_color = hex_color .lstrip ("#" )
179+ if len (hex_color ) != 6 :
180+ return _DEFAULT_COLOR
181+ try :
182+ r = int (hex_color [0 :2 ], 16 )
183+ g = int (hex_color [2 :4 ], 16 )
184+ b = int (hex_color [4 :6 ], 16 )
185+ except ValueError :
186+ return _DEFAULT_COLOR
187+ r = int (r + (255 - r ) * factor )
188+ g = int (g + (255 - g ) * factor )
189+ b = int (b + (255 - b ) * factor )
190+ return f"#{ r :02x} { g :02x} { b :02x} "
191+
192+
193+ def _pathology_colors (name : str ) -> tuple [str , str , str ]:
194+ """Return (fill_rgba, legend_hex, group_key)."""
195+ base_hex = PATHOLOGY_PASTEL_OVERRIDES .get (name )
196+ if not base_hex :
197+ fallback = PATHOLOGY_COLOR_MAP .get (name )
198+ if fallback :
199+ base_hex = _lighten_hex (fallback , 0.6 )
200+ else :
201+ base_hex = PATHOLOGY_PASTEL_OVERRIDES .get ("Clinical" , _DEFAULT_COLOR )
202+
203+ fill = hex_to_rgba (base_hex , alpha = 0.65 )
204+ if name == "Healthy" :
205+ group = "healthy"
206+ elif name == "Unknown" :
207+ group = "unknown"
208+ else :
209+ group = "clinical"
210+ return fill , base_hex , group
211+
212+
176213def _filter_zero_nodes (df : pd .DataFrame , column : str ) -> pd .DataFrame :
177214 mask = (df ["subjects" ] > 0 ) | (df [column ] == "Unknown" )
178215 return df .loc [mask ].copy ()
@@ -198,14 +235,16 @@ def _format_label(
198235 elif fallback_value > 0 :
199236 secondary_text = f"{ _abbreviate (records_value )} rec"
200237 else :
201- secondary_text = "0 h "
238+ secondary_text = "records unavailable "
202239 return (
203240 f"{ name } <br><span style='font-size:{ font_px } px;'>{ subjects_text } subj"
204241 f" | { secondary_text } </span>"
205242 )
206243
207244
208- def _build_nodes (dataset_level : pd .DataFrame ) -> list [dict [str , object ]]:
245+ def _build_nodes (
246+ dataset_level : pd .DataFrame ,
247+ ) -> tuple [list [dict [str , object ]], list [dict [str , str ]]]:
209248 dataset_level = dataset_level .sort_values (
210249 ["population_type" , "experimental_modality" , "dataset_name" ]
211250 ).reset_index (drop = True )
@@ -229,7 +268,12 @@ def _build_nodes(dataset_level: pd.DataFrame) -> list[dict[str, object]]:
229268 level1 = _filter_zero_nodes (level1 , "population_type" )
230269
231270 nodes : list [dict [str , object ]] = []
232- level1_meta : list [dict [str , str ]] = []
271+ legend_entries : list [dict [str , str ]] = []
272+ seen_groups : set [str ] = set ()
273+ modality_meta : dict [str , dict [str , str ]] = {}
274+ modality_priority = {
275+ name : idx for idx , name in enumerate (MODALITY_COLOR_MAP .keys ())
276+ }
233277
234278 total_subjects = level1 ["subjects" ].sum ()
235279 total_hours = level1 ["hours" ].sum ()
@@ -267,19 +311,16 @@ def _build_nodes(dataset_level: pd.DataFrame) -> list[dict[str, object]]:
267311 row ["hours_from_records" ],
268312 font_px = 16 ,
269313 )
270- base_color = PATHOLOGY_COLOR_MAP .get (name )
271- if not base_color :
272- base_color = PATHOLOGY_COLOR_MAP .get ("Clinical" , _DEFAULT_COLOR )
273- color = hex_to_rgba (base_color , alpha = 0.75 )
274- level1_meta .append ({"name" : name , "color" : base_color })
314+ fill_color , _ , group = _pathology_colors (name )
315+ seen_groups .add (group )
275316 nodes .append (
276317 {
277318 "id" : node_id ,
278319 "parent" : "EEG Dash datasets" ,
279320 "name" : name ,
280321 "text" : label ,
281322 "value" : float (row ["subjects" ]),
282- "color" : color ,
323+ "color" : fill_color ,
283324 "hover" : label ,
284325 }
285326 )
@@ -288,10 +329,10 @@ def _build_nodes(dataset_level: pd.DataFrame) -> list[dict[str, object]]:
288329 modality = row ["experimental_modality" ] or "Unknown"
289330 parent = row ["population_type" ] or "Unknown"
290331 node_id = f"{ parent } / { modality } "
332+ emoji_symbol = MODALITY_EMOJI .get (modality )
291333 modality_label = modality
292- emoji = MODALITY_EMOJI .get (modality )
293- if emoji :
294- modality_label = f"{ emoji } { modality } "
334+ if emoji_symbol and row ["subjects" ] >= 120 :
335+ modality_label = f"{ emoji_symbol } { modality } "
295336 label = _format_label (
296337 modality_label ,
297338 row ["subjects" ],
@@ -301,6 +342,18 @@ def _build_nodes(dataset_level: pd.DataFrame) -> list[dict[str, object]]:
301342 font_px = 16 ,
302343 )
303344 color = MODALITY_COLOR_MAP .get (modality , _DEFAULT_COLOR )
345+ legend_label = (
346+ f"{ (emoji_symbol + ' ' ) if emoji_symbol else '' } { modality } " .strip ()
347+ )
348+ if modality not in modality_meta :
349+ order = modality_priority .get (modality , len (modality_priority ))
350+ modality_meta [modality ] = {
351+ "name" : legend_label ,
352+ "color" : color ,
353+ "group" : "level2" ,
354+ "order" : 100 + order ,
355+ "legendgroup" : "modalities" ,
356+ }
304357 nodes .append (
305358 {
306359 "id" : node_id ,
@@ -327,7 +380,6 @@ def _build_nodes(dataset_level: pd.DataFrame) -> list[dict[str, object]]:
327380 row ["hours_from_records" ],
328381 font_px = 16 ,
329382 )
330- _ = row ["population_type" ] or "Unknown"
331383 if dataset_name == "Unknown" :
332384 color = _DEFAULT_COLOR
333385 else :
@@ -344,7 +396,41 @@ def _build_nodes(dataset_level: pd.DataFrame) -> list[dict[str, object]]:
344396 }
345397 )
346398
347- return nodes , level1_meta
399+ group_config = {
400+ "healthy" : {
401+ "name" : "Healthy" ,
402+ "color" : PATHOLOGY_PASTEL_OVERRIDES .get ("Healthy" , "#bbf7d0" ),
403+ "order" : 0 ,
404+ },
405+ "clinical" : {
406+ "name" : "Clinical" ,
407+ "color" : PATHOLOGY_PASTEL_OVERRIDES .get ("Clinical" , "#f8d0d0" ),
408+ "order" : 1 ,
409+ },
410+ "unknown" : {
411+ "name" : "To be Categorised" ,
412+ "color" : PATHOLOGY_PASTEL_OVERRIDES .get ("Unknown" , "#d0d7df" ),
413+ "order" : 2 ,
414+ },
415+ }
416+
417+ for key , cfg in group_config .items ():
418+ if key in seen_groups :
419+ legend_entries .append (
420+ {
421+ "name" : cfg ["name" ],
422+ "color" : cfg ["color" ],
423+ "group" : "level1" ,
424+ "order" : cfg ["order" ],
425+ "legendgroup" : "populations" ,
426+ }
427+ )
428+
429+ legend_entries .extend (
430+ sorted (modality_meta .values (), key = lambda item : item ["order" ])
431+ )
432+
433+ return nodes , legend_entries
348434
349435
350436def _build_figure (
@@ -355,6 +441,17 @@ def _build_figure(
355441 if not node_list :
356442 raise ValueError ("No data available to render the treemap." )
357443
444+ legend_list = list (legend_entries )
445+ seen : set [str ] = set ()
446+ deduped : list [dict [str , str ]] = []
447+ for entry in legend_list :
448+ if entry ["name" ] in seen :
449+ continue
450+ seen .add (entry ["name" ])
451+ deduped .append (entry )
452+
453+ deduped .sort (key = lambda item : item .get ("order" , 999 ))
454+
358455 fig = go .Figure (
359456 go .Treemap (
360457 ids = [node ["id" ] for node in node_list ],
@@ -367,41 +464,58 @@ def _build_figure(
367464 marker = dict (
368465 colors = [node ["color" ] for node in node_list ],
369466 line = dict (color = "white" , width = 1 ),
370- pad = dict (t = 6 , r = 6 , b = 6 , l = 6 ),
467+ pad = dict (t = 10 , r = 10 , b = 10 , l = 10 ),
371468 ),
372469 textinfo = "text" ,
373470 hovertemplate = "%{customdata[0]}<extra></extra>" ,
374- pathbar = dict (visible = True , edgeshape = "/" , thickness = 34 ),
471+ pathbar = dict (
472+ visible = True , edgeshape = "/" , thickness = 34 , textfont = dict (size = 14 )
473+ ),
375474 textfont = dict (size = 24 ),
376475 insidetextfont = dict (size = 24 ),
377- tiling = dict (pad = 6 , packing = "squarify" ),
476+ tiling = dict (pad = 10 , packing = "squarify" ),
378477 root = dict (color = "rgba(255,255,255,0.95)" ),
379478 )
380479 )
381480
382- for entry in legend_entries :
481+ for entry in deduped :
383482 fig .add_trace (
384483 go .Scatter (
385- x = [None ],
386- y = [None ],
484+ x = [0 ],
485+ y = [0 ],
387486 mode = "markers" ,
388- marker = dict (size = 14 , symbol = "square" , color = entry ["color" ]),
487+ marker = dict (size = 12 , symbol = "square" , color = entry ["color" ]),
389488 name = entry ["name" ],
390489 showlegend = True ,
391490 hoverinfo = "skip" ,
491+ xaxis = "x2" ,
492+ yaxis = "y2" ,
493+ legendgroup = entry .get ("legendgroup" ),
392494 )
393495 )
394496
395497 fig .update_layout (
396498 legend = dict (
397499 orientation = "h" ,
398- yanchor = "bottom " ,
500+ yanchor = "top " ,
399501 y = 1.08 ,
400- xanchor = "left " ,
401- x = 0.0 ,
502+ xanchor = "center " ,
503+ x = 0.5 ,
402504 font = dict (size = 14 ),
403- itemwidth = 80 ,
404- )
505+ itemsizing = "constant" ,
506+ traceorder = "normal" ,
507+ itemclick = False ,
508+ itemdoubleclick = False ,
509+ bgcolor = "rgba(255,255,255,0)" ,
510+ bordercolor = "rgba(0,0,0,0)" ,
511+ borderwidth = 0 ,
512+ ),
513+ legend_traceorder = "normal" ,
514+ )
515+
516+ fig .update_layout (
517+ xaxis2 = dict (visible = False ),
518+ yaxis2 = dict (visible = False ),
405519 )
406520
407521 return fig
@@ -429,8 +543,9 @@ def generate_dataset_treemap(
429543 fig = _build_figure (nodes , legend_entries )
430544 fig .update_layout (
431545 uniformtext = dict (minsize = 18 , mode = "hide" ),
432- margin = dict (t = 140 , l = 24 , r = 24 , b = 16 ),
433- hoverlabel = dict (font_size = 16 ),
546+ margin = dict (t = 60 , l = 32 , r = 220 , b = 40 ),
547+ hoverlabel = dict (font = dict (size = 16 ), align = "left" ),
548+ height = 860 ,
434549 )
435550
436551 out_path = Path (out_html )
0 commit comments