Skip to content

Commit ae12ecb

Browse files
committed
allocation checks: check branch.status instead of making call to k8s
1 parent 8cdf90e commit ae12ecb

File tree

1 file changed

+33
-41
lines changed

1 file changed

+33
-41
lines changed

src/api/_util/resourcelimit.py

Lines changed: 33 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from collections.abc import Iterable, Sequence
22
from datetime import UTC, datetime
33

4-
from sqlalchemy import delete, func
4+
from sqlalchemy import delete, func, not_
55
from sqlalchemy.dialects.mysql import insert
66
from sqlalchemy.ext.asyncio import AsyncConnection
77
from sqlmodel import col, select
@@ -459,17 +459,25 @@ async def get_current_organization_allocations(
459459
*,
460460
exclude_branch_ids: Sequence[Identifier] | None = None,
461461
) -> dict[ResourceType, int]:
462-
result = await session.execute(
463-
select(BranchProvisioning).join(Branch).join(Project).where(Project.organization_id == organization_id)
462+
status_column = col(Branch.status)
463+
branch_id_column = col(BranchProvisioning.branch_id)
464+
465+
stmt = (
466+
select(BranchProvisioning)
467+
.join(Branch)
468+
.join(Project)
469+
.where(
470+
Project.organization_id == organization_id,
471+
not_(status_column.in_([BranchServiceStatus.STOPPED, BranchServiceStatus.DELETING])),
472+
)
464473
)
465-
rows = list(result.scalars().all())
466474
if exclude_branch_ids:
467-
excluded = set(exclude_branch_ids)
468-
rows = [row for row in rows if row.branch_id not in excluded]
475+
stmt = stmt.where(not_(branch_id_column.in_(set(exclude_branch_ids))))
469476

477+
result = await session.execute(stmt)
478+
rows = list(result.scalars().all())
470479
grouped = _group_by_resource_type(rows)
471-
branch_statuses = await _collect_branch_statuses(session, rows)
472-
return _aggregate_group_by_resource_type(grouped, branch_statuses)
480+
return _aggregate_group_by_resource_type(grouped)
473481

474482

475483
async def get_current_project_allocations(
@@ -478,49 +486,33 @@ async def get_current_project_allocations(
478486
*,
479487
exclude_branch_ids: Sequence[Identifier] | None = None,
480488
) -> dict[ResourceType, int]:
481-
result = await session.execute(select(BranchProvisioning).join(Branch).where(Branch.project_id == project_id))
482-
rows = list(result.scalars().all())
489+
status_column = col(Branch.status)
490+
branch_id_column = col(BranchProvisioning.branch_id)
491+
492+
stmt = (
493+
select(BranchProvisioning)
494+
.join(Branch)
495+
.where(
496+
Branch.project_id == project_id,
497+
not_(status_column.in_([BranchServiceStatus.STOPPED, BranchServiceStatus.DELETING])),
498+
)
499+
)
483500
if exclude_branch_ids:
484-
excluded = set(exclude_branch_ids)
485-
rows = [row for row in rows if row.branch_id not in excluded]
501+
stmt = stmt.where(not_(branch_id_column.in_(set(exclude_branch_ids))))
486502

503+
result = await session.execute(stmt)
504+
rows = list(result.scalars().all())
487505
grouped = _group_by_resource_type(rows)
488-
branch_statuses = await _collect_branch_statuses(session, rows)
489-
return _aggregate_group_by_resource_type(grouped, branch_statuses)
506+
return _aggregate_group_by_resource_type(grouped)
490507

491508

492-
def _aggregate_group_by_resource_type(
493-
grouped: dict[ResourceType, list[BranchProvisioning]], branch_statuses: dict[Identifier, BranchServiceStatus]
494-
) -> dict[ResourceType, int]:
509+
def _aggregate_group_by_resource_type(grouped: dict[ResourceType, list[BranchProvisioning]]) -> dict[ResourceType, int]:
495510
return {
496-
resource_type: sum(
497-
allocation.amount
498-
for allocation in allocations
499-
if (allocation.branch_id is not None)
500-
and (
501-
branch_statuses.get(allocation.branch_id)
502-
not in {BranchServiceStatus.STOPPED, BranchServiceStatus.DELETING}
503-
)
504-
)
511+
resource_type: sum(allocation.amount for allocation in allocations if allocation.branch_id is not None)
505512
for resource_type, allocations in grouped.items()
506513
}
507514

508515

509-
async def _collect_branch_statuses(
510-
_session: SessionDep, rows: list[BranchProvisioning]
511-
) -> dict[Identifier, BranchServiceStatus]:
512-
branch_ids = {row.branch_id for row in rows if row.branch_id is not None}
513-
if not branch_ids:
514-
return {}
515-
516-
from ..organization.project import branch as branch_module
517-
518-
statuses: dict[Identifier, BranchServiceStatus] = {}
519-
for branch_id in branch_ids:
520-
statuses[branch_id] = await branch_module.refresh_branch_status(branch_id)
521-
return statuses
522-
523-
524516
def _group_by_resource_type(allocations: list[BranchProvisioning]) -> dict[ResourceType, list[BranchProvisioning]]:
525517
result: dict[ResourceType, list[BranchProvisioning]] = {}
526518
for allocation in allocations:

0 commit comments

Comments
 (0)