Skip to content

Commit 3ef10b8

Browse files
authored
Merge pull request #40 from rulego/dev
Dev
2 parents 5f7076d + e435ef2 commit 3ef10b8

24 files changed

+1377
-556
lines changed

.github/workflows/ci.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,9 @@ name: CI
22

33
on:
44
push:
5-
branches: [ main, master, develop ]
5+
branches: [ main, dev ]
66
pull_request:
7-
branches: [ main, master, develop ]
7+
branches: [ main, dev ]
88

99
jobs:
1010
test:

rsql/ast.go

Lines changed: 121 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ import (
88

99
"github.com/rulego/streamsql/functions"
1010
"github.com/rulego/streamsql/types"
11+
"github.com/rulego/streamsql/utils/cast"
1112
"github.com/rulego/streamsql/window"
1213

1314
"github.com/rulego/streamsql/aggregator"
@@ -58,9 +59,16 @@ func (s *SelectStatement) ToStreamConfig() (*types.Config, string, error) {
5859
windowType = window.TypeSession
5960
}
6061

61-
params, err := parseWindowParamsWithType(s.Window.Params, windowType)
62-
if err != nil {
63-
return nil, "", fmt.Errorf("failed to parse window parameters: %w", err)
62+
// Parse window parameters - now returns array directly
63+
params := s.Window.Params
64+
65+
// Validate and convert parameters based on window type
66+
if len(params) > 0 {
67+
var err error
68+
params, err = validateWindowParams(params, windowType)
69+
if err != nil {
70+
return nil, "", fmt.Errorf("failed to validate window parameters: %w", err)
71+
}
6472
}
6573

6674
// Check if window processing is needed
@@ -80,16 +88,7 @@ func (s *SelectStatement) ToStreamConfig() (*types.Config, string, error) {
8088
if !needWindow && hasAggregation {
8189
needWindow = true
8290
windowType = window.TypeTumbling
83-
params = map[string]interface{}{
84-
"size": 10 * time.Second, // Default 10-second window
85-
}
86-
}
87-
88-
// Handle special configuration for SessionWindow
89-
var groupByKey string
90-
if windowType == window.TypeSession && len(s.GroupBy) > 0 {
91-
// For session window, use the first GROUP BY field as session key
92-
groupByKey = s.GroupBy[0]
91+
params = []interface{}{10 * time.Second} // Default 10-second window
9392
}
9493

9594
// If no aggregation functions, collect simple fields
@@ -105,10 +104,10 @@ func (s *SelectStatement) ToStreamConfig() (*types.Config, string, error) {
105104
simpleFields = append(simpleFields, fieldName+":"+field.Alias)
106105
} else {
107106
// For fields without alias, check if it's a string literal
108-
_, n, _, _, err := ParseAggregateTypeWithExpression(fieldName)
109-
if err != nil {
110-
return nil, "", err
111-
}
107+
_, n, _, _, err := ParseAggregateTypeWithExpression(fieldName)
108+
if err != nil {
109+
return nil, "", err
110+
}
112111
if n != "" {
113112
// If string literal, use parsed field name (remove quotes)
114113
simpleFields = append(simpleFields, n)
@@ -137,11 +136,11 @@ func (s *SelectStatement) ToStreamConfig() (*types.Config, string, error) {
137136
// Build Stream configuration
138137
config := types.Config{
139138
WindowConfig: types.WindowConfig{
140-
Type: windowType,
141-
Params: params,
142-
TsProp: s.Window.TsProp,
143-
TimeUnit: s.Window.TimeUnit,
144-
GroupByKey: groupByKey,
139+
Type: windowType,
140+
Params: params,
141+
TsProp: s.Window.TsProp,
142+
TimeUnit: s.Window.TimeUnit,
143+
GroupByKeys: extractGroupFields(s),
145144
},
146145
GroupFields: extractGroupFields(s),
147146
SelectFields: aggs,
@@ -245,9 +244,9 @@ func buildSelectFields(fields []Field) (aggMap map[string]aggregator.AggregateTy
245244
for _, f := range fields {
246245
if alias := f.Alias; alias != "" {
247246
t, n, _, _, parseErr := ParseAggregateTypeWithExpression(f.Expression)
248-
if parseErr != nil {
249-
return nil, nil, parseErr
250-
}
247+
if parseErr != nil {
248+
return nil, nil, parseErr
249+
}
251250
if t != "" {
252251
// Use alias as key for aggregator, not field name
253252
selectFields[alias] = t
@@ -287,11 +286,11 @@ func detectNestedAggregationRecursive(expr string, inAggregation bool) error {
287286
// 使用正则表达式匹配函数调用模式
288287
pattern := regexp.MustCompile(`(?i)([a-z_]+)\s*\(`)
289288
matches := pattern.FindAllStringSubmatchIndex(expr, -1)
290-
289+
291290
for _, match := range matches {
292291
funcStart := match[0]
293292
funcName := strings.ToLower(expr[match[2]:match[3]])
294-
293+
295294
// 检查函数是否为聚合函数
296295
if fn, exists := functions.Get(funcName); exists {
297296
switch fn.GetType() {
@@ -300,14 +299,14 @@ func detectNestedAggregationRecursive(expr string, inAggregation bool) error {
300299
if inAggregation {
301300
return fmt.Errorf("aggregate function calls cannot be nested")
302301
}
303-
302+
304303
// 找到该函数的参数部分
305304
funcEnd := findMatchingParenInternal(expr, funcStart+len(funcName))
306305
if funcEnd > funcStart {
307306
// 提取函数参数
308307
paramStart := funcStart + len(funcName) + 1
309308
params := expr[paramStart:funcEnd]
310-
309+
311310
// 在聚合函数参数内部递归检查
312311
if err := detectNestedAggregationRecursive(params, true); err != nil {
313312
return err
@@ -316,7 +315,7 @@ func detectNestedAggregationRecursive(expr string, inAggregation bool) error {
316315
}
317316
}
318317
}
319-
318+
320319
return nil
321320
}
322321

@@ -697,43 +696,103 @@ func extractSimpleField(fieldExpr string) string {
697696
return fieldExpr
698697
}
699698

700-
func parseWindowParams(params []interface{}) (map[string]interface{}, error) {
701-
return parseWindowParamsWithType(params, "")
702-
}
699+
// validateWindowParams validates and converts window parameters based on window type
700+
// Returns validated parameters array with proper types
701+
func validateWindowParams(params []interface{}, windowType string) ([]interface{}, error) {
702+
if len(params) == 0 {
703+
return params, nil
704+
}
703705

704-
func parseWindowParamsWithType(params []interface{}, windowType string) (map[string]interface{}, error) {
705-
result := make(map[string]interface{})
706-
var key string
706+
validated := make([]interface{}, 0, len(params))
707+
708+
if windowType == window.TypeCounting {
709+
// CountingWindow expects integer count as first parameter
710+
if len(params) == 0 {
711+
return nil, fmt.Errorf("counting window requires at least one parameter")
712+
}
713+
714+
// Convert first parameter to int using cast utility
715+
count, err := cast.ToIntE(params[0])
716+
if err != nil {
717+
return nil, fmt.Errorf("invalid count parameter: %w", err)
718+
}
719+
720+
if count <= 0 {
721+
return nil, fmt.Errorf("counting window count must be positive, got: %d", count)
722+
}
723+
724+
validated = append(validated, count)
725+
726+
// Add any additional parameters
727+
if len(params) > 1 {
728+
validated = append(validated, params[1:]...)
729+
}
730+
731+
return validated, nil
732+
}
733+
734+
// Helper function to convert a value to time.Duration
735+
// For numeric types, treats them as seconds
736+
// For strings, uses time.ParseDuration
737+
convertToDuration := func(val interface{}) (time.Duration, error) {
738+
switch v := val.(type) {
739+
case time.Duration:
740+
return v, nil
741+
case string:
742+
// Use ToDurationE which handles string parsing
743+
return cast.ToDurationE(v)
744+
case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64:
745+
// Treat numeric integers as seconds
746+
return time.Duration(cast.ToInt(v)) * time.Second, nil
747+
case float32, float64:
748+
// Treat numeric floats as seconds
749+
return time.Duration(int(cast.ToFloat64(v))) * time.Second, nil
750+
default:
751+
// Try ToDurationE as fallback
752+
return cast.ToDurationE(v)
753+
}
754+
}
755+
756+
if windowType == window.TypeSession {
757+
// SessionWindow expects timeout duration as first parameter
758+
if len(params) == 0 {
759+
return nil, fmt.Errorf("session window requires at least one parameter")
760+
}
761+
762+
timeout, err := convertToDuration(params[0])
763+
if err != nil {
764+
return nil, fmt.Errorf("invalid timeout duration: %w", err)
765+
}
766+
767+
if timeout <= 0 {
768+
return nil, fmt.Errorf("session window timeout must be positive, got: %v", timeout)
769+
}
770+
771+
validated = append(validated, timeout)
772+
773+
// Add any additional parameters
774+
if len(params) > 1 {
775+
validated = append(validated, params[1:]...)
776+
}
777+
778+
return validated, nil
779+
}
780+
781+
// For TumblingWindow and SlidingWindow, convert parameters to time.Duration
707782
for index, v := range params {
708-
if windowType == window.TypeSession {
709-
// First parameter for SessionWindow is timeout
710-
if index == 0 {
711-
key = "timeout"
712-
} else {
713-
key = fmt.Sprintf("param%d", index)
714-
}
715-
} else {
716-
// Parameters for other window types
717-
if index == 0 {
718-
key = "size"
719-
} else if index == 1 {
720-
key = "slide"
721-
} else {
722-
key = "offset"
723-
}
783+
dur, err := convertToDuration(v)
784+
if err != nil {
785+
return nil, fmt.Errorf("invalid duration parameter at index %d: %w", index, err)
724786
}
725-
if s, ok := v.(string); ok {
726-
dur, err := time.ParseDuration(s)
727-
if err != nil {
728-
return nil, fmt.Errorf("invalid %s duration: %w", s, err)
729-
}
730-
result[key] = dur
731-
} else {
732-
return nil, fmt.Errorf("%s parameter must be string format (like '5s')", s)
787+
788+
if dur <= 0 {
789+
return nil, fmt.Errorf("duration parameter at index %d must be positive, got: %v", index, dur)
733790
}
791+
792+
validated = append(validated, dur)
734793
}
735794

736-
return result, nil
795+
return validated, nil
737796
}
738797

739798
func parseAggregateExpression(expr string) string {
@@ -958,7 +1017,7 @@ func parseComplexAggExpressionInternal(expr string) ([]types.AggregationFieldInf
9581017
if err := detectNestedAggregation(expr); err != nil {
9591018
return nil, "", err
9601019
}
961-
1020+
9621021
// 使用改进的递归解析方法
9631022
aggFields, exprTemplate := parseNestedFunctionsInternal(expr, make([]types.AggregationFieldInfo, 0))
9641023
return aggFields, exprTemplate, nil

rsql/ast_test.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -251,8 +251,8 @@ func TestSelectStatementEdgeCases(t *testing.T) {
251251
if config2.WindowConfig.Type != window.TypeSession {
252252
t.Errorf("Expected session window, got %v", config2.WindowConfig.Type)
253253
}
254-
if config2.WindowConfig.GroupByKey != "user_id" {
255-
t.Errorf("Expected GroupByKey to be 'user_id', got %s", config2.WindowConfig.GroupByKey)
254+
if len(config2.WindowConfig.GroupByKeys) == 0 || config2.WindowConfig.GroupByKeys[0] != "user_id" {
255+
t.Errorf("Expected GroupByKeys to contain 'user_id', got %v", config2.WindowConfig.GroupByKeys)
256256
}
257257
}
258258

rsql/coverage_test.go

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import (
66

77
"github.com/rulego/streamsql/aggregator"
88
"github.com/rulego/streamsql/types"
9+
"github.com/rulego/streamsql/window"
910
)
1011

1112
// TestParseSmartParameters 测试智能参数解析函数
@@ -202,6 +203,12 @@ func TestParseWindowParams(t *testing.T) {
202203
windowType: "SLIDINGWINDOW",
203204
expectError: false,
204205
},
206+
{
207+
name: "计数窗口参数",
208+
params: []interface{}{100},
209+
windowType: "COUNTINGWINDOW",
210+
expectError: false,
211+
},
205212
{
206213
name: "无效持续时间",
207214
params: []interface{}{"invalid"},
@@ -212,7 +219,7 @@ func TestParseWindowParams(t *testing.T) {
212219
name: "非字符串参数",
213220
params: []interface{}{123},
214221
windowType: "TUMBLINGWINDOW",
215-
expectError: true,
222+
expectError: false, // 整数参数会被视为秒数,这是有效的
216223
},
217224
{
218225
name: "空参数",
@@ -224,15 +231,24 @@ func TestParseWindowParams(t *testing.T) {
224231

225232
for _, tt := range tests {
226233
t.Run(tt.name, func(t *testing.T) {
227-
var result map[string]interface{}
234+
var result []interface{}
228235
var err error
229236

230-
if tt.windowType == "SESSIONWINDOW" {
231-
result, err = parseWindowParamsWithType(tt.params, "SESSIONWINDOW")
232-
} else {
233-
result, err = parseWindowParams(tt.params)
237+
// Convert window type to internal format
238+
windowType := ""
239+
switch tt.windowType {
240+
case "SESSIONWINDOW":
241+
windowType = window.TypeSession
242+
case "TUMBLINGWINDOW":
243+
windowType = window.TypeTumbling
244+
case "SLIDINGWINDOW":
245+
windowType = window.TypeSliding
246+
case "COUNTINGWINDOW":
247+
windowType = window.TypeCounting
234248
}
235249

250+
result, err = validateWindowParams(tt.params, windowType)
251+
236252
if tt.expectError {
237253
if err == nil {
238254
t.Errorf("Expected error but got none")

0 commit comments

Comments
 (0)