diff --git a/.github/workflows/rust.yml b/.github/workflows/rust.yml index efa24bff78..2054edec76 100644 --- a/.github/workflows/rust.yml +++ b/.github/workflows/rust.yml @@ -53,7 +53,8 @@ jobs: cargo test -p light-account-checks --all-features cargo test -p light-verifier --all-features cargo test -p light-merkle-tree-metadata --all-features - cargo test -p light-zero-copy --features std + cargo test -p light-zero-copy --features "std, mut, derive" + cargo test -p light-zero-copy-derive --features "mut" cargo test -p light-hash-set --all-features - name: program-libs-slow packages: light-bloom-filter light-indexed-merkle-tree light-batched-merkle-tree diff --git a/Cargo.lock b/Cargo.lock index 03057b4669..6dec5247d1 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2365,6 +2365,12 @@ version = "0.31.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "07e28edb80900c19c28f1072f2e8aeca7fa06b23cd4169cefe1af5aa3260783f" +[[package]] +name = "glob" +version = "0.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a8d1add55171497b4705a648c6b583acafb01d58050a51727785f0b2c8e0a2b2" + [[package]] name = "governor" version = "0.6.3" @@ -3786,6 +3792,8 @@ dependencies = [ name = "light-zero-copy" version = "0.2.0" dependencies = [ + "borsh 0.10.4", + "light-zero-copy-derive", "pinocchio", "rand 0.8.5", "solana-program-error", @@ -3793,6 +3801,21 @@ dependencies = [ "zerocopy", ] +[[package]] +name = "light-zero-copy-derive" +version = "0.1.0" +dependencies = [ + "borsh 0.10.4", + "lazy_static", + "light-zero-copy", + "proc-macro2", + "quote", + "rand 0.8.5", + "syn 2.0.103", + "trybuild", + "zerocopy", +] + [[package]] name = "linux-raw-sys" version = "0.9.4" @@ -9054,6 +9077,12 @@ version = "1.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "55937e1799185b12863d447f42597ed69d9928686b8d88a1df17376a097d8369" +[[package]] +name = "target-triple" +version = "0.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1ac9aa371f599d22256307c24a9d748c041e548cbf599f35d890f9d365361790" + [[package]] name = "tarpc" version = "0.29.0" @@ -9626,6 +9655,21 @@ version = "0.2.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e421abadd41a4225275504ea4d6566923418b7f05506fbc9c0fe86ba7396114b" +[[package]] +name = "trybuild" +version = "1.0.105" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1c9bf9513a2f4aeef5fdac8677d7d349c79fdbcc03b9c86da6e9d254f1e43be2" +dependencies = [ + "glob", + "serde", + "serde_derive", + "serde_json", + "target-triple", + "termcolor", + "toml 0.8.23", +] + [[package]] name = "tungstenite" version = "0.20.1" diff --git a/Cargo.toml b/Cargo.toml index 176115adb1..87af2c81ca 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -13,6 +13,7 @@ members = [ "program-libs/hash-set", "program-libs/indexed-merkle-tree", "program-libs/indexed-array", + "program-libs/zero-copy-derive", "programs/account-compression", "programs/system", "programs/compressed-token", @@ -167,6 +168,7 @@ light-compressed-account = { path = "program-libs/compressed-account", version = light-account-checks = { path = "program-libs/account-checks", version = "0.3.0" } light-verifier = { path = "program-libs/verifier", version = "2.1.0" } light-zero-copy = { path = "program-libs/zero-copy", version = "0.2.0" } +light-zero-copy-derive = { path = "program-libs/zero-copy-derive", version = "0.1.0" } photon-api = { path = "sdk-libs/photon-api", version = "0.51.0" } forester-utils = { path = "forester-utils", version = "2.0.0" } account-compression = { path = "programs/account-compression", version = "2.0.0", features = [ diff --git a/program-libs/expand.rs b/program-libs/expand.rs new file mode 100644 index 0000000000..e69de29bb2 diff --git a/program-libs/zero-copy-derive/Cargo.toml b/program-libs/zero-copy-derive/Cargo.toml new file mode 100644 index 0000000000..1cdc8254e8 --- /dev/null +++ b/program-libs/zero-copy-derive/Cargo.toml @@ -0,0 +1,26 @@ +[package] +name = "light-zero-copy-derive" +version = "0.1.0" +edition = "2021" +license = "Apache-2.0" +description = "Proc macro for zero-copy deserialization" + +[features] +default = [] +mut = [] + +[lib] +proc-macro = true + +[dependencies] +proc-macro2 = "1.0" +quote = "1.0" +syn = { version = "2.0", features = ["full", "extra-traits"] } +lazy_static = "1.4" + +[dev-dependencies] +trybuild = "1.0" +rand = "0.8" +borsh = { workspace = true } +light-zero-copy = { workspace = true, features = ["std", "derive"] } +zerocopy = { workspace = true, features = ["derive"] } diff --git a/program-libs/zero-copy-derive/README.md b/program-libs/zero-copy-derive/README.md new file mode 100644 index 0000000000..8e17fbbb25 --- /dev/null +++ b/program-libs/zero-copy-derive/README.md @@ -0,0 +1,103 @@ +# Light-Zero-Copy-Derive + +A procedural macro for deriving zero-copy deserialization for Rust structs used with Solana programs. + +## Features + +This crate provides two key derive macros: + +1. `#[derive(ZeroCopy)]` - Implements zero-copy deserialization with: + - The `zero_copy_at` and `zero_copy_at_mut` methods for deserialization + - Full Borsh compatibility for serialization/deserialization + - Efficient memory representation with no copying of data + - `From>` and `FromMut>` implementations for easy conversion back to the original struct + +2. `#[derive(ZeroCopyEq)]` - Adds equality comparison support: + - Compare zero-copy instances with regular struct instances + - Can be used alongside `ZeroCopy` for complete functionality + - Derivation for Options is not robust and may not compile. + +## Rules for Zero-Copy Deserialization + +The macro follows these rules when generating code: + +1. Creates a `ZStruct` for your struct that follows zero-copy principles + 1. Fields are extracted into a meta struct until reaching a `Vec`, `Option` or non-`Copy` type + 2. Vectors are represented as `ZeroCopySlice` and not included in the meta struct + 3. Integer types are replaced with their zerocopy equivalents (e.g., `u16` → `U16`) + 4. Fields after the first vector are directly included in the `ZStruct` and deserialized one by one + 5. If a vector contains a nested vector (non-`Copy` type), it must implement `Deserialize` + 6. Elements in an `Option` must implement `Deserialize` + 7. Types that don't implement `Copy` must implement `Deserialize` and are deserialized one by one + +## Usage + +### Basic Usage + +```rust +use borsh::{BorshDeserialize, BorshSerialize}; +use light_zero_copy_derive::ZeroCopy; +use light_zero_copy::{borsh::Deserialize, borsh_mut::DeserializeMut}; + +#[repr(C)] +#[derive(Debug, PartialEq, BorshSerialize, BorshDeserialize, ZeroCopy)] +pub struct MyStruct { + pub a: u8, + pub b: u16, + pub vec: Vec, + pub c: u64, +} +let my_struct = MyStruct { + a: 1, + b: 2, + vec: vec![1u8; 32], + c: 3, +}; +// Use the struct with zero-copy deserialization +let mut bytes = my_struct.try_to_vec().unwrap(); + +// Immutable zero-copy deserialization +let (zero_copy, _remaining) = MyStruct::zero_copy_at(&bytes).unwrap(); + +// Convert back to original struct using From implementation +let converted: MyStruct = zero_copy.clone().into(); +assert_eq!(converted, my_struct); + +// Mutable zero-copy deserialization with modification +let (mut zero_copy_mut, _remaining) = MyStruct::zero_copy_at_mut(&mut bytes).unwrap(); +zero_copy_mut.a = 42; + +// The change is reflected when we convert back to the original struct +let modified: MyStruct = zero_copy_mut.into(); +assert_eq!(modified.a, 42); + +// And also when we deserialize directly from the modified bytes +let borsh = MyStruct::try_from_slice(&bytes).unwrap(); +assert_eq!(borsh.a, 42u8); +``` + +### With Equality Comparison + +```rust +use borsh::{BorshDeserialize, BorshSerialize}; +use light_zero_copy_derive::ZeroCopy; + +#[repr(C)] +#[derive(Debug, PartialEq, BorshSerialize, BorshDeserialize, ZeroCopy)] +pub struct MyStruct { + pub a: u8, + pub b: u16, + pub vec: Vec, + pub c: u64, +} +let my_struct = MyStruct { + a: 1, + b: 2, + vec: vec![1u8; 32], + c: 3, +}; +// Use the struct with zero-copy deserialization +let mut bytes = my_struct.try_to_vec().unwrap(); +let (zero_copy, _remaining) = MyStruct::zero_copy_at(&bytes).unwrap(); +assert_eq!(zero_copy, my_struct); +``` diff --git a/program-libs/zero-copy-derive/src/byte_len_derive.rs b/program-libs/zero-copy-derive/src/byte_len_derive.rs new file mode 100644 index 0000000000..c260b5720a --- /dev/null +++ b/program-libs/zero-copy-derive/src/byte_len_derive.rs @@ -0,0 +1,132 @@ +use proc_macro2::TokenStream; +use quote::quote; +use syn::{Field, Ident}; + +use crate::{ + utils, + z_struct::{analyze_struct_fields, FieldType}, +}; + +/// Generates ByteLen implementation for structs +/// +/// RULES AND EXCEPTIONS FROM borsh_mut.rs: +/// +/// DEFAULT RULE: Call byte_len() on each field and sum the results +/// +/// EXCEPTIONS: +/// 1. Boolean fields: Use core::mem::size_of::() (1 byte) instead of byte_len() +/// * See line 97 where booleans use a special case +/// +/// NOTES ON TYPE-SPECIFIC IMPLEMENTATIONS: +/// * Primitive types: self.field.byte_len() delegates to size_of::() +/// - u8, u16, u32, u64, etc. all use size_of::() in their implementations +/// - See implementations in lines 88-90, 146-148, and macro in lines 135-151 +/// +/// * Arrays [T; N]: use size_of::() in implementation (line 41) +/// +/// * Vec: 4 bytes for length prefix + sum of byte_len() for each element +/// - The Vec implementation in line 131 is: 4 + self.iter().map(|t| t.byte_len()).sum::() +/// - Special case in Struct4 (line 650-657): explicitly sums the byte_len of each item +/// +/// * VecU8: Uses 1 byte for length prefix instead of regular Vec's 4 bytes +/// - Implementation in line 205 shows: 1 + size_of::() +/// +/// * Option: 1 byte for discriminator + value's byte_len if Some, or just 1 byte if None +/// - See implementation in lines 66-72 +/// +/// * Fixed-size types: Generally implement as their own fixed size +/// - Pubkey (line 45-46): hard-coded as 32 bytes +pub fn generate_byte_len_derive_impl<'a>( + _name: &Ident, + meta_fields: &'a [&'a Field], + struct_fields: &'a [&'a Field], +) -> TokenStream { + let field_types = analyze_struct_fields(struct_fields); + + // Generate statements for calculating byte_len for each field + let meta_byte_len = if !meta_fields.is_empty() { + meta_fields + .iter() + .map(|field| { + let field_name = &field.ident; + // Handle boolean fields specially by using size_of instead of byte_len + if utils::is_bool_type(&field.ty) { + quote! { core::mem::size_of::() } + } else { + quote! { self.#field_name.byte_len() } + } + }) + .reduce(|acc, item| { + quote! { #acc + #item } + }) + } else { + None + }; + + // Generate byte_len calculations for struct fields + // Default rule: Use self.field.byte_len() for all fields + // Exception: Use core::mem::size_of::() for boolean fields + let struct_byte_len = field_types.into_iter().map(|field_type| { + match field_type { + // Exception 1: Booleans use size_of::() directly + FieldType::Bool(_) | FieldType::CopyU8Bool(_) => { + quote! { core::mem::size_of::() } + } + // All other types delegate to their own byte_len implementation + FieldType::VecU8(field_name) + | FieldType::VecCopy(field_name, _) + | FieldType::VecNonCopy(field_name, _) + | FieldType::Array(field_name, _) + | FieldType::Option(field_name, _) + | FieldType::Pubkey(field_name) + | FieldType::IntegerU64(field_name) + | FieldType::IntegerU32(field_name) + | FieldType::IntegerU16(field_name) + | FieldType::IntegerU8(field_name) + | FieldType::Copy(field_name, _) + | FieldType::NonCopy(field_name, _) => { + quote! { self.#field_name.byte_len() } + }, + FieldType::OptionU64(field_name) + | FieldType::OptionU32(field_name) + | FieldType::OptionU16(field_name) => { + quote! { self.#field_name.as_ref().map_or(1, |x| 1 + x.byte_len()) } + } + } + }); + + // Combine meta fields and struct fields for total byte_len calculation + let combined_byte_len = match meta_byte_len { + Some(meta) => { + let struct_bytes = struct_byte_len.fold(quote!(), |acc, item| { + if acc.is_empty() { + item + } else { + quote! { #acc + #item } + } + }); + + if struct_bytes.is_empty() { + meta + } else { + quote! { #meta + #struct_bytes } + } + } + None => struct_byte_len.fold(quote!(), |acc, item| { + if acc.is_empty() { + item + } else { + quote! { #acc + #item } + } + }), + }; + + // Generate the final implementation + quote! { + impl light_zero_copy::ByteLen for #_name { + fn byte_len(&self) -> usize { + #combined_byte_len + } + } + } +} \ No newline at end of file diff --git a/program-libs/zero-copy-derive/src/config.rs b/program-libs/zero-copy-derive/src/config.rs new file mode 100644 index 0000000000..9a522f12e1 --- /dev/null +++ b/program-libs/zero-copy-derive/src/config.rs @@ -0,0 +1,368 @@ +use proc_macro2::TokenStream; +use quote::quote; +use syn::Ident; + +use crate::{utils, z_struct::FieldType}; + +/// Configuration system for zero-copy initialization +/// +/// This module provides functionality to generate configuration structs and +/// initialization logic for zero-copy structures with Vec and Option fields. +/// Helper functions for FieldType to support configuration +/// Determine if this field type requires configuration for initialization +pub fn requires_config(field_type: &FieldType) -> bool { + match field_type { + // Vec types always need length configuration + FieldType::VecU8(_) | FieldType::VecCopy(_, _) | FieldType::VecNonCopy(_, _) => true, + // Option types need Some/None configuration + FieldType::Option(_, _) => true, + // Fixed-size types don't need configuration + FieldType::Array(_, _) + | FieldType::Pubkey(_) + | FieldType::IntegerU64(_) + | FieldType::IntegerU32(_) + | FieldType::IntegerU16(_) + | FieldType::IntegerU8(_) + | FieldType::Bool(_) + | FieldType::CopyU8Bool(_) + | FieldType::Copy(_, _) => false, + // NonCopy types might need configuration if they contain Vec/Option + FieldType::NonCopy(_, _) => true, // Conservative: assume they need config + // Option integer types need config to determine if they're enabled + FieldType::OptionU64(_) | FieldType::OptionU32(_) | FieldType::OptionU16(_) => true, + } +} + +/// Generate the config type for this field +pub fn config_type(field_type: &FieldType) -> TokenStream { + match field_type { + // Simple Vec types: just need length + FieldType::VecU8(_) => quote! { u32 }, + FieldType::VecCopy(_, _) => quote! { u32 }, + + // Complex Vec types: need config for each element + FieldType::VecNonCopy(_, vec_type) => { + if let Some(inner_type) = utils::get_vec_inner_type(vec_type) { + quote! { Vec<<#inner_type as light_zero_copy::init_mut::ZeroCopyNew<'static>>::Config> } + } else { + panic!("Could not determine inner type for VecNonCopy config"); + } + } + + // Option types: delegate to the Option's Config type + FieldType::Option(_, option_type) => { + quote! { <#option_type as light_zero_copy::init_mut::ZeroCopyNew<'static>>::Config } + } + + // Fixed-size types don't need configuration + FieldType::Array(_, _) + | FieldType::Pubkey(_) + | FieldType::IntegerU64(_) + | FieldType::IntegerU32(_) + | FieldType::IntegerU16(_) + | FieldType::IntegerU8(_) + | FieldType::Bool(_) + | FieldType::CopyU8Bool(_) + | FieldType::Copy(_, _) => quote! { () }, + + // Option integer types: use bool config to determine if enabled + FieldType::OptionU64(_) | FieldType::OptionU32(_) | FieldType::OptionU16(_) => { + quote! { bool } + } + + // NonCopy types: delegate to their Config type (Config is typically 'static) + FieldType::NonCopy(_, field_type) => { + quote! { <#field_type as light_zero_copy::init_mut::ZeroCopyNew<'static>>::Config } + } + } +} + +/// Generate a configuration struct for a given struct +pub fn generate_config_struct(struct_name: &Ident, field_types: &[FieldType]) -> TokenStream { + let config_name = quote::format_ident!("{}Config", struct_name); + + // Generate config fields only for fields that require configuration + let config_fields: Vec = field_types + .iter() + .filter(|field_type| requires_config(field_type)) + .map(|field_type| { + let field_name = field_type.name(); + let config_type = config_type(field_type); + quote! { + pub #field_name: #config_type, + } + }) + .collect(); + + if config_fields.is_empty() { + // If no fields require configuration, create an empty config struct + quote! { + #[derive(Debug, Clone, PartialEq)] + pub struct #config_name; + } + } else { + quote! { + #[derive(Debug, Clone, PartialEq)] + pub struct #config_name { + #(#config_fields)* + } + } + } +} + +/// Generate initialization logic for a field based on its configuration +pub fn generate_field_initialization(field_type: &FieldType) -> TokenStream { + match field_type { + FieldType::VecU8(field_name) => { + quote! { + // Initialize the length prefix but don't use the returned ZeroCopySliceMut + { + light_zero_copy::slice_mut::ZeroCopySliceMutBorsh::::new_at( + config.#field_name.into(), + bytes + )?; + } + // Split off the length prefix (4 bytes) and get the slice + let (_, bytes) = bytes.split_at_mut(4); + let (#field_name, bytes) = bytes.split_at_mut(config.#field_name as usize); + } + } + + FieldType::VecCopy(field_name, inner_type) => { + quote! { + let (#field_name, bytes) = light_zero_copy::slice_mut::ZeroCopySliceMutBorsh::<#inner_type>::new_at( + config.#field_name.into(), + bytes + )?; + } + } + + FieldType::VecNonCopy(field_name, vec_type) => { + quote! { + let (#field_name, bytes) = <#vec_type as light_zero_copy::init_mut::ZeroCopyNew<'a>>::new_zero_copy( + bytes, + config.#field_name + )?; + } + } + + FieldType::Option(field_name, option_type) => { + quote! { + let (#field_name, bytes) = <#option_type as light_zero_copy::init_mut::ZeroCopyNew<'a>>::new_zero_copy(bytes, config.#field_name)?; + } + } + + // Fixed-size types that are struct fields (not meta fields) need initialization with () config + FieldType::IntegerU64(field_name) => { + quote! { + let (#field_name, bytes) = light_zero_copy::Ref::< + &'a mut [u8], + light_zero_copy::little_endian::U64 + >::from_prefix(bytes)?; + } + } + + FieldType::IntegerU32(field_name) => { + quote! { + let (#field_name, bytes) = light_zero_copy::Ref::< + &'a mut [u8], + light_zero_copy::little_endian::U32 + >::from_prefix(bytes)?; + } + } + + FieldType::IntegerU16(field_name) => { + quote! { + let (#field_name, bytes) = light_zero_copy::Ref::< + &'a mut [u8], + light_zero_copy::little_endian::U16 + >::from_prefix(bytes)?; + } + } + + FieldType::IntegerU8(field_name) => { + quote! { + let (#field_name, bytes) = light_zero_copy::Ref::<&mut [u8], u8>::from_prefix(bytes)?; + } + } + + FieldType::Bool(field_name) => { + quote! { + let (#field_name, bytes) = light_zero_copy::Ref::<&mut [u8], u8>::from_prefix(bytes)?; + } + } + + // Array fields that are struct fields (come after Vec/Option) + FieldType::Array(field_name, array_type) => { + quote! { + let (#field_name, bytes) = light_zero_copy::Ref::< + &'a mut [u8], + #array_type + >::from_prefix(bytes)?; + } + } + + FieldType::Pubkey(field_name) => { + quote! { + let (#field_name, bytes) = light_zero_copy::Ref::< + &'a mut [u8], + Pubkey + >::from_prefix(bytes)?; + } + } + + // Types that are truly meta fields (shouldn't reach here for struct fields) + FieldType::CopyU8Bool(_) | FieldType::Copy(_, _) => { + quote! { + // Should not reach here for struct fields - these should be meta fields + } + } + + FieldType::OptionU64(field_name) => { + quote! { + let (#field_name, bytes) = as light_zero_copy::init_mut::ZeroCopyNew>::new_zero_copy( + bytes, + (config.#field_name, ()) + )?; + } + } + + FieldType::OptionU32(field_name) => { + quote! { + let (#field_name, bytes) = as light_zero_copy::init_mut::ZeroCopyNew>::new_zero_copy( + bytes, + (config.#field_name, ()) + )?; + } + } + + FieldType::OptionU16(field_name) => { + quote! { + let (#field_name, bytes) = as light_zero_copy::init_mut::ZeroCopyNew>::new_zero_copy( + bytes, + (config.#field_name, ()) + )?; + } + } + + FieldType::NonCopy(field_name, field_type) => { + quote! { + let (#field_name, bytes) = <#field_type as light_zero_copy::init_mut::ZeroCopyNew<'a>>::new_zero_copy( + bytes, + config.#field_name + )?; + } + } + } +} + +/// Generate byte length calculation for a field based on its configuration +pub fn generate_byte_len_calculation(field_type: &FieldType) -> TokenStream { + match field_type { + // Vec types that require configuration + FieldType::VecU8(field_name) => { + quote! { + (4 + config.#field_name as usize) // 4 bytes for length + actual data + } + } + + FieldType::VecCopy(field_name, inner_type) => { + quote! { + (4 + (config.#field_name as usize * core::mem::size_of::<#inner_type>())) + } + } + + FieldType::VecNonCopy(field_name, vec_type) => { + quote! { + <#vec_type as light_zero_copy::init_mut::ZeroCopyNew<'static>>::byte_len(&config.#field_name) + } + } + + // Option types + FieldType::Option(field_name, option_type) => { + quote! { + <#option_type as light_zero_copy::init_mut::ZeroCopyNew<'static>>::byte_len(&config.#field_name) + } + } + + FieldType::OptionU64(field_name) => { + quote! { + as light_zero_copy::init_mut::ZeroCopyNew<'static>>::byte_len(&(config.#field_name, ())) + } + } + + FieldType::OptionU32(field_name) => { + quote! { + as light_zero_copy::init_mut::ZeroCopyNew<'static>>::byte_len(&(config.#field_name, ())) + } + } + + FieldType::OptionU16(field_name) => { + quote! { + as light_zero_copy::init_mut::ZeroCopyNew<'static>>::byte_len(&(config.#field_name, ())) + } + } + + // Fixed-size types don't need configuration and have known sizes + FieldType::IntegerU64(_) => { + quote! { + core::mem::size_of::() + } + } + + FieldType::IntegerU32(_) => { + quote! { + core::mem::size_of::() + } + } + + FieldType::IntegerU16(_) => { + quote! { + core::mem::size_of::() + } + } + + FieldType::IntegerU8(_) => { + quote! { + core::mem::size_of::() + } + } + + FieldType::Bool(_) => { + quote! { + core::mem::size_of::() // bool is serialized as u8 + } + } + + FieldType::Array(_, array_type) => { + quote! { + core::mem::size_of::<#array_type>() + } + } + + FieldType::Pubkey(_) => { + quote! { + 32 // Pubkey is always 32 bytes + } + } + + // Meta field types (should not appear in struct fields, but handle gracefully) + FieldType::CopyU8Bool(_) => { + quote! { + core::mem::size_of::() + } + } + + FieldType::Copy(_, field_type) => { + quote! { + core::mem::size_of::<#field_type>() + } + } + + FieldType::NonCopy(field_name, field_type) => { + quote! { + <#field_type as light_zero_copy::init_mut::ZeroCopyNew<'static>>::byte_len(&config.#field_name) + } + } + } +} diff --git a/program-libs/zero-copy-derive/src/deserialize_impl.rs b/program-libs/zero-copy-derive/src/deserialize_impl.rs new file mode 100644 index 0000000000..3a50e21024 --- /dev/null +++ b/program-libs/zero-copy-derive/src/deserialize_impl.rs @@ -0,0 +1,583 @@ +use proc_macro2::TokenStream; +use quote::{format_ident, quote}; +use syn::{parse_quote, Field, Ident}; + +use crate::{ + utils, + z_struct::{analyze_struct_fields, FieldType}, +}; + +/// Generates field deserialization code for the Deserialize implementation +/// The `MUT` parameter controls whether to generate code for mutable or immutable references +pub fn generate_deserialize_fields<'a, const MUT: bool>( + struct_fields: &'a [&'a Field], +) -> impl Iterator + 'a { + let field_types = analyze_struct_fields(struct_fields); + + field_types.into_iter().map(move |field_type| { + let trait_path = if MUT { + quote!( as light_zero_copy::borsh_mut::DeserializeMut>::zero_copy_at_mut) + } else { + quote!( as light_zero_copy::borsh::Deserialize>::zero_copy_at) + }; + let mutability_tokens = if MUT { + quote!(&'a mut [u8]) + } else { + quote!(&'a [u8]) + }; + match field_type { + FieldType::VecU8(field_name) => { + if MUT { + quote! { + let (#field_name, bytes) = light_zero_copy::borsh_mut::borsh_vec_u8_as_slice_mut(bytes)?; + } + } else { + quote! { + let (#field_name, bytes) = light_zero_copy::borsh::borsh_vec_u8_as_slice(bytes)?; + } + } + }, + FieldType::VecCopy(field_name, inner_type) => { + let trait_path = if MUT { + quote!(light_zero_copy::slice_mut::ZeroCopySliceMutBorsh::<'a, <#inner_type as light_zero_copy::borsh_mut::ZeroCopyStructInnerMut>::ZeroCopyInnerMut>) + } else { + quote!(light_zero_copy::slice::ZeroCopySliceBorsh::<'a, <#inner_type as light_zero_copy::borsh::ZeroCopyStructInner>::ZeroCopyInner>) + }; + quote! { + let (#field_name, bytes) = #trait_path::from_bytes_at(bytes)?; + } + }, + FieldType::VecNonCopy(field_name, field_type) => { + quote! { + let (#field_name, bytes) = <#field_type #trait_path(bytes)?; + } + }, + FieldType::Array(field_name, field_type) => { + quote! { + let (#field_name, bytes) = light_zero_copy::Ref::<#mutability_tokens, #field_type>::from_prefix(bytes)?; + } + }, + FieldType::Option(field_name, field_type) => { + quote! { + let (#field_name, bytes) = <#field_type #trait_path(bytes)?; + } + }, + FieldType::Pubkey(field_name) => { + quote! { + let (#field_name, bytes) = { + if MUT { + quote! { + let (#field_name, bytes) = light_zero_copy::Ref::<#mutability_tokens, u8>::from_prefix(bytes)?; + } + } else { + quote! { + let (#field_name, bytes) = { + let field_ty_zerocopy = utils::convert_to_zerocopy_type(&parse_quote!(u64)); + quote! { + let (#field_name, bytes) = light_zero_copy::Ref::<#mutability_tokens, #field_ty_zerocopy>::from_prefix(bytes)?; + } + }, + FieldType::IntegerU32(field_name) => { + let field_ty_zerocopy = utils::convert_to_zerocopy_type(&parse_quote!(u32)); + quote! { + let (#field_name, bytes) = light_zero_copy::Ref::<#mutability_tokens, #field_ty_zerocopy>::from_prefix(bytes)?; + } + }, + FieldType::IntegerU16(field_name) => { + let field_ty_zerocopy = utils::convert_to_zerocopy_type(&parse_quote!(u16)); + quote! { + let (#field_name, bytes) = light_zero_copy::Ref::<#mutability_tokens, #field_ty_zerocopy>::from_prefix(bytes)?; + } + }, + FieldType::IntegerU8(field_name) => { + if MUT { + quote! { + let (#field_name, bytes) = light_zero_copy::Ref::<#mutability_tokens, u8>::from_prefix(bytes)?; + } + } else { + quote! { + let (#field_name, bytes) = { + let field_ty_zerocopy = utils::convert_to_zerocopy_type(field_type); + quote! { + let (#field_name, bytes) = light_zero_copy::Ref::<#mutability_tokens, #field_ty_zerocopy>::from_prefix(bytes)?; + } + }, + FieldType::NonCopy(field_name, field_type) => { + quote! { + let (#field_name, bytes) = <#field_type #trait_path(bytes)?; + } + }, + FieldType::OptionU64(field_name) => { + let field_ty_zerocopy = utils::convert_to_zerocopy_type(&parse_quote!(u64)); + quote! { + let (#field_name, bytes) = #trait_path(bytes)?; + } + }, + FieldType::OptionU32(field_name) => { + let field_ty_zerocopy = utils::convert_to_zerocopy_type(&parse_quote!(u32)); + quote! { + let (#field_name, bytes) = #trait_path(bytes)?; + } + }, + FieldType::OptionU16(field_name) => { + let field_ty_zerocopy = utils::convert_to_zerocopy_type(&parse_quote!(u16)); + quote! { + let (#field_name, bytes) = #trait_path(bytes)?; + } + } + } + }) +} + +/// Generates field initialization code for the Deserialize implementation +pub fn generate_init_fields<'a>( + struct_fields: &'a [&'a Field], +) -> impl Iterator + 'a { + struct_fields.iter().map(|field| { + let field_name = &field.ident; + quote! { #field_name } + }) +} + +/// Generates the Deserialize implementation as a TokenStream +/// The `MUT` parameter controls whether to generate code for mutable or immutable references +pub fn generate_deserialize_impl( + name: &Ident, + z_struct_name: &Ident, + z_struct_meta_name: &Ident, + struct_fields: &[&Field], + meta_is_empty: bool, + byte_len_impl: TokenStream, +) -> TokenStream { + let mut z_struct_name = z_struct_name.clone(); + let mut z_struct_meta_name = z_struct_meta_name.clone(); + + // Define trait and types based on mutability + let (trait_name, mutability, method_name) = if MUT { + z_struct_name = format_ident!("{}Mut", z_struct_name); + z_struct_meta_name = format_ident!("{}Mut", z_struct_meta_name); + ( + quote!(light_zero_copy::borsh_mut::DeserializeMut), + quote!(mut), + quote!(zero_copy_at_mut), + ) + } else { + ( + quote!(light_zero_copy::borsh::Deserialize), + quote!(), + quote!(zero_copy_at), + ) + }; + let (meta_des, meta) = if meta_is_empty { + (quote!(), quote!()) + } else { + ( + quote! { + let (__meta, bytes) = light_zero_copy::Ref::< &'a #mutability [u8], #z_struct_meta_name>::from_prefix(bytes)?; + }, + quote!(__meta,), + ) + }; + let deserialize_fields = generate_deserialize_fields::(struct_fields); + let init_fields = generate_init_fields(struct_fields); + + quote! { + impl<'a> #trait_name<'a> for #name { + type Output = #z_struct_name<'a>; + + fn #method_name(bytes: &'a #mutability [u8]) -> Result<(Self::Output, &'a #mutability [u8]), light_zero_copy::errors::ZeroCopyError> { + #meta_des + #(#deserialize_fields)* + Ok(( + #z_struct_name { + #meta + #(#init_fields,)* + }, + bytes + )) + } + + #byte_len_impl + } + } +} + +// #[cfg(test)] +// mod tests { +// use quote::format_ident; +// use rand::{prelude::SliceRandom, rngs::StdRng, thread_rng, Rng, SeedableRng}; +// use syn::parse_quote; + +// use super::*; + +// /// Generate a safe field name for testing +// fn random_ident(rng: &mut StdRng) -> String { +// // Use predetermined safe field names +// const FIELD_NAMES: &[&str] = &[ +// "field1", "field2", "field3", "field4", "field5", "value", "data", "count", "size", +// "flag", "name", "id", "code", "index", "key", "amount", "balance", "total", "result", +// "status", +// ]; + +// FIELD_NAMES.choose(rng).unwrap().to_string() +// } + +// /// Generate a random Rust type +// fn random_type(rng: &mut StdRng, _depth: usize) -> syn::Type { +// // Define our available types +// let types = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]; + +// // Randomly select a type index +// let selected = *types.choose(rng).unwrap(); + +// // Return the corresponding type +// match selected { +// 0 => parse_quote!(u8), +// 1 => parse_quote!(u16), +// 2 => parse_quote!(u32), +// 3 => parse_quote!(u64), +// 4 => parse_quote!(bool), +// 5 => parse_quote!(Vec), +// 6 => parse_quote!(Vec), +// 7 => parse_quote!(Vec), +// 8 => parse_quote!([u32; 12]), +// 9 => parse_quote!([Vec; 12]), +// 10 => parse_quote!([Vec; 20]), +// _ => unreachable!(), +// } +// } + +// /// Generate a random field +// fn random_field(rng: &mut StdRng) -> Field { +// let name = random_ident(rng); +// let ty = random_type(rng, 0); + +// // Use a safer approach to create the field +// let name_ident = format_ident!("{}", name); +// parse_quote!(pub #name_ident: #ty) +// } + +// /// Generate a list of random fields +// fn random_fields(rng: &mut StdRng, count: usize) -> Vec { +// (0..count).map(|_| random_field(rng)).collect() +// } + +// // Test for Vec field deserialization +// #[test] +// fn test_deserialize_vec_u8() { +// let field: Field = parse_quote!(pub data: Vec); +// let struct_fields = vec![&field]; + +// let result = generate_deserialize_fields::(&struct_fields).collect::>(); +// let result_str = result[0].to_string(); +// let expected = +// "let (data , bytes) = light_zero_copy :: borsh :: borsh_vec_u8_as_slice (bytes) ?"; + +// assert!(result_str.contains(expected)); +// } + +// // Test for Vec with Copy inner type deserialization +// #[test] +// fn test_deserialize_vec_copy_type() { +// let field: Field = parse_quote!(pub values: Vec); +// let struct_fields = vec![&field]; + +// let result = generate_deserialize_fields::(&struct_fields).collect::>(); +// let result_str = result[0].to_string(); +// let expected = "let (values , bytes) = light_zero_copy :: slice :: ZeroCopySliceBorsh :: < 'a , < u32 as light_zero_copy :: borsh :: ZeroCopyStructInner > :: ZeroCopyInner > :: from_bytes_at (bytes) ?"; + +// assert!(result_str.contains(expected)); +// } + +// // Test for Vec with non-Copy inner type deserialization +// #[test] +// fn test_deserialize_vec_non_copy_type() { +// // This is a synthetic test as we're treating String as a non-Copy type +// let field: Field = parse_quote!(pub names: Vec); +// let struct_fields = vec![&field]; + +// let result = generate_deserialize_fields::(&struct_fields).collect::>(); +// let result_str = result[0].to_string(); +// let expected = "let (names , bytes) = < Vec < String > as light_zero_copy :: borsh :: Deserialize > :: zero_copy_at (bytes) ?"; + +// assert!(result_str.contains(expected)); +// } + +// // Test for Option type deserialization +// #[test] +// fn test_deserialize_option_type() { +// let field: Field = parse_quote!(pub maybe_value: Option); +// let struct_fields = vec![&field]; + +// let result = generate_deserialize_fields::(&struct_fields).collect::>(); +// let result_str = result[0].to_string(); +// let expected = "let (maybe_value , bytes) = < Option < u32 > as light_zero_copy :: borsh :: Deserialize > :: zero_copy_at (bytes) ?"; + +// assert!(result_str.contains(expected)); +// } + +// // Test for non-Copy type deserialization +// #[test] +// fn test_deserialize_non_copy_type() { +// // Using String as a non-Copy type example +// let field: Field = parse_quote!(pub name: String); +// let struct_fields = vec![&field]; + +// let result = generate_deserialize_fields::(&struct_fields).collect::>(); +// let result_str = result[0].to_string(); +// let expected = "let (name , bytes) = < String as light_zero_copy :: borsh :: Deserialize > :: zero_copy_at (bytes) ?"; + +// assert!(result_str.contains(expected)); +// } + +// // Test for Copy type deserialization (primitive types) +// #[test] +// fn test_deserialize_copy_type() { +// let field: Field = parse_quote!(pub count: u32); +// let struct_fields = vec![&field]; + +// let result = generate_deserialize_fields::(&struct_fields).collect::>(); +// let result_str = result[0].to_string(); +// let expected = "let (count , bytes) = light_zero_copy :: Ref :: < & 'a [u8] , light_zero_copy :: little_endian :: U32 > :: from_prefix (bytes) ?"; +// println!("{}", result_str); +// assert!(result_str.contains(expected)); +// } + +// // Test for boolean type deserialization +// #[test] +// fn test_deserialize_bool_type() { +// let field: Field = parse_quote!(pub flag: bool); +// let struct_fields = vec![&field]; + +// let result = generate_deserialize_fields::(&struct_fields).collect::>(); +// let result_str = result[0].to_string(); +// let expected = +// "let (flag , bytes) = < u8 as light_zero_copy :: borsh :: Deserialize > :: zero_copy_at (bytes) ?"; +// println!("{}", result_str); +// assert!(result_str.contains(expected)); +// } + +// // Test for field initialization code generation +// #[test] +// fn test_init_fields() { +// let field1: Field = parse_quote!(pub id: u32); +// let field2: Field = parse_quote!(pub name: String); +// let struct_fields = vec![&field1, &field2]; + +// let result = generate_init_fields(&struct_fields).collect::>(); +// let result_str = format!("{} {}", result[0], result[1]); +// assert!(result_str.contains("id")); +// assert!(result_str.contains("name")); +// } + +// // Test for complete deserialize implementation generation +// #[test] +// fn test_generate_deserialize_impl() { +// let struct_name = format_ident!("TestStruct"); +// let z_struct_name = format_ident!("ZTestStruct"); +// let z_struct_meta_name = format_ident!("ZTestStructMeta"); + +// let field1: Field = parse_quote!(pub id: u32); +// let field2: Field = parse_quote!(pub values: Vec); +// let struct_fields = vec![&field1, &field2]; + +// let result = generate_deserialize_impl::( +// &struct_name, +// &z_struct_name, +// &z_struct_meta_name, +// &struct_fields, +// false, +// ) +// .to_string(); + +// // Check impl header +// assert!(result +// .contains("impl < 'a > light_zero_copy :: borsh :: Deserialize < 'a > for TestStruct")); + +// // Check Output type +// assert!(result.contains("type Output = ZTestStruct < 'a >")); + +// // Check method signature +// assert!(result.contains("fn zero_copy_at (bytes : & 'a [u8]) -> Result")); + +// // Check meta field extraction +// assert!(result.contains("let (__meta , bytes) = light_zero_copy :: Ref :: < & 'a [u8] , ZTestStructMeta > :: from_prefix (bytes) ?")); + +// // Check field deserialization +// assert!(result.contains("let (id , bytes) = light_zero_copy :: Ref :: < & 'a [u8] , light_zero_copy :: little_endian :: U32 > :: from_prefix (bytes) ?")); +// assert!(result.contains("let (values , bytes) = light_zero_copy :: slice :: ZeroCopySliceBorsh :: < 'a , < u16 as light_zero_copy :: borsh :: ZeroCopyStructInner > :: ZeroCopyInner > :: from_bytes_at (bytes) ?")); + +// // Check result structure +// assert!(result.contains("Ok ((ZTestStruct { __meta , id , values ,")); +// } + +// // Test for complete deserialize implementation generation +// #[test] +// fn test_generate_deserialize_impl_no_meta() { +// let struct_name = format_ident!("TestStruct"); +// let z_struct_name = format_ident!("ZTestStruct"); +// let z_struct_meta_name = format_ident!("ZTestStructMeta"); + +// let field1: Field = parse_quote!(pub id: u32); +// let field2: Field = parse_quote!(pub values: Vec); +// let struct_fields = vec![&field1, &field2]; + +// let result = generate_deserialize_impl::( +// &struct_name, +// &z_struct_name, +// &z_struct_meta_name, +// &struct_fields, +// true, +// ) +// .to_string(); + +// // Check impl header +// assert!(result +// .contains("impl < 'a > light_zero_copy :: borsh :: Deserialize < 'a > for TestStruct")); + +// // Check Output type +// assert!(result.contains("type Output = ZTestStruct < 'a >")); + +// // Check method signature +// assert!(result.contains("fn zero_copy_at (bytes : & 'a [u8]) -> Result")); + +// // Check meta field extraction +// assert!(!result.contains("let (meta , bytes) = light_zero_copy :: Ref :: < & 'a [u8] , ZTestStructMeta > :: from_prefix (bytes) ?")); + +// // Check field deserialization +// assert!(result.contains("let (id , bytes) = light_zero_copy :: Ref :: < & 'a [u8] , light_zero_copy :: little_endian :: U32 > :: from_prefix (bytes) ?")); +// assert!(result.contains("let (values , bytes) = light_zero_copy :: slice :: ZeroCopySliceBorsh :: < 'a , < u16 as light_zero_copy :: borsh :: ZeroCopyStructInner > :: ZeroCopyInner > :: from_bytes_at (bytes) ?")); + +// // Check result structure +// assert!(result.contains("Ok ((ZTestStruct { id , values ,")); +// } + +// #[test] +// fn test_fuzz_generate_deserialize_impl() { +// // Set up RNG with a seed for reproducibility +// let seed = thread_rng().gen(); +// println!("seed {}", seed); +// let mut rng = StdRng::seed_from_u64(seed); + +// // Number of iterations for the test +// let num_iters = 10000; + +// for i in 0..num_iters { +// // Generate a random struct name +// let struct_name = format_ident!("{}", random_ident(&mut rng)); +// let z_struct_name = format_ident!("Z{}", struct_name); +// let z_struct_meta_name = format_ident!("Z{}Meta", struct_name); + +// // Generate random number of fields (1-10) +// let field_count = rng.gen_range(1..11); +// let fields = random_fields(&mut rng, field_count); + +// // Create a named fields collection +// let syn_fields = syn::punctuated::Punctuated::from_iter(fields.iter().cloned()); +// let fields_named = syn::FieldsNamed { +// brace_token: syn::token::Brace::default(), +// named: syn_fields, +// }; + +// // Split into meta fields and struct fields +// let (_, struct_fields) = crate::utils::process_fields(&fields_named); + +// // Call the function we're testing +// let result = generate_deserialize_impl::( +// &struct_name, +// &z_struct_name, +// &z_struct_meta_name, +// &struct_fields, +// false, +// ); + +// // Get the generated code as a string for validation +// let result_str = result.to_string(); + +// // Print the first result for debugging +// if i == 0 { +// println!("Generated deserialize_impl code format:\n{}", result_str); +// } + +// // Verify the result contains expected elements +// // Basic validation - must be non-empty +// assert!( +// !result_str.is_empty(), +// "Failed to generate TokenStream for iteration {}", +// i +// ); + +// // Validate that the generated code contains the expected impl definition +// let impl_pattern = format!( +// "impl < 'a > light_zero_copy :: borsh :: Deserialize < 'a > for {}", +// struct_name +// ); +// assert!( +// result_str.contains(&impl_pattern), +// "Generated code missing impl definition for iteration {}. Expected: {}", +// i, +// impl_pattern +// ); + +// // Validate type Output is defined +// let output_pattern = format!("type Output = {} < 'a >", z_struct_name); +// assert!( +// result_str.contains(&output_pattern), +// "Generated code missing Output type for iteration {}. Expected: {}", +// i, +// output_pattern +// ); + +// // Validate the zero_copy_at method is present +// assert!( +// result_str.contains("fn zero_copy_at (bytes : & 'a [u8])"), +// "Generated code missing zero_copy_at method for iteration {}", +// i +// ); + +// // Check for meta field extraction +// let meta_extraction_pattern = format!( +// "let (__meta , bytes) = light_zero_copy :: Ref :: < & 'a [u8] , {} > :: from_prefix (bytes) ?", +// z_struct_meta_name +// ); +// assert!( +// result_str.contains(&meta_extraction_pattern), +// "Generated code missing meta field extraction for iteration {}", +// i +// ); + +// // Check for return with Ok pattern +// assert!( +// result_str.contains("Ok (("), +// "Generated code missing Ok return statement for iteration {}", +// i +// ); + +// // Check for the struct initialization +// let struct_init_pattern = format!("{} {{", z_struct_name); +// assert!( +// result_str.contains(&struct_init_pattern), +// "Generated code missing struct initialization for iteration {}", +// i +// ); + +// // Check for meta field in the returned struct +// assert!( +// result_str.contains("__meta ,"), +// "Generated code missing meta field in struct initialization for iteration {}", +// i +// ); +// } +// } +// } diff --git a/program-libs/zero-copy-derive/src/from_impl.rs b/program-libs/zero-copy-derive/src/from_impl.rs new file mode 100644 index 0000000000..6aa22478e4 --- /dev/null +++ b/program-libs/zero-copy-derive/src/from_impl.rs @@ -0,0 +1,254 @@ +use proc_macro2::TokenStream; +use quote::{format_ident, quote}; +use syn::{Field, Ident, Type}; + +use crate::z_struct::{analyze_struct_fields, FieldType}; + +/// Generates code for the From> for StructName implementation +/// The `MUT` parameter controls whether to generate code for mutable or immutable references +pub fn generate_from_impl( + name: &Ident, + z_struct_name: &Ident, + meta_fields: &[&Field], + struct_fields: &[&Field], +) -> TokenStream { + let mut z_struct_name = z_struct_name.clone(); + + if MUT { + z_struct_name = format_ident!("{}Mut", z_struct_name); + } + + let _z_struct_meta_name = if MUT { + format_ident!("{}MetaMut", z_struct_name) + } else { + format_ident!("{}Meta", z_struct_name) + }; + + // Generate the conversion code for meta fields + let meta_field_conversions = if !meta_fields.is_empty() { + let field_types = analyze_struct_fields(meta_fields); + let conversions = field_types.into_iter().map(|field_type| { + match field_type { + FieldType::IntegerU64(field_name) => { + quote! { #field_name: u64::from(value.__meta.#field_name), } + } + FieldType::IntegerU32(field_name) => { + quote! { #field_name: u32::from(value.__meta.#field_name), } + } + FieldType::IntegerU16(field_name) => { + quote! { #field_name: u16::from(value.__meta.#field_name), } + } + FieldType::IntegerU8(field_name) => { + quote! { #field_name: value.__meta.#field_name, } + } + FieldType::Bool(field_name) => { + quote! { #field_name: value.__meta.#field_name > 0, } + } + FieldType::Array(field_name, _) => { + // For arrays, just copy the value + quote! { #field_name: value.__meta.#field_name, } + } + FieldType::Pubkey(field_name) => { + quote! { #field_name: value.__meta.#field_name, } + } + _ => { + let field_name = field_type.name(); + quote! { #field_name: value.__meta.#field_name.into(), } + } + } + }); + conversions.collect::>() + } else { + vec![] + }; + + // Generate the conversion code for struct fields + let struct_field_conversions = if !struct_fields.is_empty() { + let field_types = analyze_struct_fields(struct_fields); + let conversions = field_types.into_iter().map(|field_type| { + match field_type { + FieldType::VecU8(field_name) => { + quote! { #field_name: value.#field_name.to_vec(), } + } + FieldType::VecCopy(field_name, _) => { + quote! { #field_name: value.#field_name.to_vec(), } + } + FieldType::VecNonCopy(field_name, _) => { + // For non-copy vectors, clone each element directly + // We need to convert into() for Zstructs + quote! { + #field_name: { + value.#field_name.iter().map(|item| (*item).clone().into()).collect() + }, + } + } + FieldType::Array(field_name, _) => { + // For arrays, just copy the value + quote! { #field_name: *value.#field_name, } + } + FieldType::Option(field_name, field_type) => { + fn extract_inner(s: &str) -> Option<&str> { + s.strip_prefix("Option <")?.strip_suffix(">") + } + use quote::ToTokens; + let string = field_type.to_token_stream().to_string(); + println!("option string {}", string); + let cleaned_type = extract_inner(&string).unwrap(); + let field_type = syn::parse_str::(cleaned_type).unwrap(); + // For Option types, use a direct copy of the value when possible + quote! { + #field_name: if value.#field_name.is_some() { + // Create a clone of the Some value - for compressed proofs and other structs + // For instruction_data.rs, we just need to clone the value directly + Some((#field_type::from(*value.#field_name.as_ref().unwrap()).clone())) + } else { + None + }, + } + } + FieldType::Pubkey(field_name) => { + quote! { #field_name: *value.#field_name, } + } + FieldType::Bool(field_name) => { + if MUT { + quote! { #field_name: *value.#field_name > 0, } + } else { + quote! { #field_name: value.#field_name > 0, } + } + } + FieldType::CopyU8Bool(field_name) => { + quote! { #field_name: value.#field_name > 0, } + } + FieldType::IntegerU64(field_name) => { + quote! { #field_name: u64::from(*value.#field_name), } + } + FieldType::IntegerU32(field_name) => { + quote! { #field_name: u32::from(*value.#field_name), } + } + FieldType::IntegerU16(field_name) => { + quote! { #field_name: u16::from(*value.#field_name), } + } + FieldType::IntegerU8(field_name) => { + if MUT { + quote! { #field_name: *value.#field_name, } + } else { + quote! { #field_name: value.#field_name, } + } + } + FieldType::Copy(field_name, _) => { + quote! { #field_name: value.#field_name, } + } + FieldType::OptionU64(field_name) => { + quote! { #field_name: value.#field_name.as_ref().map(|x| u64::from(**x)), } + } + FieldType::OptionU32(field_name) => { + quote! { #field_name: value.#field_name.as_ref().map(|x| u32::from(**x)), } + } + FieldType::OptionU16(field_name) => { + quote! { #field_name: value.#field_name.as_ref().map(|x| u16::from(**x)), } + } + FieldType::NonCopy(field_name, field_type) => { + // For complex non-copy types, dereference and clone directly + quote! { #field_name: #field_type::from(&value.#field_name), } + } + } + }); + conversions.collect::>() + } else { + vec![] + }; + + // Combine all the field conversions + let all_field_conversions = [meta_field_conversions, struct_field_conversions].concat(); + + // Return the final From implementation without generic From implementations + quote! { + impl<'a> From<#z_struct_name<'a>> for #name { + fn from(value: #z_struct_name<'a>) -> Self { + Self { + #(#all_field_conversions)* + } + } + } + + impl<'a> From<&#z_struct_name<'a>> for #name { + fn from(value: &#z_struct_name<'a>) -> Self { + Self { + #(#all_field_conversions)* + } + } + } + } +} + +#[cfg(test)] +mod tests { + use quote::format_ident; + use syn::{parse_quote, Field}; + + use super::*; + + #[test] + fn test_generate_from_impl() { + // Create a struct for testing + let name = format_ident!("TestStruct"); + let z_struct_name = format_ident!("ZTestStruct"); + + // Create some test fields + let field_a: Field = parse_quote!(pub a: u8); + let field_b: Field = parse_quote!(pub b: u16); + let field_c: Field = parse_quote!(pub c: Vec); + + // Split into meta and struct fields + let meta_fields = vec![&field_a, &field_b]; + let struct_fields = vec![&field_c]; + + // Generate the implementation + let result = + generate_from_impl::(&name, &z_struct_name, &meta_fields, &struct_fields); + + // Convert to string for testing + let result_str = result.to_string(); + + // Check that the implementation contains required elements + println!("Generated code: {}", result_str); + assert!(result_str.contains("impl < 'a > From < ZTestStruct < 'a >> for TestStruct")); + + // Check field handling + assert!(result_str.contains("a :")); // For u8 fields + assert!(result_str.contains("b :")); // For u16 fields + assert!(result_str.contains("c :")); // For Vec fields + } + + #[test] + fn test_generate_from_impl_mut() { + // Create a struct for testing + let name = format_ident!("TestStruct"); + let z_struct_name = format_ident!("ZTestStruct"); + + // Create some test fields + let field_a: Field = parse_quote!(pub a: u8); + let field_b: Field = parse_quote!(pub b: bool); + let field_c: Field = parse_quote!(pub c: Option); + + // Split into meta and struct fields + let meta_fields = vec![&field_a, &field_b]; + let struct_fields = vec![&field_c]; + + // Generate the implementation for mutable version + let result = + generate_from_impl::(&name, &z_struct_name, &meta_fields, &struct_fields); + + // Convert to string for testing + let result_str = result.to_string(); + + // Check that the implementation contains required elements + println!("Generated code (mut): {}", result_str); + assert!(result_str.contains("impl < 'a > From < ZTestStructMut < 'a >> for TestStruct")); + + // Check field handling + assert!(result_str.contains("a :")); // For u8 fields + assert!(result_str.contains("b :")); // For bool fields + assert!(result_str.contains("c :")); // For Option fields + } +} diff --git a/program-libs/zero-copy-derive/src/lib.rs b/program-libs/zero-copy-derive/src/lib.rs new file mode 100644 index 0000000000..4a7bac1860 --- /dev/null +++ b/program-libs/zero-copy-derive/src/lib.rs @@ -0,0 +1,1337 @@ +use proc_macro::TokenStream; +use quote::{format_ident, quote}; +use syn::{parse_macro_input, DeriveInput}; + +// mod byte_len; +// mod byte_len_derive; +mod config; +mod deserialize_impl; +mod from_impl; +mod meta_struct; +mod partial_eq_impl; +mod utils; +mod z_struct; +mod zero_copy_struct_inner; + +/// ZeroCopy derivation macro for zero-copy deserialization +/// +/// # Usage +/// +/// Basic usage: +/// no_rust''' +/// #[derive(ZeroCopy)] +/// pub struct MyStruct { +/// pub a: u8, +/// } +/// ''' +/// +/// To derive PartialEq as well, use ZeroCopyEq in addition to ZeroCopy: +/// no_rust''' +/// #[derive(ZeroCopy, ZeroCopyEq)] +/// pub struct MyStruct { +/// pub a: u8, +/// } +/// ''' +/// +/// # Macro Rules +/// 1. Create zero copy structs Z and ZMut for the struct +/// 1.1. The first fields are extracted into a meta struct until we reach a Vec, Option or type that does not implement Copy +/// 1.2. Represent vectors to ZeroCopySlice & don't include these into the meta struct +/// 1.3. Replace u16 with U16, u32 with U32, etc +/// 1.4. Every field after the first vector is directly included in the ZStruct and deserialized 1 by 1 +/// 1.5. If a vector contains a nested vector (does not implement Copy) it must implement Deserialize +/// 1.6. Elements in an Option must implement Deserialize +/// 1.7. A type that does not implement Copy must implement Deserialize, and is deserialized 1 by 1 +/// 1.8. is u8 deserialized as u8::zero_copy_at instead of Ref<&'a [u8], u8> for non mut, for mut it is Ref<&'a mut [u8], u8> +/// 2. Implement Deserialize and DeserializeMut which return Z and ZMut +/// 3. Implement From> for StructName and FromMut> for StructName +/// +/// TODOs: +/// 1. test and fix boolean support for mut derivation (is just represented as u8) +/// 2. add more tests in particular for mut derivation +/// 3. rename deserialize traits to ZeroCopy and ZeroCopyMut +/// 4. check generated code by hand +/// 5. fix partial eq generation for options +#[proc_macro_derive(ZeroCopy)] +pub fn derive_zero_copy(input: TokenStream) -> TokenStream { + // Parse the input DeriveInput + let input = parse_macro_input!(input as DeriveInput); + + // // Check for both the poseidon_hasher attribute and LightHasher in derive + // let hasher = input.attrs.iter().any(|attr| { + // if attr.path().is_ident("poseidon") { + // return true; + // } + // false + // }); + let hasher = false; + + // Process the input to extract struct information + let (name, z_struct_name, z_struct_meta_name, fields) = utils::process_input(&input); + + // Process the fields to separate meta fields and struct fields + let (meta_fields, struct_fields) = utils::process_fields(fields); + // let hasher = false; + // Generate each implementation part using the respective modules + // let meta_struct_def_mut = + // meta_struct::generate_meta_struct::(&z_struct_meta_name, &meta_fields, hasher); + let meta_struct_def = + meta_struct::generate_meta_struct::(&z_struct_meta_name, &meta_fields, hasher); + + // let z_struct_def_mut = z_struct::generate_z_struct::( + // &z_struct_name, + // &z_struct_meta_name, + // &struct_fields, + // &meta_fields, + // hasher, + // ); + let z_struct_def = z_struct::generate_z_struct::( + &z_struct_name, + &z_struct_meta_name, + &struct_fields, + &meta_fields, + hasher, + ); + + // let zero_copy_struct_inner_impl_mut = + // // For mutable version, we use the Mut suffix for the ZeroCopyInner type + // zero_copy_struct_inner::generate_zero_copy_struct_inner::( + // name, + // &format_ident!("{}Mut", z_struct_name), + // ); + let zero_copy_struct_inner_impl = + zero_copy_struct_inner::generate_zero_copy_struct_inner::(name, &z_struct_name); + + // let _byte_len_impl = byte_len::generate_byte_len_impl(name, &meta_fields, &struct_fields); + + // let deserialize_impl_mut = deserialize_impl::generate_deserialize_impl::( + // name, + // &z_struct_name, + // &z_struct_meta_name, + // &struct_fields, + // meta_fields.is_empty(), + // quote! {}, + // ); + + let deserialize_impl = deserialize_impl::generate_deserialize_impl::( + name, + &z_struct_name, + &z_struct_meta_name, + &struct_fields, + meta_fields.is_empty(), + quote! {}, + ); + + // Combine all implementations + let expanded = quote! { + #meta_struct_def + + // #meta_struct_def_mut + + #z_struct_def + + // #z_struct_def_mut + + #zero_copy_struct_inner_impl + + // #zero_copy_struct_inner_impl_mut + + #deserialize_impl + + // #deserialize_impl_mut + + // Don't derive byte_len for non-mut derivations + // impl #name { + // #byte_len_impl + // } + + }; + + // For testing, we could add assertions here to verify the output + TokenStream::from(expanded) +} + +/// ZeroCopyEq implementation to add PartialEq for zero-copy structs. +/// +/// Use this in addition to ZeroCopy when you want the generated struct to implement PartialEq: +/// +/// no_rust``` +/// #[derive(ZeroCopy, ZeroCopyEq)] +/// pub struct MyStruct { +/// pub a: u8, +/// } +/// ``` +#[proc_macro_derive(ZeroCopyEq)] +pub fn derive_zero_copy_eq(input: TokenStream) -> TokenStream { + // Parse the input DeriveInput + let input = parse_macro_input!(input as DeriveInput); + + // Process the input to extract struct information + let (name, z_struct_name, z_struct_meta_name, fields) = utils::process_input(&input); + + // Process the fields to separate meta fields and struct fields + let (meta_fields, struct_fields) = utils::process_fields(fields); + + // Generate the PartialEq implementation. + let partial_eq_impl = partial_eq_impl::generate_partial_eq_impl::( + name, + &z_struct_name, + &z_struct_meta_name, + &meta_fields, + &struct_fields, + ); + // Generate From implementations + let from_impl = + from_impl::generate_from_impl::(name, &z_struct_name, &meta_fields, &struct_fields); + // let from_impl_mut = + // from_impl::generate_from_impl::(name, &z_struct_name, &meta_fields, &struct_fields); + + let _z_struct_name = format_ident!("{}Mut", z_struct_name); + let _z_struct_meta_name = format_ident!("{}Mut", z_struct_meta_name); + // let mut_partial_eq_impl = partial_eq_impl::generate_partial_eq_impl( + // name, + // &z_struct_name, + // &z_struct_meta_name, + // &meta_fields, + // &struct_fields, + // ); + + TokenStream::from(quote! { + #partial_eq_impl + // #mut_partial_eq_impl + + + #from_impl + + // #from_impl_mut + }) +} + +/// ZeroCopyMut derivation macro for mutable zero-copy deserialization +/// +/// This macro generates mutable zero-copy implementations including: +/// - DeserializeMut trait implementation +/// - Mutable Z-struct with `Mut` suffix +/// - byte_len() method implementation +/// - Mutable ZeroCopyStructInner implementation +/// +/// # Usage +/// +/// ```rust +/// use light_zero_copy_derive::ZeroCopyMut; +/// +/// #[derive(ZeroCopyMut)] +/// pub struct MyStruct { +/// pub a: u8, +/// pub vec: Vec, +/// } +/// ``` +/// +/// This will generate: +/// - `MyStruct::zero_copy_at_mut()` method +/// - `ZMyStructMut<'a>` type for mutable zero-copy access +/// - `MyStruct::byte_len()` method +/// +/// For both immutable and mutable functionality, use both derives: +/// ```rust +/// use light_zero_copy_derive::{ZeroCopy, ZeroCopyMut}; +/// +/// #[derive(ZeroCopy, ZeroCopyMut)] +/// pub struct MyStruct { +/// pub a: u8, +/// } +/// ``` +#[cfg(feature = "mut")] +#[proc_macro_derive(ZeroCopyMut)] +pub fn derive_zero_copy_mut(input: TokenStream) -> TokenStream { + // Parse the input DeriveInput + let input = parse_macro_input!(input as DeriveInput); + + let hasher = false; // Keep consistent with ZeroCopy macro + + // Process the input to extract struct information + let (name, z_struct_name, z_struct_meta_name, fields) = utils::process_input(&input); + + // Process the fields to separate meta fields and struct fields + let (meta_fields, struct_fields) = utils::process_fields(fields); + + // Generate mutable-specific implementations + let meta_struct_def_mut = + meta_struct::generate_meta_struct::(&z_struct_meta_name, &meta_fields, hasher); + + let z_struct_def_mut = z_struct::generate_z_struct::( + &z_struct_name, + &z_struct_meta_name, + &struct_fields, + &meta_fields, + hasher, + ); + + let zero_copy_struct_inner_impl_mut = zero_copy_struct_inner::generate_zero_copy_struct_inner::< + true, + >(name, &format_ident!("{}Mut", z_struct_name)); + + let deserialize_impl_mut = deserialize_impl::generate_deserialize_impl::( + name, + &z_struct_name, + &z_struct_meta_name, + &struct_fields, + meta_fields.is_empty(), + quote! {}, // No byte_len implementation in DeserializeMut anymore + ); + + // Combine all mutable implementations + let expanded = quote! { + #meta_struct_def_mut + + #z_struct_def_mut + + #zero_copy_struct_inner_impl_mut + + #deserialize_impl_mut + }; + + TokenStream::from(expanded) +} + +// ByteLen derivation macro has been merged into ZeroCopyNew trait +// +// The ByteLen functionality is now available as a static method on ZeroCopyNew: +// ```rust +// use light_zero_copy::init_mut::ZeroCopyNew; +// +// // Calculate buffer size needed for configuration +// let config = MyStructConfig { /* ... */ }; +// let buffer_size = MyStruct::byte_len(&config); +// let mut buffer = vec![0u8; buffer_size]; +// ``` +// +// This provides more accurate sizing since it accounts for the actual configuration +// rather than just the current state of an existing struct instance. +// #[proc_macro_derive(ByteLen)] +// pub fn derive_byte_len(input: TokenStream) -> TokenStream { +// // Parse the input DeriveInput +// let input = parse_macro_input!(input as DeriveInput); + +// // Process the input to extract struct information +// let (name, _z_struct_name, _z_struct_meta_name, fields) = utils::process_input(&input); + +// // Process the fields to separate meta fields and struct fields +// let (meta_fields, struct_fields) = utils::process_fields(fields); + +// // Generate ByteLen implementation +// let byte_len_impl = +// byte_len_derive::generate_byte_len_derive_impl(&name, &meta_fields, &struct_fields); + +// TokenStream::from(byte_len_impl) +// } + +/// ZeroCopyConfig derivation macro for configuration-based zero-copy initialization +/// +/// This macro generates configuration structs and initialization methods for structs +/// with Vec and Option fields that need to be initialized with specific configurations. +/// +/// # Usage +/// +/// ```ignore +/// use light_zero_copy_derive::ZeroCopyConfig; +/// +/// #[derive(ZeroCopyConfig)] +/// pub struct MyStruct { +/// pub a: u8, +/// pub vec: Vec, +/// pub option: Option, +/// } +/// ``` +/// +/// This will generate: +/// - `MyStructConfig` struct with configuration fields +/// - `ZeroCopyNew` implementation for `MyStruct` +/// - `new_zero_copy(bytes, config)` method for initialization +/// +/// The configuration struct will have fields based on the complexity of the original fields: +/// - `Vec` → `field_name: u32` (length) +/// - `Option` → `field_name: bool` (is_some) +/// - `Vec` → `field_name: Vec` (config per element) +/// - `Option` → `field_name: Option` (config if some) +#[proc_macro_derive(ZeroCopyConfig)] +pub fn derive_zero_copy_config(input: TokenStream) -> TokenStream { + // Parse the input DeriveInput + let input = parse_macro_input!(input as DeriveInput); + + // Process the input to extract struct information + let (name, _z_struct_name, _z_struct_meta_name, fields) = utils::process_input(&input); + + // Use the same field processing logic as other derive macros for consistency + let (meta_fields, struct_fields) = utils::process_fields(fields); + + // Process ALL fields uniformly by type (no position dependency for config generation) + let all_fields: Vec<&syn::Field> = meta_fields + .iter() + .chain(struct_fields.iter()) + .cloned() + .collect(); + let all_field_types = z_struct::analyze_struct_fields(&all_fields); + + // Generate configuration struct based on all fields that need config (type-based) + let config_struct = config::generate_config_struct(name, &all_field_types); + + // Generate ZeroCopyNew implementation using the existing field separation + let init_mut_impl = generate_init_mut_impl(name, &meta_fields, &struct_fields); + + let expanded = quote! { + #config_struct + + #init_mut_impl + }; + + TokenStream::from(expanded) +} + +/// Generate ZeroCopyNew implementation with new_at method for a struct +fn generate_init_mut_impl( + struct_name: &syn::Ident, + _meta_fields: &[&syn::Field], + struct_fields: &[&syn::Field], +) -> proc_macro2::TokenStream { + let config_name = quote::format_ident!("{}Config", struct_name); + let z_meta_name = quote::format_ident!("Z{}MetaMut", struct_name); + let z_struct_mut_name = quote::format_ident!("Z{}Mut", struct_name); + + // Use the pre-separated fields from utils::process_fields (consistent with other derives) + let struct_field_types = z_struct::analyze_struct_fields(struct_fields); + + // Generate field initialization code for struct fields only (meta fields are part of __meta) + let field_initializations: Vec = struct_field_types + .iter() + .map(|field_type| config::generate_field_initialization(field_type)) + .collect(); + + // Generate struct construction - only include struct fields that were initialized + // Meta fields are accessed via __meta.field_name in the generated ZStruct + let struct_field_names: Vec = struct_field_types + .iter() + .map(|field_type| { + let field_name = field_type.name(); + quote! { #field_name, } + }) + .collect(); + + // Check if there are meta fields to determine whether to include __meta + let has_meta_fields = !_meta_fields.is_empty(); + + let meta_initialization = if has_meta_fields { + quote! { + // Handle the meta struct (fixed-size fields at the beginning) + let (__meta, bytes) = Ref::<&mut [u8], #z_meta_name>::from_prefix(bytes)?; + } + } else { + quote! { + // No meta fields, skip meta struct initialization + } + }; + + let struct_construction = if has_meta_fields { + quote! { + let result = #z_struct_mut_name { + __meta, + #(#struct_field_names)* + }; + } + } else { + quote! { + let result = #z_struct_mut_name { + #(#struct_field_names)* + }; + } + }; + + // Generate byte_len calculation for each field type + let byte_len_calculations: Vec = struct_field_types + .iter() + .map(|field_type| config::generate_byte_len_calculation(field_type)) + .collect(); + + // Calculate meta size if there are meta fields + let meta_size_calculation = if has_meta_fields { + quote! { + core::mem::size_of::<#z_meta_name>() + } + } else { + quote! { 0 } + }; + + quote! { + impl<'a> light_zero_copy::init_mut::ZeroCopyNew<'a> for #struct_name { + type Config = #config_name; + type Output = >::Output; + + fn byte_len(config: &Self::Config) -> usize { + #meta_size_calculation #(+ #byte_len_calculations)* + } + + fn new_zero_copy( + bytes: &'a mut [u8], + config: Self::Config, + ) -> Result<(Self::Output, &'a mut [u8]), light_zero_copy::errors::ZeroCopyError> { + use zerocopy::Ref; + + #meta_initialization + + #(#field_initializations)* + + // Construct the final struct + #struct_construction + + Ok((result, bytes)) + } + } + } +} + +// #[cfg(test)] +// mod tests { +// use quote::{format_ident, quote}; +// use syn::{parse_quote, DeriveInput, Field}; + +// use super::*; +// use crate::utils::process_input; + +// // Test case setup struct for easier management of field definitions and expected results +// struct TestCase { +// name: &'static str, +// fields: Vec, +// expected_meta_fields: usize, +// expected_struct_fields: usize, +// assertions: Vec<(&'static str, bool)>, // pattern, should_contain +// } + +// // Basic test for the From implementation +// #[test] +// fn test_from_implementation() { +// // Create a simple struct for testing +// let input: DeriveInput = parse_quote! { +// #[repr(C)] +// #[derive(Debug, PartialEq)] +// pub struct SimpleStruct { +// pub a: u8, +// pub b: u16, +// pub vec: Vec, +// pub c: u64, +// } +// }; + +// // Process the input to extract struct information +// let (name, z_struct_name, _z_struct_meta_name, fields) = utils::process_input(&input); + +// // Process the fields to separate meta fields and struct fields +// let (meta_fields, struct_fields) = utils::process_fields(fields); + +// // Generate the From implementation +// let from_impl = from_impl::generate_from_impl::( +// name, +// &z_struct_name, +// &meta_fields, +// &struct_fields, +// ); + +// // Generate the mut From implementation +// let from_impl_mut = from_impl::generate_from_impl::( +// name, +// &z_struct_name, +// &meta_fields, +// &struct_fields, +// ); + +// // Convert to string for validation +// let from_impl_str = from_impl.to_string(); +// let from_impl_mut_str = from_impl_mut.to_string(); + +// // Check that the implementations are generated correctly +// assert!(from_impl_str.contains("impl < 'a > From < ZSimpleStruct < 'a >> for SimpleStruct")); +// assert!(from_impl_mut_str +// .contains("impl < 'a > From < ZSimpleStructMut < 'a >> for SimpleStruct")); + +// // Check field handling for both implementations +// assert!(from_impl_str.contains("a :")); +// assert!(from_impl_str.contains("b :")); +// assert!(from_impl_str.contains("vec :")); +// assert!(from_impl_str.contains("c :")); + +// assert!(from_impl_mut_str.contains("a :")); +// assert!(from_impl_mut_str.contains("b :")); +// assert!(from_impl_mut_str.contains("vec :")); +// assert!(from_impl_mut_str.contains("c :")); +// } + +// #[test] +// fn test_simple_struct_generation() { +// // Create a simple struct for testing +// let input: DeriveInput = parse_quote! { +// #[repr(C)] +// #[derive(Debug, PartialEq)] +// pub struct TestStruct { +// pub a: u8, +// pub b: u16, +// } +// }; + +// // Process the input using our utility function +// let (name, z_struct_name, z_struct_meta_name, fields) = process_input(&input); + +// // Run the function that processes the fields +// let (meta_fields, struct_fields) = utils::process_fields(fields); + +// // Check that the names are correct +// assert_eq!(name.to_string(), "TestStruct"); +// assert_eq!(z_struct_name.to_string(), "ZTestStruct"); +// assert_eq!(z_struct_meta_name.to_string(), "ZTestStructMeta"); + +// // Check that fields are correctly identified +// assert_eq!(meta_fields.len(), 2); +// assert_eq!(struct_fields.len(), 0); + +// assert_eq!(meta_fields[0].ident.as_ref().unwrap().to_string(), "a"); +// assert_eq!(meta_fields[1].ident.as_ref().unwrap().to_string(), "b"); +// } + +// #[test] +// fn test_compressed_account_struct() { +// // No need to mock Pubkey, parse_quote handles it + +// // Define CompressedAccountData struct first (used within CompressedAccount) +// let compressed_account_data_input: DeriveInput = parse_quote! { +// #[repr(C)] +// pub struct CompressedAccountData { +// pub discriminator: [u8; 8], +// pub data: Vec, +// pub data_hash: [u8; 32], +// } +// }; + +// // Define CompressedAccount struct with the complex fields +// let compressed_account_input: DeriveInput = parse_quote! { +// #[repr(C)] +// pub struct CompressedAccount { +// pub owner: Pubkey, +// pub lamports: u64, +// pub address: Option<[u8; 32]>, +// pub data: Option, +// } +// }; + +// // Process CompressedAccountData first +// let (_, _, _, fields) = process_input(&compressed_account_data_input); + +// let (meta_fields, struct_fields) = utils::process_fields(fields); + +// // Verify CompressedAccountData field splitting +// // discriminator ([u8; 8]) is a Copy type, so it should be in meta_fields +// assert_eq!(meta_fields.len(), 1); +// assert_eq!(struct_fields.len(), 2); // Vec and [u8; 32] are in struct_fields + +// // Process CompressedAccount +// let (name, z_struct_name, z_struct_meta_name, fields) = +// process_input(&compressed_account_input); + +// let (meta_fields, struct_fields) = utils::process_fields(fields); + +// // Check struct naming +// assert_eq!(name.to_string(), "CompressedAccount"); +// assert_eq!(z_struct_name.to_string(), "ZCompressedAccount"); +// assert_eq!(z_struct_meta_name.to_string(), "ZCompressedAccountMeta"); + +// // Check field splitting +// // Since we added Pubkey as a Copy type, owner should be in meta_fields +// // And all other fields should be in struct_fields due to field ordering rules +// assert_eq!(meta_fields.len(), 2); +// assert_eq!(struct_fields.len(), 2); + +// // Check struct fields are correctly identified +// assert_eq!(meta_fields[0].ident.as_ref().unwrap().to_string(), "owner"); +// assert_eq!( +// meta_fields[1].ident.as_ref().unwrap().to_string(), +// "lamports" +// ); +// assert_eq!( +// struct_fields[0].ident.as_ref().unwrap().to_string(), +// "address" +// ); +// assert_eq!(struct_fields[1].ident.as_ref().unwrap().to_string(), "data"); + +// // Generate full implementation to verify - use internal functions directly instead of proc macro +// let (name, z_struct_name, z_struct_meta_name, fields) = +// process_input(&compressed_account_input); +// let (meta_fields, struct_fields) = utils::process_fields(fields); + +// // Generate each implementation part using the respective modules +// let meta_struct_def = +// meta_struct::generate_meta_struct::(&z_struct_meta_name, &meta_fields, false); +// let z_struct_def = z_struct::generate_z_struct::( +// &z_struct_name, +// &z_struct_meta_name, +// &struct_fields, +// &meta_fields, +// false, +// ); +// let zero_copy_struct_inner_impl = +// zero_copy_struct_inner::generate_zero_copy_struct_inner::(name, &z_struct_name); +// let deserialize_impl = deserialize_impl::generate_deserialize_impl::( +// name, +// &z_struct_name, +// &z_struct_meta_name, +// &struct_fields, +// meta_fields.is_empty(), +// ); +// // let partial_eq_impl = partial_eq_impl::generate_partial_eq_impl( +// // name, +// // &z_struct_name, +// // &z_struct_meta_name, +// // &meta_fields, +// // &struct_fields, +// // ); + +// // Combine all implementations +// let expanded = quote! { +// #meta_struct_def +// #z_struct_def +// #zero_copy_struct_inner_impl +// #deserialize_impl +// // #partial_eq_impl +// }; + +// let result = expanded.to_string(); + +// // Create a standardized format for comparison by removing whitespace and normalizing syntax +// fn normalize_code(code: &str) -> String { +// code.chars() +// .filter(|c| !c.is_whitespace()) +// .collect::() +// } + +// // Get the normalized actual result first +// let normalized_result = normalize_code(&result); +// println!("Generated code normalized:\n{}", normalized_result); +// // Print the generated code for debugging purposes +// println!( +// "Generated code normalized:\n{}", +// String::from_utf8(rustfmt(result)).unwrap() +// ); + +// // Directly verify key structural elements instead of doing a full string comparison +// assert!(normalized_result.contains("pubstructZCompressedAccountMeta{")); +// assert!(normalized_result.contains("pubstructZCompressedAccount<'a>")); +// assert!( +// normalized_result.contains("meta:light_zero_copy::Ref<&'a[u8],ZCompressedAccountMeta>") +// ); +// assert!(normalized_result.contains("pubowner:")); +// assert!(normalized_result.contains("publamports:")); +// assert!(normalized_result.contains("pubaddress:")); +// assert!(normalized_result.contains("pubdata:")); +// assert!(normalized_result +// .contains("impllight_zero_copy::borsh::DeserializeforCompressedAccount")); +// assert!(normalized_result.contains("typeZeroCopyInner=ZCompressedAccount<'static>")); +// } + +// use std::{ +// env, +// io::{self, prelude::*}, +// process::{Command, Stdio}, +// thread::spawn, +// }; +// pub fn rustfmt(code: String) -> Vec { +// let mut cmd = match env::var_os("RUSTFMT") { +// Some(r) => Command::new(r), +// _ => Command::new("rustfmt"), +// }; + +// let mut cmd = cmd +// .stdin(Stdio::piped()) +// .stdout(Stdio::piped()) +// .stderr(Stdio::piped()) +// .spawn() +// .unwrap(); + +// let mut stdin = cmd.stdin.take().unwrap(); +// let mut stdout = cmd.stdout.take().unwrap(); + +// let stdin_handle = spawn(move || { +// stdin.write_all(code.as_bytes()).unwrap(); +// }); + +// let mut formatted_code = vec![]; +// io::copy(&mut stdout, &mut formatted_code).unwrap(); + +// let _ = cmd.wait(); +// stdin_handle.join().unwrap(); +// formatted_code +// } +// #[test] +// fn test_empty_struct() { +// // Create an empty struct for testing +// let input: DeriveInput = parse_quote! { +// #[repr(C)] +// pub struct EmptyStruct {} +// }; + +// // Process the input +// let (name, z_struct_name, z_struct_meta_name, fields) = process_input(&input); + +// // Split into meta fields and struct fields +// let (meta_fields, struct_fields) = utils::process_fields(fields); + +// // Generate each implementation part +// let meta_struct_def = +// meta_struct::generate_meta_struct::(&z_struct_meta_name, &meta_fields, false); +// let z_struct_def = z_struct::generate_z_struct::( +// &z_struct_name, +// &z_struct_meta_name, +// &struct_fields, +// &meta_fields, +// false, +// ); +// let zero_copy_struct_inner_impl = +// zero_copy_struct_inner::generate_zero_copy_struct_inner::(name, &z_struct_name); +// let deserialize_impl = deserialize_impl::generate_deserialize_impl::( +// name, +// &z_struct_name, +// &z_struct_meta_name, +// &struct_fields, +// meta_fields.is_empty(), +// ); +// let partial_eq_impl = partial_eq_impl::generate_partial_eq_impl( +// name, +// &z_struct_name, +// &z_struct_meta_name, +// &meta_fields, +// &struct_fields, +// ); + +// // Combine all implementations +// let expanded = quote! { +// #meta_struct_def +// #z_struct_def +// #zero_copy_struct_inner_impl +// #deserialize_impl +// #partial_eq_impl +// }; + +// // Convert to string for validation +// let result = expanded.to_string(); + +// // Verify the output contains what we expect +// assert!(result.contains("struct ZEmptyStructMeta")); +// assert!(result.contains("struct ZEmptyStruct < 'a >")); +// assert!( +// result.contains("impl light_zero_copy :: borsh :: ZeroCopyStructInner for EmptyStruct") +// ); +// assert!(result.contains( +// "impl < 'a > light_zero_copy :: borsh :: Deserialize < 'a > for EmptyStruct" +// )); +// } + +// #[test] +// fn test_struct_with_bool() { +// // Create a struct with bool fields for testing +// let input: DeriveInput = parse_quote! { +// #[repr(C)] +// pub struct BoolStruct { +// pub a: bool, +// pub b: u8, +// pub c: Vec, +// pub d: bool, +// } +// }; + +// // Process the input +// let (_, z_struct_name, z_struct_meta_name, fields) = process_input(&input); + +// // Split into meta fields and struct fields +// let (meta_fields, struct_fields) = utils::process_fields(fields); + +// // Check that fields are correctly identified +// assert_eq!(meta_fields.len(), 2); // 'a' and 'b' should be in meta_fields +// assert_eq!(struct_fields.len(), 2); // 'c' and 'd' should be in struct_fields + +// // Generate the implementation +// let meta_struct_def = +// meta_struct::generate_meta_struct::(&z_struct_meta_name, &meta_fields, false); +// let z_struct_def = z_struct::generate_z_struct::( +// &z_struct_name, +// &z_struct_meta_name, +// &struct_fields, +// &meta_fields, +// false, +// ); + +// // Check meta struct has bool converted to u8 +// let meta_struct_str = meta_struct_def.to_string(); +// println!("meta_struct_str {}", meta_struct_str); +// assert!(meta_struct_str.contains("pub a : u8")); + +// // Check z_struct has methods for boolean fields +// let z_struct_str = z_struct_def.to_string(); +// println!("z_struct_str {}", z_struct_str); +// assert!(z_struct_str.contains("pub fn a (& self) -> bool {")); +// assert!(z_struct_str.contains("self . a > 0")); +// assert!(z_struct_str.contains("pub fn d (& self) -> bool {")); +// assert!(z_struct_str.contains("self . d > 0")); +// } + +// #[test] +// fn test_zero_copy_eq() { +// // Create a test input +// let input: DeriveInput = parse_quote! { +// #[repr(C)] +// pub struct TestStruct { +// pub a: u8, +// pub b: u16, +// } +// }; + +// // Process the input for ZeroCopy +// let (name, z_struct_name, z_struct_meta_name, fields) = process_input(&input); +// let (meta_fields, struct_fields) = utils::process_fields(fields); + +// // Generate code from ZeroCopy +// let meta_struct_def = +// meta_struct::generate_meta_struct::(&z_struct_meta_name, &meta_fields, false); +// let z_struct_def = z_struct::generate_z_struct::( +// &z_struct_name, +// &z_struct_meta_name, +// &struct_fields, +// &meta_fields, +// false, +// ); +// let zero_copy_struct_inner_impl = +// zero_copy_struct_inner::generate_zero_copy_struct_inner::(name, &z_struct_name); +// let deserialize_impl = deserialize_impl::generate_deserialize_impl::( +// name, +// &z_struct_name, +// &z_struct_meta_name, +// &struct_fields, +// meta_fields.is_empty(), +// ); + +// let zero_copy_expanded = quote! { +// #meta_struct_def +// #z_struct_def +// #zero_copy_struct_inner_impl +// #deserialize_impl +// }; + +// // Generate code from ZeroCopyEq +// let partial_eq_impl = partial_eq_impl::generate_partial_eq_impl( +// name, +// &z_struct_name, +// &z_struct_meta_name, +// &meta_fields, +// &struct_fields, +// ); + +// // Verify ZeroCopy output doesn't include PartialEq +// let zero_copy_result = zero_copy_expanded.to_string(); +// assert!( +// !zero_copy_result.contains("impl < 'a > PartialEq < TestStruct >"), +// "ZeroCopy alone should not include PartialEq implementation" +// ); + +// // Verify ZeroCopyEq output is just the PartialEq implementation +// let zero_copy_eq_result = partial_eq_impl.to_string(); +// assert!( +// zero_copy_eq_result.contains("impl < 'a > PartialEq < TestStruct >"), +// "ZeroCopyEq should include PartialEq implementation" +// ); + +// // Verify that combining both gives us the complete implementation +// let combined = quote! { +// #zero_copy_expanded + +// #partial_eq_impl +// }; + +// let combined_result = combined.to_string(); +// assert!( +// combined_result.contains("impl < 'a > PartialEq < TestStruct >"), +// "Combining ZeroCopy and ZeroCopyEq should include PartialEq implementation" +// ); +// } + +// #[test] +// fn test_struct_with_vector() { +// // Create a struct with Vec field for testing +// let input: DeriveInput = parse_quote! { +// #[repr(C)] +// pub struct VecStruct { +// pub a: u8, +// pub b: Vec, +// pub c: u32, +// } +// }; + +// // Process the input +// let (_name, _z_struct_name, z_struct_meta_name, fields) = process_input(&input); + +// // Split into meta fields and struct fields +// let (meta_fields, struct_fields) = utils::process_fields(fields); + +// // Check that fields are correctly identified +// assert_eq!(meta_fields.len(), 1); // Only 'a' should be in meta_fields +// assert_eq!(struct_fields.len(), 2); // 'b' and 'c' should be in struct_fields + +// // The field names should be correct +// assert_eq!(meta_fields[0].ident.as_ref().unwrap().to_string(), "a"); +// assert_eq!(struct_fields[0].ident.as_ref().unwrap().to_string(), "b"); +// assert_eq!(struct_fields[1].ident.as_ref().unwrap().to_string(), "c"); + +// // Generate the implementation +// let meta_struct_def = +// meta_struct::generate_meta_struct::(&z_struct_meta_name, &meta_fields, false); +// let result = meta_struct_def.to_string(); + +// // Verify the meta struct has the right fields +// assert!(result.contains("pub a : u8")); +// assert!(!result.contains("pub b : Vec < u8 >")); +// } + +// #[test] +// fn test_mutable_attribute() { +// // Create a simple struct with the mutable attribute +// let input: DeriveInput = parse_quote! { +// #[derive(ZeroCopy)] +// #[zero_copy(mutable)] +// pub struct MutableStruct { +// pub a: u8, +// pub b: Vec, +// } +// }; + +// // Check for the mutable attribute +// let mut is_mutable = false; +// for attr in &input.attrs { +// if attr.path().is_ident("zero_copy") { +// let _ = attr.parse_nested_meta(|meta| { +// if meta.path.is_ident("mutable") { +// is_mutable = true; +// } +// Ok(()) +// }); +// } +// } + +// // Verify the mutable attribute is detected +// assert!(is_mutable, "Mutable attribute should be detected"); + +// // Process the input +// let (name, z_struct_name, z_struct_meta_name, fields) = process_input(&input); +// let (meta_fields, struct_fields) = utils::process_fields(fields); + +// // Generate the expanded code +// let meta_struct_def = +// meta_struct::generate_meta_struct::(&z_struct_meta_name, &meta_fields, false); +// let z_struct_def = z_struct::generate_z_struct::( +// &z_struct_name, +// &z_struct_meta_name, +// &struct_fields, +// &meta_fields, +// false, +// ); +// let zero_copy_struct_inner_impl = zero_copy_struct_inner::generate_zero_copy_struct_inner::< +// false, +// >(name, &format_ident!("{}Mut", z_struct_name)); +// let deserialize_impl = deserialize_impl::generate_deserialize_impl::( +// name, +// &z_struct_name, +// &z_struct_meta_name, +// &struct_fields, +// meta_fields.is_empty(), +// ); + +// // Combine all implementations +// let expanded = quote! { +// #meta_struct_def +// #z_struct_def +// #zero_copy_struct_inner_impl +// #deserialize_impl +// }; + +// let result = expanded.to_string(); + +// // Verify mutable-specific code generation +// println!("Generated code: {}", result); +// assert!( +// result.contains("ZMutableStructMut"), +// "Mutable implementation should add Mut suffix to type name" +// ); +// assert!( +// result.contains("light_zero_copy :: borsh_mut ::"), +// "Mutable implementation should use borsh_mut" +// ); +// assert!( +// result.contains("& 'a mut [u8]"), +// "Mutable implementation should use & 'a mut [u8]" +// ); +// assert!( +// result.contains("borsh_vec_u8_as_slice_mut"), +// "Mutable implementation should use borsh_vec_u8_as_slice_mut" +// ); +// } + +// #[test] +// fn test_derive_zero_copy_edge_cases() { +// // Define test cases covering edge cases based on the rules +// let test_cases = vec![ +// // Case 1: Empty struct +// TestCase { +// name: "EmptyStruct", +// fields: vec![], +// expected_meta_fields: 0, +// expected_struct_fields: 0, +// assertions: vec![ +// ("struct ZEmptyStructMeta { }", true), +// ("impl light_zero_copy :: borsh :: Deserialize for EmptyStruct", true), +// ], +// }, + +// // Case 2: All primitive Copy types +// TestCase { +// name: "AllPrimitives", +// fields: vec![ +// parse_quote!(pub a: u8), +// parse_quote!(pub b: u16), +// parse_quote!(pub c: u32), +// parse_quote!(pub d: u64), +// parse_quote!(pub e: bool), +// ], +// expected_meta_fields: 5, +// expected_struct_fields: 0, +// assertions: vec![ +// ("pub a : u8", true), +// ("pub b : light_zero_copy :: little_endian :: U16", true), // Rule 1.3: Replace u16 with U16 +// ("pub c : light_zero_copy :: little_endian :: U32", true), // Rule 1.3: Replace u32 with U32 +// ("pub d : light_zero_copy :: little_endian :: U64", true), // Rule 1.3: Replace u64 with U64 +// ("pub e : u8", true), +// ("meta : light_zero_copy :: Ref < & 'a [u8] , ZAllPrimitivesMeta >", true), +// ], +// }, + +// // Case 3: Vec at start (Rule 1.1) +// TestCase { +// name: "VecAtStart", +// fields: vec![ +// parse_quote!(pub data: Vec), +// parse_quote!(pub a: u8), +// parse_quote!(pub b: u16), +// ], +// expected_meta_fields: 0, +// expected_struct_fields: 3, +// assertions: vec![ +// ("pub data : & 'a [u8]", true), // Rule 1.2: Vec represented as slice +// ("struct ZVecAtStartMeta { }", true), // Empty meta struct +// ("let (data , bytes) = light_zero_copy :: borsh :: borsh_vec_u8_as_slice (bytes) ?", true), +// ], +// }, + +// // Case 4: Vec in middle (Rule 1.1, 1.4) +// TestCase { +// name: "VecInMiddle", +// fields: vec![ +// parse_quote!(pub a: u8), +// parse_quote!(pub b: u16), +// parse_quote!(pub data: Vec), // Split point +// parse_quote!(pub c: u32), +// parse_quote!(pub d: u64), +// ], +// expected_meta_fields: 2, +// expected_struct_fields: 3, +// assertions: vec![ +// ("struct ZVecInMiddleMeta { pub a : u8 , pub b : light_zero_copy :: little_endian :: U16 , }", true), +// ("pub data : & 'a [u8]", true), +// ("let (c , bytes) = light_zero_copy :: Ref :: < & 'a [u8] , light_zero_copy :: little_endian :: U32 > :: from_prefix (bytes) ?", true), +// ("let (d , bytes) = light_zero_copy :: Ref :: < & 'a [u8] , light_zero_copy :: little_endian :: U64 > :: from_prefix (bytes) ?", true), +// ], +// }, + +// // Case 5: Mixed Vec types (Rules 1.5) +// TestCase { +// name: "MixedVecTypes", +// fields: vec![ +// parse_quote!(pub a: u8), +// parse_quote!(pub bytes: Vec), // Vec special case +// parse_quote!(pub numbers: Vec), // Vec with Copy type +// ], +// expected_meta_fields: 1, +// expected_struct_fields: 2, +// assertions: vec![ +// ("pub bytes : & 'a [u8]", true), // Rule 1.2: Vec as slice +// ("pub numbers : light_zero_copy :: slice :: ZeroCopySliceBorsh < 'a ,", true), // Using ZeroCopySliceBorsh for Copy types +// ("let (bytes , bytes) = light_zero_copy :: borsh :: borsh_vec_u8_as_slice (bytes) ?", true), +// ("let (numbers , bytes) = light_zero_copy :: slice :: ZeroCopySliceBorsh", true), +// ], +// }, + +// // Case 6: Option type splitting boundary (Rule 1.6) +// TestCase { +// name: "OptionTypeStruct", +// fields: vec![ +// parse_quote!(pub a: u8), +// parse_quote!(pub b: Option), // Split point +// parse_quote!(pub c: u64), +// ], +// expected_meta_fields: 1, +// expected_struct_fields: 2, +// assertions: vec![ +// ("struct ZOptionTypeStructMeta { pub a : u8 , }", true), +// ("pub b : < Option < u32 > as light_zero_copy :: borsh :: Deserialize> :: Output< 'a >", true), +// ("let (b , bytes) = < Option < u32 > as light_zero_copy :: borsh :: Deserialize > :: zero_copy_at (bytes) ?", true), +// ], +// }, + +// // Case 7: Arrays should be treated as Copy types +// TestCase { +// name: "ArrayTypes", +// fields: vec![ +// parse_quote!(pub a: [u8; 4]), +// parse_quote!(pub b: [u32; 2]), +// parse_quote!(pub c: Vec), // Split point +// ], +// expected_meta_fields: 2, +// expected_struct_fields: 1, +// assertions: vec![ +// // Just check for the existence of the array field types, not exact formatting +// ("pub a : [u8 ; 4]", true), +// ("pub b : [u32 ; 2]", true), // Arrays don't use zerocopy types +// ("pub c : & 'a [u8]", true), +// ], +// }, + +// // Case 8: Test field after Option (Rule 1.4) +// TestCase { +// name: "FieldsAfterNonCopy", +// fields: vec![ +// parse_quote!(pub a: u8), +// parse_quote!(pub opt: Option), // Split point +// parse_quote!(pub b: u16), // After non-Copy, should be in struct_fields +// parse_quote!(pub c: u32), +// ], +// expected_meta_fields: 1, +// expected_struct_fields: 3, +// assertions: vec![ +// ("let (opt , bytes) = < Option < u16 > as light_zero_copy :: borsh :: Deserialize > :: zero_copy_at (bytes) ?", true), +// ("let (b , bytes) = light_zero_copy :: Ref :: < & 'a [u8] , light_zero_copy :: little_endian :: U16 > :: from_prefix (bytes) ?", true), +// ("let (c , bytes) = light_zero_copy :: Ref :: < & 'a [u8] , light_zero_copy :: little_endian :: U32 > :: from_prefix (bytes) ?", true), +// ], +// }, +// ]; + +// // Run all test cases +// for (i, test_case) in test_cases.iter().enumerate() { +// println!("Testing case {}: {}", i, test_case.name); + +// // Create struct +// let struct_name = format_ident!("{}", test_case.name); +// let mut fields_punctuated = +// syn::punctuated::Punctuated::::new(); +// for field in &test_case.fields { +// fields_punctuated.push(field.clone()); +// } + +// let input = parse_quote! { +// #[repr(C)] +// pub struct #struct_name { +// #fields_punctuated +// } +// }; + +// // Process input +// let (name, z_struct_name, z_struct_meta_name, fields) = process_input(&input); +// let (meta_fields, struct_fields) = utils::process_fields(fields); + +// // Verify field counts +// assert_eq!( +// meta_fields.len(), +// test_case.expected_meta_fields, +// "Case {}: Expected {} meta fields, got {}", +// i, +// test_case.expected_meta_fields, +// meta_fields.len() +// ); +// assert_eq!( +// struct_fields.len(), +// test_case.expected_struct_fields, +// "Case {}: Expected {} struct fields, got {}", +// i, +// test_case.expected_struct_fields, +// struct_fields.len() +// ); + +// // Generate code +// let meta_struct_def = meta_struct::generate_meta_struct::( +// &z_struct_meta_name, +// &meta_fields, +// false, +// ); +// let z_struct_def = z_struct::generate_z_struct::( +// &z_struct_name, +// &z_struct_meta_name, +// &struct_fields, +// &meta_fields, +// false, +// ); +// let zero_copy_struct_inner_impl = +// zero_copy_struct_inner::generate_zero_copy_struct_inner::( +// name, +// &z_struct_name, +// ); +// let deserialize_impl = deserialize_impl::generate_deserialize_impl::( +// name, +// &z_struct_name, +// &z_struct_meta_name, +// &struct_fields, +// meta_fields.is_empty(), +// ); +// let partial_eq_impl = if test_case.name != "OptionTypeStruct" { +// partial_eq_impl::generate_partial_eq_impl( +// name, +// &z_struct_name, +// &z_struct_meta_name, +// &meta_fields, +// &struct_fields, +// ) +// } else { +// quote! {} +// }; + +// // Combine all implementations +// let expanded = quote! { +// #meta_struct_def +// #z_struct_def +// #zero_copy_struct_inner_impl +// #deserialize_impl +// #partial_eq_impl +// }; + +// // Convert to string for validation +// let result = expanded.to_string(); + +// // For debugging in case of a failure +// if false { +// // Only enable when debugging +// println!("Generated code sample for case {}: {:.500}...", i, result); +// } + +// // Verify assertions +// for (pattern, should_contain) in &test_case.assertions { +// let contains = result.contains(pattern); +// assert_eq!( +// contains, +// *should_contain, +// "Case {}: Expected '{}' to be {} in the generated code", +// i, +// pattern, +// if *should_contain { "present" } else { "absent" } +// ); +// } +// } +// } +// } diff --git a/program-libs/zero-copy-derive/src/meta_struct.rs b/program-libs/zero-copy-derive/src/meta_struct.rs new file mode 100644 index 0000000000..413cef3ff6 --- /dev/null +++ b/program-libs/zero-copy-derive/src/meta_struct.rs @@ -0,0 +1,56 @@ +use proc_macro2::TokenStream; +use quote::{format_ident, quote}; +use syn::Field; + +use crate::utils::convert_to_zerocopy_type; + +/// Generates the meta struct definition as a TokenStream +/// The `MUT` parameter determines if the struct should be generated for mutable access +pub fn generate_meta_struct( + z_struct_meta_name: &syn::Ident, + meta_fields: &[&Field], + hasher: bool, +) -> TokenStream { + let mut z_struct_meta_name = z_struct_meta_name.clone(); + if MUT { + z_struct_meta_name = format_ident!("{}Mut", z_struct_meta_name); + } + + // Generate the meta struct fields with converted types + let meta_fields_with_converted_types = meta_fields.iter().map(|field| { + let field_name = &field.ident; + let attributes = if hasher { + field + .attrs + .iter() + .map(|attr| { + let path = attr; + quote! { #path } + }) + .collect::>() + } else { + vec![quote! {}] + }; + let field_type = convert_to_zerocopy_type(&field.ty); + quote! { + #(#attributes)* + pub #field_name: #field_type + } + }); + let hasher = if hasher { + quote! { + , LightHasher + } + } else { + quote! {} + }; + + // Return the complete meta struct definition + quote! { + #[repr(C)] + #[derive(Debug, PartialEq, light_zero_copy::KnownLayout, light_zero_copy::Immutable, light_zero_copy::Unaligned, light_zero_copy::FromBytes, light_zero_copy::IntoBytes #hasher)] + pub struct #z_struct_meta_name { + #(#meta_fields_with_converted_types,)* + } + } +} diff --git a/program-libs/zero-copy-derive/src/partial_eq_impl.rs b/program-libs/zero-copy-derive/src/partial_eq_impl.rs new file mode 100644 index 0000000000..11449d843c --- /dev/null +++ b/program-libs/zero-copy-derive/src/partial_eq_impl.rs @@ -0,0 +1,251 @@ +use proc_macro2::TokenStream; +use quote::{quote, ToTokens}; +use syn::{Field, Ident}; + +use crate::z_struct::{analyze_struct_fields, FieldType}; + +/// Generates meta field comparisons for PartialEq implementation +pub fn generate_meta_field_comparisons<'a>( + meta_fields: &'a [&'a Field], +) -> impl Iterator + 'a { + let field_types = analyze_struct_fields(meta_fields); + + field_types.into_iter().map(|field_type| match field_type { + FieldType::IntegerU64(field_name) => { + quote! { + if other.#field_name != u64::from(meta.#field_name) as u64 { + return false; + } + } + } + FieldType::IntegerU32(field_name) => { + quote! { + if other.#field_name != u64::from(meta.#field_name) as u32 { + return false; + } + } + } + FieldType::IntegerU16(field_name) => { + quote! { + if other.#field_name != u64::from(meta.#field_name) as u16 { + return false; + } + } + } + FieldType::IntegerU8(field_name) => { + quote! { + if other.#field_name != u64::from(meta.#field_name) as u8 { + return false; + } + } + } + FieldType::Bool(field_name) => { + quote! { + if other.#field_name != (meta.#field_name > 0) { + return false; + } + } + } + _ => { + let field_name = field_type.name(); + quote! { + if other.#field_name != meta.#field_name { + return false; + } + } + } + }) +} + +/// Generates struct field comparisons for PartialEq implementation +pub fn generate_struct_field_comparisons<'a, const MUT: bool>( + struct_fields: &'a [&'a Field], +) -> impl Iterator + 'a { + let field_types = analyze_struct_fields(struct_fields); + if field_types + .iter() + .any(|x| matches!(x, FieldType::Option(_, _))) + { + unimplemented!("Options are not supported in ZeroCopyEq"); + } + + field_types.into_iter().map(|field_type| { + match field_type { + FieldType::VecU8(field_name) => { + quote! { + if self.#field_name != other.#field_name.as_slice() { + return false; + } + } + } + FieldType::VecCopy(field_name, _) => { + quote! { + if self.#field_name.as_slice() != other.#field_name.as_slice() { + return false; + } + } + } + FieldType::VecNonCopy(field_name, _) => { + quote! { + if self.#field_name.as_slice() != other.#field_name.as_slice() { + return false; + } + } + } + FieldType::Array(field_name, _) => { + quote! { + if *self.#field_name != other.#field_name { + return false; + } + } + } + FieldType::Option(field_name, field_type) => { + if field_type.to_token_stream().to_string() == "u8" { + quote! { + if self.#field_name.is_some() && other.#field_name.is_some() { + if self.#field_name.as_ref().unwrap() != other.#field_name.as_ref().unwrap() { + return false; + } + } else if self.#field_name.is_some() || other.#field_name.is_some() { + return false; + } + } + } + // TODO: handle issue that structs need * == *, arrays need ** == * + // else if crate::utils::is_copy_type(field_type) { + // quote! { + // if self.#field_name.is_some() && other.#field_name.is_some() { + // if **self.#field_name.as_ref().unwrap() != *other.#field_name.as_ref().unwrap() { + // return false; + // } + // } else if self.#field_name.is_some() || other.#field_name.is_some() { + // return false; + // } + // } + // } + else { + quote! { + if self.#field_name.is_some() && other.#field_name.is_some() { + if **self.#field_name.as_ref().unwrap() != *other.#field_name.as_ref().unwrap() { + return false; + } + } else if self.#field_name.is_some() || other.#field_name.is_some() { + return false; + } + } + } + + } + FieldType::Pubkey(field_name) => { + quote! { + if *self.#field_name != other.#field_name { + return false; + } + } + } + FieldType::IntegerU64(field_name) => { + quote! { + if u64::from(*self.#field_name) != other.#field_name { + return false; + } + } + } + FieldType::IntegerU32(field_name) => { + quote! { + if u32::from(*self.#field_name) != other.#field_name { + return false; + } + } + } + FieldType::IntegerU16(field_name) => { + quote! { + if u16::from(*self.#field_name) != other.#field_name { + return false; + } + } + } + FieldType::IntegerU8(field_name) => { + if MUT { + quote! { + if *self.#field_name != other.#field_name { + return false; + } + } + } else { + quote! { + if self.#field_name != other.#field_name { + return false; + } + } + } + } + FieldType::Bool(field_name) => { + if MUT { + quote! { + if (*self.#field_name > 0) != other.#field_name { + return false; + } + } + } else { + quote! { + if (self.#field_name > 0) != other.#field_name { + return false; + } + } + } + } + FieldType::CopyU8Bool(field_name) + | FieldType::Copy(field_name, _) + | FieldType::NonCopy(field_name, _) => { + quote! { + if self.#field_name != other.#field_name { + return false; + } + } + }, + FieldType::OptionU64(field_name) + | FieldType::OptionU32(field_name) + | FieldType::OptionU16(field_name) => { + quote! { + if self.#field_name != other.#field_name { + return false; + } + } + } + } + }) +} + +/// Generates the PartialEq implementation as a TokenStream +pub fn generate_partial_eq_impl( + name: &Ident, + z_struct_name: &Ident, + z_struct_meta_name: &Ident, + meta_fields: &[&Field], + struct_fields: &[&Field], +) -> TokenStream { + let struct_field_comparisons = generate_struct_field_comparisons::(struct_fields); + if !meta_fields.is_empty() { + let meta_field_comparisons = generate_meta_field_comparisons(meta_fields); + quote! { + impl<'a> PartialEq<#name> for #z_struct_name<'a> { + fn eq(&self, other: &#name) -> bool { + let meta: &#z_struct_meta_name = &self.__meta; + #(#meta_field_comparisons)* + #(#struct_field_comparisons)* + true + } + } + } + } else { + quote! { + impl<'a> PartialEq<#name> for #z_struct_name<'a> { + fn eq(&self, other: &#name) -> bool { + #(#struct_field_comparisons)* + true + } + } + + } + } +} diff --git a/program-libs/zero-copy-derive/src/utils.rs b/program-libs/zero-copy-derive/src/utils.rs new file mode 100644 index 0000000000..f3db9d4eea --- /dev/null +++ b/program-libs/zero-copy-derive/src/utils.rs @@ -0,0 +1,408 @@ +use std::{ + collections::HashMap, + sync::{Arc, Mutex}, +}; + +use proc_macro2::TokenStream; +use quote::{format_ident, quote, ToTokens}; +use syn::{Attribute, Data, DeriveInput, Field, Fields, FieldsNamed, Ident, Type, TypePath}; + +// Global cache for storing whether a struct implements Copy +lazy_static::lazy_static! { + static ref COPY_IMPL_CACHE: Arc>> = Arc::new(Mutex::new(HashMap::new())); +} + +/// Process the derive input to extract the struct information +pub fn process_input( + input: &DeriveInput, +) -> ( + &Ident, // Original struct name + proc_macro2::Ident, // Z-struct name + proc_macro2::Ident, // Z-struct meta name + &FieldsNamed, // Struct fields +) { + let name = &input.ident; + let z_struct_name = format_ident!("Z{}", name); + let z_struct_meta_name = format_ident!("Z{}Meta", name); + + // Populate the cache by checking if this struct implements Copy + let _ = struct_implements_copy(input); + + let fields = match &input.data { + Data::Struct(data) => match &data.fields { + Fields::Named(fields) => fields, + _ => panic!("ZStruct only supports structs with named fields"), + }, + _ => panic!("ZStruct only supports structs"), + }; + + (name, z_struct_name, z_struct_meta_name, fields) +} + +pub fn process_fields(fields: &FieldsNamed) -> (Vec<&Field>, Vec<&Field>) { + let mut meta_fields = Vec::new(); + let mut struct_fields = Vec::new(); + let mut reached_vec_or_option = false; + + for field in fields.named.iter() { + if !reached_vec_or_option { + if is_vec_or_option(&field.ty) || !is_copy_type(&field.ty) { + reached_vec_or_option = true; + struct_fields.push(field); + } else { + meta_fields.push(field); + } + } else { + struct_fields.push(field); + } + } + + (meta_fields, struct_fields) +} + +pub fn is_vec_or_option(ty: &Type) -> bool { + is_vec_type(ty) || is_option_type(ty) +} + +pub fn is_vec_type(ty: &Type) -> bool { + if let Type::Path(TypePath { path, .. }) = ty { + if let Some(segment) = path.segments.first() { + return segment.ident == "Vec"; + } + } + false +} + +pub fn is_option_type(ty: &Type) -> bool { + if let Type::Path(TypePath { path, .. }) = ty { + if let Some(segment) = path.segments.first() { + return segment.ident == "Option"; + } + } + false +} + +pub fn get_vec_inner_type(ty: &Type) -> Option<&Type> { + if let Type::Path(TypePath { path, .. }) = ty { + if let Some(segment) = path.segments.first() { + if segment.ident == "Vec" { + if let syn::PathArguments::AngleBracketed(args) = &segment.arguments { + if let Some(syn::GenericArgument::Type(inner_ty)) = args.args.first() { + return Some(inner_ty); + } + } + } + } + } + None +} + +pub fn get_option_inner_type(ty: &Type) -> Option<&Type> { + if let Type::Path(TypePath { path, .. }) = ty { + if let Some(segment) = path.segments.first() { + if segment.ident == "Option" { + if let syn::PathArguments::AngleBracketed(args) = &segment.arguments { + if let Some(syn::GenericArgument::Type(inner_ty)) = args.args.first() { + return Some(inner_ty); + } + } + } + } + } + None +} + +pub fn is_primitive_integer(ty: &Type) -> bool { + if let Type::Path(TypePath { path, .. }) = ty { + if let Some(segment) = path.segments.first() { + let ident = &segment.ident; + return ident == "u16" + || ident == "u32" + || ident == "u64" + || ident == "i16" + || ident == "i32" + || ident == "i64" + || ident == "u8" + || ident == "i8"; + } + } + false +} + +pub fn is_bool_type(ty: &Type) -> bool { + if let Type::Path(TypePath { path, .. }) = ty { + if let Some(segment) = path.segments.first() { + return segment.ident == "bool"; + } + } + false +} + +pub fn is_pubkey_type(ty: &Type) -> bool { + if let Type::Path(TypePath { path, .. }) = ty { + if let Some(segment) = path.segments.first() { + return segment.ident == "Pubkey"; + } + } + false +} + +pub fn convert_to_zerocopy_type(ty: &Type) -> TokenStream { + match ty { + Type::Path(TypePath { path, .. }) => { + if let Some(segment) = path.segments.first() { + let ident = &segment.ident; + match ident.to_string().as_str() { + "u16" => quote! { light_zero_copy::little_endian::U16 }, + "u32" => quote! { light_zero_copy::little_endian::U32 }, + "u64" => quote! { light_zero_copy::little_endian::U64 }, + "bool" => quote! { u8 }, + _ => quote! { #ty }, + } + } else { + quote! { #ty } + } + } + _ => { + quote! { #ty } + } + } +} + +/// Checks if a struct has a derive(Copy) attribute +fn struct_has_copy_derive(attrs: &[Attribute]) -> bool { + // Check each attribute, printing debug info for test troubleshooting + for attr in attrs { + if attr.path().is_ident("derive") { + // More reliable approach to check for Copy in derive attributes + if let Ok(expr) = attr.parse_args::() { + // Check if any of the segments in the path is "Copy" + for segment in expr.path.segments.iter() { + if segment.ident == "Copy" { + return true; + } + } + } else { + // Parse derive attribute contents directly as a string and check for "Copy" + let content = attr.to_token_stream().to_string(); + if content.contains("Copy") { + return true; + } + } + + // Fallback to parse_nested_meta as before + let mut found_copy = false; + let _ = attr.parse_nested_meta(|meta| { + if meta.path.is_ident("Copy") { + found_copy = true; + } + Ok(()) + }); + if found_copy { + return true; + } + } + } + false +} + +/// Determines whether a struct implements Copy by checking for the #[derive(Copy)] attribute. +/// Results are cached for performance. +/// +/// In Rust, a struct can only implement Copy if: +/// 1. It explicitly has a #[derive(Copy)] attribute, AND +/// 2. All of its fields implement Copy +/// +/// The Rust compiler will enforce the second condition at compile time, so we only need to check +/// for the derive attribute here. +pub fn struct_implements_copy(input: &DeriveInput) -> bool { + let struct_name = input.ident.to_string(); + + // Check the cache first + if let Some(implements_copy) = COPY_IMPL_CACHE.lock().unwrap().get(&struct_name) { + return *implements_copy; + } + + // Check if the struct has a derive(Copy) attribute + let implements_copy = struct_has_copy_derive(&input.attrs); + + // Cache the result + COPY_IMPL_CACHE + .lock() + .unwrap() + .insert(struct_name, implements_copy); + + implements_copy +} + +/// Determines whether a type implements Copy +/// 1. check whether type is a primitive type that implements Copy +/// 2. check whether type is an array type (which is always Copy if the element type is Copy) +/// 3. check whether type is struct -> check in the COPY_IMPL_CACHE if we know whether it has a #[derive(Copy)] attribute +/// +/// For struct types, this relies on the cache populated by struct_implements_copy. If we don't have cached +/// information, it assumes the type does not implement Copy. This is a limitation of our approach, but it +/// works well in practice because process_input will call struct_implements_copy for all structs before +/// they might be referenced by other structs. +pub fn is_copy_type(ty: &Type) -> bool { + match ty { + Type::Path(TypePath { path, .. }) => { + if let Some(segment) = path.segments.first() { + let ident = &segment.ident; + let ident_str = ident.to_string(); + + // Check if it's a primitive type that implements Copy + if ident == "u8" + || ident == "u16" + || ident == "u32" + || ident == "u64" + || ident == "i8" + || ident == "i16" + || ident == "i32" + || ident == "i64" + || ident == "bool" // bool is a Copy type + || ident == "char" + || ident == "Pubkey" + // Pubkey is hardcoded as copy type for now. + { + return true; + } + + // Check if we have cached information about this type + if let Some(implements_copy) = COPY_IMPL_CACHE.lock().unwrap().get(&ident_str) { + return *implements_copy; + } + } + } + // Handle array types (which are always Copy if the element type is Copy) + Type::Array(array) => { + // Arrays are Copy if their element type is Copy + return is_copy_type(&array.elem); + } + // For struct types not in cache, we'd need the derive input to check attributes + _ => {} + } + false +} + +#[cfg(test)] +mod tests { + use syn::parse_quote; + + use super::*; + + // Helper function to check if a struct implements Copy + fn check_struct_implements_copy(input: syn::DeriveInput) -> bool { + struct_implements_copy(&input) + } + + #[test] + fn test_struct_implements_copy() { + // Ensure the cache is cleared and the lock is released immediately + COPY_IMPL_CACHE.lock().unwrap().clear(); + // Test case 1: Empty struct with #[derive(Copy)] + let input: syn::DeriveInput = parse_quote! { + #[derive(Copy, Clone)] + struct EmptyStruct {} + }; + assert!( + check_struct_implements_copy(input), + "EmptyStruct should implement Copy with #[derive(Copy)]" + ); + + // Test case 2: Simple struct with #[derive(Copy)] + let input: syn::DeriveInput = parse_quote! { + #[derive(Copy, Clone)] + struct SimpleStruct { + a: u8, + b: u16, + } + }; + assert!( + check_struct_implements_copy(input), + "SimpleStruct should implement Copy with #[derive(Copy)]" + ); + + // Test case 3: Struct with #[derive(Clone)] but not Copy + let input: syn::DeriveInput = parse_quote! { + #[derive(Clone)] + struct StructWithoutCopy { + a: u8, + b: u16, + } + }; + assert!( + !check_struct_implements_copy(input), + "StructWithoutCopy should not implement Copy without #[derive(Copy)]" + ); + + // Test case 4: Struct with a non-Copy field but with derive(Copy) + // Note: In real Rust code, this would not compile, but for our test we only check attributes + let input: syn::DeriveInput = parse_quote! { + #[derive(Copy, Clone)] + struct StructWithVec { + a: u8, + b: Vec, + } + }; + assert!( + check_struct_implements_copy(input), + "StructWithVec has #[derive(Copy)] so our function returns true" + ); + + // Test case 5: Struct with all Copy fields but without #[derive(Copy)] + let input: syn::DeriveInput = parse_quote! { + struct StructWithCopyFields { + a: u8, + b: u16, + c: i32, + d: bool, + } + }; + assert!( + !check_struct_implements_copy(input), + "StructWithCopyFields should not implement Copy without #[derive(Copy)]" + ); + + // Test case 6: Unit struct without #[derive(Copy)] + let input: syn::DeriveInput = parse_quote! { + struct UnitStructWithoutCopy; + }; + assert!( + !check_struct_implements_copy(input), + "UnitStructWithoutCopy should not implement Copy without #[derive(Copy)]" + ); + + // Test case 7: Unit struct with #[derive(Copy)] + let input: syn::DeriveInput = parse_quote! { + #[derive(Copy, Clone)] + struct UnitStructWithCopy; + }; + assert!( + check_struct_implements_copy(input), + "UnitStructWithCopy should implement Copy with #[derive(Copy)]" + ); + + // Test case 8: Tuple struct with #[derive(Copy)] + let input: syn::DeriveInput = parse_quote! { + #[derive(Copy, Clone)] + struct TupleStruct(u32, bool, char); + }; + assert!( + check_struct_implements_copy(input), + "TupleStruct should implement Copy with #[derive(Copy)]" + ); + + // Test case 9: Multiple derives including Copy + let input: syn::DeriveInput = parse_quote! { + #[derive(Debug, PartialEq, Copy, Clone)] + struct MultipleDerivesStruct { + a: u8, + } + }; + assert!( + check_struct_implements_copy(input), + "MultipleDerivesStruct should implement Copy with #[derive(Copy)]" + ); + } +} diff --git a/program-libs/zero-copy-derive/src/z_struct.rs b/program-libs/zero-copy-derive/src/z_struct.rs new file mode 100644 index 0000000000..2f4783dbec --- /dev/null +++ b/program-libs/zero-copy-derive/src/z_struct.rs @@ -0,0 +1,624 @@ +use proc_macro2::TokenStream; +use quote::{format_ident, quote, ToTokens, TokenStreamExt}; +use syn::{parse_quote, parse_str, Field, Ident, Type}; + +use crate::utils; + +/// Enum representing the different field types for zero-copy struct +/// (Name, Type) +#[derive(Debug)] +pub enum FieldType<'a> { + VecU8(&'a Ident), + VecCopy(&'a Ident, &'a Type), + VecNonCopy(&'a Ident, &'a Type), + Array(&'a Ident, &'a Type), + Option(&'a Ident, &'a Type), + OptionU64(&'a Ident), + OptionU32(&'a Ident), + OptionU16(&'a Ident), + Pubkey(&'a Ident), + IntegerU64(&'a Ident), + IntegerU32(&'a Ident), + IntegerU16(&'a Ident), + IntegerU8(&'a Ident), + Bool(&'a Ident), + CopyU8Bool(&'a Ident), + Copy(&'a Ident, &'a Type), + NonCopy(&'a Ident, &'a Type), +} + +impl<'a> FieldType<'a> { + /// Get the name of the field + pub fn name(&self) -> &'a Ident { + match self { + FieldType::VecU8(name) => name, + FieldType::VecCopy(name, _) => name, + FieldType::VecNonCopy(name, _) => name, + FieldType::Array(name, _) => name, + FieldType::Option(name, _) => name, + FieldType::OptionU64(name) => name, + FieldType::OptionU32(name) => name, + FieldType::OptionU16(name) => name, + FieldType::Pubkey(name) => name, + FieldType::IntegerU64(name) => name, + FieldType::IntegerU32(name) => name, + FieldType::IntegerU16(name) => name, + FieldType::IntegerU8(name) => name, + FieldType::Bool(name) => name, + FieldType::CopyU8Bool(name) => name, + FieldType::Copy(name, _) => name, + FieldType::NonCopy(name, _) => name, + } + } +} + +/// Analyze struct fields and return vector of FieldType enums +pub fn analyze_struct_fields<'a>(struct_fields: &'a [&'a Field]) -> Vec> { + struct_fields + .iter() + .map(|field| { + if let Some(field_name) = &field.ident { + let field_type = &field.ty; + + if utils::is_vec_type(field_type) { + if let Some(inner_type) = utils::get_vec_inner_type(field_type) { + if inner_type.to_token_stream().to_string() == "u8" { + FieldType::VecU8(field_name) + } else if utils::is_copy_type(inner_type) { + FieldType::VecCopy(field_name, inner_type) + } else { + FieldType::VecNonCopy(field_name, field_type) + } + } else { + panic!("Could not determine inner type of Vec {:?}", field_type); + } + } else if let Type::Array(_) = field_type { + FieldType::Array(field_name, field_type) + } else if utils::is_option_type(field_type) { + // Check the inner type of the Option and convert to appropriate FieldType + if let Some(inner_type) = utils::get_option_inner_type(field_type) { + if utils::is_primitive_integer(inner_type) { + let field_ty_str = inner_type.to_token_stream().to_string(); + match field_ty_str.as_str() { + "u64" => FieldType::OptionU64(field_name), + "u32" => FieldType::OptionU32(field_name), + "u16" => FieldType::OptionU16(field_name), + _ => FieldType::Option(field_name, field_type), + } + } else { + FieldType::Option(field_name, field_type) + } + } else { + FieldType::Option(field_name, field_type) + } + } else if utils::is_pubkey_type(field_type) { + FieldType::Pubkey(field_name) + } else if utils::is_bool_type(field_type) { + FieldType::Bool(field_name) + } else if utils::is_primitive_integer(field_type) { + let field_ty_str = field_type.to_token_stream().to_string(); + match field_ty_str.as_str() { + "u64" => FieldType::IntegerU64(field_name), + "u32" => FieldType::IntegerU32(field_name), + "u16" => FieldType::IntegerU16(field_name), + "u8" => FieldType::IntegerU8(field_name), + _ => unimplemented!("Unsupported integer type: {}", field_ty_str), + } + } else if utils::is_copy_type(field_type) { + if field_type.to_token_stream().to_string() == "u8" + || field_type.to_token_stream().to_string() == "bool" + { + FieldType::CopyU8Bool(field_name) + } else { + FieldType::Copy(field_name, field_type) + } + } else { + FieldType::NonCopy(field_name, field_type) + } + } else { + panic!("Could not determine field name"); + } + }) + .collect() +} + +/// Generate struct fields with zerocopy types based on field type enum +fn generate_struct_fields_with_zerocopy_types<'a, const MUT: bool>( + struct_fields: &'a [&'a Field], + hasher: &'a bool, +) -> impl Iterator + 'a { + let field_types = analyze_struct_fields(struct_fields); + field_types + .into_iter() + .zip(struct_fields.iter()) + .map(|(field_type, field)| { + let attributes = if *hasher { + field + .attrs + .iter() + .map(|attr| { + let path = attr; + quote! { #path } + }) + .collect::>() + } else { + vec![quote! {}] + }; + let (mutability, import_path, import_slice, camel_case_suffix): ( + syn::Type, + syn::Ident, + syn::Ident, + String, + ) = if MUT { + ( + parse_str("&'a mut [u8]").unwrap(), + format_ident!("borsh_mut"), + format_ident!("slice_mut"), + String::from("Mut"), + ) + } else { + ( + parse_str("&'a [u8]").unwrap(), + format_ident!("borsh"), + format_ident!("slice"), + String::new(), + ) + }; + let trait_name: syn::Type = parse_str( + format!( + "light_zero_copy::{}::Deserialize{}", + import_path, camel_case_suffix + ) + .as_str(), + ) + .unwrap(); + let slice_name: syn::Type = parse_str( + format!( + "light_zero_copy::{}::ZeroCopySlice{}Borsh", + import_slice, camel_case_suffix + ) + .as_str(), + ) + .unwrap(); + let struct_inner_trait_name: syn::Type = parse_str( + format!( + "light_zero_copy::{}::ZeroCopyStructInner{1}::ZeroCopyInner{1}", + import_path, camel_case_suffix + ) + .as_str(), + ) + .unwrap(); + match field_type { + FieldType::VecU8(field_name) => { + quote! { + #(#attributes)* + pub #field_name: #mutability + } + } + FieldType::VecCopy(field_name, inner_type) => { + quote! { + #(#attributes)* + pub #field_name: #slice_name<'a, <#inner_type as #struct_inner_trait_name>> + } + } + FieldType::VecNonCopy(field_name, field_type) => { + quote! { + #(#attributes)* + pub #field_name: <#field_type as #trait_name<'a>>::Output + } + } + FieldType::Array(field_name, field_type) => { + quote! { + #(#attributes)* + pub #field_name: light_zero_copy::Ref<#mutability , #field_type> + } + } + FieldType::Option(field_name, field_type) => { + quote! { + #(#attributes)* + pub #field_name: <#field_type as #trait_name<'a>>::Output + } + } + FieldType::OptionU64(field_name) => { + let field_ty_zerocopy = utils::convert_to_zerocopy_type(&parse_quote!(u64)); + quote! { + #(#attributes)* + pub #field_name: Option> + } + } + FieldType::OptionU32(field_name) => { + let field_ty_zerocopy = utils::convert_to_zerocopy_type(&parse_quote!(u32)); + quote! { + #(#attributes)* + pub #field_name: Option> + } + } + FieldType::OptionU16(field_name) => { + let field_ty_zerocopy = utils::convert_to_zerocopy_type(&parse_quote!(u16)); + quote! { + #(#attributes)* + pub #field_name: Option> + } + } + FieldType::Pubkey(field_name) => { + quote! { + #(#attributes)* + pub #field_name: >::Output + } + } + FieldType::IntegerU64(field_name) => { + let field_ty_zerocopy = utils::convert_to_zerocopy_type(&parse_quote!(u64)); + quote! { + #(#attributes)* + pub #field_name: light_zero_copy::Ref<#mutability, #field_ty_zerocopy> + } + } + FieldType::IntegerU32(field_name) => { + let field_ty_zerocopy = utils::convert_to_zerocopy_type(&parse_quote!(u32)); + quote! { + #(#attributes)* + pub #field_name: light_zero_copy::Ref<#mutability, #field_ty_zerocopy> + } + } + FieldType::IntegerU16(field_name) => { + let field_ty_zerocopy = utils::convert_to_zerocopy_type(&parse_quote!(u16)); + quote! { + #(#attributes)* + pub #field_name: light_zero_copy::Ref<#mutability, #field_ty_zerocopy> + } + } + FieldType::IntegerU8(field_name) => { + if MUT { + quote! { + #(#attributes)* + pub #field_name: light_zero_copy::Ref<#mutability, u8> + } + } else { + quote! { + #(#attributes)* + pub #field_name: >::Output + } + } + } + FieldType::Bool(field_name) => { + if MUT { + quote! { + #(#attributes)* + pub #field_name: light_zero_copy::Ref<#mutability, u8> + } + } else { + quote! { + #(#attributes)* + pub #field_name: >::Output + } + } + } + FieldType::CopyU8Bool(field_name) => { + quote! { + #(#attributes)* + pub #field_name: >::Output + } + } + FieldType::Copy(field_name, field_type) => { + let zerocopy_type = utils::convert_to_zerocopy_type(field_type); + quote! { + #(#attributes)* + pub #field_name: light_zero_copy::Ref<#mutability , #zerocopy_type> + } + } + FieldType::NonCopy(field_name, field_type) => { + quote! { + #(#attributes)* + pub #field_name: <#field_type as #trait_name<'a>>::Output + } + } + } + }) +} + +/// Generate accessor methods for boolean fields in struct_fields. +/// We need accessors because booleans are stored as u8. +fn generate_bool_accessor_methods<'a, const MUT: bool>( + struct_fields: &'a [&'a Field], +) -> impl Iterator + 'a { + struct_fields.iter().filter_map(|field| { + let field_name = &field.ident; + let field_type = &field.ty; + + if utils::is_bool_type(field_type) { + let comparison = if MUT { + quote! { *self.#field_name > 0 } + } else { + quote! { self.#field_name > 0 } + }; + + Some(quote! { + pub fn #field_name(&self) -> bool { + #comparison + } + }) + } else { + None + } + }) +} + +/// Generates the ZStruct definition as a TokenStream +pub fn generate_z_struct( + z_struct_name: &Ident, + z_struct_meta_name: &Ident, + struct_fields: &[&Field], + meta_fields: &[&Field], + hasher: bool, +) -> TokenStream { + let mut z_struct_name = z_struct_name.clone(); + let mut z_struct_meta_name = z_struct_meta_name.clone(); + let mutability: syn::Type = if MUT { + z_struct_name = format_ident!("{}Mut", z_struct_name); + z_struct_meta_name = format_ident!("{}Mut", z_struct_meta_name); + parse_str("&'a mut [u8]").unwrap() + } else { + parse_str("&'a [u8] ").unwrap() + }; + + let derive_clone = if MUT { + quote! {} + } else { + quote! {, Clone } + }; + let struct_fields_with_zerocopy_types = + generate_struct_fields_with_zerocopy_types::(struct_fields, &hasher); + + let derive_hasher = if hasher { + quote! { + , LightHasher + } + } else { + quote! {} + }; + let hasher_flatten = if hasher { + quote! { + #[flatten] + } + } else { + quote! {} + }; + + let partial_eq_derive = if MUT { quote!() } else { quote!(, PartialEq) }; + + let mut z_struct = if meta_fields.is_empty() { + quote! { + // ZStruct + #[derive(Debug #partial_eq_derive #derive_clone #derive_hasher)] + pub struct #z_struct_name<'a> { + #(#struct_fields_with_zerocopy_types,)* + } + } + } else { + let mut tokens = quote! { + // ZStruct + #[derive(Debug #partial_eq_derive #derive_clone #derive_hasher)] + pub struct #z_struct_name<'a> { + #hasher_flatten + __meta: light_zero_copy::Ref<#mutability, #z_struct_meta_name>, + #(#struct_fields_with_zerocopy_types,)* + } + impl<'a> core::ops::Deref for #z_struct_name<'a> { + type Target = light_zero_copy::Ref<#mutability , #z_struct_meta_name>; + + fn deref(&self) -> &Self::Target { + &self.__meta + } + } + }; + + if MUT { + tokens.append_all(quote! { + impl<'a> core::ops::DerefMut for #z_struct_name<'a> { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.__meta + } + } + }); + } + tokens + }; + + if !meta_fields.is_empty() { + let meta_bool_accessor_methods = generate_bool_accessor_methods::(meta_fields); + z_struct.append_all(quote! { + // Implement methods for ZStruct + impl<'a> #z_struct_name<'a> { + #(#meta_bool_accessor_methods)* + } + }) + }; + + if !struct_fields.is_empty() { + let bool_accessor_methods = generate_bool_accessor_methods::(struct_fields); + z_struct.append_all(quote! { + // Implement methods for ZStruct + impl<'a> #z_struct_name<'a> { + #(#bool_accessor_methods)* + } + + }); + } + z_struct +} + +#[cfg(test)] +mod tests { + use quote::format_ident; + use rand::{prelude::SliceRandom, rngs::StdRng, thread_rng, Rng, SeedableRng}; + use syn::parse_quote; + + use super::*; + + /// Generate a safe field name for testing + fn random_ident(rng: &mut StdRng) -> String { + // Use predetermined safe field names + const FIELD_NAMES: &[&str] = &[ + "field1", "field2", "field3", "field4", "field5", "value", "data", "count", "size", + "flag", "name", "id", "code", "index", "key", "amount", "balance", "total", "result", + "status", + ]; + + FIELD_NAMES.choose(rng).unwrap().to_string() + } + + /// Generate a random Rust type + fn random_type(rng: &mut StdRng, _depth: usize) -> syn::Type { + // Define our available types + let types = [0, 1, 2, 3, 4, 5, 6, 7]; + + // Randomly select a type index + let selected = *types.choose(rng).unwrap(); + + // Return the corresponding type + match selected { + 0 => parse_quote!(u8), + 1 => parse_quote!(u16), + 2 => parse_quote!(u32), + 3 => parse_quote!(u64), + 4 => parse_quote!(bool), + 5 => parse_quote!(Vec), + 6 => parse_quote!(Vec), + 7 => parse_quote!(Vec), + _ => unreachable!(), + } + } + + /// Generate a random field + fn random_field(rng: &mut StdRng) -> Field { + let name = random_ident(rng); + let ty = random_type(rng, 0); + + // Use a safer approach to create the field + let name_ident = format_ident!("{}", name); + parse_quote!(pub #name_ident: #ty) + } + + /// Generate a list of random fields + fn random_fields(rng: &mut StdRng, count: usize) -> Vec { + (0..count).map(|_| random_field(rng)).collect() + } + + #[test] + fn test_fuzz_generate_z_struct() { + // Set up RNG with a seed for reproducibility + let seed = thread_rng().gen(); + println!("seed {}", seed); + let mut rng = StdRng::seed_from_u64(seed); + + // Now that the test is working, run with 10,000 iterations + let num_iters = 10000; + + for i in 0..num_iters { + // Generate a random struct name + let struct_name = format_ident!("{}", random_ident(&mut rng)); + let z_struct_name = format_ident!("Z{}", struct_name); + let z_struct_meta_name = format_ident!("Z{}Meta", struct_name); + + // Generate random number of fields (1-10) + let field_count = rng.gen_range(1..11); + let fields = random_fields(&mut rng, field_count); + + // Create a named fields collection that lives longer than the process_fields call + let syn_fields = syn::punctuated::Punctuated::from_iter(fields.iter().cloned()); + let fields_named = syn::FieldsNamed { + brace_token: syn::token::Brace::default(), + named: syn_fields, + }; + + // Split into meta fields and struct fields + let (meta_fields, struct_fields) = crate::utils::process_fields(&fields_named); + + // Call the function we're testing + let result = generate_z_struct::( + &z_struct_name, + &z_struct_meta_name, + &struct_fields, + &meta_fields, + false, + ); + + // Get the generated code as a string for validation + let result_str = result.to_string(); + + // Print the first result for debugging + println!("Generated code format:\n{}", result_str); + + // Verify the result contains expected struct elements + // Basic validation - must be non-empty + assert!( + !result_str.is_empty(), + "Failed to generate TokenStream for iteration {}", + i + ); + + // Validate that the generated code contains the expected struct definition + let struct_pattern = format!("struct {} < 'a >", z_struct_name); + assert!( + result_str.contains(&struct_pattern), + "Generated code missing struct definition for iteration {}. Expected: {}", + i, + struct_pattern + ); + + if meta_fields.is_empty() { + // Validate the meta field is present + assert!( + !result_str.contains("meta :"), + "Generated code had meta field for iteration {}", + i + ); + // Validate Deref implementation + assert!( + !result_str.contains("impl < 'a > core :: ops :: Deref"), + "Generated code missing Deref implementation for iteration {}", + i + ); + } else { + // Validate the meta field is present + assert!( + result_str.contains("meta :"), + "Generated code missing meta field for iteration {}", + i + ); + // Validate Deref implementation + assert!( + result_str.contains("impl < 'a > core :: ops :: Deref"), + "Generated code missing Deref implementation for iteration {}", + i + ); + // Validate Target type + assert!( + result_str.contains("type Target"), + "Generated code missing Target type for iteration {}", + i + ); + // Check that the deref method is implemented + assert!( + result_str.contains("fn deref (& self)"), + "Generated code missing deref method for iteration {}", + i + ); + + // Check for light_zero_copy::Ref reference + assert!( + result_str.contains("light_zero_copy :: Ref"), + "Generated code missing light_zero_copy::Ref for iteration {}", + i + ); + } + + // Make sure derive attributes are present + assert!( + result_str.contains("# [derive (Debug , PartialEq , Clone)]"), + "Generated code missing derive attributes for iteration {}", + i + ); + } + } +} diff --git a/program-libs/zero-copy-derive/src/zero_copy_struct_inner.rs b/program-libs/zero-copy-derive/src/zero_copy_struct_inner.rs new file mode 100644 index 0000000000..e260955682 --- /dev/null +++ b/program-libs/zero-copy-derive/src/zero_copy_struct_inner.rs @@ -0,0 +1,25 @@ +use proc_macro2::TokenStream; +use quote::quote; +use syn::Ident; + +/// Generates the ZeroCopyStructInner implementation as a TokenStream +pub fn generate_zero_copy_struct_inner( + name: &Ident, + z_struct_name: &Ident, +) -> TokenStream { + if MUT { + quote! { + // ZeroCopyStructInner implementation + impl light_zero_copy::borsh_mut::ZeroCopyStructInnerMut for #name { + type ZeroCopyInnerMut = #z_struct_name<'static>; + } + } + } else { + quote! { + // ZeroCopyStructInner implementation + impl light_zero_copy::borsh::ZeroCopyStructInner for #name { + type ZeroCopyInner = #z_struct_name<'static>; + } + } + } +} diff --git a/program-libs/zero-copy-derive/tests/config_test.rs b/program-libs/zero-copy-derive/tests/config_test.rs new file mode 100644 index 0000000000..1a98783382 --- /dev/null +++ b/program-libs/zero-copy-derive/tests/config_test.rs @@ -0,0 +1,443 @@ +#![cfg(feature = "mut")] + +use borsh::{BorshDeserialize, BorshSerialize}; +use light_zero_copy::borsh_mut::DeserializeMut; +use light_zero_copy_derive::{ZeroCopy, ZeroCopyConfig, ZeroCopyEq, ZeroCopyMut}; + +/// Simple struct with just a Vec field to test basic config functionality +#[repr(C)] +#[derive( + Debug, + PartialEq, + BorshSerialize, + BorshDeserialize, + ZeroCopy, + ZeroCopyMut, + ZeroCopyEq, + ZeroCopyConfig, +)] +pub struct SimpleVecStruct { + pub a: u8, + pub vec: Vec, + pub b: u16, +} + +#[test] +fn test_simple_config_generation() { + // This test verifies that the ZeroCopyConfig derive macro generates the expected config struct + // and ZeroCopyNew implementation + + // The config should have been generated as SimpleVecStructConfig + let config = SimpleVecStructConfig { + vec: 10, // Vec should have u32 config (length) + }; + + // Test that we can create a configuration + assert_eq!(config.vec, 10); + + println!("Config generation test passed!"); +} + +#[test] +fn test_simple_vec_struct_new_zero_copy() { + use light_zero_copy::init_mut::ZeroCopyNew; + + // Test the new_zero_copy method generated by ZeroCopyConfig + let config = SimpleVecStructConfig { + vec: 5, // Vec with capacity 5 + }; + + // Calculate exact buffer size needed and allocate + let buffer_size = SimpleVecStruct::byte_len(&config); + let mut bytes = vec![0u8; buffer_size]; + + // Use the generated new_zero_copy method + let result = SimpleVecStruct::new_zero_copy(&mut bytes, config); + assert!(result.is_ok()); + let (mut simple_struct, remaining) = result.unwrap(); + + // Verify we used exactly the calculated number of bytes + assert_eq!( + remaining.len(), + 0, + "Should have used exactly {} bytes", + buffer_size + ); + + // Test that we can set meta fields + simple_struct.__meta.a = 42; + + // Test that we can write to the vec slice + simple_struct.vec[0] = 10; + simple_struct.vec[1] = 20; + simple_struct.vec[2] = 30; + + // Test that we can set the b field + *simple_struct.b = 12345u16.into(); + + // Verify the values we set + assert_eq!(simple_struct.__meta.a, 42); + assert_eq!(simple_struct.vec[0], 10); + assert_eq!(simple_struct.vec[1], 20); + assert_eq!(simple_struct.vec[2], 30); + assert_eq!(u16::from(*simple_struct.b), 12345); + + // Test deserializing the initialized bytes with zero_copy_at_mut + let deserialize_result = SimpleVecStruct::zero_copy_at_mut(&mut bytes); + assert!(deserialize_result.is_ok()); + let (deserialized, _remaining) = deserialize_result.unwrap(); + + // Verify the deserialized data matches what we set + assert_eq!(deserialized.__meta.a, 42); + assert_eq!(deserialized.vec[0], 10); + assert_eq!(deserialized.vec[1], 20); + assert_eq!(deserialized.vec[2], 30); + assert_eq!(u16::from(*deserialized.b), 12345); + + println!("new_zero_copy initialization test passed!"); +} + +/// Struct with Option field to test Option config +#[repr(C)] +#[derive( + Debug, PartialEq, BorshSerialize, BorshDeserialize, ZeroCopy, ZeroCopyMut, ZeroCopyConfig, +)] +pub struct SimpleOptionStruct { + pub a: u8, + pub option: Option, +} + +#[test] +fn test_simple_option_struct_new_zero_copy() { + use light_zero_copy::init_mut::ZeroCopyNew; + + // Test with option enabled + let config = SimpleOptionStructConfig { + option: true, // Option should have bool config (enabled/disabled) + }; + + // Calculate exact buffer size needed and allocate + let buffer_size = SimpleOptionStruct::byte_len(&config); + let mut bytes = vec![0u8; buffer_size]; + + let result = SimpleOptionStruct::new_zero_copy(&mut bytes, config); + assert!(result.is_ok()); + let (mut simple_struct, remaining) = result.unwrap(); + + // Verify we used exactly the calculated number of bytes + assert_eq!( + remaining.len(), + 0, + "Should have used exactly {} bytes", + buffer_size + ); + + // Test that we can set meta field + simple_struct.__meta.a = 123; + + // Test that option is Some and we can set its value + assert!(simple_struct.option.is_some()); + if let Some(ref mut opt_val) = simple_struct.option { + **opt_val = 98765u64.into(); + } + + // Verify the values + assert_eq!(simple_struct.__meta.a, 123); + if let Some(ref opt_val) = simple_struct.option { + assert_eq!(u64::from(**opt_val), 98765); + } + + // Test deserializing + let (deserialized, _) = SimpleOptionStruct::zero_copy_at_mut(&mut bytes).unwrap(); + assert_eq!(deserialized.__meta.a, 123); + assert!(deserialized.option.is_some()); + if let Some(ref opt_val) = deserialized.option { + assert_eq!(u64::from(**opt_val), 98765); + } + + println!("Option new_zero_copy test passed!"); +} + +#[test] +fn test_simple_option_struct_disabled() { + use light_zero_copy::init_mut::ZeroCopyNew; + + // Test with option disabled + let config = SimpleOptionStructConfig { + option: false, // Option disabled + }; + + // Calculate exact buffer size needed and allocate + let buffer_size = SimpleOptionStruct::byte_len(&config); + let mut bytes = vec![0u8; buffer_size]; + + let result = SimpleOptionStruct::new_zero_copy(&mut bytes, config); + assert!(result.is_ok()); + let (mut simple_struct, remaining) = result.unwrap(); + + // Verify we used exactly the calculated number of bytes + assert_eq!( + remaining.len(), + 0, + "Should have used exactly {} bytes", + buffer_size + ); + + // Set meta field + simple_struct.__meta.a = 200; + + // Test that option is None + assert!(simple_struct.option.is_none()); + + // Test deserializing + let (deserialized, _) = SimpleOptionStruct::zero_copy_at_mut(&mut bytes).unwrap(); + assert_eq!(deserialized.__meta.a, 200); + assert!(deserialized.option.is_none()); + + println!("Option disabled new_zero_copy test passed!"); +} + +/// Test both Vec and Option in one struct +#[repr(C)] +#[derive( + Debug, PartialEq, BorshSerialize, BorshDeserialize, ZeroCopy, ZeroCopyMut, ZeroCopyConfig, +)] +pub struct MixedStruct { + pub a: u8, + pub vec: Vec, + pub option: Option, + pub b: u16, +} + +#[test] +fn test_mixed_struct_new_zero_copy() { + use light_zero_copy::init_mut::ZeroCopyNew; + + // Test with both vec and option enabled + let config = MixedStructConfig { + vec: 8, // Vec -> u32 length + option: true, // Option -> bool enabled + }; + + // Calculate exact buffer size needed and allocate + let buffer_size = MixedStruct::byte_len(&config); + let mut bytes = vec![0u8; buffer_size]; + + let result = MixedStruct::new_zero_copy(&mut bytes, config); + assert!(result.is_ok()); + let (mut mixed_struct, remaining) = result.unwrap(); + + // Verify we used exactly the calculated number of bytes + assert_eq!( + remaining.len(), + 0, + "Should have used exactly {} bytes", + buffer_size + ); + + // Set meta field + mixed_struct.__meta.a = 77; + + // Set vec data + mixed_struct.vec[0] = 11; + mixed_struct.vec[3] = 44; + mixed_struct.vec[7] = 88; + + // Set option value + assert!(mixed_struct.option.is_some()); + if let Some(ref mut opt_val) = mixed_struct.option { + **opt_val = 123456789u64.into(); + } + + // Set b field + *mixed_struct.b = 54321u16.into(); + + // Verify all values + assert_eq!(mixed_struct.__meta.a, 77); + assert_eq!(mixed_struct.vec[0], 11); + assert_eq!(mixed_struct.vec[3], 44); + assert_eq!(mixed_struct.vec[7], 88); + if let Some(ref opt_val) = mixed_struct.option { + assert_eq!(u64::from(**opt_val), 123456789); + } + assert_eq!(u16::from(*mixed_struct.b), 54321); + + // Test deserializing + let (deserialized, _) = MixedStruct::zero_copy_at_mut(&mut bytes).unwrap(); + assert_eq!(deserialized.__meta.a, 77); + assert_eq!(deserialized.vec[0], 11); + assert_eq!(deserialized.vec[3], 44); + assert_eq!(deserialized.vec[7], 88); + assert!(deserialized.option.is_some()); + if let Some(ref opt_val) = deserialized.option { + assert_eq!(u64::from(**opt_val), 123456789); + } + assert_eq!(u16::from(*deserialized.b), 54321); + + println!("Mixed struct new_zero_copy test passed!"); +} + +#[test] +fn test_mixed_struct_option_disabled() { + use light_zero_copy::init_mut::ZeroCopyNew; + + // Test with vec enabled but option disabled + let config = MixedStructConfig { + vec: 3, // Vec -> u32 length + option: false, // Option -> bool disabled + }; + + // Calculate exact buffer size needed and allocate + let buffer_size = MixedStruct::byte_len(&config); + let mut bytes = vec![0u8; buffer_size]; + + let result = MixedStruct::new_zero_copy(&mut bytes, config); + assert!(result.is_ok()); + let (mut mixed_struct, remaining) = result.unwrap(); + + // Verify we used exactly the calculated number of bytes + assert_eq!( + remaining.len(), + 0, + "Should have used exactly {} bytes", + buffer_size + ); + + // Set values + mixed_struct.__meta.a = 99; + mixed_struct.vec[0] = 255; + mixed_struct.vec[2] = 128; + *mixed_struct.b = 9999u16.into(); + + // Verify option is None + assert!(mixed_struct.option.is_none()); + + // Test deserializing + let (deserialized, _) = MixedStruct::zero_copy_at_mut(&mut bytes).unwrap(); + assert_eq!(deserialized.__meta.a, 99); + assert_eq!(deserialized.vec[0], 255); + assert_eq!(deserialized.vec[2], 128); + assert!(deserialized.option.is_none()); + assert_eq!(u16::from(*deserialized.b), 9999); + + println!("Mixed struct option disabled test passed!"); +} + +#[test] +fn test_byte_len_calculation() { + use light_zero_copy::init_mut::ZeroCopyNew; + + // Test SimpleVecStruct byte_len calculation + let config = SimpleVecStructConfig { + vec: 10, // Vec with capacity 10 + }; + + let expected_size = 1 + // a: u8 (meta field) + 4 + 10 + // vec: 4 bytes length + 10 bytes data + 2; // b: u16 + + let calculated_size = SimpleVecStruct::byte_len(&config); + assert_eq!(calculated_size, expected_size); + println!( + "SimpleVecStruct byte_len: calculated={}, expected={}", + calculated_size, expected_size + ); + + // Test SimpleOptionStruct byte_len calculation + let config_some = SimpleOptionStructConfig { + option: true, // Option enabled + }; + + let expected_size_some = 1 + // a: u8 (meta field) + 1 + 8; // option: 1 byte discriminant + 8 bytes u64 + + let calculated_size_some = SimpleOptionStruct::byte_len(&config_some); + assert_eq!(calculated_size_some, expected_size_some); + println!( + "SimpleOptionStruct (Some) byte_len: calculated={}, expected={}", + calculated_size_some, expected_size_some + ); + + let config_none = SimpleOptionStructConfig { + option: false, // Option disabled + }; + + let expected_size_none = 1 + // a: u8 (meta field) + 1; // option: 1 byte discriminant for None + + let calculated_size_none = SimpleOptionStruct::byte_len(&config_none); + assert_eq!(calculated_size_none, expected_size_none); + println!( + "SimpleOptionStruct (None) byte_len: calculated={}, expected={}", + calculated_size_none, expected_size_none + ); + + // Test MixedStruct byte_len calculation + let config_mixed = MixedStructConfig { + vec: 5, // Vec with capacity 5 + option: true, // Option enabled + }; + + let expected_size_mixed = 1 + // a: u8 (meta field) + 4 + 5 + // vec: 4 bytes length + 5 bytes data + 1 + 8 + // option: 1 byte discriminant + 8 bytes u64 + 2; // b: u16 + + let calculated_size_mixed = MixedStruct::byte_len(&config_mixed); + assert_eq!(calculated_size_mixed, expected_size_mixed); + println!( + "MixedStruct byte_len: calculated={}, expected={}", + calculated_size_mixed, expected_size_mixed + ); + + println!("All byte_len calculation tests passed!"); +} + +#[test] +fn test_dynamic_buffer_allocation_with_byte_len() { + use light_zero_copy::init_mut::ZeroCopyNew; + + // Example of how to use byte_len for dynamic buffer allocation + let config = MixedStructConfig { + vec: 12, // Vec with capacity 12 + option: true, // Option enabled + }; + + // Calculate the exact buffer size needed + let required_size = MixedStruct::byte_len(&config); + println!("Required buffer size: {} bytes", required_size); + + // Allocate exactly the right amount of memory + let mut bytes = vec![0u8; required_size]; + + // Initialize the structure + let result = MixedStruct::new_zero_copy(&mut bytes, config); + assert!(result.is_ok()); + let (mut mixed_struct, remaining) = result.unwrap(); + + // Verify we used exactly the right amount of bytes (no remaining bytes) + assert_eq!( + remaining.len(), + 0, + "Should have used exactly the calculated number of bytes" + ); + + // Set some values to verify it works + mixed_struct.__meta.a = 42; + mixed_struct.vec[5] = 123; + if let Some(ref mut opt_val) = mixed_struct.option { + **opt_val = 9999u64.into(); + } + *mixed_struct.b = 7777u16.into(); + + // Verify round-trip works + let (deserialized, _) = MixedStruct::zero_copy_at_mut(&mut bytes).unwrap(); + assert_eq!(deserialized.__meta.a, 42); + assert_eq!(deserialized.vec[5], 123); + if let Some(ref opt_val) = deserialized.option { + assert_eq!(u64::from(**opt_val), 9999); + } + assert_eq!(u16::from(*deserialized.b), 7777); + + println!("Dynamic buffer allocation test passed!"); +} diff --git a/program-libs/zero-copy-derive/tests/from_test.rs b/program-libs/zero-copy-derive/tests/from_test.rs new file mode 100644 index 0000000000..20391c36dd --- /dev/null +++ b/program-libs/zero-copy-derive/tests/from_test.rs @@ -0,0 +1,77 @@ +#![cfg(feature = "mut")] +use std::vec::Vec; + +use borsh::{BorshDeserialize, BorshSerialize}; +use light_zero_copy::{borsh::Deserialize, ZeroCopyEq}; +use light_zero_copy_derive::{ZeroCopy, ZeroCopyMut}; + +// Simple struct with a primitive field and a vector +#[derive(Debug, PartialEq, BorshSerialize, BorshDeserialize, ZeroCopy, ZeroCopyMut, ZeroCopyEq)] +pub struct SimpleStruct { + pub a: u8, + pub b: Vec, +} + +// Basic struct with all basic numeric types +#[derive(Debug, PartialEq, BorshSerialize, BorshDeserialize, ZeroCopy, ZeroCopyMut, ZeroCopyEq)] +pub struct NumericStruct { + pub a: u8, + pub b: u16, + pub c: u32, + pub d: u64, + pub e: bool, +} + +// use light_zero_copy::borsh_mut::DeserializeMut; // Not needed for non-mut derivations + +#[test] +fn test_simple_from_implementation() { + // Create an instance of our struct + let original = SimpleStruct { + a: 42, + b: vec![1, 2, 3, 4, 5], + }; + + // Serialize it + let bytes = original.try_to_vec().unwrap(); + // byte_len not available for non-mut derivations + // assert_eq!(bytes.len(), original.byte_len()); + + // Test From implementation for immutable struct + let (zero_copy, _) = SimpleStruct::zero_copy_at(&bytes).unwrap(); + let converted: SimpleStruct = zero_copy.into(); + assert_eq!(converted.a, 42); + assert_eq!(converted.b, vec![1, 2, 3, 4, 5]); + assert_eq!(converted, original); +} + +#[test] +fn test_numeric_from_implementation() { + // Create a struct with different primitive types + let original = NumericStruct { + a: 1, + b: 2, + c: 3, + d: 4, + e: true, + }; + + // Serialize it + let bytes = original.try_to_vec().unwrap(); + // byte_len not available for non-mut derivations + // assert_eq!(bytes.len(), original.byte_len()); + + // Test From implementation for immutable struct + let (zero_copy, _) = NumericStruct::zero_copy_at(&bytes).unwrap(); + let converted: NumericStruct = zero_copy.clone().into(); + + // Verify all fields + assert_eq!(converted.a, 1); + assert_eq!(converted.b, 2); + assert_eq!(converted.c, 3); + assert_eq!(converted.d, 4); + assert!(converted.e); + + // Verify complete struct + assert_eq!(converted, original); +} diff --git a/program-libs/zero-copy-derive/tests/instruction_data.rs b/program-libs/zero-copy-derive/tests/instruction_data.rs new file mode 100644 index 0000000000..a9c004b651 --- /dev/null +++ b/program-libs/zero-copy-derive/tests/instruction_data.rs @@ -0,0 +1,1423 @@ +#![cfg(feature = "mut")] +use std::vec::Vec; + +use borsh::{BorshDeserialize, BorshSerialize}; +use light_zero_copy::{borsh::Deserialize, borsh_mut::DeserializeMut, errors::ZeroCopyError}; +use light_zero_copy_derive::{ZeroCopy, ZeroCopyConfig, ZeroCopyEq, ZeroCopyMut}; +use zerocopy::{FromBytes, Immutable, IntoBytes, KnownLayout, Ref, Unaligned}; + +#[derive( + Debug, + Copy, + PartialEq, + Clone, + Immutable, + FromBytes, + IntoBytes, + KnownLayout, + BorshDeserialize, + BorshSerialize, + Default, + Unaligned, +)] +#[repr(C)] +pub struct Pubkey(pub(crate) [u8; 32]); + +impl Pubkey { + pub fn new_unique() -> Self { + use rand::Rng; + let mut rng = rand::thread_rng(); + let bytes = rng.gen::<[u8; 32]>(); + Pubkey(bytes) + } + + pub fn to_bytes(self) -> [u8; 32] { + self.0 + } +} + +impl<'a> Deserialize<'a> for Pubkey { + type Output = Ref<&'a [u8], Pubkey>; + + #[inline] + fn zero_copy_at(bytes: &'a [u8]) -> Result<(Self::Output, &'a [u8]), ZeroCopyError> { + Ok(Ref::<&'a [u8], Pubkey>::from_prefix(bytes)?) + } +} + +impl<'a> DeserializeMut<'a> for Pubkey { + type Output = Ref<&'a mut [u8], Pubkey>; + + #[inline] + fn zero_copy_at_mut( + bytes: &'a mut [u8], + ) -> Result<(Self::Output, &'a mut [u8]), ZeroCopyError> { + Ok(Ref::<&'a mut [u8], Pubkey>::from_prefix(bytes)?) + } +} + +// We should not implement DeserializeMut for primitive types directly +// The implementation should be in the zero-copy crate + +impl PartialEq<>::Output> for Pubkey { + fn eq(&self, other: &>::Output) -> bool { + self.0 == other.0 + } +} + +impl<'a> light_zero_copy::init_mut::ZeroCopyNew<'a> for Pubkey { + type Config = (); + type Output = >::Output; + + fn byte_len(_config: &Self::Config) -> usize { + 32 // Pubkey is always 32 bytes + } + + fn new_zero_copy( + bytes: &'a mut [u8], + _config: Self::Config, + ) -> Result<(Self::Output, &'a mut [u8]), ZeroCopyError> { + Self::zero_copy_at_mut(bytes) + } +} + +#[derive( + ZeroCopy, + ZeroCopyMut, + BorshDeserialize, + BorshSerialize, + Debug, + PartialEq, + Default, + Clone, + ZeroCopyConfig, +)] +pub struct InstructionDataInvoke { + pub proof: Option, + pub input_compressed_accounts_with_merkle_context: + Vec, + pub output_compressed_accounts: Vec, + pub relay_fee: Option, + pub new_address_params: Vec, + pub compress_or_decompress_lamports: Option, + pub is_compress: bool, +} + +// impl<'a> light_zero_copy::init_mut::ZeroCopyNew<'a> for InstructionDataInvoke { +// type Config = InstructionDataInvokeConfig; +// type Output = >::Output; +// +// fn new_zero_copy( +// bytes: &'a mut [u8], +// config: Self::Config +// ) -> Result<(Self::Output, &'a mut [u8]), light_zero_copy::errors::ZeroCopyError> { +// use zerocopy::Ref; +// +// // First handle the meta struct (empty for InstructionDataInvoke) +// let (__meta, bytes) = Ref::<&mut [u8], ZInstructionDataInvokeMetaMut>::from_prefix(bytes)?; +// +// // Initialize each field using the corresponding config, following DeserializeMut order +// let (proof, bytes) = as light_zero_copy::init_mut::ZeroCopyNew>::new_zero_copy( +// bytes, +// (config.proof_config.is_some(), CompressedProofConfig {}) +// )?; +// +// let input_configs: Vec = config.input_accounts_configs +// .into_iter() +// .map(|compressed_account_config| PackedCompressedAccountWithMerkleContextConfig { +// compressed_account: CompressedAccountConfig { +// address: (compressed_account_config.address_enabled, ()), +// data: (compressed_account_config.data_enabled, CompressedAccountDataConfig { data: compressed_account_config.data_capacity }), +// }, +// merkle_context: PackedMerkleContextConfig {}, +// }) +// .collect(); +// let (input_compressed_accounts_with_merkle_context, bytes) = as light_zero_copy::init_mut::ZeroCopyNew>::new_zero_copy( +// bytes, +// input_configs +// )?; +// +// let output_configs: Vec = config.output_accounts_configs +// .into_iter() +// .map(|compressed_account_config| OutputCompressedAccountWithPackedContextConfig { +// compressed_account: CompressedAccountConfig { +// address: (compressed_account_config.address_enabled, ()), +// data: (compressed_account_config.data_enabled, CompressedAccountDataConfig { data: compressed_account_config.data_capacity }), +// }, +// }) +// .collect(); +// let (output_compressed_accounts, bytes) = as light_zero_copy::init_mut::ZeroCopyNew>::new_zero_copy( +// bytes, +// output_configs +// )?; +// +// let (relay_fee, bytes) = as light_zero_copy::init_mut::ZeroCopyNew>::new_zero_copy( +// bytes, +// (config.relay_fee_config.is_some(), ()) +// )?; +// +// let new_address_configs: Vec = config.new_address_configs +// .into_iter() +// .map(|_| NewAddressParamsPackedConfig {}) +// .collect(); +// let (new_address_params, bytes) = as light_zero_copy::init_mut::ZeroCopyNew>::new_zero_copy( +// bytes, +// new_address_configs +// )?; +// +// let (compress_or_decompress_lamports, bytes) = as light_zero_copy::init_mut::ZeroCopyNew>::new_zero_copy( +// bytes, +// (config.decompress_lamports_config.is_some(), ()) +// )?; +// +// let (is_compress, bytes) = ::new_zero_copy( +// bytes, +// () +// )?; +// +// Ok(( +// ZInstructionDataInvokeMut { +// proof, +// input_compressed_accounts_with_merkle_context, +// output_compressed_accounts, +// relay_fee, +// new_address_params, +// compress_or_decompress_lamports, +// is_compress, +// }, +// bytes, +// )) +// } +// } + +#[derive( + ZeroCopy, + ZeroCopyMut, + ZeroCopyEq, + BorshDeserialize, + BorshSerialize, + Debug, + PartialEq, + Default, + Clone, +)] +pub struct OutputCompressedAccountWithContext { + pub compressed_account: CompressedAccount, + pub merkle_tree: Pubkey, +} + +#[derive( + ZeroCopy, + ZeroCopyMut, + ZeroCopyEq, + BorshDeserialize, + BorshSerialize, + Debug, + PartialEq, + Default, + Clone, + ZeroCopyConfig, +)] +pub struct OutputCompressedAccountWithPackedContext { + pub compressed_account: CompressedAccount, + pub merkle_tree_index: u8, +} + +// impl<'a> light_zero_copy::init_mut::ZeroCopyNew<'a> for OutputCompressedAccountWithPackedContext { +// type Config = CompressedAccountZeroCopyConfig; +// type Output = >::Output; +// +// fn new_zero_copy( +// bytes: &'a mut [u8], +// config: Self::Config +// ) -> Result<(Self::Output, &'a mut [u8]), ZeroCopyError> { +// let (__meta, bytes) = Ref::<&mut [u8], ZOutputCompressedAccountWithPackedContextMetaMut>::from_prefix(bytes)?; +// let (compressed_account, bytes) = ::new_zero_copy(bytes, config)?; +// let (merkle_tree_index, bytes) = ::new_zero_copy(bytes, ())?; +// +// Ok(( +// ZOutputCompressedAccountWithPackedContextMut { +// compressed_account, +// merkle_tree_index, +// }, +// bytes, +// )) +// } +// } + +#[derive( + ZeroCopy, + ZeroCopyMut, + ZeroCopyEq, + BorshDeserialize, + BorshSerialize, + Debug, + PartialEq, + Default, + Clone, + Copy, + ZeroCopyConfig, +)] +pub struct NewAddressParamsPacked { + pub seed: [u8; 32], + pub address_queue_account_index: u8, + pub address_merkle_tree_account_index: u8, + pub address_merkle_tree_root_index: u16, +} + +// impl<'a> light_zero_copy::init_mut::ZeroCopyNew<'a> for NewAddressParamsPacked { +// type Config = (); +// type Output = >::Output; +// +// fn new_zero_copy( +// bytes: &'a mut [u8], +// _config: Self::Config +// ) -> Result<(Self::Output, &'a mut [u8]), ZeroCopyError> { +// let (__meta, bytes) = Ref::<&mut [u8], ZNewAddressParamsPackedMetaMut>::from_prefix(bytes)?; +// Ok((ZNewAddressParamsPackedMut { __meta }, bytes)) +// } +// } + +#[derive( + ZeroCopy, + ZeroCopyMut, + ZeroCopyEq, + BorshDeserialize, + BorshSerialize, + Debug, + PartialEq, + Default, + Clone, +)] +pub struct NewAddressParams { + pub seed: [u8; 32], + pub address_queue_pubkey: Pubkey, + pub address_merkle_tree_pubkey: Pubkey, + pub address_merkle_tree_root_index: u16, +} + +#[derive( + ZeroCopy, + ZeroCopyMut, + ZeroCopyEq, + BorshDeserialize, + BorshSerialize, + Debug, + PartialEq, + Default, + Clone, + Copy, +)] +pub struct PackedReadOnlyAddress { + pub address: [u8; 32], + pub address_merkle_tree_root_index: u16, + pub address_merkle_tree_account_index: u8, +} + +#[derive( + ZeroCopy, + ZeroCopyMut, + ZeroCopyEq, + BorshDeserialize, + BorshSerialize, + Debug, + PartialEq, + Default, + Clone, +)] +pub struct ReadOnlyAddress { + pub address: [u8; 32], + pub address_merkle_tree_pubkey: Pubkey, + pub address_merkle_tree_root_index: u16, +} + +#[derive( + ZeroCopy, + ZeroCopyMut, + ZeroCopyEq, + BorshDeserialize, + BorshSerialize, + Debug, + PartialEq, + Clone, + Copy, + ZeroCopyConfig, +)] +pub struct CompressedProof { + pub a: [u8; 32], + pub b: [u8; 64], + pub c: [u8; 32], +} + +impl Default for CompressedProof { + fn default() -> Self { + Self { + a: [0; 32], + b: [0; 64], + c: [0; 32], + } + } +} + +// impl<'a> light_zero_copy::init_mut::ZeroCopyNew<'a> for CompressedProof { +// type Config = (); +// type Output = >::Output; +// +// fn new_zero_copy( +// bytes: &'a mut [u8], +// _config: Self::Config +// ) -> Result<(Self::Output, &'a mut [u8]), ZeroCopyError> { +// let (__meta, bytes) = Ref::<&mut [u8], ZCompressedProofMetaMut>::from_prefix(bytes)?; +// Ok((ZCompressedProofMut { __meta }, bytes)) +// } +// } + +#[derive( + ZeroCopy, + ZeroCopyMut, + ZeroCopyEq, + BorshDeserialize, + BorshSerialize, + Debug, + Clone, + Copy, + PartialEq, + Eq, + Default, + ZeroCopyConfig, +)] +pub struct CompressedCpiContext { + /// Is set by the program that is invoking the CPI to signal that is should + /// set the cpi context. + pub set_context: bool, + /// Is set to clear the cpi context since someone could have set it before + /// with unrelated data. + pub first_set_context: bool, + /// Index of cpi context account in remaining accounts. + pub cpi_context_account_index: u8, +} + +#[derive( + ZeroCopy, + ZeroCopyMut, + ZeroCopyEq, + BorshDeserialize, + BorshSerialize, + Debug, + PartialEq, + Default, + Clone, + ZeroCopyConfig, +)] +pub struct PackedCompressedAccountWithMerkleContext { + pub compressed_account: CompressedAccount, + pub merkle_context: PackedMerkleContext, + /// Index of root used in inclusion validity proof. + pub root_index: u16, + /// Placeholder to mark accounts read-only unimplemented set to false. + pub read_only: bool, +} + +// impl<'a> light_zero_copy::init_mut::ZeroCopyNew<'a> for PackedCompressedAccountWithMerkleContext { +// type Config = CompressedAccountZeroCopyConfig; +// type Output = >::Output; +// +// fn new_zero_copy( +// bytes: &'a mut [u8], +// config: Self::Config +// ) -> Result<(Self::Output, &'a mut [u8]), ZeroCopyError> { +// let (__meta, bytes) = Ref::<&mut [u8], ZPackedCompressedAccountWithMerkleContextMetaMut>::from_prefix(bytes)?; +// let (compressed_account, bytes) = ::new_zero_copy(bytes, config)?; +// let (merkle_context, bytes) = ::new_zero_copy(bytes, ())?; +// let (root_index, bytes) = ::new_zero_copy(bytes, ())?; +// let (read_only, bytes) = ::new_zero_copy(bytes, ())?; +// +// Ok(( +// ZPackedCompressedAccountWithMerkleContextMut { +// compressed_account, +// merkle_context, +// root_index, +// read_only, +// }, +// bytes, +// )) +// } +// } + +#[derive( + ZeroCopy, + ZeroCopyMut, + ZeroCopyEq, + BorshDeserialize, + BorshSerialize, + Debug, + Clone, + Copy, + PartialEq, + Default, + ZeroCopyConfig, +)] +pub struct MerkleContext { + pub merkle_tree_pubkey: Pubkey, + pub nullifier_queue_pubkey: Pubkey, + pub leaf_index: u32, + pub prove_by_index: bool, +} + +// impl<'a> light_zero_copy::init_mut::ZeroCopyNew<'a> for MerkleContext { +// type Config = (); +// type Output = >::Output; +// +// fn new_zero_copy( +// bytes: &'a mut [u8], +// _config: Self::Config +// ) -> Result<(Self::Output, &'a mut [u8]), ZeroCopyError> { +// let (__meta, bytes) = Ref::<&mut [u8], ZMerkleContextMetaMut>::from_prefix(bytes)?; +// +// Ok(( +// ZMerkleContextMut { +// __meta, +// }, +// bytes, +// )) +// } +// } + +#[derive( + ZeroCopy, + ZeroCopyMut, + ZeroCopyEq, + BorshDeserialize, + BorshSerialize, + Debug, + PartialEq, + Default, + Clone, +)] +pub struct CompressedAccountWithMerkleContext { + pub compressed_account: CompressedAccount, + pub merkle_context: MerkleContext, +} + +#[derive( + ZeroCopy, + ZeroCopyMut, + ZeroCopyEq, + BorshDeserialize, + BorshSerialize, + Debug, + PartialEq, + Default, + Clone, +)] +pub struct ReadOnlyCompressedAccount { + pub account_hash: [u8; 32], + pub merkle_context: MerkleContext, + pub root_index: u16, +} + +#[derive( + ZeroCopy, + ZeroCopyMut, + ZeroCopyEq, + BorshDeserialize, + BorshSerialize, + Debug, + PartialEq, + Default, + Clone, +)] +pub struct PackedReadOnlyCompressedAccount { + pub account_hash: [u8; 32], + pub merkle_context: PackedMerkleContext, + pub root_index: u16, +} + +#[derive( + ZeroCopy, + ZeroCopyMut, + ZeroCopyEq, + BorshDeserialize, + BorshSerialize, + Debug, + Clone, + Copy, + PartialEq, + Default, + ZeroCopyConfig, +)] +pub struct PackedMerkleContext { + pub merkle_tree_pubkey_index: u8, + pub nullifier_queue_pubkey_index: u8, + pub leaf_index: u32, + pub prove_by_index: bool, +} + +// impl<'a> light_zero_copy::init_mut::ZeroCopyNew<'a> for PackedMerkleContext { +// type Config = (); +// type Output = >::Output; +// +// fn new_zero_copy( +// bytes: &'a mut [u8], +// _config: Self::Config +// ) -> Result<(Self::Output, &'a mut [u8]), ZeroCopyError> { +// let (__meta, bytes) = Ref::<&mut [u8], ZPackedMerkleContextMetaMut>::from_prefix(bytes)?; +// Ok((ZPackedMerkleContextMut { __meta }, bytes)) +// } +// } + +#[derive(Debug, PartialEq, Default, Clone, Copy)] +pub struct CompressedAccountZeroCopyConfig { + pub address_enabled: bool, + pub data_enabled: bool, + pub data_capacity: u32, +} + +// Manual InstructionDataInvokeConfig removed - now using generated config from ZeroCopyConfig derive + +#[derive( + ZeroCopy, + ZeroCopyMut, + BorshDeserialize, + BorshSerialize, + Debug, + PartialEq, + Default, + Clone, + ZeroCopyConfig, +)] +pub struct CompressedAccount { + pub owner: [u8; 32], + pub lamports: u64, + pub address: Option<[u8; 32]>, + pub data: Option, +} + +// impl<'a> light_zero_copy::init_mut::ZeroCopyNew<'a> for CompressedAccount { +// type Config = CompressedAccountZeroCopyConfig; +// type Output = >::Output; +// +// fn new_zero_copy( +// bytes: &'a mut [u8], +// config: Self::Config, +// ) -> Result<(Self::Output, &'a mut [u8]), ZeroCopyError> { +// let (__meta, bytes) = Ref::<&mut [u8], ZCompressedAccountMetaMut>::from_prefix(bytes)?; +// +// // Use generic Option implementation for address field +// let (address, bytes) = as light_zero_copy::init_mut::ZeroCopyNew>::new_zero_copy( +// bytes, +// (config.address_enabled, ()) +// )?; +// +// // Use generic Option implementation for data field +// let (data, bytes) = as light_zero_copy::init_mut::ZeroCopyNew>::new_zero_copy( +// bytes, +// (config.data_enabled, CompressedAccountDataConfig { data: config.data_capacity }) +// )?; +// +// Ok(( +// ZCompressedAccountMut { +// __meta, +// address, +// data, +// }, +// bytes, +// )) +// } +// } + +impl<'a> From> for CompressedAccount { + fn from(value: ZCompressedAccount<'a>) -> Self { + Self { + owner: value.__meta.owner, + lamports: u64::from(value.__meta.lamports), + address: value.address.map(|x| *x), + data: value.data.as_ref().map(|x| x.into()), + } + } +} + +impl<'a> From<&ZCompressedAccount<'a>> for CompressedAccount { + fn from(value: &ZCompressedAccount<'a>) -> Self { + Self { + owner: value.__meta.owner, + lamports: u64::from(value.__meta.lamports), + address: value.address.as_ref().map(|x| **x), + data: value.data.as_ref().map(|x| x.into()), + } + } +} + +impl PartialEq for ZCompressedAccount<'_> { + fn eq(&self, other: &CompressedAccount) -> bool { + if self.address.is_some() + && other.address.is_some() + && *self.address.unwrap() != other.address.unwrap() + { + return false; + } + if self.address.is_some() || other.address.is_some() { + return false; + } + if self.data.is_some() + && other.data.is_some() + && self.data.as_ref().unwrap() != other.data.as_ref().unwrap() + { + return false; + } + if self.data.is_some() || other.data.is_some() { + return false; + } + + self.owner == other.owner && self.lamports == other.lamports + } +} + +// Commented out because mutable derivation is disabled +// impl PartialEq for ZCompressedAccountMut<'_> { +// fn eq(&self, other: &CompressedAccount) -> bool { +// if self.address.is_some() +// && other.address.is_some() +// && **self.address.as_ref().unwrap() != *other.address.as_ref().unwrap() +// { +// return false; +// } +// if self.address.is_some() || other.address.is_some() { +// return false; +// } +// if self.data.is_some() +// && other.data.is_some() +// && self.data.as_ref().unwrap() != other.data.as_ref().unwrap() +// { +// return false; +// } +// if self.data.is_some() || other.data.is_some() { +// return false; +// } + +// self.owner == other.owner && self.lamports == other.lamports +// } +// } +impl PartialEq> for CompressedAccount { + fn eq(&self, other: &ZCompressedAccount) -> bool { + if self.address.is_some() + && other.address.is_some() + && self.address.unwrap() != *other.address.unwrap() + { + return false; + } + if self.address.is_some() || other.address.is_some() { + return false; + } + if self.data.is_some() + && other.data.is_some() + && other.data.as_ref().unwrap() != self.data.as_ref().unwrap() + { + return false; + } + if self.data.is_some() || other.data.is_some() { + return false; + } + + self.owner == other.owner && self.lamports == u64::from(other.lamports) + } +} + +#[derive( + ZeroCopy, + ZeroCopyMut, + ZeroCopyEq, + ZeroCopyConfig, + BorshDeserialize, + BorshSerialize, + Debug, + PartialEq, + Default, + Clone, +)] +pub struct CompressedAccountData { + pub discriminator: [u8; 8], + pub data: Vec, + pub data_hash: [u8; 32], +} + +// COMMENTED OUT: Now using ZeroCopyConfig derive macro instead +// impl<'a> light_zero_copy::init_mut::ZeroCopyNew<'a> for CompressedAccountData { +// type Config = u32; // data_capacity +// type Output = >::Output; + +// fn new_zero_copy( +// bytes: &'a mut [u8], +// data_capacity: Self::Config, +// ) -> Result<(Self::Output, &'a mut [u8]), ZeroCopyError> { +// let (__meta, bytes) = Ref::<&mut [u8], ZCompressedAccountDataMetaMut>::from_prefix(bytes)?; +// // For u8 slices we just use &mut [u8] so we init the len and the split mut separately. +// { +// light_zero_copy::slice_mut::ZeroCopySliceMutBorsh::::new_at( +// data_capacity.into(), +// bytes, +// )?; +// } +// // Split off len for +// let (_, bytes) = bytes.split_at_mut(4); +// let (data, bytes) = bytes.split_at_mut(data_capacity as usize); +// let (data_hash, bytes) = Ref::<&mut [u8], [u8; 32]>::from_prefix(bytes)?; +// Ok(( +// ZCompressedAccountDataMut { +// __meta, +// data, +// data_hash, +// }, +// bytes, +// )) +// } +// } + +#[test] +fn test_compressed_account_data_new_at() { + use light_zero_copy::init_mut::ZeroCopyNew; + let config = CompressedAccountDataConfig { data: 10 }; + + // Calculate exact buffer size needed and allocate + let buffer_size = CompressedAccountData::byte_len(&config); + let mut bytes = vec![0u8; buffer_size]; + let result = CompressedAccountData::new_zero_copy(&mut bytes, config); + assert!(result.is_ok()); + let (mut mut_account, remaining) = result.unwrap(); + + // Verify we used exactly the calculated number of bytes + assert_eq!( + remaining.len(), + 0, + "Should have used exactly {} bytes", + buffer_size + ); + + // Test that we can set discriminator + mut_account.__meta.discriminator = [1, 2, 3, 4, 5, 6, 7, 8]; + + // Test that we can write to data + mut_account.data[0] = 42; + mut_account.data[1] = 43; + + // Test that we can set data_hash + mut_account.data_hash[0] = 99; + mut_account.data_hash[1] = 100; + + assert_eq!(mut_account.__meta.discriminator, [1, 2, 3, 4, 5, 6, 7, 8]); + assert_eq!(mut_account.data[0], 42); + assert_eq!(mut_account.data[1], 43); + assert_eq!(mut_account.data_hash[0], 99); + assert_eq!(mut_account.data_hash[1], 100); + + // Test deserializing the initialized bytes with zero_copy_at_mut + let deserialize_result = CompressedAccountData::zero_copy_at_mut(&mut bytes); + assert!(deserialize_result.is_ok()); + let (deserialized_account, _remaining) = deserialize_result.unwrap(); + + // Verify the deserialized data matches what we set + assert_eq!( + deserialized_account.__meta.discriminator, + [1, 2, 3, 4, 5, 6, 7, 8] + ); + assert_eq!(deserialized_account.data.len(), 10); + assert_eq!(deserialized_account.data[0], 42); + assert_eq!(deserialized_account.data[1], 43); + assert_eq!(deserialized_account.data_hash[0], 99); + assert_eq!(deserialized_account.data_hash[1], 100); +} + +#[test] +fn test_compressed_account_new_at() { + use light_zero_copy::init_mut::ZeroCopyNew; + let config = CompressedAccountConfig { + address: (true, ()), + data: (true, CompressedAccountDataConfig { data: 10 }), + }; + + // Calculate exact buffer size needed and allocate + let buffer_size = CompressedAccount::byte_len(&config); + let mut bytes = vec![0u8; buffer_size]; + let result = CompressedAccount::new_zero_copy(&mut bytes, config); + assert!(result.is_ok()); + let (mut mut_account, remaining) = result.unwrap(); + + // Verify we used exactly the calculated number of bytes + assert_eq!( + remaining.len(), + 0, + "Should have used exactly {} bytes", + buffer_size + ); + + // Set values + mut_account.__meta.owner = [1u8; 32]; + mut_account.__meta.lamports = 12345u64.into(); + mut_account.address.as_mut().unwrap()[0] = 42; + mut_account.data.as_mut().unwrap().data[0] = 99; + + // Test deserialize + let (deserialized, _) = CompressedAccount::zero_copy_at_mut(&mut bytes).unwrap(); + assert_eq!(deserialized.__meta.owner, [1u8; 32]); + assert_eq!(u64::from(deserialized.__meta.lamports), 12345u64); + assert_eq!(deserialized.address.as_ref().unwrap()[0], 42); + assert_eq!(deserialized.data.as_ref().unwrap().data[0], 99); +} + +#[test] +fn test_instruction_data_invoke_new_at() { + use light_zero_copy::init_mut::ZeroCopyNew; + // Create different configs to test various combinations + let compressed_account_config1 = CompressedAccountZeroCopyConfig { + address_enabled: true, + data_enabled: true, + data_capacity: 10, + }; + + let compressed_account_config2 = CompressedAccountZeroCopyConfig { + address_enabled: false, + data_enabled: true, + data_capacity: 5, + }; + + let compressed_account_config3 = CompressedAccountZeroCopyConfig { + address_enabled: true, + data_enabled: false, + data_capacity: 0, + }; + + let compressed_account_config4 = CompressedAccountZeroCopyConfig { + address_enabled: false, + data_enabled: false, + data_capacity: 0, + }; + + let config = InstructionDataInvokeConfig { + proof: (true, CompressedProofConfig {}), // Enable proof + input_compressed_accounts_with_merkle_context: vec![ + PackedCompressedAccountWithMerkleContextConfig { + compressed_account: CompressedAccountConfig { + address: (compressed_account_config1.address_enabled, ()), + data: ( + compressed_account_config1.data_enabled, + CompressedAccountDataConfig { + data: compressed_account_config1.data_capacity, + }, + ), + }, + merkle_context: PackedMerkleContextConfig {}, + }, + PackedCompressedAccountWithMerkleContextConfig { + compressed_account: CompressedAccountConfig { + address: (compressed_account_config2.address_enabled, ()), + data: ( + compressed_account_config2.data_enabled, + CompressedAccountDataConfig { + data: compressed_account_config2.data_capacity, + }, + ), + }, + merkle_context: PackedMerkleContextConfig {}, + }, + ], + output_compressed_accounts: vec![ + OutputCompressedAccountWithPackedContextConfig { + compressed_account: CompressedAccountConfig { + address: (compressed_account_config3.address_enabled, ()), + data: ( + compressed_account_config3.data_enabled, + CompressedAccountDataConfig { + data: compressed_account_config3.data_capacity, + }, + ), + }, + }, + OutputCompressedAccountWithPackedContextConfig { + compressed_account: CompressedAccountConfig { + address: (compressed_account_config4.address_enabled, ()), + data: ( + compressed_account_config4.data_enabled, + CompressedAccountDataConfig { + data: compressed_account_config4.data_capacity, + }, + ), + }, + }, + ], + relay_fee: true, // Enable relay fee + new_address_params: vec![ + NewAddressParamsPackedConfig {}, + NewAddressParamsPackedConfig {}, + ], // Length 2 + compress_or_decompress_lamports: true, // Enable decompress lamports + }; + + // Calculate exact buffer size needed and allocate + let buffer_size = InstructionDataInvoke::byte_len(&config); + let mut bytes = vec![0u8; buffer_size]; + + let result = InstructionDataInvoke::new_zero_copy(&mut bytes, config); + if let Err(ref e) = result { + eprintln!("Error: {:?}", e); + } + assert!(result.is_ok()); + let (_instruction_data, remaining) = result.unwrap(); + + // Verify we used exactly the calculated number of bytes + assert_eq!( + remaining.len(), + 0, + "Should have used exactly {} bytes", + buffer_size + ); + + // Test deserialization round-trip first + let (mut deserialized, _) = InstructionDataInvoke::zero_copy_at_mut(&mut bytes).unwrap(); + + // Now set values and test again + *deserialized.is_compress = 1; + + // Set proof values + if let Some(proof) = &mut deserialized.proof { + proof.a[0] = 42; + proof.b[0] = 43; + proof.c[0] = 44; + } + + // Set relay fee value + if let Some(relay_fee) = &mut deserialized.relay_fee { + **relay_fee = 12345u64.into(); + } + + // Set decompress lamports value + if let Some(decompress_lamports) = &mut deserialized.compress_or_decompress_lamports { + **decompress_lamports = 67890u64.into(); + } + + // Set first input account values + let first_input = &mut deserialized.input_compressed_accounts_with_merkle_context[0]; + first_input.compressed_account.__meta.owner[0] = 11; + first_input.compressed_account.__meta.lamports = 1000u64.into(); + if let Some(address) = &mut first_input.compressed_account.address { + address[0] = 22; + } + if let Some(data) = &mut first_input.compressed_account.data { + data.__meta.discriminator[0] = 33; + data.data[0] = 99; + data.data_hash[0] = 55; + } + + // Set first output account values + let first_output = &mut deserialized.output_compressed_accounts[0]; + first_output.compressed_account.__meta.owner[0] = 77; + first_output.compressed_account.__meta.lamports = 2000u64.into(); + if let Some(address) = &mut first_output.compressed_account.address { + address[0] = 88; + } + + // Verify basic structure with vectors of length 2 + assert_eq!( + deserialized + .input_compressed_accounts_with_merkle_context + .len(), + 2 + ); // Length 2 + assert_eq!(deserialized.output_compressed_accounts.len(), 2); // Length 2 + assert_eq!(deserialized.new_address_params.len(), 2); // Length 2 + assert!(deserialized.proof.is_some()); // Enabled + assert!(deserialized.relay_fee.is_some()); // Enabled + assert!(deserialized.compress_or_decompress_lamports.is_some()); // Enabled + assert_eq!(*deserialized.is_compress, 1); + + // Test data access and modification + if let Some(proof) = &deserialized.proof { + // Verify we can access proof fields and our written values + assert_eq!(proof.a[0], 42); + assert_eq!(proof.b[0], 43); + assert_eq!(proof.c[0], 44); + } + + // Verify option integer values + if let Some(relay_fee) = &deserialized.relay_fee { + assert_eq!(u64::from(**relay_fee), 12345); + } + + if let Some(decompress_lamports) = &deserialized.compress_or_decompress_lamports { + assert_eq!(u64::from(**decompress_lamports), 67890); + } + + // Test accessing first input account (config1: address=true, data=true, capacity=10) + let first_input = &deserialized.input_compressed_accounts_with_merkle_context[0]; + assert_eq!(first_input.compressed_account.__meta.owner[0], 11); // Our written value + assert_eq!( + u64::from(first_input.compressed_account.__meta.lamports), + 1000 + ); // Our written value + assert!(first_input.compressed_account.address.is_some()); // Should be enabled + assert!(first_input.compressed_account.data.is_some()); // Should be enabled + if let Some(address) = &first_input.compressed_account.address { + assert_eq!(address[0], 22); // Our written value + } + if let Some(data) = &first_input.compressed_account.data { + assert_eq!(data.data.len(), 10); // Should have capacity 10 + assert_eq!(data.__meta.discriminator[0], 33); // Our written value + assert_eq!(data.data[0], 99); // Our written value + assert_eq!(data.data_hash[0], 55); // Our written value + } + + // Test accessing second input account (config2: address=false, data=true, capacity=5) + let second_input = &deserialized.input_compressed_accounts_with_merkle_context[1]; + assert_eq!(second_input.compressed_account.__meta.owner[0], 0); // Should be zero (not written) + assert!(second_input.compressed_account.address.is_none()); // Should be disabled + assert!(second_input.compressed_account.data.is_some()); // Should be enabled + if let Some(data) = &second_input.compressed_account.data { + assert_eq!(data.data.len(), 5); // Should have capacity 5 + } + + // Test accessing first output account (config3: address=true, data=false, capacity=0) + let first_output = &deserialized.output_compressed_accounts[0]; + assert_eq!(first_output.compressed_account.__meta.owner[0], 77); // Our written value + assert_eq!( + u64::from(first_output.compressed_account.__meta.lamports), + 2000 + ); // Our written value + assert!(first_output.compressed_account.address.is_some()); // Should be enabled + assert!(first_output.compressed_account.data.is_none()); // Should be disabled + if let Some(address) = &first_output.compressed_account.address { + assert_eq!(address[0], 88); // Our written value + } + + // Test accessing second output account (config4: address=false, data=false, capacity=0) + let second_output = &deserialized.output_compressed_accounts[1]; + assert_eq!(second_output.compressed_account.__meta.owner[0], 0); // Should be zero (not written) + assert!(second_output.compressed_account.address.is_none()); // Should be disabled + assert!(second_output.compressed_account.data.is_none()); // Should be disabled +} + +#[test] +fn readme() { + use borsh::{BorshDeserialize, BorshSerialize}; + use light_zero_copy_derive::{ZeroCopy, ZeroCopyEq, ZeroCopyMut}; + + #[repr(C)] + #[derive(Debug, PartialEq, BorshSerialize, BorshDeserialize, ZeroCopy, ZeroCopyMut)] + pub struct MyStructOption { + pub a: u8, + pub b: u16, + pub vec: Vec>, + pub c: Option, + } + + #[repr(C)] + #[derive( + Debug, PartialEq, BorshSerialize, BorshDeserialize, ZeroCopy, ZeroCopyMut, ZeroCopyEq, + )] + pub struct MyStruct { + pub a: u8, + pub b: u16, + pub vec: Vec, + pub c: u64, + } + + // Test the new ZeroCopyConfig functionality + + #[repr(C)] + #[derive(Debug, PartialEq, BorshSerialize, BorshDeserialize, ZeroCopy, ZeroCopyMut)] + pub struct TestConfigStruct { + pub a: u8, + pub b: u16, + pub vec: Vec, + pub option: Option, + } + + let my_struct = MyStruct { + a: 1, + b: 2, + vec: vec![1u8; 32], + c: 3, + }; + // Use the struct with zero-copy deserialization + let bytes = my_struct.try_to_vec().unwrap(); + // byte_len not available for non-mut derivations + // assert_eq!(bytes.len(), my_struct.byte_len()); + let (zero_copy, _remaining) = MyStruct::zero_copy_at(&bytes).unwrap(); + assert_eq!(zero_copy.a, 1); + let org_struct: MyStruct = zero_copy.into(); + assert_eq!(org_struct, my_struct); + // { + // let (mut zero_copy_mut, _remaining) = MyStruct::zero_copy_at_mut(&mut bytes).unwrap(); + // zero_copy_mut.a = 42; + // } + // let borsh = MyStruct::try_from_slice(&bytes).unwrap(); + // assert_eq!(borsh.a, 42u8); +} + +#[derive( + ZeroCopy, + ZeroCopyMut, + BorshDeserialize, + BorshSerialize, + Debug, + PartialEq, + Default, + Clone, + ZeroCopyConfig, +)] +pub struct InstructionDataInvokeCpi { + pub proof: Option, + pub new_address_params: Vec, + pub input_compressed_accounts_with_merkle_context: + Vec, + pub output_compressed_accounts: Vec, + pub relay_fee: Option, + pub compress_or_decompress_lamports: Option, + pub is_compress: bool, + pub cpi_context: Option, +} + +impl PartialEq> for InstructionDataInvokeCpi { + fn eq(&self, other: &ZInstructionDataInvokeCpi) -> bool { + // Compare proof + match (&self.proof, &other.proof) { + (Some(ref self_proof), Some(ref other_proof)) => { + if self_proof.a != other_proof.a + || self_proof.b != other_proof.b + || self_proof.c != other_proof.c + { + return false; + } + } + (None, None) => {} + _ => return false, + } + + // Compare vectors lengths first + if self.new_address_params.len() != other.new_address_params.len() + || self.input_compressed_accounts_with_merkle_context.len() + != other.input_compressed_accounts_with_merkle_context.len() + || self.output_compressed_accounts.len() != other.output_compressed_accounts.len() + { + return false; + } + + // Compare new_address_params + for (self_param, other_param) in self + .new_address_params + .iter() + .zip(other.new_address_params.iter()) + { + if self_param.seed != other_param.seed + || self_param.address_queue_account_index != other_param.address_queue_account_index + || self_param.address_merkle_tree_account_index + != other_param.address_merkle_tree_account_index + || self_param.address_merkle_tree_root_index + != u16::from(other_param.address_merkle_tree_root_index) + { + return false; + } + } + + // Compare input accounts + for (self_input, other_input) in self + .input_compressed_accounts_with_merkle_context + .iter() + .zip(other.input_compressed_accounts_with_merkle_context.iter()) + { + if self_input != other_input { + return false; + } + } + + // Compare output accounts + for (self_output, other_output) in self + .output_compressed_accounts + .iter() + .zip(other.output_compressed_accounts.iter()) + { + if self_output != other_output { + return false; + } + } + + // Compare relay_fee + match (&self.relay_fee, &other.relay_fee) { + (Some(self_fee), Some(other_fee)) => { + if *self_fee != u64::from(**other_fee) { + return false; + } + } + (None, None) => {} + _ => return false, + } + + // Compare compress_or_decompress_lamports + match ( + &self.compress_or_decompress_lamports, + &other.compress_or_decompress_lamports, + ) { + (Some(self_lamports), Some(other_lamports)) => { + if *self_lamports != u64::from(**other_lamports) { + return false; + } + } + (None, None) => {} + _ => return false, + } + + // Compare is_compress (bool vs u8) + if self.is_compress != (other.is_compress != 0) { + return false; + } + + // Compare cpi_context + match (&self.cpi_context, &other.cpi_context) { + (Some(self_ctx), Some(other_ctx)) => { + if self_ctx.set_context != (other_ctx.set_context != 0) + || self_ctx.first_set_context != (other_ctx.first_set_context != 0) + || self_ctx.cpi_context_account_index != other_ctx.cpi_context_account_index + { + return false; + } + } + (None, None) => {} + _ => return false, + } + + true + } +} + +impl PartialEq for ZInstructionDataInvokeCpi<'_> { + fn eq(&self, other: &InstructionDataInvokeCpi) -> bool { + other.eq(self) + } +} + +impl PartialEq> + for PackedCompressedAccountWithMerkleContext +{ + fn eq(&self, other: &ZPackedCompressedAccountWithMerkleContext) -> bool { + // Compare compressed_account + if self.compressed_account.owner != other.compressed_account.__meta.owner + || self.compressed_account.lamports + != u64::from(other.compressed_account.__meta.lamports) + { + return false; + } + + // Compare optional address + match ( + &self.compressed_account.address, + &other.compressed_account.address, + ) { + (Some(self_addr), Some(other_addr)) => { + if *self_addr != **other_addr { + return false; + } + } + (None, None) => {} + _ => return false, + } + + // Compare optional data + match ( + &self.compressed_account.data, + &other.compressed_account.data, + ) { + (Some(self_data), Some(other_data)) => { + if self_data.discriminator != other_data.__meta.discriminator + || self_data.data_hash != *other_data.data_hash + || self_data.data.len() != other_data.data.len() + { + return false; + } + // Compare data contents + for (self_byte, other_byte) in self_data.data.iter().zip(other_data.data.iter()) { + if *self_byte != *other_byte { + return false; + } + } + } + (None, None) => {} + _ => return false, + } + + // Compare merkle_context + if self.merkle_context.merkle_tree_pubkey_index + != other.merkle_context.__meta.merkle_tree_pubkey_index + || self.merkle_context.nullifier_queue_pubkey_index + != other.merkle_context.__meta.nullifier_queue_pubkey_index + || self.merkle_context.leaf_index != u32::from(other.merkle_context.__meta.leaf_index) + || self.merkle_context.prove_by_index != other.merkle_context.prove_by_index() + { + return false; + } + + // Compare root_index and read_only + if self.root_index != u16::from(*other.root_index) + || self.read_only != (other.read_only != 0) + { + return false; + } + + true + } +} + +impl PartialEq> + for OutputCompressedAccountWithPackedContext +{ + fn eq(&self, other: &ZOutputCompressedAccountWithPackedContext) -> bool { + // Compare compressed_account + if self.compressed_account.owner != other.compressed_account.__meta.owner + || self.compressed_account.lamports + != u64::from(other.compressed_account.__meta.lamports) + { + return false; + } + + // Compare optional address + match ( + &self.compressed_account.address, + &other.compressed_account.address, + ) { + (Some(self_addr), Some(other_addr)) => { + if *self_addr != **other_addr { + return false; + } + } + (None, None) => {} + _ => return false, + } + + // Compare optional data + match ( + &self.compressed_account.data, + &other.compressed_account.data, + ) { + (Some(self_data), Some(other_data)) => { + if self_data.discriminator != other_data.__meta.discriminator + || self_data.data_hash != *other_data.data_hash + || self_data.data.len() != other_data.data.len() + { + return false; + } + // Compare data contents + for (self_byte, other_byte) in self_data.data.iter().zip(other_data.data.iter()) { + if *self_byte != *other_byte { + return false; + } + } + } + (None, None) => {} + _ => return false, + } + + // Compare merkle_tree_index + if self.merkle_tree_index != other.merkle_tree_index { + return false; + } + + true + } +} diff --git a/program-libs/zero-copy-derive/tests/random.rs b/program-libs/zero-copy-derive/tests/random.rs new file mode 100644 index 0000000000..e678d0d806 --- /dev/null +++ b/program-libs/zero-copy-derive/tests/random.rs @@ -0,0 +1,651 @@ +#![cfg(feature = "mut")] +use std::assert_eq; + +use borsh::BorshDeserialize; +use light_zero_copy::{borsh::Deserialize, init_mut::ZeroCopyNew}; +use rand::{ + rngs::{StdRng, ThreadRng}, + Rng, +}; + +mod instruction_data; +use instruction_data::{ + CompressedAccount, + CompressedAccountConfig, + CompressedAccountData, + CompressedAccountDataConfig, + CompressedCpiContext, + CompressedCpiContextConfig, + CompressedProof, + CompressedProofConfig, + InstructionDataInvoke, + // Config types (generated by ZeroCopyConfig derive) + InstructionDataInvokeConfig, + InstructionDataInvokeCpi, + InstructionDataInvokeCpiConfig, + NewAddressParamsPacked, + NewAddressParamsPackedConfig, + OutputCompressedAccountWithPackedContext, + OutputCompressedAccountWithPackedContextConfig, + PackedCompressedAccountWithMerkleContext, + PackedCompressedAccountWithMerkleContextConfig, + PackedMerkleContext, + PackedMerkleContextConfig, + Pubkey, + // Zero-copy mutable types + ZInstructionDataInvokeCpiMut, + ZInstructionDataInvokeMut, +}; + +// Function to populate mutable zero-copy structure with data from InstructionDataInvokeCpi +fn populate_invoke_cpi_zero_copy( + src: &InstructionDataInvokeCpi, + dst: &mut ZInstructionDataInvokeCpiMut, +) { + *dst.is_compress = if src.is_compress { 1 } else { 0 }; + + // Copy proof if present + if let (Some(src_proof), Some(dst_proof)) = (&src.proof, &mut dst.proof) { + dst_proof.a.copy_from_slice(&src_proof.a); + dst_proof.b.copy_from_slice(&src_proof.b); + dst_proof.c.copy_from_slice(&src_proof.c); + } + + // Copy new_address_params + for (src_param, dst_param) in src + .new_address_params + .iter() + .zip(dst.new_address_params.iter_mut()) + { + dst_param.seed.copy_from_slice(&src_param.seed); + dst_param.address_queue_account_index = src_param.address_queue_account_index; + dst_param.address_merkle_tree_account_index = src_param.address_merkle_tree_account_index; + dst_param.address_merkle_tree_root_index = src_param.address_merkle_tree_root_index.into(); + } + + // Copy input_compressed_accounts_with_merkle_context + for (src_input, dst_input) in src + .input_compressed_accounts_with_merkle_context + .iter() + .zip(dst.input_compressed_accounts_with_merkle_context.iter_mut()) + { + // Copy compressed account + dst_input + .compressed_account + .owner + .copy_from_slice(&src_input.compressed_account.owner); + dst_input.compressed_account.lamports = src_input.compressed_account.lamports.into(); + + // Copy address if present + if let (Some(src_addr), Some(dst_addr)) = ( + &src_input.compressed_account.address, + &mut dst_input.compressed_account.address, + ) { + dst_addr.copy_from_slice(src_addr); + } + + // Copy data if present + if let (Some(src_data), Some(dst_data)) = ( + &src_input.compressed_account.data, + &mut dst_input.compressed_account.data, + ) { + dst_data + .discriminator + .copy_from_slice(&src_data.discriminator); + dst_data.data_hash.copy_from_slice(&src_data.data_hash); + for (src_byte, dst_byte) in src_data.data.iter().zip(dst_data.data.iter_mut()) { + *dst_byte = *src_byte; + } + } + + // Copy merkle context + dst_input.merkle_context.merkle_tree_pubkey_index = + src_input.merkle_context.merkle_tree_pubkey_index; + dst_input.merkle_context.nullifier_queue_pubkey_index = + src_input.merkle_context.nullifier_queue_pubkey_index; + dst_input.merkle_context.leaf_index = src_input.merkle_context.leaf_index.into(); + dst_input.merkle_context.prove_by_index = if src_input.merkle_context.prove_by_index { + 1 + } else { + 0 + }; + + *dst_input.root_index = src_input.root_index.into(); + *dst_input.read_only = if src_input.read_only { 1 } else { 0 }; + } + + // Copy output_compressed_accounts + for (src_output, dst_output) in src + .output_compressed_accounts + .iter() + .zip(dst.output_compressed_accounts.iter_mut()) + { + // Copy compressed account + dst_output + .compressed_account + .owner + .copy_from_slice(&src_output.compressed_account.owner); + dst_output.compressed_account.lamports = src_output.compressed_account.lamports.into(); + + // Copy address if present + if let (Some(src_addr), Some(dst_addr)) = ( + &src_output.compressed_account.address, + &mut dst_output.compressed_account.address, + ) { + dst_addr.copy_from_slice(src_addr); + } + + // Copy data if present + if let (Some(src_data), Some(dst_data)) = ( + &src_output.compressed_account.data, + &mut dst_output.compressed_account.data, + ) { + dst_data + .discriminator + .copy_from_slice(&src_data.discriminator); + dst_data.data_hash.copy_from_slice(&src_data.data_hash); + for (src_byte, dst_byte) in src_data.data.iter().zip(dst_data.data.iter_mut()) { + *dst_byte = *src_byte; + } + } + + *dst_output.merkle_tree_index = src_output.merkle_tree_index; + } + + // Copy relay_fee if present + if let (Some(src_fee), Some(dst_fee)) = (&src.relay_fee, &mut dst.relay_fee) { + **dst_fee = (*src_fee).into(); + } + + // Copy compress_or_decompress_lamports if present + if let (Some(src_lamports), Some(dst_lamports)) = ( + &src.compress_or_decompress_lamports, + &mut dst.compress_or_decompress_lamports, + ) { + **dst_lamports = (*src_lamports).into(); + } + + // Copy cpi_context if present + if let (Some(src_ctx), Some(dst_ctx)) = (&src.cpi_context, &mut dst.cpi_context) { + dst_ctx.set_context = if src_ctx.set_context { 1 } else { 0 }; + dst_ctx.first_set_context = if src_ctx.first_set_context { 1 } else { 0 }; + dst_ctx.cpi_context_account_index = src_ctx.cpi_context_account_index; + } +} + +// Function to populate mutable zero-copy structure with data from InstructionDataInvoke +fn populate_invoke_zero_copy(src: &InstructionDataInvoke, dst: &mut ZInstructionDataInvokeMut) { + *dst.is_compress = if src.is_compress { 1 } else { 0 }; + + // Copy proof if present + if let (Some(src_proof), Some(dst_proof)) = (&src.proof, &mut dst.proof) { + dst_proof.a.copy_from_slice(&src_proof.a); + dst_proof.b.copy_from_slice(&src_proof.b); + dst_proof.c.copy_from_slice(&src_proof.c); + } + + // Copy new_address_params + for (src_param, dst_param) in src + .new_address_params + .iter() + .zip(dst.new_address_params.iter_mut()) + { + dst_param.seed.copy_from_slice(&src_param.seed); + dst_param.address_queue_account_index = src_param.address_queue_account_index; + dst_param.address_merkle_tree_account_index = src_param.address_merkle_tree_account_index; + dst_param.address_merkle_tree_root_index = src_param.address_merkle_tree_root_index.into(); + } + + // Copy input_compressed_accounts_with_merkle_context + for (src_input, dst_input) in src + .input_compressed_accounts_with_merkle_context + .iter() + .zip(dst.input_compressed_accounts_with_merkle_context.iter_mut()) + { + // Copy compressed account + dst_input + .compressed_account + .owner + .copy_from_slice(&src_input.compressed_account.owner); + dst_input.compressed_account.lamports = src_input.compressed_account.lamports.into(); + + // Copy address if present + if let (Some(src_addr), Some(dst_addr)) = ( + &src_input.compressed_account.address, + &mut dst_input.compressed_account.address, + ) { + dst_addr.copy_from_slice(src_addr); + } + + // Copy data if present + if let (Some(src_data), Some(dst_data)) = ( + &src_input.compressed_account.data, + &mut dst_input.compressed_account.data, + ) { + dst_data + .discriminator + .copy_from_slice(&src_data.discriminator); + dst_data.data_hash.copy_from_slice(&src_data.data_hash); + for (src_byte, dst_byte) in src_data.data.iter().zip(dst_data.data.iter_mut()) { + *dst_byte = *src_byte; + } + } + + // Copy merkle context + dst_input.merkle_context.merkle_tree_pubkey_index = + src_input.merkle_context.merkle_tree_pubkey_index; + dst_input.merkle_context.nullifier_queue_pubkey_index = + src_input.merkle_context.nullifier_queue_pubkey_index; + dst_input.merkle_context.leaf_index = src_input.merkle_context.leaf_index.into(); + dst_input.merkle_context.prove_by_index = if src_input.merkle_context.prove_by_index { + 1 + } else { + 0 + }; + + *dst_input.root_index = src_input.root_index.into(); + *dst_input.read_only = if src_input.read_only { 1 } else { 0 }; + } + + // Copy output_compressed_accounts + for (src_output, dst_output) in src + .output_compressed_accounts + .iter() + .zip(dst.output_compressed_accounts.iter_mut()) + { + // Copy compressed account + dst_output + .compressed_account + .owner + .copy_from_slice(&src_output.compressed_account.owner); + dst_output.compressed_account.lamports = src_output.compressed_account.lamports.into(); + + // Copy address if present + if let (Some(src_addr), Some(dst_addr)) = ( + &src_output.compressed_account.address, + &mut dst_output.compressed_account.address, + ) { + dst_addr.copy_from_slice(src_addr); + } + + // Copy data if present + if let (Some(src_data), Some(dst_data)) = ( + &src_output.compressed_account.data, + &mut dst_output.compressed_account.data, + ) { + dst_data + .discriminator + .copy_from_slice(&src_data.discriminator); + dst_data.data_hash.copy_from_slice(&src_data.data_hash); + for (src_byte, dst_byte) in src_data.data.iter().zip(dst_data.data.iter_mut()) { + *dst_byte = *src_byte; + } + } + + *dst_output.merkle_tree_index = src_output.merkle_tree_index; + } + + // Copy relay_fee if present + if let (Some(src_fee), Some(dst_fee)) = (&src.relay_fee, &mut dst.relay_fee) { + **dst_fee = (*src_fee).into(); + } + + // Copy compress_or_decompress_lamports if present + if let (Some(src_lamports), Some(dst_lamports)) = ( + &src.compress_or_decompress_lamports, + &mut dst.compress_or_decompress_lamports, + ) { + **dst_lamports = (*src_lamports).into(); + } +} + +fn get_rnd_instruction_data_invoke_cpi(rng: &mut StdRng) -> InstructionDataInvokeCpi { + InstructionDataInvokeCpi { + proof: Some(CompressedProof { + a: rng.gen(), + b: (0..64) + .map(|_| rng.gen()) + .collect::>() + .try_into() + .unwrap(), + c: rng.gen(), + }), + new_address_params: vec![get_rnd_new_address_params(rng); rng.gen_range(0..10)], + input_compressed_accounts_with_merkle_context: vec![ + get_rnd_test_input_account(rng); + rng.gen_range(0..10) + ], + output_compressed_accounts: vec![get_rnd_test_output_account(rng); rng.gen_range(0..10)], + relay_fee: None, + compress_or_decompress_lamports: rng.gen(), + is_compress: rng.gen(), + cpi_context: Some(get_rnd_cpi_context(rng)), + } +} + +fn get_rnd_cpi_context(rng: &mut StdRng) -> CompressedCpiContext { + CompressedCpiContext { + first_set_context: rng.gen(), + set_context: rng.gen(), + cpi_context_account_index: rng.gen(), + } +} + +fn get_rnd_test_account_data(rng: &mut StdRng) -> CompressedAccountData { + CompressedAccountData { + discriminator: rng.gen(), + data: (0..100).map(|_| rng.gen()).collect::>(), + data_hash: rng.gen(), + } +} + +fn get_rnd_test_account(rng: &mut StdRng) -> CompressedAccount { + CompressedAccount { + owner: Pubkey::new_unique().to_bytes(), + lamports: rng.gen(), + address: Some(Pubkey::new_unique().to_bytes()), + data: Some(get_rnd_test_account_data(rng)), + } +} + +fn get_rnd_test_output_account(rng: &mut StdRng) -> OutputCompressedAccountWithPackedContext { + OutputCompressedAccountWithPackedContext { + compressed_account: get_rnd_test_account(rng), + merkle_tree_index: rng.gen(), + } +} + +fn get_rnd_test_input_account(rng: &mut StdRng) -> PackedCompressedAccountWithMerkleContext { + PackedCompressedAccountWithMerkleContext { + compressed_account: CompressedAccount { + owner: Pubkey::new_unique().to_bytes(), + lamports: 100, + address: Some(Pubkey::new_unique().to_bytes()), + data: Some(get_rnd_test_account_data(rng)), + }, + merkle_context: PackedMerkleContext { + merkle_tree_pubkey_index: rng.gen(), + nullifier_queue_pubkey_index: rng.gen(), + leaf_index: rng.gen(), + prove_by_index: rng.gen(), + }, + root_index: rng.gen(), + read_only: false, + } +} + +fn get_rnd_new_address_params(rng: &mut StdRng) -> NewAddressParamsPacked { + NewAddressParamsPacked { + seed: rng.gen(), + address_queue_account_index: rng.gen(), + address_merkle_tree_account_index: rng.gen(), + address_merkle_tree_root_index: rng.gen(), + } +} + +// Generate config for InstructionDataInvoke based on the actual data +fn generate_random_invoke_config( + invoke_ref: &InstructionDataInvoke, +) -> InstructionDataInvokeConfig { + InstructionDataInvokeConfig { + proof: (invoke_ref.proof.is_some(), CompressedProofConfig {}), + input_compressed_accounts_with_merkle_context: invoke_ref + .input_compressed_accounts_with_merkle_context + .iter() + .map(|account| PackedCompressedAccountWithMerkleContextConfig { + compressed_account: CompressedAccountConfig { + address: (account.compressed_account.address.is_some(), ()), + data: ( + account.compressed_account.data.is_some(), + CompressedAccountDataConfig { + data: account + .compressed_account + .data + .as_ref() + .map(|d| d.data.len() as u32) + .unwrap_or(0), + }, + ), + }, + merkle_context: PackedMerkleContextConfig {}, + }) + .collect(), + output_compressed_accounts: invoke_ref + .output_compressed_accounts + .iter() + .map(|account| OutputCompressedAccountWithPackedContextConfig { + compressed_account: CompressedAccountConfig { + address: (account.compressed_account.address.is_some(), ()), + data: ( + account.compressed_account.data.is_some(), + CompressedAccountDataConfig { + data: account + .compressed_account + .data + .as_ref() + .map(|d| d.data.len() as u32) + .unwrap_or(0), + }, + ), + }, + }) + .collect(), + relay_fee: invoke_ref.relay_fee.is_some(), + new_address_params: invoke_ref + .new_address_params + .iter() + .map(|_| NewAddressParamsPackedConfig {}) + .collect(), + compress_or_decompress_lamports: invoke_ref.compress_or_decompress_lamports.is_some(), + } +} + +// Generate config for InstructionDataInvokeCpi based on the actual data +fn generate_random_invoke_cpi_config( + invoke_cpi_ref: &InstructionDataInvokeCpi, +) -> InstructionDataInvokeCpiConfig { + InstructionDataInvokeCpiConfig { + proof: (invoke_cpi_ref.proof.is_some(), CompressedProofConfig {}), + new_address_params: invoke_cpi_ref + .new_address_params + .iter() + .map(|_| NewAddressParamsPackedConfig {}) + .collect(), + input_compressed_accounts_with_merkle_context: invoke_cpi_ref + .input_compressed_accounts_with_merkle_context + .iter() + .map(|account| PackedCompressedAccountWithMerkleContextConfig { + compressed_account: CompressedAccountConfig { + address: (account.compressed_account.address.is_some(), ()), + data: ( + account.compressed_account.data.is_some(), + CompressedAccountDataConfig { + data: account + .compressed_account + .data + .as_ref() + .map(|d| d.data.len() as u32) + .unwrap_or(0), + }, + ), + }, + merkle_context: PackedMerkleContextConfig {}, + }) + .collect(), + output_compressed_accounts: invoke_cpi_ref + .output_compressed_accounts + .iter() + .map(|account| OutputCompressedAccountWithPackedContextConfig { + compressed_account: CompressedAccountConfig { + address: (account.compressed_account.address.is_some(), ()), + data: ( + account.compressed_account.data.is_some(), + CompressedAccountDataConfig { + data: account + .compressed_account + .data + .as_ref() + .map(|d| d.data.len() as u32) + .unwrap_or(0), + }, + ), + }, + }) + .collect(), + relay_fee: invoke_cpi_ref.relay_fee.is_some(), + compress_or_decompress_lamports: invoke_cpi_ref.compress_or_decompress_lamports.is_some(), + cpi_context: ( + invoke_cpi_ref.cpi_context.is_some(), + CompressedCpiContextConfig {}, + ), + } +} + +#[test] +fn test_invoke_ix_data_deserialize_rnd() { + use rand::{rngs::StdRng, Rng, SeedableRng}; + let mut thread_rng = ThreadRng::default(); + let seed = thread_rng.gen(); + // Keep this print so that in case the test fails + // we can use the seed to reproduce the error. + println!("\n\ne2e test seed for invoke_ix_data {}\n\n", seed); + let mut rng = StdRng::seed_from_u64(seed); + + let num_iters = 1000; + for i in 0..num_iters { + // Create randomized instruction data + let invoke_ref = InstructionDataInvoke { + proof: if rng.gen() { + Some(CompressedProof { + a: rng.gen(), + b: (0..64) + .map(|_| rng.gen()) + .collect::>() + .try_into() + .unwrap(), + c: rng.gen(), + }) + } else { + None + }, + input_compressed_accounts_with_merkle_context: if i % 5 == 0 { + // Only add inputs occasionally to keep test manageable + vec![get_rnd_test_input_account(&mut rng); rng.gen_range(1..3)] + } else { + vec![] + }, + output_compressed_accounts: if i % 4 == 0 { + vec![get_rnd_test_output_account(&mut rng); rng.gen_range(1..3)] + } else { + vec![] + }, + relay_fee: None, // Relay fee is currently not supported + new_address_params: if i % 3 == 0 { + vec![get_rnd_new_address_params(&mut rng); rng.gen_range(1..3)] + } else { + vec![] + }, + compress_or_decompress_lamports: if rng.gen() { Some(rng.gen()) } else { None }, + is_compress: rng.gen(), + }; + + // 1. Generate config based on the random data + let config = generate_random_invoke_config(&invoke_ref); + + // 2. Calculate exact buffer size and allocate + let buffer_size = InstructionDataInvoke::byte_len(&config); + let mut bytes = vec![0u8; buffer_size]; + + // 3. Create mutable zero-copy structure and verify exact allocation + { + let result = InstructionDataInvoke::new_zero_copy(&mut bytes, config); + assert!(result.is_ok(), "Failed to create zero-copy structure"); + let (mut zero_copy_mut, remaining) = result.unwrap(); + + // 4. Verify exact buffer allocation + assert_eq!( + remaining.len(), + 0, + "Should have used exactly {} bytes", + buffer_size + ); + + // 5. Populate the mutable zero-copy structure with random data + populate_invoke_zero_copy(&invoke_ref, &mut zero_copy_mut); + }; // Mutable borrow ends here + + let borsh_ref = InstructionDataInvoke::deserialize(&mut bytes.as_slice()).unwrap(); + // 6. Test immutable deserialization to verify round-trip functionality + let result_immut = InstructionDataInvoke::zero_copy_at(&bytes); + assert!( + result_immut.is_ok(), + "Immutable deserialization should succeed" + ); + assert_eq!(invoke_ref, borsh_ref); + + // 7. Test that basic zero-copy deserialization works without crashing + // The main goal is to verify the zero-copy derive macro functionality + println!("✓ Successfully tested InstructionDataInvoke with {} inputs, {} outputs, {} new_addresses", + invoke_ref.input_compressed_accounts_with_merkle_context.len(), + invoke_ref.output_compressed_accounts.len(), + invoke_ref.new_address_params.len()); + } +} + +#[test] +fn test_instruction_data_invoke_cpi_rnd() { + use rand::{rngs::StdRng, Rng, SeedableRng}; + let mut thread_rng = ThreadRng::default(); + let seed = thread_rng.gen(); + // Keep this print so that in case the test fails + // we can use the seed to reproduce the error. + println!("\n\ne2e test seed {}\n\n", seed); + let mut rng = StdRng::seed_from_u64(seed); + + let num_iters = 10_000; + for _ in 0..num_iters { + // 1. Generate random CPI instruction data + let invoke_cpi_ref = get_rnd_instruction_data_invoke_cpi(&mut rng); + + // 2. Generate config based on the random data + let config = generate_random_invoke_cpi_config(&invoke_cpi_ref); + + // 3. Calculate exact buffer size and allocate + let buffer_size = InstructionDataInvokeCpi::byte_len(&config); + let mut bytes = vec![0u8; buffer_size]; + + // 4. Create mutable zero-copy structure and verify exact allocation + { + let result = InstructionDataInvokeCpi::new_zero_copy(&mut bytes, config); + assert!(result.is_ok(), "Failed to create CPI zero-copy structure"); + let (mut zero_copy_mut, remaining) = result.unwrap(); + + // 5. Verify exact buffer allocation + assert_eq!( + remaining.len(), + 0, + "Should have used exactly {} bytes", + buffer_size + ); + + // 6. Populate the mutable zero-copy structure with random data + populate_invoke_cpi_zero_copy(&invoke_cpi_ref, &mut zero_copy_mut); + }; // Mutable borrow ends here + + let borsh_ref = InstructionDataInvokeCpi::deserialize(&mut bytes.as_slice()).unwrap(); + // 7. Test immutable deserialization to verify round-trip functionality + let result_immut = InstructionDataInvokeCpi::zero_copy_at(&bytes); + assert!( + result_immut.is_ok(), + "Immutable deserialization should succeed" + ); + assert_eq!(invoke_cpi_ref, borsh_ref); + + // 8. Test that basic zero-copy deserialization works without crashing + // The main goal is to verify the zero-copy derive macro functionality + println!("✓ Successfully tested InstructionDataInvokeCpi with {} inputs, {} outputs, {} new_addresses", + invoke_cpi_ref.input_compressed_accounts_with_merkle_context.len(), + invoke_cpi_ref.output_compressed_accounts.len(), + invoke_cpi_ref.new_address_params.len()); + } +} diff --git a/program-libs/zero-copy/Cargo.toml b/program-libs/zero-copy/Cargo.toml index ea683e5e48..7b67262bb2 100644 --- a/program-libs/zero-copy/Cargo.toml +++ b/program-libs/zero-copy/Cargo.toml @@ -11,13 +11,17 @@ default = [] solana = ["solana-program-error"] pinocchio = ["dep:pinocchio"] std = [] +derive = ["light-zero-copy-derive"] +mut = ["light-zero-copy-derive/mut"] [dependencies] solana-program-error = { workspace = true, optional = true } pinocchio = { workspace = true, optional = true } thiserror = { workspace = true } zerocopy = { workspace = true } +light-zero-copy-derive = { workspace = true, optional = true } [dev-dependencies] rand = { workspace = true } zerocopy = { workspace = true, features = ["derive"] } +borsh = { workspace = true } diff --git a/program-libs/zero-copy/README.md b/program-libs/zero-copy/README.md index d82ee39232..e28f535469 100644 --- a/program-libs/zero-copy/README.md +++ b/program-libs/zero-copy/README.md @@ -37,6 +37,3 @@ light-zero-copy = { version = "0.1.0", features = ["anchor"] } ### Security Considerations - do not use on a 32 bit target with length greater than u32 - only length until u64 is supported - -### Tests -- `cargo test --features std` diff --git a/program-libs/zero-copy/src/borsh.rs b/program-libs/zero-copy/src/borsh.rs index c7e4fbe4db..a0c91d7473 100644 --- a/program-libs/zero-copy/src/borsh.rs +++ b/program-libs/zero-copy/src/borsh.rs @@ -82,6 +82,17 @@ macro_rules! impl_deserialize_for_primitive { impl_deserialize_for_primitive!(u16, i16, u32, i32, u64, i64); impl_deserialize_for_primitive!(U16, U32, U64); +// Implement Deserialize for fixed-size array types +impl<'a, T: KnownLayout + Immutable + FromBytes, const N: usize> Deserialize<'a> for [T; N] { + type Output = Ref<&'a [u8], [T; N]>; + + #[inline] + fn zero_copy_at(bytes: &'a [u8]) -> Result<(Self::Output, &'a [u8]), ZeroCopyError> { + let (bytes, remaining_bytes) = Ref::<&'a [u8], [T; N]>::from_prefix(bytes)?; + Ok((bytes, remaining_bytes)) + } +} + impl<'a, T: Deserialize<'a>> Deserialize<'a> for Vec { type Output = Vec; #[inline] @@ -138,6 +149,42 @@ impl<'a, T: Deserialize<'a>> Deserialize<'a> for VecU8 { } } +pub trait ZeroCopyStructInner { + type ZeroCopyInner; +} + +impl ZeroCopyStructInner for u64 { + type ZeroCopyInner = U64; +} +impl ZeroCopyStructInner for u32 { + type ZeroCopyInner = U32; +} +impl ZeroCopyStructInner for u16 { + type ZeroCopyInner = U16; +} +impl ZeroCopyStructInner for u8 { + type ZeroCopyInner = u8; +} + +impl ZeroCopyStructInner for Vec { + type ZeroCopyInner = Vec; +} + +impl ZeroCopyStructInner for Option { + type ZeroCopyInner = Option; +} + +// Add ZeroCopyStructInner for array types +impl ZeroCopyStructInner for [u8; N] { + type ZeroCopyInner = Ref<&'static [u8], [u8; N]>; +} + +pub fn borsh_vec_u8_as_slice(bytes: &[u8]) -> Result<(&[u8], &[u8]), ZeroCopyError> { + let (num_slices, bytes) = Ref::<&[u8], U32>::from_prefix(bytes)?; + let num_slices = u32::from(*num_slices) as usize; + Ok(bytes.split_at(num_slices)) +} + #[test] fn test_vecu8() { use std::vec; @@ -224,3 +271,561 @@ fn test_deserialize_vecu8() { assert_eq!(vec, std::vec![4, 5, 6]); assert_eq!(remaining, &[]); } + +#[cfg(test)] +pub mod test { + use std::vec; + + use borsh::{BorshDeserialize, BorshSerialize}; + use zerocopy::{ + little_endian::{U16, U64}, + IntoBytes, Ref, Unaligned, + }; + + use super::{ZeroCopyStructInner, *}; + use crate::slice::ZeroCopySliceBorsh; + + // Rules: + // 1. create ZStruct for the struct + // 1.1. the first fields are extracted into a meta struct until we reach a Vec, Option or type that does not implement Copy, and we implement deref for the meta struct + // 1.2. represent vectors to ZeroCopySlice & don't include these into the meta struct + // 1.3. replace u16 with U16, u32 with U32, etc + // 1.4. every field after the first vector is directly included in the ZStruct and deserialized 1 by 1 + // 1.5. If a vector contains a nested vector (does not implement Copy) it must implement Deserialize + // 1.6. Elements in an Option must implement Deserialize + // 1.7. a type that does not implement Copy must implement Deserialize, and is deserialized 1 by 1 + + // Derive Macro needs to derive: + // 1. ZeroCopyStructInner + // 2. Deserialize + // 3. PartialEq for ZStruct<'_> + // + // For every struct1 - struct7 create struct_derived1 - struct_derived7 and replicate the tests for the new structs. + + // Tests for manually implemented structures (without derive macro) + + #[repr(C)] + #[derive(Debug, PartialEq, BorshSerialize, BorshDeserialize)] + pub struct Struct1 { + pub a: u8, + pub b: u16, + } + + // pub fn data_hash_struct_1(a: u8, b: u16) -> [u8; 32] { + + // } + + #[repr(C)] + #[derive(Debug, PartialEq, KnownLayout, Immutable, Unaligned, FromBytes)] + pub struct ZStruct1Meta { + pub a: u8, + pub b: U16, + } + + #[repr(C)] + #[derive(Debug, PartialEq)] + pub struct ZStruct1<'a> { + meta: Ref<&'a [u8], ZStruct1Meta>, + } + impl<'a> Deref for ZStruct1<'a> { + type Target = Ref<&'a [u8], ZStruct1Meta>; + + fn deref(&self) -> &Self::Target { + &self.meta + } + } + + impl<'a> Deserialize<'a> for Struct1 { + type Output = ZStruct1<'a>; + + fn zero_copy_at(bytes: &'a [u8]) -> Result<(Self::Output, &'a [u8]), ZeroCopyError> { + let (meta, bytes) = Ref::<&[u8], ZStruct1Meta>::from_prefix(bytes)?; + Ok((ZStruct1 { meta }, bytes)) + } + } + + #[test] + fn test_struct_1() { + let bytes = Struct1 { a: 1, b: 2 }.try_to_vec().unwrap(); + let (struct1, remaining) = Struct1::zero_copy_at(&bytes).unwrap(); + assert_eq!(struct1.a, 1u8); + assert_eq!(struct1.b, 2u16); + assert_eq!(remaining, &[]); + } + + #[repr(C)] + #[derive(Debug, PartialEq, Clone, BorshSerialize, BorshDeserialize)] + pub struct Struct2 { + pub a: u8, + pub b: u16, + pub vec: Vec, + } + + #[repr(C)] + #[derive(Debug, PartialEq, KnownLayout, Immutable, Unaligned, FromBytes)] + pub struct ZStruct2Meta { + pub a: u8, + pub b: U16, + } + + #[repr(C)] + #[derive(Debug, PartialEq)] + pub struct ZStruct2<'a> { + meta: Ref<&'a [u8], ZStruct2Meta>, + pub vec: as ZeroCopyStructInner>::ZeroCopyInner, + } + + impl PartialEq for ZStruct2<'_> { + fn eq(&self, other: &Struct2) -> bool { + let meta: &ZStruct2Meta = &self.meta; + if meta.a != other.a || other.b != meta.b.into() { + return false; + } + self.vec.as_slice() == other.vec.as_slice() + } + } + + impl<'a> Deref for ZStruct2<'a> { + type Target = Ref<&'a [u8], ZStruct2Meta>; + + fn deref(&self) -> &Self::Target { + &self.meta + } + } + + impl<'a> Deserialize<'a> for Struct2 { + type Output = ZStruct2<'a>; + + fn zero_copy_at(bytes: &'a [u8]) -> Result<(Self::Output, &'a [u8]), ZeroCopyError> { + let (meta, bytes) = Ref::<&[u8], ZStruct2Meta>::from_prefix(bytes)?; + let (vec, bytes) = as Deserialize>::zero_copy_at(bytes)?; + Ok((ZStruct2 { meta, vec }, bytes)) + } + } + + #[test] + fn test_struct_2() { + let bytes = Struct2 { + a: 1, + b: 2, + vec: vec![1u8; 32], + } + .try_to_vec() + .unwrap(); + let (struct2, remaining) = Struct2::zero_copy_at(&bytes).unwrap(); + assert_eq!(struct2.a, 1u8); + assert_eq!(struct2.b, 2u16); + assert_eq!(struct2.vec.to_vec(), vec![1u8; 32]); + assert_eq!(remaining, &[]); + } + + #[repr(C)] + #[derive(Debug, PartialEq, BorshSerialize, BorshDeserialize)] + pub struct Struct3 { + pub a: u8, + pub b: u16, + pub vec: Vec, + pub c: u64, + } + + #[repr(C)] + #[derive(Debug, PartialEq, KnownLayout, Immutable, Unaligned, FromBytes)] + pub struct ZStruct3Meta { + pub a: u8, + pub b: U16, + } + + #[derive(Debug, PartialEq)] + pub struct ZStruct3<'a> { + meta: Ref<&'a [u8], ZStruct3Meta>, + pub vec: ZeroCopySliceBorsh<'a, u8>, + pub c: Ref<&'a [u8], U64>, + } + + impl<'a> Deref for ZStruct3<'a> { + type Target = Ref<&'a [u8], ZStruct3Meta>; + + fn deref(&self) -> &Self::Target { + &self.meta + } + } + + impl<'a> Deserialize<'a> for Struct3 { + type Output = ZStruct3<'a>; + + fn zero_copy_at(bytes: &'a [u8]) -> Result<(Self::Output, &'a [u8]), ZeroCopyError> { + let (meta, bytes) = Ref::<&[u8], ZStruct3Meta>::from_prefix(bytes)?; + let (vec, bytes) = ZeroCopySliceBorsh::zero_copy_at(bytes)?; + let (c, bytes) = Ref::<&[u8], U64>::from_prefix(bytes)?; + Ok((Self::Output { meta, vec, c }, bytes)) + } + } + + #[test] + fn test_struct_3() { + let bytes = Struct3 { + a: 1, + b: 2, + vec: vec![1u8; 32], + c: 3, + } + .try_to_vec() + .unwrap(); + let (zero_copy, remaining) = Struct3::zero_copy_at(&bytes).unwrap(); + assert_eq!(zero_copy.a, 1u8); + assert_eq!(zero_copy.b, 2u16); + assert_eq!(zero_copy.vec.to_vec(), vec![1u8; 32]); + assert_eq!(u64::from(*zero_copy.c), 3); + assert_eq!(remaining, &[]); + } + + #[repr(C)] + #[derive(Debug, PartialEq, BorshSerialize, BorshDeserialize, Clone)] + pub struct Struct4Nested { + a: u8, + b: u16, + } + + #[repr(C)] + #[derive( + Debug, PartialEq, Copy, Clone, KnownLayout, Immutable, IntoBytes, Unaligned, FromBytes, + )] + pub struct ZStruct4Nested { + pub a: u8, + pub b: U16, + } + + impl ZeroCopyStructInner for Struct4Nested { + type ZeroCopyInner = ZStruct4Nested; + } + + #[repr(C)] + #[derive(Debug, PartialEq, BorshSerialize, BorshDeserialize)] + pub struct Struct4 { + pub a: u8, + pub b: u16, + pub vec: Vec, + pub c: u64, + pub vec_2: Vec, + } + + #[repr(C)] + #[derive(Debug, PartialEq, KnownLayout, Immutable, Unaligned, IntoBytes, FromBytes)] + pub struct ZStruct4Meta { + pub a: ::ZeroCopyInner, + pub b: ::ZeroCopyInner, + } + + #[derive(Debug, PartialEq)] + pub struct ZStruct4<'a> { + meta: Ref<&'a [u8], ZStruct4Meta>, + pub vec: ZeroCopySliceBorsh<'a, ::ZeroCopyInner>, + pub c: Ref<&'a [u8], ::ZeroCopyInner>, + pub vec_2: ZeroCopySliceBorsh<'a, ::ZeroCopyInner>, + } + + impl<'a> Deref for ZStruct4<'a> { + type Target = Ref<&'a [u8], ZStruct4Meta>; + + fn deref(&self) -> &Self::Target { + &self.meta + } + } + + impl<'a> Deserialize<'a> for Struct4 { + type Output = ZStruct4<'a>; + + fn zero_copy_at(bytes: &'a [u8]) -> Result<(Self::Output, &'a [u8]), ZeroCopyError> { + let (meta, bytes) = Ref::<&[u8], ZStruct4Meta>::from_prefix(bytes)?; + let (vec, bytes) = ZeroCopySliceBorsh::from_bytes_at(bytes)?; + let (c, bytes) = + Ref::<&[u8], ::ZeroCopyInner>::from_prefix(bytes)?; + let (vec_2, bytes) = ZeroCopySliceBorsh::from_bytes_at(bytes)?; + Ok(( + Self::Output { + meta, + vec, + c, + vec_2, + }, + bytes, + )) + } + } + + #[test] + fn test_struct_4() { + let bytes = Struct4 { + a: 1, + b: 2, + vec: vec![1u8; 32], + c: 3, + vec_2: vec![Struct4Nested { a: 1, b: 2 }; 32], + } + .try_to_vec() + .unwrap(); + let (zero_copy, remaining) = Struct4::zero_copy_at(&bytes).unwrap(); + assert_eq!(zero_copy.a, 1u8); + assert_eq!(zero_copy.b, 2u16); + assert_eq!(zero_copy.vec.to_vec(), vec![1u8; 32]); + assert_eq!(u64::from(*zero_copy.c), 3); + assert_eq!( + zero_copy.vec_2.to_vec(), + vec![ZStruct4Nested { a: 1, b: 2.into() }; 32] + ); + assert_eq!(remaining, &[]); + } + + #[repr(C)] + #[derive(Debug, Clone, PartialEq, BorshSerialize, BorshDeserialize)] + pub struct Struct5 { + pub a: Vec>, + } + + #[repr(C)] + #[derive(Debug, PartialEq)] + pub struct ZStruct5<'a> { + pub a: Vec::ZeroCopyInner>>, + } + + impl<'a> Deserialize<'a> for Struct5 { + type Output = ZStruct5<'a>; + + fn zero_copy_at(bytes: &'a [u8]) -> Result<(Self::Output, &'a [u8]), ZeroCopyError> { + let (a, bytes) = Vec::::ZeroCopyInner>>::zero_copy_at(bytes)?; + Ok((ZStruct5 { a }, bytes)) + } + } + + #[test] + fn test_struct_5() { + let bytes = Struct5 { + a: vec![vec![1u8; 32]; 32], + } + .try_to_vec() + .unwrap(); + let (zero_copy, remaining) = Struct5::zero_copy_at(&bytes).unwrap(); + assert_eq!( + zero_copy.a.iter().map(|x| x.to_vec()).collect::>(), + vec![vec![1u8; 32]; 32] + ); + assert_eq!(remaining, &[]); + } + + // If a struct inside a vector contains a vector it must implement Deserialize. + #[repr(C)] + #[derive(Debug, PartialEq, BorshSerialize, BorshDeserialize)] + pub struct Struct6 { + pub a: Vec, + } + + #[repr(C)] + #[derive(Debug, PartialEq)] + pub struct ZStruct6<'a> { + pub a: Vec<>::Output>, + } + + impl<'a> Deserialize<'a> for Struct6 { + type Output = ZStruct6<'a>; + + fn zero_copy_at(bytes: &'a [u8]) -> Result<(Self::Output, &'a [u8]), ZeroCopyError> { + let (a, bytes) = Vec::::zero_copy_at(bytes)?; + Ok((ZStruct6 { a }, bytes)) + } + } + + #[test] + fn test_struct_6() { + let bytes = Struct6 { + a: vec![ + Struct2 { + a: 1, + b: 2, + vec: vec![1u8; 32], + }; + 32 + ], + } + .try_to_vec() + .unwrap(); + let (zero_copy, remaining) = Struct6::zero_copy_at(&bytes).unwrap(); + assert_eq!( + zero_copy.a.iter().collect::>(), + vec![ + &Struct2 { + a: 1, + b: 2, + vec: vec![1u8; 32], + }; + 32 + ] + ); + assert_eq!(remaining, &[]); + } + + #[repr(C)] + #[derive(Debug, PartialEq, Clone, BorshSerialize, BorshDeserialize)] + pub struct Struct7 { + pub a: u8, + pub b: u16, + pub option: Option, + } + + #[repr(C)] + #[derive(Debug, PartialEq, KnownLayout, Immutable, Unaligned, FromBytes)] + pub struct ZStruct7Meta { + pub a: u8, + pub b: U16, + } + + #[repr(C)] + #[derive(Debug, PartialEq)] + pub struct ZStruct7<'a> { + meta: Ref<&'a [u8], ZStruct7Meta>, + pub option: as ZeroCopyStructInner>::ZeroCopyInner, + } + + impl PartialEq for ZStruct7<'_> { + fn eq(&self, other: &Struct7) -> bool { + let meta: &ZStruct7Meta = &self.meta; + if meta.a != other.a || other.b != meta.b.into() { + return false; + } + self.option == other.option + } + } + + impl<'a> Deref for ZStruct7<'a> { + type Target = Ref<&'a [u8], ZStruct7Meta>; + + fn deref(&self) -> &Self::Target { + &self.meta + } + } + + impl<'a> Deserialize<'a> for Struct7 { + type Output = ZStruct7<'a>; + + fn zero_copy_at(bytes: &'a [u8]) -> Result<(Self::Output, &'a [u8]), ZeroCopyError> { + let (meta, bytes) = Ref::<&[u8], ZStruct7Meta>::from_prefix(bytes)?; + let (option, bytes) = as Deserialize>::zero_copy_at(bytes)?; + Ok((ZStruct7 { meta, option }, bytes)) + } + } + + #[test] + fn test_struct_7() { + let bytes = Struct7 { + a: 1, + b: 2, + option: Some(3), + } + .try_to_vec() + .unwrap(); + let (zero_copy, remaining) = Struct7::zero_copy_at(&bytes).unwrap(); + assert_eq!(zero_copy.a, 1u8); + assert_eq!(zero_copy.b, 2u16); + assert_eq!(zero_copy.option, Some(3)); + assert_eq!(remaining, &[]); + + let bytes = Struct7 { + a: 1, + b: 2, + option: None, + } + .try_to_vec() + .unwrap(); + let (zero_copy, remaining) = Struct7::zero_copy_at(&bytes).unwrap(); + assert_eq!(zero_copy.a, 1u8); + assert_eq!(zero_copy.b, 2u16); + assert_eq!(zero_copy.option, None); + assert_eq!(remaining, &[]); + } + + // If a struct inside a vector contains a vector it must implement Deserialize. + #[repr(C)] + #[derive(Debug, PartialEq, BorshSerialize, BorshDeserialize)] + pub struct Struct8 { + pub a: Vec, + } + + #[derive(Debug, Clone, PartialEq, BorshSerialize, BorshDeserialize)] + pub struct NestedStruct { + pub a: u8, + pub b: Struct2, + } + + #[repr(C)] + #[derive(Debug, PartialEq)] + pub struct ZNestedStruct<'a> { + pub a: ::ZeroCopyInner, + pub b: >::Output, + } + + impl<'a> Deserialize<'a> for NestedStruct { + type Output = ZNestedStruct<'a>; + + fn zero_copy_at(bytes: &'a [u8]) -> Result<(Self::Output, &'a [u8]), ZeroCopyError> { + let (a, bytes) = ::ZeroCopyInner::zero_copy_at(bytes)?; + let (b, bytes) = ::zero_copy_at(bytes)?; + Ok((ZNestedStruct { a, b }, bytes)) + } + } + + impl PartialEq for ZNestedStruct<'_> { + fn eq(&self, other: &NestedStruct) -> bool { + self.a == other.a && self.b == other.b + } + } + + #[repr(C)] + #[derive(Debug, PartialEq)] + pub struct ZStruct8<'a> { + pub a: Vec<>::Output>, + } + + impl<'a> Deserialize<'a> for Struct8 { + type Output = ZStruct8<'a>; + + fn zero_copy_at(bytes: &'a [u8]) -> Result<(Self::Output, &'a [u8]), ZeroCopyError> { + let (a, bytes) = Vec::::zero_copy_at(bytes)?; + Ok((ZStruct8 { a }, bytes)) + } + } + + #[test] + fn test_struct_8() { + let bytes = Struct8 { + a: vec![ + NestedStruct { + a: 1, + b: Struct2 { + a: 1, + b: 2, + vec: vec![1u8; 32], + }, + }; + 32 + ], + } + .try_to_vec() + .unwrap(); + + let (zero_copy, remaining) = Struct8::zero_copy_at(&bytes).unwrap(); + assert_eq!( + zero_copy.a.iter().collect::>(), + vec![ + &NestedStruct { + a: 1, + b: Struct2 { + a: 1, + b: 2, + vec: vec![1u8; 32], + }, + }; + 32 + ] + ); + assert_eq!(remaining, &[]); + } +} diff --git a/program-libs/zero-copy/src/borsh_mut.rs b/program-libs/zero-copy/src/borsh_mut.rs new file mode 100644 index 0000000000..e8143ab4ce --- /dev/null +++ b/program-libs/zero-copy/src/borsh_mut.rs @@ -0,0 +1,954 @@ +use core::{ + mem::size_of, + ops::{Deref, DerefMut}, +}; +use std::vec::Vec; + +use zerocopy::{ + little_endian::{U16, U32, U64}, + FromBytes, Immutable, KnownLayout, Ref, +}; + +use crate::errors::ZeroCopyError; + +pub trait DeserializeMut<'a> +where + Self: Sized, +{ + // TODO: rename to ZeroCopy, can be used as ::ZeroCopy + type Output; + fn zero_copy_at_mut(bytes: &'a mut [u8]) + -> Result<(Self::Output, &'a mut [u8]), ZeroCopyError>; +} + +// Implement DeserializeMut for fixed-size array types +impl<'a, T: KnownLayout + Immutable + FromBytes, const N: usize> DeserializeMut<'a> for [T; N] { + type Output = Ref<&'a mut [u8], [T; N]>; + + #[inline] + fn zero_copy_at_mut( + bytes: &'a mut [u8], + ) -> Result<(Self::Output, &'a mut [u8]), ZeroCopyError> { + let (bytes, remaining_bytes) = Ref::<&'a mut [u8], [T; N]>::from_prefix(bytes)?; + Ok((bytes, remaining_bytes)) + } +} + +impl<'a, T: DeserializeMut<'a>> DeserializeMut<'a> for Option { + type Output = Option; + + #[inline] + fn zero_copy_at_mut( + bytes: &'a mut [u8], + ) -> Result<(Self::Output, &'a mut [u8]), ZeroCopyError> { + if bytes.len() < size_of::() { + return Err(ZeroCopyError::ArraySize(1, bytes.len())); + } + let (option_byte, bytes) = bytes.split_at_mut(1); + Ok(match option_byte[0] { + 0u8 => (None, bytes), + 1u8 => { + let (value, bytes) = T::zero_copy_at_mut(bytes)?; + (Some(value), bytes) + } + _ => return Err(ZeroCopyError::InvalidOptionByte(option_byte[0])), + }) + } +} + +impl<'a> DeserializeMut<'a> for u8 { + type Output = Self; + + /// Not a zero copy but cheaper. + /// A u8 should not be deserialized on it's own but as part of a struct. + #[inline] + fn zero_copy_at_mut(bytes: &'a mut [u8]) -> Result<(u8, &'a mut [u8]), ZeroCopyError> { + if bytes.len() < size_of::() { + return Err(ZeroCopyError::ArraySize(1, bytes.len())); + } + let (bytes, remaining_bytes) = bytes.split_at_mut(size_of::()); + Ok((bytes[0], remaining_bytes)) + } +} + +// Implementation for specific zerocopy little-endian types +impl<'a, T: KnownLayout + Immutable + FromBytes> DeserializeMut<'a> for Ref<&'a mut [u8], T> { + type Output = Ref<&'a mut [u8], T>; + + #[inline] + fn zero_copy_at_mut( + bytes: &'a mut [u8], + ) -> Result<(Self::Output, &'a mut [u8]), ZeroCopyError> { + let (bytes, remaining_bytes) = Ref::<&mut [u8], T>::from_prefix(bytes)?; + Ok((bytes, remaining_bytes)) + } +} + +impl<'a, T: DeserializeMut<'a>> DeserializeMut<'a> for Vec { + type Output = Vec; + #[inline] + fn zero_copy_at_mut( + bytes: &'a mut [u8], + ) -> Result<(Self::Output, &'a mut [u8]), ZeroCopyError> { + let (num_slices, mut bytes) = Ref::<&mut [u8], U32>::from_prefix(bytes)?; + let num_slices = u32::from(*num_slices) as usize; + // TODO: add check that remaining data is enough to read num_slices + // This prevents agains invalid data allocating a lot of heap memory + let mut slices = Vec::with_capacity(num_slices); + for _ in 0..num_slices { + let (slice, _bytes) = T::zero_copy_at_mut(bytes)?; + bytes = _bytes; + slices.push(slice); + } + Ok((slices, bytes)) + } +} + +macro_rules! impl_deserialize_for_primitive { + ($($t:ty),*) => { + $( + impl<'a> DeserializeMut<'a> for $t { + type Output = Ref<&'a mut [u8], $t>; + + #[inline] + fn zero_copy_at_mut(bytes: &'a mut [u8]) -> Result<(Self::Output, &'a mut [u8]), ZeroCopyError> { + Self::Output::zero_copy_at_mut(bytes) + } + } + )* + }; +} + +impl_deserialize_for_primitive!(u16, u32, u64, i16, i32, i64); + +// Add DeserializeMut for zerocopy little-endian types +impl<'a> DeserializeMut<'a> for zerocopy::little_endian::U16 { + type Output = Ref<&'a mut [u8], zerocopy::little_endian::U16>; + + #[inline] + fn zero_copy_at_mut( + bytes: &'a mut [u8], + ) -> Result<(Self::Output, &'a mut [u8]), ZeroCopyError> { + Ok(Ref::<&mut [u8], zerocopy::little_endian::U16>::from_prefix( + bytes, + )?) + } +} + +impl<'a> DeserializeMut<'a> for zerocopy::little_endian::U32 { + type Output = Ref<&'a mut [u8], zerocopy::little_endian::U32>; + + #[inline] + fn zero_copy_at_mut( + bytes: &'a mut [u8], + ) -> Result<(Self::Output, &'a mut [u8]), ZeroCopyError> { + Ok(Ref::<&mut [u8], zerocopy::little_endian::U32>::from_prefix( + bytes, + )?) + } +} + +impl<'a> DeserializeMut<'a> for zerocopy::little_endian::U64 { + type Output = Ref<&'a mut [u8], zerocopy::little_endian::U64>; + + #[inline] + fn zero_copy_at_mut( + bytes: &'a mut [u8], + ) -> Result<(Self::Output, &'a mut [u8]), ZeroCopyError> { + Ok(Ref::<&mut [u8], zerocopy::little_endian::U64>::from_prefix( + bytes, + )?) + } +} + +pub fn borsh_vec_u8_as_slice_mut( + bytes: &mut [u8], +) -> Result<(&mut [u8], &mut [u8]), ZeroCopyError> { + let (num_slices, bytes) = Ref::<&mut [u8], U32>::from_prefix(bytes)?; + let num_slices = u32::from(*num_slices) as usize; + Ok(bytes.split_at_mut(num_slices)) +} + +#[derive(Clone, Debug, Default, PartialEq)] +pub struct VecU8(Vec); +impl VecU8 { + pub fn new() -> Self { + Self(Vec::new()) + } +} + +impl Deref for VecU8 { + type Target = Vec; + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl DerefMut for VecU8 { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.0 + } +} + +impl<'a, T: DeserializeMut<'a>> DeserializeMut<'a> for VecU8 { + type Output = Vec; + + #[inline] + fn zero_copy_at_mut( + bytes: &'a mut [u8], + ) -> Result<(Self::Output, &'a mut [u8]), ZeroCopyError> { + let (num_slices, mut bytes) = Ref::<&mut [u8], u8>::from_prefix(bytes)?; + let num_slices = u32::from(*num_slices) as usize; + let mut slices = Vec::with_capacity(num_slices); + for _ in 0..num_slices { + let (slice, _bytes) = T::zero_copy_at_mut(bytes)?; + bytes = _bytes; + slices.push(slice); + } + Ok((slices, bytes)) + } +} + +pub trait ZeroCopyStructInnerMut { + type ZeroCopyInnerMut; +} + +impl ZeroCopyStructInnerMut for u64 { + type ZeroCopyInnerMut = U64; +} +impl ZeroCopyStructInnerMut for u32 { + type ZeroCopyInnerMut = U32; +} +impl ZeroCopyStructInnerMut for u16 { + type ZeroCopyInnerMut = U16; +} +impl ZeroCopyStructInnerMut for u8 { + type ZeroCopyInnerMut = u8; +} + +impl ZeroCopyStructInnerMut for Vec { + type ZeroCopyInnerMut = Vec; +} + +impl ZeroCopyStructInnerMut for Option { + type ZeroCopyInnerMut = Option; +} + +impl ZeroCopyStructInnerMut for [u8; N] { + type ZeroCopyInnerMut = Ref<&'static mut [u8], [u8; N]>; +} + +#[test] +fn test_vecu8() { + use std::vec; + let mut bytes = vec![8, 1u8, 2, 3, 4, 5, 6, 7, 8]; + let (vec, remaining_bytes) = VecU8::::zero_copy_at_mut(&mut bytes).unwrap(); + assert_eq!(vec, vec![1u8, 2, 3, 4, 5, 6, 7, 8]); + assert_eq!(remaining_bytes, &mut []); +} + +#[test] +fn test_deserialize_mut_ref() { + let mut bytes = [1, 0, 0, 0]; // Little-endian representation of 1 + let (ref_data, remaining) = Ref::<&mut [u8], U32>::zero_copy_at_mut(&mut bytes).unwrap(); + assert_eq!(u32::from(*ref_data), 1); + assert_eq!(remaining, &mut []); + let res = Ref::<&mut [u8], U32>::zero_copy_at_mut(&mut []); + assert_eq!(res, Err(ZeroCopyError::Size)); +} + +#[test] +fn test_deserialize_mut_option_some() { + let mut bytes = [1, 2]; // 1 indicates Some, followed by the value 2 + let (option_value, remaining) = Option::::zero_copy_at_mut(&mut bytes).unwrap(); + assert_eq!(option_value, Some(2)); + assert_eq!(remaining, &mut []); + let res = Option::::zero_copy_at_mut(&mut []); + assert_eq!(res, Err(ZeroCopyError::ArraySize(1, 0))); + let mut bytes = [2, 0]; // 2 indicates invalid option byte + let res = Option::::zero_copy_at_mut(&mut bytes); + assert_eq!(res, Err(ZeroCopyError::InvalidOptionByte(2))); +} + +#[test] +fn test_deserialize_mut_option_none() { + let mut bytes = [0]; // 0 indicates None + let (option_value, remaining) = Option::::zero_copy_at_mut(&mut bytes).unwrap(); + assert_eq!(option_value, None); + assert_eq!(remaining, &mut []); +} + +#[test] +fn test_deserialize_mut_u8() { + let mut bytes = [0xFF]; // Value 255 + let (value, remaining) = u8::zero_copy_at_mut(&mut bytes).unwrap(); + assert_eq!(value, 255); + assert_eq!(remaining, &mut []); + let res = u8::zero_copy_at_mut(&mut []); + assert_eq!(res, Err(ZeroCopyError::ArraySize(1, 0))); +} + +#[test] +fn test_deserialize_mut_u16() { + let mut bytes = 2323u16.to_le_bytes(); + let (value, remaining) = u16::zero_copy_at_mut(bytes.as_mut_slice()).unwrap(); + assert_eq!(*value, 2323u16); + assert_eq!(remaining, &mut []); + let mut value = [0u8]; + let res = u16::zero_copy_at_mut(&mut value); + // TODO: investigate why error is not Size as in borsh.rs test. + assert_eq!(res, Err(ZeroCopyError::UnalignedPointer)); +} + +#[test] +fn test_deserialize_mut_vec() { + let mut bytes = [2, 0, 0, 0, 1, 2]; // Length 2, followed by values 1 and 2 + let (vec, remaining) = Vec::::zero_copy_at_mut(&mut bytes).unwrap(); + assert_eq!(vec, std::vec![1, 2]); + assert_eq!(remaining, &mut []); +} + +#[test] +fn test_vecu8_deref() { + let data = std::vec![1, 2, 3]; + let vec_u8 = VecU8(data.clone()); + assert_eq!(&*vec_u8, &data); + + let mut vec = VecU8::new(); + vec.push(1u8); + assert_eq!(*vec, std::vec![1u8]); +} + +#[test] +fn test_deserialize_mut_vecu8() { + let mut bytes = [3, 4, 5, 6]; // Length 3, followed by values 4, 5, 6 + let (vec, remaining) = VecU8::::zero_copy_at_mut(&mut bytes).unwrap(); + assert_eq!(vec, std::vec![4, 5, 6]); + assert_eq!(remaining, &mut []); +} + +#[cfg(test)] +pub mod test { + use std::vec; + + use borsh::{BorshDeserialize, BorshSerialize}; + use zerocopy::{ + little_endian::{U16, U64}, + IntoBytes, Ref, Unaligned, + }; + + use super::*; + use crate::slice_mut::ZeroCopySliceMutBorsh; + + // Rules: + // 1. create ZStruct for the struct + // 1.1. the first fields are extracted into a meta struct until we reach a Vec, Option or type that does not implement Copy, and we implement deref for the meta struct + // 1.2. represent vectors to ZeroCopySlice & don't include these into the meta struct + // 1.3. replace u16 with U16, u32 with U32, etc + // 1.4. every field after the first vector is directly included in the ZStruct and deserialized 1 by 1 + // 1.5. If a vector contains a nested vector (does not implement Copy) it must implement DeserializeMut + // 1.6. Elements in an Option must implement DeserializeMut + // 1.7. a type that does not implement Copy must implement DeserializeMut, and is deserialized 1 by 1 + + // Derive Macro needs to derive: + // 1. ZeroCopyStructInnerMut + // 2. DeserializeMut + // 3. PartialEq for ZStruct<'_> + // + // For every struct1 - struct7 create struct_derived1 - struct_derived7 and replicate the tests for the new structs. + + // Tests for manually implemented structures (without derive macro) + + #[repr(C)] + #[derive(Debug, PartialEq, BorshSerialize, BorshDeserialize)] + pub struct Struct1 { + pub a: u8, + pub b: u16, + } + + #[repr(C)] + #[derive(Debug, PartialEq, KnownLayout, Immutable, Unaligned, FromBytes, IntoBytes)] + pub struct ZStruct1Meta { + pub a: u8, + pub b: U16, + } + + #[repr(C)] + #[derive(Debug, PartialEq)] + pub struct ZStruct1<'a> { + pub meta: Ref<&'a mut [u8], ZStruct1Meta>, + } + impl<'a> Deref for ZStruct1<'a> { + type Target = Ref<&'a mut [u8], ZStruct1Meta>; + + fn deref(&self) -> &Self::Target { + &self.meta + } + } + + impl DerefMut for ZStruct1<'_> { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.meta + } + } + + impl<'a> DeserializeMut<'a> for Struct1 { + type Output = ZStruct1<'a>; + + fn zero_copy_at_mut( + bytes: &'a mut [u8], + ) -> Result<(Self::Output, &'a mut [u8]), ZeroCopyError> { + let (meta, bytes) = Ref::<&mut [u8], ZStruct1Meta>::from_prefix(bytes)?; + Ok((ZStruct1 { meta }, bytes)) + } + } + + #[test] + fn test_struct_1() { + let ref_struct = Struct1 { a: 1, b: 2 }; + let mut bytes = ref_struct.try_to_vec().unwrap(); + + let (mut struct1, remaining) = Struct1::zero_copy_at_mut(&mut bytes).unwrap(); + assert_eq!(struct1.a, 1u8); + assert_eq!(struct1.b, 2u16); + assert_eq!(remaining, &mut []); + struct1.meta.a = 2; + } + + #[repr(C)] + #[derive(Debug, PartialEq, Clone, BorshSerialize, BorshDeserialize)] + pub struct Struct2 { + pub a: u8, + pub b: u16, + pub vec: Vec, + } + + #[repr(C)] + #[derive(Debug, PartialEq, KnownLayout, Immutable, Unaligned, FromBytes)] + pub struct ZStruct2Meta { + pub a: u8, + pub b: U16, + } + + #[repr(C)] + #[derive(Debug, PartialEq)] + pub struct ZStruct2<'a> { + meta: Ref<&'a mut [u8], ZStruct2Meta>, + pub vec: as ZeroCopyStructInnerMut>::ZeroCopyInnerMut, + } + + impl PartialEq for ZStruct2<'_> { + fn eq(&self, other: &Struct2) -> bool { + let meta: &ZStruct2Meta = &self.meta; + if meta.a != other.a || other.b != meta.b.into() { + return false; + } + self.vec.as_slice() == other.vec.as_slice() + } + } + + impl<'a> Deref for ZStruct2<'a> { + type Target = Ref<&'a mut [u8], ZStruct2Meta>; + + fn deref(&self) -> &Self::Target { + &self.meta + } + } + + impl<'a> DeserializeMut<'a> for Struct2 { + type Output = ZStruct2<'a>; + + fn zero_copy_at_mut( + bytes: &'a mut [u8], + ) -> Result<(Self::Output, &'a mut [u8]), ZeroCopyError> { + let (meta, bytes) = Ref::<&mut [u8], ZStruct2Meta>::from_prefix(bytes)?; + let (vec, bytes) = as DeserializeMut<'a>>::zero_copy_at_mut(bytes)?; + Ok((ZStruct2 { meta, vec }, bytes)) + } + } + + #[test] + fn test_struct_2() { + let ref_struct = Struct2 { + a: 1, + b: 2, + vec: vec![1u8; 32], + }; + let mut bytes = ref_struct.try_to_vec().unwrap(); + + let (struct2, remaining) = Struct2::zero_copy_at_mut(&mut bytes).unwrap(); + assert_eq!(struct2.a, 1u8); + assert_eq!(struct2.b, 2u16); + assert_eq!(struct2.vec.to_vec(), vec![1u8; 32]); + assert_eq!(remaining, &mut []); + } + + #[repr(C)] + #[derive(Debug, PartialEq, BorshSerialize, BorshDeserialize)] + pub struct Struct3 { + pub a: u8, + pub b: u16, + pub vec: Vec, + pub c: u64, + } + + #[repr(C)] + #[derive(Debug, PartialEq, KnownLayout, Immutable, Unaligned, FromBytes)] + pub struct ZStruct3Meta { + pub a: u8, + pub b: U16, + } + + #[derive(Debug, PartialEq)] + pub struct ZStruct3<'a> { + meta: Ref<&'a mut [u8], ZStruct3Meta>, + pub vec: ZeroCopySliceMutBorsh<'a, u8>, + pub c: Ref<&'a mut [u8], U64>, + } + + impl<'a> Deref for ZStruct3<'a> { + type Target = Ref<&'a mut [u8], ZStruct3Meta>; + + fn deref(&self) -> &Self::Target { + &self.meta + } + } + + impl<'a> DeserializeMut<'a> for Struct3 { + type Output = ZStruct3<'a>; + + fn zero_copy_at_mut( + bytes: &'a mut [u8], + ) -> Result<(Self::Output, &'a mut [u8]), ZeroCopyError> { + let (meta, bytes) = Ref::<&mut [u8], ZStruct3Meta>::from_prefix(bytes)?; + let (vec, bytes) = ZeroCopySliceMutBorsh::zero_copy_at_mut(bytes)?; + let (c, bytes) = Ref::<&mut [u8], U64>::from_prefix(bytes)?; + Ok((Self::Output { meta, vec, c }, bytes)) + } + } + + #[test] + fn test_struct_3() { + let ref_struct = Struct3 { + a: 1, + b: 2, + vec: vec![1u8; 32], + c: 3, + }; + let mut bytes = ref_struct.try_to_vec().unwrap(); + + let (zero_copy, remaining) = Struct3::zero_copy_at_mut(&mut bytes).unwrap(); + assert_eq!(zero_copy.a, 1u8); + assert_eq!(zero_copy.b, 2u16); + assert_eq!(zero_copy.vec.to_vec(), vec![1u8; 32]); + assert_eq!(u64::from(*zero_copy.c), 3); + assert_eq!(remaining, &mut []); + } + + #[repr(C)] + #[derive(Debug, PartialEq, BorshSerialize, BorshDeserialize, Clone)] + pub struct Struct4Nested { + a: u8, + b: u16, + } + + impl<'a> DeserializeMut<'a> for Struct4Nested { + type Output = ZStruct4Nested; + + fn zero_copy_at_mut( + bytes: &'a mut [u8], + ) -> Result<(Self::Output, &'a mut [u8]), ZeroCopyError> { + let (bytes, remaining_bytes) = Ref::<&mut [u8], ZStruct4Nested>::from_prefix(bytes)?; + Ok((*bytes, remaining_bytes)) + } + } + + #[repr(C)] + #[derive( + Debug, PartialEq, Copy, Clone, KnownLayout, Immutable, IntoBytes, Unaligned, FromBytes, + )] + pub struct ZStruct4Nested { + pub a: u8, + pub b: U16, + } + + impl ZeroCopyStructInnerMut for Struct4Nested { + type ZeroCopyInnerMut = ZStruct4Nested; + } + + #[repr(C)] + #[derive(Debug, PartialEq, BorshSerialize, BorshDeserialize)] + pub struct Struct4 { + pub a: u8, + pub b: u16, + pub vec: Vec, + pub c: u64, + pub vec_2: Vec, + } + + #[repr(C)] + #[derive(Debug, PartialEq, KnownLayout, Immutable, Unaligned, IntoBytes, FromBytes)] + pub struct ZStruct4Meta { + pub a: ::ZeroCopyInnerMut, + pub b: ::ZeroCopyInnerMut, + } + + #[derive(Debug, PartialEq)] + pub struct ZStruct4<'a> { + meta: Ref<&'a mut [u8], ZStruct4Meta>, + pub vec: ZeroCopySliceMutBorsh<'a, ::ZeroCopyInnerMut>, + pub c: Ref<&'a mut [u8], ::ZeroCopyInnerMut>, + pub vec_2: + ZeroCopySliceMutBorsh<'a, ::ZeroCopyInnerMut>, + } + + impl<'a> Deref for ZStruct4<'a> { + type Target = Ref<&'a mut [u8], ZStruct4Meta>; + + fn deref(&self) -> &Self::Target { + &self.meta + } + } + + impl<'a> DeserializeMut<'a> for Struct4 { + type Output = ZStruct4<'a>; + + fn zero_copy_at_mut( + bytes: &'a mut [u8], + ) -> Result<(Self::Output, &'a mut [u8]), ZeroCopyError> { + let (meta, bytes) = Ref::<&mut [u8], ZStruct4Meta>::from_prefix(bytes)?; + let (vec, bytes) = ZeroCopySliceMutBorsh::from_bytes_at(bytes)?; + let (c, bytes) = + Ref::<&mut [u8], ::ZeroCopyInnerMut>::from_prefix( + bytes, + )?; + let (vec_2, bytes) = ZeroCopySliceMutBorsh::from_bytes_at(bytes)?; + Ok(( + Self::Output { + meta, + vec, + c, + vec_2, + }, + bytes, + )) + } + } + + /// TODO: + /// - add SIZE const generic DeserializeMut trait + /// - add new with data function + impl Struct4 { + // pub fn byte_len(&self) -> usize { + // size_of::() + // + size_of::() + // + size_of::() * self.vec.len() + // + size_of::() + // + size_of::() * self.vec_2.len() + // } + + pub fn new_with_data<'a>( + bytes: &'a mut [u8], + data: &Struct4, + ) -> (ZStruct4<'a>, &'a mut [u8]) { + let (mut zero_copy, bytes) = + ::zero_copy_at_mut(bytes).unwrap(); + zero_copy.meta.a = data.a; + zero_copy.meta.b = data.b.into(); + zero_copy + .vec + .iter_mut() + .zip(data.vec.iter()) + .for_each(|(x, y)| *x = *y); + (zero_copy, bytes) + } + } + + #[test] + fn test_struct_4() { + let ref_struct = Struct4 { + a: 1, + b: 2, + vec: vec![1u8; 32], + c: 3, + vec_2: vec![Struct4Nested { a: 1, b: 2 }; 32], + }; + let mut bytes = ref_struct.try_to_vec().unwrap(); + + let (zero_copy, remaining) = Struct4::zero_copy_at_mut(&mut bytes).unwrap(); + assert_eq!(zero_copy.a, 1u8); + assert_eq!(zero_copy.b, 2u16); + assert_eq!(zero_copy.vec.to_vec(), vec![1u8; 32]); + assert_eq!(u64::from(*zero_copy.c), 3); + assert_eq!( + zero_copy.vec_2.to_vec(), + vec![ZStruct4Nested { a: 1, b: 2.into() }; 32] + ); + assert_eq!(remaining, &mut []); + } + + #[repr(C)] + #[derive(Debug, Clone, PartialEq, BorshSerialize, BorshDeserialize)] + pub struct Struct5 { + pub a: Vec>, + } + + #[repr(C)] + #[derive(Debug, PartialEq)] + pub struct ZStruct5<'a> { + pub a: Vec::ZeroCopyInnerMut>>, + } + + impl<'a> DeserializeMut<'a> for Struct5 { + type Output = ZStruct5<'a>; + + fn zero_copy_at_mut( + bytes: &'a mut [u8], + ) -> Result<(Self::Output, &'a mut [u8]), ZeroCopyError> { + let (a, bytes) = Vec::< + ZeroCopySliceMutBorsh<::ZeroCopyInnerMut>, + >::zero_copy_at_mut(bytes)?; + Ok((ZStruct5 { a }, bytes)) + } + } + + #[test] + fn test_struct_5() { + let ref_struct = Struct5 { + a: vec![vec![1u8; 32]; 32], + }; + let mut bytes = ref_struct.try_to_vec().unwrap(); + + let (zero_copy, remaining) = Struct5::zero_copy_at_mut(&mut bytes).unwrap(); + assert_eq!( + zero_copy.a.iter().map(|x| x.to_vec()).collect::>(), + vec![vec![1u8; 32]; 32] + ); + assert_eq!(remaining, &mut []); + } + + // If a struct inside a vector contains a vector it must implement DeserializeMut. + #[repr(C)] + #[derive(Debug, PartialEq, BorshSerialize, BorshDeserialize)] + pub struct Struct6 { + pub a: Vec, + } + + #[repr(C)] + #[derive(Debug, PartialEq)] + pub struct ZStruct6<'a> { + pub a: Vec<>::Output>, + } + + impl<'a> DeserializeMut<'a> for Struct6 { + type Output = ZStruct6<'a>; + + fn zero_copy_at_mut( + bytes: &'a mut [u8], + ) -> Result<(Self::Output, &'a mut [u8]), ZeroCopyError> { + let (a, bytes) = Vec::::zero_copy_at_mut(bytes)?; + Ok((ZStruct6 { a }, bytes)) + } + } + + #[test] + fn test_struct_6() { + let ref_struct = Struct6 { + a: vec![ + Struct2 { + a: 1, + b: 2, + vec: vec![1u8; 32], + }; + 32 + ], + }; + let mut bytes = ref_struct.try_to_vec().unwrap(); + + let (zero_copy, remaining) = Struct6::zero_copy_at_mut(&mut bytes).unwrap(); + assert_eq!( + zero_copy.a.iter().collect::>(), + vec![ + &Struct2 { + a: 1, + b: 2, + vec: vec![1u8; 32], + }; + 32 + ] + ); + assert_eq!(remaining, &mut []); + } + + #[repr(C)] + #[derive(Debug, PartialEq, Clone, BorshSerialize, BorshDeserialize)] + pub struct Struct7 { + pub a: u8, + pub b: u16, + pub option: Option, + } + + #[repr(C)] + #[derive(Debug, PartialEq, KnownLayout, Immutable, Unaligned, FromBytes)] + pub struct ZStruct7Meta { + pub a: u8, + pub b: U16, + } + + #[repr(C)] + #[derive(Debug, PartialEq)] + pub struct ZStruct7<'a> { + meta: Ref<&'a mut [u8], ZStruct7Meta>, + pub option: as ZeroCopyStructInnerMut>::ZeroCopyInnerMut, + } + + impl PartialEq for ZStruct7<'_> { + fn eq(&self, other: &Struct7) -> bool { + let meta: &ZStruct7Meta = &self.meta; + if meta.a != other.a || other.b != meta.b.into() { + return false; + } + self.option == other.option + } + } + + impl<'a> Deref for ZStruct7<'a> { + type Target = Ref<&'a mut [u8], ZStruct7Meta>; + + fn deref(&self) -> &Self::Target { + &self.meta + } + } + + impl<'a> DeserializeMut<'a> for Struct7 { + type Output = ZStruct7<'a>; + + fn zero_copy_at_mut( + bytes: &'a mut [u8], + ) -> Result<(Self::Output, &'a mut [u8]), ZeroCopyError> { + let (meta, bytes) = Ref::<&mut [u8], ZStruct7Meta>::from_prefix(bytes)?; + let (option, bytes) = as DeserializeMut<'a>>::zero_copy_at_mut(bytes)?; + Ok((ZStruct7 { meta, option }, bytes)) + } + } + + #[test] + fn test_struct_7() { + let ref_struct = Struct7 { + a: 1, + b: 2, + option: Some(3), + }; + let mut bytes = ref_struct.try_to_vec().unwrap(); + + let (zero_copy, remaining) = Struct7::zero_copy_at_mut(&mut bytes).unwrap(); + assert_eq!(zero_copy.a, 1u8); + assert_eq!(zero_copy.b, 2u16); + assert_eq!(zero_copy.option, Some(3)); + assert_eq!(remaining, &mut []); + + let ref_struct = Struct7 { + a: 1, + b: 2, + option: None, + }; + let mut bytes = ref_struct.try_to_vec().unwrap(); + + let (zero_copy, remaining) = Struct7::zero_copy_at_mut(&mut bytes).unwrap(); + assert_eq!(zero_copy.a, 1u8); + assert_eq!(zero_copy.b, 2u16); + assert_eq!(zero_copy.option, None); + assert_eq!(remaining, &mut []); + } + + // If a struct inside a vector contains a vector it must implement DeserializeMut. + #[repr(C)] + #[derive(Debug, PartialEq, BorshSerialize, BorshDeserialize)] + pub struct Struct8 { + pub a: Vec, + } + + #[derive(Debug, Clone, PartialEq, BorshSerialize, BorshDeserialize)] + pub struct NestedStruct { + pub a: u8, + pub b: Struct2, + } + + #[repr(C)] + #[derive(Debug, PartialEq)] + pub struct ZNestedStruct<'a> { + pub a: ::ZeroCopyInnerMut, + pub b: >::Output, + } + + impl<'a> DeserializeMut<'a> for NestedStruct { + type Output = ZNestedStruct<'a>; + + fn zero_copy_at_mut( + bytes: &'a mut [u8], + ) -> Result<(Self::Output, &'a mut [u8]), ZeroCopyError> { + let (a, bytes) = + ::ZeroCopyInnerMut::zero_copy_at_mut(bytes)?; + let (b, bytes) = >::zero_copy_at_mut(bytes)?; + Ok((ZNestedStruct { a, b }, bytes)) + } + } + + impl PartialEq for ZNestedStruct<'_> { + fn eq(&self, other: &NestedStruct) -> bool { + self.a == other.a && self.b == other.b + } + } + + #[repr(C)] + #[derive(Debug, PartialEq)] + pub struct ZStruct8<'a> { + pub a: Vec<>::Output>, + } + + impl<'a> DeserializeMut<'a> for Struct8 { + type Output = ZStruct8<'a>; + + fn zero_copy_at_mut( + bytes: &'a mut [u8], + ) -> Result<(Self::Output, &'a mut [u8]), ZeroCopyError> { + let (a, bytes) = Vec::::zero_copy_at_mut(bytes)?; + Ok((ZStruct8 { a }, bytes)) + } + } + + #[test] + fn test_struct_8() { + let ref_struct = Struct8 { + a: vec![ + NestedStruct { + a: 1, + b: Struct2 { + a: 1, + b: 2, + vec: vec![1u8; 32], + }, + }; + 32 + ], + }; + let mut bytes = ref_struct.try_to_vec().unwrap(); + + let (zero_copy, remaining) = Struct8::zero_copy_at_mut(&mut bytes).unwrap(); + assert_eq!( + zero_copy.a.iter().collect::>(), + vec![ + &NestedStruct { + a: 1, + b: Struct2 { + a: 1, + b: 2, + vec: vec![1u8; 32], + }, + }; + 32 + ] + ); + assert_eq!(remaining, &mut []); + } +} diff --git a/program-libs/zero-copy/src/init_mut.rs b/program-libs/zero-copy/src/init_mut.rs new file mode 100644 index 0000000000..b93350f9e2 --- /dev/null +++ b/program-libs/zero-copy/src/init_mut.rs @@ -0,0 +1,264 @@ +use core::mem::size_of; +use std::vec::Vec; + +use crate::{borsh_mut::DeserializeMut, errors::ZeroCopyError}; + +/// Trait for types that can be initialized in mutable byte slices with configuration +/// +/// This trait provides a way to initialize structures in pre-allocated byte buffers +/// with specific configuration parameters that determine Vec lengths, Option states, etc. +pub trait ZeroCopyNew<'a> +where + Self: Sized, +{ + /// Configuration type needed to initialize this type + type Config; + + /// Output type - the mutable zero-copy view of this type + type Output; + + /// Calculate the byte length needed for this type with the given configuration + /// + /// This is essential for allocating the correct buffer size before calling new_zero_copy + fn byte_len(config: &Self::Config) -> usize; + + /// Initialize this type in a mutable byte slice with the given configuration + /// + /// Returns the initialized mutable view and remaining bytes + fn new_zero_copy( + bytes: &'a mut [u8], + config: Self::Config, + ) -> Result<(Self::Output, &'a mut [u8]), ZeroCopyError>; +} + +// Generic implementation for Option +impl<'a, T> ZeroCopyNew<'a> for Option +where + T: ZeroCopyNew<'a>, +{ + type Config = (bool, T::Config); // (enabled, inner_config) + type Output = Option; + + fn byte_len(config: &Self::Config) -> usize { + let (enabled, inner_config) = config; + if *enabled { + // 1 byte for Some discriminant + inner type's byte_len + 1 + T::byte_len(inner_config) + } else { + // Just 1 byte for None discriminant + 1 + } + } + + fn new_zero_copy( + bytes: &'a mut [u8], + config: Self::Config, + ) -> Result<(Self::Output, &'a mut [u8]), ZeroCopyError> { + let (enabled, inner_config) = config; + + if enabled { + bytes[0] = 1; // Some discriminant + let (_, bytes) = bytes.split_at_mut(1); + let (value, bytes) = T::new_zero_copy(bytes, inner_config)?; + Ok((Some(value), bytes)) + } else { + bytes[0] = 0; // None discriminant + let (_, bytes) = bytes.split_at_mut(1); + Ok((None, bytes)) + } + } +} + +// Implementation for primitive types (no configuration needed) +impl<'a> ZeroCopyNew<'a> for u64 { + type Config = (); + type Output = zerocopy::Ref<&'a mut [u8], zerocopy::little_endian::U64>; + + fn byte_len(_config: &Self::Config) -> usize { + size_of::() + } + + fn new_zero_copy( + bytes: &'a mut [u8], + _config: Self::Config, + ) -> Result<(Self::Output, &'a mut [u8]), ZeroCopyError> { + // Return U64 little-endian type for generated structs + Ok(zerocopy::Ref::<&mut [u8], zerocopy::little_endian::U64>::from_prefix(bytes)?) + } +} + +impl<'a> ZeroCopyNew<'a> for u32 { + type Config = (); + type Output = zerocopy::Ref<&'a mut [u8], zerocopy::little_endian::U32>; + + fn byte_len(_config: &Self::Config) -> usize { + size_of::() + } + + fn new_zero_copy( + bytes: &'a mut [u8], + _config: Self::Config, + ) -> Result<(Self::Output, &'a mut [u8]), ZeroCopyError> { + // Return U32 little-endian type for generated structs + Ok(zerocopy::Ref::<&mut [u8], zerocopy::little_endian::U32>::from_prefix(bytes)?) + } +} + +impl<'a> ZeroCopyNew<'a> for u16 { + type Config = (); + type Output = zerocopy::Ref<&'a mut [u8], zerocopy::little_endian::U16>; + + fn byte_len(_config: &Self::Config) -> usize { + size_of::() + } + + fn new_zero_copy( + bytes: &'a mut [u8], + _config: Self::Config, + ) -> Result<(Self::Output, &'a mut [u8]), ZeroCopyError> { + // Return U16 little-endian type for generated structs + Ok(zerocopy::Ref::<&mut [u8], zerocopy::little_endian::U16>::from_prefix(bytes)?) + } +} + +impl<'a> ZeroCopyNew<'a> for u8 { + type Config = (); + type Output = >::Output; + + fn byte_len(_config: &Self::Config) -> usize { + size_of::() + } + + fn new_zero_copy( + bytes: &'a mut [u8], + _config: Self::Config, + ) -> Result<(Self::Output, &'a mut [u8]), ZeroCopyError> { + // Use the DeserializeMut trait to create the proper output + Self::zero_copy_at_mut(bytes) + } +} + +impl<'a> ZeroCopyNew<'a> for bool { + type Config = (); + type Output = >::Output; + + fn byte_len(_config: &Self::Config) -> usize { + size_of::() // bool is serialized as u8 + } + + fn new_zero_copy( + bytes: &'a mut [u8], + _config: Self::Config, + ) -> Result<(Self::Output, &'a mut [u8]), ZeroCopyError> { + // Treat bool as u8 + u8::zero_copy_at_mut(bytes) + } +} + +// Implementation for fixed-size arrays +impl< + 'a, + T: Copy + Default + zerocopy::KnownLayout + zerocopy::Immutable + zerocopy::FromBytes, + const N: usize, + > ZeroCopyNew<'a> for [T; N] +{ + type Config = (); + type Output = >::Output; + + fn byte_len(_config: &Self::Config) -> usize { + size_of::() + } + + fn new_zero_copy( + bytes: &'a mut [u8], + _config: Self::Config, + ) -> Result<(Self::Output, &'a mut [u8]), ZeroCopyError> { + // Use the DeserializeMut trait to create the proper output + Self::zero_copy_at_mut(bytes) + } +} + +// Implementation for zerocopy little-endian types +impl<'a> ZeroCopyNew<'a> for zerocopy::little_endian::U16 { + type Config = (); + type Output = zerocopy::Ref<&'a mut [u8], zerocopy::little_endian::U16>; + + fn byte_len(_config: &Self::Config) -> usize { + size_of::() + } + + fn new_zero_copy( + bytes: &'a mut [u8], + _config: Self::Config, + ) -> Result<(Self::Output, &'a mut [u8]), ZeroCopyError> { + Ok(zerocopy::Ref::<&mut [u8], zerocopy::little_endian::U16>::from_prefix(bytes)?) + } +} + +impl<'a> ZeroCopyNew<'a> for zerocopy::little_endian::U32 { + type Config = (); + type Output = zerocopy::Ref<&'a mut [u8], zerocopy::little_endian::U32>; + + fn byte_len(_config: &Self::Config) -> usize { + size_of::() + } + + fn new_zero_copy( + bytes: &'a mut [u8], + _config: Self::Config, + ) -> Result<(Self::Output, &'a mut [u8]), ZeroCopyError> { + Ok(zerocopy::Ref::<&mut [u8], zerocopy::little_endian::U32>::from_prefix(bytes)?) + } +} + +impl<'a> ZeroCopyNew<'a> for zerocopy::little_endian::U64 { + type Config = (); + type Output = zerocopy::Ref<&'a mut [u8], zerocopy::little_endian::U64>; + + fn byte_len(_config: &Self::Config) -> usize { + size_of::() + } + + fn new_zero_copy( + bytes: &'a mut [u8], + _config: Self::Config, + ) -> Result<(Self::Output, &'a mut [u8]), ZeroCopyError> { + Ok(zerocopy::Ref::<&mut [u8], zerocopy::little_endian::U64>::from_prefix(bytes)?) + } +} + +// Implementation for Vec +impl<'a, T: ZeroCopyNew<'a>> ZeroCopyNew<'a> for Vec { + type Config = Vec; // Vector of configs for each item + type Output = Vec; + + fn byte_len(config: &Self::Config) -> usize { + // 4 bytes for length prefix + sum of byte_len for each element config + 4 + config + .iter() + .map(|config| T::byte_len(config)) + .sum::() + } + + fn new_zero_copy( + bytes: &'a mut [u8], + configs: Self::Config, + ) -> Result<(Self::Output, &'a mut [u8]), ZeroCopyError> { + use zerocopy::{little_endian::U32, Ref}; + + // Write length as U32 + let len = configs.len() as u32; + let (mut len_ref, mut bytes) = Ref::<&mut [u8], U32>::from_prefix(bytes)?; + *len_ref = U32::new(len); + + // Initialize each item with its config + let mut items = Vec::with_capacity(configs.len()); + for config in configs { + let (item, remaining_bytes) = T::new_zero_copy(bytes, config)?; + bytes = remaining_bytes; + items.push(item); + } + + Ok((items, bytes)) + } +} diff --git a/program-libs/zero-copy/src/lib.rs b/program-libs/zero-copy/src/lib.rs index 297c849d53..360781bf92 100644 --- a/program-libs/zero-copy/src/lib.rs +++ b/program-libs/zero-copy/src/lib.rs @@ -10,8 +10,26 @@ pub mod vec; use core::mem::{align_of, size_of}; #[cfg(feature = "std")] pub mod borsh; - -use zerocopy::{FromBytes, Immutable, IntoBytes, KnownLayout}; +#[cfg(feature = "std")] +pub mod borsh_mut; +#[cfg(feature = "std")] +pub mod init_mut; +#[cfg(feature = "std")] +pub use borsh::ZeroCopyStructInner; +#[cfg(feature = "std")] +pub use init_mut::ZeroCopyNew; +#[cfg(all(feature = "derive", feature = "std", feature = "mut"))] +pub use light_zero_copy_derive::ZeroCopyConfig; +#[cfg(all(feature = "derive", feature = "mut"))] +pub use light_zero_copy_derive::ZeroCopyMut; +#[cfg(feature = "derive")] +pub use light_zero_copy_derive::{ZeroCopy, ZeroCopyEq}; +#[cfg(feature = "derive")] +pub use zerocopy::{ + little_endian::{self, U16, U32, U64}, + Ref, Unaligned, +}; +pub use zerocopy::{FromBytes, Immutable, IntoBytes, KnownLayout}; #[cfg(feature = "std")] extern crate std; diff --git a/program-libs/zero-copy/src/slice_mut.rs b/program-libs/zero-copy/src/slice_mut.rs index 27cd2f776a..7a50b7e44d 100644 --- a/program-libs/zero-copy/src/slice_mut.rs +++ b/program-libs/zero-copy/src/slice_mut.rs @@ -276,3 +276,16 @@ where write!(f, "{:?}", self.as_slice()) } } + +#[cfg(feature = "std")] +impl<'a, T: ZeroCopyTraits + crate::borsh_mut::DeserializeMut<'a>> + crate::borsh_mut::DeserializeMut<'a> for ZeroCopySliceMutBorsh<'_, T> +{ + type Output = ZeroCopySliceMutBorsh<'a, T>; + + fn zero_copy_at_mut( + bytes: &'a mut [u8], + ) -> Result<(Self::Output, &'a mut [u8]), ZeroCopyError> { + ZeroCopySliceMutBorsh::from_bytes_at(bytes) + } +} diff --git a/program-libs/zero-copy/tests/borsh.rs b/program-libs/zero-copy/tests/borsh.rs new file mode 100644 index 0000000000..071b4e8df2 --- /dev/null +++ b/program-libs/zero-copy/tests/borsh.rs @@ -0,0 +1,335 @@ +#![cfg(all(feature = "std", feature = "derive", feature = "mut"))] +use borsh::{BorshDeserialize, BorshSerialize}; +use light_zero_copy::{ + borsh::Deserialize, borsh_mut::DeserializeMut, ZeroCopy, ZeroCopyEq, ZeroCopyMut, +}; + +#[repr(C)] +#[derive(Debug, PartialEq, ZeroCopy, ZeroCopyMut, ZeroCopyEq, BorshDeserialize, BorshSerialize)] +pub struct Struct1Derived { + pub a: u8, + pub b: u16, +} + +#[test] +fn test_struct_1_derived() { + let ref_struct = Struct1Derived { a: 1, b: 2 }; + let mut bytes = ref_struct.try_to_vec().unwrap(); + + { + let (struct1, remaining) = Struct1Derived::zero_copy_at(&bytes).unwrap(); + assert_eq!(struct1.a, 1u8); + assert_eq!(struct1.b, 2u16); + assert_eq!(struct1, ref_struct); + assert_eq!(remaining, &[]); + } + { + let (mut struct1, _) = Struct1Derived::zero_copy_at_mut(&mut bytes).unwrap(); + struct1.a = 2; + struct1.b = 3.into(); + } + let borsh = Struct1Derived::deserialize(&mut &bytes[..]).unwrap(); + let (struct_1, _) = Struct1Derived::zero_copy_at(&bytes).unwrap(); + assert_eq!(struct_1.a, 2); // Modified value from mutable operations + assert_eq!(struct_1.b, 3); // Modified value from mutable operations + assert_eq!(struct_1, borsh); +} + +// Struct2 equivalent: Manual implementation that should match Struct2 +#[repr(C)] +#[derive( + Debug, PartialEq, Clone, BorshSerialize, BorshDeserialize, ZeroCopy, ZeroCopyMut, ZeroCopyEq, +)] +pub struct Struct2Derived { + pub a: u8, + pub b: u16, + pub vec: Vec, +} + +#[test] +fn test_struct_2_derived() { + let ref_struct = Struct2Derived { + a: 1, + b: 2, + vec: vec![1u8; 32], + }; + let bytes = ref_struct.try_to_vec().unwrap(); + + let (struct2, remaining) = Struct2Derived::zero_copy_at(&bytes).unwrap(); + assert_eq!(struct2.a, 1u8); + assert_eq!(struct2.b, 2u16); + assert_eq!(struct2.vec.to_vec(), vec![1u8; 32]); + assert_eq!(remaining, &[]); + assert_eq!(struct2, ref_struct); +} + +// Struct3 equivalent: fields should match Struct3 +#[repr(C)] +#[derive(Debug, PartialEq, BorshSerialize, BorshDeserialize, ZeroCopy, ZeroCopyMut, ZeroCopyEq)] +pub struct Struct3Derived { + pub a: u8, + pub b: u16, + pub vec: Vec, + pub c: u64, +} + +#[test] +fn test_struct_3_derived() { + let ref_struct = Struct3Derived { + a: 1, + b: 2, + vec: vec![1u8; 32], + c: 3, + }; + let bytes = ref_struct.try_to_vec().unwrap(); + + let (zero_copy, remaining) = Struct3Derived::zero_copy_at(&bytes).unwrap(); + assert_eq!(zero_copy.a, 1u8); + assert_eq!(zero_copy.b, 2u16); + assert_eq!(zero_copy.vec.to_vec(), vec![1u8; 32]); + assert_eq!(u64::from(*zero_copy.c), 3); + assert_eq!(zero_copy, ref_struct); + + assert_eq!(remaining, &[]); +} + +#[repr(C)] +#[derive( + Debug, PartialEq, BorshSerialize, BorshDeserialize, Clone, ZeroCopy, ZeroCopyMut, ZeroCopyEq, +)] +pub struct Struct4NestedDerived { + a: u8, + b: u16, +} + +#[repr(C)] +#[derive( + Debug, PartialEq, BorshSerialize, BorshDeserialize, Clone, ZeroCopy, ZeroCopyMut, ZeroCopyEq, +)] +pub struct Struct4Derived { + pub a: u8, + pub b: u16, + pub vec: Vec, + pub c: u64, + pub vec_2: Vec, +} + +#[test] +fn test_struct_4_derived() { + let ref_struct = Struct4Derived { + a: 1, + b: 2, + vec: vec![1u8; 32], + c: 3, + vec_2: vec![Struct4NestedDerived { a: 1, b: 2 }; 32], + }; + let bytes = ref_struct.try_to_vec().unwrap(); + + let (zero_copy, remaining) = Struct4Derived::zero_copy_at(&bytes).unwrap(); + assert_eq!(zero_copy.a, 1u8); + assert_eq!(zero_copy.b, 2u16); + assert_eq!(zero_copy.vec.to_vec(), vec![1u8; 32]); + assert_eq!(u64::from(*zero_copy.c), 3); + // Check vec_2 length is correct + assert_eq!(zero_copy.vec_2.len(), 32); + assert_eq!(zero_copy, ref_struct); + assert_eq!(remaining, &[]); +} + +#[repr(C)] +#[derive( + Debug, Clone, PartialEq, BorshSerialize, BorshDeserialize, ZeroCopy, ZeroCopyMut, ZeroCopyEq, +)] +pub struct Struct5Derived { + pub a: Vec>, +} + +#[test] +fn test_struct_5_derived() { + let ref_struct = Struct5Derived { + a: vec![vec![1u8; 32]; 32], + }; + let bytes = ref_struct.try_to_vec().unwrap(); + + let (zero_copy, remaining) = Struct5Derived::zero_copy_at(&bytes).unwrap(); + assert_eq!( + zero_copy.a.iter().map(|x| x.to_vec()).collect::>(), + vec![vec![1u8; 32]; 32] + ); + assert_eq!(zero_copy, ref_struct); + assert_eq!(remaining, &[]); +} + +// If a struct inside a vector contains a vector it must implement Deserialize. +#[repr(C)] +#[derive(Debug, PartialEq, BorshSerialize, BorshDeserialize, ZeroCopy, ZeroCopyMut, ZeroCopyEq)] +pub struct Struct6Derived { + pub a: Vec, +} + +#[test] +fn test_struct_6_derived() { + let ref_struct = Struct6Derived { + a: vec![ + Struct2Derived { + a: 1, + b: 2, + vec: vec![1u8; 32], + }; + 32 + ], + }; + let bytes = ref_struct.try_to_vec().unwrap(); + + let (zero_copy, remaining) = Struct6Derived::zero_copy_at(&bytes).unwrap(); + assert_eq!( + zero_copy.a.iter().collect::>(), + vec![ + &Struct2Derived { + a: 1, + b: 2, + vec: vec![1u8; 32], + }; + 32 + ] + ); + assert_eq!(zero_copy, ref_struct); + assert_eq!(remaining, &[]); +} + +#[repr(C)] +#[derive(Debug, PartialEq, Clone, BorshSerialize, BorshDeserialize, ZeroCopy, ZeroCopyMut)] +pub struct Struct7Derived { + pub a: u8, + pub b: u16, + pub option: Option, +} + +#[test] +fn test_struct_7_derived() { + let ref_struct = Struct7Derived { + a: 1, + b: 2, + option: Some(3), + }; + let bytes = ref_struct.try_to_vec().unwrap(); + + let (zero_copy, remaining) = Struct7Derived::zero_copy_at(&bytes).unwrap(); + assert_eq!(zero_copy.a, 1u8); + assert_eq!(zero_copy.b, 2u16); + assert_eq!(zero_copy.option, Some(3)); + assert_eq!(remaining, &[]); + + let bytes = Struct7Derived { + a: 1, + b: 2, + option: None, + } + .try_to_vec() + .unwrap(); + let (zero_copy, remaining) = Struct7Derived::zero_copy_at(&bytes).unwrap(); + assert_eq!(zero_copy.a, 1u8); + assert_eq!(zero_copy.b, 2u16); + assert_eq!(zero_copy.option, None); + assert_eq!(remaining, &[]); +} + +// If a struct inside a vector contains a vector it must implement Deserialize. +#[repr(C)] +#[derive(Debug, PartialEq, BorshSerialize, BorshDeserialize, ZeroCopy, ZeroCopyMut, ZeroCopyEq)] +pub struct Struct8Derived { + pub a: Vec, +} + +#[derive( + Debug, Clone, PartialEq, BorshSerialize, BorshDeserialize, ZeroCopy, ZeroCopyMut, ZeroCopyEq, +)] +pub struct NestedStructDerived { + pub a: u8, + pub b: Struct2Derived, +} + +#[test] +fn test_struct_8_derived() { + let ref_struct = Struct8Derived { + a: vec![ + NestedStructDerived { + a: 1, + b: Struct2Derived { + a: 1, + b: 2, + vec: vec![1u8; 32], + }, + }; + 32 + ], + }; + let bytes = ref_struct.try_to_vec().unwrap(); + + let (zero_copy, remaining) = Struct8Derived::zero_copy_at(&bytes).unwrap(); + // Check length of vec matches + assert_eq!(zero_copy.a.len(), 32); + assert_eq!(zero_copy, ref_struct); + + assert_eq!(remaining, &[]); +} + +#[repr(C)] +#[derive(ZeroCopy, ZeroCopyMut, ZeroCopyEq, BorshSerialize, BorshDeserialize, PartialEq, Debug)] +pub struct ArrayStruct { + pub a: [u8; 32], + pub b: [u8; 64], + pub c: [u8; 32], +} + +#[test] +fn test_array_struct() -> Result<(), Box> { + let array_struct = ArrayStruct { + a: [1u8; 32], + b: [2u8; 64], + c: [3u8; 32], + }; + let bytes = array_struct.try_to_vec()?; + + let (zero_copy, remaining) = ArrayStruct::zero_copy_at(&bytes).unwrap(); + assert_eq!(zero_copy.a, [1u8; 32]); + assert_eq!(zero_copy.b, [2u8; 64]); + assert_eq!(zero_copy.c, [3u8; 32]); + assert_eq!(zero_copy, array_struct); + assert_eq!(remaining, &[]); + Ok(()) +} + +#[derive( + Debug, + PartialEq, + Default, + Clone, + BorshSerialize, + BorshDeserialize, + ZeroCopy, + ZeroCopyMut, + ZeroCopyEq, +)] +pub struct CompressedAccountData { + pub discriminator: [u8; 8], + pub data: Vec, + pub data_hash: [u8; 32], +} + +#[test] +fn test_compressed_account_data() { + let compressed_account_data = CompressedAccountData { + discriminator: [1u8; 8], + data: vec![2u8; 32], + data_hash: [3u8; 32], + }; + let bytes = compressed_account_data.try_to_vec().unwrap(); + + let (zero_copy, remaining) = CompressedAccountData::zero_copy_at(&bytes).unwrap(); + assert_eq!(zero_copy.discriminator, [1u8; 8]); + // assert_eq!(zero_copy.data, compressed_account_data.data.as_slice()); + assert_eq!(*zero_copy.data_hash, [3u8; 32]); + assert_eq!(zero_copy, compressed_account_data); + assert_eq!(remaining, &[]); +} diff --git a/program-libs/zero-copy/tests/borsh_2.rs b/program-libs/zero-copy/tests/borsh_2.rs new file mode 100644 index 0000000000..aece86bb1c --- /dev/null +++ b/program-libs/zero-copy/tests/borsh_2.rs @@ -0,0 +1,559 @@ +#![cfg(all(feature = "std", feature = "derive"))] + +use std::{ops::Deref, vec}; + +use borsh::{BorshDeserialize, BorshSerialize}; +use light_zero_copy::{ + borsh::Deserialize, errors::ZeroCopyError, slice::ZeroCopySliceBorsh, ZeroCopyStructInner, +}; +use zerocopy::{ + little_endian::{U16, U64}, + FromBytes, Immutable, IntoBytes, KnownLayout, Ref, Unaligned, +}; + +// Rules: +// 1. create ZStruct for the struct +// 1.1. the first fields are extracted into a meta struct until we reach a Vec, Option or type that does not implement Copy, and we implement deref for the meta struct +// 1.2. represent vectors to ZeroCopySlice & don't include these into the meta struct +// 1.3. replace u16 with U16, u32 with U32, etc +// 1.4. every field after the first vector is directly included in the ZStruct and deserialized 1 by 1 +// 1.5. If a vector contains a nested vector (does not implement Copy) it must implement Deserialize +// 1.6. Elements in an Option must implement Deserialize +// 1.7. a type that does not implement Copy must implement Deserialize, and is deserialized 1 by 1 + +// Derive Macro needs to derive: +// 1. ZeroCopyStructInner +// 2. Deserialize +// 3. PartialEq for ZStruct<'_> +// +// For every struct1 - struct7 create struct_derived1 - struct_derived7 and replicate the tests for the new structs. + +// Tests for manually implemented structures (without derive macro) + +#[repr(C)] +#[derive(Debug, PartialEq, BorshSerialize, BorshDeserialize)] +pub struct Struct1 { + pub a: u8, + pub b: u16, +} + +// pub fn data_hash_struct_1(a: u8, b: u16) -> [u8; 32] { + +// } + +#[repr(C)] +#[derive(Debug, PartialEq, KnownLayout, Immutable, Unaligned, FromBytes)] +pub struct ZStruct1Meta { + pub a: u8, + pub b: U16, +} + +#[repr(C)] +#[derive(Debug, PartialEq)] +pub struct ZStruct1<'a> { + meta: Ref<&'a [u8], ZStruct1Meta>, +} +impl<'a> Deref for ZStruct1<'a> { + type Target = Ref<&'a [u8], ZStruct1Meta>; + + fn deref(&self) -> &Self::Target { + &self.meta + } +} + +impl<'a> Deserialize<'a> for Struct1 { + type Output = ZStruct1<'a>; + + fn zero_copy_at(bytes: &'a [u8]) -> Result<(Self::Output, &'a [u8]), ZeroCopyError> { + let (meta, bytes) = Ref::<&[u8], ZStruct1Meta>::from_prefix(bytes)?; + Ok((ZStruct1 { meta }, bytes)) + } +} + +#[test] +fn test_struct_1() { + let bytes = Struct1 { a: 1, b: 2 }.try_to_vec().unwrap(); + let (struct1, remaining) = Struct1::zero_copy_at(&bytes).unwrap(); + assert_eq!(struct1.a, 1u8); + assert_eq!(struct1.b, 2u16); + assert_eq!(remaining, &[]); +} + +#[repr(C)] +#[derive(Debug, PartialEq, Clone, BorshSerialize, BorshDeserialize)] +pub struct Struct2 { + pub a: u8, + pub b: u16, + pub vec: Vec, +} + +#[repr(C)] +#[derive(Debug, PartialEq, KnownLayout, Immutable, Unaligned, FromBytes)] +pub struct ZStruct2Meta { + pub a: u8, + pub b: U16, +} + +#[repr(C)] +#[derive(Debug, PartialEq)] +pub struct ZStruct2<'a> { + meta: Ref<&'a [u8], ZStruct2Meta>, + pub vec: as ZeroCopyStructInner>::ZeroCopyInner, +} + +impl PartialEq for ZStruct2<'_> { + fn eq(&self, other: &Struct2) -> bool { + let meta: &ZStruct2Meta = &self.meta; + if meta.a != other.a || other.b != meta.b.into() { + return false; + } + self.vec.as_slice() == other.vec.as_slice() + } +} + +impl<'a> Deref for ZStruct2<'a> { + type Target = Ref<&'a [u8], ZStruct2Meta>; + + fn deref(&self) -> &Self::Target { + &self.meta + } +} + +impl<'a> Deserialize<'a> for Struct2 { + type Output = ZStruct2<'a>; + + fn zero_copy_at(bytes: &'a [u8]) -> Result<(Self::Output, &'a [u8]), ZeroCopyError> { + let (meta, bytes) = Ref::<&[u8], ZStruct2Meta>::from_prefix(bytes)?; + let (vec, bytes) = as Deserialize>::zero_copy_at(bytes)?; + Ok((ZStruct2 { meta, vec }, bytes)) + } +} + +#[test] +fn test_struct_2() { + let bytes = Struct2 { + a: 1, + b: 2, + vec: vec![1u8; 32], + } + .try_to_vec() + .unwrap(); + let (struct2, remaining) = Struct2::zero_copy_at(&bytes).unwrap(); + assert_eq!(struct2.a, 1u8); + assert_eq!(struct2.b, 2u16); + assert_eq!(struct2.vec.to_vec(), vec![1u8; 32]); + assert_eq!(remaining, &[]); +} + +#[repr(C)] +#[derive(Debug, PartialEq, BorshSerialize, BorshDeserialize)] +pub struct Struct3 { + pub a: u8, + pub b: u16, + pub vec: Vec, + pub c: u64, +} + +#[repr(C)] +#[derive(Debug, PartialEq, KnownLayout, Immutable, Unaligned, FromBytes)] +pub struct ZStruct3Meta { + pub a: u8, + pub b: U16, +} + +#[derive(Debug, PartialEq)] +pub struct ZStruct3<'a> { + meta: Ref<&'a [u8], ZStruct3Meta>, + pub vec: ZeroCopySliceBorsh<'a, u8>, + pub c: Ref<&'a [u8], U64>, +} + +impl<'a> Deref for ZStruct3<'a> { + type Target = Ref<&'a [u8], ZStruct3Meta>; + + fn deref(&self) -> &Self::Target { + &self.meta + } +} + +impl<'a> Deserialize<'a> for Struct3 { + type Output = ZStruct3<'a>; + + fn zero_copy_at(bytes: &'a [u8]) -> Result<(Self::Output, &'a [u8]), ZeroCopyError> { + let (meta, bytes) = Ref::<&[u8], ZStruct3Meta>::from_prefix(bytes)?; + let (vec, bytes) = ZeroCopySliceBorsh::zero_copy_at(bytes)?; + let (c, bytes) = Ref::<&[u8], U64>::from_prefix(bytes)?; + Ok((ZStruct3 { meta, vec, c }, bytes)) + } +} + +#[test] +fn test_struct_3() { + let bytes = Struct3 { + a: 1, + b: 2, + vec: vec![1u8; 32], + c: 3, + } + .try_to_vec() + .unwrap(); + let (zero_copy, remaining) = Struct3::zero_copy_at(&bytes).unwrap(); + assert_eq!(zero_copy.a, 1u8); + assert_eq!(zero_copy.b, 2u16); + assert_eq!(zero_copy.vec.to_vec(), vec![1u8; 32]); + assert_eq!(u64::from(*zero_copy.c), 3); + assert_eq!(remaining, &[]); +} + +#[repr(C)] +#[derive(Debug, PartialEq, BorshSerialize, BorshDeserialize, Clone)] +pub struct Struct4Nested { + a: u8, + b: u16, +} + +#[repr(C)] +#[derive( + Debug, PartialEq, Copy, Clone, KnownLayout, Immutable, IntoBytes, Unaligned, FromBytes, +)] +pub struct ZStruct4Nested { + pub a: u8, + pub b: U16, +} + +impl ZeroCopyStructInner for Struct4Nested { + type ZeroCopyInner = ZStruct4Nested; +} + +#[repr(C)] +#[derive(Debug, PartialEq, BorshSerialize, BorshDeserialize)] +pub struct Struct4 { + pub a: u8, + pub b: u16, + pub vec: Vec, + pub c: u64, + pub vec_2: Vec, +} + +#[repr(C)] +#[derive(Debug, PartialEq, KnownLayout, Immutable, Unaligned, IntoBytes, FromBytes)] +pub struct ZStruct4Meta { + pub a: ::ZeroCopyInner, + pub b: ::ZeroCopyInner, +} + +#[derive(Debug, PartialEq)] +pub struct ZStruct4<'a> { + meta: Ref<&'a [u8], ZStruct4Meta>, + pub vec: ZeroCopySliceBorsh<'a, ::ZeroCopyInner>, + pub c: Ref<&'a [u8], ::ZeroCopyInner>, + pub vec_2: ZeroCopySliceBorsh<'a, ::ZeroCopyInner>, +} + +impl<'a> Deref for ZStruct4<'a> { + type Target = Ref<&'a [u8], ZStruct4Meta>; + + fn deref(&self) -> &Self::Target { + &self.meta + } +} + +impl<'a> Deserialize<'a> for Struct4 { + type Output = ZStruct4<'a>; + + fn zero_copy_at(bytes: &'a [u8]) -> Result<(Self::Output, &'a [u8]), ZeroCopyError> { + let (meta, bytes) = Ref::<&[u8], ZStruct4Meta>::from_prefix(bytes)?; + let (vec, bytes) = ZeroCopySliceBorsh::from_bytes_at(bytes)?; + let (c, bytes) = + Ref::<&[u8], ::ZeroCopyInner>::from_prefix(bytes)?; + let (vec_2, bytes) = ZeroCopySliceBorsh::from_bytes_at(bytes)?; + Ok(( + ZStruct4 { + meta, + vec, + c, + vec_2, + }, + bytes, + )) + } +} + +#[test] +fn test_struct_4() { + let bytes = Struct4 { + a: 1, + b: 2, + vec: vec![1u8; 32], + c: 3, + vec_2: vec![Struct4Nested { a: 1, b: 2 }; 32], + } + .try_to_vec() + .unwrap(); + let (zero_copy, remaining) = Struct4::zero_copy_at(&bytes).unwrap(); + assert_eq!(zero_copy.a, 1u8); + assert_eq!(zero_copy.b, 2u16); + assert_eq!(zero_copy.vec.to_vec(), vec![1u8; 32]); + assert_eq!(u64::from(*zero_copy.c), 3); + assert_eq!( + zero_copy.vec_2.to_vec(), + vec![ZStruct4Nested { a: 1, b: 2.into() }; 32] + ); + assert_eq!(remaining, &[]); +} + +#[repr(C)] +#[derive(Debug, Clone, PartialEq, BorshSerialize, BorshDeserialize)] +pub struct Struct5 { + pub a: Vec>, +} + +#[repr(C)] +#[derive(Debug, PartialEq)] +pub struct ZStruct5<'a> { + pub a: Vec::ZeroCopyInner>>, +} + +impl<'a> Deserialize<'a> for Struct5 { + type Output = ZStruct5<'a>; + + fn zero_copy_at(bytes: &'a [u8]) -> Result<(Self::Output, &'a [u8]), ZeroCopyError> { + let (a, bytes) = + Vec::::ZeroCopyInner>>::zero_copy_at( + bytes, + )?; + Ok((ZStruct5 { a }, bytes)) + } +} + +#[test] +fn test_struct_5() { + let bytes = Struct5 { + a: vec![vec![1u8; 32]; 32], + } + .try_to_vec() + .unwrap(); + let (zero_copy, remaining) = Struct5::zero_copy_at(&bytes).unwrap(); + assert_eq!( + zero_copy.a.iter().map(|x| x.to_vec()).collect::>(), + vec![vec![1u8; 32]; 32] + ); + assert_eq!(remaining, &[]); +} + +// If a struct inside a vector contains a vector it must implement Deserialize. +#[repr(C)] +#[derive(Debug, PartialEq, BorshSerialize, BorshDeserialize)] +pub struct Struct6 { + pub a: Vec, +} + +#[repr(C)] +#[derive(Debug, PartialEq)] +pub struct ZStruct6<'a> { + pub a: Vec<>::Output>, +} + +impl<'a> Deserialize<'a> for Struct6 { + type Output = ZStruct6<'a>; + + fn zero_copy_at(bytes: &'a [u8]) -> Result<(Self::Output, &'a [u8]), ZeroCopyError> { + let (a, bytes) = Vec::::zero_copy_at(bytes)?; + Ok((ZStruct6 { a }, bytes)) + } +} + +#[test] +fn test_struct_6() { + let bytes = Struct6 { + a: vec![ + Struct2 { + a: 1, + b: 2, + vec: vec![1u8; 32], + }; + 32 + ], + } + .try_to_vec() + .unwrap(); + let (zero_copy, remaining) = Struct6::zero_copy_at(&bytes).unwrap(); + assert_eq!( + zero_copy.a.iter().collect::>(), + vec![ + &Struct2 { + a: 1, + b: 2, + vec: vec![1u8; 32], + }; + 32 + ] + ); + assert_eq!(remaining, &[]); +} + +#[repr(C)] +#[derive(Debug, PartialEq, Clone, BorshSerialize, BorshDeserialize)] +pub struct Struct7 { + pub a: u8, + pub b: u16, + pub option: Option, +} + +#[repr(C)] +#[derive(Debug, PartialEq, KnownLayout, Immutable, Unaligned, FromBytes)] +pub struct ZStruct7Meta { + pub a: u8, + pub b: U16, +} + +#[repr(C)] +#[derive(Debug, PartialEq)] +pub struct ZStruct7<'a> { + meta: Ref<&'a [u8], ZStruct7Meta>, + pub option: as ZeroCopyStructInner>::ZeroCopyInner, +} + +impl PartialEq for ZStruct7<'_> { + fn eq(&self, other: &Struct7) -> bool { + let meta: &ZStruct7Meta = &self.meta; + if meta.a != other.a || other.b != meta.b.into() { + return false; + } + self.option == other.option + } +} + +impl<'a> Deref for ZStruct7<'a> { + type Target = Ref<&'a [u8], ZStruct7Meta>; + + fn deref(&self) -> &Self::Target { + &self.meta + } +} + +impl<'a> Deserialize<'a> for Struct7 { + type Output = ZStruct7<'a>; + + fn zero_copy_at(bytes: &'a [u8]) -> Result<(Self::Output, &'a [u8]), ZeroCopyError> { + let (meta, bytes) = Ref::<&[u8], ZStruct7Meta>::from_prefix(bytes)?; + let (option, bytes) = as Deserialize>::zero_copy_at(bytes)?; + Ok((ZStruct7 { meta, option }, bytes)) + } +} + +#[test] +fn test_struct_7() { + let bytes = Struct7 { + a: 1, + b: 2, + option: Some(3), + } + .try_to_vec() + .unwrap(); + let (zero_copy, remaining) = Struct7::zero_copy_at(&bytes).unwrap(); + assert_eq!(zero_copy.a, 1u8); + assert_eq!(zero_copy.b, 2u16); + assert_eq!(zero_copy.option, Some(3)); + assert_eq!(remaining, &[]); + + let bytes = Struct7 { + a: 1, + b: 2, + option: None, + } + .try_to_vec() + .unwrap(); + let (zero_copy, remaining) = Struct7::zero_copy_at(&bytes).unwrap(); + assert_eq!(zero_copy.a, 1u8); + assert_eq!(zero_copy.b, 2u16); + assert_eq!(zero_copy.option, None); + assert_eq!(remaining, &[]); +} + +// If a struct inside a vector contains a vector it must implement Deserialize. +#[repr(C)] +#[derive(Debug, PartialEq, BorshSerialize, BorshDeserialize)] +pub struct Struct8 { + pub a: Vec, +} + +#[derive(Debug, Clone, PartialEq, BorshSerialize, BorshDeserialize)] +pub struct NestedStruct { + pub a: u8, + pub b: Struct2, +} + +#[repr(C)] +#[derive(Debug, PartialEq)] +pub struct ZNestedStruct<'a> { + pub a: ::ZeroCopyInner, + pub b: >::Output, +} + +impl<'a> Deserialize<'a> for NestedStruct { + type Output = ZNestedStruct<'a>; + + fn zero_copy_at(bytes: &'a [u8]) -> Result<(Self::Output, &'a [u8]), ZeroCopyError> { + let (a, bytes) = ::ZeroCopyInner::zero_copy_at(bytes)?; + let (b, bytes) = ::zero_copy_at(bytes)?; + Ok((ZNestedStruct { a, b }, bytes)) + } +} + +impl PartialEq for ZNestedStruct<'_> { + fn eq(&self, other: &NestedStruct) -> bool { + self.a == other.a && self.b == other.b + } +} + +#[repr(C)] +#[derive(Debug, PartialEq)] +pub struct ZStruct8<'a> { + pub a: Vec<>::Output>, +} + +impl<'a> Deserialize<'a> for Struct8 { + type Output = ZStruct8<'a>; + + fn zero_copy_at(bytes: &'a [u8]) -> Result<(Self::Output, &'a [u8]), ZeroCopyError> { + let (a, bytes) = Vec::::zero_copy_at(bytes)?; + Ok((ZStruct8 { a }, bytes)) + } +} + +#[test] +fn test_struct_8() { + let bytes = Struct8 { + a: vec![ + NestedStruct { + a: 1, + b: Struct2 { + a: 1, + b: 2, + vec: vec![1u8; 32], + }, + }; + 32 + ], + } + .try_to_vec() + .unwrap(); + + let (zero_copy, remaining) = Struct8::zero_copy_at(&bytes).unwrap(); + assert_eq!( + zero_copy.a.iter().collect::>(), + vec![ + &NestedStruct { + a: 1, + b: Struct2 { + a: 1, + b: 2, + vec: vec![1u8; 32], + }, + }; + 32 + ] + ); + assert_eq!(remaining, &[]); +}