3535 UpdateTokenRequest ,
3636 UpdateTokenResponse ,
3737 StreamWriteMessage ,
38+ TransactionIdentity ,
3839 WriterMessagesFromServerToClient ,
3940)
4041from .._grpc .grpcwrapper .common_utils import (
4344 GrpcWrapperAsyncIO ,
4445)
4546
47+ if typing .TYPE_CHECKING :
48+ from ..query .transaction import BaseQueryTxContext
49+
4650logger = logging .getLogger (__name__ )
4751
4852
@@ -165,7 +169,20 @@ async def wait_init(self) -> PublicWriterInitInfo:
165169
166170
167171class TxWriterAsyncIO (WriterAsyncIO ):
168- ...
172+ _tx : object
173+
174+ def __init__ (
175+ self ,
176+ tx ,
177+ driver : SupportedDriverType ,
178+ settings : PublicWriterSettings ,
179+ _client = None ,
180+ ):
181+ self ._tx = tx
182+ self ._loop = asyncio .get_running_loop ()
183+ self ._closed = False
184+ self ._reconnector = WriterAsyncIOReconnector (driver = driver , settings = WriterSettings (settings ), tx = self ._tx )
185+ self ._parent = _client
169186
170187
171188class WriterAsyncIOReconnector :
@@ -182,6 +199,7 @@ class WriterAsyncIOReconnector:
182199 _codec_selector_batch_num : int
183200 _codec_selector_last_codec : Optional [PublicCodec ]
184201 _codec_selector_check_batches_interval : int
202+ _tx : Optional ["BaseQueryTxContext" ]
185203
186204 if typing .TYPE_CHECKING :
187205 _messages_for_encode : asyncio .Queue [List [InternalMessage ]]
@@ -199,7 +217,9 @@ class WriterAsyncIOReconnector:
199217 _stop_reason : asyncio .Future
200218 _init_info : Optional [PublicWriterInitInfo ]
201219
202- def __init__ (self , driver : SupportedDriverType , settings : WriterSettings ):
220+ def __init__ (
221+ self , driver : SupportedDriverType , settings : WriterSettings , tx : Optional ["BaseQueryTxContext" ] = None
222+ ):
203223 self ._closed = False
204224 self ._loop = asyncio .get_running_loop ()
205225 self ._driver = driver
@@ -209,6 +229,7 @@ def __init__(self, driver: SupportedDriverType, settings: WriterSettings):
209229 self ._init_info = None
210230 self ._stream_connected = asyncio .Event ()
211231 self ._settings = settings
232+ self ._tx = tx
212233
213234 self ._codec_functions = {
214235 PublicCodec .RAW : lambda data : data ,
@@ -358,10 +379,12 @@ async def _connection_loop(self):
358379 # noinspection PyBroadException
359380 stream_writer = None
360381 try :
382+ tx_identity = None if self ._tx is None else self ._tx ._tx_identity ()
361383 stream_writer = await WriterAsyncIOStream .create (
362384 self ._driver ,
363385 self ._init_message ,
364386 self ._settings .update_token_interval ,
387+ tx_identity = tx_identity ,
365388 )
366389 try :
367390 if self ._init_info is None :
@@ -601,10 +624,13 @@ class WriterAsyncIOStream:
601624 _update_token_event : asyncio .Event
602625 _get_token_function : Optional [Callable [[], str ]]
603626
627+ _tx_identity : Optional [TransactionIdentity ]
628+
604629 def __init__ (
605630 self ,
606631 update_token_interval : Optional [Union [int , float ]] = None ,
607632 get_token_function : Optional [Callable [[], str ]] = None ,
633+ tx_identity : Optional [TransactionIdentity ] = None ,
608634 ):
609635 self ._closed = False
610636
@@ -613,6 +639,8 @@ def __init__(
613639 self ._update_token_event = asyncio .Event ()
614640 self ._update_token_task = None
615641
642+ self ._tx_identity = tx_identity
643+
616644 async def close (self ):
617645 if self ._closed :
618646 return
@@ -629,6 +657,7 @@ async def create(
629657 driver : SupportedDriverType ,
630658 init_request : StreamWriteMessage .InitRequest ,
631659 update_token_interval : Optional [Union [int , float ]] = None ,
660+ tx_identity : Optional [TransactionIdentity ] = None ,
632661 ) -> "WriterAsyncIOStream" :
633662 stream = GrpcWrapperAsyncIO (StreamWriteMessage .FromServer .from_proto )
634663
@@ -638,6 +667,7 @@ async def create(
638667 writer = WriterAsyncIOStream (
639668 update_token_interval = update_token_interval ,
640669 get_token_function = creds .get_auth_token if creds else lambda : "" ,
670+ tx_identity = tx_identity ,
641671 )
642672 await writer ._start (stream , init_request )
643673 return writer
@@ -684,7 +714,7 @@ def write(self, messages: List[InternalMessage]):
684714 if self ._closed :
685715 raise RuntimeError ("Can not write on closed stream." )
686716
687- for request in messages_to_proto_requests (messages ):
717+ for request in messages_to_proto_requests (messages , self . _tx_identity ):
688718 self ._stream .write (request )
689719
690720 async def _update_token_loop (self ):
0 commit comments