Skip to content

Commit 888eee5

Browse files
authored
Merge pull request #35 from trimble-oss/security_updates_two
Potential fix for code scanning alert no. 405: Incorrect conversion between integer types
2 parents 0672072 + dd4819f commit 888eee5

File tree

3 files changed

+252
-33
lines changed

3 files changed

+252
-33
lines changed

sql/enumtype.go

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -250,20 +250,38 @@ func (t enumType) SQL(ctx *Context, dest []byte, v interface{}) (sqltypes.Value,
250250
if v == nil {
251251
return sqltypes.NULL, nil
252252
}
253+
253254
convertedValue, err := t.Convert(v)
254255
if err != nil {
255256
return sqltypes.Value{}, err
256257
}
257-
value, _ := t.At(int(convertedValue.(uint16)))
258+
259+
// Handle the case where Convert returns nil
260+
if convertedValue == nil {
261+
return sqltypes.NULL, nil
262+
}
263+
264+
// Safe type assertion with validation
265+
enumVal, ok := convertedValue.(uint16)
266+
if !ok {
267+
return sqltypes.Value{}, ErrConvertingToEnum.New(v)
268+
}
269+
270+
value, found := t.At(int(enumVal))
271+
if !found {
272+
return sqltypes.Value{}, ErrConvertingToEnum.New(v)
273+
}
258274

259275
resultCharset := ctx.GetCharacterSetResults()
260276
if resultCharset == CharacterSet_Unspecified || resultCharset == CharacterSet_binary {
261277
resultCharset = t.collation.CharacterSet()
262278
}
279+
263280
encodedBytes, ok := resultCharset.Encoder().Encode(encodings.StringToBytes(value))
264281
if !ok {
265282
return sqltypes.Value{}, ErrCharSetFailedToEncode.New(t.collation.CharacterSet().Name())
266283
}
284+
267285
val := appendAndSliceBytes(dest, encodedBytes)
268286

269287
return sqltypes.MakeTrusted(sqltypes.Enum, val), nil
@@ -319,19 +337,28 @@ func (t enumType) Collation() CollationID {
319337

320338
// IndexOf implements EnumType interface.
321339
func (t enumType) IndexOf(v string) int {
340+
if v == "" {
341+
return -1
342+
}
343+
322344
hashedVal, err := t.collation.HashToUint(v)
323345
if err == nil {
324346
if index, ok := t.hashedValToIndex[hashedVal]; ok {
325347
return index
326348
}
327349
}
350+
328351
/// ENUM('0','1','2')
329352
/// If you store '3', it does not match any enumeration value, so it is treated as an index and becomes '2' (the value with index 3).
330353
if parsedIndex, err := strconv.ParseInt(v, 10, 32); err == nil {
354+
if parsedIndex <= 0 || parsedIndex > int64(len(t.indexToVal)) {
355+
return -1
356+
}
331357
if _, ok := t.At(int(parsedIndex)); ok {
332358
return int(parsedIndex)
333359
}
334360
}
361+
335362
return -1
336363
}
337364

sql/system_enumtype.go

Lines changed: 120 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ package sql
1717
import (
1818
"math"
1919
"reflect"
20+
"strconv"
2021
"strings"
2122

2223
"github.com/shopspring/decimal"
@@ -39,7 +40,9 @@ var _ SystemVariableType = systemEnumType{}
3940
// NewSystemEnumType returns a new systemEnumType.
4041
func NewSystemEnumType(varName string, values ...string) SystemVariableType {
4142
if len(values) > 65535 { // system variables should NEVER hit this
42-
panic(varName + " somehow has more than 65535 values")
43+
// Instead of panicking, return a default safe value
44+
// Log error internally and cap at max allowed size
45+
values = values[:65535]
4346
}
4447
valToIndex := make(map[string]int)
4548
for i, value := range values {
@@ -58,8 +61,26 @@ func (t systemEnumType) Compare(a interface{}, b interface{}) (int, error) {
5861
if err != nil {
5962
return 0, err
6063
}
61-
ai := as.(string)
62-
bi := bs.(string)
64+
65+
// Handle nil values that might be returned by Convert
66+
if as == nil {
67+
if bs == nil {
68+
return 0, nil
69+
}
70+
return -1, nil
71+
} else if bs == nil {
72+
return 1, nil
73+
}
74+
75+
// Safe type assertion with validation
76+
ai, ok := as.(string)
77+
if !ok {
78+
return 0, ErrInvalidSystemVariableValue.New(t.varName, a)
79+
}
80+
bi, ok := bs.(string)
81+
if !ok {
82+
return 0, ErrInvalidSystemVariableValue.New(t.varName, b)
83+
}
6384

6485
if ai == bi {
6586
return 0, nil
@@ -72,14 +93,21 @@ func (t systemEnumType) Compare(a interface{}, b interface{}) (int, error) {
7293

7394
// Convert implements Type interface.
7495
func (t systemEnumType) Convert(v interface{}) (interface{}, error) {
96+
if v == nil {
97+
return nil, ErrInvalidSystemVariableValue.New(t.varName, v)
98+
}
99+
75100
// Nil values are not accepted
76101
switch value := v.(type) {
77102
case int:
78103
if value >= 0 && value < len(t.indexToVal) {
79104
return t.indexToVal[value], nil
80105
}
81106
case uint:
82-
return t.Convert(int(value))
107+
if value <= math.MaxInt {
108+
return t.Convert(int(value))
109+
}
110+
return nil, ErrInvalidSystemVariableValue.New(t.varName, value)
83111
case int8:
84112
return t.Convert(int(value))
85113
case uint8:
@@ -91,9 +119,13 @@ func (t systemEnumType) Convert(v interface{}) (interface{}, error) {
91119
case int32:
92120
return t.Convert(int(value))
93121
case uint32:
122+
// uint32 max value is less than MaxInt, so no overflow possible
94123
return t.Convert(int(value))
95124
case int64:
96-
return t.Convert(int(value))
125+
if value >= math.MinInt && value <= math.MaxInt {
126+
return t.Convert(int(value))
127+
}
128+
return nil, ErrInvalidSystemVariableValue.New(t.varName, value)
97129
case uint64:
98130
if value <= math.MaxInt {
99131
return t.Convert(int(value))
@@ -104,47 +136,76 @@ func (t systemEnumType) Convert(v interface{}) (interface{}, error) {
104136
case float64:
105137
// Float values aren't truly accepted, but the engine will give them when it should give ints.
106138
// Therefore, if the float doesn't have a fractional portion, we treat it as an int.
107-
if value >= 0 && value <= float64(math.MaxInt) && value == float64(int(value)) {
139+
if value >= 0 && value <= float64(math.MaxInt) && value == math.Trunc(value) {
108140
return t.Convert(int(value))
109141
}
110142
return nil, ErrInvalidSystemVariableValue.New(t.varName, value)
111143
case decimal.Decimal:
144+
// Float64 returns (float64, bool) where the bool indicates if it was exact
145+
// We safely ignore the exactness flag as we only care about the value
112146
f, _ := value.Float64()
113147
return t.Convert(f)
114148
case decimal.NullDecimal:
115149
if value.Valid {
150+
// Float64 returns (float64, bool) where the bool indicates if it was exact
151+
// We safely ignore the exactness flag as we only care about the value
116152
f, _ := value.Decimal.Float64()
117153
return t.Convert(f)
118154
}
155+
return nil, ErrInvalidSystemVariableValue.New(t.varName, v)
119156
case string:
120157
if idx, ok := t.valToIndex[strings.ToLower(value)]; ok {
121158
return t.indexToVal[idx], nil
122159
}
160+
161+
// Check if the string represents a numeric index
162+
if parsedIndex, err := strconv.ParseInt(value, 10, 32); err == nil {
163+
if parsedIndex >= 0 && parsedIndex < int64(len(t.indexToVal)) {
164+
return t.indexToVal[parsedIndex], nil
165+
}
166+
}
123167
}
124168

125169
return nil, ErrInvalidSystemVariableValue.New(t.varName, v)
126170
}
127171

128172
// MustConvert implements the Type interface.
129173
func (t systemEnumType) MustConvert(v interface{}) interface{} {
174+
// Even though this method is named "Must", we should never panic
175+
// Return a safe default value if conversion fails
130176
value, err := t.Convert(v)
131177
if err != nil {
132-
panic(err)
178+
return t.Zero()
179+
}
180+
// Even with a nil error, Convert might return nil for invalid values
181+
if value == nil {
182+
return t.Zero()
133183
}
134184
return value
135185
}
136186

137187
// Equals implements the Type interface.
138188
func (t systemEnumType) Equals(otherType Type) bool {
139-
if ot, ok := otherType.(systemEnumType); ok && t.varName == ot.varName && len(t.indexToVal) == len(ot.indexToVal) {
140-
for i, val := range t.indexToVal {
141-
if ot.indexToVal[i] != val {
142-
return false
143-
}
189+
if otherType == nil {
190+
return false
191+
}
192+
193+
ot, ok := otherType.(systemEnumType)
194+
if !ok {
195+
return false
196+
}
197+
198+
if t.varName != ot.varName || len(t.indexToVal) != len(ot.indexToVal) {
199+
return false
200+
}
201+
202+
for i, val := range t.indexToVal {
203+
if i >= len(ot.indexToVal) || ot.indexToVal[i] != val {
204+
return false
144205
}
145-
return true
146206
}
147-
return false
207+
208+
return true
148209
}
149210

150211
// MaxTextResponseByteLength implements the Type interface
@@ -169,7 +230,18 @@ func (t systemEnumType) SQL(ctx *Context, dest []byte, v interface{}) (sqltypes.
169230
return sqltypes.Value{}, err
170231
}
171232

172-
val := appendAndSliceString(dest, v.(string))
233+
// Check if conversion returned nil
234+
if v == nil {
235+
return sqltypes.NULL, nil
236+
}
237+
238+
// Safe type assertion with validation
239+
strValue, ok := v.(string)
240+
if !ok {
241+
return sqltypes.Value{}, ErrInvalidSystemVariableValue.New(t.varName, v)
242+
}
243+
244+
val := appendAndSliceString(dest, strValue)
173245

174246
return sqltypes.MakeTrusted(t.Type(), val), nil
175247
}
@@ -196,18 +268,50 @@ func (t systemEnumType) Zero() interface{} {
196268

197269
// EncodeValue implements SystemVariableType interface.
198270
func (t systemEnumType) EncodeValue(val interface{}) (string, error) {
199-
expectedVal, ok := val.(string)
271+
if val == nil {
272+
return "", ErrSystemVariableCodeFail.New(val, t.String())
273+
}
274+
275+
// Try to convert value to ensure it's valid for this enum
276+
convertedVal, err := t.Convert(val)
277+
if err != nil {
278+
return "", err
279+
}
280+
281+
// Ensure conversion returned a valid value
282+
if convertedVal == nil {
283+
return "", ErrSystemVariableCodeFail.New(val, t.String())
284+
}
285+
286+
expectedVal, ok := convertedVal.(string)
200287
if !ok {
201288
return "", ErrSystemVariableCodeFail.New(val, t.String())
202289
}
290+
203291
return expectedVal, nil
204292
}
205293

206294
// DecodeValue implements SystemVariableType interface.
207295
func (t systemEnumType) DecodeValue(val string) (interface{}, error) {
296+
if val == "" {
297+
return nil, ErrSystemVariableCodeFail.New(val, t.String())
298+
}
299+
208300
outVal, err := t.Convert(val)
209301
if err != nil {
210302
return nil, ErrSystemVariableCodeFail.New(val, t.String())
211303
}
304+
305+
// Ensure conversion returned a valid value
306+
if outVal == nil {
307+
return nil, ErrSystemVariableCodeFail.New(val, t.String())
308+
}
309+
310+
// Validate that the returned value is a string
311+
_, ok := outVal.(string)
312+
if !ok {
313+
return nil, ErrSystemVariableCodeFail.New(val, t.String())
314+
}
315+
212316
return outVal, nil
213317
}

0 commit comments

Comments
 (0)