Skip to content

Commit d5b806a

Browse files
committed
accept messages with type arguments
1 parent 12dc94e commit d5b806a

File tree

2 files changed

+140
-30
lines changed

2 files changed

+140
-30
lines changed

src/lib.rs

Lines changed: 120 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -408,17 +408,19 @@ extern crate syn;
408408
#[macro_use]
409409
extern crate quote;
410410

411-
use std::collections::HashMap;
411+
use std::collections::{HashMap, HashSet};
412412
use std::fs::{File, OpenOptions, create_dir};
413413
use std::io::{Seek, Write};
414414

415415
use case::CaseExt;
416416
use syn::export::Span;
417+
use syn::punctuated::Pair;
417418
use syn::parse::{Parse, ParseStream, Result};
418419
use syn::{
419420
Abi, Attribute, Expr, FnArg, FnDecl, Generics, Ident, ItemEnum, MethodSig, ReturnType, Type,
420-
WhereClause,
421+
WhereClause, PathArguments, GenericArgument,
421422
};
423+
use quote::ToTokens;
422424

423425
struct Machine {
424426
attributes: Vec<Attribute>,
@@ -561,7 +563,7 @@ struct Transitions {
561563
#[derive(Debug)]
562564
struct Transition {
563565
pub start: Ident,
564-
pub message: Ident,
566+
pub message: Type,
565567
pub end: Vec<Ident>,
566568
}
567569

@@ -604,7 +606,7 @@ impl Parse for Transition {
604606

605607
let start: Ident = left.parse()?;
606608
let _: Token![,] = left.parse()?;
607-
let message: Ident = left.parse()?;
609+
let message: Type = left.parse()?;
608610

609611
let _: Token![=>] = input.parse()?;
610612

@@ -664,7 +666,7 @@ impl Transitions {
664666

665667
for edge in edges.iter() {
666668
file.write_all(
667-
&format!("{} -> {} [ label = \"{}\" ];\n", edge.0, edge.2, edge.1).as_bytes(),
669+
&format!("{} -> {} [ label = \"{}\" ];\n", edge.0, edge.2, edge.1.into_token_stream()).as_bytes(),
668670
)
669671
.expect("error writing to dot file");
670672
}
@@ -693,29 +695,47 @@ pub fn transitions(input: proc_macro::TokenStream) -> syn::export::TokenStream {
693695
entry.push((&t.start, &t.end));
694696
}
695697

698+
//let mut message_types = transitions.transitions.iter().map(|t| &t.message).collect::<Vec<_>>();
699+
700+
let mut type_arguments = HashSet::new();
701+
for t in transitions.transitions.iter() {
702+
type_arguments.extend(type_args(&t.message).drain());
703+
}
704+
705+
let type_arguments = reorder_type_arguments(type_arguments);
706+
696707
// create an enum from the messages
697708
let message_enum_ident = Ident::new(
698709
&format!("{}Messages", &machine_name.to_string()),
699710
Span::call_site(),
700711
);
701-
let variants_names = &messages.keys().collect::<Vec<_>>();
702-
let structs_names = variants_names.clone();
712+
let structs_names = messages.keys().collect::<Vec<_>>();
713+
let variants_names = structs_names.iter().map(|t| type_last_ident(*t)).collect::<Vec<_>>();
714+
715+
716+
let type_arg_toks = if type_arguments.is_empty() {
717+
quote!{}
718+
} else {
719+
quote!{
720+
< #(#type_arguments),* >
721+
}
722+
};
703723

704724
// define the state enum
705725
let toks = quote! {
706726
#[derive(Clone,Debug,PartialEq)]
707-
pub enum #message_enum_ident {
727+
pub enum #message_enum_ident #type_arg_toks {
708728
#(#variants_names(#structs_names)),*
709729
}
710730
};
711731

712732
stream.extend(proc_macro::TokenStream::from(toks));
713-
714733
let functions = messages
715734
.iter()
716735
.map(|(msg, moves)| {
717736
let fn_ident = Ident::new(
718-
&format!("on_{}", &msg.to_string().to_snake()),
737+
//&format!("on_{}", &msg.to_string().to_snake()),
738+
&format!("on_{}", type_to_snake(msg)),
719739
Span::call_site(),
720740
);
721741
let mv = moves.iter().map(|(start, end)| {
@@ -731,8 +751,17 @@ pub fn transitions(input: proc_macro::TokenStream) -> syn::export::TokenStream {
731751
}
732752
}).collect::<Vec<_>>();
733753

754+
let type_arguments = reorder_type_arguments(type_args(msg));
755+
let type_arg_toks = if type_arguments.is_empty() {
756+
quote!{}
757+
} else {
758+
quote!{
759+
< #(#type_arguments),* >
760+
}
761+
};
762+
734763
quote! {
735-
pub fn #fn_ident(self, input: #msg) -> #machine_name {
764+
pub fn #fn_ident #type_arg_toks(self, input: #msg) -> #machine_name {
736765
match self {
737766
#(#mv)*
738767
_ => #machine_name::Error,
@@ -746,18 +775,30 @@ pub fn transitions(input: proc_macro::TokenStream) -> syn::export::TokenStream {
746775
.keys()
747776
.map(|msg| {
748777
let fn_ident = Ident::new(
749-
&format!("on_{}", &msg.to_string().to_snake()),
778+
//&format!("on_{}", &msg.to_string().to_snake()),
779+
&format!("on_{}", type_to_snake(msg)),
750780
Span::call_site(),
751781
);
782+
783+
let id = type_last_ident(msg);
784+
752785
quote!{
753-
#message_enum_ident::#msg(message) => self.#fn_ident(message),
786+
#message_enum_ident::#id(message) => self.#fn_ident(message),
754787
}
755788

756789
})
757790
.collect::<Vec<_>>();
758791

792+
/*let type_arg_toks = if type_arguments.is_empty() {
793+
quote!{}
794+
} else {
795+
quote!{
796+
< #(#type_arguments),* >
797+
}
798+
};*/
799+
759800
let execute = quote! {
760-
pub fn execute(self, input: #message_enum_ident) -> #machine_name {
801+
pub fn execute #type_arg_toks(self, input: #message_enum_ident #type_arg_toks) -> #machine_name {
761802
match input {
762803
#(#matches)*
763804
_ => #machine_name::Error,
@@ -1175,3 +1216,68 @@ fn parse_method_sig(input: ParseStream) -> Result<MethodSig> {
11751216
},
11761217
})
11771218
}
1219+
1220+
fn type_to_snake(t: &Type) -> String {
1221+
match t {
1222+
Type::Path(ref p) => {
1223+
match p.path.segments.last() {
1224+
Some(Pair::End(segment)) => {
1225+
segment.ident.to_string().to_snake()
1226+
},
1227+
_ => panic!("expected a path segment"),
1228+
}
1229+
},
1230+
t => panic!("expected a Type::Path, got {:?}", t),
1231+
}
1232+
}
1233+
1234+
fn type_last_ident(t: &Type) -> &Ident {
1235+
match t {
1236+
Type::Path(ref p) => {
1237+
match p.path.segments.last() {
1238+
Some(Pair::End(segment)) => {
1239+
&segment.ident
1240+
},
1241+
_ => panic!("expected a path segment"),
1242+
}
1243+
},
1244+
t => panic!("expected a Type::Path, got {:?}", t),
1245+
}
1246+
}
1247+
1248+
fn type_args(t: &Type) -> HashSet<GenericArgument> {
1249+
match t {
1250+
Type::Path(ref p) => {
1251+
match p.path.segments.last() {
1252+
Some(Pair::End(segment)) => {
1253+
match &segment.arguments {
1254+
PathArguments::AngleBracketed(a) => {
1255+
a.args.iter().cloned().collect()
1256+
},
1257+
PathArguments::None => HashSet::new(),
1258+
a => panic!("expected angle bracketed arguments, got {:?}", a),
1259+
}
1260+
},
1261+
_ => panic!("expected a path segment"),
1262+
}
1263+
},
1264+
t => panic!("expected a Type::Path, got {:?}", t),
1265+
}
1266+
}
1267+
1268+
// lifetimes must appear before other type arguments
1269+
fn reorder_type_arguments(mut t: HashSet<GenericArgument>) -> Vec<GenericArgument> {
1270+
let mut lifetimes = Vec::new();
1271+
let mut others = Vec::new();
1272+
1273+
for arg in t.drain() {
1274+
if let GenericArgument::Lifetime(_) = arg {
1275+
lifetimes.push(arg);
1276+
} else {
1277+
others.push(arg);
1278+
}
1279+
}
1280+
1281+
lifetimes.extend(others.drain(..));
1282+
lifetimes
1283+
}

tests/traffic_light.rs

Lines changed: 20 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -11,21 +11,23 @@ machine!(
1111
}
1212
);
1313

14-
#[derive(Clone,Debug,PartialEq)]
15-
pub struct Advance;
14+
pub mod prefix {
15+
#[derive(Clone,Debug,PartialEq)]
16+
pub struct Advance;
17+
}
1618

1719
#[derive(Clone,Debug,PartialEq)]
18-
pub struct PassCar { count: u8 }
20+
pub struct PassCar<'a, T> { count: u8, name: &'a T }
1921

2022
#[derive(Clone,Debug,PartialEq)]
2123
pub struct Toggle;
2224

2325
transitions!(TrafficLight,
2426
[
25-
(Green, Advance) => Orange,
26-
(Orange, Advance) => Red,
27-
(Red, Advance) => Green,
28-
(Green, PassCar) => [Green, Orange],
27+
(Green, prefix::Advance) => Orange,
28+
(Orange, prefix::Advance) => Red,
29+
(Red, prefix::Advance) => Green,
30+
(Green, PassCar<'a, T>) => [Green, Orange],
2931
(Green, Toggle) => BlinkingOrange,
3032
(Orange, Toggle) => BlinkingOrange,
3133
(Red, Toggle) => BlinkingOrange,
@@ -42,11 +44,11 @@ methods!(TrafficLight,
4244
);
4345

4446
impl Green {
45-
pub fn on_advance(self, _: Advance) -> Orange {
47+
pub fn on_advance(self, _: prefix::Advance) -> Orange {
4648
Orange {}
4749
}
4850

49-
pub fn on_pass_car(self, input: PassCar) -> TrafficLight {
51+
pub fn on_pass_car<'a, T>(self, input: PassCar<'a, T>) -> TrafficLight {
5052
let count = self.count + input.count;
5153
if count >= 10 {
5254
println!("reached max cars count: {}", count);
@@ -66,7 +68,7 @@ impl Green {
6668
}
6769

6870
impl Orange {
69-
pub fn on_advance(self, _: Advance) -> Red {
71+
pub fn on_advance(self, _: prefix::Advance) -> Red {
7072
Red {}
7173
}
7274

@@ -80,7 +82,7 @@ impl Orange {
8082
}
8183

8284
impl Red {
83-
pub fn on_advance(self, _: Advance) -> Green {
85+
pub fn on_advance(self, _: prefix::Advance) -> Green {
8486
Green {
8587
count: 0
8688
}
@@ -107,9 +109,11 @@ impl BlinkingOrange {
107109

108110
#[test]
109111
fn test() {
112+
use prefix::Advance;
113+
110114
let mut t = TrafficLight::Green(Green { count: 0 });
111-
t = t.on_pass_car(PassCar { count: 1});
112-
t = t.on_pass_car(PassCar { count: 2});
115+
t = t.on_pass_car(PassCar { count: 1, name: &"test".to_string() });
116+
t = t.on_pass_car(PassCar { count: 2, name: &"test".to_string() });
113117
assert_eq!(t, TrafficLight::green(3));
114118
t = t.on_advance(Advance);
115119
//println!("trace: {}", t.print_trace());
@@ -120,13 +124,13 @@ fn test() {
120124

121125
t = t.on_advance(Advance);
122126
assert_eq!(t, TrafficLight::green(0));
123-
t = t.on_pass_car(PassCar { count: 5 });
127+
t = t.on_pass_car(PassCar { count: 5, name: &"test".to_string() });
124128
assert_eq!(t, TrafficLight::green(5));
125-
t = t.on_pass_car(PassCar { count: 7 });
129+
t = t.on_pass_car(PassCar { count: 7, name: &"test".to_string() });
126130
assert_eq!(t, TrafficLight::orange());
127131
t = t.on_advance(Advance);
128132
assert_eq!(t, TrafficLight::red());
129-
t = t.on_pass_car(PassCar { count: 7 });
133+
t = t.on_pass_car(PassCar { count: 7, name: &"test".to_string() });
130134
assert_eq!(t, TrafficLight::error());
131135
t = t.on_advance(Advance);
132136
assert_eq!(t, TrafficLight::error());

0 commit comments

Comments
 (0)