77from datetime import UTC , date , datetime
88from io import BytesIO
99from pathlib import Path
10- from typing import Any
10+ from typing import Any , AsyncGenerator , Coroutine
1111
1212import aiofiles
1313import aiofiles .os
2222 CurrencyConversionInfo ,
2323 Datasource ,
2424 Refund ,
25+ TrialInfo ,
2526)
2627from ffc .billing .exceptions import ExchangeRatesClientError , JournalStatusError
2728from ffc .clients .exchage_rates import ExchangeRatesAsyncClient
@@ -112,7 +113,7 @@ def __init__(
112113 )
113114
114115 @asynccontextmanager
115- async def acquire_semaphore (self ):
116+ async def acquire_semaphore (self ) -> AsyncGenerator [ None , Any ] :
116117 """
117118 This method acquires and releases a semaphore.
118119 """
@@ -129,7 +130,7 @@ async def maybe_call(
129130 func ,
130131 * args ,
131132 ** kwargs ,
132- ):
133+ ) -> Any | None :
133134 """
134135 Conditionally calls and awaits the given asynchronous function.
135136
@@ -149,7 +150,7 @@ async def maybe_call(
149150 if not self .dry_run :
150151 return await func (* args , ** kwargs )
151152
152- def build_filepath (self ):
153+ def build_filepath (self ) -> str :
153154 """
154155 Constructs and returns the file path for a charges JSONL file.
155156
@@ -166,7 +167,7 @@ def build_filepath(self):
166167 filepath = f"{ tempfile .gettempdir ()} /{ filepath } " if not self .dry_run else filepath
167168 return filepath
168169
169- async def evaluate_journal_status (self , journal_external_id ):
170+ async def evaluate_journal_status (self , journal_external_id ) -> dict [ str , Any ] | None :
170171 """
171172 Evaluates the status of a journal.
172173
@@ -195,7 +196,7 @@ async def evaluate_journal_status(self, journal_external_id):
195196 self .logger .warning (f"Found the journal { journal_id } with status { journal_status } " )
196197 raise JournalStatusError ()
197198
198- async def process (self ):
199+ async def process (self ) -> None :
199200 """
200201 This method is responsible for passing a journal with status VALIDATED to
201202 the function that writes the charges into a file and then to complete the
@@ -204,7 +205,6 @@ async def process(self):
204205 result = AuthorizationProcessResult (authorization_id = self .authorization_id )
205206 async with self .acquire_semaphore ():
206207 try :
207- # double check with production
208208 if not await self .mpt_client .count_active_agreements (
209209 self .authorization_id ,
210210 self .billing_start_date ,
@@ -248,7 +248,7 @@ async def process(self):
248248 except Exception as error :
249249 self .logger .error (f"An error occurred: { error } " , exc_info = error )
250250
251- async def write_charges_file (self , filepath ):
251+ async def write_charges_file (self , filepath ) -> bool :
252252 """
253253 This method writes the charges file to the given filepath.
254254 If there is more than one agreement for an organization, it won't be processed.
@@ -308,7 +308,9 @@ async def write_charges_file(self, filepath):
308308 return False
309309 return True
310310
311- async def complete_journal_process (self , filepath , journal , journal_external_id ):
311+ async def complete_journal_process (
312+ self , filepath , journal , journal_external_id
313+ ) -> Coroutine | None :
312314 """
313315 This method uploads and submits the given journal, attaching also the exchange rates
314316 files.
@@ -377,7 +379,7 @@ async def get_currency_conversion_info(
377379
378380 async def attach_exchange_rates (
379381 self , journal_id : str , currency : str , exchange_rates : dict [str , Any ]
380- ):
382+ ) -> Coroutine | None :
381383 """
382384 This method checks if an attachment already exists for the given journal.
383385 If it exists, it will be deleted and a new one will be created with the
@@ -392,7 +394,7 @@ async def attach_exchange_rates(
392394 attachment = await self .mpt_client .fetch_journal_attachment (journal_id , f"{ currency } _" )
393395 if attachment : # pragma no cover
394396 if attachment ["name" ] == filename :
395- return
397+ return None
396398 await self .mpt_client .delete_journal_attachment (journal_id , attachment ["id" ])
397399
398400 return await self .mpt_client .create_journal_attachment (journal_id , filename , serialized )
@@ -402,7 +404,7 @@ async def dump_organization_charges(
402404 charges_file : Any ,
403405 organization : dict [str , Any ],
404406 agreement : dict [str , Any ] | None = None ,
405- ):
407+ ) -> Coroutine | None :
406408 organization_id = organization ["id" ]
407409 async for datasource_info , expenses in async_groupby (
408410 self .ffc_client .fetch_organization_expenses (organization_id , self .year , self .month ),
@@ -441,7 +443,7 @@ async def generate_datasource_charges(
441443 datasource_id : str ,
442444 datasource_name : str ,
443445 daily_expenses : dict [int , Decimal ],
444- ):
446+ ) -> list [ str ] :
445447 """
446448 This method generates all the charges for the given datasource and
447449 calculates the refund for the Trials and Entitlements periods.
@@ -539,7 +541,7 @@ async def generate_datasource_charges(
539541 )
540542 return charges
541543
542- async def is_journal_status_validated (self , journal_id , max_attempts = 5 ):
544+ async def is_journal_status_validated (self , journal_id , max_attempts = 5 ) -> bool :
543545 backoff_times = [0.15 , 0.45 , 1.05 , 2.25 , 4.65 ]
544546
545547 for attempt in range (min (max_attempts , len (backoff_times ))):
@@ -566,17 +568,19 @@ def generate_refunds(
566568 trial_days = set ()
567569 trial_start_date , trial_end_date = get_trial_dates (agreement = agreement )
568570 if trial_start_date and trial_end_date :
569- trial_days , trial_refund_from , trial_refund_to = self .get_trial_days (
571+ trial_info = self .get_trial_days (
570572 trial_start_date ,
571573 trial_end_date ,
572574 )
573- trial_amount = sum (daily_expenses .get (day , Decimal ("0" )) for day in trial_days )
574-
575+ trial_amount = sum (
576+ daily_expenses .get (day , Decimal ("0" )) for day in trial_info .trial_days
577+ )
578+ trial_days = trial_info .trial_days
575579 refund_lines .append (
576580 Refund (
577581 trial_amount ,
578- trial_refund_from ,
579- trial_refund_to ,
582+ trial_info . refund_from ,
583+ trial_info . refund_to ,
580584 (
581585 "Refund due to trial period "
582586 f"(from { trial_start_date .strftime ("%d %b %Y" )} " # type: ignore
@@ -614,7 +618,7 @@ async def generate_refund_lines(
614618 linked_datasource_type : str ,
615619 datasource_id : str ,
616620 datasource_name : str ,
617- ):
621+ ) -> list [ str ] :
618622 """
619623 This function calculates the refund lines for a billing period
620624 """
@@ -680,24 +684,39 @@ def get_trial_days(
680684 self ,
681685 trial_start_date : date | None ,
682686 trial_end_date : date | None ,
683- ):
687+ ) -> TrialInfo :
684688 # Trial period can start or end on month other than billing month period.
685689 # In this situation, we need to limit refunded expenses
686690 # to a period overlapping with billing month.
687691 # Example, billing month is June 1-30, a trial period is May 17 - June 17
688692 # We need to refund expenses from June 1st to June 17th
689693 if not trial_start_date and not trial_end_date :
690- return None , None , None
694+ return TrialInfo ()
691695 trial_refund_from = max (trial_start_date , self .billing_start_date .date ())
692696 trial_refund_to = min (trial_end_date , self .billing_end_date .date ())
693697 trial_days = {
694698 dt .date ().day for dt in rrule (DAILY , dtstart = trial_refund_from , until = trial_refund_to )
695699 }
696- return trial_days , trial_refund_from , trial_refund_to
700+ return TrialInfo ( trial_days , trial_refund_from , trial_refund_to )
697701
698702 def get_entitlement_days (
699- self , entitlement_start_date : str , entitlement_end_date : str , trial_days : set [int ]
700- ):
703+ self ,
704+ entitlement_start_date : str ,
705+ entitlement_end_date : str | None = None ,
706+ trial_days : set [int ] | None = None ,
707+ ) -> set [int ]:
708+ """
709+ Calculate the set of entitlement day numbers within a date range,
710+ excluding the given trial day numbers.
711+
712+ Args:
713+ entitlement_start_date (str): ISO format start date string.
714+ entitlement_end_date (Optional[str]): ISO format end date string, or None.
715+ trial_days (Optional[Set[int]]): Set of trial day numbers (1–31)
716+
717+ Returns:
718+ Set[int]: Set of day numbers (1–31) representing entitlement days.
719+ """
701720 start_date = max (datetime .fromisoformat (entitlement_start_date ), self .billing_start_date )
702721 end_date = min (
703722 datetime .fromisoformat (entitlement_end_date )
@@ -722,7 +741,7 @@ def generate_charge_line(
722741 price : Decimal ,
723742 datasource_name : str ,
724743 description : str = "" ,
725- ):
744+ ) -> str :
726745 """
727746 This function generates a charge line for a vendor and datasource.
728747 """
0 commit comments