Skip to content

Commit 8d56ecc

Browse files
committed
Make the client own key generation
1 parent 0c71c21 commit 8d56ecc

3 files changed

Lines changed: 58 additions & 46 deletions

File tree

ipykernel/kernelapp.py

Lines changed: 14 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -188,19 +188,6 @@ def abs_connection_file(self):
188188
""",
189189
).tag(config=True)
190190

191-
enable_curve = Bool(
192-
bool(int(os.environ.get("JUPYTER_ENABLE_CURVE", "0"))),
193-
help="Enable CurveZMQ transport encryption and authentication. "
194-
"When True, a keypair is generated at startup and stored in the "
195-
"connection file so that clients can authenticate and encrypt "
196-
"all ZMQ channels.",
197-
).tag(config=True)
198-
199-
# Internal CurveZMQ keypair (Z85-encoded bytes); populated in init_sockets
200-
# when enable_curve is True.
201-
_curve_publickey: bytes | None = None
202-
_curve_secretkey: bytes | None = None
203-
204191
# polling
205192
parent_handle = Integer(
206193
int(os.environ.get("JPY_PARENT_PID") or 0),
@@ -227,12 +214,12 @@ def excepthook(self, etype, evalue, tb):
227214
def _apply_curve_server_options(self, socket: zmq.Socket[t.Any]) -> None:
228215
"""Set CurveZMQ server-side options on *socket* before it is bound.
229216
230-
This is a no-op when enable_curve is False or keys have not been
231-
generated yet, so it is safe to call unconditionally.
217+
This is a no-op when Curve keys are not available yet, so it is safe
218+
to call unconditionally.
232219
"""
233-
if self.enable_curve and self._curve_secretkey is not None:
234-
socket.curve_secretkey = self._curve_secretkey
235-
socket.curve_publickey = self._curve_publickey
220+
if self.curve_secretkey is not None:
221+
socket.curve_secretkey = self.curve_secretkey
222+
socket.curve_publickey = self.curve_publickey
236223
socket.curve_server = True
237224

238225
def init_poller(self):
@@ -298,10 +285,9 @@ def write_connection_file(self, **kwargs: Any) -> None:
298285
iopub_port=self.iopub_port,
299286
control_port=self.control_port,
300287
)
301-
if self.enable_curve and self._curve_publickey is not None:
302-
# write_connection_file() in jupyter-client handles JSON-safe key serialization
303-
connection_info["curve_publickey"] = self._curve_publickey
304-
connection_info["curve_secretkey"] = self._curve_secretkey
288+
if self.curve_publickey is not None:
289+
connection_info["curve_publickey"] = self.curve_publickey
290+
connection_info["curve_secretkey"] = self.curve_secretkey
305291
if Path(cf).exists():
306292
# If the file exists, merge our info into it. For example, if the
307293
# original file had port number 0, we update with the actual port
@@ -356,16 +342,15 @@ def init_sockets(self):
356342
self.context = context = zmq.Context()
357343
atexit.register(self.close)
358344

359-
if self.enable_curve:
360-
self._curve_publickey, self._curve_secretkey = zmq.curve_keypair()
361-
self.log.debug("CurveZMQ enabled; generated server keypair")
345+
if self.curve_secretkey is not None:
346+
self.log.debug("Detected CurveZMQ secret key; using transport encryption")
362347
elif self.transport == "tcp":
363348
self.log.warning(
364349
"Kernel is running over TCP without encryption."
365350
" All communication (including code and outputs) is sent in plain text"
366351
" and is susceptible to eavesdropping."
367-
" Use IPC transport or set IPKernelApp.enable_curve=True to enable"
368-
" CurveZMQ encryption."
352+
" Use IPC transport or launch with kernel manager-provisioned"
353+
" CurveZMQ keys to enable transport encryption."
369354
)
370355

371356
self.shell_socket = context.socket(zmq.ROUTER)
@@ -439,8 +424,8 @@ def init_heartbeat(self):
439424
self.heartbeat = Heartbeat(
440425
hb_ctx,
441426
(self.transport, self.ip, self.hb_port),
442-
curve_publickey=self._curve_publickey if self.enable_curve else None,
443-
curve_secretkey=self._curve_secretkey if self.enable_curve else None,
427+
curve_publickey=self.curve_publickey,
428+
curve_secretkey=self.curve_secretkey,
444429
)
445430
self.hb_port = self.heartbeat.port
446431
self.log.debug("Heartbeat REP Channel on port: %i", self.hb_port)

tests/test_curve.py

Lines changed: 30 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
import pytest
88
import zmq
9+
from jupyter_client import KernelManager
910

1011
from ipykernel.kernelapp import IPKernelApp
1112

@@ -28,12 +29,6 @@ def curve_enabled_kernel_app(tmp_path):
2829
app.close()
2930

3031

31-
def test_curve_disabled_by_default():
32-
"""CurveZMQ must be off by default so existing kernels are unaffected."""
33-
app = IPKernelApp()
34-
assert app.enable_curve is False
35-
36-
3732
def test_connection_file_no_curve_keys_by_default(curve_disabled_kernel_app):
3833
"""Connection file must not contain curve keys when Curve is disabled."""
3934
app, connection_file_path = curve_disabled_kernel_app
@@ -65,14 +60,14 @@ def test_curve_connection_file_has_keys(curve_enabled_kernel_app):
6560

6661

6762
def test_curve_keys_are_stable_per_startup(curve_enabled_kernel_app):
68-
"""Keys generated at startup stay the same throughout the process lifetime."""
63+
"""Provisioned keys stay unchanged throughout the kernel process lifetime."""
6964
app, _connection_file_path = curve_enabled_kernel_app
7065
app.init_sockets()
71-
pub1 = app._curve_publickey
66+
pub1 = app.curve_publickey
7267
# Writing the file twice should not regenerate keys.
7368
app.init_heartbeat()
7469
app.write_connection_file()
75-
assert app._curve_publickey == pub1
70+
assert app.curve_publickey == pub1
7671

7772

7873
def test_curve_socket_server_options(curve_enabled_kernel_app):
@@ -134,7 +129,7 @@ def test_curve_authenticated_socket_can_communicate(curve_enabled_kernel_app):
134129
app.init_sockets()
135130

136131
endpoint = f"tcp://{app.ip}:{app.shell_port}"
137-
server_public = app._curve_publickey
132+
server_public = app.curve_publickey
138133

139134
ctx = zmq.Context()
140135
auth_client = ctx.socket(zmq.DEALER)
@@ -161,9 +156,33 @@ def test_curve_authenticated_socket_can_communicate(curve_enabled_kernel_app):
161156
ctx.term()
162157

163158

164-
def _make_app(tmp_path, **kwargs):
159+
def test_manager_provisioned_curve_keys_are_used(curve_enabled_kernel_app):
160+
"""Kernel uses manager-provisioned Curve keys exactly as provided."""
161+
app, _connection_file_path = curve_enabled_kernel_app
162+
try:
163+
with open(_connection_file_path) as f:
164+
info = json.load(f)
165+
166+
app.init_sockets()
167+
168+
assert app.curve_publickey == info["curve_publickey"].encode()
169+
assert app.curve_secretkey == info["curve_secretkey"].encode()
170+
assert app.shell_socket.curve_server
171+
assert app.stdin_socket.curve_server
172+
assert app.control_socket.curve_server
173+
finally:
174+
app.close()
175+
176+
177+
def _make_app(tmp_path, *, enable_curve=False, **kwargs):
165178
"""Return a minimal IPKernelApp rooted in temporary directory *tmp_path*."""
166179
connection_file_path = str(tmp_path / "kernel.json")
180+
if enable_curve:
181+
# Populate the Curve keys into the connection file
182+
km = KernelManager(connection_file=connection_file_path)
183+
km.transport_encryption = True
184+
km.pre_start_kernel()
185+
167186
app = IPKernelApp(connection_file=connection_file_path, **kwargs)
168187
# Replicate the subset of initialize() that sets up connection info
169188
# without importing IPython shell machinery.

tests/test_kernelapp.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from unittest.mock import patch
66

77
import pytest
8+
from jupyter_client import KernelManager
89
from jupyter_core.paths import secure_write
910
from traitlets.config.loader import Config
1011

@@ -131,25 +132,32 @@ def test_trio_loop():
131132
app.close()
132133

133134

134-
def test_init_sockets_curve_enabled_logs_debug():
135-
app = IPKernelApp(enable_curve=True)
135+
def test_init_sockets_curve_enabled_logs_debug(tmp_path):
136+
connection_file = str(tmp_path / "kernel.json")
137+
km = KernelManager(connection_file=connection_file)
138+
km.transport_encryption = True
139+
km.pre_start_kernel()
140+
141+
app = IPKernelApp(connection_file=connection_file)
142+
super(IPKernelApp, app).initialize(argv=[""])
143+
app.init_connection_file()
136144
with patch.object(app.log, "debug") as mock_debug:
137145
app.init_sockets()
138146
app.cleanup_connection_file()
139147
app.close()
140148
messages = [str(call) for call in mock_debug.call_args_list]
141-
assert any("CurveZMQ enabled" in m for m in messages), (
142-
"Expected a debug log mentioning CurveZMQ when enable_curve=True"
149+
assert any("Detected CurveZMQ secret key; using transport encryption" in m for m in messages), (
150+
"Expected a debug log mentioning CurveZMQ when keys are provided"
143151
)
144152

145153

146154
def test_init_sockets_tcp_without_curve_logs_warning():
147-
app = IPKernelApp(transport="tcp", enable_curve=False)
155+
app = IPKernelApp(transport="tcp")
148156
with patch.object(app.log, "warning") as mock_warning:
149157
app.init_sockets()
150158
app.cleanup_connection_file()
151159
app.close()
152160
messages = [str(call) for call in mock_warning.call_args_list]
153161
assert any("Kernel is running over TCP without encryption" in m for m in messages), (
154-
"Expected a warning about missing encryption when transport=tcp and enable_curve=False"
162+
"Expected a warning about missing encryption when transport=tcp without curve keys"
155163
)

0 commit comments

Comments
 (0)