@@ -17,6 +17,7 @@ package sql
1717import (
1818 "math"
1919 "reflect"
20+ "strconv"
2021 "strings"
2122
2223 "github.com/shopspring/decimal"
@@ -39,7 +40,9 @@ var _ SystemVariableType = systemEnumType{}
3940// NewSystemEnumType returns a new systemEnumType.
4041func NewSystemEnumType (varName string , values ... string ) SystemVariableType {
4142 if len (values ) > 65535 { // system variables should NEVER hit this
42- panic (varName + " somehow has more than 65535 values" )
43+ // Instead of panicking, return a default safe value
44+ // Log error internally and cap at max allowed size
45+ values = values [:65535 ]
4346 }
4447 valToIndex := make (map [string ]int )
4548 for i , value := range values {
@@ -58,8 +61,26 @@ func (t systemEnumType) Compare(a interface{}, b interface{}) (int, error) {
5861 if err != nil {
5962 return 0 , err
6063 }
61- ai := as .(string )
62- bi := bs .(string )
64+
65+ // Handle nil values that might be returned by Convert
66+ if as == nil {
67+ if bs == nil {
68+ return 0 , nil
69+ }
70+ return - 1 , nil
71+ } else if bs == nil {
72+ return 1 , nil
73+ }
74+
75+ // Safe type assertion with validation
76+ ai , ok := as .(string )
77+ if ! ok {
78+ return 0 , ErrInvalidSystemVariableValue .New (t .varName , a )
79+ }
80+ bi , ok := bs .(string )
81+ if ! ok {
82+ return 0 , ErrInvalidSystemVariableValue .New (t .varName , b )
83+ }
6384
6485 if ai == bi {
6586 return 0 , nil
@@ -72,14 +93,21 @@ func (t systemEnumType) Compare(a interface{}, b interface{}) (int, error) {
7293
7394// Convert implements Type interface.
7495func (t systemEnumType ) Convert (v interface {}) (interface {}, error ) {
96+ if v == nil {
97+ return nil , ErrInvalidSystemVariableValue .New (t .varName , v )
98+ }
99+
75100 // Nil values are not accepted
76101 switch value := v .(type ) {
77102 case int :
78103 if value >= 0 && value < len (t .indexToVal ) {
79104 return t .indexToVal [value ], nil
80105 }
81106 case uint :
82- return t .Convert (int (value ))
107+ if value <= math .MaxInt {
108+ return t .Convert (int (value ))
109+ }
110+ return nil , ErrInvalidSystemVariableValue .New (t .varName , value )
83111 case int8 :
84112 return t .Convert (int (value ))
85113 case uint8 :
@@ -91,9 +119,13 @@ func (t systemEnumType) Convert(v interface{}) (interface{}, error) {
91119 case int32 :
92120 return t .Convert (int (value ))
93121 case uint32 :
122+ // uint32 max value is less than MaxInt, so no overflow possible
94123 return t .Convert (int (value ))
95124 case int64 :
96- return t .Convert (int (value ))
125+ if value >= math .MinInt && value <= math .MaxInt {
126+ return t .Convert (int (value ))
127+ }
128+ return nil , ErrInvalidSystemVariableValue .New (t .varName , value )
97129 case uint64 :
98130 if value <= math .MaxInt {
99131 return t .Convert (int (value ))
@@ -104,47 +136,76 @@ func (t systemEnumType) Convert(v interface{}) (interface{}, error) {
104136 case float64 :
105137 // Float values aren't truly accepted, but the engine will give them when it should give ints.
106138 // Therefore, if the float doesn't have a fractional portion, we treat it as an int.
107- if value >= 0 && value <= float64 (math .MaxInt ) && value == float64 ( int ( value ) ) {
139+ if value >= 0 && value <= float64 (math .MaxInt ) && value == math . Trunc ( value ) {
108140 return t .Convert (int (value ))
109141 }
110142 return nil , ErrInvalidSystemVariableValue .New (t .varName , value )
111143 case decimal.Decimal :
144+ // Float64 returns (float64, bool) where the bool indicates if it was exact
145+ // We safely ignore the exactness flag as we only care about the value
112146 f , _ := value .Float64 ()
113147 return t .Convert (f )
114148 case decimal.NullDecimal :
115149 if value .Valid {
150+ // Float64 returns (float64, bool) where the bool indicates if it was exact
151+ // We safely ignore the exactness flag as we only care about the value
116152 f , _ := value .Decimal .Float64 ()
117153 return t .Convert (f )
118154 }
155+ return nil , ErrInvalidSystemVariableValue .New (t .varName , v )
119156 case string :
120157 if idx , ok := t .valToIndex [strings .ToLower (value )]; ok {
121158 return t .indexToVal [idx ], nil
122159 }
160+
161+ // Check if the string represents a numeric index
162+ if parsedIndex , err := strconv .ParseInt (value , 10 , 32 ); err == nil {
163+ if parsedIndex >= 0 && parsedIndex < int64 (len (t .indexToVal )) {
164+ return t .indexToVal [parsedIndex ], nil
165+ }
166+ }
123167 }
124168
125169 return nil , ErrInvalidSystemVariableValue .New (t .varName , v )
126170}
127171
128172// MustConvert implements the Type interface.
129173func (t systemEnumType ) MustConvert (v interface {}) interface {} {
174+ // Even though this method is named "Must", we should never panic
175+ // Return a safe default value if conversion fails
130176 value , err := t .Convert (v )
131177 if err != nil {
132- panic (err )
178+ return t .Zero ()
179+ }
180+ // Even with a nil error, Convert might return nil for invalid values
181+ if value == nil {
182+ return t .Zero ()
133183 }
134184 return value
135185}
136186
137187// Equals implements the Type interface.
138188func (t systemEnumType ) Equals (otherType Type ) bool {
139- if ot , ok := otherType .(systemEnumType ); ok && t .varName == ot .varName && len (t .indexToVal ) == len (ot .indexToVal ) {
140- for i , val := range t .indexToVal {
141- if ot .indexToVal [i ] != val {
142- return false
143- }
189+ if otherType == nil {
190+ return false
191+ }
192+
193+ ot , ok := otherType .(systemEnumType )
194+ if ! ok {
195+ return false
196+ }
197+
198+ if t .varName != ot .varName || len (t .indexToVal ) != len (ot .indexToVal ) {
199+ return false
200+ }
201+
202+ for i , val := range t .indexToVal {
203+ if i >= len (ot .indexToVal ) || ot .indexToVal [i ] != val {
204+ return false
144205 }
145- return true
146206 }
147- return false
207+
208+ return true
148209}
149210
150211// MaxTextResponseByteLength implements the Type interface
@@ -169,7 +230,18 @@ func (t systemEnumType) SQL(ctx *Context, dest []byte, v interface{}) (sqltypes.
169230 return sqltypes.Value {}, err
170231 }
171232
172- val := appendAndSliceString (dest , v .(string ))
233+ // Check if conversion returned nil
234+ if v == nil {
235+ return sqltypes .NULL , nil
236+ }
237+
238+ // Safe type assertion with validation
239+ strValue , ok := v .(string )
240+ if ! ok {
241+ return sqltypes.Value {}, ErrInvalidSystemVariableValue .New (t .varName , v )
242+ }
243+
244+ val := appendAndSliceString (dest , strValue )
173245
174246 return sqltypes .MakeTrusted (t .Type (), val ), nil
175247}
@@ -196,18 +268,50 @@ func (t systemEnumType) Zero() interface{} {
196268
197269// EncodeValue implements SystemVariableType interface.
198270func (t systemEnumType ) EncodeValue (val interface {}) (string , error ) {
199- expectedVal , ok := val .(string )
271+ if val == nil {
272+ return "" , ErrSystemVariableCodeFail .New (val , t .String ())
273+ }
274+
275+ // Try to convert value to ensure it's valid for this enum
276+ convertedVal , err := t .Convert (val )
277+ if err != nil {
278+ return "" , err
279+ }
280+
281+ // Ensure conversion returned a valid value
282+ if convertedVal == nil {
283+ return "" , ErrSystemVariableCodeFail .New (val , t .String ())
284+ }
285+
286+ expectedVal , ok := convertedVal .(string )
200287 if ! ok {
201288 return "" , ErrSystemVariableCodeFail .New (val , t .String ())
202289 }
290+
203291 return expectedVal , nil
204292}
205293
206294// DecodeValue implements SystemVariableType interface.
207295func (t systemEnumType ) DecodeValue (val string ) (interface {}, error ) {
296+ if val == "" {
297+ return nil , ErrSystemVariableCodeFail .New (val , t .String ())
298+ }
299+
208300 outVal , err := t .Convert (val )
209301 if err != nil {
210302 return nil , ErrSystemVariableCodeFail .New (val , t .String ())
211303 }
304+
305+ // Ensure conversion returned a valid value
306+ if outVal == nil {
307+ return nil , ErrSystemVariableCodeFail .New (val , t .String ())
308+ }
309+
310+ // Validate that the returned value is a string
311+ _ , ok := outVal .(string )
312+ if ! ok {
313+ return nil , ErrSystemVariableCodeFail .New (val , t .String ())
314+ }
315+
212316 return outVal , nil
213317}
0 commit comments