@@ -125,7 +125,6 @@ def test_percentage(tmp_path):
125125 func_callback (1 )
126126
127127
128- @pytest .mark .skipolddriver
129128def test_upload_file_with_azure_upload_failed_error (tmp_path ):
130129 """Tests Upload file with expired Azure storage token."""
131130 file1 = tmp_path / "file1"
@@ -166,3 +165,94 @@ def test_upload_file_with_azure_upload_failed_error(tmp_path):
166165 rest_client .execute ()
167166 assert mock_update .called
168167 assert rest_client ._results [0 ].error_details is exc
168+
169+
170+ def test_iobound_limit (tmp_path ):
171+ file1 = tmp_path / "file1"
172+ file2 = tmp_path / "file2"
173+ file3 = tmp_path / "file3"
174+ file1 .touch ()
175+ file2 .touch ()
176+ file3 .touch ()
177+ # Positive case
178+ rest_client = SnowflakeFileTransferAgent (
179+ mock .MagicMock (autospec = SnowflakeCursor ),
180+ "PUT some_file.txt" ,
181+ {
182+ "data" : {
183+ "command" : "UPLOAD" ,
184+ "src_locations" : [file1 , file2 , file3 ],
185+ "sourceCompression" : "none" ,
186+ "stageInfo" : {
187+ "creds" : {
188+ "AZURE_SAS_TOKEN" : "sas_token" ,
189+ },
190+ "location" : "some_bucket" ,
191+ "region" : "no_region" ,
192+ "locationType" : "AZURE" ,
193+ "path" : "remote_loc" ,
194+ "endPoint" : "" ,
195+ "storageAccount" : "storage_account" ,
196+ },
197+ },
198+ "success" : True ,
199+ },
200+ )
201+ with mock .patch (
202+ "snowflake.connector.file_transfer_agent.ThreadPoolExecutor"
203+ ) as tpe :
204+ with mock .patch ("snowflake.connector.file_transfer_agent.threading.Condition" ):
205+ with mock .patch (
206+ "snowflake.connector.file_transfer_agent.TransferMetadata" ,
207+ return_value = mock .Mock (
208+ num_files_started = 0 ,
209+ num_files_completed = 3 ,
210+ ),
211+ ):
212+ try :
213+ rest_client .execute ()
214+ except AttributeError :
215+ pass
216+ # 2 IObound TPEs should be created for 3 files unlimited
217+ rest_client = SnowflakeFileTransferAgent (
218+ mock .MagicMock (autospec = SnowflakeCursor ),
219+ "PUT some_file.txt" ,
220+ {
221+ "data" : {
222+ "command" : "UPLOAD" ,
223+ "src_locations" : [file1 , file2 , file3 ],
224+ "sourceCompression" : "none" ,
225+ "stageInfo" : {
226+ "creds" : {
227+ "AZURE_SAS_TOKEN" : "sas_token" ,
228+ },
229+ "location" : "some_bucket" ,
230+ "region" : "no_region" ,
231+ "locationType" : "AZURE" ,
232+ "path" : "remote_loc" ,
233+ "endPoint" : "" ,
234+ "storageAccount" : "storage_account" ,
235+ },
236+ },
237+ "success" : True ,
238+ },
239+ iobound_tpe_limit = 2 ,
240+ )
241+ assert len (list (filter (lambda e : e .args == (3 ,), tpe .call_args_list ))) == 2
242+ with mock .patch (
243+ "snowflake.connector.file_transfer_agent.ThreadPoolExecutor"
244+ ) as tpe :
245+ with mock .patch ("snowflake.connector.file_transfer_agent.threading.Condition" ):
246+ with mock .patch (
247+ "snowflake.connector.file_transfer_agent.TransferMetadata" ,
248+ return_value = mock .Mock (
249+ num_files_started = 0 ,
250+ num_files_completed = 3 ,
251+ ),
252+ ):
253+ try :
254+ rest_client .execute ()
255+ except AttributeError :
256+ pass
257+ # 2 IObound TPEs should be created for 3 files limited to 2
258+ assert len (list (filter (lambda e : e .args == (2 ,), tpe .call_args_list ))) == 2
0 commit comments