|
17 | 17 | from __future__ import division
|
18 | 18 | from __future__ import print_function
|
19 | 19 |
|
| 20 | +import io |
| 21 | +import logging |
20 | 22 | import os
|
21 | 23 | import sys
|
22 | 24 | import tempfile
|
23 | 25 |
|
24 | 26 | from . import machine_config
|
25 | 27 |
|
| 28 | +logger = logging.getLogger(__name__) |
| 29 | +logging.basicConfig(level=logging.INFO) |
| 30 | + |
26 | 31 | try:
|
27 | 32 | from nbconvert import PythonExporter # pylint: disable=g-import-not-at-top
|
28 | 33 | except ImportError:
|
@@ -164,7 +169,13 @@ def get_preprocessed_entry_point(
|
164 | 169 | 'exec(open("{}").read())\n'.format(entry_point_file_name))
|
165 | 170 | else:
|
166 | 171 | if called_from_notebook:
|
167 |
| - py_content = _get_colab_notebook_content() |
| 172 | + # Kaggle integration |
| 173 | + if os.getenv("KAGGLE_CONTAINER_NAME"): |
| 174 | + logger.info("Preprocessing Kaggle notebook...") |
| 175 | + py_content = _get_kaggle_notebook_content() |
| 176 | + else: |
| 177 | + # Colab integration |
| 178 | + py_content = _get_colab_notebook_content() |
168 | 179 | else:
|
169 | 180 | if PythonExporter is None:
|
170 | 181 | raise RuntimeError(
|
@@ -212,6 +223,25 @@ def _get_colab_notebook_content():
|
212 | 223 | return py_content
|
213 | 224 |
|
214 | 225 |
|
| 226 | +def _get_kaggle_notebook_content(): |
| 227 | + """Returns the kaggle notebook python code contents.""" |
| 228 | + if PythonExporter is None: |
| 229 | + raise RuntimeError( |
| 230 | + # This should never occur. |
| 231 | + # `nbconvert` is always installed on Kaggle. |
| 232 | + "Please make sure you have installed `nbconvert` package." |
| 233 | + ) |
| 234 | + from kaggle_session import UserSessionClient # pylint: disable=g-import-not-at-top # pytype: disable=import-error |
| 235 | + kaggle_session_client = UserSessionClient() |
| 236 | + try: |
| 237 | + response = kaggle_session_client.get_exportable_ipynb() |
| 238 | + ipynb_stream = io.StringIO(response["source"]) |
| 239 | + py_content, _ = PythonExporter().from_file(ipynb_stream) |
| 240 | + return py_content.splitlines(keepends=True) |
| 241 | + except: |
| 242 | + raise RuntimeError("Unable to get the notebook contents.") |
| 243 | + |
| 244 | + |
215 | 245 | def get_tpu_cluster_resolver_fn():
|
216 | 246 | """Returns the fn required for runnning custom container on cloud TPUs.
|
217 | 247 |
|
|
0 commit comments