diff --git a/pyproject.toml b/pyproject.toml index 639f2399..2f049a3c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -28,7 +28,7 @@ dependencies = [ "nanoid>=2.0.0", "protobuf>=5.28.3", "pydantic-core>=2.20.1", - "websockets>=12.0", + "websockets>=13.0,<14", "opentelemetry-sdk>=1.28.2", "opentelemetry-api>=1.28.2", ] diff --git a/scripts/parity.sh b/scripts/parity.sh deleted file mode 100644 index c9beda24..00000000 --- a/scripts/parity.sh +++ /dev/null @@ -1,40 +0,0 @@ -#!/usr/bin/env bash -# -# parity.sh: Generate Pydantic and TypedDict models and check for deep equality. -# This script expects that ai-infra is cloned alongside river-python. - -set -e - -scripts="$(dirname "$0")" -cd "${scripts}/.." - -root="$(mktemp -d --tmpdir 'river-codegen-parity.XXX')" -mkdir "$root/src" - -echo "Using $root" >&2 - -function cleanup { - if [ -z "${DEBUG}" ]; then - echo "Cleaning up..." >&2 - rm -rfv "${root}" >&2 - fi -} -trap "cleanup" 0 2 3 15 - -gen() { - fname="$1"; shift - name="$1"; shift - poetry run python -m replit_river.codegen \ - client \ - --output "${root}/src/${fname}" \ - --client-name "${name}" \ - ../ai-infra/pkgs/pid2_client/src/schema/schema.json \ - "$@" -} - -gen tyd.py Pid2TypedDict --typed-dict-inputs -gen pyd.py Pid2Pydantic - -PYTHONPATH="${root}/src:${scripts}" -poetry run bash -c "MYPYPATH='$PYTHONPATH' mypy -m parity.check_parity" -poetry run bash -c "PYTHONPATH='$PYTHONPATH' python -m parity.check_parity" diff --git a/scripts/parity/check_parity.py b/scripts/parity/check_parity.py deleted file mode 100644 index c89a6918..00000000 --- a/scripts/parity/check_parity.py +++ /dev/null @@ -1,295 +0,0 @@ -from typing import Any, Callable, Literal, TypedDict, TypeVar - -import pyd -import tyd -from parity.gen import ( - gen_bool, - gen_choice, - gen_dict, - gen_float, - gen_int, - gen_list, - gen_opt, - gen_str, -) -from pydantic import TypeAdapter - -A = TypeVar("A") - -PrimitiveType = ( - bool | str | int | float | dict[str, "PrimitiveType"] | list["PrimitiveType"] -) - - -def deep_equal(a: PrimitiveType, b: PrimitiveType) -> Literal[True]: - if a == b: - return True - elif isinstance(a, dict) and isinstance(b, dict): - a_keys: PrimitiveType = list(a.keys()) - b_keys: PrimitiveType = list(b.keys()) - assert deep_equal(a_keys, b_keys) - - # We do this dance again because Python variance is hard. Feel free to fix it. - keys = set(a.keys()) - keys.update(b.keys()) - for k in keys: - aa: PrimitiveType = a[k] - bb: PrimitiveType = b[k] - assert deep_equal(aa, bb) - return True - elif isinstance(a, list) and isinstance(b, list): - assert len(a) == len(b) - for i in range(len(a)): - assert deep_equal(a[i], b[i]) - return True - else: - assert a == b, f"{a} != {b}" - return True - - -def baseTestPattern( - x: A, encode: Callable[[A], Any], adapter: TypeAdapter[Any] -) -> None: - a = encode(x) - m = adapter.validate_python(a) - z = adapter.dump_python(m, by_alias=True, exclude_none=True) - - assert deep_equal(a, z) - - -def testAiexecExecInit() -> None: - x: tyd.AiexecExecInit = { - "args": gen_list(gen_str)(), - "env": gen_opt(gen_dict(gen_str))(), - "cwd": gen_opt(gen_str)(), - "omitStdout": gen_opt(gen_bool)(), - "omitStderr": gen_opt(gen_bool)(), - "useReplitRunEnv": gen_opt(gen_bool)(), - } - - baseTestPattern(x, tyd.encode_AiexecExecInit, TypeAdapter(pyd.AiexecExecInit)) - - -def testAgenttoollanguageserverOpendocumentInput() -> None: - x: tyd.AgenttoollanguageserverOpendocumentInput = { - "uri": gen_str(), - "languageId": gen_str(), - "version": gen_float(), - "text": gen_str(), - } - - baseTestPattern( - x, - tyd.encode_AgenttoollanguageserverOpendocumentInput, - TypeAdapter(pyd.AgenttoollanguageserverOpendocumentInput), - ) - - -kind_type = ( - Literal[ - 1, - 2, - 3, - 4, - 5, - 6, - 7, - 8, - 9, - 10, - 11, - 12, - 13, - 14, - 15, - 16, - 17, - 18, - 19, - 20, - 21, - 22, - 23, - 24, - 25, - 26, - ] - | None -) - - -def testAgenttoollanguageserverGetcodesymbolInput() -> None: - x: tyd.AgenttoollanguageserverGetcodesymbolInput = { - "uri": gen_str(), - "position": { - "line": gen_float(), - "character": gen_float(), - }, - "kind": gen_choice( - list[kind_type]( - [ - 1, - 2, - 3, - 4, - 5, - 6, - 7, - 8, - 9, - 10, - 11, - 12, - 13, - 14, - 15, - 16, - 17, - 18, - 19, - 20, - 21, - 22, - 23, - 24, - 25, - 26, - None, - ] - ) - )(), - } - - baseTestPattern( - x, - tyd.encode_AgenttoollanguageserverGetcodesymbolInput, - TypeAdapter(pyd.AgenttoollanguageserverGetcodesymbolInput), - ) - - -class size_type(TypedDict): - rows: int - cols: int - - -def testShellexecSpawnInput() -> None: - x: tyd.ShellexecSpawnInput = { - "cmd": gen_str(), - "args": gen_opt(gen_list(gen_str))(), - "initialCmd": gen_opt(gen_str)(), - "env": gen_opt(gen_dict(gen_str))(), - "cwd": gen_opt(gen_str)(), - "size": gen_opt( - lambda: size_type( - { - "rows": gen_int(), - "cols": gen_int(), - } - ), - )(), - "useReplitRunEnv": gen_opt(gen_bool)(), - "useCgroupMagic": gen_opt(gen_bool)(), - "interactive": gen_opt(gen_bool)(), - "onlySpawnIfNoProcesses": gen_opt(gen_bool)(), - } - - baseTestPattern( - x, - tyd.encode_ShellexecSpawnInput, - TypeAdapter(pyd.ShellexecSpawnInput), - ) - - -def testConmanfilesystemPersistInput() -> None: - x: tyd.ConmanfilesystemPersistInput = {} - - baseTestPattern( - x, - tyd.encode_ConmanfilesystemPersistInput, - TypeAdapter(pyd.ConmanfilesystemPersistInput), - ) - - -closeFile = tyd.ReplspaceapiInitInputOneOf_closeFile -githubToken = tyd.ReplspaceapiInitInputOneOf_githubToken -sshToken0 = tyd.ReplspaceapiInitInputOneOf_sshToken0 -sshToken1 = tyd.ReplspaceapiInitInputOneOf_sshToken1 -allowDefaultBucketAccess = tyd.ReplspaceapiInitInputOneOf_allowDefaultBucketAccess - -allowDefaultBucketAccessResultOk = ( - tyd.ReplspaceapiInitInputOneOf_allowDefaultBucketAccessResultOneOf_ok -) -allowDefaultBucketAccessResultError = ( - tyd.ReplspaceapiInitInputOneOf_allowDefaultBucketAccessResultOneOf_error -) - - -def testReplspaceapiInitInput() -> None: - x: tyd.ReplspaceapiInitInput = gen_choice( - list[tyd.ReplspaceapiInitInput]( - [ - closeFile( - {"kind": "closeFile", "filename": gen_str(), "nonce": gen_str()} - ), - githubToken( - {"kind": "githubToken", "token": gen_str(), "nonce": gen_str()} - ), - sshToken0( - { - "kind": "sshToken", - "nonce": gen_str(), - "SSHHostname": gen_str(), - "token": gen_str(), - } - ), - sshToken1({"kind": "sshToken", "nonce": gen_str(), "error": gen_str()}), - allowDefaultBucketAccess( - { - "kind": "allowDefaultBucketAccess", - "nonce": gen_str(), - "result": gen_choice( - list[ - tyd.ReplspaceapiInitInputOneOf_allowDefaultBucketAccessResult - ]( - [ - allowDefaultBucketAccessResultOk( - { - "bucketId": gen_str(), - "sourceReplId": gen_str(), - "status": "ok", - "targetReplId": gen_str(), - } - ), - allowDefaultBucketAccessResultError( - {"message": gen_str(), "status": "error"} - ), - ] - ) - )(), - } - ), - ] - ) - )() - - baseTestPattern( - x, - tyd.encode_ReplspaceapiInitInput, - TypeAdapter(pyd.ReplspaceapiInitInput), - ) - - -def main() -> None: - testAiexecExecInit() - testAgenttoollanguageserverOpendocumentInput() - testAgenttoollanguageserverGetcodesymbolInput() - testShellexecSpawnInput() - testConmanfilesystemPersistInput() - testReplspaceapiInitInput() - - -if __name__ == "__main__": - print("Starting...") - for _ in range(0, 100): - main() - print("Verified") diff --git a/src/replit_river/__init__.py b/src/replit_river/__init__.py index bc166e3e..d3bf1966 100644 --- a/src/replit_river/__init__.py +++ b/src/replit_river/__init__.py @@ -1,3 +1,4 @@ +from . import v2 from .client import Client from .error_schema import RiverError from .rpc import ( @@ -20,4 +21,5 @@ "subscription_method_handler", "upload_method_handler", "stream_method_handler", + "v2", ] diff --git a/src/replit_river/client.py b/src/replit_river/client.py index 3f852d64..db4608ec 100644 --- a/src/replit_river/client.py +++ b/src/replit_river/client.py @@ -226,9 +226,11 @@ def _trace_procedure( span_handle = _SpanHandle(span) try: yield span_handle + span_handle.set_status(StatusCode.OK) except GeneratorExit: # This error indicates the caller is done with the async generator # but messages are still left. This is okay, we do not consider it an error. + span_handle.set_status(StatusCode.OK) raise except RiverException as e: span.record_exception(e, escaped=True) @@ -239,7 +241,6 @@ def _trace_procedure( span_handle.set_status(StatusCode.ERROR, f"{type(e).__name__}: {e}") raise e finally: - span_handle.set_status(StatusCode.OK) span.end() diff --git a/src/replit_river/client_session.py b/src/replit_river/client_session.py index 2d1e847a..ef535ee3 100644 --- a/src/replit_river/client_session.py +++ b/src/replit_river/client_session.py @@ -173,9 +173,8 @@ async def _handle_messages_from_ws(self) -> None: # The client is no longer interested in this stream, # just drop the message. pass - except RuntimeError as e: + except Exception as e: raise InvalidMessageException(e) from e - else: raise InvalidMessageException( "Client should not receive stream open bit" @@ -245,7 +244,7 @@ async def send_rpc( service_name, procedure_name, ) from e - except RuntimeError as e: + except Exception as e: raise RiverException(ERROR_CODE_STREAM_CLOSED, str(e)) from e if not response.get("ok", False): try: @@ -331,7 +330,7 @@ async def send_upload( service_name, procedure_name, ) from e - except RuntimeError as e: + except Exception as e: raise RiverException(ERROR_CODE_STREAM_CLOSED, str(e)) from e if not response.get("ok", False): try: @@ -388,15 +387,13 @@ async def send_subscription( ) continue yield response_deserializer(item["payload"]) - except (RuntimeError, ChannelClosed) as e: + except Exception as e: raise RiverServiceException( ERROR_CODE_STREAM_CLOSED, "Stream closed before response", service_name, procedure_name, ) from e - except Exception as e: - raise e finally: output.close() @@ -491,17 +488,14 @@ async def _encode_stream() -> None: ) continue yield response_deserializer(item["payload"]) - except (RuntimeError, ChannelClosed) as e: + except Exception as e: + logger.exception("There was a problem") raise RiverServiceException( ERROR_CODE_STREAM_CLOSED, "Stream closed before response", service_name, procedure_name, ) from e - except Exception as e: - raise e - finally: - output.close() async def send_close_stream( self, diff --git a/src/replit_river/client_transport.py b/src/replit_river/client_transport.py index 04215811..1e8fdcf1 100644 --- a/src/replit_river/client_transport.py +++ b/src/replit_river/client_transport.py @@ -261,7 +261,7 @@ async def websocket_closed_callback() -> None: try: await send_transport_message( TransportMessage( - from_=transport_id, # type: ignore + from_=transport_id, to=to_id, streamId=stream_id, controlFlags=0, diff --git a/src/replit_river/codegen/client.py b/src/replit_river/codegen/client.py index d4e69ddf..f04b5427 100644 --- a/src/replit_river/codegen/client.py +++ b/src/replit_river/codegen/client.py @@ -11,6 +11,7 @@ Sequence, Set, TextIO, + assert_never, cast, ) @@ -118,7 +119,7 @@ class RiverNotType(BaseModel): class RiverProcedure(BaseModel): init: RiverType | None = Field(default=None) - input: RiverType + input: RiverType | None = Field(default=None) output: RiverType errors: RiverType | None = Field(default=None) type: ( @@ -274,7 +275,9 @@ def flatten_union(tpe: RiverType) -> list[RiverType]: encoder_name = TypeName( f"encode_{render_literal_type(type_name)}" ) - encoder_names.add(encoder_name) + if base_model == "TypedDict": + # "encoder_names" is only a TypedDict thing + encoder_names.add(encoder_name) _field_name = render_literal_type(encoder_name) typeddict_encoder.append( f"""\ @@ -753,7 +756,14 @@ def generate_common_client( handshake_type: HandshakeType, handshake_chunks: Sequence[str], modules: list[tuple[ModuleName, ClassName]], + protocol_version: Literal["v1.1", "v2.0"], ) -> FileContents: + client_module: str + match protocol_version: + case "v1.1": + client_module = "river" + case "v2.0": + client_module = "river.v2" chunks: list[str] = [ROOT_FILE_HEADER] chunks.extend( [ @@ -767,7 +777,7 @@ def generate_common_client( dedent( f"""\ class {client_name}: - def __init__(self, client: river.Client[{handshake_type}]): + def __init__(self, client: {client_module}.Client[{handshake_type}]): """.rstrip() ) ] @@ -780,11 +790,291 @@ def __init__(self, client: river.Client[{handshake_type}]): return FileContents("\n".join(chunks)) +def render_library_call( + schema_name: str, + name: str, + procedure: RiverProcedure, + protocol_version: Literal["v1.1", "v2.0"], + init_meta: tuple[RiverType, TypeExpression, str] | None, + input_meta: tuple[RiverType, TypeExpression, str] | None, + output_meta: tuple[RiverType, TypeExpression, str] | None, + error_meta: tuple[RiverType, TypeExpression, str] | None, +) -> list[str]: + """ + This method is only ever called from one place, but it's defensively establishing a + namespace that lets us draw some new boundaries around the parameters, without the + pollution from other intermediatae values. + """ + current_chunks: list[str] = [] + + binding: str + if procedure.type == "rpc": + match protocol_version: + case "v1.1": + assert input_meta, "rpc expects input to be required" + _, tpe, render_method = input_meta + binding = "input" + case "v2.0": + assert init_meta, "rpc expects init to be required" + _, tpe, render_method = init_meta + binding = "init" + case other: + assert_never(other) + + assert output_meta + assert error_meta + _, output_type, parse_output_method = output_meta + _, _, parse_error_method = error_meta + + current_chunks.extend( + [ + reindent( + " ", + f"""\ + async def {name}( + self, + {binding}: {render_type_expr(tpe)}, + timeout: datetime.timedelta, + ) -> {render_type_expr(output_type)}: + return await self.client.send_rpc( + {repr(schema_name)}, + {repr(name)}, + {binding}, + {reindent(" ", render_method)}, + {reindent(" ", parse_output_method)}, + {reindent(" ", parse_error_method)}, + timeout, + ) + """, + ) + ] + ) + elif procedure.type == "subscription": + match protocol_version: + case "v1.1": + assert input_meta, "subscription expects input to be required" + _, tpe, render_method = input_meta + binding = "input" + case "v2.0": + assert init_meta, "subscription expects init to be required" + _, tpe, render_method = init_meta + binding = "init" + case other: + assert_never(other) + + assert output_meta + assert error_meta + _, output_type, parse_output_method = output_meta + _, error_type, parse_error_method = error_meta + error_type_name = extract_inner_type(error_type) + + output_or_error_type = UnionTypeExpr([output_type, error_type_name]) + + output_or_error_type = UnionTypeExpr( + [ + output_or_error_type, + TypeName("RiverError"), + ] + ) + current_chunks.extend( + [ + reindent( + " ", + f"""\ + async def {name}( + self, + {binding}: {render_type_expr(tpe)}, + ) -> AsyncIterator[{render_type_expr(output_or_error_type)}]: + return self.client.send_subscription( + {repr(schema_name)}, + {repr(name)}, + {binding}, + {reindent(" ", render_method)}, + {reindent(" ", parse_output_method)}, + {reindent(" ", parse_error_method)}, + ) + """, + ) + ] + ) + elif procedure.type == "upload": + assert output_meta + assert error_meta + _, output_type, parse_output_method = output_meta + _, error_type, parse_error_method = error_meta + error_type_name = extract_inner_type(error_type) + + output_or_error_type = UnionTypeExpr([output_type, error_type_name]) + + if init_meta and input_meta: + _, init_type, render_init_method = init_meta + _, input_type, render_input_method = input_meta + current_chunks.extend( + [ + reindent( + " ", + f"""\ + async def {name}( + self, + init: {render_type_expr(init_type)}, + inputStream: AsyncIterable[{render_type_expr(input_type)}], + ) -> {render_type_expr(output_type)}: + return await self.client.send_upload( + {repr(schema_name)}, + {repr(name)}, + init, + inputStream, + {reindent(" ", render_init_method)}, + {reindent(" ", render_input_method)}, + {reindent(" ", parse_output_method)}, + {reindent(" ", parse_error_method)}, + ) + """, + ) + ] + ) + elif protocol_version == "v1.1": + assert input_meta, "upload requires input to be defined" + _, input_type, render_input_method = input_meta + current_chunks.extend( + [ + reindent( + " ", + f"""\ + async def {name}( + self, + inputStream: AsyncIterable[{render_type_expr(input_type)}], + ) -> { # TODO(dstewart) This should just be output_type + render_type_expr(output_or_error_type) + }: + return await self.client.send_upload( + {repr(schema_name)}, + {repr(name)}, + None, + inputStream, + None, + {reindent(" ", render_input_method)}, + {reindent(" ", parse_output_method)}, + {reindent(" ", parse_error_method)}, + ) + """, + ) + ] + ) + elif protocol_version == "v2.0": + raise ValueError( + "It is expected that protocol v2 uploads have both init and input " + "defined, otherwise it's no different than rpc", + ) + else: + assert_never(protocol_version) + elif procedure.type == "stream": + assert output_meta + assert error_meta + _, output_type, parse_output_method = output_meta + _, error_type, parse_error_method = error_meta + error_type_name = extract_inner_type(error_type) + + output_or_error_type = UnionTypeExpr([output_type, error_type_name]) + + output_or_error_type = UnionTypeExpr( + [ + output_or_error_type, + TypeName("RiverError"), + ] + ) + if init_meta and input_meta: + _, init_type, render_init_method = init_meta + _, input_type, render_input_method = input_meta + current_chunks.extend( + [ + reindent( + " ", + f"""\ + async def {name}( + self, + init: {render_type_expr(init_type)}, + inputStream: AsyncIterable[{render_type_expr(input_type)}], + ) -> AsyncIterator[{render_type_expr(output_or_error_type)}]: + return self.client.send_stream( + {repr(schema_name)}, + {repr(name)}, + init, + inputStream, + {reindent(" ", render_init_method)}, + {reindent(" ", render_input_method)}, + {reindent(" ", parse_output_method)}, + {reindent(" ", parse_error_method)}, + ) + """, + ) + ] + ) + elif protocol_version == "v1.1": + assert input_meta, "stream requires input to be defined" + _, input_type, render_input_method = input_meta + current_chunks.extend( + [ + reindent( + " ", + f"""\ + async def {name}( + self, + inputStream: AsyncIterable[{render_type_expr(input_type)}], + ) -> AsyncIterator[{render_type_expr(output_or_error_type)}]: + return self.client.send_stream( + {repr(schema_name)}, + {repr(name)}, + None, + inputStream, + None, + {reindent(" ", render_input_method)}, + {reindent(" ", parse_output_method)}, + {reindent(" ", parse_error_method)}, + ) + """, + ) + ] + ) + elif protocol_version == "v2.0": + assert init_meta, "Protocol v2 requires init to be defined" + _, init_type, render_init_method = init_meta + current_chunks.extend( + [ + reindent( + " ", + f"""\ + async def {name}( + self, + init: {render_type_expr(init_type)}, + ) -> AsyncIterator[{render_type_expr(output_or_error_type)}]: + return self.client.send_stream( + {repr(schema_name)}, + {repr(name)}, + init, + None, + {reindent(" ", render_init_method)}, + None, + {reindent(" ", parse_output_method)}, + {reindent(" ", parse_error_method)}, + ) + """, + ) + ] + ) + else: + assert_never(protocol_version) + + current_chunks.append("") + return current_chunks + + def generate_individual_service( schema_name: str, schema: RiverService, input_base_class: Literal["TypedDict"] | Literal["BaseModel"], method_filter: set[str] | None, + protocol_version: Literal["v1.1", "v2.0"], ) -> tuple[ModuleName, ClassName, dict[RenderedPath, FileContents]]: serdes: list[tuple[list[TypeName], list[ModuleName], list[FileContents]]] = [] @@ -809,12 +1099,19 @@ def _type_adapter_definition( ], ) + client_module: str + match protocol_version: + case "v1.1": + client_module = "river" + case "v2.0": + client_module = "river.v2" + class_name = ClassName(f"{schema_name.title()}Service") current_chunks: list[str] = [ dedent( f"""\ class {class_name}: - def __init__(self, client: river.Client[Any]): + def __init__(self, client: {client_module}.Client[Any]): self.client = client """ ), @@ -824,8 +1121,9 @@ def __init__(self, client: river.Client[Any]): continue module_names = [ModuleName(name)] init_type: TypeExpression | None = None + init_module_info: list[ModuleName] = [] if procedure.init: - init_type, module_info, init_chunks, encoder_names = encode_type( + init_type, init_module_info, init_chunks, encoder_names = encode_type( procedure.init, TypeName(f"{name.title()}Init"), input_base_class, @@ -835,34 +1133,29 @@ def __init__(self, client: river.Client[Any]): serdes.append( ( [extract_inner_type(init_type), *encoder_names], - module_info, + init_module_info, init_chunks, ) ) - input_type, module_info, input_chunks, encoder_names = encode_type( - procedure.input, - TypeName(f"{name.title()}Input"), - input_base_class, - module_names, - permit_unknown_members=False, - ) - input_type_name = extract_inner_type(input_type) - input_type_type_adapter_name = TypeName( - f"{render_literal_type(input_type_name)}TypeAdapter" - ) - serdes.append( - ( - [extract_inner_type(input_type), *encoder_names], - module_info, - input_chunks, + input_type: TypeExpression | None = None + input_module_info: list[ModuleName] = [] + if procedure.input: + input_type, input_module_info, input_chunks, encoder_names = encode_type( + procedure.input, + TypeName(f"{name.title()}Input"), + input_base_class, + module_names, + permit_unknown_members=False, ) - ) - serdes.append( - _type_adapter_definition( - input_type_type_adapter_name, input_type, module_info + serdes.append( + ( + [extract_inner_type(input_type), *encoder_names], + input_module_info, + input_chunks, + ) ) - ) - output_type, module_info, output_chunks, encoder_names = encode_type( + + output_type, output_module_info, output_chunks, encoder_names = encode_type( procedure.output, TypeName(f"{name.title()}Output"), "BaseModel", @@ -873,7 +1166,7 @@ def __init__(self, client: river.Client[Any]): serdes.append( ( [output_type_name, *encoder_names], - module_info, + output_module_info, output_chunks, ) ) @@ -882,12 +1175,12 @@ def __init__(self, client: river.Client[Any]): ) serdes.append( _type_adapter_definition( - output_type_type_adapter_name, output_type, module_info + output_type_type_adapter_name, output_type, output_module_info ) ) - output_module_info = module_info + if procedure.errors: - error_type, module_info, errors_chunks, encoder_names = encode_type( + error_type, error_module_info, errors_chunks, encoder_names = encode_type( procedure.errors, TypeName(f"{name.title()}Errors"), "RiverError", @@ -899,7 +1192,7 @@ def __init__(self, client: river.Client[Any]): error_type = error_type_name else: error_type_name = extract_inner_type(error_type) - serdes.append(([error_type_name], module_info, errors_chunks)) + serdes.append(([error_type_name], error_module_info, errors_chunks)) else: error_type_name = TypeName("RiverError") @@ -909,14 +1202,11 @@ def __init__(self, client: river.Client[Any]): f"{render_literal_type(error_type_name)}TypeAdapter" ) if error_type_type_adapter_name.value != "RiverErrorTypeAdapter": - if len(module_info) == 0: - module_info = output_module_info serdes.append( _type_adapter_definition( - error_type_type_adapter_name, error_type, module_info + error_type_type_adapter_name, error_type, output_module_info ) ) - output_or_error_type = UnionTypeExpr([output_type, error_type_name]) # NB: These strings must be indented to at least the same level of # the function strings in the branches below, otherwise `dedent` @@ -948,9 +1238,9 @@ def __init__(self, client: river.Client[Any]): ) and procedure.init.type in ["array"]: match init_type: case ListTypeExpr(init_type_name): - render_init_method = ( - f"lambda xs: [encode_{init_type_name}(x) for x in xs]" - ) + render_init_method = f"lambda xs: [encode_{ + render_literal_type(init_type_name) + }(x) for x in xs]" else: render_init_method = f"encode_{render_literal_type(init_type)}" else: @@ -960,7 +1250,7 @@ def __init__(self, client: river.Client[Any]): ) serdes.append( _type_adapter_definition( - init_type_type_adapter_name, init_type, module_info + init_type_type_adapter_name, init_type, init_module_info ) ) render_init_method = f"""\ @@ -968,206 +1258,79 @@ def __init__(self, client: river.Client[Any]): .validate_python(x) """ - assert init_type is None or render_init_method, ( - f"Unable to derive the init encoder from: {input_type}" - ) - # Input renderer render_input_method: str | None = None - if input_base_class == "TypedDict": - if is_literal(procedure.input): - render_input_method = "lambda x: x" - elif isinstance( - procedure.input, RiverConcreteType - ) and procedure.input.type in ["array"]: - match input_type: - case ListTypeExpr(list_type): - render_input_method = f"""\ - lambda xs: [ - encode_{render_literal_type(list_type)}(x) for x in xs - ] - """ - else: - render_input_method = f"encode_{render_literal_type(input_type)}" - else: - render_input_method = f"""\ - lambda x: {render_type_expr(input_type_type_adapter_name)} - .dump_python( - x, # type: ignore[arg-type] - by_alias=True, - exclude_none=True, - ) + if input_type and procedure.input is not None: + if input_base_class == "TypedDict": + if is_literal(procedure.input): + render_input_method = "lambda x: x" + elif isinstance( + procedure.input, RiverConcreteType + ) and procedure.input.type in ["array"]: + match input_type: + case ListTypeExpr(list_type): + render_input_method = f"""\ + lambda xs: [ + encode_{render_literal_type(list_type)}(x) for x in xs + ] """ - - assert render_input_method, ( - f"Unable to derive the input encoder from: {input_type}" - ) - + else: + render_input_method = f"encode_{render_literal_type(input_type)}" + else: + input_type_name = extract_inner_type(input_type) + input_type_type_adapter = TypeName( + f"{render_literal_type(input_type_name)}TypeAdapter" + ) + serdes.append( + _type_adapter_definition( + input_type_type_adapter, input_type, input_module_info + ) + ) + render_input_method = f"""\ + lambda x: {render_type_expr(input_type_type_adapter)} + .dump_python( + x, # type: ignore[arg-type] + by_alias=True, + exclude_none=True, + ) + """ if isinstance(output_type, NoneTypeExpr): parse_output_method = "lambda x: None" - if procedure.type == "rpc": - current_chunks.extend( - [ - reindent( - " ", - f"""\ - async def {name}( - self, - input: {render_type_expr(input_type)}, - timeout: datetime.timedelta, - ) -> {render_type_expr(output_type)}: - return await self.client.send_rpc( - {repr(schema_name)}, - {repr(name)}, - input, - {reindent(" ", render_input_method)}, - {reindent(" ", parse_output_method)}, - {reindent(" ", parse_error_method)}, - timeout, - ) - """, - ) - ] - ) - elif procedure.type == "subscription": - output_or_error_type = UnionTypeExpr( - [ - output_or_error_type, - TypeName("RiverError"), - ] - ) - current_chunks.extend( - [ - reindent( - " ", - f"""\ - async def {name}( - self, - input: {render_type_expr(input_type)}, - ) -> AsyncIterator[{render_type_expr(output_or_error_type)}]: - return self.client.send_subscription( - {repr(schema_name)}, - {repr(name)}, - input, - {reindent(" ", render_input_method)}, - {reindent(" ", parse_output_method)}, - {reindent(" ", parse_error_method)}, - ) - """, - ) - ] - ) - elif procedure.type == "upload": - if init_type: - assert render_init_method, "Expected an init renderer!" - current_chunks.extend( - [ - reindent( - " ", - f"""\ - async def {name}( - self, - init: {render_type_expr(init_type)}, - inputStream: AsyncIterable[{render_type_expr(input_type)}], - ) -> {render_type_expr(output_type)}: - return await self.client.send_upload( - {repr(schema_name)}, - {repr(name)}, - init, - inputStream, - {reindent(" ", render_init_method)}, - {reindent(" ", render_input_method)}, - {reindent(" ", parse_output_method)}, - {reindent(" ", parse_error_method)}, - ) - """, - ) - ] + def combine_or_none( + proc_type: RiverType | None, + tpe: TypeExpression | None, + serde_method: str | None, + ) -> tuple[RiverType, TypeExpression, str] | None: + if not proc_type and not tpe and not serde_method: + return None + if not proc_type or not tpe or not serde_method: + raise ValueError( + f"Unable to convert {repr(proc_type)} into either" + f" tpe={tpe} or render_method={serde_method}" ) - else: - current_chunks.extend( - [ - reindent( - " ", - f"""\ - async def {name}( - self, - inputStream: AsyncIterable[{render_type_expr(input_type)}], - ) -> {render_type_expr(output_or_error_type)}: - return await self.client.send_upload( - {repr(schema_name)}, - {repr(name)}, - None, - inputStream, - None, - {reindent(" ", render_input_method)}, - {reindent(" ", parse_output_method)}, - {reindent(" ", parse_error_method)}, - ) - """, - ) - ] - ) - elif procedure.type == "stream": - output_or_error_type = UnionTypeExpr( - [ - output_or_error_type, - TypeName("RiverError"), - ] + return (proc_type, tpe, serde_method) + + current_chunks.extend( + render_library_call( + schema_name=schema_name, + name=name, + procedure=procedure, + protocol_version=protocol_version, + init_meta=combine_or_none( + procedure.init, init_type, render_init_method + ), + input_meta=combine_or_none( + procedure.input, input_type, render_input_method + ), + output_meta=combine_or_none( + procedure.output, output_type, parse_output_method + ), + error_meta=combine_or_none( + procedure.errors, error_type, parse_error_method + ), ) - if init_type: - assert render_init_method, "Expected an init renderer!" - current_chunks.extend( - [ - reindent( - " ", - f"""\ - async def {name}( - self, - init: {render_type_expr(init_type)}, - inputStream: AsyncIterable[{render_type_expr(input_type)}], - ) -> AsyncIterator[{render_type_expr(output_or_error_type)}]: - return self.client.send_stream( - {repr(schema_name)}, - {repr(name)}, - init, - inputStream, - {reindent(" ", render_init_method)}, - {reindent(" ", render_input_method)}, - {reindent(" ", parse_output_method)}, - {reindent(" ", parse_error_method)}, - ) - """, - ) - ] - ) - else: - current_chunks.extend( - [ - reindent( - " ", - f"""\ - async def {name}( - self, - inputStream: AsyncIterable[{render_type_expr(input_type)}], - ) -> AsyncIterator[{render_type_expr(output_or_error_type)}]: - return self.client.send_stream( - {repr(schema_name)}, - {repr(name)}, - None, - inputStream, - None, - {reindent(" ", render_input_method)}, - {reindent(" ", parse_output_method)}, - {reindent(" ", parse_error_method)}, - ) - """, - ) - ] - ) - - current_chunks.append("") + ) emitted_files: dict[RenderedPath, FileContents] = {} @@ -1209,6 +1372,7 @@ def generate_river_client_module( schema_root: RiverSchema, typed_dict_inputs: bool, method_filter: set[str] | None, + protocol_version: Literal["v1.1", "v2.0"], ) -> dict[RenderedPath, FileContents]: files: dict[RenderedPath, FileContents] = {} @@ -1237,6 +1401,7 @@ def generate_river_client_module( schema, input_base_class, method_filter, + protocol_version, ) if emitted_files: # Short-cut if we didn't actually emit anything @@ -1244,7 +1409,11 @@ def generate_river_client_module( modules.append((module_name, class_name)) main_contents = generate_common_client( - client_name, handshake_type, handshake_chunks, modules + client_name, + handshake_type, + handshake_chunks, + modules, + protocol_version, ) files[RenderedPath(str(Path("__init__.py")))] = main_contents @@ -1258,6 +1427,7 @@ def schema_to_river_client_codegen( typed_dict_inputs: bool, file_opener: Callable[[Path], TextIO], method_filter: set[str] | None, + protocol_version: Literal["v1.1", "v2.0"], ) -> None: """Generates the lines of a River module.""" with read_schema() as f: @@ -1267,6 +1437,7 @@ def schema_to_river_client_codegen( schemas.root, typed_dict_inputs, method_filter, + protocol_version, ).items(): module_path = Path(target_path).joinpath(subpath) module_path.parent.mkdir(mode=0o755, parents=True, exist_ok=True) diff --git a/src/replit_river/codegen/run.py b/src/replit_river/codegen/run.py index 16d94a5e..5a69f1dd 100644 --- a/src/replit_river/codegen/run.py +++ b/src/replit_river/codegen/run.py @@ -45,6 +45,13 @@ def main() -> None: action="store", type=pathlib.Path, ) + client.add_argument( + "--protocol-version", + help="Generate river v2 clients", + action="store", + default="v1.1", + choices=["v1.1", "v2.0"], + ) client.add_argument("schema", help="schema file") args = parser.parse_args() @@ -69,12 +76,13 @@ def file_opener(path: Path) -> TextIO: method_filter = set(x.strip() for x in handle.readlines()) schema_to_river_client_codegen( - lambda: open(schema_path), - target_path, - args.client_name, - args.typed_dict_inputs, - file_opener, + read_schema=lambda: open(schema_path), + target_path=target_path, + client_name=args.client_name, + typed_dict_inputs=args.typed_dict_inputs, + file_opener=file_opener, method_filter=method_filter, + protocol_version=args.protocol_version, ) else: raise NotImplementedError(f"Unknown command {args.command}") diff --git a/src/replit_river/codegen/typing.py b/src/replit_river/codegen/typing.py index 53c028ff..68443ffa 100644 --- a/src/replit_river/codegen/typing.py +++ b/src/replit_river/codegen/typing.py @@ -210,19 +210,23 @@ def extract_inner_type(value: TypeExpression) -> TypeName: case ListTypeExpr(nested): return extract_inner_type(nested) case LiteralTypeExpr(_): - raise ValueError(f"Unexpected literal type: {value}") + raise ValueError(f"Unexpected literal type: {repr(value)}") case UnionTypeExpr(_): raise ValueError( - f"Attempting to extract from a union, currently not possible: {value}" + "Attempting to extract from a union, " + f"currently not possible: {repr(value)}" ) case OpenUnionTypeExpr(_): raise ValueError( - f"Attempting to extract from a union, currently not possible: {value}" + "Attempting to extract from a union, " + f"currently not possible: {repr(value)}" ) case TypeName(name): return TypeName(name) case NoneTypeExpr(): - raise ValueError(f"Attempting to extract from a literal 'None': {value}") + raise ValueError( + f"Attempting to extract from a literal 'None': {repr(value)}", + ) case other: assert_never(other) @@ -233,5 +237,5 @@ def ensure_literal_type(value: TypeExpression) -> TypeName: return TypeName(name) case other: raise ValueError( - f"Unexpected expression when expecting a type name: {other}" + f"Unexpected expression when expecting a type name: {repr(other)}" ) diff --git a/src/replit_river/common_session.py b/src/replit_river/common_session.py index 2325492e..3dce6138 100644 --- a/src/replit_river/common_session.py +++ b/src/replit_river/common_session.py @@ -1,13 +1,23 @@ +import asyncio import enum import logging -from typing import Any, Protocol +from typing import Any, Awaitable, Callable, Coroutine, Protocol from opentelemetry.trace import Span +from websockets import WebSocketCommonProtocol +from websockets.asyncio.client import ClientConnection + +from replit_river.messages import ( + FailedSendingMessageException, + WebsocketClosedException, + send_transport_message, +) +from replit_river.rpc import TransportMessage logger = logging.getLogger(__name__) -class SendMessage(Protocol): +class SendMessage[Result](Protocol): async def __call__( self, *, @@ -17,24 +27,91 @@ async def __call__( service_name: str | None, procedure_name: str | None, span: Span | None, - ) -> None: ... + ) -> Result: ... class SessionState(enum.Enum): """The state a session can be in. Valid transitions: - - NO_CONNECTION -> {ACTIVE} - - ACTIVE -> {NO_CONNECTION, CLOSING} + - NO_CONNECTION -> {CONNECTING, CLOSING} + - CONNECTING -> {NO_CONNECTION, ACTIVE, CLOSING} + - ACTIVE -> {NO_CONNECTION, CONNECTING, CLOSING} - CLOSING -> {CLOSED} - CLOSED -> {} """ NO_CONNECTION = 0 - ACTIVE = 1 - CLOSING = 2 - CLOSED = 3 + CONNECTING = 1 + ACTIVE = 2 + CLOSING = 3 + CLOSED = 4 -ConnectingStates = set([SessionState.NO_CONNECTION]) +ConnectingStates = set([SessionState.NO_CONNECTION, SessionState.CONNECTING]) +ActiveStates = set([SessionState.ACTIVE]) TerminalStates = set([SessionState.CLOSING, SessionState.CLOSED]) + + +async def buffered_message_sender( + block_until_connected: Callable[[], Awaitable[None]], + block_until_message_available: Callable[[], Awaitable[None]], + get_ws: Callable[[], WebSocketCommonProtocol | ClientConnection | None], + websocket_closed_callback: Callable[[], Coroutine[Any, Any, None]], + get_next_pending: Callable[[], TransportMessage | None], + commit: Callable[[TransportMessage], Awaitable[None]], + get_state: Callable[[], SessionState], +) -> None: + """ + buffered_message_sender runs in a task and consumes from a queue, emitting + messages over the websocket as quickly as it can. + + One of the design goals is to keep the message queue as short as possible to permit + quickly cancelling streams or acking heartbeats, so to that end it is wise to + incorporate backpressure into the lifecycle of get_next_pending/commit. + """ + + our_task = asyncio.current_task() + while our_task and not our_task.cancelling() and not our_task.cancelled(): + while get_state() in ConnectingStates: + # Block until we have a handle + logger.debug( + "_buffered_message_sender: Waiting until ws is connected", + ) + await block_until_connected() + + if get_state() in TerminalStates: + logger.debug("_buffered_message_sender: closing") + return + + await block_until_message_available() + + if not (ws := get_ws()): + logger.debug("_buffered_message_sender: ws is not connected, loop") + continue + + if msg := get_next_pending(): + logger.debug( + "_buffered_message_sender: Dequeued %r to send over %r", + msg, + ws, + ) + try: + await send_transport_message(msg, ws, websocket_closed_callback) + await commit(msg) + logger.debug("_buffered_message_sender: Sent %r", msg.id) + except WebsocketClosedException as e: + logger.debug( + "_buffered_message_sender: Connection closed while sending " + "message %r, waiting for reconnect and retry from buffer", + type(e), + exc_info=e, + ) + except FailedSendingMessageException: + logger.error( + "Failed sending message, " + "waiting for reconnect and retry from buffer", + exc_info=True, + ) + except Exception: + logger.exception("Error attempting to send buffered messages") diff --git a/src/replit_river/error_schema.py b/src/replit_river/error_schema.py index af5837dd..5bff801a 100644 --- a/src/replit_river/error_schema.py +++ b/src/replit_river/error_schema.py @@ -17,9 +17,16 @@ # ERROR_CODE_CANCEL is the code used when either server or client cancels the stream. ERROR_CODE_CANCEL = "CANCEL" +# SYNTHETIC_ERROR_CODE_SESSION_CLOSED is a synthetic code emitted exclusively by the +# client's session. It is not sent over the wire. +SYNTHETIC_ERROR_CODE_SESSION_CLOSED = "SESSION_CLOSED" + # ERROR_CODE_UNKNOWN is the code for the RiverUnknownError ERROR_CODE_UNKNOWN = "UNKNOWN" +# SESSION_STATE_MISMATCH is the code when the remote server rejects the session's state +ERROR_CODE_SESSION_STATE_MISMATCH = "SESSION_STATE_MISMATCH" + class RiverError(BaseModel): """Error message from the server.""" @@ -75,6 +82,14 @@ class StreamClosedRiverServiceException(RiverServiceException): pass +class SessionClosedRiverServiceException(RiverException): + def __init__( + self, + message: str, + ) -> None: + super().__init__(SYNTHETIC_ERROR_CODE_SESSION_CLOSED, message) + + def exception_from_message(code: str) -> type[RiverServiceException]: """Return the error class for a given error code.""" if code == ERROR_CODE_STREAM_CLOSED: diff --git a/src/replit_river/message_buffer.py b/src/replit_river/message_buffer.py index 6e1fdad7..1ced8582 100644 --- a/src/replit_river/message_buffer.py +++ b/src/replit_river/message_buffer.py @@ -17,6 +17,7 @@ class MessageBuffer: def __init__(self, max_num_messages: int = MAX_MESSAGE_BUFFER_SIZE): self.max_size = max_num_messages self.buffer: list[TransportMessage] = [] + self._has_messages = asyncio.Event() self._space_available_cond = asyncio.Condition() self._closed = False @@ -35,6 +36,7 @@ def put(self, message: TransportMessage) -> None: if self._closed: raise MessageBufferClosedError("message buffer is closed") self.buffer.append(message) + self._has_messages.set() def get_next_sent_seq(self) -> int | None: if self.buffer: @@ -50,13 +52,23 @@ def peek(self) -> TransportMessage | None: async def remove_old_messages(self, min_seq: int) -> None: """Remove messages in the buffer with a seq number less than min_seq.""" self.buffer = [msg for msg in self.buffer if msg.seq >= min_seq] + if self.buffer: + self._has_messages.set() + else: + self._has_messages.clear() async with self._space_available_cond: self._space_available_cond.notify_all() + async def block_until_message_available(self) -> None: + """Allow consumers to avoid spinning unnecessarily""" + await self._has_messages.wait() + async def close(self) -> None: """ Closes the message buffer and rejects any pending put operations. """ self._closed = True + # Wake up block_until_message_available to permit graceful cleanup + self._has_messages.set() async with self._space_available_cond: self._space_available_cond.notify_all() diff --git a/src/replit_river/messages.py b/src/replit_river/messages.py index 9cdf324a..653929a6 100644 --- a/src/replit_river/messages.py +++ b/src/replit_river/messages.py @@ -8,6 +8,7 @@ from websockets import ( WebSocketCommonProtocol, ) +from websockets.asyncio.client import ClientConnection from replit_river.rpc import ( TransportMessage, @@ -29,7 +30,7 @@ class FailedSendingMessageException(Exception): async def send_transport_message( msg: TransportMessage, - ws: WebSocketCommonProtocol, + ws: WebSocketCommonProtocol | ClientConnection, # legacy | asyncio websocket_closed_callback: Callable[[], Coroutine[Any, Any, None]], ) -> None: logger.debug("sending a message %r to ws %s", msg, ws) diff --git a/src/replit_river/rate_limiter.py b/src/replit_river/rate_limiter.py index b9265eee..5e742ce9 100644 --- a/src/replit_river/rate_limiter.py +++ b/src/replit_river/rate_limiter.py @@ -1,9 +1,19 @@ import asyncio +import logging import random from contextvars import Context +from replit_river.error_schema import RiverException from replit_river.transport_options import ConnectionRetryOptions +logger = logging.getLogger(__name__) + + +class BudgetExhaustedException(RiverException): + def __init__(self, code: str, message: str, client_id: str) -> None: + super().__init__(code, message) + self.client_id = client_id + class LeakyBucketRateLimit: """Asynchronous leaky bucket rate limiter. diff --git a/src/replit_river/rpc.py b/src/replit_river/rpc.py index 3459bf7d..678e8bf6 100644 --- a/src/replit_river/rpc.py +++ b/src/replit_river/rpc.py @@ -47,8 +47,10 @@ GenericRpcHandlerBuilder = Callable[ [str, Channel[Any], Channel[Any]], Coroutine[None, None, None] ] -ACK_BIT = 0b00001 -STREAM_OPEN_BIT = 0b00010 +ACK_BIT_TYPE = Literal[0b00001] +ACK_BIT: ACK_BIT_TYPE = 0b00001 +STREAM_OPEN_BIT_TYPE = Literal[0b00010] +STREAM_OPEN_BIT: STREAM_OPEN_BIT_TYPE = 0b00010 # these codes are retriable # if the server sends a response with one of these codes, diff --git a/src/replit_river/seq_manager.py b/src/replit_river/seq_manager.py index 8a2f6798..a1baa0f3 100644 --- a/src/replit_river/seq_manager.py +++ b/src/replit_river/seq_manager.py @@ -18,7 +18,12 @@ class OutOfOrderMessageException(Exception): we close the session. """ - pass + def __init__(self, *, received_seq: int, expected_ack: int) -> None: + super().__init__( + "Out of order message received: " + f"got={received_seq}, " + f"expected={expected_ack}" + ) class SessionStateMismatchException(Exception): @@ -71,7 +76,8 @@ def check_seq_and_update(self, msg: TransportMessage) -> IgnoreMessage | None: ) raise OutOfOrderMessageException( - f"Out of order message received got {msg.seq} expected {self.ack}" + received_seq=msg.seq, + expected_ack=self.ack, ) self.receiver_ack = msg.ack self.ack = msg.seq + 1 diff --git a/src/replit_river/server_session.py b/src/replit_river/server_session.py index 3a931274..c397e900 100644 --- a/src/replit_river/server_session.py +++ b/src/replit_river/server_session.py @@ -164,7 +164,6 @@ async def _handle_messages_from_ws(self, tg: asyncio.TaskGroup) -> None: pass except RuntimeError as e: raise InvalidMessageException(e) from e - else: _stream = await self._open_stream_and_call_handler(msg, tg) if isinstance(_stream, IgnoreMessage): diff --git a/src/replit_river/server_transport.py b/src/replit_river/server_transport.py index 7b15e36e..3f743e51 100644 --- a/src/replit_river/server_transport.py +++ b/src/replit_river/server_transport.py @@ -186,7 +186,7 @@ async def _send_handshake_response( response_message = TransportMessage( streamId=request_message.streamId, id=nanoid.generate(), - from_=request_message.to, # type: ignore + from_=request_message.to, to=request_message.from_, seq=0, ack=0, diff --git a/src/replit_river/session.py b/src/replit_river/session.py index ac01ffba..465a6672 100644 --- a/src/replit_river/session.py +++ b/src/replit_river/session.py @@ -24,7 +24,7 @@ ) from replit_river.task_manager import BackgroundTaskManager from replit_river.transport_options import TransportOptions -from replit_river.websocket_wrapper import WebsocketWrapper +from replit_river.websocket_wrapper import WebsocketWrapper, WsState from .rpc import ( ACK_BIT, @@ -120,7 +120,11 @@ def increment_and_get_heartbeat_misses() -> int: self.session_id, self._transport_options.heartbeat_ms, self._transport_options.heartbeats_until_dead, - lambda: self._state, + lambda: ( + self._state + if self._ws_wrapper.ws_state == WsState.OPEN + else SessionState.CONNECTING + ), lambda: self._close_session_after_time_secs, close_websocket=do_close_websocket, send_message=self.send_message, @@ -232,7 +236,7 @@ async def send_message( msg = TransportMessage( streamId=stream_id, id=nanoid.generate(), - from_=self._transport_id, # type: ignore + from_=self._transport_id, to=self._to_id, seq=self._seq_manager.get_seq_and_increment(), ack=self._seq_manager.get_ack(), @@ -346,7 +350,7 @@ async def setup_heartbeat( get_state: Callable[[], SessionState], get_closing_grace_period: Callable[[], float | None], close_websocket: Callable[[], Awaitable[None]], - send_message: SendMessage, + send_message: SendMessage[None], increment_and_get_heartbeat_misses: Callable[[], int], ) -> None: while True: diff --git a/src/replit_river/transport_options.py b/src/replit_river/transport_options.py index 47032bac..09f000f8 100644 --- a/src/replit_river/transport_options.py +++ b/src/replit_river/transport_options.py @@ -27,6 +27,7 @@ class TransportOptions(BaseModel): connection_retry_options: ConnectionRetryOptions = ConnectionRetryOptions() buffer_size: int = 1_000 transparent_reconnect: bool = True + shutdown_all_streams_timeout_ms: float = 10_000 def websocket_disconnect_grace_ms(self) -> float: return self.heartbeat_ms * self.heartbeats_until_dead @@ -39,11 +40,16 @@ def create_from_env(cls) -> "TransportOptions": ) heartbeat_ms = float(os.getenv("HEARTBEAT_MS", 2_000)) heartbeats_to_dead = int(os.getenv("HEARTBEATS_UNTIL_DEAD", 2)) + shutdown_all_streams_timeout_ms = float( + os.getenv("SHUTDOWN_STREAMS_TIMEOUT_MS", 10_000) + ) + return TransportOptions( handshake_timeout_ms=handshake_timeout_ms, session_disconnect_grace_ms=session_disconnect_grace_ms, heartbeat_ms=heartbeat_ms, heartbeats_until_dead=heartbeats_to_dead, + shutdown_all_streams_timeout_ms=shutdown_all_streams_timeout_ms, ) diff --git a/src/replit_river/v2/__init__.py b/src/replit_river/v2/__init__.py new file mode 100644 index 00000000..a9b0c7ee --- /dev/null +++ b/src/replit_river/v2/__init__.py @@ -0,0 +1,7 @@ +from .client import Client +from .session import Session + +__all__ = [ + "Client", + "Session", +] diff --git a/src/replit_river/v2/client.py b/src/replit_river/v2/client.py new file mode 100644 index 00000000..1b900d38 --- /dev/null +++ b/src/replit_river/v2/client.py @@ -0,0 +1,206 @@ +import logging +from collections.abc import AsyncIterable, Awaitable, Callable +from contextlib import contextmanager +from dataclasses import dataclass +from datetime import timedelta +from typing import Any, AsyncGenerator, Generator, Generic, Literal + +from opentelemetry import trace +from opentelemetry.trace import Span, SpanKind, Status, StatusCode + +from replit_river.error_schema import RiverError, RiverException +from replit_river.transport_options import ( + HandshakeMetadataType, + TransportOptions, + UriAndMetadata, +) +from replit_river.v2.client_transport import ClientTransport + +logger = logging.getLogger(__name__) +tracer = trace.get_tracer(__name__) + + +class Client(Generic[HandshakeMetadataType]): + def __init__( + self, + uri_and_metadata_factory: Callable[ + [], Awaitable[UriAndMetadata[HandshakeMetadataType]] + ], + client_id: str, + server_id: str, + transport_options: TransportOptions, + ) -> None: + self._client_id = client_id + self._server_id = server_id + self._transport = ClientTransport[HandshakeMetadataType]( + uri_and_metadata_factory=uri_and_metadata_factory, + client_id=client_id, + server_id=server_id, + transport_options=transport_options, + ) + + async def close(self) -> None: + logger.info(f"river client {self._client_id} start closing") + await self._transport.close() + logger.info(f"river client {self._client_id} closed") + + async def ensure_connected(self) -> None: + await self._transport.get_or_create_session() + + async def send_rpc[R, A]( + self, + service_name: str, + procedure_name: str, + request: R, + request_serializer: Callable[[R], Any], + response_deserializer: Callable[[Any], A], + error_deserializer: Callable[[Any], RiverError], + timeout: timedelta, + ) -> A: + with _trace_procedure("rpc", service_name, procedure_name) as span_handle: + session = await self._transport.get_or_create_session() + return await session.send_rpc( + service_name, + procedure_name, + request, + request_serializer, + response_deserializer, + error_deserializer, + span_handle.span, + timeout, + ) + + async def send_upload[I, R, A]( + self, + service_name: str, + procedure_name: str, + init: I, + request: AsyncIterable[R], + init_serializer: Callable[[I], Any], + request_serializer: Callable[[R], Any], + response_deserializer: Callable[[Any], A], + error_deserializer: Callable[[Any], RiverError], + ) -> A: + with _trace_procedure("upload", service_name, procedure_name) as span_handle: + session = await self._transport.get_or_create_session() + return await session.send_upload( + service_name, + procedure_name, + init, + request, + init_serializer, + request_serializer, + response_deserializer, + error_deserializer, + span_handle.span, + ) + + async def send_subscription[R, E, A]( + self, + service_name: str, + procedure_name: str, + request: R, + request_serializer: Callable[[R], Any], + response_deserializer: Callable[[Any], A], + error_deserializer: Callable[[Any], E], + ) -> AsyncGenerator[A | E, None]: + with _trace_procedure( + "subscription", service_name, procedure_name + ) as span_handle: + session = await self._transport.get_or_create_session() + async for msg in session.send_subscription( + service_name, + procedure_name, + request, + request_serializer, + response_deserializer, + error_deserializer, + span_handle.span, + ): + if isinstance(msg, RiverError): + _record_river_error(span_handle, msg) + yield msg # type: ignore # https://github.com/python/mypy/issues/10817 + + async def send_stream[I, R, E, A]( + self, + service_name: str, + procedure_name: str, + init: I, + request: AsyncIterable[R] | None, + init_serializer: Callable[[I], Any], + request_serializer: Callable[[R], Any] | None, + response_deserializer: Callable[[Any], A], + error_deserializer: Callable[[Any], E], + ) -> AsyncGenerator[A | E, None]: + with _trace_procedure("stream", service_name, procedure_name) as span_handle: + session = await self._transport.get_or_create_session() + async for msg in session.send_stream( + service_name, + procedure_name, + init, + request, + init_serializer, + request_serializer, + response_deserializer, + error_deserializer, + span_handle.span, + ): + if isinstance(msg, RiverError): + _record_river_error(span_handle, msg) + yield msg # type: ignore # https://github.com/python/mypy/issues/10817 + + +@dataclass +class _SpanHandle: + """Wraps a span and keeps track of whether or not a status has been recorded yet.""" + + span: Span + did_set_status: bool = False + + def set_status( + self, + status: Status | StatusCode, + description: str | None = None, + ) -> None: + if self.did_set_status: + return + self.did_set_status = True + self.span.set_status(status, description) + + +@contextmanager +def _trace_procedure( + procedure_type: Literal["rpc", "upload", "subscription", "stream"], + service_name: str, + procedure_name: str, +) -> Generator[_SpanHandle, None, None]: + span = tracer.start_span( + f"river.client.{procedure_type}.{service_name}.{procedure_name}", + kind=SpanKind.CLIENT, + ) + span_handle = _SpanHandle(span) + try: + yield span_handle + span_handle.set_status(StatusCode.OK) + except GeneratorExit: + # This error indicates the caller is done with the async generator + # but messages are still left. This is okay, we do not consider it an error. + span_handle.set_status(StatusCode.OK) + raise + except RiverException as e: + span.record_exception(e, escaped=True) + _record_river_error(span_handle, RiverError(code=e.code, message=e.message)) + raise e + except BaseException as e: + span.record_exception(e, escaped=True) + span_handle.set_status(StatusCode.ERROR, f"{type(e).__name__}: {e}") + raise e + finally: + span.end() + + +def _record_river_error(span_handle: _SpanHandle, error: RiverError) -> None: + span_handle.set_status(StatusCode.ERROR, error.message) + span_handle.span.record_exception(RiverException(error.code, error.message)) + span_handle.span.set_attribute("river.error_code", error.code) + span_handle.span.set_attribute("river.error_message", error.message) diff --git a/src/replit_river/v2/client_transport.py b/src/replit_river/v2/client_transport.py new file mode 100644 index 00000000..3dc96522 --- /dev/null +++ b/src/replit_river/v2/client_transport.py @@ -0,0 +1,96 @@ +import logging +from collections.abc import Awaitable, Callable +from typing import Generic + +import nanoid + +from replit_river.rate_limiter import LeakyBucketRateLimit +from replit_river.transport_options import ( + HandshakeMetadataType, + TransportOptions, + UriAndMetadata, +) +from replit_river.v2.session import Session + +logger = logging.getLogger(__name__) + + +class ClientTransport(Generic[HandshakeMetadataType]): + _session: Session | None + + def __init__( + self, + uri_and_metadata_factory: Callable[[], Awaitable[UriAndMetadata]], + client_id: str, + server_id: str, + transport_options: TransportOptions, + ): + self._session = None + self._transport_id = nanoid.generate() + self._transport_options = transport_options + + self._uri_and_metadata_factory = uri_and_metadata_factory + self._client_id = client_id + self._server_id = server_id + self._rate_limiter = LeakyBucketRateLimit( + transport_options.connection_retry_options + ) + + async def close(self) -> None: + self._rate_limiter.close() + if self._session: + await self._session.close() + logger.info( + "Transport closed", + extra={ + "client_id": self._client_id, + "transport_id": self._transport_id, + }, + ) + + async def get_or_create_session(self) -> Session: + """ + Create a session if it does not exist, + call ensure_connected on whatever session is active. + """ + existing_session = self._session + if not existing_session or existing_session.is_closed(): + logger.info("Creating new session") + new_session = Session( + client_id=self._client_id, + server_id=self._server_id, + session_id=nanoid.generate(), + transport_options=self._transport_options, + close_session_callback=self._delete_session, + retry_connection_callback=self._retry_connection, + uri_and_metadata_factory=self._uri_and_metadata_factory, + rate_limiter=self._rate_limiter, + ) + + self._session = new_session + existing_session = new_session + + await existing_session.ensure_connected() + return existing_session + + async def _retry_connection(self) -> Session: + if self._session and not self._transport_options.transparent_reconnect: + logger.info("transparent_reconnect not set, closing {self._transport_id}") + await self._session.close() + logger.debug("Triggering get_or_create_session") + return await self.get_or_create_session() + + async def _delete_session(self, session: Session) -> None: + if self._session is session: + self._session = None + else: + logger.warning( + "Session attempted to close itself but it was not the " + "active session, doing nothing", + extra={ + "client_id": self._client_id, + "transport_id": self._transport_id, + "active_session_id": self._session and self._session.session_id, + "orphan_session_id": session.session_id, + }, + ) diff --git a/src/replit_river/v2/session.py b/src/replit_river/v2/session.py new file mode 100644 index 00000000..46ab151d --- /dev/null +++ b/src/replit_river/v2/session.py @@ -0,0 +1,1358 @@ +import asyncio +import logging +from collections import deque +from collections.abc import AsyncIterable +from contextlib import asynccontextmanager +from dataclasses import dataclass +from datetime import timedelta +from typing import ( + Any, + AsyncGenerator, + AsyncIterator, + Awaitable, + Callable, + Coroutine, + Literal, + NotRequired, + TypeAlias, + TypedDict, + assert_never, +) + +import nanoid +import websockets.asyncio.client +from aiochannel import Channel, ChannelEmpty, ChannelFull +from aiochannel.errors import ChannelClosed +from opentelemetry.trace import Span, use_span +from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator +from pydantic import ValidationError +from websockets.asyncio.client import ClientConnection +from websockets.exceptions import ConnectionClosed + +from replit_river.common_session import ( + ActiveStates, + ConnectingStates, + SendMessage, + SessionState, + TerminalStates, + buffered_message_sender, +) +from replit_river.error_schema import ( + ERROR_CODE_CANCEL, + ERROR_CODE_SESSION_STATE_MISMATCH, + ERROR_CODE_STREAM_CLOSED, + ERROR_HANDSHAKE, + RiverError, + RiverException, + RiverServiceException, + SessionClosedRiverServiceException, + exception_from_message, +) +from replit_river.messages import ( + FailedSendingMessageException, + WebsocketClosedException, + parse_transport_msg, + send_transport_message, +) +from replit_river.rate_limiter import LeakyBucketRateLimit +from replit_river.rpc import ( + ACK_BIT, + STREAM_OPEN_BIT, + ControlMessageHandshakeRequest, + ControlMessageHandshakeResponse, + ExpectedSessionState, + TransportMessage, + TransportMessageTracingSetter, +) +from replit_river.seq_manager import ( + InvalidMessageException, + OutOfOrderMessageException, +) +from replit_river.task_manager import BackgroundTaskManager +from replit_river.transport_options import ( + MAX_MESSAGE_BUFFER_SIZE, + TransportOptions, + UriAndMetadata, +) + +PROTOCOL_VERSION = "v2.0" + +STREAM_CANCEL_BIT_TYPE = Literal[0b00100] +STREAM_CANCEL_BIT: STREAM_CANCEL_BIT_TYPE = 0b00100 +STREAM_CLOSED_BIT_TYPE = Literal[0b01000] +STREAM_CLOSED_BIT: STREAM_CLOSED_BIT_TYPE = 0b01000 + + +_BackpressuredWaiter: TypeAlias = Callable[[], Awaitable[None]] + + +class ResultOk(TypedDict): + ok: Literal[True] + payload: Any + + +class ErrorPayload(TypedDict): + code: str + message: str + + +class ResultError(TypedDict): + # Account for structurally incoherent payloads + ok: NotRequired[Literal[False]] + payload: ErrorPayload + + +ResultType: TypeAlias = ResultOk | ResultError + +logger = logging.getLogger(__name__) + +trace_propagator = TraceContextTextMapPropagator() +trace_setter = TransportMessageTracingSetter() + +CloseSessionCallback: TypeAlias = Callable[["Session"], Coroutine[Any, Any, Any]] +RetryConnectionCallback: TypeAlias = Callable[ + [], + Coroutine[Any, Any, Any], +] + + +@dataclass +class _IgnoreMessage: + pass + + +class StreamMeta(TypedDict): + span: Span + release_backpressured_waiter: Callable[[], None] + error_channel: Channel[Exception] + output: Channel[ResultType] + + +class Session[HandshakeMetadata]: + _server_id: str + session_id: str + _transport_options: TransportOptions + + # session state, only modified during closing + _state: SessionState + _close_session_callback: CloseSessionCallback + _close_session_after_time_secs: float | None + _connecting_task: asyncio.Task[None] | None + _wait_for_connected: asyncio.Event + + _client_id: str + _rate_limiter: LeakyBucketRateLimit + _uri_and_metadata_factory: Callable[ + [], Awaitable[UriAndMetadata[HandshakeMetadata]] + ] + + # ws state + _ws: ClientConnection | None + _retry_connection_callback: RetryConnectionCallback | None + + # message state + _process_messages: asyncio.Event + _space_available: asyncio.Event + + # stream for tasks + _streams: dict[str, StreamMeta] + + # book keeping + _ack_buffer: deque[TransportMessage] + _send_buffer: deque[TransportMessage] + _task_manager: BackgroundTaskManager + ack: int # Most recently acknowledged seq + seq: int # Last sent sequence number + + # Terminating + _terminating_task: asyncio.Task[None] | None + + def __init__( + self, + server_id: str, + session_id: str, + transport_options: TransportOptions, + close_session_callback: CloseSessionCallback, + client_id: str, + rate_limiter: LeakyBucketRateLimit, + uri_and_metadata_factory: Callable[ + [], Awaitable[UriAndMetadata[HandshakeMetadata]] + ], + retry_connection_callback: RetryConnectionCallback | None = None, + ) -> None: + self._server_id = server_id + self.session_id = session_id + self._transport_options = transport_options + + # session state + self._state = SessionState.NO_CONNECTION + self._close_session_callback = close_session_callback + self._close_session_after_time_secs: float | None = None + self._connecting_task = None + self._wait_for_connected = asyncio.Event() + + self._client_id = client_id + # TODO: LeakyBucketRateLimit accepts "user" for all methods, which has + # historically been and continues to be "client_id". + # + # There's 1:1 client <-> transport, which means LeakyBucketRateLimit is only + # tracking exactly one rate limit. + # + # The "user" parameter is YAGNI, dethread client_id after v1 is deleted. + self._rate_limiter = rate_limiter + self._uri_and_metadata_factory = uri_and_metadata_factory + + # ws state + self._ws = None + self._retry_connection_callback = retry_connection_callback + + # message state + self._process_messages = asyncio.Event() + self._space_available = asyncio.Event() + # Ensure we initialize the above Event to "set" to avoid being blocked from + # the beginning. + self._space_available.set() + + # stream for tasks + self._streams: dict[str, StreamMeta] = {} + + # book keeping + self._ack_buffer = deque() + self._send_buffer = deque() + self._task_manager = BackgroundTaskManager() + self.ack = 0 + self.seq = 0 + + # Terminating + self._terminating_task = None + + self._start_recv_from_ws() + self._start_buffered_message_sender() + + async def ensure_connected(self) -> None: + """ + Either return immediately or establish a websocket connection and return + once we can accept messages. + + One of the goals of this function is to gate exactly one call to the + logic that actually establishes the connection. + """ + + logger.debug("ensure_connected: is_connected=%r", self.is_connected()) + if self.is_connected(): + return + + def get_next_sent_seq() -> int: + if self._send_buffer: + return self._send_buffer[0].seq + return self.seq + + def close_session(reason: Exception | None) -> None: + # If we're already closing, just let whoever's currently doing it handle it. + if self._state in TerminalStates: + return + + # Avoid closing twice + if self._terminating_task is None: + current_state = self._state + self._state = SessionState.CLOSING + + # We can't just call self.close() directly because + # we're inside a thread that will eventually be awaited + # during the cleanup procedure. + + self._terminating_task = asyncio.create_task( + self.close(reason, current_state=current_state), + ) + + def transition_connecting() -> None: + if self._state in TerminalStates: + return + logger.debug("transition_connecting") + self._state = SessionState.CONNECTING + # "Clear" here means observers should wait until we are connected. + self._wait_for_connected.clear() + + def transition_connected(ws: ClientConnection) -> None: + if self._state in TerminalStates: + return + logger.debug("transition_connected") + self._state = SessionState.ACTIVE + self._ws = ws + + # We're connected, wake everybody up using set() + self._wait_for_connected.set() + + def close_ws_in_background(ws: ClientConnection) -> None: + self._task_manager.create_task(ws.close()) + + def unbind_connecting_task() -> None: + # We are in a state where we may throw an exception. + # + # To allow subsequent calls to ensure_connected to pass, we clear ourselves. + # This is safe because each individual function that is waiting on this + # function completeing already has a reference, so we'll last a few ticks + # before GC. + # + # Let's do our best to avoid clobbering other tasks by comparing the .name + current_task = asyncio.current_task() + if self._connecting_task is current_task: + self._connecting_task = None + + if not self._connecting_task: + self._connecting_task = asyncio.create_task( + _do_ensure_connected( + transport_options=self._transport_options, + client_id=self._client_id, + server_id=self._server_id, + session_id=self.session_id, + rate_limiter=self._rate_limiter, + uri_and_metadata_factory=self._uri_and_metadata_factory, + get_next_sent_seq=get_next_sent_seq, + get_current_ack=lambda: self.ack, + get_current_time=self._get_current_time, + get_state=lambda: self._state, + transition_connecting=transition_connecting, + close_ws_in_background=close_ws_in_background, + transition_connected=transition_connected, + unbind_connecting_task=unbind_connecting_task, + close_session=close_session, + ) + ) + + await self._connecting_task + + def is_closed(self) -> bool: + """ + If the session is in a terminal state. + Do not send messages, do not expect any more messages to be emitted, + the state is expected to be stale. + """ + return self._state in TerminalStates + + def is_connected(self) -> bool: + return self._state in ActiveStates + + async def _get_current_time(self) -> float: + return asyncio.get_event_loop().time() + + async def _enqueue_message( + self, + stream_id: str, + payload: dict[Any, Any] | str, + control_flags: int = 0, + service_name: str | None = None, + procedure_name: str | None = None, + span: Span | None = None, + ) -> None: + """Send serialized messages to the websockets.""" + logger.debug( + "_enqueue_message(stream_id=%r, payload=%r, control_flags=%r, " + "service_name=%r, procedure_name=%r)", + stream_id, + payload, + bin(control_flags), + service_name, + procedure_name, + ) + # Ensure the buffer isn't full before we enqueue + await self._space_available.wait() + + # Before we append, do an important check + if self._state in TerminalStates: + # session is closing / closed, raise + raise SessionClosedRiverServiceException( + "river session is closed, dropping message", + ) + + # Begin critical section: Avoid any await between here and _send_buffer.append + msg = TransportMessage( + streamId=stream_id, + id=nanoid.generate(), + from_=self._client_id, + to=self._server_id, + seq=self.seq, + ack=self.ack, + controlFlags=control_flags, + payload=payload, + serviceName=service_name, + procedureName=procedure_name, + ) + + if span: + with use_span(span): + trace_propagator.inject(msg, None, trace_setter) + + # We're clear to add to the send buffer + self._send_buffer.append(msg) + + # Increment immediately so we maintain consistency + self.seq += 1 + + # If the buffer is now full, reset the block + if len(self._send_buffer) >= self._transport_options.buffer_size: + self._space_available.clear() + + # Wake up buffered_message_sender + self._process_messages.set() + + async def close( + self, reason: Exception | None = None, current_state: SessionState | None = None + ) -> None: + """Close the session and all associated streams.""" + logger.info( + f"{self.session_id} closing session to {self._server_id}, ws: {self._ws}" + ) + if (current_state or self._state) in TerminalStates: + # already closing + return + self._state = SessionState.CLOSING + + # We're closing, so we need to wake up... + # ... tasks waiting for connection to be established + self._wait_for_connected.set() + # ... consumers waiting to enqueue messages + self._space_available.set() + # ... message processor so it can exit cleanly + self._process_messages.set() + + # Wait a tick to permit the waiting tasks to shut down gracefully + await asyncio.sleep(0.01) + + await self._task_manager.cancel_all_tasks() + + for stream_meta in self._streams.values(): + stream_meta["output"].close() + # Wake up backpressured writers + try: + stream_meta["error_channel"].put_nowait( + reason + or SessionClosedRiverServiceException( + "river session is closed", + ) + ) + except ChannelFull: + logger.exception( + "Unable to tell the caller that the session is going away", + ) + stream_meta["release_backpressured_waiter"]() + # Before we GC the streams, let's wait for all tasks to be closed gracefully. + try: + async with asyncio.timeout( + self._transport_options.shutdown_all_streams_timeout_ms + ): + # Block for backpressure and emission errors from the ws + await asyncio.gather( + *[ + stream_meta["output"].join() + for stream_meta in self._streams.values() + ] + ) + except asyncio.TimeoutError: + spans: list[Span] = [ + stream_meta["span"] + for stream_meta in self._streams.values() + if not stream_meta["output"].closed() + ] + span_ids = [span.get_span_context().span_id for span in spans] + logger.exception( + "Timeout waiting for output streams to finallize", + extra={"span_ids": span_ids}, + ) + self._streams.clear() + + if self._ws: + # The Session isn't guaranteed to live much longer than this close() + # invocation, so let's await this close to avoid dropping the socket. + await self._ws.close() + + self._state = SessionState.CLOSED + + # Clear the session in transports + # This will get us GC'd, so this should be the last thing. + await self._close_session_callback(self) + + def _start_buffered_message_sender( + self, + ) -> None: + """ + Building on buffered_message_sender's documentation, we implement backpressure + per-stream by way of self._streams' + + error_channel: Channel[Exception | None] + + This is accomplished via the following strategy: + - If buffered_message_sender encounters an error, we transition back to + connecting and attempt to handshake. + + If the handshake fails, we close the session with an informative error that + gets emitted to all backpressured client methods. + + - Alternately, if buffered_message_sender successfully writes back to the + + - Finally, if _recv_from_ws encounters an error (transport or deserialization), + we emit an informative error to close_session which gets emitted to all + backpressured client methods. + """ + + async def commit(msg: TransportMessage) -> None: + pending = self._send_buffer.popleft() + if msg.seq != pending.seq: + logger.error("Out of sequence error") + self._ack_buffer.append(pending) + + # On commit... + # ... release pending writers waiting for more buffer space + self._space_available.set() + # ... tell the message sender to back off if there are no pending messages + if not self._send_buffer: + self._process_messages.clear() + + # Wake up backpressured writer + stream_meta = self._streams.get(pending.streamId) + if stream_meta: + stream_meta["release_backpressured_waiter"]() + + def get_next_pending() -> TransportMessage | None: + if self._send_buffer: + return self._send_buffer[0] + return None + + def get_ws() -> ClientConnection | None: + if self.is_connected(): + return self._ws + return None + + async def block_until_connected() -> None: + if self._state in TerminalStates: + return + logger.debug("block_until_connected") + await self._wait_for_connected.wait() + logger.debug("block_until_connected released!") + + async def block_until_message_available() -> None: + await self._process_messages.wait() + + self._task_manager.create_task( + buffered_message_sender( + block_until_connected=block_until_connected, + block_until_message_available=block_until_message_available, + get_ws=get_ws, + websocket_closed_callback=self.ensure_connected, + get_next_pending=get_next_pending, + commit=commit, + get_state=lambda: self._state, + ) + ) + + def _start_recv_from_ws(self) -> None: + async def transition_no_connection() -> None: + if self._state in TerminalStates: + return + self._state = SessionState.NO_CONNECTION + self._wait_for_connected.clear() + if self._ws: + self._task_manager.create_task(self._ws.close()) + self._ws = None + + if self._retry_connection_callback: + self._task_manager.create_task(self._retry_connection_callback()) + else: + await self.ensure_connected() + + def assert_incoming_seq_bookkeeping( + msg_from: str, + msg_seq: int, + msg_ack: int, + ) -> Literal[True] | _IgnoreMessage: + # Update bookkeeping + if msg_seq < self.ack: + logger.info( + "Received duplicate msg", + extra={ + "from": msg_from, + "got_seq": msg_seq, + "expected_ack": self.ack, + }, + ) + return _IgnoreMessage() + elif msg_seq > self.ack: + logger.warning( + f"Out of order message received got {msg_seq} expected {self.ack}" + ) + + raise OutOfOrderMessageException( + received_seq=msg_seq, + expected_ack=self.ack, + ) + else: + # Set our next expected ack number + self.ack = msg_seq + 1 + + # Discard old server-ack'd messages from the ack buffer + while self._ack_buffer and self._ack_buffer[0].seq < msg_ack: + self._ack_buffer.popleft() + + return True + + async def block_until_connected() -> None: + if self._state in TerminalStates: + return + logger.debug("block_until_connected") + await self._wait_for_connected.wait() + logger.debug("block_until_connected released!") + + self._task_manager.create_task( + _recv_from_ws( + block_until_connected=block_until_connected, + client_id=self._client_id, + get_state=lambda: self._state, + get_ws=lambda: self._ws, + transition_no_connection=transition_no_connection, + close_session=self.close, + assert_incoming_seq_bookkeeping=assert_incoming_seq_bookkeeping, + get_stream=lambda stream_id: self._streams.get(stream_id), + enqueue_message=self._enqueue_message, + ) + ) + + @asynccontextmanager + async def _with_stream( + self, + span: Span, + stream_id: str, + maxsize: int, + ) -> AsyncIterator[tuple[_BackpressuredWaiter, AsyncIterator[ResultType]]]: + """ + _with_stream + + An async context that exposes a managed stream and an event that permits + producers to respond to backpressure. + + It is expected that the first message emitted ignores this error_channel, + since the first event does not care about backpressure, but subsequent events + emitted should call await error_channel.wait() prior to emission. + """ + output: Channel[ResultType] = Channel(maxsize=maxsize) + backpressured_waiter_event: asyncio.Event = asyncio.Event() + error_channel: Channel[Exception] = Channel(maxsize=1) + self._streams[stream_id] = { + "span": span, + "error_channel": error_channel, + "release_backpressured_waiter": backpressured_waiter_event.set, + "output": output, + } + + async def backpressured_waiter() -> None: + await backpressured_waiter_event.wait() + try: + err = error_channel.get_nowait() + raise err + except (ChannelClosed, ChannelEmpty): + # No errors, off to the next message + pass + + async def error_checking_output() -> AsyncIterator[ResultType]: + async for elem in output: + try: + err = error_channel.get_nowait() + raise err + except (ChannelClosed, ChannelEmpty): + # No errors, off to the next message + pass + yield elem + + try: + yield (backpressured_waiter, error_checking_output()) + finally: + stream_meta = self._streams.get(stream_id) + if not stream_meta: + logger.warning( + "_with_stream had an entry deleted out from under it", + extra={ + "session_id": self.session_id, + "stream_id": stream_id, + }, + ) + return + # We need to signal back to all emitters or waiters that we're gone + output.close() + del self._streams[stream_id] + + async def send_rpc[R, A]( + self, + service_name: str, + procedure_name: str, + request: R, + request_serializer: Callable[[R], Any], + response_deserializer: Callable[[Any], A], + error_deserializer: Callable[[Any], RiverError], + span: Span, + timeout: timedelta, + ) -> A: + """Sends a single RPC request to the server. + + Expects the input and output be messages that will be msgpacked. + """ + stream_id = nanoid.generate() + await self._enqueue_message( + stream_id=stream_id, + control_flags=STREAM_OPEN_BIT | STREAM_CLOSED_BIT, + payload=request_serializer(request), + service_name=service_name, + procedure_name=procedure_name, + span=span, + ) + + async with self._with_stream(span, stream_id, 1) as ( + backpressured_waiter, + output, + ): + # Handle potential errors during communication + try: + async with asyncio.timeout(timeout.total_seconds()): + # Block for backpressure and emission errors from the ws + await backpressured_waiter() + result = await anext(output) + except asyncio.TimeoutError as e: + await self._send_cancel_stream( + stream_id=stream_id, + message="Timeout, abandoning request", + span=span, + ) + raise RiverException(ERROR_CODE_CANCEL, str(e)) from e + except ChannelClosed as e: + raise RiverServiceException( + ERROR_CODE_STREAM_CLOSED, + "Stream closed before response", + service_name, + procedure_name, + ) from e + + if "ok" not in result or not result["ok"]: + try: + error = error_deserializer(result["payload"]) + except Exception as e: + raise RiverException("error_deserializer", str(e)) from e + raise exception_from_message(error.code)( + error.code, error.message, service_name, procedure_name + ) + + return response_deserializer(result["payload"]) + + async def send_upload[I, R, A]( + self, + service_name: str, + procedure_name: str, + init: I, + request: AsyncIterable[R], + init_serializer: Callable[[I], Any], + request_serializer: Callable[[R], Any], + response_deserializer: Callable[[Any], A], + error_deserializer: Callable[[Any], RiverError], + span: Span, + ) -> A: + """Sends an upload request to the server. + + Expects the input and output be messages that will be msgpacked. + """ + stream_id = nanoid.generate() + await self._enqueue_message( + stream_id=stream_id, + control_flags=STREAM_OPEN_BIT, + service_name=service_name, + procedure_name=procedure_name, + payload=init_serializer(init), + span=span, + ) + + async with self._with_stream(span, stream_id, 1) as ( + backpressured_waiter, + output, + ): + try: + # If this request is not closed and the session is killed, we should + # throw exception here + async for item in request: + # Block for backpressure and emission errors from the ws + await backpressured_waiter() + try: + payload = request_serializer(item) + except Exception as e: + await self._send_cancel_stream( + stream_id=stream_id, + message="Request serialization error", + span=span, + ) + raise RiverServiceException( + ERROR_CODE_STREAM_CLOSED, + str(e), + service_name, + procedure_name, + ) from e + await self._enqueue_message( + stream_id=stream_id, + service_name=service_name, + procedure_name=procedure_name, + control_flags=0, + payload=payload, + span=span, + ) + except Exception as e: + # If we get any exception other than WebsocketClosedException, + # cancel the stream. + await self._send_cancel_stream( + stream_id=stream_id, + message="Unspecified error", + span=span, + ) + raise RiverServiceException( + ERROR_CODE_STREAM_CLOSED, str(e), service_name, procedure_name + ) from e + await self._send_close_stream( + stream_id=stream_id, + span=span, + ) + + try: + result = await anext(output) + except ChannelClosed as e: + raise RiverServiceException( + ERROR_CODE_STREAM_CLOSED, + "Stream closed before response", + service_name, + procedure_name, + ) from e + except Exception as e: + await self._send_cancel_stream( + stream_id=stream_id, + message="Unspecified error", + span=span, + ) + raise RiverException(ERROR_CODE_STREAM_CLOSED, str(e)) from e + + if "ok" not in result or not result["ok"]: + try: + error = error_deserializer(result["payload"]) + except Exception as e: + raise RiverException("error_deserializer", str(e)) from e + raise exception_from_message(error.code)( + error.code, error.message, service_name, procedure_name + ) + + return response_deserializer(result["payload"]) + + async def send_subscription[I, E, A]( + self, + service_name: str, + procedure_name: str, + init: I, + init_serializer: Callable[[I], Any], + response_deserializer: Callable[[Any], A], + error_deserializer: Callable[[Any], E], + span: Span, + ) -> AsyncGenerator[A | E, None]: + """Sends a subscription request to the server. + + Expects the input and output be messages that will be msgpacked. + """ + stream_id = nanoid.generate() + await self._enqueue_message( + service_name=service_name, + procedure_name=procedure_name, + stream_id=stream_id, + control_flags=STREAM_OPEN_BIT, + payload=init_serializer(init), + span=span, + ) + + async with self._with_stream(span, stream_id, MAX_MESSAGE_BUFFER_SIZE) as ( + _, + output, + ): + try: + async for item in output: + if item.get("type") == "CLOSE": + break + if not item.get("ok", False): + yield error_deserializer(item["payload"]) + continue + yield response_deserializer(item["payload"]) + await self._send_close_stream(stream_id, span) + except Exception as e: + await self._send_cancel_stream( + stream_id=stream_id, + message="Unspecified error", + span=span, + ) + raise RiverServiceException( + ERROR_CODE_STREAM_CLOSED, + "Stream closed before response", + service_name, + procedure_name, + ) from e + + async def send_stream[I, R, E, A]( + self, + service_name: str, + procedure_name: str, + init: I, + request: AsyncIterable[R] | None, + init_serializer: Callable[[I], Any], + request_serializer: Callable[[R], Any] | None, + response_deserializer: Callable[[Any], A], + error_deserializer: Callable[[Any], E], + span: Span, + ) -> AsyncGenerator[A | E, None]: + """Sends a subscription request to the server. + + Expects the input and output be messages that will be msgpacked. + """ + + stream_id = nanoid.generate() + await self._enqueue_message( + service_name=service_name, + procedure_name=procedure_name, + stream_id=stream_id, + control_flags=STREAM_OPEN_BIT, + payload=init_serializer(init), + span=span, + ) + + async with self._with_stream(span, stream_id, MAX_MESSAGE_BUFFER_SIZE) as ( + backpressured_waiter, + output, + ): + # Create the encoder task + async def _encode_stream() -> None: + if not request: + await self._send_close_stream( + stream_id=stream_id, + span=span, + ) + return + + assert request_serializer, "send_stream missing request_serializer" + + async for item in request: + # Block for backpressure (or errors) + await backpressured_waiter() + # If there are any errors so far, raise them + await self._enqueue_message( + stream_id=stream_id, + control_flags=0, + payload=request_serializer(item), + ) + await self._send_close_stream( + stream_id=stream_id, + span=span, + ) + + emitter_task = self._task_manager.create_task(_encode_stream()) + + # Handle potential errors during communication + try: + async for result in output: + # Raise as early as we possibly can in case of an emission error + if emitter_task.done() and (err := emitter_task.exception()): + raise err + if result.get("type") == "CLOSE": + break + if "ok" not in result or not result["ok"]: + yield error_deserializer(result["payload"]) + continue + yield response_deserializer(result["payload"]) + # ... block the outer function until the emitter is finished emitting, + # possibly raising a terminal exception. + await emitter_task + except Exception as e: + await self._send_cancel_stream( + stream_id=stream_id, + message="Unspecified error", + span=span, + ) + raise RiverServiceException( + ERROR_CODE_STREAM_CLOSED, + "Stream closed before response", + service_name, + procedure_name, + ) from e + + async def _send_cancel_stream( + self, + stream_id: str, + message: str, + span: Span, + ) -> None: + await self._enqueue_message( + stream_id=stream_id, + control_flags=STREAM_CANCEL_BIT, + payload={ + "ok": False, + "payload": { + "code": "CANCEL", + "message": message, + }, + }, + span=span, + ) + + async def _send_close_stream( + self, + stream_id: str, + span: Span, + ) -> None: + await self._enqueue_message( + stream_id=stream_id, + control_flags=STREAM_CLOSED_BIT, + payload={"type": "CLOSE"}, + span=span, + ) + + +async def _do_ensure_connected[HandshakeMetadata]( + transport_options: TransportOptions, + client_id: str, + session_id: str, + server_id: str, + rate_limiter: LeakyBucketRateLimit, + uri_and_metadata_factory: Callable[ + [], Awaitable[UriAndMetadata[HandshakeMetadata]] + ], + get_current_time: Callable[[], Awaitable[float]], + get_next_sent_seq: Callable[[], int], + get_current_ack: Callable[[], int], + get_state: Callable[[], SessionState], + transition_connecting: Callable[[], None], + close_ws_in_background: Callable[[ClientConnection], None], + transition_connected: Callable[[ClientConnection], None], + unbind_connecting_task: Callable[[], None], + close_session: Callable[[Exception | None], None], +) -> None: + logger.info("Attempting to establish new ws connection") + + last_error: Exception | None = None + attempt_count = 0 + while rate_limiter.has_budget(client_id): + if (state := get_state()) in TerminalStates or state in ActiveStates: + logger.info(f"_do_ensure_connected stopping due to state={state}") + break + + if attempt_count > 0: + logger.info(f"Retrying build handshake number {attempt_count} times") + attempt_count += 1 + + rate_limiter.consume_budget(client_id) + transition_connecting() + + ws: ClientConnection | None = None + try: + uri_and_metadata = await uri_and_metadata_factory() + ws = await websockets.asyncio.client.connect(uri_and_metadata["uri"]) + + try: + handshake_request = ControlMessageHandshakeRequest[HandshakeMetadata]( + type="HANDSHAKE_REQ", + protocolVersion=PROTOCOL_VERSION, + sessionId=session_id, + metadata=uri_and_metadata["metadata"], + expectedSessionState=ExpectedSessionState( + nextExpectedSeq=get_current_ack(), + nextSentSeq=get_next_sent_seq(), + ), + ) + + async def websocket_closed_callback() -> None: + logger.error("websocket closed before handshake response") + + await send_transport_message( + TransportMessage( + from_=client_id, + to=server_id, + streamId=nanoid.generate(), + controlFlags=0, + id=nanoid.generate(), + seq=0, + ack=0, + payload=handshake_request.model_dump(), + ), + ws=ws, + websocket_closed_callback=websocket_closed_callback, + ) + except ( + WebsocketClosedException, + FailedSendingMessageException, + ) as e: + raise RiverException( + ERROR_HANDSHAKE, + "Handshake failed, conn closed while sending response", + ) from e + + handshake_deadline_ms = ( + await get_current_time() + transport_options.handshake_timeout_ms + ) + + if await get_current_time() >= handshake_deadline_ms: + raise RiverException( + ERROR_HANDSHAKE, + "Handshake response timeout, closing connection", + ) + + try: + data = await ws.recv(decode=False) + except ConnectionClosed as e: + logger.debug( + "_do_ensure_connected: Connection closed during waiting " + "for handshake response", + exc_info=True, + ) + raise RiverException( + ERROR_HANDSHAKE, + "Handshake failed, conn closed while waiting for response", + ) from e + + try: + response_msg = parse_transport_msg(data) + except InvalidMessageException as e: + raise RiverException( + ERROR_HANDSHAKE, + "Got invalid transport message, closing connection", + ) from e + + if isinstance(response_msg, str): + raise RiverException( + ERROR_HANDSHAKE, + "Handshake failed, received a raw string message while waiting " + "for a handshake response", + ) + + try: + handshake_response = ControlMessageHandshakeResponse( + **response_msg.payload + ) + logger.debug("river client waiting for handshake response") + except ValidationError as e: + raise RiverException( + ERROR_HANDSHAKE, "Failed to parse handshake response" + ) from e + + logger.debug("river client get handshake response : %r", handshake_response) + if not handshake_response.status.ok: + err = RiverException( + ERROR_HANDSHAKE, + f"Handshake failed with code {handshake_response.status.code}: { + handshake_response.status.reason + }", + ) + if handshake_response.status.code == ERROR_CODE_SESSION_STATE_MISMATCH: + # A session state mismatch is unrecoverable. Terminate immediately. + close_session(err) + + raise err + + # We did it! We're connected! + last_error = None + rate_limiter.start_restoring_budget(client_id) + transition_connected(ws) + break + except Exception as e: + if ws: + close_ws_in_background(ws) + ws = None + last_error = e + backoff_time = rate_limiter.get_backoff_ms(client_id) + logger.exception( + f"Error connecting, retrying with {backoff_time}ms backoff" + ) + await asyncio.sleep(backoff_time / 1000) + unbind_connecting_task() + + if last_error is not None: + logger.debug("Handshake attempts exhausted, terminating") + close_session(last_error) + raise RiverException( + ERROR_HANDSHAKE, + f"Failed to create ws after retrying {attempt_count} number of times", + ) from last_error + + return None + + +async def _recv_from_ws( + block_until_connected: Callable[[], Awaitable[None]], + client_id: str, + get_state: Callable[[], SessionState], + get_ws: Callable[[], ClientConnection | None], + transition_no_connection: Callable[[], Awaitable[None]], + close_session: Callable[[Exception | None], Awaitable[None]], + assert_incoming_seq_bookkeeping: Callable[ + [str, int, int], Literal[True] | _IgnoreMessage + ], + get_stream: Callable[ + [str], + StreamMeta | None, + ], + enqueue_message: SendMessage[None], +) -> None: + """Serve messages from the websocket. + + Process incoming packets from the connected websocket. + """ + our_task = asyncio.current_task() + connection_attempts = 0 + try: + while our_task and not our_task.cancelling() and not our_task.cancelled(): + logger.debug(f"_recv_from_ws loop count={connection_attempts}") + connection_attempts += 1 + ws = None + while ((state := get_state()) in ConnectingStates) and ( + state not in TerminalStates + ): + logger.debug( + "_handle_messages_from_ws spinning while connecting, %r %r", + ws, + state, + ) + await block_until_connected() + + if state in TerminalStates: + logger.debug( + f"Session is {state}, shut down _recv_from_ws", + ) + # session is closing / closed, no need to _recv_from_ws anymore + break + + logger.debug("client start handling messages from ws %r", ws) + + # We should not process messages if the websocket is closed. + while (ws := get_ws()) and get_state() in ActiveStates: + connection_attempts = 0 + # decode=False: Avoiding an unnecessary round-trip through str + # Ideally this should be type-ascripted to : bytes, but there + # is no @overrides in `websockets` to hint this. + try: + message = await ws.recv(decode=False) + except ConnectionClosed: + # This triggers a break in the inner loop so we can get back to + # the outer loop. + await transition_no_connection() + break + try: + msg = parse_transport_msg(message) + logger.debug( + "[%s] got a message %r", + client_id, + msg, + ) + if isinstance(msg, str): + logger.debug("Ignoring transport message", exc_info=True) + continue + + if msg.controlFlags & STREAM_OPEN_BIT != 0: + raise InvalidMessageException( + "Client should not receive stream open bit" + ) + + match assert_incoming_seq_bookkeeping( + msg.from_, + msg.seq, + msg.ack, + ): + case _IgnoreMessage(): + logger.debug( + "Ignoring transport message", + exc_info=True, + ) + continue + case True: + pass + case other: + assert_never(other) + + # Shortcut to avoid processing ack packets + if msg.controlFlags & ACK_BIT != 0: + await enqueue_message( + stream_id="heartbeat", + # TODO: make this a message class + # https://github.com/replit/river/blob/741b1ea6d7600937ad53564e9cf8cd27a92ec36a/transport/message.ts#L42 + payload={ + "type": "ACK", + }, + control_flags=ACK_BIT, + procedure_name=None, + service_name=None, + span=None, + ) + continue + + stream_meta = get_stream(msg.streamId) + + if not stream_meta: + logger.warning( + "no stream for %s, ignoring message", + msg.streamId, + ) + continue + + if ( + msg.controlFlags & STREAM_CLOSED_BIT != 0 + and msg.payload.get("type", None) == "CLOSE" + ): + # close message is not sent to the stream + # event is set during cleanup down below + pass + else: + try: + await stream_meta["output"].put(msg.payload) + except ChannelClosed: + # The client is no longer interested in this stream, + # just drop the message. + pass + + if msg.controlFlags & STREAM_CLOSED_BIT != 0: + # Communicate that we're going down + # + # This implements the receive side of the half-closed strategy. + stream_meta["output"].close() + except OutOfOrderMessageException: + logger.exception("Out of order message, closing connection") + await close_session( + SessionClosedRiverServiceException( + "Out of order message, closing connection" + ) + ) + continue + except InvalidMessageException: + logger.exception( + "Got invalid transport message, closing session", + ) + await close_session( + SessionClosedRiverServiceException( + "Out of order message, closing connection" + ) + ) + continue + except FailedSendingMessageException: + # Expected error if the connection is closed. + await transition_no_connection() + logger.debug( + "FailedSendingMessageException while serving", exc_info=True + ) + break + except Exception: + logger.exception("caught exception at message iterator") + break + logger.debug("_handle_messages_from_ws exiting") + except ExceptionGroup as eg: + _, unhandled = eg.split(lambda e: isinstance(e, ConnectionClosed)) + if unhandled: + # We're in a task, there's not that much that can be done. + unhandled = ExceptionGroup( + "Unhandled exceptions on River server", unhandled.exceptions + ) + logger.exception( + "caught exception at message iterator", + exc_info=unhandled, + ) + raise unhandled + logger.debug(f"_recv_from_ws exiting normally after {connection_attempts} loops") diff --git a/tests/conftest.py b/tests/conftest.py index b9b8cdf6..3866fdd1 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -14,7 +14,11 @@ ) # Modular fixtures -pytest_plugins = ["tests.river_fixtures.logging", "tests.river_fixtures.clientserver"] +pytest_plugins = [ + "tests.v1.river_fixtures.logging", + "tests.v1.river_fixtures.clientserver", + "tests.v2.fixtures", +] HandlerKind = Literal["rpc", "subscription-stream", "upload-stream", "stream"] HandlerMapping = Mapping[tuple[str, str], tuple[HandlerKind, GenericRpcHandlerBuilder]] @@ -31,7 +35,7 @@ def transport_message( ) -> TransportMessage: return TransportMessage( id=str(nanoid.generate()), - from_=from_, # type: ignore + from_=from_, to=to, streamId=streamId, seq=seq, diff --git a/tests/codegen/snapshot/codegen_snapshot_fixtures.py b/tests/fixtures/codegen_snapshot_fixtures.py similarity index 78% rename from tests/codegen/snapshot/codegen_snapshot_fixtures.py rename to tests/fixtures/codegen_snapshot_fixtures.py index ef74a1fb..d4a88f54 100644 --- a/tests/codegen/snapshot/codegen_snapshot_fixtures.py +++ b/tests/fixtures/codegen_snapshot_fixtures.py @@ -1,6 +1,6 @@ from io import StringIO from pathlib import Path -from typing import Callable, TextIO +from typing import Callable, Literal, TextIO from pytest_snapshot.plugin import Snapshot @@ -15,11 +15,14 @@ def close(self) -> None: def validate_codegen( *, snapshot: Snapshot, + snapshot_dir: str, read_schema: Callable[[], TextIO], target_path: str, client_name: str, + protocol_version: Literal["v1.1", "v2.0"], + typeddict_inputs: bool = True, ) -> None: - snapshot.snapshot_dir = "tests/codegen/snapshot/snapshots" + snapshot.snapshot_dir = snapshot_dir files: dict[Path, UnclosableStringIO] = {} def file_opener(path: Path) -> TextIO: @@ -33,8 +36,9 @@ def file_opener(path: Path) -> TextIO: target_path=target_path, client_name=client_name, file_opener=file_opener, - typed_dict_inputs=True, + typed_dict_inputs=typeddict_inputs, method_filter=None, + protocol_version=protocol_version, ) for path, file in files.items(): file.seek(0) diff --git a/tests/codegen/rpc/generated/__init__.py b/tests/v1/codegen/rpc/generated/__init__.py similarity index 100% rename from tests/codegen/rpc/generated/__init__.py rename to tests/v1/codegen/rpc/generated/__init__.py diff --git a/tests/codegen/rpc/generated/test_service/__init__.py b/tests/v1/codegen/rpc/generated/test_service/__init__.py similarity index 97% rename from tests/codegen/rpc/generated/test_service/__init__.py rename to tests/v1/codegen/rpc/generated/test_service/__init__.py index 24545e00..3d9bc86a 100644 --- a/tests/codegen/rpc/generated/test_service/__init__.py +++ b/tests/v1/codegen/rpc/generated/test_service/__init__.py @@ -11,7 +11,6 @@ from .rpc_method import ( Rpc_MethodInput, - Rpc_MethodInputTypeAdapter, Rpc_MethodOutput, Rpc_MethodOutputTypeAdapter, encode_Rpc_MethodInput, diff --git a/tests/codegen/rpc/generated/test_service/rpc_method.py b/tests/v1/codegen/rpc/generated/test_service/rpc_method.py similarity index 91% rename from tests/codegen/rpc/generated/test_service/rpc_method.py rename to tests/v1/codegen/rpc/generated/test_service/rpc_method.py index dfe8a47c..1e40411f 100644 --- a/tests/codegen/rpc/generated/test_service/rpc_method.py +++ b/tests/v1/codegen/rpc/generated/test_service/rpc_method.py @@ -40,9 +40,6 @@ class Rpc_MethodInput(TypedDict): data: str -Rpc_MethodInputTypeAdapter: TypeAdapter[Rpc_MethodInput] = TypeAdapter(Rpc_MethodInput) - - class Rpc_MethodOutput(BaseModel): data: str diff --git a/tests/codegen/rpc/schema.json b/tests/v1/codegen/rpc/schema.json similarity index 100% rename from tests/codegen/rpc/schema.json rename to tests/v1/codegen/rpc/schema.json diff --git a/tests/v1/codegen/snapshot/parity/__init__.py b/tests/v1/codegen/snapshot/parity/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/v1/codegen/snapshot/parity/check_parity.py b/tests/v1/codegen/snapshot/parity/check_parity.py new file mode 100644 index 00000000..68731bd2 --- /dev/null +++ b/tests/v1/codegen/snapshot/parity/check_parity.py @@ -0,0 +1,347 @@ +import importlib +from typing import Any, Callable, Literal, TypedDict + +from pydantic import TypeAdapter +from pytest_snapshot.plugin import Snapshot + +from tests.fixtures.codegen_snapshot_fixtures import validate_codegen +from tests.v1.codegen.snapshot.parity.gen import ( + gen_bool, + gen_choice, + gen_dict, + gen_int, + gen_list, + gen_opt, + gen_str, +) + +PrimitiveType = ( + bool | str | int | float | dict[str, "PrimitiveType"] | list["PrimitiveType"] +) + + +def deep_equal(a: PrimitiveType, b: PrimitiveType) -> Literal[True]: + if a == b: + return True + elif isinstance(a, dict) and isinstance(b, dict): + a_keys: PrimitiveType = list(a.keys()) + b_keys: PrimitiveType = list(b.keys()) + assert deep_equal(a_keys, b_keys) + + # We do this dance again because Python variance is hard. Feel free to fix it. + keys = set(a.keys()) + keys.update(b.keys()) + for k in keys: + aa: PrimitiveType = a[k] + bb: PrimitiveType = b[k] + assert deep_equal(aa, bb) + return True + elif isinstance(a, list) and isinstance(b, list): + assert len(a) == len(b) + for i in range(len(a)): + assert deep_equal(a[i], b[i]) + return True + else: + assert a == b, f"{a} != {b}" + return True + + +def baseTestPattern[A]( + x: A, encode: Callable[[A], Any], adapter: TypeAdapter[Any] +) -> None: + a = encode(x) + m = adapter.validate_python(a) + z = adapter.dump_python(m, by_alias=True, exclude_none=True) + + assert deep_equal(a, z) + + +def test_AiexecExecInit(snapshot: Snapshot) -> None: + validate_codegen( + snapshot=snapshot, + snapshot_dir="tests/v1/codegen/snapshot/snapshots", + read_schema=lambda: open("tests/v1/codegen/snapshot/parity/schema.json"), + target_path="parity/typeddict_inputs", + client_name="foo", + protocol_version="v1.1", + typeddict_inputs=True, + ) + + validate_codegen( + snapshot=snapshot, + snapshot_dir="tests/v1/codegen/snapshot/snapshots", + read_schema=lambda: open("tests/v1/codegen/snapshot/parity/schema.json"), + target_path="parity/pydantic_inputs", + client_name="foo", + protocol_version="v1.1", + typeddict_inputs=False, + ) + + import tests.v1.codegen.snapshot.snapshots.parity + + importlib.reload(tests.v1.codegen.snapshot.snapshots.parity) + + from tests.v1.codegen.snapshot.snapshots.parity import pydantic_inputs as pyd + from tests.v1.codegen.snapshot.snapshots.parity import typeddict_inputs as tyd + + x: tyd.aiExec.ExecInit = { + "args": gen_list(gen_str)(), + "env": gen_opt(gen_dict(gen_str))(), + "cwd": gen_opt(gen_str)(), + "omitStdout": gen_opt(gen_bool)(), + "omitStderr": gen_opt(gen_bool)(), + "useReplitRunEnv": gen_opt(gen_bool)(), + } + + baseTestPattern(x, tyd.aiExec.encode_ExecInit, pyd.aiExec.ExecInitTypeAdapter) + + +def test_AgenttoollanguageserverOpendocumentInput(snapshot: Snapshot) -> None: + validate_codegen( + snapshot=snapshot, + snapshot_dir="tests/v1/codegen/snapshot/snapshots", + read_schema=lambda: open("tests/v1/codegen/snapshot/parity/schema.json"), + target_path="parity/typeddict_inputs", + client_name="foo", + protocol_version="v1.1", + typeddict_inputs=True, + ) + + validate_codegen( + snapshot=snapshot, + snapshot_dir="tests/v1/codegen/snapshot/snapshots", + read_schema=lambda: open("tests/v1/codegen/snapshot/parity/schema.json"), + target_path="parity/pydantic_inputs", + client_name="foo", + protocol_version="v1.1", + typeddict_inputs=False, + ) + + import tests.v1.codegen.snapshot.snapshots.parity + + importlib.reload(tests.v1.codegen.snapshot.snapshots.parity) + + from tests.v1.codegen.snapshot.snapshots.parity import pydantic_inputs as pyd + from tests.v1.codegen.snapshot.snapshots.parity import typeddict_inputs as tyd + + x: tyd.agentToolLanguageServer.OpendocumentInput = { + "path": gen_str(), + } + + baseTestPattern( + x, + tyd.agentToolLanguageServer.encode_OpendocumentInput, + pyd.agentToolLanguageServer.OpendocumentInputTypeAdapter, + ) + + +kind_type = ( + Literal[ + 1, + 2, + 3, + 4, + 5, + 6, + 7, + 8, + 9, + 10, + 11, + 12, + 13, + 14, + 15, + 16, + 17, + 18, + 19, + 20, + 21, + 22, + 23, + 24, + 25, + 26, + ] + | None +) + + +class size_type(TypedDict): + rows: int + cols: int + + +def test_ShellexecSpawnInput(snapshot: Snapshot) -> None: + validate_codegen( + snapshot=snapshot, + snapshot_dir="tests/v1/codegen/snapshot/snapshots", + read_schema=lambda: open("tests/v1/codegen/snapshot/parity/schema.json"), + target_path="parity/typeddict_inputs", + client_name="foo", + protocol_version="v1.1", + typeddict_inputs=True, + ) + + validate_codegen( + snapshot=snapshot, + snapshot_dir="tests/v1/codegen/snapshot/snapshots", + read_schema=lambda: open("tests/v1/codegen/snapshot/parity/schema.json"), + target_path="parity/pydantic_inputs", + client_name="foo", + protocol_version="v1.1", + typeddict_inputs=False, + ) + + import tests.v1.codegen.snapshot.snapshots.parity + + importlib.reload(tests.v1.codegen.snapshot.snapshots.parity) + + from tests.v1.codegen.snapshot.snapshots.parity import pydantic_inputs as pyd + from tests.v1.codegen.snapshot.snapshots.parity import typeddict_inputs as tyd + + x: tyd.shellExec.SpawnInput = { + "cmd": gen_str(), + "args": gen_opt(gen_list(gen_str))(), + "initialCmd": gen_opt(gen_str)(), + "env": gen_opt(gen_dict(gen_str))(), + "cwd": gen_opt(gen_str)(), + "size": gen_opt( + lambda: size_type( + { + "rows": gen_int(), + "cols": gen_int(), + } + ), + )(), + "useReplitRunEnv": gen_opt(gen_bool)(), + "interactive": gen_opt(gen_bool)(), + } + + baseTestPattern( + x, + tyd.shellExec.encode_SpawnInput, + pyd.shellExec.SpawnInputTypeAdapter, + ) + + +def test_ConmanfilesystemPersistInput(snapshot: Snapshot) -> None: + validate_codegen( + snapshot=snapshot, + snapshot_dir="tests/v1/codegen/snapshot/snapshots", + read_schema=lambda: open("tests/v1/codegen/snapshot/parity/schema.json"), + target_path="parity/typeddict_inputs", + client_name="foo", + protocol_version="v1.1", + typeddict_inputs=True, + ) + + validate_codegen( + snapshot=snapshot, + snapshot_dir="tests/v1/codegen/snapshot/snapshots", + read_schema=lambda: open("tests/v1/codegen/snapshot/parity/schema.json"), + target_path="parity/pydantic_inputs", + client_name="foo", + protocol_version="v1.1", + typeddict_inputs=False, + ) + + import tests.v1.codegen.snapshot.snapshots.parity + + importlib.reload(tests.v1.codegen.snapshot.snapshots.parity) + + from tests.v1.codegen.snapshot.snapshots.parity import pydantic_inputs as pyd + from tests.v1.codegen.snapshot.snapshots.parity import typeddict_inputs as tyd + + x: tyd.conmanFilesystem.PersistInput = {} + + baseTestPattern( + x, + tyd.conmanFilesystem.encode_PersistInput, + pyd.conmanFilesystem.PersistInputTypeAdapter, + ) + + +def test_ReplspaceapiInitInput(snapshot: Snapshot) -> None: + validate_codegen( + snapshot=snapshot, + snapshot_dir="tests/v1/codegen/snapshot/snapshots", + read_schema=lambda: open("tests/v1/codegen/snapshot/parity/schema.json"), + target_path="parity/typeddict_inputs", + client_name="foo", + protocol_version="v1.1", + typeddict_inputs=True, + ) + + validate_codegen( + snapshot=snapshot, + snapshot_dir="tests/v1/codegen/snapshot/snapshots", + read_schema=lambda: open("tests/v1/codegen/snapshot/parity/schema.json"), + target_path="parity/pydantic_inputs", + client_name="foo", + protocol_version="v1.1", + typeddict_inputs=False, + ) + + import tests.v1.codegen.snapshot.snapshots.parity + + importlib.reload(tests.v1.codegen.snapshot.snapshots.parity) + + from tests.v1.codegen.snapshot.snapshots.parity import pydantic_inputs as pyd + from tests.v1.codegen.snapshot.snapshots.parity import typeddict_inputs as tyd + + x: tyd.replspaceApi.init.InitInput = gen_choice( + list[tyd.replspaceApi.init.InitInput]( + [ + tyd.replspaceApi.init.InitInputOneOf_closeFile( + {"kind": "closeFile", "filename": gen_str(), "nonce": gen_str()} + ), + tyd.replspaceApi.init.InitInputOneOf_githubToken( + {"kind": "githubToken", "token": gen_str(), "nonce": gen_str()} + ), + tyd.replspaceApi.init.InitInputOneOf_sshToken0( + { + "kind": "sshToken", + "nonce": gen_str(), + "SSHHostname": gen_str(), + "token": gen_str(), + } + ), + tyd.replspaceApi.init.InitInputOneOf_sshToken1( + {"kind": "sshToken", "nonce": gen_str(), "error": gen_str()} + ), + tyd.replspaceApi.init.InitInputOneOf_allowDefaultBucketAccess( + { + "kind": "allowDefaultBucketAccess", + "nonce": gen_str(), + "result": gen_choice( + list[ + tyd.replspaceApi.init.InitInputOneOf_allowDefaultBucketAccessResult + ]( + [ + tyd.replspaceApi.init.InitInputOneOf_allowDefaultBucketAccessResultOneOf_ok( + { + "bucketId": gen_str(), + "sourceReplId": gen_str(), + "status": "ok", + "targetReplId": gen_str(), + } + ), + tyd.replspaceApi.init.InitInputOneOf_allowDefaultBucketAccessResultOneOf_error( + {"message": gen_str(), "status": "error"} + ), + ] + ) + )(), + } + ), + ] + ) + )() + + baseTestPattern( + x, + tyd.replspaceApi.init.encode_InitInput, + pyd.replspaceApi.init.InitInputTypeAdapter, + ) diff --git a/scripts/parity/gen.py b/tests/v1/codegen/snapshot/parity/gen.py similarity index 100% rename from scripts/parity/gen.py rename to tests/v1/codegen/snapshot/parity/gen.py diff --git a/tests/v1/codegen/snapshot/parity/schema.json b/tests/v1/codegen/snapshot/parity/schema.json new file mode 100644 index 00000000..dbe3a74e --- /dev/null +++ b/tests/v1/codegen/snapshot/parity/schema.json @@ -0,0 +1,312 @@ +{ + "services": { + "aiExec": { + "procedures": { + "exec": { + "init": { + "type": "object", + "properties": { + "args": { + "type": "array", + "items": { + "type": "string" + } + }, + "env": { + "type": "object", + "patternProperties": { + "^(.*)$": { + "type": "string" + } + } + }, + "cwd": { + "type": "string" + }, + "omitStdout": { + "type": "boolean" + }, + "omitStderr": { + "type": "boolean" + }, + "useReplitRunEnv": { + "type": "boolean" + } + }, + "required": ["args"] + }, + "output": { + "type": "object", + "properties": {} + }, + "errors": { + "type": "object", + "properties": {} + }, + "type": "stream", + "input": { + "type": "object", + "properties": {} + } + } + } + }, + "agentToolLanguageServer": { + "procedures": { + "openDocument": { + "input": { + "type": "object", + "properties": { + "path": { + "description": "The path to the file. This should be relative to the workspace root", + "type": "string" + } + }, + "required": ["path"] + }, + "output": { + "type": "object", + "properties": {} + }, + "errors": { + "type": "object", + "properties": {} + }, + "description": "Reports a document as open to the language servers", + "type": "rpc" + } + } + }, + "shellExec": { + "procedures": { + "spawn": { + "input": { + "type": "object", + "properties": { + "cmd": { + "type": "string" + }, + "args": { + "type": "array", + "items": { + "type": "string" + } + }, + "initialCmd": { + "type": "string" + }, + "env": { + "type": "object", + "patternProperties": { + "^(.*)$": { + "type": "string" + } + } + }, + "cwd": { + "type": "string" + }, + "size": { + "type": "object", + "properties": { + "rows": { + "type": "integer" + }, + "cols": { + "type": "integer" + } + }, + "required": ["rows", "cols"] + }, + "useReplitRunEnv": { + "type": "boolean" + }, + "interactive": { + "type": "boolean" + } + }, + "required": ["cmd"] + }, + "output": { + "type": "object", + "properties": {} + }, + "errors": { + "type": "object", + "properties": {} + }, + "description": "Start a new shell process and returns the process handle id which you can use to interact with the process", + "type": "rpc" + } + } + }, + "replspaceApi": { + "procedures": { + "init": { + "init": { + "type": "object", + "properties": {} + }, + "output": { + "type": "object", + "properties": {} + }, + "errors": { + "type": "object", + "properties": {} + }, + "type": "stream", + "input": { + "anyOf": [ + { + "type": "object", + "properties": { + "$kind": { + "const": "closeFile", + "type": "string" + }, + "filename": { + "type": "string" + }, + "nonce": { + "type": "string" + } + }, + "required": ["$kind", "filename", "nonce"] + }, + { + "type": "object", + "properties": { + "$kind": { + "const": "githubToken", + "type": "string" + }, + "token": { + "type": "string" + }, + "nonce": { + "type": "string" + } + }, + "required": ["$kind", "nonce"] + }, + { + "anyOf": [ + { + "type": "object", + "properties": { + "$kind": { + "const": "sshToken", + "type": "string" + }, + "nonce": { + "type": "string" + }, + "token": { + "type": "string" + }, + "SSHHostname": { + "type": "string" + } + }, + "required": ["$kind", "nonce", "token", "SSHHostname"] + }, + { + "type": "object", + "properties": { + "$kind": { + "const": "sshToken", + "type": "string" + }, + "nonce": { + "type": "string" + }, + "error": { + "type": "string" + } + }, + "required": ["$kind", "nonce", "error"] + } + ] + }, + { + "type": "object", + "properties": { + "$kind": { + "const": "allowDefaultBucketAccess", + "type": "string" + }, + "nonce": { + "type": "string" + }, + "result": { + "anyOf": [ + { + "type": "object", + "properties": { + "status": { + "const": "error", + "type": "string" + }, + "message": { + "type": "string" + } + }, + "required": ["status", "message"] + }, + { + "type": "object", + "properties": { + "status": { + "const": "ok", + "type": "string" + }, + "targetReplId": { + "type": "string" + }, + "sourceReplId": { + "type": "string" + }, + "bucketId": { + "type": "string" + } + }, + "required": [ + "status", + "targetReplId", + "sourceReplId", + "bucketId" + ] + } + ] + } + }, + "required": ["$kind", "nonce", "result"] + } + ] + } + } + } + }, + "conmanFilesystem": { + "procedures": { + "persist": { + "input": { + "type": "object", + "properties": {} + }, + "output": { + "type": "object", + "properties": {} + }, + "errors": { + "type": "object", + "properties": {} + }, + "type": "rpc" + } + } + } + }, + "handshakeSchema": null +} diff --git a/tests/v1/codegen/snapshot/snapshots/parity/pydantic_inputs/__init__.py b/tests/v1/codegen/snapshot/snapshots/parity/pydantic_inputs/__init__.py new file mode 100644 index 00000000..0f5c7dfb --- /dev/null +++ b/tests/v1/codegen/snapshot/snapshots/parity/pydantic_inputs/__init__.py @@ -0,0 +1,21 @@ +# Code generated by river.codegen. DO NOT EDIT. +from pydantic import BaseModel +from typing import Literal + +import replit_river as river + + +from .agentToolLanguageServer import AgenttoollanguageserverService +from .aiExec import AiexecService +from .conmanFilesystem import ConmanfilesystemService +from .replspaceApi import ReplspaceapiService +from .shellExec import ShellexecService + + +class foo: + def __init__(self, client: river.Client[Literal[None]]): + self.aiExec = AiexecService(client) + self.agentToolLanguageServer = AgenttoollanguageserverService(client) + self.shellExec = ShellexecService(client) + self.replspaceApi = ReplspaceapiService(client) + self.conmanFilesystem = ConmanfilesystemService(client) diff --git a/tests/v1/codegen/snapshot/snapshots/parity/pydantic_inputs/agentToolLanguageServer/__init__.py b/tests/v1/codegen/snapshot/snapshots/parity/pydantic_inputs/agentToolLanguageServer/__init__.py new file mode 100644 index 00000000..f867f902 --- /dev/null +++ b/tests/v1/codegen/snapshot/snapshots/parity/pydantic_inputs/agentToolLanguageServer/__init__.py @@ -0,0 +1,47 @@ +# Code generated by river.codegen. DO NOT EDIT. +from collections.abc import AsyncIterable, AsyncIterator +from typing import Any +import datetime + +from pydantic import TypeAdapter + +from replit_river.error_schema import RiverError, RiverErrorTypeAdapter +import replit_river as river + + +from .openDocument import ( + OpendocumentErrors, + OpendocumentErrorsTypeAdapter, + OpendocumentInput, + OpendocumentInputTypeAdapter, + OpendocumentOutput, + OpendocumentOutputTypeAdapter, +) + + +class AgenttoollanguageserverService: + def __init__(self, client: river.Client[Any]): + self.client = client + + async def openDocument( + self, + input: OpendocumentInput, + timeout: datetime.timedelta, + ) -> OpendocumentOutput: + return await self.client.send_rpc( + "agentToolLanguageServer", + "openDocument", + input, + lambda x: OpendocumentInputTypeAdapter.dump_python( + x, # type: ignore[arg-type] + by_alias=True, + exclude_none=True, + ), + lambda x: OpendocumentOutputTypeAdapter.validate_python( + x # type: ignore[arg-type] + ), + lambda x: OpendocumentErrorsTypeAdapter.validate_python( + x # type: ignore[arg-type] + ), + timeout, + ) diff --git a/tests/v1/codegen/snapshot/snapshots/parity/pydantic_inputs/agentToolLanguageServer/openDocument.py b/tests/v1/codegen/snapshot/snapshots/parity/pydantic_inputs/agentToolLanguageServer/openDocument.py new file mode 100644 index 00000000..02c814e1 --- /dev/null +++ b/tests/v1/codegen/snapshot/snapshots/parity/pydantic_inputs/agentToolLanguageServer/openDocument.py @@ -0,0 +1,49 @@ +# Code generated by river.codegen. DO NOT EDIT. +from collections.abc import AsyncIterable, AsyncIterator +import datetime +from typing import ( + Any, + Literal, + Mapping, + NotRequired, + TypedDict, +) +from typing_extensions import Annotated + +from pydantic import BaseModel, Field, TypeAdapter, WrapValidator +from replit_river.error_schema import RiverError +from replit_river.client import ( + RiverUnknownError, + translate_unknown_error, + RiverUnknownValue, + translate_unknown_value, +) + +import replit_river as river + + +class OpendocumentInput(BaseModel): + path: str + + +class OpendocumentOutput(BaseModel): + pass + + +OpendocumentOutputTypeAdapter: TypeAdapter[OpendocumentOutput] = TypeAdapter( + OpendocumentOutput +) + + +class OpendocumentErrors(RiverError): + pass + + +OpendocumentErrorsTypeAdapter: TypeAdapter[OpendocumentErrors] = TypeAdapter( + OpendocumentErrors +) + + +OpendocumentInputTypeAdapter: TypeAdapter[OpendocumentInput] = TypeAdapter( + OpendocumentInput +) diff --git a/tests/v1/codegen/snapshot/snapshots/parity/pydantic_inputs/aiExec/__init__.py b/tests/v1/codegen/snapshot/snapshots/parity/pydantic_inputs/aiExec/__init__.py new file mode 100644 index 00000000..f8d9671d --- /dev/null +++ b/tests/v1/codegen/snapshot/snapshots/parity/pydantic_inputs/aiExec/__init__.py @@ -0,0 +1,50 @@ +# Code generated by river.codegen. DO NOT EDIT. +from collections.abc import AsyncIterable, AsyncIterator +from typing import Any +import datetime + +from pydantic import TypeAdapter + +from replit_river.error_schema import RiverError, RiverErrorTypeAdapter +import replit_river as river + + +from .exec import ( + ExecErrors, + ExecErrorsTypeAdapter, + ExecInit, + ExecInitTypeAdapter, + ExecInput, + ExecInputTypeAdapter, + ExecOutput, + ExecOutputTypeAdapter, +) + + +class AiexecService: + def __init__(self, client: river.Client[Any]): + self.client = client + + async def exec( + self, + init: ExecInit, + inputStream: AsyncIterable[ExecInput], + ) -> AsyncIterator[ExecOutput | ExecErrors | RiverError]: + return self.client.send_stream( + "aiExec", + "exec", + init, + inputStream, + lambda x: ExecInitTypeAdapter.validate_python(x), + lambda x: ExecInputTypeAdapter.dump_python( + x, # type: ignore[arg-type] + by_alias=True, + exclude_none=True, + ), + lambda x: ExecOutputTypeAdapter.validate_python( + x # type: ignore[arg-type] + ), + lambda x: ExecErrorsTypeAdapter.validate_python( + x # type: ignore[arg-type] + ), + ) diff --git a/tests/v1/codegen/snapshot/snapshots/parity/pydantic_inputs/aiExec/exec.py b/tests/v1/codegen/snapshot/snapshots/parity/pydantic_inputs/aiExec/exec.py new file mode 100644 index 00000000..57c6c69a --- /dev/null +++ b/tests/v1/codegen/snapshot/snapshots/parity/pydantic_inputs/aiExec/exec.py @@ -0,0 +1,56 @@ +# Code generated by river.codegen. DO NOT EDIT. +from collections.abc import AsyncIterable, AsyncIterator +import datetime +from typing import ( + Any, + Literal, + Mapping, + NotRequired, + TypedDict, +) +from typing_extensions import Annotated + +from pydantic import BaseModel, Field, TypeAdapter, WrapValidator +from replit_river.error_schema import RiverError +from replit_river.client import ( + RiverUnknownError, + translate_unknown_error, + RiverUnknownValue, + translate_unknown_value, +) + +import replit_river as river + + +class ExecInit(BaseModel): + args: list[str] + cwd: str | None = None + env: dict[str, str] | None = None + omitStderr: bool | None = None + omitStdout: bool | None = None + useReplitRunEnv: bool | None = None + + +class ExecInput(BaseModel): + kind: Annotated[Literal["stdin"], Field(alias="$kind")] = "stdin" + stdin: bytes + + +class ExecOutput(BaseModel): + pass + + +ExecOutputTypeAdapter: TypeAdapter[ExecOutput] = TypeAdapter(ExecOutput) + + +class ExecErrors(RiverError): + pass + + +ExecErrorsTypeAdapter: TypeAdapter[ExecErrors] = TypeAdapter(ExecErrors) + + +ExecInitTypeAdapter: TypeAdapter[ExecInit] = TypeAdapter(ExecInit) + + +ExecInputTypeAdapter: TypeAdapter[ExecInput] = TypeAdapter(ExecInput) diff --git a/tests/v1/codegen/snapshot/snapshots/parity/pydantic_inputs/conmanFilesystem/__init__.py b/tests/v1/codegen/snapshot/snapshots/parity/pydantic_inputs/conmanFilesystem/__init__.py new file mode 100644 index 00000000..42d92981 --- /dev/null +++ b/tests/v1/codegen/snapshot/snapshots/parity/pydantic_inputs/conmanFilesystem/__init__.py @@ -0,0 +1,47 @@ +# Code generated by river.codegen. DO NOT EDIT. +from collections.abc import AsyncIterable, AsyncIterator +from typing import Any +import datetime + +from pydantic import TypeAdapter + +from replit_river.error_schema import RiverError, RiverErrorTypeAdapter +import replit_river as river + + +from .persist import ( + PersistErrors, + PersistErrorsTypeAdapter, + PersistInput, + PersistInputTypeAdapter, + PersistOutput, + PersistOutputTypeAdapter, +) + + +class ConmanfilesystemService: + def __init__(self, client: river.Client[Any]): + self.client = client + + async def persist( + self, + input: PersistInput, + timeout: datetime.timedelta, + ) -> PersistOutput: + return await self.client.send_rpc( + "conmanFilesystem", + "persist", + input, + lambda x: PersistInputTypeAdapter.dump_python( + x, # type: ignore[arg-type] + by_alias=True, + exclude_none=True, + ), + lambda x: PersistOutputTypeAdapter.validate_python( + x # type: ignore[arg-type] + ), + lambda x: PersistErrorsTypeAdapter.validate_python( + x # type: ignore[arg-type] + ), + timeout, + ) diff --git a/tests/v1/codegen/snapshot/snapshots/parity/pydantic_inputs/conmanFilesystem/persist.py b/tests/v1/codegen/snapshot/snapshots/parity/pydantic_inputs/conmanFilesystem/persist.py new file mode 100644 index 00000000..3eb5fa01 --- /dev/null +++ b/tests/v1/codegen/snapshot/snapshots/parity/pydantic_inputs/conmanFilesystem/persist.py @@ -0,0 +1,43 @@ +# Code generated by river.codegen. DO NOT EDIT. +from collections.abc import AsyncIterable, AsyncIterator +import datetime +from typing import ( + Any, + Literal, + Mapping, + NotRequired, + TypedDict, +) +from typing_extensions import Annotated + +from pydantic import BaseModel, Field, TypeAdapter, WrapValidator +from replit_river.error_schema import RiverError +from replit_river.client import ( + RiverUnknownError, + translate_unknown_error, + RiverUnknownValue, + translate_unknown_value, +) + +import replit_river as river + + +class PersistInput(BaseModel): + pass + + +class PersistOutput(BaseModel): + pass + + +PersistOutputTypeAdapter: TypeAdapter[PersistOutput] = TypeAdapter(PersistOutput) + + +class PersistErrors(RiverError): + pass + + +PersistErrorsTypeAdapter: TypeAdapter[PersistErrors] = TypeAdapter(PersistErrors) + + +PersistInputTypeAdapter: TypeAdapter[PersistInput] = TypeAdapter(PersistInput) diff --git a/tests/v1/codegen/snapshot/snapshots/parity/pydantic_inputs/replspaceApi/__init__.py b/tests/v1/codegen/snapshot/snapshots/parity/pydantic_inputs/replspaceApi/__init__.py new file mode 100644 index 00000000..8842016c --- /dev/null +++ b/tests/v1/codegen/snapshot/snapshots/parity/pydantic_inputs/replspaceApi/__init__.py @@ -0,0 +1,50 @@ +# Code generated by river.codegen. DO NOT EDIT. +from collections.abc import AsyncIterable, AsyncIterator +from typing import Any +import datetime + +from pydantic import TypeAdapter + +from replit_river.error_schema import RiverError, RiverErrorTypeAdapter +import replit_river as river + + +from .init import ( + InitErrors, + InitErrorsTypeAdapter, + InitInit, + InitInitTypeAdapter, + InitInput, + InitInputTypeAdapter, + InitOutput, + InitOutputTypeAdapter, +) + + +class ReplspaceapiService: + def __init__(self, client: river.Client[Any]): + self.client = client + + async def init( + self, + init: InitInit, + inputStream: AsyncIterable[InitInput], + ) -> AsyncIterator[InitOutput | InitErrors | RiverError]: + return self.client.send_stream( + "replspaceApi", + "init", + init, + inputStream, + lambda x: InitInitTypeAdapter.validate_python(x), + lambda x: InitInputTypeAdapter.dump_python( + x, # type: ignore[arg-type] + by_alias=True, + exclude_none=True, + ), + lambda x: InitOutputTypeAdapter.validate_python( + x # type: ignore[arg-type] + ), + lambda x: InitErrorsTypeAdapter.validate_python( + x # type: ignore[arg-type] + ), + ) diff --git a/tests/v1/codegen/snapshot/snapshots/parity/pydantic_inputs/replspaceApi/init.py b/tests/v1/codegen/snapshot/snapshots/parity/pydantic_inputs/replspaceApi/init.py new file mode 100644 index 00000000..46bc3a96 --- /dev/null +++ b/tests/v1/codegen/snapshot/snapshots/parity/pydantic_inputs/replspaceApi/init.py @@ -0,0 +1,106 @@ +# Code generated by river.codegen. DO NOT EDIT. +from collections.abc import AsyncIterable, AsyncIterator +import datetime +from typing import ( + Any, + Literal, + Mapping, + NotRequired, + TypedDict, +) +from typing_extensions import Annotated + +from pydantic import BaseModel, Field, TypeAdapter, WrapValidator +from replit_river.error_schema import RiverError +from replit_river.client import ( + RiverUnknownError, + translate_unknown_error, + RiverUnknownValue, + translate_unknown_value, +) + +import replit_river as river + + +class InitInit(BaseModel): + pass + + +class InitInputOneOf_closeFile(BaseModel): + kind: Annotated[Literal["closeFile"], Field(alias="$kind")] = "closeFile" + filename: str + nonce: str + + +class InitInputOneOf_githubToken(BaseModel): + kind: Annotated[Literal["githubToken"], Field(alias="$kind")] = "githubToken" + nonce: str + token: str | None = None + + +class InitInputOneOf_sshToken0(BaseModel): + kind: Annotated[Literal["sshToken"], Field(alias="$kind")] = "sshToken" + SSHHostname: str + nonce: str + token: str + + +class InitInputOneOf_sshToken1(BaseModel): + kind: Annotated[Literal["sshToken"], Field(alias="$kind")] = "sshToken" + error: str + nonce: str + + +class InitInputOneOf_allowDefaultBucketAccessResultOneOf_error(BaseModel): + message: str + status: Literal["error"] + + +class InitInputOneOf_allowDefaultBucketAccessResultOneOf_ok(BaseModel): + bucketId: str + sourceReplId: str + status: Literal["ok"] + targetReplId: str + + +InitInputOneOf_allowDefaultBucketAccessResult = ( + InitInputOneOf_allowDefaultBucketAccessResultOneOf_error + | InitInputOneOf_allowDefaultBucketAccessResultOneOf_ok +) + + +class InitInputOneOf_allowDefaultBucketAccess(BaseModel): + kind: Annotated[Literal["allowDefaultBucketAccess"], Field(alias="$kind")] = ( + "allowDefaultBucketAccess" + ) + nonce: str + result: InitInputOneOf_allowDefaultBucketAccessResult + + +InitInput = ( + InitInputOneOf_closeFile + | InitInputOneOf_githubToken + | InitInputOneOf_sshToken0 + | InitInputOneOf_sshToken1 + | InitInputOneOf_allowDefaultBucketAccess +) + + +class InitOutput(BaseModel): + pass + + +InitOutputTypeAdapter: TypeAdapter[InitOutput] = TypeAdapter(InitOutput) + + +class InitErrors(RiverError): + pass + + +InitErrorsTypeAdapter: TypeAdapter[InitErrors] = TypeAdapter(InitErrors) + + +InitInitTypeAdapter: TypeAdapter[InitInit] = TypeAdapter(InitInit) + + +InitInputTypeAdapter: TypeAdapter[InitInput] = TypeAdapter(InitInput) diff --git a/tests/v1/codegen/snapshot/snapshots/parity/pydantic_inputs/shellExec/__init__.py b/tests/v1/codegen/snapshot/snapshots/parity/pydantic_inputs/shellExec/__init__.py new file mode 100644 index 00000000..9e0013a7 --- /dev/null +++ b/tests/v1/codegen/snapshot/snapshots/parity/pydantic_inputs/shellExec/__init__.py @@ -0,0 +1,47 @@ +# Code generated by river.codegen. DO NOT EDIT. +from collections.abc import AsyncIterable, AsyncIterator +from typing import Any +import datetime + +from pydantic import TypeAdapter + +from replit_river.error_schema import RiverError, RiverErrorTypeAdapter +import replit_river as river + + +from .spawn import ( + SpawnErrors, + SpawnErrorsTypeAdapter, + SpawnInput, + SpawnInputTypeAdapter, + SpawnOutput, + SpawnOutputTypeAdapter, +) + + +class ShellexecService: + def __init__(self, client: river.Client[Any]): + self.client = client + + async def spawn( + self, + input: SpawnInput, + timeout: datetime.timedelta, + ) -> SpawnOutput: + return await self.client.send_rpc( + "shellExec", + "spawn", + input, + lambda x: SpawnInputTypeAdapter.dump_python( + x, # type: ignore[arg-type] + by_alias=True, + exclude_none=True, + ), + lambda x: SpawnOutputTypeAdapter.validate_python( + x # type: ignore[arg-type] + ), + lambda x: SpawnErrorsTypeAdapter.validate_python( + x # type: ignore[arg-type] + ), + timeout, + ) diff --git a/tests/v1/codegen/snapshot/snapshots/parity/pydantic_inputs/shellExec/spawn.py b/tests/v1/codegen/snapshot/snapshots/parity/pydantic_inputs/shellExec/spawn.py new file mode 100644 index 00000000..6f4f6473 --- /dev/null +++ b/tests/v1/codegen/snapshot/snapshots/parity/pydantic_inputs/shellExec/spawn.py @@ -0,0 +1,57 @@ +# Code generated by river.codegen. DO NOT EDIT. +from collections.abc import AsyncIterable, AsyncIterator +import datetime +from typing import ( + Any, + Literal, + Mapping, + NotRequired, + TypedDict, +) +from typing_extensions import Annotated + +from pydantic import BaseModel, Field, TypeAdapter, WrapValidator +from replit_river.error_schema import RiverError +from replit_river.client import ( + RiverUnknownError, + translate_unknown_error, + RiverUnknownValue, + translate_unknown_value, +) + +import replit_river as river + + +class SpawnInputSize(BaseModel): + cols: int + rows: int + + +class SpawnInput(BaseModel): + args: list[str] | None = None + autoCleanup: bool | None = None + cmd: str + cwd: str | None = None + env: dict[str, str] | None = None + initialCmd: str | None = None + interactive: bool | None = None + size: SpawnInputSize | None = None + useCgroupMagic: bool | None = None + useReplitRunEnv: bool | None = None + + +class SpawnOutput(BaseModel): + pass + + +SpawnOutputTypeAdapter: TypeAdapter[SpawnOutput] = TypeAdapter(SpawnOutput) + + +class SpawnErrors(RiverError): + pass + + +SpawnErrorsTypeAdapter: TypeAdapter[SpawnErrors] = TypeAdapter(SpawnErrors) + + +SpawnInputTypeAdapter: TypeAdapter[SpawnInput] = TypeAdapter(SpawnInput) diff --git a/tests/v1/codegen/snapshot/snapshots/parity/typeddict_inputs/__init__.py b/tests/v1/codegen/snapshot/snapshots/parity/typeddict_inputs/__init__.py new file mode 100644 index 00000000..0f5c7dfb --- /dev/null +++ b/tests/v1/codegen/snapshot/snapshots/parity/typeddict_inputs/__init__.py @@ -0,0 +1,21 @@ +# Code generated by river.codegen. DO NOT EDIT. +from pydantic import BaseModel +from typing import Literal + +import replit_river as river + + +from .agentToolLanguageServer import AgenttoollanguageserverService +from .aiExec import AiexecService +from .conmanFilesystem import ConmanfilesystemService +from .replspaceApi import ReplspaceapiService +from .shellExec import ShellexecService + + +class foo: + def __init__(self, client: river.Client[Literal[None]]): + self.aiExec = AiexecService(client) + self.agentToolLanguageServer = AgenttoollanguageserverService(client) + self.shellExec = ShellexecService(client) + self.replspaceApi = ReplspaceapiService(client) + self.conmanFilesystem = ConmanfilesystemService(client) diff --git a/tests/v1/codegen/snapshot/snapshots/parity/typeddict_inputs/agentToolLanguageServer/__init__.py b/tests/v1/codegen/snapshot/snapshots/parity/typeddict_inputs/agentToolLanguageServer/__init__.py new file mode 100644 index 00000000..4cc3ebdd --- /dev/null +++ b/tests/v1/codegen/snapshot/snapshots/parity/typeddict_inputs/agentToolLanguageServer/__init__.py @@ -0,0 +1,43 @@ +# Code generated by river.codegen. DO NOT EDIT. +from collections.abc import AsyncIterable, AsyncIterator +from typing import Any +import datetime + +from pydantic import TypeAdapter + +from replit_river.error_schema import RiverError, RiverErrorTypeAdapter +import replit_river as river + + +from .openDocument import ( + OpendocumentErrors, + OpendocumentErrorsTypeAdapter, + OpendocumentInput, + OpendocumentOutput, + OpendocumentOutputTypeAdapter, + encode_OpendocumentInput, +) + + +class AgenttoollanguageserverService: + def __init__(self, client: river.Client[Any]): + self.client = client + + async def openDocument( + self, + input: OpendocumentInput, + timeout: datetime.timedelta, + ) -> OpendocumentOutput: + return await self.client.send_rpc( + "agentToolLanguageServer", + "openDocument", + input, + encode_OpendocumentInput, + lambda x: OpendocumentOutputTypeAdapter.validate_python( + x # type: ignore[arg-type] + ), + lambda x: OpendocumentErrorsTypeAdapter.validate_python( + x # type: ignore[arg-type] + ), + timeout, + ) diff --git a/tests/v1/codegen/snapshot/snapshots/parity/typeddict_inputs/agentToolLanguageServer/openDocument.py b/tests/v1/codegen/snapshot/snapshots/parity/typeddict_inputs/agentToolLanguageServer/openDocument.py new file mode 100644 index 00000000..efbd4d8d --- /dev/null +++ b/tests/v1/codegen/snapshot/snapshots/parity/typeddict_inputs/agentToolLanguageServer/openDocument.py @@ -0,0 +1,58 @@ +# Code generated by river.codegen. DO NOT EDIT. +from collections.abc import AsyncIterable, AsyncIterator +import datetime +from typing import ( + Any, + Literal, + Mapping, + NotRequired, + TypedDict, +) +from typing_extensions import Annotated + +from pydantic import BaseModel, Field, TypeAdapter, WrapValidator +from replit_river.error_schema import RiverError +from replit_river.client import ( + RiverUnknownError, + translate_unknown_error, + RiverUnknownValue, + translate_unknown_value, +) + +import replit_river as river + + +def encode_OpendocumentInput( + x: "OpendocumentInput", +) -> Any: + return { + k: v + for (k, v) in ( + { + "path": x.get("path"), + } + ).items() + if v is not None + } + + +class OpendocumentInput(TypedDict): + path: str + + +class OpendocumentOutput(BaseModel): + pass + + +OpendocumentOutputTypeAdapter: TypeAdapter[OpendocumentOutput] = TypeAdapter( + OpendocumentOutput +) + + +class OpendocumentErrors(RiverError): + pass + + +OpendocumentErrorsTypeAdapter: TypeAdapter[OpendocumentErrors] = TypeAdapter( + OpendocumentErrors +) diff --git a/tests/v1/codegen/snapshot/snapshots/parity/typeddict_inputs/aiExec/__init__.py b/tests/v1/codegen/snapshot/snapshots/parity/typeddict_inputs/aiExec/__init__.py new file mode 100644 index 00000000..d66a196b --- /dev/null +++ b/tests/v1/codegen/snapshot/snapshots/parity/typeddict_inputs/aiExec/__init__.py @@ -0,0 +1,46 @@ +# Code generated by river.codegen. DO NOT EDIT. +from collections.abc import AsyncIterable, AsyncIterator +from typing import Any +import datetime + +from pydantic import TypeAdapter + +from replit_river.error_schema import RiverError, RiverErrorTypeAdapter +import replit_river as river + + +from .exec import ( + ExecErrors, + ExecErrorsTypeAdapter, + ExecInit, + ExecInput, + ExecOutput, + ExecOutputTypeAdapter, + encode_ExecInit, + encode_ExecInput, +) + + +class AiexecService: + def __init__(self, client: river.Client[Any]): + self.client = client + + async def exec( + self, + init: ExecInit, + inputStream: AsyncIterable[ExecInput], + ) -> AsyncIterator[ExecOutput | ExecErrors | RiverError]: + return self.client.send_stream( + "aiExec", + "exec", + init, + inputStream, + encode_ExecInit, + encode_ExecInput, + lambda x: ExecOutputTypeAdapter.validate_python( + x # type: ignore[arg-type] + ), + lambda x: ExecErrorsTypeAdapter.validate_python( + x # type: ignore[arg-type] + ), + ) diff --git a/tests/v1/codegen/snapshot/snapshots/parity/typeddict_inputs/aiExec/exec.py b/tests/v1/codegen/snapshot/snapshots/parity/typeddict_inputs/aiExec/exec.py new file mode 100644 index 00000000..8f7bc741 --- /dev/null +++ b/tests/v1/codegen/snapshot/snapshots/parity/typeddict_inputs/aiExec/exec.py @@ -0,0 +1,84 @@ +# Code generated by river.codegen. DO NOT EDIT. +from collections.abc import AsyncIterable, AsyncIterator +import datetime +from typing import ( + Any, + Literal, + Mapping, + NotRequired, + TypedDict, +) +from typing_extensions import Annotated + +from pydantic import BaseModel, Field, TypeAdapter, WrapValidator +from replit_river.error_schema import RiverError +from replit_river.client import ( + RiverUnknownError, + translate_unknown_error, + RiverUnknownValue, + translate_unknown_value, +) + +import replit_river as river + + +def encode_ExecInit( + x: "ExecInit", +) -> Any: + return { + k: v + for (k, v) in ( + { + "args": x.get("args"), + "cwd": x.get("cwd"), + "env": x.get("env"), + "omitStderr": x.get("omitStderr"), + "omitStdout": x.get("omitStdout"), + "useReplitRunEnv": x.get("useReplitRunEnv"), + } + ).items() + if v is not None + } + + +class ExecInit(TypedDict): + args: list[str] + cwd: NotRequired[str | None] + env: NotRequired[dict[str, str] | None] + omitStderr: NotRequired[bool | None] + omitStdout: NotRequired[bool | None] + useReplitRunEnv: NotRequired[bool | None] + + +def encode_ExecInput( + x: "ExecInput", +) -> Any: + return { + k: v + for (k, v) in ( + { + "$kind": x.get("kind"), + "stdin": x.get("stdin"), + } + ).items() + if v is not None + } + + +class ExecInput(TypedDict): + kind: Annotated[Literal["stdin"], Field(alias="$kind")] + stdin: bytes + + +class ExecOutput(BaseModel): + pass + + +ExecOutputTypeAdapter: TypeAdapter[ExecOutput] = TypeAdapter(ExecOutput) + + +class ExecErrors(RiverError): + pass + + +ExecErrorsTypeAdapter: TypeAdapter[ExecErrors] = TypeAdapter(ExecErrors) diff --git a/tests/v1/codegen/snapshot/snapshots/parity/typeddict_inputs/conmanFilesystem/__init__.py b/tests/v1/codegen/snapshot/snapshots/parity/typeddict_inputs/conmanFilesystem/__init__.py new file mode 100644 index 00000000..ef5b401d --- /dev/null +++ b/tests/v1/codegen/snapshot/snapshots/parity/typeddict_inputs/conmanFilesystem/__init__.py @@ -0,0 +1,43 @@ +# Code generated by river.codegen. DO NOT EDIT. +from collections.abc import AsyncIterable, AsyncIterator +from typing import Any +import datetime + +from pydantic import TypeAdapter + +from replit_river.error_schema import RiverError, RiverErrorTypeAdapter +import replit_river as river + + +from .persist import ( + PersistErrors, + PersistErrorsTypeAdapter, + PersistInput, + PersistOutput, + PersistOutputTypeAdapter, + encode_PersistInput, +) + + +class ConmanfilesystemService: + def __init__(self, client: river.Client[Any]): + self.client = client + + async def persist( + self, + input: PersistInput, + timeout: datetime.timedelta, + ) -> PersistOutput: + return await self.client.send_rpc( + "conmanFilesystem", + "persist", + input, + encode_PersistInput, + lambda x: PersistOutputTypeAdapter.validate_python( + x # type: ignore[arg-type] + ), + lambda x: PersistErrorsTypeAdapter.validate_python( + x # type: ignore[arg-type] + ), + timeout, + ) diff --git a/tests/v1/codegen/snapshot/snapshots/parity/typeddict_inputs/conmanFilesystem/persist.py b/tests/v1/codegen/snapshot/snapshots/parity/typeddict_inputs/conmanFilesystem/persist.py new file mode 100644 index 00000000..5ea0df8a --- /dev/null +++ b/tests/v1/codegen/snapshot/snapshots/parity/typeddict_inputs/conmanFilesystem/persist.py @@ -0,0 +1,46 @@ +# Code generated by river.codegen. DO NOT EDIT. +from collections.abc import AsyncIterable, AsyncIterator +import datetime +from typing import ( + Any, + Literal, + Mapping, + NotRequired, + TypedDict, +) +from typing_extensions import Annotated + +from pydantic import BaseModel, Field, TypeAdapter, WrapValidator +from replit_river.error_schema import RiverError +from replit_river.client import ( + RiverUnknownError, + translate_unknown_error, + RiverUnknownValue, + translate_unknown_value, +) + +import replit_river as river + + +def encode_PersistInput( + _: "PersistInput", +) -> Any: + return {} + + +class PersistInput(TypedDict): + pass + + +class PersistOutput(BaseModel): + pass + + +PersistOutputTypeAdapter: TypeAdapter[PersistOutput] = TypeAdapter(PersistOutput) + + +class PersistErrors(RiverError): + pass + + +PersistErrorsTypeAdapter: TypeAdapter[PersistErrors] = TypeAdapter(PersistErrors) diff --git a/tests/v1/codegen/snapshot/snapshots/parity/typeddict_inputs/replspaceApi/__init__.py b/tests/v1/codegen/snapshot/snapshots/parity/typeddict_inputs/replspaceApi/__init__.py new file mode 100644 index 00000000..72f33c93 --- /dev/null +++ b/tests/v1/codegen/snapshot/snapshots/parity/typeddict_inputs/replspaceApi/__init__.py @@ -0,0 +1,48 @@ +# Code generated by river.codegen. DO NOT EDIT. +from collections.abc import AsyncIterable, AsyncIterator +from typing import Any +import datetime + +from pydantic import TypeAdapter + +from replit_river.error_schema import RiverError, RiverErrorTypeAdapter +import replit_river as river + + +from .init import ( + InitErrors, + InitErrorsTypeAdapter, + InitInit, + InitInput, + InitOutput, + InitOutputTypeAdapter, + encode_InitInit, + encode_InitInput, + encode_InitInputOneOf_sshToken0, + encode_InitInputOneOf_sshToken1, +) + + +class ReplspaceapiService: + def __init__(self, client: river.Client[Any]): + self.client = client + + async def init( + self, + init: InitInit, + inputStream: AsyncIterable[InitInput], + ) -> AsyncIterator[InitOutput | InitErrors | RiverError]: + return self.client.send_stream( + "replspaceApi", + "init", + init, + inputStream, + encode_InitInit, + encode_InitInput, + lambda x: InitOutputTypeAdapter.validate_python( + x # type: ignore[arg-type] + ), + lambda x: InitErrorsTypeAdapter.validate_python( + x # type: ignore[arg-type] + ), + ) diff --git a/tests/v1/codegen/snapshot/snapshots/parity/typeddict_inputs/replspaceApi/init.py b/tests/v1/codegen/snapshot/snapshots/parity/typeddict_inputs/replspaceApi/init.py new file mode 100644 index 00000000..aecafb6c --- /dev/null +++ b/tests/v1/codegen/snapshot/snapshots/parity/typeddict_inputs/replspaceApi/init.py @@ -0,0 +1,247 @@ +# Code generated by river.codegen. DO NOT EDIT. +from collections.abc import AsyncIterable, AsyncIterator +import datetime +from typing import ( + Any, + Literal, + Mapping, + NotRequired, + TypedDict, +) +from typing_extensions import Annotated + +from pydantic import BaseModel, Field, TypeAdapter, WrapValidator +from replit_river.error_schema import RiverError +from replit_river.client import ( + RiverUnknownError, + translate_unknown_error, + RiverUnknownValue, + translate_unknown_value, +) + +import replit_river as river + + +def encode_InitInit( + _: "InitInit", +) -> Any: + return {} + + +class InitInit(TypedDict): + pass + + +def encode_InitInputOneOf_closeFile( + x: "InitInputOneOf_closeFile", +) -> Any: + return { + k: v + for (k, v) in ( + { + "$kind": x.get("kind"), + "filename": x.get("filename"), + "nonce": x.get("nonce"), + } + ).items() + if v is not None + } + + +class InitInputOneOf_closeFile(TypedDict): + kind: Annotated[Literal["closeFile"], Field(alias="$kind")] + filename: str + nonce: str + + +def encode_InitInputOneOf_githubToken( + x: "InitInputOneOf_githubToken", +) -> Any: + return { + k: v + for (k, v) in ( + { + "$kind": x.get("kind"), + "nonce": x.get("nonce"), + "token": x.get("token"), + } + ).items() + if v is not None + } + + +class InitInputOneOf_githubToken(TypedDict): + kind: Annotated[Literal["githubToken"], Field(alias="$kind")] + nonce: str + token: NotRequired[str | None] + + +def encode_InitInputOneOf_sshToken0( + x: "InitInputOneOf_sshToken0", +) -> Any: + return { + k: v + for (k, v) in ( + { + "$kind": x.get("kind"), + "SSHHostname": x.get("SSHHostname"), + "nonce": x.get("nonce"), + "token": x.get("token"), + } + ).items() + if v is not None + } + + +class InitInputOneOf_sshToken0(TypedDict): + kind: Annotated[Literal["sshToken"], Field(alias="$kind")] + SSHHostname: str + nonce: str + token: str + + +def encode_InitInputOneOf_sshToken1( + x: "InitInputOneOf_sshToken1", +) -> Any: + return { + k: v + for (k, v) in ( + { + "$kind": x.get("kind"), + "error": x.get("error"), + "nonce": x.get("nonce"), + } + ).items() + if v is not None + } + + +class InitInputOneOf_sshToken1(TypedDict): + kind: Annotated[Literal["sshToken"], Field(alias="$kind")] + error: str + nonce: str + + +def encode_InitInputOneOf_allowDefaultBucketAccessResultOneOf_error( + x: "InitInputOneOf_allowDefaultBucketAccessResultOneOf_error", +) -> Any: + return { + k: v + for (k, v) in ( + { + "message": x.get("message"), + "status": x.get("status"), + } + ).items() + if v is not None + } + + +class InitInputOneOf_allowDefaultBucketAccessResultOneOf_error(TypedDict): + message: str + status: Literal["error"] + + +def encode_InitInputOneOf_allowDefaultBucketAccessResultOneOf_ok( + x: "InitInputOneOf_allowDefaultBucketAccessResultOneOf_ok", +) -> Any: + return { + k: v + for (k, v) in ( + { + "bucketId": x.get("bucketId"), + "sourceReplId": x.get("sourceReplId"), + "status": x.get("status"), + "targetReplId": x.get("targetReplId"), + } + ).items() + if v is not None + } + + +class InitInputOneOf_allowDefaultBucketAccessResultOneOf_ok(TypedDict): + bucketId: str + sourceReplId: str + status: Literal["ok"] + targetReplId: str + + +InitInputOneOf_allowDefaultBucketAccessResult = ( + InitInputOneOf_allowDefaultBucketAccessResultOneOf_error + | InitInputOneOf_allowDefaultBucketAccessResultOneOf_ok +) + + +def encode_InitInputOneOf_allowDefaultBucketAccessResult( + x: "InitInputOneOf_allowDefaultBucketAccessResult", +) -> Any: + return ( + encode_InitInputOneOf_allowDefaultBucketAccessResultOneOf_error(x) + if x["status"] == "error" + else encode_InitInputOneOf_allowDefaultBucketAccessResultOneOf_ok(x) + ) + + +def encode_InitInputOneOf_allowDefaultBucketAccess( + x: "InitInputOneOf_allowDefaultBucketAccess", +) -> Any: + return { + k: v + for (k, v) in ( + { + "$kind": x.get("kind"), + "nonce": x.get("nonce"), + "result": encode_InitInputOneOf_allowDefaultBucketAccessResult( + x["result"] + ), + } + ).items() + if v is not None + } + + +class InitInputOneOf_allowDefaultBucketAccess(TypedDict): + kind: Annotated[Literal["allowDefaultBucketAccess"], Field(alias="$kind")] + nonce: str + result: InitInputOneOf_allowDefaultBucketAccessResult + + +InitInput = ( + InitInputOneOf_closeFile + | InitInputOneOf_githubToken + | InitInputOneOf_sshToken0 + | InitInputOneOf_sshToken1 + | InitInputOneOf_allowDefaultBucketAccess +) + + +def encode_InitInput( + x: "InitInput", +) -> Any: + return ( + encode_InitInputOneOf_closeFile(x) + if x["kind"] == "closeFile" + else encode_InitInputOneOf_githubToken(x) + if x["kind"] == "githubToken" + else ( + encode_InitInputOneOf_sshToken0(x) # type: ignore[arg-type] + if "token" in x + else encode_InitInputOneOf_sshToken1(x) # type: ignore[arg-type] + ) + if x["kind"] == "sshToken" + else encode_InitInputOneOf_allowDefaultBucketAccess(x) + ) + + +class InitOutput(BaseModel): + pass + + +InitOutputTypeAdapter: TypeAdapter[InitOutput] = TypeAdapter(InitOutput) + + +class InitErrors(RiverError): + pass + + +InitErrorsTypeAdapter: TypeAdapter[InitErrors] = TypeAdapter(InitErrors) diff --git a/tests/v1/codegen/snapshot/snapshots/parity/typeddict_inputs/shellExec/__init__.py b/tests/v1/codegen/snapshot/snapshots/parity/typeddict_inputs/shellExec/__init__.py new file mode 100644 index 00000000..41f0d9e9 --- /dev/null +++ b/tests/v1/codegen/snapshot/snapshots/parity/typeddict_inputs/shellExec/__init__.py @@ -0,0 +1,44 @@ +# Code generated by river.codegen. DO NOT EDIT. +from collections.abc import AsyncIterable, AsyncIterator +from typing import Any +import datetime + +from pydantic import TypeAdapter + +from replit_river.error_schema import RiverError, RiverErrorTypeAdapter +import replit_river as river + + +from .spawn import ( + SpawnErrors, + SpawnErrorsTypeAdapter, + SpawnInput, + SpawnOutput, + SpawnOutputTypeAdapter, + encode_SpawnInput, + encode_SpawnInputSize, +) + + +class ShellexecService: + def __init__(self, client: river.Client[Any]): + self.client = client + + async def spawn( + self, + input: SpawnInput, + timeout: datetime.timedelta, + ) -> SpawnOutput: + return await self.client.send_rpc( + "shellExec", + "spawn", + input, + encode_SpawnInput, + lambda x: SpawnOutputTypeAdapter.validate_python( + x # type: ignore[arg-type] + ), + lambda x: SpawnErrorsTypeAdapter.validate_python( + x # type: ignore[arg-type] + ), + timeout, + ) diff --git a/tests/v1/codegen/snapshot/snapshots/parity/typeddict_inputs/shellExec/spawn.py b/tests/v1/codegen/snapshot/snapshots/parity/typeddict_inputs/shellExec/spawn.py new file mode 100644 index 00000000..d4eb64b7 --- /dev/null +++ b/tests/v1/codegen/snapshot/snapshots/parity/typeddict_inputs/shellExec/spawn.py @@ -0,0 +1,94 @@ +# Code generated by river.codegen. DO NOT EDIT. +from collections.abc import AsyncIterable, AsyncIterator +import datetime +from typing import ( + Any, + Literal, + Mapping, + NotRequired, + TypedDict, +) +from typing_extensions import Annotated + +from pydantic import BaseModel, Field, TypeAdapter, WrapValidator +from replit_river.error_schema import RiverError +from replit_river.client import ( + RiverUnknownError, + translate_unknown_error, + RiverUnknownValue, + translate_unknown_value, +) + +import replit_river as river + + +def encode_SpawnInputSize( + x: "SpawnInputSize", +) -> Any: + return { + k: v + for (k, v) in ( + { + "cols": x.get("cols"), + "rows": x.get("rows"), + } + ).items() + if v is not None + } + + +class SpawnInputSize(TypedDict): + cols: int + rows: int + + +def encode_SpawnInput( + x: "SpawnInput", +) -> Any: + return { + k: v + for (k, v) in ( + { + "args": x.get("args"), + "autoCleanup": x.get("autoCleanup"), + "cmd": x.get("cmd"), + "cwd": x.get("cwd"), + "env": x.get("env"), + "initialCmd": x.get("initialCmd"), + "interactive": x.get("interactive"), + "size": encode_SpawnInputSize(x["size"]) + if "size" in x and x["size"] is not None + else None, + "useCgroupMagic": x.get("useCgroupMagic"), + "useReplitRunEnv": x.get("useReplitRunEnv"), + } + ).items() + if v is not None + } + + +class SpawnInput(TypedDict): + args: NotRequired[list[str] | None] + autoCleanup: NotRequired[bool | None] + cmd: str + cwd: NotRequired[str | None] + env: NotRequired[dict[str, str] | None] + initialCmd: NotRequired[str | None] + interactive: NotRequired[bool | None] + size: NotRequired[SpawnInputSize | None] + useCgroupMagic: NotRequired[bool | None] + useReplitRunEnv: NotRequired[bool | None] + + +class SpawnOutput(BaseModel): + pass + + +SpawnOutputTypeAdapter: TypeAdapter[SpawnOutput] = TypeAdapter(SpawnOutput) + + +class SpawnErrors(RiverError): + pass + + +SpawnErrorsTypeAdapter: TypeAdapter[SpawnErrors] = TypeAdapter(SpawnErrors) diff --git a/tests/codegen/snapshot/snapshots/test_basic_stream/__init__.py b/tests/v1/codegen/snapshot/snapshots/test_basic_stream/__init__.py similarity index 100% rename from tests/codegen/snapshot/snapshots/test_basic_stream/__init__.py rename to tests/v1/codegen/snapshot/snapshots/test_basic_stream/__init__.py diff --git a/tests/codegen/snapshot/snapshots/test_basic_stream/test_service/__init__.py b/tests/v1/codegen/snapshot/snapshots/test_basic_stream/test_service/__init__.py similarity index 91% rename from tests/codegen/snapshot/snapshots/test_basic_stream/test_service/__init__.py rename to tests/v1/codegen/snapshot/snapshots/test_basic_stream/test_service/__init__.py index 44d6c18c..0b106014 100644 --- a/tests/codegen/snapshot/snapshots/test_basic_stream/test_service/__init__.py +++ b/tests/v1/codegen/snapshot/snapshots/test_basic_stream/test_service/__init__.py @@ -11,17 +11,18 @@ from .stream_method import ( Stream_MethodInput, - Stream_MethodInputTypeAdapter, Stream_MethodOutput, Stream_MethodOutputTypeAdapter, encode_Stream_MethodInput, ) -from .emit_error import Emit_ErrorErrors, Emit_ErrorErrorsTypeAdapter +from .emit_error import Emit_ErrorErrors -intTypeAdapter: TypeAdapter[int] = TypeAdapter(int) +boolTypeAdapter: TypeAdapter[bool] = TypeAdapter(bool) -boolTypeAdapter: TypeAdapter[bool] = TypeAdapter(bool) +Emit_ErrorErrorsTypeAdapter: TypeAdapter[Emit_ErrorErrors] = TypeAdapter( + Emit_ErrorErrors +) class Test_ServiceService: diff --git a/tests/codegen/snapshot/snapshots/test_basic_stream/test_service/emit_error.py b/tests/v1/codegen/snapshot/snapshots/test_basic_stream/test_service/emit_error.py similarity index 90% rename from tests/codegen/snapshot/snapshots/test_basic_stream/test_service/emit_error.py rename to tests/v1/codegen/snapshot/snapshots/test_basic_stream/test_service/emit_error.py index ddba3a38..e7005c29 100644 --- a/tests/codegen/snapshot/snapshots/test_basic_stream/test_service/emit_error.py +++ b/tests/v1/codegen/snapshot/snapshots/test_basic_stream/test_service/emit_error.py @@ -38,8 +38,3 @@ class Emit_ErrorErrorsOneOf_UNEXPECTED_DISCONNECT(RiverError): | RiverUnknownError, WrapValidator(translate_unknown_error), ] - - -Emit_ErrorErrorsTypeAdapter: TypeAdapter[Emit_ErrorErrors] = TypeAdapter( - Emit_ErrorErrors -) diff --git a/tests/codegen/snapshot/snapshots/test_basic_stream/test_service/stream_method.py b/tests/v1/codegen/snapshot/snapshots/test_basic_stream/test_service/stream_method.py similarity index 90% rename from tests/codegen/snapshot/snapshots/test_basic_stream/test_service/stream_method.py rename to tests/v1/codegen/snapshot/snapshots/test_basic_stream/test_service/stream_method.py index 5baa9c40..1914aefc 100644 --- a/tests/codegen/snapshot/snapshots/test_basic_stream/test_service/stream_method.py +++ b/tests/v1/codegen/snapshot/snapshots/test_basic_stream/test_service/stream_method.py @@ -40,11 +40,6 @@ class Stream_MethodInput(TypedDict): data: str -Stream_MethodInputTypeAdapter: TypeAdapter[Stream_MethodInput] = TypeAdapter( - Stream_MethodInput -) - - class Stream_MethodOutput(BaseModel): data: str diff --git a/tests/codegen/snapshot/snapshots/test_pathological_types/__init__.py b/tests/v1/codegen/snapshot/snapshots/test_pathological_types/__init__.py similarity index 100% rename from tests/codegen/snapshot/snapshots/test_pathological_types/__init__.py rename to tests/v1/codegen/snapshot/snapshots/test_pathological_types/__init__.py diff --git a/tests/codegen/snapshot/snapshots/test_pathological_types/test_service/__init__.py b/tests/v1/codegen/snapshot/snapshots/test_pathological_types/test_service/__init__.py similarity index 97% rename from tests/codegen/snapshot/snapshots/test_pathological_types/test_service/__init__.py rename to tests/v1/codegen/snapshot/snapshots/test_pathological_types/test_service/__init__.py index 3a578118..dd7f15e4 100644 --- a/tests/codegen/snapshot/snapshots/test_pathological_types/test_service/__init__.py +++ b/tests/v1/codegen/snapshot/snapshots/test_pathological_types/test_service/__init__.py @@ -11,7 +11,6 @@ from .pathological_method import ( Pathological_MethodInput, - Pathological_MethodInputTypeAdapter, encode_Pathological_MethodInput, encode_Pathological_MethodInputObj_Boolean, encode_Pathological_MethodInputObj_Date, diff --git a/tests/codegen/snapshot/snapshots/test_pathological_types/test_service/pathological_method.py b/tests/v1/codegen/snapshot/snapshots/test_pathological_types/test_service/pathological_method.py similarity index 99% rename from tests/codegen/snapshot/snapshots/test_pathological_types/test_service/pathological_method.py rename to tests/v1/codegen/snapshot/snapshots/test_pathological_types/test_service/pathological_method.py index 137add7b..2c325e64 100644 --- a/tests/codegen/snapshot/snapshots/test_pathological_types/test_service/pathological_method.py +++ b/tests/v1/codegen/snapshot/snapshots/test_pathological_types/test_service/pathological_method.py @@ -476,8 +476,3 @@ class Pathological_MethodInput(TypedDict): req_undefined: None string: NotRequired[str | None] undefined: NotRequired[None] - - -Pathological_MethodInputTypeAdapter: TypeAdapter[Pathological_MethodInput] = ( - TypeAdapter(Pathological_MethodInput) -) diff --git a/tests/codegen/snapshot/snapshots/test_unknown_enum/__init__.py b/tests/v1/codegen/snapshot/snapshots/test_unknown_enum/__init__.py similarity index 100% rename from tests/codegen/snapshot/snapshots/test_unknown_enum/__init__.py rename to tests/v1/codegen/snapshot/snapshots/test_unknown_enum/__init__.py diff --git a/tests/codegen/snapshot/snapshots/test_unknown_enum/enumService/__init__.py b/tests/v1/codegen/snapshot/snapshots/test_unknown_enum/enumService/__init__.py similarity index 96% rename from tests/codegen/snapshot/snapshots/test_unknown_enum/enumService/__init__.py rename to tests/v1/codegen/snapshot/snapshots/test_unknown_enum/enumService/__init__.py index ab9eaa08..e1067475 100644 --- a/tests/codegen/snapshot/snapshots/test_unknown_enum/enumService/__init__.py +++ b/tests/v1/codegen/snapshot/snapshots/test_unknown_enum/enumService/__init__.py @@ -13,7 +13,6 @@ NeedsenumErrors, NeedsenumErrorsTypeAdapter, NeedsenumInput, - NeedsenumInputTypeAdapter, NeedsenumOutput, NeedsenumOutputTypeAdapter, encode_NeedsenumInput, @@ -22,7 +21,6 @@ NeedsenumobjectErrors, NeedsenumobjectErrorsTypeAdapter, NeedsenumobjectInput, - NeedsenumobjectInputTypeAdapter, NeedsenumobjectOutput, NeedsenumobjectOutputTypeAdapter, encode_NeedsenumobjectInput, diff --git a/tests/codegen/snapshot/snapshots/test_unknown_enum/enumService/needsEnum.py b/tests/v1/codegen/snapshot/snapshots/test_unknown_enum/enumService/needsEnum.py similarity index 94% rename from tests/codegen/snapshot/snapshots/test_unknown_enum/enumService/needsEnum.py rename to tests/v1/codegen/snapshot/snapshots/test_unknown_enum/enumService/needsEnum.py index 8f325775..69b976b5 100644 --- a/tests/codegen/snapshot/snapshots/test_unknown_enum/enumService/needsEnum.py +++ b/tests/v1/codegen/snapshot/snapshots/test_unknown_enum/enumService/needsEnum.py @@ -29,8 +29,6 @@ def encode_NeedsenumInput(x: "NeedsenumInput") -> Any: return x -NeedsenumInputTypeAdapter: TypeAdapter[NeedsenumInput] = TypeAdapter(NeedsenumInput) - NeedsenumOutput = Annotated[ Literal["out_first", "out_second"] | RiverUnknownValue, WrapValidator(translate_unknown_value), diff --git a/tests/codegen/snapshot/snapshots/test_unknown_enum/enumService/needsEnumObject.py b/tests/v1/codegen/snapshot/snapshots/test_unknown_enum/enumService/needsEnumObject.py similarity index 96% rename from tests/codegen/snapshot/snapshots/test_unknown_enum/enumService/needsEnumObject.py rename to tests/v1/codegen/snapshot/snapshots/test_unknown_enum/enumService/needsEnumObject.py index 4e1243a3..dd61a2d7 100644 --- a/tests/codegen/snapshot/snapshots/test_unknown_enum/enumService/needsEnumObject.py +++ b/tests/v1/codegen/snapshot/snapshots/test_unknown_enum/enumService/needsEnumObject.py @@ -77,11 +77,6 @@ def encode_NeedsenumobjectInput( ) -NeedsenumobjectInputTypeAdapter: TypeAdapter[NeedsenumobjectInput] = TypeAdapter( - NeedsenumobjectInput -) - - class NeedsenumobjectOutputFooOneOf_out_first(BaseModel): kind: Annotated[Literal["out_first"], Field(alias="$kind")] = "out_first" foo: int diff --git a/tests/codegen/snapshot/test_enum.py b/tests/v1/codegen/snapshot/test_enum.py similarity index 89% rename from tests/codegen/snapshot/test_enum.py rename to tests/v1/codegen/snapshot/test_enum.py index e45509d2..72a3f2e2 100644 --- a/tests/codegen/snapshot/test_enum.py +++ b/tests/v1/codegen/snapshot/test_enum.py @@ -4,7 +4,7 @@ from pytest_snapshot.plugin import Snapshot -from tests.codegen.snapshot.codegen_snapshot_fixtures import validate_codegen +from tests.fixtures.codegen_snapshot_fixtures import validate_codegen test_unknown_enum_schema = """ { @@ -173,15 +173,17 @@ def test_unknown_enum(snapshot: Snapshot) -> None: validate_codegen( snapshot=snapshot, + snapshot_dir="tests/v1/codegen/snapshot/snapshots", read_schema=lambda: StringIO(test_unknown_enum_schema), target_path="test_unknown_enum", client_name="foo", + protocol_version="v1.1", ) - import tests.codegen.snapshot.snapshots.test_unknown_enum + import tests.v1.codegen.snapshot.snapshots.test_unknown_enum - importlib.reload(tests.codegen.snapshot.snapshots.test_unknown_enum) - from tests.codegen.snapshot.snapshots.test_unknown_enum.enumService.needsEnum import ( # noqa + importlib.reload(tests.v1.codegen.snapshot.snapshots.test_unknown_enum) + from tests.v1.codegen.snapshot.snapshots.test_unknown_enum.enumService.needsEnum import ( # noqa NeedsenumErrorsTypeAdapter, ) @@ -209,15 +211,17 @@ def test_unknown_enum(snapshot: Snapshot) -> None: def test_unknown_enum_field_aliases(snapshot: Snapshot) -> None: validate_codegen( snapshot=snapshot, + snapshot_dir="tests/v1/codegen/snapshot/snapshots", read_schema=lambda: StringIO(test_unknown_enum_schema), target_path="test_unknown_enum", client_name="foo", + protocol_version="v1.1", ) - import tests.codegen.snapshot.snapshots.test_unknown_enum + import tests.v1.codegen.snapshot.snapshots.test_unknown_enum - importlib.reload(tests.codegen.snapshot.snapshots.test_unknown_enum) - from tests.codegen.snapshot.snapshots.test_unknown_enum.enumService.needsEnumObject import ( # noqa + importlib.reload(tests.v1.codegen.snapshot.snapshots.test_unknown_enum) + from tests.v1.codegen.snapshot.snapshots.test_unknown_enum.enumService.needsEnumObject import ( # noqa NeedsenumobjectOutputTypeAdapter, NeedsenumobjectOutput, NeedsenumobjectOutputFooOneOf_out_first, diff --git a/tests/codegen/snapshot/test_pathological.py b/tests/v1/codegen/snapshot/test_pathological.py similarity index 51% rename from tests/codegen/snapshot/test_pathological.py rename to tests/v1/codegen/snapshot/test_pathological.py index 68775c8f..789a2d4c 100644 --- a/tests/codegen/snapshot/test_pathological.py +++ b/tests/v1/codegen/snapshot/test_pathological.py @@ -1,12 +1,14 @@ from pytest_snapshot.plugin import Snapshot -from tests.codegen.snapshot.codegen_snapshot_fixtures import validate_codegen +from tests.fixtures.codegen_snapshot_fixtures import validate_codegen async def test_pathological_types(snapshot: Snapshot) -> None: validate_codegen( snapshot=snapshot, - read_schema=lambda: open("tests/codegen/types/schema.json"), + snapshot_dir="tests/v1/codegen/snapshot/snapshots", + read_schema=lambda: open("tests/v1/codegen/types/schema.json"), target_path="test_pathological_types", client_name="PathologicalClient", + protocol_version="v1.1", ) diff --git a/tests/codegen/stream/schema.json b/tests/v1/codegen/stream/schema.json similarity index 100% rename from tests/codegen/stream/schema.json rename to tests/v1/codegen/stream/schema.json diff --git a/tests/codegen/stream/test_stream.py b/tests/v1/codegen/stream/test_stream.py similarity index 74% rename from tests/codegen/stream/test_stream.py rename to tests/v1/codegen/stream/test_stream.py index f3966043..2ac02a76 100644 --- a/tests/codegen/stream/test_stream.py +++ b/tests/v1/codegen/stream/test_stream.py @@ -5,8 +5,8 @@ from pytest_snapshot.plugin import Snapshot from replit_river.client import Client, RiverUnknownError -from tests.codegen.snapshot.codegen_snapshot_fixtures import validate_codegen -from tests.common_handlers import basic_stream, error_stream +from tests.fixtures.codegen_snapshot_fixtures import validate_codegen +from tests.v1.common_handlers import basic_stream, error_stream _AlreadyGenerated = False @@ -17,15 +17,17 @@ def stream_client_codegen(snapshot: Snapshot) -> Literal[True]: if not _AlreadyGenerated: validate_codegen( snapshot=snapshot, - read_schema=lambda: open("tests/codegen/stream/schema.json"), + snapshot_dir="tests/v1/codegen/snapshot/snapshots", + read_schema=lambda: open("tests/v1/codegen/stream/schema.json"), target_path="test_basic_stream", client_name="StreamClient", + protocol_version="v1.1", ) _AlreadyGenerated = True - import tests.codegen.snapshot.snapshots.test_basic_stream + import tests.v1.codegen.snapshot.snapshots.test_basic_stream - importlib.reload(tests.codegen.snapshot.snapshots.test_basic_stream) + importlib.reload(tests.v1.codegen.snapshot.snapshots.test_basic_stream) return True @@ -34,10 +36,10 @@ async def test_basic_stream( stream_client_codegen: Literal[True], client: Client, ) -> None: - from tests.codegen.snapshot.snapshots.test_basic_stream import ( + from tests.v1.codegen.snapshot.snapshots.test_basic_stream import ( StreamClient, # noqa: E501 ) - from tests.codegen.snapshot.snapshots.test_basic_stream.test_service.stream_method import ( # noqa: E501 + from tests.v1.codegen.snapshot.snapshots.test_basic_stream.test_service.stream_method import ( # noqa: E501 Stream_MethodInput, Stream_MethodOutput, ) @@ -63,7 +65,7 @@ async def test_error_stream( erroringClient: Client, phase: int, ) -> None: - from tests.codegen.snapshot.snapshots.test_basic_stream import ( + from tests.v1.codegen.snapshot.snapshots.test_basic_stream import ( StreamClient, # noqa: E501 ) diff --git a/tests/codegen/test_rpc.py b/tests/v1/codegen/test_rpc.py similarity index 78% rename from tests/codegen/test_rpc.py rename to tests/v1/codegen/test_rpc.py index 8ab82095..55837190 100644 --- a/tests/codegen/test_rpc.py +++ b/tests/v1/codegen/test_rpc.py @@ -14,39 +14,40 @@ from replit_river.codegen.client import schema_to_river_client_codegen from replit_river.error_schema import RiverException from replit_river.rpc import rpc_method_handler -from tests.common_handlers import basic_rpc_method from tests.conftest import HandlerMapping, deserialize_request, serialize_response +from tests.v1.common_handlers import basic_rpc_method @pytest.fixture(scope="session", autouse=True) def generate_rpc_client() -> None: - shutil.rmtree("tests/codegen/rpc/generated", ignore_errors=True) - os.makedirs("tests/codegen/rpc/generated") + shutil.rmtree("tests/v1/codegen/rpc/generated", ignore_errors=True) + os.makedirs("tests/v1/codegen/rpc/generated") def file_opener(path: Path) -> TextIO: return open(path, "w") schema_to_river_client_codegen( - read_schema=lambda: open("tests/codegen/rpc/schema.json"), - target_path="tests/codegen/rpc/generated", + read_schema=lambda: open("tests/v1/codegen/rpc/schema.json"), + target_path="tests/v1/codegen/rpc/generated", client_name="RpcClient", typed_dict_inputs=True, file_opener=file_opener, method_filter=None, + protocol_version="v1.1", ) @pytest.fixture(scope="session", autouse=True) def reload_rpc_import(generate_rpc_client: None) -> None: - import tests.codegen.rpc.generated + import tests.v1.codegen.rpc.generated - importlib.reload(tests.codegen.rpc.generated) + importlib.reload(tests.v1.codegen.rpc.generated) @pytest.mark.asyncio @pytest.mark.parametrize("handlers", [{**basic_rpc_method}]) async def test_basic_rpc(client: Client) -> None: - from tests.codegen.rpc.generated import RpcClient + from tests.v1.codegen.rpc.generated import RpcClient res = await RpcClient(client).test_service.rpc_method( { @@ -75,7 +76,7 @@ async def rpc_timeout_handler(request: str, context: grpc.aio.ServicerContext) - @pytest.mark.asyncio @pytest.mark.parametrize("handlers", [{**rpc_timeout_method}]) async def test_rpc_timeout(client: Client) -> None: - from tests.codegen.rpc.generated import RpcClient + from tests.v1.codegen.rpc.generated import RpcClient with pytest.raises(RiverException): await RpcClient(client).test_service.rpc_method( diff --git a/tests/codegen/types/schema.json b/tests/v1/codegen/types/schema.json similarity index 100% rename from tests/codegen/types/schema.json rename to tests/v1/codegen/types/schema.json diff --git a/tests/common_handlers.py b/tests/v1/common_handlers.py similarity index 100% rename from tests/common_handlers.py rename to tests/v1/common_handlers.py diff --git a/tests/river_fixtures/clientserver.py b/tests/v1/river_fixtures/clientserver.py similarity index 97% rename from tests/river_fixtures/clientserver.py rename to tests/v1/river_fixtures/clientserver.py index bf576b5c..c01a9ee3 100644 --- a/tests/river_fixtures/clientserver.py +++ b/tests/v1/river_fixtures/clientserver.py @@ -9,7 +9,7 @@ from replit_river.server import Server from replit_river.transport_options import TransportOptions from tests.conftest import HandlerMapping -from tests.river_fixtures.logging import NoErrors # noqa: E402 +from tests.v1.river_fixtures.logging import NoErrors # noqa: E402 @pytest.fixture diff --git a/tests/river_fixtures/logging.py b/tests/v1/river_fixtures/logging.py similarity index 100% rename from tests/river_fixtures/logging.py rename to tests/v1/river_fixtures/logging.py diff --git a/tests/test_communication.py b/tests/v1/test_communication.py similarity index 99% rename from tests/test_communication.py rename to tests/v1/test_communication.py index 87de1811..470b368b 100644 --- a/tests/test_communication.py +++ b/tests/v1/test_communication.py @@ -9,12 +9,6 @@ from replit_river.error_schema import RiverError from replit_river.rpc import subscription_method_handler from replit_river.transport_options import MAX_MESSAGE_BUFFER_SIZE -from tests.common_handlers import ( - basic_rpc_method, - basic_stream, - basic_subscription, - basic_upload, -) from tests.conftest import ( HandlerMapping, deserialize_error, @@ -23,6 +17,12 @@ serialize_request, serialize_response, ) +from tests.v1.common_handlers import ( + basic_rpc_method, + basic_stream, + basic_subscription, + basic_upload, +) @pytest.mark.asyncio diff --git a/tests/test_handshake.py b/tests/v1/test_handshake.py similarity index 100% rename from tests/test_handshake.py rename to tests/v1/test_handshake.py diff --git a/tests/test_message_buffer.py b/tests/v1/test_message_buffer.py similarity index 98% rename from tests/test_message_buffer.py rename to tests/v1/test_message_buffer.py index 02a21ccb..d5d1bda4 100644 --- a/tests/test_message_buffer.py +++ b/tests/v1/test_message_buffer.py @@ -11,7 +11,7 @@ def mock_transport_message(seq: int) -> TransportMessage: seq=seq, id="test", ack=0, - from_="test", # type: ignore + from_="test", to="test", streamId="test", controlFlags=0, diff --git a/tests/test_opentelemetry.py b/tests/v1/test_opentelemetry.py similarity index 98% rename from tests/test_opentelemetry.py rename to tests/v1/test_opentelemetry.py index 801b133d..c47b5418 100644 --- a/tests/test_opentelemetry.py +++ b/tests/v1/test_opentelemetry.py @@ -11,12 +11,6 @@ from replit_river.client import Client from replit_river.error_schema import RiverError, RiverException from replit_river.rpc import stream_method_handler -from tests.common_handlers import ( - basic_rpc_method, - basic_stream, - basic_subscription, - basic_upload, -) from tests.conftest import ( HandlerMapping, deserialize_error, @@ -25,7 +19,13 @@ serialize_request, serialize_response, ) -from tests.river_fixtures.logging import NoErrors +from tests.v1.common_handlers import ( + basic_rpc_method, + basic_stream, + basic_subscription, + basic_upload, +) +from tests.v1.river_fixtures.logging import NoErrors @pytest.mark.asyncio diff --git a/tests/test_rate_limiter.py b/tests/v1/test_rate_limiter.py similarity index 100% rename from tests/test_rate_limiter.py rename to tests/v1/test_rate_limiter.py diff --git a/tests/test_seq_manager.py b/tests/v1/test_seq_manager.py similarity index 97% rename from tests/test_seq_manager.py rename to tests/v1/test_seq_manager.py index fba7f58a..1373e0f9 100644 --- a/tests/test_seq_manager.py +++ b/tests/v1/test_seq_manager.py @@ -6,7 +6,7 @@ SeqManager, ) from tests.conftest import transport_message -from tests.river_fixtures.logging import NoErrors +from tests.v1.river_fixtures.logging import NoErrors @pytest.mark.asyncio diff --git a/tests/test_timeout.py b/tests/v1/test_timeout.py similarity index 97% rename from tests/test_timeout.py rename to tests/v1/test_timeout.py index 94e033fd..1ac3db97 100644 --- a/tests/test_timeout.py +++ b/tests/v1/test_timeout.py @@ -7,15 +7,15 @@ from replit_river.client import Client from replit_river.error_schema import ERROR_CODE_CANCEL, RiverException -from tests.common_handlers import ( - rpc_method_handler, -) from tests.conftest import ( HandlerMapping, deserialize_error, deserialize_response, serialize_response, ) +from tests.v1.common_handlers import ( + rpc_method_handler, +) async def rpc_slow_handler(duration: float, context: grpc.aio.ServicerContext) -> str: diff --git a/tests/v2/codegen/snapshot/snapshots/test_basic_rpc/__init__.py b/tests/v2/codegen/snapshot/snapshots/test_basic_rpc/__init__.py new file mode 100644 index 00000000..58d21c8a --- /dev/null +++ b/tests/v2/codegen/snapshot/snapshots/test_basic_rpc/__init__.py @@ -0,0 +1,13 @@ +# Code generated by river.codegen. DO NOT EDIT. +from pydantic import BaseModel +from typing import Literal + +import replit_river as river + + +from .test_service import Test_ServiceService + + +class StreamClient: + def __init__(self, client: river.v2.Client[Literal[None]]): + self.test_service = Test_ServiceService(client) diff --git a/tests/v2/codegen/snapshot/snapshots/test_basic_rpc/test_service/__init__.py b/tests/v2/codegen/snapshot/snapshots/test_basic_rpc/test_service/__init__.py new file mode 100644 index 00000000..f7c04990 --- /dev/null +++ b/tests/v2/codegen/snapshot/snapshots/test_basic_rpc/test_service/__init__.py @@ -0,0 +1,41 @@ +# Code generated by river.codegen. DO NOT EDIT. +from collections.abc import AsyncIterable, AsyncIterator +from typing import Any +import datetime + +from pydantic import TypeAdapter + +from replit_river.error_schema import RiverError, RiverErrorTypeAdapter +import replit_river as river + + +from .rpc_method import ( + Rpc_MethodInit, + Rpc_MethodOutput, + Rpc_MethodOutputTypeAdapter, + encode_Rpc_MethodInit, +) + + +class Test_ServiceService: + def __init__(self, client: river.v2.Client[Any]): + self.client = client + + async def rpc_method( + self, + init: Rpc_MethodInit, + timeout: datetime.timedelta, + ) -> Rpc_MethodOutput: + return await self.client.send_rpc( + "test_service", + "rpc_method", + init, + encode_Rpc_MethodInit, + lambda x: Rpc_MethodOutputTypeAdapter.validate_python( + x # type: ignore[arg-type] + ), + lambda x: RiverErrorTypeAdapter.validate_python( + x # type: ignore[arg-type] + ), + timeout, + ) diff --git a/tests/v2/codegen/snapshot/snapshots/test_basic_rpc/test_service/rpc_method.py b/tests/v2/codegen/snapshot/snapshots/test_basic_rpc/test_service/rpc_method.py new file mode 100644 index 00000000..5fe764eb --- /dev/null +++ b/tests/v2/codegen/snapshot/snapshots/test_basic_rpc/test_service/rpc_method.py @@ -0,0 +1,49 @@ +# Code generated by river.codegen. DO NOT EDIT. +from collections.abc import AsyncIterable, AsyncIterator +import datetime +from typing import ( + Any, + Literal, + Mapping, + NotRequired, + TypedDict, +) +from typing_extensions import Annotated + +from pydantic import BaseModel, Field, TypeAdapter, WrapValidator +from replit_river.error_schema import RiverError +from replit_river.client import ( + RiverUnknownError, + translate_unknown_error, + RiverUnknownValue, + translate_unknown_value, +) + +import replit_river as river + + +def encode_Rpc_MethodInit( + x: "Rpc_MethodInit", +) -> Any: + return { + k: v + for (k, v) in ( + { + "data": x.get("data"), + } + ).items() + if v is not None + } + + +class Rpc_MethodInit(TypedDict): + data: str + + +class Rpc_MethodOutput(BaseModel): + data: str + + +Rpc_MethodOutputTypeAdapter: TypeAdapter[Rpc_MethodOutput] = TypeAdapter( + Rpc_MethodOutput +) diff --git a/tests/v2/codegen/snapshot/snapshots/test_basic_stream/__init__.py b/tests/v2/codegen/snapshot/snapshots/test_basic_stream/__init__.py new file mode 100644 index 00000000..58d21c8a --- /dev/null +++ b/tests/v2/codegen/snapshot/snapshots/test_basic_stream/__init__.py @@ -0,0 +1,13 @@ +# Code generated by river.codegen. DO NOT EDIT. +from pydantic import BaseModel +from typing import Literal + +import replit_river as river + + +from .test_service import Test_ServiceService + + +class StreamClient: + def __init__(self, client: river.v2.Client[Literal[None]]): + self.test_service = Test_ServiceService(client) diff --git a/tests/v2/codegen/snapshot/snapshots/test_basic_stream/test_service/__init__.py b/tests/v2/codegen/snapshot/snapshots/test_basic_stream/test_service/__init__.py new file mode 100644 index 00000000..90080831 --- /dev/null +++ b/tests/v2/codegen/snapshot/snapshots/test_basic_stream/test_service/__init__.py @@ -0,0 +1,44 @@ +# Code generated by river.codegen. DO NOT EDIT. +from collections.abc import AsyncIterable, AsyncIterator +from typing import Any +import datetime + +from pydantic import TypeAdapter + +from replit_river.error_schema import RiverError, RiverErrorTypeAdapter +import replit_river as river + + +from .stream_method import ( + Stream_MethodInit, + Stream_MethodInput, + Stream_MethodOutput, + Stream_MethodOutputTypeAdapter, + encode_Stream_MethodInit, + encode_Stream_MethodInput, +) + + +class Test_ServiceService: + def __init__(self, client: river.v2.Client[Any]): + self.client = client + + async def stream_method( + self, + init: Stream_MethodInit, + inputStream: AsyncIterable[Stream_MethodInput], + ) -> AsyncIterator[Stream_MethodOutput | RiverError | RiverError]: + return self.client.send_stream( + "test_service", + "stream_method", + init, + inputStream, + encode_Stream_MethodInit, + encode_Stream_MethodInput, + lambda x: Stream_MethodOutputTypeAdapter.validate_python( + x # type: ignore[arg-type] + ), + lambda x: RiverErrorTypeAdapter.validate_python( + x # type: ignore[arg-type] + ), + ) diff --git a/tests/v2/codegen/snapshot/snapshots/test_basic_stream/test_service/emit_error.py b/tests/v2/codegen/snapshot/snapshots/test_basic_stream/test_service/emit_error.py new file mode 100644 index 00000000..e7005c29 --- /dev/null +++ b/tests/v2/codegen/snapshot/snapshots/test_basic_stream/test_service/emit_error.py @@ -0,0 +1,40 @@ +# Code generated by river.codegen. DO NOT EDIT. +from collections.abc import AsyncIterable, AsyncIterator +import datetime +from typing import ( + Any, + Literal, + Mapping, + NotRequired, + TypedDict, +) +from typing_extensions import Annotated + +from pydantic import BaseModel, Field, TypeAdapter, WrapValidator +from replit_river.error_schema import RiverError +from replit_river.client import ( + RiverUnknownError, + translate_unknown_error, + RiverUnknownValue, + translate_unknown_value, +) + +import replit_river as river + + +class Emit_ErrorErrorsOneOf_DATA_LOSS(RiverError): + code: Literal["DATA_LOSS"] + message: str + + +class Emit_ErrorErrorsOneOf_UNEXPECTED_DISCONNECT(RiverError): + code: Literal["UNEXPECTED_DISCONNECT"] + message: str + + +Emit_ErrorErrors = Annotated[ + Emit_ErrorErrorsOneOf_DATA_LOSS + | Emit_ErrorErrorsOneOf_UNEXPECTED_DISCONNECT + | RiverUnknownError, + WrapValidator(translate_unknown_error), +] diff --git a/tests/v2/codegen/snapshot/snapshots/test_basic_stream/test_service/stream_method.py b/tests/v2/codegen/snapshot/snapshots/test_basic_stream/test_service/stream_method.py new file mode 100644 index 00000000..9a0fc5d0 --- /dev/null +++ b/tests/v2/codegen/snapshot/snapshots/test_basic_stream/test_service/stream_method.py @@ -0,0 +1,59 @@ +# Code generated by river.codegen. DO NOT EDIT. +from collections.abc import AsyncIterable, AsyncIterator +import datetime +from typing import ( + Any, + Literal, + Mapping, + NotRequired, + TypedDict, +) +from typing_extensions import Annotated + +from pydantic import BaseModel, Field, TypeAdapter, WrapValidator +from replit_river.error_schema import RiverError +from replit_river.client import ( + RiverUnknownError, + translate_unknown_error, + RiverUnknownValue, + translate_unknown_value, +) + +import replit_river as river + + +def encode_Stream_MethodInit( + _: "Stream_MethodInit", +) -> Any: + return {} + + +class Stream_MethodInit(TypedDict): + pass + + +def encode_Stream_MethodInput( + x: "Stream_MethodInput", +) -> Any: + return { + k: v + for (k, v) in ( + { + "data": x.get("data"), + } + ).items() + if v is not None + } + + +class Stream_MethodInput(TypedDict): + data: str + + +class Stream_MethodOutput(BaseModel): + data: str + + +Stream_MethodOutputTypeAdapter: TypeAdapter[Stream_MethodOutput] = TypeAdapter( + Stream_MethodOutput +) diff --git a/tests/v2/datagrams.py b/tests/v2/datagrams.py new file mode 100644 index 00000000..78feb447 --- /dev/null +++ b/tests/v2/datagrams.py @@ -0,0 +1,104 @@ +from dataclasses import dataclass +from typing import ( + Any, + Literal, + NewType, + TypeAlias, +) + +Datagram = dict[str, Any] +TestTransport: TypeAlias = "FromClient | ToClient | WaitForClosed" + +StreamId = NewType("StreamId", str) +ClientId = NewType("ClientId", str) +ServerId = NewType("ServerId", str) +SessionId = NewType("SessionId", str) + + +@dataclass(frozen=True) +class StreamAlias: + alias_id: int + + +@dataclass(frozen=True) +class ValueSet: + seq: int + ack: int + from_: ServerId | None = None + to: ClientId | None = None + procedureName: str | None = None + serviceName: str | None = None + create_alias: StreamAlias | None = None + stream_alias: StreamAlias | None = None + payload: Datagram | None = None + stream_closed: Literal[True] | None = None + + +@dataclass(frozen=True) +class FromClient: + handshake_request: tuple[ClientId, ServerId, SessionId] | ValueSet | None = None + stream_open: ( + tuple[ClientId, ServerId, str, str, StreamId, Datagram] | ValueSet | None + ) = None + stream_frame: tuple[ClientId, ServerId, int, int, Datagram] | ValueSet | None = None + stream_closed: Literal[True] | None = None + + +@dataclass(frozen=True) +class ToClient: + seq: int + ack: int + control_flags: int = 0 + handshake_response: bool | None = None + stream_frame: tuple[StreamAlias, Datagram] | None = None + stream_close: StreamAlias | None = None + + +@dataclass(frozen=True) +class WaitForClosed: + pass + + +def decode_FromClient(datagram: dict[str, Any]) -> FromClient: + assert "from" in datagram + assert "to" in datagram + if datagram.get("payload", {}).get("type") == "HANDSHAKE_REQ": + assert "payload" in datagram + assert "sessionId" in datagram["payload"] + return FromClient( + handshake_request=( + ClientId(datagram["from"]), + ServerId(datagram["to"]), + SessionId(datagram["payload"]["sessionId"]), + ) + ) + elif datagram.get("controlFlags", 0) & 0b00010 > 0: # STREAM_OPEN_BIT + return FromClient( + stream_open=( + ClientId(datagram["from"]), + ServerId(datagram["to"]), + datagram["serviceName"], + datagram["procedureName"], + StreamId(datagram["streamId"]), + datagram["payload"], + ), + stream_closed=( + datagram["controlFlags"] & 0b01000 > 0 # STREAM_CLOSED_BIT + ) + or None, + ) + elif datagram: + return FromClient( + stream_frame=( + ClientId(datagram["from"]), + ServerId(datagram["to"]), + datagram["seq"], + datagram["ack"], + datagram["payload"], + ) + ) + raise ValueError("Unexpected datagram: %r", datagram) + + +def parser(datagram: dict[str, Any]) -> FromClient: + return decode_FromClient(datagram) diff --git a/tests/v2/fixtures.py b/tests/v2/fixtures.py new file mode 100644 index 00000000..dda934fc --- /dev/null +++ b/tests/v2/fixtures.py @@ -0,0 +1,209 @@ +import asyncio +import copy +from collections import deque +from typing import ( + Any, + AsyncGenerator, + AsyncIterator, + TypeAlias, + assert_never, +) + +import msgpack +import pytest +from aiochannel import Channel +from websockets.asyncio.server import ServerConnection, serve + +from replit_river.transport_options import TransportOptions, UriAndMetadata +from replit_river.v2.client import Client +from tests.v2.datagrams import ( + FromClient, + TestTransport, + ToClient, + WaitForClosed, + parser, +) +from tests.v2.interpreter import make_interpreters + +# Client -> Server +ClientToServerChannel: TypeAlias = Channel[Any] +# Server -> Client +ServerToClientChannel: TypeAlias = Channel[Any] + + +@pytest.fixture +async def raw_websocket_meta() -> AsyncIterator[ + tuple[ServerToClientChannel, ClientToServerChannel, UriAndMetadata[None]] +]: + # server -> client + server_to_client = Channel[dict[str, Any]](maxsize=1) + # client -> server + client_to_server = Channel[dict[str, Any]](maxsize=1) + + # Service websocket connection + async def handle(websocket: ServerConnection) -> None: + async def emit() -> None: + async for msg in client_to_server: + await websocket.send(msgpack.packb(msg)) + + emit_task = asyncio.create_task(emit()) + try: + while message := await websocket.recv(decode=False): + assert isinstance(message, bytes) + unpacked = msgpack.unpackb(message, timestamp=3) + await server_to_client.put(unpacked) + finally: + server_to_client.close() + client_to_server.close() + emit_task.cancel() + await emit_task + return None + + ipv4_laddr: str | None = None + async with serve(handle, "localhost") as server: + for sock in server.sockets: + if (pair := sock.getsockname())[0] == "127.0.0.1": + ipv4_laddr = "ws://%s:%d" % pair + serve_forever = asyncio.create_task(server.serve_forever()) + assert ipv4_laddr + yield ( + server_to_client, + client_to_server, + UriAndMetadata(uri=ipv4_laddr, metadata=None), + ) + + serve_forever.cancel() + + +@pytest.fixture +async def bound_client( + raw_websocket_meta: tuple[ + ClientToServerChannel, + ServerToClientChannel, + UriAndMetadata[None], + ], + expected: deque[TestTransport], +) -> AsyncGenerator[Client[None], None]: + # Do our best to not modify the test data + _expected = expected + expected = copy.deepcopy(_expected) + + # client-to-server handler + # + # Consume FromClient events, optionally emitting datagrams to be written directly + # back to the client. + # + # This represents the direct request-response flow, but does not handle + # server-emitted events. + + client_to_server, server_to_client, uri_and_metadata = raw_websocket_meta + + async def messages_from_client() -> None: + async for msg in client_to_server: + parsed = parser(msg) + try: + next_expected = await anext(messages_from_client_channel) + except StopAsyncIteration: + break + response = from_client_interpreter(received=parsed, expected=next_expected) + if response is not None: + await server_to_client.put(response) + processing_finished.set() + + server_task = asyncio.create_task(messages_from_client()) + + # server-to-client handler + # + # Consume ToClient events, optionally emitting datagrams to be written directly to + # the client. + # + # This represents the other half of the "server" lifecycle, where the server can + # choose to emit events directly without a request. + + messages_to_client_channel = Channel[ToClient]() + + async def messages_to_client() -> None: + our_task = asyncio.current_task() + while our_task and not our_task.cancelled(): + try: + next_expected = await anext(messages_to_client_channel) + except StopAsyncIteration: + break + + response = to_client_interpreter(next_expected) + if response is not None: + await server_to_client.put(response) + processing_finished.set() + + processor_task = asyncio.create_task(messages_to_client()) + + # This consumes from the "expected" queue and routes messages to waiting channels + # + # This also handles shutdown + + async def driver() -> None: + while expected: + next_expected = expected.popleft() + if isinstance(next_expected, FromClient): + await messages_from_client_channel.put(next_expected) + elif isinstance(next_expected, ToClient): + await messages_to_client_channel.put(next_expected) + elif isinstance(next_expected, WaitForClosed): + countdown = 100 + messages_to_client_channel.close() + while not processor_task.done(): + if countdown <= 0: + break + countdown -= 1 + await asyncio.sleep(0.1) + client_to_server.close() + await client.close() + break + else: + assert_never(next_expected) + await asyncio.sleep(0.1) + await processing_finished.wait() + + driver_task = asyncio.create_task(driver()) + + async def uri_and_metadata_factory() -> UriAndMetadata[None]: + return uri_and_metadata + + client = Client( + uri_and_metadata_factory=uri_and_metadata_factory, + client_id="client-001", + server_id="server-001", + transport_options=TransportOptions( + close_session_check_interval_ms=500, + ), + ) + + from_client_interpreter, to_client_interpreter = make_interpreters() + + processing_finished = asyncio.Event() + + messages_from_client_channel = Channel[FromClient]() + + yield client + + await driver_task + + assert len(expected) == 0, "Unconsumed messages from 'expected'" + assert messages_to_client_channel.qsize() == 0, ( + "Dangling messages the client has not consumed" + ) + assert messages_from_client_channel.qsize() == 0, ( + "Dangling messages the processor has not consumed" + ) + + messages_to_client_channel.close() + messages_from_client_channel.close() + + processing_finished.set() + + server_task.cancel() + processor_task.cancel() + + await client.close() + await server_task + await processor_task diff --git a/tests/v2/interpreter.py b/tests/v2/interpreter.py new file mode 100644 index 00000000..d706316c --- /dev/null +++ b/tests/v2/interpreter.py @@ -0,0 +1,200 @@ +from typing import ( + Any, + Callable, + NotRequired, + Protocol, + TypedDict, +) + +import nanoid + +from tests.v2.datagrams import ( + ClientId, + FromClient, + ServerId, + StreamAlias, + StreamId, + ToClient, + ValueSet, +) + +Datagram = dict[str, Any] + + +class FromClientInterpreter(Protocol): + def __call__( + self, *, received: FromClient, expected: FromClient + ) -> Datagram | None: ... + + +class _TestTransportState(TypedDict): + streams: dict[StreamAlias, tuple[ClientId, ServerId, StreamId]] + client_id: NotRequired[str] + server_id: NotRequired[str] + + +def make_interpreters() -> tuple[ + FromClientInterpreter, + Callable[[ToClient], Datagram | None], +]: + state: _TestTransportState = { + "streams": {}, + } + + def from_client_interpreter( + received: FromClient, + expected: FromClient, + ) -> Datagram | None: + if isinstance(received.handshake_request, tuple): + (from_, to, session_id) = received.handshake_request + assert isinstance(expected.handshake_request, ValueSet), ( + f"Expected ValueSet 1: {repr(received)}, {repr(expected)}" + ) + assert expected.handshake_request.from_, ( + f"Expected {expected.handshake_request.from_}" + ) + assert expected.handshake_request.to, ( + "Expected {expected.handshake_request.to}" + ) + assert expected.handshake_request.from_ == to, ( + f"Expected {expected.handshake_request.from_} == {to}" + ) + assert expected.handshake_request.to == from_, ( + "Expected {expected.handshake_request.to} == {from_}" + ) + return _build_datagram( + seq=expected.handshake_request.seq, + packet_id=nanoid.generate(), + stream_id=nanoid.generate(), + control_flags=0, + ack=expected.handshake_request.ack, + client_id=expected.handshake_request.from_, + server_id=expected.handshake_request.to, + payload=_build_handshake_resp(session_id), + ) + elif isinstance(received.stream_open, tuple): + (from_, to, service_name, procedure_name, stream_id, payload) = ( + received.stream_open + ) + assert isinstance(expected.stream_open, ValueSet), "Expected ValueSet 2" + assert expected.stream_open.from_ == to, ( + f"Expected {expected.stream_open.from_} == {to}" + ) + assert expected.stream_open.to == from_, ( + "Expected {expected.stream_open.to} == {from_}" + ) + assert expected.stream_open.serviceName == service_name, ( + f"Expected {expected.stream_open.serviceName} == {service_name}" + ) + assert expected.stream_open.procedureName == procedure_name, ( + f"Expected {expected.stream_open.procedureName} == {procedure_name}" + ) + assert expected.stream_open.create_alias, ( + "Expected create_alias to be a StreamAlias" + ) + if expected.stream_open.payload is not None and payload is not None: + assert expected.stream_open.payload == payload, ( + f"Expected {expected.stream_open.payload} == {payload}" + ) + assert expected.stream_open.stream_closed or not received.stream_closed, ( + f"Are we self-closing? {expected.stream_open.stream_closed} " + f"or not {received.stream_closed}" + ) + # Do it all again because mypy can't infer correctly + alias_mapping: tuple[ClientId, ServerId, StreamId] = ( + ClientId(from_), + ServerId(to), + StreamId(stream_id), + ) + state["streams"][expected.stream_open.create_alias] = alias_mapping + return None + elif isinstance(received.stream_frame, tuple): + (from_, to, seq, ack, payload) = received.stream_frame + assert isinstance(expected.stream_frame, ValueSet), "Expected ValueSet 3" + assert seq == expected.stream_frame.seq, ( + f"Expected seq {seq} == {expected.stream_frame.seq}" + ) + assert ack == expected.stream_frame.ack, ( + f"Expected ack {ack} == {expected.stream_frame.ack}" + ) + assert expected.stream_frame.stream_alias, "Expected stream_alias" + (from_, to, stream_id) = state["streams"][ + expected.stream_frame.stream_alias + ] + assert expected.stream_frame.payload == payload, ( + f"Expected {expected.stream_frame.payload} == {payload}" + ) + return None + raise ValueError("Unexpected from_client case: %r", received) + + def to_client_interpreter(expected: ToClient) -> Datagram | None: + if expected.stream_frame is not None: + (stream_alias, payload) = expected.stream_frame + (from_, to, stream_id) = state["streams"][stream_alias] + return _build_datagram( + seq=expected.seq, + ack=expected.ack, + control_flags=expected.control_flags, + packet_id=nanoid.generate(), + stream_id=stream_id, + client_id=from_, + server_id=to, + payload=payload, + ) + elif expected.stream_close is not None: + stream_alias = expected.stream_close + (from_, to, stream_id) = state["streams"][stream_alias] + return _build_datagram( + seq=expected.seq, + ack=expected.ack, + control_flags=0b01000, + packet_id=nanoid.generate(), + stream_id=stream_id, + client_id=from_, + server_id=to, + payload={"type": "CLOSE"}, + ) + raise ValueError("Unexpected to_client case: %r", expected) + + return from_client_interpreter, to_client_interpreter + + +def _strip_none(datagram: Datagram) -> Datagram: + return {k: v for k, v in datagram.items() if v is not None} + + +def _build_handshake_resp(session_id: str) -> Datagram: + return _strip_none( + { + "type": "HANDSHAKE_RESP", + "status": { + "ok": True, + "sessionId": session_id, + }, + } + ) + + +def _build_datagram( + *, + packet_id: str, + stream_id: str, + server_id: str, + client_id: str, + control_flags: int, + seq: int, + ack: int, + payload: Datagram, +) -> Datagram: + return _strip_none( + { + "id": packet_id, + "from": server_id, + "to": client_id, + "seq": seq, + "ack": ack, + "streamId": stream_id, + "controlFlags": control_flags, + "payload": payload, + } + ) diff --git a/tests/v2/test_rpc.schema.json b/tests/v2/test_rpc.schema.json new file mode 100644 index 00000000..d53a1878 --- /dev/null +++ b/tests/v2/test_rpc.schema.json @@ -0,0 +1,32 @@ +{ + "services": { + "test_service": { + "procedures": { + "rpc_method": { + "init": { + "type": "object", + "properties": { + "data": { + "type": "string" + } + }, + "required": ["data"] + }, + "output": { + "type": "object", + "properties": { + "data": { + "type": "string" + } + }, + "required": ["data"] + }, + "errors": { + "not": {} + }, + "type": "rpc" + } + } + } + } +} diff --git a/tests/v2/test_stream.schema.json b/tests/v2/test_stream.schema.json new file mode 100644 index 00000000..5b1f34f8 --- /dev/null +++ b/tests/v2/test_stream.schema.json @@ -0,0 +1,36 @@ +{ + "services": { + "test_service": { + "procedures": { + "stream_method": { + "init": { + "type": "object", + "properties": {} + }, + "input": { + "type": "object", + "properties": { + "data": { + "type": "string" + } + }, + "required": ["data"] + }, + "output": { + "type": "object", + "properties": { + "data": { + "type": "string" + } + }, + "required": ["data"] + }, + "errors": { + "not": {} + }, + "type": "stream" + } + } + } + } +} diff --git a/tests/v2/test_v2_rpc.py b/tests/v2/test_v2_rpc.py new file mode 100644 index 00000000..3f6fe645 --- /dev/null +++ b/tests/v2/test_v2_rpc.py @@ -0,0 +1,97 @@ +import importlib +import logging +from collections import deque +from datetime import timedelta +from typing import ( + Literal, +) + +import pytest +from pytest_snapshot.plugin import Snapshot + +from replit_river.v2.client import Client +from tests.fixtures.codegen_snapshot_fixtures import validate_codegen +from tests.v2.datagrams import ( + ClientId, + FromClient, + ServerId, + StreamAlias, + TestTransport, + ToClient, + ValueSet, + WaitForClosed, +) + +logger = logging.getLogger(__name__) + +_AlreadyGenerated = False + + +@pytest.fixture(scope="function", autouse=True) +def stream_client_codegen(snapshot: Snapshot) -> Literal[True]: + global _AlreadyGenerated + if not _AlreadyGenerated: + validate_codegen( + snapshot=snapshot, + snapshot_dir="tests/v2/codegen/snapshot/snapshots", + read_schema=lambda: open("tests/v2/test_rpc.schema.json"), + target_path="test_basic_rpc", + client_name="StreamClient", + protocol_version="v2.0", + ) + _AlreadyGenerated = True + + import tests.v2.codegen.snapshot.snapshots.test_basic_stream + + importlib.reload(tests.v2.codegen.snapshot.snapshots.test_basic_stream) + return True + + +rpc_expected: deque[TestTransport] = deque( + [ + FromClient( + handshake_request=ValueSet( + seq=0, # These don't count due to being during a handshake + ack=0, + from_=ServerId("server-001"), + to=ClientId("client-001"), + ) + ), + FromClient( + stream_open=ValueSet( + seq=0, + ack=0, + from_=ServerId("server-001"), + to=ClientId("client-001"), + serviceName="test_service", + procedureName="rpc_method", + create_alias=StreamAlias(1), + payload={"data": "foo"}, + stream_closed=True, + ) + ), + ToClient( + seq=0, + ack=1, + stream_frame=( + StreamAlias(1), + {"ok": True, "payload": {"data": "Hello, foo!"}}, + ), + ), + WaitForClosed(), + ] +) + + +@pytest.mark.parametrize("expected", [rpc_expected]) +async def test_rpc(bound_client: Client) -> None: + from tests.v2.codegen.snapshot.snapshots.test_basic_rpc import ( + StreamClient, + ) + + res = await StreamClient(bound_client).test_service.rpc_method( + init={"data": "foo"}, + timeout=timedelta(seconds=5), + ) + + assert res.data == "Hello, foo!" diff --git a/tests/v2/test_v2_stream.py b/tests/v2/test_v2_stream.py new file mode 100644 index 00000000..b76ff906 --- /dev/null +++ b/tests/v2/test_v2_stream.py @@ -0,0 +1,194 @@ +import importlib +import logging +from collections import deque +from typing import ( + AsyncIterable, + Literal, +) + +import pytest +from pytest_snapshot.plugin import Snapshot + +from replit_river.v2.client import Client +from tests.fixtures.codegen_snapshot_fixtures import validate_codegen +from tests.v2.datagrams import ( + ClientId, + FromClient, + ServerId, + StreamAlias, + TestTransport, + ToClient, + ValueSet, + WaitForClosed, +) + +logger = logging.getLogger(__name__) + +_AlreadyGenerated = False + + +@pytest.fixture(scope="function", autouse=True) +def stream_client_codegen(snapshot: Snapshot) -> Literal[True]: + global _AlreadyGenerated + if not _AlreadyGenerated: + validate_codegen( + snapshot=snapshot, + snapshot_dir="tests/v2/codegen/snapshot/snapshots", + read_schema=lambda: open("tests/v2/test_stream.schema.json"), + target_path="test_basic_stream", + client_name="StreamClient", + protocol_version="v2.0", + ) + _AlreadyGenerated = True + + import tests.v2.codegen.snapshot.snapshots.test_basic_stream + + importlib.reload(tests.v2.codegen.snapshot.snapshots.test_basic_stream) + return True + + +stream_expected: deque[TestTransport] = deque( + [ + FromClient( + handshake_request=ValueSet( + seq=0, # These don't count due to being during a handshake + ack=0, + from_=ServerId("server-001"), + to=ClientId("client-001"), + ) + ), + FromClient( + stream_open=ValueSet( + seq=0, + ack=0, + from_=ServerId("server-001"), + to=ClientId("client-001"), + serviceName="test_service", + procedureName="stream_method", + create_alias=StreamAlias(1), + ) + ), + FromClient( + stream_frame=ValueSet( + seq=1, + ack=0, + stream_alias=StreamAlias(1), + payload={"data": "0"}, + ) + ), + ToClient( + seq=0, + ack=1, + stream_frame=( + StreamAlias(1), + {"ok": True, "payload": {"data": "Stream response for 0"}}, + ), + ), + FromClient( + stream_frame=ValueSet( + seq=2, + ack=0, + stream_alias=StreamAlias(1), + payload={"data": "1"}, + ) + ), + ToClient( + seq=1, + ack=2, + stream_frame=( + StreamAlias(1), + {"ok": True, "payload": {"data": "Stream response for 1"}}, + ), + ), + FromClient( + stream_frame=ValueSet( + seq=3, + ack=0, + stream_alias=StreamAlias(1), + payload={"data": "2"}, + ) + ), + ToClient( + seq=2, + ack=0, + stream_frame=( + StreamAlias(1), + {"ok": True, "payload": {"data": "Stream response for 2"}}, + ), + ), + FromClient( + stream_frame=ValueSet( + seq=4, + ack=0, + stream_alias=StreamAlias(1), + payload={"data": "3"}, + ) + ), + ToClient( + seq=3, + ack=4, + stream_frame=( + StreamAlias(1), + {"ok": True, "payload": {"data": "Stream response for 3"}}, + ), + ), + FromClient( + stream_frame=ValueSet( + seq=5, + ack=0, + stream_alias=StreamAlias(1), + payload={"data": "4"}, + ) + ), + ToClient( + seq=4, + ack=5, + stream_frame=( + StreamAlias(1), + {"ok": True, "payload": {"data": "Stream response for 4"}}, + ), + ), + FromClient( + stream_frame=ValueSet( + seq=6, + ack=0, + stream_alias=StreamAlias(1), + payload={"type": "CLOSE"}, + ) + ), + ToClient( + seq=5, + ack=0, + stream_close=StreamAlias(1), + ), + WaitForClosed(), + ] +) + + +@pytest.mark.parametrize("expected", [stream_expected]) +async def test_stream(bound_client: Client) -> None: + from tests.v2.codegen.snapshot.snapshots.test_basic_stream import ( + StreamClient, + ) + from tests.v2.codegen.snapshot.snapshots.test_basic_stream.test_service.stream_method import ( # noqa: E501 + Stream_MethodInput, + Stream_MethodOutput, + ) + + async def emit() -> AsyncIterable[Stream_MethodInput]: + for i in range(5): + data: Stream_MethodInput = {"data": str(i)} + yield data + + res = await StreamClient(bound_client).test_service.stream_method( + init={}, + inputStream=emit(), + ) + + i = 0 + async for datum in res: + assert isinstance(datum, Stream_MethodOutput) + assert f"Stream response for {i}" == datum.data, f"{i} == {datum.data}" + i = i + 1 + assert i == 5 diff --git a/uv.lock b/uv.lock index fcd7277e..4833a89f 100644 --- a/uv.lock +++ b/uv.lock @@ -627,7 +627,7 @@ requires-dist = [ { name = "protobuf", specifier = ">=5.28.3" }, { name = "pydantic", specifier = "==2.9.2" }, { name = "pydantic-core", specifier = ">=2.20.1" }, - { name = "websockets", specifier = ">=12.0" }, + { name = "websockets", specifier = ">=13.0,<14" }, ] [package.metadata.requires-dev]