Skip to content

Commit 3a0206a

Browse files
authored
[tune] Parallel Coordinate Visualization Notebook (#1218)
1 parent c70430f commit 3a0206a

File tree

2 files changed

+210
-0
lines changed

2 files changed

+210
-0
lines changed
Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "markdown",
5+
"metadata": {},
6+
"source": [
7+
"# Tune Visualization"
8+
]
9+
},
10+
{
11+
"cell_type": "markdown",
12+
"metadata": {},
13+
"source": [
14+
"In order to visualize results, please install `plotly` with the following command:\n",
15+
"\n",
16+
" `pip install plotly`"
17+
]
18+
},
19+
{
20+
"cell_type": "code",
21+
"execution_count": null,
22+
"metadata": {
23+
"collapsed": true
24+
},
25+
"outputs": [],
26+
"source": [
27+
"import pandas as pd\n",
28+
"from ray.tune.visual_utils import load_results_to_df, generate_plotly_dim_dict\n",
29+
"import plotly\n",
30+
"import plotly.graph_objs as go\n",
31+
"plotly.offline.init_notebook_mode(connected=True)"
32+
]
33+
},
34+
{
35+
"cell_type": "markdown",
36+
"metadata": {},
37+
"source": [
38+
"### Specify the directory where all your results are in the variable `RESULTS_DIR`."
39+
]
40+
},
41+
{
42+
"cell_type": "code",
43+
"execution_count": null,
44+
"metadata": {
45+
"collapsed": true
46+
},
47+
"outputs": [],
48+
"source": [
49+
"RESULTS_DIR = \"/tmp/ray/\"\n",
50+
"df = load_results_to_df(RESULTS_DIR)\n",
51+
"[key for key in df]"
52+
]
53+
},
54+
{
55+
"cell_type": "markdown",
56+
"metadata": {},
57+
"source": [
58+
"### Choose the fields you wish to visualize over in `GOOD_FIELDS`."
59+
]
60+
},
61+
{
62+
"cell_type": "code",
63+
"execution_count": null,
64+
"metadata": {
65+
"collapsed": true,
66+
"scrolled": true
67+
},
68+
"outputs": [],
69+
"source": [
70+
"GOOD_FIELDS = ['experiment_id',\n",
71+
" 'num_sgd_iter',\n",
72+
" 'timesteps_total',\n",
73+
" 'episode_len_mean',\n",
74+
" 'episode_reward_mean']\n",
75+
"\n",
76+
"visualization_df = df[GOOD_FIELDS]\n",
77+
"visualization_df"
78+
]
79+
},
80+
{
81+
"cell_type": "markdown",
82+
"metadata": {},
83+
"source": [
84+
"### Enjoy.\n",
85+
"\n",
86+
"Documentation for this Plotly visualization can be found here: https://plot.ly/python/parallel-coordinates-plot/"
87+
]
88+
},
89+
{
90+
"cell_type": "code",
91+
"execution_count": null,
92+
"metadata": {
93+
"collapsed": true
94+
},
95+
"outputs": [],
96+
"source": [
97+
"data = [go.Parcoords(\n",
98+
" line = dict(color = 'blue'),\n",
99+
" dimensions = [generate_plotly_dim_dict(visualization_df, field) \n",
100+
" for field in visualization_df])\n",
101+
"]\n",
102+
"\n",
103+
"plotly.offline.iplot(data)"
104+
]
105+
}
106+
],
107+
"metadata": {
108+
"kernelspec": {
109+
"display_name": "Python 3",
110+
"language": "python",
111+
"name": "python3"
112+
},
113+
"language_info": {
114+
"codemirror_mode": {
115+
"name": "ipython",
116+
"version": 3
117+
},
118+
"file_extension": ".py",
119+
"mimetype": "text/x-python",
120+
"name": "python",
121+
"nbconvert_exporter": "python",
122+
"pygments_lexer": "ipython3",
123+
"version": "3.6.2"
124+
}
125+
},
126+
"nbformat": 4,
127+
"nbformat_minor": 2
128+
}

python/ray/tune/visual_utils.py

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
from __future__ import absolute_import
2+
from __future__ import division
3+
from __future__ import print_function
4+
5+
import pandas as pd
6+
from pandas.api.types import is_string_dtype, is_numeric_dtype
7+
8+
import os
9+
import os.path as osp
10+
import numpy as np
11+
import json
12+
13+
14+
def _flatten_dict(dt):
15+
while any(type(v) is dict for v in dt.values()):
16+
remove = []
17+
add = {}
18+
for key, value in dt.items():
19+
if type(value) is dict:
20+
for subkey, v in value.items():
21+
add[":".join([key, subkey])] = v
22+
remove.append(key)
23+
dt.update(add)
24+
for k in remove:
25+
del dt[k]
26+
return dt
27+
28+
29+
def _parse_results(res_path):
30+
res_dict = {}
31+
try:
32+
with open(res_path) as f:
33+
# Get last line in file
34+
for line in f:
35+
pass
36+
res_dict = _flatten_dict(json.loads(line.strip()))
37+
except Exception as e:
38+
print("Importing %s failed...Perhaps empty?" % res_path)
39+
return res_dict
40+
41+
42+
def _parse_configs(cfg_path):
43+
try:
44+
with open(cfg_path) as f:
45+
cfg_dict = _flatten_dict(json.load(f))
46+
except Exception as e:
47+
print(e)
48+
return cfg_dict
49+
50+
51+
def _resolve(directory, result_fname):
52+
resultp = osp.join(directory, result_fname)
53+
res_dict = _parse_results(resultp)
54+
cfgp = osp.join(directory, "config.json")
55+
cfg_dict = _parse_configs(cfgp)
56+
cfg_dict.update(res_dict)
57+
return cfg_dict
58+
59+
60+
def load_results_to_df(directory, result_name="result.json"):
61+
exp_directories = [dirpath for dirpath, dirs, files in os.walk(directory)
62+
for f in files if f == result_name]
63+
data = [_resolve(directory, result_name) for directory in exp_directories]
64+
return pd.DataFrame(data)
65+
66+
67+
def generate_plotly_dim_dict(df, field):
68+
dim_dict = {}
69+
dim_dict["label"] = field
70+
column = df[field]
71+
if is_numeric_dtype(column):
72+
dim_dict["values"] = column
73+
elif is_string_dtype(column):
74+
texts = column.unique()
75+
dim_dict["values"] = [np.argwhere(texts == x).flatten()[0]
76+
for x in column]
77+
dim_dict["tickvals"] = list(range(len(texts)))
78+
dim_dict["ticktext"] = texts
79+
else:
80+
raise Exception("Unidentifiable Type")
81+
82+
return dim_dict

0 commit comments

Comments
 (0)