Skip to content

Commit cd68a1d

Browse files
authored
fix(sql): pagination remove extra conversions (#1702)
* fix(sql): pagination remove extra conversions * update comment
1 parent dab41c2 commit cd68a1d

File tree

2 files changed

+21
-26
lines changed

2 files changed

+21
-26
lines changed

pandasai/query_builders/paginator.py

Lines changed: 12 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,9 @@ def apply_pagination(
107107
if not pagination:
108108
return query, params
109109

110+
# Convert query from target dialect to postgres to generate standardized pagination query
111+
query = sqlglot.transpile(query, read=target_dialect, write="postgres")[0]
112+
110113
filtering_query = f"SELECT * FROM ({query}) AS filtered_data"
111114
conditions = []
112115

@@ -118,26 +121,26 @@ def apply_pagination(
118121
column_type = column["type"]
119122

120123
if column_type == "string":
121-
search_conditions.append(f"{column_name} ILIKE %s")
124+
search_conditions.append(f'"{column_name}" ILIKE %s')
122125
params.append(f"%{pagination.search}%")
123126

124127
elif column_type == "float" and DatasetPaginator.is_float(
125128
pagination.search
126129
):
127-
search_conditions.append(f"{column_name} = %s")
130+
search_conditions.append(f'"{column_name}" = %s')
128131
params.append(pagination.search)
129132

130133
elif (
131134
column_type in ["number", "integer"]
132135
and pagination.search.isnumeric()
133136
):
134-
search_conditions.append(f"{column_name} = %s")
137+
search_conditions.append(f'"{column_name}" = %s')
135138
params.append(pagination.search)
136139

137140
elif column_type == "datetime" and DatasetPaginator.is_valid_datetime(
138141
pagination.search
139142
):
140-
search_conditions.append(f"{column_name} = %s")
143+
search_conditions.append(f'"{column_name}" = %s')
141144
params.append(
142145
datetime.datetime.strptime(
143146
pagination.search, "%Y-%m-%d %H:%M:%S"
@@ -147,13 +150,13 @@ def apply_pagination(
147150
elif column_type == "boolean" and DatasetPaginator.is_valid_boolean(
148151
pagination.search
149152
):
150-
search_conditions.append(f"{column_name} = %s")
153+
search_conditions.append(f'"{column_name}" = %s')
151154
params.append(pagination.search)
152155

153156
elif column_type == "uuid" and DatasetPaginator.is_valid_uuid(
154157
pagination.search
155158
):
156-
search_conditions.append(f"{column_name}::TEXT = %s")
159+
search_conditions.append(f'"{column_name}"::TEXT = %s')
157160
params.append(pagination.search)
158161

159162
if search_conditions:
@@ -171,7 +174,7 @@ def apply_pagination(
171174
if not isinstance(values, list):
172175
values = [values]
173176
placeholders = ", ".join(["%s"] * len(values))
174-
conditions.append(f"{column} IN ({placeholders})")
177+
conditions.append(f'"{column}" IN ({placeholders})')
175178
params.extend(values)
176179
except json.JSONDecodeError as e:
177180
raise ValueError(f"Invalid filters format: {e}")
@@ -188,7 +191,7 @@ def apply_pagination(
188191
)
189192

190193
filtering_query += (
191-
f" ORDER BY {pagination.sort_by} {pagination.sort_order.upper()}"
194+
f' ORDER BY "{pagination.sort_by}" {pagination.sort_order.upper()}'
192195
)
193196

194197
# Handle page and page_size
@@ -198,12 +201,4 @@ def apply_pagination(
198201
[pagination.page_size, (pagination.page - 1) * pagination.page_size]
199202
)
200203

201-
# Replace placeholders for target dialect
202-
placeholder = "___PLACEHOLDER___"
203-
temp_query = filtering_query.replace("%s", placeholder)
204-
transpiled_query = sqlglot.transpile(
205-
temp_query, read="postgres", write=target_dialect
206-
)[0]
207-
final_query = transpiled_query.replace(placeholder, "%s")
208-
209-
return final_query, params
204+
return filtering_query, params

tests/unit_tests/query_builders/test_paginator.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ def test_search_string_column(self, sample_query, sample_columns):
8282
query, parameters = DatasetPaginator.apply_pagination(
8383
sample_query, sample_columns, params
8484
)
85-
assert "name ILIKE %s" in query
85+
assert '"name" ILIKE %s' in query
8686
assert parameters[0] == "%John%" # First parameter is search term
8787
assert len(parameters) == 3 # search + LIMIT/OFFSET
8888

@@ -92,8 +92,8 @@ def test_search_numeric_columns(self, sample_query, sample_columns):
9292
query, parameters = DatasetPaginator.apply_pagination(
9393
sample_query, sample_columns, params
9494
)
95-
assert "id = %s" in query
96-
assert "age = %s" in query
95+
assert '"id" = %s' in query
96+
assert '"age" = %s' in query
9797
assert parameters.count("25") >= 2 # At least id and age columns
9898
assert len(parameters) > 2 # search params + LIMIT/OFFSET
9999

@@ -103,7 +103,7 @@ def test_search_datetime(self, sample_query, sample_columns):
103103
query, parameters = DatasetPaginator.apply_pagination(
104104
sample_query, sample_columns, params
105105
)
106-
assert "created_at = %s" in query
106+
assert '"created_at" = %s' in query
107107
# Convert the datetime string to expected format
108108
expected_dt = datetime.datetime.strptime(
109109
"2023-01-01 12:00:00", "%Y-%m-%d %H:%M:%S"
@@ -120,7 +120,7 @@ def test_filters(self, sample_query, sample_columns):
120120
query, parameters = DatasetPaginator.apply_pagination(
121121
sample_query, sample_columns, params
122122
)
123-
assert "age IN (%s, %s, %s)" in query
123+
assert '"age" IN (%s, %s, %s)' in query
124124
assert all(
125125
x in parameters for x in [25, 30, 35]
126126
) # Filter values are in parameters
@@ -134,7 +134,7 @@ def test_sorting(self, sample_query, sample_columns):
134134
query, parameters = DatasetPaginator.apply_pagination(
135135
sample_query, sample_columns, params
136136
)
137-
assert "ORDER BY age DESC" in query
137+
assert 'ORDER BY "age" DESC' in query
138138

139139
def test_invalid_sort_column(self, sample_query, sample_columns):
140140
"""Test error on invalid sort column"""
@@ -183,7 +183,7 @@ def test_boolean_search(self, sample_query, sample_columns):
183183
query, parameters = DatasetPaginator.apply_pagination(
184184
sample_query, sample_columns, params
185185
)
186-
assert "is_active = %s" in query
186+
assert '"is_active" = %s' in query
187187
assert "true" in [str(p).lower() for p in parameters]
188188

189189
def test_uuid_search(self, sample_query, sample_columns):
@@ -193,7 +193,7 @@ def test_uuid_search(self, sample_query, sample_columns):
193193
query, parameters = DatasetPaginator.apply_pagination(
194194
sample_query, sample_columns, params
195195
)
196-
assert "CAST(user_id AS TEXT) = %s" in query
196+
assert '"user_id"::TEXT = %s' in query
197197
assert uuid_value in parameters
198198

199199
def test_filter_single_value(self, sample_query, sample_columns):
@@ -206,7 +206,7 @@ def test_filter_single_value(self, sample_query, sample_columns):
206206
query, parameters = DatasetPaginator.apply_pagination(
207207
sample_query, sample_columns, params
208208
)
209-
assert "age IN (%s)" in query
209+
assert '"age" IN (%s)' in query
210210
assert 25 in parameters
211211

212212
def test_invalid_json_filter(self, sample_query, sample_columns):

0 commit comments

Comments
 (0)