@@ -16,8 +16,8 @@ package sql
1616
1717import (
1818 "math"
19- _ "math"
2019 "reflect"
20+ "strconv"
2121 "strings"
2222
2323 "github.com/shopspring/decimal"
@@ -59,8 +59,26 @@ func (t systemEnumType) Compare(a interface{}, b interface{}) (int, error) {
5959 if err != nil {
6060 return 0 , err
6161 }
62- ai := as .(string )
63- bi := bs .(string )
62+
63+ // Handle nil values that might be returned by Convert
64+ if as == nil {
65+ if bs == nil {
66+ return 0 , nil
67+ }
68+ return - 1 , nil
69+ } else if bs == nil {
70+ return 1 , nil
71+ }
72+
73+ // Safe type assertion with validation
74+ ai , ok := as .(string )
75+ if ! ok {
76+ return 0 , ErrInvalidSystemVariableValue .New (t .varName , a )
77+ }
78+ bi , ok := bs .(string )
79+ if ! ok {
80+ return 0 , ErrInvalidSystemVariableValue .New (t .varName , b )
81+ }
6482
6583 if ai == bi {
6684 return 0 , nil
@@ -73,14 +91,21 @@ func (t systemEnumType) Compare(a interface{}, b interface{}) (int, error) {
7391
7492// Convert implements Type interface.
7593func (t systemEnumType ) Convert (v interface {}) (interface {}, error ) {
94+ if v == nil {
95+ return nil , ErrInvalidSystemVariableValue .New (t .varName , v )
96+ }
97+
7698 // Nil values are not accepted
7799 switch value := v .(type ) {
78100 case int :
79101 if value >= 0 && value < len (t .indexToVal ) {
80102 return t .indexToVal [value ], nil
81103 }
82104 case uint :
83- return t .Convert (int (value ))
105+ if value <= math .MaxInt {
106+ return t .Convert (int (value ))
107+ }
108+ return nil , ErrInvalidSystemVariableValue .New (t .varName , value )
84109 case int8 :
85110 return t .Convert (int (value ))
86111 case uint8 :
@@ -92,6 +117,7 @@ func (t systemEnumType) Convert(v interface{}) (interface{}, error) {
92117 case int32 :
93118 return t .Convert (int (value ))
94119 case uint32 :
120+ // uint32 max value is less than MaxInt, so no overflow possible
95121 return t .Convert (int (value ))
96122 case int64 :
97123 if value >= math .MinInt && value <= math .MaxInt {
@@ -108,7 +134,7 @@ func (t systemEnumType) Convert(v interface{}) (interface{}, error) {
108134 case float64 :
109135 // Float values aren't truly accepted, but the engine will give them when it should give ints.
110136 // Therefore, if the float doesn't have a fractional portion, we treat it as an int.
111- if value >= 0 && value <= float64 (math .MaxInt ) && value == float64 ( int ( value ) ) {
137+ if value >= 0 && value <= float64 (math .MaxInt ) && value == math . Trunc ( value ) {
112138 return t .Convert (int (value ))
113139 }
114140 return nil , ErrInvalidSystemVariableValue .New (t .varName , value )
@@ -120,10 +146,18 @@ func (t systemEnumType) Convert(v interface{}) (interface{}, error) {
120146 f , _ := value .Decimal .Float64 ()
121147 return t .Convert (f )
122148 }
149+ return nil , ErrInvalidSystemVariableValue .New (t .varName , v )
123150 case string :
124151 if idx , ok := t .valToIndex [strings .ToLower (value )]; ok {
125152 return t .indexToVal [idx ], nil
126153 }
154+
155+ // Check if the string represents a numeric index
156+ if parsedIndex , err := strconv .ParseInt (value , 10 , 32 ); err == nil {
157+ if parsedIndex >= 0 && parsedIndex < int64 (len (t .indexToVal )) {
158+ return t .indexToVal [parsedIndex ], nil
159+ }
160+ }
127161 }
128162
129163 return nil , ErrInvalidSystemVariableValue .New (t .varName , v )
@@ -135,20 +169,35 @@ func (t systemEnumType) MustConvert(v interface{}) interface{} {
135169 if err != nil {
136170 panic (err )
137171 }
172+ // Even with a nil error, Convert might return nil for invalid values
173+ if value == nil {
174+ panic (ErrInvalidSystemVariableValue .New (t .varName , v ))
175+ }
138176 return value
139177}
140178
141179// Equals implements the Type interface.
142180func (t systemEnumType ) Equals (otherType Type ) bool {
143- if ot , ok := otherType .(systemEnumType ); ok && t .varName == ot .varName && len (t .indexToVal ) == len (ot .indexToVal ) {
144- for i , val := range t .indexToVal {
145- if ot .indexToVal [i ] != val {
146- return false
147- }
181+ if otherType == nil {
182+ return false
183+ }
184+
185+ ot , ok := otherType .(systemEnumType )
186+ if ! ok {
187+ return false
188+ }
189+
190+ if t .varName != ot .varName || len (t .indexToVal ) != len (ot .indexToVal ) {
191+ return false
192+ }
193+
194+ for i , val := range t .indexToVal {
195+ if i >= len (ot .indexToVal ) || ot .indexToVal [i ] != val {
196+ return false
148197 }
149- return true
150198 }
151- return false
199+
200+ return true
152201}
153202
154203// MaxTextResponseByteLength implements the Type interface
@@ -173,7 +222,18 @@ func (t systemEnumType) SQL(ctx *Context, dest []byte, v interface{}) (sqltypes.
173222 return sqltypes.Value {}, err
174223 }
175224
176- val := appendAndSliceString (dest , v .(string ))
225+ // Check if conversion returned nil
226+ if v == nil {
227+ return sqltypes .NULL , nil
228+ }
229+
230+ // Safe type assertion with validation
231+ strValue , ok := v .(string )
232+ if ! ok {
233+ return sqltypes.Value {}, ErrInvalidSystemVariableValue .New (t .varName , v )
234+ }
235+
236+ val := appendAndSliceString (dest , strValue )
177237
178238 return sqltypes .MakeTrusted (t .Type (), val ), nil
179239}
@@ -200,18 +260,50 @@ func (t systemEnumType) Zero() interface{} {
200260
201261// EncodeValue implements SystemVariableType interface.
202262func (t systemEnumType ) EncodeValue (val interface {}) (string , error ) {
203- expectedVal , ok := val .(string )
263+ if val == nil {
264+ return "" , ErrSystemVariableCodeFail .New (val , t .String ())
265+ }
266+
267+ // Try to convert value to ensure it's valid for this enum
268+ convertedVal , err := t .Convert (val )
269+ if err != nil {
270+ return "" , err
271+ }
272+
273+ // Ensure conversion returned a valid value
274+ if convertedVal == nil {
275+ return "" , ErrSystemVariableCodeFail .New (val , t .String ())
276+ }
277+
278+ expectedVal , ok := convertedVal .(string )
204279 if ! ok {
205280 return "" , ErrSystemVariableCodeFail .New (val , t .String ())
206281 }
282+
207283 return expectedVal , nil
208284}
209285
210286// DecodeValue implements SystemVariableType interface.
211287func (t systemEnumType ) DecodeValue (val string ) (interface {}, error ) {
288+ if val == "" {
289+ return nil , ErrSystemVariableCodeFail .New (val , t .String ())
290+ }
291+
212292 outVal , err := t .Convert (val )
213293 if err != nil {
214294 return nil , ErrSystemVariableCodeFail .New (val , t .String ())
215295 }
296+
297+ // Ensure conversion returned a valid value
298+ if outVal == nil {
299+ return nil , ErrSystemVariableCodeFail .New (val , t .String ())
300+ }
301+
302+ // Validate that the returned value is a string
303+ _ , ok := outVal .(string )
304+ if ! ok {
305+ return nil , ErrSystemVariableCodeFail .New (val , t .String ())
306+ }
307+
216308 return outVal , nil
217309}
0 commit comments