From bdf3cf606f4273046b48a9c200ab0d88fb6fb4f0 Mon Sep 17 00:00:00 2001 From: ananas-block Date: Wed, 2 Jul 2025 22:22:48 +0100 Subject: [PATCH] feat: zero-copy-derive fix: light-zero-copy tests comment derive mut commented byte len fix: derive macro for non mut pre bytelen refactor: detach bytelen trait stash adding config simple config derive works stash stash new at stash new at Compressed Account stash man InstructionDataInvoke new_zero_copy works stash simple config zero_copy_new tests stash refactor fixed lifetime issue stash instruction data tests work move byte_len to init_mut added randomized tests stash got failing random tests fixed u8 and bool remove bytelen renamed trait fix lint fix tests apply feedback meta_struct use syn to parse options instead of strings primitive types Replace string-based type comparisons with proper syn AST matching replace parse_str with parse_quote replace empty quote with unreachable! add byte len check borsh_vec_u8_as_slice_mut converted unimplemented to panic cleanup redundant as u64 etc fix docs cleanup cleanup commtend code cleanup mut conditionals remove bytelen derive cleanup refactor: replace duplicate code with generate_deserialize_call refactor detecting copy moved to internal refactor: add error handling cleanup cleanup file structure stash wip transform all primitive types to zero copy types simplify analyze_struct_fields fix empty meta struct generation stash zero copy changes unified some with Deserialize::Output unified integer field type enum renam VecNonStaticZeroCopy -> VecDynamicZeroCopy Simplify Option inner type extraction using syn utilities. Add bounds check before writing discriminant byte. improve generate_field_initialization remove debug test Incorrect type conversion from u8 to u32, add note options in arrays are not supported Error context lost in conversion format and add heap allocation check Check the last path segment for accurate type detection fix: test fix: test improve cache robustness --- .github/workflows/rust.yml | 3 +- .gitignore | 2 + Cargo.lock | 44 + Cargo.toml | 2 + .../src/instruction_data/with_account_info.rs | 8 +- .../src/instruction_data/with_readonly.rs | 10 +- program-libs/zero-copy-derive/Cargo.toml | 26 + program-libs/zero-copy-derive/README.md | 103 ++ program-libs/zero-copy-derive/src/lib.rs | 166 ++ .../zero-copy-derive/src/shared/from_impl.rs | 242 +++ .../src/shared/meta_struct.rs | 57 + .../zero-copy-derive/src/shared/mod.rs | 6 + .../zero-copy-derive/src/shared/utils.rs | 437 +++++ .../zero-copy-derive/src/shared/z_struct.rs | 630 ++++++++ .../src/shared/zero_copy_new.rs | 391 +++++ .../zero-copy-derive/src/zero_copy.rs | 636 ++++++++ .../zero-copy-derive/src/zero_copy_eq.rs | 265 ++++ .../zero-copy-derive/src/zero_copy_mut.rs | 93 ++ .../zero-copy-derive/tests/config_test.rs | 430 +++++ .../tests/cross_crate_copy.rs | 295 ++++ .../zero-copy-derive/tests/from_test.rs | 77 + .../tests/instruction_data.rs | 1401 +++++++++++++++++ program-libs/zero-copy-derive/tests/random.rs | 651 ++++++++ program-libs/zero-copy/Cargo.toml | 4 + program-libs/zero-copy/README.md | 3 - program-libs/zero-copy/src/borsh.rs | 669 +++++++- program-libs/zero-copy/src/borsh_mut.rs | 965 ++++++++++++ program-libs/zero-copy/src/init_mut.rs | 268 ++++ program-libs/zero-copy/src/lib.rs | 20 +- program-libs/zero-copy/src/slice_mut.rs | 13 + program-libs/zero-copy/tests/borsh.rs | 335 ++++ program-libs/zero-copy/tests/borsh_2.rs | 559 +++++++ 32 files changed, 8790 insertions(+), 21 deletions(-) create mode 100644 program-libs/zero-copy-derive/Cargo.toml create mode 100644 program-libs/zero-copy-derive/README.md create mode 100644 program-libs/zero-copy-derive/src/lib.rs create mode 100644 program-libs/zero-copy-derive/src/shared/from_impl.rs create mode 100644 program-libs/zero-copy-derive/src/shared/meta_struct.rs create mode 100644 program-libs/zero-copy-derive/src/shared/mod.rs create mode 100644 program-libs/zero-copy-derive/src/shared/utils.rs create mode 100644 program-libs/zero-copy-derive/src/shared/z_struct.rs create mode 100644 program-libs/zero-copy-derive/src/shared/zero_copy_new.rs create mode 100644 program-libs/zero-copy-derive/src/zero_copy.rs create mode 100644 program-libs/zero-copy-derive/src/zero_copy_eq.rs create mode 100644 program-libs/zero-copy-derive/src/zero_copy_mut.rs create mode 100644 program-libs/zero-copy-derive/tests/config_test.rs create mode 100644 program-libs/zero-copy-derive/tests/cross_crate_copy.rs create mode 100644 program-libs/zero-copy-derive/tests/from_test.rs create mode 100644 program-libs/zero-copy-derive/tests/instruction_data.rs create mode 100644 program-libs/zero-copy-derive/tests/random.rs create mode 100644 program-libs/zero-copy/src/borsh_mut.rs create mode 100644 program-libs/zero-copy/src/init_mut.rs create mode 100644 program-libs/zero-copy/tests/borsh.rs create mode 100644 program-libs/zero-copy/tests/borsh_2.rs diff --git a/.github/workflows/rust.yml b/.github/workflows/rust.yml index efa24bff78..2054edec76 100644 --- a/.github/workflows/rust.yml +++ b/.github/workflows/rust.yml @@ -53,7 +53,8 @@ jobs: cargo test -p light-account-checks --all-features cargo test -p light-verifier --all-features cargo test -p light-merkle-tree-metadata --all-features - cargo test -p light-zero-copy --features std + cargo test -p light-zero-copy --features "std, mut, derive" + cargo test -p light-zero-copy-derive --features "mut" cargo test -p light-hash-set --all-features - name: program-libs-slow packages: light-bloom-filter light-indexed-merkle-tree light-batched-merkle-tree diff --git a/.gitignore b/.gitignore index b7e754f3b1..5129907005 100644 --- a/.gitignore +++ b/.gitignore @@ -86,3 +86,5 @@ output1.txt .zed **/.claude/**/* + +expand.rs diff --git a/Cargo.lock b/Cargo.lock index 03057b4669..6dec5247d1 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2365,6 +2365,12 @@ version = "0.31.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "07e28edb80900c19c28f1072f2e8aeca7fa06b23cd4169cefe1af5aa3260783f" +[[package]] +name = "glob" +version = "0.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a8d1add55171497b4705a648c6b583acafb01d58050a51727785f0b2c8e0a2b2" + [[package]] name = "governor" version = "0.6.3" @@ -3786,6 +3792,8 @@ dependencies = [ name = "light-zero-copy" version = "0.2.0" dependencies = [ + "borsh 0.10.4", + "light-zero-copy-derive", "pinocchio", "rand 0.8.5", "solana-program-error", @@ -3793,6 +3801,21 @@ dependencies = [ "zerocopy", ] +[[package]] +name = "light-zero-copy-derive" +version = "0.1.0" +dependencies = [ + "borsh 0.10.4", + "lazy_static", + "light-zero-copy", + "proc-macro2", + "quote", + "rand 0.8.5", + "syn 2.0.103", + "trybuild", + "zerocopy", +] + [[package]] name = "linux-raw-sys" version = "0.9.4" @@ -9054,6 +9077,12 @@ version = "1.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "55937e1799185b12863d447f42597ed69d9928686b8d88a1df17376a097d8369" +[[package]] +name = "target-triple" +version = "0.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1ac9aa371f599d22256307c24a9d748c041e548cbf599f35d890f9d365361790" + [[package]] name = "tarpc" version = "0.29.0" @@ -9626,6 +9655,21 @@ version = "0.2.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e421abadd41a4225275504ea4d6566923418b7f05506fbc9c0fe86ba7396114b" +[[package]] +name = "trybuild" +version = "1.0.105" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1c9bf9513a2f4aeef5fdac8677d7d349c79fdbcc03b9c86da6e9d254f1e43be2" +dependencies = [ + "glob", + "serde", + "serde_derive", + "serde_json", + "target-triple", + "termcolor", + "toml 0.8.23", +] + [[package]] name = "tungstenite" version = "0.20.1" diff --git a/Cargo.toml b/Cargo.toml index 176115adb1..87af2c81ca 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -13,6 +13,7 @@ members = [ "program-libs/hash-set", "program-libs/indexed-merkle-tree", "program-libs/indexed-array", + "program-libs/zero-copy-derive", "programs/account-compression", "programs/system", "programs/compressed-token", @@ -167,6 +168,7 @@ light-compressed-account = { path = "program-libs/compressed-account", version = light-account-checks = { path = "program-libs/account-checks", version = "0.3.0" } light-verifier = { path = "program-libs/verifier", version = "2.1.0" } light-zero-copy = { path = "program-libs/zero-copy", version = "0.2.0" } +light-zero-copy-derive = { path = "program-libs/zero-copy-derive", version = "0.1.0" } photon-api = { path = "sdk-libs/photon-api", version = "0.51.0" } forester-utils = { path = "forester-utils", version = "2.0.0" } account-compression = { path = "programs/account-compression", version = "2.0.0", features = [ diff --git a/program-libs/compressed-account/src/instruction_data/with_account_info.rs b/program-libs/compressed-account/src/instruction_data/with_account_info.rs index 599ad9cd0b..57b49e5e78 100644 --- a/program-libs/compressed-account/src/instruction_data/with_account_info.rs +++ b/program-libs/compressed-account/src/instruction_data/with_account_info.rs @@ -399,9 +399,13 @@ impl<'a> Deserialize<'a> for InstructionDataInvokeCpiWithAccountInfo { let (account_infos, bytes) = { let (num_slices, mut bytes) = Ref::<&[u8], U32>::from_prefix(bytes)?; let num_slices = u32::from(*num_slices) as usize; - // TODO: add check that remaining data is enough to read num_slices - // This prevents agains invalid data allocating a lot of heap memory let mut slices = Vec::with_capacity(num_slices); + if bytes.len() < num_slices { + return Err(ZeroCopyError::InsufficientMemoryAllocated( + bytes.len(), + num_slices, + )); + } for _ in 0..num_slices { let (slice, _bytes) = CompressedAccountInfo::zero_copy_at_with_owner( bytes, diff --git a/program-libs/compressed-account/src/instruction_data/with_readonly.rs b/program-libs/compressed-account/src/instruction_data/with_readonly.rs index 59b9c27bd7..e591f45444 100644 --- a/program-libs/compressed-account/src/instruction_data/with_readonly.rs +++ b/program-libs/compressed-account/src/instruction_data/with_readonly.rs @@ -347,8 +347,14 @@ impl<'a> Deserialize<'a> for InstructionDataInvokeCpiWithReadOnly { let (input_compressed_accounts, bytes) = { let (num_slices, mut bytes) = Ref::<&[u8], U32>::from_prefix(bytes)?; let num_slices = u32::from(*num_slices) as usize; - // TODO: add check that remaining data is enough to read num_slices - // This prevents agains invalid data allocating a lot of heap memory + // Prevent heap exhaustion attacks by checking if num_slices is reasonable + // Each element needs at least 1 byte when serialized + if bytes.len() < num_slices { + return Err(ZeroCopyError::InsufficientMemoryAllocated( + bytes.len(), + num_slices, + )); + } let mut slices = Vec::with_capacity(num_slices); for _ in 0..num_slices { let (slice, _bytes) = diff --git a/program-libs/zero-copy-derive/Cargo.toml b/program-libs/zero-copy-derive/Cargo.toml new file mode 100644 index 0000000000..1cdc8254e8 --- /dev/null +++ b/program-libs/zero-copy-derive/Cargo.toml @@ -0,0 +1,26 @@ +[package] +name = "light-zero-copy-derive" +version = "0.1.0" +edition = "2021" +license = "Apache-2.0" +description = "Proc macro for zero-copy deserialization" + +[features] +default = [] +mut = [] + +[lib] +proc-macro = true + +[dependencies] +proc-macro2 = "1.0" +quote = "1.0" +syn = { version = "2.0", features = ["full", "extra-traits"] } +lazy_static = "1.4" + +[dev-dependencies] +trybuild = "1.0" +rand = "0.8" +borsh = { workspace = true } +light-zero-copy = { workspace = true, features = ["std", "derive"] } +zerocopy = { workspace = true, features = ["derive"] } diff --git a/program-libs/zero-copy-derive/README.md b/program-libs/zero-copy-derive/README.md new file mode 100644 index 0000000000..8e17fbbb25 --- /dev/null +++ b/program-libs/zero-copy-derive/README.md @@ -0,0 +1,103 @@ +# Light-Zero-Copy-Derive + +A procedural macro for deriving zero-copy deserialization for Rust structs used with Solana programs. + +## Features + +This crate provides two key derive macros: + +1. `#[derive(ZeroCopy)]` - Implements zero-copy deserialization with: + - The `zero_copy_at` and `zero_copy_at_mut` methods for deserialization + - Full Borsh compatibility for serialization/deserialization + - Efficient memory representation with no copying of data + - `From>` and `FromMut>` implementations for easy conversion back to the original struct + +2. `#[derive(ZeroCopyEq)]` - Adds equality comparison support: + - Compare zero-copy instances with regular struct instances + - Can be used alongside `ZeroCopy` for complete functionality + - Derivation for Options is not robust and may not compile. + +## Rules for Zero-Copy Deserialization + +The macro follows these rules when generating code: + +1. Creates a `ZStruct` for your struct that follows zero-copy principles + 1. Fields are extracted into a meta struct until reaching a `Vec`, `Option` or non-`Copy` type + 2. Vectors are represented as `ZeroCopySlice` and not included in the meta struct + 3. Integer types are replaced with their zerocopy equivalents (e.g., `u16` → `U16`) + 4. Fields after the first vector are directly included in the `ZStruct` and deserialized one by one + 5. If a vector contains a nested vector (non-`Copy` type), it must implement `Deserialize` + 6. Elements in an `Option` must implement `Deserialize` + 7. Types that don't implement `Copy` must implement `Deserialize` and are deserialized one by one + +## Usage + +### Basic Usage + +```rust +use borsh::{BorshDeserialize, BorshSerialize}; +use light_zero_copy_derive::ZeroCopy; +use light_zero_copy::{borsh::Deserialize, borsh_mut::DeserializeMut}; + +#[repr(C)] +#[derive(Debug, PartialEq, BorshSerialize, BorshDeserialize, ZeroCopy)] +pub struct MyStruct { + pub a: u8, + pub b: u16, + pub vec: Vec, + pub c: u64, +} +let my_struct = MyStruct { + a: 1, + b: 2, + vec: vec![1u8; 32], + c: 3, +}; +// Use the struct with zero-copy deserialization +let mut bytes = my_struct.try_to_vec().unwrap(); + +// Immutable zero-copy deserialization +let (zero_copy, _remaining) = MyStruct::zero_copy_at(&bytes).unwrap(); + +// Convert back to original struct using From implementation +let converted: MyStruct = zero_copy.clone().into(); +assert_eq!(converted, my_struct); + +// Mutable zero-copy deserialization with modification +let (mut zero_copy_mut, _remaining) = MyStruct::zero_copy_at_mut(&mut bytes).unwrap(); +zero_copy_mut.a = 42; + +// The change is reflected when we convert back to the original struct +let modified: MyStruct = zero_copy_mut.into(); +assert_eq!(modified.a, 42); + +// And also when we deserialize directly from the modified bytes +let borsh = MyStruct::try_from_slice(&bytes).unwrap(); +assert_eq!(borsh.a, 42u8); +``` + +### With Equality Comparison + +```rust +use borsh::{BorshDeserialize, BorshSerialize}; +use light_zero_copy_derive::ZeroCopy; + +#[repr(C)] +#[derive(Debug, PartialEq, BorshSerialize, BorshDeserialize, ZeroCopy)] +pub struct MyStruct { + pub a: u8, + pub b: u16, + pub vec: Vec, + pub c: u64, +} +let my_struct = MyStruct { + a: 1, + b: 2, + vec: vec![1u8; 32], + c: 3, +}; +// Use the struct with zero-copy deserialization +let mut bytes = my_struct.try_to_vec().unwrap(); +let (zero_copy, _remaining) = MyStruct::zero_copy_at(&bytes).unwrap(); +assert_eq!(zero_copy, my_struct); +``` diff --git a/program-libs/zero-copy-derive/src/lib.rs b/program-libs/zero-copy-derive/src/lib.rs new file mode 100644 index 0000000000..becac18087 --- /dev/null +++ b/program-libs/zero-copy-derive/src/lib.rs @@ -0,0 +1,166 @@ +//! Procedural macros for zero-copy deserialization. +//! +//! This crate provides derive macros that generate efficient zero-copy data structures +//! and deserialization code, eliminating the need for data copying during parsing. +//! +//! ## Main Macros +//! +//! - `ZeroCopy`: Generates zero-copy structs and deserialization traits +//! - `ZeroCopyMut`: Adds mutable zero-copy support +//! - `ZeroCopyEq`: Adds PartialEq implementation for comparing with original structs +//! - `ZeroCopyNew`: Generates configuration structs for initialization + +use proc_macro::TokenStream; + +mod shared; +mod zero_copy; +mod zero_copy_eq; +#[cfg(feature = "mut")] +mod zero_copy_mut; + +/// ZeroCopy derivation macro for zero-copy deserialization +/// +/// # Usage +/// +/// Basic usage: +/// ```rust +/// use light_zero_copy_derive::ZeroCopy; +/// #[derive(ZeroCopy)] +/// pub struct MyStruct { +/// pub a: u8, +/// } +/// ``` +/// +/// To derive PartialEq as well, use ZeroCopyEq in addition to ZeroCopy: +/// ```rust +/// use light_zero_copy_derive::{ZeroCopy, ZeroCopyEq}; +/// #[derive(ZeroCopy, ZeroCopyEq)] +/// pub struct MyStruct { +/// pub a: u8, +/// } +/// ``` +/// +/// # Macro Rules +/// 1. Create zero copy structs Z and ZMut for the struct +/// 1.1. The first fields are extracted into a meta struct until we reach a Vec, Option or type that does not implement Copy +/// 1.2. Represent vectors to ZeroCopySlice & don't include these into the meta struct +/// 1.3. Replace u16 with U16, u32 with U32, etc +/// 1.4. Every field after the first vector is directly included in the ZStruct and deserialized 1 by 1 +/// 1.5. If a vector contains a nested vector (does not implement Copy) it must implement Deserialize +/// 1.6. Elements in an Option must implement Deserialize +/// 1.7. A type that does not implement Copy must implement Deserialize, and is deserialized 1 by 1 +/// 1.8. is u8 deserialized as u8::zero_copy_at instead of Ref<&'a [u8], u8> for non mut, for mut it is Ref<&'a mut [u8], u8> +/// 2. Implement Deserialize and DeserializeMut which return Z and ZMut +/// 3. Implement From> for StructName and FromMut> for StructName +/// +/// Note: Options are not supported in ZeroCopyEq +#[proc_macro_derive(ZeroCopy)] +pub fn derive_zero_copy(input: TokenStream) -> TokenStream { + let res = zero_copy::derive_zero_copy_impl(input); + TokenStream::from(match res { + Ok(res) => res, + Err(err) => err.to_compile_error(), + }) +} + +/// ZeroCopyEq implementation to add PartialEq for zero-copy structs. +/// +/// Use this in addition to ZeroCopy when you want the generated struct to implement PartialEq: +/// +/// ```rust +/// use light_zero_copy_derive::{ZeroCopy, ZeroCopyEq}; +/// #[derive(ZeroCopy, ZeroCopyEq)] +/// pub struct MyStruct { +/// pub a: u8, +/// } +/// ``` +#[proc_macro_derive(ZeroCopyEq)] +pub fn derive_zero_copy_eq(input: TokenStream) -> TokenStream { + let res = zero_copy_eq::derive_zero_copy_eq_impl(input); + TokenStream::from(match res { + Ok(res) => res, + Err(err) => err.to_compile_error(), + }) +} + +/// ZeroCopyMut derivation macro for mutable zero-copy deserialization +/// +/// This macro generates mutable zero-copy implementations including: +/// - DeserializeMut trait implementation +/// - Mutable Z-struct with `Mut` suffix +/// - byte_len() method implementation +/// - Mutable ZeroCopyStructInner implementation +/// +/// # Usage +/// +/// ```rust +/// use light_zero_copy_derive::ZeroCopyMut; +/// +/// #[derive(ZeroCopyMut)] +/// pub struct MyStruct { +/// pub a: u8, +/// pub vec: Vec, +/// } +/// ``` +/// +/// This will generate: +/// - `MyStruct::zero_copy_at_mut()` method +/// - `ZMyStructMut<'a>` type for mutable zero-copy access +/// - `MyStruct::byte_len()` method +/// +/// For both immutable and mutable functionality, use both derives: +/// ```rust +/// use light_zero_copy_derive::{ZeroCopy, ZeroCopyMut}; +/// +/// #[derive(ZeroCopy, ZeroCopyMut)] +/// pub struct MyStruct { +/// pub a: u8, +/// } +/// ``` +#[cfg(feature = "mut")] +#[proc_macro_derive(ZeroCopyMut)] +pub fn derive_zero_copy_mut(input: TokenStream) -> TokenStream { + let res = zero_copy_mut::derive_zero_copy_mut_impl(input); + TokenStream::from(match res { + Ok(res) => res, + Err(err) => err.to_compile_error(), + }) +} + +// /// ZeroCopyNew derivation macro for configuration-based zero-copy initialization +// /// +// /// This macro generates configuration structs and initialization methods for structs +// /// with Vec and Option fields that need to be initialized with specific configurations. +// /// +// /// # Usage +// /// +// /// ```ignore +// /// use light_zero_copy_derive::ZeroCopyNew; +// /// +// /// #[derive(ZeroCopyNew)] +// /// pub struct MyStruct { +// /// pub a: u8, +// /// pub vec: Vec, +// /// pub option: Option, +// /// } +// /// ``` +// /// +// /// This will generate: +// /// - `MyStructConfig` struct with configuration fields +// /// - `ZeroCopyNew` implementation for `MyStruct` +// /// - `new_zero_copy(bytes, config)` method for initialization +// /// +// /// The configuration struct will have fields based on the complexity of the original fields: +// /// - `Vec` → `field_name: u32` (length) +// /// - `Option` → `field_name: bool` (is_some) +// /// - `Vec` → `field_name: Vec` (config per element) +// /// - `Option` → `field_name: Option` (config if some) +// #[cfg(feature = "mut")] +// #[proc_macro_derive(ZeroCopyNew)] +// pub fn derive_zero_copy_config(input: TokenStream) -> TokenStream { +// let res = zero_copy_new::derive_zero_copy_config_impl(input); +// TokenStream::from(match res { +// Ok(res) => res, +// Err(err) => err.to_compile_error(), +// }) +// } diff --git a/program-libs/zero-copy-derive/src/shared/from_impl.rs b/program-libs/zero-copy-derive/src/shared/from_impl.rs new file mode 100644 index 0000000000..1aab0eb9b3 --- /dev/null +++ b/program-libs/zero-copy-derive/src/shared/from_impl.rs @@ -0,0 +1,242 @@ +use proc_macro2::TokenStream; +use quote::{format_ident, quote}; +use syn::{Field, Ident}; + +use super::{ + utils, + z_struct::{analyze_struct_fields, FieldType}, +}; + +/// Generates code for the From> for StructName implementation +/// The `MUT` parameter controls whether to generate code for mutable or immutable references +pub fn generate_from_impl( + name: &Ident, + z_struct_name: &Ident, + meta_fields: &[&Field], + struct_fields: &[&Field], +) -> syn::Result { + let z_struct_name = if MUT { + format_ident!("{}Mut", z_struct_name) + } else { + z_struct_name.clone() + }; + + // Generate the conversion code for meta fields + let meta_field_conversions = if !meta_fields.is_empty() { + let field_types = analyze_struct_fields(meta_fields)?; + let conversions = field_types.into_iter().map(|field_type| { + match field_type { + FieldType::Primitive(field_name, field_type) => { + match () { + _ if utils::is_specific_primitive_type(field_type, "u8") => { + quote! { #field_name: value.__meta.#field_name, } + } + _ if utils::is_specific_primitive_type(field_type, "bool") => { + quote! { #field_name: value.__meta.#field_name > 0, } + } + _ => { + // For u64, u32, u16 - use the type's from() method + quote! { #field_name: #field_type::from(value.__meta.#field_name), } + } + } + } + FieldType::Array(field_name, _) => { + // For arrays, just copy the value + quote! { #field_name: value.__meta.#field_name, } + } + FieldType::Pubkey(field_name) => { + quote! { #field_name: value.__meta.#field_name, } + } + _ => { + let field_name = field_type.name(); + quote! { #field_name: value.__meta.#field_name.into(), } + } + } + }); + conversions.collect::>() + } else { + vec![] + }; + + // Generate the conversion code for struct fields + let struct_field_conversions = if !struct_fields.is_empty() { + let field_types = analyze_struct_fields(struct_fields)?; + let conversions = field_types.into_iter().map(|field_type| { + match field_type { + FieldType::VecU8(field_name) => { + quote! { #field_name: value.#field_name.to_vec(), } + } + FieldType::VecCopy(field_name, _) => { + quote! { #field_name: value.#field_name.to_vec(), } + } + FieldType::VecDynamicZeroCopy(field_name, _) => { + // For non-copy vectors, clone each element directly + // We need to convert into() for Zstructs + quote! { + #field_name: { + value.#field_name.iter().map(|item| (*item).clone().into()).collect() + }, + } + } + FieldType::Array(field_name, _) => { + // For arrays, just copy the value + quote! { #field_name: *value.#field_name, } + } + FieldType::Option(field_name, field_type) => { + // Extract inner type from Option + let inner_type = utils::get_option_inner_type(field_type).expect( + "Failed to extract inner type from Option - expected Option format", + ); + let field_type = inner_type; + // For Option types, use a direct copy of the value when possible + quote! { + #field_name: if value.#field_name.is_some() { + // Create a clone of the Some value - for compressed proofs and other structs + // For instruction_data.rs, we just need to clone the value directly + Some((#field_type::from(*value.#field_name.as_ref().unwrap()).clone())) + } else { + None + }, + } + } + FieldType::Pubkey(field_name) => { + quote! { #field_name: *value.#field_name, } + } + FieldType::Primitive(field_name, field_type) => { + match () { + _ if utils::is_specific_primitive_type(field_type, "u8") => { + if MUT { + quote! { #field_name: *value.#field_name, } + } else { + quote! { #field_name: value.#field_name, } + } + } + _ if utils::is_specific_primitive_type(field_type, "bool") => { + if MUT { + quote! { #field_name: *value.#field_name > 0, } + } else { + quote! { #field_name: value.#field_name > 0, } + } + } + _ => { + // For u64, u32, u16 - use the type's from() method + quote! { #field_name: #field_type::from(*value.#field_name), } + } + } + } + FieldType::Copy(field_name, _) => { + quote! { #field_name: value.#field_name, } + } + FieldType::OptionU64(field_name) => { + quote! { #field_name: value.#field_name.as_ref().map(|x| u64::from(**x)), } + } + FieldType::OptionU32(field_name) => { + quote! { #field_name: value.#field_name.as_ref().map(|x| u32::from(**x)), } + } + FieldType::OptionU16(field_name) => { + quote! { #field_name: value.#field_name.as_ref().map(|x| u16::from(**x)), } + } + FieldType::DynamicZeroCopy(field_name, field_type) => { + // For complex non-copy types, dereference and clone directly + quote! { #field_name: #field_type::from(&value.#field_name), } + } + } + }); + conversions.collect::>() + } else { + vec![] + }; + + // Combine all the field conversions + let all_field_conversions = [meta_field_conversions, struct_field_conversions].concat(); + + // Return the final From implementation without generic From implementations + let result = quote! { + impl<'a> From<#z_struct_name<'a>> for #name { + fn from(value: #z_struct_name<'a>) -> Self { + Self { + #(#all_field_conversions)* + } + } + } + + impl<'a> From<&#z_struct_name<'a>> for #name { + fn from(value: &#z_struct_name<'a>) -> Self { + Self { + #(#all_field_conversions)* + } + } + } + }; + Ok(result) +} + +#[cfg(test)] +mod tests { + use quote::format_ident; + use syn::{parse_quote, Field}; + + use super::*; + + #[test] + fn test_generate_from_impl() { + // Create a struct for testing + let name = format_ident!("TestStruct"); + let z_struct_name = format_ident!("ZTestStruct"); + + // Create some test fields + let field_a: Field = parse_quote!(pub a: u8); + let field_b: Field = parse_quote!(pub b: u16); + let field_c: Field = parse_quote!(pub c: Vec); + + // Split into meta and struct fields + let meta_fields = vec![&field_a, &field_b]; + let struct_fields = vec![&field_c]; + + // Generate the implementation + let result = + generate_from_impl::(&name, &z_struct_name, &meta_fields, &struct_fields); + + // Convert to string for testing + let result_str = result.unwrap().to_string(); + + // Check that the implementation contains required elements + assert!(result_str.contains("impl < 'a > From < ZTestStruct < 'a >> for TestStruct")); + + // Check field handling + assert!(result_str.contains("a :")); // For u8 fields + assert!(result_str.contains("b :")); // For u16 fields + assert!(result_str.contains("c :")); // For Vec fields + } + + #[test] + fn test_generate_from_impl_mut() { + // Create a struct for testing + let name = format_ident!("TestStruct"); + let z_struct_name = format_ident!("ZTestStruct"); + + // Create some test fields + let field_a: Field = parse_quote!(pub a: u8); + let field_b: Field = parse_quote!(pub b: bool); + let field_c: Field = parse_quote!(pub c: Option); + + // Split into meta and struct fields + let meta_fields = vec![&field_a, &field_b]; + let struct_fields = vec![&field_c]; + + // Generate the implementation for mutable version + let result = + generate_from_impl::(&name, &z_struct_name, &meta_fields, &struct_fields); + + // Convert to string for testing + let result_str = result.unwrap().to_string(); + + // Check that the implementation contains required elements + assert!(result_str.contains("impl < 'a > From < ZTestStructMut < 'a >> for TestStruct")); + + // Check field handling + assert!(result_str.contains("a :")); // For u8 fields + assert!(result_str.contains("b :")); // For bool fields + assert!(result_str.contains("c :")); // For Option fields + } +} diff --git a/program-libs/zero-copy-derive/src/shared/meta_struct.rs b/program-libs/zero-copy-derive/src/shared/meta_struct.rs new file mode 100644 index 0000000000..0dbf9cda3a --- /dev/null +++ b/program-libs/zero-copy-derive/src/shared/meta_struct.rs @@ -0,0 +1,57 @@ +use proc_macro2::TokenStream; +use quote::{format_ident, quote}; +use syn::Field; + +use super::utils::convert_to_zerocopy_type; + +/// Generates the meta struct definition as a TokenStream +/// The `MUT` parameter determines if the struct should be generated for mutable access +pub fn generate_meta_struct( + z_struct_meta_name: &syn::Ident, + meta_fields: &[&Field], + hasher: bool, +) -> syn::Result { + let z_struct_meta_name = if MUT { + format_ident!("{}Mut", z_struct_meta_name) + } else { + z_struct_meta_name.clone() + }; + + // Generate the meta struct fields with converted types + let meta_fields_with_converted_types = meta_fields.iter().map(|field| { + let field_name = &field.ident; + let attributes = if hasher { + field + .attrs + .iter() + .map(|attr| { + quote! { #attr } + }) + .collect::>() + } else { + vec![quote! {}] + }; + let field_type = convert_to_zerocopy_type(&field.ty); + quote! { + #(#attributes)* + pub #field_name: #field_type + } + }); + let hasher = if hasher { + quote! { + , LightHasher + } + } else { + quote! {} + }; + + // Return the complete meta struct definition + let result = quote! { + #[repr(C)] + #[derive(Debug, PartialEq, light_zero_copy::KnownLayout, light_zero_copy::Immutable, light_zero_copy::Unaligned, light_zero_copy::FromBytes, light_zero_copy::IntoBytes #hasher)] + pub struct #z_struct_meta_name { + #(#meta_fields_with_converted_types,)* + } + }; + Ok(result) +} diff --git a/program-libs/zero-copy-derive/src/shared/mod.rs b/program-libs/zero-copy-derive/src/shared/mod.rs new file mode 100644 index 0000000000..c7b406b530 --- /dev/null +++ b/program-libs/zero-copy-derive/src/shared/mod.rs @@ -0,0 +1,6 @@ +pub mod from_impl; +pub mod meta_struct; +pub mod utils; +pub mod z_struct; +#[cfg(feature = "mut")] +pub mod zero_copy_new; diff --git a/program-libs/zero-copy-derive/src/shared/utils.rs b/program-libs/zero-copy-derive/src/shared/utils.rs new file mode 100644 index 0000000000..e92e56bb29 --- /dev/null +++ b/program-libs/zero-copy-derive/src/shared/utils.rs @@ -0,0 +1,437 @@ +use std::{ + collections::HashMap, + sync::{Arc, Mutex}, +}; + +use proc_macro2::TokenStream; +use quote::{format_ident, quote}; +use syn::{Attribute, Data, DeriveInput, Field, Fields, FieldsNamed, Ident, Type, TypePath}; + +// Global cache for storing whether a struct implements Copy +lazy_static::lazy_static! { + static ref COPY_IMPL_CACHE: Arc>> = Arc::new(Mutex::new(HashMap::new())); +} + +/// Creates a unique cache key for a type using span information to avoid collisions +/// between types with the same name from different modules/locations +fn create_unique_type_key(ident: &Ident) -> String { + format!("{}:{:?}", ident, ident.span()) +} + +/// Process the derive input to extract the struct information +pub fn process_input( + input: &DeriveInput, +) -> syn::Result<( + &Ident, // Original struct name + proc_macro2::Ident, // Z-struct name + proc_macro2::Ident, // Z-struct meta name + &FieldsNamed, // Struct fields +)> { + let name = &input.ident; + let z_struct_name = format_ident!("Z{}", name); + let z_struct_meta_name = format_ident!("Z{}Meta", name); + + // Populate the cache by checking if this struct implements Copy + let _ = struct_implements_copy(input); + + let fields = match &input.data { + Data::Struct(data) => match &data.fields { + Fields::Named(fields) => fields, + _ => { + return Err(syn::Error::new_spanned( + &data.fields, + "ZeroCopy only supports structs with named fields", + )) + } + }, + _ => { + return Err(syn::Error::new_spanned( + input, + "ZeroCopy only supports structs", + )) + } + }; + + Ok((name, z_struct_name, z_struct_meta_name, fields)) +} + +pub fn process_fields(fields: &FieldsNamed) -> (Vec<&Field>, Vec<&Field>) { + let mut meta_fields = Vec::new(); + let mut struct_fields = Vec::new(); + let mut reached_vec_or_option = false; + + for field in fields.named.iter() { + if !reached_vec_or_option { + if is_vec_or_option(&field.ty) || !is_copy_type(&field.ty) { + reached_vec_or_option = true; + struct_fields.push(field); + } else { + meta_fields.push(field); + } + } else { + struct_fields.push(field); + } + } + + (meta_fields, struct_fields) +} + +pub fn is_vec_or_option(ty: &Type) -> bool { + is_vec_type(ty) || is_option_type(ty) +} + +pub fn is_vec_type(ty: &Type) -> bool { + if let Type::Path(TypePath { path, .. }) = ty { + if let Some(segment) = path.segments.last() { + return segment.ident == "Vec"; + } + } + false +} + +pub fn is_option_type(ty: &Type) -> bool { + if let Type::Path(TypePath { path, .. }) = ty { + if let Some(segment) = path.segments.last() { + return segment.ident == "Option"; + } + } + false +} + +pub fn get_vec_inner_type(ty: &Type) -> Option<&Type> { + if let Type::Path(TypePath { path, .. }) = ty { + if let Some(segment) = path.segments.last() { + if segment.ident == "Vec" { + if let syn::PathArguments::AngleBracketed(args) = &segment.arguments { + if let Some(syn::GenericArgument::Type(inner_ty)) = args.args.first() { + return Some(inner_ty); + } + } + } + } + } + None +} + +pub fn get_option_inner_type(ty: &Type) -> Option<&Type> { + if let Type::Path(TypePath { path, .. }) = ty { + if let Some(segment) = path.segments.last() { + if segment.ident == "Option" { + if let syn::PathArguments::AngleBracketed(args) = &segment.arguments { + if let Some(syn::GenericArgument::Type(inner_ty)) = args.args.first() { + return Some(inner_ty); + } + } + } + } + } + None +} + +pub fn is_primitive_integer(ty: &Type) -> bool { + if let Type::Path(TypePath { path, .. }) = ty { + if let Some(segment) = path.segments.last() { + let ident = &segment.ident; + return ident == "u16" + || ident == "u32" + || ident == "u64" + || ident == "i16" + || ident == "i32" + || ident == "i64" + || ident == "u8" + || ident == "i8"; + } + } + false +} + +pub fn is_bool_type(ty: &Type) -> bool { + if let Type::Path(TypePath { path, .. }) = ty { + if let Some(segment) = path.segments.last() { + return segment.ident == "bool"; + } + } + false +} + +/// Check if a type is a specific primitive type (u8, u16, u32, u64, bool, etc.) +pub fn is_specific_primitive_type(ty: &Type, type_name: &str) -> bool { + if let Type::Path(TypePath { path, .. }) = ty { + if let Some(segment) = path.segments.last() { + return segment.ident == type_name; + } + } + false +} + +pub fn is_pubkey_type(ty: &Type) -> bool { + if let Type::Path(TypePath { path, .. }) = ty { + if let Some(segment) = path.segments.last() { + return segment.ident == "Pubkey"; + } + } + false +} + +pub fn convert_to_zerocopy_type(ty: &Type) -> TokenStream { + match ty { + Type::Path(TypePath { path, .. }) => { + if let Some(segment) = path.segments.last() { + let ident = &segment.ident; + + // Handle primitive types first + match ident.to_string().as_str() { + "u16" => quote! { light_zero_copy::little_endian::U16 }, + "u32" => quote! { light_zero_copy::little_endian::U32 }, + "u64" => quote! { light_zero_copy::little_endian::U64 }, + "bool" => quote! { u8 }, + _ => { + // Handle container types recursively + if let syn::PathArguments::AngleBracketed(args) = &segment.arguments { + let transformed_args: Vec = args + .args + .iter() + .map(|arg| { + if let syn::GenericArgument::Type(inner_type) = arg { + convert_to_zerocopy_type(inner_type) + } else { + quote! { #arg } + } + }) + .collect(); + + quote! { #ident<#(#transformed_args),*> } + } else { + quote! { #ty } + } + } + } + } else { + quote! { #ty } + } + } + _ => { + quote! { #ty } + } + } +} + +/// Checks if a struct has a derive(Copy) attribute +fn struct_has_copy_derive(attrs: &[Attribute]) -> bool { + attrs.iter().any(|attr| { + attr.path().is_ident("derive") && { + let mut found_copy = false; + // Use parse_nested_meta as the primary and only approach - it's the syn 2.0 standard + // for parsing comma-separated derive items like #[derive(Copy, Clone, Debug)] + attr.parse_nested_meta(|meta| { + if meta.path.is_ident("Copy") { + found_copy = true; + } + Ok(()) // Continue parsing other derive items + }) + .is_ok() + && found_copy + } + }) +} + +/// Determines whether a struct implements Copy by checking for the #[derive(Copy)] attribute. +/// Results are cached for performance. +/// +/// In Rust, a struct can only implement Copy if: +/// 1. It explicitly has a #[derive(Copy)] attribute, AND +/// 2. All of its fields implement Copy +/// +/// The Rust compiler will enforce the second condition at compile time, so we only need to check +/// for the derive attribute here. +pub fn struct_implements_copy(input: &DeriveInput) -> bool { + let cache_key = create_unique_type_key(&input.ident); + + // Check the cache first + if let Some(implements_copy) = COPY_IMPL_CACHE.lock().unwrap().get(&cache_key) { + return *implements_copy; + } + + // Check if the struct has a derive(Copy) attribute + let implements_copy = struct_has_copy_derive(&input.attrs); + + // Cache the result + COPY_IMPL_CACHE + .lock() + .unwrap() + .insert(cache_key, implements_copy); + + implements_copy +} + +/// Determines whether a type implements Copy +/// 1. check whether type is a primitive type that implements Copy +/// 2. check whether type is an array type (which is always Copy if the element type is Copy) +/// 3. check whether type is struct -> check in the COPY_IMPL_CACHE if we know whether it has a #[derive(Copy)] attribute +/// +/// For struct types, this relies on the cache populated by struct_implements_copy. If we don't have cached +/// information, it assumes the type does not implement Copy. This is a limitation of our approach, but it +/// works well in practice because process_input will call struct_implements_copy for all structs before +/// they might be referenced by other structs. +pub fn is_copy_type(ty: &Type) -> bool { + match ty { + Type::Path(TypePath { path, .. }) => { + if let Some(segment) = path.segments.last() { + let ident = &segment.ident; + + // Check if it's a primitive type that implements Copy + if ident == "u8" + || ident == "u16" + || ident == "u32" + || ident == "u64" + || ident == "i8" + || ident == "i16" + || ident == "i32" + || ident == "i64" + || ident == "bool" // bool is a Copy type + || ident == "char" + || ident == "Pubkey" + // Pubkey is hardcoded as copy type for now. + { + return true; + } + + // Check if we have cached information about this type + let cache_key = create_unique_type_key(ident); + if let Some(implements_copy) = COPY_IMPL_CACHE.lock().unwrap().get(&cache_key) { + return *implements_copy; + } + } + } + // Handle array types (which are always Copy if the element type is Copy) + Type::Array(array) => { + // Arrays are Copy if their element type is Copy + return is_copy_type(&array.elem); + } + // For struct types not in cache, we'd need the derive input to check attributes + _ => {} + } + false +} + +#[cfg(test)] +mod tests { + use syn::parse_quote; + + use super::*; + + // Helper function to check if a struct implements Copy + fn check_struct_implements_copy(input: syn::DeriveInput) -> bool { + struct_implements_copy(&input) + } + + #[test] + fn test_struct_implements_copy() { + // Ensure the cache is cleared and the lock is released immediately + COPY_IMPL_CACHE.lock().unwrap().clear(); + // Test case 1: Empty struct with #[derive(Copy)] + let input: syn::DeriveInput = parse_quote! { + #[derive(Copy, Clone)] + struct EmptyStruct {} + }; + assert!( + check_struct_implements_copy(input), + "EmptyStruct should implement Copy with #[derive(Copy)]" + ); + + // Test case 2: Simple struct with #[derive(Copy)] + let input: syn::DeriveInput = parse_quote! { + #[derive(Copy, Clone)] + struct SimpleStruct { + a: u8, + b: u16, + } + }; + assert!( + check_struct_implements_copy(input), + "SimpleStruct should implement Copy with #[derive(Copy)]" + ); + + // Test case 3: Struct with #[derive(Clone)] but not Copy + let input: syn::DeriveInput = parse_quote! { + #[derive(Clone)] + struct StructWithoutCopy { + a: u8, + b: u16, + } + }; + assert!( + !check_struct_implements_copy(input), + "StructWithoutCopy should not implement Copy without #[derive(Copy)]" + ); + + // Test case 4: Struct with a non-Copy field but with derive(Copy) + // Note: In real Rust code, this would not compile, but for our test we only check attributes + let input: syn::DeriveInput = parse_quote! { + #[derive(Copy, Clone)] + struct StructWithVec { + a: u8, + b: Vec, + } + }; + assert!( + check_struct_implements_copy(input), + "StructWithVec has #[derive(Copy)] so our function returns true" + ); + + // Test case 5: Struct with all Copy fields but without #[derive(Copy)] + let input: syn::DeriveInput = parse_quote! { + struct StructWithCopyFields { + a: u8, + b: u16, + c: i32, + d: bool, + } + }; + assert!( + !check_struct_implements_copy(input), + "StructWithCopyFields should not implement Copy without #[derive(Copy)]" + ); + + // Test case 6: Unit struct without #[derive(Copy)] + let input: syn::DeriveInput = parse_quote! { + struct UnitStructWithoutCopy; + }; + assert!( + !check_struct_implements_copy(input), + "UnitStructWithoutCopy should not implement Copy without #[derive(Copy)]" + ); + + // Test case 7: Unit struct with #[derive(Copy)] + let input: syn::DeriveInput = parse_quote! { + #[derive(Copy, Clone)] + struct UnitStructWithCopy; + }; + assert!( + check_struct_implements_copy(input), + "UnitStructWithCopy should implement Copy with #[derive(Copy)]" + ); + + // Test case 8: Tuple struct with #[derive(Copy)] + let input: syn::DeriveInput = parse_quote! { + #[derive(Copy, Clone)] + struct TupleStruct(u32, bool, char); + }; + assert!( + check_struct_implements_copy(input), + "TupleStruct should implement Copy with #[derive(Copy)]" + ); + + // Test case 9: Multiple derives including Copy + let input: syn::DeriveInput = parse_quote! { + #[derive(Debug, PartialEq, Copy, Clone)] + struct MultipleDerivesStruct { + a: u8, + } + }; + assert!( + check_struct_implements_copy(input), + "MultipleDerivesStruct should implement Copy with #[derive(Copy)]" + ); + } +} diff --git a/program-libs/zero-copy-derive/src/shared/z_struct.rs b/program-libs/zero-copy-derive/src/shared/z_struct.rs new file mode 100644 index 0000000000..77aa085c49 --- /dev/null +++ b/program-libs/zero-copy-derive/src/shared/z_struct.rs @@ -0,0 +1,630 @@ +use proc_macro2::TokenStream; +use quote::{format_ident, quote, TokenStreamExt}; +use syn::{parse_quote, Field, Ident, Type}; + +use super::utils; + +/// Enum representing the different field types for zero-copy struct +/// (Name, Type) +/// Note: Arrays with Option elements are not currently supported +#[derive(Debug)] +pub enum FieldType<'a> { + VecU8(&'a Ident), + VecCopy(&'a Ident, &'a Type), + VecDynamicZeroCopy(&'a Ident, &'a Type), + Array(&'a Ident, &'a Type), // Static arrays only - no Option elements supported + Option(&'a Ident, &'a Type), + OptionU64(&'a Ident), + OptionU32(&'a Ident), + OptionU16(&'a Ident), + Pubkey(&'a Ident), + Primitive(&'a Ident, &'a Type), + Copy(&'a Ident, &'a Type), + DynamicZeroCopy(&'a Ident, &'a Type), +} + +impl<'a> FieldType<'a> { + /// Get the name of the field + pub fn name(&self) -> &'a Ident { + match self { + FieldType::VecU8(name) => name, + FieldType::VecCopy(name, _) => name, + FieldType::VecDynamicZeroCopy(name, _) => name, + FieldType::Array(name, _) => name, + FieldType::Option(name, _) => name, + FieldType::OptionU64(name) => name, + FieldType::OptionU32(name) => name, + FieldType::OptionU16(name) => name, + FieldType::Pubkey(name) => name, + FieldType::Primitive(name, _) => name, + FieldType::Copy(name, _) => name, + FieldType::DynamicZeroCopy(name, _) => name, + } + } +} + +/// Classify a Vec type based on its inner type +fn classify_vec_type<'a>( + field_name: &'a Ident, + field_type: &'a Type, + inner_type: &'a Type, +) -> FieldType<'a> { + if utils::is_specific_primitive_type(inner_type, "u8") { + FieldType::VecU8(field_name) + } else if utils::is_copy_type(inner_type) { + FieldType::VecCopy(field_name, inner_type) + } else { + FieldType::VecDynamicZeroCopy(field_name, field_type) + } +} + +/// Classify an Option type based on its inner type +fn classify_option_type<'a>( + field_name: &'a Ident, + field_type: &'a Type, + inner_type: &'a Type, +) -> FieldType<'a> { + if utils::is_primitive_integer(inner_type) { + match () { + _ if utils::is_specific_primitive_type(inner_type, "u64") => { + FieldType::OptionU64(field_name) + } + _ if utils::is_specific_primitive_type(inner_type, "u32") => { + FieldType::OptionU32(field_name) + } + _ if utils::is_specific_primitive_type(inner_type, "u16") => { + FieldType::OptionU16(field_name) + } + _ => FieldType::Option(field_name, field_type), + } + } else { + FieldType::Option(field_name, field_type) + } +} + +/// Classify a primitive integer type +fn classify_integer_type<'a>( + field_name: &'a Ident, + field_type: &'a Type, +) -> syn::Result> { + match () { + _ if utils::is_specific_primitive_type(field_type, "u64") + | utils::is_specific_primitive_type(field_type, "u32") + | utils::is_specific_primitive_type(field_type, "u16") + | utils::is_specific_primitive_type(field_type, "u8") => + { + Ok(FieldType::Primitive(field_name, field_type)) + } + _ => Err(syn::Error::new_spanned( + field_type, + "Unsupported integer type. Only u8, u16, u32, and u64 are supported", + )), + } +} + +/// Classify a Copy type +fn classify_copy_type<'a>(field_name: &'a Ident, field_type: &'a Type) -> FieldType<'a> { + if utils::is_specific_primitive_type(field_type, "u8") + || utils::is_specific_primitive_type(field_type, "bool") + { + FieldType::Primitive(field_name, field_type) + } else { + FieldType::Copy(field_name, field_type) + } +} + +/// Classify a single field into its FieldType +fn classify_field<'a>(field_name: &'a Ident, field_type: &'a Type) -> syn::Result> { + // Vec types + if utils::is_vec_type(field_type) { + return match utils::get_vec_inner_type(field_type) { + Some(inner_type) => Ok(classify_vec_type(field_name, field_type, inner_type)), + None => Err(syn::Error::new_spanned( + field_type, + "Could not determine inner type of Vec", + )), + }; + } + + // Array types + if let Type::Array(_) = field_type { + return Ok(FieldType::Array(field_name, field_type)); + } + + // Option types + if utils::is_option_type(field_type) { + return match utils::get_option_inner_type(field_type) { + Some(inner_type) => Ok(classify_option_type(field_name, field_type, inner_type)), + None => Ok(FieldType::Option(field_name, field_type)), + }; + } + + // Simple type dispatch + match () { + _ if utils::is_pubkey_type(field_type) => Ok(FieldType::Pubkey(field_name)), + _ if utils::is_bool_type(field_type) => Ok(FieldType::Primitive(field_name, field_type)), + _ if utils::is_primitive_integer(field_type) => { + classify_integer_type(field_name, field_type) + } + _ if utils::is_copy_type(field_type) => Ok(classify_copy_type(field_name, field_type)), + _ => Ok(FieldType::DynamicZeroCopy(field_name, field_type)), + } +} + +/// Analyze struct fields and return vector of FieldType enums +pub fn analyze_struct_fields<'a>( + struct_fields: &'a [&'a Field], +) -> syn::Result>> { + struct_fields + .iter() + .map(|field| { + let field_name = field + .ident + .as_ref() + .ok_or_else(|| syn::Error::new_spanned(field, "Field must have a name"))?; + classify_field(field_name, &field.ty) + }) + .collect() +} + +/// Generate struct fields with zerocopy types based on field type enum +fn generate_struct_fields_with_zerocopy_types<'a, const MUT: bool>( + struct_fields: &'a [&'a Field], + hasher: &'a bool, +) -> syn::Result + 'a> { + let field_types = analyze_struct_fields(struct_fields)?; + let iterator = field_types + .into_iter() + .zip(struct_fields.iter()) + .map(|(field_type, field)| { + let attributes = if *hasher { + field + .attrs + .iter() + .map(|attr| { + quote! { #attr } + }) + .collect::>() + } else { + vec![quote! {}] + }; + let (mutability, import_path, import_slice, camel_case_suffix): ( + syn::Type, + syn::Ident, + syn::Ident, + String, + ) = if MUT { + ( + parse_quote!(&'a mut [u8]), + format_ident!("borsh_mut"), + format_ident!("slice_mut"), + String::from("Mut"), + ) + } else { + ( + parse_quote!(&'a [u8]), + format_ident!("borsh"), + format_ident!("slice"), + String::new(), + ) + }; + let deserialize_ident = format_ident!("Deserialize{}", camel_case_suffix); + let trait_name: syn::Type = parse_quote!(light_zero_copy::#import_path::#deserialize_ident); + let slice_ident = format_ident!("ZeroCopySlice{}Borsh", camel_case_suffix); + let slice_name: syn::Type = parse_quote!(light_zero_copy::#import_slice::#slice_ident); + let struct_inner_ident = format_ident!("ZeroCopyStructInner{}", camel_case_suffix); + let inner_ident = format_ident!("ZeroCopyInner{}", camel_case_suffix); + let struct_inner_trait_name: syn::Type = parse_quote!(light_zero_copy::#import_path::#struct_inner_ident::#inner_ident); + match field_type { + FieldType::VecU8(field_name) => { + quote! { + #(#attributes)* + pub #field_name: #mutability + } + } + FieldType::VecCopy(field_name, inner_type) => { + // For primitive Copy types, use the zerocopy converted type directly + // For complex Copy types, use the ZeroCopyStructInner trait + if utils::is_primitive_integer(inner_type) || utils::is_bool_type(inner_type) || utils::is_pubkey_type(inner_type) { + let zerocopy_type = utils::convert_to_zerocopy_type(inner_type); + quote! { + #(#attributes)* + pub #field_name: #slice_name<'a, #zerocopy_type> + } + } else { + let inner_type = utils::convert_to_zerocopy_type(inner_type); + quote! { + #(#attributes)* + pub #field_name: #slice_name<'a, <#inner_type as #struct_inner_trait_name>> + } + } + } + FieldType::VecDynamicZeroCopy(field_name, field_type) => { + let field_type = utils::convert_to_zerocopy_type(field_type); + quote! { + #(#attributes)* + pub #field_name: <#field_type as #trait_name<'a>>::Output + } + } + FieldType::Array(field_name, field_type) => { + let field_type = utils::convert_to_zerocopy_type(field_type); + quote! { + #(#attributes)* + pub #field_name: light_zero_copy::Ref<#mutability , #field_type> + } + } + FieldType::Option(field_name, field_type) => { + let field_type = utils::convert_to_zerocopy_type(field_type); + quote! { + #(#attributes)* + pub #field_name: <#field_type as #trait_name<'a>>::Output + } + } + FieldType::OptionU64(field_name) => { + let field_ty_zerocopy = utils::convert_to_zerocopy_type(&parse_quote!(u64)); + quote! { + #(#attributes)* + pub #field_name: Option> + } + } + FieldType::OptionU32(field_name) => { + let field_ty_zerocopy = utils::convert_to_zerocopy_type(&parse_quote!(u32)); + quote! { + #(#attributes)* + pub #field_name: Option> + } + } + FieldType::OptionU16(field_name) => { + let field_ty_zerocopy = utils::convert_to_zerocopy_type(&parse_quote!(u16)); + quote! { + #(#attributes)* + pub #field_name: Option> + } + } + FieldType::Pubkey(field_name) => { + quote! { + #(#attributes)* + pub #field_name: >::Output + } + } + FieldType::Primitive(field_name, field_type) => { + quote! { + #(#attributes)* + pub #field_name: <#field_type as #trait_name<'a>>::Output + } + } + // FieldType::Bool(field_name) => { + // quote! { + // #(#attributes)* + // pub #field_name: >::Output + // } + // } + FieldType::Copy(field_name, field_type) => { + let zerocopy_type = utils::convert_to_zerocopy_type(field_type); + quote! { + #(#attributes)* + pub #field_name: light_zero_copy::Ref<#mutability , #zerocopy_type> + } + } + FieldType::DynamicZeroCopy(field_name, field_type) => { + quote! { + #(#attributes)* + pub #field_name: <#field_type as #trait_name<'a>>::Output + } + } + } + }); + Ok(iterator) +} + +/// Generate accessor methods for boolean fields in struct_fields. +/// We need accessors because booleans are stored as u8. +fn generate_bool_accessor_methods<'a, const MUT: bool>( + struct_fields: &'a [&'a Field], +) -> impl Iterator + 'a { + struct_fields.iter().filter_map(|field| { + let field_name = &field.ident; + let field_type = &field.ty; + + if utils::is_bool_type(field_type) { + let comparison = if MUT { + quote! { *self.#field_name > 0 } + } else { + quote! { self.#field_name > 0 } + }; + + Some(quote! { + pub fn #field_name(&self) -> bool { + #comparison + } + }) + } else { + None + } + }) +} + +/// Generates the ZStruct definition as a TokenStream +pub fn generate_z_struct( + z_struct_name: &Ident, + z_struct_meta_name: &Ident, + struct_fields: &[&Field], + meta_fields: &[&Field], + hasher: bool, +) -> syn::Result { + let z_struct_name = if MUT { + format_ident!("{}Mut", z_struct_name) + } else { + z_struct_name.clone() + }; + let z_struct_meta_name = if MUT { + format_ident!("{}Mut", z_struct_meta_name) + } else { + z_struct_meta_name.clone() + }; + let mutability: syn::Type = if MUT { + parse_quote!(&'a mut [u8]) + } else { + parse_quote!(&'a [u8]) + }; + + let derive_clone = if MUT { + quote! {} + } else { + quote! {, Clone } + }; + let struct_fields_with_zerocopy_types: Vec = + generate_struct_fields_with_zerocopy_types::(struct_fields, &hasher)?.collect(); + + let derive_hasher = if hasher { + quote! { + , LightHasher + } + } else { + quote! {} + }; + let hasher_flatten = if hasher { + quote! { + #[flatten] + } + } else { + quote! {} + }; + + let partial_eq_derive = if MUT { quote!() } else { quote!(, PartialEq) }; + + let mut z_struct = if meta_fields.is_empty() { + quote! { + // ZStruct + #[derive(Debug #partial_eq_derive #derive_clone #derive_hasher)] + pub struct #z_struct_name<'a> { + #(#struct_fields_with_zerocopy_types,)* + } + } + } else { + let mut tokens = quote! { + // ZStruct + #[derive(Debug #partial_eq_derive #derive_clone #derive_hasher)] + pub struct #z_struct_name<'a> { + #hasher_flatten + __meta: light_zero_copy::Ref<#mutability, #z_struct_meta_name>, + #(#struct_fields_with_zerocopy_types,)* + } + impl<'a> core::ops::Deref for #z_struct_name<'a> { + type Target = light_zero_copy::Ref<#mutability , #z_struct_meta_name>; + + fn deref(&self) -> &Self::Target { + &self.__meta + } + } + }; + + if MUT { + tokens.append_all(quote! { + impl<'a> core::ops::DerefMut for #z_struct_name<'a> { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.__meta + } + } + }); + } + tokens + }; + + if !meta_fields.is_empty() { + let meta_bool_accessor_methods = generate_bool_accessor_methods::(meta_fields); + z_struct.append_all(quote! { + // Implement methods for ZStruct + impl<'a> #z_struct_name<'a> { + #(#meta_bool_accessor_methods)* + } + }) + }; + + if !struct_fields.is_empty() { + let bool_accessor_methods = generate_bool_accessor_methods::(struct_fields); + z_struct.append_all(quote! { + // Implement methods for ZStruct + impl<'a> #z_struct_name<'a> { + #(#bool_accessor_methods)* + } + + }); + } + Ok(z_struct) +} + +#[cfg(test)] +mod tests { + use quote::format_ident; + use rand::{prelude::SliceRandom, rngs::StdRng, thread_rng, Rng, SeedableRng}; + use syn::parse_quote; + + use super::*; + + /// Generate a safe field name for testing + fn random_ident(rng: &mut StdRng) -> String { + // Use predetermined safe field names + const FIELD_NAMES: &[&str] = &[ + "field1", "field2", "field3", "field4", "field5", "value", "data", "count", "size", + "flag", "name", "id", "code", "index", "key", "amount", "balance", "total", "result", + "status", + ]; + + FIELD_NAMES.choose(rng).unwrap().to_string() + } + + /// Generate a random Rust type + fn random_type(rng: &mut StdRng, _depth: usize) -> syn::Type { + // Define our available types + let types = [0, 1, 2, 3, 4, 5, 6, 7]; + + // Randomly select a type index + let selected = *types.choose(rng).unwrap(); + + // Return the corresponding type + match selected { + 0 => parse_quote!(u8), + 1 => parse_quote!(u16), + 2 => parse_quote!(u32), + 3 => parse_quote!(u64), + 4 => parse_quote!(bool), + 5 => parse_quote!(Vec), + 6 => parse_quote!(Vec), + 7 => parse_quote!(Vec), + _ => unreachable!(), + } + } + + /// Generate a random field + fn random_field(rng: &mut StdRng) -> Field { + let name = random_ident(rng); + let ty = random_type(rng, 0); + + // Use a safer approach to create the field + let name_ident = format_ident!("{}", name); + parse_quote!(pub #name_ident: #ty) + } + + /// Generate a list of random fields + fn random_fields(rng: &mut StdRng, count: usize) -> Vec { + (0..count).map(|_| random_field(rng)).collect() + } + + #[test] + fn test_fuzz_generate_z_struct() { + // Set up RNG with a seed for reproducibility + let seed = thread_rng().gen(); + println!("seed {}", seed); + let mut rng = StdRng::seed_from_u64(seed); + + // Now that the test is working, run with 10,000 iterations + let num_iters = 10000; + + for i in 0..num_iters { + // Generate a random struct name + let struct_name = format_ident!("{}", random_ident(&mut rng)); + let z_struct_name = format_ident!("Z{}", struct_name); + let z_struct_meta_name = format_ident!("Z{}Meta", struct_name); + + // Generate random number of fields (1-10) + let field_count = rng.gen_range(1..11); + let fields = random_fields(&mut rng, field_count); + + // Create a named fields collection that lives longer than the process_fields call + let syn_fields = syn::punctuated::Punctuated::from_iter(fields.iter().cloned()); + let fields_named = syn::FieldsNamed { + brace_token: syn::token::Brace::default(), + named: syn_fields, + }; + + // Split into meta fields and struct fields + let (meta_fields, struct_fields) = crate::shared::utils::process_fields(&fields_named); + + // Call the function we're testing + let result = generate_z_struct::( + &z_struct_name, + &z_struct_meta_name, + &struct_fields, + &meta_fields, + false, + ); + + // Get the generated code as a string for validation + let result_str = result.unwrap().to_string(); + + // Validate the generated code + + // Verify the result contains expected struct elements + // Basic validation - must be non-empty + assert!( + !result_str.is_empty(), + "Failed to generate TokenStream for iteration {}", + i + ); + + // Validate that the generated code contains the expected struct definition + let struct_pattern = format!("struct {} < 'a >", z_struct_name); + assert!( + result_str.contains(&struct_pattern), + "Generated code missing struct definition for iteration {}. Expected: {}", + i, + struct_pattern + ); + + if meta_fields.is_empty() { + // Validate the meta field is present + assert!( + !result_str.contains("meta :"), + "Generated code had meta field for iteration {}", + i + ); + // Validate Deref implementation + assert!( + !result_str.contains("impl < 'a > core :: ops :: Deref"), + "Generated code missing Deref implementation for iteration {}", + i + ); + } else { + // Validate the meta field is present + assert!( + result_str.contains("meta :"), + "Generated code missing meta field for iteration {}", + i + ); + // Validate Deref implementation + assert!( + result_str.contains("impl < 'a > core :: ops :: Deref"), + "Generated code missing Deref implementation for iteration {}", + i + ); + // Validate Target type + assert!( + result_str.contains("type Target"), + "Generated code missing Target type for iteration {}", + i + ); + // Check that the deref method is implemented + assert!( + result_str.contains("fn deref (& self)"), + "Generated code missing deref method for iteration {}", + i + ); + + // Check for light_zero_copy::Ref reference + assert!( + result_str.contains("light_zero_copy :: Ref"), + "Generated code missing light_zero_copy::Ref for iteration {}", + i + ); + } + + // Make sure derive attributes are present + assert!( + result_str.contains("# [derive (Debug , PartialEq , Clone)]"), + "Generated code missing derive attributes for iteration {}", + i + ); + } + } +} diff --git a/program-libs/zero-copy-derive/src/shared/zero_copy_new.rs b/program-libs/zero-copy-derive/src/shared/zero_copy_new.rs new file mode 100644 index 0000000000..495977cbf0 --- /dev/null +++ b/program-libs/zero-copy-derive/src/shared/zero_copy_new.rs @@ -0,0 +1,391 @@ +use proc_macro2::TokenStream as TokenStream2; +use quote::quote; +use syn::Ident; + +use crate::shared::{ + utils, + z_struct::{analyze_struct_fields, FieldType}, +}; + +/// Generate ZeroCopyNew implementation with new_at method for a struct +pub fn generate_init_mut_impl( + struct_name: &syn::Ident, + meta_fields: &[&syn::Field], + struct_fields: &[&syn::Field], +) -> syn::Result { + let config_name = quote::format_ident!("{}Config", struct_name); + let z_meta_name = quote::format_ident!("Z{}MetaMut", struct_name); + let z_struct_mut_name = quote::format_ident!("Z{}Mut", struct_name); + + // Use the pre-separated fields from utils::process_fields (consistent with other derives) + let struct_field_types = analyze_struct_fields(struct_fields)?; + + // Generate field initialization code for struct fields only (meta fields are part of __meta) + let field_initializations: Result, syn::Error> = + struct_field_types + .iter() + .map(|field_type| generate_field_initialization(field_type)) + .collect(); + let field_initializations = field_initializations?; + + // Generate struct construction - only include struct fields that were initialized + // Meta fields are accessed via __meta.field_name in the generated ZStruct + let struct_field_names: Vec = struct_field_types + .iter() + .map(|field_type| { + let field_name = field_type.name(); + quote! { #field_name, } + }) + .collect(); + + // Check if there are meta fields to determine whether to include __meta + let has_meta_fields = !meta_fields.is_empty(); + + let meta_initialization = if has_meta_fields { + quote! { + // Handle the meta struct (fixed-size fields at the beginning) + let (__meta, bytes) = Ref::<&mut [u8], #z_meta_name>::from_prefix(bytes)?; + } + } else { + quote! { + // No meta fields, skip meta struct initialization + } + }; + + let struct_construction = if has_meta_fields { + quote! { + let result = #z_struct_mut_name { + __meta, + #(#struct_field_names)* + }; + } + } else { + quote! { + let result = #z_struct_mut_name { + #(#struct_field_names)* + }; + } + }; + + // Generate byte_len calculation for each field type + let byte_len_calculations: Result, syn::Error> = + struct_field_types + .iter() + .map(|field_type| generate_byte_len_calculation(field_type)) + .collect(); + let byte_len_calculations = byte_len_calculations?; + + // Calculate meta size if there are meta fields + let meta_size_calculation = if has_meta_fields { + quote! { + core::mem::size_of::<#z_meta_name>() + } + } else { + quote! { 0 } + }; + + let result = quote! { + impl<'a> light_zero_copy::init_mut::ZeroCopyNew<'a> for #struct_name { + type Config = #config_name; + type Output = >::Output; + + fn byte_len(config: &Self::Config) -> usize { + #meta_size_calculation #(+ #byte_len_calculations)* + } + + fn new_zero_copy( + bytes: &'a mut [u8], + config: Self::Config, + ) -> Result<(Self::Output, &'a mut [u8]), light_zero_copy::errors::ZeroCopyError> { + use zerocopy::Ref; + + #meta_initialization + + #(#field_initializations)* + + #struct_construction + + Ok((result, bytes)) + } + } + }; + Ok(result) +} + +// Configuration system functions moved from config.rs + +/// Determine if this field type requires configuration for initialization +pub fn requires_config(field_type: &FieldType) -> bool { + match field_type { + // Vec types always need length configuration + FieldType::VecU8(_) | FieldType::VecCopy(_, _) | FieldType::VecDynamicZeroCopy(_, _) => { + true + } + // Option types need Some/None configuration + FieldType::Option(_, _) => true, + // Fixed-size types don't need configuration + FieldType::Array(_, _) + | FieldType::Pubkey(_) + | FieldType::Primitive(_, _) + | FieldType::Copy(_, _) => false, + // DynamicZeroCopy types might need configuration if they contain Vec/Option + FieldType::DynamicZeroCopy(_, _) => true, // Conservative: assume they need config + // Option integer types need config to determine if they're enabled + FieldType::OptionU64(_) | FieldType::OptionU32(_) | FieldType::OptionU16(_) => true, + } +} + +/// Generate the config type for this field +pub fn config_type(field_type: &FieldType) -> syn::Result { + let result = match field_type { + // Simple Vec types: just need length + FieldType::VecU8(_) => quote! { u32 }, + FieldType::VecCopy(_, _) => quote! { u32 }, + + // Complex Vec types: need config for each element + FieldType::VecDynamicZeroCopy(_, vec_type) => { + if let Some(inner_type) = utils::get_vec_inner_type(vec_type) { + quote! { Vec<<#inner_type as light_zero_copy::init_mut::ZeroCopyNew<'static>>::Config> } + } else { + return Err(syn::Error::new_spanned( + vec_type, + "Could not determine inner type for VecDynamicZeroCopy config", + )); + } + } + + // Option types: delegate to the Option's Config type + FieldType::Option(_, option_type) => { + quote! { <#option_type as light_zero_copy::init_mut::ZeroCopyNew<'static>>::Config } + } + + // Fixed-size types don't need configuration + FieldType::Array(_, _) + | FieldType::Pubkey(_) + | FieldType::Primitive(_, _) + | FieldType::Copy(_, _) => quote! { () }, + + // Option integer types: use bool config to determine if enabled + FieldType::OptionU64(_) | FieldType::OptionU32(_) | FieldType::OptionU16(_) => { + quote! { bool } + } + + // DynamicZeroCopy types: delegate to their Config type (Config is typically 'static) + FieldType::DynamicZeroCopy(_, field_type) => { + let field_type = utils::convert_to_zerocopy_type(field_type); + quote! { <#field_type as light_zero_copy::init_mut::ZeroCopyNew<'static>>::Config } + } + }; + Ok(result) +} + +/// Generate a configuration struct for a given struct +pub fn generate_config_struct( + struct_name: &Ident, + field_types: &[FieldType], +) -> syn::Result { + let config_name = quote::format_ident!("{}Config", struct_name); + + // Generate config fields only for fields that require configuration + let config_fields: Result, syn::Error> = field_types + .iter() + .filter(|field_type| requires_config(field_type)) + .map(|field_type| -> syn::Result { + let field_name = field_type.name(); + let config_type = config_type(field_type)?; + Ok(quote! { + pub #field_name: #config_type, + }) + }) + .collect(); + let config_fields = config_fields?; + + let result = if config_fields.is_empty() { + // If no fields require configuration, create an empty config struct + quote! { + #[derive(Debug, Clone, PartialEq)] + pub struct #config_name; + } + } else { + quote! { + #[derive(Debug, Clone, PartialEq)] + pub struct #config_name { + #(#config_fields)* + } + } + }; + Ok(result) +} + +/// Generate initialization logic for a field based on its configuration +pub fn generate_field_initialization(field_type: &FieldType) -> syn::Result { + let result = match field_type { + FieldType::VecU8(field_name) => { + quote! { + // Initialize the length prefix but don't use the returned ZeroCopySliceMut + { + light_zero_copy::slice_mut::ZeroCopySliceMutBorsh::::new_at( + config.#field_name.into(), + bytes + )?; + } + // Split off the length prefix (4 bytes) and get the slice + let (_, bytes) = bytes.split_at_mut(4); + let (#field_name, bytes) = bytes.split_at_mut(config.#field_name as usize); + } + } + + FieldType::VecCopy(field_name, inner_type) => { + quote! { + let (#field_name, bytes) = light_zero_copy::slice_mut::ZeroCopySliceMutBorsh::<#inner_type>::new_at( + config.#field_name.into(), + bytes + )?; + } + } + + FieldType::VecDynamicZeroCopy(field_name, vec_type) + | FieldType::DynamicZeroCopy(field_name, vec_type) + | FieldType::Option(field_name, vec_type) => { + quote! { + let (#field_name, bytes) = <#vec_type as light_zero_copy::init_mut::ZeroCopyNew<'a>>::new_zero_copy( + bytes, + config.#field_name + )?; + } + } + + FieldType::OptionU64(field_name) + | FieldType::OptionU32(field_name) + | FieldType::OptionU16(field_name) => { + let option_type = match field_type { + FieldType::OptionU64(_) => quote! { Option }, + FieldType::OptionU32(_) => quote! { Option }, + FieldType::OptionU16(_) => quote! { Option }, + _ => unreachable!(), + }; + quote! { + let (#field_name, bytes) = <#option_type as light_zero_copy::init_mut::ZeroCopyNew>::new_zero_copy( + bytes, + (config.#field_name, ()) + )?; + } + } + + // Fixed-size types that are struct fields (not meta fields) need initialization with () config + FieldType::Primitive(field_name, field_type) => { + quote! { + let (#field_name, bytes) = <#field_type as light_zero_copy::borsh_mut::DeserializeMut>::zero_copy_at_mut(bytes)?; + } + } + + // Array fields that are struct fields (come after Vec/Option) + FieldType::Array(field_name, array_type) => { + quote! { + let (#field_name, bytes) = light_zero_copy::Ref::< + &'a mut [u8], + #array_type + >::from_prefix(bytes)?; + } + } + + FieldType::Pubkey(field_name) => { + quote! { + let (#field_name, bytes) = light_zero_copy::Ref::< + &'a mut [u8], + Pubkey + >::from_prefix(bytes)?; + } + } + + FieldType::Copy(field_name, field_type) => { + quote! { + let (#field_name, bytes) = <#field_type as light_zero_copy::init_mut::ZeroCopyNew>::new_zero_copy(bytes)?; + } + } + }; + Ok(result) +} + +/// Generate byte length calculation for a field based on its configuration +pub fn generate_byte_len_calculation(field_type: &FieldType) -> syn::Result { + let result = match field_type { + // Vec types that require configuration + FieldType::VecU8(field_name) => { + quote! { + (4 + config.#field_name as usize) // 4 bytes for length + actual data + } + } + + FieldType::VecCopy(field_name, inner_type) => { + quote! { + (4 + (config.#field_name as usize * core::mem::size_of::<#inner_type>())) + } + } + + FieldType::VecDynamicZeroCopy(field_name, vec_type) => { + quote! { + <#vec_type as light_zero_copy::init_mut::ZeroCopyNew<'static>>::byte_len(&config.#field_name) + } + } + + // Option types + FieldType::Option(field_name, option_type) => { + quote! { + <#option_type as light_zero_copy::init_mut::ZeroCopyNew<'static>>::byte_len(&config.#field_name) + } + } + + FieldType::OptionU64(field_name) => { + quote! { + as light_zero_copy::init_mut::ZeroCopyNew<'static>>::byte_len(&(config.#field_name, ())) + } + } + + FieldType::OptionU32(field_name) => { + quote! { + as light_zero_copy::init_mut::ZeroCopyNew<'static>>::byte_len(&(config.#field_name, ())) + } + } + + FieldType::OptionU16(field_name) => { + quote! { + as light_zero_copy::init_mut::ZeroCopyNew<'static>>::byte_len(&(config.#field_name, ())) + } + } + + // Fixed-size types don't need configuration and have known sizes + FieldType::Primitive(_, field_type) => { + let zerocopy_type = utils::convert_to_zerocopy_type(field_type); + quote! { + core::mem::size_of::<#zerocopy_type>() + } + } + + FieldType::Array(_, array_type) => { + quote! { + core::mem::size_of::<#array_type>() + } + } + + FieldType::Pubkey(_) => { + quote! { + 32 // Pubkey is always 32 bytes + } + } + + // Meta field types (should not appear in struct fields, but handle gracefully) + FieldType::Copy(_, field_type) => { + quote! { + core::mem::size_of::<#field_type>() + } + } + + FieldType::DynamicZeroCopy(field_name, field_type) => { + quote! { + <#field_type as light_zero_copy::init_mut::ZeroCopyNew<'static>>::byte_len(&config.#field_name) + } + } + }; + Ok(result) +} diff --git a/program-libs/zero-copy-derive/src/zero_copy.rs b/program-libs/zero-copy-derive/src/zero_copy.rs new file mode 100644 index 0000000000..bbae45e207 --- /dev/null +++ b/program-libs/zero-copy-derive/src/zero_copy.rs @@ -0,0 +1,636 @@ +use proc_macro::TokenStream as ProcTokenStream; +use proc_macro2::TokenStream; +use quote::{format_ident, quote}; +use syn::{parse_quote, DeriveInput, Field, Ident}; + +use crate::shared::{ + meta_struct, utils, + z_struct::{analyze_struct_fields, generate_z_struct, FieldType}, +}; + +/// Helper function to generate deserialize call pattern for a given type +fn generate_deserialize_call( + field_name: &syn::Ident, + field_type: &syn::Type, +) -> TokenStream { + let field_type = utils::convert_to_zerocopy_type(field_type); + let trait_path = if MUT { + quote!( as light_zero_copy::borsh_mut::DeserializeMut>::zero_copy_at_mut) + } else { + quote!( as light_zero_copy::borsh::Deserialize>::zero_copy_at) + }; + + quote! { + let (#field_name, bytes) = <#field_type #trait_path(bytes)?; + } +} + +/// Generates field deserialization code for the Deserialize implementation +/// The `MUT` parameter controls whether to generate code for mutable or immutable references +pub fn generate_deserialize_fields<'a, const MUT: bool>( + struct_fields: &'a [&'a Field], +) -> syn::Result + 'a> { + let field_types = analyze_struct_fields(struct_fields)?; + + let iterator = field_types.into_iter().map(move |field_type| { + let mutability_tokens = if MUT { + quote!(&'a mut [u8]) + } else { + quote!(&'a [u8]) + }; + match field_type { + FieldType::VecU8(field_name) => { + if MUT { + quote! { + let (#field_name, bytes) = light_zero_copy::borsh_mut::borsh_vec_u8_as_slice_mut(bytes)?; + } + } else { + quote! { + let (#field_name, bytes) = light_zero_copy::borsh::borsh_vec_u8_as_slice(bytes)?; + } + } + }, + FieldType::VecCopy(field_name, inner_type) => { + let inner_type = utils::convert_to_zerocopy_type(inner_type); + + let trait_path = if MUT { + quote!(light_zero_copy::slice_mut::ZeroCopySliceMutBorsh::<'a, <#inner_type as light_zero_copy::borsh_mut::ZeroCopyStructInnerMut>::ZeroCopyInnerMut>) + } else { + quote!(light_zero_copy::slice::ZeroCopySliceBorsh::<'a, <#inner_type as light_zero_copy::borsh::ZeroCopyStructInner>::ZeroCopyInner>) + }; + quote! { + let (#field_name, bytes) = #trait_path::from_bytes_at(bytes)?; + } + }, + FieldType::VecDynamicZeroCopy(field_name, field_type) => { + generate_deserialize_call::(field_name, field_type) + }, + FieldType::Array(field_name, field_type) => { + let field_type = utils::convert_to_zerocopy_type(field_type); + quote! { + let (#field_name, bytes) = light_zero_copy::Ref::<#mutability_tokens, #field_type>::from_prefix(bytes)?; + } + }, + FieldType::Option(field_name, field_type) => { + generate_deserialize_call::(field_name, field_type) + }, + FieldType::Pubkey(field_name) => { + generate_deserialize_call::(field_name, &parse_quote!(Pubkey)) + }, + FieldType::Primitive(field_name, field_type) => { + if MUT { + quote! { + let (#field_name, bytes) = <#field_type as light_zero_copy::borsh_mut::DeserializeMut>::zero_copy_at_mut(bytes)?; + } + } else { + quote! { + let (#field_name, bytes) = <#field_type as light_zero_copy::borsh::Deserialize>::zero_copy_at(bytes)?; + } + } + }, + FieldType::Copy(field_name, field_type) => { + let field_ty_zerocopy = utils::convert_to_zerocopy_type(field_type); + quote! { + let (#field_name, bytes) = light_zero_copy::Ref::<#mutability_tokens, #field_ty_zerocopy>::from_prefix(bytes)?; + } + }, + FieldType::DynamicZeroCopy(field_name, field_type) => { + generate_deserialize_call::(field_name, field_type) + }, + FieldType::OptionU64(field_name) => { + let field_ty_zerocopy = utils::convert_to_zerocopy_type(&parse_quote!(u64)); + generate_deserialize_call::(field_name, &parse_quote!(Option<#field_ty_zerocopy>)) + }, + FieldType::OptionU32(field_name) => { + let field_ty_zerocopy = utils::convert_to_zerocopy_type(&parse_quote!(u32)); + generate_deserialize_call::(field_name, &parse_quote!(Option<#field_ty_zerocopy>)) + }, + FieldType::OptionU16(field_name) => { + let field_ty_zerocopy = utils::convert_to_zerocopy_type(&parse_quote!(u16)); + generate_deserialize_call::(field_name, &parse_quote!(Option<#field_ty_zerocopy>)) + } + } + }); + Ok(iterator) +} + +/// Generates field initialization code for the Deserialize implementation +pub fn generate_init_fields<'a>( + struct_fields: &'a [&'a Field], +) -> impl Iterator + 'a { + struct_fields.iter().map(|field| { + let field_name = &field.ident; + quote! { #field_name } + }) +} + +/// Generates the Deserialize implementation as a TokenStream +/// The `MUT` parameter controls whether to generate code for mutable or immutable references +pub fn generate_deserialize_impl( + name: &Ident, + z_struct_name: &Ident, + z_struct_meta_name: &Ident, + struct_fields: &[&Field], + meta_is_empty: bool, + byte_len_impl: TokenStream, +) -> syn::Result { + let z_struct_name = if MUT { + format_ident!("{}Mut", z_struct_name) + } else { + z_struct_name.clone() + }; + let z_struct_meta_name = if MUT { + format_ident!("{}Mut", z_struct_meta_name) + } else { + z_struct_meta_name.clone() + }; + + // Define trait and types based on mutability + let (trait_name, mutability, method_name) = if MUT { + ( + quote!(light_zero_copy::borsh_mut::DeserializeMut), + quote!(mut), + quote!(zero_copy_at_mut), + ) + } else { + ( + quote!(light_zero_copy::borsh::Deserialize), + quote!(), + quote!(zero_copy_at), + ) + }; + let (meta_des, meta) = if meta_is_empty { + (quote!(), quote!()) + } else { + ( + quote! { + let (__meta, bytes) = light_zero_copy::Ref::< &'a #mutability [u8], #z_struct_meta_name>::from_prefix(bytes)?; + }, + quote!(__meta,), + ) + }; + let deserialize_fields = generate_deserialize_fields::(struct_fields)?; + let init_fields = generate_init_fields(struct_fields); + + let result = quote! { + impl<'a> #trait_name<'a> for #name { + type Output = #z_struct_name<'a>; + + fn #method_name(bytes: &'a #mutability [u8]) -> Result<(Self::Output, &'a #mutability [u8]), light_zero_copy::errors::ZeroCopyError> { + #meta_des + #(#deserialize_fields)* + Ok(( + #z_struct_name { + #meta + #(#init_fields,)* + }, + bytes + )) + } + + #byte_len_impl + } + }; + Ok(result) +} + +// #[cfg(test)] +// mod tests { +// use quote::format_ident; +// use rand::{prelude::SliceRandom, rngs::StdRng, thread_rng, Rng, SeedableRng}; +// use syn::parse_quote; + +// use super::*; + +// /// Generate a safe field name for testing +// fn random_ident(rng: &mut StdRng) -> String { +// // Use predetermined safe field names +// const FIELD_NAMES: &[&str] = &[ +// "field1", "field2", "field3", "field4", "field5", "value", "data", "count", "size", +// "flag", "name", "id", "code", "index", "key", "amount", "balance", "total", "result", +// "status", +// ]; + +// FIELD_NAMES.choose(rng).unwrap().to_string() +// } + +// /// Generate a random Rust type +// fn random_type(rng: &mut StdRng, _depth: usize) -> syn::Type { +// // Define our available types +// let types = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]; + +// // Randomly select a type index +// let selected = *types.choose(rng).unwrap(); + +// // Return the corresponding type +// match selected { +// 0 => parse_quote!(u8), +// 1 => parse_quote!(u16), +// 2 => parse_quote!(u32), +// 3 => parse_quote!(u64), +// 4 => parse_quote!(bool), +// 5 => parse_quote!(Vec), +// 6 => parse_quote!(Vec), +// 7 => parse_quote!(Vec), +// 8 => parse_quote!([u32; 12]), +// 9 => parse_quote!([Vec; 12]), +// 10 => parse_quote!([Vec; 20]), +// _ => unreachable!(), +// } +// } + +// /// Generate a random field +// fn random_field(rng: &mut StdRng) -> Field { +// let name = random_ident(rng); +// let ty = random_type(rng, 0); + +// // Use a safer approach to create the field +// let name_ident = format_ident!("{}", name); +// parse_quote!(pub #name_ident: #ty) +// } + +// /// Generate a list of random fields +// fn random_fields(rng: &mut StdRng, count: usize) -> Vec { +// (0..count).map(|_| random_field(rng)).collect() +// } + +// // Test for Vec field deserialization +// #[test] +// fn test_deserialize_vec_u8() { +// let field: Field = parse_quote!(pub data: Vec); +// let struct_fields = vec![&field]; + +// let result = generate_deserialize_fields::(&struct_fields).collect::>(); +// let result_str = result[0].to_string(); +// let expected = +// "let (data , bytes) = light_zero_copy :: borsh :: borsh_vec_u8_as_slice (bytes) ?"; + +// assert!(result_str.contains(expected)); +// } + +// // Test for Vec with Copy inner type deserialization +// #[test] +// fn test_deserialize_vec_copy_type() { +// let field: Field = parse_quote!(pub values: Vec); +// let struct_fields = vec![&field]; + +// let result = generate_deserialize_fields::(&struct_fields).collect::>(); +// let result_str = result[0].to_string(); +// let expected = "let (values , bytes) = light_zero_copy :: slice :: ZeroCopySliceBorsh :: < 'a , < u32 as light_zero_copy :: borsh :: ZeroCopyStructInner > :: ZeroCopyInner > :: from_bytes_at (bytes) ?"; + +// assert!(result_str.contains(expected)); +// } + +// // Test for Vec with non-Copy inner type deserialization +// #[test] +// fn test_deserialize_vec_non_copy_type() { +// // This is a synthetic test as we're treating String as a non-Copy type +// let field: Field = parse_quote!(pub names: Vec); +// let struct_fields = vec![&field]; + +// let result = generate_deserialize_fields::(&struct_fields).collect::>(); +// let result_str = result[0].to_string(); +// let expected = "let (names , bytes) = < Vec < String > as light_zero_copy :: borsh :: Deserialize > :: zero_copy_at (bytes) ?"; + +// assert!(result_str.contains(expected)); +// } + +// // Test for Option type deserialization +// #[test] +// fn test_deserialize_option_type() { +// let field: Field = parse_quote!(pub maybe_value: Option); +// let struct_fields = vec![&field]; + +// let result = generate_deserialize_fields::(&struct_fields).collect::>(); +// let result_str = result[0].to_string(); +// let expected = "let (maybe_value , bytes) = < Option < u32 > as light_zero_copy :: borsh :: Deserialize > :: zero_copy_at (bytes) ?"; + +// assert!(result_str.contains(expected)); +// } + +// // Test for non-Copy type deserialization +// #[test] +// fn test_deserialize_non_copy_type() { +// // Using String as a non-Copy type example +// let field: Field = parse_quote!(pub name: String); +// let struct_fields = vec![&field]; + +// let result = generate_deserialize_fields::(&struct_fields).collect::>(); +// let result_str = result[0].to_string(); +// let expected = "let (name , bytes) = < String as light_zero_copy :: borsh :: Deserialize > :: zero_copy_at (bytes) ?"; + +// assert!(result_str.contains(expected)); +// } + +// // Test for Copy type deserialization (primitive types) +// #[test] +// fn test_deserialize_copy_type() { +// let field: Field = parse_quote!(pub count: u32); +// let struct_fields = vec![&field]; + +// let result = generate_deserialize_fields::(&struct_fields).collect::>(); +// let result_str = result[0].to_string(); +// let expected = "let (count , bytes) = light_zero_copy :: Ref :: < & 'a [u8] , light_zero_copy :: little_endian :: U32 > :: from_prefix (bytes) ?"; +// println!("{}", result_str); +// assert!(result_str.contains(expected)); +// } + +// // Test for boolean type deserialization +// #[test] +// fn test_deserialize_bool_type() { +// let field: Field = parse_quote!(pub flag: bool); +// let struct_fields = vec![&field]; + +// let result = generate_deserialize_fields::(&struct_fields).collect::>(); +// let result_str = result[0].to_string(); +// let expected = +// "let (flag , bytes) = < u8 as light_zero_copy :: borsh :: Deserialize > :: zero_copy_at (bytes) ?"; +// println!("{}", result_str); +// assert!(result_str.contains(expected)); +// } + +// // Test for field initialization code generation +// #[test] +// fn test_init_fields() { +// let field1: Field = parse_quote!(pub id: u32); +// let field2: Field = parse_quote!(pub name: String); +// let struct_fields = vec![&field1, &field2]; + +// let result = generate_init_fields(&struct_fields).collect::>(); +// let result_str = format!("{} {}", result[0], result[1]); +// assert!(result_str.contains("id")); +// assert!(result_str.contains("name")); +// } + +// // Test for complete deserialize implementation generation +// #[test] +// fn test_generate_deserialize_impl() { +// let struct_name = format_ident!("TestStruct"); +// let z_struct_name = format_ident!("ZTestStruct"); +// let z_struct_meta_name = format_ident!("ZTestStructMeta"); + +// let field1: Field = parse_quote!(pub id: u32); +// let field2: Field = parse_quote!(pub values: Vec); +// let struct_fields = vec![&field1, &field2]; + +// let result = generate_deserialize_impl::( +// &struct_name, +// &z_struct_name, +// &z_struct_meta_name, +// &struct_fields, +// false, +// ) +// .to_string(); + +// // Check impl header +// assert!(result +// .contains("impl < 'a > light_zero_copy :: borsh :: Deserialize < 'a > for TestStruct")); + +// // Check Output type +// assert!(result.contains("type Output = ZTestStruct < 'a >")); + +// // Check method signature +// assert!(result.contains("fn zero_copy_at (bytes : & 'a [u8]) -> Result")); + +// // Check meta field extraction +// assert!(result.contains("let (__meta , bytes) = light_zero_copy :: Ref :: < & 'a [u8] , ZTestStructMeta > :: from_prefix (bytes) ?")); + +// // Check field deserialization +// assert!(result.contains("let (id , bytes) = light_zero_copy :: Ref :: < & 'a [u8] , light_zero_copy :: little_endian :: U32 > :: from_prefix (bytes) ?")); +// assert!(result.contains("let (values , bytes) = light_zero_copy :: slice :: ZeroCopySliceBorsh :: < 'a , < u16 as light_zero_copy :: borsh :: ZeroCopyStructInner > :: ZeroCopyInner > :: from_bytes_at (bytes) ?")); + +// // Check result structure +// assert!(result.contains("Ok ((ZTestStruct { __meta , id , values ,")); +// } + +// // Test for complete deserialize implementation generation +// #[test] +// fn test_generate_deserialize_impl_no_meta() { +// let struct_name = format_ident!("TestStruct"); +// let z_struct_name = format_ident!("ZTestStruct"); +// let z_struct_meta_name = format_ident!("ZTestStructMeta"); + +// let field1: Field = parse_quote!(pub id: u32); +// let field2: Field = parse_quote!(pub values: Vec); +// let struct_fields = vec![&field1, &field2]; + +// let result = generate_deserialize_impl::( +// &struct_name, +// &z_struct_name, +// &z_struct_meta_name, +// &struct_fields, +// true, +// ) +// .to_string(); + +// // Check impl header +// assert!(result +// .contains("impl < 'a > light_zero_copy :: borsh :: Deserialize < 'a > for TestStruct")); + +// // Check Output type +// assert!(result.contains("type Output = ZTestStruct < 'a >")); + +// // Check method signature +// assert!(result.contains("fn zero_copy_at (bytes : & 'a [u8]) -> Result")); + +// // Check meta field extraction +// assert!(!result.contains("let (meta , bytes) = light_zero_copy :: Ref :: < & 'a [u8] , ZTestStructMeta > :: from_prefix (bytes) ?")); + +// // Check field deserialization +// assert!(result.contains("let (id , bytes) = light_zero_copy :: Ref :: < & 'a [u8] , light_zero_copy :: little_endian :: U32 > :: from_prefix (bytes) ?")); +// assert!(result.contains("let (values , bytes) = light_zero_copy :: slice :: ZeroCopySliceBorsh :: < 'a , < u16 as light_zero_copy :: borsh :: ZeroCopyStructInner > :: ZeroCopyInner > :: from_bytes_at (bytes) ?")); + +// // Check result structure +// assert!(result.contains("Ok ((ZTestStruct { id , values ,")); +// } + +// #[test] +// fn test_fuzz_generate_deserialize_impl() { +// // Set up RNG with a seed for reproducibility +// let seed = thread_rng().gen(); +// println!("seed {}", seed); +// let mut rng = StdRng::seed_from_u64(seed); + +// // Number of iterations for the test +// let num_iters = 10000; + +// for i in 0..num_iters { +// // Generate a random struct name +// let struct_name = format_ident!("{}", random_ident(&mut rng)); +// let z_struct_name = format_ident!("Z{}", struct_name); +// let z_struct_meta_name = format_ident!("Z{}Meta", struct_name); + +// // Generate random number of fields (1-10) +// let field_count = rng.gen_range(1..11); +// let fields = random_fields(&mut rng, field_count); + +// // Create a named fields collection +// let syn_fields = syn::punctuated::Punctuated::from_iter(fields.iter().cloned()); +// let fields_named = syn::FieldsNamed { +// brace_token: syn::token::Brace::default(), +// named: syn_fields, +// }; + +// // Split into meta fields and struct fields +// let (_, struct_fields) = crate::utils::process_fields(&fields_named); + +// // Call the function we're testing +// let result = generate_deserialize_impl::( +// &struct_name, +// &z_struct_name, +// &z_struct_meta_name, +// &struct_fields, +// false, +// ); + +// // Get the generated code as a string for validation +// let result_str = result.to_string(); + +// // Print the first result for debugging +// if i == 0 { +// println!("Generated deserialize_impl code format:\n{}", result_str); +// } + +// // Verify the result contains expected elements +// // Basic validation - must be non-empty +// assert!( +// !result_str.is_empty(), +// "Failed to generate TokenStream for iteration {}", +// i +// ); + +// // Validate that the generated code contains the expected impl definition +// let impl_pattern = format!( +// "impl < 'a > light_zero_copy :: borsh :: Deserialize < 'a > for {}", +// struct_name +// ); +// assert!( +// result_str.contains(&impl_pattern), +// "Generated code missing impl definition for iteration {}. Expected: {}", +// i, +// impl_pattern +// ); + +// // Validate type Output is defined +// let output_pattern = format!("type Output = {} < 'a >", z_struct_name); +// assert!( +// result_str.contains(&output_pattern), +// "Generated code missing Output type for iteration {}. Expected: {}", +// i, +// output_pattern +// ); + +// // Validate the zero_copy_at method is present +// assert!( +// result_str.contains("fn zero_copy_at (bytes : & 'a [u8])"), +// "Generated code missing zero_copy_at method for iteration {}", +// i +// ); + +// // Check for meta field extraction +// let meta_extraction_pattern = format!( +// "let (__meta , bytes) = light_zero_copy :: Ref :: < & 'a [u8] , {} > :: from_prefix (bytes) ?", +// z_struct_meta_name +// ); +// assert!( +// result_str.contains(&meta_extraction_pattern), +// "Generated code missing meta field extraction for iteration {}", +// i +// ); + +// // Check for return with Ok pattern +// assert!( +// result_str.contains("Ok (("), +// "Generated code missing Ok return statement for iteration {}", +// i +// ); + +// // Check for the struct initialization +// let struct_init_pattern = format!("{} {{", z_struct_name); +// assert!( +// result_str.contains(&struct_init_pattern), +// "Generated code missing struct initialization for iteration {}", +// i +// ); + +// // Check for meta field in the returned struct +// assert!( +// result_str.contains("__meta ,"), +// "Generated code missing meta field in struct initialization for iteration {}", +// i +// ); +// } +// } +// } + +/// Generates the ZeroCopyStructInner implementation as a TokenStream +pub fn generate_zero_copy_struct_inner( + name: &Ident, + z_struct_name: &Ident, +) -> syn::Result { + let result = if MUT { + quote! { + // ZeroCopyStructInner implementation + impl light_zero_copy::borsh_mut::ZeroCopyStructInnerMut for #name { + type ZeroCopyInnerMut = #z_struct_name<'static>; + } + } + } else { + quote! { + // ZeroCopyStructInner implementation + impl light_zero_copy::borsh::ZeroCopyStructInner for #name { + type ZeroCopyInner = #z_struct_name<'static>; + } + } + }; + Ok(result) +} + +pub fn derive_zero_copy_impl(input: ProcTokenStream) -> syn::Result { + // Parse the input DeriveInput + let input: DeriveInput = syn::parse(input)?; + + let hasher = false; + + // Process the input to extract struct information + let (name, z_struct_name, z_struct_meta_name, fields) = utils::process_input(&input)?; + + // Process the fields to separate meta fields and struct fields + let (meta_fields, struct_fields) = utils::process_fields(fields); + + let meta_struct_def = if !meta_fields.is_empty() { + meta_struct::generate_meta_struct::(&z_struct_meta_name, &meta_fields, hasher)? + } else { + quote! {} + }; + + let z_struct_def = generate_z_struct::( + &z_struct_name, + &z_struct_meta_name, + &struct_fields, + &meta_fields, + hasher, + )?; + + let zero_copy_struct_inner_impl = + generate_zero_copy_struct_inner::(name, &z_struct_name)?; + + let deserialize_impl = generate_deserialize_impl::( + name, + &z_struct_name, + &z_struct_meta_name, + &struct_fields, + meta_fields.is_empty(), + quote! {}, + )?; + + // Combine all implementations + let expanded = quote! { + #meta_struct_def + #z_struct_def + #zero_copy_struct_inner_impl + #deserialize_impl + }; + + Ok(expanded) +} diff --git a/program-libs/zero-copy-derive/src/zero_copy_eq.rs b/program-libs/zero-copy-derive/src/zero_copy_eq.rs new file mode 100644 index 0000000000..94b06b51a6 --- /dev/null +++ b/program-libs/zero-copy-derive/src/zero_copy_eq.rs @@ -0,0 +1,265 @@ +use proc_macro::TokenStream as ProcTokenStream; +use proc_macro2::TokenStream; +use quote::quote; +use syn::{DeriveInput, Field, Ident}; + +use crate::shared::{ + from_impl, utils, + z_struct::{analyze_struct_fields, FieldType}, +}; + +/// Generates meta field comparisons for PartialEq implementation +pub fn generate_meta_field_comparisons<'a>( + meta_fields: &'a [&'a Field], +) -> syn::Result + 'a> { + let field_types = analyze_struct_fields(meta_fields)?; + + let iterator = field_types.into_iter().map(|field_type| match field_type { + FieldType::Primitive(field_name, field_type) => { + match () { + _ if utils::is_specific_primitive_type(field_type, "u8") => quote! { + if other.#field_name != meta.#field_name { + return false; + } + }, + _ if utils::is_specific_primitive_type(field_type, "bool") => quote! { + if other.#field_name != (meta.#field_name > 0) { + return false; + } + }, + _ => { + // For u64, u32, u16 - use the type's from() method + quote! { + if other.#field_name != #field_type::from(meta.#field_name) { + return false; + } + } + } + } + } + _ => { + let field_name = field_type.name(); + quote! { + if other.#field_name != meta.#field_name { + return false; + } + } + } + }); + Ok(iterator) +} + +/// Generates struct field comparisons for PartialEq implementation +pub fn generate_struct_field_comparisons<'a, const MUT: bool>( + struct_fields: &'a [&'a Field], +) -> syn::Result + 'a> { + let field_types = analyze_struct_fields(struct_fields)?; + if field_types + .iter() + .any(|x| matches!(x, FieldType::Option(_, _))) + { + return Err(syn::Error::new_spanned( + struct_fields[0], + "Options are not supported in ZeroCopyEq", + )); + } + + let iterator = field_types.into_iter().map(|field_type| { + match field_type { + FieldType::VecU8(field_name) => { + quote! { + if self.#field_name != other.#field_name.as_slice() { + return false; + } + } + } + FieldType::VecCopy(field_name, _) => { + quote! { + if self.#field_name.as_slice() != other.#field_name.as_slice() { + return false; + } + } + } + FieldType::VecDynamicZeroCopy(field_name, _) => { + quote! { + if self.#field_name.as_slice() != other.#field_name.as_slice() { + return false; + } + } + } + FieldType::Array(field_name, _) => { + quote! { + if *self.#field_name != other.#field_name { + return false; + } + } + } + FieldType::Option(field_name, field_type) => { + if utils::is_specific_primitive_type(field_type, "u8") { + quote! { + if self.#field_name.is_some() && other.#field_name.is_some() { + if self.#field_name.as_ref().unwrap() != other.#field_name.as_ref().unwrap() { + return false; + } + } else if self.#field_name.is_some() || other.#field_name.is_some() { + return false; + } + } + } + // TODO: handle issue that structs need * == *, arrays need ** == * + // else if crate::utils::is_copy_type(field_type) { + // quote! { + // if self.#field_name.is_some() && other.#field_name.is_some() { + // if **self.#field_name.as_ref().unwrap() != *other.#field_name.as_ref().unwrap() { + // return false; + // } + // } else if self.#field_name.is_some() || other.#field_name.is_some() { + // return false; + // } + // } + // } + else { + quote! { + if self.#field_name.is_some() && other.#field_name.is_some() { + if **self.#field_name.as_ref().unwrap() != *other.#field_name.as_ref().unwrap() { + return false; + } + } else if self.#field_name.is_some() || other.#field_name.is_some() { + return false; + } + } + } + + } + FieldType::Pubkey(field_name) => { + quote! { + if *self.#field_name != other.#field_name { + return false; + } + } + } + FieldType::Primitive(field_name, field_type) => { + match () { + _ if utils::is_specific_primitive_type(field_type, "u8") => + if MUT { + quote! { + if *self.#field_name != other.#field_name { + return false; + } + } + } else { + quote! { + if self.#field_name != other.#field_name { + return false; + } + } + }, + _ if utils::is_specific_primitive_type(field_type, "bool") => + if MUT { + quote! { + if (*self.#field_name > 0) != other.#field_name { + return false; + } + } + } else { + quote! { + if (self.#field_name > 0) != other.#field_name { + return false; + } + } + }, + _ => { + // For u64, u32, u16 - use the type's from() method + quote! { + if #field_type::from(*self.#field_name) != other.#field_name { + return false; + } + } + } + } + } + FieldType::Copy(field_name, _) + | FieldType::DynamicZeroCopy(field_name, _) => { + quote! { + if self.#field_name != other.#field_name { + return false; + } + } + }, + FieldType::OptionU64(field_name) + | FieldType::OptionU32(field_name) + | FieldType::OptionU16(field_name) => { + quote! { + if self.#field_name != other.#field_name { + return false; + } + } + } + } + }); + Ok(iterator) +} + +/// Generates the PartialEq implementation as a TokenStream +pub fn generate_partial_eq_impl( + name: &Ident, + z_struct_name: &Ident, + z_struct_meta_name: &Ident, + meta_fields: &[&Field], + struct_fields: &[&Field], +) -> syn::Result { + let struct_field_comparisons = generate_struct_field_comparisons::(struct_fields)?; + let result = if !meta_fields.is_empty() { + let meta_field_comparisons = generate_meta_field_comparisons(meta_fields)?; + quote! { + impl<'a> PartialEq<#name> for #z_struct_name<'a> { + fn eq(&self, other: &#name) -> bool { + let meta: &#z_struct_meta_name = &self.__meta; + #(#meta_field_comparisons)* + #(#struct_field_comparisons)* + true + } + } + } + } else { + quote! { + impl<'a> PartialEq<#name> for #z_struct_name<'a> { + fn eq(&self, other: &#name) -> bool { + #(#struct_field_comparisons)* + true + } + } + + } + }; + Ok(result) +} + +pub fn derive_zero_copy_eq_impl(input: ProcTokenStream) -> syn::Result { + // Parse the input DeriveInput + let input: DeriveInput = syn::parse(input)?; + + // Process the input to extract struct information + let (name, z_struct_name, z_struct_meta_name, fields) = utils::process_input(&input)?; + + // Process the fields to separate meta fields and struct fields + let (meta_fields, struct_fields) = utils::process_fields(fields); + + // Generate the PartialEq implementation + let partial_eq_impl = generate_partial_eq_impl::( + name, + &z_struct_name, + &z_struct_meta_name, + &meta_fields, + &struct_fields, + )?; + + // Generate From implementations + let from_impl = + from_impl::generate_from_impl::(name, &z_struct_name, &meta_fields, &struct_fields)?; + + Ok(quote! { + #partial_eq_impl + #from_impl + }) +} diff --git a/program-libs/zero-copy-derive/src/zero_copy_mut.rs b/program-libs/zero-copy-derive/src/zero_copy_mut.rs new file mode 100644 index 0000000000..ad52bba4d5 --- /dev/null +++ b/program-libs/zero-copy-derive/src/zero_copy_mut.rs @@ -0,0 +1,93 @@ +use proc_macro::TokenStream; +use quote::{format_ident, quote}; +use syn::DeriveInput; + +use crate::{ + shared::{ + meta_struct, utils, + z_struct::{self, analyze_struct_fields}, + zero_copy_new::{generate_config_struct, generate_init_mut_impl}, + }, + zero_copy, +}; + +pub fn derive_zero_copy_mut_impl(fn_input: TokenStream) -> syn::Result { + // Parse the input DeriveInput + let input: DeriveInput = syn::parse(fn_input.clone())?; + + let hasher = false; + + // Process the input to extract struct information + let (name, z_struct_name, z_struct_meta_name, fields) = utils::process_input(&input)?; + + // Process the fields to separate meta fields and struct fields + let (meta_fields, struct_fields) = utils::process_fields(fields); + + let meta_struct_def_mut = if !meta_fields.is_empty() { + meta_struct::generate_meta_struct::(&z_struct_meta_name, &meta_fields, hasher)? + } else { + quote! {} + }; + + let z_struct_def_mut = z_struct::generate_z_struct::( + &z_struct_name, + &z_struct_meta_name, + &struct_fields, + &meta_fields, + hasher, + )?; + + let zero_copy_struct_inner_impl_mut = zero_copy::generate_zero_copy_struct_inner::( + name, + &format_ident!("{}Mut", z_struct_name), + )?; + + let deserialize_impl_mut = zero_copy::generate_deserialize_impl::( + name, + &z_struct_name, + &z_struct_meta_name, + &struct_fields, + meta_fields.is_empty(), + quote! {}, + )?; + + // Parse the input DeriveInput + let input: DeriveInput = syn::parse(fn_input)?; + + // Process the input to extract struct information + let (name, _z_struct_name, _z_struct_meta_name, fields) = utils::process_input(&input)?; + + // Use the same field processing logic as other derive macros for consistency + let (meta_fields, struct_fields) = utils::process_fields(fields); + + // Process ALL fields uniformly by type (no position dependency for config generation) + let all_fields: Vec<&syn::Field> = meta_fields + .iter() + .chain(struct_fields.iter()) + .cloned() + .collect(); + let all_field_types = analyze_struct_fields(&all_fields)?; + + // Generate configuration struct based on all fields that need config (type-based) + let config_struct = generate_config_struct(name, &all_field_types)?; + + // Generate ZeroCopyNew implementation using the existing field separation + let init_mut_impl = generate_init_mut_impl(name, &meta_fields, &struct_fields)?; + + // Combine all mutable implementations + let expanded = quote! { + #config_struct + + #init_mut_impl + + #meta_struct_def_mut + + #z_struct_def_mut + + #zero_copy_struct_inner_impl_mut + + #deserialize_impl_mut + }; + + Ok(expanded) +} diff --git a/program-libs/zero-copy-derive/tests/config_test.rs b/program-libs/zero-copy-derive/tests/config_test.rs new file mode 100644 index 0000000000..990b2f6a18 --- /dev/null +++ b/program-libs/zero-copy-derive/tests/config_test.rs @@ -0,0 +1,430 @@ +#![cfg(feature = "mut")] + +use borsh::{BorshDeserialize, BorshSerialize}; +use light_zero_copy::borsh_mut::DeserializeMut; +use light_zero_copy_derive::{ZeroCopy, ZeroCopyEq, ZeroCopyMut}; + +/// Simple struct with just a Vec field to test basic config functionality +#[repr(C)] +#[derive(Debug, PartialEq, BorshSerialize, BorshDeserialize, ZeroCopy, ZeroCopyMut, ZeroCopyEq)] +pub struct SimpleVecStruct { + pub a: u8, + pub vec: Vec, + pub b: u16, +} + +#[test] +fn test_simple_config_generation() { + // This test verifies that the ZeroCopyNew derive macro generates the expected config struct + // and ZeroCopyNew implementation + + // The config should have been generated as SimpleVecStructConfig + let config = SimpleVecStructConfig { + vec: 10, // Vec should have u32 config (length) + }; + + // Test that we can create a configuration + assert_eq!(config.vec, 10); + + println!("Config generation test passed!"); +} + +#[test] +fn test_simple_vec_struct_new_zero_copy() { + use light_zero_copy::init_mut::ZeroCopyNew; + + // Test the new_zero_copy method generated by ZeroCopyNew + let config = SimpleVecStructConfig { + vec: 5, // Vec with capacity 5 + }; + + // Calculate exact buffer size needed and allocate + let buffer_size = SimpleVecStruct::byte_len(&config); + let mut bytes = vec![0u8; buffer_size]; + + // Use the generated new_zero_copy method + let result = SimpleVecStruct::new_zero_copy(&mut bytes, config); + assert!(result.is_ok()); + let (mut simple_struct, remaining) = result.unwrap(); + + // Verify we used exactly the calculated number of bytes + assert_eq!( + remaining.len(), + 0, + "Should have used exactly {} bytes", + buffer_size + ); + + // Test that we can set meta fields + simple_struct.__meta.a = 42; + + // Test that we can write to the vec slice + simple_struct.vec[0] = 10; + simple_struct.vec[1] = 20; + simple_struct.vec[2] = 30; + + // Test that we can set the b field + *simple_struct.b = 12345u16.into(); + + // Verify the values we set + assert_eq!(simple_struct.__meta.a, 42); + assert_eq!(simple_struct.vec[0], 10); + assert_eq!(simple_struct.vec[1], 20); + assert_eq!(simple_struct.vec[2], 30); + assert_eq!(u16::from(*simple_struct.b), 12345); + + // Test deserializing the initialized bytes with zero_copy_at_mut + let deserialize_result = SimpleVecStruct::zero_copy_at_mut(&mut bytes); + assert!(deserialize_result.is_ok()); + let (deserialized, _remaining) = deserialize_result.unwrap(); + + // Verify the deserialized data matches what we set + assert_eq!(deserialized.__meta.a, 42); + assert_eq!(deserialized.vec[0], 10); + assert_eq!(deserialized.vec[1], 20); + assert_eq!(deserialized.vec[2], 30); + assert_eq!(u16::from(*deserialized.b), 12345); + + println!("new_zero_copy initialization test passed!"); +} + +/// Struct with Option field to test Option config +#[repr(C)] +#[derive(Debug, PartialEq, BorshSerialize, BorshDeserialize, ZeroCopy, ZeroCopyMut)] +pub struct SimpleOptionStruct { + pub a: u8, + pub option: Option, +} + +#[test] +fn test_simple_option_struct_new_zero_copy() { + use light_zero_copy::init_mut::ZeroCopyNew; + + // Test with option enabled + let config = SimpleOptionStructConfig { + option: true, // Option should have bool config (enabled/disabled) + }; + + // Calculate exact buffer size needed and allocate + let buffer_size = SimpleOptionStruct::byte_len(&config); + let mut bytes = vec![0u8; buffer_size]; + + let result = SimpleOptionStruct::new_zero_copy(&mut bytes, config); + assert!(result.is_ok()); + let (mut simple_struct, remaining) = result.unwrap(); + + // Verify we used exactly the calculated number of bytes + assert_eq!( + remaining.len(), + 0, + "Should have used exactly {} bytes", + buffer_size + ); + + // Test that we can set meta field + simple_struct.__meta.a = 123; + + // Test that option is Some and we can set its value + assert!(simple_struct.option.is_some()); + if let Some(ref mut opt_val) = simple_struct.option { + **opt_val = 98765u64.into(); + } + + // Verify the values + assert_eq!(simple_struct.__meta.a, 123); + if let Some(ref opt_val) = simple_struct.option { + assert_eq!(u64::from(**opt_val), 98765); + } + + // Test deserializing + let (deserialized, _) = SimpleOptionStruct::zero_copy_at_mut(&mut bytes).unwrap(); + assert_eq!(deserialized.__meta.a, 123); + assert!(deserialized.option.is_some()); + if let Some(ref opt_val) = deserialized.option { + assert_eq!(u64::from(**opt_val), 98765); + } + + println!("Option new_zero_copy test passed!"); +} + +#[test] +fn test_simple_option_struct_disabled() { + use light_zero_copy::init_mut::ZeroCopyNew; + + // Test with option disabled + let config = SimpleOptionStructConfig { + option: false, // Option disabled + }; + + // Calculate exact buffer size needed and allocate + let buffer_size = SimpleOptionStruct::byte_len(&config); + let mut bytes = vec![0u8; buffer_size]; + + let result = SimpleOptionStruct::new_zero_copy(&mut bytes, config); + assert!(result.is_ok()); + let (mut simple_struct, remaining) = result.unwrap(); + + // Verify we used exactly the calculated number of bytes + assert_eq!( + remaining.len(), + 0, + "Should have used exactly {} bytes", + buffer_size + ); + + // Set meta field + simple_struct.__meta.a = 200; + + // Test that option is None + assert!(simple_struct.option.is_none()); + + // Test deserializing + let (deserialized, _) = SimpleOptionStruct::zero_copy_at_mut(&mut bytes).unwrap(); + assert_eq!(deserialized.__meta.a, 200); + assert!(deserialized.option.is_none()); + + println!("Option disabled new_zero_copy test passed!"); +} + +/// Test both Vec and Option in one struct +#[repr(C)] +#[derive(Debug, PartialEq, BorshSerialize, BorshDeserialize, ZeroCopy, ZeroCopyMut)] +pub struct MixedStruct { + pub a: u8, + pub vec: Vec, + pub option: Option, + pub b: u16, +} + +#[test] +fn test_mixed_struct_new_zero_copy() { + use light_zero_copy::init_mut::ZeroCopyNew; + + // Test with both vec and option enabled + let config = MixedStructConfig { + vec: 8, // Vec -> u32 length + option: true, // Option -> bool enabled + }; + + // Calculate exact buffer size needed and allocate + let buffer_size = MixedStruct::byte_len(&config); + let mut bytes = vec![0u8; buffer_size]; + + let result = MixedStruct::new_zero_copy(&mut bytes, config); + assert!(result.is_ok()); + let (mut mixed_struct, remaining) = result.unwrap(); + + // Verify we used exactly the calculated number of bytes + assert_eq!( + remaining.len(), + 0, + "Should have used exactly {} bytes", + buffer_size + ); + + // Set meta field + mixed_struct.__meta.a = 77; + + // Set vec data + mixed_struct.vec[0] = 11; + mixed_struct.vec[3] = 44; + mixed_struct.vec[7] = 88; + + // Set option value + assert!(mixed_struct.option.is_some()); + if let Some(ref mut opt_val) = mixed_struct.option { + **opt_val = 123456789u64.into(); + } + + // Set b field + *mixed_struct.b = 54321u16.into(); + + // Verify all values + assert_eq!(mixed_struct.__meta.a, 77); + assert_eq!(mixed_struct.vec[0], 11); + assert_eq!(mixed_struct.vec[3], 44); + assert_eq!(mixed_struct.vec[7], 88); + if let Some(ref opt_val) = mixed_struct.option { + assert_eq!(u64::from(**opt_val), 123456789); + } + assert_eq!(u16::from(*mixed_struct.b), 54321); + + // Test deserializing + let (deserialized, _) = MixedStruct::zero_copy_at_mut(&mut bytes).unwrap(); + assert_eq!(deserialized.__meta.a, 77); + assert_eq!(deserialized.vec[0], 11); + assert_eq!(deserialized.vec[3], 44); + assert_eq!(deserialized.vec[7], 88); + assert!(deserialized.option.is_some()); + if let Some(ref opt_val) = deserialized.option { + assert_eq!(u64::from(**opt_val), 123456789); + } + assert_eq!(u16::from(*deserialized.b), 54321); + + println!("Mixed struct new_zero_copy test passed!"); +} + +#[test] +fn test_mixed_struct_option_disabled() { + use light_zero_copy::init_mut::ZeroCopyNew; + + // Test with vec enabled but option disabled + let config = MixedStructConfig { + vec: 3, // Vec -> u32 length + option: false, // Option -> bool disabled + }; + + // Calculate exact buffer size needed and allocate + let buffer_size = MixedStruct::byte_len(&config); + let mut bytes = vec![0u8; buffer_size]; + + let result = MixedStruct::new_zero_copy(&mut bytes, config); + assert!(result.is_ok()); + let (mut mixed_struct, remaining) = result.unwrap(); + + // Verify we used exactly the calculated number of bytes + assert_eq!( + remaining.len(), + 0, + "Should have used exactly {} bytes", + buffer_size + ); + + // Set values + mixed_struct.__meta.a = 99; + mixed_struct.vec[0] = 255; + mixed_struct.vec[2] = 128; + *mixed_struct.b = 9999u16.into(); + + // Verify option is None + assert!(mixed_struct.option.is_none()); + + // Test deserializing + let (deserialized, _) = MixedStruct::zero_copy_at_mut(&mut bytes).unwrap(); + assert_eq!(deserialized.__meta.a, 99); + assert_eq!(deserialized.vec[0], 255); + assert_eq!(deserialized.vec[2], 128); + assert!(deserialized.option.is_none()); + assert_eq!(u16::from(*deserialized.b), 9999); + + println!("Mixed struct option disabled test passed!"); +} + +#[test] +fn test_byte_len_calculation() { + use light_zero_copy::init_mut::ZeroCopyNew; + + // Test SimpleVecStruct byte_len calculation + let config = SimpleVecStructConfig { + vec: 10, // Vec with capacity 10 + }; + + let expected_size = 1 + // a: u8 (meta field) + 4 + 10 + // vec: 4 bytes length + 10 bytes data + 2; // b: u16 + + let calculated_size = SimpleVecStruct::byte_len(&config); + assert_eq!(calculated_size, expected_size); + println!( + "SimpleVecStruct byte_len: calculated={}, expected={}", + calculated_size, expected_size + ); + + // Test SimpleOptionStruct byte_len calculation + let config_some = SimpleOptionStructConfig { + option: true, // Option enabled + }; + + let expected_size_some = 1 + // a: u8 (meta field) + 1 + 8; // option: 1 byte discriminant + 8 bytes u64 + + let calculated_size_some = SimpleOptionStruct::byte_len(&config_some); + assert_eq!(calculated_size_some, expected_size_some); + println!( + "SimpleOptionStruct (Some) byte_len: calculated={}, expected={}", + calculated_size_some, expected_size_some + ); + + let config_none = SimpleOptionStructConfig { + option: false, // Option disabled + }; + + let expected_size_none = 1 + // a: u8 (meta field) + 1; // option: 1 byte discriminant for None + + let calculated_size_none = SimpleOptionStruct::byte_len(&config_none); + assert_eq!(calculated_size_none, expected_size_none); + println!( + "SimpleOptionStruct (None) byte_len: calculated={}, expected={}", + calculated_size_none, expected_size_none + ); + + // Test MixedStruct byte_len calculation + let config_mixed = MixedStructConfig { + vec: 5, // Vec with capacity 5 + option: true, // Option enabled + }; + + let expected_size_mixed = 1 + // a: u8 (meta field) + 4 + 5 + // vec: 4 bytes length + 5 bytes data + 1 + 8 + // option: 1 byte discriminant + 8 bytes u64 + 2; // b: u16 + + let calculated_size_mixed = MixedStruct::byte_len(&config_mixed); + assert_eq!(calculated_size_mixed, expected_size_mixed); + println!( + "MixedStruct byte_len: calculated={}, expected={}", + calculated_size_mixed, expected_size_mixed + ); + + println!("All byte_len calculation tests passed!"); +} + +#[test] +fn test_dynamic_buffer_allocation_with_byte_len() { + use light_zero_copy::init_mut::ZeroCopyNew; + + // Example of how to use byte_len for dynamic buffer allocation + let config = MixedStructConfig { + vec: 12, // Vec with capacity 12 + option: true, // Option enabled + }; + + // Calculate the exact buffer size needed + let required_size = MixedStruct::byte_len(&config); + println!("Required buffer size: {} bytes", required_size); + + // Allocate exactly the right amount of memory + let mut bytes = vec![0u8; required_size]; + + // Initialize the structure + let result = MixedStruct::new_zero_copy(&mut bytes, config); + assert!(result.is_ok()); + let (mut mixed_struct, remaining) = result.unwrap(); + + // Verify we used exactly the right amount of bytes (no remaining bytes) + assert_eq!( + remaining.len(), + 0, + "Should have used exactly the calculated number of bytes" + ); + + // Set some values to verify it works + mixed_struct.__meta.a = 42; + mixed_struct.vec[5] = 123; + if let Some(ref mut opt_val) = mixed_struct.option { + **opt_val = 9999u64.into(); + } + *mixed_struct.b = 7777u16.into(); + + // Verify round-trip works + let (deserialized, _) = MixedStruct::zero_copy_at_mut(&mut bytes).unwrap(); + assert_eq!(deserialized.__meta.a, 42); + assert_eq!(deserialized.vec[5], 123); + if let Some(ref opt_val) = deserialized.option { + assert_eq!(u64::from(**opt_val), 9999); + } + assert_eq!(u16::from(*deserialized.b), 7777); + + println!("Dynamic buffer allocation test passed!"); +} diff --git a/program-libs/zero-copy-derive/tests/cross_crate_copy.rs b/program-libs/zero-copy-derive/tests/cross_crate_copy.rs new file mode 100644 index 0000000000..e827908046 --- /dev/null +++ b/program-libs/zero-copy-derive/tests/cross_crate_copy.rs @@ -0,0 +1,295 @@ +#![cfg(feature = "mut")] +//! Test cross-crate Copy identification functionality +//! +//! This test validates that the zero-copy derive macro correctly identifies +//! which types implement Copy, both for built-in types and user-defined types. + +use borsh::{BorshDeserialize, BorshSerialize}; +use light_zero_copy_derive::{ZeroCopy, ZeroCopyEq, ZeroCopyMut}; + +// Test struct with primitive Copy types that should be in meta fields +#[derive(Debug, PartialEq, BorshSerialize, BorshDeserialize, ZeroCopy)] +pub struct PrimitiveCopyStruct { + pub a: u8, + pub b: u16, + pub c: u32, + pub d: u64, + pub e: bool, + pub f: Vec, // Split point - this and following fields go to struct_fields + pub g: u32, // Should be in struct_fields due to field ordering rules +} + +// Test struct with primitive Copy types that should be in meta fields +#[derive(Debug, PartialEq, BorshSerialize, BorshDeserialize, ZeroCopy, ZeroCopyEq, ZeroCopyMut)] +pub struct PrimitiveCopyStruct2 { + pub f: Vec, // Split point - this and following fields go to struct_fields + pub a: u8, + pub b: u16, + pub c: u32, + pub d: u64, + pub e: bool, + pub g: u32, +} + +// Test struct with arrays that use u8 (which supports Unaligned) +#[derive(Debug, PartialEq, BorshSerialize, BorshDeserialize, ZeroCopy)] +pub struct ArrayCopyStruct { + pub fixed_u8: [u8; 4], + pub another_u8: [u8; 8], + pub data: Vec, // Split point + pub more_data: [u8; 3], // Should be in struct_fields due to field ordering +} + +// Test struct with Vec of primitive Copy types +#[derive(Debug, PartialEq, BorshSerialize, BorshDeserialize, ZeroCopy)] +pub struct VecPrimitiveStruct { + pub header: u32, + pub data: Vec, // Vec - special case + pub numbers: Vec, // Vec of Copy type + pub footer: u64, +} + +#[cfg(test)] +mod tests { + use light_zero_copy::borsh::Deserialize; + + use super::*; + + #[test] + fn test_primitive_copy_field_splitting() { + // This test validates that primitive Copy types are correctly + // identified and placed in meta_fields until we hit a Vec + + let data = PrimitiveCopyStruct { + a: 1, + b: 2, + c: 3, + d: 4, + e: true, + f: vec![5, 6, 7], + g: 8, + }; + + let serialized = borsh::to_vec(&data).unwrap(); + let (deserialized, _) = PrimitiveCopyStruct::zero_copy_at(&serialized).unwrap(); + + // Verify we can access meta fields (should be zero-copy references) + assert_eq!(deserialized.a, 1); + assert_eq!(deserialized.b.get(), 2); // U16 type, use .get() + assert_eq!(deserialized.c.get(), 3); // U32 type, use .get() + assert_eq!(deserialized.d.get(), 4); // U64 type, use .get() + assert_eq!(deserialized.e(), true); // bool accessor method + + // Verify we can access struct fields + assert_eq!(deserialized.f, &[5, 6, 7]); + assert_eq!(deserialized.g.get(), 8); // U32 type in struct fields + } + + #[test] + fn test_array_copy_field_splitting() { + // Arrays should be treated as Copy types + let data = ArrayCopyStruct { + fixed_u8: [1, 2, 3, 4], + another_u8: [10, 20, 30, 40, 50, 60, 70, 80], + data: vec![5, 6], + more_data: [30, 40, 50], + }; + + let serialized = borsh::to_vec(&data).unwrap(); + let (deserialized, _) = ArrayCopyStruct::zero_copy_at(&serialized).unwrap(); + + // Arrays should be accessible (in meta_fields before Vec split) + assert_eq!(deserialized.fixed_u8.as_ref(), &[1, 2, 3, 4]); + assert_eq!( + deserialized.another_u8.as_ref(), + &[10, 20, 30, 40, 50, 60, 70, 80] + ); + + // After Vec split + assert_eq!(deserialized.data, &[5, 6]); + assert_eq!(deserialized.more_data.as_ref(), &[30, 40, 50]); + } + + #[test] + fn test_vec_primitive_types() { + // Test Vec with various primitive Copy element types + let data = VecPrimitiveStruct { + header: 1, + data: vec![10, 20, 30], + numbers: vec![100, 200, 300], + footer: 999, + }; + + let serialized = borsh::to_vec(&data).unwrap(); + let (deserialized, _) = VecPrimitiveStruct::zero_copy_at(&serialized).unwrap(); + + assert_eq!(deserialized.header.get(), 1); + + // Vec is special case - stored as slice + assert_eq!(deserialized.data, &[10, 20, 30]); + + // Vec should use ZeroCopySliceBorsh + assert_eq!(deserialized.numbers.len(), 3); + assert_eq!(deserialized.numbers[0].get(), 100); + assert_eq!(deserialized.numbers[1].get(), 200); + assert_eq!(deserialized.numbers[2].get(), 300); + + assert_eq!(deserialized.footer.get(), 999); + } + + #[test] + fn test_all_derives_with_vec_first() { + // This test validates PrimitiveCopyStruct2 which has Vec as the first field + // This means NO meta fields (all fields go to struct_fields due to field ordering) + // Also tests all derive macros: ZeroCopy, ZeroCopyEq, ZeroCopyMut + + use light_zero_copy::{borsh_mut::DeserializeMut, init_mut::ZeroCopyNew}; + + let data = PrimitiveCopyStruct2 { + f: vec![1, 2, 3], // Vec first - causes all fields to be in struct_fields + a: 10, + b: 20, + c: 30, + d: 40, + e: true, + g: 50, + }; + + // Test ZeroCopy (immutable) + let serialized = borsh::to_vec(&data).unwrap(); + let (deserialized, _) = PrimitiveCopyStruct2::zero_copy_at(&serialized).unwrap(); + + // Since Vec is first, ALL fields should be in struct_fields (no meta fields) + assert_eq!(deserialized.f, &[1, 2, 3]); + assert_eq!(deserialized.a, 10); // u8 direct access + assert_eq!(deserialized.b.get(), 20); // U16 via .get() + assert_eq!(deserialized.c.get(), 30); // U32 via .get() + assert_eq!(deserialized.d.get(), 40); // U64 via .get() + assert_eq!(deserialized.e(), true); // bool accessor method + assert_eq!(deserialized.g.get(), 50); // U32 via .get() + + // Test ZeroCopyEq (PartialEq implementation) + let original = PrimitiveCopyStruct2 { + f: vec![1, 2, 3], + a: 10, + b: 20, + c: 30, + d: 40, + e: true, + g: 50, + }; + + // Should be equal to original + assert_eq!(deserialized, original); + + // Test inequality + let different = PrimitiveCopyStruct2 { + f: vec![1, 2, 3], + a: 11, + b: 20, + c: 30, + d: 40, + e: true, + g: 50, // Different 'a' + }; + assert_ne!(deserialized, different); + + // Test ZeroCopyMut (mutable zero-copy) + #[cfg(feature = "mut")] + { + let mut serialized_mut = borsh::to_vec(&data).unwrap(); + let (deserialized_mut, _) = + PrimitiveCopyStruct2::zero_copy_at_mut(&mut serialized_mut).unwrap(); + + // Test mutable access + assert_eq!(deserialized_mut.f, &[1, 2, 3]); + assert_eq!(*deserialized_mut.a, 10); // Mutable u8 field + assert_eq!(deserialized_mut.b.get(), 20); + let (deserialized_mut, _) = + PrimitiveCopyStruct2::zero_copy_at(&mut serialized_mut).unwrap(); + + // Test From implementation (ZeroCopyEq generates this for immutable version) + let converted: PrimitiveCopyStruct2 = deserialized_mut.into(); + assert_eq!(converted.a, 10); + assert_eq!(converted.b, 20); + assert_eq!(converted.c, 30); + assert_eq!(converted.d, 40); + assert_eq!(converted.e, true); + assert_eq!(converted.f, vec![1, 2, 3]); + assert_eq!(converted.g, 50); + } + + // Test ZeroCopyNew (configuration-based initialization) + let config = super::PrimitiveCopyStruct2Config { + f: 3, // Vec length + // Other fields don't need config (they're primitives) + }; + + // Calculate required buffer size + let buffer_size = PrimitiveCopyStruct2::byte_len(&config); + let mut buffer = vec![0u8; buffer_size]; + + // Initialize the zero-copy struct + let (mut initialized, _) = + PrimitiveCopyStruct2::new_zero_copy(&mut buffer, config).unwrap(); + + // Verify we can access the initialized fields + assert_eq!(initialized.f.len(), 3); // Vec should have correct length + + // Set some values in the Vec + initialized.f[0] = 100; + initialized.f[1] = 101; + initialized.f[2] = 102; + *initialized.a = 200; + + // Verify the values were set correctly + assert_eq!(initialized.f, &[100, 101, 102]); + assert_eq!(*initialized.a, 200); + + println!("All derive macros (ZeroCopy, ZeroCopyEq, ZeroCopyMut) work correctly with Vec-first struct!"); + } + + #[test] + fn test_copy_identification_compilation() { + // The primary test is that our macro successfully processes all struct definitions + // above without panicking or generating invalid code. The fact that compilation + // succeeds demonstrates that our Copy identification logic works correctly. + + // Test basic functionality to ensure the generated code is sound + let primitive_data = PrimitiveCopyStruct { + a: 1, + b: 2, + c: 3, + d: 4, + e: true, + f: vec![1, 2], + g: 5, + }; + + let array_data = ArrayCopyStruct { + fixed_u8: [1, 2, 3, 4], + another_u8: [5, 6, 7, 8, 9, 10, 11, 12], + data: vec![13, 14], + more_data: [15, 16, 17], + }; + + let vec_data = VecPrimitiveStruct { + header: 42, + data: vec![1, 2, 3], + numbers: vec![10, 20], + footer: 99, + }; + + // Serialize and deserialize to verify the generated code works + let serialized = borsh::to_vec(&primitive_data).unwrap(); + let (_, _) = PrimitiveCopyStruct::zero_copy_at(&serialized).unwrap(); + + let serialized = borsh::to_vec(&array_data).unwrap(); + let (_, _) = ArrayCopyStruct::zero_copy_at(&serialized).unwrap(); + + let serialized = borsh::to_vec(&vec_data).unwrap(); + let (_, _) = VecPrimitiveStruct::zero_copy_at(&serialized).unwrap(); + + println!("Cross-crate Copy identification test passed - all structs compiled and work correctly!"); + } +} diff --git a/program-libs/zero-copy-derive/tests/from_test.rs b/program-libs/zero-copy-derive/tests/from_test.rs new file mode 100644 index 0000000000..20391c36dd --- /dev/null +++ b/program-libs/zero-copy-derive/tests/from_test.rs @@ -0,0 +1,77 @@ +#![cfg(feature = "mut")] +use std::vec::Vec; + +use borsh::{BorshDeserialize, BorshSerialize}; +use light_zero_copy::{borsh::Deserialize, ZeroCopyEq}; +use light_zero_copy_derive::{ZeroCopy, ZeroCopyMut}; + +// Simple struct with a primitive field and a vector +#[derive(Debug, PartialEq, BorshSerialize, BorshDeserialize, ZeroCopy, ZeroCopyMut, ZeroCopyEq)] +pub struct SimpleStruct { + pub a: u8, + pub b: Vec, +} + +// Basic struct with all basic numeric types +#[derive(Debug, PartialEq, BorshSerialize, BorshDeserialize, ZeroCopy, ZeroCopyMut, ZeroCopyEq)] +pub struct NumericStruct { + pub a: u8, + pub b: u16, + pub c: u32, + pub d: u64, + pub e: bool, +} + +// use light_zero_copy::borsh_mut::DeserializeMut; // Not needed for non-mut derivations + +#[test] +fn test_simple_from_implementation() { + // Create an instance of our struct + let original = SimpleStruct { + a: 42, + b: vec![1, 2, 3, 4, 5], + }; + + // Serialize it + let bytes = original.try_to_vec().unwrap(); + // byte_len not available for non-mut derivations + // assert_eq!(bytes.len(), original.byte_len()); + + // Test From implementation for immutable struct + let (zero_copy, _) = SimpleStruct::zero_copy_at(&bytes).unwrap(); + let converted: SimpleStruct = zero_copy.into(); + assert_eq!(converted.a, 42); + assert_eq!(converted.b, vec![1, 2, 3, 4, 5]); + assert_eq!(converted, original); +} + +#[test] +fn test_numeric_from_implementation() { + // Create a struct with different primitive types + let original = NumericStruct { + a: 1, + b: 2, + c: 3, + d: 4, + e: true, + }; + + // Serialize it + let bytes = original.try_to_vec().unwrap(); + // byte_len not available for non-mut derivations + // assert_eq!(bytes.len(), original.byte_len()); + + // Test From implementation for immutable struct + let (zero_copy, _) = NumericStruct::zero_copy_at(&bytes).unwrap(); + let converted: NumericStruct = zero_copy.clone().into(); + + // Verify all fields + assert_eq!(converted.a, 1); + assert_eq!(converted.b, 2); + assert_eq!(converted.c, 3); + assert_eq!(converted.d, 4); + assert!(converted.e); + + // Verify complete struct + assert_eq!(converted, original); +} diff --git a/program-libs/zero-copy-derive/tests/instruction_data.rs b/program-libs/zero-copy-derive/tests/instruction_data.rs new file mode 100644 index 0000000000..094248e4c8 --- /dev/null +++ b/program-libs/zero-copy-derive/tests/instruction_data.rs @@ -0,0 +1,1401 @@ +#![cfg(feature = "mut")] +use std::vec::Vec; + +use borsh::{BorshDeserialize, BorshSerialize}; +use light_zero_copy::{borsh::Deserialize, borsh_mut::DeserializeMut, errors::ZeroCopyError}; +use light_zero_copy_derive::{ZeroCopy, ZeroCopyEq, ZeroCopyMut}; +use zerocopy::{FromBytes, Immutable, IntoBytes, KnownLayout, Ref, Unaligned}; + +#[derive( + Debug, + Copy, + PartialEq, + Clone, + Immutable, + FromBytes, + IntoBytes, + KnownLayout, + BorshDeserialize, + BorshSerialize, + Default, + Unaligned, +)] +#[repr(C)] +pub struct Pubkey(pub(crate) [u8; 32]); + +impl Pubkey { + pub fn new_unique() -> Self { + use rand::Rng; + let mut rng = rand::thread_rng(); + let bytes = rng.gen::<[u8; 32]>(); + Pubkey(bytes) + } + + pub fn to_bytes(self) -> [u8; 32] { + self.0 + } +} + +impl<'a> Deserialize<'a> for Pubkey { + type Output = Ref<&'a [u8], Pubkey>; + + #[inline] + fn zero_copy_at(bytes: &'a [u8]) -> Result<(Self::Output, &'a [u8]), ZeroCopyError> { + Ok(Ref::<&'a [u8], Pubkey>::from_prefix(bytes)?) + } +} + +impl<'a> DeserializeMut<'a> for Pubkey { + type Output = Ref<&'a mut [u8], Pubkey>; + + #[inline] + fn zero_copy_at_mut( + bytes: &'a mut [u8], + ) -> Result<(Self::Output, &'a mut [u8]), ZeroCopyError> { + Ok(Ref::<&'a mut [u8], Pubkey>::from_prefix(bytes)?) + } +} + +// We should not implement DeserializeMut for primitive types directly +// The implementation should be in the zero-copy crate + +impl PartialEq<>::Output> for Pubkey { + fn eq(&self, other: &>::Output) -> bool { + self.0 == other.0 + } +} + +impl<'a> light_zero_copy::init_mut::ZeroCopyNew<'a> for Pubkey { + type Config = (); + type Output = >::Output; + + fn byte_len(_config: &Self::Config) -> usize { + 32 // Pubkey is always 32 bytes + } + + fn new_zero_copy( + bytes: &'a mut [u8], + _config: Self::Config, + ) -> Result<(Self::Output, &'a mut [u8]), ZeroCopyError> { + Self::zero_copy_at_mut(bytes) + } +} + +#[derive( + ZeroCopy, ZeroCopyMut, BorshDeserialize, BorshSerialize, Debug, PartialEq, Default, Clone, +)] +pub struct InstructionDataInvoke { + pub proof: Option, + pub input_compressed_accounts_with_merkle_context: + Vec, + pub output_compressed_accounts: Vec, + pub relay_fee: Option, + pub new_address_params: Vec, + pub compress_or_decompress_lamports: Option, + pub is_compress: bool, +} + +// impl<'a> light_zero_copy::init_mut::ZeroCopyNew<'a> for InstructionDataInvoke { +// type Config = InstructionDataInvokeConfig; +// type Output = >::Output; +// +// fn new_zero_copy( +// bytes: &'a mut [u8], +// config: Self::Config +// ) -> Result<(Self::Output, &'a mut [u8]), light_zero_copy::errors::ZeroCopyError> { +// use zerocopy::Ref; +// +// // First handle the meta struct (empty for InstructionDataInvoke) +// let (__meta, bytes) = Ref::<&mut [u8], ZInstructionDataInvokeMetaMut>::from_prefix(bytes)?; +// +// // Initialize each field using the corresponding config, following DeserializeMut order +// let (proof, bytes) = as light_zero_copy::init_mut::ZeroCopyNew>::new_zero_copy( +// bytes, +// (config.proof_config.is_some(), CompressedProofConfig {}) +// )?; +// +// let input_configs: Vec = config.input_accounts_configs +// .into_iter() +// .map(|compressed_account_config| PackedCompressedAccountWithMerkleContextConfig { +// compressed_account: CompressedAccountConfig { +// address: (compressed_account_config.address_enabled, ()), +// data: (compressed_account_config.data_enabled, CompressedAccountDataConfig { data: compressed_account_config.data_capacity }), +// }, +// merkle_context: PackedMerkleContextConfig {}, +// }) +// .collect(); +// let (input_compressed_accounts_with_merkle_context, bytes) = as light_zero_copy::init_mut::ZeroCopyNew>::new_zero_copy( +// bytes, +// input_configs +// )?; +// +// let output_configs: Vec = config.output_accounts_configs +// .into_iter() +// .map(|compressed_account_config| OutputCompressedAccountWithPackedContextConfig { +// compressed_account: CompressedAccountConfig { +// address: (compressed_account_config.address_enabled, ()), +// data: (compressed_account_config.data_enabled, CompressedAccountDataConfig { data: compressed_account_config.data_capacity }), +// }, +// }) +// .collect(); +// let (output_compressed_accounts, bytes) = as light_zero_copy::init_mut::ZeroCopyNew>::new_zero_copy( +// bytes, +// output_configs +// )?; +// +// let (relay_fee, bytes) = as light_zero_copy::init_mut::ZeroCopyNew>::new_zero_copy( +// bytes, +// (config.relay_fee_config.is_some(), ()) +// )?; +// +// let new_address_configs: Vec = config.new_address_configs +// .into_iter() +// .map(|_| NewAddressParamsPackedConfig {}) +// .collect(); +// let (new_address_params, bytes) = as light_zero_copy::init_mut::ZeroCopyNew>::new_zero_copy( +// bytes, +// new_address_configs +// )?; +// +// let (compress_or_decompress_lamports, bytes) = as light_zero_copy::init_mut::ZeroCopyNew>::new_zero_copy( +// bytes, +// (config.decompress_lamports_config.is_some(), ()) +// )?; +// +// let (is_compress, bytes) = ::new_zero_copy( +// bytes, +// () +// )?; +// +// Ok(( +// ZInstructionDataInvokeMut { +// proof, +// input_compressed_accounts_with_merkle_context, +// output_compressed_accounts, +// relay_fee, +// new_address_params, +// compress_or_decompress_lamports, +// is_compress, +// }, +// bytes, +// )) +// } +// } + +#[derive( + ZeroCopy, + ZeroCopyMut, + ZeroCopyEq, + BorshDeserialize, + BorshSerialize, + Debug, + PartialEq, + Default, + Clone, +)] +pub struct OutputCompressedAccountWithContext { + pub compressed_account: CompressedAccount, + pub merkle_tree: Pubkey, +} + +#[derive( + ZeroCopy, + ZeroCopyMut, + ZeroCopyEq, + BorshDeserialize, + BorshSerialize, + Debug, + PartialEq, + Default, + Clone, +)] +pub struct OutputCompressedAccountWithPackedContext { + pub compressed_account: CompressedAccount, + pub merkle_tree_index: u8, +} + +// impl<'a> light_zero_copy::init_mut::ZeroCopyNew<'a> for OutputCompressedAccountWithPackedContext { +// type Config = CompressedAccountZeroCopyNew; +// type Output = >::Output; +// +// fn new_zero_copy( +// bytes: &'a mut [u8], +// config: Self::Config +// ) -> Result<(Self::Output, &'a mut [u8]), ZeroCopyError> { +// let (__meta, bytes) = Ref::<&mut [u8], ZOutputCompressedAccountWithPackedContextMetaMut>::from_prefix(bytes)?; +// let (compressed_account, bytes) = ::new_zero_copy(bytes, config)?; +// let (merkle_tree_index, bytes) = ::new_zero_copy(bytes, ())?; +// +// Ok(( +// ZOutputCompressedAccountWithPackedContextMut { +// compressed_account, +// merkle_tree_index, +// }, +// bytes, +// )) +// } +// } + +#[derive( + ZeroCopy, + ZeroCopyMut, + ZeroCopyEq, + BorshDeserialize, + BorshSerialize, + Debug, + PartialEq, + Default, + Clone, + Copy, +)] +pub struct NewAddressParamsPacked { + pub seed: [u8; 32], + pub address_queue_account_index: u8, + pub address_merkle_tree_account_index: u8, + pub address_merkle_tree_root_index: u16, +} + +// impl<'a> light_zero_copy::init_mut::ZeroCopyNew<'a> for NewAddressParamsPacked { +// type Config = (); +// type Output = >::Output; +// +// fn new_zero_copy( +// bytes: &'a mut [u8], +// _config: Self::Config +// ) -> Result<(Self::Output, &'a mut [u8]), ZeroCopyError> { +// let (__meta, bytes) = Ref::<&mut [u8], ZNewAddressParamsPackedMetaMut>::from_prefix(bytes)?; +// Ok((ZNewAddressParamsPackedMut { __meta }, bytes)) +// } +// } + +#[derive( + ZeroCopy, + ZeroCopyMut, + ZeroCopyEq, + BorshDeserialize, + BorshSerialize, + Debug, + PartialEq, + Default, + Clone, +)] +pub struct NewAddressParams { + pub seed: [u8; 32], + pub address_queue_pubkey: Pubkey, + pub address_merkle_tree_pubkey: Pubkey, + pub address_merkle_tree_root_index: u16, +} + +#[derive( + ZeroCopy, + ZeroCopyMut, + ZeroCopyEq, + BorshDeserialize, + BorshSerialize, + Debug, + PartialEq, + Default, + Clone, + Copy, +)] +pub struct PackedReadOnlyAddress { + pub address: [u8; 32], + pub address_merkle_tree_root_index: u16, + pub address_merkle_tree_account_index: u8, +} + +#[derive( + ZeroCopy, + ZeroCopyMut, + ZeroCopyEq, + BorshDeserialize, + BorshSerialize, + Debug, + PartialEq, + Default, + Clone, +)] +pub struct ReadOnlyAddress { + pub address: [u8; 32], + pub address_merkle_tree_pubkey: Pubkey, + pub address_merkle_tree_root_index: u16, +} + +#[derive( + ZeroCopy, + ZeroCopyMut, + ZeroCopyEq, + BorshDeserialize, + BorshSerialize, + Debug, + PartialEq, + Clone, + Copy, +)] +pub struct CompressedProof { + pub a: [u8; 32], + pub b: [u8; 64], + pub c: [u8; 32], +} + +impl Default for CompressedProof { + fn default() -> Self { + Self { + a: [0; 32], + b: [0; 64], + c: [0; 32], + } + } +} + +// impl<'a> light_zero_copy::init_mut::ZeroCopyNew<'a> for CompressedProof { +// type Config = (); +// type Output = >::Output; +// +// fn new_zero_copy( +// bytes: &'a mut [u8], +// _config: Self::Config +// ) -> Result<(Self::Output, &'a mut [u8]), ZeroCopyError> { +// let (__meta, bytes) = Ref::<&mut [u8], ZCompressedProofMetaMut>::from_prefix(bytes)?; +// Ok((ZCompressedProofMut { __meta }, bytes)) +// } +// } + +#[derive( + ZeroCopy, + ZeroCopyMut, + ZeroCopyEq, + BorshDeserialize, + BorshSerialize, + Debug, + Clone, + Copy, + PartialEq, + Eq, + Default, +)] +pub struct CompressedCpiContext { + /// Is set by the program that is invoking the CPI to signal that is should + /// set the cpi context. + pub set_context: bool, + /// Is set to clear the cpi context since someone could have set it before + /// with unrelated data. + pub first_set_context: bool, + /// Index of cpi context account in remaining accounts. + pub cpi_context_account_index: u8, +} + +#[derive( + ZeroCopy, + ZeroCopyMut, + ZeroCopyEq, + BorshDeserialize, + BorshSerialize, + Debug, + PartialEq, + Default, + Clone, +)] +pub struct PackedCompressedAccountWithMerkleContext { + pub compressed_account: CompressedAccount, + pub merkle_context: PackedMerkleContext, + /// Index of root used in inclusion validity proof. + pub root_index: u16, + /// Placeholder to mark accounts read-only unimplemented set to false. + pub read_only: bool, +} + +// impl<'a> light_zero_copy::init_mut::ZeroCopyNew<'a> for PackedCompressedAccountWithMerkleContext { +// type Config = CompressedAccountZeroCopyNew; +// type Output = >::Output; +// +// fn new_zero_copy( +// bytes: &'a mut [u8], +// config: Self::Config +// ) -> Result<(Self::Output, &'a mut [u8]), ZeroCopyError> { +// let (__meta, bytes) = Ref::<&mut [u8], ZPackedCompressedAccountWithMerkleContextMetaMut>::from_prefix(bytes)?; +// let (compressed_account, bytes) = ::new_zero_copy(bytes, config)?; +// let (merkle_context, bytes) = ::new_zero_copy(bytes, ())?; +// let (root_index, bytes) = ::new_zero_copy(bytes, ())?; +// let (read_only, bytes) = ::new_zero_copy(bytes, ())?; +// +// Ok(( +// ZPackedCompressedAccountWithMerkleContextMut { +// compressed_account, +// merkle_context, +// root_index, +// read_only, +// }, +// bytes, +// )) +// } +// } + +#[derive( + ZeroCopy, + ZeroCopyMut, + ZeroCopyEq, + BorshDeserialize, + BorshSerialize, + Debug, + Clone, + Copy, + PartialEq, + Default, +)] +pub struct MerkleContext { + pub merkle_tree_pubkey: Pubkey, + pub nullifier_queue_pubkey: Pubkey, + pub leaf_index: u32, + pub prove_by_index: bool, +} + +// impl<'a> light_zero_copy::init_mut::ZeroCopyNew<'a> for MerkleContext { +// type Config = (); +// type Output = >::Output; +// +// fn new_zero_copy( +// bytes: &'a mut [u8], +// _config: Self::Config +// ) -> Result<(Self::Output, &'a mut [u8]), ZeroCopyError> { +// let (__meta, bytes) = Ref::<&mut [u8], ZMerkleContextMetaMut>::from_prefix(bytes)?; +// +// Ok(( +// ZMerkleContextMut { +// __meta, +// }, +// bytes, +// )) +// } +// } + +#[derive( + ZeroCopy, + ZeroCopyMut, + ZeroCopyEq, + BorshDeserialize, + BorshSerialize, + Debug, + PartialEq, + Default, + Clone, +)] +pub struct CompressedAccountWithMerkleContext { + pub compressed_account: CompressedAccount, + pub merkle_context: MerkleContext, +} + +#[derive( + ZeroCopy, + ZeroCopyMut, + ZeroCopyEq, + BorshDeserialize, + BorshSerialize, + Debug, + PartialEq, + Default, + Clone, +)] +pub struct ReadOnlyCompressedAccount { + pub account_hash: [u8; 32], + pub merkle_context: MerkleContext, + pub root_index: u16, +} + +#[derive( + ZeroCopy, + ZeroCopyMut, + ZeroCopyEq, + BorshDeserialize, + BorshSerialize, + Debug, + PartialEq, + Default, + Clone, +)] +pub struct PackedReadOnlyCompressedAccount { + pub account_hash: [u8; 32], + pub merkle_context: PackedMerkleContext, + pub root_index: u16, +} + +#[derive( + ZeroCopy, + ZeroCopyMut, + ZeroCopyEq, + BorshDeserialize, + BorshSerialize, + Debug, + Clone, + Copy, + PartialEq, + Default, +)] +pub struct PackedMerkleContext { + pub merkle_tree_pubkey_index: u8, + pub nullifier_queue_pubkey_index: u8, + pub leaf_index: u32, + pub prove_by_index: bool, +} + +// impl<'a> light_zero_copy::init_mut::ZeroCopyNew<'a> for PackedMerkleContext { +// type Config = (); +// type Output = >::Output; +// +// fn new_zero_copy( +// bytes: &'a mut [u8], +// _config: Self::Config +// ) -> Result<(Self::Output, &'a mut [u8]), ZeroCopyError> { +// let (__meta, bytes) = Ref::<&mut [u8], ZPackedMerkleContextMetaMut>::from_prefix(bytes)?; +// Ok((ZPackedMerkleContextMut { __meta }, bytes)) +// } +// } + +#[derive(Debug, PartialEq, Default, Clone, Copy)] +pub struct CompressedAccountZeroCopyNew { + pub address_enabled: bool, + pub data_enabled: bool, + pub data_capacity: u32, +} + +// Manual InstructionDataInvokeConfig removed - now using generated config from ZeroCopyNew derive + +#[derive( + ZeroCopy, ZeroCopyMut, BorshDeserialize, BorshSerialize, Debug, PartialEq, Default, Clone, +)] +pub struct CompressedAccount { + pub owner: [u8; 32], + pub lamports: u64, + pub address: Option<[u8; 32]>, + pub data: Option, +} + +// impl<'a> light_zero_copy::init_mut::ZeroCopyNew<'a> for CompressedAccount { +// type Config = CompressedAccountZeroCopyNew; +// type Output = >::Output; +// +// fn new_zero_copy( +// bytes: &'a mut [u8], +// config: Self::Config, +// ) -> Result<(Self::Output, &'a mut [u8]), ZeroCopyError> { +// let (__meta, bytes) = Ref::<&mut [u8], ZCompressedAccountMetaMut>::from_prefix(bytes)?; +// +// // Use generic Option implementation for address field +// let (address, bytes) = as light_zero_copy::init_mut::ZeroCopyNew>::new_zero_copy( +// bytes, +// (config.address_enabled, ()) +// )?; +// +// // Use generic Option implementation for data field +// let (data, bytes) = as light_zero_copy::init_mut::ZeroCopyNew>::new_zero_copy( +// bytes, +// (config.data_enabled, CompressedAccountDataConfig { data: config.data_capacity }) +// )?; +// +// Ok(( +// ZCompressedAccountMut { +// __meta, +// address, +// data, +// }, +// bytes, +// )) +// } +// } + +impl<'a> From> for CompressedAccount { + fn from(value: ZCompressedAccount<'a>) -> Self { + Self { + owner: value.__meta.owner, + lamports: u64::from(value.__meta.lamports), + address: value.address.map(|x| *x), + data: value.data.as_ref().map(|x| x.into()), + } + } +} + +impl<'a> From<&ZCompressedAccount<'a>> for CompressedAccount { + fn from(value: &ZCompressedAccount<'a>) -> Self { + Self { + owner: value.__meta.owner, + lamports: u64::from(value.__meta.lamports), + address: value.address.as_ref().map(|x| **x), + data: value.data.as_ref().map(|x| x.into()), + } + } +} + +impl PartialEq for ZCompressedAccount<'_> { + fn eq(&self, other: &CompressedAccount) -> bool { + // Check address: if both Some and unequal, return false + if self.address.is_some() + && other.address.is_some() + && *self.address.unwrap() != other.address.unwrap() + { + return false; + } + // Check address: if exactly one is Some, return false + if self.address.is_some() != other.address.is_some() { + return false; + } + + // Check data: if both Some and unequal, return false + if self.data.is_some() + && other.data.is_some() + && self.data.as_ref().unwrap() != other.data.as_ref().unwrap() + { + return false; + } + // Check data: if exactly one is Some, return false + if self.data.is_some() != other.data.is_some() { + return false; + } + + self.owner == other.owner && self.lamports == other.lamports + } +} + +// Commented out because mutable derivation is disabled +// impl PartialEq for ZCompressedAccountMut<'_> { +// fn eq(&self, other: &CompressedAccount) -> bool { +// if self.address.is_some() +// && other.address.is_some() +// && **self.address.as_ref().unwrap() != *other.address.as_ref().unwrap() +// { +// return false; +// } +// if self.address.is_some() || other.address.is_some() { +// return false; +// } +// if self.data.is_some() +// && other.data.is_some() +// && self.data.as_ref().unwrap() != other.data.as_ref().unwrap() +// { +// return false; +// } +// if self.data.is_some() || other.data.is_some() { +// return false; +// } + +// self.owner == other.owner && self.lamports == other.lamports +// } +// } +impl PartialEq> for CompressedAccount { + fn eq(&self, other: &ZCompressedAccount) -> bool { + // Check address: if both Some and unequal, return false + if self.address.is_some() + && other.address.is_some() + && self.address.unwrap() != *other.address.unwrap() + { + return false; + } + // Check address: if exactly one is Some, return false + if self.address.is_some() != other.address.is_some() { + return false; + } + + // Check data: if both Some and unequal, return false + if self.data.is_some() + && other.data.is_some() + && other.data.as_ref().unwrap() != self.data.as_ref().unwrap() + { + return false; + } + // Check data: if exactly one is Some, return false + if self.data.is_some() != other.data.is_some() { + return false; + } + + self.owner == other.owner && self.lamports == u64::from(other.lamports) + } +} + +#[derive( + ZeroCopy, + ZeroCopyMut, + ZeroCopyEq, + BorshDeserialize, + BorshSerialize, + Debug, + PartialEq, + Default, + Clone, +)] +pub struct CompressedAccountData { + pub discriminator: [u8; 8], + pub data: Vec, + pub data_hash: [u8; 32], +} + +// COMMENTED OUT: Now using ZeroCopyNew derive macro instead +// impl<'a> light_zero_copy::init_mut::ZeroCopyNew<'a> for CompressedAccountData { +// type Config = u32; // data_capacity +// type Output = >::Output; + +// fn new_zero_copy( +// bytes: &'a mut [u8], +// data_capacity: Self::Config, +// ) -> Result<(Self::Output, &'a mut [u8]), ZeroCopyError> { +// let (__meta, bytes) = Ref::<&mut [u8], ZCompressedAccountDataMetaMut>::from_prefix(bytes)?; +// // For u8 slices we just use &mut [u8] so we init the len and the split mut separately. +// { +// light_zero_copy::slice_mut::ZeroCopySliceMutBorsh::::new_at( +// data_capacity.into(), +// bytes, +// )?; +// } +// // Split off len for +// let (_, bytes) = bytes.split_at_mut(4); +// let (data, bytes) = bytes.split_at_mut(data_capacity as usize); +// let (data_hash, bytes) = Ref::<&mut [u8], [u8; 32]>::from_prefix(bytes)?; +// Ok(( +// ZCompressedAccountDataMut { +// __meta, +// data, +// data_hash, +// }, +// bytes, +// )) +// } +// } + +#[test] +fn test_compressed_account_data_new_at() { + use light_zero_copy::init_mut::ZeroCopyNew; + let config = CompressedAccountDataConfig { data: 10 }; + + // Calculate exact buffer size needed and allocate + let buffer_size = CompressedAccountData::byte_len(&config); + let mut bytes = vec![0u8; buffer_size]; + let result = CompressedAccountData::new_zero_copy(&mut bytes, config); + assert!(result.is_ok()); + let (mut mut_account, remaining) = result.unwrap(); + + // Verify we used exactly the calculated number of bytes + assert_eq!( + remaining.len(), + 0, + "Should have used exactly {} bytes", + buffer_size + ); + + // Test that we can set discriminator + mut_account.__meta.discriminator = [1, 2, 3, 4, 5, 6, 7, 8]; + + // Test that we can write to data + mut_account.data[0] = 42; + mut_account.data[1] = 43; + + // Test that we can set data_hash + mut_account.data_hash[0] = 99; + mut_account.data_hash[1] = 100; + + assert_eq!(mut_account.__meta.discriminator, [1, 2, 3, 4, 5, 6, 7, 8]); + assert_eq!(mut_account.data[0], 42); + assert_eq!(mut_account.data[1], 43); + assert_eq!(mut_account.data_hash[0], 99); + assert_eq!(mut_account.data_hash[1], 100); + + // Test deserializing the initialized bytes with zero_copy_at_mut + let deserialize_result = CompressedAccountData::zero_copy_at_mut(&mut bytes); + assert!(deserialize_result.is_ok()); + let (deserialized_account, _remaining) = deserialize_result.unwrap(); + + // Verify the deserialized data matches what we set + assert_eq!( + deserialized_account.__meta.discriminator, + [1, 2, 3, 4, 5, 6, 7, 8] + ); + assert_eq!(deserialized_account.data.len(), 10); + assert_eq!(deserialized_account.data[0], 42); + assert_eq!(deserialized_account.data[1], 43); + assert_eq!(deserialized_account.data_hash[0], 99); + assert_eq!(deserialized_account.data_hash[1], 100); +} + +#[test] +fn test_compressed_account_new_at() { + use light_zero_copy::init_mut::ZeroCopyNew; + let config = CompressedAccountConfig { + address: (true, ()), + data: (true, CompressedAccountDataConfig { data: 10 }), + }; + + // Calculate exact buffer size needed and allocate + let buffer_size = CompressedAccount::byte_len(&config); + let mut bytes = vec![0u8; buffer_size]; + let result = CompressedAccount::new_zero_copy(&mut bytes, config); + assert!(result.is_ok()); + let (mut mut_account, remaining) = result.unwrap(); + + // Verify we used exactly the calculated number of bytes + assert_eq!( + remaining.len(), + 0, + "Should have used exactly {} bytes", + buffer_size + ); + + // Set values + mut_account.__meta.owner = [1u8; 32]; + mut_account.__meta.lamports = 12345u64.into(); + mut_account.address.as_mut().unwrap()[0] = 42; + mut_account.data.as_mut().unwrap().data[0] = 99; + + // Test deserialize + let (deserialized, _) = CompressedAccount::zero_copy_at_mut(&mut bytes).unwrap(); + assert_eq!(deserialized.__meta.owner, [1u8; 32]); + assert_eq!(u64::from(deserialized.__meta.lamports), 12345u64); + assert_eq!(deserialized.address.as_ref().unwrap()[0], 42); + assert_eq!(deserialized.data.as_ref().unwrap().data[0], 99); +} + +#[test] +fn test_instruction_data_invoke_new_at() { + use light_zero_copy::init_mut::ZeroCopyNew; + // Create different configs to test various combinations + let compressed_account_config1 = CompressedAccountZeroCopyNew { + address_enabled: true, + data_enabled: true, + data_capacity: 10, + }; + + let compressed_account_config2 = CompressedAccountZeroCopyNew { + address_enabled: false, + data_enabled: true, + data_capacity: 5, + }; + + let compressed_account_config3 = CompressedAccountZeroCopyNew { + address_enabled: true, + data_enabled: false, + data_capacity: 0, + }; + + let compressed_account_config4 = CompressedAccountZeroCopyNew { + address_enabled: false, + data_enabled: false, + data_capacity: 0, + }; + + let config = InstructionDataInvokeConfig { + proof: (true, CompressedProofConfig {}), // Enable proof + input_compressed_accounts_with_merkle_context: vec![ + PackedCompressedAccountWithMerkleContextConfig { + compressed_account: CompressedAccountConfig { + address: (compressed_account_config1.address_enabled, ()), + data: ( + compressed_account_config1.data_enabled, + CompressedAccountDataConfig { + data: compressed_account_config1.data_capacity, + }, + ), + }, + merkle_context: PackedMerkleContextConfig {}, + }, + PackedCompressedAccountWithMerkleContextConfig { + compressed_account: CompressedAccountConfig { + address: (compressed_account_config2.address_enabled, ()), + data: ( + compressed_account_config2.data_enabled, + CompressedAccountDataConfig { + data: compressed_account_config2.data_capacity, + }, + ), + }, + merkle_context: PackedMerkleContextConfig {}, + }, + ], + output_compressed_accounts: vec![ + OutputCompressedAccountWithPackedContextConfig { + compressed_account: CompressedAccountConfig { + address: (compressed_account_config3.address_enabled, ()), + data: ( + compressed_account_config3.data_enabled, + CompressedAccountDataConfig { + data: compressed_account_config3.data_capacity, + }, + ), + }, + }, + OutputCompressedAccountWithPackedContextConfig { + compressed_account: CompressedAccountConfig { + address: (compressed_account_config4.address_enabled, ()), + data: ( + compressed_account_config4.data_enabled, + CompressedAccountDataConfig { + data: compressed_account_config4.data_capacity, + }, + ), + }, + }, + ], + relay_fee: true, // Enable relay fee + new_address_params: vec![ + NewAddressParamsPackedConfig {}, + NewAddressParamsPackedConfig {}, + ], // Length 2 + compress_or_decompress_lamports: true, // Enable decompress lamports + }; + + // Calculate exact buffer size needed and allocate + let buffer_size = InstructionDataInvoke::byte_len(&config); + let mut bytes = vec![0u8; buffer_size]; + + let result = InstructionDataInvoke::new_zero_copy(&mut bytes, config); + if let Err(ref e) = result { + eprintln!("Error: {:?}", e); + } + assert!(result.is_ok()); + let (_instruction_data, remaining) = result.unwrap(); + + // Verify we used exactly the calculated number of bytes + assert_eq!( + remaining.len(), + 0, + "Should have used exactly {} bytes", + buffer_size + ); + + // Test deserialization round-trip first + let (mut deserialized, _) = InstructionDataInvoke::zero_copy_at_mut(&mut bytes).unwrap(); + + // Now set values and test again + *deserialized.is_compress = 1; + + // Set proof values + if let Some(proof) = &mut deserialized.proof { + proof.a[0] = 42; + proof.b[0] = 43; + proof.c[0] = 44; + } + + // Set relay fee value + if let Some(relay_fee) = &mut deserialized.relay_fee { + **relay_fee = 12345u64.into(); + } + + // Set decompress lamports value + if let Some(decompress_lamports) = &mut deserialized.compress_or_decompress_lamports { + **decompress_lamports = 67890u64.into(); + } + + // Set first input account values + let first_input = &mut deserialized.input_compressed_accounts_with_merkle_context[0]; + first_input.compressed_account.__meta.owner[0] = 11; + first_input.compressed_account.__meta.lamports = 1000u64.into(); + if let Some(address) = &mut first_input.compressed_account.address { + address[0] = 22; + } + if let Some(data) = &mut first_input.compressed_account.data { + data.__meta.discriminator[0] = 33; + data.data[0] = 99; + data.data_hash[0] = 55; + } + + // Set first output account values + let first_output = &mut deserialized.output_compressed_accounts[0]; + first_output.compressed_account.__meta.owner[0] = 77; + first_output.compressed_account.__meta.lamports = 2000u64.into(); + if let Some(address) = &mut first_output.compressed_account.address { + address[0] = 88; + } + + // Verify basic structure with vectors of length 2 + assert_eq!( + deserialized + .input_compressed_accounts_with_merkle_context + .len(), + 2 + ); // Length 2 + assert_eq!(deserialized.output_compressed_accounts.len(), 2); // Length 2 + assert_eq!(deserialized.new_address_params.len(), 2); // Length 2 + assert!(deserialized.proof.is_some()); // Enabled + assert!(deserialized.relay_fee.is_some()); // Enabled + assert!(deserialized.compress_or_decompress_lamports.is_some()); // Enabled + assert_eq!(*deserialized.is_compress, 1); + + // Test data access and modification + if let Some(proof) = &deserialized.proof { + // Verify we can access proof fields and our written values + assert_eq!(proof.a[0], 42); + assert_eq!(proof.b[0], 43); + assert_eq!(proof.c[0], 44); + } + + // Verify option integer values + if let Some(relay_fee) = &deserialized.relay_fee { + assert_eq!(u64::from(**relay_fee), 12345); + } + + if let Some(decompress_lamports) = &deserialized.compress_or_decompress_lamports { + assert_eq!(u64::from(**decompress_lamports), 67890); + } + + // Test accessing first input account (config1: address=true, data=true, capacity=10) + let first_input = &deserialized.input_compressed_accounts_with_merkle_context[0]; + assert_eq!(first_input.compressed_account.__meta.owner[0], 11); // Our written value + assert_eq!( + u64::from(first_input.compressed_account.__meta.lamports), + 1000 + ); // Our written value + assert!(first_input.compressed_account.address.is_some()); // Should be enabled + assert!(first_input.compressed_account.data.is_some()); // Should be enabled + if let Some(address) = &first_input.compressed_account.address { + assert_eq!(address[0], 22); // Our written value + } + if let Some(data) = &first_input.compressed_account.data { + assert_eq!(data.data.len(), 10); // Should have capacity 10 + assert_eq!(data.__meta.discriminator[0], 33); // Our written value + assert_eq!(data.data[0], 99); // Our written value + assert_eq!(data.data_hash[0], 55); // Our written value + } + + // Test accessing second input account (config2: address=false, data=true, capacity=5) + let second_input = &deserialized.input_compressed_accounts_with_merkle_context[1]; + assert_eq!(second_input.compressed_account.__meta.owner[0], 0); // Should be zero (not written) + assert!(second_input.compressed_account.address.is_none()); // Should be disabled + assert!(second_input.compressed_account.data.is_some()); // Should be enabled + if let Some(data) = &second_input.compressed_account.data { + assert_eq!(data.data.len(), 5); // Should have capacity 5 + } + + // Test accessing first output account (config3: address=true, data=false, capacity=0) + let first_output = &deserialized.output_compressed_accounts[0]; + assert_eq!(first_output.compressed_account.__meta.owner[0], 77); // Our written value + assert_eq!( + u64::from(first_output.compressed_account.__meta.lamports), + 2000 + ); // Our written value + assert!(first_output.compressed_account.address.is_some()); // Should be enabled + assert!(first_output.compressed_account.data.is_none()); // Should be disabled + if let Some(address) = &first_output.compressed_account.address { + assert_eq!(address[0], 88); // Our written value + } + + // Test accessing second output account (config4: address=false, data=false, capacity=0) + let second_output = &deserialized.output_compressed_accounts[1]; + assert_eq!(second_output.compressed_account.__meta.owner[0], 0); // Should be zero (not written) + assert!(second_output.compressed_account.address.is_none()); // Should be disabled + assert!(second_output.compressed_account.data.is_none()); // Should be disabled +} + +#[test] +fn readme() { + use borsh::{BorshDeserialize, BorshSerialize}; + use light_zero_copy_derive::{ZeroCopy, ZeroCopyEq, ZeroCopyMut}; + + #[repr(C)] + #[derive(Debug, PartialEq, BorshSerialize, BorshDeserialize, ZeroCopy, ZeroCopyMut)] + pub struct MyStructOption { + pub a: u8, + pub b: u16, + pub vec: Vec>, + pub c: Option, + } + + #[repr(C)] + #[derive( + Debug, PartialEq, BorshSerialize, BorshDeserialize, ZeroCopy, ZeroCopyMut, ZeroCopyEq, + )] + pub struct MyStruct { + pub a: u8, + pub b: u16, + pub vec: Vec, + pub c: u64, + } + + // Test the new ZeroCopyNew functionality + + #[repr(C)] + #[derive(Debug, PartialEq, BorshSerialize, BorshDeserialize, ZeroCopy, ZeroCopyMut)] + pub struct TestConfigStruct { + pub a: u8, + pub b: u16, + pub vec: Vec, + pub option: Option, + } + + let my_struct = MyStruct { + a: 1, + b: 2, + vec: vec![1u8; 32], + c: 3, + }; + // Use the struct with zero-copy deserialization + let bytes = my_struct.try_to_vec().unwrap(); + // byte_len not available for non-mut derivations + // assert_eq!(bytes.len(), my_struct.byte_len()); + let (zero_copy, _remaining) = MyStruct::zero_copy_at(&bytes).unwrap(); + assert_eq!(zero_copy.a, 1); + let org_struct: MyStruct = zero_copy.into(); + assert_eq!(org_struct, my_struct); + // { + // let (mut zero_copy_mut, _remaining) = MyStruct::zero_copy_at_mut(&mut bytes).unwrap(); + // zero_copy_mut.a = 42; + // } + // let borsh = MyStruct::try_from_slice(&bytes).unwrap(); + // assert_eq!(borsh.a, 42u8); +} + +#[derive( + ZeroCopy, ZeroCopyMut, BorshDeserialize, BorshSerialize, Debug, PartialEq, Default, Clone, +)] +pub struct InstructionDataInvokeCpi { + pub proof: Option, + pub new_address_params: Vec, + pub input_compressed_accounts_with_merkle_context: + Vec, + pub output_compressed_accounts: Vec, + pub relay_fee: Option, + pub compress_or_decompress_lamports: Option, + pub is_compress: bool, + pub cpi_context: Option, +} + +impl PartialEq> for InstructionDataInvokeCpi { + fn eq(&self, other: &ZInstructionDataInvokeCpi) -> bool { + // Compare proof + match (&self.proof, &other.proof) { + (Some(ref self_proof), Some(ref other_proof)) => { + if self_proof.a != other_proof.a + || self_proof.b != other_proof.b + || self_proof.c != other_proof.c + { + return false; + } + } + (None, None) => {} + _ => return false, + } + + // Compare vectors lengths first + if self.new_address_params.len() != other.new_address_params.len() + || self.input_compressed_accounts_with_merkle_context.len() + != other.input_compressed_accounts_with_merkle_context.len() + || self.output_compressed_accounts.len() != other.output_compressed_accounts.len() + { + return false; + } + + // Compare new_address_params + for (self_param, other_param) in self + .new_address_params + .iter() + .zip(other.new_address_params.iter()) + { + if self_param.seed != other_param.seed + || self_param.address_queue_account_index != other_param.address_queue_account_index + || self_param.address_merkle_tree_account_index + != other_param.address_merkle_tree_account_index + || self_param.address_merkle_tree_root_index + != u16::from(other_param.address_merkle_tree_root_index) + { + return false; + } + } + + // Compare input accounts + for (self_input, other_input) in self + .input_compressed_accounts_with_merkle_context + .iter() + .zip(other.input_compressed_accounts_with_merkle_context.iter()) + { + if self_input != other_input { + return false; + } + } + + // Compare output accounts + for (self_output, other_output) in self + .output_compressed_accounts + .iter() + .zip(other.output_compressed_accounts.iter()) + { + if self_output != other_output { + return false; + } + } + + // Compare relay_fee + match (&self.relay_fee, &other.relay_fee) { + (Some(self_fee), Some(other_fee)) => { + if *self_fee != u64::from(**other_fee) { + return false; + } + } + (None, None) => {} + _ => return false, + } + + // Compare compress_or_decompress_lamports + match ( + &self.compress_or_decompress_lamports, + &other.compress_or_decompress_lamports, + ) { + (Some(self_lamports), Some(other_lamports)) => { + if *self_lamports != u64::from(**other_lamports) { + return false; + } + } + (None, None) => {} + _ => return false, + } + + // Compare is_compress (bool vs u8) + if self.is_compress != (other.is_compress != 0) { + return false; + } + + // Compare cpi_context + match (&self.cpi_context, &other.cpi_context) { + (Some(self_ctx), Some(other_ctx)) => { + if self_ctx.set_context != (other_ctx.set_context != 0) + || self_ctx.first_set_context != (other_ctx.first_set_context != 0) + || self_ctx.cpi_context_account_index != other_ctx.cpi_context_account_index + { + return false; + } + } + (None, None) => {} + _ => return false, + } + + true + } +} + +impl PartialEq for ZInstructionDataInvokeCpi<'_> { + fn eq(&self, other: &InstructionDataInvokeCpi) -> bool { + other.eq(self) + } +} + +impl PartialEq> + for PackedCompressedAccountWithMerkleContext +{ + fn eq(&self, other: &ZPackedCompressedAccountWithMerkleContext) -> bool { + // Compare compressed_account + if self.compressed_account.owner != other.compressed_account.__meta.owner + || self.compressed_account.lamports + != u64::from(other.compressed_account.__meta.lamports) + { + return false; + } + + // Compare optional address + match ( + &self.compressed_account.address, + &other.compressed_account.address, + ) { + (Some(self_addr), Some(other_addr)) => { + if *self_addr != **other_addr { + return false; + } + } + (None, None) => {} + _ => return false, + } + + // Compare optional data + match ( + &self.compressed_account.data, + &other.compressed_account.data, + ) { + (Some(self_data), Some(other_data)) => { + if self_data.discriminator != other_data.__meta.discriminator + || self_data.data_hash != *other_data.data_hash + || self_data.data.len() != other_data.data.len() + { + return false; + } + // Compare data contents + for (self_byte, other_byte) in self_data.data.iter().zip(other_data.data.iter()) { + if *self_byte != *other_byte { + return false; + } + } + } + (None, None) => {} + _ => return false, + } + + // Compare merkle_context + if self.merkle_context.merkle_tree_pubkey_index + != other.merkle_context.__meta.merkle_tree_pubkey_index + || self.merkle_context.nullifier_queue_pubkey_index + != other.merkle_context.__meta.nullifier_queue_pubkey_index + || self.merkle_context.leaf_index != u32::from(other.merkle_context.__meta.leaf_index) + || self.merkle_context.prove_by_index != other.merkle_context.prove_by_index() + { + return false; + } + + // Compare root_index and read_only + if self.root_index != u16::from(*other.root_index) + || self.read_only != (other.read_only != 0) + { + return false; + } + + true + } +} + +impl PartialEq> + for OutputCompressedAccountWithPackedContext +{ + fn eq(&self, other: &ZOutputCompressedAccountWithPackedContext) -> bool { + // Compare compressed_account + if self.compressed_account.owner != other.compressed_account.__meta.owner + || self.compressed_account.lamports + != u64::from(other.compressed_account.__meta.lamports) + { + return false; + } + + // Compare optional address + match ( + &self.compressed_account.address, + &other.compressed_account.address, + ) { + (Some(self_addr), Some(other_addr)) => { + if *self_addr != **other_addr { + return false; + } + } + (None, None) => {} + _ => return false, + } + + // Compare optional data + match ( + &self.compressed_account.data, + &other.compressed_account.data, + ) { + (Some(self_data), Some(other_data)) => { + if self_data.discriminator != other_data.__meta.discriminator + || self_data.data_hash != *other_data.data_hash + || self_data.data.len() != other_data.data.len() + { + return false; + } + // Compare data contents + for (self_byte, other_byte) in self_data.data.iter().zip(other_data.data.iter()) { + if *self_byte != *other_byte { + return false; + } + } + } + (None, None) => {} + _ => return false, + } + + // Compare merkle_tree_index + if self.merkle_tree_index != other.merkle_tree_index { + return false; + } + + true + } +} diff --git a/program-libs/zero-copy-derive/tests/random.rs b/program-libs/zero-copy-derive/tests/random.rs new file mode 100644 index 0000000000..993adef704 --- /dev/null +++ b/program-libs/zero-copy-derive/tests/random.rs @@ -0,0 +1,651 @@ +#![cfg(feature = "mut")] +use std::assert_eq; + +use borsh::BorshDeserialize; +use light_zero_copy::{borsh::Deserialize, init_mut::ZeroCopyNew}; +use rand::{ + rngs::{StdRng, ThreadRng}, + Rng, +}; + +mod instruction_data; +use instruction_data::{ + CompressedAccount, + CompressedAccountConfig, + CompressedAccountData, + CompressedAccountDataConfig, + CompressedCpiContext, + CompressedCpiContextConfig, + CompressedProof, + CompressedProofConfig, + InstructionDataInvoke, + // Config types (generated by ZeroCopyNew derive) + InstructionDataInvokeConfig, + InstructionDataInvokeCpi, + InstructionDataInvokeCpiConfig, + NewAddressParamsPacked, + NewAddressParamsPackedConfig, + OutputCompressedAccountWithPackedContext, + OutputCompressedAccountWithPackedContextConfig, + PackedCompressedAccountWithMerkleContext, + PackedCompressedAccountWithMerkleContextConfig, + PackedMerkleContext, + PackedMerkleContextConfig, + Pubkey, + // Zero-copy mutable types + ZInstructionDataInvokeCpiMut, + ZInstructionDataInvokeMut, +}; + +// Function to populate mutable zero-copy structure with data from InstructionDataInvokeCpi +fn populate_invoke_cpi_zero_copy( + src: &InstructionDataInvokeCpi, + dst: &mut ZInstructionDataInvokeCpiMut, +) { + *dst.is_compress = if src.is_compress { 1 } else { 0 }; + + // Copy proof if present + if let (Some(src_proof), Some(dst_proof)) = (&src.proof, &mut dst.proof) { + dst_proof.a.copy_from_slice(&src_proof.a); + dst_proof.b.copy_from_slice(&src_proof.b); + dst_proof.c.copy_from_slice(&src_proof.c); + } + + // Copy new_address_params + for (src_param, dst_param) in src + .new_address_params + .iter() + .zip(dst.new_address_params.iter_mut()) + { + dst_param.seed.copy_from_slice(&src_param.seed); + dst_param.address_queue_account_index = src_param.address_queue_account_index; + dst_param.address_merkle_tree_account_index = src_param.address_merkle_tree_account_index; + dst_param.address_merkle_tree_root_index = src_param.address_merkle_tree_root_index.into(); + } + + // Copy input_compressed_accounts_with_merkle_context + for (src_input, dst_input) in src + .input_compressed_accounts_with_merkle_context + .iter() + .zip(dst.input_compressed_accounts_with_merkle_context.iter_mut()) + { + // Copy compressed account + dst_input + .compressed_account + .owner + .copy_from_slice(&src_input.compressed_account.owner); + dst_input.compressed_account.lamports = src_input.compressed_account.lamports.into(); + + // Copy address if present + if let (Some(src_addr), Some(dst_addr)) = ( + &src_input.compressed_account.address, + &mut dst_input.compressed_account.address, + ) { + dst_addr.copy_from_slice(src_addr); + } + + // Copy data if present + if let (Some(src_data), Some(dst_data)) = ( + &src_input.compressed_account.data, + &mut dst_input.compressed_account.data, + ) { + dst_data + .discriminator + .copy_from_slice(&src_data.discriminator); + dst_data.data_hash.copy_from_slice(&src_data.data_hash); + for (src_byte, dst_byte) in src_data.data.iter().zip(dst_data.data.iter_mut()) { + *dst_byte = *src_byte; + } + } + + // Copy merkle context + dst_input.merkle_context.merkle_tree_pubkey_index = + src_input.merkle_context.merkle_tree_pubkey_index; + dst_input.merkle_context.nullifier_queue_pubkey_index = + src_input.merkle_context.nullifier_queue_pubkey_index; + dst_input.merkle_context.leaf_index = src_input.merkle_context.leaf_index.into(); + dst_input.merkle_context.prove_by_index = if src_input.merkle_context.prove_by_index { + 1 + } else { + 0 + }; + + *dst_input.root_index = src_input.root_index.into(); + *dst_input.read_only = if src_input.read_only { 1 } else { 0 }; + } + + // Copy output_compressed_accounts + for (src_output, dst_output) in src + .output_compressed_accounts + .iter() + .zip(dst.output_compressed_accounts.iter_mut()) + { + // Copy compressed account + dst_output + .compressed_account + .owner + .copy_from_slice(&src_output.compressed_account.owner); + dst_output.compressed_account.lamports = src_output.compressed_account.lamports.into(); + + // Copy address if present + if let (Some(src_addr), Some(dst_addr)) = ( + &src_output.compressed_account.address, + &mut dst_output.compressed_account.address, + ) { + dst_addr.copy_from_slice(src_addr); + } + + // Copy data if present + if let (Some(src_data), Some(dst_data)) = ( + &src_output.compressed_account.data, + &mut dst_output.compressed_account.data, + ) { + dst_data + .discriminator + .copy_from_slice(&src_data.discriminator); + dst_data.data_hash.copy_from_slice(&src_data.data_hash); + for (src_byte, dst_byte) in src_data.data.iter().zip(dst_data.data.iter_mut()) { + *dst_byte = *src_byte; + } + } + + *dst_output.merkle_tree_index = src_output.merkle_tree_index; + } + + // Copy relay_fee if present + if let (Some(src_fee), Some(dst_fee)) = (&src.relay_fee, &mut dst.relay_fee) { + **dst_fee = (*src_fee).into(); + } + + // Copy compress_or_decompress_lamports if present + if let (Some(src_lamports), Some(dst_lamports)) = ( + &src.compress_or_decompress_lamports, + &mut dst.compress_or_decompress_lamports, + ) { + **dst_lamports = (*src_lamports).into(); + } + + // Copy cpi_context if present + if let (Some(src_ctx), Some(dst_ctx)) = (&src.cpi_context, &mut dst.cpi_context) { + dst_ctx.set_context = if src_ctx.set_context { 1 } else { 0 }; + dst_ctx.first_set_context = if src_ctx.first_set_context { 1 } else { 0 }; + dst_ctx.cpi_context_account_index = src_ctx.cpi_context_account_index; + } +} + +// Function to populate mutable zero-copy structure with data from InstructionDataInvoke +fn populate_invoke_zero_copy(src: &InstructionDataInvoke, dst: &mut ZInstructionDataInvokeMut) { + *dst.is_compress = if src.is_compress { 1 } else { 0 }; + + // Copy proof if present + if let (Some(src_proof), Some(dst_proof)) = (&src.proof, &mut dst.proof) { + dst_proof.a.copy_from_slice(&src_proof.a); + dst_proof.b.copy_from_slice(&src_proof.b); + dst_proof.c.copy_from_slice(&src_proof.c); + } + + // Copy new_address_params + for (src_param, dst_param) in src + .new_address_params + .iter() + .zip(dst.new_address_params.iter_mut()) + { + dst_param.seed.copy_from_slice(&src_param.seed); + dst_param.address_queue_account_index = src_param.address_queue_account_index; + dst_param.address_merkle_tree_account_index = src_param.address_merkle_tree_account_index; + dst_param.address_merkle_tree_root_index = src_param.address_merkle_tree_root_index.into(); + } + + // Copy input_compressed_accounts_with_merkle_context + for (src_input, dst_input) in src + .input_compressed_accounts_with_merkle_context + .iter() + .zip(dst.input_compressed_accounts_with_merkle_context.iter_mut()) + { + // Copy compressed account + dst_input + .compressed_account + .owner + .copy_from_slice(&src_input.compressed_account.owner); + dst_input.compressed_account.lamports = src_input.compressed_account.lamports.into(); + + // Copy address if present + if let (Some(src_addr), Some(dst_addr)) = ( + &src_input.compressed_account.address, + &mut dst_input.compressed_account.address, + ) { + dst_addr.copy_from_slice(src_addr); + } + + // Copy data if present + if let (Some(src_data), Some(dst_data)) = ( + &src_input.compressed_account.data, + &mut dst_input.compressed_account.data, + ) { + dst_data + .discriminator + .copy_from_slice(&src_data.discriminator); + dst_data.data_hash.copy_from_slice(&src_data.data_hash); + for (src_byte, dst_byte) in src_data.data.iter().zip(dst_data.data.iter_mut()) { + *dst_byte = *src_byte; + } + } + + // Copy merkle context + dst_input.merkle_context.merkle_tree_pubkey_index = + src_input.merkle_context.merkle_tree_pubkey_index; + dst_input.merkle_context.nullifier_queue_pubkey_index = + src_input.merkle_context.nullifier_queue_pubkey_index; + dst_input.merkle_context.leaf_index = src_input.merkle_context.leaf_index.into(); + dst_input.merkle_context.prove_by_index = if src_input.merkle_context.prove_by_index { + 1 + } else { + 0 + }; + + *dst_input.root_index = src_input.root_index.into(); + *dst_input.read_only = if src_input.read_only { 1 } else { 0 }; + } + + // Copy output_compressed_accounts + for (src_output, dst_output) in src + .output_compressed_accounts + .iter() + .zip(dst.output_compressed_accounts.iter_mut()) + { + // Copy compressed account + dst_output + .compressed_account + .owner + .copy_from_slice(&src_output.compressed_account.owner); + dst_output.compressed_account.lamports = src_output.compressed_account.lamports.into(); + + // Copy address if present + if let (Some(src_addr), Some(dst_addr)) = ( + &src_output.compressed_account.address, + &mut dst_output.compressed_account.address, + ) { + dst_addr.copy_from_slice(src_addr); + } + + // Copy data if present + if let (Some(src_data), Some(dst_data)) = ( + &src_output.compressed_account.data, + &mut dst_output.compressed_account.data, + ) { + dst_data + .discriminator + .copy_from_slice(&src_data.discriminator); + dst_data.data_hash.copy_from_slice(&src_data.data_hash); + for (src_byte, dst_byte) in src_data.data.iter().zip(dst_data.data.iter_mut()) { + *dst_byte = *src_byte; + } + } + + *dst_output.merkle_tree_index = src_output.merkle_tree_index; + } + + // Copy relay_fee if present + if let (Some(src_fee), Some(dst_fee)) = (&src.relay_fee, &mut dst.relay_fee) { + **dst_fee = (*src_fee).into(); + } + + // Copy compress_or_decompress_lamports if present + if let (Some(src_lamports), Some(dst_lamports)) = ( + &src.compress_or_decompress_lamports, + &mut dst.compress_or_decompress_lamports, + ) { + **dst_lamports = (*src_lamports).into(); + } +} + +fn get_rnd_instruction_data_invoke_cpi(rng: &mut StdRng) -> InstructionDataInvokeCpi { + InstructionDataInvokeCpi { + proof: Some(CompressedProof { + a: rng.gen(), + b: (0..64) + .map(|_| rng.gen()) + .collect::>() + .try_into() + .unwrap(), + c: rng.gen(), + }), + new_address_params: vec![get_rnd_new_address_params(rng); rng.gen_range(0..10)], + input_compressed_accounts_with_merkle_context: vec![ + get_rnd_test_input_account(rng); + rng.gen_range(0..10) + ], + output_compressed_accounts: vec![get_rnd_test_output_account(rng); rng.gen_range(0..10)], + relay_fee: None, + compress_or_decompress_lamports: rng.gen(), + is_compress: rng.gen(), + cpi_context: Some(get_rnd_cpi_context(rng)), + } +} + +fn get_rnd_cpi_context(rng: &mut StdRng) -> CompressedCpiContext { + CompressedCpiContext { + first_set_context: rng.gen(), + set_context: rng.gen(), + cpi_context_account_index: rng.gen(), + } +} + +fn get_rnd_test_account_data(rng: &mut StdRng) -> CompressedAccountData { + CompressedAccountData { + discriminator: rng.gen(), + data: (0..100).map(|_| rng.gen()).collect::>(), + data_hash: rng.gen(), + } +} + +fn get_rnd_test_account(rng: &mut StdRng) -> CompressedAccount { + CompressedAccount { + owner: Pubkey::new_unique().to_bytes(), + lamports: rng.gen(), + address: Some(Pubkey::new_unique().to_bytes()), + data: Some(get_rnd_test_account_data(rng)), + } +} + +fn get_rnd_test_output_account(rng: &mut StdRng) -> OutputCompressedAccountWithPackedContext { + OutputCompressedAccountWithPackedContext { + compressed_account: get_rnd_test_account(rng), + merkle_tree_index: rng.gen(), + } +} + +fn get_rnd_test_input_account(rng: &mut StdRng) -> PackedCompressedAccountWithMerkleContext { + PackedCompressedAccountWithMerkleContext { + compressed_account: CompressedAccount { + owner: Pubkey::new_unique().to_bytes(), + lamports: 100, + address: Some(Pubkey::new_unique().to_bytes()), + data: Some(get_rnd_test_account_data(rng)), + }, + merkle_context: PackedMerkleContext { + merkle_tree_pubkey_index: rng.gen(), + nullifier_queue_pubkey_index: rng.gen(), + leaf_index: rng.gen(), + prove_by_index: rng.gen(), + }, + root_index: rng.gen(), + read_only: false, + } +} + +fn get_rnd_new_address_params(rng: &mut StdRng) -> NewAddressParamsPacked { + NewAddressParamsPacked { + seed: rng.gen(), + address_queue_account_index: rng.gen(), + address_merkle_tree_account_index: rng.gen(), + address_merkle_tree_root_index: rng.gen(), + } +} + +// Generate config for InstructionDataInvoke based on the actual data +fn generate_random_invoke_config( + invoke_ref: &InstructionDataInvoke, +) -> InstructionDataInvokeConfig { + InstructionDataInvokeConfig { + proof: (invoke_ref.proof.is_some(), CompressedProofConfig {}), + input_compressed_accounts_with_merkle_context: invoke_ref + .input_compressed_accounts_with_merkle_context + .iter() + .map(|account| PackedCompressedAccountWithMerkleContextConfig { + compressed_account: CompressedAccountConfig { + address: (account.compressed_account.address.is_some(), ()), + data: ( + account.compressed_account.data.is_some(), + CompressedAccountDataConfig { + data: account + .compressed_account + .data + .as_ref() + .map(|d| d.data.len() as u32) + .unwrap_or(0), + }, + ), + }, + merkle_context: PackedMerkleContextConfig {}, + }) + .collect(), + output_compressed_accounts: invoke_ref + .output_compressed_accounts + .iter() + .map(|account| OutputCompressedAccountWithPackedContextConfig { + compressed_account: CompressedAccountConfig { + address: (account.compressed_account.address.is_some(), ()), + data: ( + account.compressed_account.data.is_some(), + CompressedAccountDataConfig { + data: account + .compressed_account + .data + .as_ref() + .map(|d| d.data.len() as u32) + .unwrap_or(0), + }, + ), + }, + }) + .collect(), + relay_fee: invoke_ref.relay_fee.is_some(), + new_address_params: invoke_ref + .new_address_params + .iter() + .map(|_| NewAddressParamsPackedConfig {}) + .collect(), + compress_or_decompress_lamports: invoke_ref.compress_or_decompress_lamports.is_some(), + } +} + +// Generate config for InstructionDataInvokeCpi based on the actual data +fn generate_random_invoke_cpi_config( + invoke_cpi_ref: &InstructionDataInvokeCpi, +) -> InstructionDataInvokeCpiConfig { + InstructionDataInvokeCpiConfig { + proof: (invoke_cpi_ref.proof.is_some(), CompressedProofConfig {}), + new_address_params: invoke_cpi_ref + .new_address_params + .iter() + .map(|_| NewAddressParamsPackedConfig {}) + .collect(), + input_compressed_accounts_with_merkle_context: invoke_cpi_ref + .input_compressed_accounts_with_merkle_context + .iter() + .map(|account| PackedCompressedAccountWithMerkleContextConfig { + compressed_account: CompressedAccountConfig { + address: (account.compressed_account.address.is_some(), ()), + data: ( + account.compressed_account.data.is_some(), + CompressedAccountDataConfig { + data: account + .compressed_account + .data + .as_ref() + .map(|d| d.data.len() as u32) + .unwrap_or(0), + }, + ), + }, + merkle_context: PackedMerkleContextConfig {}, + }) + .collect(), + output_compressed_accounts: invoke_cpi_ref + .output_compressed_accounts + .iter() + .map(|account| OutputCompressedAccountWithPackedContextConfig { + compressed_account: CompressedAccountConfig { + address: (account.compressed_account.address.is_some(), ()), + data: ( + account.compressed_account.data.is_some(), + CompressedAccountDataConfig { + data: account + .compressed_account + .data + .as_ref() + .map(|d| d.data.len() as u32) + .unwrap_or(0), + }, + ), + }, + }) + .collect(), + relay_fee: invoke_cpi_ref.relay_fee.is_some(), + compress_or_decompress_lamports: invoke_cpi_ref.compress_or_decompress_lamports.is_some(), + cpi_context: ( + invoke_cpi_ref.cpi_context.is_some(), + CompressedCpiContextConfig {}, + ), + } +} + +#[test] +fn test_invoke_ix_data_deserialize_rnd() { + use rand::{rngs::StdRng, Rng, SeedableRng}; + let mut thread_rng = ThreadRng::default(); + let seed = thread_rng.gen(); + // Keep this print so that in case the test fails + // we can use the seed to reproduce the error. + println!("\n\ne2e test seed for invoke_ix_data {}\n\n", seed); + let mut rng = StdRng::seed_from_u64(seed); + + let num_iters = 1000; + for i in 0..num_iters { + // Create randomized instruction data + let invoke_ref = InstructionDataInvoke { + proof: if rng.gen() { + Some(CompressedProof { + a: rng.gen(), + b: (0..64) + .map(|_| rng.gen()) + .collect::>() + .try_into() + .unwrap(), + c: rng.gen(), + }) + } else { + None + }, + input_compressed_accounts_with_merkle_context: if i % 5 == 0 { + // Only add inputs occasionally to keep test manageable + vec![get_rnd_test_input_account(&mut rng); rng.gen_range(1..3)] + } else { + vec![] + }, + output_compressed_accounts: if i % 4 == 0 { + vec![get_rnd_test_output_account(&mut rng); rng.gen_range(1..3)] + } else { + vec![] + }, + relay_fee: None, // Relay fee is currently not supported + new_address_params: if i % 3 == 0 { + vec![get_rnd_new_address_params(&mut rng); rng.gen_range(1..3)] + } else { + vec![] + }, + compress_or_decompress_lamports: if rng.gen() { Some(rng.gen()) } else { None }, + is_compress: rng.gen(), + }; + + // 1. Generate config based on the random data + let config = generate_random_invoke_config(&invoke_ref); + + // 2. Calculate exact buffer size and allocate + let buffer_size = InstructionDataInvoke::byte_len(&config); + let mut bytes = vec![0u8; buffer_size]; + + // 3. Create mutable zero-copy structure and verify exact allocation + { + let result = InstructionDataInvoke::new_zero_copy(&mut bytes, config); + assert!(result.is_ok(), "Failed to create zero-copy structure"); + let (mut zero_copy_mut, remaining) = result.unwrap(); + + // 4. Verify exact buffer allocation + assert_eq!( + remaining.len(), + 0, + "Should have used exactly {} bytes", + buffer_size + ); + + // 5. Populate the mutable zero-copy structure with random data + populate_invoke_zero_copy(&invoke_ref, &mut zero_copy_mut); + }; // Mutable borrow ends here + + let borsh_ref = InstructionDataInvoke::deserialize(&mut bytes.as_slice()).unwrap(); + // 6. Test immutable deserialization to verify round-trip functionality + let result_immut = InstructionDataInvoke::zero_copy_at(&bytes); + assert!( + result_immut.is_ok(), + "Immutable deserialization should succeed" + ); + assert_eq!(invoke_ref, borsh_ref); + + // 7. Test that basic zero-copy deserialization works without crashing + // The main goal is to verify the zero-copy derive macro functionality + println!("✓ Successfully tested InstructionDataInvoke with {} inputs, {} outputs, {} new_addresses", + invoke_ref.input_compressed_accounts_with_merkle_context.len(), + invoke_ref.output_compressed_accounts.len(), + invoke_ref.new_address_params.len()); + } +} + +#[test] +fn test_instruction_data_invoke_cpi_rnd() { + use rand::{rngs::StdRng, Rng, SeedableRng}; + let mut thread_rng = ThreadRng::default(); + let seed = thread_rng.gen(); + // Keep this print so that in case the test fails + // we can use the seed to reproduce the error. + println!("\n\ne2e test seed {}\n\n", seed); + let mut rng = StdRng::seed_from_u64(seed); + + let num_iters = 10_000; + for _ in 0..num_iters { + // 1. Generate random CPI instruction data + let invoke_cpi_ref = get_rnd_instruction_data_invoke_cpi(&mut rng); + + // 2. Generate config based on the random data + let config = generate_random_invoke_cpi_config(&invoke_cpi_ref); + + // 3. Calculate exact buffer size and allocate + let buffer_size = InstructionDataInvokeCpi::byte_len(&config); + let mut bytes = vec![0u8; buffer_size]; + + // 4. Create mutable zero-copy structure and verify exact allocation + { + let result = InstructionDataInvokeCpi::new_zero_copy(&mut bytes, config); + assert!(result.is_ok(), "Failed to create CPI zero-copy structure"); + let (mut zero_copy_mut, remaining) = result.unwrap(); + + // 5. Verify exact buffer allocation + assert_eq!( + remaining.len(), + 0, + "Should have used exactly {} bytes", + buffer_size + ); + + // 6. Populate the mutable zero-copy structure with random data + populate_invoke_cpi_zero_copy(&invoke_cpi_ref, &mut zero_copy_mut); + }; // Mutable borrow ends here + + let borsh_ref = InstructionDataInvokeCpi::deserialize(&mut bytes.as_slice()).unwrap(); + // 7. Test immutable deserialization to verify round-trip functionality + let result_immut = InstructionDataInvokeCpi::zero_copy_at(&bytes); + assert!( + result_immut.is_ok(), + "Immutable deserialization should succeed" + ); + assert_eq!(invoke_cpi_ref, borsh_ref); + + // 8. Test that basic zero-copy deserialization works without crashing + // The main goal is to verify the zero-copy derive macro functionality + println!("✓ Successfully tested InstructionDataInvokeCpi with {} inputs, {} outputs, {} new_addresses", + invoke_cpi_ref.input_compressed_accounts_with_merkle_context.len(), + invoke_cpi_ref.output_compressed_accounts.len(), + invoke_cpi_ref.new_address_params.len()); + } +} diff --git a/program-libs/zero-copy/Cargo.toml b/program-libs/zero-copy/Cargo.toml index ea683e5e48..7b67262bb2 100644 --- a/program-libs/zero-copy/Cargo.toml +++ b/program-libs/zero-copy/Cargo.toml @@ -11,13 +11,17 @@ default = [] solana = ["solana-program-error"] pinocchio = ["dep:pinocchio"] std = [] +derive = ["light-zero-copy-derive"] +mut = ["light-zero-copy-derive/mut"] [dependencies] solana-program-error = { workspace = true, optional = true } pinocchio = { workspace = true, optional = true } thiserror = { workspace = true } zerocopy = { workspace = true } +light-zero-copy-derive = { workspace = true, optional = true } [dev-dependencies] rand = { workspace = true } zerocopy = { workspace = true, features = ["derive"] } +borsh = { workspace = true } diff --git a/program-libs/zero-copy/README.md b/program-libs/zero-copy/README.md index d82ee39232..e28f535469 100644 --- a/program-libs/zero-copy/README.md +++ b/program-libs/zero-copy/README.md @@ -37,6 +37,3 @@ light-zero-copy = { version = "0.1.0", features = ["anchor"] } ### Security Considerations - do not use on a 32 bit target with length greater than u32 - only length until u64 is supported - -### Tests -- `cargo test --features std` diff --git a/program-libs/zero-copy/src/borsh.rs b/program-libs/zero-copy/src/borsh.rs index c7e4fbe4db..a60d73f2ba 100644 --- a/program-libs/zero-copy/src/borsh.rs +++ b/program-libs/zero-copy/src/borsh.rs @@ -5,7 +5,7 @@ use core::{ use std::vec::Vec; use zerocopy::{ - little_endian::{U16, U32, U64}, + little_endian::{I16, I32, I64, U16, U32, U64}, FromBytes, Immutable, KnownLayout, Ref, }; @@ -52,8 +52,6 @@ impl<'a, T: Deserialize<'a>> Deserialize<'a> for Option { impl Deserialize<'_> for u8 { type Output = Self; - /// Not a zero copy but cheaper. - /// A u8 should not be deserialized on it's own but as part of a struct. #[inline] fn zero_copy_at(bytes: &[u8]) -> Result<(u8, &[u8]), ZeroCopyError> { if bytes.len() < size_of::() { @@ -64,23 +62,59 @@ impl Deserialize<'_> for u8 { } } +impl<'a> Deserialize<'a> for bool { + type Output = u8; + + #[inline] + fn zero_copy_at(bytes: &'a [u8]) -> Result<(Self::Output, &'a [u8]), ZeroCopyError> { + if bytes.len() < size_of::() { + return Err(ZeroCopyError::ArraySize(1, bytes.len())); + } + let (bytes, remaining_bytes) = bytes.split_at(size_of::()); + Ok((bytes[0], remaining_bytes)) + } +} + macro_rules! impl_deserialize_for_primitive { - ($($t:ty),*) => { + ($(($native:ty, $zerocopy:ty)),*) => { $( - impl<'a> Deserialize<'a> for $t { - type Output = Ref<&'a [u8], $t>; + impl<'a> Deserialize<'a> for $native { + type Output = Ref<&'a [u8], $zerocopy>; #[inline] fn zero_copy_at(bytes: &'a [u8]) -> Result<(Self::Output, &'a [u8]), ZeroCopyError> { - Self::Output::zero_copy_at(bytes) + Ref::<&'a [u8], $zerocopy>::from_prefix(bytes).map_err(ZeroCopyError::from) } } )* }; } -impl_deserialize_for_primitive!(u16, i16, u32, i32, u64, i64); -impl_deserialize_for_primitive!(U16, U32, U64); +impl_deserialize_for_primitive!( + (u16, U16), + (u32, U32), + (u64, U64), + (i16, I16), + (i32, I32), + (i64, I64), + (U16, U16), + (U32, U32), + (U64, U64), + (I16, I16), + (I32, I32), + (I64, I64) +); + +// Implement Deserialize for fixed-size array types +impl<'a, T: KnownLayout + Immutable + FromBytes, const N: usize> Deserialize<'a> for [T; N] { + type Output = Ref<&'a [u8], [T; N]>; + + #[inline] + fn zero_copy_at(bytes: &'a [u8]) -> Result<(Self::Output, &'a [u8]), ZeroCopyError> { + let (bytes, remaining_bytes) = Ref::<&'a [u8], [T; N]>::from_prefix(bytes)?; + Ok((bytes, remaining_bytes)) + } +} impl<'a, T: Deserialize<'a>> Deserialize<'a> for Vec { type Output = Vec; @@ -88,8 +122,14 @@ impl<'a, T: Deserialize<'a>> Deserialize<'a> for Vec { fn zero_copy_at(bytes: &'a [u8]) -> Result<(Self::Output, &'a [u8]), ZeroCopyError> { let (num_slices, mut bytes) = Ref::<&[u8], U32>::from_prefix(bytes)?; let num_slices = u32::from(*num_slices) as usize; - // TODO: add check that remaining data is enough to read num_slices - // This prevents agains invalid data allocating a lot of heap memory + // Prevent heap exhaustion attacks by checking if num_slices is reasonable + // Each element needs at least 1 byte when serialized + if bytes.len() < num_slices { + return Err(ZeroCopyError::InsufficientMemoryAllocated( + bytes.len(), + num_slices, + )); + } let mut slices = Vec::with_capacity(num_slices); for _ in 0..num_slices { let (slice, _bytes) = T::zero_copy_at(bytes)?; @@ -138,6 +178,55 @@ impl<'a, T: Deserialize<'a>> Deserialize<'a> for VecU8 { } } +pub trait ZeroCopyStructInner { + type ZeroCopyInner; +} + +impl ZeroCopyStructInner for u64 { + type ZeroCopyInner = U64; +} +impl ZeroCopyStructInner for u32 { + type ZeroCopyInner = U32; +} +impl ZeroCopyStructInner for u16 { + type ZeroCopyInner = U16; +} +impl ZeroCopyStructInner for u8 { + type ZeroCopyInner = u8; +} + +impl ZeroCopyStructInner for U16 { + type ZeroCopyInner = U16; +} +impl ZeroCopyStructInner for U32 { + type ZeroCopyInner = U32; +} +impl ZeroCopyStructInner for U64 { + type ZeroCopyInner = U64; +} + +impl ZeroCopyStructInner for Vec { + type ZeroCopyInner = Vec; +} + +impl ZeroCopyStructInner for Option { + type ZeroCopyInner = Option; +} + +// Add ZeroCopyStructInner for array types +impl ZeroCopyStructInner for [u8; N] { + type ZeroCopyInner = Ref<&'static [u8], [u8; N]>; +} + +pub fn borsh_vec_u8_as_slice(bytes: &[u8]) -> Result<(&[u8], &[u8]), ZeroCopyError> { + let (num_slices, bytes) = Ref::<&[u8], U32>::from_prefix(bytes)?; + let num_slices = u32::from(*num_slices) as usize; + if num_slices > bytes.len() { + return Err(ZeroCopyError::ArraySize(num_slices, bytes.len())); + } + Ok(bytes.split_at(num_slices)) +} + #[test] fn test_vecu8() { use std::vec; @@ -224,3 +313,561 @@ fn test_deserialize_vecu8() { assert_eq!(vec, std::vec![4, 5, 6]); assert_eq!(remaining, &[]); } + +#[cfg(test)] +pub mod test { + use std::vec; + + use borsh::{BorshDeserialize, BorshSerialize}; + use zerocopy::{ + little_endian::{U16, U64}, + IntoBytes, Ref, Unaligned, + }; + + use super::{ZeroCopyStructInner, *}; + use crate::slice::ZeroCopySliceBorsh; + + // Rules: + // 1. create ZStruct for the struct + // 1.1. the first fields are extracted into a meta struct until we reach a Vec, Option or type that does not implement Copy, and we implement deref for the meta struct + // 1.2. represent vectors to ZeroCopySlice & don't include these into the meta struct + // 1.3. replace u16 with U16, u32 with U32, etc + // 1.4. every field after the first vector is directly included in the ZStruct and deserialized 1 by 1 + // 1.5. If a vector contains a nested vector (does not implement Copy) it must implement Deserialize + // 1.6. Elements in an Option must implement Deserialize + // 1.7. a type that does not implement Copy must implement Deserialize, and is deserialized 1 by 1 + + // Derive Macro needs to derive: + // 1. ZeroCopyStructInner + // 2. Deserialize + // 3. PartialEq for ZStruct<'_> + // + // For every struct1 - struct7 create struct_derived1 - struct_derived7 and replicate the tests for the new structs. + + // Tests for manually implemented structures (without derive macro) + + #[repr(C)] + #[derive(Debug, PartialEq, BorshSerialize, BorshDeserialize)] + pub struct Struct1 { + pub a: u8, + pub b: u16, + } + + // pub fn data_hash_struct_1(a: u8, b: u16) -> [u8; 32] { + + // } + + #[repr(C)] + #[derive(Debug, PartialEq, KnownLayout, Immutable, Unaligned, FromBytes)] + pub struct ZStruct1Meta { + pub a: u8, + pub b: U16, + } + + #[repr(C)] + #[derive(Debug, PartialEq)] + pub struct ZStruct1<'a> { + meta: Ref<&'a [u8], ZStruct1Meta>, + } + impl<'a> Deref for ZStruct1<'a> { + type Target = Ref<&'a [u8], ZStruct1Meta>; + + fn deref(&self) -> &Self::Target { + &self.meta + } + } + + impl<'a> Deserialize<'a> for Struct1 { + type Output = ZStruct1<'a>; + + fn zero_copy_at(bytes: &'a [u8]) -> Result<(Self::Output, &'a [u8]), ZeroCopyError> { + let (meta, bytes) = Ref::<&[u8], ZStruct1Meta>::from_prefix(bytes)?; + Ok((ZStruct1 { meta }, bytes)) + } + } + + #[test] + fn test_struct_1() { + let bytes = Struct1 { a: 1, b: 2 }.try_to_vec().unwrap(); + let (struct1, remaining) = Struct1::zero_copy_at(&bytes).unwrap(); + assert_eq!(struct1.a, 1u8); + assert_eq!(struct1.b, 2u16); + assert_eq!(remaining, &[]); + } + + #[repr(C)] + #[derive(Debug, PartialEq, Clone, BorshSerialize, BorshDeserialize)] + pub struct Struct2 { + pub a: u8, + pub b: u16, + pub vec: Vec, + } + + #[repr(C)] + #[derive(Debug, PartialEq, KnownLayout, Immutable, Unaligned, FromBytes)] + pub struct ZStruct2Meta { + pub a: u8, + pub b: U16, + } + + #[repr(C)] + #[derive(Debug, PartialEq)] + pub struct ZStruct2<'a> { + meta: Ref<&'a [u8], ZStruct2Meta>, + pub vec: as ZeroCopyStructInner>::ZeroCopyInner, + } + + impl PartialEq for ZStruct2<'_> { + fn eq(&self, other: &Struct2) -> bool { + let meta: &ZStruct2Meta = &self.meta; + if meta.a != other.a || other.b != meta.b.into() { + return false; + } + self.vec.as_slice() == other.vec.as_slice() + } + } + + impl<'a> Deref for ZStruct2<'a> { + type Target = Ref<&'a [u8], ZStruct2Meta>; + + fn deref(&self) -> &Self::Target { + &self.meta + } + } + + impl<'a> Deserialize<'a> for Struct2 { + type Output = ZStruct2<'a>; + + fn zero_copy_at(bytes: &'a [u8]) -> Result<(Self::Output, &'a [u8]), ZeroCopyError> { + let (meta, bytes) = Ref::<&[u8], ZStruct2Meta>::from_prefix(bytes)?; + let (vec, bytes) = as Deserialize>::zero_copy_at(bytes)?; + Ok((ZStruct2 { meta, vec }, bytes)) + } + } + + #[test] + fn test_struct_2() { + let bytes = Struct2 { + a: 1, + b: 2, + vec: vec![1u8; 32], + } + .try_to_vec() + .unwrap(); + let (struct2, remaining) = Struct2::zero_copy_at(&bytes).unwrap(); + assert_eq!(struct2.a, 1u8); + assert_eq!(struct2.b, 2u16); + assert_eq!(struct2.vec.to_vec(), vec![1u8; 32]); + assert_eq!(remaining, &[]); + } + + #[repr(C)] + #[derive(Debug, PartialEq, BorshSerialize, BorshDeserialize)] + pub struct Struct3 { + pub a: u8, + pub b: u16, + pub vec: Vec, + pub c: u64, + } + + #[repr(C)] + #[derive(Debug, PartialEq, KnownLayout, Immutable, Unaligned, FromBytes)] + pub struct ZStruct3Meta { + pub a: u8, + pub b: U16, + } + + #[derive(Debug, PartialEq)] + pub struct ZStruct3<'a> { + meta: Ref<&'a [u8], ZStruct3Meta>, + pub vec: ZeroCopySliceBorsh<'a, u8>, + pub c: Ref<&'a [u8], U64>, + } + + impl<'a> Deref for ZStruct3<'a> { + type Target = Ref<&'a [u8], ZStruct3Meta>; + + fn deref(&self) -> &Self::Target { + &self.meta + } + } + + impl<'a> Deserialize<'a> for Struct3 { + type Output = ZStruct3<'a>; + + fn zero_copy_at(bytes: &'a [u8]) -> Result<(Self::Output, &'a [u8]), ZeroCopyError> { + let (meta, bytes) = Ref::<&[u8], ZStruct3Meta>::from_prefix(bytes)?; + let (vec, bytes) = ZeroCopySliceBorsh::zero_copy_at(bytes)?; + let (c, bytes) = Ref::<&[u8], U64>::from_prefix(bytes)?; + Ok((Self::Output { meta, vec, c }, bytes)) + } + } + + #[test] + fn test_struct_3() { + let bytes = Struct3 { + a: 1, + b: 2, + vec: vec![1u8; 32], + c: 3, + } + .try_to_vec() + .unwrap(); + let (zero_copy, remaining) = Struct3::zero_copy_at(&bytes).unwrap(); + assert_eq!(zero_copy.a, 1u8); + assert_eq!(zero_copy.b, 2u16); + assert_eq!(zero_copy.vec.to_vec(), vec![1u8; 32]); + assert_eq!(u64::from(*zero_copy.c), 3); + assert_eq!(remaining, &[]); + } + + #[repr(C)] + #[derive(Debug, PartialEq, BorshSerialize, BorshDeserialize, Clone)] + pub struct Struct4Nested { + a: u8, + b: u16, + } + + #[repr(C)] + #[derive( + Debug, PartialEq, Copy, Clone, KnownLayout, Immutable, IntoBytes, Unaligned, FromBytes, + )] + pub struct ZStruct4Nested { + pub a: u8, + pub b: U16, + } + + impl ZeroCopyStructInner for Struct4Nested { + type ZeroCopyInner = ZStruct4Nested; + } + + #[repr(C)] + #[derive(Debug, PartialEq, BorshSerialize, BorshDeserialize)] + pub struct Struct4 { + pub a: u8, + pub b: u16, + pub vec: Vec, + pub c: u64, + pub vec_2: Vec, + } + + #[repr(C)] + #[derive(Debug, PartialEq, KnownLayout, Immutable, Unaligned, IntoBytes, FromBytes)] + pub struct ZStruct4Meta { + pub a: ::ZeroCopyInner, + pub b: ::ZeroCopyInner, + } + + #[derive(Debug, PartialEq)] + pub struct ZStruct4<'a> { + meta: Ref<&'a [u8], ZStruct4Meta>, + pub vec: ZeroCopySliceBorsh<'a, ::ZeroCopyInner>, + pub c: Ref<&'a [u8], ::ZeroCopyInner>, + pub vec_2: ZeroCopySliceBorsh<'a, ::ZeroCopyInner>, + } + + impl<'a> Deref for ZStruct4<'a> { + type Target = Ref<&'a [u8], ZStruct4Meta>; + + fn deref(&self) -> &Self::Target { + &self.meta + } + } + + impl<'a> Deserialize<'a> for Struct4 { + type Output = ZStruct4<'a>; + + fn zero_copy_at(bytes: &'a [u8]) -> Result<(Self::Output, &'a [u8]), ZeroCopyError> { + let (meta, bytes) = Ref::<&[u8], ZStruct4Meta>::from_prefix(bytes)?; + let (vec, bytes) = ZeroCopySliceBorsh::from_bytes_at(bytes)?; + let (c, bytes) = + Ref::<&[u8], ::ZeroCopyInner>::from_prefix(bytes)?; + let (vec_2, bytes) = ZeroCopySliceBorsh::from_bytes_at(bytes)?; + Ok(( + Self::Output { + meta, + vec, + c, + vec_2, + }, + bytes, + )) + } + } + + #[test] + fn test_struct_4() { + let bytes = Struct4 { + a: 1, + b: 2, + vec: vec![1u8; 32], + c: 3, + vec_2: vec![Struct4Nested { a: 1, b: 2 }; 32], + } + .try_to_vec() + .unwrap(); + let (zero_copy, remaining) = Struct4::zero_copy_at(&bytes).unwrap(); + assert_eq!(zero_copy.a, 1u8); + assert_eq!(zero_copy.b, 2u16); + assert_eq!(zero_copy.vec.to_vec(), vec![1u8; 32]); + assert_eq!(u64::from(*zero_copy.c), 3); + assert_eq!( + zero_copy.vec_2.to_vec(), + vec![ZStruct4Nested { a: 1, b: 2.into() }; 32] + ); + assert_eq!(remaining, &[]); + } + + #[repr(C)] + #[derive(Debug, Clone, PartialEq, BorshSerialize, BorshDeserialize)] + pub struct Struct5 { + pub a: Vec>, + } + + #[repr(C)] + #[derive(Debug, PartialEq)] + pub struct ZStruct5<'a> { + pub a: Vec::ZeroCopyInner>>, + } + + impl<'a> Deserialize<'a> for Struct5 { + type Output = ZStruct5<'a>; + + fn zero_copy_at(bytes: &'a [u8]) -> Result<(Self::Output, &'a [u8]), ZeroCopyError> { + let (a, bytes) = Vec::::ZeroCopyInner>>::zero_copy_at(bytes)?; + Ok((ZStruct5 { a }, bytes)) + } + } + + #[test] + fn test_struct_5() { + let bytes = Struct5 { + a: vec![vec![1u8; 32]; 32], + } + .try_to_vec() + .unwrap(); + let (zero_copy, remaining) = Struct5::zero_copy_at(&bytes).unwrap(); + assert_eq!( + zero_copy.a.iter().map(|x| x.to_vec()).collect::>(), + vec![vec![1u8; 32]; 32] + ); + assert_eq!(remaining, &[]); + } + + // If a struct inside a vector contains a vector it must implement Deserialize. + #[repr(C)] + #[derive(Debug, PartialEq, BorshSerialize, BorshDeserialize)] + pub struct Struct6 { + pub a: Vec, + } + + #[repr(C)] + #[derive(Debug, PartialEq)] + pub struct ZStruct6<'a> { + pub a: Vec<>::Output>, + } + + impl<'a> Deserialize<'a> for Struct6 { + type Output = ZStruct6<'a>; + + fn zero_copy_at(bytes: &'a [u8]) -> Result<(Self::Output, &'a [u8]), ZeroCopyError> { + let (a, bytes) = Vec::::zero_copy_at(bytes)?; + Ok((ZStruct6 { a }, bytes)) + } + } + + #[test] + fn test_struct_6() { + let bytes = Struct6 { + a: vec![ + Struct2 { + a: 1, + b: 2, + vec: vec![1u8; 32], + }; + 32 + ], + } + .try_to_vec() + .unwrap(); + let (zero_copy, remaining) = Struct6::zero_copy_at(&bytes).unwrap(); + assert_eq!( + zero_copy.a.iter().collect::>(), + vec![ + &Struct2 { + a: 1, + b: 2, + vec: vec![1u8; 32], + }; + 32 + ] + ); + assert_eq!(remaining, &[]); + } + + #[repr(C)] + #[derive(Debug, PartialEq, Clone, BorshSerialize, BorshDeserialize)] + pub struct Struct7 { + pub a: u8, + pub b: u16, + pub option: Option, + } + + #[repr(C)] + #[derive(Debug, PartialEq, KnownLayout, Immutable, Unaligned, FromBytes)] + pub struct ZStruct7Meta { + pub a: u8, + pub b: U16, + } + + #[repr(C)] + #[derive(Debug, PartialEq)] + pub struct ZStruct7<'a> { + meta: Ref<&'a [u8], ZStruct7Meta>, + pub option: as ZeroCopyStructInner>::ZeroCopyInner, + } + + impl PartialEq for ZStruct7<'_> { + fn eq(&self, other: &Struct7) -> bool { + let meta: &ZStruct7Meta = &self.meta; + if meta.a != other.a || other.b != meta.b.into() { + return false; + } + self.option == other.option + } + } + + impl<'a> Deref for ZStruct7<'a> { + type Target = Ref<&'a [u8], ZStruct7Meta>; + + fn deref(&self) -> &Self::Target { + &self.meta + } + } + + impl<'a> Deserialize<'a> for Struct7 { + type Output = ZStruct7<'a>; + + fn zero_copy_at(bytes: &'a [u8]) -> Result<(Self::Output, &'a [u8]), ZeroCopyError> { + let (meta, bytes) = Ref::<&[u8], ZStruct7Meta>::from_prefix(bytes)?; + let (option, bytes) = as Deserialize>::zero_copy_at(bytes)?; + Ok((ZStruct7 { meta, option }, bytes)) + } + } + + #[test] + fn test_struct_7() { + let bytes = Struct7 { + a: 1, + b: 2, + option: Some(3), + } + .try_to_vec() + .unwrap(); + let (zero_copy, remaining) = Struct7::zero_copy_at(&bytes).unwrap(); + assert_eq!(zero_copy.a, 1u8); + assert_eq!(zero_copy.b, 2u16); + assert_eq!(zero_copy.option, Some(3)); + assert_eq!(remaining, &[]); + + let bytes = Struct7 { + a: 1, + b: 2, + option: None, + } + .try_to_vec() + .unwrap(); + let (zero_copy, remaining) = Struct7::zero_copy_at(&bytes).unwrap(); + assert_eq!(zero_copy.a, 1u8); + assert_eq!(zero_copy.b, 2u16); + assert_eq!(zero_copy.option, None); + assert_eq!(remaining, &[]); + } + + // If a struct inside a vector contains a vector it must implement Deserialize. + #[repr(C)] + #[derive(Debug, PartialEq, BorshSerialize, BorshDeserialize)] + pub struct Struct8 { + pub a: Vec, + } + + #[derive(Debug, Clone, PartialEq, BorshSerialize, BorshDeserialize)] + pub struct NestedStruct { + pub a: u8, + pub b: Struct2, + } + + #[repr(C)] + #[derive(Debug, PartialEq)] + pub struct ZNestedStruct<'a> { + pub a: ::ZeroCopyInner, + pub b: >::Output, + } + + impl<'a> Deserialize<'a> for NestedStruct { + type Output = ZNestedStruct<'a>; + + fn zero_copy_at(bytes: &'a [u8]) -> Result<(Self::Output, &'a [u8]), ZeroCopyError> { + let (a, bytes) = ::ZeroCopyInner::zero_copy_at(bytes)?; + let (b, bytes) = ::zero_copy_at(bytes)?; + Ok((ZNestedStruct { a, b }, bytes)) + } + } + + impl PartialEq for ZNestedStruct<'_> { + fn eq(&self, other: &NestedStruct) -> bool { + self.a == other.a && self.b == other.b + } + } + + #[repr(C)] + #[derive(Debug, PartialEq)] + pub struct ZStruct8<'a> { + pub a: Vec<>::Output>, + } + + impl<'a> Deserialize<'a> for Struct8 { + type Output = ZStruct8<'a>; + + fn zero_copy_at(bytes: &'a [u8]) -> Result<(Self::Output, &'a [u8]), ZeroCopyError> { + let (a, bytes) = Vec::::zero_copy_at(bytes)?; + Ok((ZStruct8 { a }, bytes)) + } + } + + #[test] + fn test_struct_8() { + let bytes = Struct8 { + a: vec![ + NestedStruct { + a: 1, + b: Struct2 { + a: 1, + b: 2, + vec: vec![1u8; 32], + }, + }; + 32 + ], + } + .try_to_vec() + .unwrap(); + + let (zero_copy, remaining) = Struct8::zero_copy_at(&bytes).unwrap(); + assert_eq!( + zero_copy.a.iter().collect::>(), + vec![ + &NestedStruct { + a: 1, + b: Struct2 { + a: 1, + b: 2, + vec: vec![1u8; 32], + }, + }; + 32 + ] + ); + assert_eq!(remaining, &[]); + } +} diff --git a/program-libs/zero-copy/src/borsh_mut.rs b/program-libs/zero-copy/src/borsh_mut.rs new file mode 100644 index 0000000000..38d5df2b65 --- /dev/null +++ b/program-libs/zero-copy/src/borsh_mut.rs @@ -0,0 +1,965 @@ +use core::{ + mem::size_of, + ops::{Deref, DerefMut}, +}; +use std::vec::Vec; + +use zerocopy::{ + little_endian::{I16, I32, I64, U16, U32, U64}, + FromBytes, Immutable, KnownLayout, Ref, +}; + +use crate::errors::ZeroCopyError; + +pub trait DeserializeMut<'a> +where + Self: Sized, +{ + // TODO: rename to ZeroCopy, can be used as ::ZeroCopy + type Output; + fn zero_copy_at_mut(bytes: &'a mut [u8]) + -> Result<(Self::Output, &'a mut [u8]), ZeroCopyError>; +} + +// Implement DeserializeMut for fixed-size array types +impl<'a, T: KnownLayout + Immutable + FromBytes, const N: usize> DeserializeMut<'a> for [T; N] { + type Output = Ref<&'a mut [u8], [T; N]>; + + #[inline] + fn zero_copy_at_mut( + bytes: &'a mut [u8], + ) -> Result<(Self::Output, &'a mut [u8]), ZeroCopyError> { + let (bytes, remaining_bytes) = Ref::<&'a mut [u8], [T; N]>::from_prefix(bytes)?; + Ok((bytes, remaining_bytes)) + } +} + +impl<'a, T: DeserializeMut<'a>> DeserializeMut<'a> for Option { + type Output = Option; + + #[inline] + fn zero_copy_at_mut( + bytes: &'a mut [u8], + ) -> Result<(Self::Output, &'a mut [u8]), ZeroCopyError> { + if bytes.len() < size_of::() { + return Err(ZeroCopyError::ArraySize(1, bytes.len())); + } + let (option_byte, bytes) = bytes.split_at_mut(1); + Ok(match option_byte[0] { + 0u8 => (None, bytes), + 1u8 => { + let (value, bytes) = T::zero_copy_at_mut(bytes)?; + (Some(value), bytes) + } + _ => return Err(ZeroCopyError::InvalidOptionByte(option_byte[0])), + }) + } +} + +impl<'a> DeserializeMut<'a> for u8 { + type Output = Ref<&'a mut [u8], u8>; + + #[inline] + fn zero_copy_at_mut( + bytes: &'a mut [u8], + ) -> Result<(Self::Output, &'a mut [u8]), ZeroCopyError> { + Ref::<&'a mut [u8], u8>::from_prefix(bytes).map_err(ZeroCopyError::from) + } +} + +impl<'a> DeserializeMut<'a> for bool { + type Output = Ref<&'a mut [u8], u8>; + + #[inline] + fn zero_copy_at_mut( + bytes: &'a mut [u8], + ) -> Result<(Self::Output, &'a mut [u8]), ZeroCopyError> { + Ref::<&'a mut [u8], u8>::from_prefix(bytes).map_err(ZeroCopyError::from) + } +} + +// Implementation for specific zerocopy little-endian types +impl<'a, T: KnownLayout + Immutable + FromBytes> DeserializeMut<'a> for Ref<&'a mut [u8], T> { + type Output = Ref<&'a mut [u8], T>; + + #[inline] + fn zero_copy_at_mut( + bytes: &'a mut [u8], + ) -> Result<(Self::Output, &'a mut [u8]), ZeroCopyError> { + let (bytes, remaining_bytes) = Ref::<&mut [u8], T>::from_prefix(bytes)?; + Ok((bytes, remaining_bytes)) + } +} + +impl<'a, T: DeserializeMut<'a>> DeserializeMut<'a> for Vec { + type Output = Vec; + #[inline] + fn zero_copy_at_mut( + bytes: &'a mut [u8], + ) -> Result<(Self::Output, &'a mut [u8]), ZeroCopyError> { + let (num_slices, mut bytes) = Ref::<&mut [u8], U32>::from_prefix(bytes)?; + let num_slices = u32::from(*num_slices) as usize; + // Prevent heap exhaustion attacks by checking if num_slices is reasonable + // Each element needs at least 1 byte when serialized + if bytes.len() < num_slices { + return Err(ZeroCopyError::InsufficientMemoryAllocated( + bytes.len(), + num_slices, + )); + } + let mut slices = Vec::with_capacity(num_slices); + for _ in 0..num_slices { + let (slice, _bytes) = T::zero_copy_at_mut(bytes)?; + bytes = _bytes; + slices.push(slice); + } + Ok((slices, bytes)) + } +} + +macro_rules! impl_deserialize_for_primitive { + ($(($native:ty, $zerocopy:ty)),*) => { + $( + impl<'a> DeserializeMut<'a> for $native { + type Output = Ref<&'a mut [u8], $zerocopy>; + + #[inline] + fn zero_copy_at_mut(bytes: &'a mut [u8]) -> Result<(Self::Output, &'a mut [u8]), ZeroCopyError> { + Ref::<&'a mut [u8], $zerocopy>::from_prefix(bytes).map_err(ZeroCopyError::from) + } + } + )* + }; +} + +impl_deserialize_for_primitive!( + (u16, U16), + (u32, U32), + (u64, U64), + (i16, I16), + (i32, I32), + (i64, I64), + (U16, U16), + (U32, U32), + (U64, U64), + (I16, I16), + (I32, I32), + (I64, I64) +); + +pub fn borsh_vec_u8_as_slice_mut( + bytes: &mut [u8], +) -> Result<(&mut [u8], &mut [u8]), ZeroCopyError> { + let (num_slices, bytes) = Ref::<&mut [u8], U32>::from_prefix(bytes)?; + let num_slices = u32::from(*num_slices) as usize; + if num_slices > bytes.len() { + return Err(ZeroCopyError::ArraySize(num_slices, bytes.len())); + } + Ok(bytes.split_at_mut(num_slices)) +} + +#[derive(Clone, Debug, Default, PartialEq)] +pub struct VecU8(Vec); +impl VecU8 { + pub fn new() -> Self { + Self(Vec::new()) + } +} + +impl Deref for VecU8 { + type Target = Vec; + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl DerefMut for VecU8 { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.0 + } +} + +impl<'a, T: DeserializeMut<'a>> DeserializeMut<'a> for VecU8 { + type Output = Vec; + + #[inline] + fn zero_copy_at_mut( + bytes: &'a mut [u8], + ) -> Result<(Self::Output, &'a mut [u8]), ZeroCopyError> { + let (num_slices, mut bytes) = Ref::<&mut [u8], u8>::from_prefix(bytes)?; + let mut slices = Vec::with_capacity(*num_slices as usize); + for _ in 0..(*num_slices as usize) { + let (slice, _bytes) = T::zero_copy_at_mut(bytes)?; + bytes = _bytes; + slices.push(slice); + } + Ok((slices, bytes)) + } +} + +pub trait ZeroCopyStructInnerMut { + type ZeroCopyInnerMut; +} + +impl ZeroCopyStructInnerMut for u64 { + type ZeroCopyInnerMut = U64; +} +impl ZeroCopyStructInnerMut for u32 { + type ZeroCopyInnerMut = U32; +} +impl ZeroCopyStructInnerMut for u16 { + type ZeroCopyInnerMut = U16; +} +impl ZeroCopyStructInnerMut for u8 { + type ZeroCopyInnerMut = u8; +} +impl ZeroCopyStructInnerMut for U64 { + type ZeroCopyInnerMut = U64; +} +impl ZeroCopyStructInnerMut for U32 { + type ZeroCopyInnerMut = U32; +} +impl ZeroCopyStructInnerMut for U16 { + type ZeroCopyInnerMut = U16; +} + +impl ZeroCopyStructInnerMut for Vec { + type ZeroCopyInnerMut = Vec; +} + +impl ZeroCopyStructInnerMut for Option { + type ZeroCopyInnerMut = Option; +} + +impl ZeroCopyStructInnerMut for [u8; N] { + type ZeroCopyInnerMut = Ref<&'static mut [u8], [u8; N]>; +} + +#[test] +fn test_vecu8() { + use std::vec; + let mut bytes = vec![8, 1u8, 2, 3, 4, 5, 6, 7, 8]; + let (vec, remaining_bytes) = VecU8::::zero_copy_at_mut(&mut bytes).unwrap(); + assert_eq!( + vec.iter().map(|x| **x).collect::>(), + vec![1u8, 2, 3, 4, 5, 6, 7, 8] + ); + assert_eq!(remaining_bytes, &mut []); +} + +#[test] +fn test_deserialize_mut_ref() { + let mut bytes = [1, 0, 0, 0]; // Little-endian representation of 1 + let (ref_data, remaining) = Ref::<&mut [u8], U32>::zero_copy_at_mut(&mut bytes).unwrap(); + assert_eq!(u32::from(*ref_data), 1); + assert_eq!(remaining, &mut []); + let res = Ref::<&mut [u8], U32>::zero_copy_at_mut(&mut []); + assert_eq!(res, Err(ZeroCopyError::Size)); +} + +#[test] +fn test_deserialize_mut_option_some() { + let mut bytes = [1, 2]; // 1 indicates Some, followed by the value 2 + let (option_value, remaining) = Option::::zero_copy_at_mut(&mut bytes).unwrap(); + assert_eq!(option_value.map(|x| *x), Some(2)); + assert_eq!(remaining, &mut []); + let res = Option::::zero_copy_at_mut(&mut []); + assert_eq!(res, Err(ZeroCopyError::ArraySize(1, 0))); + let mut bytes = [2, 0]; // 2 indicates invalid option byte + let res = Option::::zero_copy_at_mut(&mut bytes); + assert_eq!(res, Err(ZeroCopyError::InvalidOptionByte(2))); +} + +#[test] +fn test_deserialize_mut_option_none() { + let mut bytes = [0]; // 0 indicates None + let (option_value, remaining) = Option::::zero_copy_at_mut(&mut bytes).unwrap(); + assert_eq!(option_value, None); + assert_eq!(remaining, &mut []); +} + +#[test] +fn test_deserialize_mut_u8() { + let mut bytes = [0xFF]; // Value 255 + let (value, remaining) = u8::zero_copy_at_mut(&mut bytes).unwrap(); + assert_eq!(*value, 255); + assert_eq!(remaining, &mut []); + let res = u8::zero_copy_at_mut(&mut []); + assert_eq!(res, Err(ZeroCopyError::Size)); +} + +#[test] +fn test_deserialize_mut_u16() { + let mut bytes = 2323u16.to_le_bytes(); + let (value, remaining) = u16::zero_copy_at_mut(bytes.as_mut_slice()).unwrap(); + assert_eq!(*value, 2323u16); + assert_eq!(remaining, &mut []); + let mut value = [0u8]; + let res = u16::zero_copy_at_mut(&mut value); + + assert_eq!(res, Err(ZeroCopyError::Size)); +} + +#[test] +fn test_deserialize_mut_vec() { + let mut bytes = [2, 0, 0, 0, 1, 2]; // Length 2, followed by values 1 and 2 + let (vec, remaining) = Vec::::zero_copy_at_mut(&mut bytes).unwrap(); + assert_eq!( + vec.iter().map(|x| **x).collect::>(), + std::vec![1u8, 2] + ); + assert_eq!(remaining, &mut []); +} + +#[test] +fn test_vecu8_deref() { + let data = std::vec![1, 2, 3]; + let vec_u8 = VecU8(data.clone()); + assert_eq!(&*vec_u8, &data); + + let mut vec = VecU8::new(); + vec.push(1u8); + assert_eq!(*vec, std::vec![1u8]); +} + +#[test] +fn test_deserialize_mut_vecu8() { + let mut bytes = [3, 4, 5, 6]; // Length 3, followed by values 4, 5, 6 + let (vec, remaining) = VecU8::::zero_copy_at_mut(&mut bytes).unwrap(); + assert_eq!( + vec.iter().map(|x| **x).collect::>(), + std::vec![4, 5, 6] + ); + assert_eq!(remaining, &mut []); +} + +#[cfg(test)] +pub mod test { + use std::vec; + + use borsh::{BorshDeserialize, BorshSerialize}; + use zerocopy::{ + little_endian::{U16, U64}, + IntoBytes, Ref, Unaligned, + }; + + use super::*; + use crate::slice_mut::ZeroCopySliceMutBorsh; + + // Rules: + // 1. create ZStruct for the struct + // 1.1. the first fields are extracted into a meta struct until we reach a Vec, Option or type that does not implement Copy, and we implement deref for the meta struct + // 1.2. represent vectors to ZeroCopySlice & don't include these into the meta struct + // 1.3. replace u16 with U16, u32 with U32, etc + // 1.4. every field after the first vector is directly included in the ZStruct and deserialized 1 by 1 + // 1.5. If a vector contains a nested vector (does not implement Copy) it must implement DeserializeMut + // 1.6. Elements in an Option must implement DeserializeMut + // 1.7. a type that does not implement Copy must implement DeserializeMut, and is deserialized 1 by 1 + + // Derive Macro needs to derive: + // 1. ZeroCopyStructInnerMut + // 2. DeserializeMut + // 3. PartialEq for ZStruct<'_> + // + // For every struct1 - struct7 create struct_derived1 - struct_derived7 and replicate the tests for the new structs. + + // Tests for manually implemented structures (without derive macro) + + #[repr(C)] + #[derive(Debug, PartialEq, BorshSerialize, BorshDeserialize)] + pub struct Struct1 { + pub a: u8, + pub b: u16, + } + + #[repr(C)] + #[derive(Debug, PartialEq, KnownLayout, Immutable, Unaligned, FromBytes, IntoBytes)] + pub struct ZStruct1Meta { + pub a: u8, + pub b: U16, + } + + #[repr(C)] + #[derive(Debug, PartialEq)] + pub struct ZStruct1<'a> { + pub meta: Ref<&'a mut [u8], ZStruct1Meta>, + } + impl<'a> Deref for ZStruct1<'a> { + type Target = Ref<&'a mut [u8], ZStruct1Meta>; + + fn deref(&self) -> &Self::Target { + &self.meta + } + } + + impl DerefMut for ZStruct1<'_> { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.meta + } + } + + impl<'a> DeserializeMut<'a> for Struct1 { + type Output = ZStruct1<'a>; + + fn zero_copy_at_mut( + bytes: &'a mut [u8], + ) -> Result<(Self::Output, &'a mut [u8]), ZeroCopyError> { + let (meta, bytes) = Ref::<&mut [u8], ZStruct1Meta>::from_prefix(bytes)?; + Ok((ZStruct1 { meta }, bytes)) + } + } + + #[test] + fn test_struct_1() { + let ref_struct = Struct1 { a: 1, b: 2 }; + let mut bytes = ref_struct.try_to_vec().unwrap(); + + let (mut struct1, remaining) = Struct1::zero_copy_at_mut(&mut bytes).unwrap(); + assert_eq!(struct1.a, 1u8); + assert_eq!(struct1.b, 2u16); + assert_eq!(remaining, &mut []); + struct1.meta.a = 2; + } + + #[repr(C)] + #[derive(Debug, PartialEq, Clone, BorshSerialize, BorshDeserialize)] + pub struct Struct2 { + pub a: u8, + pub b: u16, + pub vec: Vec, + } + + #[repr(C)] + #[derive(Debug, PartialEq, KnownLayout, Immutable, Unaligned, FromBytes)] + pub struct ZStruct2Meta { + pub a: u8, + pub b: U16, + } + + #[repr(C)] + #[derive(Debug, PartialEq)] + pub struct ZStruct2<'a> { + meta: Ref<&'a mut [u8], ZStruct2Meta>, + pub vec: &'a mut [u8], + } + + impl PartialEq for ZStruct2<'_> { + fn eq(&self, other: &Struct2) -> bool { + let meta: &ZStruct2Meta = &self.meta; + if meta.a != other.a || other.b != meta.b.into() { + return false; + } + self.vec == other.vec.as_slice() + } + } + + impl<'a> Deref for ZStruct2<'a> { + type Target = Ref<&'a mut [u8], ZStruct2Meta>; + + fn deref(&self) -> &Self::Target { + &self.meta + } + } + + impl<'a> DeserializeMut<'a> for Struct2 { + type Output = ZStruct2<'a>; + + fn zero_copy_at_mut( + bytes: &'a mut [u8], + ) -> Result<(Self::Output, &'a mut [u8]), ZeroCopyError> { + let (meta, bytes) = Ref::<&mut [u8], ZStruct2Meta>::from_prefix(bytes)?; + let (len, bytes) = bytes.split_at_mut(4); + let len = U32::from_bytes( + len.try_into() + .map_err(|_| ZeroCopyError::ArraySize(4, len.len()))?, + ); + let (vec, bytes) = bytes.split_at_mut(u32::from(len) as usize); + Ok((ZStruct2 { meta, vec }, bytes)) + } + } + + #[test] + fn test_struct_2() { + let ref_struct = Struct2 { + a: 1, + b: 2, + vec: vec![1u8; 32], + }; + let mut bytes = ref_struct.try_to_vec().unwrap(); + + let (struct2, remaining) = Struct2::zero_copy_at_mut(&mut bytes).unwrap(); + assert_eq!(struct2.a, 1u8); + assert_eq!(struct2.b, 2u16); + assert_eq!(struct2.vec.to_vec(), vec![1u8; 32]); + assert_eq!(remaining, &mut []); + } + + #[repr(C)] + #[derive(Debug, PartialEq, BorshSerialize, BorshDeserialize)] + pub struct Struct3 { + pub a: u8, + pub b: u16, + pub vec: Vec, + pub c: u64, + } + + #[repr(C)] + #[derive(Debug, PartialEq, KnownLayout, Immutable, Unaligned, FromBytes)] + pub struct ZStruct3Meta { + pub a: u8, + pub b: U16, + } + + #[derive(Debug, PartialEq)] + pub struct ZStruct3<'a> { + meta: Ref<&'a mut [u8], ZStruct3Meta>, + pub vec: ZeroCopySliceMutBorsh<'a, u8>, + pub c: Ref<&'a mut [u8], U64>, + } + + impl<'a> Deref for ZStruct3<'a> { + type Target = Ref<&'a mut [u8], ZStruct3Meta>; + + fn deref(&self) -> &Self::Target { + &self.meta + } + } + + impl<'a> DeserializeMut<'a> for Struct3 { + type Output = ZStruct3<'a>; + + fn zero_copy_at_mut( + bytes: &'a mut [u8], + ) -> Result<(Self::Output, &'a mut [u8]), ZeroCopyError> { + let (meta, bytes) = Ref::<&mut [u8], ZStruct3Meta>::from_prefix(bytes)?; + let (vec, bytes) = ZeroCopySliceMutBorsh::zero_copy_at_mut(bytes)?; + let (c, bytes) = Ref::<&mut [u8], U64>::from_prefix(bytes)?; + Ok((Self::Output { meta, vec, c }, bytes)) + } + } + + #[test] + fn test_struct_3() { + let ref_struct = Struct3 { + a: 1, + b: 2, + vec: vec![1u8; 32], + c: 3, + }; + let mut bytes = ref_struct.try_to_vec().unwrap(); + + let (zero_copy, remaining) = Struct3::zero_copy_at_mut(&mut bytes).unwrap(); + assert_eq!(zero_copy.a, 1u8); + assert_eq!(zero_copy.b, 2u16); + assert_eq!(zero_copy.vec.to_vec(), vec![1u8; 32]); + assert_eq!(u64::from(*zero_copy.c), 3); + assert_eq!(remaining, &mut []); + } + + #[repr(C)] + #[derive(Debug, PartialEq, BorshSerialize, BorshDeserialize, Clone)] + pub struct Struct4Nested { + a: u8, + b: u16, + } + + impl<'a> DeserializeMut<'a> for Struct4Nested { + type Output = ZStruct4Nested; + + fn zero_copy_at_mut( + bytes: &'a mut [u8], + ) -> Result<(Self::Output, &'a mut [u8]), ZeroCopyError> { + let (bytes, remaining_bytes) = Ref::<&mut [u8], ZStruct4Nested>::from_prefix(bytes)?; + Ok((*bytes, remaining_bytes)) + } + } + + #[repr(C)] + #[derive( + Debug, PartialEq, Copy, Clone, KnownLayout, Immutable, IntoBytes, Unaligned, FromBytes, + )] + pub struct ZStruct4Nested { + pub a: u8, + pub b: U16, + } + + impl ZeroCopyStructInnerMut for Struct4Nested { + type ZeroCopyInnerMut = ZStruct4Nested; + } + + #[repr(C)] + #[derive(Debug, PartialEq, BorshSerialize, BorshDeserialize)] + pub struct Struct4 { + pub a: u8, + pub b: u16, + pub vec: Vec, + pub c: u64, + pub vec_2: Vec, + } + + #[repr(C)] + #[derive(Debug, PartialEq, KnownLayout, Immutable, Unaligned, IntoBytes, FromBytes)] + pub struct ZStruct4Meta { + pub a: ::ZeroCopyInnerMut, + pub b: ::ZeroCopyInnerMut, + } + + #[derive(Debug, PartialEq)] + pub struct ZStruct4<'a> { + meta: Ref<&'a mut [u8], ZStruct4Meta>, + pub vec: ZeroCopySliceMutBorsh<'a, ::ZeroCopyInnerMut>, + pub c: Ref<&'a mut [u8], ::ZeroCopyInnerMut>, + pub vec_2: + ZeroCopySliceMutBorsh<'a, ::ZeroCopyInnerMut>, + } + + impl<'a> Deref for ZStruct4<'a> { + type Target = Ref<&'a mut [u8], ZStruct4Meta>; + + fn deref(&self) -> &Self::Target { + &self.meta + } + } + + impl<'a> DeserializeMut<'a> for Struct4 { + type Output = ZStruct4<'a>; + + fn zero_copy_at_mut( + bytes: &'a mut [u8], + ) -> Result<(Self::Output, &'a mut [u8]), ZeroCopyError> { + let (meta, bytes) = Ref::<&mut [u8], ZStruct4Meta>::from_prefix(bytes)?; + let (vec, bytes) = ZeroCopySliceMutBorsh::from_bytes_at(bytes)?; + let (c, bytes) = + Ref::<&mut [u8], ::ZeroCopyInnerMut>::from_prefix( + bytes, + )?; + let (vec_2, bytes) = ZeroCopySliceMutBorsh::from_bytes_at(bytes)?; + Ok(( + Self::Output { + meta, + vec, + c, + vec_2, + }, + bytes, + )) + } + } + + /// TODO: + /// - add SIZE const generic DeserializeMut trait + /// - add new with data function + impl Struct4 { + // pub fn byte_len(&self) -> usize { + // size_of::() + // + size_of::() + // + size_of::() * self.vec.len() + // + size_of::() + // + size_of::() * self.vec_2.len() + // } + + pub fn new_with_data<'a>( + bytes: &'a mut [u8], + data: &Struct4, + ) -> (ZStruct4<'a>, &'a mut [u8]) { + let (mut zero_copy, bytes) = + ::zero_copy_at_mut(bytes).unwrap(); + zero_copy.meta.a = data.a; + zero_copy.meta.b = data.b.into(); + zero_copy + .vec + .iter_mut() + .zip(data.vec.iter()) + .for_each(|(x, y)| *x = *y); + (zero_copy, bytes) + } + } + + #[test] + fn test_struct_4() { + let ref_struct = Struct4 { + a: 1, + b: 2, + vec: vec![1u8; 32], + c: 3, + vec_2: vec![Struct4Nested { a: 1, b: 2 }; 32], + }; + let mut bytes = ref_struct.try_to_vec().unwrap(); + + let (zero_copy, remaining) = Struct4::zero_copy_at_mut(&mut bytes).unwrap(); + assert_eq!(zero_copy.a, 1u8); + assert_eq!(zero_copy.b, 2u16); + assert_eq!(zero_copy.vec.to_vec(), vec![1u8; 32]); + assert_eq!(u64::from(*zero_copy.c), 3); + assert_eq!( + zero_copy.vec_2.to_vec(), + vec![ZStruct4Nested { a: 1, b: 2.into() }; 32] + ); + assert_eq!(remaining, &mut []); + } + + #[repr(C)] + #[derive(Debug, Clone, PartialEq, BorshSerialize, BorshDeserialize)] + pub struct Struct5 { + pub a: Vec>, + } + + #[repr(C)] + #[derive(Debug, PartialEq)] + pub struct ZStruct5<'a> { + pub a: Vec::ZeroCopyInnerMut>>, + } + + impl<'a> DeserializeMut<'a> for Struct5 { + type Output = ZStruct5<'a>; + + fn zero_copy_at_mut( + bytes: &'a mut [u8], + ) -> Result<(Self::Output, &'a mut [u8]), ZeroCopyError> { + let (a, bytes) = Vec::< + ZeroCopySliceMutBorsh<::ZeroCopyInnerMut>, + >::zero_copy_at_mut(bytes)?; + Ok((ZStruct5 { a }, bytes)) + } + } + + #[test] + fn test_struct_5() { + let ref_struct = Struct5 { + a: vec![vec![1u8; 32]; 32], + }; + let mut bytes = ref_struct.try_to_vec().unwrap(); + + let (zero_copy, remaining) = Struct5::zero_copy_at_mut(&mut bytes).unwrap(); + assert_eq!( + zero_copy.a.iter().map(|x| x.to_vec()).collect::>(), + vec![vec![1u8; 32]; 32] + ); + assert_eq!(remaining, &mut []); + } + + // If a struct inside a vector contains a vector it must implement DeserializeMut. + #[repr(C)] + #[derive(Debug, PartialEq, BorshSerialize, BorshDeserialize)] + pub struct Struct6 { + pub a: Vec, + } + + #[repr(C)] + #[derive(Debug, PartialEq)] + pub struct ZStruct6<'a> { + pub a: Vec<>::Output>, + } + + impl<'a> DeserializeMut<'a> for Struct6 { + type Output = ZStruct6<'a>; + + fn zero_copy_at_mut( + bytes: &'a mut [u8], + ) -> Result<(Self::Output, &'a mut [u8]), ZeroCopyError> { + let (a, bytes) = Vec::::zero_copy_at_mut(bytes)?; + Ok((ZStruct6 { a }, bytes)) + } + } + + #[test] + fn test_struct_6() { + let ref_struct = Struct6 { + a: vec![ + Struct2 { + a: 1, + b: 2, + vec: vec![1u8; 32], + }; + 32 + ], + }; + let mut bytes = ref_struct.try_to_vec().unwrap(); + + let (zero_copy, remaining) = Struct6::zero_copy_at_mut(&mut bytes).unwrap(); + assert_eq!( + zero_copy.a.iter().collect::>(), + vec![ + &Struct2 { + a: 1, + b: 2, + vec: vec![1u8; 32], + }; + 32 + ] + ); + assert_eq!(remaining, &mut []); + } + + #[repr(C)] + #[derive(Debug, PartialEq, Clone, BorshSerialize, BorshDeserialize)] + pub struct Struct7 { + pub a: u8, + pub b: u16, + pub option: Option, + } + + #[repr(C)] + #[derive(Debug, PartialEq, KnownLayout, Immutable, Unaligned, FromBytes)] + pub struct ZStruct7Meta { + pub a: u8, + pub b: U16, + } + + #[repr(C)] + #[derive(Debug, PartialEq)] + pub struct ZStruct7<'a> { + meta: Ref<&'a mut [u8], ZStruct7Meta>, + pub option: Option<>::Output>, + } + + impl PartialEq for ZStruct7<'_> { + fn eq(&self, other: &Struct7) -> bool { + let meta: &ZStruct7Meta = &self.meta; + if meta.a != other.a || other.b != meta.b.into() { + return false; + } + self.option.as_ref().map(|x| **x) == other.option + } + } + + impl<'a> Deref for ZStruct7<'a> { + type Target = Ref<&'a mut [u8], ZStruct7Meta>; + + fn deref(&self) -> &Self::Target { + &self.meta + } + } + + impl<'a> DeserializeMut<'a> for Struct7 { + type Output = ZStruct7<'a>; + + fn zero_copy_at_mut( + bytes: &'a mut [u8], + ) -> Result<(Self::Output, &'a mut [u8]), ZeroCopyError> { + let (meta, bytes) = Ref::<&mut [u8], ZStruct7Meta>::from_prefix(bytes)?; + let (option, bytes) = as DeserializeMut<'a>>::zero_copy_at_mut(bytes)?; + Ok((ZStruct7 { meta, option }, bytes)) + } + } + + #[test] + fn test_struct_7() { + let ref_struct = Struct7 { + a: 1, + b: 2, + option: Some(3), + }; + let mut bytes = ref_struct.try_to_vec().unwrap(); + + let (zero_copy, remaining) = Struct7::zero_copy_at_mut(&mut bytes).unwrap(); + assert_eq!(zero_copy.a, 1u8); + assert_eq!(zero_copy.b, 2u16); + assert_eq!(zero_copy.option.map(|x| *x), Some(3)); + assert_eq!(remaining, &mut []); + + let ref_struct = Struct7 { + a: 1, + b: 2, + option: None, + }; + let mut bytes = ref_struct.try_to_vec().unwrap(); + + let (zero_copy, remaining) = Struct7::zero_copy_at_mut(&mut bytes).unwrap(); + assert_eq!(zero_copy.a, 1u8); + assert_eq!(zero_copy.b, 2u16); + assert_eq!(zero_copy.option, None); + assert_eq!(remaining, &mut []); + } + + // If a struct inside a vector contains a vector it must implement DeserializeMut. + #[repr(C)] + #[derive(Debug, PartialEq, BorshSerialize, BorshDeserialize)] + pub struct Struct8 { + pub a: Vec, + } + + #[derive(Debug, Clone, PartialEq, BorshSerialize, BorshDeserialize)] + pub struct NestedStruct { + pub a: u8, + pub b: Struct2, + } + + #[repr(C)] + #[derive(Debug, PartialEq)] + pub struct ZNestedStruct<'a> { + pub a: >::Output, + pub b: >::Output, + } + + impl<'a> DeserializeMut<'a> for NestedStruct { + type Output = ZNestedStruct<'a>; + + fn zero_copy_at_mut( + bytes: &'a mut [u8], + ) -> Result<(Self::Output, &'a mut [u8]), ZeroCopyError> { + let (a, bytes) = + ::ZeroCopyInnerMut::zero_copy_at_mut(bytes)?; + let (b, bytes) = >::zero_copy_at_mut(bytes)?; + Ok((ZNestedStruct { a, b }, bytes)) + } + } + + impl PartialEq for ZNestedStruct<'_> { + fn eq(&self, other: &NestedStruct) -> bool { + *self.a == other.a && self.b == other.b + } + } + + #[repr(C)] + #[derive(Debug, PartialEq)] + pub struct ZStruct8<'a> { + pub a: Vec<>::Output>, + } + + impl<'a> DeserializeMut<'a> for Struct8 { + type Output = ZStruct8<'a>; + + fn zero_copy_at_mut( + bytes: &'a mut [u8], + ) -> Result<(Self::Output, &'a mut [u8]), ZeroCopyError> { + let (a, bytes) = Vec::::zero_copy_at_mut(bytes)?; + Ok((ZStruct8 { a }, bytes)) + } + } + + #[test] + fn test_struct_8() { + let ref_struct = Struct8 { + a: vec![ + NestedStruct { + a: 1, + b: Struct2 { + a: 1, + b: 2, + vec: vec![1u8; 32], + }, + }; + 32 + ], + }; + let mut bytes = ref_struct.try_to_vec().unwrap(); + + let (zero_copy, remaining) = Struct8::zero_copy_at_mut(&mut bytes).unwrap(); + assert_eq!( + zero_copy.a.iter().collect::>(), + vec![ + &NestedStruct { + a: 1, + b: Struct2 { + a: 1, + b: 2, + vec: vec![1u8; 32], + }, + }; + 32 + ] + ); + assert_eq!(remaining, &mut []); + } +} diff --git a/program-libs/zero-copy/src/init_mut.rs b/program-libs/zero-copy/src/init_mut.rs new file mode 100644 index 0000000000..c16d371176 --- /dev/null +++ b/program-libs/zero-copy/src/init_mut.rs @@ -0,0 +1,268 @@ +use core::mem::size_of; +use std::vec::Vec; + +use crate::{borsh_mut::DeserializeMut, errors::ZeroCopyError}; + +/// Trait for types that can be initialized in mutable byte slices with configuration +/// +/// This trait provides a way to initialize structures in pre-allocated byte buffers +/// with specific configuration parameters that determine Vec lengths, Option states, etc. +pub trait ZeroCopyNew<'a> +where + Self: Sized, +{ + /// Configuration type needed to initialize this type + type Config; + + /// Output type - the mutable zero-copy view of this type + type Output; + + /// Calculate the byte length needed for this type with the given configuration + /// + /// This is essential for allocating the correct buffer size before calling new_zero_copy + fn byte_len(config: &Self::Config) -> usize; + + /// Initialize this type in a mutable byte slice with the given configuration + /// + /// Returns the initialized mutable view and remaining bytes + fn new_zero_copy( + bytes: &'a mut [u8], + config: Self::Config, + ) -> Result<(Self::Output, &'a mut [u8]), ZeroCopyError>; +} + +// Generic implementation for Option +impl<'a, T> ZeroCopyNew<'a> for Option +where + T: ZeroCopyNew<'a>, +{ + type Config = (bool, T::Config); // (enabled, inner_config) + type Output = Option; + + fn byte_len(config: &Self::Config) -> usize { + let (enabled, inner_config) = config; + if *enabled { + // 1 byte for Some discriminant + inner type's byte_len + 1 + T::byte_len(inner_config) + } else { + // Just 1 byte for None discriminant + 1 + } + } + + fn new_zero_copy( + bytes: &'a mut [u8], + config: Self::Config, + ) -> Result<(Self::Output, &'a mut [u8]), ZeroCopyError> { + if bytes.is_empty() { + return Err(ZeroCopyError::ArraySize(1, bytes.len())); + } + + let (enabled, inner_config) = config; + + if enabled { + bytes[0] = 1; // Some discriminant + let (_, bytes) = bytes.split_at_mut(1); + let (value, bytes) = T::new_zero_copy(bytes, inner_config)?; + Ok((Some(value), bytes)) + } else { + bytes[0] = 0; // None discriminant + let (_, bytes) = bytes.split_at_mut(1); + Ok((None, bytes)) + } + } +} + +// Implementation for primitive types (no configuration needed) +impl<'a> ZeroCopyNew<'a> for u64 { + type Config = (); + type Output = zerocopy::Ref<&'a mut [u8], zerocopy::little_endian::U64>; + + fn byte_len(_config: &Self::Config) -> usize { + size_of::() + } + + fn new_zero_copy( + bytes: &'a mut [u8], + _config: Self::Config, + ) -> Result<(Self::Output, &'a mut [u8]), ZeroCopyError> { + // Return U64 little-endian type for generated structs + Ok(zerocopy::Ref::<&mut [u8], zerocopy::little_endian::U64>::from_prefix(bytes)?) + } +} + +impl<'a> ZeroCopyNew<'a> for u32 { + type Config = (); + type Output = zerocopy::Ref<&'a mut [u8], zerocopy::little_endian::U32>; + + fn byte_len(_config: &Self::Config) -> usize { + size_of::() + } + + fn new_zero_copy( + bytes: &'a mut [u8], + _config: Self::Config, + ) -> Result<(Self::Output, &'a mut [u8]), ZeroCopyError> { + // Return U32 little-endian type for generated structs + Ok(zerocopy::Ref::<&mut [u8], zerocopy::little_endian::U32>::from_prefix(bytes)?) + } +} + +impl<'a> ZeroCopyNew<'a> for u16 { + type Config = (); + type Output = zerocopy::Ref<&'a mut [u8], zerocopy::little_endian::U16>; + + fn byte_len(_config: &Self::Config) -> usize { + size_of::() + } + + fn new_zero_copy( + bytes: &'a mut [u8], + _config: Self::Config, + ) -> Result<(Self::Output, &'a mut [u8]), ZeroCopyError> { + // Return U16 little-endian type for generated structs + Ok(zerocopy::Ref::<&mut [u8], zerocopy::little_endian::U16>::from_prefix(bytes)?) + } +} + +impl<'a> ZeroCopyNew<'a> for u8 { + type Config = (); + type Output = >::Output; + + fn byte_len(_config: &Self::Config) -> usize { + size_of::() + } + + fn new_zero_copy( + bytes: &'a mut [u8], + _config: Self::Config, + ) -> Result<(Self::Output, &'a mut [u8]), ZeroCopyError> { + // Use the DeserializeMut trait to create the proper output + Self::zero_copy_at_mut(bytes) + } +} + +impl<'a> ZeroCopyNew<'a> for bool { + type Config = (); + type Output = >::Output; + + fn byte_len(_config: &Self::Config) -> usize { + size_of::() // bool is serialized as u8 + } + + fn new_zero_copy( + bytes: &'a mut [u8], + _config: Self::Config, + ) -> Result<(Self::Output, &'a mut [u8]), ZeroCopyError> { + // Treat bool as u8 + u8::zero_copy_at_mut(bytes) + } +} + +// Implementation for fixed-size arrays +impl< + 'a, + T: Copy + Default + zerocopy::KnownLayout + zerocopy::Immutable + zerocopy::FromBytes, + const N: usize, + > ZeroCopyNew<'a> for [T; N] +{ + type Config = (); + type Output = >::Output; + + fn byte_len(_config: &Self::Config) -> usize { + size_of::() + } + + fn new_zero_copy( + bytes: &'a mut [u8], + _config: Self::Config, + ) -> Result<(Self::Output, &'a mut [u8]), ZeroCopyError> { + // Use the DeserializeMut trait to create the proper output + Self::zero_copy_at_mut(bytes) + } +} + +// Implementation for zerocopy little-endian types +impl<'a> ZeroCopyNew<'a> for zerocopy::little_endian::U16 { + type Config = (); + type Output = zerocopy::Ref<&'a mut [u8], zerocopy::little_endian::U16>; + + fn byte_len(_config: &Self::Config) -> usize { + size_of::() + } + + fn new_zero_copy( + bytes: &'a mut [u8], + _config: Self::Config, + ) -> Result<(Self::Output, &'a mut [u8]), ZeroCopyError> { + Ok(zerocopy::Ref::<&mut [u8], zerocopy::little_endian::U16>::from_prefix(bytes)?) + } +} + +impl<'a> ZeroCopyNew<'a> for zerocopy::little_endian::U32 { + type Config = (); + type Output = zerocopy::Ref<&'a mut [u8], zerocopy::little_endian::U32>; + + fn byte_len(_config: &Self::Config) -> usize { + size_of::() + } + + fn new_zero_copy( + bytes: &'a mut [u8], + _config: Self::Config, + ) -> Result<(Self::Output, &'a mut [u8]), ZeroCopyError> { + Ok(zerocopy::Ref::<&mut [u8], zerocopy::little_endian::U32>::from_prefix(bytes)?) + } +} + +impl<'a> ZeroCopyNew<'a> for zerocopy::little_endian::U64 { + type Config = (); + type Output = zerocopy::Ref<&'a mut [u8], zerocopy::little_endian::U64>; + + fn byte_len(_config: &Self::Config) -> usize { + size_of::() + } + + fn new_zero_copy( + bytes: &'a mut [u8], + _config: Self::Config, + ) -> Result<(Self::Output, &'a mut [u8]), ZeroCopyError> { + Ok(zerocopy::Ref::<&mut [u8], zerocopy::little_endian::U64>::from_prefix(bytes)?) + } +} + +// Implementation for Vec +impl<'a, T: ZeroCopyNew<'a>> ZeroCopyNew<'a> for Vec { + type Config = Vec; // Vector of configs for each item + type Output = Vec; + + fn byte_len(config: &Self::Config) -> usize { + // 4 bytes for length prefix + sum of byte_len for each element config + 4 + config + .iter() + .map(|config| T::byte_len(config)) + .sum::() + } + + fn new_zero_copy( + bytes: &'a mut [u8], + configs: Self::Config, + ) -> Result<(Self::Output, &'a mut [u8]), ZeroCopyError> { + use zerocopy::{little_endian::U32, Ref}; + + // Write length as U32 + let len = configs.len() as u32; + let (mut len_ref, mut bytes) = Ref::<&mut [u8], U32>::from_prefix(bytes)?; + *len_ref = U32::new(len); + + // Initialize each item with its config + let mut items = Vec::with_capacity(configs.len()); + for config in configs { + let (item, remaining_bytes) = T::new_zero_copy(bytes, config)?; + bytes = remaining_bytes; + items.push(item); + } + + Ok((items, bytes)) + } +} diff --git a/program-libs/zero-copy/src/lib.rs b/program-libs/zero-copy/src/lib.rs index 297c849d53..3ac6a38948 100644 --- a/program-libs/zero-copy/src/lib.rs +++ b/program-libs/zero-copy/src/lib.rs @@ -10,8 +10,24 @@ pub mod vec; use core::mem::{align_of, size_of}; #[cfg(feature = "std")] pub mod borsh; - -use zerocopy::{FromBytes, Immutable, IntoBytes, KnownLayout}; +#[cfg(feature = "std")] +pub mod borsh_mut; +#[cfg(feature = "std")] +pub mod init_mut; +#[cfg(feature = "std")] +pub use borsh::ZeroCopyStructInner; +#[cfg(feature = "std")] +pub use init_mut::ZeroCopyNew; +#[cfg(all(feature = "derive", feature = "mut"))] +pub use light_zero_copy_derive::ZeroCopyMut; +#[cfg(feature = "derive")] +pub use light_zero_copy_derive::{ZeroCopy, ZeroCopyEq}; +#[cfg(feature = "derive")] +pub use zerocopy::{ + little_endian::{self, U16, U32, U64}, + Ref, Unaligned, +}; +pub use zerocopy::{FromBytes, Immutable, IntoBytes, KnownLayout}; #[cfg(feature = "std")] extern crate std; diff --git a/program-libs/zero-copy/src/slice_mut.rs b/program-libs/zero-copy/src/slice_mut.rs index 27cd2f776a..7a50b7e44d 100644 --- a/program-libs/zero-copy/src/slice_mut.rs +++ b/program-libs/zero-copy/src/slice_mut.rs @@ -276,3 +276,16 @@ where write!(f, "{:?}", self.as_slice()) } } + +#[cfg(feature = "std")] +impl<'a, T: ZeroCopyTraits + crate::borsh_mut::DeserializeMut<'a>> + crate::borsh_mut::DeserializeMut<'a> for ZeroCopySliceMutBorsh<'_, T> +{ + type Output = ZeroCopySliceMutBorsh<'a, T>; + + fn zero_copy_at_mut( + bytes: &'a mut [u8], + ) -> Result<(Self::Output, &'a mut [u8]), ZeroCopyError> { + ZeroCopySliceMutBorsh::from_bytes_at(bytes) + } +} diff --git a/program-libs/zero-copy/tests/borsh.rs b/program-libs/zero-copy/tests/borsh.rs new file mode 100644 index 0000000000..071b4e8df2 --- /dev/null +++ b/program-libs/zero-copy/tests/borsh.rs @@ -0,0 +1,335 @@ +#![cfg(all(feature = "std", feature = "derive", feature = "mut"))] +use borsh::{BorshDeserialize, BorshSerialize}; +use light_zero_copy::{ + borsh::Deserialize, borsh_mut::DeserializeMut, ZeroCopy, ZeroCopyEq, ZeroCopyMut, +}; + +#[repr(C)] +#[derive(Debug, PartialEq, ZeroCopy, ZeroCopyMut, ZeroCopyEq, BorshDeserialize, BorshSerialize)] +pub struct Struct1Derived { + pub a: u8, + pub b: u16, +} + +#[test] +fn test_struct_1_derived() { + let ref_struct = Struct1Derived { a: 1, b: 2 }; + let mut bytes = ref_struct.try_to_vec().unwrap(); + + { + let (struct1, remaining) = Struct1Derived::zero_copy_at(&bytes).unwrap(); + assert_eq!(struct1.a, 1u8); + assert_eq!(struct1.b, 2u16); + assert_eq!(struct1, ref_struct); + assert_eq!(remaining, &[]); + } + { + let (mut struct1, _) = Struct1Derived::zero_copy_at_mut(&mut bytes).unwrap(); + struct1.a = 2; + struct1.b = 3.into(); + } + let borsh = Struct1Derived::deserialize(&mut &bytes[..]).unwrap(); + let (struct_1, _) = Struct1Derived::zero_copy_at(&bytes).unwrap(); + assert_eq!(struct_1.a, 2); // Modified value from mutable operations + assert_eq!(struct_1.b, 3); // Modified value from mutable operations + assert_eq!(struct_1, borsh); +} + +// Struct2 equivalent: Manual implementation that should match Struct2 +#[repr(C)] +#[derive( + Debug, PartialEq, Clone, BorshSerialize, BorshDeserialize, ZeroCopy, ZeroCopyMut, ZeroCopyEq, +)] +pub struct Struct2Derived { + pub a: u8, + pub b: u16, + pub vec: Vec, +} + +#[test] +fn test_struct_2_derived() { + let ref_struct = Struct2Derived { + a: 1, + b: 2, + vec: vec![1u8; 32], + }; + let bytes = ref_struct.try_to_vec().unwrap(); + + let (struct2, remaining) = Struct2Derived::zero_copy_at(&bytes).unwrap(); + assert_eq!(struct2.a, 1u8); + assert_eq!(struct2.b, 2u16); + assert_eq!(struct2.vec.to_vec(), vec![1u8; 32]); + assert_eq!(remaining, &[]); + assert_eq!(struct2, ref_struct); +} + +// Struct3 equivalent: fields should match Struct3 +#[repr(C)] +#[derive(Debug, PartialEq, BorshSerialize, BorshDeserialize, ZeroCopy, ZeroCopyMut, ZeroCopyEq)] +pub struct Struct3Derived { + pub a: u8, + pub b: u16, + pub vec: Vec, + pub c: u64, +} + +#[test] +fn test_struct_3_derived() { + let ref_struct = Struct3Derived { + a: 1, + b: 2, + vec: vec![1u8; 32], + c: 3, + }; + let bytes = ref_struct.try_to_vec().unwrap(); + + let (zero_copy, remaining) = Struct3Derived::zero_copy_at(&bytes).unwrap(); + assert_eq!(zero_copy.a, 1u8); + assert_eq!(zero_copy.b, 2u16); + assert_eq!(zero_copy.vec.to_vec(), vec![1u8; 32]); + assert_eq!(u64::from(*zero_copy.c), 3); + assert_eq!(zero_copy, ref_struct); + + assert_eq!(remaining, &[]); +} + +#[repr(C)] +#[derive( + Debug, PartialEq, BorshSerialize, BorshDeserialize, Clone, ZeroCopy, ZeroCopyMut, ZeroCopyEq, +)] +pub struct Struct4NestedDerived { + a: u8, + b: u16, +} + +#[repr(C)] +#[derive( + Debug, PartialEq, BorshSerialize, BorshDeserialize, Clone, ZeroCopy, ZeroCopyMut, ZeroCopyEq, +)] +pub struct Struct4Derived { + pub a: u8, + pub b: u16, + pub vec: Vec, + pub c: u64, + pub vec_2: Vec, +} + +#[test] +fn test_struct_4_derived() { + let ref_struct = Struct4Derived { + a: 1, + b: 2, + vec: vec![1u8; 32], + c: 3, + vec_2: vec![Struct4NestedDerived { a: 1, b: 2 }; 32], + }; + let bytes = ref_struct.try_to_vec().unwrap(); + + let (zero_copy, remaining) = Struct4Derived::zero_copy_at(&bytes).unwrap(); + assert_eq!(zero_copy.a, 1u8); + assert_eq!(zero_copy.b, 2u16); + assert_eq!(zero_copy.vec.to_vec(), vec![1u8; 32]); + assert_eq!(u64::from(*zero_copy.c), 3); + // Check vec_2 length is correct + assert_eq!(zero_copy.vec_2.len(), 32); + assert_eq!(zero_copy, ref_struct); + assert_eq!(remaining, &[]); +} + +#[repr(C)] +#[derive( + Debug, Clone, PartialEq, BorshSerialize, BorshDeserialize, ZeroCopy, ZeroCopyMut, ZeroCopyEq, +)] +pub struct Struct5Derived { + pub a: Vec>, +} + +#[test] +fn test_struct_5_derived() { + let ref_struct = Struct5Derived { + a: vec![vec![1u8; 32]; 32], + }; + let bytes = ref_struct.try_to_vec().unwrap(); + + let (zero_copy, remaining) = Struct5Derived::zero_copy_at(&bytes).unwrap(); + assert_eq!( + zero_copy.a.iter().map(|x| x.to_vec()).collect::>(), + vec![vec![1u8; 32]; 32] + ); + assert_eq!(zero_copy, ref_struct); + assert_eq!(remaining, &[]); +} + +// If a struct inside a vector contains a vector it must implement Deserialize. +#[repr(C)] +#[derive(Debug, PartialEq, BorshSerialize, BorshDeserialize, ZeroCopy, ZeroCopyMut, ZeroCopyEq)] +pub struct Struct6Derived { + pub a: Vec, +} + +#[test] +fn test_struct_6_derived() { + let ref_struct = Struct6Derived { + a: vec![ + Struct2Derived { + a: 1, + b: 2, + vec: vec![1u8; 32], + }; + 32 + ], + }; + let bytes = ref_struct.try_to_vec().unwrap(); + + let (zero_copy, remaining) = Struct6Derived::zero_copy_at(&bytes).unwrap(); + assert_eq!( + zero_copy.a.iter().collect::>(), + vec![ + &Struct2Derived { + a: 1, + b: 2, + vec: vec![1u8; 32], + }; + 32 + ] + ); + assert_eq!(zero_copy, ref_struct); + assert_eq!(remaining, &[]); +} + +#[repr(C)] +#[derive(Debug, PartialEq, Clone, BorshSerialize, BorshDeserialize, ZeroCopy, ZeroCopyMut)] +pub struct Struct7Derived { + pub a: u8, + pub b: u16, + pub option: Option, +} + +#[test] +fn test_struct_7_derived() { + let ref_struct = Struct7Derived { + a: 1, + b: 2, + option: Some(3), + }; + let bytes = ref_struct.try_to_vec().unwrap(); + + let (zero_copy, remaining) = Struct7Derived::zero_copy_at(&bytes).unwrap(); + assert_eq!(zero_copy.a, 1u8); + assert_eq!(zero_copy.b, 2u16); + assert_eq!(zero_copy.option, Some(3)); + assert_eq!(remaining, &[]); + + let bytes = Struct7Derived { + a: 1, + b: 2, + option: None, + } + .try_to_vec() + .unwrap(); + let (zero_copy, remaining) = Struct7Derived::zero_copy_at(&bytes).unwrap(); + assert_eq!(zero_copy.a, 1u8); + assert_eq!(zero_copy.b, 2u16); + assert_eq!(zero_copy.option, None); + assert_eq!(remaining, &[]); +} + +// If a struct inside a vector contains a vector it must implement Deserialize. +#[repr(C)] +#[derive(Debug, PartialEq, BorshSerialize, BorshDeserialize, ZeroCopy, ZeroCopyMut, ZeroCopyEq)] +pub struct Struct8Derived { + pub a: Vec, +} + +#[derive( + Debug, Clone, PartialEq, BorshSerialize, BorshDeserialize, ZeroCopy, ZeroCopyMut, ZeroCopyEq, +)] +pub struct NestedStructDerived { + pub a: u8, + pub b: Struct2Derived, +} + +#[test] +fn test_struct_8_derived() { + let ref_struct = Struct8Derived { + a: vec![ + NestedStructDerived { + a: 1, + b: Struct2Derived { + a: 1, + b: 2, + vec: vec![1u8; 32], + }, + }; + 32 + ], + }; + let bytes = ref_struct.try_to_vec().unwrap(); + + let (zero_copy, remaining) = Struct8Derived::zero_copy_at(&bytes).unwrap(); + // Check length of vec matches + assert_eq!(zero_copy.a.len(), 32); + assert_eq!(zero_copy, ref_struct); + + assert_eq!(remaining, &[]); +} + +#[repr(C)] +#[derive(ZeroCopy, ZeroCopyMut, ZeroCopyEq, BorshSerialize, BorshDeserialize, PartialEq, Debug)] +pub struct ArrayStruct { + pub a: [u8; 32], + pub b: [u8; 64], + pub c: [u8; 32], +} + +#[test] +fn test_array_struct() -> Result<(), Box> { + let array_struct = ArrayStruct { + a: [1u8; 32], + b: [2u8; 64], + c: [3u8; 32], + }; + let bytes = array_struct.try_to_vec()?; + + let (zero_copy, remaining) = ArrayStruct::zero_copy_at(&bytes).unwrap(); + assert_eq!(zero_copy.a, [1u8; 32]); + assert_eq!(zero_copy.b, [2u8; 64]); + assert_eq!(zero_copy.c, [3u8; 32]); + assert_eq!(zero_copy, array_struct); + assert_eq!(remaining, &[]); + Ok(()) +} + +#[derive( + Debug, + PartialEq, + Default, + Clone, + BorshSerialize, + BorshDeserialize, + ZeroCopy, + ZeroCopyMut, + ZeroCopyEq, +)] +pub struct CompressedAccountData { + pub discriminator: [u8; 8], + pub data: Vec, + pub data_hash: [u8; 32], +} + +#[test] +fn test_compressed_account_data() { + let compressed_account_data = CompressedAccountData { + discriminator: [1u8; 8], + data: vec![2u8; 32], + data_hash: [3u8; 32], + }; + let bytes = compressed_account_data.try_to_vec().unwrap(); + + let (zero_copy, remaining) = CompressedAccountData::zero_copy_at(&bytes).unwrap(); + assert_eq!(zero_copy.discriminator, [1u8; 8]); + // assert_eq!(zero_copy.data, compressed_account_data.data.as_slice()); + assert_eq!(*zero_copy.data_hash, [3u8; 32]); + assert_eq!(zero_copy, compressed_account_data); + assert_eq!(remaining, &[]); +} diff --git a/program-libs/zero-copy/tests/borsh_2.rs b/program-libs/zero-copy/tests/borsh_2.rs new file mode 100644 index 0000000000..aece86bb1c --- /dev/null +++ b/program-libs/zero-copy/tests/borsh_2.rs @@ -0,0 +1,559 @@ +#![cfg(all(feature = "std", feature = "derive"))] + +use std::{ops::Deref, vec}; + +use borsh::{BorshDeserialize, BorshSerialize}; +use light_zero_copy::{ + borsh::Deserialize, errors::ZeroCopyError, slice::ZeroCopySliceBorsh, ZeroCopyStructInner, +}; +use zerocopy::{ + little_endian::{U16, U64}, + FromBytes, Immutable, IntoBytes, KnownLayout, Ref, Unaligned, +}; + +// Rules: +// 1. create ZStruct for the struct +// 1.1. the first fields are extracted into a meta struct until we reach a Vec, Option or type that does not implement Copy, and we implement deref for the meta struct +// 1.2. represent vectors to ZeroCopySlice & don't include these into the meta struct +// 1.3. replace u16 with U16, u32 with U32, etc +// 1.4. every field after the first vector is directly included in the ZStruct and deserialized 1 by 1 +// 1.5. If a vector contains a nested vector (does not implement Copy) it must implement Deserialize +// 1.6. Elements in an Option must implement Deserialize +// 1.7. a type that does not implement Copy must implement Deserialize, and is deserialized 1 by 1 + +// Derive Macro needs to derive: +// 1. ZeroCopyStructInner +// 2. Deserialize +// 3. PartialEq for ZStruct<'_> +// +// For every struct1 - struct7 create struct_derived1 - struct_derived7 and replicate the tests for the new structs. + +// Tests for manually implemented structures (without derive macro) + +#[repr(C)] +#[derive(Debug, PartialEq, BorshSerialize, BorshDeserialize)] +pub struct Struct1 { + pub a: u8, + pub b: u16, +} + +// pub fn data_hash_struct_1(a: u8, b: u16) -> [u8; 32] { + +// } + +#[repr(C)] +#[derive(Debug, PartialEq, KnownLayout, Immutable, Unaligned, FromBytes)] +pub struct ZStruct1Meta { + pub a: u8, + pub b: U16, +} + +#[repr(C)] +#[derive(Debug, PartialEq)] +pub struct ZStruct1<'a> { + meta: Ref<&'a [u8], ZStruct1Meta>, +} +impl<'a> Deref for ZStruct1<'a> { + type Target = Ref<&'a [u8], ZStruct1Meta>; + + fn deref(&self) -> &Self::Target { + &self.meta + } +} + +impl<'a> Deserialize<'a> for Struct1 { + type Output = ZStruct1<'a>; + + fn zero_copy_at(bytes: &'a [u8]) -> Result<(Self::Output, &'a [u8]), ZeroCopyError> { + let (meta, bytes) = Ref::<&[u8], ZStruct1Meta>::from_prefix(bytes)?; + Ok((ZStruct1 { meta }, bytes)) + } +} + +#[test] +fn test_struct_1() { + let bytes = Struct1 { a: 1, b: 2 }.try_to_vec().unwrap(); + let (struct1, remaining) = Struct1::zero_copy_at(&bytes).unwrap(); + assert_eq!(struct1.a, 1u8); + assert_eq!(struct1.b, 2u16); + assert_eq!(remaining, &[]); +} + +#[repr(C)] +#[derive(Debug, PartialEq, Clone, BorshSerialize, BorshDeserialize)] +pub struct Struct2 { + pub a: u8, + pub b: u16, + pub vec: Vec, +} + +#[repr(C)] +#[derive(Debug, PartialEq, KnownLayout, Immutable, Unaligned, FromBytes)] +pub struct ZStruct2Meta { + pub a: u8, + pub b: U16, +} + +#[repr(C)] +#[derive(Debug, PartialEq)] +pub struct ZStruct2<'a> { + meta: Ref<&'a [u8], ZStruct2Meta>, + pub vec: as ZeroCopyStructInner>::ZeroCopyInner, +} + +impl PartialEq for ZStruct2<'_> { + fn eq(&self, other: &Struct2) -> bool { + let meta: &ZStruct2Meta = &self.meta; + if meta.a != other.a || other.b != meta.b.into() { + return false; + } + self.vec.as_slice() == other.vec.as_slice() + } +} + +impl<'a> Deref for ZStruct2<'a> { + type Target = Ref<&'a [u8], ZStruct2Meta>; + + fn deref(&self) -> &Self::Target { + &self.meta + } +} + +impl<'a> Deserialize<'a> for Struct2 { + type Output = ZStruct2<'a>; + + fn zero_copy_at(bytes: &'a [u8]) -> Result<(Self::Output, &'a [u8]), ZeroCopyError> { + let (meta, bytes) = Ref::<&[u8], ZStruct2Meta>::from_prefix(bytes)?; + let (vec, bytes) = as Deserialize>::zero_copy_at(bytes)?; + Ok((ZStruct2 { meta, vec }, bytes)) + } +} + +#[test] +fn test_struct_2() { + let bytes = Struct2 { + a: 1, + b: 2, + vec: vec![1u8; 32], + } + .try_to_vec() + .unwrap(); + let (struct2, remaining) = Struct2::zero_copy_at(&bytes).unwrap(); + assert_eq!(struct2.a, 1u8); + assert_eq!(struct2.b, 2u16); + assert_eq!(struct2.vec.to_vec(), vec![1u8; 32]); + assert_eq!(remaining, &[]); +} + +#[repr(C)] +#[derive(Debug, PartialEq, BorshSerialize, BorshDeserialize)] +pub struct Struct3 { + pub a: u8, + pub b: u16, + pub vec: Vec, + pub c: u64, +} + +#[repr(C)] +#[derive(Debug, PartialEq, KnownLayout, Immutable, Unaligned, FromBytes)] +pub struct ZStruct3Meta { + pub a: u8, + pub b: U16, +} + +#[derive(Debug, PartialEq)] +pub struct ZStruct3<'a> { + meta: Ref<&'a [u8], ZStruct3Meta>, + pub vec: ZeroCopySliceBorsh<'a, u8>, + pub c: Ref<&'a [u8], U64>, +} + +impl<'a> Deref for ZStruct3<'a> { + type Target = Ref<&'a [u8], ZStruct3Meta>; + + fn deref(&self) -> &Self::Target { + &self.meta + } +} + +impl<'a> Deserialize<'a> for Struct3 { + type Output = ZStruct3<'a>; + + fn zero_copy_at(bytes: &'a [u8]) -> Result<(Self::Output, &'a [u8]), ZeroCopyError> { + let (meta, bytes) = Ref::<&[u8], ZStruct3Meta>::from_prefix(bytes)?; + let (vec, bytes) = ZeroCopySliceBorsh::zero_copy_at(bytes)?; + let (c, bytes) = Ref::<&[u8], U64>::from_prefix(bytes)?; + Ok((ZStruct3 { meta, vec, c }, bytes)) + } +} + +#[test] +fn test_struct_3() { + let bytes = Struct3 { + a: 1, + b: 2, + vec: vec![1u8; 32], + c: 3, + } + .try_to_vec() + .unwrap(); + let (zero_copy, remaining) = Struct3::zero_copy_at(&bytes).unwrap(); + assert_eq!(zero_copy.a, 1u8); + assert_eq!(zero_copy.b, 2u16); + assert_eq!(zero_copy.vec.to_vec(), vec![1u8; 32]); + assert_eq!(u64::from(*zero_copy.c), 3); + assert_eq!(remaining, &[]); +} + +#[repr(C)] +#[derive(Debug, PartialEq, BorshSerialize, BorshDeserialize, Clone)] +pub struct Struct4Nested { + a: u8, + b: u16, +} + +#[repr(C)] +#[derive( + Debug, PartialEq, Copy, Clone, KnownLayout, Immutable, IntoBytes, Unaligned, FromBytes, +)] +pub struct ZStruct4Nested { + pub a: u8, + pub b: U16, +} + +impl ZeroCopyStructInner for Struct4Nested { + type ZeroCopyInner = ZStruct4Nested; +} + +#[repr(C)] +#[derive(Debug, PartialEq, BorshSerialize, BorshDeserialize)] +pub struct Struct4 { + pub a: u8, + pub b: u16, + pub vec: Vec, + pub c: u64, + pub vec_2: Vec, +} + +#[repr(C)] +#[derive(Debug, PartialEq, KnownLayout, Immutable, Unaligned, IntoBytes, FromBytes)] +pub struct ZStruct4Meta { + pub a: ::ZeroCopyInner, + pub b: ::ZeroCopyInner, +} + +#[derive(Debug, PartialEq)] +pub struct ZStruct4<'a> { + meta: Ref<&'a [u8], ZStruct4Meta>, + pub vec: ZeroCopySliceBorsh<'a, ::ZeroCopyInner>, + pub c: Ref<&'a [u8], ::ZeroCopyInner>, + pub vec_2: ZeroCopySliceBorsh<'a, ::ZeroCopyInner>, +} + +impl<'a> Deref for ZStruct4<'a> { + type Target = Ref<&'a [u8], ZStruct4Meta>; + + fn deref(&self) -> &Self::Target { + &self.meta + } +} + +impl<'a> Deserialize<'a> for Struct4 { + type Output = ZStruct4<'a>; + + fn zero_copy_at(bytes: &'a [u8]) -> Result<(Self::Output, &'a [u8]), ZeroCopyError> { + let (meta, bytes) = Ref::<&[u8], ZStruct4Meta>::from_prefix(bytes)?; + let (vec, bytes) = ZeroCopySliceBorsh::from_bytes_at(bytes)?; + let (c, bytes) = + Ref::<&[u8], ::ZeroCopyInner>::from_prefix(bytes)?; + let (vec_2, bytes) = ZeroCopySliceBorsh::from_bytes_at(bytes)?; + Ok(( + ZStruct4 { + meta, + vec, + c, + vec_2, + }, + bytes, + )) + } +} + +#[test] +fn test_struct_4() { + let bytes = Struct4 { + a: 1, + b: 2, + vec: vec![1u8; 32], + c: 3, + vec_2: vec![Struct4Nested { a: 1, b: 2 }; 32], + } + .try_to_vec() + .unwrap(); + let (zero_copy, remaining) = Struct4::zero_copy_at(&bytes).unwrap(); + assert_eq!(zero_copy.a, 1u8); + assert_eq!(zero_copy.b, 2u16); + assert_eq!(zero_copy.vec.to_vec(), vec![1u8; 32]); + assert_eq!(u64::from(*zero_copy.c), 3); + assert_eq!( + zero_copy.vec_2.to_vec(), + vec![ZStruct4Nested { a: 1, b: 2.into() }; 32] + ); + assert_eq!(remaining, &[]); +} + +#[repr(C)] +#[derive(Debug, Clone, PartialEq, BorshSerialize, BorshDeserialize)] +pub struct Struct5 { + pub a: Vec>, +} + +#[repr(C)] +#[derive(Debug, PartialEq)] +pub struct ZStruct5<'a> { + pub a: Vec::ZeroCopyInner>>, +} + +impl<'a> Deserialize<'a> for Struct5 { + type Output = ZStruct5<'a>; + + fn zero_copy_at(bytes: &'a [u8]) -> Result<(Self::Output, &'a [u8]), ZeroCopyError> { + let (a, bytes) = + Vec::::ZeroCopyInner>>::zero_copy_at( + bytes, + )?; + Ok((ZStruct5 { a }, bytes)) + } +} + +#[test] +fn test_struct_5() { + let bytes = Struct5 { + a: vec![vec![1u8; 32]; 32], + } + .try_to_vec() + .unwrap(); + let (zero_copy, remaining) = Struct5::zero_copy_at(&bytes).unwrap(); + assert_eq!( + zero_copy.a.iter().map(|x| x.to_vec()).collect::>(), + vec![vec![1u8; 32]; 32] + ); + assert_eq!(remaining, &[]); +} + +// If a struct inside a vector contains a vector it must implement Deserialize. +#[repr(C)] +#[derive(Debug, PartialEq, BorshSerialize, BorshDeserialize)] +pub struct Struct6 { + pub a: Vec, +} + +#[repr(C)] +#[derive(Debug, PartialEq)] +pub struct ZStruct6<'a> { + pub a: Vec<>::Output>, +} + +impl<'a> Deserialize<'a> for Struct6 { + type Output = ZStruct6<'a>; + + fn zero_copy_at(bytes: &'a [u8]) -> Result<(Self::Output, &'a [u8]), ZeroCopyError> { + let (a, bytes) = Vec::::zero_copy_at(bytes)?; + Ok((ZStruct6 { a }, bytes)) + } +} + +#[test] +fn test_struct_6() { + let bytes = Struct6 { + a: vec![ + Struct2 { + a: 1, + b: 2, + vec: vec![1u8; 32], + }; + 32 + ], + } + .try_to_vec() + .unwrap(); + let (zero_copy, remaining) = Struct6::zero_copy_at(&bytes).unwrap(); + assert_eq!( + zero_copy.a.iter().collect::>(), + vec![ + &Struct2 { + a: 1, + b: 2, + vec: vec![1u8; 32], + }; + 32 + ] + ); + assert_eq!(remaining, &[]); +} + +#[repr(C)] +#[derive(Debug, PartialEq, Clone, BorshSerialize, BorshDeserialize)] +pub struct Struct7 { + pub a: u8, + pub b: u16, + pub option: Option, +} + +#[repr(C)] +#[derive(Debug, PartialEq, KnownLayout, Immutable, Unaligned, FromBytes)] +pub struct ZStruct7Meta { + pub a: u8, + pub b: U16, +} + +#[repr(C)] +#[derive(Debug, PartialEq)] +pub struct ZStruct7<'a> { + meta: Ref<&'a [u8], ZStruct7Meta>, + pub option: as ZeroCopyStructInner>::ZeroCopyInner, +} + +impl PartialEq for ZStruct7<'_> { + fn eq(&self, other: &Struct7) -> bool { + let meta: &ZStruct7Meta = &self.meta; + if meta.a != other.a || other.b != meta.b.into() { + return false; + } + self.option == other.option + } +} + +impl<'a> Deref for ZStruct7<'a> { + type Target = Ref<&'a [u8], ZStruct7Meta>; + + fn deref(&self) -> &Self::Target { + &self.meta + } +} + +impl<'a> Deserialize<'a> for Struct7 { + type Output = ZStruct7<'a>; + + fn zero_copy_at(bytes: &'a [u8]) -> Result<(Self::Output, &'a [u8]), ZeroCopyError> { + let (meta, bytes) = Ref::<&[u8], ZStruct7Meta>::from_prefix(bytes)?; + let (option, bytes) = as Deserialize>::zero_copy_at(bytes)?; + Ok((ZStruct7 { meta, option }, bytes)) + } +} + +#[test] +fn test_struct_7() { + let bytes = Struct7 { + a: 1, + b: 2, + option: Some(3), + } + .try_to_vec() + .unwrap(); + let (zero_copy, remaining) = Struct7::zero_copy_at(&bytes).unwrap(); + assert_eq!(zero_copy.a, 1u8); + assert_eq!(zero_copy.b, 2u16); + assert_eq!(zero_copy.option, Some(3)); + assert_eq!(remaining, &[]); + + let bytes = Struct7 { + a: 1, + b: 2, + option: None, + } + .try_to_vec() + .unwrap(); + let (zero_copy, remaining) = Struct7::zero_copy_at(&bytes).unwrap(); + assert_eq!(zero_copy.a, 1u8); + assert_eq!(zero_copy.b, 2u16); + assert_eq!(zero_copy.option, None); + assert_eq!(remaining, &[]); +} + +// If a struct inside a vector contains a vector it must implement Deserialize. +#[repr(C)] +#[derive(Debug, PartialEq, BorshSerialize, BorshDeserialize)] +pub struct Struct8 { + pub a: Vec, +} + +#[derive(Debug, Clone, PartialEq, BorshSerialize, BorshDeserialize)] +pub struct NestedStruct { + pub a: u8, + pub b: Struct2, +} + +#[repr(C)] +#[derive(Debug, PartialEq)] +pub struct ZNestedStruct<'a> { + pub a: ::ZeroCopyInner, + pub b: >::Output, +} + +impl<'a> Deserialize<'a> for NestedStruct { + type Output = ZNestedStruct<'a>; + + fn zero_copy_at(bytes: &'a [u8]) -> Result<(Self::Output, &'a [u8]), ZeroCopyError> { + let (a, bytes) = ::ZeroCopyInner::zero_copy_at(bytes)?; + let (b, bytes) = ::zero_copy_at(bytes)?; + Ok((ZNestedStruct { a, b }, bytes)) + } +} + +impl PartialEq for ZNestedStruct<'_> { + fn eq(&self, other: &NestedStruct) -> bool { + self.a == other.a && self.b == other.b + } +} + +#[repr(C)] +#[derive(Debug, PartialEq)] +pub struct ZStruct8<'a> { + pub a: Vec<>::Output>, +} + +impl<'a> Deserialize<'a> for Struct8 { + type Output = ZStruct8<'a>; + + fn zero_copy_at(bytes: &'a [u8]) -> Result<(Self::Output, &'a [u8]), ZeroCopyError> { + let (a, bytes) = Vec::::zero_copy_at(bytes)?; + Ok((ZStruct8 { a }, bytes)) + } +} + +#[test] +fn test_struct_8() { + let bytes = Struct8 { + a: vec![ + NestedStruct { + a: 1, + b: Struct2 { + a: 1, + b: 2, + vec: vec![1u8; 32], + }, + }; + 32 + ], + } + .try_to_vec() + .unwrap(); + + let (zero_copy, remaining) = Struct8::zero_copy_at(&bytes).unwrap(); + assert_eq!( + zero_copy.a.iter().collect::>(), + vec![ + &NestedStruct { + a: 1, + b: Struct2 { + a: 1, + b: 2, + vec: vec![1u8; 32], + }, + }; + 32 + ] + ); + assert_eq!(remaining, &[]); +}