Skip to content

Commit 2d60c3c

Browse files
[MPT-10795] Add TrialInfo dataclass and code improvements
1 parent 0a26392 commit 2d60c3c

File tree

6 files changed

+113
-91
lines changed

6 files changed

+113
-91
lines changed

ffc/billing/dataclasses.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,13 @@ class Refund:
2525
description: str
2626

2727

28+
@dataclass
29+
class TrialInfo:
30+
trial_days: set[int] | None = None
31+
refund_from: date | None = None
32+
refund_to: date | None = None
33+
34+
2835
@dataclass
2936
class CurrencyConversionInfo:
3037
base_currency: str
Lines changed: 45 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from datetime import UTC, date, datetime
88
from io import BytesIO
99
from pathlib import Path
10-
from typing import Any
10+
from typing import Any, AsyncGenerator, Coroutine
1111

1212
import aiofiles
1313
import aiofiles.os
@@ -22,6 +22,7 @@
2222
CurrencyConversionInfo,
2323
Datasource,
2424
Refund,
25+
TrialInfo,
2526
)
2627
from ffc.billing.exceptions import ExchangeRatesClientError, JournalStatusError
2728
from ffc.clients.exchage_rates import ExchangeRatesAsyncClient
@@ -112,7 +113,7 @@ def __init__(
112113
)
113114

114115
@asynccontextmanager
115-
async def acquire_semaphore(self):
116+
async def acquire_semaphore(self) -> AsyncGenerator[None, Any]:
116117
"""
117118
This method acquires and releases a semaphore.
118119
"""
@@ -129,7 +130,7 @@ async def maybe_call(
129130
func,
130131
*args,
131132
**kwargs,
132-
):
133+
) -> Any | None:
133134
"""
134135
Conditionally calls and awaits the given asynchronous function.
135136
@@ -149,7 +150,7 @@ async def maybe_call(
149150
if not self.dry_run:
150151
return await func(*args, **kwargs)
151152

152-
def build_filepath(self):
153+
def build_filepath(self) -> str:
153154
"""
154155
Constructs and returns the file path for a charges JSONL file.
155156
@@ -166,7 +167,7 @@ def build_filepath(self):
166167
filepath = f"{tempfile.gettempdir()}/{filepath}" if not self.dry_run else filepath
167168
return filepath
168169

169-
async def evaluate_journal_status(self, journal_external_id):
170+
async def evaluate_journal_status(self, journal_external_id) -> dict[str, Any] | None:
170171
"""
171172
Evaluates the status of a journal.
172173
@@ -195,7 +196,7 @@ async def evaluate_journal_status(self, journal_external_id):
195196
self.logger.warning(f"Found the journal {journal_id} with status {journal_status}")
196197
raise JournalStatusError()
197198

198-
async def process(self):
199+
async def process(self) -> None:
199200
"""
200201
This method is responsible for passing a journal with status VALIDATED to
201202
the function that writes the charges into a file and then to complete the
@@ -204,7 +205,6 @@ async def process(self):
204205
result = AuthorizationProcessResult(authorization_id=self.authorization_id)
205206
async with self.acquire_semaphore():
206207
try:
207-
# double check with production
208208
if not await self.mpt_client.count_active_agreements(
209209
self.authorization_id,
210210
self.billing_start_date,
@@ -248,7 +248,7 @@ async def process(self):
248248
except Exception as error:
249249
self.logger.error(f"An error occurred: {error}", exc_info=error)
250250

251-
async def write_charges_file(self, filepath):
251+
async def write_charges_file(self, filepath) -> bool:
252252
"""
253253
This method writes the charges file to the given filepath.
254254
If there is more than one agreement for an organization, it won't be processed.
@@ -308,7 +308,9 @@ async def write_charges_file(self, filepath):
308308
return False
309309
return True
310310

311-
async def complete_journal_process(self, filepath, journal, journal_external_id):
311+
async def complete_journal_process(
312+
self, filepath, journal, journal_external_id
313+
) -> Coroutine | None:
312314
"""
313315
This method uploads and submits the given journal, attaching also the exchange rates
314316
files.
@@ -377,7 +379,7 @@ async def get_currency_conversion_info(
377379

378380
async def attach_exchange_rates(
379381
self, journal_id: str, currency: str, exchange_rates: dict[str, Any]
380-
):
382+
) -> Coroutine | None:
381383
"""
382384
This method checks if an attachment already exists for the given journal.
383385
If it exists, it will be deleted and a new one will be created with the
@@ -392,7 +394,7 @@ async def attach_exchange_rates(
392394
attachment = await self.mpt_client.fetch_journal_attachment(journal_id, f"{currency}_")
393395
if attachment: # pragma no cover
394396
if attachment["name"] == filename:
395-
return
397+
return None
396398
await self.mpt_client.delete_journal_attachment(journal_id, attachment["id"])
397399

398400
return await self.mpt_client.create_journal_attachment(journal_id, filename, serialized)
@@ -402,7 +404,7 @@ async def dump_organization_charges(
402404
charges_file: Any,
403405
organization: dict[str, Any],
404406
agreement: dict[str, Any] | None = None,
405-
):
407+
) -> Coroutine | None:
406408
organization_id = organization["id"]
407409
async for datasource_info, expenses in async_groupby(
408410
self.ffc_client.fetch_organization_expenses(organization_id, self.year, self.month),
@@ -441,7 +443,7 @@ async def generate_datasource_charges(
441443
datasource_id: str,
442444
datasource_name: str,
443445
daily_expenses: dict[int, Decimal],
444-
):
446+
) -> list[str]:
445447
"""
446448
This method generates all the charges for the given datasource and
447449
calculates the refund for the Trials and Entitlements periods.
@@ -539,7 +541,7 @@ async def generate_datasource_charges(
539541
)
540542
return charges
541543

542-
async def is_journal_status_validated(self, journal_id, max_attempts=5):
544+
async def is_journal_status_validated(self, journal_id, max_attempts=5) -> bool:
543545
backoff_times = [0.15, 0.45, 1.05, 2.25, 4.65]
544546

545547
for attempt in range(min(max_attempts, len(backoff_times))):
@@ -566,17 +568,19 @@ def generate_refunds(
566568
trial_days = set()
567569
trial_start_date, trial_end_date = get_trial_dates(agreement=agreement)
568570
if trial_start_date and trial_end_date:
569-
trial_days, trial_refund_from, trial_refund_to = self.get_trial_days(
571+
trial_info = self.get_trial_days(
570572
trial_start_date,
571573
trial_end_date,
572574
)
573-
trial_amount = sum(daily_expenses.get(day, Decimal("0")) for day in trial_days)
574-
575+
trial_amount = sum(
576+
daily_expenses.get(day, Decimal("0")) for day in trial_info.trial_days
577+
)
578+
trial_days = trial_info.trial_days
575579
refund_lines.append(
576580
Refund(
577581
trial_amount,
578-
trial_refund_from,
579-
trial_refund_to,
582+
trial_info.refund_from,
583+
trial_info.refund_to,
580584
(
581585
"Refund due to trial period "
582586
f"(from {trial_start_date.strftime("%d %b %Y")} " # type: ignore
@@ -614,7 +618,7 @@ async def generate_refund_lines(
614618
linked_datasource_type: str,
615619
datasource_id: str,
616620
datasource_name: str,
617-
):
621+
) -> list[str]:
618622
"""
619623
This function calculates the refund lines for a billing period
620624
"""
@@ -680,24 +684,39 @@ def get_trial_days(
680684
self,
681685
trial_start_date: date | None,
682686
trial_end_date: date | None,
683-
):
687+
) -> TrialInfo:
684688
# Trial period can start or end on month other than billing month period.
685689
# In this situation, we need to limit refunded expenses
686690
# to a period overlapping with billing month.
687691
# Example, billing month is June 1-30, a trial period is May 17 - June 17
688692
# We need to refund expenses from June 1st to June 17th
689693
if not trial_start_date and not trial_end_date:
690-
return None, None, None
694+
return TrialInfo()
691695
trial_refund_from = max(trial_start_date, self.billing_start_date.date())
692696
trial_refund_to = min(trial_end_date, self.billing_end_date.date())
693697
trial_days = {
694698
dt.date().day for dt in rrule(DAILY, dtstart=trial_refund_from, until=trial_refund_to)
695699
}
696-
return trial_days, trial_refund_from, trial_refund_to
700+
return TrialInfo(trial_days, trial_refund_from, trial_refund_to)
697701

698702
def get_entitlement_days(
699-
self, entitlement_start_date: str, entitlement_end_date: str, trial_days: set[int]
700-
):
703+
self,
704+
entitlement_start_date: str,
705+
entitlement_end_date: str | None = None,
706+
trial_days: set[int] | None = None,
707+
) -> set[int]:
708+
"""
709+
Calculate the set of entitlement day numbers within a date range,
710+
excluding the given trial day numbers.
711+
712+
Args:
713+
entitlement_start_date (str): ISO format start date string.
714+
entitlement_end_date (Optional[str]): ISO format end date string, or None.
715+
trial_days (Optional[Set[int]]): Set of trial day numbers (1–31)
716+
717+
Returns:
718+
Set[int]: Set of day numbers (1–31) representing entitlement days.
719+
"""
701720
start_date = max(datetime.fromisoformat(entitlement_start_date), self.billing_start_date)
702721
end_date = min(
703722
datetime.fromisoformat(entitlement_end_date)
@@ -722,7 +741,7 @@ def generate_charge_line(
722741
price: Decimal,
723742
datasource_name: str,
724743
description: str = "",
725-
):
744+
) -> str:
726745
"""
727746
This function generates a charge line for a vendor and datasource.
728747
"""

ffc/management/commands/process_billing.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from dateutil.relativedelta import relativedelta
55
from django.core.management.base import BaseCommand
66

7-
from ffc.process_billing import (
7+
from ffc.billing.process_billing import (
88
process_billing,
99
)
1010

tests/conftest.py

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,17 @@
11
import copy
22
from datetime import UTC, datetime
33
from decimal import Decimal
4+
from typing import Generator
45

6+
import httpx
57
import jwt
68
import pytest
79
import responses
810
from swo.mpt.extensions.runtime.djapp.conf import get_for_product
911

1012
from ffc.billing.dataclasses import AuthorizationProcessResult
11-
from ffc.process_billing import AuthorizationProcessor
13+
from ffc.billing.process_billing import AuthorizationProcessor
14+
from ffc.clients.base import BaseAsyncAPIClient
1215

1316

1417
@pytest.fixture()
@@ -2665,3 +2668,30 @@ def patch_fetch_organizations_agr_000(
26652668
"fetch_organizations",
26662669
return_value=org_mock_generator_agr_000,
26672670
)
2671+
2672+
2673+
class TestClientAuth(httpx.Auth):
2674+
def auth_flow(self, request: httpx.Request) -> Generator[httpx.Request, httpx.Response, None]:
2675+
request.headers["Authorization"] = "Bearer fake token"
2676+
yield request
2677+
2678+
2679+
class FakeAPIClient(BaseAsyncAPIClient):
2680+
@property
2681+
def base_url(self) -> str:
2682+
return "https://local.local/v1"
2683+
2684+
@property
2685+
def auth(self):
2686+
return TestClientAuth()
2687+
2688+
def get_pagination_meta(self, response):
2689+
return response["meta"]["pagination"]
2690+
2691+
def get_page_data(self, response):
2692+
return response["data"]
2693+
2694+
2695+
@pytest.fixture()
2696+
def fake_apiclient():
2697+
return FakeAPIClient(limit=2)

tests/ffc/test_base_client.py

Lines changed: 8 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1,37 +1,11 @@
1-
from collections.abc import Generator
2-
31
import httpx
42
import pytest
53

64
from ffc.clients.base import BaseAsyncAPIClient, PaginationSupportMixin
75

86

9-
class TestClientAuth(httpx.Auth):
10-
def auth_flow(self, request: httpx.Request) -> Generator[httpx.Request, httpx.Response, None]:
11-
request.headers["Authorization"] = "Bearer fake token"
12-
yield request
13-
14-
15-
class TestClient(BaseAsyncAPIClient):
16-
@property
17-
def base_url(self) -> str:
18-
return "https://local.local/v1"
19-
20-
@property
21-
def auth(self):
22-
return TestClientAuth()
23-
24-
def get_pagination_meta(self, response):
25-
return response["meta"]["pagination"]
26-
27-
def get_page_data(self, response):
28-
return response["data"]
29-
30-
317
@pytest.mark.asyncio()
32-
async def test_collection_iterator_paginates(httpx_mock):
33-
client = TestClient(limit=2)
34-
8+
async def test_collection_iterator_paginates(httpx_mock, fake_apiclient):
359
endpoint = "/catalog/authorizations"
3610
rql = "eq(mytestfield,'value')"
3711

@@ -60,7 +34,7 @@ async def test_collection_iterator_paginates(httpx_mock):
6034
)
6135

6236
items = []
63-
async for item in client.collection_iterator(endpoint, rql=rql):
37+
async for item in fake_apiclient.collection_iterator(endpoint, rql=rql):
6438
items.append(item)
6539

6640
assert [item["id"] for item in items] == ["AUT-1111-1111", "AUT-2222-2222", "AUT-3333-3333"]
@@ -74,14 +48,12 @@ async def test_collection_iterator_paginates(httpx_mock):
7448
for r in reqs:
7549
assert r.headers["Authorization"] == "Bearer fake token"
7650

77-
await client.close()
78-
assert client.httpx_client.is_closed
51+
await fake_apiclient.close()
52+
assert fake_apiclient.httpx_client.is_closed
7953

8054

8155
@pytest.mark.asyncio()
82-
async def test_collection_iterator_paginates_404(httpx_mock):
83-
client = TestClient(limit=2)
84-
56+
async def test_collection_iterator_paginates_404(httpx_mock, fake_apiclient):
8557
endpoint = "/catalog/authorizations"
8658
rql = "eq(mytestfield,'value')"
8759

@@ -92,14 +64,14 @@ async def test_collection_iterator_paginates_404(httpx_mock):
9264
)
9365

9466
with pytest.raises(httpx.HTTPStatusError):
95-
await anext(client.collection_iterator(endpoint, rql=rql))
67+
await anext(fake_apiclient.collection_iterator(endpoint, rql=rql))
9668
[req] = httpx_mock.get_requests()
9769
assert req.method == "GET"
9870
assert str(req.url) == f"https://local.local/v1/catalog/authorizations?{rql}&limit=2&offset=0"
9971
assert req.headers["Authorization"] == "Bearer fake token"
10072

101-
await client.close()
102-
assert client.httpx_client.is_closed
73+
await fake_apiclient.close()
74+
assert fake_apiclient.httpx_client.is_closed
10375

10476

10577
def test_cannot_instantiate_paginationsupportmixin_directly():

0 commit comments

Comments
 (0)