Skip to content

Commit 26ea421

Browse files
feat(datastore): added newTransaction option for runQuery (#925)
1 parent 84da6b7 commit 26ea421

File tree

2 files changed

+60
-5
lines changed

2 files changed

+60
-5
lines changed

datastore/gcloud/aio/datastore/query.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -282,9 +282,11 @@ class QueryResult:
282282
query_result_batch_kind = QueryResultBatch
283283

284284
def __init__(self, result_batch: Optional[QueryResultBatch] = None,
285-
explain_metrics: Optional[ExplainMetrics] = None):
285+
explain_metrics: Optional[ExplainMetrics] = None,
286+
transaction: Optional[str] = None):
286287
self.result_batch = result_batch
287288
self.explain_metrics = explain_metrics
289+
self.transaction = transaction
288290

289291
def __repr__(self) -> str:
290292
return str(self.to_repr())
@@ -293,27 +295,33 @@ def __eq__(self, other: object) -> bool:
293295
if not isinstance(other, QueryResult):
294296
return False
295297
return (self.result_batch == other.result_batch
296-
and self.explain_metrics == other.explain_metrics)
298+
and self.explain_metrics == other.explain_metrics
299+
and self.transaction == other.transaction)
297300

298301
@classmethod
299302
def from_repr(cls, data: Dict[str, Any]) -> 'QueryResult':
300303
result_batch = None
301304
explain_metrics = None
305+
transaction = None
302306

303307
if 'batch' in data:
304308
result_batch = cls.query_result_batch_kind.from_repr(data['batch'])
305309
if 'explainMetrics' in data:
306310
explain_metrics = ExplainMetrics.from_repr(data['explainMetrics'])
311+
if 'transaction' in data:
312+
transaction = data['transaction']
307313

308-
return cls(result_batch=result_batch, explain_metrics=explain_metrics)
314+
return cls(result_batch=result_batch,
315+
explain_metrics=explain_metrics, transaction=transaction)
309316

310317
def to_repr(self) -> Dict[str, Any]:
311-
result = {}
318+
result: Dict[str, Any] = {}
312319
if self.result_batch:
313320
result['batch'] = self.result_batch.to_repr()
314321
if self.explain_metrics:
315322
result['explainMetrics'] = self.explain_metrics.to_repr()
316-
323+
if self.transaction:
324+
result['transaction'] = self.transaction
317325
return result
318326

319327
def get_explain_metrics(self) -> Optional[ExplainMetrics]:

datastore/tests/integration/smoke_test.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,53 @@ async def test_start_transaction_on_lookup(creds: str,
135135
actual = await ds.lookup([key], session=s)
136136
assert actual['found'][0].entity.properties == {'animal': 'aardvark'}
137137

138+
# Clean up test data
139+
await ds.delete(key, s)
140+
141+
142+
@pytest.mark.asyncio
143+
async def test_start_transaction_on_query(
144+
creds: str, kind: str, project: str,
145+
) -> None:
146+
key = Key(project, [PathElement(kind, name=f'test_record_{uuid.uuid4()}')])
147+
148+
async with Session() as s:
149+
ds = Datastore(project=project, service_file=creds, session=s)
150+
151+
# Test query with newTransaction parameter
152+
property_filter = PropertyFilter(
153+
prop='animal',
154+
operator=PropertyFilterOperator.EQUAL,
155+
value=Value('three-toed sloth'),
156+
)
157+
query = Query(kind=kind, query_filter=Filter(property_filter))
158+
159+
# Use newTransaction parameter
160+
options = TransactionOptions(ReadWrite())
161+
result = await ds.runQuery(query, newTransaction=options, session=s)
162+
assert result.transaction is not None and result.transaction
163+
164+
mutations = [
165+
ds.make_mutation(
166+
Operation.INSERT, key,
167+
properties={'animal': 'three-toed sloth'},
168+
),
169+
ds.make_mutation(
170+
Operation.UPDATE, key,
171+
properties={'animal': 'aardvark'},
172+
),
173+
]
174+
await ds.commit(
175+
mutations,
176+
transaction=result.transaction,
177+
session=s)
178+
179+
actual = await ds.lookup([key], session=s)
180+
assert actual['found'][0].entity.properties == {'animal': 'aardvark'}
181+
182+
# Clean up test data
183+
await ds.delete(key, s)
184+
138185

139186
@pytest.mark.asyncio
140187
async def test_transaction(creds: str, kind: str, project: str) -> None:

0 commit comments

Comments
 (0)