33import tarfile
44import urllib .request
55from dataclasses import dataclass , field
6- from typing import Dict , Optional
7-
6+ from typing import Dict , List , Optional , Union
7+ import boto3
8+ import botocore .exceptions
89from benchmark import DATASETS_DIR
910from dataset_reader .ann_compound_reader import AnnCompoundReader
1011from dataset_reader .ann_h5_reader import AnnH5Reader
12+ from dataset_reader .ann_h5_multi_reader import AnnH5MultiReader
1113from dataset_reader .base_reader import BaseReader
1214from dataset_reader .json_reader import JSONReader
15+ from tqdm import tqdm
16+ from pathlib import Path
1317
1418
1519@dataclass
@@ -18,59 +22,236 @@ class DatasetConfig:
1822 distance : str
1923 name : str
2024 type : str
21- path : str
22- link : Optional [str ] = None
25+ path : Dict [
26+ str , List [Dict [str , str ]]
27+ ] # Now path is expected to handle multi-file structure for h5-multi
28+ link : Optional [Dict [str , List [Dict [str , str ]]]] = None
2329 schema : Optional [Dict [str , str ]] = field (default_factory = dict )
2430
2531
26- READER_TYPE = {"h5" : AnnH5Reader , "jsonl" : JSONReader , "tar" : AnnCompoundReader }
32+ READER_TYPE = {
33+ "h5" : AnnH5Reader ,
34+ "h5-multi" : AnnH5MultiReader ,
35+ "jsonl" : JSONReader ,
36+ "tar" : AnnCompoundReader ,
37+ }
38+
39+
40+ # Progress bar for urllib downloads
41+ def show_progress (block_num , block_size , total_size ):
42+ percent = round (block_num * block_size / total_size * 100 , 2 )
43+ print (f"{ percent } %" , end = "\r " )
44+
45+
46+ # Progress handler for S3 downloads
47+ class S3Progress (tqdm ):
48+ def __init__ (self , total_size ):
49+ super ().__init__ (
50+ total = total_size , unit = "B" , unit_scale = True , desc = "Downloading from S3"
51+ )
52+
53+ def __call__ (self , bytes_amount ):
54+ self .update (bytes_amount )
2755
2856
2957class Dataset :
30- def __init__ (self , config : dict ):
58+ def __init__ (
59+ self ,
60+ config : dict ,
61+ skip_upload : bool ,
62+ skip_search : bool ,
63+ upload_start_idx : int ,
64+ upload_end_idx : int ,
65+ ):
3166 self .config = DatasetConfig (** config )
67+ self .skip_upload = skip_upload
68+ self .skip_search = skip_search
69+ self .upload_start_idx = upload_start_idx
70+ self .upload_end_idx = upload_end_idx
3271
3372 def download (self ):
34- target_path = DATASETS_DIR / self .config .path
73+ if isinstance (self .config .path , dict ): # Handle multi-file datasets
74+ if self .skip_search is False :
75+ # Download query files
76+ for query in self .config .path .get ("queries" , []):
77+ self ._download_file (query ["path" ], query ["link" ])
78+ else :
79+ print (
80+ f"skipping to download query file given skip_search={ self .skip_search } "
81+ )
82+ if self .skip_upload is False :
83+ # Download data files
84+ for data in self .config .path .get ("data" , []):
85+ start_idx = data ["start_idx" ]
86+ end_idx = data ["end_idx" ]
87+ data_path = data ["path" ]
88+ data_link = data ["link" ]
89+ if self .upload_start_idx >= end_idx :
90+ print (
91+ f"skipping downloading { data_path } from { data_link } given { self .upload_start_idx } >{ end_idx } "
92+ )
93+ continue
94+ if self .upload_end_idx < start_idx :
95+ print (
96+ f"skipping downloading { data_path } from { data_link } given { self .upload_end_idx } <{ start_idx } "
97+ )
98+ continue
99+ self ._download_file (data ["path" ], data ["link" ])
100+ else :
101+ print (
102+ f"skipping to download data/upload files given skip_upload={ self .skip_upload } "
103+ )
104+
105+ else : # Handle single-file datasets
106+ target_path = DATASETS_DIR / self .config .path
35107
108+ if target_path .exists ():
109+ print (f"{ target_path } already exists" )
110+ return
111+
112+ if self .config .link :
113+ downloaded_withboto = False
114+ if is_s3_link (self .config .link ):
115+ print ("Use boto3 to download from S3. Faster!" )
116+ try :
117+ self ._download_from_s3 (self .config .link , target_path )
118+ downloaded_withboto = True
119+ except botocore .exceptions .NoCredentialsError :
120+ print ("Credentials not found, downloading without boto3" )
121+ if not downloaded_withboto :
122+ print (f"Downloading from URL { self .config .link } ..." )
123+ tmp_path , _ = urllib .request .urlretrieve (
124+ self .config .link , None , show_progress
125+ )
126+ self ._extract_or_move_file (tmp_path , target_path )
127+
128+ def _download_file (self , relative_path : str , url : str ):
129+ target_path = DATASETS_DIR / relative_path
36130 if target_path .exists ():
37131 print (f"{ target_path } already exists" )
38132 return
39133
40- if self . config . link :
41- print ( f"Downloading { self . config . link } ..." )
42- tmp_path , _ = urllib . request . urlretrieve ( self . config . link )
134+ print ( f"Downloading from { url } to { target_path } " )
135+ tmp_path , _ = urllib . request . urlretrieve ( url , None , show_progress )
136+ self . _extract_or_move_file ( tmp_path , target_path )
43137
44- if self .config .link .endswith (".tgz" ) or self .config .link .endswith (
45- ".tar.gz"
46- ):
47- print (f"Extracting: { tmp_path } -> { target_path } " )
48- (DATASETS_DIR / self .config .path ).mkdir (exist_ok = True , parents = True )
49- file = tarfile .open (tmp_path )
138+ def _extract_or_move_file (self , tmp_path , target_path ):
139+ if tmp_path .endswith (".tgz" ) or tmp_path .endswith (".tar.gz" ):
140+ print (f"Extracting: { tmp_path } -> { target_path } " )
141+ (DATASETS_DIR / self .config .path ).mkdir (exist_ok = True , parents = True )
142+ with tarfile .open (tmp_path ) as file :
50143 file .extractall (target_path )
51- file .close ()
52- os .remove (tmp_path )
53- else :
54- print (f"Moving: { tmp_path } -> { target_path } " )
55- (DATASETS_DIR / self .config .path ).parent .mkdir (exist_ok = True )
56- shutil .copy2 (tmp_path , target_path )
57- os .remove (tmp_path )
144+ os .remove (tmp_path )
145+ else :
146+ print (f"Moving: { tmp_path } -> { target_path } " )
147+ Path (target_path ).parent .mkdir (exist_ok = True )
148+ shutil .copy2 (tmp_path , target_path )
149+ os .remove (tmp_path )
150+
151+ def _download_from_s3 (self , link , target_path ):
152+ s3 = boto3 .client ("s3" )
153+ bucket_name , s3_key = parse_s3_url (link )
154+ tmp_path = f"/tmp/{ os .path .basename (s3_key )} "
155+
156+ print (
157+ f"Downloading from S3: { link } ... bucket_name={ bucket_name } , s3_key={ s3_key } "
158+ )
159+ object_info = s3 .head_object (Bucket = bucket_name , Key = s3_key )
160+ total_size = object_info ["ContentLength" ]
161+
162+ with open (tmp_path , "wb" ) as f :
163+ progress = S3Progress (total_size )
164+ s3 .download_fileobj (bucket_name , s3_key , f , Callback = progress )
165+
166+ self ._extract_or_move_file (tmp_path , target_path )
58167
59168 def get_reader (self , normalize : bool ) -> BaseReader :
60169 reader_class = READER_TYPE [self .config .type ]
61- return reader_class (DATASETS_DIR / self .config .path , normalize = normalize )
170+
171+ if self .config .type == "h5-multi" :
172+ # For h5-multi, we need to pass both data files and query file
173+ data_files = self .config .path ["data" ]
174+ for data_file_dict in data_files :
175+ data_file_dict ["path" ] = DATASETS_DIR / data_file_dict ["path" ]
176+ query_file = DATASETS_DIR / self .config .path ["queries" ][0 ]["path" ]
177+ return reader_class (
178+ data_files = data_files ,
179+ query_file = query_file ,
180+ normalize = normalize ,
181+ skip_upload = self .skip_upload ,
182+ skip_search = self .skip_search ,
183+ )
184+ else :
185+ # For single-file datasets
186+ return reader_class (DATASETS_DIR / self .config .path , normalize = normalize )
187+
188+
189+ def is_s3_link (link ):
190+ return link .startswith ("s3://" ) or "s3.amazonaws.com" in link
191+
192+
193+ def parse_s3_url (s3_url ):
194+ if s3_url .startswith ("s3://" ):
195+ s3_parts = s3_url .replace ("s3://" , "" ).split ("/" , 1 )
196+ bucket_name = s3_parts [0 ]
197+ s3_key = s3_parts [1 ] if len (s3_parts ) > 1 else ""
198+ else :
199+ s3_parts = s3_url .replace ("http://" , "" ).replace ("https://" , "" ).split ("/" , 1 )
200+
201+ if ".s3.amazonaws.com" in s3_parts [0 ]:
202+ bucket_name = s3_parts [0 ].split (".s3.amazonaws.com" )[0 ]
203+ s3_key = s3_parts [1 ] if len (s3_parts ) > 1 else ""
204+ else :
205+ bucket_name = s3_parts [0 ]
206+ s3_key = s3_parts [1 ] if len (s3_parts ) > 1 else ""
207+
208+ return bucket_name , s3_key
62209
63210
64211if __name__ == "__main__" :
65- dataset = Dataset (
212+ dataset_s3_split = Dataset (
66213 {
67- "name" : "glove-25-angular" ,
68- "vector_size" : 25 ,
69- "distance" : "Cosine" ,
70- "type" : "h5" ,
71- "path" : "glove-25-angular/glove-25-angular.hdf5" ,
72- "link" : "http://ann-benchmarks.com/glove-25-angular.hdf5" ,
73- }
214+ "name" : "laion-img-emb-768d-1Billion-cosine" ,
215+ "vector_size" : 768 ,
216+ "distance" : "cosine" ,
217+ "type" : "h5-multi" ,
218+ "path" : {
219+ "data" : [
220+ {
221+ "file_number" : 1 ,
222+ "path" : "laion-1b/data/laion-img-emb-768d-1Billion-cosine-data-part1-0_to_10000000.hdf5" ,
223+ "link" : "http://benchmarks.redislabs.s3.amazonaws.com/vecsim/laion-1b/laion-img-emb-768d-1Billion-cosine-data-part1-0_to_10000000.hdf5" ,
224+ "vector_range" : "0-10000000" ,
225+ "file_size" : "30.7 GB" ,
226+ },
227+ {
228+ "file_number" : 2 ,
229+ "path" : "laion-1b/data/laion-img-emb-768d-1Billion-cosine-data-part10-90000000_to_100000000.hdf5" ,
230+ "link" : "http://benchmarks.redislabs.s3.amazonaws.com/vecsim/laion-1b/laion-img-emb-768d-1Billion-cosine-data-part10-90000000_to_100000000.hdf5" ,
231+ "vector_range" : "90000000-100000000" ,
232+ "file_size" : "30.7 GB" ,
233+ },
234+ {
235+ "file_number" : 3 ,
236+ "path" : "laion-1b/data/laion-img-emb-768d-1Billion-cosine-data-part100-990000000_to_1000000000.hdf5" ,
237+ "link" : "http://benchmarks.redislabs.s3.amazonaws.com/vecsim/laion-1b/laion-img-emb-768d-1Billion-cosine-data-part100-990000000_to_1000000000.hdf5" ,
238+ "vector_range" : "990000000-1000000000" ,
239+ "file_size" : "30.7 GB" ,
240+ },
241+ ],
242+ "queries" : [
243+ {
244+ "path" : "laion-1b/laion-img-emb-768d-1Billion-cosine-queries.hdf5" ,
245+ "link" : "http://benchmarks.redislabs.s3.amazonaws.com/vecsim/laion-1b/laion-img-emb-768d-1Billion-cosine-queries.hdf5" ,
246+ "file_size" : "38.7 MB" ,
247+ },
248+ ],
249+ },
250+ },
251+ skip_upload = True ,
252+ skip_search = False ,
74253 )
75254
76- dataset .download ()
255+ dataset_s3_split .download ()
256+ reader = dataset_s3_split .get_reader (normalize = False )
257+ print (reader ) # Outputs the AnnH5MultiReader instance
0 commit comments