|
| 1 | +package api |
| 2 | + |
| 3 | +import ( |
| 4 | + "fmt" |
| 5 | + "regexp" |
| 6 | + "strings" |
| 7 | +) |
| 8 | + |
| 9 | +// EntityColumns defines the valid columns for each entity type |
| 10 | +var EntityColumns = map[string][]string{ |
| 11 | + "blocks": { |
| 12 | + "chain_id", "block_number", "block_timestamp", "hash", "parent_hash", "sha3_uncles", |
| 13 | + "nonce", "mix_hash", "miner", "state_root", "transactions_root", "receipts_root", |
| 14 | + "logs_bloom", "size", "extra_data", "difficulty", "total_difficulty", "transaction_count", |
| 15 | + "gas_limit", "gas_used", "withdrawals_root", "base_fee_per_gas", "insert_timestamp", "sign", |
| 16 | + }, |
| 17 | + "transactions": { |
| 18 | + "chain_id", "hash", "nonce", "block_hash", "block_number", "block_timestamp", |
| 19 | + "transaction_index", "from_address", "to_address", "value", "gas", "gas_price", |
| 20 | + "data", "function_selector", "max_fee_per_gas", "max_priority_fee_per_gas", |
| 21 | + "max_fee_per_blob_gas", "blob_versioned_hashes", "transaction_type", "r", "s", "v", |
| 22 | + "access_list", "authorization_list", "contract_address", "gas_used", "cumulative_gas_used", |
| 23 | + "effective_gas_price", "blob_gas_used", "blob_gas_price", "logs_bloom", "status", |
| 24 | + "insert_timestamp", "sign", |
| 25 | + }, |
| 26 | + "logs": { |
| 27 | + "chain_id", "block_number", "block_hash", "block_timestamp", "transaction_hash", |
| 28 | + "transaction_index", "log_index", "address", "data", "topic_0", "topic_1", "topic_2", "topic_3", |
| 29 | + "insert_timestamp", "sign", |
| 30 | + }, |
| 31 | + "transfers": { |
| 32 | + "token_type", "chain_id", "token_address", "from_address", "to_address", "block_number", |
| 33 | + "block_timestamp", "transaction_hash", "token_id", "amount", "log_index", "insert_timestamp", "sign", |
| 34 | + }, |
| 35 | + "balances": { |
| 36 | + "token_type", "chain_id", "owner", "address", "token_id", "balance", |
| 37 | + }, |
| 38 | + "traces": { |
| 39 | + "chain_id", "block_number", "block_hash", "block_timestamp", "transaction_hash", |
| 40 | + "transaction_index", "subtraces", "trace_address", "type", "call_type", "error", |
| 41 | + "from_address", "to_address", "gas", "gas_used", "input", "output", "value", |
| 42 | + "author", "reward_type", "refund_address", "insert_timestamp", "sign", |
| 43 | + }, |
| 44 | +} |
| 45 | + |
| 46 | +// ValidateGroupByAndSortBy validates that GroupBy and SortBy fields are valid for the given entity |
| 47 | +// It checks that fields are either: |
| 48 | +// 1. Valid entity columns |
| 49 | +// 2. Valid aggregate function aliases (e.g., "count", "total_amount") |
| 50 | +func ValidateGroupByAndSortBy(entity string, groupBy []string, sortBy string, aggregates []string) error { |
| 51 | + // Get valid columns for the entity |
| 52 | + validColumns, exists := EntityColumns[entity] |
| 53 | + if !exists { |
| 54 | + return fmt.Errorf("unknown entity: %s", entity) |
| 55 | + } |
| 56 | + |
| 57 | + // Create a set of valid fields (entity columns + aggregate aliases) |
| 58 | + validFields := make(map[string]bool) |
| 59 | + for _, col := range validColumns { |
| 60 | + validFields[col] = true |
| 61 | + } |
| 62 | + |
| 63 | + // Add aggregate function aliases |
| 64 | + aggregateAliases := extractAggregateAliases(aggregates) |
| 65 | + for _, alias := range aggregateAliases { |
| 66 | + validFields[alias] = true |
| 67 | + } |
| 68 | + |
| 69 | + // Validate GroupBy fields |
| 70 | + for _, field := range groupBy { |
| 71 | + if !validFields[field] { |
| 72 | + return fmt.Errorf("invalid group_by field '%s' for entity '%s'. Valid fields are: %s", |
| 73 | + field, entity, strings.Join(getValidFieldsList(validFields), ", ")) |
| 74 | + } |
| 75 | + } |
| 76 | + |
| 77 | + // Validate SortBy field |
| 78 | + if sortBy != "" && !validFields[sortBy] { |
| 79 | + return fmt.Errorf("invalid sort_by field '%s' for entity '%s'. Valid fields are: %s", |
| 80 | + sortBy, entity, strings.Join(getValidFieldsList(validFields), ", ")) |
| 81 | + } |
| 82 | + |
| 83 | + return nil |
| 84 | +} |
| 85 | + |
| 86 | +// extractAggregateAliases extracts column aliases from aggregate functions |
| 87 | +// Examples: |
| 88 | +// - "COUNT(*) AS count" -> "count" |
| 89 | +// - "SUM(amount) AS total_amount" -> "total_amount" |
| 90 | +// - "AVG(value) as avg_value" -> "avg_value" |
| 91 | +func extractAggregateAliases(aggregates []string) []string { |
| 92 | + var aliases []string |
| 93 | + aliasRegex := regexp.MustCompile(`(?i)\s+AS\s+([a-zA-Z_][a-zA-Z0-9_]*)`) |
| 94 | + |
| 95 | + for _, aggregate := range aggregates { |
| 96 | + matches := aliasRegex.FindStringSubmatch(aggregate) |
| 97 | + if len(matches) > 1 { |
| 98 | + aliases = append(aliases, matches[1]) |
| 99 | + } |
| 100 | + } |
| 101 | + |
| 102 | + return aliases |
| 103 | +} |
| 104 | + |
| 105 | +// getValidFieldsList converts the validFields map to a sorted list for error messages |
| 106 | +func getValidFieldsList(validFields map[string]bool) []string { |
| 107 | + var fields []string |
| 108 | + for field := range validFields { |
| 109 | + fields = append(fields, field) |
| 110 | + } |
| 111 | + // Sort for consistent error messages |
| 112 | + // Note: In a production environment, you might want to use sort.Strings(fields) |
| 113 | + return fields |
| 114 | +} |
0 commit comments