Skip to content

Commit b474eee

Browse files
Add streaming support fo JsonRpcServer
1 parent 660b841 commit b474eee

File tree

2 files changed

+41
-20
lines changed

2 files changed

+41
-20
lines changed

pyk/src/pyk/rpc/rpc.py

Lines changed: 30 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -6,15 +6,15 @@
66
from dataclasses import dataclass
77
from functools import partial
88
from http.server import BaseHTTPRequestHandler, HTTPServer
9-
from typing import TYPE_CHECKING, Any, Final, NamedTuple
10-
9+
from typing import TYPE_CHECKING, NamedTuple
1110
from typing_extensions import Protocol
1211

1312
from ..cli.cli import Options
1413

1514
if TYPE_CHECKING:
1615
from collections.abc import Callable
1716
from pathlib import Path
17+
from typing import Any, Final, Iterator
1818

1919

2020
_LOGGER: Final = logging.getLogger(__name__)
@@ -86,7 +86,8 @@ class JsonRpcBatchRequest(NamedTuple):
8686
class JsonRpcResult(ABC):
8787

8888
@abstractmethod
89-
def encode(self) -> bytes: ...
89+
def encode(self) -> Iterator[bytes]:
90+
...
9091

9192

9293
@dataclass(frozen=True)
@@ -96,7 +97,7 @@ class JsonRpcError(JsonRpcResult):
9697
message: str
9798
id: str | int | None
9899

99-
def to_json(self) -> dict[str, Any]:
100+
def wrap_response(self) -> dict[str, Any]:
100101
return {
101102
'jsonrpc': JsonRpcServer.JSONRPC_VERSION,
102103
'error': {
@@ -106,32 +107,40 @@ def to_json(self) -> dict[str, Any]:
106107
'id': self.id,
107108
}
108109

109-
def encode(self) -> bytes:
110-
return json.dumps(self.to_json()).encode('ascii')
110+
def encode(self) -> Iterator[bytes]:
111+
yield json.dumps(self.wrap_response()).encode('ascii')
111112

112113

113114
@dataclass(frozen=True)
114115
class JsonRpcSuccess(JsonRpcResult):
115116
payload: Any
116117
id: Any
117118

118-
def to_json(self) -> dict[str, Any]:
119-
return {
120-
'jsonrpc': JsonRpcServer.JSONRPC_VERSION,
121-
'result': self.payload,
122-
'id': self.id,
123-
}
124-
125-
def encode(self) -> bytes:
126-
return json.dumps(self.to_json()).encode('ascii')
119+
def encode(self) -> Iterator[bytes]:
120+
yield f'{{"jsonrpc":"2.0", "id": {self.id}, "result": '.encode('ascii')
121+
if isinstance(self.payload, Iterator):
122+
for chunk in self.payload:
123+
yield chunk.encode('ascii')
124+
else:
125+
yield json.dumps(self.payload).encode('ascii')
126+
yield b'}'
127127

128128

129129
@dataclass(frozen=True)
130130
class JsonRpcBatchResult(JsonRpcResult):
131131
results: tuple[JsonRpcError | JsonRpcSuccess, ...]
132132

133-
def encode(self) -> bytes:
134-
return json.dumps([result.to_json() for result in self.results]).encode('ascii')
133+
def encode(self) -> Iterator[bytes]:
134+
yield b'['
135+
first = True
136+
for result in self.results:
137+
if not first:
138+
yield b','
139+
else:
140+
first = False
141+
for chunk in result.encode():
142+
yield chunk
143+
yield b']'
135144

136145

137146
class JsonRpcRequestHandler(BaseHTTPRequestHandler):
@@ -143,8 +152,10 @@ def __init__(self, methods: dict[str, JsonRpcMethod], *args: Any, **kwargs: Any)
143152

144153
def _send_response(self, response: JsonRpcResult) -> None:
145154
self.send_response_headers()
146-
response_bytes = response.encode()
147-
self.wfile.write(response_bytes)
155+
response_body = response.encode()
156+
for chunk in response_body:
157+
self.wfile.write(chunk)
158+
self.wfile.flush()
148159

149160
def send_response_headers(self) -> None:
150161
self.send_response(200)

pyk/src/tests/integration/test_json_rpc.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
from __future__ import annotations
22

3+
from collections.abc import Iterator
34
import json
45
from http.client import HTTPConnection
56
from threading import Thread
67
from time import sleep
7-
from typing import TYPE_CHECKING
8+
from typing import TYPE_CHECKING, Iterator
89

910
from pyk.cterm import CTerm
1011
from pyk.kast.inner import KApply, KSequence, KSort, KToken
@@ -154,6 +155,7 @@ def __init__(self, options: ServeRpcOptions) -> None:
154155
self.register_method('set_x', self.exec_set_x)
155156
self.register_method('set_y', self.exec_set_y)
156157
self.register_method('add', self.exec_add)
158+
self.register_method('streaming', self.exec_streaming)
157159

158160
def exec_get_x(self) -> int:
159161
return self.x
@@ -169,6 +171,11 @@ def exec_set_y(self, n: int) -> None:
169171

170172
def exec_add(self) -> int:
171173
return self.x + self.y
174+
175+
def exec_streaming(self) -> Iterator[bytes]:
176+
yield b'{'
177+
yield b'"foo": "bar"'
178+
yield b'}'
172179

173180

174181
class TestJsonRPCServer(KRunTest):
@@ -221,6 +228,9 @@ def wait_until_ready() -> None:
221228
assert len(res) == 3
222229
assert res[2]['result'] == 1 + 2
223230

231+
res = rpc_client.request('streaming', [])
232+
assert res == {'foo': 'bar'}
233+
224234
server.shutdown()
225235
thread.join()
226236

0 commit comments

Comments
 (0)