Skip to content

Commit c50f40a

Browse files
authored
Merge pull request #72 from tiqi-group/7-hardware-worker-prevent-processing-stale-json-after-parameter-updates
7 hardware worker prevent processing stale json after parameter updates
2 parents bb73494 + 66bc312 commit c50f40a

File tree

6 files changed

+120
-9
lines changed

6 files changed

+120
-9
lines changed
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
"""JobRun: add parameter_update_timestamp column
2+
3+
Revision ID: b18002636885
4+
Revises: fc9af856df20
5+
Create Date: 2025-11-20 10:28:24.835587
6+
7+
"""
8+
9+
from collections.abc import Sequence
10+
11+
from alembic import op
12+
import sqlalchemy as sa
13+
14+
15+
# revision identifiers, used by Alembic.
16+
revision: str = 'b18002636885'
17+
down_revision: str | None = 'fc9af856df20'
18+
branch_labels: str | Sequence[str] | None = None
19+
depends_on: str | Sequence[str] | None = None
20+
21+
22+
def upgrade() -> None:
23+
# ### commands auto generated by Alembic - please adjust! ###
24+
with op.batch_alter_table('job_runs', schema=None) as batch_op:
25+
batch_op.add_column(sa.Column('parameter_update_timestamp', sa.TIMESTAMP(timezone=True), nullable=True))
26+
27+
# ### end Alembic commands ###
28+
29+
30+
def downgrade() -> None:
31+
# ### commands auto generated by Alembic - please adjust! ###
32+
with op.batch_alter_table('job_runs', schema=None) as batch_op:
33+
batch_op.drop_column('parameter_update_timestamp')
34+
35+
# ### end Alembic commands ###

src/icon/server/data_access/models/sqlite/job_run.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,11 @@ class JobRun(Base):
6565
log: sqlalchemy.orm.Mapped[str | None] = sqlalchemy.orm.mapped_column(default=None)
6666
"""Optional log message for this run (e.g., cancellation reason)."""
6767

68+
parameter_update_timestamp: sqlalchemy.orm.Mapped[datetime.datetime | None] = (
69+
sqlalchemy.orm.mapped_column(default=None)
70+
)
71+
"""Timestamp of the last parameter update."""
72+
6873
def __repr__(self) -> str:
6974
return (
7075
f"<JobRun id={self.id} "

src/icon/server/data_access/repositories/job_run_repository.py

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import logging
22
from collections.abc import Sequence
3-
from datetime import datetime
3+
from datetime import UTC, datetime
44

55
import sqlalchemy.orm
66
from sqlalchemy import select, update
@@ -185,3 +185,43 @@ def get_scheduled_time_by_job_id(*, job_id: int) -> datetime:
185185
scheduled_time = session.execute(stmt).scalar_one()
186186
logger.debug("Got scheduled time for job_id %s", job_id)
187187
return scheduled_time
188+
189+
@staticmethod
190+
def set_parameter_update_timestamp(*, run_id: int, timestamp: datetime) -> None:
191+
"""Set the paramter update timestamp.
192+
193+
Args:
194+
job_id: ID of the job.
195+
timestamp: New parameter update timestamp.
196+
"""
197+
198+
with sqlalchemy.orm.Session(engine) as session:
199+
stmt = (
200+
update(JobRun)
201+
.where(JobRun.id == run_id)
202+
.values(parameter_update_timestamp=timestamp.astimezone(UTC))
203+
.returning(JobRun)
204+
)
205+
206+
run = session.execute(stmt).scalar_one()
207+
session.commit()
208+
209+
logger.debug("Updated parameter update timestam for run %s", run)
210+
211+
@staticmethod
212+
def get_parameter_update_timestamp(*, run_id: int) -> datetime:
213+
"""Get the paramter update timestamp.
214+
215+
Args:
216+
job_id: ID of the job.
217+
Returns:
218+
The parameter update timestamp.
219+
"""
220+
221+
with sqlalchemy.orm.Session(engine) as session:
222+
stmt = select(JobRun.parameter_update_timestamp).where(JobRun.id == run_id)
223+
224+
timestamp = session.execute(stmt).scalar_one()
225+
logger.debug("Got parameter update timestamp for run %s", run_id)
226+
227+
return timestamp.replace(tzinfo=UTC)

src/icon/server/hardware_processing/task.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from __future__ import annotations
33

44
from datetime import datetime
5-
from queue import Queue
5+
from queue import PriorityQueue, Queue
66
from typing import TYPE_CHECKING, Any
77

88
import pydantic
@@ -22,15 +22,16 @@ class HardwareProcessingTask(pydantic.BaseModel):
2222
sequence_json: str
2323
src_dir: str
2424
created: datetime
25-
2625
if TYPE_CHECKING:
2726
processed_data_points: Queue[HardwareProcessingTask]
2827
data_points_to_process: Queue[tuple[int, dict[str, DatabaseValueType]]]
28+
outdated_tasks: PriorityQueue[HardwareProcessingTask]
2929
else:
3030
# must be Any as the queues are AutoProxy instances, which I didn't figure out
3131
# how to type
3232
processed_data_points: Any
3333
data_points_to_process: Any
34+
outdated_tasks: Any
3435

3536
def __lt__(self, other: HardwareProcessingTask) -> bool:
3637
return self.priority < other.priority or self.created < other.created

src/icon/server/hardware_processing/worker.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,14 @@ def run(self) -> None:
160160
):
161161
continue
162162

163+
parameter_update_timestamp = (
164+
JobRunRepository.get_parameter_update_timestamp(
165+
run_id=task.pre_processing_task.job_run.id,
166+
)
167+
)
168+
if task.created < parameter_update_timestamp:
169+
task.outdated_tasks.put(task)
170+
continue
163171
try:
164172
self._set_pydase_service_values(scanned_params=task.scanned_params)
165173

src/icon/server/pre_processing/worker.py

Lines changed: 28 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@
4141
from icon.server.hardware_processing.task import HardwareProcessingTask
4242

4343
if TYPE_CHECKING:
44-
from collections.abc import Iterator
44+
from collections.abc import Iterable, Iterator
4545

4646
from icon.server.data_access.models.sqlite.job import Job
4747
from icon.server.pre_processing.task import PreProcessingTask
@@ -192,6 +192,9 @@ def __init__(
192192
]
193193
self._processed_data_points: queue.Queue[HardwareProcessingTask]
194194
self._parameter_dict: dict[str, DatabaseValueType] = {}
195+
self._outdated_tasks: queue.PriorityQueue[HardwareProcessingTask] = (
196+
manager.PriorityQueue()
197+
)
195198

196199
def run(self) -> None:
197200
with tempfile.TemporaryDirectory() as tmp_dir:
@@ -289,14 +292,17 @@ def _process_task(
289292
readout_metadata=readout_metadata,
290293
)
291294

292-
if contains_realtime_parameter(job.scan_parameters):
295+
jobs = (
293296
self._handle_realtime_scan(
294297
pre_processing_task, src_dir=src_dir, namespace=namespace
295298
)
296-
else:
297-
self._handle_regular_scan(
299+
if contains_realtime_parameter(job.scan_parameters)
300+
else self._handle_regular_scan(
298301
pre_processing_task, src_dir=src_dir, namespace=namespace
299302
)
303+
)
304+
for _ in jobs:
305+
self._regenerate_outdated_jobs(namespace)
300306

301307
def _update_parameter_dict(
302308
self,
@@ -321,6 +327,10 @@ def _update_parameter_dict(
321327

322328
self._global_parameter_timestamp = datetime.now(timezone)
323329

330+
JobRunRepository.set_parameter_update_timestamp(
331+
run_id=pre_processing_task.job_run.id,
332+
timestamp=self._global_parameter_timestamp,
333+
)
324334
if mode == ParamUpdateMode.ONLY_NEW_PARAMETERS:
325335
if new_parameters:
326336
self._parameter_dict.update(new_parameters)
@@ -396,7 +406,7 @@ def _handle_regular_scan(
396406
pre_processing_task: PreProcessingTask,
397407
namespace: ExperimentIdentifier,
398408
src_dir: str,
399-
) -> None:
409+
) -> Iterable[None]:
400410
scan_parameter_value_combinations = get_scan_combinations(
401411
pre_processing_task.job
402412
)
@@ -421,6 +431,7 @@ def _handle_regular_scan(
421431
):
422432
break
423433

434+
yield
424435
self._submit_task_to_hw_worker(
425436
task=self._create_hardware_task(
426437
pre_processing_task=pre_processing_task,
@@ -454,15 +465,25 @@ def _create_hardware_task(
454465
sequence_json=sequence_json,
455466
processed_data_points=self._processed_data_points,
456467
data_points_to_process=self._data_points_to_process,
468+
outdated_tasks=self._outdated_tasks,
457469
created=datetime.now(timezone),
458470
)
459471

472+
def _regenerate_outdated_jobs(self, namespace: ExperimentIdentifier) -> None:
473+
for task in consume_queue(self._outdated_tasks):
474+
task.sequence_json = generate_sequence_json(
475+
n_shots=task.pre_processing_task.job.number_of_shots,
476+
parameter_dict={**self._parameter_dict, **task.scanned_params},
477+
namespace=namespace,
478+
)
479+
self._submit_task_to_hw_worker(task=task)
480+
460481
def _handle_realtime_scan(
461482
self,
462483
pre_processing_task: PreProcessingTask,
463484
namespace: ExperimentIdentifier,
464485
src_dir: str,
465-
) -> None:
486+
) -> Iterable[None]:
466487
params = pre_processing_task.job.scan_parameters
467488
realtime_param = next(p for p in params if p.realtime)
468489
n_scan_values = len(realtime_param.scan_values)
@@ -507,6 +528,7 @@ def _handle_realtime_scan(
507528
hardware_tasks[frozen_data_point] = hardware_task
508529
hardware_task.created = datetime.now(timezone)
509530
hardware_task.data_point_index = index
531+
yield
510532
self._submit_task_to_hw_worker(task=hardware_task)
511533

512534

0 commit comments

Comments
 (0)