Skip to content

Commit d9cb3d0

Browse files
pavithrasvTensorflow Cloud maintainers
authored andcommitted
Add back Kaggle notebook integration (#224). The change got reverted because of some copybara issue in (e9c5e87#diff-b5e4b83c53a04cd75cd395f4b75352333339f0f0c58fdb720a37910d189ae0f9)
PiperOrigin-RevId: 340508010
1 parent 0b594b5 commit d9cb3d0

File tree

1 file changed

+31
-1
lines changed

1 file changed

+31
-1
lines changed

src/python/tensorflow_cloud/core/preprocess.py

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,17 @@
1717
from __future__ import division
1818
from __future__ import print_function
1919

20+
import io
21+
import logging
2022
import os
2123
import sys
2224
import tempfile
2325

2426
from . import machine_config
2527

28+
logger = logging.getLogger(__name__)
29+
logging.basicConfig(level=logging.INFO)
30+
2631
try:
2732
from nbconvert import PythonExporter # pylint: disable=g-import-not-at-top
2833
except ImportError:
@@ -164,7 +169,13 @@ def get_preprocessed_entry_point(
164169
'exec(open("{}").read())\n'.format(entry_point_file_name))
165170
else:
166171
if called_from_notebook:
167-
py_content = _get_colab_notebook_content()
172+
# Kaggle integration
173+
if os.getenv("KAGGLE_CONTAINER_NAME"):
174+
logger.info("Preprocessing Kaggle notebook...")
175+
py_content = _get_kaggle_notebook_content()
176+
else:
177+
# Colab integration
178+
py_content = _get_colab_notebook_content()
168179
else:
169180
if PythonExporter is None:
170181
raise RuntimeError(
@@ -212,6 +223,25 @@ def _get_colab_notebook_content():
212223
return py_content
213224

214225

226+
def _get_kaggle_notebook_content():
227+
"""Returns the kaggle notebook python code contents."""
228+
if PythonExporter is None:
229+
raise RuntimeError(
230+
# This should never occur.
231+
# `nbconvert` is always installed on Kaggle.
232+
"Please make sure you have installed `nbconvert` package."
233+
)
234+
from kaggle_session import UserSessionClient # pylint: disable=g-import-not-at-top # pytype: disable=import-error
235+
kaggle_session_client = UserSessionClient()
236+
try:
237+
response = kaggle_session_client.get_exportable_ipynb()
238+
ipynb_stream = io.StringIO(response["source"])
239+
py_content, _ = PythonExporter().from_file(ipynb_stream)
240+
return py_content.splitlines(keepends=True)
241+
except:
242+
raise RuntimeError("Unable to get the notebook contents.")
243+
244+
215245
def get_tpu_cluster_resolver_fn():
216246
"""Returns the fn required for runnning custom container on cloud TPUs.
217247

0 commit comments

Comments
 (0)