11package benchmark
22
33import (
4+ "context"
45 "database/sql"
56 "fmt"
67 "math/rand"
78 "strings"
89
910 _ "github.com/go-sql-driver/mysql"
11+ _ "github.com/lib/pq"
12+ _ "github.com/samitani/go-sysbench/driver"
1013)
1114
1215const (
@@ -63,19 +66,43 @@ type (
6366 MySQLDB string `long:"mysql-db" description:"MySQL database name" default:"sbtest"`
6467 }
6568
66- MySQLOLTP struct {
69+ PgSQLOpts struct {
70+ PgSQLHost string `long:"pgsql-host" description:"PostgreSQL server host" default:"localhost"`
71+ PgSQLPort int `long:"pgsql-port" description:"PostgreSQL server port" default:"5432"`
72+ PgSQLUser string `long:"pgsql-user" description:"PostgreSQL user" default:"sbtest"`
73+ PgSQLPassword string `long:"pgsql-password" env:"PGPASSWORD" description:"PostgreSQL password" default:""`
74+ PgSQLDB string `long:"pgsql-db" description:"PostgreSQL database name" default:"sbtest"`
75+ }
76+
77+ OLTPBench struct {
6778 opts * BenchmarkOpts
6879
6980 db * sql.DB
7081 }
7182)
7283
73- func newMySQLOLTP (option * BenchmarkOpts ) * MySQLOLTP {
74- return & MySQLOLTP {opts : option }
84+ func newOLTPBench (option * BenchmarkOpts ) * OLTPBench {
85+ return & OLTPBench {opts : option }
7586}
7687
77- func (o * MySQLOLTP ) Init () error {
78- db , err := sql .Open ("mysql" , o .dsn ())
88+ func (o * OLTPBench ) Init (ctx context.Context ) error {
89+ var drvName string
90+ var dsn string
91+
92+ if o .opts .DBDriver == DBDriverMySQL {
93+ drvName = "mysql"
94+ dsn = o .dsnMySQL ()
95+ } else if o .opts .DBDriver == DBDriverPgSQL {
96+ drvName = "postgres"
97+ dsn = o .dsnPgSQL ()
98+ }
99+
100+ db , err := sql .Open (drvName , dsn )
101+ if err != nil {
102+ return err
103+ }
104+
105+ err = db .Ping ()
79106 if err != nil {
80107 return err
81108 }
@@ -85,15 +112,15 @@ func (o *MySQLOLTP) Init() error {
85112 return nil
86113}
87114
88- func (o * MySQLOLTP ) Prepare () error {
115+ func (o * OLTPBench ) Prepare (ctx context. Context ) error {
89116 err := o .createTable ()
90117 if err != nil {
91118 return err
92119 }
93120 return nil
94121}
95122
96- func (o * MySQLOLTP ) Event () (reads uint64 , writes uint64 , others uint64 , errors uint64 , e error ) {
123+ func (o * OLTPBench ) Event (ctx context. Context ) (reads uint64 , writes uint64 , others uint64 , errors uint64 , e error ) {
97124 var numReads , numWrites , numOthers uint64
98125 var tableNum = o .getRandTableNum ()
99126 var numRowReturn = 0
@@ -106,7 +133,7 @@ func (o *MySQLOLTP) Event() (reads uint64, writes uint64, others uint64, errors
106133 numOthers += 1
107134
108135 for i := 0 ; i < numPointSelects ; i ++ {
109- rows , err := tx .Query ( fmt .Sprintf (stmtPointSelects , tableNum , sbRand (0 , o .opts .TableSize )))
136+ rows , err := tx .QueryContext ( ctx , fmt .Sprintf (stmtPointSelects , tableNum , sbRand (0 , o .opts .TableSize )))
110137 if err != nil {
111138 tx .Rollback ()
112139 return numReads , numWrites , numOthers , 1 , err
@@ -119,7 +146,7 @@ func (o *MySQLOLTP) Event() (reads uint64, writes uint64, others uint64, errors
119146
120147 for i := 0 ; i < numSimpleRanges ; i ++ {
121148 begin := sbRand (0 , o .opts .TableSize )
122- rows , err := tx .Query ( fmt .Sprintf (stmtSimpleRanges , tableNum , begin , begin + rangeSize - 1 ))
149+ rows , err := tx .QueryContext ( ctx , fmt .Sprintf (stmtSimpleRanges , tableNum , begin , begin + rangeSize - 1 ))
123150 if err != nil {
124151 tx .Rollback ()
125152 return numReads , numWrites , numOthers , 1 , err
@@ -132,7 +159,7 @@ func (o *MySQLOLTP) Event() (reads uint64, writes uint64, others uint64, errors
132159
133160 for i := 0 ; i < numSumRanges ; i ++ {
134161 begin := sbRand (0 , o .opts .TableSize )
135- rows , err := tx .Query ( fmt .Sprintf (stmtSumRanges , tableNum , begin , begin + rangeSize - 1 ))
162+ rows , err := tx .QueryContext ( ctx , fmt .Sprintf (stmtSumRanges , tableNum , begin , begin + rangeSize - 1 ))
136163 if err != nil {
137164 tx .Rollback ()
138165 return numReads , numWrites , numOthers , 1 , err
@@ -145,7 +172,7 @@ func (o *MySQLOLTP) Event() (reads uint64, writes uint64, others uint64, errors
145172
146173 for i := 0 ; i < numOrderRanges ; i ++ {
147174 begin := sbRand (0 , o .opts .TableSize )
148- rows , err := tx .Query ( fmt .Sprintf (stmtOrderRanges , tableNum , begin , begin + rangeSize - 1 ))
175+ rows , err := tx .QueryContext ( ctx , fmt .Sprintf (stmtOrderRanges , tableNum , begin , begin + rangeSize - 1 ))
149176 if err != nil {
150177 tx .Rollback ()
151178 return numReads , numWrites , numOthers , 1 , err
@@ -158,7 +185,7 @@ func (o *MySQLOLTP) Event() (reads uint64, writes uint64, others uint64, errors
158185
159186 for i := 0 ; i < numDistinctRanges ; i ++ {
160187 begin := sbRand (0 , o .opts .TableSize )
161- rows , err := tx .Query ( fmt .Sprintf (stmtDistinctRanges , tableNum , begin , begin + rangeSize - 1 ))
188+ rows , err := tx .QueryContext ( ctx , fmt .Sprintf (stmtDistinctRanges , tableNum , begin , begin + rangeSize - 1 ))
162189 if err != nil {
163190 tx .Rollback ()
164191 return numReads , numWrites , numOthers , 1 , err
@@ -171,15 +198,15 @@ func (o *MySQLOLTP) Event() (reads uint64, writes uint64, others uint64, errors
171198
172199 if o .opts .ReadWrite {
173200 for i := 0 ; i < numIndexUpdates ; i ++ {
174- _ , err := tx .Exec ( fmt .Sprintf (stmtIndexUpdates , tableNum , sbRand (0 , o .opts .TableSize )))
201+ _ , err := tx .ExecContext ( ctx , fmt .Sprintf (stmtIndexUpdates , tableNum , sbRand (0 , o .opts .TableSize )))
175202 if err != nil {
176203 tx .Rollback ()
177204 return numReads , numWrites , numOthers , 1 , err
178205 }
179206 numWrites += 1
180207 }
181208 for i := 0 ; i < numNonIndexUpdates ; i ++ {
182- _ , err := tx .Exec ( fmt .Sprintf (stmtNonIndex_updates , tableNum , getCValue (), sbRand (0 , o .opts .TableSize )))
209+ _ , err := tx .ExecContext ( ctx , fmt .Sprintf (stmtNonIndex_updates , tableNum , getCValue (), sbRand (0 , o .opts .TableSize )))
183210 if err != nil {
184211 tx .Rollback ()
185212 return numReads , numWrites , numOthers , 1 , err
@@ -189,14 +216,14 @@ func (o *MySQLOLTP) Event() (reads uint64, writes uint64, others uint64, errors
189216 for i := 0 ; i < numDeleteInserts ; i ++ {
190217 id := sbRand (0 , o .opts .TableSize )
191218
192- _ , err := tx .Exec ( fmt .Sprintf (stmtDeletes , tableNum , id ))
219+ _ , err := tx .ExecContext ( ctx , fmt .Sprintf (stmtDeletes , tableNum , id ))
193220 if err != nil {
194221 tx .Rollback ()
195222 return numReads , numWrites , numOthers , 1 , err
196223 }
197224 numWrites += 1
198225
199- _ , err = tx .Exec ( fmt .Sprintf (stmtInserts , tableNum , id , sbRand (0 , o .opts .TableSize ), getCValue (), getPadValue ()))
226+ _ , err = tx .ExecContext ( ctx , fmt .Sprintf (stmtInserts , tableNum , id , sbRand (0 , o .opts .TableSize ), getCValue (), getPadValue ()))
200227 if err != nil {
201228 tx .Rollback ()
202229 return numReads , numWrites , numOthers , 1 , err
@@ -215,16 +242,20 @@ func (o *MySQLOLTP) Event() (reads uint64, writes uint64, others uint64, errors
215242 return numReads , numWrites , numOthers , 0 , nil
216243}
217244
218- func (o * MySQLOLTP ) Done () error {
245+ func (o * OLTPBench ) Done () error {
219246 o .db .Close ()
220247 return nil
221248}
222249
223- func (o * MySQLOLTP ) dsn () string {
250+ func (o * OLTPBench ) dsnMySQL () string {
224251 return fmt .Sprintf ("%s:%s@tcp(%s:%d)/%s" , o .opts .MySQLUser , o .opts .MySQLPassword , o .opts .MySQLHost , o .opts .MySQLPort , o .opts .MySQLDB )
225252}
226253
227- func (o * MySQLOLTP ) getRandTableNum () int {
254+ func (o * OLTPBench ) dsnPgSQL () string {
255+ return fmt .Sprintf ("postgres://%s:%s@%s:%d/%s?sslmode=disable" , o .opts .PgSQLUser , o .opts .PgSQLPassword , o .opts .PgSQLHost , o .opts .PgSQLPort , o .opts .PgSQLDB )
256+ }
257+
258+ func (o * OLTPBench ) getRandTableNum () int {
228259 return sbRand (1 , o .opts .Tables )
229260}
230261
@@ -256,8 +287,15 @@ func sbRandStr(format string) string {
256287 return string (buf )
257288}
258289
259- func (o * MySQLOLTP ) createTable () error {
260- idDef := "INT NOT NULL AUTO_INCREMENT"
290+ func (o * OLTPBench ) createTable () error {
291+ var idDef string
292+
293+ if o .opts .DBDriver == DBDriverPgSQL {
294+ idDef = "INT NOT NULL"
295+ } else {
296+ idDef = "INT NOT NULL AUTO_INCREMENT"
297+ }
298+
261299 idIndexDef := "PRIMARY KEY"
262300 engineDef := ""
263301 extraTableOptions := ""
@@ -278,10 +316,10 @@ func (o *MySQLOLTP) createTable() error {
278316
279317 fmt .Printf ("Inserting %d records into 'sbtest%d'\n " , o .opts .TableSize , tableNum )
280318 insertValues := []string {}
281- for i := 0 ; i < o .opts .TableSize ; i ++ {
282- insertValues = append (insertValues , fmt .Sprintf (`(%d, "%s", "%s" ) ` , sbRand (0 , o .opts .TableSize ), getCValue (), getPadValue ()))
319+ for i := 1 ; i <= o .opts .TableSize ; i ++ {
320+ insertValues = append (insertValues , fmt .Sprintf (`(%d, %d, '%s', '%s' ) ` , i , sbRand (0 , o .opts .TableSize ), getCValue (), getPadValue ()))
283321 }
284- query = fmt .Sprintf ("INSERT INTO sbtest%d (k, c, pad) VALUES" , tableNum ) + strings .Join (insertValues , "," )
322+ query = fmt .Sprintf ("INSERT INTO sbtest%d (id, k, c, pad) VALUES" , tableNum ) + strings .Join (insertValues , "," )
285323 _ , err = o .db .Exec (query )
286324 if err != nil {
287325 return err
0 commit comments