diff --git a/internal/gen.go b/internal/gen.go index 6e50fae..e51c626 100644 --- a/internal/gen.go +++ b/internal/gen.go @@ -193,8 +193,10 @@ func pyInnerType(req *plugin.GenerateRequest, col *plugin.Column) string { switch req.Settings.Engine { case "postgresql": return postgresType(req, col) + case "mysql": + return mysqlType(req, col) default: - log.Println("unsupported engine type") + log.Printf("unsupported engine type: %s\n", req.Settings.Engine) return "Any" } } @@ -360,6 +362,14 @@ func sqlalchemySQL(s, engine string) string { if engine == "postgresql" { return postgresPlaceholderRegexp.ReplaceAllString(s, ":p$1") } + if engine == "mysql" { + // Convert MySQL ? placeholders to named parameters for SQLAlchemy compatibility + i := 1 + for strings.Contains(s, "?") { + s = strings.Replace(s, "?", fmt.Sprintf(":p%d", i), 1) + i++ + } + } return s } diff --git a/internal/mysql_type.go b/internal/mysql_type.go new file mode 100644 index 0000000..0b2ff20 --- /dev/null +++ b/internal/mysql_type.go @@ -0,0 +1,53 @@ +package python + +import ( + "log" + + "github.com/sqlc-dev/plugin-sdk-go/plugin" + "github.com/sqlc-dev/plugin-sdk-go/sdk" +) + +func mysqlType(req *plugin.GenerateRequest, col *plugin.Column) string { + columnType := sdk.DataType(col.Type) + + switch columnType { + case "tinyint", "smallint", "mediumint", "int", "integer", "bigint": + return "int" + case "float", "double", "real": + return "float" + case "decimal", "numeric": + return "decimal.Decimal" + case "bit", "boolean", "bool": + return "bool" + case "json": + return "Any" + case "binary", "varbinary", "blob", "tinyblob", "mediumblob", "longblob": + return "memoryview" + case "date": + return "datetime.date" + case "time": + return "datetime.time" + case "datetime", "timestamp": + return "datetime.datetime" + case "char", "varchar", "text", "tinytext", "mediumtext", "longtext": + return "str" + case "enum", "set": + return "str" + default: + for _, schema := range req.Catalog.Schemas { + if schema.Name == "information_schema" || schema.Name == "mysql" { + continue + } + for _, enum := range schema.Enums { + if columnType == enum.Name { + if schema.Name == req.Catalog.DefaultSchema { + return "models." + modelName(enum.Name, req.Settings) + } + return "models." + modelName(schema.Name+"_"+enum.Name, req.Settings) + } + } + } + log.Printf("unknown MySQL type: %s\n", columnType) + return "Any" + } +}