Skip to content
This repository was archived by the owner on Sep 21, 2021. It is now read-only.

Commit cddf355

Browse files
Merge pull request #10 from gettaxi/master
Add gorm and sqlx support, make easy add new other ORM
2 parents 452e37e + a7e6848 commit cddf355

File tree

2 files changed

+99
-19
lines changed

2 files changed

+99
-19
lines changed

README.md

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,8 @@ How does it work?
3131
-----------------
3232

3333
SafeSQL uses the static analysis utilities in [go/tools][tools] to search for
34-
all call sites of each of the `query` functions in package [database/sql][sql]
35-
(i.e., functions which accept a `string` parameter named `query`). It then makes
34+
all call sites of each of the `query` functions in packages ([database/sql][sql],[github.com/jinzhu/gorm][gorm],[github.com/jmoiron/sqlx][sqlx])
35+
(i.e., functions which accept a parameter named `query`,`sql`). It then makes
3636
sure that every such call site uses a query that is a compile-time constant.
3737

3838
The principle behind SafeSQL's safety guarantees is that queries that are
@@ -44,6 +44,8 @@ will not be allowed.
4444

4545
[tools]: https://godoc.org/golang.org/x/tools/go
4646
[sql]: http://golang.org/pkg/database/sql/
47+
[sqlx]: https://github.com/jmoiron/sqlx
48+
[gorm]: https://github.com/jinzhu/gorm
4749

4850
False positives
4951
---------------
@@ -66,8 +68,6 @@ a fundamental limitation: SafeSQL could recursively trace the `query` argument
6668
through every intervening helper function to ensure that its argument is always
6769
constant, but this code has yet to be written.
6870

69-
If you use a wrapper for `database/sql` (e.g., [`sqlx`][sqlx]), it's likely
70-
SafeSQL will not work for you because of this.
7171

7272
The second sort of false positive is based on a limitation in the sort of
7373
analysis SafeSQL performs: there are many safe SQL statements which are not
@@ -76,4 +76,3 @@ static analysis techniques (such as taint analysis) or user-provided safety
7676
annotations would be able to reduce the number of false positives, but this is
7777
expected to be a significant undertaking.
7878

79-
[sqlx]: https://github.com/jmoiron/sqlx

safesql.go

Lines changed: 95 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ import (
99
"go/build"
1010
"go/types"
1111
"os"
12+
1213
"path/filepath"
1314
"strings"
1415

@@ -19,6 +20,27 @@ import (
1920
"golang.org/x/tools/go/ssa/ssautil"
2021
)
2122

23+
type sqlPackage struct {
24+
packageName string
25+
paramNames []string
26+
enable bool
27+
}
28+
29+
var sqlPackages = []sqlPackage{
30+
{
31+
packageName: "database/sql",
32+
paramNames: []string{"query"},
33+
},
34+
{
35+
packageName: "github.com/jinzhu/gorm",
36+
paramNames: []string{"sql", "query"},
37+
},
38+
{
39+
packageName: "github.com/jmoiron/sqlx",
40+
paramNames: []string{"query"},
41+
},
42+
}
43+
2244
func main() {
2345
var verbose, quiet bool
2446
flag.BoolVar(&verbose, "v", false, "Verbose mode")
@@ -38,21 +60,45 @@ func main() {
3860
c := loader.Config{
3961
FindPackage: FindPackage,
4062
}
41-
c.Import("database/sql")
4263
for _, pkg := range pkgs {
4364
c.Import(pkg)
4465
}
4566
p, err := c.Load()
67+
4668
if err != nil {
4769
fmt.Printf("error loading packages %v: %v\n", pkgs, err)
4870
os.Exit(2)
4971
}
72+
73+
imports := getImports(p)
74+
existOne := false
75+
for i := range sqlPackages {
76+
if _, exist := imports[sqlPackages[i].packageName]; exist {
77+
if verbose {
78+
fmt.Printf("Enabling support for %s\n", sqlPackages[i].packageName)
79+
}
80+
sqlPackages[i].enable = true
81+
existOne = true
82+
}
83+
}
84+
if !existOne {
85+
fmt.Printf("No packages in %v include a supported database driver", pkgs)
86+
os.Exit(2)
87+
}
88+
5089
s := ssautil.CreateProgram(p, 0)
5190
s.Build()
5291

53-
qms := FindQueryMethods(p.Package("database/sql").Pkg, s)
92+
qms := make([]*QueryMethod, 0)
93+
94+
for i := range sqlPackages {
95+
if sqlPackages[i].enable {
96+
qms = append(qms, FindQueryMethods(sqlPackages[i], p.Package(sqlPackages[i].packageName).Pkg, s)...)
97+
}
98+
}
99+
54100
if verbose {
55-
fmt.Println("database/sql functions that accept queries:")
101+
fmt.Println("database driver functions that accept queries:")
56102
for _, m := range qms {
57103
fmt.Printf("- %s (param %d)\n", m.Func, m.Param)
58104
}
@@ -75,21 +121,27 @@ func main() {
75121
}
76122

77123
bad := FindNonConstCalls(res.CallGraph, qms)
124+
78125
if len(bad) == 0 {
79126
if !quiet {
80127
fmt.Println(`You're safe from SQL injection! Yay \o/`)
81128
}
82129
return
83130
}
84131

85-
fmt.Printf("Found %d potentially unsafe SQL statements:\n", len(bad))
132+
if verbose {
133+
fmt.Printf("Found %d potentially unsafe SQL statements:\n", len(bad))
134+
}
135+
86136
for _, ci := range bad {
87137
pos := p.Fset.Position(ci.Pos())
88138
fmt.Printf("- %s\n", pos)
89139
}
90-
fmt.Println("Please ensure that all SQL queries you use are compile-time constants.")
91-
fmt.Println("You should always use parameterized queries or prepared statements")
92-
fmt.Println("instead of building queries from strings.")
140+
if verbose {
141+
fmt.Println("Please ensure that all SQL queries you use are compile-time constants.")
142+
fmt.Println("You should always use parameterized queries or prepared statements")
143+
fmt.Println("instead of building queries from strings.")
144+
}
93145
os.Exit(1)
94146
}
95147

@@ -104,7 +156,7 @@ type QueryMethod struct {
104156

105157
// FindQueryMethods locates all methods in the given package (assumed to be
106158
// package database/sql) with a string parameter named "query".
107-
func FindQueryMethods(sql *types.Package, ssa *ssa.Program) []*QueryMethod {
159+
func FindQueryMethods(sqlPackages sqlPackage, sql *types.Package, ssa *ssa.Program) []*QueryMethod {
108160
methods := make([]*QueryMethod, 0)
109161
scope := sql.Scope()
110162
for _, name := range scope.Names() {
@@ -122,7 +174,7 @@ func FindQueryMethods(sql *types.Package, ssa *ssa.Program) []*QueryMethod {
122174
continue
123175
}
124176
s := m.Type().(*types.Signature)
125-
if num, ok := FuncHasQuery(s); ok {
177+
if num, ok := FuncHasQuery(sqlPackages, s); ok {
126178
methods = append(methods, &QueryMethod{
127179
Func: m,
128180
SSA: ssa.FuncValue(m),
@@ -135,16 +187,16 @@ func FindQueryMethods(sql *types.Package, ssa *ssa.Program) []*QueryMethod {
135187
return methods
136188
}
137189

138-
var stringType types.Type = types.Typ[types.String]
139-
140190
// FuncHasQuery returns the offset of the string parameter named "query", or
141191
// none if no such parameter exists.
142-
func FuncHasQuery(s *types.Signature) (offset int, ok bool) {
192+
func FuncHasQuery(sqlPackages sqlPackage, s *types.Signature) (offset int, ok bool) {
143193
params := s.Params()
144194
for i := 0; i < params.Len(); i++ {
145195
v := params.At(i)
146-
if v.Name() == "query" && v.Type() == stringType {
147-
return i, true
196+
for _, paramName := range sqlPackages.paramNames {
197+
if v.Name() == paramName {
198+
return i, true
199+
}
148200
}
149201
}
150202
return 0, false
@@ -164,6 +216,16 @@ func FindMains(p *loader.Program, s *ssa.Program) []*ssa.Package {
164216
return mains
165217
}
166218

219+
func getImports(p *loader.Program) map[string]interface{} {
220+
pkgs := make(map[string]interface{})
221+
for _, pkg := range p.AllPackages {
222+
if pkg.Importable {
223+
pkgs[pkg.Pkg.Path()] = nil
224+
}
225+
}
226+
return pkgs
227+
}
228+
167229
// FindNonConstCalls returns the set of callsites of the given set of methods
168230
// for which the "query" parameter is not a compile-time constant.
169231
func FindNonConstCalls(cg *callgraph.Graph, qms []*QueryMethod) []ssa.CallInstruction {
@@ -186,6 +248,18 @@ func FindNonConstCalls(cg *callgraph.Graph, qms []*QueryMethod) []ssa.CallInstru
186248
if _, ok := okFuncs[edge.Site.Parent()]; ok {
187249
continue
188250
}
251+
252+
isInternalSQLPkg := false
253+
for _, pkg := range sqlPackages {
254+
if pkg.packageName == edge.Caller.Func.Pkg.Pkg.Path() {
255+
isInternalSQLPkg = true
256+
break
257+
}
258+
}
259+
if isInternalSQLPkg {
260+
continue
261+
}
262+
189263
cc := edge.Site.Common()
190264
args := cc.Args
191265
// The first parameter is occasionally the receiver.
@@ -195,7 +269,14 @@ func FindNonConstCalls(cg *callgraph.Graph, qms []*QueryMethod) []ssa.CallInstru
195269
panic("arg count mismatch")
196270
}
197271
v := args[m.Param]
272+
198273
if _, ok := v.(*ssa.Const); !ok {
274+
if inter, ok := v.(*ssa.MakeInterface); ok && types.IsInterface(v.(*ssa.MakeInterface).Type()) {
275+
if inter.X.Referrers() == nil || inter.X.Type() != types.Typ[types.String] {
276+
continue
277+
}
278+
}
279+
199280
bad = append(bad, edge.Site)
200281
}
201282
}

0 commit comments

Comments
 (0)