Skip to content

Commit 9cb0aac

Browse files
committed
feat: implement whitelist
1 parent 6d744ac commit 9cb0aac

File tree

4 files changed

+66
-1
lines changed

4 files changed

+66
-1
lines changed

internal/common/utils.go

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package common
33
import (
44
"fmt"
55
"math/big"
6+
"regexp"
67
"strings"
78
"unicode"
89
)
@@ -169,3 +170,51 @@ func isType(word string) bool {
169170

170171
return types[word]
171172
}
173+
174+
var allowedFunctions = map[string]struct{}{
175+
"sum": {},
176+
"count": {},
177+
"reinterpretAsUInt256": {},
178+
"reverse": {},
179+
"unhex": {},
180+
"substring": {},
181+
"length": {},
182+
"toUInt256": {},
183+
"if": {},
184+
}
185+
186+
var disallowedPatterns = []string{
187+
`(?i)\b(UNION|INSERT|DELETE|UPDATE|DROP|CREATE|ALTER|TRUNCATE|EXEC|;|--)`,
188+
}
189+
190+
// validateQuery checks the query for disallowed patterns and ensures only allowed functions are used.
191+
func ValidateQuery(query string) error {
192+
// Check for disallowed patterns
193+
for _, pattern := range disallowedPatterns {
194+
matched, err := regexp.MatchString(pattern, query)
195+
if err != nil {
196+
return fmt.Errorf("error checking disallowed patterns: %v", err)
197+
}
198+
if matched {
199+
return fmt.Errorf("query contains disallowed keywords or patterns")
200+
}
201+
}
202+
203+
// Ensure the query is a SELECT statement
204+
trimmedQuery := strings.TrimSpace(strings.ToUpper(query))
205+
if !strings.HasPrefix(trimmedQuery, "SELECT") {
206+
return fmt.Errorf("only SELECT queries are allowed")
207+
}
208+
209+
// Extract function names and validate them
210+
functionPattern := regexp.MustCompile(`(?i)(\b\w+\b)\s*\(`)
211+
matches := functionPattern.FindAllStringSubmatch(query, -1)
212+
for _, match := range matches {
213+
funcName := match[1]
214+
if _, ok := allowedFunctions[funcName]; !ok {
215+
return fmt.Errorf("function '%s' is not allowed", funcName)
216+
}
217+
}
218+
219+
return nil
220+
}

internal/handlers/logs_handlers.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,7 @@ func handleLogsRequest(c *gin.Context, contractAddress, signature string) {
170170
aggregatesResult, err := mainStorage.GetAggregations("logs", qf)
171171
if err != nil {
172172
log.Error().Err(err).Msg("Error querying aggregates")
173+
// TODO: might want to choose BadRequestError if it's due to not-allowed functions
173174
api.InternalErrorHandler(c)
174175
return
175176
}
@@ -180,6 +181,7 @@ func handleLogsRequest(c *gin.Context, contractAddress, signature string) {
180181
logsResult, err := mainStorage.GetLogs(qf)
181182
if err != nil {
182183
log.Error().Err(err).Msg("Error querying logs")
184+
// TODO: might want to choose BadRequestError if it's due to not-allowed functions
183185
api.InternalErrorHandler(c)
184186
return
185187
}

internal/handlers/transactions_handlers.go

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,7 @@ func handleTransactionsRequest(c *gin.Context, contractAddress, signature string
172172
aggregatesResult, err := mainStorage.GetAggregations("transactions", qf)
173173
if err != nil {
174174
log.Error().Err(err).Msg("Error querying aggregates")
175+
// TODO: might want to choose BadRequestError if it's due to not-allowed functions
175176
api.InternalErrorHandler(c)
176177
return
177178
}
@@ -181,7 +182,8 @@ func handleTransactionsRequest(c *gin.Context, contractAddress, signature string
181182
// Retrieve logs data
182183
transactionsResult, err := mainStorage.GetTransactions(qf)
183184
if err != nil {
184-
log.Error().Err(err).Msg("Error querying tran")
185+
log.Error().Err(err).Msg("Error querying transactions")
186+
// TODO: might want to choose BadRequestError if it's due to not-allowed functions
185187
api.InternalErrorHandler(c)
186188
return
187189
}

internal/storage/clickhouse.go

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -301,6 +301,9 @@ func (c *ClickHouseConnector) GetBlocks(qf QueryFilter) (blocks []common.Block,
301301

302302
query += getLimitClause(int(qf.Limit))
303303

304+
if err := common.ValidateQuery(query); err != nil {
305+
return nil, err
306+
}
304307
rows, err := c.conn.Query(context.Background(), query)
305308
if err != nil {
306309
return nil, err
@@ -369,6 +372,9 @@ func (c *ClickHouseConnector) GetAggregations(table string, qf QueryFilter) (Que
369372
query += fmt.Sprintf(" GROUP BY %s", groupByColumns)
370373
}
371374

375+
if err := common.ValidateQuery(query); err != nil {
376+
return QueryResult[interface{}]{}, err
377+
}
372378
// Execute the query
373379
rows, err := c.conn.Query(context.Background(), query)
374380
if err != nil {
@@ -421,6 +427,9 @@ func (c *ClickHouseConnector) GetAggregations(table string, qf QueryFilter) (Que
421427
func executeQuery[T any](c *ClickHouseConnector, table, columns string, qf QueryFilter, scanFunc func(driver.Rows) (T, error)) (QueryResult[T], error) {
422428
query := c.buildQuery(table, columns, qf)
423429

430+
if err := common.ValidateQuery(query); err != nil {
431+
return QueryResult[T]{}, err
432+
}
424433
rows, err := c.conn.Query(context.Background(), query)
425434
if err != nil {
426435
return QueryResult[T]{}, err
@@ -856,6 +865,9 @@ func (c *ClickHouseConnector) GetTraces(qf QueryFilter) (traces []common.Trace,
856865

857866
query += getLimitClause(int(qf.Limit))
858867

868+
if err := common.ValidateQuery(query); err != nil {
869+
return nil, err
870+
}
859871
rows, err := c.conn.Query(context.Background(), query)
860872
if err != nil {
861873
return nil, err

0 commit comments

Comments
 (0)