Skip to content

Commit 3d57ecf

Browse files
committed
feat: adding mysql
Signed-off-by: Andrew Steurer <[email protected]>
1 parent a65f13a commit 3d57ecf

File tree

2 files changed

+344
-0
lines changed

2 files changed

+344
-0
lines changed

v3/internal/db/driver.go

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
package db
2+
3+
import "database/sql/driver"
4+
5+
// GlobalParameterConverter is a global valueConverter instance to convert parameters.
6+
var GlobalParameterConverter = &valueConverter{}
7+
8+
var _ driver.ValueConverter = (*valueConverter)(nil)
9+
10+
// valueConverter is a no-op value converter.
11+
type valueConverter struct{}
12+
13+
func (c *valueConverter) ConvertValue(v any) (driver.Value, error) {
14+
return driver.Value(v), nil
15+
}

v3/mysql/mysql.go

Lines changed: 329 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,329 @@
1+
package mysql
2+
3+
import (
4+
"context"
5+
"database/sql"
6+
"database/sql/driver"
7+
"errors"
8+
"io"
9+
"reflect"
10+
11+
spindb "github.com/spinframework/spin-go-sdk/v3/internal/db"
12+
"github.com/spinframework/spin-go-sdk/v3/internal/fermyon/spin/mysql"
13+
rdbmstypes "github.com/spinframework/spin-go-sdk/v3/internal/fermyon/spin/rdbms-types"
14+
_ "github.com/spinframework/spin-go-sdk/v3/internal/fermyon/spin/v2.0.0/sqlite"
15+
"go.bytecodealliance.org/cm"
16+
_ "go.bytecodealliance.org/cm"
17+
)
18+
19+
// Open returns a new connection to the database.
20+
func Open(address string) *sql.DB {
21+
return sql.OpenDB(&connector{address})
22+
}
23+
24+
// connector implements driver.Connector.
25+
type connector struct {
26+
address string
27+
}
28+
29+
// Connect returns a connection to the database.
30+
func (d *connector) Connect(_ context.Context) (driver.Conn, error) {
31+
return d.Open(d.address)
32+
}
33+
34+
// Driver returns the underlying Driver of the Connector.
35+
func (d *connector) Driver() driver.Driver {
36+
return d
37+
}
38+
39+
// Open returns a new connection to the database.
40+
func (d *connector) Open(address string) (driver.Conn, error) {
41+
return &conn{address: address}, nil
42+
}
43+
44+
// conn implements driver.Conn
45+
type conn struct {
46+
address string
47+
}
48+
49+
var _ driver.Conn = (*conn)(nil)
50+
51+
// Prepare returns a prepared statement, bound to this connection.
52+
func (c *conn) Prepare(query string) (driver.Stmt, error) {
53+
return &stmt{c: c, query: query}, nil
54+
}
55+
56+
func (c *conn) Close() error {
57+
return nil
58+
}
59+
60+
func (c *conn) Begin() (driver.Tx, error) {
61+
return nil, errors.New("transactions are unsupported by this driver")
62+
}
63+
64+
type stmt struct {
65+
c *conn
66+
query string
67+
}
68+
69+
var _ driver.Stmt = (*stmt)(nil)
70+
var _ driver.ColumnConverter = (*stmt)(nil) // TODO: remove deprecated?
71+
var _ driver.NamedValueChecker = (*stmt)(nil) // TODO: implement?
72+
73+
// Close closes the statement.
74+
func (s *stmt) Close() error {
75+
return nil
76+
}
77+
78+
// NumInput returns the number of placeholder parameters.
79+
func (s *stmt) NumInput() int {
80+
// Golang sql won't sanity check argument counts before Query.
81+
return -1
82+
}
83+
84+
// Query executes a query that may return rows, such as a SELECT.
85+
func (s *stmt) Query(args []driver.Value) (driver.Rows, error) {
86+
params := make([]mysql.ParameterValue, len(args))
87+
for i := range args {
88+
params[i] = toWasiValue(args[i])
89+
}
90+
91+
results, err, isErr := mysql.Query(s.c.address, s.query, cm.ToList(params)).Result()
92+
if isErr {
93+
return nil, toError(&err)
94+
}
95+
96+
cols := results.Columns.Slice()
97+
colNames := make([]string, len(cols))
98+
colTypes := make([]uint8, len(cols))
99+
for _, col := range cols {
100+
colNames = append(colNames, col.Name)
101+
colTypes = append(colTypes, uint8(col.DataType))
102+
}
103+
104+
rowLen := int(results.Rows.Len())
105+
allRows := make([][]any, rowLen)
106+
for rowNum, row := range results.Rows.Slice() {
107+
allRows[rowNum] = toRow(row.Slice())
108+
}
109+
110+
rows := &rows{
111+
columns: colNames,
112+
columnType: colTypes,
113+
len: rowLen,
114+
rows: allRows,
115+
}
116+
117+
return rows, nil
118+
}
119+
120+
// Exec executes a query that doesn't return rows, such as an INSERT or
121+
// UPDATE.
122+
func (s *stmt) Exec(args []driver.Value) (driver.Result, error) {
123+
params := make([]mysql.ParameterValue, len(args))
124+
for i := range args {
125+
params[i] = toWasiValue(args[i])
126+
}
127+
_, err, isErr := mysql.Execute(s.c.address, s.query, cm.ToList(params)).Result()
128+
if isErr {
129+
return nil, toError(&err)
130+
}
131+
132+
return &result{}, nil
133+
}
134+
135+
// ColumnConverter returns GlobalParameterConverter to prevent using driver.DefaultParameterConverter.
136+
func (s *stmt) ColumnConverter(_ int) driver.ValueConverter {
137+
return spindb.GlobalParameterConverter
138+
}
139+
140+
func (s *stmt) NamedValueChecker(v *driver.NamedValue) error {
141+
// TODO: implement?
142+
}
143+
144+
type result struct{}
145+
146+
func (r result) LastInsertId() (int64, error) {
147+
return -1, errors.New("LastInsertId is unsupported by this driver")
148+
}
149+
150+
func (r result) RowsAffected() (int64, error) {
151+
return -1, errors.New("RowsAffected is unsupported by this driver")
152+
}
153+
154+
type rows struct {
155+
columns []string
156+
columnType []uint8
157+
pos int
158+
len int
159+
rows [][]any
160+
closed bool
161+
}
162+
163+
var _ driver.Rows = (*rows)(nil)
164+
var _ driver.RowsColumnTypeScanType = (*rows)(nil)
165+
var _ driver.RowsNextResultSet = (*rows)(nil)
166+
167+
// Columns return column names.
168+
func (r *rows) Columns() []string {
169+
return r.columns
170+
}
171+
172+
// Close closes the rows iterator.
173+
func (r *rows) Close() error {
174+
r.rows = nil
175+
r.pos = 0
176+
r.len = 0
177+
r.closed = true
178+
return nil
179+
}
180+
181+
// Next moves the cursor to the next row.
182+
func (r *rows) Next(dest []driver.Value) error {
183+
if !r.HasNextResultSet() {
184+
return io.EOF
185+
}
186+
for i := 0; i != len(r.columns); i++ {
187+
dest[i] = driver.Value(r.rows[r.pos][i])
188+
}
189+
r.pos++
190+
return nil
191+
}
192+
193+
// HasNextResultSet is called at the end of the current result set and
194+
// reports whether there is another result set after the current one.
195+
func (r *rows) HasNextResultSet() bool {
196+
return r.pos < r.len
197+
}
198+
199+
// NextResultSet advances the driver to the next result set even
200+
// if there are remaining rows in the current result set.
201+
//
202+
// NextResultSet should return io.EOF when there are no more result sets.
203+
func (r *rows) NextResultSet() error {
204+
if r.HasNextResultSet() {
205+
r.pos++
206+
return nil
207+
}
208+
return io.EOF // Per interface spec.
209+
}
210+
211+
// ColumnTypeScanType return the value type that can be used to scan types into.
212+
func (r *rows) ColumnTypeScanType(index int) reflect.Type {
213+
return colTypeToReflectType(r.columnType[index])
214+
}
215+
216+
func colTypeToReflectType(typ uint8) reflect.Type {
217+
switch rdbmstypes.DbDataType(typ) {
218+
case rdbmstypes.DbDataTypeBoolean:
219+
return reflect.TypeOf(false)
220+
case rdbmstypes.DbDataTypeInt8:
221+
return reflect.TypeOf(int8(0))
222+
case rdbmstypes.DbDataTypeInt16:
223+
return reflect.TypeOf(int16(0))
224+
case rdbmstypes.DbDataTypeInt32:
225+
return reflect.TypeOf(int32(0))
226+
case rdbmstypes.DbDataTypeInt64:
227+
return reflect.TypeOf(int64(0))
228+
case rdbmstypes.DbDataTypeUint8:
229+
return reflect.TypeOf(uint8(0))
230+
case rdbmstypes.DbDataTypeUint16:
231+
return reflect.TypeOf(uint16(0))
232+
case rdbmstypes.DbDataTypeUint32:
233+
return reflect.TypeOf(uint32(0))
234+
case rdbmstypes.DbDataTypeUint64:
235+
return reflect.TypeOf(uint64(0))
236+
case rdbmstypes.DbDataTypeStr:
237+
return reflect.TypeOf("")
238+
case rdbmstypes.DbDataTypeBinary:
239+
return reflect.TypeOf(new([]byte))
240+
case rdbmstypes.DbDataTypeOther:
241+
return reflect.TypeOf(new(any)).Elem()
242+
}
243+
panic("invalid db column type of " + string(typ))
244+
}
245+
246+
func toWasiValue(x any) mysql.ParameterValue {
247+
switch v := x.(type) {
248+
case bool:
249+
return rdbmstypes.ParameterValueBoolean(v)
250+
case int8:
251+
return rdbmstypes.ParameterValueInt8(v)
252+
case int16:
253+
return rdbmstypes.ParameterValueInt16(v)
254+
case int32:
255+
return rdbmstypes.ParameterValueInt32(v)
256+
case int64:
257+
return rdbmstypes.ParameterValueInt64(v)
258+
case int:
259+
return rdbmstypes.ParameterValueInt64(int64(v))
260+
case uint8:
261+
return rdbmstypes.ParameterValueUint8(v)
262+
case uint16:
263+
return rdbmstypes.ParameterValueUint16(v)
264+
case uint32:
265+
return rdbmstypes.ParameterValueUint32(v)
266+
case uint64:
267+
return rdbmstypes.ParameterValueUint64(v)
268+
case float32:
269+
return rdbmstypes.ParameterValueFloating32(v)
270+
case float64:
271+
return rdbmstypes.ParameterValueFloating64(v)
272+
case string:
273+
return rdbmstypes.ParameterValueStr(v)
274+
case []byte:
275+
return rdbmstypes.ParameterValueBinary(cm.ToList([]uint8(v)))
276+
case nil:
277+
return rdbmstypes.ParameterValueDbNull()
278+
default:
279+
panic("unknown value type")
280+
}
281+
}
282+
283+
func toError(err *mysql.MysqlError) error {
284+
if err == nil {
285+
return nil
286+
}
287+
288+
return errors.New(err.String())
289+
}
290+
291+
func toRow(row []rdbmstypes.DbValue) []any {
292+
result := make([]any, len(row))
293+
for i, v := range row {
294+
switch v.String() {
295+
case "boolean":
296+
result[i] = *v.Boolean()
297+
case "int8":
298+
result[i] = *v.Int8()
299+
case "int16":
300+
result[i] = *v.Int16()
301+
case "int32":
302+
result[i] = *v.Int32()
303+
case "int64":
304+
result[i] = *v.Int64()
305+
case "uint8":
306+
result[i] = *v.Uint8()
307+
case "uint16":
308+
result[i] = *v.Uint16()
309+
case "uint32":
310+
result[i] = *v.Uint32()
311+
case "uint64":
312+
result[i] = *v.Uint64()
313+
case "floating32":
314+
result[i] = *v.Floating32()
315+
case "floating64":
316+
result[i] = *v.Floating64()
317+
case "str":
318+
result[i] = *v.Str()
319+
case "binary":
320+
result[i] = *v.Binary()
321+
case "db-null":
322+
result[i] = nil
323+
default:
324+
panic("unknown value type")
325+
}
326+
}
327+
328+
return result
329+
}

0 commit comments

Comments
 (0)