99 "sort"
1010 "strings"
1111
12+ easyjson "github.com/mailru/easyjson"
1213 plugin "github.com/tabbed/sqlc-go/codegen"
1314 "github.com/tabbed/sqlc-go/metadata"
1415 "github.com/tabbed/sqlc-go/sdk"
@@ -379,7 +380,7 @@ func sqlalchemySQL(s, engine string) string {
379380 return s
380381}
381382
382- func buildQueries (req * plugin.CodeGenRequest , structs []Struct ) ([]Query , error ) {
383+ func buildQueries (conf Config , req * plugin.CodeGenRequest , structs []Struct ) ([]Query , error ) {
383384 qs := make ([]Query , 0 , len (req .Queries ))
384385 for _ , query := range req .Queries {
385386 if query .Name == "" {
@@ -405,8 +406,8 @@ func buildQueries(req *plugin.CodeGenRequest, structs []Struct) ([]Query, error)
405406 }
406407
407408 qpl := 4
408- if req . Settings . Python .QueryParameterLimit != nil {
409- qpl = int (* req . Settings . Python .QueryParameterLimit )
409+ if conf .QueryParameterLimit != nil {
410+ qpl = int (* conf .QueryParameterLimit )
410411 }
411412 if len (query .Params ) > qpl || qpl == 0 {
412413 var cols []pyColumn
@@ -722,7 +723,7 @@ func buildModelsTree(ctx *pyTmplCtx, i *importer) *pyast.Node {
722723
723724 for _ , m := range ctx .Models {
724725 var def * pyast.ClassDef
725- if ctx .EmitPydanticModels {
726+ if ctx .C . EmitPydanticModels {
726727 def = pydanticNode (m .Name )
727728 } else {
728729 def = dataclassNode (m .Name )
@@ -836,7 +837,7 @@ func buildQueryTree(ctx *pyTmplCtx, i *importer, source string) *pyast.Node {
836837 {
837838 Node : & pyast.Node_ImportFrom {
838839 ImportFrom : & pyast.ImportFrom {
839- Module : i . Settings . Python .Package ,
840+ Module : ctx . C .Package ,
840841 Names : []* pyast.Node {
841842 poet .Alias ("models" ),
842843 },
@@ -857,7 +858,7 @@ func buildQueryTree(ctx *pyTmplCtx, i *importer, source string) *pyast.Node {
857858 for _ , arg := range q .Args {
858859 if arg .EmitStruct () {
859860 var def * pyast.ClassDef
860- if ctx .EmitPydanticModels {
861+ if ctx .C . EmitPydanticModels {
861862 def = pydanticNode (arg .Struct .Name )
862863 } else {
863864 def = dataclassNode (arg .Struct .Name )
@@ -870,7 +871,7 @@ func buildQueryTree(ctx *pyTmplCtx, i *importer, source string) *pyast.Node {
870871 }
871872 if q .Ret .EmitStruct () {
872873 var def * pyast.ClassDef
873- if ctx .EmitPydanticModels {
874+ if ctx .C . EmitPydanticModels {
874875 def = pydanticNode (q .Ret .Struct .Name )
875876 } else {
876877 def = dataclassNode (q .Ret .Struct .Name )
@@ -882,7 +883,7 @@ func buildQueryTree(ctx *pyTmplCtx, i *importer, source string) *pyast.Node {
882883 }
883884 }
884885
885- if ctx .EmitSync {
886+ if ctx .C . EmitSyncQuerier {
886887 cls := querierClassDef ()
887888 for _ , q := range ctx .Queries {
888889 if ! ctx .OutputQuery (q .SourceName ) {
@@ -974,7 +975,7 @@ func buildQueryTree(ctx *pyTmplCtx, i *importer, source string) *pyast.Node {
974975 mod .Body = append (mod .Body , poet .Node (cls ))
975976 }
976977
977- if ctx .EmitAsync {
978+ if ctx .C . EmitAsyncQuerier {
978979 cls := asyncQuerierClassDef ()
979980 for _ , q := range ctx .Queries {
980981 if ! ctx .OutputQuery (q .SourceName ) {
@@ -1071,14 +1072,12 @@ func buildQueryTree(ctx *pyTmplCtx, i *importer, source string) *pyast.Node {
10711072}
10721073
10731074type pyTmplCtx struct {
1074- Models []Struct
1075- Queries []Query
1076- Enums []Enum
1077- EmitSync bool
1078- EmitAsync bool
1079- SourceName string
1080- SqlcVersion string
1081- EmitPydanticModels bool
1075+ SqlcVersion string
1076+ Models []Struct
1077+ Queries []Query
1078+ Enums []Enum
1079+ SourceName string
1080+ C Config
10821081}
10831082
10841083func (t * pyTmplCtx ) OutputQuery (sourceName string ) bool {
@@ -1090,9 +1089,16 @@ func HashComment(s string) string {
10901089}
10911090
10921091func Generate (_ context.Context , req * plugin.Request ) (* plugin.Response , error ) {
1092+ var conf Config
1093+ if len (req .PluginOptions ) > 0 {
1094+ if err := easyjson .Unmarshal (req .PluginOptions , & conf ); err != nil {
1095+ return nil , err
1096+ }
1097+ }
1098+
10931099 enums := buildEnums (req )
10941100 models := buildModels (req )
1095- queries , err := buildQueries (req , models )
1101+ queries , err := buildQueries (conf , req , models )
10961102 if err != nil {
10971103 return nil , err
10981104 }
@@ -1102,16 +1108,15 @@ func Generate(_ context.Context, req *plugin.Request) (*plugin.Response, error)
11021108 Models : models ,
11031109 Queries : queries ,
11041110 Enums : enums ,
1111+ C : conf ,
11051112 }
11061113
11071114 tctx := pyTmplCtx {
1108- Models : models ,
1109- Queries : queries ,
1110- Enums : enums ,
1111- EmitSync : req .Settings .Python .EmitSyncQuerier ,
1112- EmitAsync : req .Settings .Python .EmitAsyncQuerier ,
1113- SqlcVersion : req .SqlcVersion ,
1114- EmitPydanticModels : req .Settings .Python .EmitPydanticModels ,
1115+ Models : models ,
1116+ Queries : queries ,
1117+ Enums : enums ,
1118+ SqlcVersion : req .SqlcVersion ,
1119+ C : conf ,
11151120 }
11161121
11171122 output := map [string ]string {}
0 commit comments