Skip to content

Commit c9b6747

Browse files
authored
Merge pull request #31 from softwareone-platform/mpt-10795_improvements
[MPT-10795] Add TrialInfo dataclass and code improvements
2 parents 0a26392 + 809c62a commit c9b6747

File tree

7 files changed

+135
-119
lines changed

7 files changed

+135
-119
lines changed

ffc/billing/dataclasses.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,13 @@ class Refund:
2525
description: str
2626

2727

28+
@dataclass
29+
class TrialInfo:
30+
trial_days: set[int]
31+
refund_from: date
32+
refund_to: date
33+
34+
2835
@dataclass
2936
class CurrencyConversionInfo:
3037
base_currency: str
Lines changed: 58 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from datetime import UTC, date, datetime
88
from io import BytesIO
99
from pathlib import Path
10-
from typing import Any
10+
from typing import Any, AsyncGenerator, Coroutine
1111

1212
import aiofiles
1313
import aiofiles.os
@@ -22,6 +22,7 @@
2222
CurrencyConversionInfo,
2323
Datasource,
2424
Refund,
25+
TrialInfo,
2526
)
2627
from ffc.billing.exceptions import ExchangeRatesClientError, JournalStatusError
2728
from ffc.clients.exchage_rates import ExchangeRatesAsyncClient
@@ -105,14 +106,14 @@ def __init__(
105106
self.billing_end_date = self.billing_start_date + relativedelta(months=1, days=-1)
106107
self.DECIMAL_DIGITS = 4
107108
self.DECIMAL_PRECISION = Decimal("10") ** -self.DECIMAL_DIGITS
108-
self.exchange_rates = {}
109-
self.invalid_organizations = []
109+
self.exchange_rates: dict[str, Any] = {}
110+
self.invalid_organizations: list[Any] = []
110111
self.logger = PrefixAdapter(
111112
logging.getLogger(__name__), {"prefix": self.authorization.get("id")}
112113
)
113114

114115
@asynccontextmanager
115-
async def acquire_semaphore(self):
116+
async def acquire_semaphore(self) -> AsyncGenerator[None, Any]:
116117
"""
117118
This method acquires and releases a semaphore.
118119
"""
@@ -129,7 +130,7 @@ async def maybe_call(
129130
func,
130131
*args,
131132
**kwargs,
132-
):
133+
) -> Any | None:
133134
"""
134135
Conditionally calls and awaits the given asynchronous function.
135136
@@ -149,7 +150,7 @@ async def maybe_call(
149150
if not self.dry_run:
150151
return await func(*args, **kwargs)
151152

152-
def build_filepath(self):
153+
def build_filepath(self) -> str:
153154
"""
154155
Constructs and returns the file path for a charges JSONL file.
155156
@@ -166,7 +167,7 @@ def build_filepath(self):
166167
filepath = f"{tempfile.gettempdir()}/{filepath}" if not self.dry_run else filepath
167168
return filepath
168169

169-
async def evaluate_journal_status(self, journal_external_id):
170+
async def evaluate_journal_status(self, journal_external_id) -> dict[str, Any] | None:
170171
"""
171172
Evaluates the status of a journal.
172173
@@ -195,7 +196,7 @@ async def evaluate_journal_status(self, journal_external_id):
195196
self.logger.warning(f"Found the journal {journal_id} with status {journal_status}")
196197
raise JournalStatusError()
197198

198-
async def process(self):
199+
async def process(self) -> AuthorizationProcessResult | None:
199200
"""
200201
This method is responsible for passing a journal with status VALIDATED to
201202
the function that writes the charges into a file and then to complete the
@@ -204,7 +205,6 @@ async def process(self):
204205
result = AuthorizationProcessResult(authorization_id=self.authorization_id)
205206
async with self.acquire_semaphore():
206207
try:
207-
# double check with production
208208
if not await self.mpt_client.count_active_agreements(
209209
self.authorization_id,
210210
self.billing_start_date,
@@ -248,7 +248,7 @@ async def process(self):
248248
except Exception as error:
249249
self.logger.error(f"An error occurred: {error}", exc_info=error)
250250

251-
async def write_charges_file(self, filepath):
251+
async def write_charges_file(self, filepath) -> bool:
252252
"""
253253
This method writes the charges file to the given filepath.
254254
If there is more than one agreement for an organization, it won't be processed.
@@ -308,7 +308,9 @@ async def write_charges_file(self, filepath):
308308
return False
309309
return True
310310

311-
async def complete_journal_process(self, filepath, journal, journal_external_id):
311+
async def complete_journal_process(
312+
self, filepath, journal, journal_external_id
313+
) -> Coroutine | None:
312314
"""
313315
This method uploads and submits the given journal, attaching also the exchange rates
314316
files.
@@ -377,7 +379,7 @@ async def get_currency_conversion_info(
377379

378380
async def attach_exchange_rates(
379381
self, journal_id: str, currency: str, exchange_rates: dict[str, Any]
380-
):
382+
) -> Coroutine | None:
381383
"""
382384
This method checks if an attachment already exists for the given journal.
383385
If it exists, it will be deleted and a new one will be created with the
@@ -392,7 +394,7 @@ async def attach_exchange_rates(
392394
attachment = await self.mpt_client.fetch_journal_attachment(journal_id, f"{currency}_")
393395
if attachment: # pragma no cover
394396
if attachment["name"] == filename:
395-
return
397+
return None
396398
await self.mpt_client.delete_journal_attachment(journal_id, attachment["id"])
397399

398400
return await self.mpt_client.create_journal_attachment(journal_id, filename, serialized)
@@ -401,8 +403,8 @@ async def dump_organization_charges(
401403
self,
402404
charges_file: Any,
403405
organization: dict[str, Any],
404-
agreement: dict[str, Any] | None = None,
405-
):
406+
agreement: dict[str, Any],
407+
) -> Coroutine | None:
406408
organization_id = organization["id"]
407409
async for datasource_info, expenses in async_groupby(
408410
self.ffc_client.fetch_organization_expenses(organization_id, self.year, self.month),
@@ -441,7 +443,7 @@ async def generate_datasource_charges(
441443
datasource_id: str,
442444
datasource_name: str,
443445
daily_expenses: dict[int, Decimal],
444-
):
446+
) -> list[str]:
445447
"""
446448
This method generates all the charges for the given datasource and
447449
calculates the refund for the Trials and Entitlements periods.
@@ -539,7 +541,7 @@ async def generate_datasource_charges(
539541
)
540542
return charges
541543

542-
async def is_journal_status_validated(self, journal_id, max_attempts=5):
544+
async def is_journal_status_validated(self, journal_id, max_attempts=5) -> bool:
543545
backoff_times = [0.15, 0.45, 1.05, 2.25, 4.65]
544546

545547
for attempt in range(min(max_attempts, len(backoff_times))):
@@ -563,20 +565,22 @@ def generate_refunds(
563565
the trials and entitlements period. Trials get priority over Entitlements
564566
"""
565567
refund_lines = []
566-
trial_days = set()
568+
trial_days: set[int] = set()
567569
trial_start_date, trial_end_date = get_trial_dates(agreement=agreement)
568570
if trial_start_date and trial_end_date:
569-
trial_days, trial_refund_from, trial_refund_to = self.get_trial_days(
571+
trial_info = self.get_trial_info(
570572
trial_start_date,
571573
trial_end_date,
572574
)
573-
trial_amount = sum(daily_expenses.get(day, Decimal("0")) for day in trial_days)
574-
575+
trial_amount = Decimal(
576+
sum(daily_expenses.get(day, Decimal("0")) for day in trial_info.trial_days)
577+
)
578+
trial_days = trial_info.trial_days
575579
refund_lines.append(
576580
Refund(
577581
trial_amount,
578-
trial_refund_from,
579-
trial_refund_to,
582+
trial_info.refund_from,
583+
trial_info.refund_to,
580584
(
581585
"Refund due to trial period "
582586
f"(from {trial_start_date.strftime("%d %b %Y")} " # type: ignore
@@ -587,12 +591,14 @@ def generate_refunds(
587591

588592
if entitlement_start_date:
589593
entitlement_days = self.get_entitlement_days(
590-
entitlement_start_date,
591-
entitlement_termination_date,
592-
trial_days,
594+
trial_days=trial_days,
595+
entitlement_start_date=entitlement_start_date,
596+
entitlement_end_date=entitlement_termination_date,
593597
)
594598
for r_start, r_end in split_entitlement_days_into_ranges(entitlement_days):
595-
ent_amount = sum(daily_expenses.get(day, 0) for day in range(r_start, r_end + 1))
599+
ent_amount = Decimal(
600+
sum(daily_expenses.get(day, 0) for day in range(r_start, r_end + 1))
601+
)
596602

597603
refund_lines.append(
598604
Refund(
@@ -614,7 +620,7 @@ async def generate_refund_lines(
614620
linked_datasource_type: str,
615621
datasource_id: str,
616622
datasource_name: str,
617-
):
623+
) -> list[str]:
618624
"""
619625
This function calculates the refund lines for a billing period
620626
"""
@@ -676,28 +682,41 @@ async def generate_refund_lines(
676682
idx += 1
677683
return charges
678684

679-
def get_trial_days(
685+
def get_trial_info(
680686
self,
681-
trial_start_date: date | None,
682-
trial_end_date: date | None,
683-
):
687+
trial_start_date: date,
688+
trial_end_date: date,
689+
) -> TrialInfo:
684690
# Trial period can start or end on month other than billing month period.
685691
# In this situation, we need to limit refunded expenses
686692
# to a period overlapping with billing month.
687693
# Example, billing month is June 1-30, a trial period is May 17 - June 17
688694
# We need to refund expenses from June 1st to June 17th
689-
if not trial_start_date and not trial_end_date:
690-
return None, None, None
691695
trial_refund_from = max(trial_start_date, self.billing_start_date.date())
692696
trial_refund_to = min(trial_end_date, self.billing_end_date.date())
693697
trial_days = {
694698
dt.date().day for dt in rrule(DAILY, dtstart=trial_refund_from, until=trial_refund_to)
695699
}
696-
return trial_days, trial_refund_from, trial_refund_to
700+
return TrialInfo(trial_days, trial_refund_from, trial_refund_to)
697701

698702
def get_entitlement_days(
699-
self, entitlement_start_date: str, entitlement_end_date: str, trial_days: set[int]
700-
):
703+
self,
704+
trial_days: set[int],
705+
entitlement_start_date: str,
706+
entitlement_end_date: str | None = None,
707+
) -> set[int]:
708+
"""
709+
Calculate the set of entitlement day numbers within a date range,
710+
excluding the given trial day numbers.
711+
712+
Args:
713+
entitlement_start_date (str): ISO format start date string.
714+
entitlement_end_date (Optional[str]): ISO format end date string, or None.
715+
trial_days (Optional[Set[int]]): Set of trial day numbers (1–31)
716+
717+
Returns:
718+
Set[int]: Set of day numbers (1–31) representing entitlement days.
719+
"""
701720
start_date = max(datetime.fromisoformat(entitlement_start_date), self.billing_start_date)
702721
end_date = min(
703722
datetime.fromisoformat(entitlement_end_date)
@@ -722,7 +741,7 @@ def generate_charge_line(
722741
price: Decimal,
723742
datasource_name: str,
724743
description: str = "",
725-
):
744+
) -> str:
726745
"""
727746
This function generates a charge line for a vendor and datasource.
728747
"""
@@ -762,7 +781,7 @@ def generate_charge_line(
762781
return f"{line}\n"
763782

764783

765-
def get_trial_dates(agreement: dict[str, Any]) -> tuple[Any | None, Decimal]:
784+
def get_trial_dates(agreement: dict[str, Any]) -> tuple[Any, Any]:
766785
trial_start = get_trial_start_date(agreement)
767786
trial_end = get_trial_end_date(agreement)
768787
return trial_start, trial_end

ffc/management/commands/process_billing.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from dateutil.relativedelta import relativedelta
55
from django.core.management.base import BaseCommand
66

7-
from ffc.process_billing import (
7+
from ffc.billing.process_billing import (
88
process_billing,
99
)
1010

pyproject.toml

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,9 @@ pytest-randomly = "3.15.*"
4646
pytest-xdist = "3.5.*"
4747
responses = "0.24.*"
4848
ruff = "0.3.*"
49+
types-aiofiles = "^24.1.0.20250822"
4950
types-openpyxl = "3.1.*"
51+
types-python-dateutil = "^2.9.0.20250822"
5052
types-requests = "2.31.*"
5153

5254
[tool.poetry.group.sdk.dependencies] # Move to the SDK when splitting
@@ -106,21 +108,21 @@ line-length = 100
106108
[tool.ruff.lint]
107109

108110
select = [
109-
"E", # w errors
110-
"W", # pycodestyle warnings
111-
"F", # pyflakes
112-
"I", # isort
113-
"B", # flake8-bugbear
114-
"C4", # flake8-comprehensions
115-
"UP", # pyupgrade,
116-
"PT", # flake8-pytest-style
117-
"T10", # flake8-pytest-style
111+
"E", # w errors
112+
"W", # pycodestyle warnings
113+
"F", # pyflakes
114+
"I", # isort
115+
"B", # flake8-bugbear
116+
"C4", # flake8-comprehensions
117+
"UP", # pyupgrade,
118+
"PT", # flake8-pytest-style
119+
"T10", # flake8-pytest-style
118120
]
119121
ignore = [
120122
"PT004", # fixture '{name}' does not return anything, add leading underscore
121123
"PT011", # pytest.raises({exception}) is too broad, set the match parameter or use a more specific exception
122-
"B008", # do not perform function calls in argument defaults
123-
"B904", # Within an `except` clause, raise exceptions with `raise ... from err` or `raise ... from None` to distinguish them from errors in exception handling
124+
"B008", # do not perform function calls in argument defaults
125+
"B904", # Within an `except` clause, raise exceptions with `raise ... from err` or `raise ... from None` to distinguish them from errors in exception handling
124126
]
125127

126128
[tool.ruff.lint.isort]

tests/conftest.py

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,17 @@
11
import copy
22
from datetime import UTC, datetime
33
from decimal import Decimal
4+
from typing import Generator
45

6+
import httpx
57
import jwt
68
import pytest
79
import responses
810
from swo.mpt.extensions.runtime.djapp.conf import get_for_product
911

1012
from ffc.billing.dataclasses import AuthorizationProcessResult
11-
from ffc.process_billing import AuthorizationProcessor
13+
from ffc.billing.process_billing import AuthorizationProcessor
14+
from ffc.clients.base import BaseAsyncAPIClient
1215

1316

1417
@pytest.fixture()
@@ -2665,3 +2668,30 @@ def patch_fetch_organizations_agr_000(
26652668
"fetch_organizations",
26662669
return_value=org_mock_generator_agr_000,
26672670
)
2671+
2672+
2673+
class TestClientAuth(httpx.Auth):
2674+
def auth_flow(self, request: httpx.Request) -> Generator[httpx.Request, httpx.Response, None]:
2675+
request.headers["Authorization"] = "Bearer fake token"
2676+
yield request
2677+
2678+
2679+
class FakeAPIClient(BaseAsyncAPIClient):
2680+
@property
2681+
def base_url(self) -> str:
2682+
return "https://local.local/v1"
2683+
2684+
@property
2685+
def auth(self):
2686+
return TestClientAuth()
2687+
2688+
def get_pagination_meta(self, response):
2689+
return response["meta"]["pagination"]
2690+
2691+
def get_page_data(self, response):
2692+
return response["data"]
2693+
2694+
2695+
@pytest.fixture()
2696+
def fake_apiclient():
2697+
return FakeAPIClient(limit=2)

0 commit comments

Comments
 (0)