Skip to content
Merged
7 changes: 7 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

306 changes: 278 additions & 28 deletions soroban-sdk-macros/src/syn_ext.rs
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ impl Parse for HasFnsItem {
Ok(HasFnsItem::Trait(t))
} else if lookahead.peek(Token![impl]) {
let mut imp = input.parse()?;
flatten_associated_items_in_impl_fns(&mut imp);
flatten_associated_items_in_impl_fns(&mut imp)?;
Ok(HasFnsItem::Impl(imp))
} else {
Err(lookahead.error())
Expand Down Expand Up @@ -306,48 +306,121 @@ fn unpack_result(typ: &Type) -> Option<(Type, Type)> {
}
}

fn flatten_associated_items_in_impl_fns(imp: &mut ItemImpl) {
fn flatten_associated_items_in_impl_fns(imp: &mut ItemImpl) -> Result<(), Error> {
// TODO: Flatten associated consts used in functions.
// Flatten associated types used in functions.
let associated_types = imp
let associated_types: HashMap<Ident, Type> = imp
.items
.iter()
.filter_map(|item| match item {
ImplItem::Type(i) => Some((i.ident.clone(), i.ty.clone())),
_ => None,
})
.collect::<HashMap<_, _>>();
let fn_input_types = imp
.items
.iter_mut()
.filter_map(|item| match item {
ImplItem::Fn(f) => Some(f.sig.inputs.iter_mut().filter_map(|input| match input {
FnArg::Typed(t) => Some(&mut t.ty),
_ => None,
})),
_ => None,
})
.flatten();
for t in fn_input_types {
if let Type::Path(TypePath { qself: None, path }) = t.as_mut() {
let segments = &path.segments;
if segments.len() == 2
&& segments.first() == Some(&PathSegment::from(format_ident!("Self")))
{
if let Some(PathSegment {
arguments: PathArguments::None,
ident,
}) = segments.get(1)
{
if let Some(resolved_ty) = associated_types.get(ident) {
*t.as_mut() = resolved_ty.clone();
.collect();

// Resolve Self::* in function input types and return types, including
// inside generic arguments like Vec<Self::Val>, Result<Self::Val, Error>, or &Self::Val.
// Uses default 128 depth limit for types. This is a somewhat arbitrary limit if it needs
// to be increased in the future.
for item in imp.items.iter_mut() {
if let ImplItem::Fn(f) = item {
for input in f.sig.inputs.iter_mut() {
if let FnArg::Typed(t) = input {
resolve_self_types(&mut t.ty, &associated_types, 128)?;
}
}
if let ReturnType::Type(_, ty) = &mut f.sig.output {
resolve_self_types(ty, &associated_types, 128)?;
}
Comment thread
mootz12 marked this conversation as resolved.
}
}

Ok(())
}

/// Recursively resolve `Self::Ident` types within a type, including inside
/// generic arguments like `Vec<Self::Val>`, `Result<Self::Val, Error>`, or `&Self::Val`.
///
/// ### Errors
/// If we cannot resolve the type or any unresolved `Self::Ident` remains after resolution.
fn resolve_self_types(
ty: &mut Type,
associated_types: &HashMap<Ident, Type>,
depth: usize,
) -> Result<(), Error> {
if depth == 0 {
return Err(Error::new(
ty.span(),
"unable to resolve type; type depth limit exceeded",
));
}

if let Some(ident) = self_type_ident(ty)? {
if let Some(resolved) = associated_types.get(ident).cloned() {
*ty = resolved;
return resolve_self_types(ty, associated_types, depth - 1);
}
return Err(Error::new(
ty.span(),
format!("unresolved associated type `Self::{ident}`; use a concrete type instead"),
));
}

match ty {
// Reject qualified Self paths like `<Self as Trait>::Foo`.
Type::Path(TypePath { qself: Some(qself), .. })
if matches!(qself.ty.as_ref(), Type::Path(TypePath { qself: None, path }) if path.is_ident("Self")) =>
{
Err(Error::new(
ty.span(),
"qualified associated types like `<Self as Trait>::Type` are not supported; use a concrete type instead",
))
}
Comment thread
mootz12 marked this conversation as resolved.
Outdated
// Recurse into generic arguments of path types.
Type::Path(TypePath { path, .. }) => {
for segment in path.segments.iter_mut() {
if let PathArguments::AngleBracketed(args) = &mut segment.arguments {
for arg in args.args.iter_mut() {
if let GenericArgument::Type(inner_ty) = arg {
resolve_self_types(inner_ty, associated_types, depth - 1)?;
}
}
}
}
Ok(())
}
// Recurse into reference types like &Self::Val.
Type::Reference(TypeReference { elem, .. }) => {
resolve_self_types(elem, associated_types, depth - 1)
}
_ => Ok(()),
Comment thread
mootz12 marked this conversation as resolved.
Outdated
}
}

/// If the type is `Self::Ident`, return the `Ident`. Otherwise return `None`.
///
/// ### Errors
/// If the type is a generic associated type like `Self::Foo<T>`.
fn self_type_ident(ty: &Type) -> Result<Option<&Ident>, Error> {
if let Type::Path(TypePath { qself: None, path }) = ty {
let segments = &path.segments;
if segments.len() == 2
&& segments.first() == Some(&PathSegment::from(format_ident!("Self")))
{
if let Some(seg) = segments.get(1) {
return match seg.arguments {
PathArguments::None => Ok(Some(&seg.ident)),
_ => Err(Error::new(
path.span(),
format!("generic associated types like `Self::{}<..>` are not supported; use a concrete type instead", seg.ident),
)),
};
}
}
}
Ok(None)
}

pub fn ty_to_safe_ident_str(ty: &Type) -> String {
quote!(#ty).to_string().replace(' ', "").replace(':', "_")
}
Expand Down Expand Up @@ -438,3 +511,180 @@ mod test_path_in_macro_rules {
assert_paths_eq(input, expected);
}
}

#[cfg(test)]
mod test_fns_parse {
use super::*;
use quote::quote;
use syn::parse2;

/// Parse an impl block through HasFnsItem and return the resolved fns.
fn parse_fns(input: TokenStream) -> syn::Result<Vec<Fn>> {
parse2::<HasFnsItem>(input).map(|item| item.fns())
}

/// Parse an impl block and return the string representation of the nth
/// fn's input types (excluding self) and return type.
fn parsed_fn_sig(input: TokenStream, n: usize) -> (Vec<String>, String) {
let fns = parse_fns(input).expect("parse failed");
let f = &fns[n];
let inputs: Vec<String> = f
.inputs
.iter()
.filter_map(|arg| match arg {
FnArg::Typed(t) => Some(quote!(#t).to_string()),
_ => None,
})
.collect();
let output = match &f.output {
ReturnType::Default => "()".to_string(),
ReturnType::Type(_, ty) => quote!(#ty).to_string(),
};
(inputs, output)
}

#[test]
fn test_no_associated_types() {
let input = quote! {
impl MyContract {
pub fn hello(x: u32) -> u64 {}
}
};
let (inputs, output) = parsed_fn_sig(input, 0);
assert_eq!(inputs, vec!["x : u32"]);
assert_eq!(output, "u64");
}

#[test]
fn test_basic_param_and_return() {
let input = quote! {
impl MyContract {
type Val = u64;
pub fn get(x: Self::Val) -> Self::Val {}
}
};
let (inputs, output) = parsed_fn_sig(input, 0);
assert_eq!(inputs, vec!["x : u64"]);
assert_eq!(output, "u64");
}

#[test]
fn test_chained_two_step() {
let input = quote! {
impl MyContract {
type A = u32;
type B = Self::A;
pub fn get(x: Self::B) {}
}
};
let (inputs, _) = parsed_fn_sig(input, 0);
assert_eq!(inputs, vec!["x : u32"]);
}

#[test]
fn test_wrapped_option() {
let input = quote! {
impl MyContract {
type A = u64;
pub fn get(x: Option<Self::A>) {}
}
};
let (inputs, _) = parsed_fn_sig(input, 0);
assert_eq!(inputs, vec!["x : Option < u64 >"]);
}

#[test]
fn test_double_wrapped_result_vec() {
let input = quote! {
impl MyContract {
type A = u64;
pub fn get(x: Result<Vec<Self::A>, Error>) {}
}
};
let (inputs, _) = parsed_fn_sig(input, 0);
assert_eq!(inputs, vec!["x : Result < Vec < u64 > , Error >"]);
}

#[test]
fn test_reject_qualified_self_path() {
let input = quote! {
impl MyContract {
pub fn get(x: <Self as Trait>::A) {}
}
};
let Err(err) = parse_fns(input) else {
panic!("expected error");
};
assert!(
err.to_string().contains("qualified associated types"),
"unexpected error: {err}"
);
}

#[test]
fn test_reject_generic_associated_type() {
let input = quote! {
impl MyContract {
pub fn get(x: Self::Foo<u32>) {}
}
};
let Err(err) = parse_fns(input) else {
panic!("expected error");
};
assert!(
err.to_string().contains("generic associated types"),
"unexpected error: {err}"
);
}

#[test]
fn test_reject_buried_qualified_self_path() {
let input = quote! {
impl MyContract {
pub fn get(x: Result<Vec<<Self as Trait>::A>, Error>) {}
}
};
let Err(err) = parse_fns(input) else {
panic!("expected error");
};
assert!(
err.to_string().contains("qualified associated types"),
"unexpected error: {err}"
);
}

#[test]
fn test_reject_unresolved_type() {
let input = quote! {
impl MyContract {
pub fn get(x: Self::Elsewhere) {}
}
};
let Err(err) = parse_fns(input) else {
panic!("expected error");
};
assert!(
err.to_string()
.contains("unresolved associated type `Self::Elsewhere`"),
"unexpected error: {err}"
);
}

#[test]
fn test_reject_recursive_cycle() {
let input = quote! {
impl MyContract {
type A = Self::B;
type B = Self::A;
pub fn get(x: Self::A) {}
}
};
let Err(err) = parse_fns(input) else {
panic!("expected error");
};
assert!(
err.to_string().contains("depth limit exceeded"),
"unexpected error: {err}"
);
}
}
Loading
Loading