Skip to content

Commit 9292a1a

Browse files
authored
Merge pull request #202 Add flag for deny split transaction
2 parents 9c456d4 + 341ec3c commit 9292a1a

File tree

6 files changed

+227
-7
lines changed

6 files changed

+227
-7
lines changed

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
* Flag for deny split transaction
2+
13
## 2.12.3 ##
24
* Add six package to requirements
35
* Fixed error while passing date parameter in execute

tests/aio/test_tx.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,3 +94,62 @@ async def test_tx_snapshot_ro(driver, database):
9494
commit_tx=True,
9595
)
9696
assert data[0].rows == [{"value": 2}]
97+
98+
99+
@pytest.mark.asyncio
100+
async def test_split_transactions_deny_split(driver, table_name):
101+
async with ydb.aio.SessionPool(driver, 1) as pool:
102+
103+
async def check_transaction(s: ydb.aio.table.Session):
104+
async with s.transaction(deny_split_transactions=True) as tx:
105+
await tx.execute("INSERT INTO %s (id) VALUES (1)" % table_name)
106+
await tx.commit()
107+
108+
with pytest.raises(RuntimeError):
109+
await tx.execute("INSERT INTO %s (id) VALUES (2)" % table_name)
110+
111+
await tx.commit()
112+
113+
async with s.transaction() as tx:
114+
rs = await tx.execute("SELECT COUNT(*) as cnt FROM %s" % table_name)
115+
assert rs[0].rows[0].cnt == 1
116+
117+
await pool.retry_operation(check_transaction)
118+
119+
120+
@pytest.mark.asyncio
121+
async def test_split_transactions_allow_split(driver, table_name):
122+
async with ydb.aio.SessionPool(driver, 1) as pool:
123+
124+
async def check_transaction(s: ydb.aio.table.Session):
125+
async with s.transaction(deny_split_transactions=False) as tx:
126+
await tx.execute("INSERT INTO %s (id) VALUES (1)" % table_name)
127+
await tx.commit()
128+
129+
await tx.execute("INSERT INTO %s (id) VALUES (2)" % table_name)
130+
await tx.commit()
131+
132+
async with s.transaction() as tx:
133+
rs = await tx.execute("SELECT COUNT(*) as cnt FROM %s" % table_name)
134+
assert rs[0].rows[0].cnt == 2
135+
136+
await pool.retry_operation(check_transaction)
137+
138+
139+
@pytest.mark.asyncio
140+
async def test_split_transactions_default(driver, table_name):
141+
async with ydb.aio.SessionPool(driver, 1) as pool:
142+
143+
async def check_transaction(s: ydb.aio.table.Session):
144+
async with s.transaction() as tx:
145+
await tx.execute("INSERT INTO %s (id) VALUES (1)" % table_name)
146+
await tx.commit()
147+
148+
await tx.execute("INSERT INTO %s (id) VALUES (2)" % table_name)
149+
await tx.commit()
150+
151+
async with s.transaction() as tx:
152+
rs = await tx.execute("SELECT COUNT(*) as cnt FROM %s" % table_name)
153+
assert rs[0].rows[0].cnt == 2
154+
155+
await pool.retry_operation(check_transaction)

tests/conftest.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,3 +111,30 @@ async def driver_sync(endpoint, database, event_loop):
111111
yield driver
112112

113113
driver.stop(timeout=10)
114+
115+
116+
@pytest.fixture()
117+
def table_name(driver_sync, database):
118+
table_name = "table"
119+
120+
with ydb.SessionPool(driver_sync) as pool:
121+
122+
def create_table(s):
123+
try:
124+
s.drop_table(database + "/" + table_name)
125+
except ydb.SchemeError:
126+
pass
127+
128+
s.execute_scheme(
129+
"""
130+
CREATE TABLE %s (
131+
id Int64 NOT NULL,
132+
i64Val Int64,
133+
PRIMARY KEY(id)
134+
)
135+
"""
136+
% table_name
137+
)
138+
139+
pool.retry_operation_sync(create_table)
140+
return table_name

tests/table/test_tx.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,3 +89,59 @@ def test_tx_snapshot_ro(driver_sync, database):
8989
commit_tx=True,
9090
)
9191
assert data[0].rows == [{"value": 2}]
92+
93+
94+
def test_split_transactions_deny_split(driver_sync, table_name):
95+
with ydb.SessionPool(driver_sync, 1) as pool:
96+
97+
def check_transaction(s: ydb.table.Session):
98+
with s.transaction(deny_split_transactions=True) as tx:
99+
tx.execute("INSERT INTO %s (id) VALUES (1)" % table_name)
100+
tx.commit()
101+
102+
with pytest.raises(RuntimeError):
103+
tx.execute("INSERT INTO %s (id) VALUES (2)" % table_name)
104+
105+
tx.commit()
106+
107+
with s.transaction() as tx:
108+
rs = tx.execute("SELECT COUNT(*) as cnt FROM %s" % table_name)
109+
assert rs[0].rows[0].cnt == 1
110+
111+
pool.retry_operation_sync(check_transaction)
112+
113+
114+
def test_split_transactions_allow_split(driver_sync, table_name):
115+
with ydb.SessionPool(driver_sync, 1) as pool:
116+
117+
def check_transaction(s: ydb.table.Session):
118+
with s.transaction(deny_split_transactions=False) as tx:
119+
tx.execute("INSERT INTO %s (id) VALUES (1)" % table_name)
120+
tx.commit()
121+
122+
tx.execute("INSERT INTO %s (id) VALUES (2)" % table_name)
123+
tx.commit()
124+
125+
with s.transaction() as tx:
126+
rs = tx.execute("SELECT COUNT(*) as cnt FROM %s" % table_name)
127+
assert rs[0].rows[0].cnt == 2
128+
129+
pool.retry_operation_sync(check_transaction)
130+
131+
132+
def test_split_transactions_default(driver_sync, table_name):
133+
with ydb.SessionPool(driver_sync, 1) as pool:
134+
135+
def check_transaction(s: ydb.table.Session):
136+
with s.transaction() as tx:
137+
tx.execute("INSERT INTO %s (id) VALUES (1)" % table_name)
138+
tx.commit()
139+
140+
tx.execute("INSERT INTO %s (id) VALUES (2)" % table_name)
141+
tx.commit()
142+
143+
with s.transaction() as tx:
144+
rs = tx.execute("SELECT COUNT(*) as cnt FROM %s" % table_name)
145+
assert rs[0].rows[0].cnt == 2
146+
147+
pool.retry_operation_sync(check_transaction)

ydb/aio/table.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -120,8 +120,14 @@ async def alter_table(
120120
set_read_replicas_settings,
121121
)
122122

123-
def transaction(self, tx_mode=None):
124-
return TxContext(self._driver, self._state, self, tx_mode)
123+
def transaction(self, tx_mode=None, *, deny_split_transactions=False):
124+
return TxContext(
125+
self._driver,
126+
self._state,
127+
self,
128+
tx_mode,
129+
deny_split_transactions=deny_split_transactions,
130+
)
125131

126132
async def describe_table(self, path, settings=None): # pylint: disable=W0236
127133
return await super().describe_table(path, settings)
@@ -184,6 +190,9 @@ async def __aexit__(self, exc_type, exc_val, exc_tb):
184190
async def execute(
185191
self, query, parameters=None, commit_tx=False, settings=None
186192
): # pylint: disable=W0236
193+
194+
self._check_split()
195+
187196
return await super().execute(query, parameters, commit_tx, settings)
188197

189198
async def commit(self, settings=None): # pylint: disable=W0236

ydb/table.py

Lines changed: 72 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1176,7 +1176,7 @@ def execute_scheme(self, yql_text, settings=None):
11761176
pass
11771177

11781178
@abstractmethod
1179-
def transaction(self, tx_mode=None):
1179+
def transaction(self, tx_mode=None, deny_split_transactions=False):
11801180
pass
11811181

11821182
@abstractmethod
@@ -1681,8 +1681,14 @@ def execute_scheme(self, yql_text, settings=None):
16811681
self._state.endpoint,
16821682
)
16831683

1684-
def transaction(self, tx_mode=None):
1685-
return TxContext(self._driver, self._state, self, tx_mode)
1684+
def transaction(self, tx_mode=None, deny_split_transactions=False):
1685+
return TxContext(
1686+
self._driver,
1687+
self._state,
1688+
self,
1689+
tx_mode,
1690+
deny_split_transactions=deny_split_transactions,
1691+
)
16861692

16871693
def has_prepared(self, query):
16881694
return query in self._state
@@ -2194,9 +2200,27 @@ def begin(self, settings=None):
21942200

21952201

21962202
class BaseTxContext(ITxContext):
2197-
__slots__ = ("_tx_state", "_session_state", "_driver", "session")
2203+
__slots__ = (
2204+
"_tx_state",
2205+
"_session_state",
2206+
"_driver",
2207+
"session",
2208+
"_finished",
2209+
"_deny_split_transactions",
2210+
)
21982211

2199-
def __init__(self, driver, session_state, session, tx_mode=None):
2212+
_COMMIT = "commit"
2213+
_ROLLBACK = "rollback"
2214+
2215+
def __init__(
2216+
self,
2217+
driver,
2218+
session_state,
2219+
session,
2220+
tx_mode=None,
2221+
*,
2222+
deny_split_transactions=False
2223+
):
22002224
"""
22012225
An object that provides a simple transaction context manager that allows statements execution
22022226
in a transaction. You don't have to open transaction explicitly, because context manager encapsulates
@@ -2219,6 +2243,8 @@ def __init__(self, driver, session_state, session, tx_mode=None):
22192243
self._tx_state = _tx_ctx_impl.TxState(tx_mode)
22202244
self._session_state = session_state
22212245
self.session = session
2246+
self._finished = ""
2247+
self._deny_split_transactions = deny_split_transactions
22222248

22232249
def __enter__(self):
22242250
"""
@@ -2276,6 +2302,9 @@ def execute(self, query, parameters=None, commit_tx=False, settings=None):
22762302
22772303
:return: A result sets or exception in case of execution errors
22782304
"""
2305+
2306+
self._check_split()
2307+
22792308
return self._driver(
22802309
_tx_ctx_impl.execute_request_factory(
22812310
self._session_state,
@@ -2302,8 +2331,12 @@ def commit(self, settings=None):
23022331
23032332
:return: A committed transaction or exception if commit is failed
23042333
"""
2334+
2335+
self._set_finish(self._COMMIT)
2336+
23052337
if self._tx_state.tx_id is None and not self._tx_state.dead:
23062338
return self
2339+
23072340
return self._driver(
23082341
_tx_ctx_impl.commit_request_factory(self._session_state, self._tx_state),
23092342
_apis.TableService.Stub,
@@ -2323,8 +2356,12 @@ def rollback(self, settings=None):
23232356
23242357
:return: A rolled back transaction or exception if rollback is failed
23252358
"""
2359+
2360+
self._set_finish(self._ROLLBACK)
2361+
23262362
if self._tx_state.tx_id is None and not self._tx_state.dead:
23272363
return self
2364+
23282365
return self._driver(
23292366
_tx_ctx_impl.rollback_request_factory(self._session_state, self._tx_state),
23302367
_apis.TableService.Stub,
@@ -2345,6 +2382,9 @@ def begin(self, settings=None):
23452382
"""
23462383
if self._tx_state.tx_id is not None:
23472384
return self
2385+
2386+
self._check_split()
2387+
23482388
return self._driver(
23492389
_tx_ctx_impl.begin_request_factory(self._session_state, self._tx_state),
23502390
_apis.TableService.Stub,
@@ -2355,6 +2395,21 @@ def begin(self, settings=None):
23552395
self._session_state.endpoint,
23562396
)
23572397

2398+
def _set_finish(self, val):
2399+
self._check_split(val)
2400+
self._finished = val
2401+
2402+
def _check_split(self, allow=""):
2403+
"""
2404+
Deny all operaions with transaction after commit/rollback.
2405+
Exception: double commit and double rollbacks, because it is safe
2406+
"""
2407+
if not self._deny_split_transactions:
2408+
return
2409+
2410+
if self._finished != "" and self._finished != allow:
2411+
raise RuntimeError("Any operation with finished transaction is denied")
2412+
23582413

23592414
class TxContext(BaseTxContext):
23602415
@_utilities.wrap_async_call_exceptions
@@ -2370,6 +2425,9 @@ def async_execute(self, query, parameters=None, commit_tx=False, settings=None):
23702425
23712426
:return: A future of query execution
23722427
"""
2428+
2429+
self._check_split()
2430+
23732431
return self._driver.future(
23742432
_tx_ctx_impl.execute_request_factory(
23752433
self._session_state,
@@ -2401,8 +2459,11 @@ def async_commit(self, settings=None):
24012459
24022460
:return: A future of commit call
24032461
"""
2462+
self._set_finish(self._COMMIT)
2463+
24042464
if self._tx_state.tx_id is None and not self._tx_state.dead:
24052465
return _utilities.wrap_result_in_future(self)
2466+
24062467
return self._driver.future(
24072468
_tx_ctx_impl.commit_request_factory(self._session_state, self._tx_state),
24082469
_apis.TableService.Stub,
@@ -2423,8 +2484,11 @@ def async_rollback(self, settings=None):
24232484
24242485
:return: A future of rollback call
24252486
"""
2487+
self._set_finish(self._ROLLBACK)
2488+
24262489
if self._tx_state.tx_id is None and not self._tx_state.dead:
24272490
return _utilities.wrap_result_in_future(self)
2491+
24282492
return self._driver.future(
24292493
_tx_ctx_impl.rollback_request_factory(self._session_state, self._tx_state),
24302494
_apis.TableService.Stub,
@@ -2446,6 +2510,9 @@ def async_begin(self, settings=None):
24462510
"""
24472511
if self._tx_state.tx_id is not None:
24482512
return _utilities.wrap_result_in_future(self)
2513+
2514+
self._check_split()
2515+
24492516
return self._driver.future(
24502517
_tx_ctx_impl.begin_request_factory(self._session_state, self._tx_state),
24512518
_apis.TableService.Stub,

0 commit comments

Comments
 (0)