Skip to content

Commit 71c384f

Browse files
authored
Merge pull request #95 from nikomatsakis/main
refactor(sacp): Session API improvements and proxy race condition fix
2 parents 706e739 + a4e7da3 commit 71c384f

25 files changed

+453
-203
lines changed

src/elizacp/tests/mcp_tool_invocation.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ async fn test_elizacp_mcp_tool_call() -> Result<(), sacp::Error> {
5959
))
6060
.await
6161
})
62-
.with_client(transport, async |client_cx| {
62+
.run_until(transport, async |client_cx| {
6363
// Initialize
6464
let _init_response = recv(client_cx.send_request(InitializeRequest {
6565
protocol_version: Default::default(),

src/sacp-conductor/src/conductor.rs

Lines changed: 66 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -122,8 +122,8 @@ use sacp::schema::{
122122
};
123123
use sacp::{Agent, Client, Component, Error, JrMessage};
124124
use sacp::{
125-
HasDefaultEndpoint, JrConnectionBuilder, JrConnectionCx, JrNotification, JrRequest,
126-
JrRequestCx, JrResponse, JrRole, MessageCx, UntypedMessage,
125+
JrConnectionBuilder, JrConnectionCx, JrEndpoint, JrNotification, JrRequest, JrRequestCx,
126+
JrResponse, JrRole, MessageCx, UntypedMessage,
127127
};
128128
use sacp::{
129129
JrMessageHandler, JrResponsePayload,
@@ -751,7 +751,7 @@ impl ConductorResponder {
751751
request: Req,
752752
) -> JrResponse<Req::Response> {
753753
if source_component_index == 0 {
754-
client.send_request(request)
754+
client.send_request_to(Client, request)
755755
} else {
756756
self.proxies[source_component_index - 1].send_request(SuccessorMessage {
757757
message: request,
@@ -783,7 +783,7 @@ impl ConductorResponder {
783783
);
784784
if source_component_index == 0 {
785785
tracing::debug!("Sending notification directly to client");
786-
client.send_notification(notification)
786+
client.send_notification_to(Client, notification)
787787
} else {
788788
tracing::debug!(
789789
target_proxy = source_component_index - 1,
@@ -829,43 +829,30 @@ impl ConductorResponder {
829829
proxies_count = self.proxies.len(),
830830
"Proxy mode: forwarding successor message to conductor's successor"
831831
);
832-
// Wrap the message as a successor message before sending
833-
let to_successor_message = message.map(
834-
|request, request_cx| {
835-
(
836-
SuccessorMessage {
837-
message: request,
838-
meta: None,
839-
},
840-
request_cx,
841-
)
842-
},
843-
|notification| SuccessorMessage {
844-
message: notification,
845-
meta: None,
846-
},
847-
);
848-
return connection_cx.send_proxied_message_to(Client, to_successor_message);
832+
return connection_cx.send_proxied_message_to_via(Agent, conductor_tx, message);
849833
}
850834

851835
tracing::debug!(?message, "forward_client_to_agent_message");
852836

853837
MatchMessageFrom::new(message, connection_cx)
854-
.if_request(async |request: InitializeProxyRequest, request_cx| {
855-
// Proxy forwarding InitializeProxyRequest to its successor
856-
tracing::debug!("forward_client_to_agent_message: InitializeProxyRequest");
857-
// Wrap the request_cx to convert InitializeResponse back to InitializeProxyResponse
858-
self.forward_initialize_request(
859-
target_component_index,
860-
conductor_tx,
861-
connection_cx,
862-
request.initialize,
863-
request_cx,
864-
)
865-
.await
866-
})
838+
.if_request_from(
839+
Client,
840+
async |request: InitializeProxyRequest, request_cx| {
841+
// Proxy forwarding InitializeProxyRequest to its successor
842+
tracing::debug!("forward_client_to_agent_message: InitializeProxyRequest");
843+
// Wrap the request_cx to convert InitializeResponse back to InitializeProxyResponse
844+
self.forward_initialize_request(
845+
target_component_index,
846+
conductor_tx,
847+
connection_cx,
848+
request.initialize,
849+
request_cx,
850+
)
851+
.await
852+
},
853+
)
867854
.await
868-
.if_request(async |request: InitializeRequest, request_cx| {
855+
.if_request_from(Client, async |request: InitializeRequest, request_cx| {
869856
// Direct InitializeRequest (shouldn't happen after initialization, but handle it)
870857
tracing::debug!("forward_client_to_agent_message: InitializeRequest");
871858
self.forward_initialize_request(
@@ -878,7 +865,7 @@ impl ConductorResponder {
878865
.await
879866
})
880867
.await
881-
.if_request(async |request: NewSessionRequest, request_cx| {
868+
.if_request_from(Client, async |request: NewSessionRequest, request_cx| {
882869
// When forwarding "session/new", we adjust MCP servers to manage "acp:" URLs.
883870
self.forward_session_new_request(
884871
target_component_index,
@@ -890,7 +877,8 @@ impl ConductorResponder {
890877
.await
891878
})
892879
.await
893-
.if_request(
880+
.if_request_from(
881+
Client,
894882
async |request: McpOverAcpMessage<UntypedMessage>, request_cx| {
895883
let McpOverAcpMessage {
896884
connection_id,
@@ -910,34 +898,40 @@ impl ConductorResponder {
910898
},
911899
)
912900
.await
913-
.if_notification(async |notification: McpOverAcpMessage<UntypedMessage>| {
914-
let McpOverAcpMessage {
915-
connection_id,
916-
message: mcp_notification,
917-
..
918-
} = notification;
919-
self.bridge_connections
920-
.get_mut(&connection_id)
921-
.ok_or_else(|| {
922-
sacp::util::internal_error(format!(
923-
"unknown connection id: {}",
924-
connection_id
925-
))
926-
})?
927-
.send(MessageCx::Notification(mcp_notification))
928-
.await
929-
})
901+
.if_notification_from(
902+
Client,
903+
async |notification: McpOverAcpMessage<UntypedMessage>| {
904+
let McpOverAcpMessage {
905+
connection_id,
906+
message: mcp_notification,
907+
..
908+
} = notification;
909+
self.bridge_connections
910+
.get_mut(&connection_id)
911+
.ok_or_else(|| {
912+
sacp::util::internal_error(format!(
913+
"unknown connection id: {}",
914+
connection_id
915+
))
916+
})?
917+
.send(MessageCx::Notification(mcp_notification))
918+
.await
919+
},
920+
)
930921
.await
931922
.otherwise(async |message| {
932923
// Otherwise, just send the message along "as is".
933924
if target_component_index == self.proxies.len() {
934925
self.agent
935926
.as_ref()
936927
.expect("targeting agent")
937-
.send_proxied_message_via(conductor_tx, message)
928+
.send_proxied_message_to_via(Agent, conductor_tx, message)
938929
} else {
939-
self.proxies[target_component_index]
940-
.send_proxied_message_via(conductor_tx, message)
930+
self.proxies[target_component_index].send_proxied_message_to_via(
931+
Agent,
932+
conductor_tx,
933+
message,
934+
)
941935
}
942936
})
943937
.await
@@ -1485,27 +1479,32 @@ pub enum ConductorMessage {
14851479
},
14861480
}
14871481

1488-
trait JrConnectionCxExt {
1489-
fn send_proxied_message_via(
1482+
trait JrConnectionCxExt<Role: JrRole> {
1483+
fn send_proxied_message_to_via<End: JrEndpoint>(
14901484
&self,
1485+
end: End,
14911486
conductor_tx: &mpsc::Sender<ConductorMessage>,
14921487
message: MessageCx,
1493-
) -> Result<(), sacp::Error>;
1488+
) -> Result<(), sacp::Error>
1489+
where
1490+
Role: sacp::HasEndpoint<End>;
14941491
}
14951492

1496-
impl<Role: HasDefaultEndpoint + sacp::HasEndpoint<<Role as JrRole>::HandlerEndpoint>>
1497-
JrConnectionCxExt for JrConnectionCx<Role>
1498-
{
1499-
fn send_proxied_message_via(
1493+
impl<Role: JrRole> JrConnectionCxExt<Role> for JrConnectionCx<Role> {
1494+
fn send_proxied_message_to_via<End: JrEndpoint>(
15001495
&self,
1496+
end: End,
15011497
conductor_tx: &mpsc::Sender<ConductorMessage>,
15021498
message: MessageCx,
1503-
) -> Result<(), sacp::Error> {
1499+
) -> Result<(), sacp::Error>
1500+
where
1501+
Role: sacp::HasEndpoint<End>,
1502+
{
15041503
match message {
15051504
MessageCx::Request(request, request_cx) => self
1506-
.send_request(request)
1505+
.send_request_to(end, request)
15071506
.forward_response_via(conductor_tx, request_cx),
1508-
MessageCx::Notification(notification) => self.send_notification(notification),
1507+
MessageCx::Notification(notification) => self.send_notification_to(end, notification),
15091508
}
15101509
}
15111510
}

src/sacp-conductor/src/conductor/mcp_bridge/actor.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ impl McpBridgeConnectionActor {
6464
)
6565
// When we receive messages from the conductor, forward them to the MCP client
6666
.connect_to(transport)?
67-
.with_client(async move |mcp_client_cx| {
67+
.run_until(async move |mcp_client_cx| {
6868
let mut to_mcp_client_rx = to_mcp_client_rx;
6969
while let Some(message) = to_mcp_client_rx.next().await {
7070
mcp_client_cx.send_proxied_message_to(McpServerEnd, message)?;

src/sacp-conductor/tests/initialization_sequence.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,7 @@ async fn run_test_with_components(
136136
))
137137
.await
138138
})
139-
.with_client(transport, editor_task)
139+
.run_until(transport, editor_task)
140140
.await
141141
}
142142

src/sacp-conductor/tests/mcp-integration.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ async fn run_test_with_mode(
6363
))
6464
.await
6565
})
66-
.with_client(transport, editor_task)
66+
.run_until(transport, editor_task)
6767
.await
6868
}
6969

@@ -206,7 +206,7 @@ async fn test_agent_handles_prompt() -> Result<(), sacp::Error> {
206206
],
207207
Default::default(),
208208
))?
209-
.with_client(async |editor_cx| {
209+
.run_until(async |editor_cx| {
210210
// Initialize
211211
recv(editor_cx.send_request(InitializeRequest {
212212
protocol_version: Default::default(),

src/sacp-conductor/tests/mcp_server_handler_chain.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,7 @@ async fn run_test(
167167
))
168168
.await
169169
})
170-
.with_client(transport, editor_task)
170+
.run_until(transport, editor_task)
171171
.await
172172
}
173173

src/sacp-conductor/tests/scoped_mcp_server.rs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ async fn test_scoped_mcp_server_through_proxy() -> Result<(), sacp::Error> {
5252
async fn test_scoped_mcp_server_through_session() -> Result<(), sacp::Error> {
5353
ClientToAgent::builder()
5454
.connect_to(Conductor::new("conductor".to_string(), vec![ElizaAgent::new()], McpBridgeMode::default()))?
55-
.with_client(async |cx| {
55+
.run_until(async |cx| {
5656
// Initialize first
5757
cx.send_request(sacp::schema::InitializeRequest {
5858
protocol_version: Default::default(),
@@ -67,7 +67,8 @@ async fn test_scoped_mcp_server_through_session() -> Result<(), sacp::Error> {
6767
let result = cx
6868
.build_session(".")
6969
.with_mcp_server(make_mcp_server(&collected_values))?
70-
.run_session(async |mut active_session| {
70+
.block_task()
71+
.run_until(async |mut active_session| {
7172
active_session
7273
.send_prompt(r#"Use tool test::push with {"elements": ["Hello", "world"]}"#)?;
7374
active_session.read_to_string().await

src/sacp-tokio/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ JrConnection::to_agent(agent)?
2424
println!("Agent update: {:?}", notif);
2525
Ok(())
2626
})
27-
.with_client(|cx| async move {
27+
.run_until(|cx| async move {
2828
// Initialize and interact with the agent
2929
let response = cx.send_request(InitializeRequest { ... })
3030
.block_task()

src/sacp-tokio/src/acp_agent.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ pub enum LineDirection {
6666
/// // The agent process will be spawned automatically when served
6767
/// UntypedRole::builder()
6868
/// .connect_to(agent)?
69-
/// .with_client(|cx| async move {
69+
/// .run_until(|cx| async move {
7070
/// // Use the connection to communicate with the agent process
7171
/// Ok(())
7272
/// })

src/sacp-tokio/tests/debug_logging.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ async fn test_acp_agent_debug_callback() -> Result<(), Box<dyn std::error::Error
6868
))
6969
.await
7070
})
71-
.with_client(transport, async |client_cx| {
71+
.run_until(transport, async |client_cx| {
7272
// Send an initialize request
7373
let _init_response = recv(client_cx.send_request(InitializeRequest {
7474
protocol_version: Default::default(),

0 commit comments

Comments
 (0)