33
33
from etils import epath
34
34
from tensorflow_datasets .core import units
35
35
from tensorflow_datasets .core import utils
36
+ from tensorflow_datasets .core import lazy_imports_lib
36
37
from tensorflow_datasets .core .download import checksums as checksums_lib
37
38
from tensorflow_datasets .core .download import resource as resource_lib
38
39
from tensorflow_datasets .core .download import util as download_utils_lib
@@ -130,6 +131,44 @@ def _get_filename(response: Response) -> str:
130
131
return _basename_from_url (response .url )
131
132
132
133
134
+ def _process_gdrive_confirmation (original_url : str , contents : str ) -> str :
135
+ """Process Google Drive confirmation page.
136
+
137
+ Extracts the download link from a Google Drive confirmation page.
138
+
139
+ Args:
140
+ original_url: The URL the confirmation page was originally
141
+ retrieved from.
142
+ contents: The confirmation page's HTML.
143
+
144
+ Returns:
145
+ download_url: The URL for downloading the file.
146
+ """
147
+ bs4 = lazy_imports_lib .lazy_imports .bs4
148
+ soup = bs4 .BeautifulSoup (contents , 'html.parser' )
149
+ form = soup .find ('form' )
150
+ if not form :
151
+ raise ValueError (
152
+ f'Failed to obtain confirmation link for GDrive URL { original_url } .'
153
+ )
154
+ action = form .get ('action' , '' )
155
+ if not action :
156
+ raise ValueError (
157
+ f'Failed to obtain confirmation link for GDrive URL { original_url } .'
158
+ )
159
+ # Find the <input>s named 'uuid', 'export', 'id' and 'confirm'
160
+ input_names = ['uuid' , 'export' , 'id' , 'confirm' ]
161
+ params = {}
162
+ for name in input_names :
163
+ input_tag = form .find ('input' , {'name' : name })
164
+ if input_tag :
165
+ params [name ] = input_tag .get ('value' , '' )
166
+ query_string = urllib .parse .urlencode (params )
167
+ download_url = f'{ action } ?{ query_string } ' if query_string else action
168
+ download_url = urllib .parse .urljoin (original_url , download_url )
169
+ return download_url
170
+
171
+
133
172
class _Downloader :
134
173
"""Class providing async download API with checksum validation.
135
174
@@ -318,11 +357,15 @@ def _open_with_requests(
318
357
session .mount (
319
358
'https://' , requests .adapters .HTTPAdapter (max_retries = retries )
320
359
)
321
- if _DRIVE_URL .match (url ):
322
- url = _normalize_drive_url (url )
323
360
with session .get (url , stream = True , ** kwargs ) as response :
324
- _assert_status (response )
325
- yield (response , response .iter_content (chunk_size = io .DEFAULT_BUFFER_SIZE ))
361
+ if _DRIVE_URL .match (url ) and 'Content-Disposition' not in response .headers :
362
+ download_url = _process_gdrive_confirmation (url , response .text )
363
+ with session .get (download_url , stream = True , ** kwargs ) as download_response :
364
+ _assert_status (download_response )
365
+ yield (download_response , download_response .iter_content (chunk_size = io .DEFAULT_BUFFER_SIZE ))
366
+ else :
367
+ _assert_status (response )
368
+ yield (response , response .iter_content (chunk_size = io .DEFAULT_BUFFER_SIZE ))
326
369
327
370
328
371
@contextlib .contextmanager
@@ -338,13 +381,6 @@ def _open_with_urllib(
338
381
)
339
382
340
383
341
- def _normalize_drive_url (url : str ) -> str :
342
- """Returns Google Drive url with confirmation token."""
343
- # This bypasses the "Google Drive can't scan this file for viruses" warning
344
- # when dowloading large files.
345
- return url + '&confirm=t'
346
-
347
-
348
384
def _assert_status (response : requests .Response ) -> None :
349
385
"""Ensure the URL response is 200."""
350
386
if response .status_code != 200 :
0 commit comments