forked from y-scope/spider
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtest_mariadb.py
More file actions
108 lines (85 loc) · 3.88 KB
/
test_mariadb.py
File metadata and controls
108 lines (85 loc) · 3.88 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
"""Tests for the MariaDB storage backend."""
import os
from uuid import uuid4
import msgpack
import pytest
from spider_py import chain, group, Int8, TaskContext
from spider_py.core import Data, DriverId, Job, JobStatus, TaskInputValue
from spider_py.storage import MariaDBStorage, parse_jdbc_url, StorageError
MariaDBTestUrl = "jdbc:mariadb://127.0.0.1:3306/spider-storage?user=spider&password=password"
@pytest.fixture(scope="session")
def mariadb_storage() -> MariaDBStorage:
"""Fixture to create a MariaDB storage instance."""
url = os.getenv("SPIDER_STORAGE_URL", MariaDBTestUrl)
params = parse_jdbc_url(url)
return MariaDBStorage(params)
def double(_: TaskContext, x: Int8) -> Int8:
"""Double a number."""
return Int8(x * 2)
def swap(_: TaskContext, x: Int8, y: Int8) -> tuple[Int8, Int8]:
"""Swaps two numbers."""
return y, x
@pytest.fixture
def submit_job(mariadb_storage: MariaDBStorage) -> Job:
"""
Fixture to submit a simple job to the MariaDB storage backend.
The job composes of two parent tasks of `double` and a child task of `swap`.
:param mariadb_storage:
:return: The submitted job.
"""
graph = chain(group([double, double]), group([swap]))._impl
# Fill input data
for i, task_index in enumerate(graph.input_task_indices):
task = graph.tasks[task_index]
task.task_inputs[0].value = TaskInputValue(msgpack.packb(i))
driver_id = uuid4()
jobs = mariadb_storage.submit_jobs(driver_id, [graph])
return jobs[0]
@pytest.fixture
def driver(mariadb_storage: MariaDBStorage) -> DriverId:
"""Fixture to create a driver."""
driver_id = uuid4()
mariadb_storage.create_driver(driver_id)
return driver_id
class TestMariaDBStorage:
"""Test class for the MariaDB storage backend."""
@pytest.mark.storage
def test_job_submission(self, mariadb_storage: MariaDBStorage) -> None:
"""Tests job submission to the MariaDB storage backend."""
graph = chain(group([double, double, double, double]), group([swap, swap]))._impl
# Fill input data
for i, task_index in enumerate(graph.input_task_indices):
task = graph.tasks[task_index]
task.task_inputs[0].value = TaskInputValue(msgpack.packb(i))
driver_id = uuid4()
jobs = mariadb_storage.submit_jobs(driver_id, [graph])
assert len(jobs) == 1
@pytest.mark.storage
def test_running_job_status(self, mariadb_storage: MariaDBStorage, submit_job: Job) -> None:
"""Tests getting status of a running job."""
status = mariadb_storage.get_job_status(submit_job)
assert status == JobStatus.Running
@pytest.mark.storage
def test_running_job_result(self, mariadb_storage: MariaDBStorage, submit_job: Job) -> None:
"""Tests getting results of a running job."""
results = mariadb_storage.get_job_results(submit_job)
assert results is None
@pytest.mark.storage
def test_data(self, mariadb_storage: MariaDBStorage, driver: DriverId) -> None:
"""Tests data storage and retrieval."""
value = b"test data"
data = Data(id=uuid4(), value=value, localities=["localhost"])
mariadb_storage.create_data_with_driver_ref(driver, data)
retrieved_data = mariadb_storage.get_data(data.id)
assert retrieved_data is not None
assert retrieved_data.id == data.id
assert retrieved_data.value == value
assert retrieved_data.hard_locality == data.hard_locality
assert retrieved_data.localities == data.localities
@pytest.mark.storage
def test_create_data_fail(self, mariadb_storage: MariaDBStorage) -> None:
"""Tests creating data without a driver fails."""
value = b"test data"
data = Data(id=uuid4(), value=value, localities=["localhost"])
with pytest.raises(StorageError):
mariadb_storage.create_data_with_driver_ref(uuid4(), data)