|
| 1 | +# Copyright 2024, Sayan Nandan <[email protected]> |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# |
| 13 | +# See the License for the specific language governing permissions and |
| 14 | +# limitations under the License. |
| 15 | + |
| 16 | +from typing import Union |
| 17 | +from .exception import ProtocolException |
| 18 | +from .response import Value, UInt8, UInt16, UInt32, UInt64, SInt8, SInt16, SInt32, SInt64, Float32, Float64, Empty, \ |
| 19 | + ErrorCode, Row |
| 20 | + |
| 21 | + |
| 22 | +class Protocol: |
| 23 | + def __init__(self, buffer=bytes()) -> None: |
| 24 | + self._buffer = buffer |
| 25 | + self._cursor = 0 |
| 26 | + |
| 27 | + def push_additional_bytes(self, additional_bytes: bytes) -> None: |
| 28 | + self._buffer = self._buffer + additional_bytes |
| 29 | + |
| 30 | + def __step(self) -> int: |
| 31 | + ret = self.__buf()[0] |
| 32 | + self.__increment_cursor() |
| 33 | + return ret |
| 34 | + |
| 35 | + def __decrement(self) -> None: |
| 36 | + self._cursor -= 1 |
| 37 | + |
| 38 | + def __increment_cursor_by(self, by: int) -> None: |
| 39 | + self._cursor += by |
| 40 | + |
| 41 | + def __increment_cursor(self) -> None: |
| 42 | + self.__increment_cursor_by(1) |
| 43 | + |
| 44 | + def __buf(self) -> bytes: |
| 45 | + return self._buffer[self._cursor:] |
| 46 | + |
| 47 | + def __remaining(self) -> int: |
| 48 | + return len(self.__buf()) |
| 49 | + |
| 50 | + def __is_eof(self) -> bool: |
| 51 | + return self.__remaining() == 0 |
| 52 | + |
| 53 | + def parse_next_int(self, stop_symbol='\n') -> Union[None, int]: |
| 54 | + i = 0 |
| 55 | + integer = 0 |
| 56 | + stop = False |
| 57 | + buffer = self.__buf() |
| 58 | + |
| 59 | + while i < len(buffer) and not stop: |
| 60 | + digit = None |
| 61 | + if 48 <= buffer[i] <= 57: |
| 62 | + digit = buffer[i] - 48 |
| 63 | + |
| 64 | + if digit is not None: |
| 65 | + integer = (10 * integer) + digit |
| 66 | + i += 1 |
| 67 | + else: |
| 68 | + raise ProtocolException("invalid response from server") |
| 69 | + |
| 70 | + if i < len(buffer) and buffer[i] == ord(stop_symbol): |
| 71 | + stop = True |
| 72 | + |
| 73 | + if stop: |
| 74 | + self.__increment_cursor_by(i) |
| 75 | + self.__increment_cursor() # for LF |
| 76 | + return integer |
| 77 | + |
| 78 | + def parse_next_string(self) -> Union[None, Value]: |
| 79 | + strlen = self.parse_next_int() |
| 80 | + if strlen: |
| 81 | + if self.__remaining() >= strlen: |
| 82 | + string = self.__buf()[:strlen].decode() |
| 83 | + self.__increment_cursor_by(strlen) |
| 84 | + return Value(string) |
| 85 | + |
| 86 | + def parse_next_binary(self) -> Union[None, Value]: |
| 87 | + binlen = self.parse_next_int() |
| 88 | + if binlen: |
| 89 | + if self.__remaining() >= binlen: |
| 90 | + blob = self.__buf()[:binlen] |
| 91 | + self.__increment_cursor_by(binlen) |
| 92 | + return Value(blob) |
| 93 | + |
| 94 | + def parse_boolean(self) -> Union[None, Value]: |
| 95 | + # boolean |
| 96 | + if self.__is_eof(): |
| 97 | + self.__decrement() # move back to type symbol |
| 98 | + return None |
| 99 | + else: |
| 100 | + byte = self.__step() |
| 101 | + if byte > 1: |
| 102 | + raise ProtocolException("received invalid data") |
| 103 | + return Value(True) if byte == 1 else Value(False) |
| 104 | + |
| 105 | + def parse_uint(self, type_symbol: int) -> Union[None, Value]: |
| 106 | + # uint |
| 107 | + integer = self.parse_next_int() |
| 108 | + if integer: |
| 109 | + if type_symbol == 2: |
| 110 | + return Value(UInt8(integer)) |
| 111 | + elif type_symbol == 3: |
| 112 | + return Value(UInt16(integer)) |
| 113 | + elif type_symbol == 4: |
| 114 | + return Value(UInt32(integer)) |
| 115 | + else: |
| 116 | + return Value(UInt64(integer)) |
| 117 | + else: |
| 118 | + self.__decrement() # move back to type symbol |
| 119 | + |
| 120 | + def parse_sint(self, type_symbol: int) -> Union[None, Value]: |
| 121 | + # sint |
| 122 | + if self.__is_eof(): |
| 123 | + self.__decrement() # move back to type symbol |
| 124 | + return None |
| 125 | + is_negative = False |
| 126 | + if self.__step() == ord('-'): |
| 127 | + is_negative = True |
| 128 | + else: |
| 129 | + self.__decrement() # move back to integer starting position since there is no '-' |
| 130 | + integer = self.parse_next_int() |
| 131 | + if integer: |
| 132 | + if is_negative: |
| 133 | + integer = -integer |
| 134 | + if type_symbol == 6: |
| 135 | + return Value(SInt8(integer)) |
| 136 | + elif type_symbol == 7: |
| 137 | + return Value(SInt16(integer)) |
| 138 | + elif type_symbol == 8: |
| 139 | + return Value(SInt32(integer)) |
| 140 | + else: |
| 141 | + return Value(SInt64(integer)) |
| 142 | + else: |
| 143 | + self.__decrement() # move back to type symbol |
| 144 | + if is_negative: |
| 145 | + self.__decrement() # move back to starting position of this integer |
| 146 | + |
| 147 | + def parse_float(self, type_symbol: int) -> Union[None, Value]: |
| 148 | + whole = self.parse_next_int(stop_symbol='.') |
| 149 | + if whole: |
| 150 | + decimal = self.parse_next_int() |
| 151 | + if decimal: |
| 152 | + full_float = float(f"{whole}.{decimal}") |
| 153 | + if type_symbol == 10: |
| 154 | + return Value(Float32(full_float)) |
| 155 | + else: |
| 156 | + return Value(Float64(full_float)) |
| 157 | + self.__decrement() # type symbol |
| 158 | + |
| 159 | + def parse_error_code(self) -> Union[None, ErrorCode]: |
| 160 | + if self.__remaining() < 2: |
| 161 | + self.__decrement() # type symbol |
| 162 | + else: |
| 163 | + a, b = self.__buf() |
| 164 | + self.__increment_cursor_by(2) |
| 165 | + return ErrorCode(int.from_bytes([a, b], byteorder="little", signed=False)) |
| 166 | + |
| 167 | + def parse_list(self) -> Union[None, Value]: |
| 168 | + cursor_start = self._cursor - 1 |
| 169 | + list_len = self.parse_next_int() |
| 170 | + if list_len is None: |
| 171 | + self.__decrement() # type symbol |
| 172 | + return None |
| 173 | + items = [] |
| 174 | + while len(items) != list_len: |
| 175 | + element = self.parse_next_element() |
| 176 | + if element: |
| 177 | + items.append(element) |
| 178 | + else: |
| 179 | + self._cursor = cursor_start |
| 180 | + return None |
| 181 | + return Value(items) |
| 182 | + |
| 183 | + def parse_row(self) -> Union[None, Row]: |
| 184 | + cursor_start = self._cursor - 1 |
| 185 | + column_count = self.parse_next_int() |
| 186 | + if column_count is None: |
| 187 | + self.__decrement() # type symbol |
| 188 | + return None |
| 189 | + columns = [] |
| 190 | + while len(columns) != column_count: |
| 191 | + column = self.parse_next_element() |
| 192 | + if column: |
| 193 | + columns.append(column) |
| 194 | + else: |
| 195 | + self._cursor = cursor_start |
| 196 | + return None |
| 197 | + return Row(columns) |
| 198 | + |
| 199 | + def parse_rows(self) -> Union[None, list[Row]]: |
| 200 | + cursor_start = self._cursor - 1 |
| 201 | + row_count = self.parse_next_int() |
| 202 | + rows = [] |
| 203 | + while len(rows) != row_count: |
| 204 | + row = self.parse_row() |
| 205 | + if row: |
| 206 | + rows.append(row) |
| 207 | + else: |
| 208 | + self._cursor = cursor_start |
| 209 | + return None |
| 210 | + return rows |
| 211 | + |
| 212 | + def parse_next_element(self) -> Union[None, Value, Empty, ErrorCode]: |
| 213 | + if self.__is_eof(): |
| 214 | + return None |
| 215 | + type_symbol = self.__step() |
| 216 | + if type_symbol == 0: |
| 217 | + # null |
| 218 | + return Value(None) |
| 219 | + elif type_symbol == 1: |
| 220 | + return self.parse_boolean() |
| 221 | + elif 2 <= type_symbol <= 5: |
| 222 | + return self.parse_uint(type_symbol) |
| 223 | + elif 6 <= type_symbol <= 9: |
| 224 | + return self.parse_sint(type_symbol) |
| 225 | + elif 10 <= type_symbol <= 11: |
| 226 | + return self.parse_float(type_symbol) |
| 227 | + elif type_symbol == 12: |
| 228 | + return self.parse_next_binary() |
| 229 | + elif type_symbol == 13: |
| 230 | + return self.parse_next_string() |
| 231 | + elif type_symbol == 14: |
| 232 | + return self.parse_list() |
| 233 | + elif type_symbol == 15: |
| 234 | + raise ProtocolException("dictionaries are not supported yet") |
| 235 | + elif type_symbol == 16: |
| 236 | + return self.parse_error_code() |
| 237 | + elif type_symbol == 17: |
| 238 | + return self.parse_row() |
| 239 | + elif type_symbol == 18: |
| 240 | + return Empty() |
| 241 | + elif type_symbol == 19: |
| 242 | + return self.parse_rows() |
| 243 | + else: |
| 244 | + raise ProtocolException( |
| 245 | + f"unknown type with code {type_symbol} sent by server") |
0 commit comments