@@ -5,7 +5,11 @@ use futures_util::StreamExt;
55use sqlparser:: ast:: {
66 CopyLegacyCsvOption , CopyLegacyOption , CopyOption , CopySource , CopyTarget , Statement ,
77} ;
8- use sqlx:: { any:: AnyArguments , AnyConnection , Arguments , Executor } ;
8+ use sqlx:: {
9+ any:: { AnyArguments , AnyKind } ,
10+ AnyConnection , Arguments , Executor ,
11+ } ;
12+ use tokio:: io:: AsyncRead ;
913
1014use crate :: webserver:: http_request_info:: RequestInfo ;
1115
@@ -152,7 +156,16 @@ pub(super) async fn run_csv_import(
152156 let file = tokio:: fs:: File :: open ( file_path)
153157 . await
154158 . with_context ( || "opening csv" ) ?;
155- let insert_stmt = create_insert_stmt ( db, csv_import) ;
159+ let buffered = tokio:: io:: BufReader :: new ( file) ;
160+ run_csv_import_on_path ( db, csv_import, buffered) . await
161+ }
162+
163+ async fn run_csv_import_on_path (
164+ db : & mut AnyConnection ,
165+ csv_import : & CsvImport ,
166+ file : impl AsyncRead + Unpin + Send ,
167+ ) -> anyhow:: Result < ( ) > {
168+ let insert_stmt = create_insert_stmt ( db. kind ( ) , csv_import) ;
156169 log:: debug!( "CSV data insert statement: {insert_stmt}" ) ;
157170 let mut reader = make_csv_reader ( csv_import, file) ;
158171 let col_idxs = compute_column_indices ( & mut reader, csv_import) . await ?;
@@ -164,8 +177,8 @@ pub(super) async fn run_csv_import(
164177 Ok ( ( ) )
165178}
166179
167- async fn compute_column_indices (
168- reader : & mut csv_async:: AsyncReader < tokio :: fs :: File > ,
180+ async fn compute_column_indices < R : AsyncRead + Unpin + Send > (
181+ reader : & mut csv_async:: AsyncReader < R > ,
169182 csv_import : & CsvImport ,
170183) -> anyhow:: Result < Vec < usize > > {
171184 let mut col_idxs = Vec :: with_capacity ( csv_import. columns . len ( ) ) ;
@@ -189,16 +202,17 @@ async fn compute_column_indices(
189202 Ok ( col_idxs)
190203}
191204
192- fn create_insert_stmt ( db : & mut AnyConnection , csv_import : & CsvImport ) -> String {
193- let kind = db. kind ( ) ;
205+ fn create_insert_stmt ( kind : AnyKind , csv_import : & CsvImport ) -> String {
194206 let columns = csv_import. columns . join ( ", " ) ;
195207 let placeholders = csv_import
196208 . columns
197209 . iter ( )
198210 . enumerate ( )
199- . map ( |( i, _) | make_placeholder ( kind, i) )
211+ . map ( |( i, _) | make_placeholder ( kind, i + 1 ) )
200212 . fold ( String :: new ( ) , |mut acc, f| {
201- acc. push_str ( ", " ) ;
213+ if !acc. is_empty ( ) {
214+ acc. push_str ( ", " ) ;
215+ }
202216 acc. push_str ( & f) ;
203217 acc
204218 } ) ;
@@ -225,10 +239,10 @@ async fn process_csv_record(
225239 Ok ( ( ) )
226240}
227241
228- fn make_csv_reader (
242+ fn make_csv_reader < R : AsyncRead + Unpin + Send > (
229243 csv_import : & CsvImport ,
230- file : tokio :: fs :: File ,
231- ) -> csv_async:: AsyncReader < tokio :: fs :: File > {
244+ file : R ,
245+ ) -> csv_async:: AsyncReader < R > {
232246 let delimiter = csv_import
233247 . delimiter
234248 . and_then ( |c| u8:: try_from ( c) . ok ( ) )
@@ -246,3 +260,60 @@ fn make_csv_reader(
246260 . escape ( escape)
247261 . create_reader ( file)
248262}
263+
264+ #[ test]
265+ fn test_make_statement ( ) {
266+ let csv_import = CsvImport {
267+ query : "COPY my_table (col1, col2) FROM 'my_file.csv' WITH (DELIMITER ';', HEADER)" . into ( ) ,
268+ table_name : "my_table" . into ( ) ,
269+ columns : vec ! [ "col1" . into( ) , "col2" . into( ) ] ,
270+ delimiter : Some ( ';' ) ,
271+ quote : None ,
272+ header : Some ( true ) ,
273+ null_str : None ,
274+ escape : None ,
275+ uploaded_file : "my_file.csv" . into ( ) ,
276+ } ;
277+ let insert_stmt = create_insert_stmt ( AnyKind :: Postgres , & csv_import) ;
278+ assert_eq ! (
279+ insert_stmt,
280+ "INSERT INTO my_table (col1, col2) VALUES ($1, $2)"
281+ ) ;
282+ }
283+
284+ #[ actix_web:: test]
285+ async fn test_end_to_end ( ) {
286+ use sqlx:: ConnectOptions ;
287+
288+ let mut copy_stmt = sqlparser:: parser:: Parser :: parse_sql (
289+ & sqlparser:: dialect:: GenericDialect { } ,
290+ "COPY my_table (col1, col2) FROM 'my_file.csv' WITH (DELIMITER ';', HEADER)" ,
291+ )
292+ . unwrap ( )
293+ . into_iter ( )
294+ . next ( )
295+ . unwrap ( ) ;
296+ let csv_import = extract_csv_copy_statement ( & mut copy_stmt) . unwrap ( ) ;
297+ let mut conn = "sqlite::memory:"
298+ . parse :: < sqlx:: any:: AnyConnectOptions > ( )
299+ . unwrap ( )
300+ . connect ( )
301+ . await
302+ . unwrap ( ) ;
303+ conn. execute ( "CREATE TABLE my_table (col1 TEXT, col2 TEXT)" )
304+ . await
305+ . unwrap ( ) ;
306+ let csv = "col2;col1\n a;b\n c;d" ; // order is different from the table
307+ let file = csv. as_bytes ( ) ;
308+ run_csv_import_on_path ( & mut conn, & csv_import, file)
309+ . await
310+ . unwrap ( ) ;
311+ let rows: Vec < ( String , String ) > = sqlx:: query_as ( "SELECT * FROM my_table" )
312+ . fetch_all ( & mut conn)
313+ . await
314+ . unwrap ( ) ;
315+ assert_eq ! (
316+ rows,
317+ vec![ ( "b" . into( ) , "a" . into( ) ) , ( "d" . into( ) , "c" . into( ) ) ]
318+ ) ;
319+ }
0 commit comments