Skip to content

Commit eb4e004

Browse files
committed
Adding more code repair
1 parent 2467355 commit eb4e004

File tree

2 files changed

+134
-15
lines changed

2 files changed

+134
-15
lines changed

sql/enumtype.go

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -250,20 +250,38 @@ func (t enumType) SQL(ctx *Context, dest []byte, v interface{}) (sqltypes.Value,
250250
if v == nil {
251251
return sqltypes.NULL, nil
252252
}
253+
253254
convertedValue, err := t.Convert(v)
254255
if err != nil {
255256
return sqltypes.Value{}, err
256257
}
257-
value, _ := t.At(int(convertedValue.(uint16)))
258+
259+
// Handle the case where Convert returns nil
260+
if convertedValue == nil {
261+
return sqltypes.NULL, nil
262+
}
263+
264+
// Safe type assertion with validation
265+
enumVal, ok := convertedValue.(uint16)
266+
if !ok {
267+
return sqltypes.Value{}, ErrConvertingToEnum.New(v)
268+
}
269+
270+
value, found := t.At(int(enumVal))
271+
if !found {
272+
return sqltypes.Value{}, ErrConvertingToEnum.New(v)
273+
}
258274

259275
resultCharset := ctx.GetCharacterSetResults()
260276
if resultCharset == CharacterSet_Unspecified || resultCharset == CharacterSet_binary {
261277
resultCharset = t.collation.CharacterSet()
262278
}
279+
263280
encodedBytes, ok := resultCharset.Encoder().Encode(encodings.StringToBytes(value))
264281
if !ok {
265282
return sqltypes.Value{}, ErrCharSetFailedToEncode.New(t.collation.CharacterSet().Name())
266283
}
284+
267285
val := appendAndSliceBytes(dest, encodedBytes)
268286

269287
return sqltypes.MakeTrusted(sqltypes.Enum, val), nil
@@ -319,19 +337,28 @@ func (t enumType) Collation() CollationID {
319337

320338
// IndexOf implements EnumType interface.
321339
func (t enumType) IndexOf(v string) int {
340+
if v == "" {
341+
return -1
342+
}
343+
322344
hashedVal, err := t.collation.HashToUint(v)
323345
if err == nil {
324346
if index, ok := t.hashedValToIndex[hashedVal]; ok {
325347
return index
326348
}
327349
}
350+
328351
/// ENUM('0','1','2')
329352
/// If you store '3', it does not match any enumeration value, so it is treated as an index and becomes '2' (the value with index 3).
330353
if parsedIndex, err := strconv.ParseInt(v, 10, 32); err == nil {
354+
if parsedIndex <= 0 || parsedIndex > int64(len(t.indexToVal)) {
355+
return -1
356+
}
331357
if _, ok := t.At(int(parsedIndex)); ok {
332358
return int(parsedIndex)
333359
}
334360
}
361+
335362
return -1
336363
}
337364

sql/system_enumtype.go

Lines changed: 106 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,8 @@ package sql
1616

1717
import (
1818
"math"
19-
_ "math"
2019
"reflect"
20+
"strconv"
2121
"strings"
2222

2323
"github.com/shopspring/decimal"
@@ -59,8 +59,26 @@ func (t systemEnumType) Compare(a interface{}, b interface{}) (int, error) {
5959
if err != nil {
6060
return 0, err
6161
}
62-
ai := as.(string)
63-
bi := bs.(string)
62+
63+
// Handle nil values that might be returned by Convert
64+
if as == nil {
65+
if bs == nil {
66+
return 0, nil
67+
}
68+
return -1, nil
69+
} else if bs == nil {
70+
return 1, nil
71+
}
72+
73+
// Safe type assertion with validation
74+
ai, ok := as.(string)
75+
if !ok {
76+
return 0, ErrInvalidSystemVariableValue.New(t.varName, a)
77+
}
78+
bi, ok := bs.(string)
79+
if !ok {
80+
return 0, ErrInvalidSystemVariableValue.New(t.varName, b)
81+
}
6482

6583
if ai == bi {
6684
return 0, nil
@@ -73,14 +91,21 @@ func (t systemEnumType) Compare(a interface{}, b interface{}) (int, error) {
7391

7492
// Convert implements Type interface.
7593
func (t systemEnumType) Convert(v interface{}) (interface{}, error) {
94+
if v == nil {
95+
return nil, ErrInvalidSystemVariableValue.New(t.varName, v)
96+
}
97+
7698
// Nil values are not accepted
7799
switch value := v.(type) {
78100
case int:
79101
if value >= 0 && value < len(t.indexToVal) {
80102
return t.indexToVal[value], nil
81103
}
82104
case uint:
83-
return t.Convert(int(value))
105+
if value <= math.MaxInt {
106+
return t.Convert(int(value))
107+
}
108+
return nil, ErrInvalidSystemVariableValue.New(t.varName, value)
84109
case int8:
85110
return t.Convert(int(value))
86111
case uint8:
@@ -92,6 +117,7 @@ func (t systemEnumType) Convert(v interface{}) (interface{}, error) {
92117
case int32:
93118
return t.Convert(int(value))
94119
case uint32:
120+
// uint32 max value is less than MaxInt, so no overflow possible
95121
return t.Convert(int(value))
96122
case int64:
97123
if value >= math.MinInt && value <= math.MaxInt {
@@ -108,7 +134,7 @@ func (t systemEnumType) Convert(v interface{}) (interface{}, error) {
108134
case float64:
109135
// Float values aren't truly accepted, but the engine will give them when it should give ints.
110136
// Therefore, if the float doesn't have a fractional portion, we treat it as an int.
111-
if value >= 0 && value <= float64(math.MaxInt) && value == float64(int(value)) {
137+
if value >= 0 && value <= float64(math.MaxInt) && value == math.Trunc(value) {
112138
return t.Convert(int(value))
113139
}
114140
return nil, ErrInvalidSystemVariableValue.New(t.varName, value)
@@ -120,10 +146,18 @@ func (t systemEnumType) Convert(v interface{}) (interface{}, error) {
120146
f, _ := value.Decimal.Float64()
121147
return t.Convert(f)
122148
}
149+
return nil, ErrInvalidSystemVariableValue.New(t.varName, v)
123150
case string:
124151
if idx, ok := t.valToIndex[strings.ToLower(value)]; ok {
125152
return t.indexToVal[idx], nil
126153
}
154+
155+
// Check if the string represents a numeric index
156+
if parsedIndex, err := strconv.ParseInt(value, 10, 32); err == nil {
157+
if parsedIndex >= 0 && parsedIndex < int64(len(t.indexToVal)) {
158+
return t.indexToVal[parsedIndex], nil
159+
}
160+
}
127161
}
128162

129163
return nil, ErrInvalidSystemVariableValue.New(t.varName, v)
@@ -135,20 +169,35 @@ func (t systemEnumType) MustConvert(v interface{}) interface{} {
135169
if err != nil {
136170
panic(err)
137171
}
172+
// Even with a nil error, Convert might return nil for invalid values
173+
if value == nil {
174+
panic(ErrInvalidSystemVariableValue.New(t.varName, v))
175+
}
138176
return value
139177
}
140178

141179
// Equals implements the Type interface.
142180
func (t systemEnumType) Equals(otherType Type) bool {
143-
if ot, ok := otherType.(systemEnumType); ok && t.varName == ot.varName && len(t.indexToVal) == len(ot.indexToVal) {
144-
for i, val := range t.indexToVal {
145-
if ot.indexToVal[i] != val {
146-
return false
147-
}
181+
if otherType == nil {
182+
return false
183+
}
184+
185+
ot, ok := otherType.(systemEnumType)
186+
if !ok {
187+
return false
188+
}
189+
190+
if t.varName != ot.varName || len(t.indexToVal) != len(ot.indexToVal) {
191+
return false
192+
}
193+
194+
for i, val := range t.indexToVal {
195+
if i >= len(ot.indexToVal) || ot.indexToVal[i] != val {
196+
return false
148197
}
149-
return true
150198
}
151-
return false
199+
200+
return true
152201
}
153202

154203
// MaxTextResponseByteLength implements the Type interface
@@ -173,7 +222,18 @@ func (t systemEnumType) SQL(ctx *Context, dest []byte, v interface{}) (sqltypes.
173222
return sqltypes.Value{}, err
174223
}
175224

176-
val := appendAndSliceString(dest, v.(string))
225+
// Check if conversion returned nil
226+
if v == nil {
227+
return sqltypes.NULL, nil
228+
}
229+
230+
// Safe type assertion with validation
231+
strValue, ok := v.(string)
232+
if !ok {
233+
return sqltypes.Value{}, ErrInvalidSystemVariableValue.New(t.varName, v)
234+
}
235+
236+
val := appendAndSliceString(dest, strValue)
177237

178238
return sqltypes.MakeTrusted(t.Type(), val), nil
179239
}
@@ -200,18 +260,50 @@ func (t systemEnumType) Zero() interface{} {
200260

201261
// EncodeValue implements SystemVariableType interface.
202262
func (t systemEnumType) EncodeValue(val interface{}) (string, error) {
203-
expectedVal, ok := val.(string)
263+
if val == nil {
264+
return "", ErrSystemVariableCodeFail.New(val, t.String())
265+
}
266+
267+
// Try to convert value to ensure it's valid for this enum
268+
convertedVal, err := t.Convert(val)
269+
if err != nil {
270+
return "", err
271+
}
272+
273+
// Ensure conversion returned a valid value
274+
if convertedVal == nil {
275+
return "", ErrSystemVariableCodeFail.New(val, t.String())
276+
}
277+
278+
expectedVal, ok := convertedVal.(string)
204279
if !ok {
205280
return "", ErrSystemVariableCodeFail.New(val, t.String())
206281
}
282+
207283
return expectedVal, nil
208284
}
209285

210286
// DecodeValue implements SystemVariableType interface.
211287
func (t systemEnumType) DecodeValue(val string) (interface{}, error) {
288+
if val == "" {
289+
return nil, ErrSystemVariableCodeFail.New(val, t.String())
290+
}
291+
212292
outVal, err := t.Convert(val)
213293
if err != nil {
214294
return nil, ErrSystemVariableCodeFail.New(val, t.String())
215295
}
296+
297+
// Ensure conversion returned a valid value
298+
if outVal == nil {
299+
return nil, ErrSystemVariableCodeFail.New(val, t.String())
300+
}
301+
302+
// Validate that the returned value is a string
303+
_, ok := outVal.(string)
304+
if !ok {
305+
return nil, ErrSystemVariableCodeFail.New(val, t.String())
306+
}
307+
216308
return outVal, nil
217309
}

0 commit comments

Comments
 (0)