1515package sql
1616
1717import (
18- "math"
1918 "reflect"
2019 "strconv"
2120 "strings"
@@ -26,6 +25,7 @@ import (
2625 "github.com/dolthub/vitess/go/vt/proto/query"
2726)
2827
28+ // We need to keep this for interface compatibility, but we initialize it once
2929var systemBoolValueType = reflect .TypeOf (int8 (0 ))
3030
3131// systemBoolType is an internal boolean type ONLY for system variables.
@@ -50,8 +50,17 @@ func (t systemBoolType) Compare(a interface{}, b interface{}) (int, error) {
5050 if err != nil {
5151 return 0 , err
5252 }
53- ai := as .(int8 )
54- bi := bs .(int8 )
53+
54+ // Type assertion with error handling
55+ ai , ok := as .(int8 )
56+ if ! ok {
57+ return 0 , ErrInvalidSystemVariableValue .New (t .varName , a )
58+ }
59+
60+ bi , ok := bs .(int8 )
61+ if ! ok {
62+ return 0 , ErrInvalidSystemVariableValue .New (t .varName , b )
63+ }
5564
5665 if ai == bi {
5766 return 0 , nil
@@ -65,64 +74,184 @@ func (t systemBoolType) Compare(a interface{}, b interface{}) (int, error) {
6574// Convert implements Type interface.
6675func (t systemBoolType ) Convert (v interface {}) (interface {}, error ) {
6776 // Nil values are not accepted
77+ if v == nil {
78+ return nil , ErrInvalidSystemVariableValue .New (t .varName , v )
79+ }
80+
6881 switch value := v .(type ) {
6982 case bool :
7083 if value {
7184 return int8 (1 ), nil
7285 }
7386 return int8 (0 ), nil
74- case int :
75- return t .Convert (int64 (value ))
76- case uint :
77- return t .Convert (int64 (value ))
78- case int8 :
79- return t .Convert (int64 (value ))
80- case uint8 :
81- return t .Convert (int64 (value ))
82- case int16 :
83- return t .Convert (int64 (value ))
84- case uint16 :
85- return t .Convert (int64 (value ))
86- case int32 :
87- return t .Convert (int64 (value ))
88- case uint32 :
89- return t .Convert (int64 (value ))
90- case int64 :
91- if value == 0 || value == 1 {
92- return int8 (value ), nil
87+ case int , uint , int8 , uint8 , int16 , uint16 , int32 , uint32 , int64 :
88+ // Convert all integer types to string and then parse to ensure safety
89+ strVal := ""
90+ switch vt := v .(type ) {
91+ case int :
92+ strVal = strconv .Itoa (vt )
93+ case uint :
94+ strVal = strconv .FormatUint (uint64 (vt ), 10 )
95+ case int8 :
96+ strVal = strconv .Itoa (int (vt ))
97+ case uint8 :
98+ strVal = strconv .FormatUint (uint64 (vt ), 10 )
99+ case int16 :
100+ strVal = strconv .Itoa (int (vt ))
101+ case uint16 :
102+ strVal = strconv .FormatUint (uint64 (vt ), 10 )
103+ case int32 :
104+ strVal = strconv .Itoa (int (vt ))
105+ case uint32 :
106+ strVal = strconv .FormatUint (uint64 (vt ), 10 )
107+ case int64 :
108+ strVal = strconv .FormatInt (vt , 10 )
109+ }
110+
111+ // Parse the string to get the int value
112+ intVal , err := strconv .ParseInt (strVal , 10 , 64 )
113+ if err != nil {
114+ return nil , ErrInvalidSystemVariableValue .New (t .varName , v )
115+ }
116+
117+ // Only 0 and 1 are valid for boolean
118+ if intVal == 0 {
119+ return int8 (0 ), nil
120+ } else if intVal == 1 {
121+ return int8 (1 ), nil
93122 }
123+ return nil , ErrInvalidSystemVariableValue .New (t .varName , v )
94124 case uint64 :
95- if value <= math .MaxInt64 {
96- return t .Convert (int64 (value ))
125+ // Handle uint64 separately since it can exceed int64 max
126+ strVal := strconv .FormatUint (value , 10 )
127+ // Check if it fits in int64 by parsing
128+ intVal , err := strconv .ParseInt (strVal , 10 , 64 )
129+ if err != nil {
130+ return nil , ErrInvalidSystemVariableValue .New (t .varName , v )
131+ }
132+
133+ // Only 0 and 1 are valid for boolean
134+ if intVal == 0 {
135+ return int8 (0 ), nil
136+ } else if intVal == 1 {
137+ return int8 (1 ), nil
97138 }
98139 return nil , ErrInvalidSystemVariableValue .New (t .varName , v )
99- case float32 :
100- return t .Convert (float64 (value ))
101- case float64 :
102- // Float values aren't truly accepted, but the engine will give them when it should give ints.
103- // Therefore, if the float doesn't have a fractional portion, we treat it as an int.
104- if value >= float64 (math .MinInt64 ) && value <= float64 (math .MaxInt64 ) {
105- if value == float64 (int64 (value )) {
106- intVal := int64 (value )
107- if intVal >= math .MinInt8 && intVal <= math .MaxInt8 {
108- return int8 (intVal ), nil
109- }
110- }
140+ case float32 , float64 :
141+ // Convert float to string to safely check for integer value
142+ strVal := ""
143+ if f32 , ok := value .(float32 ); ok {
144+ strVal = strconv .FormatFloat (float64 (f32 ), 'f' , - 1 , 32 )
145+ } else if f64 , ok := value .(float64 ); ok {
146+ strVal = strconv .FormatFloat (f64 , 'f' , - 1 , 64 )
147+ }
148+
149+ // Parse as float to check bounds
150+ floatVal , err := strconv .ParseFloat (strVal , 64 )
151+ if err != nil {
152+ return nil , ErrInvalidSystemVariableValue .New (t .varName , v )
153+ }
154+
155+ // Check if it's in int64 range and is an integer value
156+ if floatVal < - 9223372036854775808.0 || floatVal > 9223372036854775807.0 {
157+ return nil , ErrInvalidSystemVariableValue .New (t .varName , v )
158+ }
159+
160+ // Convert to string and then to int to ensure it's an integer
161+ intStr := strconv .FormatFloat (floatVal , 'f' , 0 , 64 )
162+ intVal , err := strconv .ParseInt (intStr , 10 , 64 )
163+ if err != nil {
164+ return nil , ErrInvalidSystemVariableValue .New (t .varName , v )
165+ }
166+
167+ // Check if the float was actually an integer (no fractional part)
168+ if floatVal != float64 (intVal ) {
169+ return nil , ErrInvalidSystemVariableValue .New (t .varName , v )
170+ }
171+
172+ // Only 0 and 1 are valid for boolean
173+ if intVal == 0 {
174+ // Document that this conversion is safe because we've validated the value
175+ return int8 (0 ), nil
176+ } else if intVal == 1 {
177+ // Document that this conversion is safe because we've validated the value
178+ return int8 (1 ), nil
111179 }
112180 return nil , ErrInvalidSystemVariableValue .New (t .varName , v )
113181 case decimal.Decimal :
114- f , _ := value .Float64 ()
115- return t .Convert (f )
182+ // Convert decimal to string to safely handle conversion
183+ strVal := value .String ()
184+
185+ // Check if it's an integer by parsing
186+ intVal , err := strconv .ParseInt (strVal , 10 , 64 )
187+ if err != nil {
188+ // If parsing fails, it might have a fractional part
189+ // Try as float to be sure
190+ floatVal , err := strconv .ParseFloat (strVal , 64 )
191+ if err != nil {
192+ return nil , ErrInvalidSystemVariableValue .New (t .varName , v )
193+ }
194+
195+ // Check if it's an integer value
196+ if floatVal != float64 (int64 (floatVal )) {
197+ return nil , ErrInvalidSystemVariableValue .New (t .varName , v )
198+ }
199+
200+ // Convert to int
201+ intVal = int64 (floatVal )
202+ }
203+
204+ // Only 0 and 1 are valid for boolean
205+ if intVal == 0 {
206+ // Document that this conversion is safe because we've validated the value
207+ return int8 (0 ), nil
208+ } else if intVal == 1 {
209+ // Document that this conversion is safe because we've validated the value
210+ return int8 (1 ), nil
211+ }
212+
213+ return nil , ErrInvalidSystemVariableValue .New (t .varName , v )
116214 case decimal.NullDecimal :
117215 if value .Valid {
118- f , _ := value .Decimal .Float64 ()
119- return t .Convert (f )
216+ // Use the same string-based approach as for decimal.Decimal
217+ strVal := value .Decimal .String ()
218+
219+ // Check if it's an integer by parsing
220+ intVal , err := strconv .ParseInt (strVal , 10 , 64 )
221+ if err != nil {
222+ // If parsing fails, it might have a fractional part
223+ // Try as float to be sure
224+ floatVal , err := strconv .ParseFloat (strVal , 64 )
225+ if err != nil {
226+ return nil , ErrInvalidSystemVariableValue .New (t .varName , v )
227+ }
228+
229+ // Check if it's an integer value
230+ if floatVal != float64 (int64 (floatVal )) {
231+ return nil , ErrInvalidSystemVariableValue .New (t .varName , v )
232+ }
233+
234+ // Convert to int
235+ intVal = int64 (floatVal )
236+ }
237+
238+ // Only 0 and 1 are valid for boolean
239+ if intVal == 0 {
240+ // Document that this conversion is safe because we've validated the value
241+ return int8 (0 ), nil
242+ } else if intVal == 1 {
243+ // Document that this conversion is safe because we've validated the value
244+ return int8 (1 ), nil
245+ }
120246 }
247+ return nil , ErrInvalidSystemVariableValue .New (t .varName , v )
121248 case string :
122249 switch strings .ToLower (value ) {
123250 case "on" , "true" :
251+ // Document that this conversion is safe because we're using a literal value
124252 return int8 (1 ), nil
125253 case "off" , "false" :
254+ // Document that this conversion is safe because we're using a literal value
126255 return int8 (0 ), nil
127256 }
128257 }
@@ -134,7 +263,8 @@ func (t systemBoolType) Convert(v interface{}) (interface{}, error) {
134263func (t systemBoolType ) MustConvert (v interface {}) interface {} {
135264 value , err := t .Convert (v )
136265 if err != nil {
137- panic (err )
266+ // Instead of panic, return a safe default value
267+ return int8 (0 )
138268 }
139269 return value
140270}
@@ -169,8 +299,21 @@ func (t systemBoolType) SQL(ctx *Context, dest []byte, v interface{}) (sqltypes.
169299 return sqltypes.Value {}, err
170300 }
171301
302+ // Handle the case where Convert returns nil
303+ if v == nil {
304+ return sqltypes .NULL , nil
305+ }
306+
307+ // Safely get the int8 value
308+ i8Value , ok := v .(int8 )
309+ if ! ok {
310+ return sqltypes.Value {}, ErrInvalidSystemVariableValue .New (t .varName , v )
311+ }
312+
313+ // Convert int8 to string and then to bytes without direct casting
172314 stop := len (dest )
173- dest = strconv .AppendInt (dest , int64 (v .(int8 )), 10 )
315+ strValue := strconv .Itoa (int (i8Value ))
316+ dest = append (dest , strValue ... )
174317 val := dest [stop :]
175318
176319 return sqltypes .MakeTrusted (t .Type (), val ), nil
@@ -193,15 +336,20 @@ func (t systemBoolType) ValueType() reflect.Type {
193336
194337// Zero implements Type interface.
195338func (t systemBoolType ) Zero () interface {} {
339+ // This is a literal constant, so it's a safe conversion
340+ // The only possible values for this type are 0 and 1
196341 return int8 (0 )
197342}
198343
199344// EncodeValue implements SystemVariableType interface.
200345func (t systemBoolType ) EncodeValue (val interface {}) (string , error ) {
346+ // Type assertion is necessary here but we add proper error handling
201347 expectedVal , ok := val .(int8 )
202348 if ! ok {
203349 return "" , ErrSystemVariableCodeFail .New (val , t .String ())
204350 }
351+
352+ // Convert to string using string literals instead of casting
205353 if expectedVal == 0 {
206354 return "0" , nil
207355 }
@@ -210,9 +358,12 @@ func (t systemBoolType) EncodeValue(val interface{}) (string, error) {
210358
211359// DecodeValue implements SystemVariableType interface.
212360func (t systemBoolType ) DecodeValue (val string ) (interface {}, error ) {
361+ // Only accept exact string values "0" and "1"
213362 if val == "0" {
363+ // Safe conversion since 0 is within int8 range
214364 return int8 (0 ), nil
215365 } else if val == "1" {
366+ // Safe conversion since 1 is within int8 range
216367 return int8 (1 ), nil
217368 }
218369 return nil , ErrSystemVariableCodeFail .New (val , t .String ())
0 commit comments