Skip to content

Commit 66b44c4

Browse files
committed
feat: joblib driver support
1 parent 1010320 commit 66b44c4

File tree

4 files changed

+295
-30
lines changed

4 files changed

+295
-30
lines changed

pins/drivers.py

Lines changed: 29 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,19 +4,37 @@
44
# from .boards import IFileSystem
55

66

7-
def load_data(meta: Meta, fs, path_to_version):
7+
def load_data(meta: Meta, fs, path_to_version: str):
8+
"""Return loaded data, based on meta type.
9+
Parameters
10+
----------
11+
meta: Meta
12+
Information about the stored data (e.g. its type).
13+
fs: IFileSystem
14+
An abstract filesystem with a method to .open() files.
15+
path_to_version:
16+
A filepath used as the parent directory the data to-be-loaded lives in.
17+
"""
818
# TODO: extandable loading with deferred importing
19+
20+
# Check that only a single file name was given
21+
fnames = [meta.file] if isinstance(meta.file, str) else meta.file
22+
if len(fnames) > 1:
23+
raise ValueError("Cannot load data when more than 1 file")
24+
25+
target_fname = fnames[0]
26+
path_to_file = f"{path_to_version}/{target_fname}"
27+
928
if meta.type == "csv":
1029
import pandas as pd
1130

12-
fnames = [meta.file] if isinstance(meta.file, str) else meta.file
13-
if len(fnames) > 1:
14-
raise ValueError("Cannot load CSV when more than 1 file")
15-
16-
target_fname = fnames[0]
17-
path_to_file = f"{path_to_version}/{target_fname}"
1831
return pd.read_csv(fs.open(path_to_file), index_col=0)
1932

33+
elif meta.type == "joblib":
34+
import joblib
35+
36+
return joblib.load(fs.open(path_to_file))
37+
2038
raise NotImplementedError(f"No driver for type {meta.type}")
2139

2240

@@ -35,6 +53,10 @@ def save_data(obj, fname, type=None):
3553
"Currently only pandas.DataFrame can be saved to a CSV."
3654
)
3755
obj.to_csv(fname)
56+
elif type == "joblib":
57+
import joblib
58+
59+
joblib.dump(obj, fname)
3860
else:
3961
raise NotImplementedError(f"Cannot save type: {type}")
4062

pins/tests/test_boards.py

Lines changed: 31 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,34 @@
11
import pytest
2+
import pandas as pd
23

34
from pins.tests.helpers import DEFAULT_CREATION_DATE
45

6+
# using pytest cases, so that we can pass in fixtures as parameters
7+
from pytest_cases import fixture, parametrize
58

6-
@pytest.fixture
9+
10+
@fixture
711
def board(backend):
812
yield backend.create_tmp_board()
913
backend.teardown()
1014

1115

16+
@fixture
17+
def df():
18+
return pd.DataFrame({"x": [1, 2, 3], "y": ["a", "b", "c"]})
19+
20+
21+
# High level pins functionality -----------------------------------------------
22+
23+
1224
def test_board_pin_write_default_title(board):
13-
import pandas as pd
1425

1526
df = pd.DataFrame({"x": [1, 2, 3], "y": [4, 5, 6]})
1627
meta = board.pin_write(df, "df_csv", title=None, type="csv")
1728
assert meta.title == "A pinned 3 x 2 CSV"
1829

1930

2031
def test_board_pin_write_prepare_pin(board, tmp_dir2):
21-
import pandas as pd
2232

2333
df = pd.DataFrame({"x": [1, 2, 3], "y": [4, 5, 6]})
2434

@@ -32,7 +42,6 @@ def test_board_pin_write_prepare_pin(board, tmp_dir2):
3242

3343

3444
def test_board_pin_write_roundtrip(board):
35-
import pandas as pd
3645

3746
df = pd.DataFrame({"x": [1, 2, 3], "y": [4, 5, 6]})
3847

@@ -68,8 +77,6 @@ def test_board_pin_write_rsc_index_html(board, tmp_dir2, snapshot):
6877
if board.fs.protocol != "rsc":
6978
pytest.skip()
7079

71-
import pandas as pd
72-
7380
df = pd.DataFrame({"x": [1, 2, 3], "y": ["a", "b", "c"]})
7481

7582
pin_name = "test_rsc_pin"
@@ -85,3 +92,21 @@ def test_board_pin_write_rsc_index_html(board, tmp_dir2, snapshot):
8592
)
8693

8794
snapshot.assert_equal_dir(tmp_dir2)
95+
96+
97+
# pin_write against different types -------------------------------------------
98+
99+
100+
@parametrize(
101+
"obj, type_", [(df, "csv"), (df, "joblib"), ({"a": 1, "b": [2, 3]}, "joblib")]
102+
)
103+
def test_board_pin_write_type(board, obj, type_, request):
104+
meta = board.pin_write(obj, "test_pin", type=type_, title="some title")
105+
dst_obj = board.pin_read("test_pin")
106+
107+
assert meta.type == type_
108+
109+
if isinstance(obj, pd.DataFrame):
110+
assert obj.equals(dst_obj)
111+
112+
obj == dst_obj

0 commit comments

Comments
 (0)