Skip to content

Commit 3b0b1f9

Browse files
tconley1428claude
andauthored
Add random seed access and callback methods to temporalio.workflow (#1320)
* Add random seed access and callback methods to temporalio.workflow This commit adds three new methods to temporalio.workflow to provide access to and control over the workflow's random seed: 1. workflow.random_seed() - Returns the current random seed value from core 2. workflow.register_random_seed_callback() - Registers callbacks for seed changes 3. workflow.new_random() - Creates an auto-reseeded Random instance These methods enable workflows to: - Access the current deterministic random seed - React to seed changes during workflow resets/replays - Create additional random number generators that stay synchronized The implementation includes: - Abstract methods in _Runtime class - Concrete implementation in _WorkflowInstanceImpl - Proper callback invocation during seed updates - Comprehensive test coverage with workflow reset scenarios - Full type safety and documentation 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude <noreply@anthropic.com> * Skip test on timeskipping * Remove accidental CLAUDE file --------- Co-authored-by: Claude <noreply@anthropic.com>
1 parent a9769d7 commit 3b0b1f9

File tree

3 files changed

+188
-0
lines changed

3 files changed

+188
-0
lines changed

temporalio/worker/_workflow_instance.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -272,6 +272,8 @@ def __init__(self, det: WorkflowInstanceDetails) -> None:
272272
self._object: Any = None
273273
self._is_replaying: bool = False
274274
self._random = random.Random(det.randomness_seed)
275+
self._current_seed = det.randomness_seed
276+
self._seed_callbacks: list[Callable[[int], None]] = []
275277
self._read_only = False
276278
self._in_query_or_validator = False
277279

@@ -1075,6 +1077,14 @@ def _apply_update_random_seed(
10751077
self, job: temporalio.bridge.proto.workflow_activation.UpdateRandomSeed
10761078
) -> None:
10771079
self._random.seed(job.randomness_seed)
1080+
self._current_seed = job.randomness_seed
1081+
# Notify all registered callbacks
1082+
for callback in self._seed_callbacks:
1083+
try:
1084+
callback(job.randomness_seed)
1085+
except Exception:
1086+
# Ignore callback errors to avoid disrupting workflow execution
1087+
pass
10781088

10791089
def _make_workflow_input(
10801090
self, init_job: temporalio.bridge.proto.workflow_activation.InitializeWorkflow
@@ -1808,6 +1818,14 @@ def workflow_last_failure(self) -> BaseException | None:
18081818

18091819
return None
18101820

1821+
def workflow_random_seed(self) -> int:
1822+
return self._current_seed
1823+
1824+
def workflow_register_random_seed_callback(
1825+
self, callback: Callable[[int], None]
1826+
) -> None:
1827+
self._seed_callbacks.append(callback)
1828+
18111829
#### Calls from outbound impl ####
18121830
# These are in alphabetical order and all start with "_outbound_".
18131831

temporalio/workflow.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -902,6 +902,14 @@ def workflow_last_completion_result(self, type_hint: type | None) -> Any | None:
902902
@abstractmethod
903903
def workflow_last_failure(self) -> BaseException | None: ...
904904

905+
@abstractmethod
906+
def workflow_random_seed(self) -> int: ...
907+
908+
@abstractmethod
909+
def workflow_register_random_seed_callback(
910+
self, callback: Callable[[int], None]
911+
) -> None: ...
912+
905913

906914
_current_update_info: contextvars.ContextVar[UpdateInfo] = contextvars.ContextVar(
907915
"__temporal_current_update_info"
@@ -1156,6 +1164,51 @@ def random() -> Random:
11561164
return _Runtime.current().workflow_random()
11571165

11581166

1167+
def random_seed() -> int:
1168+
"""Get the current random seed value from core.
1169+
1170+
This returns the seed value currently being used by the workflow's
1171+
deterministic random number generator.
1172+
1173+
Returns:
1174+
The current random seed as an integer.
1175+
"""
1176+
return _Runtime.current().workflow_random_seed()
1177+
1178+
1179+
def register_random_seed_callback(callback: Callable[[int], None]) -> None:
1180+
"""Register a callback to be notified when the random seed changes.
1181+
1182+
The callback will be invoked whenever the workflow receives a new random
1183+
seed from the core. This is useful for maintaining external random number
1184+
generators that need to stay in sync with the workflow's randomness.
1185+
1186+
Args:
1187+
callback: Function to be called with the new seed value when it changes.
1188+
"""
1189+
return _Runtime.current().workflow_register_random_seed_callback(callback)
1190+
1191+
1192+
def new_random() -> Random:
1193+
"""Create a Random instance that automatically reseeds when the workflow seed changes.
1194+
1195+
This creates a new Random instance that is initially seeded with the current
1196+
workflow seed, and automatically registers a callback to reseed itself
1197+
whenever the workflow receives a new seed from core.
1198+
1199+
Returns:
1200+
A Random instance that stays synchronized with the workflow's randomness.
1201+
"""
1202+
current_seed = random_seed()
1203+
auto_random = Random(current_seed)
1204+
1205+
def reseed_callback(new_seed: int) -> None:
1206+
auto_random.seed(new_seed)
1207+
1208+
register_random_seed_callback(reseed_callback)
1209+
return auto_random
1210+
1211+
11591212
def time() -> float:
11601213
"""Current seconds since the epoch from the workflow perspective.
11611214

tests/worker/test_workflow.py

Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8485,3 +8485,120 @@ async def test_disable_logger_sandbox(
84858485
run_timeout=timedelta(seconds=1),
84868486
retry_policy=RetryPolicy(maximum_attempts=1),
84878487
)
8488+
8489+
8490+
@workflow.defn
8491+
class RandomSeedTestWorkflow:
8492+
def __init__(self) -> None:
8493+
self.seed_changes: list[int] = []
8494+
self.continue_signal_received = False
8495+
self._ready = False
8496+
8497+
@workflow.run
8498+
async def run(self) -> dict[str, Any]:
8499+
# Get the initial seed
8500+
initial_seed = workflow.random_seed()
8501+
8502+
# Register callback to track seed changes
8503+
workflow.register_random_seed_callback(self._on_seed_change)
8504+
8505+
# Create a new random instance that auto-reseeds
8506+
auto_random = workflow.new_random()
8507+
8508+
# Generate random values before waiting
8509+
auto_value1 = auto_random.randint(1, 1000000)
8510+
8511+
# Do an activity to give a reset point
8512+
await workflow.execute_activity(
8513+
say_hello,
8514+
"Hi",
8515+
schedule_to_close_timeout=timedelta(seconds=5),
8516+
)
8517+
8518+
self._ready = True
8519+
8520+
# Wait for signal to continue - this allows for workflow reset
8521+
await workflow.wait_condition(lambda: self.continue_signal_received)
8522+
8523+
# Generate more random values after reset might have occurred
8524+
auto_value2 = auto_random.randint(1, 1000000)
8525+
8526+
# Get final seed
8527+
final_seed = workflow.random_seed()
8528+
8529+
return {
8530+
"initial_seed": initial_seed,
8531+
"final_seed": final_seed,
8532+
"seed_changes": self.seed_changes.copy(),
8533+
"auto_values": [auto_value1, auto_value2],
8534+
}
8535+
8536+
def _on_seed_change(self, new_seed: int) -> None:
8537+
self.seed_changes.append(new_seed)
8538+
8539+
@workflow.signal
8540+
def continue_workflow(self) -> None:
8541+
self.continue_signal_received = True
8542+
8543+
@workflow.query
8544+
def ready(self) -> bool:
8545+
return self._ready
8546+
8547+
8548+
async def test_random_seed_functionality(
8549+
client: Client, worker: Worker, env: WorkflowEnvironment
8550+
):
8551+
if env.supports_time_skipping:
8552+
pytest.skip("Java test server doesn't support reset")
8553+
async with new_worker(
8554+
client, RandomSeedTestWorkflow, activities=[say_hello], max_cached_workflows=0
8555+
) as worker:
8556+
workflow_id = f"test-random-seed-{uuid.uuid4()}"
8557+
handle = await client.start_workflow(
8558+
RandomSeedTestWorkflow.run,
8559+
id=workflow_id,
8560+
task_queue=worker.task_queue,
8561+
)
8562+
8563+
# Let workflow generate some random values
8564+
# Wait for workflow to be ready
8565+
async def ready() -> bool:
8566+
return await handle.query(RandomSeedTestWorkflow.ready)
8567+
8568+
await assert_eq_eventually(True, ready)
8569+
8570+
# Reset workflow using raw gRPC call to trigger seed change
8571+
from temporalio.api.common.v1.message_pb2 import WorkflowExecution
8572+
from temporalio.api.enums.v1.reset_pb2 import ResetReapplyType
8573+
from temporalio.api.workflowservice.v1 import ResetWorkflowExecutionRequest
8574+
8575+
await client.workflow_service.reset_workflow_execution(
8576+
ResetWorkflowExecutionRequest(
8577+
namespace=client.namespace,
8578+
workflow_execution=WorkflowExecution(
8579+
workflow_id=handle.id,
8580+
run_id="",
8581+
),
8582+
reason="Test seed change",
8583+
reset_reapply_type=ResetReapplyType.RESET_REAPPLY_TYPE_UNSPECIFIED,
8584+
request_id=str(uuid.uuid4()),
8585+
workflow_task_finish_event_id=9, # Reset to after activity completion
8586+
)
8587+
)
8588+
8589+
# Get handle to the reset workflow using the new run ID
8590+
reset_handle = client.get_workflow_handle(
8591+
workflow_id,
8592+
)
8593+
8594+
# Continue the workflow
8595+
await reset_handle.signal(RandomSeedTestWorkflow.continue_workflow)
8596+
8597+
result = await reset_handle.result()
8598+
8599+
# Verify basic functionality
8600+
assert isinstance(result["initial_seed"], int)
8601+
assert isinstance(result["final_seed"], int)
8602+
assert isinstance(result["seed_changes"], list)
8603+
assert len(result["auto_values"]) == 2
8604+
assert len(result["seed_changes"]) == 1

0 commit comments

Comments
 (0)