|
| 1 | +#!/usr/bin/env python |
| 2 | +# encoding: utf-8 |
| 3 | +# |
| 4 | +# Copyright SAS Institute |
| 5 | +# |
| 6 | +# Licensed under the Apache License, Version 2.0 (the License); |
| 7 | +# you may not use this file except in compliance with the License. |
| 8 | +# You may obtain a copy of the License at |
| 9 | +# |
| 10 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 11 | +# |
| 12 | +# Unless required by applicable law or agreed to in writing, software |
| 13 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 14 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 15 | +# See the License for the specific language governing permissions and |
| 16 | +# limitations under the License. |
| 17 | +# |
| 18 | + |
| 19 | +''' |
| 20 | +Utilities for Zeppelin Notebook integration |
| 21 | +
|
| 22 | +''' |
| 23 | + |
| 24 | +from __future__ import print_function, division, absolute_import, unicode_literals |
| 25 | + |
| 26 | +import base64 |
| 27 | +import cgi |
| 28 | +import pandas as pd |
| 29 | +import pprint |
| 30 | +import six |
| 31 | +import sys |
| 32 | +from ..utils.compat import a2b |
| 33 | + |
| 34 | +def img2tag(img, fmt='png', **kwargs): |
| 35 | + ''' |
| 36 | + Convert image data into HTML tag with data URL |
| 37 | +
|
| 38 | + Parameters |
| 39 | + ---------- |
| 40 | + img : bytes |
| 41 | + The image data |
| 42 | + **kwargs : keyword arguments |
| 43 | + CSS attributes as keyword arguments |
| 44 | +
|
| 45 | + Returns |
| 46 | + ------- |
| 47 | + HTML string |
| 48 | +
|
| 49 | + ''' |
| 50 | + img = b'data:image/' + a2b(fmt) + b';base64,' + base64.b64encode(img.strip()) |
| 51 | + css = [] |
| 52 | + for key, value in kwargs.items(): |
| 53 | + css.append('%s:%s' % (key, value)) |
| 54 | + css = css and ("style='%s' " % '; '.join(css)) or '' |
| 55 | + return "<img src='%s' %s/>" % (img.decode('ascii'), css) |
| 56 | + |
| 57 | +def show(obj, **kwargs): |
| 58 | + ''' Display object using the Zeppelin Display System ''' |
| 59 | + if hasattr(obj, '_z_show_'): |
| 60 | + obj._z_show_(**kwargs) |
| 61 | + |
| 62 | + elif hasattr(obj, 'head') and callable(obj.head): |
| 63 | + show_dataframe(obj, **kwargs) |
| 64 | + |
| 65 | + elif hasattr(obj, 'savefig') and callable(obj.savefig): |
| 66 | + show_matplotlib(obj, **kwargs) |
| 67 | + |
| 68 | + elif hasattr(obj, '_repr_png_'): |
| 69 | + show_image(obj, fmt='png', **kwargs) |
| 70 | + |
| 71 | + elif hasattr(obj, '_repr_jpeg_'): |
| 72 | + show_image(obj, fmt='jpeg', **kwargs) |
| 73 | + |
| 74 | + elif hasattr(obj, '_repr_svg_'): |
| 75 | + show_svg(obj, **kwargs) |
| 76 | + |
| 77 | + else: |
| 78 | + print('%%html <pre>%s</pre>' % cgi.escape(pprint.pformat(obj))) |
| 79 | + |
| 80 | +def show_image(img, fmt='png', width='auto', height='auto'): |
| 81 | + ''' Display an Image object ''' |
| 82 | + if fmt == 'png': |
| 83 | + img = img2tag(img._repr_png_()) |
| 84 | + |
| 85 | + elif fmt in ['jpeg', 'jpg']: |
| 86 | + img = img2tag(img._repr_jpeg_()) |
| 87 | + |
| 88 | + else: |
| 89 | + raise ValueError("Image format must be 'png' or 'jpeg'.") |
| 90 | + |
| 91 | + out = "%html <div style='width:{width}; height:{height}'>{img}</div>" |
| 92 | + |
| 93 | + print(out.format(width=width, height=height, img=img)) |
| 94 | + |
| 95 | +def show_svg(img, width='auto', height='auto'): |
| 96 | + ''' Display an SVG object ''' |
| 97 | + img = img._repr_svg_() |
| 98 | + |
| 99 | + out = "%html <div style='width:{width}; height:{height}'>{img}</div>" |
| 100 | + |
| 101 | + print(out.format(width=width, height=height, img=img)) |
| 102 | + |
| 103 | +def show_matplotlib(plt, fmt='png', width='auto', height='auto'): |
| 104 | + ''' Display a Matplotlib plot ''' |
| 105 | + if fmt in ['png', 'jpeg', 'jpg']: |
| 106 | + io = six.BytesIO() |
| 107 | + plt.savefig(img, format=fmt) |
| 108 | + img = img2tag(img.getvalue(), width=width, height=height) |
| 109 | + io.close() |
| 110 | + |
| 111 | + elif fmt == 'svg': |
| 112 | + io = six.StringIO() |
| 113 | + plt.savefig(io, format=fmt) |
| 114 | + img = io.getvalue() |
| 115 | + io.close() |
| 116 | + |
| 117 | + else: |
| 118 | + raise ValueError("Image format must be 'png', 'jpeg', or 'svg'.") |
| 119 | + |
| 120 | + out = "%html <div style='width:{width}; height:{height}'>{img}</div>" |
| 121 | + |
| 122 | + print(out.format(width=width, height=height, img=img)) |
| 123 | + |
| 124 | +def show_dataframe(df, show_index=None, max_result=None, **kwargs): |
| 125 | + ''' |
| 126 | + Display a DataFrame-like object in a Zeppelin notebook |
| 127 | +
|
| 128 | + Parameters |
| 129 | + ---------- |
| 130 | + show_index : bool, optional |
| 131 | + Should the index be displayed? By default, If the index appears to |
| 132 | + simply be a row number (name is None, type is int), the index is |
| 133 | + not displayed. Otherwise, it is displayed. |
| 134 | + max_result : int, optional |
| 135 | + The maximum number of rows to display. Defaults to the Pandas option |
| 136 | + ``display.max_rows``. |
| 137 | +
|
| 138 | + ''' |
| 139 | + title = getattr(df, 'title', getattr(df, 'label', None)) |
| 140 | + if title: |
| 141 | + sys.stdout.write('%%html <div>%s</div>\n\n' % title) |
| 142 | + |
| 143 | + sys.stdout.write('%table ') |
| 144 | + |
| 145 | + rows = df.head(n=max_result or pd.get_option('display.max_rows')) |
| 146 | + index = rows.index |
| 147 | + |
| 148 | + if show_index is None: |
| 149 | + show_index = True |
| 150 | + if index.names == [None] and str(index.dtype).startswith('int'): |
| 151 | + show_index = False |
| 152 | + |
| 153 | + if show_index and index.names: |
| 154 | + sys.stdout.write('\t'.join([x or '' for x in index.names])) |
| 155 | + sys.stdout.write('\t') |
| 156 | + |
| 157 | + sys.stdout.write('\t'.join(rows.columns)) |
| 158 | + sys.stdout.write('\n') |
| 159 | + |
| 160 | + for idx, row in zip(index.values, rows.values): |
| 161 | + if show_index: |
| 162 | + if isinstance(idx, (list, tuple)): |
| 163 | + sys.stdout.write('\t'.join(['%s' % item for item in idx])) |
| 164 | + else: |
| 165 | + sys.stdout.write('%s' % idx) |
| 166 | + sys.stdout.write('\t') |
| 167 | + sys.stdout.write('\t'.join(['%s' % item for item in row])) |
| 168 | + sys.stdout.write('\n') |
0 commit comments