Skip to content

Commit f725367

Browse files
authored
validate group and sort by (#254)
* improve start block determination for poller * validate group and sort by
1 parent f0eaebc commit f725367

File tree

7 files changed

+361
-3
lines changed

7 files changed

+361
-3
lines changed

api/field_validation.go

Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
package api
2+
3+
import (
4+
"fmt"
5+
"regexp"
6+
"strings"
7+
)
8+
9+
// EntityColumns defines the valid columns for each entity type
10+
var EntityColumns = map[string][]string{
11+
"blocks": {
12+
"chain_id", "block_number", "block_timestamp", "hash", "parent_hash", "sha3_uncles",
13+
"nonce", "mix_hash", "miner", "state_root", "transactions_root", "receipts_root",
14+
"logs_bloom", "size", "extra_data", "difficulty", "total_difficulty", "transaction_count",
15+
"gas_limit", "gas_used", "withdrawals_root", "base_fee_per_gas", "insert_timestamp", "sign",
16+
},
17+
"transactions": {
18+
"chain_id", "hash", "nonce", "block_hash", "block_number", "block_timestamp",
19+
"transaction_index", "from_address", "to_address", "value", "gas", "gas_price",
20+
"data", "function_selector", "max_fee_per_gas", "max_priority_fee_per_gas",
21+
"max_fee_per_blob_gas", "blob_versioned_hashes", "transaction_type", "r", "s", "v",
22+
"access_list", "authorization_list", "contract_address", "gas_used", "cumulative_gas_used",
23+
"effective_gas_price", "blob_gas_used", "blob_gas_price", "logs_bloom", "status",
24+
"insert_timestamp", "sign",
25+
},
26+
"logs": {
27+
"chain_id", "block_number", "block_hash", "block_timestamp", "transaction_hash",
28+
"transaction_index", "log_index", "address", "data", "topic_0", "topic_1", "topic_2", "topic_3",
29+
"insert_timestamp", "sign",
30+
},
31+
"transfers": {
32+
"token_type", "chain_id", "token_address", "from_address", "to_address", "block_number",
33+
"block_timestamp", "transaction_hash", "token_id", "amount", "log_index", "insert_timestamp", "sign",
34+
},
35+
"balances": {
36+
"token_type", "chain_id", "owner", "address", "token_id", "balance",
37+
},
38+
"traces": {
39+
"chain_id", "block_number", "block_hash", "block_timestamp", "transaction_hash",
40+
"transaction_index", "subtraces", "trace_address", "type", "call_type", "error",
41+
"from_address", "to_address", "gas", "gas_used", "input", "output", "value",
42+
"author", "reward_type", "refund_address", "insert_timestamp", "sign",
43+
},
44+
}
45+
46+
// ValidateGroupByAndSortBy validates that GroupBy and SortBy fields are valid for the given entity
47+
// It checks that fields are either:
48+
// 1. Valid entity columns
49+
// 2. Valid aggregate function aliases (e.g., "count", "total_amount")
50+
func ValidateGroupByAndSortBy(entity string, groupBy []string, sortBy string, aggregates []string) error {
51+
// Get valid columns for the entity
52+
validColumns, exists := EntityColumns[entity]
53+
if !exists {
54+
return fmt.Errorf("unknown entity: %s", entity)
55+
}
56+
57+
// Create a set of valid fields (entity columns + aggregate aliases)
58+
validFields := make(map[string]bool)
59+
for _, col := range validColumns {
60+
validFields[col] = true
61+
}
62+
63+
// Add aggregate function aliases
64+
aggregateAliases := extractAggregateAliases(aggregates)
65+
for _, alias := range aggregateAliases {
66+
validFields[alias] = true
67+
}
68+
69+
// Validate GroupBy fields
70+
for _, field := range groupBy {
71+
if !validFields[field] {
72+
return fmt.Errorf("invalid group_by field '%s' for entity '%s'. Valid fields are: %s",
73+
field, entity, strings.Join(getValidFieldsList(validFields), ", "))
74+
}
75+
}
76+
77+
// Validate SortBy field
78+
if sortBy != "" && !validFields[sortBy] {
79+
return fmt.Errorf("invalid sort_by field '%s' for entity '%s'. Valid fields are: %s",
80+
sortBy, entity, strings.Join(getValidFieldsList(validFields), ", "))
81+
}
82+
83+
return nil
84+
}
85+
86+
// extractAggregateAliases extracts column aliases from aggregate functions
87+
// Examples:
88+
// - "COUNT(*) AS count" -> "count"
89+
// - "SUM(amount) AS total_amount" -> "total_amount"
90+
// - "AVG(value) as avg_value" -> "avg_value"
91+
func extractAggregateAliases(aggregates []string) []string {
92+
var aliases []string
93+
aliasRegex := regexp.MustCompile(`(?i)\s+AS\s+([a-zA-Z_][a-zA-Z0-9_]*)`)
94+
95+
for _, aggregate := range aggregates {
96+
matches := aliasRegex.FindStringSubmatch(aggregate)
97+
if len(matches) > 1 {
98+
aliases = append(aliases, matches[1])
99+
}
100+
}
101+
102+
return aliases
103+
}
104+
105+
// getValidFieldsList converts the validFields map to a sorted list for error messages
106+
func getValidFieldsList(validFields map[string]bool) []string {
107+
var fields []string
108+
for field := range validFields {
109+
fields = append(fields, field)
110+
}
111+
// Sort for consistent error messages
112+
// Note: In a production environment, you might want to use sort.Strings(fields)
113+
return fields
114+
}

api/field_validation_test.go

Lines changed: 193 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,193 @@
1+
package api
2+
3+
import (
4+
"strings"
5+
"testing"
6+
)
7+
8+
func TestValidateGroupByAndSortBy(t *testing.T) {
9+
tests := []struct {
10+
name string
11+
entity string
12+
groupBy []string
13+
sortBy string
14+
aggregates []string
15+
wantErr bool
16+
errMsg string
17+
}{
18+
{
19+
name: "valid blocks fields",
20+
entity: "blocks",
21+
groupBy: []string{"block_number", "hash"},
22+
sortBy: "block_timestamp",
23+
aggregates: nil,
24+
wantErr: false,
25+
},
26+
{
27+
name: "valid transactions fields",
28+
entity: "transactions",
29+
groupBy: []string{"from_address", "to_address"},
30+
sortBy: "value",
31+
aggregates: nil,
32+
wantErr: false,
33+
},
34+
{
35+
name: "valid logs fields",
36+
entity: "logs",
37+
groupBy: []string{"address", "topic_0"},
38+
sortBy: "block_number",
39+
aggregates: nil,
40+
wantErr: false,
41+
},
42+
{
43+
name: "valid transfers fields",
44+
entity: "transfers",
45+
groupBy: []string{"token_address", "from_address"},
46+
sortBy: "amount",
47+
aggregates: nil,
48+
wantErr: false,
49+
},
50+
{
51+
name: "valid balances fields",
52+
entity: "balances",
53+
groupBy: []string{"owner", "token_id"},
54+
sortBy: "balance",
55+
aggregates: nil,
56+
wantErr: false,
57+
},
58+
{
59+
name: "valid with aggregate aliases",
60+
entity: "transactions",
61+
groupBy: []string{"from_address"},
62+
sortBy: "total_value",
63+
aggregates: []string{"SUM(value) AS total_value", "COUNT(*) AS count"},
64+
wantErr: false,
65+
},
66+
{
67+
name: "invalid entity",
68+
entity: "invalid_entity",
69+
groupBy: []string{"field"},
70+
sortBy: "field",
71+
aggregates: nil,
72+
wantErr: true,
73+
errMsg: "unknown entity: invalid_entity",
74+
},
75+
{
76+
name: "invalid group_by field",
77+
entity: "blocks",
78+
groupBy: []string{"invalid_field"},
79+
sortBy: "block_number",
80+
aggregates: nil,
81+
wantErr: true,
82+
errMsg: "invalid group_by field 'invalid_field' for entity 'blocks'",
83+
},
84+
{
85+
name: "invalid sort_by field",
86+
entity: "transactions",
87+
groupBy: []string{"hash"},
88+
sortBy: "invalid_field",
89+
aggregates: nil,
90+
wantErr: true,
91+
errMsg: "invalid sort_by field 'invalid_field' for entity 'transactions'",
92+
},
93+
{
94+
name: "invalid aggregate alias",
95+
entity: "logs",
96+
groupBy: []string{"address"},
97+
sortBy: "invalid_alias",
98+
aggregates: []string{"COUNT(*) AS count"},
99+
wantErr: true,
100+
errMsg: "invalid sort_by field 'invalid_alias' for entity 'logs'",
101+
},
102+
{
103+
name: "empty sort_by is valid",
104+
entity: "blocks",
105+
groupBy: []string{"block_number"},
106+
sortBy: "",
107+
aggregates: nil,
108+
wantErr: false,
109+
},
110+
{
111+
name: "empty group_by is valid",
112+
entity: "transactions",
113+
groupBy: []string{},
114+
sortBy: "hash",
115+
aggregates: nil,
116+
wantErr: false,
117+
},
118+
}
119+
120+
for _, tt := range tests {
121+
t.Run(tt.name, func(t *testing.T) {
122+
err := ValidateGroupByAndSortBy(tt.entity, tt.groupBy, tt.sortBy, tt.aggregates)
123+
124+
if tt.wantErr {
125+
if err == nil {
126+
t.Errorf("ValidateGroupByAndSortBy() expected error but got none")
127+
return
128+
}
129+
if tt.errMsg != "" && !strings.Contains(err.Error(), tt.errMsg) {
130+
t.Errorf("ValidateGroupByAndSortBy() error = %v, want error containing %v", err, tt.errMsg)
131+
}
132+
} else {
133+
if err != nil {
134+
t.Errorf("ValidateGroupByAndSortBy() unexpected error = %v", err)
135+
}
136+
}
137+
})
138+
}
139+
}
140+
141+
func TestExtractAggregateAliases(t *testing.T) {
142+
tests := []struct {
143+
name string
144+
aggregates []string
145+
want []string
146+
}{
147+
{
148+
name: "simple aliases",
149+
aggregates: []string{"COUNT(*) AS count", "SUM(value) AS total_value"},
150+
want: []string{"count", "total_value"},
151+
},
152+
{
153+
name: "case insensitive AS",
154+
aggregates: []string{"AVG(amount) as avg_amount", "MAX(price) As max_price"},
155+
want: []string{"avg_amount", "max_price"},
156+
},
157+
{
158+
name: "no aliases",
159+
aggregates: []string{"COUNT(*)", "SUM(value)"},
160+
want: []string{},
161+
},
162+
{
163+
name: "mixed with and without aliases",
164+
aggregates: []string{"COUNT(*) AS count", "SUM(value)", "AVG(price) as avg_price"},
165+
want: []string{"count", "avg_price"},
166+
},
167+
{
168+
name: "empty aggregates",
169+
aggregates: []string{},
170+
want: []string{},
171+
},
172+
{
173+
name: "complex aliases",
174+
aggregates: []string{"COUNT(DISTINCT address) AS unique_addresses", "SUM(CASE WHEN value > 0 THEN 1 ELSE 0 END) AS positive_transactions"},
175+
want: []string{"unique_addresses", "positive_transactions"},
176+
},
177+
}
178+
179+
for _, tt := range tests {
180+
t.Run(tt.name, func(t *testing.T) {
181+
got := extractAggregateAliases(tt.aggregates)
182+
if len(got) != len(tt.want) {
183+
t.Errorf("extractAggregateAliases() length = %v, want %v", len(got), len(tt.want))
184+
return
185+
}
186+
for i, alias := range got {
187+
if alias != tt.want[i] {
188+
t.Errorf("extractAggregateAliases()[%d] = %v, want %v", i, alias, tt.want[i])
189+
}
190+
}
191+
})
192+
}
193+
}

internal/handlers/blocks_handlers.go

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,12 @@ func handleBlocksRequest(c *gin.Context) {
4545
return
4646
}
4747

48+
// Validate GroupBy and SortBy fields
49+
if err := api.ValidateGroupByAndSortBy("blocks", queryParams.GroupBy, queryParams.SortBy, queryParams.Aggregates); err != nil {
50+
api.BadRequestErrorHandler(c, err)
51+
return
52+
}
53+
4854
mainStorage, err := getMainStorage()
4955
if err != nil {
5056
log.Error().Err(err).Msg("Error getting main storage")

internal/handlers/logs_handlers.go

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,12 @@ func handleLogsRequest(c *gin.Context) {
111111
return
112112
}
113113

114+
// Validate GroupBy and SortBy fields
115+
if err := api.ValidateGroupByAndSortBy("logs", queryParams.GroupBy, queryParams.SortBy, queryParams.Aggregates); err != nil {
116+
api.BadRequestErrorHandler(c, err)
117+
return
118+
}
119+
114120
var eventABI *abi.Event
115121
signatureHash := ""
116122
if signature != "" {

internal/handlers/token_handlers.go

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,13 @@ func GetTokenIdsByType(c *gin.Context) {
8686
// We only care about token_id and token_type
8787
columns := []string{"token_id", "token_type"}
8888
groupBy := []string{"token_id", "token_type"}
89+
sortBy := c.Query("sort_by")
90+
91+
// Validate GroupBy and SortBy fields
92+
if err := api.ValidateGroupByAndSortBy("balances", groupBy, sortBy, nil); err != nil {
93+
api.BadRequestErrorHandler(c, err)
94+
return
95+
}
8996

9097
tokenIds, err := getTokenIdsFromReq(c)
9198
if err != nil {
@@ -100,7 +107,7 @@ func GetTokenIdsByType(c *gin.Context) {
100107
ZeroBalance: hideZeroBalances,
101108
TokenIds: tokenIds,
102109
GroupBy: groupBy,
103-
SortBy: c.Query("sort_by"),
110+
SortBy: sortBy,
104111
SortOrder: c.Query("sort_order"),
105112
Page: api.ParseIntQueryParam(c.Query("page"), 0),
106113
Limit: api.ParseIntQueryParam(c.Query("limit"), 0),
@@ -189,6 +196,14 @@ func GetTokenBalancesByType(c *gin.Context) {
189196
groupBy = []string{"address", "token_id", "token_type"}
190197
}
191198

199+
sortBy := c.Query("sort_by")
200+
201+
// Validate GroupBy and SortBy fields
202+
if err := api.ValidateGroupByAndSortBy("balances", groupBy, sortBy, nil); err != nil {
203+
api.BadRequestErrorHandler(c, err)
204+
return
205+
}
206+
192207
qf := storage.BalancesQueryFilter{
193208
ChainId: chainId,
194209
Owner: owner,
@@ -197,7 +212,7 @@ func GetTokenBalancesByType(c *gin.Context) {
197212
ZeroBalance: hideZeroBalances,
198213
TokenIds: tokenIds,
199214
GroupBy: groupBy,
200-
SortBy: c.Query("sort_by"),
215+
SortBy: sortBy,
201216
SortOrder: c.Query("sort_order"),
202217
Page: api.ParseIntQueryParam(c.Query("page"), 0),
203218
Limit: api.ParseIntQueryParam(c.Query("limit"), 0),
@@ -280,6 +295,15 @@ func GetTokenHoldersByType(c *gin.Context) {
280295
api.BadRequestErrorHandler(c, fmt.Errorf("invalid token ids '%s'", err))
281296
return
282297
}
298+
299+
sortBy := c.Query("sort_by")
300+
301+
// Validate GroupBy and SortBy fields
302+
if err := api.ValidateGroupByAndSortBy("balances", groupBy, sortBy, nil); err != nil {
303+
api.BadRequestErrorHandler(c, err)
304+
return
305+
}
306+
283307
qf := storage.BalancesQueryFilter{
284308
ChainId: chainId,
285309
TokenTypes: tokenTypes,

0 commit comments

Comments
 (0)