Skip to content

Commit b41a388

Browse files
authored
Ensure custom types support pickling
Before this change `NamedRowTuple` did not roundtrip successfully. This also adds tests for all custom types.
1 parent fe1022b commit b41a388

File tree

2 files changed

+48
-0
lines changed

2 files changed

+48
-0
lines changed

tests/unit/test_types.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
# Licensed under the Apache License, Version 2.0 (the "License");
2+
# you may not use this file except in compliance with the License.
3+
# You may obtain a copy of the License at
4+
#
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
#
7+
# Unless required by applicable law or agreed to in writing, software
8+
# distributed under the License is distributed on an "AS IS" BASIS,
9+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10+
# See the License for the specific language governing permissions and
11+
# limitations under the License.
12+
13+
import pickle
14+
from datetime import datetime, time
15+
from decimal import Decimal
16+
17+
import pytest
18+
19+
from trino import types
20+
21+
22+
def identity(x):
23+
return x
24+
25+
26+
type_instances = [
27+
(types.Time(time(11, 47, 23), Decimal(0.314)), lambda v: v.to_python_type()),
28+
(types.TimeWithTimeZone(time(11, 47, 23), Decimal(0.314)), lambda v: v.to_python_type()),
29+
(types.Timestamp(datetime(2024, 10, 15, 11, 47, 23), Decimal(0.314)), lambda v: v.to_python_type()),
30+
(types.TimestampWithTimeZone(datetime(2024, 10, 15, 11, 47, 23), Decimal(0.314)), lambda v: v.to_python_type()),
31+
(types.NamedRowTuple(["Alice", 38], ["name", "age"], ["varchar", "integer"]), identity),
32+
]
33+
34+
35+
@pytest.mark.parametrize("value,fn", type_instances)
36+
def test_pickle_roundtripping(value, fn):
37+
bytes = pickle.dumps(value)
38+
unpickled_value = pickle.loads(bytes)
39+
assert fn(value) == fn(unpickled_value)

trino/types.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,5 +111,14 @@ def __getattr__(self, name: str) -> Any:
111111
if self._names.count(name):
112112
raise ValueError("Ambiguous row field reference: " + name)
113113

114+
def __getnewargs__(self) -> Any:
115+
return (tuple(self), (), ())
116+
117+
def __getstate__(self) -> Any:
118+
return vars(self)
119+
120+
def __setstate__(self, state: Any) -> None:
121+
vars(self).update(state)
122+
114123
def __repr__(self) -> str:
115124
return self._repr

0 commit comments

Comments
 (0)