Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
247 changes: 18 additions & 229 deletions examples/plotting_tutorials.ipynb

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions ggseg_py/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
"""ggseg_py core functions."""

from ggseg_py.ggseg_py import merge_data, rda2gpd
from ggseg_py.ggseg_py import atlas2df, rda2gpd

__all__ = [
'merge_data',
'atlas2df',
'rda2gpd',
'__version__',
]
28 changes: 0 additions & 28 deletions ggseg_py/conversion_dicts.py

This file was deleted.

106 changes: 86 additions & 20 deletions ggseg_py/ggseg_py.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import warnings
from collections.abc import Sequence
from pathlib import Path

import geopandas as gpd
import pandas as pd
Expand Down Expand Up @@ -31,46 +32,72 @@ def _list_to_multipolygon(coords: list[Sequence[Sequence[float]]]) -> MultiPolyg
return MultiPolygon(polys)


def rda2gpd(path2atlas: str, atlas_name: str) -> gpd.GeoDataFrame:
def rda2gpd(atlas: Path, return_atlas_fname: bool = False) -> gpd.GeoDataFrame:
"""
Load atlas data from an R .rda file and convert to GeoDataFrame.

Parameters
----------
path2atlas : str
Filepath to the .rda atlas file.
atlas_name : str
Name of the object inside the .rda to extract (e.g., 'aseg').
atlas : str
Name of an atlas or filepath to an .rda atlas file.
When using a "custom" rda file be aware that the object inside the .rda to extract data (e.g., 'aseg').
should be labeled according to the atlas itself.

Returns
-------
GeoDataFrame
A GeoDataFrame with 'geometry', 'region', 'label', and optional 'roi' columns.
"""

atlas_split = str(atlas).split('/')

if len(atlas_split) == 1:
atlas_name = atlas_split[0]

atlas_loc = Path(__file__).parent.parent
if atlas_name == 'aseg':
path2atlas = atlas_loc / 'ggseg_py' / 'atlases' / 'aseg.rda'
elif atlas_name == 'glasser':
path2atlas = atlas_loc / 'ggseg_py' / 'atlases' / 'glasser.rda'
elif atlas_name == 'dk':
path2atlas = atlas_loc / 'ggseg_py' / 'atlases' / 'dk.rda'
else:
raise ValueError(
'Currently only aseg, glasser and dk atlasses are supported directly. '
'If you want to use a different ggseg compatible atlas taken from an rda file you need'
'to directly supply the path to said file.'
)
else:
path2atlas = atlas
atlas_name = atlas_split[-1].split('.')[0]

with warnings.catch_warnings():
warnings.simplefilter('ignore') # ignoring because fixing issues below
atlas_r: dict = read_rda(path2atlas)
df: pd.DataFrame = atlas_r[atlas_name]['data'] # type: ignore
df_atlas: pd.DataFrame = atlas_r[atlas_name]['data'] # type: ignore

# Convert nested lists to MultiPolygon
df['geometry'] = df['geometry'].apply(_list_to_multipolygon)
df_atlas['geometry'] = df_atlas['geometry'].apply(_list_to_multipolygon)

# Clean up region and label fields
regions: list[str] = []
labels: list[str] = []
for region, label in zip(df['region'], df['label']):
for region, label in zip(df_atlas['region'], df_atlas['label']):
regions.append(region if region is not None else '???')
labels.append(label if label is not None else '???')
df['region'] = regions
df['label'] = labels
df_atlas['region'] = regions
df_atlas['label'] = labels

# Add 'roi' column for aseg atlas
if atlas_name == 'aseg':
df['roi'] = df['hemi'] + '_' + df['label'] # type: ignore
df_atlas['roi'] = df_atlas['label'] # df_atlas['hemi'] + '_' + df_atlas['label'] # type: ignore
elif atlas_name == 'glasser':
df['roi'] = [label.split('_')[1] + '_' + label.split('_')[-1] + '_ROI' for label in df['label']]
df_atlas['roi'] = [label.split('_')[1] + '_' + label.split('_')[-1] + '_ROI' for label in df_atlas['label']]

return gpd.GeoDataFrame(df, geometry='geometry')
if return_atlas_fname:
return gpd.GeoDataFrame(df_atlas, geometry='geometry'), atlas_name
else:
return gpd.GeoDataFrame(df_atlas, geometry='geometry')


def merge_data(data: pd.DataFrame, geo_df: gpd.GeoDataFrame, atlas_name: str) -> gpd.GeoDataFrame:
Expand All @@ -93,13 +120,52 @@ def merge_data(data: pd.DataFrame, geo_df: gpd.GeoDataFrame, atlas_name: str) ->
"""
# Dynamically import the correct conversion dictionary
if atlas_name == 'aseg':
from ggseg_py.conversion_dicts import aseg_dict as mapping
aseg_dict: dict = {
'x3rd-ventricle': '3rd-Ventricle',
'x4th-ventricle': '4th-Ventricle',
'Left-Thalamus-Proper': 'Left-Thalamus',
'Right-Thalamus-Proper': 'Right-Thalamus',
}

# Create ROI column in data
data = data.copy()
data['roi'] = data['StructName'].replace(mapping) # type: ignore
else:
raise ValueError(f'Unsupported atlas_name: {atlas_name}')
geo_df['roi'] = geo_df['roi'].replace(aseg_dict) # type: ignore

geo_mg = geo_df.merge(data, on='roi')
geo_mg = pd.concat([geo_mg, geo_df.query('roi == "???"')]) # add empty cortex back-in

else:
geo_mg = geo_df.merge(data, on='roi')
# Merge and return
return geo_df.merge(data, on='roi', how='outer')
return geo_mg


def atlas2df(atlas: Path, df: pd.DataFrame, col2merge: str) -> gpd.GeoDataFrame:
"""
Load atlas data from an R .rda file, convert to GeoDataFrame and merge with existing data.

Parameters
----------
atlas : str
Name of an atlas or filepath to an .rda atlas file.
When using a "custom" rda file be aware that the object inside the .rda to extract data (e.g., 'aseg').
should be labeled according to the atlas itself.
df : pd.DataFrame
Pandas Datframe containing the data you want to plot on an atlas.
Be aware that the information the labels of the atlas should match the labels that you want to plot.
col2merge : str
a string denoting the column in your pandas dataframe upon which
you want to merge the data of your dataframe with the atlas.

Returns
-------
GeoDataFrame
A GeoDataFrame with 'geometry' a 'roi' column denoting the name of the brain "location"
on which you merged your data, as well as all the other data in you original pandas dataframe.
"""

geo_df, atlas_name = rda2gpd(atlas=atlas, return_atlas_fname=True)

df.rename(columns={col2merge: 'roi'}, inplace=True)

df_mg = merge_data(df, geo_df, atlas_name)

return df_mg
40 changes: 20 additions & 20 deletions ggseg_py/plotting_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def _add_colorbar(
def _plot_views(
gdf: gpd.GeoDataFrame,
views: list[dict],
column: str,
value: str,
cmap: mcolors.Colormap,
mask_region: str | None,
edgecolor: str,
Expand All @@ -120,8 +120,8 @@ def _plot_views(
Input geospatial data with 'side', 'hemi', and specified column.
views : list of dict
Each dict must have 'side' and 'hemi' keys for filtering.
column : str
Column name to color by.
value : str
Column name containing the data to visualize.
cmap : Colormap
Colormap for numeric plots or used to generate category colors.
mask_region : str or None
Expand All @@ -145,7 +145,7 @@ def _plot_views(
sel = sel[sel['hemi'].isin(hemi)] if isinstance(hemi, list | tuple) else sel[sel['hemi'] == hemi]
if norm is not None:
sel.plot(
column=column,
column=value,
cmap=cmap,
norm=norm,
edgecolor=edgecolor,
Expand All @@ -155,7 +155,7 @@ def _plot_views(
ax=ax,
)
elif color_map is not None:
sel_colors = sel[column].map(color_map)
sel_colors = sel[value].map(color_map)
sel.plot(color=sel_colors, edgecolor=edgecolor, linewidth=linewidth, aspect=aspect, ax=ax)
if mask_region:
mask = gdf[(gdf['side'] == view['side']) & (gdf.get('region') == mask_region)]
Expand All @@ -167,7 +167,7 @@ def _plot_multi(
gdf: gpd.GeoDataFrame,
views: list,
layout: tuple[int, int],
column: str,
value: str,
cmap: str | mcolors.Colormap,
mask_region: str | None,
edgecolor: str,
Expand All @@ -189,8 +189,8 @@ def _plot_multi(
View definitions with 'side' and 'hemi'.
layout : tuple of ints
(n_rows, n_cols) specifying subplot grid.
column : str
Column name to color by.
value : str
Column name containing the data to visualize.
cmap : str or Colormap
Colormap for mapping data values.
mask_region : str or None
Expand All @@ -217,12 +217,12 @@ def _plot_multi(
axes : array-like
Array of Axes objects corresponding to each view.
"""
is_num, norm, color_map, cmap = _prepare_coloring(gdf[column], cmap, vmin, vmax)
is_num, norm, color_map, cmap = _prepare_coloring(gdf[value], cmap, vmin, vmax)
fig, axes = plt.subplots(layout[0], layout[1], figsize=figsize, squeeze=False)
_plot_views(gdf, views, column, cmap, mask_region, edgecolor, linewidth, aspect, axes, norm, color_map)
_plot_views(gdf, views, value, cmap, mask_region, edgecolor, linewidth, aspect, axes, norm, color_map)
if show_cbar:
gdf[column]
_add_colorbar(fig, axes, column, cmap, norm=norm, color_map=color_map)
gdf[value]
_add_colorbar(fig, axes, value, cmap, norm=norm, color_map=color_map)
return fig, axes


Expand Down Expand Up @@ -268,7 +268,7 @@ def plot_aseg(

def plot_surface(
gdf: gpd.GeoDataFrame,
column: str = 'label',
value: str = 'label',
cmap: str | mcolors.Colormap = 'tab20',
edgecolor: str = 'black',
linewidth: float = 1.5,
Expand All @@ -285,8 +285,8 @@ def plot_surface(
----------
gdf : GeoDataFrame
Surface atlas geodata with 'side', 'hemi', and data column.
column : str
Column name containing labels or measurements.
value : str
Column name containing the data to visualize.
cmap : str or Colormap
Colormap for mapping values or categories.
edgecolor : str
Expand Down Expand Up @@ -318,15 +318,15 @@ def plot_surface(
{'side': 'medial', 'hemi': 'right'},
]
return _plot_multi(
gdf, views, (2, 2), column, cmap, None, edgecolor, linewidth, aspect, figsize, vmin, vmax, show_cbar
gdf, views, (2, 2), value, cmap, None, edgecolor, linewidth, aspect, figsize, vmin, vmax, show_cbar
)


def plot_view(
gdf: gpd.GeoDataFrame,
side: str,
hemi: str,
column: str = 'label',
value: str = 'label',
cmap: str | mcolors.Colormap = 'viridis',
edgecolor: str = 'black',
linewidth: float = 1.5,
Expand All @@ -347,8 +347,8 @@ def plot_view(
View orientation (e.g., 'lateral', 'medial', 'coronal', 'sagittal').
hemi : str
Hemisphere identifier ('left', 'right', or 'midline').
column : str
Column name for values or labels to plot.
value : str
Column name containing the data to visualize.
cmap : str or Colormap
Colormap for numeric or categorical data.
edgecolor : str
Expand All @@ -375,6 +375,6 @@ def plot_view(
"""
views = [{'side': side, 'hemi': hemi}]
fig, axes = _plot_multi(
gdf, views, (1, 1), column, cmap, None, edgecolor, linewidth, aspect, figsize, vmin, vmax, show_cbar
gdf, views, (1, 1), value, cmap, None, edgecolor, linewidth, aspect, figsize, vmin, vmax, show_cbar
)
return fig, axes[0, 0]
Loading
Loading