Skip to content

Commit 190313b

Browse files
committed
feat: dynamic resolving of aggrergates results types
1 parent 6a03b6a commit 190313b

File tree

1 file changed

+137
-15
lines changed

1 file changed

+137
-15
lines changed

internal/storage/clickhouse.go

Lines changed: 137 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ import (
77
"encoding/json"
88
"fmt"
99
"math/big"
10+
"reflect"
1011
"strings"
1112
"sync"
1213
"time"
@@ -349,24 +350,20 @@ func (c *ClickHouseConnector) GetLogs(qf QueryFilter) (QueryResult[common.Log],
349350

350351
func (c *ClickHouseConnector) GetAggregations(table string, qf QueryFilter) (QueryResult[interface{}], error) {
351352
// Build the SELECT clause with aggregates
352-
columns := strings.Join(append(qf.GroupBy, qf.Aggregates...), ", ")
353-
query := fmt.Sprintf("SELECT %s FROM %s.%s WHERE is_deleted = 0", columns, c.cfg.Database, table)
353+
selectColumns := strings.Join(append(qf.GroupBy, qf.Aggregates...), ", ")
354+
query := fmt.Sprintf("SELECT %s FROM %s.%s WHERE is_deleted = 0", selectColumns, c.cfg.Database, table)
354355

355356
// Apply filters
356357
if qf.ChainId != nil && qf.ChainId.Sign() > 0 {
357358
query = addFilterParams("chain_id", qf.ChainId.String(), query)
358359
}
359360
query = addContractAddress(table, query, qf.ContractAddress)
360-
361361
if qf.Signature != "" {
362362
query += fmt.Sprintf(" AND topic_0 = '%s'", qf.Signature)
363363
}
364-
365364
for key, value := range qf.FilterParams {
366365
query = addFilterParams(key, strings.ToLower(value), query)
367366
}
368-
369-
// Add GROUP BY clause if specified
370367
if len(qf.GroupBy) > 0 {
371368
groupByColumns := strings.Join(qf.GroupBy, ", ")
372369
query += fmt.Sprintf(" GROUP BY %s", groupByColumns)
@@ -379,28 +376,52 @@ func (c *ClickHouseConnector) GetAggregations(table string, qf QueryFilter) (Que
379376
}
380377
defer rows.Close()
381378

379+
columnNames := rows.Columns()
380+
columnTypes := rows.ColumnTypes()
381+
382382
// Collect results
383383
var aggregates []map[string]interface{}
384384
for rows.Next() {
385-
columns := rows.Columns()
386-
values := make([]interface{}, len(columns))
387-
valuePtrs := make([]interface{}, len(columns))
388-
for i := range columns {
389-
valuePtrs[i] = &values[i]
385+
values := make([]interface{}, len(columnNames))
386+
387+
// Assign Go types based on ClickHouse types
388+
for i, colType := range columnTypes {
389+
dbType := colType.DatabaseTypeName()
390+
values[i] = mapClickHouseTypeToGoType(dbType)
390391
}
391392

392-
if err := rows.Scan(valuePtrs...); err != nil {
393-
return QueryResult[interface{}]{}, err
393+
if err := rows.Scan(values...); err != nil {
394+
return QueryResult[interface{}]{}, fmt.Errorf("failed to scan row: %w", err)
394395
}
395396

397+
// Prepare the result map for the current row
396398
result := make(map[string]interface{})
397-
for i, col := range columns {
398-
result[col] = values[i]
399+
for i, colName := range columnNames {
400+
valuePtr := values[i]
401+
value := reflect.ValueOf(valuePtr).Elem()
402+
403+
if !value.IsValid() || (value.Kind() == reflect.Ptr && value.IsNil()) {
404+
// Handle nil pointer
405+
result[colName] = nil
406+
continue
407+
}
408+
409+
// Dereference pointers to get the actual value
410+
if value.Kind() == reflect.Ptr {
411+
actualValue := value.Elem().Interface()
412+
result[colName] = actualValue
413+
} else {
414+
result[colName] = value.Interface()
415+
}
399416
}
400417

401418
aggregates = append(aggregates, result)
402419
}
403420

421+
if err := rows.Err(); err != nil {
422+
return QueryResult[interface{}]{}, fmt.Errorf("row iteration error: %w", err)
423+
}
424+
404425
return QueryResult[interface{}]{Data: nil, Aggregates: aggregates}, nil
405426
}
406427

@@ -1056,3 +1077,104 @@ func (c *ClickHouseConnector) InsertBlockData(data *[]common.BlockData) error {
10561077
}
10571078
return nil
10581079
}
1080+
1081+
func mapClickHouseTypeToGoType(dbType string) interface{} {
1082+
// Handle Nullable types
1083+
isNullable := false
1084+
if strings.HasPrefix(dbType, "Nullable(") {
1085+
isNullable = true
1086+
dbType = dbType[len("Nullable(") : len(dbType)-1]
1087+
}
1088+
1089+
// Handle LowCardinality types
1090+
if strings.HasPrefix(dbType, "LowCardinality(") {
1091+
dbType = dbType[len("LowCardinality(") : len(dbType)-1]
1092+
}
1093+
1094+
// Handle Array types
1095+
if strings.HasPrefix(dbType, "Array(") {
1096+
elementType := dbType[len("Array(") : len(dbType)-1]
1097+
// For arrays, we'll use slices of pointers to the element type
1098+
switch elementType {
1099+
case "String", "FixedString":
1100+
return new([]*string)
1101+
case "Int8", "Int16", "Int32", "Int64":
1102+
return new([]*int64)
1103+
case "UInt8", "UInt16", "UInt32", "UInt64":
1104+
return new([]*uint64)
1105+
case "Float32", "Float64":
1106+
return new([]*float64)
1107+
case "Decimal", "Decimal32", "Decimal64", "Decimal128", "Decimal256":
1108+
return new([]*big.Float)
1109+
// Add more cases as needed
1110+
default:
1111+
return new([]interface{})
1112+
}
1113+
}
1114+
1115+
// Handle parameterized types by extracting the base type
1116+
baseType := dbType
1117+
if idx := strings.Index(dbType, "("); idx != -1 {
1118+
baseType = dbType[:idx]
1119+
}
1120+
1121+
// Map basic data types
1122+
switch baseType {
1123+
// Signed integers
1124+
case "Int8", "Int16", "Int32", "Int64":
1125+
if isNullable {
1126+
return new(*int64)
1127+
}
1128+
return new(int64)
1129+
// Unsigned integers
1130+
case "UInt8", "UInt16", "UInt32", "UInt64":
1131+
if isNullable {
1132+
return new(*uint64)
1133+
}
1134+
return new(uint64)
1135+
// Floating-point numbers
1136+
case "Float32":
1137+
if isNullable {
1138+
return new(*float32)
1139+
}
1140+
return new(float32)
1141+
case "Float64":
1142+
if isNullable {
1143+
return new(*float64)
1144+
}
1145+
return new(float64)
1146+
// Decimal types
1147+
case "Decimal", "Decimal32", "Decimal64", "Decimal128", "Decimal256":
1148+
if isNullable {
1149+
return new(*big.Float)
1150+
}
1151+
return new(big.Float)
1152+
// String types
1153+
case "String", "FixedString", "UUID", "IPv4", "IPv6":
1154+
if isNullable {
1155+
return new(*string)
1156+
}
1157+
return new(string)
1158+
// Enums
1159+
case "Enum8", "Enum16":
1160+
if isNullable {
1161+
return new(*string)
1162+
}
1163+
return new(string)
1164+
// Date and time types
1165+
case "Date", "Date32", "DateTime", "DateTime64":
1166+
if isNullable {
1167+
return new(*time.Time)
1168+
}
1169+
return new(time.Time)
1170+
// Big integers
1171+
case "Int128", "UInt128", "Int256", "UInt256":
1172+
if isNullable {
1173+
return new(*big.Int)
1174+
}
1175+
return new(big.Int)
1176+
default:
1177+
// For unknown types, use interface{}
1178+
return new(interface{})
1179+
}
1180+
}

0 commit comments

Comments
 (0)