Skip to content

Commit 7698103

Browse files
committed
feat: dynamic resolving of aggrergates results types
1 parent 6a03b6a commit 7698103

File tree

2 files changed

+325
-15
lines changed

2 files changed

+325
-15
lines changed

internal/storage/clickhouse.go

Lines changed: 191 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ import (
77
"encoding/json"
88
"fmt"
99
"math/big"
10+
"reflect"
1011
"strings"
1112
"sync"
1213
"time"
@@ -349,24 +350,20 @@ func (c *ClickHouseConnector) GetLogs(qf QueryFilter) (QueryResult[common.Log],
349350

350351
func (c *ClickHouseConnector) GetAggregations(table string, qf QueryFilter) (QueryResult[interface{}], error) {
351352
// Build the SELECT clause with aggregates
352-
columns := strings.Join(append(qf.GroupBy, qf.Aggregates...), ", ")
353-
query := fmt.Sprintf("SELECT %s FROM %s.%s WHERE is_deleted = 0", columns, c.cfg.Database, table)
353+
selectColumns := strings.Join(append(qf.GroupBy, qf.Aggregates...), ", ")
354+
query := fmt.Sprintf("SELECT %s FROM %s.%s WHERE is_deleted = 0", selectColumns, c.cfg.Database, table)
354355

355356
// Apply filters
356357
if qf.ChainId != nil && qf.ChainId.Sign() > 0 {
357358
query = addFilterParams("chain_id", qf.ChainId.String(), query)
358359
}
359360
query = addContractAddress(table, query, qf.ContractAddress)
360-
361361
if qf.Signature != "" {
362362
query += fmt.Sprintf(" AND topic_0 = '%s'", qf.Signature)
363363
}
364-
365364
for key, value := range qf.FilterParams {
366365
query = addFilterParams(key, strings.ToLower(value), query)
367366
}
368-
369-
// Add GROUP BY clause if specified
370367
if len(qf.GroupBy) > 0 {
371368
groupByColumns := strings.Join(qf.GroupBy, ", ")
372369
query += fmt.Sprintf(" GROUP BY %s", groupByColumns)
@@ -379,28 +376,45 @@ func (c *ClickHouseConnector) GetAggregations(table string, qf QueryFilter) (Que
379376
}
380377
defer rows.Close()
381378

379+
columnNames := rows.Columns()
380+
columnTypes := rows.ColumnTypes()
381+
382382
// Collect results
383383
var aggregates []map[string]interface{}
384384
for rows.Next() {
385-
columns := rows.Columns()
386-
values := make([]interface{}, len(columns))
387-
valuePtrs := make([]interface{}, len(columns))
388-
for i := range columns {
389-
valuePtrs[i] = &values[i]
385+
values := make([]interface{}, len(columnNames))
386+
387+
// Assign Go types based on ClickHouse types
388+
for i, colType := range columnTypes {
389+
dbType := colType.DatabaseTypeName()
390+
values[i] = mapClickHouseTypeToGoType(dbType)
390391
}
391392

392-
if err := rows.Scan(valuePtrs...); err != nil {
393-
return QueryResult[interface{}]{}, err
393+
if err := rows.Scan(values...); err != nil {
394+
return QueryResult[interface{}]{}, fmt.Errorf("failed to scan row: %w", err)
394395
}
395396

397+
// Prepare the result map for the current row
396398
result := make(map[string]interface{})
397-
for i, col := range columns {
398-
result[col] = values[i]
399+
for i, colName := range columnNames {
400+
valuePtr := values[i]
401+
value := getUnderlyingValue(valuePtr)
402+
403+
// Convert *big.Int to string
404+
if bigIntValue, ok := value.(big.Int); ok {
405+
result[colName] = BigInt{Int: bigIntValue}
406+
} else {
407+
result[colName] = value
408+
}
399409
}
400410

401411
aggregates = append(aggregates, result)
402412
}
403413

414+
if err := rows.Err(); err != nil {
415+
return QueryResult[interface{}]{}, fmt.Errorf("row iteration error: %w", err)
416+
}
417+
404418
return QueryResult[interface{}]{Data: nil, Aggregates: aggregates}, nil
405419
}
406420

@@ -1056,3 +1070,165 @@ func (c *ClickHouseConnector) InsertBlockData(data *[]common.BlockData) error {
10561070
}
10571071
return nil
10581072
}
1073+
1074+
func mapClickHouseTypeToGoType(dbType string) interface{} {
1075+
// Handle LowCardinality types
1076+
if strings.HasPrefix(dbType, "LowCardinality(") {
1077+
dbType = dbType[len("LowCardinality(") : len(dbType)-1]
1078+
}
1079+
1080+
// Handle Nullable types
1081+
isNullable := false
1082+
if strings.HasPrefix(dbType, "Nullable(") {
1083+
isNullable = true
1084+
dbType = dbType[len("Nullable(") : len(dbType)-1]
1085+
}
1086+
1087+
// Handle Array types
1088+
if strings.HasPrefix(dbType, "Array(") {
1089+
elementType := dbType[len("Array(") : len(dbType)-1]
1090+
// For arrays, we'll use slices of pointers to the element type
1091+
switch elementType {
1092+
case "String", "FixedString":
1093+
return new([]*string)
1094+
case "Int8", "Int16", "Int32", "Int64":
1095+
return new([]*int64)
1096+
case "UInt8", "UInt16", "UInt32", "UInt64":
1097+
return new([]*uint64)
1098+
case "Float32", "Float64":
1099+
return new([]*float64)
1100+
case "Decimal", "Decimal32", "Decimal64", "Decimal128", "Decimal256":
1101+
return new([]*big.Float)
1102+
// Add more cases as needed
1103+
default:
1104+
return new([]interface{})
1105+
}
1106+
}
1107+
1108+
// Handle parameterized types by extracting the base type
1109+
baseType := dbType
1110+
if idx := strings.Index(dbType, "("); idx != -1 {
1111+
baseType = dbType[:idx]
1112+
}
1113+
1114+
// Map basic data types
1115+
switch baseType {
1116+
// Signed integers
1117+
case "Int8":
1118+
if isNullable {
1119+
return new(*int8)
1120+
}
1121+
return new(int8)
1122+
case "Int16":
1123+
if isNullable {
1124+
return new(*int16)
1125+
}
1126+
return new(int16)
1127+
case "Int32":
1128+
if isNullable {
1129+
return new(*int32)
1130+
}
1131+
return new(int32)
1132+
case "Int64":
1133+
if isNullable {
1134+
return new(*int64)
1135+
}
1136+
return new(int64)
1137+
// Unsigned integers
1138+
case "UInt8":
1139+
if isNullable {
1140+
return new(*uint8)
1141+
}
1142+
return new(uint8)
1143+
case "UInt16":
1144+
if isNullable {
1145+
return new(*uint16)
1146+
}
1147+
return new(uint16)
1148+
case "UInt32":
1149+
if isNullable {
1150+
return new(*uint32)
1151+
}
1152+
return new(uint32)
1153+
case "UInt64":
1154+
if isNullable {
1155+
return new(*uint64)
1156+
}
1157+
return new(uint64)
1158+
// Floating-point numbers
1159+
case "Float32":
1160+
if isNullable {
1161+
return new(*float32)
1162+
}
1163+
return new(float32)
1164+
case "Float64":
1165+
if isNullable {
1166+
return new(*float64)
1167+
}
1168+
return new(float64)
1169+
// Decimal types
1170+
case "Decimal", "Decimal32", "Decimal64", "Decimal128", "Decimal256":
1171+
if isNullable {
1172+
return new(*big.Float)
1173+
}
1174+
return new(big.Float)
1175+
// String types
1176+
case "String", "FixedString", "UUID", "IPv4", "IPv6":
1177+
if isNullable {
1178+
return new(*string)
1179+
}
1180+
return new(string)
1181+
// Enums
1182+
case "Enum8", "Enum16":
1183+
if isNullable {
1184+
return new(*string)
1185+
}
1186+
return new(string)
1187+
// Date and time types
1188+
case "Date", "Date32", "DateTime", "DateTime64":
1189+
if isNullable {
1190+
return new(*time.Time)
1191+
}
1192+
return new(time.Time)
1193+
// Big integers
1194+
case "Int128", "UInt128", "Int256", "UInt256":
1195+
if isNullable {
1196+
return new(*big.Int)
1197+
}
1198+
return new(big.Int)
1199+
default:
1200+
// For unknown types, use interface{}
1201+
return new(interface{})
1202+
}
1203+
}
1204+
1205+
type BigInt struct {
1206+
big.Int
1207+
}
1208+
1209+
func (b BigInt) MarshalJSON() ([]byte, error) {
1210+
return []byte(`"` + b.String() + `"`), nil
1211+
}
1212+
1213+
func getUnderlyingValue(valuePtr interface{}) interface{} {
1214+
v := reflect.ValueOf(valuePtr)
1215+
1216+
// Handle nil values
1217+
if !v.IsValid() {
1218+
return nil
1219+
}
1220+
1221+
// Handle pointers and interfaces
1222+
for {
1223+
if v.Kind() == reflect.Ptr || v.Kind() == reflect.Interface {
1224+
if v.IsNil() {
1225+
return nil
1226+
}
1227+
v = v.Elem()
1228+
continue
1229+
}
1230+
break
1231+
}
1232+
1233+
return v.Interface()
1234+
}
Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
package storage
2+
3+
import (
4+
"math/big"
5+
"reflect"
6+
"testing"
7+
"time"
8+
)
9+
10+
// TestMapClickHouseTypeToGoType tests the mapClickHouseTypeToGoType function
11+
func TestMapClickHouseTypeToGoType(t *testing.T) {
12+
testCases := []struct {
13+
dbType string
14+
expectedType interface{}
15+
}{
16+
// Signed integers
17+
{"Int8", int8(0)},
18+
{"Nullable(Int8)", (**int8)(nil)},
19+
{"Int16", int16(0)},
20+
{"Nullable(Int16)", (**int16)(nil)},
21+
{"Int32", int32(0)},
22+
{"Nullable(Int32)", (**int32)(nil)},
23+
{"Int64", int64(0)},
24+
{"Nullable(Int64)", (**int64)(nil)},
25+
// Unsigned integers
26+
{"UInt8", uint8(0)},
27+
{"Nullable(UInt8)", (**uint8)(nil)},
28+
{"UInt16", uint16(0)},
29+
{"Nullable(UInt16)", (**uint16)(nil)},
30+
{"UInt32", uint32(0)},
31+
{"Nullable(UInt32)", (**uint32)(nil)},
32+
{"UInt64", uint64(0)},
33+
{"Nullable(UInt64)", (**uint64)(nil)},
34+
// Big integers
35+
{"Int128", big.NewInt(0)},
36+
{"Nullable(Int128)", (**big.Int)(nil)},
37+
{"UInt128", big.NewInt(0)},
38+
{"Nullable(UInt128)", (**big.Int)(nil)},
39+
{"Int256", big.NewInt(0)},
40+
{"Nullable(Int256)", (**big.Int)(nil)},
41+
{"UInt256", big.NewInt(0)},
42+
{"Nullable(UInt256)", (**big.Int)(nil)},
43+
// Floating-point numbers
44+
{"Float32", float32(0)},
45+
{"Nullable(Float32)", (**float32)(nil)},
46+
{"Float64", float64(0)},
47+
{"Nullable(Float64)", (**float64)(nil)},
48+
// Decimal types
49+
{"Decimal", big.NewFloat(0)},
50+
{"Nullable(Decimal)", (**big.Float)(nil)},
51+
{"Decimal32", big.NewFloat(0)},
52+
{"Nullable(Decimal32)", (**big.Float)(nil)},
53+
{"Decimal64", big.NewFloat(0)},
54+
{"Nullable(Decimal64)", (**big.Float)(nil)},
55+
{"Decimal128", big.NewFloat(0)},
56+
{"Nullable(Decimal128)", (**big.Float)(nil)},
57+
{"Decimal256", big.NewFloat(0)},
58+
{"Nullable(Decimal256)", (**big.Float)(nil)},
59+
// String types
60+
{"String", ""},
61+
{"Nullable(String)", (**string)(nil)},
62+
{"FixedString(42)", ""},
63+
{"Nullable(FixedString(42))", (**string)(nil)},
64+
{"UUID", ""},
65+
{"Nullable(UUID)", (**string)(nil)},
66+
{"IPv4", ""},
67+
{"Nullable(IPv4)", (**string)(nil)},
68+
{"IPv6", ""},
69+
{"Nullable(IPv6)", (**string)(nil)},
70+
// Date and time types
71+
{"Date", time.Time{}},
72+
{"Nullable(Date)", (**time.Time)(nil)},
73+
{"DateTime", time.Time{}},
74+
{"Nullable(DateTime)", (**time.Time)(nil)},
75+
{"DateTime64", time.Time{}},
76+
{"Nullable(DateTime64)", (**time.Time)(nil)},
77+
// Enums
78+
{"Enum8('a' = 1, 'b' = 2)", ""},
79+
{"Nullable(Enum8('a' = 1, 'b' = 2))", (**string)(nil)},
80+
{"Enum16('a' = 1, 'b' = 2)", ""},
81+
{"Nullable(Enum16('a' = 1, 'b' = 2))", (**string)(nil)},
82+
// Arrays
83+
{"Array(Int32)", &[]*int64{}},
84+
{"Array(String)", &[]*string{}},
85+
{"Array(Float64)", &[]*float64{}},
86+
// LowCardinality
87+
{"LowCardinality(String)", ""},
88+
{"LowCardinality(Nullable(String))", (**string)(nil)},
89+
// Unknown type
90+
{"UnknownType", new(interface{})},
91+
{"Nullable(UnknownType)", new(interface{})},
92+
}
93+
94+
for _, tc := range testCases {
95+
t.Run(tc.dbType, func(t *testing.T) {
96+
result := mapClickHouseTypeToGoType(tc.dbType)
97+
98+
expectedType := reflect.TypeOf(tc.expectedType)
99+
resultType := reflect.TypeOf(result)
100+
101+
// Handle pointers
102+
if expectedType.Kind() == reflect.Ptr {
103+
if resultType.Kind() != reflect.Ptr {
104+
t.Errorf("Expected pointer type for dbType %s, got %s", tc.dbType, resultType.Kind())
105+
return
106+
}
107+
expectedElemType := expectedType.Elem()
108+
resultElemType := resultType.Elem()
109+
if expectedElemType.Kind() == reflect.Ptr {
110+
// Expected pointer to pointer
111+
if resultElemType.Kind() != reflect.Ptr {
112+
t.Errorf("Expected pointer to pointer for dbType %s, got %s", tc.dbType, resultElemType.Kind())
113+
return
114+
}
115+
expectedElemType = expectedElemType.Elem()
116+
resultElemType = resultElemType.Elem()
117+
}
118+
if expectedElemType != resultElemType {
119+
t.Errorf("Type mismatch for dbType %s: expected %s, got %s", tc.dbType, expectedElemType, resultElemType)
120+
}
121+
} else {
122+
// Non-pointer types
123+
if resultType.Kind() != reflect.Ptr {
124+
t.Errorf("Expected pointer type for dbType %s, got %s", tc.dbType, resultType.Kind())
125+
return
126+
}
127+
resultElemType := resultType.Elem()
128+
if expectedType != resultElemType {
129+
t.Errorf("Type mismatch for dbType %s: expected %s, got %s", tc.dbType, expectedType, resultElemType)
130+
}
131+
}
132+
})
133+
}
134+
}

0 commit comments

Comments
 (0)