1- from typing import Union , List
1+ from typing import List
22
33import numpy as np
44from matplotlib import pyplot as plt
5- from matplotlib .colors import Colormap , to_rgba
6-
7- import plotly .graph_objects as go
8- from plotly .colors import convert_to_RGB_255
95
106from reeds .function_libs .visualization import plots_style as ps
117from reeds .function_libs .visualization .utils import nice_s_vals
@@ -380,76 +376,4 @@ def plot_stateOccurence_matrix(data: dict,
380376
381377 if (not out_dir is None ):
382378 fig .savefig (out_dir + '/sampling_maxContrib_matrix.png' , bbox_inches = 'tight' )
383- plt .close ()
384-
385- def plot_state_transitions (state_transitions : np .ndarray , title : str = None , colors : Union [List [str ], Colormap ] = ps .qualitative_tab_map , out_path : str = None ):
386- """
387- Make a Sankey plot showing the flows between states.
388-
389- Parameters
390- ----------
391- state_transitions : np.ndarray
392- num_states * num_states 2D array containing the number of transitions between states
393- title: str, optional
394- printed title of the plot
395- colors: Union[List[str], Colormap], optional
396- if you don't like the default colors
397- out_path: str, optional
398- path to save the image to. if none, the image is returned as a plotly figure
399- Returns
400- -------
401- None or fig
402- plotly figure if if was not saved
403- """
404- num_states = len (state_transitions )
405-
406- if isinstance (colors , Colormap ):
407- colors = [colors (i ) for i in np .linspace (0 , 1 , num_states )]
408- elif len (colors ) < num_states :
409- raise Exception ("Insufficient colors to plot all states" )
410-
411- def v_distribute (total_transitions ):
412- # Vertically distribute nodes in plot based on total number of transitions per state
413- box_sizes = total_transitions / total_transitions .sum ()
414- box_vplace = [np .sum (box_sizes [:i ]) + box_sizes [i ]/ 2 for i in range (len (box_sizes ))]
415- return box_vplace
416-
417- y_placements = v_distribute (np .sum (state_transitions , axis = 1 )) + v_distribute (np .sum (state_transitions , axis = 0 ))
418-
419- # Convert colors to plotly format and make them transparent
420- rgba_colors = []
421- for color in colors :
422- rgba = to_rgba (color )
423- rgba_plotly = convert_to_RGB_255 (rgba [:- 1 ])
424- # Add opacity
425- rgba_plotly = rgba_plotly + (0.8 ,)
426- # Make string
427- rgba_colors .append ("rgba" + str (rgba_plotly ))
428-
429- # Indices 0..n-1 are the source and n..2n-1 are the target.
430- fig = go .Figure (data = [go .Sankey (
431- node = dict (
432- pad = 5 ,
433- thickness = 20 ,
434- line = dict (color = "black" , width = 2 ),
435- label = [f"state { i + 1 } " for i in range (num_states )]* 2 ,
436- color = rgba_colors [:num_states ]* 2 ,
437- x = [0.1 ]* num_states + [1 ]* num_states ,
438- y = y_placements
439- ),
440- link = dict (
441- arrowlen = 30 ,
442- source = np .array ([[i ]* num_states for i in range (num_states )]).flatten (),
443- target = np .array ([[i for i in range (num_states , 2 * num_states )] for _ in range (num_states )]).flatten (),
444- value = state_transitions .flatten (),
445- color = np .array ([[c ]* num_states for c in rgba_colors [:num_states ]]).flatten ()
446- ),
447- arrangement = "fixed" ,
448- )])
449- fig .update_layout (title_text = title , font_size = 20 , title_x = 0.5 , height = max (600 , num_states * 100 ))
450-
451- if out_path :
452- fig .write_image (out_path )
453- return None
454- else :
455- return fig
379+ plt .close ()
0 commit comments