Skip to content

Commit 69fadcb

Browse files
sfc-gh-aalamsfc-gh-mkellersfc-gh-lqingithub-actions
authored
SNOW-678541: HTAP Driver Changes (#1416)
* SNOW-678541: snowflake-connector-python all integ tests successful (#1396) * SNOW-678541: snowflake-connector-python all integ tests successful * SNOW-678541: add explicit not None checks and cover _auth.py * SNOW-678541: add explicit not None checks * SNOW-678541: add htap tests * SNOW-678541: incorporate feedback * update import statement * Update test/integ/test_session_parameters.py Co-authored-by: Mark Keller <[email protected]> Co-authored-by: Mark Keller <[email protected]> * SNOW-677788 query context caching for python connector (#1405) * SNOW-677788: added QueryContextCache class * SNOW-677788: add lock; remove unnecessary comments * SNOW-677788: integrate qcc with connection.py * SNOW-677788: add unit tests; support context to be None * SNOW-677788: add sortedcontainers requirement * SNOW-730503: use QueryContextCache only if pyarrow installed; minor refactor _tree_set * SNOW-677788: add disable query context cache * SNOW-754627 python connector query context serialization/deserialization format from Apache Arrow to JSON (#1507) * record progress. unit testcases to be fixed * fix unit testcases. TODO: end-to-end test with GS * checked end-to-end test with GS especially for multi-database query * fix comments * add type error unit test case * deserialize from the python dict directly * SNOW-754627 fix priority switch corner case error (#1532) * update implementation * add randomized test * fix comments * Update requirements files * make sync_priority_map private * remove unused json import * update license * skip qcc tests for old driver * fix old driver test * fix old driver test attemp#3 * fix old driver test attemp#4 * address comments * remove deepcopy and rename merge to insert * lower bound sortedcontainers module * remove pytest mark * document the need to sync after add --------- Co-authored-by: Mark Keller <[email protected]> Co-authored-by: Lianke Qin <[email protected]> Co-authored-by: github-actions <[email protected]>
1 parent d6a4749 commit 69fadcb

File tree

10 files changed

+817
-14
lines changed

10 files changed

+817
-14
lines changed

setup.cfg

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ install_requires =
6060
certifi>=2017.4.17
6161
typing_extensions>=4.3,<5
6262
filelock>=3.5,<4
63+
sortedcontainers>=2.4.0
6364
include_package_data = True
6465
package_dir =
6566
=src
Lines changed: 276 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,276 @@
1+
#
2+
# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved.
3+
#
4+
from __future__ import annotations
5+
6+
import json
7+
from functools import total_ordering
8+
from hashlib import md5
9+
from logging import getLogger
10+
from threading import Lock
11+
from typing import Any, Iterable
12+
13+
from sortedcontainers import SortedSet
14+
15+
logger = getLogger(__name__)
16+
17+
18+
@total_ordering
19+
class QueryContextElement:
20+
def __init__(
21+
self, id: int, read_timestamp: int, priority: int, context: str
22+
) -> None:
23+
# entry with id = 0 is the main entry
24+
self.id = id
25+
self.read_timestamp = read_timestamp
26+
# priority values are 0..N with 0 being the highest priority
27+
self.priority = priority
28+
# OpaqueContext field will be base64 encoded in GS, but it is opaque to client side. Client side should not do decoding/encoding and just store the raw data.
29+
self.context = context
30+
31+
def __eq__(self, other: object) -> bool:
32+
if not isinstance(other, QueryContextElement):
33+
return False
34+
return (
35+
self.id == other.id
36+
and self.read_timestamp == other.read_timestamp
37+
and self.priority == other.priority
38+
and self.context == other.context
39+
)
40+
41+
def __lt__(self, other: Any) -> bool:
42+
if not isinstance(other, QueryContextElement):
43+
raise TypeError(
44+
f"cannot compare QueryContextElement with object of type {type(other)}"
45+
)
46+
return self.priority < other.priority
47+
48+
def __hash__(self) -> int:
49+
_hash = 31
50+
51+
_hash = _hash * 31 + self.id
52+
_hash += (_hash * 31) + self.read_timestamp
53+
_hash += (_hash * 31) + self.priority
54+
if self.context:
55+
_hash += (_hash * 31) + int.from_bytes(
56+
md5(self.context.encode("utf-8")).digest(), "big"
57+
)
58+
return _hash
59+
60+
def __str__(self) -> str:
61+
return f"({self.id}, {self.read_timestamp}, {self.priority})"
62+
63+
64+
class QueryContextCache:
65+
def __init__(self, capacity: int) -> None:
66+
self.capacity = capacity
67+
self._id_map: dict[int, QueryContextElement] = {}
68+
self._priority_map: dict[int, QueryContextElement] = {}
69+
self._intermediate_priority_map: dict[int, QueryContextElement] = {}
70+
71+
# stores elements sorted by priority. Element with
72+
# least priority value has the highest priority
73+
self._tree_set: set[QueryContextElement] = SortedSet()
74+
self._lock = Lock()
75+
self._data: str = None
76+
77+
def _add_qce(self, qce: QueryContextElement) -> None:
78+
"""Adds qce element in tree_set, id_map and intermediate_priority_map.
79+
We still need to add _sync_priority_map after all the new qce have been merged
80+
into the cache.
81+
"""
82+
self._tree_set.add(qce)
83+
self._id_map[qce.id] = qce
84+
self._intermediate_priority_map[qce.priority] = qce
85+
86+
def _remove_qce(self, qce: QueryContextElement) -> None:
87+
self._id_map.pop(qce.id)
88+
self._priority_map.pop(qce.priority)
89+
self._tree_set.remove(qce)
90+
91+
def _replace_qce(
92+
self, old_qce: QueryContextElement, new_qce: QueryContextElement
93+
) -> None:
94+
"""This is just a convenience function to call a remove and add operation back-to-back"""
95+
self._remove_qce(old_qce)
96+
self._add_qce(new_qce)
97+
98+
def _sync_priority_map(self):
99+
"""
100+
Sync the _intermediate_priority_map with the _priority_map at the end of the current round of inserts.
101+
"""
102+
logger.debug(
103+
f"sync_priority_map called priority_map size = {len(self._priority_map)}, new_priority_map size = {len(self._intermediate_priority_map)}"
104+
)
105+
106+
self._priority_map.update(self._intermediate_priority_map)
107+
# Clear the _intermediate_priority_map for the next round of QCC insert (a round consists of multiple entries)
108+
self._intermediate_priority_map.clear()
109+
110+
def insert(self, id: int, read_timestamp: int, priority: int, context: str) -> None:
111+
if id in self._id_map:
112+
qce = self._id_map[id]
113+
if (read_timestamp > qce.read_timestamp) or (
114+
read_timestamp == qce.read_timestamp and priority != qce.priority
115+
):
116+
# when id if found in cache and we are operating on a more recent timestamp. We do not update in-place here.
117+
new_qce = QueryContextElement(id, read_timestamp, priority, context)
118+
self._replace_qce(qce, new_qce)
119+
else:
120+
new_qce = QueryContextElement(id, read_timestamp, priority, context)
121+
if priority in self._priority_map:
122+
old_qce = self._priority_map[priority]
123+
self._replace_qce(old_qce, new_qce)
124+
else:
125+
self._add_qce(new_qce)
126+
127+
def trim_cache(self) -> None:
128+
logger.debug(
129+
f"trim_cache() called. treeSet size is {len(self._tree_set)} and cache capacity is {self.capacity}"
130+
)
131+
132+
while len(self) > self.capacity:
133+
# remove the qce with highest priority value => element with least priority
134+
qce = self._last()
135+
self._remove_qce(qce)
136+
137+
logger.debug(
138+
f"trim_cache() returns. treeSet size is {len(self._tree_set)} and cache capacity is {self.capacity}"
139+
)
140+
141+
def clear_cache(self) -> None:
142+
logger.debug("clear_cache() called")
143+
self._id_map.clear()
144+
self._priority_map.clear()
145+
self._tree_set.clear()
146+
self._intermediate_priority_map.clear()
147+
148+
def _get_elements(self) -> Iterable[QueryContextElement]:
149+
return self._tree_set
150+
151+
def _last(self) -> QueryContextElement:
152+
return self._tree_set[-1]
153+
154+
def serialize_to_json(self) -> str:
155+
with self._lock:
156+
logger.debug("serialize_to_json() called")
157+
self.log_cache_entries()
158+
159+
if len(self._tree_set) == 0:
160+
return ""
161+
162+
try:
163+
data = {
164+
"entries": [
165+
{
166+
"id": qce.id,
167+
"timestamp": qce.read_timestamp,
168+
"priority": qce.priority,
169+
"context": qce.context,
170+
}
171+
for qce in self._tree_set
172+
]
173+
}
174+
# Serialize the data to JSON
175+
serialized_data = json.dumps(data)
176+
177+
logger.debug(
178+
f"serialize_to_json(): data to send to server {serialized_data}"
179+
)
180+
181+
return serialized_data
182+
except Exception as e:
183+
logger.debug(f"serialize_to_json(): Exception {e}")
184+
return ""
185+
186+
def deserialize_json_dict(self, data: dict) -> None:
187+
with self._lock:
188+
logger.debug(f"deserialize_json_dict() called: data from server: {data}")
189+
self.log_cache_entries()
190+
191+
if data is None or len(data) == 0:
192+
self.clear_cache()
193+
logger.debug("deserialize_json_dict() returns")
194+
self.log_cache_entries()
195+
return
196+
197+
try:
198+
# Deserialize the entries. The first entry with priority 0 is the main entry. On python
199+
# connector side, we save all entries into one list to simplify the logic. When python
200+
# connector receives HTTP response, the data["queryContext"] field has been converted
201+
# from JSON to dict type automatically, so for this function we deserialize from python
202+
# dict directly. Below is an example QueryContext dict.
203+
# {
204+
# "entries": [
205+
# {
206+
# "id": 0,
207+
# "read_timestamp": 123456789,
208+
# "priority": 0,
209+
# "context": "base64 encoded context"
210+
# },
211+
# {
212+
# "id": 1,
213+
# "read_timestamp": 123456789,
214+
# "priority": 1,
215+
# "context": "base64 encoded context"
216+
# },
217+
# {
218+
# "id": 2,
219+
# "read_timestamp": 123456789,
220+
# "priority": 2,
221+
# "context": "base64 encoded context"
222+
# }
223+
# ]
224+
# }
225+
226+
# Deserialize entries
227+
entries = data.get("entries", list())
228+
for entry in entries:
229+
logger.debug(f"deserialize {entry}")
230+
if not isinstance(entry.get("id"), int):
231+
logger.debug("id type error")
232+
raise TypeError(
233+
f"Invalid type for 'id' field: Expected int, got {type(entry['id'])}"
234+
)
235+
if not isinstance(entry.get("timestamp"), int):
236+
logger.debug("timestamp type error")
237+
raise TypeError(
238+
f"Invalid type for 'timestamp' field: Expected int, got {type(entry['timestamp'])}"
239+
)
240+
if not isinstance(entry.get("priority"), int):
241+
logger.debug("priority type error")
242+
raise TypeError(
243+
f"Invalid type for 'priority' field: Expected int, got {type(entry['priority'])}"
244+
)
245+
246+
# OpaqueContext field currently is empty from GS side.
247+
context = entry.get("context", None)
248+
if context and not isinstance(entry.get("context"), str):
249+
logger.debug("context type error")
250+
raise TypeError(
251+
f"Invalid type for 'context' field: Expected str, got {type(entry['context'])}"
252+
)
253+
self.insert(
254+
entry.get("id"),
255+
entry.get("timestamp"),
256+
entry.get("priority"),
257+
context,
258+
)
259+
260+
# Sync the priority map at the end of for loop insert.
261+
self._sync_priority_map()
262+
except Exception as e:
263+
logger.debug(f"deserialize_json_dict: Exception = {e}")
264+
# clear cache due to incomplete insert
265+
self.clear_cache()
266+
267+
self.trim_cache()
268+
logger.debug("deserialize_json_dict() returns")
269+
self.log_cache_entries()
270+
271+
def log_cache_entries(self) -> None:
272+
for qce in self._tree_set:
273+
logger.debug(f"Cache Entry: {str(qce)}")
274+
275+
def __len__(self) -> int:
276+
return len(self._tree_set)

0 commit comments

Comments
 (0)