1010
1111
1212UNSAFE_TYPES = frozenset (["joblib" ])
13- REQUIRES_SINGLE_FILE = frozenset (["csv" , "joblib" , "file" ])
13+ REQUIRES_SINGLE_FILE = frozenset (["csv" , "joblib" ])
1414
1515
1616def _assert_is_pandas_df (x , file_type : str ) -> None :
@@ -22,35 +22,24 @@ def _assert_is_pandas_df(x, file_type: str) -> None:
2222 )
2323
2424
25- def load_path (meta , path_to_version ):
26- # Check that only a single file name was given
27- fnames = [meta .file ] if isinstance (meta .file , str ) else meta .file
28-
29- _type = meta .type
30-
31- if len (fnames ) > 1 and _type in REQUIRES_SINGLE_FILE :
32- raise ValueError ("Cannot load data when more than 1 file" )
33-
25+ def load_path (filename : str , path_to_version , pin_type = None ):
3426 # file path creation ------------------------------------------------------
35-
36- if _type == "table" :
27+ if pin_type == "table" :
3728 # this type contains an rds and csv files named data.{ext}, so we match
3829 # R pins behavior and hardcode the name
39- target_fname = "data.csv"
40- else :
41- target_fname = fnames [0 ]
30+ filename = "data.csv"
4231
4332 if path_to_version is not None :
44- path_to_file = f"{ path_to_version } /{ target_fname } "
33+ path_to_file = f"{ path_to_version } /{ filename } "
4534 else :
4635 # BoardUrl doesn't have versions, and the file is the full url
47- path_to_file = target_fname
36+ path_to_file = filename
4837
4938 return path_to_file
5039
5140
52- def load_file (meta : Meta , fs , path_to_version ):
53- return fs .open (load_path (meta , path_to_version ))
41+ def load_file (filename : str , fs , path_to_version , pin_type ):
42+ return fs .open (load_path (filename , path_to_version , pin_type ))
5443
5544
5645def load_data (
@@ -81,7 +70,7 @@ def load_data(
8170 " * https://scikit-learn.org/stable/modules/model_persistence.html#security-maintainability-limitations"
8271 )
8372
84- with load_file (meta , fs , path_to_version ) as f :
73+ with load_file (meta . file , fs , path_to_version , meta . type ) as f :
8574 if meta .type == "csv" :
8675 import pandas as pd
8776
@@ -136,7 +125,9 @@ def load_data(
136125 raise NotImplementedError (f"No driver for type { meta .type } " )
137126
138127
139- def save_data (obj , fname , type = None , apply_suffix : bool = True ) -> "str | Sequence[str]" :
128+ def save_data (
129+ obj , fname , pin_type = None , apply_suffix : bool = True
130+ ) -> "str | Sequence[str]" :
140131 # TODO: extensible saving with deferred importing
141132 # TODO: how to encode arguments to saving / loading drivers?
142133 # e.g. pandas index options
@@ -145,59 +136,68 @@ def save_data(obj, fname, type=None, apply_suffix: bool = True) -> "str | Sequen
145136 # of saving / loading objects different ways.
146137
147138 if apply_suffix :
148- if type == "file" :
139+ if pin_type == "file" :
149140 suffix = "" .join (Path (obj ).suffixes )
150141 else :
151- suffix = f".{ type } "
142+ suffix = f".{ pin_type } "
152143 else :
153144 suffix = ""
154145
155- final_name = f"{ fname } { suffix } "
146+ if isinstance (fname , list ):
147+ final_name = fname
148+ else :
149+ final_name = f"{ fname } { suffix } "
156150
157- if type == "csv" :
151+ if pin_type == "csv" :
158152 _assert_is_pandas_df (obj , file_type = type )
159153
160154 obj .to_csv (final_name , index = False )
161155
162- elif type == "arrow" :
156+ elif pin_type == "arrow" :
163157 # NOTE: R pins accepts the type arrow, and saves it as feather.
164158 # we allow reading this type, but raise an error for writing.
165159 _assert_is_pandas_df (obj , file_type = type )
166160
167161 obj .to_feather (final_name )
168162
169- elif type == "feather" :
163+ elif pin_type == "feather" :
170164 _assert_is_pandas_df (obj , file_type = type )
171165
172166 raise NotImplementedError (
173167 'Saving data as type "feather" no longer supported. Use type "arrow" instead.'
174168 )
175169
176- elif type == "parquet" :
170+ elif pin_type == "parquet" :
177171 _assert_is_pandas_df (obj , file_type = type )
178172
179173 obj .to_parquet (final_name )
180174
181- elif type == "joblib" :
175+ elif pin_type == "joblib" :
182176 import joblib
183177
184178 joblib .dump (obj , final_name )
185179
186- elif type == "json" :
180+ elif pin_type == "json" :
187181 import json
188182
189183 json .dump (obj , open (final_name , "w" ))
190184
191- elif type == "file" :
185+ elif pin_type == "file" :
192186 import contextlib
193187 import shutil
194188
189+ if isinstance (obj , list ):
190+ for file , final in zip (obj , final_name ):
191+ with contextlib .suppress (shutil .SameFileError ):
192+ shutil .copyfile (str (file ), final )
193+ return obj
195194 # ignore the case where the source is the same as the target
196- with contextlib .suppress (shutil .SameFileError ):
197- shutil .copyfile (str (obj ), final_name )
195+ else :
196+ with contextlib .suppress (shutil .SameFileError ):
197+ shutil .copyfile (str (obj ), final_name )
198198
199199 else :
200- raise NotImplementedError (f"Cannot save type: { type } " )
200+ raise NotImplementedError (f"Cannot save type: { pin_type } " )
201201
202202 return final_name
203203
0 commit comments