99 "go/build"
1010 "go/types"
1111 "os"
12+
1213 "path/filepath"
1314 "strings"
1415
@@ -19,6 +20,27 @@ import (
1920 "golang.org/x/tools/go/ssa/ssautil"
2021)
2122
23+ type sqlPackage struct {
24+ packageName string
25+ paramNames []string
26+ enable bool
27+ }
28+
29+ var sqlPackages = []sqlPackage {
30+ {
31+ packageName : "database/sql" ,
32+ paramNames : []string {"query" },
33+ },
34+ {
35+ packageName : "github.com/jinzhu/gorm" ,
36+ paramNames : []string {"sql" , "query" },
37+ },
38+ {
39+ packageName : "github.com/jmoiron/sqlx" ,
40+ paramNames : []string {"query" },
41+ },
42+ }
43+
2244func main () {
2345 var verbose , quiet bool
2446 flag .BoolVar (& verbose , "v" , false , "Verbose mode" )
@@ -38,21 +60,45 @@ func main() {
3860 c := loader.Config {
3961 FindPackage : FindPackage ,
4062 }
41- c .Import ("database/sql" )
4263 for _ , pkg := range pkgs {
4364 c .Import (pkg )
4465 }
4566 p , err := c .Load ()
67+
4668 if err != nil {
4769 fmt .Printf ("error loading packages %v: %v\n " , pkgs , err )
4870 os .Exit (2 )
4971 }
72+
73+ imports := getImports (p )
74+ existOne := false
75+ for i := range sqlPackages {
76+ if _ , exist := imports [sqlPackages [i ].packageName ]; exist {
77+ if verbose {
78+ fmt .Printf ("Enabling support for %s\n " , sqlPackages [i ].packageName )
79+ }
80+ sqlPackages [i ].enable = true
81+ existOne = true
82+ }
83+ }
84+ if ! existOne {
85+ fmt .Printf ("No packages in %v include a supported database driver" , pkgs )
86+ os .Exit (2 )
87+ }
88+
5089 s := ssautil .CreateProgram (p , 0 )
5190 s .Build ()
5291
53- qms := FindQueryMethods (p .Package ("database/sql" ).Pkg , s )
92+ qms := make ([]* QueryMethod , 0 )
93+
94+ for i := range sqlPackages {
95+ if sqlPackages [i ].enable {
96+ qms = append (qms , FindQueryMethods (sqlPackages [i ], p .Package (sqlPackages [i ].packageName ).Pkg , s )... )
97+ }
98+ }
99+
54100 if verbose {
55- fmt .Println ("database/sql functions that accept queries:" )
101+ fmt .Println ("database driver functions that accept queries:" )
56102 for _ , m := range qms {
57103 fmt .Printf ("- %s (param %d)\n " , m .Func , m .Param )
58104 }
@@ -75,21 +121,27 @@ func main() {
75121 }
76122
77123 bad := FindNonConstCalls (res .CallGraph , qms )
124+
78125 if len (bad ) == 0 {
79126 if ! quiet {
80127 fmt .Println (`You're safe from SQL injection! Yay \o/` )
81128 }
82129 return
83130 }
84131
85- fmt .Printf ("Found %d potentially unsafe SQL statements:\n " , len (bad ))
132+ if verbose {
133+ fmt .Printf ("Found %d potentially unsafe SQL statements:\n " , len (bad ))
134+ }
135+
86136 for _ , ci := range bad {
87137 pos := p .Fset .Position (ci .Pos ())
88138 fmt .Printf ("- %s\n " , pos )
89139 }
90- fmt .Println ("Please ensure that all SQL queries you use are compile-time constants." )
91- fmt .Println ("You should always use parameterized queries or prepared statements" )
92- fmt .Println ("instead of building queries from strings." )
140+ if verbose {
141+ fmt .Println ("Please ensure that all SQL queries you use are compile-time constants." )
142+ fmt .Println ("You should always use parameterized queries or prepared statements" )
143+ fmt .Println ("instead of building queries from strings." )
144+ }
93145 os .Exit (1 )
94146}
95147
@@ -104,7 +156,7 @@ type QueryMethod struct {
104156
105157// FindQueryMethods locates all methods in the given package (assumed to be
106158// package database/sql) with a string parameter named "query".
107- func FindQueryMethods (sql * types.Package , ssa * ssa.Program ) []* QueryMethod {
159+ func FindQueryMethods (sqlPackages sqlPackage , sql * types.Package , ssa * ssa.Program ) []* QueryMethod {
108160 methods := make ([]* QueryMethod , 0 )
109161 scope := sql .Scope ()
110162 for _ , name := range scope .Names () {
@@ -122,7 +174,7 @@ func FindQueryMethods(sql *types.Package, ssa *ssa.Program) []*QueryMethod {
122174 continue
123175 }
124176 s := m .Type ().(* types.Signature )
125- if num , ok := FuncHasQuery (s ); ok {
177+ if num , ok := FuncHasQuery (sqlPackages , s ); ok {
126178 methods = append (methods , & QueryMethod {
127179 Func : m ,
128180 SSA : ssa .FuncValue (m ),
@@ -135,16 +187,16 @@ func FindQueryMethods(sql *types.Package, ssa *ssa.Program) []*QueryMethod {
135187 return methods
136188}
137189
138- var stringType types.Type = types .Typ [types .String ]
139-
140190// FuncHasQuery returns the offset of the string parameter named "query", or
141191// none if no such parameter exists.
142- func FuncHasQuery (s * types.Signature ) (offset int , ok bool ) {
192+ func FuncHasQuery (sqlPackages sqlPackage , s * types.Signature ) (offset int , ok bool ) {
143193 params := s .Params ()
144194 for i := 0 ; i < params .Len (); i ++ {
145195 v := params .At (i )
146- if v .Name () == "query" && v .Type () == stringType {
147- return i , true
196+ for _ , paramName := range sqlPackages .paramNames {
197+ if v .Name () == paramName {
198+ return i , true
199+ }
148200 }
149201 }
150202 return 0 , false
@@ -164,6 +216,16 @@ func FindMains(p *loader.Program, s *ssa.Program) []*ssa.Package {
164216 return mains
165217}
166218
219+ func getImports (p * loader.Program ) map [string ]interface {} {
220+ pkgs := make (map [string ]interface {})
221+ for _ , pkg := range p .AllPackages {
222+ if pkg .Importable {
223+ pkgs [pkg .Pkg .Path ()] = nil
224+ }
225+ }
226+ return pkgs
227+ }
228+
167229// FindNonConstCalls returns the set of callsites of the given set of methods
168230// for which the "query" parameter is not a compile-time constant.
169231func FindNonConstCalls (cg * callgraph.Graph , qms []* QueryMethod ) []ssa.CallInstruction {
@@ -186,6 +248,18 @@ func FindNonConstCalls(cg *callgraph.Graph, qms []*QueryMethod) []ssa.CallInstru
186248 if _ , ok := okFuncs [edge .Site .Parent ()]; ok {
187249 continue
188250 }
251+
252+ isInternalSQLPkg := false
253+ for _ , pkg := range sqlPackages {
254+ if pkg .packageName == edge .Caller .Func .Pkg .Pkg .Path () {
255+ isInternalSQLPkg = true
256+ break
257+ }
258+ }
259+ if isInternalSQLPkg {
260+ continue
261+ }
262+
189263 cc := edge .Site .Common ()
190264 args := cc .Args
191265 // The first parameter is occasionally the receiver.
@@ -195,7 +269,14 @@ func FindNonConstCalls(cg *callgraph.Graph, qms []*QueryMethod) []ssa.CallInstru
195269 panic ("arg count mismatch" )
196270 }
197271 v := args [m .Param ]
272+
198273 if _ , ok := v .(* ssa.Const ); ! ok {
274+ if inter , ok := v .(* ssa.MakeInterface ); ok && types .IsInterface (v .(* ssa.MakeInterface ).Type ()) {
275+ if inter .X .Referrers () == nil || inter .X .Type () != types .Typ [types .String ] {
276+ continue
277+ }
278+ }
279+
199280 bad = append (bad , edge .Site )
200281 }
201282 }
0 commit comments