diff --git a/.github/workflows/rust.yml b/.github/workflows/rust.yml index efa24bff78..2054edec76 100644 --- a/.github/workflows/rust.yml +++ b/.github/workflows/rust.yml @@ -53,7 +53,8 @@ jobs: cargo test -p light-account-checks --all-features cargo test -p light-verifier --all-features cargo test -p light-merkle-tree-metadata --all-features - cargo test -p light-zero-copy --features std + cargo test -p light-zero-copy --features "std, mut, derive" + cargo test -p light-zero-copy-derive --features "mut" cargo test -p light-hash-set --all-features - name: program-libs-slow packages: light-bloom-filter light-indexed-merkle-tree light-batched-merkle-tree diff --git a/.gitignore b/.gitignore index b7e754f3b1..5129907005 100644 --- a/.gitignore +++ b/.gitignore @@ -86,3 +86,5 @@ output1.txt .zed **/.claude/**/* + +expand.rs diff --git a/Cargo.lock b/Cargo.lock index 03057b4669..84f5b3b969 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -250,6 +250,27 @@ dependencies = [ "syn 1.0.109", ] +[[package]] +name = "anchor-compressed-token" +version = "2.0.0" +dependencies = [ + "account-compression", + "anchor-lang", + "anchor-spl", + "light-compressed-account", + "light-hasher", + "light-heap", + "light-system-program-anchor", + "light-zero-copy", + "num-bigint 0.4.6", + "rand 0.8.5", + "solana-sdk", + "solana-security-txt", + "spl-token", + "spl-token-2022 7.0.0", + "zerocopy", +] + [[package]] name = "anchor-derive-accounts" version = "0.31.1" @@ -1323,6 +1344,7 @@ dependencies = [ "rand 0.8.5", "serial_test", "solana-sdk", + "spl-pod", "spl-token", "tokio", ] @@ -2365,6 +2387,12 @@ version = "0.31.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "07e28edb80900c19c28f1072f2e8aeca7fa06b23cd4169cefe1af5aa3260783f" +[[package]] +name = "glob" +version = "0.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a8d1add55171497b4705a648c6b583acafb01d58050a51727785f0b2c8e0a2b2" + [[package]] name = "governor" version = "0.6.3" @@ -3360,17 +3388,25 @@ name = "light-compressed-token" version = "2.0.0" dependencies = [ "account-compression", + "anchor-compressed-token", "anchor-lang", - "anchor-spl", + "arrayvec", + "borsh 0.10.4", + "light-account-checks", "light-compressed-account", "light-hasher", "light-heap", + "light-sdk", + "light-sdk-pinocchio", + "light-sdk-types", "light-system-program-anchor", "light-zero-copy", "num-bigint 0.4.6", + "pinocchio", "rand 0.8.5", - "solana-sdk", + "solana-pubkey", "solana-security-txt", + "spl-pod", "spl-token", "spl-token-2022 7.0.0", "zerocopy", @@ -3786,6 +3822,8 @@ dependencies = [ name = "light-zero-copy" version = "0.2.0" dependencies = [ + "borsh 0.10.4", + "light-zero-copy-derive", "pinocchio", "rand 0.8.5", "solana-program-error", @@ -3793,6 +3831,21 @@ dependencies = [ "zerocopy", ] +[[package]] +name = "light-zero-copy-derive" +version = "0.1.0" +dependencies = [ + "borsh 0.10.4", + "lazy_static", + "light-zero-copy", + "proc-macro2", + "quote", + "rand 0.8.5", + "syn 2.0.103", + "trybuild", + "zerocopy", +] + [[package]] name = "linux-raw-sys" version = "0.9.4" @@ -9054,6 +9107,12 @@ version = "1.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "55937e1799185b12863d447f42597ed69d9928686b8d88a1df17376a097d8369" +[[package]] +name = "target-triple" +version = "0.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1ac9aa371f599d22256307c24a9d748c041e548cbf599f35d890f9d365361790" + [[package]] name = "tarpc" version = "0.29.0" @@ -9626,6 +9685,21 @@ version = "0.2.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e421abadd41a4225275504ea4d6566923418b7f05506fbc9c0fe86ba7396114b" +[[package]] +name = "trybuild" +version = "1.0.105" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1c9bf9513a2f4aeef5fdac8677d7d349c79fdbcc03b9c86da6e9d254f1e43be2" +dependencies = [ + "glob", + "serde", + "serde_derive", + "serde_json", + "target-triple", + "termcolor", + "toml 0.8.23", +] + [[package]] name = "tungstenite" version = "0.20.1" diff --git a/Cargo.toml b/Cargo.toml index 176115adb1..527de1de46 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -13,9 +13,11 @@ members = [ "program-libs/hash-set", "program-libs/indexed-merkle-tree", "program-libs/indexed-array", + "program-libs/zero-copy-derive", "programs/account-compression", "programs/system", - "programs/compressed-token", + "programs/compressed-token/program", + "programs/compressed-token/anchor", "programs/registry", "anchor-programs/system", "sdk-libs/client", @@ -105,6 +107,7 @@ solana-system-interface = { version = "1" } solana-security-txt = "1.1.1" spl-token = "7.0.0" spl-token-2022 = { version = "7", features = ["no-entrypoint"] } +spl-pod = "0.5.1" pinocchio = { version = "0.8.4" } bs58 = "^0.5.1" litesvm = "0.6.1" @@ -167,12 +170,13 @@ light-compressed-account = { path = "program-libs/compressed-account", version = light-account-checks = { path = "program-libs/account-checks", version = "0.3.0" } light-verifier = { path = "program-libs/verifier", version = "2.1.0" } light-zero-copy = { path = "program-libs/zero-copy", version = "0.2.0" } +light-zero-copy-derive = { path = "program-libs/zero-copy-derive", version = "0.1.0" } photon-api = { path = "sdk-libs/photon-api", version = "0.51.0" } forester-utils = { path = "forester-utils", version = "2.0.0" } account-compression = { path = "programs/account-compression", version = "2.0.0", features = [ "cpi", ] } -light-compressed-token = { path = "programs/compressed-token", version = "2.0.0", features = [ +light-compressed-token = { path = "programs/compressed-token/program", version = "2.0.0", features = [ "cpi", ] } light-system-program-anchor = { path = "anchor-programs/system", version = "2.0.0", features = [ diff --git a/metadata.md b/metadata.md new file mode 100644 index 0000000000..520a17636f --- /dev/null +++ b/metadata.md @@ -0,0 +1,242 @@ +# Token 2022 Metadata Pointer Extension Analysis + +## Overview +The Token 2022 metadata pointer extension provides a mechanism for SPL Token 2022 mints to reference metadata accounts using a **Type-Length-Value (TLV)** encoding system. This allows metadata to be stored either directly in the mint account or pointed to external metadata accounts. + +## Core Architecture + +### 1. MetadataPointer Extension Structure +```rust +#[repr(C)] +#[derive(Clone, Copy, Debug, Default, PartialEq, Pod, Zeroable)] +pub struct MetadataPointer { + /// Authority that can set the metadata address + pub authority: OptionalNonZeroPubkey, + /// Account address that holds the metadata + pub metadata_address: OptionalNonZeroPubkey, +} +``` + +### 2. TLV Extension System +Extensions are stored using TLV format: +- **Type**: 2 bytes (ExtensionType enum) +- **Length**: 2 bytes (data length) +- **Value**: Variable length data + +Account layout: +``` +[Base Mint: 82 bytes][Padding: 83 bytes][Account Type: 1 byte][TLV Extensions...] +``` + +### 3. Extension Types +- `MetadataPointer`: Points to metadata account +- `TokenMetadata`: Contains metadata directly +- Extensions are parsed sequentially through TLV data + +## Token 2022 Metadata Account Structure + +The account that a `MetadataPointer` points to contains the actual `TokenMetadata` stored in a **TLV (Type-Length-Value)** format. Here's the detailed structure: + +### Account Layout + +``` +┌─────────────────────────────────────────────────────────────────┐ +│ Complete Account Structure │ +├─────────────────────────────────────────────────────────────────┤ +│ Base Mint Data (82 bytes) │ +│ ┌─ supply: u64 │ +│ ├─ decimals: u8 │ +│ ├─ is_initialized: bool │ +│ ├─ freeze_authority: Option │ +│ └─ mint_authority: Option │ +├─────────────────────────────────────────────────────────────────┤ +│ Extension Data (Variable Length) │ +│ │ +│ ┌─ MetadataPointer Extension (TLV Entry) │ +│ │ ├─ Type: ExtensionType::MetadataPointer (2 bytes) │ +│ │ ├─ Length: 64 (4 bytes) │ +│ │ └─ Value: MetadataPointer struct (64 bytes) │ +│ │ ├─ authority: OptionalNonZeroPubkey (32 bytes) │ +│ │ └─ metadata_address: OptionalNonZeroPubkey (32 bytes) │ +│ │ │ +│ └─ TokenMetadata Extension (TLV Entry) │ +│ ├─ Type: ExtensionType::TokenMetadata (2 bytes) │ +│ ├─ Length: Variable (4 bytes) │ +│ └─ Value: Borsh-serialized TokenMetadata │ +│ ├─ update_authority: OptionalNonZeroPubkey (32 bytes) │ +│ ├─ mint: Pubkey (32 bytes) │ +│ ├─ name: String (4 bytes length + data) │ +│ ├─ symbol: String (4 bytes length + data) │ +│ ├─ uri: String (4 bytes length + data) │ +│ └─ additional_metadata: Vec<(String, String)> │ +│ └─ (4 bytes count + entries) │ +└─────────────────────────────────────────────────────────────────┘ +``` + +### TokenMetadata Structure Details + +```rust +#[derive(Clone, Debug, Default, PartialEq, BorshDeserialize, BorshSerialize)] +pub struct TokenMetadata { + /// Authority that can update the metadata + pub update_authority: OptionalNonZeroPubkey, + /// Associated mint (prevents spoofing) + pub mint: Pubkey, + /// Token name (e.g., "Solana Token") + pub name: String, + /// Token symbol (e.g., "SOL") + pub symbol: String, + /// URI to external metadata JSON + pub uri: String, + /// Additional key-value pairs + pub additional_metadata: Vec<(String, String)>, +} +``` + +### Two Storage Patterns + +#### Pattern 1: Self-Referential (Common) +``` +Mint Account (Same Account) +├─ MetadataPointer Extension +│ └─ metadata_address: [points to same account] +└─ TokenMetadata Extension + └─ [actual metadata data] +``` + +#### Pattern 2: External Account +``` +Mint Account External Metadata Account +├─ MetadataPointer Extension ├─ TokenMetadata Extension +│ └─ metadata_address ────────→│ └─ [actual metadata data] +└─ [no TokenMetadata] └─ [account owned by token program] +``` + +### Serialization Format + +The `TokenMetadata` is serialized using **Borsh** format: +- **Discriminator**: `[112, 132, 90, 90, 11, 88, 157, 87]` (not stored in account) +- **Variable Length**: Strings and Vec fields make the size dynamic +- **TLV Wrapper**: Type + Length headers allow efficient parsing + +## Key Functions + +### Metadata Creation Process +1. **Initialize MetadataPointer**: Set authority and metadata address +2. **Create/Update Metadata**: Store metadata in referenced account +3. **Authority Validation**: Ensure proper permissions for updates + +### Extension Parsing +- Sequential TLV parsing using `get_tlv_indices()` +- Type-based lookup for specific extensions +- Support for both fixed-size (Pod) and variable-length extensions + +## Integration with Compressed Token Mint + +### Current Implementation Analysis +Your compressed token mint in `programs/compressed-token/program/src/mint/state.rs`: + +```rust +pub struct CompressedMint { + pub spl_mint: Pubkey, + pub supply: u64, + pub decimals: u8, + pub is_decompressed: bool, + pub mint_authority: Option, + pub freeze_authority: Option, + pub num_extensions: u8, // ← Already supports extensions! +} +``` + +### Integration Recommendations + +#### 1. **Extension Data Structure** +Add metadata pointer extension to your compressed mint: + +```rust +#[derive(Debug, Clone, BorshSerialize, BorshDeserialize)] +pub struct CompressedMintMetadataPointer { + pub authority: Option, + pub metadata_address: Option, +} + +// Add to extension system +pub enum CompressedMintExtension { + MetadataPointer(CompressedMintMetadataPointer), + // Other extensions... +} +``` + +#### 2. **Hashing Integration** +The metadata pointer would need to be included in the hash calculation: + +```rust +// In hash_with_hashed_values, add metadata pointer handling +if let Some(metadata_pointer) = metadata_pointer_extension { + // Hash metadata pointer data + let metadata_pointer_bytes = [0u8; 32]; + // Set prefix for metadata pointer + metadata_pointer_bytes[30] = 4; // metadata_pointer prefix + // Include in hash_inputs +} +``` + +#### 3. **Processing Integration** +Update `process_create_compressed_mint` to handle metadata pointer: + +```rust +// In processor.rs, add metadata pointer initialization +if let Some(metadata_pointer_data) = parsed_instruction_data.metadata_pointer { + // Validate metadata pointer authority + // Set metadata address + // Update num_extensions count +} +``` + +### Key Considerations + +#### 1. **Compression-Specific Challenges** +- **Hash State**: Metadata pointer must be included in compressed account hash +- **Proof Generation**: Changes to metadata pointer affect merkle tree proofs +- **Extension Counting**: `num_extensions` field needs proper management + +#### 2. **Authority Model** +- Metadata pointer authority separate from mint authority +- Authority validation needed for metadata updates +- Consider compressed account ownership model + +#### 3. **Storage Efficiency** +- Compressed accounts store data efficiently +- Metadata pointer adds minimal overhead (64 bytes) +- Consider storing metadata directly vs. pointer for small metadata + +### Implementation Steps + +1. **Define Extension Types**: Create compressed mint extension enum +2. **Update State Structure**: Add extension parsing to CompressedMint +3. **Modify Hash Function**: Include extensions in hash calculation +4. **Update Instructions**: Add metadata pointer initialization/update +5. **Authority Validation**: Implement permission checks +6. **Testing**: Ensure compatibility with existing compressed token functionality + +## Account Reading Process + +```rust +// 1. Load account data +let buffer = account_info.try_borrow_data()?; + +// 2. Parse as mint with extensions +let mint = PodStateWithExtensions::::unpack(&buffer)?; + +// 3. Get metadata pointer +let metadata_pointer = mint.get_extension::()?; + +// 4. If self-referential, read metadata from same account +if metadata_pointer.metadata_address == Some(mint_pubkey) { + let metadata = mint.get_variable_len_extension::()?; +} +``` + +## Summary + +The Token 2022 metadata pointer extension is well-designed for integration with compressed tokens, requiring mainly adaptation of the TLV parsing logic and hash computation for the compressed account model. The metadata account structure is designed for flexibility, allowing metadata to be stored either directly in the mint account or in a separate dedicated account, while maintaining efficient TLV parsing and Borsh serialization. \ No newline at end of file diff --git a/program-libs/account-checks/src/account_info/pinocchio.rs b/program-libs/account-checks/src/account_info/pinocchio.rs index 2b4f6ff2a9..b6b7c83134 100644 --- a/program-libs/account-checks/src/account_info/pinocchio.rs +++ b/program-libs/account-checks/src/account_info/pinocchio.rs @@ -19,14 +19,17 @@ impl AccountInfoTrait for pinocchio::account_info::AccountInfo { bytes } + #[inline(always)] fn is_writable(&self) -> bool { self.is_writable() } + #[inline(always)] fn is_signer(&self) -> bool { self.is_signer() } + #[inline(always)] fn executable(&self) -> bool { self.executable() } diff --git a/program-libs/compressed-account/Cargo.toml b/program-libs/compressed-account/Cargo.toml index 8623b20991..95ca99677e 100644 --- a/program-libs/compressed-account/Cargo.toml +++ b/program-libs/compressed-account/Cargo.toml @@ -18,7 +18,7 @@ new-unique = ["dep:solana-pubkey"] thiserror = { workspace = true } zerocopy = { workspace = true, features = ["derive"] } light-hasher = { workspace = true } -light-zero-copy = { workspace = true, features = ["std"] } +light-zero-copy = { workspace = true, features = ["std", "mut", "derive"] } light-macros = { workspace = true } pinocchio = { workspace = true, optional = true } solana-program-error = { workspace = true, optional = true } diff --git a/program-libs/compressed-account/src/compressed_account.rs b/program-libs/compressed-account/src/compressed_account.rs index 62159d135d..69e523362f 100644 --- a/program-libs/compressed-account/src/compressed_account.rs +++ b/program-libs/compressed-account/src/compressed_account.rs @@ -1,6 +1,7 @@ use std::collections::HashMap; use light_hasher::{Hasher, Poseidon}; +use light_zero_copy::{ZeroCopy, ZeroCopyMut}; use crate::{ address::pack_account, @@ -11,7 +12,7 @@ use crate::{ AnchorDeserialize, AnchorSerialize, CompressedAccountError, Pubkey, TreeType, }; -#[derive(Debug, PartialEq, Default, Clone, AnchorSerialize, AnchorDeserialize)] +#[derive(Debug, PartialEq, Default, Clone, AnchorSerialize, AnchorDeserialize, ZeroCopyMut)] pub struct PackedCompressedAccountWithMerkleContext { pub compressed_account: CompressedAccount, pub merkle_context: PackedMerkleContext, @@ -133,7 +134,7 @@ pub struct ReadOnlyCompressedAccount { pub root_index: u16, } -#[derive(Debug, PartialEq, Default, Clone, AnchorSerialize, AnchorDeserialize)] +#[derive(Debug, PartialEq, Default, Clone, AnchorSerialize, AnchorDeserialize, ZeroCopyMut)] pub struct PackedReadOnlyCompressedAccount { pub account_hash: [u8; 32], pub merkle_context: PackedMerkleContext, @@ -149,7 +150,17 @@ pub struct MerkleContext { pub tree_type: TreeType, } -#[derive(Debug, Clone, Copy, AnchorSerialize, AnchorDeserialize, PartialEq, Default)] +#[derive( + Debug, + Clone, + Copy, + AnchorSerialize, + AnchorDeserialize, + PartialEq, + Default, + ZeroCopy, + ZeroCopyMut, +)] pub struct PackedMerkleContext { pub merkle_tree_pubkey_index: u8, pub queue_pubkey_index: u8, @@ -217,7 +228,7 @@ pub fn pack_merkle_context( .collect::>() } -#[derive(Debug, PartialEq, Default, Clone, AnchorSerialize, AnchorDeserialize)] +#[derive(Debug, PartialEq, Default, Clone, AnchorSerialize, AnchorDeserialize, ZeroCopyMut)] pub struct CompressedAccount { pub owner: Pubkey, pub lamports: u64, @@ -234,7 +245,7 @@ pub struct InCompressedAccount { pub address: Option<[u8; 32]>, } -#[derive(Debug, PartialEq, Default, Clone, AnchorSerialize, AnchorDeserialize)] +#[derive(Debug, PartialEq, Default, Clone, AnchorSerialize, AnchorDeserialize, ZeroCopyMut)] pub struct CompressedAccountData { pub discriminator: [u8; 8], pub data: Vec, diff --git a/program-libs/compressed-account/src/instruction_data/compressed_proof.rs b/program-libs/compressed-account/src/instruction_data/compressed_proof.rs index 9c79f9ca24..d5c69381d8 100644 --- a/program-libs/compressed-account/src/instruction_data/compressed_proof.rs +++ b/program-libs/compressed-account/src/instruction_data/compressed_proof.rs @@ -1,4 +1,4 @@ -use light_zero_copy::{borsh::Deserialize, errors::ZeroCopyError}; +use light_zero_copy::{borsh::Deserialize, errors::ZeroCopyError, ZeroCopyMut}; use zerocopy::{FromBytes, Immutable, IntoBytes, KnownLayout, Ref, Unaligned}; use crate::{AnchorDeserialize, AnchorSerialize}; @@ -17,6 +17,7 @@ use crate::{AnchorDeserialize, AnchorSerialize}; FromBytes, IntoBytes, Unaligned, + ZeroCopyMut, )] pub struct CompressedProof { pub a: [u8; 32], diff --git a/program-libs/compressed-account/src/instruction_data/cpi_context.rs b/program-libs/compressed-account/src/instruction_data/cpi_context.rs index d91a4e11bb..05d9306559 100644 --- a/program-libs/compressed-account/src/instruction_data/cpi_context.rs +++ b/program-libs/compressed-account/src/instruction_data/cpi_context.rs @@ -1,6 +1,10 @@ +use light_zero_copy::ZeroCopyMut; + use crate::{AnchorDeserialize, AnchorSerialize}; -#[derive(AnchorSerialize, AnchorDeserialize, Debug, Clone, Copy, PartialEq, Eq, Default)] +#[derive( + AnchorSerialize, AnchorDeserialize, Debug, Clone, Copy, PartialEq, Eq, Default, ZeroCopyMut, +)] pub struct CompressedCpiContext { /// Is set by the program that is invoking the CPI to signal that is should /// set the cpi context. diff --git a/program-libs/compressed-account/src/instruction_data/data.rs b/program-libs/compressed-account/src/instruction_data/data.rs index 4c5ff5c261..e7a50d5d2e 100644 --- a/program-libs/compressed-account/src/instruction_data/data.rs +++ b/program-libs/compressed-account/src/instruction_data/data.rs @@ -1,5 +1,7 @@ use std::collections::HashMap; +use light_zero_copy::ZeroCopyMut; + use crate::{ compressed_account::{CompressedAccount, PackedCompressedAccountWithMerkleContext}, instruction_data::compressed_proof::CompressedProof, @@ -24,13 +26,15 @@ pub struct OutputCompressedAccountWithContext { pub merkle_tree: Pubkey, } -#[derive(Debug, PartialEq, Default, Clone, AnchorDeserialize, AnchorSerialize)] +#[derive(Debug, PartialEq, Default, Clone, AnchorDeserialize, AnchorSerialize, ZeroCopyMut)] pub struct OutputCompressedAccountWithPackedContext { pub compressed_account: CompressedAccount, pub merkle_tree_index: u8, } -#[derive(Debug, PartialEq, Default, Clone, Copy, AnchorDeserialize, AnchorSerialize)] +#[derive( + Debug, PartialEq, Default, Clone, Copy, AnchorDeserialize, AnchorSerialize, ZeroCopyMut, +)] pub struct NewAddressParamsPacked { pub seed: [u8; 32], pub address_queue_account_index: u8, @@ -38,7 +42,9 @@ pub struct NewAddressParamsPacked { pub address_merkle_tree_root_index: u16, } -#[derive(Debug, PartialEq, Default, Clone, Copy, AnchorDeserialize, AnchorSerialize)] +#[derive( + Debug, PartialEq, Default, Clone, Copy, AnchorDeserialize, AnchorSerialize, ZeroCopyMut, +)] pub struct NewAddressParamsAssignedPacked { pub seed: [u8; 32], pub address_queue_account_index: u8, @@ -86,7 +92,9 @@ pub struct NewAddressParamsAssigned { pub assigned_account_index: Option, } -#[derive(Debug, PartialEq, Default, Clone, Copy, AnchorDeserialize, AnchorSerialize)] +#[derive( + Debug, PartialEq, Default, Clone, Copy, AnchorDeserialize, AnchorSerialize, ZeroCopyMut, +)] pub struct PackedReadOnlyAddress { pub address: [u8; 32], pub address_merkle_tree_root_index: u16, diff --git a/program-libs/compressed-account/src/instruction_data/invoke_cpi.rs b/program-libs/compressed-account/src/instruction_data/invoke_cpi.rs index eaed16c3cd..59299dcaa1 100644 --- a/program-libs/compressed-account/src/instruction_data/invoke_cpi.rs +++ b/program-libs/compressed-account/src/instruction_data/invoke_cpi.rs @@ -1,3 +1,5 @@ +use light_zero_copy::ZeroCopyMut; + use super::{ cpi_context::CompressedCpiContext, data::{NewAddressParamsPacked, OutputCompressedAccountWithPackedContext}, @@ -8,7 +10,7 @@ use crate::{ }; #[repr(C)] -#[derive(Debug, PartialEq, Default, Clone, AnchorDeserialize, AnchorSerialize)] +#[derive(Debug, PartialEq, Default, Clone, AnchorDeserialize, AnchorSerialize, ZeroCopyMut)] pub struct InstructionDataInvokeCpi { pub proof: Option, pub new_address_params: Vec, diff --git a/program-libs/compressed-account/src/instruction_data/with_account_info.rs b/program-libs/compressed-account/src/instruction_data/with_account_info.rs index 599ad9cd0b..57b49e5e78 100644 --- a/program-libs/compressed-account/src/instruction_data/with_account_info.rs +++ b/program-libs/compressed-account/src/instruction_data/with_account_info.rs @@ -399,9 +399,13 @@ impl<'a> Deserialize<'a> for InstructionDataInvokeCpiWithAccountInfo { let (account_infos, bytes) = { let (num_slices, mut bytes) = Ref::<&[u8], U32>::from_prefix(bytes)?; let num_slices = u32::from(*num_slices) as usize; - // TODO: add check that remaining data is enough to read num_slices - // This prevents agains invalid data allocating a lot of heap memory let mut slices = Vec::with_capacity(num_slices); + if bytes.len() < num_slices { + return Err(ZeroCopyError::InsufficientMemoryAllocated( + bytes.len(), + num_slices, + )); + } for _ in 0..num_slices { let (slice, _bytes) = CompressedAccountInfo::zero_copy_at_with_owner( bytes, diff --git a/program-libs/compressed-account/src/instruction_data/with_readonly.rs b/program-libs/compressed-account/src/instruction_data/with_readonly.rs index 59b9c27bd7..28b169b206 100644 --- a/program-libs/compressed-account/src/instruction_data/with_readonly.rs +++ b/program-libs/compressed-account/src/instruction_data/with_readonly.rs @@ -1,6 +1,8 @@ use std::ops::Deref; -use light_zero_copy::{borsh::Deserialize, errors::ZeroCopyError, slice::ZeroCopySliceBorsh}; +use light_zero_copy::{ + borsh::Deserialize, errors::ZeroCopyError, slice::ZeroCopySliceBorsh, ZeroCopyMut, +}; use zerocopy::{ little_endian::{U16, U32, U64}, FromBytes, Immutable, IntoBytes, KnownLayout, Ref, Unaligned, @@ -30,7 +32,7 @@ use crate::{ AnchorDeserialize, AnchorSerialize, CompressedAccountError, }; -#[derive(Debug, Default, PartialEq, Clone, AnchorSerialize, AnchorDeserialize)] +#[derive(Debug, Default, PartialEq, Clone, AnchorSerialize, AnchorDeserialize, ZeroCopyMut)] pub struct InAccount { pub discriminator: [u8; 8], /// Data hash @@ -193,7 +195,7 @@ impl<'a> Deref for ZInAccount<'a> { } } -#[derive(Debug, PartialEq, Default, Clone, AnchorSerialize, AnchorDeserialize)] +#[derive(Debug, PartialEq, Default, Clone, AnchorSerialize, AnchorDeserialize, ZeroCopyMut)] pub struct InstructionDataInvokeCpiWithReadOnly { /// 0 With program ids /// 1 without program ids @@ -347,8 +349,14 @@ impl<'a> Deserialize<'a> for InstructionDataInvokeCpiWithReadOnly { let (input_compressed_accounts, bytes) = { let (num_slices, mut bytes) = Ref::<&[u8], U32>::from_prefix(bytes)?; let num_slices = u32::from(*num_slices) as usize; - // TODO: add check that remaining data is enough to read num_slices - // This prevents agains invalid data allocating a lot of heap memory + // Prevent heap exhaustion attacks by checking if num_slices is reasonable + // Each element needs at least 1 byte when serialized + if bytes.len() < num_slices { + return Err(ZeroCopyError::InsufficientMemoryAllocated( + bytes.len(), + num_slices, + )); + } let mut slices = Vec::with_capacity(num_slices); for _ in 0..num_slices { let (slice, _bytes) = diff --git a/program-libs/compressed-account/src/pubkey.rs b/program-libs/compressed-account/src/pubkey.rs index 9dc74ea35f..2f2929e7a1 100644 --- a/program-libs/compressed-account/src/pubkey.rs +++ b/program-libs/compressed-account/src/pubkey.rs @@ -1,6 +1,6 @@ #[cfg(feature = "bytemuck-des")] use bytemuck::{Pod, Zeroable}; -use light_zero_copy::{borsh::Deserialize, errors::ZeroCopyError}; +use light_zero_copy::{borsh::{Deserialize, ZeroCopyStructInner}, borsh_mut::{DeserializeMut, ZeroCopyStructInnerMut}, errors::ZeroCopyError, ZeroCopyNew}; use zerocopy::{FromBytes, Immutable, IntoBytes, KnownLayout, Ref, Unaligned}; use crate::{AnchorDeserialize, AnchorSerialize}; @@ -46,6 +46,21 @@ pub struct Pubkey(pub(crate) [u8; 32]); #[repr(C)] pub struct Pubkey(pub(crate) [u8; 32]); +impl<'a> ZeroCopyNew<'a> for Pubkey { + type ZeroCopyConfig = (); + type Output = zerocopy::Ref<&'a mut [u8], Pubkey>; + fn byte_len(_config: &Self::ZeroCopyConfig) -> usize { + 32 + } + fn new_zero_copy( + bytes: &'a mut [u8], + _config: Self::ZeroCopyConfig, + ) -> Result<(Self::Output, &'a mut [u8]), ZeroCopyError> { + let (key, rest) = zerocopy::Ref::from_prefix(bytes)?; + Ok((key, rest)) + } +} + impl Pubkey { pub fn new_from_array(array: [u8; 32]) -> Self { Self(array) @@ -91,6 +106,25 @@ impl<'a> Deserialize<'a> for Pubkey { Ok(Ref::<&[u8], Pubkey>::from_prefix(bytes)?) } } + +impl<'a> DeserializeMut<'a> for Pubkey { + type Output = Ref<&'a mut [u8], Pubkey>; + + #[inline] + fn zero_copy_at_mut( + bytes: &'a mut [u8], + ) -> Result<(Self::Output, &'a mut [u8]), ZeroCopyError> { + Ok(Ref::<&mut [u8], Pubkey>::from_prefix(bytes)?) + } +} + +impl ZeroCopyStructInner for Pubkey { + type ZeroCopyInner = Pubkey; +} + +impl ZeroCopyStructInnerMut for Pubkey { + type ZeroCopyInnerMut = Pubkey; +} impl From for [u8; 32] { fn from(pubkey: Pubkey) -> Self { pubkey.to_bytes() diff --git a/program-libs/zero-copy-derive/Cargo.toml b/program-libs/zero-copy-derive/Cargo.toml new file mode 100644 index 0000000000..1cdc8254e8 --- /dev/null +++ b/program-libs/zero-copy-derive/Cargo.toml @@ -0,0 +1,26 @@ +[package] +name = "light-zero-copy-derive" +version = "0.1.0" +edition = "2021" +license = "Apache-2.0" +description = "Proc macro for zero-copy deserialization" + +[features] +default = [] +mut = [] + +[lib] +proc-macro = true + +[dependencies] +proc-macro2 = "1.0" +quote = "1.0" +syn = { version = "2.0", features = ["full", "extra-traits"] } +lazy_static = "1.4" + +[dev-dependencies] +trybuild = "1.0" +rand = "0.8" +borsh = { workspace = true } +light-zero-copy = { workspace = true, features = ["std", "derive"] } +zerocopy = { workspace = true, features = ["derive"] } diff --git a/program-libs/zero-copy-derive/README.md b/program-libs/zero-copy-derive/README.md new file mode 100644 index 0000000000..8e17fbbb25 --- /dev/null +++ b/program-libs/zero-copy-derive/README.md @@ -0,0 +1,103 @@ +# Light-Zero-Copy-Derive + +A procedural macro for deriving zero-copy deserialization for Rust structs used with Solana programs. + +## Features + +This crate provides two key derive macros: + +1. `#[derive(ZeroCopy)]` - Implements zero-copy deserialization with: + - The `zero_copy_at` and `zero_copy_at_mut` methods for deserialization + - Full Borsh compatibility for serialization/deserialization + - Efficient memory representation with no copying of data + - `From>` and `FromMut>` implementations for easy conversion back to the original struct + +2. `#[derive(ZeroCopyEq)]` - Adds equality comparison support: + - Compare zero-copy instances with regular struct instances + - Can be used alongside `ZeroCopy` for complete functionality + - Derivation for Options is not robust and may not compile. + +## Rules for Zero-Copy Deserialization + +The macro follows these rules when generating code: + +1. Creates a `ZStruct` for your struct that follows zero-copy principles + 1. Fields are extracted into a meta struct until reaching a `Vec`, `Option` or non-`Copy` type + 2. Vectors are represented as `ZeroCopySlice` and not included in the meta struct + 3. Integer types are replaced with their zerocopy equivalents (e.g., `u16` → `U16`) + 4. Fields after the first vector are directly included in the `ZStruct` and deserialized one by one + 5. If a vector contains a nested vector (non-`Copy` type), it must implement `Deserialize` + 6. Elements in an `Option` must implement `Deserialize` + 7. Types that don't implement `Copy` must implement `Deserialize` and are deserialized one by one + +## Usage + +### Basic Usage + +```rust +use borsh::{BorshDeserialize, BorshSerialize}; +use light_zero_copy_derive::ZeroCopy; +use light_zero_copy::{borsh::Deserialize, borsh_mut::DeserializeMut}; + +#[repr(C)] +#[derive(Debug, PartialEq, BorshSerialize, BorshDeserialize, ZeroCopy)] +pub struct MyStruct { + pub a: u8, + pub b: u16, + pub vec: Vec, + pub c: u64, +} +let my_struct = MyStruct { + a: 1, + b: 2, + vec: vec![1u8; 32], + c: 3, +}; +// Use the struct with zero-copy deserialization +let mut bytes = my_struct.try_to_vec().unwrap(); + +// Immutable zero-copy deserialization +let (zero_copy, _remaining) = MyStruct::zero_copy_at(&bytes).unwrap(); + +// Convert back to original struct using From implementation +let converted: MyStruct = zero_copy.clone().into(); +assert_eq!(converted, my_struct); + +// Mutable zero-copy deserialization with modification +let (mut zero_copy_mut, _remaining) = MyStruct::zero_copy_at_mut(&mut bytes).unwrap(); +zero_copy_mut.a = 42; + +// The change is reflected when we convert back to the original struct +let modified: MyStruct = zero_copy_mut.into(); +assert_eq!(modified.a, 42); + +// And also when we deserialize directly from the modified bytes +let borsh = MyStruct::try_from_slice(&bytes).unwrap(); +assert_eq!(borsh.a, 42u8); +``` + +### With Equality Comparison + +```rust +use borsh::{BorshDeserialize, BorshSerialize}; +use light_zero_copy_derive::ZeroCopy; + +#[repr(C)] +#[derive(Debug, PartialEq, BorshSerialize, BorshDeserialize, ZeroCopy)] +pub struct MyStruct { + pub a: u8, + pub b: u16, + pub vec: Vec, + pub c: u64, +} +let my_struct = MyStruct { + a: 1, + b: 2, + vec: vec![1u8; 32], + c: 3, +}; +// Use the struct with zero-copy deserialization +let mut bytes = my_struct.try_to_vec().unwrap(); +let (zero_copy, _remaining) = MyStruct::zero_copy_at(&bytes).unwrap(); +assert_eq!(zero_copy, my_struct); +``` diff --git a/program-libs/zero-copy-derive/src/lib.rs b/program-libs/zero-copy-derive/src/lib.rs new file mode 100644 index 0000000000..becac18087 --- /dev/null +++ b/program-libs/zero-copy-derive/src/lib.rs @@ -0,0 +1,166 @@ +//! Procedural macros for zero-copy deserialization. +//! +//! This crate provides derive macros that generate efficient zero-copy data structures +//! and deserialization code, eliminating the need for data copying during parsing. +//! +//! ## Main Macros +//! +//! - `ZeroCopy`: Generates zero-copy structs and deserialization traits +//! - `ZeroCopyMut`: Adds mutable zero-copy support +//! - `ZeroCopyEq`: Adds PartialEq implementation for comparing with original structs +//! - `ZeroCopyNew`: Generates configuration structs for initialization + +use proc_macro::TokenStream; + +mod shared; +mod zero_copy; +mod zero_copy_eq; +#[cfg(feature = "mut")] +mod zero_copy_mut; + +/// ZeroCopy derivation macro for zero-copy deserialization +/// +/// # Usage +/// +/// Basic usage: +/// ```rust +/// use light_zero_copy_derive::ZeroCopy; +/// #[derive(ZeroCopy)] +/// pub struct MyStruct { +/// pub a: u8, +/// } +/// ``` +/// +/// To derive PartialEq as well, use ZeroCopyEq in addition to ZeroCopy: +/// ```rust +/// use light_zero_copy_derive::{ZeroCopy, ZeroCopyEq}; +/// #[derive(ZeroCopy, ZeroCopyEq)] +/// pub struct MyStruct { +/// pub a: u8, +/// } +/// ``` +/// +/// # Macro Rules +/// 1. Create zero copy structs Z and ZMut for the struct +/// 1.1. The first fields are extracted into a meta struct until we reach a Vec, Option or type that does not implement Copy +/// 1.2. Represent vectors to ZeroCopySlice & don't include these into the meta struct +/// 1.3. Replace u16 with U16, u32 with U32, etc +/// 1.4. Every field after the first vector is directly included in the ZStruct and deserialized 1 by 1 +/// 1.5. If a vector contains a nested vector (does not implement Copy) it must implement Deserialize +/// 1.6. Elements in an Option must implement Deserialize +/// 1.7. A type that does not implement Copy must implement Deserialize, and is deserialized 1 by 1 +/// 1.8. is u8 deserialized as u8::zero_copy_at instead of Ref<&'a [u8], u8> for non mut, for mut it is Ref<&'a mut [u8], u8> +/// 2. Implement Deserialize and DeserializeMut which return Z and ZMut +/// 3. Implement From> for StructName and FromMut> for StructName +/// +/// Note: Options are not supported in ZeroCopyEq +#[proc_macro_derive(ZeroCopy)] +pub fn derive_zero_copy(input: TokenStream) -> TokenStream { + let res = zero_copy::derive_zero_copy_impl(input); + TokenStream::from(match res { + Ok(res) => res, + Err(err) => err.to_compile_error(), + }) +} + +/// ZeroCopyEq implementation to add PartialEq for zero-copy structs. +/// +/// Use this in addition to ZeroCopy when you want the generated struct to implement PartialEq: +/// +/// ```rust +/// use light_zero_copy_derive::{ZeroCopy, ZeroCopyEq}; +/// #[derive(ZeroCopy, ZeroCopyEq)] +/// pub struct MyStruct { +/// pub a: u8, +/// } +/// ``` +#[proc_macro_derive(ZeroCopyEq)] +pub fn derive_zero_copy_eq(input: TokenStream) -> TokenStream { + let res = zero_copy_eq::derive_zero_copy_eq_impl(input); + TokenStream::from(match res { + Ok(res) => res, + Err(err) => err.to_compile_error(), + }) +} + +/// ZeroCopyMut derivation macro for mutable zero-copy deserialization +/// +/// This macro generates mutable zero-copy implementations including: +/// - DeserializeMut trait implementation +/// - Mutable Z-struct with `Mut` suffix +/// - byte_len() method implementation +/// - Mutable ZeroCopyStructInner implementation +/// +/// # Usage +/// +/// ```rust +/// use light_zero_copy_derive::ZeroCopyMut; +/// +/// #[derive(ZeroCopyMut)] +/// pub struct MyStruct { +/// pub a: u8, +/// pub vec: Vec, +/// } +/// ``` +/// +/// This will generate: +/// - `MyStruct::zero_copy_at_mut()` method +/// - `ZMyStructMut<'a>` type for mutable zero-copy access +/// - `MyStruct::byte_len()` method +/// +/// For both immutable and mutable functionality, use both derives: +/// ```rust +/// use light_zero_copy_derive::{ZeroCopy, ZeroCopyMut}; +/// +/// #[derive(ZeroCopy, ZeroCopyMut)] +/// pub struct MyStruct { +/// pub a: u8, +/// } +/// ``` +#[cfg(feature = "mut")] +#[proc_macro_derive(ZeroCopyMut)] +pub fn derive_zero_copy_mut(input: TokenStream) -> TokenStream { + let res = zero_copy_mut::derive_zero_copy_mut_impl(input); + TokenStream::from(match res { + Ok(res) => res, + Err(err) => err.to_compile_error(), + }) +} + +// /// ZeroCopyNew derivation macro for configuration-based zero-copy initialization +// /// +// /// This macro generates configuration structs and initialization methods for structs +// /// with Vec and Option fields that need to be initialized with specific configurations. +// /// +// /// # Usage +// /// +// /// ```ignore +// /// use light_zero_copy_derive::ZeroCopyNew; +// /// +// /// #[derive(ZeroCopyNew)] +// /// pub struct MyStruct { +// /// pub a: u8, +// /// pub vec: Vec, +// /// pub option: Option, +// /// } +// /// ``` +// /// +// /// This will generate: +// /// - `MyStructConfig` struct with configuration fields +// /// - `ZeroCopyNew` implementation for `MyStruct` +// /// - `new_zero_copy(bytes, config)` method for initialization +// /// +// /// The configuration struct will have fields based on the complexity of the original fields: +// /// - `Vec` → `field_name: u32` (length) +// /// - `Option` → `field_name: bool` (is_some) +// /// - `Vec` → `field_name: Vec` (config per element) +// /// - `Option` → `field_name: Option` (config if some) +// #[cfg(feature = "mut")] +// #[proc_macro_derive(ZeroCopyNew)] +// pub fn derive_zero_copy_config(input: TokenStream) -> TokenStream { +// let res = zero_copy_new::derive_zero_copy_config_impl(input); +// TokenStream::from(match res { +// Ok(res) => res, +// Err(err) => err.to_compile_error(), +// }) +// } diff --git a/program-libs/zero-copy-derive/src/shared/from_impl.rs b/program-libs/zero-copy-derive/src/shared/from_impl.rs new file mode 100644 index 0000000000..1aab0eb9b3 --- /dev/null +++ b/program-libs/zero-copy-derive/src/shared/from_impl.rs @@ -0,0 +1,242 @@ +use proc_macro2::TokenStream; +use quote::{format_ident, quote}; +use syn::{Field, Ident}; + +use super::{ + utils, + z_struct::{analyze_struct_fields, FieldType}, +}; + +/// Generates code for the From> for StructName implementation +/// The `MUT` parameter controls whether to generate code for mutable or immutable references +pub fn generate_from_impl( + name: &Ident, + z_struct_name: &Ident, + meta_fields: &[&Field], + struct_fields: &[&Field], +) -> syn::Result { + let z_struct_name = if MUT { + format_ident!("{}Mut", z_struct_name) + } else { + z_struct_name.clone() + }; + + // Generate the conversion code for meta fields + let meta_field_conversions = if !meta_fields.is_empty() { + let field_types = analyze_struct_fields(meta_fields)?; + let conversions = field_types.into_iter().map(|field_type| { + match field_type { + FieldType::Primitive(field_name, field_type) => { + match () { + _ if utils::is_specific_primitive_type(field_type, "u8") => { + quote! { #field_name: value.__meta.#field_name, } + } + _ if utils::is_specific_primitive_type(field_type, "bool") => { + quote! { #field_name: value.__meta.#field_name > 0, } + } + _ => { + // For u64, u32, u16 - use the type's from() method + quote! { #field_name: #field_type::from(value.__meta.#field_name), } + } + } + } + FieldType::Array(field_name, _) => { + // For arrays, just copy the value + quote! { #field_name: value.__meta.#field_name, } + } + FieldType::Pubkey(field_name) => { + quote! { #field_name: value.__meta.#field_name, } + } + _ => { + let field_name = field_type.name(); + quote! { #field_name: value.__meta.#field_name.into(), } + } + } + }); + conversions.collect::>() + } else { + vec![] + }; + + // Generate the conversion code for struct fields + let struct_field_conversions = if !struct_fields.is_empty() { + let field_types = analyze_struct_fields(struct_fields)?; + let conversions = field_types.into_iter().map(|field_type| { + match field_type { + FieldType::VecU8(field_name) => { + quote! { #field_name: value.#field_name.to_vec(), } + } + FieldType::VecCopy(field_name, _) => { + quote! { #field_name: value.#field_name.to_vec(), } + } + FieldType::VecDynamicZeroCopy(field_name, _) => { + // For non-copy vectors, clone each element directly + // We need to convert into() for Zstructs + quote! { + #field_name: { + value.#field_name.iter().map(|item| (*item).clone().into()).collect() + }, + } + } + FieldType::Array(field_name, _) => { + // For arrays, just copy the value + quote! { #field_name: *value.#field_name, } + } + FieldType::Option(field_name, field_type) => { + // Extract inner type from Option + let inner_type = utils::get_option_inner_type(field_type).expect( + "Failed to extract inner type from Option - expected Option format", + ); + let field_type = inner_type; + // For Option types, use a direct copy of the value when possible + quote! { + #field_name: if value.#field_name.is_some() { + // Create a clone of the Some value - for compressed proofs and other structs + // For instruction_data.rs, we just need to clone the value directly + Some((#field_type::from(*value.#field_name.as_ref().unwrap()).clone())) + } else { + None + }, + } + } + FieldType::Pubkey(field_name) => { + quote! { #field_name: *value.#field_name, } + } + FieldType::Primitive(field_name, field_type) => { + match () { + _ if utils::is_specific_primitive_type(field_type, "u8") => { + if MUT { + quote! { #field_name: *value.#field_name, } + } else { + quote! { #field_name: value.#field_name, } + } + } + _ if utils::is_specific_primitive_type(field_type, "bool") => { + if MUT { + quote! { #field_name: *value.#field_name > 0, } + } else { + quote! { #field_name: value.#field_name > 0, } + } + } + _ => { + // For u64, u32, u16 - use the type's from() method + quote! { #field_name: #field_type::from(*value.#field_name), } + } + } + } + FieldType::Copy(field_name, _) => { + quote! { #field_name: value.#field_name, } + } + FieldType::OptionU64(field_name) => { + quote! { #field_name: value.#field_name.as_ref().map(|x| u64::from(**x)), } + } + FieldType::OptionU32(field_name) => { + quote! { #field_name: value.#field_name.as_ref().map(|x| u32::from(**x)), } + } + FieldType::OptionU16(field_name) => { + quote! { #field_name: value.#field_name.as_ref().map(|x| u16::from(**x)), } + } + FieldType::DynamicZeroCopy(field_name, field_type) => { + // For complex non-copy types, dereference and clone directly + quote! { #field_name: #field_type::from(&value.#field_name), } + } + } + }); + conversions.collect::>() + } else { + vec![] + }; + + // Combine all the field conversions + let all_field_conversions = [meta_field_conversions, struct_field_conversions].concat(); + + // Return the final From implementation without generic From implementations + let result = quote! { + impl<'a> From<#z_struct_name<'a>> for #name { + fn from(value: #z_struct_name<'a>) -> Self { + Self { + #(#all_field_conversions)* + } + } + } + + impl<'a> From<&#z_struct_name<'a>> for #name { + fn from(value: &#z_struct_name<'a>) -> Self { + Self { + #(#all_field_conversions)* + } + } + } + }; + Ok(result) +} + +#[cfg(test)] +mod tests { + use quote::format_ident; + use syn::{parse_quote, Field}; + + use super::*; + + #[test] + fn test_generate_from_impl() { + // Create a struct for testing + let name = format_ident!("TestStruct"); + let z_struct_name = format_ident!("ZTestStruct"); + + // Create some test fields + let field_a: Field = parse_quote!(pub a: u8); + let field_b: Field = parse_quote!(pub b: u16); + let field_c: Field = parse_quote!(pub c: Vec); + + // Split into meta and struct fields + let meta_fields = vec![&field_a, &field_b]; + let struct_fields = vec![&field_c]; + + // Generate the implementation + let result = + generate_from_impl::(&name, &z_struct_name, &meta_fields, &struct_fields); + + // Convert to string for testing + let result_str = result.unwrap().to_string(); + + // Check that the implementation contains required elements + assert!(result_str.contains("impl < 'a > From < ZTestStruct < 'a >> for TestStruct")); + + // Check field handling + assert!(result_str.contains("a :")); // For u8 fields + assert!(result_str.contains("b :")); // For u16 fields + assert!(result_str.contains("c :")); // For Vec fields + } + + #[test] + fn test_generate_from_impl_mut() { + // Create a struct for testing + let name = format_ident!("TestStruct"); + let z_struct_name = format_ident!("ZTestStruct"); + + // Create some test fields + let field_a: Field = parse_quote!(pub a: u8); + let field_b: Field = parse_quote!(pub b: bool); + let field_c: Field = parse_quote!(pub c: Option); + + // Split into meta and struct fields + let meta_fields = vec![&field_a, &field_b]; + let struct_fields = vec![&field_c]; + + // Generate the implementation for mutable version + let result = + generate_from_impl::(&name, &z_struct_name, &meta_fields, &struct_fields); + + // Convert to string for testing + let result_str = result.unwrap().to_string(); + + // Check that the implementation contains required elements + assert!(result_str.contains("impl < 'a > From < ZTestStructMut < 'a >> for TestStruct")); + + // Check field handling + assert!(result_str.contains("a :")); // For u8 fields + assert!(result_str.contains("b :")); // For bool fields + assert!(result_str.contains("c :")); // For Option fields + } +} diff --git a/program-libs/zero-copy-derive/src/shared/meta_struct.rs b/program-libs/zero-copy-derive/src/shared/meta_struct.rs new file mode 100644 index 0000000000..0dbf9cda3a --- /dev/null +++ b/program-libs/zero-copy-derive/src/shared/meta_struct.rs @@ -0,0 +1,57 @@ +use proc_macro2::TokenStream; +use quote::{format_ident, quote}; +use syn::Field; + +use super::utils::convert_to_zerocopy_type; + +/// Generates the meta struct definition as a TokenStream +/// The `MUT` parameter determines if the struct should be generated for mutable access +pub fn generate_meta_struct( + z_struct_meta_name: &syn::Ident, + meta_fields: &[&Field], + hasher: bool, +) -> syn::Result { + let z_struct_meta_name = if MUT { + format_ident!("{}Mut", z_struct_meta_name) + } else { + z_struct_meta_name.clone() + }; + + // Generate the meta struct fields with converted types + let meta_fields_with_converted_types = meta_fields.iter().map(|field| { + let field_name = &field.ident; + let attributes = if hasher { + field + .attrs + .iter() + .map(|attr| { + quote! { #attr } + }) + .collect::>() + } else { + vec![quote! {}] + }; + let field_type = convert_to_zerocopy_type(&field.ty); + quote! { + #(#attributes)* + pub #field_name: #field_type + } + }); + let hasher = if hasher { + quote! { + , LightHasher + } + } else { + quote! {} + }; + + // Return the complete meta struct definition + let result = quote! { + #[repr(C)] + #[derive(Debug, PartialEq, light_zero_copy::KnownLayout, light_zero_copy::Immutable, light_zero_copy::Unaligned, light_zero_copy::FromBytes, light_zero_copy::IntoBytes #hasher)] + pub struct #z_struct_meta_name { + #(#meta_fields_with_converted_types,)* + } + }; + Ok(result) +} diff --git a/program-libs/zero-copy-derive/src/shared/mod.rs b/program-libs/zero-copy-derive/src/shared/mod.rs new file mode 100644 index 0000000000..c7b406b530 --- /dev/null +++ b/program-libs/zero-copy-derive/src/shared/mod.rs @@ -0,0 +1,6 @@ +pub mod from_impl; +pub mod meta_struct; +pub mod utils; +pub mod z_struct; +#[cfg(feature = "mut")] +pub mod zero_copy_new; diff --git a/program-libs/zero-copy-derive/src/shared/utils.rs b/program-libs/zero-copy-derive/src/shared/utils.rs new file mode 100644 index 0000000000..e92e56bb29 --- /dev/null +++ b/program-libs/zero-copy-derive/src/shared/utils.rs @@ -0,0 +1,437 @@ +use std::{ + collections::HashMap, + sync::{Arc, Mutex}, +}; + +use proc_macro2::TokenStream; +use quote::{format_ident, quote}; +use syn::{Attribute, Data, DeriveInput, Field, Fields, FieldsNamed, Ident, Type, TypePath}; + +// Global cache for storing whether a struct implements Copy +lazy_static::lazy_static! { + static ref COPY_IMPL_CACHE: Arc>> = Arc::new(Mutex::new(HashMap::new())); +} + +/// Creates a unique cache key for a type using span information to avoid collisions +/// between types with the same name from different modules/locations +fn create_unique_type_key(ident: &Ident) -> String { + format!("{}:{:?}", ident, ident.span()) +} + +/// Process the derive input to extract the struct information +pub fn process_input( + input: &DeriveInput, +) -> syn::Result<( + &Ident, // Original struct name + proc_macro2::Ident, // Z-struct name + proc_macro2::Ident, // Z-struct meta name + &FieldsNamed, // Struct fields +)> { + let name = &input.ident; + let z_struct_name = format_ident!("Z{}", name); + let z_struct_meta_name = format_ident!("Z{}Meta", name); + + // Populate the cache by checking if this struct implements Copy + let _ = struct_implements_copy(input); + + let fields = match &input.data { + Data::Struct(data) => match &data.fields { + Fields::Named(fields) => fields, + _ => { + return Err(syn::Error::new_spanned( + &data.fields, + "ZeroCopy only supports structs with named fields", + )) + } + }, + _ => { + return Err(syn::Error::new_spanned( + input, + "ZeroCopy only supports structs", + )) + } + }; + + Ok((name, z_struct_name, z_struct_meta_name, fields)) +} + +pub fn process_fields(fields: &FieldsNamed) -> (Vec<&Field>, Vec<&Field>) { + let mut meta_fields = Vec::new(); + let mut struct_fields = Vec::new(); + let mut reached_vec_or_option = false; + + for field in fields.named.iter() { + if !reached_vec_or_option { + if is_vec_or_option(&field.ty) || !is_copy_type(&field.ty) { + reached_vec_or_option = true; + struct_fields.push(field); + } else { + meta_fields.push(field); + } + } else { + struct_fields.push(field); + } + } + + (meta_fields, struct_fields) +} + +pub fn is_vec_or_option(ty: &Type) -> bool { + is_vec_type(ty) || is_option_type(ty) +} + +pub fn is_vec_type(ty: &Type) -> bool { + if let Type::Path(TypePath { path, .. }) = ty { + if let Some(segment) = path.segments.last() { + return segment.ident == "Vec"; + } + } + false +} + +pub fn is_option_type(ty: &Type) -> bool { + if let Type::Path(TypePath { path, .. }) = ty { + if let Some(segment) = path.segments.last() { + return segment.ident == "Option"; + } + } + false +} + +pub fn get_vec_inner_type(ty: &Type) -> Option<&Type> { + if let Type::Path(TypePath { path, .. }) = ty { + if let Some(segment) = path.segments.last() { + if segment.ident == "Vec" { + if let syn::PathArguments::AngleBracketed(args) = &segment.arguments { + if let Some(syn::GenericArgument::Type(inner_ty)) = args.args.first() { + return Some(inner_ty); + } + } + } + } + } + None +} + +pub fn get_option_inner_type(ty: &Type) -> Option<&Type> { + if let Type::Path(TypePath { path, .. }) = ty { + if let Some(segment) = path.segments.last() { + if segment.ident == "Option" { + if let syn::PathArguments::AngleBracketed(args) = &segment.arguments { + if let Some(syn::GenericArgument::Type(inner_ty)) = args.args.first() { + return Some(inner_ty); + } + } + } + } + } + None +} + +pub fn is_primitive_integer(ty: &Type) -> bool { + if let Type::Path(TypePath { path, .. }) = ty { + if let Some(segment) = path.segments.last() { + let ident = &segment.ident; + return ident == "u16" + || ident == "u32" + || ident == "u64" + || ident == "i16" + || ident == "i32" + || ident == "i64" + || ident == "u8" + || ident == "i8"; + } + } + false +} + +pub fn is_bool_type(ty: &Type) -> bool { + if let Type::Path(TypePath { path, .. }) = ty { + if let Some(segment) = path.segments.last() { + return segment.ident == "bool"; + } + } + false +} + +/// Check if a type is a specific primitive type (u8, u16, u32, u64, bool, etc.) +pub fn is_specific_primitive_type(ty: &Type, type_name: &str) -> bool { + if let Type::Path(TypePath { path, .. }) = ty { + if let Some(segment) = path.segments.last() { + return segment.ident == type_name; + } + } + false +} + +pub fn is_pubkey_type(ty: &Type) -> bool { + if let Type::Path(TypePath { path, .. }) = ty { + if let Some(segment) = path.segments.last() { + return segment.ident == "Pubkey"; + } + } + false +} + +pub fn convert_to_zerocopy_type(ty: &Type) -> TokenStream { + match ty { + Type::Path(TypePath { path, .. }) => { + if let Some(segment) = path.segments.last() { + let ident = &segment.ident; + + // Handle primitive types first + match ident.to_string().as_str() { + "u16" => quote! { light_zero_copy::little_endian::U16 }, + "u32" => quote! { light_zero_copy::little_endian::U32 }, + "u64" => quote! { light_zero_copy::little_endian::U64 }, + "bool" => quote! { u8 }, + _ => { + // Handle container types recursively + if let syn::PathArguments::AngleBracketed(args) = &segment.arguments { + let transformed_args: Vec = args + .args + .iter() + .map(|arg| { + if let syn::GenericArgument::Type(inner_type) = arg { + convert_to_zerocopy_type(inner_type) + } else { + quote! { #arg } + } + }) + .collect(); + + quote! { #ident<#(#transformed_args),*> } + } else { + quote! { #ty } + } + } + } + } else { + quote! { #ty } + } + } + _ => { + quote! { #ty } + } + } +} + +/// Checks if a struct has a derive(Copy) attribute +fn struct_has_copy_derive(attrs: &[Attribute]) -> bool { + attrs.iter().any(|attr| { + attr.path().is_ident("derive") && { + let mut found_copy = false; + // Use parse_nested_meta as the primary and only approach - it's the syn 2.0 standard + // for parsing comma-separated derive items like #[derive(Copy, Clone, Debug)] + attr.parse_nested_meta(|meta| { + if meta.path.is_ident("Copy") { + found_copy = true; + } + Ok(()) // Continue parsing other derive items + }) + .is_ok() + && found_copy + } + }) +} + +/// Determines whether a struct implements Copy by checking for the #[derive(Copy)] attribute. +/// Results are cached for performance. +/// +/// In Rust, a struct can only implement Copy if: +/// 1. It explicitly has a #[derive(Copy)] attribute, AND +/// 2. All of its fields implement Copy +/// +/// The Rust compiler will enforce the second condition at compile time, so we only need to check +/// for the derive attribute here. +pub fn struct_implements_copy(input: &DeriveInput) -> bool { + let cache_key = create_unique_type_key(&input.ident); + + // Check the cache first + if let Some(implements_copy) = COPY_IMPL_CACHE.lock().unwrap().get(&cache_key) { + return *implements_copy; + } + + // Check if the struct has a derive(Copy) attribute + let implements_copy = struct_has_copy_derive(&input.attrs); + + // Cache the result + COPY_IMPL_CACHE + .lock() + .unwrap() + .insert(cache_key, implements_copy); + + implements_copy +} + +/// Determines whether a type implements Copy +/// 1. check whether type is a primitive type that implements Copy +/// 2. check whether type is an array type (which is always Copy if the element type is Copy) +/// 3. check whether type is struct -> check in the COPY_IMPL_CACHE if we know whether it has a #[derive(Copy)] attribute +/// +/// For struct types, this relies on the cache populated by struct_implements_copy. If we don't have cached +/// information, it assumes the type does not implement Copy. This is a limitation of our approach, but it +/// works well in practice because process_input will call struct_implements_copy for all structs before +/// they might be referenced by other structs. +pub fn is_copy_type(ty: &Type) -> bool { + match ty { + Type::Path(TypePath { path, .. }) => { + if let Some(segment) = path.segments.last() { + let ident = &segment.ident; + + // Check if it's a primitive type that implements Copy + if ident == "u8" + || ident == "u16" + || ident == "u32" + || ident == "u64" + || ident == "i8" + || ident == "i16" + || ident == "i32" + || ident == "i64" + || ident == "bool" // bool is a Copy type + || ident == "char" + || ident == "Pubkey" + // Pubkey is hardcoded as copy type for now. + { + return true; + } + + // Check if we have cached information about this type + let cache_key = create_unique_type_key(ident); + if let Some(implements_copy) = COPY_IMPL_CACHE.lock().unwrap().get(&cache_key) { + return *implements_copy; + } + } + } + // Handle array types (which are always Copy if the element type is Copy) + Type::Array(array) => { + // Arrays are Copy if their element type is Copy + return is_copy_type(&array.elem); + } + // For struct types not in cache, we'd need the derive input to check attributes + _ => {} + } + false +} + +#[cfg(test)] +mod tests { + use syn::parse_quote; + + use super::*; + + // Helper function to check if a struct implements Copy + fn check_struct_implements_copy(input: syn::DeriveInput) -> bool { + struct_implements_copy(&input) + } + + #[test] + fn test_struct_implements_copy() { + // Ensure the cache is cleared and the lock is released immediately + COPY_IMPL_CACHE.lock().unwrap().clear(); + // Test case 1: Empty struct with #[derive(Copy)] + let input: syn::DeriveInput = parse_quote! { + #[derive(Copy, Clone)] + struct EmptyStruct {} + }; + assert!( + check_struct_implements_copy(input), + "EmptyStruct should implement Copy with #[derive(Copy)]" + ); + + // Test case 2: Simple struct with #[derive(Copy)] + let input: syn::DeriveInput = parse_quote! { + #[derive(Copy, Clone)] + struct SimpleStruct { + a: u8, + b: u16, + } + }; + assert!( + check_struct_implements_copy(input), + "SimpleStruct should implement Copy with #[derive(Copy)]" + ); + + // Test case 3: Struct with #[derive(Clone)] but not Copy + let input: syn::DeriveInput = parse_quote! { + #[derive(Clone)] + struct StructWithoutCopy { + a: u8, + b: u16, + } + }; + assert!( + !check_struct_implements_copy(input), + "StructWithoutCopy should not implement Copy without #[derive(Copy)]" + ); + + // Test case 4: Struct with a non-Copy field but with derive(Copy) + // Note: In real Rust code, this would not compile, but for our test we only check attributes + let input: syn::DeriveInput = parse_quote! { + #[derive(Copy, Clone)] + struct StructWithVec { + a: u8, + b: Vec, + } + }; + assert!( + check_struct_implements_copy(input), + "StructWithVec has #[derive(Copy)] so our function returns true" + ); + + // Test case 5: Struct with all Copy fields but without #[derive(Copy)] + let input: syn::DeriveInput = parse_quote! { + struct StructWithCopyFields { + a: u8, + b: u16, + c: i32, + d: bool, + } + }; + assert!( + !check_struct_implements_copy(input), + "StructWithCopyFields should not implement Copy without #[derive(Copy)]" + ); + + // Test case 6: Unit struct without #[derive(Copy)] + let input: syn::DeriveInput = parse_quote! { + struct UnitStructWithoutCopy; + }; + assert!( + !check_struct_implements_copy(input), + "UnitStructWithoutCopy should not implement Copy without #[derive(Copy)]" + ); + + // Test case 7: Unit struct with #[derive(Copy)] + let input: syn::DeriveInput = parse_quote! { + #[derive(Copy, Clone)] + struct UnitStructWithCopy; + }; + assert!( + check_struct_implements_copy(input), + "UnitStructWithCopy should implement Copy with #[derive(Copy)]" + ); + + // Test case 8: Tuple struct with #[derive(Copy)] + let input: syn::DeriveInput = parse_quote! { + #[derive(Copy, Clone)] + struct TupleStruct(u32, bool, char); + }; + assert!( + check_struct_implements_copy(input), + "TupleStruct should implement Copy with #[derive(Copy)]" + ); + + // Test case 9: Multiple derives including Copy + let input: syn::DeriveInput = parse_quote! { + #[derive(Debug, PartialEq, Copy, Clone)] + struct MultipleDerivesStruct { + a: u8, + } + }; + assert!( + check_struct_implements_copy(input), + "MultipleDerivesStruct should implement Copy with #[derive(Copy)]" + ); + } +} diff --git a/program-libs/zero-copy-derive/src/shared/z_struct.rs b/program-libs/zero-copy-derive/src/shared/z_struct.rs new file mode 100644 index 0000000000..77aa085c49 --- /dev/null +++ b/program-libs/zero-copy-derive/src/shared/z_struct.rs @@ -0,0 +1,630 @@ +use proc_macro2::TokenStream; +use quote::{format_ident, quote, TokenStreamExt}; +use syn::{parse_quote, Field, Ident, Type}; + +use super::utils; + +/// Enum representing the different field types for zero-copy struct +/// (Name, Type) +/// Note: Arrays with Option elements are not currently supported +#[derive(Debug)] +pub enum FieldType<'a> { + VecU8(&'a Ident), + VecCopy(&'a Ident, &'a Type), + VecDynamicZeroCopy(&'a Ident, &'a Type), + Array(&'a Ident, &'a Type), // Static arrays only - no Option elements supported + Option(&'a Ident, &'a Type), + OptionU64(&'a Ident), + OptionU32(&'a Ident), + OptionU16(&'a Ident), + Pubkey(&'a Ident), + Primitive(&'a Ident, &'a Type), + Copy(&'a Ident, &'a Type), + DynamicZeroCopy(&'a Ident, &'a Type), +} + +impl<'a> FieldType<'a> { + /// Get the name of the field + pub fn name(&self) -> &'a Ident { + match self { + FieldType::VecU8(name) => name, + FieldType::VecCopy(name, _) => name, + FieldType::VecDynamicZeroCopy(name, _) => name, + FieldType::Array(name, _) => name, + FieldType::Option(name, _) => name, + FieldType::OptionU64(name) => name, + FieldType::OptionU32(name) => name, + FieldType::OptionU16(name) => name, + FieldType::Pubkey(name) => name, + FieldType::Primitive(name, _) => name, + FieldType::Copy(name, _) => name, + FieldType::DynamicZeroCopy(name, _) => name, + } + } +} + +/// Classify a Vec type based on its inner type +fn classify_vec_type<'a>( + field_name: &'a Ident, + field_type: &'a Type, + inner_type: &'a Type, +) -> FieldType<'a> { + if utils::is_specific_primitive_type(inner_type, "u8") { + FieldType::VecU8(field_name) + } else if utils::is_copy_type(inner_type) { + FieldType::VecCopy(field_name, inner_type) + } else { + FieldType::VecDynamicZeroCopy(field_name, field_type) + } +} + +/// Classify an Option type based on its inner type +fn classify_option_type<'a>( + field_name: &'a Ident, + field_type: &'a Type, + inner_type: &'a Type, +) -> FieldType<'a> { + if utils::is_primitive_integer(inner_type) { + match () { + _ if utils::is_specific_primitive_type(inner_type, "u64") => { + FieldType::OptionU64(field_name) + } + _ if utils::is_specific_primitive_type(inner_type, "u32") => { + FieldType::OptionU32(field_name) + } + _ if utils::is_specific_primitive_type(inner_type, "u16") => { + FieldType::OptionU16(field_name) + } + _ => FieldType::Option(field_name, field_type), + } + } else { + FieldType::Option(field_name, field_type) + } +} + +/// Classify a primitive integer type +fn classify_integer_type<'a>( + field_name: &'a Ident, + field_type: &'a Type, +) -> syn::Result> { + match () { + _ if utils::is_specific_primitive_type(field_type, "u64") + | utils::is_specific_primitive_type(field_type, "u32") + | utils::is_specific_primitive_type(field_type, "u16") + | utils::is_specific_primitive_type(field_type, "u8") => + { + Ok(FieldType::Primitive(field_name, field_type)) + } + _ => Err(syn::Error::new_spanned( + field_type, + "Unsupported integer type. Only u8, u16, u32, and u64 are supported", + )), + } +} + +/// Classify a Copy type +fn classify_copy_type<'a>(field_name: &'a Ident, field_type: &'a Type) -> FieldType<'a> { + if utils::is_specific_primitive_type(field_type, "u8") + || utils::is_specific_primitive_type(field_type, "bool") + { + FieldType::Primitive(field_name, field_type) + } else { + FieldType::Copy(field_name, field_type) + } +} + +/// Classify a single field into its FieldType +fn classify_field<'a>(field_name: &'a Ident, field_type: &'a Type) -> syn::Result> { + // Vec types + if utils::is_vec_type(field_type) { + return match utils::get_vec_inner_type(field_type) { + Some(inner_type) => Ok(classify_vec_type(field_name, field_type, inner_type)), + None => Err(syn::Error::new_spanned( + field_type, + "Could not determine inner type of Vec", + )), + }; + } + + // Array types + if let Type::Array(_) = field_type { + return Ok(FieldType::Array(field_name, field_type)); + } + + // Option types + if utils::is_option_type(field_type) { + return match utils::get_option_inner_type(field_type) { + Some(inner_type) => Ok(classify_option_type(field_name, field_type, inner_type)), + None => Ok(FieldType::Option(field_name, field_type)), + }; + } + + // Simple type dispatch + match () { + _ if utils::is_pubkey_type(field_type) => Ok(FieldType::Pubkey(field_name)), + _ if utils::is_bool_type(field_type) => Ok(FieldType::Primitive(field_name, field_type)), + _ if utils::is_primitive_integer(field_type) => { + classify_integer_type(field_name, field_type) + } + _ if utils::is_copy_type(field_type) => Ok(classify_copy_type(field_name, field_type)), + _ => Ok(FieldType::DynamicZeroCopy(field_name, field_type)), + } +} + +/// Analyze struct fields and return vector of FieldType enums +pub fn analyze_struct_fields<'a>( + struct_fields: &'a [&'a Field], +) -> syn::Result>> { + struct_fields + .iter() + .map(|field| { + let field_name = field + .ident + .as_ref() + .ok_or_else(|| syn::Error::new_spanned(field, "Field must have a name"))?; + classify_field(field_name, &field.ty) + }) + .collect() +} + +/// Generate struct fields with zerocopy types based on field type enum +fn generate_struct_fields_with_zerocopy_types<'a, const MUT: bool>( + struct_fields: &'a [&'a Field], + hasher: &'a bool, +) -> syn::Result + 'a> { + let field_types = analyze_struct_fields(struct_fields)?; + let iterator = field_types + .into_iter() + .zip(struct_fields.iter()) + .map(|(field_type, field)| { + let attributes = if *hasher { + field + .attrs + .iter() + .map(|attr| { + quote! { #attr } + }) + .collect::>() + } else { + vec![quote! {}] + }; + let (mutability, import_path, import_slice, camel_case_suffix): ( + syn::Type, + syn::Ident, + syn::Ident, + String, + ) = if MUT { + ( + parse_quote!(&'a mut [u8]), + format_ident!("borsh_mut"), + format_ident!("slice_mut"), + String::from("Mut"), + ) + } else { + ( + parse_quote!(&'a [u8]), + format_ident!("borsh"), + format_ident!("slice"), + String::new(), + ) + }; + let deserialize_ident = format_ident!("Deserialize{}", camel_case_suffix); + let trait_name: syn::Type = parse_quote!(light_zero_copy::#import_path::#deserialize_ident); + let slice_ident = format_ident!("ZeroCopySlice{}Borsh", camel_case_suffix); + let slice_name: syn::Type = parse_quote!(light_zero_copy::#import_slice::#slice_ident); + let struct_inner_ident = format_ident!("ZeroCopyStructInner{}", camel_case_suffix); + let inner_ident = format_ident!("ZeroCopyInner{}", camel_case_suffix); + let struct_inner_trait_name: syn::Type = parse_quote!(light_zero_copy::#import_path::#struct_inner_ident::#inner_ident); + match field_type { + FieldType::VecU8(field_name) => { + quote! { + #(#attributes)* + pub #field_name: #mutability + } + } + FieldType::VecCopy(field_name, inner_type) => { + // For primitive Copy types, use the zerocopy converted type directly + // For complex Copy types, use the ZeroCopyStructInner trait + if utils::is_primitive_integer(inner_type) || utils::is_bool_type(inner_type) || utils::is_pubkey_type(inner_type) { + let zerocopy_type = utils::convert_to_zerocopy_type(inner_type); + quote! { + #(#attributes)* + pub #field_name: #slice_name<'a, #zerocopy_type> + } + } else { + let inner_type = utils::convert_to_zerocopy_type(inner_type); + quote! { + #(#attributes)* + pub #field_name: #slice_name<'a, <#inner_type as #struct_inner_trait_name>> + } + } + } + FieldType::VecDynamicZeroCopy(field_name, field_type) => { + let field_type = utils::convert_to_zerocopy_type(field_type); + quote! { + #(#attributes)* + pub #field_name: <#field_type as #trait_name<'a>>::Output + } + } + FieldType::Array(field_name, field_type) => { + let field_type = utils::convert_to_zerocopy_type(field_type); + quote! { + #(#attributes)* + pub #field_name: light_zero_copy::Ref<#mutability , #field_type> + } + } + FieldType::Option(field_name, field_type) => { + let field_type = utils::convert_to_zerocopy_type(field_type); + quote! { + #(#attributes)* + pub #field_name: <#field_type as #trait_name<'a>>::Output + } + } + FieldType::OptionU64(field_name) => { + let field_ty_zerocopy = utils::convert_to_zerocopy_type(&parse_quote!(u64)); + quote! { + #(#attributes)* + pub #field_name: Option> + } + } + FieldType::OptionU32(field_name) => { + let field_ty_zerocopy = utils::convert_to_zerocopy_type(&parse_quote!(u32)); + quote! { + #(#attributes)* + pub #field_name: Option> + } + } + FieldType::OptionU16(field_name) => { + let field_ty_zerocopy = utils::convert_to_zerocopy_type(&parse_quote!(u16)); + quote! { + #(#attributes)* + pub #field_name: Option> + } + } + FieldType::Pubkey(field_name) => { + quote! { + #(#attributes)* + pub #field_name: >::Output + } + } + FieldType::Primitive(field_name, field_type) => { + quote! { + #(#attributes)* + pub #field_name: <#field_type as #trait_name<'a>>::Output + } + } + // FieldType::Bool(field_name) => { + // quote! { + // #(#attributes)* + // pub #field_name: >::Output + // } + // } + FieldType::Copy(field_name, field_type) => { + let zerocopy_type = utils::convert_to_zerocopy_type(field_type); + quote! { + #(#attributes)* + pub #field_name: light_zero_copy::Ref<#mutability , #zerocopy_type> + } + } + FieldType::DynamicZeroCopy(field_name, field_type) => { + quote! { + #(#attributes)* + pub #field_name: <#field_type as #trait_name<'a>>::Output + } + } + } + }); + Ok(iterator) +} + +/// Generate accessor methods for boolean fields in struct_fields. +/// We need accessors because booleans are stored as u8. +fn generate_bool_accessor_methods<'a, const MUT: bool>( + struct_fields: &'a [&'a Field], +) -> impl Iterator + 'a { + struct_fields.iter().filter_map(|field| { + let field_name = &field.ident; + let field_type = &field.ty; + + if utils::is_bool_type(field_type) { + let comparison = if MUT { + quote! { *self.#field_name > 0 } + } else { + quote! { self.#field_name > 0 } + }; + + Some(quote! { + pub fn #field_name(&self) -> bool { + #comparison + } + }) + } else { + None + } + }) +} + +/// Generates the ZStruct definition as a TokenStream +pub fn generate_z_struct( + z_struct_name: &Ident, + z_struct_meta_name: &Ident, + struct_fields: &[&Field], + meta_fields: &[&Field], + hasher: bool, +) -> syn::Result { + let z_struct_name = if MUT { + format_ident!("{}Mut", z_struct_name) + } else { + z_struct_name.clone() + }; + let z_struct_meta_name = if MUT { + format_ident!("{}Mut", z_struct_meta_name) + } else { + z_struct_meta_name.clone() + }; + let mutability: syn::Type = if MUT { + parse_quote!(&'a mut [u8]) + } else { + parse_quote!(&'a [u8]) + }; + + let derive_clone = if MUT { + quote! {} + } else { + quote! {, Clone } + }; + let struct_fields_with_zerocopy_types: Vec = + generate_struct_fields_with_zerocopy_types::(struct_fields, &hasher)?.collect(); + + let derive_hasher = if hasher { + quote! { + , LightHasher + } + } else { + quote! {} + }; + let hasher_flatten = if hasher { + quote! { + #[flatten] + } + } else { + quote! {} + }; + + let partial_eq_derive = if MUT { quote!() } else { quote!(, PartialEq) }; + + let mut z_struct = if meta_fields.is_empty() { + quote! { + // ZStruct + #[derive(Debug #partial_eq_derive #derive_clone #derive_hasher)] + pub struct #z_struct_name<'a> { + #(#struct_fields_with_zerocopy_types,)* + } + } + } else { + let mut tokens = quote! { + // ZStruct + #[derive(Debug #partial_eq_derive #derive_clone #derive_hasher)] + pub struct #z_struct_name<'a> { + #hasher_flatten + __meta: light_zero_copy::Ref<#mutability, #z_struct_meta_name>, + #(#struct_fields_with_zerocopy_types,)* + } + impl<'a> core::ops::Deref for #z_struct_name<'a> { + type Target = light_zero_copy::Ref<#mutability , #z_struct_meta_name>; + + fn deref(&self) -> &Self::Target { + &self.__meta + } + } + }; + + if MUT { + tokens.append_all(quote! { + impl<'a> core::ops::DerefMut for #z_struct_name<'a> { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.__meta + } + } + }); + } + tokens + }; + + if !meta_fields.is_empty() { + let meta_bool_accessor_methods = generate_bool_accessor_methods::(meta_fields); + z_struct.append_all(quote! { + // Implement methods for ZStruct + impl<'a> #z_struct_name<'a> { + #(#meta_bool_accessor_methods)* + } + }) + }; + + if !struct_fields.is_empty() { + let bool_accessor_methods = generate_bool_accessor_methods::(struct_fields); + z_struct.append_all(quote! { + // Implement methods for ZStruct + impl<'a> #z_struct_name<'a> { + #(#bool_accessor_methods)* + } + + }); + } + Ok(z_struct) +} + +#[cfg(test)] +mod tests { + use quote::format_ident; + use rand::{prelude::SliceRandom, rngs::StdRng, thread_rng, Rng, SeedableRng}; + use syn::parse_quote; + + use super::*; + + /// Generate a safe field name for testing + fn random_ident(rng: &mut StdRng) -> String { + // Use predetermined safe field names + const FIELD_NAMES: &[&str] = &[ + "field1", "field2", "field3", "field4", "field5", "value", "data", "count", "size", + "flag", "name", "id", "code", "index", "key", "amount", "balance", "total", "result", + "status", + ]; + + FIELD_NAMES.choose(rng).unwrap().to_string() + } + + /// Generate a random Rust type + fn random_type(rng: &mut StdRng, _depth: usize) -> syn::Type { + // Define our available types + let types = [0, 1, 2, 3, 4, 5, 6, 7]; + + // Randomly select a type index + let selected = *types.choose(rng).unwrap(); + + // Return the corresponding type + match selected { + 0 => parse_quote!(u8), + 1 => parse_quote!(u16), + 2 => parse_quote!(u32), + 3 => parse_quote!(u64), + 4 => parse_quote!(bool), + 5 => parse_quote!(Vec), + 6 => parse_quote!(Vec), + 7 => parse_quote!(Vec), + _ => unreachable!(), + } + } + + /// Generate a random field + fn random_field(rng: &mut StdRng) -> Field { + let name = random_ident(rng); + let ty = random_type(rng, 0); + + // Use a safer approach to create the field + let name_ident = format_ident!("{}", name); + parse_quote!(pub #name_ident: #ty) + } + + /// Generate a list of random fields + fn random_fields(rng: &mut StdRng, count: usize) -> Vec { + (0..count).map(|_| random_field(rng)).collect() + } + + #[test] + fn test_fuzz_generate_z_struct() { + // Set up RNG with a seed for reproducibility + let seed = thread_rng().gen(); + println!("seed {}", seed); + let mut rng = StdRng::seed_from_u64(seed); + + // Now that the test is working, run with 10,000 iterations + let num_iters = 10000; + + for i in 0..num_iters { + // Generate a random struct name + let struct_name = format_ident!("{}", random_ident(&mut rng)); + let z_struct_name = format_ident!("Z{}", struct_name); + let z_struct_meta_name = format_ident!("Z{}Meta", struct_name); + + // Generate random number of fields (1-10) + let field_count = rng.gen_range(1..11); + let fields = random_fields(&mut rng, field_count); + + // Create a named fields collection that lives longer than the process_fields call + let syn_fields = syn::punctuated::Punctuated::from_iter(fields.iter().cloned()); + let fields_named = syn::FieldsNamed { + brace_token: syn::token::Brace::default(), + named: syn_fields, + }; + + // Split into meta fields and struct fields + let (meta_fields, struct_fields) = crate::shared::utils::process_fields(&fields_named); + + // Call the function we're testing + let result = generate_z_struct::( + &z_struct_name, + &z_struct_meta_name, + &struct_fields, + &meta_fields, + false, + ); + + // Get the generated code as a string for validation + let result_str = result.unwrap().to_string(); + + // Validate the generated code + + // Verify the result contains expected struct elements + // Basic validation - must be non-empty + assert!( + !result_str.is_empty(), + "Failed to generate TokenStream for iteration {}", + i + ); + + // Validate that the generated code contains the expected struct definition + let struct_pattern = format!("struct {} < 'a >", z_struct_name); + assert!( + result_str.contains(&struct_pattern), + "Generated code missing struct definition for iteration {}. Expected: {}", + i, + struct_pattern + ); + + if meta_fields.is_empty() { + // Validate the meta field is present + assert!( + !result_str.contains("meta :"), + "Generated code had meta field for iteration {}", + i + ); + // Validate Deref implementation + assert!( + !result_str.contains("impl < 'a > core :: ops :: Deref"), + "Generated code missing Deref implementation for iteration {}", + i + ); + } else { + // Validate the meta field is present + assert!( + result_str.contains("meta :"), + "Generated code missing meta field for iteration {}", + i + ); + // Validate Deref implementation + assert!( + result_str.contains("impl < 'a > core :: ops :: Deref"), + "Generated code missing Deref implementation for iteration {}", + i + ); + // Validate Target type + assert!( + result_str.contains("type Target"), + "Generated code missing Target type for iteration {}", + i + ); + // Check that the deref method is implemented + assert!( + result_str.contains("fn deref (& self)"), + "Generated code missing deref method for iteration {}", + i + ); + + // Check for light_zero_copy::Ref reference + assert!( + result_str.contains("light_zero_copy :: Ref"), + "Generated code missing light_zero_copy::Ref for iteration {}", + i + ); + } + + // Make sure derive attributes are present + assert!( + result_str.contains("# [derive (Debug , PartialEq , Clone)]"), + "Generated code missing derive attributes for iteration {}", + i + ); + } + } +} diff --git a/program-libs/zero-copy-derive/src/shared/zero_copy_new.rs b/program-libs/zero-copy-derive/src/shared/zero_copy_new.rs new file mode 100644 index 0000000000..d6190fdefa --- /dev/null +++ b/program-libs/zero-copy-derive/src/shared/zero_copy_new.rs @@ -0,0 +1,391 @@ +use proc_macro2::TokenStream as TokenStream2; +use quote::quote; +use syn::Ident; + +use crate::shared::{ + utils, + z_struct::{analyze_struct_fields, FieldType}, +}; + +/// Generate ZeroCopyNew implementation with new_at method for a struct +pub fn generate_init_mut_impl( + struct_name: &syn::Ident, + meta_fields: &[&syn::Field], + struct_fields: &[&syn::Field], +) -> syn::Result { + let config_name = quote::format_ident!("{}Config", struct_name); + let z_meta_name = quote::format_ident!("Z{}MetaMut", struct_name); + let z_struct_mut_name = quote::format_ident!("Z{}Mut", struct_name); + + // Use the pre-separated fields from utils::process_fields (consistent with other derives) + let struct_field_types = analyze_struct_fields(struct_fields)?; + + // Generate field initialization code for struct fields only (meta fields are part of __meta) + let field_initializations: Result, syn::Error> = + struct_field_types + .iter() + .map(|field_type| generate_field_initialization(field_type)) + .collect(); + let field_initializations = field_initializations?; + + // Generate struct construction - only include struct fields that were initialized + // Meta fields are accessed via __meta.field_name in the generated ZStruct + let struct_field_names: Vec = struct_field_types + .iter() + .map(|field_type| { + let field_name = field_type.name(); + quote! { #field_name, } + }) + .collect(); + + // Check if there are meta fields to determine whether to include __meta + let has_meta_fields = !meta_fields.is_empty(); + + let meta_initialization = if has_meta_fields { + quote! { + // Handle the meta struct (fixed-size fields at the beginning) + let (__meta, bytes) = Ref::<&mut [u8], #z_meta_name>::from_prefix(bytes)?; + } + } else { + quote! { + // No meta fields, skip meta struct initialization + } + }; + + let struct_construction = if has_meta_fields { + quote! { + let result = #z_struct_mut_name { + __meta, + #(#struct_field_names)* + }; + } + } else { + quote! { + let result = #z_struct_mut_name { + #(#struct_field_names)* + }; + } + }; + + // Generate byte_len calculation for each field type + let byte_len_calculations: Result, syn::Error> = + struct_field_types + .iter() + .map(|field_type| generate_byte_len_calculation(field_type)) + .collect(); + let byte_len_calculations = byte_len_calculations?; + + // Calculate meta size if there are meta fields + let meta_size_calculation = if has_meta_fields { + quote! { + core::mem::size_of::<#z_meta_name>() + } + } else { + quote! { 0 } + }; + + let result = quote! { + impl<'a> light_zero_copy::init_mut::ZeroCopyNew<'a> for #struct_name { + type ZeroCopyConfig = #config_name; + type Output = >::Output; + + fn byte_len(config: &Self::ZeroCopyConfig) -> usize { + #meta_size_calculation #(+ #byte_len_calculations)* + } + + fn new_zero_copy( + bytes: &'a mut [u8], + config: Self::ZeroCopyConfig, + ) -> Result<(Self::Output, &'a mut [u8]), light_zero_copy::errors::ZeroCopyError> { + use zerocopy::Ref; + + #meta_initialization + + #(#field_initializations)* + + #struct_construction + + Ok((result, bytes)) + } + } + }; + Ok(result) +} + +// Configuration system functions moved from config.rs + +/// Determine if this field type requires configuration for initialization +pub fn requires_config(field_type: &FieldType) -> bool { + match field_type { + // Vec types always need length configuration + FieldType::VecU8(_) | FieldType::VecCopy(_, _) | FieldType::VecDynamicZeroCopy(_, _) => { + true + } + // Option types need Some/None configuration + FieldType::Option(_, _) => true, + // Fixed-size types don't need configuration + FieldType::Array(_, _) + | FieldType::Pubkey(_) + | FieldType::Primitive(_, _) + | FieldType::Copy(_, _) => false, + // DynamicZeroCopy types might need configuration if they contain Vec/Option + FieldType::DynamicZeroCopy(_, _) => true, // Conservative: assume they need config + // Option integer types need config to determine if they're enabled + FieldType::OptionU64(_) | FieldType::OptionU32(_) | FieldType::OptionU16(_) => true, + } +} + +/// Generate the config type for this field +pub fn config_type(field_type: &FieldType) -> syn::Result { + let result = match field_type { + // Simple Vec types: just need length + FieldType::VecU8(_) => quote! { u32 }, + FieldType::VecCopy(_, _) => quote! { u32 }, + + // Complex Vec types: need config for each element + FieldType::VecDynamicZeroCopy(_, vec_type) => { + if let Some(inner_type) = utils::get_vec_inner_type(vec_type) { + quote! { Vec<<#inner_type as light_zero_copy::init_mut::ZeroCopyNew<'static>>::ZeroCopyConfig> } + } else { + return Err(syn::Error::new_spanned( + vec_type, + "Could not determine inner type for VecDynamicZeroCopy config", + )); + } + } + + // Option types: delegate to the Option's Config type + FieldType::Option(_, option_type) => { + quote! { <#option_type as light_zero_copy::init_mut::ZeroCopyNew<'static>>::ZeroCopyConfig } + } + + // Fixed-size types don't need configuration + FieldType::Array(_, _) + | FieldType::Pubkey(_) + | FieldType::Primitive(_, _) + | FieldType::Copy(_, _) => quote! { () }, + + // Option integer types: use bool config to determine if enabled + FieldType::OptionU64(_) | FieldType::OptionU32(_) | FieldType::OptionU16(_) => { + quote! { bool } + } + + // DynamicZeroCopy types: delegate to their Config type (Config is typically 'static) + FieldType::DynamicZeroCopy(_, field_type) => { + let field_type = utils::convert_to_zerocopy_type(field_type); + quote! { <#field_type as light_zero_copy::init_mut::ZeroCopyNew<'static>>::ZeroCopyConfig } + } + }; + Ok(result) +} + +/// Generate a configuration struct for a given struct +pub fn generate_config_struct( + struct_name: &Ident, + field_types: &[FieldType], +) -> syn::Result { + let config_name = quote::format_ident!("{}Config", struct_name); + + // Generate config fields only for fields that require configuration + let config_fields: Result, syn::Error> = field_types + .iter() + .filter(|field_type| requires_config(field_type)) + .map(|field_type| -> syn::Result { + let field_name = field_type.name(); + let config_type = config_type(field_type)?; + Ok(quote! { + pub #field_name: #config_type, + }) + }) + .collect(); + let config_fields = config_fields?; + + let result = if config_fields.is_empty() { + // If no fields require configuration, create an empty config struct + quote! { + #[derive(Debug, Clone, PartialEq)] + pub struct #config_name; + } + } else { + quote! { + #[derive(Debug, Clone, PartialEq)] + pub struct #config_name { + #(#config_fields)* + } + } + }; + Ok(result) +} + +/// Generate initialization logic for a field based on its configuration +pub fn generate_field_initialization(field_type: &FieldType) -> syn::Result { + let result = match field_type { + FieldType::VecU8(field_name) => { + quote! { + // Initialize the length prefix but don't use the returned ZeroCopySliceMut + { + light_zero_copy::slice_mut::ZeroCopySliceMutBorsh::::new_at( + config.#field_name.into(), + bytes + )?; + } + // Split off the length prefix (4 bytes) and get the slice + let (_, bytes) = bytes.split_at_mut(4); + let (#field_name, bytes) = bytes.split_at_mut(config.#field_name as usize); + } + } + + FieldType::VecCopy(field_name, inner_type) => { + quote! { + let (#field_name, bytes) = light_zero_copy::slice_mut::ZeroCopySliceMutBorsh::<#inner_type>::new_at( + config.#field_name.into(), + bytes + )?; + } + } + + FieldType::VecDynamicZeroCopy(field_name, vec_type) + | FieldType::DynamicZeroCopy(field_name, vec_type) + | FieldType::Option(field_name, vec_type) => { + quote! { + let (#field_name, bytes) = <#vec_type as light_zero_copy::init_mut::ZeroCopyNew<'a>>::new_zero_copy( + bytes, + config.#field_name + )?; + } + } + + FieldType::OptionU64(field_name) + | FieldType::OptionU32(field_name) + | FieldType::OptionU16(field_name) => { + let option_type = match field_type { + FieldType::OptionU64(_) => quote! { Option }, + FieldType::OptionU32(_) => quote! { Option }, + FieldType::OptionU16(_) => quote! { Option }, + _ => unreachable!(), + }; + quote! { + let (#field_name, bytes) = <#option_type as light_zero_copy::init_mut::ZeroCopyNew>::new_zero_copy( + bytes, + (config.#field_name, ()) + )?; + } + } + + // Fixed-size types that are struct fields (not meta fields) need initialization with () config + FieldType::Primitive(field_name, field_type) => { + quote! { + let (#field_name, bytes) = <#field_type as light_zero_copy::borsh_mut::DeserializeMut>::zero_copy_at_mut(bytes)?; + } + } + + // Array fields that are struct fields (come after Vec/Option) + FieldType::Array(field_name, array_type) => { + quote! { + let (#field_name, bytes) = light_zero_copy::Ref::< + &'a mut [u8], + #array_type + >::from_prefix(bytes)?; + } + } + + FieldType::Pubkey(field_name) => { + quote! { + let (#field_name, bytes) = light_zero_copy::Ref::< + &'a mut [u8], + Pubkey + >::from_prefix(bytes)?; + } + } + + FieldType::Copy(field_name, field_type) => { + quote! { + let (#field_name, bytes) = <#field_type as light_zero_copy::init_mut::ZeroCopyNew>::new_zero_copy(bytes)?; + } + } + }; + Ok(result) +} + +/// Generate byte length calculation for a field based on its configuration +pub fn generate_byte_len_calculation(field_type: &FieldType) -> syn::Result { + let result = match field_type { + // Vec types that require configuration + FieldType::VecU8(field_name) => { + quote! { + (4 + config.#field_name as usize) // 4 bytes for length + actual data + } + } + + FieldType::VecCopy(field_name, inner_type) => { + quote! { + (4 + (config.#field_name as usize * core::mem::size_of::<#inner_type>())) + } + } + + FieldType::VecDynamicZeroCopy(field_name, vec_type) => { + quote! { + <#vec_type as light_zero_copy::init_mut::ZeroCopyNew<'static>>::byte_len(&config.#field_name) + } + } + + // Option types + FieldType::Option(field_name, option_type) => { + quote! { + <#option_type as light_zero_copy::init_mut::ZeroCopyNew<'static>>::byte_len(&config.#field_name) + } + } + + FieldType::OptionU64(field_name) => { + quote! { + as light_zero_copy::init_mut::ZeroCopyNew<'static>>::byte_len(&(config.#field_name, ())) + } + } + + FieldType::OptionU32(field_name) => { + quote! { + as light_zero_copy::init_mut::ZeroCopyNew<'static>>::byte_len(&(config.#field_name, ())) + } + } + + FieldType::OptionU16(field_name) => { + quote! { + as light_zero_copy::init_mut::ZeroCopyNew<'static>>::byte_len(&(config.#field_name, ())) + } + } + + // Fixed-size types don't need configuration and have known sizes + FieldType::Primitive(_, field_type) => { + let zerocopy_type = utils::convert_to_zerocopy_type(field_type); + quote! { + core::mem::size_of::<#zerocopy_type>() + } + } + + FieldType::Array(_, array_type) => { + quote! { + core::mem::size_of::<#array_type>() + } + } + + FieldType::Pubkey(_) => { + quote! { + 32 // Pubkey is always 32 bytes + } + } + + // Meta field types (should not appear in struct fields, but handle gracefully) + FieldType::Copy(_, field_type) => { + quote! { + core::mem::size_of::<#field_type>() + } + } + + FieldType::DynamicZeroCopy(field_name, field_type) => { + quote! { + <#field_type as light_zero_copy::init_mut::ZeroCopyNew<'static>>::byte_len(&config.#field_name) + } + } + }; + Ok(result) +} diff --git a/program-libs/zero-copy-derive/src/zero_copy.rs b/program-libs/zero-copy-derive/src/zero_copy.rs new file mode 100644 index 0000000000..bbae45e207 --- /dev/null +++ b/program-libs/zero-copy-derive/src/zero_copy.rs @@ -0,0 +1,636 @@ +use proc_macro::TokenStream as ProcTokenStream; +use proc_macro2::TokenStream; +use quote::{format_ident, quote}; +use syn::{parse_quote, DeriveInput, Field, Ident}; + +use crate::shared::{ + meta_struct, utils, + z_struct::{analyze_struct_fields, generate_z_struct, FieldType}, +}; + +/// Helper function to generate deserialize call pattern for a given type +fn generate_deserialize_call( + field_name: &syn::Ident, + field_type: &syn::Type, +) -> TokenStream { + let field_type = utils::convert_to_zerocopy_type(field_type); + let trait_path = if MUT { + quote!( as light_zero_copy::borsh_mut::DeserializeMut>::zero_copy_at_mut) + } else { + quote!( as light_zero_copy::borsh::Deserialize>::zero_copy_at) + }; + + quote! { + let (#field_name, bytes) = <#field_type #trait_path(bytes)?; + } +} + +/// Generates field deserialization code for the Deserialize implementation +/// The `MUT` parameter controls whether to generate code for mutable or immutable references +pub fn generate_deserialize_fields<'a, const MUT: bool>( + struct_fields: &'a [&'a Field], +) -> syn::Result + 'a> { + let field_types = analyze_struct_fields(struct_fields)?; + + let iterator = field_types.into_iter().map(move |field_type| { + let mutability_tokens = if MUT { + quote!(&'a mut [u8]) + } else { + quote!(&'a [u8]) + }; + match field_type { + FieldType::VecU8(field_name) => { + if MUT { + quote! { + let (#field_name, bytes) = light_zero_copy::borsh_mut::borsh_vec_u8_as_slice_mut(bytes)?; + } + } else { + quote! { + let (#field_name, bytes) = light_zero_copy::borsh::borsh_vec_u8_as_slice(bytes)?; + } + } + }, + FieldType::VecCopy(field_name, inner_type) => { + let inner_type = utils::convert_to_zerocopy_type(inner_type); + + let trait_path = if MUT { + quote!(light_zero_copy::slice_mut::ZeroCopySliceMutBorsh::<'a, <#inner_type as light_zero_copy::borsh_mut::ZeroCopyStructInnerMut>::ZeroCopyInnerMut>) + } else { + quote!(light_zero_copy::slice::ZeroCopySliceBorsh::<'a, <#inner_type as light_zero_copy::borsh::ZeroCopyStructInner>::ZeroCopyInner>) + }; + quote! { + let (#field_name, bytes) = #trait_path::from_bytes_at(bytes)?; + } + }, + FieldType::VecDynamicZeroCopy(field_name, field_type) => { + generate_deserialize_call::(field_name, field_type) + }, + FieldType::Array(field_name, field_type) => { + let field_type = utils::convert_to_zerocopy_type(field_type); + quote! { + let (#field_name, bytes) = light_zero_copy::Ref::<#mutability_tokens, #field_type>::from_prefix(bytes)?; + } + }, + FieldType::Option(field_name, field_type) => { + generate_deserialize_call::(field_name, field_type) + }, + FieldType::Pubkey(field_name) => { + generate_deserialize_call::(field_name, &parse_quote!(Pubkey)) + }, + FieldType::Primitive(field_name, field_type) => { + if MUT { + quote! { + let (#field_name, bytes) = <#field_type as light_zero_copy::borsh_mut::DeserializeMut>::zero_copy_at_mut(bytes)?; + } + } else { + quote! { + let (#field_name, bytes) = <#field_type as light_zero_copy::borsh::Deserialize>::zero_copy_at(bytes)?; + } + } + }, + FieldType::Copy(field_name, field_type) => { + let field_ty_zerocopy = utils::convert_to_zerocopy_type(field_type); + quote! { + let (#field_name, bytes) = light_zero_copy::Ref::<#mutability_tokens, #field_ty_zerocopy>::from_prefix(bytes)?; + } + }, + FieldType::DynamicZeroCopy(field_name, field_type) => { + generate_deserialize_call::(field_name, field_type) + }, + FieldType::OptionU64(field_name) => { + let field_ty_zerocopy = utils::convert_to_zerocopy_type(&parse_quote!(u64)); + generate_deserialize_call::(field_name, &parse_quote!(Option<#field_ty_zerocopy>)) + }, + FieldType::OptionU32(field_name) => { + let field_ty_zerocopy = utils::convert_to_zerocopy_type(&parse_quote!(u32)); + generate_deserialize_call::(field_name, &parse_quote!(Option<#field_ty_zerocopy>)) + }, + FieldType::OptionU16(field_name) => { + let field_ty_zerocopy = utils::convert_to_zerocopy_type(&parse_quote!(u16)); + generate_deserialize_call::(field_name, &parse_quote!(Option<#field_ty_zerocopy>)) + } + } + }); + Ok(iterator) +} + +/// Generates field initialization code for the Deserialize implementation +pub fn generate_init_fields<'a>( + struct_fields: &'a [&'a Field], +) -> impl Iterator + 'a { + struct_fields.iter().map(|field| { + let field_name = &field.ident; + quote! { #field_name } + }) +} + +/// Generates the Deserialize implementation as a TokenStream +/// The `MUT` parameter controls whether to generate code for mutable or immutable references +pub fn generate_deserialize_impl( + name: &Ident, + z_struct_name: &Ident, + z_struct_meta_name: &Ident, + struct_fields: &[&Field], + meta_is_empty: bool, + byte_len_impl: TokenStream, +) -> syn::Result { + let z_struct_name = if MUT { + format_ident!("{}Mut", z_struct_name) + } else { + z_struct_name.clone() + }; + let z_struct_meta_name = if MUT { + format_ident!("{}Mut", z_struct_meta_name) + } else { + z_struct_meta_name.clone() + }; + + // Define trait and types based on mutability + let (trait_name, mutability, method_name) = if MUT { + ( + quote!(light_zero_copy::borsh_mut::DeserializeMut), + quote!(mut), + quote!(zero_copy_at_mut), + ) + } else { + ( + quote!(light_zero_copy::borsh::Deserialize), + quote!(), + quote!(zero_copy_at), + ) + }; + let (meta_des, meta) = if meta_is_empty { + (quote!(), quote!()) + } else { + ( + quote! { + let (__meta, bytes) = light_zero_copy::Ref::< &'a #mutability [u8], #z_struct_meta_name>::from_prefix(bytes)?; + }, + quote!(__meta,), + ) + }; + let deserialize_fields = generate_deserialize_fields::(struct_fields)?; + let init_fields = generate_init_fields(struct_fields); + + let result = quote! { + impl<'a> #trait_name<'a> for #name { + type Output = #z_struct_name<'a>; + + fn #method_name(bytes: &'a #mutability [u8]) -> Result<(Self::Output, &'a #mutability [u8]), light_zero_copy::errors::ZeroCopyError> { + #meta_des + #(#deserialize_fields)* + Ok(( + #z_struct_name { + #meta + #(#init_fields,)* + }, + bytes + )) + } + + #byte_len_impl + } + }; + Ok(result) +} + +// #[cfg(test)] +// mod tests { +// use quote::format_ident; +// use rand::{prelude::SliceRandom, rngs::StdRng, thread_rng, Rng, SeedableRng}; +// use syn::parse_quote; + +// use super::*; + +// /// Generate a safe field name for testing +// fn random_ident(rng: &mut StdRng) -> String { +// // Use predetermined safe field names +// const FIELD_NAMES: &[&str] = &[ +// "field1", "field2", "field3", "field4", "field5", "value", "data", "count", "size", +// "flag", "name", "id", "code", "index", "key", "amount", "balance", "total", "result", +// "status", +// ]; + +// FIELD_NAMES.choose(rng).unwrap().to_string() +// } + +// /// Generate a random Rust type +// fn random_type(rng: &mut StdRng, _depth: usize) -> syn::Type { +// // Define our available types +// let types = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]; + +// // Randomly select a type index +// let selected = *types.choose(rng).unwrap(); + +// // Return the corresponding type +// match selected { +// 0 => parse_quote!(u8), +// 1 => parse_quote!(u16), +// 2 => parse_quote!(u32), +// 3 => parse_quote!(u64), +// 4 => parse_quote!(bool), +// 5 => parse_quote!(Vec), +// 6 => parse_quote!(Vec), +// 7 => parse_quote!(Vec), +// 8 => parse_quote!([u32; 12]), +// 9 => parse_quote!([Vec; 12]), +// 10 => parse_quote!([Vec; 20]), +// _ => unreachable!(), +// } +// } + +// /// Generate a random field +// fn random_field(rng: &mut StdRng) -> Field { +// let name = random_ident(rng); +// let ty = random_type(rng, 0); + +// // Use a safer approach to create the field +// let name_ident = format_ident!("{}", name); +// parse_quote!(pub #name_ident: #ty) +// } + +// /// Generate a list of random fields +// fn random_fields(rng: &mut StdRng, count: usize) -> Vec { +// (0..count).map(|_| random_field(rng)).collect() +// } + +// // Test for Vec field deserialization +// #[test] +// fn test_deserialize_vec_u8() { +// let field: Field = parse_quote!(pub data: Vec); +// let struct_fields = vec![&field]; + +// let result = generate_deserialize_fields::(&struct_fields).collect::>(); +// let result_str = result[0].to_string(); +// let expected = +// "let (data , bytes) = light_zero_copy :: borsh :: borsh_vec_u8_as_slice (bytes) ?"; + +// assert!(result_str.contains(expected)); +// } + +// // Test for Vec with Copy inner type deserialization +// #[test] +// fn test_deserialize_vec_copy_type() { +// let field: Field = parse_quote!(pub values: Vec); +// let struct_fields = vec![&field]; + +// let result = generate_deserialize_fields::(&struct_fields).collect::>(); +// let result_str = result[0].to_string(); +// let expected = "let (values , bytes) = light_zero_copy :: slice :: ZeroCopySliceBorsh :: < 'a , < u32 as light_zero_copy :: borsh :: ZeroCopyStructInner > :: ZeroCopyInner > :: from_bytes_at (bytes) ?"; + +// assert!(result_str.contains(expected)); +// } + +// // Test for Vec with non-Copy inner type deserialization +// #[test] +// fn test_deserialize_vec_non_copy_type() { +// // This is a synthetic test as we're treating String as a non-Copy type +// let field: Field = parse_quote!(pub names: Vec); +// let struct_fields = vec![&field]; + +// let result = generate_deserialize_fields::(&struct_fields).collect::>(); +// let result_str = result[0].to_string(); +// let expected = "let (names , bytes) = < Vec < String > as light_zero_copy :: borsh :: Deserialize > :: zero_copy_at (bytes) ?"; + +// assert!(result_str.contains(expected)); +// } + +// // Test for Option type deserialization +// #[test] +// fn test_deserialize_option_type() { +// let field: Field = parse_quote!(pub maybe_value: Option); +// let struct_fields = vec![&field]; + +// let result = generate_deserialize_fields::(&struct_fields).collect::>(); +// let result_str = result[0].to_string(); +// let expected = "let (maybe_value , bytes) = < Option < u32 > as light_zero_copy :: borsh :: Deserialize > :: zero_copy_at (bytes) ?"; + +// assert!(result_str.contains(expected)); +// } + +// // Test for non-Copy type deserialization +// #[test] +// fn test_deserialize_non_copy_type() { +// // Using String as a non-Copy type example +// let field: Field = parse_quote!(pub name: String); +// let struct_fields = vec![&field]; + +// let result = generate_deserialize_fields::(&struct_fields).collect::>(); +// let result_str = result[0].to_string(); +// let expected = "let (name , bytes) = < String as light_zero_copy :: borsh :: Deserialize > :: zero_copy_at (bytes) ?"; + +// assert!(result_str.contains(expected)); +// } + +// // Test for Copy type deserialization (primitive types) +// #[test] +// fn test_deserialize_copy_type() { +// let field: Field = parse_quote!(pub count: u32); +// let struct_fields = vec![&field]; + +// let result = generate_deserialize_fields::(&struct_fields).collect::>(); +// let result_str = result[0].to_string(); +// let expected = "let (count , bytes) = light_zero_copy :: Ref :: < & 'a [u8] , light_zero_copy :: little_endian :: U32 > :: from_prefix (bytes) ?"; +// println!("{}", result_str); +// assert!(result_str.contains(expected)); +// } + +// // Test for boolean type deserialization +// #[test] +// fn test_deserialize_bool_type() { +// let field: Field = parse_quote!(pub flag: bool); +// let struct_fields = vec![&field]; + +// let result = generate_deserialize_fields::(&struct_fields).collect::>(); +// let result_str = result[0].to_string(); +// let expected = +// "let (flag , bytes) = < u8 as light_zero_copy :: borsh :: Deserialize > :: zero_copy_at (bytes) ?"; +// println!("{}", result_str); +// assert!(result_str.contains(expected)); +// } + +// // Test for field initialization code generation +// #[test] +// fn test_init_fields() { +// let field1: Field = parse_quote!(pub id: u32); +// let field2: Field = parse_quote!(pub name: String); +// let struct_fields = vec![&field1, &field2]; + +// let result = generate_init_fields(&struct_fields).collect::>(); +// let result_str = format!("{} {}", result[0], result[1]); +// assert!(result_str.contains("id")); +// assert!(result_str.contains("name")); +// } + +// // Test for complete deserialize implementation generation +// #[test] +// fn test_generate_deserialize_impl() { +// let struct_name = format_ident!("TestStruct"); +// let z_struct_name = format_ident!("ZTestStruct"); +// let z_struct_meta_name = format_ident!("ZTestStructMeta"); + +// let field1: Field = parse_quote!(pub id: u32); +// let field2: Field = parse_quote!(pub values: Vec); +// let struct_fields = vec![&field1, &field2]; + +// let result = generate_deserialize_impl::( +// &struct_name, +// &z_struct_name, +// &z_struct_meta_name, +// &struct_fields, +// false, +// ) +// .to_string(); + +// // Check impl header +// assert!(result +// .contains("impl < 'a > light_zero_copy :: borsh :: Deserialize < 'a > for TestStruct")); + +// // Check Output type +// assert!(result.contains("type Output = ZTestStruct < 'a >")); + +// // Check method signature +// assert!(result.contains("fn zero_copy_at (bytes : & 'a [u8]) -> Result")); + +// // Check meta field extraction +// assert!(result.contains("let (__meta , bytes) = light_zero_copy :: Ref :: < & 'a [u8] , ZTestStructMeta > :: from_prefix (bytes) ?")); + +// // Check field deserialization +// assert!(result.contains("let (id , bytes) = light_zero_copy :: Ref :: < & 'a [u8] , light_zero_copy :: little_endian :: U32 > :: from_prefix (bytes) ?")); +// assert!(result.contains("let (values , bytes) = light_zero_copy :: slice :: ZeroCopySliceBorsh :: < 'a , < u16 as light_zero_copy :: borsh :: ZeroCopyStructInner > :: ZeroCopyInner > :: from_bytes_at (bytes) ?")); + +// // Check result structure +// assert!(result.contains("Ok ((ZTestStruct { __meta , id , values ,")); +// } + +// // Test for complete deserialize implementation generation +// #[test] +// fn test_generate_deserialize_impl_no_meta() { +// let struct_name = format_ident!("TestStruct"); +// let z_struct_name = format_ident!("ZTestStruct"); +// let z_struct_meta_name = format_ident!("ZTestStructMeta"); + +// let field1: Field = parse_quote!(pub id: u32); +// let field2: Field = parse_quote!(pub values: Vec); +// let struct_fields = vec![&field1, &field2]; + +// let result = generate_deserialize_impl::( +// &struct_name, +// &z_struct_name, +// &z_struct_meta_name, +// &struct_fields, +// true, +// ) +// .to_string(); + +// // Check impl header +// assert!(result +// .contains("impl < 'a > light_zero_copy :: borsh :: Deserialize < 'a > for TestStruct")); + +// // Check Output type +// assert!(result.contains("type Output = ZTestStruct < 'a >")); + +// // Check method signature +// assert!(result.contains("fn zero_copy_at (bytes : & 'a [u8]) -> Result")); + +// // Check meta field extraction +// assert!(!result.contains("let (meta , bytes) = light_zero_copy :: Ref :: < & 'a [u8] , ZTestStructMeta > :: from_prefix (bytes) ?")); + +// // Check field deserialization +// assert!(result.contains("let (id , bytes) = light_zero_copy :: Ref :: < & 'a [u8] , light_zero_copy :: little_endian :: U32 > :: from_prefix (bytes) ?")); +// assert!(result.contains("let (values , bytes) = light_zero_copy :: slice :: ZeroCopySliceBorsh :: < 'a , < u16 as light_zero_copy :: borsh :: ZeroCopyStructInner > :: ZeroCopyInner > :: from_bytes_at (bytes) ?")); + +// // Check result structure +// assert!(result.contains("Ok ((ZTestStruct { id , values ,")); +// } + +// #[test] +// fn test_fuzz_generate_deserialize_impl() { +// // Set up RNG with a seed for reproducibility +// let seed = thread_rng().gen(); +// println!("seed {}", seed); +// let mut rng = StdRng::seed_from_u64(seed); + +// // Number of iterations for the test +// let num_iters = 10000; + +// for i in 0..num_iters { +// // Generate a random struct name +// let struct_name = format_ident!("{}", random_ident(&mut rng)); +// let z_struct_name = format_ident!("Z{}", struct_name); +// let z_struct_meta_name = format_ident!("Z{}Meta", struct_name); + +// // Generate random number of fields (1-10) +// let field_count = rng.gen_range(1..11); +// let fields = random_fields(&mut rng, field_count); + +// // Create a named fields collection +// let syn_fields = syn::punctuated::Punctuated::from_iter(fields.iter().cloned()); +// let fields_named = syn::FieldsNamed { +// brace_token: syn::token::Brace::default(), +// named: syn_fields, +// }; + +// // Split into meta fields and struct fields +// let (_, struct_fields) = crate::utils::process_fields(&fields_named); + +// // Call the function we're testing +// let result = generate_deserialize_impl::( +// &struct_name, +// &z_struct_name, +// &z_struct_meta_name, +// &struct_fields, +// false, +// ); + +// // Get the generated code as a string for validation +// let result_str = result.to_string(); + +// // Print the first result for debugging +// if i == 0 { +// println!("Generated deserialize_impl code format:\n{}", result_str); +// } + +// // Verify the result contains expected elements +// // Basic validation - must be non-empty +// assert!( +// !result_str.is_empty(), +// "Failed to generate TokenStream for iteration {}", +// i +// ); + +// // Validate that the generated code contains the expected impl definition +// let impl_pattern = format!( +// "impl < 'a > light_zero_copy :: borsh :: Deserialize < 'a > for {}", +// struct_name +// ); +// assert!( +// result_str.contains(&impl_pattern), +// "Generated code missing impl definition for iteration {}. Expected: {}", +// i, +// impl_pattern +// ); + +// // Validate type Output is defined +// let output_pattern = format!("type Output = {} < 'a >", z_struct_name); +// assert!( +// result_str.contains(&output_pattern), +// "Generated code missing Output type for iteration {}. Expected: {}", +// i, +// output_pattern +// ); + +// // Validate the zero_copy_at method is present +// assert!( +// result_str.contains("fn zero_copy_at (bytes : & 'a [u8])"), +// "Generated code missing zero_copy_at method for iteration {}", +// i +// ); + +// // Check for meta field extraction +// let meta_extraction_pattern = format!( +// "let (__meta , bytes) = light_zero_copy :: Ref :: < & 'a [u8] , {} > :: from_prefix (bytes) ?", +// z_struct_meta_name +// ); +// assert!( +// result_str.contains(&meta_extraction_pattern), +// "Generated code missing meta field extraction for iteration {}", +// i +// ); + +// // Check for return with Ok pattern +// assert!( +// result_str.contains("Ok (("), +// "Generated code missing Ok return statement for iteration {}", +// i +// ); + +// // Check for the struct initialization +// let struct_init_pattern = format!("{} {{", z_struct_name); +// assert!( +// result_str.contains(&struct_init_pattern), +// "Generated code missing struct initialization for iteration {}", +// i +// ); + +// // Check for meta field in the returned struct +// assert!( +// result_str.contains("__meta ,"), +// "Generated code missing meta field in struct initialization for iteration {}", +// i +// ); +// } +// } +// } + +/// Generates the ZeroCopyStructInner implementation as a TokenStream +pub fn generate_zero_copy_struct_inner( + name: &Ident, + z_struct_name: &Ident, +) -> syn::Result { + let result = if MUT { + quote! { + // ZeroCopyStructInner implementation + impl light_zero_copy::borsh_mut::ZeroCopyStructInnerMut for #name { + type ZeroCopyInnerMut = #z_struct_name<'static>; + } + } + } else { + quote! { + // ZeroCopyStructInner implementation + impl light_zero_copy::borsh::ZeroCopyStructInner for #name { + type ZeroCopyInner = #z_struct_name<'static>; + } + } + }; + Ok(result) +} + +pub fn derive_zero_copy_impl(input: ProcTokenStream) -> syn::Result { + // Parse the input DeriveInput + let input: DeriveInput = syn::parse(input)?; + + let hasher = false; + + // Process the input to extract struct information + let (name, z_struct_name, z_struct_meta_name, fields) = utils::process_input(&input)?; + + // Process the fields to separate meta fields and struct fields + let (meta_fields, struct_fields) = utils::process_fields(fields); + + let meta_struct_def = if !meta_fields.is_empty() { + meta_struct::generate_meta_struct::(&z_struct_meta_name, &meta_fields, hasher)? + } else { + quote! {} + }; + + let z_struct_def = generate_z_struct::( + &z_struct_name, + &z_struct_meta_name, + &struct_fields, + &meta_fields, + hasher, + )?; + + let zero_copy_struct_inner_impl = + generate_zero_copy_struct_inner::(name, &z_struct_name)?; + + let deserialize_impl = generate_deserialize_impl::( + name, + &z_struct_name, + &z_struct_meta_name, + &struct_fields, + meta_fields.is_empty(), + quote! {}, + )?; + + // Combine all implementations + let expanded = quote! { + #meta_struct_def + #z_struct_def + #zero_copy_struct_inner_impl + #deserialize_impl + }; + + Ok(expanded) +} diff --git a/program-libs/zero-copy-derive/src/zero_copy_eq.rs b/program-libs/zero-copy-derive/src/zero_copy_eq.rs new file mode 100644 index 0000000000..94b06b51a6 --- /dev/null +++ b/program-libs/zero-copy-derive/src/zero_copy_eq.rs @@ -0,0 +1,265 @@ +use proc_macro::TokenStream as ProcTokenStream; +use proc_macro2::TokenStream; +use quote::quote; +use syn::{DeriveInput, Field, Ident}; + +use crate::shared::{ + from_impl, utils, + z_struct::{analyze_struct_fields, FieldType}, +}; + +/// Generates meta field comparisons for PartialEq implementation +pub fn generate_meta_field_comparisons<'a>( + meta_fields: &'a [&'a Field], +) -> syn::Result + 'a> { + let field_types = analyze_struct_fields(meta_fields)?; + + let iterator = field_types.into_iter().map(|field_type| match field_type { + FieldType::Primitive(field_name, field_type) => { + match () { + _ if utils::is_specific_primitive_type(field_type, "u8") => quote! { + if other.#field_name != meta.#field_name { + return false; + } + }, + _ if utils::is_specific_primitive_type(field_type, "bool") => quote! { + if other.#field_name != (meta.#field_name > 0) { + return false; + } + }, + _ => { + // For u64, u32, u16 - use the type's from() method + quote! { + if other.#field_name != #field_type::from(meta.#field_name) { + return false; + } + } + } + } + } + _ => { + let field_name = field_type.name(); + quote! { + if other.#field_name != meta.#field_name { + return false; + } + } + } + }); + Ok(iterator) +} + +/// Generates struct field comparisons for PartialEq implementation +pub fn generate_struct_field_comparisons<'a, const MUT: bool>( + struct_fields: &'a [&'a Field], +) -> syn::Result + 'a> { + let field_types = analyze_struct_fields(struct_fields)?; + if field_types + .iter() + .any(|x| matches!(x, FieldType::Option(_, _))) + { + return Err(syn::Error::new_spanned( + struct_fields[0], + "Options are not supported in ZeroCopyEq", + )); + } + + let iterator = field_types.into_iter().map(|field_type| { + match field_type { + FieldType::VecU8(field_name) => { + quote! { + if self.#field_name != other.#field_name.as_slice() { + return false; + } + } + } + FieldType::VecCopy(field_name, _) => { + quote! { + if self.#field_name.as_slice() != other.#field_name.as_slice() { + return false; + } + } + } + FieldType::VecDynamicZeroCopy(field_name, _) => { + quote! { + if self.#field_name.as_slice() != other.#field_name.as_slice() { + return false; + } + } + } + FieldType::Array(field_name, _) => { + quote! { + if *self.#field_name != other.#field_name { + return false; + } + } + } + FieldType::Option(field_name, field_type) => { + if utils::is_specific_primitive_type(field_type, "u8") { + quote! { + if self.#field_name.is_some() && other.#field_name.is_some() { + if self.#field_name.as_ref().unwrap() != other.#field_name.as_ref().unwrap() { + return false; + } + } else if self.#field_name.is_some() || other.#field_name.is_some() { + return false; + } + } + } + // TODO: handle issue that structs need * == *, arrays need ** == * + // else if crate::utils::is_copy_type(field_type) { + // quote! { + // if self.#field_name.is_some() && other.#field_name.is_some() { + // if **self.#field_name.as_ref().unwrap() != *other.#field_name.as_ref().unwrap() { + // return false; + // } + // } else if self.#field_name.is_some() || other.#field_name.is_some() { + // return false; + // } + // } + // } + else { + quote! { + if self.#field_name.is_some() && other.#field_name.is_some() { + if **self.#field_name.as_ref().unwrap() != *other.#field_name.as_ref().unwrap() { + return false; + } + } else if self.#field_name.is_some() || other.#field_name.is_some() { + return false; + } + } + } + + } + FieldType::Pubkey(field_name) => { + quote! { + if *self.#field_name != other.#field_name { + return false; + } + } + } + FieldType::Primitive(field_name, field_type) => { + match () { + _ if utils::is_specific_primitive_type(field_type, "u8") => + if MUT { + quote! { + if *self.#field_name != other.#field_name { + return false; + } + } + } else { + quote! { + if self.#field_name != other.#field_name { + return false; + } + } + }, + _ if utils::is_specific_primitive_type(field_type, "bool") => + if MUT { + quote! { + if (*self.#field_name > 0) != other.#field_name { + return false; + } + } + } else { + quote! { + if (self.#field_name > 0) != other.#field_name { + return false; + } + } + }, + _ => { + // For u64, u32, u16 - use the type's from() method + quote! { + if #field_type::from(*self.#field_name) != other.#field_name { + return false; + } + } + } + } + } + FieldType::Copy(field_name, _) + | FieldType::DynamicZeroCopy(field_name, _) => { + quote! { + if self.#field_name != other.#field_name { + return false; + } + } + }, + FieldType::OptionU64(field_name) + | FieldType::OptionU32(field_name) + | FieldType::OptionU16(field_name) => { + quote! { + if self.#field_name != other.#field_name { + return false; + } + } + } + } + }); + Ok(iterator) +} + +/// Generates the PartialEq implementation as a TokenStream +pub fn generate_partial_eq_impl( + name: &Ident, + z_struct_name: &Ident, + z_struct_meta_name: &Ident, + meta_fields: &[&Field], + struct_fields: &[&Field], +) -> syn::Result { + let struct_field_comparisons = generate_struct_field_comparisons::(struct_fields)?; + let result = if !meta_fields.is_empty() { + let meta_field_comparisons = generate_meta_field_comparisons(meta_fields)?; + quote! { + impl<'a> PartialEq<#name> for #z_struct_name<'a> { + fn eq(&self, other: &#name) -> bool { + let meta: &#z_struct_meta_name = &self.__meta; + #(#meta_field_comparisons)* + #(#struct_field_comparisons)* + true + } + } + } + } else { + quote! { + impl<'a> PartialEq<#name> for #z_struct_name<'a> { + fn eq(&self, other: &#name) -> bool { + #(#struct_field_comparisons)* + true + } + } + + } + }; + Ok(result) +} + +pub fn derive_zero_copy_eq_impl(input: ProcTokenStream) -> syn::Result { + // Parse the input DeriveInput + let input: DeriveInput = syn::parse(input)?; + + // Process the input to extract struct information + let (name, z_struct_name, z_struct_meta_name, fields) = utils::process_input(&input)?; + + // Process the fields to separate meta fields and struct fields + let (meta_fields, struct_fields) = utils::process_fields(fields); + + // Generate the PartialEq implementation + let partial_eq_impl = generate_partial_eq_impl::( + name, + &z_struct_name, + &z_struct_meta_name, + &meta_fields, + &struct_fields, + )?; + + // Generate From implementations + let from_impl = + from_impl::generate_from_impl::(name, &z_struct_name, &meta_fields, &struct_fields)?; + + Ok(quote! { + #partial_eq_impl + #from_impl + }) +} diff --git a/program-libs/zero-copy-derive/src/zero_copy_mut.rs b/program-libs/zero-copy-derive/src/zero_copy_mut.rs new file mode 100644 index 0000000000..ad52bba4d5 --- /dev/null +++ b/program-libs/zero-copy-derive/src/zero_copy_mut.rs @@ -0,0 +1,93 @@ +use proc_macro::TokenStream; +use quote::{format_ident, quote}; +use syn::DeriveInput; + +use crate::{ + shared::{ + meta_struct, utils, + z_struct::{self, analyze_struct_fields}, + zero_copy_new::{generate_config_struct, generate_init_mut_impl}, + }, + zero_copy, +}; + +pub fn derive_zero_copy_mut_impl(fn_input: TokenStream) -> syn::Result { + // Parse the input DeriveInput + let input: DeriveInput = syn::parse(fn_input.clone())?; + + let hasher = false; + + // Process the input to extract struct information + let (name, z_struct_name, z_struct_meta_name, fields) = utils::process_input(&input)?; + + // Process the fields to separate meta fields and struct fields + let (meta_fields, struct_fields) = utils::process_fields(fields); + + let meta_struct_def_mut = if !meta_fields.is_empty() { + meta_struct::generate_meta_struct::(&z_struct_meta_name, &meta_fields, hasher)? + } else { + quote! {} + }; + + let z_struct_def_mut = z_struct::generate_z_struct::( + &z_struct_name, + &z_struct_meta_name, + &struct_fields, + &meta_fields, + hasher, + )?; + + let zero_copy_struct_inner_impl_mut = zero_copy::generate_zero_copy_struct_inner::( + name, + &format_ident!("{}Mut", z_struct_name), + )?; + + let deserialize_impl_mut = zero_copy::generate_deserialize_impl::( + name, + &z_struct_name, + &z_struct_meta_name, + &struct_fields, + meta_fields.is_empty(), + quote! {}, + )?; + + // Parse the input DeriveInput + let input: DeriveInput = syn::parse(fn_input)?; + + // Process the input to extract struct information + let (name, _z_struct_name, _z_struct_meta_name, fields) = utils::process_input(&input)?; + + // Use the same field processing logic as other derive macros for consistency + let (meta_fields, struct_fields) = utils::process_fields(fields); + + // Process ALL fields uniformly by type (no position dependency for config generation) + let all_fields: Vec<&syn::Field> = meta_fields + .iter() + .chain(struct_fields.iter()) + .cloned() + .collect(); + let all_field_types = analyze_struct_fields(&all_fields)?; + + // Generate configuration struct based on all fields that need config (type-based) + let config_struct = generate_config_struct(name, &all_field_types)?; + + // Generate ZeroCopyNew implementation using the existing field separation + let init_mut_impl = generate_init_mut_impl(name, &meta_fields, &struct_fields)?; + + // Combine all mutable implementations + let expanded = quote! { + #config_struct + + #init_mut_impl + + #meta_struct_def_mut + + #z_struct_def_mut + + #zero_copy_struct_inner_impl_mut + + #deserialize_impl_mut + }; + + Ok(expanded) +} diff --git a/program-libs/zero-copy-derive/tests/config_test.rs b/program-libs/zero-copy-derive/tests/config_test.rs new file mode 100644 index 0000000000..990b2f6a18 --- /dev/null +++ b/program-libs/zero-copy-derive/tests/config_test.rs @@ -0,0 +1,430 @@ +#![cfg(feature = "mut")] + +use borsh::{BorshDeserialize, BorshSerialize}; +use light_zero_copy::borsh_mut::DeserializeMut; +use light_zero_copy_derive::{ZeroCopy, ZeroCopyEq, ZeroCopyMut}; + +/// Simple struct with just a Vec field to test basic config functionality +#[repr(C)] +#[derive(Debug, PartialEq, BorshSerialize, BorshDeserialize, ZeroCopy, ZeroCopyMut, ZeroCopyEq)] +pub struct SimpleVecStruct { + pub a: u8, + pub vec: Vec, + pub b: u16, +} + +#[test] +fn test_simple_config_generation() { + // This test verifies that the ZeroCopyNew derive macro generates the expected config struct + // and ZeroCopyNew implementation + + // The config should have been generated as SimpleVecStructConfig + let config = SimpleVecStructConfig { + vec: 10, // Vec should have u32 config (length) + }; + + // Test that we can create a configuration + assert_eq!(config.vec, 10); + + println!("Config generation test passed!"); +} + +#[test] +fn test_simple_vec_struct_new_zero_copy() { + use light_zero_copy::init_mut::ZeroCopyNew; + + // Test the new_zero_copy method generated by ZeroCopyNew + let config = SimpleVecStructConfig { + vec: 5, // Vec with capacity 5 + }; + + // Calculate exact buffer size needed and allocate + let buffer_size = SimpleVecStruct::byte_len(&config); + let mut bytes = vec![0u8; buffer_size]; + + // Use the generated new_zero_copy method + let result = SimpleVecStruct::new_zero_copy(&mut bytes, config); + assert!(result.is_ok()); + let (mut simple_struct, remaining) = result.unwrap(); + + // Verify we used exactly the calculated number of bytes + assert_eq!( + remaining.len(), + 0, + "Should have used exactly {} bytes", + buffer_size + ); + + // Test that we can set meta fields + simple_struct.__meta.a = 42; + + // Test that we can write to the vec slice + simple_struct.vec[0] = 10; + simple_struct.vec[1] = 20; + simple_struct.vec[2] = 30; + + // Test that we can set the b field + *simple_struct.b = 12345u16.into(); + + // Verify the values we set + assert_eq!(simple_struct.__meta.a, 42); + assert_eq!(simple_struct.vec[0], 10); + assert_eq!(simple_struct.vec[1], 20); + assert_eq!(simple_struct.vec[2], 30); + assert_eq!(u16::from(*simple_struct.b), 12345); + + // Test deserializing the initialized bytes with zero_copy_at_mut + let deserialize_result = SimpleVecStruct::zero_copy_at_mut(&mut bytes); + assert!(deserialize_result.is_ok()); + let (deserialized, _remaining) = deserialize_result.unwrap(); + + // Verify the deserialized data matches what we set + assert_eq!(deserialized.__meta.a, 42); + assert_eq!(deserialized.vec[0], 10); + assert_eq!(deserialized.vec[1], 20); + assert_eq!(deserialized.vec[2], 30); + assert_eq!(u16::from(*deserialized.b), 12345); + + println!("new_zero_copy initialization test passed!"); +} + +/// Struct with Option field to test Option config +#[repr(C)] +#[derive(Debug, PartialEq, BorshSerialize, BorshDeserialize, ZeroCopy, ZeroCopyMut)] +pub struct SimpleOptionStruct { + pub a: u8, + pub option: Option, +} + +#[test] +fn test_simple_option_struct_new_zero_copy() { + use light_zero_copy::init_mut::ZeroCopyNew; + + // Test with option enabled + let config = SimpleOptionStructConfig { + option: true, // Option should have bool config (enabled/disabled) + }; + + // Calculate exact buffer size needed and allocate + let buffer_size = SimpleOptionStruct::byte_len(&config); + let mut bytes = vec![0u8; buffer_size]; + + let result = SimpleOptionStruct::new_zero_copy(&mut bytes, config); + assert!(result.is_ok()); + let (mut simple_struct, remaining) = result.unwrap(); + + // Verify we used exactly the calculated number of bytes + assert_eq!( + remaining.len(), + 0, + "Should have used exactly {} bytes", + buffer_size + ); + + // Test that we can set meta field + simple_struct.__meta.a = 123; + + // Test that option is Some and we can set its value + assert!(simple_struct.option.is_some()); + if let Some(ref mut opt_val) = simple_struct.option { + **opt_val = 98765u64.into(); + } + + // Verify the values + assert_eq!(simple_struct.__meta.a, 123); + if let Some(ref opt_val) = simple_struct.option { + assert_eq!(u64::from(**opt_val), 98765); + } + + // Test deserializing + let (deserialized, _) = SimpleOptionStruct::zero_copy_at_mut(&mut bytes).unwrap(); + assert_eq!(deserialized.__meta.a, 123); + assert!(deserialized.option.is_some()); + if let Some(ref opt_val) = deserialized.option { + assert_eq!(u64::from(**opt_val), 98765); + } + + println!("Option new_zero_copy test passed!"); +} + +#[test] +fn test_simple_option_struct_disabled() { + use light_zero_copy::init_mut::ZeroCopyNew; + + // Test with option disabled + let config = SimpleOptionStructConfig { + option: false, // Option disabled + }; + + // Calculate exact buffer size needed and allocate + let buffer_size = SimpleOptionStruct::byte_len(&config); + let mut bytes = vec![0u8; buffer_size]; + + let result = SimpleOptionStruct::new_zero_copy(&mut bytes, config); + assert!(result.is_ok()); + let (mut simple_struct, remaining) = result.unwrap(); + + // Verify we used exactly the calculated number of bytes + assert_eq!( + remaining.len(), + 0, + "Should have used exactly {} bytes", + buffer_size + ); + + // Set meta field + simple_struct.__meta.a = 200; + + // Test that option is None + assert!(simple_struct.option.is_none()); + + // Test deserializing + let (deserialized, _) = SimpleOptionStruct::zero_copy_at_mut(&mut bytes).unwrap(); + assert_eq!(deserialized.__meta.a, 200); + assert!(deserialized.option.is_none()); + + println!("Option disabled new_zero_copy test passed!"); +} + +/// Test both Vec and Option in one struct +#[repr(C)] +#[derive(Debug, PartialEq, BorshSerialize, BorshDeserialize, ZeroCopy, ZeroCopyMut)] +pub struct MixedStruct { + pub a: u8, + pub vec: Vec, + pub option: Option, + pub b: u16, +} + +#[test] +fn test_mixed_struct_new_zero_copy() { + use light_zero_copy::init_mut::ZeroCopyNew; + + // Test with both vec and option enabled + let config = MixedStructConfig { + vec: 8, // Vec -> u32 length + option: true, // Option -> bool enabled + }; + + // Calculate exact buffer size needed and allocate + let buffer_size = MixedStruct::byte_len(&config); + let mut bytes = vec![0u8; buffer_size]; + + let result = MixedStruct::new_zero_copy(&mut bytes, config); + assert!(result.is_ok()); + let (mut mixed_struct, remaining) = result.unwrap(); + + // Verify we used exactly the calculated number of bytes + assert_eq!( + remaining.len(), + 0, + "Should have used exactly {} bytes", + buffer_size + ); + + // Set meta field + mixed_struct.__meta.a = 77; + + // Set vec data + mixed_struct.vec[0] = 11; + mixed_struct.vec[3] = 44; + mixed_struct.vec[7] = 88; + + // Set option value + assert!(mixed_struct.option.is_some()); + if let Some(ref mut opt_val) = mixed_struct.option { + **opt_val = 123456789u64.into(); + } + + // Set b field + *mixed_struct.b = 54321u16.into(); + + // Verify all values + assert_eq!(mixed_struct.__meta.a, 77); + assert_eq!(mixed_struct.vec[0], 11); + assert_eq!(mixed_struct.vec[3], 44); + assert_eq!(mixed_struct.vec[7], 88); + if let Some(ref opt_val) = mixed_struct.option { + assert_eq!(u64::from(**opt_val), 123456789); + } + assert_eq!(u16::from(*mixed_struct.b), 54321); + + // Test deserializing + let (deserialized, _) = MixedStruct::zero_copy_at_mut(&mut bytes).unwrap(); + assert_eq!(deserialized.__meta.a, 77); + assert_eq!(deserialized.vec[0], 11); + assert_eq!(deserialized.vec[3], 44); + assert_eq!(deserialized.vec[7], 88); + assert!(deserialized.option.is_some()); + if let Some(ref opt_val) = deserialized.option { + assert_eq!(u64::from(**opt_val), 123456789); + } + assert_eq!(u16::from(*deserialized.b), 54321); + + println!("Mixed struct new_zero_copy test passed!"); +} + +#[test] +fn test_mixed_struct_option_disabled() { + use light_zero_copy::init_mut::ZeroCopyNew; + + // Test with vec enabled but option disabled + let config = MixedStructConfig { + vec: 3, // Vec -> u32 length + option: false, // Option -> bool disabled + }; + + // Calculate exact buffer size needed and allocate + let buffer_size = MixedStruct::byte_len(&config); + let mut bytes = vec![0u8; buffer_size]; + + let result = MixedStruct::new_zero_copy(&mut bytes, config); + assert!(result.is_ok()); + let (mut mixed_struct, remaining) = result.unwrap(); + + // Verify we used exactly the calculated number of bytes + assert_eq!( + remaining.len(), + 0, + "Should have used exactly {} bytes", + buffer_size + ); + + // Set values + mixed_struct.__meta.a = 99; + mixed_struct.vec[0] = 255; + mixed_struct.vec[2] = 128; + *mixed_struct.b = 9999u16.into(); + + // Verify option is None + assert!(mixed_struct.option.is_none()); + + // Test deserializing + let (deserialized, _) = MixedStruct::zero_copy_at_mut(&mut bytes).unwrap(); + assert_eq!(deserialized.__meta.a, 99); + assert_eq!(deserialized.vec[0], 255); + assert_eq!(deserialized.vec[2], 128); + assert!(deserialized.option.is_none()); + assert_eq!(u16::from(*deserialized.b), 9999); + + println!("Mixed struct option disabled test passed!"); +} + +#[test] +fn test_byte_len_calculation() { + use light_zero_copy::init_mut::ZeroCopyNew; + + // Test SimpleVecStruct byte_len calculation + let config = SimpleVecStructConfig { + vec: 10, // Vec with capacity 10 + }; + + let expected_size = 1 + // a: u8 (meta field) + 4 + 10 + // vec: 4 bytes length + 10 bytes data + 2; // b: u16 + + let calculated_size = SimpleVecStruct::byte_len(&config); + assert_eq!(calculated_size, expected_size); + println!( + "SimpleVecStruct byte_len: calculated={}, expected={}", + calculated_size, expected_size + ); + + // Test SimpleOptionStruct byte_len calculation + let config_some = SimpleOptionStructConfig { + option: true, // Option enabled + }; + + let expected_size_some = 1 + // a: u8 (meta field) + 1 + 8; // option: 1 byte discriminant + 8 bytes u64 + + let calculated_size_some = SimpleOptionStruct::byte_len(&config_some); + assert_eq!(calculated_size_some, expected_size_some); + println!( + "SimpleOptionStruct (Some) byte_len: calculated={}, expected={}", + calculated_size_some, expected_size_some + ); + + let config_none = SimpleOptionStructConfig { + option: false, // Option disabled + }; + + let expected_size_none = 1 + // a: u8 (meta field) + 1; // option: 1 byte discriminant for None + + let calculated_size_none = SimpleOptionStruct::byte_len(&config_none); + assert_eq!(calculated_size_none, expected_size_none); + println!( + "SimpleOptionStruct (None) byte_len: calculated={}, expected={}", + calculated_size_none, expected_size_none + ); + + // Test MixedStruct byte_len calculation + let config_mixed = MixedStructConfig { + vec: 5, // Vec with capacity 5 + option: true, // Option enabled + }; + + let expected_size_mixed = 1 + // a: u8 (meta field) + 4 + 5 + // vec: 4 bytes length + 5 bytes data + 1 + 8 + // option: 1 byte discriminant + 8 bytes u64 + 2; // b: u16 + + let calculated_size_mixed = MixedStruct::byte_len(&config_mixed); + assert_eq!(calculated_size_mixed, expected_size_mixed); + println!( + "MixedStruct byte_len: calculated={}, expected={}", + calculated_size_mixed, expected_size_mixed + ); + + println!("All byte_len calculation tests passed!"); +} + +#[test] +fn test_dynamic_buffer_allocation_with_byte_len() { + use light_zero_copy::init_mut::ZeroCopyNew; + + // Example of how to use byte_len for dynamic buffer allocation + let config = MixedStructConfig { + vec: 12, // Vec with capacity 12 + option: true, // Option enabled + }; + + // Calculate the exact buffer size needed + let required_size = MixedStruct::byte_len(&config); + println!("Required buffer size: {} bytes", required_size); + + // Allocate exactly the right amount of memory + let mut bytes = vec![0u8; required_size]; + + // Initialize the structure + let result = MixedStruct::new_zero_copy(&mut bytes, config); + assert!(result.is_ok()); + let (mut mixed_struct, remaining) = result.unwrap(); + + // Verify we used exactly the right amount of bytes (no remaining bytes) + assert_eq!( + remaining.len(), + 0, + "Should have used exactly the calculated number of bytes" + ); + + // Set some values to verify it works + mixed_struct.__meta.a = 42; + mixed_struct.vec[5] = 123; + if let Some(ref mut opt_val) = mixed_struct.option { + **opt_val = 9999u64.into(); + } + *mixed_struct.b = 7777u16.into(); + + // Verify round-trip works + let (deserialized, _) = MixedStruct::zero_copy_at_mut(&mut bytes).unwrap(); + assert_eq!(deserialized.__meta.a, 42); + assert_eq!(deserialized.vec[5], 123); + if let Some(ref opt_val) = deserialized.option { + assert_eq!(u64::from(**opt_val), 9999); + } + assert_eq!(u16::from(*deserialized.b), 7777); + + println!("Dynamic buffer allocation test passed!"); +} diff --git a/program-libs/zero-copy-derive/tests/cross_crate_copy.rs b/program-libs/zero-copy-derive/tests/cross_crate_copy.rs new file mode 100644 index 0000000000..e827908046 --- /dev/null +++ b/program-libs/zero-copy-derive/tests/cross_crate_copy.rs @@ -0,0 +1,295 @@ +#![cfg(feature = "mut")] +//! Test cross-crate Copy identification functionality +//! +//! This test validates that the zero-copy derive macro correctly identifies +//! which types implement Copy, both for built-in types and user-defined types. + +use borsh::{BorshDeserialize, BorshSerialize}; +use light_zero_copy_derive::{ZeroCopy, ZeroCopyEq, ZeroCopyMut}; + +// Test struct with primitive Copy types that should be in meta fields +#[derive(Debug, PartialEq, BorshSerialize, BorshDeserialize, ZeroCopy)] +pub struct PrimitiveCopyStruct { + pub a: u8, + pub b: u16, + pub c: u32, + pub d: u64, + pub e: bool, + pub f: Vec, // Split point - this and following fields go to struct_fields + pub g: u32, // Should be in struct_fields due to field ordering rules +} + +// Test struct with primitive Copy types that should be in meta fields +#[derive(Debug, PartialEq, BorshSerialize, BorshDeserialize, ZeroCopy, ZeroCopyEq, ZeroCopyMut)] +pub struct PrimitiveCopyStruct2 { + pub f: Vec, // Split point - this and following fields go to struct_fields + pub a: u8, + pub b: u16, + pub c: u32, + pub d: u64, + pub e: bool, + pub g: u32, +} + +// Test struct with arrays that use u8 (which supports Unaligned) +#[derive(Debug, PartialEq, BorshSerialize, BorshDeserialize, ZeroCopy)] +pub struct ArrayCopyStruct { + pub fixed_u8: [u8; 4], + pub another_u8: [u8; 8], + pub data: Vec, // Split point + pub more_data: [u8; 3], // Should be in struct_fields due to field ordering +} + +// Test struct with Vec of primitive Copy types +#[derive(Debug, PartialEq, BorshSerialize, BorshDeserialize, ZeroCopy)] +pub struct VecPrimitiveStruct { + pub header: u32, + pub data: Vec, // Vec - special case + pub numbers: Vec, // Vec of Copy type + pub footer: u64, +} + +#[cfg(test)] +mod tests { + use light_zero_copy::borsh::Deserialize; + + use super::*; + + #[test] + fn test_primitive_copy_field_splitting() { + // This test validates that primitive Copy types are correctly + // identified and placed in meta_fields until we hit a Vec + + let data = PrimitiveCopyStruct { + a: 1, + b: 2, + c: 3, + d: 4, + e: true, + f: vec![5, 6, 7], + g: 8, + }; + + let serialized = borsh::to_vec(&data).unwrap(); + let (deserialized, _) = PrimitiveCopyStruct::zero_copy_at(&serialized).unwrap(); + + // Verify we can access meta fields (should be zero-copy references) + assert_eq!(deserialized.a, 1); + assert_eq!(deserialized.b.get(), 2); // U16 type, use .get() + assert_eq!(deserialized.c.get(), 3); // U32 type, use .get() + assert_eq!(deserialized.d.get(), 4); // U64 type, use .get() + assert_eq!(deserialized.e(), true); // bool accessor method + + // Verify we can access struct fields + assert_eq!(deserialized.f, &[5, 6, 7]); + assert_eq!(deserialized.g.get(), 8); // U32 type in struct fields + } + + #[test] + fn test_array_copy_field_splitting() { + // Arrays should be treated as Copy types + let data = ArrayCopyStruct { + fixed_u8: [1, 2, 3, 4], + another_u8: [10, 20, 30, 40, 50, 60, 70, 80], + data: vec![5, 6], + more_data: [30, 40, 50], + }; + + let serialized = borsh::to_vec(&data).unwrap(); + let (deserialized, _) = ArrayCopyStruct::zero_copy_at(&serialized).unwrap(); + + // Arrays should be accessible (in meta_fields before Vec split) + assert_eq!(deserialized.fixed_u8.as_ref(), &[1, 2, 3, 4]); + assert_eq!( + deserialized.another_u8.as_ref(), + &[10, 20, 30, 40, 50, 60, 70, 80] + ); + + // After Vec split + assert_eq!(deserialized.data, &[5, 6]); + assert_eq!(deserialized.more_data.as_ref(), &[30, 40, 50]); + } + + #[test] + fn test_vec_primitive_types() { + // Test Vec with various primitive Copy element types + let data = VecPrimitiveStruct { + header: 1, + data: vec![10, 20, 30], + numbers: vec![100, 200, 300], + footer: 999, + }; + + let serialized = borsh::to_vec(&data).unwrap(); + let (deserialized, _) = VecPrimitiveStruct::zero_copy_at(&serialized).unwrap(); + + assert_eq!(deserialized.header.get(), 1); + + // Vec is special case - stored as slice + assert_eq!(deserialized.data, &[10, 20, 30]); + + // Vec should use ZeroCopySliceBorsh + assert_eq!(deserialized.numbers.len(), 3); + assert_eq!(deserialized.numbers[0].get(), 100); + assert_eq!(deserialized.numbers[1].get(), 200); + assert_eq!(deserialized.numbers[2].get(), 300); + + assert_eq!(deserialized.footer.get(), 999); + } + + #[test] + fn test_all_derives_with_vec_first() { + // This test validates PrimitiveCopyStruct2 which has Vec as the first field + // This means NO meta fields (all fields go to struct_fields due to field ordering) + // Also tests all derive macros: ZeroCopy, ZeroCopyEq, ZeroCopyMut + + use light_zero_copy::{borsh_mut::DeserializeMut, init_mut::ZeroCopyNew}; + + let data = PrimitiveCopyStruct2 { + f: vec![1, 2, 3], // Vec first - causes all fields to be in struct_fields + a: 10, + b: 20, + c: 30, + d: 40, + e: true, + g: 50, + }; + + // Test ZeroCopy (immutable) + let serialized = borsh::to_vec(&data).unwrap(); + let (deserialized, _) = PrimitiveCopyStruct2::zero_copy_at(&serialized).unwrap(); + + // Since Vec is first, ALL fields should be in struct_fields (no meta fields) + assert_eq!(deserialized.f, &[1, 2, 3]); + assert_eq!(deserialized.a, 10); // u8 direct access + assert_eq!(deserialized.b.get(), 20); // U16 via .get() + assert_eq!(deserialized.c.get(), 30); // U32 via .get() + assert_eq!(deserialized.d.get(), 40); // U64 via .get() + assert_eq!(deserialized.e(), true); // bool accessor method + assert_eq!(deserialized.g.get(), 50); // U32 via .get() + + // Test ZeroCopyEq (PartialEq implementation) + let original = PrimitiveCopyStruct2 { + f: vec![1, 2, 3], + a: 10, + b: 20, + c: 30, + d: 40, + e: true, + g: 50, + }; + + // Should be equal to original + assert_eq!(deserialized, original); + + // Test inequality + let different = PrimitiveCopyStruct2 { + f: vec![1, 2, 3], + a: 11, + b: 20, + c: 30, + d: 40, + e: true, + g: 50, // Different 'a' + }; + assert_ne!(deserialized, different); + + // Test ZeroCopyMut (mutable zero-copy) + #[cfg(feature = "mut")] + { + let mut serialized_mut = borsh::to_vec(&data).unwrap(); + let (deserialized_mut, _) = + PrimitiveCopyStruct2::zero_copy_at_mut(&mut serialized_mut).unwrap(); + + // Test mutable access + assert_eq!(deserialized_mut.f, &[1, 2, 3]); + assert_eq!(*deserialized_mut.a, 10); // Mutable u8 field + assert_eq!(deserialized_mut.b.get(), 20); + let (deserialized_mut, _) = + PrimitiveCopyStruct2::zero_copy_at(&mut serialized_mut).unwrap(); + + // Test From implementation (ZeroCopyEq generates this for immutable version) + let converted: PrimitiveCopyStruct2 = deserialized_mut.into(); + assert_eq!(converted.a, 10); + assert_eq!(converted.b, 20); + assert_eq!(converted.c, 30); + assert_eq!(converted.d, 40); + assert_eq!(converted.e, true); + assert_eq!(converted.f, vec![1, 2, 3]); + assert_eq!(converted.g, 50); + } + + // Test ZeroCopyNew (configuration-based initialization) + let config = super::PrimitiveCopyStruct2Config { + f: 3, // Vec length + // Other fields don't need config (they're primitives) + }; + + // Calculate required buffer size + let buffer_size = PrimitiveCopyStruct2::byte_len(&config); + let mut buffer = vec![0u8; buffer_size]; + + // Initialize the zero-copy struct + let (mut initialized, _) = + PrimitiveCopyStruct2::new_zero_copy(&mut buffer, config).unwrap(); + + // Verify we can access the initialized fields + assert_eq!(initialized.f.len(), 3); // Vec should have correct length + + // Set some values in the Vec + initialized.f[0] = 100; + initialized.f[1] = 101; + initialized.f[2] = 102; + *initialized.a = 200; + + // Verify the values were set correctly + assert_eq!(initialized.f, &[100, 101, 102]); + assert_eq!(*initialized.a, 200); + + println!("All derive macros (ZeroCopy, ZeroCopyEq, ZeroCopyMut) work correctly with Vec-first struct!"); + } + + #[test] + fn test_copy_identification_compilation() { + // The primary test is that our macro successfully processes all struct definitions + // above without panicking or generating invalid code. The fact that compilation + // succeeds demonstrates that our Copy identification logic works correctly. + + // Test basic functionality to ensure the generated code is sound + let primitive_data = PrimitiveCopyStruct { + a: 1, + b: 2, + c: 3, + d: 4, + e: true, + f: vec![1, 2], + g: 5, + }; + + let array_data = ArrayCopyStruct { + fixed_u8: [1, 2, 3, 4], + another_u8: [5, 6, 7, 8, 9, 10, 11, 12], + data: vec![13, 14], + more_data: [15, 16, 17], + }; + + let vec_data = VecPrimitiveStruct { + header: 42, + data: vec![1, 2, 3], + numbers: vec![10, 20], + footer: 99, + }; + + // Serialize and deserialize to verify the generated code works + let serialized = borsh::to_vec(&primitive_data).unwrap(); + let (_, _) = PrimitiveCopyStruct::zero_copy_at(&serialized).unwrap(); + + let serialized = borsh::to_vec(&array_data).unwrap(); + let (_, _) = ArrayCopyStruct::zero_copy_at(&serialized).unwrap(); + + let serialized = borsh::to_vec(&vec_data).unwrap(); + let (_, _) = VecPrimitiveStruct::zero_copy_at(&serialized).unwrap(); + + println!("Cross-crate Copy identification test passed - all structs compiled and work correctly!"); + } +} diff --git a/program-libs/zero-copy-derive/tests/from_test.rs b/program-libs/zero-copy-derive/tests/from_test.rs new file mode 100644 index 0000000000..20391c36dd --- /dev/null +++ b/program-libs/zero-copy-derive/tests/from_test.rs @@ -0,0 +1,77 @@ +#![cfg(feature = "mut")] +use std::vec::Vec; + +use borsh::{BorshDeserialize, BorshSerialize}; +use light_zero_copy::{borsh::Deserialize, ZeroCopyEq}; +use light_zero_copy_derive::{ZeroCopy, ZeroCopyMut}; + +// Simple struct with a primitive field and a vector +#[derive(Debug, PartialEq, BorshSerialize, BorshDeserialize, ZeroCopy, ZeroCopyMut, ZeroCopyEq)] +pub struct SimpleStruct { + pub a: u8, + pub b: Vec, +} + +// Basic struct with all basic numeric types +#[derive(Debug, PartialEq, BorshSerialize, BorshDeserialize, ZeroCopy, ZeroCopyMut, ZeroCopyEq)] +pub struct NumericStruct { + pub a: u8, + pub b: u16, + pub c: u32, + pub d: u64, + pub e: bool, +} + +// use light_zero_copy::borsh_mut::DeserializeMut; // Not needed for non-mut derivations + +#[test] +fn test_simple_from_implementation() { + // Create an instance of our struct + let original = SimpleStruct { + a: 42, + b: vec![1, 2, 3, 4, 5], + }; + + // Serialize it + let bytes = original.try_to_vec().unwrap(); + // byte_len not available for non-mut derivations + // assert_eq!(bytes.len(), original.byte_len()); + + // Test From implementation for immutable struct + let (zero_copy, _) = SimpleStruct::zero_copy_at(&bytes).unwrap(); + let converted: SimpleStruct = zero_copy.into(); + assert_eq!(converted.a, 42); + assert_eq!(converted.b, vec![1, 2, 3, 4, 5]); + assert_eq!(converted, original); +} + +#[test] +fn test_numeric_from_implementation() { + // Create a struct with different primitive types + let original = NumericStruct { + a: 1, + b: 2, + c: 3, + d: 4, + e: true, + }; + + // Serialize it + let bytes = original.try_to_vec().unwrap(); + // byte_len not available for non-mut derivations + // assert_eq!(bytes.len(), original.byte_len()); + + // Test From implementation for immutable struct + let (zero_copy, _) = NumericStruct::zero_copy_at(&bytes).unwrap(); + let converted: NumericStruct = zero_copy.clone().into(); + + // Verify all fields + assert_eq!(converted.a, 1); + assert_eq!(converted.b, 2); + assert_eq!(converted.c, 3); + assert_eq!(converted.d, 4); + assert!(converted.e); + + // Verify complete struct + assert_eq!(converted, original); +} diff --git a/program-libs/zero-copy-derive/tests/instruction_data.rs b/program-libs/zero-copy-derive/tests/instruction_data.rs new file mode 100644 index 0000000000..74c7a76179 --- /dev/null +++ b/program-libs/zero-copy-derive/tests/instruction_data.rs @@ -0,0 +1,1401 @@ +#![cfg(feature = "mut")] +use std::vec::Vec; + +use borsh::{BorshDeserialize, BorshSerialize}; +use light_zero_copy::{borsh::Deserialize, borsh_mut::DeserializeMut, errors::ZeroCopyError}; +use light_zero_copy_derive::{ZeroCopy, ZeroCopyEq, ZeroCopyMut}; +use zerocopy::{FromBytes, Immutable, IntoBytes, KnownLayout, Ref, Unaligned}; + +#[derive( + Debug, + Copy, + PartialEq, + Clone, + Immutable, + FromBytes, + IntoBytes, + KnownLayout, + BorshDeserialize, + BorshSerialize, + Default, + Unaligned, +)] +#[repr(C)] +pub struct Pubkey(pub(crate) [u8; 32]); + +impl Pubkey { + pub fn new_unique() -> Self { + use rand::Rng; + let mut rng = rand::thread_rng(); + let bytes = rng.gen::<[u8; 32]>(); + Pubkey(bytes) + } + + pub fn to_bytes(self) -> [u8; 32] { + self.0 + } +} + +impl<'a> Deserialize<'a> for Pubkey { + type Output = Ref<&'a [u8], Pubkey>; + + #[inline] + fn zero_copy_at(bytes: &'a [u8]) -> Result<(Self::Output, &'a [u8]), ZeroCopyError> { + Ok(Ref::<&'a [u8], Pubkey>::from_prefix(bytes)?) + } +} + +impl<'a> DeserializeMut<'a> for Pubkey { + type Output = Ref<&'a mut [u8], Pubkey>; + + #[inline] + fn zero_copy_at_mut( + bytes: &'a mut [u8], + ) -> Result<(Self::Output, &'a mut [u8]), ZeroCopyError> { + Ok(Ref::<&'a mut [u8], Pubkey>::from_prefix(bytes)?) + } +} + +// We should not implement DeserializeMut for primitive types directly +// The implementation should be in the zero-copy crate + +impl PartialEq<>::Output> for Pubkey { + fn eq(&self, other: &>::Output) -> bool { + self.0 == other.0 + } +} + +impl<'a> light_zero_copy::init_mut::ZeroCopyNew<'a> for Pubkey { + type ZeroCopyConfig = (); + type Output = >::Output; + + fn byte_len(_config: &Self::ZeroCopyConfig) -> usize { + 32 // Pubkey is always 32 bytes + } + + fn new_zero_copy( + bytes: &'a mut [u8], + _config: Self::ZeroCopyConfig, + ) -> Result<(Self::Output, &'a mut [u8]), ZeroCopyError> { + Self::zero_copy_at_mut(bytes) + } +} + +#[derive( + ZeroCopy, ZeroCopyMut, BorshDeserialize, BorshSerialize, Debug, PartialEq, Default, Clone, +)] +pub struct InstructionDataInvoke { + pub proof: Option, + pub input_compressed_accounts_with_merkle_context: + Vec, + pub output_compressed_accounts: Vec, + pub relay_fee: Option, + pub new_address_params: Vec, + pub compress_or_decompress_lamports: Option, + pub is_compress: bool, +} + +// impl<'a> light_zero_copy::init_mut::ZeroCopyNew<'a> for InstructionDataInvoke { +// type Config = InstructionDataInvokeConfig; +// type Output = >::Output; +// +// fn new_zero_copy( +// bytes: &'a mut [u8], +// config: Self::Config +// ) -> Result<(Self::Output, &'a mut [u8]), light_zero_copy::errors::ZeroCopyError> { +// use zerocopy::Ref; +// +// // First handle the meta struct (empty for InstructionDataInvoke) +// let (__meta, bytes) = Ref::<&mut [u8], ZInstructionDataInvokeMetaMut>::from_prefix(bytes)?; +// +// // Initialize each field using the corresponding config, following DeserializeMut order +// let (proof, bytes) = as light_zero_copy::init_mut::ZeroCopyNew>::new_zero_copy( +// bytes, +// (config.proof_config.is_some(), CompressedProofConfig {}) +// )?; +// +// let input_configs: Vec = config.input_accounts_configs +// .into_iter() +// .map(|compressed_account_config| PackedCompressedAccountWithMerkleContextConfig { +// compressed_account: CompressedAccountConfig { +// address: (compressed_account_config.address_enabled, ()), +// data: (compressed_account_config.data_enabled, CompressedAccountDataConfig { data: compressed_account_config.data_capacity }), +// }, +// merkle_context: PackedMerkleContextConfig {}, +// }) +// .collect(); +// let (input_compressed_accounts_with_merkle_context, bytes) = as light_zero_copy::init_mut::ZeroCopyNew>::new_zero_copy( +// bytes, +// input_configs +// )?; +// +// let output_configs: Vec = config.output_accounts_configs +// .into_iter() +// .map(|compressed_account_config| OutputCompressedAccountWithPackedContextConfig { +// compressed_account: CompressedAccountConfig { +// address: (compressed_account_config.address_enabled, ()), +// data: (compressed_account_config.data_enabled, CompressedAccountDataConfig { data: compressed_account_config.data_capacity }), +// }, +// }) +// .collect(); +// let (output_compressed_accounts, bytes) = as light_zero_copy::init_mut::ZeroCopyNew>::new_zero_copy( +// bytes, +// output_configs +// )?; +// +// let (relay_fee, bytes) = as light_zero_copy::init_mut::ZeroCopyNew>::new_zero_copy( +// bytes, +// (config.relay_fee_config.is_some(), ()) +// )?; +// +// let new_address_configs: Vec = config.new_address_configs +// .into_iter() +// .map(|_| NewAddressParamsPackedConfig {}) +// .collect(); +// let (new_address_params, bytes) = as light_zero_copy::init_mut::ZeroCopyNew>::new_zero_copy( +// bytes, +// new_address_configs +// )?; +// +// let (compress_or_decompress_lamports, bytes) = as light_zero_copy::init_mut::ZeroCopyNew>::new_zero_copy( +// bytes, +// (config.decompress_lamports_config.is_some(), ()) +// )?; +// +// let (is_compress, bytes) = ::new_zero_copy( +// bytes, +// () +// )?; +// +// Ok(( +// ZInstructionDataInvokeMut { +// proof, +// input_compressed_accounts_with_merkle_context, +// output_compressed_accounts, +// relay_fee, +// new_address_params, +// compress_or_decompress_lamports, +// is_compress, +// }, +// bytes, +// )) +// } +// } + +#[derive( + ZeroCopy, + ZeroCopyMut, + ZeroCopyEq, + BorshDeserialize, + BorshSerialize, + Debug, + PartialEq, + Default, + Clone, +)] +pub struct OutputCompressedAccountWithContext { + pub compressed_account: CompressedAccount, + pub merkle_tree: Pubkey, +} + +#[derive( + ZeroCopy, + ZeroCopyMut, + ZeroCopyEq, + BorshDeserialize, + BorshSerialize, + Debug, + PartialEq, + Default, + Clone, +)] +pub struct OutputCompressedAccountWithPackedContext { + pub compressed_account: CompressedAccount, + pub merkle_tree_index: u8, +} + +// impl<'a> light_zero_copy::init_mut::ZeroCopyNew<'a> for OutputCompressedAccountWithPackedContext { +// type Config = CompressedAccountZeroCopyNew; +// type Output = >::Output; +// +// fn new_zero_copy( +// bytes: &'a mut [u8], +// config: Self::Config +// ) -> Result<(Self::Output, &'a mut [u8]), ZeroCopyError> { +// let (__meta, bytes) = Ref::<&mut [u8], ZOutputCompressedAccountWithPackedContextMetaMut>::from_prefix(bytes)?; +// let (compressed_account, bytes) = ::new_zero_copy(bytes, config)?; +// let (merkle_tree_index, bytes) = ::new_zero_copy(bytes, ())?; +// +// Ok(( +// ZOutputCompressedAccountWithPackedContextMut { +// compressed_account, +// merkle_tree_index, +// }, +// bytes, +// )) +// } +// } + +#[derive( + ZeroCopy, + ZeroCopyMut, + ZeroCopyEq, + BorshDeserialize, + BorshSerialize, + Debug, + PartialEq, + Default, + Clone, + Copy, +)] +pub struct NewAddressParamsPacked { + pub seed: [u8; 32], + pub address_queue_account_index: u8, + pub address_merkle_tree_account_index: u8, + pub address_merkle_tree_root_index: u16, +} + +// impl<'a> light_zero_copy::init_mut::ZeroCopyNew<'a> for NewAddressParamsPacked { +// type Config = (); +// type Output = >::Output; +// +// fn new_zero_copy( +// bytes: &'a mut [u8], +// _config: Self::Config +// ) -> Result<(Self::Output, &'a mut [u8]), ZeroCopyError> { +// let (__meta, bytes) = Ref::<&mut [u8], ZNewAddressParamsPackedMetaMut>::from_prefix(bytes)?; +// Ok((ZNewAddressParamsPackedMut { __meta }, bytes)) +// } +// } + +#[derive( + ZeroCopy, + ZeroCopyMut, + ZeroCopyEq, + BorshDeserialize, + BorshSerialize, + Debug, + PartialEq, + Default, + Clone, +)] +pub struct NewAddressParams { + pub seed: [u8; 32], + pub address_queue_pubkey: Pubkey, + pub address_merkle_tree_pubkey: Pubkey, + pub address_merkle_tree_root_index: u16, +} + +#[derive( + ZeroCopy, + ZeroCopyMut, + ZeroCopyEq, + BorshDeserialize, + BorshSerialize, + Debug, + PartialEq, + Default, + Clone, + Copy, +)] +pub struct PackedReadOnlyAddress { + pub address: [u8; 32], + pub address_merkle_tree_root_index: u16, + pub address_merkle_tree_account_index: u8, +} + +#[derive( + ZeroCopy, + ZeroCopyMut, + ZeroCopyEq, + BorshDeserialize, + BorshSerialize, + Debug, + PartialEq, + Default, + Clone, +)] +pub struct ReadOnlyAddress { + pub address: [u8; 32], + pub address_merkle_tree_pubkey: Pubkey, + pub address_merkle_tree_root_index: u16, +} + +#[derive( + ZeroCopy, + ZeroCopyMut, + ZeroCopyEq, + BorshDeserialize, + BorshSerialize, + Debug, + PartialEq, + Clone, + Copy, +)] +pub struct CompressedProof { + pub a: [u8; 32], + pub b: [u8; 64], + pub c: [u8; 32], +} + +impl Default for CompressedProof { + fn default() -> Self { + Self { + a: [0; 32], + b: [0; 64], + c: [0; 32], + } + } +} + +// impl<'a> light_zero_copy::init_mut::ZeroCopyNew<'a> for CompressedProof { +// type Config = (); +// type Output = >::Output; +// +// fn new_zero_copy( +// bytes: &'a mut [u8], +// _config: Self::Config +// ) -> Result<(Self::Output, &'a mut [u8]), ZeroCopyError> { +// let (__meta, bytes) = Ref::<&mut [u8], ZCompressedProofMetaMut>::from_prefix(bytes)?; +// Ok((ZCompressedProofMut { __meta }, bytes)) +// } +// } + +#[derive( + ZeroCopy, + ZeroCopyMut, + ZeroCopyEq, + BorshDeserialize, + BorshSerialize, + Debug, + Clone, + Copy, + PartialEq, + Eq, + Default, +)] +pub struct CompressedCpiContext { + /// Is set by the program that is invoking the CPI to signal that is should + /// set the cpi context. + pub set_context: bool, + /// Is set to clear the cpi context since someone could have set it before + /// with unrelated data. + pub first_set_context: bool, + /// Index of cpi context account in remaining accounts. + pub cpi_context_account_index: u8, +} + +#[derive( + ZeroCopy, + ZeroCopyMut, + ZeroCopyEq, + BorshDeserialize, + BorshSerialize, + Debug, + PartialEq, + Default, + Clone, +)] +pub struct PackedCompressedAccountWithMerkleContext { + pub compressed_account: CompressedAccount, + pub merkle_context: PackedMerkleContext, + /// Index of root used in inclusion validity proof. + pub root_index: u16, + /// Placeholder to mark accounts read-only unimplemented set to false. + pub read_only: bool, +} + +// impl<'a> light_zero_copy::init_mut::ZeroCopyNew<'a> for PackedCompressedAccountWithMerkleContext { +// type Config = CompressedAccountZeroCopyNew; +// type Output = >::Output; +// +// fn new_zero_copy( +// bytes: &'a mut [u8], +// config: Self::Config +// ) -> Result<(Self::Output, &'a mut [u8]), ZeroCopyError> { +// let (__meta, bytes) = Ref::<&mut [u8], ZPackedCompressedAccountWithMerkleContextMetaMut>::from_prefix(bytes)?; +// let (compressed_account, bytes) = ::new_zero_copy(bytes, config)?; +// let (merkle_context, bytes) = ::new_zero_copy(bytes, ())?; +// let (root_index, bytes) = ::new_zero_copy(bytes, ())?; +// let (read_only, bytes) = ::new_zero_copy(bytes, ())?; +// +// Ok(( +// ZPackedCompressedAccountWithMerkleContextMut { +// compressed_account, +// merkle_context, +// root_index, +// read_only, +// }, +// bytes, +// )) +// } +// } + +#[derive( + ZeroCopy, + ZeroCopyMut, + ZeroCopyEq, + BorshDeserialize, + BorshSerialize, + Debug, + Clone, + Copy, + PartialEq, + Default, +)] +pub struct MerkleContext { + pub merkle_tree_pubkey: Pubkey, + pub nullifier_queue_pubkey: Pubkey, + pub leaf_index: u32, + pub prove_by_index: bool, +} + +// impl<'a> light_zero_copy::init_mut::ZeroCopyNew<'a> for MerkleContext { +// type Config = (); +// type Output = >::Output; +// +// fn new_zero_copy( +// bytes: &'a mut [u8], +// _config: Self::Config +// ) -> Result<(Self::Output, &'a mut [u8]), ZeroCopyError> { +// let (__meta, bytes) = Ref::<&mut [u8], ZMerkleContextMetaMut>::from_prefix(bytes)?; +// +// Ok(( +// ZMerkleContextMut { +// __meta, +// }, +// bytes, +// )) +// } +// } + +#[derive( + ZeroCopy, + ZeroCopyMut, + ZeroCopyEq, + BorshDeserialize, + BorshSerialize, + Debug, + PartialEq, + Default, + Clone, +)] +pub struct CompressedAccountWithMerkleContext { + pub compressed_account: CompressedAccount, + pub merkle_context: MerkleContext, +} + +#[derive( + ZeroCopy, + ZeroCopyMut, + ZeroCopyEq, + BorshDeserialize, + BorshSerialize, + Debug, + PartialEq, + Default, + Clone, +)] +pub struct ReadOnlyCompressedAccount { + pub account_hash: [u8; 32], + pub merkle_context: MerkleContext, + pub root_index: u16, +} + +#[derive( + ZeroCopy, + ZeroCopyMut, + ZeroCopyEq, + BorshDeserialize, + BorshSerialize, + Debug, + PartialEq, + Default, + Clone, +)] +pub struct PackedReadOnlyCompressedAccount { + pub account_hash: [u8; 32], + pub merkle_context: PackedMerkleContext, + pub root_index: u16, +} + +#[derive( + ZeroCopy, + ZeroCopyMut, + ZeroCopyEq, + BorshDeserialize, + BorshSerialize, + Debug, + Clone, + Copy, + PartialEq, + Default, +)] +pub struct PackedMerkleContext { + pub merkle_tree_pubkey_index: u8, + pub nullifier_queue_pubkey_index: u8, + pub leaf_index: u32, + pub prove_by_index: bool, +} + +// impl<'a> light_zero_copy::init_mut::ZeroCopyNew<'a> for PackedMerkleContext { +// type Config = (); +// type Output = >::Output; +// +// fn new_zero_copy( +// bytes: &'a mut [u8], +// _config: Self::Config +// ) -> Result<(Self::Output, &'a mut [u8]), ZeroCopyError> { +// let (__meta, bytes) = Ref::<&mut [u8], ZPackedMerkleContextMetaMut>::from_prefix(bytes)?; +// Ok((ZPackedMerkleContextMut { __meta }, bytes)) +// } +// } + +#[derive(Debug, PartialEq, Default, Clone, Copy)] +pub struct CompressedAccountZeroCopyNew { + pub address_enabled: bool, + pub data_enabled: bool, + pub data_capacity: u32, +} + +// Manual InstructionDataInvokeConfig removed - now using generated config from ZeroCopyNew derive + +#[derive( + ZeroCopy, ZeroCopyMut, BorshDeserialize, BorshSerialize, Debug, PartialEq, Default, Clone, +)] +pub struct CompressedAccount { + pub owner: [u8; 32], + pub lamports: u64, + pub address: Option<[u8; 32]>, + pub data: Option, +} + +// impl<'a> light_zero_copy::init_mut::ZeroCopyNew<'a> for CompressedAccount { +// type Config = CompressedAccountZeroCopyNew; +// type Output = >::Output; +// +// fn new_zero_copy( +// bytes: &'a mut [u8], +// config: Self::Config, +// ) -> Result<(Self::Output, &'a mut [u8]), ZeroCopyError> { +// let (__meta, bytes) = Ref::<&mut [u8], ZCompressedAccountMetaMut>::from_prefix(bytes)?; +// +// // Use generic Option implementation for address field +// let (address, bytes) = as light_zero_copy::init_mut::ZeroCopyNew>::new_zero_copy( +// bytes, +// (config.address_enabled, ()) +// )?; +// +// // Use generic Option implementation for data field +// let (data, bytes) = as light_zero_copy::init_mut::ZeroCopyNew>::new_zero_copy( +// bytes, +// (config.data_enabled, CompressedAccountDataConfig { data: config.data_capacity }) +// )?; +// +// Ok(( +// ZCompressedAccountMut { +// __meta, +// address, +// data, +// }, +// bytes, +// )) +// } +// } + +impl<'a> From> for CompressedAccount { + fn from(value: ZCompressedAccount<'a>) -> Self { + Self { + owner: value.__meta.owner, + lamports: u64::from(value.__meta.lamports), + address: value.address.map(|x| *x), + data: value.data.as_ref().map(|x| x.into()), + } + } +} + +impl<'a> From<&ZCompressedAccount<'a>> for CompressedAccount { + fn from(value: &ZCompressedAccount<'a>) -> Self { + Self { + owner: value.__meta.owner, + lamports: u64::from(value.__meta.lamports), + address: value.address.as_ref().map(|x| **x), + data: value.data.as_ref().map(|x| x.into()), + } + } +} + +impl PartialEq for ZCompressedAccount<'_> { + fn eq(&self, other: &CompressedAccount) -> bool { + // Check address: if both Some and unequal, return false + if self.address.is_some() + && other.address.is_some() + && *self.address.unwrap() != other.address.unwrap() + { + return false; + } + // Check address: if exactly one is Some, return false + if self.address.is_some() != other.address.is_some() { + return false; + } + + // Check data: if both Some and unequal, return false + if self.data.is_some() + && other.data.is_some() + && self.data.as_ref().unwrap() != other.data.as_ref().unwrap() + { + return false; + } + // Check data: if exactly one is Some, return false + if self.data.is_some() != other.data.is_some() { + return false; + } + + self.owner == other.owner && self.lamports == other.lamports + } +} + +// Commented out because mutable derivation is disabled +// impl PartialEq for ZCompressedAccountMut<'_> { +// fn eq(&self, other: &CompressedAccount) -> bool { +// if self.address.is_some() +// && other.address.is_some() +// && **self.address.as_ref().unwrap() != *other.address.as_ref().unwrap() +// { +// return false; +// } +// if self.address.is_some() || other.address.is_some() { +// return false; +// } +// if self.data.is_some() +// && other.data.is_some() +// && self.data.as_ref().unwrap() != other.data.as_ref().unwrap() +// { +// return false; +// } +// if self.data.is_some() || other.data.is_some() { +// return false; +// } + +// self.owner == other.owner && self.lamports == other.lamports +// } +// } +impl PartialEq> for CompressedAccount { + fn eq(&self, other: &ZCompressedAccount) -> bool { + // Check address: if both Some and unequal, return false + if self.address.is_some() + && other.address.is_some() + && self.address.unwrap() != *other.address.unwrap() + { + return false; + } + // Check address: if exactly one is Some, return false + if self.address.is_some() != other.address.is_some() { + return false; + } + + // Check data: if both Some and unequal, return false + if self.data.is_some() + && other.data.is_some() + && other.data.as_ref().unwrap() != self.data.as_ref().unwrap() + { + return false; + } + // Check data: if exactly one is Some, return false + if self.data.is_some() != other.data.is_some() { + return false; + } + + self.owner == other.owner && self.lamports == u64::from(other.lamports) + } +} + +#[derive( + ZeroCopy, + ZeroCopyMut, + ZeroCopyEq, + BorshDeserialize, + BorshSerialize, + Debug, + PartialEq, + Default, + Clone, +)] +pub struct CompressedAccountData { + pub discriminator: [u8; 8], + pub data: Vec, + pub data_hash: [u8; 32], +} + +// COMMENTED OUT: Now using ZeroCopyNew derive macro instead +// impl<'a> light_zero_copy::init_mut::ZeroCopyNew<'a> for CompressedAccountData { +// type Config = u32; // data_capacity +// type Output = >::Output; + +// fn new_zero_copy( +// bytes: &'a mut [u8], +// data_capacity: Self::Config, +// ) -> Result<(Self::Output, &'a mut [u8]), ZeroCopyError> { +// let (__meta, bytes) = Ref::<&mut [u8], ZCompressedAccountDataMetaMut>::from_prefix(bytes)?; +// // For u8 slices we just use &mut [u8] so we init the len and the split mut separately. +// { +// light_zero_copy::slice_mut::ZeroCopySliceMutBorsh::::new_at( +// data_capacity.into(), +// bytes, +// )?; +// } +// // Split off len for +// let (_, bytes) = bytes.split_at_mut(4); +// let (data, bytes) = bytes.split_at_mut(data_capacity as usize); +// let (data_hash, bytes) = Ref::<&mut [u8], [u8; 32]>::from_prefix(bytes)?; +// Ok(( +// ZCompressedAccountDataMut { +// __meta, +// data, +// data_hash, +// }, +// bytes, +// )) +// } +// } + +#[test] +fn test_compressed_account_data_new_at() { + use light_zero_copy::init_mut::ZeroCopyNew; + let config = CompressedAccountDataConfig { data: 10 }; + + // Calculate exact buffer size needed and allocate + let buffer_size = CompressedAccountData::byte_len(&config); + let mut bytes = vec![0u8; buffer_size]; + let result = CompressedAccountData::new_zero_copy(&mut bytes, config); + assert!(result.is_ok()); + let (mut mut_account, remaining) = result.unwrap(); + + // Verify we used exactly the calculated number of bytes + assert_eq!( + remaining.len(), + 0, + "Should have used exactly {} bytes", + buffer_size + ); + + // Test that we can set discriminator + mut_account.__meta.discriminator = [1, 2, 3, 4, 5, 6, 7, 8]; + + // Test that we can write to data + mut_account.data[0] = 42; + mut_account.data[1] = 43; + + // Test that we can set data_hash + mut_account.data_hash[0] = 99; + mut_account.data_hash[1] = 100; + + assert_eq!(mut_account.__meta.discriminator, [1, 2, 3, 4, 5, 6, 7, 8]); + assert_eq!(mut_account.data[0], 42); + assert_eq!(mut_account.data[1], 43); + assert_eq!(mut_account.data_hash[0], 99); + assert_eq!(mut_account.data_hash[1], 100); + + // Test deserializing the initialized bytes with zero_copy_at_mut + let deserialize_result = CompressedAccountData::zero_copy_at_mut(&mut bytes); + assert!(deserialize_result.is_ok()); + let (deserialized_account, _remaining) = deserialize_result.unwrap(); + + // Verify the deserialized data matches what we set + assert_eq!( + deserialized_account.__meta.discriminator, + [1, 2, 3, 4, 5, 6, 7, 8] + ); + assert_eq!(deserialized_account.data.len(), 10); + assert_eq!(deserialized_account.data[0], 42); + assert_eq!(deserialized_account.data[1], 43); + assert_eq!(deserialized_account.data_hash[0], 99); + assert_eq!(deserialized_account.data_hash[1], 100); +} + +#[test] +fn test_compressed_account_new_at() { + use light_zero_copy::init_mut::ZeroCopyNew; + let config = CompressedAccountConfig { + address: (true, ()), + data: (true, CompressedAccountDataConfig { data: 10 }), + }; + + // Calculate exact buffer size needed and allocate + let buffer_size = CompressedAccount::byte_len(&config); + let mut bytes = vec![0u8; buffer_size]; + let result = CompressedAccount::new_zero_copy(&mut bytes, config); + assert!(result.is_ok()); + let (mut mut_account, remaining) = result.unwrap(); + + // Verify we used exactly the calculated number of bytes + assert_eq!( + remaining.len(), + 0, + "Should have used exactly {} bytes", + buffer_size + ); + + // Set values + mut_account.__meta.owner = [1u8; 32]; + mut_account.__meta.lamports = 12345u64.into(); + mut_account.address.as_mut().unwrap()[0] = 42; + mut_account.data.as_mut().unwrap().data[0] = 99; + + // Test deserialize + let (deserialized, _) = CompressedAccount::zero_copy_at_mut(&mut bytes).unwrap(); + assert_eq!(deserialized.__meta.owner, [1u8; 32]); + assert_eq!(u64::from(deserialized.__meta.lamports), 12345u64); + assert_eq!(deserialized.address.as_ref().unwrap()[0], 42); + assert_eq!(deserialized.data.as_ref().unwrap().data[0], 99); +} + +#[test] +fn test_instruction_data_invoke_new_at() { + use light_zero_copy::init_mut::ZeroCopyNew; + // Create different configs to test various combinations + let compressed_account_config1 = CompressedAccountZeroCopyNew { + address_enabled: true, + data_enabled: true, + data_capacity: 10, + }; + + let compressed_account_config2 = CompressedAccountZeroCopyNew { + address_enabled: false, + data_enabled: true, + data_capacity: 5, + }; + + let compressed_account_config3 = CompressedAccountZeroCopyNew { + address_enabled: true, + data_enabled: false, + data_capacity: 0, + }; + + let compressed_account_config4 = CompressedAccountZeroCopyNew { + address_enabled: false, + data_enabled: false, + data_capacity: 0, + }; + + let config = InstructionDataInvokeConfig { + proof: (true, CompressedProofConfig {}), // Enable proof + input_compressed_accounts_with_merkle_context: vec![ + PackedCompressedAccountWithMerkleContextConfig { + compressed_account: CompressedAccountConfig { + address: (compressed_account_config1.address_enabled, ()), + data: ( + compressed_account_config1.data_enabled, + CompressedAccountDataConfig { + data: compressed_account_config1.data_capacity, + }, + ), + }, + merkle_context: PackedMerkleContextConfig {}, + }, + PackedCompressedAccountWithMerkleContextConfig { + compressed_account: CompressedAccountConfig { + address: (compressed_account_config2.address_enabled, ()), + data: ( + compressed_account_config2.data_enabled, + CompressedAccountDataConfig { + data: compressed_account_config2.data_capacity, + }, + ), + }, + merkle_context: PackedMerkleContextConfig {}, + }, + ], + output_compressed_accounts: vec![ + OutputCompressedAccountWithPackedContextConfig { + compressed_account: CompressedAccountConfig { + address: (compressed_account_config3.address_enabled, ()), + data: ( + compressed_account_config3.data_enabled, + CompressedAccountDataConfig { + data: compressed_account_config3.data_capacity, + }, + ), + }, + }, + OutputCompressedAccountWithPackedContextConfig { + compressed_account: CompressedAccountConfig { + address: (compressed_account_config4.address_enabled, ()), + data: ( + compressed_account_config4.data_enabled, + CompressedAccountDataConfig { + data: compressed_account_config4.data_capacity, + }, + ), + }, + }, + ], + relay_fee: true, // Enable relay fee + new_address_params: vec![ + NewAddressParamsPackedConfig {}, + NewAddressParamsPackedConfig {}, + ], // Length 2 + compress_or_decompress_lamports: true, // Enable decompress lamports + }; + + // Calculate exact buffer size needed and allocate + let buffer_size = InstructionDataInvoke::byte_len(&config); + let mut bytes = vec![0u8; buffer_size]; + + let result = InstructionDataInvoke::new_zero_copy(&mut bytes, config); + if let Err(ref e) = result { + eprintln!("Error: {:?}", e); + } + assert!(result.is_ok()); + let (_instruction_data, remaining) = result.unwrap(); + + // Verify we used exactly the calculated number of bytes + assert_eq!( + remaining.len(), + 0, + "Should have used exactly {} bytes", + buffer_size + ); + + // Test deserialization round-trip first + let (mut deserialized, _) = InstructionDataInvoke::zero_copy_at_mut(&mut bytes).unwrap(); + + // Now set values and test again + *deserialized.is_compress = 1; + + // Set proof values + if let Some(proof) = &mut deserialized.proof { + proof.a[0] = 42; + proof.b[0] = 43; + proof.c[0] = 44; + } + + // Set relay fee value + if let Some(relay_fee) = &mut deserialized.relay_fee { + **relay_fee = 12345u64.into(); + } + + // Set decompress lamports value + if let Some(decompress_lamports) = &mut deserialized.compress_or_decompress_lamports { + **decompress_lamports = 67890u64.into(); + } + + // Set first input account values + let first_input = &mut deserialized.input_compressed_accounts_with_merkle_context[0]; + first_input.compressed_account.__meta.owner[0] = 11; + first_input.compressed_account.__meta.lamports = 1000u64.into(); + if let Some(address) = &mut first_input.compressed_account.address { + address[0] = 22; + } + if let Some(data) = &mut first_input.compressed_account.data { + data.__meta.discriminator[0] = 33; + data.data[0] = 99; + data.data_hash[0] = 55; + } + + // Set first output account values + let first_output = &mut deserialized.output_compressed_accounts[0]; + first_output.compressed_account.__meta.owner[0] = 77; + first_output.compressed_account.__meta.lamports = 2000u64.into(); + if let Some(address) = &mut first_output.compressed_account.address { + address[0] = 88; + } + + // Verify basic structure with vectors of length 2 + assert_eq!( + deserialized + .input_compressed_accounts_with_merkle_context + .len(), + 2 + ); // Length 2 + assert_eq!(deserialized.output_compressed_accounts.len(), 2); // Length 2 + assert_eq!(deserialized.new_address_params.len(), 2); // Length 2 + assert!(deserialized.proof.is_some()); // Enabled + assert!(deserialized.relay_fee.is_some()); // Enabled + assert!(deserialized.compress_or_decompress_lamports.is_some()); // Enabled + assert_eq!(*deserialized.is_compress, 1); + + // Test data access and modification + if let Some(proof) = &deserialized.proof { + // Verify we can access proof fields and our written values + assert_eq!(proof.a[0], 42); + assert_eq!(proof.b[0], 43); + assert_eq!(proof.c[0], 44); + } + + // Verify option integer values + if let Some(relay_fee) = &deserialized.relay_fee { + assert_eq!(u64::from(**relay_fee), 12345); + } + + if let Some(decompress_lamports) = &deserialized.compress_or_decompress_lamports { + assert_eq!(u64::from(**decompress_lamports), 67890); + } + + // Test accessing first input account (config1: address=true, data=true, capacity=10) + let first_input = &deserialized.input_compressed_accounts_with_merkle_context[0]; + assert_eq!(first_input.compressed_account.__meta.owner[0], 11); // Our written value + assert_eq!( + u64::from(first_input.compressed_account.__meta.lamports), + 1000 + ); // Our written value + assert!(first_input.compressed_account.address.is_some()); // Should be enabled + assert!(first_input.compressed_account.data.is_some()); // Should be enabled + if let Some(address) = &first_input.compressed_account.address { + assert_eq!(address[0], 22); // Our written value + } + if let Some(data) = &first_input.compressed_account.data { + assert_eq!(data.data.len(), 10); // Should have capacity 10 + assert_eq!(data.__meta.discriminator[0], 33); // Our written value + assert_eq!(data.data[0], 99); // Our written value + assert_eq!(data.data_hash[0], 55); // Our written value + } + + // Test accessing second input account (config2: address=false, data=true, capacity=5) + let second_input = &deserialized.input_compressed_accounts_with_merkle_context[1]; + assert_eq!(second_input.compressed_account.__meta.owner[0], 0); // Should be zero (not written) + assert!(second_input.compressed_account.address.is_none()); // Should be disabled + assert!(second_input.compressed_account.data.is_some()); // Should be enabled + if let Some(data) = &second_input.compressed_account.data { + assert_eq!(data.data.len(), 5); // Should have capacity 5 + } + + // Test accessing first output account (config3: address=true, data=false, capacity=0) + let first_output = &deserialized.output_compressed_accounts[0]; + assert_eq!(first_output.compressed_account.__meta.owner[0], 77); // Our written value + assert_eq!( + u64::from(first_output.compressed_account.__meta.lamports), + 2000 + ); // Our written value + assert!(first_output.compressed_account.address.is_some()); // Should be enabled + assert!(first_output.compressed_account.data.is_none()); // Should be disabled + if let Some(address) = &first_output.compressed_account.address { + assert_eq!(address[0], 88); // Our written value + } + + // Test accessing second output account (config4: address=false, data=false, capacity=0) + let second_output = &deserialized.output_compressed_accounts[1]; + assert_eq!(second_output.compressed_account.__meta.owner[0], 0); // Should be zero (not written) + assert!(second_output.compressed_account.address.is_none()); // Should be disabled + assert!(second_output.compressed_account.data.is_none()); // Should be disabled +} + +#[test] +fn readme() { + use borsh::{BorshDeserialize, BorshSerialize}; + use light_zero_copy_derive::{ZeroCopy, ZeroCopyEq, ZeroCopyMut}; + + #[repr(C)] + #[derive(Debug, PartialEq, BorshSerialize, BorshDeserialize, ZeroCopy, ZeroCopyMut)] + pub struct MyStructOption { + pub a: u8, + pub b: u16, + pub vec: Vec>, + pub c: Option, + } + + #[repr(C)] + #[derive( + Debug, PartialEq, BorshSerialize, BorshDeserialize, ZeroCopy, ZeroCopyMut, ZeroCopyEq, + )] + pub struct MyStruct { + pub a: u8, + pub b: u16, + pub vec: Vec, + pub c: u64, + } + + // Test the new ZeroCopyNew functionality + + #[repr(C)] + #[derive(Debug, PartialEq, BorshSerialize, BorshDeserialize, ZeroCopy, ZeroCopyMut)] + pub struct TestConfigStruct { + pub a: u8, + pub b: u16, + pub vec: Vec, + pub option: Option, + } + + let my_struct = MyStruct { + a: 1, + b: 2, + vec: vec![1u8; 32], + c: 3, + }; + // Use the struct with zero-copy deserialization + let bytes = my_struct.try_to_vec().unwrap(); + // byte_len not available for non-mut derivations + // assert_eq!(bytes.len(), my_struct.byte_len()); + let (zero_copy, _remaining) = MyStruct::zero_copy_at(&bytes).unwrap(); + assert_eq!(zero_copy.a, 1); + let org_struct: MyStruct = zero_copy.into(); + assert_eq!(org_struct, my_struct); + // { + // let (mut zero_copy_mut, _remaining) = MyStruct::zero_copy_at_mut(&mut bytes).unwrap(); + // zero_copy_mut.a = 42; + // } + // let borsh = MyStruct::try_from_slice(&bytes).unwrap(); + // assert_eq!(borsh.a, 42u8); +} + +#[derive( + ZeroCopy, ZeroCopyMut, BorshDeserialize, BorshSerialize, Debug, PartialEq, Default, Clone, +)] +pub struct InstructionDataInvokeCpi { + pub proof: Option, + pub new_address_params: Vec, + pub input_compressed_accounts_with_merkle_context: + Vec, + pub output_compressed_accounts: Vec, + pub relay_fee: Option, + pub compress_or_decompress_lamports: Option, + pub is_compress: bool, + pub cpi_context: Option, +} + +impl PartialEq> for InstructionDataInvokeCpi { + fn eq(&self, other: &ZInstructionDataInvokeCpi) -> bool { + // Compare proof + match (&self.proof, &other.proof) { + (Some(ref self_proof), Some(ref other_proof)) => { + if self_proof.a != other_proof.a + || self_proof.b != other_proof.b + || self_proof.c != other_proof.c + { + return false; + } + } + (None, None) => {} + _ => return false, + } + + // Compare vectors lengths first + if self.new_address_params.len() != other.new_address_params.len() + || self.input_compressed_accounts_with_merkle_context.len() + != other.input_compressed_accounts_with_merkle_context.len() + || self.output_compressed_accounts.len() != other.output_compressed_accounts.len() + { + return false; + } + + // Compare new_address_params + for (self_param, other_param) in self + .new_address_params + .iter() + .zip(other.new_address_params.iter()) + { + if self_param.seed != other_param.seed + || self_param.address_queue_account_index != other_param.address_queue_account_index + || self_param.address_merkle_tree_account_index + != other_param.address_merkle_tree_account_index + || self_param.address_merkle_tree_root_index + != u16::from(other_param.address_merkle_tree_root_index) + { + return false; + } + } + + // Compare input accounts + for (self_input, other_input) in self + .input_compressed_accounts_with_merkle_context + .iter() + .zip(other.input_compressed_accounts_with_merkle_context.iter()) + { + if self_input != other_input { + return false; + } + } + + // Compare output accounts + for (self_output, other_output) in self + .output_compressed_accounts + .iter() + .zip(other.output_compressed_accounts.iter()) + { + if self_output != other_output { + return false; + } + } + + // Compare relay_fee + match (&self.relay_fee, &other.relay_fee) { + (Some(self_fee), Some(other_fee)) => { + if *self_fee != u64::from(**other_fee) { + return false; + } + } + (None, None) => {} + _ => return false, + } + + // Compare compress_or_decompress_lamports + match ( + &self.compress_or_decompress_lamports, + &other.compress_or_decompress_lamports, + ) { + (Some(self_lamports), Some(other_lamports)) => { + if *self_lamports != u64::from(**other_lamports) { + return false; + } + } + (None, None) => {} + _ => return false, + } + + // Compare is_compress (bool vs u8) + if self.is_compress != (other.is_compress != 0) { + return false; + } + + // Compare cpi_context + match (&self.cpi_context, &other.cpi_context) { + (Some(self_ctx), Some(other_ctx)) => { + if self_ctx.set_context != (other_ctx.set_context != 0) + || self_ctx.first_set_context != (other_ctx.first_set_context != 0) + || self_ctx.cpi_context_account_index != other_ctx.cpi_context_account_index + { + return false; + } + } + (None, None) => {} + _ => return false, + } + + true + } +} + +impl PartialEq for ZInstructionDataInvokeCpi<'_> { + fn eq(&self, other: &InstructionDataInvokeCpi) -> bool { + other.eq(self) + } +} + +impl PartialEq> + for PackedCompressedAccountWithMerkleContext +{ + fn eq(&self, other: &ZPackedCompressedAccountWithMerkleContext) -> bool { + // Compare compressed_account + if self.compressed_account.owner != other.compressed_account.__meta.owner + || self.compressed_account.lamports + != u64::from(other.compressed_account.__meta.lamports) + { + return false; + } + + // Compare optional address + match ( + &self.compressed_account.address, + &other.compressed_account.address, + ) { + (Some(self_addr), Some(other_addr)) => { + if *self_addr != **other_addr { + return false; + } + } + (None, None) => {} + _ => return false, + } + + // Compare optional data + match ( + &self.compressed_account.data, + &other.compressed_account.data, + ) { + (Some(self_data), Some(other_data)) => { + if self_data.discriminator != other_data.__meta.discriminator + || self_data.data_hash != *other_data.data_hash + || self_data.data.len() != other_data.data.len() + { + return false; + } + // Compare data contents + for (self_byte, other_byte) in self_data.data.iter().zip(other_data.data.iter()) { + if *self_byte != *other_byte { + return false; + } + } + } + (None, None) => {} + _ => return false, + } + + // Compare merkle_context + if self.merkle_context.merkle_tree_pubkey_index + != other.merkle_context.__meta.merkle_tree_pubkey_index + || self.merkle_context.nullifier_queue_pubkey_index + != other.merkle_context.__meta.nullifier_queue_pubkey_index + || self.merkle_context.leaf_index != u32::from(other.merkle_context.__meta.leaf_index) + || self.merkle_context.prove_by_index != other.merkle_context.prove_by_index() + { + return false; + } + + // Compare root_index and read_only + if self.root_index != u16::from(*other.root_index) + || self.read_only != (other.read_only != 0) + { + return false; + } + + true + } +} + +impl PartialEq> + for OutputCompressedAccountWithPackedContext +{ + fn eq(&self, other: &ZOutputCompressedAccountWithPackedContext) -> bool { + // Compare compressed_account + if self.compressed_account.owner != other.compressed_account.__meta.owner + || self.compressed_account.lamports + != u64::from(other.compressed_account.__meta.lamports) + { + return false; + } + + // Compare optional address + match ( + &self.compressed_account.address, + &other.compressed_account.address, + ) { + (Some(self_addr), Some(other_addr)) => { + if *self_addr != **other_addr { + return false; + } + } + (None, None) => {} + _ => return false, + } + + // Compare optional data + match ( + &self.compressed_account.data, + &other.compressed_account.data, + ) { + (Some(self_data), Some(other_data)) => { + if self_data.discriminator != other_data.__meta.discriminator + || self_data.data_hash != *other_data.data_hash + || self_data.data.len() != other_data.data.len() + { + return false; + } + // Compare data contents + for (self_byte, other_byte) in self_data.data.iter().zip(other_data.data.iter()) { + if *self_byte != *other_byte { + return false; + } + } + } + (None, None) => {} + _ => return false, + } + + // Compare merkle_tree_index + if self.merkle_tree_index != other.merkle_tree_index { + return false; + } + + true + } +} diff --git a/program-libs/zero-copy-derive/tests/random.rs b/program-libs/zero-copy-derive/tests/random.rs new file mode 100644 index 0000000000..993adef704 --- /dev/null +++ b/program-libs/zero-copy-derive/tests/random.rs @@ -0,0 +1,651 @@ +#![cfg(feature = "mut")] +use std::assert_eq; + +use borsh::BorshDeserialize; +use light_zero_copy::{borsh::Deserialize, init_mut::ZeroCopyNew}; +use rand::{ + rngs::{StdRng, ThreadRng}, + Rng, +}; + +mod instruction_data; +use instruction_data::{ + CompressedAccount, + CompressedAccountConfig, + CompressedAccountData, + CompressedAccountDataConfig, + CompressedCpiContext, + CompressedCpiContextConfig, + CompressedProof, + CompressedProofConfig, + InstructionDataInvoke, + // Config types (generated by ZeroCopyNew derive) + InstructionDataInvokeConfig, + InstructionDataInvokeCpi, + InstructionDataInvokeCpiConfig, + NewAddressParamsPacked, + NewAddressParamsPackedConfig, + OutputCompressedAccountWithPackedContext, + OutputCompressedAccountWithPackedContextConfig, + PackedCompressedAccountWithMerkleContext, + PackedCompressedAccountWithMerkleContextConfig, + PackedMerkleContext, + PackedMerkleContextConfig, + Pubkey, + // Zero-copy mutable types + ZInstructionDataInvokeCpiMut, + ZInstructionDataInvokeMut, +}; + +// Function to populate mutable zero-copy structure with data from InstructionDataInvokeCpi +fn populate_invoke_cpi_zero_copy( + src: &InstructionDataInvokeCpi, + dst: &mut ZInstructionDataInvokeCpiMut, +) { + *dst.is_compress = if src.is_compress { 1 } else { 0 }; + + // Copy proof if present + if let (Some(src_proof), Some(dst_proof)) = (&src.proof, &mut dst.proof) { + dst_proof.a.copy_from_slice(&src_proof.a); + dst_proof.b.copy_from_slice(&src_proof.b); + dst_proof.c.copy_from_slice(&src_proof.c); + } + + // Copy new_address_params + for (src_param, dst_param) in src + .new_address_params + .iter() + .zip(dst.new_address_params.iter_mut()) + { + dst_param.seed.copy_from_slice(&src_param.seed); + dst_param.address_queue_account_index = src_param.address_queue_account_index; + dst_param.address_merkle_tree_account_index = src_param.address_merkle_tree_account_index; + dst_param.address_merkle_tree_root_index = src_param.address_merkle_tree_root_index.into(); + } + + // Copy input_compressed_accounts_with_merkle_context + for (src_input, dst_input) in src + .input_compressed_accounts_with_merkle_context + .iter() + .zip(dst.input_compressed_accounts_with_merkle_context.iter_mut()) + { + // Copy compressed account + dst_input + .compressed_account + .owner + .copy_from_slice(&src_input.compressed_account.owner); + dst_input.compressed_account.lamports = src_input.compressed_account.lamports.into(); + + // Copy address if present + if let (Some(src_addr), Some(dst_addr)) = ( + &src_input.compressed_account.address, + &mut dst_input.compressed_account.address, + ) { + dst_addr.copy_from_slice(src_addr); + } + + // Copy data if present + if let (Some(src_data), Some(dst_data)) = ( + &src_input.compressed_account.data, + &mut dst_input.compressed_account.data, + ) { + dst_data + .discriminator + .copy_from_slice(&src_data.discriminator); + dst_data.data_hash.copy_from_slice(&src_data.data_hash); + for (src_byte, dst_byte) in src_data.data.iter().zip(dst_data.data.iter_mut()) { + *dst_byte = *src_byte; + } + } + + // Copy merkle context + dst_input.merkle_context.merkle_tree_pubkey_index = + src_input.merkle_context.merkle_tree_pubkey_index; + dst_input.merkle_context.nullifier_queue_pubkey_index = + src_input.merkle_context.nullifier_queue_pubkey_index; + dst_input.merkle_context.leaf_index = src_input.merkle_context.leaf_index.into(); + dst_input.merkle_context.prove_by_index = if src_input.merkle_context.prove_by_index { + 1 + } else { + 0 + }; + + *dst_input.root_index = src_input.root_index.into(); + *dst_input.read_only = if src_input.read_only { 1 } else { 0 }; + } + + // Copy output_compressed_accounts + for (src_output, dst_output) in src + .output_compressed_accounts + .iter() + .zip(dst.output_compressed_accounts.iter_mut()) + { + // Copy compressed account + dst_output + .compressed_account + .owner + .copy_from_slice(&src_output.compressed_account.owner); + dst_output.compressed_account.lamports = src_output.compressed_account.lamports.into(); + + // Copy address if present + if let (Some(src_addr), Some(dst_addr)) = ( + &src_output.compressed_account.address, + &mut dst_output.compressed_account.address, + ) { + dst_addr.copy_from_slice(src_addr); + } + + // Copy data if present + if let (Some(src_data), Some(dst_data)) = ( + &src_output.compressed_account.data, + &mut dst_output.compressed_account.data, + ) { + dst_data + .discriminator + .copy_from_slice(&src_data.discriminator); + dst_data.data_hash.copy_from_slice(&src_data.data_hash); + for (src_byte, dst_byte) in src_data.data.iter().zip(dst_data.data.iter_mut()) { + *dst_byte = *src_byte; + } + } + + *dst_output.merkle_tree_index = src_output.merkle_tree_index; + } + + // Copy relay_fee if present + if let (Some(src_fee), Some(dst_fee)) = (&src.relay_fee, &mut dst.relay_fee) { + **dst_fee = (*src_fee).into(); + } + + // Copy compress_or_decompress_lamports if present + if let (Some(src_lamports), Some(dst_lamports)) = ( + &src.compress_or_decompress_lamports, + &mut dst.compress_or_decompress_lamports, + ) { + **dst_lamports = (*src_lamports).into(); + } + + // Copy cpi_context if present + if let (Some(src_ctx), Some(dst_ctx)) = (&src.cpi_context, &mut dst.cpi_context) { + dst_ctx.set_context = if src_ctx.set_context { 1 } else { 0 }; + dst_ctx.first_set_context = if src_ctx.first_set_context { 1 } else { 0 }; + dst_ctx.cpi_context_account_index = src_ctx.cpi_context_account_index; + } +} + +// Function to populate mutable zero-copy structure with data from InstructionDataInvoke +fn populate_invoke_zero_copy(src: &InstructionDataInvoke, dst: &mut ZInstructionDataInvokeMut) { + *dst.is_compress = if src.is_compress { 1 } else { 0 }; + + // Copy proof if present + if let (Some(src_proof), Some(dst_proof)) = (&src.proof, &mut dst.proof) { + dst_proof.a.copy_from_slice(&src_proof.a); + dst_proof.b.copy_from_slice(&src_proof.b); + dst_proof.c.copy_from_slice(&src_proof.c); + } + + // Copy new_address_params + for (src_param, dst_param) in src + .new_address_params + .iter() + .zip(dst.new_address_params.iter_mut()) + { + dst_param.seed.copy_from_slice(&src_param.seed); + dst_param.address_queue_account_index = src_param.address_queue_account_index; + dst_param.address_merkle_tree_account_index = src_param.address_merkle_tree_account_index; + dst_param.address_merkle_tree_root_index = src_param.address_merkle_tree_root_index.into(); + } + + // Copy input_compressed_accounts_with_merkle_context + for (src_input, dst_input) in src + .input_compressed_accounts_with_merkle_context + .iter() + .zip(dst.input_compressed_accounts_with_merkle_context.iter_mut()) + { + // Copy compressed account + dst_input + .compressed_account + .owner + .copy_from_slice(&src_input.compressed_account.owner); + dst_input.compressed_account.lamports = src_input.compressed_account.lamports.into(); + + // Copy address if present + if let (Some(src_addr), Some(dst_addr)) = ( + &src_input.compressed_account.address, + &mut dst_input.compressed_account.address, + ) { + dst_addr.copy_from_slice(src_addr); + } + + // Copy data if present + if let (Some(src_data), Some(dst_data)) = ( + &src_input.compressed_account.data, + &mut dst_input.compressed_account.data, + ) { + dst_data + .discriminator + .copy_from_slice(&src_data.discriminator); + dst_data.data_hash.copy_from_slice(&src_data.data_hash); + for (src_byte, dst_byte) in src_data.data.iter().zip(dst_data.data.iter_mut()) { + *dst_byte = *src_byte; + } + } + + // Copy merkle context + dst_input.merkle_context.merkle_tree_pubkey_index = + src_input.merkle_context.merkle_tree_pubkey_index; + dst_input.merkle_context.nullifier_queue_pubkey_index = + src_input.merkle_context.nullifier_queue_pubkey_index; + dst_input.merkle_context.leaf_index = src_input.merkle_context.leaf_index.into(); + dst_input.merkle_context.prove_by_index = if src_input.merkle_context.prove_by_index { + 1 + } else { + 0 + }; + + *dst_input.root_index = src_input.root_index.into(); + *dst_input.read_only = if src_input.read_only { 1 } else { 0 }; + } + + // Copy output_compressed_accounts + for (src_output, dst_output) in src + .output_compressed_accounts + .iter() + .zip(dst.output_compressed_accounts.iter_mut()) + { + // Copy compressed account + dst_output + .compressed_account + .owner + .copy_from_slice(&src_output.compressed_account.owner); + dst_output.compressed_account.lamports = src_output.compressed_account.lamports.into(); + + // Copy address if present + if let (Some(src_addr), Some(dst_addr)) = ( + &src_output.compressed_account.address, + &mut dst_output.compressed_account.address, + ) { + dst_addr.copy_from_slice(src_addr); + } + + // Copy data if present + if let (Some(src_data), Some(dst_data)) = ( + &src_output.compressed_account.data, + &mut dst_output.compressed_account.data, + ) { + dst_data + .discriminator + .copy_from_slice(&src_data.discriminator); + dst_data.data_hash.copy_from_slice(&src_data.data_hash); + for (src_byte, dst_byte) in src_data.data.iter().zip(dst_data.data.iter_mut()) { + *dst_byte = *src_byte; + } + } + + *dst_output.merkle_tree_index = src_output.merkle_tree_index; + } + + // Copy relay_fee if present + if let (Some(src_fee), Some(dst_fee)) = (&src.relay_fee, &mut dst.relay_fee) { + **dst_fee = (*src_fee).into(); + } + + // Copy compress_or_decompress_lamports if present + if let (Some(src_lamports), Some(dst_lamports)) = ( + &src.compress_or_decompress_lamports, + &mut dst.compress_or_decompress_lamports, + ) { + **dst_lamports = (*src_lamports).into(); + } +} + +fn get_rnd_instruction_data_invoke_cpi(rng: &mut StdRng) -> InstructionDataInvokeCpi { + InstructionDataInvokeCpi { + proof: Some(CompressedProof { + a: rng.gen(), + b: (0..64) + .map(|_| rng.gen()) + .collect::>() + .try_into() + .unwrap(), + c: rng.gen(), + }), + new_address_params: vec![get_rnd_new_address_params(rng); rng.gen_range(0..10)], + input_compressed_accounts_with_merkle_context: vec![ + get_rnd_test_input_account(rng); + rng.gen_range(0..10) + ], + output_compressed_accounts: vec![get_rnd_test_output_account(rng); rng.gen_range(0..10)], + relay_fee: None, + compress_or_decompress_lamports: rng.gen(), + is_compress: rng.gen(), + cpi_context: Some(get_rnd_cpi_context(rng)), + } +} + +fn get_rnd_cpi_context(rng: &mut StdRng) -> CompressedCpiContext { + CompressedCpiContext { + first_set_context: rng.gen(), + set_context: rng.gen(), + cpi_context_account_index: rng.gen(), + } +} + +fn get_rnd_test_account_data(rng: &mut StdRng) -> CompressedAccountData { + CompressedAccountData { + discriminator: rng.gen(), + data: (0..100).map(|_| rng.gen()).collect::>(), + data_hash: rng.gen(), + } +} + +fn get_rnd_test_account(rng: &mut StdRng) -> CompressedAccount { + CompressedAccount { + owner: Pubkey::new_unique().to_bytes(), + lamports: rng.gen(), + address: Some(Pubkey::new_unique().to_bytes()), + data: Some(get_rnd_test_account_data(rng)), + } +} + +fn get_rnd_test_output_account(rng: &mut StdRng) -> OutputCompressedAccountWithPackedContext { + OutputCompressedAccountWithPackedContext { + compressed_account: get_rnd_test_account(rng), + merkle_tree_index: rng.gen(), + } +} + +fn get_rnd_test_input_account(rng: &mut StdRng) -> PackedCompressedAccountWithMerkleContext { + PackedCompressedAccountWithMerkleContext { + compressed_account: CompressedAccount { + owner: Pubkey::new_unique().to_bytes(), + lamports: 100, + address: Some(Pubkey::new_unique().to_bytes()), + data: Some(get_rnd_test_account_data(rng)), + }, + merkle_context: PackedMerkleContext { + merkle_tree_pubkey_index: rng.gen(), + nullifier_queue_pubkey_index: rng.gen(), + leaf_index: rng.gen(), + prove_by_index: rng.gen(), + }, + root_index: rng.gen(), + read_only: false, + } +} + +fn get_rnd_new_address_params(rng: &mut StdRng) -> NewAddressParamsPacked { + NewAddressParamsPacked { + seed: rng.gen(), + address_queue_account_index: rng.gen(), + address_merkle_tree_account_index: rng.gen(), + address_merkle_tree_root_index: rng.gen(), + } +} + +// Generate config for InstructionDataInvoke based on the actual data +fn generate_random_invoke_config( + invoke_ref: &InstructionDataInvoke, +) -> InstructionDataInvokeConfig { + InstructionDataInvokeConfig { + proof: (invoke_ref.proof.is_some(), CompressedProofConfig {}), + input_compressed_accounts_with_merkle_context: invoke_ref + .input_compressed_accounts_with_merkle_context + .iter() + .map(|account| PackedCompressedAccountWithMerkleContextConfig { + compressed_account: CompressedAccountConfig { + address: (account.compressed_account.address.is_some(), ()), + data: ( + account.compressed_account.data.is_some(), + CompressedAccountDataConfig { + data: account + .compressed_account + .data + .as_ref() + .map(|d| d.data.len() as u32) + .unwrap_or(0), + }, + ), + }, + merkle_context: PackedMerkleContextConfig {}, + }) + .collect(), + output_compressed_accounts: invoke_ref + .output_compressed_accounts + .iter() + .map(|account| OutputCompressedAccountWithPackedContextConfig { + compressed_account: CompressedAccountConfig { + address: (account.compressed_account.address.is_some(), ()), + data: ( + account.compressed_account.data.is_some(), + CompressedAccountDataConfig { + data: account + .compressed_account + .data + .as_ref() + .map(|d| d.data.len() as u32) + .unwrap_or(0), + }, + ), + }, + }) + .collect(), + relay_fee: invoke_ref.relay_fee.is_some(), + new_address_params: invoke_ref + .new_address_params + .iter() + .map(|_| NewAddressParamsPackedConfig {}) + .collect(), + compress_or_decompress_lamports: invoke_ref.compress_or_decompress_lamports.is_some(), + } +} + +// Generate config for InstructionDataInvokeCpi based on the actual data +fn generate_random_invoke_cpi_config( + invoke_cpi_ref: &InstructionDataInvokeCpi, +) -> InstructionDataInvokeCpiConfig { + InstructionDataInvokeCpiConfig { + proof: (invoke_cpi_ref.proof.is_some(), CompressedProofConfig {}), + new_address_params: invoke_cpi_ref + .new_address_params + .iter() + .map(|_| NewAddressParamsPackedConfig {}) + .collect(), + input_compressed_accounts_with_merkle_context: invoke_cpi_ref + .input_compressed_accounts_with_merkle_context + .iter() + .map(|account| PackedCompressedAccountWithMerkleContextConfig { + compressed_account: CompressedAccountConfig { + address: (account.compressed_account.address.is_some(), ()), + data: ( + account.compressed_account.data.is_some(), + CompressedAccountDataConfig { + data: account + .compressed_account + .data + .as_ref() + .map(|d| d.data.len() as u32) + .unwrap_or(0), + }, + ), + }, + merkle_context: PackedMerkleContextConfig {}, + }) + .collect(), + output_compressed_accounts: invoke_cpi_ref + .output_compressed_accounts + .iter() + .map(|account| OutputCompressedAccountWithPackedContextConfig { + compressed_account: CompressedAccountConfig { + address: (account.compressed_account.address.is_some(), ()), + data: ( + account.compressed_account.data.is_some(), + CompressedAccountDataConfig { + data: account + .compressed_account + .data + .as_ref() + .map(|d| d.data.len() as u32) + .unwrap_or(0), + }, + ), + }, + }) + .collect(), + relay_fee: invoke_cpi_ref.relay_fee.is_some(), + compress_or_decompress_lamports: invoke_cpi_ref.compress_or_decompress_lamports.is_some(), + cpi_context: ( + invoke_cpi_ref.cpi_context.is_some(), + CompressedCpiContextConfig {}, + ), + } +} + +#[test] +fn test_invoke_ix_data_deserialize_rnd() { + use rand::{rngs::StdRng, Rng, SeedableRng}; + let mut thread_rng = ThreadRng::default(); + let seed = thread_rng.gen(); + // Keep this print so that in case the test fails + // we can use the seed to reproduce the error. + println!("\n\ne2e test seed for invoke_ix_data {}\n\n", seed); + let mut rng = StdRng::seed_from_u64(seed); + + let num_iters = 1000; + for i in 0..num_iters { + // Create randomized instruction data + let invoke_ref = InstructionDataInvoke { + proof: if rng.gen() { + Some(CompressedProof { + a: rng.gen(), + b: (0..64) + .map(|_| rng.gen()) + .collect::>() + .try_into() + .unwrap(), + c: rng.gen(), + }) + } else { + None + }, + input_compressed_accounts_with_merkle_context: if i % 5 == 0 { + // Only add inputs occasionally to keep test manageable + vec![get_rnd_test_input_account(&mut rng); rng.gen_range(1..3)] + } else { + vec![] + }, + output_compressed_accounts: if i % 4 == 0 { + vec![get_rnd_test_output_account(&mut rng); rng.gen_range(1..3)] + } else { + vec![] + }, + relay_fee: None, // Relay fee is currently not supported + new_address_params: if i % 3 == 0 { + vec![get_rnd_new_address_params(&mut rng); rng.gen_range(1..3)] + } else { + vec![] + }, + compress_or_decompress_lamports: if rng.gen() { Some(rng.gen()) } else { None }, + is_compress: rng.gen(), + }; + + // 1. Generate config based on the random data + let config = generate_random_invoke_config(&invoke_ref); + + // 2. Calculate exact buffer size and allocate + let buffer_size = InstructionDataInvoke::byte_len(&config); + let mut bytes = vec![0u8; buffer_size]; + + // 3. Create mutable zero-copy structure and verify exact allocation + { + let result = InstructionDataInvoke::new_zero_copy(&mut bytes, config); + assert!(result.is_ok(), "Failed to create zero-copy structure"); + let (mut zero_copy_mut, remaining) = result.unwrap(); + + // 4. Verify exact buffer allocation + assert_eq!( + remaining.len(), + 0, + "Should have used exactly {} bytes", + buffer_size + ); + + // 5. Populate the mutable zero-copy structure with random data + populate_invoke_zero_copy(&invoke_ref, &mut zero_copy_mut); + }; // Mutable borrow ends here + + let borsh_ref = InstructionDataInvoke::deserialize(&mut bytes.as_slice()).unwrap(); + // 6. Test immutable deserialization to verify round-trip functionality + let result_immut = InstructionDataInvoke::zero_copy_at(&bytes); + assert!( + result_immut.is_ok(), + "Immutable deserialization should succeed" + ); + assert_eq!(invoke_ref, borsh_ref); + + // 7. Test that basic zero-copy deserialization works without crashing + // The main goal is to verify the zero-copy derive macro functionality + println!("✓ Successfully tested InstructionDataInvoke with {} inputs, {} outputs, {} new_addresses", + invoke_ref.input_compressed_accounts_with_merkle_context.len(), + invoke_ref.output_compressed_accounts.len(), + invoke_ref.new_address_params.len()); + } +} + +#[test] +fn test_instruction_data_invoke_cpi_rnd() { + use rand::{rngs::StdRng, Rng, SeedableRng}; + let mut thread_rng = ThreadRng::default(); + let seed = thread_rng.gen(); + // Keep this print so that in case the test fails + // we can use the seed to reproduce the error. + println!("\n\ne2e test seed {}\n\n", seed); + let mut rng = StdRng::seed_from_u64(seed); + + let num_iters = 10_000; + for _ in 0..num_iters { + // 1. Generate random CPI instruction data + let invoke_cpi_ref = get_rnd_instruction_data_invoke_cpi(&mut rng); + + // 2. Generate config based on the random data + let config = generate_random_invoke_cpi_config(&invoke_cpi_ref); + + // 3. Calculate exact buffer size and allocate + let buffer_size = InstructionDataInvokeCpi::byte_len(&config); + let mut bytes = vec![0u8; buffer_size]; + + // 4. Create mutable zero-copy structure and verify exact allocation + { + let result = InstructionDataInvokeCpi::new_zero_copy(&mut bytes, config); + assert!(result.is_ok(), "Failed to create CPI zero-copy structure"); + let (mut zero_copy_mut, remaining) = result.unwrap(); + + // 5. Verify exact buffer allocation + assert_eq!( + remaining.len(), + 0, + "Should have used exactly {} bytes", + buffer_size + ); + + // 6. Populate the mutable zero-copy structure with random data + populate_invoke_cpi_zero_copy(&invoke_cpi_ref, &mut zero_copy_mut); + }; // Mutable borrow ends here + + let borsh_ref = InstructionDataInvokeCpi::deserialize(&mut bytes.as_slice()).unwrap(); + // 7. Test immutable deserialization to verify round-trip functionality + let result_immut = InstructionDataInvokeCpi::zero_copy_at(&bytes); + assert!( + result_immut.is_ok(), + "Immutable deserialization should succeed" + ); + assert_eq!(invoke_cpi_ref, borsh_ref); + + // 8. Test that basic zero-copy deserialization works without crashing + // The main goal is to verify the zero-copy derive macro functionality + println!("✓ Successfully tested InstructionDataInvokeCpi with {} inputs, {} outputs, {} new_addresses", + invoke_cpi_ref.input_compressed_accounts_with_merkle_context.len(), + invoke_cpi_ref.output_compressed_accounts.len(), + invoke_cpi_ref.new_address_params.len()); + } +} diff --git a/program-libs/zero-copy/Cargo.toml b/program-libs/zero-copy/Cargo.toml index ea683e5e48..7b67262bb2 100644 --- a/program-libs/zero-copy/Cargo.toml +++ b/program-libs/zero-copy/Cargo.toml @@ -11,13 +11,17 @@ default = [] solana = ["solana-program-error"] pinocchio = ["dep:pinocchio"] std = [] +derive = ["light-zero-copy-derive"] +mut = ["light-zero-copy-derive/mut"] [dependencies] solana-program-error = { workspace = true, optional = true } pinocchio = { workspace = true, optional = true } thiserror = { workspace = true } zerocopy = { workspace = true } +light-zero-copy-derive = { workspace = true, optional = true } [dev-dependencies] rand = { workspace = true } zerocopy = { workspace = true, features = ["derive"] } +borsh = { workspace = true } diff --git a/program-libs/zero-copy/README.md b/program-libs/zero-copy/README.md index d82ee39232..e28f535469 100644 --- a/program-libs/zero-copy/README.md +++ b/program-libs/zero-copy/README.md @@ -37,6 +37,3 @@ light-zero-copy = { version = "0.1.0", features = ["anchor"] } ### Security Considerations - do not use on a 32 bit target with length greater than u32 - only length until u64 is supported - -### Tests -- `cargo test --features std` diff --git a/program-libs/zero-copy/src/borsh.rs b/program-libs/zero-copy/src/borsh.rs index c7e4fbe4db..a60d73f2ba 100644 --- a/program-libs/zero-copy/src/borsh.rs +++ b/program-libs/zero-copy/src/borsh.rs @@ -5,7 +5,7 @@ use core::{ use std::vec::Vec; use zerocopy::{ - little_endian::{U16, U32, U64}, + little_endian::{I16, I32, I64, U16, U32, U64}, FromBytes, Immutable, KnownLayout, Ref, }; @@ -52,8 +52,6 @@ impl<'a, T: Deserialize<'a>> Deserialize<'a> for Option { impl Deserialize<'_> for u8 { type Output = Self; - /// Not a zero copy but cheaper. - /// A u8 should not be deserialized on it's own but as part of a struct. #[inline] fn zero_copy_at(bytes: &[u8]) -> Result<(u8, &[u8]), ZeroCopyError> { if bytes.len() < size_of::() { @@ -64,23 +62,59 @@ impl Deserialize<'_> for u8 { } } +impl<'a> Deserialize<'a> for bool { + type Output = u8; + + #[inline] + fn zero_copy_at(bytes: &'a [u8]) -> Result<(Self::Output, &'a [u8]), ZeroCopyError> { + if bytes.len() < size_of::() { + return Err(ZeroCopyError::ArraySize(1, bytes.len())); + } + let (bytes, remaining_bytes) = bytes.split_at(size_of::()); + Ok((bytes[0], remaining_bytes)) + } +} + macro_rules! impl_deserialize_for_primitive { - ($($t:ty),*) => { + ($(($native:ty, $zerocopy:ty)),*) => { $( - impl<'a> Deserialize<'a> for $t { - type Output = Ref<&'a [u8], $t>; + impl<'a> Deserialize<'a> for $native { + type Output = Ref<&'a [u8], $zerocopy>; #[inline] fn zero_copy_at(bytes: &'a [u8]) -> Result<(Self::Output, &'a [u8]), ZeroCopyError> { - Self::Output::zero_copy_at(bytes) + Ref::<&'a [u8], $zerocopy>::from_prefix(bytes).map_err(ZeroCopyError::from) } } )* }; } -impl_deserialize_for_primitive!(u16, i16, u32, i32, u64, i64); -impl_deserialize_for_primitive!(U16, U32, U64); +impl_deserialize_for_primitive!( + (u16, U16), + (u32, U32), + (u64, U64), + (i16, I16), + (i32, I32), + (i64, I64), + (U16, U16), + (U32, U32), + (U64, U64), + (I16, I16), + (I32, I32), + (I64, I64) +); + +// Implement Deserialize for fixed-size array types +impl<'a, T: KnownLayout + Immutable + FromBytes, const N: usize> Deserialize<'a> for [T; N] { + type Output = Ref<&'a [u8], [T; N]>; + + #[inline] + fn zero_copy_at(bytes: &'a [u8]) -> Result<(Self::Output, &'a [u8]), ZeroCopyError> { + let (bytes, remaining_bytes) = Ref::<&'a [u8], [T; N]>::from_prefix(bytes)?; + Ok((bytes, remaining_bytes)) + } +} impl<'a, T: Deserialize<'a>> Deserialize<'a> for Vec { type Output = Vec; @@ -88,8 +122,14 @@ impl<'a, T: Deserialize<'a>> Deserialize<'a> for Vec { fn zero_copy_at(bytes: &'a [u8]) -> Result<(Self::Output, &'a [u8]), ZeroCopyError> { let (num_slices, mut bytes) = Ref::<&[u8], U32>::from_prefix(bytes)?; let num_slices = u32::from(*num_slices) as usize; - // TODO: add check that remaining data is enough to read num_slices - // This prevents agains invalid data allocating a lot of heap memory + // Prevent heap exhaustion attacks by checking if num_slices is reasonable + // Each element needs at least 1 byte when serialized + if bytes.len() < num_slices { + return Err(ZeroCopyError::InsufficientMemoryAllocated( + bytes.len(), + num_slices, + )); + } let mut slices = Vec::with_capacity(num_slices); for _ in 0..num_slices { let (slice, _bytes) = T::zero_copy_at(bytes)?; @@ -138,6 +178,55 @@ impl<'a, T: Deserialize<'a>> Deserialize<'a> for VecU8 { } } +pub trait ZeroCopyStructInner { + type ZeroCopyInner; +} + +impl ZeroCopyStructInner for u64 { + type ZeroCopyInner = U64; +} +impl ZeroCopyStructInner for u32 { + type ZeroCopyInner = U32; +} +impl ZeroCopyStructInner for u16 { + type ZeroCopyInner = U16; +} +impl ZeroCopyStructInner for u8 { + type ZeroCopyInner = u8; +} + +impl ZeroCopyStructInner for U16 { + type ZeroCopyInner = U16; +} +impl ZeroCopyStructInner for U32 { + type ZeroCopyInner = U32; +} +impl ZeroCopyStructInner for U64 { + type ZeroCopyInner = U64; +} + +impl ZeroCopyStructInner for Vec { + type ZeroCopyInner = Vec; +} + +impl ZeroCopyStructInner for Option { + type ZeroCopyInner = Option; +} + +// Add ZeroCopyStructInner for array types +impl ZeroCopyStructInner for [u8; N] { + type ZeroCopyInner = Ref<&'static [u8], [u8; N]>; +} + +pub fn borsh_vec_u8_as_slice(bytes: &[u8]) -> Result<(&[u8], &[u8]), ZeroCopyError> { + let (num_slices, bytes) = Ref::<&[u8], U32>::from_prefix(bytes)?; + let num_slices = u32::from(*num_slices) as usize; + if num_slices > bytes.len() { + return Err(ZeroCopyError::ArraySize(num_slices, bytes.len())); + } + Ok(bytes.split_at(num_slices)) +} + #[test] fn test_vecu8() { use std::vec; @@ -224,3 +313,561 @@ fn test_deserialize_vecu8() { assert_eq!(vec, std::vec![4, 5, 6]); assert_eq!(remaining, &[]); } + +#[cfg(test)] +pub mod test { + use std::vec; + + use borsh::{BorshDeserialize, BorshSerialize}; + use zerocopy::{ + little_endian::{U16, U64}, + IntoBytes, Ref, Unaligned, + }; + + use super::{ZeroCopyStructInner, *}; + use crate::slice::ZeroCopySliceBorsh; + + // Rules: + // 1. create ZStruct for the struct + // 1.1. the first fields are extracted into a meta struct until we reach a Vec, Option or type that does not implement Copy, and we implement deref for the meta struct + // 1.2. represent vectors to ZeroCopySlice & don't include these into the meta struct + // 1.3. replace u16 with U16, u32 with U32, etc + // 1.4. every field after the first vector is directly included in the ZStruct and deserialized 1 by 1 + // 1.5. If a vector contains a nested vector (does not implement Copy) it must implement Deserialize + // 1.6. Elements in an Option must implement Deserialize + // 1.7. a type that does not implement Copy must implement Deserialize, and is deserialized 1 by 1 + + // Derive Macro needs to derive: + // 1. ZeroCopyStructInner + // 2. Deserialize + // 3. PartialEq for ZStruct<'_> + // + // For every struct1 - struct7 create struct_derived1 - struct_derived7 and replicate the tests for the new structs. + + // Tests for manually implemented structures (without derive macro) + + #[repr(C)] + #[derive(Debug, PartialEq, BorshSerialize, BorshDeserialize)] + pub struct Struct1 { + pub a: u8, + pub b: u16, + } + + // pub fn data_hash_struct_1(a: u8, b: u16) -> [u8; 32] { + + // } + + #[repr(C)] + #[derive(Debug, PartialEq, KnownLayout, Immutable, Unaligned, FromBytes)] + pub struct ZStruct1Meta { + pub a: u8, + pub b: U16, + } + + #[repr(C)] + #[derive(Debug, PartialEq)] + pub struct ZStruct1<'a> { + meta: Ref<&'a [u8], ZStruct1Meta>, + } + impl<'a> Deref for ZStruct1<'a> { + type Target = Ref<&'a [u8], ZStruct1Meta>; + + fn deref(&self) -> &Self::Target { + &self.meta + } + } + + impl<'a> Deserialize<'a> for Struct1 { + type Output = ZStruct1<'a>; + + fn zero_copy_at(bytes: &'a [u8]) -> Result<(Self::Output, &'a [u8]), ZeroCopyError> { + let (meta, bytes) = Ref::<&[u8], ZStruct1Meta>::from_prefix(bytes)?; + Ok((ZStruct1 { meta }, bytes)) + } + } + + #[test] + fn test_struct_1() { + let bytes = Struct1 { a: 1, b: 2 }.try_to_vec().unwrap(); + let (struct1, remaining) = Struct1::zero_copy_at(&bytes).unwrap(); + assert_eq!(struct1.a, 1u8); + assert_eq!(struct1.b, 2u16); + assert_eq!(remaining, &[]); + } + + #[repr(C)] + #[derive(Debug, PartialEq, Clone, BorshSerialize, BorshDeserialize)] + pub struct Struct2 { + pub a: u8, + pub b: u16, + pub vec: Vec, + } + + #[repr(C)] + #[derive(Debug, PartialEq, KnownLayout, Immutable, Unaligned, FromBytes)] + pub struct ZStruct2Meta { + pub a: u8, + pub b: U16, + } + + #[repr(C)] + #[derive(Debug, PartialEq)] + pub struct ZStruct2<'a> { + meta: Ref<&'a [u8], ZStruct2Meta>, + pub vec: as ZeroCopyStructInner>::ZeroCopyInner, + } + + impl PartialEq for ZStruct2<'_> { + fn eq(&self, other: &Struct2) -> bool { + let meta: &ZStruct2Meta = &self.meta; + if meta.a != other.a || other.b != meta.b.into() { + return false; + } + self.vec.as_slice() == other.vec.as_slice() + } + } + + impl<'a> Deref for ZStruct2<'a> { + type Target = Ref<&'a [u8], ZStruct2Meta>; + + fn deref(&self) -> &Self::Target { + &self.meta + } + } + + impl<'a> Deserialize<'a> for Struct2 { + type Output = ZStruct2<'a>; + + fn zero_copy_at(bytes: &'a [u8]) -> Result<(Self::Output, &'a [u8]), ZeroCopyError> { + let (meta, bytes) = Ref::<&[u8], ZStruct2Meta>::from_prefix(bytes)?; + let (vec, bytes) = as Deserialize>::zero_copy_at(bytes)?; + Ok((ZStruct2 { meta, vec }, bytes)) + } + } + + #[test] + fn test_struct_2() { + let bytes = Struct2 { + a: 1, + b: 2, + vec: vec![1u8; 32], + } + .try_to_vec() + .unwrap(); + let (struct2, remaining) = Struct2::zero_copy_at(&bytes).unwrap(); + assert_eq!(struct2.a, 1u8); + assert_eq!(struct2.b, 2u16); + assert_eq!(struct2.vec.to_vec(), vec![1u8; 32]); + assert_eq!(remaining, &[]); + } + + #[repr(C)] + #[derive(Debug, PartialEq, BorshSerialize, BorshDeserialize)] + pub struct Struct3 { + pub a: u8, + pub b: u16, + pub vec: Vec, + pub c: u64, + } + + #[repr(C)] + #[derive(Debug, PartialEq, KnownLayout, Immutable, Unaligned, FromBytes)] + pub struct ZStruct3Meta { + pub a: u8, + pub b: U16, + } + + #[derive(Debug, PartialEq)] + pub struct ZStruct3<'a> { + meta: Ref<&'a [u8], ZStruct3Meta>, + pub vec: ZeroCopySliceBorsh<'a, u8>, + pub c: Ref<&'a [u8], U64>, + } + + impl<'a> Deref for ZStruct3<'a> { + type Target = Ref<&'a [u8], ZStruct3Meta>; + + fn deref(&self) -> &Self::Target { + &self.meta + } + } + + impl<'a> Deserialize<'a> for Struct3 { + type Output = ZStruct3<'a>; + + fn zero_copy_at(bytes: &'a [u8]) -> Result<(Self::Output, &'a [u8]), ZeroCopyError> { + let (meta, bytes) = Ref::<&[u8], ZStruct3Meta>::from_prefix(bytes)?; + let (vec, bytes) = ZeroCopySliceBorsh::zero_copy_at(bytes)?; + let (c, bytes) = Ref::<&[u8], U64>::from_prefix(bytes)?; + Ok((Self::Output { meta, vec, c }, bytes)) + } + } + + #[test] + fn test_struct_3() { + let bytes = Struct3 { + a: 1, + b: 2, + vec: vec![1u8; 32], + c: 3, + } + .try_to_vec() + .unwrap(); + let (zero_copy, remaining) = Struct3::zero_copy_at(&bytes).unwrap(); + assert_eq!(zero_copy.a, 1u8); + assert_eq!(zero_copy.b, 2u16); + assert_eq!(zero_copy.vec.to_vec(), vec![1u8; 32]); + assert_eq!(u64::from(*zero_copy.c), 3); + assert_eq!(remaining, &[]); + } + + #[repr(C)] + #[derive(Debug, PartialEq, BorshSerialize, BorshDeserialize, Clone)] + pub struct Struct4Nested { + a: u8, + b: u16, + } + + #[repr(C)] + #[derive( + Debug, PartialEq, Copy, Clone, KnownLayout, Immutable, IntoBytes, Unaligned, FromBytes, + )] + pub struct ZStruct4Nested { + pub a: u8, + pub b: U16, + } + + impl ZeroCopyStructInner for Struct4Nested { + type ZeroCopyInner = ZStruct4Nested; + } + + #[repr(C)] + #[derive(Debug, PartialEq, BorshSerialize, BorshDeserialize)] + pub struct Struct4 { + pub a: u8, + pub b: u16, + pub vec: Vec, + pub c: u64, + pub vec_2: Vec, + } + + #[repr(C)] + #[derive(Debug, PartialEq, KnownLayout, Immutable, Unaligned, IntoBytes, FromBytes)] + pub struct ZStruct4Meta { + pub a: ::ZeroCopyInner, + pub b: ::ZeroCopyInner, + } + + #[derive(Debug, PartialEq)] + pub struct ZStruct4<'a> { + meta: Ref<&'a [u8], ZStruct4Meta>, + pub vec: ZeroCopySliceBorsh<'a, ::ZeroCopyInner>, + pub c: Ref<&'a [u8], ::ZeroCopyInner>, + pub vec_2: ZeroCopySliceBorsh<'a, ::ZeroCopyInner>, + } + + impl<'a> Deref for ZStruct4<'a> { + type Target = Ref<&'a [u8], ZStruct4Meta>; + + fn deref(&self) -> &Self::Target { + &self.meta + } + } + + impl<'a> Deserialize<'a> for Struct4 { + type Output = ZStruct4<'a>; + + fn zero_copy_at(bytes: &'a [u8]) -> Result<(Self::Output, &'a [u8]), ZeroCopyError> { + let (meta, bytes) = Ref::<&[u8], ZStruct4Meta>::from_prefix(bytes)?; + let (vec, bytes) = ZeroCopySliceBorsh::from_bytes_at(bytes)?; + let (c, bytes) = + Ref::<&[u8], ::ZeroCopyInner>::from_prefix(bytes)?; + let (vec_2, bytes) = ZeroCopySliceBorsh::from_bytes_at(bytes)?; + Ok(( + Self::Output { + meta, + vec, + c, + vec_2, + }, + bytes, + )) + } + } + + #[test] + fn test_struct_4() { + let bytes = Struct4 { + a: 1, + b: 2, + vec: vec![1u8; 32], + c: 3, + vec_2: vec![Struct4Nested { a: 1, b: 2 }; 32], + } + .try_to_vec() + .unwrap(); + let (zero_copy, remaining) = Struct4::zero_copy_at(&bytes).unwrap(); + assert_eq!(zero_copy.a, 1u8); + assert_eq!(zero_copy.b, 2u16); + assert_eq!(zero_copy.vec.to_vec(), vec![1u8; 32]); + assert_eq!(u64::from(*zero_copy.c), 3); + assert_eq!( + zero_copy.vec_2.to_vec(), + vec![ZStruct4Nested { a: 1, b: 2.into() }; 32] + ); + assert_eq!(remaining, &[]); + } + + #[repr(C)] + #[derive(Debug, Clone, PartialEq, BorshSerialize, BorshDeserialize)] + pub struct Struct5 { + pub a: Vec>, + } + + #[repr(C)] + #[derive(Debug, PartialEq)] + pub struct ZStruct5<'a> { + pub a: Vec::ZeroCopyInner>>, + } + + impl<'a> Deserialize<'a> for Struct5 { + type Output = ZStruct5<'a>; + + fn zero_copy_at(bytes: &'a [u8]) -> Result<(Self::Output, &'a [u8]), ZeroCopyError> { + let (a, bytes) = Vec::::ZeroCopyInner>>::zero_copy_at(bytes)?; + Ok((ZStruct5 { a }, bytes)) + } + } + + #[test] + fn test_struct_5() { + let bytes = Struct5 { + a: vec![vec![1u8; 32]; 32], + } + .try_to_vec() + .unwrap(); + let (zero_copy, remaining) = Struct5::zero_copy_at(&bytes).unwrap(); + assert_eq!( + zero_copy.a.iter().map(|x| x.to_vec()).collect::>(), + vec![vec![1u8; 32]; 32] + ); + assert_eq!(remaining, &[]); + } + + // If a struct inside a vector contains a vector it must implement Deserialize. + #[repr(C)] + #[derive(Debug, PartialEq, BorshSerialize, BorshDeserialize)] + pub struct Struct6 { + pub a: Vec, + } + + #[repr(C)] + #[derive(Debug, PartialEq)] + pub struct ZStruct6<'a> { + pub a: Vec<>::Output>, + } + + impl<'a> Deserialize<'a> for Struct6 { + type Output = ZStruct6<'a>; + + fn zero_copy_at(bytes: &'a [u8]) -> Result<(Self::Output, &'a [u8]), ZeroCopyError> { + let (a, bytes) = Vec::::zero_copy_at(bytes)?; + Ok((ZStruct6 { a }, bytes)) + } + } + + #[test] + fn test_struct_6() { + let bytes = Struct6 { + a: vec![ + Struct2 { + a: 1, + b: 2, + vec: vec![1u8; 32], + }; + 32 + ], + } + .try_to_vec() + .unwrap(); + let (zero_copy, remaining) = Struct6::zero_copy_at(&bytes).unwrap(); + assert_eq!( + zero_copy.a.iter().collect::>(), + vec![ + &Struct2 { + a: 1, + b: 2, + vec: vec![1u8; 32], + }; + 32 + ] + ); + assert_eq!(remaining, &[]); + } + + #[repr(C)] + #[derive(Debug, PartialEq, Clone, BorshSerialize, BorshDeserialize)] + pub struct Struct7 { + pub a: u8, + pub b: u16, + pub option: Option, + } + + #[repr(C)] + #[derive(Debug, PartialEq, KnownLayout, Immutable, Unaligned, FromBytes)] + pub struct ZStruct7Meta { + pub a: u8, + pub b: U16, + } + + #[repr(C)] + #[derive(Debug, PartialEq)] + pub struct ZStruct7<'a> { + meta: Ref<&'a [u8], ZStruct7Meta>, + pub option: as ZeroCopyStructInner>::ZeroCopyInner, + } + + impl PartialEq for ZStruct7<'_> { + fn eq(&self, other: &Struct7) -> bool { + let meta: &ZStruct7Meta = &self.meta; + if meta.a != other.a || other.b != meta.b.into() { + return false; + } + self.option == other.option + } + } + + impl<'a> Deref for ZStruct7<'a> { + type Target = Ref<&'a [u8], ZStruct7Meta>; + + fn deref(&self) -> &Self::Target { + &self.meta + } + } + + impl<'a> Deserialize<'a> for Struct7 { + type Output = ZStruct7<'a>; + + fn zero_copy_at(bytes: &'a [u8]) -> Result<(Self::Output, &'a [u8]), ZeroCopyError> { + let (meta, bytes) = Ref::<&[u8], ZStruct7Meta>::from_prefix(bytes)?; + let (option, bytes) = as Deserialize>::zero_copy_at(bytes)?; + Ok((ZStruct7 { meta, option }, bytes)) + } + } + + #[test] + fn test_struct_7() { + let bytes = Struct7 { + a: 1, + b: 2, + option: Some(3), + } + .try_to_vec() + .unwrap(); + let (zero_copy, remaining) = Struct7::zero_copy_at(&bytes).unwrap(); + assert_eq!(zero_copy.a, 1u8); + assert_eq!(zero_copy.b, 2u16); + assert_eq!(zero_copy.option, Some(3)); + assert_eq!(remaining, &[]); + + let bytes = Struct7 { + a: 1, + b: 2, + option: None, + } + .try_to_vec() + .unwrap(); + let (zero_copy, remaining) = Struct7::zero_copy_at(&bytes).unwrap(); + assert_eq!(zero_copy.a, 1u8); + assert_eq!(zero_copy.b, 2u16); + assert_eq!(zero_copy.option, None); + assert_eq!(remaining, &[]); + } + + // If a struct inside a vector contains a vector it must implement Deserialize. + #[repr(C)] + #[derive(Debug, PartialEq, BorshSerialize, BorshDeserialize)] + pub struct Struct8 { + pub a: Vec, + } + + #[derive(Debug, Clone, PartialEq, BorshSerialize, BorshDeserialize)] + pub struct NestedStruct { + pub a: u8, + pub b: Struct2, + } + + #[repr(C)] + #[derive(Debug, PartialEq)] + pub struct ZNestedStruct<'a> { + pub a: ::ZeroCopyInner, + pub b: >::Output, + } + + impl<'a> Deserialize<'a> for NestedStruct { + type Output = ZNestedStruct<'a>; + + fn zero_copy_at(bytes: &'a [u8]) -> Result<(Self::Output, &'a [u8]), ZeroCopyError> { + let (a, bytes) = ::ZeroCopyInner::zero_copy_at(bytes)?; + let (b, bytes) = ::zero_copy_at(bytes)?; + Ok((ZNestedStruct { a, b }, bytes)) + } + } + + impl PartialEq for ZNestedStruct<'_> { + fn eq(&self, other: &NestedStruct) -> bool { + self.a == other.a && self.b == other.b + } + } + + #[repr(C)] + #[derive(Debug, PartialEq)] + pub struct ZStruct8<'a> { + pub a: Vec<>::Output>, + } + + impl<'a> Deserialize<'a> for Struct8 { + type Output = ZStruct8<'a>; + + fn zero_copy_at(bytes: &'a [u8]) -> Result<(Self::Output, &'a [u8]), ZeroCopyError> { + let (a, bytes) = Vec::::zero_copy_at(bytes)?; + Ok((ZStruct8 { a }, bytes)) + } + } + + #[test] + fn test_struct_8() { + let bytes = Struct8 { + a: vec![ + NestedStruct { + a: 1, + b: Struct2 { + a: 1, + b: 2, + vec: vec![1u8; 32], + }, + }; + 32 + ], + } + .try_to_vec() + .unwrap(); + + let (zero_copy, remaining) = Struct8::zero_copy_at(&bytes).unwrap(); + assert_eq!( + zero_copy.a.iter().collect::>(), + vec![ + &NestedStruct { + a: 1, + b: Struct2 { + a: 1, + b: 2, + vec: vec![1u8; 32], + }, + }; + 32 + ] + ); + assert_eq!(remaining, &[]); + } +} diff --git a/program-libs/zero-copy/src/borsh_mut.rs b/program-libs/zero-copy/src/borsh_mut.rs new file mode 100644 index 0000000000..38d5df2b65 --- /dev/null +++ b/program-libs/zero-copy/src/borsh_mut.rs @@ -0,0 +1,965 @@ +use core::{ + mem::size_of, + ops::{Deref, DerefMut}, +}; +use std::vec::Vec; + +use zerocopy::{ + little_endian::{I16, I32, I64, U16, U32, U64}, + FromBytes, Immutable, KnownLayout, Ref, +}; + +use crate::errors::ZeroCopyError; + +pub trait DeserializeMut<'a> +where + Self: Sized, +{ + // TODO: rename to ZeroCopy, can be used as ::ZeroCopy + type Output; + fn zero_copy_at_mut(bytes: &'a mut [u8]) + -> Result<(Self::Output, &'a mut [u8]), ZeroCopyError>; +} + +// Implement DeserializeMut for fixed-size array types +impl<'a, T: KnownLayout + Immutable + FromBytes, const N: usize> DeserializeMut<'a> for [T; N] { + type Output = Ref<&'a mut [u8], [T; N]>; + + #[inline] + fn zero_copy_at_mut( + bytes: &'a mut [u8], + ) -> Result<(Self::Output, &'a mut [u8]), ZeroCopyError> { + let (bytes, remaining_bytes) = Ref::<&'a mut [u8], [T; N]>::from_prefix(bytes)?; + Ok((bytes, remaining_bytes)) + } +} + +impl<'a, T: DeserializeMut<'a>> DeserializeMut<'a> for Option { + type Output = Option; + + #[inline] + fn zero_copy_at_mut( + bytes: &'a mut [u8], + ) -> Result<(Self::Output, &'a mut [u8]), ZeroCopyError> { + if bytes.len() < size_of::() { + return Err(ZeroCopyError::ArraySize(1, bytes.len())); + } + let (option_byte, bytes) = bytes.split_at_mut(1); + Ok(match option_byte[0] { + 0u8 => (None, bytes), + 1u8 => { + let (value, bytes) = T::zero_copy_at_mut(bytes)?; + (Some(value), bytes) + } + _ => return Err(ZeroCopyError::InvalidOptionByte(option_byte[0])), + }) + } +} + +impl<'a> DeserializeMut<'a> for u8 { + type Output = Ref<&'a mut [u8], u8>; + + #[inline] + fn zero_copy_at_mut( + bytes: &'a mut [u8], + ) -> Result<(Self::Output, &'a mut [u8]), ZeroCopyError> { + Ref::<&'a mut [u8], u8>::from_prefix(bytes).map_err(ZeroCopyError::from) + } +} + +impl<'a> DeserializeMut<'a> for bool { + type Output = Ref<&'a mut [u8], u8>; + + #[inline] + fn zero_copy_at_mut( + bytes: &'a mut [u8], + ) -> Result<(Self::Output, &'a mut [u8]), ZeroCopyError> { + Ref::<&'a mut [u8], u8>::from_prefix(bytes).map_err(ZeroCopyError::from) + } +} + +// Implementation for specific zerocopy little-endian types +impl<'a, T: KnownLayout + Immutable + FromBytes> DeserializeMut<'a> for Ref<&'a mut [u8], T> { + type Output = Ref<&'a mut [u8], T>; + + #[inline] + fn zero_copy_at_mut( + bytes: &'a mut [u8], + ) -> Result<(Self::Output, &'a mut [u8]), ZeroCopyError> { + let (bytes, remaining_bytes) = Ref::<&mut [u8], T>::from_prefix(bytes)?; + Ok((bytes, remaining_bytes)) + } +} + +impl<'a, T: DeserializeMut<'a>> DeserializeMut<'a> for Vec { + type Output = Vec; + #[inline] + fn zero_copy_at_mut( + bytes: &'a mut [u8], + ) -> Result<(Self::Output, &'a mut [u8]), ZeroCopyError> { + let (num_slices, mut bytes) = Ref::<&mut [u8], U32>::from_prefix(bytes)?; + let num_slices = u32::from(*num_slices) as usize; + // Prevent heap exhaustion attacks by checking if num_slices is reasonable + // Each element needs at least 1 byte when serialized + if bytes.len() < num_slices { + return Err(ZeroCopyError::InsufficientMemoryAllocated( + bytes.len(), + num_slices, + )); + } + let mut slices = Vec::with_capacity(num_slices); + for _ in 0..num_slices { + let (slice, _bytes) = T::zero_copy_at_mut(bytes)?; + bytes = _bytes; + slices.push(slice); + } + Ok((slices, bytes)) + } +} + +macro_rules! impl_deserialize_for_primitive { + ($(($native:ty, $zerocopy:ty)),*) => { + $( + impl<'a> DeserializeMut<'a> for $native { + type Output = Ref<&'a mut [u8], $zerocopy>; + + #[inline] + fn zero_copy_at_mut(bytes: &'a mut [u8]) -> Result<(Self::Output, &'a mut [u8]), ZeroCopyError> { + Ref::<&'a mut [u8], $zerocopy>::from_prefix(bytes).map_err(ZeroCopyError::from) + } + } + )* + }; +} + +impl_deserialize_for_primitive!( + (u16, U16), + (u32, U32), + (u64, U64), + (i16, I16), + (i32, I32), + (i64, I64), + (U16, U16), + (U32, U32), + (U64, U64), + (I16, I16), + (I32, I32), + (I64, I64) +); + +pub fn borsh_vec_u8_as_slice_mut( + bytes: &mut [u8], +) -> Result<(&mut [u8], &mut [u8]), ZeroCopyError> { + let (num_slices, bytes) = Ref::<&mut [u8], U32>::from_prefix(bytes)?; + let num_slices = u32::from(*num_slices) as usize; + if num_slices > bytes.len() { + return Err(ZeroCopyError::ArraySize(num_slices, bytes.len())); + } + Ok(bytes.split_at_mut(num_slices)) +} + +#[derive(Clone, Debug, Default, PartialEq)] +pub struct VecU8(Vec); +impl VecU8 { + pub fn new() -> Self { + Self(Vec::new()) + } +} + +impl Deref for VecU8 { + type Target = Vec; + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl DerefMut for VecU8 { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.0 + } +} + +impl<'a, T: DeserializeMut<'a>> DeserializeMut<'a> for VecU8 { + type Output = Vec; + + #[inline] + fn zero_copy_at_mut( + bytes: &'a mut [u8], + ) -> Result<(Self::Output, &'a mut [u8]), ZeroCopyError> { + let (num_slices, mut bytes) = Ref::<&mut [u8], u8>::from_prefix(bytes)?; + let mut slices = Vec::with_capacity(*num_slices as usize); + for _ in 0..(*num_slices as usize) { + let (slice, _bytes) = T::zero_copy_at_mut(bytes)?; + bytes = _bytes; + slices.push(slice); + } + Ok((slices, bytes)) + } +} + +pub trait ZeroCopyStructInnerMut { + type ZeroCopyInnerMut; +} + +impl ZeroCopyStructInnerMut for u64 { + type ZeroCopyInnerMut = U64; +} +impl ZeroCopyStructInnerMut for u32 { + type ZeroCopyInnerMut = U32; +} +impl ZeroCopyStructInnerMut for u16 { + type ZeroCopyInnerMut = U16; +} +impl ZeroCopyStructInnerMut for u8 { + type ZeroCopyInnerMut = u8; +} +impl ZeroCopyStructInnerMut for U64 { + type ZeroCopyInnerMut = U64; +} +impl ZeroCopyStructInnerMut for U32 { + type ZeroCopyInnerMut = U32; +} +impl ZeroCopyStructInnerMut for U16 { + type ZeroCopyInnerMut = U16; +} + +impl ZeroCopyStructInnerMut for Vec { + type ZeroCopyInnerMut = Vec; +} + +impl ZeroCopyStructInnerMut for Option { + type ZeroCopyInnerMut = Option; +} + +impl ZeroCopyStructInnerMut for [u8; N] { + type ZeroCopyInnerMut = Ref<&'static mut [u8], [u8; N]>; +} + +#[test] +fn test_vecu8() { + use std::vec; + let mut bytes = vec![8, 1u8, 2, 3, 4, 5, 6, 7, 8]; + let (vec, remaining_bytes) = VecU8::::zero_copy_at_mut(&mut bytes).unwrap(); + assert_eq!( + vec.iter().map(|x| **x).collect::>(), + vec![1u8, 2, 3, 4, 5, 6, 7, 8] + ); + assert_eq!(remaining_bytes, &mut []); +} + +#[test] +fn test_deserialize_mut_ref() { + let mut bytes = [1, 0, 0, 0]; // Little-endian representation of 1 + let (ref_data, remaining) = Ref::<&mut [u8], U32>::zero_copy_at_mut(&mut bytes).unwrap(); + assert_eq!(u32::from(*ref_data), 1); + assert_eq!(remaining, &mut []); + let res = Ref::<&mut [u8], U32>::zero_copy_at_mut(&mut []); + assert_eq!(res, Err(ZeroCopyError::Size)); +} + +#[test] +fn test_deserialize_mut_option_some() { + let mut bytes = [1, 2]; // 1 indicates Some, followed by the value 2 + let (option_value, remaining) = Option::::zero_copy_at_mut(&mut bytes).unwrap(); + assert_eq!(option_value.map(|x| *x), Some(2)); + assert_eq!(remaining, &mut []); + let res = Option::::zero_copy_at_mut(&mut []); + assert_eq!(res, Err(ZeroCopyError::ArraySize(1, 0))); + let mut bytes = [2, 0]; // 2 indicates invalid option byte + let res = Option::::zero_copy_at_mut(&mut bytes); + assert_eq!(res, Err(ZeroCopyError::InvalidOptionByte(2))); +} + +#[test] +fn test_deserialize_mut_option_none() { + let mut bytes = [0]; // 0 indicates None + let (option_value, remaining) = Option::::zero_copy_at_mut(&mut bytes).unwrap(); + assert_eq!(option_value, None); + assert_eq!(remaining, &mut []); +} + +#[test] +fn test_deserialize_mut_u8() { + let mut bytes = [0xFF]; // Value 255 + let (value, remaining) = u8::zero_copy_at_mut(&mut bytes).unwrap(); + assert_eq!(*value, 255); + assert_eq!(remaining, &mut []); + let res = u8::zero_copy_at_mut(&mut []); + assert_eq!(res, Err(ZeroCopyError::Size)); +} + +#[test] +fn test_deserialize_mut_u16() { + let mut bytes = 2323u16.to_le_bytes(); + let (value, remaining) = u16::zero_copy_at_mut(bytes.as_mut_slice()).unwrap(); + assert_eq!(*value, 2323u16); + assert_eq!(remaining, &mut []); + let mut value = [0u8]; + let res = u16::zero_copy_at_mut(&mut value); + + assert_eq!(res, Err(ZeroCopyError::Size)); +} + +#[test] +fn test_deserialize_mut_vec() { + let mut bytes = [2, 0, 0, 0, 1, 2]; // Length 2, followed by values 1 and 2 + let (vec, remaining) = Vec::::zero_copy_at_mut(&mut bytes).unwrap(); + assert_eq!( + vec.iter().map(|x| **x).collect::>(), + std::vec![1u8, 2] + ); + assert_eq!(remaining, &mut []); +} + +#[test] +fn test_vecu8_deref() { + let data = std::vec![1, 2, 3]; + let vec_u8 = VecU8(data.clone()); + assert_eq!(&*vec_u8, &data); + + let mut vec = VecU8::new(); + vec.push(1u8); + assert_eq!(*vec, std::vec![1u8]); +} + +#[test] +fn test_deserialize_mut_vecu8() { + let mut bytes = [3, 4, 5, 6]; // Length 3, followed by values 4, 5, 6 + let (vec, remaining) = VecU8::::zero_copy_at_mut(&mut bytes).unwrap(); + assert_eq!( + vec.iter().map(|x| **x).collect::>(), + std::vec![4, 5, 6] + ); + assert_eq!(remaining, &mut []); +} + +#[cfg(test)] +pub mod test { + use std::vec; + + use borsh::{BorshDeserialize, BorshSerialize}; + use zerocopy::{ + little_endian::{U16, U64}, + IntoBytes, Ref, Unaligned, + }; + + use super::*; + use crate::slice_mut::ZeroCopySliceMutBorsh; + + // Rules: + // 1. create ZStruct for the struct + // 1.1. the first fields are extracted into a meta struct until we reach a Vec, Option or type that does not implement Copy, and we implement deref for the meta struct + // 1.2. represent vectors to ZeroCopySlice & don't include these into the meta struct + // 1.3. replace u16 with U16, u32 with U32, etc + // 1.4. every field after the first vector is directly included in the ZStruct and deserialized 1 by 1 + // 1.5. If a vector contains a nested vector (does not implement Copy) it must implement DeserializeMut + // 1.6. Elements in an Option must implement DeserializeMut + // 1.7. a type that does not implement Copy must implement DeserializeMut, and is deserialized 1 by 1 + + // Derive Macro needs to derive: + // 1. ZeroCopyStructInnerMut + // 2. DeserializeMut + // 3. PartialEq for ZStruct<'_> + // + // For every struct1 - struct7 create struct_derived1 - struct_derived7 and replicate the tests for the new structs. + + // Tests for manually implemented structures (without derive macro) + + #[repr(C)] + #[derive(Debug, PartialEq, BorshSerialize, BorshDeserialize)] + pub struct Struct1 { + pub a: u8, + pub b: u16, + } + + #[repr(C)] + #[derive(Debug, PartialEq, KnownLayout, Immutable, Unaligned, FromBytes, IntoBytes)] + pub struct ZStruct1Meta { + pub a: u8, + pub b: U16, + } + + #[repr(C)] + #[derive(Debug, PartialEq)] + pub struct ZStruct1<'a> { + pub meta: Ref<&'a mut [u8], ZStruct1Meta>, + } + impl<'a> Deref for ZStruct1<'a> { + type Target = Ref<&'a mut [u8], ZStruct1Meta>; + + fn deref(&self) -> &Self::Target { + &self.meta + } + } + + impl DerefMut for ZStruct1<'_> { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.meta + } + } + + impl<'a> DeserializeMut<'a> for Struct1 { + type Output = ZStruct1<'a>; + + fn zero_copy_at_mut( + bytes: &'a mut [u8], + ) -> Result<(Self::Output, &'a mut [u8]), ZeroCopyError> { + let (meta, bytes) = Ref::<&mut [u8], ZStruct1Meta>::from_prefix(bytes)?; + Ok((ZStruct1 { meta }, bytes)) + } + } + + #[test] + fn test_struct_1() { + let ref_struct = Struct1 { a: 1, b: 2 }; + let mut bytes = ref_struct.try_to_vec().unwrap(); + + let (mut struct1, remaining) = Struct1::zero_copy_at_mut(&mut bytes).unwrap(); + assert_eq!(struct1.a, 1u8); + assert_eq!(struct1.b, 2u16); + assert_eq!(remaining, &mut []); + struct1.meta.a = 2; + } + + #[repr(C)] + #[derive(Debug, PartialEq, Clone, BorshSerialize, BorshDeserialize)] + pub struct Struct2 { + pub a: u8, + pub b: u16, + pub vec: Vec, + } + + #[repr(C)] + #[derive(Debug, PartialEq, KnownLayout, Immutable, Unaligned, FromBytes)] + pub struct ZStruct2Meta { + pub a: u8, + pub b: U16, + } + + #[repr(C)] + #[derive(Debug, PartialEq)] + pub struct ZStruct2<'a> { + meta: Ref<&'a mut [u8], ZStruct2Meta>, + pub vec: &'a mut [u8], + } + + impl PartialEq for ZStruct2<'_> { + fn eq(&self, other: &Struct2) -> bool { + let meta: &ZStruct2Meta = &self.meta; + if meta.a != other.a || other.b != meta.b.into() { + return false; + } + self.vec == other.vec.as_slice() + } + } + + impl<'a> Deref for ZStruct2<'a> { + type Target = Ref<&'a mut [u8], ZStruct2Meta>; + + fn deref(&self) -> &Self::Target { + &self.meta + } + } + + impl<'a> DeserializeMut<'a> for Struct2 { + type Output = ZStruct2<'a>; + + fn zero_copy_at_mut( + bytes: &'a mut [u8], + ) -> Result<(Self::Output, &'a mut [u8]), ZeroCopyError> { + let (meta, bytes) = Ref::<&mut [u8], ZStruct2Meta>::from_prefix(bytes)?; + let (len, bytes) = bytes.split_at_mut(4); + let len = U32::from_bytes( + len.try_into() + .map_err(|_| ZeroCopyError::ArraySize(4, len.len()))?, + ); + let (vec, bytes) = bytes.split_at_mut(u32::from(len) as usize); + Ok((ZStruct2 { meta, vec }, bytes)) + } + } + + #[test] + fn test_struct_2() { + let ref_struct = Struct2 { + a: 1, + b: 2, + vec: vec![1u8; 32], + }; + let mut bytes = ref_struct.try_to_vec().unwrap(); + + let (struct2, remaining) = Struct2::zero_copy_at_mut(&mut bytes).unwrap(); + assert_eq!(struct2.a, 1u8); + assert_eq!(struct2.b, 2u16); + assert_eq!(struct2.vec.to_vec(), vec![1u8; 32]); + assert_eq!(remaining, &mut []); + } + + #[repr(C)] + #[derive(Debug, PartialEq, BorshSerialize, BorshDeserialize)] + pub struct Struct3 { + pub a: u8, + pub b: u16, + pub vec: Vec, + pub c: u64, + } + + #[repr(C)] + #[derive(Debug, PartialEq, KnownLayout, Immutable, Unaligned, FromBytes)] + pub struct ZStruct3Meta { + pub a: u8, + pub b: U16, + } + + #[derive(Debug, PartialEq)] + pub struct ZStruct3<'a> { + meta: Ref<&'a mut [u8], ZStruct3Meta>, + pub vec: ZeroCopySliceMutBorsh<'a, u8>, + pub c: Ref<&'a mut [u8], U64>, + } + + impl<'a> Deref for ZStruct3<'a> { + type Target = Ref<&'a mut [u8], ZStruct3Meta>; + + fn deref(&self) -> &Self::Target { + &self.meta + } + } + + impl<'a> DeserializeMut<'a> for Struct3 { + type Output = ZStruct3<'a>; + + fn zero_copy_at_mut( + bytes: &'a mut [u8], + ) -> Result<(Self::Output, &'a mut [u8]), ZeroCopyError> { + let (meta, bytes) = Ref::<&mut [u8], ZStruct3Meta>::from_prefix(bytes)?; + let (vec, bytes) = ZeroCopySliceMutBorsh::zero_copy_at_mut(bytes)?; + let (c, bytes) = Ref::<&mut [u8], U64>::from_prefix(bytes)?; + Ok((Self::Output { meta, vec, c }, bytes)) + } + } + + #[test] + fn test_struct_3() { + let ref_struct = Struct3 { + a: 1, + b: 2, + vec: vec![1u8; 32], + c: 3, + }; + let mut bytes = ref_struct.try_to_vec().unwrap(); + + let (zero_copy, remaining) = Struct3::zero_copy_at_mut(&mut bytes).unwrap(); + assert_eq!(zero_copy.a, 1u8); + assert_eq!(zero_copy.b, 2u16); + assert_eq!(zero_copy.vec.to_vec(), vec![1u8; 32]); + assert_eq!(u64::from(*zero_copy.c), 3); + assert_eq!(remaining, &mut []); + } + + #[repr(C)] + #[derive(Debug, PartialEq, BorshSerialize, BorshDeserialize, Clone)] + pub struct Struct4Nested { + a: u8, + b: u16, + } + + impl<'a> DeserializeMut<'a> for Struct4Nested { + type Output = ZStruct4Nested; + + fn zero_copy_at_mut( + bytes: &'a mut [u8], + ) -> Result<(Self::Output, &'a mut [u8]), ZeroCopyError> { + let (bytes, remaining_bytes) = Ref::<&mut [u8], ZStruct4Nested>::from_prefix(bytes)?; + Ok((*bytes, remaining_bytes)) + } + } + + #[repr(C)] + #[derive( + Debug, PartialEq, Copy, Clone, KnownLayout, Immutable, IntoBytes, Unaligned, FromBytes, + )] + pub struct ZStruct4Nested { + pub a: u8, + pub b: U16, + } + + impl ZeroCopyStructInnerMut for Struct4Nested { + type ZeroCopyInnerMut = ZStruct4Nested; + } + + #[repr(C)] + #[derive(Debug, PartialEq, BorshSerialize, BorshDeserialize)] + pub struct Struct4 { + pub a: u8, + pub b: u16, + pub vec: Vec, + pub c: u64, + pub vec_2: Vec, + } + + #[repr(C)] + #[derive(Debug, PartialEq, KnownLayout, Immutable, Unaligned, IntoBytes, FromBytes)] + pub struct ZStruct4Meta { + pub a: ::ZeroCopyInnerMut, + pub b: ::ZeroCopyInnerMut, + } + + #[derive(Debug, PartialEq)] + pub struct ZStruct4<'a> { + meta: Ref<&'a mut [u8], ZStruct4Meta>, + pub vec: ZeroCopySliceMutBorsh<'a, ::ZeroCopyInnerMut>, + pub c: Ref<&'a mut [u8], ::ZeroCopyInnerMut>, + pub vec_2: + ZeroCopySliceMutBorsh<'a, ::ZeroCopyInnerMut>, + } + + impl<'a> Deref for ZStruct4<'a> { + type Target = Ref<&'a mut [u8], ZStruct4Meta>; + + fn deref(&self) -> &Self::Target { + &self.meta + } + } + + impl<'a> DeserializeMut<'a> for Struct4 { + type Output = ZStruct4<'a>; + + fn zero_copy_at_mut( + bytes: &'a mut [u8], + ) -> Result<(Self::Output, &'a mut [u8]), ZeroCopyError> { + let (meta, bytes) = Ref::<&mut [u8], ZStruct4Meta>::from_prefix(bytes)?; + let (vec, bytes) = ZeroCopySliceMutBorsh::from_bytes_at(bytes)?; + let (c, bytes) = + Ref::<&mut [u8], ::ZeroCopyInnerMut>::from_prefix( + bytes, + )?; + let (vec_2, bytes) = ZeroCopySliceMutBorsh::from_bytes_at(bytes)?; + Ok(( + Self::Output { + meta, + vec, + c, + vec_2, + }, + bytes, + )) + } + } + + /// TODO: + /// - add SIZE const generic DeserializeMut trait + /// - add new with data function + impl Struct4 { + // pub fn byte_len(&self) -> usize { + // size_of::() + // + size_of::() + // + size_of::() * self.vec.len() + // + size_of::() + // + size_of::() * self.vec_2.len() + // } + + pub fn new_with_data<'a>( + bytes: &'a mut [u8], + data: &Struct4, + ) -> (ZStruct4<'a>, &'a mut [u8]) { + let (mut zero_copy, bytes) = + ::zero_copy_at_mut(bytes).unwrap(); + zero_copy.meta.a = data.a; + zero_copy.meta.b = data.b.into(); + zero_copy + .vec + .iter_mut() + .zip(data.vec.iter()) + .for_each(|(x, y)| *x = *y); + (zero_copy, bytes) + } + } + + #[test] + fn test_struct_4() { + let ref_struct = Struct4 { + a: 1, + b: 2, + vec: vec![1u8; 32], + c: 3, + vec_2: vec![Struct4Nested { a: 1, b: 2 }; 32], + }; + let mut bytes = ref_struct.try_to_vec().unwrap(); + + let (zero_copy, remaining) = Struct4::zero_copy_at_mut(&mut bytes).unwrap(); + assert_eq!(zero_copy.a, 1u8); + assert_eq!(zero_copy.b, 2u16); + assert_eq!(zero_copy.vec.to_vec(), vec![1u8; 32]); + assert_eq!(u64::from(*zero_copy.c), 3); + assert_eq!( + zero_copy.vec_2.to_vec(), + vec![ZStruct4Nested { a: 1, b: 2.into() }; 32] + ); + assert_eq!(remaining, &mut []); + } + + #[repr(C)] + #[derive(Debug, Clone, PartialEq, BorshSerialize, BorshDeserialize)] + pub struct Struct5 { + pub a: Vec>, + } + + #[repr(C)] + #[derive(Debug, PartialEq)] + pub struct ZStruct5<'a> { + pub a: Vec::ZeroCopyInnerMut>>, + } + + impl<'a> DeserializeMut<'a> for Struct5 { + type Output = ZStruct5<'a>; + + fn zero_copy_at_mut( + bytes: &'a mut [u8], + ) -> Result<(Self::Output, &'a mut [u8]), ZeroCopyError> { + let (a, bytes) = Vec::< + ZeroCopySliceMutBorsh<::ZeroCopyInnerMut>, + >::zero_copy_at_mut(bytes)?; + Ok((ZStruct5 { a }, bytes)) + } + } + + #[test] + fn test_struct_5() { + let ref_struct = Struct5 { + a: vec![vec![1u8; 32]; 32], + }; + let mut bytes = ref_struct.try_to_vec().unwrap(); + + let (zero_copy, remaining) = Struct5::zero_copy_at_mut(&mut bytes).unwrap(); + assert_eq!( + zero_copy.a.iter().map(|x| x.to_vec()).collect::>(), + vec![vec![1u8; 32]; 32] + ); + assert_eq!(remaining, &mut []); + } + + // If a struct inside a vector contains a vector it must implement DeserializeMut. + #[repr(C)] + #[derive(Debug, PartialEq, BorshSerialize, BorshDeserialize)] + pub struct Struct6 { + pub a: Vec, + } + + #[repr(C)] + #[derive(Debug, PartialEq)] + pub struct ZStruct6<'a> { + pub a: Vec<>::Output>, + } + + impl<'a> DeserializeMut<'a> for Struct6 { + type Output = ZStruct6<'a>; + + fn zero_copy_at_mut( + bytes: &'a mut [u8], + ) -> Result<(Self::Output, &'a mut [u8]), ZeroCopyError> { + let (a, bytes) = Vec::::zero_copy_at_mut(bytes)?; + Ok((ZStruct6 { a }, bytes)) + } + } + + #[test] + fn test_struct_6() { + let ref_struct = Struct6 { + a: vec![ + Struct2 { + a: 1, + b: 2, + vec: vec![1u8; 32], + }; + 32 + ], + }; + let mut bytes = ref_struct.try_to_vec().unwrap(); + + let (zero_copy, remaining) = Struct6::zero_copy_at_mut(&mut bytes).unwrap(); + assert_eq!( + zero_copy.a.iter().collect::>(), + vec![ + &Struct2 { + a: 1, + b: 2, + vec: vec![1u8; 32], + }; + 32 + ] + ); + assert_eq!(remaining, &mut []); + } + + #[repr(C)] + #[derive(Debug, PartialEq, Clone, BorshSerialize, BorshDeserialize)] + pub struct Struct7 { + pub a: u8, + pub b: u16, + pub option: Option, + } + + #[repr(C)] + #[derive(Debug, PartialEq, KnownLayout, Immutable, Unaligned, FromBytes)] + pub struct ZStruct7Meta { + pub a: u8, + pub b: U16, + } + + #[repr(C)] + #[derive(Debug, PartialEq)] + pub struct ZStruct7<'a> { + meta: Ref<&'a mut [u8], ZStruct7Meta>, + pub option: Option<>::Output>, + } + + impl PartialEq for ZStruct7<'_> { + fn eq(&self, other: &Struct7) -> bool { + let meta: &ZStruct7Meta = &self.meta; + if meta.a != other.a || other.b != meta.b.into() { + return false; + } + self.option.as_ref().map(|x| **x) == other.option + } + } + + impl<'a> Deref for ZStruct7<'a> { + type Target = Ref<&'a mut [u8], ZStruct7Meta>; + + fn deref(&self) -> &Self::Target { + &self.meta + } + } + + impl<'a> DeserializeMut<'a> for Struct7 { + type Output = ZStruct7<'a>; + + fn zero_copy_at_mut( + bytes: &'a mut [u8], + ) -> Result<(Self::Output, &'a mut [u8]), ZeroCopyError> { + let (meta, bytes) = Ref::<&mut [u8], ZStruct7Meta>::from_prefix(bytes)?; + let (option, bytes) = as DeserializeMut<'a>>::zero_copy_at_mut(bytes)?; + Ok((ZStruct7 { meta, option }, bytes)) + } + } + + #[test] + fn test_struct_7() { + let ref_struct = Struct7 { + a: 1, + b: 2, + option: Some(3), + }; + let mut bytes = ref_struct.try_to_vec().unwrap(); + + let (zero_copy, remaining) = Struct7::zero_copy_at_mut(&mut bytes).unwrap(); + assert_eq!(zero_copy.a, 1u8); + assert_eq!(zero_copy.b, 2u16); + assert_eq!(zero_copy.option.map(|x| *x), Some(3)); + assert_eq!(remaining, &mut []); + + let ref_struct = Struct7 { + a: 1, + b: 2, + option: None, + }; + let mut bytes = ref_struct.try_to_vec().unwrap(); + + let (zero_copy, remaining) = Struct7::zero_copy_at_mut(&mut bytes).unwrap(); + assert_eq!(zero_copy.a, 1u8); + assert_eq!(zero_copy.b, 2u16); + assert_eq!(zero_copy.option, None); + assert_eq!(remaining, &mut []); + } + + // If a struct inside a vector contains a vector it must implement DeserializeMut. + #[repr(C)] + #[derive(Debug, PartialEq, BorshSerialize, BorshDeserialize)] + pub struct Struct8 { + pub a: Vec, + } + + #[derive(Debug, Clone, PartialEq, BorshSerialize, BorshDeserialize)] + pub struct NestedStruct { + pub a: u8, + pub b: Struct2, + } + + #[repr(C)] + #[derive(Debug, PartialEq)] + pub struct ZNestedStruct<'a> { + pub a: >::Output, + pub b: >::Output, + } + + impl<'a> DeserializeMut<'a> for NestedStruct { + type Output = ZNestedStruct<'a>; + + fn zero_copy_at_mut( + bytes: &'a mut [u8], + ) -> Result<(Self::Output, &'a mut [u8]), ZeroCopyError> { + let (a, bytes) = + ::ZeroCopyInnerMut::zero_copy_at_mut(bytes)?; + let (b, bytes) = >::zero_copy_at_mut(bytes)?; + Ok((ZNestedStruct { a, b }, bytes)) + } + } + + impl PartialEq for ZNestedStruct<'_> { + fn eq(&self, other: &NestedStruct) -> bool { + *self.a == other.a && self.b == other.b + } + } + + #[repr(C)] + #[derive(Debug, PartialEq)] + pub struct ZStruct8<'a> { + pub a: Vec<>::Output>, + } + + impl<'a> DeserializeMut<'a> for Struct8 { + type Output = ZStruct8<'a>; + + fn zero_copy_at_mut( + bytes: &'a mut [u8], + ) -> Result<(Self::Output, &'a mut [u8]), ZeroCopyError> { + let (a, bytes) = Vec::::zero_copy_at_mut(bytes)?; + Ok((ZStruct8 { a }, bytes)) + } + } + + #[test] + fn test_struct_8() { + let ref_struct = Struct8 { + a: vec![ + NestedStruct { + a: 1, + b: Struct2 { + a: 1, + b: 2, + vec: vec![1u8; 32], + }, + }; + 32 + ], + }; + let mut bytes = ref_struct.try_to_vec().unwrap(); + + let (zero_copy, remaining) = Struct8::zero_copy_at_mut(&mut bytes).unwrap(); + assert_eq!( + zero_copy.a.iter().collect::>(), + vec![ + &NestedStruct { + a: 1, + b: Struct2 { + a: 1, + b: 2, + vec: vec![1u8; 32], + }, + }; + 32 + ] + ); + assert_eq!(remaining, &mut []); + } +} diff --git a/program-libs/zero-copy/src/init_mut.rs b/program-libs/zero-copy/src/init_mut.rs new file mode 100644 index 0000000000..acd41b62ee --- /dev/null +++ b/program-libs/zero-copy/src/init_mut.rs @@ -0,0 +1,268 @@ +use core::mem::size_of; +use std::vec::Vec; + +use crate::{borsh_mut::DeserializeMut, errors::ZeroCopyError}; + +/// Trait for types that can be initialized in mutable byte slices with configuration +/// +/// This trait provides a way to initialize structures in pre-allocated byte buffers +/// with specific configuration parameters that determine Vec lengths, Option states, etc. +pub trait ZeroCopyNew<'a> +where + Self: Sized, +{ + /// Configuration type needed to initialize this type + type ZeroCopyConfig; + + /// Output type - the mutable zero-copy view of this type + type Output; + + /// Calculate the byte length needed for this type with the given configuration + /// + /// This is essential for allocating the correct buffer size before calling new_zero_copy + fn byte_len(config: &Self::ZeroCopyConfig) -> usize; + + /// Initialize this type in a mutable byte slice with the given configuration + /// + /// Returns the initialized mutable view and remaining bytes + fn new_zero_copy( + bytes: &'a mut [u8], + config: Self::ZeroCopyConfig, + ) -> Result<(Self::Output, &'a mut [u8]), ZeroCopyError>; +} + +// Generic implementation for Option +impl<'a, T> ZeroCopyNew<'a> for Option +where + T: ZeroCopyNew<'a>, +{ + type ZeroCopyConfig = (bool, T::ZeroCopyConfig); // (enabled, inner_config) + type Output = Option; + + fn byte_len(config: &Self::ZeroCopyConfig) -> usize { + let (enabled, inner_config) = config; + if *enabled { + // 1 byte for Some discriminant + inner type's byte_len + 1 + T::byte_len(inner_config) + } else { + // Just 1 byte for None discriminant + 1 + } + } + + fn new_zero_copy( + bytes: &'a mut [u8], + config: Self::ZeroCopyConfig, + ) -> Result<(Self::Output, &'a mut [u8]), ZeroCopyError> { + if bytes.is_empty() { + return Err(ZeroCopyError::ArraySize(1, bytes.len())); + } + + let (enabled, inner_config) = config; + + if enabled { + bytes[0] = 1; // Some discriminant + let (_, bytes) = bytes.split_at_mut(1); + let (value, bytes) = T::new_zero_copy(bytes, inner_config)?; + Ok((Some(value), bytes)) + } else { + bytes[0] = 0; // None discriminant + let (_, bytes) = bytes.split_at_mut(1); + Ok((None, bytes)) + } + } +} + +// Implementation for primitive types (no configuration needed) +impl<'a> ZeroCopyNew<'a> for u64 { + type ZeroCopyConfig = (); + type Output = zerocopy::Ref<&'a mut [u8], zerocopy::little_endian::U64>; + + fn byte_len(_config: &Self::ZeroCopyConfig) -> usize { + size_of::() + } + + fn new_zero_copy( + bytes: &'a mut [u8], + _config: Self::ZeroCopyConfig, + ) -> Result<(Self::Output, &'a mut [u8]), ZeroCopyError> { + // Return U64 little-endian type for generated structs + Ok(zerocopy::Ref::<&mut [u8], zerocopy::little_endian::U64>::from_prefix(bytes)?) + } +} + +impl<'a> ZeroCopyNew<'a> for u32 { + type ZeroCopyConfig = (); + type Output = zerocopy::Ref<&'a mut [u8], zerocopy::little_endian::U32>; + + fn byte_len(_config: &Self::ZeroCopyConfig) -> usize { + size_of::() + } + + fn new_zero_copy( + bytes: &'a mut [u8], + _config: Self::ZeroCopyConfig, + ) -> Result<(Self::Output, &'a mut [u8]), ZeroCopyError> { + // Return U32 little-endian type for generated structs + Ok(zerocopy::Ref::<&mut [u8], zerocopy::little_endian::U32>::from_prefix(bytes)?) + } +} + +impl<'a> ZeroCopyNew<'a> for u16 { + type ZeroCopyConfig = (); + type Output = zerocopy::Ref<&'a mut [u8], zerocopy::little_endian::U16>; + + fn byte_len(_config: &Self::ZeroCopyConfig) -> usize { + size_of::() + } + + fn new_zero_copy( + bytes: &'a mut [u8], + _config: Self::ZeroCopyConfig, + ) -> Result<(Self::Output, &'a mut [u8]), ZeroCopyError> { + // Return U16 little-endian type for generated structs + Ok(zerocopy::Ref::<&mut [u8], zerocopy::little_endian::U16>::from_prefix(bytes)?) + } +} + +impl<'a> ZeroCopyNew<'a> for u8 { + type ZeroCopyConfig = (); + type Output = >::Output; + + fn byte_len(_config: &Self::ZeroCopyConfig) -> usize { + size_of::() + } + + fn new_zero_copy( + bytes: &'a mut [u8], + _config: Self::ZeroCopyConfig, + ) -> Result<(Self::Output, &'a mut [u8]), ZeroCopyError> { + // Use the DeserializeMut trait to create the proper output + Self::zero_copy_at_mut(bytes) + } +} + +impl<'a> ZeroCopyNew<'a> for bool { + type ZeroCopyConfig = (); + type Output = >::Output; + + fn byte_len(_config: &Self::ZeroCopyConfig) -> usize { + size_of::() // bool is serialized as u8 + } + + fn new_zero_copy( + bytes: &'a mut [u8], + _config: Self::ZeroCopyConfig, + ) -> Result<(Self::Output, &'a mut [u8]), ZeroCopyError> { + // Treat bool as u8 + u8::zero_copy_at_mut(bytes) + } +} + +// Implementation for fixed-size arrays +impl< + 'a, + T: Copy + Default + zerocopy::KnownLayout + zerocopy::Immutable + zerocopy::FromBytes, + const N: usize, + > ZeroCopyNew<'a> for [T; N] +{ + type ZeroCopyConfig = (); + type Output = >::Output; + + fn byte_len(_config: &Self::ZeroCopyConfig) -> usize { + size_of::() + } + + fn new_zero_copy( + bytes: &'a mut [u8], + _config: Self::ZeroCopyConfig, + ) -> Result<(Self::Output, &'a mut [u8]), ZeroCopyError> { + // Use the DeserializeMut trait to create the proper output + Self::zero_copy_at_mut(bytes) + } +} + +// Implementation for zerocopy little-endian types +impl<'a> ZeroCopyNew<'a> for zerocopy::little_endian::U16 { + type ZeroCopyConfig = (); + type Output = zerocopy::Ref<&'a mut [u8], zerocopy::little_endian::U16>; + + fn byte_len(_config: &Self::ZeroCopyConfig) -> usize { + size_of::() + } + + fn new_zero_copy( + bytes: &'a mut [u8], + _config: Self::ZeroCopyConfig, + ) -> Result<(Self::Output, &'a mut [u8]), ZeroCopyError> { + Ok(zerocopy::Ref::<&mut [u8], zerocopy::little_endian::U16>::from_prefix(bytes)?) + } +} + +impl<'a> ZeroCopyNew<'a> for zerocopy::little_endian::U32 { + type ZeroCopyConfig = (); + type Output = zerocopy::Ref<&'a mut [u8], zerocopy::little_endian::U32>; + + fn byte_len(_config: &Self::ZeroCopyConfig) -> usize { + size_of::() + } + + fn new_zero_copy( + bytes: &'a mut [u8], + _config: Self::ZeroCopyConfig, + ) -> Result<(Self::Output, &'a mut [u8]), ZeroCopyError> { + Ok(zerocopy::Ref::<&mut [u8], zerocopy::little_endian::U32>::from_prefix(bytes)?) + } +} + +impl<'a> ZeroCopyNew<'a> for zerocopy::little_endian::U64 { + type ZeroCopyConfig = (); + type Output = zerocopy::Ref<&'a mut [u8], zerocopy::little_endian::U64>; + + fn byte_len(_config: &Self::ZeroCopyConfig) -> usize { + size_of::() + } + + fn new_zero_copy( + bytes: &'a mut [u8], + _config: Self::ZeroCopyConfig, + ) -> Result<(Self::Output, &'a mut [u8]), ZeroCopyError> { + Ok(zerocopy::Ref::<&mut [u8], zerocopy::little_endian::U64>::from_prefix(bytes)?) + } +} + +// Implementation for Vec +impl<'a, T: ZeroCopyNew<'a>> ZeroCopyNew<'a> for Vec { + type ZeroCopyConfig = Vec; // Vector of configs for each item + type Output = Vec; + + fn byte_len(config: &Self::ZeroCopyConfig) -> usize { + // 4 bytes for length prefix + sum of byte_len for each element config + 4 + config + .iter() + .map(|config| T::byte_len(config)) + .sum::() + } + + fn new_zero_copy( + bytes: &'a mut [u8], + configs: Self::ZeroCopyConfig, + ) -> Result<(Self::Output, &'a mut [u8]), ZeroCopyError> { + use zerocopy::{little_endian::U32, Ref}; + + // Write length as U32 + let len = configs.len() as u32; + let (mut len_ref, mut bytes) = Ref::<&mut [u8], U32>::from_prefix(bytes)?; + *len_ref = U32::new(len); + + // Initialize each item with its config + let mut items = Vec::with_capacity(configs.len()); + for config in configs { + let (item, remaining_bytes) = T::new_zero_copy(bytes, config)?; + bytes = remaining_bytes; + items.push(item); + } + + Ok((items, bytes)) + } +} diff --git a/program-libs/zero-copy/src/lib.rs b/program-libs/zero-copy/src/lib.rs index 297c849d53..3ac6a38948 100644 --- a/program-libs/zero-copy/src/lib.rs +++ b/program-libs/zero-copy/src/lib.rs @@ -10,8 +10,24 @@ pub mod vec; use core::mem::{align_of, size_of}; #[cfg(feature = "std")] pub mod borsh; - -use zerocopy::{FromBytes, Immutable, IntoBytes, KnownLayout}; +#[cfg(feature = "std")] +pub mod borsh_mut; +#[cfg(feature = "std")] +pub mod init_mut; +#[cfg(feature = "std")] +pub use borsh::ZeroCopyStructInner; +#[cfg(feature = "std")] +pub use init_mut::ZeroCopyNew; +#[cfg(all(feature = "derive", feature = "mut"))] +pub use light_zero_copy_derive::ZeroCopyMut; +#[cfg(feature = "derive")] +pub use light_zero_copy_derive::{ZeroCopy, ZeroCopyEq}; +#[cfg(feature = "derive")] +pub use zerocopy::{ + little_endian::{self, U16, U32, U64}, + Ref, Unaligned, +}; +pub use zerocopy::{FromBytes, Immutable, IntoBytes, KnownLayout}; #[cfg(feature = "std")] extern crate std; diff --git a/program-libs/zero-copy/src/slice_mut.rs b/program-libs/zero-copy/src/slice_mut.rs index 27cd2f776a..7a50b7e44d 100644 --- a/program-libs/zero-copy/src/slice_mut.rs +++ b/program-libs/zero-copy/src/slice_mut.rs @@ -276,3 +276,16 @@ where write!(f, "{:?}", self.as_slice()) } } + +#[cfg(feature = "std")] +impl<'a, T: ZeroCopyTraits + crate::borsh_mut::DeserializeMut<'a>> + crate::borsh_mut::DeserializeMut<'a> for ZeroCopySliceMutBorsh<'_, T> +{ + type Output = ZeroCopySliceMutBorsh<'a, T>; + + fn zero_copy_at_mut( + bytes: &'a mut [u8], + ) -> Result<(Self::Output, &'a mut [u8]), ZeroCopyError> { + ZeroCopySliceMutBorsh::from_bytes_at(bytes) + } +} diff --git a/program-libs/zero-copy/tests/borsh.rs b/program-libs/zero-copy/tests/borsh.rs new file mode 100644 index 0000000000..071b4e8df2 --- /dev/null +++ b/program-libs/zero-copy/tests/borsh.rs @@ -0,0 +1,335 @@ +#![cfg(all(feature = "std", feature = "derive", feature = "mut"))] +use borsh::{BorshDeserialize, BorshSerialize}; +use light_zero_copy::{ + borsh::Deserialize, borsh_mut::DeserializeMut, ZeroCopy, ZeroCopyEq, ZeroCopyMut, +}; + +#[repr(C)] +#[derive(Debug, PartialEq, ZeroCopy, ZeroCopyMut, ZeroCopyEq, BorshDeserialize, BorshSerialize)] +pub struct Struct1Derived { + pub a: u8, + pub b: u16, +} + +#[test] +fn test_struct_1_derived() { + let ref_struct = Struct1Derived { a: 1, b: 2 }; + let mut bytes = ref_struct.try_to_vec().unwrap(); + + { + let (struct1, remaining) = Struct1Derived::zero_copy_at(&bytes).unwrap(); + assert_eq!(struct1.a, 1u8); + assert_eq!(struct1.b, 2u16); + assert_eq!(struct1, ref_struct); + assert_eq!(remaining, &[]); + } + { + let (mut struct1, _) = Struct1Derived::zero_copy_at_mut(&mut bytes).unwrap(); + struct1.a = 2; + struct1.b = 3.into(); + } + let borsh = Struct1Derived::deserialize(&mut &bytes[..]).unwrap(); + let (struct_1, _) = Struct1Derived::zero_copy_at(&bytes).unwrap(); + assert_eq!(struct_1.a, 2); // Modified value from mutable operations + assert_eq!(struct_1.b, 3); // Modified value from mutable operations + assert_eq!(struct_1, borsh); +} + +// Struct2 equivalent: Manual implementation that should match Struct2 +#[repr(C)] +#[derive( + Debug, PartialEq, Clone, BorshSerialize, BorshDeserialize, ZeroCopy, ZeroCopyMut, ZeroCopyEq, +)] +pub struct Struct2Derived { + pub a: u8, + pub b: u16, + pub vec: Vec, +} + +#[test] +fn test_struct_2_derived() { + let ref_struct = Struct2Derived { + a: 1, + b: 2, + vec: vec![1u8; 32], + }; + let bytes = ref_struct.try_to_vec().unwrap(); + + let (struct2, remaining) = Struct2Derived::zero_copy_at(&bytes).unwrap(); + assert_eq!(struct2.a, 1u8); + assert_eq!(struct2.b, 2u16); + assert_eq!(struct2.vec.to_vec(), vec![1u8; 32]); + assert_eq!(remaining, &[]); + assert_eq!(struct2, ref_struct); +} + +// Struct3 equivalent: fields should match Struct3 +#[repr(C)] +#[derive(Debug, PartialEq, BorshSerialize, BorshDeserialize, ZeroCopy, ZeroCopyMut, ZeroCopyEq)] +pub struct Struct3Derived { + pub a: u8, + pub b: u16, + pub vec: Vec, + pub c: u64, +} + +#[test] +fn test_struct_3_derived() { + let ref_struct = Struct3Derived { + a: 1, + b: 2, + vec: vec![1u8; 32], + c: 3, + }; + let bytes = ref_struct.try_to_vec().unwrap(); + + let (zero_copy, remaining) = Struct3Derived::zero_copy_at(&bytes).unwrap(); + assert_eq!(zero_copy.a, 1u8); + assert_eq!(zero_copy.b, 2u16); + assert_eq!(zero_copy.vec.to_vec(), vec![1u8; 32]); + assert_eq!(u64::from(*zero_copy.c), 3); + assert_eq!(zero_copy, ref_struct); + + assert_eq!(remaining, &[]); +} + +#[repr(C)] +#[derive( + Debug, PartialEq, BorshSerialize, BorshDeserialize, Clone, ZeroCopy, ZeroCopyMut, ZeroCopyEq, +)] +pub struct Struct4NestedDerived { + a: u8, + b: u16, +} + +#[repr(C)] +#[derive( + Debug, PartialEq, BorshSerialize, BorshDeserialize, Clone, ZeroCopy, ZeroCopyMut, ZeroCopyEq, +)] +pub struct Struct4Derived { + pub a: u8, + pub b: u16, + pub vec: Vec, + pub c: u64, + pub vec_2: Vec, +} + +#[test] +fn test_struct_4_derived() { + let ref_struct = Struct4Derived { + a: 1, + b: 2, + vec: vec![1u8; 32], + c: 3, + vec_2: vec![Struct4NestedDerived { a: 1, b: 2 }; 32], + }; + let bytes = ref_struct.try_to_vec().unwrap(); + + let (zero_copy, remaining) = Struct4Derived::zero_copy_at(&bytes).unwrap(); + assert_eq!(zero_copy.a, 1u8); + assert_eq!(zero_copy.b, 2u16); + assert_eq!(zero_copy.vec.to_vec(), vec![1u8; 32]); + assert_eq!(u64::from(*zero_copy.c), 3); + // Check vec_2 length is correct + assert_eq!(zero_copy.vec_2.len(), 32); + assert_eq!(zero_copy, ref_struct); + assert_eq!(remaining, &[]); +} + +#[repr(C)] +#[derive( + Debug, Clone, PartialEq, BorshSerialize, BorshDeserialize, ZeroCopy, ZeroCopyMut, ZeroCopyEq, +)] +pub struct Struct5Derived { + pub a: Vec>, +} + +#[test] +fn test_struct_5_derived() { + let ref_struct = Struct5Derived { + a: vec![vec![1u8; 32]; 32], + }; + let bytes = ref_struct.try_to_vec().unwrap(); + + let (zero_copy, remaining) = Struct5Derived::zero_copy_at(&bytes).unwrap(); + assert_eq!( + zero_copy.a.iter().map(|x| x.to_vec()).collect::>(), + vec![vec![1u8; 32]; 32] + ); + assert_eq!(zero_copy, ref_struct); + assert_eq!(remaining, &[]); +} + +// If a struct inside a vector contains a vector it must implement Deserialize. +#[repr(C)] +#[derive(Debug, PartialEq, BorshSerialize, BorshDeserialize, ZeroCopy, ZeroCopyMut, ZeroCopyEq)] +pub struct Struct6Derived { + pub a: Vec, +} + +#[test] +fn test_struct_6_derived() { + let ref_struct = Struct6Derived { + a: vec![ + Struct2Derived { + a: 1, + b: 2, + vec: vec![1u8; 32], + }; + 32 + ], + }; + let bytes = ref_struct.try_to_vec().unwrap(); + + let (zero_copy, remaining) = Struct6Derived::zero_copy_at(&bytes).unwrap(); + assert_eq!( + zero_copy.a.iter().collect::>(), + vec![ + &Struct2Derived { + a: 1, + b: 2, + vec: vec![1u8; 32], + }; + 32 + ] + ); + assert_eq!(zero_copy, ref_struct); + assert_eq!(remaining, &[]); +} + +#[repr(C)] +#[derive(Debug, PartialEq, Clone, BorshSerialize, BorshDeserialize, ZeroCopy, ZeroCopyMut)] +pub struct Struct7Derived { + pub a: u8, + pub b: u16, + pub option: Option, +} + +#[test] +fn test_struct_7_derived() { + let ref_struct = Struct7Derived { + a: 1, + b: 2, + option: Some(3), + }; + let bytes = ref_struct.try_to_vec().unwrap(); + + let (zero_copy, remaining) = Struct7Derived::zero_copy_at(&bytes).unwrap(); + assert_eq!(zero_copy.a, 1u8); + assert_eq!(zero_copy.b, 2u16); + assert_eq!(zero_copy.option, Some(3)); + assert_eq!(remaining, &[]); + + let bytes = Struct7Derived { + a: 1, + b: 2, + option: None, + } + .try_to_vec() + .unwrap(); + let (zero_copy, remaining) = Struct7Derived::zero_copy_at(&bytes).unwrap(); + assert_eq!(zero_copy.a, 1u8); + assert_eq!(zero_copy.b, 2u16); + assert_eq!(zero_copy.option, None); + assert_eq!(remaining, &[]); +} + +// If a struct inside a vector contains a vector it must implement Deserialize. +#[repr(C)] +#[derive(Debug, PartialEq, BorshSerialize, BorshDeserialize, ZeroCopy, ZeroCopyMut, ZeroCopyEq)] +pub struct Struct8Derived { + pub a: Vec, +} + +#[derive( + Debug, Clone, PartialEq, BorshSerialize, BorshDeserialize, ZeroCopy, ZeroCopyMut, ZeroCopyEq, +)] +pub struct NestedStructDerived { + pub a: u8, + pub b: Struct2Derived, +} + +#[test] +fn test_struct_8_derived() { + let ref_struct = Struct8Derived { + a: vec![ + NestedStructDerived { + a: 1, + b: Struct2Derived { + a: 1, + b: 2, + vec: vec![1u8; 32], + }, + }; + 32 + ], + }; + let bytes = ref_struct.try_to_vec().unwrap(); + + let (zero_copy, remaining) = Struct8Derived::zero_copy_at(&bytes).unwrap(); + // Check length of vec matches + assert_eq!(zero_copy.a.len(), 32); + assert_eq!(zero_copy, ref_struct); + + assert_eq!(remaining, &[]); +} + +#[repr(C)] +#[derive(ZeroCopy, ZeroCopyMut, ZeroCopyEq, BorshSerialize, BorshDeserialize, PartialEq, Debug)] +pub struct ArrayStruct { + pub a: [u8; 32], + pub b: [u8; 64], + pub c: [u8; 32], +} + +#[test] +fn test_array_struct() -> Result<(), Box> { + let array_struct = ArrayStruct { + a: [1u8; 32], + b: [2u8; 64], + c: [3u8; 32], + }; + let bytes = array_struct.try_to_vec()?; + + let (zero_copy, remaining) = ArrayStruct::zero_copy_at(&bytes).unwrap(); + assert_eq!(zero_copy.a, [1u8; 32]); + assert_eq!(zero_copy.b, [2u8; 64]); + assert_eq!(zero_copy.c, [3u8; 32]); + assert_eq!(zero_copy, array_struct); + assert_eq!(remaining, &[]); + Ok(()) +} + +#[derive( + Debug, + PartialEq, + Default, + Clone, + BorshSerialize, + BorshDeserialize, + ZeroCopy, + ZeroCopyMut, + ZeroCopyEq, +)] +pub struct CompressedAccountData { + pub discriminator: [u8; 8], + pub data: Vec, + pub data_hash: [u8; 32], +} + +#[test] +fn test_compressed_account_data() { + let compressed_account_data = CompressedAccountData { + discriminator: [1u8; 8], + data: vec![2u8; 32], + data_hash: [3u8; 32], + }; + let bytes = compressed_account_data.try_to_vec().unwrap(); + + let (zero_copy, remaining) = CompressedAccountData::zero_copy_at(&bytes).unwrap(); + assert_eq!(zero_copy.discriminator, [1u8; 8]); + // assert_eq!(zero_copy.data, compressed_account_data.data.as_slice()); + assert_eq!(*zero_copy.data_hash, [3u8; 32]); + assert_eq!(zero_copy, compressed_account_data); + assert_eq!(remaining, &[]); +} diff --git a/program-libs/zero-copy/tests/borsh_2.rs b/program-libs/zero-copy/tests/borsh_2.rs new file mode 100644 index 0000000000..aece86bb1c --- /dev/null +++ b/program-libs/zero-copy/tests/borsh_2.rs @@ -0,0 +1,559 @@ +#![cfg(all(feature = "std", feature = "derive"))] + +use std::{ops::Deref, vec}; + +use borsh::{BorshDeserialize, BorshSerialize}; +use light_zero_copy::{ + borsh::Deserialize, errors::ZeroCopyError, slice::ZeroCopySliceBorsh, ZeroCopyStructInner, +}; +use zerocopy::{ + little_endian::{U16, U64}, + FromBytes, Immutable, IntoBytes, KnownLayout, Ref, Unaligned, +}; + +// Rules: +// 1. create ZStruct for the struct +// 1.1. the first fields are extracted into a meta struct until we reach a Vec, Option or type that does not implement Copy, and we implement deref for the meta struct +// 1.2. represent vectors to ZeroCopySlice & don't include these into the meta struct +// 1.3. replace u16 with U16, u32 with U32, etc +// 1.4. every field after the first vector is directly included in the ZStruct and deserialized 1 by 1 +// 1.5. If a vector contains a nested vector (does not implement Copy) it must implement Deserialize +// 1.6. Elements in an Option must implement Deserialize +// 1.7. a type that does not implement Copy must implement Deserialize, and is deserialized 1 by 1 + +// Derive Macro needs to derive: +// 1. ZeroCopyStructInner +// 2. Deserialize +// 3. PartialEq for ZStruct<'_> +// +// For every struct1 - struct7 create struct_derived1 - struct_derived7 and replicate the tests for the new structs. + +// Tests for manually implemented structures (without derive macro) + +#[repr(C)] +#[derive(Debug, PartialEq, BorshSerialize, BorshDeserialize)] +pub struct Struct1 { + pub a: u8, + pub b: u16, +} + +// pub fn data_hash_struct_1(a: u8, b: u16) -> [u8; 32] { + +// } + +#[repr(C)] +#[derive(Debug, PartialEq, KnownLayout, Immutable, Unaligned, FromBytes)] +pub struct ZStruct1Meta { + pub a: u8, + pub b: U16, +} + +#[repr(C)] +#[derive(Debug, PartialEq)] +pub struct ZStruct1<'a> { + meta: Ref<&'a [u8], ZStruct1Meta>, +} +impl<'a> Deref for ZStruct1<'a> { + type Target = Ref<&'a [u8], ZStruct1Meta>; + + fn deref(&self) -> &Self::Target { + &self.meta + } +} + +impl<'a> Deserialize<'a> for Struct1 { + type Output = ZStruct1<'a>; + + fn zero_copy_at(bytes: &'a [u8]) -> Result<(Self::Output, &'a [u8]), ZeroCopyError> { + let (meta, bytes) = Ref::<&[u8], ZStruct1Meta>::from_prefix(bytes)?; + Ok((ZStruct1 { meta }, bytes)) + } +} + +#[test] +fn test_struct_1() { + let bytes = Struct1 { a: 1, b: 2 }.try_to_vec().unwrap(); + let (struct1, remaining) = Struct1::zero_copy_at(&bytes).unwrap(); + assert_eq!(struct1.a, 1u8); + assert_eq!(struct1.b, 2u16); + assert_eq!(remaining, &[]); +} + +#[repr(C)] +#[derive(Debug, PartialEq, Clone, BorshSerialize, BorshDeserialize)] +pub struct Struct2 { + pub a: u8, + pub b: u16, + pub vec: Vec, +} + +#[repr(C)] +#[derive(Debug, PartialEq, KnownLayout, Immutable, Unaligned, FromBytes)] +pub struct ZStruct2Meta { + pub a: u8, + pub b: U16, +} + +#[repr(C)] +#[derive(Debug, PartialEq)] +pub struct ZStruct2<'a> { + meta: Ref<&'a [u8], ZStruct2Meta>, + pub vec: as ZeroCopyStructInner>::ZeroCopyInner, +} + +impl PartialEq for ZStruct2<'_> { + fn eq(&self, other: &Struct2) -> bool { + let meta: &ZStruct2Meta = &self.meta; + if meta.a != other.a || other.b != meta.b.into() { + return false; + } + self.vec.as_slice() == other.vec.as_slice() + } +} + +impl<'a> Deref for ZStruct2<'a> { + type Target = Ref<&'a [u8], ZStruct2Meta>; + + fn deref(&self) -> &Self::Target { + &self.meta + } +} + +impl<'a> Deserialize<'a> for Struct2 { + type Output = ZStruct2<'a>; + + fn zero_copy_at(bytes: &'a [u8]) -> Result<(Self::Output, &'a [u8]), ZeroCopyError> { + let (meta, bytes) = Ref::<&[u8], ZStruct2Meta>::from_prefix(bytes)?; + let (vec, bytes) = as Deserialize>::zero_copy_at(bytes)?; + Ok((ZStruct2 { meta, vec }, bytes)) + } +} + +#[test] +fn test_struct_2() { + let bytes = Struct2 { + a: 1, + b: 2, + vec: vec![1u8; 32], + } + .try_to_vec() + .unwrap(); + let (struct2, remaining) = Struct2::zero_copy_at(&bytes).unwrap(); + assert_eq!(struct2.a, 1u8); + assert_eq!(struct2.b, 2u16); + assert_eq!(struct2.vec.to_vec(), vec![1u8; 32]); + assert_eq!(remaining, &[]); +} + +#[repr(C)] +#[derive(Debug, PartialEq, BorshSerialize, BorshDeserialize)] +pub struct Struct3 { + pub a: u8, + pub b: u16, + pub vec: Vec, + pub c: u64, +} + +#[repr(C)] +#[derive(Debug, PartialEq, KnownLayout, Immutable, Unaligned, FromBytes)] +pub struct ZStruct3Meta { + pub a: u8, + pub b: U16, +} + +#[derive(Debug, PartialEq)] +pub struct ZStruct3<'a> { + meta: Ref<&'a [u8], ZStruct3Meta>, + pub vec: ZeroCopySliceBorsh<'a, u8>, + pub c: Ref<&'a [u8], U64>, +} + +impl<'a> Deref for ZStruct3<'a> { + type Target = Ref<&'a [u8], ZStruct3Meta>; + + fn deref(&self) -> &Self::Target { + &self.meta + } +} + +impl<'a> Deserialize<'a> for Struct3 { + type Output = ZStruct3<'a>; + + fn zero_copy_at(bytes: &'a [u8]) -> Result<(Self::Output, &'a [u8]), ZeroCopyError> { + let (meta, bytes) = Ref::<&[u8], ZStruct3Meta>::from_prefix(bytes)?; + let (vec, bytes) = ZeroCopySliceBorsh::zero_copy_at(bytes)?; + let (c, bytes) = Ref::<&[u8], U64>::from_prefix(bytes)?; + Ok((ZStruct3 { meta, vec, c }, bytes)) + } +} + +#[test] +fn test_struct_3() { + let bytes = Struct3 { + a: 1, + b: 2, + vec: vec![1u8; 32], + c: 3, + } + .try_to_vec() + .unwrap(); + let (zero_copy, remaining) = Struct3::zero_copy_at(&bytes).unwrap(); + assert_eq!(zero_copy.a, 1u8); + assert_eq!(zero_copy.b, 2u16); + assert_eq!(zero_copy.vec.to_vec(), vec![1u8; 32]); + assert_eq!(u64::from(*zero_copy.c), 3); + assert_eq!(remaining, &[]); +} + +#[repr(C)] +#[derive(Debug, PartialEq, BorshSerialize, BorshDeserialize, Clone)] +pub struct Struct4Nested { + a: u8, + b: u16, +} + +#[repr(C)] +#[derive( + Debug, PartialEq, Copy, Clone, KnownLayout, Immutable, IntoBytes, Unaligned, FromBytes, +)] +pub struct ZStruct4Nested { + pub a: u8, + pub b: U16, +} + +impl ZeroCopyStructInner for Struct4Nested { + type ZeroCopyInner = ZStruct4Nested; +} + +#[repr(C)] +#[derive(Debug, PartialEq, BorshSerialize, BorshDeserialize)] +pub struct Struct4 { + pub a: u8, + pub b: u16, + pub vec: Vec, + pub c: u64, + pub vec_2: Vec, +} + +#[repr(C)] +#[derive(Debug, PartialEq, KnownLayout, Immutable, Unaligned, IntoBytes, FromBytes)] +pub struct ZStruct4Meta { + pub a: ::ZeroCopyInner, + pub b: ::ZeroCopyInner, +} + +#[derive(Debug, PartialEq)] +pub struct ZStruct4<'a> { + meta: Ref<&'a [u8], ZStruct4Meta>, + pub vec: ZeroCopySliceBorsh<'a, ::ZeroCopyInner>, + pub c: Ref<&'a [u8], ::ZeroCopyInner>, + pub vec_2: ZeroCopySliceBorsh<'a, ::ZeroCopyInner>, +} + +impl<'a> Deref for ZStruct4<'a> { + type Target = Ref<&'a [u8], ZStruct4Meta>; + + fn deref(&self) -> &Self::Target { + &self.meta + } +} + +impl<'a> Deserialize<'a> for Struct4 { + type Output = ZStruct4<'a>; + + fn zero_copy_at(bytes: &'a [u8]) -> Result<(Self::Output, &'a [u8]), ZeroCopyError> { + let (meta, bytes) = Ref::<&[u8], ZStruct4Meta>::from_prefix(bytes)?; + let (vec, bytes) = ZeroCopySliceBorsh::from_bytes_at(bytes)?; + let (c, bytes) = + Ref::<&[u8], ::ZeroCopyInner>::from_prefix(bytes)?; + let (vec_2, bytes) = ZeroCopySliceBorsh::from_bytes_at(bytes)?; + Ok(( + ZStruct4 { + meta, + vec, + c, + vec_2, + }, + bytes, + )) + } +} + +#[test] +fn test_struct_4() { + let bytes = Struct4 { + a: 1, + b: 2, + vec: vec![1u8; 32], + c: 3, + vec_2: vec![Struct4Nested { a: 1, b: 2 }; 32], + } + .try_to_vec() + .unwrap(); + let (zero_copy, remaining) = Struct4::zero_copy_at(&bytes).unwrap(); + assert_eq!(zero_copy.a, 1u8); + assert_eq!(zero_copy.b, 2u16); + assert_eq!(zero_copy.vec.to_vec(), vec![1u8; 32]); + assert_eq!(u64::from(*zero_copy.c), 3); + assert_eq!( + zero_copy.vec_2.to_vec(), + vec![ZStruct4Nested { a: 1, b: 2.into() }; 32] + ); + assert_eq!(remaining, &[]); +} + +#[repr(C)] +#[derive(Debug, Clone, PartialEq, BorshSerialize, BorshDeserialize)] +pub struct Struct5 { + pub a: Vec>, +} + +#[repr(C)] +#[derive(Debug, PartialEq)] +pub struct ZStruct5<'a> { + pub a: Vec::ZeroCopyInner>>, +} + +impl<'a> Deserialize<'a> for Struct5 { + type Output = ZStruct5<'a>; + + fn zero_copy_at(bytes: &'a [u8]) -> Result<(Self::Output, &'a [u8]), ZeroCopyError> { + let (a, bytes) = + Vec::::ZeroCopyInner>>::zero_copy_at( + bytes, + )?; + Ok((ZStruct5 { a }, bytes)) + } +} + +#[test] +fn test_struct_5() { + let bytes = Struct5 { + a: vec![vec![1u8; 32]; 32], + } + .try_to_vec() + .unwrap(); + let (zero_copy, remaining) = Struct5::zero_copy_at(&bytes).unwrap(); + assert_eq!( + zero_copy.a.iter().map(|x| x.to_vec()).collect::>(), + vec![vec![1u8; 32]; 32] + ); + assert_eq!(remaining, &[]); +} + +// If a struct inside a vector contains a vector it must implement Deserialize. +#[repr(C)] +#[derive(Debug, PartialEq, BorshSerialize, BorshDeserialize)] +pub struct Struct6 { + pub a: Vec, +} + +#[repr(C)] +#[derive(Debug, PartialEq)] +pub struct ZStruct6<'a> { + pub a: Vec<>::Output>, +} + +impl<'a> Deserialize<'a> for Struct6 { + type Output = ZStruct6<'a>; + + fn zero_copy_at(bytes: &'a [u8]) -> Result<(Self::Output, &'a [u8]), ZeroCopyError> { + let (a, bytes) = Vec::::zero_copy_at(bytes)?; + Ok((ZStruct6 { a }, bytes)) + } +} + +#[test] +fn test_struct_6() { + let bytes = Struct6 { + a: vec![ + Struct2 { + a: 1, + b: 2, + vec: vec![1u8; 32], + }; + 32 + ], + } + .try_to_vec() + .unwrap(); + let (zero_copy, remaining) = Struct6::zero_copy_at(&bytes).unwrap(); + assert_eq!( + zero_copy.a.iter().collect::>(), + vec![ + &Struct2 { + a: 1, + b: 2, + vec: vec![1u8; 32], + }; + 32 + ] + ); + assert_eq!(remaining, &[]); +} + +#[repr(C)] +#[derive(Debug, PartialEq, Clone, BorshSerialize, BorshDeserialize)] +pub struct Struct7 { + pub a: u8, + pub b: u16, + pub option: Option, +} + +#[repr(C)] +#[derive(Debug, PartialEq, KnownLayout, Immutable, Unaligned, FromBytes)] +pub struct ZStruct7Meta { + pub a: u8, + pub b: U16, +} + +#[repr(C)] +#[derive(Debug, PartialEq)] +pub struct ZStruct7<'a> { + meta: Ref<&'a [u8], ZStruct7Meta>, + pub option: as ZeroCopyStructInner>::ZeroCopyInner, +} + +impl PartialEq for ZStruct7<'_> { + fn eq(&self, other: &Struct7) -> bool { + let meta: &ZStruct7Meta = &self.meta; + if meta.a != other.a || other.b != meta.b.into() { + return false; + } + self.option == other.option + } +} + +impl<'a> Deref for ZStruct7<'a> { + type Target = Ref<&'a [u8], ZStruct7Meta>; + + fn deref(&self) -> &Self::Target { + &self.meta + } +} + +impl<'a> Deserialize<'a> for Struct7 { + type Output = ZStruct7<'a>; + + fn zero_copy_at(bytes: &'a [u8]) -> Result<(Self::Output, &'a [u8]), ZeroCopyError> { + let (meta, bytes) = Ref::<&[u8], ZStruct7Meta>::from_prefix(bytes)?; + let (option, bytes) = as Deserialize>::zero_copy_at(bytes)?; + Ok((ZStruct7 { meta, option }, bytes)) + } +} + +#[test] +fn test_struct_7() { + let bytes = Struct7 { + a: 1, + b: 2, + option: Some(3), + } + .try_to_vec() + .unwrap(); + let (zero_copy, remaining) = Struct7::zero_copy_at(&bytes).unwrap(); + assert_eq!(zero_copy.a, 1u8); + assert_eq!(zero_copy.b, 2u16); + assert_eq!(zero_copy.option, Some(3)); + assert_eq!(remaining, &[]); + + let bytes = Struct7 { + a: 1, + b: 2, + option: None, + } + .try_to_vec() + .unwrap(); + let (zero_copy, remaining) = Struct7::zero_copy_at(&bytes).unwrap(); + assert_eq!(zero_copy.a, 1u8); + assert_eq!(zero_copy.b, 2u16); + assert_eq!(zero_copy.option, None); + assert_eq!(remaining, &[]); +} + +// If a struct inside a vector contains a vector it must implement Deserialize. +#[repr(C)] +#[derive(Debug, PartialEq, BorshSerialize, BorshDeserialize)] +pub struct Struct8 { + pub a: Vec, +} + +#[derive(Debug, Clone, PartialEq, BorshSerialize, BorshDeserialize)] +pub struct NestedStruct { + pub a: u8, + pub b: Struct2, +} + +#[repr(C)] +#[derive(Debug, PartialEq)] +pub struct ZNestedStruct<'a> { + pub a: ::ZeroCopyInner, + pub b: >::Output, +} + +impl<'a> Deserialize<'a> for NestedStruct { + type Output = ZNestedStruct<'a>; + + fn zero_copy_at(bytes: &'a [u8]) -> Result<(Self::Output, &'a [u8]), ZeroCopyError> { + let (a, bytes) = ::ZeroCopyInner::zero_copy_at(bytes)?; + let (b, bytes) = ::zero_copy_at(bytes)?; + Ok((ZNestedStruct { a, b }, bytes)) + } +} + +impl PartialEq for ZNestedStruct<'_> { + fn eq(&self, other: &NestedStruct) -> bool { + self.a == other.a && self.b == other.b + } +} + +#[repr(C)] +#[derive(Debug, PartialEq)] +pub struct ZStruct8<'a> { + pub a: Vec<>::Output>, +} + +impl<'a> Deserialize<'a> for Struct8 { + type Output = ZStruct8<'a>; + + fn zero_copy_at(bytes: &'a [u8]) -> Result<(Self::Output, &'a [u8]), ZeroCopyError> { + let (a, bytes) = Vec::::zero_copy_at(bytes)?; + Ok((ZStruct8 { a }, bytes)) + } +} + +#[test] +fn test_struct_8() { + let bytes = Struct8 { + a: vec![ + NestedStruct { + a: 1, + b: Struct2 { + a: 1, + b: 2, + vec: vec![1u8; 32], + }, + }; + 32 + ], + } + .try_to_vec() + .unwrap(); + + let (zero_copy, remaining) = Struct8::zero_copy_at(&bytes).unwrap(); + assert_eq!( + zero_copy.a.iter().collect::>(), + vec![ + &NestedStruct { + a: 1, + b: Struct2 { + a: 1, + b: 2, + vec: vec![1u8; 32], + }, + }; + 32 + ] + ); + assert_eq!(remaining, &[]); +} diff --git a/program-tests/compressed-token-test/Cargo.toml b/program-tests/compressed-token-test/Cargo.toml index 8f7ba53810..fb296cdf4c 100644 --- a/program-tests/compressed-token-test/Cargo.toml +++ b/program-tests/compressed-token-test/Cargo.toml @@ -39,6 +39,7 @@ light-program-test = { workspace = true, features = ["devenv"] } tokio = { workspace = true } light-prover-client = { workspace = true, features = ["devenv"] } spl-token = { workspace = true } +spl-pod = { workspace = true } anchor-spl = { workspace = true } rand = { workspace = true } serial_test = { workspace = true } diff --git a/program-tests/compressed-token-test/tests/pinocchio.rs b/program-tests/compressed-token-test/tests/pinocchio.rs new file mode 100644 index 0000000000..2af3c0f287 --- /dev/null +++ b/program-tests/compressed-token-test/tests/pinocchio.rs @@ -0,0 +1,1284 @@ +// #![cfg(feature = "test-sbf")] + +use std::assert_eq; + +use anchor_lang::prelude::borsh::BorshSerialize; +use anchor_spl::token_2022::spl_token_2022; +use light_compressed_token::mint_to_compressed::instructions::{ + CompressedMintInput, CompressedMintInputs, MintToCompressedInstructionData, Recipient, +}; + +use anchor_lang::{prelude::AccountMeta, solana_program::program_pack::Pack, system_program}; +use light_client::indexer::Indexer; +use light_program_test::{LightProgramTest, ProgramTestConfig}; +use light_sdk::instruction::ValidityProof; +use light_test_utils::Rpc; +use light_verifier::CompressedProof; +use serial_test::serial; +use solana_sdk::{instruction::Instruction, pubkey::Pubkey, signature::Keypair, signer::Signer}; + +struct MultiTransferInput { + payer: Pubkey, + current_owner: Pubkey, + new_recipient: Pubkey, + mint: Pubkey, + input_amount: u64, + transfer_amount: u64, + input_lamports: u64, + transfer_lamports: u64, + change_lamports: u64, + leaf_index: u32, + merkle_tree: Pubkey, + output_queue: Pubkey, +} + +fn create_multi_transfer_instruction(input: &MultiTransferInput) -> Instruction { + // Create input token data + let input_token_data = + light_compressed_token::multi_transfer::instruction_data::MultiInputTokenDataWithContext { + amount: input.input_amount, + merkle_context: light_sdk::instruction::PackedMerkleContext { + merkle_tree_pubkey_index: 0, // Index for merkle tree in remaining accounts + queue_pubkey_index: 1, // Index for output queue in remaining accounts + leaf_index: input.leaf_index, + prove_by_index: true, + }, + root_index: 0, + mint: 2, // Index in remaining accounts + owner: 3, // Index in remaining accounts + with_delegate: false, + delegate: 0, // Unused + }; + + // Create output token data + let output_token_data = + light_compressed_token::multi_transfer::instruction_data::MultiTokenTransferOutputData { + owner: 4, // Index for new recipient in remaining accounts + amount: input.transfer_amount, + merkle_tree: 1, // Index for output queue in remaining accounts + delegate: 0, // No delegate + mint: 2, // Same mint index + }; + + // Create multi-transfer instruction data + let multi_transfer_data = light_compressed_token::multi_transfer::instruction_data::CompressedTokenInstructionDataMultiTransfer { + with_transaction_hash: false, + with_lamports_change_account_merkle_tree_index: false, + lamports_change_account_merkle_tree_index: 0, + lamports_change_account_owner_index: 0, + proof: None, + in_token_data: vec![input_token_data], + out_token_data: vec![output_token_data], + in_lamports: Some(vec![input.input_lamports]), // Include input lamports + out_lamports: Some(vec![input.transfer_lamports]), // Include output lamports + in_tlv: None, + out_tlv: None, + compressions: None, + cpi_context: None, + }; + + // Create multi-transfer accounts in the correct order expected by processor + let multi_transfer_accounts = vec![ + // Light system program account (index 0) - skipped in processor + AccountMeta::new_readonly(light_system_program::ID, false), // 0: light_system_program (skipped) + // System accounts for multi-transfer (exact order from processor) + AccountMeta::new(input.payer, true), // 1: fee_payer (signer, mutable) + AccountMeta::new_readonly( + light_compressed_token::process_transfer::get_cpi_authority_pda().0, + false, + ), // 2: authority (CPI authority PDA, signer via CPI) + AccountMeta::new_readonly( + light_system_program::utils::get_registered_program_pda(&light_system_program::ID), + false, + ), // 3: registered_program_pda + AccountMeta::new_readonly( + Pubkey::new_from_array(account_compression::utils::constants::NOOP_PUBKEY), + false, + ), // 4: noop_program + AccountMeta::new_readonly( + light_system_program::utils::get_cpi_authority_pda(&light_system_program::ID), + false, + ), // 5: account_compression_authority + AccountMeta::new_readonly(account_compression::ID, false), // 6: account_compression_program + AccountMeta::new_readonly(light_compressed_token::ID, false), // 7: invoking_program (self_program) + // No sol_pool_pda since we don't have SOL decompression + // No sol_decompression_recipient since we don't have SOL decompression + AccountMeta::new_readonly(system_program::ID, false), // 8: system_program + // No cpi_context_account since we don't use CPI context + // Remaining accounts for token transfer - trees and queues FIRST for CPI + AccountMeta::new(input.merkle_tree, false), // 9: merkle tree (index 0 in remaining) + AccountMeta::new(input.output_queue, false), // 10: output queue (index 1 in remaining) + AccountMeta::new_readonly(input.mint, false), // 11: mint (index 2 in remaining) + AccountMeta::new_readonly(input.current_owner, true), // 12: current owner (index 3 in remaining) - must be signer + AccountMeta::new_readonly(input.new_recipient, false), // 13: new recipient (index 4 in remaining) + ]; + + Instruction { + program_id: light_compressed_token::ID, + accounts: multi_transfer_accounts, + data: [vec![104], multi_transfer_data.try_to_vec().unwrap()].concat(), // 104 is MultiTransfer discriminator + } +} + +fn derive_ctoken_ata(owner: &Pubkey, mint: &Pubkey) -> (Pubkey, u8) { + Pubkey::find_program_address( + &[ + owner.as_ref(), + light_compressed_token::ID.as_ref(), + mint.as_ref(), + ], + &light_compressed_token::ID, + ) +} + +fn create_ctoken_ata_instruction( + payer: &Pubkey, + owner: &Pubkey, + mint: &Pubkey, +) -> (Instruction, Pubkey) { + let (ctoken_ata_pubkey, bump) = derive_ctoken_ata(owner, mint); + + use light_compressed_account::Pubkey as LightPubkey; + use light_compressed_token::create_associated_token_account::instruction_data::CreateAssociatedTokenAccountInstructionData; + + let instruction_data = CreateAssociatedTokenAccountInstructionData { + owner: LightPubkey::from(owner.to_bytes()), + mint: LightPubkey::from(mint.to_bytes()), + bump, + }; + + let mut instruction_data_bytes = vec![103u8]; + instruction_data_bytes.extend_from_slice(&instruction_data.try_to_vec().unwrap()); + + let accounts = vec![ + AccountMeta::new(*payer, true), + AccountMeta::new(ctoken_ata_pubkey, false), + AccountMeta::new_readonly(*mint, false), + AccountMeta::new_readonly(*owner, false), + AccountMeta::new_readonly(system_program::ID, false), + ]; + + let create_ata_instruction = solana_sdk::instruction::Instruction { + program_id: light_compressed_token::ID, + accounts, + data: instruction_data_bytes, + }; + + (create_ata_instruction, ctoken_ata_pubkey) +} + +fn create_decompress_instruction( + proof: ValidityProof, + compressed_token_account: &[light_client::indexer::TokenAccount], + decompress_amount: u64, + spl_token_account: Pubkey, + payer: Pubkey, + output_queue: Pubkey, +) -> Instruction { + // Process all input token accounts + let mut in_token_data = Vec::with_capacity(8); + let mut in_lamports = Vec::with_capacity(8); + let mut total_amount = 0u64; + + // Calculate account indices dynamically + let merkle_tree_index = 0; + let output_queue_index = 1; + let mint_index = 2; + let owner_index = 3; + let spl_token_account_index = 4; + + for account in compressed_token_account { + total_amount += account.token.amount; + + in_token_data.push( + light_compressed_token::multi_transfer::instruction_data::MultiInputTokenDataWithContext { + amount: account.token.amount, + merkle_context: light_sdk::instruction::PackedMerkleContext { + merkle_tree_pubkey_index: merkle_tree_index, + queue_pubkey_index: output_queue_index, + leaf_index: account.account.leaf_index, + prove_by_index: true, + }, + root_index: 0, + mint: mint_index, + owner: owner_index, + with_delegate: false, + delegate: 0, + } + ); + + in_lamports.push(account.account.lamports); + } + + let remaining_amount = total_amount - decompress_amount; + + // Get merkle tree from first account + let merkle_tree = compressed_token_account[0].account.tree_info.tree; + + // Create output token data for remaining compressed tokens (if any) + let mut out_token_data = Vec::new(); + let mut out_lamports = Vec::new(); + + if remaining_amount > 0 { + out_token_data.push( + light_compressed_token::multi_transfer::instruction_data::MultiTokenTransferOutputData { + owner: owner_index, + amount: remaining_amount, + merkle_tree: output_queue_index, + delegate: 0, + mint: mint_index, + } + ); + out_lamports.push(compressed_token_account[0].account.lamports); + } + + // Create compression data for decompression + let compression_data = light_compressed_token::multi_transfer::instruction_data::Compression { + amount: decompress_amount, + is_compress: false, // This is decompression + mint: mint_index, + source_or_recipient: spl_token_account_index, + }; + + let multi_transfer_data = light_compressed_token::multi_transfer::instruction_data::CompressedTokenInstructionDataMultiTransfer { + with_transaction_hash: false, + with_lamports_change_account_merkle_tree_index: false, + lamports_change_account_merkle_tree_index: 0, // Index of output queue + lamports_change_account_owner_index: 0, // Index of owner + proof: None, + in_token_data, + out_token_data, + in_lamports: if in_lamports.is_empty() { None } else { Some(in_lamports) }, + out_lamports: if out_lamports.is_empty() { None } else { Some(out_lamports) }, + in_tlv: None, + out_tlv: None, + compressions: Some(vec![compression_data]), + cpi_context: None, + }; + + let multi_transfer_accounts = vec![ + AccountMeta::new_readonly(light_system_program::ID, false), + AccountMeta::new(payer, true), + AccountMeta::new_readonly( + light_compressed_token::process_transfer::get_cpi_authority_pda().0, + false, + ), + AccountMeta::new_readonly( + light_system_program::utils::get_registered_program_pda(&light_system_program::ID), + false, + ), + AccountMeta::new_readonly( + Pubkey::new_from_array(account_compression::utils::constants::NOOP_PUBKEY), + false, + ), + AccountMeta::new_readonly( + light_system_program::utils::get_cpi_authority_pda(&light_system_program::ID), + false, + ), + AccountMeta::new_readonly(account_compression::ID, false), + AccountMeta::new_readonly(light_compressed_token::ID, false), + AccountMeta::new_readonly(system_program::ID, false), + // Tree accounts + AccountMeta::new(merkle_tree, false), // 0: merkle tree + AccountMeta::new(output_queue, false), // 1: output queue + AccountMeta::new_readonly(compressed_token_account[0].token.mint, false), // 2: mint + AccountMeta::new_readonly(compressed_token_account[0].token.owner, true), // 3: current owner (signer) + AccountMeta::new(spl_token_account, false), // 4: SPL token account for decompression + ]; + + Instruction { + program_id: light_compressed_token::ID, + accounts: multi_transfer_accounts, + data: [vec![104], multi_transfer_data.try_to_vec().unwrap()].concat(), + } +} + +fn create_compressed_mint( + decimals: u8, + mint_authority: Pubkey, + freeze_authority: Option, + proof: CompressedProof, + mint_bump: u8, + address_merkle_tree_root_index: u16, + mint_signer: Pubkey, + payer: Pubkey, + address_tree_pubkey: Pubkey, + output_queue: Pubkey, +) -> Instruction { + let instruction_data = + light_compressed_token::mint::instructions::CreateCompressedMintInstructionData { + decimals, + mint_authority: mint_authority.into(), + freeze_authority: freeze_authority.map(|auth| auth.into()), + proof, + mint_bump, + address_merkle_tree_root_index, + }; + + let accounts = vec![ + // Static non-CPI accounts first + AccountMeta::new_readonly(mint_signer, true), // 0: mint_signer (signer) + AccountMeta::new_readonly(light_system_program::ID, false), // light system program + // CPI accounts in exact order expected by execute_cpi_invoke + AccountMeta::new(payer, true), // 1: fee_payer (signer, mutable) + AccountMeta::new_readonly( + light_compressed_token::process_transfer::get_cpi_authority_pda().0, + false, + ), // 2: cpi_authority_pda + AccountMeta::new_readonly( + light_system_program::utils::get_registered_program_pda(&light_system_program::ID), + false, + ), // 3: registered_program_pda + AccountMeta::new_readonly( + Pubkey::new_from_array(account_compression::utils::constants::NOOP_PUBKEY), + false, + ), // 4: noop_program + AccountMeta::new_readonly( + light_system_program::utils::get_cpi_authority_pda(&light_system_program::ID), + false, + ), // 5: account_compression_authority + AccountMeta::new_readonly(account_compression::ID, false), // 6: account_compression_program + AccountMeta::new_readonly(light_compressed_token::ID, false), // 7: invoking_program (self_program) + // AccountMeta::new_readonly(light_system_program::ID, false), // 8: sol_pool_pda placeholder + // AccountMeta::new_readonly(light_system_program::ID, false), // 9: decompression_recipient + AccountMeta::new_readonly(system_program::ID, false), // 10: system_program + // AccountMeta::new_readonly(light_system_program::ID, false), // 11: cpi_context_account placeholder + AccountMeta::new(address_tree_pubkey, false), // 12: address_merkle_tree (mutable) + AccountMeta::new(output_queue, false), // 13: output_queue (mutable) + ]; + + Instruction { + program_id: light_compressed_token::ID, + accounts, + data: [vec![100], instruction_data.try_to_vec().unwrap()].concat(), + } +} + +#[tokio::test] +#[serial] +async fn test_create_compressed_mint() { + let mut rpc = LightProgramTest::new(ProgramTestConfig::new_v2(false, None)) + .await + .unwrap(); + let payer = rpc.get_payer().insecure_clone(); + + // Test parameters + let decimals = 6u8; + let mint_authority_keypair = Keypair::new(); // Create keypair so we can sign + let mint_authority = mint_authority_keypair.pubkey(); + let freeze_authority = Pubkey::new_unique(); + let mint_signer = Keypair::new(); + + // Get address tree for creating compressed mint address + let address_tree_pubkey = rpc.get_address_merkle_tree_v2(); + let output_queue = rpc.get_random_state_tree_info().unwrap().queue; + let state_merkle_tree = rpc.get_random_state_tree_info().unwrap().tree; + + // Find mint PDA and bump + let (mint_pda, mint_bump) = Pubkey::find_program_address( + &[b"compressed_mint", mint_signer.pubkey().as_ref()], + &light_compressed_token::ID, + ); + + // Use the mint PDA as the seed for the compressed account address + let address_seed = mint_pda.to_bytes(); + + let compressed_mint_address = light_compressed_account::address::derive_address( + &address_seed, + &address_tree_pubkey.to_bytes(), + &light_compressed_token::ID.to_bytes(), + ); + + // Get validity proof for address creation + let rpc_result = rpc + .get_validity_proof( + vec![], + vec![light_program_test::AddressWithTree { + address: compressed_mint_address, + tree: address_tree_pubkey, + }], + None, + ) + .await + .unwrap() + .value; + + let address_merkle_tree_root_index = rpc_result.addresses[0].root_index; + + // Create instruction + let instruction = create_compressed_mint( + decimals, + mint_authority, + Some(freeze_authority), + rpc_result.proof.0.unwrap(), + mint_bump, + address_merkle_tree_root_index, + mint_signer.pubkey(), + payer.pubkey(), + address_tree_pubkey, + output_queue, + ); + + // Send transaction + rpc.create_and_send_transaction(&[instruction], &payer.pubkey(), &[&payer, &mint_signer]) + .await + .unwrap(); + + // Verify the compressed mint was created + let compressed_mint_account = rpc + .indexer() + .unwrap() + .get_compressed_account(compressed_mint_address, None) + .await + .unwrap() + .value; + + // Create expected compressed mint for comparison + let expected_compressed_mint = light_compressed_token::create_mint::CompressedMint { + spl_mint: mint_pda, + supply: 0, + decimals, + is_decompressed: false, + mint_authority: Some(mint_authority), + freeze_authority: Some(freeze_authority), + num_extensions: 0, + }; + + // Verify the account exists and has correct properties + assert_eq!( + compressed_mint_account.address.unwrap(), + compressed_mint_address + ); + assert_eq!(compressed_mint_account.owner, light_compressed_token::ID); + assert_eq!(compressed_mint_account.lamports, 0); + + // Verify the compressed mint data + let compressed_account_data = compressed_mint_account.data.unwrap(); + assert_eq!( + compressed_account_data.discriminator, + light_compressed_token::constants::COMPRESSED_MINT_DISCRIMINATOR + ); + + // Deserialize and verify the CompressedMint struct matches expected + let actual_compressed_mint: light_compressed_token::create_mint::CompressedMint = + anchor_lang::AnchorDeserialize::deserialize(&mut compressed_account_data.data.as_slice()) + .unwrap(); + + assert_eq!(actual_compressed_mint, expected_compressed_mint); + + // Test mint_to_compressed functionality + let recipient_keypair = Keypair::new(); + let recipient = recipient_keypair.pubkey(); + let mint_amount = 1000u64; + let lamports = Some(10000u64); + + // Get state tree for output token accounts + let state_tree_info = rpc.get_random_state_tree_info().unwrap(); + let state_tree_pubkey = state_tree_info.tree; + let state_output_queue = state_tree_info.queue; + println!("state_tree_pubkey {:?}", state_tree_pubkey); + println!("state_output_queue {:?}", state_output_queue); + + // Prepare compressed mint inputs for minting + let compressed_mint_inputs = CompressedMintInputs { + merkle_context: light_compressed_account::compressed_account::PackedMerkleContext { + merkle_tree_pubkey_index: 0, // Will be set in remaining accounts + queue_pubkey_index: 1, + leaf_index: compressed_mint_account.leaf_index, + prove_by_index: true, + }, + root_index: 0, + address: compressed_mint_address, + compressed_mint_input: CompressedMintInput { + spl_mint: expected_compressed_mint.spl_mint.into(), + supply: expected_compressed_mint.supply, // Current supply + decimals: expected_compressed_mint.decimals, + is_decompressed: expected_compressed_mint.is_decompressed, // Pure compressed mint + freeze_authority_is_set: expected_compressed_mint.freeze_authority.is_some(), + freeze_authority: expected_compressed_mint + .freeze_authority + .unwrap_or_default() + .into(), + num_extensions: 0, + }, + output_merkle_tree_index: 3, + }; + + // Create mint_to_compressed instruction + let mint_to_instruction_data = MintToCompressedInstructionData { + compressed_mint_inputs, + lamports, + recipients: vec![Recipient { + recipient: recipient.into(), + amount: mint_amount, + }], + proof: None, // No proof needed for this test + }; + + // Create accounts in the correct order for manual parsing + let mint_to_accounts = vec![ + // Static non-CPI accounts first + AccountMeta::new_readonly(mint_authority, true), // 0: authority (signer) + // AccountMeta::new(mint_pda, false), // 1: mint (mutable) + // AccountMeta::new(Pubkey::new_unique(), false), // 2: token_pool_pda (mutable) + // AccountMeta::new_readonly(spl_token::ID, false), // 3: token_program + AccountMeta::new_readonly(light_system_program::ID, false), // 4: light_system_program + // CPI accounts in exact order expected by InvokeCpiWithReadOnly + AccountMeta::new(payer.pubkey(), true), // 5: fee_payer (signer, mutable) + AccountMeta::new_readonly( + light_compressed_token::process_transfer::get_cpi_authority_pda().0, + false, + ), // 6: cpi_authority_pda + AccountMeta::new_readonly( + light_system_program::utils::get_registered_program_pda(&light_system_program::ID), + false, + ), // 7: registered_program_pda + AccountMeta::new_readonly( + Pubkey::new_from_array(account_compression::utils::constants::NOOP_PUBKEY), + false, + ), // 8: noop_program + AccountMeta::new_readonly( + light_system_program::utils::get_cpi_authority_pda(&light_system_program::ID), + false, + ), // 9: account_compression_authority + AccountMeta::new_readonly(account_compression::ID, false), // 10: account_compression_program + AccountMeta::new_readonly(light_compressed_token::ID, false), // 11: self_program + AccountMeta::new(light_system_program::utils::get_sol_pool_pda(), false), // 12: sol_pool_pda (mutable) + AccountMeta::new_readonly(Pubkey::default(), false), // 13: system_program + AccountMeta::new(state_merkle_tree, false), // 14: mint_merkle_tree (mutable) + AccountMeta::new(output_queue, false), // 15: mint_in_queue (mutable) + AccountMeta::new(output_queue, false), // 16: mint_out_queue (mutable) + AccountMeta::new(output_queue, false), // 17: tokens_out_queue (mutable) + ]; + println!("mint_to_accounts {:?}", mint_to_accounts); + println!("output_queue {:?}", output_queue); + println!("output_queue {:?}", output_queue); + println!( + "light_system_program::utils::get_sol_pool_pda() {:?}", + light_system_program::utils::get_sol_pool_pda() + ); + + let mut mint_instruction = Instruction { + program_id: light_compressed_token::ID, + accounts: mint_to_accounts, + data: [vec![101], mint_to_instruction_data.try_to_vec().unwrap()].concat(), + }; + + // Add remaining accounts: compressed mint's address tree, then output state tree + mint_instruction.accounts.extend_from_slice(&[ + AccountMeta::new(state_tree_pubkey, false), // Compressed mint's queue + ]); + + // Execute mint_to_compressed + // Note: We need the mint authority to sign since it's the authority for minting + rpc.create_and_send_transaction( + &[mint_instruction], + &payer.pubkey(), + &[&payer, &mint_authority_keypair], + ) + .await + .unwrap(); + + // Verify minted token account + let token_accounts = rpc + .indexer() + .unwrap() + .get_compressed_token_accounts_by_owner(&recipient, None, None) + .await + .unwrap() + .value + .items; + + assert_eq!( + token_accounts.len(), + 1, + "Should have exactly one token account" + ); + let token_account = &token_accounts[0].token; + assert_eq!( + token_account.mint, mint_pda, + "Token account should have correct mint" + ); + assert_eq!( + token_account.amount, mint_amount, + "Token account should have correct amount" + ); + assert_eq!( + token_account.owner, recipient, + "Token account should have correct owner" + ); + + // Verify updated compressed mint supply + let updated_compressed_mint_account = rpc + .indexer() + .unwrap() + .get_compressed_account(compressed_mint_address, None) + .await + .unwrap() + .value; + + let updated_compressed_mint: light_compressed_token::create_mint::CompressedMint = + anchor_lang::AnchorDeserialize::deserialize( + &mut updated_compressed_mint_account + .data + .unwrap() + .data + .as_slice(), + ) + .unwrap(); + + assert_eq!( + updated_compressed_mint.supply, mint_amount, + "Compressed mint supply should be updated to match minted amount" + ); + + // Test create_spl_mint functionality + println!("Creating SPL mint for the compressed mint..."); + + // Find token pool PDA and bump + let (token_pool_pda, token_pool_bump) = + light_compressed_token::instructions::create_token_pool::find_token_pool_pda_with_index( + &mint_pda, 0, + ); + + // Prepare compressed mint inputs for create_spl_mint + let compressed_mint_inputs_for_spl = CompressedMintInputs { + merkle_context: light_compressed_account::compressed_account::PackedMerkleContext { + merkle_tree_pubkey_index: 0, // Will be set in remaining accounts + queue_pubkey_index: 1, + leaf_index: updated_compressed_mint_account.leaf_index, + prove_by_index: true, + }, + root_index: address_merkle_tree_root_index, + address: compressed_mint_address, + compressed_mint_input: CompressedMintInput { + spl_mint: mint_pda.into(), + supply: mint_amount, // Current supply after minting + decimals, + is_decompressed: false, // Not yet decompressed + freeze_authority_is_set: true, + freeze_authority: freeze_authority.into(), + num_extensions: 0, + }, + output_merkle_tree_index: 2, + }; + + // Create create_spl_mint instruction data using the non-anchor pattern + let create_spl_mint_instruction_data = + light_compressed_token::create_spl_mint::instructions::CreateSplMintInstructionData { + mint_bump, + token_pool_bump, + decimals, + mint_authority: mint_authority.into(), + freeze_authority: Some(freeze_authority.into()), + compressed_mint_inputs: compressed_mint_inputs_for_spl, + proof: None, // No proof needed for this test + }; + + // Build accounts manually for non-anchor instruction (following account order from accounts.rs) + let create_spl_mint_accounts = vec![ + // Static non-CPI accounts first + AccountMeta::new_readonly(mint_authority, true), // 0: authority + AccountMeta::new(mint_pda, false), // 1: mint + AccountMeta::new_readonly(mint_signer.pubkey(), false), // 2: mint_signer + AccountMeta::new(token_pool_pda, false), // 3: token_pool_pda + AccountMeta::new_readonly(spl_token_2022::ID, false), // 4: token_program + AccountMeta::new_readonly(light_system_program::ID, false), // 5: light_system_program + // CPI accounts in exact order expected by light-system-program + AccountMeta::new(payer.pubkey(), true), // 5: fee_payer + AccountMeta::new_readonly( + light_compressed_token::process_transfer::get_cpi_authority_pda().0, + false, + ), // 6: cpi_authority_pda + AccountMeta::new_readonly( + light_system_program::utils::get_registered_program_pda(&light_system_program::ID), + false, + ), // 7: registered_program_pda + AccountMeta::new_readonly( + Pubkey::new_from_array(account_compression::utils::constants::NOOP_PUBKEY), + false, + ), // 8: noop_program + AccountMeta::new_readonly( + light_system_program::utils::get_cpi_authority_pda(&light_system_program::ID), + false, + ), // 9: account_compression_authority + AccountMeta::new_readonly(account_compression::ID, false), // 10: account_compression_program + AccountMeta::new_readonly(light_compressed_token::ID, false), // 11: self_program + AccountMeta::new_readonly(system_program::ID, false), // 13: system_program + AccountMeta::new(state_merkle_tree, false), // 14: in_merkle_tree + AccountMeta::new(output_queue, false), // 15: in_output_queue + AccountMeta::new(output_queue, false), // 16: out_output_queue + ]; + println!("create_spl_mint_accounts {:?}", create_spl_mint_accounts); + + let mut create_spl_mint_instruction = Instruction { + program_id: light_compressed_token::ID, + accounts: create_spl_mint_accounts, + data: [ + vec![102], + create_spl_mint_instruction_data.try_to_vec().unwrap(), + ] + .concat(), // 102 = CreateSplMint discriminator + }; + + // Add remaining accounts (address tree for compressed mint updates) + create_spl_mint_instruction.accounts.extend_from_slice(&[ + AccountMeta::new(address_tree_pubkey, false), // Address tree for compressed mint + ]); + + // Execute create_spl_mint + rpc.create_and_send_transaction( + &[create_spl_mint_instruction], + &payer.pubkey(), + &[&payer, &mint_authority_keypair], + ) + .await + .unwrap(); + + // Verify SPL mint was created + let mint_account_data = rpc.get_account(mint_pda).await.unwrap().unwrap(); + let spl_mint = spl_token_2022::state::Mint::unpack(&mint_account_data.data).unwrap(); + assert_eq!( + spl_mint.decimals, decimals, + "SPL mint should have correct decimals" + ); + assert_eq!( + spl_mint.supply, mint_amount, + "SPL mint should have minted supply" + ); + assert_eq!( + spl_mint.mint_authority.unwrap(), + mint_authority, + "SPL mint should have correct authority" + ); + + // Verify token pool was created and has the supply + let token_pool_account_data = rpc.get_account(token_pool_pda).await.unwrap().unwrap(); + let token_pool = spl_token_2022::state::Account::unpack(&token_pool_account_data.data).unwrap(); + assert_eq!( + token_pool.mint, mint_pda, + "Token pool should have correct mint" + ); + assert_eq!( + token_pool.amount, mint_amount, + "Token pool should have the minted supply" + ); + + // Verify compressed mint is now marked as decompressed + let final_compressed_mint_account = rpc + .indexer() + .unwrap() + .get_compressed_account(compressed_mint_address, None) + .await + .unwrap() + .value; + + let final_compressed_mint: light_compressed_token::create_mint::CompressedMint = + anchor_lang::AnchorDeserialize::deserialize( + &mut final_compressed_mint_account.data.unwrap().data.as_slice(), + ) + .unwrap(); + + assert!( + final_compressed_mint.is_decompressed, + "Compressed mint should now be marked as decompressed" + ); + + // Test decompression functionality + println!("Testing token decompression..."); + + // Create SPL token account for the recipient + let recipient_token_keypair = Keypair::new(); // Create keypair for token account + light_test_utils::spl::create_token_2022_account( + &mut rpc, + &mint_pda, + &recipient_token_keypair, + &payer, + true, // token_22 + ) + .await + .unwrap(); + + // Get the compressed token account for decompression + let compressed_token_accounts = rpc + .indexer() + .unwrap() + .get_compressed_token_accounts_by_owner(&recipient, None, None) + .await + .unwrap() + .value + .items; + + assert_eq!( + compressed_token_accounts.len(), + 1, + "Should have one compressed token account" + ); + let _input_compressed_account = compressed_token_accounts[0].clone(); + + // Decompress half of the tokens (500 out of 1000) + let _decompress_amount = mint_amount / 2; + let _output_merkle_tree_pubkey = state_tree_pubkey; + + // Since we need a keypair to sign, and tokens were minted to a pubkey, let's skip decompression test for now + // and just verify the basic create_spl_mint functionality worked + println!("✅ SPL mint creation and token pool setup completed successfully!"); + println!( + "Note: Decompression test skipped - would need token owner keypair to sign transaction" + ); + + // The SPL mint and token pool have been successfully created and verified + println!("✅ create_spl_mint test completed successfully!"); + println!(" - SPL mint created with supply: {}", mint_amount); + println!(" - Token pool created with balance: {}", mint_amount); + println!( + " - Compressed mint marked as decompressed: {}", + final_compressed_mint.is_decompressed + ); + + // Add a simple multi-transfer test: 1 input -> 1 output + println!("🔄 Testing multi-transfer..."); + + let new_recipient_keypair = Keypair::new(); + let new_recipient = new_recipient_keypair.pubkey(); + let transfer_amount = mint_amount; // Transfer all tokens (1000) + + let input_lamports = token_accounts[0].account.lamports; // Get the lamports from the token account + let transfer_lamports = (input_lamports * transfer_amount) / mint_amount; // Proportional lamports transfer + let change_lamports = 0; // No change in lamports since we're transferring proportionally + println!("owner {:?}", recipient); + let multi_transfer_input = MultiTransferInput { + payer: payer.pubkey(), + current_owner: recipient, + new_recipient, + mint: mint_pda, + input_amount: mint_amount, + transfer_amount, + input_lamports, + transfer_lamports, + change_lamports, + leaf_index: token_accounts[0].account.leaf_index, + merkle_tree: state_tree_pubkey, + output_queue: state_output_queue, + }; + + let multi_transfer_instruction = create_multi_transfer_instruction(&multi_transfer_input); + println!( + "Multi-transfer instruction: {:?}", + multi_transfer_instruction.accounts + ); + // Execute the multi-transfer instruction + rpc.create_and_send_transaction( + &[multi_transfer_instruction], + &payer.pubkey(), + &[&payer, &recipient_keypair], // Both payer and recipient need to sign + ) + .await + .unwrap(); + + // Verify the transfer was successful + let new_token_accounts = rpc + .indexer() + .unwrap() + .get_compressed_token_accounts_by_owner(&new_recipient, None, None) + .await + .unwrap() + .value + .items; + + assert_eq!( + new_token_accounts.len(), + 1, + "New recipient should have exactly one token account" + ); + assert_eq!( + new_token_accounts[0].token.amount, transfer_amount, + "New recipient should have the transferred amount" + ); + assert_eq!( + new_token_accounts[0].token.mint, mint_pda, + "New recipient token should have correct mint" + ); + + println!("✅ Multi-transfer executed successfully!"); + println!( + " - Transferred {} tokens from {} to {}", + transfer_amount, recipient, new_recipient + ); + + let compressed_token_account = &new_token_accounts[0]; + let decompress_amount = 300u64; + let remaining_amount = transfer_amount - decompress_amount; + + // Get the output queue from the token account's tree info + let output_queue = compressed_token_account.account.tree_info.queue; + + // Create compressed token associated token account for decompression + let (ctoken_ata_pubkey, _bump) = derive_ctoken_ata(&new_recipient, &mint_pda); + let (create_ata_instruction, _) = + create_ctoken_ata_instruction(&payer.pubkey(), &new_recipient, &mint_pda); + rpc.create_and_send_transaction(&[create_ata_instruction], &payer.pubkey(), &[&payer]) + .await + .unwrap(); + + // Get validity proof for the compressed token account + let validity_proof = rpc + .get_validity_proof(vec![compressed_token_account.account.hash], vec![], None) + .await + .unwrap() + .value; + + // Create decompression instruction using the wrapper + let decompress_instruction = create_decompress_instruction( + validity_proof.proof, + std::slice::from_ref(compressed_token_account), + decompress_amount, + ctoken_ata_pubkey, + payer.pubkey(), + output_queue, + ); + + println!("🔓 Sending decompression transaction..."); + println!(" - Decompress amount: {}", decompress_amount); + println!(" - Remaining amount: {}", remaining_amount); + println!(" - SPL token account: {}", ctoken_ata_pubkey); + println!(" metas {:?}", decompress_instruction.accounts); + // Send the decompression transaction + let tx_result = rpc + .create_and_send_transaction( + &[decompress_instruction], + &payer.pubkey(), + &[&payer, &new_recipient_keypair], + ) + .await; + + match tx_result { + Ok(_) => { + println!("✅ Decompression transaction sent successfully!"); + + // Verify the decompression worked + let ctoken_account = rpc.get_account(ctoken_ata_pubkey).await.unwrap().unwrap(); + + let token_account = + spl_token_2022::state::Account::unpack(&ctoken_account.data).unwrap(); + println!(" - CToken ATA balance: {}", token_account.amount); + + // Assert that the token account contains the expected decompressed amount + assert_eq!( + token_account.amount, decompress_amount, + "Token account should contain exactly the decompressed amount" + ); + + // Check remaining compressed tokens + let remaining_compressed = rpc + .indexer() + .unwrap() + .get_compressed_token_accounts_by_owner(&new_recipient, None, None) + .await + .unwrap() + .value + .items; + + if !remaining_compressed.is_empty() { + println!( + " - Remaining compressed tokens: {}", + remaining_compressed[0].token.amount + ); + } + } + Err(e) => { + println!("❌ Decompression transaction failed: {:?}", e); + panic!("Decompression transaction failed"); + } + } +} + +/// Creates a `InitializeAccount3` instruction. +pub fn initialize_account3( + token_program_id: &Pubkey, + account_pubkey: &Pubkey, + mint_pubkey: &Pubkey, + owner_pubkey: &Pubkey, +) -> Result { + let data = spl_token_2022::instruction::TokenInstruction::InitializeAccount3 { + owner: *owner_pubkey, + } + .pack(); + + let accounts = vec![ + AccountMeta::new(*account_pubkey, false), + AccountMeta::new_readonly(*mint_pubkey, false), + ]; + + Ok(solana_sdk::instruction::Instruction { + program_id: *token_program_id, + accounts, + data, + }) +} + +/// Creates a `CloseAccount` instruction. +pub fn close_account( + token_program_id: &Pubkey, + account_pubkey: &Pubkey, + destination_pubkey: &Pubkey, + owner_pubkey: &Pubkey, +) -> Result { + let data = spl_token_2022::instruction::TokenInstruction::CloseAccount.pack(); + + let accounts = vec![ + AccountMeta::new(*account_pubkey, false), + AccountMeta::new(*destination_pubkey, false), + AccountMeta::new_readonly(*owner_pubkey, true), // signer + ]; + + Ok(solana_sdk::instruction::Instruction { + program_id: *token_program_id, + accounts, + data, + }) +} + +#[tokio::test] +async fn test_create_and_close_token_account() { + use spl_pod::bytemuck::pod_from_bytes; + use spl_token_2022::pod::PodAccount; + use spl_token_2022::state::AccountState; + + let mut rpc = LightProgramTest::new(ProgramTestConfig::new_v2(false, None)) + .await + .unwrap(); + let payer = rpc.get_payer().insecure_clone(); + let payer_pubkey = payer.pubkey(); + + // Create a mock mint pubkey (we don't need actual mint for this test) + let mint_pubkey = Pubkey::new_unique(); + + // Create owner for the token account + let owner_keypair = Keypair::new(); + let owner_pubkey = owner_keypair.pubkey(); + + // Create a new keypair for the token account + let token_account_keypair = Keypair::new(); + let token_account_pubkey = token_account_keypair.pubkey(); + + // First create the account using system program + let create_account_system_ix = solana_sdk::system_instruction::create_account( + &payer_pubkey, + &token_account_pubkey, + rpc.get_minimum_balance_for_rent_exemption(165) + .await + .unwrap(), // SPL token account size + 165, + &light_compressed_token::ID, // Our program owns the account + ); + + // Then use SPL token SDK format but with our compressed token program ID + // This tests that our create_token_account instruction is compatible with SPL SDKs + let initialize_account_ix = initialize_account3( + &light_compressed_token::ID, // Use our program ID instead of spl_token_2022::ID + &token_account_pubkey, + &mint_pubkey, + &owner_pubkey, + ) + .unwrap(); + + // Execute both instructions in one transaction + let (blockhash, _) = rpc.get_latest_blockhash().await.unwrap(); + let transaction = solana_sdk::transaction::Transaction::new_signed_with_payer( + &[create_account_system_ix, initialize_account_ix], + Some(&payer_pubkey), + &[&payer, &token_account_keypair], + blockhash, + ); + + rpc.process_transaction(transaction.clone()) + .await + .expect("Failed to create token account using SPL SDK"); + + // Verify the token account was created correctly + let account_info = rpc + .get_account(token_account_pubkey) + .await + .unwrap() + .unwrap(); + + // Verify account exists and has correct owner + assert_eq!(account_info.owner, light_compressed_token::ID); + assert_eq!(account_info.data.len(), 165); // SPL token account size + + let pod_account = pod_from_bytes::(&account_info.data) + .expect("Failed to parse token account data"); + + // Verify the token account fields + assert_eq!(Pubkey::from(pod_account.mint), mint_pubkey); + assert_eq!(Pubkey::from(pod_account.owner), owner_pubkey); + assert_eq!(u64::from(pod_account.amount), 0); // Should start with zero balance + assert_eq!(pod_account.state, AccountState::Initialized as u8); + + // Now test closing the account using SPL SDK format + let destination_keypair = Keypair::new(); + let destination_pubkey = destination_keypair.pubkey(); + + // Airdrop some lamports to destination account so it exists + rpc.context.airdrop(&destination_pubkey, 1_000_000).unwrap(); + + // Get initial lamports before closing + let initial_token_account_lamports = rpc + .get_account(token_account_pubkey) + .await + .unwrap() + .unwrap() + .lamports; + let initial_destination_lamports = rpc + .get_account(destination_pubkey) + .await + .unwrap() + .unwrap() + .lamports; + + // Create close account instruction using SPL SDK format + let close_account_ix = close_account( + &light_compressed_token::ID, + &token_account_pubkey, + &destination_pubkey, + &owner_pubkey, + ) + .unwrap(); + + // Execute the close instruction + let (blockhash, _) = rpc.get_latest_blockhash().await.unwrap(); + let close_transaction = solana_sdk::transaction::Transaction::new_signed_with_payer( + &[close_account_ix], + Some(&payer_pubkey), + &[&payer, &owner_keypair], // Need owner to sign + blockhash, + ); + + rpc.process_transaction(close_transaction) + .await + .expect("Failed to close token account using SPL SDK"); + + // Verify the account was closed (data should be cleared, lamports should be 0) + let closed_account = rpc.get_account(token_account_pubkey).await.unwrap(); + if let Some(account) = closed_account { + // Account still exists, but should have 0 lamports and cleared data + assert_eq!(account.lamports, 0, "Closed account should have 0 lamports"); + assert!( + account.data.iter().all(|&b| b == 0), + "Closed account data should be cleared" + ); + } + + // Verify lamports were transferred to destination + let final_destination_lamports = rpc + .get_account(destination_pubkey) + .await + .unwrap() + .unwrap() + .lamports; + assert_eq!( + final_destination_lamports, + initial_destination_lamports + initial_token_account_lamports, + "Destination should receive all lamports from closed account" + ); +} + +#[tokio::test] +async fn test_create_associated_token_account() { + use spl_pod::bytemuck::pod_from_bytes; + use spl_token_2022::pod::PodAccount; + use spl_token_2022::state::AccountState; + + let mut rpc = LightProgramTest::new(ProgramTestConfig::new_v2(false, None)) + .await + .unwrap(); + let payer = rpc.get_payer().insecure_clone(); + let payer_pubkey = payer.pubkey(); + + // Create a mock mint pubkey + let mint_pubkey = Pubkey::new_unique(); + + // Create owner for the associated token account + let owner_keypair = Keypair::new(); + let owner_pubkey = owner_keypair.pubkey(); + + // Calculate the expected associated token account address + let (expected_ata_pubkey, bump) = Pubkey::find_program_address( + &[ + owner_pubkey.as_ref(), + light_compressed_token::ID.as_ref(), + mint_pubkey.as_ref(), + ], + &light_compressed_token::ID, + ); + + // Build the create_associated_token_account instruction + use light_compressed_account::Pubkey as LightPubkey; + use light_compressed_token::create_associated_token_account::instruction_data::CreateAssociatedTokenAccountInstructionData; + + let instruction_data = CreateAssociatedTokenAccountInstructionData { + owner: LightPubkey::from(owner_pubkey.to_bytes()), + mint: LightPubkey::from(mint_pubkey.to_bytes()), + bump, + }; + + let mut instruction_data_bytes = vec![103u8]; // CreateAssociatedTokenAccount discriminator + instruction_data_bytes.extend_from_slice(&instruction_data.try_to_vec().unwrap()); + + // Create the accounts for the instruction + let accounts = vec![ + AccountMeta::new(payer_pubkey, true), // fee_payer (signer) + AccountMeta::new(expected_ata_pubkey, false), // associated_token_account + AccountMeta::new_readonly(mint_pubkey, false), // mint + AccountMeta::new_readonly(owner_pubkey, false), // owner + AccountMeta::new_readonly(system_program::ID, false), // system_program + ]; + + let instruction = solana_sdk::instruction::Instruction { + program_id: light_compressed_token::ID, + accounts, + data: instruction_data_bytes, + }; + + // Execute the instruction + let (blockhash, _) = rpc.get_latest_blockhash().await.unwrap(); + let transaction = solana_sdk::transaction::Transaction::new_signed_with_payer( + &[instruction], + Some(&payer_pubkey), + &[&payer], + blockhash, + ); + + rpc.process_transaction(transaction.clone()) + .await + .expect("Failed to create associated token account"); + + // Verify the associated token account was created correctly + let account_info = rpc.get_account(expected_ata_pubkey).await.unwrap().unwrap(); + + // Verify account exists and has correct owner + assert_eq!(account_info.owner, light_compressed_token::ID); + assert_eq!(account_info.data.len(), 165); // SPL token account size + + let pod_account = pod_from_bytes::(&account_info.data) + .expect("Failed to parse token account data"); + + // Verify the token account fields + assert_eq!(Pubkey::from(pod_account.mint), mint_pubkey); + assert_eq!(Pubkey::from(pod_account.owner), owner_pubkey); + assert_eq!(u64::from(pod_account.amount), 0); // Should start with zero balance + assert_eq!(pod_account.state, AccountState::Initialized as u8); + + // Verify the PDA derivation is correct + let (derived_ata_pubkey, derived_bump) = Pubkey::find_program_address( + &[ + owner_pubkey.as_ref(), + light_compressed_token::ID.as_ref(), + mint_pubkey.as_ref(), + ], + &light_compressed_token::ID, + ); + assert_eq!(expected_ata_pubkey, derived_ata_pubkey); + assert_eq!(bump, derived_bump); +} diff --git a/program-tests/compressed-token-test/tests/test.rs b/program-tests/compressed-token-test/tests/test.rs index c8e165d856..4bb645205d 100644 --- a/program-tests/compressed-token-test/tests/test.rs +++ b/program-tests/compressed-token-test/tests/test.rs @@ -1,11 +1,16 @@ -#![cfg(feature = "test-sbf")] +// #![cfg(feature = "test-sbf")] use std::{assert_eq, str::FromStr}; +use anchor_lang::prelude::borsh::BorshSerialize; +use light_compressed_token::mint_to_compressed::instructions::{ + CompressedMintInput, CompressedMintInputs, MintToCompressedInstructionData, Recipient, +}; + use account_compression::errors::AccountCompressionErrorCode; use anchor_lang::{ - prelude::AccountMeta, system_program, AccountDeserialize, AnchorDeserialize, AnchorSerialize, - InstructionData, ToAccountMetas, + prelude::AccountMeta, solana_program::program_pack::Pack, system_program, AccountDeserialize, + AnchorDeserialize, InstructionData, ToAccountMetas, }; use anchor_spl::{ token::{Mint, TokenAccount}, diff --git a/programs/compressed-token/README.md b/programs/compressed-token/README.md index 764e509cdc..227bd71394 100644 --- a/programs/compressed-token/README.md +++ b/programs/compressed-token/README.md @@ -1,13 +1,2 @@ # Compressed Token Program - -A token program on the Solana blockchain using ZK Compression. - -This program provides an interface and implementation that third parties can utilize to create and use compressed tokens on Solana. - -Documentation is available at https://zkcompression.com - -Source code: https://github.com/Lightprotocol/light-protocol/tree/main/programs/compressed-token - -## Audit - -This code is unaudited. Use at your own risk. +- program wraps the anchor program and new optimized instructions diff --git a/programs/compressed-token/Cargo.toml b/programs/compressed-token/anchor/Cargo.toml similarity index 95% rename from programs/compressed-token/Cargo.toml rename to programs/compressed-token/anchor/Cargo.toml index 4c1604dcdf..c6c51bb089 100644 --- a/programs/compressed-token/Cargo.toml +++ b/programs/compressed-token/anchor/Cargo.toml @@ -1,5 +1,5 @@ [package] -name = "light-compressed-token" +name = "anchor-compressed-token" version = "2.0.0" description = "Generalized token compression on Solana" repository = "https://github.com/Lightprotocol/light-protocol" @@ -8,7 +8,7 @@ edition = "2021" [lib] crate-type = ["cdylib", "lib"] -name = "light_compressed_token" +name = "anchor_compressed_token" [features] no-entrypoint = [] diff --git a/programs/compressed-token/anchor/README.md b/programs/compressed-token/anchor/README.md new file mode 100644 index 0000000000..764e509cdc --- /dev/null +++ b/programs/compressed-token/anchor/README.md @@ -0,0 +1,13 @@ +# Compressed Token Program + +A token program on the Solana blockchain using ZK Compression. + +This program provides an interface and implementation that third parties can utilize to create and use compressed tokens on Solana. + +Documentation is available at https://zkcompression.com + +Source code: https://github.com/Lightprotocol/light-protocol/tree/main/programs/compressed-token + +## Audit + +This code is unaudited. Use at your own risk. diff --git a/programs/compressed-token/Xargo.toml b/programs/compressed-token/anchor/Xargo.toml similarity index 100% rename from programs/compressed-token/Xargo.toml rename to programs/compressed-token/anchor/Xargo.toml diff --git a/programs/compressed-token/src/batch_compress.rs b/programs/compressed-token/anchor/src/batch_compress.rs similarity index 100% rename from programs/compressed-token/src/batch_compress.rs rename to programs/compressed-token/anchor/src/batch_compress.rs diff --git a/programs/compressed-token/src/burn.rs b/programs/compressed-token/anchor/src/burn.rs similarity index 100% rename from programs/compressed-token/src/burn.rs rename to programs/compressed-token/anchor/src/burn.rs diff --git a/programs/compressed-token/src/constants.rs b/programs/compressed-token/anchor/src/constants.rs similarity index 72% rename from programs/compressed-token/src/constants.rs rename to programs/compressed-token/anchor/src/constants.rs index 67b9ab70f8..8043ec4d55 100644 --- a/programs/compressed-token/src/constants.rs +++ b/programs/compressed-token/anchor/src/constants.rs @@ -1,3 +1,5 @@ +// 1 in little endian (for compressed mint accounts) +pub const COMPRESSED_MINT_DISCRIMINATOR: [u8; 8] = [1, 0, 0, 0, 0, 0, 0, 0]; // 2 in little endian pub const TOKEN_COMPRESSED_ACCOUNT_DISCRIMINATOR: [u8; 8] = [2, 0, 0, 0, 0, 0, 0, 0]; pub const BUMP_CPI_AUTHORITY: u8 = 254; diff --git a/programs/compressed-token/anchor/src/create_mint.rs b/programs/compressed-token/anchor/src/create_mint.rs new file mode 100644 index 0000000000..1e0675a6d3 --- /dev/null +++ b/programs/compressed-token/anchor/src/create_mint.rs @@ -0,0 +1,422 @@ +use anchor_lang::{ + prelude::{borsh, Pubkey}, + AnchorDeserialize, AnchorSerialize, +}; +use light_compressed_account::hash_to_bn254_field_size_be; +use light_hasher::{errors::HasherError, Hasher, Poseidon}; + +// TODO: add is native_compressed, this means that the compressed mint is always synced with the spl mint +// compressed mint accounts which are not native_compressed can be not in sync the spl mint account is the source of truth +// Order is optimized for hashing. +// freeze_authority option is skipped if None. +#[derive(Debug, PartialEq, Eq, AnchorSerialize, AnchorDeserialize, Clone)] +pub struct CompressedMint { + /// Pda with seed address of compressed mint + pub spl_mint: Pubkey, + /// Total supply of tokens. + pub supply: u64, + /// Number of base 10 digits to the right of the decimal place. + pub decimals: u8, + /// Extension, necessary for mint to. + pub is_decompressed: bool, + /// Optional authority used to mint new tokens. The mint authority may only + /// be provided during mint creation. If no mint authority is present + /// then the mint has a fixed supply and no further tokens may be + /// minted. + pub mint_authority: Option, + /// Optional authority to freeze token accounts. + pub freeze_authority: Option, + // Not necessary. + // /// Is `true` if this structure has been initialized + // pub is_initialized: bool, + pub num_extensions: u8, // TODO: check again how token22 does it +} + +impl CompressedMint { + pub fn hash(&self) -> std::result::Result<[u8; 32], HasherError> { + let hashed_spl_mint = hash_to_bn254_field_size_be(self.spl_mint.to_bytes().as_slice()); + let mut supply_bytes = [0u8; 32]; + supply_bytes[24..].copy_from_slice(self.supply.to_be_bytes().as_slice()); + + let hashed_mint_authority; + let hashed_mint_authority_option = if let Some(mint_authority) = self.mint_authority { + hashed_mint_authority = + hash_to_bn254_field_size_be(mint_authority.to_bytes().as_slice()); + Some(&hashed_mint_authority) + } else { + None + }; + + let hashed_freeze_authority; + let hashed_freeze_authority_option = if let Some(freeze_authority) = self.freeze_authority { + hashed_freeze_authority = + hash_to_bn254_field_size_be(freeze_authority.to_bytes().as_slice()); + Some(&hashed_freeze_authority) + } else { + None + }; + + Self::hash_with_hashed_values( + &hashed_spl_mint, + &supply_bytes, + self.decimals, + self.is_decompressed, + &hashed_mint_authority_option, + &hashed_freeze_authority_option, + self.num_extensions, + ) + } + + pub fn hash_with_hashed_values( + hashed_spl_mint: &[u8; 32], + supply_bytes: &[u8; 32], + decimals: u8, + is_decompressed: bool, + hashed_mint_authority: &Option<&[u8; 32]>, + hashed_freeze_authority: &Option<&[u8; 32]>, + num_extensions: u8, + ) -> std::result::Result<[u8; 32], HasherError> { + let mut hash_inputs = vec![hashed_spl_mint.as_slice(), supply_bytes.as_slice()]; + + // Add decimals with prefix if not 0 + let mut decimals_bytes = [0u8; 32]; + if decimals != 0 { + decimals_bytes[30] = 1; // decimals prefix + decimals_bytes[31] = decimals; + hash_inputs.push(&decimals_bytes[..]); + } + + // Add is_decompressed with prefix if true + let mut is_decompressed_bytes = [0u8; 32]; + if is_decompressed { + is_decompressed_bytes[30] = 2; // is_decompressed prefix + is_decompressed_bytes[31] = 1; // true as 1 + hash_inputs.push(&is_decompressed_bytes[..]); + } + + // Add mint authority if present + if let Some(hashed_mint_authority) = hashed_mint_authority { + hash_inputs.push(hashed_mint_authority.as_slice()); + } + + // Add freeze authority if present + let empty_authority = [0u8; 32]; + if let Some(hashed_freeze_authority) = hashed_freeze_authority { + // If there is freeze authority but no mint authority, add empty mint authority + if hashed_mint_authority.is_none() { + hash_inputs.push(&empty_authority[..]); + } + hash_inputs.push(hashed_freeze_authority.as_slice()); + } + + // Add num_extensions with prefix if not 0 + let mut num_extensions_bytes = [0u8; 32]; + if num_extensions != 0 { + num_extensions_bytes[30] = 3; // num_extensions prefix + num_extensions_bytes[31] = num_extensions; + hash_inputs.push(&num_extensions_bytes[..]); + } + + Poseidon::hashv(hash_inputs.as_slice()) + } +} + +#[cfg(test)] +pub mod test { + use rand::Rng; + + use super::*; + + #[test] + fn test_equivalency_of_hash_functions() { + let compressed_mint = CompressedMint { + spl_mint: Pubkey::new_unique(), + supply: 1000000, + decimals: 6, + is_decompressed: false, + mint_authority: Some(Pubkey::new_unique()), + freeze_authority: Some(Pubkey::new_unique()), + num_extensions: 2, + }; + + let hash_result = compressed_mint.hash().unwrap(); + + // Test with hashed values + let hashed_spl_mint = + hash_to_bn254_field_size_be(compressed_mint.spl_mint.to_bytes().as_slice()); + let mut supply_bytes = [0u8; 32]; + supply_bytes[24..].copy_from_slice(compressed_mint.supply.to_be_bytes().as_slice()); + + let hashed_mint_authority = hash_to_bn254_field_size_be( + compressed_mint + .mint_authority + .unwrap() + .to_bytes() + .as_slice(), + ); + let hashed_freeze_authority = hash_to_bn254_field_size_be( + compressed_mint + .freeze_authority + .unwrap() + .to_bytes() + .as_slice(), + ); + + let hash_with_hashed_values = CompressedMint::hash_with_hashed_values( + &hashed_spl_mint, + &supply_bytes, + compressed_mint.decimals, + compressed_mint.is_decompressed, + &Some(&hashed_mint_authority), + &Some(&hashed_freeze_authority), + compressed_mint.num_extensions, + ) + .unwrap(); + + assert_eq!(hash_result, hash_with_hashed_values); + } + + #[test] + fn test_equivalency_without_optional_fields() { + let compressed_mint = CompressedMint { + spl_mint: Pubkey::new_unique(), + supply: 500000, + decimals: 0, + is_decompressed: false, + mint_authority: None, + freeze_authority: None, + num_extensions: 0, + }; + + let hash_result = compressed_mint.hash().unwrap(); + + let hashed_spl_mint = + hash_to_bn254_field_size_be(compressed_mint.spl_mint.to_bytes().as_slice()); + let mut supply_bytes = [0u8; 32]; + supply_bytes[24..].copy_from_slice(compressed_mint.supply.to_be_bytes().as_slice()); + + let hash_with_hashed_values = CompressedMint::hash_with_hashed_values( + &hashed_spl_mint, + &supply_bytes, + compressed_mint.decimals, + compressed_mint.is_decompressed, + &None, + &None, + compressed_mint.num_extensions, + ) + .unwrap(); + + assert_eq!(hash_result, hash_with_hashed_values); + } + + fn equivalency_of_hash_functions_rnd_iters() { + let mut rng = rand::thread_rng(); + + for _ in 0..ITERS { + let compressed_mint = CompressedMint { + spl_mint: Pubkey::new_unique(), + supply: rng.gen(), + decimals: rng.gen_range(0..=18), + is_decompressed: rng.gen_bool(0.5), + mint_authority: if rng.gen_bool(0.5) { + Some(Pubkey::new_unique()) + } else { + None + }, + freeze_authority: if rng.gen_bool(0.5) { + Some(Pubkey::new_unique()) + } else { + None + }, + num_extensions: rng.gen_range(0..=10), + }; + + let hash_result = compressed_mint.hash().unwrap(); + + let hashed_spl_mint = + hash_to_bn254_field_size_be(compressed_mint.spl_mint.to_bytes().as_slice()); + let mut supply_bytes = [0u8; 32]; + supply_bytes[24..].copy_from_slice(compressed_mint.supply.to_be_bytes().as_slice()); + + let hashed_mint_authority; + let hashed_mint_authority_option = + if let Some(mint_authority) = compressed_mint.mint_authority { + hashed_mint_authority = + hash_to_bn254_field_size_be(mint_authority.to_bytes().as_slice()); + Some(&hashed_mint_authority) + } else { + None + }; + + let hashed_freeze_authority; + let hashed_freeze_authority_option = + if let Some(freeze_authority) = compressed_mint.freeze_authority { + hashed_freeze_authority = + hash_to_bn254_field_size_be(freeze_authority.to_bytes().as_slice()); + Some(&hashed_freeze_authority) + } else { + None + }; + + let hash_with_hashed_values = CompressedMint::hash_with_hashed_values( + &hashed_spl_mint, + &supply_bytes, + compressed_mint.decimals, + compressed_mint.is_decompressed, + &hashed_mint_authority_option, + &hashed_freeze_authority_option, + compressed_mint.num_extensions, + ) + .unwrap(); + + assert_eq!(hash_result, hash_with_hashed_values); + } + } + + #[test] + fn test_equivalency_random_iterations() { + equivalency_of_hash_functions_rnd_iters::<1000>(); + } + + #[test] + fn test_hash_collision_detection() { + let mut vec_previous_hashes = Vec::new(); + + // Base compressed mint + let base_mint = CompressedMint { + spl_mint: Pubkey::new_unique(), + supply: 1000000, + decimals: 6, + is_decompressed: false, + mint_authority: None, + freeze_authority: None, + num_extensions: 0, + }; + + let base_hash = base_mint.hash().unwrap(); + vec_previous_hashes.push(base_hash); + + // Different spl_mint + let mut mint1 = base_mint.clone(); + mint1.spl_mint = Pubkey::new_unique(); + let hash1 = mint1.hash().unwrap(); + assert_to_previous_hashes(hash1, &mut vec_previous_hashes); + + // Different supply + let mut mint2 = base_mint.clone(); + mint2.supply = 2000000; + let hash2 = mint2.hash().unwrap(); + assert_to_previous_hashes(hash2, &mut vec_previous_hashes); + + // Different decimals + let mut mint3 = base_mint.clone(); + mint3.decimals = 9; + let hash3 = mint3.hash().unwrap(); + assert_to_previous_hashes(hash3, &mut vec_previous_hashes); + + // Different is_decompressed + let mut mint4 = base_mint.clone(); + mint4.is_decompressed = true; + let hash4 = mint4.hash().unwrap(); + assert_to_previous_hashes(hash4, &mut vec_previous_hashes); + + // Different mint_authority + let mut mint5 = base_mint.clone(); + mint5.mint_authority = Some(Pubkey::new_unique()); + let hash5 = mint5.hash().unwrap(); + assert_to_previous_hashes(hash5, &mut vec_previous_hashes); + + // Different freeze_authority + let mut mint6 = base_mint.clone(); + mint6.freeze_authority = Some(Pubkey::new_unique()); + let hash6 = mint6.hash().unwrap(); + assert_to_previous_hashes(hash6, &mut vec_previous_hashes); + + // Different num_extensions + let mut mint7 = base_mint.clone(); + mint7.num_extensions = 5; + let hash7 = mint7.hash().unwrap(); + assert_to_previous_hashes(hash7, &mut vec_previous_hashes); + + // Multiple fields different + let mut mint8 = base_mint.clone(); + mint8.decimals = 18; + mint8.is_decompressed = true; + mint8.mint_authority = Some(Pubkey::new_unique()); + mint8.freeze_authority = Some(Pubkey::new_unique()); + mint8.num_extensions = 3; + let hash8 = mint8.hash().unwrap(); + assert_to_previous_hashes(hash8, &mut vec_previous_hashes); + } + + #[test] + fn test_authority_hash_collision_prevention() { + // This is a critical security test: ensuring that different authority combinations + // with the same pubkey don't produce the same hash + let same_pubkey = Pubkey::new_unique(); + + let base_mint = CompressedMint { + spl_mint: Pubkey::new_unique(), + supply: 1000000, + decimals: 6, + is_decompressed: false, + mint_authority: None, + freeze_authority: None, + num_extensions: 0, + }; + + // Case 1: None mint_authority, Some freeze_authority + let mut mint1 = base_mint.clone(); + mint1.mint_authority = None; + mint1.freeze_authority = Some(same_pubkey); + let hash1 = mint1.hash().unwrap(); + + // Case 2: Some mint_authority, None freeze_authority (using same pubkey) + let mut mint2 = base_mint.clone(); + mint2.mint_authority = Some(same_pubkey); + mint2.freeze_authority = None; + let hash2 = mint2.hash().unwrap(); + + // These must be different hashes to prevent authority confusion + assert_ne!( + hash1, hash2, + "CRITICAL: Hash collision between different authority configurations!" + ); + + // Case 3: Both authorities present (should also be different) + let mut mint3 = base_mint.clone(); + mint3.mint_authority = Some(same_pubkey); + mint3.freeze_authority = Some(same_pubkey); + let hash3 = mint3.hash().unwrap(); + + assert_ne!( + hash1, hash3, + "Hash collision between freeze-only and both authorities!" + ); + assert_ne!( + hash2, hash3, + "Hash collision between mint-only and both authorities!" + ); + + // Test with different pubkeys for good measure + let different_pubkey = Pubkey::new_unique(); + let mut mint4 = base_mint.clone(); + mint4.mint_authority = Some(same_pubkey); + mint4.freeze_authority = Some(different_pubkey); + let hash4 = mint4.hash().unwrap(); + + assert_ne!( + hash1, hash4, + "Hash collision with different freeze authority!" + ); + assert_ne!(hash2, hash4, "Hash collision with different authorities!"); + assert_ne!(hash3, hash4, "Hash collision with mixed authorities!"); + } + + fn assert_to_previous_hashes(hash: [u8; 32], previous_hashes: &mut Vec<[u8; 32]>) { + for previous_hash in previous_hashes.iter() { + assert_ne!(hash, *previous_hash, "Hash collision detected!"); + } + previous_hashes.push(hash); + } +} diff --git a/programs/compressed-token/src/delegation.rs b/programs/compressed-token/anchor/src/delegation.rs similarity index 100% rename from programs/compressed-token/src/delegation.rs rename to programs/compressed-token/anchor/src/delegation.rs diff --git a/programs/compressed-token/src/freeze.rs b/programs/compressed-token/anchor/src/freeze.rs similarity index 100% rename from programs/compressed-token/src/freeze.rs rename to programs/compressed-token/anchor/src/freeze.rs diff --git a/programs/compressed-token/src/instructions/burn.rs b/programs/compressed-token/anchor/src/instructions/burn.rs similarity index 100% rename from programs/compressed-token/src/instructions/burn.rs rename to programs/compressed-token/anchor/src/instructions/burn.rs diff --git a/programs/compressed-token/anchor/src/instructions/create_compressed_mint.rs b/programs/compressed-token/anchor/src/instructions/create_compressed_mint.rs new file mode 100644 index 0000000000..582ac1905c --- /dev/null +++ b/programs/compressed-token/anchor/src/instructions/create_compressed_mint.rs @@ -0,0 +1,48 @@ +use account_compression::program::AccountCompression; +use anchor_lang::prelude::*; +use light_system_program::program::LightSystemProgram; + +use crate::program::LightCompressedToken; + +/// Creates a compressed mint stored as a compressed account +#[derive(Accounts)] +pub struct CreateCompressedMintInstruction<'info> { + #[account(mut)] + pub fee_payer: Signer<'info>, + + /// CPI authority for compressed account creation + pub cpi_authority_pda: AccountInfo<'info>, + + /// Light system program for compressed account creation + pub light_system_program: Program<'info, LightSystemProgram>, + + /// Account compression program + pub account_compression_program: Program<'info, AccountCompression>, + + /// Registered program PDA for light system program + pub registered_program_pda: AccountInfo<'info>, + + /// NoOp program for event emission + pub noop_program: AccountInfo<'info>, + + /// Authority for account compression + pub account_compression_authority: AccountInfo<'info>, + + /// Self program reference + pub self_program: Program<'info, LightCompressedToken>, + + pub system_program: Program<'info, System>, + + /// Address merkle tree for compressed account creation + /// CHECK: Validated by light-system-program + #[account(mut)] + pub address_merkle_tree: AccountInfo<'info>, + + /// Output queue account where compressed mint will be stored + /// CHECK: Validated by light-system-program + #[account(mut)] + pub output_queue: AccountInfo<'info>, + + /// Signer used as seed for PDA derivation (ensures uniqueness) + pub mint_signer: Signer<'info>, +} diff --git a/programs/compressed-token/src/instructions/create_token_pool.rs b/programs/compressed-token/anchor/src/instructions/create_token_pool.rs similarity index 100% rename from programs/compressed-token/src/instructions/create_token_pool.rs rename to programs/compressed-token/anchor/src/instructions/create_token_pool.rs diff --git a/programs/compressed-token/src/instructions/freeze.rs b/programs/compressed-token/anchor/src/instructions/freeze.rs similarity index 100% rename from programs/compressed-token/src/instructions/freeze.rs rename to programs/compressed-token/anchor/src/instructions/freeze.rs diff --git a/programs/compressed-token/src/instructions/generic.rs b/programs/compressed-token/anchor/src/instructions/generic.rs similarity index 100% rename from programs/compressed-token/src/instructions/generic.rs rename to programs/compressed-token/anchor/src/instructions/generic.rs diff --git a/programs/compressed-token/src/instructions/mod.rs b/programs/compressed-token/anchor/src/instructions/mod.rs similarity index 74% rename from programs/compressed-token/src/instructions/mod.rs rename to programs/compressed-token/anchor/src/instructions/mod.rs index c934aac35a..b27b424afa 100644 --- a/programs/compressed-token/src/instructions/mod.rs +++ b/programs/compressed-token/anchor/src/instructions/mod.rs @@ -1,10 +1,12 @@ pub mod burn; +pub mod create_compressed_mint; pub mod create_token_pool; pub mod freeze; pub mod generic; pub mod transfer; pub use burn::*; +pub use create_compressed_mint::*; pub use create_token_pool::*; pub use freeze::*; pub use generic::*; diff --git a/programs/compressed-token/src/instructions/transfer.rs b/programs/compressed-token/anchor/src/instructions/transfer.rs similarity index 100% rename from programs/compressed-token/src/instructions/transfer.rs rename to programs/compressed-token/anchor/src/instructions/transfer.rs diff --git a/programs/compressed-token/src/lib.rs b/programs/compressed-token/anchor/src/lib.rs similarity index 98% rename from programs/compressed-token/src/lib.rs rename to programs/compressed-token/anchor/src/lib.rs index 97cf581510..f03919ea87 100644 --- a/programs/compressed-token/src/lib.rs +++ b/programs/compressed-token/anchor/src/lib.rs @@ -16,6 +16,7 @@ pub use instructions::*; pub mod burn; pub use burn::*; pub mod batch_compress; +pub mod create_mint; use light_compressed_account::instruction_data::cpi_context::CompressedCpiContext; use crate::process_transfer::CompressedTokenInstructionDataTransfer; @@ -46,7 +47,7 @@ pub mod light_compressed_token { pub fn create_token_pool<'info>( ctx: Context<'_, '_, '_, 'info, CreateTokenPoolInstruction<'info>>, ) -> Result<()> { - create_token_pool::assert_mint_extensions( + instructions::create_token_pool::assert_mint_extensions( &ctx.accounts.mint.to_account_info().try_borrow_data()?, ) } @@ -276,4 +277,9 @@ pub enum ErrorCode { NoMatchingBumpFound, NoAmount, AmountsAndAmountProvided, + MintIsNone, + InvalidMintPda, + InputsOutOfOrder, + TooManyMints, + InvalidExtensionType, } diff --git a/programs/compressed-token/src/process_compress_spl_token_account.rs b/programs/compressed-token/anchor/src/process_compress_spl_token_account.rs similarity index 100% rename from programs/compressed-token/src/process_compress_spl_token_account.rs rename to programs/compressed-token/anchor/src/process_compress_spl_token_account.rs diff --git a/programs/compressed-token/src/process_mint.rs b/programs/compressed-token/anchor/src/process_mint.rs similarity index 58% rename from programs/compressed-token/src/process_mint.rs rename to programs/compressed-token/anchor/src/process_mint.rs index 719eeda736..7267e5943b 100644 --- a/programs/compressed-token/src/process_mint.rs +++ b/programs/compressed-token/anchor/src/process_mint.rs @@ -2,7 +2,11 @@ use account_compression::program::AccountCompression; use anchor_lang::prelude::*; use anchor_spl::token_interface::{TokenAccount, TokenInterface}; use light_compressed_account::{ - instruction_data::data::OutputCompressedAccountWithPackedContext, pubkey::AsPubkey, + compressed_account::PackedCompressedAccountWithMerkleContext, + instruction_data::{ + compressed_proof::CompressedProof, data::OutputCompressedAccountWithPackedContext, + }, + pubkey::AsPubkey, }; use light_system_program::program::LightSystemProgram; use light_zero_copy::num_trait::ZeroCopyNumTrait; @@ -10,8 +14,8 @@ use light_zero_copy::num_trait::ZeroCopyNumTrait; use { crate::{ check_spl_token_pool_derivation_with_index, - process_transfer::create_output_compressed_accounts, - process_transfer::get_cpi_signer_seeds, spl_compression::spl_token_transfer, + process_transfer::{create_output_compressed_accounts, get_cpi_signer_seeds}, + spl_compression::spl_token_transfer, }, light_compressed_account::hash_to_bn254_field_size_be, light_heap::{bench_sbf_end, bench_sbf_start, GLOBAL_ALLOCATOR}, @@ -58,6 +62,7 @@ pub fn process_mint_to_or_compress<'info, const IS_MINT_TO: bool>( #[cfg(target_os = "solana")] { let option_compression_lamports = if lamports.unwrap_or(0) == 0 { 0 } else { 8 }; + let inputs_len = 1 + 4 + 4 + 4 + amounts.len() * 162 + 1 + 1 + 1 + 1 + option_compression_lamports; // inputs_len = @@ -75,11 +80,15 @@ pub fn process_mint_to_or_compress<'info, const IS_MINT_TO: bool>( let pre_compressed_acounts_pos = GLOBAL_ALLOCATOR.get_heap_pos(); bench_sbf_start!("tm_mint_spl_to_pool_pda"); - let mint = if IS_MINT_TO { - // 7,978 CU + let (mint, compressed_mint_update_data) = if IS_MINT_TO { + // EXISTING SPL MINT PATH mint_spl_to_pool_pda(&ctx, &amounts)?; - ctx.accounts.mint.as_ref().unwrap().key() + ( + ctx.accounts.mint.as_ref().unwrap().key(), + None::, + ) } else { + // EXISTING BATCH COMPRESS PATH let mut amount = 0u64; for a in amounts { amount += (*a).into(); @@ -103,7 +112,7 @@ pub fn process_mint_to_or_compress<'info, const IS_MINT_TO: bool>( ctx.accounts.token_program.to_account_info(), amount, )?; - mint + (mint, None) }; let hashed_mint = hash_to_bn254_field_size_be(mint.as_ref()); @@ -126,10 +135,15 @@ pub fn process_mint_to_or_compress<'info, const IS_MINT_TO: bool>( )?; bench_sbf_end!("tm_output_compressed_accounts"); - cpi_execute_compressed_transaction_mint_to( + // Create compressed mint update data if needed + let (input_compressed_accounts, proof) = (vec![], None); + // Execute single CPI call with updated serialization + cpi_execute_compressed_transaction_mint_to::( &ctx, + input_compressed_accounts.as_slice(), output_compressed_accounts, &mut inputs, + proof, pre_compressed_acounts_pos, )?; @@ -147,12 +161,123 @@ pub fn process_mint_to_or_compress<'info, const IS_MINT_TO: bool>( Ok(()) } +// #[cfg(target_os = "solana")] +// fn mint_with_compressed_mint<'info>( +// ctx: &Context<'_, '_, '_, 'info, MintToInstruction<'info>>, +// amounts: &[impl ZeroCopyNumTrait], +// compressed_inputs: &CompressedMintInputs, +// ) -> Result<( +// Pubkey, +// Option<( +// PackedCompressedAccountWithMerkleContext, +// OutputCompressedAccountWithPackedContext, +// )>, +// )> { +// let mint_pubkey = ctx +// .accounts +// .mint +// .as_ref() +// .ok_or(crate::ErrorCode::MintIsNone)? +// .key(); +// let compressed_mint: CompressedMint = CompressedMint { +// mint_authority: Some(ctx.accounts.authority.key()), +// freeze_authority: if compressed_inputs +// .compressed_mint_input +// .freeze_authority_is_set +// { +// Some(compressed_inputs.compressed_mint_input.freeze_authority) +// } else { +// None +// }, +// spl_mint: mint_pubkey, +// supply: compressed_inputs.compressed_mint_input.supply, +// decimals: compressed_inputs.compressed_mint_input.decimals, +// is_decompressed: compressed_inputs.compressed_mint_input.is_decompressed, +// num_extensions: compressed_inputs.compressed_mint_input.num_extensions, +// }; +// // Create input compressed account for existing mint +// let input_compressed_account = PackedCompressedAccountWithMerkleContext { +// compressed_account: CompressedAccount { +// owner: crate::ID.into(), +// lamports: 0, +// address: Some(compressed_inputs.address), +// data: Some(CompressedAccountData { +// discriminator: COMPRESSED_MINT_DISCRIMINATOR, +// data: Vec::new(), +// // TODO: hash with hashed inputs +// data_hash: compressed_mint.hash().map_err(ProgramError::from)?, +// }), +// }, +// merkle_context: compressed_inputs.merkle_context, +// root_index: compressed_inputs.root_index, +// read_only: false, +// }; +// let total_mint_amount: u64 = amounts.iter().map(|a| (*a).into()).sum(); +// let updated_compressed_mint = if compressed_mint.is_decompressed { +// // SYNC WITH SPL MINT (SPL is source of truth) + +// // Mint to SPL token pool as normal +// mint_spl_to_pool_pda(ctx, amounts)?; + +// // Read updated SPL mint state for sync +// let spl_mint_info = ctx +// .accounts +// .mint +// .as_ref() +// .ok_or(crate::ErrorCode::MintIsNone)?; +// let spl_mint_data = spl_mint_info.data.borrow(); +// let spl_mint = anchor_spl::token::Mint::try_deserialize(&mut &spl_mint_data[..])?; + +// // Create updated compressed mint with synced state +// let mut updated_compressed_mint = compressed_mint; +// updated_compressed_mint.supply = spl_mint.supply; +// updated_compressed_mint +// } else { +// // PURE COMPRESSED MINT - no SPL backing +// let mut updated_compressed_mint = compressed_mint; +// updated_compressed_mint.supply = updated_compressed_mint +// .supply +// .checked_add(total_mint_amount) +// .ok_or(crate::ErrorCode::MintTooLarge)?; +// updated_compressed_mint +// }; +// let updated_data_hash = updated_compressed_mint +// .hash() +// .map_err(|_| crate::ErrorCode::HashToFieldError)?; + +// let mut updated_mint_bytes = Vec::new(); +// updated_compressed_mint.serialize(&mut updated_mint_bytes)?; + +// let updated_compressed_account_data = CompressedAccountData { +// discriminator: COMPRESSED_MINT_DISCRIMINATOR, +// data: updated_mint_bytes, +// data_hash: updated_data_hash, +// }; + +// let output_compressed_mint_account = OutputCompressedAccountWithPackedContext { +// compressed_account: CompressedAccount { +// owner: crate::ID.into(), +// lamports: 0, +// address: Some(compressed_inputs.address), +// data: Some(updated_compressed_account_data), +// }, +// merkle_tree_index: compressed_inputs.output_merkle_tree_index, +// }; + +// Ok(( +// mint_pubkey, +// Some((input_compressed_account, output_compressed_mint_account)), +// )) +// } + #[cfg(target_os = "solana")] #[inline(never)] -pub fn cpi_execute_compressed_transaction_mint_to<'info>( - ctx: &Context<'_, '_, '_, 'info, MintToInstruction>, +pub fn cpi_execute_compressed_transaction_mint_to<'info, const IS_MINT_TO: bool>( + ctx: &Context<'_, '_, '_, 'info, MintToInstruction<'info>>, + mint_to_compressed_account: &[PackedCompressedAccountWithMerkleContext], output_compressed_accounts: Vec, inputs: &mut Vec, + proof: Option, pre_compressed_acounts_pos: usize, ) -> Result<()> { bench_sbf_start!("tm_cpi"); @@ -162,7 +287,12 @@ pub fn cpi_execute_compressed_transaction_mint_to<'info>( // 4300 CU for 10 accounts // 6700 CU for 20 accounts // 7,978 CU for 25 accounts - serialize_mint_to_cpi_instruction_data(inputs, &output_compressed_accounts); + serialize_mint_to_cpi_instruction_data_with_inputs( + inputs, + mint_to_compressed_account, + &output_compressed_accounts, + proof, + ); GLOBAL_ALLOCATOR.free_heap(pre_compressed_acounts_pos)?; @@ -181,7 +311,7 @@ pub fn cpi_execute_compressed_transaction_mint_to<'info>( }; // 1300 CU - let account_infos = vec![ + let mut account_infos = vec![ ctx.accounts.fee_payer.to_account_info(), ctx.accounts.cpi_authority_pda.to_account_info(), ctx.accounts.registered_program_pda.to_account_info(), @@ -195,9 +325,16 @@ pub fn cpi_execute_compressed_transaction_mint_to<'info>( ctx.accounts.light_system_program.to_account_info(), // none cpi_context_account ctx.accounts.merkle_tree.to_account_info(), // first remaining account ]; + // Don't add for batch compress + if IS_MINT_TO { + // Add remaining account metas (compressed mint merkle tree should be writable) + for remaining in ctx.remaining_accounts { + account_infos.push(remaining.to_account_info()); + } + } // account_metas take 1k cu - let accounts = vec![ + let mut accounts = vec![ AccountMeta { pubkey: account_infos[0].key(), is_signer: true, @@ -255,7 +392,18 @@ pub fn cpi_execute_compressed_transaction_mint_to<'info>( is_writable: true, }, ]; - + // Don't add for batch compress + if IS_MINT_TO { + // Add remaining account metas (compressed mint merkle tree should be writable) + for remaining in &account_infos[12..] { + msg!(" remaining.key() {:?}", remaining.key()); + accounts.push(AccountMeta { + pubkey: remaining.key(), + is_signer: false, + is_writable: remaining.is_writable, + }); + } + } let instruction = anchor_lang::solana_program::instruction::Instruction { program_id: light_system_program::ID, accounts, @@ -274,26 +422,41 @@ pub fn cpi_execute_compressed_transaction_mint_to<'info>( } #[inline(never)] -pub fn serialize_mint_to_cpi_instruction_data( +pub fn serialize_mint_to_cpi_instruction_data_with_inputs( inputs: &mut Vec, + input_compressed_accounts: &[PackedCompressedAccountWithMerkleContext], output_compressed_accounts: &[OutputCompressedAccountWithPackedContext], + proof: Option, ) { - let len = output_compressed_accounts.len(); - // proof (option None) - inputs.extend_from_slice(&[0u8]); - // two empty vecs 4 bytes of zeroes each: address_params, + // proof (option) + if let Some(proof) = proof { + inputs.extend_from_slice(&[1u8]); // Some + proof.serialize(inputs).unwrap(); + } else { + inputs.extend_from_slice(&[0u8]); // None + } + + // new_address_params (empty for mint operations) + inputs.extend_from_slice(&[0u8; 4]); + // input_compressed_accounts_with_merkle_context - inputs.extend_from_slice(&[0u8; 8]); - // lenght of output_compressed_accounts vec as u32 - inputs.extend_from_slice(&[(len as u8), 0, 0, 0]); - let mut sum_lamports = 0u64; + let input_len = input_compressed_accounts.len(); + inputs.extend_from_slice(&[(input_len as u8), 0, 0, 0]); + for input_account in input_compressed_accounts.iter() { + input_account.serialize(inputs).unwrap(); + } + // output_compressed_accounts + let output_len = output_compressed_accounts.len(); + inputs.extend_from_slice(&[(output_len as u8), 0, 0, 0]); + let mut sum_lamports = 0u64; for compressed_account in output_compressed_accounts.iter() { compressed_account.serialize(inputs).unwrap(); sum_lamports = sum_lamports .checked_add(compressed_account.compressed_account.lamports) .unwrap(); } + // None relay_fee inputs.extend_from_slice(&[0u8; 1]); @@ -309,6 +472,158 @@ pub fn serialize_mint_to_cpi_instruction_data( inputs.extend_from_slice(&[0u8]); } +// #[cfg(target_os = "solana")] +// fn create_compressed_mint_update_accounts( +// updated_compressed_mint: CompressedMint, +// compressed_inputs: CompressedMintInputs, +// ) -> Result<( +// PackedCompressedAccountWithMerkleContext, +// OutputCompressedAccountWithPackedContext, +// )> { +// // Create input compressed account for existing mint +// let input_compressed_account = PackedCompressedAccountWithMerkleContext { +// compressed_account: CompressedAccount { +// owner: crate::ID.into(), +// lamports: 0, +// address: Some(compressed_inputs.address), +// data: Some(CompressedAccountData { +// discriminator: COMPRESSED_MINT_DISCRIMINATOR, +// data: Vec::new(), +// data_hash: updated_compressed_mint.hash().map_err(ProgramError::from)?, +// }), +// }, +// merkle_context: compressed_inputs.merkle_context, +// root_index: compressed_inputs.root_index, +// read_only: false, +// }; +// msg!( +// "compressed_inputs.merkle_context: {:?}", +// compressed_inputs.merkle_context +// ); + +// // Create output compressed account for updated mint +// let mut updated_mint_bytes = Vec::new(); +// updated_compressed_mint.serialize(&mut updated_mint_bytes)?; +// let updated_data_hash = updated_compressed_mint +// .hash() +// .map_err(|_| crate::ErrorCode::HashToFieldError)?; + +// let updated_compressed_account_data = CompressedAccountData { +// discriminator: COMPRESSED_MINT_DISCRIMINATOR, +// data: updated_mint_bytes, +// data_hash: updated_data_hash, +// }; + +// let output_compressed_mint_account = OutputCompressedAccountWithPackedContext { +// compressed_account: CompressedAccount { +// owner: crate::ID.into(), +// lamports: 0, +// address: Some(compressed_inputs.address), +// data: Some(updated_compressed_account_data), +// }, +// merkle_tree_index: compressed_inputs.output_merkle_tree_index, +// }; +// msg!( +// "compressed_inputs.output_merkle_tree_index {}", +// compressed_inputs.output_merkle_tree_index +// ); + +// Ok((input_compressed_account, output_compressed_mint_account)) +// } + +// #[cfg(target_os = "solana")] +// #[inline(never)] +// pub fn cpi_execute_compressed_transaction_mint_to_with_inputs<'info>( +// ctx: &Context<'_, '_, '_, 'info, MintToInstruction<'info>>, +// input_compressed_accounts: Vec, +// output_compressed_accounts: Vec, +// proof: Option, +// inputs: &mut Vec, +// pre_compressed_accounts_pos: usize, +// ) -> Result<()> { +// bench_sbf_start!("tm_cpi_mint_update"); + +// let signer_seeds = get_cpi_signer_seeds(); + +// // Serialize CPI instruction data with inputs +// serialize_mint_to_cpi_instruction_data_with_inputs( +// inputs, +// &input_compressed_accounts, +// &output_compressed_accounts, +// proof, +// ); + +// GLOBAL_ALLOCATOR.free_heap(pre_compressed_accounts_pos)?; + +// use anchor_lang::InstructionData; + +// let instructiondata = light_system_program::instruction::InvokeCpi { +// inputs: inputs.to_owned(), +// }; + +// let (sol_pool_pda, is_writable) = if let Some(pool_pda) = ctx.accounts.sol_pool_pda.as_ref() { +// (pool_pda.to_account_info(), true) +// } else { +// (ctx.accounts.light_system_program.to_account_info(), false) +// }; + +// // Build account infos including both output merkle tree and remaining accounts (compressed mint merkle tree) +// let mut account_infos = vec![ +// ctx.accounts.fee_payer.to_account_info(), +// ctx.accounts.cpi_authority_pda.to_account_info(), +// ctx.accounts.registered_program_pda.to_account_info(), +// ctx.accounts.noop_program.to_account_info(), +// ctx.accounts.account_compression_authority.to_account_info(), +// ctx.accounts.account_compression_program.to_account_info(), +// ctx.accounts.self_program.to_account_info(), +// sol_pool_pda, +// ctx.accounts.light_system_program.to_account_info(), +// ctx.accounts.system_program.to_account_info(), +// ctx.accounts.light_system_program.to_account_info(), // cpi_context_account placeholder +// ctx.accounts.merkle_tree.to_account_info(), // output merkle tree +// ]; + +// // Add remaining accounts (compressed mint merkle tree, etc.) +// account_infos.extend_from_slice(ctx.remaining_accounts); + +// // Build account metas +// let mut accounts = vec![ +// AccountMeta::new(account_infos[0].key(), true), // fee_payer +// AccountMeta::new_readonly(account_infos[1].key(), true), // cpi_authority_pda (signer) +// AccountMeta::new_readonly(account_infos[2].key(), false), // registered_program_pda +// AccountMeta::new_readonly(account_infos[3].key(), false), // noop_program +// AccountMeta::new_readonly(account_infos[4].key(), false), // account_compression_authority +// AccountMeta::new_readonly(account_infos[5].key(), false), // account_compression_program +// AccountMeta::new_readonly(account_infos[6].key(), false), // self_program +// AccountMeta::new(account_infos[7].key(), is_writable), // sol_pool_pda +// AccountMeta::new_readonly(account_infos[8].key(), false), // decompression_recipient placeholder +// AccountMeta::new_readonly(account_infos[9].key(), false), // system_program +// AccountMeta::new_readonly(account_infos[10].key(), false), // cpi_context_account placeholder +// AccountMeta::new(account_infos[11].key(), false), // output merkle tree (writable) +// ]; + +// // Add remaining account metas (compressed mint merkle tree should be writable) +// for remaining in &account_infos[12..] { +// accounts.push(AccountMeta::new(remaining.key(), false)); +// } + +// let instruction = anchor_lang::solana_program::instruction::Instruction { +// program_id: light_system_program::ID, +// accounts, +// data: instructiondata.data(), +// }; + +// bench_sbf_end!("tm_cpi_mint_update"); +// bench_sbf_start!("tm_invoke_mint_update"); +// anchor_lang::solana_program::program::invoke_signed( +// &instruction, +// account_infos.as_slice(), +// &[&signer_seeds[..]], +// )?; +// bench_sbf_end!("tm_invoke_mint_update"); +// Ok(()) +// } + #[inline(never)] pub fn mint_spl_to_pool_pda( ctx: &Context, @@ -580,7 +895,12 @@ mod test { } let mut inputs = Vec::::new(); - serialize_mint_to_cpi_instruction_data(&mut inputs, &output_compressed_accounts); + serialize_mint_to_cpi_instruction_data_with_inputs( + &mut inputs, + &[], + &output_compressed_accounts, + None, + ); let inputs_struct = InstructionDataInvokeCpi { relay_fee: None, input_compressed_accounts_with_merkle_context: Vec::with_capacity(0), @@ -643,17 +963,67 @@ mod test { merkle_tree_index: 0, }; } + + // Randomly test with or without compressed mint inputs + let (input_compressed_accounts, expected_inputs, proof) = if rng.gen_bool(0.5) { + // Test with compressed mint inputs (50% chance) + let input_mint_account = PackedCompressedAccountWithMerkleContext { + compressed_account: CompressedAccount { + owner: crate::ID.into(), + lamports: 0, + address: Some([rng.gen::(); 32]), + data: Some(CompressedAccountData { + discriminator: crate::constants::COMPRESSED_MINT_DISCRIMINATOR, + data: vec![rng.gen::(); 32], + data_hash: [rng.gen::(); 32], + }), + }, + merkle_context: PackedMerkleContext { + merkle_tree_pubkey_index: rng.gen_range(0..10), + queue_pubkey_index: rng.gen_range(0..10), + leaf_index: rng.gen_range(0..1000), + prove_by_index: rng.gen_bool(0.5), + }, + root_index: rng.gen_range(0..100), + read_only: false, + }; + + let proof = if rng.gen_bool(0.3) { + Some(CompressedProof { + a: [rng.gen::(); 32], + b: [rng.gen::(); 64], + c: [rng.gen::(); 32], + }) + } else { + None + }; + + ( + vec![input_mint_account.clone()], + vec![input_mint_account], + proof, + ) + } else { + // Test without compressed mint inputs (50% chance) + (Vec::new(), Vec::new(), None) + }; + let mut inputs = Vec::::new(); - serialize_mint_to_cpi_instruction_data(&mut inputs, &output_compressed_accounts); + serialize_mint_to_cpi_instruction_data_with_inputs( + &mut inputs, + &input_compressed_accounts, + &output_compressed_accounts, + proof, + ); let sum = output_compressed_accounts .iter() .map(|x| x.compressed_account.lamports) .sum::(); let inputs_struct = InstructionDataInvokeCpi { relay_fee: None, - input_compressed_accounts_with_merkle_context: Vec::with_capacity(0), + input_compressed_accounts_with_merkle_context: expected_inputs, output_compressed_accounts: output_compressed_accounts.clone(), - proof: None, + proof, new_address_params: Vec::with_capacity(0), compress_or_decompress_lamports: Some(sum), is_compress: true, diff --git a/programs/compressed-token/src/process_transfer.rs b/programs/compressed-token/anchor/src/process_transfer.rs similarity index 100% rename from programs/compressed-token/src/process_transfer.rs rename to programs/compressed-token/anchor/src/process_transfer.rs diff --git a/programs/compressed-token/src/spl_compression.rs b/programs/compressed-token/anchor/src/spl_compression.rs similarity index 100% rename from programs/compressed-token/src/spl_compression.rs rename to programs/compressed-token/anchor/src/spl_compression.rs diff --git a/programs/compressed-token/src/token_data.rs b/programs/compressed-token/anchor/src/token_data.rs similarity index 100% rename from programs/compressed-token/src/token_data.rs rename to programs/compressed-token/anchor/src/token_data.rs diff --git a/programs/compressed-token/program/Cargo.toml b/programs/compressed-token/program/Cargo.toml new file mode 100644 index 0000000000..7402cd396c --- /dev/null +++ b/programs/compressed-token/program/Cargo.toml @@ -0,0 +1,57 @@ +[package] +name = "light-compressed-token" +version = "2.0.0" +description = "Generalized token compression on Solana" +repository = "https://github.com/Lightprotocol/light-protocol" +license = "Apache-2.0" +edition = "2021" + +[lib] +crate-type = ["cdylib", "lib"] +name = "light_compressed_token" + +[features] +no-entrypoint = [] +no-log-ix-name = [] +cpi = ["no-entrypoint"] +custom-heap = ["light-heap"] +mem-profiling = [] +default = ["custom-heap"] +test-sbf = [] +bench-sbf = [] +cpi-context = [] +cpi-without-program-ids = [] + +[dependencies] +anchor-lang = { workspace = true } +spl-token = { workspace = true, features = ["no-entrypoint"] } +account-compression = { workspace = true, features = ["cpi", "no-idl"] } +light-system-program-anchor = { workspace = true, features = ["cpi"] } +solana-security-txt = "1.1.0" +light-hasher = { workspace = true } +light-heap = { workspace = true, optional = true } +light-compressed-account = { workspace = true, features = ["anchor"] } +spl-token-2022 = { workspace = true } +spl-pod = { workspace = true } +light-zero-copy = { workspace = true, features = ["mut", "std", "derive"] } +zerocopy = { workspace = true } +anchor-compressed-token = { path = "../anchor", features = ["cpi"] } +light-account-checks = { workspace = true, features = ["solana", "pinocchio"] } +light-sdk = { workspace = true } +borsh = { workspace = true } +light-sdk-types = { workspace = true } +solana-pubkey = { workspace = true } +arrayvec = { workspace = true } +pinocchio = { workspace = true, features = ["std"] } +light-sdk-pinocchio = { workspace = true } + +[dev-dependencies] +rand = { workspace = true } +num-bigint = { workspace = true } + +[lints.rust.unexpected_cfgs] +level = "allow" +check-cfg = [ + 'cfg(target_os, values("solana"))', + 'cfg(feature, values("frozen-abi", "no-entrypoint"))', +] diff --git a/programs/compressed-token/program/README.md b/programs/compressed-token/program/README.md new file mode 100644 index 0000000000..764e509cdc --- /dev/null +++ b/programs/compressed-token/program/README.md @@ -0,0 +1,13 @@ +# Compressed Token Program + +A token program on the Solana blockchain using ZK Compression. + +This program provides an interface and implementation that third parties can utilize to create and use compressed tokens on Solana. + +Documentation is available at https://zkcompression.com + +Source code: https://github.com/Lightprotocol/light-protocol/tree/main/programs/compressed-token + +## Audit + +This code is unaudited. Use at your own risk. diff --git a/programs/compressed-token/program/Xargo.toml b/programs/compressed-token/program/Xargo.toml new file mode 100644 index 0000000000..475fb71ed1 --- /dev/null +++ b/programs/compressed-token/program/Xargo.toml @@ -0,0 +1,2 @@ +[target.bpfel-unknown-unknown.dependencies.std] +features = [] diff --git a/programs/compressed-token/program/src/close_token_account/accounts.rs b/programs/compressed-token/program/src/close_token_account/accounts.rs new file mode 100644 index 0000000000..27a1dbdae1 --- /dev/null +++ b/programs/compressed-token/program/src/close_token_account/accounts.rs @@ -0,0 +1,30 @@ +use anchor_lang::prelude::ProgramError; +use light_account_checks::checks::{check_mut, check_signer}; +use pinocchio::account_info::AccountInfo; + +pub struct CloseTokenAccountAccounts<'a> { + pub token_account: &'a AccountInfo, + pub destination: &'a AccountInfo, + pub authority: &'a AccountInfo, +} + +impl<'a> CloseTokenAccountAccounts<'a> { + pub fn new(accounts: &'a [AccountInfo]) -> Result { + Ok(Self { + token_account: &accounts[0], + destination: &accounts[1], + authority: &accounts[2], + }) + } + + pub fn get_checked(accounts: &'a [AccountInfo]) -> Result { + let accounts_struct = Self::new(accounts)?; + + // Basic validations using light_account_checks + check_mut(accounts_struct.token_account)?; + check_mut(accounts_struct.destination)?; + check_signer(accounts_struct.authority)?; + + Ok(accounts_struct) + } +} \ No newline at end of file diff --git a/programs/compressed-token/program/src/close_token_account/mod.rs b/programs/compressed-token/program/src/close_token_account/mod.rs new file mode 100644 index 0000000000..b96a2596f4 --- /dev/null +++ b/programs/compressed-token/program/src/close_token_account/mod.rs @@ -0,0 +1,2 @@ +pub mod accounts; +pub mod processor; \ No newline at end of file diff --git a/programs/compressed-token/program/src/close_token_account/processor.rs b/programs/compressed-token/program/src/close_token_account/processor.rs new file mode 100644 index 0000000000..223ed08cac --- /dev/null +++ b/programs/compressed-token/program/src/close_token_account/processor.rs @@ -0,0 +1,67 @@ +use anchor_lang::prelude::ProgramError; +use light_account_checks::AccountInfoTrait; +use pinocchio::account_info::AccountInfo; +use spl_pod::bytemuck::pod_from_bytes; +use spl_token_2022::pod::PodAccount; +use spl_token_2022::state::AccountState; + +use super::accounts::CloseTokenAccountAccounts; + +/// Process the close token account instruction +pub fn process_close_token_account( + account_infos: &[AccountInfo], + _instruction_data: &[u8], +) -> Result<(), ProgramError> { + // Validate and get accounts + let accounts = CloseTokenAccountAccounts::get_checked(account_infos)?; + + // Validate token account state and balance + { + let token_account_data = AccountInfoTrait::try_borrow_data(accounts.token_account) + .map_err(|_| ProgramError::InvalidAccountData)?; + let pod_account = pod_from_bytes::(&token_account_data) + .map_err(|_| ProgramError::InvalidAccountData)?; + + // Check that the account is initialized + if pod_account.state != AccountState::Initialized as u8 { + return Err(ProgramError::UninitializedAccount); + } + + // Check that the account has zero balance + let balance: u64 = pod_account.amount.into(); + if balance != 0 { + return Err(ProgramError::InvalidAccountData); + } + + // Verify the authority matches the account owner + let account_owner = solana_pubkey::Pubkey::from(pod_account.owner); + let authority_key = solana_pubkey::Pubkey::new_from_array(*accounts.authority.key()); + if account_owner != authority_key { + return Err(ProgramError::InvalidAccountOwner); + } + } + // TODO: double check that it is safely closed. + // Transfer all lamports from token account to destination + let token_account_lamports = AccountInfoTrait::lamports(accounts.token_account); + + // Set token account lamports to 0 + unsafe { + *accounts.token_account.borrow_mut_lamports_unchecked() = 0; + } + + // Add lamports to destination + let destination_lamports = AccountInfoTrait::lamports(accounts.destination); + let new_destination_lamports = destination_lamports + .checked_add(token_account_lamports) + .ok_or(ProgramError::ArithmeticOverflow)?; + + unsafe { + *accounts.destination.borrow_mut_lamports_unchecked() = new_destination_lamports; + } + // Clear the token account data + let mut token_account_data = AccountInfoTrait::try_borrow_mut_data(accounts.token_account) + .map_err(|_| ProgramError::InvalidAccountData)?; + token_account_data.fill(0); + + Ok(()) +} diff --git a/programs/compressed-token/program/src/constants.rs b/programs/compressed-token/program/src/constants.rs new file mode 100644 index 0000000000..2309eb6902 --- /dev/null +++ b/programs/compressed-token/program/src/constants.rs @@ -0,0 +1,5 @@ +// Compressed mint discriminator +pub const COMPRESSED_MINT_DISCRIMINATOR: [u8; 8] = [1, 0, 0, 0, 0, 0, 0, 0]; + +// CPI authority bump +pub const BUMP_CPI_AUTHORITY: u8 = 254; \ No newline at end of file diff --git a/programs/compressed-token/program/src/create_associated_token_account/accounts.rs b/programs/compressed-token/program/src/create_associated_token_account/accounts.rs new file mode 100644 index 0000000000..1e43301284 --- /dev/null +++ b/programs/compressed-token/program/src/create_associated_token_account/accounts.rs @@ -0,0 +1,72 @@ +use anchor_lang::prelude::ProgramError; +use anchor_lang::solana_program::program_pack::IsInitialized; +use light_account_checks::{checks::{check_mut, check_non_mut, check_signer}, AccountInfoTrait}; +use pinocchio::account_info::AccountInfo; +use spl_pod::bytemuck::pod_from_bytes; +use spl_token_2022::pod::PodMint; + +pub struct CreateAssociatedTokenAccountAccounts<'a> { + pub fee_payer: &'a AccountInfo, + pub associated_token_account: &'a AccountInfo, + pub mint: Option<&'a AccountInfo>, + pub system_program: &'a AccountInfo, +} + +impl<'a> CreateAssociatedTokenAccountAccounts<'a> { + pub fn new( + accounts: &'a [AccountInfo], + mint_is_decompressed: bool, + ) -> Result { + let (mint, system_program_index) = if mint_is_decompressed { + (Some(&accounts[2]), 3) + } else { + (None, 2) + }; + Ok(Self { + fee_payer: &accounts[0], + associated_token_account: &accounts[1], + mint, + system_program: &accounts[system_program_index], + }) + } + + pub fn get_checked( + accounts: &'a [AccountInfo], + mint: &[u8; 32], + mint_is_decompressed: bool, + ) -> Result { + let accounts_struct = Self::new(accounts, mint_is_decompressed)?; + + // Basic validations using light_account_checks + check_signer(accounts_struct.fee_payer)?; + check_mut(accounts_struct.fee_payer)?; + check_mut(accounts_struct.associated_token_account)?; + check_non_mut(accounts_struct.system_program)?; + // ata derivation is checked implicitly by cpi + + if let Some(mint_account_info) = accounts_struct.mint { + if AccountInfoTrait::key(mint_account_info) != *mint { + return Err(ProgramError::InvalidAccountData); + } + + // Check if owned by either spl-token or spl-token-2022 program + let spl_token_id = spl_token::id().to_bytes(); + let spl_token_2022_id = spl_token_2022::id().to_bytes(); + let owner = unsafe { *mint_account_info.owner() }; + if owner != spl_token_id && owner != spl_token_2022_id { + return Err(ProgramError::IncorrectProgramId); + } + + let mint_data = AccountInfoTrait::try_borrow_data(mint_account_info) + .map_err(|_| ProgramError::InvalidAccountData)?; + let pod_mint = pod_from_bytes::(&mint_data) + .map_err(|_| ProgramError::InvalidAccountData)?; + + if !pod_mint.is_initialized() { + return Err(ProgramError::UninitializedAccount); + } + } + + Ok(accounts_struct) + } +} diff --git a/programs/compressed-token/program/src/create_associated_token_account/instruction_data.rs b/programs/compressed-token/program/src/create_associated_token_account/instruction_data.rs new file mode 100644 index 0000000000..731fd597e2 --- /dev/null +++ b/programs/compressed-token/program/src/create_associated_token_account/instruction_data.rs @@ -0,0 +1,12 @@ +use borsh::{BorshDeserialize, BorshSerialize}; +use light_compressed_account::Pubkey; +use light_zero_copy::ZeroCopy; + +#[derive(Debug, Clone, BorshSerialize, BorshDeserialize, ZeroCopy)] +pub struct CreateAssociatedTokenAccountInstructionData { + /// The owner of the associated token account + pub owner: Pubkey, + /// The mint for the associated token account + pub mint: Pubkey, + pub bump: u8, +} diff --git a/programs/compressed-token/program/src/create_associated_token_account/mod.rs b/programs/compressed-token/program/src/create_associated_token_account/mod.rs new file mode 100644 index 0000000000..a3b881274a --- /dev/null +++ b/programs/compressed-token/program/src/create_associated_token_account/mod.rs @@ -0,0 +1,5 @@ +pub mod accounts; +pub mod instruction_data; +pub mod processor; + +pub use processor::process_create_associated_token_account; \ No newline at end of file diff --git a/programs/compressed-token/program/src/create_associated_token_account/processor.rs b/programs/compressed-token/program/src/create_associated_token_account/processor.rs new file mode 100644 index 0000000000..80c1d4201c --- /dev/null +++ b/programs/compressed-token/program/src/create_associated_token_account/processor.rs @@ -0,0 +1,115 @@ +use anchor_lang::prelude::{ProgramError, SolanaSysvar}; +use anchor_lang::solana_program::{rent::Rent, system_instruction}; +use light_account_checks::AccountInfoTrait; +use light_zero_copy::borsh::Deserialize; +use pinocchio::account_info::AccountInfo; + +use super::{ + accounts::CreateAssociatedTokenAccountAccounts, + instruction_data::CreateAssociatedTokenAccountInstructionData, +}; +use crate::shared::initialize_token_account::initialize_token_account; + +/// Note: +/// - we don't validate the mint because it would be very expensive with compressed mints +/// - it is possible to create an associated token account for non existing mints +/// - accounts with non existing mints can never have a balance +/// Process the create associated token account instruction +pub fn process_create_associated_token_account<'info>( + account_infos: &'info [AccountInfo], + instruction_data: &[u8], +) -> Result<(), ProgramError> { + // Parse instruction data using zero-copy + let (inputs, _) = CreateAssociatedTokenAccountInstructionData::zero_copy_at(instruction_data) + .map_err(ProgramError::from)?; + + // Validate and get accounts + let accounts = CreateAssociatedTokenAccountAccounts::get_checked( + account_infos, + &inputs.mint.to_bytes(), + false, + )?; + + { + let owner = inputs.owner.to_bytes(); + let mint = inputs.mint.to_bytes(); + // Define the PDA seeds for signing + use pinocchio::instruction::{Seed, Signer}; + let bump_bytes = [inputs.bump]; + let seed_array = [ + Seed::from(owner.as_ref()), + Seed::from(crate::ID.as_ref()), + Seed::from(mint.as_ref()), + Seed::from(bump_bytes.as_ref()), + ]; + let signer = Signer::from(&seed_array); + + // Calculate rent for SPL token account (165 bytes) + let token_account_size = 165_usize; + let rent = Rent::get()?; + let rent_lamports = rent.minimum_balance(token_account_size); + + // Create the associated token account + let fee_payer_key = + solana_pubkey::Pubkey::new_from_array(AccountInfoTrait::key(accounts.fee_payer)); + let ata_key = solana_pubkey::Pubkey::new_from_array(AccountInfoTrait::key( + accounts.associated_token_account, + )); + let create_account_instruction = system_instruction::create_account( + &fee_payer_key, + &ata_key, + rent_lamports, + token_account_size as u64, + &crate::ID, + ); + + // Execute the create account instruction with PDA signing + let instruction_data = create_account_instruction.data; + let pinocchio_instruction = pinocchio::instruction::Instruction { + program_id: &create_account_instruction.program_id.to_bytes(), + accounts: &[ + pinocchio::instruction::AccountMeta { + pubkey: accounts.fee_payer.key(), + is_signer: true, + is_writable: true, + }, + pinocchio::instruction::AccountMeta { + pubkey: accounts.associated_token_account.key(), + is_signer: true, + is_writable: true, + }, + pinocchio::instruction::AccountMeta { + pubkey: accounts.system_program.key(), + is_signer: false, + is_writable: false, + }, + ], + data: &instruction_data, + }; + + match pinocchio::program::invoke_signed( + &pinocchio_instruction, + &[ + accounts.fee_payer, + accounts.associated_token_account, + accounts.system_program, + ], + &[signer], + ) { + Ok(()) => {} + Err(e) => { + anchor_lang::solana_program::msg!("invoke_signed failed: {:?}", e); + return Err(ProgramError::Custom(u64::from(e) as u32)); + } + } + } + + // Initialize the token account using shared utility + initialize_token_account( + accounts.associated_token_account, + &inputs.mint.to_bytes(), + &inputs.owner.to_bytes(), + )?; + + Ok(()) +} diff --git a/programs/compressed-token/program/src/create_spl_mint/accounts.rs b/programs/compressed-token/program/src/create_spl_mint/accounts.rs new file mode 100644 index 0000000000..721ded6bad --- /dev/null +++ b/programs/compressed-token/program/src/create_spl_mint/accounts.rs @@ -0,0 +1,86 @@ +use anchor_lang::solana_program::program_error::ProgramError; +use light_account_checks::checks::{check_mut, check_signer}; +use pinocchio::account_info::AccountInfo; +use crate::shared::AccountIterator; + +pub struct CreateSplMintAccounts<'info> { + pub fee_payer: &'info AccountInfo, + pub authority: &'info AccountInfo, + pub mint: &'info AccountInfo, + pub mint_signer: &'info AccountInfo, + pub token_pool_pda: &'info AccountInfo, + pub token_program: &'info AccountInfo, + pub cpi_authority_pda: &'info AccountInfo, + pub light_system_program: &'info AccountInfo, + pub registered_program_pda: &'info AccountInfo, + pub noop_program: &'info AccountInfo, + pub account_compression_authority: &'info AccountInfo, + pub account_compression_program: &'info AccountInfo, + pub system_program: &'info AccountInfo, + pub self_program: &'info AccountInfo, + pub in_merkle_tree: &'info AccountInfo, + pub in_output_queue: &'info AccountInfo, + pub out_output_queue: &'info AccountInfo, +} + +impl<'info> CreateSplMintAccounts<'info> { + + pub fn validate_and_parse( + accounts: &'info [AccountInfo], + ) -> Result { + if accounts.len() < 17 { + return Err(ProgramError::NotEnoughAccountKeys); + } + + let mut iter = AccountIterator::new(accounts); + + // Static non-CPI accounts first + let authority = iter.next()?; + let mint = iter.next()?; + let mint_signer = iter.next()?; + let token_pool_pda = iter.next()?; + let token_program = iter.next()?; + let light_system_program = iter.next()?; + + // CPI accounts in exact order expected by light-system-program + let fee_payer = iter.next()?; + let cpi_authority_pda = iter.next()?; + let registered_program_pda = iter.next()?; + let noop_program = iter.next()?; + let account_compression_authority = iter.next()?; + let account_compression_program = iter.next()?; + let self_program = iter.next()?; + + let system_program = iter.next()?; + let in_merkle_tree = iter.next()?; + let in_output_queue = iter.next()?; + let out_output_queue = iter.next()?; + + // Validate fee_payer: must be signer and mutable + check_signer(fee_payer).map_err(ProgramError::from)?; + check_mut(fee_payer).map_err(ProgramError::from)?; + + // Validate authority: must be signer + check_signer(authority).map_err(ProgramError::from)?; + + Ok(CreateSplMintAccounts { + fee_payer, + authority, + mint, + mint_signer, + token_pool_pda, + token_program, + cpi_authority_pda, + light_system_program, + registered_program_pda, + noop_program, + account_compression_authority, + account_compression_program, + system_program, + self_program, + in_merkle_tree, + in_output_queue, + out_output_queue, + }) + } +} diff --git a/programs/compressed-token/program/src/create_spl_mint/instructions.rs b/programs/compressed-token/program/src/create_spl_mint/instructions.rs new file mode 100644 index 0000000000..de9efece68 --- /dev/null +++ b/programs/compressed-token/program/src/create_spl_mint/instructions.rs @@ -0,0 +1,16 @@ +use crate::mint_to_compressed::instructions::CompressedMintInputs; +use borsh::{BorshDeserialize, BorshSerialize}; +use light_compressed_account::{instruction_data::compressed_proof::CompressedProof, Pubkey}; +use light_zero_copy::ZeroCopy; + +#[derive(Debug, Clone, BorshSerialize, BorshDeserialize, ZeroCopy)] +pub struct CreateSplMintInstructionData { + pub mint_bump: u8, + pub token_pool_bump: u8, + // TODO: remove decimals, duplicate input + pub decimals: u8, + pub mint_authority: Pubkey, + pub compressed_mint_inputs: CompressedMintInputs, + pub freeze_authority: Option, + pub proof: Option, +} diff --git a/programs/compressed-token/program/src/create_spl_mint/mod.rs b/programs/compressed-token/program/src/create_spl_mint/mod.rs new file mode 100644 index 0000000000..c31719e252 --- /dev/null +++ b/programs/compressed-token/program/src/create_spl_mint/mod.rs @@ -0,0 +1,3 @@ +pub mod accounts; +pub mod instructions; +pub mod processor; \ No newline at end of file diff --git a/programs/compressed-token/program/src/create_spl_mint/processor.rs b/programs/compressed-token/program/src/create_spl_mint/processor.rs new file mode 100644 index 0000000000..bb3a4ecf24 --- /dev/null +++ b/programs/compressed-token/program/src/create_spl_mint/processor.rs @@ -0,0 +1,463 @@ +use anchor_lang::solana_program::{ + program_error::ProgramError, rent::Rent, system_instruction, sysvar::Sysvar, +}; +use arrayvec::ArrayVec; +use light_zero_copy::{borsh::Deserialize, borsh_mut::DeserializeMut, ZeroCopyNew}; +use pinocchio::{account_info::AccountInfo, pubkey::Pubkey}; +use spl_token::solana_program::log::sol_log_compute_units; + +use crate::{ + constants::POOL_SEED, + create_spl_mint::{ + accounts::CreateSplMintAccounts, + instructions::{CreateSplMintInstructionData, ZCreateSplMintInstructionData}, + }, + shared::cpi::execute_cpi_invoke, +}; +// TODO: check and handle extensions +pub fn process_create_spl_mint( + program_id: Pubkey, + accounts: &[AccountInfo], + instruction_data: &[u8], +) -> Result<(), ProgramError> { + sol_log_compute_units(); + + // Parse instruction data using zero-copy + let (parsed_instruction_data, _) = CreateSplMintInstructionData::zero_copy_at(instruction_data) + .map_err(|_| ProgramError::InvalidInstructionData)?; + + sol_log_compute_units(); + + // Validate and parse accounts + let validated_accounts = CreateSplMintAccounts::validate_and_parse(accounts)?; + + // Verify mint PDA matches the spl_mint field in compressed mint inputs + let expected_mint: [u8; 32] = parsed_instruction_data + .compressed_mint_inputs + .compressed_mint_input + .spl_mint + .into(); + if validated_accounts.mint.key() != &expected_mint { + return Err(ProgramError::InvalidAccountData); + } + + // Create the mint account manually (PDA derived from our program, owned by token program) + create_mint_account( + &validated_accounts, + &program_id, + parsed_instruction_data.mint_bump, + )?; + + // Initialize the mint account using Token-2022's initialize_mint2 instruction + initialize_mint_account(&validated_accounts, &parsed_instruction_data)?; + + // Create the token pool account manually (PDA derived from our program, owned by token program) + create_token_pool_account_manual(&validated_accounts, &program_id)?; + + // Initialize the token pool account + initialize_token_pool_account(&validated_accounts)?; + + // Mint the existing supply to the token pool if there's any supply + if parsed_instruction_data + .compressed_mint_inputs + .compressed_mint_input + .supply + > 0 + { + mint_existing_supply_to_pool(&validated_accounts, &parsed_instruction_data)?; + } + // Update the compressed mint to mark it as is_decompressed = true + update_compressed_mint_to_decompressed( + accounts, + &validated_accounts, + &parsed_instruction_data, + &program_id, + )?; + + sol_log_compute_units(); + Ok(()) +} + +fn update_compressed_mint_to_decompressed<'info>( + all_accounts: &'info [AccountInfo], + accounts: &CreateSplMintAccounts<'info>, + instruction_data: &ZCreateSplMintInstructionData, + program_id: &pinocchio::pubkey::Pubkey, +) -> Result<(), ProgramError> { + use crate::mint::{ + input::create_input_compressed_mint_account, output::create_output_compressed_mint_account, + }; + use crate::shared::{ + context::TokenContext, + cpi_bytes_size::{ + allocate_invoke_with_read_only_cpi_bytes, cpi_bytes_config, CpiConfigInput, + }, + }; + use light_compressed_account::instruction_data::with_readonly::InstructionDataInvokeCpiWithReadOnly; + + // Build configuration for CPI instruction data - 1 input, 1 output, with optional proof + let config_input = CpiConfigInput { + input_accounts: ArrayVec::new(), + output_accounts: ArrayVec::new(), + has_proof: instruction_data.proof.is_some(), + compressed_mint: true, + compressed_mint_with_freeze_authority: instruction_data.freeze_authority.is_some(), + }; + + let config = cpi_bytes_config(config_input); + let mut cpi_bytes = allocate_invoke_with_read_only_cpi_bytes(&config); + + let (mut cpi_instruction_struct, _) = + InstructionDataInvokeCpiWithReadOnly::new_zero_copy(&mut cpi_bytes[8..], config) + .map_err(ProgramError::from)?; + + cpi_instruction_struct.bump = crate::LIGHT_CPI_SIGNER.bump; + cpi_instruction_struct.invoking_program_id = crate::LIGHT_CPI_SIGNER.program_id.into(); + + let mut context = TokenContext::new(); + let hashed_mint_authority = context.get_or_hash_pubkey(accounts.authority.key()); + + // Process input compressed mint account (before is_decompressed = true) + create_input_compressed_mint_account( + &mut cpi_instruction_struct.input_compressed_accounts[0], + &mut context, + &instruction_data.compressed_mint_inputs, + &hashed_mint_authority, + )?; + + // Process output compressed mint account (with is_decompressed = true) + let mint_inputs = &instruction_data + .compressed_mint_inputs + .compressed_mint_input; + let mint_pda = mint_inputs.spl_mint; + let decimals = instruction_data.decimals; + let freeze_authority = if mint_inputs.freeze_authority_is_set() { + Some(mint_inputs.freeze_authority) + } else { + None + }; + + let mint_config = crate::mint::state::CompressedMintConfig { + mint_authority: (true, ()), + freeze_authority: (mint_inputs.freeze_authority_is_set(), ()), + }; + let compressed_account_address = *instruction_data.compressed_mint_inputs.address; + let supply = mint_inputs.supply; // Keep same supply, just mark as decompressed + create_output_compressed_mint_account( + &mut cpi_instruction_struct.output_compressed_accounts[0], + mint_pda, + decimals, + freeze_authority, + Some(instruction_data.mint_authority), + supply, + &program_id.into(), + mint_config, + compressed_account_address, + instruction_data + .compressed_mint_inputs + .output_merkle_tree_index, + )?; + + // Set proof data if provided + if let Some(instruction_proof) = &instruction_data.proof { + if let Some(proof) = cpi_instruction_struct.proof.as_deref_mut() { + proof.a = instruction_proof.a; + proof.b = instruction_proof.b; + proof.c = instruction_proof.c; + } + } + + // Override the output compressed mint to set is_decompressed = true + // The create_output_compressed_mint_account function sets is_decompressed = false by default + { + let output_account = &mut cpi_instruction_struct.output_compressed_accounts[0]; + if let Some(data) = output_account.compressed_account.data.as_mut() { + let (mut compressed_mint, _) = + crate::mint::state::CompressedMint::zero_copy_at_mut(data.data) + .map_err(ProgramError::from)?; + compressed_mint.is_decompressed = 1; // Override to mark as decompressed (1 = true) + + // Recalculate hash with is_decompressed = true + *data.data_hash = compressed_mint + .hash() + .map_err(|_| ProgramError::InvalidAccountData)?; + } + } + + // Extract tree accounts for the generalized CPI call + let tree_accounts = [ + accounts.in_merkle_tree.key(), + accounts.in_output_queue.key(), + accounts.out_output_queue.key(), + ]; + // Execute CPI to light system program to update the compressed mint + execute_cpi_invoke( + &all_accounts[6..], // Skip first 6 non-CPI accounts + cpi_bytes, + &tree_accounts, + false, // no sol_pool_pda + None, // no cpi_context_account + )?; + + Ok(()) +} + +/// Creates the mint account manually as a PDA derived from our program but owned by the token program +fn create_mint_account( + accounts: &CreateSplMintAccounts<'_>, + program_id: &pinocchio::pubkey::Pubkey, + mint_bump: u8, +) -> Result<(), ProgramError> { + let mint_account_size = 82; // Size of Token-2022 Mint account + let rent = Rent::get()?; + let lamports = rent.minimum_balance(mint_account_size); + + // Derive the mint PDA seeds using provided bump + let program_id_pubkey = solana_pubkey::Pubkey::new_from_array(*program_id); + let expected_mint = solana_pubkey::Pubkey::create_program_address( + &[ + b"compressed_mint", + accounts.mint_signer.key().as_ref(), + &[mint_bump], + ], + &program_id_pubkey, + ) + .map_err(|_| ProgramError::InvalidAccountData)?; + + // Verify the provided mint account matches the expected PDA + if accounts.mint.key() != &expected_mint.to_bytes() { + return Err(ProgramError::InvalidAccountData); + } + + use pinocchio::instruction::{Seed, Signer}; + let mint_signer_key = accounts.mint_signer.key(); + let bump_bytes = [mint_bump]; + let seed_array = [ + Seed::from(b"compressed_mint"), + Seed::from(mint_signer_key.as_ref()), + Seed::from(bump_bytes.as_ref()), + ]; + let signer = Signer::from(&seed_array); + + // Create account owned by token program but derived from our program + let fee_payer_pubkey = solana_pubkey::Pubkey::new_from_array(*accounts.fee_payer.key()); + let mint_pubkey = solana_pubkey::Pubkey::new_from_array(*accounts.mint.key()); + let token_program_pubkey = solana_pubkey::Pubkey::new_from_array(*accounts.token_program.key()); + let create_account_ix = system_instruction::create_account( + &fee_payer_pubkey, + &mint_pubkey, + lamports, + mint_account_size as u64, + &token_program_pubkey, // Owned by token program + ); + + let pinocchio_instruction = pinocchio::instruction::Instruction { + program_id: &create_account_ix.program_id.to_bytes(), + accounts: &[ + pinocchio::instruction::AccountMeta::new(accounts.fee_payer.key(), true, true), + pinocchio::instruction::AccountMeta::new(accounts.mint.key(), true, true), + pinocchio::instruction::AccountMeta::readonly(accounts.system_program.key()), + ], + data: &create_account_ix.data, + }; + + match pinocchio::program::invoke_signed( + &pinocchio_instruction, + &[accounts.fee_payer, accounts.mint, accounts.system_program], + &[signer], // Signed with our program's PDA seeds + ) { + Ok(()) => {} + Err(e) => { + return Err(ProgramError::Custom(u64::from(e) as u32)); + } + } + + Ok(()) +} + +/// Initializes the mint account using Token-2022's initialize_mint2 instruction +fn initialize_mint_account( + accounts: &CreateSplMintAccounts<'_>, + instruction_data: &ZCreateSplMintInstructionData, +) -> Result<(), ProgramError> { + let spl_ix = spl_token_2022::instruction::initialize_mint2( + &solana_pubkey::Pubkey::new_from_array(*accounts.token_program.key()), + &solana_pubkey::Pubkey::new_from_array(*accounts.mint.key()), + &solana_pubkey::Pubkey::new_from_array(instruction_data.mint_authority.into()), + instruction_data + .freeze_authority + .as_ref() + .map(|f| solana_pubkey::Pubkey::new_from_array((**f).into())) + .as_ref(), + instruction_data.decimals, + )?; + + let initialize_mint_ix = pinocchio::instruction::Instruction { + program_id: accounts.token_program.key(), + accounts: &[pinocchio::instruction::AccountMeta::new( + accounts.mint.key(), + true, // is_writable: true (we're initializing the mint) + false, + )], + data: &spl_ix.data, + }; + + match pinocchio::program::invoke(&initialize_mint_ix, &[accounts.mint]) { + Ok(()) => {} + Err(e) => { + return Err(ProgramError::Custom(u64::from(e) as u32)); + } + } + + Ok(()) +} + +/// Creates the token pool account manually as a PDA derived from our program but owned by the token program +fn create_token_pool_account_manual( + accounts: &CreateSplMintAccounts<'_>, + program_id: &pinocchio::pubkey::Pubkey, +) -> Result<(), ProgramError> { + let token_account_size = 165; // Size of Token account + let rent = Rent::get()?; + let lamports = rent.minimum_balance(token_account_size); + + // Derive the token pool PDA seeds and bump + let mint_key = accounts.mint.key(); + let program_id_pubkey = solana_pubkey::Pubkey::new_from_array(*program_id); + let (expected_token_pool, bump) = solana_pubkey::Pubkey::find_program_address( + &[POOL_SEED, mint_key.as_ref()], + &program_id_pubkey, + ); + + // Verify the provided token pool account matches the expected PDA + if accounts.token_pool_pda.key() != &expected_token_pool.to_bytes() { + return Err(ProgramError::InvalidAccountData); + } + + use pinocchio::instruction::{Seed, Signer}; + let bump_bytes = [bump]; + let seed_array = [ + Seed::from(POOL_SEED), + Seed::from(mint_key.as_ref()), + Seed::from(bump_bytes.as_ref()), + ]; + let signer = Signer::from(&seed_array); + + // Create account owned by token program but derived from our program + let fee_payer_pubkey = solana_pubkey::Pubkey::new_from_array(*accounts.fee_payer.key()); + let token_pool_pubkey = solana_pubkey::Pubkey::new_from_array(*accounts.token_pool_pda.key()); + let token_program_pubkey = solana_pubkey::Pubkey::new_from_array(*accounts.token_program.key()); + let create_account_ix = system_instruction::create_account( + &fee_payer_pubkey, + &token_pool_pubkey, + lamports, + token_account_size as u64, + &token_program_pubkey, // Owned by token program + ); + + let pinocchio_instruction = pinocchio::instruction::Instruction { + program_id: &create_account_ix.program_id.to_bytes(), + accounts: &[ + pinocchio::instruction::AccountMeta::new(accounts.fee_payer.key(), true, true), + pinocchio::instruction::AccountMeta::new(accounts.token_pool_pda.key(), true, true), + pinocchio::instruction::AccountMeta::readonly(accounts.system_program.key()), + ], + data: &create_account_ix.data, + }; + + match pinocchio::program::invoke_signed( + &pinocchio_instruction, + &[ + accounts.fee_payer, + accounts.token_pool_pda, + accounts.system_program, + ], + &[signer], // Signed with our program's PDA seeds + ) { + Ok(()) => {} + Err(e) => { + return Err(ProgramError::Custom(u64::from(e) as u32)); + } + } + + Ok(()) +} + +/// Initializes the token pool account (assumes account already exists) +fn initialize_token_pool_account(accounts: &CreateSplMintAccounts<'_>) -> Result<(), ProgramError> { + let initialize_account_ix = pinocchio::instruction::Instruction { + program_id: accounts.token_program.key(), + accounts: &[ + pinocchio::instruction::AccountMeta::new(accounts.token_pool_pda.key(), true, false), // writable=true for initialization + pinocchio::instruction::AccountMeta::readonly(accounts.mint.key()), + ], + data: &spl_token_2022::instruction::initialize_account3( + &solana_pubkey::Pubkey::new_from_array(*accounts.token_program.key()), + &solana_pubkey::Pubkey::new_from_array(*accounts.token_pool_pda.key()), + &solana_pubkey::Pubkey::new_from_array(*accounts.mint.key()), + &solana_pubkey::Pubkey::new_from_array(*accounts.cpi_authority_pda.key()), + )? + .data, + }; + + match pinocchio::program::invoke( + &initialize_account_ix, + &[accounts.token_pool_pda, accounts.mint], + ) { + Ok(()) => {} + Err(e) => { + return Err(ProgramError::Custom(u64::from(e) as u32)); + } + } + Ok(()) +} + +/// Mints the existing supply from compressed mint to the token pool +fn mint_existing_supply_to_pool( + accounts: &CreateSplMintAccounts<'_>, + instruction_data: &ZCreateSplMintInstructionData, +) -> Result<(), ProgramError> { + // Only mint if the authority matches + if accounts.authority.key() != &instruction_data.mint_authority.to_bytes() { + return Err(ProgramError::InvalidAccountData); + } + + let supply = instruction_data + .compressed_mint_inputs + .compressed_mint_input + .supply + .into(); + + // Create SPL mint_to instruction and use its account structure + let spl_mint_to_ix = spl_token_2022::instruction::mint_to( + &solana_pubkey::Pubkey::new_from_array(*accounts.token_program.key()), + &solana_pubkey::Pubkey::new_from_array(*accounts.mint.key()), + &solana_pubkey::Pubkey::new_from_array(*accounts.token_pool_pda.key()), + &solana_pubkey::Pubkey::new_from_array(*accounts.authority.key()), + &[], + supply, + )?; + + // Mint tokens to the pool + let mint_to_ix = pinocchio::instruction::Instruction { + program_id: accounts.token_program.key(), + accounts: &[ + pinocchio::instruction::AccountMeta::new(accounts.mint.key(), true, false), // writable + pinocchio::instruction::AccountMeta::new(accounts.token_pool_pda.key(), true, false), // writable + pinocchio::instruction::AccountMeta::new(accounts.authority.key(), false, true), // signer + ], + data: &spl_mint_to_ix.data, + }; + + match pinocchio::program::invoke( + &mint_to_ix, + &[accounts.mint, accounts.token_pool_pda, accounts.authority], + ) { + Ok(()) => {} + Err(e) => { + return Err(ProgramError::Custom(u64::from(e) as u32)); + } + } + + Ok(()) +} diff --git a/programs/compressed-token/program/src/create_token_account/accounts.rs b/programs/compressed-token/program/src/create_token_account/accounts.rs new file mode 100644 index 0000000000..f67503695c --- /dev/null +++ b/programs/compressed-token/program/src/create_token_account/accounts.rs @@ -0,0 +1,27 @@ +use anchor_lang::prelude::ProgramError; +use light_account_checks::checks::{check_mut, check_non_mut}; +use pinocchio::account_info::AccountInfo; + +pub struct CreateTokenAccountAccounts<'a> { + pub token_account: &'a AccountInfo, + pub mint: &'a AccountInfo, +} + +impl<'a> CreateTokenAccountAccounts<'a> { + pub fn new(accounts: &'a [AccountInfo]) -> Result { + Ok(Self { + token_account: &accounts[0], + mint: &accounts[1], + }) + } + + pub fn get_checked(accounts: &'a [AccountInfo]) -> Result { + let accounts_struct = Self::new(accounts)?; + + // Basic validations using light_account_checks + check_mut(accounts_struct.token_account)?; + check_non_mut(accounts_struct.mint)?; + + Ok(accounts_struct) + } +} \ No newline at end of file diff --git a/programs/compressed-token/program/src/create_token_account/instruction_data.rs b/programs/compressed-token/program/src/create_token_account/instruction_data.rs new file mode 100644 index 0000000000..98798be397 --- /dev/null +++ b/programs/compressed-token/program/src/create_token_account/instruction_data.rs @@ -0,0 +1,9 @@ +use borsh::{BorshDeserialize, BorshSerialize}; +use light_compressed_account::Pubkey; +use light_zero_copy::ZeroCopy; + +#[derive(Debug, Clone, BorshSerialize, BorshDeserialize, ZeroCopy)] +pub struct CreateTokenAccountInstructionData { + /// The owner of the token account + pub owner: Pubkey, +} \ No newline at end of file diff --git a/programs/compressed-token/program/src/create_token_account/mod.rs b/programs/compressed-token/program/src/create_token_account/mod.rs new file mode 100644 index 0000000000..e9133a6782 --- /dev/null +++ b/programs/compressed-token/program/src/create_token_account/mod.rs @@ -0,0 +1,5 @@ +pub mod accounts; +pub mod instruction_data; +pub mod processor; + +pub use processor::process_create_token_account; \ No newline at end of file diff --git a/programs/compressed-token/program/src/create_token_account/processor.rs b/programs/compressed-token/program/src/create_token_account/processor.rs new file mode 100644 index 0000000000..0fb38eb3f0 --- /dev/null +++ b/programs/compressed-token/program/src/create_token_account/processor.rs @@ -0,0 +1,30 @@ +use anchor_lang::prelude::ProgramError; +use light_zero_copy::borsh::Deserialize; +use pinocchio::account_info::AccountInfo; + +use super::{ + accounts::CreateTokenAccountAccounts, instruction_data::CreateTokenAccountInstructionData, +}; +use crate::shared::initialize_token_account::initialize_token_account; + +/// Process the create token account instruction +pub fn process_create_token_account<'info>( + account_infos: &'info [AccountInfo], + instruction_data: &[u8], +) -> Result<(), ProgramError> { + // Parse instruction data using zero-copy + let (inputs, _) = CreateTokenAccountInstructionData::zero_copy_at(instruction_data) + .map_err(ProgramError::from)?; + + // Validate and get accounts + let accounts = CreateTokenAccountAccounts::get_checked(account_infos)?; + + // Initialize the token account (assumes account already exists and is owned by our program) + initialize_token_account( + accounts.token_account, + accounts.mint.key(), + &inputs.owner.to_bytes(), + )?; + + Ok(()) +} diff --git a/programs/compressed-token/program/src/extensions/metadata_pointer.rs b/programs/compressed-token/program/src/extensions/metadata_pointer.rs new file mode 100644 index 0000000000..d8a72ffd93 --- /dev/null +++ b/programs/compressed-token/program/src/extensions/metadata_pointer.rs @@ -0,0 +1,45 @@ +use borsh::{BorshDeserialize, BorshSerialize}; +use light_compressed_account::Pubkey; +use light_sdk::LightHasher; +use light_zero_copy::ZeroCopy; + +/// Metadata pointer extension data for compressed mints. +#[derive(Debug, Clone, PartialEq, BorshSerialize, ZeroCopy, BorshDeserialize, LightHasher)] +pub struct MetadataPointer { + /// Authority that can set the metadata address + #[hash] + pub authority: Option, + /// Compressed address that holds the metadata (in token 22) + #[hash] + // TODO: implement manually, because there is no need to hash the compressed metadata_address + pub metadata_address: Option, +} + +#[derive( + Debug, PartialEq, Default, Clone, Copy, Eq, BorshSerialize, BorshDeserialize, ZeroCopy, +)] +pub struct NewAddressParamsAssignedPackedWithAddress { + pub address: [u8; 32], + pub seed: [u8; 32], + pub address_merkle_tree_account_index: u8, + pub address_merkle_tree_root_index: u16, +} + +impl MetadataPointer { + /// Validate metadata pointer - at least one field must be provided + pub fn validate(&self) -> Result<(), anchor_lang::prelude::ProgramError> { + if self.authority.is_none() && self.metadata_address.is_none() { + return Err(anchor_lang::prelude::ProgramError::InvalidInstructionData); + } + Ok(()) + } +} + +/// Instruction data for initializing metadata pointer +#[derive(Debug, Clone, BorshSerialize, BorshDeserialize, ZeroCopy)] +pub struct InitializeMetadataPointerInstructionData { + /// The authority that can set the metadata address + pub authority: Option, + /// The account address that holds the metadata + pub metadata_address_params: Option, +} diff --git a/programs/compressed-token/program/src/extensions/mod.rs b/programs/compressed-token/program/src/extensions/mod.rs new file mode 100644 index 0000000000..a2ae3cb190 --- /dev/null +++ b/programs/compressed-token/program/src/extensions/mod.rs @@ -0,0 +1,26 @@ +use anchor_compressed_token::ErrorCode; + +pub mod metadata_pointer; +pub mod processor; +pub mod token_metadata; + +pub enum ExtensionType { + /// Mint contains a pointer to another account (or the same account) that + /// holds metadata + MetadataPointer, + /// Mint contains token-metadata + TokenMetadata, +} +// use spl_token_2022::extension::ExtensionType SplExtensionType; + +impl TryFrom for ExtensionType { + type Error = ErrorCode; + + fn try_from(value: u8) -> Result { + match value { + 18 => Ok(ExtensionType::MetadataPointer), + 19 => Ok(ExtensionType::TokenMetadata), + _ => Err(ErrorCode::InvalidExtensionType), + } + } +} diff --git a/programs/compressed-token/program/src/extensions/processor.rs b/programs/compressed-token/program/src/extensions/processor.rs new file mode 100644 index 0000000000..6e2907c2d3 --- /dev/null +++ b/programs/compressed-token/program/src/extensions/processor.rs @@ -0,0 +1,123 @@ +use anchor_lang::prelude::ProgramError; +use borsh::BorshDeserialize; +use light_compressed_account::{ + compressed_account::ZCompressedAccountDataMut, + instruction_data::with_readonly::ZInstructionDataInvokeCpiWithReadOnlyMut, +}; + +use crate::{ + extensions::{ + metadata_pointer::InitializeMetadataPointerInstructionData, + token_metadata::{TokenMetadata, TOKEN_METADATA_DISCRIMINATOR}, + ExtensionType, + }, + mint::instructions::ZExtensionInstructionData, +}; + +// Applying extension(s) to compressed accounts. +pub fn process_create_extensions<'a>( + extensions: &[ZExtensionInstructionData], + cpi_data: &mut ZInstructionDataInvokeCpiWithReadOnlyMut<'a>, + mint_data_len: usize, +) -> Result<(), ProgramError> { + for extension in extensions { + match ExtensionType::try_from(extension.extension_type).unwrap() { + ExtensionType::MetadataPointer => { + // TODO: add a second new address params for the other address. + + // deserialize metadata pointer ix data + let has_address = create_metadata_pointer(extension.data, cpi_data, mint_data_len)?; + // only go ahed if has address, probably duplicate + if has_address { + create_token_metadata_account( + extension.data, + cpi_data.output_compressed_accounts[0] + .compressed_account + .data + .as_mut() + .unwrap(), + )?; + } + } + _ => return Err(ProgramError::InvalidInstructionData), + } + } + Ok(()) +} + +// We need to return the hash to add it to the overall output hash. +// TODO: remove the hash value and possibly the len of the instruction data +// TODO: do compatibility token 22 deserialization for all accounts. +// TODO: fix +fn create_metadata_pointer<'a>( + instruction_data: &[u8], + cpi_instruction_struct: &mut ZInstructionDataInvokeCpiWithReadOnlyMut<'a>, + mint_data_len: usize, +) -> Result { + use light_zero_copy::borsh::Deserialize; + // 1. Deserialize the metadata pointer instruction data + let (metadata_pointer_data, _) = + InitializeMetadataPointerInstructionData::zero_copy_at(instruction_data) + .map_err(|_| ProgramError::InvalidInstructionData)?; + if let Some(metadata_address_params) = metadata_pointer_data.metadata_address_params.as_ref() { + **cpi_instruction_struct.output_compressed_accounts[1] + .compressed_account + .address + .as_mut() + .unwrap() = metadata_address_params.address; + + cpi_instruction_struct.new_address_params[1].seed = metadata_address_params.seed; + cpi_instruction_struct.new_address_params[1].address_merkle_tree_root_index = + metadata_address_params.address_merkle_tree_root_index; + cpi_instruction_struct.new_address_params[1].assigned_account_index = 1; + // Note we can skip address derivation since we are assigning it to the account in index 0. + cpi_instruction_struct.new_address_params[1].assigned_to_account = 1; + cpi_instruction_struct.new_address_params[1].address_merkle_tree_account_index = + metadata_address_params.address_merkle_tree_account_index; + } + + let cpi_data = cpi_instruction_struct.output_compressed_accounts[1] + .compressed_account + .data + .as_mut() + .ok_or(ProgramError::InvalidInstructionData)?; + + if metadata_pointer_data.authority.is_none() + && metadata_pointer_data.metadata_address_params.is_none() + { + return Err(anchor_lang::prelude::ProgramError::InvalidInstructionData); + } + let start_offset = mint_data_len; + let mut end_offset = start_offset; + if metadata_pointer_data.authority.is_some() { + end_offset += 33; + } else { + end_offset += 1; + } + let hash_address = metadata_pointer_data.metadata_address_params.is_some(); + if metadata_pointer_data.metadata_address_params.is_some() { + end_offset += 33; + } else { + end_offset += 1; + } + // TODO: double test this is risky but should be ok + // The layout is also Option<[u8;32]>, Option<[u8;32], ..> but we cut off after 32 bytes. + cpi_data.data[start_offset..end_offset].copy_from_slice(&instruction_data); + + Ok(hash_address) +} + +// Could be ok +fn create_token_metadata_account<'a>( + mut instruction_data: &[u8], + cpi_data: &mut ZCompressedAccountDataMut<'a>, +) -> Result<(), ProgramError> { + // TODO: use zero copy (need to add string support or manual impl) + let token_metadata = TokenMetadata::deserialize(&mut instruction_data) + .map_err(|_| ProgramError::InvalidInstructionData)?; + let hash = TokenMetadata::hash(&token_metadata)?; + *cpi_data.data_hash = hash; + cpi_data.discriminator = TOKEN_METADATA_DISCRIMINATOR; + (*cpi_data.data).copy_from_slice(instruction_data); + Ok(()) +} diff --git a/programs/compressed-token/program/src/extensions/token_metadata.rs b/programs/compressed-token/program/src/extensions/token_metadata.rs new file mode 100644 index 0000000000..b3463f6de7 --- /dev/null +++ b/programs/compressed-token/program/src/extensions/token_metadata.rs @@ -0,0 +1,139 @@ +use borsh::{BorshDeserialize, BorshSerialize}; +use light_compressed_account::Pubkey; +use light_hasher::{ + hash_to_field_size::hashv_to_bn254_field_size_be_const_array, DataHasher, Hasher, HasherError, + Keccak, Poseidon, Sha256, +}; +use light_sdk::LightHasher; +use light_zero_copy::ZeroCopy; + +// TODO: decide whether to keep Shaflat +pub enum Version { + Poseidon, + Sha256, + Keccak256, + Sha256Flat, +} +// Same as extesion type enum TODO: check token 2022 equivalent. +pub const TOKEN_METADATA_DISCRIMINATOR: [u8; 8] = [0, 0, 0, 0, 0, 0, 0, 19]; + +impl TryFrom for Version { + type Error = HasherError; + + fn try_from(value: u8) -> Result { + match value { + 0 => Ok(Version::Poseidon), + 1 => Ok(Version::Sha256), + 2 => Ok(Version::Keccak256), + 3 => Ok(Version::Sha256Flat), + // TODO: use real error + _ => Err(HasherError::InvalidInputLength(value as usize, 3)), + } + } +} +// TODO: impl string for zero copy +// TODO: test deserialization equivalence +/// Used for onchain serialization +#[derive(Debug, Clone, PartialEq, Eq, BorshSerialize, BorshDeserialize)] +pub struct TokenMetadata { + /// The authority that can sign to update the metadata + pub update_authority: Option, + /// The associated mint, used to counter spoofing to be sure that metadata + /// belongs to a particular mint + pub mint: Pubkey, + pub metadata: Metadata, + /// Any additional metadata about the token as key-value pairs. The program + /// must avoid storing the same key twice. + pub additional_metadata: Vec, + /// 0: Poseidon, 1: Sha256, 2: Keccak256, 3: Sha256Flat + pub version: u8, +} + +impl TokenMetadata { + pub fn hash(&self) -> Result<[u8; 32], HasherError> { + match Version::try_from(self.version)? { + Version::Poseidon => ::hash::(self), + Version::Sha256 => ::hash::(self), + Version::Keccak256 => ::hash::(self), + Version::Sha256Flat => self.sha_flat(), + } + } + fn sha_flat(&self) -> Result<[u8; 32], HasherError> { + use borsh::BorshSerialize; + let vec = self.try_to_vec().map_err(|_| HasherError::BorshError)?; + Sha256::hash(vec.as_slice()) + } +} + +impl DataHasher for TokenMetadata { + fn hash(&self) -> Result<[u8; 32], HasherError> { + let mut vec = [[0u8; 32]; 5]; + let mut slice_vec: [&[u8]; 5] = [&[]; 5]; + if let Some(update_authority) = self.update_authority { + vec[0].copy_from_slice( + hashv_to_bn254_field_size_be_const_array::<2>(&[&update_authority.to_bytes()])? + .as_slice(), + ); + } + + vec[1] = hashv_to_bn254_field_size_be_const_array::<2>(&[&self.mint.to_bytes()])?; + vec[2] = self.metadata.hash::()?; + + for additional_metadata in &self.additional_metadata { + // TODO: add check is poseidon and throw meaningful error. + vec[3] = H::hashv(&[ + vec[3].as_slice(), + additional_metadata.key.as_bytes(), + additional_metadata.value.as_bytes(), + ])?; + } + vec[4][31] = self.version; + + slice_vec[0] = vec[0].as_slice(); + slice_vec[1] = vec[1].as_slice(); + slice_vec[2] = vec[2].as_slice(); + slice_vec[3] = vec[3].as_slice(); + + slice_vec[4] = vec[4].as_slice(); + if vec[4] != [0u8; 32] { + H::hashv(&slice_vec[..4]) + } else { + H::hashv(slice_vec.as_slice()) + } + } +} + +// TODO: if version 0 we check all string len for less than 31 bytes +#[derive(Debug, LightHasher, Clone, PartialEq, Eq, BorshSerialize, BorshDeserialize)] +pub struct Metadata { + /// The longer name of the token + pub name: String, + /// The shortened symbol for the token + pub symbol: String, + /// The URI pointing to richer metadata + pub uri: String, +} + +#[derive(Debug, Clone, PartialEq, Eq, BorshSerialize, BorshDeserialize)] +pub struct AdditionalMetadata { + /// The key of the metadata + pub key: String, + /// The value of the metadata + pub value: String, +} + +// Small instruction data input. +// TODO: impl hash fn that is consistent with full hash fn +pub struct SmallTokenMetadata { + /// The authority that can sign to update the metadata + pub update_authority: Option, + /// The associated mint, used to counter spoofing to be sure that metadata + /// belongs to a particular mint + pub mint: Pubkey, + pub metadata_hash: [u8; 32], + /// Any additional metadata about the token as key-value pairs. The program + /// must avoid storing the same key twice. + pub additional_metadata: Option>, + /// 0: Poseidon, 1: Sha256, 2: Keccak256, 3: Sha256Flat + pub version: u8, +} diff --git a/programs/compressed-token/program/src/lib.rs b/programs/compressed-token/program/src/lib.rs new file mode 100644 index 0000000000..578488c4f5 --- /dev/null +++ b/programs/compressed-token/program/src/lib.rs @@ -0,0 +1,198 @@ +use std::mem::ManuallyDrop; + +use anchor_lang::solana_program::program_error::ProgramError; + +use light_sdk::{cpi::CpiSigner, derive_light_cpi_signer}; +use pinocchio::account_info::AccountInfo; +use spl_token::instruction::TokenInstruction; + +pub mod close_token_account; +pub mod create_associated_token_account; +pub mod create_spl_mint; +pub mod create_token_account; +pub mod extensions; +pub mod mint; +pub mod mint_to_compressed; +pub mod multi_transfer; +pub mod shared; + +// Reexport the wrapped anchor program. +pub use ::anchor_compressed_token::*; +use close_token_account::processor::process_close_token_account; +use create_associated_token_account::processor::process_create_associated_token_account; +use create_spl_mint::processor::process_create_spl_mint; +use create_token_account::processor::process_create_token_account; +use mint::processor::process_create_compressed_mint; +use mint_to_compressed::processor::process_mint_to_compressed; + +pub const LIGHT_CPI_SIGNER: CpiSigner = + derive_light_cpi_signer!("cTokenmWW8bLPjZEBAUgYy3zKxQZW6VKi7bqNFEVv3m"); + +pub const MAX_ACCOUNTS: usize = 30; + +// Start light token instructions at 100 to skip spl-token program instrutions. +// When adding new instructions check anchor discriminators for collisions! +#[repr(u8)] +pub enum InstructionType { + DecompressedTransfer = 3, + CloseTokenAccount = 9, // SPL Token CloseAccount + CreateCompressedMint = 100, + MintToCompressed = 101, + CreateSplMint = 102, + CreateAssociatedTokenAccount = 103, + MultiTransfer = 104, + CreateTokenAccount = 18, // SPL Token InitializeAccount3 + Other, +} + +impl From for InstructionType { + fn from(value: u8) -> Self { + match value { + 3 => InstructionType::DecompressedTransfer, + 9 => InstructionType::CloseTokenAccount, + 100 => InstructionType::CreateCompressedMint, + 101 => InstructionType::MintToCompressed, + 102 => InstructionType::CreateSplMint, + 103 => InstructionType::CreateAssociatedTokenAccount, + 104 => InstructionType::MultiTransfer, + 18 => InstructionType::CreateTokenAccount, + _ => InstructionType::Other, + } + } +} + +#[cfg(not(feature = "cpi"))] +use pinocchio::program_entrypoint; + +use crate::multi_transfer::processor::process_multi_transfer; + +#[cfg(not(feature = "cpi"))] +program_entrypoint!(process_instruction); + +pub fn process_instruction( + program_id: &pinocchio::pubkey::Pubkey, + accounts: &[AccountInfo], + instruction_data: &[u8], +) -> Result<(), ProgramError> { + let discriminator = InstructionType::from(instruction_data[0]); + match discriminator { + InstructionType::DecompressedTransfer => { + let instruction = TokenInstruction::unpack(instruction_data)?; + match instruction { + TokenInstruction::Transfer { amount } => { + let account_infos = unsafe { convert_account_infos::(accounts)? }; + let program_id_pubkey = solana_pubkey::Pubkey::new_from_array(*program_id); + spl_token::processor::Processor::process_transfer( + &program_id_pubkey, + &account_infos, + amount, + None, + )?; + } + _ => return Err(ProgramError::InvalidInstructionData), + } + } + InstructionType::CreateCompressedMint => { + anchor_lang::solana_program::msg!("CreateCompressedMint"); + process_create_compressed_mint(*program_id, accounts, &instruction_data[1..])?; + } + InstructionType::MintToCompressed => { + anchor_lang::solana_program::msg!("MintToCompressed"); + process_mint_to_compressed(*program_id, accounts, &instruction_data[1..])?; + } + InstructionType::CreateSplMint => { + anchor_lang::solana_program::msg!("CreateSplMint"); + process_create_spl_mint(*program_id, accounts, &instruction_data[1..])?; + } + InstructionType::CreateAssociatedTokenAccount => { + anchor_lang::solana_program::msg!("CreateAssociatedTokenAccount"); + process_create_associated_token_account(accounts, &instruction_data[1..])?; + } + InstructionType::CreateTokenAccount => { + anchor_lang::solana_program::msg!("CreateTokenAccount"); + process_create_token_account(accounts, &instruction_data[1..])?; + } + InstructionType::CloseTokenAccount => { + anchor_lang::solana_program::msg!("CloseTokenAccount"); + process_close_token_account(accounts, &instruction_data[1..])?; + } + InstructionType::MultiTransfer => { + anchor_lang::solana_program::msg!("MultiTransfer"); + process_multi_transfer(accounts, &instruction_data[1..])?; + } + // anchor instructions have no discriminator conflicts with InstructionType + _ => { + let account_infos = unsafe { convert_account_infos::(accounts)? }; + let account_infos = ManuallyDrop::new(account_infos); + let solana_program_id = solana_pubkey::Pubkey::new_from_array(*program_id); + + entry( + &solana_program_id, + account_infos.as_slice(), + instruction_data, + )?; + } + } + Ok(()) +} + +/// Convert Pinocchio AccountInfo to Solana AccountInfo with minimal safety overhead +/// +/// # SAFETY +/// - `pinocchio_accounts` must remain valid for lifetime 'a +/// - No other code may mutably borrow these accounts during 'a +/// - Pinocchio runtime must have properly deserialized the accounts +/// - Caller must ensure no concurrent access to returned AccountInfo +#[inline(always)] +pub unsafe fn convert_account_infos<'a, const N: usize>( + pinocchio_accounts: &'a [AccountInfo], +) -> Result, N>, ProgramError> { + if pinocchio_accounts.len() > N { + return Err(ProgramError::MaxAccountsDataAllocationsExceeded); + } + + use std::cell::RefCell; + use std::rc::Rc; + + // Compile-time type safety: Ensure Pubkey types are layout-compatible + const _: () = { + assert!( + std::mem::size_of::() + == std::mem::size_of::() + ); + assert!( + std::mem::align_of::() + == std::mem::align_of::() + ); + }; + + let mut solana_accounts = arrayvec::ArrayVec::, N>::new(); + for pinocchio_account in pinocchio_accounts { + let key: &'a solana_pubkey::Pubkey = + &*(pinocchio_account.key() as *const _ as *const solana_pubkey::Pubkey); + + let owner: &'a solana_pubkey::Pubkey = + &*(pinocchio_account.owner() as *const _ as *const solana_pubkey::Pubkey); + + let lamports = Rc::new(RefCell::new( + pinocchio_account.borrow_mut_lamports_unchecked(), + )); + + let data = Rc::new(RefCell::new(pinocchio_account.borrow_mut_data_unchecked())); + + let account_info = anchor_lang::prelude::AccountInfo { + key, + lamports, + data, + owner, + rent_epoch: 0, // Pinocchio doesn't track rent epoch + is_signer: pinocchio_account.is_signer(), + is_writable: pinocchio_account.is_writable(), + executable: pinocchio_account.executable(), + }; + + solana_accounts.push(account_info); + } + + Ok(solana_accounts) +} diff --git a/programs/compressed-token/program/src/mint/accounts.rs b/programs/compressed-token/program/src/mint/accounts.rs new file mode 100644 index 0000000000..96825ba1c7 --- /dev/null +++ b/programs/compressed-token/program/src/mint/accounts.rs @@ -0,0 +1,92 @@ +use crate::constants::BUMP_CPI_AUTHORITY; +use account_compression::utils::constants::CPI_AUTHORITY_PDA_SEED; +use anchor_lang::solana_program::program_error::ProgramError; +use light_account_checks::checks::{ + check_mut, check_non_mut, check_pda_seeds_with_bump, check_program, check_signer, +}; +use light_compressed_account::constants::ACCOUNT_COMPRESSION_PROGRAM_ID; +use pinocchio::{account_info::AccountInfo, pubkey::Pubkey}; + +pub struct CreateCompressedMintAccounts<'info> { + pub address_merkle_tree: &'info AccountInfo, + pub mint_signer: &'info AccountInfo, +} + +impl<'info> CreateCompressedMintAccounts<'info> { + pub fn validate_and_parse( + accounts: &'info [AccountInfo], + program_id: &Pubkey, + ) -> Result { + if accounts.len() != 12 { + return Err(ProgramError::NotEnoughAccountKeys); + } + + // Static non-CPI accounts first + let mint_signer = &accounts[0]; + let light_system_program = &accounts[1]; + + // CPI accounts in exact order expected by InvokeCpiWithReadOnly + let fee_payer = &accounts[2]; + let cpi_authority_pda = &accounts[3]; + let registered_program_pda = &accounts[4]; + let noop_program = &accounts[5]; + let account_compression_authority = &accounts[6]; + let account_compression_program = &accounts[7]; + let self_program = &accounts[8]; + // let sol_pool_pda_placeholder = &accounts[9]; // light_system_program placeholder + // let _decompression_recipient_placeholder = &accounts[10]; // light_system_program placeholder + let system_program = &accounts[9]; + // let _cpi_context_placeholder = &accounts[12]; // light_system_program placeholder + let address_merkle_tree = &accounts[10]; + let output_queue = &accounts[11]; + + // Validate fee_payer: must be signer and mutable + check_signer(fee_payer).map_err(ProgramError::from)?; + check_mut(fee_payer).map_err(ProgramError::from)?; + + // Validate cpi_authority_pda: must be the correct PDA + let expected_seeds = &[CPI_AUTHORITY_PDA_SEED, &[BUMP_CPI_AUTHORITY]]; + check_pda_seeds_with_bump(expected_seeds, program_id, cpi_authority_pda) + .map_err(ProgramError::from)?; + + // Validate light_system_program: must be the correct program + // The placeholders are always None -> no need for an extra light system program account info. + let light_system_program_id = light_system_program::id(); + check_program(&light_system_program_id.to_bytes(), light_system_program) + .map_err(ProgramError::from)?; + + // Validate account_compression_program: must be the correct program + check_program(&ACCOUNT_COMPRESSION_PROGRAM_ID, account_compression_program) + .map_err(ProgramError::from)?; + + // Validate registered_program_pda: non-mutable + check_non_mut(registered_program_pda).map_err(ProgramError::from)?; + + // Validate noop_program: non-mutable + check_non_mut(noop_program).map_err(ProgramError::from)?; + + // Validate account_compression_authority: non-mutable + check_non_mut(account_compression_authority).map_err(ProgramError::from)?; + + // Validate self_program: must be this program + check_program(program_id, self_program).map_err(ProgramError::from)?; + + // Validate system_program: must be the system program + let system_program_id = anchor_lang::solana_program::system_program::ID; + check_program(&system_program_id.to_bytes(), system_program).map_err(ProgramError::from)?; + + // Validate address_merkle_tree: mutable + check_mut(address_merkle_tree).map_err(ProgramError::from)?; + + // Validate output_queue: mutable + check_mut(output_queue).map_err(ProgramError::from)?; + + // Validate mint_signer: must be signer + check_signer(mint_signer).map_err(ProgramError::from)?; + + Ok(CreateCompressedMintAccounts { + address_merkle_tree, + mint_signer, + }) + } +} diff --git a/programs/compressed-token/program/src/mint/input.rs b/programs/compressed-token/program/src/mint/input.rs new file mode 100644 index 0000000000..2c8d1cbebc --- /dev/null +++ b/programs/compressed-token/program/src/mint/input.rs @@ -0,0 +1,99 @@ +use anchor_lang::solana_program::program_error::ProgramError; +use light_compressed_account::instruction_data::with_readonly::ZInAccountMut; + +use crate::{ + constants::COMPRESSED_MINT_DISCRIMINATOR, mint::state::CompressedMint, + mint_to_compressed::instructions::ZCompressedMintInputs, shared::context::TokenContext, +}; + +/// Creates and validates an input compressed mint account. +/// This function follows the same pattern as create_output_compressed_mint_account +/// but processes existing compressed mint accounts as inputs. +/// +/// Steps: +/// 1. Set InAccount fields (discriminator, merkle context, address) +/// 2. Validate the compressed mint data matches expected values +/// 3. Compute data hash using TokenContext for caching +/// 4. Return validated CompressedMint data for output processing +pub fn create_input_compressed_mint_account( + input_compressed_account: &mut ZInAccountMut, + context: &mut TokenContext, + compressed_mint_inputs: &ZCompressedMintInputs, + hashed_mint_authority: &[u8; 32], +) -> Result<(), ProgramError> { + // 1. Set InAccount fields + { + input_compressed_account.discriminator = COMPRESSED_MINT_DISCRIMINATOR; + // Set merkle context fields manually due to mutability constraints + input_compressed_account + .merkle_context + .merkle_tree_pubkey_index = compressed_mint_inputs + .merkle_context + .merkle_tree_pubkey_index; + input_compressed_account.merkle_context.queue_pubkey_index = + compressed_mint_inputs.merkle_context.queue_pubkey_index; + input_compressed_account + .merkle_context + .leaf_index + .set(compressed_mint_inputs.merkle_context.leaf_index.get()); + input_compressed_account.merkle_context.prove_by_index = + compressed_mint_inputs.merkle_context.prove_by_index; + input_compressed_account + .root_index + .set(compressed_mint_inputs.root_index.get()); + + input_compressed_account + .address + .as_mut() + .ok_or(ProgramError::InvalidAccountData)? + .copy_from_slice(compressed_mint_inputs.address.as_ref()); + } + + // 2. Extract and validate compressed mint data + let compressed_mint_input = &compressed_mint_inputs.compressed_mint_input; + + // // Create the expected CompressedMint structure for validation + // let compressed_mint = CompressedMint { + // spl_mint: compressed_mint_input.spl_mint, + // supply: compressed_mint_input.supply.get(), + // decimals: compressed_mint_input.decimals, + // is_decompressed: compressed_mint_input.is_decompressed(), + // mint_authority: None, // Will be set based on validation + // freeze_authority: if compressed_mint_input.freeze_authority_is_set() { + // Some(compressed_mint_input.freeze_authority) + // } else { + // None + // }, + // num_extensions: compressed_mint_input.num_extensions, + // }; + + // 3. Compute data hash using TokenContext for caching + { + let hashed_spl_mint = context.get_or_hash_mint(&compressed_mint_input.spl_mint.into())?; + let mut supply_bytes = [0u8; 32]; + supply_bytes[24..] + .copy_from_slice(compressed_mint_input.supply.get().to_be_bytes().as_slice()); + + let hashed_freeze_authority = if compressed_mint_input.freeze_authority_is_set() { + Some(context.get_or_hash_pubkey(&compressed_mint_input.freeze_authority.into())) + } else { + None + }; + + // Compute the data hash using the CompressedMint hash function + let data_hash = CompressedMint::hash_with_hashed_values( + &hashed_spl_mint, + &supply_bytes, + compressed_mint_input.decimals, + compressed_mint_input.is_decompressed(), + &Some(hashed_mint_authority), // pre-hashed mint_authority from signer + &hashed_freeze_authority.as_ref(), + compressed_mint_input.num_extensions, + ) + .map_err(|_| ProgramError::InvalidAccountData)?; + + input_compressed_account.data_hash = data_hash; + } + + Ok(()) +} diff --git a/programs/compressed-token/program/src/mint/instructions.rs b/programs/compressed-token/program/src/mint/instructions.rs new file mode 100644 index 0000000000..eedb3a0a34 --- /dev/null +++ b/programs/compressed-token/program/src/mint/instructions.rs @@ -0,0 +1,22 @@ +use borsh::{BorshDeserialize, BorshSerialize}; +use light_compressed_account::{instruction_data::compressed_proof::CompressedProof, Pubkey}; +use light_zero_copy::ZeroCopy; + +#[derive(Debug, Clone, BorshSerialize, BorshDeserialize, ZeroCopy)] +pub struct CreateCompressedMintInstructionData { + pub decimals: u8, + pub mint_authority: Pubkey, + pub proof: CompressedProof, + pub mint_bump: u8, + pub address_merkle_tree_root_index: u16, + // compressed address TODO: make a type CompressedAddress + pub mint_address: [u8; 32], + pub freeze_authority: Option, + pub extensions: Option>, +} + +#[derive(Debug, Clone, BorshSerialize, BorshDeserialize, ZeroCopy)] +pub struct ExtensionInstructionData { + pub extension_type: u8, + pub data: Vec, +} diff --git a/programs/compressed-token/program/src/mint/mod.rs b/programs/compressed-token/program/src/mint/mod.rs new file mode 100644 index 0000000000..9370c97179 --- /dev/null +++ b/programs/compressed-token/program/src/mint/mod.rs @@ -0,0 +1,6 @@ +pub mod accounts; +pub mod input; +pub mod instructions; +pub mod output; +pub mod processor; +pub mod state; diff --git a/programs/compressed-token/program/src/mint/output.rs b/programs/compressed-token/program/src/mint/output.rs new file mode 100644 index 0000000000..29289baf85 --- /dev/null +++ b/programs/compressed-token/program/src/mint/output.rs @@ -0,0 +1,76 @@ +use anchor_lang::solana_program::program_error::ProgramError; +use light_compressed_account::{ + instruction_data::data::ZOutputCompressedAccountWithPackedContextMut, Pubkey, +}; + +use light_zero_copy::ZeroCopyNew; +use zerocopy::little_endian::U64; + +use crate::{ + constants::COMPRESSED_MINT_DISCRIMINATOR, + mint::state::{CompressedMint, CompressedMintConfig}, +}; +// TODO: pass in struct +#[allow(clippy::too_many_arguments)] +pub fn create_output_compressed_mint_account( + output_compressed_account: &mut ZOutputCompressedAccountWithPackedContextMut, + mint_pda: Pubkey, + decimals: u8, + freeze_authority: Option, + mint_authority: Option, + supply: U64, + program_id: &Pubkey, + mint_config: CompressedMintConfig, + compressed_account_address: [u8; 32], + merkle_tree_index: u8, +) -> Result<(), ProgramError> { + // 3. Create output compressed account + { + // TODO: create helper to assign output_compressed_account + output_compressed_account.compressed_account.owner = *program_id; + + if let Some(address) = output_compressed_account + .compressed_account + .address + .as_deref_mut() + { + *address = compressed_account_address; + } else { + panic!("Compressed account address is required"); + } + *output_compressed_account.merkle_tree_index = merkle_tree_index; + } + // 4. Create CompressedMint account data & compute hash + { + // TODO: create helper to assign compressed account data + let compressed_account_data = output_compressed_account + .compressed_account + .data + .as_mut() + .ok_or(ProgramError::InvalidAccountData)?; + + compressed_account_data.discriminator = COMPRESSED_MINT_DISCRIMINATOR; + let (mut compressed_mint, _) = + CompressedMint::new_zero_copy(compressed_account_data.data, mint_config) + .map_err(ProgramError::from)?; + compressed_mint.spl_mint = mint_pda; + compressed_mint.decimals = decimals; + compressed_mint.supply = supply; + if let Some(freeze_auth) = freeze_authority { + if let Some(z_freeze_authority) = compressed_mint.freeze_authority.as_deref_mut() { + *z_freeze_authority = freeze_auth; + } + } + if let Some(mint_auth) = mint_authority { + if let Some(z_mint_authority) = compressed_mint.mint_authority.as_deref_mut() { + *z_mint_authority = mint_auth; + } + } + + *compressed_account_data.data_hash = compressed_mint + .hash() + .map_err(|_| ProgramError::InvalidAccountData)?; + } + + Ok(()) +} diff --git a/programs/compressed-token/program/src/mint/processor.rs b/programs/compressed-token/program/src/mint/processor.rs new file mode 100644 index 0000000000..c07b6f8fbf --- /dev/null +++ b/programs/compressed-token/program/src/mint/processor.rs @@ -0,0 +1,198 @@ +use anchor_lang::{prelude::msg, solana_program::program_error::ProgramError}; +use light_compressed_account::{ + address::derive_address, + compressed_account::{CompressedAccountConfig, CompressedAccountDataConfig}, + instruction_data::{ + compressed_proof::CompressedProofConfig, + cpi_context::CompressedCpiContextConfig, + data::{NewAddressParamsPackedConfig, OutputCompressedAccountWithPackedContextConfig}, + invoke_cpi::{InstructionDataInvokeCpi, InstructionDataInvokeCpiConfig}, + with_readonly::{ + InstructionDataInvokeCpiWithReadOnly, InstructionDataInvokeCpiWithReadOnlyConfig, + }, + }, + Pubkey, +}; +use light_sdk_pinocchio::NewAddressParamsAssignedPackedConfig; +use light_zero_copy::borsh::Deserialize; +use pinocchio::account_info::AccountInfo; +use spl_token::solana_program::log::sol_log_compute_units; + +use crate::{ + extensions::{ + metadata_pointer::InitializeMetadataPointerInstructionData, + processor::process_create_extensions, ExtensionType, + }, + mint::{ + accounts::CreateCompressedMintAccounts, + instructions::CreateCompressedMintInstructionData, + output::create_output_compressed_mint_account, + state::{CompressedMint, CompressedMintConfig}, + }, + shared::cpi::execute_cpi_invoke, +}; + +pub fn process_create_compressed_mint( + program_id: pinocchio::pubkey::Pubkey, + accounts: &[AccountInfo], + instruction_data: &[u8], +) -> Result<(), ProgramError> { + sol_log_compute_units(); + let (parsed_instruction_data, _) = + CreateCompressedMintInstructionData::zero_copy_at(instruction_data) + .map_err(|_| ProgramError::InvalidInstructionData)?; + sol_log_compute_units(); + + // Validate and parse accounts + let validated_accounts = + CreateCompressedMintAccounts::validate_and_parse(accounts, &program_id.into())?; + // 1. Create mint PDA using provided bump + let mint_pda: Pubkey = solana_pubkey::Pubkey::create_program_address( + &[ + b"compressed_mint", + validated_accounts.mint_signer.key().as_slice(), + &[parsed_instruction_data.mint_bump], + ], + &program_id.into(), + )? + .into(); + use light_zero_copy::ZeroCopyNew; + + let mint_size_config: ::ZeroCopyConfig = CompressedMintConfig { + mint_authority: (true, ()), + freeze_authority: (parsed_instruction_data.freeze_authority.is_some(), ()), + }; + let compressed_mint_len = CompressedMint::byte_len(&mint_size_config) as u32; + let mut output_compressed_accounts = vec![OutputCompressedAccountWithPackedContextConfig { + compressed_account: CompressedAccountConfig { + address: (true, ()), + data: ( + true, + CompressedAccountDataConfig { + data: compressed_mint_len, + }, + ), + }, + }]; + let mut new_address_params = vec![NewAddressParamsAssignedPackedConfig {}]; + if parsed_instruction_data.extensions.is_some() { + for extension in parsed_instruction_data.extensions.as_ref().unwrap().iter() { + match ExtensionType::try_from(extension.extension_type).unwrap() { + ExtensionType::MetadataPointer => { + let (extension, token_metadata) = + InitializeMetadataPointerInstructionData::zero_copy_at(extension.data) + .map_err(|_| ProgramError::InvalidInstructionData)?; + let mut data_len = 0; + if extension.authority.is_some() { + data_len += 33; + } else { + data_len += 1; + }; + if extension.metadata_address_params.is_some() { + data_len += 33; + } else { + data_len += 1; + }; + // increased mint account data len + output_compressed_accounts[0].compressed_account.data.1.data += data_len; + // set token metadata account data len + if !token_metadata.is_empty() { + new_address_params.push(NewAddressParamsAssignedPackedConfig {}); + output_compressed_accounts.push( + OutputCompressedAccountWithPackedContextConfig { + compressed_account: CompressedAccountConfig { + address: (true, ()), + data: ( + true, + CompressedAccountDataConfig { + data: token_metadata.len() as u32, + }, + ), + }, + }, + ); + } + } + _ => return Err(ProgramError::InvalidInstructionData), + } + } + } + let final_compressed_mint_len = output_compressed_accounts[0].compressed_account.data.1.data; + let config = InstructionDataInvokeCpiWithReadOnlyConfig { + cpi_context: CompressedCpiContextConfig {}, + input_compressed_accounts: vec![], + proof: (true, CompressedProofConfig {}), + read_only_accounts: vec![], + read_only_addresses: vec![], + new_address_params, + output_compressed_accounts, + }; + // TODO: InstructionDataInvokeCpi::Output -> InstructionDataInvokeCpi::ZeroCopyMut and InstructionDataInvokeCpi::ZeroCopy + // TODO: hardcode since len is constant + let vec_len = InstructionDataInvokeCpiWithReadOnly::byte_len(&config); + msg!("vec len {}", vec_len); + // + discriminator len + vector len + let mut cpi_bytes = vec![0u8; vec_len + 8 + 4]; + cpi_bytes[0..8] + .copy_from_slice(&light_compressed_account::discriminators::DISCRIMINATOR_INVOKE_CPI); + cpi_bytes[8..12].copy_from_slice(&(vec_len as u32).to_le_bytes()); + + sol_log_compute_units(); + let (mut cpi_instruction_struct, _) = + InstructionDataInvokeCpiWithReadOnly::new_zero_copy(&mut cpi_bytes[12..], config) + .map_err(ProgramError::from)?; + sol_log_compute_units(); + + let proof = cpi_instruction_struct + .proof + .as_deref_mut() + .ok_or(ProgramError::InvalidInstructionData)?; + proof.a = parsed_instruction_data.proof.a; + proof.b = parsed_instruction_data.proof.b; + proof.c = parsed_instruction_data.proof.c; + // 1. Create NewAddressParams + cpi_instruction_struct.new_address_params[0].seed = mint_pda.to_bytes(); + cpi_instruction_struct.new_address_params[0].address_merkle_tree_root_index = + *parsed_instruction_data.address_merkle_tree_root_index; + cpi_instruction_struct.new_address_params[0].assigned_account_index = 0; + // Note we can skip address derivation since we are assigning it to the account in index 0. + cpi_instruction_struct.new_address_params[0].assigned_to_account = 1; + // 2. process token extensions. + if let Some(extensions) = parsed_instruction_data.extensions.as_ref() { + process_create_extensions( + extensions, + &mut cpi_instruction_struct, + final_compressed_mint_len as usize, + )?; + } + // 2. Create compressed mint account data + create_output_compressed_mint_account( + &mut cpi_instruction_struct.output_compressed_accounts[0], + mint_pda, + parsed_instruction_data.decimals, + parsed_instruction_data.freeze_authority.map(|fa| *fa), + Some(parsed_instruction_data.mint_authority), + 0.into(), + &program_id.into(), + mint_size_config, + *parsed_instruction_data.mint_address, + 1, + )?; + sol_log_compute_units(); + // 3. Execute CPI to light-system-program + // Extract tree accounts for the generalized CPI call + let tree_accounts = [accounts[10].key(), accounts[11].key()]; // address_merkle_tree, output_queue + let _accounts = accounts[1..] + .iter() + .map(|account| account.key()) + .collect::>(); + msg!("tree_accounts {:?}", tree_accounts); + msg!("accounts {:?}", _accounts); + execute_cpi_invoke( + &accounts[2..], // Skip first non-CPI account (mint_signer) + cpi_bytes, + tree_accounts.as_slice(), + false, // no sol_pool_pda for create_compressed_mint + None, // no cpi_context_account for create_compressed_mint + ) +} diff --git a/programs/compressed-token/program/src/mint/state.rs b/programs/compressed-token/program/src/mint/state.rs new file mode 100644 index 0000000000..cd0fe2a9c5 --- /dev/null +++ b/programs/compressed-token/program/src/mint/state.rs @@ -0,0 +1,164 @@ +use borsh::{BorshDeserialize, BorshSerialize}; +use light_compressed_account::{hash_to_bn254_field_size_be, Pubkey}; +use light_hasher::{errors::HasherError, Hasher, Poseidon}; +use light_zero_copy::ZeroCopyMut; +use zerocopy::IntoBytes; + +// Order is optimized for hashing. +// freeze_authority option is skipped if None. +#[derive(Debug, PartialEq, Eq, Clone, BorshSerialize, BorshDeserialize, ZeroCopyMut)] +pub struct CompressedMint { + /// Pda with seed address of compressed mint + pub spl_mint: Pubkey, + /// Total supply of tokens. + pub supply: u64, + /// Number of base 10 digits to the right of the decimal place. + pub decimals: u8, + /// Extension, necessary for mint to. + pub is_decompressed: bool, + /// Optional authority used to mint new tokens. The mint authority may only + /// be provided during mint creation. If no mint authority is present + /// then the mint has a fixed supply and no further tokens may be + /// minted. + pub mint_authority: Option, + /// Optional authority to freeze token accounts. + pub freeze_authority: Option, + // TODO: add extension hash to hash + pub num_extensions: u8, + // use nested token metadata layout for data extension + pub extension_hash: [u8; 32], +} + +impl CompressedMint { + #[allow(dead_code)] + pub fn hash(&self) -> std::result::Result<[u8; 32], HasherError> { + let hashed_spl_mint = hash_to_bn254_field_size_be(self.spl_mint.to_bytes().as_slice()); + let mut supply_bytes = [0u8; 32]; + supply_bytes[24..].copy_from_slice(self.supply.to_be_bytes().as_slice()); + + let hashed_mint_authority; + let hashed_mint_authority_option = if let Some(mint_authority) = self.mint_authority { + hashed_mint_authority = + hash_to_bn254_field_size_be(mint_authority.to_bytes().as_slice()); + Some(&hashed_mint_authority) + } else { + None + }; + + let hashed_freeze_authority; + let hashed_freeze_authority_option = if let Some(freeze_authority) = self.freeze_authority { + hashed_freeze_authority = + hash_to_bn254_field_size_be(freeze_authority.to_bytes().as_slice()); + Some(&hashed_freeze_authority) + } else { + None + }; + + Self::hash_with_hashed_values( + &hashed_spl_mint, + &supply_bytes, + self.decimals, + self.is_decompressed, + &hashed_mint_authority_option, + &hashed_freeze_authority_option, + self.num_extensions, + ) + } + + pub fn hash_with_hashed_values( + hashed_spl_mint: &[u8; 32], + supply_bytes: &[u8; 32], + decimals: u8, + is_decompressed: bool, + hashed_mint_authority: &Option<&[u8; 32]>, + hashed_freeze_authority: &Option<&[u8; 32]>, + num_extensions: u8, + ) -> std::result::Result<[u8; 32], HasherError> { + let mut hash_inputs = vec![hashed_spl_mint.as_slice(), supply_bytes.as_slice()]; + + // Add decimals with prefix if not 0 + let mut decimals_bytes = [0u8; 32]; + if decimals != 0 { + decimals_bytes[30] = 1; // decimals prefix + decimals_bytes[31] = decimals; + hash_inputs.push(&decimals_bytes[..]); + } + + // Add is_decompressed with prefix if true + let mut is_decompressed_bytes = [0u8; 32]; + if is_decompressed { + is_decompressed_bytes[30] = 2; // is_decompressed prefix + is_decompressed_bytes[31] = 1; // true as 1 + hash_inputs.push(&is_decompressed_bytes[..]); + } + + // Add mint authority if present + if let Some(hashed_mint_authority) = hashed_mint_authority { + hash_inputs.push(hashed_mint_authority.as_slice()); + } + + // Add freeze authority if present + let empty_authority = [0u8; 32]; + if let Some(hashed_freeze_authority) = hashed_freeze_authority { + // If there is freeze authority but no mint authority, add empty mint authority + if hashed_mint_authority.is_none() { + hash_inputs.push(&empty_authority[..]); + } + hash_inputs.push(hashed_freeze_authority.as_slice()); + } + + // Add num_extensions with prefix if not 0 + let mut num_extensions_bytes = [0u8; 32]; + if num_extensions != 0 { + num_extensions_bytes[30] = 3; // num_extensions prefix + num_extensions_bytes[31] = num_extensions; + hash_inputs.push(&num_extensions_bytes[..]); + } + + Poseidon::hashv(hash_inputs.as_slice()) + } +} + +impl ZCompressedMintMut<'_> { + pub fn hash(&self) -> std::result::Result<[u8; 32], HasherError> { + let hashed_spl_mint = hash_to_bn254_field_size_be(self.spl_mint.to_bytes().as_slice()); + let mut supply_bytes = [0u8; 32]; + // TODO: copy from slice + self.supply + .as_bytes() + .iter() + .rev() + .zip(supply_bytes[24..].iter_mut()) + .for_each(|(x, y)| *y = *x); + + let hashed_mint_authority; + let hashed_mint_authority_option = + if let Some(mint_authority) = self.mint_authority.as_ref() { + hashed_mint_authority = + hash_to_bn254_field_size_be(mint_authority.to_bytes().as_slice()); + Some(&hashed_mint_authority) + } else { + None + }; + + let hashed_freeze_authority; + let hashed_freeze_authority_option = + if let Some(freeze_authority) = self.freeze_authority.as_ref() { + hashed_freeze_authority = + hash_to_bn254_field_size_be(freeze_authority.to_bytes().as_slice()); + Some(&hashed_freeze_authority) + } else { + None + }; + + CompressedMint::hash_with_hashed_values( + &hashed_spl_mint, + &supply_bytes, + self.decimals, + self.is_decompressed(), + &hashed_mint_authority_option, + &hashed_freeze_authority_option, + *self.num_extensions, + ) + } +} diff --git a/programs/compressed-token/program/src/mint_to_compressed/accounts.rs b/programs/compressed-token/program/src/mint_to_compressed/accounts.rs new file mode 100644 index 0000000000..9e7aea66bd --- /dev/null +++ b/programs/compressed-token/program/src/mint_to_compressed/accounts.rs @@ -0,0 +1,113 @@ +use anchor_lang::solana_program::program_error::ProgramError; +use light_account_checks::checks::{check_mut, check_signer}; +use pinocchio::account_info::AccountInfo; +use crate::shared::AccountIterator; + +pub struct MintToCompressedAccounts<'info> { + pub fee_payer: &'info AccountInfo, + pub authority: &'info AccountInfo, + pub cpi_authority_pda: &'info AccountInfo, + pub mint: Option<&'info AccountInfo>, + pub token_pool_pda: Option<&'info AccountInfo>, + pub token_program: Option<&'info AccountInfo>, + pub light_system_program: &'info AccountInfo, + pub registered_program_pda: &'info AccountInfo, + pub noop_program: &'info AccountInfo, + pub account_compression_authority: &'info AccountInfo, + pub account_compression_program: &'info AccountInfo, + pub self_program: &'info AccountInfo, + pub system_program: &'info AccountInfo, + pub sol_pool_pda: Option<&'info AccountInfo>, + pub mint_in_merkle_tree: &'info AccountInfo, + pub mint_in_queue: &'info AccountInfo, + pub mint_out_queue: &'info AccountInfo, + pub tokens_out_queue: &'info AccountInfo, +} + +impl<'info> MintToCompressedAccounts<'info> { + + pub fn validate_and_parse( + accounts: &'info [AccountInfo], + with_lamports: bool, + is_decompressed: bool, + ) -> Result { + // Calculate minimum accounts needed + let mut base_accounts = 13; + + if with_lamports { + base_accounts += 1; + }; + if is_decompressed { + base_accounts += 3; // Add mint, token_pool_pda, token_program + }; + if accounts.len() < base_accounts { + return Err(ProgramError::NotEnoughAccountKeys); + } + + let mut iter = AccountIterator::new(accounts); + + // Static non-CPI accounts first + let authority = iter.next()?; + + let (mint, token_pool_pda, token_program) = if is_decompressed { + ( + Some(iter.next()?), + Some(iter.next()?), + Some(iter.next()?), + ) + } else { + (None, None, None) + }; + + let light_system_program = iter.next()?; + + // CPI accounts in exact order expected by InvokeCpiWithReadOnly + let fee_payer = iter.next()?; + let cpi_authority_pda = iter.next()?; + let registered_program_pda = iter.next()?; + let noop_program = iter.next()?; + let account_compression_authority = iter.next()?; + let account_compression_program = iter.next()?; + let self_program = iter.next()?; + let system_program = iter.next()?; + + let sol_pool_pda = if with_lamports { + Some(iter.next()?) + } else { + None + }; + + let mint_in_merkle_tree = iter.next()?; + let mint_in_queue = iter.next()?; + let mint_out_queue = iter.next()?; + let tokens_out_queue = iter.next()?; + + // Validate fee_payer: must be signer and mutable + check_signer(fee_payer).map_err(ProgramError::from)?; + check_mut(fee_payer).map_err(ProgramError::from)?; + + // Validate authority: must be signer + check_signer(authority).map_err(ProgramError::from)?; + + Ok(MintToCompressedAccounts { + fee_payer, + authority, + cpi_authority_pda, + mint, + token_pool_pda, + token_program, + light_system_program, + registered_program_pda, + noop_program, + account_compression_authority, + account_compression_program, + system_program, + sol_pool_pda, + self_program, + mint_in_merkle_tree, + mint_in_queue, + mint_out_queue, + tokens_out_queue, + }) + } +} diff --git a/programs/compressed-token/program/src/mint_to_compressed/input_compressed_mint.rs b/programs/compressed-token/program/src/mint_to_compressed/input_compressed_mint.rs new file mode 100644 index 0000000000..e69de29bb2 diff --git a/programs/compressed-token/program/src/mint_to_compressed/instructions.rs b/programs/compressed-token/program/src/mint_to_compressed/instructions.rs new file mode 100644 index 0000000000..4b6b224faa --- /dev/null +++ b/programs/compressed-token/program/src/mint_to_compressed/instructions.rs @@ -0,0 +1,40 @@ +use borsh::{BorshDeserialize, BorshSerialize}; +use light_compressed_account::{ + compressed_account::PackedMerkleContext, instruction_data::compressed_proof::CompressedProof, + Pubkey, +}; +use light_zero_copy::ZeroCopy; + +#[derive(Debug, Clone, BorshSerialize, BorshDeserialize, ZeroCopy)] +pub struct CompressedMintInputs { + pub merkle_context: PackedMerkleContext, + pub root_index: u16, + pub address: [u8; 32], + pub compressed_mint_input: CompressedMintInput, + pub output_merkle_tree_index: u8, +} + +#[derive(Debug, Clone, BorshSerialize, BorshDeserialize, ZeroCopy)] +pub struct CompressedMintInput { + pub spl_mint: Pubkey, + pub supply: u64, + pub decimals: u8, + pub is_decompressed: bool, + pub freeze_authority_is_set: bool, + pub freeze_authority: Pubkey, + pub num_extensions: u8, +} + +#[derive(Debug, Clone, BorshSerialize, BorshDeserialize, ZeroCopy)] +pub struct Recipient { + pub recipient: Pubkey, + pub amount: u64, +} + +#[derive(Debug, Clone, BorshSerialize, BorshDeserialize, ZeroCopy)] +pub struct MintToCompressedInstructionData { + pub compressed_mint_inputs: CompressedMintInputs, + pub lamports: Option, + pub recipients: Vec, + pub proof: Option, +} diff --git a/programs/compressed-token/program/src/mint_to_compressed/mod.rs b/programs/compressed-token/program/src/mint_to_compressed/mod.rs new file mode 100644 index 0000000000..c31719e252 --- /dev/null +++ b/programs/compressed-token/program/src/mint_to_compressed/mod.rs @@ -0,0 +1,3 @@ +pub mod accounts; +pub mod instructions; +pub mod processor; \ No newline at end of file diff --git a/programs/compressed-token/program/src/mint_to_compressed/processor.rs b/programs/compressed-token/program/src/mint_to_compressed/processor.rs new file mode 100644 index 0000000000..c8b5661112 --- /dev/null +++ b/programs/compressed-token/program/src/mint_to_compressed/processor.rs @@ -0,0 +1,199 @@ +use anchor_lang::solana_program::program_error::ProgramError; +use light_compressed_account::{ + hash_to_bn254_field_size_be, + instruction_data::with_readonly::InstructionDataInvokeCpiWithReadOnly, Pubkey, +}; +use light_zero_copy::{borsh::Deserialize, ZeroCopyNew}; +use pinocchio::account_info::AccountInfo; +use spl_token::solana_program::log::sol_log_compute_units; +use zerocopy::little_endian::U64; + +use crate::{ + mint::{ + input::create_input_compressed_mint_account, output::create_output_compressed_mint_account, + }, + mint_to_compressed::{ + accounts::MintToCompressedAccounts, instructions::MintToCompressedInstructionData, + }, + shared::{ + context::TokenContext, + cpi::execute_cpi_invoke, + cpi_bytes_size::{ + allocate_invoke_with_read_only_cpi_bytes, cpi_bytes_config, CpiConfigInput, + }, + outputs::create_output_compressed_account, + }, + LIGHT_CPI_SIGNER, +}; + +pub fn process_mint_to_compressed( + program_id: pinocchio::pubkey::Pubkey, + accounts: &[AccountInfo], + instruction_data: &[u8], +) -> Result<(), ProgramError> { + sol_log_compute_units(); + + // Parse instruction data using zero-copy + let (parsed_instruction_data, _) = + MintToCompressedInstructionData::zero_copy_at(instruction_data) + .map_err(|_| ProgramError::InvalidInstructionData)?; + + sol_log_compute_units(); + + // Validate and parse accounts + let validated_accounts = MintToCompressedAccounts::validate_and_parse( + accounts, + parsed_instruction_data.lamports.is_some(), + parsed_instruction_data + .compressed_mint_inputs + .compressed_mint_input + .is_decompressed(), + )?; + // Build configuration for CPI instruction data using the generalized function + let compressed_mint_with_freeze_authority = parsed_instruction_data + .compressed_mint_inputs + .compressed_mint_input + .freeze_authority_is_set + != 0; + + let config_input = CpiConfigInput::mint_to_compressed( + parsed_instruction_data.recipients.len(), + parsed_instruction_data.proof.is_some(), + compressed_mint_with_freeze_authority, + ); + + let config = cpi_bytes_config(config_input); + let mut cpi_bytes = allocate_invoke_with_read_only_cpi_bytes(&config); + + sol_log_compute_units(); + let (mut cpi_instruction_struct, _) = + InstructionDataInvokeCpiWithReadOnly::new_zero_copy(&mut cpi_bytes[8..], config) + .map_err(ProgramError::from)?; + cpi_instruction_struct.bump = LIGHT_CPI_SIGNER.bump; + cpi_instruction_struct.invoking_program_id = LIGHT_CPI_SIGNER.program_id.into(); + if let Some(lamports) = parsed_instruction_data.lamports { + cpi_instruction_struct.compress_or_decompress_lamports = + U64::from(parsed_instruction_data.recipients.len() as u64) * *lamports; + cpi_instruction_struct.is_compress = 1; + } + + let mut context = TokenContext::new(); + let mint = parsed_instruction_data + .compressed_mint_inputs + .compressed_mint_input + .spl_mint; + + let hashed_mint = hash_to_bn254_field_size_be(mint.as_ref()); + let hashed_mint_authority = context.get_or_hash_pubkey(validated_accounts.authority.key()); + + { + // Process input compressed mint account + create_input_compressed_mint_account( + &mut cpi_instruction_struct.input_compressed_accounts[0], + &mut context, + &parsed_instruction_data.compressed_mint_inputs, + &hashed_mint_authority, + )?; + let mint_inputs = &parsed_instruction_data + .compressed_mint_inputs + .compressed_mint_input; + let mint_pda = mint_inputs.spl_mint; + let decimals = mint_inputs.decimals; + // TODO: make option in ix data. + let freeze_authority = if mint_inputs.freeze_authority_is_set() { + Some(mint_inputs.freeze_authority) + } else { + None + }; + use crate::mint::state::CompressedMintConfig; + let mint_config = CompressedMintConfig { + mint_authority: (true, ()), + freeze_authority: (mint_inputs.freeze_authority_is_set(), ()), + }; + let compressed_account_address = *parsed_instruction_data.compressed_mint_inputs.address; + let sum_amounts: U64 = parsed_instruction_data + .recipients + .iter() + .map(|x| u64::from(x.amount)) + .sum::() + .into(); + let supply = mint_inputs.supply + sum_amounts; + + // Compressed mint account is the last output + create_output_compressed_mint_account( + &mut cpi_instruction_struct.output_compressed_accounts + [parsed_instruction_data.recipients.len()], + mint_pda, + decimals, + freeze_authority, + Some(Pubkey::from(*validated_accounts.authority.key())), + supply, + &program_id.into(), + mint_config, + compressed_account_address, + 2, + )?; + } + + let is_decompressed = parsed_instruction_data + .compressed_mint_inputs + .compressed_mint_input + .is_decompressed(); + // Create output token accounts + create_output_compressed_token_accounts( + parsed_instruction_data, + cpi_instruction_struct, + &mut context, + mint, + hashed_mint, + )?; + + // Extract tree accounts for the generalized CPI call + let tree_accounts = [ + validated_accounts.mint_in_merkle_tree.key(), + validated_accounts.mint_in_queue.key(), + validated_accounts.mint_out_queue.key(), + validated_accounts.tokens_out_queue.key(), + ]; + let start_index = if is_decompressed { 5 } else { 2 }; + + execute_cpi_invoke( + &accounts[start_index..], // Skip first 5 non-CPI accounts (authority, mint, token_pool_pda, token_program, light_system_program) + cpi_bytes, + tree_accounts.as_slice(), + validated_accounts.sol_pool_pda.is_some(), + None, // no cpi_context_account for mint_to_compressed + )?; + Ok(()) +} + +fn create_output_compressed_token_accounts( + parsed_instruction_data: super::instructions::ZMintToCompressedInstructionData<'_>, + mut cpi_instruction_struct: light_compressed_account::instruction_data::with_readonly::ZInstructionDataInvokeCpiWithReadOnlyMut<'_>, + context: &mut TokenContext, + mint: Pubkey, + hashed_mint: [u8; 32], +) -> Result<(), ProgramError> { + let lamports = parsed_instruction_data + .lamports + .map(|lamports| u64::from(*lamports)); + for (recipient, output_account) in parsed_instruction_data + .recipients + .iter() + .zip(cpi_instruction_struct.output_compressed_accounts.iter_mut()) + { + let output_delegate = None; + create_output_compressed_account( + output_account, + context, + recipient.recipient, + output_delegate, + recipient.amount, + lamports, + mint, + &hashed_mint, + 2, + )?; + } + Ok(()) +} diff --git a/programs/compressed-token/program/src/multi_transfer/accounts.rs b/programs/compressed-token/program/src/multi_transfer/accounts.rs new file mode 100644 index 0000000000..0182e39390 --- /dev/null +++ b/programs/compressed-token/program/src/multi_transfer/accounts.rs @@ -0,0 +1,125 @@ +use crate::shared::AccountIterator; +use anchor_lang::solana_program::program_error::ProgramError; +use light_account_checks::checks::{check_mut, check_signer}; +use pinocchio::account_info::AccountInfo; + +/// Validated system accounts for multi-transfer instruction +/// Accounts are ordered to match light-system-program CPI expectation +pub struct MultiTransferValidatedAccounts<'info> { + /// Fee payer account (index 0) - signer, mutable + pub fee_payer: &'info AccountInfo, + /// CPI authority PDA (index 1) - signer (via CPI) + pub authority: &'info AccountInfo, + /// Registered program PDA (index 2) - non-mutable + pub registered_program_pda: &'info AccountInfo, + /// Noop program (index 3) - non-mutable + pub noop_program: &'info AccountInfo, + /// Account compression authority (index 4) - non-mutable + pub account_compression_authority: &'info AccountInfo, + /// Account compression program (index 5) - non-mutable + pub account_compression_program: &'info AccountInfo, + /// Invoking program (index 6) - self program, non-mutable + pub invoking_program: &'info AccountInfo, + /// Sol pool PDA (index 7) - optional, mutable if present + pub sol_pool_pda: Option<&'info AccountInfo>, + /// SOL decompression recipient (index 8) - optional, mutable, for SOL decompression + pub sol_decompression_recipient: Option<&'info AccountInfo>, + /// System program (index 9) - non-mutable + pub system_program: &'info AccountInfo, + /// CPI context account (index 10) - optional, non-mutable + pub cpi_context_account: Option<&'info AccountInfo>, +} + +/// Dynamic accounts slice for index-based access +/// Contains mint, owner, delegate, merkle tree, and queue accounts +pub struct MultiTransferPackedAccounts<'info> { + /// Remaining accounts slice starting at index 11 + pub accounts: &'info [AccountInfo], +} + +impl MultiTransferPackedAccounts<'_> { + /// Get account by index with bounds checking + pub fn get(&self, index: usize) -> Result<&AccountInfo, ProgramError> { + self.accounts + .get(index) + .ok_or(ProgramError::NotEnoughAccountKeys) + } + + /// Get account by u8 index with bounds checking + pub fn get_u8(&self, index: u8) -> Result<&AccountInfo, ProgramError> { + self.get(index as usize) + } +} + +impl<'info> MultiTransferValidatedAccounts<'info> { + /// Validate and parse accounts from the instruction accounts slice + pub fn validate_and_parse( + accounts: &'info [AccountInfo], + with_sol_pool: bool, + with_cpi_context: bool, + ) -> Result<(Self, MultiTransferPackedAccounts<'info>), ProgramError> { + // Calculate minimum required accounts + let min_accounts = + 11 + if with_sol_pool { 1 } else { 0 } + if with_cpi_context { 1 } else { 0 }; + + if accounts.len() < min_accounts { + return Err(ProgramError::NotEnoughAccountKeys); + } + + // Parse system accounts from fixed positions + let mut iter = AccountIterator::new(accounts); + let fee_payer = iter.next()?; + let authority = iter.next()?; + let registered_program_pda = iter.next()?; + let noop_program = iter.next()?; + let account_compression_authority = iter.next()?; + let account_compression_program = iter.next()?; + let invoking_program = iter.next()?; + + let sol_pool_pda = if with_sol_pool { + Some(iter.next()?) + } else { + None + }; + + let sol_decompression_recipient = if with_sol_pool { + Some(iter.next()?) + } else { + None + }; + + let system_program = iter.next()?; + + let cpi_context_account = if with_cpi_context { + Some(iter.next()?) + } else { + None + }; + + // Validate fee_payer: must be signer and mutable + check_signer(fee_payer).map_err(ProgramError::from)?; + check_mut(fee_payer).map_err(ProgramError::from)?; + // Extract remaining accounts slice for dynamic indexing + let remaining_accounts = iter.remaining(); + + let validated_accounts = MultiTransferValidatedAccounts { + fee_payer, + authority, + registered_program_pda, + noop_program, + account_compression_authority, + account_compression_program, + invoking_program, + sol_pool_pda, + sol_decompression_recipient, + system_program, + cpi_context_account, + }; + + let packed_accounts = MultiTransferPackedAccounts { + accounts: remaining_accounts, + }; + + Ok((validated_accounts, packed_accounts)) + } +} diff --git a/programs/compressed-token/program/src/multi_transfer/assign_inputs.rs b/programs/compressed-token/program/src/multi_transfer/assign_inputs.rs new file mode 100644 index 0000000000..c71870da1e --- /dev/null +++ b/programs/compressed-token/program/src/multi_transfer/assign_inputs.rs @@ -0,0 +1,47 @@ +use anchor_lang::prelude::ProgramError; +use light_compressed_account::instruction_data::with_readonly::ZInstructionDataInvokeCpiWithReadOnlyMut; + +use crate::{ + multi_transfer::{ + accounts::MultiTransferPackedAccounts, + instruction_data::ZCompressedTokenInstructionDataMultiTransfer, + }, + shared::{context::TokenContext, inputs::create_input_compressed_account}, +}; + +/// Process input compressed accounts and return total input lamports +pub fn assign_input_compressed_accounts( + cpi_instruction_struct: &mut ZInstructionDataInvokeCpiWithReadOnlyMut, + context: &mut TokenContext, + inputs: &ZCompressedTokenInstructionDataMultiTransfer, + packed_accounts: &MultiTransferPackedAccounts, +) -> Result { + let mut total_input_lamports = 0u64; + + for (i, input_data) in inputs.in_token_data.iter().enumerate() { + let input_lamports = if let Some(lamports) = inputs.in_lamports.as_ref() { + if let Some(input_lamports) = lamports.get(i) { + input_lamports.get() + } else { + 0 + } + } else { + 0 + }; + + total_input_lamports += input_lamports; + + create_input_compressed_account::( + cpi_instruction_struct + .input_compressed_accounts + .get_mut(i) + .ok_or(ProgramError::InvalidAccountData)?, + context, + input_data, + packed_accounts.accounts, + input_lamports, + )?; + } + + Ok(total_input_lamports) +} diff --git a/programs/compressed-token/program/src/multi_transfer/assign_outputs.rs b/programs/compressed-token/program/src/multi_transfer/assign_outputs.rs new file mode 100644 index 0000000000..7f4a1e5484 --- /dev/null +++ b/programs/compressed-token/program/src/multi_transfer/assign_outputs.rs @@ -0,0 +1,71 @@ +use anchor_lang::prelude::ProgramError; +use light_compressed_account::instruction_data::with_readonly::ZInstructionDataInvokeCpiWithReadOnlyMut; + +use crate::{ + multi_transfer::{ + accounts::MultiTransferPackedAccounts, + instruction_data::ZCompressedTokenInstructionDataMultiTransfer, + }, + shared::{context::TokenContext, outputs::create_output_compressed_account}, +}; + +/// Process output compressed accounts and return total output lamports +pub fn assign_output_compressed_accounts( + cpi_instruction_struct: &mut ZInstructionDataInvokeCpiWithReadOnlyMut, + context: &mut TokenContext, + inputs: &ZCompressedTokenInstructionDataMultiTransfer, + packed_accounts: &MultiTransferPackedAccounts, +) -> Result { + let mut total_output_lamports = 0u64; + + for (i, output_data) in inputs.out_token_data.iter().enumerate() { + let output_lamports = if let Some(lamports) = inputs.out_lamports.as_ref() { + if let Some(lamports) = lamports.get(i) { + lamports.get() + } else { + 0 + } + } else { + 0 + }; + + total_output_lamports += output_lamports; + + let mint_index = output_data.mint; + let mint_account = packed_accounts.get_u8(mint_index)?; + let hashed_mint = context.get_or_hash_pubkey(mint_account.key()); + + // Get owner account using owner index + let owner_account = packed_accounts.get_u8(output_data.owner)?; + let owner_pubkey = *owner_account.key(); + + // Get delegate if present + let delegate_pubkey = if output_data.delegate != 0 { + let delegate_account = packed_accounts.get_u8(output_data.delegate)?; + Some(*delegate_account.key()) + } else { + None + }; + + create_output_compressed_account( + cpi_instruction_struct + .output_compressed_accounts + .get_mut(i) + .ok_or(ProgramError::InvalidAccountData)?, + context, + owner_pubkey.into(), + delegate_pubkey.map(|d| d.into()), + output_data.amount, + if output_lamports > 0 { + Some(output_lamports) + } else { + None + }, + mint_account.key().into(), + &hashed_mint, + output_data.merkle_tree, + )?; + } + + Ok(total_output_lamports) +} diff --git a/programs/compressed-token/program/src/multi_transfer/change_account.rs b/programs/compressed-token/program/src/multi_transfer/change_account.rs new file mode 100644 index 0000000000..7e69ec20b6 --- /dev/null +++ b/programs/compressed-token/program/src/multi_transfer/change_account.rs @@ -0,0 +1,92 @@ +use anchor_lang::prelude::ProgramError; +use light_compressed_account::instruction_data::with_readonly::ZInstructionDataInvokeCpiWithReadOnlyMut; + +use crate::multi_transfer::{ + accounts::MultiTransferPackedAccounts, + instruction_data::ZCompressedTokenInstructionDataMultiTransfer, +}; + +/// Create a change account for excess lamports (following anchor program pattern) +pub fn assign_change_account( + cpi_instruction_struct: &mut ZInstructionDataInvokeCpiWithReadOnlyMut, + inputs: &ZCompressedTokenInstructionDataMultiTransfer, + packed_accounts: &MultiTransferPackedAccounts, + change_lamports: u64, +) -> Result<(), ProgramError> { + // Find the next available output account slot + let current_output_count = inputs.out_token_data.len(); + + // Get the change account slot (should be pre-allocated by CPI config) + let change_account = cpi_instruction_struct + .output_compressed_accounts + .get_mut(current_output_count) + .ok_or(ProgramError::InvalidAccountData)?; + anchor_lang::solana_program::log::msg!("inputs {:?}", inputs); + + // Get merkle tree index - use specified index + let merkle_tree_index = if inputs.with_lamports_change_account_merkle_tree_index != 0 { + inputs.lamports_change_account_merkle_tree_index + } else { + return Err(ProgramError::InvalidInstructionData); + }; + + // Get the owner account using the specified index + let owner_account = packed_accounts.get_u8(inputs.lamports_change_account_owner_index)?; + let owner_pubkey = *owner_account.key(); + + // Set up the change account as a lamports-only account (no token data) + let compressed_account = &mut change_account.compressed_account; + + // Set owner from the specified account index + compressed_account.owner = owner_pubkey.into(); + + // Set lamports amount + compressed_account.lamports.set(change_lamports); + + // No token data for change account + + if compressed_account.data.is_some() { + unimplemented!("lamports change account shouldn't have data.") + } + + // Set merkle tree index + *change_account.merkle_tree_index = merkle_tree_index; + + Ok(()) +} + +pub fn process_change_lamports( + inputs: &ZCompressedTokenInstructionDataMultiTransfer<'_>, + packed_accounts: &MultiTransferPackedAccounts<'_>, + mut cpi_instruction_struct: ZInstructionDataInvokeCpiWithReadOnlyMut<'_>, + total_input_lamports: u64, + total_output_lamports: u64, +) -> Result<(), ProgramError> { + if total_input_lamports != total_output_lamports { + let (change_lamports, is_compress) = if total_input_lamports > total_output_lamports { + ( + total_input_lamports.saturating_sub(total_output_lamports), + 0, + ) + } else { + ( + total_output_lamports.saturating_sub(total_input_lamports), + 1, + ) + }; + // Set CPI instruction fields for compression/decompression + cpi_instruction_struct + .compress_or_decompress_lamports + .set(change_lamports); + cpi_instruction_struct.is_compress = is_compress; + // Create change account with the lamports difference + assign_change_account( + &mut cpi_instruction_struct, + inputs, + packed_accounts, + change_lamports, + )?; + } + + Ok(()) +} diff --git a/programs/compressed-token/program/src/multi_transfer/cpi.rs b/programs/compressed-token/program/src/multi_transfer/cpi.rs new file mode 100644 index 0000000000..cab0967f74 --- /dev/null +++ b/programs/compressed-token/program/src/multi_transfer/cpi.rs @@ -0,0 +1,41 @@ +use arrayvec::ArrayVec; +use light_compressed_account::instruction_data::with_readonly::InstructionDataInvokeCpiWithReadOnlyConfig; + +use crate::{ + multi_transfer::instruction_data::ZCompressedTokenInstructionDataMultiTransfer, + shared::cpi_bytes_size::{ + allocate_invoke_with_read_only_cpi_bytes, cpi_bytes_config, CpiConfigInput, + }, +}; + +/// Build CPI configuration from instruction data +pub fn allocate_cpi_bytes( + inputs: &ZCompressedTokenInstructionDataMultiTransfer, +) -> (Vec, InstructionDataInvokeCpiWithReadOnlyConfig) { + // Build CPI configuration based on delegate flags + let mut input_delegate_flags = ArrayVec::new(); + for input_data in inputs.in_token_data.iter() { + input_delegate_flags.push(input_data.with_delegate != 0); + } + + let mut output_delegate_flags = ArrayVec::new(); + for output_data in inputs.out_token_data.iter() { + // Check if output has delegate (delegate index != 0 means delegate is present) + output_delegate_flags.push(output_data.delegate != 0); + } + + // Add extra output account for change account if needed (no delegate, no token data) + if inputs.with_lamports_change_account_merkle_tree_index != 0 { + output_delegate_flags.push(false); + } + + let config_input = CpiConfigInput { + input_accounts: input_delegate_flags, + output_accounts: output_delegate_flags, + has_proof: inputs.proof.is_some(), + compressed_mint: false, + compressed_mint_with_freeze_authority: false, + }; + let config = cpi_bytes_config(config_input); + (allocate_invoke_with_read_only_cpi_bytes(&config), config) +} diff --git a/programs/compressed-token/program/src/multi_transfer/instruction_data.rs b/programs/compressed-token/program/src/multi_transfer/instruction_data.rs new file mode 100644 index 0000000000..abd8dac35c --- /dev/null +++ b/programs/compressed-token/program/src/multi_transfer/instruction_data.rs @@ -0,0 +1,110 @@ +use std::fmt::Debug; + +use anchor_lang::{prelude::ProgramError, AnchorDeserialize, AnchorSerialize}; +use light_compressed_account::instruction_data::{ + compressed_proof::CompressedProof, cpi_context::CompressedCpiContext, +}; +use light_sdk::instruction::PackedMerkleContext; +use light_zero_copy::{ZeroCopy, ZeroCopyMut}; + +#[derive(Debug, Clone, Default, AnchorSerialize, AnchorDeserialize, ZeroCopy, ZeroCopyMut)] +pub struct MultiInputTokenDataWithContext { + pub amount: u64, + pub merkle_context: PackedMerkleContext, + pub root_index: u16, + // From remaining accounts. + pub mint: u8, + pub owner: u8, + pub with_delegate: bool, + // Only used if with_delegate is true + pub delegate: u8, + // // Only used if with_delegate is true + // pub delegate_change_account: u8, + // pub lamports: Option, move into separate vector to opt zero copy + // pub tlv: Option>, move into separate vector to opt zero copy +} + +#[derive( + Clone, + Copy, + Debug, + Default, + PartialEq, + Eq, + AnchorSerialize, + AnchorDeserialize, + ZeroCopy, + ZeroCopyMut, +)] +pub struct MultiTokenTransferOutputData { + pub owner: u8, + pub amount: u64, + pub merkle_tree: u8, + pub delegate: u8, + pub mint: u8, +} + +#[derive( + Clone, Copy, Debug, PartialEq, Eq, AnchorSerialize, AnchorDeserialize, ZeroCopy, ZeroCopyMut, +)] +pub struct Compression { + pub amount: u64, + pub is_compress: bool, + pub mint: u8, + pub source_or_recipient: u8, +} + +// #[derive( +// Clone, Copy, Debug, PartialEq, Eq, AnchorSerialize, AnchorDeserialize, ZeroCopy, ZeroCopyMut, +// )] +// pub struct MultiTokenTransferDelegateOutputData { +// pub delegate: u8, +// pub owner: u8, +// pub amount: u64, +// pub merkle_tree: u8, +// } + +#[derive(Debug, Clone, AnchorSerialize, AnchorDeserialize, ZeroCopy, ZeroCopyMut)] +pub struct CompressedTokenInstructionDataMultiTransfer { + pub with_transaction_hash: bool, + pub with_lamports_change_account_merkle_tree_index: bool, + // Set zero if unused + pub lamports_change_account_merkle_tree_index: u8, + pub lamports_change_account_owner_index: u8, + pub proof: Option, + pub in_token_data: Vec, + pub out_token_data: Vec, + // pub delegate_out_token_data: Option>, + // put accounts with lamports first, stop adding values after TODO: only access by get to prevent oob errors + pub in_lamports: Option>, + // TODO: put accounts with lamports first, stop adding values after TODO: only access by get to prevent oob errors + pub out_lamports: Option>, + // TODO: put accounts with tlv first, stop adding values after TODO: only access by get to prevent oob errors + pub in_tlv: Option>>, + pub out_tlv: Option>>, + pub compressions: Option>, + pub cpi_context: Option, +} + +/// Validate instruction data consistency (lamports and TLV checks) +pub fn validate_instruction_data( + inputs: &ZCompressedTokenInstructionDataMultiTransfer, +) -> Result<(), ProgramError> { + if let Some(ref in_lamports) = inputs.in_lamports { + if in_lamports.len() > inputs.in_token_data.len() { + unimplemented!("Tlv is unimplemented"); + } + } + if let Some(ref out_lamports) = inputs.out_lamports { + if out_lamports.len() > inputs.out_token_data.len() { + unimplemented!("Tlv is unimplemented"); + } + } + if inputs.in_tlv.is_some() { + unimplemented!("Tlv is unimplemented"); + } + if inputs.out_tlv.is_some() { + unimplemented!("Tlv is unimplemented"); + } + Ok(()) +} diff --git a/programs/compressed-token/program/src/multi_transfer/mod.rs b/programs/compressed-token/program/src/multi_transfer/mod.rs new file mode 100644 index 0000000000..b726111d42 --- /dev/null +++ b/programs/compressed-token/program/src/multi_transfer/mod.rs @@ -0,0 +1,9 @@ +pub mod accounts; +pub mod assign_inputs; +pub mod assign_outputs; +pub mod change_account; +pub mod cpi; +pub mod instruction_data; +pub mod native_compression; +pub mod processor; +pub mod sum_check; diff --git a/programs/compressed-token/program/src/multi_transfer/native_compression.rs b/programs/compressed-token/program/src/multi_transfer/native_compression.rs new file mode 100644 index 0000000000..4cd30a4491 --- /dev/null +++ b/programs/compressed-token/program/src/multi_transfer/native_compression.rs @@ -0,0 +1,74 @@ +use anchor_lang::prelude::ProgramError; +use pinocchio::{account_info::AccountInfo, msg}; +use spl_pod::bytemuck::pod_from_bytes_mut; +use spl_token_2022::pod::PodAccount; + +use crate::multi_transfer::{ + accounts::MultiTransferPackedAccounts, + instruction_data::{ZCompressedTokenInstructionDataMultiTransfer, ZCompression}, +}; +use crate::LIGHT_CPI_SIGNER; +const ID: &[u8; 32] = &LIGHT_CPI_SIGNER.program_id; +/// Process native compressions/decompressions with token accounts +pub fn process_token_compression( + inputs: &ZCompressedTokenInstructionDataMultiTransfer, + packed_accounts: &MultiTransferPackedAccounts, +) -> Result<(), ProgramError> { + if let Some(compressions) = inputs.compressions.as_ref() { + for compression in compressions { + let source_or_recipient = packed_accounts.get_u8(compression.source_or_recipient)?; + use anchor_lang::solana_program::log::msg; + msg!( + "source_or_recipient: {:?}", + solana_pubkey::Pubkey::new_from_array(*source_or_recipient.key()) + ); + + match unsafe { source_or_recipient.owner() } { + ID => { + process_native_compressions(compression, source_or_recipient)?; + } + _ => return Err(ProgramError::InvalidInstructionData), + } + } + } + Ok(()) +} + +/// Process compression/decompression for token accounts using zero-copy PodAccount +fn process_native_compressions( + compression: &ZCompression, + token_account_info: &AccountInfo, +) -> Result<(), ProgramError> { + msg!("process_native_compressions"); + + // Access token account data as mutable bytes + let mut token_account_data = token_account_info + .try_borrow_mut_data() + .map_err(|_| ProgramError::AccountBorrowFailed)?; + msg!("pre pod"); + // Use zero-copy PodAccount to access the token account + let pod_account = pod_from_bytes_mut::(&mut token_account_data) + .map_err(|e| ProgramError::Custom(u64::from(e) as u32))?; + msg!(format!("pod_account {:?}", pod_account).as_str()); + + // Get current balance + let current_balance: u64 = pod_account.amount.into(); + + // Update balance based on compression type + let new_balance = if compression.is_compress() { + // Compress: subtract balance (tokens are being compressed) + current_balance + .checked_sub(compression.amount.into()) + .ok_or(ProgramError::InsufficientFunds)? + } else { + // Decompress: add balance (tokens are being decompressed) + current_balance + .checked_add(compression.amount.into()) + .ok_or(ProgramError::ArithmeticOverflow)? + }; + + // Update the balance in the pod account + pod_account.amount = new_balance.into(); + + Ok(()) +} diff --git a/programs/compressed-token/program/src/multi_transfer/processor.rs b/programs/compressed-token/program/src/multi_transfer/processor.rs new file mode 100644 index 0000000000..b16363a613 --- /dev/null +++ b/programs/compressed-token/program/src/multi_transfer/processor.rs @@ -0,0 +1,187 @@ +use anchor_lang::prelude::ProgramError; +use light_compressed_account::instruction_data::with_readonly::InstructionDataInvokeCpiWithReadOnly; +use light_heap::{bench_sbf_end, bench_sbf_start}; +use light_zero_copy::{borsh::Deserialize, ZeroCopyNew}; +use pinocchio::account_info::AccountInfo; + +use crate::{ + multi_transfer::{ + accounts::{MultiTransferPackedAccounts, MultiTransferValidatedAccounts}, + assign_inputs::assign_input_compressed_accounts, + assign_outputs::assign_output_compressed_accounts, + change_account::process_change_lamports, + cpi::allocate_cpi_bytes, + instruction_data::{ + validate_instruction_data, CompressedTokenInstructionDataMultiTransfer, + ZCompressedTokenInstructionDataMultiTransfer, + }, + native_compression::process_token_compression, + sum_check::sum_check_multi_mint, + }, + shared::{context::TokenContext, cpi::execute_cpi_invoke}, + LIGHT_CPI_SIGNER, +}; + +/// Process a token transfer instruction +/// build inputs -> sum check -> build outputs -> add token data to inputs -> invoke cpi +/// 1. Unpack compressed input accounts and input token data, this uses +/// standardized signer / delegate and will fail in proof verification in +/// case either is invalid. +/// 2. Check that compressed accounts are of same mint. +/// 3. Check that sum of input compressed accounts is equal to sum of output +/// compressed accounts +/// 4. create_output_compressed_accounts +/// 5. Serialize and add token_data data to in compressed_accounts. +/// 6. Invoke light_system_program::execute_compressed_transaction. +#[inline(always)] +pub fn process_multi_transfer( + accounts: &[AccountInfo], + instruction_data: &[u8], +) -> Result<(), ProgramError> { + // Parse instruction data first to determine optional accounts + let (inputs, _) = CompressedTokenInstructionDataMultiTransfer::zero_copy_at(instruction_data) + .map_err(ProgramError::from)?; + + let total_input_lamports = if let Some(inputs) = inputs.in_lamports.as_ref() { + inputs.iter().map(|input| u64::from(**input)).sum() + } else { + 0 + }; + let total_output_lamports = if let Some(inputs) = inputs.out_lamports.as_ref() { + inputs.iter().map(|input| u64::from(**input)).sum() + } else { + 0 + }; + // Determine optional account flags from instruction data + let with_sol_pool = total_input_lamports != total_output_lamports; + let with_cpi_context = inputs.cpi_context.is_some(); + + // Skip first account (light-system-program) and validate remaining accounts + let (validated_accounts, packed_accounts) = MultiTransferValidatedAccounts::validate_and_parse( + &accounts[1..], + with_sol_pool, + with_cpi_context, + )?; + use anchor_lang::solana_program::msg; + // Validate instruction data consistency + validate_instruction_data(&inputs)?; + msg!("validate_instruction_data"); + bench_sbf_start!("t_context_and_check_sig"); + anchor_lang::solana_program::log::msg!("inputs {:?}", inputs); + + // Create TokenContext for hash caching + let mut context = TokenContext::new(); + + // Allocate CPI bytes and create zero-copy structure + let (mut cpi_bytes, config) = allocate_cpi_bytes(&inputs); + + let (mut cpi_instruction_struct, _) = + InstructionDataInvokeCpiWithReadOnly::new_zero_copy(&mut cpi_bytes[8..], config) + .map_err(ProgramError::from)?; + + // Set CPI signer information + cpi_instruction_struct.bump = LIGHT_CPI_SIGNER.bump; + cpi_instruction_struct.invoking_program_id = LIGHT_CPI_SIGNER.program_id.into(); + msg!("pre assign_input_compressed_accounts"); + + // Process input compressed accounts + assign_input_compressed_accounts( + &mut cpi_instruction_struct, + &mut context, + &inputs, + &packed_accounts, + )?; + msg!("pre sum_check_multi_mint"); + bench_sbf_end!("t_context_and_check_sig"); + bench_sbf_start!("t_sum_check"); + sum_check_multi_mint( + &inputs.in_token_data, + &inputs.out_token_data, + inputs.compressions.as_deref(), + ) + .map_err(|e| ProgramError::Custom(e as u32))?; + bench_sbf_end!("t_sum_check"); + msg!("pre assign_output_compressed_accounts"); + + // Process output compressed accounts + assign_output_compressed_accounts( + &mut cpi_instruction_struct, + &mut context, + &inputs, + &packed_accounts, + )?; + bench_sbf_end!("t_create_output_compressed_accounts"); + + msg!("pre process_change_lamports"); + process_change_lamports( + &inputs, + &packed_accounts, + cpi_instruction_struct, + total_input_lamports, + total_output_lamports, + )?; + // Process token compressions/decompressions + // TODO: support spl + process_token_compression(&inputs, &packed_accounts)?; + + // Extract tree accounts using highest index approach + let (tree_accounts, tree_accounts_count) = extract_tree_accounts(&inputs, &packed_accounts); + + // Calculate static accounts count after skipping index 0 (system accounts only) + let static_accounts_count = + 8 + if with_sol_pool { 2 } else { 0 } + if with_cpi_context { 1 } else { 0 }; + + // Include static CPI accounts + tree accounts based on highest index + let cpi_accounts_end = 1 + static_accounts_count + tree_accounts_count; + let cpi_accounts = &accounts[1..cpi_accounts_end]; + let solana_tree_accounts = tree_accounts + .iter() + .map(|&x| solana_pubkey::Pubkey::new_from_array(*x)) + .collect::>(); + msg!("solana_tree_accounts {:?}", solana_tree_accounts); + let _cpi_accounts = cpi_accounts + .iter() + .map(|x| solana_pubkey::Pubkey::new_from_array(*x.key())) + .collect::>(); + msg!("cpi_accounts {:?}", _cpi_accounts); + // Execute CPI call to light-system-program + execute_cpi_invoke( + cpi_accounts, + cpi_bytes, + tree_accounts.as_slice(), + with_sol_pool, + validated_accounts.cpi_context_account.map(|x| *x.key()), + )?; + + Ok(()) +} + +/// Extract tree accounts by finding the highest tree index and using it as closing offset +fn extract_tree_accounts<'a>( + inputs: &ZCompressedTokenInstructionDataMultiTransfer, + packed_accounts: &'a MultiTransferPackedAccounts<'a>, +) -> (Vec<&'a pinocchio::pubkey::Pubkey>, usize) { + // Find highest tree index from input and output data to determine tree accounts range + let mut highest_tree_index = 0u8; + for input_data in inputs.in_token_data.iter() { + highest_tree_index = + highest_tree_index.max(input_data.merkle_context.merkle_tree_pubkey_index); + highest_tree_index = highest_tree_index.max(input_data.merkle_context.queue_pubkey_index); + } + for output_data in inputs.out_token_data.iter() { + highest_tree_index = highest_tree_index.max(output_data.merkle_tree); + } + + // Tree accounts span from index 0 to highest_tree_index in remaining accounts + let tree_accounts_count = (highest_tree_index + 1) as usize; + + // Extract tree account pubkeys from the determined range + let mut tree_accounts = Vec::new(); + for i in 0..tree_accounts_count { + if let Some(account) = packed_accounts.accounts.get(i) { + tree_accounts.push(account.key()); + } + } + + (tree_accounts, tree_accounts_count) +} diff --git a/programs/compressed-token/program/src/multi_transfer/sum_check.rs b/programs/compressed-token/program/src/multi_transfer/sum_check.rs new file mode 100644 index 0000000000..804c56eed1 --- /dev/null +++ b/programs/compressed-token/program/src/multi_transfer/sum_check.rs @@ -0,0 +1,132 @@ +use anchor_compressed_token::ErrorCode; +use arrayvec::ArrayVec; + +use crate::multi_transfer::instruction_data::{ + ZCompression, ZMultiInputTokenDataWithContext, ZMultiTokenTransferOutputData, +}; + +/// Process inputs and add amounts to mint sums with order validation +#[inline(always)] +fn sum_inputs( + inputs: &[ZMultiInputTokenDataWithContext], + mint_sums: &mut ArrayVec<(u8, u64), 5>, +) -> Result<(), ErrorCode> { + let mut prev_mint_index = 0u8; + for (i, input) in inputs.iter().enumerate() { + let mint_index = input.mint; + + // Validate incremental order + if i > 0 && mint_index < prev_mint_index { + return Err(ErrorCode::InputsOutOfOrder); + } + + // Find or create mint entry + if let Some(entry) = mint_sums.iter_mut().find(|(idx, _)| *idx == mint_index) { + entry.1 = entry + .1 + .checked_add(input.amount.into()) + .ok_or(ErrorCode::ComputeInputSumFailed)?; + } else { + if mint_sums.is_full() { + return Err(ErrorCode::TooManyMints); + } + mint_sums.push((mint_index, input.amount.into())); + } + + prev_mint_index = mint_index; + } + Ok(()) +} + +/// Process compressions and adjust mint sums (add for compress, subtract for decompress) +#[inline(always)] +fn sum_compressions( + compressions: &[ZCompression], + mint_sums: &mut ArrayVec<(u8, u64), 5>, +) -> Result<(), ErrorCode> { + for compression in compressions.iter() { + let mint_index = compression.mint; + + // Find mint entry (create if doesn't exist for compression) + if let Some(entry) = mint_sums.iter_mut().find(|(idx, _)| *idx == mint_index) { + if compression.is_compress() { + // Compress: add to balance + entry.1 = entry + .1 + .checked_add(compression.amount.into()) + .ok_or(ErrorCode::ComputeCompressSumFailed)?; + } else { + // Decompress: subtract from balance + entry.1 = entry + .1 + .checked_sub(compression.amount.into()) + .ok_or(ErrorCode::ComputeDecompressSumFailed)?; + } + } else { + // Create new entry if compressing + if compression.is_compress() { + if mint_sums.is_full() { + return Err(ErrorCode::TooManyMints); + } + mint_sums.push((mint_index, compression.amount.into())); + } else { + // Cannot decompress if no balance exists + return Err(ErrorCode::SumCheckFailed); + } + } + } + Ok(()) +} + +/// Process outputs and subtract amounts from mint sums +#[inline(always)] +fn sum_outputs( + outputs: &[ZMultiTokenTransferOutputData], + mint_sums: &mut ArrayVec<(u8, u64), 5>, +) -> Result<(), ErrorCode> { + for output in outputs.iter() { + let mint_index = output.mint; + + // Find mint entry (create if doesn't exist for output-only mints) + if let Some(entry) = mint_sums.iter_mut().find(|(idx, _)| *idx == mint_index) { + entry.1 = entry + .1 + .checked_sub(output.amount.into()) + .ok_or(ErrorCode::ComputeOutputSumFailed)?; + } else { + // Output mint not in inputs or compressions - invalid + return Err(ErrorCode::ComputeOutputSumFailed); + } + } + Ok(()) +} + +/// Sum check for multi-mint transfers with ordered mint validation and compression support +pub fn sum_check_multi_mint( + inputs: &[ZMultiInputTokenDataWithContext], + outputs: &[ZMultiTokenTransferOutputData], + compressions: Option<&[ZCompression]>, +) -> Result<(), ErrorCode> { + // ArrayVec with 5 entries: (mint_index, sum) + let mut mint_sums: ArrayVec<(u8, u64), 5> = ArrayVec::new(); + + // Process inputs - increase sums + sum_inputs(inputs, &mut mint_sums)?; + + // Process compressions if present + if let Some(compressions) = compressions { + sum_compressions(compressions, &mut mint_sums)?; + } + + // Process outputs - decrease sums + sum_outputs(outputs, &mut mint_sums)?; + + // Verify all sums are zero + for (_, sum) in mint_sums.iter() { + if *sum != 0 { + return Err(ErrorCode::SumCheckFailed); + } + } + + Ok(()) +} diff --git a/programs/compressed-token/program/src/shared/context.rs b/programs/compressed-token/program/src/shared/context.rs new file mode 100644 index 0000000000..1048281b72 --- /dev/null +++ b/programs/compressed-token/program/src/shared/context.rs @@ -0,0 +1,60 @@ +use anchor_lang::solana_program::program_error::ProgramError; +use arrayvec::ArrayVec; +use light_compressed_account::hash_to_bn254_field_size_be; +use pinocchio::pubkey::Pubkey; + +/// Context for caching hashed values to avoid recomputation +pub struct TokenContext { + /// Cache for mint hashes: (mint_pubkey, hashed_mint) + pub hashed_mints: ArrayVec<(Pubkey, [u8; 32]), 5>, + /// Cache for pubkey hashes: (pubkey, hashed_pubkey) + pub hashed_pubkeys: Vec<(Pubkey, [u8; 32])>, +} + +impl TokenContext { + /// Create a new empty context + pub fn new() -> Self { + Self { + hashed_mints: ArrayVec::new(), + hashed_pubkeys: Vec::new(), + } + } + + /// Get or compute hash for a mint pubkey + pub fn get_or_hash_mint(&mut self, mint: &Pubkey) -> Result<[u8; 32], ProgramError> { + let hashed_mint = self.hashed_mints.iter().find(|a| &a.0 == mint).map(|a| a.1); + match hashed_mint { + Some(hashed_mint) => Ok(hashed_mint), + None => { + let hashed_mint = hash_to_bn254_field_size_be(mint); + self.hashed_mints + .try_push((*mint, hashed_mint)) + .map_err(|_| ProgramError::InvalidAccountData)?; + Ok(hashed_mint) + } + } + } + + /// Get or compute hash for a pubkey (owner, delegate, etc.) + pub fn get_or_hash_pubkey(&mut self, pubkey: &Pubkey) -> [u8; 32] { + let hashed_pubkey = self + .hashed_pubkeys + .iter() + .find(|a| &a.0 == pubkey) + .map(|a| a.1); + match hashed_pubkey { + Some(hashed_pubkey) => hashed_pubkey, + None => { + let hashed_pubkey = hash_to_bn254_field_size_be(pubkey); + self.hashed_pubkeys.push((*pubkey, hashed_pubkey)); + hashed_pubkey + } + } + } +} + +impl Default for TokenContext { + fn default() -> Self { + Self::new() + } +} diff --git a/programs/compressed-token/program/src/shared/cpi.rs b/programs/compressed-token/program/src/shared/cpi.rs new file mode 100644 index 0000000000..d9473a354b --- /dev/null +++ b/programs/compressed-token/program/src/shared/cpi.rs @@ -0,0 +1,186 @@ +use std::mem::MaybeUninit; + +use account_compression::utils::constants::NOOP_PUBKEY; +use anchor_lang::solana_program::program_error::ProgramError; +use light_sdk_types::{ + ACCOUNT_COMPRESSION_AUTHORITY_PDA, ACCOUNT_COMPRESSION_PROGRAM_ID, CPI_AUTHORITY_PDA_SEED, + LIGHT_SYSTEM_PROGRAM_ID, REGISTERED_PROGRAM_PDA, +}; +use pinocchio::{ + account_info::{AccountInfo, BorrowState}, + cpi::{invoke_signed_unchecked, MAX_CPI_ACCOUNTS}, + instruction::{Account, AccountMeta, Instruction, Seed, Signer}, + msg, + pubkey::Pubkey, +}; + +use crate::LIGHT_CPI_SIGNER; + +/// Generalized CPI function for invoking light-system-program +/// +/// This function builds the standard account meta structure for light-system-program CPI +/// and appends dynamic tree accounts (merkle trees, queues, etc.) to the account metas. +/// +/// # Arguments +/// * `accounts` - All account infos passed to the instruction +/// * `cpi_bytes` - The CPI instruction data bytes +/// * `tree_accounts` - Slice of tree account pubkeys to append (will be marked as mutable) +/// * `sol_pool_pda` - Optional sol pool PDA pubkey +/// * `cpi_context_account` - Optional CPI context account pubkey +/// +/// # Returns +/// * `Result<(), ProgramError>` - Success or error from the CPI call +pub fn execute_cpi_invoke( + accounts: &[AccountInfo], + cpi_bytes: Vec, + tree_accounts: &[&Pubkey], + with_sol_pool: bool, + cpi_context_account: Option, +) -> Result<(), ProgramError> { + // Build account metas with capacity for standard accounts + dynamic tree accounts + let capacity = 11 + tree_accounts.len(); // 11 standard accounts + dynamic tree accounts + let mut account_metas = Vec::with_capacity(capacity); + + // Standard account metas for light-system-program CPI + // Account order must match light-system program's InvokeCpiInstruction expectation: + // 0: fee_payer, 1: authority, 2: registered_program_pda, 3: noop_program, + // 4: account_compression_authority, 5: account_compression_program, 6: invoking_program, + // 7: sol_pool_pda (optional), 8: decompression_recipient (optional), 9: system_program, + // 10: cpi_context_account (optional), then remaining accounts (merkle trees, etc.) + let inner_pool = + solana_pubkey::pubkey!("CHK57ywWSDncAoRu1F8QgwYJeXuAJyyBYT4LixLXvMZ1").to_bytes(); + let sol_pool_pda = if with_sol_pool { + AccountMeta::new(&inner_pool, true, false) + } else { + AccountMeta::new(&LIGHT_SYSTEM_PROGRAM_ID, false, false) + }; + account_metas.extend_from_slice(&[ + AccountMeta::new(accounts[0].key(), true, true), // 0 fee_payer (signer, mutable) + AccountMeta::new(&LIGHT_CPI_SIGNER.cpi_signer, false, true), // 1 authority (cpi_authority_pda) + AccountMeta::new(®ISTERED_PROGRAM_PDA, false, false), // 2 registered_program_pda + AccountMeta::new(&NOOP_PUBKEY, false, false), // 3 noop_program + AccountMeta::new(&ACCOUNT_COMPRESSION_AUTHORITY_PDA, false, false), // 4 account_compression_authority + AccountMeta::new(&ACCOUNT_COMPRESSION_PROGRAM_ID, false, false), // 5 account_compression_program + AccountMeta::new(&LIGHT_CPI_SIGNER.program_id, false, false), // 6 invoking_program (self_program) + sol_pool_pda, // 7 sol_pool_pda + AccountMeta::new(&LIGHT_SYSTEM_PROGRAM_ID, false, false), // 8 decompression_recipient (None, using default) + AccountMeta::new(&[0u8; 32], false, false), // system_program + AccountMeta::new( + if let Some(cpi_context) = cpi_context_account.as_ref() { + cpi_context + } else { + &LIGHT_SYSTEM_PROGRAM_ID + }, + false, + false, + ), // cpi_context_account + ]); + + // Append dynamic tree accounts (merkle trees, queues, etc.) as mutable accounts + for tree_account in tree_accounts { + account_metas.push(AccountMeta::new(tree_account, true, false)); + } + msg!( + "account_metas {:?}", + account_metas + .iter() + .map(|meta| solana_pubkey::Pubkey::new_from_array(*meta.pubkey)) + .collect::>() + ); + let instruction = Instruction { + program_id: &LIGHT_SYSTEM_PROGRAM_ID, + accounts: account_metas.as_slice(), + data: cpi_bytes.as_slice(), + }; + + // Use the precomputed CPI signer and bump from the config + let bump_seed = [LIGHT_CPI_SIGNER.bump]; + let seed_array = [ + Seed::from(CPI_AUTHORITY_PDA_SEED), + Seed::from(bump_seed.as_slice()), + ]; + let signer = Signer::from(&seed_array); + + match slice_invoke_signed(&instruction, accounts, &[signer]) { + Ok(()) => {} + Err(e) => { + msg!(format!("slice_invoke_signed failed: {:?}", e).as_str()); + return Err(ProgramError::InvalidArgument); + } + } + + Ok(()) +} + +#[inline] +pub fn slice_invoke_signed( + instruction: &Instruction, + account_infos: &[AccountInfo], + signers_seeds: &[Signer], +) -> pinocchio::ProgramResult { + use pinocchio::program_error::ProgramError; + if instruction.accounts.len() < account_infos.len() { + return Err(ProgramError::NotEnoughAccountKeys); + } + + if account_infos.len() > MAX_CPI_ACCOUNTS { + return Err(ProgramError::InvalidArgument); + } + + const UNINIT: MaybeUninit = MaybeUninit::::uninit(); + let mut accounts = [UNINIT; MAX_CPI_ACCOUNTS]; + let mut len = 0; + + for (account_info, account_meta) in account_infos.iter().zip( + instruction + .accounts + .iter() + .filter(|x| x.pubkey != instruction.program_id), + ) { + if account_info.key() != account_meta.pubkey { + use std::format; + msg!(format!( + "Received account key: {:?}", + solana_pubkey::Pubkey::new_from_array(*account_info.key()) + ) + .as_str()); + msg!(format!( + "Expected account key: {:?}", + solana_pubkey::Pubkey::new_from_array(*account_meta.pubkey) + ) + .as_str()); + + return Err(ProgramError::InvalidArgument); + } + + let state = if account_meta.is_writable { + BorrowState::Borrowed + } else { + BorrowState::MutablyBorrowed + }; + + if account_info.is_borrowed(state) { + return Err(ProgramError::AccountBorrowFailed); + } + + // SAFETY: The number of accounts has been validated to be less than + // `MAX_CPI_ACCOUNTS`. + unsafe { + accounts + .get_unchecked_mut(len) + .write(Account::from(account_info)); + } + + len += 1; + } + // SAFETY: The accounts have been validated. + unsafe { + invoke_signed_unchecked( + instruction, + core::slice::from_raw_parts(accounts.as_ptr() as _, len), + signers_seeds, + ); + } + + Ok(()) +} diff --git a/programs/compressed-token/program/src/shared/cpi_bytes_size.rs b/programs/compressed-token/program/src/shared/cpi_bytes_size.rs new file mode 100644 index 0000000000..fb303a2958 --- /dev/null +++ b/programs/compressed-token/program/src/shared/cpi_bytes_size.rs @@ -0,0 +1,143 @@ +use anchor_lang::Discriminator; +use arrayvec::ArrayVec; +use light_compressed_account::{ + compressed_account::{ + CompressedAccountConfig, CompressedAccountDataConfig, PackedMerkleContextConfig, + }, + instruction_data::{ + compressed_proof::CompressedProofConfig, + cpi_context::CompressedCpiContextConfig, + data::OutputCompressedAccountWithPackedContextConfig, + with_readonly::{ + InAccountConfig, InstructionDataInvokeCpiWithReadOnly, + InstructionDataInvokeCpiWithReadOnlyConfig, + }, + }, +}; +use light_zero_copy::ZeroCopyNew; + +const MAX_INPUT_ACCOUNTS: usize = 8; +const MAX_OUTPUT_ACCOUNTS: usize = 35; + +#[derive(Debug, Clone)] +pub struct CpiConfigInput { + pub input_accounts: ArrayVec, // Per-input account delegate flag + pub output_accounts: ArrayVec, // Per-output account delegate flag + pub has_proof: bool, + pub compressed_mint: bool, + pub compressed_mint_with_freeze_authority: bool, +} + +impl CpiConfigInput { + /// Helper to create config for mint_to_compressed with no delegates + pub fn mint_to_compressed( + num_recipients: usize, + has_proof: bool, + compressed_mint_with_freeze_authority: bool, + ) -> Self { + let mut output_delegates = ArrayVec::new(); + for _ in 0..num_recipients { + output_delegates.push(false); // No delegates for simple mint + } + + Self { + input_accounts: ArrayVec::new(), // No input accounts for mint_to_compressed + output_accounts: output_delegates, + has_proof, + compressed_mint: true, + compressed_mint_with_freeze_authority, + } + } +} + +// TODO: add version of this function with hardcoded values that just calculates the cpi_byte_size, with a randomized test vs this function +pub fn cpi_bytes_config(input: CpiConfigInput) -> InstructionDataInvokeCpiWithReadOnlyConfig { + let input_compressed_accounts = { + let mut inputs_capacity = input.input_accounts.len(); + if input.compressed_mint { + inputs_capacity += 1; + } + let mut input_compressed_accounts = Vec::with_capacity(inputs_capacity); + + // Add regular input accounts (token accounts) + for _ in input.input_accounts { + input_compressed_accounts.push(InAccountConfig { + merkle_context: PackedMerkleContextConfig {}, // Default merkle context + address: (false, ()), // Token accounts don't have addresses + }); + } + + // Add compressed mint input account if needed + if input.compressed_mint { + input_compressed_accounts.push(InAccountConfig { + merkle_context: PackedMerkleContextConfig {}, // Default merkle context + address: (true, ()), + }); + } + + input_compressed_accounts + }; + + let output_compressed_accounts = { + { + let total_outputs = input.output_accounts.len() + if input.has_proof { 1 } else { 0 }; + let mut outputs = Vec::with_capacity(total_outputs); + for has_delegate in input.output_accounts { + let token_data_size = if has_delegate { 107 } else { 75 }; // 75 + 32 (delegate) = 107 + + outputs.push(OutputCompressedAccountWithPackedContextConfig { + compressed_account: CompressedAccountConfig { + address: (false, ()), // Token accounts don't have addresses + data: ( + true, + CompressedAccountDataConfig { + data: token_data_size, // Size depends on delegate: 75 without, 107 with + }, + ), + }, + }); + } + + // Add compressed mint update if needed (last output account) + if input.compressed_mint { + use crate::mint::state::{CompressedMint, CompressedMintConfig}; + let mint_size_config = CompressedMintConfig { + mint_authority: (input.compressed_mint, ()), + freeze_authority: (input.compressed_mint_with_freeze_authority, ()), + }; + outputs.push(OutputCompressedAccountWithPackedContextConfig { + compressed_account: CompressedAccountConfig { + address: (true, ()), // Compressed mint has an address + data: ( + true, + CompressedAccountDataConfig { + data: CompressedMint::byte_len(&mint_size_config) as u32, + }, + ), + }, + }); + } + outputs + } + }; + InstructionDataInvokeCpiWithReadOnlyConfig { + cpi_context: CompressedCpiContextConfig {}, + proof: (input.has_proof, CompressedProofConfig {}), + new_address_params: vec![], // No new addresses for mint_to_compressed + input_compressed_accounts, + output_compressed_accounts, + read_only_addresses: vec![], + read_only_accounts: vec![], + } +} + +/// Allocate CPI instruction bytes with discriminator and length prefix +pub fn allocate_invoke_with_read_only_cpi_bytes( + config: &InstructionDataInvokeCpiWithReadOnlyConfig, +) -> Vec { + let vec_len = InstructionDataInvokeCpiWithReadOnly::byte_len(config); + let mut cpi_bytes = vec![0u8; vec_len + 8]; + cpi_bytes[0..8] + .copy_from_slice(light_system_program::instruction::InvokeCpiWithReadOnly::DISCRIMINATOR); + cpi_bytes +} diff --git a/programs/compressed-token/program/src/shared/initialize_token_account.rs b/programs/compressed-token/program/src/shared/initialize_token_account.rs new file mode 100644 index 0000000000..8fdc8653d8 --- /dev/null +++ b/programs/compressed-token/program/src/shared/initialize_token_account.rs @@ -0,0 +1,29 @@ +use anchor_lang::prelude::ProgramError; +use light_account_checks::AccountInfoTrait; +use pinocchio::account_info::AccountInfo; +use spl_pod::bytemuck::pod_from_bytes_mut; +use spl_token_2022::pod::PodAccount; +use spl_token_2022::state::AccountState; + +/// Initialize a token account using spl-pod with zero balance and default settings +pub fn initialize_token_account( + token_account_info: &AccountInfo, + mint_pubkey: &[u8; 32], + owner_pubkey: &[u8; 32], +) -> Result<(), ProgramError> { + // Access the token account data as mutable bytes + let mut token_account_data = AccountInfoTrait::try_borrow_mut_data(token_account_info) + .map_err(|_| ProgramError::InvalidAccountData)?; + + // Use zero-copy PodAccount to initialize the token account + let pod_account = pod_from_bytes_mut::(&mut token_account_data) + .map_err(|_| ProgramError::InvalidAccountData)?; + + // Initialize the token account fields + pod_account.mint = solana_pubkey::Pubkey::from(*mint_pubkey); + pod_account.owner = solana_pubkey::Pubkey::from(*owner_pubkey); + pod_account.delegate = spl_token_2022::pod::PodCOption::none(); // No delegate + pod_account.state = AccountState::Initialized as u8; // Set to Initialized state + + Ok(()) +} diff --git a/programs/compressed-token/program/src/shared/inputs.rs b/programs/compressed-token/program/src/shared/inputs.rs new file mode 100644 index 0000000000..e23499786b --- /dev/null +++ b/programs/compressed-token/program/src/shared/inputs.rs @@ -0,0 +1,110 @@ +use anchor_compressed_token::token_data::TokenData; +use anchor_lang::solana_program::program_error::ProgramError; +use light_account_checks::checks::check_signer; +use light_compressed_account::instruction_data::with_readonly::ZInAccountMut; +use pinocchio::account_info::AccountInfo; + +use super::context::TokenContext; +use crate::{ + constants::TOKEN_COMPRESSED_ACCOUNT_DISCRIMINATOR, + multi_transfer::instruction_data::ZMultiInputTokenDataWithContext, +}; + +/// Creates an input compressed account using zero-copy patterns and index-based account lookup. +/// +/// Validates signer authorization (owner or delegate), populates the zero-copy account structure, +/// and computes the appropriate token data hash based on frozen state. +pub fn create_input_compressed_account( + input_compressed_account: &mut ZInAccountMut, + context: &mut TokenContext, + input_token_data: &ZMultiInputTokenDataWithContext, + remaining_accounts: &[AccountInfo], + lamports: u64, +) -> std::result::Result<(), ProgramError> { + anchor_lang::solana_program::msg!("create_input_compressed_account"); + anchor_lang::solana_program::msg!("remaining_accounts len {}", remaining_accounts.len()); + // Get owner from remaining accounts using the owner index + let owner_account = &remaining_accounts[input_token_data.owner as usize]; + let owner = *owner_account.key(); + anchor_lang::solana_program::msg!("owner_account"); + + // Verify signer authorization using light-account-checks + let hashed_delegate = if input_token_data.with_delegate() { + // If delegate is used, delegate must be signer + let delegate_account = &remaining_accounts[input_token_data.delegate as usize]; + + check_signer(delegate_account).map_err(|e| { + anchor_lang::solana_program::msg!( + "Delegate signer: {:?}", + solana_pubkey::Pubkey::new_from_array(*delegate_account.key()) + ); + anchor_lang::solana_program::msg!("Delegate signer check failed: {:?}", e); + ProgramError::from(e) + })?; + Some(context.get_or_hash_pubkey(delegate_account.key())) + } else { + // If no delegate, owner must be signer + + check_signer(owner_account).map_err(|e| { + anchor_lang::solana_program::msg!( + "Checking owner signer: {:?}", + solana_pubkey::Pubkey::new_from_array(*owner_account.key()) + ); + anchor_lang::solana_program::msg!("Owner signer check failed: {:?}", e); + ProgramError::from(e) + })?; + None + }; + + // Create ZInAccountMut with proper fields + input_compressed_account.lamports.set(lamports); + input_compressed_account.discriminator = TOKEN_COMPRESSED_ACCOUNT_DISCRIMINATOR; + // Set merkle context fields manually due to mutability constraints + input_compressed_account + .merkle_context + .merkle_tree_pubkey_index = input_token_data.merkle_context.merkle_tree_pubkey_index; + input_compressed_account.merkle_context.queue_pubkey_index = + input_token_data.merkle_context.queue_pubkey_index; + input_compressed_account + .merkle_context + .leaf_index + .set(input_token_data.merkle_context.leaf_index.into()); + input_compressed_account.merkle_context.prove_by_index = + input_token_data.merkle_context.prove_by_index; + input_compressed_account + .root_index + .set(input_token_data.root_index.get()); + input_compressed_account.address = None; + + // TLV handling is now done separately in the parent instruction data + // Compute data hash using TokenContext for caching + let hashed_owner = context.get_or_hash_pubkey(&owner); + + // Get mint hash from context + let mint_account = &remaining_accounts[input_token_data.mint as usize]; + let hashed_mint = context.get_or_hash_mint(mint_account.key())?; + + let mut amount_bytes = [0u8; 32]; + amount_bytes[24..].copy_from_slice(input_token_data.amount.get().to_be_bytes().as_slice()); + + // Use appropriate hash function based on frozen state + input_compressed_account.data_hash = if !IS_FROZEN { + TokenData::hash_with_hashed_values( + &hashed_mint, + &hashed_owner, + &amount_bytes, + &hashed_delegate.as_ref(), + ) + .map_err(ProgramError::from)? + } else { + TokenData::hash_frozen_with_hashed_values( + &hashed_mint, + &hashed_owner, + &amount_bytes, + &hashed_delegate.as_ref(), + ) + .map_err(ProgramError::from)? + }; + + Ok(()) +} diff --git a/programs/compressed-token/program/src/shared/mod.rs b/programs/compressed-token/program/src/shared/mod.rs new file mode 100644 index 0000000000..990fc5ce7a --- /dev/null +++ b/programs/compressed-token/program/src/shared/mod.rs @@ -0,0 +1,36 @@ +pub mod context; +pub mod cpi; +pub mod cpi_bytes_size; +pub mod inputs; +pub mod outputs; +pub mod initialize_token_account; + +use anchor_lang::solana_program::program_error::ProgramError; +use pinocchio::account_info::AccountInfo; + +pub struct AccountIterator<'info> { + accounts: &'info [AccountInfo], + position: usize, +} + +impl<'info> AccountIterator<'info> { + pub fn new(accounts: &'info [AccountInfo]) -> Self { + Self { + accounts, + position: 0, + } + } + + pub fn next(&mut self) -> Result<&'info AccountInfo, ProgramError> { + if self.position >= self.accounts.len() { + return Err(ProgramError::NotEnoughAccountKeys); + } + let account = &self.accounts[self.position]; + self.position += 1; + Ok(account) + } + + pub fn remaining(&self) -> &'info [AccountInfo] { + &self.accounts[self.position..] + } +} diff --git a/programs/compressed-token/program/src/shared/outputs.rs b/programs/compressed-token/program/src/shared/outputs.rs new file mode 100644 index 0000000000..26bf36c65e --- /dev/null +++ b/programs/compressed-token/program/src/shared/outputs.rs @@ -0,0 +1,119 @@ +use anchor_lang::{ + prelude::borsh, solana_program::program_error::ProgramError, AnchorDeserialize, AnchorSerialize, +}; +use light_compressed_account::{ + instruction_data::data::ZOutputCompressedAccountWithPackedContextMut, Pubkey, +}; +use light_zero_copy::{num_trait::ZeroCopyNumTrait, ZeroCopyMut, ZeroCopyNew}; + +use super::context::TokenContext; + +use crate::constants::TOKEN_COMPRESSED_ACCOUNT_DISCRIMINATOR; + +// Import the anchor TokenData for hash computation +use anchor_compressed_token::token_data::TokenData as AnchorTokenData; + +#[derive(Clone, Copy, Debug, PartialEq, Eq, AnchorSerialize, AnchorDeserialize)] +#[repr(u8)] +pub enum AccountState { + Initialized, + Frozen, +} + +#[derive(Debug, PartialEq, Eq, AnchorSerialize, AnchorDeserialize, Clone, ZeroCopyMut)] +pub struct TokenData { + /// The mint associated with this account + pub mint: Pubkey, + /// The owner of this account. + pub owner: Pubkey, + /// The amount of tokens this account holds. + pub amount: u64, + /// If `delegate` is `Some` then `delegated_amount` represents + /// the amount authorized by the delegate + pub delegate: Option, + /// The account's state (u8: 0 = Initialized, 1 = Frozen) + pub state: u8, + /// Placeholder for TokenExtension tlv data (unimplemented) + pub tlv: Option>, +} + +#[allow(clippy::too_many_arguments)] +pub fn create_output_compressed_account( + output_compressed_account: &mut ZOutputCompressedAccountWithPackedContextMut<'_>, + context: &mut TokenContext, + owner: Pubkey, + delegate: Option, + amount: impl ZeroCopyNumTrait, + lamports: Option, + mint_pubkey: Pubkey, + hashed_mint: &[u8; 32], + merkle_tree_index: u8, +) -> Result<(), ProgramError> { + // Get compressed account data from CPI struct + let compressed_account_data = output_compressed_account + .compressed_account + .data + .as_mut() + .ok_or(ProgramError::InvalidAccountData)?; + + // Set discriminator + compressed_account_data.discriminator = TOKEN_COMPRESSED_ACCOUNT_DISCRIMINATOR; + // Create TokenData using zero-copy + { + // Create token data config based on delegate presence + let token_config: ::ZeroCopyConfig = TokenDataConfig { + delegate: (delegate.is_some(), ()), + tlv: (false, vec![]), + }; + + let (mut token_data, _) = + TokenData::new_zero_copy(compressed_account_data.data, token_config) + .map_err(ProgramError::from)?; + + // Set token data fields directly on zero-copy struct + token_data.mint = mint_pubkey; + token_data.owner = owner; + token_data.amount.set(amount.into()); + if let Some(z_delegate) = token_data.delegate.as_deref_mut() { + let delegate_pubkey = delegate.ok_or(ProgramError::InvalidAccountData)?; + *z_delegate = delegate_pubkey; + } + *token_data.state = AccountState::Initialized as u8; + } + // Compute data hash using the anchor TokenData hash_with_hashed_values method + { + let hashed_owner = context.get_or_hash_pubkey(&owner.into()); + let mut amount_bytes = [0u8; 32]; + amount_bytes[24..].copy_from_slice(amount.to_bytes_be().as_slice()); + + let hashed_delegate = + delegate.map(|delegate_pubkey| context.get_or_hash_pubkey(&delegate_pubkey.into())); + + let hash_result = AnchorTokenData::hash_with_hashed_values( + hashed_mint, + &hashed_owner, + &amount_bytes, + &hashed_delegate.as_ref(), + ) + .map_err(ProgramError::from)?; + compressed_account_data + .data_hash + .copy_from_slice(&hash_result); + } + + // Set other compressed account fields + { + output_compressed_account.compressed_account.owner = crate::ID.into(); + + let lamports_value = lamports.unwrap_or(0u64.into()); + output_compressed_account + .compressed_account + .lamports + .set(lamports_value.into()); + + // Set merkle tree index from parameter + *output_compressed_account.merkle_tree_index = merkle_tree_index; + } + + Ok(()) +} diff --git a/programs/compressed-token/program/tests/inputs.rs b/programs/compressed-token/program/tests/inputs.rs new file mode 100644 index 0000000000..2940272018 --- /dev/null +++ b/programs/compressed-token/program/tests/inputs.rs @@ -0,0 +1,196 @@ +use anchor_compressed_token::token_data::TokenData as AnchorTokenData; +use anchor_lang::{prelude::*, solana_program::account_info::AccountInfo}; +use arrayvec::ArrayVec; +use borsh::{BorshDeserialize, BorshSerialize}; +use light_compressed_account::instruction_data::{ + with_readonly::InAccount, with_readonly::InstructionDataInvokeCpiWithReadOnly, +}; +use light_compressed_token::{ + constants::TOKEN_COMPRESSED_ACCOUNT_DISCRIMINATOR, + multi_transfer::instruction_data::MultiInputTokenDataWithContext, + shared::{ + context::TokenContext, + cpi_bytes_size::{ + allocate_invoke_with_read_only_cpi_bytes, cpi_bytes_config, CpiConfigInput, + }, + inputs::create_input_compressed_account, + }, +}; +use light_sdk::instruction::PackedMerkleContext; +use light_zero_copy::{borsh::Deserialize, ZeroCopyNew}; +use rand::Rng; + +/* +#[test] +fn test_rnd_create_input_compressed_account() { + let mut rng = rand::thread_rng(); + let iter = 1000; + + for _ in 0..iter { + // Generate random parameters + let mint_pubkey = Pubkey::new_from_array(rng.gen::<[u8; 32]>()); + let owner_pubkey = Pubkey::new_from_array(rng.gen::<[u8; 32]>()); + let delegate_pubkey = Pubkey::new_from_array(rng.gen::<[u8; 32]>()); + + // Random amount from 0 to u64::MAX + let amount = rng.gen::(); + let lamports = rng.gen_range(0..=1000000u64); + + // Random delegate flag (30% chance) + let with_delegate = rng.gen_bool(0.3); + + // Random merkle context fields + let merkle_tree_pubkey_index = rng.gen_range(0..=255u8); + let queue_pubkey_index = rng.gen_range(0..=255u8); + let leaf_index = rng.gen::(); + let prove_by_index = rng.gen_bool(0.5); + let root_index = rng.gen::(); + + // Create input token data + let input_token_data = MultiInputTokenDataWithContext { + amount, + merkle_context: PackedMerkleContext { + merkle_tree_pubkey_index, + queue_pubkey_index, + leaf_index, + prove_by_index, + }, + root_index, + mint: 0, // mint is at index 0 in remaining_accounts + owner: 1, // owner is at index 1 in remaining_accounts + with_delegate, + delegate: if with_delegate { 2 } else { 0 }, // delegate at index 2 if present + }; + + // Serialize and get zero-copy reference + let input_data = input_token_data.try_to_vec().unwrap(); + let (z_input_data, _) = MultiInputTokenDataWithContext::zero_copy_at(&input_data).unwrap(); + + // Create mock remaining accounts + let mut mock_accounts = vec![ + create_mock_account(mint_pubkey, false), // mint at index 0 + create_mock_account(owner_pubkey, !with_delegate), // owner at index 1, signer if no delegate + ]; + + if with_delegate { + mock_accounts.push(create_mock_account(delegate_pubkey, true)); // delegate at index 2, signer + } + + let remaining_accounts: Vec = mock_accounts; + + // Test both frozen and unfrozen states + for is_frozen in [false, true] { + // Allocate CPI bytes structure like in other tests + let config_input = CpiConfigInput { + input_accounts: { + let mut arr = ArrayVec::new(); + arr.push(false); // Basic input account + arr + }, + output_accounts: ArrayVec::new(), + has_proof: false, + compressed_mint: false, + compressed_mint_with_freeze_authority: false, + }; + + let config = cpi_bytes_config(config_input); + let mut cpi_bytes = allocate_invoke_with_read_only_cpi_bytes(&config); + let (mut cpi_instruction_struct, _) = + InstructionDataInvokeCpiWithReadOnly::new_zero_copy(&mut cpi_bytes[8..], config) + .unwrap(); + + // Get the input account reference + let input_account = &mut cpi_instruction_struct.input_compressed_accounts[0]; + + let mut context = TokenContext::new(); + + // Call the function under test + let result = if is_frozen { + create_input_compressed_account::( + input_account, + &mut context, + &z_input_data, + &remaining_accounts, + lamports, + ) + } else { + create_input_compressed_account::( + input_account, + &mut context, + &z_input_data, + &remaining_accounts, + lamports, + ) + }; + + assert!(result.is_ok(), "Function failed: {:?}", result.err()); + + // Deserialize for validation using borsh pattern like other tests + let cpi_borsh = + InstructionDataInvokeCpiWithReadOnly::deserialize(&mut &cpi_bytes[8..]).unwrap(); + + // Create expected token data for validation + let expected_owner = owner_pubkey; + let expected_delegate = if with_delegate { + Some(delegate_pubkey) + } else { + None + }; + + let expected_token_data = AnchorTokenData { + mint: mint_pubkey, + owner: expected_owner, + amount, + delegate: expected_delegate, + state: if is_frozen { + anchor_compressed_token::token_data::AccountState::Frozen + } else { + anchor_compressed_token::token_data::AccountState::Initialized + }, + tlv: None, + }; + + // Calculate expected data hash + let expected_hash = expected_token_data.hash().unwrap(); + + // Build expected input account + let expected_input_account = InAccount { + discriminator: TOKEN_COMPRESSED_ACCOUNT_DISCRIMINATOR, + data_hash: expected_hash, + merkle_context: PackedMerkleContext { + merkle_tree_pubkey_index, + queue_pubkey_index, + leaf_index, + prove_by_index, + }, + root_index, + lamports, + address: None, + }; + + let expected = InstructionDataInvokeCpiWithReadOnly { + input_compressed_accounts: vec![expected_input_account], + ..Default::default() + }; + + assert_eq!(cpi_borsh, expected); + } + } +} +*/ + +// Helper function to create mock AccountInfo +fn create_mock_account(pubkey: Pubkey, is_signer: bool) -> AccountInfo<'static> { + let lamports = Box::leak(Box::new(0u64)); + let data = Box::leak(Box::new(vec![])); + AccountInfo::new( + Box::leak(Box::new(pubkey)), + is_signer, + false, + lamports, + data, + Box::leak(Box::new(Pubkey::default())), + false, + 0, + ) +} diff --git a/programs/compressed-token/program/tests/mint.rs b/programs/compressed-token/program/tests/mint.rs new file mode 100644 index 0000000000..7c3748d03a --- /dev/null +++ b/programs/compressed-token/program/tests/mint.rs @@ -0,0 +1,217 @@ +use borsh::{BorshDeserialize, BorshSerialize}; +use light_compressed_account::{ + address::derive_address, + compressed_account::{CompressedAccount, CompressedAccountData}, + instruction_data::{ + data::OutputCompressedAccountWithPackedContext, + with_readonly::InstructionDataInvokeCpiWithReadOnly, + }, + Pubkey, +}; +use light_compressed_token::{ + constants::COMPRESSED_MINT_DISCRIMINATOR, + mint::{ + output::create_output_compressed_mint_account, + state::{CompressedMint, CompressedMintConfig}, + }, + shared::cpi_bytes_size::{ + allocate_invoke_with_read_only_cpi_bytes, cpi_bytes_config, CpiConfigInput, + }, +}; +use light_zero_copy::ZeroCopyNew; +use rand::Rng; + +#[test] +fn test_rnd_create_compressed_mint_account() { + let mut rng = rand::thread_rng(); + let iter = 100; + + for _ in 0..iter { + // Generate random mint parameters + let mint_pda = Pubkey::new_from_array(rng.gen::<[u8; 32]>()); + let decimals = rng.gen_range(0..=18u8); + let program_id = Pubkey::new_from_array(rng.gen::<[u8; 32]>()); + let address_merkle_tree = Pubkey::new_from_array(rng.gen::<[u8; 32]>()); + + // Random freeze authority (50% chance) + let freeze_authority = if rng.gen_bool(0.5) { + Some(Pubkey::new_from_array(rng.gen::<[u8; 32]>())) + } else { + None + }; + + let mint_authority = Some(Pubkey::new_from_array(rng.gen::<[u8; 32]>())); + + // // Create mint config - match the real usage pattern (always reserve mint_authority space) + let mint_config = CompressedMintConfig { + mint_authority: (true, ()), // Always true like in cpi_bytes_config and mint_to_compressed + freeze_authority: (freeze_authority.is_some(), ()), + }; + // Derive compressed account address + let compressed_account_address = derive_address( + &mint_pda.to_bytes(), + &address_merkle_tree.to_bytes(), + &program_id.to_bytes(), + ); + + // Create a simple test structure for just the output account + let config_input = CpiConfigInput { + input_accounts: arrayvec::ArrayVec::new(), + output_accounts: arrayvec::ArrayVec::new(), + has_proof: false, + compressed_mint: true, + compressed_mint_with_freeze_authority: freeze_authority.is_some(), + }; + + let config = cpi_bytes_config(config_input); + let mut cpi_bytes = allocate_invoke_with_read_only_cpi_bytes(&config); + let (mut cpi_instruction_struct, _) = + light_compressed_account::instruction_data::with_readonly::InstructionDataInvokeCpiWithReadOnly::new_zero_copy( + &mut cpi_bytes[8..], + config, + ) + .unwrap(); + + // Get the input and output compressed accounts + let input_account = &mut cpi_instruction_struct.input_compressed_accounts[0]; + let output_account = &mut cpi_instruction_struct.output_compressed_accounts[0]; + + // Create mock input data for the input compressed mint account test + use light_compressed_account::compressed_account::PackedMerkleContext; + use light_compressed_token::mint_to_compressed::instructions::CompressedMintInputs; + use light_compressed_token::shared::context::TokenContext; + use light_zero_copy::borsh::Deserialize; + + // Generate random values for more comprehensive testing + let input_supply = rng.gen_range(0..=u64::MAX); + let output_supply = rng.gen_range(0..=u64::MAX); // Random supply for output account + let is_decompressed = rng.gen_bool(0.1); // 10% chance + let num_extensions = rng.gen_range(0..=255u8); + let merkle_tree_pubkey_index = rng.gen_range(0..=255u8); + let queue_pubkey_index = rng.gen_range(0..=255u8); + let leaf_index = rng.gen::(); + let prove_by_index = rng.gen_bool(0.5); + let root_index = rng.gen::(); + let output_merkle_tree_index = rng.gen_range(0..=255u8); + + // Create mock input compressed mint data + let input_compressed_mint = CompressedMintInputs { + compressed_mint_input: + light_compressed_token::mint_to_compressed::instructions::CompressedMintInput { + spl_mint: mint_pda, + supply: input_supply, + decimals, + is_decompressed, + freeze_authority_is_set: freeze_authority.is_some(), + freeze_authority: freeze_authority.unwrap_or_default(), + num_extensions, + }, + merkle_context: PackedMerkleContext { + merkle_tree_pubkey_index, + queue_pubkey_index, + leaf_index, + prove_by_index, + }, + root_index, + address: compressed_account_address, + output_merkle_tree_index, + }; + + // Serialize and get zero-copy reference + let input_data = input_compressed_mint.try_to_vec().unwrap(); + let (z_compressed_mint_inputs, _) = + CompressedMintInputs::zero_copy_at(&input_data).unwrap(); + + // Create token context and call input function + let mut context = TokenContext::new(); + let hashed_mint_authority = context.get_or_hash_pubkey(&mint_authority.unwrap().into()); + light_compressed_token::mint::input::create_input_compressed_mint_account( + input_account, + &mut context, + &z_compressed_mint_inputs, + &hashed_mint_authority, + ) + .unwrap(); + + // Call the function under test + create_output_compressed_mint_account( + output_account, + mint_pda, + decimals, + freeze_authority, + mint_authority, + output_supply.into(), // supply parameter (U64 type) + &program_id, + mint_config, + compressed_account_address, + output_merkle_tree_index, + ) + .unwrap(); + + // Final comparison with borsh deserialization - same pattern as token account tests + let cpi_borsh = + InstructionDataInvokeCpiWithReadOnly::deserialize(&mut &cpi_bytes[8..]).unwrap(); + + // Build expected output + let expected_compressed_mint = CompressedMint { + spl_mint: mint_pda, + supply: output_supply, + decimals, + is_decompressed: false, + mint_authority, + freeze_authority, + num_extensions: 0, + }; + + let expected_data_hash = expected_compressed_mint.hash().unwrap(); + + let expected_account = OutputCompressedAccountWithPackedContext { + compressed_account: CompressedAccount { + address: Some(compressed_account_address), + owner: program_id, + lamports: 0, + data: Some(CompressedAccountData { + data: expected_compressed_mint.try_to_vec().unwrap(), + discriminator: COMPRESSED_MINT_DISCRIMINATOR, + data_hash: expected_data_hash, + }), + }, + merkle_tree_index: output_merkle_tree_index, + }; + + // Create expected input account data that matches what the input function should produce + let expected_input_compressed_mint = CompressedMint { + spl_mint: mint_pda, + supply: input_supply, + decimals, + is_decompressed, + mint_authority, // Use the actual mint authority passed to the function + freeze_authority, + num_extensions, + }; + let expected_input_data_hash = expected_input_compressed_mint.hash().unwrap(); + + let expected_input_account = + light_compressed_account::instruction_data::with_readonly::InAccount { + discriminator: COMPRESSED_MINT_DISCRIMINATOR, + data_hash: expected_input_data_hash, + merkle_context: PackedMerkleContext { + merkle_tree_pubkey_index, + queue_pubkey_index, + leaf_index, + prove_by_index, + }, + root_index, + lamports: 0, + address: Some(compressed_account_address), + }; + + let expected = InstructionDataInvokeCpiWithReadOnly { + input_compressed_accounts: vec![expected_input_account], + output_compressed_accounts: vec![expected_account], + ..Default::default() + }; + + assert_eq!(cpi_borsh, expected); + } +} diff --git a/programs/compressed-token/program/tests/multi_sum_check.rs b/programs/compressed-token/program/tests/multi_sum_check.rs new file mode 100644 index 0000000000..26e744d1fb --- /dev/null +++ b/programs/compressed-token/program/tests/multi_sum_check.rs @@ -0,0 +1,364 @@ +use anchor_compressed_token::ErrorCode; +use anchor_lang::AnchorSerialize; +use light_compressed_token::multi_transfer::{ + instruction_data::{Compression, MultiInputTokenDataWithContext, MultiTokenTransferOutputData}, + sum_check::sum_check_multi_mint, +}; +use light_zero_copy::borsh::Deserialize; +use std::collections::HashMap; + +type Result = std::result::Result; +// TODO: check test coverage +#[test] +fn test_multi_sum_check() { + // SUCCEED: no relay fee, compression + multi_sum_check_test(&[100, 50], &[150], None, false).unwrap(); + multi_sum_check_test(&[75, 25, 25], &[25, 25, 25, 25, 12, 13], None, false).unwrap(); + + // FAIL: no relay fee, compression + multi_sum_check_test(&[100, 50], &[150 + 1], None, false).unwrap_err(); + multi_sum_check_test(&[100, 50], &[150 - 1], None, false).unwrap_err(); + multi_sum_check_test(&[100, 50], &[], None, false).unwrap_err(); + multi_sum_check_test(&[], &[100, 50], None, false).unwrap_err(); + + // SUCCEED: empty + multi_sum_check_test(&[], &[], None, true).unwrap(); + multi_sum_check_test(&[], &[], None, false).unwrap(); + // FAIL: empty + multi_sum_check_test(&[], &[], Some(1), false).unwrap_err(); + multi_sum_check_test(&[], &[], Some(1), true).unwrap_err(); + + // SUCCEED: with compress + multi_sum_check_test(&[100], &[123], Some(23), true).unwrap(); + multi_sum_check_test(&[], &[150], Some(150), true).unwrap(); + // FAIL: compress + multi_sum_check_test(&[], &[150], Some(150 - 1), true).unwrap_err(); + multi_sum_check_test(&[], &[150], Some(150 + 1), true).unwrap_err(); + + // SUCCEED: with decompress + multi_sum_check_test(&[100, 50], &[100], Some(50), false).unwrap(); + multi_sum_check_test(&[100, 50], &[], Some(150), false).unwrap(); + // FAIL: decompress + multi_sum_check_test(&[100, 50], &[], Some(150 - 1), false).unwrap_err(); + multi_sum_check_test(&[100, 50], &[], Some(150 + 1), false).unwrap_err(); +} + +fn multi_sum_check_test( + input_amounts: &[u64], + output_amounts: &[u64], + compress_or_decompress_amount: Option, + is_compress: bool, +) -> Result<()> { + // Create normal types + let inputs: Vec<_> = input_amounts + .iter() + .map(|&amount| MultiInputTokenDataWithContext { + amount, + ..Default::default() + }) + .collect(); + + let outputs: Vec<_> = output_amounts + .iter() + .map(|&amount| MultiTokenTransferOutputData { + amount, + ..Default::default() + }) + .collect(); + + let compressions = compress_or_decompress_amount.map(|amount| { + vec![Compression { + amount, + is_compress, + mint: 0, // Same mint + source_or_recipient: 0, + }] + }); + + // Serialize to bytes using borsh + let input_bytes = inputs.try_to_vec().unwrap(); + let output_bytes = outputs.try_to_vec().unwrap(); + let compression_bytes = compressions.as_ref().map(|c| c.try_to_vec().unwrap()); + + // Deserialize as zero-copy + let (inputs_zc, _) = Vec::::zero_copy_at(&input_bytes).unwrap(); + let (outputs_zc, _) = Vec::::zero_copy_at(&output_bytes).unwrap(); + let compressions_zc = if let Some(ref bytes) = compression_bytes { + let (comp, _) = Vec::::zero_copy_at(bytes).unwrap(); + Some(comp) + } else { + None + }; + + // Call our sum check function + sum_check_multi_mint(&inputs_zc, &outputs_zc, compressions_zc.as_deref()) +} + +#[test] +fn test_simple_multi_mint_cases() { + // First test a simple known case + test_simple_multi_mint().unwrap(); +} + +#[test] +fn test_multi_mint_randomized() { + // Test multiple scenarios with different mint combinations + for scenario in 0..3000 { + println!("Testing scenario {}", scenario); + + // Create test case with multiple mints + let seed = scenario as u64; + test_randomized_scenario(seed).unwrap(); + } +} +#[test] +fn test_failing_multi_mint_cases() { + // Test specific failure cases + test_failing_cases().unwrap(); +} +fn test_simple_multi_mint() -> Result<()> { + // Simple test: mint 0: input 100, output 100; mint 1: input 200, output 200 + let inputs = vec![(0, 100), (1, 200)]; + let outputs = vec![(0, 100), (1, 200)]; + let compressions = vec![]; + + test_multi_mint_scenario(&inputs, &outputs, &compressions)?; + + // Test with compression: mint 0: input 100 + compress 50 = output 150 + let inputs = vec![(0, 100)]; + let outputs = vec![(0, 150)]; + let compressions = vec![(0, 50, true)]; + + test_multi_mint_scenario(&inputs, &outputs, &compressions)?; + + // Test with decompression: mint 0: input 200 - decompress 50 = output 150 + let inputs = vec![(0, 200)]; + let outputs = vec![(0, 150)]; + let compressions = vec![(0, 50, false)]; + + test_multi_mint_scenario(&inputs, &outputs, &compressions) +} + +fn test_randomized_scenario(seed: u64) -> Result<()> { + let mut rng_state = seed; + + // Simple LCG for deterministic randomness + let mut next_rand = || { + rng_state = rng_state.wrapping_mul(1103515245).wrapping_add(12345); + rng_state + }; + + // Generate 2-4 mints + let num_mints = 2 + (next_rand() % 3) as usize; + let mint_ids: Vec = (0..num_mints as u8).collect(); + + // Track balances per mint + let mut mint_balances: HashMap = HashMap::new(); + + // Generate inputs (1-6 inputs) + let num_inputs = 1 + (next_rand() % 6) as usize; + let mut inputs = Vec::new(); + + for _ in 0..num_inputs { + let mint = mint_ids[(next_rand() % num_mints as u64) as usize]; + let amount = 100 + (next_rand() % 1000); + + inputs.push((mint, amount)); + *mint_balances.entry(mint).or_insert(0) += amount as i128; + } + + // Generate compressions (0-3 compressions) + let num_compressions = (next_rand() % 4) as usize; + let mut compressions = Vec::new(); + + for _ in 0..num_compressions { + let mint = mint_ids[(next_rand() % num_mints as u64) as usize]; + let amount = 50 + (next_rand() % 500); + let is_compress = (next_rand() % 2) == 0; + + compressions.push((mint, amount, is_compress)); + + if is_compress { + *mint_balances.entry(mint).or_insert(0) += amount as i128; + } else { + // Only allow decompress if the mint has sufficient balance + let current_balance = *mint_balances.entry(mint).or_insert(0); + if current_balance >= amount as i128 { + *mint_balances.entry(mint).or_insert(0) -= amount as i128; + } else { + // Convert to compress instead to avoid negative balance + compressions.last_mut().unwrap().2 = true; + *mint_balances.entry(mint).or_insert(0) += amount as i128; + } + } + } + + // Ensure all balances are non-negative (adjust decompressions if needed) + for (&mint, balance) in mint_balances.iter_mut() { + if *balance < 0 { + // Add compression to make balance positive + let needed = (-*balance) as u64; + compressions.push((mint, needed, true)); + *balance += needed as i128; + } + } + + // Generate outputs that exactly match the remaining balances + let mut outputs = Vec::new(); + for (&mint, &balance) in mint_balances.iter() { + if balance > 0 { + // Split the balance into 1-3 outputs + let num_outputs = 1 + (next_rand() % 3) as usize; + let mut remaining = balance as u64; + + for i in 0..num_outputs { + let amount = if i == num_outputs - 1 { + // Last output gets the remainder + remaining + } else if remaining <= 1 { + break; // Don't create zero-amount outputs + } else { + let max_amount = remaining / (num_outputs - i) as u64; + if max_amount == 0 { + break; + } else { + 1 + (next_rand() % max_amount.max(1)) + } + }; + + if amount > 0 && remaining >= amount { + outputs.push((mint, amount)); + remaining -= amount; + } else { + break; + } + } + + // Add any remaining amount as final output + if remaining > 0 { + outputs.push((mint, remaining)); + } + } + } + + // Debug print for first scenario only + if seed == 0 { + println!( + "Debug scenario {}: inputs={:?}, compressions={:?}, outputs={:?}", + seed, inputs, compressions, outputs + ); + println!("Balances: {:?}", mint_balances); + } + + // Sort inputs by mint for order validation + inputs.sort_by_key(|(mint, _)| *mint); + // Sort outputs by mint for order validation + outputs.sort_by_key(|(mint, _)| *mint); + + // Test the sum check + test_multi_mint_scenario(&inputs, &outputs, &compressions) +} + +fn test_failing_cases() -> Result<()> { + // Test case 1: Wrong output amount + let inputs = vec![(0, 100), (1, 200)]; + let outputs = vec![(0, 100), (1, 201)]; // Wrong amount + let compressions = vec![]; + + match test_multi_mint_scenario(&inputs, &outputs, &compressions) { + Err(ErrorCode::ComputeOutputSumFailed) => {} // Expected + Err(e) => panic!("Expected ComputeOutputSumFailed, got: {:?}", e), + Ok(_) => panic!("Expected ComputeOutputSumFailed, but transaction succeeded"), + } + + // Test case 2: Output for non-existent mint + let inputs = vec![(0, 100)]; + let outputs = vec![(0, 50), (1, 50)]; // Mint 1 not in inputs + let compressions = vec![]; + + match test_multi_mint_scenario(&inputs, &outputs, &compressions) { + Err(ErrorCode::ComputeOutputSumFailed) => {} // Expected + _ => panic!("Should have failed with SumCheckFailed"), + } + + // Test case 3: Too many mints (>5) + let inputs = vec![(0, 10), (1, 10), (2, 10), (3, 10), (4, 10), (5, 10)]; + let outputs = vec![(0, 10), (1, 10), (2, 10), (3, 10), (4, 10), (5, 10)]; + let compressions = vec![]; + + match test_multi_mint_scenario(&inputs, &outputs, &compressions) { + Err(ErrorCode::TooManyMints) => {} // Expected + _ => panic!("Should have failed with TooManyMints"), + } + + // Test case 4: Inputs out of order + let inputs = vec![(1, 100), (0, 200)]; // Wrong order + let outputs = vec![(0, 200), (1, 100)]; + let compressions = vec![]; + + match test_multi_mint_scenario(&inputs, &outputs, &compressions) { + Err(ErrorCode::InputsOutOfOrder) => {} // Expected + _ => panic!("Should have failed with InputsOutOfOrder"), + } + + Ok(()) +} + +fn test_multi_mint_scenario( + inputs: &[(u8, u64)], // (mint, amount) + outputs: &[(u8, u64)], // (mint, amount) + compressions: &[(u8, u64, bool)], // (mint, amount, is_compress) +) -> Result<()> { + // Create input structures + let input_structs: Vec<_> = inputs + .iter() + .map(|&(mint, amount)| MultiInputTokenDataWithContext { + amount, + mint, + ..Default::default() + }) + .collect(); + + // Create output structures + let output_structs: Vec<_> = outputs + .iter() + .map(|&(mint, amount)| MultiTokenTransferOutputData { + amount, + mint, + ..Default::default() + }) + .collect(); + + // Create compression structures + let compression_structs: Vec<_> = compressions + .iter() + .map(|&(mint, amount, is_compress)| Compression { + amount, + is_compress, + mint, + source_or_recipient: 0, + }) + .collect(); + + // Serialize to bytes + let input_bytes = input_structs.try_to_vec().unwrap(); + let output_bytes = output_structs.try_to_vec().unwrap(); + let compression_bytes = if compression_structs.is_empty() { + None + } else { + Some(compression_structs.try_to_vec().unwrap()) + }; + + // Deserialize as zero-copy + let (inputs_zc, _) = Vec::::zero_copy_at(&input_bytes).unwrap(); + let (outputs_zc, _) = Vec::::zero_copy_at(&output_bytes).unwrap(); + let compressions_zc = if let Some(ref bytes) = compression_bytes { + let (comp, _) = Vec::::zero_copy_at(bytes).unwrap(); + Some(comp) + } else { + None + }; + + // Call sum check + sum_check_multi_mint(&inputs_zc, &outputs_zc, compressions_zc.as_deref()) +} diff --git a/programs/compressed-token/program/tests/outputs.rs b/programs/compressed-token/program/tests/outputs.rs new file mode 100644 index 0000000000..c1f9ac0925 --- /dev/null +++ b/programs/compressed-token/program/tests/outputs.rs @@ -0,0 +1,159 @@ +use anchor_compressed_token::token_data::TokenData as AnchorTokenData; +use arrayvec::ArrayVec; +use borsh::{BorshDeserialize, BorshSerialize}; +use light_compressed_account::{ + compressed_account::{CompressedAccount, CompressedAccountData}, + hash_to_bn254_field_size_be, + instruction_data::{ + data::OutputCompressedAccountWithPackedContext, + with_readonly::InstructionDataInvokeCpiWithReadOnly, + }, + Pubkey, +}; +use light_compressed_token::{ + constants::TOKEN_COMPRESSED_ACCOUNT_DISCRIMINATOR, + shared::{ + context::TokenContext, + cpi_bytes_size::{ + allocate_invoke_with_read_only_cpi_bytes, cpi_bytes_config, CpiConfigInput, + }, + outputs::create_output_compressed_account, + }, +}; +use light_zero_copy::ZeroCopyNew; + +#[test] +fn test_rnd_create_output_compressed_accounts() { + use rand::Rng; + let mut rng = rand::rngs::ThreadRng::default(); + + let iter = 1000; + for _ in 0..iter { + let mint_pubkey = Pubkey::new_from_array(rng.gen::<[u8; 32]>()); + let hashed_mint = hash_to_bn254_field_size_be(mint_pubkey.to_bytes().as_slice()); + + // Random number of output accounts (0-35 max) + let num_outputs = rng.gen_range(0..=35); + + // Generate random owners and amounts + let mut owner_pubkeys = Vec::new(); + let mut amounts = Vec::new(); + let mut delegate_flags = Vec::new(); + let mut lamports_vec = Vec::new(); + let mut merkle_tree_indices = Vec::new(); + + for _ in 0..num_outputs { + owner_pubkeys.push(Pubkey::new_from_array(rng.gen::<[u8; 32]>())); + amounts.push(rng.gen_range(1..=u64::MAX)); + delegate_flags.push(rng.gen_bool(0.3)); // 30% chance of having delegate + lamports_vec.push(if rng.gen_bool(0.2) { + Some(rng.gen_range(1..=1000000)) + } else { + None + }); + merkle_tree_indices.push(rng.gen_range(0..=255u8)); + } + + // Random delegate + let delegate = if delegate_flags.iter().any(|&has_delegate| has_delegate) { + Some(Pubkey::new_from_array(rng.gen::<[u8; 32]>())) + } else { + None + }; + + let lamports = if lamports_vec.iter().any(|l| l.is_some()) { + Some(lamports_vec.clone()) + } else { + None + }; + + // Create output config + let mut outputs = ArrayVec::new(); + for &has_delegate in &delegate_flags { + outputs.push(has_delegate); + } + + let config_input = CpiConfigInput { + input_accounts: ArrayVec::new(), + output_accounts: outputs, + has_proof: false, + compressed_mint: false, + compressed_mint_with_freeze_authority: false, + }; + + let config = cpi_bytes_config(config_input.clone()); + let mut cpi_bytes = allocate_invoke_with_read_only_cpi_bytes(&config); + let (mut cpi_instruction_struct, _) = InstructionDataInvokeCpiWithReadOnly::new_zero_copy( + &mut cpi_bytes[8..], + config.clone(), + ) + .unwrap(); + + let mut context = TokenContext::new(); + for (index, output_account) in cpi_instruction_struct + .output_compressed_accounts + .iter_mut() + .enumerate() + { + let output_delegate = if delegate_flags[index] { + delegate + } else { + None + }; + + create_output_compressed_account( + output_account, + &mut context, + owner_pubkeys[index], + output_delegate, + amounts[index], + lamports.as_ref().and_then(|l| l[index]), + mint_pubkey, + &hashed_mint, + merkle_tree_indices[index], + ) + .unwrap(); + } + + let cpi_borsh = + InstructionDataInvokeCpiWithReadOnly::deserialize(&mut &cpi_bytes[8..]).unwrap(); + + // Build expected output + let mut expected_accounts = Vec::new(); + + for i in 0..num_outputs { + let token_delegate = if delegate_flags[i] { delegate } else { None }; + let account_lamports = lamports_vec[i].unwrap_or(0); + + let token_data = AnchorTokenData { + mint: mint_pubkey.into(), + owner: owner_pubkeys[i].into(), + amount: amounts[i], + delegate: token_delegate.map(|d| d.into()), + state: anchor_compressed_token::token_data::AccountState::Initialized, + tlv: None, + }; + let data_hash = token_data.hash().unwrap(); + + expected_accounts.push(OutputCompressedAccountWithPackedContext { + compressed_account: CompressedAccount { + address: None, + owner: light_compressed_token::ID.into(), + lamports: account_lamports, + data: Some(CompressedAccountData { + data: token_data.try_to_vec().unwrap(), + discriminator: TOKEN_COMPRESSED_ACCOUNT_DISCRIMINATOR, + data_hash, + }), + }, + merkle_tree_index: merkle_tree_indices[i], + }); + } + + let expected = InstructionDataInvokeCpiWithReadOnly { + output_compressed_accounts: expected_accounts, + ..Default::default() + }; + assert_eq!(cpi_borsh, expected); + } +} diff --git a/programs/package.json b/programs/package.json index eff8a5590d..54e7cf2ffb 100644 --- a/programs/package.json +++ b/programs/package.json @@ -3,7 +3,7 @@ "version": "0.3.0", "license": "Apache-2.0", "scripts": { - "build": "cd system/ && cargo build-sbf && cd .. && cd account-compression/ && cargo build-sbf && cd .. && cd registry/ && cargo build-sbf && cd .. && cd compressed-token/ && cargo build-sbf && cd ..", + "build": "cd system/ && cargo build-sbf && cd .. && cd account-compression/ && cargo build-sbf && cd .. && cd registry/ && cargo build-sbf && cd .. && cd compressed-token/program && cargo build-sbf && cd ../..", "build-compressed-token-small": "cd compressed-token/ && cargo build-sbf --features cpi-without-program-ids && cd ..", "build-system": "anchor build --program-name light_system_program -- --features idl-build custom-heap", "build-compressed-token": "anchor build --program-name light_compressed_token -- --features idl-build custom-heap", diff --git a/scripts/devenv.sh b/scripts/devenv.sh index 0349400d9c..35f10067c3 100755 --- a/scripts/devenv.sh +++ b/scripts/devenv.sh @@ -87,6 +87,7 @@ export CARGO_HOME export NPM_CONFIG_PREFIX export LIGHT_PROTOCOL_TOPLEVEL export LIGHT_PROTOCOL_DEVENV +export SBF_OUT_DIR=./target/deploy # Set Redis URL if not already set export REDIS_URL="${REDIS_URL:-redis://localhost:6379}" diff --git a/sdk-libs/sdk-types/src/constants.rs b/sdk-libs/sdk-types/src/constants.rs index 455cbdfd22..add2861b98 100644 --- a/sdk-libs/sdk-types/src/constants.rs +++ b/sdk-libs/sdk-types/src/constants.rs @@ -31,3 +31,5 @@ pub const TOKEN_COMPRESSED_ACCOUNT_DISCRIMINATOR: [u8; 8] = [2, 0, 0, 0, 0, 0, 0 pub const ADDRESS_TREE_V1: [u8; 32] = pubkey_array!("amt1Ayt45jfbdw5YSo7iz6WZxUmnZsQTYXy82hVwyC2"); pub const ADDRESS_QUEUE_V1: [u8; 32] = pubkey_array!("aq1S9z4reTSQAdgWHGD2zDaS39sjGrAxbR31vxJ2F4F"); +pub const ACCOUNT_COMPRESSION_AUTHORITY_PDA: [u8; 32] = + pubkey_array!("HwXnGK3tPkkVY6P439H2p68AxpeuWXd5PcrAxFpbmfbA");