44 "context"
55
66 "github.com/jackc/pgx/v5"
7+ "github.com/jackc/pgx/v5/pgxpool"
78 "github.com/torfstack/synod/backend/convert/fromdb"
89 "github.com/torfstack/synod/backend/convert/todb"
910 "github.com/torfstack/synod/backend/models"
@@ -12,71 +13,50 @@ import (
1213
1314type database struct {
1415 connStr string
15- conn * pgx. Conn
16+ pool * pgxpool. Pool
1617 tx * transaction
1718}
1819
1920var _ Database = (* database )(nil )
2021
21- func NewDatabase (connStr string ) Database {
22- return & database {connStr : connStr }
22+ func NewDatabase (ctx context.Context , connStr string ) (Database , error ) {
23+ pool , err := pgxpool .New (ctx , connStr )
24+ if err != nil {
25+ return nil , err
26+ }
27+ return & database {connStr : connStr , pool : pool }, nil
2328}
2429
2530func (d * database ) WithTx (ctx context.Context , withTx func (Database ) error ) error {
2631 if d .tx != nil {
2732 return withTx (d )
2833 }
2934
30- var conn * pgx.Conn
31- if d .conn != nil {
32- conn = d .conn
33- } else {
34- var err error
35- conn , err = pgx .Connect (ctx , d .connStr )
36- if err != nil {
37- return err
38- }
39- }
40-
41- tx , err := conn .Begin (ctx )
35+ tx , err := d .pool .Begin (ctx )
4236 if err != nil {
4337 return err
4438 }
45- trans := & transaction {conn : conn , tx : tx }
39+ trans := & transaction {tx : tx }
4640 defer func (tx pgx.Tx , ctx context.Context ) {
4741 _ = tx .Rollback (ctx )
4842 }(tx , ctx )
49- err = withTx (& database {connStr : d .connStr , conn : conn , tx : trans })
43+ err = withTx (& database {connStr : d .connStr , pool : d . pool , tx : trans })
5044 if err != nil {
5145 return err
5246 }
5347 return tx .Commit (ctx )
5448}
5549
56- func (d * database ) CommitTransaction (ctx context.Context ) error {
57- if d .tx == nil {
58- return nil
59- }
60- defer func (context.Context ) {
61- d .tx .Rollback (ctx )
62- _ = (* d .conn ).Close (ctx )
63- }(ctx )
64- d .tx .Commit (ctx )
65- return nil
66- }
67-
6850func (d * database ) DoesUserExist (ctx context.Context , username string ) (bool , error ) {
69- q , err := startQuery (ctx , d )
70- defer endQuery (ctx , d )
51+ q , err := startQuery (d )
7152 if err != nil {
7253 return false , err
7354 }
7455 return q .DoesUserExist (ctx , username )
7556}
7657
7758func (d * database ) InsertUser (ctx context.Context , user models.User ) (models.ExistingUser , error ) {
78- q , err := startQuery (ctx , d )
79- defer endQuery (ctx , d )
59+ q , err := startQuery (d )
8060 if err != nil {
8161 return models.ExistingUser {}, err
8262 }
@@ -86,8 +66,7 @@ func (d *database) InsertUser(ctx context.Context, user models.User) (models.Exi
8666}
8767
8868func (d * database ) SelectUserByName (ctx context.Context , username string ) (models.ExistingUser , error ) {
89- q , err := startQuery (ctx , d )
90- defer endQuery (ctx , d )
69+ q , err := startQuery (d )
9170 if err != nil {
9271 return models.ExistingUser {}, err
9372 }
@@ -100,8 +79,7 @@ func (d *database) UpsertSecret(
10079 secret models.EncryptedSecret ,
10180 userID int64 ,
10281) (models.EncryptedSecret , error ) {
103- q , err := startQuery (ctx , d )
104- defer endQuery (ctx , d )
82+ q , err := startQuery (d )
10583 if err != nil {
10684 return models.EncryptedSecret {}, err
10785 }
@@ -117,8 +95,7 @@ func (d *database) UpsertSecret(
11795}
11896
11997func (d * database ) SelectSecrets (ctx context.Context , userID int64 ) ([]models.EncryptedSecret , error ) {
120- q , err := startQuery (ctx , d )
121- defer endQuery (ctx , d )
98+ q , err := startQuery (d )
12299 if err != nil {
123100 return []models.EncryptedSecret {}, err
124101 }
@@ -127,8 +104,7 @@ func (d *database) SelectSecrets(ctx context.Context, userID int64) ([]models.En
127104}
128105
129106func (d * database ) InsertKeys (ctx context.Context , pair models.UserKeyPair ) (models.UserKeyPair , error ) {
130- q , err := startQuery (ctx , d )
131- defer endQuery (ctx , d )
107+ q , err := startQuery (d )
132108 if err != nil {
133109 return models.UserKeyPair {}, err
134110 }
@@ -141,8 +117,7 @@ func (d *database) InsertKeys(ctx context.Context, pair models.UserKeyPair) (mod
141117}
142118
143119func (d * database ) SelectKeys (ctx context.Context , userID int64 ) (models.UserKeyPair , error ) {
144- q , err := startQuery (ctx , d )
145- defer endQuery (ctx , d )
120+ q , err := startQuery (d )
146121 if err != nil {
147122 return models.UserKeyPair {}, err
148123 }
@@ -154,17 +129,15 @@ func (d *database) SelectKeys(ctx context.Context, userID int64) (models.UserKey
154129}
155130
156131func (d * database ) HasKeys (ctx context.Context , userID int64 ) (bool , error ) {
157- q , err := startQuery (ctx , d )
158- defer endQuery (ctx , d )
132+ q , err := startQuery (d )
159133 if err != nil {
160134 return false , err
161135 }
162136 return q .HasKeys (ctx , userID )
163137}
164138
165139func (d * database ) InsertPassword (ctx context.Context , password models.HashedPassword ) (models.HashedPassword , error ) {
166- q , err := startQuery (ctx , d )
167- defer endQuery (ctx , d )
140+ q , err := startQuery (d )
168141 if err != nil {
169142 return models.HashedPassword {}, err
170143 }
@@ -177,8 +150,7 @@ func (d *database) InsertPassword(ctx context.Context, password models.HashedPas
177150}
178151
179152func (d * database ) SelectPassword (ctx context.Context , passwordID int64 ) (models.HashedPassword , error ) {
180- q , err := startQuery (ctx , d )
181- defer endQuery (ctx , d )
153+ q , err := startQuery (d )
182154 if err != nil {
183155 return models.HashedPassword {}, err
184156 }
@@ -189,23 +161,9 @@ func (d *database) SelectPassword(ctx context.Context, passwordID int64) (models
189161 return fromdb .HashedPassword (dbPassword ), nil
190162}
191163
192- func startQuery (ctx context. Context , d * database ) (* sqlc.Queries , error ) {
164+ func startQuery (d * database ) (* sqlc.Queries , error ) {
193165 if d .tx != nil {
194166 return sqlc .New (d .tx .SqlTx ()), nil
195167 }
196- if d .conn == nil {
197- conn , err := pgx .Connect (ctx , d .connStr )
198- if err != nil {
199- return nil , err
200- }
201- d .conn = conn
202- }
203- return sqlc .New (d .conn ), nil
204- }
205-
206- func endQuery (ctx context.Context , d * database ) {
207- if d .conn != nil && d .tx == nil {
208- _ = (* d .conn ).Close (ctx )
209- d .conn = nil
210- }
168+ return sqlc .New (d .pool ), nil
211169}
0 commit comments