Skip to content

Commit ff89242

Browse files
committed
Fix GDrive URLs
1 parent 1a8fed7 commit ff89242

File tree

1 file changed

+47
-11
lines changed

1 file changed

+47
-11
lines changed

tensorflow_datasets/core/download/downloader.py

Lines changed: 47 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
from etils import epath
3434
from tensorflow_datasets.core import units
3535
from tensorflow_datasets.core import utils
36+
from tensorflow_datasets.core import lazy_imports_lib
3637
from tensorflow_datasets.core.download import checksums as checksums_lib
3738
from tensorflow_datasets.core.download import resource as resource_lib
3839
from tensorflow_datasets.core.download import util as download_utils_lib
@@ -130,6 +131,44 @@ def _get_filename(response: Response) -> str:
130131
return _basename_from_url(response.url)
131132

132133

134+
def _process_gdrive_confirmation(original_url: str, contents: str) -> str:
135+
"""Process Google Drive confirmation page.
136+
137+
Extracts the download link from a Google Drive confirmation page.
138+
139+
Args:
140+
original_url: The URL the confirmation page was originally
141+
retrieved from.
142+
contents: The confirmation page's HTML.
143+
144+
Returns:
145+
download_url: The URL for downloading the file.
146+
"""
147+
bs4 = lazy_imports_lib.lazy_imports.bs4
148+
soup = bs4.BeautifulSoup(contents, 'html.parser')
149+
form = soup.find('form')
150+
if not form:
151+
raise ValueError(
152+
f'Failed to obtain confirmation link for GDrive URL {original_url}.'
153+
)
154+
action = form.get('action', '')
155+
if not action:
156+
raise ValueError(
157+
f'Failed to obtain confirmation link for GDrive URL {original_url}.'
158+
)
159+
# Find the <input>s named 'uuid', 'export', 'id' and 'confirm'
160+
input_names = ['uuid', 'export', 'id', 'confirm']
161+
params = {}
162+
for name in input_names:
163+
input_tag = form.find('input', {'name': name})
164+
if input_tag:
165+
params[name] = input_tag.get('value', '')
166+
query_string = urllib.parse.urlencode(params)
167+
download_url = f'{action}?{query_string}' if query_string else action
168+
download_url = urllib.parse.urljoin(original_url, download_url)
169+
return download_url
170+
171+
133172
class _Downloader:
134173
"""Class providing async download API with checksum validation.
135174
@@ -318,11 +357,15 @@ def _open_with_requests(
318357
session.mount(
319358
'https://', requests.adapters.HTTPAdapter(max_retries=retries)
320359
)
321-
if _DRIVE_URL.match(url):
322-
url = _normalize_drive_url(url)
323360
with session.get(url, stream=True, **kwargs) as response:
324-
_assert_status(response)
325-
yield (response, response.iter_content(chunk_size=io.DEFAULT_BUFFER_SIZE))
361+
if _DRIVE_URL.match(url) and 'Content-Disposition' not in response.headers:
362+
download_url = _process_gdrive_confirmation(url, response.text)
363+
with session.get(download_url, stream=True, **kwargs) as download_response:
364+
_assert_status(download_response)
365+
yield (download_response, download_response.iter_content(chunk_size=io.DEFAULT_BUFFER_SIZE))
366+
else:
367+
_assert_status(response)
368+
yield (response, response.iter_content(chunk_size=io.DEFAULT_BUFFER_SIZE))
326369

327370

328371
@contextlib.contextmanager
@@ -338,13 +381,6 @@ def _open_with_urllib(
338381
)
339382

340383

341-
def _normalize_drive_url(url: str) -> str:
342-
"""Returns Google Drive url with confirmation token."""
343-
# This bypasses the "Google Drive can't scan this file for viruses" warning
344-
# when dowloading large files.
345-
return url + '&confirm=t'
346-
347-
348384
def _assert_status(response: requests.Response) -> None:
349385
"""Ensure the URL response is 200."""
350386
if response.status_code != 200:

0 commit comments

Comments
 (0)