Skip to content

Commit 5bff52a

Browse files
sampan-s-nayaksampan
andauthored
[core] Configure an interceptor to pass auth token in python direct g… (#58395)
## Description there are places in the python code where we use the raw grpc library to make grpc calls (eg: pub-sub, some calls to gcs etc). In the long term we want to fully deprecate grpc library usage in our python code base but as that can take more effort and testing, in this pr I am introducing an interceptor to add auth headers (this will take effect for all grpc calls made from python). ## Testing ### case 1: submitting a job using CLI ``` export RAY_auth_mode="token" export RAY_AUTH_TOKEN="abcdef1234567890abcdef123456789" ray start --head ray job submit -- echo "hi" ``` output ``` ray job submit -- echo "hi" 2025-11-04 06:28:09,122 - INFO - NumExpr defaulting to 4 threads. Job submission server address: http://127.0.0.1:8265 ------------------------------------------------------- Job 'raysubmit_1EV8q86uKM24nHmH' submitted successfully ------------------------------------------------------- Next steps Query the logs of the job: ray job logs raysubmit_1EV8q86uKM24nHmH Query the status of the job: ray job status raysubmit_1EV8q86uKM24nHmH Request the job to be stopped: ray job stop raysubmit_1EV8q86uKM24nHmH Tailing logs until the job exits (disable with --no-wait): 2025-11-04 06:28:10,363 INFO job_manager.py:568 -- Runtime env is setting up. hi Running entrypoint for job raysubmit_1EV8q86uKM24nHmH: echo hi ------------------------------------------ Job 'raysubmit_1EV8q86uKM24nHmH' succeeded ------------------------------------------ ``` ### case 2: submitting a job with actors and tasks and verifying on dashboard test.py ```python import time import ray from ray._raylet import Config ray.init() @ray.remote def print_hi(): print("Hi") time.sleep(2) @ray.remote class SimpleActor: def __init__(self): self.value = 0 def increment(self): self.value += 1 return self.value actor = SimpleActor.remote() result = ray.get(actor.increment.remote()) for i in range(100): ray.get(print_hi.remote()) time.sleep(20) ray.shutdown() ``` ``` export RAY_auth_mode="token" export RAY_AUTH_TOKEN="abcdef1234567890abcdef123456789" python test.py ``` ### dashboard screenshots: #### promts user to input the token <img width="1720" height="1073" alt="image" src="https://github.com/user-attachments/assets/008829d8-51b6-445a-b135-5f76b6ccf292" /> ### on passing the right token: overview page <img width="1720" height="1073" alt="image" src="https://github.com/user-attachments/assets/cece0da7-0edd-4438-9d60-776526b49762" /> job page: tasks are listed <img width="1720" height="1073" alt="image" src="https://github.com/user-attachments/assets/b98eb1d9-cacc-45ea-b0e2-07ce8922202a" /> task page <img width="1720" height="1073" alt="image" src="https://github.com/user-attachments/assets/09ff38e1-e151-4e34-8651-d206eb8b5136" /> actors page <img width="1720" height="1073" alt="image" src="https://github.com/user-attachments/assets/10a30b3d-3f7e-4f3d-b669-962056579459" /> specific actor page <img width="1720" height="1073" alt="image" src="https://github.com/user-attachments/assets/ab1915bd-3d1b-4813-8101-a219432a55c0" /> --------- Signed-off-by: sampan <[email protected]> Co-authored-by: sampan <[email protected]>
1 parent 71c7bd0 commit 5bff52a

File tree

5 files changed

+213
-5
lines changed

5 files changed

+213
-5
lines changed
Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
"""gRPC client interceptor for token-based authentication."""
2+
3+
import logging
4+
from collections import namedtuple
5+
from typing import Tuple
6+
7+
import grpc
8+
from grpc import aio as aiogrpc
9+
10+
from ray._raylet import AuthenticationTokenLoader
11+
12+
logger = logging.getLogger(__name__)
13+
14+
15+
# Named tuple to hold client call details
16+
_ClientCallDetails = namedtuple(
17+
"_ClientCallDetails",
18+
("method", "timeout", "metadata", "credentials", "wait_for_ready", "compression"),
19+
)
20+
21+
22+
def _get_authentication_metadata_tuple() -> Tuple[Tuple[str, str], ...]:
23+
"""Get gRPC metadata tuple for authentication. Currently only supported for token authentication.
24+
25+
Returns:
26+
tuple: Empty tuple or ((AUTHORIZATION_HEADER_NAME, "Bearer <token>"),)
27+
"""
28+
token_loader = AuthenticationTokenLoader.instance()
29+
if not token_loader.has_token():
30+
return ()
31+
32+
headers = token_loader.get_token_for_http_header()
33+
34+
# Convert HTTP header dict to gRPC metadata tuple
35+
# gRPC expects: (("key", "value"), ...)
36+
return tuple((k, v) for k, v in headers.items())
37+
38+
39+
class AuthenticationMetadataClientInterceptor(
40+
grpc.UnaryUnaryClientInterceptor,
41+
grpc.UnaryStreamClientInterceptor,
42+
grpc.StreamUnaryClientInterceptor,
43+
grpc.StreamStreamClientInterceptor,
44+
):
45+
"""Synchronous gRPC client interceptor that adds authentication metadata."""
46+
47+
def _intercept_call_details(self, client_call_details):
48+
"""Helper method to add authentication metadata to client call details."""
49+
metadata = list(client_call_details.metadata or [])
50+
metadata.extend(_get_authentication_metadata_tuple())
51+
52+
return _ClientCallDetails(
53+
method=client_call_details.method,
54+
timeout=client_call_details.timeout,
55+
metadata=metadata,
56+
credentials=client_call_details.credentials,
57+
wait_for_ready=getattr(client_call_details, "wait_for_ready", None),
58+
compression=getattr(client_call_details, "compression", None),
59+
)
60+
61+
def intercept_unary_unary(self, continuation, client_call_details, request):
62+
new_details = self._intercept_call_details(client_call_details)
63+
return continuation(new_details, request)
64+
65+
def intercept_unary_stream(self, continuation, client_call_details, request):
66+
new_details = self._intercept_call_details(client_call_details)
67+
return continuation(new_details, request)
68+
69+
def intercept_stream_unary(
70+
self, continuation, client_call_details, request_iterator
71+
):
72+
new_details = self._intercept_call_details(client_call_details)
73+
return continuation(new_details, request_iterator)
74+
75+
def intercept_stream_stream(
76+
self, continuation, client_call_details, request_iterator
77+
):
78+
new_details = self._intercept_call_details(client_call_details)
79+
return continuation(new_details, request_iterator)
80+
81+
82+
class AsyncAuthenticationMetadataClientInterceptor(
83+
aiogrpc.UnaryUnaryClientInterceptor,
84+
aiogrpc.UnaryStreamClientInterceptor,
85+
aiogrpc.StreamUnaryClientInterceptor,
86+
aiogrpc.StreamStreamClientInterceptor,
87+
):
88+
"""Async gRPC client interceptor that adds authentication metadata."""
89+
90+
def _intercept_call_details(self, client_call_details):
91+
"""Helper method to add authentication metadata to client call details."""
92+
metadata = list(client_call_details.metadata or [])
93+
metadata.extend(_get_authentication_metadata_tuple())
94+
95+
return _ClientCallDetails(
96+
method=client_call_details.method,
97+
timeout=client_call_details.timeout,
98+
metadata=metadata,
99+
credentials=client_call_details.credentials,
100+
wait_for_ready=getattr(client_call_details, "wait_for_ready", None),
101+
compression=getattr(client_call_details, "compression", None),
102+
)
103+
104+
async def intercept_unary_unary(self, continuation, client_call_details, request):
105+
new_details = self._intercept_call_details(client_call_details)
106+
return await continuation(new_details, request)
107+
108+
async def intercept_unary_stream(self, continuation, client_call_details, request):
109+
new_details = self._intercept_call_details(client_call_details)
110+
return await continuation(new_details, request)
111+
112+
async def intercept_stream_unary(
113+
self, continuation, client_call_details, request_iterator
114+
):
115+
new_details = self._intercept_call_details(client_call_details)
116+
return await continuation(new_details, request_iterator)
117+
118+
async def intercept_stream_stream(
119+
self, continuation, client_call_details, request_iterator
120+
):
121+
new_details = self._intercept_call_details(client_call_details)
122+
return await continuation(new_details, request_iterator)

python/ray/_private/authentication/http_token_authentication.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,10 @@
22
from types import ModuleType
33
from typing import Dict, List, Optional
44

5-
from ray._private.authentication import authentication_constants
6-
from ray.dashboard import authentication_utils as auth_utils
5+
from ray._private.authentication import (
6+
authentication_constants,
7+
authentication_utils as auth_utils,
8+
)
79

810
logger = logging.getLogger(__name__)
911

python/ray/_private/utils.py

Lines changed: 31 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1026,6 +1026,7 @@ def init_grpc_channel(
10261026
import grpc
10271027
from grpc import aio as aiogrpc
10281028

1029+
from ray._private.authentication import authentication_utils
10291030
from ray._private.tls_utils import load_certs_from_env
10301031

10311032
grpc_module = aiogrpc if asynchronous else grpc
@@ -1040,16 +1041,43 @@ def init_grpc_channel(
10401041
)
10411042
options = options_dict.items()
10421043

1043-
if os.environ.get("RAY_USE_TLS", "0").lower() in ("1", "true"):
1044+
# Build interceptors list
1045+
interceptors = []
1046+
if authentication_utils.is_token_auth_enabled():
1047+
from ray._private.authentication.grpc_authentication_client_interceptor import (
1048+
AsyncAuthenticationMetadataClientInterceptor,
1049+
AuthenticationMetadataClientInterceptor,
1050+
)
1051+
1052+
if asynchronous:
1053+
interceptors.append(AsyncAuthenticationMetadataClientInterceptor())
1054+
else:
1055+
interceptors.append(AuthenticationMetadataClientInterceptor())
1056+
1057+
# Create channel with TLS if enabled
1058+
use_tls = os.environ.get("RAY_USE_TLS", "0").lower() in ("1", "true")
1059+
if use_tls:
10441060
server_cert_chain, private_key, ca_cert = load_certs_from_env()
10451061
credentials = grpc.ssl_channel_credentials(
10461062
certificate_chain=server_cert_chain,
10471063
private_key=private_key,
10481064
root_certificates=ca_cert,
10491065
)
1050-
channel = grpc_module.secure_channel(address, credentials, options=options)
1066+
channel_creator = grpc_module.secure_channel
1067+
base_args = (address, credentials)
1068+
else:
1069+
channel_creator = grpc_module.insecure_channel
1070+
base_args = (address,)
1071+
1072+
# Create channel (async channels get interceptors in constructor, sync via intercept_channel)
1073+
if asynchronous:
1074+
channel = channel_creator(
1075+
*base_args, options=options, interceptors=interceptors
1076+
)
10511077
else:
1052-
channel = grpc_module.insecure_channel(address, options=options)
1078+
channel = channel_creator(*base_args, options=options)
1079+
if interceptors:
1080+
channel = grpc.intercept_channel(channel, *interceptors)
10531081

10541082
return channel
10551083

python/ray/tests/test_token_auth_integration.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -342,5 +342,61 @@ def worker_joined():
342342
_cleanup_ray_start(env)
343343

344344

345+
def test_e2e_operations_with_token_auth(setup_cluster_with_token_auth):
346+
"""Test that e2e operations work with token authentication enabled.
347+
348+
This verifies that with token auth enabled:
349+
1. Job submission works
350+
2. Tasks execute successfully
351+
3. Actors can be created and called
352+
"""
353+
cluster_info = setup_cluster_with_token_auth
354+
355+
# Test 1: Submit a simple task
356+
@ray.remote
357+
def simple_task(x):
358+
return x + 1
359+
360+
result = ray.get(simple_task.remote(41))
361+
assert result == 42, f"Task should return 42, got {result}"
362+
363+
# Test 2: Create and use an actor
364+
@ray.remote
365+
class SimpleActor:
366+
def __init__(self):
367+
self.value = 0
368+
369+
def increment(self):
370+
self.value += 1
371+
return self.value
372+
373+
actor = SimpleActor.remote()
374+
result = ray.get(actor.increment.remote())
375+
assert result == 1, f"Actor method should return 1, got {result}"
376+
377+
# Test 3: Submit a job and wait for completion
378+
from ray.job_submission import JobSubmissionClient
379+
380+
# Create job submission client (uses HTTP with auth headers)
381+
client = JobSubmissionClient(address=cluster_info["dashboard_url"])
382+
383+
# Submit a simple job
384+
job_id = client.submit_job(
385+
entrypoint="echo 'Hello from job'",
386+
)
387+
388+
# Wait for job to complete
389+
def job_finished():
390+
status = client.get_job_status(job_id)
391+
return status in ["SUCCEEDED", "FAILED", "STOPPED"]
392+
393+
wait_for_condition(job_finished, timeout=30)
394+
395+
final_status = client.get_job_status(job_id)
396+
assert (
397+
final_status == "SUCCEEDED"
398+
), f"Job should succeed, got status: {final_status}"
399+
400+
345401
if __name__ == "__main__":
346402
sys.exit(pytest.main(["-vv", __file__]))

0 commit comments

Comments
 (0)