diff --git a/target_chains/stylus/contracts/pyth-receiver/src/error.rs b/target_chains/stylus/contracts/pyth-receiver/src/error.rs index b20cb5284d..4b68178d80 100644 --- a/target_chains/stylus/contracts/pyth-receiver/src/error.rs +++ b/target_chains/stylus/contracts/pyth-receiver/src/error.rs @@ -17,6 +17,8 @@ pub enum PythReceiverError { InsufficientFee, InvalidEmitterAddress, TooManyUpdates, + PriceFeedNotFoundWithinRange, + NoFreshUpdate, } impl core::fmt::Debug for PythReceiverError { @@ -43,6 +45,8 @@ impl From for Vec { PythReceiverError::InsufficientFee => 13, PythReceiverError::InvalidEmitterAddress => 14, PythReceiverError::TooManyUpdates => 15, + PythReceiverError::PriceFeedNotFoundWithinRange => 16, + PythReceiverError::NoFreshUpdate => 17, }] } } diff --git a/target_chains/stylus/contracts/pyth-receiver/src/integration_tests.rs b/target_chains/stylus/contracts/pyth-receiver/src/integration_tests.rs index 0d5cb57d0f..784e274f26 100644 --- a/target_chains/stylus/contracts/pyth-receiver/src/integration_tests.rs +++ b/target_chains/stylus/contracts/pyth-receiver/src/integration_tests.rs @@ -342,4 +342,14 @@ mod test { multiple_updates_diff_vaa_results()[1] ); } + + #[motsu::test] + fn test_multiple_updates_same_id_updates_latest( + pyth_contract: Contract, + wormhole_contract: Contract, + alice: Address, + ) { + pyth_wormhole_init(&pyth_contract, &wormhole_contract, &alice); + alice.fund(U256::from(200)); + } } diff --git a/target_chains/stylus/contracts/pyth-receiver/src/lib.rs b/target_chains/stylus/contracts/pyth-receiver/src/lib.rs index aab2d43a60..8fd7c270a0 100644 --- a/target_chains/stylus/contracts/pyth-receiver/src/lib.rs +++ b/target_chains/stylus/contracts/pyth-receiver/src/lib.rs @@ -15,7 +15,7 @@ mod test_data; #[cfg(test)] use mock_instant::global::MockClock; -use alloc::vec::Vec; +use alloc::{collections::BTreeMap, vec::Vec}; use stylus_sdk::{ alloy_primitives::{Address, FixedBytes, I32, I64, U16, U256, U32, U64}, call::Call, @@ -97,7 +97,6 @@ impl PythReceiver { for (i, chain_id) in data_source_emitter_chain_ids.iter().enumerate() { let emitter_address = FixedBytes::<32>::from(data_source_emitter_addresses[i]); - // Create a new data source storage slot let mut data_source = self.valid_data_sources.grow(); data_source.chain_id.set(U16::from(*chain_id)); data_source.emitter_address.set(emitter_address); @@ -178,7 +177,7 @@ impl PythReceiver { update_data: Vec>, ) -> Result<(), PythReceiverError> { for data in &update_data { - self.update_price_feeds_internal(data.clone())?; + self.update_price_feeds_internal(data.clone(), Vec::new(), 0, 0, false)?; } let total_fee = self.get_update_fee(update_data)?; @@ -193,17 +192,172 @@ impl PythReceiver { pub fn update_price_feeds_if_necessary( &mut self, - _update_data: Vec>, - _price_ids: Vec<[u8; 32]>, - _publish_times: Vec, - ) { - // dummy implementation + update_data: Vec>, + price_ids: Vec<[u8; 32]>, + publish_times: Vec, + ) -> Result<(), PythReceiverError> { + if (price_ids.len() != publish_times.len()) + || (price_ids.is_empty() && publish_times.is_empty()) + { + return Err(PythReceiverError::InvalidUpdateData); + } + + for i in 0..price_ids.len() { + if self.latest_price_info_publish_time(price_ids[i]) < publish_times[i] { + self.update_price_feeds(update_data.clone())?; + return Ok(()); + } + } + + return Err(PythReceiverError::NoFreshUpdate); + } + + fn latest_price_info_publish_time(&self, price_id: [u8; 32]) -> u64 { + let price_id_fb: FixedBytes<32> = FixedBytes::from(price_id); + let recent_price_info = self.latest_price_info.get(price_id_fb); + recent_price_info.publish_time.get().to::() } fn update_price_feeds_internal( &mut self, update_data: Vec, - ) -> Result<(), PythReceiverError> { + _price_ids: Vec<[u8; 32]>, + min_publish_time: u64, + max_publish_time: u64, + _unique: bool, + ) -> Result, PythReceiverError> { + let price_pairs = self.parse_price_feed_updates_internal( + update_data, + min_publish_time, + max_publish_time, + false, // check_uniqueness + )?; + + for (price_id, price_return) in price_pairs.clone() { + let price_id_fb: FixedBytes<32> = FixedBytes::from(price_id); + let mut recent_price_info = self.latest_price_info.setter(price_id_fb); + + if recent_price_info.publish_time.get() < price_return.0 + || recent_price_info.price.get() == I64::ZERO + { + recent_price_info.publish_time.set(price_return.0); + recent_price_info.expo.set(price_return.1); + recent_price_info.price.set(price_return.2); + recent_price_info.conf.set(price_return.3); + recent_price_info.ema_price.set(price_return.4); + recent_price_info.ema_conf.set(price_return.5); + } + } + + Ok(price_pairs) + } + + fn get_update_fee(&self, update_data: Vec>) -> Result { + let mut total_num_updates: u64 = 0; + for data in &update_data { + let update_data_array: &[u8] = &data; + let accumulator_update = AccumulatorUpdateData::try_from_slice(&update_data_array) + .map_err(|_| PythReceiverError::InvalidAccumulatorMessage)?; + match accumulator_update.proof { + Proof::WormholeMerkle { vaa: _, updates } => { + let num_updates = u64::try_from(updates.len()) + .map_err(|_| PythReceiverError::TooManyUpdates)?; + total_num_updates += num_updates; + } + } + } + Ok(self.get_total_fee(total_num_updates)) + } + + fn get_total_fee(&self, total_num_updates: u64) -> U256 { + U256::from(total_num_updates).saturating_mul(self.single_update_fee_in_wei.get()) + + self.transaction_fee_in_wei.get() + } + + pub fn get_twap_update_fee(&self, _update_data: Vec>) -> U256 { + U256::from(0u8) + } + + pub fn parse_price_feed_updates( + &mut self, + update_data: Vec, + price_ids: Vec<[u8; 32]>, + min_publish_time: u64, + max_publish_time: u64, + ) -> Result, PythReceiverError> { + let price_feeds = self.parse_price_feed_updates_with_config( + vec![update_data], + price_ids, + min_publish_time, + max_publish_time, + false, + false, + false, + ); + price_feeds + } + + pub fn parse_price_feed_updates_with_config( + &mut self, + update_data: Vec>, + price_ids: Vec<[u8; 32]>, + min_allowed_publish_time: u64, + max_allowed_publish_time: u64, + check_uniqueness: bool, + check_update_data_is_minimal: bool, + store_updates_if_fresh: bool, + ) -> Result, PythReceiverError> { + let mut all_parsed_price_pairs = Vec::new(); + for data in &update_data { + if store_updates_if_fresh { + all_parsed_price_pairs.extend(self.update_price_feeds_internal( + data.clone(), + price_ids.clone(), + min_allowed_publish_time, + max_allowed_publish_time, + check_uniqueness, + )?); + } else { + all_parsed_price_pairs.extend(self.parse_price_feed_updates_internal( + data.clone(), + min_allowed_publish_time, + max_allowed_publish_time, + check_uniqueness, + )?); + } + } + + if check_update_data_is_minimal && all_parsed_price_pairs.len() != price_ids.len() { + return Err(PythReceiverError::InvalidUpdateData); + } + + let mut result: Vec = Vec::with_capacity(price_ids.len()); + let mut price_map: BTreeMap<[u8; 32], PriceInfoReturn> = BTreeMap::new(); + + for (price_id, price_info) in all_parsed_price_pairs { + if !price_map.contains_key(&price_id) { + price_map.insert(price_id, price_info); + } + } + + for price_id in price_ids { + if let Some(price_info) = price_map.get(&price_id) { + result.push(*price_info); + } else { + return Err(PythReceiverError::PriceFeedNotFoundWithinRange); + } + } + + Ok(result) + } + + fn parse_price_feed_updates_internal( + &mut self, + update_data: Vec, + min_allowed_publish_time: u64, + max_allowed_publish_time: u64, + check_uniqueness: bool, + ) -> Result, PythReceiverError> { let update_data_array: &[u8] = &update_data; // Check the first 4 bytes of the update_data_array for the magic header if update_data_array.len() < 4 { @@ -220,6 +374,8 @@ impl PythReceiver { let accumulator_update = AccumulatorUpdateData::try_from_slice(&update_data_array) .map_err(|_| PythReceiverError::InvalidAccumulatorMessage)?; + let mut price_feeds: BTreeMap<[u8; 32], PriceInfoReturn> = BTreeMap::new(); + match accumulator_update.proof { Proof::WormholeMerkle { vaa, updates } => { let wormhole: IWormholeContract = IWormholeContract::new(self.wormhole.get()); @@ -228,10 +384,10 @@ impl PythReceiver { .parse_and_verify_vm(config, Vec::from(vaa.clone())) .map_err(|_| PythReceiverError::InvalidWormholeMessage)?; - let vaa = Vaa::read(&mut Vec::from(vaa.clone()).as_slice()) + let vaa_obj = Vaa::read(&mut Vec::from(vaa.clone()).as_slice()) .map_err(|_| PythReceiverError::VaaVerificationFailed)?; - let cur_emitter_address: &[u8; 32] = vaa + let cur_emitter_address: &[u8; 32] = vaa_obj .body .emitter_address .as_slice() @@ -239,7 +395,7 @@ impl PythReceiver { .map_err(|_| PythReceiverError::InvalidEmitterAddress)?; let cur_data_source = DataSource { - chain_id: U16::from(vaa.body.emitter_chain), + chain_id: U16::from(vaa_obj.body.emitter_chain), emitter_address: FixedBytes::from(cur_emitter_address), }; @@ -247,7 +403,7 @@ impl PythReceiver { return Err(PythReceiverError::InvalidWormholeMessage); } - let root_digest: MerkleRoot = parse_wormhole_proof(vaa)?; + let root_digest: MerkleRoot = parse_wormhole_proof(vaa_obj)?; for update in updates { let message_vec = Vec::from(update.message); @@ -262,33 +418,40 @@ impl PythReceiver { match msg { Message::PriceFeedMessage(price_feed_message) => { - let price_id_fb: FixedBytes<32> = - FixedBytes::from(price_feed_message.feed_id); - let mut recent_price_info = self.latest_price_info.setter(price_id_fb); + let publish_time = price_feed_message.publish_time; - if recent_price_info.publish_time.get() - < U64::from(price_feed_message.publish_time) - || recent_price_info.price.get() == I64::ZERO + if (min_allowed_publish_time > 0 + && publish_time < min_allowed_publish_time as i64) + || (max_allowed_publish_time > 0 + && publish_time > max_allowed_publish_time as i64) { - recent_price_info - .publish_time - .set(U64::from(price_feed_message.publish_time)); - recent_price_info.price.set(I64::from_le_bytes( - price_feed_message.price.to_le_bytes(), - )); - recent_price_info - .conf - .set(U64::from(price_feed_message.conf)); - recent_price_info.expo.set(I32::from_le_bytes( - price_feed_message.exponent.to_le_bytes(), - )); - recent_price_info.ema_price.set(I64::from_le_bytes( - price_feed_message.ema_price.to_le_bytes(), - )); - recent_price_info - .ema_conf - .set(U64::from(price_feed_message.ema_conf)); + return Err(PythReceiverError::PriceFeedNotFoundWithinRange); } + + if check_uniqueness { + let price_id_fb = + FixedBytes::<32>::from(price_feed_message.feed_id); + let prev_price_info = self.latest_price_info.get(price_id_fb); + let prev_publish_time = + prev_price_info.publish_time.get().to::(); + + if prev_publish_time > 0 + && min_allowed_publish_time <= prev_publish_time + { + return Err(PythReceiverError::PriceFeedNotFoundWithinRange); + } + } + + let price_info_return = ( + U64::from(publish_time), + I32::from_be_bytes(price_feed_message.exponent.to_be_bytes()), + I64::from_be_bytes(price_feed_message.price.to_be_bytes()), + U64::from(price_feed_message.conf), + I64::from_be_bytes(price_feed_message.ema_price.to_be_bytes()), + U64::from(price_feed_message.ema_conf), + ); + + price_feeds.insert(price_feed_message.feed_id, price_info_return); } _ => { return Err(PythReceiverError::InvalidAccumulatorMessageType); @@ -298,56 +461,7 @@ impl PythReceiver { } }; - Ok(()) - } - - fn get_update_fee(&self, update_data: Vec>) -> Result { - let mut total_num_updates: u64 = 0; - for data in &update_data { - let update_data_array: &[u8] = &data; - let accumulator_update = AccumulatorUpdateData::try_from_slice(&update_data_array) - .map_err(|_| PythReceiverError::InvalidAccumulatorMessage)?; - match accumulator_update.proof { - Proof::WormholeMerkle { vaa: _, updates } => { - let num_updates = u64::try_from(updates.len()) - .map_err(|_| PythReceiverError::TooManyUpdates)?; - total_num_updates += num_updates; - } - } - } - Ok(self.get_total_fee(total_num_updates)) - } - - fn get_total_fee(&self, total_num_updates: u64) -> U256 { - U256::from(total_num_updates).saturating_mul(self.single_update_fee_in_wei.get()) - + self.transaction_fee_in_wei.get() - } - - pub fn get_twap_update_fee(&self, _update_data: Vec>) -> U256 { - U256::from(0u8) - } - - pub fn parse_price_feed_updates( - &mut self, - _update_data: Vec>, - _price_ids: Vec<[u8; 32]>, - _min_publish_time: u64, - _max_publish_time: u64, - ) -> Vec { - Vec::new() - } - - pub fn parse_price_feed_updates_with_config( - &mut self, - _update_data: Vec>, - _price_ids: Vec<[u8; 32]>, - _min_allowed_publish_time: u64, - _max_allowed_publish_time: u64, - _check_uniqueness: bool, - _check_update_data_is_minimal: bool, - _store_updates_if_fresh: bool, - ) -> (Vec, Vec) { - (Vec::new(), Vec::new()) + Ok(price_feeds.into_iter().collect()) } pub fn parse_twap_price_feed_updates( @@ -360,12 +474,21 @@ impl PythReceiver { pub fn parse_price_feed_updates_unique( &mut self, - _update_data: Vec>, - _price_ids: Vec<[u8; 32]>, - _min_publish_time: u64, - _max_publish_time: u64, - ) -> Vec { - Vec::new() + update_data: Vec>, + price_ids: Vec<[u8; 32]>, + min_publish_time: u64, + max_publish_time: u64, + ) -> Result, PythReceiverError> { + let price_feeds = self.parse_price_feed_updates_with_config( + update_data, + price_ids, + min_publish_time, + max_publish_time, + true, + false, + false, + ); + price_feeds } fn is_no_older_than(&self, publish_time: U64, max_age: u64) -> bool { diff --git a/target_chains/stylus/contracts/wormhole/src/lib.rs b/target_chains/stylus/contracts/wormhole/src/lib.rs index 5b0aea296d..240f988fb0 100644 --- a/target_chains/stylus/contracts/wormhole/src/lib.rs +++ b/target_chains/stylus/contracts/wormhole/src/lib.rs @@ -500,7 +500,7 @@ mod tests { use core::str::FromStr; use k256::ecdsa::SigningKey; use stylus_sdk::alloy_primitives::keccak256; - + #[cfg(test)] use base64::engine::general_purpose; #[cfg(test)] @@ -543,7 +543,7 @@ mod tests { 0x39, 0x3a, 0x3b, 0x3c, 0x3d, 0x3e, 0x3f, 0x40, ] } - + #[cfg(test)] fn current_guardians() -> Vec
{ vec![ @@ -634,7 +634,7 @@ mod tests { contract.initialize(guardians, 1, CHAIN_ID, GOVERNANCE_CHAIN_ID, governance_contract).unwrap(); contract } - + #[cfg(test)] fn deploy_with_current_mainnet_guardians() -> WormholeContract { let mut contract = WormholeContract::default(); @@ -802,7 +802,7 @@ mod tests { #[motsu::test] fn test_verification_multiple_guardian_sets() { let mut contract = deploy_with_current_mainnet_guardians(); - + let store_result = contract.store_gs(4, current_guardians(), 0); if let Err(_) = store_result { panic!("Error deploying multiple guardian sets"); @@ -816,7 +816,7 @@ mod tests { #[motsu::test] fn test_verification_incorrect_guardian_set() { let mut contract = deploy_with_current_mainnet_guardians(); - + let store_result = contract.store_gs(4, mock_guardian_set13(), 0); if let Err(_) = store_result { panic!("Error deploying guardian set"); @@ -1147,7 +1147,7 @@ mod tests { let mut contract = WormholeContract::default(); let guardians = current_guardians(); let governance_contract = Address::from_slice(&GOVERNANCE_CONTRACT.to_be_bytes::<32>()[12..32]); - + let result = contract.initialize(guardians.clone(), 4, CHAIN_ID, GOVERNANCE_CHAIN_ID, governance_contract); assert!(result.is_ok(), "Contract initialization should succeed"); } @@ -1222,5 +1222,5 @@ mod tests { assert!(result2.is_ok()); } - -} \ No newline at end of file + +}