Skip to content

Commit 2922b7c

Browse files
committed
feat: dynamic resolving of aggrergates results types
1 parent 6a03b6a commit 2922b7c

File tree

2 files changed

+301
-15
lines changed

2 files changed

+301
-15
lines changed

internal/storage/clickhouse.go

Lines changed: 167 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ import (
77
"encoding/json"
88
"fmt"
99
"math/big"
10+
"reflect"
1011
"strings"
1112
"sync"
1213
"time"
@@ -349,24 +350,20 @@ func (c *ClickHouseConnector) GetLogs(qf QueryFilter) (QueryResult[common.Log],
349350

350351
func (c *ClickHouseConnector) GetAggregations(table string, qf QueryFilter) (QueryResult[interface{}], error) {
351352
// Build the SELECT clause with aggregates
352-
columns := strings.Join(append(qf.GroupBy, qf.Aggregates...), ", ")
353-
query := fmt.Sprintf("SELECT %s FROM %s.%s WHERE is_deleted = 0", columns, c.cfg.Database, table)
353+
selectColumns := strings.Join(append(qf.GroupBy, qf.Aggregates...), ", ")
354+
query := fmt.Sprintf("SELECT %s FROM %s.%s WHERE is_deleted = 0", selectColumns, c.cfg.Database, table)
354355

355356
// Apply filters
356357
if qf.ChainId != nil && qf.ChainId.Sign() > 0 {
357358
query = addFilterParams("chain_id", qf.ChainId.String(), query)
358359
}
359360
query = addContractAddress(table, query, qf.ContractAddress)
360-
361361
if qf.Signature != "" {
362362
query += fmt.Sprintf(" AND topic_0 = '%s'", qf.Signature)
363363
}
364-
365364
for key, value := range qf.FilterParams {
366365
query = addFilterParams(key, strings.ToLower(value), query)
367366
}
368-
369-
// Add GROUP BY clause if specified
370367
if len(qf.GroupBy) > 0 {
371368
groupByColumns := strings.Join(qf.GroupBy, ", ")
372369
query += fmt.Sprintf(" GROUP BY %s", groupByColumns)
@@ -379,28 +376,52 @@ func (c *ClickHouseConnector) GetAggregations(table string, qf QueryFilter) (Que
379376
}
380377
defer rows.Close()
381378

379+
columnNames := rows.Columns()
380+
columnTypes := rows.ColumnTypes()
381+
382382
// Collect results
383383
var aggregates []map[string]interface{}
384384
for rows.Next() {
385-
columns := rows.Columns()
386-
values := make([]interface{}, len(columns))
387-
valuePtrs := make([]interface{}, len(columns))
388-
for i := range columns {
389-
valuePtrs[i] = &values[i]
385+
values := make([]interface{}, len(columnNames))
386+
387+
// Assign Go types based on ClickHouse types
388+
for i, colType := range columnTypes {
389+
dbType := colType.DatabaseTypeName()
390+
values[i] = mapClickHouseTypeToGoType(dbType)
390391
}
391392

392-
if err := rows.Scan(valuePtrs...); err != nil {
393-
return QueryResult[interface{}]{}, err
393+
if err := rows.Scan(values...); err != nil {
394+
return QueryResult[interface{}]{}, fmt.Errorf("failed to scan row: %w", err)
394395
}
395396

397+
// Prepare the result map for the current row
396398
result := make(map[string]interface{})
397-
for i, col := range columns {
398-
result[col] = values[i]
399+
for i, colName := range columnNames {
400+
valuePtr := values[i]
401+
value := reflect.ValueOf(valuePtr).Elem()
402+
403+
if !value.IsValid() || (value.Kind() == reflect.Ptr && value.IsNil()) {
404+
// Handle nil pointer
405+
result[colName] = nil
406+
continue
407+
}
408+
409+
// Dereference pointers to get the actual value
410+
if value.Kind() == reflect.Ptr {
411+
actualValue := value.Elem().Interface()
412+
result[colName] = actualValue
413+
} else {
414+
result[colName] = value.Interface()
415+
}
399416
}
400417

401418
aggregates = append(aggregates, result)
402419
}
403420

421+
if err := rows.Err(); err != nil {
422+
return QueryResult[interface{}]{}, fmt.Errorf("row iteration error: %w", err)
423+
}
424+
404425
return QueryResult[interface{}]{Data: nil, Aggregates: aggregates}, nil
405426
}
406427

@@ -1056,3 +1077,134 @@ func (c *ClickHouseConnector) InsertBlockData(data *[]common.BlockData) error {
10561077
}
10571078
return nil
10581079
}
1080+
1081+
func mapClickHouseTypeToGoType(dbType string) interface{} {
1082+
// Handle LowCardinality types
1083+
if strings.HasPrefix(dbType, "LowCardinality(") {
1084+
dbType = dbType[len("LowCardinality(") : len(dbType)-1]
1085+
}
1086+
1087+
// Handle Nullable types
1088+
isNullable := false
1089+
if strings.HasPrefix(dbType, "Nullable(") {
1090+
isNullable = true
1091+
dbType = dbType[len("Nullable(") : len(dbType)-1]
1092+
}
1093+
1094+
// Handle Array types
1095+
if strings.HasPrefix(dbType, "Array(") {
1096+
elementType := dbType[len("Array(") : len(dbType)-1]
1097+
// For arrays, we'll use slices of pointers to the element type
1098+
switch elementType {
1099+
case "String", "FixedString":
1100+
return new([]*string)
1101+
case "Int8", "Int16", "Int32", "Int64":
1102+
return new([]*int64)
1103+
case "UInt8", "UInt16", "UInt32", "UInt64":
1104+
return new([]*uint64)
1105+
case "Float32", "Float64":
1106+
return new([]*float64)
1107+
case "Decimal", "Decimal32", "Decimal64", "Decimal128", "Decimal256":
1108+
return new([]*big.Float)
1109+
// Add more cases as needed
1110+
default:
1111+
return new([]interface{})
1112+
}
1113+
}
1114+
1115+
// Handle parameterized types by extracting the base type
1116+
baseType := dbType
1117+
if idx := strings.Index(dbType, "("); idx != -1 {
1118+
baseType = dbType[:idx]
1119+
}
1120+
1121+
// Map basic data types
1122+
switch baseType {
1123+
// Signed integers
1124+
case "Int8":
1125+
if isNullable {
1126+
return new(*int8)
1127+
}
1128+
return new(int8)
1129+
case "Int16":
1130+
if isNullable {
1131+
return new(*int16)
1132+
}
1133+
return new(int16)
1134+
case "Int32":
1135+
if isNullable {
1136+
return new(*int32)
1137+
}
1138+
return new(int32)
1139+
case "Int64":
1140+
if isNullable {
1141+
return new(*int64)
1142+
}
1143+
return new(int64)
1144+
// Unsigned integers
1145+
case "UInt8":
1146+
if isNullable {
1147+
return new(*uint8)
1148+
}
1149+
return new(uint8)
1150+
case "UInt16":
1151+
if isNullable {
1152+
return new(*uint16)
1153+
}
1154+
return new(uint16)
1155+
case "UInt32":
1156+
if isNullable {
1157+
return new(*uint32)
1158+
}
1159+
return new(uint32)
1160+
case "UInt64":
1161+
if isNullable {
1162+
return new(*uint64)
1163+
}
1164+
return new(uint64)
1165+
// Floating-point numbers
1166+
case "Float32":
1167+
if isNullable {
1168+
return new(*float32)
1169+
}
1170+
return new(float32)
1171+
case "Float64":
1172+
if isNullable {
1173+
return new(*float64)
1174+
}
1175+
return new(float64)
1176+
// Decimal types
1177+
case "Decimal", "Decimal32", "Decimal64", "Decimal128", "Decimal256":
1178+
if isNullable {
1179+
return new(*big.Float)
1180+
}
1181+
return new(big.Float)
1182+
// String types
1183+
case "String", "FixedString", "UUID", "IPv4", "IPv6":
1184+
if isNullable {
1185+
return new(*string)
1186+
}
1187+
return new(string)
1188+
// Enums
1189+
case "Enum8", "Enum16":
1190+
if isNullable {
1191+
return new(*string)
1192+
}
1193+
return new(string)
1194+
// Date and time types
1195+
case "Date", "Date32", "DateTime", "DateTime64":
1196+
if isNullable {
1197+
return new(*time.Time)
1198+
}
1199+
return new(time.Time)
1200+
// Big integers
1201+
case "Int128", "UInt128", "Int256", "UInt256":
1202+
if isNullable {
1203+
return new(*big.Int)
1204+
}
1205+
return new(big.Int)
1206+
default:
1207+
// For unknown types, use interface{}
1208+
return new(interface{})
1209+
}
1210+
}
Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
package storage
2+
3+
import (
4+
"math/big"
5+
"reflect"
6+
"testing"
7+
"time"
8+
)
9+
10+
// TestMapClickHouseTypeToGoType tests the mapClickHouseTypeToGoType function
11+
func TestMapClickHouseTypeToGoType(t *testing.T) {
12+
testCases := []struct {
13+
dbType string
14+
expectedType interface{}
15+
}{
16+
// Signed integers
17+
{"Int8", int8(0)},
18+
{"Nullable(Int8)", (**int8)(nil)},
19+
{"Int16", int16(0)},
20+
{"Nullable(Int16)", (**int16)(nil)},
21+
{"Int32", int32(0)},
22+
{"Nullable(Int32)", (**int32)(nil)},
23+
{"Int64", int64(0)},
24+
{"Nullable(Int64)", (**int64)(nil)},
25+
// Unsigned integers
26+
{"UInt8", uint8(0)},
27+
{"Nullable(UInt8)", (**uint8)(nil)},
28+
{"UInt16", uint16(0)},
29+
{"Nullable(UInt16)", (**uint16)(nil)},
30+
{"UInt32", uint32(0)},
31+
{"Nullable(UInt32)", (**uint32)(nil)},
32+
{"UInt64", uint64(0)},
33+
{"Nullable(UInt64)", (**uint64)(nil)},
34+
// Big integers
35+
{"Int128", big.NewInt(0)},
36+
{"Nullable(Int128)", (**big.Int)(nil)},
37+
{"UInt128", big.NewInt(0)},
38+
{"Nullable(UInt128)", (**big.Int)(nil)},
39+
{"Int256", big.NewInt(0)},
40+
{"Nullable(Int256)", (**big.Int)(nil)},
41+
{"UInt256", big.NewInt(0)},
42+
{"Nullable(UInt256)", (**big.Int)(nil)},
43+
// Floating-point numbers
44+
{"Float32", float32(0)},
45+
{"Nullable(Float32)", (**float32)(nil)},
46+
{"Float64", float64(0)},
47+
{"Nullable(Float64)", (**float64)(nil)},
48+
// Decimal types
49+
{"Decimal", big.NewFloat(0)},
50+
{"Nullable(Decimal)", (**big.Float)(nil)},
51+
{"Decimal32", big.NewFloat(0)},
52+
{"Nullable(Decimal32)", (**big.Float)(nil)},
53+
{"Decimal64", big.NewFloat(0)},
54+
{"Nullable(Decimal64)", (**big.Float)(nil)},
55+
{"Decimal128", big.NewFloat(0)},
56+
{"Nullable(Decimal128)", (**big.Float)(nil)},
57+
{"Decimal256", big.NewFloat(0)},
58+
{"Nullable(Decimal256)", (**big.Float)(nil)},
59+
// String types
60+
{"String", ""},
61+
{"Nullable(String)", (**string)(nil)},
62+
{"FixedString(42)", ""},
63+
{"Nullable(FixedString(42))", (**string)(nil)},
64+
{"UUID", ""},
65+
{"Nullable(UUID)", (**string)(nil)},
66+
{"IPv4", ""},
67+
{"Nullable(IPv4)", (**string)(nil)},
68+
{"IPv6", ""},
69+
{"Nullable(IPv6)", (**string)(nil)},
70+
// Date and time types
71+
{"Date", time.Time{}},
72+
{"Nullable(Date)", (**time.Time)(nil)},
73+
{"DateTime", time.Time{}},
74+
{"Nullable(DateTime)", (**time.Time)(nil)},
75+
{"DateTime64", time.Time{}},
76+
{"Nullable(DateTime64)", (**time.Time)(nil)},
77+
// Enums
78+
{"Enum8('a' = 1, 'b' = 2)", ""},
79+
{"Nullable(Enum8('a' = 1, 'b' = 2))", (**string)(nil)},
80+
{"Enum16('a' = 1, 'b' = 2)", ""},
81+
{"Nullable(Enum16('a' = 1, 'b' = 2))", (**string)(nil)},
82+
// Arrays
83+
{"Array(Int32)", &[]*int64{}},
84+
{"Array(String)", &[]*string{}},
85+
{"Array(Float64)", &[]*float64{}},
86+
// LowCardinality
87+
{"LowCardinality(String)", ""},
88+
{"LowCardinality(Nullable(String))", (**string)(nil)},
89+
// Unknown type
90+
{"UnknownType", new(interface{})},
91+
{"Nullable(UnknownType)", new(interface{})},
92+
}
93+
94+
for _, tc := range testCases {
95+
t.Run(tc.dbType, func(t *testing.T) {
96+
result := mapClickHouseTypeToGoType(tc.dbType)
97+
98+
expectedType := reflect.TypeOf(tc.expectedType)
99+
resultType := reflect.TypeOf(result)
100+
101+
// Handle pointers
102+
if expectedType.Kind() == reflect.Ptr {
103+
if resultType.Kind() != reflect.Ptr {
104+
t.Errorf("Expected pointer type for dbType %s, got %s", tc.dbType, resultType.Kind())
105+
return
106+
}
107+
expectedElemType := expectedType.Elem()
108+
resultElemType := resultType.Elem()
109+
if expectedElemType.Kind() == reflect.Ptr {
110+
// Expected pointer to pointer
111+
if resultElemType.Kind() != reflect.Ptr {
112+
t.Errorf("Expected pointer to pointer for dbType %s, got %s", tc.dbType, resultElemType.Kind())
113+
return
114+
}
115+
expectedElemType = expectedElemType.Elem()
116+
resultElemType = resultElemType.Elem()
117+
}
118+
if expectedElemType != resultElemType {
119+
t.Errorf("Type mismatch for dbType %s: expected %s, got %s", tc.dbType, expectedElemType, resultElemType)
120+
}
121+
} else {
122+
// Non-pointer types
123+
if resultType.Kind() != reflect.Ptr {
124+
t.Errorf("Expected pointer type for dbType %s, got %s", tc.dbType, resultType.Kind())
125+
return
126+
}
127+
resultElemType := resultType.Elem()
128+
if expectedType != resultElemType {
129+
t.Errorf("Type mismatch for dbType %s: expected %s, got %s", tc.dbType, expectedType, resultElemType)
130+
}
131+
}
132+
})
133+
}
134+
}

0 commit comments

Comments
 (0)