11import bz2
2- import contextlib
32import gzip
43import hashlib
5- import itertools
64import lzma
75import os
86import os .path
1311import urllib
1412import urllib .error
1513import urllib .request
16- import warnings
1714import zipfile
1815from typing import Any , Callable , Dict , IO , Iterable , Iterator , List , Optional , Tuple , TypeVar
1916from urllib .parse import urlparse
2017
2118import numpy as np
22- import requests
2319import torch
2420from torch .utils .model_zoo import tqdm
2521
@@ -187,22 +183,6 @@ def list_files(root: str, suffix: str, prefix: bool = False) -> List[str]:
187183 return files
188184
189185
190- def _extract_gdrive_api_response (response , chunk_size : int = 32 * 1024 ) -> Tuple [bytes , Iterator [bytes ]]:
191- content = response .iter_content (chunk_size )
192- first_chunk = None
193- # filter out keep-alive new chunks
194- while not first_chunk :
195- first_chunk = next (content )
196- content = itertools .chain ([first_chunk ], content )
197-
198- try :
199- match = re .search ("<title>Google Drive - (?P<api_response>.+?)</title>" , first_chunk .decode ())
200- api_response = match ["api_response" ] if match is not None else None
201- except UnicodeDecodeError :
202- api_response = None
203- return api_response , content
204-
205-
206186def download_file_from_google_drive (file_id : str , root : str , filename : Optional [str ] = None , md5 : Optional [str ] = None ):
207187 """Download a Google Drive file from and place it in root.
208188
@@ -212,7 +192,12 @@ def download_file_from_google_drive(file_id: str, root: str, filename: Optional[
212192 filename (str, optional): Name to save the file under. If None, use the id of the file.
213193 md5 (str, optional): MD5 checksum of the download. If None, do not check
214194 """
215- # Based on https://stackoverflow.com/questions/38511444/python-download-files-from-google-drive-using-url
195+ try :
196+ import gdown
197+ except ModuleNotFoundError :
198+ raise RuntimeError (
199+ "To download files from GDrive, 'gdown' is required. You can install it with 'pip install gdown'."
200+ )
216201
217202 root = os .path .expanduser (root )
218203 if not filename :
@@ -225,51 +210,10 @@ def download_file_from_google_drive(file_id: str, root: str, filename: Optional[
225210 print (f"Using downloaded { 'and verified ' if md5 else '' } file: { fpath } " )
226211 return
227212
228- url = "https://drive.google.com/uc"
229- params = dict (id = file_id , export = "download" )
230- with requests .Session () as session :
231- response = session .get (url , params = params , stream = True )
213+ gdown .download (id = file_id , output = fpath , quiet = False , user_agent = USER_AGENT )
232214
233- for key , value in response .cookies .items ():
234- if key .startswith ("download_warning" ):
235- token = value
236- break
237- else :
238- api_response , content = _extract_gdrive_api_response (response )
239- token = "t" if api_response == "Virus scan warning" else None
240-
241- if token is not None :
242- response = session .get (url , params = dict (params , confirm = token ), stream = True )
243- api_response , content = _extract_gdrive_api_response (response )
244-
245- if api_response == "Quota exceeded" :
246- raise RuntimeError (
247- f"The daily quota of the file { filename } is exceeded and it "
248- f"can't be downloaded. This is a limitation of Google Drive "
249- f"and can only be overcome by trying again later."
250- )
251-
252- _save_response_content (content , fpath )
253-
254- # In case we deal with an unhandled GDrive API response, the file should be smaller than 10kB and contain only text
255- if os .stat (fpath ).st_size < 10 * 1024 :
256- with contextlib .suppress (UnicodeDecodeError ), open (fpath ) as fh :
257- text = fh .read ()
258- # Regular expression to detect HTML. Copied from https://stackoverflow.com/a/70585604
259- if re .search (r"</?\s*[a-z-][^>]*\s*>|(&(?:[\w\d]+|#\d+|#x[a-f\d]+);)" , text ):
260- warnings .warn (
261- f"We detected some HTML elements in the downloaded file. "
262- f"This most likely means that the download triggered an unhandled API response by GDrive. "
263- f"Please report this to torchvision at https://github.com/pytorch/vision/issues including "
264- f"the response:\n \n { text } "
265- )
266-
267- if md5 and not check_md5 (fpath , md5 ):
268- raise RuntimeError (
269- f"The MD5 checksum of the download file { fpath } does not match the one on record."
270- f"Please delete the file and try again. "
271- f"If the issue persists, please report this to torchvision at https://github.com/pytorch/vision/issues."
272- )
215+ if not check_integrity (fpath , md5 ):
216+ raise RuntimeError ("File not found or corrupted." )
273217
274218
275219def _extract_tar (from_path : str , to_path : str , compression : Optional [str ]) -> None :
0 commit comments