Skip to content

Commit 9e3a5fd

Browse files
committed
response: Add decode methods for all types
1 parent 077b7ff commit 9e3a5fd

File tree

6 files changed

+606
-53
lines changed

6 files changed

+606
-53
lines changed

.gitignore

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
11
env
22
dist
3-
**/__pycache__
3+
**/__pycache__
4+
mock*
5+
.idea

src/skytable_py/connection.py

Lines changed: 3 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616
from asyncio import StreamReader, StreamWriter
1717
from .query import Query
18-
from .exception import ProtocolException
18+
from .protocol import Protocol
1919

2020

2121
class Connection:
@@ -26,8 +26,7 @@ class Connection:
2626
def __init__(self, reader: StreamReader, writer: StreamWriter) -> None:
2727
self._reader = reader
2828
self._writer = writer
29-
self._cursor = 0
30-
self.buffer = bytes()
29+
self._protocol = Protocol()
3130

3231
async def _write_all(self, bytes: bytes):
3332
self._write(bytes)
@@ -52,59 +51,11 @@ async def close(self):
5251
self._writer.close()
5352
await self._writer.wait_closed()
5453

55-
def __parse_string(self) -> None | str:
56-
strlen = self.__parse_int()
57-
if strlen:
58-
if len(self.__buffer()) >= strlen:
59-
strlen = self.__buffer()[:strlen].decode()
60-
self._cursor += strlen
61-
return strlen
62-
63-
def __parse_binary(self) -> None | bytes:
64-
binlen = self.__parse_int()
65-
if binlen:
66-
if len(self.__buffer()) >= binlen:
67-
binlen = self.__buffer()[:binlen].decode()
68-
self._cursor += binlen
69-
return binlen
70-
71-
def __parse_int(self) -> None | int:
72-
i = 0
73-
strlen = 0
74-
stop = False
75-
buffer = self.__buffer()
76-
77-
while i < len(buffer) and not stop:
78-
digit = None
79-
if 48 <= buffer[i] <= 57:
80-
digit = buffer[i] - 48
81-
82-
if digit is not None:
83-
strlen = (10 * strlen) + digit
84-
i += 1
85-
else:
86-
raise ProtocolException("invalid response from server")
87-
88-
if i < len(buffer) and buffer[i] == ord(b'\n'):
89-
stop = True
90-
i += 1
91-
92-
if stop:
93-
self._cursor += i
94-
self._cursor += 1 # for LF
95-
return strlen
96-
9754
async def run_simple_query(self, query: Query):
98-
query_window_str = str(len(query._q_window))
55+
query_window_str = str(query._q_window)
9956
total_packet_size = len(query_window_str) + 1 + len(query._buffer)
10057
# write metaframe
10158
metaframe = f"S{str(total_packet_size)}\n{query_window_str}\n"
10259
await self._write_all(metaframe.encode())
10360
# write dataframe
10461
await self._write_all(query._buffer)
105-
# now enter read loop
106-
while True:
107-
read = await self._reader.read(1024)
108-
if len(read) == 0:
109-
raise ConnectionResetError
110-
self.buffer = self.buffer + read

src/skytable_py/protocol.py

Lines changed: 245 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,245 @@
1+
# Copyright 2024, Sayan Nandan <[email protected]>
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
#
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
from typing import Union
17+
from .exception import ProtocolException
18+
from .response import Value, UInt8, UInt16, UInt32, UInt64, SInt8, SInt16, SInt32, SInt64, Float32, Float64, Empty, \
19+
ErrorCode, Row
20+
21+
22+
class Protocol:
23+
def __init__(self, buffer=bytes()) -> None:
24+
self._buffer = buffer
25+
self._cursor = 0
26+
27+
def push_additional_bytes(self, additional_bytes: bytes) -> None:
28+
self._buffer = self._buffer + additional_bytes
29+
30+
def __step(self) -> int:
31+
ret = self.__buf()[0]
32+
self.__increment_cursor()
33+
return ret
34+
35+
def __decrement(self) -> None:
36+
self._cursor -= 1
37+
38+
def __increment_cursor_by(self, by: int) -> None:
39+
self._cursor += by
40+
41+
def __increment_cursor(self) -> None:
42+
self.__increment_cursor_by(1)
43+
44+
def __buf(self) -> bytes:
45+
return self._buffer[self._cursor:]
46+
47+
def __remaining(self) -> int:
48+
return len(self.__buf())
49+
50+
def __is_eof(self) -> bool:
51+
return self.__remaining() == 0
52+
53+
def parse_next_int(self, stop_symbol='\n') -> Union[None, int]:
54+
i = 0
55+
integer = 0
56+
stop = False
57+
buffer = self.__buf()
58+
59+
while i < len(buffer) and not stop:
60+
digit = None
61+
if 48 <= buffer[i] <= 57:
62+
digit = buffer[i] - 48
63+
64+
if digit is not None:
65+
integer = (10 * integer) + digit
66+
i += 1
67+
else:
68+
raise ProtocolException("invalid response from server")
69+
70+
if i < len(buffer) and buffer[i] == ord(stop_symbol):
71+
stop = True
72+
73+
if stop:
74+
self.__increment_cursor_by(i)
75+
self.__increment_cursor() # for LF
76+
return integer
77+
78+
def parse_next_string(self) -> Union[None, Value]:
79+
strlen = self.parse_next_int()
80+
if strlen:
81+
if self.__remaining() >= strlen:
82+
string = self.__buf()[:strlen].decode()
83+
self.__increment_cursor_by(strlen)
84+
return Value(string)
85+
86+
def parse_next_binary(self) -> Union[None, Value]:
87+
binlen = self.parse_next_int()
88+
if binlen:
89+
if self.__remaining() >= binlen:
90+
blob = self.__buf()[:binlen]
91+
self.__increment_cursor_by(binlen)
92+
return Value(blob)
93+
94+
def parse_boolean(self) -> Union[None, Value]:
95+
# boolean
96+
if self.__is_eof():
97+
self.__decrement() # move back to type symbol
98+
return None
99+
else:
100+
byte = self.__step()
101+
if byte > 1:
102+
raise ProtocolException("received invalid data")
103+
return Value(True) if byte == 1 else Value(False)
104+
105+
def parse_uint(self, type_symbol: int) -> Union[None, Value]:
106+
# uint
107+
integer = self.parse_next_int()
108+
if integer:
109+
if type_symbol == 2:
110+
return Value(UInt8(integer))
111+
elif type_symbol == 3:
112+
return Value(UInt16(integer))
113+
elif type_symbol == 4:
114+
return Value(UInt32(integer))
115+
else:
116+
return Value(UInt64(integer))
117+
else:
118+
self.__decrement() # move back to type symbol
119+
120+
def parse_sint(self, type_symbol: int) -> Union[None, Value]:
121+
# sint
122+
if self.__is_eof():
123+
self.__decrement() # move back to type symbol
124+
return None
125+
is_negative = False
126+
if self.__step() == ord('-'):
127+
is_negative = True
128+
else:
129+
self.__decrement() # move back to integer starting position since there is no '-'
130+
integer = self.parse_next_int()
131+
if integer:
132+
if is_negative:
133+
integer = -integer
134+
if type_symbol == 6:
135+
return Value(SInt8(integer))
136+
elif type_symbol == 7:
137+
return Value(SInt16(integer))
138+
elif type_symbol == 8:
139+
return Value(SInt32(integer))
140+
else:
141+
return Value(SInt64(integer))
142+
else:
143+
self.__decrement() # move back to type symbol
144+
if is_negative:
145+
self.__decrement() # move back to starting position of this integer
146+
147+
def parse_float(self, type_symbol: int) -> Union[None, Value]:
148+
whole = self.parse_next_int(stop_symbol='.')
149+
if whole:
150+
decimal = self.parse_next_int()
151+
if decimal:
152+
full_float = float(f"{whole}.{decimal}")
153+
if type_symbol == 10:
154+
return Value(Float32(full_float))
155+
else:
156+
return Value(Float64(full_float))
157+
self.__decrement() # type symbol
158+
159+
def parse_error_code(self) -> Union[None, ErrorCode]:
160+
if self.__remaining() < 2:
161+
self.__decrement() # type symbol
162+
else:
163+
a, b = self.__buf()
164+
self.__increment_cursor_by(2)
165+
return ErrorCode(int.from_bytes([a, b], byteorder="little", signed=False))
166+
167+
def parse_list(self) -> Union[None, Value]:
168+
cursor_start = self._cursor - 1
169+
list_len = self.parse_next_int()
170+
if list_len is None:
171+
self.__decrement() # type symbol
172+
return None
173+
items = []
174+
while len(items) != list_len:
175+
element = self.parse_next_element()
176+
if element:
177+
items.append(element)
178+
else:
179+
self._cursor = cursor_start
180+
return None
181+
return Value(items)
182+
183+
def parse_row(self) -> Union[None, Row]:
184+
cursor_start = self._cursor - 1
185+
column_count = self.parse_next_int()
186+
if column_count is None:
187+
self.__decrement() # type symbol
188+
return None
189+
columns = []
190+
while len(columns) != column_count:
191+
column = self.parse_next_element()
192+
if column:
193+
columns.append(column)
194+
else:
195+
self._cursor = cursor_start
196+
return None
197+
return Row(columns)
198+
199+
def parse_rows(self) -> Union[None, list[Row]]:
200+
cursor_start = self._cursor - 1
201+
row_count = self.parse_next_int()
202+
rows = []
203+
while len(rows) != row_count:
204+
row = self.parse_row()
205+
if row:
206+
rows.append(row)
207+
else:
208+
self._cursor = cursor_start
209+
return None
210+
return rows
211+
212+
def parse_next_element(self) -> Union[None, Value, Empty, ErrorCode]:
213+
if self.__is_eof():
214+
return None
215+
type_symbol = self.__step()
216+
if type_symbol == 0:
217+
# null
218+
return Value(None)
219+
elif type_symbol == 1:
220+
return self.parse_boolean()
221+
elif 2 <= type_symbol <= 5:
222+
return self.parse_uint(type_symbol)
223+
elif 6 <= type_symbol <= 9:
224+
return self.parse_sint(type_symbol)
225+
elif 10 <= type_symbol <= 11:
226+
return self.parse_float(type_symbol)
227+
elif type_symbol == 12:
228+
return self.parse_next_binary()
229+
elif type_symbol == 13:
230+
return self.parse_next_string()
231+
elif type_symbol == 14:
232+
return self.parse_list()
233+
elif type_symbol == 15:
234+
raise ProtocolException("dictionaries are not supported yet")
235+
elif type_symbol == 16:
236+
return self.parse_error_code()
237+
elif type_symbol == 17:
238+
return self.parse_row()
239+
elif type_symbol == 18:
240+
return Empty()
241+
elif type_symbol == 19:
242+
return self.parse_rows()
243+
else:
244+
raise ProtocolException(
245+
f"unknown type with code {type_symbol} sent by server")

0 commit comments

Comments
 (0)