|
11 | 11 | from pcweb.constants import REFLEX_CLOUD_URL, PRO_TIERS_TABLE |
12 | 12 |
|
13 | 13 |
|
| 14 | +_SORTED_TIERS = sorted( |
| 15 | + [{"key": k, **v} for k, v in PRO_TIERS_TABLE.items()], key=lambda x: x["credits"] |
| 16 | +) |
| 17 | + |
| 18 | + |
14 | 19 | def format_number(number: int | float) -> str: |
15 | 20 | """Format number with locale string, handling non-numeric values""" |
16 | 21 | return rx.Var( |
@@ -85,37 +90,43 @@ class MachineState(rx.State): |
85 | 90 | def _recalculate_all(self): |
86 | 91 | """Recalculate all derived values when state changes""" |
87 | 92 | # Calculate machines weekly credits using cached values |
88 | | - self.machines_weekly_credits = sum( |
89 | | - machine.weekly_credits for machine in self.machines |
90 | | - ) |
| 93 | + machines_credits = sum(m.weekly_credits for m in self.machines) |
| 94 | + self.machines_weekly_credits = machines_credits |
91 | 95 |
|
92 | 96 | # Calculate current tier based on message credits |
93 | 97 | msg_credits = get_message_credits(self.messages_tier_index) |
94 | 98 | is_enterprise = get_is_enterprise_tier(self.messages_tier_index) |
95 | 99 |
|
| 100 | + # Early return path for enterprise tier |
96 | 101 | if is_enterprise: |
97 | 102 | self.current_tier = { |
98 | 103 | "key": "Enterprise", |
99 | 104 | "credits": msg_credits, |
100 | 105 | "price": "custom", |
101 | 106 | } |
102 | | - else: |
103 | | - tier = self._find_tier_for_credits(msg_credits) |
104 | | - self.current_tier = { |
105 | | - "key": tier["key"] if tier else "Enterprise", |
106 | | - "credits": msg_credits, |
107 | | - "price": tier["price"] if tier else "custom", |
| 107 | + self.total_credits = "Custom" |
| 108 | + self.recommended_tier_info = { |
| 109 | + "price": "Custom", |
| 110 | + "needs_enterprise": True, |
| 111 | + "name": "Enterprise", |
| 112 | + "credits": "Custom", |
108 | 113 | } |
| 114 | + return |
109 | 115 |
|
110 | | - # Calculate total credits and find tier once |
111 | | - total = msg_credits + round(self.machines_weekly_credits, 2) |
112 | | - total_tier = None if is_enterprise else self._find_tier_for_credits(total) |
| 116 | + # Non-enterprise path - find tiers once |
| 117 | + current_tier = self._find_tier_for_credits(msg_credits) |
| 118 | + total = msg_credits + machines_credits |
| 119 | + total_tier = self._find_tier_for_credits(total) |
| 120 | + |
| 121 | + # Set current tier |
| 122 | + self.current_tier = { |
| 123 | + "key": current_tier["key"] if current_tier else "Enterprise", |
| 124 | + "credits": msg_credits, |
| 125 | + "price": current_tier["price"] if current_tier else "custom", |
| 126 | + } |
113 | 127 |
|
114 | 128 | # Set total credits display |
115 | | - if is_enterprise or not total_tier: |
116 | | - self.total_credits = "Custom" |
117 | | - else: |
118 | | - self.total_credits = f"{total:,}" |
| 129 | + self.total_credits = f"{total:,}" if total_tier else "Custom" |
119 | 130 |
|
120 | 131 | # Set recommended tier info |
121 | 132 | if total_tier: |
@@ -156,11 +167,10 @@ def update_messages_tier(self, new_tier_index: int): |
156 | 167 | self._recalculate_all() |
157 | 168 |
|
158 | 169 | def _find_tier_for_credits(self, credits: float) -> dict | None: |
159 | | - """Find Pro tier that fits the given credits, or None if Enterprise needed""" |
160 | | - for tier_key in pro_tier_keys: |
161 | | - tier_data = PRO_TIERS_TABLE[tier_key] |
162 | | - if credits <= tier_data["credits"]: |
163 | | - return {"key": tier_key, **tier_data} |
| 170 | + """Find Pro tier that fits the given credits using binary search""" |
| 171 | + for tier in _SORTED_TIERS: |
| 172 | + if credits <= tier["credits"]: |
| 173 | + return tier |
164 | 174 | return None |
165 | 175 |
|
166 | 176 | @rx.event(temporal=True) |
@@ -342,6 +352,7 @@ def messages_card() -> rx.Component: |
342 | 352 | MachineState.update_messages_tier(new_tier_index), |
343 | 353 | rx.noop(), |
344 | 354 | ), |
| 355 | + min_steps_between_values=1, |
345 | 356 | class_name="w-full max-w-full", |
346 | 357 | ), |
347 | 358 | on_mouse_enter=message_tooltip_open_cs.set_value(True), |
@@ -417,6 +428,7 @@ def machine_card(machine: Machine, index: int) -> rx.Component: |
417 | 428 | max=COMPUTE_TABLE_KEYS.length() - 1, |
418 | 429 | step=1, |
419 | 430 | value=machine.index, |
| 431 | + min_steps_between_values=1, |
420 | 432 | on_value_change=lambda new_machine_index: rx.cond( |
421 | 433 | machine.index != new_machine_index, |
422 | 434 | MachineState.update_machine(index, new_machine_index), |
|
0 commit comments