From 5a2d176df9478037f4456d0961b6ed2446b22d66 Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Mon, 7 Apr 2025 17:44:18 -0700 Subject: [PATCH 001/193] Permit emitting pydantic inputs in snapshot tests --- tests/codegen/snapshot/codegen_snapshot_fixtures.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/codegen/snapshot/codegen_snapshot_fixtures.py b/tests/codegen/snapshot/codegen_snapshot_fixtures.py index ef74a1fb..2cf43011 100644 --- a/tests/codegen/snapshot/codegen_snapshot_fixtures.py +++ b/tests/codegen/snapshot/codegen_snapshot_fixtures.py @@ -18,6 +18,7 @@ def validate_codegen( read_schema: Callable[[], TextIO], target_path: str, client_name: str, + typeddict_inputs: bool = True, ) -> None: snapshot.snapshot_dir = "tests/codegen/snapshot/snapshots" files: dict[Path, UnclosableStringIO] = {} @@ -33,7 +34,7 @@ def file_opener(path: Path) -> TextIO: target_path=target_path, client_name=client_name, file_opener=file_opener, - typed_dict_inputs=True, + typed_dict_inputs=typeddict_inputs, method_filter=None, ) for path, file in files.items(): From 1a099fe35c538e002a0d123bab0dddc7c8b74a40 Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Mon, 7 Apr 2025 17:20:49 -0700 Subject: [PATCH 002/193] Converting TypedDict<->pydantic parity tests over to snapshot testing --- scripts/parity.sh | 40 -- scripts/parity/check_parity.py | 295 --------------- tests/v1/codegen/snapshot/parity/__init__.py | 0 .../codegen/snapshot/parity/check_parity.py | 347 ++++++++++++++++++ .../v1/codegen/snapshot}/parity/gen.py | 0 tests/v1/codegen/snapshot/parity/schema.json | 312 ++++++++++++++++ .../parity/pydantic_inputs/__init__.py | 21 ++ .../agentToolLanguageServer/__init__.py | 47 +++ .../agentToolLanguageServer/openDocument.py | 49 +++ .../parity/pydantic_inputs/aiExec/__init__.py | 50 +++ .../parity/pydantic_inputs/aiExec/exec.py | 56 +++ .../conmanFilesystem/__init__.py | 47 +++ .../conmanFilesystem/persist.py | 43 +++ .../pydantic_inputs/replspaceApi/__init__.py | 50 +++ .../pydantic_inputs/replspaceApi/init.py | 106 ++++++ .../pydantic_inputs/shellExec/__init__.py | 47 +++ .../parity/pydantic_inputs/shellExec/spawn.py | 57 +++ .../parity/typeddict_inputs/__init__.py | 21 ++ .../agentToolLanguageServer/__init__.py | 43 +++ .../agentToolLanguageServer/openDocument.py | 58 +++ .../typeddict_inputs/aiExec/__init__.py | 46 +++ .../parity/typeddict_inputs/aiExec/exec.py | 84 +++++ .../conmanFilesystem/__init__.py | 43 +++ .../conmanFilesystem/persist.py | 46 +++ .../typeddict_inputs/replspaceApi/__init__.py | 48 +++ .../typeddict_inputs/replspaceApi/init.py | 247 +++++++++++++ .../typeddict_inputs/shellExec/__init__.py | 44 +++ .../typeddict_inputs/shellExec/spawn.py | 94 +++++ 28 files changed, 2006 insertions(+), 335 deletions(-) delete mode 100644 scripts/parity.sh delete mode 100644 scripts/parity/check_parity.py create mode 100644 tests/v1/codegen/snapshot/parity/__init__.py create mode 100644 tests/v1/codegen/snapshot/parity/check_parity.py rename {scripts => tests/v1/codegen/snapshot}/parity/gen.py (100%) create mode 100644 tests/v1/codegen/snapshot/parity/schema.json create mode 100644 tests/v1/codegen/snapshot/snapshots/parity/pydantic_inputs/__init__.py create mode 100644 tests/v1/codegen/snapshot/snapshots/parity/pydantic_inputs/agentToolLanguageServer/__init__.py create mode 100644 tests/v1/codegen/snapshot/snapshots/parity/pydantic_inputs/agentToolLanguageServer/openDocument.py create mode 100644 tests/v1/codegen/snapshot/snapshots/parity/pydantic_inputs/aiExec/__init__.py create mode 100644 tests/v1/codegen/snapshot/snapshots/parity/pydantic_inputs/aiExec/exec.py create mode 100644 tests/v1/codegen/snapshot/snapshots/parity/pydantic_inputs/conmanFilesystem/__init__.py create mode 100644 tests/v1/codegen/snapshot/snapshots/parity/pydantic_inputs/conmanFilesystem/persist.py create mode 100644 tests/v1/codegen/snapshot/snapshots/parity/pydantic_inputs/replspaceApi/__init__.py create mode 100644 tests/v1/codegen/snapshot/snapshots/parity/pydantic_inputs/replspaceApi/init.py create mode 100644 tests/v1/codegen/snapshot/snapshots/parity/pydantic_inputs/shellExec/__init__.py create mode 100644 tests/v1/codegen/snapshot/snapshots/parity/pydantic_inputs/shellExec/spawn.py create mode 100644 tests/v1/codegen/snapshot/snapshots/parity/typeddict_inputs/__init__.py create mode 100644 tests/v1/codegen/snapshot/snapshots/parity/typeddict_inputs/agentToolLanguageServer/__init__.py create mode 100644 tests/v1/codegen/snapshot/snapshots/parity/typeddict_inputs/agentToolLanguageServer/openDocument.py create mode 100644 tests/v1/codegen/snapshot/snapshots/parity/typeddict_inputs/aiExec/__init__.py create mode 100644 tests/v1/codegen/snapshot/snapshots/parity/typeddict_inputs/aiExec/exec.py create mode 100644 tests/v1/codegen/snapshot/snapshots/parity/typeddict_inputs/conmanFilesystem/__init__.py create mode 100644 tests/v1/codegen/snapshot/snapshots/parity/typeddict_inputs/conmanFilesystem/persist.py create mode 100644 tests/v1/codegen/snapshot/snapshots/parity/typeddict_inputs/replspaceApi/__init__.py create mode 100644 tests/v1/codegen/snapshot/snapshots/parity/typeddict_inputs/replspaceApi/init.py create mode 100644 tests/v1/codegen/snapshot/snapshots/parity/typeddict_inputs/shellExec/__init__.py create mode 100644 tests/v1/codegen/snapshot/snapshots/parity/typeddict_inputs/shellExec/spawn.py diff --git a/scripts/parity.sh b/scripts/parity.sh deleted file mode 100644 index c9beda24..00000000 --- a/scripts/parity.sh +++ /dev/null @@ -1,40 +0,0 @@ -#!/usr/bin/env bash -# -# parity.sh: Generate Pydantic and TypedDict models and check for deep equality. -# This script expects that ai-infra is cloned alongside river-python. - -set -e - -scripts="$(dirname "$0")" -cd "${scripts}/.." - -root="$(mktemp -d --tmpdir 'river-codegen-parity.XXX')" -mkdir "$root/src" - -echo "Using $root" >&2 - -function cleanup { - if [ -z "${DEBUG}" ]; then - echo "Cleaning up..." >&2 - rm -rfv "${root}" >&2 - fi -} -trap "cleanup" 0 2 3 15 - -gen() { - fname="$1"; shift - name="$1"; shift - poetry run python -m replit_river.codegen \ - client \ - --output "${root}/src/${fname}" \ - --client-name "${name}" \ - ../ai-infra/pkgs/pid2_client/src/schema/schema.json \ - "$@" -} - -gen tyd.py Pid2TypedDict --typed-dict-inputs -gen pyd.py Pid2Pydantic - -PYTHONPATH="${root}/src:${scripts}" -poetry run bash -c "MYPYPATH='$PYTHONPATH' mypy -m parity.check_parity" -poetry run bash -c "PYTHONPATH='$PYTHONPATH' python -m parity.check_parity" diff --git a/scripts/parity/check_parity.py b/scripts/parity/check_parity.py deleted file mode 100644 index c89a6918..00000000 --- a/scripts/parity/check_parity.py +++ /dev/null @@ -1,295 +0,0 @@ -from typing import Any, Callable, Literal, TypedDict, TypeVar - -import pyd -import tyd -from parity.gen import ( - gen_bool, - gen_choice, - gen_dict, - gen_float, - gen_int, - gen_list, - gen_opt, - gen_str, -) -from pydantic import TypeAdapter - -A = TypeVar("A") - -PrimitiveType = ( - bool | str | int | float | dict[str, "PrimitiveType"] | list["PrimitiveType"] -) - - -def deep_equal(a: PrimitiveType, b: PrimitiveType) -> Literal[True]: - if a == b: - return True - elif isinstance(a, dict) and isinstance(b, dict): - a_keys: PrimitiveType = list(a.keys()) - b_keys: PrimitiveType = list(b.keys()) - assert deep_equal(a_keys, b_keys) - - # We do this dance again because Python variance is hard. Feel free to fix it. - keys = set(a.keys()) - keys.update(b.keys()) - for k in keys: - aa: PrimitiveType = a[k] - bb: PrimitiveType = b[k] - assert deep_equal(aa, bb) - return True - elif isinstance(a, list) and isinstance(b, list): - assert len(a) == len(b) - for i in range(len(a)): - assert deep_equal(a[i], b[i]) - return True - else: - assert a == b, f"{a} != {b}" - return True - - -def baseTestPattern( - x: A, encode: Callable[[A], Any], adapter: TypeAdapter[Any] -) -> None: - a = encode(x) - m = adapter.validate_python(a) - z = adapter.dump_python(m, by_alias=True, exclude_none=True) - - assert deep_equal(a, z) - - -def testAiexecExecInit() -> None: - x: tyd.AiexecExecInit = { - "args": gen_list(gen_str)(), - "env": gen_opt(gen_dict(gen_str))(), - "cwd": gen_opt(gen_str)(), - "omitStdout": gen_opt(gen_bool)(), - "omitStderr": gen_opt(gen_bool)(), - "useReplitRunEnv": gen_opt(gen_bool)(), - } - - baseTestPattern(x, tyd.encode_AiexecExecInit, TypeAdapter(pyd.AiexecExecInit)) - - -def testAgenttoollanguageserverOpendocumentInput() -> None: - x: tyd.AgenttoollanguageserverOpendocumentInput = { - "uri": gen_str(), - "languageId": gen_str(), - "version": gen_float(), - "text": gen_str(), - } - - baseTestPattern( - x, - tyd.encode_AgenttoollanguageserverOpendocumentInput, - TypeAdapter(pyd.AgenttoollanguageserverOpendocumentInput), - ) - - -kind_type = ( - Literal[ - 1, - 2, - 3, - 4, - 5, - 6, - 7, - 8, - 9, - 10, - 11, - 12, - 13, - 14, - 15, - 16, - 17, - 18, - 19, - 20, - 21, - 22, - 23, - 24, - 25, - 26, - ] - | None -) - - -def testAgenttoollanguageserverGetcodesymbolInput() -> None: - x: tyd.AgenttoollanguageserverGetcodesymbolInput = { - "uri": gen_str(), - "position": { - "line": gen_float(), - "character": gen_float(), - }, - "kind": gen_choice( - list[kind_type]( - [ - 1, - 2, - 3, - 4, - 5, - 6, - 7, - 8, - 9, - 10, - 11, - 12, - 13, - 14, - 15, - 16, - 17, - 18, - 19, - 20, - 21, - 22, - 23, - 24, - 25, - 26, - None, - ] - ) - )(), - } - - baseTestPattern( - x, - tyd.encode_AgenttoollanguageserverGetcodesymbolInput, - TypeAdapter(pyd.AgenttoollanguageserverGetcodesymbolInput), - ) - - -class size_type(TypedDict): - rows: int - cols: int - - -def testShellexecSpawnInput() -> None: - x: tyd.ShellexecSpawnInput = { - "cmd": gen_str(), - "args": gen_opt(gen_list(gen_str))(), - "initialCmd": gen_opt(gen_str)(), - "env": gen_opt(gen_dict(gen_str))(), - "cwd": gen_opt(gen_str)(), - "size": gen_opt( - lambda: size_type( - { - "rows": gen_int(), - "cols": gen_int(), - } - ), - )(), - "useReplitRunEnv": gen_opt(gen_bool)(), - "useCgroupMagic": gen_opt(gen_bool)(), - "interactive": gen_opt(gen_bool)(), - "onlySpawnIfNoProcesses": gen_opt(gen_bool)(), - } - - baseTestPattern( - x, - tyd.encode_ShellexecSpawnInput, - TypeAdapter(pyd.ShellexecSpawnInput), - ) - - -def testConmanfilesystemPersistInput() -> None: - x: tyd.ConmanfilesystemPersistInput = {} - - baseTestPattern( - x, - tyd.encode_ConmanfilesystemPersistInput, - TypeAdapter(pyd.ConmanfilesystemPersistInput), - ) - - -closeFile = tyd.ReplspaceapiInitInputOneOf_closeFile -githubToken = tyd.ReplspaceapiInitInputOneOf_githubToken -sshToken0 = tyd.ReplspaceapiInitInputOneOf_sshToken0 -sshToken1 = tyd.ReplspaceapiInitInputOneOf_sshToken1 -allowDefaultBucketAccess = tyd.ReplspaceapiInitInputOneOf_allowDefaultBucketAccess - -allowDefaultBucketAccessResultOk = ( - tyd.ReplspaceapiInitInputOneOf_allowDefaultBucketAccessResultOneOf_ok -) -allowDefaultBucketAccessResultError = ( - tyd.ReplspaceapiInitInputOneOf_allowDefaultBucketAccessResultOneOf_error -) - - -def testReplspaceapiInitInput() -> None: - x: tyd.ReplspaceapiInitInput = gen_choice( - list[tyd.ReplspaceapiInitInput]( - [ - closeFile( - {"kind": "closeFile", "filename": gen_str(), "nonce": gen_str()} - ), - githubToken( - {"kind": "githubToken", "token": gen_str(), "nonce": gen_str()} - ), - sshToken0( - { - "kind": "sshToken", - "nonce": gen_str(), - "SSHHostname": gen_str(), - "token": gen_str(), - } - ), - sshToken1({"kind": "sshToken", "nonce": gen_str(), "error": gen_str()}), - allowDefaultBucketAccess( - { - "kind": "allowDefaultBucketAccess", - "nonce": gen_str(), - "result": gen_choice( - list[ - tyd.ReplspaceapiInitInputOneOf_allowDefaultBucketAccessResult - ]( - [ - allowDefaultBucketAccessResultOk( - { - "bucketId": gen_str(), - "sourceReplId": gen_str(), - "status": "ok", - "targetReplId": gen_str(), - } - ), - allowDefaultBucketAccessResultError( - {"message": gen_str(), "status": "error"} - ), - ] - ) - )(), - } - ), - ] - ) - )() - - baseTestPattern( - x, - tyd.encode_ReplspaceapiInitInput, - TypeAdapter(pyd.ReplspaceapiInitInput), - ) - - -def main() -> None: - testAiexecExecInit() - testAgenttoollanguageserverOpendocumentInput() - testAgenttoollanguageserverGetcodesymbolInput() - testShellexecSpawnInput() - testConmanfilesystemPersistInput() - testReplspaceapiInitInput() - - -if __name__ == "__main__": - print("Starting...") - for _ in range(0, 100): - main() - print("Verified") diff --git a/tests/v1/codegen/snapshot/parity/__init__.py b/tests/v1/codegen/snapshot/parity/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/v1/codegen/snapshot/parity/check_parity.py b/tests/v1/codegen/snapshot/parity/check_parity.py new file mode 100644 index 00000000..68731bd2 --- /dev/null +++ b/tests/v1/codegen/snapshot/parity/check_parity.py @@ -0,0 +1,347 @@ +import importlib +from typing import Any, Callable, Literal, TypedDict + +from pydantic import TypeAdapter +from pytest_snapshot.plugin import Snapshot + +from tests.fixtures.codegen_snapshot_fixtures import validate_codegen +from tests.v1.codegen.snapshot.parity.gen import ( + gen_bool, + gen_choice, + gen_dict, + gen_int, + gen_list, + gen_opt, + gen_str, +) + +PrimitiveType = ( + bool | str | int | float | dict[str, "PrimitiveType"] | list["PrimitiveType"] +) + + +def deep_equal(a: PrimitiveType, b: PrimitiveType) -> Literal[True]: + if a == b: + return True + elif isinstance(a, dict) and isinstance(b, dict): + a_keys: PrimitiveType = list(a.keys()) + b_keys: PrimitiveType = list(b.keys()) + assert deep_equal(a_keys, b_keys) + + # We do this dance again because Python variance is hard. Feel free to fix it. + keys = set(a.keys()) + keys.update(b.keys()) + for k in keys: + aa: PrimitiveType = a[k] + bb: PrimitiveType = b[k] + assert deep_equal(aa, bb) + return True + elif isinstance(a, list) and isinstance(b, list): + assert len(a) == len(b) + for i in range(len(a)): + assert deep_equal(a[i], b[i]) + return True + else: + assert a == b, f"{a} != {b}" + return True + + +def baseTestPattern[A]( + x: A, encode: Callable[[A], Any], adapter: TypeAdapter[Any] +) -> None: + a = encode(x) + m = adapter.validate_python(a) + z = adapter.dump_python(m, by_alias=True, exclude_none=True) + + assert deep_equal(a, z) + + +def test_AiexecExecInit(snapshot: Snapshot) -> None: + validate_codegen( + snapshot=snapshot, + snapshot_dir="tests/v1/codegen/snapshot/snapshots", + read_schema=lambda: open("tests/v1/codegen/snapshot/parity/schema.json"), + target_path="parity/typeddict_inputs", + client_name="foo", + protocol_version="v1.1", + typeddict_inputs=True, + ) + + validate_codegen( + snapshot=snapshot, + snapshot_dir="tests/v1/codegen/snapshot/snapshots", + read_schema=lambda: open("tests/v1/codegen/snapshot/parity/schema.json"), + target_path="parity/pydantic_inputs", + client_name="foo", + protocol_version="v1.1", + typeddict_inputs=False, + ) + + import tests.v1.codegen.snapshot.snapshots.parity + + importlib.reload(tests.v1.codegen.snapshot.snapshots.parity) + + from tests.v1.codegen.snapshot.snapshots.parity import pydantic_inputs as pyd + from tests.v1.codegen.snapshot.snapshots.parity import typeddict_inputs as tyd + + x: tyd.aiExec.ExecInit = { + "args": gen_list(gen_str)(), + "env": gen_opt(gen_dict(gen_str))(), + "cwd": gen_opt(gen_str)(), + "omitStdout": gen_opt(gen_bool)(), + "omitStderr": gen_opt(gen_bool)(), + "useReplitRunEnv": gen_opt(gen_bool)(), + } + + baseTestPattern(x, tyd.aiExec.encode_ExecInit, pyd.aiExec.ExecInitTypeAdapter) + + +def test_AgenttoollanguageserverOpendocumentInput(snapshot: Snapshot) -> None: + validate_codegen( + snapshot=snapshot, + snapshot_dir="tests/v1/codegen/snapshot/snapshots", + read_schema=lambda: open("tests/v1/codegen/snapshot/parity/schema.json"), + target_path="parity/typeddict_inputs", + client_name="foo", + protocol_version="v1.1", + typeddict_inputs=True, + ) + + validate_codegen( + snapshot=snapshot, + snapshot_dir="tests/v1/codegen/snapshot/snapshots", + read_schema=lambda: open("tests/v1/codegen/snapshot/parity/schema.json"), + target_path="parity/pydantic_inputs", + client_name="foo", + protocol_version="v1.1", + typeddict_inputs=False, + ) + + import tests.v1.codegen.snapshot.snapshots.parity + + importlib.reload(tests.v1.codegen.snapshot.snapshots.parity) + + from tests.v1.codegen.snapshot.snapshots.parity import pydantic_inputs as pyd + from tests.v1.codegen.snapshot.snapshots.parity import typeddict_inputs as tyd + + x: tyd.agentToolLanguageServer.OpendocumentInput = { + "path": gen_str(), + } + + baseTestPattern( + x, + tyd.agentToolLanguageServer.encode_OpendocumentInput, + pyd.agentToolLanguageServer.OpendocumentInputTypeAdapter, + ) + + +kind_type = ( + Literal[ + 1, + 2, + 3, + 4, + 5, + 6, + 7, + 8, + 9, + 10, + 11, + 12, + 13, + 14, + 15, + 16, + 17, + 18, + 19, + 20, + 21, + 22, + 23, + 24, + 25, + 26, + ] + | None +) + + +class size_type(TypedDict): + rows: int + cols: int + + +def test_ShellexecSpawnInput(snapshot: Snapshot) -> None: + validate_codegen( + snapshot=snapshot, + snapshot_dir="tests/v1/codegen/snapshot/snapshots", + read_schema=lambda: open("tests/v1/codegen/snapshot/parity/schema.json"), + target_path="parity/typeddict_inputs", + client_name="foo", + protocol_version="v1.1", + typeddict_inputs=True, + ) + + validate_codegen( + snapshot=snapshot, + snapshot_dir="tests/v1/codegen/snapshot/snapshots", + read_schema=lambda: open("tests/v1/codegen/snapshot/parity/schema.json"), + target_path="parity/pydantic_inputs", + client_name="foo", + protocol_version="v1.1", + typeddict_inputs=False, + ) + + import tests.v1.codegen.snapshot.snapshots.parity + + importlib.reload(tests.v1.codegen.snapshot.snapshots.parity) + + from tests.v1.codegen.snapshot.snapshots.parity import pydantic_inputs as pyd + from tests.v1.codegen.snapshot.snapshots.parity import typeddict_inputs as tyd + + x: tyd.shellExec.SpawnInput = { + "cmd": gen_str(), + "args": gen_opt(gen_list(gen_str))(), + "initialCmd": gen_opt(gen_str)(), + "env": gen_opt(gen_dict(gen_str))(), + "cwd": gen_opt(gen_str)(), + "size": gen_opt( + lambda: size_type( + { + "rows": gen_int(), + "cols": gen_int(), + } + ), + )(), + "useReplitRunEnv": gen_opt(gen_bool)(), + "interactive": gen_opt(gen_bool)(), + } + + baseTestPattern( + x, + tyd.shellExec.encode_SpawnInput, + pyd.shellExec.SpawnInputTypeAdapter, + ) + + +def test_ConmanfilesystemPersistInput(snapshot: Snapshot) -> None: + validate_codegen( + snapshot=snapshot, + snapshot_dir="tests/v1/codegen/snapshot/snapshots", + read_schema=lambda: open("tests/v1/codegen/snapshot/parity/schema.json"), + target_path="parity/typeddict_inputs", + client_name="foo", + protocol_version="v1.1", + typeddict_inputs=True, + ) + + validate_codegen( + snapshot=snapshot, + snapshot_dir="tests/v1/codegen/snapshot/snapshots", + read_schema=lambda: open("tests/v1/codegen/snapshot/parity/schema.json"), + target_path="parity/pydantic_inputs", + client_name="foo", + protocol_version="v1.1", + typeddict_inputs=False, + ) + + import tests.v1.codegen.snapshot.snapshots.parity + + importlib.reload(tests.v1.codegen.snapshot.snapshots.parity) + + from tests.v1.codegen.snapshot.snapshots.parity import pydantic_inputs as pyd + from tests.v1.codegen.snapshot.snapshots.parity import typeddict_inputs as tyd + + x: tyd.conmanFilesystem.PersistInput = {} + + baseTestPattern( + x, + tyd.conmanFilesystem.encode_PersistInput, + pyd.conmanFilesystem.PersistInputTypeAdapter, + ) + + +def test_ReplspaceapiInitInput(snapshot: Snapshot) -> None: + validate_codegen( + snapshot=snapshot, + snapshot_dir="tests/v1/codegen/snapshot/snapshots", + read_schema=lambda: open("tests/v1/codegen/snapshot/parity/schema.json"), + target_path="parity/typeddict_inputs", + client_name="foo", + protocol_version="v1.1", + typeddict_inputs=True, + ) + + validate_codegen( + snapshot=snapshot, + snapshot_dir="tests/v1/codegen/snapshot/snapshots", + read_schema=lambda: open("tests/v1/codegen/snapshot/parity/schema.json"), + target_path="parity/pydantic_inputs", + client_name="foo", + protocol_version="v1.1", + typeddict_inputs=False, + ) + + import tests.v1.codegen.snapshot.snapshots.parity + + importlib.reload(tests.v1.codegen.snapshot.snapshots.parity) + + from tests.v1.codegen.snapshot.snapshots.parity import pydantic_inputs as pyd + from tests.v1.codegen.snapshot.snapshots.parity import typeddict_inputs as tyd + + x: tyd.replspaceApi.init.InitInput = gen_choice( + list[tyd.replspaceApi.init.InitInput]( + [ + tyd.replspaceApi.init.InitInputOneOf_closeFile( + {"kind": "closeFile", "filename": gen_str(), "nonce": gen_str()} + ), + tyd.replspaceApi.init.InitInputOneOf_githubToken( + {"kind": "githubToken", "token": gen_str(), "nonce": gen_str()} + ), + tyd.replspaceApi.init.InitInputOneOf_sshToken0( + { + "kind": "sshToken", + "nonce": gen_str(), + "SSHHostname": gen_str(), + "token": gen_str(), + } + ), + tyd.replspaceApi.init.InitInputOneOf_sshToken1( + {"kind": "sshToken", "nonce": gen_str(), "error": gen_str()} + ), + tyd.replspaceApi.init.InitInputOneOf_allowDefaultBucketAccess( + { + "kind": "allowDefaultBucketAccess", + "nonce": gen_str(), + "result": gen_choice( + list[ + tyd.replspaceApi.init.InitInputOneOf_allowDefaultBucketAccessResult + ]( + [ + tyd.replspaceApi.init.InitInputOneOf_allowDefaultBucketAccessResultOneOf_ok( + { + "bucketId": gen_str(), + "sourceReplId": gen_str(), + "status": "ok", + "targetReplId": gen_str(), + } + ), + tyd.replspaceApi.init.InitInputOneOf_allowDefaultBucketAccessResultOneOf_error( + {"message": gen_str(), "status": "error"} + ), + ] + ) + )(), + } + ), + ] + ) + )() + + baseTestPattern( + x, + tyd.replspaceApi.init.encode_InitInput, + pyd.replspaceApi.init.InitInputTypeAdapter, + ) diff --git a/scripts/parity/gen.py b/tests/v1/codegen/snapshot/parity/gen.py similarity index 100% rename from scripts/parity/gen.py rename to tests/v1/codegen/snapshot/parity/gen.py diff --git a/tests/v1/codegen/snapshot/parity/schema.json b/tests/v1/codegen/snapshot/parity/schema.json new file mode 100644 index 00000000..dbe3a74e --- /dev/null +++ b/tests/v1/codegen/snapshot/parity/schema.json @@ -0,0 +1,312 @@ +{ + "services": { + "aiExec": { + "procedures": { + "exec": { + "init": { + "type": "object", + "properties": { + "args": { + "type": "array", + "items": { + "type": "string" + } + }, + "env": { + "type": "object", + "patternProperties": { + "^(.*)$": { + "type": "string" + } + } + }, + "cwd": { + "type": "string" + }, + "omitStdout": { + "type": "boolean" + }, + "omitStderr": { + "type": "boolean" + }, + "useReplitRunEnv": { + "type": "boolean" + } + }, + "required": ["args"] + }, + "output": { + "type": "object", + "properties": {} + }, + "errors": { + "type": "object", + "properties": {} + }, + "type": "stream", + "input": { + "type": "object", + "properties": {} + } + } + } + }, + "agentToolLanguageServer": { + "procedures": { + "openDocument": { + "input": { + "type": "object", + "properties": { + "path": { + "description": "The path to the file. This should be relative to the workspace root", + "type": "string" + } + }, + "required": ["path"] + }, + "output": { + "type": "object", + "properties": {} + }, + "errors": { + "type": "object", + "properties": {} + }, + "description": "Reports a document as open to the language servers", + "type": "rpc" + } + } + }, + "shellExec": { + "procedures": { + "spawn": { + "input": { + "type": "object", + "properties": { + "cmd": { + "type": "string" + }, + "args": { + "type": "array", + "items": { + "type": "string" + } + }, + "initialCmd": { + "type": "string" + }, + "env": { + "type": "object", + "patternProperties": { + "^(.*)$": { + "type": "string" + } + } + }, + "cwd": { + "type": "string" + }, + "size": { + "type": "object", + "properties": { + "rows": { + "type": "integer" + }, + "cols": { + "type": "integer" + } + }, + "required": ["rows", "cols"] + }, + "useReplitRunEnv": { + "type": "boolean" + }, + "interactive": { + "type": "boolean" + } + }, + "required": ["cmd"] + }, + "output": { + "type": "object", + "properties": {} + }, + "errors": { + "type": "object", + "properties": {} + }, + "description": "Start a new shell process and returns the process handle id which you can use to interact with the process", + "type": "rpc" + } + } + }, + "replspaceApi": { + "procedures": { + "init": { + "init": { + "type": "object", + "properties": {} + }, + "output": { + "type": "object", + "properties": {} + }, + "errors": { + "type": "object", + "properties": {} + }, + "type": "stream", + "input": { + "anyOf": [ + { + "type": "object", + "properties": { + "$kind": { + "const": "closeFile", + "type": "string" + }, + "filename": { + "type": "string" + }, + "nonce": { + "type": "string" + } + }, + "required": ["$kind", "filename", "nonce"] + }, + { + "type": "object", + "properties": { + "$kind": { + "const": "githubToken", + "type": "string" + }, + "token": { + "type": "string" + }, + "nonce": { + "type": "string" + } + }, + "required": ["$kind", "nonce"] + }, + { + "anyOf": [ + { + "type": "object", + "properties": { + "$kind": { + "const": "sshToken", + "type": "string" + }, + "nonce": { + "type": "string" + }, + "token": { + "type": "string" + }, + "SSHHostname": { + "type": "string" + } + }, + "required": ["$kind", "nonce", "token", "SSHHostname"] + }, + { + "type": "object", + "properties": { + "$kind": { + "const": "sshToken", + "type": "string" + }, + "nonce": { + "type": "string" + }, + "error": { + "type": "string" + } + }, + "required": ["$kind", "nonce", "error"] + } + ] + }, + { + "type": "object", + "properties": { + "$kind": { + "const": "allowDefaultBucketAccess", + "type": "string" + }, + "nonce": { + "type": "string" + }, + "result": { + "anyOf": [ + { + "type": "object", + "properties": { + "status": { + "const": "error", + "type": "string" + }, + "message": { + "type": "string" + } + }, + "required": ["status", "message"] + }, + { + "type": "object", + "properties": { + "status": { + "const": "ok", + "type": "string" + }, + "targetReplId": { + "type": "string" + }, + "sourceReplId": { + "type": "string" + }, + "bucketId": { + "type": "string" + } + }, + "required": [ + "status", + "targetReplId", + "sourceReplId", + "bucketId" + ] + } + ] + } + }, + "required": ["$kind", "nonce", "result"] + } + ] + } + } + } + }, + "conmanFilesystem": { + "procedures": { + "persist": { + "input": { + "type": "object", + "properties": {} + }, + "output": { + "type": "object", + "properties": {} + }, + "errors": { + "type": "object", + "properties": {} + }, + "type": "rpc" + } + } + } + }, + "handshakeSchema": null +} diff --git a/tests/v1/codegen/snapshot/snapshots/parity/pydantic_inputs/__init__.py b/tests/v1/codegen/snapshot/snapshots/parity/pydantic_inputs/__init__.py new file mode 100644 index 00000000..0f5c7dfb --- /dev/null +++ b/tests/v1/codegen/snapshot/snapshots/parity/pydantic_inputs/__init__.py @@ -0,0 +1,21 @@ +# Code generated by river.codegen. DO NOT EDIT. +from pydantic import BaseModel +from typing import Literal + +import replit_river as river + + +from .agentToolLanguageServer import AgenttoollanguageserverService +from .aiExec import AiexecService +from .conmanFilesystem import ConmanfilesystemService +from .replspaceApi import ReplspaceapiService +from .shellExec import ShellexecService + + +class foo: + def __init__(self, client: river.Client[Literal[None]]): + self.aiExec = AiexecService(client) + self.agentToolLanguageServer = AgenttoollanguageserverService(client) + self.shellExec = ShellexecService(client) + self.replspaceApi = ReplspaceapiService(client) + self.conmanFilesystem = ConmanfilesystemService(client) diff --git a/tests/v1/codegen/snapshot/snapshots/parity/pydantic_inputs/agentToolLanguageServer/__init__.py b/tests/v1/codegen/snapshot/snapshots/parity/pydantic_inputs/agentToolLanguageServer/__init__.py new file mode 100644 index 00000000..f867f902 --- /dev/null +++ b/tests/v1/codegen/snapshot/snapshots/parity/pydantic_inputs/agentToolLanguageServer/__init__.py @@ -0,0 +1,47 @@ +# Code generated by river.codegen. DO NOT EDIT. +from collections.abc import AsyncIterable, AsyncIterator +from typing import Any +import datetime + +from pydantic import TypeAdapter + +from replit_river.error_schema import RiverError, RiverErrorTypeAdapter +import replit_river as river + + +from .openDocument import ( + OpendocumentErrors, + OpendocumentErrorsTypeAdapter, + OpendocumentInput, + OpendocumentInputTypeAdapter, + OpendocumentOutput, + OpendocumentOutputTypeAdapter, +) + + +class AgenttoollanguageserverService: + def __init__(self, client: river.Client[Any]): + self.client = client + + async def openDocument( + self, + input: OpendocumentInput, + timeout: datetime.timedelta, + ) -> OpendocumentOutput: + return await self.client.send_rpc( + "agentToolLanguageServer", + "openDocument", + input, + lambda x: OpendocumentInputTypeAdapter.dump_python( + x, # type: ignore[arg-type] + by_alias=True, + exclude_none=True, + ), + lambda x: OpendocumentOutputTypeAdapter.validate_python( + x # type: ignore[arg-type] + ), + lambda x: OpendocumentErrorsTypeAdapter.validate_python( + x # type: ignore[arg-type] + ), + timeout, + ) diff --git a/tests/v1/codegen/snapshot/snapshots/parity/pydantic_inputs/agentToolLanguageServer/openDocument.py b/tests/v1/codegen/snapshot/snapshots/parity/pydantic_inputs/agentToolLanguageServer/openDocument.py new file mode 100644 index 00000000..02c814e1 --- /dev/null +++ b/tests/v1/codegen/snapshot/snapshots/parity/pydantic_inputs/agentToolLanguageServer/openDocument.py @@ -0,0 +1,49 @@ +# Code generated by river.codegen. DO NOT EDIT. +from collections.abc import AsyncIterable, AsyncIterator +import datetime +from typing import ( + Any, + Literal, + Mapping, + NotRequired, + TypedDict, +) +from typing_extensions import Annotated + +from pydantic import BaseModel, Field, TypeAdapter, WrapValidator +from replit_river.error_schema import RiverError +from replit_river.client import ( + RiverUnknownError, + translate_unknown_error, + RiverUnknownValue, + translate_unknown_value, +) + +import replit_river as river + + +class OpendocumentInput(BaseModel): + path: str + + +class OpendocumentOutput(BaseModel): + pass + + +OpendocumentOutputTypeAdapter: TypeAdapter[OpendocumentOutput] = TypeAdapter( + OpendocumentOutput +) + + +class OpendocumentErrors(RiverError): + pass + + +OpendocumentErrorsTypeAdapter: TypeAdapter[OpendocumentErrors] = TypeAdapter( + OpendocumentErrors +) + + +OpendocumentInputTypeAdapter: TypeAdapter[OpendocumentInput] = TypeAdapter( + OpendocumentInput +) diff --git a/tests/v1/codegen/snapshot/snapshots/parity/pydantic_inputs/aiExec/__init__.py b/tests/v1/codegen/snapshot/snapshots/parity/pydantic_inputs/aiExec/__init__.py new file mode 100644 index 00000000..f8d9671d --- /dev/null +++ b/tests/v1/codegen/snapshot/snapshots/parity/pydantic_inputs/aiExec/__init__.py @@ -0,0 +1,50 @@ +# Code generated by river.codegen. DO NOT EDIT. +from collections.abc import AsyncIterable, AsyncIterator +from typing import Any +import datetime + +from pydantic import TypeAdapter + +from replit_river.error_schema import RiverError, RiverErrorTypeAdapter +import replit_river as river + + +from .exec import ( + ExecErrors, + ExecErrorsTypeAdapter, + ExecInit, + ExecInitTypeAdapter, + ExecInput, + ExecInputTypeAdapter, + ExecOutput, + ExecOutputTypeAdapter, +) + + +class AiexecService: + def __init__(self, client: river.Client[Any]): + self.client = client + + async def exec( + self, + init: ExecInit, + inputStream: AsyncIterable[ExecInput], + ) -> AsyncIterator[ExecOutput | ExecErrors | RiverError]: + return self.client.send_stream( + "aiExec", + "exec", + init, + inputStream, + lambda x: ExecInitTypeAdapter.validate_python(x), + lambda x: ExecInputTypeAdapter.dump_python( + x, # type: ignore[arg-type] + by_alias=True, + exclude_none=True, + ), + lambda x: ExecOutputTypeAdapter.validate_python( + x # type: ignore[arg-type] + ), + lambda x: ExecErrorsTypeAdapter.validate_python( + x # type: ignore[arg-type] + ), + ) diff --git a/tests/v1/codegen/snapshot/snapshots/parity/pydantic_inputs/aiExec/exec.py b/tests/v1/codegen/snapshot/snapshots/parity/pydantic_inputs/aiExec/exec.py new file mode 100644 index 00000000..57c6c69a --- /dev/null +++ b/tests/v1/codegen/snapshot/snapshots/parity/pydantic_inputs/aiExec/exec.py @@ -0,0 +1,56 @@ +# Code generated by river.codegen. DO NOT EDIT. +from collections.abc import AsyncIterable, AsyncIterator +import datetime +from typing import ( + Any, + Literal, + Mapping, + NotRequired, + TypedDict, +) +from typing_extensions import Annotated + +from pydantic import BaseModel, Field, TypeAdapter, WrapValidator +from replit_river.error_schema import RiverError +from replit_river.client import ( + RiverUnknownError, + translate_unknown_error, + RiverUnknownValue, + translate_unknown_value, +) + +import replit_river as river + + +class ExecInit(BaseModel): + args: list[str] + cwd: str | None = None + env: dict[str, str] | None = None + omitStderr: bool | None = None + omitStdout: bool | None = None + useReplitRunEnv: bool | None = None + + +class ExecInput(BaseModel): + kind: Annotated[Literal["stdin"], Field(alias="$kind")] = "stdin" + stdin: bytes + + +class ExecOutput(BaseModel): + pass + + +ExecOutputTypeAdapter: TypeAdapter[ExecOutput] = TypeAdapter(ExecOutput) + + +class ExecErrors(RiverError): + pass + + +ExecErrorsTypeAdapter: TypeAdapter[ExecErrors] = TypeAdapter(ExecErrors) + + +ExecInitTypeAdapter: TypeAdapter[ExecInit] = TypeAdapter(ExecInit) + + +ExecInputTypeAdapter: TypeAdapter[ExecInput] = TypeAdapter(ExecInput) diff --git a/tests/v1/codegen/snapshot/snapshots/parity/pydantic_inputs/conmanFilesystem/__init__.py b/tests/v1/codegen/snapshot/snapshots/parity/pydantic_inputs/conmanFilesystem/__init__.py new file mode 100644 index 00000000..42d92981 --- /dev/null +++ b/tests/v1/codegen/snapshot/snapshots/parity/pydantic_inputs/conmanFilesystem/__init__.py @@ -0,0 +1,47 @@ +# Code generated by river.codegen. DO NOT EDIT. +from collections.abc import AsyncIterable, AsyncIterator +from typing import Any +import datetime + +from pydantic import TypeAdapter + +from replit_river.error_schema import RiverError, RiverErrorTypeAdapter +import replit_river as river + + +from .persist import ( + PersistErrors, + PersistErrorsTypeAdapter, + PersistInput, + PersistInputTypeAdapter, + PersistOutput, + PersistOutputTypeAdapter, +) + + +class ConmanfilesystemService: + def __init__(self, client: river.Client[Any]): + self.client = client + + async def persist( + self, + input: PersistInput, + timeout: datetime.timedelta, + ) -> PersistOutput: + return await self.client.send_rpc( + "conmanFilesystem", + "persist", + input, + lambda x: PersistInputTypeAdapter.dump_python( + x, # type: ignore[arg-type] + by_alias=True, + exclude_none=True, + ), + lambda x: PersistOutputTypeAdapter.validate_python( + x # type: ignore[arg-type] + ), + lambda x: PersistErrorsTypeAdapter.validate_python( + x # type: ignore[arg-type] + ), + timeout, + ) diff --git a/tests/v1/codegen/snapshot/snapshots/parity/pydantic_inputs/conmanFilesystem/persist.py b/tests/v1/codegen/snapshot/snapshots/parity/pydantic_inputs/conmanFilesystem/persist.py new file mode 100644 index 00000000..3eb5fa01 --- /dev/null +++ b/tests/v1/codegen/snapshot/snapshots/parity/pydantic_inputs/conmanFilesystem/persist.py @@ -0,0 +1,43 @@ +# Code generated by river.codegen. DO NOT EDIT. +from collections.abc import AsyncIterable, AsyncIterator +import datetime +from typing import ( + Any, + Literal, + Mapping, + NotRequired, + TypedDict, +) +from typing_extensions import Annotated + +from pydantic import BaseModel, Field, TypeAdapter, WrapValidator +from replit_river.error_schema import RiverError +from replit_river.client import ( + RiverUnknownError, + translate_unknown_error, + RiverUnknownValue, + translate_unknown_value, +) + +import replit_river as river + + +class PersistInput(BaseModel): + pass + + +class PersistOutput(BaseModel): + pass + + +PersistOutputTypeAdapter: TypeAdapter[PersistOutput] = TypeAdapter(PersistOutput) + + +class PersistErrors(RiverError): + pass + + +PersistErrorsTypeAdapter: TypeAdapter[PersistErrors] = TypeAdapter(PersistErrors) + + +PersistInputTypeAdapter: TypeAdapter[PersistInput] = TypeAdapter(PersistInput) diff --git a/tests/v1/codegen/snapshot/snapshots/parity/pydantic_inputs/replspaceApi/__init__.py b/tests/v1/codegen/snapshot/snapshots/parity/pydantic_inputs/replspaceApi/__init__.py new file mode 100644 index 00000000..8842016c --- /dev/null +++ b/tests/v1/codegen/snapshot/snapshots/parity/pydantic_inputs/replspaceApi/__init__.py @@ -0,0 +1,50 @@ +# Code generated by river.codegen. DO NOT EDIT. +from collections.abc import AsyncIterable, AsyncIterator +from typing import Any +import datetime + +from pydantic import TypeAdapter + +from replit_river.error_schema import RiverError, RiverErrorTypeAdapter +import replit_river as river + + +from .init import ( + InitErrors, + InitErrorsTypeAdapter, + InitInit, + InitInitTypeAdapter, + InitInput, + InitInputTypeAdapter, + InitOutput, + InitOutputTypeAdapter, +) + + +class ReplspaceapiService: + def __init__(self, client: river.Client[Any]): + self.client = client + + async def init( + self, + init: InitInit, + inputStream: AsyncIterable[InitInput], + ) -> AsyncIterator[InitOutput | InitErrors | RiverError]: + return self.client.send_stream( + "replspaceApi", + "init", + init, + inputStream, + lambda x: InitInitTypeAdapter.validate_python(x), + lambda x: InitInputTypeAdapter.dump_python( + x, # type: ignore[arg-type] + by_alias=True, + exclude_none=True, + ), + lambda x: InitOutputTypeAdapter.validate_python( + x # type: ignore[arg-type] + ), + lambda x: InitErrorsTypeAdapter.validate_python( + x # type: ignore[arg-type] + ), + ) diff --git a/tests/v1/codegen/snapshot/snapshots/parity/pydantic_inputs/replspaceApi/init.py b/tests/v1/codegen/snapshot/snapshots/parity/pydantic_inputs/replspaceApi/init.py new file mode 100644 index 00000000..46bc3a96 --- /dev/null +++ b/tests/v1/codegen/snapshot/snapshots/parity/pydantic_inputs/replspaceApi/init.py @@ -0,0 +1,106 @@ +# Code generated by river.codegen. DO NOT EDIT. +from collections.abc import AsyncIterable, AsyncIterator +import datetime +from typing import ( + Any, + Literal, + Mapping, + NotRequired, + TypedDict, +) +from typing_extensions import Annotated + +from pydantic import BaseModel, Field, TypeAdapter, WrapValidator +from replit_river.error_schema import RiverError +from replit_river.client import ( + RiverUnknownError, + translate_unknown_error, + RiverUnknownValue, + translate_unknown_value, +) + +import replit_river as river + + +class InitInit(BaseModel): + pass + + +class InitInputOneOf_closeFile(BaseModel): + kind: Annotated[Literal["closeFile"], Field(alias="$kind")] = "closeFile" + filename: str + nonce: str + + +class InitInputOneOf_githubToken(BaseModel): + kind: Annotated[Literal["githubToken"], Field(alias="$kind")] = "githubToken" + nonce: str + token: str | None = None + + +class InitInputOneOf_sshToken0(BaseModel): + kind: Annotated[Literal["sshToken"], Field(alias="$kind")] = "sshToken" + SSHHostname: str + nonce: str + token: str + + +class InitInputOneOf_sshToken1(BaseModel): + kind: Annotated[Literal["sshToken"], Field(alias="$kind")] = "sshToken" + error: str + nonce: str + + +class InitInputOneOf_allowDefaultBucketAccessResultOneOf_error(BaseModel): + message: str + status: Literal["error"] + + +class InitInputOneOf_allowDefaultBucketAccessResultOneOf_ok(BaseModel): + bucketId: str + sourceReplId: str + status: Literal["ok"] + targetReplId: str + + +InitInputOneOf_allowDefaultBucketAccessResult = ( + InitInputOneOf_allowDefaultBucketAccessResultOneOf_error + | InitInputOneOf_allowDefaultBucketAccessResultOneOf_ok +) + + +class InitInputOneOf_allowDefaultBucketAccess(BaseModel): + kind: Annotated[Literal["allowDefaultBucketAccess"], Field(alias="$kind")] = ( + "allowDefaultBucketAccess" + ) + nonce: str + result: InitInputOneOf_allowDefaultBucketAccessResult + + +InitInput = ( + InitInputOneOf_closeFile + | InitInputOneOf_githubToken + | InitInputOneOf_sshToken0 + | InitInputOneOf_sshToken1 + | InitInputOneOf_allowDefaultBucketAccess +) + + +class InitOutput(BaseModel): + pass + + +InitOutputTypeAdapter: TypeAdapter[InitOutput] = TypeAdapter(InitOutput) + + +class InitErrors(RiverError): + pass + + +InitErrorsTypeAdapter: TypeAdapter[InitErrors] = TypeAdapter(InitErrors) + + +InitInitTypeAdapter: TypeAdapter[InitInit] = TypeAdapter(InitInit) + + +InitInputTypeAdapter: TypeAdapter[InitInput] = TypeAdapter(InitInput) diff --git a/tests/v1/codegen/snapshot/snapshots/parity/pydantic_inputs/shellExec/__init__.py b/tests/v1/codegen/snapshot/snapshots/parity/pydantic_inputs/shellExec/__init__.py new file mode 100644 index 00000000..9e0013a7 --- /dev/null +++ b/tests/v1/codegen/snapshot/snapshots/parity/pydantic_inputs/shellExec/__init__.py @@ -0,0 +1,47 @@ +# Code generated by river.codegen. DO NOT EDIT. +from collections.abc import AsyncIterable, AsyncIterator +from typing import Any +import datetime + +from pydantic import TypeAdapter + +from replit_river.error_schema import RiverError, RiverErrorTypeAdapter +import replit_river as river + + +from .spawn import ( + SpawnErrors, + SpawnErrorsTypeAdapter, + SpawnInput, + SpawnInputTypeAdapter, + SpawnOutput, + SpawnOutputTypeAdapter, +) + + +class ShellexecService: + def __init__(self, client: river.Client[Any]): + self.client = client + + async def spawn( + self, + input: SpawnInput, + timeout: datetime.timedelta, + ) -> SpawnOutput: + return await self.client.send_rpc( + "shellExec", + "spawn", + input, + lambda x: SpawnInputTypeAdapter.dump_python( + x, # type: ignore[arg-type] + by_alias=True, + exclude_none=True, + ), + lambda x: SpawnOutputTypeAdapter.validate_python( + x # type: ignore[arg-type] + ), + lambda x: SpawnErrorsTypeAdapter.validate_python( + x # type: ignore[arg-type] + ), + timeout, + ) diff --git a/tests/v1/codegen/snapshot/snapshots/parity/pydantic_inputs/shellExec/spawn.py b/tests/v1/codegen/snapshot/snapshots/parity/pydantic_inputs/shellExec/spawn.py new file mode 100644 index 00000000..6f4f6473 --- /dev/null +++ b/tests/v1/codegen/snapshot/snapshots/parity/pydantic_inputs/shellExec/spawn.py @@ -0,0 +1,57 @@ +# Code generated by river.codegen. DO NOT EDIT. +from collections.abc import AsyncIterable, AsyncIterator +import datetime +from typing import ( + Any, + Literal, + Mapping, + NotRequired, + TypedDict, +) +from typing_extensions import Annotated + +from pydantic import BaseModel, Field, TypeAdapter, WrapValidator +from replit_river.error_schema import RiverError +from replit_river.client import ( + RiverUnknownError, + translate_unknown_error, + RiverUnknownValue, + translate_unknown_value, +) + +import replit_river as river + + +class SpawnInputSize(BaseModel): + cols: int + rows: int + + +class SpawnInput(BaseModel): + args: list[str] | None = None + autoCleanup: bool | None = None + cmd: str + cwd: str | None = None + env: dict[str, str] | None = None + initialCmd: str | None = None + interactive: bool | None = None + size: SpawnInputSize | None = None + useCgroupMagic: bool | None = None + useReplitRunEnv: bool | None = None + + +class SpawnOutput(BaseModel): + pass + + +SpawnOutputTypeAdapter: TypeAdapter[SpawnOutput] = TypeAdapter(SpawnOutput) + + +class SpawnErrors(RiverError): + pass + + +SpawnErrorsTypeAdapter: TypeAdapter[SpawnErrors] = TypeAdapter(SpawnErrors) + + +SpawnInputTypeAdapter: TypeAdapter[SpawnInput] = TypeAdapter(SpawnInput) diff --git a/tests/v1/codegen/snapshot/snapshots/parity/typeddict_inputs/__init__.py b/tests/v1/codegen/snapshot/snapshots/parity/typeddict_inputs/__init__.py new file mode 100644 index 00000000..0f5c7dfb --- /dev/null +++ b/tests/v1/codegen/snapshot/snapshots/parity/typeddict_inputs/__init__.py @@ -0,0 +1,21 @@ +# Code generated by river.codegen. DO NOT EDIT. +from pydantic import BaseModel +from typing import Literal + +import replit_river as river + + +from .agentToolLanguageServer import AgenttoollanguageserverService +from .aiExec import AiexecService +from .conmanFilesystem import ConmanfilesystemService +from .replspaceApi import ReplspaceapiService +from .shellExec import ShellexecService + + +class foo: + def __init__(self, client: river.Client[Literal[None]]): + self.aiExec = AiexecService(client) + self.agentToolLanguageServer = AgenttoollanguageserverService(client) + self.shellExec = ShellexecService(client) + self.replspaceApi = ReplspaceapiService(client) + self.conmanFilesystem = ConmanfilesystemService(client) diff --git a/tests/v1/codegen/snapshot/snapshots/parity/typeddict_inputs/agentToolLanguageServer/__init__.py b/tests/v1/codegen/snapshot/snapshots/parity/typeddict_inputs/agentToolLanguageServer/__init__.py new file mode 100644 index 00000000..4cc3ebdd --- /dev/null +++ b/tests/v1/codegen/snapshot/snapshots/parity/typeddict_inputs/agentToolLanguageServer/__init__.py @@ -0,0 +1,43 @@ +# Code generated by river.codegen. DO NOT EDIT. +from collections.abc import AsyncIterable, AsyncIterator +from typing import Any +import datetime + +from pydantic import TypeAdapter + +from replit_river.error_schema import RiverError, RiverErrorTypeAdapter +import replit_river as river + + +from .openDocument import ( + OpendocumentErrors, + OpendocumentErrorsTypeAdapter, + OpendocumentInput, + OpendocumentOutput, + OpendocumentOutputTypeAdapter, + encode_OpendocumentInput, +) + + +class AgenttoollanguageserverService: + def __init__(self, client: river.Client[Any]): + self.client = client + + async def openDocument( + self, + input: OpendocumentInput, + timeout: datetime.timedelta, + ) -> OpendocumentOutput: + return await self.client.send_rpc( + "agentToolLanguageServer", + "openDocument", + input, + encode_OpendocumentInput, + lambda x: OpendocumentOutputTypeAdapter.validate_python( + x # type: ignore[arg-type] + ), + lambda x: OpendocumentErrorsTypeAdapter.validate_python( + x # type: ignore[arg-type] + ), + timeout, + ) diff --git a/tests/v1/codegen/snapshot/snapshots/parity/typeddict_inputs/agentToolLanguageServer/openDocument.py b/tests/v1/codegen/snapshot/snapshots/parity/typeddict_inputs/agentToolLanguageServer/openDocument.py new file mode 100644 index 00000000..efbd4d8d --- /dev/null +++ b/tests/v1/codegen/snapshot/snapshots/parity/typeddict_inputs/agentToolLanguageServer/openDocument.py @@ -0,0 +1,58 @@ +# Code generated by river.codegen. DO NOT EDIT. +from collections.abc import AsyncIterable, AsyncIterator +import datetime +from typing import ( + Any, + Literal, + Mapping, + NotRequired, + TypedDict, +) +from typing_extensions import Annotated + +from pydantic import BaseModel, Field, TypeAdapter, WrapValidator +from replit_river.error_schema import RiverError +from replit_river.client import ( + RiverUnknownError, + translate_unknown_error, + RiverUnknownValue, + translate_unknown_value, +) + +import replit_river as river + + +def encode_OpendocumentInput( + x: "OpendocumentInput", +) -> Any: + return { + k: v + for (k, v) in ( + { + "path": x.get("path"), + } + ).items() + if v is not None + } + + +class OpendocumentInput(TypedDict): + path: str + + +class OpendocumentOutput(BaseModel): + pass + + +OpendocumentOutputTypeAdapter: TypeAdapter[OpendocumentOutput] = TypeAdapter( + OpendocumentOutput +) + + +class OpendocumentErrors(RiverError): + pass + + +OpendocumentErrorsTypeAdapter: TypeAdapter[OpendocumentErrors] = TypeAdapter( + OpendocumentErrors +) diff --git a/tests/v1/codegen/snapshot/snapshots/parity/typeddict_inputs/aiExec/__init__.py b/tests/v1/codegen/snapshot/snapshots/parity/typeddict_inputs/aiExec/__init__.py new file mode 100644 index 00000000..d66a196b --- /dev/null +++ b/tests/v1/codegen/snapshot/snapshots/parity/typeddict_inputs/aiExec/__init__.py @@ -0,0 +1,46 @@ +# Code generated by river.codegen. DO NOT EDIT. +from collections.abc import AsyncIterable, AsyncIterator +from typing import Any +import datetime + +from pydantic import TypeAdapter + +from replit_river.error_schema import RiverError, RiverErrorTypeAdapter +import replit_river as river + + +from .exec import ( + ExecErrors, + ExecErrorsTypeAdapter, + ExecInit, + ExecInput, + ExecOutput, + ExecOutputTypeAdapter, + encode_ExecInit, + encode_ExecInput, +) + + +class AiexecService: + def __init__(self, client: river.Client[Any]): + self.client = client + + async def exec( + self, + init: ExecInit, + inputStream: AsyncIterable[ExecInput], + ) -> AsyncIterator[ExecOutput | ExecErrors | RiverError]: + return self.client.send_stream( + "aiExec", + "exec", + init, + inputStream, + encode_ExecInit, + encode_ExecInput, + lambda x: ExecOutputTypeAdapter.validate_python( + x # type: ignore[arg-type] + ), + lambda x: ExecErrorsTypeAdapter.validate_python( + x # type: ignore[arg-type] + ), + ) diff --git a/tests/v1/codegen/snapshot/snapshots/parity/typeddict_inputs/aiExec/exec.py b/tests/v1/codegen/snapshot/snapshots/parity/typeddict_inputs/aiExec/exec.py new file mode 100644 index 00000000..8f7bc741 --- /dev/null +++ b/tests/v1/codegen/snapshot/snapshots/parity/typeddict_inputs/aiExec/exec.py @@ -0,0 +1,84 @@ +# Code generated by river.codegen. DO NOT EDIT. +from collections.abc import AsyncIterable, AsyncIterator +import datetime +from typing import ( + Any, + Literal, + Mapping, + NotRequired, + TypedDict, +) +from typing_extensions import Annotated + +from pydantic import BaseModel, Field, TypeAdapter, WrapValidator +from replit_river.error_schema import RiverError +from replit_river.client import ( + RiverUnknownError, + translate_unknown_error, + RiverUnknownValue, + translate_unknown_value, +) + +import replit_river as river + + +def encode_ExecInit( + x: "ExecInit", +) -> Any: + return { + k: v + for (k, v) in ( + { + "args": x.get("args"), + "cwd": x.get("cwd"), + "env": x.get("env"), + "omitStderr": x.get("omitStderr"), + "omitStdout": x.get("omitStdout"), + "useReplitRunEnv": x.get("useReplitRunEnv"), + } + ).items() + if v is not None + } + + +class ExecInit(TypedDict): + args: list[str] + cwd: NotRequired[str | None] + env: NotRequired[dict[str, str] | None] + omitStderr: NotRequired[bool | None] + omitStdout: NotRequired[bool | None] + useReplitRunEnv: NotRequired[bool | None] + + +def encode_ExecInput( + x: "ExecInput", +) -> Any: + return { + k: v + for (k, v) in ( + { + "$kind": x.get("kind"), + "stdin": x.get("stdin"), + } + ).items() + if v is not None + } + + +class ExecInput(TypedDict): + kind: Annotated[Literal["stdin"], Field(alias="$kind")] + stdin: bytes + + +class ExecOutput(BaseModel): + pass + + +ExecOutputTypeAdapter: TypeAdapter[ExecOutput] = TypeAdapter(ExecOutput) + + +class ExecErrors(RiverError): + pass + + +ExecErrorsTypeAdapter: TypeAdapter[ExecErrors] = TypeAdapter(ExecErrors) diff --git a/tests/v1/codegen/snapshot/snapshots/parity/typeddict_inputs/conmanFilesystem/__init__.py b/tests/v1/codegen/snapshot/snapshots/parity/typeddict_inputs/conmanFilesystem/__init__.py new file mode 100644 index 00000000..ef5b401d --- /dev/null +++ b/tests/v1/codegen/snapshot/snapshots/parity/typeddict_inputs/conmanFilesystem/__init__.py @@ -0,0 +1,43 @@ +# Code generated by river.codegen. DO NOT EDIT. +from collections.abc import AsyncIterable, AsyncIterator +from typing import Any +import datetime + +from pydantic import TypeAdapter + +from replit_river.error_schema import RiverError, RiverErrorTypeAdapter +import replit_river as river + + +from .persist import ( + PersistErrors, + PersistErrorsTypeAdapter, + PersistInput, + PersistOutput, + PersistOutputTypeAdapter, + encode_PersistInput, +) + + +class ConmanfilesystemService: + def __init__(self, client: river.Client[Any]): + self.client = client + + async def persist( + self, + input: PersistInput, + timeout: datetime.timedelta, + ) -> PersistOutput: + return await self.client.send_rpc( + "conmanFilesystem", + "persist", + input, + encode_PersistInput, + lambda x: PersistOutputTypeAdapter.validate_python( + x # type: ignore[arg-type] + ), + lambda x: PersistErrorsTypeAdapter.validate_python( + x # type: ignore[arg-type] + ), + timeout, + ) diff --git a/tests/v1/codegen/snapshot/snapshots/parity/typeddict_inputs/conmanFilesystem/persist.py b/tests/v1/codegen/snapshot/snapshots/parity/typeddict_inputs/conmanFilesystem/persist.py new file mode 100644 index 00000000..5ea0df8a --- /dev/null +++ b/tests/v1/codegen/snapshot/snapshots/parity/typeddict_inputs/conmanFilesystem/persist.py @@ -0,0 +1,46 @@ +# Code generated by river.codegen. DO NOT EDIT. +from collections.abc import AsyncIterable, AsyncIterator +import datetime +from typing import ( + Any, + Literal, + Mapping, + NotRequired, + TypedDict, +) +from typing_extensions import Annotated + +from pydantic import BaseModel, Field, TypeAdapter, WrapValidator +from replit_river.error_schema import RiverError +from replit_river.client import ( + RiverUnknownError, + translate_unknown_error, + RiverUnknownValue, + translate_unknown_value, +) + +import replit_river as river + + +def encode_PersistInput( + _: "PersistInput", +) -> Any: + return {} + + +class PersistInput(TypedDict): + pass + + +class PersistOutput(BaseModel): + pass + + +PersistOutputTypeAdapter: TypeAdapter[PersistOutput] = TypeAdapter(PersistOutput) + + +class PersistErrors(RiverError): + pass + + +PersistErrorsTypeAdapter: TypeAdapter[PersistErrors] = TypeAdapter(PersistErrors) diff --git a/tests/v1/codegen/snapshot/snapshots/parity/typeddict_inputs/replspaceApi/__init__.py b/tests/v1/codegen/snapshot/snapshots/parity/typeddict_inputs/replspaceApi/__init__.py new file mode 100644 index 00000000..72f33c93 --- /dev/null +++ b/tests/v1/codegen/snapshot/snapshots/parity/typeddict_inputs/replspaceApi/__init__.py @@ -0,0 +1,48 @@ +# Code generated by river.codegen. DO NOT EDIT. +from collections.abc import AsyncIterable, AsyncIterator +from typing import Any +import datetime + +from pydantic import TypeAdapter + +from replit_river.error_schema import RiverError, RiverErrorTypeAdapter +import replit_river as river + + +from .init import ( + InitErrors, + InitErrorsTypeAdapter, + InitInit, + InitInput, + InitOutput, + InitOutputTypeAdapter, + encode_InitInit, + encode_InitInput, + encode_InitInputOneOf_sshToken0, + encode_InitInputOneOf_sshToken1, +) + + +class ReplspaceapiService: + def __init__(self, client: river.Client[Any]): + self.client = client + + async def init( + self, + init: InitInit, + inputStream: AsyncIterable[InitInput], + ) -> AsyncIterator[InitOutput | InitErrors | RiverError]: + return self.client.send_stream( + "replspaceApi", + "init", + init, + inputStream, + encode_InitInit, + encode_InitInput, + lambda x: InitOutputTypeAdapter.validate_python( + x # type: ignore[arg-type] + ), + lambda x: InitErrorsTypeAdapter.validate_python( + x # type: ignore[arg-type] + ), + ) diff --git a/tests/v1/codegen/snapshot/snapshots/parity/typeddict_inputs/replspaceApi/init.py b/tests/v1/codegen/snapshot/snapshots/parity/typeddict_inputs/replspaceApi/init.py new file mode 100644 index 00000000..aecafb6c --- /dev/null +++ b/tests/v1/codegen/snapshot/snapshots/parity/typeddict_inputs/replspaceApi/init.py @@ -0,0 +1,247 @@ +# Code generated by river.codegen. DO NOT EDIT. +from collections.abc import AsyncIterable, AsyncIterator +import datetime +from typing import ( + Any, + Literal, + Mapping, + NotRequired, + TypedDict, +) +from typing_extensions import Annotated + +from pydantic import BaseModel, Field, TypeAdapter, WrapValidator +from replit_river.error_schema import RiverError +from replit_river.client import ( + RiverUnknownError, + translate_unknown_error, + RiverUnknownValue, + translate_unknown_value, +) + +import replit_river as river + + +def encode_InitInit( + _: "InitInit", +) -> Any: + return {} + + +class InitInit(TypedDict): + pass + + +def encode_InitInputOneOf_closeFile( + x: "InitInputOneOf_closeFile", +) -> Any: + return { + k: v + for (k, v) in ( + { + "$kind": x.get("kind"), + "filename": x.get("filename"), + "nonce": x.get("nonce"), + } + ).items() + if v is not None + } + + +class InitInputOneOf_closeFile(TypedDict): + kind: Annotated[Literal["closeFile"], Field(alias="$kind")] + filename: str + nonce: str + + +def encode_InitInputOneOf_githubToken( + x: "InitInputOneOf_githubToken", +) -> Any: + return { + k: v + for (k, v) in ( + { + "$kind": x.get("kind"), + "nonce": x.get("nonce"), + "token": x.get("token"), + } + ).items() + if v is not None + } + + +class InitInputOneOf_githubToken(TypedDict): + kind: Annotated[Literal["githubToken"], Field(alias="$kind")] + nonce: str + token: NotRequired[str | None] + + +def encode_InitInputOneOf_sshToken0( + x: "InitInputOneOf_sshToken0", +) -> Any: + return { + k: v + for (k, v) in ( + { + "$kind": x.get("kind"), + "SSHHostname": x.get("SSHHostname"), + "nonce": x.get("nonce"), + "token": x.get("token"), + } + ).items() + if v is not None + } + + +class InitInputOneOf_sshToken0(TypedDict): + kind: Annotated[Literal["sshToken"], Field(alias="$kind")] + SSHHostname: str + nonce: str + token: str + + +def encode_InitInputOneOf_sshToken1( + x: "InitInputOneOf_sshToken1", +) -> Any: + return { + k: v + for (k, v) in ( + { + "$kind": x.get("kind"), + "error": x.get("error"), + "nonce": x.get("nonce"), + } + ).items() + if v is not None + } + + +class InitInputOneOf_sshToken1(TypedDict): + kind: Annotated[Literal["sshToken"], Field(alias="$kind")] + error: str + nonce: str + + +def encode_InitInputOneOf_allowDefaultBucketAccessResultOneOf_error( + x: "InitInputOneOf_allowDefaultBucketAccessResultOneOf_error", +) -> Any: + return { + k: v + for (k, v) in ( + { + "message": x.get("message"), + "status": x.get("status"), + } + ).items() + if v is not None + } + + +class InitInputOneOf_allowDefaultBucketAccessResultOneOf_error(TypedDict): + message: str + status: Literal["error"] + + +def encode_InitInputOneOf_allowDefaultBucketAccessResultOneOf_ok( + x: "InitInputOneOf_allowDefaultBucketAccessResultOneOf_ok", +) -> Any: + return { + k: v + for (k, v) in ( + { + "bucketId": x.get("bucketId"), + "sourceReplId": x.get("sourceReplId"), + "status": x.get("status"), + "targetReplId": x.get("targetReplId"), + } + ).items() + if v is not None + } + + +class InitInputOneOf_allowDefaultBucketAccessResultOneOf_ok(TypedDict): + bucketId: str + sourceReplId: str + status: Literal["ok"] + targetReplId: str + + +InitInputOneOf_allowDefaultBucketAccessResult = ( + InitInputOneOf_allowDefaultBucketAccessResultOneOf_error + | InitInputOneOf_allowDefaultBucketAccessResultOneOf_ok +) + + +def encode_InitInputOneOf_allowDefaultBucketAccessResult( + x: "InitInputOneOf_allowDefaultBucketAccessResult", +) -> Any: + return ( + encode_InitInputOneOf_allowDefaultBucketAccessResultOneOf_error(x) + if x["status"] == "error" + else encode_InitInputOneOf_allowDefaultBucketAccessResultOneOf_ok(x) + ) + + +def encode_InitInputOneOf_allowDefaultBucketAccess( + x: "InitInputOneOf_allowDefaultBucketAccess", +) -> Any: + return { + k: v + for (k, v) in ( + { + "$kind": x.get("kind"), + "nonce": x.get("nonce"), + "result": encode_InitInputOneOf_allowDefaultBucketAccessResult( + x["result"] + ), + } + ).items() + if v is not None + } + + +class InitInputOneOf_allowDefaultBucketAccess(TypedDict): + kind: Annotated[Literal["allowDefaultBucketAccess"], Field(alias="$kind")] + nonce: str + result: InitInputOneOf_allowDefaultBucketAccessResult + + +InitInput = ( + InitInputOneOf_closeFile + | InitInputOneOf_githubToken + | InitInputOneOf_sshToken0 + | InitInputOneOf_sshToken1 + | InitInputOneOf_allowDefaultBucketAccess +) + + +def encode_InitInput( + x: "InitInput", +) -> Any: + return ( + encode_InitInputOneOf_closeFile(x) + if x["kind"] == "closeFile" + else encode_InitInputOneOf_githubToken(x) + if x["kind"] == "githubToken" + else ( + encode_InitInputOneOf_sshToken0(x) # type: ignore[arg-type] + if "token" in x + else encode_InitInputOneOf_sshToken1(x) # type: ignore[arg-type] + ) + if x["kind"] == "sshToken" + else encode_InitInputOneOf_allowDefaultBucketAccess(x) + ) + + +class InitOutput(BaseModel): + pass + + +InitOutputTypeAdapter: TypeAdapter[InitOutput] = TypeAdapter(InitOutput) + + +class InitErrors(RiverError): + pass + + +InitErrorsTypeAdapter: TypeAdapter[InitErrors] = TypeAdapter(InitErrors) diff --git a/tests/v1/codegen/snapshot/snapshots/parity/typeddict_inputs/shellExec/__init__.py b/tests/v1/codegen/snapshot/snapshots/parity/typeddict_inputs/shellExec/__init__.py new file mode 100644 index 00000000..41f0d9e9 --- /dev/null +++ b/tests/v1/codegen/snapshot/snapshots/parity/typeddict_inputs/shellExec/__init__.py @@ -0,0 +1,44 @@ +# Code generated by river.codegen. DO NOT EDIT. +from collections.abc import AsyncIterable, AsyncIterator +from typing import Any +import datetime + +from pydantic import TypeAdapter + +from replit_river.error_schema import RiverError, RiverErrorTypeAdapter +import replit_river as river + + +from .spawn import ( + SpawnErrors, + SpawnErrorsTypeAdapter, + SpawnInput, + SpawnOutput, + SpawnOutputTypeAdapter, + encode_SpawnInput, + encode_SpawnInputSize, +) + + +class ShellexecService: + def __init__(self, client: river.Client[Any]): + self.client = client + + async def spawn( + self, + input: SpawnInput, + timeout: datetime.timedelta, + ) -> SpawnOutput: + return await self.client.send_rpc( + "shellExec", + "spawn", + input, + encode_SpawnInput, + lambda x: SpawnOutputTypeAdapter.validate_python( + x # type: ignore[arg-type] + ), + lambda x: SpawnErrorsTypeAdapter.validate_python( + x # type: ignore[arg-type] + ), + timeout, + ) diff --git a/tests/v1/codegen/snapshot/snapshots/parity/typeddict_inputs/shellExec/spawn.py b/tests/v1/codegen/snapshot/snapshots/parity/typeddict_inputs/shellExec/spawn.py new file mode 100644 index 00000000..d4eb64b7 --- /dev/null +++ b/tests/v1/codegen/snapshot/snapshots/parity/typeddict_inputs/shellExec/spawn.py @@ -0,0 +1,94 @@ +# Code generated by river.codegen. DO NOT EDIT. +from collections.abc import AsyncIterable, AsyncIterator +import datetime +from typing import ( + Any, + Literal, + Mapping, + NotRequired, + TypedDict, +) +from typing_extensions import Annotated + +from pydantic import BaseModel, Field, TypeAdapter, WrapValidator +from replit_river.error_schema import RiverError +from replit_river.client import ( + RiverUnknownError, + translate_unknown_error, + RiverUnknownValue, + translate_unknown_value, +) + +import replit_river as river + + +def encode_SpawnInputSize( + x: "SpawnInputSize", +) -> Any: + return { + k: v + for (k, v) in ( + { + "cols": x.get("cols"), + "rows": x.get("rows"), + } + ).items() + if v is not None + } + + +class SpawnInputSize(TypedDict): + cols: int + rows: int + + +def encode_SpawnInput( + x: "SpawnInput", +) -> Any: + return { + k: v + for (k, v) in ( + { + "args": x.get("args"), + "autoCleanup": x.get("autoCleanup"), + "cmd": x.get("cmd"), + "cwd": x.get("cwd"), + "env": x.get("env"), + "initialCmd": x.get("initialCmd"), + "interactive": x.get("interactive"), + "size": encode_SpawnInputSize(x["size"]) + if "size" in x and x["size"] is not None + else None, + "useCgroupMagic": x.get("useCgroupMagic"), + "useReplitRunEnv": x.get("useReplitRunEnv"), + } + ).items() + if v is not None + } + + +class SpawnInput(TypedDict): + args: NotRequired[list[str] | None] + autoCleanup: NotRequired[bool | None] + cmd: str + cwd: NotRequired[str | None] + env: NotRequired[dict[str, str] | None] + initialCmd: NotRequired[str | None] + interactive: NotRequired[bool | None] + size: NotRequired[SpawnInputSize | None] + useCgroupMagic: NotRequired[bool | None] + useReplitRunEnv: NotRequired[bool | None] + + +class SpawnOutput(BaseModel): + pass + + +SpawnOutputTypeAdapter: TypeAdapter[SpawnOutput] = TypeAdapter(SpawnOutput) + + +class SpawnErrors(RiverError): + pass + + +SpawnErrorsTypeAdapter: TypeAdapter[SpawnErrors] = TypeAdapter(SpawnErrors) From 8b683c3bb2f28c087dd9f97c34064922675f2981 Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Fri, 21 Mar 2025 23:26:16 -0700 Subject: [PATCH 003/193] Address error message --- src/replit_river/codegen/client.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/replit_river/codegen/client.py b/src/replit_river/codegen/client.py index d4e69ddf..fa8205a2 100644 --- a/src/replit_river/codegen/client.py +++ b/src/replit_river/codegen/client.py @@ -969,7 +969,7 @@ def __init__(self, client: river.Client[Any]): """ assert init_type is None or render_init_method, ( - f"Unable to derive the init encoder from: {input_type}" + f"Unable to derive the init encoder from: {init_type}" ) # Input renderer From 35ae452f3a41424f02921082a89aa2be08a483f9 Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Fri, 21 Mar 2025 23:52:46 -0700 Subject: [PATCH 004/193] Making input and init types optional, gating on protocol version --- src/replit_river/codegen/client.py | 126 ++++++++++++++--------------- 1 file changed, 60 insertions(+), 66 deletions(-) diff --git a/src/replit_river/codegen/client.py b/src/replit_river/codegen/client.py index fa8205a2..b01f6462 100644 --- a/src/replit_river/codegen/client.py +++ b/src/replit_river/codegen/client.py @@ -824,8 +824,9 @@ def __init__(self, client: river.Client[Any]): continue module_names = [ModuleName(name)] init_type: TypeExpression | None = None + init_module_info: list[ModuleName] = [] if procedure.init: - init_type, module_info, init_chunks, encoder_names = encode_type( + init_type, init_module_info, init_chunks, encoder_names = encode_type( procedure.init, TypeName(f"{name.title()}Init"), input_base_class, @@ -835,34 +836,29 @@ def __init__(self, client: river.Client[Any]): serdes.append( ( [extract_inner_type(init_type), *encoder_names], - module_info, + init_module_info, init_chunks, ) ) - input_type, module_info, input_chunks, encoder_names = encode_type( - procedure.input, - TypeName(f"{name.title()}Input"), - input_base_class, - module_names, - permit_unknown_members=False, - ) - input_type_name = extract_inner_type(input_type) - input_type_type_adapter_name = TypeName( - f"{render_literal_type(input_type_name)}TypeAdapter" - ) - serdes.append( - ( - [extract_inner_type(input_type), *encoder_names], - module_info, - input_chunks, + input_type: TypeExpression | None = None + input_module_info: list[ModuleName] = [] + if procedure.input: + input_type, input_module_info, input_chunks, encoder_names = encode_type( + procedure.input, + TypeName(f"{name.title()}Input"), + input_base_class, + module_names, + permit_unknown_members=False, ) - ) - serdes.append( - _type_adapter_definition( - input_type_type_adapter_name, input_type, module_info + serdes.append( + ( + [extract_inner_type(input_type), *encoder_names], + input_module_info, + input_chunks, + ) ) - ) - output_type, module_info, output_chunks, encoder_names = encode_type( + + output_type, output_module_info, output_chunks, encoder_names = encode_type( procedure.output, TypeName(f"{name.title()}Output"), "BaseModel", @@ -873,7 +869,7 @@ def __init__(self, client: river.Client[Any]): serdes.append( ( [output_type_name, *encoder_names], - module_info, + output_module_info, output_chunks, ) ) @@ -882,12 +878,12 @@ def __init__(self, client: river.Client[Any]): ) serdes.append( _type_adapter_definition( - output_type_type_adapter_name, output_type, module_info + output_type_type_adapter_name, output_type, output_module_info ) ) - output_module_info = module_info + if procedure.errors: - error_type, module_info, errors_chunks, encoder_names = encode_type( + error_type, error_module_info, errors_chunks, encoder_names = encode_type( procedure.errors, TypeName(f"{name.title()}Errors"), "RiverError", @@ -899,7 +895,7 @@ def __init__(self, client: river.Client[Any]): error_type = error_type_name else: error_type_name = extract_inner_type(error_type) - serdes.append(([error_type_name], module_info, errors_chunks)) + serdes.append(([error_type_name], error_module_info, errors_chunks)) else: error_type_name = TypeName("RiverError") @@ -909,11 +905,9 @@ def __init__(self, client: river.Client[Any]): f"{render_literal_type(error_type_name)}TypeAdapter" ) if error_type_type_adapter_name.value != "RiverErrorTypeAdapter": - if len(module_info) == 0: - module_info = output_module_info serdes.append( _type_adapter_definition( - error_type_type_adapter_name, error_type, module_info + error_type_type_adapter_name, error_type, output_module_info ) ) output_or_error_type = UnionTypeExpr([output_type, error_type_name]) @@ -960,7 +954,7 @@ def __init__(self, client: river.Client[Any]): ) serdes.append( _type_adapter_definition( - init_type_type_adapter_name, init_type, module_info + init_type_type_adapter_name, init_type, init_module_info ) ) render_init_method = f"""\ @@ -968,40 +962,42 @@ def __init__(self, client: river.Client[Any]): .validate_python(x) """ - assert init_type is None or render_init_method, ( - f"Unable to derive the init encoder from: {init_type}" - ) - # Input renderer render_input_method: str | None = None - if input_base_class == "TypedDict": - if is_literal(procedure.input): - render_input_method = "lambda x: x" - elif isinstance( - procedure.input, RiverConcreteType - ) and procedure.input.type in ["array"]: - match input_type: - case ListTypeExpr(list_type): - render_input_method = f"""\ - lambda xs: [ - encode_{render_literal_type(list_type)}(x) for x in xs - ] - """ - else: - render_input_method = f"encode_{render_literal_type(input_type)}" - else: - render_input_method = f"""\ - lambda x: {render_type_expr(input_type_type_adapter_name)} - .dump_python( - x, # type: ignore[arg-type] - by_alias=True, - exclude_none=True, - ) + if input_type and procedure.input is not None: + if input_base_class == "TypedDict": + if is_literal(procedure.input): + render_input_method = "lambda x: x" + elif isinstance( + procedure.input, RiverConcreteType + ) and procedure.input.type in ["array"]: + match input_type: + case ListTypeExpr(list_type): + render_input_method = f"""\ + lambda xs: [ + encode_{render_literal_type(list_type)}(x) for x in xs + ] """ - - assert render_input_method, ( - f"Unable to derive the input encoder from: {input_type}" - ) + else: + render_input_method = f"encode_{render_literal_type(input_type)}" + else: + input_type_name = extract_inner_type(input_type) + input_type_type_adapter_name = TypeName( + f"{render_literal_type(input_type_name)}TypeAdapter" + ) + serdes.append( + _type_adapter_definition( + input_type_type_adapter_name, input_type, input_module_info + ) + ) + render_input_method = f"""\ + lambda x: {render_type_expr(input_type_type_adapter_name)} + .dump_python( + x, # type: ignore[arg-type] + by_alias=True, + exclude_none=True, + ) + """ if isinstance(output_type, NoneTypeExpr): parse_output_method = "lambda x: None" @@ -1060,7 +1056,6 @@ async def {name}( ) elif procedure.type == "upload": if init_type: - assert render_init_method, "Expected an init renderer!" current_chunks.extend( [ reindent( @@ -1117,7 +1112,6 @@ async def {name}( ] ) if init_type: - assert render_init_method, "Expected an init renderer!" current_chunks.extend( [ reindent( From d884d2ee895b10031bc5d086cd485cb88f3a92b3 Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Sat, 22 Mar 2025 00:20:04 -0700 Subject: [PATCH 005/193] Moving library call render method out to reduce local scope --- src/replit_river/codegen/client.py | 417 ++++++++++++++++++----------- 1 file changed, 254 insertions(+), 163 deletions(-) diff --git a/src/replit_river/codegen/client.py b/src/replit_river/codegen/client.py index b01f6462..00b24026 100644 --- a/src/replit_river/codegen/client.py +++ b/src/replit_river/codegen/client.py @@ -780,6 +780,226 @@ def __init__(self, client: river.Client[{handshake_type}]): return FileContents("\n".join(chunks)) +def render_library_call( + schema_name: str, + name: str, + procedure: RiverProcedure, + init_meta: tuple[RiverType, TypeExpression, str] | None, + input_meta: tuple[RiverType, TypeExpression, str] | None, + output_meta: tuple[RiverType, TypeExpression, str] | None, + error_meta: tuple[RiverType, TypeExpression, str] | None, +) -> list[str]: + """ + This method is only ever called from one place, but it's defensively establishing a + namespace that lets us draw some new boundaries around the parameters, without the + pollution from other intermediatae values. + """ + current_chunks: list[str] = [] + + if procedure.type == "rpc": + assert input_meta + assert output_meta + assert error_meta + _, input_type, render_input_method = input_meta + _, output_type, parse_output_method = output_meta + _, _, parse_error_method = error_meta + + current_chunks.extend( + [ + reindent( + " ", + f"""\ + async def {name}( + self, + input: {render_type_expr(input_type)}, + timeout: datetime.timedelta, + ) -> {render_type_expr(output_type)}: + return await self.client.send_rpc( + {repr(schema_name)}, + {repr(name)}, + input, + {reindent(" ", render_input_method)}, + {reindent(" ", parse_output_method)}, + {reindent(" ", parse_error_method)}, + timeout, + ) + """, + ) + ] + ) + elif procedure.type == "subscription": + assert input_meta + assert output_meta + assert error_meta + _, input_type, render_input_method = input_meta + _, output_type, parse_output_method = output_meta + _, error_type, parse_error_method = error_meta + error_type_name = extract_inner_type(error_type) + + output_or_error_type = UnionTypeExpr([output_type, error_type_name]) + + output_or_error_type = UnionTypeExpr( + [ + output_or_error_type, + TypeName("RiverError"), + ] + ) + current_chunks.extend( + [ + reindent( + " ", + f"""\ + async def {name}( + self, + input: {render_type_expr(input_type)}, + ) -> AsyncIterator[{render_type_expr(output_or_error_type)}]: + return self.client.send_subscription( + {repr(schema_name)}, + {repr(name)}, + input, + {reindent(" ", render_input_method)}, + {reindent(" ", parse_output_method)}, + {reindent(" ", parse_error_method)}, + ) + """, + ) + ] + ) + elif procedure.type == "upload": + assert input_meta + assert output_meta + assert error_meta + _, input_type, render_input_method = input_meta + _, output_type, parse_output_method = output_meta + _, error_type, parse_error_method = error_meta + error_type_name = extract_inner_type(error_type) + + output_or_error_type = UnionTypeExpr([output_type, error_type_name]) + + if init_meta: + _, init_type, render_init_method = init_meta + current_chunks.extend( + [ + reindent( + " ", + f"""\ + async def {name}( + self, + init: {render_type_expr(init_type)}, + inputStream: AsyncIterable[{render_type_expr(input_type)}], + ) -> {render_type_expr(output_type)}: + return await self.client.send_upload( + {repr(schema_name)}, + {repr(name)}, + init, + inputStream, + {reindent(" ", render_init_method)}, + {reindent(" ", render_input_method)}, + {reindent(" ", parse_output_method)}, + {reindent(" ", parse_error_method)}, + ) + """, + ) + ] + ) + else: + current_chunks.extend( + [ + reindent( + " ", + f"""\ + async def {name}( + self, + inputStream: AsyncIterable[{render_type_expr(input_type)}], + ) -> { # TODO(dstewart) This should just be output_type + render_type_expr(output_or_error_type) + }: + return await self.client.send_upload( + {repr(schema_name)}, + {repr(name)}, + None, + inputStream, + None, + {reindent(" ", render_input_method)}, + {reindent(" ", parse_output_method)}, + {reindent(" ", parse_error_method)}, + ) + """, + ) + ] + ) + elif procedure.type == "stream": + assert input_meta + assert output_meta + assert error_meta + _, input_type, render_input_method = input_meta + _, output_type, parse_output_method = output_meta + _, error_type, parse_error_method = error_meta + error_type_name = extract_inner_type(error_type) + + output_or_error_type = UnionTypeExpr([output_type, error_type_name]) + + output_or_error_type = UnionTypeExpr( + [ + output_or_error_type, + TypeName("RiverError"), + ] + ) + if init_meta: + _, init_type, render_init_method = init_meta + current_chunks.extend( + [ + reindent( + " ", + f"""\ + async def {name}( + self, + init: {render_type_expr(init_type)}, + inputStream: AsyncIterable[{render_type_expr(input_type)}], + ) -> AsyncIterator[{render_type_expr(output_or_error_type)}]: + return self.client.send_stream( + {repr(schema_name)}, + {repr(name)}, + init, + inputStream, + {reindent(" ", render_init_method)}, + {reindent(" ", render_input_method)}, + {reindent(" ", parse_output_method)}, + {reindent(" ", parse_error_method)}, + ) + """, + ) + ] + ) + else: + current_chunks.extend( + [ + reindent( + " ", + f"""\ + async def {name}( + self, + inputStream: AsyncIterable[{render_type_expr(input_type)}], + ) -> AsyncIterator[{render_type_expr(output_or_error_type)}]: + return self.client.send_stream( + {repr(schema_name)}, + {repr(name)}, + None, + inputStream, + None, + {reindent(" ", render_input_method)}, + {reindent(" ", parse_output_method)}, + {reindent(" ", parse_error_method)}, + ) + """, + ) + ] + ) + + current_chunks.append("") + return current_chunks + + def generate_individual_service( schema_name: str, schema: RiverService, @@ -910,7 +1130,6 @@ def __init__(self, client: river.Client[Any]): error_type_type_adapter_name, error_type, output_module_info ) ) - output_or_error_type = UnionTypeExpr([output_type, error_type_name]) # NB: These strings must be indented to at least the same level of # the function strings in the branches below, otherwise `dedent` @@ -982,186 +1201,58 @@ def __init__(self, client: river.Client[Any]): render_input_method = f"encode_{render_literal_type(input_type)}" else: input_type_name = extract_inner_type(input_type) - input_type_type_adapter_name = TypeName( + input_type_type_adapter = TypeName( f"{render_literal_type(input_type_name)}TypeAdapter" ) serdes.append( _type_adapter_definition( - input_type_type_adapter_name, input_type, input_module_info + input_type_type_adapter, input_type, input_module_info ) ) render_input_method = f"""\ - lambda x: {render_type_expr(input_type_type_adapter_name)} + lambda x: {render_type_expr(input_type_type_adapter)} .dump_python( x, # type: ignore[arg-type] by_alias=True, exclude_none=True, ) """ - if isinstance(output_type, NoneTypeExpr): parse_output_method = "lambda x: None" - if procedure.type == "rpc": - current_chunks.extend( - [ - reindent( - " ", - f"""\ - async def {name}( - self, - input: {render_type_expr(input_type)}, - timeout: datetime.timedelta, - ) -> {render_type_expr(output_type)}: - return await self.client.send_rpc( - {repr(schema_name)}, - {repr(name)}, - input, - {reindent(" ", render_input_method)}, - {reindent(" ", parse_output_method)}, - {reindent(" ", parse_error_method)}, - timeout, - ) - """, - ) - ] - ) - elif procedure.type == "subscription": - output_or_error_type = UnionTypeExpr( - [ - output_or_error_type, - TypeName("RiverError"), - ] - ) - current_chunks.extend( - [ - reindent( - " ", - f"""\ - async def {name}( - self, - input: {render_type_expr(input_type)}, - ) -> AsyncIterator[{render_type_expr(output_or_error_type)}]: - return self.client.send_subscription( - {repr(schema_name)}, - {repr(name)}, - input, - {reindent(" ", render_input_method)}, - {reindent(" ", parse_output_method)}, - {reindent(" ", parse_error_method)}, - ) - """, - ) - ] - ) - elif procedure.type == "upload": - if init_type: - current_chunks.extend( - [ - reindent( - " ", - f"""\ - async def {name}( - self, - init: {render_type_expr(init_type)}, - inputStream: AsyncIterable[{render_type_expr(input_type)}], - ) -> {render_type_expr(output_type)}: - return await self.client.send_upload( - {repr(schema_name)}, - {repr(name)}, - init, - inputStream, - {reindent(" ", render_init_method)}, - {reindent(" ", render_input_method)}, - {reindent(" ", parse_output_method)}, - {reindent(" ", parse_error_method)}, - ) - """, - ) - ] + def combine_or_none( + proc_type: RiverType | None, + tpe: TypeExpression | None, + serde_method: str | None, + ) -> tuple[RiverType, TypeExpression, str] | None: + if not proc_type and not tpe and not serde_method: + return None + if not proc_type or not tpe or not serde_method: + raise ValueError( + f"Unable to convert {repr(proc_type)} into either" + f" tpe={tpe} or render_method={serde_method}" ) - else: - current_chunks.extend( - [ - reindent( - " ", - f"""\ - async def {name}( - self, - inputStream: AsyncIterable[{render_type_expr(input_type)}], - ) -> {render_type_expr(output_or_error_type)}: - return await self.client.send_upload( - {repr(schema_name)}, - {repr(name)}, - None, - inputStream, - None, - {reindent(" ", render_input_method)}, - {reindent(" ", parse_output_method)}, - {reindent(" ", parse_error_method)}, - ) - """, - ) - ] - ) - elif procedure.type == "stream": - output_or_error_type = UnionTypeExpr( - [ - output_or_error_type, - TypeName("RiverError"), - ] + return (proc_type, tpe, serde_method) + + current_chunks.extend( + render_library_call( + schema_name=schema_name, + name=name, + procedure=procedure, + init_meta=combine_or_none( + procedure.init, init_type, render_init_method + ), + input_meta=combine_or_none( + procedure.input, input_type, render_input_method + ), + output_meta=combine_or_none( + procedure.output, output_type, parse_output_method + ), + error_meta=combine_or_none( + procedure.errors, error_type, parse_error_method + ), ) - if init_type: - current_chunks.extend( - [ - reindent( - " ", - f"""\ - async def {name}( - self, - init: {render_type_expr(init_type)}, - inputStream: AsyncIterable[{render_type_expr(input_type)}], - ) -> AsyncIterator[{render_type_expr(output_or_error_type)}]: - return self.client.send_stream( - {repr(schema_name)}, - {repr(name)}, - init, - inputStream, - {reindent(" ", render_init_method)}, - {reindent(" ", render_input_method)}, - {reindent(" ", parse_output_method)}, - {reindent(" ", parse_error_method)}, - ) - """, - ) - ] - ) - else: - current_chunks.extend( - [ - reindent( - " ", - f"""\ - async def {name}( - self, - inputStream: AsyncIterable[{render_type_expr(input_type)}], - ) -> AsyncIterator[{render_type_expr(output_or_error_type)}]: - return self.client.send_stream( - {repr(schema_name)}, - {repr(name)}, - None, - inputStream, - None, - {reindent(" ", render_input_method)}, - {reindent(" ", parse_output_method)}, - {reindent(" ", parse_error_method)}, - ) - """, - ) - ] - ) - - current_chunks.append("") + ) emitted_files: dict[RenderedPath, FileContents] = {} From bbe78dd99dd3d608946f1ff9d1ab7618bfc7f98d Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Fri, 21 Mar 2025 23:22:53 -0700 Subject: [PATCH 006/193] v2 client --- src/replit_river/__init__.py | 2 + src/replit_river/v2/__init__.py | 5 + src/replit_river/v2/client.py | 249 ++++++++++++ src/replit_river/v2/client_session.py | 494 ++++++++++++++++++++++++ src/replit_river/v2/client_transport.py | 379 ++++++++++++++++++ 5 files changed, 1129 insertions(+) create mode 100644 src/replit_river/v2/__init__.py create mode 100644 src/replit_river/v2/client.py create mode 100644 src/replit_river/v2/client_session.py create mode 100644 src/replit_river/v2/client_transport.py diff --git a/src/replit_river/__init__.py b/src/replit_river/__init__.py index bc166e3e..d3bf1966 100644 --- a/src/replit_river/__init__.py +++ b/src/replit_river/__init__.py @@ -1,3 +1,4 @@ +from . import v2 from .client import Client from .error_schema import RiverError from .rpc import ( @@ -20,4 +21,5 @@ "subscription_method_handler", "upload_method_handler", "stream_method_handler", + "v2", ] diff --git a/src/replit_river/v2/__init__.py b/src/replit_river/v2/__init__.py new file mode 100644 index 00000000..19790f0e --- /dev/null +++ b/src/replit_river/v2/__init__.py @@ -0,0 +1,5 @@ +from .client import Client + +__all__ = [ + "Client", +] diff --git a/src/replit_river/v2/client.py b/src/replit_river/v2/client.py new file mode 100644 index 00000000..efd1bfe5 --- /dev/null +++ b/src/replit_river/v2/client.py @@ -0,0 +1,249 @@ +import logging +from collections.abc import AsyncIterable, Awaitable, Callable +from contextlib import contextmanager +from dataclasses import dataclass +from datetime import timedelta +from typing import Any, AsyncGenerator, Generator, Generic, Literal + +from opentelemetry import trace +from opentelemetry.trace import Span, SpanKind, Status, StatusCode +from pydantic import ( + BaseModel, + ValidationInfo, +) + +from replit_river.error_schema import ERROR_CODE_UNKNOWN, RiverError, RiverException +from replit_river.rpc import ( + ErrorType, + InitType, + RequestType, + ResponseType, +) +from replit_river.transport_options import ( + HandshakeMetadataType, + TransportOptions, + UriAndMetadata, +) +from replit_river.v2.client_transport import ClientTransport + +logger = logging.getLogger(__name__) +tracer = trace.get_tracer(__name__) + + +@dataclass(frozen=True) +class RiverUnknownValue(BaseModel): + tag: Literal["RiverUnknownValue"] + value: Any + + +class RiverUnknownError(RiverError): + pass + + +def translate_unknown_value( + value: Any, handler: Callable[[Any], Any], info: ValidationInfo +) -> Any | RiverUnknownValue: + try: + return handler(value) + except Exception: + return RiverUnknownValue(tag="RiverUnknownValue", value=value) + + +def translate_unknown_error( + value: Any, handler: Callable[[Any], Any], info: ValidationInfo +) -> Any | RiverUnknownError: + try: + return handler(value) + except Exception: + if isinstance(value, dict) and "code" in value and "message" in value: + return RiverUnknownError( + code=value["code"], + message=value["message"], + ) + else: + return RiverUnknownError(code=ERROR_CODE_UNKNOWN, message="Unknown error") + + +class Client(Generic[HandshakeMetadataType]): + def __init__( + self, + uri_and_metadata_factory: Callable[ + [], Awaitable[UriAndMetadata[HandshakeMetadataType]] + ], + client_id: str, + server_id: str, + transport_options: TransportOptions, + ) -> None: + self._client_id = client_id + self._server_id = server_id + self._transport = ClientTransport[HandshakeMetadataType]( + uri_and_metadata_factory=uri_and_metadata_factory, + client_id=client_id, + server_id=server_id, + transport_options=transport_options, + ) + + async def close(self) -> None: + logger.info(f"river client {self._client_id} start closing") + await self._transport.close() + logger.info(f"river client {self._client_id} closed") + + async def ensure_connected(self) -> None: + await self._transport.get_or_create_session() + + async def send_rpc( + self, + service_name: str, + procedure_name: str, + request: RequestType, + request_serializer: Callable[[RequestType], Any], + response_deserializer: Callable[[Any], ResponseType], + error_deserializer: Callable[[Any], ErrorType], + timeout: timedelta, + ) -> ResponseType: + with _trace_procedure("rpc", service_name, procedure_name) as span_handle: + session = await self._transport.get_or_create_session() + return await session.send_rpc( + service_name, + procedure_name, + request, + request_serializer, + response_deserializer, + error_deserializer, + span_handle.span, + timeout, + ) + + async def send_upload( + self, + service_name: str, + procedure_name: str, + init: InitType | None, + request: AsyncIterable[RequestType], + init_serializer: Callable[[InitType], Any] | None, + request_serializer: Callable[[RequestType], Any], + response_deserializer: Callable[[Any], ResponseType], + error_deserializer: Callable[[Any], ErrorType], + ) -> ResponseType: + with _trace_procedure("upload", service_name, procedure_name) as span_handle: + session = await self._transport.get_or_create_session() + return await session.send_upload( + service_name, + procedure_name, + init, + request, + init_serializer, + request_serializer, + response_deserializer, + error_deserializer, + span_handle.span, + ) + + async def send_subscription( + self, + service_name: str, + procedure_name: str, + request: RequestType, + request_serializer: Callable[[RequestType], Any], + response_deserializer: Callable[[Any], ResponseType], + error_deserializer: Callable[[Any], ErrorType], + ) -> AsyncGenerator[ResponseType | RiverError, None]: + with _trace_procedure( + "subscription", service_name, procedure_name + ) as span_handle: + session = await self._transport.get_or_create_session() + async for msg in session.send_subscription( + service_name, + procedure_name, + request, + request_serializer, + response_deserializer, + error_deserializer, + span_handle.span, + ): + if isinstance(msg, RiverError): + _record_river_error(span_handle, msg) + yield msg # type: ignore # https://github.com/python/mypy/issues/10817 + + async def send_stream( + self, + service_name: str, + procedure_name: str, + init: InitType | None, + request: AsyncIterable[RequestType], + init_serializer: Callable[[InitType], Any] | None, + request_serializer: Callable[[RequestType], Any], + response_deserializer: Callable[[Any], ResponseType], + error_deserializer: Callable[[Any], ErrorType], + ) -> AsyncGenerator[ResponseType | RiverError, None]: + with _trace_procedure("stream", service_name, procedure_name) as span_handle: + session = await self._transport.get_or_create_session() + async for msg in session.send_stream( + service_name, + procedure_name, + init, + request, + init_serializer, + request_serializer, + response_deserializer, + error_deserializer, + span_handle.span, + ): + if isinstance(msg, RiverError): + _record_river_error(span_handle, msg) + yield msg # type: ignore # https://github.com/python/mypy/issues/10817 + + +@dataclass +class _SpanHandle: + """Wraps a span and keeps track of whether or not a status has been recorded yet.""" + + span: Span + did_set_status: bool = False + + def set_status( + self, + status: Status | StatusCode, + description: str | None = None, + ) -> None: + if self.did_set_status: + return + self.did_set_status = True + self.span.set_status(status, description) + + +@contextmanager +def _trace_procedure( + procedure_type: Literal["rpc", "upload", "subscription", "stream"], + service_name: str, + procedure_name: str, +) -> Generator[_SpanHandle, None, None]: + span = tracer.start_span( + f"river.client.{procedure_type}.{service_name}.{procedure_name}", + kind=SpanKind.CLIENT, + ) + span_handle = _SpanHandle(span) + try: + yield span_handle + except GeneratorExit: + # This error indicates the caller is done with the async generator + # but messages are still left. This is okay, we do not consider it an error. + raise + except RiverException as e: + span.record_exception(e, escaped=True) + _record_river_error(span_handle, RiverError(code=e.code, message=e.message)) + raise e + except BaseException as e: + span.record_exception(e, escaped=True) + span_handle.set_status(StatusCode.ERROR, f"{type(e).__name__}: {e}") + raise e + finally: + span_handle.set_status(StatusCode.OK) + span.end() + + +def _record_river_error(span_handle: _SpanHandle, error: RiverError) -> None: + span_handle.set_status(StatusCode.ERROR, error.message) + span_handle.span.record_exception(RiverException(error.code, error.message)) + span_handle.span.set_attribute("river.error_code", error.code) + span_handle.span.set_attribute("river.error_message", error.message) diff --git a/src/replit_river/v2/client_session.py b/src/replit_river/v2/client_session.py new file mode 100644 index 00000000..505e23e2 --- /dev/null +++ b/src/replit_river/v2/client_session.py @@ -0,0 +1,494 @@ +import asyncio +import logging +from collections.abc import AsyncIterable +from datetime import timedelta +from typing import Any, AsyncGenerator, Callable, Coroutine + +import nanoid # type: ignore +import websockets +from aiochannel import Channel +from aiochannel.errors import ChannelClosed +from opentelemetry.trace import Span +from websockets.exceptions import ConnectionClosed + +from replit_river.common_session import add_msg_to_stream +from replit_river.error_schema import ( + ERROR_CODE_CANCEL, + ERROR_CODE_STREAM_CLOSED, + RiverException, + RiverServiceException, + StreamClosedRiverServiceException, + exception_from_message, +) +from replit_river.messages import ( + FailedSendingMessageException, + parse_transport_msg, +) +from replit_river.rpc import ( + ACK_BIT, + STREAM_CLOSED_BIT, + STREAM_OPEN_BIT, + ErrorType, + InitType, + RequestType, + ResponseType, +) +from replit_river.seq_manager import ( + IgnoreMessageException, + InvalidMessageException, + OutOfOrderMessageException, +) +from replit_river.session import Session +from replit_river.transport_options import MAX_MESSAGE_BUFFER_SIZE, TransportOptions + +logger = logging.getLogger(__name__) + + +class ClientSession(Session): + def __init__( + self, + transport_id: str, + to_id: str, + session_id: str, + websocket: websockets.WebSocketCommonProtocol, + transport_options: TransportOptions, + close_session_callback: Callable[[Session], Coroutine[Any, Any, Any]], + retry_connection_callback: ( + Callable[ + [], + Coroutine[Any, Any, Any], + ] + | None + ) = None, + ) -> None: + super().__init__( + transport_id=transport_id, + to_id=to_id, + session_id=session_id, + websocket=websocket, + transport_options=transport_options, + close_session_callback=close_session_callback, + retry_connection_callback=retry_connection_callback, + ) + + async def do_close_websocket() -> None: + await self.close_websocket( + self._ws_wrapper, + should_retry=True, + ) + await self._begin_close_session_countdown() + + self._setup_heartbeats_task(do_close_websocket) + + async def start_serve_responses(self) -> None: + self._task_manager.create_task(self.serve()) + + async def serve(self) -> None: + """Serve messages from the websocket.""" + self._reset_session_close_countdown() + try: + try: + await self._handle_messages_from_ws() + except ConnectionClosed: + if self._retry_connection_callback: + self._task_manager.create_task(self._retry_connection_callback()) + + await self._begin_close_session_countdown() + logger.debug("ConnectionClosed while serving", exc_info=True) + except FailedSendingMessageException: + # Expected error if the connection is closed. + logger.debug( + "FailedSendingMessageException while serving", exc_info=True + ) + except Exception: + logger.exception("caught exception at message iterator") + except ExceptionGroup as eg: + _, unhandled = eg.split(lambda e: isinstance(e, ConnectionClosed)) + if unhandled: + raise ExceptionGroup( + "Unhandled exceptions on River server", unhandled.exceptions + ) + + async def _handle_messages_from_ws(self) -> None: + logger.debug( + "%s start handling messages from ws %s", + "client", + self._ws_wrapper.id, + ) + try: + ws_wrapper = self._ws_wrapper + async for message in ws_wrapper.ws: + try: + if not await ws_wrapper.is_open(): + # We should not process messages if the websocket is closed. + break + msg = parse_transport_msg(message, self._transport_options) + + logger.debug(f"{self._transport_id} got a message %r", msg) + + # Update bookkeeping + await self._seq_manager.check_seq_and_update(msg) + await self._buffer.remove_old_messages( + self._seq_manager.receiver_ack, + ) + self._reset_session_close_countdown() + + if msg.controlFlags & ACK_BIT != 0: + continue + async with self._stream_lock: + stream = self._streams.get(msg.streamId, None) + if msg.controlFlags & STREAM_OPEN_BIT == 0: + if not stream: + logger.warning("no stream for %s", msg.streamId) + raise IgnoreMessageException( + "no stream for message, ignoring" + ) + await add_msg_to_stream(msg, stream) + else: + raise InvalidMessageException( + "Client should not receive stream open bit" + ) + + if msg.controlFlags & STREAM_CLOSED_BIT != 0: + if stream: + stream.close() + async with self._stream_lock: + del self._streams[msg.streamId] + except IgnoreMessageException: + logger.debug("Ignoring transport message", exc_info=True) + continue + except OutOfOrderMessageException: + logger.exception("Out of order message, closing connection") + await ws_wrapper.close() + return + except InvalidMessageException: + logger.exception("Got invalid transport message, closing session") + await self.close() + return + except ConnectionClosed as e: + raise e + + async def send_rpc( + self, + service_name: str, + procedure_name: str, + request: RequestType, + request_serializer: Callable[[RequestType], Any], + response_deserializer: Callable[[Any], ResponseType], + error_deserializer: Callable[[Any], ErrorType], + span: Span, + timeout: timedelta, + ) -> ResponseType: + """Sends a single RPC request to the server. + + Expects the input and output be messages that will be msgpacked. + """ + stream_id = nanoid.generate() + output: Channel[Any] = Channel(1) + self._streams[stream_id] = output + await self.send_message( + stream_id=stream_id, + control_flags=STREAM_OPEN_BIT | STREAM_CLOSED_BIT, + payload=request_serializer(request), + service_name=service_name, + procedure_name=procedure_name, + span=span, + ) + # Handle potential errors during communication + try: + try: + async with asyncio.timeout(timeout.total_seconds()): + response = await output.get() + except asyncio.TimeoutError as e: + # TODO(dstewart) After protocol v2, change this to STREAM_CANCEL_BIT + await self.send_message( + stream_id=stream_id, + control_flags=STREAM_CLOSED_BIT, + payload={"type": "CLOSE"}, + service_name=service_name, + procedure_name=procedure_name, + span=span, + ) + raise RiverException(ERROR_CODE_CANCEL, str(e)) from e + except ChannelClosed as e: + raise RiverServiceException( + ERROR_CODE_STREAM_CLOSED, + "Stream closed before response", + service_name, + procedure_name, + ) from e + except RuntimeError as e: + raise RiverException(ERROR_CODE_STREAM_CLOSED, str(e)) from e + if not response.get("ok", False): + try: + error = error_deserializer(response["payload"]) + except Exception as e: + raise RiverException("error_deserializer", str(e)) from e + raise exception_from_message(error.code)( + error.code, error.message, service_name, procedure_name + ) + return response_deserializer(response["payload"]) + except RiverException as e: + raise e + except Exception as e: + raise e + + async def send_upload( + self, + service_name: str, + procedure_name: str, + init: InitType | None, + request: AsyncIterable[RequestType], + init_serializer: Callable[[InitType], Any] | None, + request_serializer: Callable[[RequestType], Any], + response_deserializer: Callable[[Any], ResponseType], + error_deserializer: Callable[[Any], ErrorType], + span: Span, + ) -> ResponseType: + """Sends an upload request to the server. + + Expects the input and output be messages that will be msgpacked. + """ + + stream_id = nanoid.generate() + output: Channel[Any] = Channel(1) + self._streams[stream_id] = output + first_message = True + try: + if init and init_serializer: + await self.send_message( + stream_id=stream_id, + control_flags=STREAM_OPEN_BIT, + service_name=service_name, + procedure_name=procedure_name, + payload=init_serializer(init), + span=span, + ) + first_message = False + # If this request is not closed and the session is killed, we should + # throw exception here + async for item in request: + control_flags = 0 + if first_message: + control_flags = STREAM_OPEN_BIT + first_message = False + await self.send_message( + stream_id=stream_id, + service_name=service_name, + procedure_name=procedure_name, + control_flags=control_flags, + payload=request_serializer(item), + span=span, + ) + except Exception as e: + raise RiverServiceException( + ERROR_CODE_STREAM_CLOSED, str(e), service_name, procedure_name + ) from e + await self.send_close_stream( + service_name, + procedure_name, + stream_id, + extra_control_flags=STREAM_OPEN_BIT if first_message else 0, + ) + + # Handle potential errors during communication + # TODO: throw a error when the transport is hard closed + try: + try: + response = await output.get() + except ChannelClosed as e: + raise RiverServiceException( + ERROR_CODE_STREAM_CLOSED, + "Stream closed before response", + service_name, + procedure_name, + ) from e + except RuntimeError as e: + raise RiverException(ERROR_CODE_STREAM_CLOSED, str(e)) from e + if not response.get("ok", False): + try: + error = error_deserializer(response["payload"]) + except Exception as e: + raise RiverException("error_deserializer", str(e)) from e + raise exception_from_message(error.code)( + error.code, error.message, service_name, procedure_name + ) + + return response_deserializer(response["payload"]) + except RiverException as e: + raise e + except Exception as e: + raise e + + async def send_subscription( + self, + service_name: str, + procedure_name: str, + request: RequestType, + request_serializer: Callable[[RequestType], Any], + response_deserializer: Callable[[Any], ResponseType], + error_deserializer: Callable[[Any], ErrorType], + span: Span, + ) -> AsyncGenerator[ResponseType | ErrorType, None]: + """Sends a subscription request to the server. + + Expects the input and output be messages that will be msgpacked. + """ + stream_id = nanoid.generate() + output: Channel[Any] = Channel(MAX_MESSAGE_BUFFER_SIZE) + self._streams[stream_id] = output + await self.send_message( + service_name=service_name, + procedure_name=procedure_name, + stream_id=stream_id, + control_flags=STREAM_OPEN_BIT, + payload=request_serializer(request), + span=span, + ) + + # Handle potential errors during communication + try: + async for item in output: + if item.get("type", None) == "CLOSE": + break + if not item.get("ok", False): + try: + yield error_deserializer(item["payload"]) + except Exception: + logger.exception( + f"Error during subscription error deserialization: {item}" + ) + continue + yield response_deserializer(item["payload"]) + except (RuntimeError, ChannelClosed) as e: + raise RiverServiceException( + ERROR_CODE_STREAM_CLOSED, + "Stream closed before response", + service_name, + procedure_name, + ) from e + except Exception as e: + raise e + finally: + output.close() + + async def send_stream( + self, + service_name: str, + procedure_name: str, + init: InitType | None, + request: AsyncIterable[RequestType], + init_serializer: Callable[[InitType], Any] | None, + request_serializer: Callable[[RequestType], Any], + response_deserializer: Callable[[Any], ResponseType], + error_deserializer: Callable[[Any], ErrorType], + span: Span, + ) -> AsyncGenerator[ResponseType | ErrorType, None]: + """Sends a subscription request to the server. + + Expects the input and output be messages that will be msgpacked. + """ + + stream_id = nanoid.generate() + output: Channel[Any] = Channel(MAX_MESSAGE_BUFFER_SIZE) + self._streams[stream_id] = output + empty_stream = False + try: + if init and init_serializer: + await self.send_message( + service_name=service_name, + procedure_name=procedure_name, + stream_id=stream_id, + control_flags=STREAM_OPEN_BIT, + payload=init_serializer(init), + span=span, + ) + else: + # Get the very first message to open the stream + request_iter = aiter(request) + first = await anext(request_iter) + await self.send_message( + service_name=service_name, + procedure_name=procedure_name, + stream_id=stream_id, + control_flags=STREAM_OPEN_BIT, + payload=request_serializer(first), + span=span, + ) + + except StopAsyncIteration: + empty_stream = True + + except Exception as e: + raise StreamClosedRiverServiceException( + ERROR_CODE_STREAM_CLOSED, str(e), service_name, procedure_name + ) from e + + # Create the encoder task + async def _encode_stream() -> None: + if empty_stream: + await self.send_close_stream( + service_name, + procedure_name, + stream_id, + extra_control_flags=STREAM_OPEN_BIT, + ) + return + + async for item in request: + if item is None: + continue + await self.send_message( + service_name=service_name, + procedure_name=procedure_name, + stream_id=stream_id, + control_flags=0, + payload=request_serializer(item), + ) + await self.send_close_stream(service_name, procedure_name, stream_id) + + self._task_manager.create_task(_encode_stream()) + + # Handle potential errors during communication + try: + async for item in output: + if "type" in item and item["type"] == "CLOSE": + break + if not item.get("ok", False): + try: + yield error_deserializer(item["payload"]) + except Exception: + logger.exception( + f"Error during subscription error deserialization: {item}" + ) + continue + yield response_deserializer(item["payload"]) + except (RuntimeError, ChannelClosed) as e: + raise RiverServiceException( + ERROR_CODE_STREAM_CLOSED, + "Stream closed before response", + service_name, + procedure_name, + ) from e + except Exception as e: + raise e + finally: + output.close() + + async def send_close_stream( + self, + service_name: str, + procedure_name: str, + stream_id: str, + extra_control_flags: int = 0, + ) -> None: + # close stream + await self.send_message( + service_name=service_name, + procedure_name=procedure_name, + stream_id=stream_id, + control_flags=STREAM_CLOSED_BIT | extra_control_flags, + payload={ + "type": "CLOSE", + }, + ) diff --git a/src/replit_river/v2/client_transport.py b/src/replit_river/v2/client_transport.py new file mode 100644 index 00000000..8248d7c4 --- /dev/null +++ b/src/replit_river/v2/client_transport.py @@ -0,0 +1,379 @@ +import asyncio +import logging +from collections.abc import Awaitable, Callable +from typing import Generic, assert_never + +import nanoid +import websockets +from pydantic import ValidationError +from websockets import ( + WebSocketCommonProtocol, +) +from websockets.exceptions import ConnectionClosed + +from replit_river.error_schema import ( + ERROR_CODE_STREAM_CLOSED, + ERROR_HANDSHAKE, + ERROR_SESSION, + RiverException, +) +from replit_river.messages import ( + FailedSendingMessageException, + WebsocketClosedException, + parse_transport_msg, + send_transport_message, +) +from replit_river.rate_limiter import LeakyBucketRateLimit +from replit_river.rpc import ( + SESSION_MISMATCH_CODE, + ControlMessageHandshakeRequest, + ControlMessageHandshakeResponse, + ExpectedSessionState, + TransportMessage, +) +from replit_river.seq_manager import ( + IgnoreMessageException, + InvalidMessageException, +) +from replit_river.session import Session +from replit_river.transport_options import ( + HandshakeMetadataType, + TransportOptions, + UriAndMetadata, +) +from replit_river.v2.client_session import ClientSession + +logger = logging.getLogger(__name__) + + +class ClientTransport(Generic[HandshakeMetadataType]): + _sessions: dict[str, ClientSession] + + def __init__( + self, + uri_and_metadata_factory: Callable[[], Awaitable[UriAndMetadata]], + client_id: str, + server_id: str, + transport_options: TransportOptions, + ): + self._sessions = {} + self._transport_id = client_id + self._transport_options = transport_options + self._session_lock = asyncio.Lock() + + self._uri_and_metadata_factory = uri_and_metadata_factory + self._client_id = client_id + self._server_id = server_id + self._rate_limiter = LeakyBucketRateLimit( + transport_options.connection_retry_options + ) + # We want to make sure there's only one session creation at a time + self._create_session_lock = asyncio.Lock() + + async def _close_all_sessions(self) -> None: + sessions = self._sessions.values() + logger.info( + f"start closing sessions {self._transport_id}, number sessions : " + f"{len(sessions)}" + ) + sessions_to_close = list(sessions) + + # closing sessions requires access to the session lock, so we need to close + # them one by one to be safe + for session in sessions_to_close: + await session.close() + + logger.info(f"Transport closed {self._transport_id}") + + def generate_nanoid(self) -> str: + return str(nanoid.generate()) + + async def close(self) -> None: + self._rate_limiter.close() + await self._close_all_sessions() + + async def get_or_create_session(self) -> ClientSession: + async with self._create_session_lock: + existing_session = await self._get_existing_session() + if not existing_session: + return await self._create_new_session() + is_session_open = await existing_session.is_session_open() + if not is_session_open: + return await self._create_new_session() + is_ws_open = await existing_session.is_websocket_open() + if is_ws_open: + return existing_session + new_ws, _, hs_response = await self._establish_new_connection( + existing_session + ) + if hs_response.status.sessionId == existing_session.session_id: + logger.info( + "Replacing ws connection in session id %s", + existing_session.session_id, + ) + await existing_session.replace_with_new_websocket(new_ws) + return existing_session + else: + logger.info("Closing stale session %s", existing_session.session_id) + await existing_session.close() + return await self._create_new_session() + + async def _get_existing_session(self) -> ClientSession | None: + async with self._session_lock: + if not self._sessions: + return None + if len(self._sessions) > 1: + raise RiverException( + "session_error", + "More than one session found in client, should only be one", + ) + session = list(self._sessions.values())[0] + if isinstance(session, ClientSession): + return session + else: + raise RiverException( + "session_error", f"Client session type wrong, got {type(session)}" + ) + + async def _establish_new_connection( + self, + old_session: ClientSession | None = None, + ) -> tuple[ + WebSocketCommonProtocol, + ControlMessageHandshakeRequest[HandshakeMetadataType], + ControlMessageHandshakeResponse, + ]: + """Build a new websocket connection with retry logic.""" + rate_limit = self._rate_limiter + max_retry = self._transport_options.connection_retry_options.max_retry + client_id = self._client_id + logger.info("Attempting to establish new ws connection") + + last_error: Exception | None = None + for i in range(max_retry): + if i > 0: + logger.info(f"Retrying build handshake number {i} times") + if not rate_limit.has_budget(client_id): + logger.debug("No retry budget for %s.", client_id) + raise RiverException( + ERROR_HANDSHAKE, f"No retry budget for {client_id}" + ) from last_error + + rate_limit.consume_budget(client_id) + + # if the session is closed, we shouldn't use it + if old_session and not await old_session.is_session_open(): + old_session = None + + try: + uri_and_metadata = await self._uri_and_metadata_factory() + ws = await websockets.connect(uri_and_metadata["uri"]) + session_id = ( + self.generate_nanoid() + if not old_session + else old_session.session_id + ) + + try: + ( + handshake_request, + handshake_response, + ) = await self._establish_handshake( + self._transport_id, + self._server_id, + session_id, + uri_and_metadata["metadata"], + ws, + old_session, + ) + rate_limit.start_restoring_budget(client_id) + return ws, handshake_request, handshake_response + except RiverException as e: + await ws.close() + raise e + except Exception as e: + last_error = e + backoff_time = rate_limit.get_backoff_ms(client_id) + logger.exception( + f"Error connecting, retrying with {backoff_time}ms backoff" + ) + await asyncio.sleep(backoff_time / 1000) + + raise RiverException( + ERROR_HANDSHAKE, + f"Failed to create ws after retrying {max_retry} number of times", + ) from last_error + + async def _create_new_session( + self, + ) -> ClientSession: + logger.info("Creating new session") + new_ws, hs_request, hs_response = await self._establish_new_connection() + if not hs_response.status.ok: + message = hs_response.status.reason + raise RiverException( + ERROR_SESSION, + f"Server did not return OK status on handshake response: {message}", + ) + new_session = ClientSession( + transport_id=self._transport_id, + to_id=self._server_id, + session_id=hs_request.sessionId, + websocket=new_ws, + transport_options=self._transport_options, + close_session_callback=self._delete_session, + retry_connection_callback=self._retry_connection, + ) + + self._sessions[new_session._to_id] = new_session + await new_session.start_serve_responses() + return new_session + + async def _retry_connection(self) -> ClientSession: + if not self._transport_options.transparent_reconnect: + await self._close_all_sessions() + return await self.get_or_create_session() + + async def _send_handshake_request( + self, + transport_id: str, + to_id: str, + session_id: str, + handshake_metadata: HandshakeMetadataType | None, + websocket: WebSocketCommonProtocol, + expected_session_state: ExpectedSessionState, + ) -> ControlMessageHandshakeRequest[HandshakeMetadataType]: + handshake_request = ControlMessageHandshakeRequest[HandshakeMetadataType]( + type="HANDSHAKE_REQ", + protocolVersion="v2.0", + sessionId=session_id, + metadata=handshake_metadata, + expectedSessionState=expected_session_state, + ) + stream_id = self.generate_nanoid() + + async def websocket_closed_callback() -> None: + logger.error("websocket closed before handshake response") + + try: + await send_transport_message( + TransportMessage( + from_=transport_id, # type: ignore + to=to_id, + streamId=stream_id, + controlFlags=0, + id=self.generate_nanoid(), + seq=0, + ack=0, + payload=handshake_request.model_dump(), + ), + ws=websocket, + websocket_closed_callback=websocket_closed_callback, + ) + return handshake_request + except (WebsocketClosedException, FailedSendingMessageException) as e: + raise RiverException( + ERROR_HANDSHAKE, "Handshake failed, conn closed while sending response" + ) from e + + async def _get_handshake_response_msg( + self, websocket: WebSocketCommonProtocol + ) -> TransportMessage: + while True: + try: + data = await websocket.recv() + except ConnectionClosed as e: + logger.debug( + "Connection closed during waiting for handshake response", + exc_info=True, + ) + raise RiverException( + ERROR_HANDSHAKE, + "Handshake failed, conn closed while waiting for response", + ) from e + try: + return parse_transport_msg(data, self._transport_options) + except IgnoreMessageException: + logger.debug("Ignoring transport message", exc_info=True) + continue + except InvalidMessageException as e: + raise RiverException( + ERROR_HANDSHAKE, + "Got invalid transport message, closing connection", + ) from e + + async def _establish_handshake( + self, + transport_id: str, + to_id: str, + session_id: str, + handshake_metadata: HandshakeMetadataType, + websocket: WebSocketCommonProtocol, + old_session: ClientSession | None, + ) -> tuple[ + ControlMessageHandshakeRequest[HandshakeMetadataType], + ControlMessageHandshakeResponse, + ]: + try: + expectedSessionState: ExpectedSessionState + match old_session: + case None: + expectedSessionState = ExpectedSessionState( + nextExpectedSeq=0, + nextSentSeq=0, + ) + case ClientSession(): + expectedSessionState = ExpectedSessionState( + nextExpectedSeq=await old_session.get_next_expected_seq(), + nextSentSeq=await old_session.get_next_sent_seq(), + ) + case other: + assert_never(other) + handshake_request = await self._send_handshake_request( + transport_id=transport_id, + to_id=to_id, + session_id=session_id, + handshake_metadata=handshake_metadata, + websocket=websocket, + expected_session_state=expectedSessionState, + ) + except FailedSendingMessageException as e: + raise RiverException( + ERROR_CODE_STREAM_CLOSED, + "Stream closed before response, closing connection", + ) from e + + startup_grace_sec = 60 + try: + response_msg = await asyncio.wait_for( + self._get_handshake_response_msg(websocket), startup_grace_sec + ) + handshake_response = ControlMessageHandshakeResponse(**response_msg.payload) + logger.debug("river client waiting for handshake response") + except ValidationError as e: + raise RiverException( + ERROR_HANDSHAKE, "Failed to parse handshake response" + ) from e + except asyncio.TimeoutError as e: + raise RiverException( + ERROR_HANDSHAKE, "Handshake response timeout, closing connection" + ) from e + + logger.debug("river client get handshake response : %r", handshake_response) + if not handshake_response.status.ok: + if old_session and handshake_response.status.code == SESSION_MISMATCH_CODE: + # If the session status is mismatched, we should close the old session + # and let the retry logic to create a new session. + await old_session.close() + + raise RiverException( + ERROR_HANDSHAKE, + f"Handshake failed with code ${handshake_response.status.code}: " + + f"{handshake_response.status.reason}", + ) + return handshake_request, handshake_response + + async def _delete_session(self, session: Session) -> None: + async with self._session_lock: + if session._to_id in self._sessions: + del self._sessions[session._to_id] From 20555620203e5b81820460f5b29060ec19e55b07 Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Sat, 22 Mar 2025 00:42:19 -0700 Subject: [PATCH 007/193] Adding and threading through protocol_version --- src/replit_river/codegen/client.py | 7 +++++++ src/replit_river/codegen/run.py | 1 + tests/codegen/snapshot/codegen_snapshot_fixtures.py | 1 + tests/codegen/test_rpc.py | 1 + 4 files changed, 10 insertions(+) diff --git a/src/replit_river/codegen/client.py b/src/replit_river/codegen/client.py index 00b24026..49a0d95d 100644 --- a/src/replit_river/codegen/client.py +++ b/src/replit_river/codegen/client.py @@ -784,6 +784,7 @@ def render_library_call( schema_name: str, name: str, procedure: RiverProcedure, + protocol_version: Literal["v1.1", "v2.0"], init_meta: tuple[RiverType, TypeExpression, str] | None, input_meta: tuple[RiverType, TypeExpression, str] | None, output_meta: tuple[RiverType, TypeExpression, str] | None, @@ -1005,6 +1006,7 @@ def generate_individual_service( schema: RiverService, input_base_class: Literal["TypedDict"] | Literal["BaseModel"], method_filter: set[str] | None, + protocol_version: Literal["v1.1", "v2.0"], ) -> tuple[ModuleName, ClassName, dict[RenderedPath, FileContents]]: serdes: list[tuple[list[TypeName], list[ModuleName], list[FileContents]]] = [] @@ -1239,6 +1241,7 @@ def combine_or_none( schema_name=schema_name, name=name, procedure=procedure, + protocol_version=protocol_version, init_meta=combine_or_none( procedure.init, init_type, render_init_method ), @@ -1294,6 +1297,7 @@ def generate_river_client_module( schema_root: RiverSchema, typed_dict_inputs: bool, method_filter: set[str] | None, + protocol_version: Literal["v1.1", "v2.0"], ) -> dict[RenderedPath, FileContents]: files: dict[RenderedPath, FileContents] = {} @@ -1322,6 +1326,7 @@ def generate_river_client_module( schema, input_base_class, method_filter, + protocol_version, ) if emitted_files: # Short-cut if we didn't actually emit anything @@ -1343,6 +1348,7 @@ def schema_to_river_client_codegen( typed_dict_inputs: bool, file_opener: Callable[[Path], TextIO], method_filter: set[str] | None, + protocol_version: Literal["v1.1", "v2.0"], ) -> None: """Generates the lines of a River module.""" with read_schema() as f: @@ -1352,6 +1358,7 @@ def schema_to_river_client_codegen( schemas.root, typed_dict_inputs, method_filter, + protocol_version, ).items(): module_path = Path(target_path).joinpath(subpath) module_path.parent.mkdir(mode=0o755, parents=True, exist_ok=True) diff --git a/src/replit_river/codegen/run.py b/src/replit_river/codegen/run.py index 16d94a5e..5044780d 100644 --- a/src/replit_river/codegen/run.py +++ b/src/replit_river/codegen/run.py @@ -75,6 +75,7 @@ def file_opener(path: Path) -> TextIO: args.typed_dict_inputs, file_opener, method_filter=method_filter, + protocol_version="v1.1", ) else: raise NotImplementedError(f"Unknown command {args.command}") diff --git a/tests/codegen/snapshot/codegen_snapshot_fixtures.py b/tests/codegen/snapshot/codegen_snapshot_fixtures.py index 2cf43011..2fdff907 100644 --- a/tests/codegen/snapshot/codegen_snapshot_fixtures.py +++ b/tests/codegen/snapshot/codegen_snapshot_fixtures.py @@ -36,6 +36,7 @@ def file_opener(path: Path) -> TextIO: file_opener=file_opener, typed_dict_inputs=typeddict_inputs, method_filter=None, + protocol_version="v1.1", ) for path, file in files.items(): file.seek(0) diff --git a/tests/codegen/test_rpc.py b/tests/codegen/test_rpc.py index 8ab82095..450a74f0 100644 --- a/tests/codegen/test_rpc.py +++ b/tests/codegen/test_rpc.py @@ -33,6 +33,7 @@ def file_opener(path: Path) -> TextIO: typed_dict_inputs=True, file_opener=file_opener, method_filter=None, + protocol_version="v1.1", ) From 5beda8e70fc99b6ab163b447fd98b6a2ba04e73c Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Sat, 22 Mar 2025 01:00:08 -0700 Subject: [PATCH 008/193] Conditionally emit v2 Clients --- src/replit_river/codegen/client.py | 24 +++++++++++++++++++++--- 1 file changed, 21 insertions(+), 3 deletions(-) diff --git a/src/replit_river/codegen/client.py b/src/replit_river/codegen/client.py index 49a0d95d..d1f4ef6c 100644 --- a/src/replit_river/codegen/client.py +++ b/src/replit_river/codegen/client.py @@ -753,7 +753,14 @@ def generate_common_client( handshake_type: HandshakeType, handshake_chunks: Sequence[str], modules: list[tuple[ModuleName, ClassName]], + protocol_version: Literal["v1.1", "v2.0"], ) -> FileContents: + client_module: str + match protocol_version: + case "v1.1": + client_module = "river" + case "v2.0": + client_module = "river.v2" chunks: list[str] = [ROOT_FILE_HEADER] chunks.extend( [ @@ -767,7 +774,7 @@ def generate_common_client( dedent( f"""\ class {client_name}: - def __init__(self, client: river.Client[{handshake_type}]): + def __init__(self, client: {client_module}.Client[{handshake_type}]): """.rstrip() ) ] @@ -1031,12 +1038,19 @@ def _type_adapter_definition( ], ) + client_module: str + match protocol_version: + case "v1.1": + client_module = "river" + case "v2.0": + client_module = "river.v2" + class_name = ClassName(f"{schema_name.title()}Service") current_chunks: list[str] = [ dedent( f"""\ class {class_name}: - def __init__(self, client: river.Client[Any]): + def __init__(self, client: {client_module}.Client[Any]): self.client = client """ ), @@ -1334,7 +1348,11 @@ def generate_river_client_module( modules.append((module_name, class_name)) main_contents = generate_common_client( - client_name, handshake_type, handshake_chunks, modules + client_name, + handshake_type, + handshake_chunks, + modules, + protocol_version, ) files[RenderedPath(str(Path("__init__.py")))] = main_contents From 25d3585cbfbc2f89c813363b559202d37425ea64 Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Sat, 22 Mar 2025 11:08:29 -0700 Subject: [PATCH 009/193] Expose protocol-version to CLI --- src/replit_river/codegen/run.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/src/replit_river/codegen/run.py b/src/replit_river/codegen/run.py index 5044780d..0f0deac0 100644 --- a/src/replit_river/codegen/run.py +++ b/src/replit_river/codegen/run.py @@ -45,6 +45,13 @@ def main() -> None: action="store", type=pathlib.Path, ) + client.add_argument( + "--protocol-version", + help="Generate river v2 clients", + action="store", + default="v1.1", + choices=["v1.1", "v2.0"], + ) client.add_argument("schema", help="schema file") args = parser.parse_args() @@ -75,7 +82,7 @@ def file_opener(path: Path) -> TextIO: args.typed_dict_inputs, file_opener, method_filter=method_filter, - protocol_version="v1.1", + protocol_version=args.protocol_version, ) else: raise NotImplementedError(f"Unknown command {args.command}") From 1b5ae1250a03724939e94973c0b4bc2efc6faf79 Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Sat, 22 Mar 2025 00:55:51 -0700 Subject: [PATCH 010/193] Pivoting rendered parameters based on protocol_version --- src/replit_river/codegen/client.py | 44 +++++++++++++++++++++++------- 1 file changed, 34 insertions(+), 10 deletions(-) diff --git a/src/replit_river/codegen/client.py b/src/replit_river/codegen/client.py index d1f4ef6c..11c3ac27 100644 --- a/src/replit_river/codegen/client.py +++ b/src/replit_river/codegen/client.py @@ -11,6 +11,7 @@ Sequence, Set, TextIO, + assert_never, cast, ) @@ -804,11 +805,22 @@ def render_library_call( """ current_chunks: list[str] = [] + binding: str if procedure.type == "rpc": - assert input_meta + match protocol_version: + case "v1.1": + assert input_meta + _, tpe, render_method = input_meta + binding = "input" + case "v2.0": + assert init_meta + _, tpe, render_method = init_meta + binding = "init" + case other: + assert_never(other) + assert output_meta assert error_meta - _, input_type, render_input_method = input_meta _, output_type, parse_output_method = output_meta _, _, parse_error_method = error_meta @@ -819,14 +831,14 @@ def render_library_call( f"""\ async def {name}( self, - input: {render_type_expr(input_type)}, + {binding}: {render_type_expr(tpe)}, timeout: datetime.timedelta, ) -> {render_type_expr(output_type)}: return await self.client.send_rpc( {repr(schema_name)}, {repr(name)}, - input, - {reindent(" ", render_input_method)}, + {binding}, + {reindent(" ", render_method)}, {reindent(" ", parse_output_method)}, {reindent(" ", parse_error_method)}, timeout, @@ -836,10 +848,20 @@ async def {name}( ] ) elif procedure.type == "subscription": - assert input_meta + match protocol_version: + case "v1.1": + assert input_meta + _, tpe, render_method = input_meta + binding = "input" + case "v2.0": + assert init_meta + _, tpe, render_method = init_meta + binding = "init" + case other: + assert_never(other) + assert output_meta assert error_meta - _, input_type, render_input_method = input_meta _, output_type, parse_output_method = output_meta _, error_type, parse_error_method = error_meta error_type_name = extract_inner_type(error_type) @@ -859,13 +881,13 @@ async def {name}( f"""\ async def {name}( self, - input: {render_type_expr(input_type)}, + {binding}: {render_type_expr(tpe)}, ) -> AsyncIterator[{render_type_expr(output_or_error_type)}]: return self.client.send_subscription( {repr(schema_name)}, {repr(name)}, - input, - {reindent(" ", render_input_method)}, + {binding}, + {reindent(" ", render_method)}, {reindent(" ", parse_output_method)}, {reindent(" ", parse_error_method)}, ) @@ -911,6 +933,7 @@ async def {name}( ] ) else: + assert protocol_version == "v1.1", "Protocol v2 requires init to be defined" current_chunks.extend( [ reindent( @@ -980,6 +1003,7 @@ async def {name}( ] ) else: + assert protocol_version == "v1.1", "Protocol v2 requires init to be defined" current_chunks.extend( [ reindent( From b3dc162d7de33744382e6d28242d08d3aaa870d8 Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Sat, 22 Mar 2025 12:59:47 -0700 Subject: [PATCH 011/193] Relax the RiverProcedure field requirements so we can parse v2 --- src/replit_river/codegen/client.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/replit_river/codegen/client.py b/src/replit_river/codegen/client.py index 11c3ac27..553c7132 100644 --- a/src/replit_river/codegen/client.py +++ b/src/replit_river/codegen/client.py @@ -119,7 +119,7 @@ class RiverNotType(BaseModel): class RiverProcedure(BaseModel): init: RiverType | None = Field(default=None) - input: RiverType + input: RiverType | None = Field(default=None) output: RiverType errors: RiverType | None = Field(default=None) type: ( From 21216619bcaff603a611f5c989921ce72619f9e4 Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Sat, 22 Mar 2025 12:59:58 -0700 Subject: [PATCH 012/193] Missing render --- src/replit_river/codegen/client.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/replit_river/codegen/client.py b/src/replit_river/codegen/client.py index 553c7132..d05756b6 100644 --- a/src/replit_river/codegen/client.py +++ b/src/replit_river/codegen/client.py @@ -1201,9 +1201,9 @@ def __init__(self, client: {client_module}.Client[Any]): ) and procedure.init.type in ["array"]: match init_type: case ListTypeExpr(init_type_name): - render_init_method = ( - f"lambda xs: [encode_{init_type_name}(x) for x in xs]" - ) + render_init_method = f"lambda xs: [encode_{ + render_literal_type(init_type_name) + }(x) for x in xs]" else: render_init_method = f"encode_{render_literal_type(init_type)}" else: From d241780ca3368c18fb0dad0beef1edef57dd1152 Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Sat, 22 Mar 2025 13:04:01 -0700 Subject: [PATCH 013/193] Moving type ascriptions down to the method level --- src/replit_river/v2/client.py | 62 ++++++++++++--------------- src/replit_river/v2/client_session.py | 61 +++++++++++++------------- 2 files changed, 57 insertions(+), 66 deletions(-) diff --git a/src/replit_river/v2/client.py b/src/replit_river/v2/client.py index efd1bfe5..8221d3d0 100644 --- a/src/replit_river/v2/client.py +++ b/src/replit_river/v2/client.py @@ -13,12 +13,6 @@ ) from replit_river.error_schema import ERROR_CODE_UNKNOWN, RiverError, RiverException -from replit_river.rpc import ( - ErrorType, - InitType, - RequestType, - ResponseType, -) from replit_river.transport_options import ( HandshakeMetadataType, TransportOptions, @@ -91,16 +85,16 @@ async def close(self) -> None: async def ensure_connected(self) -> None: await self._transport.get_or_create_session() - async def send_rpc( + async def send_rpc[R, A]( self, service_name: str, procedure_name: str, - request: RequestType, - request_serializer: Callable[[RequestType], Any], - response_deserializer: Callable[[Any], ResponseType], - error_deserializer: Callable[[Any], ErrorType], + request: R, + request_serializer: Callable[[R], Any], + response_deserializer: Callable[[Any], A], + error_deserializer: Callable[[Any], RiverError], timeout: timedelta, - ) -> ResponseType: + ) -> A: with _trace_procedure("rpc", service_name, procedure_name) as span_handle: session = await self._transport.get_or_create_session() return await session.send_rpc( @@ -114,17 +108,17 @@ async def send_rpc( timeout, ) - async def send_upload( + async def send_upload[I, R, A]( self, service_name: str, procedure_name: str, - init: InitType | None, - request: AsyncIterable[RequestType], - init_serializer: Callable[[InitType], Any] | None, - request_serializer: Callable[[RequestType], Any], - response_deserializer: Callable[[Any], ResponseType], - error_deserializer: Callable[[Any], ErrorType], - ) -> ResponseType: + init: I | None, + request: AsyncIterable[R], + init_serializer: Callable[[I], Any] | None, + request_serializer: Callable[[R], Any], + response_deserializer: Callable[[Any], A], + error_deserializer: Callable[[Any], RiverError], + ) -> A: with _trace_procedure("upload", service_name, procedure_name) as span_handle: session = await self._transport.get_or_create_session() return await session.send_upload( @@ -139,15 +133,15 @@ async def send_upload( span_handle.span, ) - async def send_subscription( + async def send_subscription[R, E, A]( self, service_name: str, procedure_name: str, - request: RequestType, - request_serializer: Callable[[RequestType], Any], - response_deserializer: Callable[[Any], ResponseType], - error_deserializer: Callable[[Any], ErrorType], - ) -> AsyncGenerator[ResponseType | RiverError, None]: + request: R, + request_serializer: Callable[[R], Any], + response_deserializer: Callable[[Any], A], + error_deserializer: Callable[[Any], E], + ) -> AsyncGenerator[A | E, None]: with _trace_procedure( "subscription", service_name, procedure_name ) as span_handle: @@ -165,17 +159,17 @@ async def send_subscription( _record_river_error(span_handle, msg) yield msg # type: ignore # https://github.com/python/mypy/issues/10817 - async def send_stream( + async def send_stream[I, R, E, A]( self, service_name: str, procedure_name: str, - init: InitType | None, - request: AsyncIterable[RequestType], - init_serializer: Callable[[InitType], Any] | None, - request_serializer: Callable[[RequestType], Any], - response_deserializer: Callable[[Any], ResponseType], - error_deserializer: Callable[[Any], ErrorType], - ) -> AsyncGenerator[ResponseType | RiverError, None]: + init: I | None, + request: AsyncIterable[R], + init_serializer: Callable[[I], Any] | None, + request_serializer: Callable[[R], Any], + response_deserializer: Callable[[Any], A], + error_deserializer: Callable[[Any], E], + ) -> AsyncGenerator[A | E, None]: with _trace_procedure("stream", service_name, procedure_name) as span_handle: session = await self._transport.get_or_create_session() async for msg in session.send_stream( diff --git a/src/replit_river/v2/client_session.py b/src/replit_river/v2/client_session.py index 505e23e2..3daed6cd 100644 --- a/src/replit_river/v2/client_session.py +++ b/src/replit_river/v2/client_session.py @@ -15,6 +15,7 @@ from replit_river.error_schema import ( ERROR_CODE_CANCEL, ERROR_CODE_STREAM_CLOSED, + RiverError, RiverException, RiverServiceException, StreamClosedRiverServiceException, @@ -28,10 +29,6 @@ ACK_BIT, STREAM_CLOSED_BIT, STREAM_OPEN_BIT, - ErrorType, - InitType, - RequestType, - ResponseType, ) from replit_river.seq_manager import ( IgnoreMessageException, @@ -168,17 +165,17 @@ async def _handle_messages_from_ws(self) -> None: except ConnectionClosed as e: raise e - async def send_rpc( + async def send_rpc[R, A]( self, service_name: str, procedure_name: str, - request: RequestType, - request_serializer: Callable[[RequestType], Any], - response_deserializer: Callable[[Any], ResponseType], - error_deserializer: Callable[[Any], ErrorType], + request: R, + request_serializer: Callable[[R], Any], + response_deserializer: Callable[[Any], A], + error_deserializer: Callable[[Any], RiverError], span: Span, timeout: timedelta, - ) -> ResponseType: + ) -> A: """Sends a single RPC request to the server. Expects the input and output be messages that will be msgpacked. @@ -233,18 +230,18 @@ async def send_rpc( except Exception as e: raise e - async def send_upload( + async def send_upload[I, R, A]( self, service_name: str, procedure_name: str, - init: InitType | None, - request: AsyncIterable[RequestType], - init_serializer: Callable[[InitType], Any] | None, - request_serializer: Callable[[RequestType], Any], - response_deserializer: Callable[[Any], ResponseType], - error_deserializer: Callable[[Any], ErrorType], + init: I | None, + request: AsyncIterable[R], + init_serializer: Callable[[I], Any] | None, + request_serializer: Callable[[R], Any], + response_deserializer: Callable[[Any], A], + error_deserializer: Callable[[Any], RiverError], span: Span, - ) -> ResponseType: + ) -> A: """Sends an upload request to the server. Expects the input and output be messages that will be msgpacked. @@ -320,16 +317,16 @@ async def send_upload( except Exception as e: raise e - async def send_subscription( + async def send_subscription[R, E, A]( self, service_name: str, procedure_name: str, - request: RequestType, - request_serializer: Callable[[RequestType], Any], - response_deserializer: Callable[[Any], ResponseType], - error_deserializer: Callable[[Any], ErrorType], + request: R, + request_serializer: Callable[[R], Any], + response_deserializer: Callable[[Any], A], + error_deserializer: Callable[[Any], E], span: Span, - ) -> AsyncGenerator[ResponseType | ErrorType, None]: + ) -> AsyncGenerator[A | E, None]: """Sends a subscription request to the server. Expects the input and output be messages that will be msgpacked. @@ -372,18 +369,18 @@ async def send_subscription( finally: output.close() - async def send_stream( + async def send_stream[I, R, E, A]( self, service_name: str, procedure_name: str, - init: InitType | None, - request: AsyncIterable[RequestType], - init_serializer: Callable[[InitType], Any] | None, - request_serializer: Callable[[RequestType], Any], - response_deserializer: Callable[[Any], ResponseType], - error_deserializer: Callable[[Any], ErrorType], + init: I | None, + request: AsyncIterable[R], + init_serializer: Callable[[I], Any] | None, + request_serializer: Callable[[R], Any], + response_deserializer: Callable[[Any], A], + error_deserializer: Callable[[Any], E], span: Span, - ) -> AsyncGenerator[ResponseType | ErrorType, None]: + ) -> AsyncGenerator[A | E, None]: """Sends a subscription request to the server. Expects the input and output be messages that will be msgpacked. From 9fa548d4318148e7ccbea8e8a55f35489c563434 Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Sat, 22 Mar 2025 13:20:53 -0700 Subject: [PATCH 014/193] Just avoid calling add_msg_to_stream if it should not be called --- src/replit_river/v2/client_session.py | 18 ++++++++++++++++-- 1 file changed, 16 insertions(+), 2 deletions(-) diff --git a/src/replit_river/v2/client_session.py b/src/replit_river/v2/client_session.py index 3daed6cd..c7288463 100644 --- a/src/replit_river/v2/client_session.py +++ b/src/replit_river/v2/client_session.py @@ -11,7 +11,6 @@ from opentelemetry.trace import Span from websockets.exceptions import ConnectionClosed -from replit_river.common_session import add_msg_to_stream from replit_river.error_schema import ( ERROR_CODE_CANCEL, ERROR_CODE_STREAM_CLOSED, @@ -140,7 +139,22 @@ async def _handle_messages_from_ws(self) -> None: raise IgnoreMessageException( "no stream for message, ignoring" ) - await add_msg_to_stream(msg, stream) + + if ( + msg.controlFlags & STREAM_CLOSED_BIT != 0 + and msg.payload.get("type", None) == "CLOSE" + ): + # close message is not sent to the stream + pass + else: + try: + await stream.put(msg.payload) + except ChannelClosed: + # The client is no longer interested in this stream, + # just drop the message. + pass + except RuntimeError as e: + raise InvalidMessageException(e) from e else: raise InvalidMessageException( "Client should not receive stream open bit" From 231f1a3978d0f6b1e8f5611872bd0e6820b73625 Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Sat, 22 Mar 2025 13:22:36 -0700 Subject: [PATCH 015/193] Adding missing STREAM_CLOSED_BIT --- src/replit_river/v2/client_session.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/replit_river/v2/client_session.py b/src/replit_river/v2/client_session.py index c7288463..5e7a1a08 100644 --- a/src/replit_river/v2/client_session.py +++ b/src/replit_river/v2/client_session.py @@ -26,7 +26,6 @@ ) from replit_river.rpc import ( ACK_BIT, - STREAM_CLOSED_BIT, STREAM_OPEN_BIT, ) from replit_river.seq_manager import ( @@ -37,6 +36,9 @@ from replit_river.session import Session from replit_river.transport_options import MAX_MESSAGE_BUFFER_SIZE, TransportOptions +STREAM_CLOSED_BIT = 0x0004 # Synonymous with the cancel bit in v2 + + logger = logging.getLogger(__name__) From 6a71888624633f823002cfb8b4e7912accb2bc22 Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Sat, 22 Mar 2025 14:38:44 -0700 Subject: [PATCH 016/193] v2 closed bit --- src/replit_river/v2/client_session.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/replit_river/v2/client_session.py b/src/replit_river/v2/client_session.py index 5e7a1a08..5bd55b28 100644 --- a/src/replit_river/v2/client_session.py +++ b/src/replit_river/v2/client_session.py @@ -36,7 +36,7 @@ from replit_river.session import Session from replit_river.transport_options import MAX_MESSAGE_BUFFER_SIZE, TransportOptions -STREAM_CLOSED_BIT = 0x0004 # Synonymous with the cancel bit in v2 +STREAM_CLOSED_BIT = 0b01000 # Synonymous with the cancel bit in v2 logger = logging.getLogger(__name__) From 8bced2d00c597c962d8b81c76f13efd280c23013 Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Sat, 22 Mar 2025 14:38:54 -0700 Subject: [PATCH 017/193] v2 cancel bit --- src/replit_river/v2/client_session.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/replit_river/v2/client_session.py b/src/replit_river/v2/client_session.py index 5bd55b28..10a68fb8 100644 --- a/src/replit_river/v2/client_session.py +++ b/src/replit_river/v2/client_session.py @@ -36,6 +36,7 @@ from replit_river.session import Session from replit_river.transport_options import MAX_MESSAGE_BUFFER_SIZE, TransportOptions +STREAM_CANCEL_BIT = 0b00100 # Synonymous with the cancel bit in v2 STREAM_CLOSED_BIT = 0b01000 # Synonymous with the cancel bit in v2 @@ -213,11 +214,10 @@ async def send_rpc[R, A]( async with asyncio.timeout(timeout.total_seconds()): response = await output.get() except asyncio.TimeoutError as e: - # TODO(dstewart) After protocol v2, change this to STREAM_CANCEL_BIT await self.send_message( stream_id=stream_id, - control_flags=STREAM_CLOSED_BIT, - payload={"type": "CLOSE"}, + control_flags=STREAM_CANCEL_BIT, + payload={"type": "CANCEL"}, service_name=service_name, procedure_name=procedure_name, span=span, From 656da90605f0199004a8fd2713b70535d89b7fa2 Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Sat, 22 Mar 2025 14:39:09 -0700 Subject: [PATCH 018/193] Reflowing v2 send_stream to have modern semantics --- src/replit_river/v2/client.py | 8 ++--- src/replit_river/v2/client_session.py | 46 +++++++++------------------ 2 files changed, 19 insertions(+), 35 deletions(-) diff --git a/src/replit_river/v2/client.py b/src/replit_river/v2/client.py index 8221d3d0..6cf7abe9 100644 --- a/src/replit_river/v2/client.py +++ b/src/replit_river/v2/client.py @@ -163,10 +163,10 @@ async def send_stream[I, R, E, A]( self, service_name: str, procedure_name: str, - init: I | None, - request: AsyncIterable[R], - init_serializer: Callable[[I], Any] | None, - request_serializer: Callable[[R], Any], + init: I, + request: AsyncIterable[R] | None, + init_serializer: Callable[[I], Any], + request_serializer: Callable[[R], Any] | None, response_deserializer: Callable[[Any], A], error_deserializer: Callable[[Any], E], ) -> AsyncGenerator[A | E, None]: diff --git a/src/replit_river/v2/client_session.py b/src/replit_river/v2/client_session.py index 10a68fb8..6058ca95 100644 --- a/src/replit_river/v2/client_session.py +++ b/src/replit_river/v2/client_session.py @@ -389,10 +389,10 @@ async def send_stream[I, R, E, A]( self, service_name: str, procedure_name: str, - init: I | None, - request: AsyncIterable[R], - init_serializer: Callable[[I], Any] | None, - request_serializer: Callable[[R], Any], + init: I, + request: AsyncIterable[R] | None, + init_serializer: Callable[[I], Any], + request_serializer: Callable[[R], Any] | None, response_deserializer: Callable[[Any], A], error_deserializer: Callable[[Any], E], span: Span, @@ -405,33 +405,15 @@ async def send_stream[I, R, E, A]( stream_id = nanoid.generate() output: Channel[Any] = Channel(MAX_MESSAGE_BUFFER_SIZE) self._streams[stream_id] = output - empty_stream = False try: - if init and init_serializer: - await self.send_message( - service_name=service_name, - procedure_name=procedure_name, - stream_id=stream_id, - control_flags=STREAM_OPEN_BIT, - payload=init_serializer(init), - span=span, - ) - else: - # Get the very first message to open the stream - request_iter = aiter(request) - first = await anext(request_iter) - await self.send_message( - service_name=service_name, - procedure_name=procedure_name, - stream_id=stream_id, - control_flags=STREAM_OPEN_BIT, - payload=request_serializer(first), - span=span, - ) - - except StopAsyncIteration: - empty_stream = True - + await self.send_message( + service_name=service_name, + procedure_name=procedure_name, + stream_id=stream_id, + control_flags=STREAM_OPEN_BIT, + payload=init_serializer(init), + span=span, + ) except Exception as e: raise StreamClosedRiverServiceException( ERROR_CODE_STREAM_CLOSED, str(e), service_name, procedure_name @@ -439,7 +421,7 @@ async def send_stream[I, R, E, A]( # Create the encoder task async def _encode_stream() -> None: - if empty_stream: + if not request: await self.send_close_stream( service_name, procedure_name, @@ -448,6 +430,8 @@ async def _encode_stream() -> None: ) return + assert request_serializer, "send_stream missing request_serializer" + async for item in request: if item is None: continue From 8628a3c1b5d83af7a0b2a04bc960014aeb5069be Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Sat, 22 Mar 2025 14:28:10 -0700 Subject: [PATCH 019/193] v2 send_stream codegen --- src/replit_river/codegen/client.py | 38 ++++++++++++++++++++++++++---- 1 file changed, 33 insertions(+), 5 deletions(-) diff --git a/src/replit_river/codegen/client.py b/src/replit_river/codegen/client.py index d05756b6..a35c3ab4 100644 --- a/src/replit_river/codegen/client.py +++ b/src/replit_river/codegen/client.py @@ -960,10 +960,8 @@ async def {name}( ] ) elif procedure.type == "stream": - assert input_meta assert output_meta assert error_meta - _, input_type, render_input_method = input_meta _, output_type, parse_output_method = output_meta _, error_type, parse_error_method = error_meta error_type_name = extract_inner_type(error_type) @@ -976,8 +974,9 @@ async def {name}( TypeName("RiverError"), ] ) - if init_meta: + if init_meta and input_meta: _, init_type, render_init_method = init_meta + _, input_type, render_input_method = input_meta current_chunks.extend( [ reindent( @@ -1002,8 +1001,9 @@ async def {name}( ) ] ) - else: - assert protocol_version == "v1.1", "Protocol v2 requires init to be defined" + elif protocol_version == "v1.1": + assert input_meta, "Protocol v1 requires input to be defined" + _, input_type, render_input_method = input_meta current_chunks.extend( [ reindent( @@ -1027,6 +1027,34 @@ async def {name}( ) ] ) + elif protocol_version == "v2.0": + assert init_meta, "Protocol v2 requires init to be defined" + _, init_type, render_init_method = init_meta + current_chunks.extend( + [ + reindent( + " ", + f"""\ + async def {name}( + self, + init: {render_type_expr(init_type)}, + ) -> AsyncIterator[{render_type_expr(output_or_error_type)}]: + return self.client.send_stream( + {repr(schema_name)}, + {repr(name)}, + init, + None, + {reindent(" ", render_init_method)}, + None, + {reindent(" ", parse_output_method)}, + {reindent(" ", parse_error_method)}, + ) + """, + ) + ] + ) + else: + raise ValueError("Precondition failed") current_chunks.append("") return current_chunks From 1efbb3affe1a9f03dc1585c4abe5fd1c667e7b9b Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Sat, 22 Mar 2025 14:52:19 -0700 Subject: [PATCH 020/193] Reflowing v2 send_upload to have modern semantics --- src/replit_river/v2/client.py | 8 ++-- src/replit_river/v2/client_session.py | 60 +++++++++++++-------------- 2 files changed, 33 insertions(+), 35 deletions(-) diff --git a/src/replit_river/v2/client.py b/src/replit_river/v2/client.py index 6cf7abe9..958ff67d 100644 --- a/src/replit_river/v2/client.py +++ b/src/replit_river/v2/client.py @@ -112,10 +112,10 @@ async def send_upload[I, R, A]( self, service_name: str, procedure_name: str, - init: I | None, - request: AsyncIterable[R], - init_serializer: Callable[[I], Any] | None, - request_serializer: Callable[[R], Any], + init: I, + request: AsyncIterable[R] | None, + init_serializer: Callable[[I], Any], + request_serializer: Callable[[R], Any] | None, response_deserializer: Callable[[Any], A], error_deserializer: Callable[[Any], RiverError], ) -> A: diff --git a/src/replit_river/v2/client_session.py b/src/replit_river/v2/client_session.py index 6058ca95..6fa6cd0c 100644 --- a/src/replit_river/v2/client_session.py +++ b/src/replit_river/v2/client_session.py @@ -250,10 +250,10 @@ async def send_upload[I, R, A]( self, service_name: str, procedure_name: str, - init: I | None, - request: AsyncIterable[R], - init_serializer: Callable[[I], Any] | None, - request_serializer: Callable[[R], Any], + init: I, + request: AsyncIterable[R] | None, + init_serializer: Callable[[I], Any], + request_serializer: Callable[[R], Any] | None, response_deserializer: Callable[[Any], A], error_deserializer: Callable[[Any], RiverError], span: Span, @@ -266,33 +266,31 @@ async def send_upload[I, R, A]( stream_id = nanoid.generate() output: Channel[Any] = Channel(1) self._streams[stream_id] = output - first_message = True try: - if init and init_serializer: - await self.send_message( - stream_id=stream_id, - control_flags=STREAM_OPEN_BIT, - service_name=service_name, - procedure_name=procedure_name, - payload=init_serializer(init), - span=span, - ) - first_message = False - # If this request is not closed and the session is killed, we should - # throw exception here - async for item in request: - control_flags = 0 - if first_message: - control_flags = STREAM_OPEN_BIT - first_message = False - await self.send_message( - stream_id=stream_id, - service_name=service_name, - procedure_name=procedure_name, - control_flags=control_flags, - payload=request_serializer(item), - span=span, - ) + await self.send_message( + stream_id=stream_id, + control_flags=STREAM_OPEN_BIT, + service_name=service_name, + procedure_name=procedure_name, + payload=init_serializer(init), + span=span, + ) + + if request: + assert request_serializer, "send_stream missing request_serializer" + + # If this request is not closed and the session is killed, we should + # throw exception here + async for item in request: + control_flags = 0 + await self.send_message( + stream_id=stream_id, + service_name=service_name, + procedure_name=procedure_name, + control_flags=control_flags, + payload=request_serializer(item), + span=span, + ) except Exception as e: raise RiverServiceException( ERROR_CODE_STREAM_CLOSED, str(e), service_name, procedure_name @@ -301,7 +299,7 @@ async def send_upload[I, R, A]( service_name, procedure_name, stream_id, - extra_control_flags=STREAM_OPEN_BIT if first_message else 0, + extra_control_flags=0, ) # Handle potential errors during communication From 6d92145ae11203d19c37d0a14f395f14e10a6556 Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Sat, 22 Mar 2025 14:52:34 -0700 Subject: [PATCH 021/193] v2 send_upload codegen --- src/replit_river/codegen/client.py | 40 ++++++++++++++++++++++++++---- 1 file changed, 35 insertions(+), 5 deletions(-) diff --git a/src/replit_river/codegen/client.py b/src/replit_river/codegen/client.py index a35c3ab4..d56856f8 100644 --- a/src/replit_river/codegen/client.py +++ b/src/replit_river/codegen/client.py @@ -896,18 +896,17 @@ async def {name}( ] ) elif procedure.type == "upload": - assert input_meta assert output_meta assert error_meta - _, input_type, render_input_method = input_meta _, output_type, parse_output_method = output_meta _, error_type, parse_error_method = error_meta error_type_name = extract_inner_type(error_type) output_or_error_type = UnionTypeExpr([output_type, error_type_name]) - if init_meta: + if init_meta and input_meta: _, init_type, render_init_method = init_meta + _, input_type, render_input_method = input_meta current_chunks.extend( [ reindent( @@ -932,8 +931,9 @@ async def {name}( ) ] ) - else: - assert protocol_version == "v1.1", "Protocol v2 requires init to be defined" + elif protocol_version == "v1.1": + assert input_meta, "Protocol v1 requires input to be defined" + _, input_type, render_input_method = input_meta current_chunks.extend( [ reindent( @@ -959,6 +959,36 @@ async def {name}( ) ] ) + elif protocol_version == "v2.0": + assert init_meta, "Protocol v2 requires init to be defined" + _, init_type, render_init_method = init_meta + current_chunks.extend( + [ + reindent( + " ", + f"""\ + async def {name}( + self, + init: {render_type_expr(init_type)}, + ) -> { # TODO(dstewart) This should just be output_type + render_type_expr(output_or_error_type) + }: + return await self.client.send_upload( + {repr(schema_name)}, + {repr(name)}, + init, + None, + {reindent(" ", render_init_method)}, + None, + {reindent(" ", parse_output_method)}, + {reindent(" ", parse_error_method)}, + ) + """, + ) + ] + ) + else: + raise ValueError("Precondition failed") elif procedure.type == "stream": assert output_meta assert error_meta From f0f248214082589f3e96b19643973cb09242df71 Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Sat, 22 Mar 2025 16:01:59 -0700 Subject: [PATCH 022/193] Distribute PROTOCOL_VERSION constant --- src/replit_river/v2/client_transport.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/replit_river/v2/client_transport.py b/src/replit_river/v2/client_transport.py index 8248d7c4..54d5a73c 100644 --- a/src/replit_river/v2/client_transport.py +++ b/src/replit_river/v2/client_transport.py @@ -43,6 +43,8 @@ ) from replit_river.v2.client_session import ClientSession +PROTOCOL_VERSION = "v2.0" + logger = logging.getLogger(__name__) @@ -245,7 +247,7 @@ async def _send_handshake_request( ) -> ControlMessageHandshakeRequest[HandshakeMetadataType]: handshake_request = ControlMessageHandshakeRequest[HandshakeMetadataType]( type="HANDSHAKE_REQ", - protocolVersion="v2.0", + protocolVersion=PROTOCOL_VERSION, sessionId=session_id, metadata=handshake_metadata, expectedSessionState=expected_session_state, From 41c1c49937ed41ab161fcff160c68175c5cdb9f8 Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Sat, 22 Mar 2025 16:19:10 -0700 Subject: [PATCH 023/193] Describing the control packet structures --- src/replit_river/v2/schema.py | 39 +++++++++++++++++++++++++++++++++++ 1 file changed, 39 insertions(+) create mode 100644 src/replit_river/v2/schema.py diff --git a/src/replit_river/v2/schema.py b/src/replit_river/v2/schema.py new file mode 100644 index 00000000..473fd0c9 --- /dev/null +++ b/src/replit_river/v2/schema.py @@ -0,0 +1,39 @@ +from typing import Any, Literal, NotRequired, TypeAlias, TypedDict + +from replit_river.rpc import ExpectedSessionState + + +class ControlClose(TypedDict): + type: Literal["CLOSE"] + + +class ControlAck(TypedDict): + type: Literal["ACK"] + + +class ControlHandshakeRequest(TypedDict): + type: Literal["HANDSHAKE_REQ"] + protocolVersion: Literal["v2.0"] + sessionId: str + expectedSessionState: ExpectedSessionState + metdata: NotRequired[Any] + + +class HandshakeOK(TypedDict): + ok: Literal[True] + sessionId: str + + +class HandshakeError(TypedDict): + ok: Literal[False] + reaason: str + + +class ControlHandshakeResponse(TypedDict): + type: Literal["HANDSHAKE_RESP"] + status: HandshakeOK | HandshakeError + + +Control: TypeAlias = ( + ControlClose | ControlAck | ControlHandshakeRequest | ControlHandshakeResponse +) From 66f0e36a7caafbd36bb1db07ca6a381f8eb943a4 Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Sat, 22 Mar 2025 17:39:16 -0700 Subject: [PATCH 024/193] Translating a little more of PROTOCOL into code --- src/replit_river/rpc.py | 6 +++-- src/replit_river/v2/client_session.py | 8 +++--- src/replit_river/v2/schema.py | 37 ++++++++++++++++++++++++++- 3 files changed, 45 insertions(+), 6 deletions(-) diff --git a/src/replit_river/rpc.py b/src/replit_river/rpc.py index 3459bf7d..678e8bf6 100644 --- a/src/replit_river/rpc.py +++ b/src/replit_river/rpc.py @@ -47,8 +47,10 @@ GenericRpcHandlerBuilder = Callable[ [str, Channel[Any], Channel[Any]], Coroutine[None, None, None] ] -ACK_BIT = 0b00001 -STREAM_OPEN_BIT = 0b00010 +ACK_BIT_TYPE = Literal[0b00001] +ACK_BIT: ACK_BIT_TYPE = 0b00001 +STREAM_OPEN_BIT_TYPE = Literal[0b00010] +STREAM_OPEN_BIT: STREAM_OPEN_BIT_TYPE = 0b00010 # these codes are retriable # if the server sends a response with one of these codes, diff --git a/src/replit_river/v2/client_session.py b/src/replit_river/v2/client_session.py index 6fa6cd0c..b2ebbd42 100644 --- a/src/replit_river/v2/client_session.py +++ b/src/replit_river/v2/client_session.py @@ -2,7 +2,7 @@ import logging from collections.abc import AsyncIterable from datetime import timedelta -from typing import Any, AsyncGenerator, Callable, Coroutine +from typing import Any, AsyncGenerator, Callable, Coroutine, Literal import nanoid # type: ignore import websockets @@ -36,8 +36,10 @@ from replit_river.session import Session from replit_river.transport_options import MAX_MESSAGE_BUFFER_SIZE, TransportOptions -STREAM_CANCEL_BIT = 0b00100 # Synonymous with the cancel bit in v2 -STREAM_CLOSED_BIT = 0b01000 # Synonymous with the cancel bit in v2 +STREAM_CANCEL_BIT_TYPE = Literal[0b00100] +STREAM_CANCEL_BIT: STREAM_CANCEL_BIT_TYPE = 0b00100 +STREAM_CLOSED_BIT_TYPE = Literal[0b01000] +STREAM_CLOSED_BIT: STREAM_CLOSED_BIT_TYPE = 0b01000 logger = logging.getLogger(__name__) diff --git a/src/replit_river/v2/schema.py b/src/replit_river/v2/schema.py index 473fd0c9..ac08fbfc 100644 --- a/src/replit_river/v2/schema.py +++ b/src/replit_river/v2/schema.py @@ -1,6 +1,9 @@ from typing import Any, Literal, NotRequired, TypeAlias, TypedDict -from replit_river.rpc import ExpectedSessionState +from grpc.aio import BaseError + +from replit_river.rpc import ACK_BIT, ACK_BIT_TYPE, ExpectedSessionState +from replit_river.v2.client_session import STREAM_CANCEL_BIT, STREAM_CANCEL_BIT_TYPE class ControlClose(TypedDict): @@ -33,7 +36,39 @@ class ControlHandshakeResponse(TypedDict): type: Literal["HANDSHAKE_RESP"] status: HandshakeOK | HandshakeError +# This is sent when the server encounters an internal error +# i.e. an invariant has been violated +class BaseErrorStructure(TypedDict): + # This should be a defined literal to make sure errors are easily differentiated + # code: str # Supplied by implementations + # This can be any string + message: str + # Any extra metadata + extra: NotRequired[Any] + +# When a client sends a malformed request. This can be +# for a variety of reasons which would be included +# in the message. +class InvalidRequestError(BaseErrorStructure): + code: Literal['INVALID_REQUEST'] + +# This is sent when an exception happens in the handler of a stream. +class UncaughtError(BaseErrorStructure): + code: Literal['UNCAUGHT_ERROR'] + +# This is sent when one side wishes to cancel the stream +# abruptly from user-space. Handling this is up to the procedure +# implementation or the caller. +class CancelError(BaseErrorStructure): + code: Literal['CANCEL'] + +ProtocolError: TypeAlias = UncaughtError | InvalidRequestError | CancelError; Control: TypeAlias = ( ControlClose | ControlAck | ControlHandshakeRequest | ControlHandshakeResponse ) + +ValidPairings = ( + tuple[ACK_BIT_TYPE, ControlAck] | + tuple[STREAM_CANCEL_BIT_TYPE, ProtocolError] +) From 9a46e261a99b97221963724260109d6a08bf7732 Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Sat, 22 Mar 2025 17:39:37 -0700 Subject: [PATCH 025/193] These are all self calls --- src/replit_river/v2/client_transport.py | 12 ++---------- 1 file changed, 2 insertions(+), 10 deletions(-) diff --git a/src/replit_river/v2/client_transport.py b/src/replit_river/v2/client_transport.py index 54d5a73c..464660f2 100644 --- a/src/replit_river/v2/client_transport.py +++ b/src/replit_river/v2/client_transport.py @@ -181,8 +181,6 @@ async def _establish_new_connection( handshake_request, handshake_response, ) = await self._establish_handshake( - self._transport_id, - self._server_id, session_id, uri_and_metadata["metadata"], ws, @@ -238,8 +236,6 @@ async def _retry_connection(self) -> ClientSession: async def _send_handshake_request( self, - transport_id: str, - to_id: str, session_id: str, handshake_metadata: HandshakeMetadataType | None, websocket: WebSocketCommonProtocol, @@ -260,8 +256,8 @@ async def websocket_closed_callback() -> None: try: await send_transport_message( TransportMessage( - from_=transport_id, # type: ignore - to=to_id, + from_=self.transport_id, # type: ignore + to=self._server_id, streamId=stream_id, controlFlags=0, id=self.generate_nanoid(), @@ -306,8 +302,6 @@ async def _get_handshake_response_msg( async def _establish_handshake( self, - transport_id: str, - to_id: str, session_id: str, handshake_metadata: HandshakeMetadataType, websocket: WebSocketCommonProtocol, @@ -332,8 +326,6 @@ async def _establish_handshake( case other: assert_never(other) handshake_request = await self._send_handshake_request( - transport_id=transport_id, - to_id=to_id, session_id=session_id, handshake_metadata=handshake_metadata, websocket=websocket, From b43f2e68d3a5e61d9f6a406b5e9c47974f2c5961 Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Sat, 22 Mar 2025 17:50:13 -0700 Subject: [PATCH 026/193] Typing all self parameters --- src/replit_river/v2/client_session.py | 14 ++++------- src/replit_river/v2/schema.py | 34 ++++++++++++++------------- 2 files changed, 22 insertions(+), 26 deletions(-) diff --git a/src/replit_river/v2/client_session.py b/src/replit_river/v2/client_session.py index b2ebbd42..c3da969c 100644 --- a/src/replit_river/v2/client_session.py +++ b/src/replit_river/v2/client_session.py @@ -2,7 +2,7 @@ import logging from collections.abc import AsyncIterable from datetime import timedelta -from typing import Any, AsyncGenerator, Callable, Coroutine, Literal +from typing import Any, AsyncGenerator, Callable, Literal import nanoid # type: ignore import websockets @@ -33,7 +33,7 @@ InvalidMessageException, OutOfOrderMessageException, ) -from replit_river.session import Session +from replit_river.session import CloseSessionCallback, RetryConnectionCallback, Session from replit_river.transport_options import MAX_MESSAGE_BUFFER_SIZE, TransportOptions STREAM_CANCEL_BIT_TYPE = Literal[0b00100] @@ -53,14 +53,8 @@ def __init__( session_id: str, websocket: websockets.WebSocketCommonProtocol, transport_options: TransportOptions, - close_session_callback: Callable[[Session], Coroutine[Any, Any, Any]], - retry_connection_callback: ( - Callable[ - [], - Coroutine[Any, Any, Any], - ] - | None - ) = None, + close_session_callback: CloseSessionCallback, + retry_connection_callback: RetryConnectionCallback | None = None, ) -> None: super().__init__( transport_id=transport_id, diff --git a/src/replit_river/v2/schema.py b/src/replit_river/v2/schema.py index ac08fbfc..5fffd164 100644 --- a/src/replit_river/v2/schema.py +++ b/src/replit_river/v2/schema.py @@ -1,9 +1,7 @@ from typing import Any, Literal, NotRequired, TypeAlias, TypedDict -from grpc.aio import BaseError - -from replit_river.rpc import ACK_BIT, ACK_BIT_TYPE, ExpectedSessionState -from replit_river.v2.client_session import STREAM_CANCEL_BIT, STREAM_CANCEL_BIT_TYPE +from replit_river.rpc import ACK_BIT_TYPE, ExpectedSessionState +from replit_river.v2.client_session import STREAM_CANCEL_BIT_TYPE class ControlClose(TypedDict): @@ -36,39 +34,43 @@ class ControlHandshakeResponse(TypedDict): type: Literal["HANDSHAKE_RESP"] status: HandshakeOK | HandshakeError + # This is sent when the server encounters an internal error # i.e. an invariant has been violated class BaseErrorStructure(TypedDict): - # This should be a defined literal to make sure errors are easily differentiated - # code: str # Supplied by implementations - # This can be any string - message: str - # Any extra metadata - extra: NotRequired[Any] + # This should be a defined literal to make sure errors are easily differentiated + # code: str # Supplied by implementations + # This can be any string + message: str + # Any extra metadata + extra: NotRequired[Any] + # When a client sends a malformed request. This can be # for a variety of reasons which would be included # in the message. class InvalidRequestError(BaseErrorStructure): - code: Literal['INVALID_REQUEST'] + code: Literal["INVALID_REQUEST"] + # This is sent when an exception happens in the handler of a stream. class UncaughtError(BaseErrorStructure): - code: Literal['UNCAUGHT_ERROR'] + code: Literal["UNCAUGHT_ERROR"] + # This is sent when one side wishes to cancel the stream # abruptly from user-space. Handling this is up to the procedure # implementation or the caller. class CancelError(BaseErrorStructure): - code: Literal['CANCEL'] + code: Literal["CANCEL"] + -ProtocolError: TypeAlias = UncaughtError | InvalidRequestError | CancelError; +ProtocolError: TypeAlias = UncaughtError | InvalidRequestError | CancelError Control: TypeAlias = ( ControlClose | ControlAck | ControlHandshakeRequest | ControlHandshakeResponse ) ValidPairings = ( - tuple[ACK_BIT_TYPE, ControlAck] | - tuple[STREAM_CANCEL_BIT_TYPE, ProtocolError] + tuple[ACK_BIT_TYPE, ControlAck] | tuple[STREAM_CANCEL_BIT_TYPE, ProtocolError] ) From d39aafc1a0ceccaff2c0da0347ab79be05e27f4a Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Sat, 22 Mar 2025 20:58:17 -0700 Subject: [PATCH 027/193] WIP --- src/replit_river/common_session.py | 3 +- src/replit_river/v2/__init__.py | 3 + src/replit_river/v2/client.py | 21 ++ src/replit_river/v2/client_session.py | 142 +++++++++---- src/replit_river/v2/client_transport.py | 108 +++++----- src/replit_river/v2/session.py | 256 ++++++++++++++++++++++++ 6 files changed, 429 insertions(+), 104 deletions(-) create mode 100644 src/replit_river/v2/session.py diff --git a/src/replit_river/common_session.py b/src/replit_river/common_session.py index 2325492e..f697c1ae 100644 --- a/src/replit_river/common_session.py +++ b/src/replit_river/common_session.py @@ -3,6 +3,7 @@ from typing import Any, Protocol from opentelemetry.trace import Span +from websockets import WebSocketCommonProtocol logger = logging.getLogger(__name__) @@ -36,5 +37,5 @@ class SessionState(enum.Enum): CLOSED = 3 -ConnectingStates = set([SessionState.NO_CONNECTION]) +ConnectingStateta = set([SessionState.NO_CONNECTION]) TerminalStates = set([SessionState.CLOSING, SessionState.CLOSED]) diff --git a/src/replit_river/v2/__init__.py b/src/replit_river/v2/__init__.py index 19790f0e..142ab2a0 100644 --- a/src/replit_river/v2/__init__.py +++ b/src/replit_river/v2/__init__.py @@ -1,5 +1,8 @@ +from replit_river.v2.session import Session + from .client import Client __all__ = [ "Client", + "Session", ] diff --git a/src/replit_river/v2/client.py b/src/replit_river/v2/client.py index 958ff67d..a7c09f74 100644 --- a/src/replit_river/v2/client.py +++ b/src/replit_river/v2/client.py @@ -58,6 +58,27 @@ def translate_unknown_error( return RiverUnknownError(code=ERROR_CODE_UNKNOWN, message="Unknown error") +# Client[HandshakeSchema]( +# uri_and_metadata_factory=uri_and_metadata_factory, +# client_id=self.client_id, +# server_id="SERVER", +# transport_options=TransportOptions( +# session_disconnect_grace_ms=settings.RIVER_SESSION_DISCONNECT_GRACE_MS, +# heartbeat_ms=settings.RIVER_HEARTBEAT_MS, +# heartbeats_until_dead=settings.RIVER_HEARTBEATS_UNTIL_DEAD, +# connection_retry_options=ConnectionRetryOptions( +# base_interval_ms=settings.RIVER_CONNECTION_BASE_INTERVAL_MS, +# max_jitter_ms=settings.RIVER_CONNECTION_MAX_JITTER_MS, +# max_backoff_ms=settings.RIVER_CONNECTION_MAX_BACKOFF_MS, +# attempt_budget_capacity=self.attempt_budget_capacity, +# budget_restore_interval_ms= +# settings.RIVER_CONNECTION_BUDGET_RESTORE_INTERVAL_MS, +# max_retry=self.max_retry_count, +# ), +# ), +# ) + + class Client(Generic[HandshakeMetadataType]): def __init__( self, diff --git a/src/replit_river/v2/client_session.py b/src/replit_river/v2/client_session.py index c3da969c..53249c06 100644 --- a/src/replit_river/v2/client_session.py +++ b/src/replit_river/v2/client_session.py @@ -4,12 +4,13 @@ from datetime import timedelta from typing import Any, AsyncGenerator, Callable, Literal -import nanoid # type: ignore +import nanoid import websockets from aiochannel import Channel from aiochannel.errors import ChannelClosed from opentelemetry.trace import Span from websockets.exceptions import ConnectionClosed +from websockets.frames import CloseCode from replit_river.error_schema import ( ERROR_CODE_CANCEL, @@ -27,14 +28,19 @@ from replit_river.rpc import ( ACK_BIT, STREAM_OPEN_BIT, + TransportMessage, ) from replit_river.seq_manager import ( IgnoreMessageException, InvalidMessageException, OutOfOrderMessageException, ) -from replit_river.session import CloseSessionCallback, RetryConnectionCallback, Session from replit_river.transport_options import MAX_MESSAGE_BUFFER_SIZE, TransportOptions +from replit_river.v2.session import ( + CloseSessionCallback, + RetryConnectionCallback, + Session, +) STREAM_CANCEL_BIT_TYPE = Literal[0b00100] STREAM_CANCEL_BIT: STREAM_CANCEL_BIT_TYPE = 0b00100 @@ -67,18 +73,42 @@ def __init__( ) async def do_close_websocket() -> None: - await self.close_websocket( - self._ws_wrapper, - should_retry=True, - ) + if self._ws_unwrapped: + self._task_manager.create_task(self._ws_unwrapped.close()) + if self._retry_connection_callback: + self._task_manager.create_task(self._retry_connection_callback()) await self._begin_close_session_countdown() self._setup_heartbeats_task(do_close_websocket) + def commit(msg: TransportMessage) -> None: + pending = self._send_buffer.popleft() + if msg.seq != pending.seq: + logger.error("Out of sequence error") + self._ack_buffer.append(pending) + + # On commit, release pending writers waiting for more buffer space + if self._queue_full_lock.locked(): + self._queue_full_lock.release() + + def get_next_pending() -> TransportMessage | None: + if self._send_buffer: + return self._send_buffer[0] + return None + + self._task_manager.create_task( + buffered_message_sender( + get_ws=lambda: self._ws_unwrapped, + websocket_closed_callback=self._begin_close_session_countdown, + get_next_pending=get_next_pending, + commit=commit, + ) + ) + async def start_serve_responses(self) -> None: - self._task_manager.create_task(self.serve()) + self._task_manager.create_task(self._serve()) - async def serve(self) -> None: + async def _serve(self) -> None: """Serve messages from the websocket.""" self._reset_session_close_countdown() try: @@ -105,16 +135,18 @@ async def serve(self) -> None: ) async def _handle_messages_from_ws(self) -> None: + while self._ws_unwrapped is None: + await asyncio.sleep(1) logger.debug( "%s start handling messages from ws %s", "client", - self._ws_wrapper.id, + self._ws_unwrapped.id, ) try: - ws_wrapper = self._ws_wrapper - async for message in ws_wrapper.ws: + ws = self._ws_unwrapped + async for message in ws: try: - if not await ws_wrapper.is_open(): + if not self._ws_unwrapped: # We should not process messages if the websocket is closed. break msg = parse_transport_msg(message, self._transport_options) @@ -122,54 +154,76 @@ async def _handle_messages_from_ws(self) -> None: logger.debug(f"{self._transport_id} got a message %r", msg) # Update bookkeeping - await self._seq_manager.check_seq_and_update(msg) - await self._buffer.remove_old_messages( - self._seq_manager.receiver_ack, - ) + if msg.seq < self.ack: + raise IgnoreMessageException( + f"{msg.from_} received duplicate msg, got {msg.seq}" + f" expected {self.ack}" + ) + elif msg.seq > self.ack: + logger.warning( + f"Out of order message received got {msg.seq} expected " + f"{self.ack}" + ) + + raise OutOfOrderMessageException( + f"Out of order message received got {msg.seq} expected " + f"{self.ack}" + ) + + assert msg.seq == self.ack, "Safety net, redundant assertion" + + # Set our next expected ack number + self.ack = msg.seq + 1 + + # Discard old messages from the buffer + while self._ack_buffer and self._ack_buffer[0].seq < msg.ack: + self._ack_buffer.popleft() + self._reset_session_close_countdown() if msg.controlFlags & ACK_BIT != 0: continue - async with self._stream_lock: - stream = self._streams.get(msg.streamId, None) - if msg.controlFlags & STREAM_OPEN_BIT == 0: - if not stream: - logger.warning("no stream for %s", msg.streamId) - raise IgnoreMessageException( - "no stream for message, ignoring" - ) - - if ( - msg.controlFlags & STREAM_CLOSED_BIT != 0 - and msg.payload.get("type", None) == "CLOSE" - ): - # close message is not sent to the stream - pass - else: - try: - await stream.put(msg.payload) - except ChannelClosed: - # The client is no longer interested in this stream, - # just drop the message. - pass - except RuntimeError as e: - raise InvalidMessageException(e) from e - else: + stream = self._streams.get(msg.streamId) + if msg.controlFlags & STREAM_OPEN_BIT != 0: raise InvalidMessageException( "Client should not receive stream open bit" ) + if not stream: + logger.warning("no stream for %s", msg.streamId) + raise IgnoreMessageException("no stream for message, ignoring") + + if ( + msg.controlFlags & STREAM_CLOSED_BIT != 0 + and msg.payload.get("type", None) == "CLOSE" + ): + # close message is not sent to the stream + pass + else: + try: + await stream.put(msg.payload) + except ChannelClosed: + # The client is no longer interested in this stream, + # just drop the message. + pass + except RuntimeError as e: + raise InvalidMessageException(e) from e + if msg.controlFlags & STREAM_CLOSED_BIT != 0: if stream: stream.close() - async with self._stream_lock: - del self._streams[msg.streamId] + del self._streams[msg.streamId] except IgnoreMessageException: logger.debug("Ignoring transport message", exc_info=True) continue except OutOfOrderMessageException: logger.exception("Out of order message, closing connection") - await ws_wrapper.close() + self._task_manager.create_task( + self._ws_unwrapped.close( + code=CloseCode.INVALID_DATA, + reason="Out of order message", + ) + ) return except InvalidMessageException: logger.exception("Got invalid transport message, closing session") diff --git a/src/replit_river/v2/client_transport.py b/src/replit_river/v2/client_transport.py index 464660f2..50002cc1 100644 --- a/src/replit_river/v2/client_transport.py +++ b/src/replit_river/v2/client_transport.py @@ -35,7 +35,6 @@ IgnoreMessageException, InvalidMessageException, ) -from replit_river.session import Session from replit_river.transport_options import ( HandshakeMetadataType, TransportOptions, @@ -43,13 +42,21 @@ ) from replit_river.v2.client_session import ClientSession +from .session import Session + PROTOCOL_VERSION = "v2.0" logger = logging.getLogger(__name__) +class HandshakeBudgetExhaustedException(RiverException): + def __init__(self, code: str, message: str, client_id: str) -> None: + super().__init__(code, message) + self.client_id = client_id + + class ClientTransport(Generic[HandshakeMetadataType]): - _sessions: dict[str, ClientSession] + _session: ClientSession | None def __init__( self, @@ -58,10 +65,9 @@ def __init__( server_id: str, transport_options: TransportOptions, ): - self._sessions = {} + self._session = None self._transport_id = client_id self._transport_options = transport_options - self._session_lock = asyncio.Lock() self._uri_and_metadata_factory = uri_and_metadata_factory self._client_id = client_id @@ -72,19 +78,11 @@ def __init__( # We want to make sure there's only one session creation at a time self._create_session_lock = asyncio.Lock() - async def _close_all_sessions(self) -> None: - sessions = self._sessions.values() - logger.info( - f"start closing sessions {self._transport_id}, number sessions : " - f"{len(sessions)}" - ) - sessions_to_close = list(sessions) - - # closing sessions requires access to the session lock, so we need to close - # them one by one to be safe - for session in sessions_to_close: - await session.close() - + async def _close_session(self) -> None: + logger.info(f"start closing session {self._transport_id}") + if not self._session: + return + await self._session.close() logger.info(f"Transport closed {self._transport_id}") def generate_nanoid(self) -> str: @@ -92,18 +90,23 @@ def generate_nanoid(self) -> str: async def close(self) -> None: self._rate_limiter.close() - await self._close_all_sessions() + await self._close_session() async def get_or_create_session(self) -> ClientSession: + """ + If we have an active session, return it. + If we have a "closed" session, mint a whole new session. + If we have a disconnected session, attempt to start a new WS and use it. + """ async with self._create_session_lock: - existing_session = await self._get_existing_session() - if not existing_session: - return await self._create_new_session() - is_session_open = await existing_session.is_session_open() - if not is_session_open: + existing_session = ( + self._session + if self._session and self._session.is_session_open() + else None + ) + if existing_session is None: return await self._create_new_session() - is_ws_open = await existing_session.is_websocket_open() - if is_ws_open: + if existing_session.is_websocket_open(): return existing_session new_ws, _, hs_response = await self._establish_new_connection( existing_session @@ -117,26 +120,11 @@ async def get_or_create_session(self) -> ClientSession: return existing_session else: logger.info("Closing stale session %s", existing_session.session_id) + await new_ws.close() # NB(dstewart): This wasn't there in the + # v1 transport, were we just leaking WS? await existing_session.close() return await self._create_new_session() - async def _get_existing_session(self) -> ClientSession | None: - async with self._session_lock: - if not self._sessions: - return None - if len(self._sessions) > 1: - raise RiverException( - "session_error", - "More than one session found in client, should only be one", - ) - session = list(self._sessions.values())[0] - if isinstance(session, ClientSession): - return session - else: - raise RiverException( - "session_error", f"Client session type wrong, got {type(session)}" - ) - async def _establish_new_connection( self, old_session: ClientSession | None = None, @@ -157,24 +145,26 @@ async def _establish_new_connection( logger.info(f"Retrying build handshake number {i} times") if not rate_limit.has_budget(client_id): logger.debug("No retry budget for %s.", client_id) - raise RiverException( - ERROR_HANDSHAKE, f"No retry budget for {client_id}" + raise HandshakeBudgetExhaustedException( + ERROR_HANDSHAKE, + "No retry budget", + client_id=client_id, ) from last_error rate_limit.consume_budget(client_id) # if the session is closed, we shouldn't use it - if old_session and not await old_session.is_session_open(): + if old_session and not old_session.is_session_open(): old_session = None try: uri_and_metadata = await self._uri_and_metadata_factory() ws = await websockets.connect(uri_and_metadata["uri"]) - session_id = ( - self.generate_nanoid() - if not old_session - else old_session.session_id - ) + session_id: str + if old_session: + session_id = old_session.session_id + else: + session_id = self.generate_nanoid() try: ( @@ -225,13 +215,13 @@ async def _create_new_session( retry_connection_callback=self._retry_connection, ) - self._sessions[new_session._to_id] = new_session + self._session = new_session await new_session.start_serve_responses() return new_session async def _retry_connection(self) -> ClientSession: if not self._transport_options.transparent_reconnect: - await self._close_all_sessions() + await self._close_session() return await self.get_or_create_session() async def _send_handshake_request( @@ -254,16 +244,17 @@ async def websocket_closed_callback() -> None: logger.error("websocket closed before handshake response") try: + payload = handshake_request.model_dump() await send_transport_message( TransportMessage( - from_=self.transport_id, # type: ignore + from_=self._transport_id, to=self._server_id, streamId=stream_id, controlFlags=0, id=self.generate_nanoid(), seq=0, ack=0, - payload=handshake_request.model_dump(), + payload=payload, ), ws=websocket, websocket_closed_callback=websocket_closed_callback, @@ -320,8 +311,8 @@ async def _establish_handshake( ) case ClientSession(): expectedSessionState = ExpectedSessionState( - nextExpectedSeq=await old_session.get_next_expected_seq(), - nextSentSeq=await old_session.get_next_sent_seq(), + nextExpectedSeq=old_session.ack, + nextSentSeq=old_session.seq, ) case other: assert_never(other) @@ -368,6 +359,5 @@ async def _establish_handshake( return handshake_request, handshake_response async def _delete_session(self, session: Session) -> None: - async with self._session_lock: - if session._to_id in self._sessions: - del self._sessions[session._to_id] + if self._session and session._to_id == self._session._to_id: + self._session = None diff --git a/src/replit_river/v2/session.py b/src/replit_river/v2/session.py new file mode 100644 index 00000000..6648ba5d --- /dev/null +++ b/src/replit_river/v2/session.py @@ -0,0 +1,256 @@ +import asyncio +import logging +from collections import deque +from typing import Any, Awaitable, Callable, Coroutine, TypeAlias + +import nanoid # type: ignore +import websockets +from aiochannel import Channel +from opentelemetry.trace import Span, use_span +from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator +from websockets.frames import CloseCode + +from replit_river.common_session import ( + SessionState, + check_to_close_session, + setup_heartbeat, +) +from replit_river.rpc import ( + TransportMessage, + TransportMessageTracingSetter, +) +from replit_river.task_manager import BackgroundTaskManager +from replit_river.transport_options import TransportOptions + +logger = logging.getLogger(__name__) + +trace_propagator = TraceContextTextMapPropagator() +trace_setter = TransportMessageTracingSetter() + +CloseSessionCallback: TypeAlias = Callable[["Session"], Coroutine[Any, Any, Any]] +RetryConnectionCallback: TypeAlias = Callable[ + [], + Coroutine[Any, Any, Any], +] + + +class Session: + _transport_id: str + _to_id: str + session_id: str + _transport_options: TransportOptions + + # session state, only modified during closing + _state: SessionState + _close_session_callback: CloseSessionCallback + _close_session_after_time_secs: float | None + + # ws state + _ws_connected: bool + _ws_unwrapped: websockets.WebSocketCommonProtocol | None + _heartbeat_misses: int + _retry_connection_callback: RetryConnectionCallback | None + + # stream for tasks + _streams: dict[str, Channel[Any]] + + # book keeping + _ack_buffer: deque[TransportMessage] + _send_buffer: deque[TransportMessage] + _task_manager: BackgroundTaskManager + ack: int # Most recently acknowledged seq + seq: int # Last sent sequence number + + def __init__( + self, + transport_id: str, + to_id: str, + session_id: str, + websocket: websockets.WebSocketCommonProtocol, + transport_options: TransportOptions, + close_session_callback: CloseSessionCallback, + retry_connection_callback: RetryConnectionCallback | None = None, + ) -> None: + self._transport_id = transport_id + self._to_id = to_id + self.session_id = session_id + self._transport_options = transport_options + + # session state, only modified during closing + self._state = SessionState.ACTIVE + self._close_session_callback = close_session_callback + self._close_session_after_time_secs: float | None = None + + # ws state + self._ws_connected = True + self._ws_unwrapped = websocket + self._heartbeat_misses = 0 + self._retry_connection_callback = retry_connection_callback + + # message state + self._space_available_cond = asyncio.Condition() + self._queue_full_lock = asyncio.Lock() + + # stream for tasks + self._streams: dict[str, Channel[Any]] = {} + + # book keeping + self._ack_buffer = deque() + self._send_buffer = deque() + self._task_manager = BackgroundTaskManager() + self.ack = 0 + self.seq = 0 + + def _setup_heartbeats_task( + self, + do_close_websocket: Callable[[], Awaitable[None]], + ) -> None: + def increment_and_get_heartbeat_misses() -> int: + self._heartbeat_misses += 1 + return self._heartbeat_misses + + self._task_manager.create_task( + setup_heartbeat( + self.session_id, + self._transport_options.heartbeat_ms, + self._transport_options.heartbeats_until_dead, + lambda: self._state, + lambda: self._close_session_after_time_secs, + close_websocket=do_close_websocket, + send_message=self.send_message, + increment_and_get_heartbeat_misses=increment_and_get_heartbeat_misses, + ) + ) + self._task_manager.create_task( + check_to_close_session( + self._transport_id, + self._transport_options.close_session_check_interval_ms, + lambda: self._state, + self._get_current_time, + lambda: self._close_session_after_time_secs, + self.close, + ) + ) + + def is_session_open(self) -> bool: + return self._state == SessionState.ACTIVE + + def is_websocket_open(self) -> bool: + return self._ws_connected + + async def _begin_close_session_countdown(self) -> None: + """Begin the countdown to close session, this should be called when + websocket is closed. + """ + # calculate the value now before establishing it so that there are no + # await points between the check and the assignment to avoid a TOCTOU + # race. + grace_period_ms = self._transport_options.session_disconnect_grace_ms + close_session_after_time_secs = ( + await self._get_current_time() + grace_period_ms / 1000 + ) + if self._close_session_after_time_secs is not None: + # already in grace period, no need to set again + return + logger.info( + "websocket closed from %s to %s begin grace period", + self._transport_id, + self._to_id, + ) + self._close_session_after_time_secs = close_session_after_time_secs + + async def replace_with_new_websocket( + self, new_ws: websockets.WebSocketCommonProtocol + ) -> None: + if self._ws_unwrapped and new_ws.id != self._ws_unwrapped.id: + self._task_manager.create_task( + self._ws_unwrapped.close( + CloseCode.PROTOCOL_ERROR, "Transparent reconnect" + ) + ) + self._ws_unwrapped = new_ws + + async def _get_current_time(self) -> float: + return asyncio.get_event_loop().time() + + def _reset_session_close_countdown(self) -> None: + self._heartbeat_misses = 0 + self._close_session_after_time_secs = None + + async def send_message( + self, + stream_id: str, + payload: dict[Any, Any] | str, + control_flags: int = 0, + service_name: str | None = None, + procedure_name: str | None = None, + span: Span | None = None, + ) -> None: + """Send serialized messages to the websockets.""" + # if the session is not active, we should not do anything + if self._state != SessionState.ACTIVE: + return + msg = TransportMessage( + streamId=stream_id, + id=nanoid.generate(), + from_=self._transport_id, + to=self._to_id, + seq=self.seq, + ack=self.ack, + controlFlags=control_flags, + payload=payload, + serviceName=service_name, + procedureName=procedure_name, + ) + + if span: + with use_span(span): + trace_propagator.inject(msg, None, trace_setter) + + # As we prepare to push onto the buffer, if the buffer is full, we lock. + # This lock will be released by the buffered_message_sender task, so it's + # important that we don't release it here. + # + # The reason for this is that in Python, asyncio.Lock is "fair", first + # come, first served. + # + # If somebody else is already waiting or we've filled the buffer, we + # should get in line. + if ( + self._queue_full_lock.locked() + or len(self._send_buffer) >= self._transport_options.buffer_size + ): + logger.warning("LOCK ACQUIRED %r", repr(payload)) + await self._queue_full_lock.acquire() + logger.warning("LOCK RELEASED %r", repr(payload)) + self._send_buffer.append(msg) + self.seq += 1 + + async def close(self) -> None: + """Close the session and all associated streams.""" + logger.info( + f"{self._transport_id} closing session " + f"to {self._to_id}, ws: {self._ws_unwrapped}" + ) + if self._state != SessionState.ACTIVE: + # already closing + return + self._state = SessionState.CLOSING + self._reset_session_close_countdown() + await self._task_manager.cancel_all_tasks() + + if self._ws_unwrapped: + # The Session isn't guaranteed to live much longer than this close() + # invocation, so let's await this close to avoid dropping the socket. + await self._ws_unwrapped.close() + + # Clear the session in transports + await self._close_session_callback(self) + + # TODO: unexpected_close should close stream differently here to + # throw exception correctly. + for stream in self._streams.values(): + stream.close() + self._streams.clear() + + self._state = SessionState.CLOSED From 3e7bc4ea403233014bb5dc0645de432b5a708b3c Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Sun, 23 Mar 2025 21:52:05 -0700 Subject: [PATCH 028/193] :thinking: Avoid sending more messages when we know we are closed --- src/replit_river/v2/client_session.py | 3 ++- src/replit_river/v2/session.py | 1 + 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/src/replit_river/v2/client_session.py b/src/replit_river/v2/client_session.py index 53249c06..7546a9bf 100644 --- a/src/replit_river/v2/client_session.py +++ b/src/replit_river/v2/client_session.py @@ -74,6 +74,7 @@ def __init__( async def do_close_websocket() -> None: if self._ws_unwrapped: + self._ws_connected = False self._task_manager.create_task(self._ws_unwrapped.close()) if self._retry_connection_callback: self._task_manager.create_task(self._retry_connection_callback()) @@ -98,7 +99,7 @@ def get_next_pending() -> TransportMessage | None: self._task_manager.create_task( buffered_message_sender( - get_ws=lambda: self._ws_unwrapped, + get_ws=lambda: self._ws_unwrapped if self.is_websocket_open() else None, websocket_closed_callback=self._begin_close_session_countdown, get_next_pending=get_next_pending, commit=commit, diff --git a/src/replit_river/v2/session.py b/src/replit_river/v2/session.py index 6648ba5d..b6e0de35 100644 --- a/src/replit_river/v2/session.py +++ b/src/replit_river/v2/session.py @@ -169,6 +169,7 @@ async def replace_with_new_websocket( ) ) self._ws_unwrapped = new_ws + self._ws_connected = True async def _get_current_time(self) -> float: return asyncio.get_event_loop().time() From 7a577c87a7ce10f92d757ebd56a4996ada8d5458 Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Sun, 23 Mar 2025 21:59:58 -0700 Subject: [PATCH 029/193] Skipping heartbeats when we know we're closed --- src/replit_river/v2/session.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/replit_river/v2/session.py b/src/replit_river/v2/session.py index b6e0de35..7d86aca4 100644 --- a/src/replit_river/v2/session.py +++ b/src/replit_river/v2/session.py @@ -115,6 +115,7 @@ def increment_and_get_heartbeat_misses() -> int: self._transport_options.heartbeat_ms, self._transport_options.heartbeats_until_dead, lambda: self._state, + lambda: self._ws_connected, lambda: self._close_session_after_time_secs, close_websocket=do_close_websocket, send_message=self.send_message, From bca134fc7941079ec947d80421b2b234cd349780 Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Sun, 23 Mar 2025 22:07:31 -0700 Subject: [PATCH 030/193] One more I think --- src/replit_river/v2/session.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/replit_river/v2/session.py b/src/replit_river/v2/session.py index 7d86aca4..12b6d370 100644 --- a/src/replit_river/v2/session.py +++ b/src/replit_river/v2/session.py @@ -159,6 +159,7 @@ async def _begin_close_session_countdown(self) -> None: self._to_id, ) self._close_session_after_time_secs = close_session_after_time_secs + self._ws_connected = False async def replace_with_new_websocket( self, new_ws: websockets.WebSocketCommonProtocol From 61b8b7f987dd8dc983c7db3fe174531eff50622a Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Sun, 23 Mar 2025 22:11:51 -0700 Subject: [PATCH 031/193] Unsure, was ws unset by this point? --- src/replit_river/v2/client_session.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/replit_river/v2/client_session.py b/src/replit_river/v2/client_session.py index 7546a9bf..6d057d34 100644 --- a/src/replit_river/v2/client_session.py +++ b/src/replit_river/v2/client_session.py @@ -73,8 +73,9 @@ def __init__( ) async def do_close_websocket() -> None: + logger.debug("do_close called, _ws_connected=%r, _ws_unwrapped=%r", self._ws_connected, self._ws_unwrapped) + self._ws_connected = False if self._ws_unwrapped: - self._ws_connected = False self._task_manager.create_task(self._ws_unwrapped.close()) if self._retry_connection_callback: self._task_manager.create_task(self._retry_connection_callback()) From 0fba294ccf4e01b180c52c8c334bf60c45143846 Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Mon, 24 Mar 2025 12:03:49 -0700 Subject: [PATCH 032/193] Missing parameter in v1 client --- src/replit_river/session.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/replit_river/session.py b/src/replit_river/session.py index ac01ffba..517b53f6 100644 --- a/src/replit_river/session.py +++ b/src/replit_river/session.py @@ -24,7 +24,7 @@ ) from replit_river.task_manager import BackgroundTaskManager from replit_river.transport_options import TransportOptions -from replit_river.websocket_wrapper import WebsocketWrapper +from replit_river.websocket_wrapper import WebsocketWrapper, WsState from .rpc import ( ACK_BIT, @@ -121,6 +121,7 @@ def increment_and_get_heartbeat_misses() -> int: self._transport_options.heartbeat_ms, self._transport_options.heartbeats_until_dead, lambda: self._state, + lambda: self._ws_wrapper.ws_state == WsState.OPEN, lambda: self._close_session_after_time_secs, close_websocket=do_close_websocket, send_message=self.send_message, From d1702ace9ed52274c12aa975e208915b06101250 Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Mon, 24 Mar 2025 12:04:07 -0700 Subject: [PATCH 033/193] Upgrade to the new websocket impl --- src/replit_river/common_session.py | 1 + src/replit_river/messages.py | 3 ++- src/replit_river/v2/client.py | 21 --------------------- src/replit_river/v2/client_session.py | 12 +++++++++--- src/replit_river/v2/client_transport.py | 16 ++++++++-------- src/replit_river/v2/session.py | 7 ++++--- 6 files changed, 24 insertions(+), 36 deletions(-) diff --git a/src/replit_river/common_session.py b/src/replit_river/common_session.py index f697c1ae..f1f9ccd5 100644 --- a/src/replit_river/common_session.py +++ b/src/replit_river/common_session.py @@ -4,6 +4,7 @@ from opentelemetry.trace import Span from websockets import WebSocketCommonProtocol +from websockets.asyncio.client import ClientConnection logger = logging.getLogger(__name__) diff --git a/src/replit_river/messages.py b/src/replit_river/messages.py index 9cdf324a..653929a6 100644 --- a/src/replit_river/messages.py +++ b/src/replit_river/messages.py @@ -8,6 +8,7 @@ from websockets import ( WebSocketCommonProtocol, ) +from websockets.asyncio.client import ClientConnection from replit_river.rpc import ( TransportMessage, @@ -29,7 +30,7 @@ class FailedSendingMessageException(Exception): async def send_transport_message( msg: TransportMessage, - ws: WebSocketCommonProtocol, + ws: WebSocketCommonProtocol | ClientConnection, # legacy | asyncio websocket_closed_callback: Callable[[], Coroutine[Any, Any, None]], ) -> None: logger.debug("sending a message %r to ws %s", msg, ws) diff --git a/src/replit_river/v2/client.py b/src/replit_river/v2/client.py index a7c09f74..958ff67d 100644 --- a/src/replit_river/v2/client.py +++ b/src/replit_river/v2/client.py @@ -58,27 +58,6 @@ def translate_unknown_error( return RiverUnknownError(code=ERROR_CODE_UNKNOWN, message="Unknown error") -# Client[HandshakeSchema]( -# uri_and_metadata_factory=uri_and_metadata_factory, -# client_id=self.client_id, -# server_id="SERVER", -# transport_options=TransportOptions( -# session_disconnect_grace_ms=settings.RIVER_SESSION_DISCONNECT_GRACE_MS, -# heartbeat_ms=settings.RIVER_HEARTBEAT_MS, -# heartbeats_until_dead=settings.RIVER_HEARTBEATS_UNTIL_DEAD, -# connection_retry_options=ConnectionRetryOptions( -# base_interval_ms=settings.RIVER_CONNECTION_BASE_INTERVAL_MS, -# max_jitter_ms=settings.RIVER_CONNECTION_MAX_JITTER_MS, -# max_backoff_ms=settings.RIVER_CONNECTION_MAX_BACKOFF_MS, -# attempt_budget_capacity=self.attempt_budget_capacity, -# budget_restore_interval_ms= -# settings.RIVER_CONNECTION_BUDGET_RESTORE_INTERVAL_MS, -# max_retry=self.max_retry_count, -# ), -# ), -# ) - - class Client(Generic[HandshakeMetadataType]): def __init__( self, diff --git a/src/replit_river/v2/client_session.py b/src/replit_river/v2/client_session.py index 6d057d34..989e9dc9 100644 --- a/src/replit_river/v2/client_session.py +++ b/src/replit_river/v2/client_session.py @@ -2,15 +2,17 @@ import logging from collections.abc import AsyncIterable from datetime import timedelta -from typing import Any, AsyncGenerator, Callable, Literal +from typing import Any, AsyncGenerator, Callable, Literal, cast import nanoid import websockets from aiochannel import Channel from aiochannel.errors import ChannelClosed from opentelemetry.trace import Span +from websockets.asyncio.client import ClientConnection from websockets.exceptions import ConnectionClosed from websockets.frames import CloseCode +from websockets.legacy.protocol import WebSocketCommonProtocol from replit_river.error_schema import ( ERROR_CODE_CANCEL, @@ -57,7 +59,7 @@ def __init__( transport_id: str, to_id: str, session_id: str, - websocket: websockets.WebSocketCommonProtocol, + websocket: ClientConnection, transport_options: TransportOptions, close_session_callback: CloseSessionCallback, retry_connection_callback: RetryConnectionCallback | None = None, @@ -100,7 +102,11 @@ def get_next_pending() -> TransportMessage | None: self._task_manager.create_task( buffered_message_sender( - get_ws=lambda: self._ws_unwrapped if self.is_websocket_open() else None, + get_ws=lambda: ( + cast(WebSocketCommonProtocol | ClientConnection, self._ws_unwrapped) + if self.is_websocket_open() + else None + ), websocket_closed_callback=self._begin_close_session_countdown, get_next_pending=get_next_pending, commit=commit, diff --git a/src/replit_river/v2/client_transport.py b/src/replit_river/v2/client_transport.py index 50002cc1..f2056d62 100644 --- a/src/replit_river/v2/client_transport.py +++ b/src/replit_river/v2/client_transport.py @@ -6,9 +6,9 @@ import nanoid import websockets from pydantic import ValidationError -from websockets import ( - WebSocketCommonProtocol, -) +import websockets.asyncio.client +from websockets import WebSocketCommonProtocol +from websockets.asyncio.client import ClientConnection from websockets.exceptions import ConnectionClosed from replit_river.error_schema import ( @@ -129,7 +129,7 @@ async def _establish_new_connection( self, old_session: ClientSession | None = None, ) -> tuple[ - WebSocketCommonProtocol, + ClientConnection, ControlMessageHandshakeRequest[HandshakeMetadataType], ControlMessageHandshakeResponse, ]: @@ -159,7 +159,7 @@ async def _establish_new_connection( try: uri_and_metadata = await self._uri_and_metadata_factory() - ws = await websockets.connect(uri_and_metadata["uri"]) + ws = await websockets.asyncio.client.connect(uri_and_metadata["uri"]) session_id: str if old_session: session_id = old_session.session_id @@ -228,7 +228,7 @@ async def _send_handshake_request( self, session_id: str, handshake_metadata: HandshakeMetadataType | None, - websocket: WebSocketCommonProtocol, + websocket: ClientConnection, expected_session_state: ExpectedSessionState, ) -> ControlMessageHandshakeRequest[HandshakeMetadataType]: handshake_request = ControlMessageHandshakeRequest[HandshakeMetadataType]( @@ -266,7 +266,7 @@ async def websocket_closed_callback() -> None: ) from e async def _get_handshake_response_msg( - self, websocket: WebSocketCommonProtocol + self, websocket: ClientConnection ) -> TransportMessage: while True: try: @@ -295,7 +295,7 @@ async def _establish_handshake( self, session_id: str, handshake_metadata: HandshakeMetadataType, - websocket: WebSocketCommonProtocol, + websocket: ClientConnection, old_session: ClientSession | None, ) -> tuple[ ControlMessageHandshakeRequest[HandshakeMetadataType], diff --git a/src/replit_river/v2/session.py b/src/replit_river/v2/session.py index 12b6d370..74da65a6 100644 --- a/src/replit_river/v2/session.py +++ b/src/replit_river/v2/session.py @@ -8,6 +8,7 @@ from aiochannel import Channel from opentelemetry.trace import Span, use_span from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator +from websockets.asyncio.client import ClientConnection from websockets.frames import CloseCode from replit_river.common_session import ( @@ -47,7 +48,7 @@ class Session: # ws state _ws_connected: bool - _ws_unwrapped: websockets.WebSocketCommonProtocol | None + _ws_unwrapped: ClientConnection | None _heartbeat_misses: int _retry_connection_callback: RetryConnectionCallback | None @@ -66,7 +67,7 @@ def __init__( transport_id: str, to_id: str, session_id: str, - websocket: websockets.WebSocketCommonProtocol, + websocket: ClientConnection, transport_options: TransportOptions, close_session_callback: CloseSessionCallback, retry_connection_callback: RetryConnectionCallback | None = None, @@ -162,7 +163,7 @@ async def _begin_close_session_countdown(self) -> None: self._ws_connected = False async def replace_with_new_websocket( - self, new_ws: websockets.WebSocketCommonProtocol + self, new_ws: ClientConnection ) -> None: if self._ws_unwrapped and new_ws.id != self._ws_unwrapped.id: self._task_manager.create_task( From 5cb47e2c5d719b1a1ec2533d677552010a4b7303 Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Mon, 24 Mar 2025 12:08:50 -0700 Subject: [PATCH 034/193] Removing transport_options parameter --- src/replit_river/v2/client_session.py | 2 +- src/replit_river/v2/client_transport.py | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/replit_river/v2/client_session.py b/src/replit_river/v2/client_session.py index 989e9dc9..f203209f 100644 --- a/src/replit_river/v2/client_session.py +++ b/src/replit_river/v2/client_session.py @@ -157,7 +157,7 @@ async def _handle_messages_from_ws(self) -> None: if not self._ws_unwrapped: # We should not process messages if the websocket is closed. break - msg = parse_transport_msg(message, self._transport_options) + msg = parse_transport_msg(message) logger.debug(f"{self._transport_id} got a message %r", msg) diff --git a/src/replit_river/v2/client_transport.py b/src/replit_river/v2/client_transport.py index f2056d62..68f105c2 100644 --- a/src/replit_river/v2/client_transport.py +++ b/src/replit_river/v2/client_transport.py @@ -281,10 +281,10 @@ async def _get_handshake_response_msg( "Handshake failed, conn closed while waiting for response", ) from e try: - return parse_transport_msg(data, self._transport_options) - except IgnoreMessageException: - logger.debug("Ignoring transport message", exc_info=True) - continue + msg = parse_transport_msg(data) + if isinstance(msg, str): + logger.debug("Ignoring transport message", exc_info=True) + continue except InvalidMessageException as e: raise RiverException( ERROR_HANDSHAKE, From f9aba7a5ff56edcd16c9c8a772b7504c6468f76a Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Mon, 24 Mar 2025 12:10:25 -0700 Subject: [PATCH 035/193] Formatting --- src/replit_river/v2/client_session.py | 7 +++++-- src/replit_river/v2/client_transport.py | 3 +-- src/replit_river/v2/session.py | 5 +---- 3 files changed, 7 insertions(+), 8 deletions(-) diff --git a/src/replit_river/v2/client_session.py b/src/replit_river/v2/client_session.py index f203209f..b2d98348 100644 --- a/src/replit_river/v2/client_session.py +++ b/src/replit_river/v2/client_session.py @@ -5,7 +5,6 @@ from typing import Any, AsyncGenerator, Callable, Literal, cast import nanoid -import websockets from aiochannel import Channel from aiochannel.errors import ChannelClosed from opentelemetry.trace import Span @@ -75,7 +74,11 @@ def __init__( ) async def do_close_websocket() -> None: - logger.debug("do_close called, _ws_connected=%r, _ws_unwrapped=%r", self._ws_connected, self._ws_unwrapped) + logger.debug( + "do_close called, _ws_connected=%r, _ws_unwrapped=%r", + self._ws_connected, + self._ws_unwrapped, + ) self._ws_connected = False if self._ws_unwrapped: self._task_manager.create_task(self._ws_unwrapped.close()) diff --git a/src/replit_river/v2/client_transport.py b/src/replit_river/v2/client_transport.py index 68f105c2..18240b09 100644 --- a/src/replit_river/v2/client_transport.py +++ b/src/replit_river/v2/client_transport.py @@ -5,9 +5,8 @@ import nanoid import websockets -from pydantic import ValidationError import websockets.asyncio.client -from websockets import WebSocketCommonProtocol +from pydantic import ValidationError from websockets.asyncio.client import ClientConnection from websockets.exceptions import ConnectionClosed diff --git a/src/replit_river/v2/session.py b/src/replit_river/v2/session.py index 74da65a6..dc8f91b6 100644 --- a/src/replit_river/v2/session.py +++ b/src/replit_river/v2/session.py @@ -4,7 +4,6 @@ from typing import Any, Awaitable, Callable, Coroutine, TypeAlias import nanoid # type: ignore -import websockets from aiochannel import Channel from opentelemetry.trace import Span, use_span from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator @@ -162,9 +161,7 @@ async def _begin_close_session_countdown(self) -> None: self._close_session_after_time_secs = close_session_after_time_secs self._ws_connected = False - async def replace_with_new_websocket( - self, new_ws: ClientConnection - ) -> None: + async def replace_with_new_websocket(self, new_ws: ClientConnection) -> None: if self._ws_unwrapped and new_ws.id != self._ws_unwrapped.id: self._task_manager.create_task( self._ws_unwrapped.close( From 736c3cdb42a3847325840c894e7309c9e4c7b0ae Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Mon, 24 Mar 2025 12:17:28 -0700 Subject: [PATCH 036/193] Switch to raw bytes recv() call to avoid round-tripping through str --- src/replit_river/v2/client_session.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/src/replit_river/v2/client_session.py b/src/replit_river/v2/client_session.py index b2d98348..8655e2ee 100644 --- a/src/replit_river/v2/client_session.py +++ b/src/replit_river/v2/client_session.py @@ -9,7 +9,7 @@ from aiochannel.errors import ChannelClosed from opentelemetry.trace import Span from websockets.asyncio.client import ClientConnection -from websockets.exceptions import ConnectionClosed +from websockets.exceptions import ConnectionClosed, ConnectionClosedOK from websockets.frames import CloseCode from websockets.legacy.protocol import WebSocketCommonProtocol @@ -155,7 +155,11 @@ async def _handle_messages_from_ws(self) -> None: ) try: ws = self._ws_unwrapped - async for message in ws: + while True: + # decode=False: Avoiding an unnecessary round-trip through str + # Ideally this should be type-ascripted to : bytes, but there is no + # @overrides in `websockets` to hint this. + message = await ws.recv(decode=False) try: if not self._ws_unwrapped: # We should not process messages if the websocket is closed. @@ -240,6 +244,8 @@ async def _handle_messages_from_ws(self) -> None: logger.exception("Got invalid transport message, closing session") await self.close() return + except ConnectionClosedOK: + pass # Exited normally except ConnectionClosed as e: raise e From b2f6582ac13dbc280a533473f726f30f17c047fc Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Mon, 24 Mar 2025 12:18:26 -0700 Subject: [PATCH 037/193] wat --- src/replit_river/v2/client_session.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/replit_river/v2/client_session.py b/src/replit_river/v2/client_session.py index 8655e2ee..dcda8ca8 100644 --- a/src/replit_river/v2/client_session.py +++ b/src/replit_river/v2/client_session.py @@ -156,14 +156,15 @@ async def _handle_messages_from_ws(self) -> None: try: ws = self._ws_unwrapped while True: + if not self._ws_unwrapped: + # We should not process messages if the websocket is closed. + break + # decode=False: Avoiding an unnecessary round-trip through str # Ideally this should be type-ascripted to : bytes, but there is no # @overrides in `websockets` to hint this. message = await ws.recv(decode=False) try: - if not self._ws_unwrapped: - # We should not process messages if the websocket is closed. - break msg = parse_transport_msg(message) logger.debug(f"{self._transport_id} got a message %r", msg) From 1068f1295f692ef38e2404f8a8950ec764852ecb Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Mon, 24 Mar 2025 12:54:11 -0700 Subject: [PATCH 038/193] Bite the bullet and merge client_session and session together --- src/replit_river/v2/__init__.py | 3 +- src/replit_river/v2/client_session.py | 557 ------------------------ src/replit_river/v2/client_transport.py | 20 +- src/replit_river/v2/schema.py | 2 +- src/replit_river/v2/session.py | 528 +++++++++++++++++++++- 5 files changed, 537 insertions(+), 573 deletions(-) delete mode 100644 src/replit_river/v2/client_session.py diff --git a/src/replit_river/v2/__init__.py b/src/replit_river/v2/__init__.py index 142ab2a0..a9b0c7ee 100644 --- a/src/replit_river/v2/__init__.py +++ b/src/replit_river/v2/__init__.py @@ -1,6 +1,5 @@ -from replit_river.v2.session import Session - from .client import Client +from .session import Session __all__ = [ "Client", diff --git a/src/replit_river/v2/client_session.py b/src/replit_river/v2/client_session.py deleted file mode 100644 index dcda8ca8..00000000 --- a/src/replit_river/v2/client_session.py +++ /dev/null @@ -1,557 +0,0 @@ -import asyncio -import logging -from collections.abc import AsyncIterable -from datetime import timedelta -from typing import Any, AsyncGenerator, Callable, Literal, cast - -import nanoid -from aiochannel import Channel -from aiochannel.errors import ChannelClosed -from opentelemetry.trace import Span -from websockets.asyncio.client import ClientConnection -from websockets.exceptions import ConnectionClosed, ConnectionClosedOK -from websockets.frames import CloseCode -from websockets.legacy.protocol import WebSocketCommonProtocol - -from replit_river.error_schema import ( - ERROR_CODE_CANCEL, - ERROR_CODE_STREAM_CLOSED, - RiverError, - RiverException, - RiverServiceException, - StreamClosedRiverServiceException, - exception_from_message, -) -from replit_river.messages import ( - FailedSendingMessageException, - parse_transport_msg, -) -from replit_river.rpc import ( - ACK_BIT, - STREAM_OPEN_BIT, - TransportMessage, -) -from replit_river.seq_manager import ( - IgnoreMessageException, - InvalidMessageException, - OutOfOrderMessageException, -) -from replit_river.transport_options import MAX_MESSAGE_BUFFER_SIZE, TransportOptions -from replit_river.v2.session import ( - CloseSessionCallback, - RetryConnectionCallback, - Session, -) - -STREAM_CANCEL_BIT_TYPE = Literal[0b00100] -STREAM_CANCEL_BIT: STREAM_CANCEL_BIT_TYPE = 0b00100 -STREAM_CLOSED_BIT_TYPE = Literal[0b01000] -STREAM_CLOSED_BIT: STREAM_CLOSED_BIT_TYPE = 0b01000 - - -logger = logging.getLogger(__name__) - - -class ClientSession(Session): - def __init__( - self, - transport_id: str, - to_id: str, - session_id: str, - websocket: ClientConnection, - transport_options: TransportOptions, - close_session_callback: CloseSessionCallback, - retry_connection_callback: RetryConnectionCallback | None = None, - ) -> None: - super().__init__( - transport_id=transport_id, - to_id=to_id, - session_id=session_id, - websocket=websocket, - transport_options=transport_options, - close_session_callback=close_session_callback, - retry_connection_callback=retry_connection_callback, - ) - - async def do_close_websocket() -> None: - logger.debug( - "do_close called, _ws_connected=%r, _ws_unwrapped=%r", - self._ws_connected, - self._ws_unwrapped, - ) - self._ws_connected = False - if self._ws_unwrapped: - self._task_manager.create_task(self._ws_unwrapped.close()) - if self._retry_connection_callback: - self._task_manager.create_task(self._retry_connection_callback()) - await self._begin_close_session_countdown() - - self._setup_heartbeats_task(do_close_websocket) - - def commit(msg: TransportMessage) -> None: - pending = self._send_buffer.popleft() - if msg.seq != pending.seq: - logger.error("Out of sequence error") - self._ack_buffer.append(pending) - - # On commit, release pending writers waiting for more buffer space - if self._queue_full_lock.locked(): - self._queue_full_lock.release() - - def get_next_pending() -> TransportMessage | None: - if self._send_buffer: - return self._send_buffer[0] - return None - - self._task_manager.create_task( - buffered_message_sender( - get_ws=lambda: ( - cast(WebSocketCommonProtocol | ClientConnection, self._ws_unwrapped) - if self.is_websocket_open() - else None - ), - websocket_closed_callback=self._begin_close_session_countdown, - get_next_pending=get_next_pending, - commit=commit, - ) - ) - - async def start_serve_responses(self) -> None: - self._task_manager.create_task(self._serve()) - - async def _serve(self) -> None: - """Serve messages from the websocket.""" - self._reset_session_close_countdown() - try: - try: - await self._handle_messages_from_ws() - except ConnectionClosed: - if self._retry_connection_callback: - self._task_manager.create_task(self._retry_connection_callback()) - - await self._begin_close_session_countdown() - logger.debug("ConnectionClosed while serving", exc_info=True) - except FailedSendingMessageException: - # Expected error if the connection is closed. - logger.debug( - "FailedSendingMessageException while serving", exc_info=True - ) - except Exception: - logger.exception("caught exception at message iterator") - except ExceptionGroup as eg: - _, unhandled = eg.split(lambda e: isinstance(e, ConnectionClosed)) - if unhandled: - raise ExceptionGroup( - "Unhandled exceptions on River server", unhandled.exceptions - ) - - async def _handle_messages_from_ws(self) -> None: - while self._ws_unwrapped is None: - await asyncio.sleep(1) - logger.debug( - "%s start handling messages from ws %s", - "client", - self._ws_unwrapped.id, - ) - try: - ws = self._ws_unwrapped - while True: - if not self._ws_unwrapped: - # We should not process messages if the websocket is closed. - break - - # decode=False: Avoiding an unnecessary round-trip through str - # Ideally this should be type-ascripted to : bytes, but there is no - # @overrides in `websockets` to hint this. - message = await ws.recv(decode=False) - try: - msg = parse_transport_msg(message) - - logger.debug(f"{self._transport_id} got a message %r", msg) - - # Update bookkeeping - if msg.seq < self.ack: - raise IgnoreMessageException( - f"{msg.from_} received duplicate msg, got {msg.seq}" - f" expected {self.ack}" - ) - elif msg.seq > self.ack: - logger.warning( - f"Out of order message received got {msg.seq} expected " - f"{self.ack}" - ) - - raise OutOfOrderMessageException( - f"Out of order message received got {msg.seq} expected " - f"{self.ack}" - ) - - assert msg.seq == self.ack, "Safety net, redundant assertion" - - # Set our next expected ack number - self.ack = msg.seq + 1 - - # Discard old messages from the buffer - while self._ack_buffer and self._ack_buffer[0].seq < msg.ack: - self._ack_buffer.popleft() - - self._reset_session_close_countdown() - - if msg.controlFlags & ACK_BIT != 0: - continue - stream = self._streams.get(msg.streamId) - if msg.controlFlags & STREAM_OPEN_BIT != 0: - raise InvalidMessageException( - "Client should not receive stream open bit" - ) - - if not stream: - logger.warning("no stream for %s", msg.streamId) - raise IgnoreMessageException("no stream for message, ignoring") - - if ( - msg.controlFlags & STREAM_CLOSED_BIT != 0 - and msg.payload.get("type", None) == "CLOSE" - ): - # close message is not sent to the stream - pass - else: - try: - await stream.put(msg.payload) - except ChannelClosed: - # The client is no longer interested in this stream, - # just drop the message. - pass - except RuntimeError as e: - raise InvalidMessageException(e) from e - - if msg.controlFlags & STREAM_CLOSED_BIT != 0: - if stream: - stream.close() - del self._streams[msg.streamId] - except IgnoreMessageException: - logger.debug("Ignoring transport message", exc_info=True) - continue - except OutOfOrderMessageException: - logger.exception("Out of order message, closing connection") - self._task_manager.create_task( - self._ws_unwrapped.close( - code=CloseCode.INVALID_DATA, - reason="Out of order message", - ) - ) - return - except InvalidMessageException: - logger.exception("Got invalid transport message, closing session") - await self.close() - return - except ConnectionClosedOK: - pass # Exited normally - except ConnectionClosed as e: - raise e - - async def send_rpc[R, A]( - self, - service_name: str, - procedure_name: str, - request: R, - request_serializer: Callable[[R], Any], - response_deserializer: Callable[[Any], A], - error_deserializer: Callable[[Any], RiverError], - span: Span, - timeout: timedelta, - ) -> A: - """Sends a single RPC request to the server. - - Expects the input and output be messages that will be msgpacked. - """ - stream_id = nanoid.generate() - output: Channel[Any] = Channel(1) - self._streams[stream_id] = output - await self.send_message( - stream_id=stream_id, - control_flags=STREAM_OPEN_BIT | STREAM_CLOSED_BIT, - payload=request_serializer(request), - service_name=service_name, - procedure_name=procedure_name, - span=span, - ) - # Handle potential errors during communication - try: - try: - async with asyncio.timeout(timeout.total_seconds()): - response = await output.get() - except asyncio.TimeoutError as e: - await self.send_message( - stream_id=stream_id, - control_flags=STREAM_CANCEL_BIT, - payload={"type": "CANCEL"}, - service_name=service_name, - procedure_name=procedure_name, - span=span, - ) - raise RiverException(ERROR_CODE_CANCEL, str(e)) from e - except ChannelClosed as e: - raise RiverServiceException( - ERROR_CODE_STREAM_CLOSED, - "Stream closed before response", - service_name, - procedure_name, - ) from e - except RuntimeError as e: - raise RiverException(ERROR_CODE_STREAM_CLOSED, str(e)) from e - if not response.get("ok", False): - try: - error = error_deserializer(response["payload"]) - except Exception as e: - raise RiverException("error_deserializer", str(e)) from e - raise exception_from_message(error.code)( - error.code, error.message, service_name, procedure_name - ) - return response_deserializer(response["payload"]) - except RiverException as e: - raise e - except Exception as e: - raise e - - async def send_upload[I, R, A]( - self, - service_name: str, - procedure_name: str, - init: I, - request: AsyncIterable[R] | None, - init_serializer: Callable[[I], Any], - request_serializer: Callable[[R], Any] | None, - response_deserializer: Callable[[Any], A], - error_deserializer: Callable[[Any], RiverError], - span: Span, - ) -> A: - """Sends an upload request to the server. - - Expects the input and output be messages that will be msgpacked. - """ - - stream_id = nanoid.generate() - output: Channel[Any] = Channel(1) - self._streams[stream_id] = output - try: - await self.send_message( - stream_id=stream_id, - control_flags=STREAM_OPEN_BIT, - service_name=service_name, - procedure_name=procedure_name, - payload=init_serializer(init), - span=span, - ) - - if request: - assert request_serializer, "send_stream missing request_serializer" - - # If this request is not closed and the session is killed, we should - # throw exception here - async for item in request: - control_flags = 0 - await self.send_message( - stream_id=stream_id, - service_name=service_name, - procedure_name=procedure_name, - control_flags=control_flags, - payload=request_serializer(item), - span=span, - ) - except Exception as e: - raise RiverServiceException( - ERROR_CODE_STREAM_CLOSED, str(e), service_name, procedure_name - ) from e - await self.send_close_stream( - service_name, - procedure_name, - stream_id, - extra_control_flags=0, - ) - - # Handle potential errors during communication - # TODO: throw a error when the transport is hard closed - try: - try: - response = await output.get() - except ChannelClosed as e: - raise RiverServiceException( - ERROR_CODE_STREAM_CLOSED, - "Stream closed before response", - service_name, - procedure_name, - ) from e - except RuntimeError as e: - raise RiverException(ERROR_CODE_STREAM_CLOSED, str(e)) from e - if not response.get("ok", False): - try: - error = error_deserializer(response["payload"]) - except Exception as e: - raise RiverException("error_deserializer", str(e)) from e - raise exception_from_message(error.code)( - error.code, error.message, service_name, procedure_name - ) - - return response_deserializer(response["payload"]) - except RiverException as e: - raise e - except Exception as e: - raise e - - async def send_subscription[R, E, A]( - self, - service_name: str, - procedure_name: str, - request: R, - request_serializer: Callable[[R], Any], - response_deserializer: Callable[[Any], A], - error_deserializer: Callable[[Any], E], - span: Span, - ) -> AsyncGenerator[A | E, None]: - """Sends a subscription request to the server. - - Expects the input and output be messages that will be msgpacked. - """ - stream_id = nanoid.generate() - output: Channel[Any] = Channel(MAX_MESSAGE_BUFFER_SIZE) - self._streams[stream_id] = output - await self.send_message( - service_name=service_name, - procedure_name=procedure_name, - stream_id=stream_id, - control_flags=STREAM_OPEN_BIT, - payload=request_serializer(request), - span=span, - ) - - # Handle potential errors during communication - try: - async for item in output: - if item.get("type", None) == "CLOSE": - break - if not item.get("ok", False): - try: - yield error_deserializer(item["payload"]) - except Exception: - logger.exception( - f"Error during subscription error deserialization: {item}" - ) - continue - yield response_deserializer(item["payload"]) - except (RuntimeError, ChannelClosed) as e: - raise RiverServiceException( - ERROR_CODE_STREAM_CLOSED, - "Stream closed before response", - service_name, - procedure_name, - ) from e - except Exception as e: - raise e - finally: - output.close() - - async def send_stream[I, R, E, A]( - self, - service_name: str, - procedure_name: str, - init: I, - request: AsyncIterable[R] | None, - init_serializer: Callable[[I], Any], - request_serializer: Callable[[R], Any] | None, - response_deserializer: Callable[[Any], A], - error_deserializer: Callable[[Any], E], - span: Span, - ) -> AsyncGenerator[A | E, None]: - """Sends a subscription request to the server. - - Expects the input and output be messages that will be msgpacked. - """ - - stream_id = nanoid.generate() - output: Channel[Any] = Channel(MAX_MESSAGE_BUFFER_SIZE) - self._streams[stream_id] = output - try: - await self.send_message( - service_name=service_name, - procedure_name=procedure_name, - stream_id=stream_id, - control_flags=STREAM_OPEN_BIT, - payload=init_serializer(init), - span=span, - ) - except Exception as e: - raise StreamClosedRiverServiceException( - ERROR_CODE_STREAM_CLOSED, str(e), service_name, procedure_name - ) from e - - # Create the encoder task - async def _encode_stream() -> None: - if not request: - await self.send_close_stream( - service_name, - procedure_name, - stream_id, - extra_control_flags=STREAM_OPEN_BIT, - ) - return - - assert request_serializer, "send_stream missing request_serializer" - - async for item in request: - if item is None: - continue - await self.send_message( - service_name=service_name, - procedure_name=procedure_name, - stream_id=stream_id, - control_flags=0, - payload=request_serializer(item), - ) - await self.send_close_stream(service_name, procedure_name, stream_id) - - self._task_manager.create_task(_encode_stream()) - - # Handle potential errors during communication - try: - async for item in output: - if "type" in item and item["type"] == "CLOSE": - break - if not item.get("ok", False): - try: - yield error_deserializer(item["payload"]) - except Exception: - logger.exception( - f"Error during subscription error deserialization: {item}" - ) - continue - yield response_deserializer(item["payload"]) - except (RuntimeError, ChannelClosed) as e: - raise RiverServiceException( - ERROR_CODE_STREAM_CLOSED, - "Stream closed before response", - service_name, - procedure_name, - ) from e - except Exception as e: - raise e - finally: - output.close() - - async def send_close_stream( - self, - service_name: str, - procedure_name: str, - stream_id: str, - extra_control_flags: int = 0, - ) -> None: - # close stream - await self.send_message( - service_name=service_name, - procedure_name=procedure_name, - stream_id=stream_id, - control_flags=STREAM_CLOSED_BIT | extra_control_flags, - payload={ - "type": "CLOSE", - }, - ) diff --git a/src/replit_river/v2/client_transport.py b/src/replit_river/v2/client_transport.py index 18240b09..3b0a7435 100644 --- a/src/replit_river/v2/client_transport.py +++ b/src/replit_river/v2/client_transport.py @@ -39,9 +39,7 @@ TransportOptions, UriAndMetadata, ) -from replit_river.v2.client_session import ClientSession - -from .session import Session +from replit_river.v2.session import Session PROTOCOL_VERSION = "v2.0" @@ -55,7 +53,7 @@ def __init__(self, code: str, message: str, client_id: str) -> None: class ClientTransport(Generic[HandshakeMetadataType]): - _session: ClientSession | None + _session: Session | None def __init__( self, @@ -91,7 +89,7 @@ async def close(self) -> None: self._rate_limiter.close() await self._close_session() - async def get_or_create_session(self) -> ClientSession: + async def get_or_create_session(self) -> Session: """ If we have an active session, return it. If we have a "closed" session, mint a whole new session. @@ -126,7 +124,7 @@ async def get_or_create_session(self) -> ClientSession: async def _establish_new_connection( self, - old_session: ClientSession | None = None, + old_session: Session | None = None, ) -> tuple[ ClientConnection, ControlMessageHandshakeRequest[HandshakeMetadataType], @@ -195,7 +193,7 @@ async def _establish_new_connection( async def _create_new_session( self, - ) -> ClientSession: + ) -> Session: logger.info("Creating new session") new_ws, hs_request, hs_response = await self._establish_new_connection() if not hs_response.status.ok: @@ -204,7 +202,7 @@ async def _create_new_session( ERROR_SESSION, f"Server did not return OK status on handshake response: {message}", ) - new_session = ClientSession( + new_session = Session( transport_id=self._transport_id, to_id=self._server_id, session_id=hs_request.sessionId, @@ -218,7 +216,7 @@ async def _create_new_session( await new_session.start_serve_responses() return new_session - async def _retry_connection(self) -> ClientSession: + async def _retry_connection(self) -> Session: if not self._transport_options.transparent_reconnect: await self._close_session() return await self.get_or_create_session() @@ -295,7 +293,7 @@ async def _establish_handshake( session_id: str, handshake_metadata: HandshakeMetadataType, websocket: ClientConnection, - old_session: ClientSession | None, + old_session: Session | None, ) -> tuple[ ControlMessageHandshakeRequest[HandshakeMetadataType], ControlMessageHandshakeResponse, @@ -308,7 +306,7 @@ async def _establish_handshake( nextExpectedSeq=0, nextSentSeq=0, ) - case ClientSession(): + case Session(): expectedSessionState = ExpectedSessionState( nextExpectedSeq=old_session.ack, nextSentSeq=old_session.seq, diff --git a/src/replit_river/v2/schema.py b/src/replit_river/v2/schema.py index 5fffd164..3c9792b8 100644 --- a/src/replit_river/v2/schema.py +++ b/src/replit_river/v2/schema.py @@ -1,7 +1,7 @@ from typing import Any, Literal, NotRequired, TypeAlias, TypedDict from replit_river.rpc import ACK_BIT_TYPE, ExpectedSessionState -from replit_river.v2.client_session import STREAM_CANCEL_BIT_TYPE +from replit_river.v2.session import STREAM_CANCEL_BIT_TYPE class ControlClose(TypedDict): diff --git a/src/replit_river/v2/session.py b/src/replit_river/v2/session.py index dc8f91b6..a0d93ca3 100644 --- a/src/replit_river/v2/session.py +++ b/src/replit_river/v2/session.py @@ -1,26 +1,67 @@ import asyncio import logging from collections import deque -from typing import Any, Awaitable, Callable, Coroutine, TypeAlias +from collections.abc import AsyncIterable +from datetime import timedelta +from typing import ( + Any, + AsyncGenerator, + Awaitable, + Callable, + Coroutine, + Literal, + TypeAlias, + cast, +) import nanoid # type: ignore from aiochannel import Channel +from aiochannel.errors import ChannelClosed from opentelemetry.trace import Span, use_span from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator from websockets.asyncio.client import ClientConnection +from websockets.exceptions import ConnectionClosed, ConnectionClosedOK from websockets.frames import CloseCode +from websockets.legacy.protocol import WebSocketCommonProtocol from replit_river.common_session import ( SessionState, + buffered_message_sender, check_to_close_session, setup_heartbeat, ) +from replit_river.error_schema import ( + ERROR_CODE_CANCEL, + ERROR_CODE_STREAM_CLOSED, + RiverError, + RiverException, + RiverServiceException, + StreamClosedRiverServiceException, + exception_from_message, +) +from replit_river.messages import ( + FailedSendingMessageException, + parse_transport_msg, +) from replit_river.rpc import ( + ACK_BIT, + STREAM_OPEN_BIT, TransportMessage, TransportMessageTracingSetter, ) +from replit_river.seq_manager import ( + IgnoreMessageException, + InvalidMessageException, + OutOfOrderMessageException, +) from replit_river.task_manager import BackgroundTaskManager -from replit_river.transport_options import TransportOptions +from replit_river.transport_options import MAX_MESSAGE_BUFFER_SIZE, TransportOptions + +STREAM_CANCEL_BIT_TYPE = Literal[0b00100] +STREAM_CANCEL_BIT: STREAM_CANCEL_BIT_TYPE = 0b00100 +STREAM_CLOSED_BIT_TYPE = Literal[0b01000] +STREAM_CLOSED_BIT: STREAM_CLOSED_BIT_TYPE = 0b01000 + logger = logging.getLogger(__name__) @@ -101,6 +142,49 @@ def __init__( self.ack = 0 self.seq = 0 + async def do_close_websocket() -> None: + logger.debug( + "do_close called, _ws_connected=%r, _ws_unwrapped=%r", + self._ws_connected, + self._ws_unwrapped, + ) + self._ws_connected = False + if self._ws_unwrapped: + self._task_manager.create_task(self._ws_unwrapped.close()) + if self._retry_connection_callback: + self._task_manager.create_task(self._retry_connection_callback()) + await self._begin_close_session_countdown() + + self._setup_heartbeats_task(do_close_websocket) + + def commit(msg: TransportMessage) -> None: + pending = self._send_buffer.popleft() + if msg.seq != pending.seq: + logger.error("Out of sequence error") + self._ack_buffer.append(pending) + + # On commit, release pending writers waiting for more buffer space + if self._queue_full_lock.locked(): + self._queue_full_lock.release() + + def get_next_pending() -> TransportMessage | None: + if self._send_buffer: + return self._send_buffer[0] + return None + + self._task_manager.create_task( + buffered_message_sender( + get_ws=lambda: ( + cast(WebSocketCommonProtocol | ClientConnection, self._ws_unwrapped) + if self.is_websocket_open() + else None + ), + websocket_closed_callback=self._begin_close_session_countdown, + get_next_pending=get_next_pending, + commit=commit, + ) + ) + def _setup_heartbeats_task( self, do_close_websocket: Callable[[], Awaitable[None]], @@ -255,3 +339,443 @@ async def close(self) -> None: self._streams.clear() self._state = SessionState.CLOSED + + async def start_serve_responses(self) -> None: + self._task_manager.create_task(self._serve()) + + async def _serve(self) -> None: + """Serve messages from the websocket.""" + self._reset_session_close_countdown() + try: + try: + await self._handle_messages_from_ws() + except ConnectionClosed: + if self._retry_connection_callback: + self._task_manager.create_task(self._retry_connection_callback()) + + await self._begin_close_session_countdown() + logger.debug("ConnectionClosed while serving", exc_info=True) + except FailedSendingMessageException: + # Expected error if the connection is closed. + logger.debug( + "FailedSendingMessageException while serving", exc_info=True + ) + except Exception: + logger.exception("caught exception at message iterator") + except ExceptionGroup as eg: + _, unhandled = eg.split(lambda e: isinstance(e, ConnectionClosed)) + if unhandled: + raise ExceptionGroup( + "Unhandled exceptions on River server", unhandled.exceptions + ) + + async def _handle_messages_from_ws(self) -> None: + while self._ws_unwrapped is None: + await asyncio.sleep(1) + logger.debug( + "%s start handling messages from ws %s", + "client", + self._ws_unwrapped.id, + ) + try: + ws = self._ws_unwrapped + while True: + if not self._ws_unwrapped: + # We should not process messages if the websocket is closed. + break + + # decode=False: Avoiding an unnecessary round-trip through str + # Ideally this should be type-ascripted to : bytes, but there is no + # @overrides in `websockets` to hint this. + message = await ws.recv(decode=False) + try: + msg = parse_transport_msg(message) + + logger.debug(f"{self._transport_id} got a message %r", msg) + + # Update bookkeeping + if msg.seq < self.ack: + raise IgnoreMessageException( + f"{msg.from_} received duplicate msg, got {msg.seq}" + f" expected {self.ack}" + ) + elif msg.seq > self.ack: + logger.warning( + f"Out of order message received got {msg.seq} expected " + f"{self.ack}" + ) + + raise OutOfOrderMessageException( + f"Out of order message received got {msg.seq} expected " + f"{self.ack}" + ) + + assert msg.seq == self.ack, "Safety net, redundant assertion" + + # Set our next expected ack number + self.ack = msg.seq + 1 + + # Discard old messages from the buffer + while self._ack_buffer and self._ack_buffer[0].seq < msg.ack: + self._ack_buffer.popleft() + + self._reset_session_close_countdown() + + if msg.controlFlags & ACK_BIT != 0: + continue + stream = self._streams.get(msg.streamId, None) + if msg.controlFlags & STREAM_OPEN_BIT != 0: + raise InvalidMessageException( + "Client should not receive stream open bit" + ) + + if not stream: + logger.warning("no stream for %s", msg.streamId) + raise IgnoreMessageException("no stream for message, ignoring") + + if ( + msg.controlFlags & STREAM_CLOSED_BIT != 0 + and msg.payload.get("type", None) == "CLOSE" + ): + # close message is not sent to the stream + pass + else: + try: + await stream.put(msg.payload) + except ChannelClosed: + # The client is no longer interested in this stream, + # just drop the message. + pass + except RuntimeError as e: + raise InvalidMessageException(e) from e + + if msg.controlFlags & STREAM_CLOSED_BIT != 0: + if stream: + stream.close() + del self._streams[msg.streamId] + except IgnoreMessageException: + logger.debug("Ignoring transport message", exc_info=True) + continue + except OutOfOrderMessageException: + logger.exception("Out of order message, closing connection") + self._task_manager.create_task( + self._ws_unwrapped.close( + code=CloseCode.INVALID_DATA, + reason="Out of order message", + ) + ) + return + except InvalidMessageException: + logger.exception("Got invalid transport message, closing session") + await self.close() + return + except ConnectionClosedOK: + pass # Exited normally + except ConnectionClosed as e: + raise e + + async def send_rpc[R, A]( + self, + service_name: str, + procedure_name: str, + request: R, + request_serializer: Callable[[R], Any], + response_deserializer: Callable[[Any], A], + error_deserializer: Callable[[Any], RiverError], + span: Span, + timeout: timedelta, + ) -> A: + """Sends a single RPC request to the server. + + Expects the input and output be messages that will be msgpacked. + """ + stream_id = nanoid.generate() + output: Channel[Any] = Channel(1) + self._streams[stream_id] = output + await self.send_message( + stream_id=stream_id, + control_flags=STREAM_OPEN_BIT | STREAM_CLOSED_BIT, + payload=request_serializer(request), + service_name=service_name, + procedure_name=procedure_name, + span=span, + ) + # Handle potential errors during communication + try: + try: + async with asyncio.timeout(timeout.total_seconds()): + response = await output.get() + except asyncio.TimeoutError as e: + await self.send_message( + stream_id=stream_id, + control_flags=STREAM_CANCEL_BIT, + payload={"type": "CANCEL"}, + service_name=service_name, + procedure_name=procedure_name, + span=span, + ) + raise RiverException(ERROR_CODE_CANCEL, str(e)) from e + except ChannelClosed as e: + raise RiverServiceException( + ERROR_CODE_STREAM_CLOSED, + "Stream closed before response", + service_name, + procedure_name, + ) from e + except RuntimeError as e: + raise RiverException(ERROR_CODE_STREAM_CLOSED, str(e)) from e + if not response.get("ok", False): + try: + error = error_deserializer(response["payload"]) + except Exception as e: + raise RiverException("error_deserializer", str(e)) from e + raise exception_from_message(error.code)( + error.code, error.message, service_name, procedure_name + ) + return response_deserializer(response["payload"]) + except RiverException as e: + raise e + except Exception as e: + raise e + + async def send_upload[I, R, A]( + self, + service_name: str, + procedure_name: str, + init: I, + request: AsyncIterable[R] | None, + init_serializer: Callable[[I], Any], + request_serializer: Callable[[R], Any] | None, + response_deserializer: Callable[[Any], A], + error_deserializer: Callable[[Any], RiverError], + span: Span, + ) -> A: + """Sends an upload request to the server. + + Expects the input and output be messages that will be msgpacked. + """ + + stream_id = nanoid.generate() + output: Channel[Any] = Channel(1) + self._streams[stream_id] = output + try: + await self.send_message( + stream_id=stream_id, + control_flags=STREAM_OPEN_BIT, + service_name=service_name, + procedure_name=procedure_name, + payload=init_serializer(init), + span=span, + ) + + if request: + assert request_serializer, "send_stream missing request_serializer" + + # If this request is not closed and the session is killed, we should + # throw exception here + async for item in request: + control_flags = 0 + await self.send_message( + stream_id=stream_id, + service_name=service_name, + procedure_name=procedure_name, + control_flags=control_flags, + payload=request_serializer(item), + span=span, + ) + except Exception as e: + raise RiverServiceException( + ERROR_CODE_STREAM_CLOSED, str(e), service_name, procedure_name + ) from e + await self.send_close_stream( + service_name, + procedure_name, + stream_id, + extra_control_flags=0, + ) + + # Handle potential errors during communication + # TODO: throw a error when the transport is hard closed + try: + try: + response = await output.get() + except ChannelClosed as e: + raise RiverServiceException( + ERROR_CODE_STREAM_CLOSED, + "Stream closed before response", + service_name, + procedure_name, + ) from e + except RuntimeError as e: + raise RiverException(ERROR_CODE_STREAM_CLOSED, str(e)) from e + if not response.get("ok", False): + try: + error = error_deserializer(response["payload"]) + except Exception as e: + raise RiverException("error_deserializer", str(e)) from e + raise exception_from_message(error.code)( + error.code, error.message, service_name, procedure_name + ) + + return response_deserializer(response["payload"]) + except RiverException as e: + raise e + except Exception as e: + raise e + + async def send_subscription[R, E, A]( + self, + service_name: str, + procedure_name: str, + request: R, + request_serializer: Callable[[R], Any], + response_deserializer: Callable[[Any], A], + error_deserializer: Callable[[Any], E], + span: Span, + ) -> AsyncGenerator[A | E, None]: + """Sends a subscription request to the server. + + Expects the input and output be messages that will be msgpacked. + """ + stream_id = nanoid.generate() + output: Channel[Any] = Channel(MAX_MESSAGE_BUFFER_SIZE) + self._streams[stream_id] = output + await self.send_message( + service_name=service_name, + procedure_name=procedure_name, + stream_id=stream_id, + control_flags=STREAM_OPEN_BIT, + payload=request_serializer(request), + span=span, + ) + + # Handle potential errors during communication + try: + async for item in output: + if item.get("type", None) == "CLOSE": + break + if not item.get("ok", False): + try: + yield error_deserializer(item["payload"]) + except Exception: + logger.exception( + f"Error during subscription error deserialization: {item}" + ) + continue + yield response_deserializer(item["payload"]) + except (RuntimeError, ChannelClosed) as e: + raise RiverServiceException( + ERROR_CODE_STREAM_CLOSED, + "Stream closed before response", + service_name, + procedure_name, + ) from e + except Exception as e: + raise e + finally: + output.close() + + async def send_stream[I, R, E, A]( + self, + service_name: str, + procedure_name: str, + init: I, + request: AsyncIterable[R] | None, + init_serializer: Callable[[I], Any], + request_serializer: Callable[[R], Any] | None, + response_deserializer: Callable[[Any], A], + error_deserializer: Callable[[Any], E], + span: Span, + ) -> AsyncGenerator[A | E, None]: + """Sends a subscription request to the server. + + Expects the input and output be messages that will be msgpacked. + """ + + stream_id = nanoid.generate() + output: Channel[Any] = Channel(MAX_MESSAGE_BUFFER_SIZE) + self._streams[stream_id] = output + try: + await self.send_message( + service_name=service_name, + procedure_name=procedure_name, + stream_id=stream_id, + control_flags=STREAM_OPEN_BIT, + payload=init_serializer(init), + span=span, + ) + except Exception as e: + raise StreamClosedRiverServiceException( + ERROR_CODE_STREAM_CLOSED, str(e), service_name, procedure_name + ) from e + + # Create the encoder task + async def _encode_stream() -> None: + if not request: + await self.send_close_stream( + service_name, + procedure_name, + stream_id, + extra_control_flags=STREAM_OPEN_BIT, + ) + return + + assert request_serializer, "send_stream missing request_serializer" + + async for item in request: + if item is None: + continue + await self.send_message( + service_name=service_name, + procedure_name=procedure_name, + stream_id=stream_id, + control_flags=0, + payload=request_serializer(item), + ) + await self.send_close_stream(service_name, procedure_name, stream_id) + + self._task_manager.create_task(_encode_stream()) + + # Handle potential errors during communication + try: + async for item in output: + if "type" in item and item["type"] == "CLOSE": + break + if not item.get("ok", False): + try: + yield error_deserializer(item["payload"]) + except Exception: + logger.exception( + f"Error during subscription error deserialization: {item}" + ) + continue + yield response_deserializer(item["payload"]) + except (RuntimeError, ChannelClosed) as e: + raise RiverServiceException( + ERROR_CODE_STREAM_CLOSED, + "Stream closed before response", + service_name, + procedure_name, + ) from e + except Exception as e: + raise e + finally: + output.close() + + async def send_close_stream( + self, + service_name: str, + procedure_name: str, + stream_id: str, + extra_control_flags: int = 0, + ) -> None: + # close stream + await self.send_message( + service_name=service_name, + procedure_name=procedure_name, + stream_id=stream_id, + control_flags=STREAM_CLOSED_BIT | extra_control_flags, + payload={ + "type": "CLOSE", + }, + ) From ebf3bf82bd5e8b7fcd0925f66efd7024ad7a2a3d Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Mon, 24 Mar 2025 13:20:46 -0700 Subject: [PATCH 039/193] This is what Semaphores are for --- src/replit_river/v2/session.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/replit_river/v2/session.py b/src/replit_river/v2/session.py index a0d93ca3..7638649c 100644 --- a/src/replit_river/v2/session.py +++ b/src/replit_river/v2/session.py @@ -129,6 +129,7 @@ def __init__( self._retry_connection_callback = retry_connection_callback # message state + self._message_enqueued = asyncio.Semaphore() self._space_available_cond = asyncio.Condition() self._queue_full_lock = asyncio.Lock() @@ -174,6 +175,7 @@ def get_next_pending() -> TransportMessage | None: self._task_manager.create_task( buffered_message_sender( + self._message_enqueued, get_ws=lambda: ( cast(WebSocketCommonProtocol | ClientConnection, self._ws_unwrapped) if self.is_websocket_open() @@ -309,6 +311,8 @@ async def send_message( await self._queue_full_lock.acquire() logger.warning("LOCK RELEASED %r", repr(payload)) self._send_buffer.append(msg) + # Wake up buffered_message_sender + self._message_enqueued.release() self.seq += 1 async def close(self) -> None: From 62143d32a754aaae7e5edaacf46f92a5975c5242 Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Mon, 24 Mar 2025 14:42:16 -0700 Subject: [PATCH 040/193] Inline ws creation --- src/replit_river/v2/client_transport.py | 71 +++------- src/replit_river/v2/session.py | 181 +++++++++++++++++++++++- 2 files changed, 198 insertions(+), 54 deletions(-) diff --git a/src/replit_river/v2/client_transport.py b/src/replit_river/v2/client_transport.py index 3b0a7435..9751eb43 100644 --- a/src/replit_river/v2/client_transport.py +++ b/src/replit_river/v2/client_transport.py @@ -13,7 +13,6 @@ from replit_river.error_schema import ( ERROR_CODE_STREAM_CLOSED, ERROR_HANDSHAKE, - ERROR_SESSION, RiverException, ) from replit_river.messages import ( @@ -96,31 +95,28 @@ async def get_or_create_session(self) -> Session: If we have a disconnected session, attempt to start a new WS and use it. """ async with self._create_session_lock: - existing_session = ( - self._session - if self._session and self._session.is_session_open() - else None - ) - if existing_session is None: - return await self._create_new_session() - if existing_session.is_websocket_open(): - return existing_session - new_ws, _, hs_response = await self._establish_new_connection( - existing_session - ) - if hs_response.status.sessionId == existing_session.session_id: - logger.info( - "Replacing ws connection in session id %s", - existing_session.session_id, + existing_session = self._session + if not existing_session: + logger.info("Creating new session") + new_session = Session( + transport_id=self._transport_id, + to_id=self._server_id, + session_id=self.generate_nanoid(), + transport_options=self._transport_options, + close_session_callback=self._delete_session, + retry_connection_callback=self._retry_connection, ) - await existing_session.replace_with_new_websocket(new_ws) - return existing_session - else: - logger.info("Closing stale session %s", existing_session.session_id) - await new_ws.close() # NB(dstewart): This wasn't there in the - # v1 transport, were we just leaking WS? - await existing_session.close() - return await self._create_new_session() + + self._session = new_session + existing_session = new_session + await existing_session.start_serve_responses() + + await existing_session.ensure_connected( + client_id=self._client_id, + rate_limiter=self._rate_limiter, + uri_and_metadata_factory=self._uri_and_metadata_factory, + ) + return existing_session async def _establish_new_connection( self, @@ -191,31 +187,6 @@ async def _establish_new_connection( f"Failed to create ws after retrying {max_retry} number of times", ) from last_error - async def _create_new_session( - self, - ) -> Session: - logger.info("Creating new session") - new_ws, hs_request, hs_response = await self._establish_new_connection() - if not hs_response.status.ok: - message = hs_response.status.reason - raise RiverException( - ERROR_SESSION, - f"Server did not return OK status on handshake response: {message}", - ) - new_session = Session( - transport_id=self._transport_id, - to_id=self._server_id, - session_id=hs_request.sessionId, - websocket=new_ws, - transport_options=self._transport_options, - close_session_callback=self._delete_session, - retry_connection_callback=self._retry_connection, - ) - - self._session = new_session - await new_session.start_serve_responses() - return new_session - async def _retry_connection(self) -> Session: if not self._transport_options.transparent_reconnect: await self._close_session() diff --git a/src/replit_river/v2/session.py b/src/replit_river/v2/session.py index 7638649c..758d3bf2 100644 --- a/src/replit_river/v2/session.py +++ b/src/replit_river/v2/session.py @@ -15,10 +15,12 @@ ) import nanoid # type: ignore +import websockets.asyncio.client from aiochannel import Channel from aiochannel.errors import ChannelClosed from opentelemetry.trace import Span, use_span from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator +from pydantic import ValidationError from websockets.asyncio.client import ClientConnection from websockets.exceptions import ConnectionClosed, ConnectionClosedOK from websockets.frames import CloseCode @@ -33,6 +35,7 @@ from replit_river.error_schema import ( ERROR_CODE_CANCEL, ERROR_CODE_STREAM_CLOSED, + ERROR_HANDSHAKE, RiverError, RiverException, RiverServiceException, @@ -41,11 +44,17 @@ ) from replit_river.messages import ( FailedSendingMessageException, + WebsocketClosedException, parse_transport_msg, + send_transport_message, ) +from replit_river.rate_limiter import LeakyBucketRateLimit from replit_river.rpc import ( ACK_BIT, STREAM_OPEN_BIT, + ControlMessageHandshakeRequest, + ControlMessageHandshakeResponse, + ExpectedSessionState, TransportMessage, TransportMessageTracingSetter, ) @@ -55,7 +64,15 @@ OutOfOrderMessageException, ) from replit_river.task_manager import BackgroundTaskManager -from replit_river.transport_options import MAX_MESSAGE_BUFFER_SIZE, TransportOptions +from replit_river.transport_options import ( + MAX_MESSAGE_BUFFER_SIZE, + TransportOptions, + UriAndMetadata, +) +from replit_river.v2.client_transport import ( + PROTOCOL_VERSION, + HandshakeBudgetExhaustedException, +) STREAM_CANCEL_BIT_TYPE = Literal[0b00100] STREAM_CANCEL_BIT: STREAM_CANCEL_BIT_TYPE = 0b00100 @@ -107,7 +124,6 @@ def __init__( transport_id: str, to_id: str, session_id: str, - websocket: ClientConnection, transport_options: TransportOptions, close_session_callback: CloseSessionCallback, retry_connection_callback: RetryConnectionCallback | None = None, @@ -123,8 +139,7 @@ def __init__( self._close_session_after_time_secs: float | None = None # ws state - self._ws_connected = True - self._ws_unwrapped = websocket + self._ws_connected = False self._heartbeat_misses = 0 self._retry_connection_callback = retry_connection_callback @@ -187,6 +202,164 @@ def get_next_pending() -> TransportMessage | None: ) ) + async def ensure_connected( + self, + client_id: str, + rate_limiter: LeakyBucketRateLimit, + uri_and_metadata_factory: Callable[[], Awaitable[UriAndMetadata]], + ) -> None: + """ + Either return immediately or establish a websocket connection and return + once we can accept messages + """ + if self._ws_unwrapped and self._ws_connected: + return + max_retry = self._transport_options.connection_retry_options.max_retry + logger.info("Attempting to establish new ws connection") + + last_error: Exception | None = None + for i in range(max_retry): + if i > 0: + logger.info(f"Retrying build handshake number {i} times") + if not rate_limiter.has_budget(client_id): + logger.debug("No retry budget for %s.", client_id) + raise HandshakeBudgetExhaustedException( + ERROR_HANDSHAKE, + "No retry budget", + client_id=client_id, + ) from last_error + + rate_limiter.consume_budget(client_id) + + try: + uri_and_metadata = await uri_and_metadata_factory() + ws = await websockets.asyncio.client.connect(uri_and_metadata["uri"]) + + try: + try: + expectedSessionState = ExpectedSessionState( + nextExpectedSeq=0, + nextSentSeq=0, + ) + handshake_request = ControlMessageHandshakeRequest[Any]( + type="HANDSHAKE_REQ", + protocolVersion=PROTOCOL_VERSION, + sessionId=self.session_id, + metadata=uri_and_metadata["metadata"], + expectedSessionState=expectedSessionState, + ) + stream_id = nanoid.generate() + + async def websocket_closed_callback() -> None: + logger.error("websocket closed before handshake response") + + try: + payload = handshake_request.model_dump() + await send_transport_message( + TransportMessage( + from_=self._transport_id, + to=self._to_id, + streamId=stream_id, + controlFlags=0, + id=nanoid.generate(), + seq=0, + ack=0, + payload=payload, + ), + ws=ws, + websocket_closed_callback=websocket_closed_callback, + ) + except ( + WebsocketClosedException, + FailedSendingMessageException, + ) as e: # noqa: E501 + raise RiverException( + ERROR_HANDSHAKE, + "Handshake failed, conn closed while sending response", # noqa: E501 + ) from e + except FailedSendingMessageException as e: + raise RiverException( + ERROR_CODE_STREAM_CLOSED, + "Stream closed before response, closing connection", + ) from e + + startup_grace_deadline_ms = await self._get_current_time() + 60_000 + try: + while True: + if ( + await self._get_current_time() + >= startup_grace_deadline_ms + ): # noqa: E501 + raise RiverException( + ERROR_HANDSHAKE, + "Handshake response timeout, closing connection", # noqa: E501 + ) + try: + data = await ws.recv() + except ConnectionClosed as e: + logger.debug( + "Connection closed during waiting for handshake response", # noqa: E501 + exc_info=True, + ) + raise RiverException( + ERROR_HANDSHAKE, + "Handshake failed, conn closed while waiting for response", # noqa: E501 + ) from e + try: + response_msg = parse_transport_msg(data) + break + except IgnoreMessageException: + logger.debug( + "Ignoring transport message", exc_info=True + ) # noqa: E501 + continue + except InvalidMessageException as e: + raise RiverException( + ERROR_HANDSHAKE, + "Got invalid transport message, closing connection", + ) from e + + handshake_response = ControlMessageHandshakeResponse( + **response_msg.payload + ) # noqa: E501 + logger.debug("river client waiting for handshake response") + except ValidationError as e: + raise RiverException( + ERROR_HANDSHAKE, "Failed to parse handshake response" + ) from e + except asyncio.TimeoutError as e: + raise RiverException( + ERROR_HANDSHAKE, + "Handshake response timeout, closing connection", # noqa: E501 + ) from e + + logger.debug( + "river client get handshake response : %r", handshake_response + ) # noqa: E501 + if not handshake_response.status.ok: + raise RiverException( + ERROR_HANDSHAKE, + f"Handshake failed with code {handshake_response.status.code}: " # noqa: E501 + + f"{handshake_response.status.reason}", + ) + + rate_limiter.start_restoring_budget(client_id) + except RiverException as e: + await ws.close() + raise e + except Exception as e: + last_error = e + backoff_time = rate_limiter.get_backoff_ms(client_id) + logger.exception( + f"Error connecting, retrying with {backoff_time}ms backoff" + ) + await asyncio.sleep(backoff_time / 1000) + + raise RiverException( + ERROR_HANDSHAKE, + f"Failed to create ws after retrying {max_retry} number of times", + ) from last_error + def _setup_heartbeats_task( self, do_close_websocket: Callable[[], Awaitable[None]], From e9054991d387b2d1098a0b93c6b93f1f9ddb33d2 Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Mon, 24 Mar 2025 14:45:02 -0700 Subject: [PATCH 041/193] Goodbye session_lock --- src/replit_river/v2/client_transport.py | 43 ++++++++++++------------- 1 file changed, 20 insertions(+), 23 deletions(-) diff --git a/src/replit_river/v2/client_transport.py b/src/replit_river/v2/client_transport.py index 9751eb43..1f846338 100644 --- a/src/replit_river/v2/client_transport.py +++ b/src/replit_river/v2/client_transport.py @@ -71,8 +71,6 @@ def __init__( self._rate_limiter = LeakyBucketRateLimit( transport_options.connection_retry_options ) - # We want to make sure there's only one session creation at a time - self._create_session_lock = asyncio.Lock() async def _close_session(self) -> None: logger.info(f"start closing session {self._transport_id}") @@ -94,29 +92,28 @@ async def get_or_create_session(self) -> Session: If we have a "closed" session, mint a whole new session. If we have a disconnected session, attempt to start a new WS and use it. """ - async with self._create_session_lock: - existing_session = self._session - if not existing_session: - logger.info("Creating new session") - new_session = Session( - transport_id=self._transport_id, - to_id=self._server_id, - session_id=self.generate_nanoid(), - transport_options=self._transport_options, - close_session_callback=self._delete_session, - retry_connection_callback=self._retry_connection, - ) + existing_session = self._session + if not existing_session: + logger.info("Creating new session") + new_session = Session( + transport_id=self._transport_id, + to_id=self._server_id, + session_id=self.generate_nanoid(), + transport_options=self._transport_options, + close_session_callback=self._delete_session, + retry_connection_callback=self._retry_connection, + ) - self._session = new_session - existing_session = new_session - await existing_session.start_serve_responses() + self._session = new_session + existing_session = new_session + await existing_session.start_serve_responses() - await existing_session.ensure_connected( - client_id=self._client_id, - rate_limiter=self._rate_limiter, - uri_and_metadata_factory=self._uri_and_metadata_factory, - ) - return existing_session + await existing_session.ensure_connected( + client_id=self._client_id, + rate_limiter=self._rate_limiter, + uri_and_metadata_factory=self._uri_and_metadata_factory, + ) + return existing_session async def _establish_new_connection( self, From c0a4f4498b6788921dce160a29708f88e5205e02 Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Mon, 24 Mar 2025 15:08:36 -0700 Subject: [PATCH 042/193] Moving throw over into rate_limiter --- src/replit_river/rate_limiter.py | 30 +++++ src/replit_river/v2/client_transport.py | 6 + src/replit_river/v2/session.py | 139 ++++++++++-------------- 3 files changed, 96 insertions(+), 79 deletions(-) diff --git a/src/replit_river/rate_limiter.py b/src/replit_river/rate_limiter.py index b9265eee..37bf32aa 100644 --- a/src/replit_river/rate_limiter.py +++ b/src/replit_river/rate_limiter.py @@ -1,8 +1,13 @@ import asyncio +import logging import random from contextvars import Context +from typing import Literal from replit_river.transport_options import ConnectionRetryOptions +from replit_river.v2.client_transport import BudgetExhaustedException + +logger = logging.getLogger(__name__) class LeakyBucketRateLimit: @@ -64,6 +69,31 @@ def has_budget(self, user: str) -> bool: """ return self.get_budget_consumed(user) < self.options.attempt_budget_capacity + def has_budget_or_throw( + self, + user: str, + error_code: str, + last_error: Exception | None, + ) -> Literal[True]: + """ + Check if the user has remaining budget to make a retry. + If they do not, explode. + + Args: + user (str): The identifier for the user. + + Returns: + bool: True if budget is available, False otherwise. + """ + if self.get_budget_consumed(user) < self.options.attempt_budget_capacity: + logger.debug("No retry budget for %s.", user) + raise BudgetExhaustedException( + error_code, + "No retry budget", + client_id=user, + ) from last_error + return True + def consume_budget(self, user: str) -> None: """Increment the budget consumed for the user by 1, indicating a retry attempt. diff --git a/src/replit_river/v2/client_transport.py b/src/replit_river/v2/client_transport.py index 1f846338..7fa365bf 100644 --- a/src/replit_river/v2/client_transport.py +++ b/src/replit_river/v2/client_transport.py @@ -51,6 +51,12 @@ def __init__(self, code: str, message: str, client_id: str) -> None: self.client_id = client_id +class BudgetExhaustedException(RiverException): + def __init__(self, code: str, message: str, client_id: str) -> None: + super().__init__(code, message) + self.client_id = client_id + + class ClientTransport(Generic[HandshakeMetadataType]): _session: Session | None diff --git a/src/replit_river/v2/session.py b/src/replit_river/v2/session.py index 758d3bf2..1681b93f 100644 --- a/src/replit_river/v2/session.py +++ b/src/replit_river/v2/session.py @@ -71,7 +71,6 @@ ) from replit_river.v2.client_transport import ( PROTOCOL_VERSION, - HandshakeBudgetExhaustedException, ) STREAM_CANCEL_BIT_TYPE = Literal[0b00100] @@ -202,11 +201,13 @@ def get_next_pending() -> TransportMessage | None: ) ) - async def ensure_connected( + async def ensure_connected[HandshakeMetadata]( self, client_id: str, rate_limiter: LeakyBucketRateLimit, - uri_and_metadata_factory: Callable[[], Awaitable[UriAndMetadata]], + uri_and_metadata_factory: Callable[ + [], Awaitable[UriAndMetadata[HandshakeMetadata]] + ], # noqa: E501 ) -> None: """ Either return immediately or establish a websocket connection and return @@ -218,16 +219,11 @@ async def ensure_connected( logger.info("Attempting to establish new ws connection") last_error: Exception | None = None - for i in range(max_retry): + i = 0 + while rate_limiter.has_budget_or_throw(client_id, ERROR_HANDSHAKE, last_error): if i > 0: logger.info(f"Retrying build handshake number {i} times") - if not rate_limiter.has_budget(client_id): - logger.debug("No retry budget for %s.", client_id) - raise HandshakeBudgetExhaustedException( - ERROR_HANDSHAKE, - "No retry budget", - client_id=client_id, - ) from last_error + i += 1 rate_limiter.consume_budget(client_id) @@ -238,10 +234,12 @@ async def ensure_connected( try: try: expectedSessionState = ExpectedSessionState( - nextExpectedSeq=0, - nextSentSeq=0, + nextExpectedSeq=self.ack, + nextSentSeq=self.seq, ) - handshake_request = ControlMessageHandshakeRequest[Any]( + handshake_request = ControlMessageHandshakeRequest[ + HandshakeMetadata + ]( # noqa: E501 type="HANDSHAKE_REQ", protocolVersion=PROTOCOL_VERSION, sessionId=self.session_id, @@ -253,85 +251,68 @@ async def ensure_connected( async def websocket_closed_callback() -> None: logger.error("websocket closed before handshake response") + await send_transport_message( + TransportMessage( + from_=self._transport_id, + to=self._to_id, + streamId=stream_id, + controlFlags=0, + id=nanoid.generate(), + seq=0, + ack=0, + payload=handshake_request.model_dump(), + ), + ws=ws, + websocket_closed_callback=websocket_closed_callback, + ) + except ( + WebsocketClosedException, + FailedSendingMessageException, + ) as e: # noqa: E501 + raise RiverException( + ERROR_HANDSHAKE, + "Handshake failed, conn closed while sending response", # noqa: E501 + ) from e + + startup_grace_deadline_ms = await self._get_current_time() + 60_000 + while True: + if await self._get_current_time() >= startup_grace_deadline_ms: # noqa: E501 + raise RiverException( + ERROR_HANDSHAKE, + "Handshake response timeout, closing connection", # noqa: E501 + ) try: - payload = handshake_request.model_dump() - await send_transport_message( - TransportMessage( - from_=self._transport_id, - to=self._to_id, - streamId=stream_id, - controlFlags=0, - id=nanoid.generate(), - seq=0, - ack=0, - payload=payload, - ), - ws=ws, - websocket_closed_callback=websocket_closed_callback, + data = await ws.recv() + except ConnectionClosed as e: + logger.debug( + "Connection closed during waiting for handshake response", # noqa: E501 + exc_info=True, ) - except ( - WebsocketClosedException, - FailedSendingMessageException, - ) as e: # noqa: E501 raise RiverException( ERROR_HANDSHAKE, - "Handshake failed, conn closed while sending response", # noqa: E501 + "Handshake failed, conn closed while waiting for response", # noqa: E501 + ) from e + try: + response_msg = parse_transport_msg(data) + break + except IgnoreMessageException: + logger.debug("Ignoring transport message", exc_info=True) # noqa: E501 + continue + except InvalidMessageException as e: + raise RiverException( + ERROR_HANDSHAKE, + "Got invalid transport message, closing connection", ) from e - except FailedSendingMessageException as e: - raise RiverException( - ERROR_CODE_STREAM_CLOSED, - "Stream closed before response, closing connection", - ) from e - startup_grace_deadline_ms = await self._get_current_time() + 60_000 try: - while True: - if ( - await self._get_current_time() - >= startup_grace_deadline_ms - ): # noqa: E501 - raise RiverException( - ERROR_HANDSHAKE, - "Handshake response timeout, closing connection", # noqa: E501 - ) - try: - data = await ws.recv() - except ConnectionClosed as e: - logger.debug( - "Connection closed during waiting for handshake response", # noqa: E501 - exc_info=True, - ) - raise RiverException( - ERROR_HANDSHAKE, - "Handshake failed, conn closed while waiting for response", # noqa: E501 - ) from e - try: - response_msg = parse_transport_msg(data) - break - except IgnoreMessageException: - logger.debug( - "Ignoring transport message", exc_info=True - ) # noqa: E501 - continue - except InvalidMessageException as e: - raise RiverException( - ERROR_HANDSHAKE, - "Got invalid transport message, closing connection", - ) from e - handshake_response = ControlMessageHandshakeResponse( **response_msg.payload - ) # noqa: E501 + ) logger.debug("river client waiting for handshake response") except ValidationError as e: raise RiverException( ERROR_HANDSHAKE, "Failed to parse handshake response" ) from e - except asyncio.TimeoutError as e: - raise RiverException( - ERROR_HANDSHAKE, - "Handshake response timeout, closing connection", # noqa: E501 - ) from e logger.debug( "river client get handshake response : %r", handshake_response From 736c4e86d504ee364f980f45bfd1010c2a655093 Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Mon, 24 Mar 2025 15:40:51 -0700 Subject: [PATCH 043/193] More lifecycle management --- src/replit_river/v2/client_transport.py | 259 ++---------------------- src/replit_river/v2/session.py | 22 +- 2 files changed, 23 insertions(+), 258 deletions(-) diff --git a/src/replit_river/v2/client_transport.py b/src/replit_river/v2/client_transport.py index 7fa365bf..a46adeea 100644 --- a/src/replit_river/v2/client_transport.py +++ b/src/replit_river/v2/client_transport.py @@ -1,38 +1,13 @@ -import asyncio import logging from collections.abc import Awaitable, Callable -from typing import Generic, assert_never +from typing import Generic import nanoid -import websockets -import websockets.asyncio.client -from pydantic import ValidationError -from websockets.asyncio.client import ClientConnection -from websockets.exceptions import ConnectionClosed from replit_river.error_schema import ( - ERROR_CODE_STREAM_CLOSED, - ERROR_HANDSHAKE, RiverException, ) -from replit_river.messages import ( - FailedSendingMessageException, - WebsocketClosedException, - parse_transport_msg, - send_transport_message, -) from replit_river.rate_limiter import LeakyBucketRateLimit -from replit_river.rpc import ( - SESSION_MISMATCH_CODE, - ControlMessageHandshakeRequest, - ControlMessageHandshakeResponse, - ExpectedSessionState, - TransportMessage, -) -from replit_river.seq_manager import ( - IgnoreMessageException, - InvalidMessageException, -) from replit_river.transport_options import ( HandshakeMetadataType, TransportOptions, @@ -78,33 +53,25 @@ def __init__( transport_options.connection_retry_options ) - async def _close_session(self) -> None: - logger.info(f"start closing session {self._transport_id}") - if not self._session: - return - await self._session.close() - logger.info(f"Transport closed {self._transport_id}") - - def generate_nanoid(self) -> str: - return str(nanoid.generate()) - async def close(self) -> None: self._rate_limiter.close() - await self._close_session() + if self._session: + logger.info(f"start closing session {self._transport_id}") + await self._session.close() + logger.info(f"Transport closed {self._transport_id}") async def get_or_create_session(self) -> Session: """ - If we have an active session, return it. - If we have a "closed" session, mint a whole new session. - If we have a disconnected session, attempt to start a new WS and use it. + Create a session if it does not exist, + call ensure_connected on whatever session is active. """ existing_session = self._session - if not existing_session: + if not existing_session or not existing_session.is_session_open(): logger.info("Creating new session") new_session = Session( transport_id=self._transport_id, to_id=self._server_id, - session_id=self.generate_nanoid(), + session_id=nanoid.generate(), transport_options=self._transport_options, close_session_callback=self._delete_session, retry_connection_callback=self._retry_connection, @@ -121,214 +88,12 @@ async def get_or_create_session(self) -> Session: ) return existing_session - async def _establish_new_connection( - self, - old_session: Session | None = None, - ) -> tuple[ - ClientConnection, - ControlMessageHandshakeRequest[HandshakeMetadataType], - ControlMessageHandshakeResponse, - ]: - """Build a new websocket connection with retry logic.""" - rate_limit = self._rate_limiter - max_retry = self._transport_options.connection_retry_options.max_retry - client_id = self._client_id - logger.info("Attempting to establish new ws connection") - - last_error: Exception | None = None - for i in range(max_retry): - if i > 0: - logger.info(f"Retrying build handshake number {i} times") - if not rate_limit.has_budget(client_id): - logger.debug("No retry budget for %s.", client_id) - raise HandshakeBudgetExhaustedException( - ERROR_HANDSHAKE, - "No retry budget", - client_id=client_id, - ) from last_error - - rate_limit.consume_budget(client_id) - - # if the session is closed, we shouldn't use it - if old_session and not old_session.is_session_open(): - old_session = None - - try: - uri_and_metadata = await self._uri_and_metadata_factory() - ws = await websockets.asyncio.client.connect(uri_and_metadata["uri"]) - session_id: str - if old_session: - session_id = old_session.session_id - else: - session_id = self.generate_nanoid() - - try: - ( - handshake_request, - handshake_response, - ) = await self._establish_handshake( - session_id, - uri_and_metadata["metadata"], - ws, - old_session, - ) - rate_limit.start_restoring_budget(client_id) - return ws, handshake_request, handshake_response - except RiverException as e: - await ws.close() - raise e - except Exception as e: - last_error = e - backoff_time = rate_limit.get_backoff_ms(client_id) - logger.exception( - f"Error connecting, retrying with {backoff_time}ms backoff" - ) - await asyncio.sleep(backoff_time / 1000) - - raise RiverException( - ERROR_HANDSHAKE, - f"Failed to create ws after retrying {max_retry} number of times", - ) from last_error - async def _retry_connection(self) -> Session: - if not self._transport_options.transparent_reconnect: - await self._close_session() + if not self._transport_options.transparent_reconnect and self._session: + logger.info("transparent_reconnect not set, closing {self._transport_id}") + await self._session.close() return await self.get_or_create_session() - async def _send_handshake_request( - self, - session_id: str, - handshake_metadata: HandshakeMetadataType | None, - websocket: ClientConnection, - expected_session_state: ExpectedSessionState, - ) -> ControlMessageHandshakeRequest[HandshakeMetadataType]: - handshake_request = ControlMessageHandshakeRequest[HandshakeMetadataType]( - type="HANDSHAKE_REQ", - protocolVersion=PROTOCOL_VERSION, - sessionId=session_id, - metadata=handshake_metadata, - expectedSessionState=expected_session_state, - ) - stream_id = self.generate_nanoid() - - async def websocket_closed_callback() -> None: - logger.error("websocket closed before handshake response") - - try: - payload = handshake_request.model_dump() - await send_transport_message( - TransportMessage( - from_=self._transport_id, - to=self._server_id, - streamId=stream_id, - controlFlags=0, - id=self.generate_nanoid(), - seq=0, - ack=0, - payload=payload, - ), - ws=websocket, - websocket_closed_callback=websocket_closed_callback, - ) - return handshake_request - except (WebsocketClosedException, FailedSendingMessageException) as e: - raise RiverException( - ERROR_HANDSHAKE, "Handshake failed, conn closed while sending response" - ) from e - - async def _get_handshake_response_msg( - self, websocket: ClientConnection - ) -> TransportMessage: - while True: - try: - data = await websocket.recv() - except ConnectionClosed as e: - logger.debug( - "Connection closed during waiting for handshake response", - exc_info=True, - ) - raise RiverException( - ERROR_HANDSHAKE, - "Handshake failed, conn closed while waiting for response", - ) from e - try: - msg = parse_transport_msg(data) - if isinstance(msg, str): - logger.debug("Ignoring transport message", exc_info=True) - continue - except InvalidMessageException as e: - raise RiverException( - ERROR_HANDSHAKE, - "Got invalid transport message, closing connection", - ) from e - - async def _establish_handshake( - self, - session_id: str, - handshake_metadata: HandshakeMetadataType, - websocket: ClientConnection, - old_session: Session | None, - ) -> tuple[ - ControlMessageHandshakeRequest[HandshakeMetadataType], - ControlMessageHandshakeResponse, - ]: - try: - expectedSessionState: ExpectedSessionState - match old_session: - case None: - expectedSessionState = ExpectedSessionState( - nextExpectedSeq=0, - nextSentSeq=0, - ) - case Session(): - expectedSessionState = ExpectedSessionState( - nextExpectedSeq=old_session.ack, - nextSentSeq=old_session.seq, - ) - case other: - assert_never(other) - handshake_request = await self._send_handshake_request( - session_id=session_id, - handshake_metadata=handshake_metadata, - websocket=websocket, - expected_session_state=expectedSessionState, - ) - except FailedSendingMessageException as e: - raise RiverException( - ERROR_CODE_STREAM_CLOSED, - "Stream closed before response, closing connection", - ) from e - - startup_grace_sec = 60 - try: - response_msg = await asyncio.wait_for( - self._get_handshake_response_msg(websocket), startup_grace_sec - ) - handshake_response = ControlMessageHandshakeResponse(**response_msg.payload) - logger.debug("river client waiting for handshake response") - except ValidationError as e: - raise RiverException( - ERROR_HANDSHAKE, "Failed to parse handshake response" - ) from e - except asyncio.TimeoutError as e: - raise RiverException( - ERROR_HANDSHAKE, "Handshake response timeout, closing connection" - ) from e - - logger.debug("river client get handshake response : %r", handshake_response) - if not handshake_response.status.ok: - if old_session and handshake_response.status.code == SESSION_MISMATCH_CODE: - # If the session status is mismatched, we should close the old session - # and let the retry logic to create a new session. - await old_session.close() - - raise RiverException( - ERROR_HANDSHAKE, - f"Handshake failed with code ${handshake_response.status.code}: " - + f"{handshake_response.status.reason}", - ) - return handshake_request, handshake_response - async def _delete_session(self, session: Session) -> None: if self._session and session._to_id == self._session._to_id: self._session = None diff --git a/src/replit_river/v2/session.py b/src/replit_river/v2/session.py index 1681b93f..b6bc6789 100644 --- a/src/replit_river/v2/session.py +++ b/src/replit_river/v2/session.py @@ -139,6 +139,7 @@ def __init__( # ws state self._ws_connected = False + self._ws_unwrapped = None self._heartbeat_misses = 0 self._retry_connection_callback = retry_connection_callback @@ -487,9 +488,6 @@ async def close(self) -> None: # invocation, so let's await this close to avoid dropping the socket. await self._ws_unwrapped.close() - # Clear the session in transports - await self._close_session_callback(self) - # TODO: unexpected_close should close stream differently here to # throw exception correctly. for stream in self._streams.values(): @@ -498,6 +496,10 @@ async def close(self) -> None: self._state = SessionState.CLOSED + # Clear the session in transports + # This will get us GC'd, so this should be the last thing. + await self._close_session_callback(self) + async def start_serve_responses(self) -> None: self._task_manager.create_task(self._serve()) @@ -528,7 +530,7 @@ async def _serve(self) -> None: ) async def _handle_messages_from_ws(self) -> None: - while self._ws_unwrapped is None: + while self._ws_unwrapped is None or not self._ws_connected: await asyncio.sleep(1) logger.debug( "%s start handling messages from ws %s", @@ -536,12 +538,8 @@ async def _handle_messages_from_ws(self) -> None: self._ws_unwrapped.id, ) try: - ws = self._ws_unwrapped - while True: - if not self._ws_unwrapped: - # We should not process messages if the websocket is closed. - break - + # We should not process messages if the websocket is closed. + while ws := self._ws_unwrapped: # decode=False: Avoiding an unnecessary round-trip through str # Ideally this should be type-ascripted to : bytes, but there is no # @overrides in `websockets` to hint this. @@ -573,14 +571,16 @@ async def _handle_messages_from_ws(self) -> None: # Set our next expected ack number self.ack = msg.seq + 1 - # Discard old messages from the buffer + # Discard old server-ack'd messages from the ack buffer while self._ack_buffer and self._ack_buffer[0].seq < msg.ack: self._ack_buffer.popleft() self._reset_session_close_countdown() + # Shortcut to avoid processing ack packets if msg.controlFlags & ACK_BIT != 0: continue + stream = self._streams.get(msg.streamId, None) if msg.controlFlags & STREAM_OPEN_BIT != 0: raise InvalidMessageException( From 046e770b62f2ce4570fc066e8aded82ee4536c66 Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Mon, 24 Mar 2025 16:23:31 -0700 Subject: [PATCH 044/193] Just use _state instead of having two --- src/replit_river/common_session.py | 14 +-- src/replit_river/session.py | 7 +- src/replit_river/v2/client_transport.py | 2 +- src/replit_river/v2/session.py | 109 +++++++++++------------- 4 files changed, 64 insertions(+), 68 deletions(-) diff --git a/src/replit_river/common_session.py b/src/replit_river/common_session.py index f1f9ccd5..e1444311 100644 --- a/src/replit_river/common_session.py +++ b/src/replit_river/common_session.py @@ -26,17 +26,19 @@ class SessionState(enum.Enum): """The state a session can be in. Valid transitions: - - NO_CONNECTION -> {ACTIVE} - - ACTIVE -> {NO_CONNECTION, CLOSING} + - NO_CONNECTION -> {CONNECTING} + - CONNECTING -> {ACTIVE, CLOSING} + - ACTIVE -> {NO_CONNECTION, CONNECTING, CLOSING} - CLOSING -> {CLOSED} - CLOSED -> {} """ NO_CONNECTION = 0 - ACTIVE = 1 - CLOSING = 2 - CLOSED = 3 + CONNECTING = 1 + ACTIVE = 2 + CLOSING = 3 + CLOSED = 4 -ConnectingStateta = set([SessionState.NO_CONNECTION]) +ConnectingStateta = set([SessionState.NO_CONNECTION, SessionState.CONNECTING]) TerminalStates = set([SessionState.CLOSING, SessionState.CLOSED]) diff --git a/src/replit_river/session.py b/src/replit_river/session.py index 517b53f6..b04463a3 100644 --- a/src/replit_river/session.py +++ b/src/replit_river/session.py @@ -120,8 +120,11 @@ def increment_and_get_heartbeat_misses() -> int: self.session_id, self._transport_options.heartbeat_ms, self._transport_options.heartbeats_until_dead, - lambda: self._state, - lambda: self._ws_wrapper.ws_state == WsState.OPEN, + lambda: ( + self._state + if self._ws_wrapper.ws_state == WsState.OPEN + else SessionState.CONNECTING + ), lambda: self._close_session_after_time_secs, close_websocket=do_close_websocket, send_message=self.send_message, diff --git a/src/replit_river/v2/client_transport.py b/src/replit_river/v2/client_transport.py index a46adeea..d26e9998 100644 --- a/src/replit_river/v2/client_transport.py +++ b/src/replit_river/v2/client_transport.py @@ -66,7 +66,7 @@ async def get_or_create_session(self) -> Session: call ensure_connected on whatever session is active. """ existing_session = self._session - if not existing_session or not existing_session.is_session_open(): + if not existing_session or existing_session.is_closed(): logger.info("Creating new session") new_session = Session( transport_id=self._transport_id, diff --git a/src/replit_river/v2/session.py b/src/replit_river/v2/session.py index b6bc6789..28bbf4d3 100644 --- a/src/replit_river/v2/session.py +++ b/src/replit_river/v2/session.py @@ -28,6 +28,7 @@ from replit_river.common_session import ( SessionState, + TerminalStates, buffered_message_sender, check_to_close_session, setup_heartbeat, @@ -103,7 +104,6 @@ class Session: _close_session_after_time_secs: float | None # ws state - _ws_connected: bool _ws_unwrapped: ClientConnection | None _heartbeat_misses: int _retry_connection_callback: RetryConnectionCallback | None @@ -133,12 +133,11 @@ def __init__( self._transport_options = transport_options # session state, only modified during closing - self._state = SessionState.ACTIVE + self._state = SessionState.CONNECTING self._close_session_callback = close_session_callback self._close_session_after_time_secs: float | None = None # ws state - self._ws_connected = False self._ws_unwrapped = None self._heartbeat_misses = 0 self._retry_connection_callback = retry_connection_callback @@ -160,18 +159,43 @@ def __init__( async def do_close_websocket() -> None: logger.debug( - "do_close called, _ws_connected=%r, _ws_unwrapped=%r", - self._ws_connected, + "do_close called, _state=%r, _ws_unwrapped=%r", + self._state, self._ws_unwrapped, ) - self._ws_connected = False + self._state = SessionState.CLOSING if self._ws_unwrapped: self._task_manager.create_task(self._ws_unwrapped.close()) if self._retry_connection_callback: self._task_manager.create_task(self._retry_connection_callback()) await self._begin_close_session_countdown() - self._setup_heartbeats_task(do_close_websocket) + def increment_and_get_heartbeat_misses() -> int: + self._heartbeat_misses += 1 + return self._heartbeat_misses + + self._task_manager.create_task( + setup_heartbeat( + self.session_id, + self._transport_options.heartbeat_ms, + self._transport_options.heartbeats_until_dead, + lambda: self._state, + lambda: self._close_session_after_time_secs, + close_websocket=do_close_websocket, + send_message=self.send_message, + increment_and_get_heartbeat_misses=increment_and_get_heartbeat_misses, + ) + ) + self._task_manager.create_task( + check_to_close_session( + self._transport_id, + self._transport_options.close_session_check_interval_ms, + lambda: self._state, + self._get_current_time, + lambda: self._close_session_after_time_secs, + self.close, + ) + ) def commit(msg: TransportMessage) -> None: pending = self._send_buffer.popleft() @@ -193,7 +217,7 @@ def get_next_pending() -> TransportMessage | None: self._message_enqueued, get_ws=lambda: ( cast(WebSocketCommonProtocol | ClientConnection, self._ws_unwrapped) - if self.is_websocket_open() + if self.is_connected() else None ), websocket_closed_callback=self._begin_close_session_countdown, @@ -214,7 +238,7 @@ async def ensure_connected[HandshakeMetadata]( Either return immediately or establish a websocket connection and return once we can accept messages """ - if self._ws_unwrapped and self._ws_connected: + if self._ws_unwrapped and self._state == SessionState.ACTIVE: return max_retry = self._transport_options.connection_retry_options.max_retry logger.info("Attempting to establish new ws connection") @@ -326,6 +350,7 @@ async def websocket_closed_callback() -> None: ) rate_limiter.start_restoring_budget(client_id) + self._state = SessionState.ACTIVE except RiverException as e: await ws.close() raise e @@ -342,44 +367,17 @@ async def websocket_closed_callback() -> None: f"Failed to create ws after retrying {max_retry} number of times", ) from last_error - def _setup_heartbeats_task( - self, - do_close_websocket: Callable[[], Awaitable[None]], - ) -> None: - def increment_and_get_heartbeat_misses() -> int: - self._heartbeat_misses += 1 - return self._heartbeat_misses - - self._task_manager.create_task( - setup_heartbeat( - self.session_id, - self._transport_options.heartbeat_ms, - self._transport_options.heartbeats_until_dead, - lambda: self._state, - lambda: self._ws_connected, - lambda: self._close_session_after_time_secs, - close_websocket=do_close_websocket, - send_message=self.send_message, - increment_and_get_heartbeat_misses=increment_and_get_heartbeat_misses, - ) - ) - self._task_manager.create_task( - check_to_close_session( - self._transport_id, - self._transport_options.close_session_check_interval_ms, - lambda: self._state, - self._get_current_time, - lambda: self._close_session_after_time_secs, - self.close, - ) - ) + def is_closed(self) -> bool: + """ + If the session is in a terminal state. + Do not send messages, do not expect any more messages to be emitted, + the state is expected to be stale. + """ + return self._state not in TerminalStates - def is_session_open(self) -> bool: + def is_connected(self) -> bool: return self._state == SessionState.ACTIVE - def is_websocket_open(self) -> bool: - return self._ws_connected - async def _begin_close_session_countdown(self) -> None: """Begin the countdown to close session, this should be called when websocket is closed. @@ -400,17 +398,6 @@ async def _begin_close_session_countdown(self) -> None: self._to_id, ) self._close_session_after_time_secs = close_session_after_time_secs - self._ws_connected = False - - async def replace_with_new_websocket(self, new_ws: ClientConnection) -> None: - if self._ws_unwrapped and new_ws.id != self._ws_unwrapped.id: - self._task_manager.create_task( - self._ws_unwrapped.close( - CloseCode.PROTOCOL_ERROR, "Transparent reconnect" - ) - ) - self._ws_unwrapped = new_ws - self._ws_connected = True async def _get_current_time(self) -> float: return asyncio.get_event_loop().time() @@ -430,7 +417,7 @@ async def send_message( ) -> None: """Send serialized messages to the websockets.""" # if the session is not active, we should not do anything - if self._state != SessionState.ACTIVE: + if self._state in TerminalStates: return msg = TransportMessage( streamId=stream_id, @@ -476,7 +463,7 @@ async def close(self) -> None: f"{self._transport_id} closing session " f"to {self._to_id}, ws: {self._ws_unwrapped}" ) - if self._state != SessionState.ACTIVE: + if self._state in TerminalStates: # already closing return self._state = SessionState.CLOSING @@ -510,6 +497,8 @@ async def _serve(self) -> None: try: await self._handle_messages_from_ws() except ConnectionClosed: + # Set ourselves to closed as soon as we get the signal + self._state = SessionState.CONNECTING if self._retry_connection_callback: self._task_manager.create_task(self._retry_connection_callback()) @@ -530,7 +519,7 @@ async def _serve(self) -> None: ) async def _handle_messages_from_ws(self) -> None: - while self._ws_unwrapped is None or not self._ws_connected: + while self._ws_unwrapped is None or self._state == SessionState.CONNECTING: await asyncio.sleep(1) logger.debug( "%s start handling messages from ws %s", @@ -628,8 +617,10 @@ async def _handle_messages_from_ws(self) -> None: await self.close() return except ConnectionClosedOK: - pass # Exited normally + # Exited normally + self._state = SessionState.CONNECTING except ConnectionClosed as e: + self._state = SessionState.CONNECTING raise e async def send_rpc[R, A]( From c457944ea42b682de020e1430e5f0bf75ba0a48f Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Mon, 24 Mar 2025 16:32:53 -0700 Subject: [PATCH 045/193] Patch circular import --- src/replit_river/rate_limiter.py | 8 +++++++- src/replit_river/v2/client_transport.py | 7 +------ src/replit_river/v2/session.py | 6 ++---- 3 files changed, 10 insertions(+), 11 deletions(-) diff --git a/src/replit_river/rate_limiter.py b/src/replit_river/rate_limiter.py index 37bf32aa..a7d0f333 100644 --- a/src/replit_river/rate_limiter.py +++ b/src/replit_river/rate_limiter.py @@ -4,12 +4,18 @@ from contextvars import Context from typing import Literal +from replit_river.error_schema import RiverException from replit_river.transport_options import ConnectionRetryOptions -from replit_river.v2.client_transport import BudgetExhaustedException logger = logging.getLogger(__name__) +class BudgetExhaustedException(RiverException): + def __init__(self, code: str, message: str, client_id: str) -> None: + super().__init__(code, message) + self.client_id = client_id + + class LeakyBucketRateLimit: """Asynchronous leaky bucket rate limiter. diff --git a/src/replit_river/v2/client_transport.py b/src/replit_river/v2/client_transport.py index d26e9998..377931d7 100644 --- a/src/replit_river/v2/client_transport.py +++ b/src/replit_river/v2/client_transport.py @@ -26,12 +26,6 @@ def __init__(self, code: str, message: str, client_id: str) -> None: self.client_id = client_id -class BudgetExhaustedException(RiverException): - def __init__(self, code: str, message: str, client_id: str) -> None: - super().__init__(code, message) - self.client_id = client_id - - class ClientTransport(Generic[HandshakeMetadataType]): _session: Session | None @@ -85,6 +79,7 @@ async def get_or_create_session(self) -> Session: client_id=self._client_id, rate_limiter=self._rate_limiter, uri_and_metadata_factory=self._uri_and_metadata_factory, + protocol_version=PROTOCOL_VERSION, ) return existing_session diff --git a/src/replit_river/v2/session.py b/src/replit_river/v2/session.py index 28bbf4d3..3e749225 100644 --- a/src/replit_river/v2/session.py +++ b/src/replit_river/v2/session.py @@ -70,9 +70,6 @@ TransportOptions, UriAndMetadata, ) -from replit_river.v2.client_transport import ( - PROTOCOL_VERSION, -) STREAM_CANCEL_BIT_TYPE = Literal[0b00100] STREAM_CANCEL_BIT: STREAM_CANCEL_BIT_TYPE = 0b00100 @@ -233,6 +230,7 @@ async def ensure_connected[HandshakeMetadata]( uri_and_metadata_factory: Callable[ [], Awaitable[UriAndMetadata[HandshakeMetadata]] ], # noqa: E501 + protocol_version: str, ) -> None: """ Either return immediately or establish a websocket connection and return @@ -266,7 +264,7 @@ async def ensure_connected[HandshakeMetadata]( HandshakeMetadata ]( # noqa: E501 type="HANDSHAKE_REQ", - protocolVersion=PROTOCOL_VERSION, + protocolVersion=protocol_version, sessionId=self.session_id, metadata=uri_and_metadata["metadata"], expectedSessionState=expectedSessionState, From 37b01877b6e9ee157c55394cadb55bc79a44d543 Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Mon, 24 Mar 2025 17:36:46 -0700 Subject: [PATCH 046/193] Block ensure_connected until connected --- src/replit_river/v2/session.py | 22 ++++++++++++++++++---- 1 file changed, 18 insertions(+), 4 deletions(-) diff --git a/src/replit_river/v2/session.py b/src/replit_river/v2/session.py index 3e749225..b297c571 100644 --- a/src/replit_river/v2/session.py +++ b/src/replit_river/v2/session.py @@ -102,6 +102,7 @@ class Session: # ws state _ws_unwrapped: ClientConnection | None + _ensure_connected_condition: asyncio.Condition _heartbeat_misses: int _retry_connection_callback: RetryConnectionCallback | None @@ -129,13 +130,14 @@ def __init__( self.session_id = session_id self._transport_options = transport_options - # session state, only modified during closing - self._state = SessionState.CONNECTING + # session state + self._state = SessionState.NO_CONNECTION self._close_session_callback = close_session_callback self._close_session_after_time_secs: float | None = None # ws state self._ws_unwrapped = None + self._ensure_connected_condition = asyncio.Condition() self._heartbeat_misses = 0 self._retry_connection_callback = retry_connection_callback @@ -236,11 +238,22 @@ async def ensure_connected[HandshakeMetadata]( Either return immediately or establish a websocket connection and return once we can accept messages """ - if self._ws_unwrapped and self._state == SessionState.ACTIVE: - return max_retry = self._transport_options.connection_retry_options.max_retry logger.info("Attempting to establish new ws connection") + if self.is_connected(): + return + + while True: + await self._ensure_connected_condition.acquire() + if self._state == SessionState.ACTIVE: + return + elif self._state == SessionState.NO_CONNECTION: + self._state = SessionState.CONNECTING + break + elif self._state in TerminalStates: + raise RiverException("SESSION_CLOSING", "Going away") + last_error: Exception | None = None i = 0 while rate_limiter.has_budget_or_throw(client_id, ERROR_HANDSHAKE, last_error): @@ -349,6 +362,7 @@ async def websocket_closed_callback() -> None: rate_limiter.start_restoring_budget(client_id) self._state = SessionState.ACTIVE + self._ensure_connected_condition.notify_all() except RiverException as e: await ws.close() raise e From c3e8b8661a2501e96e9ffc3151dd66be4527c46c Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Mon, 24 Mar 2025 18:12:12 -0700 Subject: [PATCH 047/193] Do our best to avoid contention on ensure_connected --- src/replit_river/v2/session.py | 72 +++++++++++++++++++++++++--------- 1 file changed, 53 insertions(+), 19 deletions(-) diff --git a/src/replit_river/v2/session.py b/src/replit_river/v2/session.py index b297c571..6951de3b 100644 --- a/src/replit_river/v2/session.py +++ b/src/replit_river/v2/session.py @@ -99,10 +99,10 @@ class Session: _state: SessionState _close_session_callback: CloseSessionCallback _close_session_after_time_secs: float | None + _connecting_task: asyncio.Task[Literal[True]] | None # ws state _ws_unwrapped: ClientConnection | None - _ensure_connected_condition: asyncio.Condition _heartbeat_misses: int _retry_connection_callback: RetryConnectionCallback | None @@ -134,10 +134,10 @@ def __init__( self._state = SessionState.NO_CONNECTION self._close_session_callback = close_session_callback self._close_session_after_time_secs: float | None = None + self._connecting_task = None # ws state self._ws_unwrapped = None - self._ensure_connected_condition = asyncio.Condition() self._heartbeat_misses = 0 self._retry_connection_callback = retry_connection_callback @@ -236,23 +236,38 @@ async def ensure_connected[HandshakeMetadata]( ) -> None: """ Either return immediately or establish a websocket connection and return - once we can accept messages + once we can accept messages. + + One of the goals of this function is to gate exactly one call to the + logic that actually establishes the connection. """ - max_retry = self._transport_options.connection_retry_options.max_retry - logger.info("Attempting to establish new ws connection") if self.is_connected(): return - while True: - await self._ensure_connected_condition.acquire() - if self._state == SessionState.ACTIVE: - return - elif self._state == SessionState.NO_CONNECTION: - self._state = SessionState.CONNECTING - break - elif self._state in TerminalStates: - raise RiverException("SESSION_CLOSING", "Going away") + if not self._connecting_task: + self._connecting_task = asyncio.create_task( + self._do_ensure_connected( + client_id, + rate_limiter, + uri_and_metadata_factory, + protocol_version, + ) + ) + + await self._connecting_task + + async def _do_ensure_connected[HandshakeMetadata]( + self, + client_id: str, + rate_limiter: LeakyBucketRateLimit, + uri_and_metadata_factory: Callable[ + [], Awaitable[UriAndMetadata[HandshakeMetadata]] + ], # noqa: E501 + protocol_version: str, + ) -> Literal[True]: + max_retry = self._transport_options.connection_retry_options.max_retry + logger.info("Attempting to establish new ws connection") last_error: Exception | None = None i = 0 @@ -360,9 +375,9 @@ async def websocket_closed_callback() -> None: + f"{handshake_response.status.reason}", ) + last_error = None rate_limiter.start_restoring_budget(client_id) self._state = SessionState.ACTIVE - self._ensure_connected_condition.notify_all() except RiverException as e: await ws.close() raise e @@ -374,10 +389,29 @@ async def websocket_closed_callback() -> None: ) await asyncio.sleep(backoff_time / 1000) - raise RiverException( - ERROR_HANDSHAKE, - f"Failed to create ws after retrying {max_retry} number of times", - ) from last_error + # We are in a state where we may throw an exception. + # + # To permit subsequent calls to ensure_connected to pass, we clear ourselves. + # This is safe because each individual function that is waiting on this + # function completeing already has a reference, so we'll last a few ticks + # before GC. + # + # Let's do our best to avoid clobbering other tasks by comparing the .name + current_task = asyncio.current_task() + if ( + self._connecting_task + and current_task + and self._connecting_task.get_name() == current_task.get_name() + ): + self._connecting_task = None + + if last_error is not None: + raise RiverException( + ERROR_HANDSHAKE, + f"Failed to create ws after retrying {max_retry} number of times", + ) from last_error + + return True def is_closed(self) -> bool: """ From 6a3d63eefc6a10a822ca56fa61ca1b9bf7b6d83e Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Mon, 24 Mar 2025 23:47:56 -0700 Subject: [PATCH 048/193] Fix bugs --- src/replit_river/rate_limiter.py | 3 +- src/replit_river/v2/client_transport.py | 1 + src/replit_river/v2/session.py | 43 ++++++++++++++++++------- 3 files changed, 34 insertions(+), 13 deletions(-) diff --git a/src/replit_river/rate_limiter.py b/src/replit_river/rate_limiter.py index a7d0f333..87de217e 100644 --- a/src/replit_river/rate_limiter.py +++ b/src/replit_river/rate_limiter.py @@ -91,7 +91,8 @@ def has_budget_or_throw( Returns: bool: True if budget is available, False otherwise. """ - if self.get_budget_consumed(user) < self.options.attempt_budget_capacity: + logger.debug("self.get_budget_consumed(user)=%r < self.options.attempt_budget_capacity=%r", self.get_budget_consumed(user), self.options.attempt_budget_capacity) + if self.get_budget_consumed(user) > self.options.attempt_budget_capacity: logger.debug("No retry budget for %s.", user) raise BudgetExhaustedException( error_code, diff --git a/src/replit_river/v2/client_transport.py b/src/replit_river/v2/client_transport.py index 377931d7..7d033bac 100644 --- a/src/replit_river/v2/client_transport.py +++ b/src/replit_river/v2/client_transport.py @@ -60,6 +60,7 @@ async def get_or_create_session(self) -> Session: call ensure_connected on whatever session is active. """ existing_session = self._session + logger.debug(f"if not existing_session={existing_session} or existing_session.is_closed()={existing_session and existing_session.is_closed()}:") if not existing_session or existing_session.is_closed(): logger.info("Creating new session") new_session = Session( diff --git a/src/replit_river/v2/session.py b/src/replit_river/v2/session.py index 6951de3b..7d51adb0 100644 --- a/src/replit_river/v2/session.py +++ b/src/replit_river/v2/session.py @@ -100,6 +100,7 @@ class Session: _close_session_callback: CloseSessionCallback _close_session_after_time_secs: float | None _connecting_task: asyncio.Task[Literal[True]] | None + _connection_condition: asyncio.Condition # ws state _ws_unwrapped: ClientConnection | None @@ -135,6 +136,7 @@ def __init__( self._close_session_callback = close_session_callback self._close_session_after_time_secs: float | None = None self._connecting_task = None + self._connection_condition = asyncio.Condition() # ws state self._ws_unwrapped = None @@ -162,11 +164,13 @@ async def do_close_websocket() -> None: self._state, self._ws_unwrapped, ) - self._state = SessionState.CLOSING if self._ws_unwrapped: self._task_manager.create_task(self._ws_unwrapped.close()) if self._retry_connection_callback: self._task_manager.create_task(self._retry_connection_callback()) + self._ws_unwrapped = None + else: + self._state = SessionState.CLOSING await self._begin_close_session_countdown() def increment_and_get_heartbeat_misses() -> int: @@ -211,14 +215,18 @@ def get_next_pending() -> TransportMessage | None: return self._send_buffer[0] return None + # TODO: Just return _ws_unwrapped once we are no longer using the legacy client + def get_ws() -> WebSocketCommonProtocol | ClientConnection | None: + logger.debug("get_ws: %r %r", self.is_connected(), self._ws_unwrapped) + if self.is_connected(): + return self._ws_unwrapped + return None + self._task_manager.create_task( buffered_message_sender( + self._connection_condition, self._message_enqueued, - get_ws=lambda: ( - cast(WebSocketCommonProtocol | ClientConnection, self._ws_unwrapped) - if self.is_connected() - else None - ), + get_ws=get_ws, websocket_closed_callback=self._begin_close_session_countdown, get_next_pending=get_next_pending, commit=commit, @@ -242,6 +250,7 @@ async def ensure_connected[HandshakeMetadata]( logic that actually establishes the connection. """ + logger.debug("ensure_connected: %r", self.is_connected()) if self.is_connected(): return @@ -255,7 +264,9 @@ async def ensure_connected[HandshakeMetadata]( ) ) + logger.debug("BEFORE await _do_ensure_connected") await self._connecting_task + logger.debug("AFTER await _do_ensure_connected") async def _do_ensure_connected[HandshakeMetadata]( self, @@ -271,6 +282,7 @@ async def _do_ensure_connected[HandshakeMetadata]( last_error: Exception | None = None i = 0 + await self._connection_condition.acquire() while rate_limiter.has_budget_or_throw(client_id, ERROR_HANDSHAKE, last_error): if i > 0: logger.info(f"Retrying build handshake number {i} times") @@ -378,6 +390,11 @@ async def websocket_closed_callback() -> None: last_error = None rate_limiter.start_restoring_budget(client_id) self._state = SessionState.ACTIVE + self._ws_unwrapped = ws + logger.debug("Before notify_all: %r %r %r", self._state, self._ws_unwrapped, self._connection_condition) + self._connection_condition.notify_all() + self._connection_condition.release() + break except RiverException as e: await ws.close() raise e @@ -411,6 +428,7 @@ async def websocket_closed_callback() -> None: f"Failed to create ws after retrying {max_retry} number of times", ) from last_error + logger.debug("EXITING _do_ensure_connected") return True def is_closed(self) -> bool: @@ -419,7 +437,7 @@ def is_closed(self) -> bool: Do not send messages, do not expect any more messages to be emitted, the state is expected to be stale. """ - return self._state not in TerminalStates + return self._state in TerminalStates def is_connected(self) -> bool: return self._state == SessionState.ACTIVE @@ -477,6 +495,7 @@ async def send_message( serviceName=service_name, procedureName=procedure_name, ) + logger.debug("SENDING MESSAGE: %r", msg) if span: with use_span(span): @@ -516,17 +535,17 @@ async def close(self) -> None: self._reset_session_close_countdown() await self._task_manager.cancel_all_tasks() - if self._ws_unwrapped: - # The Session isn't guaranteed to live much longer than this close() - # invocation, so let's await this close to avoid dropping the socket. - await self._ws_unwrapped.close() - # TODO: unexpected_close should close stream differently here to # throw exception correctly. for stream in self._streams.values(): stream.close() self._streams.clear() + if self._ws_unwrapped: + # The Session isn't guaranteed to live much longer than this close() + # invocation, so let's await this close to avoid dropping the socket. + await self._ws_unwrapped.close() + self._state = SessionState.CLOSED # Clear the session in transports From d0c083502a129a8d408ab44276b90a7f54d110ed Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Tue, 25 Mar 2025 00:08:32 -0700 Subject: [PATCH 049/193] Stripping WIP logging --- src/replit_river/rate_limiter.py | 1 - src/replit_river/v2/client_transport.py | 2 -- src/replit_river/v2/session.py | 16 +++++----------- 3 files changed, 5 insertions(+), 14 deletions(-) diff --git a/src/replit_river/rate_limiter.py b/src/replit_river/rate_limiter.py index 87de217e..384288be 100644 --- a/src/replit_river/rate_limiter.py +++ b/src/replit_river/rate_limiter.py @@ -91,7 +91,6 @@ def has_budget_or_throw( Returns: bool: True if budget is available, False otherwise. """ - logger.debug("self.get_budget_consumed(user)=%r < self.options.attempt_budget_capacity=%r", self.get_budget_consumed(user), self.options.attempt_budget_capacity) if self.get_budget_consumed(user) > self.options.attempt_budget_capacity: logger.debug("No retry budget for %s.", user) raise BudgetExhaustedException( diff --git a/src/replit_river/v2/client_transport.py b/src/replit_river/v2/client_transport.py index 7d033bac..98f6a51e 100644 --- a/src/replit_river/v2/client_transport.py +++ b/src/replit_river/v2/client_transport.py @@ -50,7 +50,6 @@ def __init__( async def close(self) -> None: self._rate_limiter.close() if self._session: - logger.info(f"start closing session {self._transport_id}") await self._session.close() logger.info(f"Transport closed {self._transport_id}") @@ -60,7 +59,6 @@ async def get_or_create_session(self) -> Session: call ensure_connected on whatever session is active. """ existing_session = self._session - logger.debug(f"if not existing_session={existing_session} or existing_session.is_closed()={existing_session and existing_session.is_closed()}:") if not existing_session or existing_session.is_closed(): logger.info("Creating new session") new_session = Session( diff --git a/src/replit_river/v2/session.py b/src/replit_river/v2/session.py index 7d51adb0..6ea03367 100644 --- a/src/replit_river/v2/session.py +++ b/src/replit_river/v2/session.py @@ -11,7 +11,6 @@ Coroutine, Literal, TypeAlias, - cast, ) import nanoid # type: ignore @@ -217,7 +216,6 @@ def get_next_pending() -> TransportMessage | None: # TODO: Just return _ws_unwrapped once we are no longer using the legacy client def get_ws() -> WebSocketCommonProtocol | ClientConnection | None: - logger.debug("get_ws: %r %r", self.is_connected(), self._ws_unwrapped) if self.is_connected(): return self._ws_unwrapped return None @@ -250,7 +248,7 @@ async def ensure_connected[HandshakeMetadata]( logic that actually establishes the connection. """ - logger.debug("ensure_connected: %r", self.is_connected()) + logger.debug("ensure_connected: is_connected=%r", self.is_connected()) if self.is_connected(): return @@ -264,9 +262,7 @@ async def ensure_connected[HandshakeMetadata]( ) ) - logger.debug("BEFORE await _do_ensure_connected") await self._connecting_task - logger.debug("AFTER await _do_ensure_connected") async def _do_ensure_connected[HandshakeMetadata]( self, @@ -391,9 +387,7 @@ async def websocket_closed_callback() -> None: rate_limiter.start_restoring_budget(client_id) self._state = SessionState.ACTIVE self._ws_unwrapped = ws - logger.debug("Before notify_all: %r %r %r", self._state, self._ws_unwrapped, self._connection_condition) self._connection_condition.notify_all() - self._connection_condition.release() break except RiverException as e: await ws.close() @@ -422,13 +416,16 @@ async def websocket_closed_callback() -> None: ): self._connecting_task = None + # Release the lock we took earlier so we can use it again in the next + # connection attempt + self._connection_condition.release() + if last_error is not None: raise RiverException( ERROR_HANDSHAKE, f"Failed to create ws after retrying {max_retry} number of times", ) from last_error - logger.debug("EXITING _do_ensure_connected") return True def is_closed(self) -> bool: @@ -495,7 +492,6 @@ async def send_message( serviceName=service_name, procedureName=procedure_name, ) - logger.debug("SENDING MESSAGE: %r", msg) if span: with use_span(span): @@ -514,9 +510,7 @@ async def send_message( self._queue_full_lock.locked() or len(self._send_buffer) >= self._transport_options.buffer_size ): - logger.warning("LOCK ACQUIRED %r", repr(payload)) await self._queue_full_lock.acquire() - logger.warning("LOCK RELEASED %r", repr(payload)) self._send_buffer.append(msg) # Wake up buffered_message_sender self._message_enqueued.release() From 4ed787a97dacf25341e64d563a6bfed8bc35c240 Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Tue, 25 Mar 2025 00:49:59 -0700 Subject: [PATCH 050/193] Adding a missing debug log --- src/replit_river/v2/session.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/src/replit_river/v2/session.py b/src/replit_river/v2/session.py index 6ea03367..e67bacab 100644 --- a/src/replit_river/v2/session.py +++ b/src/replit_river/v2/session.py @@ -480,6 +480,15 @@ async def send_message( # if the session is not active, we should not do anything if self._state in TerminalStates: return + logger.debug( + "send_message(stream_id=%r, payload=%r, control_flags=%r, " + "service_name=%r, procedure_name=%r)", + stream_id, + payload, + bin(control_flags), + service_name, + procedure_name, + ) msg = TransportMessage( streamId=stream_id, id=nanoid.generate(), From e8c241797083d897b4ce1153c14618aabfdd59ef Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Tue, 25 Mar 2025 09:32:34 -0700 Subject: [PATCH 051/193] More logging around message receipt --- src/replit_river/v2/session.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/replit_river/v2/session.py b/src/replit_river/v2/session.py index e67bacab..40973586 100644 --- a/src/replit_river/v2/session.py +++ b/src/replit_river/v2/session.py @@ -519,6 +519,7 @@ async def send_message( self._queue_full_lock.locked() or len(self._send_buffer) >= self._transport_options.buffer_size ): + logging.debug("send_message: queue full, waiting") await self._queue_full_lock.acquire() self._send_buffer.append(msg) # Wake up buffered_message_sender @@ -587,7 +588,9 @@ async def _serve(self) -> None: ) async def _handle_messages_from_ws(self) -> None: + logging.debug("_handle_messages_from_ws started") while self._ws_unwrapped is None or self._state == SessionState.CONNECTING: + logging.debug("_handle_messages_from_ws started") await asyncio.sleep(1) logger.debug( "%s start handling messages from ws %s", @@ -690,6 +693,7 @@ async def _handle_messages_from_ws(self) -> None: except ConnectionClosed as e: self._state = SessionState.CONNECTING raise e + logging.debug("_handle_messages_from_ws exiting") # When the network disconnects this Task exits and then we don't restart it. async def send_rpc[R, A]( self, From 8107ab654bd9dd4de0c91762a8addac044b98887 Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Tue, 25 Mar 2025 10:20:57 -0700 Subject: [PATCH 052/193] Prevent _handle_messages_from_ws from terminating early --- src/replit_river/v2/session.py | 75 ++++++++++++++++++++++------------ 1 file changed, 48 insertions(+), 27 deletions(-) diff --git a/src/replit_river/v2/session.py b/src/replit_river/v2/session.py index 40973586..70fe0aa3 100644 --- a/src/replit_river/v2/session.py +++ b/src/replit_river/v2/session.py @@ -24,6 +24,7 @@ from websockets.exceptions import ConnectionClosed, ConnectionClosedOK from websockets.frames import CloseCode from websockets.legacy.protocol import WebSocketCommonProtocol +from websockets.protocol import CONNECTING from replit_river.common_session import ( SessionState, @@ -557,40 +558,60 @@ async def close(self) -> None: await self._close_session_callback(self) async def start_serve_responses(self) -> None: - self._task_manager.create_task(self._serve()) + async def transition_closed() -> None: + self._state = SessionState.CONNECTING + if self._retry_connection_callback: + self._task_manager.create_task(self._retry_connection_callback()) + + await self._begin_close_session_countdown() + self._task_manager.create_task(self._serve( + get_state=lambda: self._state, + transition_closed=transition_closed, + reset_session_close_countdown=self._reset_session_close_countdown, + )) - async def _serve(self) -> None: + async def _serve( + self, + get_state: Callable[[], SessionState], + transition_closed: Callable[[], Awaitable[None]], + reset_session_close_countdown: Callable[[], None], + ) -> None: """Serve messages from the websocket.""" - self._reset_session_close_countdown() - try: + reset_session_close_countdown() + our_task = asyncio.current_task() + idx = 0 + while our_task and not our_task.cancelling() and not our_task.cancelled(): + logging.debug(f"_serve loop count={idx}") + idx += 1 try: - await self._handle_messages_from_ws() - except ConnectionClosed: - # Set ourselves to closed as soon as we get the signal - self._state = SessionState.CONNECTING - if self._retry_connection_callback: - self._task_manager.create_task(self._retry_connection_callback()) - - await self._begin_close_session_countdown() - logger.debug("ConnectionClosed while serving", exc_info=True) - except FailedSendingMessageException: - # Expected error if the connection is closed. - logger.debug( - "FailedSendingMessageException while serving", exc_info=True - ) - except Exception: - logger.exception("caught exception at message iterator") - except ExceptionGroup as eg: - _, unhandled = eg.split(lambda e: isinstance(e, ConnectionClosed)) - if unhandled: - raise ExceptionGroup( - "Unhandled exceptions on River server", unhandled.exceptions - ) + try: + await self._handle_messages_from_ws() + except ConnectionClosed: + # Set ourselves to closed as soon as we get the signal + await transition_closed() + logger.debug("ConnectionClosed while serving", exc_info=True) + except FailedSendingMessageException: + # Expected error if the connection is closed. + logger.debug( + "FailedSendingMessageException while serving", exc_info=True + ) + except Exception: + logger.exception("caught exception at message iterator") + except ExceptionGroup as eg: + _, unhandled = eg.split(lambda e: isinstance(e, ConnectionClosed)) + if unhandled: + # We're in a task, there's not that much that can be done. + unhandled = ExceptionGroup( + "Unhandled exceptions on River server", unhandled.exceptions + ) + logger.exception("caught exception at message iterator", exc_info=unhandled) + raise unhandled + logging.debug(f"_serve exiting normally after {idx} loops") async def _handle_messages_from_ws(self) -> None: logging.debug("_handle_messages_from_ws started") while self._ws_unwrapped is None or self._state == SessionState.CONNECTING: - logging.debug("_handle_messages_from_ws started") + logging.debug("_handle_messages_from_ws spinning while connecting") await asyncio.sleep(1) logger.debug( "%s start handling messages from ws %s", From 2bc50c36a64b6f8c4d1c20ef8ee280b7c50cdfec Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Tue, 25 Mar 2025 11:15:55 -0700 Subject: [PATCH 053/193] Break out _serve from Session --- src/replit_river/v2/session.py | 353 +++++++++++++++++++-------------- 1 file changed, 208 insertions(+), 145 deletions(-) diff --git a/src/replit_river/v2/session.py b/src/replit_river/v2/session.py index 70fe0aa3..dc2c95bb 100644 --- a/src/replit_river/v2/session.py +++ b/src/replit_river/v2/session.py @@ -2,6 +2,7 @@ import logging from collections import deque from collections.abc import AsyncIterable +from dataclasses import dataclass from datetime import timedelta from typing import ( Any, @@ -11,6 +12,7 @@ Coroutine, Literal, TypeAlias, + assert_never, ) import nanoid # type: ignore @@ -22,9 +24,7 @@ from pydantic import ValidationError from websockets.asyncio.client import ClientConnection from websockets.exceptions import ConnectionClosed, ConnectionClosedOK -from websockets.frames import CloseCode from websockets.legacy.protocol import WebSocketCommonProtocol -from websockets.protocol import CONNECTING from replit_river.common_session import ( SessionState, @@ -89,6 +89,11 @@ ] +@dataclass +class _IgnoreMessage: + pass + + class Session: _transport_id: str _to_id: str @@ -558,163 +563,65 @@ async def close(self) -> None: await self._close_session_callback(self) async def start_serve_responses(self) -> None: + async def transition_connecting() -> None: + self._state = SessionState.CONNECTING + async def transition_closed() -> None: self._state = SessionState.CONNECTING if self._retry_connection_callback: self._task_manager.create_task(self._retry_connection_callback()) await self._begin_close_session_countdown() - self._task_manager.create_task(self._serve( - get_state=lambda: self._state, - transition_closed=transition_closed, - reset_session_close_countdown=self._reset_session_close_countdown, - )) - - async def _serve( - self, - get_state: Callable[[], SessionState], - transition_closed: Callable[[], Awaitable[None]], - reset_session_close_countdown: Callable[[], None], - ) -> None: - """Serve messages from the websocket.""" - reset_session_close_countdown() - our_task = asyncio.current_task() - idx = 0 - while our_task and not our_task.cancelling() and not our_task.cancelled(): - logging.debug(f"_serve loop count={idx}") - idx += 1 - try: - try: - await self._handle_messages_from_ws() - except ConnectionClosed: - # Set ourselves to closed as soon as we get the signal - await transition_closed() - logger.debug("ConnectionClosed while serving", exc_info=True) - except FailedSendingMessageException: - # Expected error if the connection is closed. - logger.debug( - "FailedSendingMessageException while serving", exc_info=True - ) - except Exception: - logger.exception("caught exception at message iterator") - except ExceptionGroup as eg: - _, unhandled = eg.split(lambda e: isinstance(e, ConnectionClosed)) - if unhandled: - # We're in a task, there's not that much that can be done. - unhandled = ExceptionGroup( - "Unhandled exceptions on River server", unhandled.exceptions - ) - logger.exception("caught exception at message iterator", exc_info=unhandled) - raise unhandled - logging.debug(f"_serve exiting normally after {idx} loops") - - async def _handle_messages_from_ws(self) -> None: - logging.debug("_handle_messages_from_ws started") - while self._ws_unwrapped is None or self._state == SessionState.CONNECTING: - logging.debug("_handle_messages_from_ws spinning while connecting") - await asyncio.sleep(1) - logger.debug( - "%s start handling messages from ws %s", - "client", - self._ws_unwrapped.id, - ) - try: - # We should not process messages if the websocket is closed. - while ws := self._ws_unwrapped: - # decode=False: Avoiding an unnecessary round-trip through str - # Ideally this should be type-ascripted to : bytes, but there is no - # @overrides in `websockets` to hint this. - message = await ws.recv(decode=False) - try: - msg = parse_transport_msg(message) - - logger.debug(f"{self._transport_id} got a message %r", msg) - # Update bookkeeping - if msg.seq < self.ack: - raise IgnoreMessageException( - f"{msg.from_} received duplicate msg, got {msg.seq}" - f" expected {self.ack}" - ) - elif msg.seq > self.ack: - logger.warning( - f"Out of order message received got {msg.seq} expected " - f"{self.ack}" - ) - - raise OutOfOrderMessageException( - f"Out of order message received got {msg.seq} expected " - f"{self.ack}" - ) - - assert msg.seq == self.ack, "Safety net, redundant assertion" + def assert_incoming_seq_bookkeeping( + msg_from: str, + msg_seq: int, + msg_ack: int, + ) -> Literal[True] | _IgnoreMessage: + # Update bookkeeping + if msg_seq < self.ack: + logging.info( + f"{msg_from} received duplicate msg, got {msg_seq}" + f" expected {self.ack}" + ) + return _IgnoreMessage() + elif msg_seq > self.ack: + logger.warning( + f"Out of order message received got {msg_seq} expected {self.ack}" + ) - # Set our next expected ack number - self.ack = msg.seq + 1 + raise OutOfOrderMessageException( + f"Out of order message received got {msg_seq} expected {self.ack}" + ) - # Discard old server-ack'd messages from the ack buffer - while self._ack_buffer and self._ack_buffer[0].seq < msg.ack: - self._ack_buffer.popleft() + assert msg_seq == self.ack, "Safety net, redundant assertion" - self._reset_session_close_countdown() + # Set our next expected ack number + self.ack = msg_seq + 1 - # Shortcut to avoid processing ack packets - if msg.controlFlags & ACK_BIT != 0: - continue + # Discard old server-ack'd messages from the ack buffer + while self._ack_buffer and self._ack_buffer[0].seq < msg_ack: + self._ack_buffer.popleft() - stream = self._streams.get(msg.streamId, None) - if msg.controlFlags & STREAM_OPEN_BIT != 0: - raise InvalidMessageException( - "Client should not receive stream open bit" - ) + return True - if not stream: - logger.warning("no stream for %s", msg.streamId) - raise IgnoreMessageException("no stream for message, ignoring") + def close_stream(stream_id: str) -> None: + del self._streams[stream_id] - if ( - msg.controlFlags & STREAM_CLOSED_BIT != 0 - and msg.payload.get("type", None) == "CLOSE" - ): - # close message is not sent to the stream - pass - else: - try: - await stream.put(msg.payload) - except ChannelClosed: - # The client is no longer interested in this stream, - # just drop the message. - pass - except RuntimeError as e: - raise InvalidMessageException(e) from e - - if msg.controlFlags & STREAM_CLOSED_BIT != 0: - if stream: - stream.close() - del self._streams[msg.streamId] - except IgnoreMessageException: - logger.debug("Ignoring transport message", exc_info=True) - continue - except OutOfOrderMessageException: - logger.exception("Out of order message, closing connection") - self._task_manager.create_task( - self._ws_unwrapped.close( - code=CloseCode.INVALID_DATA, - reason="Out of order message", - ) - ) - return - except InvalidMessageException: - logger.exception("Got invalid transport message, closing session") - await self.close() - return - except ConnectionClosedOK: - # Exited normally - self._state = SessionState.CONNECTING - except ConnectionClosed as e: - self._state = SessionState.CONNECTING - raise e - logging.debug("_handle_messages_from_ws exiting") # When the network disconnects this Task exits and then we don't restart it. + self._task_manager.create_task( + _serve( + self._transport_id, + get_state=lambda: self._state, + get_ws=lambda: self._ws_unwrapped, + transition_connecting=transition_connecting, + transition_closed=transition_closed, + reset_session_close_countdown=self._reset_session_close_countdown, + close_session=self.close, + assert_incoming_seq_bookkeeping=assert_incoming_seq_bookkeeping, + get_stream=lambda stream_id: self._streams.get(stream_id), + close_stream=close_stream, + ) + ) async def send_rpc[R, A]( self, @@ -1021,3 +928,159 @@ async def send_close_stream( "type": "CLOSE", }, ) + + +async def _serve( + transport_id: str, + get_state: Callable[[], SessionState], + get_ws: Callable[[], ClientConnection | None], + transition_connecting: Callable[[], Awaitable[None]], + transition_closed: Callable[[], Awaitable[None]], + reset_session_close_countdown: Callable[[], None], + close_session: Callable[[], Awaitable[None]], + assert_incoming_seq_bookkeeping: Callable[ + [str, int, int], Literal[True] | _IgnoreMessage + ], # noqa: E501 + get_stream: Callable[[str], Channel[Any] | None], + close_stream: Callable[[str], None], +) -> None: + """Serve messages from the websocket.""" + reset_session_close_countdown() + our_task = asyncio.current_task() + idx = 0 + while our_task and not our_task.cancelling() and not our_task.cancelled(): + logging.debug(f"_serve loop count={idx}") + idx += 1 + try: + try: + await _handle_messages_from_ws( + transport_id=transport_id, + get_state=get_state, + get_ws=get_ws, + transition_connecting=transition_connecting, + close_session=close_session, + assert_incoming_seq_bookkeeping=assert_incoming_seq_bookkeeping, + reset_session_close_countdown=reset_session_close_countdown, + get_stream=get_stream, + close_stream=close_stream, + ) + except ConnectionClosed: + # Set ourselves to closed as soon as we get the signal + await transition_closed() + logger.debug("ConnectionClosed while serving", exc_info=True) + except FailedSendingMessageException: + # Expected error if the connection is closed. + logger.debug( + "FailedSendingMessageException while serving", exc_info=True + ) + except Exception: + logger.exception("caught exception at message iterator") + except ExceptionGroup as eg: + _, unhandled = eg.split(lambda e: isinstance(e, ConnectionClosed)) + if unhandled: + # We're in a task, there's not that much that can be done. + unhandled = ExceptionGroup( + "Unhandled exceptions on River server", unhandled.exceptions + ) + logger.exception( + "caught exception at message iterator", + exc_info=unhandled, + ) + raise unhandled + logging.debug(f"_serve exiting normally after {idx} loops") + + +async def _handle_messages_from_ws( + transport_id: str, + get_state: Callable[[], SessionState], + get_ws: Callable[[], ClientConnection | None], + transition_connecting: Callable[[], Awaitable[None]], + close_session: Callable[[], Awaitable[None]], + assert_incoming_seq_bookkeeping: Callable[ + [str, int, int], Literal[True] | _IgnoreMessage + ], # noqa: E501 + reset_session_close_countdown: Callable[[], None], + get_stream: Callable[[str], Channel[Any] | None], + close_stream: Callable[[str], None], +) -> None: + logging.debug("_handle_messages_from_ws started") + while (ws := get_ws()) is None or get_state() == SessionState.CONNECTING: + logging.debug("_handle_messages_from_ws spinning while connecting") + await asyncio.sleep(1) + logger.debug( + "%s start handling messages from ws %s", + "client", + ws.id, + ) + try: + # We should not process messages if the websocket is closed. + while ws := get_ws(): + # decode=False: Avoiding an unnecessary round-trip through str + # Ideally this should be type-ascripted to : bytes, but there is no + # @overrides in `websockets` to hint this. + message = await ws.recv(decode=False) + try: + msg = parse_transport_msg(message) + logger.debug("[%s] got a message %r", transport_id, msg) + + if msg.controlFlags & STREAM_OPEN_BIT != 0: + raise InvalidMessageException( + "Client should not receive stream open bit" + ) + + match assert_incoming_seq_bookkeeping(msg.from_, msg.seq, msg.ack): + case _IgnoreMessage(): + logger.debug("Ignoring transport message", exc_info=True) + continue + case True: + pass + case other: + assert_never(other) + + reset_session_close_countdown() + + # Shortcut to avoid processing ack packets + if msg.controlFlags & ACK_BIT != 0: + continue + + stream = get_stream(msg.streamId) + + if not stream: + logger.warning("no stream for %s, ignoring message", msg.streamId) + continue + + if ( + msg.controlFlags & STREAM_CLOSED_BIT != 0 + and msg.payload.get("type", None) == "CLOSE" + ): + # close message is not sent to the stream + pass + else: + try: + await stream.put(msg.payload) + except ChannelClosed: + # The client is no longer interested in this stream, + # just drop the message. + pass + except RuntimeError as e: + raise InvalidMessageException(e) from e + + if msg.controlFlags & STREAM_CLOSED_BIT != 0: + if stream: + stream.close() + close_stream(msg.streamId) + except OutOfOrderMessageException: + logger.exception("Out of order message, closing connection") + await close_session() + return + except InvalidMessageException: + logger.exception("Got invalid transport message, closing session") + await close_session() + return + except ConnectionClosedOK: + # Exited normally + transition_connecting() + except ConnectionClosed as e: + transition_connecting() + raise e + logging.debug("_handle_messages_from_ws exiting") From 4fd76e73f9de176983f0f109063f39ae9f04d614 Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Tue, 25 Mar 2025 11:18:43 -0700 Subject: [PATCH 054/193] Merging _handle_messages_from_ws --- src/replit_river/v2/session.py | 204 ++++++++++++++++----------------- 1 file changed, 98 insertions(+), 106 deletions(-) diff --git a/src/replit_river/v2/session.py b/src/replit_river/v2/session.py index dc2c95bb..dd55fccf 100644 --- a/src/replit_river/v2/session.py +++ b/src/replit_river/v2/session.py @@ -953,17 +953,105 @@ async def _serve( idx += 1 try: try: - await _handle_messages_from_ws( - transport_id=transport_id, - get_state=get_state, - get_ws=get_ws, - transition_connecting=transition_connecting, - close_session=close_session, - assert_incoming_seq_bookkeeping=assert_incoming_seq_bookkeeping, - reset_session_close_countdown=reset_session_close_countdown, - get_stream=get_stream, - close_stream=close_stream, + logging.debug("_handle_messages_from_ws started") + while ( + ws := get_ws() + ) is None or get_state() == SessionState.CONNECTING: + logging.debug("_handle_messages_from_ws spinning while connecting") + await asyncio.sleep(1) + logger.debug( + "%s start handling messages from ws %s", + "client", + ws.id, ) + try: + # We should not process messages if the websocket is closed. + while ws := get_ws(): + # decode=False: Avoiding an unnecessary round-trip through str + # Ideally this should be type-ascripted to : bytes, but there + # is no @overrides in `websockets` to hint this. + message = await ws.recv(decode=False) + try: + msg = parse_transport_msg(message) + logger.debug( + "[%s] got a message %r", + transport_id, + msg, + ) + + if msg.controlFlags & STREAM_OPEN_BIT != 0: + raise InvalidMessageException( + "Client should not receive stream open bit" + ) + + match assert_incoming_seq_bookkeeping( + msg.from_, + msg.seq, + msg.ack, + ): + case _IgnoreMessage(): + logger.debug( + "Ignoring transport message", + exc_info=True, + ) + continue + case True: + pass + case other: + assert_never(other) + + reset_session_close_countdown() + + # Shortcut to avoid processing ack packets + if msg.controlFlags & ACK_BIT != 0: + continue + + stream = get_stream(msg.streamId) + + if not stream: + logger.warning( + "no stream for %s, ignoring message", + msg.streamId, + ) + continue + + if ( + msg.controlFlags & STREAM_CLOSED_BIT != 0 + and msg.payload.get("type", None) == "CLOSE" + ): + # close message is not sent to the stream + pass + else: + try: + await stream.put(msg.payload) + except ChannelClosed: + # The client is no longer interested in this stream, + # just drop the message. + pass + except RuntimeError as e: + raise InvalidMessageException(e) from e + + if msg.controlFlags & STREAM_CLOSED_BIT != 0: + if stream: + stream.close() + close_stream(msg.streamId) + except OutOfOrderMessageException: + logger.exception("Out of order message, closing connection") + await close_session() + return + except InvalidMessageException: + logger.exception( + "Got invalid transport message, closing session", + ) + await close_session() + return + except ConnectionClosedOK: + # Exited normally + transition_connecting() + except ConnectionClosed as e: + transition_connecting() + raise e + logging.debug("_handle_messages_from_ws exiting") except ConnectionClosed: # Set ourselves to closed as soon as we get the signal await transition_closed() @@ -988,99 +1076,3 @@ async def _serve( ) raise unhandled logging.debug(f"_serve exiting normally after {idx} loops") - - -async def _handle_messages_from_ws( - transport_id: str, - get_state: Callable[[], SessionState], - get_ws: Callable[[], ClientConnection | None], - transition_connecting: Callable[[], Awaitable[None]], - close_session: Callable[[], Awaitable[None]], - assert_incoming_seq_bookkeeping: Callable[ - [str, int, int], Literal[True] | _IgnoreMessage - ], # noqa: E501 - reset_session_close_countdown: Callable[[], None], - get_stream: Callable[[str], Channel[Any] | None], - close_stream: Callable[[str], None], -) -> None: - logging.debug("_handle_messages_from_ws started") - while (ws := get_ws()) is None or get_state() == SessionState.CONNECTING: - logging.debug("_handle_messages_from_ws spinning while connecting") - await asyncio.sleep(1) - logger.debug( - "%s start handling messages from ws %s", - "client", - ws.id, - ) - try: - # We should not process messages if the websocket is closed. - while ws := get_ws(): - # decode=False: Avoiding an unnecessary round-trip through str - # Ideally this should be type-ascripted to : bytes, but there is no - # @overrides in `websockets` to hint this. - message = await ws.recv(decode=False) - try: - msg = parse_transport_msg(message) - logger.debug("[%s] got a message %r", transport_id, msg) - - if msg.controlFlags & STREAM_OPEN_BIT != 0: - raise InvalidMessageException( - "Client should not receive stream open bit" - ) - - match assert_incoming_seq_bookkeeping(msg.from_, msg.seq, msg.ack): - case _IgnoreMessage(): - logger.debug("Ignoring transport message", exc_info=True) - continue - case True: - pass - case other: - assert_never(other) - - reset_session_close_countdown() - - # Shortcut to avoid processing ack packets - if msg.controlFlags & ACK_BIT != 0: - continue - - stream = get_stream(msg.streamId) - - if not stream: - logger.warning("no stream for %s, ignoring message", msg.streamId) - continue - - if ( - msg.controlFlags & STREAM_CLOSED_BIT != 0 - and msg.payload.get("type", None) == "CLOSE" - ): - # close message is not sent to the stream - pass - else: - try: - await stream.put(msg.payload) - except ChannelClosed: - # The client is no longer interested in this stream, - # just drop the message. - pass - except RuntimeError as e: - raise InvalidMessageException(e) from e - - if msg.controlFlags & STREAM_CLOSED_BIT != 0: - if stream: - stream.close() - close_stream(msg.streamId) - except OutOfOrderMessageException: - logger.exception("Out of order message, closing connection") - await close_session() - return - except InvalidMessageException: - logger.exception("Got invalid transport message, closing session") - await close_session() - return - except ConnectionClosedOK: - # Exited normally - transition_connecting() - except ConnectionClosed as e: - transition_connecting() - raise e - logging.debug("_handle_messages_from_ws exiting") From 0be92676717b693f06bbb598cdbb8d559cfb5c2a Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Tue, 25 Mar 2025 11:35:08 -0700 Subject: [PATCH 055/193] Flattening try:while:try: --- src/replit_river/v2/session.py | 212 ++++++++++++++++----------------- 1 file changed, 105 insertions(+), 107 deletions(-) diff --git a/src/replit_river/v2/session.py b/src/replit_river/v2/session.py index dd55fccf..1aaa97c3 100644 --- a/src/replit_river/v2/session.py +++ b/src/replit_river/v2/session.py @@ -566,7 +566,7 @@ async def start_serve_responses(self) -> None: async def transition_connecting() -> None: self._state = SessionState.CONNECTING - async def transition_closed() -> None: + async def connection_interrupted() -> None: self._state = SessionState.CONNECTING if self._retry_connection_callback: self._task_manager.create_task(self._retry_connection_callback()) @@ -614,7 +614,7 @@ def close_stream(stream_id: str) -> None: get_state=lambda: self._state, get_ws=lambda: self._ws_unwrapped, transition_connecting=transition_connecting, - transition_closed=transition_closed, + connection_interrupted=connection_interrupted, reset_session_close_countdown=self._reset_session_close_countdown, close_session=self.close, assert_incoming_seq_bookkeeping=assert_incoming_seq_bookkeeping, @@ -935,7 +935,7 @@ async def _serve( get_state: Callable[[], SessionState], get_ws: Callable[[], ClientConnection | None], transition_connecting: Callable[[], Awaitable[None]], - transition_closed: Callable[[], Awaitable[None]], + connection_interrupted: Callable[[], Awaitable[None]], reset_session_close_countdown: Callable[[], None], close_session: Callable[[], Awaitable[None]], assert_incoming_seq_bookkeeping: Callable[ @@ -952,117 +952,115 @@ async def _serve( logging.debug(f"_serve loop count={idx}") idx += 1 try: - try: - logging.debug("_handle_messages_from_ws started") - while ( - ws := get_ws() - ) is None or get_state() == SessionState.CONNECTING: - logging.debug("_handle_messages_from_ws spinning while connecting") - await asyncio.sleep(1) - logger.debug( - "%s start handling messages from ws %s", - "client", - ws.id, - ) + logging.debug("_handle_messages_from_ws started") + while (ws := get_ws()) is None or get_state() == SessionState.CONNECTING: + logging.debug("_handle_messages_from_ws spinning while connecting") + await asyncio.sleep(1) + logger.debug( + "%s start handling messages from ws %s", + "client", + ws.id, + ) + # We should not process messages if the websocket is closed. + while (ws := get_ws()) and get_state() == SessionState.ACTIVE: + # decode=False: Avoiding an unnecessary round-trip through str + # Ideally this should be type-ascripted to : bytes, but there + # is no @overrides in `websockets` to hint this. + message = await ws.recv(decode=False) try: - # We should not process messages if the websocket is closed. - while ws := get_ws(): - # decode=False: Avoiding an unnecessary round-trip through str - # Ideally this should be type-ascripted to : bytes, but there - # is no @overrides in `websockets` to hint this. - message = await ws.recv(decode=False) - try: - msg = parse_transport_msg(message) + msg = parse_transport_msg(message) + logger.debug( + "[%s] got a message %r", + transport_id, + msg, + ) + + if msg.controlFlags & STREAM_OPEN_BIT != 0: + raise InvalidMessageException( + "Client should not receive stream open bit" + ) + + match assert_incoming_seq_bookkeeping( + msg.from_, + msg.seq, + msg.ack, + ): + case _IgnoreMessage(): logger.debug( - "[%s] got a message %r", - transport_id, - msg, + "Ignoring transport message", + exc_info=True, ) + continue + case True: + pass + case other: + assert_never(other) - if msg.controlFlags & STREAM_OPEN_BIT != 0: - raise InvalidMessageException( - "Client should not receive stream open bit" - ) - - match assert_incoming_seq_bookkeeping( - msg.from_, - msg.seq, - msg.ack, - ): - case _IgnoreMessage(): - logger.debug( - "Ignoring transport message", - exc_info=True, - ) - continue - case True: - pass - case other: - assert_never(other) - - reset_session_close_countdown() - - # Shortcut to avoid processing ack packets - if msg.controlFlags & ACK_BIT != 0: - continue - - stream = get_stream(msg.streamId) - - if not stream: - logger.warning( - "no stream for %s, ignoring message", - msg.streamId, - ) - continue - - if ( - msg.controlFlags & STREAM_CLOSED_BIT != 0 - and msg.payload.get("type", None) == "CLOSE" - ): - # close message is not sent to the stream - pass - else: - try: - await stream.put(msg.payload) - except ChannelClosed: - # The client is no longer interested in this stream, - # just drop the message. - pass - except RuntimeError as e: - raise InvalidMessageException(e) from e - - if msg.controlFlags & STREAM_CLOSED_BIT != 0: - if stream: - stream.close() - close_stream(msg.streamId) - except OutOfOrderMessageException: - logger.exception("Out of order message, closing connection") - await close_session() - return - except InvalidMessageException: - logger.exception( - "Got invalid transport message, closing session", - ) - await close_session() - return + reset_session_close_countdown() + + # Shortcut to avoid processing ack packets + if msg.controlFlags & ACK_BIT != 0: + continue + + stream = get_stream(msg.streamId) + + if not stream: + logger.warning( + "no stream for %s, ignoring message", + msg.streamId, + ) + continue + + if ( + msg.controlFlags & STREAM_CLOSED_BIT != 0 + and msg.payload.get("type", None) == "CLOSE" + ): + # close message is not sent to the stream + pass + else: + try: + await stream.put(msg.payload) + except ChannelClosed: + # The client is no longer interested in this stream, + # just drop the message. + pass + except RuntimeError as e: + raise InvalidMessageException(e) from e + + if msg.controlFlags & STREAM_CLOSED_BIT != 0: + if stream: + stream.close() + close_stream(msg.streamId) + except OutOfOrderMessageException: + logger.exception("Out of order message, closing connection") + await close_session() + continue + except InvalidMessageException: + logger.exception( + "Got invalid transport message, closing session", + ) + await close_session() + continue except ConnectionClosedOK: # Exited normally transition_connecting() - except ConnectionClosed as e: - transition_connecting() - raise e - logging.debug("_handle_messages_from_ws exiting") - except ConnectionClosed: - # Set ourselves to closed as soon as we get the signal - await transition_closed() - logger.debug("ConnectionClosed while serving", exc_info=True) - except FailedSendingMessageException: - # Expected error if the connection is closed. - logger.debug( - "FailedSendingMessageException while serving", exc_info=True - ) - except Exception: - logger.exception("caught exception at message iterator") + break + except ConnectionClosed: + # Set ourselves to closed as soon as we get the signal + await connection_interrupted() + logger.debug("ConnectionClosed while serving", exc_info=True) + break + except FailedSendingMessageException: + # Expected error if the connection is closed. + await connection_interrupted() + logger.debug( + "FailedSendingMessageException while serving", exc_info=True + ) + break + except Exception: + logger.exception("caught exception at message iterator") + break + logging.debug("_handle_messages_from_ws exiting") except ExceptionGroup as eg: _, unhandled = eg.split(lambda e: isinstance(e, ConnectionClosed)) if unhandled: From cb216643ead7fda812243f1c160c26c978913987 Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Tue, 25 Mar 2025 11:35:18 -0700 Subject: [PATCH 056/193] Colocating functions from common_session --- src/replit_river/v2/session.py | 141 +++++++++++++++++++++++++++++++-- 1 file changed, 135 insertions(+), 6 deletions(-) diff --git a/src/replit_river/v2/session.py b/src/replit_river/v2/session.py index 1aaa97c3..5a804474 100644 --- a/src/replit_river/v2/session.py +++ b/src/replit_river/v2/session.py @@ -27,11 +27,9 @@ from websockets.legacy.protocol import WebSocketCommonProtocol from replit_river.common_session import ( + SendMessage, SessionState, TerminalStates, - buffered_message_sender, - check_to_close_session, - setup_heartbeat, ) from replit_river.error_schema import ( ERROR_CODE_CANCEL, @@ -183,7 +181,7 @@ def increment_and_get_heartbeat_misses() -> int: return self._heartbeat_misses self._task_manager.create_task( - setup_heartbeat( + _setup_heartbeat( self.session_id, self._transport_options.heartbeat_ms, self._transport_options.heartbeats_until_dead, @@ -195,7 +193,7 @@ def increment_and_get_heartbeat_misses() -> int: ) ) self._task_manager.create_task( - check_to_close_session( + _check_to_close_session( self._transport_id, self._transport_options.close_session_check_interval_ms, lambda: self._state, @@ -227,7 +225,7 @@ def get_ws() -> WebSocketCommonProtocol | ClientConnection | None: return None self._task_manager.create_task( - buffered_message_sender( + _buffered_message_sender( self._connection_condition, self._message_enqueued, get_ws=get_ws, @@ -930,6 +928,137 @@ async def send_close_stream( ) +async def _check_to_close_session( + transport_id: str, + close_session_check_interval_ms: float, + get_state: Callable[[], SessionState], + get_current_time: Callable[[], Awaitable[float]], + get_close_session_after_time_secs: Callable[[], float | None], + do_close: Callable[[], Awaitable[None]], +) -> None: + our_task = asyncio.current_task() + while our_task and not our_task.cancelling() and not our_task.cancelled(): + await asyncio.sleep(close_session_check_interval_ms / 1000) + if get_state() in TerminalStates: + # already closing + return + # calculate the value now before comparing it so that there are no + # await points between the check and the comparison to avoid a TOCTOU + # race. + current_time = await get_current_time() + close_session_after_time_secs = get_close_session_after_time_secs() + if not close_session_after_time_secs: + continue + if current_time > close_session_after_time_secs: + logger.info("Grace period ended for %s, closing session", transport_id) + await do_close() + return + + +async def _buffered_message_sender( + connection_condition: asyncio.Condition, + message_enqueued: asyncio.Semaphore, + get_ws: Callable[[], WebSocketCommonProtocol | ClientConnection | None], + websocket_closed_callback: Callable[[], Coroutine[Any, Any, None]], + get_next_pending: Callable[[], TransportMessage | None], + commit: Callable[[TransportMessage], None], +) -> None: + while True: + await message_enqueued.acquire() + while (ws := get_ws()) is None: + # Block until we have a handle + logger.debug( + "buffered_message_sender: Waiting until ws is connected (condition=%r)", + connection_condition, + ) + async with connection_condition: + await connection_condition.wait() + if msg := get_next_pending(): + logger.debug( + "buffered_message_sender: Dequeued %r to send over %r", + msg, + ws, + ) + try: + await send_transport_message(msg, ws, websocket_closed_callback) + commit(msg) + except WebsocketClosedException as e: + logger.debug( + "Connection closed while sending message %r, waiting for " + "retry from buffer", + type(e), + exc_info=e, + ) + message_enqueued.release() + break + except FailedSendingMessageException: + logger.error( + "Failed sending message, waiting for retry from buffer", + exc_info=True, + ) + message_enqueued.release() + break + except Exception: + logger.exception("Error attempting to send buffered messages") + message_enqueued.release() + break + + +async def _setup_heartbeat( + session_id: str, + heartbeat_ms: float, + heartbeats_until_dead: int, + get_state: Callable[[], SessionState], + get_closing_grace_period: Callable[[], float | None], + close_websocket: Callable[[], Awaitable[None]], + send_message: SendMessage, + increment_and_get_heartbeat_misses: Callable[[], int], +) -> None: + while True: + await asyncio.sleep(heartbeat_ms / 1000) + state = get_state() + if state == SessionState.CONNECTING: + logger.debug("Websocket is not connected, not sending heartbeat") + continue + if state in TerminalStates: + logger.debug( + "Session is closed, no need to send heartbeat, state : " + "%r close_session_after_this: %r", + {state}, + {get_closing_grace_period()}, + ) + # session is closing / closed, no need to send heartbeat anymore + return + try: + await send_message( + stream_id="heartbeat", + # TODO: make this a message class + # https://github.com/replit/river/blob/741b1ea6d7600937ad53564e9cf8cd27a92ec36a/transport/message.ts#L42 + payload={ + "type": "ACK", + "ack": 0, + }, + control_flags=ACK_BIT, + procedure_name=None, + service_name=None, + span=None, + ) + + if increment_and_get_heartbeat_misses() > heartbeats_until_dead: + if get_closing_grace_period() is not None: + # already in grace period, no need to set again + continue + logger.info( + "%r closing websocket because of heartbeat misses", + session_id, + ) + await close_websocket() + continue + except FailedSendingMessageException: + # this is expected during websocket closed period + continue + + async def _serve( transport_id: str, get_state: Callable[[], SessionState], From bb9edda0d120e961f90b032ce3afdddcac2d5af5 Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Tue, 25 Mar 2025 11:40:07 -0700 Subject: [PATCH 057/193] Flip while:try: to try:while: --- src/replit_river/v2/session.py | 33 ++++++++++++++++----------------- 1 file changed, 16 insertions(+), 17 deletions(-) diff --git a/src/replit_river/v2/session.py b/src/replit_river/v2/session.py index 5a804474..19396145 100644 --- a/src/replit_river/v2/session.py +++ b/src/replit_river/v2/session.py @@ -1077,11 +1077,10 @@ async def _serve( reset_session_close_countdown() our_task = asyncio.current_task() idx = 0 - while our_task and not our_task.cancelling() and not our_task.cancelled(): - logging.debug(f"_serve loop count={idx}") - idx += 1 - try: - logging.debug("_handle_messages_from_ws started") + try: + while our_task and not our_task.cancelling() and not our_task.cancelled(): + logging.debug(f"_serve loop count={idx}") + idx += 1 while (ws := get_ws()) is None or get_state() == SessionState.CONNECTING: logging.debug("_handle_messages_from_ws spinning while connecting") await asyncio.sleep(1) @@ -1190,16 +1189,16 @@ async def _serve( logger.exception("caught exception at message iterator") break logging.debug("_handle_messages_from_ws exiting") - except ExceptionGroup as eg: - _, unhandled = eg.split(lambda e: isinstance(e, ConnectionClosed)) - if unhandled: - # We're in a task, there's not that much that can be done. - unhandled = ExceptionGroup( - "Unhandled exceptions on River server", unhandled.exceptions - ) - logger.exception( - "caught exception at message iterator", - exc_info=unhandled, - ) - raise unhandled + except ExceptionGroup as eg: + _, unhandled = eg.split(lambda e: isinstance(e, ConnectionClosed)) + if unhandled: + # We're in a task, there's not that much that can be done. + unhandled = ExceptionGroup( + "Unhandled exceptions on River server", unhandled.exceptions + ) + logger.exception( + "caught exception at message iterator", + exc_info=unhandled, + ) + raise unhandled logging.debug(f"_serve exiting normally after {idx} loops") From 551f908af41e422d8e0d9be923a4844ea401f865 Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Tue, 25 Mar 2025 11:59:02 -0700 Subject: [PATCH 058/193] Compartmentalize initializers --- src/replit_river/v2/client_transport.py | 1 - src/replit_river/v2/session.py | 159 +++++++++++++----------- 2 files changed, 85 insertions(+), 75 deletions(-) diff --git a/src/replit_river/v2/client_transport.py b/src/replit_river/v2/client_transport.py index 98f6a51e..a9378f92 100644 --- a/src/replit_river/v2/client_transport.py +++ b/src/replit_river/v2/client_transport.py @@ -72,7 +72,6 @@ async def get_or_create_session(self) -> Session: self._session = new_session existing_session = new_session - await existing_session.start_serve_responses() await existing_session.ensure_connected( client_id=self._client_id, diff --git a/src/replit_river/v2/session.py b/src/replit_river/v2/session.py index 19396145..6f6946fc 100644 --- a/src/replit_river/v2/session.py +++ b/src/replit_river/v2/session.py @@ -161,79 +161,10 @@ def __init__( self.ack = 0 self.seq = 0 - async def do_close_websocket() -> None: - logger.debug( - "do_close called, _state=%r, _ws_unwrapped=%r", - self._state, - self._ws_unwrapped, - ) - if self._ws_unwrapped: - self._task_manager.create_task(self._ws_unwrapped.close()) - if self._retry_connection_callback: - self._task_manager.create_task(self._retry_connection_callback()) - self._ws_unwrapped = None - else: - self._state = SessionState.CLOSING - await self._begin_close_session_countdown() - - def increment_and_get_heartbeat_misses() -> int: - self._heartbeat_misses += 1 - return self._heartbeat_misses - - self._task_manager.create_task( - _setup_heartbeat( - self.session_id, - self._transport_options.heartbeat_ms, - self._transport_options.heartbeats_until_dead, - lambda: self._state, - lambda: self._close_session_after_time_secs, - close_websocket=do_close_websocket, - send_message=self.send_message, - increment_and_get_heartbeat_misses=increment_and_get_heartbeat_misses, - ) - ) - self._task_manager.create_task( - _check_to_close_session( - self._transport_id, - self._transport_options.close_session_check_interval_ms, - lambda: self._state, - self._get_current_time, - lambda: self._close_session_after_time_secs, - self.close, - ) - ) - - def commit(msg: TransportMessage) -> None: - pending = self._send_buffer.popleft() - if msg.seq != pending.seq: - logger.error("Out of sequence error") - self._ack_buffer.append(pending) - - # On commit, release pending writers waiting for more buffer space - if self._queue_full_lock.locked(): - self._queue_full_lock.release() - - def get_next_pending() -> TransportMessage | None: - if self._send_buffer: - return self._send_buffer[0] - return None - - # TODO: Just return _ws_unwrapped once we are no longer using the legacy client - def get_ws() -> WebSocketCommonProtocol | ClientConnection | None: - if self.is_connected(): - return self._ws_unwrapped - return None - - self._task_manager.create_task( - _buffered_message_sender( - self._connection_condition, - self._message_enqueued, - get_ws=get_ws, - websocket_closed_callback=self._begin_close_session_countdown, - get_next_pending=get_next_pending, - commit=commit, - ) - ) + self._start_heartbeat() + self._start_serve_responses() + self._start_close_session_checker() + self._start_buffered_message_sender() async def ensure_connected[HandshakeMetadata]( self, @@ -560,7 +491,87 @@ async def close(self) -> None: # This will get us GC'd, so this should be the last thing. await self._close_session_callback(self) - async def start_serve_responses(self) -> None: + def _start_buffered_message_sender(self) -> None: + def commit(msg: TransportMessage) -> None: + pending = self._send_buffer.popleft() + if msg.seq != pending.seq: + logger.error("Out of sequence error") + self._ack_buffer.append(pending) + + # On commit, release pending writers waiting for more buffer space + if self._queue_full_lock.locked(): + self._queue_full_lock.release() + + def get_next_pending() -> TransportMessage | None: + if self._send_buffer: + return self._send_buffer[0] + return None + + # TODO: Just return _ws_unwrapped once we are no longer using the legacy client + def get_ws() -> WebSocketCommonProtocol | ClientConnection | None: + if self.is_connected(): + return self._ws_unwrapped + return None + + self._task_manager.create_task( + _buffered_message_sender( + self._connection_condition, + self._message_enqueued, + get_ws=get_ws, + websocket_closed_callback=self._begin_close_session_countdown, + get_next_pending=get_next_pending, + commit=commit, + ) + ) + + + def _start_close_session_checker(self) -> None: + self._task_manager.create_task( + _check_to_close_session( + self._transport_id, + self._transport_options.close_session_check_interval_ms, + lambda: self._state, + self._get_current_time, + lambda: self._close_session_after_time_secs, + self.close, + ) + ) + + + def _start_heartbeat(self) -> None: + async def do_close_websocket() -> None: + logger.debug( + "do_close called, _state=%r, _ws_unwrapped=%r", + self._state, + self._ws_unwrapped, + ) + if self._ws_unwrapped: + self._task_manager.create_task(self._ws_unwrapped.close()) + if self._retry_connection_callback: + self._task_manager.create_task(self._retry_connection_callback()) + self._ws_unwrapped = None + else: + self._state = SessionState.CLOSING + await self._begin_close_session_countdown() + + def increment_and_get_heartbeat_misses() -> int: + self._heartbeat_misses += 1 + return self._heartbeat_misses + + self._task_manager.create_task( + _setup_heartbeat( + self.session_id, + self._transport_options.heartbeat_ms, + self._transport_options.heartbeats_until_dead, + lambda: self._state, + lambda: self._close_session_after_time_secs, + close_websocket=do_close_websocket, + send_message=self.send_message, + increment_and_get_heartbeat_misses=increment_and_get_heartbeat_misses, + ) + ) + + def _start_serve_responses(self) -> None: async def transition_connecting() -> None: self._state = SessionState.CONNECTING From 49ae2a072dfa2c4db2305eca4fce9e4680ee7c88 Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Tue, 25 Mar 2025 11:59:28 -0700 Subject: [PATCH 059/193] Boom --- src/replit_river/v2/session.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/replit_river/v2/session.py b/src/replit_river/v2/session.py index 6f6946fc..8251cfe4 100644 --- a/src/replit_river/v2/session.py +++ b/src/replit_river/v2/session.py @@ -507,8 +507,7 @@ def get_next_pending() -> TransportMessage | None: return self._send_buffer[0] return None - # TODO: Just return _ws_unwrapped once we are no longer using the legacy client - def get_ws() -> WebSocketCommonProtocol | ClientConnection | None: + def get_ws() -> ClientConnection | None: if self.is_connected(): return self._ws_unwrapped return None @@ -969,7 +968,7 @@ async def _check_to_close_session( async def _buffered_message_sender( connection_condition: asyncio.Condition, message_enqueued: asyncio.Semaphore, - get_ws: Callable[[], WebSocketCommonProtocol | ClientConnection | None], + get_ws: Callable[[], ClientConnection | None], websocket_closed_callback: Callable[[], Coroutine[Any, Any, None]], get_next_pending: Callable[[], TransportMessage | None], commit: Callable[[TransportMessage], None], From 877ea1296eb7d8ff21482833c758d72c58258775 Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Tue, 25 Mar 2025 11:59:42 -0700 Subject: [PATCH 060/193] Deprecated --- src/replit_river/v2/session.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/src/replit_river/v2/session.py b/src/replit_river/v2/session.py index 8251cfe4..8889fc9e 100644 --- a/src/replit_river/v2/session.py +++ b/src/replit_river/v2/session.py @@ -24,7 +24,6 @@ from pydantic import ValidationError from websockets.asyncio.client import ClientConnection from websockets.exceptions import ConnectionClosed, ConnectionClosedOK -from websockets.legacy.protocol import WebSocketCommonProtocol from replit_river.common_session import ( SendMessage, @@ -523,7 +522,6 @@ def get_ws() -> ClientConnection | None: ) ) - def _start_close_session_checker(self) -> None: self._task_manager.create_task( _check_to_close_session( @@ -536,7 +534,6 @@ def _start_close_session_checker(self) -> None: ) ) - def _start_heartbeat(self) -> None: async def do_close_websocket() -> None: logger.debug( @@ -889,7 +886,12 @@ async def _encode_stream() -> None: control_flags=0, payload=request_serializer(item), ) - await self.send_close_stream(service_name, procedure_name, stream_id) + await self.send_close_stream( + service_name, + procedure_name, + stream_id, + extra_control_flags=0, + ) self._task_manager.create_task(_encode_stream()) @@ -924,7 +926,7 @@ async def send_close_stream( service_name: str, procedure_name: str, stream_id: str, - extra_control_flags: int = 0, + extra_control_flags: int, ) -> None: # close stream await self.send_message( From 3e691bebb589e517a07ce88dc779b8dd6c57f399 Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Tue, 25 Mar 2025 12:03:39 -0700 Subject: [PATCH 061/193] try:try: --- src/replit_river/v2/session.py | 104 +++++++++++++++------------------ 1 file changed, 47 insertions(+), 57 deletions(-) diff --git a/src/replit_river/v2/session.py b/src/replit_river/v2/session.py index 8889fc9e..80469ed1 100644 --- a/src/replit_river/v2/session.py +++ b/src/replit_river/v2/session.py @@ -656,41 +656,36 @@ async def send_rpc[R, A]( ) # Handle potential errors during communication try: + async with asyncio.timeout(timeout.total_seconds()): + response = await output.get() + except asyncio.TimeoutError as e: + await self.send_message( + stream_id=stream_id, + control_flags=STREAM_CANCEL_BIT, + payload={"type": "CANCEL"}, + service_name=service_name, + procedure_name=procedure_name, + span=span, + ) + raise RiverException(ERROR_CODE_CANCEL, str(e)) from e + except ChannelClosed as e: + raise RiverServiceException( + ERROR_CODE_STREAM_CLOSED, + "Stream closed before response", + service_name, + procedure_name, + ) from e + except RuntimeError as e: + raise RiverException(ERROR_CODE_STREAM_CLOSED, str(e)) from e + if not response.get("ok", False): try: - async with asyncio.timeout(timeout.total_seconds()): - response = await output.get() - except asyncio.TimeoutError as e: - await self.send_message( - stream_id=stream_id, - control_flags=STREAM_CANCEL_BIT, - payload={"type": "CANCEL"}, - service_name=service_name, - procedure_name=procedure_name, - span=span, - ) - raise RiverException(ERROR_CODE_CANCEL, str(e)) from e - except ChannelClosed as e: - raise RiverServiceException( - ERROR_CODE_STREAM_CLOSED, - "Stream closed before response", - service_name, - procedure_name, - ) from e - except RuntimeError as e: - raise RiverException(ERROR_CODE_STREAM_CLOSED, str(e)) from e - if not response.get("ok", False): - try: - error = error_deserializer(response["payload"]) - except Exception as e: - raise RiverException("error_deserializer", str(e)) from e - raise exception_from_message(error.code)( - error.code, error.message, service_name, procedure_name - ) - return response_deserializer(response["payload"]) - except RiverException as e: - raise e - except Exception as e: - raise e + error = error_deserializer(response["payload"]) + except Exception as e: + raise RiverException("error_deserializer", str(e)) from e + raise exception_from_message(error.code)( + error.code, error.message, service_name, procedure_name + ) + return response_deserializer(response["payload"]) async def send_upload[I, R, A]( self, @@ -751,31 +746,26 @@ async def send_upload[I, R, A]( # Handle potential errors during communication # TODO: throw a error when the transport is hard closed try: + response = await output.get() + except ChannelClosed as e: + raise RiverServiceException( + ERROR_CODE_STREAM_CLOSED, + "Stream closed before response", + service_name, + procedure_name, + ) from e + except RuntimeError as e: + raise RiverException(ERROR_CODE_STREAM_CLOSED, str(e)) from e + if not response.get("ok", False): try: - response = await output.get() - except ChannelClosed as e: - raise RiverServiceException( - ERROR_CODE_STREAM_CLOSED, - "Stream closed before response", - service_name, - procedure_name, - ) from e - except RuntimeError as e: - raise RiverException(ERROR_CODE_STREAM_CLOSED, str(e)) from e - if not response.get("ok", False): - try: - error = error_deserializer(response["payload"]) - except Exception as e: - raise RiverException("error_deserializer", str(e)) from e - raise exception_from_message(error.code)( - error.code, error.message, service_name, procedure_name - ) + error = error_deserializer(response["payload"]) + except Exception as e: + raise RiverException("error_deserializer", str(e)) from e + raise exception_from_message(error.code)( + error.code, error.message, service_name, procedure_name + ) - return response_deserializer(response["payload"]) - except RiverException as e: - raise e - except Exception as e: - raise e + return response_deserializer(response["payload"]) async def send_subscription[R, E, A]( self, From 85b778a19280bd9c22b02809b3646481a10adda5 Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Tue, 25 Mar 2025 12:18:32 -0700 Subject: [PATCH 062/193] Fix state transition logic --- src/replit_river/v2/session.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/src/replit_river/v2/session.py b/src/replit_river/v2/session.py index 80469ed1..819d5f5f 100644 --- a/src/replit_river/v2/session.py +++ b/src/replit_river/v2/session.py @@ -470,7 +470,6 @@ async def close(self) -> None: # already closing return self._state = SessionState.CLOSING - self._reset_session_close_countdown() await self._task_manager.cancel_all_tasks() # TODO: unexpected_close should close stream differently here to @@ -535,7 +534,7 @@ def _start_close_session_checker(self) -> None: ) def _start_heartbeat(self) -> None: - async def do_close_websocket() -> None: + async def close_websocket() -> None: logger.debug( "do_close called, _state=%r, _ws_unwrapped=%r", self._state, @@ -543,11 +542,13 @@ async def do_close_websocket() -> None: ) if self._ws_unwrapped: self._task_manager.create_task(self._ws_unwrapped.close()) - if self._retry_connection_callback: - self._task_manager.create_task(self._retry_connection_callback()) self._ws_unwrapped = None + + if self._retry_connection_callback: + self._task_manager.create_task(self._retry_connection_callback()) else: self._state = SessionState.CLOSING + await self._begin_close_session_countdown() def increment_and_get_heartbeat_misses() -> int: @@ -561,7 +562,7 @@ def increment_and_get_heartbeat_misses() -> int: self._transport_options.heartbeats_until_dead, lambda: self._state, lambda: self._close_session_after_time_secs, - close_websocket=do_close_websocket, + close_websocket=close_websocket, send_message=self.send_message, increment_and_get_heartbeat_misses=increment_and_get_heartbeat_misses, ) @@ -573,6 +574,10 @@ async def transition_connecting() -> None: async def connection_interrupted() -> None: self._state = SessionState.CONNECTING + if self._ws_unwrapped: + self._task_manager.create_task(self._ws_unwrapped.close()) + self._ws_unwrapped = None + if self._retry_connection_callback: self._task_manager.create_task(self._retry_connection_callback()) From df36d3522b5d8c8faa97f1ad1047c597bde82378 Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Tue, 25 Mar 2025 12:21:45 -0700 Subject: [PATCH 063/193] REVERTME --- src/replit_river/v2/session.py | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/src/replit_river/v2/session.py b/src/replit_river/v2/session.py index 819d5f5f..335fb228 100644 --- a/src/replit_river/v2/session.py +++ b/src/replit_river/v2/session.py @@ -1104,11 +1104,11 @@ async def _serve( message = await ws.recv(decode=False) try: msg = parse_transport_msg(message) - logger.debug( - "[%s] got a message %r", - transport_id, - msg, - ) + # logger.debug( + # "[%s] got a message %r", + # transport_id, + # msg, + # ) if msg.controlFlags & STREAM_OPEN_BIT != 0: raise InvalidMessageException( @@ -1131,6 +1131,13 @@ async def _serve( case other: assert_never(other) + # TODO: Delete me + logger.debug( + "[%s] got a message %r", + transport_id, + msg, + ) + reset_session_close_countdown() # Shortcut to avoid processing ack packets From 0184369810f4de26741bacfe1bd511797c618af2 Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Tue, 25 Mar 2025 12:57:31 -0700 Subject: [PATCH 064/193] Missing state transition --- src/replit_river/v2/session.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/replit_river/v2/session.py b/src/replit_river/v2/session.py index 335fb228..579efd5c 100644 --- a/src/replit_river/v2/session.py +++ b/src/replit_river/v2/session.py @@ -392,6 +392,7 @@ async def _begin_close_session_countdown(self) -> None: self._transport_id, self._to_id, ) + self._state = SessionState.CONNECTING self._close_session_after_time_secs = close_session_after_time_secs async def _get_current_time(self) -> float: @@ -956,6 +957,9 @@ async def _check_to_close_session( close_session_after_time_secs = get_close_session_after_time_secs() if not close_session_after_time_secs: continue + logging.debug( + "_check_to_close_session: Preparing to close session if not interrupted" + ) if current_time > close_session_after_time_secs: logger.info("Grace period ended for %s, closing session", transport_id) await do_close() From 3ddff3405b68339d0ade37ecd72481a67c838dce Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Tue, 25 Mar 2025 13:01:47 -0700 Subject: [PATCH 065/193] We have a new enumeration state for this! --- src/replit_river/v2/session.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/replit_river/v2/session.py b/src/replit_river/v2/session.py index 579efd5c..e7588970 100644 --- a/src/replit_river/v2/session.py +++ b/src/replit_river/v2/session.py @@ -392,7 +392,7 @@ async def _begin_close_session_countdown(self) -> None: self._transport_id, self._to_id, ) - self._state = SessionState.CONNECTING + self._state = SessionState.PENDING self._close_session_after_time_secs = close_session_after_time_secs async def _get_current_time(self) -> float: @@ -574,7 +574,7 @@ async def transition_connecting() -> None: self._state = SessionState.CONNECTING async def connection_interrupted() -> None: - self._state = SessionState.CONNECTING + self._state = SessionState.PENDING if self._ws_unwrapped: self._task_manager.create_task(self._ws_unwrapped.close()) self._ws_unwrapped = None From eea25bd7ad68cdfbe408241bc7b46c18f6a5776a Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Tue, 25 Mar 2025 13:06:13 -0700 Subject: [PATCH 066/193] Gotcha --- src/replit_river/v2/session.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/src/replit_river/v2/session.py b/src/replit_river/v2/session.py index e7588970..d500b500 100644 --- a/src/replit_river/v2/session.py +++ b/src/replit_river/v2/session.py @@ -226,10 +226,9 @@ async def _do_ensure_connected[HandshakeMetadata]( try: try: - expectedSessionState = ExpectedSessionState( - nextExpectedSeq=self.ack, - nextSentSeq=self.seq, - ) + next_seq = 0 + if self._send_buffer: + next_seq = self._send_buffer[0].seq handshake_request = ControlMessageHandshakeRequest[ HandshakeMetadata ]( # noqa: E501 @@ -237,7 +236,10 @@ async def _do_ensure_connected[HandshakeMetadata]( protocolVersion=protocol_version, sessionId=self.session_id, metadata=uri_and_metadata["metadata"], - expectedSessionState=expectedSessionState, + expectedSessionState=ExpectedSessionState( + nextExpectedSeq=self.ack, + nextSentSeq=next_seq, + ), ) stream_id = nanoid.generate() From 8ddf5915549c3c6e26658fd22f20bd3738b2fe0c Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Tue, 25 Mar 2025 15:36:26 -0700 Subject: [PATCH 067/193] Readability reordering --- src/replit_river/v2/client_transport.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/replit_river/v2/client_transport.py b/src/replit_river/v2/client_transport.py index a9378f92..2a751ce3 100644 --- a/src/replit_river/v2/client_transport.py +++ b/src/replit_river/v2/client_transport.py @@ -82,7 +82,7 @@ async def get_or_create_session(self) -> Session: return existing_session async def _retry_connection(self) -> Session: - if not self._transport_options.transparent_reconnect and self._session: + if self._session and not self._transport_options.transparent_reconnect: logger.info("transparent_reconnect not set, closing {self._transport_id}") await self._session.close() return await self.get_or_create_session() From e7bd18be68f693ae32174e0f7ca3a2f713fc712c Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Tue, 25 Mar 2025 15:36:47 -0700 Subject: [PATCH 068/193] c-style strings work with format strings as well --- src/replit_river/v2/session.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/replit_river/v2/session.py b/src/replit_river/v2/session.py index d500b500..28d1422f 100644 --- a/src/replit_river/v2/session.py +++ b/src/replit_river/v2/session.py @@ -316,7 +316,7 @@ async def websocket_closed_callback() -> None: raise RiverException( ERROR_HANDSHAKE, f"Handshake failed with code {handshake_response.status.code}: " # noqa: E501 - + f"{handshake_response.status.reason}", + f"{handshake_response.status.reason}", ) last_error = None From b21a99bde0c3dd4ac85343a400ce7900d42dcc79 Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Tue, 25 Mar 2025 15:37:29 -0700 Subject: [PATCH 069/193] Wake up tasks pending connected state so they can exit cleanly --- src/replit_river/v2/session.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/src/replit_river/v2/session.py b/src/replit_river/v2/session.py index 28d1422f..38000a55 100644 --- a/src/replit_river/v2/session.py +++ b/src/replit_river/v2/session.py @@ -473,6 +473,13 @@ async def close(self) -> None: # already closing return self._state = SessionState.CLOSING + + # We need to wake up all tasks waiting for connection to be established + assert not self._connection_condition.locked() + await self._connection_condition.acquire() + self._connection_condition.notify_all() + self._connection_condition.release() + await self._task_manager.cancel_all_tasks() # TODO: unexpected_close should close stream differently here to From 7d4912f9a2049599f68a14c756c155b2d984c23d Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Tue, 25 Mar 2025 15:38:28 -0700 Subject: [PATCH 070/193] Migrating all background tasks to just use block_until_connected --- src/replit_river/v2/session.py | 63 +++++++++++++++++++++++++--------- 1 file changed, 47 insertions(+), 16 deletions(-) diff --git a/src/replit_river/v2/session.py b/src/replit_river/v2/session.py index 38000a55..eabc5f84 100644 --- a/src/replit_river/v2/session.py +++ b/src/replit_river/v2/session.py @@ -26,6 +26,7 @@ from websockets.exceptions import ConnectionClosed, ConnectionClosedOK from replit_river.common_session import ( + ConnectingStates, SendMessage, SessionState, TerminalStates, @@ -520,14 +521,19 @@ def get_ws() -> ClientConnection | None: return self._ws_unwrapped return None + async def block_until_connected() -> None: + async with self._connection_condition: + await self._connection_condition.wait() + self._task_manager.create_task( _buffered_message_sender( - self._connection_condition, - self._message_enqueued, + block_until_connected=block_until_connected, + message_enqueued=self._message_enqueued, get_ws=get_ws, websocket_closed_callback=self._begin_close_session_countdown, get_next_pending=get_next_pending, commit=commit, + get_state=lambda: self._state, ) ) @@ -565,8 +571,13 @@ def increment_and_get_heartbeat_misses() -> int: self._heartbeat_misses += 1 return self._heartbeat_misses + async def block_until_connected() -> None: + async with self._connection_condition: + await self._connection_condition.wait() + self._task_manager.create_task( _setup_heartbeat( + block_until_connected, self.session_id, self._transport_options.heartbeat_ms, self._transport_options.heartbeats_until_dead, @@ -628,9 +639,15 @@ def assert_incoming_seq_bookkeeping( def close_stream(stream_id: str) -> None: del self._streams[stream_id] + async def block_until_connected() -> None: + async with self._connection_condition: + await self._connection_condition.wait() + + self._task_manager.create_task( _serve( - self._transport_id, + block_until_connected=block_until_connected, + transport_id=self._transport_id, get_state=lambda: self._state, get_ws=lambda: self._ws_unwrapped, transition_connecting=transition_connecting, @@ -976,23 +993,32 @@ async def _check_to_close_session( async def _buffered_message_sender( - connection_condition: asyncio.Condition, + block_until_connected: Callable[[], Awaitable[None]], message_enqueued: asyncio.Semaphore, get_ws: Callable[[], ClientConnection | None], websocket_closed_callback: Callable[[], Coroutine[Any, Any, None]], get_next_pending: Callable[[], TransportMessage | None], commit: Callable[[TransportMessage], None], + get_state: Callable[[], SessionState], ) -> None: - while True: + our_task = asyncio.current_task() + while our_task and not our_task.cancelling() and not our_task.cancelled(): await message_enqueued.acquire() while (ws := get_ws()) is None: # Block until we have a handle logger.debug( - "buffered_message_sender: Waiting until ws is connected (condition=%r)", - connection_condition, + "buffered_message_sender: Waiting until ws is connected", ) - async with connection_condition: - await connection_condition.wait() + await block_until_connected() + + if get_state() in TerminalStates: + logger.debug("We're going away!") + return + + if not ws: + logger.debug("ws is not connected, loop") + continue + if msg := get_next_pending(): logger.debug( "buffered_message_sender: Dequeued %r to send over %r", @@ -1025,6 +1051,7 @@ async def _buffered_message_sender( async def _setup_heartbeat( + block_until_connected: Callable[[], Awaitable[None]], session_id: str, heartbeat_ms: float, heartbeats_until_dead: int, @@ -1035,11 +1062,8 @@ async def _setup_heartbeat( increment_and_get_heartbeat_misses: Callable[[], int], ) -> None: while True: - await asyncio.sleep(heartbeat_ms / 1000) - state = get_state() - if state == SessionState.CONNECTING: - logger.debug("Websocket is not connected, not sending heartbeat") - continue + while (state := get_state()) in ConnectingStates: + await block_until_connected() if state in TerminalStates: logger.debug( "Session is closed, no need to send heartbeat, state : " @@ -1048,7 +1072,13 @@ async def _setup_heartbeat( {get_closing_grace_period()}, ) # session is closing / closed, no need to send heartbeat anymore - return + break + + await asyncio.sleep(heartbeat_ms / 1000) + state = get_state() + if state == SessionState.CONNECTING: + logger.debug("Websocket is not connected, not sending heartbeat") + continue try: await send_message( stream_id="heartbeat", @@ -1080,6 +1110,7 @@ async def _setup_heartbeat( async def _serve( + block_until_connected: Callable[[], Awaitable[None]], transport_id: str, get_state: Callable[[], SessionState], get_ws: Callable[[], ClientConnection | None], @@ -1103,7 +1134,7 @@ async def _serve( idx += 1 while (ws := get_ws()) is None or get_state() == SessionState.CONNECTING: logging.debug("_handle_messages_from_ws spinning while connecting") - await asyncio.sleep(1) + await block_until_connected() logger.debug( "%s start handling messages from ws %s", "client", From ee8d0652e2009a1e44de6df741aac931fecc3bf3 Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Tue, 25 Mar 2025 15:39:25 -0700 Subject: [PATCH 071/193] Redundant --- src/replit_river/v2/session.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/replit_river/v2/session.py b/src/replit_river/v2/session.py index eabc5f84..34d79216 100644 --- a/src/replit_river/v2/session.py +++ b/src/replit_river/v2/session.py @@ -828,7 +828,7 @@ async def send_subscription[R, E, A]( # Handle potential errors during communication try: async for item in output: - if item.get("type", None) == "CLOSE": + if item.get("type") == "CLOSE": break if not item.get("ok", False): try: @@ -920,7 +920,7 @@ async def _encode_stream() -> None: # Handle potential errors during communication try: async for item in output: - if "type" in item and item["type"] == "CLOSE": + if item.get("type") == "CLOSE": break if not item.get("ok", False): try: From 5735ce99889c96a7620fc416751db4aaefed40f8 Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Tue, 25 Mar 2025 15:40:25 -0700 Subject: [PATCH 072/193] Allow _check_to_close_session to exit normally if we are done --- src/replit_river/v2/session.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/src/replit_river/v2/session.py b/src/replit_river/v2/session.py index 34d79216..83851a65 100644 --- a/src/replit_river/v2/session.py +++ b/src/replit_river/v2/session.py @@ -975,7 +975,7 @@ async def _check_to_close_session( await asyncio.sleep(close_session_check_interval_ms / 1000) if get_state() in TerminalStates: # already closing - return + break # calculate the value now before comparing it so that there are no # await points between the check and the comparison to avoid a TOCTOU # race. @@ -983,9 +983,6 @@ async def _check_to_close_session( close_session_after_time_secs = get_close_session_after_time_secs() if not close_session_after_time_secs: continue - logging.debug( - "_check_to_close_session: Preparing to close session if not interrupted" - ) if current_time > close_session_after_time_secs: logger.info("Grace period ended for %s, closing session", transport_id) await do_close() From e18b544ab96c17c62a63859c01d97e47a1fd034b Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Tue, 25 Mar 2025 15:53:18 -0700 Subject: [PATCH 073/193] Abort if the session is rejected by the server --- src/replit_river/error_schema.py | 3 +++ src/replit_river/v2/session.py | 3 +++ 2 files changed, 6 insertions(+) diff --git a/src/replit_river/error_schema.py b/src/replit_river/error_schema.py index af5837dd..fab1041b 100644 --- a/src/replit_river/error_schema.py +++ b/src/replit_river/error_schema.py @@ -20,6 +20,9 @@ # ERROR_CODE_UNKNOWN is the code for the RiverUnknownError ERROR_CODE_UNKNOWN = "UNKNOWN" +# SESSION_STATE_MISMATCH is the code when the remote server rejects the session's state +ERROR_CODE_SESSION_STATE_MISMATCH = "SESSION_STATE_MISMATCH" + class RiverError(BaseModel): """Error message from the server.""" diff --git a/src/replit_river/v2/session.py b/src/replit_river/v2/session.py index 83851a65..bfd72355 100644 --- a/src/replit_river/v2/session.py +++ b/src/replit_river/v2/session.py @@ -33,6 +33,7 @@ ) from replit_river.error_schema import ( ERROR_CODE_CANCEL, + ERROR_CODE_SESSION_STATE_MISMATCH, ERROR_CODE_STREAM_CLOSED, ERROR_HANDSHAKE, RiverError, @@ -314,6 +315,8 @@ async def websocket_closed_callback() -> None: "river client get handshake response : %r", handshake_response ) # noqa: E501 if not handshake_response.status.ok: + if handshake_response.status.code == ERROR_CODE_SESSION_STATE_MISMATCH: + await self.close() raise RiverException( ERROR_HANDSHAKE, f"Handshake failed with code {handshake_response.status.code}: " # noqa: E501 From fbe8b30f7355ce77e5136fc30b4abc1ab458e261 Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Tue, 25 Mar 2025 16:00:39 -0700 Subject: [PATCH 074/193] Expand CONNECTING to the various ConnectingStates --- src/replit_river/v2/session.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/replit_river/v2/session.py b/src/replit_river/v2/session.py index bfd72355..92b6d105 100644 --- a/src/replit_river/v2/session.py +++ b/src/replit_river/v2/session.py @@ -1076,7 +1076,7 @@ async def _setup_heartbeat( await asyncio.sleep(heartbeat_ms / 1000) state = get_state() - if state == SessionState.CONNECTING: + if state in ConnectingStates: logger.debug("Websocket is not connected, not sending heartbeat") continue try: @@ -1132,7 +1132,7 @@ async def _serve( while our_task and not our_task.cancelling() and not our_task.cancelled(): logging.debug(f"_serve loop count={idx}") idx += 1 - while (ws := get_ws()) is None or get_state() == SessionState.CONNECTING: + while (ws := get_ws()) is None or get_state() in ConnectingStates: logging.debug("_handle_messages_from_ws spinning while connecting") await block_until_connected() logger.debug( From d18c6f036bbdfb01af30d3760a1204bab3478d26 Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Tue, 25 Mar 2025 16:06:46 -0700 Subject: [PATCH 075/193] More lifecycle tweaks during shutdown so we exit cleanly. --- src/replit_river/v2/session.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/src/replit_river/v2/session.py b/src/replit_river/v2/session.py index 92b6d105..8fd40a5c 100644 --- a/src/replit_river/v2/session.py +++ b/src/replit_river/v2/session.py @@ -315,7 +315,10 @@ async def websocket_closed_callback() -> None: "river client get handshake response : %r", handshake_response ) # noqa: E501 if not handshake_response.status.ok: - if handshake_response.status.code == ERROR_CODE_SESSION_STATE_MISMATCH: + if ( + handshake_response.status.code + == ERROR_CODE_SESSION_STATE_MISMATCH + ): # noqa: E501 await self.close() raise RiverException( ERROR_HANDSHAKE, @@ -479,8 +482,8 @@ async def close(self) -> None: self._state = SessionState.CLOSING # We need to wake up all tasks waiting for connection to be established - assert not self._connection_condition.locked() - await self._connection_condition.acquire() + if not self._connection_condition.locked(): + await self._connection_condition.acquire() self._connection_condition.notify_all() self._connection_condition.release() @@ -490,6 +493,8 @@ async def close(self) -> None: # throw exception correctly. for stream in self._streams.values(): stream.close() + # Before we GC the streams, let's wait for all tasks to be closed gracefully. + await asyncio.gather(*[x.join() for x in self._streams.values()]) self._streams.clear() if self._ws_unwrapped: @@ -646,7 +651,6 @@ async def block_until_connected() -> None: async with self._connection_condition: await self._connection_condition.wait() - self._task_manager.create_task( _serve( block_until_connected=block_until_connected, From 9fd39e17872c90d5c835219862771fdfb036e2e0 Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Tue, 25 Mar 2025 16:56:23 -0700 Subject: [PATCH 076/193] More logging --- src/replit_river/v2/session.py | 17 +++++------------ 1 file changed, 5 insertions(+), 12 deletions(-) diff --git a/src/replit_river/v2/session.py b/src/replit_river/v2/session.py index 8fd40a5c..f3388128 100644 --- a/src/replit_river/v2/session.py +++ b/src/replit_river/v2/session.py @@ -1152,11 +1152,11 @@ async def _serve( message = await ws.recv(decode=False) try: msg = parse_transport_msg(message) - # logger.debug( - # "[%s] got a message %r", - # transport_id, - # msg, - # ) + logger.debug( + "[%s] got a message %r", + transport_id, + msg, + ) if msg.controlFlags & STREAM_OPEN_BIT != 0: raise InvalidMessageException( @@ -1179,13 +1179,6 @@ async def _serve( case other: assert_never(other) - # TODO: Delete me - logger.debug( - "[%s] got a message %r", - transport_id, - msg, - ) - reset_session_close_countdown() # Shortcut to avoid processing ack packets From 424baaed9ea711fb7c6bab3a6f5ddb09829f2e21 Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Tue, 25 Mar 2025 17:04:25 -0700 Subject: [PATCH 077/193] try:try: --- src/replit_river/v2/session.py | 194 ++++++++++++++++----------------- 1 file changed, 97 insertions(+), 97 deletions(-) diff --git a/src/replit_river/v2/session.py b/src/replit_river/v2/session.py index f3388128..ecbfce9d 100644 --- a/src/replit_river/v2/session.py +++ b/src/replit_river/v2/session.py @@ -222,120 +222,120 @@ async def _do_ensure_connected[HandshakeMetadata]( rate_limiter.consume_budget(client_id) + ws = None try: uri_and_metadata = await uri_and_metadata_factory() ws = await websockets.asyncio.client.connect(uri_and_metadata["uri"]) try: - try: - next_seq = 0 - if self._send_buffer: - next_seq = self._send_buffer[0].seq - handshake_request = ControlMessageHandshakeRequest[ - HandshakeMetadata - ]( # noqa: E501 - type="HANDSHAKE_REQ", - protocolVersion=protocol_version, - sessionId=self.session_id, - metadata=uri_and_metadata["metadata"], - expectedSessionState=ExpectedSessionState( - nextExpectedSeq=self.ack, - nextSentSeq=next_seq, - ), + next_seq = 0 + if self._send_buffer: + next_seq = self._send_buffer[0].seq + handshake_request = ControlMessageHandshakeRequest[ + HandshakeMetadata + ]( # noqa: E501 + type="HANDSHAKE_REQ", + protocolVersion=protocol_version, + sessionId=self.session_id, + metadata=uri_and_metadata["metadata"], + expectedSessionState=ExpectedSessionState( + nextExpectedSeq=self.ack, + nextSentSeq=next_seq, + ), + ) + stream_id = nanoid.generate() + + async def websocket_closed_callback() -> None: + logger.error("websocket closed before handshake response") + + await send_transport_message( + TransportMessage( + from_=self._transport_id, + to=self._to_id, + streamId=stream_id, + controlFlags=0, + id=nanoid.generate(), + seq=0, + ack=0, + payload=handshake_request.model_dump(), + ), + ws=ws, + websocket_closed_callback=websocket_closed_callback, + ) + except ( + WebsocketClosedException, + FailedSendingMessageException, + ) as e: # noqa: E501 + raise RiverException( + ERROR_HANDSHAKE, + "Handshake failed, conn closed while sending response", # noqa: E501 + ) from e + + startup_grace_deadline_ms = await self._get_current_time() + 60_000 + while True: + if await self._get_current_time() >= startup_grace_deadline_ms: # noqa: E501 + raise RiverException( + ERROR_HANDSHAKE, + "Handshake response timeout, closing connection", # noqa: E501 ) - stream_id = nanoid.generate() - - async def websocket_closed_callback() -> None: - logger.error("websocket closed before handshake response") - - await send_transport_message( - TransportMessage( - from_=self._transport_id, - to=self._to_id, - streamId=stream_id, - controlFlags=0, - id=nanoid.generate(), - seq=0, - ack=0, - payload=handshake_request.model_dump(), - ), - ws=ws, - websocket_closed_callback=websocket_closed_callback, + try: + data = await ws.recv(decode=False) + except ConnectionClosed as e: + logger.debug( + "Connection closed during waiting for handshake response", # noqa: E501 + exc_info=True, ) - except ( - WebsocketClosedException, - FailedSendingMessageException, - ) as e: # noqa: E501 raise RiverException( ERROR_HANDSHAKE, - "Handshake failed, conn closed while sending response", # noqa: E501 + "Handshake failed, conn closed while waiting for response", # noqa: E501 ) from e - startup_grace_deadline_ms = await self._get_current_time() + 60_000 - while True: - if await self._get_current_time() >= startup_grace_deadline_ms: # noqa: E501 - raise RiverException( - ERROR_HANDSHAKE, - "Handshake response timeout, closing connection", # noqa: E501 - ) - try: - data = await ws.recv() - except ConnectionClosed as e: - logger.debug( - "Connection closed during waiting for handshake response", # noqa: E501 - exc_info=True, - ) - raise RiverException( - ERROR_HANDSHAKE, - "Handshake failed, conn closed while waiting for response", # noqa: E501 - ) from e - try: - response_msg = parse_transport_msg(data) - break - except IgnoreMessageException: - logger.debug("Ignoring transport message", exc_info=True) # noqa: E501 - continue - except InvalidMessageException as e: - raise RiverException( - ERROR_HANDSHAKE, - "Got invalid transport message, closing connection", - ) from e - try: - handshake_response = ControlMessageHandshakeResponse( - **response_msg.payload - ) - logger.debug("river client waiting for handshake response") - except ValidationError as e: + response_msg = parse_transport_msg(data) + break + except IgnoreMessageException: + logger.debug("Ignoring transport message", exc_info=True) # noqa: E501 + continue + except InvalidMessageException as e: raise RiverException( - ERROR_HANDSHAKE, "Failed to parse handshake response" + ERROR_HANDSHAKE, + "Got invalid transport message, closing connection", ) from e - logger.debug( - "river client get handshake response : %r", handshake_response - ) # noqa: E501 - if not handshake_response.status.ok: - if ( - handshake_response.status.code - == ERROR_CODE_SESSION_STATE_MISMATCH - ): # noqa: E501 - await self.close() - raise RiverException( - ERROR_HANDSHAKE, - f"Handshake failed with code {handshake_response.status.code}: " # noqa: E501 - f"{handshake_response.status.reason}", - ) + try: + handshake_response = ControlMessageHandshakeResponse( + **response_msg.payload + ) + logger.debug("river client waiting for handshake response") + except ValidationError as e: + raise RiverException( + ERROR_HANDSHAKE, "Failed to parse handshake response" + ) from e - last_error = None - rate_limiter.start_restoring_budget(client_id) - self._state = SessionState.ACTIVE - self._ws_unwrapped = ws - self._connection_condition.notify_all() - break - except RiverException as e: - await ws.close() - raise e + logger.debug( + "river client get handshake response : %r", handshake_response + ) # noqa: E501 + if not handshake_response.status.ok: + if ( + handshake_response.status.code + == ERROR_CODE_SESSION_STATE_MISMATCH + ): # noqa: E501 + await self.close() + raise RiverException( + ERROR_HANDSHAKE, + f"Handshake failed with code {handshake_response.status.code}: " # noqa: E501 + f"{handshake_response.status.reason}", + ) + + last_error = None + rate_limiter.start_restoring_budget(client_id) + self._state = SessionState.ACTIVE + self._ws_unwrapped = ws + self._connection_condition.notify_all() + break except Exception as e: + if ws: + await ws.close() last_error = e backoff_time = rate_limiter.get_backoff_ms(client_id) logger.exception( From 2abe7304ed5c4f19320a2554af24e268ab9edded Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Tue, 25 Mar 2025 17:11:54 -0700 Subject: [PATCH 078/193] more state logging --- src/replit_river/v2/session.py | 29 ++++++++++++++++++++--------- 1 file changed, 20 insertions(+), 9 deletions(-) diff --git a/src/replit_river/v2/session.py b/src/replit_river/v2/session.py index ecbfce9d..43a419f0 100644 --- a/src/replit_river/v2/session.py +++ b/src/replit_river/v2/session.py @@ -243,6 +243,16 @@ async def _do_ensure_connected[HandshakeMetadata]( nextSentSeq=next_seq, ), ) + sb_state = None + ab_state = None + if self._send_buffer: + sb_state = [self._send_buffer[0].seq, self._send_buffer[0].ack] + if self._ack_buffer: + ab_state = [self._ack_buffer[0].seq, self._ack_buffer[0].ack] + logger.debug( + f"STATE{{seq={self.seq}, ack={self.ack}, next_seq={next_seq}, " + f"sb_state={sb_state}, ab_state={ab_state} }}" + ) stream_id = nanoid.generate() async def websocket_closed_callback() -> None: @@ -268,33 +278,33 @@ async def websocket_closed_callback() -> None: ) as e: # noqa: E501 raise RiverException( ERROR_HANDSHAKE, - "Handshake failed, conn closed while sending response", # noqa: E501 + "Handshake failed, conn closed while sending response", ) from e startup_grace_deadline_ms = await self._get_current_time() + 60_000 while True: - if await self._get_current_time() >= startup_grace_deadline_ms: # noqa: E501 + if await self._get_current_time() >= startup_grace_deadline_ms: raise RiverException( ERROR_HANDSHAKE, - "Handshake response timeout, closing connection", # noqa: E501 + "Handshake response timeout, closing connection", ) try: data = await ws.recv(decode=False) except ConnectionClosed as e: logger.debug( - "Connection closed during waiting for handshake response", # noqa: E501 + "Connection closed during waiting for handshake response", exc_info=True, ) raise RiverException( ERROR_HANDSHAKE, - "Handshake failed, conn closed while waiting for response", # noqa: E501 + "Handshake failed, conn closed while waiting for response", ) from e try: response_msg = parse_transport_msg(data) break except IgnoreMessageException: - logger.debug("Ignoring transport message", exc_info=True) # noqa: E501 + logger.debug("Ignoring transport message", exc_info=True) continue except InvalidMessageException as e: raise RiverException( @@ -319,12 +329,13 @@ async def websocket_closed_callback() -> None: if ( handshake_response.status.code == ERROR_CODE_SESSION_STATE_MISMATCH - ): # noqa: E501 + ): await self.close() raise RiverException( ERROR_HANDSHAKE, - f"Handshake failed with code {handshake_response.status.code}: " # noqa: E501 - f"{handshake_response.status.reason}", + f"Handshake failed with code {handshake_response.status.code}: { + handshake_response.status.reason + }", ) last_error = None From 9ff6acaa7f7e5db82fcbbd88e417ebab80d8f8ae Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Tue, 25 Mar 2025 18:21:42 -0700 Subject: [PATCH 079/193] Switch from active heartbeat to passive heartbeat --- src/replit_river/v2/session.py | 60 ++++++++++++++++++---------------- 1 file changed, 31 insertions(+), 29 deletions(-) diff --git a/src/replit_river/v2/session.py b/src/replit_river/v2/session.py index 43a419f0..557c364d 100644 --- a/src/replit_river/v2/session.py +++ b/src/replit_river/v2/session.py @@ -603,7 +603,6 @@ async def block_until_connected() -> None: lambda: self._state, lambda: self._close_session_after_time_secs, close_websocket=close_websocket, - send_message=self.send_message, increment_and_get_heartbeat_misses=increment_and_get_heartbeat_misses, ) ) @@ -662,6 +661,9 @@ async def block_until_connected() -> None: async with self._connection_condition: await self._connection_condition.wait() + def received_message(message: TransportMessage) -> None: + pass + self._task_manager.create_task( _serve( block_until_connected=block_until_connected, @@ -675,6 +677,8 @@ async def block_until_connected() -> None: assert_incoming_seq_bookkeeping=assert_incoming_seq_bookkeeping, get_stream=lambda stream_id: self._streams.get(stream_id), close_stream=close_stream, + received_message=received_message, + send_message=self.send_message, ) ) @@ -1073,7 +1077,6 @@ async def _setup_heartbeat( get_state: Callable[[], SessionState], get_closing_grace_period: Callable[[], float | None], close_websocket: Callable[[], Awaitable[None]], - send_message: SendMessage, increment_and_get_heartbeat_misses: Callable[[], int], ) -> None: while True: @@ -1092,36 +1095,18 @@ async def _setup_heartbeat( await asyncio.sleep(heartbeat_ms / 1000) state = get_state() if state in ConnectingStates: - logger.debug("Websocket is not connected, not sending heartbeat") + logger.debug("Websocket is not connected, don't expect heartbeat") continue - try: - await send_message( - stream_id="heartbeat", - # TODO: make this a message class - # https://github.com/replit/river/blob/741b1ea6d7600937ad53564e9cf8cd27a92ec36a/transport/message.ts#L42 - payload={ - "type": "ACK", - "ack": 0, - }, - control_flags=ACK_BIT, - procedure_name=None, - service_name=None, - span=None, - ) - if increment_and_get_heartbeat_misses() > heartbeats_until_dead: - if get_closing_grace_period() is not None: - # already in grace period, no need to set again - continue - logger.info( - "%r closing websocket because of heartbeat misses", - session_id, - ) - await close_websocket() + if increment_and_get_heartbeat_misses() > heartbeats_until_dead: + if get_closing_grace_period() is not None: + # already in grace period, no need to set again continue - except FailedSendingMessageException: - # this is expected during websocket closed period - continue + logger.info( + "%r closing websocket because of heartbeat misses", + session_id, + ) + await close_websocket() async def _serve( @@ -1138,6 +1123,8 @@ async def _serve( ], # noqa: E501 get_stream: Callable[[str], Channel[Any] | None], close_stream: Callable[[str], None], + received_message: Callable[[TransportMessage], None], + send_message: SendMessage, ) -> None: """Serve messages from the websocket.""" reset_session_close_countdown() @@ -1169,6 +1156,8 @@ async def _serve( msg, ) + received_message(msg) + if msg.controlFlags & STREAM_OPEN_BIT != 0: raise InvalidMessageException( "Client should not receive stream open bit" @@ -1194,6 +1183,19 @@ async def _serve( # Shortcut to avoid processing ack packets if msg.controlFlags & ACK_BIT != 0: + await send_message( + stream_id="heartbeat", + # TODO: make this a message class + # https://github.com/replit/river/blob/741b1ea6d7600937ad53564e9cf8cd27a92ec36a/transport/message.ts#L42 + payload={ + "type": "ACK", + "ack": 0, + }, + control_flags=ACK_BIT, + procedure_name=None, + service_name=None, + span=None, + ) continue stream = get_stream(msg.streamId) From 74b713345771aa3cacc5219784d74162b72d431d Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Tue, 25 Mar 2025 21:27:20 -0700 Subject: [PATCH 080/193] Whoops --- src/replit_river/v2/session.py | 16 ++++++---------- 1 file changed, 6 insertions(+), 10 deletions(-) diff --git a/src/replit_river/v2/session.py b/src/replit_river/v2/session.py index 557c364d..df3f4a0b 100644 --- a/src/replit_river/v2/session.py +++ b/src/replit_river/v2/session.py @@ -214,7 +214,6 @@ async def _do_ensure_connected[HandshakeMetadata]( last_error: Exception | None = None i = 0 - await self._connection_condition.acquire() while rate_limiter.has_budget_or_throw(client_id, ERROR_HANDSHAKE, last_error): if i > 0: logger.info(f"Retrying build handshake number {i} times") @@ -342,7 +341,10 @@ async def websocket_closed_callback() -> None: rate_limiter.start_restoring_budget(client_id) self._state = SessionState.ACTIVE self._ws_unwrapped = ws - self._connection_condition.notify_all() + + # We're connected, wake everybody up + async with self._connection_condition: + self._connection_condition.notify_all() break except Exception as e: if ws: @@ -370,10 +372,6 @@ async def websocket_closed_callback() -> None: ): self._connecting_task = None - # Release the lock we took earlier so we can use it again in the next - # connection attempt - self._connection_condition.release() - if last_error is not None: raise RiverException( ERROR_HANDSHAKE, @@ -493,10 +491,8 @@ async def close(self) -> None: self._state = SessionState.CLOSING # We need to wake up all tasks waiting for connection to be established - if not self._connection_condition.locked(): - await self._connection_condition.acquire() - self._connection_condition.notify_all() - self._connection_condition.release() + async with self._connection_condition: + self._connection_condition.notify_all() await self._task_manager.cancel_all_tasks() From 80dcd7212041710e8a52a0151b63802b198da526 Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Tue, 25 Mar 2025 21:43:17 -0700 Subject: [PATCH 081/193] Avoid circular awaits --- src/replit_river/v2/session.py | 28 +++++++++++++++++++++++----- 1 file changed, 23 insertions(+), 5 deletions(-) diff --git a/src/replit_river/v2/session.py b/src/replit_river/v2/session.py index df3f4a0b..0842341e 100644 --- a/src/replit_river/v2/session.py +++ b/src/replit_river/v2/session.py @@ -121,6 +121,9 @@ class Session: ack: int # Most recently acknowledged seq seq: int # Last sent sequence number + # Terminating + _terminating_task: asyncio.Task[None] + def __init__( self, transport_id: str, @@ -188,6 +191,12 @@ async def ensure_connected[HandshakeMetadata]( if self.is_connected(): return + def do_close() -> None: + # We can't just call self.close() directly because + # we're inside a thread that will eventually be awaited + # during the cleanup procedure. + self._terminating_task = asyncio.create_task(self.close()) + if not self._connecting_task: self._connecting_task = asyncio.create_task( self._do_ensure_connected( @@ -195,6 +204,7 @@ async def ensure_connected[HandshakeMetadata]( rate_limiter, uri_and_metadata_factory, protocol_version, + do_close, ) ) @@ -208,6 +218,7 @@ async def _do_ensure_connected[HandshakeMetadata]( [], Awaitable[UriAndMetadata[HandshakeMetadata]] ], # noqa: E501 protocol_version: str, + do_close: Callable[[], None], ) -> Literal[True]: max_retry = self._transport_options.connection_retry_options.max_retry logger.info("Attempting to establish new ws connection") @@ -329,7 +340,8 @@ async def websocket_closed_callback() -> None: handshake_response.status.code == ERROR_CODE_SESSION_STATE_MISMATCH ): - await self.close() + do_close() + raise RiverException( ERROR_HANDSHAKE, f"Handshake failed with code {handshake_response.status.code}: { @@ -553,6 +565,12 @@ async def block_until_connected() -> None: ) def _start_close_session_checker(self) -> None: + def do_close() -> None: + # We can't just call self.close() directly because + # we're inside a thread that will eventually be awaited + # during the cleanup procedure. + self._terminating_task = asyncio.create_task(self.close()) + self._task_manager.create_task( _check_to_close_session( self._transport_id, @@ -560,7 +578,7 @@ def _start_close_session_checker(self) -> None: lambda: self._state, self._get_current_time, lambda: self._close_session_after_time_secs, - self.close, + do_close=do_close, ) ) @@ -986,7 +1004,7 @@ async def _check_to_close_session( get_state: Callable[[], SessionState], get_current_time: Callable[[], Awaitable[float]], get_close_session_after_time_secs: Callable[[], float | None], - do_close: Callable[[], Awaitable[None]], + do_close: Callable[[], None], ) -> None: our_task = asyncio.current_task() while our_task and not our_task.cancelling() and not our_task.cancelled(): @@ -1003,8 +1021,8 @@ async def _check_to_close_session( continue if current_time > close_session_after_time_secs: logger.info("Grace period ended for %s, closing session", transport_id) - await do_close() - return + do_close() + our_task.cancel() async def _buffered_message_sender( From 54bfe47281d55f7b8d984da444cb542f094cbcb2 Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Tue, 25 Mar 2025 22:31:51 -0700 Subject: [PATCH 082/193] Permit method filtering based on supplied file --- src/replit_river/codegen/client.py | 5 +++++ src/replit_river/codegen/run.py | 7 +++++++ tests/codegen/snapshot/codegen_snapshot_fixtures.py | 1 + tests/codegen/test_rpc.py | 1 + 4 files changed, 14 insertions(+) diff --git a/src/replit_river/codegen/client.py b/src/replit_river/codegen/client.py index d56856f8..66aa618d 100644 --- a/src/replit_river/codegen/client.py +++ b/src/replit_river/codegen/client.py @@ -1096,6 +1096,7 @@ def generate_individual_service( input_base_class: Literal["TypedDict"] | Literal["BaseModel"], method_filter: set[str] | None, protocol_version: Literal["v1.1", "v2.0"], + method_filter: set[str] | None, ) -> tuple[ModuleName, ClassName, dict[RenderedPath, FileContents]]: serdes: list[tuple[list[TypeName], list[ModuleName], list[FileContents]]] = [] @@ -1394,6 +1395,7 @@ def generate_river_client_module( typed_dict_inputs: bool, method_filter: set[str] | None, protocol_version: Literal["v1.1", "v2.0"], + method_filter: set[str] | None, ) -> dict[RenderedPath, FileContents]: files: dict[RenderedPath, FileContents] = {} @@ -1423,6 +1425,7 @@ def generate_river_client_module( input_base_class, method_filter, protocol_version, + method_filter, ) if emitted_files: # Short-cut if we didn't actually emit anything @@ -1449,6 +1452,7 @@ def schema_to_river_client_codegen( file_opener: Callable[[Path], TextIO], method_filter: set[str] | None, protocol_version: Literal["v1.1", "v2.0"], + method_filter: set[str] | None, ) -> None: """Generates the lines of a River module.""" with read_schema() as f: @@ -1459,6 +1463,7 @@ def schema_to_river_client_codegen( typed_dict_inputs, method_filter, protocol_version, + method_filter, ).items(): module_path = Path(target_path).joinpath(subpath) module_path.parent.mkdir(mode=0o755, parents=True, exist_ok=True) diff --git a/src/replit_river/codegen/run.py b/src/replit_river/codegen/run.py index 0f0deac0..1a6801f5 100644 --- a/src/replit_river/codegen/run.py +++ b/src/replit_river/codegen/run.py @@ -52,6 +52,12 @@ def main() -> None: default="v1.1", choices=["v1.1", "v2.0"], ) + client.add_argument( + "--method-filter", + help="Only generate a subset of the specified methods", + action="store", + type=pathlib.Path, + ) client.add_argument("schema", help="schema file") args = parser.parse_args() @@ -83,6 +89,7 @@ def file_opener(path: Path) -> TextIO: file_opener, method_filter=method_filter, protocol_version=args.protocol_version, + method_filter=method_filter, ) else: raise NotImplementedError(f"Unknown command {args.command}") diff --git a/tests/codegen/snapshot/codegen_snapshot_fixtures.py b/tests/codegen/snapshot/codegen_snapshot_fixtures.py index 2fdff907..9b2e73e2 100644 --- a/tests/codegen/snapshot/codegen_snapshot_fixtures.py +++ b/tests/codegen/snapshot/codegen_snapshot_fixtures.py @@ -37,6 +37,7 @@ def file_opener(path: Path) -> TextIO: typed_dict_inputs=typeddict_inputs, method_filter=None, protocol_version="v1.1", + method_filter=None, ) for path, file in files.items(): file.seek(0) diff --git a/tests/codegen/test_rpc.py b/tests/codegen/test_rpc.py index 450a74f0..9d24c99f 100644 --- a/tests/codegen/test_rpc.py +++ b/tests/codegen/test_rpc.py @@ -34,6 +34,7 @@ def file_opener(path: Path) -> TextIO: file_opener=file_opener, method_filter=None, protocol_version="v1.1", + method_filter=None, ) From 1a5a3e94f7041096750f233d9243aef2bc8878c6 Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Tue, 25 Mar 2025 22:52:12 -0700 Subject: [PATCH 083/193] Updating snapshots --- src/replit_river/v2/session.py | 4 ++-- tests/codegen/rpc/generated/test_service/__init__.py | 1 - tests/codegen/rpc/generated/test_service/rpc_method.py | 3 --- .../snapshots/test_basic_stream/test_service/__init__.py | 9 +++++---- .../test_basic_stream/test_service/emit_error.py | 5 ----- .../test_basic_stream/test_service/stream_method.py | 5 ----- .../test_pathological_types/test_service/__init__.py | 1 - .../test_service/pathological_method.py | 5 ----- .../snapshots/test_unknown_enum/enumService/__init__.py | 2 -- .../snapshots/test_unknown_enum/enumService/needsEnum.py | 2 -- .../test_unknown_enum/enumService/needsEnumObject.py | 5 ----- 11 files changed, 7 insertions(+), 35 deletions(-) diff --git a/src/replit_river/v2/session.py b/src/replit_river/v2/session.py index 0842341e..6dd484e8 100644 --- a/src/replit_river/v2/session.py +++ b/src/replit_river/v2/session.py @@ -422,7 +422,7 @@ async def _begin_close_session_countdown(self) -> None: self._transport_id, self._to_id, ) - self._state = SessionState.PENDING + self._state = SessionState.NO_CONNECTION self._close_session_after_time_secs = close_session_after_time_secs async def _get_current_time(self) -> float: @@ -626,7 +626,7 @@ async def transition_connecting() -> None: self._state = SessionState.CONNECTING async def connection_interrupted() -> None: - self._state = SessionState.PENDING + self._state = SessionState.NO_CONNECTION if self._ws_unwrapped: self._task_manager.create_task(self._ws_unwrapped.close()) self._ws_unwrapped = None diff --git a/tests/codegen/rpc/generated/test_service/__init__.py b/tests/codegen/rpc/generated/test_service/__init__.py index 24545e00..3d9bc86a 100644 --- a/tests/codegen/rpc/generated/test_service/__init__.py +++ b/tests/codegen/rpc/generated/test_service/__init__.py @@ -11,7 +11,6 @@ from .rpc_method import ( Rpc_MethodInput, - Rpc_MethodInputTypeAdapter, Rpc_MethodOutput, Rpc_MethodOutputTypeAdapter, encode_Rpc_MethodInput, diff --git a/tests/codegen/rpc/generated/test_service/rpc_method.py b/tests/codegen/rpc/generated/test_service/rpc_method.py index dfe8a47c..1e40411f 100644 --- a/tests/codegen/rpc/generated/test_service/rpc_method.py +++ b/tests/codegen/rpc/generated/test_service/rpc_method.py @@ -40,9 +40,6 @@ class Rpc_MethodInput(TypedDict): data: str -Rpc_MethodInputTypeAdapter: TypeAdapter[Rpc_MethodInput] = TypeAdapter(Rpc_MethodInput) - - class Rpc_MethodOutput(BaseModel): data: str diff --git a/tests/codegen/snapshot/snapshots/test_basic_stream/test_service/__init__.py b/tests/codegen/snapshot/snapshots/test_basic_stream/test_service/__init__.py index 44d6c18c..0b106014 100644 --- a/tests/codegen/snapshot/snapshots/test_basic_stream/test_service/__init__.py +++ b/tests/codegen/snapshot/snapshots/test_basic_stream/test_service/__init__.py @@ -11,17 +11,18 @@ from .stream_method import ( Stream_MethodInput, - Stream_MethodInputTypeAdapter, Stream_MethodOutput, Stream_MethodOutputTypeAdapter, encode_Stream_MethodInput, ) -from .emit_error import Emit_ErrorErrors, Emit_ErrorErrorsTypeAdapter +from .emit_error import Emit_ErrorErrors -intTypeAdapter: TypeAdapter[int] = TypeAdapter(int) +boolTypeAdapter: TypeAdapter[bool] = TypeAdapter(bool) -boolTypeAdapter: TypeAdapter[bool] = TypeAdapter(bool) +Emit_ErrorErrorsTypeAdapter: TypeAdapter[Emit_ErrorErrors] = TypeAdapter( + Emit_ErrorErrors +) class Test_ServiceService: diff --git a/tests/codegen/snapshot/snapshots/test_basic_stream/test_service/emit_error.py b/tests/codegen/snapshot/snapshots/test_basic_stream/test_service/emit_error.py index ddba3a38..e7005c29 100644 --- a/tests/codegen/snapshot/snapshots/test_basic_stream/test_service/emit_error.py +++ b/tests/codegen/snapshot/snapshots/test_basic_stream/test_service/emit_error.py @@ -38,8 +38,3 @@ class Emit_ErrorErrorsOneOf_UNEXPECTED_DISCONNECT(RiverError): | RiverUnknownError, WrapValidator(translate_unknown_error), ] - - -Emit_ErrorErrorsTypeAdapter: TypeAdapter[Emit_ErrorErrors] = TypeAdapter( - Emit_ErrorErrors -) diff --git a/tests/codegen/snapshot/snapshots/test_basic_stream/test_service/stream_method.py b/tests/codegen/snapshot/snapshots/test_basic_stream/test_service/stream_method.py index 5baa9c40..1914aefc 100644 --- a/tests/codegen/snapshot/snapshots/test_basic_stream/test_service/stream_method.py +++ b/tests/codegen/snapshot/snapshots/test_basic_stream/test_service/stream_method.py @@ -40,11 +40,6 @@ class Stream_MethodInput(TypedDict): data: str -Stream_MethodInputTypeAdapter: TypeAdapter[Stream_MethodInput] = TypeAdapter( - Stream_MethodInput -) - - class Stream_MethodOutput(BaseModel): data: str diff --git a/tests/codegen/snapshot/snapshots/test_pathological_types/test_service/__init__.py b/tests/codegen/snapshot/snapshots/test_pathological_types/test_service/__init__.py index 3a578118..dd7f15e4 100644 --- a/tests/codegen/snapshot/snapshots/test_pathological_types/test_service/__init__.py +++ b/tests/codegen/snapshot/snapshots/test_pathological_types/test_service/__init__.py @@ -11,7 +11,6 @@ from .pathological_method import ( Pathological_MethodInput, - Pathological_MethodInputTypeAdapter, encode_Pathological_MethodInput, encode_Pathological_MethodInputObj_Boolean, encode_Pathological_MethodInputObj_Date, diff --git a/tests/codegen/snapshot/snapshots/test_pathological_types/test_service/pathological_method.py b/tests/codegen/snapshot/snapshots/test_pathological_types/test_service/pathological_method.py index 137add7b..2c325e64 100644 --- a/tests/codegen/snapshot/snapshots/test_pathological_types/test_service/pathological_method.py +++ b/tests/codegen/snapshot/snapshots/test_pathological_types/test_service/pathological_method.py @@ -476,8 +476,3 @@ class Pathological_MethodInput(TypedDict): req_undefined: None string: NotRequired[str | None] undefined: NotRequired[None] - - -Pathological_MethodInputTypeAdapter: TypeAdapter[Pathological_MethodInput] = ( - TypeAdapter(Pathological_MethodInput) -) diff --git a/tests/codegen/snapshot/snapshots/test_unknown_enum/enumService/__init__.py b/tests/codegen/snapshot/snapshots/test_unknown_enum/enumService/__init__.py index ab9eaa08..e1067475 100644 --- a/tests/codegen/snapshot/snapshots/test_unknown_enum/enumService/__init__.py +++ b/tests/codegen/snapshot/snapshots/test_unknown_enum/enumService/__init__.py @@ -13,7 +13,6 @@ NeedsenumErrors, NeedsenumErrorsTypeAdapter, NeedsenumInput, - NeedsenumInputTypeAdapter, NeedsenumOutput, NeedsenumOutputTypeAdapter, encode_NeedsenumInput, @@ -22,7 +21,6 @@ NeedsenumobjectErrors, NeedsenumobjectErrorsTypeAdapter, NeedsenumobjectInput, - NeedsenumobjectInputTypeAdapter, NeedsenumobjectOutput, NeedsenumobjectOutputTypeAdapter, encode_NeedsenumobjectInput, diff --git a/tests/codegen/snapshot/snapshots/test_unknown_enum/enumService/needsEnum.py b/tests/codegen/snapshot/snapshots/test_unknown_enum/enumService/needsEnum.py index 8f325775..69b976b5 100644 --- a/tests/codegen/snapshot/snapshots/test_unknown_enum/enumService/needsEnum.py +++ b/tests/codegen/snapshot/snapshots/test_unknown_enum/enumService/needsEnum.py @@ -29,8 +29,6 @@ def encode_NeedsenumInput(x: "NeedsenumInput") -> Any: return x -NeedsenumInputTypeAdapter: TypeAdapter[NeedsenumInput] = TypeAdapter(NeedsenumInput) - NeedsenumOutput = Annotated[ Literal["out_first", "out_second"] | RiverUnknownValue, WrapValidator(translate_unknown_value), diff --git a/tests/codegen/snapshot/snapshots/test_unknown_enum/enumService/needsEnumObject.py b/tests/codegen/snapshot/snapshots/test_unknown_enum/enumService/needsEnumObject.py index 4e1243a3..dd61a2d7 100644 --- a/tests/codegen/snapshot/snapshots/test_unknown_enum/enumService/needsEnumObject.py +++ b/tests/codegen/snapshot/snapshots/test_unknown_enum/enumService/needsEnumObject.py @@ -77,11 +77,6 @@ def encode_NeedsenumobjectInput( ) -NeedsenumobjectInputTypeAdapter: TypeAdapter[NeedsenumobjectInput] = TypeAdapter( - NeedsenumobjectInput -) - - class NeedsenumobjectOutputFooOneOf_out_first(BaseModel): kind: Annotated[Literal["out_first"], Field(alias="$kind")] = "out_first" foo: int From e9aded6213969dafce0518b3def28cfec00f1df2 Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Tue, 25 Mar 2025 23:30:26 -0700 Subject: [PATCH 084/193] nit --- src/replit_river/v2/session.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/replit_river/v2/session.py b/src/replit_river/v2/session.py index 6dd484e8..4dc925e5 100644 --- a/src/replit_river/v2/session.py +++ b/src/replit_river/v2/session.py @@ -791,12 +791,11 @@ async def send_upload[I, R, A]( # If this request is not closed and the session is killed, we should # throw exception here async for item in request: - control_flags = 0 await self.send_message( stream_id=stream_id, service_name=service_name, procedure_name=procedure_name, - control_flags=control_flags, + control_flags=0, payload=request_serializer(item), span=span, ) From 0139bbe38bda944f506840e8dd0a9717148c3d83 Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Wed, 26 Mar 2025 09:49:58 -0700 Subject: [PATCH 085/193] Type error instead of runtime error --- src/replit_river/codegen/client.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/replit_river/codegen/client.py b/src/replit_river/codegen/client.py index 66aa618d..13282d48 100644 --- a/src/replit_river/codegen/client.py +++ b/src/replit_river/codegen/client.py @@ -988,7 +988,7 @@ async def {name}( ] ) else: - raise ValueError("Precondition failed") + assert_never(protocol_version) elif procedure.type == "stream": assert output_meta assert error_meta @@ -1084,7 +1084,7 @@ async def {name}( ] ) else: - raise ValueError("Precondition failed") + assert_never(protocol_version) current_chunks.append("") return current_chunks From e7069e32aeb039450ee15014e77d91c9319a6b00 Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Wed, 26 Mar 2025 09:50:18 -0700 Subject: [PATCH 086/193] Avoid explosions when given incorrect input --- src/replit_river/codegen/typing.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/replit_river/codegen/typing.py b/src/replit_river/codegen/typing.py index 53c028ff..a3272f0d 100644 --- a/src/replit_river/codegen/typing.py +++ b/src/replit_river/codegen/typing.py @@ -210,19 +210,19 @@ def extract_inner_type(value: TypeExpression) -> TypeName: case ListTypeExpr(nested): return extract_inner_type(nested) case LiteralTypeExpr(_): - raise ValueError(f"Unexpected literal type: {value}") + raise ValueError(f"Unexpected literal type: {repr(value)}") case UnionTypeExpr(_): raise ValueError( - f"Attempting to extract from a union, currently not possible: {value}" + f"Attempting to extract from a union, currently not possible: {repr(value)}" ) case OpenUnionTypeExpr(_): raise ValueError( - f"Attempting to extract from a union, currently not possible: {value}" + f"Attempting to extract from a union, currently not possible: {repr(value)}" ) case TypeName(name): return TypeName(name) case NoneTypeExpr(): - raise ValueError(f"Attempting to extract from a literal 'None': {value}") + raise ValueError(f"Attempting to extract from a literal 'None': {repr(value)}") case other: assert_never(other) @@ -233,5 +233,5 @@ def ensure_literal_type(value: TypeExpression) -> TypeName: return TypeName(name) case other: raise ValueError( - f"Unexpected expression when expecting a type name: {other}" + f"Unexpected expression when expecting a type name: {repr(other)}" ) From 278b274a5415b1a49df217bff91902d83d433cfb Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Wed, 26 Mar 2025 10:25:01 -0700 Subject: [PATCH 087/193] Document --- src/replit_river/common_session.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/replit_river/common_session.py b/src/replit_river/common_session.py index e1444311..5032c09b 100644 --- a/src/replit_river/common_session.py +++ b/src/replit_river/common_session.py @@ -27,7 +27,7 @@ class SessionState(enum.Enum): Valid transitions: - NO_CONNECTION -> {CONNECTING} - - CONNECTING -> {ACTIVE, CLOSING} + - CONNECTING -> {NO_CONNECTION, ACTIVE, CLOSING} - ACTIVE -> {NO_CONNECTION, CONNECTING, CLOSING} - CLOSING -> {CLOSED} - CLOSED -> {} From e6fa6b283f3ed929702957a2cd5416cfc3597b8c Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Wed, 26 Mar 2025 10:43:26 -0700 Subject: [PATCH 088/193] Inline add_msg_to_stream --- src/replit_river/client_session.py | 1 - src/replit_river/codegen/typing.py | 10 +++++++--- src/replit_river/server_session.py | 1 - 3 files changed, 7 insertions(+), 5 deletions(-) diff --git a/src/replit_river/client_session.py b/src/replit_river/client_session.py index 2d1e847a..c37768b7 100644 --- a/src/replit_river/client_session.py +++ b/src/replit_river/client_session.py @@ -175,7 +175,6 @@ async def _handle_messages_from_ws(self) -> None: pass except RuntimeError as e: raise InvalidMessageException(e) from e - else: raise InvalidMessageException( "Client should not receive stream open bit" diff --git a/src/replit_river/codegen/typing.py b/src/replit_river/codegen/typing.py index a3272f0d..68443ffa 100644 --- a/src/replit_river/codegen/typing.py +++ b/src/replit_river/codegen/typing.py @@ -213,16 +213,20 @@ def extract_inner_type(value: TypeExpression) -> TypeName: raise ValueError(f"Unexpected literal type: {repr(value)}") case UnionTypeExpr(_): raise ValueError( - f"Attempting to extract from a union, currently not possible: {repr(value)}" + "Attempting to extract from a union, " + f"currently not possible: {repr(value)}" ) case OpenUnionTypeExpr(_): raise ValueError( - f"Attempting to extract from a union, currently not possible: {repr(value)}" + "Attempting to extract from a union, " + f"currently not possible: {repr(value)}" ) case TypeName(name): return TypeName(name) case NoneTypeExpr(): - raise ValueError(f"Attempting to extract from a literal 'None': {repr(value)}") + raise ValueError( + f"Attempting to extract from a literal 'None': {repr(value)}", + ) case other: assert_never(other) diff --git a/src/replit_river/server_session.py b/src/replit_river/server_session.py index 3a931274..c397e900 100644 --- a/src/replit_river/server_session.py +++ b/src/replit_river/server_session.py @@ -164,7 +164,6 @@ async def _handle_messages_from_ws(self, tg: asyncio.TaskGroup) -> None: pass except RuntimeError as e: raise InvalidMessageException(e) from e - else: _stream = await self._open_stream_and_call_handler(msg, tg) if isinstance(_stream, IgnoreMessage): From e29c8fd2ad0e907d85b0e05449602d57c47da0cb Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Thu, 27 Mar 2025 23:43:03 -0700 Subject: [PATCH 089/193] 4da3dbb192b7476fae8e930c0544ecfb1d523d72 --- src/replit_river/common_session.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/replit_river/common_session.py b/src/replit_river/common_session.py index 5032c09b..551f5545 100644 --- a/src/replit_river/common_session.py +++ b/src/replit_river/common_session.py @@ -40,5 +40,5 @@ class SessionState(enum.Enum): CLOSED = 4 -ConnectingStateta = set([SessionState.NO_CONNECTION, SessionState.CONNECTING]) +ConnectingStates = set([SessionState.NO_CONNECTION, SessionState.CONNECTING]) TerminalStates = set([SessionState.CLOSING, SessionState.CLOSED]) From 6ae70d8c45936e08db473d11ee0ead22dc4a92a3 Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Wed, 26 Mar 2025 11:05:22 -0700 Subject: [PATCH 090/193] These are part of the codegen serdes that point at the v1 client, they are unchanged --- src/replit_river/v2/client.py | 40 +---------------------------------- 1 file changed, 1 insertion(+), 39 deletions(-) diff --git a/src/replit_river/v2/client.py b/src/replit_river/v2/client.py index 958ff67d..4edf299b 100644 --- a/src/replit_river/v2/client.py +++ b/src/replit_river/v2/client.py @@ -7,12 +7,8 @@ from opentelemetry import trace from opentelemetry.trace import Span, SpanKind, Status, StatusCode -from pydantic import ( - BaseModel, - ValidationInfo, -) -from replit_river.error_schema import ERROR_CODE_UNKNOWN, RiverError, RiverException +from replit_river.error_schema import RiverError, RiverException from replit_river.transport_options import ( HandshakeMetadataType, TransportOptions, @@ -24,40 +20,6 @@ tracer = trace.get_tracer(__name__) -@dataclass(frozen=True) -class RiverUnknownValue(BaseModel): - tag: Literal["RiverUnknownValue"] - value: Any - - -class RiverUnknownError(RiverError): - pass - - -def translate_unknown_value( - value: Any, handler: Callable[[Any], Any], info: ValidationInfo -) -> Any | RiverUnknownValue: - try: - return handler(value) - except Exception: - return RiverUnknownValue(tag="RiverUnknownValue", value=value) - - -def translate_unknown_error( - value: Any, handler: Callable[[Any], Any], info: ValidationInfo -) -> Any | RiverUnknownError: - try: - return handler(value) - except Exception: - if isinstance(value, dict) and "code" in value and "message" in value: - return RiverUnknownError( - code=value["code"], - message=value["message"], - ) - else: - return RiverUnknownError(code=ERROR_CODE_UNKNOWN, message="Unknown error") - - class Client(Generic[HandshakeMetadataType]): def __init__( self, From 7de928347a7efb7355fc34474a91606900753795 Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Wed, 26 Mar 2025 11:15:08 -0700 Subject: [PATCH 091/193] Unused --- src/replit_river/v2/client_transport.py | 6 -- src/replit_river/v2/schema.py | 76 ------------------------- 2 files changed, 82 deletions(-) delete mode 100644 src/replit_river/v2/schema.py diff --git a/src/replit_river/v2/client_transport.py b/src/replit_river/v2/client_transport.py index 2a751ce3..c4bde024 100644 --- a/src/replit_river/v2/client_transport.py +++ b/src/replit_river/v2/client_transport.py @@ -20,12 +20,6 @@ logger = logging.getLogger(__name__) -class HandshakeBudgetExhaustedException(RiverException): - def __init__(self, code: str, message: str, client_id: str) -> None: - super().__init__(code, message) - self.client_id = client_id - - class ClientTransport(Generic[HandshakeMetadataType]): _session: Session | None diff --git a/src/replit_river/v2/schema.py b/src/replit_river/v2/schema.py deleted file mode 100644 index 3c9792b8..00000000 --- a/src/replit_river/v2/schema.py +++ /dev/null @@ -1,76 +0,0 @@ -from typing import Any, Literal, NotRequired, TypeAlias, TypedDict - -from replit_river.rpc import ACK_BIT_TYPE, ExpectedSessionState -from replit_river.v2.session import STREAM_CANCEL_BIT_TYPE - - -class ControlClose(TypedDict): - type: Literal["CLOSE"] - - -class ControlAck(TypedDict): - type: Literal["ACK"] - - -class ControlHandshakeRequest(TypedDict): - type: Literal["HANDSHAKE_REQ"] - protocolVersion: Literal["v2.0"] - sessionId: str - expectedSessionState: ExpectedSessionState - metdata: NotRequired[Any] - - -class HandshakeOK(TypedDict): - ok: Literal[True] - sessionId: str - - -class HandshakeError(TypedDict): - ok: Literal[False] - reaason: str - - -class ControlHandshakeResponse(TypedDict): - type: Literal["HANDSHAKE_RESP"] - status: HandshakeOK | HandshakeError - - -# This is sent when the server encounters an internal error -# i.e. an invariant has been violated -class BaseErrorStructure(TypedDict): - # This should be a defined literal to make sure errors are easily differentiated - # code: str # Supplied by implementations - # This can be any string - message: str - # Any extra metadata - extra: NotRequired[Any] - - -# When a client sends a malformed request. This can be -# for a variety of reasons which would be included -# in the message. -class InvalidRequestError(BaseErrorStructure): - code: Literal["INVALID_REQUEST"] - - -# This is sent when an exception happens in the handler of a stream. -class UncaughtError(BaseErrorStructure): - code: Literal["UNCAUGHT_ERROR"] - - -# This is sent when one side wishes to cancel the stream -# abruptly from user-space. Handling this is up to the procedure -# implementation or the caller. -class CancelError(BaseErrorStructure): - code: Literal["CANCEL"] - - -ProtocolError: TypeAlias = UncaughtError | InvalidRequestError | CancelError - -Control: TypeAlias = ( - ControlClose | ControlAck | ControlHandshakeRequest | ControlHandshakeResponse -) - -ValidPairings = ( - tuple[ACK_BIT_TYPE, ControlAck] | tuple[STREAM_CANCEL_BIT_TYPE, ProtocolError] -) From 864090c9ed1cedbffbe436a2e7060e0df6f513d8 Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Wed, 26 Mar 2025 11:19:10 -0700 Subject: [PATCH 092/193] _ws_unwrapped -> _ws, we never wrap it --- src/replit_river/v2/session.py | 32 ++++++++++++++++---------------- 1 file changed, 16 insertions(+), 16 deletions(-) diff --git a/src/replit_river/v2/session.py b/src/replit_river/v2/session.py index 4dc925e5..ea9f619b 100644 --- a/src/replit_river/v2/session.py +++ b/src/replit_river/v2/session.py @@ -107,7 +107,7 @@ class Session: _connection_condition: asyncio.Condition # ws state - _ws_unwrapped: ClientConnection | None + _ws: ClientConnection | None _heartbeat_misses: int _retry_connection_callback: RetryConnectionCallback | None @@ -146,7 +146,7 @@ def __init__( self._connection_condition = asyncio.Condition() # ws state - self._ws_unwrapped = None + self._ws = None self._heartbeat_misses = 0 self._retry_connection_callback = retry_connection_callback @@ -352,7 +352,7 @@ async def websocket_closed_callback() -> None: last_error = None rate_limiter.start_restoring_budget(client_id) self._state = SessionState.ACTIVE - self._ws_unwrapped = ws + self._ws = ws # We're connected, wake everybody up async with self._connection_condition: @@ -495,7 +495,7 @@ async def close(self) -> None: """Close the session and all associated streams.""" logger.info( f"{self._transport_id} closing session " - f"to {self._to_id}, ws: {self._ws_unwrapped}" + f"to {self._to_id}, ws: {self._ws}" ) if self._state in TerminalStates: # already closing @@ -516,10 +516,10 @@ async def close(self) -> None: await asyncio.gather(*[x.join() for x in self._streams.values()]) self._streams.clear() - if self._ws_unwrapped: + if self._ws: # The Session isn't guaranteed to live much longer than this close() # invocation, so let's await this close to avoid dropping the socket. - await self._ws_unwrapped.close() + await self._ws.close() self._state = SessionState.CLOSED @@ -545,7 +545,7 @@ def get_next_pending() -> TransportMessage | None: def get_ws() -> ClientConnection | None: if self.is_connected(): - return self._ws_unwrapped + return self._ws return None async def block_until_connected() -> None: @@ -585,13 +585,13 @@ def do_close() -> None: def _start_heartbeat(self) -> None: async def close_websocket() -> None: logger.debug( - "do_close called, _state=%r, _ws_unwrapped=%r", + "do_close called, _state=%r, _ws=%r", self._state, - self._ws_unwrapped, + self._ws, ) - if self._ws_unwrapped: - self._task_manager.create_task(self._ws_unwrapped.close()) - self._ws_unwrapped = None + if self._ws: + self._task_manager.create_task(self._ws.close()) + self._ws = None if self._retry_connection_callback: self._task_manager.create_task(self._retry_connection_callback()) @@ -627,9 +627,9 @@ async def transition_connecting() -> None: async def connection_interrupted() -> None: self._state = SessionState.NO_CONNECTION - if self._ws_unwrapped: - self._task_manager.create_task(self._ws_unwrapped.close()) - self._ws_unwrapped = None + if self._ws: + self._task_manager.create_task(self._ws.close()) + self._ws = None if self._retry_connection_callback: self._task_manager.create_task(self._retry_connection_callback()) @@ -683,7 +683,7 @@ def received_message(message: TransportMessage) -> None: block_until_connected=block_until_connected, transport_id=self._transport_id, get_state=lambda: self._state, - get_ws=lambda: self._ws_unwrapped, + get_ws=lambda: self._ws, transition_connecting=transition_connecting, connection_interrupted=connection_interrupted, reset_session_close_countdown=self._reset_session_close_countdown, From cd735a13c924983ec57557b5aa9472e597e15b7c Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Wed, 26 Mar 2025 11:19:16 -0700 Subject: [PATCH 093/193] Dead code --- src/replit_river/v2/session.py | 7 ------- 1 file changed, 7 deletions(-) diff --git a/src/replit_river/v2/session.py b/src/replit_river/v2/session.py index ea9f619b..54c11c17 100644 --- a/src/replit_river/v2/session.py +++ b/src/replit_river/v2/session.py @@ -675,9 +675,6 @@ async def block_until_connected() -> None: async with self._connection_condition: await self._connection_condition.wait() - def received_message(message: TransportMessage) -> None: - pass - self._task_manager.create_task( _serve( block_until_connected=block_until_connected, @@ -691,7 +688,6 @@ def received_message(message: TransportMessage) -> None: assert_incoming_seq_bookkeeping=assert_incoming_seq_bookkeeping, get_stream=lambda stream_id: self._streams.get(stream_id), close_stream=close_stream, - received_message=received_message, send_message=self.send_message, ) ) @@ -1136,7 +1132,6 @@ async def _serve( ], # noqa: E501 get_stream: Callable[[str], Channel[Any] | None], close_stream: Callable[[str], None], - received_message: Callable[[TransportMessage], None], send_message: SendMessage, ) -> None: """Serve messages from the websocket.""" @@ -1169,8 +1164,6 @@ async def _serve( msg, ) - received_message(msg) - if msg.controlFlags & STREAM_OPEN_BIT != 0: raise InvalidMessageException( "Client should not receive stream open bit" From 1a9ab87bd41fd37bc66896d6873cae84dfded9e4 Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Wed, 26 Mar 2025 11:19:49 -0700 Subject: [PATCH 094/193] We have nanoid types --- src/replit_river/v2/session.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/replit_river/v2/session.py b/src/replit_river/v2/session.py index 54c11c17..abe6f6f3 100644 --- a/src/replit_river/v2/session.py +++ b/src/replit_river/v2/session.py @@ -15,7 +15,7 @@ assert_never, ) -import nanoid # type: ignore +import nanoid import websockets.asyncio.client from aiochannel import Channel from aiochannel.errors import ChannelClosed From 399348af82398d0f019921baa1dda7225888b2fc Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Wed, 26 Mar 2025 11:23:27 -0700 Subject: [PATCH 095/193] Avoid double close --- src/replit_river/v2/session.py | 19 +++++++++++++------ 1 file changed, 13 insertions(+), 6 deletions(-) diff --git a/src/replit_river/v2/session.py b/src/replit_river/v2/session.py index abe6f6f3..25967ec2 100644 --- a/src/replit_river/v2/session.py +++ b/src/replit_river/v2/session.py @@ -122,7 +122,7 @@ class Session: seq: int # Last sent sequence number # Terminating - _terminating_task: asyncio.Task[None] + _terminating_task: asyncio.Task[None] | None def __init__( self, @@ -165,6 +165,9 @@ def __init__( self.ack = 0 self.seq = 0 + # Terminating + self._terminating_task = None + self._start_heartbeat() self._start_serve_responses() self._start_close_session_checker() @@ -192,10 +195,12 @@ async def ensure_connected[HandshakeMetadata]( return def do_close() -> None: - # We can't just call self.close() directly because - # we're inside a thread that will eventually be awaited - # during the cleanup procedure. - self._terminating_task = asyncio.create_task(self.close()) + # Avoid closing twice + if self._terminating_task is None: + # We can't just call self.close() directly because + # we're inside a thread that will eventually be awaited + # during the cleanup procedure. + self._terminating_task = asyncio.create_task(self.close()) if not self._connecting_task: self._connecting_task = asyncio.create_task( @@ -566,10 +571,12 @@ async def block_until_connected() -> None: def _start_close_session_checker(self) -> None: def do_close() -> None: + # Avoid closing twice + if self._terminating_task is None: # We can't just call self.close() directly because # we're inside a thread that will eventually be awaited # during the cleanup procedure. - self._terminating_task = asyncio.create_task(self.close()) + self._terminating_task = asyncio.create_task(self.close()) self._task_manager.create_task( _check_to_close_session( From ee6ac68581ab1a59e0b264da24fe5dadb0b808ac Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Wed, 26 Mar 2025 11:24:42 -0700 Subject: [PATCH 096/193] Debugging code --- src/replit_river/v2/client_transport.py | 3 --- src/replit_river/v2/session.py | 22 +++++----------------- 2 files changed, 5 insertions(+), 20 deletions(-) diff --git a/src/replit_river/v2/client_transport.py b/src/replit_river/v2/client_transport.py index c4bde024..9690174e 100644 --- a/src/replit_river/v2/client_transport.py +++ b/src/replit_river/v2/client_transport.py @@ -4,9 +4,6 @@ import nanoid -from replit_river.error_schema import ( - RiverException, -) from replit_river.rate_limiter import LeakyBucketRateLimit from replit_river.transport_options import ( HandshakeMetadataType, diff --git a/src/replit_river/v2/session.py b/src/replit_river/v2/session.py index 25967ec2..87ad91d6 100644 --- a/src/replit_river/v2/session.py +++ b/src/replit_river/v2/session.py @@ -258,17 +258,6 @@ async def _do_ensure_connected[HandshakeMetadata]( nextSentSeq=next_seq, ), ) - sb_state = None - ab_state = None - if self._send_buffer: - sb_state = [self._send_buffer[0].seq, self._send_buffer[0].ack] - if self._ack_buffer: - ab_state = [self._ack_buffer[0].seq, self._ack_buffer[0].ack] - logger.debug( - f"STATE{{seq={self.seq}, ack={self.ack}, next_seq={next_seq}, " - f"sb_state={sb_state}, ab_state={ab_state} }}" - ) - stream_id = nanoid.generate() async def websocket_closed_callback() -> None: logger.error("websocket closed before handshake response") @@ -277,7 +266,7 @@ async def websocket_closed_callback() -> None: TransportMessage( from_=self._transport_id, to=self._to_id, - streamId=stream_id, + streamId=nanoid.generate(), controlFlags=0, id=nanoid.generate(), seq=0, @@ -499,8 +488,7 @@ async def send_message( async def close(self) -> None: """Close the session and all associated streams.""" logger.info( - f"{self._transport_id} closing session " - f"to {self._to_id}, ws: {self._ws}" + f"{self._transport_id} closing session to {self._to_id}, ws: {self._ws}" ) if self._state in TerminalStates: # already closing @@ -573,9 +561,9 @@ def _start_close_session_checker(self) -> None: def do_close() -> None: # Avoid closing twice if self._terminating_task is None: - # We can't just call self.close() directly because - # we're inside a thread that will eventually be awaited - # during the cleanup procedure. + # We can't just call self.close() directly because + # we're inside a thread that will eventually be awaited + # during the cleanup procedure. self._terminating_task = asyncio.create_task(self.close()) self._task_manager.create_task( From f9d34ed7bb61924ec4b52899bcb0cfb6e8647743 Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Wed, 26 Mar 2025 11:26:26 -0700 Subject: [PATCH 097/193] Avoid using logging directly --- src/replit_river/v2/session.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/src/replit_river/v2/session.py b/src/replit_river/v2/session.py index 87ad91d6..9683a4f9 100644 --- a/src/replit_river/v2/session.py +++ b/src/replit_river/v2/session.py @@ -248,7 +248,7 @@ async def _do_ensure_connected[HandshakeMetadata]( next_seq = self._send_buffer[0].seq handshake_request = ControlMessageHandshakeRequest[ HandshakeMetadata - ]( # noqa: E501 + ]( type="HANDSHAKE_REQ", protocolVersion=protocol_version, sessionId=self.session_id, @@ -478,7 +478,7 @@ async def send_message( self._queue_full_lock.locked() or len(self._send_buffer) >= self._transport_options.buffer_size ): - logging.debug("send_message: queue full, waiting") + logger.debug("send_message: queue full, waiting") await self._queue_full_lock.acquire() self._send_buffer.append(msg) # Wake up buffered_message_sender @@ -638,7 +638,7 @@ def assert_incoming_seq_bookkeeping( ) -> Literal[True] | _IgnoreMessage: # Update bookkeeping if msg_seq < self.ack: - logging.info( + logger.info( f"{msg_from} received duplicate msg, got {msg_seq}" f" expected {self.ack}" ) @@ -1135,10 +1135,10 @@ async def _serve( idx = 0 try: while our_task and not our_task.cancelling() and not our_task.cancelled(): - logging.debug(f"_serve loop count={idx}") + logger.debug(f"_serve loop count={idx}") idx += 1 while (ws := get_ws()) is None or get_state() in ConnectingStates: - logging.debug("_handle_messages_from_ws spinning while connecting") + logger.debug("_handle_messages_from_ws spinning while connecting") await block_until_connected() logger.debug( "%s start handling messages from ws %s", @@ -1257,7 +1257,7 @@ async def _serve( except Exception: logger.exception("caught exception at message iterator") break - logging.debug("_handle_messages_from_ws exiting") + logger.debug("_handle_messages_from_ws exiting") except ExceptionGroup as eg: _, unhandled = eg.split(lambda e: isinstance(e, ConnectionClosed)) if unhandled: @@ -1270,4 +1270,4 @@ async def _serve( exc_info=unhandled, ) raise unhandled - logging.debug(f"_serve exiting normally after {idx} loops") + logger.debug(f"_serve exiting normally after {idx} loops") From c26812f68c08b790bc82482e0c23987f09b3634b Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Wed, 26 Mar 2025 11:51:50 -0700 Subject: [PATCH 098/193] No sense threading this all the way through --- src/replit_river/v2/client_transport.py | 3 --- src/replit_river/v2/session.py | 9 ++++----- 2 files changed, 4 insertions(+), 8 deletions(-) diff --git a/src/replit_river/v2/client_transport.py b/src/replit_river/v2/client_transport.py index 9690174e..ebafefbe 100644 --- a/src/replit_river/v2/client_transport.py +++ b/src/replit_river/v2/client_transport.py @@ -12,8 +12,6 @@ ) from replit_river.v2.session import Session -PROTOCOL_VERSION = "v2.0" - logger = logging.getLogger(__name__) @@ -68,7 +66,6 @@ async def get_or_create_session(self) -> Session: client_id=self._client_id, rate_limiter=self._rate_limiter, uri_and_metadata_factory=self._uri_and_metadata_factory, - protocol_version=PROTOCOL_VERSION, ) return existing_session diff --git a/src/replit_river/v2/session.py b/src/replit_river/v2/session.py index 9683a4f9..698de0c6 100644 --- a/src/replit_river/v2/session.py +++ b/src/replit_river/v2/session.py @@ -70,6 +70,8 @@ UriAndMetadata, ) +PROTOCOL_VERSION = "v2.0" + STREAM_CANCEL_BIT_TYPE = Literal[0b00100] STREAM_CANCEL_BIT: STREAM_CANCEL_BIT_TYPE = 0b00100 STREAM_CLOSED_BIT_TYPE = Literal[0b01000] @@ -180,7 +182,6 @@ async def ensure_connected[HandshakeMetadata]( uri_and_metadata_factory: Callable[ [], Awaitable[UriAndMetadata[HandshakeMetadata]] ], # noqa: E501 - protocol_version: str, ) -> None: """ Either return immediately or establish a websocket connection and return @@ -208,7 +209,6 @@ def do_close() -> None: client_id, rate_limiter, uri_and_metadata_factory, - protocol_version, do_close, ) ) @@ -221,8 +221,7 @@ async def _do_ensure_connected[HandshakeMetadata]( rate_limiter: LeakyBucketRateLimit, uri_and_metadata_factory: Callable[ [], Awaitable[UriAndMetadata[HandshakeMetadata]] - ], # noqa: E501 - protocol_version: str, + ], do_close: Callable[[], None], ) -> Literal[True]: max_retry = self._transport_options.connection_retry_options.max_retry @@ -250,7 +249,7 @@ async def _do_ensure_connected[HandshakeMetadata]( HandshakeMetadata ]( type="HANDSHAKE_REQ", - protocolVersion=protocol_version, + protocolVersion=PROTOCOL_VERSION, sessionId=self.session_id, metadata=uri_and_metadata["metadata"], expectedSessionState=ExpectedSessionState( From 3401587adb5d6ae2eb138a2a12233aa28040960d Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Wed, 26 Mar 2025 12:06:09 -0700 Subject: [PATCH 099/193] Break out _do_ensure_connected --- src/replit_river/v2/session.py | 422 +++++++++++++++++---------------- 1 file changed, 220 insertions(+), 202 deletions(-) diff --git a/src/replit_river/v2/session.py b/src/replit_river/v2/session.py index 698de0c6..9c4856ac 100644 --- a/src/replit_river/v2/session.py +++ b/src/replit_river/v2/session.py @@ -195,6 +195,11 @@ async def ensure_connected[HandshakeMetadata]( if self.is_connected(): return + def get_next_sent_seq() -> int: + if self._send_buffer: + return self._send_buffer[0].seq + return self.seq + def do_close() -> None: # Avoid closing twice if self._terminating_task is None: @@ -203,188 +208,52 @@ def do_close() -> None: # during the cleanup procedure. self._terminating_task = asyncio.create_task(self.close()) + async def transition_connected(ws: ClientConnection) -> None: + self._state = SessionState.ACTIVE + self._ws = ws + + # We're connected, wake everybody up + async with self._connection_condition: + self._connection_condition.notify_all() + + def finalize_attempt() -> None: + # We are in a state where we may throw an exception. + # + # To allow subsequent calls to ensure_connected to pass, we clear ourselves. + # This is safe because each individual function that is waiting on this + # function completeing already has a reference, so we'll last a few ticks + # before GC. + # + # Let's do our best to avoid clobbering other tasks by comparing the .name + current_task = asyncio.current_task() + if ( + self._connecting_task + and current_task + and self._connecting_task.get_name() == current_task.get_name() + ): + self._connecting_task = None + if not self._connecting_task: self._connecting_task = asyncio.create_task( - self._do_ensure_connected( - client_id, - rate_limiter, - uri_and_metadata_factory, - do_close, + _do_ensure_connected( + transport_id=self._transport_id, + client_id=client_id, + to_id=self._to_id, + session_id=self.session_id, + max_retry=self._transport_options.connection_retry_options.max_retry, + rate_limiter=rate_limiter, + uri_and_metadata_factory=uri_and_metadata_factory, + get_next_sent_seq=get_next_sent_seq, + get_current_ack=lambda: self.ack, + get_current_time=self._get_current_time, + transition_connected=transition_connected, + finalize_attempt=finalize_attempt, + do_close=do_close, ) ) await self._connecting_task - async def _do_ensure_connected[HandshakeMetadata]( - self, - client_id: str, - rate_limiter: LeakyBucketRateLimit, - uri_and_metadata_factory: Callable[ - [], Awaitable[UriAndMetadata[HandshakeMetadata]] - ], - do_close: Callable[[], None], - ) -> Literal[True]: - max_retry = self._transport_options.connection_retry_options.max_retry - logger.info("Attempting to establish new ws connection") - - last_error: Exception | None = None - i = 0 - while rate_limiter.has_budget_or_throw(client_id, ERROR_HANDSHAKE, last_error): - if i > 0: - logger.info(f"Retrying build handshake number {i} times") - i += 1 - - rate_limiter.consume_budget(client_id) - - ws = None - try: - uri_and_metadata = await uri_and_metadata_factory() - ws = await websockets.asyncio.client.connect(uri_and_metadata["uri"]) - - try: - next_seq = 0 - if self._send_buffer: - next_seq = self._send_buffer[0].seq - handshake_request = ControlMessageHandshakeRequest[ - HandshakeMetadata - ]( - type="HANDSHAKE_REQ", - protocolVersion=PROTOCOL_VERSION, - sessionId=self.session_id, - metadata=uri_and_metadata["metadata"], - expectedSessionState=ExpectedSessionState( - nextExpectedSeq=self.ack, - nextSentSeq=next_seq, - ), - ) - - async def websocket_closed_callback() -> None: - logger.error("websocket closed before handshake response") - - await send_transport_message( - TransportMessage( - from_=self._transport_id, - to=self._to_id, - streamId=nanoid.generate(), - controlFlags=0, - id=nanoid.generate(), - seq=0, - ack=0, - payload=handshake_request.model_dump(), - ), - ws=ws, - websocket_closed_callback=websocket_closed_callback, - ) - except ( - WebsocketClosedException, - FailedSendingMessageException, - ) as e: # noqa: E501 - raise RiverException( - ERROR_HANDSHAKE, - "Handshake failed, conn closed while sending response", - ) from e - - startup_grace_deadline_ms = await self._get_current_time() + 60_000 - while True: - if await self._get_current_time() >= startup_grace_deadline_ms: - raise RiverException( - ERROR_HANDSHAKE, - "Handshake response timeout, closing connection", - ) - try: - data = await ws.recv(decode=False) - except ConnectionClosed as e: - logger.debug( - "Connection closed during waiting for handshake response", - exc_info=True, - ) - raise RiverException( - ERROR_HANDSHAKE, - "Handshake failed, conn closed while waiting for response", - ) from e - - try: - response_msg = parse_transport_msg(data) - break - except IgnoreMessageException: - logger.debug("Ignoring transport message", exc_info=True) - continue - except InvalidMessageException as e: - raise RiverException( - ERROR_HANDSHAKE, - "Got invalid transport message, closing connection", - ) from e - - try: - handshake_response = ControlMessageHandshakeResponse( - **response_msg.payload - ) - logger.debug("river client waiting for handshake response") - except ValidationError as e: - raise RiverException( - ERROR_HANDSHAKE, "Failed to parse handshake response" - ) from e - - logger.debug( - "river client get handshake response : %r", handshake_response - ) # noqa: E501 - if not handshake_response.status.ok: - if ( - handshake_response.status.code - == ERROR_CODE_SESSION_STATE_MISMATCH - ): - do_close() - - raise RiverException( - ERROR_HANDSHAKE, - f"Handshake failed with code {handshake_response.status.code}: { - handshake_response.status.reason - }", - ) - - last_error = None - rate_limiter.start_restoring_budget(client_id) - self._state = SessionState.ACTIVE - self._ws = ws - - # We're connected, wake everybody up - async with self._connection_condition: - self._connection_condition.notify_all() - break - except Exception as e: - if ws: - await ws.close() - last_error = e - backoff_time = rate_limiter.get_backoff_ms(client_id) - logger.exception( - f"Error connecting, retrying with {backoff_time}ms backoff" - ) - await asyncio.sleep(backoff_time / 1000) - - # We are in a state where we may throw an exception. - # - # To permit subsequent calls to ensure_connected to pass, we clear ourselves. - # This is safe because each individual function that is waiting on this - # function completeing already has a reference, so we'll last a few ticks - # before GC. - # - # Let's do our best to avoid clobbering other tasks by comparing the .name - current_task = asyncio.current_task() - if ( - self._connecting_task - and current_task - and self._connecting_task.get_name() == current_task.get_name() - ): - self._connecting_task = None - - if last_error is not None: - raise RiverException( - ERROR_HANDSHAKE, - f"Failed to create ws after retrying {max_retry} number of times", - ) from last_error - - return True - def is_closed(self) -> bool: """ If the session is in a terminal state. @@ -987,33 +856,6 @@ async def send_close_stream( ) -async def _check_to_close_session( - transport_id: str, - close_session_check_interval_ms: float, - get_state: Callable[[], SessionState], - get_current_time: Callable[[], Awaitable[float]], - get_close_session_after_time_secs: Callable[[], float | None], - do_close: Callable[[], None], -) -> None: - our_task = asyncio.current_task() - while our_task and not our_task.cancelling() and not our_task.cancelled(): - await asyncio.sleep(close_session_check_interval_ms / 1000) - if get_state() in TerminalStates: - # already closing - break - # calculate the value now before comparing it so that there are no - # await points between the check and the comparison to avoid a TOCTOU - # race. - current_time = await get_current_time() - close_session_after_time_secs = get_close_session_after_time_secs() - if not close_session_after_time_secs: - continue - if current_time > close_session_after_time_secs: - logger.info("Grace period ended for %s, closing session", transport_id) - do_close() - our_task.cancel() - - async def _buffered_message_sender( block_until_connected: Callable[[], Awaitable[None]], message_enqueued: asyncio.Semaphore, @@ -1072,6 +914,182 @@ async def _buffered_message_sender( break +async def _check_to_close_session( + transport_id: str, + close_session_check_interval_ms: float, + get_state: Callable[[], SessionState], + get_current_time: Callable[[], Awaitable[float]], + get_close_session_after_time_secs: Callable[[], float | None], + do_close: Callable[[], None], +) -> None: + our_task = asyncio.current_task() + while our_task and not our_task.cancelling() and not our_task.cancelled(): + await asyncio.sleep(close_session_check_interval_ms / 1000) + if get_state() in TerminalStates: + # already closing + break + # calculate the value now before comparing it so that there are no + # await points between the check and the comparison to avoid a TOCTOU + # race. + current_time = await get_current_time() + close_session_after_time_secs = get_close_session_after_time_secs() + if not close_session_after_time_secs: + continue + if current_time > close_session_after_time_secs: + logger.info("Grace period ended for %s, closing session", transport_id) + do_close() + our_task.cancel() + + +async def _do_ensure_connected[HandshakeMetadata]( + transport_id: str, + client_id: str, + to_id: str, + session_id: str, + max_retry: int, + rate_limiter: LeakyBucketRateLimit, + uri_and_metadata_factory: Callable[ + [], Awaitable[UriAndMetadata[HandshakeMetadata]] + ], + get_current_time: Callable[[], Awaitable[float]], + get_next_sent_seq: Callable[[], int], + get_current_ack: Callable[[], int], + transition_connected: Callable[[ClientConnection], Awaitable[None]], + finalize_attempt: Callable[[], None], + do_close: Callable[[], None], +) -> Literal[True]: + logger.info("Attempting to establish new ws connection") + + last_error: Exception | None = None + i = 0 + while rate_limiter.has_budget_or_throw(client_id, ERROR_HANDSHAKE, last_error): + if i > 0: + logger.info(f"Retrying build handshake number {i} times") + i += 1 + + rate_limiter.consume_budget(client_id) + + ws = None + try: + uri_and_metadata = await uri_and_metadata_factory() + ws = await websockets.asyncio.client.connect(uri_and_metadata["uri"]) + + try: + handshake_request = ControlMessageHandshakeRequest[HandshakeMetadata]( + type="HANDSHAKE_REQ", + protocolVersion=PROTOCOL_VERSION, + sessionId=session_id, + metadata=uri_and_metadata["metadata"], + expectedSessionState=ExpectedSessionState( + nextExpectedSeq=get_current_ack(), + nextSentSeq=get_next_sent_seq(), + ), + ) + + async def websocket_closed_callback() -> None: + logger.error("websocket closed before handshake response") + + await send_transport_message( + TransportMessage( + from_=transport_id, + to=to_id, + streamId=nanoid.generate(), + controlFlags=0, + id=nanoid.generate(), + seq=0, + ack=0, + payload=handshake_request.model_dump(), + ), + ws=ws, + websocket_closed_callback=websocket_closed_callback, + ) + except ( + WebsocketClosedException, + FailedSendingMessageException, + ) as e: # noqa: E501 + raise RiverException( + ERROR_HANDSHAKE, + "Handshake failed, conn closed while sending response", + ) from e + + startup_grace_deadline_ms = await get_current_time() + 60_000 + while True: + if await get_current_time() >= startup_grace_deadline_ms: + raise RiverException( + ERROR_HANDSHAKE, + "Handshake response timeout, closing connection", + ) + try: + data = await ws.recv(decode=False) + except ConnectionClosed as e: + logger.debug( + "Connection closed during waiting for handshake response", + exc_info=True, + ) + raise RiverException( + ERROR_HANDSHAKE, + "Handshake failed, conn closed while waiting for response", + ) from e + + try: + response_msg = parse_transport_msg(data) + break + except IgnoreMessageException: + logger.debug("Ignoring transport message", exc_info=True) + continue + except InvalidMessageException as e: + raise RiverException( + ERROR_HANDSHAKE, + "Got invalid transport message, closing connection", + ) from e + + try: + handshake_response = ControlMessageHandshakeResponse( + **response_msg.payload + ) + logger.debug("river client waiting for handshake response") + except ValidationError as e: + raise RiverException( + ERROR_HANDSHAKE, "Failed to parse handshake response" + ) from e + + logger.debug("river client get handshake response : %r", handshake_response) # noqa: E501 + if not handshake_response.status.ok: + if handshake_response.status.code == ERROR_CODE_SESSION_STATE_MISMATCH: + do_close() + + raise RiverException( + ERROR_HANDSHAKE, + f"Handshake failed with code {handshake_response.status.code}: { + handshake_response.status.reason + }", + ) + + last_error = None + rate_limiter.start_restoring_budget(client_id) + transition_connected(ws) + break + except Exception as e: + if ws: + await ws.close() + last_error = e + backoff_time = rate_limiter.get_backoff_ms(client_id) + logger.exception( + f"Error connecting, retrying with {backoff_time}ms backoff" + ) + await asyncio.sleep(backoff_time / 1000) + + finalize_attempt() + + if last_error is not None: + raise RiverException( + ERROR_HANDSHAKE, + f"Failed to create ws after retrying {max_retry} number of times", + ) from last_error + + return True + + async def _setup_heartbeat( block_until_connected: Callable[[], Awaitable[None]], session_id: str, From 3c655d722341cce539517592cf5a961a3c647704 Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Wed, 26 Mar 2025 12:07:28 -0700 Subject: [PATCH 100/193] send_message -> _send_message --- src/replit_river/v2/session.py | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/src/replit_river/v2/session.py b/src/replit_river/v2/session.py index 9c4856ac..8e0d110d 100644 --- a/src/replit_river/v2/session.py +++ b/src/replit_river/v2/session.py @@ -294,7 +294,7 @@ def _reset_session_close_countdown(self) -> None: self._heartbeat_misses = 0 self._close_session_after_time_secs = None - async def send_message( + async def _send_message( self, stream_id: str, payload: dict[Any, Any] | str, @@ -308,7 +308,7 @@ async def send_message( if self._state in TerminalStates: return logger.debug( - "send_message(stream_id=%r, payload=%r, control_flags=%r, " + "_send_message(stream_id=%r, payload=%r, control_flags=%r, " "service_name=%r, procedure_name=%r)", stream_id, payload, @@ -346,7 +346,7 @@ async def send_message( self._queue_full_lock.locked() or len(self._send_buffer) >= self._transport_options.buffer_size ): - logger.debug("send_message: queue full, waiting") + logger.debug("_send_message: queue full, waiting") await self._queue_full_lock.acquire() self._send_buffer.append(msg) # Wake up buffered_message_sender @@ -551,7 +551,7 @@ async def block_until_connected() -> None: assert_incoming_seq_bookkeeping=assert_incoming_seq_bookkeeping, get_stream=lambda stream_id: self._streams.get(stream_id), close_stream=close_stream, - send_message=self.send_message, + send_message=self._send_message, ) ) @@ -573,7 +573,7 @@ async def send_rpc[R, A]( stream_id = nanoid.generate() output: Channel[Any] = Channel(1) self._streams[stream_id] = output - await self.send_message( + await self._send_message( stream_id=stream_id, control_flags=STREAM_OPEN_BIT | STREAM_CLOSED_BIT, payload=request_serializer(request), @@ -586,7 +586,7 @@ async def send_rpc[R, A]( async with asyncio.timeout(timeout.total_seconds()): response = await output.get() except asyncio.TimeoutError as e: - await self.send_message( + await self._send_message( stream_id=stream_id, control_flags=STREAM_CANCEL_BIT, payload={"type": "CANCEL"}, @@ -635,7 +635,7 @@ async def send_upload[I, R, A]( output: Channel[Any] = Channel(1) self._streams[stream_id] = output try: - await self.send_message( + await self._send_message( stream_id=stream_id, control_flags=STREAM_OPEN_BIT, service_name=service_name, @@ -650,7 +650,7 @@ async def send_upload[I, R, A]( # If this request is not closed and the session is killed, we should # throw exception here async for item in request: - await self.send_message( + await self._send_message( stream_id=stream_id, service_name=service_name, procedure_name=procedure_name, @@ -710,7 +710,7 @@ async def send_subscription[R, E, A]( stream_id = nanoid.generate() output: Channel[Any] = Channel(MAX_MESSAGE_BUFFER_SIZE) self._streams[stream_id] = output - await self.send_message( + await self._send_message( service_name=service_name, procedure_name=procedure_name, stream_id=stream_id, @@ -766,7 +766,7 @@ async def send_stream[I, R, E, A]( output: Channel[Any] = Channel(MAX_MESSAGE_BUFFER_SIZE) self._streams[stream_id] = output try: - await self.send_message( + await self._send_message( service_name=service_name, procedure_name=procedure_name, stream_id=stream_id, @@ -795,7 +795,7 @@ async def _encode_stream() -> None: async for item in request: if item is None: continue - await self.send_message( + await self._send_message( service_name=service_name, procedure_name=procedure_name, stream_id=stream_id, @@ -845,7 +845,7 @@ async def send_close_stream( extra_control_flags: int, ) -> None: # close stream - await self.send_message( + await self._send_message( service_name=service_name, procedure_name=procedure_name, stream_id=stream_id, From 39630e99d055f9adf3c241a715606339d1518b0b Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Wed, 26 Mar 2025 12:08:23 -0700 Subject: [PATCH 101/193] noqa --- src/replit_river/v2/session.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/replit_river/v2/session.py b/src/replit_river/v2/session.py index 8e0d110d..f67facff 100644 --- a/src/replit_river/v2/session.py +++ b/src/replit_river/v2/session.py @@ -181,7 +181,7 @@ async def ensure_connected[HandshakeMetadata]( rate_limiter: LeakyBucketRateLimit, uri_and_metadata_factory: Callable[ [], Awaitable[UriAndMetadata[HandshakeMetadata]] - ], # noqa: E501 + ], ) -> None: """ Either return immediately or establish a websocket connection and return @@ -1006,7 +1006,7 @@ async def websocket_closed_callback() -> None: except ( WebsocketClosedException, FailedSendingMessageException, - ) as e: # noqa: E501 + ) as e: raise RiverException( ERROR_HANDSHAKE, "Handshake failed, conn closed while sending response", @@ -1053,7 +1053,7 @@ async def websocket_closed_callback() -> None: ERROR_HANDSHAKE, "Failed to parse handshake response" ) from e - logger.debug("river client get handshake response : %r", handshake_response) # noqa: E501 + logger.debug("river client get handshake response : %r", handshake_response) if not handshake_response.status.ok: if handshake_response.status.code == ERROR_CODE_SESSION_STATE_MISMATCH: do_close() @@ -1141,7 +1141,7 @@ async def _serve( close_session: Callable[[], Awaitable[None]], assert_incoming_seq_bookkeeping: Callable[ [str, int, int], Literal[True] | _IgnoreMessage - ], # noqa: E501 + ], get_stream: Callable[[str], Channel[Any] | None], close_stream: Callable[[str], None], send_message: SendMessage, From 312e20724ae016b8848b757a8ec88683013f93d3 Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Thu, 27 Mar 2025 09:53:30 -0700 Subject: [PATCH 102/193] Just use asyncio.Event to represent "connected" --- src/replit_river/v2/session.py | 32 +++++++++++++++++--------------- 1 file changed, 17 insertions(+), 15 deletions(-) diff --git a/src/replit_river/v2/session.py b/src/replit_river/v2/session.py index f67facff..48140071 100644 --- a/src/replit_river/v2/session.py +++ b/src/replit_river/v2/session.py @@ -106,7 +106,7 @@ class Session: _close_session_callback: CloseSessionCallback _close_session_after_time_secs: float | None _connecting_task: asyncio.Task[Literal[True]] | None - _connection_condition: asyncio.Condition + _wait_for_connected: asyncio.Event # ws state _ws: ClientConnection | None @@ -145,7 +145,7 @@ def __init__( self._close_session_callback = close_session_callback self._close_session_after_time_secs: float | None = None self._connecting_task = None - self._connection_condition = asyncio.Condition() + self._wait_for_connected = asyncio.Event() # ws state self._ws = None @@ -208,13 +208,16 @@ def do_close() -> None: # during the cleanup procedure. self._terminating_task = asyncio.create_task(self.close()) - async def transition_connected(ws: ClientConnection) -> None: + def transition_connecting() -> None: + # "Clear" here means observers should wait until we are connected. + self._wait_for_connected.clear() + + def transition_connected(ws: ClientConnection) -> None: self._state = SessionState.ACTIVE self._ws = ws - # We're connected, wake everybody up - async with self._connection_condition: - self._connection_condition.notify_all() + # We're connected, wake everybody up using set() + self._wait_for_connected.set() def finalize_attempt() -> None: # We are in a state where we may throw an exception. @@ -246,6 +249,7 @@ def finalize_attempt() -> None: get_next_sent_seq=get_next_sent_seq, get_current_ack=lambda: self.ack, get_current_time=self._get_current_time, + transition_connecting=transition_connecting, transition_connected=transition_connected, finalize_attempt=finalize_attempt, do_close=do_close, @@ -364,8 +368,7 @@ async def close(self) -> None: self._state = SessionState.CLOSING # We need to wake up all tasks waiting for connection to be established - async with self._connection_condition: - self._connection_condition.notify_all() + self._wait_for_connected.clear() await self._task_manager.cancel_all_tasks() @@ -410,8 +413,7 @@ def get_ws() -> ClientConnection | None: return None async def block_until_connected() -> None: - async with self._connection_condition: - await self._connection_condition.wait() + await self._wait_for_connected.wait() self._task_manager.create_task( _buffered_message_sender( @@ -468,8 +470,7 @@ def increment_and_get_heartbeat_misses() -> int: return self._heartbeat_misses async def block_until_connected() -> None: - async with self._connection_condition: - await self._connection_condition.wait() + await self._wait_for_connected.wait() self._task_manager.create_task( _setup_heartbeat( @@ -535,8 +536,7 @@ def close_stream(stream_id: str) -> None: del self._streams[stream_id] async def block_until_connected() -> None: - async with self._connection_condition: - await self._connection_condition.wait() + await self._wait_for_connected.wait() self._task_manager.create_task( _serve( @@ -954,7 +954,8 @@ async def _do_ensure_connected[HandshakeMetadata]( get_current_time: Callable[[], Awaitable[float]], get_next_sent_seq: Callable[[], int], get_current_ack: Callable[[], int], - transition_connected: Callable[[ClientConnection], Awaitable[None]], + transition_connecting: Callable[[], None], + transition_connected: Callable[[ClientConnection], None], finalize_attempt: Callable[[], None], do_close: Callable[[], None], ) -> Literal[True]: @@ -968,6 +969,7 @@ async def _do_ensure_connected[HandshakeMetadata]( i += 1 rate_limiter.consume_budget(client_id) + transition_connecting() ws = None try: From 215ec79cdeff9627094cd1f4597b625ea3fdb6e2 Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Thu, 27 Mar 2025 09:57:30 -0700 Subject: [PATCH 103/193] Terminate _serve early if we are terminal --- src/replit_river/v2/session.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/src/replit_river/v2/session.py b/src/replit_river/v2/session.py index 48140071..a9368514 100644 --- a/src/replit_river/v2/session.py +++ b/src/replit_river/v2/session.py @@ -1105,6 +1105,7 @@ async def _setup_heartbeat( while True: while (state := get_state()) in ConnectingStates: await block_until_connected() + if state in TerminalStates: logger.debug( "Session is closed, no need to send heartbeat, state : " @@ -1156,9 +1157,17 @@ async def _serve( while our_task and not our_task.cancelling() and not our_task.cancelled(): logger.debug(f"_serve loop count={idx}") idx += 1 - while (ws := get_ws()) is None or get_state() in ConnectingStates: + while (ws := get_ws()) is None or (state := get_state()) in ConnectingStates: logger.debug("_handle_messages_from_ws spinning while connecting") await block_until_connected() + + if state in TerminalStates: + logger.debug( + f"Session is {state}, shut down _serve", + ) + # session is closing / closed, no need to serve anymore + break + logger.debug( "%s start handling messages from ws %s", "client", From dd371bb29062f8501fab326b7c14c1c3e161eaa7 Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Thu, 27 Mar 2025 10:23:25 -0700 Subject: [PATCH 104/193] Switch from "queue_full Lock to space_available Event --- src/replit_river/error_schema.py | 14 +++++++++ src/replit_river/v2/session.py | 51 +++++++++++++++++++------------- 2 files changed, 44 insertions(+), 21 deletions(-) diff --git a/src/replit_river/error_schema.py b/src/replit_river/error_schema.py index fab1041b..0dfee049 100644 --- a/src/replit_river/error_schema.py +++ b/src/replit_river/error_schema.py @@ -17,6 +17,10 @@ # ERROR_CODE_CANCEL is the code used when either server or client cancels the stream. ERROR_CODE_CANCEL = "CANCEL" +# ERROR_CODE_SESSION_CLOSED is the code used when either server or client closes +# the session. +ERROR_CODE_SESSION_CLOSED = "CLOSED" + # ERROR_CODE_UNKNOWN is the code for the RiverUnknownError ERROR_CODE_UNKNOWN = "UNKNOWN" @@ -78,6 +82,16 @@ class StreamClosedRiverServiceException(RiverServiceException): pass +class SessionClosedRiverServiceException(RiverServiceException): + def __init__( + self, + message: str, + service: str | None, + procedure: str | None, + ) -> None: + super().__init__(ERROR_CODE_SESSION_CLOSED, message, service, procedure) + + def exception_from_message(code: str) -> type[RiverServiceException]: """Return the error class for a given error code.""" if code == ERROR_CODE_STREAM_CLOSED: diff --git a/src/replit_river/v2/session.py b/src/replit_river/v2/session.py index a9368514..184198cb 100644 --- a/src/replit_river/v2/session.py +++ b/src/replit_river/v2/session.py @@ -39,6 +39,7 @@ RiverError, RiverException, RiverServiceException, + SessionClosedRiverServiceException, StreamClosedRiverServiceException, exception_from_message, ) @@ -154,8 +155,10 @@ def __init__( # message state self._message_enqueued = asyncio.Semaphore() - self._space_available_cond = asyncio.Condition() - self._queue_full_lock = asyncio.Lock() + self._space_available = asyncio.Event() + # Ensure we initialize the above Event to "set" to avoid being blocked from + # the beginning. + self._space_available.set() # stream for tasks self._streams: dict[str, Channel[Any]] = {} @@ -337,22 +340,24 @@ async def _send_message( with use_span(span): trace_propagator.inject(msg, None, trace_setter) - # As we prepare to push onto the buffer, if the buffer is full, we lock. - # This lock will be released by the buffered_message_sender task, so it's - # important that we don't release it here. - # - # The reason for this is that in Python, asyncio.Lock is "fair", first - # come, first served. - # - # If somebody else is already waiting or we've filled the buffer, we - # should get in line. - if ( - self._queue_full_lock.locked() - or len(self._send_buffer) >= self._transport_options.buffer_size - ): - logger.debug("_send_message: queue full, waiting") - await self._queue_full_lock.acquire() + # Ensure the buffer isn't full before we enqueue + await self._space_available.wait() + + # Before we append, do an important check + if self._state in TerminalStates: + # session is closing / closed, raise + raise SessionClosedRiverServiceException( + "river session is closed, dropping message", + service_name, + procedure_name, + ) + self._send_buffer.append(msg) + + # If the buffer is now full, reset the block + if len(self._send_buffer) >= self._transport_options.buffer_size: + self._space_available.clear() + # Wake up buffered_message_sender self._message_enqueued.release() self.seq += 1 @@ -368,7 +373,10 @@ async def close(self) -> None: self._state = SessionState.CLOSING # We need to wake up all tasks waiting for connection to be established - self._wait_for_connected.clear() + self._wait_for_connected.set() + + # We also need to wake up consumers waiting to enqueue messages + self._space_available.set() await self._task_manager.cancel_all_tasks() @@ -399,8 +407,7 @@ def commit(msg: TransportMessage) -> None: self._ack_buffer.append(pending) # On commit, release pending writers waiting for more buffer space - if self._queue_full_lock.locked(): - self._queue_full_lock.release() + self._space_available.set() def get_next_pending() -> TransportMessage | None: if self._send_buffer: @@ -1157,7 +1164,9 @@ async def _serve( while our_task and not our_task.cancelling() and not our_task.cancelled(): logger.debug(f"_serve loop count={idx}") idx += 1 - while (ws := get_ws()) is None or (state := get_state()) in ConnectingStates: + while (ws := get_ws()) is None or ( + state := get_state() + ) in ConnectingStates: logger.debug("_handle_messages_from_ws spinning while connecting") await block_until_connected() From 05a1f3c115ab7750af5d2518f68ef54e6c5693d6 Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Thu, 27 Mar 2025 10:36:26 -0700 Subject: [PATCH 105/193] Switch message_enqueued semaphore to Event to avoid out-of-sync bugs Semaphore length and _send_buffer were maintained 1:1, but it still left the opportunity for bugs in the future. Switching to an Event lets us only ever care about the length of the _send_buffer. --- src/replit_river/v2/session.py | 43 +++++++++++++++++++++------------- 1 file changed, 27 insertions(+), 16 deletions(-) diff --git a/src/replit_river/v2/session.py b/src/replit_river/v2/session.py index 184198cb..2824660c 100644 --- a/src/replit_river/v2/session.py +++ b/src/replit_river/v2/session.py @@ -114,6 +114,10 @@ class Session: _heartbeat_misses: int _retry_connection_callback: RetryConnectionCallback | None + # message state + _process_messages: asyncio.Event + _space_available: asyncio.Event + # stream for tasks _streams: dict[str, Channel[Any]] @@ -154,7 +158,7 @@ def __init__( self._retry_connection_callback = retry_connection_callback # message state - self._message_enqueued = asyncio.Semaphore() + self._process_messages = asyncio.Event() self._space_available = asyncio.Event() # Ensure we initialize the above Event to "set" to avoid being blocked from # the beginning. @@ -359,7 +363,7 @@ async def _send_message( self._space_available.clear() # Wake up buffered_message_sender - self._message_enqueued.release() + self._process_messages.set() self.seq += 1 async def close(self) -> None: @@ -372,11 +376,13 @@ async def close(self) -> None: return self._state = SessionState.CLOSING - # We need to wake up all tasks waiting for connection to be established + # We're closing, so we need to wake up... + # ... tasks waiting for connection to be established self._wait_for_connected.set() - - # We also need to wake up consumers waiting to enqueue messages + # ... consumers waiting to enqueue messages self._space_available.set() + # ... message processor so it can exit cleanly + self._process_messages.set() await self._task_manager.cancel_all_tasks() @@ -406,8 +412,12 @@ def commit(msg: TransportMessage) -> None: logger.error("Out of sequence error") self._ack_buffer.append(pending) - # On commit, release pending writers waiting for more buffer space + # On commit... + # ... release pending writers waiting for more buffer space self._space_available.set() + # ... tell the message sender to back off if there are no pending messages + if not self._send_buffer: + self._process_messages.clear() def get_next_pending() -> TransportMessage | None: if self._send_buffer: @@ -422,10 +432,13 @@ def get_ws() -> ClientConnection | None: async def block_until_connected() -> None: await self._wait_for_connected.wait() + async def block_until_message_available() -> None: + await self._process_messages.wait() + self._task_manager.create_task( _buffered_message_sender( block_until_connected=block_until_connected, - message_enqueued=self._message_enqueued, + block_until_message_available=block_until_message_available, get_ws=get_ws, websocket_closed_callback=self._begin_close_session_countdown, get_next_pending=get_next_pending, @@ -865,7 +878,7 @@ async def send_close_stream( async def _buffered_message_sender( block_until_connected: Callable[[], Awaitable[None]], - message_enqueued: asyncio.Semaphore, + block_until_message_available: Callable[[], Awaitable[None]], get_ws: Callable[[], ClientConnection | None], websocket_closed_callback: Callable[[], Coroutine[Any, Any, None]], get_next_pending: Callable[[], TransportMessage | None], @@ -874,7 +887,12 @@ async def _buffered_message_sender( ) -> None: our_task = asyncio.current_task() while our_task and not our_task.cancelling() and not our_task.cancelled(): - await message_enqueued.acquire() + await block_until_message_available() + + if get_state() in TerminalStates: + logger.debug("buffered_message_sender: closing") + return + while (ws := get_ws()) is None: # Block until we have a handle logger.debug( @@ -882,10 +900,6 @@ async def _buffered_message_sender( ) await block_until_connected() - if get_state() in TerminalStates: - logger.debug("We're going away!") - return - if not ws: logger.debug("ws is not connected, loop") continue @@ -906,18 +920,15 @@ async def _buffered_message_sender( type(e), exc_info=e, ) - message_enqueued.release() break except FailedSendingMessageException: logger.error( "Failed sending message, waiting for retry from buffer", exc_info=True, ) - message_enqueued.release() break except Exception: logger.exception("Error attempting to send buffered messages") - message_enqueued.release() break From 14df5c27d725b3c4989ef755539316c97b330348 Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Thu, 27 Mar 2025 10:51:19 -0700 Subject: [PATCH 106/193] Better background task management --- src/replit_river/v2/session.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/src/replit_river/v2/session.py b/src/replit_river/v2/session.py index 2824660c..c65ed5f2 100644 --- a/src/replit_river/v2/session.py +++ b/src/replit_river/v2/session.py @@ -226,6 +226,9 @@ def transition_connected(ws: ClientConnection) -> None: # We're connected, wake everybody up using set() self._wait_for_connected.set() + def close_ws_in_background(ws: ClientConnection) -> None: + self._task_manager.create_task(ws.close()) + def finalize_attempt() -> None: # We are in a state where we may throw an exception. # @@ -239,7 +242,7 @@ def finalize_attempt() -> None: if ( self._connecting_task and current_task - and self._connecting_task.get_name() == current_task.get_name() + and self._connecting_task is current_task ): self._connecting_task = None @@ -257,6 +260,7 @@ def finalize_attempt() -> None: get_current_ack=lambda: self.ack, get_current_time=self._get_current_time, transition_connecting=transition_connecting, + close_ws_in_background=close_ws_in_background, transition_connected=transition_connected, finalize_attempt=finalize_attempt, do_close=do_close, @@ -973,6 +977,7 @@ async def _do_ensure_connected[HandshakeMetadata]( get_next_sent_seq: Callable[[], int], get_current_ack: Callable[[], int], transition_connecting: Callable[[], None], + close_ws_in_background: Callable[[ClientConnection], None], transition_connected: Callable[[ClientConnection], None], finalize_attempt: Callable[[], None], do_close: Callable[[], None], @@ -989,7 +994,7 @@ async def _do_ensure_connected[HandshakeMetadata]( rate_limiter.consume_budget(client_id) transition_connecting() - ws = None + ws: ClientConnection | None = None try: uri_and_metadata = await uri_and_metadata_factory() ws = await websockets.asyncio.client.connect(uri_and_metadata["uri"]) @@ -1085,13 +1090,15 @@ async def websocket_closed_callback() -> None: }", ) + # We did it! We're connected! last_error = None rate_limiter.start_restoring_budget(client_id) transition_connected(ws) break except Exception as e: if ws: - await ws.close() + close_ws_in_background(ws) + ws = None last_error = e backoff_time = rate_limiter.get_backoff_ms(client_id) logger.exception( From c0e835bfcf3f80086b4198fb39aa09d22fc92cfc Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Thu, 27 Mar 2025 10:51:56 -0700 Subject: [PATCH 107/193] v2 ack payload type --- src/replit_river/v2/session.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/replit_river/v2/session.py b/src/replit_river/v2/session.py index c65ed5f2..d209fd45 100644 --- a/src/replit_river/v2/session.py +++ b/src/replit_river/v2/session.py @@ -1245,7 +1245,6 @@ async def _serve( # https://github.com/replit/river/blob/741b1ea6d7600937ad53564e9cf8cd27a92ec36a/transport/message.ts#L42 payload={ "type": "ACK", - "ack": 0, }, control_flags=ACK_BIT, procedure_name=None, From 480e6da976c37db54121ed724909be1da541555c Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Thu, 27 Mar 2025 11:09:30 -0700 Subject: [PATCH 108/193] Tweaking debug logs --- src/replit_river/v2/session.py | 22 +++++++++++++--------- 1 file changed, 13 insertions(+), 9 deletions(-) diff --git a/src/replit_river/v2/session.py b/src/replit_river/v2/session.py index d209fd45..3b100625 100644 --- a/src/replit_river/v2/session.py +++ b/src/replit_river/v2/session.py @@ -474,7 +474,7 @@ def do_close() -> None: def _start_heartbeat(self) -> None: async def close_websocket() -> None: logger.debug( - "do_close called, _state=%r, _ws=%r", + "close_websocket called, _state=%r, _ws=%r", self._state, self._ws, ) @@ -894,23 +894,23 @@ async def _buffered_message_sender( await block_until_message_available() if get_state() in TerminalStates: - logger.debug("buffered_message_sender: closing") + logger.debug("_buffered_message_sender: closing") return while (ws := get_ws()) is None: # Block until we have a handle logger.debug( - "buffered_message_sender: Waiting until ws is connected", + "_buffered_message_sender: Waiting until ws is connected", ) await block_until_connected() if not ws: - logger.debug("ws is not connected, loop") + logger.debug("_buffered_message_sender: ws is not connected, loop") continue if msg := get_next_pending(): logger.debug( - "buffered_message_sender: Dequeued %r to send over %r", + "_buffered_message_sender: Dequeued %r to send over %r", msg, ws, ) @@ -919,8 +919,8 @@ async def _buffered_message_sender( commit(msg) except WebsocketClosedException as e: logger.debug( - "Connection closed while sending message %r, waiting for " - "retry from buffer", + "_buffered_message_sender: Connection closed while sending " + "message %r, waiting for retry from buffer", type(e), exc_info=e, ) @@ -1048,7 +1048,8 @@ async def websocket_closed_callback() -> None: data = await ws.recv(decode=False) except ConnectionClosed as e: logger.debug( - "Connection closed during waiting for handshake response", + "_do_ensure_connected: Connection closed during waiting " + "for handshake response", exc_info=True, ) raise RiverException( @@ -1060,7 +1061,10 @@ async def websocket_closed_callback() -> None: response_msg = parse_transport_msg(data) break except IgnoreMessageException: - logger.debug("Ignoring transport message", exc_info=True) + logger.debug( + "_do_ensure_connected: Ignoring transport message", + exc_info=True, + ) continue except InvalidMessageException as e: raise RiverException( From 817b4c967c437c134cd7fa6d142776a98393e3b9 Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Thu, 27 Mar 2025 11:19:16 -0700 Subject: [PATCH 109/193] Avoid exceptions for flow control --- src/replit_river/v2/session.py | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/src/replit_river/v2/session.py b/src/replit_river/v2/session.py index 3b100625..a4a50a5c 100644 --- a/src/replit_river/v2/session.py +++ b/src/replit_river/v2/session.py @@ -60,7 +60,6 @@ TransportMessageTracingSetter, ) from replit_river.seq_manager import ( - IgnoreMessageException, InvalidMessageException, OutOfOrderMessageException, ) @@ -1059,13 +1058,14 @@ async def websocket_closed_callback() -> None: try: response_msg = parse_transport_msg(data) + if isinstance(response_msg, str): + logger.debug( + "_do_ensure_connected: Ignoring transport message", + exc_info=True, + ) + continue + break - except IgnoreMessageException: - logger.debug( - "_do_ensure_connected: Ignoring transport message", - exc_info=True, - ) - continue except InvalidMessageException as e: raise RiverException( ERROR_HANDSHAKE, @@ -1217,6 +1217,9 @@ async def _serve( transport_id, msg, ) + if isinstance(msg, str): + logger.debug("Ignoring transport message", exc_info=True) + continue if msg.controlFlags & STREAM_OPEN_BIT != 0: raise InvalidMessageException( From 4fdbdbca8e05f63cbb98ac0dd0545ec87a474d77 Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Thu, 27 Mar 2025 13:10:00 -0700 Subject: [PATCH 110/193] Renaming PENDING to NO_CONNECTION to match TS --- src/replit_river/common_session.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/replit_river/common_session.py b/src/replit_river/common_session.py index 551f5545..051cc923 100644 --- a/src/replit_river/common_session.py +++ b/src/replit_river/common_session.py @@ -26,7 +26,7 @@ class SessionState(enum.Enum): """The state a session can be in. Valid transitions: - - NO_CONNECTION -> {CONNECTING} + - NO_CONNECTION -> {CONNECTING, CLOSING} - CONNECTING -> {NO_CONNECTION, ACTIVE, CLOSING} - ACTIVE -> {NO_CONNECTION, CONNECTING, CLOSING} - CLOSING -> {CLOSED} From 096200ec2c4dd09a2914aef653afd8410f0549ad Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Thu, 27 Mar 2025 13:26:48 -0700 Subject: [PATCH 111/193] Update dead code to be the place where this method lives --- src/replit_river/codegen/client.py | 5 -- src/replit_river/codegen/run.py | 1 - src/replit_river/common_session.py | 66 ++++++++++++++++++- src/replit_river/v2/session.py | 59 +---------------- .../snapshot/codegen_snapshot_fixtures.py | 1 - tests/codegen/test_rpc.py | 1 - 6 files changed, 67 insertions(+), 66 deletions(-) diff --git a/src/replit_river/codegen/client.py b/src/replit_river/codegen/client.py index 13282d48..42e87a0f 100644 --- a/src/replit_river/codegen/client.py +++ b/src/replit_river/codegen/client.py @@ -1096,7 +1096,6 @@ def generate_individual_service( input_base_class: Literal["TypedDict"] | Literal["BaseModel"], method_filter: set[str] | None, protocol_version: Literal["v1.1", "v2.0"], - method_filter: set[str] | None, ) -> tuple[ModuleName, ClassName, dict[RenderedPath, FileContents]]: serdes: list[tuple[list[TypeName], list[ModuleName], list[FileContents]]] = [] @@ -1395,7 +1394,6 @@ def generate_river_client_module( typed_dict_inputs: bool, method_filter: set[str] | None, protocol_version: Literal["v1.1", "v2.0"], - method_filter: set[str] | None, ) -> dict[RenderedPath, FileContents]: files: dict[RenderedPath, FileContents] = {} @@ -1425,7 +1423,6 @@ def generate_river_client_module( input_base_class, method_filter, protocol_version, - method_filter, ) if emitted_files: # Short-cut if we didn't actually emit anything @@ -1452,7 +1449,6 @@ def schema_to_river_client_codegen( file_opener: Callable[[Path], TextIO], method_filter: set[str] | None, protocol_version: Literal["v1.1", "v2.0"], - method_filter: set[str] | None, ) -> None: """Generates the lines of a River module.""" with read_schema() as f: @@ -1463,7 +1459,6 @@ def schema_to_river_client_codegen( typed_dict_inputs, method_filter, protocol_version, - method_filter, ).items(): module_path = Path(target_path).joinpath(subpath) module_path.parent.mkdir(mode=0o755, parents=True, exist_ok=True) diff --git a/src/replit_river/codegen/run.py b/src/replit_river/codegen/run.py index 1a6801f5..3fbc8d78 100644 --- a/src/replit_river/codegen/run.py +++ b/src/replit_river/codegen/run.py @@ -89,7 +89,6 @@ def file_opener(path: Path) -> TextIO: file_opener, method_filter=method_filter, protocol_version=args.protocol_version, - method_filter=method_filter, ) else: raise NotImplementedError(f"Unknown command {args.command}") diff --git a/src/replit_river/common_session.py b/src/replit_river/common_session.py index 051cc923..e3824206 100644 --- a/src/replit_river/common_session.py +++ b/src/replit_river/common_session.py @@ -1,11 +1,19 @@ +import asyncio import enum import logging -from typing import Any, Protocol +from typing import Any, Awaitable, Callable, Coroutine, Protocol from opentelemetry.trace import Span from websockets import WebSocketCommonProtocol from websockets.asyncio.client import ClientConnection +from replit_river.messages import ( + FailedSendingMessageException, + WebsocketClosedException, + send_transport_message, +) +from replit_river.rpc import TransportMessage + logger = logging.getLogger(__name__) @@ -42,3 +50,59 @@ class SessionState(enum.Enum): ConnectingStates = set([SessionState.NO_CONNECTION, SessionState.CONNECTING]) TerminalStates = set([SessionState.CLOSING, SessionState.CLOSED]) + + +async def buffered_message_sender( + block_until_connected: Callable[[], Awaitable[None]], + block_until_message_available: Callable[[], Awaitable[None]], + get_ws: Callable[[], WebSocketCommonProtocol | ClientConnection | None], + websocket_closed_callback: Callable[[], Coroutine[Any, Any, None]], + get_next_pending: Callable[[], TransportMessage | None], + commit: Callable[[TransportMessage], None], + get_state: Callable[[], SessionState], +) -> None: + our_task = asyncio.current_task() + while our_task and not our_task.cancelling() and not our_task.cancelled(): + await block_until_message_available() + + if get_state() in TerminalStates: + logger.debug("_buffered_message_sender: closing") + return + + while (ws := get_ws()) is None: + # Block until we have a handle + logger.debug( + "_buffered_message_sender: Waiting until ws is connected", + ) + await block_until_connected() + + if not ws: + logger.debug("_buffered_message_sender: ws is not connected, loop") + continue + + if msg := get_next_pending(): + logger.debug( + "_buffered_message_sender: Dequeued %r to send over %r", + msg, + ws, + ) + try: + await send_transport_message(msg, ws, websocket_closed_callback) + commit(msg) + except WebsocketClosedException as e: + logger.debug( + "_buffered_message_sender: Connection closed while sending " + "message %r, waiting for retry from buffer", + type(e), + exc_info=e, + ) + break + except FailedSendingMessageException: + logger.error( + "Failed sending message, waiting for retry from buffer", + exc_info=True, + ) + break + except Exception: + logger.exception("Error attempting to send buffered messages") + break diff --git a/src/replit_river/v2/session.py b/src/replit_river/v2/session.py index a4a50a5c..5d4057fb 100644 --- a/src/replit_river/v2/session.py +++ b/src/replit_river/v2/session.py @@ -30,6 +30,7 @@ SendMessage, SessionState, TerminalStates, + buffered_message_sender, ) from replit_river.error_schema import ( ERROR_CODE_CANCEL, @@ -439,7 +440,7 @@ async def block_until_message_available() -> None: await self._process_messages.wait() self._task_manager.create_task( - _buffered_message_sender( + buffered_message_sender( block_until_connected=block_until_connected, block_until_message_available=block_until_message_available, get_ws=get_ws, @@ -879,62 +880,6 @@ async def send_close_stream( ) -async def _buffered_message_sender( - block_until_connected: Callable[[], Awaitable[None]], - block_until_message_available: Callable[[], Awaitable[None]], - get_ws: Callable[[], ClientConnection | None], - websocket_closed_callback: Callable[[], Coroutine[Any, Any, None]], - get_next_pending: Callable[[], TransportMessage | None], - commit: Callable[[TransportMessage], None], - get_state: Callable[[], SessionState], -) -> None: - our_task = asyncio.current_task() - while our_task and not our_task.cancelling() and not our_task.cancelled(): - await block_until_message_available() - - if get_state() in TerminalStates: - logger.debug("_buffered_message_sender: closing") - return - - while (ws := get_ws()) is None: - # Block until we have a handle - logger.debug( - "_buffered_message_sender: Waiting until ws is connected", - ) - await block_until_connected() - - if not ws: - logger.debug("_buffered_message_sender: ws is not connected, loop") - continue - - if msg := get_next_pending(): - logger.debug( - "_buffered_message_sender: Dequeued %r to send over %r", - msg, - ws, - ) - try: - await send_transport_message(msg, ws, websocket_closed_callback) - commit(msg) - except WebsocketClosedException as e: - logger.debug( - "_buffered_message_sender: Connection closed while sending " - "message %r, waiting for retry from buffer", - type(e), - exc_info=e, - ) - break - except FailedSendingMessageException: - logger.error( - "Failed sending message, waiting for retry from buffer", - exc_info=True, - ) - break - except Exception: - logger.exception("Error attempting to send buffered messages") - break - - async def _check_to_close_session( transport_id: str, close_session_check_interval_ms: float, diff --git a/tests/codegen/snapshot/codegen_snapshot_fixtures.py b/tests/codegen/snapshot/codegen_snapshot_fixtures.py index 9b2e73e2..2fdff907 100644 --- a/tests/codegen/snapshot/codegen_snapshot_fixtures.py +++ b/tests/codegen/snapshot/codegen_snapshot_fixtures.py @@ -37,7 +37,6 @@ def file_opener(path: Path) -> TextIO: typed_dict_inputs=typeddict_inputs, method_filter=None, protocol_version="v1.1", - method_filter=None, ) for path, file in files.items(): file.seek(0) diff --git a/tests/codegen/test_rpc.py b/tests/codegen/test_rpc.py index 9d24c99f..450a74f0 100644 --- a/tests/codegen/test_rpc.py +++ b/tests/codegen/test_rpc.py @@ -34,7 +34,6 @@ def file_opener(path: Path) -> TextIO: file_opener=file_opener, method_filter=None, protocol_version="v1.1", - method_filter=None, ) From 290393c2f6bfdf925d73e6efd150a4041fbc920b Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Thu, 27 Mar 2025 13:33:03 -0700 Subject: [PATCH 112/193] None of these were async either --- src/replit_river/client_session.py | 2 +- src/replit_river/message_buffer.py | 4 ++-- src/replit_river/server_session.py | 2 +- src/replit_river/session.py | 2 +- tests/test_message_buffer.py | 2 +- 5 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/replit_river/client_session.py b/src/replit_river/client_session.py index c37768b7..661bd966 100644 --- a/src/replit_river/client_session.py +++ b/src/replit_river/client_session.py @@ -147,7 +147,7 @@ async def _handle_messages_from_ws(self) -> None: case other: assert_never(other) - await self._buffer.remove_old_messages( + self._buffer.remove_old_messages( self._seq_manager.receiver_ack, ) self._reset_session_close_countdown() diff --git a/src/replit_river/message_buffer.py b/src/replit_river/message_buffer.py index 6e1fdad7..d5c434ed 100644 --- a/src/replit_river/message_buffer.py +++ b/src/replit_river/message_buffer.py @@ -47,13 +47,13 @@ def peek(self) -> TransportMessage | None: return None return self.buffer[0] - async def remove_old_messages(self, min_seq: int) -> None: + def remove_old_messages(self, min_seq: int) -> None: """Remove messages in the buffer with a seq number less than min_seq.""" self.buffer = [msg for msg in self.buffer if msg.seq >= min_seq] async with self._space_available_cond: self._space_available_cond.notify_all() - async def close(self) -> None: + def close(self) -> None: """ Closes the message buffer and rejects any pending put operations. """ diff --git a/src/replit_river/server_session.py b/src/replit_river/server_session.py index c397e900..f0eb70e6 100644 --- a/src/replit_river/server_session.py +++ b/src/replit_river/server_session.py @@ -136,7 +136,7 @@ async def _handle_messages_from_ws(self, tg: asyncio.TaskGroup) -> None: pass case other: assert_never(other) - await self._buffer.remove_old_messages( + self._buffer.remove_old_messages( self._seq_manager.receiver_ack, ) self._reset_session_close_countdown() diff --git a/src/replit_river/session.py b/src/replit_river/session.py index b04463a3..140b7066 100644 --- a/src/replit_river/session.py +++ b/src/replit_river/session.py @@ -303,7 +303,7 @@ async def close(self) -> None: await self.close_websocket(self._ws_wrapper, should_retry=False) - await self._buffer.close() + self._buffer.close() # Clear the session in transports await self._close_session_callback(self) diff --git a/tests/test_message_buffer.py b/tests/test_message_buffer.py index 02a21ccb..a8fb9b57 100644 --- a/tests/test_message_buffer.py +++ b/tests/test_message_buffer.py @@ -45,7 +45,7 @@ async def put_messages() -> None: # Wait for the put call to return. await sync_events.get() assert len(buffer.buffer) == 1 - await buffer.remove_old_messages(i) + buffer.remove_old_messages(i) await background_puts From d8a9887f8c2e9fe9d1f2fe960ac54596dd7f6be7 Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Thu, 27 Mar 2025 13:52:51 -0700 Subject: [PATCH 113/193] v2 avoid message ordering issues --- src/replit_river/v2/session.py | 25 +++++++++++++------------ 1 file changed, 13 insertions(+), 12 deletions(-) diff --git a/src/replit_river/v2/session.py b/src/replit_river/v2/session.py index 5d4057fb..2d28c52e 100644 --- a/src/replit_river/v2/session.py +++ b/src/replit_river/v2/session.py @@ -331,6 +331,19 @@ async def _send_message( service_name, procedure_name, ) + # Ensure the buffer isn't full before we enqueue + await self._space_available.wait() + + # Before we append, do an important check + if self._state in TerminalStates: + # session is closing / closed, raise + raise SessionClosedRiverServiceException( + "river session is closed, dropping message", + service_name, + procedure_name, + ) + + # Begin critical section: Avoid any await between here and _send_buffer.append msg = TransportMessage( streamId=stream_id, id=nanoid.generate(), @@ -348,18 +361,6 @@ async def _send_message( with use_span(span): trace_propagator.inject(msg, None, trace_setter) - # Ensure the buffer isn't full before we enqueue - await self._space_available.wait() - - # Before we append, do an important check - if self._state in TerminalStates: - # session is closing / closed, raise - raise SessionClosedRiverServiceException( - "river session is closed, dropping message", - service_name, - procedure_name, - ) - self._send_buffer.append(msg) # If the buffer is now full, reset the block From 7bbc2a9c6ea8289102219076c2c427a7a60dc208 Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Thu, 27 Mar 2025 13:59:19 -0700 Subject: [PATCH 114/193] Permit sensible message_buffer state transitions --- src/replit_river/client_session.py | 2 +- src/replit_river/codegen/run.py | 16 +++++----------- src/replit_river/message_buffer.py | 15 +++++++++++++-- src/replit_river/server_session.py | 2 +- src/replit_river/session.py | 2 +- tests/test_message_buffer.py | 2 +- 6 files changed, 22 insertions(+), 17 deletions(-) diff --git a/src/replit_river/client_session.py b/src/replit_river/client_session.py index 661bd966..c37768b7 100644 --- a/src/replit_river/client_session.py +++ b/src/replit_river/client_session.py @@ -147,7 +147,7 @@ async def _handle_messages_from_ws(self) -> None: case other: assert_never(other) - self._buffer.remove_old_messages( + await self._buffer.remove_old_messages( self._seq_manager.receiver_ack, ) self._reset_session_close_countdown() diff --git a/src/replit_river/codegen/run.py b/src/replit_river/codegen/run.py index 3fbc8d78..5a69f1dd 100644 --- a/src/replit_river/codegen/run.py +++ b/src/replit_river/codegen/run.py @@ -52,12 +52,6 @@ def main() -> None: default="v1.1", choices=["v1.1", "v2.0"], ) - client.add_argument( - "--method-filter", - help="Only generate a subset of the specified methods", - action="store", - type=pathlib.Path, - ) client.add_argument("schema", help="schema file") args = parser.parse_args() @@ -82,11 +76,11 @@ def file_opener(path: Path) -> TextIO: method_filter = set(x.strip() for x in handle.readlines()) schema_to_river_client_codegen( - lambda: open(schema_path), - target_path, - args.client_name, - args.typed_dict_inputs, - file_opener, + read_schema=lambda: open(schema_path), + target_path=target_path, + client_name=args.client_name, + typed_dict_inputs=args.typed_dict_inputs, + file_opener=file_opener, method_filter=method_filter, protocol_version=args.protocol_version, ) diff --git a/src/replit_river/message_buffer.py b/src/replit_river/message_buffer.py index d5c434ed..5c3b8f44 100644 --- a/src/replit_river/message_buffer.py +++ b/src/replit_river/message_buffer.py @@ -17,6 +17,7 @@ class MessageBuffer: def __init__(self, max_num_messages: int = MAX_MESSAGE_BUFFER_SIZE): self.max_size = max_num_messages self.buffer: list[TransportMessage] = [] + self._has_messages = asyncio.Event() self._space_available_cond = asyncio.Condition() self._closed = False @@ -35,6 +36,7 @@ def put(self, message: TransportMessage) -> None: if self._closed: raise MessageBufferClosedError("message buffer is closed") self.buffer.append(message) + self._has_messages.set() def get_next_sent_seq(self) -> int | None: if self.buffer: @@ -47,16 +49,25 @@ def peek(self) -> TransportMessage | None: return None return self.buffer[0] - def remove_old_messages(self, min_seq: int) -> None: + async def remove_old_messages(self, min_seq: int) -> None: """Remove messages in the buffer with a seq number less than min_seq.""" self.buffer = [msg for msg in self.buffer if msg.seq >= min_seq] + if self.buffer: + self._has_messages.set() + else: + self._has_messages.clear() async with self._space_available_cond: self._space_available_cond.notify_all() - def close(self) -> None: + async def block_until_message_available(self) -> None: + """Allow consumers to avoid spinning unnecessarily""" + await self._has_messages.wait() + + async def close(self) -> None: """ Closes the message buffer and rejects any pending put operations. """ self._closed = True + self._has_messages.set() async with self._space_available_cond: self._space_available_cond.notify_all() diff --git a/src/replit_river/server_session.py b/src/replit_river/server_session.py index f0eb70e6..c397e900 100644 --- a/src/replit_river/server_session.py +++ b/src/replit_river/server_session.py @@ -136,7 +136,7 @@ async def _handle_messages_from_ws(self, tg: asyncio.TaskGroup) -> None: pass case other: assert_never(other) - self._buffer.remove_old_messages( + await self._buffer.remove_old_messages( self._seq_manager.receiver_ack, ) self._reset_session_close_countdown() diff --git a/src/replit_river/session.py b/src/replit_river/session.py index 140b7066..b04463a3 100644 --- a/src/replit_river/session.py +++ b/src/replit_river/session.py @@ -303,7 +303,7 @@ async def close(self) -> None: await self.close_websocket(self._ws_wrapper, should_retry=False) - self._buffer.close() + await self._buffer.close() # Clear the session in transports await self._close_session_callback(self) diff --git a/tests/test_message_buffer.py b/tests/test_message_buffer.py index a8fb9b57..02a21ccb 100644 --- a/tests/test_message_buffer.py +++ b/tests/test_message_buffer.py @@ -45,7 +45,7 @@ async def put_messages() -> None: # Wait for the put call to return. await sync_events.get() assert len(buffer.buffer) == 1 - buffer.remove_old_messages(i) + await buffer.remove_old_messages(i) await background_puts From ef53cc81f6fa2b770d93532b2d2ac86774866d30 Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Sun, 30 Mar 2025 13:54:03 -0700 Subject: [PATCH 115/193] Clarify SessionClosedRiverServiceException --- src/replit_river/error_schema.py | 12 +++++------- src/replit_river/v2/session.py | 2 -- 2 files changed, 5 insertions(+), 9 deletions(-) diff --git a/src/replit_river/error_schema.py b/src/replit_river/error_schema.py index 0dfee049..5bff801a 100644 --- a/src/replit_river/error_schema.py +++ b/src/replit_river/error_schema.py @@ -17,9 +17,9 @@ # ERROR_CODE_CANCEL is the code used when either server or client cancels the stream. ERROR_CODE_CANCEL = "CANCEL" -# ERROR_CODE_SESSION_CLOSED is the code used when either server or client closes -# the session. -ERROR_CODE_SESSION_CLOSED = "CLOSED" +# SYNTHETIC_ERROR_CODE_SESSION_CLOSED is a synthetic code emitted exclusively by the +# client's session. It is not sent over the wire. +SYNTHETIC_ERROR_CODE_SESSION_CLOSED = "SESSION_CLOSED" # ERROR_CODE_UNKNOWN is the code for the RiverUnknownError ERROR_CODE_UNKNOWN = "UNKNOWN" @@ -82,14 +82,12 @@ class StreamClosedRiverServiceException(RiverServiceException): pass -class SessionClosedRiverServiceException(RiverServiceException): +class SessionClosedRiverServiceException(RiverException): def __init__( self, message: str, - service: str | None, - procedure: str | None, ) -> None: - super().__init__(ERROR_CODE_SESSION_CLOSED, message, service, procedure) + super().__init__(SYNTHETIC_ERROR_CODE_SESSION_CLOSED, message) def exception_from_message(code: str) -> type[RiverServiceException]: diff --git a/src/replit_river/v2/session.py b/src/replit_river/v2/session.py index 2d28c52e..e01b820d 100644 --- a/src/replit_river/v2/session.py +++ b/src/replit_river/v2/session.py @@ -339,8 +339,6 @@ async def _send_message( # session is closing / closed, raise raise SessionClosedRiverServiceException( "river session is closed, dropping message", - service_name, - procedure_name, ) # Begin critical section: Avoid any await between here and _send_buffer.append From a9cf61e4339a67db2a185f83e3500f8f9f81c100 Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Mon, 31 Mar 2025 10:47:03 -0700 Subject: [PATCH 116/193] Private method --- src/replit_river/v2/session.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/replit_river/v2/session.py b/src/replit_river/v2/session.py index e01b820d..cd535706 100644 --- a/src/replit_river/v2/session.py +++ b/src/replit_river/v2/session.py @@ -685,7 +685,7 @@ async def send_upload[I, R, A]( raise RiverServiceException( ERROR_CODE_STREAM_CLOSED, str(e), service_name, procedure_name ) from e - await self.send_close_stream( + await self._send_close_stream( service_name, procedure_name, stream_id, @@ -805,7 +805,7 @@ async def send_stream[I, R, E, A]( # Create the encoder task async def _encode_stream() -> None: if not request: - await self.send_close_stream( + await self._send_close_stream( service_name, procedure_name, stream_id, @@ -825,7 +825,7 @@ async def _encode_stream() -> None: control_flags=0, payload=request_serializer(item), ) - await self.send_close_stream( + await self._send_close_stream( service_name, procedure_name, stream_id, @@ -860,7 +860,7 @@ async def _encode_stream() -> None: finally: output.close() - async def send_close_stream( + async def _send_close_stream( self, service_name: str, procedure_name: str, From 0d900209f1dc37b86592136a732cd46374d6a2bd Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Mon, 31 Mar 2025 10:50:39 -0700 Subject: [PATCH 117/193] Add span to send_close_stream --- src/replit_river/v2/session.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/src/replit_river/v2/session.py b/src/replit_river/v2/session.py index cd535706..e0f5a899 100644 --- a/src/replit_river/v2/session.py +++ b/src/replit_river/v2/session.py @@ -690,6 +690,7 @@ async def send_upload[I, R, A]( procedure_name, stream_id, extra_control_flags=0, + span=span, ) # Handle potential errors during communication @@ -810,6 +811,7 @@ async def _encode_stream() -> None: procedure_name, stream_id, extra_control_flags=STREAM_OPEN_BIT, + span=span, ) return @@ -830,6 +832,7 @@ async def _encode_stream() -> None: procedure_name, stream_id, extra_control_flags=0, + span=span, ) self._task_manager.create_task(_encode_stream()) @@ -866,6 +869,7 @@ async def _send_close_stream( procedure_name: str, stream_id: str, extra_control_flags: int, + span: Span, ) -> None: # close stream await self._send_message( @@ -873,9 +877,8 @@ async def _send_close_stream( procedure_name=procedure_name, stream_id=stream_id, control_flags=STREAM_CLOSED_BIT | extra_control_flags, - payload={ - "type": "CLOSE", - }, + payload={"type": "CLOSE"}, + span=span, ) From 6798a94994c6aea475ce901a816ffc76d71077bf Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Mon, 31 Mar 2025 10:50:51 -0700 Subject: [PATCH 118/193] Break out send_cancel_stream --- src/replit_river/v2/session.py | 21 ++++++++++++++++----- 1 file changed, 16 insertions(+), 5 deletions(-) diff --git a/src/replit_river/v2/session.py b/src/replit_river/v2/session.py index e0f5a899..254dcf03 100644 --- a/src/replit_river/v2/session.py +++ b/src/replit_river/v2/session.py @@ -609,12 +609,9 @@ async def send_rpc[R, A]( async with asyncio.timeout(timeout.total_seconds()): response = await output.get() except asyncio.TimeoutError as e: - await self._send_message( + await self._send_cancel_stream( stream_id=stream_id, - control_flags=STREAM_CANCEL_BIT, - payload={"type": "CANCEL"}, - service_name=service_name, - procedure_name=procedure_name, + extra_control_flags=0, span=span, ) raise RiverException(ERROR_CODE_CANCEL, str(e)) from e @@ -863,6 +860,20 @@ async def _encode_stream() -> None: finally: output.close() + async def _send_cancel_stream( + self, + stream_id: str, + extra_control_flags: int, + span: Span, + ) -> None: + # close stream + await self._send_message( + stream_id=stream_id, + control_flags=STREAM_CANCEL_BIT | extra_control_flags, + payload={"type": "CANCEL"}, + span=span, + ) + async def _send_close_stream( self, service_name: str, From fa66288bf11f6e5a5d9d03ec323e161094730c8e Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Mon, 31 Mar 2025 12:20:31 -0700 Subject: [PATCH 119/193] Increment seq immediately --- src/replit_river/v2/session.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/replit_river/v2/session.py b/src/replit_river/v2/session.py index 254dcf03..13f028bb 100644 --- a/src/replit_river/v2/session.py +++ b/src/replit_river/v2/session.py @@ -359,15 +359,18 @@ async def _send_message( with use_span(span): trace_propagator.inject(msg, None, trace_setter) + # We're clear to add to the send buffer self._send_buffer.append(msg) + # Increment immediately so we maintain consistency + self.seq += 1 + # If the buffer is now full, reset the block if len(self._send_buffer) >= self._transport_options.buffer_size: self._space_available.clear() # Wake up buffered_message_sender self._process_messages.set() - self.seq += 1 async def close(self) -> None: """Close the session and all associated streams.""" From c2cbe70f3fadcf875a55fd61c923b17936102f0a Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Mon, 31 Mar 2025 12:32:06 -0700 Subject: [PATCH 120/193] Useless comments --- src/replit_river/v2/session.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/replit_river/v2/session.py b/src/replit_river/v2/session.py index 13f028bb..57efc529 100644 --- a/src/replit_river/v2/session.py +++ b/src/replit_river/v2/session.py @@ -869,7 +869,6 @@ async def _send_cancel_stream( extra_control_flags: int, span: Span, ) -> None: - # close stream await self._send_message( stream_id=stream_id, control_flags=STREAM_CANCEL_BIT | extra_control_flags, @@ -885,7 +884,6 @@ async def _send_close_stream( extra_control_flags: int, span: Span, ) -> None: - # close stream await self._send_message( service_name=service_name, procedure_name=procedure_name, From 345b7b2ed0eb88ad15ebc19f30ae9258925c1021 Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Mon, 31 Mar 2025 11:30:06 -0700 Subject: [PATCH 121/193] Cancel streams on exception --- src/replit_river/v2/session.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/src/replit_river/v2/session.py b/src/replit_river/v2/session.py index 57efc529..f4ed667b 100644 --- a/src/replit_river/v2/session.py +++ b/src/replit_river/v2/session.py @@ -681,7 +681,18 @@ async def send_upload[I, R, A]( payload=request_serializer(item), span=span, ) + except WebsocketClosedException as e: + raise RiverServiceException( + ERROR_CODE_STREAM_CLOSED, str(e), service_name, procedure_name + ) from e except Exception as e: + # If we get any exception other than WebsocketClosedException, + # cancel the stream. + await self._send_cancel_stream( + stream_id=stream_id, + extra_control_flags=0, + span=span, + ) raise RiverServiceException( ERROR_CODE_STREAM_CLOSED, str(e), service_name, procedure_name ) from e From 03a0d6767f17e24f04dde246c016edaa9aaf6798 Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Mon, 31 Mar 2025 12:31:53 -0700 Subject: [PATCH 122/193] Alter send_message to be generic over return value --- src/replit_river/common_session.py | 4 ++-- src/replit_river/session.py | 2 +- src/replit_river/v2/session.py | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/replit_river/common_session.py b/src/replit_river/common_session.py index e3824206..2d670efd 100644 --- a/src/replit_river/common_session.py +++ b/src/replit_river/common_session.py @@ -17,7 +17,7 @@ logger = logging.getLogger(__name__) -class SendMessage(Protocol): +class SendMessage[Result](Protocol): async def __call__( self, *, @@ -27,7 +27,7 @@ async def __call__( service_name: str | None, procedure_name: str | None, span: Span | None, - ) -> None: ... + ) -> Result: ... class SessionState(enum.Enum): diff --git a/src/replit_river/session.py b/src/replit_river/session.py index b04463a3..8bb745b7 100644 --- a/src/replit_river/session.py +++ b/src/replit_river/session.py @@ -350,7 +350,7 @@ async def setup_heartbeat( get_state: Callable[[], SessionState], get_closing_grace_period: Callable[[], float | None], close_websocket: Callable[[], Awaitable[None]], - send_message: SendMessage, + send_message: SendMessage[None], increment_and_get_heartbeat_misses: Callable[[], int], ) -> None: while True: diff --git a/src/replit_river/v2/session.py b/src/replit_river/v2/session.py index f4ed667b..3899cf7f 100644 --- a/src/replit_river/v2/session.py +++ b/src/replit_river/v2/session.py @@ -1146,7 +1146,7 @@ async def _serve( ], get_stream: Callable[[str], Channel[Any] | None], close_stream: Callable[[str], None], - send_message: SendMessage, + send_message: SendMessage[None], ) -> None: """Serve messages from the websocket.""" reset_session_close_countdown() From ab223e1bc9ab45b9569b44bfc4f76aa078f621f7 Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Mon, 31 Mar 2025 14:55:57 -0700 Subject: [PATCH 123/193] Introduce backpressure for emitters --- src/replit_river/v2/session.py | 466 ++++++++++++++++++--------------- 1 file changed, 250 insertions(+), 216 deletions(-) diff --git a/src/replit_river/v2/session.py b/src/replit_river/v2/session.py index 3899cf7f..5bcfea8b 100644 --- a/src/replit_river/v2/session.py +++ b/src/replit_river/v2/session.py @@ -2,11 +2,13 @@ import logging from collections import deque from collections.abc import AsyncIterable +from contextlib import asynccontextmanager from dataclasses import dataclass from datetime import timedelta from typing import ( Any, AsyncGenerator, + AsyncIterator, Awaitable, Callable, Coroutine, @@ -119,7 +121,7 @@ class Session: _space_available: asyncio.Event # stream for tasks - _streams: dict[str, Channel[Any]] + _streams: dict[str, tuple[asyncio.Event, Channel[Any]]] # book keeping _ack_buffer: deque[TransportMessage] @@ -165,7 +167,7 @@ def __init__( self._space_available.set() # stream for tasks - self._streams: dict[str, Channel[Any]] = {} + self._streams: dict[str, tuple[asyncio.Event, Channel[Any]]] = {} # book keeping self._ack_buffer = deque() @@ -394,10 +396,12 @@ async def close(self) -> None: # TODO: unexpected_close should close stream differently here to # throw exception correctly. - for stream in self._streams.values(): + for event, stream in self._streams.values(): stream.close() + # Wake up backpressured writers + event.set() # Before we GC the streams, let's wait for all tasks to be closed gracefully. - await asyncio.gather(*[x.join() for x in self._streams.values()]) + await asyncio.gather(*[stream.join() for _, stream in self._streams.values()]) self._streams.clear() if self._ws: @@ -558,9 +562,6 @@ def assert_incoming_seq_bookkeeping( return True - def close_stream(stream_id: str) -> None: - del self._streams[stream_id] - async def block_until_connected() -> None: await self._wait_for_connected.wait() @@ -576,11 +577,24 @@ async def block_until_connected() -> None: close_session=self.close, assert_incoming_seq_bookkeeping=assert_incoming_seq_bookkeeping, get_stream=lambda stream_id: self._streams.get(stream_id), - close_stream=close_stream, send_message=self._send_message, ) ) + @asynccontextmanager + async def _with_stream( + self, + session_id: str, + maxsize: int, + ) -> AsyncIterator[tuple[asyncio.Event, Channel[Any]]]: + output: Channel[Any] = Channel(maxsize=maxsize) + event = asyncio.Event() + self._streams[session_id] = (event, output) + try: + yield (event, output) + finally: + del self._streams[session_id] + async def send_rpc[R, A]( self, service_name: str, @@ -597,45 +611,48 @@ async def send_rpc[R, A]( Expects the input and output be messages that will be msgpacked. """ stream_id = nanoid.generate() - output: Channel[Any] = Channel(1) - self._streams[stream_id] = output - await self._send_message( - stream_id=stream_id, - control_flags=STREAM_OPEN_BIT | STREAM_CLOSED_BIT, - payload=request_serializer(request), - service_name=service_name, - procedure_name=procedure_name, - span=span, - ) - # Handle potential errors during communication - try: - async with asyncio.timeout(timeout.total_seconds()): - response = await output.get() - except asyncio.TimeoutError as e: - await self._send_cancel_stream( + async with self._with_stream(stream_id, 1) as (event, output): + await self._send_message( stream_id=stream_id, - extra_control_flags=0, + control_flags=STREAM_OPEN_BIT | STREAM_CLOSED_BIT, + payload=request_serializer(request), + service_name=service_name, + procedure_name=procedure_name, span=span, ) - raise RiverException(ERROR_CODE_CANCEL, str(e)) from e - except ChannelClosed as e: - raise RiverServiceException( - ERROR_CODE_STREAM_CLOSED, - "Stream closed before response", - service_name, - procedure_name, - ) from e - except RuntimeError as e: - raise RiverException(ERROR_CODE_STREAM_CLOSED, str(e)) from e - if not response.get("ok", False): + # Handle potential errors during communication try: - error = error_deserializer(response["payload"]) - except Exception as e: - raise RiverException("error_deserializer", str(e)) from e - raise exception_from_message(error.code)( - error.code, error.message, service_name, procedure_name - ) - return response_deserializer(response["payload"]) + async with asyncio.timeout(timeout.total_seconds()): + # Block for event for symmetry with backpressured producers + # Here this should be trivially true. + await event.wait() + response = await output.get() + except asyncio.TimeoutError as e: + await self._send_cancel_stream( + stream_id=stream_id, + extra_control_flags=0, + span=span, + ) + raise RiverException(ERROR_CODE_CANCEL, str(e)) from e + except ChannelClosed as e: + raise RiverServiceException( + ERROR_CODE_STREAM_CLOSED, + "Stream closed before response", + service_name, + procedure_name, + ) from e + except RuntimeError as e: + raise RiverException(ERROR_CODE_STREAM_CLOSED, str(e)) from e + if not response.get("ok", False): + try: + error = error_deserializer(response["payload"]) + except Exception as e: + raise RiverException("error_deserializer", str(e)) from e + raise exception_from_message(error.code)( + error.code, error.message, service_name, procedure_name + ) + + return response_deserializer(response["payload"]) async def send_upload[I, R, A]( self, @@ -655,78 +672,82 @@ async def send_upload[I, R, A]( """ stream_id = nanoid.generate() - output: Channel[Any] = Channel(1) - self._streams[stream_id] = output - try: - await self._send_message( - stream_id=stream_id, - control_flags=STREAM_OPEN_BIT, - service_name=service_name, - procedure_name=procedure_name, - payload=init_serializer(init), - span=span, - ) - - if request: - assert request_serializer, "send_stream missing request_serializer" + async with self._with_stream(stream_id, 1) as (event, output): + try: + await self._send_message( + stream_id=stream_id, + control_flags=STREAM_OPEN_BIT, + service_name=service_name, + procedure_name=procedure_name, + payload=init_serializer(init), + span=span, + ) - # If this request is not closed and the session is killed, we should - # throw exception here - async for item in request: - await self._send_message( - stream_id=stream_id, - service_name=service_name, - procedure_name=procedure_name, - control_flags=0, - payload=request_serializer(item), - span=span, - ) - except WebsocketClosedException as e: - raise RiverServiceException( - ERROR_CODE_STREAM_CLOSED, str(e), service_name, procedure_name - ) from e - except Exception as e: - # If we get any exception other than WebsocketClosedException, - # cancel the stream. - await self._send_cancel_stream( - stream_id=stream_id, + if request: + assert request_serializer, "send_stream missing request_serializer" + + # If this request is not closed and the session is killed, we should + # throw exception here + async for item in request: + # Block for backpressure + await event.wait() + if output.closed(): + logger.debug("Stream is closed, avoid sending the rest") + break + await self._send_message( + stream_id=stream_id, + service_name=service_name, + procedure_name=procedure_name, + control_flags=0, + payload=request_serializer(item), + span=span, + ) + except WebsocketClosedException as e: + raise RiverServiceException( + ERROR_CODE_STREAM_CLOSED, str(e), service_name, procedure_name + ) from e + except Exception as e: + # If we get any exception other than WebsocketClosedException, + # cancel the stream. + await self._send_cancel_stream( + stream_id=stream_id, + extra_control_flags=0, + span=span, + ) + raise RiverServiceException( + ERROR_CODE_STREAM_CLOSED, str(e), service_name, procedure_name + ) from e + await self._send_close_stream( + service_name, + procedure_name, + stream_id, extra_control_flags=0, span=span, ) - raise RiverServiceException( - ERROR_CODE_STREAM_CLOSED, str(e), service_name, procedure_name - ) from e - await self._send_close_stream( - service_name, - procedure_name, - stream_id, - extra_control_flags=0, - span=span, - ) - # Handle potential errors during communication - # TODO: throw a error when the transport is hard closed - try: - response = await output.get() - except ChannelClosed as e: - raise RiverServiceException( - ERROR_CODE_STREAM_CLOSED, - "Stream closed before response", - service_name, - procedure_name, - ) from e - except RuntimeError as e: - raise RiverException(ERROR_CODE_STREAM_CLOSED, str(e)) from e - if not response.get("ok", False): + # Handle potential errors during communication + # TODO: throw a error when the transport is hard closed try: - error = error_deserializer(response["payload"]) - except Exception as e: - raise RiverException("error_deserializer", str(e)) from e - raise exception_from_message(error.code)( - error.code, error.message, service_name, procedure_name - ) + response = await output.get() + except ChannelClosed as e: + raise RiverServiceException( + ERROR_CODE_STREAM_CLOSED, + "Stream closed before response", + service_name, + procedure_name, + ) from e + except RuntimeError as e: + raise RiverException(ERROR_CODE_STREAM_CLOSED, str(e)) from e + if not response.get("ok", False): + try: + error = error_deserializer(response["payload"]) + except Exception as e: + raise RiverException("error_deserializer", str(e)) from e + raise exception_from_message(error.code)( + error.code, error.message, service_name, procedure_name + ) - return response_deserializer(response["payload"]) + return response_deserializer(response["payload"]) async def send_subscription[R, E, A]( self, @@ -743,42 +764,42 @@ async def send_subscription[R, E, A]( Expects the input and output be messages that will be msgpacked. """ stream_id = nanoid.generate() - output: Channel[Any] = Channel(MAX_MESSAGE_BUFFER_SIZE) - self._streams[stream_id] = output - await self._send_message( - service_name=service_name, - procedure_name=procedure_name, - stream_id=stream_id, - control_flags=STREAM_OPEN_BIT, - payload=request_serializer(request), - span=span, - ) + async with self._with_stream(stream_id, MAX_MESSAGE_BUFFER_SIZE) as (_, output): + await self._send_message( + service_name=service_name, + procedure_name=procedure_name, + stream_id=stream_id, + control_flags=STREAM_OPEN_BIT, + payload=request_serializer(request), + span=span, + ) - # Handle potential errors during communication - try: - async for item in output: - if item.get("type") == "CLOSE": - break - if not item.get("ok", False): - try: - yield error_deserializer(item["payload"]) - except Exception: - logger.exception( - f"Error during subscription error deserialization: {item}" - ) - continue - yield response_deserializer(item["payload"]) - except (RuntimeError, ChannelClosed) as e: - raise RiverServiceException( - ERROR_CODE_STREAM_CLOSED, - "Stream closed before response", - service_name, - procedure_name, - ) from e - except Exception as e: - raise e - finally: - output.close() + # Handle potential errors during communication + try: + async for item in output: + if item.get("type") == "CLOSE": + break + if not item.get("ok", False): + try: + yield error_deserializer(item["payload"]) + except Exception: + logger.exception( + "Error during subscription " + f"error deserialization: {item}" + ) + continue + yield response_deserializer(item["payload"]) + except (RuntimeError, ChannelClosed) as e: + raise RiverServiceException( + ERROR_CODE_STREAM_CLOSED, + "Stream closed before response", + service_name, + procedure_name, + ) from e + except Exception as e: + raise e + finally: + output.close() async def send_stream[I, R, E, A]( self, @@ -798,81 +819,89 @@ async def send_stream[I, R, E, A]( """ stream_id = nanoid.generate() - output: Channel[Any] = Channel(MAX_MESSAGE_BUFFER_SIZE) - self._streams[stream_id] = output - try: - await self._send_message( - service_name=service_name, - procedure_name=procedure_name, - stream_id=stream_id, - control_flags=STREAM_OPEN_BIT, - payload=init_serializer(init), - span=span, - ) - except Exception as e: - raise StreamClosedRiverServiceException( - ERROR_CODE_STREAM_CLOSED, str(e), service_name, procedure_name - ) from e + async with self._with_stream( + stream_id, + MAX_MESSAGE_BUFFER_SIZE, + ) as (event, output): + try: + await self._send_message( + service_name=service_name, + procedure_name=procedure_name, + stream_id=stream_id, + control_flags=STREAM_OPEN_BIT, + payload=init_serializer(init), + span=span, + ) + except Exception as e: + raise StreamClosedRiverServiceException( + ERROR_CODE_STREAM_CLOSED, str(e), service_name, procedure_name + ) from e + + # Create the encoder task + async def _encode_stream() -> None: + if not request: + await self._send_close_stream( + service_name, + procedure_name, + stream_id, + extra_control_flags=STREAM_OPEN_BIT, + span=span, + ) + return + + assert request_serializer, "send_stream missing request_serializer" - # Create the encoder task - async def _encode_stream() -> None: - if not request: + async for item in request: + if item is None: + continue + await event.wait() + if output.closed(): + logger.debug("Stream is closed, avoid sending the rest") + break + await self._send_message( + service_name=service_name, + procedure_name=procedure_name, + stream_id=stream_id, + control_flags=0, + payload=request_serializer(item), + ) await self._send_close_stream( service_name, procedure_name, stream_id, - extra_control_flags=STREAM_OPEN_BIT, + extra_control_flags=0, span=span, ) - return - assert request_serializer, "send_stream missing request_serializer" + self._task_manager.create_task(_encode_stream()) - async for item in request: - if item is None: - continue - await self._send_message( - service_name=service_name, - procedure_name=procedure_name, - stream_id=stream_id, - control_flags=0, - payload=request_serializer(item), - ) - await self._send_close_stream( - service_name, - procedure_name, - stream_id, - extra_control_flags=0, - span=span, - ) - - self._task_manager.create_task(_encode_stream()) - - # Handle potential errors during communication - try: - async for item in output: - if item.get("type") == "CLOSE": - break - if not item.get("ok", False): - try: - yield error_deserializer(item["payload"]) - except Exception: - logger.exception( - f"Error during subscription error deserialization: {item}" - ) - continue - yield response_deserializer(item["payload"]) - except (RuntimeError, ChannelClosed) as e: - raise RiverServiceException( - ERROR_CODE_STREAM_CLOSED, - "Stream closed before response", - service_name, - procedure_name, - ) from e - except Exception as e: - raise e - finally: - output.close() + # Handle potential errors during communication + try: + async for item in output: + if item.get("type") == "CLOSE": + break + if not item.get("ok", False): + try: + yield error_deserializer(item["payload"]) + except Exception: + logger.exception( + "Error during subscription " + f"error deserialization: {item}" + ) + continue + yield response_deserializer(item["payload"]) + except (RuntimeError, ChannelClosed) as e: + raise RiverServiceException( + ERROR_CODE_STREAM_CLOSED, + "Stream closed before response", + service_name, + procedure_name, + ) from e + except Exception as e: + raise e + finally: + output.close() + event.set() async def _send_cancel_stream( self, @@ -1144,8 +1173,7 @@ async def _serve( assert_incoming_seq_bookkeeping: Callable[ [str, int, int], Literal[True] | _IgnoreMessage ], - get_stream: Callable[[str], Channel[Any] | None], - close_stream: Callable[[str], None], + get_stream: Callable[[str], tuple[asyncio.Event, Channel[Any]] | None], send_message: SendMessage[None], ) -> None: """Serve messages from the websocket.""" @@ -1230,24 +1258,29 @@ async def _serve( ) continue - stream = get_stream(msg.streamId) + event_stream = get_stream(msg.streamId) - if not stream: + if not event_stream: logger.warning( "no stream for %s, ignoring message", msg.streamId, ) continue + event, stream = event_stream + if ( msg.controlFlags & STREAM_CLOSED_BIT != 0 and msg.payload.get("type", None) == "CLOSE" ): # close message is not sent to the stream + # event is set during cleanup down below pass else: try: await stream.put(msg.payload) + # Wake up backpressured writer + event.set() except ChannelClosed: # The client is no longer interested in this stream, # just drop the message. @@ -1256,9 +1289,10 @@ async def _serve( raise InvalidMessageException(e) from e if msg.controlFlags & STREAM_CLOSED_BIT != 0: - if stream: - stream.close() - close_stream(msg.streamId) + # Communicate that we're going down + stream.close() + # Wake up backpressured writer + event.set() except OutOfOrderMessageException: logger.exception("Out of order message, closing connection") await close_session() From ce934af4044d0a84dc3c4fceeea42c0976efac38 Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Mon, 31 Mar 2025 15:07:06 -0700 Subject: [PATCH 124/193] Making space for v2 tests --- tests/conftest.py | 5 ++++- .../codegen_snapshot_fixtures.py | 8 +++++--- .../{ => v1}/codegen/rpc/generated/__init__.py | 0 .../rpc/generated/test_service/__init__.py | 0 .../rpc/generated/test_service/rpc_method.py | 0 tests/{ => v1}/codegen/rpc/schema.json | 0 .../snapshots/test_basic_stream/__init__.py | 0 .../test_basic_stream/test_service/__init__.py | 0 .../test_service/emit_error.py | 0 .../test_service/stream_method.py | 0 .../test_pathological_types/__init__.py | 0 .../test_service/__init__.py | 0 .../test_service/pathological_method.py | 0 .../snapshots/test_unknown_enum/__init__.py | 0 .../test_unknown_enum/enumService/__init__.py | 0 .../test_unknown_enum/enumService/needsEnum.py | 0 .../enumService/needsEnumObject.py | 0 tests/{ => v1}/codegen/snapshot/test_enum.py | 18 +++++++++++------- .../codegen/snapshot/test_pathological.py | 6 ++++-- tests/{ => v1}/codegen/stream/schema.json | 0 tests/{ => v1}/codegen/stream/test_stream.py | 18 ++++++++++-------- tests/{ => v1}/codegen/test_rpc.py | 18 +++++++++--------- tests/{ => v1}/codegen/types/schema.json | 0 tests/{ => v1}/common_handlers.py | 0 tests/{ => v1}/river_fixtures/clientserver.py | 2 +- tests/{ => v1}/river_fixtures/logging.py | 0 tests/{ => v1}/test_communication.py | 12 ++++++------ tests/{ => v1}/test_handshake.py | 0 tests/{ => v1}/test_message_buffer.py | 0 tests/{ => v1}/test_opentelemetry.py | 14 +++++++------- tests/{ => v1}/test_rate_limiter.py | 0 tests/{ => v1}/test_seq_manager.py | 2 +- tests/{ => v1}/test_timeout.py | 6 +++--- 33 files changed, 61 insertions(+), 48 deletions(-) rename tests/{codegen/snapshot => fixtures}/codegen_snapshot_fixtures.py (84%) rename tests/{ => v1}/codegen/rpc/generated/__init__.py (100%) rename tests/{ => v1}/codegen/rpc/generated/test_service/__init__.py (100%) rename tests/{ => v1}/codegen/rpc/generated/test_service/rpc_method.py (100%) rename tests/{ => v1}/codegen/rpc/schema.json (100%) rename tests/{ => v1}/codegen/snapshot/snapshots/test_basic_stream/__init__.py (100%) rename tests/{ => v1}/codegen/snapshot/snapshots/test_basic_stream/test_service/__init__.py (100%) rename tests/{ => v1}/codegen/snapshot/snapshots/test_basic_stream/test_service/emit_error.py (100%) rename tests/{ => v1}/codegen/snapshot/snapshots/test_basic_stream/test_service/stream_method.py (100%) rename tests/{ => v1}/codegen/snapshot/snapshots/test_pathological_types/__init__.py (100%) rename tests/{ => v1}/codegen/snapshot/snapshots/test_pathological_types/test_service/__init__.py (100%) rename tests/{ => v1}/codegen/snapshot/snapshots/test_pathological_types/test_service/pathological_method.py (100%) rename tests/{ => v1}/codegen/snapshot/snapshots/test_unknown_enum/__init__.py (100%) rename tests/{ => v1}/codegen/snapshot/snapshots/test_unknown_enum/enumService/__init__.py (100%) rename tests/{ => v1}/codegen/snapshot/snapshots/test_unknown_enum/enumService/needsEnum.py (100%) rename tests/{ => v1}/codegen/snapshot/snapshots/test_unknown_enum/enumService/needsEnumObject.py (100%) rename tests/{ => v1}/codegen/snapshot/test_enum.py (89%) rename tests/{ => v1}/codegen/snapshot/test_pathological.py (51%) rename tests/{ => v1}/codegen/stream/schema.json (100%) rename tests/{ => v1}/codegen/stream/test_stream.py (74%) rename tests/{ => v1}/codegen/test_rpc.py (79%) rename tests/{ => v1}/codegen/types/schema.json (100%) rename tests/{ => v1}/common_handlers.py (100%) rename tests/{ => v1}/river_fixtures/clientserver.py (97%) rename tests/{ => v1}/river_fixtures/logging.py (100%) rename tests/{ => v1}/test_communication.py (99%) rename tests/{ => v1}/test_handshake.py (100%) rename tests/{ => v1}/test_message_buffer.py (100%) rename tests/{ => v1}/test_opentelemetry.py (98%) rename tests/{ => v1}/test_rate_limiter.py (100%) rename tests/{ => v1}/test_seq_manager.py (97%) rename tests/{ => v1}/test_timeout.py (97%) diff --git a/tests/conftest.py b/tests/conftest.py index b9b8cdf6..c52bb7f7 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -14,7 +14,10 @@ ) # Modular fixtures -pytest_plugins = ["tests.river_fixtures.logging", "tests.river_fixtures.clientserver"] +pytest_plugins = [ + "tests.v1.river_fixtures.logging", + "tests.v1.river_fixtures.clientserver", +] HandlerKind = Literal["rpc", "subscription-stream", "upload-stream", "stream"] HandlerMapping = Mapping[tuple[str, str], tuple[HandlerKind, GenericRpcHandlerBuilder]] diff --git a/tests/codegen/snapshot/codegen_snapshot_fixtures.py b/tests/fixtures/codegen_snapshot_fixtures.py similarity index 84% rename from tests/codegen/snapshot/codegen_snapshot_fixtures.py rename to tests/fixtures/codegen_snapshot_fixtures.py index 2fdff907..d4a88f54 100644 --- a/tests/codegen/snapshot/codegen_snapshot_fixtures.py +++ b/tests/fixtures/codegen_snapshot_fixtures.py @@ -1,6 +1,6 @@ from io import StringIO from pathlib import Path -from typing import Callable, TextIO +from typing import Callable, Literal, TextIO from pytest_snapshot.plugin import Snapshot @@ -15,12 +15,14 @@ def close(self) -> None: def validate_codegen( *, snapshot: Snapshot, + snapshot_dir: str, read_schema: Callable[[], TextIO], target_path: str, client_name: str, + protocol_version: Literal["v1.1", "v2.0"], typeddict_inputs: bool = True, ) -> None: - snapshot.snapshot_dir = "tests/codegen/snapshot/snapshots" + snapshot.snapshot_dir = snapshot_dir files: dict[Path, UnclosableStringIO] = {} def file_opener(path: Path) -> TextIO: @@ -36,7 +38,7 @@ def file_opener(path: Path) -> TextIO: file_opener=file_opener, typed_dict_inputs=typeddict_inputs, method_filter=None, - protocol_version="v1.1", + protocol_version=protocol_version, ) for path, file in files.items(): file.seek(0) diff --git a/tests/codegen/rpc/generated/__init__.py b/tests/v1/codegen/rpc/generated/__init__.py similarity index 100% rename from tests/codegen/rpc/generated/__init__.py rename to tests/v1/codegen/rpc/generated/__init__.py diff --git a/tests/codegen/rpc/generated/test_service/__init__.py b/tests/v1/codegen/rpc/generated/test_service/__init__.py similarity index 100% rename from tests/codegen/rpc/generated/test_service/__init__.py rename to tests/v1/codegen/rpc/generated/test_service/__init__.py diff --git a/tests/codegen/rpc/generated/test_service/rpc_method.py b/tests/v1/codegen/rpc/generated/test_service/rpc_method.py similarity index 100% rename from tests/codegen/rpc/generated/test_service/rpc_method.py rename to tests/v1/codegen/rpc/generated/test_service/rpc_method.py diff --git a/tests/codegen/rpc/schema.json b/tests/v1/codegen/rpc/schema.json similarity index 100% rename from tests/codegen/rpc/schema.json rename to tests/v1/codegen/rpc/schema.json diff --git a/tests/codegen/snapshot/snapshots/test_basic_stream/__init__.py b/tests/v1/codegen/snapshot/snapshots/test_basic_stream/__init__.py similarity index 100% rename from tests/codegen/snapshot/snapshots/test_basic_stream/__init__.py rename to tests/v1/codegen/snapshot/snapshots/test_basic_stream/__init__.py diff --git a/tests/codegen/snapshot/snapshots/test_basic_stream/test_service/__init__.py b/tests/v1/codegen/snapshot/snapshots/test_basic_stream/test_service/__init__.py similarity index 100% rename from tests/codegen/snapshot/snapshots/test_basic_stream/test_service/__init__.py rename to tests/v1/codegen/snapshot/snapshots/test_basic_stream/test_service/__init__.py diff --git a/tests/codegen/snapshot/snapshots/test_basic_stream/test_service/emit_error.py b/tests/v1/codegen/snapshot/snapshots/test_basic_stream/test_service/emit_error.py similarity index 100% rename from tests/codegen/snapshot/snapshots/test_basic_stream/test_service/emit_error.py rename to tests/v1/codegen/snapshot/snapshots/test_basic_stream/test_service/emit_error.py diff --git a/tests/codegen/snapshot/snapshots/test_basic_stream/test_service/stream_method.py b/tests/v1/codegen/snapshot/snapshots/test_basic_stream/test_service/stream_method.py similarity index 100% rename from tests/codegen/snapshot/snapshots/test_basic_stream/test_service/stream_method.py rename to tests/v1/codegen/snapshot/snapshots/test_basic_stream/test_service/stream_method.py diff --git a/tests/codegen/snapshot/snapshots/test_pathological_types/__init__.py b/tests/v1/codegen/snapshot/snapshots/test_pathological_types/__init__.py similarity index 100% rename from tests/codegen/snapshot/snapshots/test_pathological_types/__init__.py rename to tests/v1/codegen/snapshot/snapshots/test_pathological_types/__init__.py diff --git a/tests/codegen/snapshot/snapshots/test_pathological_types/test_service/__init__.py b/tests/v1/codegen/snapshot/snapshots/test_pathological_types/test_service/__init__.py similarity index 100% rename from tests/codegen/snapshot/snapshots/test_pathological_types/test_service/__init__.py rename to tests/v1/codegen/snapshot/snapshots/test_pathological_types/test_service/__init__.py diff --git a/tests/codegen/snapshot/snapshots/test_pathological_types/test_service/pathological_method.py b/tests/v1/codegen/snapshot/snapshots/test_pathological_types/test_service/pathological_method.py similarity index 100% rename from tests/codegen/snapshot/snapshots/test_pathological_types/test_service/pathological_method.py rename to tests/v1/codegen/snapshot/snapshots/test_pathological_types/test_service/pathological_method.py diff --git a/tests/codegen/snapshot/snapshots/test_unknown_enum/__init__.py b/tests/v1/codegen/snapshot/snapshots/test_unknown_enum/__init__.py similarity index 100% rename from tests/codegen/snapshot/snapshots/test_unknown_enum/__init__.py rename to tests/v1/codegen/snapshot/snapshots/test_unknown_enum/__init__.py diff --git a/tests/codegen/snapshot/snapshots/test_unknown_enum/enumService/__init__.py b/tests/v1/codegen/snapshot/snapshots/test_unknown_enum/enumService/__init__.py similarity index 100% rename from tests/codegen/snapshot/snapshots/test_unknown_enum/enumService/__init__.py rename to tests/v1/codegen/snapshot/snapshots/test_unknown_enum/enumService/__init__.py diff --git a/tests/codegen/snapshot/snapshots/test_unknown_enum/enumService/needsEnum.py b/tests/v1/codegen/snapshot/snapshots/test_unknown_enum/enumService/needsEnum.py similarity index 100% rename from tests/codegen/snapshot/snapshots/test_unknown_enum/enumService/needsEnum.py rename to tests/v1/codegen/snapshot/snapshots/test_unknown_enum/enumService/needsEnum.py diff --git a/tests/codegen/snapshot/snapshots/test_unknown_enum/enumService/needsEnumObject.py b/tests/v1/codegen/snapshot/snapshots/test_unknown_enum/enumService/needsEnumObject.py similarity index 100% rename from tests/codegen/snapshot/snapshots/test_unknown_enum/enumService/needsEnumObject.py rename to tests/v1/codegen/snapshot/snapshots/test_unknown_enum/enumService/needsEnumObject.py diff --git a/tests/codegen/snapshot/test_enum.py b/tests/v1/codegen/snapshot/test_enum.py similarity index 89% rename from tests/codegen/snapshot/test_enum.py rename to tests/v1/codegen/snapshot/test_enum.py index e45509d2..72a3f2e2 100644 --- a/tests/codegen/snapshot/test_enum.py +++ b/tests/v1/codegen/snapshot/test_enum.py @@ -4,7 +4,7 @@ from pytest_snapshot.plugin import Snapshot -from tests.codegen.snapshot.codegen_snapshot_fixtures import validate_codegen +from tests.fixtures.codegen_snapshot_fixtures import validate_codegen test_unknown_enum_schema = """ { @@ -173,15 +173,17 @@ def test_unknown_enum(snapshot: Snapshot) -> None: validate_codegen( snapshot=snapshot, + snapshot_dir="tests/v1/codegen/snapshot/snapshots", read_schema=lambda: StringIO(test_unknown_enum_schema), target_path="test_unknown_enum", client_name="foo", + protocol_version="v1.1", ) - import tests.codegen.snapshot.snapshots.test_unknown_enum + import tests.v1.codegen.snapshot.snapshots.test_unknown_enum - importlib.reload(tests.codegen.snapshot.snapshots.test_unknown_enum) - from tests.codegen.snapshot.snapshots.test_unknown_enum.enumService.needsEnum import ( # noqa + importlib.reload(tests.v1.codegen.snapshot.snapshots.test_unknown_enum) + from tests.v1.codegen.snapshot.snapshots.test_unknown_enum.enumService.needsEnum import ( # noqa NeedsenumErrorsTypeAdapter, ) @@ -209,15 +211,17 @@ def test_unknown_enum(snapshot: Snapshot) -> None: def test_unknown_enum_field_aliases(snapshot: Snapshot) -> None: validate_codegen( snapshot=snapshot, + snapshot_dir="tests/v1/codegen/snapshot/snapshots", read_schema=lambda: StringIO(test_unknown_enum_schema), target_path="test_unknown_enum", client_name="foo", + protocol_version="v1.1", ) - import tests.codegen.snapshot.snapshots.test_unknown_enum + import tests.v1.codegen.snapshot.snapshots.test_unknown_enum - importlib.reload(tests.codegen.snapshot.snapshots.test_unknown_enum) - from tests.codegen.snapshot.snapshots.test_unknown_enum.enumService.needsEnumObject import ( # noqa + importlib.reload(tests.v1.codegen.snapshot.snapshots.test_unknown_enum) + from tests.v1.codegen.snapshot.snapshots.test_unknown_enum.enumService.needsEnumObject import ( # noqa NeedsenumobjectOutputTypeAdapter, NeedsenumobjectOutput, NeedsenumobjectOutputFooOneOf_out_first, diff --git a/tests/codegen/snapshot/test_pathological.py b/tests/v1/codegen/snapshot/test_pathological.py similarity index 51% rename from tests/codegen/snapshot/test_pathological.py rename to tests/v1/codegen/snapshot/test_pathological.py index 68775c8f..789a2d4c 100644 --- a/tests/codegen/snapshot/test_pathological.py +++ b/tests/v1/codegen/snapshot/test_pathological.py @@ -1,12 +1,14 @@ from pytest_snapshot.plugin import Snapshot -from tests.codegen.snapshot.codegen_snapshot_fixtures import validate_codegen +from tests.fixtures.codegen_snapshot_fixtures import validate_codegen async def test_pathological_types(snapshot: Snapshot) -> None: validate_codegen( snapshot=snapshot, - read_schema=lambda: open("tests/codegen/types/schema.json"), + snapshot_dir="tests/v1/codegen/snapshot/snapshots", + read_schema=lambda: open("tests/v1/codegen/types/schema.json"), target_path="test_pathological_types", client_name="PathologicalClient", + protocol_version="v1.1", ) diff --git a/tests/codegen/stream/schema.json b/tests/v1/codegen/stream/schema.json similarity index 100% rename from tests/codegen/stream/schema.json rename to tests/v1/codegen/stream/schema.json diff --git a/tests/codegen/stream/test_stream.py b/tests/v1/codegen/stream/test_stream.py similarity index 74% rename from tests/codegen/stream/test_stream.py rename to tests/v1/codegen/stream/test_stream.py index f3966043..2ac02a76 100644 --- a/tests/codegen/stream/test_stream.py +++ b/tests/v1/codegen/stream/test_stream.py @@ -5,8 +5,8 @@ from pytest_snapshot.plugin import Snapshot from replit_river.client import Client, RiverUnknownError -from tests.codegen.snapshot.codegen_snapshot_fixtures import validate_codegen -from tests.common_handlers import basic_stream, error_stream +from tests.fixtures.codegen_snapshot_fixtures import validate_codegen +from tests.v1.common_handlers import basic_stream, error_stream _AlreadyGenerated = False @@ -17,15 +17,17 @@ def stream_client_codegen(snapshot: Snapshot) -> Literal[True]: if not _AlreadyGenerated: validate_codegen( snapshot=snapshot, - read_schema=lambda: open("tests/codegen/stream/schema.json"), + snapshot_dir="tests/v1/codegen/snapshot/snapshots", + read_schema=lambda: open("tests/v1/codegen/stream/schema.json"), target_path="test_basic_stream", client_name="StreamClient", + protocol_version="v1.1", ) _AlreadyGenerated = True - import tests.codegen.snapshot.snapshots.test_basic_stream + import tests.v1.codegen.snapshot.snapshots.test_basic_stream - importlib.reload(tests.codegen.snapshot.snapshots.test_basic_stream) + importlib.reload(tests.v1.codegen.snapshot.snapshots.test_basic_stream) return True @@ -34,10 +36,10 @@ async def test_basic_stream( stream_client_codegen: Literal[True], client: Client, ) -> None: - from tests.codegen.snapshot.snapshots.test_basic_stream import ( + from tests.v1.codegen.snapshot.snapshots.test_basic_stream import ( StreamClient, # noqa: E501 ) - from tests.codegen.snapshot.snapshots.test_basic_stream.test_service.stream_method import ( # noqa: E501 + from tests.v1.codegen.snapshot.snapshots.test_basic_stream.test_service.stream_method import ( # noqa: E501 Stream_MethodInput, Stream_MethodOutput, ) @@ -63,7 +65,7 @@ async def test_error_stream( erroringClient: Client, phase: int, ) -> None: - from tests.codegen.snapshot.snapshots.test_basic_stream import ( + from tests.v1.codegen.snapshot.snapshots.test_basic_stream import ( StreamClient, # noqa: E501 ) diff --git a/tests/codegen/test_rpc.py b/tests/v1/codegen/test_rpc.py similarity index 79% rename from tests/codegen/test_rpc.py rename to tests/v1/codegen/test_rpc.py index 450a74f0..55837190 100644 --- a/tests/codegen/test_rpc.py +++ b/tests/v1/codegen/test_rpc.py @@ -14,21 +14,21 @@ from replit_river.codegen.client import schema_to_river_client_codegen from replit_river.error_schema import RiverException from replit_river.rpc import rpc_method_handler -from tests.common_handlers import basic_rpc_method from tests.conftest import HandlerMapping, deserialize_request, serialize_response +from tests.v1.common_handlers import basic_rpc_method @pytest.fixture(scope="session", autouse=True) def generate_rpc_client() -> None: - shutil.rmtree("tests/codegen/rpc/generated", ignore_errors=True) - os.makedirs("tests/codegen/rpc/generated") + shutil.rmtree("tests/v1/codegen/rpc/generated", ignore_errors=True) + os.makedirs("tests/v1/codegen/rpc/generated") def file_opener(path: Path) -> TextIO: return open(path, "w") schema_to_river_client_codegen( - read_schema=lambda: open("tests/codegen/rpc/schema.json"), - target_path="tests/codegen/rpc/generated", + read_schema=lambda: open("tests/v1/codegen/rpc/schema.json"), + target_path="tests/v1/codegen/rpc/generated", client_name="RpcClient", typed_dict_inputs=True, file_opener=file_opener, @@ -39,15 +39,15 @@ def file_opener(path: Path) -> TextIO: @pytest.fixture(scope="session", autouse=True) def reload_rpc_import(generate_rpc_client: None) -> None: - import tests.codegen.rpc.generated + import tests.v1.codegen.rpc.generated - importlib.reload(tests.codegen.rpc.generated) + importlib.reload(tests.v1.codegen.rpc.generated) @pytest.mark.asyncio @pytest.mark.parametrize("handlers", [{**basic_rpc_method}]) async def test_basic_rpc(client: Client) -> None: - from tests.codegen.rpc.generated import RpcClient + from tests.v1.codegen.rpc.generated import RpcClient res = await RpcClient(client).test_service.rpc_method( { @@ -76,7 +76,7 @@ async def rpc_timeout_handler(request: str, context: grpc.aio.ServicerContext) - @pytest.mark.asyncio @pytest.mark.parametrize("handlers", [{**rpc_timeout_method}]) async def test_rpc_timeout(client: Client) -> None: - from tests.codegen.rpc.generated import RpcClient + from tests.v1.codegen.rpc.generated import RpcClient with pytest.raises(RiverException): await RpcClient(client).test_service.rpc_method( diff --git a/tests/codegen/types/schema.json b/tests/v1/codegen/types/schema.json similarity index 100% rename from tests/codegen/types/schema.json rename to tests/v1/codegen/types/schema.json diff --git a/tests/common_handlers.py b/tests/v1/common_handlers.py similarity index 100% rename from tests/common_handlers.py rename to tests/v1/common_handlers.py diff --git a/tests/river_fixtures/clientserver.py b/tests/v1/river_fixtures/clientserver.py similarity index 97% rename from tests/river_fixtures/clientserver.py rename to tests/v1/river_fixtures/clientserver.py index bf576b5c..c01a9ee3 100644 --- a/tests/river_fixtures/clientserver.py +++ b/tests/v1/river_fixtures/clientserver.py @@ -9,7 +9,7 @@ from replit_river.server import Server from replit_river.transport_options import TransportOptions from tests.conftest import HandlerMapping -from tests.river_fixtures.logging import NoErrors # noqa: E402 +from tests.v1.river_fixtures.logging import NoErrors # noqa: E402 @pytest.fixture diff --git a/tests/river_fixtures/logging.py b/tests/v1/river_fixtures/logging.py similarity index 100% rename from tests/river_fixtures/logging.py rename to tests/v1/river_fixtures/logging.py diff --git a/tests/test_communication.py b/tests/v1/test_communication.py similarity index 99% rename from tests/test_communication.py rename to tests/v1/test_communication.py index 87de1811..470b368b 100644 --- a/tests/test_communication.py +++ b/tests/v1/test_communication.py @@ -9,12 +9,6 @@ from replit_river.error_schema import RiverError from replit_river.rpc import subscription_method_handler from replit_river.transport_options import MAX_MESSAGE_BUFFER_SIZE -from tests.common_handlers import ( - basic_rpc_method, - basic_stream, - basic_subscription, - basic_upload, -) from tests.conftest import ( HandlerMapping, deserialize_error, @@ -23,6 +17,12 @@ serialize_request, serialize_response, ) +from tests.v1.common_handlers import ( + basic_rpc_method, + basic_stream, + basic_subscription, + basic_upload, +) @pytest.mark.asyncio diff --git a/tests/test_handshake.py b/tests/v1/test_handshake.py similarity index 100% rename from tests/test_handshake.py rename to tests/v1/test_handshake.py diff --git a/tests/test_message_buffer.py b/tests/v1/test_message_buffer.py similarity index 100% rename from tests/test_message_buffer.py rename to tests/v1/test_message_buffer.py diff --git a/tests/test_opentelemetry.py b/tests/v1/test_opentelemetry.py similarity index 98% rename from tests/test_opentelemetry.py rename to tests/v1/test_opentelemetry.py index 801b133d..c47b5418 100644 --- a/tests/test_opentelemetry.py +++ b/tests/v1/test_opentelemetry.py @@ -11,12 +11,6 @@ from replit_river.client import Client from replit_river.error_schema import RiverError, RiverException from replit_river.rpc import stream_method_handler -from tests.common_handlers import ( - basic_rpc_method, - basic_stream, - basic_subscription, - basic_upload, -) from tests.conftest import ( HandlerMapping, deserialize_error, @@ -25,7 +19,13 @@ serialize_request, serialize_response, ) -from tests.river_fixtures.logging import NoErrors +from tests.v1.common_handlers import ( + basic_rpc_method, + basic_stream, + basic_subscription, + basic_upload, +) +from tests.v1.river_fixtures.logging import NoErrors @pytest.mark.asyncio diff --git a/tests/test_rate_limiter.py b/tests/v1/test_rate_limiter.py similarity index 100% rename from tests/test_rate_limiter.py rename to tests/v1/test_rate_limiter.py diff --git a/tests/test_seq_manager.py b/tests/v1/test_seq_manager.py similarity index 97% rename from tests/test_seq_manager.py rename to tests/v1/test_seq_manager.py index fba7f58a..1373e0f9 100644 --- a/tests/test_seq_manager.py +++ b/tests/v1/test_seq_manager.py @@ -6,7 +6,7 @@ SeqManager, ) from tests.conftest import transport_message -from tests.river_fixtures.logging import NoErrors +from tests.v1.river_fixtures.logging import NoErrors @pytest.mark.asyncio diff --git a/tests/test_timeout.py b/tests/v1/test_timeout.py similarity index 97% rename from tests/test_timeout.py rename to tests/v1/test_timeout.py index 94e033fd..1ac3db97 100644 --- a/tests/test_timeout.py +++ b/tests/v1/test_timeout.py @@ -7,15 +7,15 @@ from replit_river.client import Client from replit_river.error_schema import ERROR_CODE_CANCEL, RiverException -from tests.common_handlers import ( - rpc_method_handler, -) from tests.conftest import ( HandlerMapping, deserialize_error, deserialize_response, serialize_response, ) +from tests.v1.common_handlers import ( + rpc_method_handler, +) async def rpc_slow_handler(duration: float, context: grpc.aio.ServicerContext) -> str: From f96d61b3664592e94f6f3cda7cd2deb07488472a Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Mon, 31 Mar 2025 16:40:25 -0700 Subject: [PATCH 125/193] Cleaning up budget exhausted errors in do_ensure_connected --- src/replit_river/rate_limiter.py | 26 -------------------------- src/replit_river/v2/session.py | 11 ++++++----- 2 files changed, 6 insertions(+), 31 deletions(-) diff --git a/src/replit_river/rate_limiter.py b/src/replit_river/rate_limiter.py index 384288be..5e742ce9 100644 --- a/src/replit_river/rate_limiter.py +++ b/src/replit_river/rate_limiter.py @@ -2,7 +2,6 @@ import logging import random from contextvars import Context -from typing import Literal from replit_river.error_schema import RiverException from replit_river.transport_options import ConnectionRetryOptions @@ -75,31 +74,6 @@ def has_budget(self, user: str) -> bool: """ return self.get_budget_consumed(user) < self.options.attempt_budget_capacity - def has_budget_or_throw( - self, - user: str, - error_code: str, - last_error: Exception | None, - ) -> Literal[True]: - """ - Check if the user has remaining budget to make a retry. - If they do not, explode. - - Args: - user (str): The identifier for the user. - - Returns: - bool: True if budget is available, False otherwise. - """ - if self.get_budget_consumed(user) > self.options.attempt_budget_capacity: - logger.debug("No retry budget for %s.", user) - raise BudgetExhaustedException( - error_code, - "No retry budget", - client_id=user, - ) from last_error - return True - def consume_budget(self, user: str) -> None: """Increment the budget consumed for the user by 1, indicating a retry attempt. diff --git a/src/replit_river/v2/session.py b/src/replit_river/v2/session.py index 5bcfea8b..cc37c4c5 100644 --- a/src/replit_river/v2/session.py +++ b/src/replit_river/v2/session.py @@ -108,7 +108,7 @@ class Session: _state: SessionState _close_session_callback: CloseSessionCallback _close_session_after_time_secs: float | None - _connecting_task: asyncio.Task[Literal[True]] | None + _connecting_task: asyncio.Task[None] | None _wait_for_connected: asyncio.Event # ws state @@ -979,12 +979,12 @@ async def _do_ensure_connected[HandshakeMetadata]( transition_connected: Callable[[ClientConnection], None], finalize_attempt: Callable[[], None], do_close: Callable[[], None], -) -> Literal[True]: +) -> None: logger.info("Attempting to establish new ws connection") last_error: Exception | None = None i = 0 - while rate_limiter.has_budget_or_throw(client_id, ERROR_HANDSHAKE, last_error): + while rate_limiter.has_budget(client_id): if i > 0: logger.info(f"Retrying build handshake number {i} times") i += 1 @@ -1108,16 +1108,17 @@ async def websocket_closed_callback() -> None: f"Error connecting, retrying with {backoff_time}ms backoff" ) await asyncio.sleep(backoff_time / 1000) - finalize_attempt() if last_error is not None: + logger.debug("Handshake attempts exhausted, terminating") + do_close() raise RiverException( ERROR_HANDSHAKE, f"Failed to create ws after retrying {max_retry} number of times", ) from last_error - return True + return None async def _setup_heartbeat( From c7b7f8eddf633d15e94c844d885ad6df10288788 Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Mon, 31 Mar 2025 17:36:45 -0700 Subject: [PATCH 126/193] Transitioning back to NO_CONNECTION needs to start blocking connection waiters --- src/replit_river/v2/session.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/replit_river/v2/session.py b/src/replit_river/v2/session.py index cc37c4c5..263347fc 100644 --- a/src/replit_river/v2/session.py +++ b/src/replit_river/v2/session.py @@ -303,6 +303,7 @@ async def _begin_close_session_countdown(self) -> None: ) self._state = SessionState.NO_CONNECTION self._close_session_after_time_secs = close_session_after_time_secs + self._wait_for_connected.clear() async def _get_current_time(self) -> float: return asyncio.get_event_loop().time() From 6231e24187372f1141d62940cda941aeaa570893 Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Mon, 31 Mar 2025 17:37:47 -0700 Subject: [PATCH 127/193] Adding more debug messages --- src/replit_river/v2/client_transport.py | 1 + src/replit_river/v2/session.py | 11 +++++++++-- 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/src/replit_river/v2/client_transport.py b/src/replit_river/v2/client_transport.py index ebafefbe..854eba3c 100644 --- a/src/replit_river/v2/client_transport.py +++ b/src/replit_river/v2/client_transport.py @@ -73,6 +73,7 @@ async def _retry_connection(self) -> Session: if self._session and not self._transport_options.transparent_reconnect: logger.info("transparent_reconnect not set, closing {self._transport_id}") await self._session.close() + logger.debug("Triggering get_or_create_session") return await self.get_or_create_session() async def _delete_session(self, session: Session) -> None: diff --git a/src/replit_river/v2/session.py b/src/replit_river/v2/session.py index 263347fc..9ec293db 100644 --- a/src/replit_river/v2/session.py +++ b/src/replit_river/v2/session.py @@ -309,6 +309,7 @@ async def _get_current_time(self) -> float: return asyncio.get_event_loop().time() def _reset_session_close_countdown(self) -> None: + logger.debug('_reset_session_close_countdown') self._heartbeat_misses = 0 self._close_session_after_time_secs = None @@ -945,6 +946,7 @@ async def _check_to_close_session( ) -> None: our_task = asyncio.current_task() while our_task and not our_task.cancelling() and not our_task.cancelled(): + logger.debug('_check_to_close_session: Checking') await asyncio.sleep(close_session_check_interval_ms / 1000) if get_state() in TerminalStates: # already closing @@ -955,6 +957,7 @@ async def _check_to_close_session( current_time = await get_current_time() close_session_after_time_secs = get_close_session_after_time_secs() if not close_session_after_time_secs: + logger.debug(f'_check_to_close_session: Not reached: {close_session_after_time_secs}') continue if current_time > close_session_after_time_secs: logger.info("Grace period ended for %s, closing session", transport_id) @@ -1134,14 +1137,18 @@ async def _setup_heartbeat( ) -> None: while True: while (state := get_state()) in ConnectingStates: + logger.debug( + "Heartbeat: block_until_connected: %r", + state, + ) await block_until_connected() if state in TerminalStates: logger.debug( "Session is closed, no need to send heartbeat, state : " "%r close_session_after_this: %r", - {state}, - {get_closing_grace_period()}, + state, + get_closing_grace_period(), ) # session is closing / closed, no need to send heartbeat anymore break From b0646f92f61310e1ae25c82cc501ecf6fe7e1958 Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Tue, 1 Apr 2025 13:16:32 -0700 Subject: [PATCH 128/193] Clarifying "output" channel type --- src/replit_river/v2/session.py | 64 +++++++++++++++++++++++----------- 1 file changed, 43 insertions(+), 21 deletions(-) diff --git a/src/replit_river/v2/session.py b/src/replit_river/v2/session.py index 9ec293db..53453706 100644 --- a/src/replit_river/v2/session.py +++ b/src/replit_river/v2/session.py @@ -13,7 +13,9 @@ Callable, Coroutine, Literal, + NotRequired, TypeAlias, + TypedDict, assert_never, ) @@ -81,6 +83,24 @@ STREAM_CLOSED_BIT: STREAM_CLOSED_BIT_TYPE = 0b01000 +class ResultOk(TypedDict): + ok: Literal[True] + payload: Any + + +class ErrorPayload(TypedDict): + code: str + message: str + + +class ResultError(TypedDict): + # Account for structurally incoherent payloads + ok: NotRequired[Literal[False]] + payload: ErrorPayload + + +ResultType: TypeAlias = ResultOk | ResultError + logger = logging.getLogger(__name__) trace_propagator = TraceContextTextMapPropagator() @@ -309,7 +329,7 @@ async def _get_current_time(self) -> float: return asyncio.get_event_loop().time() def _reset_session_close_countdown(self) -> None: - logger.debug('_reset_session_close_countdown') + logger.debug("_reset_session_close_countdown") self._heartbeat_misses = 0 self._close_session_after_time_secs = None @@ -588,7 +608,7 @@ async def _with_stream( self, session_id: str, maxsize: int, - ) -> AsyncIterator[tuple[asyncio.Event, Channel[Any]]]: + ) -> AsyncIterator[tuple[asyncio.Event, Channel[ResultType]]]: output: Channel[Any] = Channel(maxsize=maxsize) event = asyncio.Event() self._streams[session_id] = (event, output) @@ -628,7 +648,7 @@ async def send_rpc[R, A]( # Block for event for symmetry with backpressured producers # Here this should be trivially true. await event.wait() - response = await output.get() + result = await output.get() except asyncio.TimeoutError as e: await self._send_cancel_stream( stream_id=stream_id, @@ -645,16 +665,16 @@ async def send_rpc[R, A]( ) from e except RuntimeError as e: raise RiverException(ERROR_CODE_STREAM_CLOSED, str(e)) from e - if not response.get("ok", False): + if "ok" not in result or not result["ok"]: try: - error = error_deserializer(response["payload"]) + error = error_deserializer(result["payload"]) except Exception as e: raise RiverException("error_deserializer", str(e)) from e raise exception_from_message(error.code)( error.code, error.message, service_name, procedure_name ) - return response_deserializer(response["payload"]) + return response_deserializer(result["payload"]) async def send_upload[I, R, A]( self, @@ -730,7 +750,7 @@ async def send_upload[I, R, A]( # Handle potential errors during communication # TODO: throw a error when the transport is hard closed try: - response = await output.get() + result = await output.get() except ChannelClosed as e: raise RiverServiceException( ERROR_CODE_STREAM_CLOSED, @@ -740,16 +760,16 @@ async def send_upload[I, R, A]( ) from e except RuntimeError as e: raise RiverException(ERROR_CODE_STREAM_CLOSED, str(e)) from e - if not response.get("ok", False): + if "ok" not in result or not result["ok"]: try: - error = error_deserializer(response["payload"]) + error = error_deserializer(result["payload"]) except Exception as e: raise RiverException("error_deserializer", str(e)) from e raise exception_from_message(error.code)( error.code, error.message, service_name, procedure_name ) - return response_deserializer(response["payload"]) + return response_deserializer(result["payload"]) async def send_subscription[R, E, A]( self, @@ -879,19 +899,19 @@ async def _encode_stream() -> None: # Handle potential errors during communication try: - async for item in output: - if item.get("type") == "CLOSE": + async for result in output: + if result.get("type") == "CLOSE": break - if not item.get("ok", False): + if "ok" not in result or not result["ok"]: try: - yield error_deserializer(item["payload"]) + yield error_deserializer(result["payload"]) except Exception: logger.exception( - "Error during subscription " - f"error deserialization: {item}" + "Error during stream " + f"error deserialization: {result}" ) continue - yield response_deserializer(item["payload"]) + yield response_deserializer(result["payload"]) except (RuntimeError, ChannelClosed) as e: raise RiverServiceException( ERROR_CODE_STREAM_CLOSED, @@ -946,7 +966,7 @@ async def _check_to_close_session( ) -> None: our_task = asyncio.current_task() while our_task and not our_task.cancelling() and not our_task.cancelled(): - logger.debug('_check_to_close_session: Checking') + logger.debug("_check_to_close_session: Checking") await asyncio.sleep(close_session_check_interval_ms / 1000) if get_state() in TerminalStates: # already closing @@ -957,7 +977,9 @@ async def _check_to_close_session( current_time = await get_current_time() close_session_after_time_secs = get_close_session_after_time_secs() if not close_session_after_time_secs: - logger.debug(f'_check_to_close_session: Not reached: {close_session_after_time_secs}') + logger.debug( + f"_check_to_close_session: Not reached: {close_session_after_time_secs}" + ) continue if current_time > close_session_after_time_secs: logger.info("Grace period ended for %s, closing session", transport_id) @@ -1138,8 +1160,8 @@ async def _setup_heartbeat( while True: while (state := get_state()) in ConnectingStates: logger.debug( - "Heartbeat: block_until_connected: %r", - state, + "Heartbeat: block_until_connected: %r", + state, ) await block_until_connected() From ad0d6839b2b7d15ae99c770385fb31dcc5688949 Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Tue, 1 Apr 2025 13:38:17 -0700 Subject: [PATCH 129/193] Avoid sending service/procedure names on CLOSE --- src/replit_river/v2/session.py | 18 +++--------------- 1 file changed, 3 insertions(+), 15 deletions(-) diff --git a/src/replit_river/v2/session.py b/src/replit_river/v2/session.py index 53453706..40c014e9 100644 --- a/src/replit_river/v2/session.py +++ b/src/replit_river/v2/session.py @@ -740,9 +740,7 @@ async def send_upload[I, R, A]( ERROR_CODE_STREAM_CLOSED, str(e), service_name, procedure_name ) from e await self._send_close_stream( - service_name, - procedure_name, - stream_id, + stream_id=stream_id, extra_control_flags=0, span=span, ) @@ -863,9 +861,7 @@ async def send_stream[I, R, E, A]( async def _encode_stream() -> None: if not request: await self._send_close_stream( - service_name, - procedure_name, - stream_id, + stream_id=stream_id, extra_control_flags=STREAM_OPEN_BIT, span=span, ) @@ -881,16 +877,12 @@ async def _encode_stream() -> None: logger.debug("Stream is closed, avoid sending the rest") break await self._send_message( - service_name=service_name, - procedure_name=procedure_name, stream_id=stream_id, control_flags=0, payload=request_serializer(item), ) await self._send_close_stream( - service_name, - procedure_name, - stream_id, + stream_id=stream_id, extra_control_flags=0, span=span, ) @@ -940,15 +932,11 @@ async def _send_cancel_stream( async def _send_close_stream( self, - service_name: str, - procedure_name: str, stream_id: str, extra_control_flags: int, span: Span, ) -> None: await self._send_message( - service_name=service_name, - procedure_name=procedure_name, stream_id=stream_id, control_flags=STREAM_CLOSED_BIT | extra_control_flags, payload={"type": "CLOSE"}, From 862eca9db2ed6a2bcafa2f6d7bfd709a70558cb5 Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Tue, 1 Apr 2025 13:39:21 -0700 Subject: [PATCH 130/193] event -> backpressure_waiter name clarification --- src/replit_river/v2/session.py | 30 +++++++++++++++--------------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/src/replit_river/v2/session.py b/src/replit_river/v2/session.py index 40c014e9..e6dac0a8 100644 --- a/src/replit_river/v2/session.py +++ b/src/replit_river/v2/session.py @@ -418,10 +418,10 @@ async def close(self) -> None: # TODO: unexpected_close should close stream differently here to # throw exception correctly. - for event, stream in self._streams.values(): + for backpressure_waiter, stream in self._streams.values(): stream.close() # Wake up backpressured writers - event.set() + backpressure_waiter.set() # Before we GC the streams, let's wait for all tasks to be closed gracefully. await asyncio.gather(*[stream.join() for _, stream in self._streams.values()]) self._streams.clear() @@ -610,10 +610,10 @@ async def _with_stream( maxsize: int, ) -> AsyncIterator[tuple[asyncio.Event, Channel[ResultType]]]: output: Channel[Any] = Channel(maxsize=maxsize) - event = asyncio.Event() - self._streams[session_id] = (event, output) + backpressure_waiter = asyncio.Event() + self._streams[session_id] = (backpressure_waiter, output) try: - yield (event, output) + yield (backpressure_waiter, output) finally: del self._streams[session_id] @@ -633,7 +633,7 @@ async def send_rpc[R, A]( Expects the input and output be messages that will be msgpacked. """ stream_id = nanoid.generate() - async with self._with_stream(stream_id, 1) as (event, output): + async with self._with_stream(stream_id, 1) as (backpressure_waiter, output): await self._send_message( stream_id=stream_id, control_flags=STREAM_OPEN_BIT | STREAM_CLOSED_BIT, @@ -647,7 +647,7 @@ async def send_rpc[R, A]( async with asyncio.timeout(timeout.total_seconds()): # Block for event for symmetry with backpressured producers # Here this should be trivially true. - await event.wait() + await backpressure_waiter.wait() result = await output.get() except asyncio.TimeoutError as e: await self._send_cancel_stream( @@ -694,7 +694,7 @@ async def send_upload[I, R, A]( """ stream_id = nanoid.generate() - async with self._with_stream(stream_id, 1) as (event, output): + async with self._with_stream(stream_id, 1) as (backpressure_waiter, output): try: await self._send_message( stream_id=stream_id, @@ -712,7 +712,7 @@ async def send_upload[I, R, A]( # throw exception here async for item in request: # Block for backpressure - await event.wait() + await backpressure_waiter.wait() if output.closed(): logger.debug("Stream is closed, avoid sending the rest") break @@ -842,7 +842,7 @@ async def send_stream[I, R, E, A]( async with self._with_stream( stream_id, MAX_MESSAGE_BUFFER_SIZE, - ) as (event, output): + ) as (backpressure_waiter, output): try: await self._send_message( service_name=service_name, @@ -872,7 +872,7 @@ async def _encode_stream() -> None: async for item in request: if item is None: continue - await event.wait() + await backpressure_waiter.wait() if output.closed(): logger.debug("Stream is closed, avoid sending the rest") break @@ -915,7 +915,7 @@ async def _encode_stream() -> None: raise e finally: output.close() - event.set() + backpressure_waiter.set() async def _send_cancel_stream( self, @@ -1286,7 +1286,7 @@ async def _serve( ) continue - event, stream = event_stream + backpressure_waiter, stream = event_stream if ( msg.controlFlags & STREAM_CLOSED_BIT != 0 @@ -1299,7 +1299,7 @@ async def _serve( try: await stream.put(msg.payload) # Wake up backpressured writer - event.set() + backpressure_waiter.set() except ChannelClosed: # The client is no longer interested in this stream, # just drop the message. @@ -1311,7 +1311,7 @@ async def _serve( # Communicate that we're going down stream.close() # Wake up backpressured writer - event.set() + backpressure_waiter.set() except OutOfOrderMessageException: logger.exception("Out of order message, closing connection") await close_session() From 682de58717e78a1f04856e5682396f7432cf8094 Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Tue, 1 Apr 2025 13:39:40 -0700 Subject: [PATCH 131/193] Wake up backpressured writers on commit() --- src/replit_river/v2/session.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/replit_river/v2/session.py b/src/replit_river/v2/session.py index e6dac0a8..e758b25a 100644 --- a/src/replit_river/v2/session.py +++ b/src/replit_river/v2/session.py @@ -451,6 +451,9 @@ def commit(msg: TransportMessage) -> None: if not self._send_buffer: self._process_messages.clear() + # Wake up backpressured writer + backpressure_waiter, _ = self._streams[pending.streamId] + backpressure_waiter.set() def get_next_pending() -> TransportMessage | None: if self._send_buffer: return self._send_buffer[0] @@ -1298,8 +1301,6 @@ async def _serve( else: try: await stream.put(msg.payload) - # Wake up backpressured writer - backpressure_waiter.set() except ChannelClosed: # The client is no longer interested in this stream, # just drop the message. From 2a7a0948dfdcb83a3255c23c5b9b2ac5224875da Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Tue, 1 Apr 2025 15:52:50 -0700 Subject: [PATCH 132/193] Tighter controls around lifecycle management --- src/replit_river/v2/session.py | 48 +++++++++++++++++++++++++++------- 1 file changed, 39 insertions(+), 9 deletions(-) diff --git a/src/replit_river/v2/session.py b/src/replit_river/v2/session.py index e758b25a..09a73b88 100644 --- a/src/replit_river/v2/session.py +++ b/src/replit_river/v2/session.py @@ -238,10 +238,17 @@ def do_close() -> None: self._terminating_task = asyncio.create_task(self.close()) def transition_connecting() -> None: + if self._state in TerminalStates: + return + logger.debug("transition_connecting") + self._state = SessionState.CONNECTING # "Clear" here means observers should wait until we are connected. self._wait_for_connected.clear() def transition_connected(ws: ClientConnection) -> None: + if self._state in TerminalStates: + return + logger.debug("transition_connected") self._state = SessionState.ACTIVE self._ws = ws @@ -465,7 +472,9 @@ def get_ws() -> ClientConnection | None: return None async def block_until_connected() -> None: + logger.debug("block_until_connected") await self._wait_for_connected.wait() + logger.debug("block_until_connected released!") async def block_until_message_available() -> None: await self._process_messages.wait() @@ -542,9 +551,14 @@ async def block_until_connected() -> None: def _start_serve_responses(self) -> None: async def transition_connecting() -> None: + if self._state in TerminalStates: + return self._state = SessionState.CONNECTING + self._wait_for_connected.clear() - async def connection_interrupted() -> None: + async def transition_no_connection() -> None: + if self._state in TerminalStates: + return self._state = SessionState.NO_CONNECTION if self._ws: self._task_manager.create_task(self._ws.close()) @@ -588,7 +602,11 @@ def assert_incoming_seq_bookkeeping( return True async def block_until_connected() -> None: + if self._state in TerminalStates: + return + logger.debug("block_until_connected") await self._wait_for_connected.wait() + logger.debug("block_until_connected released!") self._task_manager.create_task( _serve( @@ -597,7 +615,7 @@ async def block_until_connected() -> None: get_state=lambda: self._state, get_ws=lambda: self._ws, transition_connecting=transition_connecting, - connection_interrupted=connection_interrupted, + transition_no_connection=transition_no_connection, reset_session_close_countdown=self._reset_session_close_countdown, close_session=self.close, assert_incoming_seq_bookkeeping=assert_incoming_seq_bookkeeping, @@ -1189,7 +1207,7 @@ async def _serve( get_state: Callable[[], SessionState], get_ws: Callable[[], ClientConnection | None], transition_connecting: Callable[[], Awaitable[None]], - connection_interrupted: Callable[[], Awaitable[None]], + transition_no_connection: Callable[[], Awaitable[None]], reset_session_close_countdown: Callable[[], None], close_session: Callable[[], Awaitable[None]], assert_incoming_seq_bookkeeping: Callable[ @@ -1206,11 +1224,14 @@ async def _serve( while our_task and not our_task.cancelling() and not our_task.cancelled(): logger.debug(f"_serve loop count={idx}") idx += 1 - while (ws := get_ws()) is None or ( + ws = None + while ( state := get_state() - ) in ConnectingStates: - logger.debug("_handle_messages_from_ws spinning while connecting") + ) in ConnectingStates or (ws := get_ws()) is None: + logger.debug("_handle_messages_from_ws spinning while connecting, %r %r", ws, state) await block_until_connected() + if state in TerminalStates: + break if state in TerminalStates: logger.debug( @@ -1219,6 +1240,11 @@ async def _serve( # session is closing / closed, no need to serve anymore break + # This should not happen, but due to the complex logic around TerminalStates + # above, pyright is not convinced we've caught all the states. + if not ws: + continue + logger.debug( "%s start handling messages from ws %s", "client", @@ -1229,7 +1255,11 @@ async def _serve( # decode=False: Avoiding an unnecessary round-trip through str # Ideally this should be type-ascripted to : bytes, but there # is no @overrides in `websockets` to hint this. - message = await ws.recv(decode=False) + try: + message = await ws.recv(decode=False) + except ConnectionClosed: + await transition_connecting() + continue try: msg = parse_transport_msg(message) logger.debug( @@ -1329,12 +1359,12 @@ async def _serve( break except ConnectionClosed: # Set ourselves to closed as soon as we get the signal - await connection_interrupted() + await transition_no_connection() logger.debug("ConnectionClosed while serving", exc_info=True) break except FailedSendingMessageException: # Expected error if the connection is closed. - await connection_interrupted() + await transition_no_connection() logger.debug( "FailedSendingMessageException while serving", exc_info=True ) From 3591e49ebcf7ef7c95eb8250f265c5e5b898f52e Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Tue, 1 Apr 2025 22:59:05 -0700 Subject: [PATCH 133/193] Wait for graceful shutdown --- src/replit_river/v2/session.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/replit_river/v2/session.py b/src/replit_river/v2/session.py index 09a73b88..51c023b5 100644 --- a/src/replit_river/v2/session.py +++ b/src/replit_river/v2/session.py @@ -421,6 +421,9 @@ async def close(self) -> None: # ... message processor so it can exit cleanly self._process_messages.set() + # Wait a tick to permit the waiting tasks to shut down gracefully + await asyncio.sleep(0.01) + await self._task_manager.cancel_all_tasks() # TODO: unexpected_close should close stream differently here to From 79f5c6f505c7b3832762af0ee2ccb6bfc8575ffb Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Mon, 31 Mar 2025 16:11:45 -0700 Subject: [PATCH 134/193] Add v2 stream tests --- src/replit_river/v2/session.py | 16 +- tests/conftest.py | 1 + .../snapshots/test_basic_stream/__init__.py | 13 ++ .../test_service/__init__.py | 44 ++++ .../test_service/emit_error.py | 40 ++++ .../test_service/stream_method.py | 59 +++++ tests/v2/datagrams.py | 94 ++++++++ tests/v2/fixtures.py | 211 ++++++++++++++++++ tests/v2/interpreter.py | 190 ++++++++++++++++ tests/v2/test_stream.py | 194 ++++++++++++++++ tests/v2/test_stream.schema.json | 36 +++ 11 files changed, 892 insertions(+), 6 deletions(-) create mode 100644 tests/v2/codegen/snapshot/snapshots/test_basic_stream/__init__.py create mode 100644 tests/v2/codegen/snapshot/snapshots/test_basic_stream/test_service/__init__.py create mode 100644 tests/v2/codegen/snapshot/snapshots/test_basic_stream/test_service/emit_error.py create mode 100644 tests/v2/codegen/snapshot/snapshots/test_basic_stream/test_service/stream_method.py create mode 100644 tests/v2/datagrams.py create mode 100644 tests/v2/fixtures.py create mode 100644 tests/v2/interpreter.py create mode 100644 tests/v2/test_stream.py create mode 100644 tests/v2/test_stream.schema.json diff --git a/src/replit_river/v2/session.py b/src/replit_river/v2/session.py index 51c023b5..2fe4b9a2 100644 --- a/src/replit_river/v2/session.py +++ b/src/replit_river/v2/session.py @@ -464,6 +464,7 @@ def commit(msg: TransportMessage) -> None: # Wake up backpressured writer backpressure_waiter, _ = self._streams[pending.streamId] backpressure_waiter.set() + def get_next_pending() -> TransportMessage | None: if self._send_buffer: return self._send_buffer[0] @@ -923,8 +924,7 @@ async def _encode_stream() -> None: yield error_deserializer(result["payload"]) except Exception: logger.exception( - "Error during stream " - f"error deserialization: {result}" + f"Error during stream error deserialization: {result}" ) continue yield response_deserializer(result["payload"]) @@ -1228,10 +1228,14 @@ async def _serve( logger.debug(f"_serve loop count={idx}") idx += 1 ws = None - while ( - state := get_state() - ) in ConnectingStates or (ws := get_ws()) is None: - logger.debug("_handle_messages_from_ws spinning while connecting, %r %r", ws, state) + while (state := get_state()) in ConnectingStates or ( + ws := get_ws() + ) is None: + logger.debug( + "_handle_messages_from_ws spinning while connecting, %r %r", + ws, + state, + ) await block_until_connected() if state in TerminalStates: break diff --git a/tests/conftest.py b/tests/conftest.py index c52bb7f7..9ab2f9f1 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -17,6 +17,7 @@ pytest_plugins = [ "tests.v1.river_fixtures.logging", "tests.v1.river_fixtures.clientserver", + "tests.v2.fixtures", ] HandlerKind = Literal["rpc", "subscription-stream", "upload-stream", "stream"] diff --git a/tests/v2/codegen/snapshot/snapshots/test_basic_stream/__init__.py b/tests/v2/codegen/snapshot/snapshots/test_basic_stream/__init__.py new file mode 100644 index 00000000..58d21c8a --- /dev/null +++ b/tests/v2/codegen/snapshot/snapshots/test_basic_stream/__init__.py @@ -0,0 +1,13 @@ +# Code generated by river.codegen. DO NOT EDIT. +from pydantic import BaseModel +from typing import Literal + +import replit_river as river + + +from .test_service import Test_ServiceService + + +class StreamClient: + def __init__(self, client: river.v2.Client[Literal[None]]): + self.test_service = Test_ServiceService(client) diff --git a/tests/v2/codegen/snapshot/snapshots/test_basic_stream/test_service/__init__.py b/tests/v2/codegen/snapshot/snapshots/test_basic_stream/test_service/__init__.py new file mode 100644 index 00000000..90080831 --- /dev/null +++ b/tests/v2/codegen/snapshot/snapshots/test_basic_stream/test_service/__init__.py @@ -0,0 +1,44 @@ +# Code generated by river.codegen. DO NOT EDIT. +from collections.abc import AsyncIterable, AsyncIterator +from typing import Any +import datetime + +from pydantic import TypeAdapter + +from replit_river.error_schema import RiverError, RiverErrorTypeAdapter +import replit_river as river + + +from .stream_method import ( + Stream_MethodInit, + Stream_MethodInput, + Stream_MethodOutput, + Stream_MethodOutputTypeAdapter, + encode_Stream_MethodInit, + encode_Stream_MethodInput, +) + + +class Test_ServiceService: + def __init__(self, client: river.v2.Client[Any]): + self.client = client + + async def stream_method( + self, + init: Stream_MethodInit, + inputStream: AsyncIterable[Stream_MethodInput], + ) -> AsyncIterator[Stream_MethodOutput | RiverError | RiverError]: + return self.client.send_stream( + "test_service", + "stream_method", + init, + inputStream, + encode_Stream_MethodInit, + encode_Stream_MethodInput, + lambda x: Stream_MethodOutputTypeAdapter.validate_python( + x # type: ignore[arg-type] + ), + lambda x: RiverErrorTypeAdapter.validate_python( + x # type: ignore[arg-type] + ), + ) diff --git a/tests/v2/codegen/snapshot/snapshots/test_basic_stream/test_service/emit_error.py b/tests/v2/codegen/snapshot/snapshots/test_basic_stream/test_service/emit_error.py new file mode 100644 index 00000000..e7005c29 --- /dev/null +++ b/tests/v2/codegen/snapshot/snapshots/test_basic_stream/test_service/emit_error.py @@ -0,0 +1,40 @@ +# Code generated by river.codegen. DO NOT EDIT. +from collections.abc import AsyncIterable, AsyncIterator +import datetime +from typing import ( + Any, + Literal, + Mapping, + NotRequired, + TypedDict, +) +from typing_extensions import Annotated + +from pydantic import BaseModel, Field, TypeAdapter, WrapValidator +from replit_river.error_schema import RiverError +from replit_river.client import ( + RiverUnknownError, + translate_unknown_error, + RiverUnknownValue, + translate_unknown_value, +) + +import replit_river as river + + +class Emit_ErrorErrorsOneOf_DATA_LOSS(RiverError): + code: Literal["DATA_LOSS"] + message: str + + +class Emit_ErrorErrorsOneOf_UNEXPECTED_DISCONNECT(RiverError): + code: Literal["UNEXPECTED_DISCONNECT"] + message: str + + +Emit_ErrorErrors = Annotated[ + Emit_ErrorErrorsOneOf_DATA_LOSS + | Emit_ErrorErrorsOneOf_UNEXPECTED_DISCONNECT + | RiverUnknownError, + WrapValidator(translate_unknown_error), +] diff --git a/tests/v2/codegen/snapshot/snapshots/test_basic_stream/test_service/stream_method.py b/tests/v2/codegen/snapshot/snapshots/test_basic_stream/test_service/stream_method.py new file mode 100644 index 00000000..9a0fc5d0 --- /dev/null +++ b/tests/v2/codegen/snapshot/snapshots/test_basic_stream/test_service/stream_method.py @@ -0,0 +1,59 @@ +# Code generated by river.codegen. DO NOT EDIT. +from collections.abc import AsyncIterable, AsyncIterator +import datetime +from typing import ( + Any, + Literal, + Mapping, + NotRequired, + TypedDict, +) +from typing_extensions import Annotated + +from pydantic import BaseModel, Field, TypeAdapter, WrapValidator +from replit_river.error_schema import RiverError +from replit_river.client import ( + RiverUnknownError, + translate_unknown_error, + RiverUnknownValue, + translate_unknown_value, +) + +import replit_river as river + + +def encode_Stream_MethodInit( + _: "Stream_MethodInit", +) -> Any: + return {} + + +class Stream_MethodInit(TypedDict): + pass + + +def encode_Stream_MethodInput( + x: "Stream_MethodInput", +) -> Any: + return { + k: v + for (k, v) in ( + { + "data": x.get("data"), + } + ).items() + if v is not None + } + + +class Stream_MethodInput(TypedDict): + data: str + + +class Stream_MethodOutput(BaseModel): + data: str + + +Stream_MethodOutputTypeAdapter: TypeAdapter[Stream_MethodOutput] = TypeAdapter( + Stream_MethodOutput +) diff --git a/tests/v2/datagrams.py b/tests/v2/datagrams.py new file mode 100644 index 00000000..7a1bab06 --- /dev/null +++ b/tests/v2/datagrams.py @@ -0,0 +1,94 @@ +from dataclasses import dataclass +from typing import ( + Any, + NewType, + TypeAlias, +) + +Datagram = dict[str, Any] +TestTransport: TypeAlias = "FromClient | ToClient | WaitForClosed" + +StreamId = NewType("StreamId", str) +ClientId = NewType("ClientId", str) +ServerId = NewType("ServerId", str) +SessionId = NewType("SessionId", str) + + +@dataclass(frozen=True) +class StreamAlias: + alias_id: int + + +@dataclass(frozen=True) +class ValueSet: + seq: int + ack: int + from_: ServerId | None = None + to: ClientId | None = None + procedureName: str | None = None + serviceName: str | None = None + create_alias: StreamAlias | None = None + stream_alias: StreamAlias | None = None + payload: Datagram | None = None + + +@dataclass(frozen=True) +class FromClient: + handshake_request: tuple[ClientId, ServerId, SessionId] | ValueSet | None = None + stream_open: tuple[ClientId, ServerId, str, str, StreamId] | ValueSet | None = None + stream_frame: tuple[ClientId, ServerId, int, int, Datagram] | ValueSet | None = None + + +@dataclass(frozen=True) +class ToClient: + seq: int + ack: int + control_flags: int = 0 + handshake_response: bool | None = None + stream_frame: tuple[StreamAlias, Datagram] | None = None + stream_close: StreamAlias | None = None + + +@dataclass(frozen=True) +class WaitForClosed: + pass + + +def decode_FromClient(datagram: dict[str, Any]) -> FromClient: + assert "from" in datagram + assert "to" in datagram + if datagram.get("payload", {}).get("type") == "HANDSHAKE_REQ": + assert "payload" in datagram + assert "sessionId" in datagram["payload"] + return FromClient( + handshake_request=( + ClientId(datagram["from"]), + ServerId(datagram["to"]), + SessionId(datagram["payload"]["sessionId"]), + ) + ) + elif datagram.get("controlFlags", 0) & 0b00010 > 0: # STREAM_OPEN_BIT + return FromClient( + stream_open=( + ClientId(datagram["from"]), + ServerId(datagram["to"]), + datagram["serviceName"], + datagram["procedureName"], + StreamId(datagram["streamId"]), + ) + ) + elif datagram: + return FromClient( + stream_frame=( + ClientId(datagram["from"]), + ServerId(datagram["to"]), + datagram["seq"], + datagram["ack"], + datagram["payload"], + ) + ) + raise ValueError("Unexpected datagram: %r", datagram) + + +def parser(datagram: dict[str, Any]) -> FromClient: + return decode_FromClient(datagram) diff --git a/tests/v2/fixtures.py b/tests/v2/fixtures.py new file mode 100644 index 00000000..90bee603 --- /dev/null +++ b/tests/v2/fixtures.py @@ -0,0 +1,211 @@ +import asyncio +import copy +from collections import deque +from typing import ( + Any, + AsyncGenerator, + AsyncIterator, + TypeAlias, + assert_never, +) + +import msgpack +import pytest +from aiochannel import Channel +from websockets.asyncio.server import ServerConnection, serve + +from replit_river.transport_options import TransportOptions, UriAndMetadata +from replit_river.v2.client import Client +from tests.v2.datagrams import ( + FromClient, + TestTransport, + ToClient, + WaitForClosed, + parser, +) +from tests.v2.interpreter import make_interpreters + +# Client -> Server +ClientToServerChannel: TypeAlias = Channel[Any] +# Server -> Client +ServerToClientChannel: TypeAlias = Channel[Any] + + +@pytest.fixture +async def raw_websocket_meta() -> AsyncIterator[ + tuple[ServerToClientChannel, ClientToServerChannel, UriAndMetadata[None]] +]: + # server -> client + server_to_client = Channel[dict[str, Any]](maxsize=1) + # client -> server + client_to_server = Channel[dict[str, Any]](maxsize=1) + + # Service websocket connection + async def handle(websocket: ServerConnection) -> None: + async def emit() -> None: + async for msg in client_to_server: + await websocket.send(msgpack.packb(msg)) + + emit_task = asyncio.create_task(emit()) + try: + while message := await websocket.recv(decode=False): + assert isinstance(message, bytes) + unpacked = msgpack.unpackb(message, timestamp=3) + await server_to_client.put(unpacked) + finally: + server_to_client.close() + client_to_server.close() + emit_task.cancel() + await emit_task + return None + + ipv4_laddr: str | None = None + async with serve(handle, "localhost") as server: + for sock in server.sockets: + if (pair := sock.getsockname())[0] == "127.0.0.1": + ipv4_laddr = "ws://%s:%d" % pair + serve_forever = asyncio.create_task(server.serve_forever()) + assert ipv4_laddr + yield ( + server_to_client, + client_to_server, + UriAndMetadata(uri=ipv4_laddr, metadata=None), + ) + + serve_forever.cancel() + + +@pytest.fixture +async def bound_client( + raw_websocket_meta: tuple[ + ClientToServerChannel, + ServerToClientChannel, + UriAndMetadata[None], + ], + expected: deque[TestTransport], +) -> AsyncGenerator[Client[None], None]: + # Do our best to not modify the test data + _expected = expected + expected = copy.deepcopy(_expected) + + # client-to-server handler + # + # Consume FromClient events, optionally emitting datagrams to be written directly + # back to the client. + # + # This represents the direct request-response flow, but does not handle + # server-emitted events. + + client_to_server, server_to_client, uri_and_metadata = raw_websocket_meta + + async def messages_from_client() -> None: + async for msg in client_to_server: + parsed = parser(msg) + try: + next_expected = await anext(messages_from_client_channel) + except StopAsyncIteration: + break + response = from_client_interpreter(received=parsed, expected=next_expected) + if response is not None: + await server_to_client.put(response) + processing_finished.set() + + server_task = asyncio.create_task(messages_from_client()) + + # server-to-client handler + # + # Consume ToClient events, optionally emitting datagrams to be written directly to + # the client. + # + # This represents the other half of the "server" lifecycle, where the server can + # choose to emit events directly without a request. + + messages_to_client_channel = Channel[ToClient]() + + async def messages_to_client() -> None: + our_task = asyncio.current_task() + while our_task and not our_task.cancelled(): + try: + next_expected = await anext(messages_to_client_channel) + except StopAsyncIteration: + break + + response = to_client_interpreter(next_expected) + if response is not None: + await server_to_client.put(response) + processing_finished.set() + + processor_task = asyncio.create_task(messages_to_client()) + + # This consumes from the "expected" queue and routes messages to waiting channels + # + # This also handles shutdown + + async def driver() -> None: + while expected: + next_expected = expected.popleft() + if isinstance(next_expected, FromClient): + await messages_from_client_channel.put(next_expected) + elif isinstance(next_expected, ToClient): + await messages_to_client_channel.put(next_expected) + elif isinstance(next_expected, WaitForClosed): + countdown = 100 + messages_to_client_channel.close() + while not processor_task.done(): + if countdown <= 0: + break + countdown -= 1 + await asyncio.sleep(0.1) + client_to_server.close() + await client.close() + break + else: + assert_never(next_expected) + await asyncio.sleep(0.1) + await processing_finished.wait() + + driver_task = asyncio.create_task(driver()) + + # Watchdog keeps track of the above tasks + # async def watchdog() -> None: + # while True: + # print(repr(server_task)) + # print(repr(processor_task)) + # print(repr(driver_task)) + # await asyncio.sleep(1) + # + # watchdog_task = asyncio.create_task(watchdog()) + + async def uri_and_metadata_factory() -> UriAndMetadata[None]: + return uri_and_metadata + + client = Client( + uri_and_metadata_factory=uri_and_metadata_factory, + client_id="client-001", + server_id="server-001", + transport_options=TransportOptions( + close_session_check_interval_ms=500, + ), + ) + + from_client_interpreter, to_client_interpreter = make_interpreters() + + processing_finished = asyncio.Event() + + messages_from_client_channel = Channel[FromClient]() + + yield client + + await driver_task + + messages_to_client_channel.close() + messages_from_client_channel.close() + + processing_finished.set() + + server_task.cancel() + processor_task.cancel() + + await client.close() + await server_task + await processor_task diff --git a/tests/v2/interpreter.py b/tests/v2/interpreter.py new file mode 100644 index 00000000..b7cd254d --- /dev/null +++ b/tests/v2/interpreter.py @@ -0,0 +1,190 @@ +from typing import ( + Any, + Callable, + NotRequired, + Protocol, + TypedDict, +) + +import nanoid + +from tests.v2.datagrams import ( + ClientId, + FromClient, + ServerId, + StreamAlias, + StreamId, + ToClient, + ValueSet, +) + +Datagram = dict[str, Any] + + +class FromClientInterpreter(Protocol): + def __call__( + self, *, received: FromClient, expected: FromClient + ) -> Datagram | None: ... + + +class _TestTransportState(TypedDict): + streams: dict[StreamAlias, tuple[ClientId, ServerId, StreamId]] + client_id: NotRequired[str] + server_id: NotRequired[str] + + +def make_interpreters() -> tuple[ + FromClientInterpreter, + Callable[[ToClient], Datagram | None], +]: + state: _TestTransportState = { + "streams": {}, + } + + def from_client_interpreter( + received: FromClient, + expected: FromClient, + ) -> Datagram | None: + if isinstance(received.handshake_request, tuple): + (from_, to, session_id) = received.handshake_request + assert isinstance(expected.handshake_request, ValueSet), ( + f"Expected ValueSet 1: {repr(received)}, {repr(expected)}" + ) + assert expected.handshake_request.from_, ( + f"Expected {expected.handshake_request.from_}" + ) + assert expected.handshake_request.to, ( + "Expected {expected.handshake_request.to}" + ) + assert expected.handshake_request.from_ == to, ( + f"Expected {expected.handshake_request.from_} == {to}" + ) + assert expected.handshake_request.to == from_, ( + "Expected {expected.handshake_request.to} == {from_}" + ) + return _build_datagram( + seq=expected.handshake_request.seq, + packet_id=nanoid.generate(), + stream_id=nanoid.generate(), + control_flags=0, + ack=expected.handshake_request.ack, + client_id=expected.handshake_request.from_, + server_id=expected.handshake_request.to, + payload=_build_handshake_resp(session_id), + ) + elif isinstance(received.stream_open, tuple): + (from_, to, service_name, procedure_name, stream_id) = received.stream_open + assert isinstance(expected.stream_open, ValueSet), "Expected ValueSet 2" + assert expected.stream_open.from_ == to, ( + f"Expected {expected.stream_open.from_} == {to}" + ) + assert expected.stream_open.to == from_, ( + "Expected {expected.stream_open.to} == {from_}" + ) + assert expected.stream_open.serviceName == service_name, ( + f"Expected {expected.stream_open.serviceName} == {service_name}" + ) + assert expected.stream_open.procedureName == procedure_name, ( + f"Expected {expected.stream_open.procedureName} == {procedure_name}" + ) + assert expected.stream_open.create_alias, ( + "Expected create_alias to be a StreamAlias" + ) + # Do it all again because mypy can't infer correctly + alias_mapping: tuple[ClientId, ServerId, StreamId] = ( + ClientId(from_), + ServerId(to), + StreamId(stream_id), + ) + state["streams"][expected.stream_open.create_alias] = alias_mapping + return None + elif isinstance(received.stream_frame, tuple): + (from_, to, seq, ack, payload) = received.stream_frame + assert isinstance(expected.stream_frame, ValueSet), "Expected ValueSet 3" + assert seq == expected.stream_frame.seq, ( + f"Expected seq {seq} == {expected.stream_frame.seq}" + ) + assert ack == expected.stream_frame.ack, ( + f"Expected ack {ack} == {expected.stream_frame.ack}" + ) + assert expected.stream_frame.stream_alias, "Expected stream_alias" + (from_, to, stream_id) = state["streams"][ + expected.stream_frame.stream_alias + ] + assert expected.stream_frame.payload == payload, ( + f"Expected {expected.stream_frame.payload} == {payload}" + ) + return None + raise ValueError("Unexpected from_client case: %r", received) + + def to_client_interpreter(expected: ToClient) -> Datagram | None: + if expected.stream_frame is not None: + (stream_alias, payload) = expected.stream_frame + (from_, to, stream_id) = state["streams"][stream_alias] + return _build_datagram( + seq=expected.seq, + ack=expected.ack, + control_flags=expected.control_flags, + packet_id=nanoid.generate(), + stream_id=stream_id, + client_id=from_, + server_id=to, + payload=payload, + ) + elif expected.stream_close is not None: + stream_alias = expected.stream_close + (from_, to, stream_id) = state["streams"][stream_alias] + return _build_datagram( + seq=expected.seq, + ack=expected.ack, + control_flags=0b01000, + packet_id=nanoid.generate(), + stream_id=stream_id, + client_id=from_, + server_id=to, + payload={"type": "CLOSE"}, + ) + raise ValueError("Unexpected to_client case: %r", expected) + + return from_client_interpreter, to_client_interpreter + + +def _strip_none(datagram: Datagram) -> Datagram: + return {k: v for k, v in datagram.items() if v is not None} + + +def _build_handshake_resp(session_id: str) -> Datagram: + return _strip_none( + { + "type": "HANDSHAKE_RESP", + "status": { + "ok": True, + "sessionId": session_id, + }, + } + ) + + +def _build_datagram( + *, + packet_id: str, + stream_id: str, + server_id: str, + client_id: str, + control_flags: int, + seq: int, + ack: int, + payload: Datagram, +) -> Datagram: + return _strip_none( + { + "id": packet_id, + "from": server_id, + "to": client_id, + "seq": seq, + "ack": ack, + "streamId": stream_id, + "controlFlags": control_flags, + "payload": payload, + } + ) diff --git a/tests/v2/test_stream.py b/tests/v2/test_stream.py new file mode 100644 index 00000000..b76ff906 --- /dev/null +++ b/tests/v2/test_stream.py @@ -0,0 +1,194 @@ +import importlib +import logging +from collections import deque +from typing import ( + AsyncIterable, + Literal, +) + +import pytest +from pytest_snapshot.plugin import Snapshot + +from replit_river.v2.client import Client +from tests.fixtures.codegen_snapshot_fixtures import validate_codegen +from tests.v2.datagrams import ( + ClientId, + FromClient, + ServerId, + StreamAlias, + TestTransport, + ToClient, + ValueSet, + WaitForClosed, +) + +logger = logging.getLogger(__name__) + +_AlreadyGenerated = False + + +@pytest.fixture(scope="function", autouse=True) +def stream_client_codegen(snapshot: Snapshot) -> Literal[True]: + global _AlreadyGenerated + if not _AlreadyGenerated: + validate_codegen( + snapshot=snapshot, + snapshot_dir="tests/v2/codegen/snapshot/snapshots", + read_schema=lambda: open("tests/v2/test_stream.schema.json"), + target_path="test_basic_stream", + client_name="StreamClient", + protocol_version="v2.0", + ) + _AlreadyGenerated = True + + import tests.v2.codegen.snapshot.snapshots.test_basic_stream + + importlib.reload(tests.v2.codegen.snapshot.snapshots.test_basic_stream) + return True + + +stream_expected: deque[TestTransport] = deque( + [ + FromClient( + handshake_request=ValueSet( + seq=0, # These don't count due to being during a handshake + ack=0, + from_=ServerId("server-001"), + to=ClientId("client-001"), + ) + ), + FromClient( + stream_open=ValueSet( + seq=0, + ack=0, + from_=ServerId("server-001"), + to=ClientId("client-001"), + serviceName="test_service", + procedureName="stream_method", + create_alias=StreamAlias(1), + ) + ), + FromClient( + stream_frame=ValueSet( + seq=1, + ack=0, + stream_alias=StreamAlias(1), + payload={"data": "0"}, + ) + ), + ToClient( + seq=0, + ack=1, + stream_frame=( + StreamAlias(1), + {"ok": True, "payload": {"data": "Stream response for 0"}}, + ), + ), + FromClient( + stream_frame=ValueSet( + seq=2, + ack=0, + stream_alias=StreamAlias(1), + payload={"data": "1"}, + ) + ), + ToClient( + seq=1, + ack=2, + stream_frame=( + StreamAlias(1), + {"ok": True, "payload": {"data": "Stream response for 1"}}, + ), + ), + FromClient( + stream_frame=ValueSet( + seq=3, + ack=0, + stream_alias=StreamAlias(1), + payload={"data": "2"}, + ) + ), + ToClient( + seq=2, + ack=0, + stream_frame=( + StreamAlias(1), + {"ok": True, "payload": {"data": "Stream response for 2"}}, + ), + ), + FromClient( + stream_frame=ValueSet( + seq=4, + ack=0, + stream_alias=StreamAlias(1), + payload={"data": "3"}, + ) + ), + ToClient( + seq=3, + ack=4, + stream_frame=( + StreamAlias(1), + {"ok": True, "payload": {"data": "Stream response for 3"}}, + ), + ), + FromClient( + stream_frame=ValueSet( + seq=5, + ack=0, + stream_alias=StreamAlias(1), + payload={"data": "4"}, + ) + ), + ToClient( + seq=4, + ack=5, + stream_frame=( + StreamAlias(1), + {"ok": True, "payload": {"data": "Stream response for 4"}}, + ), + ), + FromClient( + stream_frame=ValueSet( + seq=6, + ack=0, + stream_alias=StreamAlias(1), + payload={"type": "CLOSE"}, + ) + ), + ToClient( + seq=5, + ack=0, + stream_close=StreamAlias(1), + ), + WaitForClosed(), + ] +) + + +@pytest.mark.parametrize("expected", [stream_expected]) +async def test_stream(bound_client: Client) -> None: + from tests.v2.codegen.snapshot.snapshots.test_basic_stream import ( + StreamClient, + ) + from tests.v2.codegen.snapshot.snapshots.test_basic_stream.test_service.stream_method import ( # noqa: E501 + Stream_MethodInput, + Stream_MethodOutput, + ) + + async def emit() -> AsyncIterable[Stream_MethodInput]: + for i in range(5): + data: Stream_MethodInput = {"data": str(i)} + yield data + + res = await StreamClient(bound_client).test_service.stream_method( + init={}, + inputStream=emit(), + ) + + i = 0 + async for datum in res: + assert isinstance(datum, Stream_MethodOutput) + assert f"Stream response for {i}" == datum.data, f"{i} == {datum.data}" + i = i + 1 + assert i == 5 diff --git a/tests/v2/test_stream.schema.json b/tests/v2/test_stream.schema.json new file mode 100644 index 00000000..5b1f34f8 --- /dev/null +++ b/tests/v2/test_stream.schema.json @@ -0,0 +1,36 @@ +{ + "services": { + "test_service": { + "procedures": { + "stream_method": { + "init": { + "type": "object", + "properties": {} + }, + "input": { + "type": "object", + "properties": { + "data": { + "type": "string" + } + }, + "required": ["data"] + }, + "output": { + "type": "object", + "properties": { + "data": { + "type": "string" + } + }, + "required": ["data"] + }, + "errors": { + "not": {} + }, + "type": "stream" + } + } + } + } +} From 2ec2cb543bb345f8a3c3f2e88a4e88695edd37ab Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Wed, 2 Apr 2025 15:58:50 -0700 Subject: [PATCH 135/193] Codegen should provide useful errors --- src/replit_river/codegen/client.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/replit_river/codegen/client.py b/src/replit_river/codegen/client.py index 42e87a0f..d7996500 100644 --- a/src/replit_river/codegen/client.py +++ b/src/replit_river/codegen/client.py @@ -809,11 +809,11 @@ def render_library_call( if procedure.type == "rpc": match protocol_version: case "v1.1": - assert input_meta + assert input_meta, "rpc expects input to be required" _, tpe, render_method = input_meta binding = "input" case "v2.0": - assert init_meta + assert init_meta, "rpc expects init to be required" _, tpe, render_method = init_meta binding = "init" case other: @@ -850,11 +850,11 @@ async def {name}( elif procedure.type == "subscription": match protocol_version: case "v1.1": - assert input_meta + assert input_meta, "rpc expects input to be required" _, tpe, render_method = input_meta binding = "input" case "v2.0": - assert init_meta + assert init_meta, "rpc expects init to be required" _, tpe, render_method = init_meta binding = "init" case other: From 5ca64c3f7eb1471f7bc168a76d78dd171e8dc436 Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Wed, 2 Apr 2025 16:00:07 -0700 Subject: [PATCH 136/193] Ensure that the whole test state has been consumed --- tests/v2/fixtures.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/tests/v2/fixtures.py b/tests/v2/fixtures.py index 90bee603..eae41605 100644 --- a/tests/v2/fixtures.py +++ b/tests/v2/fixtures.py @@ -198,6 +198,14 @@ async def uri_and_metadata_factory() -> UriAndMetadata[None]: await driver_task + assert len(expected) == 0, "Unconsumed messages from 'expected'" + assert messages_to_client_channel.qsize() == 0, ( + "Dangling messages the client has not consumed" + ) + assert messages_from_client_channel.qsize() == 0, ( + "Dangling messages the processor has not consumed" + ) + messages_to_client_channel.close() messages_from_client_channel.close() From 9e803d676d63aa02dfd205aa2d34c2e9bfc339af Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Wed, 2 Apr 2025 16:01:10 -0700 Subject: [PATCH 137/193] Add v2 rpc tests --- .../snapshots/test_basic_rpc/__init__.py | 13 +++ .../test_basic_rpc/test_service/__init__.py | 41 ++++++++ .../test_basic_rpc/test_service/rpc_method.py | 49 ++++++++++ tests/v2/test_rpc.py | 97 +++++++++++++++++++ tests/v2/test_rpc.schema.json | 32 ++++++ 5 files changed, 232 insertions(+) create mode 100644 tests/v2/codegen/snapshot/snapshots/test_basic_rpc/__init__.py create mode 100644 tests/v2/codegen/snapshot/snapshots/test_basic_rpc/test_service/__init__.py create mode 100644 tests/v2/codegen/snapshot/snapshots/test_basic_rpc/test_service/rpc_method.py create mode 100644 tests/v2/test_rpc.py create mode 100644 tests/v2/test_rpc.schema.json diff --git a/tests/v2/codegen/snapshot/snapshots/test_basic_rpc/__init__.py b/tests/v2/codegen/snapshot/snapshots/test_basic_rpc/__init__.py new file mode 100644 index 00000000..58d21c8a --- /dev/null +++ b/tests/v2/codegen/snapshot/snapshots/test_basic_rpc/__init__.py @@ -0,0 +1,13 @@ +# Code generated by river.codegen. DO NOT EDIT. +from pydantic import BaseModel +from typing import Literal + +import replit_river as river + + +from .test_service import Test_ServiceService + + +class StreamClient: + def __init__(self, client: river.v2.Client[Literal[None]]): + self.test_service = Test_ServiceService(client) diff --git a/tests/v2/codegen/snapshot/snapshots/test_basic_rpc/test_service/__init__.py b/tests/v2/codegen/snapshot/snapshots/test_basic_rpc/test_service/__init__.py new file mode 100644 index 00000000..f7c04990 --- /dev/null +++ b/tests/v2/codegen/snapshot/snapshots/test_basic_rpc/test_service/__init__.py @@ -0,0 +1,41 @@ +# Code generated by river.codegen. DO NOT EDIT. +from collections.abc import AsyncIterable, AsyncIterator +from typing import Any +import datetime + +from pydantic import TypeAdapter + +from replit_river.error_schema import RiverError, RiverErrorTypeAdapter +import replit_river as river + + +from .rpc_method import ( + Rpc_MethodInit, + Rpc_MethodOutput, + Rpc_MethodOutputTypeAdapter, + encode_Rpc_MethodInit, +) + + +class Test_ServiceService: + def __init__(self, client: river.v2.Client[Any]): + self.client = client + + async def rpc_method( + self, + init: Rpc_MethodInit, + timeout: datetime.timedelta, + ) -> Rpc_MethodOutput: + return await self.client.send_rpc( + "test_service", + "rpc_method", + init, + encode_Rpc_MethodInit, + lambda x: Rpc_MethodOutputTypeAdapter.validate_python( + x # type: ignore[arg-type] + ), + lambda x: RiverErrorTypeAdapter.validate_python( + x # type: ignore[arg-type] + ), + timeout, + ) diff --git a/tests/v2/codegen/snapshot/snapshots/test_basic_rpc/test_service/rpc_method.py b/tests/v2/codegen/snapshot/snapshots/test_basic_rpc/test_service/rpc_method.py new file mode 100644 index 00000000..5fe764eb --- /dev/null +++ b/tests/v2/codegen/snapshot/snapshots/test_basic_rpc/test_service/rpc_method.py @@ -0,0 +1,49 @@ +# Code generated by river.codegen. DO NOT EDIT. +from collections.abc import AsyncIterable, AsyncIterator +import datetime +from typing import ( + Any, + Literal, + Mapping, + NotRequired, + TypedDict, +) +from typing_extensions import Annotated + +from pydantic import BaseModel, Field, TypeAdapter, WrapValidator +from replit_river.error_schema import RiverError +from replit_river.client import ( + RiverUnknownError, + translate_unknown_error, + RiverUnknownValue, + translate_unknown_value, +) + +import replit_river as river + + +def encode_Rpc_MethodInit( + x: "Rpc_MethodInit", +) -> Any: + return { + k: v + for (k, v) in ( + { + "data": x.get("data"), + } + ).items() + if v is not None + } + + +class Rpc_MethodInit(TypedDict): + data: str + + +class Rpc_MethodOutput(BaseModel): + data: str + + +Rpc_MethodOutputTypeAdapter: TypeAdapter[Rpc_MethodOutput] = TypeAdapter( + Rpc_MethodOutput +) diff --git a/tests/v2/test_rpc.py b/tests/v2/test_rpc.py new file mode 100644 index 00000000..3f6fe645 --- /dev/null +++ b/tests/v2/test_rpc.py @@ -0,0 +1,97 @@ +import importlib +import logging +from collections import deque +from datetime import timedelta +from typing import ( + Literal, +) + +import pytest +from pytest_snapshot.plugin import Snapshot + +from replit_river.v2.client import Client +from tests.fixtures.codegen_snapshot_fixtures import validate_codegen +from tests.v2.datagrams import ( + ClientId, + FromClient, + ServerId, + StreamAlias, + TestTransport, + ToClient, + ValueSet, + WaitForClosed, +) + +logger = logging.getLogger(__name__) + +_AlreadyGenerated = False + + +@pytest.fixture(scope="function", autouse=True) +def stream_client_codegen(snapshot: Snapshot) -> Literal[True]: + global _AlreadyGenerated + if not _AlreadyGenerated: + validate_codegen( + snapshot=snapshot, + snapshot_dir="tests/v2/codegen/snapshot/snapshots", + read_schema=lambda: open("tests/v2/test_rpc.schema.json"), + target_path="test_basic_rpc", + client_name="StreamClient", + protocol_version="v2.0", + ) + _AlreadyGenerated = True + + import tests.v2.codegen.snapshot.snapshots.test_basic_stream + + importlib.reload(tests.v2.codegen.snapshot.snapshots.test_basic_stream) + return True + + +rpc_expected: deque[TestTransport] = deque( + [ + FromClient( + handshake_request=ValueSet( + seq=0, # These don't count due to being during a handshake + ack=0, + from_=ServerId("server-001"), + to=ClientId("client-001"), + ) + ), + FromClient( + stream_open=ValueSet( + seq=0, + ack=0, + from_=ServerId("server-001"), + to=ClientId("client-001"), + serviceName="test_service", + procedureName="rpc_method", + create_alias=StreamAlias(1), + payload={"data": "foo"}, + stream_closed=True, + ) + ), + ToClient( + seq=0, + ack=1, + stream_frame=( + StreamAlias(1), + {"ok": True, "payload": {"data": "Hello, foo!"}}, + ), + ), + WaitForClosed(), + ] +) + + +@pytest.mark.parametrize("expected", [rpc_expected]) +async def test_rpc(bound_client: Client) -> None: + from tests.v2.codegen.snapshot.snapshots.test_basic_rpc import ( + StreamClient, + ) + + res = await StreamClient(bound_client).test_service.rpc_method( + init={"data": "foo"}, + timeout=timedelta(seconds=5), + ) + + assert res.data == "Hello, foo!" diff --git a/tests/v2/test_rpc.schema.json b/tests/v2/test_rpc.schema.json new file mode 100644 index 00000000..d53a1878 --- /dev/null +++ b/tests/v2/test_rpc.schema.json @@ -0,0 +1,32 @@ +{ + "services": { + "test_service": { + "procedures": { + "rpc_method": { + "init": { + "type": "object", + "properties": { + "data": { + "type": "string" + } + }, + "required": ["data"] + }, + "output": { + "type": "object", + "properties": { + "data": { + "type": "string" + } + }, + "required": ["data"] + }, + "errors": { + "not": {} + }, + "type": "rpc" + } + } + } + } +} From 9e5ded0866d8e477ce89d97dcf2665d7b9cdac3c Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Wed, 2 Apr 2025 16:11:53 -0700 Subject: [PATCH 138/193] Add capacity for self-closing messages --- tests/v2/datagrams.py | 12 +++++++++++- tests/v2/interpreter.py | 12 +++++++++++- 2 files changed, 22 insertions(+), 2 deletions(-) diff --git a/tests/v2/datagrams.py b/tests/v2/datagrams.py index 7a1bab06..78feb447 100644 --- a/tests/v2/datagrams.py +++ b/tests/v2/datagrams.py @@ -1,6 +1,7 @@ from dataclasses import dataclass from typing import ( Any, + Literal, NewType, TypeAlias, ) @@ -30,13 +31,17 @@ class ValueSet: create_alias: StreamAlias | None = None stream_alias: StreamAlias | None = None payload: Datagram | None = None + stream_closed: Literal[True] | None = None @dataclass(frozen=True) class FromClient: handshake_request: tuple[ClientId, ServerId, SessionId] | ValueSet | None = None - stream_open: tuple[ClientId, ServerId, str, str, StreamId] | ValueSet | None = None + stream_open: ( + tuple[ClientId, ServerId, str, str, StreamId, Datagram] | ValueSet | None + ) = None stream_frame: tuple[ClientId, ServerId, int, int, Datagram] | ValueSet | None = None + stream_closed: Literal[True] | None = None @dataclass(frozen=True) @@ -75,7 +80,12 @@ def decode_FromClient(datagram: dict[str, Any]) -> FromClient: datagram["serviceName"], datagram["procedureName"], StreamId(datagram["streamId"]), + datagram["payload"], + ), + stream_closed=( + datagram["controlFlags"] & 0b01000 > 0 # STREAM_CLOSED_BIT ) + or None, ) elif datagram: return FromClient( diff --git a/tests/v2/interpreter.py b/tests/v2/interpreter.py index b7cd254d..d706316c 100644 --- a/tests/v2/interpreter.py +++ b/tests/v2/interpreter.py @@ -73,7 +73,9 @@ def from_client_interpreter( payload=_build_handshake_resp(session_id), ) elif isinstance(received.stream_open, tuple): - (from_, to, service_name, procedure_name, stream_id) = received.stream_open + (from_, to, service_name, procedure_name, stream_id, payload) = ( + received.stream_open + ) assert isinstance(expected.stream_open, ValueSet), "Expected ValueSet 2" assert expected.stream_open.from_ == to, ( f"Expected {expected.stream_open.from_} == {to}" @@ -90,6 +92,14 @@ def from_client_interpreter( assert expected.stream_open.create_alias, ( "Expected create_alias to be a StreamAlias" ) + if expected.stream_open.payload is not None and payload is not None: + assert expected.stream_open.payload == payload, ( + f"Expected {expected.stream_open.payload} == {payload}" + ) + assert expected.stream_open.stream_closed or not received.stream_closed, ( + f"Are we self-closing? {expected.stream_open.stream_closed} " + f"or not {received.stream_closed}" + ) # Do it all again because mypy can't infer correctly alias_mapping: tuple[ClientId, ServerId, StreamId] = ( ClientId(from_), From 4b52339d30188dd552ca20c3a232dae7bade9850 Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Wed, 2 Apr 2025 16:24:52 -0700 Subject: [PATCH 139/193] Unique names for tests --- tests/v2/{test_rpc.py => test_v2_rpc.py} | 0 tests/v2/{test_stream.py => test_v2_stream.py} | 0 2 files changed, 0 insertions(+), 0 deletions(-) rename tests/v2/{test_rpc.py => test_v2_rpc.py} (100%) rename tests/v2/{test_stream.py => test_v2_stream.py} (100%) diff --git a/tests/v2/test_rpc.py b/tests/v2/test_v2_rpc.py similarity index 100% rename from tests/v2/test_rpc.py rename to tests/v2/test_v2_rpc.py diff --git a/tests/v2/test_stream.py b/tests/v2/test_v2_stream.py similarity index 100% rename from tests/v2/test_stream.py rename to tests/v2/test_v2_stream.py From 7d16324ba98cef1a6ce6d220dc2ff3258fcaf803 Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Wed, 2 Apr 2025 16:35:34 -0700 Subject: [PATCH 140/193] Avoid waking up streams that are not waiting Presumably this is limited to heartbeats. --- src/replit_river/v2/session.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/replit_river/v2/session.py b/src/replit_river/v2/session.py index 2fe4b9a2..8102b5e4 100644 --- a/src/replit_river/v2/session.py +++ b/src/replit_river/v2/session.py @@ -462,8 +462,9 @@ def commit(msg: TransportMessage) -> None: self._process_messages.clear() # Wake up backpressured writer - backpressure_waiter, _ = self._streams[pending.streamId] - backpressure_waiter.set() + stream_meta = self._streams.get(pending.streamId) + if stream_meta: + stream_meta[0].set() def get_next_pending() -> TransportMessage | None: if self._send_buffer: From 91bdf2e21c80ffafa22192d118a8de697d80d2eb Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Wed, 2 Apr 2025 17:16:04 -0700 Subject: [PATCH 141/193] We still need to be awake to catch heartbeat timeouts even if we know we are disconnected --- src/replit_river/v2/session.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/src/replit_river/v2/session.py b/src/replit_river/v2/session.py index 8102b5e4..802bf511 100644 --- a/src/replit_river/v2/session.py +++ b/src/replit_river/v2/session.py @@ -1189,10 +1189,6 @@ async def _setup_heartbeat( break await asyncio.sleep(heartbeat_ms / 1000) - state = get_state() - if state in ConnectingStates: - logger.debug("Websocket is not connected, don't expect heartbeat") - continue if increment_and_get_heartbeat_misses() > heartbeats_until_dead: if get_closing_grace_period() is not None: From c94e0df30accce7870f01fa537a154afeea2ba8d Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Wed, 2 Apr 2025 17:38:27 -0700 Subject: [PATCH 142/193] Ripping out all the heartbeat stuff in favor of server-directed signaling --- src/replit_river/v2/session.py | 131 +++++---------------------------- 1 file changed, 17 insertions(+), 114 deletions(-) diff --git a/src/replit_river/v2/session.py b/src/replit_river/v2/session.py index 802bf511..b2381445 100644 --- a/src/replit_river/v2/session.py +++ b/src/replit_river/v2/session.py @@ -28,6 +28,7 @@ from pydantic import ValidationError from websockets.asyncio.client import ClientConnection from websockets.exceptions import ConnectionClosed, ConnectionClosedOK +from websockets.protocol import CLOSED from replit_river.common_session import ( ConnectingStates, @@ -199,7 +200,6 @@ def __init__( # Terminating self._terminating_task = None - self._start_heartbeat() self._start_serve_responses() self._start_close_session_checker() self._start_buffered_message_sender() @@ -497,65 +497,24 @@ async def block_until_message_available() -> None: ) def _start_close_session_checker(self) -> None: - def do_close() -> None: - # Avoid closing twice - if self._terminating_task is None: - # We can't just call self.close() directly because - # we're inside a thread that will eventually be awaited - # during the cleanup procedure. - self._terminating_task = asyncio.create_task(self.close()) + def transition_connecting() -> None: + if self._state in TerminalStates: + return + self._state = SessionState.CONNECTING + self._wait_for_connected.clear() self._task_manager.create_task( _check_to_close_session( self._transport_id, self._transport_options.close_session_check_interval_ms, lambda: self._state, - self._get_current_time, - lambda: self._close_session_after_time_secs, - do_close=do_close, - ) - ) - - def _start_heartbeat(self) -> None: - async def close_websocket() -> None: - logger.debug( - "close_websocket called, _state=%r, _ws=%r", - self._state, - self._ws, - ) - if self._ws: - self._task_manager.create_task(self._ws.close()) - self._ws = None - - if self._retry_connection_callback: - self._task_manager.create_task(self._retry_connection_callback()) - else: - self._state = SessionState.CLOSING - - await self._begin_close_session_countdown() - - def increment_and_get_heartbeat_misses() -> int: - self._heartbeat_misses += 1 - return self._heartbeat_misses - - async def block_until_connected() -> None: - await self._wait_for_connected.wait() - - self._task_manager.create_task( - _setup_heartbeat( - block_until_connected, - self.session_id, - self._transport_options.heartbeat_ms, - self._transport_options.heartbeats_until_dead, - lambda: self._state, - lambda: self._close_session_after_time_secs, - close_websocket=close_websocket, - increment_and_get_heartbeat_misses=increment_and_get_heartbeat_misses, + lambda: self._ws, + transition_connecting=transition_connecting, ) ) def _start_serve_responses(self) -> None: - async def transition_connecting() -> None: + def transition_connecting() -> None: if self._state in TerminalStates: return self._state = SessionState.CONNECTING @@ -973,31 +932,16 @@ async def _check_to_close_session( transport_id: str, close_session_check_interval_ms: float, get_state: Callable[[], SessionState], - get_current_time: Callable[[], Awaitable[float]], - get_close_session_after_time_secs: Callable[[], float | None], - do_close: Callable[[], None], + get_ws: Callable[[], ClientConnection | None], + transition_connecting: Callable[[], None], ) -> None: - our_task = asyncio.current_task() - while our_task and not our_task.cancelling() and not our_task.cancelled(): + while get_state() not in TerminalStates: logger.debug("_check_to_close_session: Checking") await asyncio.sleep(close_session_check_interval_ms / 1000) - if get_state() in TerminalStates: - # already closing - break - # calculate the value now before comparing it so that there are no - # await points between the check and the comparison to avoid a TOCTOU - # race. - current_time = await get_current_time() - close_session_after_time_secs = get_close_session_after_time_secs() - if not close_session_after_time_secs: - logger.debug( - f"_check_to_close_session: Not reached: {close_session_after_time_secs}" - ) - continue - if current_time > close_session_after_time_secs: + + if not (ws := get_ws()) or ws.protocol.state is CLOSED: logger.info("Grace period ended for %s, closing session", transport_id) - do_close() - our_task.cancel() + transition_connecting() async def _do_ensure_connected[HandshakeMetadata]( @@ -1160,53 +1104,12 @@ async def websocket_closed_callback() -> None: return None -async def _setup_heartbeat( - block_until_connected: Callable[[], Awaitable[None]], - session_id: str, - heartbeat_ms: float, - heartbeats_until_dead: int, - get_state: Callable[[], SessionState], - get_closing_grace_period: Callable[[], float | None], - close_websocket: Callable[[], Awaitable[None]], - increment_and_get_heartbeat_misses: Callable[[], int], -) -> None: - while True: - while (state := get_state()) in ConnectingStates: - logger.debug( - "Heartbeat: block_until_connected: %r", - state, - ) - await block_until_connected() - - if state in TerminalStates: - logger.debug( - "Session is closed, no need to send heartbeat, state : " - "%r close_session_after_this: %r", - state, - get_closing_grace_period(), - ) - # session is closing / closed, no need to send heartbeat anymore - break - - await asyncio.sleep(heartbeat_ms / 1000) - - if increment_and_get_heartbeat_misses() > heartbeats_until_dead: - if get_closing_grace_period() is not None: - # already in grace period, no need to set again - continue - logger.info( - "%r closing websocket because of heartbeat misses", - session_id, - ) - await close_websocket() - - async def _serve( block_until_connected: Callable[[], Awaitable[None]], transport_id: str, get_state: Callable[[], SessionState], get_ws: Callable[[], ClientConnection | None], - transition_connecting: Callable[[], Awaitable[None]], + transition_connecting: Callable[[], None], transition_no_connection: Callable[[], Awaitable[None]], reset_session_close_countdown: Callable[[], None], close_session: Callable[[], Awaitable[None]], @@ -1262,7 +1165,7 @@ async def _serve( try: message = await ws.recv(decode=False) except ConnectionClosed: - await transition_connecting() + transition_connecting() continue try: msg = parse_transport_msg(message) From caa2f4ab67611bc8ad16577eed38994cb820c0c6 Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Wed, 2 Apr 2025 17:46:06 -0700 Subject: [PATCH 143/193] Transition connecting --- src/replit_river/v2/session.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/src/replit_river/v2/session.py b/src/replit_river/v2/session.py index b2381445..21cef0ba 100644 --- a/src/replit_river/v2/session.py +++ b/src/replit_river/v2/session.py @@ -505,7 +505,6 @@ def transition_connecting() -> None: self._task_manager.create_task( _check_to_close_session( - self._transport_id, self._transport_options.close_session_check_interval_ms, lambda: self._state, lambda: self._ws, @@ -929,7 +928,6 @@ async def _send_close_stream( async def _check_to_close_session( - transport_id: str, close_session_check_interval_ms: float, get_state: Callable[[], SessionState], get_ws: Callable[[], ClientConnection | None], @@ -939,8 +937,8 @@ async def _check_to_close_session( logger.debug("_check_to_close_session: Checking") await asyncio.sleep(close_session_check_interval_ms / 1000) - if not (ws := get_ws()) or ws.protocol.state is CLOSED: - logger.info("Grace period ended for %s, closing session", transport_id) + if (ws := get_ws()) and ws.protocol.state is CLOSED: + logger.info("Websocket is closed, transitioning to connecting") transition_connecting() From a55b07c2a7b615676d69ba23c2fa293d43884067 Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Wed, 2 Apr 2025 17:54:09 -0700 Subject: [PATCH 144/193] Make ensure_connected callable based on our own internal state --- src/replit_river/v2/client_transport.py | 9 ++++--- src/replit_river/v2/session.py | 32 +++++++++++++++---------- 2 files changed, 24 insertions(+), 17 deletions(-) diff --git a/src/replit_river/v2/client_transport.py b/src/replit_river/v2/client_transport.py index 854eba3c..d74a2558 100644 --- a/src/replit_river/v2/client_transport.py +++ b/src/replit_river/v2/client_transport.py @@ -57,16 +57,15 @@ async def get_or_create_session(self) -> Session: transport_options=self._transport_options, close_session_callback=self._delete_session, retry_connection_callback=self._retry_connection, + uri_and_metadata_factory=self._uri_and_metadata_factory, + rate_limiter=self._rate_limiter, + client_id=self._client_id, ) self._session = new_session existing_session = new_session - await existing_session.ensure_connected( - client_id=self._client_id, - rate_limiter=self._rate_limiter, - uri_and_metadata_factory=self._uri_and_metadata_factory, - ) + await existing_session.ensure_connected() return existing_session async def _retry_connection(self) -> Session: diff --git a/src/replit_river/v2/session.py b/src/replit_river/v2/session.py index 21cef0ba..85446042 100644 --- a/src/replit_river/v2/session.py +++ b/src/replit_river/v2/session.py @@ -119,7 +119,7 @@ class _IgnoreMessage: pass -class Session: +class Session[HandshakeMetadata]: _transport_id: str _to_id: str session_id: str @@ -132,6 +132,12 @@ class Session: _connecting_task: asyncio.Task[None] | None _wait_for_connected: asyncio.Event + _client_id: str + _rate_limiter: LeakyBucketRateLimit + _uri_and_metadata_factory: Callable[ + [], Awaitable[UriAndMetadata[HandshakeMetadata]] + ] + # ws state _ws: ClientConnection | None _heartbeat_misses: int @@ -161,6 +167,11 @@ def __init__( session_id: str, transport_options: TransportOptions, close_session_callback: CloseSessionCallback, + client_id: str, + rate_limiter: LeakyBucketRateLimit, + uri_and_metadata_factory: Callable[ + [], Awaitable[UriAndMetadata[HandshakeMetadata]] + ], retry_connection_callback: RetryConnectionCallback | None = None, ) -> None: self._transport_id = transport_id @@ -175,6 +186,10 @@ def __init__( self._connecting_task = None self._wait_for_connected = asyncio.Event() + self._client_id = client_id + self._rate_limiter = rate_limiter + self._uri_and_metadata_factory = uri_and_metadata_factory + # ws state self._ws = None self._heartbeat_misses = 0 @@ -204,14 +219,7 @@ def __init__( self._start_close_session_checker() self._start_buffered_message_sender() - async def ensure_connected[HandshakeMetadata]( - self, - client_id: str, - rate_limiter: LeakyBucketRateLimit, - uri_and_metadata_factory: Callable[ - [], Awaitable[UriAndMetadata[HandshakeMetadata]] - ], - ) -> None: + async def ensure_connected(self) -> None: """ Either return immediately or establish a websocket connection and return once we can accept messages. @@ -279,12 +287,12 @@ def finalize_attempt() -> None: self._connecting_task = asyncio.create_task( _do_ensure_connected( transport_id=self._transport_id, - client_id=client_id, + client_id=self._client_id, to_id=self._to_id, session_id=self.session_id, max_retry=self._transport_options.connection_retry_options.max_retry, - rate_limiter=rate_limiter, - uri_and_metadata_factory=uri_and_metadata_factory, + rate_limiter=self._rate_limiter, + uri_and_metadata_factory=self._uri_and_metadata_factory, get_next_sent_seq=get_next_sent_seq, get_current_ack=lambda: self.ack, get_current_time=self._get_current_time, From 8bf1de46b6927d3da82a15b1ff0290c38c42b5b7 Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Wed, 2 Apr 2025 18:02:46 -0700 Subject: [PATCH 145/193] Exploratory, just transition_no_connection from close checker --- src/replit_river/v2/session.py | 66 ++++++++++++++++++---------------- 1 file changed, 36 insertions(+), 30 deletions(-) diff --git a/src/replit_river/v2/session.py b/src/replit_river/v2/session.py index 85446042..6b98b763 100644 --- a/src/replit_river/v2/session.py +++ b/src/replit_river/v2/session.py @@ -215,8 +215,25 @@ def __init__( # Terminating self._terminating_task = None - self._start_serve_responses() - self._start_close_session_checker() + async def transition_no_connection() -> None: + if self._state in TerminalStates: + return + self._state = SessionState.NO_CONNECTION + if self._ws: + self._task_manager.create_task(self._ws.close()) + self._ws = None + + if self._retry_connection_callback: + self._task_manager.create_task(self._retry_connection_callback()) + + await self._begin_close_session_countdown() + + self._start_serve_responses( + transition_no_connection=transition_no_connection, + ) + self._start_close_session_checker( + transition_no_connection=transition_no_connection, + ) self._start_buffered_message_sender() async def ensure_connected(self) -> None: @@ -455,7 +472,9 @@ async def close(self) -> None: # This will get us GC'd, so this should be the last thing. await self._close_session_callback(self) - def _start_buffered_message_sender(self) -> None: + def _start_buffered_message_sender( + self, + ) -> None: def commit(msg: TransportMessage) -> None: pending = self._send_buffer.popleft() if msg.seq != pending.seq: @@ -504,42 +523,29 @@ async def block_until_message_available() -> None: ) ) - def _start_close_session_checker(self) -> None: - def transition_connecting() -> None: - if self._state in TerminalStates: - return - self._state = SessionState.CONNECTING - self._wait_for_connected.clear() - + def _start_close_session_checker( + self, + transition_no_connection: Callable[[], Awaitable[None]], + ) -> None: self._task_manager.create_task( _check_to_close_session( - self._transport_options.close_session_check_interval_ms, - lambda: self._state, - lambda: self._ws, - transition_connecting=transition_connecting, + close_session_check_interval_ms=self._transport_options.close_session_check_interval_ms, + get_state=lambda: self._state, + get_ws=lambda: self._ws, + transition_no_connection=transition_no_connection, ) ) - def _start_serve_responses(self) -> None: + def _start_serve_responses( + self, + transition_no_connection: Callable[[], Awaitable[None]], + ) -> None: def transition_connecting() -> None: if self._state in TerminalStates: return self._state = SessionState.CONNECTING self._wait_for_connected.clear() - async def transition_no_connection() -> None: - if self._state in TerminalStates: - return - self._state = SessionState.NO_CONNECTION - if self._ws: - self._task_manager.create_task(self._ws.close()) - self._ws = None - - if self._retry_connection_callback: - self._task_manager.create_task(self._retry_connection_callback()) - - await self._begin_close_session_countdown() - def assert_incoming_seq_bookkeeping( msg_from: str, msg_seq: int, @@ -939,7 +945,7 @@ async def _check_to_close_session( close_session_check_interval_ms: float, get_state: Callable[[], SessionState], get_ws: Callable[[], ClientConnection | None], - transition_connecting: Callable[[], None], + transition_no_connection: Callable[[], Awaitable[None]], ) -> None: while get_state() not in TerminalStates: logger.debug("_check_to_close_session: Checking") @@ -947,7 +953,7 @@ async def _check_to_close_session( if (ws := get_ws()) and ws.protocol.state is CLOSED: logger.info("Websocket is closed, transitioning to connecting") - transition_connecting() + await transition_no_connection() async def _do_ensure_connected[HandshakeMetadata]( From c5633aed09ef3984ac8044311b96d233ef011c82 Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Thu, 3 Apr 2025 10:02:54 -0700 Subject: [PATCH 146/193] Renaming _serve to _recv_from_ws --- src/replit_river/v2/session.py | 23 +++++++++++++---------- 1 file changed, 13 insertions(+), 10 deletions(-) diff --git a/src/replit_river/v2/session.py b/src/replit_river/v2/session.py index 6b98b763..dc175b41 100644 --- a/src/replit_river/v2/session.py +++ b/src/replit_river/v2/session.py @@ -228,7 +228,7 @@ async def transition_no_connection() -> None: await self._begin_close_session_countdown() - self._start_serve_responses( + self._start_recv_from_ws( transition_no_connection=transition_no_connection, ) self._start_close_session_checker( @@ -536,7 +536,7 @@ def _start_close_session_checker( ) ) - def _start_serve_responses( + def _start_recv_from_ws( self, transition_no_connection: Callable[[], Awaitable[None]], ) -> None: @@ -586,7 +586,7 @@ async def block_until_connected() -> None: logger.debug("block_until_connected released!") self._task_manager.create_task( - _serve( + _recv_from_ws( block_until_connected=block_until_connected, transport_id=self._transport_id, get_state=lambda: self._state, @@ -1116,7 +1116,7 @@ async def websocket_closed_callback() -> None: return None -async def _serve( +async def _recv_from_ws( block_until_connected: Callable[[], Awaitable[None]], transport_id: str, get_state: Callable[[], SessionState], @@ -1131,13 +1131,16 @@ async def _serve( get_stream: Callable[[str], tuple[asyncio.Event, Channel[Any]] | None], send_message: SendMessage[None], ) -> None: - """Serve messages from the websocket.""" + """Serve messages from the websocket. + + + """ reset_session_close_countdown() our_task = asyncio.current_task() idx = 0 try: while our_task and not our_task.cancelling() and not our_task.cancelled(): - logger.debug(f"_serve loop count={idx}") + logger.debug(f"_recv_from_ws loop count={idx}") idx += 1 ws = None while (state := get_state()) in ConnectingStates or ( @@ -1154,9 +1157,9 @@ async def _serve( if state in TerminalStates: logger.debug( - f"Session is {state}, shut down _serve", + f"Session is {state}, shut down _recv_from_ws", ) - # session is closing / closed, no need to serve anymore + # session is closing / closed, no need to _recv_from_ws anymore break # This should not happen, but due to the complex logic around TerminalStates @@ -1278,7 +1281,7 @@ async def _serve( break except ConnectionClosed: # Set ourselves to closed as soon as we get the signal - await transition_no_connection() + transition_connecting() logger.debug("ConnectionClosed while serving", exc_info=True) break except FailedSendingMessageException: @@ -1304,4 +1307,4 @@ async def _serve( exc_info=unhandled, ) raise unhandled - logger.debug(f"_serve exiting normally after {idx} loops") + logger.debug(f"_recv_from_ws exiting normally after {idx} loops") From 72de1e99e260ee47e436b49b21ad0ce20d858d38 Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Thu, 3 Apr 2025 10:34:11 -0700 Subject: [PATCH 147/193] # type: ignore is no longer necessary --- src/replit_river/client_transport.py | 2 +- src/replit_river/server_transport.py | 2 +- src/replit_river/session.py | 2 +- tests/conftest.py | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/replit_river/client_transport.py b/src/replit_river/client_transport.py index 04215811..1e8fdcf1 100644 --- a/src/replit_river/client_transport.py +++ b/src/replit_river/client_transport.py @@ -261,7 +261,7 @@ async def websocket_closed_callback() -> None: try: await send_transport_message( TransportMessage( - from_=transport_id, # type: ignore + from_=transport_id, to=to_id, streamId=stream_id, controlFlags=0, diff --git a/src/replit_river/server_transport.py b/src/replit_river/server_transport.py index 7b15e36e..3f743e51 100644 --- a/src/replit_river/server_transport.py +++ b/src/replit_river/server_transport.py @@ -186,7 +186,7 @@ async def _send_handshake_response( response_message = TransportMessage( streamId=request_message.streamId, id=nanoid.generate(), - from_=request_message.to, # type: ignore + from_=request_message.to, to=request_message.from_, seq=0, ack=0, diff --git a/src/replit_river/session.py b/src/replit_river/session.py index 8bb745b7..465a6672 100644 --- a/src/replit_river/session.py +++ b/src/replit_river/session.py @@ -236,7 +236,7 @@ async def send_message( msg = TransportMessage( streamId=stream_id, id=nanoid.generate(), - from_=self._transport_id, # type: ignore + from_=self._transport_id, to=self._to_id, seq=self._seq_manager.get_seq_and_increment(), ack=self._seq_manager.get_ack(), diff --git a/tests/conftest.py b/tests/conftest.py index 9ab2f9f1..3866fdd1 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -35,7 +35,7 @@ def transport_message( ) -> TransportMessage: return TransportMessage( id=str(nanoid.generate()), - from_=from_, # type: ignore + from_=from_, to=to, streamId=streamId, seq=seq, From 6fb17963aaa874e5da865807168b71e8a3091d14 Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Thu, 3 Apr 2025 10:35:23 -0700 Subject: [PATCH 148/193] Establish symmetry between ConnectingStates and TerminalStates for ActiveStates --- src/replit_river/common_session.py | 1 + src/replit_river/v2/session.py | 5 +++-- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/src/replit_river/common_session.py b/src/replit_river/common_session.py index 2d670efd..d54cbbd0 100644 --- a/src/replit_river/common_session.py +++ b/src/replit_river/common_session.py @@ -49,6 +49,7 @@ class SessionState(enum.Enum): ConnectingStates = set([SessionState.NO_CONNECTION, SessionState.CONNECTING]) +ActiveStates = set([SessionState.ACTIVE]) TerminalStates = set([SessionState.CLOSING, SessionState.CLOSED]) diff --git a/src/replit_river/v2/session.py b/src/replit_river/v2/session.py index dc175b41..7aa8aa6e 100644 --- a/src/replit_river/v2/session.py +++ b/src/replit_river/v2/session.py @@ -31,6 +31,7 @@ from websockets.protocol import CLOSED from replit_river.common_session import ( + ActiveStates, ConnectingStates, SendMessage, SessionState, @@ -332,7 +333,7 @@ def is_closed(self) -> bool: return self._state in TerminalStates def is_connected(self) -> bool: - return self._state == SessionState.ACTIVE + return self._state in ActiveStates async def _begin_close_session_countdown(self) -> None: """Begin the countdown to close session, this should be called when @@ -1173,7 +1174,7 @@ async def _recv_from_ws( ws.id, ) # We should not process messages if the websocket is closed. - while (ws := get_ws()) and get_state() == SessionState.ACTIVE: + while (ws := get_ws()) and get_state() in ActiveStates: # decode=False: Avoiding an unnecessary round-trip through str # Ideally this should be type-ascripted to : bytes, but there # is no @overrides in `websockets` to hint this. From e21e5c4dadc4e36bf2f0d2446747e0f2aab7333d Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Thu, 3 Apr 2025 10:36:02 -0700 Subject: [PATCH 149/193] Shoring up state transitions --- src/replit_river/v2/session.py | 24 ++++++++---------------- 1 file changed, 8 insertions(+), 16 deletions(-) diff --git a/src/replit_river/v2/session.py b/src/replit_river/v2/session.py index 7aa8aa6e..c5f53d67 100644 --- a/src/replit_river/v2/session.py +++ b/src/replit_river/v2/session.py @@ -1134,7 +1134,7 @@ async def _recv_from_ws( ) -> None: """Serve messages from the websocket. - + Process incoming packets from the connected websocket. """ reset_session_close_countdown() our_task = asyncio.current_task() @@ -1144,17 +1144,15 @@ async def _recv_from_ws( logger.debug(f"_recv_from_ws loop count={idx}") idx += 1 ws = None - while (state := get_state()) in ConnectingStates or ( - ws := get_ws() - ) is None: + while ((state := get_state()) in ConnectingStates) and ( + state not in TerminalStates + ): logger.debug( "_handle_messages_from_ws spinning while connecting, %r %r", ws, state, ) await block_until_connected() - if state in TerminalStates: - break if state in TerminalStates: logger.debug( @@ -1163,16 +1161,8 @@ async def _recv_from_ws( # session is closing / closed, no need to _recv_from_ws anymore break - # This should not happen, but due to the complex logic around TerminalStates - # above, pyright is not convinced we've caught all the states. - if not ws: - continue + logger.debug("client start handling messages from ws %r", ws) - logger.debug( - "%s start handling messages from ws %s", - "client", - ws.id, - ) # We should not process messages if the websocket is closed. while (ws := get_ws()) and get_state() in ActiveStates: # decode=False: Avoiding an unnecessary round-trip through str @@ -1181,8 +1171,10 @@ async def _recv_from_ws( try: message = await ws.recv(decode=False) except ConnectionClosed: + # This triggers a break in the inner loop so we can get back to + # the outer loop. transition_connecting() - continue + break try: msg = parse_transport_msg(message) logger.debug( From d4d0c8229d8927243fb1aaee90f16637c1c2117d Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Thu, 3 Apr 2025 10:44:29 -0700 Subject: [PATCH 150/193] Avoid renaming parameters --- src/replit_river/v2/client_transport.py | 28 ++++++++++++++++++----- src/replit_river/v2/session.py | 30 +++++++++++-------------- 2 files changed, 35 insertions(+), 23 deletions(-) diff --git a/src/replit_river/v2/client_transport.py b/src/replit_river/v2/client_transport.py index d74a2558..3dc96522 100644 --- a/src/replit_river/v2/client_transport.py +++ b/src/replit_river/v2/client_transport.py @@ -26,7 +26,7 @@ def __init__( transport_options: TransportOptions, ): self._session = None - self._transport_id = client_id + self._transport_id = nanoid.generate() self._transport_options = transport_options self._uri_and_metadata_factory = uri_and_metadata_factory @@ -40,7 +40,13 @@ async def close(self) -> None: self._rate_limiter.close() if self._session: await self._session.close() - logger.info(f"Transport closed {self._transport_id}") + logger.info( + "Transport closed", + extra={ + "client_id": self._client_id, + "transport_id": self._transport_id, + }, + ) async def get_or_create_session(self) -> Session: """ @@ -51,15 +57,14 @@ async def get_or_create_session(self) -> Session: if not existing_session or existing_session.is_closed(): logger.info("Creating new session") new_session = Session( - transport_id=self._transport_id, - to_id=self._server_id, + client_id=self._client_id, + server_id=self._server_id, session_id=nanoid.generate(), transport_options=self._transport_options, close_session_callback=self._delete_session, retry_connection_callback=self._retry_connection, uri_and_metadata_factory=self._uri_and_metadata_factory, rate_limiter=self._rate_limiter, - client_id=self._client_id, ) self._session = new_session @@ -76,5 +81,16 @@ async def _retry_connection(self) -> Session: return await self.get_or_create_session() async def _delete_session(self, session: Session) -> None: - if self._session and session._to_id == self._session._to_id: + if self._session is session: self._session = None + else: + logger.warning( + "Session attempted to close itself but it was not the " + "active session, doing nothing", + extra={ + "client_id": self._client_id, + "transport_id": self._transport_id, + "active_session_id": self._session and self._session.session_id, + "orphan_session_id": session.session_id, + }, + ) diff --git a/src/replit_river/v2/session.py b/src/replit_river/v2/session.py index c5f53d67..2f147791 100644 --- a/src/replit_river/v2/session.py +++ b/src/replit_river/v2/session.py @@ -122,7 +122,7 @@ class _IgnoreMessage: class Session[HandshakeMetadata]: _transport_id: str - _to_id: str + _server_id: str session_id: str _transport_options: TransportOptions @@ -163,8 +163,7 @@ class Session[HandshakeMetadata]: def __init__( self, - transport_id: str, - to_id: str, + server_id: str, session_id: str, transport_options: TransportOptions, close_session_callback: CloseSessionCallback, @@ -175,8 +174,7 @@ def __init__( ], retry_connection_callback: RetryConnectionCallback | None = None, ) -> None: - self._transport_id = transport_id - self._to_id = to_id + self._server_id = server_id self.session_id = session_id self._transport_options = transport_options @@ -304,9 +302,8 @@ def finalize_attempt() -> None: if not self._connecting_task: self._connecting_task = asyncio.create_task( _do_ensure_connected( - transport_id=self._transport_id, client_id=self._client_id, - to_id=self._to_id, + server_id=self._server_id, session_id=self.session_id, max_retry=self._transport_options.connection_retry_options.max_retry, rate_limiter=self._rate_limiter, @@ -352,7 +349,7 @@ async def _begin_close_session_countdown(self) -> None: logger.info( "websocket closed from %s to %s begin grace period", self._transport_id, - self._to_id, + self._server_id, ) self._state = SessionState.NO_CONNECTION self._close_session_after_time_secs = close_session_after_time_secs @@ -403,7 +400,7 @@ async def _send_message( streamId=stream_id, id=nanoid.generate(), from_=self._transport_id, - to=self._to_id, + to=self._server_id, seq=self.seq, ack=self.ack, controlFlags=control_flags, @@ -432,7 +429,7 @@ async def _send_message( async def close(self) -> None: """Close the session and all associated streams.""" logger.info( - f"{self._transport_id} closing session to {self._to_id}, ws: {self._ws}" + f"{self._transport_id} closing session to {self._server_id}, ws: {self._ws}" ) if self._state in TerminalStates: # already closing @@ -589,7 +586,7 @@ async def block_until_connected() -> None: self._task_manager.create_task( _recv_from_ws( block_until_connected=block_until_connected, - transport_id=self._transport_id, + client_id=self._transport_id, get_state=lambda: self._state, get_ws=lambda: self._ws, transition_connecting=transition_connecting, @@ -958,9 +955,8 @@ async def _check_to_close_session( async def _do_ensure_connected[HandshakeMetadata]( - transport_id: str, client_id: str, - to_id: str, + server_id: str, session_id: str, max_retry: int, rate_limiter: LeakyBucketRateLimit, @@ -1010,8 +1006,8 @@ async def websocket_closed_callback() -> None: await send_transport_message( TransportMessage( - from_=transport_id, - to=to_id, + from_=client_id, + to=server_id, streamId=nanoid.generate(), controlFlags=0, id=nanoid.generate(), @@ -1119,7 +1115,7 @@ async def websocket_closed_callback() -> None: async def _recv_from_ws( block_until_connected: Callable[[], Awaitable[None]], - transport_id: str, + client_id: str, get_state: Callable[[], SessionState], get_ws: Callable[[], ClientConnection | None], transition_connecting: Callable[[], None], @@ -1179,7 +1175,7 @@ async def _recv_from_ws( msg = parse_transport_msg(message) logger.debug( "[%s] got a message %r", - transport_id, + client_id, msg, ) if isinstance(msg, str): From bfc2392ac7e38a527bf15747b90f65ecb9f0976b Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Thu, 3 Apr 2025 10:48:05 -0700 Subject: [PATCH 151/193] Avoid redundant exceptions --- src/replit_river/v2/session.py | 11 +---------- 1 file changed, 1 insertion(+), 10 deletions(-) diff --git a/src/replit_river/v2/session.py b/src/replit_river/v2/session.py index 2f147791..edf02f58 100644 --- a/src/replit_river/v2/session.py +++ b/src/replit_river/v2/session.py @@ -27,7 +27,7 @@ from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator from pydantic import ValidationError from websockets.asyncio.client import ClientConnection -from websockets.exceptions import ConnectionClosed, ConnectionClosedOK +from websockets.exceptions import ConnectionClosed from websockets.protocol import CLOSED from replit_river.common_session import ( @@ -1264,15 +1264,6 @@ async def _recv_from_ws( ) await close_session() continue - except ConnectionClosedOK: - # Exited normally - transition_connecting() - break - except ConnectionClosed: - # Set ourselves to closed as soon as we get the signal - transition_connecting() - logger.debug("ConnectionClosed while serving", exc_info=True) - break except FailedSendingMessageException: # Expected error if the connection is closed. await transition_no_connection() From 14ae1e218f478a2f5eeb875a6615555488a49a1f Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Thu, 3 Apr 2025 10:53:00 -0700 Subject: [PATCH 152/193] check_to_close_connection is irrelevant --- src/replit_river/v2/session.py | 67 ++++++++-------------------------- 1 file changed, 15 insertions(+), 52 deletions(-) diff --git a/src/replit_river/v2/session.py b/src/replit_river/v2/session.py index edf02f58..18fd67a8 100644 --- a/src/replit_river/v2/session.py +++ b/src/replit_river/v2/session.py @@ -28,7 +28,6 @@ from pydantic import ValidationError from websockets.asyncio.client import ClientConnection from websockets.exceptions import ConnectionClosed -from websockets.protocol import CLOSED from replit_river.common_session import ( ActiveStates, @@ -214,25 +213,7 @@ def __init__( # Terminating self._terminating_task = None - async def transition_no_connection() -> None: - if self._state in TerminalStates: - return - self._state = SessionState.NO_CONNECTION - if self._ws: - self._task_manager.create_task(self._ws.close()) - self._ws = None - - if self._retry_connection_callback: - self._task_manager.create_task(self._retry_connection_callback()) - - await self._begin_close_session_countdown() - - self._start_recv_from_ws( - transition_no_connection=transition_no_connection, - ) - self._start_close_session_checker( - transition_no_connection=transition_no_connection, - ) + self._start_recv_from_ws() self._start_buffered_message_sender() async def ensure_connected(self) -> None: @@ -521,29 +502,26 @@ async def block_until_message_available() -> None: ) ) - def _start_close_session_checker( - self, - transition_no_connection: Callable[[], Awaitable[None]], - ) -> None: - self._task_manager.create_task( - _check_to_close_session( - close_session_check_interval_ms=self._transport_options.close_session_check_interval_ms, - get_state=lambda: self._state, - get_ws=lambda: self._ws, - transition_no_connection=transition_no_connection, - ) - ) - - def _start_recv_from_ws( - self, - transition_no_connection: Callable[[], Awaitable[None]], - ) -> None: + def _start_recv_from_ws(self) -> None: def transition_connecting() -> None: if self._state in TerminalStates: return self._state = SessionState.CONNECTING self._wait_for_connected.clear() + async def transition_no_connection() -> None: + if self._state in TerminalStates: + return + self._state = SessionState.NO_CONNECTION + if self._ws: + self._task_manager.create_task(self._ws.close()) + self._ws = None + + if self._retry_connection_callback: + self._task_manager.create_task(self._retry_connection_callback()) + + await self._begin_close_session_countdown() + def assert_incoming_seq_bookkeeping( msg_from: str, msg_seq: int, @@ -939,21 +917,6 @@ async def _send_close_stream( ) -async def _check_to_close_session( - close_session_check_interval_ms: float, - get_state: Callable[[], SessionState], - get_ws: Callable[[], ClientConnection | None], - transition_no_connection: Callable[[], Awaitable[None]], -) -> None: - while get_state() not in TerminalStates: - logger.debug("_check_to_close_session: Checking") - await asyncio.sleep(close_session_check_interval_ms / 1000) - - if (ws := get_ws()) and ws.protocol.state is CLOSED: - logger.info("Websocket is closed, transitioning to connecting") - await transition_no_connection() - - async def _do_ensure_connected[HandshakeMetadata]( client_id: str, server_id: str, From 6b12256bf64b4e05ad13da1a04947cc4f35d58a6 Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Thu, 3 Apr 2025 10:55:45 -0700 Subject: [PATCH 153/193] Add missing "else" for clarity --- src/replit_river/v2/session.py | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/src/replit_river/v2/session.py b/src/replit_river/v2/session.py index 18fd67a8..637f00f3 100644 --- a/src/replit_river/v2/session.py +++ b/src/replit_river/v2/session.py @@ -542,15 +542,13 @@ def assert_incoming_seq_bookkeeping( raise OutOfOrderMessageException( f"Out of order message received got {msg_seq} expected {self.ack}" ) + else: + # Set our next expected ack number + self.ack = msg_seq + 1 - assert msg_seq == self.ack, "Safety net, redundant assertion" - - # Set our next expected ack number - self.ack = msg_seq + 1 - - # Discard old server-ack'd messages from the ack buffer - while self._ack_buffer and self._ack_buffer[0].seq < msg_ack: - self._ack_buffer.popleft() + # Discard old server-ack'd messages from the ack buffer + while self._ack_buffer and self._ack_buffer[0].seq < msg_ack: + self._ack_buffer.popleft() return True From db6ff5c0f84ae5a91217a226ed3e04ce05ea4cc4 Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Thu, 3 Apr 2025 11:22:15 -0700 Subject: [PATCH 154/193] Pushing error metadata into the exception --- src/replit_river/seq_manager.py | 10 ++++++++-- src/replit_river/v2/session.py | 11 ++++++++--- 2 files changed, 16 insertions(+), 5 deletions(-) diff --git a/src/replit_river/seq_manager.py b/src/replit_river/seq_manager.py index 8a2f6798..a1baa0f3 100644 --- a/src/replit_river/seq_manager.py +++ b/src/replit_river/seq_manager.py @@ -18,7 +18,12 @@ class OutOfOrderMessageException(Exception): we close the session. """ - pass + def __init__(self, *, received_seq: int, expected_ack: int) -> None: + super().__init__( + "Out of order message received: " + f"got={received_seq}, " + f"expected={expected_ack}" + ) class SessionStateMismatchException(Exception): @@ -71,7 +76,8 @@ def check_seq_and_update(self, msg: TransportMessage) -> IgnoreMessage | None: ) raise OutOfOrderMessageException( - f"Out of order message received got {msg.seq} expected {self.ack}" + received_seq=msg.seq, + expected_ack=self.ack, ) self.receiver_ack = msg.ack self.ack = msg.seq + 1 diff --git a/src/replit_river/v2/session.py b/src/replit_river/v2/session.py index 637f00f3..41c7b063 100644 --- a/src/replit_river/v2/session.py +++ b/src/replit_river/v2/session.py @@ -530,8 +530,12 @@ def assert_incoming_seq_bookkeeping( # Update bookkeeping if msg_seq < self.ack: logger.info( - f"{msg_from} received duplicate msg, got {msg_seq}" - f" expected {self.ack}" + "Received duplicate msg", + extra={ + "from": msg_from, + "got_seq": msg_seq, + "expected_ack": self.ack, + }, ) return _IgnoreMessage() elif msg_seq > self.ack: @@ -540,7 +544,8 @@ def assert_incoming_seq_bookkeeping( ) raise OutOfOrderMessageException( - f"Out of order message received got {msg_seq} expected {self.ack}" + received_seq=msg_seq, + expected_ack=self.ack, ) else: # Set our next expected ack number From 36d9b2b738771683d87920757c95557c2c73d0fc Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Thu, 3 Apr 2025 11:43:43 -0700 Subject: [PATCH 155/193] Clarify finalize_attempt() --- src/replit_river/v2/session.py | 14 +++++--------- 1 file changed, 5 insertions(+), 9 deletions(-) diff --git a/src/replit_river/v2/session.py b/src/replit_river/v2/session.py index 41c7b063..e904c665 100644 --- a/src/replit_river/v2/session.py +++ b/src/replit_river/v2/session.py @@ -263,7 +263,7 @@ def transition_connected(ws: ClientConnection) -> None: def close_ws_in_background(ws: ClientConnection) -> None: self._task_manager.create_task(ws.close()) - def finalize_attempt() -> None: + def unbind_connecting_task() -> None: # We are in a state where we may throw an exception. # # To allow subsequent calls to ensure_connected to pass, we clear ourselves. @@ -273,11 +273,7 @@ def finalize_attempt() -> None: # # Let's do our best to avoid clobbering other tasks by comparing the .name current_task = asyncio.current_task() - if ( - self._connecting_task - and current_task - and self._connecting_task is current_task - ): + if self._connecting_task is current_task: self._connecting_task = None if not self._connecting_task: @@ -295,7 +291,7 @@ def finalize_attempt() -> None: transition_connecting=transition_connecting, close_ws_in_background=close_ws_in_background, transition_connected=transition_connected, - finalize_attempt=finalize_attempt, + unbind_connecting_task=unbind_connecting_task, do_close=do_close, ) ) @@ -935,7 +931,7 @@ async def _do_ensure_connected[HandshakeMetadata]( transition_connecting: Callable[[], None], close_ws_in_background: Callable[[ClientConnection], None], transition_connected: Callable[[ClientConnection], None], - finalize_attempt: Callable[[], None], + unbind_connecting_task: Callable[[], None], do_close: Callable[[], None], ) -> None: logger.info("Attempting to establish new ws connection") @@ -1066,7 +1062,7 @@ async def websocket_closed_callback() -> None: f"Error connecting, retrying with {backoff_time}ms backoff" ) await asyncio.sleep(backoff_time / 1000) - finalize_attempt() + unbind_connecting_task() if last_error is not None: logger.debug("Handshake attempts exhausted, terminating") From 79fe8ef659908fa4843c15332e3d00e82746a8f8 Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Thu, 3 Apr 2025 11:45:26 -0700 Subject: [PATCH 156/193] Only one exception can be raised from putting on an aiochannel --- src/replit_river/v2/session.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/replit_river/v2/session.py b/src/replit_river/v2/session.py index e904c665..93876c01 100644 --- a/src/replit_river/v2/session.py +++ b/src/replit_river/v2/session.py @@ -1208,8 +1208,6 @@ async def _recv_from_ws( # The client is no longer interested in this stream, # just drop the message. pass - except RuntimeError as e: - raise InvalidMessageException(e) from e if msg.controlFlags & STREAM_CLOSED_BIT != 0: # Communicate that we're going down From 75399d636a5f76f10d0b881f084d6a914e5531e9 Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Thu, 3 Apr 2025 11:46:56 -0700 Subject: [PATCH 157/193] naming --- src/replit_river/v2/session.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/replit_river/v2/session.py b/src/replit_river/v2/session.py index 93876c01..3b7f356b 100644 --- a/src/replit_river/v2/session.py +++ b/src/replit_river/v2/session.py @@ -1183,16 +1183,16 @@ async def _recv_from_ws( ) continue - event_stream = get_stream(msg.streamId) + waiter_and_stream = get_stream(msg.streamId) - if not event_stream: + if not waiter_and_stream: logger.warning( "no stream for %s, ignoring message", msg.streamId, ) continue - backpressure_waiter, stream = event_stream + backpressure_waiter, stream = waiter_and_stream if ( msg.controlFlags & STREAM_CLOSED_BIT != 0 From 17c149e6c328c2f54f1fe8503cded76f2c0226b3 Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Thu, 3 Apr 2025 11:52:26 -0700 Subject: [PATCH 158/193] Unused extra_control_flags --- src/replit_river/v2/session.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/src/replit_river/v2/session.py b/src/replit_river/v2/session.py index 3b7f356b..ebe50981 100644 --- a/src/replit_river/v2/session.py +++ b/src/replit_river/v2/session.py @@ -714,7 +714,6 @@ async def send_upload[I, R, A]( ) from e await self._send_close_stream( stream_id=stream_id, - extra_control_flags=0, span=span, ) @@ -835,7 +834,6 @@ async def _encode_stream() -> None: if not request: await self._send_close_stream( stream_id=stream_id, - extra_control_flags=STREAM_OPEN_BIT, span=span, ) return @@ -856,7 +854,6 @@ async def _encode_stream() -> None: ) await self._send_close_stream( stream_id=stream_id, - extra_control_flags=0, span=span, ) @@ -905,12 +902,11 @@ async def _send_cancel_stream( async def _send_close_stream( self, stream_id: str, - extra_control_flags: int, span: Span, ) -> None: await self._send_message( stream_id=stream_id, - control_flags=STREAM_CLOSED_BIT | extra_control_flags, + control_flags=STREAM_CLOSED_BIT, payload={"type": "CLOSE"}, span=span, ) From 684e7d6f9f625bcbf6a5205780f6432529f80e55 Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Thu, 3 Apr 2025 12:51:09 -0700 Subject: [PATCH 159/193] Clarify _with_stream semantics --- src/replit_river/v2/session.py | 97 +++++++++++++++++----------------- 1 file changed, 49 insertions(+), 48 deletions(-) diff --git a/src/replit_river/v2/session.py b/src/replit_river/v2/session.py index ebe50981..e7f4ecc0 100644 --- a/src/replit_river/v2/session.py +++ b/src/replit_river/v2/session.py @@ -46,7 +46,6 @@ RiverException, RiverServiceException, SessionClosedRiverServiceException, - StreamClosedRiverServiceException, exception_from_message, ) from replit_river.messages import ( @@ -350,9 +349,6 @@ async def _send_message( span: Span | None = None, ) -> None: """Send serialized messages to the websockets.""" - # if the session is not active, we should not do anything - if self._state in TerminalStates: - return logger.debug( "_send_message(stream_id=%r, payload=%r, control_flags=%r, " "service_name=%r, procedure_name=%r)", @@ -582,6 +578,16 @@ async def _with_stream( session_id: str, maxsize: int, ) -> AsyncIterator[tuple[asyncio.Event, Channel[ResultType]]]: + """ + _with_stream + + An async context that exposes a managed stream and an event that permits + producers to respond to backpressure. + + It is expected that the first message emitted ignores this backpressure_waiter, + since the first event does not care about backpressure, but subsequent events + emitted should call await backpressure_waiter.wait() prior to emission. + """ output: Channel[Any] = Channel(maxsize=maxsize) backpressure_waiter = asyncio.Event() self._streams[session_id] = (backpressure_waiter, output) @@ -606,15 +612,16 @@ async def send_rpc[R, A]( Expects the input and output be messages that will be msgpacked. """ stream_id = nanoid.generate() + await self._send_message( + stream_id=stream_id, + control_flags=STREAM_OPEN_BIT | STREAM_CLOSED_BIT, + payload=request_serializer(request), + service_name=service_name, + procedure_name=procedure_name, + span=span, + ) + async with self._with_stream(stream_id, 1) as (backpressure_waiter, output): - await self._send_message( - stream_id=stream_id, - control_flags=STREAM_OPEN_BIT | STREAM_CLOSED_BIT, - payload=request_serializer(request), - service_name=service_name, - procedure_name=procedure_name, - span=span, - ) # Handle potential errors during communication try: async with asyncio.timeout(timeout.total_seconds()): @@ -665,19 +672,18 @@ async def send_upload[I, R, A]( Expects the input and output be messages that will be msgpacked. """ - stream_id = nanoid.generate() + await self._send_message( + stream_id=stream_id, + control_flags=STREAM_OPEN_BIT, + service_name=service_name, + procedure_name=procedure_name, + payload=init_serializer(init), + span=span, + ) + async with self._with_stream(stream_id, 1) as (backpressure_waiter, output): try: - await self._send_message( - stream_id=stream_id, - control_flags=STREAM_OPEN_BIT, - service_name=service_name, - procedure_name=procedure_name, - payload=init_serializer(init), - span=span, - ) - if request: assert request_serializer, "send_stream missing request_serializer" @@ -756,16 +762,16 @@ async def send_subscription[R, E, A]( Expects the input and output be messages that will be msgpacked. """ stream_id = nanoid.generate() - async with self._with_stream(stream_id, MAX_MESSAGE_BUFFER_SIZE) as (_, output): - await self._send_message( - service_name=service_name, - procedure_name=procedure_name, - stream_id=stream_id, - control_flags=STREAM_OPEN_BIT, - payload=request_serializer(request), - span=span, - ) + await self._send_message( + service_name=service_name, + procedure_name=procedure_name, + stream_id=stream_id, + control_flags=STREAM_OPEN_BIT, + payload=request_serializer(request), + span=span, + ) + async with self._with_stream(stream_id, MAX_MESSAGE_BUFFER_SIZE) as (_, output): # Handle potential errors during communication try: async for item in output: @@ -811,24 +817,19 @@ async def send_stream[I, R, E, A]( """ stream_id = nanoid.generate() - async with self._with_stream( - stream_id, - MAX_MESSAGE_BUFFER_SIZE, - ) as (backpressure_waiter, output): - try: - await self._send_message( - service_name=service_name, - procedure_name=procedure_name, - stream_id=stream_id, - control_flags=STREAM_OPEN_BIT, - payload=init_serializer(init), - span=span, - ) - except Exception as e: - raise StreamClosedRiverServiceException( - ERROR_CODE_STREAM_CLOSED, str(e), service_name, procedure_name - ) from e + await self._send_message( + service_name=service_name, + procedure_name=procedure_name, + stream_id=stream_id, + control_flags=STREAM_OPEN_BIT, + payload=init_serializer(init), + span=span, + ) + async with self._with_stream(stream_id, MAX_MESSAGE_BUFFER_SIZE) as ( + backpressure_waiter, + output, + ): # Create the encoder task async def _encode_stream() -> None: if not request: From 27bb8e6d8cbaa6833596ccb942479370109d1760 Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Thu, 3 Apr 2025 12:52:01 -0700 Subject: [PATCH 160/193] Taking a stand on request nullability in river v2 --- src/replit_river/codegen/client.py | 29 +++------------------- src/replit_river/v2/client.py | 4 +-- src/replit_river/v2/session.py | 39 ++++++++++++++---------------- 3 files changed, 23 insertions(+), 49 deletions(-) diff --git a/src/replit_river/codegen/client.py b/src/replit_river/codegen/client.py index d7996500..0fc226d2 100644 --- a/src/replit_river/codegen/client.py +++ b/src/replit_river/codegen/client.py @@ -960,32 +960,9 @@ async def {name}( ] ) elif protocol_version == "v2.0": - assert init_meta, "Protocol v2 requires init to be defined" - _, init_type, render_init_method = init_meta - current_chunks.extend( - [ - reindent( - " ", - f"""\ - async def {name}( - self, - init: {render_type_expr(init_type)}, - ) -> { # TODO(dstewart) This should just be output_type - render_type_expr(output_or_error_type) - }: - return await self.client.send_upload( - {repr(schema_name)}, - {repr(name)}, - init, - None, - {reindent(" ", render_init_method)}, - None, - {reindent(" ", parse_output_method)}, - {reindent(" ", parse_error_method)}, - ) - """, - ) - ] + raise ValueError( + "It is expected that protocol v2 uploads have both init and input " + "defined, otherwise it's no different than rpc", ) else: assert_never(protocol_version) diff --git a/src/replit_river/v2/client.py b/src/replit_river/v2/client.py index 4edf299b..09acc476 100644 --- a/src/replit_river/v2/client.py +++ b/src/replit_river/v2/client.py @@ -75,9 +75,9 @@ async def send_upload[I, R, A]( service_name: str, procedure_name: str, init: I, - request: AsyncIterable[R] | None, + request: AsyncIterable[R], init_serializer: Callable[[I], Any], - request_serializer: Callable[[R], Any] | None, + request_serializer: Callable[[R], Any], response_deserializer: Callable[[Any], A], error_deserializer: Callable[[Any], RiverError], ) -> A: diff --git a/src/replit_river/v2/session.py b/src/replit_river/v2/session.py index e7f4ecc0..1ba75ba6 100644 --- a/src/replit_river/v2/session.py +++ b/src/replit_river/v2/session.py @@ -661,9 +661,9 @@ async def send_upload[I, R, A]( service_name: str, procedure_name: str, init: I, - request: AsyncIterable[R] | None, + request: AsyncIterable[R], init_serializer: Callable[[I], Any], - request_serializer: Callable[[R], Any] | None, + request_serializer: Callable[[R], Any], response_deserializer: Callable[[Any], A], error_deserializer: Callable[[Any], RiverError], span: Span, @@ -684,25 +684,22 @@ async def send_upload[I, R, A]( async with self._with_stream(stream_id, 1) as (backpressure_waiter, output): try: - if request: - assert request_serializer, "send_stream missing request_serializer" - - # If this request is not closed and the session is killed, we should - # throw exception here - async for item in request: - # Block for backpressure - await backpressure_waiter.wait() - if output.closed(): - logger.debug("Stream is closed, avoid sending the rest") - break - await self._send_message( - stream_id=stream_id, - service_name=service_name, - procedure_name=procedure_name, - control_flags=0, - payload=request_serializer(item), - span=span, - ) + # If this request is not closed and the session is killed, we should + # throw exception here + async for item in request: + # Block for backpressure + await backpressure_waiter.wait() + if output.closed(): + logger.debug("Stream is closed, avoid sending the rest") + break + await self._send_message( + stream_id=stream_id, + service_name=service_name, + procedure_name=procedure_name, + control_flags=0, + payload=request_serializer(item), + span=span, + ) except WebsocketClosedException as e: raise RiverServiceException( ERROR_CODE_STREAM_CLOSED, str(e), service_name, procedure_name From 0c6301f12210af20e057ae6ad69b9c2f818d85a3 Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Thu, 3 Apr 2025 15:02:24 -0700 Subject: [PATCH 161/193] Pretty sure this is handled now --- src/replit_river/v2/session.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/replit_river/v2/session.py b/src/replit_river/v2/session.py index 1ba75ba6..4646e2d0 100644 --- a/src/replit_river/v2/session.py +++ b/src/replit_river/v2/session.py @@ -720,8 +720,6 @@ async def send_upload[I, R, A]( span=span, ) - # Handle potential errors during communication - # TODO: throw a error when the transport is hard closed try: result = await output.get() except ChannelClosed as e: From a8bb9d314db1105fb5bec5a1c2e4e442d02d1c72 Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Thu, 3 Apr 2025 19:32:24 -0700 Subject: [PATCH 162/193] stream -> output for consistency --- src/replit_river/v2/session.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/replit_river/v2/session.py b/src/replit_river/v2/session.py index 4646e2d0..b9281b0d 100644 --- a/src/replit_river/v2/session.py +++ b/src/replit_river/v2/session.py @@ -1184,7 +1184,7 @@ async def _recv_from_ws( ) continue - backpressure_waiter, stream = waiter_and_stream + backpressure_waiter, output = waiter_and_stream if ( msg.controlFlags & STREAM_CLOSED_BIT != 0 @@ -1195,7 +1195,7 @@ async def _recv_from_ws( pass else: try: - await stream.put(msg.payload) + await output.put(msg.payload) except ChannelClosed: # The client is no longer interested in this stream, # just drop the message. @@ -1203,7 +1203,7 @@ async def _recv_from_ws( if msg.controlFlags & STREAM_CLOSED_BIT != 0: # Communicate that we're going down - stream.close() + output.close() # Wake up backpressured writer backpressure_waiter.set() except OutOfOrderMessageException: From 9ab4d190b34aa41a334afd62779ece17336832ed Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Thu, 3 Apr 2025 19:36:38 -0700 Subject: [PATCH 163/193] Clarify enqueue vs send semantics --- src/replit_river/v2/session.py | 26 +++++++++++++------------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/src/replit_river/v2/session.py b/src/replit_river/v2/session.py index b9281b0d..08519246 100644 --- a/src/replit_river/v2/session.py +++ b/src/replit_river/v2/session.py @@ -339,7 +339,7 @@ def _reset_session_close_countdown(self) -> None: self._heartbeat_misses = 0 self._close_session_after_time_secs = None - async def _send_message( + async def _enqueue_message( self, stream_id: str, payload: dict[Any, Any] | str, @@ -350,7 +350,7 @@ async def _send_message( ) -> None: """Send serialized messages to the websockets.""" logger.debug( - "_send_message(stream_id=%r, payload=%r, control_flags=%r, " + "_enqueue_message(stream_id=%r, payload=%r, control_flags=%r, " "service_name=%r, procedure_name=%r)", stream_id, payload, @@ -568,7 +568,7 @@ async def block_until_connected() -> None: close_session=self.close, assert_incoming_seq_bookkeeping=assert_incoming_seq_bookkeeping, get_stream=lambda stream_id: self._streams.get(stream_id), - send_message=self._send_message, + enqueue_message=self._enqueue_message, ) ) @@ -612,7 +612,7 @@ async def send_rpc[R, A]( Expects the input and output be messages that will be msgpacked. """ stream_id = nanoid.generate() - await self._send_message( + await self._enqueue_message( stream_id=stream_id, control_flags=STREAM_OPEN_BIT | STREAM_CLOSED_BIT, payload=request_serializer(request), @@ -673,7 +673,7 @@ async def send_upload[I, R, A]( Expects the input and output be messages that will be msgpacked. """ stream_id = nanoid.generate() - await self._send_message( + await self._enqueue_message( stream_id=stream_id, control_flags=STREAM_OPEN_BIT, service_name=service_name, @@ -692,7 +692,7 @@ async def send_upload[I, R, A]( if output.closed(): logger.debug("Stream is closed, avoid sending the rest") break - await self._send_message( + await self._enqueue_message( stream_id=stream_id, service_name=service_name, procedure_name=procedure_name, @@ -757,7 +757,7 @@ async def send_subscription[R, E, A]( Expects the input and output be messages that will be msgpacked. """ stream_id = nanoid.generate() - await self._send_message( + await self._enqueue_message( service_name=service_name, procedure_name=procedure_name, stream_id=stream_id, @@ -812,7 +812,7 @@ async def send_stream[I, R, E, A]( """ stream_id = nanoid.generate() - await self._send_message( + await self._enqueue_message( service_name=service_name, procedure_name=procedure_name, stream_id=stream_id, @@ -843,7 +843,7 @@ async def _encode_stream() -> None: if output.closed(): logger.debug("Stream is closed, avoid sending the rest") break - await self._send_message( + await self._enqueue_message( stream_id=stream_id, control_flags=0, payload=request_serializer(item), @@ -888,7 +888,7 @@ async def _send_cancel_stream( extra_control_flags: int, span: Span, ) -> None: - await self._send_message( + await self._enqueue_message( stream_id=stream_id, control_flags=STREAM_CANCEL_BIT | extra_control_flags, payload={"type": "CANCEL"}, @@ -900,7 +900,7 @@ async def _send_close_stream( stream_id: str, span: Span, ) -> None: - await self._send_message( + await self._enqueue_message( stream_id=stream_id, control_flags=STREAM_CLOSED_BIT, payload={"type": "CLOSE"}, @@ -1080,7 +1080,7 @@ async def _recv_from_ws( [str, int, int], Literal[True] | _IgnoreMessage ], get_stream: Callable[[str], tuple[asyncio.Event, Channel[Any]] | None], - send_message: SendMessage[None], + enqueue_message: SendMessage[None], ) -> None: """Serve messages from the websocket. @@ -1161,7 +1161,7 @@ async def _recv_from_ws( # Shortcut to avoid processing ack packets if msg.controlFlags & ACK_BIT != 0: - await send_message( + await enqueue_message( stream_id="heartbeat", # TODO: make this a message class # https://github.com/replit/river/blob/741b1ea6d7600937ad53564e9cf8cd27a92ec36a/transport/message.ts#L42 From bb4b652ff4d15b134a7b9741164d36ae60b86f3c Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Thu, 3 Apr 2025 19:48:34 -0700 Subject: [PATCH 164/193] Avoid breaking in buffered_message_sender --- src/replit_river/common_session.py | 19 ++++++++----------- 1 file changed, 8 insertions(+), 11 deletions(-) diff --git a/src/replit_river/common_session.py b/src/replit_river/common_session.py index d54cbbd0..b47bae82 100644 --- a/src/replit_river/common_session.py +++ b/src/replit_river/common_session.py @@ -64,20 +64,20 @@ async def buffered_message_sender( ) -> None: our_task = asyncio.current_task() while our_task and not our_task.cancelling() and not our_task.cancelled(): - await block_until_message_available() - - if get_state() in TerminalStates: - logger.debug("_buffered_message_sender: closing") - return - - while (ws := get_ws()) is None: + while get_state() in ConnectingStates: # Block until we have a handle logger.debug( "_buffered_message_sender: Waiting until ws is connected", ) await block_until_connected() - if not ws: + if get_state() in TerminalStates: + logger.debug("_buffered_message_sender: closing") + return + + await block_until_message_available() + + if not (ws := get_ws()): logger.debug("_buffered_message_sender: ws is not connected, loop") continue @@ -97,13 +97,10 @@ async def buffered_message_sender( type(e), exc_info=e, ) - break except FailedSendingMessageException: logger.error( "Failed sending message, waiting for retry from buffer", exc_info=True, ) - break except Exception: logger.exception("Error attempting to send buffered messages") - break From bf2c7e6fe89e36b1500d12ed3d29f955833ce573 Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Thu, 3 Apr 2025 20:55:32 -0700 Subject: [PATCH 165/193] Permit errors from buffered_message_sender to bubble back to the calling thread --- src/replit_river/common_session.py | 4 +- src/replit_river/v2/session.py | 175 +++++++++++++++-------------- 2 files changed, 95 insertions(+), 84 deletions(-) diff --git a/src/replit_river/common_session.py b/src/replit_river/common_session.py index b47bae82..c401d269 100644 --- a/src/replit_river/common_session.py +++ b/src/replit_river/common_session.py @@ -59,7 +59,7 @@ async def buffered_message_sender( get_ws: Callable[[], WebSocketCommonProtocol | ClientConnection | None], websocket_closed_callback: Callable[[], Coroutine[Any, Any, None]], get_next_pending: Callable[[], TransportMessage | None], - commit: Callable[[TransportMessage], None], + commit: Callable[[TransportMessage], Awaitable[None]], get_state: Callable[[], SessionState], ) -> None: our_task = asyncio.current_task() @@ -89,7 +89,7 @@ async def buffered_message_sender( ) try: await send_transport_message(msg, ws, websocket_closed_callback) - commit(msg) + await commit(msg) except WebsocketClosedException as e: logger.debug( "_buffered_message_sender: Connection closed while sending " diff --git a/src/replit_river/v2/session.py b/src/replit_river/v2/session.py index 08519246..455b4e91 100644 --- a/src/replit_river/v2/session.py +++ b/src/replit_river/v2/session.py @@ -147,7 +147,7 @@ class Session[HandshakeMetadata]: _space_available: asyncio.Event # stream for tasks - _streams: dict[str, tuple[asyncio.Event, Channel[Any]]] + _streams: dict[str, tuple[Channel[Exception | None], Channel[Any]]] # book keeping _ack_buffer: deque[TransportMessage] @@ -200,7 +200,7 @@ def __init__( self._space_available.set() # stream for tasks - self._streams: dict[str, tuple[asyncio.Event, Channel[Any]]] = {} + self._streams: dict[str, tuple[Channel[Exception | None], Channel[Any]]] = {} # book keeping self._ack_buffer = deque() @@ -424,10 +424,14 @@ async def close(self) -> None: # TODO: unexpected_close should close stream differently here to # throw exception correctly. - for backpressure_waiter, stream in self._streams.values(): + for error_channel, stream in self._streams.values(): stream.close() # Wake up backpressured writers - backpressure_waiter.set() + await error_channel.put( + SessionClosedRiverServiceException( + "river session is closed", + ) + ) # Before we GC the streams, let's wait for all tasks to be closed gracefully. await asyncio.gather(*[stream.join() for _, stream in self._streams.values()]) self._streams.clear() @@ -446,7 +450,7 @@ async def close(self) -> None: def _start_buffered_message_sender( self, ) -> None: - def commit(msg: TransportMessage) -> None: + async def commit(msg: TransportMessage) -> None: pending = self._send_buffer.popleft() if msg.seq != pending.seq: logger.error("Out of sequence error") @@ -462,7 +466,7 @@ def commit(msg: TransportMessage) -> None: # Wake up backpressured writer stream_meta = self._streams.get(pending.streamId) if stream_meta: - stream_meta[0].set() + await stream_meta[0].put(None) def get_next_pending() -> TransportMessage | None: if self._send_buffer: @@ -577,22 +581,22 @@ async def _with_stream( self, session_id: str, maxsize: int, - ) -> AsyncIterator[tuple[asyncio.Event, Channel[ResultType]]]: + ) -> AsyncIterator[tuple[Channel[Exception | None], Channel[ResultType]]]: """ _with_stream An async context that exposes a managed stream and an event that permits producers to respond to backpressure. - It is expected that the first message emitted ignores this backpressure_waiter, + It is expected that the first message emitted ignores this error_channel, since the first event does not care about backpressure, but subsequent events - emitted should call await backpressure_waiter.wait() prior to emission. + emitted should call await error_channel.wait() prior to emission. """ output: Channel[Any] = Channel(maxsize=maxsize) - backpressure_waiter = asyncio.Event() - self._streams[session_id] = (backpressure_waiter, output) + error_channel: Channel[Exception | None] = Channel(maxsize=1) + self._streams[session_id] = (error_channel, output) try: - yield (backpressure_waiter, output) + yield (error_channel, output) finally: del self._streams[session_id] @@ -621,18 +625,17 @@ async def send_rpc[R, A]( span=span, ) - async with self._with_stream(stream_id, 1) as (backpressure_waiter, output): + async with self._with_stream(stream_id, 1) as (error_channel, output): # Handle potential errors during communication try: async with asyncio.timeout(timeout.total_seconds()): - # Block for event for symmetry with backpressured producers - # Here this should be trivially true. - await backpressure_waiter.wait() + # Block for backpressure and emission errors from the ws + if err := await error_channel.get(): + raise err result = await output.get() except asyncio.TimeoutError as e: await self._send_cancel_stream( stream_id=stream_id, - extra_control_flags=0, span=span, ) raise RiverException(ERROR_CODE_CANCEL, str(e)) from e @@ -643,7 +646,7 @@ async def send_rpc[R, A]( service_name, procedure_name, ) from e - except RuntimeError as e: + except Exception as e: raise RiverException(ERROR_CODE_STREAM_CLOSED, str(e)) from e if "ok" not in result or not result["ok"]: try: @@ -682,34 +685,41 @@ async def send_upload[I, R, A]( span=span, ) - async with self._with_stream(stream_id, 1) as (backpressure_waiter, output): + async with self._with_stream(stream_id, 1) as (error_channel, output): try: # If this request is not closed and the session is killed, we should # throw exception here async for item in request: - # Block for backpressure - await backpressure_waiter.wait() - if output.closed(): - logger.debug("Stream is closed, avoid sending the rest") - break + # Block for backpressure and emission errors from the ws + if err := await error_channel.get(): + raise err + + try: + payload = request_serializer(item) + except Exception as e: + await self._send_cancel_stream( + stream_id=stream_id, + span=span, + ) + raise RiverServiceException( + ERROR_CODE_STREAM_CLOSED, + str(e), + service_name, + procedure_name, + ) from e await self._enqueue_message( stream_id=stream_id, service_name=service_name, procedure_name=procedure_name, control_flags=0, - payload=request_serializer(item), + payload=payload, span=span, ) - except WebsocketClosedException as e: - raise RiverServiceException( - ERROR_CODE_STREAM_CLOSED, str(e), service_name, procedure_name - ) from e except Exception as e: # If we get any exception other than WebsocketClosedException, # cancel the stream. await self._send_cancel_stream( stream_id=stream_id, - extra_control_flags=0, span=span, ) raise RiverServiceException( @@ -729,8 +739,10 @@ async def send_upload[I, R, A]( service_name, procedure_name, ) from e - except RuntimeError as e: + except Exception as e: + await self._send_cancel_stream(stream_id, span) raise RiverException(ERROR_CODE_STREAM_CLOSED, str(e)) from e + if "ok" not in result or not result["ok"]: try: error = error_deserializer(result["payload"]) @@ -742,12 +754,12 @@ async def send_upload[I, R, A]( return response_deserializer(result["payload"]) - async def send_subscription[R, E, A]( + async def send_subscription[I, E, A]( self, service_name: str, procedure_name: str, - request: R, - request_serializer: Callable[[R], Any], + init: I, + init_serializer: Callable[[I], Any], response_deserializer: Callable[[Any], A], error_deserializer: Callable[[Any], E], span: Span, @@ -762,37 +774,26 @@ async def send_subscription[R, E, A]( procedure_name=procedure_name, stream_id=stream_id, control_flags=STREAM_OPEN_BIT, - payload=request_serializer(request), + payload=init_serializer(init), span=span, ) async with self._with_stream(stream_id, MAX_MESSAGE_BUFFER_SIZE) as (_, output): - # Handle potential errors during communication try: async for item in output: if item.get("type") == "CLOSE": break if not item.get("ok", False): - try: - yield error_deserializer(item["payload"]) - except Exception: - logger.exception( - "Error during subscription " - f"error deserialization: {item}" - ) - continue + yield error_deserializer(item["payload"]) yield response_deserializer(item["payload"]) - except (RuntimeError, ChannelClosed) as e: + except Exception as e: + await self._send_cancel_stream(stream_id, span) raise RiverServiceException( ERROR_CODE_STREAM_CLOSED, "Stream closed before response", service_name, procedure_name, ) from e - except Exception as e: - raise e - finally: - output.close() async def send_stream[I, R, E, A]( self, @@ -822,7 +823,7 @@ async def send_stream[I, R, E, A]( ) async with self._with_stream(stream_id, MAX_MESSAGE_BUFFER_SIZE) as ( - backpressure_waiter, + error_channel, output, ): # Create the encoder task @@ -837,12 +838,13 @@ async def _encode_stream() -> None: assert request_serializer, "send_stream missing request_serializer" async for item in request: - if item is None: - continue - await backpressure_waiter.wait() - if output.closed(): - logger.debug("Stream is closed, avoid sending the rest") - break + # Block for backpressure and emission errors from the ws + if err := await error_channel.get(): + await self._send_close_stream( + stream_id=stream_id, + span=span, + ) + raise err await self._enqueue_message( stream_id=stream_id, control_flags=0, @@ -853,44 +855,38 @@ async def _encode_stream() -> None: span=span, ) - self._task_manager.create_task(_encode_stream()) + emitter_task = self._task_manager.create_task(_encode_stream()) # Handle potential errors during communication try: async for result in output: + # Raise as early as we possibly can in case of an emission error + if err := emitter_task.exception(): + raise err if result.get("type") == "CLOSE": break if "ok" not in result or not result["ok"]: - try: - yield error_deserializer(result["payload"]) - except Exception: - logger.exception( - f"Error during stream error deserialization: {result}" - ) - continue + yield error_deserializer(result["payload"]) yield response_deserializer(result["payload"]) - except (RuntimeError, ChannelClosed) as e: + # ... block the outer function until the emitter is finished emitting. + await emitter_task + except Exception as e: + await self._send_cancel_stream(stream_id, span) raise RiverServiceException( ERROR_CODE_STREAM_CLOSED, "Stream closed before response", service_name, procedure_name, ) from e - except Exception as e: - raise e - finally: - output.close() - backpressure_waiter.set() async def _send_cancel_stream( self, stream_id: str, - extra_control_flags: int, span: Span, ) -> None: await self._enqueue_message( stream_id=stream_id, - control_flags=STREAM_CANCEL_BIT | extra_control_flags, + control_flags=STREAM_CANCEL_BIT, payload={"type": "CANCEL"}, span=span, ) @@ -1079,7 +1075,7 @@ async def _recv_from_ws( assert_incoming_seq_bookkeeping: Callable[ [str, int, int], Literal[True] | _IgnoreMessage ], - get_stream: Callable[[str], tuple[asyncio.Event, Channel[Any]] | None], + get_stream: Callable[[str], tuple[Channel[Exception | None], Channel[Any]] | None], enqueue_message: SendMessage[None], ) -> None: """Serve messages from the websocket. @@ -1088,11 +1084,11 @@ async def _recv_from_ws( """ reset_session_close_countdown() our_task = asyncio.current_task() - idx = 0 + connection_attempts = 0 try: while our_task and not our_task.cancelling() and not our_task.cancelled(): - logger.debug(f"_recv_from_ws loop count={idx}") - idx += 1 + logger.debug(f"_recv_from_ws loop count={connection_attempts}") + connection_attempts += 1 ws = None while ((state := get_state()) in ConnectingStates) and ( state not in TerminalStates @@ -1113,8 +1109,11 @@ async def _recv_from_ws( logger.debug("client start handling messages from ws %r", ws) + error_channel: Channel[Exception | None] | None = None + # We should not process messages if the websocket is closed. while (ws := get_ws()) and get_state() in ActiveStates: + connection_attempts = 0 # decode=False: Avoiding an unnecessary round-trip through str # Ideally this should be type-ascripted to : bytes, but there # is no @overrides in `websockets` to hint this. @@ -1175,16 +1174,16 @@ async def _recv_from_ws( ) continue - waiter_and_stream = get_stream(msg.streamId) + errors_and_stream = get_stream(msg.streamId) - if not waiter_and_stream: + if not errors_and_stream: logger.warning( "no stream for %s, ignoring message", msg.streamId, ) continue - backpressure_waiter, output = waiter_and_stream + error_channel, output = errors_and_stream if ( msg.controlFlags & STREAM_CLOSED_BIT != 0 @@ -1203,18 +1202,30 @@ async def _recv_from_ws( if msg.controlFlags & STREAM_CLOSED_BIT != 0: # Communicate that we're going down + # + # This implements the receive side of the half-closed strategy. output.close() - # Wake up backpressured writer - backpressure_waiter.set() except OutOfOrderMessageException: logger.exception("Out of order message, closing connection") await close_session() + if error_channel: + await error_channel.put( + SessionClosedRiverServiceException( + "Out of order message, closing connection" + ) + ) continue except InvalidMessageException: logger.exception( "Got invalid transport message, closing session", ) await close_session() + if error_channel: + await error_channel.put( + SessionClosedRiverServiceException( + "Out of order message, closing connection" + ) + ) continue except FailedSendingMessageException: # Expected error if the connection is closed. @@ -1239,4 +1250,4 @@ async def _recv_from_ws( exc_info=unhandled, ) raise unhandled - logger.debug(f"_recv_from_ws exiting normally after {idx} loops") + logger.debug(f"_recv_from_ws exiting normally after {connection_attempts} loops") From 2a8a32fedac14ba4fcb254a02412263210e7a4a7 Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Thu, 3 Apr 2025 21:13:55 -0700 Subject: [PATCH 166/193] clarify client/transport/session_id parameters --- src/replit_river/v2/session.py | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/src/replit_river/v2/session.py b/src/replit_river/v2/session.py index 455b4e91..c766c17e 100644 --- a/src/replit_river/v2/session.py +++ b/src/replit_river/v2/session.py @@ -119,7 +119,6 @@ class _IgnoreMessage: class Session[HandshakeMetadata]: - _transport_id: str _server_id: str session_id: str _transport_options: TransportOptions @@ -184,6 +183,13 @@ def __init__( self._wait_for_connected = asyncio.Event() self._client_id = client_id + # TODO: LeakyBucketRateLimit accepts "user" for all methods, which has + # historically been and continues to be "client_id". + # + # There's 1:1 client <-> transport, which means LeakyBucketRateLimit is only + # tracking exactly one rate limit. + # + # The "user" parameter is YAGNI, dethread client_id after v1 is deleted. self._rate_limiter = rate_limiter self._uri_and_metadata_factory = uri_and_metadata_factory @@ -324,7 +330,7 @@ async def _begin_close_session_countdown(self) -> None: return logger.info( "websocket closed from %s to %s begin grace period", - self._transport_id, + self.session_id, self._server_id, ) self._state = SessionState.NO_CONNECTION @@ -372,7 +378,7 @@ async def _enqueue_message( msg = TransportMessage( streamId=stream_id, id=nanoid.generate(), - from_=self._transport_id, + from_=self._client_id, to=self._server_id, seq=self.seq, ack=self.ack, @@ -402,7 +408,7 @@ async def _enqueue_message( async def close(self) -> None: """Close the session and all associated streams.""" logger.info( - f"{self._transport_id} closing session to {self._server_id}, ws: {self._ws}" + f"{self.session_id} closing session to {self._server_id}, ws: {self._ws}" ) if self._state in TerminalStates: # already closing @@ -563,7 +569,7 @@ async def block_until_connected() -> None: self._task_manager.create_task( _recv_from_ws( block_until_connected=block_until_connected, - client_id=self._transport_id, + client_id=self._client_id, get_state=lambda: self._state, get_ws=lambda: self._ws, transition_connecting=transition_connecting, @@ -906,8 +912,8 @@ async def _send_close_stream( async def _do_ensure_connected[HandshakeMetadata]( client_id: str, - server_id: str, session_id: str, + server_id: str, max_retry: int, rate_limiter: LeakyBucketRateLimit, uri_and_metadata_factory: Callable[ From 2e43b74dc4e67588558c509fbec899bcb171d700 Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Thu, 3 Apr 2025 22:06:41 -0700 Subject: [PATCH 167/193] Send CANCEL correctly --- src/replit_river/v2/session.py | 30 ++++++++++++++++++++++++++---- 1 file changed, 26 insertions(+), 4 deletions(-) diff --git a/src/replit_river/v2/session.py b/src/replit_river/v2/session.py index c766c17e..7055ba8a 100644 --- a/src/replit_river/v2/session.py +++ b/src/replit_river/v2/session.py @@ -642,6 +642,7 @@ async def send_rpc[R, A]( except asyncio.TimeoutError as e: await self._send_cancel_stream( stream_id=stream_id, + message="Timeout, abandoning request", span=span, ) raise RiverException(ERROR_CODE_CANCEL, str(e)) from e @@ -705,6 +706,7 @@ async def send_upload[I, R, A]( except Exception as e: await self._send_cancel_stream( stream_id=stream_id, + message="Request serialization error", span=span, ) raise RiverServiceException( @@ -726,6 +728,7 @@ async def send_upload[I, R, A]( # cancel the stream. await self._send_cancel_stream( stream_id=stream_id, + message="Unspecified error", span=span, ) raise RiverServiceException( @@ -746,7 +749,11 @@ async def send_upload[I, R, A]( procedure_name, ) from e except Exception as e: - await self._send_cancel_stream(stream_id, span) + await self._send_cancel_stream( + stream_id=stream_id, + message="Unspecified error", + span=span, + ) raise RiverException(ERROR_CODE_STREAM_CLOSED, str(e)) from e if "ok" not in result or not result["ok"]: @@ -793,7 +800,11 @@ async def send_subscription[I, E, A]( yield error_deserializer(item["payload"]) yield response_deserializer(item["payload"]) except Exception as e: - await self._send_cancel_stream(stream_id, span) + await self._send_cancel_stream( + stream_id=stream_id, + message="Unspecified error", + span=span, + ) raise RiverServiceException( ERROR_CODE_STREAM_CLOSED, "Stream closed before response", @@ -877,7 +888,11 @@ async def _encode_stream() -> None: # ... block the outer function until the emitter is finished emitting. await emitter_task except Exception as e: - await self._send_cancel_stream(stream_id, span) + await self._send_cancel_stream( + stream_id=stream_id, + message="Unspecified error", + span=span, + ) raise RiverServiceException( ERROR_CODE_STREAM_CLOSED, "Stream closed before response", @@ -888,12 +903,19 @@ async def _encode_stream() -> None: async def _send_cancel_stream( self, stream_id: str, + message: str, span: Span, ) -> None: await self._enqueue_message( stream_id=stream_id, control_flags=STREAM_CANCEL_BIT, - payload={"type": "CANCEL"}, + payload={ + "ok": False, + "payload": { + "code": "CANCEL", + "message": message, + }, + }, span=span, ) From 919a5faea8cd695e7796c9820eaf1315145043ca Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Thu, 3 Apr 2025 22:19:20 -0700 Subject: [PATCH 168/193] Dunno --- src/replit_river/client_session.py | 17 ++++++----------- src/replit_river/v2/session.py | 2 +- 2 files changed, 7 insertions(+), 12 deletions(-) diff --git a/src/replit_river/client_session.py b/src/replit_river/client_session.py index c37768b7..ef535ee3 100644 --- a/src/replit_river/client_session.py +++ b/src/replit_river/client_session.py @@ -173,7 +173,7 @@ async def _handle_messages_from_ws(self) -> None: # The client is no longer interested in this stream, # just drop the message. pass - except RuntimeError as e: + except Exception as e: raise InvalidMessageException(e) from e else: raise InvalidMessageException( @@ -244,7 +244,7 @@ async def send_rpc( service_name, procedure_name, ) from e - except RuntimeError as e: + except Exception as e: raise RiverException(ERROR_CODE_STREAM_CLOSED, str(e)) from e if not response.get("ok", False): try: @@ -330,7 +330,7 @@ async def send_upload( service_name, procedure_name, ) from e - except RuntimeError as e: + except Exception as e: raise RiverException(ERROR_CODE_STREAM_CLOSED, str(e)) from e if not response.get("ok", False): try: @@ -387,15 +387,13 @@ async def send_subscription( ) continue yield response_deserializer(item["payload"]) - except (RuntimeError, ChannelClosed) as e: + except Exception as e: raise RiverServiceException( ERROR_CODE_STREAM_CLOSED, "Stream closed before response", service_name, procedure_name, ) from e - except Exception as e: - raise e finally: output.close() @@ -490,17 +488,14 @@ async def _encode_stream() -> None: ) continue yield response_deserializer(item["payload"]) - except (RuntimeError, ChannelClosed) as e: + except Exception as e: + logger.exception("There was a problem") raise RiverServiceException( ERROR_CODE_STREAM_CLOSED, "Stream closed before response", service_name, procedure_name, ) from e - except Exception as e: - raise e - finally: - output.close() async def send_close_stream( self, diff --git a/src/replit_river/v2/session.py b/src/replit_river/v2/session.py index 7055ba8a..cde44ac1 100644 --- a/src/replit_river/v2/session.py +++ b/src/replit_river/v2/session.py @@ -878,7 +878,7 @@ async def _encode_stream() -> None: try: async for result in output: # Raise as early as we possibly can in case of an emission error - if err := emitter_task.exception(): + if err := emitter_task.done() and emitter_task.exception(): raise err if result.get("type") == "CLOSE": break From f73237d04c044eb610573782464388bd036df2a2 Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Thu, 3 Apr 2025 22:53:07 -0700 Subject: [PATCH 169/193] Forgot the last part, which is reconnect immediately instead of waiting. --- src/replit_river/v2/session.py | 29 +++-------------------------- 1 file changed, 3 insertions(+), 26 deletions(-) diff --git a/src/replit_river/v2/session.py b/src/replit_river/v2/session.py index cde44ac1..6a31965d 100644 --- a/src/replit_river/v2/session.py +++ b/src/replit_river/v2/session.py @@ -314,29 +314,6 @@ def is_closed(self) -> bool: def is_connected(self) -> bool: return self._state in ActiveStates - async def _begin_close_session_countdown(self) -> None: - """Begin the countdown to close session, this should be called when - websocket is closed. - """ - # calculate the value now before establishing it so that there are no - # await points between the check and the assignment to avoid a TOCTOU - # race. - grace_period_ms = self._transport_options.session_disconnect_grace_ms - close_session_after_time_secs = ( - await self._get_current_time() + grace_period_ms / 1000 - ) - if self._close_session_after_time_secs is not None: - # already in grace period, no need to set again - return - logger.info( - "websocket closed from %s to %s begin grace period", - self.session_id, - self._server_id, - ) - self._state = SessionState.NO_CONNECTION - self._close_session_after_time_secs = close_session_after_time_secs - self._wait_for_connected.clear() - async def _get_current_time(self) -> float: return asyncio.get_event_loop().time() @@ -497,7 +474,7 @@ async def block_until_message_available() -> None: block_until_connected=block_until_connected, block_until_message_available=block_until_message_available, get_ws=get_ws, - websocket_closed_callback=self._begin_close_session_countdown, + websocket_closed_callback=self.ensure_connected, get_next_pending=get_next_pending, commit=commit, get_state=lambda: self._state, @@ -521,8 +498,8 @@ async def transition_no_connection() -> None: if self._retry_connection_callback: self._task_manager.create_task(self._retry_connection_callback()) - - await self._begin_close_session_countdown() + else: + await self.ensure_connected() def assert_incoming_seq_bookkeeping( msg_from: str, From 1af291514661c7e1603ed593b8e565eced6a332e Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Thu, 3 Apr 2025 23:07:08 -0700 Subject: [PATCH 170/193] Whoops. --- src/replit_river/v2/session.py | 10 +--------- 1 file changed, 1 insertion(+), 9 deletions(-) diff --git a/src/replit_river/v2/session.py b/src/replit_river/v2/session.py index 6a31965d..9428fa00 100644 --- a/src/replit_river/v2/session.py +++ b/src/replit_river/v2/session.py @@ -482,12 +482,6 @@ async def block_until_message_available() -> None: ) def _start_recv_from_ws(self) -> None: - def transition_connecting() -> None: - if self._state in TerminalStates: - return - self._state = SessionState.CONNECTING - self._wait_for_connected.clear() - async def transition_no_connection() -> None: if self._state in TerminalStates: return @@ -549,7 +543,6 @@ async def block_until_connected() -> None: client_id=self._client_id, get_state=lambda: self._state, get_ws=lambda: self._ws, - transition_connecting=transition_connecting, transition_no_connection=transition_no_connection, reset_session_close_countdown=self._reset_session_close_countdown, close_session=self.close, @@ -1073,7 +1066,6 @@ async def _recv_from_ws( client_id: str, get_state: Callable[[], SessionState], get_ws: Callable[[], ClientConnection | None], - transition_connecting: Callable[[], None], transition_no_connection: Callable[[], Awaitable[None]], reset_session_close_countdown: Callable[[], None], close_session: Callable[[], Awaitable[None]], @@ -1127,7 +1119,7 @@ async def _recv_from_ws( except ConnectionClosed: # This triggers a break in the inner loop so we can get back to # the outer loop. - transition_connecting() + await transition_no_connection() break try: msg = parse_transport_msg(message) From 2813fe62a2b95402d293aa344aa049f75f46de26 Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Thu, 3 Apr 2025 23:22:02 -0700 Subject: [PATCH 171/193] Missing wait_for_connected.clear() --- src/replit_river/v2/session.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/replit_river/v2/session.py b/src/replit_river/v2/session.py index 9428fa00..5c0aa67d 100644 --- a/src/replit_river/v2/session.py +++ b/src/replit_river/v2/session.py @@ -462,6 +462,8 @@ def get_ws() -> ClientConnection | None: return None async def block_until_connected() -> None: + if self._state in TerminalStates: + return logger.debug("block_until_connected") await self._wait_for_connected.wait() logger.debug("block_until_connected released!") @@ -486,6 +488,7 @@ async def transition_no_connection() -> None: if self._state in TerminalStates: return self._state = SessionState.NO_CONNECTION + self._wait_for_connected.clear() if self._ws: self._task_manager.create_task(self._ws.close()) self._ws = None From 014248d1e30943336b3cdd13bab78aaf382bee75 Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Thu, 3 Apr 2025 23:37:50 -0700 Subject: [PATCH 172/193] Turns out the new hotness is not available in 12.0 --- pyproject.toml | 2 +- uv.lock | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 639f2399..2f049a3c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -28,7 +28,7 @@ dependencies = [ "nanoid>=2.0.0", "protobuf>=5.28.3", "pydantic-core>=2.20.1", - "websockets>=12.0", + "websockets>=13.0,<14", "opentelemetry-sdk>=1.28.2", "opentelemetry-api>=1.28.2", ] diff --git a/uv.lock b/uv.lock index fcd7277e..4833a89f 100644 --- a/uv.lock +++ b/uv.lock @@ -627,7 +627,7 @@ requires-dist = [ { name = "protobuf", specifier = ">=5.28.3" }, { name = "pydantic", specifier = "==2.9.2" }, { name = "pydantic-core", specifier = ">=2.20.1" }, - { name = "websockets", specifier = ">=12.0" }, + { name = "websockets", specifier = ">=13.0,<14" }, ] [package.metadata.requires-dev] From b872d21caccc32d2a7f8008c0c75f010c5630f80 Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Fri, 4 Apr 2025 17:10:43 -0700 Subject: [PATCH 173/193] Deleting dead code --- src/replit_river/v2/session.py | 12 ------------ 1 file changed, 12 deletions(-) diff --git a/src/replit_river/v2/session.py b/src/replit_river/v2/session.py index 5c0aa67d..2ebb7403 100644 --- a/src/replit_river/v2/session.py +++ b/src/replit_river/v2/session.py @@ -138,7 +138,6 @@ class Session[HandshakeMetadata]: # ws state _ws: ClientConnection | None - _heartbeat_misses: int _retry_connection_callback: RetryConnectionCallback | None # message state @@ -195,7 +194,6 @@ def __init__( # ws state self._ws = None - self._heartbeat_misses = 0 self._retry_connection_callback = retry_connection_callback # message state @@ -317,11 +315,6 @@ def is_connected(self) -> bool: async def _get_current_time(self) -> float: return asyncio.get_event_loop().time() - def _reset_session_close_countdown(self) -> None: - logger.debug("_reset_session_close_countdown") - self._heartbeat_misses = 0 - self._close_session_after_time_secs = None - async def _enqueue_message( self, stream_id: str, @@ -547,7 +540,6 @@ async def block_until_connected() -> None: get_state=lambda: self._state, get_ws=lambda: self._ws, transition_no_connection=transition_no_connection, - reset_session_close_countdown=self._reset_session_close_countdown, close_session=self.close, assert_incoming_seq_bookkeeping=assert_incoming_seq_bookkeeping, get_stream=lambda stream_id: self._streams.get(stream_id), @@ -1070,7 +1062,6 @@ async def _recv_from_ws( get_state: Callable[[], SessionState], get_ws: Callable[[], ClientConnection | None], transition_no_connection: Callable[[], Awaitable[None]], - reset_session_close_countdown: Callable[[], None], close_session: Callable[[], Awaitable[None]], assert_incoming_seq_bookkeeping: Callable[ [str, int, int], Literal[True] | _IgnoreMessage @@ -1082,7 +1073,6 @@ async def _recv_from_ws( Process incoming packets from the connected websocket. """ - reset_session_close_countdown() our_task = asyncio.current_task() connection_attempts = 0 try: @@ -1156,8 +1146,6 @@ async def _recv_from_ws( case other: assert_never(other) - reset_session_close_countdown() - # Shortcut to avoid processing ack packets if msg.controlFlags & ACK_BIT != 0: await enqueue_message( From 8b7d2a777107e893bdcaf3a6f3fd33f88d2c1038 Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Fri, 4 Apr 2025 17:13:18 -0700 Subject: [PATCH 174/193] Delete dead parameters --- src/replit_river/v2/session.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/src/replit_river/v2/session.py b/src/replit_river/v2/session.py index 2ebb7403..89e2b603 100644 --- a/src/replit_river/v2/session.py +++ b/src/replit_river/v2/session.py @@ -285,7 +285,6 @@ def unbind_connecting_task() -> None: client_id=self._client_id, server_id=self._server_id, session_id=self.session_id, - max_retry=self._transport_options.connection_retry_options.max_retry, rate_limiter=self._rate_limiter, uri_and_metadata_factory=self._uri_and_metadata_factory, get_next_sent_seq=get_next_sent_seq, @@ -901,7 +900,6 @@ async def _do_ensure_connected[HandshakeMetadata]( client_id: str, session_id: str, server_id: str, - max_retry: int, rate_limiter: LeakyBucketRateLimit, uri_and_metadata_factory: Callable[ [], Awaitable[UriAndMetadata[HandshakeMetadata]] @@ -918,11 +916,11 @@ async def _do_ensure_connected[HandshakeMetadata]( logger.info("Attempting to establish new ws connection") last_error: Exception | None = None - i = 0 + attempt_count = 0 while rate_limiter.has_budget(client_id): - if i > 0: - logger.info(f"Retrying build handshake number {i} times") - i += 1 + if attempt_count > 0: + logger.info(f"Retrying build handshake number {attempt_count} times") + attempt_count += 1 rate_limiter.consume_budget(client_id) transition_connecting() @@ -1050,7 +1048,7 @@ async def websocket_closed_callback() -> None: do_close() raise RiverException( ERROR_HANDSHAKE, - f"Failed to create ws after retrying {max_retry} number of times", + f"Failed to create ws after retrying {attempt_count} number of times", ) from last_error return None From e3daed355dbc8331e309b3445204541828c31b0c Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Fri, 4 Apr 2025 17:34:35 -0700 Subject: [PATCH 175/193] Didn't actually use watchdog --- tests/v2/fixtures.py | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/tests/v2/fixtures.py b/tests/v2/fixtures.py index eae41605..dda934fc 100644 --- a/tests/v2/fixtures.py +++ b/tests/v2/fixtures.py @@ -166,16 +166,6 @@ async def driver() -> None: driver_task = asyncio.create_task(driver()) - # Watchdog keeps track of the above tasks - # async def watchdog() -> None: - # while True: - # print(repr(server_task)) - # print(repr(processor_task)) - # print(repr(driver_task)) - # await asyncio.sleep(1) - # - # watchdog_task = asyncio.create_task(watchdog()) - async def uri_and_metadata_factory() -> UriAndMetadata[None]: return uri_and_metadata From 58affb0898c5ceea48707970c797d3caa75cdd1e Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Fri, 4 Apr 2025 19:50:15 -0700 Subject: [PATCH 176/193] Dangling unnecessary #ignore --- tests/v1/test_message_buffer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/v1/test_message_buffer.py b/tests/v1/test_message_buffer.py index 02a21ccb..d5d1bda4 100644 --- a/tests/v1/test_message_buffer.py +++ b/tests/v1/test_message_buffer.py @@ -11,7 +11,7 @@ def mock_transport_message(seq: int) -> TransportMessage: seq=seq, id="test", ack=0, - from_="test", # type: ignore + from_="test", to="test", streamId="test", controlFlags=0, From 3758ce9bc5c5e2cf9477495663047de18d30001e Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Fri, 4 Apr 2025 19:59:38 -0700 Subject: [PATCH 177/193] session_id -> stream_id, as well as close error_channel --- src/replit_river/v2/session.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/src/replit_river/v2/session.py b/src/replit_river/v2/session.py index 89e2b603..91849f07 100644 --- a/src/replit_river/v2/session.py +++ b/src/replit_river/v2/session.py @@ -549,7 +549,7 @@ async def block_until_connected() -> None: @asynccontextmanager async def _with_stream( self, - session_id: str, + stream_id: str, maxsize: int, ) -> AsyncIterator[tuple[Channel[Exception | None], Channel[ResultType]]]: """ @@ -564,11 +564,20 @@ async def _with_stream( """ output: Channel[Any] = Channel(maxsize=maxsize) error_channel: Channel[Exception | None] = Channel(maxsize=1) - self._streams[session_id] = (error_channel, output) + self._streams[stream_id] = (error_channel, output) try: yield (error_channel, output) finally: - del self._streams[session_id] + stream_meta = self._streams.get(stream_id) + if not stream_meta: + logger.warning("_with_stream had an entry deleted out from under it", extra={ + "session_id": self.session_id, + "stream_id": stream_id, + }) + return + # We need to signal back to all emitters or waiters that we're gone + stream_meta[0].close() + del self._streams[stream_id] async def send_rpc[R, A]( self, From dd76d11ad03ea3ed1b1a05b2efb1b49ed1cde340 Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Fri, 4 Apr 2025 20:13:14 -0700 Subject: [PATCH 178/193] Adding Span to _streams --- src/replit_river/v2/session.py | 52 +++++++++++++++++++++------------- 1 file changed, 33 insertions(+), 19 deletions(-) diff --git a/src/replit_river/v2/session.py b/src/replit_river/v2/session.py index 91849f07..3b9c29a0 100644 --- a/src/replit_river/v2/session.py +++ b/src/replit_river/v2/session.py @@ -145,7 +145,7 @@ class Session[HandshakeMetadata]: _space_available: asyncio.Event # stream for tasks - _streams: dict[str, tuple[Channel[Exception | None], Channel[Any]]] + _streams: dict[str, tuple[Span, Channel[Exception | None], Channel[Any]]] # book keeping _ack_buffer: deque[TransportMessage] @@ -204,7 +204,9 @@ def __init__( self._space_available.set() # stream for tasks - self._streams: dict[str, tuple[Channel[Exception | None], Channel[Any]]] = {} + self._streams: dict[ + str, tuple[Span, Channel[Exception | None], Channel[Any]] + ] = {} # book keeping self._ack_buffer = deque() @@ -399,7 +401,7 @@ async def close(self) -> None: # TODO: unexpected_close should close stream differently here to # throw exception correctly. - for error_channel, stream in self._streams.values(): + for _, error_channel, stream in self._streams.values(): stream.close() # Wake up backpressured writers await error_channel.put( @@ -408,7 +410,9 @@ async def close(self) -> None: ) ) # Before we GC the streams, let's wait for all tasks to be closed gracefully. - await asyncio.gather(*[stream.join() for _, stream in self._streams.values()]) + await asyncio.gather( + *[stream.join() for _, _, stream in self._streams.values()] + ) self._streams.clear() if self._ws: @@ -441,7 +445,7 @@ async def commit(msg: TransportMessage) -> None: # Wake up backpressured writer stream_meta = self._streams.get(pending.streamId) if stream_meta: - await stream_meta[0].put(None) + await stream_meta[1].put(None) def get_next_pending() -> TransportMessage | None: if self._send_buffer: @@ -549,6 +553,7 @@ async def block_until_connected() -> None: @asynccontextmanager async def _with_stream( self, + span: Span, stream_id: str, maxsize: int, ) -> AsyncIterator[tuple[Channel[Exception | None], Channel[ResultType]]]: @@ -564,19 +569,22 @@ async def _with_stream( """ output: Channel[Any] = Channel(maxsize=maxsize) error_channel: Channel[Exception | None] = Channel(maxsize=1) - self._streams[stream_id] = (error_channel, output) + self._streams[stream_id] = (span, error_channel, output) try: yield (error_channel, output) finally: stream_meta = self._streams.get(stream_id) if not stream_meta: - logger.warning("_with_stream had an entry deleted out from under it", extra={ - "session_id": self.session_id, - "stream_id": stream_id, - }) + logger.warning( + "_with_stream had an entry deleted out from under it", + extra={ + "session_id": self.session_id, + "stream_id": stream_id, + }, + ) return # We need to signal back to all emitters or waiters that we're gone - stream_meta[0].close() + stream_meta[1].close() del self._streams[stream_id] async def send_rpc[R, A]( @@ -604,7 +612,7 @@ async def send_rpc[R, A]( span=span, ) - async with self._with_stream(stream_id, 1) as (error_channel, output): + async with self._with_stream(span, stream_id, 1) as (error_channel, output): # Handle potential errors during communication try: async with asyncio.timeout(timeout.total_seconds()): @@ -665,7 +673,7 @@ async def send_upload[I, R, A]( span=span, ) - async with self._with_stream(stream_id, 1) as (error_channel, output): + async with self._with_stream(span, stream_id, 1) as (error_channel, output): try: # If this request is not closed and the session is killed, we should # throw exception here @@ -764,7 +772,10 @@ async def send_subscription[I, E, A]( span=span, ) - async with self._with_stream(stream_id, MAX_MESSAGE_BUFFER_SIZE) as (_, output): + async with self._with_stream(span, stream_id, MAX_MESSAGE_BUFFER_SIZE) as ( + _, + output, + ): try: async for item in output: if item.get("type") == "CLOSE": @@ -812,7 +823,7 @@ async def send_stream[I, R, E, A]( span=span, ) - async with self._with_stream(stream_id, MAX_MESSAGE_BUFFER_SIZE) as ( + async with self._with_stream(span, stream_id, MAX_MESSAGE_BUFFER_SIZE) as ( error_channel, output, ): @@ -1073,7 +1084,10 @@ async def _recv_from_ws( assert_incoming_seq_bookkeeping: Callable[ [str, int, int], Literal[True] | _IgnoreMessage ], - get_stream: Callable[[str], tuple[Channel[Exception | None], Channel[Any]] | None], + get_stream: Callable[ + [str], + tuple[Span, Channel[Exception | None], Channel[Any]] | None, + ], enqueue_message: SendMessage[None], ) -> None: """Serve messages from the websocket. @@ -1169,16 +1183,16 @@ async def _recv_from_ws( ) continue - errors_and_stream = get_stream(msg.streamId) + stream_meta = get_stream(msg.streamId) - if not errors_and_stream: + if not stream_meta: logger.warning( "no stream for %s, ignoring message", msg.streamId, ) continue - error_channel, output = errors_and_stream + _, error_channel, output = stream_meta if ( msg.controlFlags & STREAM_CLOSED_BIT != 0 From 8b9d6cd75bf8cf39cd86242ba6d9e432570e014a Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Fri, 4 Apr 2025 20:23:21 -0700 Subject: [PATCH 179/193] Push exception emission up into close_session directly --- src/replit_river/v2/session.py | 46 ++++++++++++++++------------------ 1 file changed, 22 insertions(+), 24 deletions(-) diff --git a/src/replit_river/v2/session.py b/src/replit_river/v2/session.py index 3b9c29a0..5ff97c99 100644 --- a/src/replit_river/v2/session.py +++ b/src/replit_river/v2/session.py @@ -21,7 +21,7 @@ import nanoid import websockets.asyncio.client -from aiochannel import Channel +from aiochannel import Channel, ChannelFull from aiochannel.errors import ChannelClosed from opentelemetry.trace import Span, use_span from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator @@ -376,7 +376,7 @@ async def _enqueue_message( # Wake up buffered_message_sender self._process_messages.set() - async def close(self) -> None: + async def close(self, reason: Exception | None = None) -> None: """Close the session and all associated streams.""" logger.info( f"{self.session_id} closing session to {self._server_id}, ws: {self._ws}" @@ -399,16 +399,20 @@ async def close(self) -> None: await self._task_manager.cancel_all_tasks() - # TODO: unexpected_close should close stream differently here to - # throw exception correctly. for _, error_channel, stream in self._streams.values(): stream.close() # Wake up backpressured writers - await error_channel.put( - SessionClosedRiverServiceException( - "river session is closed", + try: + error_channel.put_nowait( + reason + or SessionClosedRiverServiceException( + "river session is closed", + ) + ) + except ChannelFull: + logger.exception( + "Unable to tell the caller that the session is going away", ) - ) # Before we GC the streams, let's wait for all tasks to be closed gracefully. await asyncio.gather( *[stream.join() for _, _, stream in self._streams.values()] @@ -1080,7 +1084,7 @@ async def _recv_from_ws( get_state: Callable[[], SessionState], get_ws: Callable[[], ClientConnection | None], transition_no_connection: Callable[[], Awaitable[None]], - close_session: Callable[[], Awaitable[None]], + close_session: Callable[[Exception | None], Awaitable[None]], assert_incoming_seq_bookkeeping: Callable[ [str, int, int], Literal[True] | _IgnoreMessage ], @@ -1120,8 +1124,6 @@ async def _recv_from_ws( logger.debug("client start handling messages from ws %r", ws) - error_channel: Channel[Exception | None] | None = None - # We should not process messages if the websocket is closed. while (ws := get_ws()) and get_state() in ActiveStates: connection_attempts = 0 @@ -1192,7 +1194,7 @@ async def _recv_from_ws( ) continue - _, error_channel, output = stream_meta + _, _, output = stream_meta if ( msg.controlFlags & STREAM_CLOSED_BIT != 0 @@ -1216,25 +1218,21 @@ async def _recv_from_ws( output.close() except OutOfOrderMessageException: logger.exception("Out of order message, closing connection") - await close_session() - if error_channel: - await error_channel.put( - SessionClosedRiverServiceException( - "Out of order message, closing connection" - ) + await close_session( + SessionClosedRiverServiceException( + "Out of order message, closing connection" ) + ) continue except InvalidMessageException: logger.exception( "Got invalid transport message, closing session", ) - await close_session() - if error_channel: - await error_channel.put( - SessionClosedRiverServiceException( - "Out of order message, closing connection" - ) + await close_session( + SessionClosedRiverServiceException( + "Out of order message, closing connection" ) + ) continue except FailedSendingMessageException: # Expected error if the connection is closed. From 0c4435729b76d98fac61b3cb1d66183ce63b3e6d Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Fri, 4 Apr 2025 20:27:54 -0700 Subject: [PATCH 180/193] Communicate handshake errors back to callers as well --- src/replit_river/common_session.py | 9 +++++++ src/replit_river/v2/session.py | 39 +++++++++++++++++++++++------- 2 files changed, 39 insertions(+), 9 deletions(-) diff --git a/src/replit_river/common_session.py b/src/replit_river/common_session.py index c401d269..df4c3549 100644 --- a/src/replit_river/common_session.py +++ b/src/replit_river/common_session.py @@ -62,6 +62,15 @@ async def buffered_message_sender( commit: Callable[[TransportMessage], Awaitable[None]], get_state: Callable[[], SessionState], ) -> None: + """ + buffered_message_sender runs in a task and consumes from a queue, emitting + messages over the websocket as quickly as it can. + + One of the design goals is to keep the message queue as short as possible to permit + quickly cancelling streams or acking heartbeats, so to that end it is wise to + incorporate backpressure into the lifecycle of get_next_pending/commit. + """ + our_task = asyncio.current_task() while our_task and not our_task.cancelling() and not our_task.cancelled(): while get_state() in ConnectingStates: diff --git a/src/replit_river/v2/session.py b/src/replit_river/v2/session.py index 5ff97c99..62e566c0 100644 --- a/src/replit_river/v2/session.py +++ b/src/replit_river/v2/session.py @@ -239,13 +239,13 @@ def get_next_sent_seq() -> int: return self._send_buffer[0].seq return self.seq - def do_close() -> None: + def close_session(reason: Exception | None) -> None: # Avoid closing twice if self._terminating_task is None: # We can't just call self.close() directly because # we're inside a thread that will eventually be awaited # during the cleanup procedure. - self._terminating_task = asyncio.create_task(self.close()) + self._terminating_task = asyncio.create_task(self.close(reason)) def transition_connecting() -> None: if self._state in TerminalStates: @@ -296,7 +296,7 @@ def unbind_connecting_task() -> None: close_ws_in_background=close_ws_in_background, transition_connected=transition_connected, unbind_connecting_task=unbind_connecting_task, - do_close=do_close, + close_session=close_session, ) ) @@ -433,6 +433,26 @@ async def close(self, reason: Exception | None = None) -> None: def _start_buffered_message_sender( self, ) -> None: + """ + Building on buffered_message_sender's documentation, we implement backpressure + per-stream by way of self._streams' + + error_channel: Channel[Exception | None] + + This is accomplished via the following strategy: + - If buffered_message_sender encounters an error, we transition back to + connecting and attempt to handshake. + + If the handshake fails, we close the session with an informative error that + gets emitted to all backpressured client methods. + + - Alternately, if buffered_message_sender successfully writes back to the + + - Finally, if _recv_from_ws encounters an error (transport or deserialization), + we emit an informative error to close_session which gets emitted to all + backpressured client methods. + """ + async def commit(msg: TransportMessage) -> None: pending = self._send_buffer.popleft() if msg.seq != pending.seq: @@ -935,7 +955,7 @@ async def _do_ensure_connected[HandshakeMetadata]( close_ws_in_background: Callable[[ClientConnection], None], transition_connected: Callable[[ClientConnection], None], unbind_connecting_task: Callable[[], None], - do_close: Callable[[], None], + close_session: Callable[[Exception | None], None], ) -> None: logger.info("Attempting to establish new ws connection") @@ -1040,15 +1060,16 @@ async def websocket_closed_callback() -> None: logger.debug("river client get handshake response : %r", handshake_response) if not handshake_response.status.ok: - if handshake_response.status.code == ERROR_CODE_SESSION_STATE_MISMATCH: - do_close() - - raise RiverException( + err = RiverException( ERROR_HANDSHAKE, f"Handshake failed with code {handshake_response.status.code}: { handshake_response.status.reason }", ) + if handshake_response.status.code == ERROR_CODE_SESSION_STATE_MISMATCH: + close_session(err) + + raise err # We did it! We're connected! last_error = None @@ -1069,7 +1090,7 @@ async def websocket_closed_callback() -> None: if last_error is not None: logger.debug("Handshake attempts exhausted, terminating") - do_close() + close_session(last_error) raise RiverException( ERROR_HANDSHAKE, f"Failed to create ws after retrying {attempt_count} number of times", From 7acb9a1d55c344b67bf6831618bc95de88662d74 Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Fri, 4 Apr 2025 20:43:03 -0700 Subject: [PATCH 181/193] Just describe StreamMeta instead of ever-embiggening tuples --- src/replit_river/v2/session.py | 42 +++++++++++++++++++--------------- 1 file changed, 24 insertions(+), 18 deletions(-) diff --git a/src/replit_river/v2/session.py b/src/replit_river/v2/session.py index 62e566c0..c4ad009c 100644 --- a/src/replit_river/v2/session.py +++ b/src/replit_river/v2/session.py @@ -118,6 +118,12 @@ class _IgnoreMessage: pass +class StreamMeta(TypedDict): + span: Span + error_channel: Channel[None | Exception] + output: Channel[Any] + + class Session[HandshakeMetadata]: _server_id: str session_id: str @@ -145,7 +151,7 @@ class Session[HandshakeMetadata]: _space_available: asyncio.Event # stream for tasks - _streams: dict[str, tuple[Span, Channel[Exception | None], Channel[Any]]] + _streams: dict[str, StreamMeta] # book keeping _ack_buffer: deque[TransportMessage] @@ -204,9 +210,7 @@ def __init__( self._space_available.set() # stream for tasks - self._streams: dict[ - str, tuple[Span, Channel[Exception | None], Channel[Any]] - ] = {} + self._streams: dict[str, StreamMeta] = {} # book keeping self._ack_buffer = deque() @@ -399,11 +403,11 @@ async def close(self, reason: Exception | None = None) -> None: await self._task_manager.cancel_all_tasks() - for _, error_channel, stream in self._streams.values(): - stream.close() + for stream_meta in self._streams.values(): + stream_meta["output"].close() # Wake up backpressured writers try: - error_channel.put_nowait( + stream_meta["error_channel"].put_nowait( reason or SessionClosedRiverServiceException( "river session is closed", @@ -415,7 +419,7 @@ async def close(self, reason: Exception | None = None) -> None: ) # Before we GC the streams, let's wait for all tasks to be closed gracefully. await asyncio.gather( - *[stream.join() for _, _, stream in self._streams.values()] + *[stream_meta["output"].join() for stream_meta in self._streams.values()] ) self._streams.clear() @@ -469,7 +473,7 @@ async def commit(msg: TransportMessage) -> None: # Wake up backpressured writer stream_meta = self._streams.get(pending.streamId) if stream_meta: - await stream_meta[1].put(None) + await stream_meta["error_channel"].put(None) def get_next_pending() -> TransportMessage | None: if self._send_buffer: @@ -580,7 +584,7 @@ async def _with_stream( span: Span, stream_id: str, maxsize: int, - ) -> AsyncIterator[tuple[Channel[Exception | None], Channel[ResultType]]]: + ) -> AsyncIterator[tuple[Channel[None | Exception], Channel[ResultType]]]: """ _with_stream @@ -592,8 +596,12 @@ async def _with_stream( emitted should call await error_channel.wait() prior to emission. """ output: Channel[Any] = Channel(maxsize=maxsize) - error_channel: Channel[Exception | None] = Channel(maxsize=1) - self._streams[stream_id] = (span, error_channel, output) + error_channel: Channel[None | Exception] = Channel(maxsize=1) + self._streams[stream_id] = { + "span": span, + "error_channel": error_channel, + "output": output, + } try: yield (error_channel, output) finally: @@ -608,7 +616,7 @@ async def _with_stream( ) return # We need to signal back to all emitters or waiters that we're gone - stream_meta[1].close() + output.close() del self._streams[stream_id] async def send_rpc[R, A]( @@ -1111,7 +1119,7 @@ async def _recv_from_ws( ], get_stream: Callable[ [str], - tuple[Span, Channel[Exception | None], Channel[Any]] | None, + StreamMeta | None, ], enqueue_message: SendMessage[None], ) -> None: @@ -1215,8 +1223,6 @@ async def _recv_from_ws( ) continue - _, _, output = stream_meta - if ( msg.controlFlags & STREAM_CLOSED_BIT != 0 and msg.payload.get("type", None) == "CLOSE" @@ -1226,7 +1232,7 @@ async def _recv_from_ws( pass else: try: - await output.put(msg.payload) + await stream_meta["output"].put(msg.payload) except ChannelClosed: # The client is no longer interested in this stream, # just drop the message. @@ -1236,7 +1242,7 @@ async def _recv_from_ws( # Communicate that we're going down # # This implements the receive side of the half-closed strategy. - output.close() + stream_meta["output"].close() except OutOfOrderMessageException: logger.exception("Out of order message, closing connection") await close_session( From 230f6a62387a13c2c337c36e4364394b0e2c960f Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Mon, 7 Apr 2025 15:46:22 -0700 Subject: [PATCH 182/193] Decouple error channel from backpressure channel Discovered that it was overloaded, written to by multiple different sources with different semantics. --- src/replit_river/v2/session.py | 77 +++++++++++++++++++++++----------- 1 file changed, 52 insertions(+), 25 deletions(-) diff --git a/src/replit_river/v2/session.py b/src/replit_river/v2/session.py index c4ad009c..79947aba 100644 --- a/src/replit_river/v2/session.py +++ b/src/replit_river/v2/session.py @@ -21,7 +21,7 @@ import nanoid import websockets.asyncio.client -from aiochannel import Channel, ChannelFull +from aiochannel import Channel, ChannelEmpty, ChannelFull from aiochannel.errors import ChannelClosed from opentelemetry.trace import Span, use_span from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator @@ -83,6 +83,9 @@ STREAM_CLOSED_BIT: STREAM_CLOSED_BIT_TYPE = 0b01000 +_BackpressuredWaiter: TypeAlias = Callable[[], Awaitable[None]] + + class ResultOk(TypedDict): ok: Literal[True] payload: Any @@ -120,7 +123,8 @@ class _IgnoreMessage: class StreamMeta(TypedDict): span: Span - error_channel: Channel[None | Exception] + release_backpressured_waiter: Callable[[], None] + error_channel: Channel[Exception] output: Channel[Any] @@ -417,6 +421,7 @@ async def close(self, reason: Exception | None = None) -> None: logger.exception( "Unable to tell the caller that the session is going away", ) + stream_meta["release_backpressured_waiter"]() # Before we GC the streams, let's wait for all tasks to be closed gracefully. await asyncio.gather( *[stream_meta["output"].join() for stream_meta in self._streams.values()] @@ -473,7 +478,7 @@ async def commit(msg: TransportMessage) -> None: # Wake up backpressured writer stream_meta = self._streams.get(pending.streamId) if stream_meta: - await stream_meta["error_channel"].put(None) + stream_meta["release_backpressured_waiter"]() def get_next_pending() -> TransportMessage | None: if self._send_buffer: @@ -584,7 +589,7 @@ async def _with_stream( span: Span, stream_id: str, maxsize: int, - ) -> AsyncIterator[tuple[Channel[None | Exception], Channel[ResultType]]]: + ) -> AsyncIterator[tuple[_BackpressuredWaiter, AsyncIterator[ResultType]]]: """ _with_stream @@ -596,14 +601,36 @@ async def _with_stream( emitted should call await error_channel.wait() prior to emission. """ output: Channel[Any] = Channel(maxsize=maxsize) - error_channel: Channel[None | Exception] = Channel(maxsize=1) + backpressured_waiter_event: asyncio.Event = asyncio.Event() + error_channel: Channel[Exception] = Channel(maxsize=1) self._streams[stream_id] = { "span": span, "error_channel": error_channel, + "release_backpressured_waiter": backpressured_waiter_event.set, "output": output, } + + async def backpressured_waiter() -> None: + await backpressured_waiter_event.wait() + try: + err = error_channel.get_nowait() + raise err + except (ChannelClosed, ChannelEmpty): + # No errors, off to the next message + pass + + async def error_checking_output() -> AsyncIterator[ResultType]: + async for elem in output: + try: + err = error_channel.get_nowait() + raise err + except (ChannelClosed, ChannelEmpty): + # No errors, off to the next message + pass + yield elem + try: - yield (error_channel, output) + yield (backpressured_waiter, error_checking_output()) finally: stream_meta = self._streams.get(stream_id) if not stream_meta: @@ -644,14 +671,16 @@ async def send_rpc[R, A]( span=span, ) - async with self._with_stream(span, stream_id, 1) as (error_channel, output): + async with self._with_stream(span, stream_id, 1) as ( + backpressured_waiter, + output, + ): # Handle potential errors during communication try: async with asyncio.timeout(timeout.total_seconds()): # Block for backpressure and emission errors from the ws - if err := await error_channel.get(): - raise err - result = await output.get() + await backpressured_waiter() + result = await anext(output) except asyncio.TimeoutError as e: await self._send_cancel_stream( stream_id=stream_id, @@ -705,15 +734,16 @@ async def send_upload[I, R, A]( span=span, ) - async with self._with_stream(span, stream_id, 1) as (error_channel, output): + async with self._with_stream(span, stream_id, 1) as ( + backpressured_waiter, + output, + ): try: # If this request is not closed and the session is killed, we should # throw exception here async for item in request: # Block for backpressure and emission errors from the ws - if err := await error_channel.get(): - raise err - + await backpressured_waiter() try: payload = request_serializer(item) except Exception as e: @@ -753,7 +783,7 @@ async def send_upload[I, R, A]( ) try: - result = await output.get() + result = await anext(output) except ChannelClosed as e: raise RiverServiceException( ERROR_CODE_STREAM_CLOSED, @@ -856,7 +886,7 @@ async def send_stream[I, R, E, A]( ) async with self._with_stream(span, stream_id, MAX_MESSAGE_BUFFER_SIZE) as ( - error_channel, + backpressured_waiter, output, ): # Create the encoder task @@ -871,13 +901,9 @@ async def _encode_stream() -> None: assert request_serializer, "send_stream missing request_serializer" async for item in request: - # Block for backpressure and emission errors from the ws - if err := await error_channel.get(): - await self._send_close_stream( - stream_id=stream_id, - span=span, - ) - raise err + # Block for backpressure (or errors) + await backpressured_waiter() + # If there are any errors so far, raise them await self._enqueue_message( stream_id=stream_id, control_flags=0, @@ -894,14 +920,15 @@ async def _encode_stream() -> None: try: async for result in output: # Raise as early as we possibly can in case of an emission error - if err := emitter_task.done() and emitter_task.exception(): + if emitter_task.done() and (err := emitter_task.exception()): raise err if result.get("type") == "CLOSE": break if "ok" not in result or not result["ok"]: yield error_deserializer(result["payload"]) yield response_deserializer(result["payload"]) - # ... block the outer function until the emitter is finished emitting. + # ... block the outer function until the emitter is finished emitting, + # possibly raising a terminal exception. await emitter_task except Exception as e: await self._send_cancel_stream( From 08a0acd2332c0978662e9f70e283298d1c65000a Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Mon, 7 Apr 2025 17:01:28 -0700 Subject: [PATCH 183/193] PR feedback --- src/replit_river/client.py | 3 ++- src/replit_river/codegen/client.py | 8 ++++---- src/replit_river/common_session.py | 4 ++-- src/replit_river/message_buffer.py | 1 + src/replit_river/v2/client.py | 3 ++- 5 files changed, 11 insertions(+), 8 deletions(-) diff --git a/src/replit_river/client.py b/src/replit_river/client.py index 3f852d64..db4608ec 100644 --- a/src/replit_river/client.py +++ b/src/replit_river/client.py @@ -226,9 +226,11 @@ def _trace_procedure( span_handle = _SpanHandle(span) try: yield span_handle + span_handle.set_status(StatusCode.OK) except GeneratorExit: # This error indicates the caller is done with the async generator # but messages are still left. This is okay, we do not consider it an error. + span_handle.set_status(StatusCode.OK) raise except RiverException as e: span.record_exception(e, escaped=True) @@ -239,7 +241,6 @@ def _trace_procedure( span_handle.set_status(StatusCode.ERROR, f"{type(e).__name__}: {e}") raise e finally: - span_handle.set_status(StatusCode.OK) span.end() diff --git a/src/replit_river/codegen/client.py b/src/replit_river/codegen/client.py index 0fc226d2..8caaf0b5 100644 --- a/src/replit_river/codegen/client.py +++ b/src/replit_river/codegen/client.py @@ -850,11 +850,11 @@ async def {name}( elif procedure.type == "subscription": match protocol_version: case "v1.1": - assert input_meta, "rpc expects input to be required" + assert input_meta, "subscription expects input to be required" _, tpe, render_method = input_meta binding = "input" case "v2.0": - assert init_meta, "rpc expects init to be required" + assert init_meta, "subscription expects init to be required" _, tpe, render_method = init_meta binding = "init" case other: @@ -932,7 +932,7 @@ async def {name}( ] ) elif protocol_version == "v1.1": - assert input_meta, "Protocol v1 requires input to be defined" + assert input_meta, "upload requires input to be defined" _, input_type, render_input_method = input_meta current_chunks.extend( [ @@ -1009,7 +1009,7 @@ async def {name}( ] ) elif protocol_version == "v1.1": - assert input_meta, "Protocol v1 requires input to be defined" + assert input_meta, "stream requires input to be defined" _, input_type, render_input_method = input_meta current_chunks.extend( [ diff --git a/src/replit_river/common_session.py b/src/replit_river/common_session.py index df4c3549..d47bfe55 100644 --- a/src/replit_river/common_session.py +++ b/src/replit_river/common_session.py @@ -102,13 +102,13 @@ async def buffered_message_sender( except WebsocketClosedException as e: logger.debug( "_buffered_message_sender: Connection closed while sending " - "message %r, waiting for retry from buffer", + "message %r, waiting for reconnect and retry from buffer", type(e), exc_info=e, ) except FailedSendingMessageException: logger.error( - "Failed sending message, waiting for retry from buffer", + "Failed sending message, waiting for reconnect and retry from buffer", exc_info=True, ) except Exception: diff --git a/src/replit_river/message_buffer.py b/src/replit_river/message_buffer.py index 5c3b8f44..1ced8582 100644 --- a/src/replit_river/message_buffer.py +++ b/src/replit_river/message_buffer.py @@ -68,6 +68,7 @@ async def close(self) -> None: Closes the message buffer and rejects any pending put operations. """ self._closed = True + # Wake up block_until_message_available to permit graceful cleanup self._has_messages.set() async with self._space_available_cond: self._space_available_cond.notify_all() diff --git a/src/replit_river/v2/client.py b/src/replit_river/v2/client.py index 09acc476..1b900d38 100644 --- a/src/replit_river/v2/client.py +++ b/src/replit_river/v2/client.py @@ -181,9 +181,11 @@ def _trace_procedure( span_handle = _SpanHandle(span) try: yield span_handle + span_handle.set_status(StatusCode.OK) except GeneratorExit: # This error indicates the caller is done with the async generator # but messages are still left. This is okay, we do not consider it an error. + span_handle.set_status(StatusCode.OK) raise except RiverException as e: span.record_exception(e, escaped=True) @@ -194,7 +196,6 @@ def _trace_procedure( span_handle.set_status(StatusCode.ERROR, f"{type(e).__name__}: {e}") raise e finally: - span_handle.set_status(StatusCode.OK) span.end() From 502c055a5f1eb37bf07fe42d4e4a672e65bcbebc Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Mon, 7 Apr 2025 17:43:28 -0700 Subject: [PATCH 184/193] Fix a bug where we were over-emitting encoder_ members for pydantic --- src/replit_river/codegen/client.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/replit_river/codegen/client.py b/src/replit_river/codegen/client.py index 8caaf0b5..f04b5427 100644 --- a/src/replit_river/codegen/client.py +++ b/src/replit_river/codegen/client.py @@ -275,7 +275,9 @@ def flatten_union(tpe: RiverType) -> list[RiverType]: encoder_name = TypeName( f"encode_{render_literal_type(type_name)}" ) - encoder_names.add(encoder_name) + if base_model == "TypedDict": + # "encoder_names" is only a TypedDict thing + encoder_names.add(encoder_name) _field_name = render_literal_type(encoder_name) typeddict_encoder.append( f"""\ From 9b40d6ae96c2b0d24cd1ae6eb87a596108f1596c Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Mon, 7 Apr 2025 17:43:38 -0700 Subject: [PATCH 185/193] Line length lint --- src/replit_river/common_session.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/replit_river/common_session.py b/src/replit_river/common_session.py index d47bfe55..07897733 100644 --- a/src/replit_river/common_session.py +++ b/src/replit_river/common_session.py @@ -108,7 +108,8 @@ async def buffered_message_sender( ) except FailedSendingMessageException: logger.error( - "Failed sending message, waiting for reconnect and retry from buffer", + "Failed sending message, " + "waiting for reconnect and retry from buffer", exc_info=True, ) except Exception: From 59e8d0f9bb7785b759c4498a9cc2d975b4bf4d8b Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Mon, 7 Apr 2025 17:44:01 -0700 Subject: [PATCH 186/193] Typing the output channel slightly --- src/replit_river/v2/session.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/replit_river/v2/session.py b/src/replit_river/v2/session.py index 79947aba..737a71e8 100644 --- a/src/replit_river/v2/session.py +++ b/src/replit_river/v2/session.py @@ -125,7 +125,7 @@ class StreamMeta(TypedDict): span: Span release_backpressured_waiter: Callable[[], None] error_channel: Channel[Exception] - output: Channel[Any] + output: Channel[ResultType] class Session[HandshakeMetadata]: @@ -600,7 +600,7 @@ async def _with_stream( since the first event does not care about backpressure, but subsequent events emitted should call await error_channel.wait() prior to emission. """ - output: Channel[Any] = Channel(maxsize=maxsize) + output: Channel[ResultType] = Channel(maxsize=maxsize) backpressured_waiter_event: asyncio.Event = asyncio.Event() error_channel: Channel[Exception] = Channel(maxsize=1) self._streams[stream_id] = { @@ -697,6 +697,7 @@ async def send_rpc[R, A]( ) from e except Exception as e: raise RiverException(ERROR_CODE_STREAM_CLOSED, str(e)) from e + if "ok" not in result or not result["ok"]: try: error = error_deserializer(result["payload"]) From 0c030b55d899e849a591774276b4eee60870c5b9 Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Mon, 7 Apr 2025 18:23:09 -0700 Subject: [PATCH 187/193] Use handshake_timeout_ms --- src/replit_river/v2/session.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/replit_river/v2/session.py b/src/replit_river/v2/session.py index 737a71e8..5f50e737 100644 --- a/src/replit_river/v2/session.py +++ b/src/replit_river/v2/session.py @@ -292,6 +292,7 @@ def unbind_connecting_task() -> None: if not self._connecting_task: self._connecting_task = asyncio.create_task( _do_ensure_connected( + transport_options=self._transport_options, client_id=self._client_id, server_id=self._server_id, session_id=self.session_id, @@ -977,6 +978,7 @@ async def _send_close_stream( async def _do_ensure_connected[HandshakeMetadata]( + transport_options: TransportOptions, client_id: str, session_id: str, server_id: str, @@ -1048,7 +1050,9 @@ async def websocket_closed_callback() -> None: "Handshake failed, conn closed while sending response", ) from e - startup_grace_deadline_ms = await get_current_time() + 60_000 + startup_grace_deadline_ms = ( + await get_current_time() + transport_options.handshake_timeout_ms + ) while True: if await get_current_time() >= startup_grace_deadline_ms: raise RiverException( From 395f7bbbd3aabc1d87a7c0f4e0a2a4df6c94b8cc Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Mon, 7 Apr 2025 18:25:40 -0700 Subject: [PATCH 188/193] Removing overly-broad "except" --- src/replit_river/v2/session.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/src/replit_river/v2/session.py b/src/replit_river/v2/session.py index 5f50e737..a44a0b9c 100644 --- a/src/replit_river/v2/session.py +++ b/src/replit_river/v2/session.py @@ -696,8 +696,6 @@ async def send_rpc[R, A]( service_name, procedure_name, ) from e - except Exception as e: - raise RiverException(ERROR_CODE_STREAM_CLOSED, str(e)) from e if "ok" not in result or not result["ok"]: try: @@ -1050,11 +1048,11 @@ async def websocket_closed_callback() -> None: "Handshake failed, conn closed while sending response", ) from e - startup_grace_deadline_ms = ( + handshake_deadline_ms = ( await get_current_time() + transport_options.handshake_timeout_ms ) while True: - if await get_current_time() >= startup_grace_deadline_ms: + if await get_current_time() >= handshake_deadline_ms: raise RiverException( ERROR_HANDSHAKE, "Handshake response timeout, closing connection", From efe17ca83327e5816c576b769ecd4871c43bf874 Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Mon, 7 Apr 2025 18:42:57 -0700 Subject: [PATCH 189/193] PR feedback --- src/replit_river/v2/session.py | 87 ++++++++++++++++++++-------------- 1 file changed, 52 insertions(+), 35 deletions(-) diff --git a/src/replit_river/v2/session.py b/src/replit_river/v2/session.py index a44a0b9c..7fece7b0 100644 --- a/src/replit_river/v2/session.py +++ b/src/replit_river/v2/session.py @@ -248,12 +248,22 @@ def get_next_sent_seq() -> int: return self.seq def close_session(reason: Exception | None) -> None: + # If we're already closing, just let whoever's currently doing it handle it. + if self._state in TerminalStates: + return + # Avoid closing twice if self._terminating_task is None: + current_state = self._state + self._state = SessionState.CLOSING + # We can't just call self.close() directly because # we're inside a thread that will eventually be awaited # during the cleanup procedure. - self._terminating_task = asyncio.create_task(self.close(reason)) + + self._terminating_task = asyncio.create_task( + self.close(reason, current_state=current_state), + ) def transition_connecting() -> None: if self._state in TerminalStates: @@ -301,6 +311,7 @@ def unbind_connecting_task() -> None: get_next_sent_seq=get_next_sent_seq, get_current_ack=lambda: self.ack, get_current_time=self._get_current_time, + get_state=lambda: self._state, transition_connecting=transition_connecting, close_ws_in_background=close_ws_in_background, transition_connected=transition_connected, @@ -385,12 +396,12 @@ async def _enqueue_message( # Wake up buffered_message_sender self._process_messages.set() - async def close(self, reason: Exception | None = None) -> None: + async def close(self, reason: Exception | None = None, current_state: SessionState | None = None ) -> None: """Close the session and all associated streams.""" logger.info( f"{self.session_id} closing session to {self._server_id}, ws: {self._ws}" ) - if self._state in TerminalStates: + if (current_state or self._state) in TerminalStates: # already closing return self._state = SessionState.CLOSING @@ -987,6 +998,7 @@ async def _do_ensure_connected[HandshakeMetadata]( get_current_time: Callable[[], Awaitable[float]], get_next_sent_seq: Callable[[], int], get_current_ack: Callable[[], int], + get_state: Callable[[], SessionState], transition_connecting: Callable[[], None], close_ws_in_background: Callable[[ClientConnection], None], transition_connected: Callable[[ClientConnection], None], @@ -998,6 +1010,10 @@ async def _do_ensure_connected[HandshakeMetadata]( last_error: Exception | None = None attempt_count = 0 while rate_limiter.has_budget(client_id): + if (state := get_state()) in TerminalStates or state in ActiveStates: + logger.info(f"_do_ensure_connected stopping due to state={state}") + break + if attempt_count > 0: logger.info(f"Retrying build handshake number {attempt_count} times") attempt_count += 1 @@ -1051,40 +1067,40 @@ async def websocket_closed_callback() -> None: handshake_deadline_ms = ( await get_current_time() + transport_options.handshake_timeout_ms ) - while True: - if await get_current_time() >= handshake_deadline_ms: - raise RiverException( - ERROR_HANDSHAKE, - "Handshake response timeout, closing connection", - ) - try: - data = await ws.recv(decode=False) - except ConnectionClosed as e: - logger.debug( - "_do_ensure_connected: Connection closed during waiting " - "for handshake response", - exc_info=True, - ) - raise RiverException( - ERROR_HANDSHAKE, - "Handshake failed, conn closed while waiting for response", - ) from e - try: - response_msg = parse_transport_msg(data) - if isinstance(response_msg, str): - logger.debug( - "_do_ensure_connected: Ignoring transport message", - exc_info=True, - ) - continue + if await get_current_time() >= handshake_deadline_ms: + raise RiverException( + ERROR_HANDSHAKE, + "Handshake response timeout, closing connection", + ) - break - except InvalidMessageException as e: - raise RiverException( - ERROR_HANDSHAKE, - "Got invalid transport message, closing connection", - ) from e + try: + data = await ws.recv(decode=False) + except ConnectionClosed as e: + logger.debug( + "_do_ensure_connected: Connection closed during waiting " + "for handshake response", + exc_info=True, + ) + raise RiverException( + ERROR_HANDSHAKE, + "Handshake failed, conn closed while waiting for response", + ) from e + + try: + response_msg = parse_transport_msg(data) + except InvalidMessageException as e: + raise RiverException( + ERROR_HANDSHAKE, + "Got invalid transport message, closing connection", + ) from e + + if isinstance(response_msg, str): + raise RiverException( + ERROR_HANDSHAKE, + "Handshake failed, received a raw string message while waiting " + "for a handshake response", + ) try: handshake_response = ControlMessageHandshakeResponse( @@ -1105,6 +1121,7 @@ async def websocket_closed_callback() -> None: }", ) if handshake_response.status.code == ERROR_CODE_SESSION_STATE_MISMATCH: + # A session state mismatch is unrecoverable. Terminate immediately. close_session(err) raise err From e6812a79f3ae9e49502e678b9c0881b0254535b3 Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Mon, 7 Apr 2025 19:27:15 -0700 Subject: [PATCH 190/193] Avoid deadlocking client if streams don't clean up after themselves --- src/replit_river/transport_options.py | 6 +++++ src/replit_river/v2/session.py | 33 ++++++++++++++++++++++----- 2 files changed, 33 insertions(+), 6 deletions(-) diff --git a/src/replit_river/transport_options.py b/src/replit_river/transport_options.py index 47032bac..09f000f8 100644 --- a/src/replit_river/transport_options.py +++ b/src/replit_river/transport_options.py @@ -27,6 +27,7 @@ class TransportOptions(BaseModel): connection_retry_options: ConnectionRetryOptions = ConnectionRetryOptions() buffer_size: int = 1_000 transparent_reconnect: bool = True + shutdown_all_streams_timeout_ms: float = 10_000 def websocket_disconnect_grace_ms(self) -> float: return self.heartbeat_ms * self.heartbeats_until_dead @@ -39,11 +40,16 @@ def create_from_env(cls) -> "TransportOptions": ) heartbeat_ms = float(os.getenv("HEARTBEAT_MS", 2_000)) heartbeats_to_dead = int(os.getenv("HEARTBEATS_UNTIL_DEAD", 2)) + shutdown_all_streams_timeout_ms = float( + os.getenv("SHUTDOWN_STREAMS_TIMEOUT_MS", 10_000) + ) + return TransportOptions( handshake_timeout_ms=handshake_timeout_ms, session_disconnect_grace_ms=session_disconnect_grace_ms, heartbeat_ms=heartbeat_ms, heartbeats_until_dead=heartbeats_to_dead, + shutdown_all_streams_timeout_ms=shutdown_all_streams_timeout_ms, ) diff --git a/src/replit_river/v2/session.py b/src/replit_river/v2/session.py index 7fece7b0..a8cbc1a3 100644 --- a/src/replit_river/v2/session.py +++ b/src/replit_river/v2/session.py @@ -262,8 +262,8 @@ def close_session(reason: Exception | None) -> None: # during the cleanup procedure. self._terminating_task = asyncio.create_task( - self.close(reason, current_state=current_state), - ) + self.close(reason, current_state=current_state), + ) def transition_connecting() -> None: if self._state in TerminalStates: @@ -396,7 +396,9 @@ async def _enqueue_message( # Wake up buffered_message_sender self._process_messages.set() - async def close(self, reason: Exception | None = None, current_state: SessionState | None = None ) -> None: + async def close( + self, reason: Exception | None = None, current_state: SessionState | None = None + ) -> None: """Close the session and all associated streams.""" logger.info( f"{self.session_id} closing session to {self._server_id}, ws: {self._ws}" @@ -435,9 +437,28 @@ async def close(self, reason: Exception | None = None, current_state: SessionSta ) stream_meta["release_backpressured_waiter"]() # Before we GC the streams, let's wait for all tasks to be closed gracefully. - await asyncio.gather( - *[stream_meta["output"].join() for stream_meta in self._streams.values()] - ) + try: + async with asyncio.timeout( + self._transport_options.shutdown_all_streams_timeout_ms + ): + # Block for backpressure and emission errors from the ws + await asyncio.gather( + *[ + stream_meta["output"].join() + for stream_meta in self._streams.values() + ] + ) + except asyncio.TimeoutError: + spans: list[Span] = [ + stream_meta["span"] + for stream_meta in self._streams.values() + if not stream_meta["output"].closed() + ] + span_ids = [span.get_span_context().span_id for span in spans] + logger.exception( + "Timeout waiting for output streams to finallize", + extra={"span_ids": span_ids}, + ) self._streams.clear() if self._ws: From 7123e823ecdcfcb5ca483e38b3b344e1525b4d0f Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Mon, 7 Apr 2025 19:34:23 -0700 Subject: [PATCH 191/193] Missed that sessions should send close at the end of input --- src/replit_river/v2/session.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/replit_river/v2/session.py b/src/replit_river/v2/session.py index a8cbc1a3..fca4eed0 100644 --- a/src/replit_river/v2/session.py +++ b/src/replit_river/v2/session.py @@ -877,6 +877,7 @@ async def send_subscription[I, E, A]( if not item.get("ok", False): yield error_deserializer(item["payload"]) yield response_deserializer(item["payload"]) + await self._send_close_stream(stream_id, span) except Exception as e: await self._send_cancel_stream( stream_id=stream_id, From e0fac3ec58307aef668b8ce73b2af4b785646be8 Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Tue, 8 Apr 2025 10:22:26 -0700 Subject: [PATCH 192/193] Missing "continue"s --- src/replit_river/v2/session.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/replit_river/v2/session.py b/src/replit_river/v2/session.py index fca4eed0..46ab151d 100644 --- a/src/replit_river/v2/session.py +++ b/src/replit_river/v2/session.py @@ -876,6 +876,7 @@ async def send_subscription[I, E, A]( break if not item.get("ok", False): yield error_deserializer(item["payload"]) + continue yield response_deserializer(item["payload"]) await self._send_close_stream(stream_id, span) except Exception as e: @@ -959,6 +960,7 @@ async def _encode_stream() -> None: break if "ok" not in result or not result["ok"]: yield error_deserializer(result["payload"]) + continue yield response_deserializer(result["payload"]) # ... block the outer function until the emitter is finished emitting, # possibly raising a terminal exception. From 4820ef86c33350773dcb965bfb34267110637c5c Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Tue, 8 Apr 2025 10:40:43 -0700 Subject: [PATCH 193/193] Trying to debug hanging messages during disconnect --- src/replit_river/common_session.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/replit_river/common_session.py b/src/replit_river/common_session.py index 07897733..3dce6138 100644 --- a/src/replit_river/common_session.py +++ b/src/replit_river/common_session.py @@ -99,6 +99,7 @@ async def buffered_message_sender( try: await send_transport_message(msg, ws, websocket_closed_callback) await commit(msg) + logger.debug("_buffered_message_sender: Sent %r", msg.id) except WebsocketClosedException as e: logger.debug( "_buffered_message_sender: Connection closed while sending "