1313# See the License for the specific language governing permissions and
1414# limitations under the License.
1515
16- import asyncio
1716from asyncio import StreamReader , StreamWriter
18-
19-
20- class ClientException (Exception ):
21- """
22- An exception thrown by this client library
23- """
24- pass
17+ from .query import Query
18+ from .exception import ProtocolException
2519
2620
2721class Connection :
@@ -32,6 +26,8 @@ class Connection:
3226 def __init__ (self , reader : StreamReader , writer : StreamWriter ) -> None :
3327 self ._reader = reader
3428 self ._writer = writer
29+ self ._cursor = 0
30+ self .buffer = bytes ()
3531
3632 async def _write_all (self , bytes : bytes ):
3733 self ._write (bytes )
@@ -40,6 +36,9 @@ async def _write_all(self, bytes: bytes):
4036 def _write (self , bytes : bytes ) -> None :
4137 self ._writer .write (bytes )
4238
39+ def __buffer (self ) -> bytes :
40+ return self .buffer [:self ._cursor ]
41+
4342 async def _flush (self ):
4443 await self ._writer .drain ()
4544
@@ -53,46 +52,59 @@ async def close(self):
5352 self ._writer .close ()
5453 await self ._writer .wait_closed ()
5554
56-
57- class Config :
58- def __init__ (self , username : str , password : str , host : str = "127.0.0.1" , port : int = 2003 ) -> None :
59- self ._username = username
60- self ._password = password
61- self ._host = host
62- self ._port = port
63-
64- def get_username (self ) -> str :
65- return self ._username
66-
67- def get_password (self ) -> str :
68- return self ._password
69-
70- def get_host (self ) -> str :
71- return self ._host
72-
73- def get_port (self ) -> int :
74- return self ._port
75-
76- def __hs (self ) -> bytes :
77- return f"H\0 \0 \0 \0 \0 { len (self .get_username ())} \n { len (self .get_password ())} \n { self .get_username ()} { self .get_password ()} " .encode ()
78-
79- async def connect (self ) -> Connection :
80- """
81- Establish a connection to the database instance using the set configuration.
82-
83- ## Exceptions
84- Exceptions are raised in the following scenarios:
85- - If the server responds with a handshake error
86- - If the server sends an unknown handshake (usually caused by version incompatibility)
87- """
88- reader , writer = await asyncio .open_connection (self .get_host (), self .get_port ())
89- con = Connection (reader , writer )
90- await con ._write_all (self .__hs ())
91- resp = await con ._read_exact (4 )
92- a , b , c , d = resp [0 ], resp [1 ], resp [2 ], resp [3 ]
93- if resp == b"H\0 \0 \0 " :
94- return con
95- elif a == ord (b'H' ) and b == 0 and c == 1 :
96- raise ClientException (f"handshake error { d } " )
97- else :
98- raise ClientException ("unknown handshake" )
55+ def __parse_string (self ) -> None | str :
56+ strlen = self .__parse_int ()
57+ if strlen :
58+ if len (self .__buffer ()) >= strlen :
59+ strlen = self .__buffer ()[:strlen ].decode ()
60+ self ._cursor += strlen
61+ return strlen
62+
63+ def __parse_binary (self ) -> None | bytes :
64+ binlen = self .__parse_int ()
65+ if binlen :
66+ if len (self .__buffer ()) >= binlen :
67+ binlen = self .__buffer ()[:binlen ].decode ()
68+ self ._cursor += binlen
69+ return binlen
70+
71+ def __parse_int (self ) -> None | int :
72+ i = 0
73+ strlen = 0
74+ stop = False
75+ buffer = self .__buffer ()
76+
77+ while i < len (buffer ) and not stop :
78+ digit = None
79+ if 48 <= buffer [i ] <= 57 :
80+ digit = buffer [i ] - 48
81+
82+ if digit is not None :
83+ strlen = (10 * strlen ) + digit
84+ i += 1
85+ else :
86+ raise ProtocolException ("invalid response from server" )
87+
88+ if i < len (buffer ) and buffer [i ] == ord (b'\n ' ):
89+ stop = True
90+ i += 1
91+
92+ if stop :
93+ self ._cursor += i
94+ self ._cursor += 1 # for LF
95+ return strlen
96+
97+ async def run_simple_query (self , query : Query ):
98+ query_window_str = str (len (query ._q_window ))
99+ total_packet_size = len (query_window_str ) + 1 + len (query ._buffer )
100+ # write metaframe
101+ metaframe = f"S{ str (total_packet_size )} \n { query_window_str } \n "
102+ await self ._write_all (metaframe .encode ())
103+ # write dataframe
104+ await self ._write_all (query ._buffer )
105+ # now enter read loop
106+ while True :
107+ read = await self ._reader .read (1024 )
108+ if len (read ) == 0 :
109+ raise ConnectionResetError
110+ self .buffer = self .buffer + read
0 commit comments