Skip to content

Commit 3a6a766

Browse files
authored
fix(sqlite): implement PARSE_COLNAMES column name parsing (RustPython#5923)
1 parent e6fdef4 commit 3a6a766

File tree

2 files changed

+30
-8
lines changed

2 files changed

+30
-8
lines changed

Lib/test/test_sqlite3/test_types.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -343,8 +343,6 @@ def test_none(self):
343343
val = self.cur.fetchone()[0]
344344
self.assertEqual(val, None)
345345

346-
# TODO: RUSTPYTHON
347-
@unittest.expectedFailure
348346
def test_col_name(self):
349347
self.cur.execute("insert into test(x) values (?)", ("xxx",))
350348
self.cur.execute('select x as "x y [bar]" from test')

stdlib/src/sqlite.rs

Lines changed: 30 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1511,7 +1511,7 @@ mod _sqlite {
15111511

15121512
inner.row_cast_map = zelf.build_row_cast_map(&st, vm)?;
15131513

1514-
inner.description = st.columns_description(vm)?;
1514+
inner.description = st.columns_description(zelf.connection.detect_types, vm)?;
15151515

15161516
if ret == SQLITE_ROW {
15171517
drop(st);
@@ -1559,7 +1559,7 @@ mod _sqlite {
15591559
));
15601560
}
15611561

1562-
inner.description = st.columns_description(vm)?;
1562+
inner.description = st.columns_description(zelf.connection.detect_types, vm)?;
15631563

15641564
inner.rowcount = if stmt.is_dml { 0 } else { -1 };
15651565

@@ -2744,22 +2744,46 @@ mod _sqlite {
27442744
unsafe { sqlite3_column_name(self.st, pos) }
27452745
}
27462746

2747-
fn columns_name(self, vm: &VirtualMachine) -> PyResult<Vec<PyStrRef>> {
2747+
fn columns_name(self, detect_types: i32, vm: &VirtualMachine) -> PyResult<Vec<PyStrRef>> {
27482748
let count = self.column_count();
27492749
(0..count)
27502750
.map(|i| {
27512751
let name = self.column_name(i);
2752-
ptr_to_str(name, vm).map(|x| vm.ctx.new_str(x))
2752+
let name_str = ptr_to_str(name, vm)?;
2753+
2754+
// If PARSE_COLNAMES is enabled, strip everything after the first '[' (and preceding space)
2755+
let processed_name = if detect_types & PARSE_COLNAMES != 0
2756+
&& let Some(bracket_pos) = name_str.find('[')
2757+
{
2758+
// Check if there's a single space before '[' and remove it (CPython compatibility)
2759+
let end_pos = if bracket_pos > 0
2760+
&& name_str.chars().nth(bracket_pos - 1) == Some(' ')
2761+
{
2762+
bracket_pos - 1
2763+
} else {
2764+
bracket_pos
2765+
};
2766+
2767+
&name_str[..end_pos]
2768+
} else {
2769+
name_str
2770+
};
2771+
2772+
Ok(vm.ctx.new_str(processed_name))
27532773
})
27542774
.collect()
27552775
}
27562776

2757-
fn columns_description(self, vm: &VirtualMachine) -> PyResult<Option<PyTupleRef>> {
2777+
fn columns_description(
2778+
self,
2779+
detect_types: i32,
2780+
vm: &VirtualMachine,
2781+
) -> PyResult<Option<PyTupleRef>> {
27582782
if self.column_count() == 0 {
27592783
return Ok(None);
27602784
}
27612785
let columns = self
2762-
.columns_name(vm)?
2786+
.columns_name(detect_types, vm)?
27632787
.into_iter()
27642788
.map(|s| {
27652789
vm.ctx

0 commit comments

Comments
 (0)