Skip to content

Commit cb0ac8b

Browse files
committed
refactor: WIP, add add more methods
1 parent a56fcc4 commit cb0ac8b

File tree

8 files changed

+158
-154
lines changed

8 files changed

+158
-154
lines changed

binaries/cli/src/command/build/distributed.rs

Lines changed: 12 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ use communication_layer_request_reply::{TcpConnection, TcpRequestReplyConnection
22
use dora_core::descriptor::Descriptor;
33
use dora_message::{
44
BuildId,
5-
cli_to_coordinator::ControlRequest,
5+
cli_to_coordinator::{CliToCoordinatorClient, ControlRequest},
66
common::{GitSource, LogMessage},
77
coordinator_to_cli::ControlRequestReply,
88
id::NodeId,
@@ -23,33 +23,17 @@ pub fn build_distributed_dataflow(
2323
local_working_dir: Option<std::path::PathBuf>,
2424
uv: bool,
2525
) -> eyre::Result<BuildId> {
26-
let build_id = {
27-
let reply_raw = session
28-
.request(
29-
&serde_json::to_vec(&ControlRequest::Build {
30-
session_id: dataflow_session.session_id,
31-
dataflow,
32-
git_sources: git_sources.clone(),
33-
prev_git_sources: dataflow_session.git_sources.clone(),
34-
local_working_dir,
35-
uv,
36-
})
37-
.unwrap(),
38-
)
39-
.wrap_err("failed to send start dataflow message")?;
40-
41-
let result: ControlRequestReply =
42-
serde_json::from_slice(&reply_raw).wrap_err("failed to parse reply")?;
43-
match result {
44-
ControlRequestReply::DataflowBuildTriggered { build_id } => {
45-
eprintln!("dataflow build triggered: {build_id}");
46-
build_id
47-
}
48-
ControlRequestReply::Error(err) => bail!("{err}"),
49-
other => bail!("unexpected start dataflow reply: {other:?}"),
50-
}
51-
};
52-
Ok(build_id)
26+
let mut client: CliToCoordinatorClient<_, Vec<u8>, Vec<u8>, std::io::Error> =
27+
CliToCoordinatorClient::new(session);
28+
let build_id = client.build(dora_message::cli_to_coordinator::BuildReq {
29+
session_id: dataflow_session.session_id,
30+
dataflow,
31+
git_sources: git_sources.clone(),
32+
prev_git_sources: dataflow_session.git_sources.clone(),
33+
local_working_dir,
34+
uv,
35+
})?;
36+
Ok(build_id.build_id)
5337
}
5438

5539
pub fn wait_until_dataflow_built(

libraries/communication-layer/request-reply/src/lib.rs

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,3 +75,13 @@ pub trait RequestReplyConnection: Send + Sync {
7575

7676
fn request(&mut self, request: &Self::RequestData) -> Result<Self::ReplyData, Self::Error>;
7777
}
78+
79+
impl<T: RequestReplyConnection + ?Sized> RequestReplyConnection for &mut T {
80+
type RequestData = T::RequestData;
81+
type ReplyData = T::ReplyData;
82+
type Error = T::Error;
83+
fn request(&mut self, request: &Self::RequestData) -> Result<Self::ReplyData, Self::Error> {
84+
(**self).request(request)
85+
}
86+
}
87+

libraries/dora-schema-macro/src/client.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ pub fn generate_client(schema: &SchemaInput) -> proc_macro2::TokenStream {
3737

3838
match resp_enum {
3939
#response_enum::#variant(resp) => Ok(resp),
40+
#response_enum::Error(err) => Err(::eyre::eyre!("Server returned error: {}", err.msg)),
4041
_ => ::eyre::bail!("Unexpected response type"),
4142
}
4243
}
@@ -71,7 +72,7 @@ pub fn generate_client(schema: &SchemaInput) -> proc_macro2::TokenStream {
7172
ReplyData = std::vec::Vec<u8>,
7273
Error = E,
7374
>,
74-
E: std::marker::Send + std::marker::Sync + std::error::Error + 'static,
75+
E: std::marker::Send + std::marker::Sync + std::error::Error,
7576
{
7677
#(#client_methods)*
7778
}

libraries/dora-schema-macro/src/lib.rs

Lines changed: 2 additions & 116 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ mod syntax;
55

66
// dora-schema-macro/src/lib.rs
77
use proc_macro::TokenStream;
8-
use quote::{format_ident, quote};
8+
use quote::quote;
99

1010
use crate::syntax::SchemaInput;
1111

@@ -15,8 +15,7 @@ pub fn dora_schema(input: TokenStream) -> TokenStream {
1515

1616
let protocol_code = protocol::generate_protocol(&schema);
1717
let client_code = client::generate_client(&schema);
18-
// let server_trait_code = server::generate_server_trait(&schema);
19-
let server_trait_code = quote! {};
18+
let server_trait_code = server::generate_server_trait(&schema);
2019

2120
let expanded = quote! {
2221
#protocol_code
@@ -26,116 +25,3 @@ pub fn dora_schema(input: TokenStream) -> TokenStream {
2625

2726
TokenStream::from(expanded)
2827
}
29-
30-
fn generate_server(schema: &SchemaInput) -> proc_macro2::TokenStream {
31-
let client_name = &schema.client_name;
32-
let server_name = &schema.server_name;
33-
let handler_trait_name = format_ident!("{}Handler", server_name);
34-
let request_enum = format_ident!("{}To{}Request", client_name, server_name);
35-
let response_enum = format_ident!("{}To{}Response", client_name, server_name);
36-
37-
let trait_methods: Vec<_> = schema
38-
.methods
39-
.iter()
40-
.map(|m| {
41-
let handler_name = format_ident!("{}_handler", m.name);
42-
let request_type = &m.request;
43-
let response_type = &m.response;
44-
45-
// TODO: better
46-
quote! {
47-
fn #handler_name(
48-
&self,
49-
request: #request_type
50-
) -> std::pin::Pin<
51-
Box<dyn std::future::Future<Output = ::eyre::Result<#response_type>> + Send>
52-
>;
53-
}
54-
})
55-
.collect();
56-
57-
let dispatch_arms: Vec<_> = schema
58-
.methods
59-
.iter()
60-
.map(|m| {
61-
let handler_name = format_ident!("{}_handler", m.name);
62-
let request_variant = format_ident!("{}", capitalize(&m.name.to_string()));
63-
let response_variant = request_variant.clone();
64-
65-
quote! {
66-
#request_enum::#request_variant(req) => {
67-
match self.#handler_name(req).await {
68-
Ok(resp) => {
69-
let resp_enum = #response_enum::#response_variant(resp);
70-
::bincode::serialize(&resp_enum)
71-
.wrap_err("Failed to serialize response")
72-
}
73-
Err(e) => {
74-
::eyre::bail!("Handler error: {:?}", e)
75-
}
76-
}
77-
}
78-
}
79-
})
80-
.collect();
81-
82-
quote! {
83-
pub trait #handler_trait_name: Send + Sync {
84-
#(#trait_methods)*
85-
}
86-
87-
pub struct #server_name<H, L> {
88-
handler: H,
89-
listener: L,
90-
}
91-
92-
impl<H, L> #server_name<H, L>
93-
where
94-
H: #handler_trait_name + 'static,
95-
L: ListenConnection,
96-
{
97-
pub fn new(handler: H, listener: L) -> Self {
98-
Self { handler, listener }
99-
}
100-
101-
pub async fn serve(self) -> ::eyre::Result<()> {
102-
loop {
103-
let req_bytes = self.listener.receive()
104-
.await
105-
.wrap_err("Failed to receive request")?;
106-
107-
let request: #request_enum = match ::bincode::deserialize(&req_bytes) {
108-
Ok(req) => req,
109-
Err(e) => {
110-
eprintln!("Failed to deserialize request: {:?}", e);
111-
continue;
112-
}
113-
};
114-
115-
let response_bytes = match request {
116-
#(#dispatch_arms,)*
117-
};
118-
119-
match response_bytes {
120-
Ok(bytes) => {
121-
if let Err(e) = self.listener.send(bytes).await {
122-
eprintln!("Failed to send response: {:?}", e);
123-
}
124-
}
125-
Err(e) => {
126-
eprintln!("Handler error: {:?}", e);
127-
}
128-
}
129-
}
130-
}
131-
}
132-
}
133-
}
134-
135-
fn capitalize(s: &str) -> String {
136-
let mut chars = s.chars();
137-
match chars.next() {
138-
Some(first) => first.to_uppercase().chain(chars).collect(),
139-
None => String::new(),
140-
}
141-
}

libraries/dora-schema-macro/src/protocol.rs

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,12 +60,18 @@ pub fn generate_protocol(schema: &SchemaInput) -> proc_macro2::TokenStream {
6060
#[derive(Debug, ::serde::Serialize, ::serde::Deserialize)]
6161
pub enum #response_enum {
6262
#(#response_variants,)*
63+
Error(#error_struct),
6364
}
6465

6566
#[derive(Debug, ::serde::Serialize, ::serde::Deserialize)]
6667
pub struct #error_struct {
67-
pub message: String,
68-
pub source: Option<String>,
68+
pub msg: String,
69+
}
70+
71+
impl #error_struct {
72+
pub fn new(msg: String) -> Self {
73+
Self { msg }
74+
}
6975
}
7076
}
7177
}
Lines changed: 97 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,100 @@
1-
use crate::SchemaInput;
1+
use convert_case::{Case, Casing};
2+
use proc_macro2::Ident;
3+
use quote::{format_ident, quote};
4+
5+
use crate::{
6+
SchemaInput, protocol::{enum_variant_ident, error_struct_ident, request_enum_ident, response_enum_ident}
7+
};
8+
9+
pub fn server_trait_ident(schema: &SchemaInput) -> Ident {
10+
format_ident!("{}Handler", schema.protocol_name())
11+
}
12+
13+
pub fn handle_func_ident(schema: &SchemaInput) -> Ident {
14+
let snake_case_protocol_name =
15+
Casing::from_case(&schema.protocol_name(), Case::Pascal).to_case(Case::Snake);
16+
format_ident!("handle_{}", snake_case_protocol_name)
17+
}
18+
19+
pub fn method_handler_ident(method: &crate::syntax::MethodDef) -> Ident {
20+
format_ident!("{}_handler", method.name)
21+
}
222

323
pub fn generate_server_trait(schema: &SchemaInput) -> proc_macro2::TokenStream {
4-
todo!()
24+
let trait_name = server_trait_ident(schema);
25+
let request_enum = request_enum_ident(schema);
26+
let response_enum = response_enum_ident(schema);
27+
let error_type = error_struct_ident(schema);
28+
29+
let trait_methods: Vec<_> = schema
30+
.methods
31+
.iter()
32+
.map(|m| {
33+
let handler_name = method_handler_ident(m);
34+
let request_type = &m.request;
35+
let response_type = &m.response;
36+
37+
// TODO: better
38+
quote! {
39+
fn #handler_name(
40+
&self,
41+
request: #request_type
42+
) -> ::std::pin::Pin<
43+
Box<dyn ::std::future::Future<Output = ::std::result::Result<#response_type, #error_type>> + Send>
44+
>;
45+
}
46+
})
47+
.collect();
48+
49+
let dispatch_arms = schema
50+
.methods
51+
.iter()
52+
.map(|m| {
53+
let handler_name = method_handler_ident(m);
54+
let method_variant = enum_variant_ident(m);
55+
quote! {
56+
#request_enum::#method_variant(request) => {
57+
let response = handler.#handler_name(request).await;
58+
match response {
59+
Ok(resp) => #response_enum::#method_variant(resp),
60+
Err(err) => #response_enum::Error(err),
61+
}
62+
}
63+
}
64+
})
65+
.collect::<Vec<_>>();
66+
67+
let handle_func_name = handle_func_ident(schema);
68+
69+
// TODO: remove unwraps and handle errors properly
70+
let handle_func = quote! {
71+
pub async fn #handle_func_name(
72+
handler: impl #trait_name,
73+
request: ::std::vec::Vec<u8>
74+
) -> ::std::vec::Vec<u8> {
75+
let request_enum: #request_enum = match ::serde_json::from_slice(&request) {
76+
Ok(req) => req,
77+
Err(err) => {
78+
let error_response = #response_enum::Error(#error_type {
79+
msg: format!("Failed to deserialize request: {}", err),
80+
});
81+
return ::serde_json::to_vec(&error_response).unwrap();
82+
}
83+
};
84+
85+
let response_enum = match request_enum {
86+
#(#dispatch_arms)*
87+
};
88+
89+
::serde_json::to_vec(&response_enum).unwrap()
90+
}
91+
};
92+
93+
quote! {
94+
pub trait #trait_name: std::marker::Send + std::marker::Sync {
95+
#(#trait_methods)*
96+
}
97+
98+
#handle_func
99+
}
5100
}

libraries/message/src/cli_to_coordinator.rs

Lines changed: 25 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,7 @@ dora_schema_macro::dora_schema! {
1414
Cli => Coordinator:
1515

1616
build: BuildReq => BuildResp;
17-
// wait_for_build: ControlRequestWaitForBuild => ControlRequestWaitForBuildReply;
18-
// start: ControlRequestStart => ControlRequestStartReply;
19-
// wait_for_spawn: ControlRequestWaitForSpawn => ControlRequestWaitForSpawnReply;
20-
// reload: ControlRequestReload => ControlRequestReloadReply;
17+
start: StartReq => StartResp;
2118
}
2219

2320
#[derive(Debug, Clone, Serialize, Deserialize)]
@@ -42,7 +39,30 @@ pub struct BuildResp {
4239
pub build_id: BuildId,
4340
}
4441

45-
#[derive(Debug, Clone, serde::Deserialize, serde::Serialize)]
42+
#[derive(Debug, Clone, Serialize, Deserialize)]
43+
pub struct StartReq {
44+
pub build_id: Option<BuildId>,
45+
pub session_id: SessionId,
46+
pub dataflow: Descriptor,
47+
pub name: Option<String>,
48+
/// Allows overwriting the base working dir when CLI and daemon are
49+
/// running on the same machine.
50+
///
51+
/// Must not be used for multi-machine dataflows.
52+
///
53+
/// Note that nodes with git sources still use a subdirectory of
54+
/// the base working dir.
55+
pub local_working_dir: Option<PathBuf>,
56+
pub uv: bool,
57+
pub write_events_to: Option<PathBuf>,
58+
}
59+
60+
#[derive(Debug, Clone, Serialize, Deserialize)]
61+
pub struct StartResp {
62+
pub uuid: Uuid,
63+
}
64+
65+
#[derive(Debug, Clone, Serialize, Deserialize)]
4666
pub enum ControlRequest {
4767
Build {
4868
session_id: SessionId,

0 commit comments

Comments
 (0)