diff --git a/discriminator/Cargo.toml b/discriminator/Cargo.toml new file mode 100644 index 00000000..7579c29d --- /dev/null +++ b/discriminator/Cargo.toml @@ -0,0 +1,24 @@ +[package] +name = "spl-discriminator" +version = "0.4.0" +description = "Solana Program Library 8-Byte Discriminator Management" +authors = ["Solana Labs Maintainers "] +repository = "https://github.com/solana-labs/solana-program-library" +license = "Apache-2.0" +edition = "2021" + +[features] +borsh = ["dep:borsh"] + +[dependencies] +borsh = { version = "1", optional = true } +bytemuck = { version = "1.20.0", features = ["derive"] } +solana-program-error = "2.1.0" +solana-sha256-hasher = "2.1.0" +spl-discriminator-derive = { version = "0.2.0", path = "./derive" } + +[lib] +crate-type = ["cdylib", "lib"] + +[package.metadata.docs.rs] +targets = ["x86_64-unknown-linux-gnu"] diff --git a/discriminator/README.md b/discriminator/README.md new file mode 100644 index 00000000..8d31cc70 --- /dev/null +++ b/discriminator/README.md @@ -0,0 +1,57 @@ +# SPL Discriminator + +This library allows for easy management of 8-byte discriminators. + +### The `ArrayDiscriminator` Struct + +With this crate, you can leverage the `ArrayDiscriminator` type to manage an 8-byte discriminator for generic purposes. + +```rust +let my_discriminator = ArrayDiscriminator::new([8, 5, 1, 56, 10, 53, 9, 198]); +``` + +The `new(..)` function is also a **constant function**, so you can use `ArrayDiscriminator` in constants as well. + +```rust +const MY_DISCRIMINATOR: ArrayDiscriminator = ArrayDiscriminator::new([8, 5, 1, 56, 10, 53, 9, 198]); +``` + +The `ArrayDiscriminator` struct also offers another constant function `as_slice(&self)`, so you can use `as_slice()` in constants as well. + +```rust +const MY_DISCRIMINATOR_SLICE: &[u8] = MY_DISCRIMINATOR.as_slice(); +``` + +### The `SplDiscriminate` Trait + +A trait, `SplDiscriminate` is also available, which will give you the `ArrayDiscriminator` constant type and also a slice representation of the discriminator. This can be particularly handy with match statements. + +```rust +/// A trait for managing 8-byte discriminators in a slab of bytes +pub trait SplDiscriminate { + /// The 8-byte discriminator as a `[u8; 8]` + const SPL_DISCRIMINATOR: ArrayDiscriminator; + /// The 8-byte discriminator as a slice (`&[u8]`) + const SPL_DISCRIMINATOR_SLICE: &'static [u8] = Self::SPL_DISCRIMINATOR.as_slice(); +} +``` + +### The `SplDiscriminate` Derive Macro + +The `SplDiscriminate` derive macro is a particularly useful tool for those who wish to derive their 8-byte discriminator from a particular string literal. Typically, you would have to run a hash function against the string literal, then copy the first 8 bytes, and then hard-code those bytes into a statement like the one above. + +Instead, you can simply annotate a struct or enum with `SplDiscriminate` and provide a **hash input** via the `discriminator_hash_input` attribute, and the macro will automatically derive the 8-byte discriminator for you! + +```rust +#[derive(SplDiscriminate)] // Implements `SplDiscriminate` for your struct/enum using your declared string literal hash_input +#[discriminator_hash_input("some_discriminator_hash_input")] +pub struct MyInstruction1 { + arg1: String, + arg2: u8, +} + +let my_discriminator: ArrayDiscriminator = MyInstruction1::SPL_DISCRIMINATOR; +let my_discriminator_slice: &[u8] = MyInstruction1::SPL_DISCRIMINATOR_SLICE; +``` + +Note: the 8-byte discriminator derived using the macro is always the **first 8 bytes** of the resulting hashed bytes. diff --git a/discriminator/derive/Cargo.toml b/discriminator/derive/Cargo.toml new file mode 100644 index 00000000..32ce6085 --- /dev/null +++ b/discriminator/derive/Cargo.toml @@ -0,0 +1,19 @@ +[package] +name = "spl-discriminator-derive" +version = "0.2.0" +description = "Derive macro library for the `spl-discriminator` library" +authors = ["Solana Labs Maintainers "] +repository = "https://github.com/solana-labs/solana-program-library" +license = "Apache-2.0" +edition = "2021" + +[dependencies] +quote = "1.0" +spl-discriminator-syn = { version = "0.2.0", path = "../syn" } +syn = { version = "2.0", features = ["full"] } + +[lib] +proc-macro = true + +[package.metadata.docs.rs] +targets = ["x86_64-unknown-linux-gnu"] diff --git a/discriminator/derive/src/lib.rs b/discriminator/derive/src/lib.rs new file mode 100644 index 00000000..6080b80b --- /dev/null +++ b/discriminator/derive/src/lib.rs @@ -0,0 +1,20 @@ +//! Derive macro library for the `spl-discriminator` library + +#![deny(missing_docs)] +#![cfg_attr(not(test), forbid(unsafe_code))] + +extern crate proc_macro; + +use { + proc_macro::TokenStream, quote::ToTokens, spl_discriminator_syn::SplDiscriminateBuilder, + syn::parse_macro_input, +}; + +/// Derive macro library to implement the `SplDiscriminate` trait +/// on an enum or struct +#[proc_macro_derive(SplDiscriminate, attributes(discriminator_hash_input))] +pub fn spl_discriminator(input: TokenStream) -> TokenStream { + parse_macro_input!(input as SplDiscriminateBuilder) + .to_token_stream() + .into() +} diff --git a/discriminator/src/discriminator.rs b/discriminator/src/discriminator.rs new file mode 100644 index 00000000..aef70065 --- /dev/null +++ b/discriminator/src/discriminator.rs @@ -0,0 +1,83 @@ +//! The traits and types used to create a discriminator for a type + +use { + bytemuck::{Pod, Zeroable}, + solana_program_error::ProgramError, + solana_sha256_hasher::hashv, +}; + +/// A trait for managing 8-byte discriminators in a slab of bytes +pub trait SplDiscriminate { + /// The 8-byte discriminator as a `[u8; 8]` + const SPL_DISCRIMINATOR: ArrayDiscriminator; + /// The 8-byte discriminator as a slice (`&[u8]`) + const SPL_DISCRIMINATOR_SLICE: &'static [u8] = Self::SPL_DISCRIMINATOR.as_slice(); +} + +/// Array Discriminator type +#[cfg_attr( + feature = "borsh", + derive(borsh::BorshSerialize, borsh::BorshDeserialize) +)] +#[derive(Clone, Copy, Debug, Default, PartialEq, Pod, Zeroable)] +#[repr(transparent)] +pub struct ArrayDiscriminator([u8; ArrayDiscriminator::LENGTH]); +impl ArrayDiscriminator { + /// Size for discriminator in account data + pub const LENGTH: usize = 8; + /// Uninitialized variant of a discriminator + pub const UNINITIALIZED: Self = Self::new([0; Self::LENGTH]); + /// Creates a discriminator from an array + pub const fn new(value: [u8; Self::LENGTH]) -> Self { + Self(value) + } + /// Get the array as a const slice + pub const fn as_slice(&self) -> &[u8] { + self.0.as_slice() + } + /// Creates a new `ArrayDiscriminator` from some hash input string literal + pub fn new_with_hash_input(hash_input: &str) -> Self { + let hash_bytes = hashv(&[hash_input.as_bytes()]).to_bytes(); + let mut discriminator_bytes = [0u8; 8]; + discriminator_bytes.copy_from_slice(&hash_bytes[..8]); + Self(discriminator_bytes) + } +} +impl AsRef<[u8]> for ArrayDiscriminator { + fn as_ref(&self) -> &[u8] { + &self.0[..] + } +} +impl AsRef<[u8; ArrayDiscriminator::LENGTH]> for ArrayDiscriminator { + fn as_ref(&self) -> &[u8; ArrayDiscriminator::LENGTH] { + &self.0 + } +} +impl From for ArrayDiscriminator { + fn from(from: u64) -> Self { + Self(from.to_le_bytes()) + } +} +impl From<[u8; Self::LENGTH]> for ArrayDiscriminator { + fn from(from: [u8; Self::LENGTH]) -> Self { + Self(from) + } +} +impl TryFrom<&[u8]> for ArrayDiscriminator { + type Error = ProgramError; + fn try_from(a: &[u8]) -> Result { + <[u8; Self::LENGTH]>::try_from(a) + .map(Self::from) + .map_err(|_| ProgramError::InvalidAccountData) + } +} +impl From for [u8; 8] { + fn from(from: ArrayDiscriminator) -> Self { + from.0 + } +} +impl From for u64 { + fn from(from: ArrayDiscriminator) -> Self { + u64::from_le_bytes(from.0) + } +} diff --git a/discriminator/src/lib.rs b/discriminator/src/lib.rs new file mode 100644 index 00000000..5f7ddc29 --- /dev/null +++ b/discriminator/src/lib.rs @@ -0,0 +1,144 @@ +//! Crate defining a discriminator type, which creates a set of bytes +//! meant to be unique for instructions or struct types + +#![deny(missing_docs)] +#![cfg_attr(not(test), forbid(unsafe_code))] + +extern crate self as spl_discriminator; + +/// Exports the discriminator module +pub mod discriminator; + +// Export for downstream +pub use { + discriminator::{ArrayDiscriminator, SplDiscriminate}, + spl_discriminator_derive::SplDiscriminate, +}; + +#[cfg(test)] +mod tests { + use {super::*, crate::discriminator::ArrayDiscriminator}; + + #[allow(dead_code)] + #[derive(SplDiscriminate)] + #[discriminator_hash_input("my_first_instruction")] + pub struct MyInstruction1 { + arg1: String, + arg2: u8, + } + + #[allow(dead_code)] + #[derive(SplDiscriminate)] + #[discriminator_hash_input("global:my_second_instruction")] + pub enum MyInstruction2 { + One, + Two, + Three, + } + + #[allow(dead_code)] + #[derive(SplDiscriminate)] + #[discriminator_hash_input("global:my_instruction_with_lifetime")] + pub struct MyInstruction3<'a> { + data: &'a [u8], + } + + #[allow(dead_code)] + #[derive(SplDiscriminate)] + #[discriminator_hash_input("global:my_instruction_with_one_generic")] + pub struct MyInstruction4 { + data: T, + } + + #[allow(dead_code)] + #[derive(SplDiscriminate)] + #[discriminator_hash_input("global:my_instruction_with_one_generic_and_lifetime")] + pub struct MyInstruction5<'b, T> { + data: &'b [T], + } + + #[allow(dead_code)] + #[derive(SplDiscriminate)] + #[discriminator_hash_input("global:my_instruction_with_multiple_generics_and_lifetime")] + pub struct MyInstruction6<'c, U, V> { + data1: &'c [U], + data2: &'c [V], + } + + #[allow(dead_code)] + #[derive(SplDiscriminate)] + #[discriminator_hash_input( + "global:my_instruction_with_multiple_generics_and_lifetime_and_where" + )] + pub struct MyInstruction7<'c, U, V> + where + U: Clone + Copy, + V: Clone + Copy, + { + data1: &'c [U], + data2: &'c [V], + } + + fn assert_discriminator( + hash_input: &str, + ) { + let discriminator = build_discriminator(hash_input); + assert_eq!( + T::SPL_DISCRIMINATOR, + discriminator, + "Discriminator mismatch: case: {}", + hash_input + ); + assert_eq!( + T::SPL_DISCRIMINATOR_SLICE, + discriminator.as_slice(), + "Discriminator mismatch: case: {}", + hash_input + ); + } + + fn build_discriminator(hash_input: &str) -> ArrayDiscriminator { + let preimage = solana_sha256_hasher::hashv(&[hash_input.as_bytes()]); + let mut bytes = [0u8; 8]; + bytes.copy_from_slice(&preimage.to_bytes()[..8]); + ArrayDiscriminator::new(bytes) + } + + #[test] + fn test_discrminators() { + let runtime_discrim = ArrayDiscriminator::new_with_hash_input("my_runtime_hash_input"); + assert_eq!( + runtime_discrim, + build_discriminator("my_runtime_hash_input"), + ); + + assert_discriminator::("my_first_instruction"); + assert_discriminator::("global:my_second_instruction"); + assert_discriminator::>("global:my_instruction_with_lifetime"); + assert_discriminator::>("global:my_instruction_with_one_generic"); + assert_discriminator::>( + "global:my_instruction_with_one_generic_and_lifetime", + ); + assert_discriminator::>( + "global:my_instruction_with_multiple_generics_and_lifetime", + ); + assert_discriminator::>( + "global:my_instruction_with_multiple_generics_and_lifetime_and_where", + ); + } +} + +#[cfg(all(test, feature = "borsh"))] +mod borsh_test { + use {super::*, borsh::BorshDeserialize}; + + #[test] + fn borsh_test() { + let my_discrim = ArrayDiscriminator::new_with_hash_input("my_discrim"); + let mut buffer = [0u8; 8]; + borsh::to_writer(&mut buffer[..], &my_discrim).unwrap(); + let my_discrim_again = ArrayDiscriminator::try_from_slice(&buffer).unwrap(); + assert_eq!(my_discrim, my_discrim_again); + assert_eq!(buffer, <[u8; 8]>::from(my_discrim)); + } +} diff --git a/discriminator/syn/Cargo.toml b/discriminator/syn/Cargo.toml new file mode 100644 index 00000000..5bd4f59c --- /dev/null +++ b/discriminator/syn/Cargo.toml @@ -0,0 +1,18 @@ +[package] +name = "spl-discriminator-syn" +version = "0.2.0" +description = "Token parsing and generating library for the `spl-discriminator` library" +authors = ["Solana Labs Maintainers "] +repository = "https://github.com/solana-labs/solana-program-library" +license = "Apache-2.0" +edition = "2021" + +[dependencies] +proc-macro2 = "1.0" +quote = "1.0" +sha2 = "0.10" +syn = { version = "2.0", features = ["full"] } +thiserror = "1.0" + +[package.metadata.docs.rs] +targets = ["x86_64-unknown-linux-gnu"] diff --git a/discriminator/syn/src/error.rs b/discriminator/syn/src/error.rs new file mode 100644 index 00000000..5dd4c30d --- /dev/null +++ b/discriminator/syn/src/error.rs @@ -0,0 +1,12 @@ +//! Error types for the `hash_input` parser + +/// Error types for the `hash_input` parser +#[derive(Clone, Debug, Eq, thiserror::Error, PartialEq)] +pub enum SplDiscriminateError { + /// Discriminator hash_input attribute not provided + #[error("Discriminator `hash_input` attribute not provided")] + HashInputAttributeNotProvided, + /// Error parsing discriminator hash_input attribute + #[error("Error parsing discriminator `hash_input` attribute")] + HashInputAttributeParseError, +} diff --git a/discriminator/syn/src/lib.rs b/discriminator/syn/src/lib.rs new file mode 100644 index 00000000..db555916 --- /dev/null +++ b/discriminator/syn/src/lib.rs @@ -0,0 +1,108 @@ +//! Token parsing and generating library for the `spl-discriminator` library + +#![deny(missing_docs)] +#![cfg_attr(not(test), forbid(unsafe_code))] + +mod error; +pub mod parser; + +use { + crate::{error::SplDiscriminateError, parser::parse_hash_input}, + proc_macro2::{Span, TokenStream}, + quote::{quote, ToTokens}, + sha2::{Digest, Sha256}, + syn::{parse::Parse, Generics, Ident, Item, ItemEnum, ItemStruct, LitByteStr, WhereClause}, +}; + +/// "Builder" struct to implement the `SplDiscriminate` trait +/// on an enum or struct +pub struct SplDiscriminateBuilder { + /// The struct/enum identifier + pub ident: Ident, + /// The item's generic arguments (if any) + pub generics: Generics, + /// The item's where clause for generics (if any) + pub where_clause: Option, + /// The TLV hash_input + pub hash_input: String, +} + +impl TryFrom for SplDiscriminateBuilder { + type Error = SplDiscriminateError; + + fn try_from(item_enum: ItemEnum) -> Result { + let ident = item_enum.ident; + let where_clause = item_enum.generics.where_clause.clone(); + let generics = item_enum.generics; + let hash_input = parse_hash_input(&item_enum.attrs)?; + Ok(Self { + ident, + generics, + where_clause, + hash_input, + }) + } +} + +impl TryFrom for SplDiscriminateBuilder { + type Error = SplDiscriminateError; + + fn try_from(item_struct: ItemStruct) -> Result { + let ident = item_struct.ident; + let where_clause = item_struct.generics.where_clause.clone(); + let generics = item_struct.generics; + let hash_input = parse_hash_input(&item_struct.attrs)?; + Ok(Self { + ident, + generics, + where_clause, + hash_input, + }) + } +} + +impl Parse for SplDiscriminateBuilder { + fn parse(input: syn::parse::ParseStream) -> syn::Result { + let item = Item::parse(input)?; + match item { + Item::Enum(item_enum) => item_enum.try_into(), + Item::Struct(item_struct) => item_struct.try_into(), + _ => { + return Err(syn::Error::new( + Span::call_site(), + "Only enums and structs are supported", + )) + } + } + .map_err(|e| syn::Error::new(input.span(), format!("Failed to parse item: {}", e))) + } +} + +impl ToTokens for SplDiscriminateBuilder { + fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) { + tokens.extend::(self.into()); + } +} + +impl From<&SplDiscriminateBuilder> for TokenStream { + fn from(builder: &SplDiscriminateBuilder) -> Self { + let ident = &builder.ident; + let generics = &builder.generics; + let where_clause = &builder.where_clause; + let bytes = get_discriminator_bytes(&builder.hash_input); + quote! { + impl #generics spl_discriminator::discriminator::SplDiscriminate for #ident #generics #where_clause { + const SPL_DISCRIMINATOR: spl_discriminator::discriminator::ArrayDiscriminator + = spl_discriminator::discriminator::ArrayDiscriminator::new(*#bytes); + } + } + } +} + +/// Returns the bytes for the TLV hash_input discriminator +fn get_discriminator_bytes(hash_input: &str) -> LitByteStr { + LitByteStr::new( + &Sha256::digest(hash_input.as_bytes())[..8], + Span::call_site(), + ) +} diff --git a/discriminator/syn/src/parser.rs b/discriminator/syn/src/parser.rs new file mode 100644 index 00000000..1d70c3ab --- /dev/null +++ b/discriminator/syn/src/parser.rs @@ -0,0 +1,43 @@ +//! Parser for the `syn` crate to parse the +//! `#[discriminator_hash_input("...")]` attribute + +use { + crate::error::SplDiscriminateError, + syn::{ + parse::{Parse, ParseStream}, + token::Comma, + Attribute, LitStr, + }, +}; + +/// Struct used for `syn` parsing of the hash_input attribute +/// #[discriminator_hash_input("...")] +struct HashInputValueParser { + value: LitStr, + _comma: Option, +} + +impl Parse for HashInputValueParser { + fn parse(input: ParseStream) -> syn::Result { + let value: LitStr = input.parse()?; + let _comma: Option = input.parse().unwrap_or(None); + Ok(HashInputValueParser { value, _comma }) + } +} + +/// Parses the hash_input from the `#[discriminator_hash_input("...")]` +/// attribute +pub fn parse_hash_input(attrs: &[Attribute]) -> Result { + match attrs + .iter() + .find(|a| a.path().is_ident("discriminator_hash_input")) + { + Some(attr) => { + let parsed_args = attr + .parse_args::() + .map_err(|_| SplDiscriminateError::HashInputAttributeParseError)?; + Ok(parsed_args.value.value()) + } + None => Err(SplDiscriminateError::HashInputAttributeNotProvided), + } +} diff --git a/pod/Cargo.toml b/pod/Cargo.toml new file mode 100644 index 00000000..a6363a93 --- /dev/null +++ b/pod/Cargo.toml @@ -0,0 +1,37 @@ +[package] +name = "spl-pod" +version = "0.5.0" +description = "Solana Program Library Plain Old Data (Pod)" +authors = ["Solana Labs Maintainers "] +repository = "https://github.com/solana-labs/solana-program-library" +license = "Apache-2.0" +edition = "2021" + +[features] +serde-traits = ["dep:serde"] +borsh = ["dep:borsh"] + +[dependencies] +borsh = { version = "1.5.3", optional = true } +bytemuck = { version = "1.20.0" } +bytemuck_derive = { version = "1.8.0" } +num-derive = "0.4" +num-traits = "0.2" +serde = { version = "1.0.216", optional = true } +solana-decode-error = "2.1.0" +solana-msg = "2.1.0" +solana-program-error = "2.1.0" +solana-program-option = "2.1.0" +solana-pubkey = "2.1.0" +solana-zk-sdk = "2.1.0" +thiserror = "2.0" + +[dev-dependencies] +serde_json = "1.0.133" +base64 = { version = "0.22.1" } + +[lib] +crate-type = ["cdylib", "lib"] + +[package.metadata.docs.rs] +targets = ["x86_64-unknown-linux-gnu"] diff --git a/pod/README.md b/pod/README.md new file mode 100644 index 00000000..984718b5 --- /dev/null +++ b/pod/README.md @@ -0,0 +1,3 @@ +# Pod + +This library contains types shared by SPL libraries that implement the `Pod` trait from `bytemuck` and utils for working with these types. diff --git a/pod/src/bytemuck.rs b/pod/src/bytemuck.rs new file mode 100644 index 00000000..f0039e35 --- /dev/null +++ b/pod/src/bytemuck.rs @@ -0,0 +1,53 @@ +//! wrappers for bytemuck functions + +use {bytemuck::Pod, solana_program_error::ProgramError}; + +/// On-chain size of a `Pod` type +pub const fn pod_get_packed_len() -> usize { + std::mem::size_of::() +} + +/// Convert a `Pod` into a slice of bytes (zero copy) +pub fn pod_bytes_of(t: &T) -> &[u8] { + bytemuck::bytes_of(t) +} + +/// Convert a slice of bytes into a `Pod` (zero copy) +pub fn pod_from_bytes(bytes: &[u8]) -> Result<&T, ProgramError> { + bytemuck::try_from_bytes(bytes).map_err(|_| ProgramError::InvalidArgument) +} + +/// Maybe convert a slice of bytes into a `Pod` (zero copy) +/// +/// Returns `None` if the slice is empty, or else `Err` if input length is not +/// equal to `pod_get_packed_len::()`. +/// This function exists primarily because `Option` is not a `Pod`. +pub fn pod_maybe_from_bytes(bytes: &[u8]) -> Result, ProgramError> { + if bytes.is_empty() { + Ok(None) + } else { + bytemuck::try_from_bytes(bytes) + .map(Some) + .map_err(|_| ProgramError::InvalidArgument) + } +} + +/// Convert a slice of bytes into a mutable `Pod` (zero copy) +pub fn pod_from_bytes_mut(bytes: &mut [u8]) -> Result<&mut T, ProgramError> { + bytemuck::try_from_bytes_mut(bytes).map_err(|_| ProgramError::InvalidArgument) +} + +/// Convert a slice of bytes into a `Pod` slice (zero copy) +pub fn pod_slice_from_bytes(bytes: &[u8]) -> Result<&[T], ProgramError> { + bytemuck::try_cast_slice(bytes).map_err(|_| ProgramError::InvalidArgument) +} + +/// Convert a slice of bytes into a mutable `Pod` slice (zero copy) +pub fn pod_slice_from_bytes_mut(bytes: &mut [u8]) -> Result<&mut [T], ProgramError> { + bytemuck::try_cast_slice_mut(bytes).map_err(|_| ProgramError::InvalidArgument) +} + +/// Convert a `Pod` slice into a single slice of bytes +pub fn pod_slice_to_bytes(slice: &[T]) -> &[u8] { + bytemuck::cast_slice(slice) +} diff --git a/pod/src/error.rs b/pod/src/error.rs new file mode 100644 index 00000000..f67f4eab --- /dev/null +++ b/pod/src/error.rs @@ -0,0 +1,56 @@ +//! Error types +use { + solana_decode_error::DecodeError, + solana_msg::msg, + solana_program_error::{PrintProgramError, ProgramError}, +}; + +/// Errors that may be returned by the spl-pod library. +#[repr(u32)] +#[derive(Debug, Clone, PartialEq, Eq, thiserror::Error, num_derive::FromPrimitive)] +pub enum PodSliceError { + /// Error in checked math operation + #[error("Error in checked math operation")] + CalculationFailure, + /// Provided byte buffer too small for expected type + #[error("Provided byte buffer too small for expected type")] + BufferTooSmall, + /// Provided byte buffer too large for expected type + #[error("Provided byte buffer too large for expected type")] + BufferTooLarge, +} + +impl From for ProgramError { + fn from(e: PodSliceError) -> Self { + ProgramError::Custom(e as u32) + } +} + +impl solana_decode_error::DecodeError for PodSliceError { + fn type_of() -> &'static str { + "PodSliceError" + } +} + +impl PrintProgramError for PodSliceError { + fn print(&self) + where + E: 'static + + std::error::Error + + DecodeError + + PrintProgramError + + num_traits::FromPrimitive, + { + match self { + PodSliceError::CalculationFailure => { + msg!("Error in checked math operation") + } + PodSliceError::BufferTooSmall => { + msg!("Provided byte buffer too small for expected type") + } + PodSliceError::BufferTooLarge => { + msg!("Provided byte buffer too large for expected type") + } + } + } +} diff --git a/pod/src/lib.rs b/pod/src/lib.rs new file mode 100644 index 00000000..5be44e71 --- /dev/null +++ b/pod/src/lib.rs @@ -0,0 +1,14 @@ +//! Crate containing `Pod` types and `bytemuck` utils used in SPL + +pub mod bytemuck; +pub mod error; +pub mod option; +pub mod optional_keys; +pub mod primitives; +pub mod slice; + +// Export current sdk types for downstream users building with a different sdk +// version +pub use { + solana_decode_error, solana_msg, solana_program_error, solana_program_option, solana_pubkey, +}; diff --git a/pod/src/option.rs b/pod/src/option.rs new file mode 100644 index 00000000..02d7edd0 --- /dev/null +++ b/pod/src/option.rs @@ -0,0 +1,188 @@ +//! Generic `Option` that can be used as a `Pod` for types that can have +//! a designated `None` value. +//! +//! For example, a 64-bit unsigned integer can designate `0` as a `None` value. +//! This would be equivalent to +//! [`Option`](https://doc.rust-lang.org/std/num/type.NonZeroU64.html) +//! and provide the same memory layout optimization. + +use { + bytemuck::{Pod, Zeroable}, + solana_program_error::ProgramError, + solana_program_option::COption, + solana_pubkey::{Pubkey, PUBKEY_BYTES}, +}; + +/// Trait for types that can be `None`. +/// +/// This trait is used to indicate that a type can be `None` according to a +/// specific value. +pub trait Nullable: PartialEq + Pod + Sized { + /// Value that represents `None` for the type. + const NONE: Self; + + /// Indicates whether the value is `None` or not. + fn is_none(&self) -> bool { + self == &Self::NONE + } + + /// Indicates whether the value is `Some`` value of type `T`` or not. + fn is_some(&self) -> bool { + !self.is_none() + } +} + +/// A "pod-enabled" type that can be used as an `Option` without +/// requiring extra space to indicate if the value is `Some` or `None`. +/// +/// This can be used when a specific value of `T` indicates that its +/// value is `None`. +#[repr(transparent)] +#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)] +pub struct PodOption(T); + +impl Default for PodOption { + fn default() -> Self { + Self(T::NONE) + } +} + +impl PodOption { + /// Returns the contained value as an `Option`. + #[inline] + pub fn get(self) -> Option { + if self.0.is_none() { + None + } else { + Some(self.0) + } + } + + /// Returns the contained value as an `Option`. + #[inline] + pub fn as_ref(&self) -> Option<&T> { + if self.0.is_none() { + None + } else { + Some(&self.0) + } + } + + /// Returns the contained value as a mutable `Option`. + #[inline] + pub fn as_mut(&mut self) -> Option<&mut T> { + if self.0.is_none() { + None + } else { + Some(&mut self.0) + } + } +} + +/// ## Safety +/// +/// `PodOption` is a transparent wrapper around a `Pod` type `T` with identical +/// data representation. +unsafe impl Pod for PodOption {} + +/// ## Safety +/// +/// `PodOption` is a transparent wrapper around a `Pod` type `T` with identical +/// data representation. +unsafe impl Zeroable for PodOption {} + +impl From for PodOption { + fn from(value: T) -> Self { + PodOption(value) + } +} + +impl TryFrom> for PodOption { + type Error = ProgramError; + + fn try_from(value: Option) -> Result { + match value { + Some(value) if value.is_none() => Err(ProgramError::InvalidArgument), + Some(value) => Ok(PodOption(value)), + None => Ok(PodOption(T::NONE)), + } + } +} + +impl TryFrom> for PodOption { + type Error = ProgramError; + + fn try_from(value: COption) -> Result { + match value { + COption::Some(value) if value.is_none() => Err(ProgramError::InvalidArgument), + COption::Some(value) => Ok(PodOption(value)), + COption::None => Ok(PodOption(T::NONE)), + } + } +} + +/// Implementation of `Nullable` for `Pubkey`. +impl Nullable for Pubkey { + const NONE: Self = Pubkey::new_from_array([0u8; PUBKEY_BYTES]); +} + +#[cfg(test)] +mod tests { + use {super::*, crate::bytemuck::pod_slice_from_bytes}; + const ID: Pubkey = Pubkey::from_str_const("TestSysvar111111111111111111111111111111111"); + + #[test] + fn test_pod_option_pubkey() { + let some_pubkey = PodOption::from(ID); + assert_eq!(some_pubkey.get(), Some(ID)); + + let none_pubkey = PodOption::from(Pubkey::default()); + assert_eq!(none_pubkey.get(), None); + + let mut data = Vec::with_capacity(64); + data.extend_from_slice(ID.as_ref()); + data.extend_from_slice(&[0u8; 32]); + + let values = pod_slice_from_bytes::>(&data).unwrap(); + assert_eq!(values[0], PodOption::from(ID)); + assert_eq!(values[1], PodOption::from(Pubkey::default())); + + let option_pubkey = Some(ID); + let pod_option_pubkey: PodOption = option_pubkey.try_into().unwrap(); + assert_eq!(pod_option_pubkey, PodOption::from(ID)); + assert_eq!( + pod_option_pubkey, + PodOption::try_from(option_pubkey).unwrap() + ); + + let coption_pubkey = COption::Some(ID); + let pod_option_pubkey: PodOption = coption_pubkey.try_into().unwrap(); + assert_eq!(pod_option_pubkey, PodOption::from(ID)); + assert_eq!( + pod_option_pubkey, + PodOption::try_from(coption_pubkey).unwrap() + ); + } + + #[test] + fn test_try_from_option() { + let some_pubkey = Some(ID); + assert_eq!(PodOption::try_from(some_pubkey).unwrap(), PodOption(ID)); + + let none_pubkey = None; + assert_eq!( + PodOption::try_from(none_pubkey).unwrap(), + PodOption::from(Pubkey::NONE) + ); + + let invalid_option = Some(Pubkey::NONE); + let err = PodOption::try_from(invalid_option).unwrap_err(); + assert_eq!(err, ProgramError::InvalidArgument); + } + + #[test] + fn test_default() { + let def = PodOption::::default(); + assert_eq!(def, None.try_into().unwrap()); + } +} diff --git a/pod/src/optional_keys.rs b/pod/src/optional_keys.rs new file mode 100644 index 00000000..501287f9 --- /dev/null +++ b/pod/src/optional_keys.rs @@ -0,0 +1,360 @@ +//! Optional pubkeys that can be used a `Pod`s +#[cfg(feature = "borsh")] +use borsh::{BorshDeserialize, BorshSchema, BorshSerialize}; +use { + bytemuck_derive::{Pod, Zeroable}, + solana_program_error::ProgramError, + solana_program_option::COption, + solana_pubkey::Pubkey, + solana_zk_sdk::encryption::pod::elgamal::PodElGamalPubkey, +}; +#[cfg(feature = "serde-traits")] +use { + serde::de::{Error, Unexpected, Visitor}, + serde::{Deserialize, Deserializer, Serialize, Serializer}, + std::{convert::TryFrom, fmt, str::FromStr}, +}; + +/// A Pubkey that encodes `None` as all `0`, meant to be usable as a Pod type, +/// similar to all NonZero* number types from the bytemuck library. +#[cfg_attr( + feature = "borsh", + derive(BorshDeserialize, BorshSerialize, BorshSchema) +)] +#[derive(Clone, Copy, Debug, Default, PartialEq, Pod, Zeroable)] +#[repr(transparent)] +pub struct OptionalNonZeroPubkey(pub Pubkey); +impl TryFrom> for OptionalNonZeroPubkey { + type Error = ProgramError; + fn try_from(p: Option) -> Result { + match p { + None => Ok(Self(Pubkey::default())), + Some(pubkey) => { + if pubkey == Pubkey::default() { + Err(ProgramError::InvalidArgument) + } else { + Ok(Self(pubkey)) + } + } + } + } +} +impl TryFrom> for OptionalNonZeroPubkey { + type Error = ProgramError; + fn try_from(p: COption) -> Result { + match p { + COption::None => Ok(Self(Pubkey::default())), + COption::Some(pubkey) => { + if pubkey == Pubkey::default() { + Err(ProgramError::InvalidArgument) + } else { + Ok(Self(pubkey)) + } + } + } + } +} +impl From for Option { + fn from(p: OptionalNonZeroPubkey) -> Self { + if p.0 == Pubkey::default() { + None + } else { + Some(p.0) + } + } +} +impl From for COption { + fn from(p: OptionalNonZeroPubkey) -> Self { + if p.0 == Pubkey::default() { + COption::None + } else { + COption::Some(p.0) + } + } +} + +#[cfg(feature = "serde-traits")] +impl Serialize for OptionalNonZeroPubkey { + fn serialize(&self, s: S) -> Result + where + S: Serializer, + { + if self.0 == Pubkey::default() { + s.serialize_none() + } else { + s.serialize_some(&self.0.to_string()) + } + } +} + +#[cfg(feature = "serde-traits")] +/// Visitor for deserializing OptionalNonZeroPubkey +struct OptionalNonZeroPubkeyVisitor; + +#[cfg(feature = "serde-traits")] +impl<'de> Visitor<'de> for OptionalNonZeroPubkeyVisitor { + type Value = OptionalNonZeroPubkey; + + fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + formatter.write_str("a Pubkey in base58 or `null`") + } + + fn visit_str(self, v: &str) -> Result + where + E: Error, + { + let pkey = Pubkey::from_str(v) + .map_err(|_| Error::invalid_value(Unexpected::Str(v), &"value string"))?; + + OptionalNonZeroPubkey::try_from(Some(pkey)) + .map_err(|_| Error::custom("Failed to convert from pubkey")) + } + + fn visit_unit(self) -> Result + where + E: Error, + { + OptionalNonZeroPubkey::try_from(None).map_err(|e| Error::custom(e.to_string())) + } +} + +#[cfg(feature = "serde-traits")] +impl<'de> Deserialize<'de> for OptionalNonZeroPubkey { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + deserializer.deserialize_any(OptionalNonZeroPubkeyVisitor) + } +} + +/// An ElGamalPubkey that encodes `None` as all `0`, meant to be usable as a Pod +/// type. +#[derive(Clone, Copy, Debug, Default, PartialEq, Pod, Zeroable)] +#[repr(transparent)] +pub struct OptionalNonZeroElGamalPubkey(PodElGamalPubkey); +impl OptionalNonZeroElGamalPubkey { + /// Checks equality between an OptionalNonZeroElGamalPubkey and an + /// ElGamalPubkey when interpreted as bytes. + pub fn equals(&self, other: &PodElGamalPubkey) -> bool { + &self.0 == other + } +} +impl TryFrom> for OptionalNonZeroElGamalPubkey { + type Error = ProgramError; + fn try_from(p: Option) -> Result { + match p { + None => Ok(Self(PodElGamalPubkey::default())), + Some(elgamal_pubkey) => { + if elgamal_pubkey == PodElGamalPubkey::default() { + Err(ProgramError::InvalidArgument) + } else { + Ok(Self(elgamal_pubkey)) + } + } + } + } +} +impl From for Option { + fn from(p: OptionalNonZeroElGamalPubkey) -> Self { + if p.0 == PodElGamalPubkey::default() { + None + } else { + Some(p.0) + } + } +} + +#[cfg(feature = "serde-traits")] +impl Serialize for OptionalNonZeroElGamalPubkey { + fn serialize(&self, s: S) -> Result + where + S: Serializer, + { + if self.0 == PodElGamalPubkey::default() { + s.serialize_none() + } else { + s.serialize_some(&self.0.to_string()) + } + } +} + +#[cfg(feature = "serde-traits")] +struct OptionalNonZeroElGamalPubkeyVisitor; + +#[cfg(feature = "serde-traits")] +impl<'de> Visitor<'de> for OptionalNonZeroElGamalPubkeyVisitor { + type Value = OptionalNonZeroElGamalPubkey; + + fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + formatter.write_str("an ElGamal public key as base64 or `null`") + } + + fn visit_str(self, v: &str) -> Result + where + E: Error, + { + let elgamal_pubkey: PodElGamalPubkey = FromStr::from_str(v).map_err(Error::custom)?; + OptionalNonZeroElGamalPubkey::try_from(Some(elgamal_pubkey)).map_err(Error::custom) + } + + fn visit_unit(self) -> Result + where + E: Error, + { + Ok(OptionalNonZeroElGamalPubkey::default()) + } +} + +#[cfg(feature = "serde-traits")] +impl<'de> Deserialize<'de> for OptionalNonZeroElGamalPubkey { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + deserializer.deserialize_any(OptionalNonZeroElGamalPubkeyVisitor) + } +} + +#[cfg(test)] +mod tests { + use { + super::*, + crate::bytemuck::pod_from_bytes, + base64::{prelude::BASE64_STANDARD, Engine}, + solana_pubkey::PUBKEY_BYTES, + }; + + #[test] + fn test_pod_non_zero_option() { + assert_eq!( + Some(Pubkey::new_from_array([1; PUBKEY_BYTES])), + Option::::from( + *pod_from_bytes::(&[1; PUBKEY_BYTES]).unwrap() + ) + ); + assert_eq!( + None, + Option::::from( + *pod_from_bytes::(&[0; PUBKEY_BYTES]).unwrap() + ) + ); + assert_eq!( + pod_from_bytes::(&[]).unwrap_err(), + ProgramError::InvalidArgument + ); + assert_eq!( + pod_from_bytes::(&[0; 1]).unwrap_err(), + ProgramError::InvalidArgument + ); + assert_eq!( + pod_from_bytes::(&[1; 1]).unwrap_err(), + ProgramError::InvalidArgument + ); + } + + #[cfg(feature = "serde-traits")] + #[test] + fn test_pod_non_zero_option_serde_some() { + let optional_non_zero_pubkey_some = + OptionalNonZeroPubkey(Pubkey::new_from_array([1; PUBKEY_BYTES])); + let serialized_some = serde_json::to_string(&optional_non_zero_pubkey_some).unwrap(); + assert_eq!( + &serialized_some, + "\"4vJ9JU1bJJE96FWSJKvHsmmFADCg4gpZQff4P3bkLKi\"" + ); + + let deserialized_some = + serde_json::from_str::(&serialized_some).unwrap(); + assert_eq!(optional_non_zero_pubkey_some, deserialized_some); + } + + #[cfg(feature = "serde-traits")] + #[test] + fn test_pod_non_zero_option_serde_none() { + let optional_non_zero_pubkey_none = + OptionalNonZeroPubkey(Pubkey::new_from_array([0; PUBKEY_BYTES])); + let serialized_none = serde_json::to_string(&optional_non_zero_pubkey_none).unwrap(); + assert_eq!(&serialized_none, "null"); + + let deserialized_none = + serde_json::from_str::(&serialized_none).unwrap(); + assert_eq!(optional_non_zero_pubkey_none, deserialized_none); + } + + const OPTIONAL_NONZERO_ELGAMAL_PUBKEY_LEN: usize = 32; + + // Unfortunately, the `solana-zk-sdk` does not expose a constructor interface + // to construct `PodRistrettoPoint` from bytes. As a work-around, encode the + // bytes as base64 string and then convert the string to a + // `PodElGamalCiphertext`. + // + // The constructor will be added (and this function removed) with + // `solana-zk-sdk` 2.1. + fn elgamal_pubkey_from_bytes(bytes: &[u8]) -> PodElGamalPubkey { + let string = BASE64_STANDARD.encode(bytes); + std::str::FromStr::from_str(&string).unwrap() + } + + #[test] + fn test_pod_non_zero_elgamal_option() { + assert_eq!( + Some(elgamal_pubkey_from_bytes( + &[1; OPTIONAL_NONZERO_ELGAMAL_PUBKEY_LEN] + )), + Option::::from(OptionalNonZeroElGamalPubkey( + elgamal_pubkey_from_bytes(&[1; OPTIONAL_NONZERO_ELGAMAL_PUBKEY_LEN]) + )) + ); + assert_eq!( + None, + Option::::from(OptionalNonZeroElGamalPubkey( + elgamal_pubkey_from_bytes(&[0; OPTIONAL_NONZERO_ELGAMAL_PUBKEY_LEN]) + )) + ); + + assert_eq!( + OptionalNonZeroElGamalPubkey(elgamal_pubkey_from_bytes( + &[1; OPTIONAL_NONZERO_ELGAMAL_PUBKEY_LEN] + )), + *pod_from_bytes::( + &[1; OPTIONAL_NONZERO_ELGAMAL_PUBKEY_LEN] + ) + .unwrap() + ); + assert!(pod_from_bytes::(&[]).is_err()); + } + + #[cfg(feature = "serde-traits")] + #[test] + fn test_pod_non_zero_elgamal_option_serde_some() { + let optional_non_zero_elgamal_pubkey_some = OptionalNonZeroElGamalPubkey( + elgamal_pubkey_from_bytes(&[1; OPTIONAL_NONZERO_ELGAMAL_PUBKEY_LEN]), + ); + let serialized_some = + serde_json::to_string(&optional_non_zero_elgamal_pubkey_some).unwrap(); + assert_eq!( + &serialized_some, + "\"AQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQE=\"" + ); + + let deserialized_some = + serde_json::from_str::(&serialized_some).unwrap(); + assert_eq!(optional_non_zero_elgamal_pubkey_some, deserialized_some); + } + + #[cfg(feature = "serde-traits")] + #[test] + fn test_pod_non_zero_elgamal_option_serde_none() { + let optional_non_zero_elgamal_pubkey_none = OptionalNonZeroElGamalPubkey( + elgamal_pubkey_from_bytes(&[0; OPTIONAL_NONZERO_ELGAMAL_PUBKEY_LEN]), + ); + let serialized_none = + serde_json::to_string(&optional_non_zero_elgamal_pubkey_none).unwrap(); + assert_eq!(&serialized_none, "null"); + + let deserialized_none = + serde_json::from_str::(&serialized_none).unwrap(); + assert_eq!(optional_non_zero_elgamal_pubkey_none, deserialized_none); + } +} diff --git a/pod/src/primitives.rs b/pod/src/primitives.rs new file mode 100644 index 00000000..f6759cf5 --- /dev/null +++ b/pod/src/primitives.rs @@ -0,0 +1,269 @@ +//! primitive types that can be used in `Pod`s +#[cfg(feature = "borsh")] +use borsh::{BorshDeserialize, BorshSchema, BorshSerialize}; +use bytemuck_derive::{Pod, Zeroable}; +#[cfg(feature = "serde-traits")] +use serde::{Deserialize, Serialize}; + +/// The standard `bool` is not a `Pod`, define a replacement that is +#[cfg_attr(feature = "serde-traits", derive(Serialize, Deserialize))] +#[cfg_attr(feature = "serde-traits", serde(from = "bool", into = "bool"))] +#[derive(Clone, Copy, Debug, Default, PartialEq, Pod, Zeroable)] +#[repr(transparent)] +pub struct PodBool(pub u8); +impl PodBool { + pub const fn from_bool(b: bool) -> Self { + Self(if b { 1 } else { 0 }) + } +} + +impl From for PodBool { + fn from(b: bool) -> Self { + Self::from_bool(b) + } +} + +impl From<&bool> for PodBool { + fn from(b: &bool) -> Self { + Self(if *b { 1 } else { 0 }) + } +} + +impl From<&PodBool> for bool { + fn from(b: &PodBool) -> Self { + b.0 != 0 + } +} + +impl From for bool { + fn from(b: PodBool) -> Self { + b.0 != 0 + } +} + +/// Simple macro for implementing conversion functions between Pod* ints and +/// standard ints. +/// +/// The standard int types can cause alignment issues when placed in a `Pod`, +/// so these replacements are usable in all `Pod`s. +#[macro_export] +macro_rules! impl_int_conversion { + ($P:ty, $I:ty) => { + impl $P { + pub const fn from_primitive(n: $I) -> Self { + Self(n.to_le_bytes()) + } + } + impl From<$I> for $P { + fn from(n: $I) -> Self { + Self::from_primitive(n) + } + } + impl From<$P> for $I { + fn from(pod: $P) -> Self { + Self::from_le_bytes(pod.0) + } + } + }; +} + +/// `u16` type that can be used in `Pod`s +#[cfg_attr(feature = "serde-traits", derive(Serialize, Deserialize))] +#[cfg_attr(feature = "serde-traits", serde(from = "u16", into = "u16"))] +#[derive(Clone, Copy, Debug, Default, PartialEq, Pod, Zeroable)] +#[repr(transparent)] +pub struct PodU16(pub [u8; 2]); +impl_int_conversion!(PodU16, u16); + +/// `i16` type that can be used in Pods +#[cfg_attr(feature = "serde-traits", derive(Serialize, Deserialize))] +#[cfg_attr(feature = "serde-traits", serde(from = "i16", into = "i16"))] +#[derive(Clone, Copy, Debug, Default, PartialEq, Pod, Zeroable)] +#[repr(transparent)] +pub struct PodI16(pub [u8; 2]); +impl_int_conversion!(PodI16, i16); + +/// `u32` type that can be used in `Pod`s +#[cfg_attr( + feature = "borsh", + derive(BorshDeserialize, BorshSerialize, BorshSchema) +)] +#[cfg_attr(feature = "serde-traits", derive(Serialize, Deserialize))] +#[cfg_attr(feature = "serde-traits", serde(from = "u32", into = "u32"))] +#[derive(Clone, Copy, Debug, Default, PartialEq, Pod, Zeroable)] +#[repr(transparent)] +pub struct PodU32(pub [u8; 4]); +impl_int_conversion!(PodU32, u32); + +/// `u64` type that can be used in Pods +#[cfg_attr( + feature = "borsh", + derive(BorshDeserialize, BorshSerialize, BorshSchema) +)] +#[cfg_attr(feature = "serde-traits", derive(Serialize, Deserialize))] +#[cfg_attr(feature = "serde-traits", serde(from = "u64", into = "u64"))] +#[derive(Clone, Copy, Debug, Default, PartialEq, Pod, Zeroable)] +#[repr(transparent)] +pub struct PodU64(pub [u8; 8]); +impl_int_conversion!(PodU64, u64); + +/// `i64` type that can be used in Pods +#[cfg_attr(feature = "serde-traits", derive(Serialize, Deserialize))] +#[cfg_attr(feature = "serde-traits", serde(from = "i64", into = "i64"))] +#[derive(Clone, Copy, Debug, Default, PartialEq, Pod, Zeroable)] +#[repr(transparent)] +pub struct PodI64([u8; 8]); +impl_int_conversion!(PodI64, i64); + +/// `u128` type that can be used in Pods +#[cfg_attr( + feature = "borsh", + derive(BorshDeserialize, BorshSerialize, BorshSchema) +)] +#[cfg_attr(feature = "serde-traits", derive(Serialize, Deserialize))] +#[cfg_attr(feature = "serde-traits", serde(from = "u128", into = "u128"))] +#[derive(Clone, Copy, Debug, Default, PartialEq, Pod, Zeroable)] +#[repr(transparent)] +pub struct PodU128(pub [u8; 16]); +impl_int_conversion!(PodU128, u128); + +#[cfg(test)] +mod tests { + use {super::*, crate::bytemuck::pod_from_bytes}; + + #[test] + fn test_pod_bool() { + assert!(pod_from_bytes::(&[]).is_err()); + assert!(pod_from_bytes::(&[0, 0]).is_err()); + + for i in 0..=u8::MAX { + assert_eq!(i != 0, bool::from(pod_from_bytes::(&[i]).unwrap())); + } + } + + #[cfg(feature = "serde-traits")] + #[test] + fn test_pod_bool_serde() { + let pod_false: PodBool = false.into(); + let pod_true: PodBool = true.into(); + + let serialized_false = serde_json::to_string(&pod_false).unwrap(); + let serialized_true = serde_json::to_string(&pod_true).unwrap(); + assert_eq!(&serialized_false, "false"); + assert_eq!(&serialized_true, "true"); + + let deserialized_false = serde_json::from_str::(&serialized_false).unwrap(); + let deserialized_true = serde_json::from_str::(&serialized_true).unwrap(); + assert_eq!(pod_false, deserialized_false); + assert_eq!(pod_true, deserialized_true); + } + + #[test] + fn test_pod_u16() { + assert!(pod_from_bytes::(&[]).is_err()); + assert_eq!(1u16, u16::from(*pod_from_bytes::(&[1, 0]).unwrap())); + } + + #[cfg(feature = "serde-traits")] + #[test] + fn test_pod_u16_serde() { + let pod_u16: PodU16 = u16::MAX.into(); + + let serialized = serde_json::to_string(&pod_u16).unwrap(); + assert_eq!(&serialized, "65535"); + + let deserialized = serde_json::from_str::(&serialized).unwrap(); + assert_eq!(pod_u16, deserialized); + } + + #[test] + fn test_pod_i16() { + assert!(pod_from_bytes::(&[]).is_err()); + assert_eq!( + -1i16, + i16::from(*pod_from_bytes::(&[255, 255]).unwrap()) + ); + } + + #[cfg(feature = "serde-traits")] + #[test] + fn test_pod_i16_serde() { + let pod_i16: PodI16 = i16::MAX.into(); + + println!("pod_i16 {:?}", pod_i16); + + let serialized = serde_json::to_string(&pod_i16).unwrap(); + assert_eq!(&serialized, "32767"); + + let deserialized = serde_json::from_str::(&serialized).unwrap(); + assert_eq!(pod_i16, deserialized); + } + + #[test] + fn test_pod_u64() { + assert!(pod_from_bytes::(&[]).is_err()); + assert_eq!( + 1u64, + u64::from(*pod_from_bytes::(&[1, 0, 0, 0, 0, 0, 0, 0]).unwrap()) + ); + } + + #[cfg(feature = "serde-traits")] + #[test] + fn test_pod_u64_serde() { + let pod_u64: PodU64 = u64::MAX.into(); + + let serialized = serde_json::to_string(&pod_u64).unwrap(); + assert_eq!(&serialized, "18446744073709551615"); + + let deserialized = serde_json::from_str::(&serialized).unwrap(); + assert_eq!(pod_u64, deserialized); + } + + #[test] + fn test_pod_i64() { + assert!(pod_from_bytes::(&[]).is_err()); + assert_eq!( + -1i64, + i64::from( + *pod_from_bytes::(&[255, 255, 255, 255, 255, 255, 255, 255]).unwrap() + ) + ); + } + + #[cfg(feature = "serde-traits")] + #[test] + fn test_pod_i64_serde() { + let pod_i64: PodI64 = i64::MAX.into(); + + let serialized = serde_json::to_string(&pod_i64).unwrap(); + assert_eq!(&serialized, "9223372036854775807"); + + let deserialized = serde_json::from_str::(&serialized).unwrap(); + assert_eq!(pod_i64, deserialized); + } + + #[test] + fn test_pod_u128() { + assert!(pod_from_bytes::(&[]).is_err()); + assert_eq!( + 1u128, + u128::from( + *pod_from_bytes::(&[1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]) + .unwrap() + ) + ); + } + + #[cfg(feature = "serde-traits")] + #[test] + fn test_pod_u128_serde() { + let pod_u128: PodU128 = u128::MAX.into(); + + let serialized = serde_json::to_string(&pod_u128).unwrap(); + assert_eq!(&serialized, "340282366920938463463374607431768211455"); + + let deserialized = serde_json::from_str::(&serialized).unwrap(); + assert_eq!(pod_u128, deserialized); + } +} diff --git a/pod/src/slice.rs b/pod/src/slice.rs new file mode 100644 index 00000000..043f408d --- /dev/null +++ b/pod/src/slice.rs @@ -0,0 +1,220 @@ +//! Special types for working with slices of `Pod`s + +use { + crate::{ + bytemuck::{ + pod_from_bytes, pod_from_bytes_mut, pod_slice_from_bytes, pod_slice_from_bytes_mut, + }, + error::PodSliceError, + primitives::PodU32, + }, + bytemuck::Pod, + solana_program_error::ProgramError, +}; + +const LENGTH_SIZE: usize = std::mem::size_of::(); +/// Special type for using a slice of `Pod`s in a zero-copy way +pub struct PodSlice<'data, T: Pod> { + length: &'data PodU32, + data: &'data [T], +} +impl<'data, T: Pod> PodSlice<'data, T> { + /// Unpack the buffer into a slice + pub fn unpack<'a>(data: &'a [u8]) -> Result + where + 'a: 'data, + { + if data.len() < LENGTH_SIZE { + return Err(PodSliceError::BufferTooSmall.into()); + } + let (length, data) = data.split_at(LENGTH_SIZE); + let length = pod_from_bytes::(length)?; + let _max_length = max_len_for_type::(data.len())?; + let data = pod_slice_from_bytes(data)?; + Ok(Self { length, data }) + } + + /// Get the slice data + pub fn data(&self) -> &[T] { + let length = u32::from(*self.length) as usize; + &self.data[..length] + } + + /// Get the amount of bytes used by `num_items` + pub fn size_of(num_items: usize) -> Result { + std::mem::size_of::() + .checked_mul(num_items) + .and_then(|len| len.checked_add(LENGTH_SIZE)) + .ok_or_else(|| PodSliceError::CalculationFailure.into()) + } +} + +/// Special type for using a slice of mutable `Pod`s in a zero-copy way +pub struct PodSliceMut<'data, T: Pod> { + length: &'data mut PodU32, + data: &'data mut [T], + max_length: usize, +} +impl<'data, T: Pod> PodSliceMut<'data, T> { + /// Unpack the mutable buffer into a mutable slice, with the option to + /// initialize the data + fn unpack_internal<'a>(data: &'a mut [u8], init: bool) -> Result + where + 'a: 'data, + { + if data.len() < LENGTH_SIZE { + return Err(PodSliceError::BufferTooSmall.into()); + } + let (length, data) = data.split_at_mut(LENGTH_SIZE); + let length = pod_from_bytes_mut::(length)?; + if init { + *length = 0.into(); + } + let max_length = max_len_for_type::(data.len())?; + let data = pod_slice_from_bytes_mut(data)?; + Ok(Self { + length, + data, + max_length, + }) + } + + /// Unpack the mutable buffer into a mutable slice + pub fn unpack<'a>(data: &'a mut [u8]) -> Result + where + 'a: 'data, + { + Self::unpack_internal(data, /* init */ false) + } + + /// Unpack the mutable buffer into a mutable slice, and initialize the + /// slice to 0-length + pub fn init<'a>(data: &'a mut [u8]) -> Result + where + 'a: 'data, + { + Self::unpack_internal(data, /* init */ true) + } + + /// Add another item to the slice + pub fn push(&mut self, t: T) -> Result<(), ProgramError> { + let length = u32::from(*self.length); + if length as usize == self.max_length { + Err(PodSliceError::BufferTooSmall.into()) + } else { + self.data[length as usize] = t; + *self.length = length.saturating_add(1).into(); + Ok(()) + } + } +} + +fn max_len_for_type(data_len: usize) -> Result { + let size: usize = std::mem::size_of::(); + let max_len = data_len + .checked_div(size) + .ok_or(PodSliceError::CalculationFailure)?; + // check that it isn't over or under allocated + if max_len.saturating_mul(size) != data_len { + if max_len == 0 { + // Size of T is greater than buffer size + Err(PodSliceError::BufferTooSmall.into()) + } else { + Err(PodSliceError::BufferTooLarge.into()) + } + } else { + Ok(max_len) + } +} + +#[cfg(test)] +mod tests { + use { + super::*, + crate::bytemuck::pod_slice_to_bytes, + bytemuck_derive::{Pod, Zeroable}, + }; + + #[repr(C)] + #[derive(Clone, Copy, Debug, Default, PartialEq, Pod, Zeroable)] + struct TestStruct { + test_field: u8, + test_pubkey: [u8; 32], + } + + #[test] + fn test_pod_slice() { + let test_field_bytes = [0]; + let test_pubkey_bytes = [1; 32]; + let len_bytes = [2, 0, 0, 0]; + + // Slice will contain 2 `TestStruct` + let mut data_bytes = [0; 66]; + data_bytes[0..1].copy_from_slice(&test_field_bytes); + data_bytes[1..33].copy_from_slice(&test_pubkey_bytes); + data_bytes[33..34].copy_from_slice(&test_field_bytes); + data_bytes[34..66].copy_from_slice(&test_pubkey_bytes); + + let mut pod_slice_bytes = [0; 70]; + pod_slice_bytes[0..4].copy_from_slice(&len_bytes); + pod_slice_bytes[4..70].copy_from_slice(&data_bytes); + + let pod_slice = PodSlice::::unpack(&pod_slice_bytes).unwrap(); + let pod_slice_data = pod_slice.data(); + + assert_eq!(*pod_slice.length, PodU32::from(2)); + assert_eq!(pod_slice_to_bytes(pod_slice.data()), data_bytes); + assert_eq!(pod_slice_data[0].test_field, test_field_bytes[0]); + assert_eq!(pod_slice_data[0].test_pubkey, test_pubkey_bytes); + assert_eq!(PodSlice::::size_of(1).unwrap(), 37); + } + + #[test] + fn test_pod_slice_buffer_too_large() { + // 1 `TestStruct` + length = 37 bytes + // we pass 38 to trigger BufferTooLarge + let pod_slice_bytes = [1; 38]; + let err = PodSlice::::unpack(&pod_slice_bytes) + .err() + .unwrap(); + assert_eq!( + err, + PodSliceError::BufferTooLarge.into(), + "Expected an `PodSliceError::BufferTooLarge` error" + ); + } + + #[test] + fn test_pod_slice_buffer_too_small() { + // 1 `TestStruct` + length = 37 bytes + // we pass 36 to trigger BufferTooSmall + let pod_slice_bytes = [1; 36]; + let err = PodSlice::::unpack(&pod_slice_bytes) + .err() + .unwrap(); + assert_eq!( + err, + PodSliceError::BufferTooSmall.into(), + "Expected an `PodSliceError::BufferTooSmall` error" + ); + } + + #[test] + fn test_pod_slice_mut() { + // slice can fit 2 `TestStruct` + let mut pod_slice_bytes = [0; 70]; + // set length to 1, so we have room to push 1 more item + let len_bytes = [1, 0, 0, 0]; + pod_slice_bytes[0..4].copy_from_slice(&len_bytes); + + let mut pod_slice = PodSliceMut::::unpack(&mut pod_slice_bytes).unwrap(); + + assert_eq!(*pod_slice.length, PodU32::from(1)); + pod_slice.push(TestStruct::default()).unwrap(); + assert_eq!(*pod_slice.length, PodU32::from(2)); + let err = pod_slice + .push(TestStruct::default()) + .expect_err("Expected an `PodSliceError::BufferTooSmall` error"); + assert_eq!(err, PodSliceError::BufferTooSmall.into()); + } +} diff --git a/program-error/Cargo.toml b/program-error/Cargo.toml new file mode 100644 index 00000000..352655dc --- /dev/null +++ b/program-error/Cargo.toml @@ -0,0 +1,26 @@ +[package] +name = "spl-program-error" +version = "0.6.0" +description = "Library for Solana Program error attributes and derive macro for creating them" +authors = ["Solana Labs Maintainers "] +repository = "https://github.com/solana-labs/solana-program-library" +license = "Apache-2.0" +edition = "2021" + +[dependencies] +num-derive = "0.4" +num-traits = "0.2" +solana-program = "2.1.0" +spl-program-error-derive = { version = "0.4.1", path = "./derive" } +thiserror = "2.0" + +[dev-dependencies] +lazy_static = "1.5" +serial_test = "3.2" +solana-sdk = "2.1.0" + +[lib] +crate-type = ["cdylib", "lib"] + +[package.metadata.docs.rs] +targets = ["x86_64-unknown-linux-gnu"] diff --git a/program-error/README.md b/program-error/README.md new file mode 100644 index 00000000..5ada5934 --- /dev/null +++ b/program-error/README.md @@ -0,0 +1,293 @@ +# SPL Program Error + +Macros for implementing error-based traits on enums. + +- `#[derive(IntoProgramError)]`: automatically derives the trait `From for solana_program::program_error::ProgramError`. +- `#[derive(DecodeError)]`: automatically derives the trait `solana_program::decode_error::DecodeError`. +- `#[derive(PrintProgramError)]`: automatically derives the trait `solana_program::program_error::PrintProgramError`. +- `#[spl_program_error]`: Automatically derives all below traits: + - `Clone` + - `Debug` + - `Eq` + - `DecodeError` + - `IntoProgramError` + - `PrintProgramError` + - `thiserror::Error` + - `num_derive::FromPrimitive` + - `PartialEq` + +### `#[derive(IntoProgramError)]` + +This derive macro automatically derives the trait `From for solana_program::program_error::ProgramError`. + +Your enum must implement the following traits in order for this macro to work: + +- `Clone` +- `Debug` +- `Eq` +- `thiserror::Error` +- `num_derive::FromPrimitive` +- `PartialEq` + +Sample code: + +```rust +/// Example error +#[derive( + Clone, Debug, Eq, IntoProgramError, thiserror::Error, num_derive::FromPrimitive, PartialEq, +)] +pub enum ExampleError { + /// Mint has no mint authority + #[error("Mint has no mint authority")] + MintHasNoMintAuthority, + /// Incorrect mint authority has signed the instruction + #[error("Incorrect mint authority has signed the instruction")] + IncorrectMintAuthority, +} +``` + +### `#[derive(DecodeError)]` + +This derive macro automatically derives the trait `solana_program::decode_error::DecodeError`. + +Your enum must implement the following traits in order for this macro to work: + +- `Clone` +- `Debug` +- `Eq` +- `IntoProgramError` (above) +- `thiserror::Error` +- `num_derive::FromPrimitive` +- `PartialEq` + +Sample code: + +```rust +/// Example error +#[derive( + Clone, + Debug, + DecodeError, + Eq, + IntoProgramError, + thiserror::Error, + num_derive::FromPrimitive, + PartialEq, +)] +pub enum ExampleError { + /// Mint has no mint authority + #[error("Mint has no mint authority")] + MintHasNoMintAuthority, + /// Incorrect mint authority has signed the instruction + #[error("Incorrect mint authority has signed the instruction")] + IncorrectMintAuthority, +} +``` + +### `#[derive(PrintProgramError)]` + +This derive macro automatically derives the trait `solana_program::program_error::PrintProgramError`. + +Your enum must implement the following traits in order for this macro to work: + +- `Clone` +- `Debug` +- `DecodeError` (above) +- `Eq` +- `IntoProgramError` (above) +- `thiserror::Error` +- `num_derive::FromPrimitive` +- `PartialEq` + +Sample code: + +```rust +/// Example error +#[derive( + Clone, + Debug, + DecodeError, + Eq, + IntoProgramError, + thiserror::Error, + num_derive::FromPrimitive, + PartialEq, +)] +pub enum ExampleError { + /// Mint has no mint authority + #[error("Mint has no mint authority")] + MintHasNoMintAuthority, + /// Incorrect mint authority has signed the instruction + #[error("Incorrect mint authority has signed the instruction")] + IncorrectMintAuthority, +} +``` + +### `#[spl_program_error]` + +It can be cumbersome to ensure your program's defined errors - typically represented +in an enum - implement the required traits and will print to the program's logs when they're +invoked. + +This procedural macro will give you all of the required implementations out of the box: + +- `Clone` +- `Debug` +- `Eq` +- `thiserror::Error` +- `num_derive::FromPrimitive` +- `PartialEq` + +It also imports the required crates so you don't have to in your program: + +- `num_derive` +- `num_traits` +- `thiserror` + +--- + +Just annotate your enum... + +```rust +use solana_program_error_derive::*; + +/// Example error +#[solana_program_error] +pub enum ExampleError { + /// Mint has no mint authority + #[error("Mint has no mint authority")] + MintHasNoMintAuthority, + /// Incorrect mint authority has signed the instruction + #[error("Incorrect mint authority has signed the instruction")] + IncorrectMintAuthority, +} +``` + +...and get: + +```rust +/// Example error +pub enum ExampleError { + /// Mint has no mint authority + #[error("Mint has no mint authority")] + MintHasNoMintAuthority, + /// Incorrect mint authority has signed the instruction + #[error("Incorrect mint authority has signed the instruction")] + IncorrectMintAuthority, +} +#[automatically_derived] +impl ::core::clone::Clone for ExampleError { + #[inline] + fn clone(&self) -> ExampleError { + match self { + ExampleError::MintHasNoMintAuthority => ExampleError::MintHasNoMintAuthority, + ExampleError::IncorrectMintAuthority => ExampleError::IncorrectMintAuthority, + } + } +} +#[automatically_derived] +impl ::core::fmt::Debug for ExampleError { + fn fmt(&self, f: &mut ::core::fmt::Formatter) -> ::core::fmt::Result { + ::core::fmt::Formatter::write_str( + f, + match self { + ExampleError::MintHasNoMintAuthority => "MintHasNoMintAuthority", + ExampleError::IncorrectMintAuthority => "IncorrectMintAuthority", + }, + ) + } +} +#[automatically_derived] +impl ::core::marker::StructuralEq for ExampleError {} +#[automatically_derived] +impl ::core::cmp::Eq for ExampleError { + #[inline] + #[doc(hidden)] + #[no_coverage] + fn assert_receiver_is_total_eq(&self) -> () {} +} +#[allow(unused_qualifications)] +impl std::error::Error for ExampleError {} +#[allow(unused_qualifications)] +impl std::fmt::Display for ExampleError { + fn fmt(&self, __formatter: &mut std::fmt::Formatter) -> std::fmt::Result { + #[allow(unused_variables, deprecated, clippy::used_underscore_binding)] + match self { + ExampleError::MintHasNoMintAuthority {} => { + __formatter.write_fmt(format_args!("Mint has no mint authority")) + } + ExampleError::IncorrectMintAuthority {} => { + __formatter + .write_fmt( + format_args!( + "Incorrect mint authority has signed the instruction" + ), + ) + } + } + } +} +#[allow(non_upper_case_globals, unused_qualifications)] +const _IMPL_NUM_FromPrimitive_FOR_ExampleError: () = { + #[allow(clippy::useless_attribute)] + #[allow(rust_2018_idioms)] + extern crate num_traits as _num_traits; + impl _num_traits::FromPrimitive for ExampleError { + #[allow(trivial_numeric_casts)] + #[inline] + fn from_i64(n: i64) -> Option { + if n == ExampleError::MintHasNoMintAuthority as i64 { + Some(ExampleError::MintHasNoMintAuthority) + } else if n == ExampleError::IncorrectMintAuthority as i64 { + Some(ExampleError::IncorrectMintAuthority) + } else { + None + } + } + #[inline] + fn from_u64(n: u64) -> Option { + Self::from_i64(n as i64) + } + } +}; +#[automatically_derived] +impl ::core::marker::StructuralPartialEq for ExampleError {} +#[automatically_derived] +impl ::core::cmp::PartialEq for ExampleError { + #[inline] + fn eq(&self, other: &ExampleError) -> bool { + let __self_tag = ::core::intrinsics::discriminant_value(self); + let __arg1_tag = ::core::intrinsics::discriminant_value(other); + __self_tag == __arg1_tag + } +} +impl From for solana_program::program_error::ProgramError { + fn from(e: ExampleError) -> Self { + solana_program::program_error::ProgramError::Custom(e as u32) + } +} +impl solana_program::decode_error::DecodeError for ExampleError { + fn type_of() -> &'static str { + "ExampleError" + } +} +impl solana_program::program_error::PrintProgramError for ExampleError { + fn print(&self) + where + E: 'static + std::error::Error + solana_program::decode_error::DecodeError + + solana_program::program_error::PrintProgramError + + num_traits::FromPrimitive, + { + match self { + ExampleError::MintHasNoMintAuthority => { + ::solana_program::log::sol_log("Mint has no mint authority") + } + ExampleError::IncorrectMintAuthority => { + ::solana_program::log::sol_log( + "Incorrect mint authority has signed the instruction", + ) + } + } + } +} +``` diff --git a/program-error/derive/Cargo.toml b/program-error/derive/Cargo.toml new file mode 100644 index 00000000..74aec23b --- /dev/null +++ b/program-error/derive/Cargo.toml @@ -0,0 +1,17 @@ +[package] +name = "spl-program-error-derive" +version = "0.4.1" +description = "Proc-Macro Library for Solana Program error attributes and derive macro" +authors = ["Solana Labs Maintainers "] +repository = "https://github.com/solana-labs/solana-program-library" +license = "Apache-2.0" +edition = "2021" + +[lib] +proc-macro = true + +[dependencies] +proc-macro2 = "1.0" +quote = "1.0" +sha2 = "0.10" +syn = { version = "2.0", features = ["full"] } diff --git a/program-error/derive/src/lib.rs b/program-error/derive/src/lib.rs new file mode 100644 index 00000000..026b0eb7 --- /dev/null +++ b/program-error/derive/src/lib.rs @@ -0,0 +1,85 @@ +//! Crate defining a procedural macro for building Solana program errors + +// Required to include `#[allow(clippy::integer_arithmetic)]` +// below since the tokens generated by `quote!` in the implementation +// for `MacroType::PrintProgramError` and `MacroType::SplProgramError` +// trigger the lint upstream through `quote_token_with_context` within the +// `quote` crate +// +// Culprit is `macro_impl.rs:66` +#![allow(clippy::arithmetic_side_effects)] +#![deny(missing_docs)] +#![cfg_attr(not(test), forbid(unsafe_code))] + +extern crate proc_macro; + +mod macro_impl; +mod parser; + +use { + crate::parser::SplProgramErrorArgs, + macro_impl::MacroType, + proc_macro::TokenStream, + syn::{parse_macro_input, ItemEnum}, +}; + +/// Derive macro to add `Into` +/// trait +#[proc_macro_derive(IntoProgramError)] +pub fn into_program_error(input: TokenStream) -> TokenStream { + let ItemEnum { ident, .. } = parse_macro_input!(input as ItemEnum); + MacroType::IntoProgramError { ident } + .generate_tokens() + .into() +} + +/// Derive macro to add `solana_program::decode_error::DecodeError` trait +#[proc_macro_derive(DecodeError)] +pub fn decode_error(input: TokenStream) -> TokenStream { + let ItemEnum { ident, .. } = parse_macro_input!(input as ItemEnum); + MacroType::DecodeError { ident }.generate_tokens().into() +} + +/// Derive macro to add `solana_program::program_error::PrintProgramError` trait +#[proc_macro_derive(PrintProgramError)] +pub fn print_program_error(input: TokenStream) -> TokenStream { + let ItemEnum { + ident, variants, .. + } = parse_macro_input!(input as ItemEnum); + MacroType::PrintProgramError { ident, variants } + .generate_tokens() + .into() +} + +/// Proc macro attribute to turn your enum into a Solana Program Error +/// +/// Adds: +/// - `Clone` +/// - `Debug` +/// - `Eq` +/// - `PartialEq` +/// - `thiserror::Error` +/// - `num_derive::FromPrimitive` +/// - `Into` +/// - `solana_program::decode_error::DecodeError` +/// - `solana_program::program_error::PrintProgramError` +/// +/// Optionally, you can add `hash_error_code_start: u32` argument to create +/// a unique `u32` _starting_ error codes from the names of the enum variants. +/// Notes: +/// - The _error_ variant will start at this value, and the rest will be +/// incremented by one +/// - The value provided is only for code readability, the actual error code +/// will be a hash of the input string and is checked against your input +/// +/// Syntax: `#[spl_program_error(hash_error_code_start = 1275525928)]` +/// Hash Input: `spl_program_error::` +/// Value: `u32::from_le_bytes([13..17])` +#[proc_macro_attribute] +pub fn spl_program_error(attr: TokenStream, input: TokenStream) -> TokenStream { + let args = parse_macro_input!(attr as SplProgramErrorArgs); + let item_enum = parse_macro_input!(input as ItemEnum); + MacroType::SplProgramError { args, item_enum } + .generate_tokens() + .into() +} diff --git a/program-error/derive/src/macro_impl.rs b/program-error/derive/src/macro_impl.rs new file mode 100644 index 00000000..1f90cd55 --- /dev/null +++ b/program-error/derive/src/macro_impl.rs @@ -0,0 +1,207 @@ +//! The actual token generator for the macro + +use { + crate::parser::{SolanaProgram, SplProgramErrorArgs}, + proc_macro2::Span, + quote::quote, + sha2::{Digest, Sha256}, + syn::{ + punctuated::Punctuated, token::Comma, Expr, ExprLit, Ident, ItemEnum, Lit, LitInt, LitStr, + Token, Variant, + }, +}; + +const SPL_ERROR_HASH_NAMESPACE: &str = "spl_program_error"; +const SPL_ERROR_HASH_MIN_VALUE: u32 = 7_000; + +/// The type of macro being called, thus directing which tokens to generate +#[allow(clippy::enum_variant_names)] +pub enum MacroType { + IntoProgramError { + ident: Ident, + }, + DecodeError { + ident: Ident, + }, + PrintProgramError { + ident: Ident, + variants: Punctuated, + }, + SplProgramError { + args: SplProgramErrorArgs, + item_enum: ItemEnum, + }, +} + +impl MacroType { + /// Generates the corresponding tokens based on variant selection + pub fn generate_tokens(&mut self) -> proc_macro2::TokenStream { + let default_solana_program = SolanaProgram::default(); + match self { + Self::IntoProgramError { ident } => into_program_error(ident, &default_solana_program), + Self::DecodeError { ident } => decode_error(ident, &default_solana_program), + Self::PrintProgramError { ident, variants } => { + print_program_error(ident, variants, &default_solana_program) + } + Self::SplProgramError { args, item_enum } => spl_program_error(args, item_enum), + } + } +} + +/// Builds the implementation of +/// `Into` More specifically, +/// implements `From for solana_program::program_error::ProgramError` +pub fn into_program_error(ident: &Ident, import: &SolanaProgram) -> proc_macro2::TokenStream { + let this_impl = quote! { + impl From<#ident> for #import::program_error::ProgramError { + fn from(e: #ident) -> Self { + #import::program_error::ProgramError::Custom(e as u32) + } + } + }; + import.wrap(this_impl) +} + +/// Builds the implementation of `solana_program::decode_error::DecodeError` +pub fn decode_error(ident: &Ident, import: &SolanaProgram) -> proc_macro2::TokenStream { + let this_impl = quote! { + impl #import::decode_error::DecodeError for #ident { + fn type_of() -> &'static str { + stringify!(#ident) + } + } + }; + import.wrap(this_impl) +} + +/// Builds the implementation of +/// `solana_program::program_error::PrintProgramError` +pub fn print_program_error( + ident: &Ident, + variants: &Punctuated, + import: &SolanaProgram, +) -> proc_macro2::TokenStream { + let ppe_match_arms = variants.iter().map(|variant| { + let variant_ident = &variant.ident; + let error_msg = get_error_message(variant) + .unwrap_or_else(|| String::from("Unknown custom program error")); + quote! { + #ident::#variant_ident => { + #import::msg!(#error_msg) + } + } + }); + let this_impl = quote! { + impl #import::program_error::PrintProgramError for #ident { + fn print(&self) + where + E: 'static + + std::error::Error + + #import::decode_error::DecodeError + + #import::program_error::PrintProgramError + + num_traits::FromPrimitive, + { + match self { + #(#ppe_match_arms),* + } + } + } + }; + import.wrap(this_impl) +} + +/// Helper to parse out the string literal from the `#[error(..)]` attribute +fn get_error_message(variant: &Variant) -> Option { + let attrs = &variant.attrs; + for attr in attrs { + if attr.path().is_ident("error") { + if let Ok(lit_str) = attr.parse_args::() { + return Some(lit_str.value()); + } + } + } + None +} + +/// The main function that produces the tokens required to turn your +/// error enum into a Solana Program Error +pub fn spl_program_error( + args: &SplProgramErrorArgs, + item_enum: &mut ItemEnum, +) -> proc_macro2::TokenStream { + if let Some(error_code_start) = args.hash_error_code_start { + set_first_discriminant(item_enum, error_code_start); + } + + let ident = &item_enum.ident; + let variants = &item_enum.variants; + let into_program_error = into_program_error(ident, &args.import); + let decode_error = decode_error(ident, &args.import); + let print_program_error = print_program_error(ident, variants, &args.import); + + quote! { + #[repr(u32)] + #[derive(Clone, Debug, Eq, thiserror::Error, num_derive::FromPrimitive, PartialEq)] + #[num_traits = "num_traits"] + #item_enum + + #into_program_error + + #decode_error + + #print_program_error + } +} + +/// This function adds a discriminant to the first enum variant based on the +/// hash of the `SPL_ERROR_HASH_NAMESPACE` constant, the enum name and variant +/// name. +/// It will then check to make sure the provided `hash_error_code_start` is +/// equal to the hash-produced `u32`. +/// +/// See https://docs.rs/syn/latest/syn/struct.Variant.html +fn set_first_discriminant(item_enum: &mut ItemEnum, error_code_start: u32) { + let enum_ident = &item_enum.ident; + if item_enum.variants.is_empty() { + panic!("Enum must have at least one variant"); + } + let first_variant = &mut item_enum.variants[0]; + let discriminant = u32_from_hash(enum_ident); + if discriminant == error_code_start { + let eq = Token![=](Span::call_site()); + let expr = Expr::Lit(ExprLit { + attrs: Vec::new(), + lit: Lit::Int(LitInt::new(&discriminant.to_string(), Span::call_site())), + }); + first_variant.discriminant = Some((eq, expr)); + } else { + panic!( + "Error code start value from hash must be {0}. Update your macro attribute to \ + `#[spl_program_error(hash_error_code_start = {0})]`.", + discriminant + ); + } +} + +/// Hashes the `SPL_ERROR_HASH_NAMESPACE` constant, the enum name and variant +/// name and returns four middle bytes (13 through 16) as a u32. +fn u32_from_hash(enum_ident: &Ident) -> u32 { + let hash_input = format!("{}:{}", SPL_ERROR_HASH_NAMESPACE, enum_ident); + + // We don't want our error code to start at any number below + // `SPL_ERROR_HASH_MIN_VALUE`! + let mut nonce: u32 = 0; + loop { + let mut hasher = Sha256::new_with_prefix(hash_input.as_bytes()); + hasher.update(nonce.to_le_bytes()); + let d = u32::from_le_bytes( + hasher.finalize()[13..17] + .try_into() + .expect("Unable to convert hash to u32"), + ); + if d >= SPL_ERROR_HASH_MIN_VALUE { + return d; + } + nonce += 1; + } +} diff --git a/program-error/derive/src/parser.rs b/program-error/derive/src/parser.rs new file mode 100644 index 00000000..7e01cb72 --- /dev/null +++ b/program-error/derive/src/parser.rs @@ -0,0 +1,145 @@ +//! Token parsing + +use { + proc_macro2::{Ident, Span, TokenStream}, + quote::quote, + syn::{ + parse::{Parse, ParseStream}, + token::Comma, + LitInt, LitStr, Token, + }, +}; + +/// Possible arguments to the `#[spl_program_error]` attribute +pub struct SplProgramErrorArgs { + /// Whether to hash the error codes using `solana_program::hash` + /// or to use the default error code assigned by `num_traits`. + pub hash_error_code_start: Option, + /// Crate to use for solana_program + pub import: SolanaProgram, +} + +/// Struct representing the path to a `solana_program` crate, which may be +/// renamed or otherwise. +pub struct SolanaProgram { + import: Ident, + explicit: bool, +} +impl quote::ToTokens for SolanaProgram { + fn to_tokens(&self, tokens: &mut TokenStream) { + self.import.to_tokens(tokens); + } +} +impl SolanaProgram { + pub fn wrap(&self, output: TokenStream) -> TokenStream { + if self.explicit { + output + } else { + anon_const_trick(output) + } + } +} +impl Default for SolanaProgram { + fn default() -> Self { + Self { + import: Ident::new("_solana_program", Span::call_site()), + explicit: false, + } + } +} + +impl Parse for SplProgramErrorArgs { + fn parse(input: ParseStream) -> syn::Result { + let mut hash_error_code_start = None; + let mut import = None; + while !input.is_empty() { + match SplProgramErrorArgParser::parse(input)? { + SplProgramErrorArgParser::HashErrorCodes { value, .. } => { + hash_error_code_start = Some(value.base10_parse::()?); + } + SplProgramErrorArgParser::SolanaProgramCrate { value, .. } => { + import = Some(SolanaProgram { + import: value.parse()?, + explicit: true, + }); + } + } + } + Ok(Self { + hash_error_code_start, + import: import.unwrap_or(SolanaProgram::default()), + }) + } +} + +/// Parser for args to the `#[spl_program_error]` attribute +/// ie. `#[spl_program_error(hash_error_code_start = 1275525928)]` +enum SplProgramErrorArgParser { + HashErrorCodes { + _equals_sign: Token![=], + value: LitInt, + _comma: Option, + }, + SolanaProgramCrate { + _equals_sign: Token![=], + value: LitStr, + _comma: Option, + }, +} + +impl Parse for SplProgramErrorArgParser { + fn parse(input: ParseStream) -> syn::Result { + let ident = input.parse::()?; + match ident.to_string().as_str() { + "hash_error_code_start" => { + let _equals_sign = input.parse::()?; + let value = input.parse::()?; + let _comma: Option = input.parse().unwrap_or(None); + Ok(Self::HashErrorCodes { + _equals_sign, + value, + _comma, + }) + } + "solana_program" => { + let _equals_sign = input.parse::()?; + let value = input.parse::()?; + let _comma: Option = input.parse().unwrap_or(None); + Ok(Self::SolanaProgramCrate { + _equals_sign, + value, + _comma, + }) + } + _ => Err(input.error("Expected argument 'hash_error_code_start' or 'solana_program'")), + } + } +} + +// Within `exp`, you can bring things into scope with `extern crate`. +// +// We don't want to assume that `solana_program::` is in scope - the user may +// have imported it under a different name, or may have imported it in a +// non-toplevel module (common when putting impls behind a feature gate). +// +// Solution: let's just generate `extern crate solana_program as +// _solana_program` and then refer to `_solana_program` in the derived code. +// However, macros are not allowed to produce `extern crate` statements at the +// toplevel. +// +// Solution: let's generate `mod _impl_foo` and import solana_program within +// that. However, now we lose access to private members of the surrounding +// module. This is a problem if, for example, we're deriving for a newtype, +// where the inner type is defined in the same module, but not exported. +// +// Solution: use the anonymous const trick. For some reason, `extern crate` +// statements are allowed here, but everything from the surrounding module is in +// scope. This trick is taken from serde and num_traits. +fn anon_const_trick(exp: TokenStream) -> TokenStream { + quote! { + const _: () = { + extern crate solana_program as _solana_program; + #exp + }; + } +} diff --git a/program-error/src/lib.rs b/program-error/src/lib.rs new file mode 100644 index 00000000..739995dd --- /dev/null +++ b/program-error/src/lib.rs @@ -0,0 +1,17 @@ +//! Crate defining a library with a procedural macro and other +//! dependencies for building Solana program errors + +#![deny(missing_docs)] +#![cfg_attr(not(test), forbid(unsafe_code))] + +extern crate self as spl_program_error; + +// Make these available downstream for the macro to work without +// additional imports +pub use { + num_derive, num_traits, solana_program, + spl_program_error_derive::{ + spl_program_error, DecodeError, IntoProgramError, PrintProgramError, + }, + thiserror, +}; diff --git a/program-error/tests/bench.rs b/program-error/tests/bench.rs new file mode 100644 index 00000000..40f01bb7 --- /dev/null +++ b/program-error/tests/bench.rs @@ -0,0 +1,50 @@ +//! Bench case with manual implementations +use spl_program_error::*; + +/// Example error +#[derive(Clone, Debug, Eq, thiserror::Error, num_derive::FromPrimitive, PartialEq)] +pub enum ExampleError { + /// Mint has no mint authority + #[error("Mint has no mint authority")] + MintHasNoMintAuthority, + /// Incorrect mint authority has signed the instruction + #[error("Incorrect mint authority has signed the instruction")] + IncorrectMintAuthority, +} + +impl From for solana_program::program_error::ProgramError { + fn from(e: ExampleError) -> Self { + solana_program::program_error::ProgramError::Custom(e as u32) + } +} +impl solana_program::decode_error::DecodeError for ExampleError { + fn type_of() -> &'static str { + "ExampleError" + } +} + +impl solana_program::program_error::PrintProgramError for ExampleError { + fn print(&self) + where + E: 'static + + std::error::Error + + solana_program::decode_error::DecodeError + + solana_program::program_error::PrintProgramError + + num_traits::FromPrimitive, + { + match self { + ExampleError::MintHasNoMintAuthority => { + solana_program::msg!("Mint has no mint authority") + } + ExampleError::IncorrectMintAuthority => { + solana_program::msg!("Incorrect mint authority has signed the instruction") + } + } + } +} + +/// Tests that all macros compile +#[test] +fn test_macros_compile() { + let _ = ExampleError::MintHasNoMintAuthority; +} diff --git a/program-error/tests/decode.rs b/program-error/tests/decode.rs new file mode 100644 index 00000000..0c7c209b --- /dev/null +++ b/program-error/tests/decode.rs @@ -0,0 +1,29 @@ +//! Tests `#[derive(DecodeError)]` + +use spl_program_error::*; + +/// Example error +#[derive( + Clone, + Debug, + DecodeError, + Eq, + IntoProgramError, + thiserror::Error, + num_derive::FromPrimitive, + PartialEq, +)] +pub enum ExampleError { + /// Mint has no mint authority + #[error("Mint has no mint authority")] + MintHasNoMintAuthority, + /// Incorrect mint authority has signed the instruction + #[error("Incorrect mint authority has signed the instruction")] + IncorrectMintAuthority, +} + +/// Tests that all macros compile +#[test] +fn test_macros_compile() { + let _ = ExampleError::MintHasNoMintAuthority; +} diff --git a/program-error/tests/into.rs b/program-error/tests/into.rs new file mode 100644 index 00000000..0f32b8f4 --- /dev/null +++ b/program-error/tests/into.rs @@ -0,0 +1,22 @@ +//! Tests `#[derive(IntoProgramError)]` + +use spl_program_error::*; + +/// Example error +#[derive( + Clone, Debug, Eq, IntoProgramError, thiserror::Error, num_derive::FromPrimitive, PartialEq, +)] +pub enum ExampleError { + /// Mint has no mint authority + #[error("Mint has no mint authority")] + MintHasNoMintAuthority, + /// Incorrect mint authority has signed the instruction + #[error("Incorrect mint authority has signed the instruction")] + IncorrectMintAuthority, +} + +/// Tests that all macros compile +#[test] +fn test_macros_compile() { + let _ = ExampleError::MintHasNoMintAuthority; +} diff --git a/program-error/tests/mod.rs b/program-error/tests/mod.rs new file mode 100644 index 00000000..41358775 --- /dev/null +++ b/program-error/tests/mod.rs @@ -0,0 +1,141 @@ +pub mod bench; +pub mod decode; +pub mod into; +pub mod print; +pub mod spl; + +#[cfg(test)] +mod tests { + use { + super::*, + serial_test::serial, + solana_program::{ + decode_error::DecodeError, + program_error::{PrintProgramError, ProgramError}, + }, + std::sync::{Arc, RwLock}, + }; + + // Used to capture output for `PrintProgramError` for testing + lazy_static::lazy_static! { + static ref EXPECTED_DATA: Arc>> = Arc::new(RwLock::new(Vec::new())); + } + fn set_expected_data(expected_data: Vec) { + *EXPECTED_DATA.write().unwrap() = expected_data; + } + pub struct SyscallStubs {} + impl solana_sdk::program_stubs::SyscallStubs for SyscallStubs { + fn sol_log(&self, message: &str) { + assert_eq!( + message, + String::from_utf8_lossy(&EXPECTED_DATA.read().unwrap()) + ); + } + } + + // `#[derive(IntoProgramError)]` + #[test] + fn test_derive_into_program_error() { + // `Into` + assert_eq!( + Into::::into(bench::ExampleError::MintHasNoMintAuthority), + Into::::into(into::ExampleError::MintHasNoMintAuthority), + ); + assert_eq!( + Into::::into(bench::ExampleError::IncorrectMintAuthority), + Into::::into(into::ExampleError::IncorrectMintAuthority), + ); + } + + // `#[derive(DecodeError)]` + #[test] + fn test_derive_decode_error() { + // `Into` + assert_eq!( + Into::::into(bench::ExampleError::MintHasNoMintAuthority), + Into::::into(decode::ExampleError::MintHasNoMintAuthority), + ); + assert_eq!( + Into::::into(bench::ExampleError::IncorrectMintAuthority), + Into::::into(decode::ExampleError::IncorrectMintAuthority), + ); + // `DecodeError` + assert_eq!( + >::type_of(), + >::type_of(), + ); + } + // `#[derive(PrintProgramError)]` + #[test] + #[serial] + fn test_derive_print_program_error() { + use std::sync::Once; + static ONCE: Once = Once::new(); + + ONCE.call_once(|| { + solana_sdk::program_stubs::set_syscall_stubs(Box::new(SyscallStubs {})); + }); + // `Into` + assert_eq!( + Into::::into(bench::ExampleError::MintHasNoMintAuthority), + Into::::into(print::ExampleError::MintHasNoMintAuthority), + ); + assert_eq!( + Into::::into(bench::ExampleError::IncorrectMintAuthority), + Into::::into(print::ExampleError::IncorrectMintAuthority), + ); + // `DecodeError` + assert_eq!( + >::type_of(), + >::type_of(), + ); + // `PrintProgramError` + set_expected_data("Mint has no mint authority".as_bytes().to_vec()); + PrintProgramError::print::( + &print::ExampleError::MintHasNoMintAuthority, + ); + set_expected_data( + "Incorrect mint authority has signed the instruction" + .as_bytes() + .to_vec(), + ); + PrintProgramError::print::( + &print::ExampleError::IncorrectMintAuthority, + ); + } + + // `#[spl_program_error]` + #[test] + #[serial] + fn test_spl_program_error() { + use std::sync::Once; + static ONCE: Once = Once::new(); + + ONCE.call_once(|| { + solana_sdk::program_stubs::set_syscall_stubs(Box::new(SyscallStubs {})); + }); + // `Into` + assert_eq!( + Into::::into(bench::ExampleError::MintHasNoMintAuthority), + Into::::into(spl::ExampleError::MintHasNoMintAuthority), + ); + assert_eq!( + Into::::into(bench::ExampleError::IncorrectMintAuthority), + Into::::into(spl::ExampleError::IncorrectMintAuthority), + ); + // `DecodeError` + assert_eq!( + >::type_of(), + >::type_of(), + ); + // `PrintProgramError` + set_expected_data("Mint has no mint authority".as_bytes().to_vec()); + PrintProgramError::print::(&spl::ExampleError::MintHasNoMintAuthority); + set_expected_data( + "Incorrect mint authority has signed the instruction" + .as_bytes() + .to_vec(), + ); + PrintProgramError::print::(&spl::ExampleError::IncorrectMintAuthority); + } +} diff --git a/program-error/tests/print.rs b/program-error/tests/print.rs new file mode 100644 index 00000000..8b68f66a --- /dev/null +++ b/program-error/tests/print.rs @@ -0,0 +1,30 @@ +//! Tests `#[derive(PrintProgramError)]` + +use spl_program_error::*; + +/// Example error +#[derive( + Clone, + Debug, + DecodeError, + Eq, + IntoProgramError, + PrintProgramError, + thiserror::Error, + num_derive::FromPrimitive, + PartialEq, +)] +pub enum ExampleError { + /// Mint has no mint authority + #[error("Mint has no mint authority")] + MintHasNoMintAuthority, + /// Incorrect mint authority has signed the instruction + #[error("Incorrect mint authority has signed the instruction")] + IncorrectMintAuthority, +} + +/// Tests that all macros compile +#[test] +fn test_macros_compile() { + let _ = ExampleError::MintHasNoMintAuthority; +} diff --git a/program-error/tests/spl.rs b/program-error/tests/spl.rs new file mode 100644 index 00000000..d772d24f --- /dev/null +++ b/program-error/tests/spl.rs @@ -0,0 +1,91 @@ +//! Tests `#[spl_program_error]` + +use spl_program_error::*; + +/// Example error +#[spl_program_error] +pub enum ExampleError { + /// Mint has no mint authority + #[error("Mint has no mint authority")] + MintHasNoMintAuthority, + /// Incorrect mint authority has signed the instruction + #[error("Incorrect mint authority has signed the instruction")] + IncorrectMintAuthority, +} + +/// Tests that all macros compile +#[test] +fn test_macros_compile() { + let _ = ExampleError::MintHasNoMintAuthority; +} + +/// Example library error with namespace +#[spl_program_error(hash_error_code_start = 2_056_342_880)] +enum ExampleLibraryError { + /// This is a very informative error + #[error("This is a very informative error")] + VeryInformativeError, + /// This is a super important error + #[error("This is a super important error")] + SuperImportantError, + /// This is a mega serious error + #[error("This is a mega serious error")] + MegaSeriousError, + /// You are toast + #[error("You are toast")] + YouAreToast, +} + +/// Tests hashing of error codes into unique `u32` values +#[test] +fn test_library_error_codes() { + fn get_error_code_check(hash_input: &str) -> u32 { + let mut nonce: u32 = 0; + loop { + let hash = solana_program::hash::hashv(&[hash_input.as_bytes(), &nonce.to_le_bytes()]); + let mut bytes = [0u8; 4]; + bytes.copy_from_slice(&hash.to_bytes()[13..17]); + let error_code = u32::from_le_bytes(bytes); + if error_code >= 10_000 { + return error_code; + } + nonce += 1; + } + } + + let first_error_as_u32 = ExampleLibraryError::VeryInformativeError as u32; + + assert_eq!( + ExampleLibraryError::VeryInformativeError as u32, + get_error_code_check("spl_program_error:ExampleLibraryError"), + ); + assert_eq!( + ExampleLibraryError::SuperImportantError as u32, + first_error_as_u32 + 1, + ); + assert_eq!( + ExampleLibraryError::MegaSeriousError as u32, + first_error_as_u32 + 2, + ); + assert_eq!( + ExampleLibraryError::YouAreToast as u32, + first_error_as_u32 + 3, + ); +} + +/// Example error with solana_program crate set +#[spl_program_error(solana_program = "solana_program")] +enum ExampleSolanaProgramCrateError { + /// This is a very informative error + #[error("This is a very informative error")] + VeryInformativeError, + /// This is a super important error + #[error("This is a super important error")] + SuperImportantError, +} + +/// Tests that all macros compile +#[test] +fn test_macros_compile_with_solana_program_crate() { + let _ = ExampleSolanaProgramCrateError::VeryInformativeError; +} diff --git a/tlv-account-resolution/Cargo.toml b/tlv-account-resolution/Cargo.toml new file mode 100644 index 00000000..fffd7b83 --- /dev/null +++ b/tlv-account-resolution/Cargo.toml @@ -0,0 +1,42 @@ +[package] +name = "spl-tlv-account-resolution" +version = "0.9.0" +description = "Solana Program Library TLV Account Resolution Interface" +authors = ["Solana Labs Maintainers "] +repository = "https://github.com/solana-labs/solana-program-library" +license = "Apache-2.0" +edition = "2021" + +[features] +serde-traits = ["dep:serde"] +test-sbf = [] + +[dependencies] +bytemuck = { version = "1.20.0", features = ["derive"] } +num-derive = "0.4" +num-traits = "0.2" +serde = { version = "1.0.216", optional = true } +solana-account-info = "2.1.0" +solana-decode-error = "2.1.0" +solana-instruction = { version = "2.1.0", features = ["std"] } +solana-program-error = "2.1.0" +solana-msg = "2.1.0" +solana-pubkey = "2.1.0" +spl-discriminator = { version = "0.4.0", path = "../discriminator" } +spl-program-error = { version = "0.6.0", path = "../program-error" } +spl-pod = { version = "0.5.0", path = "../pod" } +spl-type-length-value = { version = "0.7.0", path = "../type-length-value" } +thiserror = "2.0" + +[dev-dependencies] +futures = "0.3.31" +futures-util = "0.3" +solana-client = "2.1.0" +solana-program-test = "2.1.0" +solana-sdk = "2.1.0" + +[lib] +crate-type = ["cdylib", "lib"] + +[package.metadata.docs.rs] +targets = ["x86_64-unknown-linux-gnu"] diff --git a/tlv-account-resolution/README.md b/tlv-account-resolution/README.md new file mode 100644 index 00000000..2b7dd2d0 --- /dev/null +++ b/tlv-account-resolution/README.md @@ -0,0 +1,198 @@ +# TLV Account Resolution + +Library defining a generic state interface to encode additional required accounts +for an instruction, using Type-Length-Value structures. + +## Example usage + +If you want to encode the additional required accounts for your instruction +into a TLV entry in an account, you can do the following: + +```rust +use { + solana_account_info::AccountInfo, + solana_instruction::{AccountMeta, Instruction}, + solana_pubkey::Pubkey + spl_discriminator::{ArrayDiscriminator, SplDiscriminate}, + spl_tlv_account_resolution::{ + account::ExtraAccountMeta, + seeds::Seed, + state::ExtraAccountMetaList + }, +}; + +struct MyInstruction; +impl SplDiscriminate for MyInstruction { + // For ease of use, give it the same discriminator as its instruction definition + const SPL_DISCRIMINATOR: ArrayDiscriminator = ArrayDiscriminator::new([1; ArrayDiscriminator::LENGTH]); +} + +// Prepare the additional required account keys and signer / writable +let extra_metas = [ + AccountMeta::new(Pubkey::new_unique(), false).into(), + AccountMeta::new_readonly(Pubkey::new_unique(), true).into(), + ExtraAccountMeta::new_with_seeds( + &[ + Seed::Literal { + bytes: b"some_string".to_vec(), + }, + Seed::InstructionData { + index: 1, + length: 1, // u8 + }, + Seed::AccountKey { index: 1 }, + ], + false, + true, + ).unwrap(), + ExtraAccountMeta::new_external_pda_with_seeds( + 0, + &[Seed::AccountKey { index: 2 }], + false, + false, + ).unwrap(), +]; + +// Allocate a new buffer with the proper `account_size` +let account_size = ExtraAccountMetaList::size_of(extra_metas.len()).unwrap(); +let mut buffer = vec![0; account_size]; + +// Initialize the structure for your instruction +ExtraAccountMetaList::init::(&mut buffer, &extra_metas).unwrap(); + +// Off-chain, you can add the additional accounts directly from the account data +// You need to provide the resolver a way to fetch account data off-chain +let client = RpcClient::new_mock("succeeds".to_string()); +let program_id = Pubkey::new_unique(); +let mut instruction = Instruction::new_with_bytes(program_id, &[0, 1, 2], vec![]); +ExtraAccountMetaList::add_to_instruction::<_, _, MyInstruction>( + &mut instruction, + |address: &Pubkey| { + client + .get_account(address) + .map_ok(|acct| Some(acct.data)) + }, + &buffer, +) +.await +.unwrap(); + +// On-chain, you can add the additional accounts *and* account infos +let mut cpi_instruction = Instruction::new_with_bytes(program_id, &[0, 1, 2], vec![]); + +// Include all of the well-known required account infos here first +let mut cpi_account_infos = vec![]; + +// Provide all "remaining_account_infos" that are *not* part of any other known interface +let remaining_account_infos = &[]; +ExtraAccountMetaList::add_to_cpi_instruction::( + &mut cpi_instruction, + &mut cpi_account_infos, + &buffer, + &remaining_account_infos, +).unwrap(); +``` + +For ease of use on-chain, `ExtraAccountMetaList::init` is also +provided to initialize directly from a set of given accounts. + +## Motivation + +The Solana account model presents unique challenges for program interfaces. +Since it's impossible to load additional accounts on-chain, if a program requires +additional accounts to properly implement an instruction, there's no clear way +for clients to fetch these accounts. + +There are two main ways to fetch additional accounts, dynamically through program +simulation, or statically by fetching account data. This library implements +additional account resolution statically. You can find more information about +dynamic account resolution in the Appendix. + +### Static Account Resolution + +It's possible for programs to write the additional required account infos +into account data, so that on-chain and off-chain clients simply need to read +the data to figure out the additional required accounts. + +Rather than exposing this data dynamically through program execution, this method +uses static account data. + +For example, let's imagine there's a `Transferable` interface, along with a +`transfer` instruction. Some programs that implement `transfer` may need more +accounts than just the ones defined in the interface. How does an on-chain or +off-chain client figure out the additional required accounts? + +The "static" approach requires programs to write the extra required accounts to +an account defined at a given address. This could be directly in the `mint`, or +some address derivable from the mint address. + +Off-chain, a client must fetch this additional account and read its data to find +out the additional required accounts, and then include them in the instruction. + +On-chain, a program must have access to "remaining account infos" containing the +special account and all other required accounts to properly create the CPI +instruction and give the correct account infos. + +This approach could also be called a "state interface". + +### Types of Required Accounts + +This library is capable of storing two types of configurations for additional +required accounts: + +- Accounts with a fixed address +- Accounts with a **dynamic program-derived address** derived from seeds that +may come from any combination of the following: + - Hard-coded values, such as string literals or integers + - A slice of the instruction data provided to the transfer-hook program + - The address of another account in the total list of accounts + - A program id from another account in the instruction + +When you store configurations for a dynamic Program-Derived Address within the +additional required accounts, the PDA itself is evaluated (or resolved) at the +time of instruction invocation using the instruction itself. This +occurs in the offchain and onchain helpers mentioned below, which leverage +the SPL TLV Account Resolution library to perform this resolution +automatically. + +## How it Works + +This library uses `spl-type-length-value` to read and write required instruction +accounts from account data. + +Interface instructions must have an 8-byte discriminator, so that the exposed +`ExtraAccountMetaList` type can use the instruction discriminator as an +`ArrayDiscriminator`, which allows that discriminator to serve as a unique TLV +discriminator for identifying entries that correspond to that particular +instruction. + +This can be confusing. Typically, a type implements `SplDiscriminate`, so that +the type can be written into TLV data. In this case, `ExtraAccountMetaList` is +generic over `SplDiscriminate`, meaning that a program can write many different instances of +`ExtraAccountMetaList` into one account, using different `ArrayDiscriminator`s. + +Also, it's reusing an instruction discriminator as a TLV discriminator. For example, +if the `transfer` instruction has a discriminator of `[1, 2, 3, 4, 5, 6, 7, 8]`, +then the account uses a TLV discriminator of `[1, 2, 3, 4, 5, 6, 7, 8]` to denote +where the additional account metas are stored. + +This isn't required, but makes it easier for clients to find the additional +required accounts for an instruction. + +## Appendix + +### Dynamic Account Resolution + +To expose the additional accounts required, instruction interfaces can include +supplemental instructions to return the required accounts. + +For example, in the `Transferable` interface example, along with a `transfer` +instruction, also requires implementations to expose a +`get_additional_accounts_for_transfer` instruction. + +In the program implementation, this instruction writes the additional accounts +into return data, making it easy for on-chain and off-chain clients to consume. + +See the +[relevant sRFC](https://forum.solana.com/t/srfc-00010-additional-accounts-request-transfer-spec/122) +for more information about the dynamic approach. diff --git a/tlv-account-resolution/src/account.rs b/tlv-account-resolution/src/account.rs new file mode 100644 index 00000000..91305680 --- /dev/null +++ b/tlv-account-resolution/src/account.rs @@ -0,0 +1,296 @@ +//! Struct for managing extra required account configs, ie. defining accounts +//! required for your interface program, which can be `AccountMeta`s - which +//! have fixed addresses - or PDAs - which have addresses derived from a +//! collection of seeds + +use { + crate::{error::AccountResolutionError, pubkey_data::PubkeyData, seeds::Seed}, + bytemuck::{Pod, Zeroable}, + solana_account_info::AccountInfo, + solana_instruction::AccountMeta, + solana_program_error::ProgramError, + solana_pubkey::{Pubkey, PUBKEY_BYTES}, + spl_pod::primitives::PodBool, +}; + +/// Resolve a program-derived address (PDA) from the instruction data +/// and the accounts that have already been resolved +fn resolve_pda<'a, F>( + seeds: &[Seed], + instruction_data: &[u8], + program_id: &Pubkey, + get_account_key_data_fn: F, +) -> Result +where + F: Fn(usize) -> Option<(&'a Pubkey, Option<&'a [u8]>)>, +{ + let mut pda_seeds: Vec<&[u8]> = vec![]; + for config in seeds { + match config { + Seed::Uninitialized => (), + Seed::Literal { bytes } => pda_seeds.push(bytes), + Seed::InstructionData { index, length } => { + let arg_start = *index as usize; + let arg_end = arg_start + *length as usize; + if arg_end > instruction_data.len() { + return Err(AccountResolutionError::InstructionDataTooSmall.into()); + } + pda_seeds.push(&instruction_data[arg_start..arg_end]); + } + Seed::AccountKey { index } => { + let account_index = *index as usize; + let address = get_account_key_data_fn(account_index) + .ok_or::(AccountResolutionError::AccountNotFound.into())? + .0; + pda_seeds.push(address.as_ref()); + } + Seed::AccountData { + account_index, + data_index, + length, + } => { + let account_index = *account_index as usize; + let account_data = get_account_key_data_fn(account_index) + .ok_or::(AccountResolutionError::AccountNotFound.into())? + .1 + .ok_or::(AccountResolutionError::AccountDataNotFound.into())?; + let arg_start = *data_index as usize; + let arg_end = arg_start + *length as usize; + if account_data.len() < arg_end { + return Err(AccountResolutionError::AccountDataTooSmall.into()); + } + pda_seeds.push(&account_data[arg_start..arg_end]); + } + } + } + Ok(Pubkey::find_program_address(&pda_seeds, program_id).0) +} + +/// Resolve a pubkey from a pubkey data configuration. +fn resolve_key_data<'a, F>( + key_data: &PubkeyData, + instruction_data: &[u8], + get_account_key_data_fn: F, +) -> Result +where + F: Fn(usize) -> Option<(&'a Pubkey, Option<&'a [u8]>)>, +{ + match key_data { + PubkeyData::Uninitialized => Err(ProgramError::InvalidAccountData), + PubkeyData::InstructionData { index } => { + let key_start = *index as usize; + let key_end = key_start + PUBKEY_BYTES; + if key_end > instruction_data.len() { + return Err(AccountResolutionError::InstructionDataTooSmall.into()); + } + Ok(Pubkey::new_from_array( + instruction_data[key_start..key_end].try_into().unwrap(), + )) + } + PubkeyData::AccountData { + account_index, + data_index, + } => { + let account_index = *account_index as usize; + let account_data = get_account_key_data_fn(account_index) + .ok_or::(AccountResolutionError::AccountNotFound.into())? + .1 + .ok_or::(AccountResolutionError::AccountDataNotFound.into())?; + let arg_start = *data_index as usize; + let arg_end = arg_start + PUBKEY_BYTES; + if account_data.len() < arg_end { + return Err(AccountResolutionError::AccountDataTooSmall.into()); + } + Ok(Pubkey::new_from_array( + account_data[arg_start..arg_end].try_into().unwrap(), + )) + } + } +} + +/// `Pod` type for defining a required account in a validation account. +/// +/// This can be any of the following: +/// +/// * A standard `AccountMeta` +/// * A PDA (with seed configurations) +/// * A pubkey stored in some data (account or instruction data) +/// +/// Can be used in TLV-encoded data. +#[repr(C)] +#[derive(Clone, Copy, Debug, Default, PartialEq, Pod, Zeroable)] +pub struct ExtraAccountMeta { + /// Discriminator to tell whether this represents a standard + /// `AccountMeta`, PDA, or pubkey data. + pub discriminator: u8, + /// This `address_config` field can either be the pubkey of the account, + /// the seeds used to derive the pubkey from provided inputs (PDA), or the + /// data used to derive the pubkey (account or instruction data). + pub address_config: [u8; 32], + /// Whether the account should sign + pub is_signer: PodBool, + /// Whether the account should be writable + pub is_writable: PodBool, +} +/// Helper used to know when the top bit is set, to interpret the +/// discriminator as an index rather than as a type +const U8_TOP_BIT: u8 = 1 << 7; +impl ExtraAccountMeta { + /// Create a `ExtraAccountMeta` from a public key, + /// thus representing a standard `AccountMeta` + pub fn new_with_pubkey( + pubkey: &Pubkey, + is_signer: bool, + is_writable: bool, + ) -> Result { + Ok(Self { + discriminator: 0, + address_config: pubkey.to_bytes(), + is_signer: is_signer.into(), + is_writable: is_writable.into(), + }) + } + + /// Create a `ExtraAccountMeta` from a list of seed configurations, + /// thus representing a PDA + pub fn new_with_seeds( + seeds: &[Seed], + is_signer: bool, + is_writable: bool, + ) -> Result { + Ok(Self { + discriminator: 1, + address_config: Seed::pack_into_address_config(seeds)?, + is_signer: is_signer.into(), + is_writable: is_writable.into(), + }) + } + + /// Create a `ExtraAccountMeta` from a pubkey data configuration. + pub fn new_with_pubkey_data( + key_data: &PubkeyData, + is_signer: bool, + is_writable: bool, + ) -> Result { + Ok(Self { + discriminator: 2, + address_config: PubkeyData::pack_into_address_config(key_data)?, + is_signer: is_signer.into(), + is_writable: is_writable.into(), + }) + } + + /// Create a `ExtraAccountMeta` from a list of seed configurations, + /// representing a PDA for an external program + /// + /// This PDA belongs to a program elsewhere in the account list, rather + /// than the executing program. For a PDA on the executing program, use + /// `ExtraAccountMeta::new_with_seeds`. + pub fn new_external_pda_with_seeds( + program_index: u8, + seeds: &[Seed], + is_signer: bool, + is_writable: bool, + ) -> Result { + Ok(Self { + discriminator: program_index + .checked_add(U8_TOP_BIT) + .ok_or(AccountResolutionError::InvalidSeedConfig)?, + address_config: Seed::pack_into_address_config(seeds)?, + is_signer: is_signer.into(), + is_writable: is_writable.into(), + }) + } + + /// Resolve an `ExtraAccountMeta` into an `AccountMeta`, potentially + /// resolving a program-derived address (PDA) if necessary + pub fn resolve<'a, F>( + &self, + instruction_data: &[u8], + program_id: &Pubkey, + get_account_key_data_fn: F, + ) -> Result + where + F: Fn(usize) -> Option<(&'a Pubkey, Option<&'a [u8]>)>, + { + match self.discriminator { + 0 => AccountMeta::try_from(self), + x if x == 1 || x >= U8_TOP_BIT => { + let program_id = if x == 1 { + program_id + } else { + get_account_key_data_fn(x.saturating_sub(U8_TOP_BIT) as usize) + .ok_or::(AccountResolutionError::AccountNotFound.into())? + .0 + }; + let seeds = Seed::unpack_address_config(&self.address_config)?; + Ok(AccountMeta { + pubkey: resolve_pda( + &seeds, + instruction_data, + program_id, + get_account_key_data_fn, + )?, + is_signer: self.is_signer.into(), + is_writable: self.is_writable.into(), + }) + } + 2 => { + let key_data = PubkeyData::unpack(&self.address_config)?; + Ok(AccountMeta { + pubkey: resolve_key_data(&key_data, instruction_data, get_account_key_data_fn)?, + is_signer: self.is_signer.into(), + is_writable: self.is_writable.into(), + }) + } + _ => Err(ProgramError::InvalidAccountData), + } + } +} + +impl From<&AccountMeta> for ExtraAccountMeta { + fn from(meta: &AccountMeta) -> Self { + Self { + discriminator: 0, + address_config: meta.pubkey.to_bytes(), + is_signer: meta.is_signer.into(), + is_writable: meta.is_writable.into(), + } + } +} +impl From for ExtraAccountMeta { + fn from(meta: AccountMeta) -> Self { + ExtraAccountMeta::from(&meta) + } +} +impl From<&AccountInfo<'_>> for ExtraAccountMeta { + fn from(account_info: &AccountInfo) -> Self { + Self { + discriminator: 0, + address_config: account_info.key.to_bytes(), + is_signer: account_info.is_signer.into(), + is_writable: account_info.is_writable.into(), + } + } +} +impl From> for ExtraAccountMeta { + fn from(account_info: AccountInfo) -> Self { + ExtraAccountMeta::from(&account_info) + } +} + +impl TryFrom<&ExtraAccountMeta> for AccountMeta { + type Error = ProgramError; + + fn try_from(pod: &ExtraAccountMeta) -> Result { + if pod.discriminator == 0 { + Ok(AccountMeta { + pubkey: Pubkey::from(pod.address_config), + is_signer: pod.is_signer.into(), + is_writable: pod.is_writable.into(), + }) + } else { + Err(AccountResolutionError::AccountTypeNotAccountMeta.into()) + } + } +} diff --git a/tlv-account-resolution/src/error.rs b/tlv-account-resolution/src/error.rs new file mode 100644 index 00000000..919fe603 --- /dev/null +++ b/tlv-account-resolution/src/error.rs @@ -0,0 +1,165 @@ +//! Error types + +use { + solana_decode_error::DecodeError, + solana_msg::msg, + solana_program_error::{PrintProgramError, ProgramError}, +}; + +/// Errors that may be returned by the Account Resolution library. +#[repr(u32)] +#[derive(Clone, Debug, Eq, thiserror::Error, num_derive::FromPrimitive, PartialEq)] +pub enum AccountResolutionError { + /// Incorrect account provided + #[error("Incorrect account provided")] + IncorrectAccount = 2_724_315_840, + /// Not enough accounts provided + #[error("Not enough accounts provided")] + NotEnoughAccounts, + /// No value initialized in TLV data + #[error("No value initialized in TLV data")] + TlvUninitialized, + /// Some value initialized in TLV data + #[error("Some value initialized in TLV data")] + TlvInitialized, + /// Too many pubkeys provided + #[error("Too many pubkeys provided")] + TooManyPubkeys, + /// Failed to parse `Pubkey` from bytes + #[error("Failed to parse `Pubkey` from bytes")] + InvalidPubkey, + /// Attempted to deserialize an `AccountMeta` but the underlying type has + /// PDA configs rather than a fixed address + #[error( + "Attempted to deserialize an `AccountMeta` but the underlying type has PDA configs rather \ + than a fixed address" + )] + AccountTypeNotAccountMeta, + /// Provided list of seed configurations too large for a validation account + #[error("Provided list of seed configurations too large for a validation account")] + SeedConfigsTooLarge, + /// Not enough bytes available to pack seed configuration + #[error("Not enough bytes available to pack seed configuration")] + NotEnoughBytesForSeed, + /// The provided bytes are not valid for a seed configuration + #[error("The provided bytes are not valid for a seed configuration")] + InvalidBytesForSeed, + /// Tried to pack an invalid seed configuration + #[error("Tried to pack an invalid seed configuration")] + InvalidSeedConfig, + /// Instruction data too small for seed configuration + #[error("Instruction data too small for seed configuration")] + InstructionDataTooSmall, + /// Could not find account at specified index + #[error("Could not find account at specified index")] + AccountNotFound, + /// Error in checked math operation + #[error("Error in checked math operation")] + CalculationFailure, + /// Could not find account data at specified index + #[error("Could not find account data at specified index")] + AccountDataNotFound, + /// Account data too small for requested seed configuration + #[error("Account data too small for requested seed configuration")] + AccountDataTooSmall, + /// Failed to fetch account + #[error("Failed to fetch account")] + AccountFetchFailed, + /// Not enough bytes available to pack pubkey data configuration. + #[error("Not enough bytes available to pack pubkey data configuration")] + NotEnoughBytesForPubkeyData, + /// The provided bytes are not valid for a pubkey data configuration + #[error("The provided bytes are not valid for a pubkey data configuration")] + InvalidBytesForPubkeyData, + /// Tried to pack an invalid pubkey data configuration + #[error("Tried to pack an invalid pubkey data configuration")] + InvalidPubkeyDataConfig, +} + +impl From for ProgramError { + fn from(e: AccountResolutionError) -> Self { + ProgramError::Custom(e as u32) + } +} + +impl DecodeError for AccountResolutionError { + fn type_of() -> &'static str { + "AccountResolutionError" + } +} + +impl PrintProgramError for AccountResolutionError { + fn print(&self) + where + E: 'static + + std::error::Error + + DecodeError + + PrintProgramError + + num_traits::FromPrimitive, + { + match self { + AccountResolutionError::IncorrectAccount => { + msg!("Incorrect account provided") + } + AccountResolutionError::NotEnoughAccounts => { + msg!("Not enough accounts provided") + } + AccountResolutionError::TlvUninitialized => { + msg!("No value initialized in TLV data") + } + AccountResolutionError::TlvInitialized => { + msg!("Some value initialized in TLV data") + } + AccountResolutionError::TooManyPubkeys => { + msg!("Too many pubkeys provided") + } + AccountResolutionError::InvalidPubkey => { + msg!("Failed to parse `Pubkey` from bytes") + } + AccountResolutionError::AccountTypeNotAccountMeta => { + msg!( + "Attempted to deserialize an `AccountMeta` but the underlying type has PDA configs rather than a fixed address", + ) + } + AccountResolutionError::SeedConfigsTooLarge => { + msg!("Provided list of seed configurations too large for a validation account",) + } + AccountResolutionError::NotEnoughBytesForSeed => { + msg!("Not enough bytes available to pack seed configuration",) + } + AccountResolutionError::InvalidBytesForSeed => { + msg!("The provided bytes are not valid for a seed configuration",) + } + AccountResolutionError::InvalidSeedConfig => { + msg!("Tried to pack an invalid seed configuration",) + } + AccountResolutionError::InstructionDataTooSmall => { + msg!("Instruction data too small for seed configuration",) + } + AccountResolutionError::AccountNotFound => { + msg!("Could not find account at specified index",) + } + AccountResolutionError::CalculationFailure => { + msg!("Error in checked math operation") + } + AccountResolutionError::AccountDataNotFound => { + msg!("Could not find account data at specified index",) + } + AccountResolutionError::AccountDataTooSmall => { + msg!("Account data too small for requested seed configuration",) + } + AccountResolutionError::AccountFetchFailed => { + msg!("Failed to fetch account") + } + AccountResolutionError::NotEnoughBytesForPubkeyData => { + msg!("Not enough bytes available to pack pubkey data configuration",) + } + AccountResolutionError::InvalidBytesForPubkeyData => { + msg!("The provided bytes are not valid for a pubkey data configuration",) + } + AccountResolutionError::InvalidPubkeyDataConfig => { + msg!("Tried to pack an invalid pubkey data configuration",) + } + } + } +} diff --git a/tlv-account-resolution/src/lib.rs b/tlv-account-resolution/src/lib.rs new file mode 100644 index 00000000..e015fc97 --- /dev/null +++ b/tlv-account-resolution/src/lib.rs @@ -0,0 +1,21 @@ +//! Crate defining a state interface for offchain account resolution. If a +//! program writes the proper state information into one of their accounts, any +//! offchain and onchain client can fetch any additional required accounts for +//! an instruction. + +#![allow(clippy::arithmetic_side_effects)] +#![deny(missing_docs)] +#![cfg_attr(not(test), forbid(unsafe_code))] + +pub mod account; +pub mod error; +pub mod pubkey_data; +pub mod seeds; +pub mod state; + +// Export current sdk types for downstream users building with a different sdk +// version +pub use { + solana_account_info, solana_decode_error, solana_instruction, solana_msg, solana_program_error, + solana_pubkey, +}; diff --git a/tlv-account-resolution/src/pubkey_data.rs b/tlv-account-resolution/src/pubkey_data.rs new file mode 100644 index 00000000..033d72f3 --- /dev/null +++ b/tlv-account-resolution/src/pubkey_data.rs @@ -0,0 +1,185 @@ +//! Types for managing extra account meta keys that may be extracted from some +//! data. +//! +//! This can be either account data from some account in the list of accounts +//! or from the instruction data itself. + +#[cfg(feature = "serde-traits")] +use serde::{Deserialize, Serialize}; +use {crate::error::AccountResolutionError, solana_program_error::ProgramError}; + +/// Enum to describe a required key stored in some data. +#[derive(Clone, Debug, PartialEq)] +#[cfg_attr(feature = "serde-traits", derive(Serialize, Deserialize))] +#[cfg_attr(feature = "serde-traits", serde(rename_all = "camelCase"))] +pub enum PubkeyData { + /// Uninitialized configuration byte space. + Uninitialized, + /// A pubkey to be resolved from the instruction data. + /// + /// Packed as: + /// * 1 - Discriminator + /// * 1 - Start index of instruction data + /// + /// Note: Length is always 32 bytes. + InstructionData { + /// The index where the address bytes begin in the instruction data. + index: u8, + }, + /// A pubkey to be resolved from the inner data of some account. + /// + /// Packed as: + /// * 1 - Discriminator + /// * 1 - Index of account in accounts list + /// * 1 - Start index of account data + /// + /// Note: Length is always 32 bytes. + AccountData { + /// The index of the account in the entire accounts list. + account_index: u8, + /// The index where the address bytes begin in the account data. + data_index: u8, + }, +} +impl PubkeyData { + /// Get the size of a pubkey data configuration. + pub fn tlv_size(&self) -> u8 { + match self { + Self::Uninitialized => 0, + // 1 byte for the discriminator, 1 byte for the index. + Self::InstructionData { .. } => 1 + 1, + // 1 byte for the discriminator, 1 byte for the account index, + // 1 byte for the data index. + Self::AccountData { .. } => 1 + 1 + 1, + } + } + + /// Packs a pubkey data configuration into a slice. + pub fn pack(&self, dst: &mut [u8]) -> Result<(), ProgramError> { + // Because no `PubkeyData` variant is larger than 3 bytes, this check + // is sufficient for the data length. + if dst.len() != self.tlv_size() as usize { + return Err(AccountResolutionError::NotEnoughBytesForPubkeyData.into()); + } + match &self { + Self::Uninitialized => { + return Err(AccountResolutionError::InvalidPubkeyDataConfig.into()) + } + Self::InstructionData { index } => { + dst[0] = 1; + dst[1] = *index; + } + Self::AccountData { + account_index, + data_index, + } => { + dst[0] = 2; + dst[1] = *account_index; + dst[2] = *data_index; + } + } + Ok(()) + } + + /// Packs a pubkey data configuration into a 32-byte array, filling the + /// rest with 0s. + pub fn pack_into_address_config(key_data: &Self) -> Result<[u8; 32], ProgramError> { + let mut packed = [0u8; 32]; + let tlv_size = key_data.tlv_size() as usize; + key_data.pack(&mut packed[..tlv_size])?; + Ok(packed) + } + + /// Unpacks a pubkey data configuration from a slice. + pub fn unpack(bytes: &[u8]) -> Result { + let (discrim, rest) = bytes + .split_first() + .ok_or::(ProgramError::InvalidAccountData)?; + match discrim { + 0 => Ok(Self::Uninitialized), + 1 => { + if rest.is_empty() { + return Err(AccountResolutionError::InvalidBytesForPubkeyData.into()); + } + Ok(Self::InstructionData { index: rest[0] }) + } + 2 => { + if rest.len() < 2 { + return Err(AccountResolutionError::InvalidBytesForPubkeyData.into()); + } + Ok(Self::AccountData { + account_index: rest[0], + data_index: rest[1], + }) + } + _ => Err(ProgramError::InvalidAccountData), + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_pack() { + // Should fail if the length is too short. + let key = PubkeyData::InstructionData { index: 0 }; + let mut packed = vec![0u8; key.tlv_size() as usize - 1]; + assert_eq!( + key.pack(&mut packed).unwrap_err(), + AccountResolutionError::NotEnoughBytesForPubkeyData.into(), + ); + + // Should fail if the length is too long. + let key = PubkeyData::InstructionData { index: 0 }; + let mut packed = vec![0u8; key.tlv_size() as usize + 1]; + assert_eq!( + key.pack(&mut packed).unwrap_err(), + AccountResolutionError::NotEnoughBytesForPubkeyData.into(), + ); + + // Can't pack a `PubkeyData::Uninitialized`. + let key = PubkeyData::Uninitialized; + let mut packed = vec![0u8; key.tlv_size() as usize]; + assert_eq!( + key.pack(&mut packed).unwrap_err(), + AccountResolutionError::InvalidPubkeyDataConfig.into(), + ); + } + + #[test] + fn test_unpack() { + // Can unpack zeroes. + let zeroes = [0u8; 32]; + let key = PubkeyData::unpack(&zeroes).unwrap(); + assert_eq!(key, PubkeyData::Uninitialized); + + // Should fail for empty bytes. + let bytes = []; + assert_eq!( + PubkeyData::unpack(&bytes).unwrap_err(), + ProgramError::InvalidAccountData + ); + } + + fn test_pack_unpack_key(key: PubkeyData) { + let tlv_size = key.tlv_size() as usize; + let mut packed = vec![0u8; tlv_size]; + key.pack(&mut packed).unwrap(); + let unpacked = PubkeyData::unpack(&packed).unwrap(); + assert_eq!(key, unpacked); + } + + #[test] + fn test_pack_unpack() { + // Instruction data. + test_pack_unpack_key(PubkeyData::InstructionData { index: 0 }); + + // Account data. + test_pack_unpack_key(PubkeyData::AccountData { + account_index: 0, + data_index: 0, + }); + } +} diff --git a/tlv-account-resolution/src/seeds.rs b/tlv-account-resolution/src/seeds.rs new file mode 100644 index 00000000..1f51128e --- /dev/null +++ b/tlv-account-resolution/src/seeds.rs @@ -0,0 +1,524 @@ +//! Types for managing seed configurations in TLV Account Resolution +//! +//! As determined by the `address_config` field of `ExtraAccountMeta`, +//! seed configurations are limited to a maximum of 32 bytes. +//! This means that the maximum number of seed configurations that can be +//! packed into a single `ExtraAccountMeta` will depend directly on the size +//! of the seed configurations themselves. +//! +//! Sizes are as follows: +//! * `Seed::Literal`: 1 + 1 + N +//! * 1 - Discriminator +//! * 1 - Length of literal +//! * N - Literal bytes themselves +//! * `Seed::InstructionData`: 1 + 1 + 1 = 3 +//! * 1 - Discriminator +//! * 1 - Start index of instruction data +//! * 1 - Length of instruction data starting at index +//! * `Seed::AccountKey` - 1 + 1 = 2 +//! * 1 - Discriminator +//! * 1 - Index of account in accounts list +//! * `Seed::AccountData`: 1 + 1 + 1 + 1 = 4 +//! * 1 - Discriminator +//! * 1 - Index of account in accounts list +//! * 1 - Start index of account data +//! * 1 - Length of account data starting at index +//! +//! No matter which types of seeds you choose, the total size of all seed +//! configurations must be less than or equal to 32 bytes. + +#[cfg(feature = "serde-traits")] +use serde::{Deserialize, Serialize}; +use {crate::error::AccountResolutionError, solana_program_error::ProgramError}; + +/// Enum to describe a required seed for a Program-Derived Address +#[derive(Clone, Debug, PartialEq)] +#[cfg_attr(feature = "serde-traits", derive(Serialize, Deserialize))] +#[cfg_attr(feature = "serde-traits", serde(rename_all = "camelCase"))] +pub enum Seed { + /// Uninitialized configuration byte space + Uninitialized, + /// A literal hard-coded argument + /// Packed as: + /// * 1 - Discriminator + /// * 1 - Length of literal + /// * N - Literal bytes themselves + Literal { + /// The literal value represented as a vector of bytes. + /// + /// For example, if a literal value is a string literal, + /// such as "my-seed", this value would be + /// `"my-seed".as_bytes().to_vec()`. + bytes: Vec, + }, + /// An instruction-provided argument, to be resolved from the instruction + /// data + /// Packed as: + /// * 1 - Discriminator + /// * 1 - Start index of instruction data + /// * 1 - Length of instruction data starting at index + InstructionData { + /// The index where the bytes of an instruction argument begin + index: u8, + /// The length of the instruction argument (number of bytes) + /// + /// Note: Max seed length is 32 bytes, so `u8` is appropriate here + length: u8, + }, + /// The public key of an account from the entire accounts list. + /// Note: This includes an extra accounts required. + /// + /// Packed as: + /// * 1 - Discriminator + /// * 1 - Index of account in accounts list + AccountKey { + /// The index of the account in the entire accounts list + index: u8, + }, + /// An argument to be resolved from the inner data of some account + /// Packed as: + /// * 1 - Discriminator + /// * 1 - Index of account in accounts list + /// * 1 - Start index of account data + /// * 1 - Length of account data starting at index + #[cfg_attr( + feature = "serde-traits", + serde(rename_all = "camelCase", alias = "account_data") + )] + AccountData { + /// The index of the account in the entire accounts list + account_index: u8, + /// The index where the bytes of an account data argument begin + data_index: u8, + /// The length of the argument (number of bytes) + /// + /// Note: Max seed length is 32 bytes, so `u8` is appropriate here + length: u8, + }, +} +impl Seed { + /// Get the size of a seed configuration + pub fn tlv_size(&self) -> u8 { + match &self { + // 1 byte for the discriminator + Self::Uninitialized => 0, + // 1 byte for the discriminator, 1 byte for the length of the bytes, then the raw bytes + Self::Literal { bytes } => 1 + 1 + bytes.len() as u8, + // 1 byte for the discriminator, 1 byte for the index, 1 byte for the length + Self::InstructionData { .. } => 1 + 1 + 1, + // 1 byte for the discriminator, 1 byte for the index + Self::AccountKey { .. } => 1 + 1, + // 1 byte for the discriminator, 1 byte for the account index, + // 1 byte for the data index 1 byte for the length + Self::AccountData { .. } => 1 + 1 + 1 + 1, + } + } + + /// Packs a seed configuration into a slice + pub fn pack(&self, dst: &mut [u8]) -> Result<(), ProgramError> { + if dst.len() != self.tlv_size() as usize { + return Err(AccountResolutionError::NotEnoughBytesForSeed.into()); + } + if dst.len() > 32 { + return Err(AccountResolutionError::SeedConfigsTooLarge.into()); + } + match &self { + Self::Uninitialized => return Err(AccountResolutionError::InvalidSeedConfig.into()), + Self::Literal { bytes } => { + dst[0] = 1; + dst[1] = bytes.len() as u8; + dst[2..].copy_from_slice(bytes); + } + Self::InstructionData { index, length } => { + dst[0] = 2; + dst[1] = *index; + dst[2] = *length; + } + Self::AccountKey { index } => { + dst[0] = 3; + dst[1] = *index; + } + Self::AccountData { + account_index, + data_index, + length, + } => { + dst[0] = 4; + dst[1] = *account_index; + dst[2] = *data_index; + dst[3] = *length; + } + } + Ok(()) + } + + /// Packs a vector of seed configurations into a 32-byte array, + /// filling the rest with 0s. Errors if it overflows. + pub fn pack_into_address_config(seeds: &[Self]) -> Result<[u8; 32], ProgramError> { + let mut packed = [0u8; 32]; + let mut i: usize = 0; + for seed in seeds { + let seed_size = seed.tlv_size() as usize; + let slice_end = i + seed_size; + if slice_end > 32 { + return Err(AccountResolutionError::SeedConfigsTooLarge.into()); + } + seed.pack(&mut packed[i..slice_end])?; + i = slice_end; + } + Ok(packed) + } + + /// Unpacks a seed configuration from a slice + pub fn unpack(bytes: &[u8]) -> Result { + let (discrim, rest) = bytes + .split_first() + .ok_or::(ProgramError::InvalidAccountData)?; + match discrim { + 0 => Ok(Self::Uninitialized), + 1 => unpack_seed_literal(rest), + 2 => unpack_seed_instruction_arg(rest), + 3 => unpack_seed_account_key(rest), + 4 => unpack_seed_account_data(rest), + _ => Err(ProgramError::InvalidAccountData), + } + } + + /// Unpacks all seed configurations from a 32-byte array. + /// Stops when it hits uninitialized data (0s). + pub fn unpack_address_config(address_config: &[u8; 32]) -> Result, ProgramError> { + let mut seeds = vec![]; + let mut i = 0; + while i < 32 { + let seed = Self::unpack(&address_config[i..])?; + let seed_size = seed.tlv_size() as usize; + i += seed_size; + if seed == Self::Uninitialized { + break; + } + seeds.push(seed); + } + Ok(seeds) + } +} + +fn unpack_seed_literal(bytes: &[u8]) -> Result { + let (length, rest) = bytes + .split_first() + // Should be at least 1 byte + .ok_or::(AccountResolutionError::InvalidBytesForSeed.into())?; + let length = *length as usize; + if rest.len() < length { + // Should be at least `length` bytes + return Err(AccountResolutionError::InvalidBytesForSeed.into()); + } + Ok(Seed::Literal { + bytes: rest[..length].to_vec(), + }) +} + +fn unpack_seed_instruction_arg(bytes: &[u8]) -> Result { + if bytes.len() < 2 { + // Should be at least 2 bytes + return Err(AccountResolutionError::InvalidBytesForSeed.into()); + } + Ok(Seed::InstructionData { + index: bytes[0], + length: bytes[1], + }) +} + +fn unpack_seed_account_key(bytes: &[u8]) -> Result { + if bytes.is_empty() { + // Should be at least 1 byte + return Err(AccountResolutionError::InvalidBytesForSeed.into()); + } + Ok(Seed::AccountKey { index: bytes[0] }) +} + +fn unpack_seed_account_data(bytes: &[u8]) -> Result { + if bytes.len() < 3 { + // Should be at least 3 bytes + return Err(AccountResolutionError::InvalidBytesForSeed.into()); + } + Ok(Seed::AccountData { + account_index: bytes[0], + data_index: bytes[1], + length: bytes[2], + }) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_pack() { + // Seed too large + let seed = Seed::Literal { bytes: vec![1; 33] }; + let mut packed = vec![0u8; seed.tlv_size() as usize]; + assert_eq!( + seed.pack(&mut packed).unwrap_err(), + AccountResolutionError::SeedConfigsTooLarge.into() + ); + assert_eq!( + Seed::pack_into_address_config(&[seed]).unwrap_err(), + AccountResolutionError::SeedConfigsTooLarge.into() + ); + + // Should fail if the length is wrong + let seed = Seed::Literal { bytes: vec![1; 12] }; + let mut packed = vec![0u8; seed.tlv_size() as usize - 1]; + assert_eq!( + seed.pack(&mut packed).unwrap_err(), + AccountResolutionError::NotEnoughBytesForSeed.into() + ); + + // Can't pack a `Seed::Uninitialized` + let seed = Seed::Uninitialized; + let mut packed = vec![0u8; seed.tlv_size() as usize]; + assert_eq!( + seed.pack(&mut packed).unwrap_err(), + AccountResolutionError::InvalidSeedConfig.into() + ); + } + + #[test] + fn test_pack_address_config() { + // Should fail if one seed is too large + let seed = Seed::Literal { bytes: vec![1; 36] }; + assert_eq!( + Seed::pack_into_address_config(&[seed]).unwrap_err(), + AccountResolutionError::SeedConfigsTooLarge.into() + ); + + // Should fail if the combination of all seeds is too large + let seed1 = Seed::Literal { bytes: vec![1; 30] }; // 30 bytes + let seed2 = Seed::InstructionData { + index: 0, + length: 4, + }; // 3 bytes + assert_eq!( + Seed::pack_into_address_config(&[seed1, seed2]).unwrap_err(), + AccountResolutionError::SeedConfigsTooLarge.into() + ); + } + + #[test] + fn test_unpack() { + // Can unpack zeroes + let zeroes = [0u8; 32]; + let seeds = Seed::unpack_address_config(&zeroes).unwrap(); + assert_eq!(seeds, vec![]); + + // Should fail for empty bytes + let bytes = []; + assert_eq!( + Seed::unpack(&bytes).unwrap_err(), + ProgramError::InvalidAccountData + ); + + // Should fail if bytes are malformed for literal seed + let bytes = [ + 1, // Discrim (Literal) + 4, // Length + 1, 1, 1, // Incorrect length + ]; + assert_eq!( + Seed::unpack(&bytes).unwrap_err(), + AccountResolutionError::InvalidBytesForSeed.into() + ); + + // Should fail if bytes are malformed for literal seed + let bytes = [ + 2, // Discrim (InstructionData) + 2, // Index (Length missing) + ]; + assert_eq!( + Seed::unpack(&bytes).unwrap_err(), + AccountResolutionError::InvalidBytesForSeed.into() + ); + + // Should fail if bytes are malformed for literal seed + let bytes = [ + 3, // Discrim (AccountKey, Index missing) + ]; + assert_eq!( + Seed::unpack(&bytes).unwrap_err(), + AccountResolutionError::InvalidBytesForSeed.into() + ); + } + + #[test] + fn test_unpack_address_config() { + // Should fail if bytes are malformed + let bytes = [ + 1, // Discrim (Literal) + 4, // Length + 1, 1, 1, 1, // 4 + 6, // Discrim (Invalid) + 2, // Index + 1, // Length + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + ]; + assert_eq!( + Seed::unpack_address_config(&bytes).unwrap_err(), + ProgramError::InvalidAccountData + ); + + // Should fail if 32nd byte is not zero, but it would be the + // start of a config + // + // Namely, if a seed config is unpacked and leaves 1 byte remaining, + // it has to be 0, since no valid seed config can be 1 byte long + let bytes = [ + 1, // Discrim (Literal) + 16, // Length + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, // 16 + 1, // Discrim (Literal) + 11, // Length + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, // 11 + 2, // Non-zero byte + ]; + assert_eq!( + Seed::unpack_address_config(&bytes).unwrap_err(), + AccountResolutionError::InvalidBytesForSeed.into(), + ); + + // Should pass if 31st byte is not zero, but it would be + // the start of a config + // + // Similar to above, however we now have 2 bytes to work with, + // which could be a valid seed config + let bytes = [ + 1, // Discrim (Literal) + 16, // Length + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, // 16 + 1, // Discrim (Literal) + 10, // Length + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, // 10 + 3, // Non-zero byte - Discrim (AccountKey) + 0, // Index + ]; + assert_eq!( + Seed::unpack_address_config(&bytes).unwrap(), + vec![ + Seed::Literal { + bytes: vec![1u8; 16] + }, + Seed::Literal { + bytes: vec![1u8; 10] + }, + Seed::AccountKey { index: 0 } + ], + ); + + // Should fail if 31st byte is not zero and a valid seed config + // discriminator, but the seed config requires more than 2 bytes + let bytes = [ + 1, // Discrim (Literal) + 16, // Length + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, // 16 + 1, // Discrim (Literal) + 10, // Length + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, // 10 + 2, // Non-zero byte - Discrim (InstructionData) + 0, // Index (Length missing) + ]; + assert_eq!( + Seed::unpack_address_config(&bytes).unwrap_err(), + AccountResolutionError::InvalidBytesForSeed.into(), + ); + } + + fn test_pack_unpack_seed(seed: Seed) { + let tlv_size = seed.tlv_size() as usize; + let mut packed = vec![0u8; tlv_size]; + seed.pack(&mut packed).unwrap(); + let unpacked = Seed::unpack(&packed).unwrap(); + assert_eq!(seed, unpacked); + } + + #[test] + fn test_pack_unpack() { + let mut mixed = vec![]; + + // Literals + + let bytes = b"hello"; + let seed = Seed::Literal { + bytes: bytes.to_vec(), + }; + test_pack_unpack_seed(seed); + + let bytes = 8u8.to_le_bytes(); + let seed = Seed::Literal { + bytes: bytes.to_vec(), + }; + test_pack_unpack_seed(seed.clone()); + mixed.push(seed); + + let bytes = 32u32.to_le_bytes(); + let seed = Seed::Literal { + bytes: bytes.to_vec(), + }; + test_pack_unpack_seed(seed.clone()); + mixed.push(seed); + + // Instruction args + + let seed = Seed::InstructionData { + index: 0, + length: 0, + }; + test_pack_unpack_seed(seed); + + let seed = Seed::InstructionData { + index: 6, + length: 9, + }; + test_pack_unpack_seed(seed.clone()); + mixed.push(seed); + + // Account keys + + let seed = Seed::AccountKey { index: 0 }; + test_pack_unpack_seed(seed); + + let seed = Seed::AccountKey { index: 9 }; + test_pack_unpack_seed(seed.clone()); + mixed.push(seed); + + // Account data + + let seed = Seed::AccountData { + account_index: 0, + data_index: 0, + length: 0, + }; + test_pack_unpack_seed(seed); + + let seed = Seed::AccountData { + account_index: 0, + data_index: 0, + length: 9, + }; + test_pack_unpack_seed(seed.clone()); + mixed.push(seed); + + // Arrays + + let packed_array = Seed::pack_into_address_config(&mixed).unwrap(); + let unpacked_array = Seed::unpack_address_config(&packed_array).unwrap(); + assert_eq!(mixed, unpacked_array); + + let mut shuffled_mixed = mixed.clone(); + shuffled_mixed.swap(0, 1); + shuffled_mixed.swap(1, 4); + shuffled_mixed.swap(3, 0); + + let packed_array = Seed::pack_into_address_config(&shuffled_mixed).unwrap(); + let unpacked_array = Seed::unpack_address_config(&packed_array).unwrap(); + assert_eq!(shuffled_mixed, unpacked_array); + } +} diff --git a/tlv-account-resolution/src/state.rs b/tlv-account-resolution/src/state.rs new file mode 100644 index 00000000..f17ee3d0 --- /dev/null +++ b/tlv-account-resolution/src/state.rs @@ -0,0 +1,1726 @@ +//! State transition types + +use { + crate::{account::ExtraAccountMeta, error::AccountResolutionError}, + solana_account_info::AccountInfo, + solana_instruction::{AccountMeta, Instruction}, + solana_program_error::ProgramError, + solana_pubkey::Pubkey, + spl_discriminator::SplDiscriminate, + spl_pod::slice::{PodSlice, PodSliceMut}, + spl_type_length_value::state::{TlvState, TlvStateBorrowed, TlvStateMut}, + std::future::Future, +}; + +/// Type representing the output of an account fetching function, for easy +/// chaining between APIs +pub type AccountDataResult = Result>, AccountFetchError>; +/// Generic error type that can come out of any client while fetching account +/// data +pub type AccountFetchError = Box; + +/// Helper to convert an `AccountInfo` to an `AccountMeta` +fn account_info_to_meta(account_info: &AccountInfo) -> AccountMeta { + AccountMeta { + pubkey: *account_info.key, + is_signer: account_info.is_signer, + is_writable: account_info.is_writable, + } +} + +/// De-escalate an account meta if necessary +fn de_escalate_account_meta(account_meta: &mut AccountMeta, account_metas: &[AccountMeta]) { + // This is a little tricky to read, but the idea is to see if + // this account is marked as writable or signer anywhere in + // the instruction at the start. If so, DON'T escalate it to + // be a writer or signer in the CPI + let maybe_highest_privileges = account_metas + .iter() + .filter(|&x| x.pubkey == account_meta.pubkey) + .map(|x| (x.is_signer, x.is_writable)) + .reduce(|acc, x| (acc.0 || x.0, acc.1 || x.1)); + // If `Some`, then the account was found somewhere in the instruction + if let Some((is_signer, is_writable)) = maybe_highest_privileges { + if !is_signer && is_signer != account_meta.is_signer { + // Existing account is *NOT* a signer already, but the CPI + // wants it to be, so de-escalate to not be a signer + account_meta.is_signer = false; + } + if !is_writable && is_writable != account_meta.is_writable { + // Existing account is *NOT* writable already, but the CPI + // wants it to be, so de-escalate to not be writable + account_meta.is_writable = false; + } + } +} + +/// Stateless helper for storing additional accounts required for an +/// instruction. +/// +/// This struct works with any `SplDiscriminate`, and stores the extra accounts +/// needed for that specific instruction, using the given `ArrayDiscriminator` +/// as the type-length-value `ArrayDiscriminator`, and then storing all of the +/// given `AccountMeta`s as a zero-copy slice. +/// +/// Sample usage: +/// +/// ```rust +/// use { +/// futures_util::TryFutureExt, +/// solana_client::nonblocking::rpc_client::RpcClient, +/// solana_account_info::AccountInfo, +/// solana_instruction::{AccountMeta, Instruction}, +/// solana_pubkey::Pubkey, +/// spl_discriminator::{ArrayDiscriminator, SplDiscriminate}, +/// spl_tlv_account_resolution::{ +/// account::ExtraAccountMeta, +/// seeds::Seed, +/// state::{AccountDataResult, AccountFetchError, ExtraAccountMetaList} +/// }, +/// }; +/// +/// struct MyInstruction; +/// impl SplDiscriminate for MyInstruction { +/// // Give it a unique discriminator, can also be generated using a hash function +/// const SPL_DISCRIMINATOR: ArrayDiscriminator = ArrayDiscriminator::new([1; ArrayDiscriminator::LENGTH]); +/// } +/// +/// // actually put it in the additional required account keys and signer / writable +/// let extra_metas = [ +/// AccountMeta::new(Pubkey::new_unique(), false).into(), +/// AccountMeta::new_readonly(Pubkey::new_unique(), false).into(), +/// ExtraAccountMeta::new_with_seeds( +/// &[ +/// Seed::Literal { +/// bytes: b"some_string".to_vec(), +/// }, +/// Seed::InstructionData { +/// index: 1, +/// length: 1, // u8 +/// }, +/// Seed::AccountKey { index: 1 }, +/// ], +/// false, +/// true, +/// ).unwrap(), +/// ExtraAccountMeta::new_external_pda_with_seeds( +/// 0, +/// &[Seed::AccountKey { index: 2 }], +/// false, +/// false, +/// ).unwrap(), +/// ]; +/// +/// // assume that this buffer is actually account data, already allocated to `account_size` +/// let account_size = ExtraAccountMetaList::size_of(extra_metas.len()).unwrap(); +/// let mut buffer = vec![0; account_size]; +/// +/// // Initialize the structure for your instruction +/// ExtraAccountMetaList::init::(&mut buffer, &extra_metas).unwrap(); +/// +/// // Off-chain, you can add the additional accounts directly from the account data +/// // You need to provide the resolver a way to fetch account data off-chain +/// struct MyClient { +/// client: RpcClient, +/// } +/// impl MyClient { +/// pub fn new() -> Self { +/// Self { +/// client: RpcClient::new_mock("succeeds".to_string()), +/// } +/// } +/// pub async fn get_account_data(&self, address: Pubkey) -> AccountDataResult { +/// self.client.get_account(&address) +/// .await +/// .map(|acct| Some(acct.data)) +/// .map_err(|e| Box::new(e) as AccountFetchError) +/// } +/// } +/// +/// let client = MyClient::new(); +/// let program_id = Pubkey::new_unique(); +/// let mut instruction = Instruction::new_with_bytes(program_id, &[0, 1, 2], vec![]); +/// # futures::executor::block_on(async { +/// // Now use the resolver to add the additional accounts off-chain +/// ExtraAccountMetaList::add_to_instruction::( +/// &mut instruction, +/// |address: Pubkey| client.get_account_data(address), +/// &buffer, +/// ) +/// .await; +/// # }); +/// +/// // On-chain, you can add the additional accounts *and* account infos +/// let mut cpi_instruction = Instruction::new_with_bytes(program_id, &[0, 1, 2], vec![]); +/// let mut cpi_account_infos = vec![]; // assume the other required account infos are already included +/// let remaining_account_infos: &[AccountInfo<'_>] = &[]; // these are the account infos provided to the instruction that are *not* part of any other known interface +/// ExtraAccountMetaList::add_to_cpi_instruction::( +/// &mut cpi_instruction, +/// &mut cpi_account_infos, +/// &buffer, +/// &remaining_account_infos, +/// ); +/// ``` +pub struct ExtraAccountMetaList; +impl ExtraAccountMetaList { + /// Initialize pod slice data for the given instruction and its required + /// list of `ExtraAccountMeta`s + pub fn init( + data: &mut [u8], + extra_account_metas: &[ExtraAccountMeta], + ) -> Result<(), ProgramError> { + let mut state = TlvStateMut::unpack(data).unwrap(); + let tlv_size = PodSlice::::size_of(extra_account_metas.len())?; + let (bytes, _) = state.alloc::(tlv_size, false)?; + let mut validation_data = PodSliceMut::init(bytes)?; + for meta in extra_account_metas { + validation_data.push(*meta)?; + } + Ok(()) + } + + /// Update pod slice data for the given instruction and its required + /// list of `ExtraAccountMeta`s + pub fn update( + data: &mut [u8], + extra_account_metas: &[ExtraAccountMeta], + ) -> Result<(), ProgramError> { + let mut state = TlvStateMut::unpack(data).unwrap(); + let tlv_size = PodSlice::::size_of(extra_account_metas.len())?; + let bytes = state.realloc_first::(tlv_size)?; + let mut validation_data = PodSliceMut::init(bytes)?; + for meta in extra_account_metas { + validation_data.push(*meta)?; + } + Ok(()) + } + + /// Get the underlying `PodSlice` from an unpacked TLV + /// + /// Due to lifetime annoyances, this function can't just take in the bytes, + /// since then we would be returning a reference to a locally created + /// `TlvStateBorrowed`. I hope there's a better way to do this! + pub fn unpack_with_tlv_state<'a, T: SplDiscriminate>( + tlv_state: &'a TlvStateBorrowed, + ) -> Result, ProgramError> { + let bytes = tlv_state.get_first_bytes::()?; + PodSlice::::unpack(bytes) + } + + /// Get the byte size required to hold `num_items` items + pub fn size_of(num_items: usize) -> Result { + Ok(TlvStateBorrowed::get_base_len() + .saturating_add(PodSlice::::size_of(num_items)?)) + } + + /// Checks provided account infos against validation data, using + /// instruction data and program ID to resolve any dynamic PDAs + /// if necessary. + /// + /// Note: this function will also verify all extra required accounts + /// have been provided in the correct order + pub fn check_account_infos( + account_infos: &[AccountInfo], + instruction_data: &[u8], + program_id: &Pubkey, + data: &[u8], + ) -> Result<(), ProgramError> { + let state = TlvStateBorrowed::unpack(data).unwrap(); + let extra_meta_list = ExtraAccountMetaList::unpack_with_tlv_state::(&state)?; + let extra_account_metas = extra_meta_list.data(); + + let initial_accounts_len = account_infos.len() - extra_account_metas.len(); + + // Convert to `AccountMeta` to check resolved metas + let provided_metas = account_infos + .iter() + .map(account_info_to_meta) + .collect::>(); + + for (i, config) in extra_account_metas.iter().enumerate() { + let meta = { + // Create a list of `Ref`s so we can reference account data in the + // resolution step + let account_key_data_refs = account_infos + .iter() + .map(|info| { + let key = *info.key; + let data = info.try_borrow_data()?; + Ok((key, data)) + }) + .collect::, ProgramError>>()?; + + config.resolve(instruction_data, program_id, |usize| { + account_key_data_refs + .get(usize) + .map(|(pubkey, opt_data)| (pubkey, Some(opt_data.as_ref()))) + })? + }; + + // Ensure the account is in the correct position + let expected_index = i + .checked_add(initial_accounts_len) + .ok_or::(AccountResolutionError::CalculationFailure.into())?; + if provided_metas.get(expected_index) != Some(&meta) { + return Err(AccountResolutionError::IncorrectAccount.into()); + } + } + + Ok(()) + } + + /// Add the additional account metas to an existing instruction + pub async fn add_to_instruction( + instruction: &mut Instruction, + fetch_account_data_fn: F, + data: &[u8], + ) -> Result<(), ProgramError> + where + F: Fn(Pubkey) -> Fut, + Fut: Future, + { + let state = TlvStateBorrowed::unpack(data)?; + let bytes = state.get_first_bytes::()?; + let extra_account_metas = PodSlice::::unpack(bytes)?; + + // Fetch account data for each of the instruction accounts + let mut account_key_datas = vec![]; + for meta in instruction.accounts.iter() { + let account_data = fetch_account_data_fn(meta.pubkey) + .await + .map_err::(|_| { + AccountResolutionError::AccountFetchFailed.into() + })?; + account_key_datas.push((meta.pubkey, account_data)); + } + + for extra_meta in extra_account_metas.data().iter() { + let mut meta = + extra_meta.resolve(&instruction.data, &instruction.program_id, |usize| { + account_key_datas + .get(usize) + .map(|(pubkey, opt_data)| (pubkey, opt_data.as_ref().map(|x| x.as_slice()))) + })?; + de_escalate_account_meta(&mut meta, &instruction.accounts); + + // Fetch account data for the new account + account_key_datas.push(( + meta.pubkey, + fetch_account_data_fn(meta.pubkey) + .await + .map_err::(|_| { + AccountResolutionError::AccountFetchFailed.into() + })?, + )); + instruction.accounts.push(meta); + } + Ok(()) + } + + /// Add the additional account metas and account infos for a CPI + pub fn add_to_cpi_instruction<'a, T: SplDiscriminate>( + cpi_instruction: &mut Instruction, + cpi_account_infos: &mut Vec>, + data: &[u8], + account_infos: &[AccountInfo<'a>], + ) -> Result<(), ProgramError> { + let state = TlvStateBorrowed::unpack(data)?; + let bytes = state.get_first_bytes::()?; + let extra_account_metas = PodSlice::::unpack(bytes)?; + + for extra_meta in extra_account_metas.data().iter() { + let mut meta = { + // Create a list of `Ref`s so we can reference account data in the + // resolution step + let account_key_data_refs = cpi_account_infos + .iter() + .map(|info| { + let key = *info.key; + let data = info.try_borrow_data()?; + Ok((key, data)) + }) + .collect::, ProgramError>>()?; + + extra_meta.resolve( + &cpi_instruction.data, + &cpi_instruction.program_id, + |usize| { + account_key_data_refs + .get(usize) + .map(|(pubkey, opt_data)| (pubkey, Some(opt_data.as_ref()))) + }, + )? + }; + de_escalate_account_meta(&mut meta, &cpi_instruction.accounts); + + let account_info = account_infos + .iter() + .find(|&x| *x.key == meta.pubkey) + .ok_or(AccountResolutionError::IncorrectAccount)? + .clone(); + + cpi_instruction.accounts.push(meta); + cpi_account_infos.push(account_info); + } + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use { + super::*, + crate::{pubkey_data::PubkeyData, seeds::Seed}, + solana_instruction::AccountMeta, + solana_program_test::tokio, + solana_pubkey::Pubkey, + spl_discriminator::{ArrayDiscriminator, SplDiscriminate}, + std::collections::HashMap, + }; + + pub struct TestInstruction; + impl SplDiscriminate for TestInstruction { + const SPL_DISCRIMINATOR: ArrayDiscriminator = + ArrayDiscriminator::new([1; ArrayDiscriminator::LENGTH]); + } + + pub struct TestOtherInstruction; + impl SplDiscriminate for TestOtherInstruction { + const SPL_DISCRIMINATOR: ArrayDiscriminator = + ArrayDiscriminator::new([2; ArrayDiscriminator::LENGTH]); + } + + pub struct MockRpc<'a> { + cache: HashMap>, + } + impl<'a> MockRpc<'a> { + pub fn setup(account_infos: &'a [AccountInfo<'a>]) -> Self { + let mut cache = HashMap::new(); + for info in account_infos { + cache.insert(*info.key, info); + } + Self { cache } + } + + pub async fn get_account_data(&self, pubkey: Pubkey) -> AccountDataResult { + Ok(self + .cache + .get(&pubkey) + .map(|account| account.try_borrow_data().unwrap().to_vec())) + } + } + + #[tokio::test] + async fn init_with_metas() { + let metas = [ + AccountMeta::new(Pubkey::new_unique(), false).into(), + AccountMeta::new(Pubkey::new_unique(), true).into(), + AccountMeta::new_readonly(Pubkey::new_unique(), true).into(), + AccountMeta::new_readonly(Pubkey::new_unique(), false).into(), + ]; + let account_size = ExtraAccountMetaList::size_of(metas.len()).unwrap(); + let mut buffer = vec![0; account_size]; + + ExtraAccountMetaList::init::(&mut buffer, &metas).unwrap(); + + let mock_rpc = MockRpc::setup(&[]); + + let mut instruction = Instruction::new_with_bytes(Pubkey::new_unique(), &[], vec![]); + ExtraAccountMetaList::add_to_instruction::( + &mut instruction, + |pubkey| mock_rpc.get_account_data(pubkey), + &buffer, + ) + .await + .unwrap(); + + let check_metas = metas + .iter() + .map(|e| AccountMeta::try_from(e).unwrap()) + .collect::>(); + + assert_eq!(instruction.accounts, check_metas,); + } + + #[tokio::test] + async fn init_with_infos() { + let program_id = Pubkey::new_unique(); + + let pubkey1 = Pubkey::new_unique(); + let mut lamports1 = 0; + let mut data1 = []; + let pubkey2 = Pubkey::new_unique(); + let mut lamports2 = 0; + let mut data2 = [4, 4, 4, 6, 6, 6, 8, 8]; + let pubkey3 = Pubkey::new_unique(); + let mut lamports3 = 0; + let mut data3 = []; + let owner = Pubkey::new_unique(); + let account_infos = [ + AccountInfo::new( + &pubkey1, + false, + true, + &mut lamports1, + &mut data1, + &owner, + false, + 0, + ), + AccountInfo::new( + &pubkey2, + true, + false, + &mut lamports2, + &mut data2, + &owner, + false, + 0, + ), + AccountInfo::new( + &pubkey3, + false, + false, + &mut lamports3, + &mut data3, + &owner, + false, + 0, + ), + ]; + + let required_pda = ExtraAccountMeta::new_with_seeds( + &[ + Seed::AccountKey { index: 0 }, + Seed::AccountData { + account_index: 1, + data_index: 2, + length: 4, + }, + ], + false, + true, + ) + .unwrap(); + + // Convert to `ExtraAccountMeta` + let required_extra_accounts = [ + ExtraAccountMeta::from(&account_infos[0]), + ExtraAccountMeta::from(&account_infos[1]), + ExtraAccountMeta::from(&account_infos[2]), + required_pda, + ]; + + let account_size = ExtraAccountMetaList::size_of(required_extra_accounts.len()).unwrap(); + let mut buffer = vec![0; account_size]; + + ExtraAccountMetaList::init::(&mut buffer, &required_extra_accounts) + .unwrap(); + + let mock_rpc = MockRpc::setup(&account_infos); + + let mut instruction = Instruction::new_with_bytes(program_id, &[], vec![]); + ExtraAccountMetaList::add_to_instruction::( + &mut instruction, + |pubkey| mock_rpc.get_account_data(pubkey), + &buffer, + ) + .await + .unwrap(); + + let (check_required_pda, _) = Pubkey::find_program_address( + &[ + account_infos[0].key.as_ref(), // Account key + &account_infos[1].try_borrow_data().unwrap()[2..6], // Account data + ], + &program_id, + ); + + // Convert to `AccountMeta` to check instruction + let check_metas = [ + account_info_to_meta(&account_infos[0]), + account_info_to_meta(&account_infos[1]), + account_info_to_meta(&account_infos[2]), + AccountMeta::new(check_required_pda, false), + ]; + + assert_eq!(instruction.accounts, check_metas,); + + assert_eq!( + instruction.accounts.get(3).unwrap().pubkey, + check_required_pda + ); + } + + #[tokio::test] + async fn init_with_extra_account_metas() { + let program_id = Pubkey::new_unique(); + + let extra_meta3_literal_str = "seed_prefix"; + + let ix_account1 = AccountMeta::new(Pubkey::new_unique(), false); + let ix_account2 = AccountMeta::new(Pubkey::new_unique(), true); + + let extra_meta1 = AccountMeta::new(Pubkey::new_unique(), false); + let extra_meta2 = AccountMeta::new(Pubkey::new_unique(), true); + let extra_meta3 = ExtraAccountMeta::new_with_seeds( + &[ + Seed::Literal { + bytes: extra_meta3_literal_str.as_bytes().to_vec(), + }, + Seed::InstructionData { + index: 1, + length: 1, // u8 + }, + Seed::AccountKey { index: 0 }, + Seed::AccountKey { index: 2 }, + ], + false, + true, + ) + .unwrap(); + let extra_meta4 = ExtraAccountMeta::new_with_pubkey_data( + &PubkeyData::InstructionData { index: 4 }, + false, + true, + ) + .unwrap(); + + let metas = [ + ExtraAccountMeta::from(&extra_meta1), + ExtraAccountMeta::from(&extra_meta2), + extra_meta3, + extra_meta4, + ]; + + let mut ix_data = vec![1, 2, 3, 4]; + let check_extra_meta4_pubkey = Pubkey::new_unique(); + ix_data.extend_from_slice(check_extra_meta4_pubkey.as_ref()); + + let ix_accounts = vec![ix_account1.clone(), ix_account2.clone()]; + let mut instruction = Instruction::new_with_bytes(program_id, &ix_data, ix_accounts); + + let account_size = ExtraAccountMetaList::size_of(metas.len()).unwrap(); + let mut buffer = vec![0; account_size]; + + ExtraAccountMetaList::init::(&mut buffer, &metas).unwrap(); + + let mock_rpc = MockRpc::setup(&[]); + + ExtraAccountMetaList::add_to_instruction::( + &mut instruction, + |pubkey| mock_rpc.get_account_data(pubkey), + &buffer, + ) + .await + .unwrap(); + + let check_extra_meta3_u8_arg = ix_data[1]; + let check_extra_meta3_pubkey = Pubkey::find_program_address( + &[ + extra_meta3_literal_str.as_bytes(), + &[check_extra_meta3_u8_arg], + ix_account1.pubkey.as_ref(), + extra_meta1.pubkey.as_ref(), + ], + &program_id, + ) + .0; + let check_metas = [ + ix_account1, + ix_account2, + extra_meta1, + extra_meta2, + AccountMeta::new(check_extra_meta3_pubkey, false), + AccountMeta::new(check_extra_meta4_pubkey, false), + ]; + + assert_eq!( + instruction.accounts.get(4).unwrap().pubkey, + check_extra_meta3_pubkey, + ); + assert_eq!( + instruction.accounts.get(5).unwrap().pubkey, + check_extra_meta4_pubkey, + ); + assert_eq!(instruction.accounts, check_metas,); + } + + #[tokio::test] + async fn init_multiple() { + let extra_meta5_literal_str = "seed_prefix"; + let extra_meta5_literal_u32 = 4u32; + let other_meta2_literal_str = "other_seed_prefix"; + + let extra_meta1 = AccountMeta::new(Pubkey::new_unique(), false); + let extra_meta2 = AccountMeta::new(Pubkey::new_unique(), true); + let extra_meta3 = AccountMeta::new_readonly(Pubkey::new_unique(), true); + let extra_meta4 = AccountMeta::new_readonly(Pubkey::new_unique(), false); + let extra_meta5 = ExtraAccountMeta::new_with_seeds( + &[ + Seed::Literal { + bytes: extra_meta5_literal_str.as_bytes().to_vec(), + }, + Seed::Literal { + bytes: extra_meta5_literal_u32.to_le_bytes().to_vec(), + }, + Seed::InstructionData { + index: 5, + length: 1, // u8 + }, + Seed::AccountKey { index: 2 }, + ], + false, + true, + ) + .unwrap(); + let extra_meta6 = ExtraAccountMeta::new_with_pubkey_data( + &PubkeyData::InstructionData { index: 8 }, + false, + true, + ) + .unwrap(); + + let other_meta1 = AccountMeta::new(Pubkey::new_unique(), false); + let other_meta2 = ExtraAccountMeta::new_with_seeds( + &[ + Seed::Literal { + bytes: other_meta2_literal_str.as_bytes().to_vec(), + }, + Seed::InstructionData { + index: 1, + length: 4, // u32 + }, + Seed::AccountKey { index: 0 }, + ], + false, + true, + ) + .unwrap(); + let other_meta3 = ExtraAccountMeta::new_with_pubkey_data( + &PubkeyData::InstructionData { index: 7 }, + false, + true, + ) + .unwrap(); + + let metas = [ + ExtraAccountMeta::from(&extra_meta1), + ExtraAccountMeta::from(&extra_meta2), + ExtraAccountMeta::from(&extra_meta3), + ExtraAccountMeta::from(&extra_meta4), + extra_meta5, + extra_meta6, + ]; + let other_metas = [ + ExtraAccountMeta::from(&other_meta1), + other_meta2, + other_meta3, + ]; + + let account_size = ExtraAccountMetaList::size_of(metas.len()).unwrap() + + ExtraAccountMetaList::size_of(other_metas.len()).unwrap(); + let mut buffer = vec![0; account_size]; + + ExtraAccountMetaList::init::(&mut buffer, &metas).unwrap(); + ExtraAccountMetaList::init::(&mut buffer, &other_metas).unwrap(); + + let mock_rpc = MockRpc::setup(&[]); + + let program_id = Pubkey::new_unique(); + + let mut ix_data = vec![0, 0, 0, 0, 0, 7, 0, 0]; + let check_extra_meta6_pubkey = Pubkey::new_unique(); + ix_data.extend_from_slice(check_extra_meta6_pubkey.as_ref()); + + let ix_accounts = vec![]; + + let mut instruction = Instruction::new_with_bytes(program_id, &ix_data, ix_accounts); + ExtraAccountMetaList::add_to_instruction::( + &mut instruction, + |pubkey| mock_rpc.get_account_data(pubkey), + &buffer, + ) + .await + .unwrap(); + + let check_extra_meta5_u8_arg = ix_data[5]; + let check_extra_meta5_pubkey = Pubkey::find_program_address( + &[ + extra_meta5_literal_str.as_bytes(), + extra_meta5_literal_u32.to_le_bytes().as_ref(), + &[check_extra_meta5_u8_arg], + extra_meta3.pubkey.as_ref(), + ], + &program_id, + ) + .0; + let check_metas = [ + extra_meta1, + extra_meta2, + extra_meta3, + extra_meta4, + AccountMeta::new(check_extra_meta5_pubkey, false), + AccountMeta::new(check_extra_meta6_pubkey, false), + ]; + + assert_eq!( + instruction.accounts.get(4).unwrap().pubkey, + check_extra_meta5_pubkey, + ); + assert_eq!( + instruction.accounts.get(5).unwrap().pubkey, + check_extra_meta6_pubkey, + ); + assert_eq!(instruction.accounts, check_metas,); + + let program_id = Pubkey::new_unique(); + + let ix_account1 = AccountMeta::new(Pubkey::new_unique(), false); + let ix_account2 = AccountMeta::new(Pubkey::new_unique(), true); + let ix_accounts = vec![ix_account1.clone(), ix_account2.clone()]; + + let mut ix_data = vec![0, 26, 0, 0, 0, 0, 0]; + let check_other_meta3_pubkey = Pubkey::new_unique(); + ix_data.extend_from_slice(check_other_meta3_pubkey.as_ref()); + + let mut instruction = Instruction::new_with_bytes(program_id, &ix_data, ix_accounts); + ExtraAccountMetaList::add_to_instruction::( + &mut instruction, + |pubkey| mock_rpc.get_account_data(pubkey), + &buffer, + ) + .await + .unwrap(); + + let check_other_meta2_u32_arg = u32::from_le_bytes(ix_data[1..5].try_into().unwrap()); + let check_other_meta2_pubkey = Pubkey::find_program_address( + &[ + other_meta2_literal_str.as_bytes(), + check_other_meta2_u32_arg.to_le_bytes().as_ref(), + ix_account1.pubkey.as_ref(), + ], + &program_id, + ) + .0; + let check_other_metas = [ + ix_account1, + ix_account2, + other_meta1, + AccountMeta::new(check_other_meta2_pubkey, false), + AccountMeta::new(check_other_meta3_pubkey, false), + ]; + + assert_eq!( + instruction.accounts.get(3).unwrap().pubkey, + check_other_meta2_pubkey, + ); + assert_eq!( + instruction.accounts.get(4).unwrap().pubkey, + check_other_meta3_pubkey, + ); + assert_eq!(instruction.accounts, check_other_metas,); + } + + #[tokio::test] + async fn init_mixed() { + let extra_meta5_literal_str = "seed_prefix"; + let extra_meta6_literal_u64 = 28u64; + + let pubkey1 = Pubkey::new_unique(); + let mut lamports1 = 0; + let mut data1 = []; + let pubkey2 = Pubkey::new_unique(); + let mut lamports2 = 0; + let mut data2 = []; + let pubkey3 = Pubkey::new_unique(); + let mut lamports3 = 0; + let mut data3 = []; + let owner = Pubkey::new_unique(); + let account_infos = [ + AccountInfo::new( + &pubkey1, + false, + true, + &mut lamports1, + &mut data1, + &owner, + false, + 0, + ), + AccountInfo::new( + &pubkey2, + true, + false, + &mut lamports2, + &mut data2, + &owner, + false, + 0, + ), + AccountInfo::new( + &pubkey3, + false, + false, + &mut lamports3, + &mut data3, + &owner, + false, + 0, + ), + ]; + + let extra_meta1 = AccountMeta::new(Pubkey::new_unique(), false); + let extra_meta2 = AccountMeta::new(Pubkey::new_unique(), true); + let extra_meta3 = AccountMeta::new_readonly(Pubkey::new_unique(), true); + let extra_meta4 = AccountMeta::new_readonly(Pubkey::new_unique(), false); + let extra_meta5 = ExtraAccountMeta::new_with_seeds( + &[ + Seed::Literal { + bytes: extra_meta5_literal_str.as_bytes().to_vec(), + }, + Seed::InstructionData { + index: 1, + length: 8, // [u8; 8] + }, + Seed::InstructionData { + index: 9, + length: 32, // Pubkey + }, + Seed::AccountKey { index: 2 }, + ], + false, + true, + ) + .unwrap(); + let extra_meta6 = ExtraAccountMeta::new_with_seeds( + &[ + Seed::Literal { + bytes: extra_meta6_literal_u64.to_le_bytes().to_vec(), + }, + Seed::AccountKey { index: 1 }, + Seed::AccountKey { index: 4 }, + ], + false, + true, + ) + .unwrap(); + let extra_meta7 = ExtraAccountMeta::new_with_pubkey_data( + &PubkeyData::InstructionData { index: 41 }, // After the other pubkey arg. + false, + true, + ) + .unwrap(); + + let test_ix_required_extra_accounts = account_infos + .iter() + .map(ExtraAccountMeta::from) + .collect::>(); + let test_other_ix_required_extra_accounts = [ + ExtraAccountMeta::from(&extra_meta1), + ExtraAccountMeta::from(&extra_meta2), + ExtraAccountMeta::from(&extra_meta3), + ExtraAccountMeta::from(&extra_meta4), + extra_meta5, + extra_meta6, + extra_meta7, + ]; + + let account_size = ExtraAccountMetaList::size_of(test_ix_required_extra_accounts.len()) + .unwrap() + + ExtraAccountMetaList::size_of(test_other_ix_required_extra_accounts.len()).unwrap(); + let mut buffer = vec![0; account_size]; + + ExtraAccountMetaList::init::( + &mut buffer, + &test_ix_required_extra_accounts, + ) + .unwrap(); + ExtraAccountMetaList::init::( + &mut buffer, + &test_other_ix_required_extra_accounts, + ) + .unwrap(); + + let mock_rpc = MockRpc::setup(&account_infos); + + let program_id = Pubkey::new_unique(); + let mut instruction = Instruction::new_with_bytes(program_id, &[], vec![]); + ExtraAccountMetaList::add_to_instruction::( + &mut instruction, + |pubkey| mock_rpc.get_account_data(pubkey), + &buffer, + ) + .await + .unwrap(); + + let test_ix_check_metas = account_infos + .iter() + .map(account_info_to_meta) + .collect::>(); + assert_eq!(instruction.accounts, test_ix_check_metas,); + + let program_id = Pubkey::new_unique(); + + let instruction_u8array_arg = [1, 2, 3, 4, 5, 6, 7, 8]; + let instruction_pubkey_arg = Pubkey::new_unique(); + let instruction_key_data_pubkey_arg = Pubkey::new_unique(); + + let mut instruction_data = vec![0]; + instruction_data.extend_from_slice(&instruction_u8array_arg); + instruction_data.extend_from_slice(instruction_pubkey_arg.as_ref()); + instruction_data.extend_from_slice(instruction_key_data_pubkey_arg.as_ref()); + + let mut instruction = Instruction::new_with_bytes(program_id, &instruction_data, vec![]); + ExtraAccountMetaList::add_to_instruction::( + &mut instruction, + |pubkey| mock_rpc.get_account_data(pubkey), + &buffer, + ) + .await + .unwrap(); + + let check_extra_meta5_pubkey = Pubkey::find_program_address( + &[ + extra_meta5_literal_str.as_bytes(), + &instruction_u8array_arg, + instruction_pubkey_arg.as_ref(), + extra_meta3.pubkey.as_ref(), + ], + &program_id, + ) + .0; + + let check_extra_meta6_pubkey = Pubkey::find_program_address( + &[ + extra_meta6_literal_u64.to_le_bytes().as_ref(), + extra_meta2.pubkey.as_ref(), + check_extra_meta5_pubkey.as_ref(), // The first PDA should be at index 4 + ], + &program_id, + ) + .0; + + let test_other_ix_check_metas = vec![ + extra_meta1, + extra_meta2, + extra_meta3, + extra_meta4, + AccountMeta::new(check_extra_meta5_pubkey, false), + AccountMeta::new(check_extra_meta6_pubkey, false), + AccountMeta::new(instruction_key_data_pubkey_arg, false), + ]; + + assert_eq!( + instruction.accounts.get(4).unwrap().pubkey, + check_extra_meta5_pubkey, + ); + assert_eq!( + instruction.accounts.get(5).unwrap().pubkey, + check_extra_meta6_pubkey, + ); + assert_eq!( + instruction.accounts.get(6).unwrap().pubkey, + instruction_key_data_pubkey_arg, + ); + assert_eq!(instruction.accounts, test_other_ix_check_metas,); + } + + #[tokio::test] + async fn cpi_instruction() { + // Say we have a program that CPIs to another program. + // + // Say that _other_ program will need extra account infos. + + // This will be our program + let program_id = Pubkey::new_unique(); + let owner = Pubkey::new_unique(); + + // Some seeds used by the program for PDAs + let required_pda1_literal_string = "required_pda1"; + let required_pda2_literal_u32 = 4u32; + let required_key_data_instruction_data = Pubkey::new_unique(); + + // Define instruction data + // - 0: u8 + // - 1-8: [u8; 8] + // - 9-16: u64 + let instruction_u8array_arg = [1, 2, 3, 4, 5, 6, 7, 8]; + let instruction_u64_arg = 208u64; + let mut instruction_data = vec![0]; + instruction_data.extend_from_slice(&instruction_u8array_arg); + instruction_data.extend_from_slice(instruction_u64_arg.to_le_bytes().as_ref()); + instruction_data.extend_from_slice(required_key_data_instruction_data.as_ref()); + + // Define known instruction accounts + let ix_accounts = vec![ + AccountMeta::new(Pubkey::new_unique(), false), + AccountMeta::new(Pubkey::new_unique(), false), + ]; + + // Define extra account metas required by the program we will CPI to + let extra_meta1 = AccountMeta::new(Pubkey::new_unique(), false); + let extra_meta2 = AccountMeta::new(Pubkey::new_unique(), true); + let extra_meta3 = AccountMeta::new_readonly(Pubkey::new_unique(), false); + let required_accounts = [ + ExtraAccountMeta::from(&extra_meta1), + ExtraAccountMeta::from(&extra_meta2), + ExtraAccountMeta::from(&extra_meta3), + ExtraAccountMeta::new_with_seeds( + &[ + Seed::Literal { + bytes: required_pda1_literal_string.as_bytes().to_vec(), + }, + Seed::InstructionData { + index: 1, + length: 8, // [u8; 8] + }, + Seed::AccountKey { index: 1 }, + ], + false, + true, + ) + .unwrap(), + ExtraAccountMeta::new_with_seeds( + &[ + Seed::Literal { + bytes: required_pda2_literal_u32.to_le_bytes().to_vec(), + }, + Seed::InstructionData { + index: 9, + length: 8, // u64 + }, + Seed::AccountKey { index: 5 }, + ], + false, + true, + ) + .unwrap(), + ExtraAccountMeta::new_with_seeds( + &[ + Seed::InstructionData { + index: 0, + length: 1, // u8 + }, + Seed::AccountData { + account_index: 2, + data_index: 0, + length: 8, + }, + ], + false, + true, + ) + .unwrap(), + ExtraAccountMeta::new_with_seeds( + &[ + Seed::AccountData { + account_index: 5, + data_index: 4, + length: 4, + }, // This one is a PDA! + ], + false, + true, + ) + .unwrap(), + ExtraAccountMeta::new_with_pubkey_data( + &PubkeyData::InstructionData { index: 17 }, + false, + true, + ) + .unwrap(), + ExtraAccountMeta::new_with_pubkey_data( + &PubkeyData::AccountData { + account_index: 6, + data_index: 0, + }, + false, + true, + ) + .unwrap(), + ExtraAccountMeta::new_with_pubkey_data( + &PubkeyData::AccountData { + account_index: 7, + data_index: 8, + }, + false, + true, + ) + .unwrap(), + ]; + + // Now here we're going to build the list of account infos + // We'll need to include: + // - The instruction account infos for the program to CPI to + // - The extra account infos for the program to CPI to + // - Some other arbitrary account infos our program may use + + // First we need to manually derive each PDA + let check_required_pda1_pubkey = Pubkey::find_program_address( + &[ + required_pda1_literal_string.as_bytes(), + &instruction_u8array_arg, + ix_accounts.get(1).unwrap().pubkey.as_ref(), // The second account + ], + &program_id, + ) + .0; + let check_required_pda2_pubkey = Pubkey::find_program_address( + &[ + required_pda2_literal_u32.to_le_bytes().as_ref(), + instruction_u64_arg.to_le_bytes().as_ref(), + check_required_pda1_pubkey.as_ref(), // The first PDA should be at index 5 + ], + &program_id, + ) + .0; + let check_required_pda3_pubkey = Pubkey::find_program_address( + &[ + &[0], // Instruction "discriminator" (u8) + &[8; 8], // The first 8 bytes of the data for account at index 2 (extra account 1) + ], + &program_id, + ) + .0; + let check_required_pda4_pubkey = Pubkey::find_program_address( + &[ + &[7; 4], /* 4 bytes starting at index 4 of the data for account at index 5 (extra + * pda 1) */ + ], + &program_id, + ) + .0; + let check_key_data1_pubkey = required_key_data_instruction_data; + let check_key_data2_pubkey = Pubkey::new_from_array([8; 32]); + let check_key_data3_pubkey = Pubkey::new_from_array([9; 32]); + + // The instruction account infos for the program to CPI to + let pubkey_ix_1 = ix_accounts.first().unwrap().pubkey; + let mut lamports_ix_1 = 0; + let mut data_ix_1 = []; + let pubkey_ix_2 = ix_accounts.get(1).unwrap().pubkey; + let mut lamports_ix_2 = 0; + let mut data_ix_2 = []; + + // The extra account infos for the program to CPI to + let mut lamports1 = 0; + let mut data1 = [8; 12]; + let mut lamports2 = 0; + let mut data2 = []; + let mut lamports3 = 0; + let mut data3 = []; + let mut lamports_pda1 = 0; + let mut data_pda1 = [7; 12]; + let mut lamports_pda2 = 0; + let mut data_pda2 = [8; 32]; + let mut lamports_pda3 = 0; + let mut data_pda3 = [0; 40]; + data_pda3[8..].copy_from_slice(&[9; 32]); // Add pubkey data for pubkey data pubkey 3. + let mut lamports_pda4 = 0; + let mut data_pda4 = []; + let mut data_key_data1 = []; + let mut lamports_key_data1 = 0; + let mut data_key_data2 = []; + let mut lamports_key_data2 = 0; + let mut data_key_data3 = []; + let mut lamports_key_data3 = 0; + + // Some other arbitrary account infos our program may use + let pubkey_arb_1 = Pubkey::new_unique(); + let mut lamports_arb_1 = 0; + let mut data_arb_1 = []; + let pubkey_arb_2 = Pubkey::new_unique(); + let mut lamports_arb_2 = 0; + let mut data_arb_2 = []; + + let all_account_infos = [ + AccountInfo::new( + &pubkey_ix_1, + ix_accounts.first().unwrap().is_signer, + ix_accounts.first().unwrap().is_writable, + &mut lamports_ix_1, + &mut data_ix_1, + &owner, + false, + 0, + ), + AccountInfo::new( + &pubkey_ix_2, + ix_accounts.get(1).unwrap().is_signer, + ix_accounts.get(1).unwrap().is_writable, + &mut lamports_ix_2, + &mut data_ix_2, + &owner, + false, + 0, + ), + AccountInfo::new( + &extra_meta1.pubkey, + required_accounts.first().unwrap().is_signer.into(), + required_accounts.first().unwrap().is_writable.into(), + &mut lamports1, + &mut data1, + &owner, + false, + 0, + ), + AccountInfo::new( + &extra_meta2.pubkey, + required_accounts.get(1).unwrap().is_signer.into(), + required_accounts.get(1).unwrap().is_writable.into(), + &mut lamports2, + &mut data2, + &owner, + false, + 0, + ), + AccountInfo::new( + &extra_meta3.pubkey, + required_accounts.get(2).unwrap().is_signer.into(), + required_accounts.get(2).unwrap().is_writable.into(), + &mut lamports3, + &mut data3, + &owner, + false, + 0, + ), + AccountInfo::new( + &check_required_pda1_pubkey, + required_accounts.get(3).unwrap().is_signer.into(), + required_accounts.get(3).unwrap().is_writable.into(), + &mut lamports_pda1, + &mut data_pda1, + &owner, + false, + 0, + ), + AccountInfo::new( + &check_required_pda2_pubkey, + required_accounts.get(4).unwrap().is_signer.into(), + required_accounts.get(4).unwrap().is_writable.into(), + &mut lamports_pda2, + &mut data_pda2, + &owner, + false, + 0, + ), + AccountInfo::new( + &check_required_pda3_pubkey, + required_accounts.get(5).unwrap().is_signer.into(), + required_accounts.get(5).unwrap().is_writable.into(), + &mut lamports_pda3, + &mut data_pda3, + &owner, + false, + 0, + ), + AccountInfo::new( + &check_required_pda4_pubkey, + required_accounts.get(6).unwrap().is_signer.into(), + required_accounts.get(6).unwrap().is_writable.into(), + &mut lamports_pda4, + &mut data_pda4, + &owner, + false, + 0, + ), + AccountInfo::new( + &check_key_data1_pubkey, + required_accounts.get(7).unwrap().is_signer.into(), + required_accounts.get(7).unwrap().is_writable.into(), + &mut lamports_key_data1, + &mut data_key_data1, + &owner, + false, + 0, + ), + AccountInfo::new( + &check_key_data2_pubkey, + required_accounts.get(8).unwrap().is_signer.into(), + required_accounts.get(8).unwrap().is_writable.into(), + &mut lamports_key_data2, + &mut data_key_data2, + &owner, + false, + 0, + ), + AccountInfo::new( + &check_key_data3_pubkey, + required_accounts.get(9).unwrap().is_signer.into(), + required_accounts.get(9).unwrap().is_writable.into(), + &mut lamports_key_data3, + &mut data_key_data3, + &owner, + false, + 0, + ), + AccountInfo::new( + &pubkey_arb_1, + false, + true, + &mut lamports_arb_1, + &mut data_arb_1, + &owner, + false, + 0, + ), + AccountInfo::new( + &pubkey_arb_2, + false, + true, + &mut lamports_arb_2, + &mut data_arb_2, + &owner, + false, + 0, + ), + ]; + + // Let's use a mock RPC and set up a test instruction to check the CPI + // instruction against later + let rpc_account_infos = all_account_infos.clone(); + let mock_rpc = MockRpc::setup(&rpc_account_infos); + + let account_size = ExtraAccountMetaList::size_of(required_accounts.len()).unwrap(); + let mut buffer = vec![0; account_size]; + ExtraAccountMetaList::init::(&mut buffer, &required_accounts).unwrap(); + + let mut instruction = + Instruction::new_with_bytes(program_id, &instruction_data, ix_accounts.clone()); + ExtraAccountMetaList::add_to_instruction::( + &mut instruction, + |pubkey| mock_rpc.get_account_data(pubkey), + &buffer, + ) + .await + .unwrap(); + + // Perform the account resolution for the CPI instruction + + // Create the instruction itself + let mut cpi_instruction = + Instruction::new_with_bytes(program_id, &instruction_data, ix_accounts); + + // Start with the known account infos + let mut cpi_account_infos = + vec![all_account_infos[0].clone(), all_account_infos[1].clone()]; + + // Mess up the ordering of the account infos to make it harder! + let mut messed_account_infos = all_account_infos.clone(); + messed_account_infos.swap(0, 4); + messed_account_infos.swap(1, 2); + messed_account_infos.swap(3, 4); + messed_account_infos.swap(5, 6); + messed_account_infos.swap(8, 7); + + // Resolve the rest! + ExtraAccountMetaList::add_to_cpi_instruction::( + &mut cpi_instruction, + &mut cpi_account_infos, + &buffer, + &messed_account_infos, + ) + .unwrap(); + + // Our CPI instruction should match the check instruction. + assert_eq!(cpi_instruction, instruction); + + // CPI account infos should have the instruction account infos + // and the extra required account infos from the validation account, + // and they should be in the correct order. + // Note: The two additional arbitrary account infos for the currently + // executing program won't be present in the CPI instruction's account + // infos, so we will omit them (hence the `..9`). + let check_account_infos = &all_account_infos[..12]; + assert_eq!(cpi_account_infos.len(), check_account_infos.len()); + for (a, b) in std::iter::zip(cpi_account_infos, check_account_infos) { + assert_eq!(a.key, b.key); + assert_eq!(a.is_signer, b.is_signer); + assert_eq!(a.is_writable, b.is_writable); + } + } + + async fn update_and_assert_metas( + program_id: Pubkey, + buffer: &mut Vec, + updated_metas: &[ExtraAccountMeta], + check_metas: &[AccountMeta], + ) { + // resize buffer if necessary + let account_size = ExtraAccountMetaList::size_of(updated_metas.len()).unwrap(); + if account_size > buffer.len() { + buffer.resize(account_size, 0); + } + + // update + ExtraAccountMetaList::update::(buffer, updated_metas).unwrap(); + + // retrieve metas and assert + let state = TlvStateBorrowed::unpack(buffer).unwrap(); + let unpacked_metas_pod = + ExtraAccountMetaList::unpack_with_tlv_state::(&state).unwrap(); + let unpacked_metas = unpacked_metas_pod.data(); + assert_eq!( + unpacked_metas, updated_metas, + "The ExtraAccountMetas in the buffer should match the expected ones." + ); + + let mock_rpc = MockRpc::setup(&[]); + + let mut instruction = Instruction::new_with_bytes(program_id, &[], vec![]); + ExtraAccountMetaList::add_to_instruction::( + &mut instruction, + |pubkey| mock_rpc.get_account_data(pubkey), + buffer, + ) + .await + .unwrap(); + + assert_eq!(instruction.accounts, check_metas,); + } + + #[tokio::test] + async fn update_extra_account_meta_list() { + let program_id = Pubkey::new_unique(); + + // Create list of initial metas + let initial_metas = [ + ExtraAccountMeta::new_with_pubkey(&Pubkey::new_unique(), false, true).unwrap(), + ExtraAccountMeta::new_with_pubkey(&Pubkey::new_unique(), true, false).unwrap(), + ]; + + // initialize + let initial_account_size = ExtraAccountMetaList::size_of(initial_metas.len()).unwrap(); + let mut buffer = vec![0; initial_account_size]; + ExtraAccountMetaList::init::(&mut buffer, &initial_metas).unwrap(); + + // Create updated metas list of the same size + let updated_metas_1 = [ + ExtraAccountMeta::new_with_pubkey(&Pubkey::new_unique(), true, true).unwrap(), + ExtraAccountMeta::new_with_pubkey(&Pubkey::new_unique(), false, false).unwrap(), + ]; + let check_metas_1 = updated_metas_1 + .iter() + .map(|e| AccountMeta::try_from(e).unwrap()) + .collect::>(); + update_and_assert_metas(program_id, &mut buffer, &updated_metas_1, &check_metas_1).await; + + // Create updated and larger list of metas + let updated_metas_2 = [ + ExtraAccountMeta::new_with_pubkey(&Pubkey::new_unique(), true, true).unwrap(), + ExtraAccountMeta::new_with_pubkey(&Pubkey::new_unique(), false, false).unwrap(), + ExtraAccountMeta::new_with_pubkey(&Pubkey::new_unique(), false, true).unwrap(), + ]; + let check_metas_2 = updated_metas_2 + .iter() + .map(|e| AccountMeta::try_from(e).unwrap()) + .collect::>(); + update_and_assert_metas(program_id, &mut buffer, &updated_metas_2, &check_metas_2).await; + + // Create updated and smaller list of metas + let updated_metas_3 = + [ExtraAccountMeta::new_with_pubkey(&Pubkey::new_unique(), true, true).unwrap()]; + let check_metas_3 = updated_metas_3 + .iter() + .map(|e| AccountMeta::try_from(e).unwrap()) + .collect::>(); + update_and_assert_metas(program_id, &mut buffer, &updated_metas_3, &check_metas_3).await; + + // Create updated list of metas with a simple PDA + let seed_pubkey = Pubkey::new_unique(); + let updated_metas_4 = [ + ExtraAccountMeta::new_with_pubkey(&seed_pubkey, true, true).unwrap(), + ExtraAccountMeta::new_with_seeds( + &[ + Seed::Literal { + bytes: b"seed-prefix".to_vec(), + }, + Seed::AccountKey { index: 0 }, + ], + false, + true, + ) + .unwrap(), + ]; + let simple_pda = Pubkey::find_program_address( + &[ + b"seed-prefix", // Literal prefix + seed_pubkey.as_ref(), // Account at index 0 + ], + &program_id, + ) + .0; + let check_metas_4 = [ + AccountMeta::new(seed_pubkey, true), + AccountMeta::new(simple_pda, false), + ]; + + update_and_assert_metas(program_id, &mut buffer, &updated_metas_4, &check_metas_4).await; + } + + #[test] + fn check_account_infos_test() { + let program_id = Pubkey::new_unique(); + let owner = Pubkey::new_unique(); + + // Create a list of required account metas + let pubkey1 = Pubkey::new_unique(); + let pubkey2 = Pubkey::new_unique(); + let required_accounts = [ + ExtraAccountMeta::new_with_pubkey(&pubkey1, false, true).unwrap(), + ExtraAccountMeta::new_with_pubkey(&pubkey2, false, false).unwrap(), + ExtraAccountMeta::new_with_seeds( + &[ + Seed::Literal { + bytes: b"lit_seed".to_vec(), + }, + Seed::InstructionData { + index: 0, + length: 4, + }, + Seed::AccountKey { index: 0 }, + ], + false, + true, + ) + .unwrap(), + ExtraAccountMeta::new_with_pubkey_data( + &PubkeyData::InstructionData { index: 8 }, + false, + true, + ) + .unwrap(), + ]; + + // Create the validation data + let account_size = ExtraAccountMetaList::size_of(required_accounts.len()).unwrap(); + let mut buffer = vec![0; account_size]; + ExtraAccountMetaList::init::(&mut buffer, &required_accounts).unwrap(); + + // Create the instruction data + let mut instruction_data = vec![0, 1, 2, 3, 4, 5, 6, 7]; + let key_data_pubkey = Pubkey::new_unique(); + instruction_data.extend_from_slice(key_data_pubkey.as_ref()); + + // Set up a list of the required accounts as account infos, + // with two instruction accounts + let pubkey_ix_1 = Pubkey::new_unique(); + let mut lamports_ix_1 = 0; + let mut data_ix_1 = []; + let pubkey_ix_2 = Pubkey::new_unique(); + let mut lamports_ix_2 = 0; + let mut data_ix_2 = []; + let mut lamports1 = 0; + let mut data1 = []; + let mut lamports2 = 0; + let mut data2 = []; + let mut lamports3 = 0; + let mut data3 = []; + let mut lamports4 = 0; + let mut data4 = []; + let pda = Pubkey::find_program_address( + &[b"lit_seed", &instruction_data[..4], pubkey_ix_1.as_ref()], + &program_id, + ) + .0; + let account_infos = [ + // Instruction account 1 + AccountInfo::new( + &pubkey_ix_1, + false, + true, + &mut lamports_ix_1, + &mut data_ix_1, + &owner, + false, + 0, + ), + // Instruction account 2 + AccountInfo::new( + &pubkey_ix_2, + false, + true, + &mut lamports_ix_2, + &mut data_ix_2, + &owner, + false, + 0, + ), + // Required account 1 + AccountInfo::new( + &pubkey1, + false, + true, + &mut lamports1, + &mut data1, + &owner, + false, + 0, + ), + // Required account 2 + AccountInfo::new( + &pubkey2, + false, + false, + &mut lamports2, + &mut data2, + &owner, + false, + 0, + ), + // Required account 3 (PDA) + AccountInfo::new( + &pda, + false, + true, + &mut lamports3, + &mut data3, + &owner, + false, + 0, + ), + // Required account 4 (pubkey data) + AccountInfo::new( + &key_data_pubkey, + false, + true, + &mut lamports4, + &mut data4, + &owner, + false, + 0, + ), + ]; + + // Create another list of account infos to intentionally mess up + let mut messed_account_infos = account_infos.clone().to_vec(); + messed_account_infos.swap(0, 2); + messed_account_infos.swap(1, 4); + messed_account_infos.swap(3, 2); + messed_account_infos.swap(5, 4); + + // Account info check should fail for the messed list + assert_eq!( + ExtraAccountMetaList::check_account_infos::( + &messed_account_infos, + &instruction_data, + &program_id, + &buffer, + ) + .unwrap_err(), + AccountResolutionError::IncorrectAccount.into(), + ); + + // Account info check should pass for the correct list + assert_eq!( + ExtraAccountMetaList::check_account_infos::( + &account_infos, + &instruction_data, + &program_id, + &buffer, + ), + Ok(()), + ); + } +} diff --git a/type-length-value-derive-test/Cargo.toml b/type-length-value-derive-test/Cargo.toml new file mode 100644 index 00000000..8ad6c307 --- /dev/null +++ b/type-length-value-derive-test/Cargo.toml @@ -0,0 +1,16 @@ +[package] +name = "spl-type-length-value-derive-test" +version = "0.1.0" +description = "Testing Derive Macro Library for SPL Type Length Value traits" +authors = ["Solana Labs Maintainers "] +repository = "https://github.com/solana-labs/solana-program-library" +license = "Apache-2.0" +edition = "2021" + +[dev-dependencies] +borsh = "1.5.3" +solana-borsh = "2.1.0" +spl-discriminator = { version = "0.4.0", path = "../discriminator" } +spl-type-length-value = { version = "0.7.0", path = "../type-length-value", features = [ + "derive", +] } diff --git a/type-length-value-derive-test/src/lib.rs b/type-length-value-derive-test/src/lib.rs new file mode 100644 index 00000000..5e32d4f5 --- /dev/null +++ b/type-length-value-derive-test/src/lib.rs @@ -0,0 +1,59 @@ +//! Test crate to avoid making `borsh` a direct dependency of +//! `spl-type-length-value`. You can't use a derive macro from within the same +//! crate that the macro is defined, so we need this extra crate for just +//! testing the macro itself. + +#[cfg(test)] +pub mod test { + use { + borsh::{BorshDeserialize, BorshSerialize}, + solana_borsh::v1::{get_instance_packed_len, try_from_slice_unchecked}, + spl_discriminator::SplDiscriminate, + spl_type_length_value::{variable_len_pack::VariableLenPack, SplBorshVariableLenPack}, + }; + + #[derive( + Clone, + Debug, + Default, + PartialEq, + BorshDeserialize, + BorshSerialize, + SplDiscriminate, + SplBorshVariableLenPack, + )] + #[discriminator_hash_input("vehicle::my_vehicle")] + pub struct Vehicle { + vin: [u8; 8], + plate: [u8; 7], + } + + #[test] + fn test_derive() { + let vehicle = Vehicle { + vin: [0; 8], + plate: [0; 7], + }; + + assert_eq!( + get_instance_packed_len::(&vehicle).unwrap(), + vehicle.get_packed_len().unwrap() + ); + + let dst1 = &mut [0u8; 15]; + borsh::to_writer(&mut dst1[..], &vehicle).unwrap(); + + let dst2 = &mut [0u8; 15]; + vehicle.pack_into_slice(&mut dst2[..]).unwrap(); + + assert_eq!(dst1, dst2,); + + let mut buffer = [0u8; 15]; + buffer.copy_from_slice(&dst1[..]); + + assert_eq!( + try_from_slice_unchecked::(&buffer).unwrap(), + Vehicle::unpack_from_slice(&buffer).unwrap() + ); + } +} diff --git a/type-length-value/Cargo.toml b/type-length-value/Cargo.toml new file mode 100644 index 00000000..63da3aee --- /dev/null +++ b/type-length-value/Cargo.toml @@ -0,0 +1,31 @@ +[package] +name = "spl-type-length-value" +version = "0.7.0" +description = "Solana Program Library Type-Length-Value Management" +authors = ["Solana Labs Maintainers "] +repository = "https://github.com/solana-labs/solana-program-library" +license = "Apache-2.0" +edition = "2021" +exclude = ["js/**"] + +[features] +derive = ["dep:spl-type-length-value-derive"] + +[dependencies] +bytemuck = { version = "1.20.0", features = ["derive"] } +num-derive = "0.4" +num-traits = "0.2" +solana-account-info = "2.1.0" +solana-decode-error = "2.1.0" +solana-msg = "2.1.0" +solana-program-error = "2.1.0" +spl-discriminator = { version = "0.4.0", path = "../discriminator" } +spl-type-length-value-derive = { version = "0.1", path = "./derive", optional = true } +spl-pod = { version = "0.5.0", path = "../pod" } +thiserror = "2.0" + +[lib] +crate-type = ["cdylib", "lib"] + +[package.metadata.docs.rs] +targets = ["x86_64-unknown-linux-gnu"] diff --git a/type-length-value/README.md b/type-length-value/README.md new file mode 100644 index 00000000..aa7a0850 --- /dev/null +++ b/type-length-value/README.md @@ -0,0 +1,196 @@ +# Type-Length-Value + +Library with utilities for working with Type-Length-Value structures. + +## Example usage + +This simple examples defines a zero-copy type with its discriminator. + +```rust +use { + bytemuck::{Pod, Zeroable}, + spl_discriminator::{ArrayDiscriminator, SplDiscriminate}, + spl_type_length_value::{ + state::{TlvState, TlvStateBorrowed, TlvStateMut} + }, +}; + +#[repr(C)] +#[derive(Clone, Copy, Debug, Default, PartialEq, Pod, Zeroable)] +struct MyPodValue { + data: [u8; 32], +} +impl SplDiscriminate for MyPodValue { + // Give it a unique discriminator, can also be generated using a hash function + const SPL_DISCRIMINATOR: ArrayDiscriminator = ArrayDiscriminator::new([1; ArrayDiscriminator::LENGTH]); +} +#[repr(C)] +#[derive(Clone, Copy, Debug, PartialEq, Pod, Zeroable)] +struct MyOtherPodValue { + data: u8, +} +// Give this type a non-derivable implementation of `Default` to write some data +impl Default for MyOtherPodValue { + fn default() -> Self { + Self { + data: 10, + } + } +} +impl SplDiscriminate for MyOtherPodValue { + // Some other unique discriminator + const SPL_DISCRIMINATOR: ArrayDiscriminator = ArrayDiscriminator::new([2; ArrayDiscriminator::LENGTH]); +} + +// Account will have two sets of `get_base_len()` (8-byte discriminator and 4-byte length), +// and enough room for a `MyPodValue` and a `MyOtherPodValue` +let account_size = TlvStateMut::get_base_len() + + std::mem::size_of::() + + TlvStateMut::get_base_len() + + std::mem::size_of::() + + TlvStateMut::get_base_len() + + std::mem::size_of::(); + +// Buffer likely comes from a Solana `solana_account_info::AccountInfo`, +// but this example just uses a vector. +let mut buffer = vec![0; account_size]; + +// Unpack the base buffer as a TLV structure +let mut state = TlvStateMut::unpack(&mut buffer).unwrap(); + +// Init and write default value +// Note: you'll need to provide a boolean whether or not to allow repeating +// values with the same TLV discriminator. +// If set to false, this function will error when an existing entry is detected. +// Note the function also returns the repetition number, which can be used to +// fetch the value again. +let (value, _repetition_number) = state.init_value::(false).unwrap(); +// Update it in-place +value.data[0] = 1; + +// Init and write another default value +// This time, we're going to allow repeating values. +let (other_value1, other_value1_repetition_number) = + state.init_value::(true).unwrap(); +assert_eq!(other_value1.data, 10); +// Update it in-place +other_value1.data = 2; + +// Let's do it again, since we can now have repeating values! +let (other_value2, other_value2_repetition_number) = + state.init_value::(true).unwrap(); +assert_eq!(other_value2.data, 10); +// Update it in-place +other_value2.data = 4; + +// Later on, to work with it again, we can just get the first value we +// encounter, because we did _not_ allow repeating entries for `MyPodValue`. +let value = state.get_first_value_mut::().unwrap(); + +// Or fetch it from an immutable buffer +let state = TlvStateBorrowed::unpack(&buffer).unwrap(); +let value1 = state.get_first_value::().unwrap(); + +// Since we used repeating entries for `MyOtherPodValue`, we can grab either one by +// its repetition number +let value1 = state + .get_value_with_repetition::(other_value1_repetition_number) + .unwrap(); +let value2 = state + .get_value_with_repetition::(other_value2_repetition_number) + .unwrap(); + +``` + +## Motivation + +The Solana blockchain exposes slabs of bytes to on-chain programs, allowing program +writers to interpret these bytes and change them however they wish. Currently, +programs interpret account bytes as being only of one type. For example, a token +mint account is only ever a token mint, an AMM pool account is only ever an AMM pool, +a token metadata account can only hold token metadata, etc. + +In a world of interfaces, a program will likely implement multiple interfaces. +As a concrete and important example, imagine a token program where mints hold +their own metadata. This means that a single account can be both a mint and +metadata. + +To allow easy implementation of multiple interfaces, accounts must be able to +hold multiple different types within one opaque slab of bytes. The +[type-length-value](https://en.wikipedia.org/wiki/Type%E2%80%93length%E2%80%93value) +scheme facilitates this exact case. + +## How it works + +This library allows for holding multiple disparate types within the same account +by encoding the type, then length, then value. + +The type is an 8-byte `ArrayDiscriminator`, which can be set to anything. + +The length is a little-endian `u32`. + +The value is a slab of `length` bytes that can be used however a program desires. + +When searching through the buffer for a particular type, the library looks at +the first 8-byte discriminator. If it's all zeroes, this means it's uninitialized. +If not, it reads the next 4-byte length. If the discriminator matches, it returns +the next `length` bytes. If not, it jumps ahead `length` bytes and reads the +next 8-byte discriminator. + +## Serialization of variable-length types + +The initial example works using the `bytemuck` crate for zero-copy serialization +and deserialization. It's possible to use Borsh by implementing the `VariableLenPack` +trait on your type. + +```rust +use { + borsh::{BorshDeserialize, BorshSerialize}, + solana_borsh::v1::{get_instance_packed_len, try_from_slice_unchecked}, + solana_program_error::ProgramError, + spl_discriminator::{ArrayDiscriminator, SplDiscriminate}, + spl_type_length_value::{ + state::{TlvState, TlvStateMut}, + variable_len_pack::VariableLenPack + }, +}; +#[derive(Clone, Debug, PartialEq, BorshDeserialize, BorshSerialize)] +struct MyVariableLenType { + data: String, // variable length type +} +impl SplDiscriminate for MyVariableLenType { + const SPL_DISCRIMINATOR: ArrayDiscriminator = ArrayDiscriminator::new([5; ArrayDiscriminator::LENGTH]); +} +impl VariableLenPack for MyVariableLenType { + fn pack_into_slice(&self, dst: &mut [u8]) -> Result<(), ProgramError> { + borsh::to_writer(&mut dst[..], self).map_err(Into::into) + } + + fn unpack_from_slice(src: &[u8]) -> Result { + try_from_slice_unchecked(src).map_err(Into::into) + } + + fn get_packed_len(&self) -> Result { + get_instance_packed_len(self).map_err(Into::into) + } +} +let initial_data = "This is a pretty cool test!"; +// Allocate exactly the right size for the string, can go bigger if desired +let tlv_size = 4 + initial_data.len(); +let account_size = TlvStateMut::get_base_len() + tlv_size; + +// Buffer likely comes from a Solana `solana_account_info::AccountInfo`, +// but this example just uses a vector. +let mut buffer = vec![0; account_size]; +let mut state = TlvStateMut::unpack(&mut buffer).unwrap(); + +// No need to hold onto the bytes since we'll serialize back into the right place +// For this example, let's _not_ allow repeating entries. +let _ = state.alloc::(tlv_size, false).unwrap(); +let my_variable_len = MyVariableLenType { + data: initial_data.to_string() +}; +state.pack_first_variable_len_value(&my_variable_len).unwrap(); +let deser = state.get_first_variable_len_value::().unwrap(); +assert_eq!(deser, my_variable_len); +``` diff --git a/type-length-value/derive/Cargo.toml b/type-length-value/derive/Cargo.toml new file mode 100644 index 00000000..4cc7cfed --- /dev/null +++ b/type-length-value/derive/Cargo.toml @@ -0,0 +1,16 @@ +[package] +name = "spl-type-length-value-derive" +version = "0.1.0" +description = "Derive Macro Library for SPL Type Length Value traits" +authors = ["Solana Labs Maintainers "] +repository = "https://github.com/solana-labs/solana-program-library" +license = "Apache-2.0" +edition = "2021" + +[lib] +proc-macro = true + +[dependencies] +proc-macro2 = "1.0" +quote = "1.0" +syn = { version = "2.0", features = ["full"] } diff --git a/type-length-value/derive/src/builder.rs b/type-length-value/derive/src/builder.rs new file mode 100644 index 00000000..4f2d5a4a --- /dev/null +++ b/type-length-value/derive/src/builder.rs @@ -0,0 +1,91 @@ +//! The actual token generator for the macro +use { + proc_macro2::{Span, TokenStream}, + quote::{quote, ToTokens}, + syn::{parse::Parse, Generics, Ident, Item, ItemEnum, ItemStruct, WhereClause}, +}; + +pub struct SplBorshVariableLenPackBuilder { + /// The struct/enum identifier + pub ident: Ident, + /// The item's generic arguments (if any) + pub generics: Generics, + /// The item's where clause for generics (if any) + pub where_clause: Option, +} + +impl TryFrom for SplBorshVariableLenPackBuilder { + type Error = syn::Error; + + fn try_from(item_enum: ItemEnum) -> Result { + let ident = item_enum.ident; + let where_clause = item_enum.generics.where_clause.clone(); + let generics = item_enum.generics; + Ok(Self { + ident, + generics, + where_clause, + }) + } +} + +impl TryFrom for SplBorshVariableLenPackBuilder { + type Error = syn::Error; + + fn try_from(item_struct: ItemStruct) -> Result { + let ident = item_struct.ident; + let where_clause = item_struct.generics.where_clause.clone(); + let generics = item_struct.generics; + Ok(Self { + ident, + generics, + where_clause, + }) + } +} + +impl Parse for SplBorshVariableLenPackBuilder { + fn parse(input: syn::parse::ParseStream) -> syn::Result { + let item = Item::parse(input)?; + match item { + Item::Enum(item_enum) => item_enum.try_into(), + Item::Struct(item_struct) => item_struct.try_into(), + _ => { + return Err(syn::Error::new( + Span::call_site(), + "Only enums and structs are supported", + )) + } + } + .map_err(|e| syn::Error::new(input.span(), format!("Failed to parse item: {}", e))) + } +} + +impl ToTokens for SplBorshVariableLenPackBuilder { + fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) { + tokens.extend::(self.into()); + } +} + +impl From<&SplBorshVariableLenPackBuilder> for TokenStream { + fn from(builder: &SplBorshVariableLenPackBuilder) -> Self { + let ident = &builder.ident; + let generics = &builder.generics; + let where_clause = &builder.where_clause; + quote! { + impl #generics spl_type_length_value::variable_len_pack::VariableLenPack for #ident #generics #where_clause { + fn pack_into_slice(&self, dst: &mut [u8]) -> Result<(), spl_type_length_value::solana_program_error::ProgramError> { + borsh::to_writer(&mut dst[..], self).map_err(Into::into) + } + + fn unpack_from_slice(src: &[u8]) -> Result { + solana_borsh::v1::try_from_slice_unchecked(src).map_err(Into::into) + } + + fn get_packed_len(&self) -> Result { + solana_borsh::v1::get_instance_packed_len(self).map_err(Into::into) + } + } + } + } +} diff --git a/type-length-value/derive/src/lib.rs b/type-length-value/derive/src/lib.rs new file mode 100644 index 00000000..b3f0f12a --- /dev/null +++ b/type-length-value/derive/src/lib.rs @@ -0,0 +1,22 @@ +//! Crate defining a derive macro for a basic borsh implementation of +//! the trait `VariableLenPack`. + +#![deny(missing_docs)] +#![cfg_attr(not(test), forbid(unsafe_code))] + +extern crate proc_macro; + +mod builder; + +use { + builder::SplBorshVariableLenPackBuilder, proc_macro::TokenStream, quote::ToTokens, + syn::parse_macro_input, +}; + +/// Derive macro to add `VariableLenPack` trait for borsh-implemented types +#[proc_macro_derive(SplBorshVariableLenPack)] +pub fn spl_borsh_variable_len_pack(input: TokenStream) -> TokenStream { + parse_macro_input!(input as SplBorshVariableLenPackBuilder) + .to_token_stream() + .into() +} diff --git a/type-length-value/js/.eslintignore b/type-length-value/js/.eslintignore new file mode 100644 index 00000000..6da325ef --- /dev/null +++ b/type-length-value/js/.eslintignore @@ -0,0 +1,5 @@ +docs +lib +test-ledger + +package-lock.json diff --git a/type-length-value/js/.eslintrc b/type-length-value/js/.eslintrc new file mode 100644 index 00000000..5aef10a4 --- /dev/null +++ b/type-length-value/js/.eslintrc @@ -0,0 +1,34 @@ +{ + "root": true, + "extends": [ + "eslint:recommended", + "plugin:@typescript-eslint/recommended", + "plugin:prettier/recommended", + "plugin:require-extensions/recommended" + ], + "parser": "@typescript-eslint/parser", + "plugins": [ + "@typescript-eslint", + "prettier", + "require-extensions" + ], + "rules": { + "@typescript-eslint/ban-ts-comment": "off", + "@typescript-eslint/no-explicit-any": "off", + "@typescript-eslint/no-unused-vars": "off", + "@typescript-eslint/no-empty-interface": "off", + "@typescript-eslint/consistent-type-imports": "error" + }, + "overrides": [ + { + "files": [ + "examples/**/*", + "test/**/*" + ], + "rules": { + "require-extensions/require-extensions": "off", + "require-extensions/require-index": "off" + } + } + ] +} diff --git a/type-length-value/js/.gitignore b/type-length-value/js/.gitignore new file mode 100644 index 00000000..21f33db8 --- /dev/null +++ b/type-length-value/js/.gitignore @@ -0,0 +1,13 @@ +.idea +.vscode +.DS_Store + +node_modules + +pnpm-lock.yaml +yarn.lock + +docs +lib +test-ledger +*.tsbuildinfo diff --git a/type-length-value/js/.mocharc.json b/type-length-value/js/.mocharc.json new file mode 100644 index 00000000..451c14c3 --- /dev/null +++ b/type-length-value/js/.mocharc.json @@ -0,0 +1,5 @@ +{ + "extension": ["ts"], + "node-option": ["experimental-specifier-resolution=node", "loader=ts-node/esm"], + "timeout": 5000 +} diff --git a/type-length-value/js/.nojekyll b/type-length-value/js/.nojekyll new file mode 100644 index 00000000..e69de29b diff --git a/type-length-value/js/LICENSE b/type-length-value/js/LICENSE new file mode 100644 index 00000000..d6456956 --- /dev/null +++ b/type-length-value/js/LICENSE @@ -0,0 +1,202 @@ + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/type-length-value/js/README.md b/type-length-value/js/README.md new file mode 100644 index 00000000..d0d14462 --- /dev/null +++ b/type-length-value/js/README.md @@ -0,0 +1,18 @@ +# Type-Length-Value-js + +Library with utilities for working with Type-Length-Value structures in js. + +## Example usage + +```ts +import { TlvState, SplDiscriminator } from '@solana/spl-type-length-value'; + +const tlv = new TlvState(tlvData, discriminatorSize, lengthSize); +const discriminator = await splDiscriminate("", discriminatorSize); + +const firstValue = tlv.firstBytes(discriminator); + +const allValues = tlv.bytesRepeating(discriminator); + +const firstThreeValues = tlv.bytesRepeating(discriminator, 3); +``` diff --git a/type-length-value/js/package.json b/type-length-value/js/package.json new file mode 100644 index 00000000..87503484 --- /dev/null +++ b/type-length-value/js/package.json @@ -0,0 +1,65 @@ +{ + "name": "@solana/spl-type-length-value", + "description": "SPL Type Length Value Library", + "version": "0.2.0", + "author": "Solana Labs Maintainers ", + "repository": "https://github.com/solana-labs/solana-program-library", + "license": "Apache-2.0", + "type": "module", + "sideEffects": false, + "engines": { + "node": ">=19" + }, + "files": [ + "lib", + "src", + "LICENSE", + "README.md" + ], + "publishConfig": { + "access": "public" + }, + "main": "./lib/cjs/index.js", + "module": "./lib/esm/index.js", + "types": "./lib/types/index.d.ts", + "exports": { + "types": "./lib/types/index.d.ts", + "require": "./lib/cjs/index.js", + "import": "./lib/esm/index.js" + }, + "scripts": { + "build": "tsc --build --verbose tsconfig.all.json", + "clean": "shx rm -rf lib **/*.tsbuildinfo || true", + "deploy": "npm run deploy:docs", + "deploy:docs": "npm run docs && gh-pages --dest type-length-value/js --dist docs --dotfiles", + "docs": "shx rm -rf docs && typedoc && shx cp .nojekyll docs/", + "lint": "eslint --max-warnings 0 .", + "lint:fix": "eslint --fix .", + "nuke": "shx rm -rf node_modules package-lock.json || true", + "postbuild": "shx echo '{ \"type\": \"commonjs\" }' > lib/cjs/package.json", + "reinstall": "npm run nuke && npm install", + "release": "npm run clean && npm run build", + "test": "mocha test", + "watch": "tsc --build --verbose --watch tsconfig.all.json" + }, + "dependencies": { + "@solana/assertions": "^2.0.0", + "buffer": "^6.0.3" + }, + "devDependencies": { + "@types/chai": "^5.0.1", + "@types/mocha": "^10.0.10", + "@types/node": "^22.10.2", + "@typescript-eslint/eslint-plugin": "^8.4.0", + "@typescript-eslint/parser": "^8.4.0", + "chai": "^5.1.2", + "eslint": "^8.57.0", + "eslint-plugin-require-extensions": "^0.1.1", + "gh-pages": "^6.2.0", + "mocha": "^11.0.1", + "shx": "^0.3.4", + "ts-node": "^10.9.2", + "typedoc": "^0.27.4", + "typescript": "^5.7.2" + } +} diff --git a/type-length-value/js/src/errors.ts b/type-length-value/js/src/errors.ts new file mode 100644 index 00000000..1b128cb8 --- /dev/null +++ b/type-length-value/js/src/errors.ts @@ -0,0 +1,11 @@ +/** Base class for errors */ +export abstract class TlvError extends Error { + constructor(message?: string) { + super(message); + } +} + +/** Thrown if the byte length of an tlv buffer doesn't match the expected size */ +export class TlvInvalidAccountDataError extends TlvError { + name = 'TlvInvalidAccountDataError'; +} diff --git a/type-length-value/js/src/index.ts b/type-length-value/js/src/index.ts new file mode 100644 index 00000000..68afecb9 --- /dev/null +++ b/type-length-value/js/src/index.ts @@ -0,0 +1,3 @@ +export * from './splDiscriminate.js'; +export * from './tlvState.js'; +export * from './errors.js'; diff --git a/type-length-value/js/src/splDiscriminate.ts b/type-length-value/js/src/splDiscriminate.ts new file mode 100644 index 00000000..01a32ac6 --- /dev/null +++ b/type-length-value/js/src/splDiscriminate.ts @@ -0,0 +1,8 @@ +import { assertDigestCapabilityIsAvailable } from '@solana/assertions'; + +export async function splDiscriminate(discriminator: string, length = 8): Promise { + assertDigestCapabilityIsAvailable(); + const bytes = new TextEncoder().encode(discriminator); + const digest = await crypto.subtle.digest('SHA-256', bytes); + return new Uint8Array(digest).subarray(0, length); +} diff --git a/type-length-value/js/src/tlvState.ts b/type-length-value/js/src/tlvState.ts new file mode 100644 index 00000000..816a2428 --- /dev/null +++ b/type-length-value/js/src/tlvState.ts @@ -0,0 +1,109 @@ +import { TlvInvalidAccountDataError } from './errors.js'; + +export type LengthSize = 1 | 2 | 4 | 8; + +export type Discriminator = Uint8Array; + +export class TlvState { + private readonly tlvData: Buffer; + private readonly discriminatorSize: number; + private readonly lengthSize: LengthSize; + + public constructor(buffer: Buffer, discriminatorSize = 2, lengthSize: LengthSize = 2, offset: number = 0) { + this.tlvData = buffer.subarray(offset); + this.discriminatorSize = discriminatorSize; + this.lengthSize = lengthSize; + } + + /** + * Get the raw tlv data + * + * @return the raw tlv data + */ + public get data(): Buffer { + return this.tlvData; + } + + private readEntryLength(size: LengthSize, offset: number, constructor: (x: number | bigint) => T): T { + switch (size) { + case 1: + return constructor(this.tlvData.readUInt8(offset)); + case 2: + return constructor(this.tlvData.readUInt16LE(offset)); + case 4: + return constructor(this.tlvData.readUInt32LE(offset)); + case 8: + return constructor(this.tlvData.readBigUInt64LE(offset)); + } + } + + /** + * Get a single entry from the tlv data. This function returns the first entry with the given type. + * + * @param type the type of the entry to get + * + * @return the entry from the tlv data or null + */ + public firstBytes(discriminator: Discriminator): Buffer | null { + const entries = this.bytesRepeating(discriminator, 1); + return entries.length > 0 ? entries[0] : null; + } + + /** + * Get a multiple entries from the tlv data. This function returns `count` or less entries with the given type. + * + * @param type the type of the entry to get + * @param count the number of entries to get (0 for all entries) + * + * @return the entry from the tlv data or null + */ + public bytesRepeating(discriminator: Discriminator, count = 0): Buffer[] { + const entries: Buffer[] = []; + let offset = 0; + while (offset < this.tlvData.length) { + if (offset + this.discriminatorSize + this.lengthSize > this.tlvData.length) { + throw new TlvInvalidAccountDataError(); + } + const type = this.tlvData.subarray(offset, offset + this.discriminatorSize); + offset += this.discriminatorSize; + const entryLength = this.readEntryLength(this.lengthSize, offset, Number); + offset += this.lengthSize; + if (offset + entryLength > this.tlvData.length) { + throw new TlvInvalidAccountDataError(); + } + if (type.equals(discriminator)) { + entries.push(this.tlvData.subarray(offset, offset + entryLength)); + } + if (count > 0 && entries.length >= count) { + break; + } + offset += entryLength; + } + return entries; + } + + /** + * Get all the discriminators from the tlv data. This function will return a type multiple times if it occurs multiple times in the tlv data. + * + * @return a list of the discriminators. + */ + public discriminators(): Buffer[] { + const discriminators: Buffer[] = []; + let offset = 0; + while (offset < this.tlvData.length) { + if (offset + this.discriminatorSize + this.lengthSize > this.tlvData.length) { + throw new TlvInvalidAccountDataError(); + } + const type = this.tlvData.subarray(offset, offset + this.discriminatorSize); + discriminators.push(type); + offset += this.discriminatorSize; + const entryLength = this.readEntryLength(this.lengthSize, offset, Number); + offset += this.lengthSize; + if (offset + entryLength > this.tlvData.length) { + throw new TlvInvalidAccountDataError(); + } + offset += entryLength; + } + return discriminators; + } +} diff --git a/type-length-value/js/test/splDiscriminate.test.ts b/type-length-value/js/test/splDiscriminate.test.ts new file mode 100644 index 00000000..8d2d62b8 --- /dev/null +++ b/type-length-value/js/test/splDiscriminate.test.ts @@ -0,0 +1,38 @@ +import { expect } from 'chai'; +import { splDiscriminate } from '../src/splDiscriminate'; + +const testVectors = [ + 'hello', + 'this-is-a-test', + 'test-namespace:this-is-a-test', + 'test-namespace:this-is-a-test:with-a-longer-name', +]; + +const testExpectedBytes = await Promise.all( + testVectors.map(x => + crypto.subtle.digest('SHA-256', new TextEncoder().encode(x)).then(digest => new Uint8Array(digest)), + ), +); + +describe('splDiscrimintor', () => { + const testSplDiscriminator = async (length: number) => { + for (let i = 0; i < testVectors.length; i++) { + const discriminator = await splDiscriminate(testVectors[i], length); + const expectedBytes = testExpectedBytes[i].subarray(0, length); + expect(discriminator).to.have.length(length); + expect(discriminator).to.deep.equal(expectedBytes); + } + }; + + it('should produce the expected bytes', () => { + testSplDiscriminator(8); + testSplDiscriminator(4); + testSplDiscriminator(2); + }); + + it('should produce the same bytes as rust library', async () => { + const expectedBytes = Buffer.from([105, 37, 101, 197, 75, 251, 102, 26]); + const discriminator = await splDiscriminate('spl-transfer-hook-interface:execute'); + expect(discriminator).to.deep.equal(expectedBytes); + }); +}); diff --git a/type-length-value/js/test/tlvData.test.ts b/type-length-value/js/test/tlvData.test.ts new file mode 100644 index 00000000..d43604ff --- /dev/null +++ b/type-length-value/js/test/tlvData.test.ts @@ -0,0 +1,147 @@ +import type { LengthSize } from '../src/tlvState'; +import { TlvState } from '../src/tlvState'; +import { expect } from 'chai'; + +describe('tlvData', () => { + // typeLength 1, lengthSize 2 + const tlvData1 = Buffer.concat([ + Buffer.from([0]), + Buffer.from([0, 0]), + Buffer.from([]), + Buffer.from([1]), + Buffer.from([1, 0]), + Buffer.from([1]), + Buffer.from([2]), + Buffer.from([2, 0]), + Buffer.from([1, 2]), + Buffer.from([0]), + Buffer.from([3, 0]), + Buffer.from([1, 2, 3]), + ]); + + // typeLength 2, lengthSize 1 + const tlvData2 = Buffer.concat([ + Buffer.from([0, 0]), + Buffer.from([0]), + Buffer.from([]), + Buffer.from([1, 0]), + Buffer.from([1]), + Buffer.from([1]), + Buffer.from([2, 0]), + Buffer.from([2]), + Buffer.from([1, 2]), + Buffer.from([0, 0]), + Buffer.from([3]), + Buffer.from([1, 2, 3]), + ]); + + // typeLength 4, lengthSize 8 + const tlvData3 = Buffer.concat([ + Buffer.from([0, 0, 0, 0]), + Buffer.from([0, 0, 0, 0, 0, 0, 0, 0]), + Buffer.from([]), + Buffer.from([1, 0, 0, 0]), + Buffer.from([1, 0, 0, 0, 0, 0, 0, 0]), + Buffer.from([1]), + Buffer.from([2, 0, 0, 0]), + Buffer.from([2, 0, 0, 0, 0, 0, 0, 0]), + Buffer.from([1, 2]), + Buffer.from([0, 0, 0, 0]), + Buffer.from([3, 0, 0, 0, 0, 0, 0, 0]), + Buffer.from([1, 2, 3]), + ]); + + // typeLength 8, lengthSize 4 + const tlvData4 = Buffer.concat([ + Buffer.from([0, 0, 0, 0, 0, 0, 0, 0]), + Buffer.from([0, 0, 0, 0]), + Buffer.from([]), + Buffer.from([1, 0, 0, 0, 0, 0, 0, 0]), + Buffer.from([1, 0, 0, 0]), + Buffer.from([1]), + Buffer.from([2, 0, 0, 0, 0, 0, 0, 0]), + Buffer.from([2, 0, 0, 0]), + Buffer.from([1, 2]), + Buffer.from([0, 0, 0, 0, 0, 0, 0, 0]), + Buffer.from([3, 0, 0, 0]), + Buffer.from([1, 2, 3]), + ]); + + const testRawData = (tlvData: Buffer, discriminatorSize: number, lengthSize: LengthSize) => { + const tlv = new TlvState(tlvData, discriminatorSize, lengthSize); + expect(tlv.data).to.be.deep.equal(tlvData); + const tlvWithOffset = new TlvState(tlvData, discriminatorSize, lengthSize, discriminatorSize + lengthSize); + expect(tlvWithOffset.data).to.be.deep.equal(tlvData.subarray(discriminatorSize + lengthSize)); + }; + + it('should get the raw tlv data', () => { + testRawData(tlvData1, 1, 2); + testRawData(tlvData2, 2, 1); + testRawData(tlvData3, 4, 8); + testRawData(tlvData4, 8, 4); + }); + + const testIndividualEntries = (tlvData: Buffer, discriminatorSize: number, lengthSize: LengthSize) => { + const tlv = new TlvState(tlvData, discriminatorSize, lengthSize); + + const type = Buffer.alloc(discriminatorSize); + type[0] = 0; + expect(tlv.firstBytes(type)).to.be.deep.equal(Buffer.from([])); + type[0] = 1; + expect(tlv.firstBytes(type)).to.be.deep.equal(Buffer.from([1])); + type[0] = 2; + expect(tlv.firstBytes(type)).to.be.deep.equal(Buffer.from([1, 2])); + type[0] = 3; + expect(tlv.firstBytes(type)).to.equal(null); + }; + + it('should get the entries individually', () => { + testIndividualEntries(tlvData1, 1, 2); + testIndividualEntries(tlvData2, 2, 1); + testIndividualEntries(tlvData3, 4, 8); + testIndividualEntries(tlvData4, 8, 4); + }); + + const testRepeatingEntries = (tlvData: Buffer, discriminatorSize: number, lengthSize: LengthSize) => { + const tlv = new TlvState(tlvData, discriminatorSize, lengthSize); + + const bufferDiscriminator = tlv.bytesRepeating(Buffer.alloc(discriminatorSize)); + expect(bufferDiscriminator).to.have.length(2); + expect(bufferDiscriminator[0]).to.be.deep.equal(Buffer.from([])); + expect(bufferDiscriminator[1]).to.be.deep.equal(Buffer.from([1, 2, 3])); + + const bufferDiscriminatorWithCount = tlv.bytesRepeating(Buffer.alloc(discriminatorSize), 1); + expect(bufferDiscriminatorWithCount).to.have.length(1); + expect(bufferDiscriminatorWithCount[0]).to.be.deep.equal(Buffer.from([])); + }; + + it('should get the repeating entries', () => { + testRepeatingEntries(tlvData1, 1, 2); + testRepeatingEntries(tlvData2, 2, 1); + testRepeatingEntries(tlvData3, 4, 8); + testRepeatingEntries(tlvData4, 8, 4); + }); + + const testDiscriminators = (tlvData: Buffer, discriminatorSize: number, lengthSize: LengthSize) => { + const tlv = new TlvState(tlvData, discriminatorSize, lengthSize); + const discriminators = tlv.discriminators(); + expect(discriminators).to.have.length(4); + + const type = Buffer.alloc(discriminatorSize); + type[0] = 0; + expect(discriminators[0]).to.be.deep.equal(type); + type[0] = 1; + expect(discriminators[1]).to.be.deep.equal(type); + type[0] = 2; + expect(discriminators[2]).to.be.deep.equal(type); + type[0] = 0; + expect(discriminators[3]).to.be.deep.equal(type); + }; + + it('should get the discriminators', () => { + testDiscriminators(tlvData1, 1, 2); + testDiscriminators(tlvData2, 2, 1); + testDiscriminators(tlvData3, 4, 8); + testDiscriminators(tlvData4, 8, 4); + }); +}); diff --git a/type-length-value/js/tsconfig.all.json b/type-length-value/js/tsconfig.all.json new file mode 100644 index 00000000..98551325 --- /dev/null +++ b/type-length-value/js/tsconfig.all.json @@ -0,0 +1,11 @@ +{ + "extends": "./tsconfig.root.json", + "references": [ + { + "path": "./tsconfig.cjs.json" + }, + { + "path": "./tsconfig.esm.json" + } + ] +} diff --git a/type-length-value/js/tsconfig.base.json b/type-length-value/js/tsconfig.base.json new file mode 100644 index 00000000..90620c4e --- /dev/null +++ b/type-length-value/js/tsconfig.base.json @@ -0,0 +1,14 @@ +{ + "include": [], + "compilerOptions": { + "target": "ESNext", + "module": "ESNext", + "moduleResolution": "Node", + "esModuleInterop": true, + "isolatedModules": true, + "noEmitOnError": true, + "resolveJsonModule": true, + "strict": true, + "stripInternal": true + } +} diff --git a/type-length-value/js/tsconfig.cjs.json b/type-length-value/js/tsconfig.cjs.json new file mode 100644 index 00000000..2db9b715 --- /dev/null +++ b/type-length-value/js/tsconfig.cjs.json @@ -0,0 +1,10 @@ +{ + "extends": "./tsconfig.base.json", + "include": ["src"], + "compilerOptions": { + "outDir": "lib/cjs", + "target": "ES2016", + "module": "CommonJS", + "sourceMap": true + } +} diff --git a/type-length-value/js/tsconfig.esm.json b/type-length-value/js/tsconfig.esm.json new file mode 100644 index 00000000..25e7e25e --- /dev/null +++ b/type-length-value/js/tsconfig.esm.json @@ -0,0 +1,13 @@ +{ + "extends": "./tsconfig.base.json", + "include": ["src"], + "compilerOptions": { + "outDir": "lib/esm", + "declarationDir": "lib/types", + "target": "ES2020", + "module": "ES2020", + "sourceMap": true, + "declaration": true, + "declarationMap": true + } +} diff --git a/type-length-value/js/tsconfig.json b/type-length-value/js/tsconfig.json new file mode 100644 index 00000000..2f9b239b --- /dev/null +++ b/type-length-value/js/tsconfig.json @@ -0,0 +1,8 @@ +{ + "extends": "./tsconfig.all.json", + "include": ["src", "test"], + "compilerOptions": { + "noEmit": true, + "skipLibCheck": true + } +} diff --git a/type-length-value/js/tsconfig.root.json b/type-length-value/js/tsconfig.root.json new file mode 100644 index 00000000..fadf294a --- /dev/null +++ b/type-length-value/js/tsconfig.root.json @@ -0,0 +1,6 @@ +{ + "extends": "./tsconfig.base.json", + "compilerOptions": { + "composite": true + } +} diff --git a/type-length-value/js/typedoc.json b/type-length-value/js/typedoc.json new file mode 100644 index 00000000..c39fc53a --- /dev/null +++ b/type-length-value/js/typedoc.json @@ -0,0 +1,5 @@ +{ + "entryPoints": ["src/index.ts"], + "out": "docs", + "readme": "README.md" +} diff --git a/type-length-value/src/error.rs b/type-length-value/src/error.rs new file mode 100644 index 00000000..203b200c --- /dev/null +++ b/type-length-value/src/error.rs @@ -0,0 +1,50 @@ +//! Error types +use { + solana_decode_error::DecodeError, + solana_msg::msg, + solana_program_error::{PrintProgramError, ProgramError}, +}; + +/// Errors that may be returned by the Token program. +#[repr(u32)] +#[derive(Clone, Debug, Eq, thiserror::Error, num_derive::FromPrimitive, PartialEq)] +pub enum TlvError { + /// Type not found in TLV data + #[error("Type not found in TLV data")] + TypeNotFound = 1_202_666_432, + /// Type already exists in TLV data + #[error("Type already exists in TLV data")] + TypeAlreadyExists, +} + +impl From for ProgramError { + fn from(e: TlvError) -> Self { + ProgramError::Custom(e as u32) + } +} + +impl DecodeError for TlvError { + fn type_of() -> &'static str { + "TlvError" + } +} + +impl PrintProgramError for TlvError { + fn print(&self) + where + E: 'static + + std::error::Error + + DecodeError + + PrintProgramError + + num_traits::FromPrimitive, + { + match self { + TlvError::TypeNotFound => { + msg!("Type not found in TLV data") + } + TlvError::TypeAlreadyExists => { + msg!("Type already exists in TLV data") + } + } + } +} diff --git a/type-length-value/src/length.rs b/type-length-value/src/length.rs new file mode 100644 index 00000000..3bab4a26 --- /dev/null +++ b/type-length-value/src/length.rs @@ -0,0 +1,25 @@ +//! Module for the length portion of a Type-Length-Value structure +use { + bytemuck::{Pod, Zeroable}, + solana_program_error::ProgramError, + spl_pod::primitives::PodU32, +}; + +/// Length in TLV structure +#[derive(Clone, Copy, Debug, Default, PartialEq, Pod, Zeroable)] +#[repr(transparent)] +pub struct Length(PodU32); +impl TryFrom for usize { + type Error = ProgramError; + fn try_from(n: Length) -> Result { + Self::try_from(u32::from(n.0)).map_err(|_| ProgramError::AccountDataTooSmall) + } +} +impl TryFrom for Length { + type Error = ProgramError; + fn try_from(n: usize) -> Result { + u32::try_from(n) + .map(|v| Self(PodU32::from(v))) + .map_err(|_| ProgramError::AccountDataTooSmall) + } +} diff --git a/type-length-value/src/lib.rs b/type-length-value/src/lib.rs new file mode 100644 index 00000000..ed7f0a2a --- /dev/null +++ b/type-length-value/src/lib.rs @@ -0,0 +1,18 @@ +//! Crate defining an interface for managing type-length-value entries in a slab +//! of bytes, to be used with Solana accounts. + +#![allow(clippy::arithmetic_side_effects)] +#![deny(missing_docs)] +#![cfg_attr(not(test), forbid(unsafe_code))] + +pub mod error; +pub mod length; +pub mod state; +pub mod variable_len_pack; + +// Export current sdk types for downstream users building with a different sdk +// version +// Expose derive macro on feature flag +#[cfg(feature = "derive")] +pub use spl_type_length_value_derive::SplBorshVariableLenPack; +pub use {solana_account_info, solana_decode_error, solana_program_error}; diff --git a/type-length-value/src/state.rs b/type-length-value/src/state.rs new file mode 100644 index 00000000..3016580e --- /dev/null +++ b/type-length-value/src/state.rs @@ -0,0 +1,1365 @@ +//! Type-length-value structure definition and manipulation + +use { + crate::{error::TlvError, length::Length, variable_len_pack::VariableLenPack}, + bytemuck::Pod, + solana_account_info::AccountInfo, + solana_program_error::ProgramError, + spl_discriminator::{ArrayDiscriminator, SplDiscriminate}, + spl_pod::bytemuck::{pod_from_bytes, pod_from_bytes_mut}, + std::{cmp::Ordering, mem::size_of}, +}; + +/// Get the current TlvIndices from the current spot +const fn get_indices_unchecked(type_start: usize, value_repetition_number: usize) -> TlvIndices { + let length_start = type_start.saturating_add(size_of::()); + let value_start = length_start.saturating_add(size_of::()); + TlvIndices { + type_start, + length_start, + value_start, + value_repetition_number, + } +} + +/// Internal helper struct for returning the indices of the type, length, and +/// value in a TLV entry +#[derive(Debug)] +struct TlvIndices { + pub type_start: usize, + pub length_start: usize, + pub value_start: usize, + pub value_repetition_number: usize, +} + +fn get_indices( + tlv_data: &[u8], + value_discriminator: ArrayDiscriminator, + init: bool, + repetition_number: Option, +) -> Result { + let mut current_repetition_number = 0; + let mut start_index = 0; + while start_index < tlv_data.len() { + let tlv_indices = get_indices_unchecked(start_index, current_repetition_number); + if tlv_data.len() < tlv_indices.value_start { + return Err(ProgramError::InvalidAccountData); + } + let discriminator = ArrayDiscriminator::try_from( + &tlv_data[tlv_indices.type_start..tlv_indices.length_start], + )?; + if discriminator == value_discriminator { + if let Some(desired_repetition_number) = repetition_number { + if current_repetition_number == desired_repetition_number { + return Ok(tlv_indices); + } + } + current_repetition_number += 1; + // got to an empty spot, init here, or error if we're searching, since + // nothing is written after an Uninitialized spot + } else if discriminator == ArrayDiscriminator::UNINITIALIZED { + if init { + return Ok(tlv_indices); + } else { + return Err(TlvError::TypeNotFound.into()); + } + } + let length = + pod_from_bytes::(&tlv_data[tlv_indices.length_start..tlv_indices.value_start])?; + let value_end_index = tlv_indices + .value_start + .saturating_add(usize::try_from(*length)?); + start_index = value_end_index; + } + Err(ProgramError::InvalidAccountData) +} + +// This function is doing two separate things at once, and would probably be +// better served by some custom iterator, but let's leave that for another day. +fn get_discriminators_and_end_index( + tlv_data: &[u8], +) -> Result<(Vec, usize), ProgramError> { + let mut discriminators = vec![]; + let mut start_index = 0; + while start_index < tlv_data.len() { + // This function is not concerned with repetitions, so we can just + // arbitrarily pass `0` here + let tlv_indices = get_indices_unchecked(start_index, 0); + if tlv_data.len() < tlv_indices.length_start { + // we got to the end, but there might be some uninitialized data after + let remainder = &tlv_data[tlv_indices.type_start..]; + if remainder.iter().all(|&x| x == 0) { + return Ok((discriminators, tlv_indices.type_start)); + } else { + return Err(ProgramError::InvalidAccountData); + } + } + let discriminator = ArrayDiscriminator::try_from( + &tlv_data[tlv_indices.type_start..tlv_indices.length_start], + )?; + if discriminator == ArrayDiscriminator::UNINITIALIZED { + return Ok((discriminators, tlv_indices.type_start)); + } else { + if tlv_data.len() < tlv_indices.value_start { + // not enough bytes to store the length, malformed + return Err(ProgramError::InvalidAccountData); + } + discriminators.push(discriminator); + let length = pod_from_bytes::( + &tlv_data[tlv_indices.length_start..tlv_indices.value_start], + )?; + + let value_end_index = tlv_indices + .value_start + .saturating_add(usize::try_from(*length)?); + if value_end_index > tlv_data.len() { + // value blows past the size of the slice, malformed + return Err(ProgramError::InvalidAccountData); + } + start_index = value_end_index; + } + } + Ok((discriminators, start_index)) +} + +fn get_bytes( + tlv_data: &[u8], + repetition_number: usize, +) -> Result<&[u8], ProgramError> { + let TlvIndices { + type_start: _, + length_start, + value_start, + value_repetition_number: _, + } = get_indices( + tlv_data, + V::SPL_DISCRIMINATOR, + false, + Some(repetition_number), + )?; + // get_indices has checked that tlv_data is long enough to include these + // indices + let length = pod_from_bytes::(&tlv_data[length_start..value_start])?; + let value_end = value_start.saturating_add(usize::try_from(*length)?); + if tlv_data.len() < value_end { + return Err(ProgramError::InvalidAccountData); + } + Ok(&tlv_data[value_start..value_end]) +} + +/// Trait for all TLV state +/// +/// Stores data as any number of type-length-value structures underneath, where: +/// +/// * the "type" is an `ArrayDiscriminator`, 8 bytes +/// * the "length" is a `Length`, 4 bytes +/// * the "value" is a slab of "length" bytes +/// +/// With this structure, it's possible to hold onto any number of entries with +/// unique discriminators, provided that the total underlying data has enough +/// bytes for every entry. +/// +/// For example, if we have two distinct types, one which is an 8-byte array +/// of value `[0, 1, 0, 0, 0, 0, 0, 0]` and discriminator +/// `[1, 1, 1, 1, 1, 1, 1, 1]`, and another which is just a single `u8` of value +/// `4` with the discriminator `[2, 2, 2, 2, 2, 2, 2, 2]`, we can deserialize +/// this buffer as follows: +/// +/// ``` +/// use { +/// bytemuck::{Pod, Zeroable}, +/// spl_discriminator::{ArrayDiscriminator, SplDiscriminate}, +/// spl_type_length_value::state::{TlvState, TlvStateBorrowed, TlvStateMut}, +/// }; +/// #[repr(C)] +/// #[derive(Clone, Copy, Debug, Default, PartialEq, Pod, Zeroable)] +/// struct MyPodValue { +/// data: [u8; 8], +/// } +/// impl SplDiscriminate for MyPodValue { +/// const SPL_DISCRIMINATOR: ArrayDiscriminator = ArrayDiscriminator::new([1; ArrayDiscriminator::LENGTH]); +/// } +/// #[repr(C)] +/// #[derive(Clone, Copy, Debug, Default, PartialEq, Pod, Zeroable)] +/// struct MyOtherPodValue { +/// data: u8, +/// } +/// impl SplDiscriminate for MyOtherPodValue { +/// const SPL_DISCRIMINATOR: ArrayDiscriminator = ArrayDiscriminator::new([2; ArrayDiscriminator::LENGTH]); +/// } +/// let buffer = [ +/// 1, 1, 1, 1, 1, 1, 1, 1, // first type's discriminator +/// 8, 0, 0, 0, // first type's length +/// 0, 1, 0, 0, 0, 0, 0, 0, // first type's value +/// 2, 2, 2, 2, 2, 2, 2, 2, // second type's discriminator +/// 1, 0, 0, 0, // second type's length +/// 4, // second type's value +/// ]; +/// let state = TlvStateBorrowed::unpack(&buffer).unwrap(); +/// let value = state.get_first_value::().unwrap(); +/// assert_eq!(value.data, [0, 1, 0, 0, 0, 0, 0, 0]); +/// let value = state.get_first_value::().unwrap(); +/// assert_eq!(value.data, 4); +/// ``` +/// +/// See the README and tests for more examples on how to use these types. +pub trait TlvState { + /// Get the full buffer containing all TLV data + fn get_data(&self) -> &[u8]; + + /// Unpack a portion of the TLV data as the desired Pod type for the entry + /// number specified + fn get_value_with_repetition( + &self, + repetition_number: usize, + ) -> Result<&V, ProgramError> { + let data = get_bytes::(self.get_data(), repetition_number)?; + pod_from_bytes::(data) + } + + /// Unpack a portion of the TLV data as the desired Pod type for the first + /// entry found + fn get_first_value(&self) -> Result<&V, ProgramError> { + self.get_value_with_repetition::(0) + } + + /// Unpacks a portion of the TLV data as the desired variable-length type + /// for the entry number specified + fn get_variable_len_value_with_repetition( + &self, + repetition_number: usize, + ) -> Result { + let data = get_bytes::(self.get_data(), repetition_number)?; + V::unpack_from_slice(data) + } + + /// Unpacks a portion of the TLV data as the desired variable-length type + /// for the first entry found + fn get_first_variable_len_value( + &self, + ) -> Result { + self.get_variable_len_value_with_repetition::(0) + } + + /// Unpack a portion of the TLV data as bytes for the entry number specified + fn get_bytes_with_repetition( + &self, + repetition_number: usize, + ) -> Result<&[u8], ProgramError> { + get_bytes::(self.get_data(), repetition_number) + } + + /// Unpack a portion of the TLV data as bytes for the first entry found + fn get_first_bytes(&self) -> Result<&[u8], ProgramError> { + self.get_bytes_with_repetition::(0) + } + + /// Iterates through the TLV entries, returning only the types + fn get_discriminators(&self) -> Result, ProgramError> { + get_discriminators_and_end_index(self.get_data()).map(|v| v.0) + } + + /// Get the base size required for TLV data + fn get_base_len() -> usize { + get_base_len() + } +} + +/// Encapsulates owned TLV data +#[derive(Debug, PartialEq)] +pub struct TlvStateOwned { + /// Raw TLV data, deserialized on demand + data: Vec, +} +impl TlvStateOwned { + /// Unpacks TLV state data + /// + /// Fails if no state is initialized or if data is too small + pub fn unpack(data: Vec) -> Result { + check_data(&data)?; + Ok(Self { data }) + } +} +impl TlvState for TlvStateOwned { + fn get_data(&self) -> &[u8] { + &self.data + } +} + +/// Encapsulates immutable base state data (mint or account) with possible +/// extensions +#[derive(Debug, PartialEq)] +pub struct TlvStateBorrowed<'data> { + /// Slice of data containing all TLV data, deserialized on demand + data: &'data [u8], +} +impl<'data> TlvStateBorrowed<'data> { + /// Unpacks TLV state data + /// + /// Fails if no state is initialized or if data is too small + pub fn unpack(data: &'data [u8]) -> Result { + check_data(data)?; + Ok(Self { data }) + } +} +impl<'a> TlvState for TlvStateBorrowed<'a> { + fn get_data(&self) -> &[u8] { + self.data + } +} + +/// Encapsulates mutable base state data (mint or account) with possible +/// extensions +#[derive(Debug, PartialEq)] +pub struct TlvStateMut<'data> { + /// Slice of data containing all TLV data, deserialized on demand + data: &'data mut [u8], +} +impl<'data> TlvStateMut<'data> { + /// Unpacks TLV state data + /// + /// Fails if no state is initialized or if data is too small + pub fn unpack(data: &'data mut [u8]) -> Result { + check_data(data)?; + Ok(Self { data }) + } + + /// Unpack a portion of the TLV data as the desired type that allows + /// modifying the type for the entry number specified + pub fn get_value_with_repetition_mut( + &mut self, + repetition_number: usize, + ) -> Result<&mut V, ProgramError> { + let data = self.get_bytes_with_repetition_mut::(repetition_number)?; + pod_from_bytes_mut::(data) + } + + /// Unpack a portion of the TLV data as the desired type that allows + /// modifying the type for the first entry found + pub fn get_first_value_mut( + &mut self, + ) -> Result<&mut V, ProgramError> { + self.get_value_with_repetition_mut::(0) + } + + /// Unpack a portion of the TLV data as mutable bytes for the entry number + /// specified + pub fn get_bytes_with_repetition_mut( + &mut self, + repetition_number: usize, + ) -> Result<&mut [u8], ProgramError> { + let TlvIndices { + type_start: _, + length_start, + value_start, + value_repetition_number: _, + } = get_indices( + self.data, + V::SPL_DISCRIMINATOR, + false, + Some(repetition_number), + )?; + + let length = pod_from_bytes::(&self.data[length_start..value_start])?; + let value_end = value_start.saturating_add(usize::try_from(*length)?); + if self.data.len() < value_end { + return Err(ProgramError::InvalidAccountData); + } + Ok(&mut self.data[value_start..value_end]) + } + + /// Unpack a portion of the TLV data as mutable bytes for the first entry + /// found + pub fn get_first_bytes_mut(&mut self) -> Result<&mut [u8], ProgramError> { + self.get_bytes_with_repetition_mut::(0) + } + + /// Packs the default TLV data into the first open slot in the data buffer. + /// Handles repetition based on the boolean arg provided: + /// * `true`: If extension is already found in the buffer, it returns an + /// error. + /// * `false`: Will add a new entry to the next open slot. + pub fn init_value( + &mut self, + allow_repetition: bool, + ) -> Result<(&mut V, usize), ProgramError> { + let length = size_of::(); + let (buffer, repetition_number) = self.alloc::(length, allow_repetition)?; + let extension_ref = pod_from_bytes_mut::(buffer)?; + *extension_ref = V::default(); + Ok((extension_ref, repetition_number)) + } + + /// Packs a variable-length value into its appropriate data segment, where + /// repeating discriminators _are_ allowed + pub fn pack_variable_len_value_with_repetition( + &mut self, + value: &V, + repetition_number: usize, + ) -> Result<(), ProgramError> { + let data = self.get_bytes_with_repetition_mut::(repetition_number)?; + // NOTE: Do *not* use `pack`, since the length check will cause + // reallocations to smaller sizes to fail + value.pack_into_slice(data) + } + + /// Packs a variable-length value into its appropriate data segment, where + /// no repeating discriminators are allowed + pub fn pack_first_variable_len_value( + &mut self, + value: &V, + ) -> Result<(), ProgramError> { + self.pack_variable_len_value_with_repetition::(value, 0) + } + + /// Allocate the given number of bytes for the given SplDiscriminate + pub fn alloc( + &mut self, + length: usize, + allow_repetition: bool, + ) -> Result<(&mut [u8], usize), ProgramError> { + let TlvIndices { + type_start, + length_start, + value_start, + value_repetition_number, + } = get_indices( + self.data, + V::SPL_DISCRIMINATOR, + true, + if allow_repetition { None } else { Some(0) }, + )?; + + let discriminator = ArrayDiscriminator::try_from(&self.data[type_start..length_start])?; + if discriminator == ArrayDiscriminator::UNINITIALIZED { + // write type + let discriminator_ref = &mut self.data[type_start..length_start]; + discriminator_ref.copy_from_slice(V::SPL_DISCRIMINATOR.as_ref()); + // write length + let length_ref = + pod_from_bytes_mut::(&mut self.data[length_start..value_start])?; + *length_ref = Length::try_from(length)?; + + let value_end = value_start.saturating_add(length); + if self.data.len() < value_end { + return Err(ProgramError::InvalidAccountData); + } + Ok(( + &mut self.data[value_start..value_end], + value_repetition_number, + )) + } else { + Err(TlvError::TypeAlreadyExists.into()) + } + } + + /// Allocates and serializes a new TLV entry from a `VariableLenPack` type + pub fn alloc_and_pack_variable_len_entry( + &mut self, + value: &V, + allow_repetition: bool, + ) -> Result { + let length = value.get_packed_len()?; + let (data, repetition_number) = self.alloc::(length, allow_repetition)?; + value.pack_into_slice(data)?; + Ok(repetition_number) + } + + /// Reallocate the given number of bytes for the given SplDiscriminate. If + /// the new length is smaller, it will compact the rest of the buffer + /// and zero out the difference at the end. If it's larger, it will move + /// the rest of the buffer data and zero out the new data. + pub fn realloc_with_repetition( + &mut self, + length: usize, + repetition_number: usize, + ) -> Result<&mut [u8], ProgramError> { + let TlvIndices { + type_start: _, + length_start, + value_start, + value_repetition_number: _, + } = get_indices( + self.data, + V::SPL_DISCRIMINATOR, + false, + Some(repetition_number), + )?; + let (_, end_index) = get_discriminators_and_end_index(self.data)?; + let data_len = self.data.len(); + + let length_ref = pod_from_bytes_mut::(&mut self.data[length_start..value_start])?; + let old_length = usize::try_from(*length_ref)?; + + // check that we're not going to panic during `copy_within` + if old_length < length { + let new_end_index = end_index.saturating_add(length.saturating_sub(old_length)); + if new_end_index > data_len { + return Err(ProgramError::InvalidAccountData); + } + } + + // write new length after the check, to avoid getting into a bad situation + // if trying to recover from an error + *length_ref = Length::try_from(length)?; + + let old_value_end = value_start.saturating_add(old_length); + let new_value_end = value_start.saturating_add(length); + self.data + .copy_within(old_value_end..end_index, new_value_end); + match old_length.cmp(&length) { + Ordering::Greater => { + // realloc to smaller, fill the end + let new_end_index = end_index.saturating_sub(old_length.saturating_sub(length)); + self.data[new_end_index..end_index].fill(0); + } + Ordering::Less => { + // realloc to bigger, fill the moved part + self.data[old_value_end..new_value_end].fill(0); + } + Ordering::Equal => {} // nothing needed! + } + + Ok(&mut self.data[value_start..new_value_end]) + } + + /// Reallocate the given number of bytes for the given SplDiscriminate, + /// where no repeating discriminators are allowed + pub fn realloc_first( + &mut self, + length: usize, + ) -> Result<&mut [u8], ProgramError> { + self.realloc_with_repetition::(length, 0) + } +} + +impl<'a> TlvState for TlvStateMut<'a> { + fn get_data(&self) -> &[u8] { + self.data + } +} + +/// Packs a variable-length value into an existing TLV space, reallocating +/// the account and TLV as needed to accommodate for any change in space +pub fn realloc_and_pack_variable_len_with_repetition( + account_info: &AccountInfo, + value: &V, + repetition_number: usize, +) -> Result<(), ProgramError> { + let previous_length = { + let data = account_info.try_borrow_data()?; + let TlvIndices { + type_start: _, + length_start, + value_start, + value_repetition_number: _, + } = get_indices(&data, V::SPL_DISCRIMINATOR, false, Some(repetition_number))?; + usize::try_from(*pod_from_bytes::(&data[length_start..value_start])?)? + }; + let new_length = value.get_packed_len()?; + let previous_account_size = account_info.try_data_len()?; + if previous_length < new_length { + // size increased, so realloc the account, then the TLV entry, then write data + let additional_bytes = new_length + .checked_sub(previous_length) + .ok_or(ProgramError::AccountDataTooSmall)?; + account_info.realloc(previous_account_size.saturating_add(additional_bytes), true)?; + let mut buffer = account_info.try_borrow_mut_data()?; + let mut state = TlvStateMut::unpack(&mut buffer)?; + state.realloc_with_repetition::(new_length, repetition_number)?; + state.pack_variable_len_value_with_repetition(value, repetition_number)?; + } else { + // do it backwards otherwise, write the state, realloc TLV, then the account + let mut buffer = account_info.try_borrow_mut_data()?; + let mut state = TlvStateMut::unpack(&mut buffer)?; + state.pack_variable_len_value_with_repetition(value, repetition_number)?; + let removed_bytes = previous_length + .checked_sub(new_length) + .ok_or(ProgramError::AccountDataTooSmall)?; + if removed_bytes > 0 { + // we decreased the size, so need to realloc the TLV, then the account + state.realloc_with_repetition::(new_length, repetition_number)?; + // this is probably fine, but be safe and avoid invalidating references + drop(buffer); + account_info.realloc(previous_account_size.saturating_sub(removed_bytes), false)?; + } + } + Ok(()) +} + +/// Packs a variable-length value into an existing TLV space, where no repeating +/// discriminators are allowed +pub fn realloc_and_pack_first_variable_len( + account_info: &AccountInfo, + value: &V, +) -> Result<(), ProgramError> { + realloc_and_pack_variable_len_with_repetition::(account_info, value, 0) +} + +/// Get the base size required for TLV data +const fn get_base_len() -> usize { + get_indices_unchecked(0, 0).value_start +} + +fn check_data(tlv_data: &[u8]) -> Result<(), ProgramError> { + // should be able to iterate through all entries in the TLV structure + let _ = get_discriminators_and_end_index(tlv_data)?; + Ok(()) +} + +#[cfg(test)] +mod test { + use { + super::*, + bytemuck::{Pod, Zeroable}, + }; + + const TEST_BUFFER: &[u8] = &[ + 1, 1, 1, 1, 1, 1, 1, 1, // discriminator + 32, 0, 0, 0, // length + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, // value + 0, 0, // empty, not enough for a discriminator + ]; + + const TEST_BIG_BUFFER: &[u8] = &[ + 1, 1, 1, 1, 1, 1, 1, 1, // discriminator + 32, 0, 0, 0, // length + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, // value + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, // empty, but enough for a discriminator and empty value + ]; + + #[repr(C)] + #[derive(Clone, Copy, Debug, Default, PartialEq, Pod, Zeroable)] + struct TestValue { + data: [u8; 32], + } + impl SplDiscriminate for TestValue { + const SPL_DISCRIMINATOR: ArrayDiscriminator = + ArrayDiscriminator::new([1; ArrayDiscriminator::LENGTH]); + } + + #[repr(C)] + #[derive(Clone, Copy, Debug, Default, PartialEq, Pod, Zeroable)] + struct TestSmallValue { + data: [u8; 3], + } + impl SplDiscriminate for TestSmallValue { + const SPL_DISCRIMINATOR: ArrayDiscriminator = + ArrayDiscriminator::new([2; ArrayDiscriminator::LENGTH]); + } + + #[repr(transparent)] + #[derive(Clone, Copy, Debug, Default, PartialEq, Pod, Zeroable)] + struct TestEmptyValue; + impl SplDiscriminate for TestEmptyValue { + const SPL_DISCRIMINATOR: ArrayDiscriminator = + ArrayDiscriminator::new([3; ArrayDiscriminator::LENGTH]); + } + + #[repr(C)] + #[derive(Clone, Copy, Debug, PartialEq, Pod, Zeroable)] + struct TestNonZeroDefault { + data: [u8; 5], + } + const TEST_NON_ZERO_DEFAULT_DATA: [u8; 5] = [4; 5]; + impl SplDiscriminate for TestNonZeroDefault { + const SPL_DISCRIMINATOR: ArrayDiscriminator = + ArrayDiscriminator::new([4; ArrayDiscriminator::LENGTH]); + } + impl Default for TestNonZeroDefault { + fn default() -> Self { + Self { + data: TEST_NON_ZERO_DEFAULT_DATA, + } + } + } + + #[test] + fn unpack_opaque_buffer() { + let state = TlvStateBorrowed::unpack(TEST_BUFFER).unwrap(); + let value = state.get_first_value::().unwrap(); + assert_eq!(value.data, [1; 32]); + assert_eq!( + state.get_first_value::(), + Err(ProgramError::InvalidAccountData) + ); + + let mut test_buffer = TEST_BUFFER.to_vec(); + let state = TlvStateMut::unpack(&mut test_buffer).unwrap(); + let value = state.get_first_value::().unwrap(); + assert_eq!(value.data, [1; 32]); + let state = TlvStateOwned::unpack(test_buffer).unwrap(); + let value = state.get_first_value::().unwrap(); + assert_eq!(value.data, [1; 32]); + } + + #[test] + fn fail_unpack_opaque_buffer() { + // input buffer too small + let mut buffer = vec![0, 3]; + assert_eq!( + TlvStateBorrowed::unpack(&buffer), + Err(ProgramError::InvalidAccountData) + ); + assert_eq!( + TlvStateMut::unpack(&mut buffer), + Err(ProgramError::InvalidAccountData) + ); + assert_eq!( + TlvStateMut::unpack(&mut buffer), + Err(ProgramError::InvalidAccountData) + ); + + // tweak the discriminator + let mut buffer = TEST_BUFFER.to_vec(); + buffer[0] += 1; + let state = TlvStateMut::unpack(&mut buffer).unwrap(); + assert_eq!( + state.get_first_value::(), + Err(ProgramError::InvalidAccountData) + ); + + // tweak the length, too big + let mut buffer = TEST_BUFFER.to_vec(); + buffer[ArrayDiscriminator::LENGTH] += 10; + assert_eq!( + TlvStateMut::unpack(&mut buffer), + Err(ProgramError::InvalidAccountData) + ); + + // tweak the length, too small + let mut buffer = TEST_BIG_BUFFER.to_vec(); + buffer[ArrayDiscriminator::LENGTH] -= 1; + let state = TlvStateMut::unpack(&mut buffer).unwrap(); + assert_eq!( + state.get_first_value::(), + Err(ProgramError::InvalidArgument) + ); + + // data buffer is too small for type + let buffer = &TEST_BUFFER[..TEST_BUFFER.len() - 5]; + assert_eq!( + TlvStateBorrowed::unpack(buffer), + Err(ProgramError::InvalidAccountData) + ); + } + + #[test] + fn get_discriminators_with_opaque_buffer() { + // incorrect due to the length + assert_eq!( + get_discriminators_and_end_index(&[1, 0, 1, 1]).unwrap_err(), + ProgramError::InvalidAccountData, + ); + // correct due to the good discriminator length and zero length + assert_eq!( + get_discriminators_and_end_index(&[1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]).unwrap(), + (vec![ArrayDiscriminator::from(1)], 12) + ); + // correct since it's just uninitialized data + assert_eq!( + get_discriminators_and_end_index(&[0, 0, 0, 0, 0, 0, 0, 0]).unwrap(), + (vec![], 0) + ); + } + + #[test] + fn value_pack_unpack() { + let account_size = + get_base_len() + size_of::() + get_base_len() + size_of::(); + let mut buffer = vec![0; account_size]; + + let mut state = TlvStateMut::unpack(&mut buffer).unwrap(); + + // success init and write value + let value = state.init_value::(false).unwrap().0; + let data = [100; 32]; + value.data = data; + assert_eq!( + &state.get_discriminators().unwrap(), + &[TestValue::SPL_DISCRIMINATOR], + ); + assert_eq!(&state.get_first_value::().unwrap().data, &data,); + + // fail init extension when already initialized + assert_eq!( + state.init_value::(false).unwrap_err(), + TlvError::TypeAlreadyExists.into(), + ); + + // check raw buffer + let mut expect = vec![]; + expect.extend_from_slice(TestValue::SPL_DISCRIMINATOR.as_ref()); + expect.extend_from_slice(&u32::try_from(size_of::()).unwrap().to_le_bytes()); + expect.extend_from_slice(&data); + expect.extend_from_slice(&[0; size_of::()]); + expect.extend_from_slice(&[0; size_of::()]); + expect.extend_from_slice(&[0; size_of::()]); + assert_eq!(expect, buffer); + + // check unpacking + let mut state = TlvStateMut::unpack(&mut buffer).unwrap(); + let unpacked = state.get_first_value_mut::().unwrap(); + assert_eq!(*unpacked, TestValue { data }); + + // update extension + let new_data = [101; 32]; + unpacked.data = new_data; + + // check updates are propagated + let state = TlvStateBorrowed::unpack(&buffer).unwrap(); + let unpacked = state.get_first_value::().unwrap(); + assert_eq!(*unpacked, TestValue { data: new_data }); + + // check raw buffer + let mut expect = vec![]; + expect.extend_from_slice(TestValue::SPL_DISCRIMINATOR.as_ref()); + expect.extend_from_slice(&u32::try_from(size_of::()).unwrap().to_le_bytes()); + expect.extend_from_slice(&new_data); + expect.extend_from_slice(&[0; size_of::()]); + expect.extend_from_slice(&[0; size_of::()]); + expect.extend_from_slice(&[0; size_of::()]); + assert_eq!(expect, buffer); + + let mut state = TlvStateMut::unpack(&mut buffer).unwrap(); + // init one more value + let new_value = state.init_value::(false).unwrap().0; + let small_data = [102; 3]; + new_value.data = small_data; + + assert_eq!( + &state.get_discriminators().unwrap(), + &[ + TestValue::SPL_DISCRIMINATOR, + TestSmallValue::SPL_DISCRIMINATOR + ] + ); + + // check raw buffer + let mut expect = vec![]; + expect.extend_from_slice(TestValue::SPL_DISCRIMINATOR.as_ref()); + expect.extend_from_slice(&u32::try_from(size_of::()).unwrap().to_le_bytes()); + expect.extend_from_slice(&new_data); + expect.extend_from_slice(TestSmallValue::SPL_DISCRIMINATOR.as_ref()); + expect.extend_from_slice( + &u32::try_from(size_of::()) + .unwrap() + .to_le_bytes(), + ); + expect.extend_from_slice(&small_data); + assert_eq!(expect, buffer); + + // fail to init one more extension that does not fit + let mut state = TlvStateMut::unpack(&mut buffer).unwrap(); + assert_eq!( + state.init_value::(false), + Err(ProgramError::InvalidAccountData), + ); + } + + #[test] + fn value_any_order() { + let account_size = + get_base_len() + size_of::() + get_base_len() + size_of::(); + let mut buffer = vec![0; account_size]; + + let mut state = TlvStateMut::unpack(&mut buffer).unwrap(); + + let data = [99; 32]; + let small_data = [98; 3]; + + // write values + let value = state.init_value::(false).unwrap().0; + value.data = data; + let value = state.init_value::(false).unwrap().0; + value.data = small_data; + + assert_eq!( + &state.get_discriminators().unwrap(), + &[ + TestValue::SPL_DISCRIMINATOR, + TestSmallValue::SPL_DISCRIMINATOR, + ] + ); + + // write values in a different order + let mut other_buffer = vec![0; account_size]; + let mut state = TlvStateMut::unpack(&mut other_buffer).unwrap(); + + let value = state.init_value::(false).unwrap().0; + value.data = small_data; + let value = state.init_value::(false).unwrap().0; + value.data = data; + + assert_eq!( + &state.get_discriminators().unwrap(), + &[ + TestSmallValue::SPL_DISCRIMINATOR, + TestValue::SPL_DISCRIMINATOR, + ] + ); + + // buffers are NOT the same because written in a different order + assert_ne!(buffer, other_buffer); + let state = TlvStateBorrowed::unpack(&buffer).unwrap(); + let other_state = TlvStateBorrowed::unpack(&other_buffer).unwrap(); + + // BUT values are the same + assert_eq!( + state.get_first_value::().unwrap(), + other_state.get_first_value::().unwrap() + ); + assert_eq!( + state.get_first_value::().unwrap(), + other_state.get_first_value::().unwrap() + ); + } + + #[test] + fn init_nonzero_default() { + let account_size = get_base_len() + size_of::(); + let mut buffer = vec![0; account_size]; + let mut state = TlvStateMut::unpack(&mut buffer).unwrap(); + let value = state.init_value::(false).unwrap().0; + assert_eq!(value.data, TEST_NON_ZERO_DEFAULT_DATA); + } + + #[test] + fn init_buffer_too_small() { + let account_size = get_base_len() + size_of::(); + let mut buffer = vec![0; account_size - 1]; + let mut state = TlvStateMut::unpack(&mut buffer).unwrap(); + let err = state.init_value::(false).unwrap_err(); + assert_eq!(err, ProgramError::InvalidAccountData); + + // hack the buffer to look like it was initialized, still fails + let discriminator_ref = &mut state.data[0..ArrayDiscriminator::LENGTH]; + discriminator_ref.copy_from_slice(TestValue::SPL_DISCRIMINATOR.as_ref()); + state.data[ArrayDiscriminator::LENGTH] = 32; + let err = state.get_first_value::().unwrap_err(); + assert_eq!(err, ProgramError::InvalidAccountData); + assert_eq!( + state.get_discriminators().unwrap_err(), + ProgramError::InvalidAccountData + ); + } + + #[test] + fn value_with_no_data() { + let account_size = get_base_len() + size_of::(); + let mut buffer = vec![0; account_size]; + let mut state = TlvStateMut::unpack(&mut buffer).unwrap(); + + assert_eq!( + state.get_first_value::().unwrap_err(), + TlvError::TypeNotFound.into(), + ); + + state.init_value::(false).unwrap(); + state.get_first_value::().unwrap(); + + // re-init fails + assert_eq!( + state.init_value::(false).unwrap_err(), + TlvError::TypeAlreadyExists.into(), + ); + } + + #[test] + fn alloc_first() { + let tlv_size = 1; + let account_size = get_base_len() + tlv_size; + let mut buffer = vec![0; account_size]; + let mut state = TlvStateMut::unpack(&mut buffer).unwrap(); + + // not enough room + let data = state.alloc::(tlv_size, false).unwrap().0; + assert_eq!( + pod_from_bytes_mut::(data).unwrap_err(), + ProgramError::InvalidArgument, + ); + + // can't double alloc + assert_eq!( + state.alloc::(tlv_size, false).unwrap_err(), + TlvError::TypeAlreadyExists.into(), + ); + } + + #[test] + fn alloc_with_repetition() { + let tlv_size = 1; + let account_size = (get_base_len() + tlv_size) * 2; + let mut buffer = vec![0; account_size]; + let mut state = TlvStateMut::unpack(&mut buffer).unwrap(); + + let (data, repetition_number) = state.alloc::(tlv_size, true).unwrap(); + assert_eq!(repetition_number, 0); + + // not enough room + assert_eq!( + pod_from_bytes_mut::(data).unwrap_err(), + ProgramError::InvalidArgument, + ); + + // Can alloc again! + let (_data, repetition_number) = state.alloc::(tlv_size, true).unwrap(); + assert_eq!(repetition_number, 1); + } + + #[test] + fn realloc_first() { + const TLV_SIZE: usize = 10; + const EXTRA_SPACE: usize = 5; + const SMALL_SIZE: usize = 2; + const ACCOUNT_SIZE: usize = get_base_len() + + TLV_SIZE + + EXTRA_SPACE + + get_base_len() + + size_of::(); + let mut buffer = vec![0; ACCOUNT_SIZE]; + let mut state = TlvStateMut::unpack(&mut buffer).unwrap(); + + // alloc both types + let _ = state.alloc::(TLV_SIZE, false).unwrap(); + let _ = state.init_value::(false).unwrap(); + + // realloc first entry to larger, all 0 + let data = state + .realloc_first::(TLV_SIZE + EXTRA_SPACE) + .unwrap(); + assert_eq!(data, [0; TLV_SIZE + EXTRA_SPACE]); + let value = state.get_first_value::().unwrap(); + assert_eq!(*value, TestNonZeroDefault::default()); + + // realloc to smaller, still all 0 + let data = state.realloc_first::(SMALL_SIZE).unwrap(); + assert_eq!(data, [0; SMALL_SIZE]); + let value = state.get_first_value::().unwrap(); + assert_eq!(*value, TestNonZeroDefault::default()); + let (_, end_index) = get_discriminators_and_end_index(&buffer).unwrap(); + assert_eq!( + &buffer[end_index..ACCOUNT_SIZE], + [0; TLV_SIZE + EXTRA_SPACE - SMALL_SIZE] + ); + + // unpack again since we dropped the last `state` + let mut state = TlvStateMut::unpack(&mut buffer).unwrap(); + // realloc too much, fails + assert_eq!( + state + .realloc_first::(TLV_SIZE + EXTRA_SPACE + 1) + .unwrap_err(), + ProgramError::InvalidAccountData, + ); + } + + #[test] + fn realloc_with_repeating_entries() { + const TLV_SIZE: usize = 10; + const EXTRA_SPACE: usize = 5; + const SMALL_SIZE: usize = 2; + const ACCOUNT_SIZE: usize = get_base_len() + + TLV_SIZE + + EXTRA_SPACE + + get_base_len() + + TLV_SIZE + + get_base_len() + + size_of::(); + let mut buffer = vec![0; ACCOUNT_SIZE]; + let mut state = TlvStateMut::unpack(&mut buffer).unwrap(); + + // alloc both types, two for the first type and one for the second + let _ = state.alloc::(TLV_SIZE, true).unwrap(); + let _ = state.alloc::(TLV_SIZE, true).unwrap(); + let _ = state.init_value::(true).unwrap(); + + // realloc first entry to larger, all 0 + let data = state + .realloc_with_repetition::(TLV_SIZE + EXTRA_SPACE, 0) + .unwrap(); + assert_eq!(data, [0; TLV_SIZE + EXTRA_SPACE]); + let value = state.get_bytes_with_repetition::(0).unwrap(); + assert_eq!(*value, [0; TLV_SIZE + EXTRA_SPACE]); + let value = state.get_bytes_with_repetition::(1).unwrap(); + assert_eq!(*value, [0; TLV_SIZE]); + let value = state.get_first_value::().unwrap(); + assert_eq!(*value, TestNonZeroDefault::default()); + + // realloc to smaller, still all 0 + let data = state + .realloc_with_repetition::(SMALL_SIZE, 0) + .unwrap(); + assert_eq!(data, [0; SMALL_SIZE]); + let value = state.get_bytes_with_repetition::(0).unwrap(); + assert_eq!(*value, [0; SMALL_SIZE]); + let value = state.get_bytes_with_repetition::(1).unwrap(); + assert_eq!(*value, [0; TLV_SIZE]); + let value = state.get_first_value::().unwrap(); + assert_eq!(*value, TestNonZeroDefault::default()); + let (_, end_index) = get_discriminators_and_end_index(&buffer).unwrap(); + assert_eq!( + &buffer[end_index..ACCOUNT_SIZE], + [0; TLV_SIZE + EXTRA_SPACE - SMALL_SIZE] + ); + + // unpack again since we dropped the last `state` + let mut state = TlvStateMut::unpack(&mut buffer).unwrap(); + // realloc too much, fails + assert_eq!( + state + .realloc_with_repetition::(TLV_SIZE + EXTRA_SPACE + 1, 0) + .unwrap_err(), + ProgramError::InvalidAccountData, + ); + } + + #[derive(Clone, Debug, PartialEq)] + struct TestVariableLen { + data: String, // test with a variable length type + } + impl SplDiscriminate for TestVariableLen { + const SPL_DISCRIMINATOR: ArrayDiscriminator = + ArrayDiscriminator::new([5; ArrayDiscriminator::LENGTH]); + } + impl VariableLenPack for TestVariableLen { + fn pack_into_slice(&self, dst: &mut [u8]) -> Result<(), ProgramError> { + let bytes = self.data.as_bytes(); + let end = 8 + bytes.len(); + if dst.len() < end { + Err(ProgramError::InvalidAccountData) + } else { + dst[..8].copy_from_slice(&self.data.len().to_le_bytes()); + dst[8..end].copy_from_slice(bytes); + Ok(()) + } + } + fn unpack_from_slice(src: &[u8]) -> Result { + let length = u64::from_le_bytes(src[..8].try_into().unwrap()) as usize; + if src[8..8 + length].len() != length { + return Err(ProgramError::InvalidAccountData); + } + let data = std::str::from_utf8(&src[8..8 + length]) + .unwrap() + .to_string(); + Ok(Self { data }) + } + fn get_packed_len(&self) -> Result { + Ok(size_of::().saturating_add(self.data.len())) + } + } + + #[test] + fn first_variable_len_value() { + let initial_data = "This is a pretty cool test!"; + // exactly the right size + let tlv_size = 8 + initial_data.len(); + let account_size = get_base_len() + tlv_size; + let mut buffer = vec![0; account_size]; + let mut state = TlvStateMut::unpack(&mut buffer).unwrap(); + + // don't actually need to hold onto the data! + let _ = state.alloc::(tlv_size, false).unwrap(); + let test_variable_len = TestVariableLen { + data: initial_data.to_string(), + }; + state + .pack_first_variable_len_value(&test_variable_len) + .unwrap(); + let deser = state + .get_first_variable_len_value::() + .unwrap(); + assert_eq!(deser, test_variable_len); + + // writing too much data fails + let too_much_data = "This is a pretty cool test!?"; + assert_eq!( + state + .pack_first_variable_len_value(&TestVariableLen { + data: too_much_data.to_string(), + }) + .unwrap_err(), + ProgramError::InvalidAccountData + ); + } + + #[test] + fn variable_len_value_with_repetition() { + let variable_len_1 = TestVariableLen { + data: "Let's see if we can pack multiple variable length values".to_string(), + }; + let tlv_size_1 = 8 + variable_len_1.data.len(); + + let variable_len_2 = TestVariableLen { + data: "I think we can".to_string(), + }; + let tlv_size_2 = 8 + variable_len_2.data.len(); + + let variable_len_3 = TestVariableLen { + data: "In fact, I know we can!".to_string(), + }; + let tlv_size_3 = 8 + variable_len_3.data.len(); + + let variable_len_4 = TestVariableLen { + data: "How cool is this?".to_string(), + }; + let tlv_size_4 = 8 + variable_len_4.data.len(); + + let account_size = get_base_len() + + tlv_size_1 + + get_base_len() + + tlv_size_2 + + get_base_len() + + tlv_size_3 + + get_base_len() + + tlv_size_4; + let mut buffer = vec![0; account_size]; + let mut state = TlvStateMut::unpack(&mut buffer).unwrap(); + + let (_, repetition_number) = state.alloc::(tlv_size_1, true).unwrap(); + state + .pack_variable_len_value_with_repetition(&variable_len_1, repetition_number) + .unwrap(); + assert_eq!(repetition_number, 0); + assert_eq!( + state + .get_first_variable_len_value::() + .unwrap(), + variable_len_1, + ); + + let (_, repetition_number) = state.alloc::(tlv_size_2, true).unwrap(); + state + .pack_variable_len_value_with_repetition(&variable_len_2, repetition_number) + .unwrap(); + assert_eq!(repetition_number, 1); + assert_eq!( + state + .get_variable_len_value_with_repetition::(repetition_number) + .unwrap(), + variable_len_2, + ); + + let (_, repetition_number) = state.alloc::(tlv_size_3, true).unwrap(); + state + .pack_variable_len_value_with_repetition(&variable_len_3, repetition_number) + .unwrap(); + assert_eq!(repetition_number, 2); + assert_eq!( + state + .get_variable_len_value_with_repetition::(repetition_number) + .unwrap(), + variable_len_3, + ); + + let (_, repetition_number) = state.alloc::(tlv_size_4, true).unwrap(); + state + .pack_variable_len_value_with_repetition(&variable_len_4, repetition_number) + .unwrap(); + assert_eq!(repetition_number, 3); + assert_eq!( + state + .get_variable_len_value_with_repetition::(repetition_number) + .unwrap(), + variable_len_4, + ); + } + + #[test] + fn add_entry_mix_and_match() { + let mut buffer = vec![]; + + // Add an entry for a fixed length value + let fixed_data = TestValue { data: [1; 32] }; + let tlv_size = get_base_len() + size_of::(); + buffer.extend(vec![0; tlv_size]); + { + let mut state = TlvStateMut::unpack(&mut buffer).unwrap(); + let (value, repetition_number) = state.init_value::(true).unwrap(); + value.data = fixed_data.data; + assert_eq!(repetition_number, 0); + assert_eq!(*value, fixed_data); + } + + // Add an entry for a variable length value + let variable_data = TestVariableLen { + data: "This is my first variable length entry!".to_string(), + }; + let tlv_size = get_base_len() + 8 + variable_data.data.len(); + buffer.extend(vec![0; tlv_size]); + { + let mut state = TlvStateMut::unpack(&mut buffer).unwrap(); + let repetition_number = state + .alloc_and_pack_variable_len_entry(&variable_data, true) + .unwrap(); + let value = state + .get_variable_len_value_with_repetition::(repetition_number) + .unwrap(); + assert_eq!(repetition_number, 0); + assert_eq!(value, variable_data); + } + + // Add another entry for a variable length value + let variable_data = TestVariableLen { + data: "This is actually my second variable length entry!".to_string(), + }; + let tlv_size = get_base_len() + 8 + variable_data.data.len(); + buffer.extend(vec![0; tlv_size]); + { + let mut state = TlvStateMut::unpack(&mut buffer).unwrap(); + let repetition_number = state + .alloc_and_pack_variable_len_entry(&variable_data, true) + .unwrap(); + let value = state + .get_variable_len_value_with_repetition::(repetition_number) + .unwrap(); + assert_eq!(repetition_number, 1); + assert_eq!(value, variable_data); + } + + // Add another entry for a fixed length value + let fixed_data = TestValue { data: [2; 32] }; + let tlv_size = get_base_len() + size_of::(); + buffer.extend(vec![0; tlv_size]); + { + let mut state = TlvStateMut::unpack(&mut buffer).unwrap(); + let (value, repetition_number) = state.init_value::(true).unwrap(); + value.data = fixed_data.data; + assert_eq!(repetition_number, 1); + assert_eq!(*value, fixed_data); + } + + // Add another entry for a fixed length value + let fixed_data = TestValue { data: [3; 32] }; + let tlv_size = get_base_len() + size_of::(); + buffer.extend(vec![0; tlv_size]); + { + let mut state = TlvStateMut::unpack(&mut buffer).unwrap(); + let (value, repetition_number) = state.init_value::(true).unwrap(); + value.data = fixed_data.data; + assert_eq!(repetition_number, 2); + assert_eq!(*value, fixed_data); + } + + // Add another entry for a variable length value + let variable_data = TestVariableLen { + data: "Wow! My third variable length entry!".to_string(), + }; + let tlv_size = get_base_len() + 8 + variable_data.data.len(); + buffer.extend(vec![0; tlv_size]); + { + let mut state = TlvStateMut::unpack(&mut buffer).unwrap(); + let repetition_number = state + .alloc_and_pack_variable_len_entry(&variable_data, true) + .unwrap(); + let value = state + .get_variable_len_value_with_repetition::(repetition_number) + .unwrap(); + assert_eq!(repetition_number, 2); + assert_eq!(value, variable_data); + } + } +} diff --git a/type-length-value/src/variable_len_pack.rs b/type-length-value/src/variable_len_pack.rs new file mode 100644 index 00000000..b2750d25 --- /dev/null +++ b/type-length-value/src/variable_len_pack.rs @@ -0,0 +1,27 @@ +//! The [`VariableLenPack`] serialization trait. + +use solana_program_error::ProgramError; + +/// Trait that mimics a lot of the functionality of +/// `solana_program_pack::Pack` but specifically works for +/// variable-size types. +pub trait VariableLenPack { + /// Writes the serialized form of the instance into the given slice + fn pack_into_slice(&self, dst: &mut [u8]) -> Result<(), ProgramError>; + + /// Deserializes the type from the given slice + fn unpack_from_slice(src: &[u8]) -> Result + where + Self: Sized; + + /// Gets the packed length for a given instance of the type + fn get_packed_len(&self) -> Result; + + /// Safely write the contents to the type into the given slice + fn pack(&self, dst: &mut [u8]) -> Result<(), ProgramError> { + if dst.len() != self.get_packed_len()? { + return Err(ProgramError::InvalidAccountData); + } + self.pack_into_slice(dst) + } +}