1515from .._grpc .grpcwrapper .ydb_topic_public_types import PublicCodec
1616from .. import connection
1717
18- Message = typing .Union ["PublicMessage" , "PublicMessage.SimpleMessageSourceType " ]
18+ Message = typing .Union ["PublicMessage" , "PublicMessage.SimpleSourceType " ]
1919
2020
2121@dataclass
@@ -91,20 +91,23 @@ class PublicWriterInitInfo:
9191class PublicMessage :
9292 seqno : Optional [int ]
9393 created_at : Optional [datetime .datetime ]
94- data : "PublicMessage.SimpleMessageSourceType"
94+ data : "PublicMessage.SimpleSourceType"
95+ metadata_items : Optional [Dict [str , "PublicMessage.SimpleSourceType" ]]
9596
96- SimpleMessageSourceType = Union [str , bytes ] # Will be extend
97+ SimpleSourceType = Union [str , bytes ] # Will be extend
9798
9899 def __init__ (
99100 self ,
100- data : SimpleMessageSourceType ,
101+ data : SimpleSourceType ,
101102 * ,
103+ metadata_items : Optional [Dict [str , "PublicMessage.SimpleSourceType" ]] = None ,
102104 seqno : Optional [int ] = None ,
103105 created_at : Optional [datetime .datetime ] = None ,
104106 ):
105107 self .seqno = seqno
106108 self .created_at = created_at
107109 self .data = data
110+ self .metadata_items = metadata_items
108111
109112 @staticmethod
110113 def _create_message (data : Message ) -> "PublicMessage" :
@@ -121,26 +124,29 @@ def __init__(self, mess: PublicMessage):
121124 seq_no = mess .seqno ,
122125 created_at = mess .created_at ,
123126 data = mess .data ,
127+ metadata_items = mess .metadata_items ,
124128 uncompressed_size = len (mess .data ),
125129 partitioning = None ,
126130 )
127131 self .codec = PublicCodec .RAW
128132
129- def get_bytes (self ) -> bytes :
130- if self . data is None :
133+ def get_bytes (self , obj : Optional [ PublicMessage . SimpleSourceType ] ) -> bytes :
134+ if obj is None :
131135 return bytes ()
132- if isinstance (self . data , bytes ):
133- return self . data
134- if isinstance (self . data , str ):
135- return self . data .encode ("utf-8" )
136+ if isinstance (obj , bytes ):
137+ return obj
138+ if isinstance (obj , str ):
139+ return obj .encode ("utf-8" )
136140 raise ValueError ("Bad data type" )
137141
138142 def to_message_data (self ) -> StreamWriteMessage .WriteRequest .MessageData :
139- data = self .get_bytes ()
143+ data = self .get_bytes (self .data )
144+ metadata_items = {key : self .get_bytes (value ) for key , value in self .metadata_items .items ()}
140145 return StreamWriteMessage .WriteRequest .MessageData (
141146 seq_no = self .seq_no ,
142147 created_at = self .created_at ,
143148 data = data ,
149+ metadata_items = metadata_items ,
144150 uncompressed_size = len (data ),
145151 partitioning = None , # unsupported by server now
146152 )
@@ -221,6 +227,7 @@ def messages_to_proto_requests(
221227 seq_no = _max_int ,
222228 created_at = datetime .datetime (3000 , 1 , 1 , 1 , 1 , 1 , 1 ),
223229 data = bytes (1 ),
230+ metadata_items = {},
224231 uncompressed_size = _max_int ,
225232 partitioning = StreamWriteMessage .PartitioningMessageGroupID (
226233 message_group_id = "a" * 100 ,
0 commit comments