11from collections .abc import Iterable , Sequence
22from datetime import UTC , datetime
33
4- from sqlalchemy import delete , func
4+ from sqlalchemy import delete , func , not_
55from sqlalchemy .dialects .mysql import insert
66from sqlalchemy .ext .asyncio import AsyncConnection
77from sqlmodel import col , select
@@ -459,17 +459,25 @@ async def get_current_organization_allocations(
459459 * ,
460460 exclude_branch_ids : Sequence [Identifier ] | None = None ,
461461) -> dict [ResourceType , int ]:
462- result = await session .execute (
463- select (BranchProvisioning ).join (Branch ).join (Project ).where (Project .organization_id == organization_id )
462+ status_column = col (Branch .status )
463+ branch_id_column = col (BranchProvisioning .branch_id )
464+
465+ stmt = (
466+ select (BranchProvisioning )
467+ .join (Branch )
468+ .join (Project )
469+ .where (
470+ Project .organization_id == organization_id ,
471+ not_ (status_column .in_ ([BranchServiceStatus .STOPPED , BranchServiceStatus .DELETING ])),
472+ )
464473 )
465- rows = list (result .scalars ().all ())
466474 if exclude_branch_ids :
467- excluded = set (exclude_branch_ids )
468- rows = [row for row in rows if row .branch_id not in excluded ]
475+ stmt = stmt .where (not_ (branch_id_column .in_ (set (exclude_branch_ids ))))
469476
477+ result = await session .execute (stmt )
478+ rows = list (result .scalars ().all ())
470479 grouped = _group_by_resource_type (rows )
471- branch_statuses = await _collect_branch_statuses (session , rows )
472- return _aggregate_group_by_resource_type (grouped , branch_statuses )
480+ return _aggregate_group_by_resource_type (grouped )
473481
474482
475483async def get_current_project_allocations (
@@ -478,49 +486,33 @@ async def get_current_project_allocations(
478486 * ,
479487 exclude_branch_ids : Sequence [Identifier ] | None = None ,
480488) -> dict [ResourceType , int ]:
481- result = await session .execute (select (BranchProvisioning ).join (Branch ).where (Branch .project_id == project_id ))
482- rows = list (result .scalars ().all ())
489+ status_column = col (Branch .status )
490+ branch_id_column = col (BranchProvisioning .branch_id )
491+
492+ stmt = (
493+ select (BranchProvisioning )
494+ .join (Branch )
495+ .where (
496+ Branch .project_id == project_id ,
497+ not_ (status_column .in_ ([BranchServiceStatus .STOPPED , BranchServiceStatus .DELETING ])),
498+ )
499+ )
483500 if exclude_branch_ids :
484- excluded = set (exclude_branch_ids )
485- rows = [row for row in rows if row .branch_id not in excluded ]
501+ stmt = stmt .where (not_ (branch_id_column .in_ (set (exclude_branch_ids ))))
486502
503+ result = await session .execute (stmt )
504+ rows = list (result .scalars ().all ())
487505 grouped = _group_by_resource_type (rows )
488- branch_statuses = await _collect_branch_statuses (session , rows )
489- return _aggregate_group_by_resource_type (grouped , branch_statuses )
506+ return _aggregate_group_by_resource_type (grouped )
490507
491508
492- def _aggregate_group_by_resource_type (
493- grouped : dict [ResourceType , list [BranchProvisioning ]], branch_statuses : dict [Identifier , BranchServiceStatus ]
494- ) -> dict [ResourceType , int ]:
509+ def _aggregate_group_by_resource_type (grouped : dict [ResourceType , list [BranchProvisioning ]]) -> dict [ResourceType , int ]:
495510 return {
496- resource_type : sum (
497- allocation .amount
498- for allocation in allocations
499- if (allocation .branch_id is not None )
500- and (
501- branch_statuses .get (allocation .branch_id )
502- not in {BranchServiceStatus .STOPPED , BranchServiceStatus .DELETING }
503- )
504- )
511+ resource_type : sum (allocation .amount for allocation in allocations if allocation .branch_id is not None )
505512 for resource_type , allocations in grouped .items ()
506513 }
507514
508515
509- async def _collect_branch_statuses (
510- _session : SessionDep , rows : list [BranchProvisioning ]
511- ) -> dict [Identifier , BranchServiceStatus ]:
512- branch_ids = {row .branch_id for row in rows if row .branch_id is not None }
513- if not branch_ids :
514- return {}
515-
516- from ..organization .project import branch as branch_module
517-
518- statuses : dict [Identifier , BranchServiceStatus ] = {}
519- for branch_id in branch_ids :
520- statuses [branch_id ] = await branch_module .refresh_branch_status (branch_id )
521- return statuses
522-
523-
524516def _group_by_resource_type (allocations : list [BranchProvisioning ]) -> dict [ResourceType , list [BranchProvisioning ]]:
525517 result : dict [ResourceType , list [BranchProvisioning ]] = {}
526518 for allocation in allocations :
0 commit comments