Skip to content

Commit 9bf5266

Browse files
Refactor PseudoClient to use async for SID endpoint and update valida… (#480)
* Refactor PseudoClient to use async for SID endpoint and update validation logic * Add arguments to docstrings * bump version --------- Co-authored-by: Nicholas <3789764+skykanin@users.noreply.github.com>
1 parent 2c0b6ab commit 9bf5266

File tree

7 files changed

+166
-95
lines changed

7 files changed

+166
-95
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "dapla-toolbelt-pseudo"
3-
version = "6.0.1"
3+
version = "6.0.2"
44
description = "Pseudonymization extensions for Dapla"
55
authors = [{ name = "Dapla Developers", email = "dapla-platform-developers@ssb.no" }]
66
requires-python = ">=3.11,<4.0"

src/dapla_pseudo/v1/client.py

Lines changed: 66 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -287,29 +287,77 @@ def _handle_response_error_sync(response: requests.Response) -> None:
287287
print(response.text)
288288
response.raise_for_status()
289289

290-
def _post_to_sid_endpoint(
290+
async def _post_to_sid_endpoint(
291291
self,
292292
path: str,
293293
values: list[str],
294294
sid_snapshot_date: date | None = None,
295-
stream: bool = True,
296-
) -> requests.Response:
297-
request: dict[str, t.Collection[str]] = {"fnrList": values}
298-
response = requests.post(
299-
url=f"{self.pseudo_service_url}/{path}",
300-
params={"snapshot": str(sid_snapshot_date)} if sid_snapshot_date else None,
301-
# Do not set content-type, as this will cause the json to serialize incorrectly
302-
headers={
303-
"Authorization": f"Bearer {self.__auth_token()}",
304-
"X-Correlation-Id": PseudoClient._generate_new_correlation_id(),
305-
},
306-
json=request,
307-
stream=stream,
308-
timeout=TIMEOUT_DEFAULT, # seconds
309-
)
295+
) -> tuple[list[str], str | None]:
296+
"""Post SID lookup in batches concurrently and merge responses.
297+
298+
Args:
299+
path: Endpoint path to append to the base pseudo_service_url
300+
values: List of FNR to look up in the SID-catalogue.
301+
sid_snapshot_date: Date representing SID-catalogue version to use. Latest if unspecified. Format: YYYY-MM-DD
302+
303+
Returns:
304+
tuple[list[str], str | None]: (missing_values, datasetExtractionSnapshotTime)
305+
"""
306+
total_rows = len(values)
307+
batch_size = self.rows_per_partition
308+
309+
# Do not split if total rows is less than batch size (default 10k)
310+
if total_rows <= batch_size:
311+
batches = [values]
312+
else:
313+
batches = [
314+
values[i : i + batch_size] for i in range(0, total_rows, batch_size)
315+
]
316+
317+
async with ClientSession(
318+
connector=TCPConnector(limit=100, enable_cleanup_closed=True),
319+
timeout=ClientTimeout(total=TIMEOUT_DEFAULT),
320+
) as session:
310321

311-
PseudoClient._handle_response_error_sync(response)
312-
return response
322+
async def _post_batch(
323+
batch: list[str],
324+
path: str,
325+
) -> dict[str, list[str] | str]:
326+
resp_cm = await session.post(
327+
url=f"{self.pseudo_service_url}/{path}",
328+
params=(
329+
{"snapshot": str(sid_snapshot_date)}
330+
if sid_snapshot_date
331+
else None
332+
),
333+
# Do not set content-type, as this will cause the json to serialize incorrectly
334+
headers={
335+
"Authorization": f"Bearer {self.__auth_token()}",
336+
"X-Correlation-Id": PseudoClient._generate_new_correlation_id(),
337+
},
338+
json={"fnrList": batch},
339+
)
340+
async with resp_cm as response:
341+
await PseudoClient._handle_response_error(response)
342+
payload = await response.json()
343+
return t.cast(
344+
dict[str, list[str] | str],
345+
(
346+
payload[0]
347+
if isinstance(payload, list) and payload
348+
else payload
349+
),
350+
)
351+
352+
results = await asyncio.gather(*[_post_batch(b, path) for b in batches])
353+
354+
# Merge results
355+
all_missing = [m for r in results for m in r.get("missing", [])]
356+
raw_snapshot = (
357+
results[0].get("datasetExtractionSnapshotTime") if results else None
358+
)
359+
snapshot_time = t.cast(str | None, raw_snapshot)
360+
return all_missing, snapshot_time
313361

314362

315363
def _client() -> PseudoClient:

src/dapla_pseudo/v1/validation.py

Lines changed: 22 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,12 @@
11
"""Builder for submitting a validation request."""
22

3-
import json
4-
from collections.abc import Sequence
3+
import asyncio
54
from datetime import date
65
from pathlib import Path
76
from typing import Any
87

98
import pandas as pd
109
import polars as pl
11-
import requests
1210

1311
from dapla_pseudo.utils import convert_to_date
1412
from dapla_pseudo.utils import get_file_format_from_file_name
@@ -98,7 +96,7 @@ def __init__(
9896
self._dataframe: pl.DataFrame = dataframe
9997
self._field: str = field
10098

101-
def validate_map_to_stable_id(
99+
async def _validate_map_to_stable_id_async(
102100
self, sid_snapshot_date: str | date | None = None
103101
) -> Result:
104102
"""Checks if all the selected fields can be mapped to a stable ID.
@@ -112,31 +110,26 @@ def validate_map_to_stable_id(
112110
"""
113111
Validator._ensure_field_valid(self._field, self._dataframe)
114112

115-
response: requests.Response = _client()._post_to_sid_endpoint(
116-
"sid/lookup/batch",
117-
self._dataframe[self._field].to_list(),
118-
convert_to_date(sid_snapshot_date),
119-
stream=True,
113+
client = _client()
114+
all_values = self._dataframe[self._field].to_list()
115+
snapshot = convert_to_date(sid_snapshot_date)
116+
117+
missing, extraction_time = await client._post_to_sid_endpoint(
118+
path="sid/lookup/batch",
119+
values=all_values,
120+
sid_snapshot_date=snapshot,
121+
)
122+
123+
result_df = pl.Series(self._field, missing).to_frame()
124+
metadata_logs: list[str] = (
125+
[f"SID snapshot time {extraction_time}"] if extraction_time else []
120126
)
121-
# The response content is received as a buffered byte stream from the server.
122-
# We decode the content using UTF-8, which gives us a List[Dict[str]] structure.
123-
result_json = json.loads(response.content.decode("utf-8"))[0]
124-
result: Sequence[str] = []
125-
metadata: list[str] = []
126-
if "missing" in result_json:
127-
result = result_json["missing"]
128-
if "datasetExtractionSnapshotTime" in result_json:
129-
extraction_time = result_json["datasetExtractionSnapshotTime"]
130-
metadata = [f"SID snapshot time {extraction_time}"]
131-
132-
result_df = pl.Series(self._field, result).to_frame()
133-
# TODO - make the Validator fit the Result() class better
134127
return Result(
135128
PseudoFieldResponse(
136129
data=result_df,
137130
raw_metadata=[
138131
RawPseudoMetadata(
139-
logs=metadata,
132+
logs=metadata_logs,
140133
metrics=[],
141134
datadoc=[],
142135
field_name=self._field,
@@ -145,6 +138,12 @@ def validate_map_to_stable_id(
145138
)
146139
)
147140

141+
def validate_map_to_stable_id(
142+
self, sid_snapshot_date: str | date | None = None
143+
) -> Result:
144+
"""Validate mapping to SID (sync wrapper around the async batched version)."""
145+
return asyncio.run(self._validate_map_to_stable_id_async(sid_snapshot_date))
146+
148147
@staticmethod
149148
def _ensure_field_valid(field: str, dataframe: pl.DataFrame) -> None:
150149
"""Ensure that all values are numeric and valid.
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
import polars as pl
2+
import pytest
3+
import requests
4+
from tests.v1.integration.utils import integration_test
5+
6+
from dapla_pseudo import Validator
7+
from dapla_pseudo.v1.result import Result
8+
9+
10+
@pytest.mark.usefixtures("setup")
11+
@integration_test()
12+
def test_sid_lookup_batch_payload_too_large() -> None:
13+
n_rows = 1_400_000
14+
df = pl.DataFrame({"fnr": [f"{i:011d}" for i in range(n_rows)]})
15+
16+
try:
17+
result = Validator.from_polars(df).on_field("fnr").validate_map_to_stable_id()
18+
except requests.HTTPError as err:
19+
# Explicitly fail on HTTP 413
20+
if err.response is not None and err.response.status_code == 413:
21+
pytest.fail(f"Unexpected 413 Payload Too Large: {err.response.text}")
22+
# Re-raise other HTTP errors to surface genuine failures
23+
raise
24+
except requests.ConnectionError as err:
25+
msg = str(err)
26+
# Some servers may close connection with a payload-too-large reason
27+
if ("Payload Too Large" in msg) or ("Request Entity Too Large" in msg):
28+
pytest.fail(f"Unexpected payload-too-large connection error: {msg}")
29+
raise
30+
31+
# Sanity check: we got a Result back
32+
assert isinstance(result, Result)

tests/v1/unit/test_client.py

Lines changed: 26 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
from unittest.mock import patch
55

66
import pytest
7-
import requests
87
from aiohttp import ClientResponse
98
from aiohttp import ClientResponseError
109
from aiohttp import RequestInfo
@@ -392,30 +391,31 @@ async def test_post_to_field_endpoint_test_splits_multiple_fields() -> None:
392391
assert results == expected_data
393392

394393

395-
@patch("requests.post")
396-
def test_successful_post_to_sid_endpoint(
397-
mock_post: Mock, test_client: PseudoClient
398-
) -> None:
399-
mock_response = Mock(spec=requests.Response)
400-
mock_response.status_code = 200
401-
mock_response.raise_for_status.return_value = None
394+
@pytest.mark.asyncio
395+
async def test_successful_post_to_sid_endpoint(test_client: PseudoClient) -> None:
396+
payload = [
397+
{
398+
"missing": ["magic", "sorcery"],
399+
"datasetExtractionSnapshotTime": "2024-01-01T00:00:00Z",
400+
}
401+
]
402402

403-
mock_post.return_value = mock_response
404-
response = test_client._post_to_sid_endpoint(
405-
path="test_path",
406-
values=["value1", "value2"],
407-
)
403+
with patch("aiohttp.ClientSession.post", new_callable=AsyncMock) as mock_post:
404+
mock_response = AsyncMock()
405+
mock_response.__aenter__.return_value = mock_response
406+
mock_response.json.return_value = payload
407+
mock_post.return_value = mock_response
408408

409-
expected_json = {"fnrList": ["value1", "value2"]}
410-
assert response == mock_response
411-
mock_post.assert_called_once_with(
412-
url="https://mocked.dapla-pseudo-service/test_path",
413-
params=None,
414-
headers={
415-
"Authorization": "Bearer some-auth-token",
416-
"X-Correlation-Id": ANY,
417-
},
418-
json=expected_json,
419-
stream=True,
420-
timeout=TIMEOUT_DEFAULT,
421-
)
409+
missing, snapshot_time = await test_client._post_to_sid_endpoint(
410+
path="test_path",
411+
values=["value1", "value2"],
412+
sid_snapshot_date=None,
413+
)
414+
415+
assert missing == ["magic", "sorcery"]
416+
assert snapshot_time == "2024-01-01T00:00:00Z"
417+
418+
mock_post.assert_called_once()
419+
_args, kwargs = mock_post.call_args
420+
assert kwargs["json"] == {"fnrList": ["value1", "value2"]}
421+
assert "X-Correlation-Id" in kwargs["headers"]

tests/v1/unit/test_validation.py

Lines changed: 18 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from datetime import date
2-
from unittest.mock import MagicMock
2+
from unittest.mock import AsyncMock
33
from unittest.mock import Mock
44
from unittest.mock import patch
55

@@ -17,26 +17,20 @@
1717

1818

1919
@pytest_cases.fixture()
20-
def sid_lookup_missing_response() -> MagicMock:
21-
mock_response = MagicMock()
22-
mock_response.status_code = 200
23-
mock_response.content = b'[{"missing": ["20859374701","01234567890"], "datasetExtractionSnapshotTime": "2023-08-31"}]'
24-
return mock_response
20+
def sid_lookup_missing_response() -> tuple[list[str], str]:
21+
return (["20859374701", "01234567890"], "2023-08-31")
2522

2623

2724
@pytest_cases.fixture()
28-
def sid_lookup_empty_response() -> MagicMock:
29-
mock_response = MagicMock()
30-
mock_response.status_code = 200
31-
mock_response.content = b'[{"datasetExtractionSnapshotTime": "2023-08-31"}]'
32-
return mock_response
25+
def sid_lookup_empty_response() -> tuple[list[str], str]:
26+
return ([], "2023-08-31")
3327

3428

35-
@patch("dapla_pseudo.v1.PseudoClient._post_to_sid_endpoint")
29+
@patch("dapla_pseudo.v1.PseudoClient._post_to_sid_endpoint", new_callable=AsyncMock)
3630
def test_validate_with_full_response(
37-
patched_post_to_sid_endpoint: Mock,
31+
patched_post_to_sid_endpoint: AsyncMock,
3832
df_personer: pl.DataFrame,
39-
sid_lookup_missing_response: MagicMock,
33+
sid_lookup_missing_response: tuple[list[str], str],
4034
) -> None:
4135
field_name = "fnr"
4236

@@ -50,21 +44,20 @@ def test_validate_with_full_response(
5044
validation_df = validation_result.to_pandas()
5145
validation_metadata = validation_result.metadata_details
5246

53-
patched_post_to_sid_endpoint.assert_called_once_with(
54-
"sid/lookup/batch",
55-
["11854898347", "01839899544", "16910599481"],
56-
None,
57-
stream=True,
47+
patched_post_to_sid_endpoint.assert_awaited_once_with(
48+
path="sid/lookup/batch",
49+
values=["11854898347", "01839899544", "16910599481"],
50+
sid_snapshot_date=None,
5851
)
5952
assert validation_df[field_name].tolist() == ["20859374701", "01234567890"]
6053
assert validation_metadata[field_name]["logs"] == ["SID snapshot time 2023-08-31"]
6154

6255

63-
@patch("dapla_pseudo.v1.PseudoClient._post_to_sid_endpoint")
56+
@patch("dapla_pseudo.v1.PseudoClient._post_to_sid_endpoint", new_callable=AsyncMock)
6457
def test_validate_with_empty_response(
65-
patched_post_to_sid_endpoint: Mock,
58+
patched_post_to_sid_endpoint: AsyncMock,
6659
df_personer: pl.DataFrame,
67-
sid_lookup_empty_response: MagicMock,
60+
sid_lookup_empty_response: tuple[list[str], str],
6861
) -> None:
6962
field_name = "fnr"
7063

@@ -79,10 +72,9 @@ def test_validate_with_empty_response(
7972
validation_metadata = validation_result.metadata_details
8073

8174
patched_post_to_sid_endpoint.assert_called_once_with(
82-
"sid/lookup/batch",
83-
["11854898347", "01839899544", "16910599481"],
84-
date(2023, 8, 31),
85-
stream=True,
75+
path="sid/lookup/batch",
76+
values=["11854898347", "01839899544", "16910599481"],
77+
sid_snapshot_date=date(2023, 8, 31),
8678
)
8779
assert validation_df[field_name].tolist() == []
8880
assert validation_metadata[field_name]["logs"] == ["SID snapshot time 2023-08-31"]

uv.lock

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)