Skip to content

Commit b5df13d

Browse files
author
wjm41
committed
Merge branch 'fix/markers'
2 parents a0c5838 + fc8d154 commit b5df13d

File tree

2 files changed

+113
-42
lines changed

2 files changed

+113
-42
lines changed

example.ipynb

Lines changed: 20 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -33,16 +33,6 @@
3333
"import molplotly\n"
3434
]
3535
},
36-
{
37-
"cell_type": "code",
38-
"execution_count": 3,
39-
"metadata": {},
40-
"outputs": [],
41-
"source": [
42-
"%load_ext autoreload\n",
43-
"%autoreload 1"
44-
]
45-
},
4636
{
4737
"cell_type": "markdown",
4838
"metadata": {},
@@ -4459,7 +4449,7 @@
44594449
" "
44604450
],
44614451
"text/plain": [
4462-
"<IPython.lib.display.IFrame at 0x7fc93564f610>"
4452+
"<IPython.lib.display.IFrame at 0x7feaa9613490>"
44634453
]
44644454
},
44654455
"metadata": {},
@@ -4488,7 +4478,7 @@
44884478
},
44894479
{
44904480
"cell_type": "code",
4491-
"execution_count": 6,
4481+
"execution_count": 5,
44924482
"metadata": {},
44934483
"outputs": [
44944484
{
@@ -4506,7 +4496,7 @@
45064496
" "
45074497
],
45084498
"text/plain": [
4509-
"<IPython.lib.display.IFrame at 0x7fc9346330d0>"
4499+
"<IPython.lib.display.IFrame at 0x7feaa96d29e0>"
45104500
]
45114501
},
45124502
"metadata": {},
@@ -4542,7 +4532,7 @@
45424532
},
45434533
{
45444534
"cell_type": "code",
4545-
"execution_count": 9,
4535+
"execution_count": 6,
45464536
"metadata": {},
45474537
"outputs": [
45484538
{
@@ -4560,7 +4550,7 @@
45604550
" "
45614551
],
45624552
"text/plain": [
4563-
"<IPython.lib.display.IFrame at 0x20f8e103520>"
4553+
"<IPython.lib.display.IFrame at 0x7feaa96d2e30>"
45644554
]
45654555
},
45664556
"metadata": {},
@@ -4598,7 +4588,7 @@
45984588
},
45994589
{
46004590
"cell_type": "code",
4601-
"execution_count": 10,
4591+
"execution_count": 7,
46024592
"metadata": {},
46034593
"outputs": [
46044594
{
@@ -4616,7 +4606,7 @@
46164606
" "
46174607
],
46184608
"text/plain": [
4619-
"<IPython.lib.display.IFrame at 0x20f907deb80>"
4609+
"<IPython.lib.display.IFrame at 0x7feaaabd1420>"
46204610
]
46214611
},
46224612
"metadata": {},
@@ -4634,6 +4624,7 @@
46344624
" x=\"y_true\",\n",
46354625
" y=\"y_pred\",\n",
46364626
" size='Molecular Weight',\n",
4627+
" symbol='Minimum Degree',\n",
46374628
" color='dataset',\n",
46384629
" title='ESOL Regression (colored by random train/test split)',\n",
46394630
" labels={'y_pred': 'Predicted Solubility',\n",
@@ -4645,7 +4636,8 @@
46454636
" df=df_esol,\n",
46464637
" smiles_col='smiles',\n",
46474638
" title_col='Compound ID',\n",
4648-
" color_col='dataset')\n",
4639+
" color_col='dataset',\n",
4640+
" marker_col='Minimum Degree')\n",
46494641
"\n",
46504642
"app_train_test.run_server(mode='inline', port=8703, height=1000)\n"
46514643
]
@@ -4672,7 +4664,7 @@
46724664
},
46734665
{
46744666
"cell_type": "code",
4675-
"execution_count": 11,
4667+
"execution_count": 8,
46764668
"metadata": {},
46774669
"outputs": [
46784670
{
@@ -4690,7 +4682,7 @@
46904682
" "
46914683
],
46924684
"text/plain": [
4693-
"<IPython.lib.display.IFrame at 0x20f90876b80>"
4685+
"<IPython.lib.display.IFrame at 0x7feaaad74970>"
46944686
]
46954687
},
46964688
"metadata": {},
@@ -4733,7 +4725,7 @@
47334725
},
47344726
{
47354727
"cell_type": "code",
4736-
"execution_count": 12,
4728+
"execution_count": 9,
47374729
"metadata": {},
47384730
"outputs": [
47394731
{
@@ -4751,7 +4743,7 @@
47514743
" "
47524744
],
47534745
"text/plain": [
4754-
"<IPython.lib.display.IFrame at 0x20f909b2700>"
4746+
"<IPython.lib.display.IFrame at 0x7feaaad74760>"
47554747
]
47564748
},
47574749
"metadata": {},
@@ -4800,7 +4792,7 @@
48004792
},
48014793
{
48024794
"cell_type": "code",
4803-
"execution_count": null,
4795+
"execution_count": 10,
48044796
"metadata": {},
48054797
"outputs": [],
48064798
"source": [
@@ -4835,7 +4827,7 @@
48354827
},
48364828
{
48374829
"cell_type": "code",
4838-
"execution_count": 14,
4830+
"execution_count": 11,
48394831
"metadata": {},
48404832
"outputs": [
48414833
{
@@ -4853,7 +4845,7 @@
48534845
" "
48544846
],
48554847
"text/plain": [
4856-
"<IPython.lib.display.IFrame at 0x20f8b475b80>"
4848+
"<IPython.lib.display.IFrame at 0x7feaaadcfd90>"
48574849
]
48584850
},
48594851
"metadata": {},
@@ -4893,7 +4885,7 @@
48934885
},
48944886
{
48954887
"cell_type": "code",
4896-
"execution_count": 15,
4888+
"execution_count": 12,
48974889
"metadata": {},
48984890
"outputs": [],
48994891
"source": [
@@ -4932,7 +4924,7 @@
49324924
},
49334925
{
49344926
"cell_type": "code",
4935-
"execution_count": 16,
4927+
"execution_count": 13,
49364928
"metadata": {},
49374929
"outputs": [
49384930
{
@@ -4950,7 +4942,7 @@
49504942
" "
49514943
],
49524944
"text/plain": [
4953-
"<IPython.lib.display.IFrame at 0x20f91152070>"
4945+
"<IPython.lib.display.IFrame at 0x7feaab5afb20>"
49544946
]
49554947
},
49564948
"metadata": {},

molplotly/main.py

Lines changed: 93 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,71 @@ def str2bool(v):
1818
return v.lower() in ("yes", "true", "t", "1")
1919

2020

21+
def test_groups(fig, df_grouped):
22+
"""Test if plotly figure curve names match up with pandas dataframe groups
23+
24+
Args:
25+
fig (plotly figure): _description_
26+
groups (pandas groupby object): _description_
27+
28+
Returns:
29+
_type_: Bool describing whether or not groups is the correct dataframe grouping descbrining the data in fig
30+
"""
31+
str_groups = {}
32+
for name, group in df_grouped:
33+
# if isinstance(name, bool) or isinstance(name, int):
34+
# str_groups[str(name)] = group
35+
if isinstance(name, tuple):
36+
str_groups[", ".join(str(x) for x in name)] = group
37+
else:
38+
str_groups[name] = group
39+
40+
for data in fig.data:
41+
if data.name in str_groups:
42+
if len(data.y) == len(str_groups[data.name]):
43+
continue
44+
else:
45+
return False
46+
return True
47+
48+
49+
def find_grouping(fig, df_data, cols):
50+
51+
if len(cols) == 1:
52+
df_grouped = df_data.groupby(cols)
53+
if not test_groups(fig, df_grouped):
54+
raise ValueError(
55+
"marker_col is mispecified because the dataframe grouping names don't match the names in the plotly figure."
56+
)
57+
58+
elif len(cols) == 2: # color_col and marker_col
59+
60+
df_grouped_x = df_data.groupby(cols)
61+
df_grouped_y = df_data.groupby([cols[1], cols[0]])
62+
63+
if test_groups(fig, df_grouped_x):
64+
df_grouped = df_grouped_x
65+
66+
elif test_groups(fig, df_grouped_y):
67+
df_grouped = df_grouped_y
68+
else:
69+
raise ValueError(
70+
"color_col and marker_col are mispecified because their dataframe grouping names don't match the names in the plotly figure."
71+
)
72+
else:
73+
raise ValueError("Too many columns specified for grouping.")
74+
75+
str_groups = {}
76+
for name, group in df_grouped:
77+
if isinstance(name, tuple):
78+
str_groups[", ".join(str(x) for x in name)] = group
79+
else:
80+
str_groups[name] = group
81+
82+
curve_dict = {index: str_groups[x["name"]] for index, x in enumerate(fig.data)}
83+
return df_grouped, curve_dict
84+
85+
2186
def add_molecules(
2287
fig,
2388
df,
@@ -31,6 +96,7 @@ def add_molecules(
3196
caption_cols=None,
3297
caption_transform={},
3398
color_col=None,
99+
marker_col=None,
34100
wrap=True,
35101
wraplen=20,
36102
width=150,
@@ -77,22 +143,25 @@ def add_molecules(
77143
the font size used in the hover box - the font of the title line is fontsize+2 (default 12)
78144
"""
79145
fig.update_traces(hoverinfo="none", hovertemplate=None)
80-
146+
df_data = df.copy()
147+
if color_col is not None:
148+
df_data[color_col] = df_data[color_col].astype(str)
149+
if marker_col is not None:
150+
df_data[marker_col] = df_data[marker_col].astype(str)
81151
colors = {0: "black"}
152+
82153
if len(fig.data) != 1:
83-
if color_col is not None:
84-
colors = {index: x.marker["color"] for index, x in enumerate(fig.data)}
85-
if df[color_col].dtype == bool:
86-
curve_dict = {
87-
index: str2bool(x["name"]) for index, x in enumerate(fig.data)
88-
}
89-
elif df[color_col].dtype == int:
90-
curve_dict = {index: int(x["name"]) for index, x in enumerate(fig.data)}
91-
else:
92-
curve_dict = {index: x["name"] for index, x in enumerate(fig.data)}
93-
else:
154+
if color_col is None and marker_col is None:
94155
raise ValueError(
95-
"color_col needs to be specified if there is more than one plotly curve in the figure!"
156+
"More than one plotly curve in figure - color_col and/or marker_col needs to be specified."
157+
)
158+
if color_col is None:
159+
df_grouped, curve_dict = find_grouping(fig, df_data, [marker_col])
160+
elif marker_col is None:
161+
df_grouped, curve_dict = find_grouping(fig, df_data, [color_col])
162+
else:
163+
df_grouped, curve_dict = find_grouping(
164+
fig, df_data, [color_col, marker_col]
96165
)
97166

98167
app = JupyterDash(__name__)
@@ -143,8 +212,16 @@ def display_hover(hoverData, value):
143212
num = pt["pointNumber"]
144213
curve_num = pt["curveNumber"]
145214

215+
# print(hoverData)
216+
# print(pt)
217+
146218
if len(fig.data) != 1:
147-
df_curve = df[df[color_col] == curve_dict[curve_num]].reset_index(drop=True)
219+
# TODO replace with query
220+
# df_curve = df_grouped.get_group(curve_dict[curve_num]).reset_index(
221+
# drop=True
222+
# )
223+
df_curve = curve_dict[curve_num].reset_index(drop=True)
224+
# df_curve = df[df[color_col] == curve_dict[curve_num]]
148225
df_row = df_curve.iloc[num]
149226
else:
150227
df_row = df.iloc[num]
@@ -197,6 +274,8 @@ def display_hover(hoverData, value):
197274
title = textwrap.fill(title, width=wraplen)
198275
else:
199276
title = title[:wraplen] + "..."
277+
278+
# TODO colorbar color titles
200279
hoverbox_elements.append(
201280
html.H4(
202281
f"{title}",

0 commit comments

Comments
 (0)