33import tarfile
44import urllib .request
55from dataclasses import dataclass , field
6- from typing import Dict , Optional
6+ from typing import Dict , List , Optional , Union
77from urllib .request import build_opener , install_opener
8-
8+ import boto3
9+ import botocore .exceptions
910from benchmark import DATASETS_DIR
1011from dataset_reader .ann_compound_reader import AnnCompoundReader
1112from dataset_reader .ann_h5_reader import AnnH5Reader
13+ from dataset_reader .ann_h5_multi_reader import AnnH5MultiReader
1214from dataset_reader .base_reader import BaseReader
1315from dataset_reader .json_reader import JSONReader
14- from dataset_reader .sparse_reader import SparseReader
16+ from tqdm import tqdm
17+ from pathlib import Path
1518
1619# Needed for Cloudflare's firewall in ann-benchmarks
1720# See https://github.com/erikbern/ann-benchmarks/pull/561
2427class DatasetConfig :
2528 name : str
2629 type : str
27- path : str
28-
29- link : Optional [str ] = None
30+ path : Union [str , Dict [str , List [Dict [str , str ]]]] # Can be a string or a dict for multi-file structure
31+ link : Optional [Union [str , Dict [str , List [Dict [str , str ]]]]] = None
3032 schema : Optional [Dict [str , str ]] = field (default_factory = dict )
3133 # None in case of sparse vectors:
3234 vector_size : Optional [int ] = None
@@ -35,57 +37,227 @@ class DatasetConfig:
3537
3638READER_TYPE = {
3739 "h5" : AnnH5Reader ,
40+ "h5-multi" : AnnH5MultiReader ,
3841 "jsonl" : JSONReader ,
3942 "tar" : AnnCompoundReader ,
40- "sparse" : SparseReader ,
4143}
4244
4345
46+ # Progress bar for urllib downloads
47+ def show_progress (block_num , block_size , total_size ):
48+ percent = round (block_num * block_size / total_size * 100 , 2 )
49+ print (f"{ percent } %" , end = "\r " )
50+
51+
52+ # Progress handler for S3 downloads
53+ class S3Progress (tqdm ):
54+ def __init__ (self , total_size ):
55+ super ().__init__ (
56+ total = total_size , unit = "B" , unit_scale = True , desc = "Downloading from S3"
57+ )
58+
59+ def __call__ (self , bytes_amount ):
60+ self .update (bytes_amount )
61+
62+
4463class Dataset :
45- def __init__ (self , config : dict ):
64+ def __init__ (
65+ self ,
66+ config : dict ,
67+ skip_upload : bool ,
68+ skip_search : bool ,
69+ upload_start_idx : int ,
70+ upload_end_idx : int ,
71+ ):
4672 self .config = DatasetConfig (** config )
73+ self .skip_upload = skip_upload
74+ self .skip_search = skip_search
75+ self .upload_start_idx = upload_start_idx
76+ self .upload_end_idx = upload_end_idx
4777
4878 def download (self ):
49- target_path = DATASETS_DIR / self .config .path
79+ if isinstance (self .config .path , dict ): # Handle multi-file datasets
80+ if self .skip_search is False :
81+ # Download query files
82+ for query in self .config .path .get ("queries" , []):
83+ self ._download_file (query ["path" ], query ["link" ])
84+ else :
85+ print (
86+ f"skipping to download query file given skip_search={ self .skip_search } "
87+ )
88+ if self .skip_upload is False :
89+ # Download data files
90+ for data in self .config .path .get ("data" , []):
91+ start_idx = data ["start_idx" ]
92+ end_idx = data ["end_idx" ]
93+ data_path = data ["path" ]
94+ data_link = data ["link" ]
95+ if self .upload_start_idx >= end_idx :
96+ print (
97+ f"skipping downloading { data_path } from { data_link } given { self .upload_start_idx } >{ end_idx } "
98+ )
99+ continue
100+ if self .upload_end_idx < start_idx :
101+ print (
102+ f"skipping downloading { data_path } from { data_link } given { self .upload_end_idx } <{ start_idx } "
103+ )
104+ continue
105+ self ._download_file (data ["path" ], data ["link" ])
106+ else :
107+ print (
108+ f"skipping to download data/upload files given skip_upload={ self .skip_upload } "
109+ )
110+
111+ else : # Handle single-file datasets
112+ target_path = DATASETS_DIR / self .config .path
113+
114+ if target_path .exists ():
115+ print (f"{ target_path } already exists" )
116+ return
117+
118+ if self .config .link :
119+ downloaded_withboto = False
120+ if is_s3_link (self .config .link ):
121+ print ("Use boto3 to download from S3. Faster!" )
122+ try :
123+ self ._download_from_s3 (self .config .link , target_path )
124+ downloaded_withboto = True
125+ except botocore .exceptions .NoCredentialsError :
126+ print ("Credentials not found, downloading without boto3" )
127+ if not downloaded_withboto :
128+ print (f"Downloading from URL { self .config .link } ..." )
129+ tmp_path , _ = urllib .request .urlretrieve (
130+ self .config .link , None , show_progress
131+ )
132+ self ._extract_or_move_file (tmp_path , target_path )
50133
134+ def _download_file (self , relative_path : str , url : str ):
135+ target_path = DATASETS_DIR / relative_path
51136 if target_path .exists ():
52137 print (f"{ target_path } already exists" )
53138 return
54139
55- if self . config . link :
56- print ( f"Downloading { self . config . link } ..." )
57- tmp_path , _ = urllib . request . urlretrieve ( self . config . link )
140+ print ( f"Downloading from { url } to { target_path } " )
141+ tmp_path , _ = urllib . request . urlretrieve ( url , None , show_progress )
142+ self . _extract_or_move_file ( tmp_path , target_path )
58143
59- if self .config .link .endswith (".tgz" ) or self .config .link .endswith (
60- ".tar.gz"
61- ):
62- print (f"Extracting: { tmp_path } -> { target_path } " )
63- (DATASETS_DIR / self .config .path ).mkdir (exist_ok = True , parents = True )
64- file = tarfile .open (tmp_path )
144+ def _extract_or_move_file (self , tmp_path , target_path ):
145+ if tmp_path .endswith (".tgz" ) or tmp_path .endswith (".tar.gz" ):
146+ print (f"Extracting: { tmp_path } -> { target_path } " )
147+ (DATASETS_DIR / self .config .path ).mkdir (exist_ok = True , parents = True )
148+ with tarfile .open (tmp_path ) as file :
65149 file .extractall (target_path )
66- file .close ()
67- os .remove (tmp_path )
68- else :
69- print (f"Moving: { tmp_path } -> { target_path } " )
70- (DATASETS_DIR / self .config .path ).parent .mkdir (exist_ok = True )
71- shutil .copy2 (tmp_path , target_path )
72- os .remove (tmp_path )
150+ os .remove (tmp_path )
151+ else :
152+ print (f"Moving: { tmp_path } -> { target_path } " )
153+ Path (target_path ).parent .mkdir (exist_ok = True )
154+ shutil .copy2 (tmp_path , target_path )
155+ os .remove (tmp_path )
156+
157+ def _download_from_s3 (self , link , target_path ):
158+ s3 = boto3 .client ("s3" )
159+ bucket_name , s3_key = parse_s3_url (link )
160+ tmp_path = f"/tmp/{ os .path .basename (s3_key )} "
161+
162+ print (
163+ f"Downloading from S3: { link } ... bucket_name={ bucket_name } , s3_key={ s3_key } "
164+ )
165+ object_info = s3 .head_object (Bucket = bucket_name , Key = s3_key )
166+ total_size = object_info ["ContentLength" ]
167+
168+ with open (tmp_path , "wb" ) as f :
169+ progress = S3Progress (total_size )
170+ s3 .download_fileobj (bucket_name , s3_key , f , Callback = progress )
171+
172+ self ._extract_or_move_file (tmp_path , target_path )
73173
74174 def get_reader (self , normalize : bool ) -> BaseReader :
75175 reader_class = READER_TYPE [self .config .type ]
76- return reader_class (DATASETS_DIR / self .config .path , normalize = normalize )
176+
177+ if self .config .type == "h5-multi" :
178+ # For h5-multi, we need to pass both data files and query file
179+ data_files = self .config .path ["data" ]
180+ for data_file_dict in data_files :
181+ data_file_dict ["path" ] = DATASETS_DIR / data_file_dict ["path" ]
182+ query_file = DATASETS_DIR / self .config .path ["queries" ][0 ]["path" ]
183+ return reader_class (
184+ data_files = data_files ,
185+ query_file = query_file ,
186+ normalize = normalize ,
187+ skip_upload = self .skip_upload ,
188+ skip_search = self .skip_search ,
189+ )
190+ else :
191+ # For single-file datasets
192+ return reader_class (DATASETS_DIR / self .config .path , normalize = normalize )
193+
194+
195+ def is_s3_link (link ):
196+ return link .startswith ("s3://" ) or "s3.amazonaws.com" in link
197+
198+
199+ def parse_s3_url (s3_url ):
200+ if s3_url .startswith ("s3://" ):
201+ s3_parts = s3_url .replace ("s3://" , "" ).split ("/" , 1 )
202+ bucket_name = s3_parts [0 ]
203+ s3_key = s3_parts [1 ] if len (s3_parts ) > 1 else ""
204+ else :
205+ s3_parts = s3_url .replace ("http://" , "" ).replace ("https://" , "" ).split ("/" , 1 )
206+
207+ if ".s3.amazonaws.com" in s3_parts [0 ]:
208+ bucket_name = s3_parts [0 ].split (".s3.amazonaws.com" )[0 ]
209+ s3_key = s3_parts [1 ] if len (s3_parts ) > 1 else ""
210+ else :
211+ bucket_name = s3_parts [0 ]
212+ s3_key = s3_parts [1 ] if len (s3_parts ) > 1 else ""
213+
214+ return bucket_name , s3_key
77215
78216
79217if __name__ == "__main__" :
80- dataset = Dataset (
218+ dataset_s3_split = Dataset (
81219 {
82- "name" : "glove-25-angular" ,
83- "vector_size" : 25 ,
84- "distance" : "Cosine" ,
85- "type" : "h5" ,
86- "path" : "glove-25-angular/glove-25-angular.hdf5" ,
87- "link" : "http://ann-benchmarks.com/glove-25-angular.hdf5" ,
88- }
220+ "name" : "laion-img-emb-768d-1Billion-cosine" ,
221+ "vector_size" : 768 ,
222+ "distance" : "cosine" ,
223+ "type" : "h5-multi" ,
224+ "path" : {
225+ "data" : [
226+ {
227+ "file_number" : 1 ,
228+ "path" : "laion-1b/data/laion-img-emb-768d-1Billion-cosine-data-part1-0_to_10000000.hdf5" ,
229+ "link" : "http://benchmarks.redislabs.s3.amazonaws.com/vecsim/laion-1b/laion-img-emb-768d-1Billion-cosine-data-part1-0_to_10000000.hdf5" ,
230+ "vector_range" : "0-10000000" ,
231+ "file_size" : "30.7 GB" ,
232+ },
233+ {
234+ "file_number" : 2 ,
235+ "path" : "laion-1b/data/laion-img-emb-768d-1Billion-cosine-data-part10-90000000_to_100000000.hdf5" ,
236+ "link" : "http://benchmarks.redislabs.s3.amazonaws.com/vecsim/laion-1b/laion-img-emb-768d-1Billion-cosine-data-part10-90000000_to_100000000.hdf5" ,
237+ "vector_range" : "90000000-100000000" ,
238+ "file_size" : "30.7 GB" ,
239+ },
240+ {
241+ "file_number" : 3 ,
242+ "path" : "laion-1b/data/laion-img-emb-768d-1Billion-cosine-data-part100-990000000_to_1000000000.hdf5" ,
243+ "link" : "http://benchmarks.redislabs.s3.amazonaws.com/vecsim/laion-1b/laion-img-emb-768d-1Billion-cosine-data-part100-990000000_to_1000000000.hdf5" ,
244+ "vector_range" : "990000000-1000000000" ,
245+ "file_size" : "30.7 GB" ,
246+ },
247+ ],
248+ "queries" : [
249+ {
250+ "path" : "laion-1b/laion-img-emb-768d-1Billion-cosine-queries.hdf5" ,
251+ "link" : "http://benchmarks.redislabs.s3.amazonaws.com/vecsim/laion-1b/laion-img-emb-768d-1Billion-cosine-queries.hdf5" ,
252+ "file_size" : "38.7 MB" ,
253+ },
254+ ],
255+ },
256+ },
257+ skip_upload = True ,
258+ skip_search = False ,
89259 )
90260
91- dataset .download ()
261+ dataset_s3_split .download ()
262+ reader = dataset_s3_split .get_reader (normalize = False )
263+ print (reader ) # Outputs the AnnH5MultiReader instance
0 commit comments