1515from .._grpc .grpcwrapper .ydb_topic_public_types import PublicCodec
1616from .. import connection
1717
18- Message = typing .Union ["PublicMessage" , "PublicMessage.SimpleMessageSourceType " ]
18+ Message = typing .Union ["PublicMessage" , "PublicMessage.SimpleSourceType " ]
1919
2020
2121@dataclass
@@ -91,20 +91,23 @@ class PublicWriterInitInfo:
9191class PublicMessage :
9292 seqno : Optional [int ]
9393 created_at : Optional [datetime .datetime ]
94- data : "PublicMessage.SimpleMessageSourceType"
94+ data : "PublicMessage.SimpleSourceType"
95+ metadata_items : Optional [Dict [str , "PublicMessage.SimpleSourceType" ]]
9596
96- SimpleMessageSourceType = Union [str , bytes ] # Will be extend
97+ SimpleSourceType = Union [str , bytes ] # Will be extend
9798
9899 def __init__ (
99100 self ,
100- data : SimpleMessageSourceType ,
101+ data : SimpleSourceType ,
101102 * ,
103+ metadata_items : Optional [Dict [str , "PublicMessage.SimpleSourceType" ]] = None ,
102104 seqno : Optional [int ] = None ,
103105 created_at : Optional [datetime .datetime ] = None ,
104106 ):
105107 self .seqno = seqno
106108 self .created_at = created_at
107109 self .data = data
110+ self .metadata_items = metadata_items
108111
109112 @staticmethod
110113 def _create_message (data : Message ) -> "PublicMessage" :
@@ -121,26 +124,32 @@ def __init__(self, mess: PublicMessage):
121124 seq_no = mess .seqno ,
122125 created_at = mess .created_at ,
123126 data = mess .data ,
127+ metadata_items = mess .metadata_items ,
124128 uncompressed_size = len (mess .data ),
125129 partitioning = None ,
126130 )
127131 self .codec = PublicCodec .RAW
128132
129- def get_bytes (self ) -> bytes :
130- if self . data is None :
133+ def _get_bytes (self , obj : Optional [ PublicMessage . SimpleSourceType ] ) -> bytes :
134+ if obj is None :
131135 return bytes ()
132- if isinstance (self . data , bytes ):
133- return self . data
134- if isinstance (self . data , str ):
135- return self . data .encode ("utf-8" )
136+ if isinstance (obj , bytes ):
137+ return obj
138+ if isinstance (obj , str ):
139+ return obj .encode ("utf-8" )
136140 raise ValueError ("Bad data type" )
137141
142+ def get_data_bytes (self ) -> bytes :
143+ return self ._get_bytes (self .data )
144+
138145 def to_message_data (self ) -> StreamWriteMessage .WriteRequest .MessageData :
139- data = self .get_bytes ()
146+ data = self .get_data_bytes ()
147+ metadata_items = {key : self ._get_bytes (value ) for key , value in self .metadata_items .items ()}
140148 return StreamWriteMessage .WriteRequest .MessageData (
141149 seq_no = self .seq_no ,
142150 created_at = self .created_at ,
143151 data = data ,
152+ metadata_items = metadata_items ,
144153 uncompressed_size = len (data ),
145154 partitioning = None , # unsupported by server now
146155 )
@@ -221,6 +230,7 @@ def messages_to_proto_requests(
221230 seq_no = _max_int ,
222231 created_at = datetime .datetime (3000 , 1 , 1 , 1 , 1 , 1 , 1 ),
223232 data = bytes (1 ),
233+ metadata_items = {},
224234 uncompressed_size = _max_int ,
225235 partitioning = StreamWriteMessage .PartitioningMessageGroupID (
226236 message_group_id = "a" * 100 ,
0 commit comments