Skip to content

Commit 7a6b3bf

Browse files
committed
fix: sql query binding building
1 parent 4ed7b8c commit 7a6b3bf

File tree

1 file changed

+23
-9
lines changed

1 file changed

+23
-9
lines changed

storage/sql.go

Lines changed: 23 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -176,7 +176,8 @@ func (s *SQLAdapter) Get(dest any, filter map[string]any, params ...map[string]a
176176
if len(filter) == 0 {
177177
return errors.New("filtering is required when getting a resource")
178178
}
179-
result := s.DB.Where(s.buildQuery(filter), filter).Find(dest)
179+
query, bindings := s.buildQuery(filter)
180+
result := s.DB.Where(query, bindings).Find(dest)
180181
if result.RowsAffected == 0 {
181182
return ErrNotFound
182183
}
@@ -187,15 +188,17 @@ func (s *SQLAdapter) Update(item any, filter map[string]any, params ...map[strin
187188
if len(filter) == 0 {
188189
return errors.New("filtering is required when updating a resource")
189190
}
190-
result := s.DB.Where(s.buildQuery(filter), filter).Save(item)
191+
query, bindings := s.buildQuery(filter)
192+
result := s.DB.Where(query, bindings).Save(item)
191193
return result.Error
192194
}
193195

194196
func (s *SQLAdapter) Delete(item any, filter map[string]any, params ...map[string]any) error {
195197
if len(filter) == 0 {
196198
return errors.New("filtering is required when deleting a resource")
197199
}
198-
result := s.DB.Where(s.buildQuery(filter), filter).Delete(item)
200+
query, bindings := s.buildQuery(filter)
201+
result := s.DB.Where(query, bindings).Delete(item)
199202
return result.Error
200203
}
201204

@@ -250,7 +253,8 @@ func (s *SQLAdapter) executePaginatedQuery(
250253
func (s *SQLAdapter) List(dest any, sortKey string, filter map[string]any, limit int, cursor string, params ...map[string]any) (string, error) {
251254
return s.executePaginatedQuery(dest, sortKey, limit, cursor, func(q *gorm.DB) *gorm.DB {
252255
if len(filter) > 0 {
253-
return q.Where(s.buildQuery(filter), filter)
256+
query, bindings := s.buildQuery(filter)
257+
return q.Where(query, bindings)
254258
}
255259
return q
256260
})
@@ -292,7 +296,8 @@ func (s *SQLAdapter) Count(dest any, filter map[string]any, params ...map[string
292296
q := s.DB.Model(dest)
293297

294298
if len(filter) > 0 {
295-
q = q.Where(s.buildQuery(filter), filter)
299+
query, bindings := s.buildQuery(filter)
300+
q = q.Where(query, bindings)
296301
}
297302

298303
var total int64
@@ -307,10 +312,19 @@ func (s *SQLAdapter) Query(dest any, statement string, limit int, cursor string,
307312
return "", fmt.Errorf("not implemented yet")
308313
}
309314

310-
func (s *SQLAdapter) buildQuery(filter map[string]any) string {
315+
func (s *SQLAdapter) buildQuery(filter map[string]any) (string, map[string]any) {
311316
clauses := []string{}
312-
for key := range filter {
313-
clauses = append(clauses, fmt.Sprintf("%s = @%s", key, key))
317+
bindings := make(map[string]any)
318+
319+
for key, value := range filter {
320+
if value == nil {
321+
// For nil values, use IS NULL instead of = @key
322+
clauses = append(clauses, fmt.Sprintf("%s IS NULL", key))
323+
} else {
324+
// For non-nil values, use = @key and include in bindings
325+
clauses = append(clauses, fmt.Sprintf("%s = @%s", key, key))
326+
bindings[key] = value
327+
}
314328
}
315-
return strings.Join(clauses, " AND ")
329+
return strings.Join(clauses, " AND "), bindings
316330
}

0 commit comments

Comments
 (0)