@@ -3,45 +3,43 @@ package main
33import (
44 "context"
55
6- "github.com/jackc/pgx/v5/pgconn"
6+ trmpgx "github.com/avito-tech/go-transaction-manager/drivers/pgxv5/v2"
7+ "github.com/avito-tech/go-transaction-manager/trm/v2/manager"
78 "go.uber.org/zap"
89
910 "github.com/stroppy-io/stroppy-core/pkg/logger"
1011 "github.com/stroppy-io/stroppy-core/pkg/plugins/driver"
1112 stroppy "github.com/stroppy-io/stroppy-core/pkg/proto"
13+ "github.com/stroppy-io/stroppy-core/pkg/utils/errchan"
1214
1315 "github.com/stroppy-io/stroppy-postgres/internal/pool"
1416 "github.com/stroppy-io/stroppy-postgres/internal/queries"
1517)
1618
17- type Connection interface {
18- // Exec executes the given SQL statement with the provided arguments in the context of the Executor.
19- //
20- // Parameters:
21- // - ctx: The context.Context object.
22- // - sql: The SQL statement to execute.
23- // - arguments: The arguments to be passed to the SQL statement.
24- //
25- // Returns:
26- // - pgconn.CommandTag: The command tag returned by the execution.
27- // - error: An error if the execution fails.
28- Exec (ctx context.Context , sql string , arguments ... interface {}) (pgconn.CommandTag , error )
29- Close ()
30- }
31-
3219type QueryBuilder interface {
3320 Build (
3421 ctx context.Context ,
3522 logger * zap.Logger ,
36- buildQueriesContext * stroppy.BuildQueriesContext ,
37- ) (* stroppy.DriverQueriesList , error )
23+ buildQueriesContext * stroppy.UnitBuildContext ,
24+ ) (* stroppy.DriverTransactionList , error )
25+ BuildStream (
26+ ctx context.Context ,
27+ logger * zap.Logger ,
28+ buildQueriesContext * stroppy.UnitBuildContext ,
29+ channel errchan.Chan [stroppy.DriverTransaction ],
30+ )
3831 ValueToPgxValue (value * stroppy.Value ) (any , error )
3932}
4033
4134type Driver struct {
42- logger * zap.Logger
43- connPool Connection
44- builder QueryBuilder
35+ logger * zap.Logger
36+ pgxPool interface {
37+ Executor
38+ Close ()
39+ }
40+ txManager * manager.Manager
41+ txExecutor * TxExecutor
42+ builder QueryBuilder
4543}
4644
4745func NewDriver () driver.Plugin { //nolint: ireturn // allow
@@ -55,56 +53,89 @@ func NewDriver() driver.Plugin { //nolint: ireturn // allow
5553func (d * Driver ) Initialize (ctx context.Context , runContext * stroppy.StepContext ) error {
5654 connPool , err := pool .NewPool (
5755 ctx ,
58- runContext .GetConfig ().GetDriver (),
56+ runContext .GetGlobalConfig (). GetRun ().GetDriver (),
5957 d .logger .Named (pool .LoggerName ),
6058 )
6159 if err != nil {
6260 return err
6361 }
6462
65- d .connPool = connPool
63+ d .pgxPool = connPool
6664
6765 d .builder , err = queries .NewQueryBuilder (runContext )
6866 if err != nil {
6967 return err
7068 }
7169
70+ d .txManager = manager .Must (trmpgx .NewDefaultFactory (connPool ))
71+ d .txExecutor = NewTxExecutor (connPool )
72+
7273 return nil
7374}
7475
75- func (d * Driver ) BuildQueries (
76+ func (d * Driver ) BuildTransactionsFromUnit (
7677 ctx context.Context ,
77- buildQueriesContext * stroppy.BuildQueriesContext ,
78- ) (* stroppy.DriverQueriesList , error ) {
79- return d .builder .Build (ctx , d .logger , buildQueriesContext )
78+ buildUnitContext * stroppy.UnitBuildContext ,
79+ ) (* stroppy.DriverTransactionList , error ) {
80+ return d .builder .Build (ctx , d .logger , buildUnitContext )
8081}
8182
82- func (d * Driver ) RunQuery (ctx context.Context , query * stroppy.DriverQuery ) error {
83- d .logger .Debug (
84- "run query" ,
85- zap .String ("name" , query .GetName ()),
86- zap .String ("sql" , query .GetRequest ()),
87- zap .Any ("args" , query .GetParams ()),
88- )
83+ func (d * Driver ) BuildTransactionsFromUnitStream (
84+ ctx context.Context ,
85+ buildUnitContext * stroppy.UnitBuildContext ,
86+ ) (errchan.Chan [stroppy.DriverTransaction ], error ) {
87+ channel := make (errchan.Chan [stroppy.DriverTransaction ])
88+ go func () {
89+ d .builder .BuildStream (ctx , d .logger , buildUnitContext , channel )
90+ }()
91+
92+ return channel , nil
93+ }
8994
90- values := make ([]any , len (query .GetParams ()))
95+ func (d * Driver ) RunTransaction (
96+ ctx context.Context ,
97+ transaction * stroppy.DriverTransaction ,
98+ ) error {
99+ if transaction .GetIsolationLevel () == stroppy .TxIsolationLevel_TX_ISOLATION_LEVEL_UNSPECIFIED {
100+ return d .runTransactionInternal (ctx , transaction , d .pgxPool )
101+ }
102+
103+ return d .txManager .DoWithSettings (
104+ ctx ,
105+ NewStroppyIsolationSettings (transaction ),
106+ func (ctx context.Context ) error {
107+ return d .runTransactionInternal (ctx , transaction , d .txExecutor )
108+ })
109+ }
110+
111+ func (d * Driver ) runTransactionInternal (
112+ ctx context.Context ,
113+ transaction * stroppy.DriverTransaction ,
114+ executor Executor ,
115+ ) error {
116+ for _ , query := range transaction .GetQueries () {
117+ values := make ([]any , len (query .GetParams ()))
118+
119+ for i , v := range query .GetParams () {
120+ val , err := d .builder .ValueToPgxValue (v )
121+ if err != nil {
122+ return err
123+ }
124+
125+ values [i ] = val
126+ }
91127
92- for i , v := range query .GetParams () {
93- val , err := d .builder .ValueToPgxValue (v )
128+ _ , err := executor .Exec (ctx , query .GetRequest (), values ... )
94129 if err != nil {
95130 return err
96131 }
97-
98- values [i ] = val
99132 }
100133
101- _ , err := d .connPool .Exec (ctx , query .GetRequest (), values ... )
102-
103- return err
134+ return nil
104135}
105136
106137func (d * Driver ) Teardown (_ context.Context ) error {
107- d .connPool .Close ()
138+ d .pgxPool .Close ()
108139
109140 return nil
110141}
0 commit comments