1+ use std:: ffi:: CStr ;
12use std:: ptr;
23
34use vortex:: error:: { VortexResult , vortex_err} ;
45
5- use crate :: duckdb:: Database ;
6+ use crate :: duckdb:: { Database , QueryResult } ;
67use crate :: { cpp, duckdb_try, wrapper} ;
78
89wrapper ! (
@@ -22,8 +23,8 @@ impl Connection {
2223 Ok ( unsafe { Self :: own ( ptr) } )
2324 }
2425
25- /// Execute SQL query and return the row count .
26- pub fn execute_and_get_row_count ( & self , query : & str ) -> VortexResult < usize > {
26+ /// Execute SQL query and return the result .
27+ pub fn query ( & self , query : & str ) -> VortexResult < QueryResult > {
2728 let mut result: cpp:: duckdb_result = unsafe { std:: mem:: zeroed ( ) } ;
2829 let query_cstr =
2930 std:: ffi:: CString :: new ( query) . map_err ( |_| vortex_err ! ( "Invalid query string" ) ) ?;
@@ -36,48 +37,202 @@ impl Connection {
3637 if error_ptr. is_null ( ) {
3738 "Unknown DuckDB error" . to_string ( )
3839 } else {
39- std:: ffi:: CStr :: from_ptr ( error_ptr)
40- . to_string_lossy ( )
41- . into_owned ( )
40+ CStr :: from_ptr ( error_ptr) . to_string_lossy ( ) . into_owned ( )
4241 }
4342 } ;
4443
4544 unsafe { cpp:: duckdb_destroy_result ( & mut result) } ;
4645 return Err ( vortex_err ! ( "Failed to execute query: {}" , error_msg) ) ;
4746 }
4847
49- let row_count = unsafe { cpp:: duckdb_row_count ( & mut result) . try_into ( ) ? } ;
50- unsafe { cpp:: duckdb_destroy_result ( & mut result) } ;
48+ Ok ( unsafe { QueryResult :: new ( result) } )
49+ }
50+ }
5151
52- Ok ( row_count)
52+ #[ cfg( test) ]
53+ mod tests {
54+ use super :: * ;
55+
56+ fn test_connection ( ) -> VortexResult < Connection > {
57+ let db = Database :: open_in_memory ( ) ?;
58+ db. connect ( )
5359 }
5460
55- /// Execute SQL query.
56- pub fn execute ( & self , query : & str ) -> VortexResult < ( ) > {
57- let mut result : cpp :: duckdb_result = unsafe { std :: mem :: zeroed ( ) } ;
58- let query_cstr =
59- std :: ffi :: CString :: new ( query ) . map_err ( |_| vortex_err ! ( "Invalid query string" ) ) ? ;
61+ # [ test ]
62+ fn test_connection_creation ( ) {
63+ let conn = test_connection ( ) ;
64+ assert ! ( conn . is_ok ( ) ) ;
65+ }
6066
61- let status = unsafe { cpp:: duckdb_query ( self . as_ptr ( ) , query_cstr. as_ptr ( ) , & mut result) } ;
67+ #[ test]
68+ fn test_execute_success ( ) {
69+ let conn = test_connection ( ) . unwrap ( ) ;
70+ let result = conn. query ( "SELECT 1" ) ;
71+ assert ! ( result. is_ok( ) ) ;
72+ }
6273
63- if status != cpp:: duckdb_state:: DuckDBSuccess {
64- let error_msg = unsafe {
65- let error_ptr = cpp:: duckdb_result_error ( & mut result) ;
66- if error_ptr. is_null ( ) {
67- "Unknown DuckDB error" . to_string ( )
68- } else {
69- std:: ffi:: CStr :: from_ptr ( error_ptr)
70- . to_string_lossy ( )
71- . into_owned ( )
72- }
73- } ;
74+ #[ test]
75+ fn test_execute_invalid_sql ( ) {
76+ let conn = test_connection ( ) . unwrap ( ) ;
77+ let result = conn. query ( "INVALID SQL STATEMENT" ) ;
78+ assert ! ( result. is_err( ) ) ;
79+ let error_msg = result. unwrap_err ( ) . to_string ( ) ;
80+ assert ! ( error_msg. contains( "Failed to execute query" ) ) ;
81+ }
7482
75- unsafe { cpp:: duckdb_destroy_result ( & mut result) } ;
76- return Err ( vortex_err ! ( "Failed to execute query: {}" , error_msg) ) ;
77- }
83+ #[ test]
84+ fn test_execute_with_null_bytes ( ) {
85+ let conn = test_connection ( ) . unwrap ( ) ;
86+ let result = conn. query ( "SELECT\0 1" ) ;
87+ assert ! ( result. is_err( ) ) ;
88+ let error_msg = result. unwrap_err ( ) . to_string ( ) ;
89+ assert ! ( error_msg. contains( "Invalid query string" ) ) ;
90+ }
91+
92+ #[ test]
93+ fn test_query_and_get_row_count_select ( ) {
94+ let conn = test_connection ( ) . unwrap ( ) ;
95+ let result = conn. query ( "SELECT 1, 2, 3" ) . unwrap ( ) ;
96+ assert_eq ! ( result. row_count( ) . unwrap( ) , 1 ) ;
97+ }
98+
99+ #[ test]
100+ fn test_query_and_get_row_count_create_table ( ) {
101+ let conn = test_connection ( ) . unwrap ( ) ;
102+
103+ // CREATE TABLE should return 0 rows
104+ let result = conn
105+ . query ( "CREATE TABLE test (id INTEGER, name VARCHAR)" )
106+ . unwrap ( ) ;
107+ assert_eq ! ( result. row_count( ) . unwrap( ) , 0 ) ;
108+ }
109+
110+ #[ test]
111+ fn test_query_and_get_row_count_insert ( ) {
112+ let conn = test_connection ( ) . unwrap ( ) ;
113+ conn. query ( "CREATE TABLE test (id INTEGER, name VARCHAR)" )
114+ . unwrap ( ) ;
115+
116+ let result = conn
117+ . query ( "INSERT INTO test VALUES (1, 'Alice'), (2, 'Bob')" )
118+ . unwrap ( ) ;
119+
120+ assert_eq ! ( result. row_count( ) . unwrap( ) , 2 ) ;
121+ }
122+
123+ #[ test]
124+ fn test_query_invalid_sql ( ) {
125+ let conn = test_connection ( ) . unwrap ( ) ;
126+ let result = conn. query ( "INVALID SQL" ) ;
127+ assert ! ( result. is_err( ) ) ;
128+ }
129+
130+ #[ test]
131+ fn test_query_single_value ( ) {
132+ let conn = test_connection ( ) . unwrap ( ) ;
133+ let result = conn. query ( "SELECT 42" ) . unwrap ( ) ;
134+
135+ assert_eq ! ( result. column_count( ) . unwrap( ) , 1 ) ;
136+ assert_eq ! ( result. row_count( ) . unwrap( ) , 1 ) ;
137+ assert_eq ! ( result. get:: <i64 >( 0 , 0 ) . unwrap( ) , 42 ) ;
138+ }
139+
140+ #[ test]
141+ fn test_query_multiple_rows ( ) {
142+ let conn = test_connection ( ) . unwrap ( ) ;
143+ conn. query ( "CREATE TABLE test (id INTEGER)" ) . unwrap ( ) ;
144+ conn. query ( "INSERT INTO test VALUES (1), (2), (3)" ) . unwrap ( ) ;
145+
146+ let result = conn. query ( "SELECT id FROM test ORDER BY id" ) . unwrap ( ) ;
147+
148+ assert_eq ! ( result. column_count( ) . unwrap( ) , 1 ) ;
149+ assert_eq ! ( result. row_count( ) . unwrap( ) , 3 ) ;
150+ assert_eq ! ( result. get:: <i64 >( 0 , 0 ) . unwrap( ) , 1 ) ;
151+ assert_eq ! ( result. get:: <i64 >( 0 , 1 ) . unwrap( ) , 2 ) ;
152+ assert_eq ! ( result. get:: <i64 >( 0 , 2 ) . unwrap( ) , 3 ) ;
153+ }
154+
155+ #[ test]
156+ fn test_query_multiple_columns ( ) {
157+ let conn = test_connection ( ) . unwrap ( ) ;
158+ let result = conn. query ( "SELECT 1 as num, 'hello' as text" ) . unwrap ( ) ;
159+
160+ assert_eq ! ( result. column_count( ) . unwrap( ) , 2 ) ;
161+ assert_eq ! ( result. row_count( ) . unwrap( ) , 1 ) ;
162+ assert_eq ! ( result. column_name( 0 ) . unwrap( ) , "num" ) ;
163+ assert_eq ! ( result. column_name( 1 ) . unwrap( ) , "text" ) ;
164+ assert_eq ! ( result. get:: <i64 >( 0 , 0 ) . unwrap( ) , 1 ) ;
165+ assert_eq ! ( result. get:: <String >( 1 , 0 ) . unwrap( ) , "hello" ) ;
166+ }
167+
168+ #[ test]
169+ fn test_query_bounds_checking ( ) {
170+ let conn = test_connection ( ) . unwrap ( ) ;
171+ let result = conn. query ( "SELECT 1" ) . unwrap ( ) ;
172+
173+ // Test row bounds
174+ assert ! ( result. get:: <i64 >( 0 , 1 ) . is_err( ) ) ;
175+
176+ // Test column bounds
177+ assert ! ( result. get:: <i64 >( 1 , 0 ) . is_err( ) ) ;
178+ }
179+
180+ #[ test]
181+ fn test_query_column_types ( ) {
182+ let conn = test_connection ( ) . unwrap ( ) ;
183+ let result = conn
184+ . query ( "SELECT 1 as int_col, 'text' as str_col" )
185+ . unwrap ( ) ;
186+
187+ assert_eq ! ( result. column_type( 0 ) , cpp:: DUCKDB_TYPE :: DUCKDB_TYPE_INTEGER ) ;
188+ assert_eq ! ( result. column_type( 1 ) , cpp:: DUCKDB_TYPE :: DUCKDB_TYPE_VARCHAR ) ;
189+ }
190+
191+ #[ test]
192+ fn test_null_handling ( ) {
193+ let conn = test_connection ( ) . unwrap ( ) ;
194+ let result = conn
195+ . query ( "SELECT NULL as null_col, 1 as not_null_col" )
196+ . unwrap ( ) ;
197+
198+ assert ! ( result. is_null( 0 , 0 ) . unwrap( ) ) ;
199+ assert ! ( !result. is_null( 1 , 0 ) . unwrap( ) ) ;
200+ }
201+
202+ #[ test]
203+ fn test_type_conversion ( ) {
204+ let conn = test_connection ( ) . unwrap ( ) ;
205+ let result = conn
206+ . query ( "SELECT 42::TINYINT, 42::SMALLINT, 42::INTEGER, 42::BIGINT" )
207+ . unwrap ( ) ;
208+
209+ assert_eq ! ( result. get:: <i64 >( 0 , 0 ) . unwrap( ) , 42 ) ; // TINYINT -> i64
210+ assert_eq ! ( result. get:: <i64 >( 1 , 0 ) . unwrap( ) , 42 ) ; // SMALLINT -> i64
211+ assert_eq ! ( result. get:: <i64 >( 2 , 0 ) . unwrap( ) , 42 ) ; // INTEGER -> i64
212+ assert_eq ! ( result. get:: <i64 >( 3 , 0 ) . unwrap( ) , 42 ) ; // BIGINT -> i64
213+ }
214+
215+ #[ test]
216+ fn test_query_and_get_row_count_update ( ) {
217+ let conn = test_connection ( ) . unwrap ( ) ;
218+ conn. query ( "CREATE TABLE test (id INTEGER, name VARCHAR)" )
219+ . unwrap ( ) ;
220+ conn. query ( "INSERT INTO test VALUES (1, 'Alice'), (2, 'Bob'), (3, 'Charlie')" )
221+ . unwrap ( ) ;
222+
223+ let result = conn
224+ . query ( "UPDATE test SET name = 'Updated' WHERE id <= 2" )
225+ . unwrap ( ) ;
226+ assert_eq ! ( result. row_count( ) . unwrap( ) , 2 ) ;
227+ }
78228
79- unsafe { cpp:: duckdb_destroy_result ( & mut result) } ;
229+ #[ test]
230+ fn test_query_and_get_row_count_delete ( ) {
231+ let conn = test_connection ( ) . unwrap ( ) ;
232+ conn. query ( "CREATE TABLE test (id INTEGER)" ) . unwrap ( ) ;
233+ conn. query ( "INSERT INTO test VALUES (1), (2), (3)" ) . unwrap ( ) ;
80234
81- Ok ( ( ) )
235+ let result = conn. query ( "DELETE FROM test WHERE id > 1" ) . unwrap ( ) ;
236+ assert_eq ! ( result. row_count( ) . unwrap( ) , 2 ) ;
82237 }
83238}
0 commit comments