Skip to content
This repository was archived by the owner on Mar 11, 2025. It is now read-only.

Commit 6f18171

Browse files
author
Joe C
authored
spl-discriminator generic arg support (#4601)
`spl discriminator` lifetimes & generics
1 parent d5786c6 commit 6f18171

File tree

5 files changed

+134
-21
lines changed

5 files changed

+134
-21
lines changed

Cargo.lock

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

libraries/discriminator/Cargo.toml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,11 @@ repository = "https://github.com/solana-labs/solana-program-library"
77
license = "Apache-2.0"
88
edition = "2021"
99

10+
[features]
11+
borsh = ["dep:borsh"]
12+
1013
[dependencies]
14+
borsh = { version = "0.10", optional = true }
1115
bytemuck = { version = "1.13.1", features = ["derive"] }
1216
solana-program = "1.16.1"
1317
spl-discriminator-derive = { version = "0.1.0", path = "./derive" }

libraries/discriminator/src/discriminator.rs

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,10 @@ pub trait SplDiscriminate {
1414
}
1515

1616
/// Array Discriminator type
17+
#[cfg_attr(
18+
feature = "borsh",
19+
derive(borsh::BorshSerialize, borsh::BorshDeserialize)
20+
)]
1721
#[derive(Clone, Copy, Debug, Default, PartialEq, Pod, Zeroable)]
1822
#[repr(transparent)]
1923
pub struct ArrayDiscriminator([u8; ArrayDiscriminator::LENGTH]);
@@ -66,3 +70,13 @@ impl TryFrom<&[u8]> for ArrayDiscriminator {
6670
.map_err(|_| ProgramError::InvalidAccountData)
6771
}
6872
}
73+
impl From<ArrayDiscriminator> for [u8; 8] {
74+
fn from(from: ArrayDiscriminator) -> Self {
75+
from.0
76+
}
77+
}
78+
impl From<ArrayDiscriminator> for u64 {
79+
fn from(from: ArrayDiscriminator) -> Self {
80+
u64::from_le_bytes(from.0)
81+
}
82+
}

libraries/discriminator/src/lib.rs

Lines changed: 91 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -21,34 +21,80 @@ mod tests {
2121

2222
#[allow(dead_code)]
2323
#[derive(SplDiscriminate)]
24-
#[discriminator_hash_input("some_discriminator_hash_input")]
24+
#[discriminator_hash_input("my_first_instruction")]
2525
pub struct MyInstruction1 {
2626
arg1: String,
2727
arg2: u8,
2828
}
2929

3030
#[allow(dead_code)]
3131
#[derive(SplDiscriminate)]
32-
#[discriminator_hash_input("yet_another_discriminator_hash_input")]
33-
pub struct MyInstruction2 {
34-
arg1: u64,
32+
#[discriminator_hash_input("global:my_second_instruction")]
33+
pub enum MyInstruction2 {
34+
One,
35+
Two,
36+
Three,
3537
}
3638

3739
#[allow(dead_code)]
3840
#[derive(SplDiscriminate)]
39-
#[discriminator_hash_input("global:my_instruction_3")]
40-
pub enum MyInstruction3 {
41-
One,
42-
Two,
43-
Three,
41+
#[discriminator_hash_input("global:my_instruction_with_lifetime")]
42+
pub struct MyInstruction3<'a> {
43+
data: &'a [u8],
44+
}
45+
46+
#[allow(dead_code)]
47+
#[derive(SplDiscriminate)]
48+
#[discriminator_hash_input("global:my_instruction_with_one_generic")]
49+
pub struct MyInstruction4<T> {
50+
data: T,
51+
}
52+
53+
#[allow(dead_code)]
54+
#[derive(SplDiscriminate)]
55+
#[discriminator_hash_input("global:my_instruction_with_one_generic_and_lifetime")]
56+
pub struct MyInstruction5<'b, T> {
57+
data: &'b [T],
58+
}
59+
60+
#[allow(dead_code)]
61+
#[derive(SplDiscriminate)]
62+
#[discriminator_hash_input("global:my_instruction_with_multiple_generics_and_lifetime")]
63+
pub struct MyInstruction6<'c, U, V> {
64+
data1: &'c [U],
65+
data2: &'c [V],
66+
}
67+
68+
#[allow(dead_code)]
69+
#[derive(SplDiscriminate)]
70+
#[discriminator_hash_input(
71+
"global:my_instruction_with_multiple_generics_and_lifetime_and_where"
72+
)]
73+
pub struct MyInstruction7<'c, U, V>
74+
where
75+
U: Clone + Copy,
76+
V: Clone + Copy,
77+
{
78+
data1: &'c [U],
79+
data2: &'c [V],
4480
}
4581

4682
fn assert_discriminator<T: spl_discriminator::discriminator::SplDiscriminate>(
4783
hash_input: &str,
4884
) {
4985
let discriminator = build_discriminator(hash_input);
50-
assert_eq!(T::SPL_DISCRIMINATOR, discriminator);
51-
assert_eq!(T::SPL_DISCRIMINATOR_SLICE, discriminator.as_slice());
86+
assert_eq!(
87+
T::SPL_DISCRIMINATOR,
88+
discriminator,
89+
"Discriminator mismatch: case: {}",
90+
hash_input
91+
);
92+
assert_eq!(
93+
T::SPL_DISCRIMINATOR_SLICE,
94+
discriminator.as_slice(),
95+
"Discriminator mismatch: case: {}",
96+
hash_input
97+
);
5298
}
5399

54100
fn build_discriminator(hash_input: &str) -> ArrayDiscriminator {
@@ -60,10 +106,39 @@ mod tests {
60106

61107
#[test]
62108
fn test_discrminators() {
63-
assert_discriminator::<MyInstruction1>("some_discriminator_hash_input");
64-
assert_discriminator::<MyInstruction2>("yet_another_discriminator_hash_input");
65-
assert_discriminator::<MyInstruction3>("global:my_instruction_3");
66-
let runtime_discrim = ArrayDiscriminator::new_with_hash_input("my_new_hash_input");
67-
assert_eq!(runtime_discrim, build_discriminator("my_new_hash_input"),);
109+
let runtime_discrim = ArrayDiscriminator::new_with_hash_input("my_runtime_hash_input");
110+
assert_eq!(
111+
runtime_discrim,
112+
build_discriminator("my_runtime_hash_input"),
113+
);
114+
115+
assert_discriminator::<MyInstruction1>("my_first_instruction");
116+
assert_discriminator::<MyInstruction2>("global:my_second_instruction");
117+
assert_discriminator::<MyInstruction3<'_>>("global:my_instruction_with_lifetime");
118+
assert_discriminator::<MyInstruction4<u8>>("global:my_instruction_with_one_generic");
119+
assert_discriminator::<MyInstruction5<'_, u8>>(
120+
"global:my_instruction_with_one_generic_and_lifetime",
121+
);
122+
assert_discriminator::<MyInstruction6<'_, u8, u8>>(
123+
"global:my_instruction_with_multiple_generics_and_lifetime",
124+
);
125+
assert_discriminator::<MyInstruction7<'_, u8, u8>>(
126+
"global:my_instruction_with_multiple_generics_and_lifetime_and_where",
127+
);
128+
}
129+
}
130+
131+
#[cfg(all(test, feature = "borsh"))]
132+
mod borsh_test {
133+
use super::*;
134+
135+
#[test]
136+
fn borsh_test() {
137+
let my_discrim = ArrayDiscriminator::new_with_hash_input("my_discrim");
138+
let mut buffer = [0u8; 8];
139+
my_discrim.serialize(&mut buffer[..]).unwrap();
140+
let my_discrim_again = ArrayDiscriminator::try_from_slice(&buffer).unwrap();
141+
assert_eq!(my_discrim, my_discrim_again);
142+
assert_eq!(buf, my_discrim.into());
68143
}
69144
}

libraries/discriminator/syn/src/lib.rs

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,15 +11,18 @@ use {
1111
proc_macro2::{Span, TokenStream},
1212
quote::{quote, ToTokens},
1313
solana_program::hash,
14-
syn::{parse::Parse, Ident, Item, ItemEnum, ItemStruct, LitByteStr},
14+
syn::{parse::Parse, Generics, Ident, Item, ItemEnum, ItemStruct, LitByteStr, WhereClause},
1515
};
1616

1717
/// "Builder" struct to implement the `SplDiscriminate` trait
1818
/// on an enum or struct
19-
#[derive(Debug)]
2019
pub struct SplDiscriminateBuilder {
2120
/// The struct/enum identifier
2221
pub ident: Ident,
22+
/// The item's generic arguments (if any)
23+
pub generics: Generics,
24+
/// The item's where clause for generics (if any)
25+
pub where_clause: Option<WhereClause>,
2326
/// The TLV hash_input
2427
pub hash_input: String,
2528
}
@@ -29,8 +32,15 @@ impl TryFrom<ItemEnum> for SplDiscriminateBuilder {
2932

3033
fn try_from(item_enum: ItemEnum) -> Result<Self, Self::Error> {
3134
let ident = item_enum.ident;
35+
let where_clause = item_enum.generics.where_clause.clone();
36+
let generics = item_enum.generics;
3237
let hash_input = parse_hash_input(&item_enum.attrs)?;
33-
Ok(Self { ident, hash_input })
38+
Ok(Self {
39+
ident,
40+
generics,
41+
where_clause,
42+
hash_input,
43+
})
3444
}
3545
}
3646

@@ -39,8 +49,15 @@ impl TryFrom<ItemStruct> for SplDiscriminateBuilder {
3949

4050
fn try_from(item_struct: ItemStruct) -> Result<Self, Self::Error> {
4151
let ident = item_struct.ident;
52+
let where_clause = item_struct.generics.where_clause.clone();
53+
let generics = item_struct.generics;
4254
let hash_input = parse_hash_input(&item_struct.attrs)?;
43-
Ok(Self { ident, hash_input })
55+
Ok(Self {
56+
ident,
57+
generics,
58+
where_clause,
59+
hash_input,
60+
})
4461
}
4562
}
4663

@@ -70,9 +87,11 @@ impl ToTokens for SplDiscriminateBuilder {
7087
impl From<&SplDiscriminateBuilder> for TokenStream {
7188
fn from(builder: &SplDiscriminateBuilder) -> Self {
7289
let ident = &builder.ident;
90+
let generics = &builder.generics;
91+
let where_clause = &builder.where_clause;
7392
let bytes = get_discriminator_bytes(&builder.hash_input);
7493
quote! {
75-
impl spl_discriminator::discriminator::SplDiscriminate for #ident {
94+
impl #generics spl_discriminator::discriminator::SplDiscriminate for #ident #generics #where_clause {
7695
const SPL_DISCRIMINATOR: spl_discriminator::discriminator::ArrayDiscriminator
7796
= spl_discriminator::discriminator::ArrayDiscriminator::new(*#bytes);
7897
}

0 commit comments

Comments
 (0)