Skip to content
This repository was archived by the owner on Mar 11, 2025. It is now read-only.

Commit cfaabb5

Browse files
author
Joe C
authored
SPL errors from hashes (#5169)
* SPL errors from hashes * hashed error code is first variant only * add check for collision error codes * address feedback! * stupid `0`!
1 parent 25381b2 commit cfaabb5

File tree

17 files changed

+300
-108
lines changed

17 files changed

+300
-108
lines changed

Cargo.lock

Lines changed: 2 additions & 4 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

libraries/program-error/derive/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,4 +13,5 @@ proc-macro = true
1313
[dependencies]
1414
proc-macro2 = "1.0"
1515
quote = "1.0"
16+
solana-program = "1.16.3"
1617
syn = { version = "2.0", features = ["full"] }

libraries/program-error/derive/src/lib.rs

Lines changed: 47 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -14,39 +14,72 @@
1414
extern crate proc_macro;
1515

1616
mod macro_impl;
17+
mod parser;
1718

18-
use macro_impl::MacroType;
19-
use proc_macro::TokenStream;
20-
use syn::{parse_macro_input, ItemEnum};
19+
use {
20+
crate::parser::SplProgramErrorArgs,
21+
macro_impl::MacroType,
22+
proc_macro::TokenStream,
23+
syn::{parse_macro_input, ItemEnum},
24+
};
2125

22-
/// Derive macro to add `Into<solana_program::program_error::ProgramError>` traits
26+
/// Derive macro to add `Into<solana_program::program_error::ProgramError>`
27+
/// trait
2328
#[proc_macro_derive(IntoProgramError)]
2429
pub fn into_program_error(input: TokenStream) -> TokenStream {
25-
MacroType::IntoProgramError
26-
.generate_tokens(parse_macro_input!(input as ItemEnum))
30+
let ItemEnum { ident, .. } = parse_macro_input!(input as ItemEnum);
31+
MacroType::IntoProgramError { ident }
32+
.generate_tokens()
2733
.into()
2834
}
2935

3036
/// Derive macro to add `solana_program::decode_error::DecodeError` trait
3137
#[proc_macro_derive(DecodeError)]
3238
pub fn decode_error(input: TokenStream) -> TokenStream {
33-
MacroType::DecodeError
34-
.generate_tokens(parse_macro_input!(input as ItemEnum))
35-
.into()
39+
let ItemEnum { ident, .. } = parse_macro_input!(input as ItemEnum);
40+
MacroType::DecodeError { ident }.generate_tokens().into()
3641
}
3742

3843
/// Derive macro to add `solana_program::program_error::PrintProgramError` trait
3944
#[proc_macro_derive(PrintProgramError)]
4045
pub fn print_program_error(input: TokenStream) -> TokenStream {
41-
MacroType::PrintProgramError
42-
.generate_tokens(parse_macro_input!(input as ItemEnum))
46+
let ItemEnum {
47+
ident, variants, ..
48+
} = parse_macro_input!(input as ItemEnum);
49+
MacroType::PrintProgramError { ident, variants }
50+
.generate_tokens()
4351
.into()
4452
}
4553

4654
/// Proc macro attribute to turn your enum into a Solana Program Error
55+
///
56+
/// Adds:
57+
/// - `Clone`
58+
/// - `Debug`
59+
/// - `Eq`
60+
/// - `PartialEq`
61+
/// - `thiserror::Error`
62+
/// - `num_derive::FromPrimitive`
63+
/// - `Into<solana_program::program_error::ProgramError>`
64+
/// - `solana_program::decode_error::DecodeError`
65+
/// - `solana_program::program_error::PrintProgramError`
66+
///
67+
/// Optionally, you can add `hash_error_code_start: u32` argument to create
68+
/// a unique `u32` _starting_ error codes from the names of the enum variants.
69+
/// Notes:
70+
/// - The _error_ variant will start at this value, and the rest will be
71+
/// incremented by one
72+
/// - The value provided is only for code readability, the actual error code
73+
/// will be a hash of the input string and is checked against your input
74+
///
75+
/// Syntax: `#[spl_program_error(hash_error_code_start = 1275525928)]`
76+
/// Hash Input: `spl_program_error:<enum name>:<variant name>`
77+
/// Value: `u32::from_le_bytes(<hash of input>[13..17])`
4778
#[proc_macro_attribute]
48-
pub fn spl_program_error(_: TokenStream, input: TokenStream) -> TokenStream {
49-
MacroType::SplProgramError
50-
.generate_tokens(parse_macro_input!(input as ItemEnum))
79+
pub fn spl_program_error(attr: TokenStream, input: TokenStream) -> TokenStream {
80+
let args = parse_macro_input!(attr as SplProgramErrorArgs);
81+
let item_enum = parse_macro_input!(input as ItemEnum);
82+
MacroType::SplProgramError { args, item_enum }
83+
.generate_tokens()
5184
.into()
5285
}

libraries/program-error/derive/src/macro_impl.rs

Lines changed: 102 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,32 +1,52 @@
11
//! The actual token generator for the macro
2-
use quote::quote;
3-
use syn::{punctuated::Punctuated, token::Comma, Ident, ItemEnum, LitStr, Variant};
2+
3+
use {
4+
crate::parser::SplProgramErrorArgs,
5+
proc_macro2::Span,
6+
quote::quote,
7+
syn::{
8+
punctuated::Punctuated, token::Comma, Expr, ExprLit, Ident, ItemEnum, Lit, LitInt, LitStr,
9+
Token, Variant,
10+
},
11+
};
12+
13+
const SPL_ERROR_HASH_NAMESPACE: &str = "spl_program_error";
14+
const SPL_ERROR_HASH_MIN_VALUE: u32 = 7_000;
415

516
/// The type of macro being called, thus directing which tokens to generate
617
#[allow(clippy::enum_variant_names)]
718
pub enum MacroType {
8-
IntoProgramError,
9-
DecodeError,
10-
PrintProgramError,
11-
SplProgramError,
19+
IntoProgramError {
20+
ident: Ident,
21+
},
22+
DecodeError {
23+
ident: Ident,
24+
},
25+
PrintProgramError {
26+
ident: Ident,
27+
variants: Punctuated<Variant, Comma>,
28+
},
29+
SplProgramError {
30+
args: SplProgramErrorArgs,
31+
item_enum: ItemEnum,
32+
},
1233
}
1334

1435
impl MacroType {
1536
/// Generates the corresponding tokens based on variant selection
16-
pub fn generate_tokens(&self, item_enum: ItemEnum) -> proc_macro2::TokenStream {
37+
pub fn generate_tokens(&mut self) -> proc_macro2::TokenStream {
1738
match self {
18-
MacroType::IntoProgramError => into_program_error(&item_enum.ident),
19-
MacroType::DecodeError => decode_error(&item_enum.ident),
20-
MacroType::PrintProgramError => {
21-
print_program_error(&item_enum.ident, &item_enum.variants)
22-
}
23-
MacroType::SplProgramError => spl_program_error(item_enum),
39+
Self::IntoProgramError { ident } => into_program_error(ident),
40+
Self::DecodeError { ident } => decode_error(ident),
41+
Self::PrintProgramError { ident, variants } => print_program_error(ident, variants),
42+
Self::SplProgramError { args, item_enum } => spl_program_error(args, item_enum),
2443
}
2544
}
2645
}
2746

28-
/// Builds the implementation of `Into<solana_program::program_error::ProgramError>`
29-
/// More specifically, implements `From<Self> for solana_program::program_error::ProgramError`
47+
/// Builds the implementation of
48+
/// `Into<solana_program::program_error::ProgramError>` More specifically,
49+
/// implements `From<Self> for solana_program::program_error::ProgramError`
3050
pub fn into_program_error(ident: &Ident) -> proc_macro2::TokenStream {
3151
quote! {
3252
impl From<#ident> for solana_program::program_error::ProgramError {
@@ -48,7 +68,8 @@ pub fn decode_error(ident: &Ident) -> proc_macro2::TokenStream {
4868
}
4969
}
5070

51-
/// Builds the implementation of `solana_program::program_error::PrintProgramError`
71+
/// Builds the implementation of
72+
/// `solana_program::program_error::PrintProgramError`
5273
pub fn print_program_error(
5374
ident: &Ident,
5475
variants: &Punctuated<Variant, Comma>,
@@ -96,16 +117,25 @@ fn get_error_message(variant: &Variant) -> Option<String> {
96117

97118
/// The main function that produces the tokens required to turn your
98119
/// error enum into a Solana Program Error
99-
pub fn spl_program_error(input: ItemEnum) -> proc_macro2::TokenStream {
100-
let ident = &input.ident;
101-
let variants = &input.variants;
120+
pub fn spl_program_error(
121+
args: &SplProgramErrorArgs,
122+
item_enum: &mut ItemEnum,
123+
) -> proc_macro2::TokenStream {
124+
if let Some(error_code_start) = args.hash_error_code_start {
125+
set_first_discriminant(item_enum, error_code_start);
126+
}
127+
128+
let ident = &item_enum.ident;
129+
let variants = &item_enum.variants;
102130
let into_program_error = into_program_error(ident);
103131
let decode_error = decode_error(ident);
104132
let print_program_error = print_program_error(ident, variants);
133+
105134
quote! {
135+
#[repr(u32)]
106136
#[derive(Clone, Debug, Eq, thiserror::Error, num_derive::FromPrimitive, PartialEq)]
107137
#[num_traits = "num_traits"]
108-
#input
138+
#item_enum
109139

110140
#into_program_error
111141

@@ -114,3 +144,55 @@ pub fn spl_program_error(input: ItemEnum) -> proc_macro2::TokenStream {
114144
#print_program_error
115145
}
116146
}
147+
148+
/// This function adds a discriminant to the first enum variant based on the
149+
/// hash of the `SPL_ERROR_HASH_NAMESPACE` constant, the enum name and variant
150+
/// name.
151+
/// It will then check to make sure the provided `hash_error_code_start` is
152+
/// equal to the hash-produced `u32`.
153+
///
154+
/// See https://docs.rs/syn/latest/syn/struct.Variant.html
155+
fn set_first_discriminant(item_enum: &mut ItemEnum, error_code_start: u32) {
156+
let enum_ident = &item_enum.ident;
157+
if item_enum.variants.is_empty() {
158+
panic!("Enum must have at least one variant");
159+
}
160+
let first_variant = &mut item_enum.variants[0];
161+
let discriminant = u32_from_hash(enum_ident);
162+
if discriminant == error_code_start {
163+
let eq = Token![=](Span::call_site());
164+
let expr = Expr::Lit(ExprLit {
165+
attrs: Vec::new(),
166+
lit: Lit::Int(LitInt::new(&discriminant.to_string(), Span::call_site())),
167+
});
168+
first_variant.discriminant = Some((eq, expr));
169+
} else {
170+
panic!(
171+
"Error code start value from hash must be {0}. Update your macro attribute to \
172+
`#[spl_program_error(hash_error_code_start = {0})]`.",
173+
discriminant
174+
);
175+
}
176+
}
177+
178+
/// Hashes the `SPL_ERROR_HASH_NAMESPACE` constant, the enum name and variant
179+
/// name and returns four middle bytes (13 through 16) as a u32.
180+
fn u32_from_hash(enum_ident: &Ident) -> u32 {
181+
let hash_input = format!("{}:{}", SPL_ERROR_HASH_NAMESPACE, enum_ident);
182+
183+
// We don't want our error code to start at any number below
184+
// `SPL_ERROR_HASH_MIN_VALUE`!
185+
let mut nonce: u32 = 0;
186+
loop {
187+
let hash = solana_program::hash::hashv(&[hash_input.as_bytes(), &nonce.to_le_bytes()]);
188+
let d = u32::from_le_bytes(
189+
hash.to_bytes()[13..17]
190+
.try_into()
191+
.expect("Unable to convert hash to u32"),
192+
);
193+
if d >= SPL_ERROR_HASH_MIN_VALUE {
194+
return d;
195+
}
196+
nonce += 1;
197+
}
198+
}
Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
//! Token parsing
2+
3+
use {
4+
proc_macro2::Ident,
5+
syn::{
6+
parse::{Parse, ParseStream},
7+
token::Comma,
8+
LitInt, Token,
9+
},
10+
};
11+
12+
/// Possible arguments to the `#[spl_program_error]` attribute
13+
pub struct SplProgramErrorArgs {
14+
/// Whether to hash the error codes using `solana_program::hash`
15+
/// or to use the default error code assigned by `num_traits`.
16+
pub hash_error_code_start: Option<u32>,
17+
}
18+
19+
impl Parse for SplProgramErrorArgs {
20+
fn parse(input: ParseStream) -> syn::Result<Self> {
21+
if input.is_empty() {
22+
return Ok(Self {
23+
hash_error_code_start: None,
24+
});
25+
}
26+
match SplProgramErrorArgParser::parse(input)? {
27+
SplProgramErrorArgParser::HashErrorCodes { value, .. } => Ok(Self {
28+
hash_error_code_start: Some(value.base10_parse::<u32>()?),
29+
}),
30+
}
31+
}
32+
}
33+
34+
/// Parser for args to the `#[spl_program_error]` attribute
35+
/// ie. `#[spl_program_error(hash_error_code_start = 1275525928)]`
36+
enum SplProgramErrorArgParser {
37+
HashErrorCodes {
38+
_ident: Ident,
39+
_equals_sign: Token![=],
40+
value: LitInt,
41+
_comma: Option<Comma>,
42+
},
43+
}
44+
45+
impl Parse for SplProgramErrorArgParser {
46+
fn parse(input: ParseStream) -> syn::Result<Self> {
47+
let _ident = {
48+
let ident = input.parse::<Ident>()?;
49+
if ident != "hash_error_code_start" {
50+
return Err(input.error("Expected argument 'hash_error_code_start'"));
51+
}
52+
ident
53+
};
54+
let _equals_sign = input.parse::<Token![=]>()?;
55+
let value = input.parse::<LitInt>()?;
56+
let _comma: Option<Comma> = input.parse().unwrap_or(None);
57+
Ok(Self::HashErrorCodes {
58+
_ident,
59+
_equals_sign,
60+
value,
61+
_comma,
62+
})
63+
}
64+
}

libraries/program-error/src/lib.rs

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,10 @@ extern crate self as spl_program_error;
88

99
// Make these available downstream for the macro to work without
1010
// additional imports
11-
pub use num_derive;
12-
pub use num_traits;
13-
pub use solana_program;
14-
pub use spl_program_error_derive::{
15-
spl_program_error, DecodeError, IntoProgramError, PrintProgramError,
11+
pub use {
12+
num_derive, num_traits, solana_program,
13+
spl_program_error_derive::{
14+
spl_program_error, DecodeError, IntoProgramError, PrintProgramError,
15+
},
16+
thiserror,
1617
};
17-
pub use thiserror;

libraries/program-error/tests/decode.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
//! Tests `#[derive(DecodeError)]`
2-
//!
2+
33
use spl_program_error::*;
44

55
/// Example error

libraries/program-error/tests/into.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
//! Tests `#[derive(IntoProgramError)]`
2-
//!
2+
33
use spl_program_error::*;
44

55
/// Example error

libraries/program-error/tests/mod.rs

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,15 @@ pub mod spl;
66

77
#[cfg(test)]
88
mod tests {
9-
use super::*;
10-
use serial_test::serial;
11-
use solana_program::{
12-
decode_error::DecodeError,
13-
program_error::{PrintProgramError, ProgramError},
9+
use {
10+
super::*,
11+
serial_test::serial,
12+
solana_program::{
13+
decode_error::DecodeError,
14+
program_error::{PrintProgramError, ProgramError},
15+
},
16+
std::sync::{Arc, RwLock},
1417
};
15-
use std::sync::{Arc, RwLock};
1618

1719
// Used to capture output for `PrintProgramError` for testing
1820
lazy_static::lazy_static! {

0 commit comments

Comments
 (0)