44import numpy as np
55import zarr
66from matplotlib import animation
7+ from matplotlib .collections import PathCollection
78
89from pyvcell .data_model .var_types import NDArray2D
9- from pyvcell .data_model .zarr_types import Channel
10+ from pyvcell .data_model .zarr_types import ChannelMetadata , ZarrMetadata
11+ from pyvcell .data_model .zarr_types import ChannelMetadata as Channel
1012from pyvcell .simdata .mesh import CartesianMesh
1113from pyvcell .simdata .postprocessing import PostProcessing , VariableInfo
1214from pyvcell .utils import slice_dataset
1719
1820
1921class Plotter :
22+ times : list [float ]
23+ concentrations : NDArray2D
24+ channels : list [Channel ]
25+ post_processing : PostProcessing
26+ zarr_dataset : Union [zarr .Group , zarr .Array ]
27+ mesh : CartesianMesh
28+ metadata : ZarrMetadata
29+
2030 def __init__ (
2131 self ,
2232 times : list [float ],
@@ -25,6 +35,7 @@ def __init__(
2535 post_processing : PostProcessing ,
2636 zarr_dataset : Union [zarr .Group , zarr .Array ],
2737 mesh : CartesianMesh ,
38+ metadata : ZarrMetadata ,
2839 ) -> None :
2940 self .times = times
3041 self .num_timepoints = len (times )
@@ -33,6 +44,18 @@ def __init__(
3344 self .post_processing = post_processing
3445 self .zarr_dataset = zarr_dataset
3546 self .mesh = mesh
47+ self .metadata = metadata
48+
49+ def get_channel (self , label : str ) -> ChannelMetadata :
50+ getter = filter (lambda c : c .label == label , self .channels )
51+ channel_data = next (getter , None )
52+
53+ if channel_data is None :
54+ raise ValueError (f"No channel found with label '{ label } '" )
55+ if next (getter , None ) is not None :
56+ raise ValueError (f"More than one '{ label } ' channel found" )
57+
58+ return channel_data
3659
3760 def plot_concentrations (self ) -> None :
3861 t = self .times
@@ -45,19 +68,18 @@ def plot_concentrations(self) -> None:
4568 ax .grid ()
4669 return plt .show ()
4770
48- def plot_slice_2d (self , time_index : int , channel_index : int , z_index : int ) -> None :
49- data_slice = slice_dataset (self .zarr_dataset , time_index , channel_index , z_index )
71+ def plot_slice_2d (self , time_index : int , channel_name : str , z_index : int ) -> None :
72+ specified_channel = self .get_channel (channel_name )
73+ data_slice = slice_dataset (specified_channel , self .zarr_dataset , time_index , z_index )
5074
5175 t = self .zarr_dataset .attrs .asdict ()["metadata" ]["times" ][time_index ]
5276 channel_label = None
5377 channel_domain = None
5478
5579 for channel in self .channels :
56- if channel .index == channel_index :
80+ if channel .index == specified_channel . index :
5781 channel_label = channel .label
5882 channel_domain = channel .domain_name
59- # channel_label = self.channels[channel_index].label
60- # channel_domain = self.channels[channel_index].domain_name
6183
6284 # z_coord = self.mesh.origin[2] + z_index * self.mesh.extent[2] / (self.mesh.size[2] - 1)
6385 title = f"{ channel_label } (in { channel_domain } ) at t={ t } "
@@ -66,22 +88,31 @@ def plot_slice_2d(self, time_index: int, channel_index: int, z_index: int) -> No
6688 # Display the slice as an image
6789 plt .imshow (data_slice )
6890 plt .title (title )
69- return plt .show ()
91+ plt .show ()
7092
71- def plot_slice_3d (self , time_index : int , channel_index : int ) -> None :
93+ def plot_slice_3d (self , time_index : int , channel_id : str ) -> None :
7294 # Select a 3D volume for a single time point and channel, shape is (z, y, x)
73- volume = self .zarr_dataset [time_index , channel_index , :, :, :]
95+ channel = self .get_channel (channel_id )
96+ volume = self .zarr_dataset [time_index , channel .index , :, :, :]
7497
7598 # Create a figure for 3D plotting
7699 fig = plt .figure ()
77100 ax = fig .add_subplot (111 , projection = "3d" )
78101
79102 # Define a mask to display the volume (use 'region_mask' channel)
80- mask = np .copy (self .zarr_dataset [3 , 0 , :, :, :])
81- z , y , x = np . where ( mask == 1 )
103+ mask = np .copy (self .zarr_dataset [time_index , 0 , :, :, :])
104+ domain = channel . domain_name
82105
83- # Get the intensity values for these points
84- intensities = volume [z , y , x ]
106+ if channel .domain_name == "all" :
107+ z , y , x = np .where (mask > - 1 ) # everywhere
108+ # Get the intensity values for these points
109+ intensities = volume [z , y , x ]
110+ else :
111+ idx : set [int ] = self .mesh .get_volume_region_ids (volume_domain_name = domain )
112+ region_func = lambda region_index : region_index in idx
113+ z , y , x = np .where (np .vectorize (region_func )(mask ))
114+ # Get the intensity values for these points
115+ intensities = volume [z , y , x ]
85116
86117 # Create a 3D scatter plot
87118 scatter = ax .scatter (x , y , z , c = intensities , cmap = "viridis" )
@@ -93,7 +124,9 @@ def plot_slice_3d(self, time_index: int, channel_index: int) -> None:
93124 ax .set_xlabel ("X" )
94125 ax .set_ylabel ("Y" )
95126 ax .set_zlabel ("Z" ) # type: ignore[attr-defined]
96-
127+ t = self .times [time_index ]
128+ title = f"{ channel .label } (in { channel .domain_name } ) at t={ t } "
129+ plt .title (title )
97130 # Show the plot
98131 return plt .show ()
99132
@@ -116,8 +149,7 @@ def get_3d_slice_animation(self, channel_index: int, interval: int = 200) -> ani
116149 interval (int): Time interval between frames in milliseconds.
117150 """
118151 # Extract metadata and the number of time points
119- channel_list = self .channels
120- channel_domain = channel_list [channel_index - 5 ].domain_name
152+ channel : Channel = self .channels [channel_index ]
121153 num_timepoints = self .num_timepoints
122154
123155 # Create a figure for 3D plotting
@@ -130,19 +162,19 @@ def get_3d_slice_animation(self, channel_index: int, interval: int = 200) -> ani
130162 ax .set_zlabel ("Z" ) # type: ignore[attr-defined]
131163 sc = None
132164
133- @no_type_check
134- def update (frame : int ):
165+ def update (frame : int ) -> tuple [PathCollection ]:
135166 """Update function for animation"""
136- # Define a mask to display the volume (use 'region_mask' channel)
137- mask = np .copy (self .zarr_dataset [frame , 0 , :, :, :])
138- z , y , x = np .where (mask == 1 )
167+ mask = np .copy (self .zarr_dataset [3 , 0 , :, :, :])
168+ print (f"Any mask: { np .any (mask )} " )
139169
170+ z , y , x = np .where (mask > 0 )
171+ print (f"got shapes: { z .shape } , { y .shape } , { x .shape } " )
140172 volume = self .zarr_dataset [frame , channel_index , :, :, :]
141173 intensities = volume [z , y , x ]
142174
143175 # Initialize the scatter plot with empty data
144176 scatter = ax .scatter (x , y , z , c = intensities , cmap = "viridis" )
145- ax .set_title (f"Channel: { channel_domain } , Time Index: { frame } " )
177+ ax .set_title (f"Channel: { channel . domain_name } , Time Index: { frame } " )
146178 return (scatter ,)
147179
148180 # Create the animation
0 commit comments